mirror of https://github.com/microsoft/autogen.git
Adding declarative HTTP tools to autogen ext (#5181)
## Why are these changes needed? These changes are needed because currently there's no generic way to add `tools` to autogen studio workflows using the existing DSL and schema other than inline python. This API will be quite verbose, and lacks a discovery mechanism, but it unlocks a lot of programmatic use-cases. ## Related issue number https://github.com/microsoft/autogen/issues/5170 Co-authored-by: Victor Dibia <victordibia@microsoft.com> Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com>
This commit is contained in:
parent
9e15e9529c
commit
8a9f452136
|
@ -46,17 +46,17 @@ python/autogen_ext.agents.web_surfer
|
|||
python/autogen_ext.agents.file_surfer
|
||||
python/autogen_ext.agents.video_surfer
|
||||
python/autogen_ext.agents.video_surfer.tools
|
||||
python/autogen_ext.auth.azure
|
||||
python/autogen_ext.teams.magentic_one
|
||||
python/autogen_ext.models.cache
|
||||
python/autogen_ext.models.openai
|
||||
python/autogen_ext.models.replay
|
||||
python/autogen_ext.models.azure
|
||||
python/autogen_ext.models.semantic_kernel
|
||||
python/autogen_ext.tools.code_execution
|
||||
python/autogen_ext.tools.graphrag
|
||||
python/autogen_ext.tools.http
|
||||
python/autogen_ext.tools.langchain
|
||||
python/autogen_ext.tools.mcp
|
||||
python/autogen_ext.tools.graphrag
|
||||
python/autogen_ext.tools.code_execution
|
||||
python/autogen_ext.tools.semantic_kernel
|
||||
python/autogen_ext.code_executors.local
|
||||
python/autogen_ext.code_executors.docker
|
||||
|
@ -65,4 +65,5 @@ python/autogen_ext.code_executors.azure
|
|||
python/autogen_ext.cache_store.diskcache
|
||||
python/autogen_ext.cache_store.redis
|
||||
python/autogen_ext.runtimes.grpc
|
||||
python/autogen_ext.auth.azure
|
||||
```
|
||||
|
|
|
@ -0,0 +1,8 @@
|
|||
autogen\_ext.tools.http
|
||||
=======================
|
||||
|
||||
|
||||
.. automodule:: autogen_ext.tools.http
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
|
@ -106,6 +106,11 @@ semantic-kernel-dapr = [
|
|||
"semantic-kernel[dapr]>=1.17.1",
|
||||
]
|
||||
|
||||
http-tool = [
|
||||
"httpx>=0.27.0",
|
||||
"json-schema-to-pydantic>=0.2.0"
|
||||
]
|
||||
|
||||
semantic-kernel-all = [
|
||||
"semantic-kernel[google,hugging_face,mistralai,ollama,onnx,anthropic,usearch,pandas,aws,dapr]>=1.17.1",
|
||||
]
|
||||
|
|
|
@ -0,0 +1,3 @@
|
|||
from ._http_tool import HttpTool
|
||||
|
||||
__all__ = ["HttpTool"]
|
|
@ -0,0 +1,233 @@
|
|||
import re
|
||||
from typing import Any, Literal, Optional, Type
|
||||
|
||||
import httpx
|
||||
from autogen_core import CancellationToken, Component
|
||||
from autogen_core.tools import BaseTool
|
||||
from json_schema_to_pydantic import create_model
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import Self
|
||||
|
||||
|
||||
class HttpToolConfig(BaseModel):
|
||||
name: str
|
||||
"""
|
||||
The name of the tool.
|
||||
"""
|
||||
description: Optional[str]
|
||||
"""
|
||||
A description of the tool.
|
||||
"""
|
||||
scheme: Literal["http", "https"] = "http"
|
||||
"""
|
||||
The scheme to use for the request.
|
||||
"""
|
||||
host: str
|
||||
"""
|
||||
The URL to send the request to.
|
||||
"""
|
||||
port: int
|
||||
"""
|
||||
The port to send the request to.
|
||||
"""
|
||||
path: str = Field(default="/")
|
||||
"""
|
||||
The path to send the request to. defaults to "/"
|
||||
The path can accept parameters, e.g. "/{param1}/{param2}".
|
||||
These parameters will be templated from the inputs args, any additional parameters will be added as query parameters or the body of the request.
|
||||
"""
|
||||
method: Optional[Literal["GET", "POST", "PUT", "DELETE", "PATCH"]] = "POST"
|
||||
"""
|
||||
The HTTP method to use, will default to POST if not provided.
|
||||
"""
|
||||
headers: Optional[dict[str, Any]]
|
||||
"""
|
||||
A dictionary of headers to send with the request.
|
||||
"""
|
||||
json_schema: dict[str, Any]
|
||||
"""
|
||||
A JSON Schema object defining the expected parameters for the tool.
|
||||
Path parameters MUST also be included in the json_schema. They must also MUST be set to string
|
||||
"""
|
||||
return_type: Optional[Literal["text", "json"]] = "text"
|
||||
"""
|
||||
The type of response to return from the tool.
|
||||
"""
|
||||
|
||||
|
||||
class HttpTool(BaseTool[BaseModel, Any], Component[HttpToolConfig]):
|
||||
"""A wrapper for using an HTTP server as a tool.
|
||||
|
||||
Args:
|
||||
name (str): The name of the tool.
|
||||
description (str, optional): A description of the tool.
|
||||
scheme (str): The scheme to use for the request. Must be either "http" or "https".
|
||||
host (str): The host to send the request to.
|
||||
port (int): The port to send the request to.
|
||||
path (str, optional): The path to send the request to. Defaults to "/".
|
||||
Can include path parameters like "/{param1}/{param2}" which will be templated from input args.
|
||||
method (str, optional): The HTTP method to use, will default to POST if not provided.
|
||||
Must be one of "GET", "POST", "PUT", "DELETE", "PATCH".
|
||||
headers (dict[str, Any], optional): A dictionary of headers to send with the request.
|
||||
json_schema (dict[str, Any]): A JSON Schema object defining the expected parameters for the tool.
|
||||
Path parameters must also be included in the schema and must be strings.
|
||||
return_type (Literal["text", "json"], optional): The type of response to return from the tool.
|
||||
Defaults to "text".
|
||||
|
||||
.. note::
|
||||
This tool requires the :code:`http-tool` extra for the :code:`autogen-ext` package.
|
||||
|
||||
To install:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
pip install -U "autogen-agentchat" "autogen-ext[http-tool]"
|
||||
|
||||
Example:
|
||||
Simple use case::
|
||||
|
||||
import asyncio
|
||||
|
||||
from autogen_agentchat.agents import AssistantAgent
|
||||
from autogen_agentchat.messages import TextMessage
|
||||
from autogen_core import CancellationToken
|
||||
from autogen_ext.models.openai import OpenAIChatCompletionClient
|
||||
from autogen_ext.tools.http import HttpTool
|
||||
|
||||
# Define a JSON schema for a base64 decode tool
|
||||
base64_schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"value": {"type": "string", "description": "The base64 value to decode"},
|
||||
},
|
||||
"required": ["value"],
|
||||
}
|
||||
|
||||
# Create an HTTP tool for the httpbin API
|
||||
base64_tool = HttpTool(
|
||||
name="base64_decode",
|
||||
description="base64 decode a value",
|
||||
scheme="https",
|
||||
host="httpbin.org",
|
||||
port=443,
|
||||
path="/base64/{value}",
|
||||
method="GET",
|
||||
json_schema=base64_schema,
|
||||
)
|
||||
|
||||
|
||||
async def main():
|
||||
# Create an assistant with the base64 tool
|
||||
model = OpenAIChatCompletionClient(model="gpt-4")
|
||||
assistant = AssistantAgent("base64_assistant", model_client=model, tools=[base64_tool])
|
||||
|
||||
# The assistant can now use the base64 tool to decode the string
|
||||
response = await assistant.on_messages(
|
||||
[TextMessage(content="Can you base64 decode the value 'YWJjZGU=', please?", source="user")],
|
||||
CancellationToken(),
|
||||
)
|
||||
print(response.chat_message.content)
|
||||
|
||||
|
||||
asyncio.run(main())
|
||||
"""
|
||||
|
||||
component_type = "tool"
|
||||
component_provider_override = "autogen_ext.tools.http.HttpTool"
|
||||
component_config_schema = HttpToolConfig
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
host: str,
|
||||
port: int,
|
||||
json_schema: dict[str, Any],
|
||||
headers: Optional[dict[str, Any]] = None,
|
||||
description: str = "HTTP tool",
|
||||
path: str = "/",
|
||||
scheme: Literal["http", "https"] = "http",
|
||||
method: Literal["GET", "POST", "PUT", "DELETE", "PATCH"] = "POST",
|
||||
return_type: Literal["text", "json"] = "text",
|
||||
) -> None:
|
||||
self.server_params = HttpToolConfig(
|
||||
name=name,
|
||||
description=description,
|
||||
host=host,
|
||||
port=port,
|
||||
path=path,
|
||||
scheme=scheme,
|
||||
method=method,
|
||||
headers=headers,
|
||||
json_schema=json_schema,
|
||||
return_type=return_type,
|
||||
)
|
||||
|
||||
# Use regex to find all path parameters, we will need those later to template the path
|
||||
path_params = {match.group(1) for match in re.finditer(r"{([^}]*)}", path)}
|
||||
self._path_params = path_params
|
||||
|
||||
# Create the input model from the modified schema
|
||||
input_model = create_model(json_schema)
|
||||
|
||||
# Use Any as return type since HTTP responses can vary
|
||||
base_return_type: Type[Any] = object
|
||||
|
||||
super().__init__(input_model, base_return_type, name, description)
|
||||
|
||||
def _to_config(self) -> HttpToolConfig:
|
||||
copied_config = self.server_params.model_copy()
|
||||
return copied_config
|
||||
|
||||
@classmethod
|
||||
def _from_config(cls, config: HttpToolConfig) -> Self:
|
||||
copied_config = config.model_copy().model_dump()
|
||||
return cls(**copied_config)
|
||||
|
||||
async def run(self, args: BaseModel, cancellation_token: CancellationToken) -> Any:
|
||||
"""Execute the HTTP tool with the given arguments.
|
||||
|
||||
Args:
|
||||
args: The validated input arguments
|
||||
cancellation_token: Token for cancelling the operation
|
||||
|
||||
Returns:
|
||||
The response body from the HTTP call in JSON format
|
||||
|
||||
Raises:
|
||||
Exception: If tool execution fails
|
||||
"""
|
||||
|
||||
model_dump = args.model_dump()
|
||||
path_params = {k: v for k, v in model_dump.items() if k in self._path_params}
|
||||
# Remove path params from the model dump
|
||||
for k in self._path_params:
|
||||
model_dump.pop(k)
|
||||
|
||||
path = self.server_params.path.format(**path_params)
|
||||
|
||||
url = httpx.URL(
|
||||
scheme=self.server_params.scheme,
|
||||
host=self.server_params.host,
|
||||
port=self.server_params.port,
|
||||
path=path,
|
||||
)
|
||||
async with httpx.AsyncClient() as client:
|
||||
match self.server_params.method:
|
||||
case "GET":
|
||||
response = await client.get(url, params=model_dump)
|
||||
case "PUT":
|
||||
response = await client.put(url, json=model_dump)
|
||||
case "DELETE":
|
||||
response = await client.delete(url, params=model_dump)
|
||||
case "PATCH":
|
||||
response = await client.patch(url, json=model_dump)
|
||||
case _: # Default case POST
|
||||
response = await client.post(url, json=model_dump)
|
||||
|
||||
match self.server_params.return_type:
|
||||
case "text":
|
||||
return response.text
|
||||
case "json":
|
||||
return response.json()
|
||||
case _:
|
||||
raise ValueError(f"Invalid return type: {self.server_params.return_type}")
|
|
@ -29,7 +29,8 @@ from autogen_test_utils import (
|
|||
MessageType,
|
||||
NoopAgent,
|
||||
)
|
||||
from protos.serialization_test_pb2 import ProtoMessage
|
||||
|
||||
from .protos.serialization_test_pb2 import ProtoMessage
|
||||
|
||||
|
||||
@pytest.mark.grpc
|
||||
|
@ -423,7 +424,7 @@ class ProtoReceivingAgent(RoutedAgent):
|
|||
self.received_messages: list[Any] = []
|
||||
|
||||
@event
|
||||
async def on_new_message(self, message: ProtoMessage, ctx: MessageContext) -> None:
|
||||
async def on_new_message(self, message: ProtoMessage, ctx: MessageContext) -> None: # type: ignore
|
||||
self.num_calls += 1
|
||||
self.received_messages.append(message)
|
||||
|
||||
|
|
|
@ -0,0 +1,101 @@
|
|||
import asyncio
|
||||
from typing import Any, AsyncGenerator, Dict
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
import uvicorn
|
||||
from autogen_core import ComponentModel
|
||||
from fastapi import FastAPI
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class TestArgs(BaseModel):
|
||||
query: str = Field(description="The test query")
|
||||
value: int = Field(description="A test value")
|
||||
|
||||
|
||||
class TestResponse(BaseModel):
|
||||
result: str = Field(description="The test result")
|
||||
|
||||
|
||||
# Create a test FastAPI app
|
||||
app = FastAPI()
|
||||
|
||||
|
||||
@app.post("/test")
|
||||
async def test_endpoint(body: TestArgs) -> TestResponse:
|
||||
return TestResponse(result=f"Received: {body.query} with value {body.value}")
|
||||
|
||||
|
||||
@app.post("/test/{query}/{value}")
|
||||
async def test_path_params_endpoint(query: str, value: int) -> TestResponse:
|
||||
return TestResponse(result=f"Received: {query} with value {value}")
|
||||
|
||||
|
||||
@app.put("/test/{query}/{value}")
|
||||
async def test_path_params_and_body_endpoint(query: str, value: int, body: Dict[str, Any]) -> TestResponse:
|
||||
return TestResponse(result=f"Received: {query} with value {value} and extra {body.get('extra')}") # type: ignore
|
||||
|
||||
|
||||
@app.get("/test")
|
||||
async def test_get_endpoint(query: str, value: int) -> TestResponse:
|
||||
return TestResponse(result=f"Received: {query} with value {value}")
|
||||
|
||||
|
||||
@app.put("/test")
|
||||
async def test_put_endpoint(body: TestArgs) -> TestResponse:
|
||||
return TestResponse(result=f"Received: {body.query} with value {body.value}")
|
||||
|
||||
|
||||
@app.delete("/test")
|
||||
async def test_delete_endpoint(query: str, value: int) -> TestResponse:
|
||||
return TestResponse(result=f"Received: {query} with value {value}")
|
||||
|
||||
|
||||
@app.patch("/test")
|
||||
async def test_patch_endpoint(body: TestArgs) -> TestResponse:
|
||||
return TestResponse(result=f"Received: {body.query} with value {body.value}")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_config() -> ComponentModel:
|
||||
return ComponentModel(
|
||||
provider="autogen_ext.tools.http.HttpTool",
|
||||
config={
|
||||
"name": "TestHttpTool",
|
||||
"description": "A test HTTP tool",
|
||||
"scheme": "http",
|
||||
"path": "/test",
|
||||
"host": "localhost",
|
||||
"port": 8000,
|
||||
"method": "POST",
|
||||
"headers": {"Content-Type": "application/json"},
|
||||
"json_schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string", "description": "The test query"},
|
||||
"value": {"type": "integer", "description": "A test value"},
|
||||
},
|
||||
"required": ["query", "value"],
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="function") # type: ignore
|
||||
async def test_server() -> AsyncGenerator[None, None]:
|
||||
# Start the test server
|
||||
config = uvicorn.Config(app, host="127.0.0.1", port=8000, log_level="error")
|
||||
server = uvicorn.Server(config)
|
||||
|
||||
# Create a task for the server
|
||||
server_task = asyncio.create_task(server.serve())
|
||||
|
||||
# Wait a bit for server to start
|
||||
await asyncio.sleep(0.5) # Increased sleep time to ensure server is ready
|
||||
|
||||
yield
|
||||
|
||||
# Cleanup
|
||||
server.should_exit = True
|
||||
await server_task
|
|
@ -0,0 +1,202 @@
|
|||
import json
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
from autogen_core import CancellationToken, Component, ComponentModel
|
||||
from autogen_ext.tools.http import HttpTool
|
||||
from pydantic import ValidationError
|
||||
|
||||
|
||||
def test_tool_schema_generation(test_config: ComponentModel) -> None:
|
||||
tool = HttpTool.load_component(test_config)
|
||||
schema = tool.schema
|
||||
|
||||
assert schema["name"] == "TestHttpTool"
|
||||
assert "description" in schema
|
||||
assert schema["description"] == "A test HTTP tool"
|
||||
assert "parameters" in schema
|
||||
assert schema["parameters"]["type"] == "object"
|
||||
assert "properties" in schema["parameters"]
|
||||
assert schema["parameters"]["properties"]["query"]["description"] == "The test query"
|
||||
assert schema["parameters"]["properties"]["query"]["type"] == "string"
|
||||
assert schema["parameters"]["properties"]["value"]["description"] == "A test value"
|
||||
assert schema["parameters"]["properties"]["value"]["type"] == "integer"
|
||||
assert "required" in schema["parameters"]
|
||||
assert set(schema["parameters"]["required"]) == {"query", "value"}
|
||||
|
||||
|
||||
def test_tool_properties(test_config: ComponentModel) -> None:
|
||||
tool = HttpTool.load_component(test_config)
|
||||
|
||||
assert tool.name == "TestHttpTool"
|
||||
assert tool.description == "A test HTTP tool"
|
||||
assert tool.server_params.host == "localhost"
|
||||
assert tool.server_params.port == 8000
|
||||
assert tool.server_params.path == "/test"
|
||||
assert tool.server_params.scheme == "http"
|
||||
assert tool.server_params.method == "POST"
|
||||
|
||||
|
||||
def test_component_base_class(test_config: ComponentModel) -> None:
|
||||
tool = HttpTool.load_component(test_config)
|
||||
assert tool.dump_component() is not None
|
||||
assert HttpTool.load_component(tool.dump_component(), HttpTool) is not None
|
||||
assert isinstance(tool, Component)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_post_request(test_config: ComponentModel, test_server: None) -> None:
|
||||
tool = HttpTool.load_component(test_config)
|
||||
result = await tool.run_json({"query": "test query", "value": 42}, CancellationToken())
|
||||
|
||||
assert isinstance(result, str)
|
||||
assert json.loads(result)["result"] == "Received: test query with value 42"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_post_request_json_return(test_config: ComponentModel, test_server: None) -> None:
|
||||
# Modify config to use json return type
|
||||
config = test_config.model_copy()
|
||||
config.config["return_type"] = "json"
|
||||
tool = HttpTool.load_component(config)
|
||||
result = await tool.run_json({"query": "test query", "value": 45}, CancellationToken())
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert result["result"] == "Received: test query with value 45"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_request(test_config: ComponentModel, test_server: None) -> None:
|
||||
# Modify config for GET request
|
||||
config = test_config.model_copy()
|
||||
config.config["method"] = "GET"
|
||||
tool = HttpTool.load_component(config)
|
||||
|
||||
result = await tool.run_json({"query": "test query", "value": 42}, CancellationToken())
|
||||
|
||||
assert isinstance(result, str)
|
||||
assert json.loads(result)["result"] == "Received: test query with value 42"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_put_request(test_config: ComponentModel, test_server: None) -> None:
|
||||
# Modify config for PUT request
|
||||
config = test_config.model_copy()
|
||||
config.config["method"] = "PUT"
|
||||
tool = HttpTool.load_component(config)
|
||||
|
||||
result = await tool.run_json({"query": "test query", "value": 42}, CancellationToken())
|
||||
|
||||
assert isinstance(result, str)
|
||||
assert json.loads(result)["result"] == "Received: test query with value 42"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_path_params(test_config: ComponentModel, test_server: None) -> None:
|
||||
# Modify config to use path parameters
|
||||
config = test_config.model_copy()
|
||||
config.config["path"] = "/test/{query}/{value}"
|
||||
tool = HttpTool.load_component(config)
|
||||
|
||||
result = await tool.run_json({"query": "test query", "value": 42}, CancellationToken())
|
||||
|
||||
assert isinstance(result, str)
|
||||
assert json.loads(result)["result"] == "Received: test query with value 42"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_path_params_and_body(test_config: ComponentModel, test_server: None) -> None:
|
||||
# Modify config to use path parameters and include body parameters
|
||||
config = test_config.model_copy()
|
||||
config.config["method"] = "PUT"
|
||||
config.config["path"] = "/test/{query}/{value}"
|
||||
config.config["json_schema"] = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string", "description": "The test query"},
|
||||
"value": {"type": "integer", "description": "A test value"},
|
||||
"extra": {"type": "string", "description": "Extra body parameter"},
|
||||
},
|
||||
"required": ["query", "value", "extra"],
|
||||
}
|
||||
tool = HttpTool.load_component(config)
|
||||
|
||||
result = await tool.run_json({"query": "test query", "value": 42, "extra": "extra data"}, CancellationToken())
|
||||
|
||||
assert isinstance(result, str)
|
||||
assert json.loads(result)["result"] == "Received: test query with value 42 and extra extra data"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_request(test_config: ComponentModel, test_server: None) -> None:
|
||||
# Modify config for DELETE request
|
||||
config = test_config.model_copy()
|
||||
config.config["method"] = "DELETE"
|
||||
tool = HttpTool.load_component(config)
|
||||
|
||||
result = await tool.run_json({"query": "test query", "value": 42}, CancellationToken())
|
||||
|
||||
assert isinstance(result, str)
|
||||
assert json.loads(result)["result"] == "Received: test query with value 42"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_patch_request(test_config: ComponentModel, test_server: None) -> None:
|
||||
# Modify config for PATCH request
|
||||
config = test_config.model_copy()
|
||||
config.config["method"] = "PATCH"
|
||||
tool = HttpTool.load_component(config)
|
||||
|
||||
result = await tool.run_json({"query": "test query", "value": 42}, CancellationToken())
|
||||
|
||||
assert isinstance(result, str)
|
||||
assert json.loads(result)["result"] == "Received: test query with value 42"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_schema(test_config: ComponentModel, test_server: None) -> None:
|
||||
# Create an invalid schema missing required properties
|
||||
config: ComponentModel = test_config.model_copy()
|
||||
config.config["host"] = True # Incorrect type
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
# Should fail when trying to create model from invalid schema
|
||||
HttpTool.load_component(config)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_request(test_config: ComponentModel, test_server: None) -> None:
|
||||
# Use an invalid URL
|
||||
config = test_config.model_copy()
|
||||
config.config["host"] = "fake"
|
||||
tool = HttpTool.load_component(config)
|
||||
|
||||
with pytest.raises(httpx.ConnectError):
|
||||
await tool.run_json({"query": "test query", "value": 42}, CancellationToken())
|
||||
|
||||
|
||||
def test_config_serialization(test_config: ComponentModel) -> None:
|
||||
tool = HttpTool.load_component(test_config)
|
||||
config = tool.dump_component()
|
||||
|
||||
assert config.config["name"] == test_config.config["name"]
|
||||
assert config.config["description"] == test_config.config["description"]
|
||||
assert config.config["host"] == test_config.config["host"]
|
||||
assert config.config["port"] == test_config.config["port"]
|
||||
assert config.config["path"] == test_config.config["path"]
|
||||
assert config.config["scheme"] == test_config.config["scheme"]
|
||||
assert config.config["method"] == test_config.config["method"]
|
||||
assert config.config["headers"] == test_config.config["headers"]
|
||||
|
||||
|
||||
def test_config_deserialization(test_config: ComponentModel) -> None:
|
||||
tool = HttpTool.load_component(test_config)
|
||||
|
||||
assert tool.name == test_config.config["name"]
|
||||
assert tool.description == test_config.config["description"]
|
||||
assert tool.server_params.host == test_config.config["host"]
|
||||
assert tool.server_params.port == test_config.config["port"]
|
||||
assert tool.server_params.path == test_config.config["path"]
|
||||
assert tool.server_params.scheme == test_config.config["scheme"]
|
||||
assert tool.server_params.method == test_config.config["method"]
|
||||
assert tool.server_params.headers == test_config.config["headers"]
|
|
@ -599,6 +599,10 @@ graphrag = [
|
|||
grpc = [
|
||||
{ name = "grpcio" },
|
||||
]
|
||||
http-tool = [
|
||||
{ name = "httpx" },
|
||||
{ name = "json-schema-to-pydantic" },
|
||||
]
|
||||
jupyter-executor = [
|
||||
{ name = "ipykernel" },
|
||||
{ name = "nbclient" },
|
||||
|
@ -698,7 +702,9 @@ requires-dist = [
|
|||
{ name = "ffmpeg-python", marker = "extra == 'video-surfer'" },
|
||||
{ name = "graphrag", marker = "extra == 'graphrag'", specifier = ">=1.0.1" },
|
||||
{ name = "grpcio", marker = "extra == 'grpc'", specifier = "~=1.70.0" },
|
||||
{ name = "httpx", marker = "extra == 'http-tool'", specifier = ">=0.27.0" },
|
||||
{ name = "ipykernel", marker = "extra == 'jupyter-executor'", specifier = ">=6.29.5" },
|
||||
{ name = "json-schema-to-pydantic", marker = "extra == 'http-tool'", specifier = ">=0.2.0" },
|
||||
{ name = "json-schema-to-pydantic", marker = "extra == 'mcp'", specifier = ">=0.2.2" },
|
||||
{ name = "langchain-core", marker = "extra == 'langchain'", specifier = "~=0.3.3" },
|
||||
{ name = "markitdown", marker = "extra == 'file-surfer'", specifier = ">=0.0.1a2" },
|
||||
|
|
Loading…
Reference in New Issue