FunctionTool partial support (#5183)

<!-- Thank you for your contribution! Please review
https://microsoft.github.io/autogen/docs/Contribute before opening a
pull request. -->

<!-- Please add a reviewer to the assignee section when you create a PR.
If you don't have the access to it, we will shortly find a reviewer and
assign them to your PR. -->

## Why are these changes needed?

FunctionTool supports passing in a partial

## Related issue number

Closes #5151 

## Checks

- [x] I've included any doc changes needed for
https://microsoft.github.io/autogen/. See
https://microsoft.github.io/autogen/docs/Contribute#documentation to
build and test documentation locally.
- [x] I've added tests (if relevant) corresponding to the changes
introduced in this PR.
- [x] I've made sure all auto checks have passed.
This commit is contained in:
Nour Bouzid 2025-01-29 19:02:18 +01:00 committed by GitHub
parent 2f1684b698
commit 02e968a531
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 66 additions and 2 deletions

View File

@ -3,6 +3,7 @@
import inspect
import typing
from functools import partial
from logging import getLogger
from typing import (
Annotated,
@ -41,7 +42,8 @@ def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature:
"""
signature = inspect.signature(call)
globalns = getattr(call, "__globals__", {})
type_hints = typing.get_type_hints(call, globalns, include_extras=True)
func_call = call.func if isinstance(call, partial) else call
type_hints = typing.get_type_hints(func_call, globalns, include_extras=True)
typed_params = [
inspect.Parameter(
name=param.name,

View File

@ -88,7 +88,7 @@ class FunctionTool(BaseTool[BaseModel, BaseModel], Component[FunctionToolConfig]
self._func = func
self._global_imports = global_imports
signature = get_typed_signature(func)
func_name = name or func.__name__
func_name = name or func.func.__name__ if isinstance(func, functools.partial) else name or func.__name__
args_model = args_base_model_from_signature(func_name + "args", signature)
return_type = signature.return_annotation
self._has_cancellation_support = "cancellation_token" in signature.parameters

View File

@ -1,4 +1,5 @@
import inspect
from functools import partial
from typing import Annotated, List
import pytest
@ -109,6 +110,67 @@ def test_func_tool_schema_generation_only_default_arg() -> None:
assert "required" not in schema["parameters"]
def test_func_tool_with_partial_positional_arguments_schema_generation() -> None:
"""Test correct schema generation for a partial function with positional arguments."""
def get_weather(country: str, city: str) -> str:
return f"The temperature in {city}, {country} is 75°"
partial_function = partial(get_weather, "Germany")
tool = FunctionTool(partial_function, description="Partial function tool.")
schema = tool.schema
assert schema["name"] == "get_weather"
assert "description" in schema
assert schema["description"] == "Partial function tool."
assert "parameters" in schema
assert schema["parameters"]["type"] == "object"
assert schema["parameters"]["properties"].keys() == {"city"}
assert schema["parameters"]["properties"]["city"]["type"] == "string"
assert schema["parameters"]["properties"]["city"]["description"] == "city"
assert "required" in schema["parameters"]
assert schema["parameters"]["required"] == ["city"]
assert "country" not in schema["parameters"]["properties"] # check country not in schema params
assert len(schema["parameters"]["properties"]) == 1
def test_func_call_tool_with_kwargs_schema_generation() -> None:
"""Test correct schema generation for a partial function with kwargs."""
def get_weather(country: str, city: str) -> str:
return f"The temperature in {city}, {country} is 75°"
partial_function = partial(get_weather, country="Germany")
tool = FunctionTool(partial_function, description="Partial function tool.")
schema = tool.schema
assert schema["name"] == "get_weather"
assert "description" in schema
assert schema["description"] == "Partial function tool."
assert "parameters" in schema
assert schema["parameters"]["type"] == "object"
assert schema["parameters"]["properties"].keys() == {"country", "city"}
assert schema["parameters"]["properties"]["city"]["type"] == "string"
assert schema["parameters"]["properties"]["country"]["type"] == "string"
assert "required" in schema["parameters"]
assert schema["parameters"]["required"] == ["city"] # only city is required
assert len(schema["parameters"]["properties"]) == 2
@pytest.mark.asyncio
async def test_run_func_call_tool_with_kwargs_and_args() -> None:
"""Test run partial function with kwargs and args."""
def get_weather(country: str, city: str, unit: str = "Celsius") -> str:
return f"The temperature in {city}, {country} is 75° {unit}"
partial_function = partial(get_weather, "Germany", unit="Fahrenheit")
tool = FunctionTool(partial_function, description="Partial function tool.")
result = await tool.run_json({"city": "Berlin"}, CancellationToken())
assert isinstance(result, str)
assert result == "The temperature in Berlin, Germany is 75° Fahrenheit"
@pytest.mark.asyncio
async def test_tool_run() -> None:
tool = MyTool()