Add output_format to AssistantAgent for structured output (#6071)

Resolves #5934

This PR adds ability for `AssistantAgent` to generate a
`StructuredMessage[T]` where `T` is the content type in base model.

How to use?

```python
from typing import Literal

from pydantic import BaseModel

from autogen_agentchat.agents import AssistantAgent
from autogen_ext.models.openai import OpenAIChatCompletionClient
from autogen_agentchat.ui import Console

# The response format for the agent as a Pydantic base model.
class AgentResponse(BaseModel):
    thoughts: str
    response: Literal["happy", "sad", "neutral"]


# Create an agent that uses the OpenAI GPT-4o model which supports structured output.
model_client = OpenAIChatCompletionClient(model="gpt-4o")
agent = AssistantAgent(
    "assistant",
    model_client=model_client,
    system_message="Categorize the input as happy, sad, or neutral following the JSON format.",
    # Setting the output format to AgentResponse to force the agent to produce a JSON string as response.
    output_content_type=AgentResponse,
)

result = await Console(agent.run_stream(task="I am happy."))

# Check the last message in the result, validate its type, and print the thoughts and response.
assert isinstance(result.messages[-1], StructuredMessage)
assert isinstance(result.messages[-1].content, AgentResponse)
print("Thought: ", result.messages[-1].content.thoughts)
print("Response: ", result.messages[-1].content.response)
await model_client.close()
```

```
---------- user ----------
I am happy.
---------- assistant ----------
{
  "thoughts": "The user explicitly states they are happy.",
  "response": "happy"
}
Thought:  The user explicitly states they are happy.
Response:  happy
```

---------

Co-authored-by: Victor Dibia <victordibia@microsoft.com>
This commit is contained in:
Eric Zhu 2025-04-01 13:11:01 -07:00 committed by GitHub
parent 9915b65929
commit 86237c9fdf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 1120 additions and 890 deletions

View File

@ -45,6 +45,7 @@ from ..messages import (
HandoffMessage,
MemoryQueryEvent,
ModelClientStreamingChunkEvent,
StructuredMessage,
TextMessage,
ThoughtEvent,
ToolCallExecutionEvent,
@ -102,12 +103,25 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
.. image:: ../../images/assistant-agent.svg
Tool call behavior:
**Structured output:**
* If the model returns no tool call, then the response is immediately returned as a :class:`~autogen_agentchat.messages.TextMessage` in :attr:`~autogen_agentchat.base.Response.chat_message`.
If the `output_content_type` is set, the agent will respond with a :class:`~autogen_agentchat.messages.StructuredMessage`
instead of a :class:`~autogen_agentchat.messages.TextMessage` in the final response by default.
.. note::
Currently, setting `output_content_type` prevents the agent from being
able to call `load_component` and `dum_component` methods for serializable
configuration. This will be fixed soon in the future.
**Tool call behavior:**
* If the model returns no tool call, then the response is immediately returned as a :class:`~autogen_agentchat.messages.TextMessage` or a :class:`~autogen_agentchat.messages.StructuredMessage` (when using structured output) in :attr:`~autogen_agentchat.base.Response.chat_message`.
* When the model returns tool calls, they will be executed right away:
- When `reflect_on_tool_use` is False (default), the tool call results are returned as a :class:`~autogen_agentchat.messages.ToolCallSummaryMessage` in :attr:`~autogen_agentchat.base.Response.chat_message`. `tool_call_summary_format` can be used to customize the tool call summary.
- When `reflect_on_tool_use` is True, the another model inference is made using the tool calls and results, and the text response is returned as a :class:`~autogen_agentchat.messages.TextMessage` in :attr:`~autogen_agentchat.base.Response.chat_message`.
- When `reflect_on_tool_use` is False, the tool call results are returned as a :class:`~autogen_agentchat.messages.ToolCallSummaryMessage` in :attr:`~autogen_agentchat.base.Response.chat_message`. `tool_call_summary_format` can be used to customize the tool call summary.
- When `reflect_on_tool_use` is True, the another model inference is made using the tool calls and results, and final response is returned as a :class:`~autogen_agentchat.messages.TextMessage` or a :class:`~autogen_agentchat.messages.StructuredMessage` (when using structured output) in :attr:`~autogen_agentchat.base.Response.chat_message`.
- `reflect_on_tool_use` is set to `True` by default when `output_content_type` is set.
- `reflect_on_tool_use` is set to `False` by default when `output_content_type` is not set.
* If the model returns multiple tool calls, they will be executed concurrently. To disable parallel tool calls you need to configure the model client. For example, set `parallel_tool_calls=False` for :class:`~autogen_ext.models.openai.OpenAIChatCompletionClient` and :class:`~autogen_ext.models.openai.AzureOpenAIChatCompletionClient`.
.. tip::
@ -116,7 +130,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
especially if another agent is expecting them in a specific format.
Use `tool_call_summary_format` to customize the tool call summary, if needed.
Hand off behavior:
**Hand off behavior:**
* If a handoff is triggered, a :class:`~autogen_agentchat.messages.HandoffMessage` will be returned in :attr:`~autogen_agentchat.base.Response.chat_message`.
* If there are tool calls, they will also be executed right away before returning the handoff.
@ -128,16 +142,18 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
To avoid this, disable parallel tool calls in the model client configuration.
Limit context size sent to the model:
**Limit context size sent to the model:**
You can limit the number of messages sent to the model by setting
the `model_context` parameter to a :class:`~autogen_core.model_context.BufferedChatCompletionContext`.
This will limit the number of recent messages sent to the model and can be useful
when the model has a limit on the number of tokens it can process.
Another option is to use a :class:`~autogen_core.model_context.TokenLimitedChatCompletionContext`
which will limit the number of tokens sent to the model.
You can also create your own model context by subclassing
:class:`~autogen_core.model_context.ChatCompletionContext`.
Streaming mode:
**Streaming mode:**
The assistant agent can be used in streaming mode by setting `model_client_stream=True`.
In this mode, the :meth:`on_messages_stream` and :meth:`BaseChatAgent.run_stream` methods will also yield
@ -161,8 +177,14 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
:meth:`on_messages_stream` and :meth:`BaseChatAgent.run_stream` methods will also yield :class:`~autogen_agentchat.messages.ModelClientStreamingChunkEvent`
messages as the model client produces chunks of response. Defaults to `False`.
reflect_on_tool_use (bool, optional): If `True`, the agent will make another model inference using the tool call and result
to generate a response. If `False`, the tool call result will be returned as the response. Defaults to `False`.
tool_call_summary_format (str, optional): The format string used to create a tool call summary for every tool call result.
to generate a response. If `False`, the tool call result will be returned as the response. By default, if `output_content_type` is set, this will be `True`;
if `output_content_type` is not set, this will be `False`.
output_content_type (type[BaseModel] | None, optional): The output content type for :class:`~autogen_agentchat.messages.StructuredMessage` response as a Pydantic model.
This will be used with the model client to generate structured output.
If this is set, the agent will respond with a :class:`~autogen_agentchat.messages.StructuredMessage` instead of a :class:`~autogen_agentchat.messages.TextMessage`
in the final response, unless `reflect_on_tool_use` is `False` and a tool call is made.
tool_call_summary_format (str, optional): The format string used to create the content for a :class:`~autogen_agentchat.messages.ToolCallSummaryMessage` response.
The format string is used to format the tool call summary for every tool call result.
Defaults to "{result}".
When `reflect_on_tool_use` is `False`, a concatenation of all the tool call summaries, separated by a new line character ('\\n')
will be returned as the response.
@ -348,10 +370,9 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
# which is required for structured output mode.
tool = FunctionTool(sentiment_analysis, description="Sentiment Analysis", strict=True)
# Create an OpenAIChatCompletionClient instance that uses the structured output format.
# Create an OpenAIChatCompletionClient instance that supports structured output.
model_client = OpenAIChatCompletionClient(
model="gpt-4o-mini",
response_format=AgentResponse, # type: ignore
)
# Create an AssistantAgent instance that uses the tool and model client.
@ -360,7 +381,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
model_client=model_client,
tools=[tool],
system_message="Use the tool to analyze sentiment.",
reflect_on_tool_use=True, # Use reflection to have the agent generate a formatted response.
output_content_type=AgentResponse,
)
@ -611,25 +632,17 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
str | None
) = "You are a helpful AI assistant. Solve tasks using your tools. Reply with TERMINATE when the task has been completed.",
model_client_stream: bool = False,
reflect_on_tool_use: bool = False,
reflect_on_tool_use: bool | None = None,
tool_call_summary_format: str = "{result}",
output_content_type: type[BaseModel] | None = None,
memory: Sequence[Memory] | None = None,
metadata: Dict[str, str] | None = None,
):
super().__init__(name=name, description=description)
self._metadata = metadata or {}
if reflect_on_tool_use and ModelFamily.is_claude(model_client.model_info["family"]):
warnings.warn(
"Claude models may not work with reflection on tool use because Claude requires that any requests including a previous tool use or tool result must include the original tools definition."
"Consider setting reflect_on_tool_use to False. "
"As an alternative, consider calling the agent in a loop until it stops producing tool calls. "
"See [Single-Agent Team](https://microsoft.github.io/autogen/stable/user-guide/agentchat-user-guide/tutorial/teams.html#single-agent-team) "
"for more details.",
UserWarning,
stacklevel=2,
)
self._model_client = model_client
self._model_client_stream = model_client_stream
self._output_content_type: type[BaseModel] | None = output_content_type
self._memory = None
if memory is not None:
if isinstance(memory, list):
@ -692,17 +705,37 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
else:
self._model_context = UnboundedChatCompletionContext()
self._reflect_on_tool_use = reflect_on_tool_use
if self._output_content_type is not None and reflect_on_tool_use is None:
# If output_content_type is set, we need to reflect on tool use by default.
self._reflect_on_tool_use = True
elif reflect_on_tool_use is None:
self._reflect_on_tool_use = False
else:
self._reflect_on_tool_use = reflect_on_tool_use
if self._reflect_on_tool_use and ModelFamily.is_claude(model_client.model_info["family"]):
warnings.warn(
"Claude models may not work with reflection on tool use because Claude requires that any requests including a previous tool use or tool result must include the original tools definition."
"Consider setting reflect_on_tool_use to False. "
"As an alternative, consider calling the agent in a loop until it stops producing tool calls. "
"See [Single-Agent Team](https://microsoft.github.io/autogen/stable/user-guide/agentchat-user-guide/tutorial/teams.html#single-agent-team) "
"for more details.",
UserWarning,
stacklevel=2,
)
self._tool_call_summary_format = tool_call_summary_format
self._is_running = False
@property
def produced_message_types(self) -> Sequence[type[BaseChatMessage]]:
message_types: List[type[BaseChatMessage]] = [TextMessage]
message_types: List[type[BaseChatMessage]] = []
if self._handoffs:
message_types.append(HandoffMessage)
if self._tools:
message_types.append(ToolCallSummaryMessage)
if self._output_content_type:
message_types.append(StructuredMessage[self._output_content_type]) # type: ignore[name-defined]
else:
message_types.append(TextMessage)
return tuple(message_types)
@property
@ -737,6 +770,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
model_client_stream = self._model_client_stream
reflect_on_tool_use = self._reflect_on_tool_use
tool_call_summary_format = self._tool_call_summary_format
output_content_type = self._output_content_type
# STEP 1: Add new user/handoff messages to the model context
await self._add_messages_to_context(
@ -765,6 +799,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
handoff_tools=handoff_tools,
agent_name=agent_name,
cancellation_token=cancellation_token,
output_content_type=output_content_type,
):
if isinstance(inference_output, CreateResult):
model_result = inference_output
@ -804,6 +839,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
model_client_stream=model_client_stream,
reflect_on_tool_use=reflect_on_tool_use,
tool_call_summary_format=tool_call_summary_format,
output_content_type=output_content_type,
):
yield output_event
@ -853,6 +889,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
handoff_tools: List[BaseTool[Any, Any]],
agent_name: str,
cancellation_token: CancellationToken,
output_content_type: type[BaseModel] | None,
) -> AsyncGenerator[Union[CreateResult, ModelClientStreamingChunkEvent], None]:
"""
Perform a model inference and yield either streaming chunk events or the final CreateResult.
@ -865,7 +902,10 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
if model_client_stream:
model_result: Optional[CreateResult] = None
async for chunk in model_client.create_stream(
llm_messages, tools=all_tools, cancellation_token=cancellation_token
llm_messages,
tools=all_tools,
json_output=output_content_type,
cancellation_token=cancellation_token,
):
if isinstance(chunk, CreateResult):
model_result = chunk
@ -878,7 +918,10 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
yield model_result
else:
model_result = await model_client.create(
llm_messages, tools=all_tools, cancellation_token=cancellation_token
llm_messages,
tools=all_tools,
cancellation_token=cancellation_token,
json_output=output_content_type,
)
yield model_result
@ -898,6 +941,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
model_client_stream: bool,
reflect_on_tool_use: bool,
tool_call_summary_format: str,
output_content_type: type[BaseModel] | None,
) -> AsyncGenerator[BaseAgentEvent | BaseChatMessage | Response, None]:
"""
Handle final or partial responses from model_result, including tool calls, handoffs,
@ -906,14 +950,25 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
# If direct text response (string)
if isinstance(model_result.content, str):
yield Response(
chat_message=TextMessage(
content=model_result.content,
source=agent_name,
models_usage=model_result.usage,
),
inner_messages=inner_messages,
)
if output_content_type:
content = output_content_type.model_validate_json(model_result.content)
yield Response(
chat_message=StructuredMessage[output_content_type]( # type: ignore[valid-type]
content=content,
source=agent_name,
models_usage=model_result.usage,
),
inner_messages=inner_messages,
)
else:
yield Response(
chat_message=TextMessage(
content=model_result.content,
source=agent_name,
models_usage=model_result.usage,
),
inner_messages=inner_messages,
)
return
# Otherwise, we have function calls
@ -977,6 +1032,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
model_context=model_context,
agent_name=agent_name,
inner_messages=inner_messages,
output_content_type=output_content_type,
):
yield reflection_response
else:
@ -1062,6 +1118,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
model_context: ChatCompletionContext,
agent_name: str,
inner_messages: List[BaseAgentEvent | BaseChatMessage],
output_content_type: type[BaseModel] | None,
) -> AsyncGenerator[Response | ModelClientStreamingChunkEvent | ThoughtEvent, None]:
"""
If reflect_on_tool_use=True, we do another inference based on tool results
@ -1073,7 +1130,10 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
reflection_result: Optional[CreateResult] = None
if model_client_stream:
async for chunk in model_client.create_stream(llm_messages):
async for chunk in model_client.create_stream(
llm_messages,
json_output=output_content_type,
):
if isinstance(chunk, CreateResult):
reflection_result = chunk
elif isinstance(chunk, str):
@ -1081,7 +1141,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
else:
raise RuntimeError(f"Invalid chunk type: {type(chunk)}")
else:
reflection_result = await model_client.create(llm_messages)
reflection_result = await model_client.create(llm_messages, json_output=output_content_type)
if not reflection_result or not isinstance(reflection_result.content, str):
raise RuntimeError("Reflect on tool use produced no valid text response.")
@ -1101,14 +1161,25 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
)
)
yield Response(
chat_message=TextMessage(
content=reflection_result.content,
source=agent_name,
models_usage=reflection_result.usage,
),
inner_messages=inner_messages,
)
if output_content_type:
content = output_content_type.model_validate_json(reflection_result.content)
yield Response(
chat_message=StructuredMessage[output_content_type]( # type: ignore[valid-type]
content=content,
source=agent_name,
models_usage=reflection_result.usage,
),
inner_messages=inner_messages,
)
else:
yield Response(
chat_message=TextMessage(
content=reflection_result.content,
source=agent_name,
models_usage=reflection_result.usage,
),
inner_messages=inner_messages,
)
@staticmethod
def _summarize_tool_use(
@ -1206,6 +1277,9 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
def _to_config(self) -> AssistantAgentConfig:
"""Convert the assistant agent to a declarative config."""
if self._output_content_type:
raise ValueError("AssistantAgent with output_content_type does not support declarative config.")
return AssistantAgentConfig(
name=self.name,
model_client=self._model_client.dump_component(),
@ -1226,6 +1300,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
@classmethod
def _from_config(cls, config: AssistantAgentConfig) -> Self:
"""Create an assistant agent from a declarative config."""
return cls(
name=config.name,
model_client=ChatCompletionClient.load_component(config.model_client),

View File

@ -401,6 +401,147 @@ async def test_run_with_parallel_tools_with_empty_call_ids() -> None:
assert state == state2
@pytest.mark.asyncio
async def test_output_format() -> None:
class AgentResponse(BaseModel):
response: str
status: str
model_client = ReplayChatCompletionClient(
[
CreateResult(
finish_reason="stop",
content=AgentResponse(response="Hello", status="success").model_dump_json(),
usage=RequestUsage(prompt_tokens=10, completion_tokens=5),
cached=False,
),
]
)
agent = AssistantAgent(
"test_agent",
model_client=model_client,
output_content_type=AgentResponse,
)
assert StructuredMessage[AgentResponse] in agent.produced_message_types
assert TextMessage not in agent.produced_message_types
result = await agent.run()
assert len(result.messages) == 1
assert isinstance(result.messages[0], StructuredMessage)
assert isinstance(result.messages[0].content, AgentResponse) # type: ignore[reportUnknownMemberType]
assert result.messages[0].content.response == "Hello"
assert result.messages[0].content.status == "success"
# Test streaming.
agent = AssistantAgent(
"test_agent",
model_client=model_client,
model_client_stream=True,
output_content_type=AgentResponse,
)
model_client.reset()
stream = agent.run_stream()
stream_result: TaskResult | None = None
async for message in stream:
if isinstance(message, TaskResult):
stream_result = message
assert stream_result is not None
assert len(stream_result.messages) == 1
assert isinstance(stream_result.messages[0], StructuredMessage)
assert isinstance(stream_result.messages[0].content, AgentResponse) # type: ignore[reportUnknownMemberType]
assert stream_result.messages[0].content.response == "Hello"
assert stream_result.messages[0].content.status == "success"
@pytest.mark.asyncio
async def test_reflection_output_format() -> None:
class AgentResponse(BaseModel):
response: str
status: str
model_client = ReplayChatCompletionClient(
[
CreateResult(
finish_reason="function_calls",
content=[FunctionCall(id="1", arguments=json.dumps({"input": "task"}), name="_pass_function")],
usage=RequestUsage(prompt_tokens=10, completion_tokens=5),
cached=False,
),
AgentResponse(response="Hello", status="success").model_dump_json(),
],
model_info={
"function_calling": True,
"vision": True,
"json_output": True,
"family": ModelFamily.GPT_4O,
"structured_output": True,
},
)
agent = AssistantAgent(
"test_agent",
model_client=model_client,
output_content_type=AgentResponse,
# reflect_on_tool_use=True,
tools=[
_pass_function,
_fail_function,
],
)
result = await agent.run()
assert len(result.messages) == 3
assert isinstance(result.messages[0], ToolCallRequestEvent)
assert isinstance(result.messages[1], ToolCallExecutionEvent)
assert isinstance(result.messages[2], StructuredMessage)
assert isinstance(result.messages[2].content, AgentResponse) # type: ignore[reportUnknownMemberType]
assert result.messages[2].content.response == "Hello"
assert result.messages[2].content.status == "success"
# Test streaming.
agent = AssistantAgent(
"test_agent",
model_client=model_client,
model_client_stream=True,
output_content_type=AgentResponse,
# reflect_on_tool_use=True,
tools=[
_pass_function,
_fail_function,
],
)
model_client.reset()
stream = agent.run_stream()
stream_result: TaskResult | None = None
async for message in stream:
if isinstance(message, TaskResult):
stream_result = message
assert stream_result is not None
assert len(stream_result.messages) == 3
assert isinstance(stream_result.messages[0], ToolCallRequestEvent)
assert isinstance(stream_result.messages[1], ToolCallExecutionEvent)
assert isinstance(stream_result.messages[2], StructuredMessage)
assert isinstance(stream_result.messages[2].content, AgentResponse) # type: ignore[reportUnknownMemberType]
assert stream_result.messages[2].content.response == "Hello"
assert stream_result.messages[2].content.status == "success"
# Test when reflect_on_tool_use is False
model_client.reset()
agent = AssistantAgent(
"test_agent",
model_client=model_client,
output_content_type=AgentResponse,
reflect_on_tool_use=False,
tools=[
_pass_function,
_fail_function,
],
)
result = await agent.run()
assert len(result.messages) == 3
assert isinstance(result.messages[0], ToolCallRequestEvent)
assert isinstance(result.messages[1], ToolCallExecutionEvent)
assert isinstance(result.messages[2], ToolCallSummaryMessage)
@pytest.mark.asyncio
async def test_handoffs() -> None:
handoff = Handoff(target="agent2")

File diff suppressed because one or more lines are too long