Adding declarative HTTP tools to autogen ext (#5181)

## Why are these changes needed?
These changes are needed because currently there's no generic way to add
`tools` to autogen studio workflows using the existing DSL and schema
other than inline python.

This API will be quite verbose, and lacks a discovery mechanism, but it
unlocks a lot of programmatic use-cases.

## Related issue number
https://github.com/microsoft/autogen/issues/5170

Co-authored-by: Victor Dibia <victordibia@microsoft.com>
Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com>
This commit is contained in:
Eitan Yarmush 2025-02-10 15:27:27 -05:00 committed by GitHub
parent 9e15e9529c
commit 8a9f452136
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 565 additions and 5 deletions

View File

@ -46,17 +46,17 @@ python/autogen_ext.agents.web_surfer
python/autogen_ext.agents.file_surfer
python/autogen_ext.agents.video_surfer
python/autogen_ext.agents.video_surfer.tools
python/autogen_ext.auth.azure
python/autogen_ext.teams.magentic_one
python/autogen_ext.models.cache
python/autogen_ext.models.openai
python/autogen_ext.models.replay
python/autogen_ext.models.azure
python/autogen_ext.models.semantic_kernel
python/autogen_ext.tools.code_execution
python/autogen_ext.tools.graphrag
python/autogen_ext.tools.http
python/autogen_ext.tools.langchain
python/autogen_ext.tools.mcp
python/autogen_ext.tools.graphrag
python/autogen_ext.tools.code_execution
python/autogen_ext.tools.semantic_kernel
python/autogen_ext.code_executors.local
python/autogen_ext.code_executors.docker
@ -65,4 +65,5 @@ python/autogen_ext.code_executors.azure
python/autogen_ext.cache_store.diskcache
python/autogen_ext.cache_store.redis
python/autogen_ext.runtimes.grpc
python/autogen_ext.auth.azure
```

View File

@ -0,0 +1,8 @@
autogen\_ext.tools.http
=======================
.. automodule:: autogen_ext.tools.http
:members:
:undoc-members:
:show-inheritance:

View File

@ -106,6 +106,11 @@ semantic-kernel-dapr = [
"semantic-kernel[dapr]>=1.17.1",
]
http-tool = [
"httpx>=0.27.0",
"json-schema-to-pydantic>=0.2.0"
]
semantic-kernel-all = [
"semantic-kernel[google,hugging_face,mistralai,ollama,onnx,anthropic,usearch,pandas,aws,dapr]>=1.17.1",
]

View File

@ -0,0 +1,3 @@
from ._http_tool import HttpTool
__all__ = ["HttpTool"]

View File

@ -0,0 +1,233 @@
import re
from typing import Any, Literal, Optional, Type
import httpx
from autogen_core import CancellationToken, Component
from autogen_core.tools import BaseTool
from json_schema_to_pydantic import create_model
from pydantic import BaseModel, Field
from typing_extensions import Self
class HttpToolConfig(BaseModel):
name: str
"""
The name of the tool.
"""
description: Optional[str]
"""
A description of the tool.
"""
scheme: Literal["http", "https"] = "http"
"""
The scheme to use for the request.
"""
host: str
"""
The URL to send the request to.
"""
port: int
"""
The port to send the request to.
"""
path: str = Field(default="/")
"""
The path to send the request to. defaults to "/"
The path can accept parameters, e.g. "/{param1}/{param2}".
These parameters will be templated from the inputs args, any additional parameters will be added as query parameters or the body of the request.
"""
method: Optional[Literal["GET", "POST", "PUT", "DELETE", "PATCH"]] = "POST"
"""
The HTTP method to use, will default to POST if not provided.
"""
headers: Optional[dict[str, Any]]
"""
A dictionary of headers to send with the request.
"""
json_schema: dict[str, Any]
"""
A JSON Schema object defining the expected parameters for the tool.
Path parameters MUST also be included in the json_schema. They must also MUST be set to string
"""
return_type: Optional[Literal["text", "json"]] = "text"
"""
The type of response to return from the tool.
"""
class HttpTool(BaseTool[BaseModel, Any], Component[HttpToolConfig]):
"""A wrapper for using an HTTP server as a tool.
Args:
name (str): The name of the tool.
description (str, optional): A description of the tool.
scheme (str): The scheme to use for the request. Must be either "http" or "https".
host (str): The host to send the request to.
port (int): The port to send the request to.
path (str, optional): The path to send the request to. Defaults to "/".
Can include path parameters like "/{param1}/{param2}" which will be templated from input args.
method (str, optional): The HTTP method to use, will default to POST if not provided.
Must be one of "GET", "POST", "PUT", "DELETE", "PATCH".
headers (dict[str, Any], optional): A dictionary of headers to send with the request.
json_schema (dict[str, Any]): A JSON Schema object defining the expected parameters for the tool.
Path parameters must also be included in the schema and must be strings.
return_type (Literal["text", "json"], optional): The type of response to return from the tool.
Defaults to "text".
.. note::
This tool requires the :code:`http-tool` extra for the :code:`autogen-ext` package.
To install:
.. code-block:: bash
pip install -U "autogen-agentchat" "autogen-ext[http-tool]"
Example:
Simple use case::
import asyncio
from autogen_agentchat.agents import AssistantAgent
from autogen_agentchat.messages import TextMessage
from autogen_core import CancellationToken
from autogen_ext.models.openai import OpenAIChatCompletionClient
from autogen_ext.tools.http import HttpTool
# Define a JSON schema for a base64 decode tool
base64_schema = {
"type": "object",
"properties": {
"value": {"type": "string", "description": "The base64 value to decode"},
},
"required": ["value"],
}
# Create an HTTP tool for the httpbin API
base64_tool = HttpTool(
name="base64_decode",
description="base64 decode a value",
scheme="https",
host="httpbin.org",
port=443,
path="/base64/{value}",
method="GET",
json_schema=base64_schema,
)
async def main():
# Create an assistant with the base64 tool
model = OpenAIChatCompletionClient(model="gpt-4")
assistant = AssistantAgent("base64_assistant", model_client=model, tools=[base64_tool])
# The assistant can now use the base64 tool to decode the string
response = await assistant.on_messages(
[TextMessage(content="Can you base64 decode the value 'YWJjZGU=', please?", source="user")],
CancellationToken(),
)
print(response.chat_message.content)
asyncio.run(main())
"""
component_type = "tool"
component_provider_override = "autogen_ext.tools.http.HttpTool"
component_config_schema = HttpToolConfig
def __init__(
self,
name: str,
host: str,
port: int,
json_schema: dict[str, Any],
headers: Optional[dict[str, Any]] = None,
description: str = "HTTP tool",
path: str = "/",
scheme: Literal["http", "https"] = "http",
method: Literal["GET", "POST", "PUT", "DELETE", "PATCH"] = "POST",
return_type: Literal["text", "json"] = "text",
) -> None:
self.server_params = HttpToolConfig(
name=name,
description=description,
host=host,
port=port,
path=path,
scheme=scheme,
method=method,
headers=headers,
json_schema=json_schema,
return_type=return_type,
)
# Use regex to find all path parameters, we will need those later to template the path
path_params = {match.group(1) for match in re.finditer(r"{([^}]*)}", path)}
self._path_params = path_params
# Create the input model from the modified schema
input_model = create_model(json_schema)
# Use Any as return type since HTTP responses can vary
base_return_type: Type[Any] = object
super().__init__(input_model, base_return_type, name, description)
def _to_config(self) -> HttpToolConfig:
copied_config = self.server_params.model_copy()
return copied_config
@classmethod
def _from_config(cls, config: HttpToolConfig) -> Self:
copied_config = config.model_copy().model_dump()
return cls(**copied_config)
async def run(self, args: BaseModel, cancellation_token: CancellationToken) -> Any:
"""Execute the HTTP tool with the given arguments.
Args:
args: The validated input arguments
cancellation_token: Token for cancelling the operation
Returns:
The response body from the HTTP call in JSON format
Raises:
Exception: If tool execution fails
"""
model_dump = args.model_dump()
path_params = {k: v for k, v in model_dump.items() if k in self._path_params}
# Remove path params from the model dump
for k in self._path_params:
model_dump.pop(k)
path = self.server_params.path.format(**path_params)
url = httpx.URL(
scheme=self.server_params.scheme,
host=self.server_params.host,
port=self.server_params.port,
path=path,
)
async with httpx.AsyncClient() as client:
match self.server_params.method:
case "GET":
response = await client.get(url, params=model_dump)
case "PUT":
response = await client.put(url, json=model_dump)
case "DELETE":
response = await client.delete(url, params=model_dump)
case "PATCH":
response = await client.patch(url, json=model_dump)
case _: # Default case POST
response = await client.post(url, json=model_dump)
match self.server_params.return_type:
case "text":
return response.text
case "json":
return response.json()
case _:
raise ValueError(f"Invalid return type: {self.server_params.return_type}")

View File

@ -29,7 +29,8 @@ from autogen_test_utils import (
MessageType,
NoopAgent,
)
from protos.serialization_test_pb2 import ProtoMessage
from .protos.serialization_test_pb2 import ProtoMessage
@pytest.mark.grpc
@ -423,7 +424,7 @@ class ProtoReceivingAgent(RoutedAgent):
self.received_messages: list[Any] = []
@event
async def on_new_message(self, message: ProtoMessage, ctx: MessageContext) -> None:
async def on_new_message(self, message: ProtoMessage, ctx: MessageContext) -> None: # type: ignore
self.num_calls += 1
self.received_messages.append(message)

View File

@ -0,0 +1,101 @@
import asyncio
from typing import Any, AsyncGenerator, Dict
import pytest
import pytest_asyncio
import uvicorn
from autogen_core import ComponentModel
from fastapi import FastAPI
from pydantic import BaseModel, Field
class TestArgs(BaseModel):
query: str = Field(description="The test query")
value: int = Field(description="A test value")
class TestResponse(BaseModel):
result: str = Field(description="The test result")
# Create a test FastAPI app
app = FastAPI()
@app.post("/test")
async def test_endpoint(body: TestArgs) -> TestResponse:
return TestResponse(result=f"Received: {body.query} with value {body.value}")
@app.post("/test/{query}/{value}")
async def test_path_params_endpoint(query: str, value: int) -> TestResponse:
return TestResponse(result=f"Received: {query} with value {value}")
@app.put("/test/{query}/{value}")
async def test_path_params_and_body_endpoint(query: str, value: int, body: Dict[str, Any]) -> TestResponse:
return TestResponse(result=f"Received: {query} with value {value} and extra {body.get('extra')}") # type: ignore
@app.get("/test")
async def test_get_endpoint(query: str, value: int) -> TestResponse:
return TestResponse(result=f"Received: {query} with value {value}")
@app.put("/test")
async def test_put_endpoint(body: TestArgs) -> TestResponse:
return TestResponse(result=f"Received: {body.query} with value {body.value}")
@app.delete("/test")
async def test_delete_endpoint(query: str, value: int) -> TestResponse:
return TestResponse(result=f"Received: {query} with value {value}")
@app.patch("/test")
async def test_patch_endpoint(body: TestArgs) -> TestResponse:
return TestResponse(result=f"Received: {body.query} with value {body.value}")
@pytest.fixture
def test_config() -> ComponentModel:
return ComponentModel(
provider="autogen_ext.tools.http.HttpTool",
config={
"name": "TestHttpTool",
"description": "A test HTTP tool",
"scheme": "http",
"path": "/test",
"host": "localhost",
"port": 8000,
"method": "POST",
"headers": {"Content-Type": "application/json"},
"json_schema": {
"type": "object",
"properties": {
"query": {"type": "string", "description": "The test query"},
"value": {"type": "integer", "description": "A test value"},
},
"required": ["query", "value"],
},
},
)
@pytest_asyncio.fixture(scope="function") # type: ignore
async def test_server() -> AsyncGenerator[None, None]:
# Start the test server
config = uvicorn.Config(app, host="127.0.0.1", port=8000, log_level="error")
server = uvicorn.Server(config)
# Create a task for the server
server_task = asyncio.create_task(server.serve())
# Wait a bit for server to start
await asyncio.sleep(0.5) # Increased sleep time to ensure server is ready
yield
# Cleanup
server.should_exit = True
await server_task

View File

@ -0,0 +1,202 @@
import json
import httpx
import pytest
from autogen_core import CancellationToken, Component, ComponentModel
from autogen_ext.tools.http import HttpTool
from pydantic import ValidationError
def test_tool_schema_generation(test_config: ComponentModel) -> None:
tool = HttpTool.load_component(test_config)
schema = tool.schema
assert schema["name"] == "TestHttpTool"
assert "description" in schema
assert schema["description"] == "A test HTTP tool"
assert "parameters" in schema
assert schema["parameters"]["type"] == "object"
assert "properties" in schema["parameters"]
assert schema["parameters"]["properties"]["query"]["description"] == "The test query"
assert schema["parameters"]["properties"]["query"]["type"] == "string"
assert schema["parameters"]["properties"]["value"]["description"] == "A test value"
assert schema["parameters"]["properties"]["value"]["type"] == "integer"
assert "required" in schema["parameters"]
assert set(schema["parameters"]["required"]) == {"query", "value"}
def test_tool_properties(test_config: ComponentModel) -> None:
tool = HttpTool.load_component(test_config)
assert tool.name == "TestHttpTool"
assert tool.description == "A test HTTP tool"
assert tool.server_params.host == "localhost"
assert tool.server_params.port == 8000
assert tool.server_params.path == "/test"
assert tool.server_params.scheme == "http"
assert tool.server_params.method == "POST"
def test_component_base_class(test_config: ComponentModel) -> None:
tool = HttpTool.load_component(test_config)
assert tool.dump_component() is not None
assert HttpTool.load_component(tool.dump_component(), HttpTool) is not None
assert isinstance(tool, Component)
@pytest.mark.asyncio
async def test_post_request(test_config: ComponentModel, test_server: None) -> None:
tool = HttpTool.load_component(test_config)
result = await tool.run_json({"query": "test query", "value": 42}, CancellationToken())
assert isinstance(result, str)
assert json.loads(result)["result"] == "Received: test query with value 42"
@pytest.mark.asyncio
async def test_post_request_json_return(test_config: ComponentModel, test_server: None) -> None:
# Modify config to use json return type
config = test_config.model_copy()
config.config["return_type"] = "json"
tool = HttpTool.load_component(config)
result = await tool.run_json({"query": "test query", "value": 45}, CancellationToken())
assert isinstance(result, dict)
assert result["result"] == "Received: test query with value 45"
@pytest.mark.asyncio
async def test_get_request(test_config: ComponentModel, test_server: None) -> None:
# Modify config for GET request
config = test_config.model_copy()
config.config["method"] = "GET"
tool = HttpTool.load_component(config)
result = await tool.run_json({"query": "test query", "value": 42}, CancellationToken())
assert isinstance(result, str)
assert json.loads(result)["result"] == "Received: test query with value 42"
@pytest.mark.asyncio
async def test_put_request(test_config: ComponentModel, test_server: None) -> None:
# Modify config for PUT request
config = test_config.model_copy()
config.config["method"] = "PUT"
tool = HttpTool.load_component(config)
result = await tool.run_json({"query": "test query", "value": 42}, CancellationToken())
assert isinstance(result, str)
assert json.loads(result)["result"] == "Received: test query with value 42"
@pytest.mark.asyncio
async def test_path_params(test_config: ComponentModel, test_server: None) -> None:
# Modify config to use path parameters
config = test_config.model_copy()
config.config["path"] = "/test/{query}/{value}"
tool = HttpTool.load_component(config)
result = await tool.run_json({"query": "test query", "value": 42}, CancellationToken())
assert isinstance(result, str)
assert json.loads(result)["result"] == "Received: test query with value 42"
@pytest.mark.asyncio
async def test_path_params_and_body(test_config: ComponentModel, test_server: None) -> None:
# Modify config to use path parameters and include body parameters
config = test_config.model_copy()
config.config["method"] = "PUT"
config.config["path"] = "/test/{query}/{value}"
config.config["json_schema"] = {
"type": "object",
"properties": {
"query": {"type": "string", "description": "The test query"},
"value": {"type": "integer", "description": "A test value"},
"extra": {"type": "string", "description": "Extra body parameter"},
},
"required": ["query", "value", "extra"],
}
tool = HttpTool.load_component(config)
result = await tool.run_json({"query": "test query", "value": 42, "extra": "extra data"}, CancellationToken())
assert isinstance(result, str)
assert json.loads(result)["result"] == "Received: test query with value 42 and extra extra data"
@pytest.mark.asyncio
async def test_delete_request(test_config: ComponentModel, test_server: None) -> None:
# Modify config for DELETE request
config = test_config.model_copy()
config.config["method"] = "DELETE"
tool = HttpTool.load_component(config)
result = await tool.run_json({"query": "test query", "value": 42}, CancellationToken())
assert isinstance(result, str)
assert json.loads(result)["result"] == "Received: test query with value 42"
@pytest.mark.asyncio
async def test_patch_request(test_config: ComponentModel, test_server: None) -> None:
# Modify config for PATCH request
config = test_config.model_copy()
config.config["method"] = "PATCH"
tool = HttpTool.load_component(config)
result = await tool.run_json({"query": "test query", "value": 42}, CancellationToken())
assert isinstance(result, str)
assert json.loads(result)["result"] == "Received: test query with value 42"
@pytest.mark.asyncio
async def test_invalid_schema(test_config: ComponentModel, test_server: None) -> None:
# Create an invalid schema missing required properties
config: ComponentModel = test_config.model_copy()
config.config["host"] = True # Incorrect type
with pytest.raises(ValidationError):
# Should fail when trying to create model from invalid schema
HttpTool.load_component(config)
@pytest.mark.asyncio
async def test_invalid_request(test_config: ComponentModel, test_server: None) -> None:
# Use an invalid URL
config = test_config.model_copy()
config.config["host"] = "fake"
tool = HttpTool.load_component(config)
with pytest.raises(httpx.ConnectError):
await tool.run_json({"query": "test query", "value": 42}, CancellationToken())
def test_config_serialization(test_config: ComponentModel) -> None:
tool = HttpTool.load_component(test_config)
config = tool.dump_component()
assert config.config["name"] == test_config.config["name"]
assert config.config["description"] == test_config.config["description"]
assert config.config["host"] == test_config.config["host"]
assert config.config["port"] == test_config.config["port"]
assert config.config["path"] == test_config.config["path"]
assert config.config["scheme"] == test_config.config["scheme"]
assert config.config["method"] == test_config.config["method"]
assert config.config["headers"] == test_config.config["headers"]
def test_config_deserialization(test_config: ComponentModel) -> None:
tool = HttpTool.load_component(test_config)
assert tool.name == test_config.config["name"]
assert tool.description == test_config.config["description"]
assert tool.server_params.host == test_config.config["host"]
assert tool.server_params.port == test_config.config["port"]
assert tool.server_params.path == test_config.config["path"]
assert tool.server_params.scheme == test_config.config["scheme"]
assert tool.server_params.method == test_config.config["method"]
assert tool.server_params.headers == test_config.config["headers"]

View File

@ -599,6 +599,10 @@ graphrag = [
grpc = [
{ name = "grpcio" },
]
http-tool = [
{ name = "httpx" },
{ name = "json-schema-to-pydantic" },
]
jupyter-executor = [
{ name = "ipykernel" },
{ name = "nbclient" },
@ -698,7 +702,9 @@ requires-dist = [
{ name = "ffmpeg-python", marker = "extra == 'video-surfer'" },
{ name = "graphrag", marker = "extra == 'graphrag'", specifier = ">=1.0.1" },
{ name = "grpcio", marker = "extra == 'grpc'", specifier = "~=1.70.0" },
{ name = "httpx", marker = "extra == 'http-tool'", specifier = ">=0.27.0" },
{ name = "ipykernel", marker = "extra == 'jupyter-executor'", specifier = ">=6.29.5" },
{ name = "json-schema-to-pydantic", marker = "extra == 'http-tool'", specifier = ">=0.2.0" },
{ name = "json-schema-to-pydantic", marker = "extra == 'mcp'", specifier = ">=0.2.2" },
{ name = "langchain-core", marker = "extra == 'langchain'", specifier = "~=0.3.3" },
{ name = "markitdown", marker = "extra == 'file-surfer'", specifier = ">=0.0.1a2" },