feat: introduce ModelClientStreamingChunkEvent for streaming model output and update handling in agents and console (#5208)

Resolves #3983

* introduce `model_client_stream` parameter in `AssistantAgent` to
enable token-level streaming output.
* introduce `ModelClientStreamingChunkEvent` as a type of `AgentEvent`
to pass the streaming chunks to the application via `run_stream` and
`on_messages_stream`. Although this will not affect the inner messages
list in the final `Response` or `TaskResult`.
* handle this new message type in `Console`.
This commit is contained in:
Eric Zhu 2025-01-28 18:49:02 -08:00 committed by GitHub
parent 8a0daf8285
commit 225eb9d0b2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 330 additions and 32 deletions

View File

@ -22,13 +22,14 @@ from autogen_core.model_context import (
from autogen_core.models import ( from autogen_core.models import (
AssistantMessage, AssistantMessage,
ChatCompletionClient, ChatCompletionClient,
CreateResult,
FunctionExecutionResult, FunctionExecutionResult,
FunctionExecutionResultMessage, FunctionExecutionResultMessage,
LLMMessage, LLMMessage,
SystemMessage, SystemMessage,
UserMessage, UserMessage,
) )
from autogen_core.tools import FunctionTool, BaseTool from autogen_core.tools import BaseTool, FunctionTool
from pydantic import BaseModel from pydantic import BaseModel
from typing_extensions import Self from typing_extensions import Self
@ -40,6 +41,7 @@ from ..messages import (
ChatMessage, ChatMessage,
HandoffMessage, HandoffMessage,
MemoryQueryEvent, MemoryQueryEvent,
ModelClientStreamingChunkEvent,
MultiModalMessage, MultiModalMessage,
TextMessage, TextMessage,
ToolCallExecutionEvent, ToolCallExecutionEvent,
@ -62,6 +64,7 @@ class AssistantAgentConfig(BaseModel):
model_context: ComponentModel | None = None model_context: ComponentModel | None = None
description: str description: str
system_message: str | None = None system_message: str | None = None
model_client_stream: bool
reflect_on_tool_use: bool reflect_on_tool_use: bool
tool_call_summary_format: str tool_call_summary_format: str
@ -126,6 +129,14 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
This will limit the number of recent messages sent to the model and can be useful This will limit the number of recent messages sent to the model and can be useful
when the model has a limit on the number of tokens it can process. when the model has a limit on the number of tokens it can process.
Streaming mode:
The assistant agent can be used in streaming mode by setting `model_client_stream=True`.
In this mode, the :meth:`on_messages_stream` and :meth:`BaseChatAgent.run_stream` methods will also yield
:class:`~autogen_agentchat.messages.ModelClientStreamingChunkEvent`
messages as the model client produces chunks of response.
The chunk messages will not be included in the final response's inner messages.
Args: Args:
name (str): The name of the agent. name (str): The name of the agent.
@ -138,6 +149,9 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
model_context (ChatCompletionContext | None, optional): The model context for storing and retrieving :class:`~autogen_core.models.LLMMessage`. It can be preloaded with initial messages. The initial messages will be cleared when the agent is reset. model_context (ChatCompletionContext | None, optional): The model context for storing and retrieving :class:`~autogen_core.models.LLMMessage`. It can be preloaded with initial messages. The initial messages will be cleared when the agent is reset.
description (str, optional): The description of the agent. description (str, optional): The description of the agent.
system_message (str, optional): The system message for the model. If provided, it will be prepended to the messages in the model context when making an inference. Set to `None` to disable. system_message (str, optional): The system message for the model. If provided, it will be prepended to the messages in the model context when making an inference. Set to `None` to disable.
model_client_stream (bool, optional): If `True`, the model client will be used in streaming mode.
:meth:`on_messages_stream` and :meth:`BaseChatAgent.run_stream` methods will also yield :class:`~autogen_agentchat.messages.ModelClientStreamingChunkEvent`
messages as the model client produces chunks of response. Defaults to `False`.
reflect_on_tool_use (bool, optional): If `True`, the agent will make another model inference using the tool call and result reflect_on_tool_use (bool, optional): If `True`, the agent will make another model inference using the tool call and result
to generate a response. If `False`, the tool call result will be returned as the response. Defaults to `False`. to generate a response. If `False`, the tool call result will be returned as the response. Defaults to `False`.
tool_call_summary_format (str, optional): The format string used to create a tool call summary for every tool call result. tool_call_summary_format (str, optional): The format string used to create a tool call summary for every tool call result.
@ -268,12 +282,14 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
system_message: ( system_message: (
str | None str | None
) = "You are a helpful AI assistant. Solve tasks using your tools. Reply with TERMINATE when the task has been completed.", ) = "You are a helpful AI assistant. Solve tasks using your tools. Reply with TERMINATE when the task has been completed.",
model_client_stream: bool = False,
reflect_on_tool_use: bool = False, reflect_on_tool_use: bool = False,
tool_call_summary_format: str = "{result}", tool_call_summary_format: str = "{result}",
memory: Sequence[Memory] | None = None, memory: Sequence[Memory] | None = None,
): ):
super().__init__(name=name, description=description) super().__init__(name=name, description=description)
self._model_client = model_client self._model_client = model_client
self._model_client_stream = model_client_stream
self._memory = None self._memory = None
if memory is not None: if memory is not None:
if isinstance(memory, list): if isinstance(memory, list):
@ -340,7 +356,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
@property @property
def produced_message_types(self) -> Sequence[type[ChatMessage]]: def produced_message_types(self) -> Sequence[type[ChatMessage]]:
"""The types of messages that the assistant agent produces.""" """The types of final response messages that the assistant agent produces."""
message_types: List[type[ChatMessage]] = [TextMessage] message_types: List[type[ChatMessage]] = [TextMessage]
if self._handoffs: if self._handoffs:
message_types.append(HandoffMessage) message_types.append(HandoffMessage)
@ -383,9 +399,23 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
# Generate an inference result based on the current model context. # Generate an inference result based on the current model context.
llm_messages = self._system_messages + await self._model_context.get_messages() llm_messages = self._system_messages + await self._model_context.get_messages()
model_result = await self._model_client.create( model_result: CreateResult | None = None
llm_messages, tools=self._tools + self._handoff_tools, cancellation_token=cancellation_token if self._model_client_stream:
) # Stream the model client.
async for chunk in self._model_client.create_stream(
llm_messages, tools=self._tools + self._handoff_tools, cancellation_token=cancellation_token
):
if isinstance(chunk, CreateResult):
model_result = chunk
elif isinstance(chunk, str):
yield ModelClientStreamingChunkEvent(content=chunk, source=self.name)
else:
raise RuntimeError(f"Invalid chunk type: {type(chunk)}")
assert isinstance(model_result, CreateResult)
else:
model_result = await self._model_client.create(
llm_messages, tools=self._tools + self._handoff_tools, cancellation_token=cancellation_token
)
# Add the response to the model context. # Add the response to the model context.
await self._model_context.add_message(AssistantMessage(content=model_result.content, source=self.name)) await self._model_context.add_message(AssistantMessage(content=model_result.content, source=self.name))
@ -465,14 +495,34 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
if self._reflect_on_tool_use: if self._reflect_on_tool_use:
# Generate another inference result based on the tool call and result. # Generate another inference result based on the tool call and result.
llm_messages = self._system_messages + await self._model_context.get_messages() llm_messages = self._system_messages + await self._model_context.get_messages()
model_result = await self._model_client.create(llm_messages, cancellation_token=cancellation_token) reflection_model_result: CreateResult | None = None
assert isinstance(model_result.content, str) if self._model_client_stream:
# Stream the model client.
async for chunk in self._model_client.create_stream(
llm_messages, cancellation_token=cancellation_token
):
if isinstance(chunk, CreateResult):
reflection_model_result = chunk
elif isinstance(chunk, str):
yield ModelClientStreamingChunkEvent(content=chunk, source=self.name)
else:
raise RuntimeError(f"Invalid chunk type: {type(chunk)}")
assert isinstance(reflection_model_result, CreateResult)
else:
reflection_model_result = await self._model_client.create(
llm_messages, cancellation_token=cancellation_token
)
assert isinstance(reflection_model_result.content, str)
# Add the response to the model context. # Add the response to the model context.
await self._model_context.add_message(AssistantMessage(content=model_result.content, source=self.name)) await self._model_context.add_message(
AssistantMessage(content=reflection_model_result.content, source=self.name)
)
# Yield the response. # Yield the response.
yield Response( yield Response(
chat_message=TextMessage( chat_message=TextMessage(
content=model_result.content, source=self.name, models_usage=model_result.usage content=reflection_model_result.content,
source=self.name,
models_usage=reflection_model_result.usage,
), ),
inner_messages=inner_messages, inner_messages=inner_messages,
) )
@ -538,6 +588,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
system_message=self._system_messages[0].content system_message=self._system_messages[0].content
if self._system_messages and isinstance(self._system_messages[0].content, str) if self._system_messages and isinstance(self._system_messages[0].content, str)
else None, else None,
model_client_stream=self._model_client_stream,
reflect_on_tool_use=self._reflect_on_tool_use, reflect_on_tool_use=self._reflect_on_tool_use,
tool_call_summary_format=self._tool_call_summary_format, tool_call_summary_format=self._tool_call_summary_format,
) )
@ -553,6 +604,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
model_context=None, model_context=None,
description=config.description, description=config.description,
system_message=config.system_message, system_message=config.system_message,
model_client_stream=config.model_client_stream,
reflect_on_tool_use=config.reflect_on_tool_use, reflect_on_tool_use=config.reflect_on_tool_use,
tool_call_summary_format=config.tool_call_summary_format, tool_call_summary_format=config.tool_call_summary_format,
) )

View File

@ -9,6 +9,7 @@ from ..messages import (
AgentEvent, AgentEvent,
BaseChatMessage, BaseChatMessage,
ChatMessage, ChatMessage,
ModelClientStreamingChunkEvent,
TextMessage, TextMessage,
) )
from ..state import BaseState from ..state import BaseState
@ -178,8 +179,11 @@ class BaseChatAgent(ChatAgent, ABC, ComponentBase[BaseModel]):
output_messages.append(message.chat_message) output_messages.append(message.chat_message)
yield TaskResult(messages=output_messages) yield TaskResult(messages=output_messages)
else: else:
output_messages.append(message)
yield message yield message
if isinstance(message, ModelClientStreamingChunkEvent):
# Skip the model client streaming chunk events.
continue
output_messages.append(message)
@abstractmethod @abstractmethod
async def on_reset(self, cancellation_token: CancellationToken) -> None: async def on_reset(self, cancellation_token: CancellationToken) -> None:

View File

@ -13,6 +13,7 @@ from ..messages import (
AgentEvent, AgentEvent,
BaseChatMessage, BaseChatMessage,
ChatMessage, ChatMessage,
ModelClientStreamingChunkEvent,
TextMessage, TextMessage,
) )
from ._base_chat_agent import BaseChatAgent from ._base_chat_agent import BaseChatAgent
@ -150,6 +151,9 @@ class SocietyOfMindAgent(BaseChatAgent, Component[SocietyOfMindAgentConfig]):
# Skip the task messages. # Skip the task messages.
continue continue
yield inner_msg yield inner_msg
if isinstance(inner_msg, ModelClientStreamingChunkEvent):
# Skip the model client streaming chunk events.
continue
inner_messages.append(inner_msg) inner_messages.append(inner_msg)
assert result is not None assert result is not None

View File

@ -1,7 +1,7 @@
import logging import logging
from typing import Any, Dict from typing import Any, Dict
from autogen_core.tools import FunctionTool, BaseTool from autogen_core.tools import BaseTool, FunctionTool
from pydantic import BaseModel, Field, model_validator from pydantic import BaseModel, Field, model_validator
from .. import EVENT_LOGGER_NAME from .. import EVENT_LOGGER_NAME

View File

@ -128,6 +128,15 @@ class MemoryQueryEvent(BaseAgentEvent):
type: Literal["MemoryQueryEvent"] = "MemoryQueryEvent" type: Literal["MemoryQueryEvent"] = "MemoryQueryEvent"
class ModelClientStreamingChunkEvent(BaseAgentEvent):
"""An event signaling a text output chunk from a model client in streaming mode."""
content: str
"""The partial text chunk."""
type: Literal["ModelClientStreamingChunkEvent"] = "ModelClientStreamingChunkEvent"
ChatMessage = Annotated[ ChatMessage = Annotated[
TextMessage | MultiModalMessage | StopMessage | ToolCallSummaryMessage | HandoffMessage, Field(discriminator="type") TextMessage | MultiModalMessage | StopMessage | ToolCallSummaryMessage | HandoffMessage, Field(discriminator="type")
] ]
@ -135,7 +144,11 @@ ChatMessage = Annotated[
AgentEvent = Annotated[ AgentEvent = Annotated[
ToolCallRequestEvent | ToolCallExecutionEvent | MemoryQueryEvent | UserInputRequestedEvent, ToolCallRequestEvent
| ToolCallExecutionEvent
| MemoryQueryEvent
| UserInputRequestedEvent
| ModelClientStreamingChunkEvent,
Field(discriminator="type"), Field(discriminator="type"),
] ]
"""Events emitted by agents and teams when they work, not used for agent-to-agent communication.""" """Events emitted by agents and teams when they work, not used for agent-to-agent communication."""
@ -154,4 +167,5 @@ __all__ = [
"ToolCallSummaryMessage", "ToolCallSummaryMessage",
"MemoryQueryEvent", "MemoryQueryEvent",
"UserInputRequestedEvent", "UserInputRequestedEvent",
"ModelClientStreamingChunkEvent",
] ]

View File

@ -21,7 +21,7 @@ from pydantic import BaseModel
from ... import EVENT_LOGGER_NAME from ... import EVENT_LOGGER_NAME
from ...base import ChatAgent, TaskResult, Team, TerminationCondition from ...base import ChatAgent, TaskResult, Team, TerminationCondition
from ...messages import AgentEvent, BaseChatMessage, ChatMessage, TextMessage from ...messages import AgentEvent, BaseChatMessage, ChatMessage, ModelClientStreamingChunkEvent, TextMessage
from ...state import TeamState from ...state import TeamState
from ._chat_agent_container import ChatAgentContainer from ._chat_agent_container import ChatAgentContainer
from ._events import GroupChatMessage, GroupChatReset, GroupChatStart, GroupChatTermination from ._events import GroupChatMessage, GroupChatReset, GroupChatStart, GroupChatTermination
@ -190,6 +190,9 @@ class BaseGroupChat(Team, ABC, ComponentBase[BaseModel]):
and it may not reset the termination condition. and it may not reset the termination condition.
To gracefully stop the team, use :class:`~autogen_agentchat.conditions.ExternalTermination` instead. To gracefully stop the team, use :class:`~autogen_agentchat.conditions.ExternalTermination` instead.
Returns:
result: The result of the task as :class:`~autogen_agentchat.base.TaskResult`. The result contains the messages produced by the team and the stop reason.
Example using the :class:`~autogen_agentchat.teams.RoundRobinGroupChat` team: Example using the :class:`~autogen_agentchat.teams.RoundRobinGroupChat` team:
@ -279,9 +282,15 @@ class BaseGroupChat(Team, ABC, ComponentBase[BaseModel]):
cancellation_token: CancellationToken | None = None, cancellation_token: CancellationToken | None = None,
) -> AsyncGenerator[AgentEvent | ChatMessage | TaskResult, None]: ) -> AsyncGenerator[AgentEvent | ChatMessage | TaskResult, None]:
"""Run the team and produces a stream of messages and the final result """Run the team and produces a stream of messages and the final result
of the type :class:`TaskResult` as the last item in the stream. Once the of the type :class:`~autogen_agentchat.base.TaskResult` as the last item in the stream. Once the
team is stopped, the termination condition is reset. team is stopped, the termination condition is reset.
.. note::
If an agent produces :class:`~autogen_agentchat.messages.ModelClientStreamingChunkEvent`,
the message will be yielded in the stream but it will not be included in the
:attr:`~autogen_agentchat.base.TaskResult.messages`.
Args: Args:
task (str | ChatMessage | Sequence[ChatMessage] | None): The task to run the team with. Can be a string, a single :class:`ChatMessage` , or a list of :class:`ChatMessage`. task (str | ChatMessage | Sequence[ChatMessage] | None): The task to run the team with. Can be a string, a single :class:`ChatMessage` , or a list of :class:`ChatMessage`.
cancellation_token (CancellationToken | None): The cancellation token to kill the task immediately. cancellation_token (CancellationToken | None): The cancellation token to kill the task immediately.
@ -289,6 +298,9 @@ class BaseGroupChat(Team, ABC, ComponentBase[BaseModel]):
and it may not reset the termination condition. and it may not reset the termination condition.
To gracefully stop the team, use :class:`~autogen_agentchat.conditions.ExternalTermination` instead. To gracefully stop the team, use :class:`~autogen_agentchat.conditions.ExternalTermination` instead.
Returns:
stream: an :class:`~collections.abc.AsyncGenerator` that yields :class:`~autogen_agentchat.messages.AgentEvent`, :class:`~autogen_agentchat.messages.ChatMessage`, and the final result :class:`~autogen_agentchat.base.TaskResult` as the last item in the stream.
Example using the :class:`~autogen_agentchat.teams.RoundRobinGroupChat` team: Example using the :class:`~autogen_agentchat.teams.RoundRobinGroupChat` team:
.. code-block:: python .. code-block:: python
@ -422,6 +434,9 @@ class BaseGroupChat(Team, ABC, ComponentBase[BaseModel]):
if message is None: if message is None:
break break
yield message yield message
if isinstance(message, ModelClientStreamingChunkEvent):
# Skip the model client streaming chunk events.
continue
output_messages.append(message) output_messages.append(message)
# Yield the final result. # Yield the final result.

View File

@ -10,7 +10,13 @@ from autogen_core.models import RequestUsage
from autogen_agentchat.agents import UserProxyAgent from autogen_agentchat.agents import UserProxyAgent
from autogen_agentchat.base import Response, TaskResult from autogen_agentchat.base import Response, TaskResult
from autogen_agentchat.messages import AgentEvent, ChatMessage, MultiModalMessage, UserInputRequestedEvent from autogen_agentchat.messages import (
AgentEvent,
ChatMessage,
ModelClientStreamingChunkEvent,
MultiModalMessage,
UserInputRequestedEvent,
)
def _is_running_in_iterm() -> bool: def _is_running_in_iterm() -> bool:
@ -106,6 +112,8 @@ async def Console(
last_processed: Optional[T] = None last_processed: Optional[T] = None
streaming_chunks: List[str] = []
async for message in stream: async for message in stream:
if isinstance(message, TaskResult): if isinstance(message, TaskResult):
duration = time.time() - start_time duration = time.time() - start_time
@ -159,13 +167,28 @@ async def Console(
else: else:
# Cast required for mypy to be happy # Cast required for mypy to be happy
message = cast(AgentEvent | ChatMessage, message) # type: ignore message = cast(AgentEvent | ChatMessage, message) # type: ignore
output = f"{'-' * 10} {message.source} {'-' * 10}\n{_message_to_str(message, render_image_iterm=render_image_iterm)}\n" if not streaming_chunks:
if message.models_usage: # Print message sender.
if output_stats: await aprint(f"{'-' * 10} {message.source} {'-' * 10}", end="\n")
output += f"[Prompt tokens: {message.models_usage.prompt_tokens}, Completion tokens: {message.models_usage.completion_tokens}]\n" if isinstance(message, ModelClientStreamingChunkEvent):
total_usage.completion_tokens += message.models_usage.completion_tokens await aprint(message.content, end="")
total_usage.prompt_tokens += message.models_usage.prompt_tokens streaming_chunks.append(message.content)
await aprint(output, end="") else:
if streaming_chunks:
streaming_chunks.clear()
# Chunked messages are already printed, so we just print a newline.
await aprint("", end="\n")
else:
# Print message content.
await aprint(_message_to_str(message, render_image_iterm=render_image_iterm), end="\n")
if message.models_usage:
if output_stats:
await aprint(
f"[Prompt tokens: {message.models_usage.prompt_tokens}, Completion tokens: {message.models_usage.completion_tokens}]",
end="\n",
)
total_usage.completion_tokens += message.models_usage.completion_tokens
total_usage.prompt_tokens += message.models_usage.prompt_tokens
if last_processed is None: if last_processed is None:
raise ValueError("No TaskResult or Response was processed.") raise ValueError("No TaskResult or Response was processed.")

View File

@ -11,6 +11,7 @@ from autogen_agentchat.messages import (
ChatMessage, ChatMessage,
HandoffMessage, HandoffMessage,
MemoryQueryEvent, MemoryQueryEvent,
ModelClientStreamingChunkEvent,
MultiModalMessage, MultiModalMessage,
TextMessage, TextMessage,
ToolCallExecutionEvent, ToolCallExecutionEvent,
@ -20,10 +21,11 @@ from autogen_agentchat.messages import (
from autogen_core import FunctionCall, Image from autogen_core import FunctionCall, Image
from autogen_core.memory import ListMemory, Memory, MemoryContent, MemoryMimeType, MemoryQueryResult from autogen_core.memory import ListMemory, Memory, MemoryContent, MemoryMimeType, MemoryQueryResult
from autogen_core.model_context import BufferedChatCompletionContext from autogen_core.model_context import BufferedChatCompletionContext
from autogen_core.models import FunctionExecutionResult, LLMMessage from autogen_core.models import CreateResult, FunctionExecutionResult, LLMMessage, RequestUsage
from autogen_core.models._model_client import ModelFamily from autogen_core.models._model_client import ModelFamily
from autogen_core.tools import FunctionTool from autogen_core.tools import FunctionTool
from autogen_ext.models.openai import OpenAIChatCompletionClient from autogen_ext.models.openai import OpenAIChatCompletionClient
from autogen_ext.models.replay import ReplayChatCompletionClient
from openai.resources.chat.completions import AsyncCompletions from openai.resources.chat.completions import AsyncCompletions
from openai.types.chat.chat_completion import ChatCompletion, Choice from openai.types.chat.chat_completion import ChatCompletion, Choice
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
@ -776,3 +778,65 @@ async def test_assistant_agent_declarative(monkeypatch: pytest.MonkeyPatch) -> N
) )
agent3_config = agent3.dump_component() agent3_config = agent3.dump_component()
assert agent3_config.provider == "autogen_agentchat.agents.AssistantAgent" assert agent3_config.provider == "autogen_agentchat.agents.AssistantAgent"
@pytest.mark.asyncio
async def test_model_client_stream() -> None:
mock_client = ReplayChatCompletionClient(
[
"Response to message 3",
]
)
agent = AssistantAgent(
"test_agent",
model_client=mock_client,
model_client_stream=True,
)
chunks: List[str] = []
async for message in agent.run_stream(task="task"):
if isinstance(message, TaskResult):
assert message.messages[-1].content == "Response to message 3"
elif isinstance(message, ModelClientStreamingChunkEvent):
chunks.append(message.content)
assert "".join(chunks) == "Response to message 3"
@pytest.mark.asyncio
async def test_model_client_stream_with_tool_calls() -> None:
mock_client = ReplayChatCompletionClient(
[
CreateResult(
content=[
FunctionCall(id="1", name="_pass_function", arguments=r'{"input": "task"}'),
FunctionCall(id="3", name="_echo_function", arguments=r'{"input": "task"}'),
],
finish_reason="function_calls",
usage=RequestUsage(prompt_tokens=10, completion_tokens=5),
cached=False,
),
"Example response 2 to task",
]
)
mock_client._model_info["function_calling"] = True # pyright: ignore
agent = AssistantAgent(
"test_agent",
model_client=mock_client,
model_client_stream=True,
reflect_on_tool_use=True,
tools=[_pass_function, _echo_function],
)
chunks: List[str] = []
async for message in agent.run_stream(task="task"):
if isinstance(message, TaskResult):
assert message.messages[-1].content == "Example response 2 to task"
assert message.messages[1].content == [
FunctionCall(id="1", name="_pass_function", arguments=r'{"input": "task"}'),
FunctionCall(id="3", name="_echo_function", arguments=r'{"input": "task"}'),
]
assert message.messages[2].content == [
FunctionExecutionResult(call_id="1", content="pass"),
FunctionExecutionResult(call_id="3", content="task"),
]
elif isinstance(message, ModelClientStreamingChunkEvent):
chunks.append(message.content)
assert "".join(chunks) == "Example response 2 to task"

View File

@ -403,6 +403,117 @@
"await Console(agent.run_stream(task=\"I am happy.\"))" "await Console(agent.run_stream(task=\"I am happy.\"))"
] ]
}, },
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Streaming Tokens\n",
"\n",
"You can stream the tokens generated by the model client by setting `model_client_stream=True`.\n",
"This will cause the agent to yield {py:class}`~autogen_agentchat.messages.ModelClientStreamingChunkEvent` messages\n",
"in {py:meth}`~autogen_agentchat.agents.BaseChatAgent.on_messages_stream` and {py:meth}`~autogen_agentchat.agents.BaseChatAgent.run_stream`.\n",
"\n",
"The underlying model API must support streaming tokens for this to work.\n",
"Please check with your model provider to see if this is supported."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"source='assistant' models_usage=None content='Two' type='ModelClientStreamingChunkEvent'\n",
"source='assistant' models_usage=None content=' cities' type='ModelClientStreamingChunkEvent'\n",
"source='assistant' models_usage=None content=' in' type='ModelClientStreamingChunkEvent'\n",
"source='assistant' models_usage=None content=' South' type='ModelClientStreamingChunkEvent'\n",
"source='assistant' models_usage=None content=' America' type='ModelClientStreamingChunkEvent'\n",
"source='assistant' models_usage=None content=' are' type='ModelClientStreamingChunkEvent'\n",
"source='assistant' models_usage=None content=' Buenos' type='ModelClientStreamingChunkEvent'\n",
"source='assistant' models_usage=None content=' Aires' type='ModelClientStreamingChunkEvent'\n",
"source='assistant' models_usage=None content=' in' type='ModelClientStreamingChunkEvent'\n",
"source='assistant' models_usage=None content=' Argentina' type='ModelClientStreamingChunkEvent'\n",
"source='assistant' models_usage=None content=' and' type='ModelClientStreamingChunkEvent'\n",
"source='assistant' models_usage=None content=' São' type='ModelClientStreamingChunkEvent'\n",
"source='assistant' models_usage=None content=' Paulo' type='ModelClientStreamingChunkEvent'\n",
"source='assistant' models_usage=None content=' in' type='ModelClientStreamingChunkEvent'\n",
"source='assistant' models_usage=None content=' Brazil' type='ModelClientStreamingChunkEvent'\n",
"source='assistant' models_usage=None content='.' type='ModelClientStreamingChunkEvent'\n",
"Response(chat_message=TextMessage(source='assistant', models_usage=RequestUsage(prompt_tokens=0, completion_tokens=0), content='Two cities in South America are Buenos Aires in Argentina and São Paulo in Brazil.', type='TextMessage'), inner_messages=[])\n"
]
}
],
"source": [
"model_client = OpenAIChatCompletionClient(model=\"gpt-4o\")\n",
"\n",
"streaming_assistant = AssistantAgent(\n",
" name=\"assistant\",\n",
" model_client=model_client,\n",
" system_message=\"You are a helpful assistant.\",\n",
" model_client_stream=True, # Enable streaming tokens.\n",
")\n",
"\n",
"# Use an async function and asyncio.run() in a script.\n",
"async for message in streaming_assistant.on_messages_stream( # type: ignore\n",
" [TextMessage(content=\"Name two cities in South America\", source=\"user\")],\n",
" cancellation_token=CancellationToken(),\n",
"):\n",
" print(message)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"You can see the streaming chunks in the output above.\n",
"The chunks are generated by the model client and are yielded by the agent as they are received.\n",
"The final response, the concatenation of all the chunks, is yielded right after the last chunk.\n",
"\n",
"Similarly, {py:meth}`~autogen_agentchat.agents.BaseChatAgent.run_stream` will also yield the same streaming chunks,\n",
"followed by a full text message right after the last chunk."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"source='user' models_usage=None content='Name two cities in North America.' type='TextMessage'\n",
"source='assistant' models_usage=None content='Two' type='ModelClientStreamingChunkEvent'\n",
"source='assistant' models_usage=None content=' cities' type='ModelClientStreamingChunkEvent'\n",
"source='assistant' models_usage=None content=' in' type='ModelClientStreamingChunkEvent'\n",
"source='assistant' models_usage=None content=' North' type='ModelClientStreamingChunkEvent'\n",
"source='assistant' models_usage=None content=' America' type='ModelClientStreamingChunkEvent'\n",
"source='assistant' models_usage=None content=' are' type='ModelClientStreamingChunkEvent'\n",
"source='assistant' models_usage=None content=' New' type='ModelClientStreamingChunkEvent'\n",
"source='assistant' models_usage=None content=' York' type='ModelClientStreamingChunkEvent'\n",
"source='assistant' models_usage=None content=' City' type='ModelClientStreamingChunkEvent'\n",
"source='assistant' models_usage=None content=' in' type='ModelClientStreamingChunkEvent'\n",
"source='assistant' models_usage=None content=' the' type='ModelClientStreamingChunkEvent'\n",
"source='assistant' models_usage=None content=' United' type='ModelClientStreamingChunkEvent'\n",
"source='assistant' models_usage=None content=' States' type='ModelClientStreamingChunkEvent'\n",
"source='assistant' models_usage=None content=' and' type='ModelClientStreamingChunkEvent'\n",
"source='assistant' models_usage=None content=' Toronto' type='ModelClientStreamingChunkEvent'\n",
"source='assistant' models_usage=None content=' in' type='ModelClientStreamingChunkEvent'\n",
"source='assistant' models_usage=None content=' Canada' type='ModelClientStreamingChunkEvent'\n",
"source='assistant' models_usage=None content='.' type='ModelClientStreamingChunkEvent'\n",
"source='assistant' models_usage=RequestUsage(prompt_tokens=0, completion_tokens=0) content='Two cities in North America are New York City in the United States and Toronto in Canada.' type='TextMessage'\n",
"TaskResult(messages=[TextMessage(source='user', models_usage=None, content='Name two cities in North America.', type='TextMessage'), TextMessage(source='assistant', models_usage=RequestUsage(prompt_tokens=0, completion_tokens=0), content='Two cities in North America are New York City in the United States and Toronto in Canada.', type='TextMessage')], stop_reason=None)\n"
]
}
],
"source": [
"async for message in streaming_assistant.run_stream(task=\"Name two cities in North America.\"): # type: ignore\n",
" print(message)"
]
},
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},

View File

@ -185,6 +185,9 @@ class ReplayChatCompletionClient(ChatCompletionClient):
yield token + " " yield token + " "
else: else:
yield token yield token
yield CreateResult(
finish_reason="stop", content=response, usage=self._cur_usage, cached=self._cached_bool_value
)
self._update_total_usage() self._update_total_usage()
else: else:
self._cur_usage = RequestUsage( self._cur_usage = RequestUsage(
@ -226,7 +229,7 @@ class ReplayChatCompletionClient(ChatCompletionClient):
total_tokens += len(tokens) total_tokens += len(tokens)
all_tokens.extend(tokens) all_tokens.extend(tokens)
else: else:
logger.warning("Token count has been done only on string content", RuntimeWarning) logger.warning("Token count has been done only on string content")
elif isinstance(messages, Sequence): elif isinstance(messages, Sequence):
for message in messages: for message in messages:
if isinstance(message.content, str): # type: ignore [reportAttributeAccessIssue, union-attr] if isinstance(message.content, str): # type: ignore [reportAttributeAccessIssue, union-attr]
@ -234,7 +237,7 @@ class ReplayChatCompletionClient(ChatCompletionClient):
total_tokens += len(tokens) total_tokens += len(tokens)
all_tokens.extend(tokens) all_tokens.extend(tokens)
else: else:
logger.warning("Token count has been done only on string content", RuntimeWarning) logger.warning("Token count has been done only on string content")
return all_tokens, total_tokens return all_tokens, total_tokens
def _update_total_usage(self) -> None: def _update_total_usage(self) -> None:

View File

@ -16,6 +16,7 @@ from autogen_agentchat.base import Response, TaskResult
from autogen_agentchat.messages import ( from autogen_agentchat.messages import (
AgentEvent, AgentEvent,
ChatMessage, ChatMessage,
ModelClientStreamingChunkEvent,
MultiModalMessage, MultiModalMessage,
UserInputRequestedEvent, UserInputRequestedEvent,
) )
@ -185,6 +186,9 @@ async def RichConsole(
elif isinstance(message, UserInputRequestedEvent): elif isinstance(message, UserInputRequestedEvent):
if user_input_manager is not None: if user_input_manager is not None:
user_input_manager.notify_event_received(message.request_id) user_input_manager.notify_event_received(message.request_id)
elif isinstance(message, ModelClientStreamingChunkEvent):
# TODO: Handle model client streaming chunk events.
pass
else: else:
# Cast required for mypy to be happy # Cast required for mypy to be happy
message = cast(AgentEvent | ChatMessage, message) # type: ignore message = cast(AgentEvent | ChatMessage, message) # type: ignore

View File

@ -107,14 +107,14 @@ async def test_cache_create_stream() -> None:
async for completion in cached_client.create_stream( async for completion in cached_client.create_stream(
[system_prompt, UserMessage(content=prompts[0], source="user")] [system_prompt, UserMessage(content=prompts[0], source="user")]
): ):
original_streamed_results.append(completion) original_streamed_results.append(copy.copy(completion))
total_usage0 = copy.copy(cached_client.total_usage()) total_usage0 = copy.copy(cached_client.total_usage())
cached_completion_results: List[Union[str, CreateResult]] = [] cached_completion_results: List[Union[str, CreateResult]] = []
async for completion in cached_client.create_stream( async for completion in cached_client.create_stream(
[system_prompt, UserMessage(content=prompts[0], source="user")] [system_prompt, UserMessage(content=prompts[0], source="user")]
): ):
cached_completion_results.append(completion) cached_completion_results.append(copy.copy(completion))
total_usage1 = copy.copy(cached_client.total_usage()) total_usage1 = copy.copy(cached_client.total_usage())
assert total_usage1.prompt_tokens == total_usage0.prompt_tokens assert total_usage1.prompt_tokens == total_usage0.prompt_tokens

View File

@ -67,12 +67,16 @@ async def test_replay_chat_completion_client_create_stream() -> None:
reply_model_client = ReplayChatCompletionClient(messages) reply_model_client = ReplayChatCompletionClient(messages)
for i in range(num_messages): for i in range(num_messages):
result: List[str] = [] chunks: List[str] = []
result: CreateResult | None = None
async for completion in reply_model_client.create_stream([UserMessage(content="dummy", source="_")]): async for completion in reply_model_client.create_stream([UserMessage(content="dummy", source="_")]):
text = completion.content if isinstance(completion, CreateResult) else completion if isinstance(completion, CreateResult):
assert isinstance(text, str) result = completion
result.append(text) else:
assert "".join(result) == messages[i] assert isinstance(completion, str)
chunks.append(completion)
assert result is not None
assert "".join(chunks) == messages[i] == result.content
with pytest.raises(ValueError, match="No more mock responses available"): with pytest.raises(ValueError, match="No more mock responses available"):
await reply_model_client.create([UserMessage(content="dummy", source="_")]) await reply_model_client.create([UserMessage(content="dummy", source="_")])