mirror of https://github.com/microsoft/autogen.git
feat: introduce ModelClientStreamingChunkEvent for streaming model output and update handling in agents and console (#5208)
Resolves #3983 * introduce `model_client_stream` parameter in `AssistantAgent` to enable token-level streaming output. * introduce `ModelClientStreamingChunkEvent` as a type of `AgentEvent` to pass the streaming chunks to the application via `run_stream` and `on_messages_stream`. Although this will not affect the inner messages list in the final `Response` or `TaskResult`. * handle this new message type in `Console`.
This commit is contained in:
parent
8a0daf8285
commit
225eb9d0b2
|
@ -22,13 +22,14 @@ from autogen_core.model_context import (
|
|||
from autogen_core.models import (
|
||||
AssistantMessage,
|
||||
ChatCompletionClient,
|
||||
CreateResult,
|
||||
FunctionExecutionResult,
|
||||
FunctionExecutionResultMessage,
|
||||
LLMMessage,
|
||||
SystemMessage,
|
||||
UserMessage,
|
||||
)
|
||||
from autogen_core.tools import FunctionTool, BaseTool
|
||||
from autogen_core.tools import BaseTool, FunctionTool
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import Self
|
||||
|
||||
|
@ -40,6 +41,7 @@ from ..messages import (
|
|||
ChatMessage,
|
||||
HandoffMessage,
|
||||
MemoryQueryEvent,
|
||||
ModelClientStreamingChunkEvent,
|
||||
MultiModalMessage,
|
||||
TextMessage,
|
||||
ToolCallExecutionEvent,
|
||||
|
@ -62,6 +64,7 @@ class AssistantAgentConfig(BaseModel):
|
|||
model_context: ComponentModel | None = None
|
||||
description: str
|
||||
system_message: str | None = None
|
||||
model_client_stream: bool
|
||||
reflect_on_tool_use: bool
|
||||
tool_call_summary_format: str
|
||||
|
||||
|
@ -126,6 +129,14 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
|
|||
This will limit the number of recent messages sent to the model and can be useful
|
||||
when the model has a limit on the number of tokens it can process.
|
||||
|
||||
Streaming mode:
|
||||
|
||||
The assistant agent can be used in streaming mode by setting `model_client_stream=True`.
|
||||
In this mode, the :meth:`on_messages_stream` and :meth:`BaseChatAgent.run_stream` methods will also yield
|
||||
:class:`~autogen_agentchat.messages.ModelClientStreamingChunkEvent`
|
||||
messages as the model client produces chunks of response.
|
||||
The chunk messages will not be included in the final response's inner messages.
|
||||
|
||||
|
||||
Args:
|
||||
name (str): The name of the agent.
|
||||
|
@ -138,6 +149,9 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
|
|||
model_context (ChatCompletionContext | None, optional): The model context for storing and retrieving :class:`~autogen_core.models.LLMMessage`. It can be preloaded with initial messages. The initial messages will be cleared when the agent is reset.
|
||||
description (str, optional): The description of the agent.
|
||||
system_message (str, optional): The system message for the model. If provided, it will be prepended to the messages in the model context when making an inference. Set to `None` to disable.
|
||||
model_client_stream (bool, optional): If `True`, the model client will be used in streaming mode.
|
||||
:meth:`on_messages_stream` and :meth:`BaseChatAgent.run_stream` methods will also yield :class:`~autogen_agentchat.messages.ModelClientStreamingChunkEvent`
|
||||
messages as the model client produces chunks of response. Defaults to `False`.
|
||||
reflect_on_tool_use (bool, optional): If `True`, the agent will make another model inference using the tool call and result
|
||||
to generate a response. If `False`, the tool call result will be returned as the response. Defaults to `False`.
|
||||
tool_call_summary_format (str, optional): The format string used to create a tool call summary for every tool call result.
|
||||
|
@ -268,12 +282,14 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
|
|||
system_message: (
|
||||
str | None
|
||||
) = "You are a helpful AI assistant. Solve tasks using your tools. Reply with TERMINATE when the task has been completed.",
|
||||
model_client_stream: bool = False,
|
||||
reflect_on_tool_use: bool = False,
|
||||
tool_call_summary_format: str = "{result}",
|
||||
memory: Sequence[Memory] | None = None,
|
||||
):
|
||||
super().__init__(name=name, description=description)
|
||||
self._model_client = model_client
|
||||
self._model_client_stream = model_client_stream
|
||||
self._memory = None
|
||||
if memory is not None:
|
||||
if isinstance(memory, list):
|
||||
|
@ -340,7 +356,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
|
|||
|
||||
@property
|
||||
def produced_message_types(self) -> Sequence[type[ChatMessage]]:
|
||||
"""The types of messages that the assistant agent produces."""
|
||||
"""The types of final response messages that the assistant agent produces."""
|
||||
message_types: List[type[ChatMessage]] = [TextMessage]
|
||||
if self._handoffs:
|
||||
message_types.append(HandoffMessage)
|
||||
|
@ -383,9 +399,23 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
|
|||
|
||||
# Generate an inference result based on the current model context.
|
||||
llm_messages = self._system_messages + await self._model_context.get_messages()
|
||||
model_result = await self._model_client.create(
|
||||
llm_messages, tools=self._tools + self._handoff_tools, cancellation_token=cancellation_token
|
||||
)
|
||||
model_result: CreateResult | None = None
|
||||
if self._model_client_stream:
|
||||
# Stream the model client.
|
||||
async for chunk in self._model_client.create_stream(
|
||||
llm_messages, tools=self._tools + self._handoff_tools, cancellation_token=cancellation_token
|
||||
):
|
||||
if isinstance(chunk, CreateResult):
|
||||
model_result = chunk
|
||||
elif isinstance(chunk, str):
|
||||
yield ModelClientStreamingChunkEvent(content=chunk, source=self.name)
|
||||
else:
|
||||
raise RuntimeError(f"Invalid chunk type: {type(chunk)}")
|
||||
assert isinstance(model_result, CreateResult)
|
||||
else:
|
||||
model_result = await self._model_client.create(
|
||||
llm_messages, tools=self._tools + self._handoff_tools, cancellation_token=cancellation_token
|
||||
)
|
||||
|
||||
# Add the response to the model context.
|
||||
await self._model_context.add_message(AssistantMessage(content=model_result.content, source=self.name))
|
||||
|
@ -465,14 +495,34 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
|
|||
if self._reflect_on_tool_use:
|
||||
# Generate another inference result based on the tool call and result.
|
||||
llm_messages = self._system_messages + await self._model_context.get_messages()
|
||||
model_result = await self._model_client.create(llm_messages, cancellation_token=cancellation_token)
|
||||
assert isinstance(model_result.content, str)
|
||||
reflection_model_result: CreateResult | None = None
|
||||
if self._model_client_stream:
|
||||
# Stream the model client.
|
||||
async for chunk in self._model_client.create_stream(
|
||||
llm_messages, cancellation_token=cancellation_token
|
||||
):
|
||||
if isinstance(chunk, CreateResult):
|
||||
reflection_model_result = chunk
|
||||
elif isinstance(chunk, str):
|
||||
yield ModelClientStreamingChunkEvent(content=chunk, source=self.name)
|
||||
else:
|
||||
raise RuntimeError(f"Invalid chunk type: {type(chunk)}")
|
||||
assert isinstance(reflection_model_result, CreateResult)
|
||||
else:
|
||||
reflection_model_result = await self._model_client.create(
|
||||
llm_messages, cancellation_token=cancellation_token
|
||||
)
|
||||
assert isinstance(reflection_model_result.content, str)
|
||||
# Add the response to the model context.
|
||||
await self._model_context.add_message(AssistantMessage(content=model_result.content, source=self.name))
|
||||
await self._model_context.add_message(
|
||||
AssistantMessage(content=reflection_model_result.content, source=self.name)
|
||||
)
|
||||
# Yield the response.
|
||||
yield Response(
|
||||
chat_message=TextMessage(
|
||||
content=model_result.content, source=self.name, models_usage=model_result.usage
|
||||
content=reflection_model_result.content,
|
||||
source=self.name,
|
||||
models_usage=reflection_model_result.usage,
|
||||
),
|
||||
inner_messages=inner_messages,
|
||||
)
|
||||
|
@ -538,6 +588,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
|
|||
system_message=self._system_messages[0].content
|
||||
if self._system_messages and isinstance(self._system_messages[0].content, str)
|
||||
else None,
|
||||
model_client_stream=self._model_client_stream,
|
||||
reflect_on_tool_use=self._reflect_on_tool_use,
|
||||
tool_call_summary_format=self._tool_call_summary_format,
|
||||
)
|
||||
|
@ -553,6 +604,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
|
|||
model_context=None,
|
||||
description=config.description,
|
||||
system_message=config.system_message,
|
||||
model_client_stream=config.model_client_stream,
|
||||
reflect_on_tool_use=config.reflect_on_tool_use,
|
||||
tool_call_summary_format=config.tool_call_summary_format,
|
||||
)
|
||||
|
|
|
@ -9,6 +9,7 @@ from ..messages import (
|
|||
AgentEvent,
|
||||
BaseChatMessage,
|
||||
ChatMessage,
|
||||
ModelClientStreamingChunkEvent,
|
||||
TextMessage,
|
||||
)
|
||||
from ..state import BaseState
|
||||
|
@ -178,8 +179,11 @@ class BaseChatAgent(ChatAgent, ABC, ComponentBase[BaseModel]):
|
|||
output_messages.append(message.chat_message)
|
||||
yield TaskResult(messages=output_messages)
|
||||
else:
|
||||
output_messages.append(message)
|
||||
yield message
|
||||
if isinstance(message, ModelClientStreamingChunkEvent):
|
||||
# Skip the model client streaming chunk events.
|
||||
continue
|
||||
output_messages.append(message)
|
||||
|
||||
@abstractmethod
|
||||
async def on_reset(self, cancellation_token: CancellationToken) -> None:
|
||||
|
|
|
@ -13,6 +13,7 @@ from ..messages import (
|
|||
AgentEvent,
|
||||
BaseChatMessage,
|
||||
ChatMessage,
|
||||
ModelClientStreamingChunkEvent,
|
||||
TextMessage,
|
||||
)
|
||||
from ._base_chat_agent import BaseChatAgent
|
||||
|
@ -150,6 +151,9 @@ class SocietyOfMindAgent(BaseChatAgent, Component[SocietyOfMindAgentConfig]):
|
|||
# Skip the task messages.
|
||||
continue
|
||||
yield inner_msg
|
||||
if isinstance(inner_msg, ModelClientStreamingChunkEvent):
|
||||
# Skip the model client streaming chunk events.
|
||||
continue
|
||||
inner_messages.append(inner_msg)
|
||||
assert result is not None
|
||||
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import logging
|
||||
from typing import Any, Dict
|
||||
|
||||
from autogen_core.tools import FunctionTool, BaseTool
|
||||
from autogen_core.tools import BaseTool, FunctionTool
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
from .. import EVENT_LOGGER_NAME
|
||||
|
|
|
@ -128,6 +128,15 @@ class MemoryQueryEvent(BaseAgentEvent):
|
|||
type: Literal["MemoryQueryEvent"] = "MemoryQueryEvent"
|
||||
|
||||
|
||||
class ModelClientStreamingChunkEvent(BaseAgentEvent):
|
||||
"""An event signaling a text output chunk from a model client in streaming mode."""
|
||||
|
||||
content: str
|
||||
"""The partial text chunk."""
|
||||
|
||||
type: Literal["ModelClientStreamingChunkEvent"] = "ModelClientStreamingChunkEvent"
|
||||
|
||||
|
||||
ChatMessage = Annotated[
|
||||
TextMessage | MultiModalMessage | StopMessage | ToolCallSummaryMessage | HandoffMessage, Field(discriminator="type")
|
||||
]
|
||||
|
@ -135,7 +144,11 @@ ChatMessage = Annotated[
|
|||
|
||||
|
||||
AgentEvent = Annotated[
|
||||
ToolCallRequestEvent | ToolCallExecutionEvent | MemoryQueryEvent | UserInputRequestedEvent,
|
||||
ToolCallRequestEvent
|
||||
| ToolCallExecutionEvent
|
||||
| MemoryQueryEvent
|
||||
| UserInputRequestedEvent
|
||||
| ModelClientStreamingChunkEvent,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
"""Events emitted by agents and teams when they work, not used for agent-to-agent communication."""
|
||||
|
@ -154,4 +167,5 @@ __all__ = [
|
|||
"ToolCallSummaryMessage",
|
||||
"MemoryQueryEvent",
|
||||
"UserInputRequestedEvent",
|
||||
"ModelClientStreamingChunkEvent",
|
||||
]
|
||||
|
|
|
@ -21,7 +21,7 @@ from pydantic import BaseModel
|
|||
|
||||
from ... import EVENT_LOGGER_NAME
|
||||
from ...base import ChatAgent, TaskResult, Team, TerminationCondition
|
||||
from ...messages import AgentEvent, BaseChatMessage, ChatMessage, TextMessage
|
||||
from ...messages import AgentEvent, BaseChatMessage, ChatMessage, ModelClientStreamingChunkEvent, TextMessage
|
||||
from ...state import TeamState
|
||||
from ._chat_agent_container import ChatAgentContainer
|
||||
from ._events import GroupChatMessage, GroupChatReset, GroupChatStart, GroupChatTermination
|
||||
|
@ -190,6 +190,9 @@ class BaseGroupChat(Team, ABC, ComponentBase[BaseModel]):
|
|||
and it may not reset the termination condition.
|
||||
To gracefully stop the team, use :class:`~autogen_agentchat.conditions.ExternalTermination` instead.
|
||||
|
||||
Returns:
|
||||
result: The result of the task as :class:`~autogen_agentchat.base.TaskResult`. The result contains the messages produced by the team and the stop reason.
|
||||
|
||||
Example using the :class:`~autogen_agentchat.teams.RoundRobinGroupChat` team:
|
||||
|
||||
|
||||
|
@ -279,9 +282,15 @@ class BaseGroupChat(Team, ABC, ComponentBase[BaseModel]):
|
|||
cancellation_token: CancellationToken | None = None,
|
||||
) -> AsyncGenerator[AgentEvent | ChatMessage | TaskResult, None]:
|
||||
"""Run the team and produces a stream of messages and the final result
|
||||
of the type :class:`TaskResult` as the last item in the stream. Once the
|
||||
of the type :class:`~autogen_agentchat.base.TaskResult` as the last item in the stream. Once the
|
||||
team is stopped, the termination condition is reset.
|
||||
|
||||
.. note::
|
||||
|
||||
If an agent produces :class:`~autogen_agentchat.messages.ModelClientStreamingChunkEvent`,
|
||||
the message will be yielded in the stream but it will not be included in the
|
||||
:attr:`~autogen_agentchat.base.TaskResult.messages`.
|
||||
|
||||
Args:
|
||||
task (str | ChatMessage | Sequence[ChatMessage] | None): The task to run the team with. Can be a string, a single :class:`ChatMessage` , or a list of :class:`ChatMessage`.
|
||||
cancellation_token (CancellationToken | None): The cancellation token to kill the task immediately.
|
||||
|
@ -289,6 +298,9 @@ class BaseGroupChat(Team, ABC, ComponentBase[BaseModel]):
|
|||
and it may not reset the termination condition.
|
||||
To gracefully stop the team, use :class:`~autogen_agentchat.conditions.ExternalTermination` instead.
|
||||
|
||||
Returns:
|
||||
stream: an :class:`~collections.abc.AsyncGenerator` that yields :class:`~autogen_agentchat.messages.AgentEvent`, :class:`~autogen_agentchat.messages.ChatMessage`, and the final result :class:`~autogen_agentchat.base.TaskResult` as the last item in the stream.
|
||||
|
||||
Example using the :class:`~autogen_agentchat.teams.RoundRobinGroupChat` team:
|
||||
|
||||
.. code-block:: python
|
||||
|
@ -422,6 +434,9 @@ class BaseGroupChat(Team, ABC, ComponentBase[BaseModel]):
|
|||
if message is None:
|
||||
break
|
||||
yield message
|
||||
if isinstance(message, ModelClientStreamingChunkEvent):
|
||||
# Skip the model client streaming chunk events.
|
||||
continue
|
||||
output_messages.append(message)
|
||||
|
||||
# Yield the final result.
|
||||
|
|
|
@ -10,7 +10,13 @@ from autogen_core.models import RequestUsage
|
|||
|
||||
from autogen_agentchat.agents import UserProxyAgent
|
||||
from autogen_agentchat.base import Response, TaskResult
|
||||
from autogen_agentchat.messages import AgentEvent, ChatMessage, MultiModalMessage, UserInputRequestedEvent
|
||||
from autogen_agentchat.messages import (
|
||||
AgentEvent,
|
||||
ChatMessage,
|
||||
ModelClientStreamingChunkEvent,
|
||||
MultiModalMessage,
|
||||
UserInputRequestedEvent,
|
||||
)
|
||||
|
||||
|
||||
def _is_running_in_iterm() -> bool:
|
||||
|
@ -106,6 +112,8 @@ async def Console(
|
|||
|
||||
last_processed: Optional[T] = None
|
||||
|
||||
streaming_chunks: List[str] = []
|
||||
|
||||
async for message in stream:
|
||||
if isinstance(message, TaskResult):
|
||||
duration = time.time() - start_time
|
||||
|
@ -159,13 +167,28 @@ async def Console(
|
|||
else:
|
||||
# Cast required for mypy to be happy
|
||||
message = cast(AgentEvent | ChatMessage, message) # type: ignore
|
||||
output = f"{'-' * 10} {message.source} {'-' * 10}\n{_message_to_str(message, render_image_iterm=render_image_iterm)}\n"
|
||||
if message.models_usage:
|
||||
if output_stats:
|
||||
output += f"[Prompt tokens: {message.models_usage.prompt_tokens}, Completion tokens: {message.models_usage.completion_tokens}]\n"
|
||||
total_usage.completion_tokens += message.models_usage.completion_tokens
|
||||
total_usage.prompt_tokens += message.models_usage.prompt_tokens
|
||||
await aprint(output, end="")
|
||||
if not streaming_chunks:
|
||||
# Print message sender.
|
||||
await aprint(f"{'-' * 10} {message.source} {'-' * 10}", end="\n")
|
||||
if isinstance(message, ModelClientStreamingChunkEvent):
|
||||
await aprint(message.content, end="")
|
||||
streaming_chunks.append(message.content)
|
||||
else:
|
||||
if streaming_chunks:
|
||||
streaming_chunks.clear()
|
||||
# Chunked messages are already printed, so we just print a newline.
|
||||
await aprint("", end="\n")
|
||||
else:
|
||||
# Print message content.
|
||||
await aprint(_message_to_str(message, render_image_iterm=render_image_iterm), end="\n")
|
||||
if message.models_usage:
|
||||
if output_stats:
|
||||
await aprint(
|
||||
f"[Prompt tokens: {message.models_usage.prompt_tokens}, Completion tokens: {message.models_usage.completion_tokens}]",
|
||||
end="\n",
|
||||
)
|
||||
total_usage.completion_tokens += message.models_usage.completion_tokens
|
||||
total_usage.prompt_tokens += message.models_usage.prompt_tokens
|
||||
|
||||
if last_processed is None:
|
||||
raise ValueError("No TaskResult or Response was processed.")
|
||||
|
|
|
@ -11,6 +11,7 @@ from autogen_agentchat.messages import (
|
|||
ChatMessage,
|
||||
HandoffMessage,
|
||||
MemoryQueryEvent,
|
||||
ModelClientStreamingChunkEvent,
|
||||
MultiModalMessage,
|
||||
TextMessage,
|
||||
ToolCallExecutionEvent,
|
||||
|
@ -20,10 +21,11 @@ from autogen_agentchat.messages import (
|
|||
from autogen_core import FunctionCall, Image
|
||||
from autogen_core.memory import ListMemory, Memory, MemoryContent, MemoryMimeType, MemoryQueryResult
|
||||
from autogen_core.model_context import BufferedChatCompletionContext
|
||||
from autogen_core.models import FunctionExecutionResult, LLMMessage
|
||||
from autogen_core.models import CreateResult, FunctionExecutionResult, LLMMessage, RequestUsage
|
||||
from autogen_core.models._model_client import ModelFamily
|
||||
from autogen_core.tools import FunctionTool
|
||||
from autogen_ext.models.openai import OpenAIChatCompletionClient
|
||||
from autogen_ext.models.replay import ReplayChatCompletionClient
|
||||
from openai.resources.chat.completions import AsyncCompletions
|
||||
from openai.types.chat.chat_completion import ChatCompletion, Choice
|
||||
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
|
||||
|
@ -776,3 +778,65 @@ async def test_assistant_agent_declarative(monkeypatch: pytest.MonkeyPatch) -> N
|
|||
)
|
||||
agent3_config = agent3.dump_component()
|
||||
assert agent3_config.provider == "autogen_agentchat.agents.AssistantAgent"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_model_client_stream() -> None:
|
||||
mock_client = ReplayChatCompletionClient(
|
||||
[
|
||||
"Response to message 3",
|
||||
]
|
||||
)
|
||||
agent = AssistantAgent(
|
||||
"test_agent",
|
||||
model_client=mock_client,
|
||||
model_client_stream=True,
|
||||
)
|
||||
chunks: List[str] = []
|
||||
async for message in agent.run_stream(task="task"):
|
||||
if isinstance(message, TaskResult):
|
||||
assert message.messages[-1].content == "Response to message 3"
|
||||
elif isinstance(message, ModelClientStreamingChunkEvent):
|
||||
chunks.append(message.content)
|
||||
assert "".join(chunks) == "Response to message 3"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_model_client_stream_with_tool_calls() -> None:
|
||||
mock_client = ReplayChatCompletionClient(
|
||||
[
|
||||
CreateResult(
|
||||
content=[
|
||||
FunctionCall(id="1", name="_pass_function", arguments=r'{"input": "task"}'),
|
||||
FunctionCall(id="3", name="_echo_function", arguments=r'{"input": "task"}'),
|
||||
],
|
||||
finish_reason="function_calls",
|
||||
usage=RequestUsage(prompt_tokens=10, completion_tokens=5),
|
||||
cached=False,
|
||||
),
|
||||
"Example response 2 to task",
|
||||
]
|
||||
)
|
||||
mock_client._model_info["function_calling"] = True # pyright: ignore
|
||||
agent = AssistantAgent(
|
||||
"test_agent",
|
||||
model_client=mock_client,
|
||||
model_client_stream=True,
|
||||
reflect_on_tool_use=True,
|
||||
tools=[_pass_function, _echo_function],
|
||||
)
|
||||
chunks: List[str] = []
|
||||
async for message in agent.run_stream(task="task"):
|
||||
if isinstance(message, TaskResult):
|
||||
assert message.messages[-1].content == "Example response 2 to task"
|
||||
assert message.messages[1].content == [
|
||||
FunctionCall(id="1", name="_pass_function", arguments=r'{"input": "task"}'),
|
||||
FunctionCall(id="3", name="_echo_function", arguments=r'{"input": "task"}'),
|
||||
]
|
||||
assert message.messages[2].content == [
|
||||
FunctionExecutionResult(call_id="1", content="pass"),
|
||||
FunctionExecutionResult(call_id="3", content="task"),
|
||||
]
|
||||
elif isinstance(message, ModelClientStreamingChunkEvent):
|
||||
chunks.append(message.content)
|
||||
assert "".join(chunks) == "Example response 2 to task"
|
||||
|
|
|
@ -403,6 +403,117 @@
|
|||
"await Console(agent.run_stream(task=\"I am happy.\"))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Streaming Tokens\n",
|
||||
"\n",
|
||||
"You can stream the tokens generated by the model client by setting `model_client_stream=True`.\n",
|
||||
"This will cause the agent to yield {py:class}`~autogen_agentchat.messages.ModelClientStreamingChunkEvent` messages\n",
|
||||
"in {py:meth}`~autogen_agentchat.agents.BaseChatAgent.on_messages_stream` and {py:meth}`~autogen_agentchat.agents.BaseChatAgent.run_stream`.\n",
|
||||
"\n",
|
||||
"The underlying model API must support streaming tokens for this to work.\n",
|
||||
"Please check with your model provider to see if this is supported."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"source='assistant' models_usage=None content='Two' type='ModelClientStreamingChunkEvent'\n",
|
||||
"source='assistant' models_usage=None content=' cities' type='ModelClientStreamingChunkEvent'\n",
|
||||
"source='assistant' models_usage=None content=' in' type='ModelClientStreamingChunkEvent'\n",
|
||||
"source='assistant' models_usage=None content=' South' type='ModelClientStreamingChunkEvent'\n",
|
||||
"source='assistant' models_usage=None content=' America' type='ModelClientStreamingChunkEvent'\n",
|
||||
"source='assistant' models_usage=None content=' are' type='ModelClientStreamingChunkEvent'\n",
|
||||
"source='assistant' models_usage=None content=' Buenos' type='ModelClientStreamingChunkEvent'\n",
|
||||
"source='assistant' models_usage=None content=' Aires' type='ModelClientStreamingChunkEvent'\n",
|
||||
"source='assistant' models_usage=None content=' in' type='ModelClientStreamingChunkEvent'\n",
|
||||
"source='assistant' models_usage=None content=' Argentina' type='ModelClientStreamingChunkEvent'\n",
|
||||
"source='assistant' models_usage=None content=' and' type='ModelClientStreamingChunkEvent'\n",
|
||||
"source='assistant' models_usage=None content=' São' type='ModelClientStreamingChunkEvent'\n",
|
||||
"source='assistant' models_usage=None content=' Paulo' type='ModelClientStreamingChunkEvent'\n",
|
||||
"source='assistant' models_usage=None content=' in' type='ModelClientStreamingChunkEvent'\n",
|
||||
"source='assistant' models_usage=None content=' Brazil' type='ModelClientStreamingChunkEvent'\n",
|
||||
"source='assistant' models_usage=None content='.' type='ModelClientStreamingChunkEvent'\n",
|
||||
"Response(chat_message=TextMessage(source='assistant', models_usage=RequestUsage(prompt_tokens=0, completion_tokens=0), content='Two cities in South America are Buenos Aires in Argentina and São Paulo in Brazil.', type='TextMessage'), inner_messages=[])\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"model_client = OpenAIChatCompletionClient(model=\"gpt-4o\")\n",
|
||||
"\n",
|
||||
"streaming_assistant = AssistantAgent(\n",
|
||||
" name=\"assistant\",\n",
|
||||
" model_client=model_client,\n",
|
||||
" system_message=\"You are a helpful assistant.\",\n",
|
||||
" model_client_stream=True, # Enable streaming tokens.\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# Use an async function and asyncio.run() in a script.\n",
|
||||
"async for message in streaming_assistant.on_messages_stream( # type: ignore\n",
|
||||
" [TextMessage(content=\"Name two cities in South America\", source=\"user\")],\n",
|
||||
" cancellation_token=CancellationToken(),\n",
|
||||
"):\n",
|
||||
" print(message)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"You can see the streaming chunks in the output above.\n",
|
||||
"The chunks are generated by the model client and are yielded by the agent as they are received.\n",
|
||||
"The final response, the concatenation of all the chunks, is yielded right after the last chunk.\n",
|
||||
"\n",
|
||||
"Similarly, {py:meth}`~autogen_agentchat.agents.BaseChatAgent.run_stream` will also yield the same streaming chunks,\n",
|
||||
"followed by a full text message right after the last chunk."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"source='user' models_usage=None content='Name two cities in North America.' type='TextMessage'\n",
|
||||
"source='assistant' models_usage=None content='Two' type='ModelClientStreamingChunkEvent'\n",
|
||||
"source='assistant' models_usage=None content=' cities' type='ModelClientStreamingChunkEvent'\n",
|
||||
"source='assistant' models_usage=None content=' in' type='ModelClientStreamingChunkEvent'\n",
|
||||
"source='assistant' models_usage=None content=' North' type='ModelClientStreamingChunkEvent'\n",
|
||||
"source='assistant' models_usage=None content=' America' type='ModelClientStreamingChunkEvent'\n",
|
||||
"source='assistant' models_usage=None content=' are' type='ModelClientStreamingChunkEvent'\n",
|
||||
"source='assistant' models_usage=None content=' New' type='ModelClientStreamingChunkEvent'\n",
|
||||
"source='assistant' models_usage=None content=' York' type='ModelClientStreamingChunkEvent'\n",
|
||||
"source='assistant' models_usage=None content=' City' type='ModelClientStreamingChunkEvent'\n",
|
||||
"source='assistant' models_usage=None content=' in' type='ModelClientStreamingChunkEvent'\n",
|
||||
"source='assistant' models_usage=None content=' the' type='ModelClientStreamingChunkEvent'\n",
|
||||
"source='assistant' models_usage=None content=' United' type='ModelClientStreamingChunkEvent'\n",
|
||||
"source='assistant' models_usage=None content=' States' type='ModelClientStreamingChunkEvent'\n",
|
||||
"source='assistant' models_usage=None content=' and' type='ModelClientStreamingChunkEvent'\n",
|
||||
"source='assistant' models_usage=None content=' Toronto' type='ModelClientStreamingChunkEvent'\n",
|
||||
"source='assistant' models_usage=None content=' in' type='ModelClientStreamingChunkEvent'\n",
|
||||
"source='assistant' models_usage=None content=' Canada' type='ModelClientStreamingChunkEvent'\n",
|
||||
"source='assistant' models_usage=None content='.' type='ModelClientStreamingChunkEvent'\n",
|
||||
"source='assistant' models_usage=RequestUsage(prompt_tokens=0, completion_tokens=0) content='Two cities in North America are New York City in the United States and Toronto in Canada.' type='TextMessage'\n",
|
||||
"TaskResult(messages=[TextMessage(source='user', models_usage=None, content='Name two cities in North America.', type='TextMessage'), TextMessage(source='assistant', models_usage=RequestUsage(prompt_tokens=0, completion_tokens=0), content='Two cities in North America are New York City in the United States and Toronto in Canada.', type='TextMessage')], stop_reason=None)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"async for message in streaming_assistant.run_stream(task=\"Name two cities in North America.\"): # type: ignore\n",
|
||||
" print(message)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
|
|
|
@ -185,6 +185,9 @@ class ReplayChatCompletionClient(ChatCompletionClient):
|
|||
yield token + " "
|
||||
else:
|
||||
yield token
|
||||
yield CreateResult(
|
||||
finish_reason="stop", content=response, usage=self._cur_usage, cached=self._cached_bool_value
|
||||
)
|
||||
self._update_total_usage()
|
||||
else:
|
||||
self._cur_usage = RequestUsage(
|
||||
|
@ -226,7 +229,7 @@ class ReplayChatCompletionClient(ChatCompletionClient):
|
|||
total_tokens += len(tokens)
|
||||
all_tokens.extend(tokens)
|
||||
else:
|
||||
logger.warning("Token count has been done only on string content", RuntimeWarning)
|
||||
logger.warning("Token count has been done only on string content")
|
||||
elif isinstance(messages, Sequence):
|
||||
for message in messages:
|
||||
if isinstance(message.content, str): # type: ignore [reportAttributeAccessIssue, union-attr]
|
||||
|
@ -234,7 +237,7 @@ class ReplayChatCompletionClient(ChatCompletionClient):
|
|||
total_tokens += len(tokens)
|
||||
all_tokens.extend(tokens)
|
||||
else:
|
||||
logger.warning("Token count has been done only on string content", RuntimeWarning)
|
||||
logger.warning("Token count has been done only on string content")
|
||||
return all_tokens, total_tokens
|
||||
|
||||
def _update_total_usage(self) -> None:
|
||||
|
|
|
@ -16,6 +16,7 @@ from autogen_agentchat.base import Response, TaskResult
|
|||
from autogen_agentchat.messages import (
|
||||
AgentEvent,
|
||||
ChatMessage,
|
||||
ModelClientStreamingChunkEvent,
|
||||
MultiModalMessage,
|
||||
UserInputRequestedEvent,
|
||||
)
|
||||
|
@ -185,6 +186,9 @@ async def RichConsole(
|
|||
elif isinstance(message, UserInputRequestedEvent):
|
||||
if user_input_manager is not None:
|
||||
user_input_manager.notify_event_received(message.request_id)
|
||||
elif isinstance(message, ModelClientStreamingChunkEvent):
|
||||
# TODO: Handle model client streaming chunk events.
|
||||
pass
|
||||
else:
|
||||
# Cast required for mypy to be happy
|
||||
message = cast(AgentEvent | ChatMessage, message) # type: ignore
|
||||
|
|
|
@ -107,14 +107,14 @@ async def test_cache_create_stream() -> None:
|
|||
async for completion in cached_client.create_stream(
|
||||
[system_prompt, UserMessage(content=prompts[0], source="user")]
|
||||
):
|
||||
original_streamed_results.append(completion)
|
||||
original_streamed_results.append(copy.copy(completion))
|
||||
total_usage0 = copy.copy(cached_client.total_usage())
|
||||
|
||||
cached_completion_results: List[Union[str, CreateResult]] = []
|
||||
async for completion in cached_client.create_stream(
|
||||
[system_prompt, UserMessage(content=prompts[0], source="user")]
|
||||
):
|
||||
cached_completion_results.append(completion)
|
||||
cached_completion_results.append(copy.copy(completion))
|
||||
total_usage1 = copy.copy(cached_client.total_usage())
|
||||
|
||||
assert total_usage1.prompt_tokens == total_usage0.prompt_tokens
|
||||
|
|
|
@ -67,12 +67,16 @@ async def test_replay_chat_completion_client_create_stream() -> None:
|
|||
reply_model_client = ReplayChatCompletionClient(messages)
|
||||
|
||||
for i in range(num_messages):
|
||||
result: List[str] = []
|
||||
chunks: List[str] = []
|
||||
result: CreateResult | None = None
|
||||
async for completion in reply_model_client.create_stream([UserMessage(content="dummy", source="_")]):
|
||||
text = completion.content if isinstance(completion, CreateResult) else completion
|
||||
assert isinstance(text, str)
|
||||
result.append(text)
|
||||
assert "".join(result) == messages[i]
|
||||
if isinstance(completion, CreateResult):
|
||||
result = completion
|
||||
else:
|
||||
assert isinstance(completion, str)
|
||||
chunks.append(completion)
|
||||
assert result is not None
|
||||
assert "".join(chunks) == messages[i] == result.content
|
||||
|
||||
with pytest.raises(ValueError, match="No more mock responses available"):
|
||||
await reply_model_client.create([UserMessage(content="dummy", source="_")])
|
||||
|
|
Loading…
Reference in New Issue