mirror of https://github.com/microsoft/autogen.git
FunctionTool partial support (#5183)
<!-- Thank you for your contribution! Please review https://microsoft.github.io/autogen/docs/Contribute before opening a pull request. --> <!-- Please add a reviewer to the assignee section when you create a PR. If you don't have the access to it, we will shortly find a reviewer and assign them to your PR. --> ## Why are these changes needed? FunctionTool supports passing in a partial ## Related issue number Closes #5151 ## Checks - [x] I've included any doc changes needed for https://microsoft.github.io/autogen/. See https://microsoft.github.io/autogen/docs/Contribute#documentation to build and test documentation locally. - [x] I've added tests (if relevant) corresponding to the changes introduced in this PR. - [x] I've made sure all auto checks have passed.
This commit is contained in:
parent
2f1684b698
commit
02e968a531
|
@ -3,6 +3,7 @@
|
|||
|
||||
import inspect
|
||||
import typing
|
||||
from functools import partial
|
||||
from logging import getLogger
|
||||
from typing import (
|
||||
Annotated,
|
||||
|
@ -41,7 +42,8 @@ def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature:
|
|||
"""
|
||||
signature = inspect.signature(call)
|
||||
globalns = getattr(call, "__globals__", {})
|
||||
type_hints = typing.get_type_hints(call, globalns, include_extras=True)
|
||||
func_call = call.func if isinstance(call, partial) else call
|
||||
type_hints = typing.get_type_hints(func_call, globalns, include_extras=True)
|
||||
typed_params = [
|
||||
inspect.Parameter(
|
||||
name=param.name,
|
||||
|
|
|
@ -88,7 +88,7 @@ class FunctionTool(BaseTool[BaseModel, BaseModel], Component[FunctionToolConfig]
|
|||
self._func = func
|
||||
self._global_imports = global_imports
|
||||
signature = get_typed_signature(func)
|
||||
func_name = name or func.__name__
|
||||
func_name = name or func.func.__name__ if isinstance(func, functools.partial) else name or func.__name__
|
||||
args_model = args_base_model_from_signature(func_name + "args", signature)
|
||||
return_type = signature.return_annotation
|
||||
self._has_cancellation_support = "cancellation_token" in signature.parameters
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import inspect
|
||||
from functools import partial
|
||||
from typing import Annotated, List
|
||||
|
||||
import pytest
|
||||
|
@ -109,6 +110,67 @@ def test_func_tool_schema_generation_only_default_arg() -> None:
|
|||
assert "required" not in schema["parameters"]
|
||||
|
||||
|
||||
def test_func_tool_with_partial_positional_arguments_schema_generation() -> None:
|
||||
"""Test correct schema generation for a partial function with positional arguments."""
|
||||
|
||||
def get_weather(country: str, city: str) -> str:
|
||||
return f"The temperature in {city}, {country} is 75°"
|
||||
|
||||
partial_function = partial(get_weather, "Germany")
|
||||
tool = FunctionTool(partial_function, description="Partial function tool.")
|
||||
schema = tool.schema
|
||||
|
||||
assert schema["name"] == "get_weather"
|
||||
assert "description" in schema
|
||||
assert schema["description"] == "Partial function tool."
|
||||
assert "parameters" in schema
|
||||
assert schema["parameters"]["type"] == "object"
|
||||
assert schema["parameters"]["properties"].keys() == {"city"}
|
||||
assert schema["parameters"]["properties"]["city"]["type"] == "string"
|
||||
assert schema["parameters"]["properties"]["city"]["description"] == "city"
|
||||
assert "required" in schema["parameters"]
|
||||
assert schema["parameters"]["required"] == ["city"]
|
||||
assert "country" not in schema["parameters"]["properties"] # check country not in schema params
|
||||
assert len(schema["parameters"]["properties"]) == 1
|
||||
|
||||
|
||||
def test_func_call_tool_with_kwargs_schema_generation() -> None:
|
||||
"""Test correct schema generation for a partial function with kwargs."""
|
||||
|
||||
def get_weather(country: str, city: str) -> str:
|
||||
return f"The temperature in {city}, {country} is 75°"
|
||||
|
||||
partial_function = partial(get_weather, country="Germany")
|
||||
tool = FunctionTool(partial_function, description="Partial function tool.")
|
||||
schema = tool.schema
|
||||
|
||||
assert schema["name"] == "get_weather"
|
||||
assert "description" in schema
|
||||
assert schema["description"] == "Partial function tool."
|
||||
assert "parameters" in schema
|
||||
assert schema["parameters"]["type"] == "object"
|
||||
assert schema["parameters"]["properties"].keys() == {"country", "city"}
|
||||
assert schema["parameters"]["properties"]["city"]["type"] == "string"
|
||||
assert schema["parameters"]["properties"]["country"]["type"] == "string"
|
||||
assert "required" in schema["parameters"]
|
||||
assert schema["parameters"]["required"] == ["city"] # only city is required
|
||||
assert len(schema["parameters"]["properties"]) == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_func_call_tool_with_kwargs_and_args() -> None:
|
||||
"""Test run partial function with kwargs and args."""
|
||||
|
||||
def get_weather(country: str, city: str, unit: str = "Celsius") -> str:
|
||||
return f"The temperature in {city}, {country} is 75° {unit}"
|
||||
|
||||
partial_function = partial(get_weather, "Germany", unit="Fahrenheit")
|
||||
tool = FunctionTool(partial_function, description="Partial function tool.")
|
||||
result = await tool.run_json({"city": "Berlin"}, CancellationToken())
|
||||
assert isinstance(result, str)
|
||||
assert result == "The temperature in Berlin, Germany is 75° Fahrenheit"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_run() -> None:
|
||||
tool = MyTool()
|
||||
|
|
Loading…
Reference in New Issue