mirror of https://github.com/microsoft/autogen.git
Add output_format to AssistantAgent for structured output (#6071)
Resolves #5934 This PR adds ability for `AssistantAgent` to generate a `StructuredMessage[T]` where `T` is the content type in base model. How to use? ```python from typing import Literal from pydantic import BaseModel from autogen_agentchat.agents import AssistantAgent from autogen_ext.models.openai import OpenAIChatCompletionClient from autogen_agentchat.ui import Console # The response format for the agent as a Pydantic base model. class AgentResponse(BaseModel): thoughts: str response: Literal["happy", "sad", "neutral"] # Create an agent that uses the OpenAI GPT-4o model which supports structured output. model_client = OpenAIChatCompletionClient(model="gpt-4o") agent = AssistantAgent( "assistant", model_client=model_client, system_message="Categorize the input as happy, sad, or neutral following the JSON format.", # Setting the output format to AgentResponse to force the agent to produce a JSON string as response. output_content_type=AgentResponse, ) result = await Console(agent.run_stream(task="I am happy.")) # Check the last message in the result, validate its type, and print the thoughts and response. assert isinstance(result.messages[-1], StructuredMessage) assert isinstance(result.messages[-1].content, AgentResponse) print("Thought: ", result.messages[-1].content.thoughts) print("Response: ", result.messages[-1].content.response) await model_client.close() ``` ``` ---------- user ---------- I am happy. ---------- assistant ---------- { "thoughts": "The user explicitly states they are happy.", "response": "happy" } Thought: The user explicitly states they are happy. Response: happy ``` --------- Co-authored-by: Victor Dibia <victordibia@microsoft.com>
This commit is contained in:
parent
9915b65929
commit
86237c9fdf
|
@ -45,6 +45,7 @@ from ..messages import (
|
|||
HandoffMessage,
|
||||
MemoryQueryEvent,
|
||||
ModelClientStreamingChunkEvent,
|
||||
StructuredMessage,
|
||||
TextMessage,
|
||||
ThoughtEvent,
|
||||
ToolCallExecutionEvent,
|
||||
|
@ -102,12 +103,25 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
|
|||
|
||||
.. image:: ../../images/assistant-agent.svg
|
||||
|
||||
Tool call behavior:
|
||||
**Structured output:**
|
||||
|
||||
* If the model returns no tool call, then the response is immediately returned as a :class:`~autogen_agentchat.messages.TextMessage` in :attr:`~autogen_agentchat.base.Response.chat_message`.
|
||||
If the `output_content_type` is set, the agent will respond with a :class:`~autogen_agentchat.messages.StructuredMessage`
|
||||
instead of a :class:`~autogen_agentchat.messages.TextMessage` in the final response by default.
|
||||
|
||||
.. note::
|
||||
|
||||
Currently, setting `output_content_type` prevents the agent from being
|
||||
able to call `load_component` and `dum_component` methods for serializable
|
||||
configuration. This will be fixed soon in the future.
|
||||
|
||||
**Tool call behavior:**
|
||||
|
||||
* If the model returns no tool call, then the response is immediately returned as a :class:`~autogen_agentchat.messages.TextMessage` or a :class:`~autogen_agentchat.messages.StructuredMessage` (when using structured output) in :attr:`~autogen_agentchat.base.Response.chat_message`.
|
||||
* When the model returns tool calls, they will be executed right away:
|
||||
- When `reflect_on_tool_use` is False (default), the tool call results are returned as a :class:`~autogen_agentchat.messages.ToolCallSummaryMessage` in :attr:`~autogen_agentchat.base.Response.chat_message`. `tool_call_summary_format` can be used to customize the tool call summary.
|
||||
- When `reflect_on_tool_use` is True, the another model inference is made using the tool calls and results, and the text response is returned as a :class:`~autogen_agentchat.messages.TextMessage` in :attr:`~autogen_agentchat.base.Response.chat_message`.
|
||||
- When `reflect_on_tool_use` is False, the tool call results are returned as a :class:`~autogen_agentchat.messages.ToolCallSummaryMessage` in :attr:`~autogen_agentchat.base.Response.chat_message`. `tool_call_summary_format` can be used to customize the tool call summary.
|
||||
- When `reflect_on_tool_use` is True, the another model inference is made using the tool calls and results, and final response is returned as a :class:`~autogen_agentchat.messages.TextMessage` or a :class:`~autogen_agentchat.messages.StructuredMessage` (when using structured output) in :attr:`~autogen_agentchat.base.Response.chat_message`.
|
||||
- `reflect_on_tool_use` is set to `True` by default when `output_content_type` is set.
|
||||
- `reflect_on_tool_use` is set to `False` by default when `output_content_type` is not set.
|
||||
* If the model returns multiple tool calls, they will be executed concurrently. To disable parallel tool calls you need to configure the model client. For example, set `parallel_tool_calls=False` for :class:`~autogen_ext.models.openai.OpenAIChatCompletionClient` and :class:`~autogen_ext.models.openai.AzureOpenAIChatCompletionClient`.
|
||||
|
||||
.. tip::
|
||||
|
@ -116,7 +130,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
|
|||
especially if another agent is expecting them in a specific format.
|
||||
Use `tool_call_summary_format` to customize the tool call summary, if needed.
|
||||
|
||||
Hand off behavior:
|
||||
**Hand off behavior:**
|
||||
|
||||
* If a handoff is triggered, a :class:`~autogen_agentchat.messages.HandoffMessage` will be returned in :attr:`~autogen_agentchat.base.Response.chat_message`.
|
||||
* If there are tool calls, they will also be executed right away before returning the handoff.
|
||||
|
@ -128,16 +142,18 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
|
|||
To avoid this, disable parallel tool calls in the model client configuration.
|
||||
|
||||
|
||||
Limit context size sent to the model:
|
||||
**Limit context size sent to the model:**
|
||||
|
||||
You can limit the number of messages sent to the model by setting
|
||||
the `model_context` parameter to a :class:`~autogen_core.model_context.BufferedChatCompletionContext`.
|
||||
This will limit the number of recent messages sent to the model and can be useful
|
||||
when the model has a limit on the number of tokens it can process.
|
||||
Another option is to use a :class:`~autogen_core.model_context.TokenLimitedChatCompletionContext`
|
||||
which will limit the number of tokens sent to the model.
|
||||
You can also create your own model context by subclassing
|
||||
:class:`~autogen_core.model_context.ChatCompletionContext`.
|
||||
|
||||
Streaming mode:
|
||||
**Streaming mode:**
|
||||
|
||||
The assistant agent can be used in streaming mode by setting `model_client_stream=True`.
|
||||
In this mode, the :meth:`on_messages_stream` and :meth:`BaseChatAgent.run_stream` methods will also yield
|
||||
|
@ -161,8 +177,14 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
|
|||
:meth:`on_messages_stream` and :meth:`BaseChatAgent.run_stream` methods will also yield :class:`~autogen_agentchat.messages.ModelClientStreamingChunkEvent`
|
||||
messages as the model client produces chunks of response. Defaults to `False`.
|
||||
reflect_on_tool_use (bool, optional): If `True`, the agent will make another model inference using the tool call and result
|
||||
to generate a response. If `False`, the tool call result will be returned as the response. Defaults to `False`.
|
||||
tool_call_summary_format (str, optional): The format string used to create a tool call summary for every tool call result.
|
||||
to generate a response. If `False`, the tool call result will be returned as the response. By default, if `output_content_type` is set, this will be `True`;
|
||||
if `output_content_type` is not set, this will be `False`.
|
||||
output_content_type (type[BaseModel] | None, optional): The output content type for :class:`~autogen_agentchat.messages.StructuredMessage` response as a Pydantic model.
|
||||
This will be used with the model client to generate structured output.
|
||||
If this is set, the agent will respond with a :class:`~autogen_agentchat.messages.StructuredMessage` instead of a :class:`~autogen_agentchat.messages.TextMessage`
|
||||
in the final response, unless `reflect_on_tool_use` is `False` and a tool call is made.
|
||||
tool_call_summary_format (str, optional): The format string used to create the content for a :class:`~autogen_agentchat.messages.ToolCallSummaryMessage` response.
|
||||
The format string is used to format the tool call summary for every tool call result.
|
||||
Defaults to "{result}".
|
||||
When `reflect_on_tool_use` is `False`, a concatenation of all the tool call summaries, separated by a new line character ('\\n')
|
||||
will be returned as the response.
|
||||
|
@ -348,10 +370,9 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
|
|||
# which is required for structured output mode.
|
||||
tool = FunctionTool(sentiment_analysis, description="Sentiment Analysis", strict=True)
|
||||
|
||||
# Create an OpenAIChatCompletionClient instance that uses the structured output format.
|
||||
# Create an OpenAIChatCompletionClient instance that supports structured output.
|
||||
model_client = OpenAIChatCompletionClient(
|
||||
model="gpt-4o-mini",
|
||||
response_format=AgentResponse, # type: ignore
|
||||
)
|
||||
|
||||
# Create an AssistantAgent instance that uses the tool and model client.
|
||||
|
@ -360,7 +381,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
|
|||
model_client=model_client,
|
||||
tools=[tool],
|
||||
system_message="Use the tool to analyze sentiment.",
|
||||
reflect_on_tool_use=True, # Use reflection to have the agent generate a formatted response.
|
||||
output_content_type=AgentResponse,
|
||||
)
|
||||
|
||||
|
||||
|
@ -611,25 +632,17 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
|
|||
str | None
|
||||
) = "You are a helpful AI assistant. Solve tasks using your tools. Reply with TERMINATE when the task has been completed.",
|
||||
model_client_stream: bool = False,
|
||||
reflect_on_tool_use: bool = False,
|
||||
reflect_on_tool_use: bool | None = None,
|
||||
tool_call_summary_format: str = "{result}",
|
||||
output_content_type: type[BaseModel] | None = None,
|
||||
memory: Sequence[Memory] | None = None,
|
||||
metadata: Dict[str, str] | None = None,
|
||||
):
|
||||
super().__init__(name=name, description=description)
|
||||
self._metadata = metadata or {}
|
||||
if reflect_on_tool_use and ModelFamily.is_claude(model_client.model_info["family"]):
|
||||
warnings.warn(
|
||||
"Claude models may not work with reflection on tool use because Claude requires that any requests including a previous tool use or tool result must include the original tools definition."
|
||||
"Consider setting reflect_on_tool_use to False. "
|
||||
"As an alternative, consider calling the agent in a loop until it stops producing tool calls. "
|
||||
"See [Single-Agent Team](https://microsoft.github.io/autogen/stable/user-guide/agentchat-user-guide/tutorial/teams.html#single-agent-team) "
|
||||
"for more details.",
|
||||
UserWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
self._model_client = model_client
|
||||
self._model_client_stream = model_client_stream
|
||||
self._output_content_type: type[BaseModel] | None = output_content_type
|
||||
self._memory = None
|
||||
if memory is not None:
|
||||
if isinstance(memory, list):
|
||||
|
@ -692,17 +705,37 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
|
|||
else:
|
||||
self._model_context = UnboundedChatCompletionContext()
|
||||
|
||||
self._reflect_on_tool_use = reflect_on_tool_use
|
||||
if self._output_content_type is not None and reflect_on_tool_use is None:
|
||||
# If output_content_type is set, we need to reflect on tool use by default.
|
||||
self._reflect_on_tool_use = True
|
||||
elif reflect_on_tool_use is None:
|
||||
self._reflect_on_tool_use = False
|
||||
else:
|
||||
self._reflect_on_tool_use = reflect_on_tool_use
|
||||
if self._reflect_on_tool_use and ModelFamily.is_claude(model_client.model_info["family"]):
|
||||
warnings.warn(
|
||||
"Claude models may not work with reflection on tool use because Claude requires that any requests including a previous tool use or tool result must include the original tools definition."
|
||||
"Consider setting reflect_on_tool_use to False. "
|
||||
"As an alternative, consider calling the agent in a loop until it stops producing tool calls. "
|
||||
"See [Single-Agent Team](https://microsoft.github.io/autogen/stable/user-guide/agentchat-user-guide/tutorial/teams.html#single-agent-team) "
|
||||
"for more details.",
|
||||
UserWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
self._tool_call_summary_format = tool_call_summary_format
|
||||
self._is_running = False
|
||||
|
||||
@property
|
||||
def produced_message_types(self) -> Sequence[type[BaseChatMessage]]:
|
||||
message_types: List[type[BaseChatMessage]] = [TextMessage]
|
||||
message_types: List[type[BaseChatMessage]] = []
|
||||
if self._handoffs:
|
||||
message_types.append(HandoffMessage)
|
||||
if self._tools:
|
||||
message_types.append(ToolCallSummaryMessage)
|
||||
if self._output_content_type:
|
||||
message_types.append(StructuredMessage[self._output_content_type]) # type: ignore[name-defined]
|
||||
else:
|
||||
message_types.append(TextMessage)
|
||||
return tuple(message_types)
|
||||
|
||||
@property
|
||||
|
@ -737,6 +770,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
|
|||
model_client_stream = self._model_client_stream
|
||||
reflect_on_tool_use = self._reflect_on_tool_use
|
||||
tool_call_summary_format = self._tool_call_summary_format
|
||||
output_content_type = self._output_content_type
|
||||
|
||||
# STEP 1: Add new user/handoff messages to the model context
|
||||
await self._add_messages_to_context(
|
||||
|
@ -765,6 +799,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
|
|||
handoff_tools=handoff_tools,
|
||||
agent_name=agent_name,
|
||||
cancellation_token=cancellation_token,
|
||||
output_content_type=output_content_type,
|
||||
):
|
||||
if isinstance(inference_output, CreateResult):
|
||||
model_result = inference_output
|
||||
|
@ -804,6 +839,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
|
|||
model_client_stream=model_client_stream,
|
||||
reflect_on_tool_use=reflect_on_tool_use,
|
||||
tool_call_summary_format=tool_call_summary_format,
|
||||
output_content_type=output_content_type,
|
||||
):
|
||||
yield output_event
|
||||
|
||||
|
@ -853,6 +889,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
|
|||
handoff_tools: List[BaseTool[Any, Any]],
|
||||
agent_name: str,
|
||||
cancellation_token: CancellationToken,
|
||||
output_content_type: type[BaseModel] | None,
|
||||
) -> AsyncGenerator[Union[CreateResult, ModelClientStreamingChunkEvent], None]:
|
||||
"""
|
||||
Perform a model inference and yield either streaming chunk events or the final CreateResult.
|
||||
|
@ -865,7 +902,10 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
|
|||
if model_client_stream:
|
||||
model_result: Optional[CreateResult] = None
|
||||
async for chunk in model_client.create_stream(
|
||||
llm_messages, tools=all_tools, cancellation_token=cancellation_token
|
||||
llm_messages,
|
||||
tools=all_tools,
|
||||
json_output=output_content_type,
|
||||
cancellation_token=cancellation_token,
|
||||
):
|
||||
if isinstance(chunk, CreateResult):
|
||||
model_result = chunk
|
||||
|
@ -878,7 +918,10 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
|
|||
yield model_result
|
||||
else:
|
||||
model_result = await model_client.create(
|
||||
llm_messages, tools=all_tools, cancellation_token=cancellation_token
|
||||
llm_messages,
|
||||
tools=all_tools,
|
||||
cancellation_token=cancellation_token,
|
||||
json_output=output_content_type,
|
||||
)
|
||||
yield model_result
|
||||
|
||||
|
@ -898,6 +941,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
|
|||
model_client_stream: bool,
|
||||
reflect_on_tool_use: bool,
|
||||
tool_call_summary_format: str,
|
||||
output_content_type: type[BaseModel] | None,
|
||||
) -> AsyncGenerator[BaseAgentEvent | BaseChatMessage | Response, None]:
|
||||
"""
|
||||
Handle final or partial responses from model_result, including tool calls, handoffs,
|
||||
|
@ -906,14 +950,25 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
|
|||
|
||||
# If direct text response (string)
|
||||
if isinstance(model_result.content, str):
|
||||
yield Response(
|
||||
chat_message=TextMessage(
|
||||
content=model_result.content,
|
||||
source=agent_name,
|
||||
models_usage=model_result.usage,
|
||||
),
|
||||
inner_messages=inner_messages,
|
||||
)
|
||||
if output_content_type:
|
||||
content = output_content_type.model_validate_json(model_result.content)
|
||||
yield Response(
|
||||
chat_message=StructuredMessage[output_content_type]( # type: ignore[valid-type]
|
||||
content=content,
|
||||
source=agent_name,
|
||||
models_usage=model_result.usage,
|
||||
),
|
||||
inner_messages=inner_messages,
|
||||
)
|
||||
else:
|
||||
yield Response(
|
||||
chat_message=TextMessage(
|
||||
content=model_result.content,
|
||||
source=agent_name,
|
||||
models_usage=model_result.usage,
|
||||
),
|
||||
inner_messages=inner_messages,
|
||||
)
|
||||
return
|
||||
|
||||
# Otherwise, we have function calls
|
||||
|
@ -977,6 +1032,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
|
|||
model_context=model_context,
|
||||
agent_name=agent_name,
|
||||
inner_messages=inner_messages,
|
||||
output_content_type=output_content_type,
|
||||
):
|
||||
yield reflection_response
|
||||
else:
|
||||
|
@ -1062,6 +1118,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
|
|||
model_context: ChatCompletionContext,
|
||||
agent_name: str,
|
||||
inner_messages: List[BaseAgentEvent | BaseChatMessage],
|
||||
output_content_type: type[BaseModel] | None,
|
||||
) -> AsyncGenerator[Response | ModelClientStreamingChunkEvent | ThoughtEvent, None]:
|
||||
"""
|
||||
If reflect_on_tool_use=True, we do another inference based on tool results
|
||||
|
@ -1073,7 +1130,10 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
|
|||
reflection_result: Optional[CreateResult] = None
|
||||
|
||||
if model_client_stream:
|
||||
async for chunk in model_client.create_stream(llm_messages):
|
||||
async for chunk in model_client.create_stream(
|
||||
llm_messages,
|
||||
json_output=output_content_type,
|
||||
):
|
||||
if isinstance(chunk, CreateResult):
|
||||
reflection_result = chunk
|
||||
elif isinstance(chunk, str):
|
||||
|
@ -1081,7 +1141,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
|
|||
else:
|
||||
raise RuntimeError(f"Invalid chunk type: {type(chunk)}")
|
||||
else:
|
||||
reflection_result = await model_client.create(llm_messages)
|
||||
reflection_result = await model_client.create(llm_messages, json_output=output_content_type)
|
||||
|
||||
if not reflection_result or not isinstance(reflection_result.content, str):
|
||||
raise RuntimeError("Reflect on tool use produced no valid text response.")
|
||||
|
@ -1101,14 +1161,25 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
|
|||
)
|
||||
)
|
||||
|
||||
yield Response(
|
||||
chat_message=TextMessage(
|
||||
content=reflection_result.content,
|
||||
source=agent_name,
|
||||
models_usage=reflection_result.usage,
|
||||
),
|
||||
inner_messages=inner_messages,
|
||||
)
|
||||
if output_content_type:
|
||||
content = output_content_type.model_validate_json(reflection_result.content)
|
||||
yield Response(
|
||||
chat_message=StructuredMessage[output_content_type]( # type: ignore[valid-type]
|
||||
content=content,
|
||||
source=agent_name,
|
||||
models_usage=reflection_result.usage,
|
||||
),
|
||||
inner_messages=inner_messages,
|
||||
)
|
||||
else:
|
||||
yield Response(
|
||||
chat_message=TextMessage(
|
||||
content=reflection_result.content,
|
||||
source=agent_name,
|
||||
models_usage=reflection_result.usage,
|
||||
),
|
||||
inner_messages=inner_messages,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _summarize_tool_use(
|
||||
|
@ -1206,6 +1277,9 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
|
|||
def _to_config(self) -> AssistantAgentConfig:
|
||||
"""Convert the assistant agent to a declarative config."""
|
||||
|
||||
if self._output_content_type:
|
||||
raise ValueError("AssistantAgent with output_content_type does not support declarative config.")
|
||||
|
||||
return AssistantAgentConfig(
|
||||
name=self.name,
|
||||
model_client=self._model_client.dump_component(),
|
||||
|
@ -1226,6 +1300,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
|
|||
@classmethod
|
||||
def _from_config(cls, config: AssistantAgentConfig) -> Self:
|
||||
"""Create an assistant agent from a declarative config."""
|
||||
|
||||
return cls(
|
||||
name=config.name,
|
||||
model_client=ChatCompletionClient.load_component(config.model_client),
|
||||
|
|
|
@ -401,6 +401,147 @@ async def test_run_with_parallel_tools_with_empty_call_ids() -> None:
|
|||
assert state == state2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_output_format() -> None:
|
||||
class AgentResponse(BaseModel):
|
||||
response: str
|
||||
status: str
|
||||
|
||||
model_client = ReplayChatCompletionClient(
|
||||
[
|
||||
CreateResult(
|
||||
finish_reason="stop",
|
||||
content=AgentResponse(response="Hello", status="success").model_dump_json(),
|
||||
usage=RequestUsage(prompt_tokens=10, completion_tokens=5),
|
||||
cached=False,
|
||||
),
|
||||
]
|
||||
)
|
||||
agent = AssistantAgent(
|
||||
"test_agent",
|
||||
model_client=model_client,
|
||||
output_content_type=AgentResponse,
|
||||
)
|
||||
assert StructuredMessage[AgentResponse] in agent.produced_message_types
|
||||
assert TextMessage not in agent.produced_message_types
|
||||
|
||||
result = await agent.run()
|
||||
assert len(result.messages) == 1
|
||||
assert isinstance(result.messages[0], StructuredMessage)
|
||||
assert isinstance(result.messages[0].content, AgentResponse) # type: ignore[reportUnknownMemberType]
|
||||
assert result.messages[0].content.response == "Hello"
|
||||
assert result.messages[0].content.status == "success"
|
||||
|
||||
# Test streaming.
|
||||
agent = AssistantAgent(
|
||||
"test_agent",
|
||||
model_client=model_client,
|
||||
model_client_stream=True,
|
||||
output_content_type=AgentResponse,
|
||||
)
|
||||
model_client.reset()
|
||||
stream = agent.run_stream()
|
||||
stream_result: TaskResult | None = None
|
||||
async for message in stream:
|
||||
if isinstance(message, TaskResult):
|
||||
stream_result = message
|
||||
assert stream_result is not None
|
||||
assert len(stream_result.messages) == 1
|
||||
assert isinstance(stream_result.messages[0], StructuredMessage)
|
||||
assert isinstance(stream_result.messages[0].content, AgentResponse) # type: ignore[reportUnknownMemberType]
|
||||
assert stream_result.messages[0].content.response == "Hello"
|
||||
assert stream_result.messages[0].content.status == "success"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reflection_output_format() -> None:
|
||||
class AgentResponse(BaseModel):
|
||||
response: str
|
||||
status: str
|
||||
|
||||
model_client = ReplayChatCompletionClient(
|
||||
[
|
||||
CreateResult(
|
||||
finish_reason="function_calls",
|
||||
content=[FunctionCall(id="1", arguments=json.dumps({"input": "task"}), name="_pass_function")],
|
||||
usage=RequestUsage(prompt_tokens=10, completion_tokens=5),
|
||||
cached=False,
|
||||
),
|
||||
AgentResponse(response="Hello", status="success").model_dump_json(),
|
||||
],
|
||||
model_info={
|
||||
"function_calling": True,
|
||||
"vision": True,
|
||||
"json_output": True,
|
||||
"family": ModelFamily.GPT_4O,
|
||||
"structured_output": True,
|
||||
},
|
||||
)
|
||||
agent = AssistantAgent(
|
||||
"test_agent",
|
||||
model_client=model_client,
|
||||
output_content_type=AgentResponse,
|
||||
# reflect_on_tool_use=True,
|
||||
tools=[
|
||||
_pass_function,
|
||||
_fail_function,
|
||||
],
|
||||
)
|
||||
result = await agent.run()
|
||||
assert len(result.messages) == 3
|
||||
assert isinstance(result.messages[0], ToolCallRequestEvent)
|
||||
assert isinstance(result.messages[1], ToolCallExecutionEvent)
|
||||
assert isinstance(result.messages[2], StructuredMessage)
|
||||
assert isinstance(result.messages[2].content, AgentResponse) # type: ignore[reportUnknownMemberType]
|
||||
assert result.messages[2].content.response == "Hello"
|
||||
assert result.messages[2].content.status == "success"
|
||||
|
||||
# Test streaming.
|
||||
agent = AssistantAgent(
|
||||
"test_agent",
|
||||
model_client=model_client,
|
||||
model_client_stream=True,
|
||||
output_content_type=AgentResponse,
|
||||
# reflect_on_tool_use=True,
|
||||
tools=[
|
||||
_pass_function,
|
||||
_fail_function,
|
||||
],
|
||||
)
|
||||
model_client.reset()
|
||||
stream = agent.run_stream()
|
||||
stream_result: TaskResult | None = None
|
||||
async for message in stream:
|
||||
if isinstance(message, TaskResult):
|
||||
stream_result = message
|
||||
assert stream_result is not None
|
||||
assert len(stream_result.messages) == 3
|
||||
assert isinstance(stream_result.messages[0], ToolCallRequestEvent)
|
||||
assert isinstance(stream_result.messages[1], ToolCallExecutionEvent)
|
||||
assert isinstance(stream_result.messages[2], StructuredMessage)
|
||||
assert isinstance(stream_result.messages[2].content, AgentResponse) # type: ignore[reportUnknownMemberType]
|
||||
assert stream_result.messages[2].content.response == "Hello"
|
||||
assert stream_result.messages[2].content.status == "success"
|
||||
|
||||
# Test when reflect_on_tool_use is False
|
||||
model_client.reset()
|
||||
agent = AssistantAgent(
|
||||
"test_agent",
|
||||
model_client=model_client,
|
||||
output_content_type=AgentResponse,
|
||||
reflect_on_tool_use=False,
|
||||
tools=[
|
||||
_pass_function,
|
||||
_fail_function,
|
||||
],
|
||||
)
|
||||
result = await agent.run()
|
||||
assert len(result.messages) == 3
|
||||
assert isinstance(result.messages[0], ToolCallRequestEvent)
|
||||
assert isinstance(result.messages[1], ToolCallExecutionEvent)
|
||||
assert isinstance(result.messages[2], ToolCallSummaryMessage)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handoffs() -> None:
|
||||
handoff = Handoff(target="agent2")
|
||||
|
|
File diff suppressed because one or more lines are too long
Loading…
Reference in New Issue