feat: introduce ModelClientStreamingChunkEvent for streaming model output and update handling in agents and console (#5208)

Resolves #3983

* introduce `model_client_stream` parameter in `AssistantAgent` to
enable token-level streaming output.
* introduce `ModelClientStreamingChunkEvent` as a type of `AgentEvent`
to pass the streaming chunks to the application via `run_stream` and
`on_messages_stream`. Although this will not affect the inner messages
list in the final `Response` or `TaskResult`.
* handle this new message type in `Console`.
This commit is contained in:
Eric Zhu 2025-01-28 18:49:02 -08:00 committed by GitHub
parent 8a0daf8285
commit 225eb9d0b2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 330 additions and 32 deletions

View File

@ -22,13 +22,14 @@ from autogen_core.model_context import (
from autogen_core.models import (
AssistantMessage,
ChatCompletionClient,
CreateResult,
FunctionExecutionResult,
FunctionExecutionResultMessage,
LLMMessage,
SystemMessage,
UserMessage,
)
from autogen_core.tools import FunctionTool, BaseTool
from autogen_core.tools import BaseTool, FunctionTool
from pydantic import BaseModel
from typing_extensions import Self
@ -40,6 +41,7 @@ from ..messages import (
ChatMessage,
HandoffMessage,
MemoryQueryEvent,
ModelClientStreamingChunkEvent,
MultiModalMessage,
TextMessage,
ToolCallExecutionEvent,
@ -62,6 +64,7 @@ class AssistantAgentConfig(BaseModel):
model_context: ComponentModel | None = None
description: str
system_message: str | None = None
model_client_stream: bool
reflect_on_tool_use: bool
tool_call_summary_format: str
@ -126,6 +129,14 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
This will limit the number of recent messages sent to the model and can be useful
when the model has a limit on the number of tokens it can process.
Streaming mode:
The assistant agent can be used in streaming mode by setting `model_client_stream=True`.
In this mode, the :meth:`on_messages_stream` and :meth:`BaseChatAgent.run_stream` methods will also yield
:class:`~autogen_agentchat.messages.ModelClientStreamingChunkEvent`
messages as the model client produces chunks of response.
The chunk messages will not be included in the final response's inner messages.
Args:
name (str): The name of the agent.
@ -138,6 +149,9 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
model_context (ChatCompletionContext | None, optional): The model context for storing and retrieving :class:`~autogen_core.models.LLMMessage`. It can be preloaded with initial messages. The initial messages will be cleared when the agent is reset.
description (str, optional): The description of the agent.
system_message (str, optional): The system message for the model. If provided, it will be prepended to the messages in the model context when making an inference. Set to `None` to disable.
model_client_stream (bool, optional): If `True`, the model client will be used in streaming mode.
:meth:`on_messages_stream` and :meth:`BaseChatAgent.run_stream` methods will also yield :class:`~autogen_agentchat.messages.ModelClientStreamingChunkEvent`
messages as the model client produces chunks of response. Defaults to `False`.
reflect_on_tool_use (bool, optional): If `True`, the agent will make another model inference using the tool call and result
to generate a response. If `False`, the tool call result will be returned as the response. Defaults to `False`.
tool_call_summary_format (str, optional): The format string used to create a tool call summary for every tool call result.
@ -268,12 +282,14 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
system_message: (
str | None
) = "You are a helpful AI assistant. Solve tasks using your tools. Reply with TERMINATE when the task has been completed.",
model_client_stream: bool = False,
reflect_on_tool_use: bool = False,
tool_call_summary_format: str = "{result}",
memory: Sequence[Memory] | None = None,
):
super().__init__(name=name, description=description)
self._model_client = model_client
self._model_client_stream = model_client_stream
self._memory = None
if memory is not None:
if isinstance(memory, list):
@ -340,7 +356,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
@property
def produced_message_types(self) -> Sequence[type[ChatMessage]]:
"""The types of messages that the assistant agent produces."""
"""The types of final response messages that the assistant agent produces."""
message_types: List[type[ChatMessage]] = [TextMessage]
if self._handoffs:
message_types.append(HandoffMessage)
@ -383,9 +399,23 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
# Generate an inference result based on the current model context.
llm_messages = self._system_messages + await self._model_context.get_messages()
model_result = await self._model_client.create(
llm_messages, tools=self._tools + self._handoff_tools, cancellation_token=cancellation_token
)
model_result: CreateResult | None = None
if self._model_client_stream:
# Stream the model client.
async for chunk in self._model_client.create_stream(
llm_messages, tools=self._tools + self._handoff_tools, cancellation_token=cancellation_token
):
if isinstance(chunk, CreateResult):
model_result = chunk
elif isinstance(chunk, str):
yield ModelClientStreamingChunkEvent(content=chunk, source=self.name)
else:
raise RuntimeError(f"Invalid chunk type: {type(chunk)}")
assert isinstance(model_result, CreateResult)
else:
model_result = await self._model_client.create(
llm_messages, tools=self._tools + self._handoff_tools, cancellation_token=cancellation_token
)
# Add the response to the model context.
await self._model_context.add_message(AssistantMessage(content=model_result.content, source=self.name))
@ -465,14 +495,34 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
if self._reflect_on_tool_use:
# Generate another inference result based on the tool call and result.
llm_messages = self._system_messages + await self._model_context.get_messages()
model_result = await self._model_client.create(llm_messages, cancellation_token=cancellation_token)
assert isinstance(model_result.content, str)
reflection_model_result: CreateResult | None = None
if self._model_client_stream:
# Stream the model client.
async for chunk in self._model_client.create_stream(
llm_messages, cancellation_token=cancellation_token
):
if isinstance(chunk, CreateResult):
reflection_model_result = chunk
elif isinstance(chunk, str):
yield ModelClientStreamingChunkEvent(content=chunk, source=self.name)
else:
raise RuntimeError(f"Invalid chunk type: {type(chunk)}")
assert isinstance(reflection_model_result, CreateResult)
else:
reflection_model_result = await self._model_client.create(
llm_messages, cancellation_token=cancellation_token
)
assert isinstance(reflection_model_result.content, str)
# Add the response to the model context.
await self._model_context.add_message(AssistantMessage(content=model_result.content, source=self.name))
await self._model_context.add_message(
AssistantMessage(content=reflection_model_result.content, source=self.name)
)
# Yield the response.
yield Response(
chat_message=TextMessage(
content=model_result.content, source=self.name, models_usage=model_result.usage
content=reflection_model_result.content,
source=self.name,
models_usage=reflection_model_result.usage,
),
inner_messages=inner_messages,
)
@ -538,6 +588,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
system_message=self._system_messages[0].content
if self._system_messages and isinstance(self._system_messages[0].content, str)
else None,
model_client_stream=self._model_client_stream,
reflect_on_tool_use=self._reflect_on_tool_use,
tool_call_summary_format=self._tool_call_summary_format,
)
@ -553,6 +604,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
model_context=None,
description=config.description,
system_message=config.system_message,
model_client_stream=config.model_client_stream,
reflect_on_tool_use=config.reflect_on_tool_use,
tool_call_summary_format=config.tool_call_summary_format,
)

View File

@ -9,6 +9,7 @@ from ..messages import (
AgentEvent,
BaseChatMessage,
ChatMessage,
ModelClientStreamingChunkEvent,
TextMessage,
)
from ..state import BaseState
@ -178,8 +179,11 @@ class BaseChatAgent(ChatAgent, ABC, ComponentBase[BaseModel]):
output_messages.append(message.chat_message)
yield TaskResult(messages=output_messages)
else:
output_messages.append(message)
yield message
if isinstance(message, ModelClientStreamingChunkEvent):
# Skip the model client streaming chunk events.
continue
output_messages.append(message)
@abstractmethod
async def on_reset(self, cancellation_token: CancellationToken) -> None:

View File

@ -13,6 +13,7 @@ from ..messages import (
AgentEvent,
BaseChatMessage,
ChatMessage,
ModelClientStreamingChunkEvent,
TextMessage,
)
from ._base_chat_agent import BaseChatAgent
@ -150,6 +151,9 @@ class SocietyOfMindAgent(BaseChatAgent, Component[SocietyOfMindAgentConfig]):
# Skip the task messages.
continue
yield inner_msg
if isinstance(inner_msg, ModelClientStreamingChunkEvent):
# Skip the model client streaming chunk events.
continue
inner_messages.append(inner_msg)
assert result is not None

View File

@ -1,7 +1,7 @@
import logging
from typing import Any, Dict
from autogen_core.tools import FunctionTool, BaseTool
from autogen_core.tools import BaseTool, FunctionTool
from pydantic import BaseModel, Field, model_validator
from .. import EVENT_LOGGER_NAME

View File

@ -128,6 +128,15 @@ class MemoryQueryEvent(BaseAgentEvent):
type: Literal["MemoryQueryEvent"] = "MemoryQueryEvent"
class ModelClientStreamingChunkEvent(BaseAgentEvent):
"""An event signaling a text output chunk from a model client in streaming mode."""
content: str
"""The partial text chunk."""
type: Literal["ModelClientStreamingChunkEvent"] = "ModelClientStreamingChunkEvent"
ChatMessage = Annotated[
TextMessage | MultiModalMessage | StopMessage | ToolCallSummaryMessage | HandoffMessage, Field(discriminator="type")
]
@ -135,7 +144,11 @@ ChatMessage = Annotated[
AgentEvent = Annotated[
ToolCallRequestEvent | ToolCallExecutionEvent | MemoryQueryEvent | UserInputRequestedEvent,
ToolCallRequestEvent
| ToolCallExecutionEvent
| MemoryQueryEvent
| UserInputRequestedEvent
| ModelClientStreamingChunkEvent,
Field(discriminator="type"),
]
"""Events emitted by agents and teams when they work, not used for agent-to-agent communication."""
@ -154,4 +167,5 @@ __all__ = [
"ToolCallSummaryMessage",
"MemoryQueryEvent",
"UserInputRequestedEvent",
"ModelClientStreamingChunkEvent",
]

View File

@ -21,7 +21,7 @@ from pydantic import BaseModel
from ... import EVENT_LOGGER_NAME
from ...base import ChatAgent, TaskResult, Team, TerminationCondition
from ...messages import AgentEvent, BaseChatMessage, ChatMessage, TextMessage
from ...messages import AgentEvent, BaseChatMessage, ChatMessage, ModelClientStreamingChunkEvent, TextMessage
from ...state import TeamState
from ._chat_agent_container import ChatAgentContainer
from ._events import GroupChatMessage, GroupChatReset, GroupChatStart, GroupChatTermination
@ -190,6 +190,9 @@ class BaseGroupChat(Team, ABC, ComponentBase[BaseModel]):
and it may not reset the termination condition.
To gracefully stop the team, use :class:`~autogen_agentchat.conditions.ExternalTermination` instead.
Returns:
result: The result of the task as :class:`~autogen_agentchat.base.TaskResult`. The result contains the messages produced by the team and the stop reason.
Example using the :class:`~autogen_agentchat.teams.RoundRobinGroupChat` team:
@ -279,9 +282,15 @@ class BaseGroupChat(Team, ABC, ComponentBase[BaseModel]):
cancellation_token: CancellationToken | None = None,
) -> AsyncGenerator[AgentEvent | ChatMessage | TaskResult, None]:
"""Run the team and produces a stream of messages and the final result
of the type :class:`TaskResult` as the last item in the stream. Once the
of the type :class:`~autogen_agentchat.base.TaskResult` as the last item in the stream. Once the
team is stopped, the termination condition is reset.
.. note::
If an agent produces :class:`~autogen_agentchat.messages.ModelClientStreamingChunkEvent`,
the message will be yielded in the stream but it will not be included in the
:attr:`~autogen_agentchat.base.TaskResult.messages`.
Args:
task (str | ChatMessage | Sequence[ChatMessage] | None): The task to run the team with. Can be a string, a single :class:`ChatMessage` , or a list of :class:`ChatMessage`.
cancellation_token (CancellationToken | None): The cancellation token to kill the task immediately.
@ -289,6 +298,9 @@ class BaseGroupChat(Team, ABC, ComponentBase[BaseModel]):
and it may not reset the termination condition.
To gracefully stop the team, use :class:`~autogen_agentchat.conditions.ExternalTermination` instead.
Returns:
stream: an :class:`~collections.abc.AsyncGenerator` that yields :class:`~autogen_agentchat.messages.AgentEvent`, :class:`~autogen_agentchat.messages.ChatMessage`, and the final result :class:`~autogen_agentchat.base.TaskResult` as the last item in the stream.
Example using the :class:`~autogen_agentchat.teams.RoundRobinGroupChat` team:
.. code-block:: python
@ -422,6 +434,9 @@ class BaseGroupChat(Team, ABC, ComponentBase[BaseModel]):
if message is None:
break
yield message
if isinstance(message, ModelClientStreamingChunkEvent):
# Skip the model client streaming chunk events.
continue
output_messages.append(message)
# Yield the final result.

View File

@ -10,7 +10,13 @@ from autogen_core.models import RequestUsage
from autogen_agentchat.agents import UserProxyAgent
from autogen_agentchat.base import Response, TaskResult
from autogen_agentchat.messages import AgentEvent, ChatMessage, MultiModalMessage, UserInputRequestedEvent
from autogen_agentchat.messages import (
AgentEvent,
ChatMessage,
ModelClientStreamingChunkEvent,
MultiModalMessage,
UserInputRequestedEvent,
)
def _is_running_in_iterm() -> bool:
@ -106,6 +112,8 @@ async def Console(
last_processed: Optional[T] = None
streaming_chunks: List[str] = []
async for message in stream:
if isinstance(message, TaskResult):
duration = time.time() - start_time
@ -159,13 +167,28 @@ async def Console(
else:
# Cast required for mypy to be happy
message = cast(AgentEvent | ChatMessage, message) # type: ignore
output = f"{'-' * 10} {message.source} {'-' * 10}\n{_message_to_str(message, render_image_iterm=render_image_iterm)}\n"
if message.models_usage:
if output_stats:
output += f"[Prompt tokens: {message.models_usage.prompt_tokens}, Completion tokens: {message.models_usage.completion_tokens}]\n"
total_usage.completion_tokens += message.models_usage.completion_tokens
total_usage.prompt_tokens += message.models_usage.prompt_tokens
await aprint(output, end="")
if not streaming_chunks:
# Print message sender.
await aprint(f"{'-' * 10} {message.source} {'-' * 10}", end="\n")
if isinstance(message, ModelClientStreamingChunkEvent):
await aprint(message.content, end="")
streaming_chunks.append(message.content)
else:
if streaming_chunks:
streaming_chunks.clear()
# Chunked messages are already printed, so we just print a newline.
await aprint("", end="\n")
else:
# Print message content.
await aprint(_message_to_str(message, render_image_iterm=render_image_iterm), end="\n")
if message.models_usage:
if output_stats:
await aprint(
f"[Prompt tokens: {message.models_usage.prompt_tokens}, Completion tokens: {message.models_usage.completion_tokens}]",
end="\n",
)
total_usage.completion_tokens += message.models_usage.completion_tokens
total_usage.prompt_tokens += message.models_usage.prompt_tokens
if last_processed is None:
raise ValueError("No TaskResult or Response was processed.")

View File

@ -11,6 +11,7 @@ from autogen_agentchat.messages import (
ChatMessage,
HandoffMessage,
MemoryQueryEvent,
ModelClientStreamingChunkEvent,
MultiModalMessage,
TextMessage,
ToolCallExecutionEvent,
@ -20,10 +21,11 @@ from autogen_agentchat.messages import (
from autogen_core import FunctionCall, Image
from autogen_core.memory import ListMemory, Memory, MemoryContent, MemoryMimeType, MemoryQueryResult
from autogen_core.model_context import BufferedChatCompletionContext
from autogen_core.models import FunctionExecutionResult, LLMMessage
from autogen_core.models import CreateResult, FunctionExecutionResult, LLMMessage, RequestUsage
from autogen_core.models._model_client import ModelFamily
from autogen_core.tools import FunctionTool
from autogen_ext.models.openai import OpenAIChatCompletionClient
from autogen_ext.models.replay import ReplayChatCompletionClient
from openai.resources.chat.completions import AsyncCompletions
from openai.types.chat.chat_completion import ChatCompletion, Choice
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
@ -776,3 +778,65 @@ async def test_assistant_agent_declarative(monkeypatch: pytest.MonkeyPatch) -> N
)
agent3_config = agent3.dump_component()
assert agent3_config.provider == "autogen_agentchat.agents.AssistantAgent"
@pytest.mark.asyncio
async def test_model_client_stream() -> None:
mock_client = ReplayChatCompletionClient(
[
"Response to message 3",
]
)
agent = AssistantAgent(
"test_agent",
model_client=mock_client,
model_client_stream=True,
)
chunks: List[str] = []
async for message in agent.run_stream(task="task"):
if isinstance(message, TaskResult):
assert message.messages[-1].content == "Response to message 3"
elif isinstance(message, ModelClientStreamingChunkEvent):
chunks.append(message.content)
assert "".join(chunks) == "Response to message 3"
@pytest.mark.asyncio
async def test_model_client_stream_with_tool_calls() -> None:
mock_client = ReplayChatCompletionClient(
[
CreateResult(
content=[
FunctionCall(id="1", name="_pass_function", arguments=r'{"input": "task"}'),
FunctionCall(id="3", name="_echo_function", arguments=r'{"input": "task"}'),
],
finish_reason="function_calls",
usage=RequestUsage(prompt_tokens=10, completion_tokens=5),
cached=False,
),
"Example response 2 to task",
]
)
mock_client._model_info["function_calling"] = True # pyright: ignore
agent = AssistantAgent(
"test_agent",
model_client=mock_client,
model_client_stream=True,
reflect_on_tool_use=True,
tools=[_pass_function, _echo_function],
)
chunks: List[str] = []
async for message in agent.run_stream(task="task"):
if isinstance(message, TaskResult):
assert message.messages[-1].content == "Example response 2 to task"
assert message.messages[1].content == [
FunctionCall(id="1", name="_pass_function", arguments=r'{"input": "task"}'),
FunctionCall(id="3", name="_echo_function", arguments=r'{"input": "task"}'),
]
assert message.messages[2].content == [
FunctionExecutionResult(call_id="1", content="pass"),
FunctionExecutionResult(call_id="3", content="task"),
]
elif isinstance(message, ModelClientStreamingChunkEvent):
chunks.append(message.content)
assert "".join(chunks) == "Example response 2 to task"

View File

@ -403,6 +403,117 @@
"await Console(agent.run_stream(task=\"I am happy.\"))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Streaming Tokens\n",
"\n",
"You can stream the tokens generated by the model client by setting `model_client_stream=True`.\n",
"This will cause the agent to yield {py:class}`~autogen_agentchat.messages.ModelClientStreamingChunkEvent` messages\n",
"in {py:meth}`~autogen_agentchat.agents.BaseChatAgent.on_messages_stream` and {py:meth}`~autogen_agentchat.agents.BaseChatAgent.run_stream`.\n",
"\n",
"The underlying model API must support streaming tokens for this to work.\n",
"Please check with your model provider to see if this is supported."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"source='assistant' models_usage=None content='Two' type='ModelClientStreamingChunkEvent'\n",
"source='assistant' models_usage=None content=' cities' type='ModelClientStreamingChunkEvent'\n",
"source='assistant' models_usage=None content=' in' type='ModelClientStreamingChunkEvent'\n",
"source='assistant' models_usage=None content=' South' type='ModelClientStreamingChunkEvent'\n",
"source='assistant' models_usage=None content=' America' type='ModelClientStreamingChunkEvent'\n",
"source='assistant' models_usage=None content=' are' type='ModelClientStreamingChunkEvent'\n",
"source='assistant' models_usage=None content=' Buenos' type='ModelClientStreamingChunkEvent'\n",
"source='assistant' models_usage=None content=' Aires' type='ModelClientStreamingChunkEvent'\n",
"source='assistant' models_usage=None content=' in' type='ModelClientStreamingChunkEvent'\n",
"source='assistant' models_usage=None content=' Argentina' type='ModelClientStreamingChunkEvent'\n",
"source='assistant' models_usage=None content=' and' type='ModelClientStreamingChunkEvent'\n",
"source='assistant' models_usage=None content=' São' type='ModelClientStreamingChunkEvent'\n",
"source='assistant' models_usage=None content=' Paulo' type='ModelClientStreamingChunkEvent'\n",
"source='assistant' models_usage=None content=' in' type='ModelClientStreamingChunkEvent'\n",
"source='assistant' models_usage=None content=' Brazil' type='ModelClientStreamingChunkEvent'\n",
"source='assistant' models_usage=None content='.' type='ModelClientStreamingChunkEvent'\n",
"Response(chat_message=TextMessage(source='assistant', models_usage=RequestUsage(prompt_tokens=0, completion_tokens=0), content='Two cities in South America are Buenos Aires in Argentina and São Paulo in Brazil.', type='TextMessage'), inner_messages=[])\n"
]
}
],
"source": [
"model_client = OpenAIChatCompletionClient(model=\"gpt-4o\")\n",
"\n",
"streaming_assistant = AssistantAgent(\n",
" name=\"assistant\",\n",
" model_client=model_client,\n",
" system_message=\"You are a helpful assistant.\",\n",
" model_client_stream=True, # Enable streaming tokens.\n",
")\n",
"\n",
"# Use an async function and asyncio.run() in a script.\n",
"async for message in streaming_assistant.on_messages_stream( # type: ignore\n",
" [TextMessage(content=\"Name two cities in South America\", source=\"user\")],\n",
" cancellation_token=CancellationToken(),\n",
"):\n",
" print(message)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"You can see the streaming chunks in the output above.\n",
"The chunks are generated by the model client and are yielded by the agent as they are received.\n",
"The final response, the concatenation of all the chunks, is yielded right after the last chunk.\n",
"\n",
"Similarly, {py:meth}`~autogen_agentchat.agents.BaseChatAgent.run_stream` will also yield the same streaming chunks,\n",
"followed by a full text message right after the last chunk."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"source='user' models_usage=None content='Name two cities in North America.' type='TextMessage'\n",
"source='assistant' models_usage=None content='Two' type='ModelClientStreamingChunkEvent'\n",
"source='assistant' models_usage=None content=' cities' type='ModelClientStreamingChunkEvent'\n",
"source='assistant' models_usage=None content=' in' type='ModelClientStreamingChunkEvent'\n",
"source='assistant' models_usage=None content=' North' type='ModelClientStreamingChunkEvent'\n",
"source='assistant' models_usage=None content=' America' type='ModelClientStreamingChunkEvent'\n",
"source='assistant' models_usage=None content=' are' type='ModelClientStreamingChunkEvent'\n",
"source='assistant' models_usage=None content=' New' type='ModelClientStreamingChunkEvent'\n",
"source='assistant' models_usage=None content=' York' type='ModelClientStreamingChunkEvent'\n",
"source='assistant' models_usage=None content=' City' type='ModelClientStreamingChunkEvent'\n",
"source='assistant' models_usage=None content=' in' type='ModelClientStreamingChunkEvent'\n",
"source='assistant' models_usage=None content=' the' type='ModelClientStreamingChunkEvent'\n",
"source='assistant' models_usage=None content=' United' type='ModelClientStreamingChunkEvent'\n",
"source='assistant' models_usage=None content=' States' type='ModelClientStreamingChunkEvent'\n",
"source='assistant' models_usage=None content=' and' type='ModelClientStreamingChunkEvent'\n",
"source='assistant' models_usage=None content=' Toronto' type='ModelClientStreamingChunkEvent'\n",
"source='assistant' models_usage=None content=' in' type='ModelClientStreamingChunkEvent'\n",
"source='assistant' models_usage=None content=' Canada' type='ModelClientStreamingChunkEvent'\n",
"source='assistant' models_usage=None content='.' type='ModelClientStreamingChunkEvent'\n",
"source='assistant' models_usage=RequestUsage(prompt_tokens=0, completion_tokens=0) content='Two cities in North America are New York City in the United States and Toronto in Canada.' type='TextMessage'\n",
"TaskResult(messages=[TextMessage(source='user', models_usage=None, content='Name two cities in North America.', type='TextMessage'), TextMessage(source='assistant', models_usage=RequestUsage(prompt_tokens=0, completion_tokens=0), content='Two cities in North America are New York City in the United States and Toronto in Canada.', type='TextMessage')], stop_reason=None)\n"
]
}
],
"source": [
"async for message in streaming_assistant.run_stream(task=\"Name two cities in North America.\"): # type: ignore\n",
" print(message)"
]
},
{
"cell_type": "markdown",
"metadata": {},

View File

@ -185,6 +185,9 @@ class ReplayChatCompletionClient(ChatCompletionClient):
yield token + " "
else:
yield token
yield CreateResult(
finish_reason="stop", content=response, usage=self._cur_usage, cached=self._cached_bool_value
)
self._update_total_usage()
else:
self._cur_usage = RequestUsage(
@ -226,7 +229,7 @@ class ReplayChatCompletionClient(ChatCompletionClient):
total_tokens += len(tokens)
all_tokens.extend(tokens)
else:
logger.warning("Token count has been done only on string content", RuntimeWarning)
logger.warning("Token count has been done only on string content")
elif isinstance(messages, Sequence):
for message in messages:
if isinstance(message.content, str): # type: ignore [reportAttributeAccessIssue, union-attr]
@ -234,7 +237,7 @@ class ReplayChatCompletionClient(ChatCompletionClient):
total_tokens += len(tokens)
all_tokens.extend(tokens)
else:
logger.warning("Token count has been done only on string content", RuntimeWarning)
logger.warning("Token count has been done only on string content")
return all_tokens, total_tokens
def _update_total_usage(self) -> None:

View File

@ -16,6 +16,7 @@ from autogen_agentchat.base import Response, TaskResult
from autogen_agentchat.messages import (
AgentEvent,
ChatMessage,
ModelClientStreamingChunkEvent,
MultiModalMessage,
UserInputRequestedEvent,
)
@ -185,6 +186,9 @@ async def RichConsole(
elif isinstance(message, UserInputRequestedEvent):
if user_input_manager is not None:
user_input_manager.notify_event_received(message.request_id)
elif isinstance(message, ModelClientStreamingChunkEvent):
# TODO: Handle model client streaming chunk events.
pass
else:
# Cast required for mypy to be happy
message = cast(AgentEvent | ChatMessage, message) # type: ignore

View File

@ -107,14 +107,14 @@ async def test_cache_create_stream() -> None:
async for completion in cached_client.create_stream(
[system_prompt, UserMessage(content=prompts[0], source="user")]
):
original_streamed_results.append(completion)
original_streamed_results.append(copy.copy(completion))
total_usage0 = copy.copy(cached_client.total_usage())
cached_completion_results: List[Union[str, CreateResult]] = []
async for completion in cached_client.create_stream(
[system_prompt, UserMessage(content=prompts[0], source="user")]
):
cached_completion_results.append(completion)
cached_completion_results.append(copy.copy(completion))
total_usage1 = copy.copy(cached_client.total_usage())
assert total_usage1.prompt_tokens == total_usage0.prompt_tokens

View File

@ -67,12 +67,16 @@ async def test_replay_chat_completion_client_create_stream() -> None:
reply_model_client = ReplayChatCompletionClient(messages)
for i in range(num_messages):
result: List[str] = []
chunks: List[str] = []
result: CreateResult | None = None
async for completion in reply_model_client.create_stream([UserMessage(content="dummy", source="_")]):
text = completion.content if isinstance(completion, CreateResult) else completion
assert isinstance(text, str)
result.append(text)
assert "".join(result) == messages[i]
if isinstance(completion, CreateResult):
result = completion
else:
assert isinstance(completion, str)
chunks.append(completion)
assert result is not None
assert "".join(chunks) == messages[i] == result.content
with pytest.raises(ValueError, match="No more mock responses available"):
await reply_model_client.create([UserMessage(content="dummy", source="_")])