mirror of https://github.com/microsoft/autogen.git
feat: introduce ModelClientStreamingChunkEvent for streaming model output and update handling in agents and console (#5208)
Resolves #3983 * introduce `model_client_stream` parameter in `AssistantAgent` to enable token-level streaming output. * introduce `ModelClientStreamingChunkEvent` as a type of `AgentEvent` to pass the streaming chunks to the application via `run_stream` and `on_messages_stream`. Although this will not affect the inner messages list in the final `Response` or `TaskResult`. * handle this new message type in `Console`.
This commit is contained in:
parent
8a0daf8285
commit
225eb9d0b2
|
@ -22,13 +22,14 @@ from autogen_core.model_context import (
|
||||||
from autogen_core.models import (
|
from autogen_core.models import (
|
||||||
AssistantMessage,
|
AssistantMessage,
|
||||||
ChatCompletionClient,
|
ChatCompletionClient,
|
||||||
|
CreateResult,
|
||||||
FunctionExecutionResult,
|
FunctionExecutionResult,
|
||||||
FunctionExecutionResultMessage,
|
FunctionExecutionResultMessage,
|
||||||
LLMMessage,
|
LLMMessage,
|
||||||
SystemMessage,
|
SystemMessage,
|
||||||
UserMessage,
|
UserMessage,
|
||||||
)
|
)
|
||||||
from autogen_core.tools import FunctionTool, BaseTool
|
from autogen_core.tools import BaseTool, FunctionTool
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from typing_extensions import Self
|
from typing_extensions import Self
|
||||||
|
|
||||||
|
@ -40,6 +41,7 @@ from ..messages import (
|
||||||
ChatMessage,
|
ChatMessage,
|
||||||
HandoffMessage,
|
HandoffMessage,
|
||||||
MemoryQueryEvent,
|
MemoryQueryEvent,
|
||||||
|
ModelClientStreamingChunkEvent,
|
||||||
MultiModalMessage,
|
MultiModalMessage,
|
||||||
TextMessage,
|
TextMessage,
|
||||||
ToolCallExecutionEvent,
|
ToolCallExecutionEvent,
|
||||||
|
@ -62,6 +64,7 @@ class AssistantAgentConfig(BaseModel):
|
||||||
model_context: ComponentModel | None = None
|
model_context: ComponentModel | None = None
|
||||||
description: str
|
description: str
|
||||||
system_message: str | None = None
|
system_message: str | None = None
|
||||||
|
model_client_stream: bool
|
||||||
reflect_on_tool_use: bool
|
reflect_on_tool_use: bool
|
||||||
tool_call_summary_format: str
|
tool_call_summary_format: str
|
||||||
|
|
||||||
|
@ -126,6 +129,14 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
|
||||||
This will limit the number of recent messages sent to the model and can be useful
|
This will limit the number of recent messages sent to the model and can be useful
|
||||||
when the model has a limit on the number of tokens it can process.
|
when the model has a limit on the number of tokens it can process.
|
||||||
|
|
||||||
|
Streaming mode:
|
||||||
|
|
||||||
|
The assistant agent can be used in streaming mode by setting `model_client_stream=True`.
|
||||||
|
In this mode, the :meth:`on_messages_stream` and :meth:`BaseChatAgent.run_stream` methods will also yield
|
||||||
|
:class:`~autogen_agentchat.messages.ModelClientStreamingChunkEvent`
|
||||||
|
messages as the model client produces chunks of response.
|
||||||
|
The chunk messages will not be included in the final response's inner messages.
|
||||||
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
name (str): The name of the agent.
|
name (str): The name of the agent.
|
||||||
|
@ -138,6 +149,9 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
|
||||||
model_context (ChatCompletionContext | None, optional): The model context for storing and retrieving :class:`~autogen_core.models.LLMMessage`. It can be preloaded with initial messages. The initial messages will be cleared when the agent is reset.
|
model_context (ChatCompletionContext | None, optional): The model context for storing and retrieving :class:`~autogen_core.models.LLMMessage`. It can be preloaded with initial messages. The initial messages will be cleared when the agent is reset.
|
||||||
description (str, optional): The description of the agent.
|
description (str, optional): The description of the agent.
|
||||||
system_message (str, optional): The system message for the model. If provided, it will be prepended to the messages in the model context when making an inference. Set to `None` to disable.
|
system_message (str, optional): The system message for the model. If provided, it will be prepended to the messages in the model context when making an inference. Set to `None` to disable.
|
||||||
|
model_client_stream (bool, optional): If `True`, the model client will be used in streaming mode.
|
||||||
|
:meth:`on_messages_stream` and :meth:`BaseChatAgent.run_stream` methods will also yield :class:`~autogen_agentchat.messages.ModelClientStreamingChunkEvent`
|
||||||
|
messages as the model client produces chunks of response. Defaults to `False`.
|
||||||
reflect_on_tool_use (bool, optional): If `True`, the agent will make another model inference using the tool call and result
|
reflect_on_tool_use (bool, optional): If `True`, the agent will make another model inference using the tool call and result
|
||||||
to generate a response. If `False`, the tool call result will be returned as the response. Defaults to `False`.
|
to generate a response. If `False`, the tool call result will be returned as the response. Defaults to `False`.
|
||||||
tool_call_summary_format (str, optional): The format string used to create a tool call summary for every tool call result.
|
tool_call_summary_format (str, optional): The format string used to create a tool call summary for every tool call result.
|
||||||
|
@ -268,12 +282,14 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
|
||||||
system_message: (
|
system_message: (
|
||||||
str | None
|
str | None
|
||||||
) = "You are a helpful AI assistant. Solve tasks using your tools. Reply with TERMINATE when the task has been completed.",
|
) = "You are a helpful AI assistant. Solve tasks using your tools. Reply with TERMINATE when the task has been completed.",
|
||||||
|
model_client_stream: bool = False,
|
||||||
reflect_on_tool_use: bool = False,
|
reflect_on_tool_use: bool = False,
|
||||||
tool_call_summary_format: str = "{result}",
|
tool_call_summary_format: str = "{result}",
|
||||||
memory: Sequence[Memory] | None = None,
|
memory: Sequence[Memory] | None = None,
|
||||||
):
|
):
|
||||||
super().__init__(name=name, description=description)
|
super().__init__(name=name, description=description)
|
||||||
self._model_client = model_client
|
self._model_client = model_client
|
||||||
|
self._model_client_stream = model_client_stream
|
||||||
self._memory = None
|
self._memory = None
|
||||||
if memory is not None:
|
if memory is not None:
|
||||||
if isinstance(memory, list):
|
if isinstance(memory, list):
|
||||||
|
@ -340,7 +356,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def produced_message_types(self) -> Sequence[type[ChatMessage]]:
|
def produced_message_types(self) -> Sequence[type[ChatMessage]]:
|
||||||
"""The types of messages that the assistant agent produces."""
|
"""The types of final response messages that the assistant agent produces."""
|
||||||
message_types: List[type[ChatMessage]] = [TextMessage]
|
message_types: List[type[ChatMessage]] = [TextMessage]
|
||||||
if self._handoffs:
|
if self._handoffs:
|
||||||
message_types.append(HandoffMessage)
|
message_types.append(HandoffMessage)
|
||||||
|
@ -383,9 +399,23 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
|
||||||
|
|
||||||
# Generate an inference result based on the current model context.
|
# Generate an inference result based on the current model context.
|
||||||
llm_messages = self._system_messages + await self._model_context.get_messages()
|
llm_messages = self._system_messages + await self._model_context.get_messages()
|
||||||
model_result = await self._model_client.create(
|
model_result: CreateResult | None = None
|
||||||
llm_messages, tools=self._tools + self._handoff_tools, cancellation_token=cancellation_token
|
if self._model_client_stream:
|
||||||
)
|
# Stream the model client.
|
||||||
|
async for chunk in self._model_client.create_stream(
|
||||||
|
llm_messages, tools=self._tools + self._handoff_tools, cancellation_token=cancellation_token
|
||||||
|
):
|
||||||
|
if isinstance(chunk, CreateResult):
|
||||||
|
model_result = chunk
|
||||||
|
elif isinstance(chunk, str):
|
||||||
|
yield ModelClientStreamingChunkEvent(content=chunk, source=self.name)
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f"Invalid chunk type: {type(chunk)}")
|
||||||
|
assert isinstance(model_result, CreateResult)
|
||||||
|
else:
|
||||||
|
model_result = await self._model_client.create(
|
||||||
|
llm_messages, tools=self._tools + self._handoff_tools, cancellation_token=cancellation_token
|
||||||
|
)
|
||||||
|
|
||||||
# Add the response to the model context.
|
# Add the response to the model context.
|
||||||
await self._model_context.add_message(AssistantMessage(content=model_result.content, source=self.name))
|
await self._model_context.add_message(AssistantMessage(content=model_result.content, source=self.name))
|
||||||
|
@ -465,14 +495,34 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
|
||||||
if self._reflect_on_tool_use:
|
if self._reflect_on_tool_use:
|
||||||
# Generate another inference result based on the tool call and result.
|
# Generate another inference result based on the tool call and result.
|
||||||
llm_messages = self._system_messages + await self._model_context.get_messages()
|
llm_messages = self._system_messages + await self._model_context.get_messages()
|
||||||
model_result = await self._model_client.create(llm_messages, cancellation_token=cancellation_token)
|
reflection_model_result: CreateResult | None = None
|
||||||
assert isinstance(model_result.content, str)
|
if self._model_client_stream:
|
||||||
|
# Stream the model client.
|
||||||
|
async for chunk in self._model_client.create_stream(
|
||||||
|
llm_messages, cancellation_token=cancellation_token
|
||||||
|
):
|
||||||
|
if isinstance(chunk, CreateResult):
|
||||||
|
reflection_model_result = chunk
|
||||||
|
elif isinstance(chunk, str):
|
||||||
|
yield ModelClientStreamingChunkEvent(content=chunk, source=self.name)
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f"Invalid chunk type: {type(chunk)}")
|
||||||
|
assert isinstance(reflection_model_result, CreateResult)
|
||||||
|
else:
|
||||||
|
reflection_model_result = await self._model_client.create(
|
||||||
|
llm_messages, cancellation_token=cancellation_token
|
||||||
|
)
|
||||||
|
assert isinstance(reflection_model_result.content, str)
|
||||||
# Add the response to the model context.
|
# Add the response to the model context.
|
||||||
await self._model_context.add_message(AssistantMessage(content=model_result.content, source=self.name))
|
await self._model_context.add_message(
|
||||||
|
AssistantMessage(content=reflection_model_result.content, source=self.name)
|
||||||
|
)
|
||||||
# Yield the response.
|
# Yield the response.
|
||||||
yield Response(
|
yield Response(
|
||||||
chat_message=TextMessage(
|
chat_message=TextMessage(
|
||||||
content=model_result.content, source=self.name, models_usage=model_result.usage
|
content=reflection_model_result.content,
|
||||||
|
source=self.name,
|
||||||
|
models_usage=reflection_model_result.usage,
|
||||||
),
|
),
|
||||||
inner_messages=inner_messages,
|
inner_messages=inner_messages,
|
||||||
)
|
)
|
||||||
|
@ -538,6 +588,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
|
||||||
system_message=self._system_messages[0].content
|
system_message=self._system_messages[0].content
|
||||||
if self._system_messages and isinstance(self._system_messages[0].content, str)
|
if self._system_messages and isinstance(self._system_messages[0].content, str)
|
||||||
else None,
|
else None,
|
||||||
|
model_client_stream=self._model_client_stream,
|
||||||
reflect_on_tool_use=self._reflect_on_tool_use,
|
reflect_on_tool_use=self._reflect_on_tool_use,
|
||||||
tool_call_summary_format=self._tool_call_summary_format,
|
tool_call_summary_format=self._tool_call_summary_format,
|
||||||
)
|
)
|
||||||
|
@ -553,6 +604,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
|
||||||
model_context=None,
|
model_context=None,
|
||||||
description=config.description,
|
description=config.description,
|
||||||
system_message=config.system_message,
|
system_message=config.system_message,
|
||||||
|
model_client_stream=config.model_client_stream,
|
||||||
reflect_on_tool_use=config.reflect_on_tool_use,
|
reflect_on_tool_use=config.reflect_on_tool_use,
|
||||||
tool_call_summary_format=config.tool_call_summary_format,
|
tool_call_summary_format=config.tool_call_summary_format,
|
||||||
)
|
)
|
||||||
|
|
|
@ -9,6 +9,7 @@ from ..messages import (
|
||||||
AgentEvent,
|
AgentEvent,
|
||||||
BaseChatMessage,
|
BaseChatMessage,
|
||||||
ChatMessage,
|
ChatMessage,
|
||||||
|
ModelClientStreamingChunkEvent,
|
||||||
TextMessage,
|
TextMessage,
|
||||||
)
|
)
|
||||||
from ..state import BaseState
|
from ..state import BaseState
|
||||||
|
@ -178,8 +179,11 @@ class BaseChatAgent(ChatAgent, ABC, ComponentBase[BaseModel]):
|
||||||
output_messages.append(message.chat_message)
|
output_messages.append(message.chat_message)
|
||||||
yield TaskResult(messages=output_messages)
|
yield TaskResult(messages=output_messages)
|
||||||
else:
|
else:
|
||||||
output_messages.append(message)
|
|
||||||
yield message
|
yield message
|
||||||
|
if isinstance(message, ModelClientStreamingChunkEvent):
|
||||||
|
# Skip the model client streaming chunk events.
|
||||||
|
continue
|
||||||
|
output_messages.append(message)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def on_reset(self, cancellation_token: CancellationToken) -> None:
|
async def on_reset(self, cancellation_token: CancellationToken) -> None:
|
||||||
|
|
|
@ -13,6 +13,7 @@ from ..messages import (
|
||||||
AgentEvent,
|
AgentEvent,
|
||||||
BaseChatMessage,
|
BaseChatMessage,
|
||||||
ChatMessage,
|
ChatMessage,
|
||||||
|
ModelClientStreamingChunkEvent,
|
||||||
TextMessage,
|
TextMessage,
|
||||||
)
|
)
|
||||||
from ._base_chat_agent import BaseChatAgent
|
from ._base_chat_agent import BaseChatAgent
|
||||||
|
@ -150,6 +151,9 @@ class SocietyOfMindAgent(BaseChatAgent, Component[SocietyOfMindAgentConfig]):
|
||||||
# Skip the task messages.
|
# Skip the task messages.
|
||||||
continue
|
continue
|
||||||
yield inner_msg
|
yield inner_msg
|
||||||
|
if isinstance(inner_msg, ModelClientStreamingChunkEvent):
|
||||||
|
# Skip the model client streaming chunk events.
|
||||||
|
continue
|
||||||
inner_messages.append(inner_msg)
|
inner_messages.append(inner_msg)
|
||||||
assert result is not None
|
assert result is not None
|
||||||
|
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Dict
|
from typing import Any, Dict
|
||||||
|
|
||||||
from autogen_core.tools import FunctionTool, BaseTool
|
from autogen_core.tools import BaseTool, FunctionTool
|
||||||
from pydantic import BaseModel, Field, model_validator
|
from pydantic import BaseModel, Field, model_validator
|
||||||
|
|
||||||
from .. import EVENT_LOGGER_NAME
|
from .. import EVENT_LOGGER_NAME
|
||||||
|
|
|
@ -128,6 +128,15 @@ class MemoryQueryEvent(BaseAgentEvent):
|
||||||
type: Literal["MemoryQueryEvent"] = "MemoryQueryEvent"
|
type: Literal["MemoryQueryEvent"] = "MemoryQueryEvent"
|
||||||
|
|
||||||
|
|
||||||
|
class ModelClientStreamingChunkEvent(BaseAgentEvent):
|
||||||
|
"""An event signaling a text output chunk from a model client in streaming mode."""
|
||||||
|
|
||||||
|
content: str
|
||||||
|
"""The partial text chunk."""
|
||||||
|
|
||||||
|
type: Literal["ModelClientStreamingChunkEvent"] = "ModelClientStreamingChunkEvent"
|
||||||
|
|
||||||
|
|
||||||
ChatMessage = Annotated[
|
ChatMessage = Annotated[
|
||||||
TextMessage | MultiModalMessage | StopMessage | ToolCallSummaryMessage | HandoffMessage, Field(discriminator="type")
|
TextMessage | MultiModalMessage | StopMessage | ToolCallSummaryMessage | HandoffMessage, Field(discriminator="type")
|
||||||
]
|
]
|
||||||
|
@ -135,7 +144,11 @@ ChatMessage = Annotated[
|
||||||
|
|
||||||
|
|
||||||
AgentEvent = Annotated[
|
AgentEvent = Annotated[
|
||||||
ToolCallRequestEvent | ToolCallExecutionEvent | MemoryQueryEvent | UserInputRequestedEvent,
|
ToolCallRequestEvent
|
||||||
|
| ToolCallExecutionEvent
|
||||||
|
| MemoryQueryEvent
|
||||||
|
| UserInputRequestedEvent
|
||||||
|
| ModelClientStreamingChunkEvent,
|
||||||
Field(discriminator="type"),
|
Field(discriminator="type"),
|
||||||
]
|
]
|
||||||
"""Events emitted by agents and teams when they work, not used for agent-to-agent communication."""
|
"""Events emitted by agents and teams when they work, not used for agent-to-agent communication."""
|
||||||
|
@ -154,4 +167,5 @@ __all__ = [
|
||||||
"ToolCallSummaryMessage",
|
"ToolCallSummaryMessage",
|
||||||
"MemoryQueryEvent",
|
"MemoryQueryEvent",
|
||||||
"UserInputRequestedEvent",
|
"UserInputRequestedEvent",
|
||||||
|
"ModelClientStreamingChunkEvent",
|
||||||
]
|
]
|
||||||
|
|
|
@ -21,7 +21,7 @@ from pydantic import BaseModel
|
||||||
|
|
||||||
from ... import EVENT_LOGGER_NAME
|
from ... import EVENT_LOGGER_NAME
|
||||||
from ...base import ChatAgent, TaskResult, Team, TerminationCondition
|
from ...base import ChatAgent, TaskResult, Team, TerminationCondition
|
||||||
from ...messages import AgentEvent, BaseChatMessage, ChatMessage, TextMessage
|
from ...messages import AgentEvent, BaseChatMessage, ChatMessage, ModelClientStreamingChunkEvent, TextMessage
|
||||||
from ...state import TeamState
|
from ...state import TeamState
|
||||||
from ._chat_agent_container import ChatAgentContainer
|
from ._chat_agent_container import ChatAgentContainer
|
||||||
from ._events import GroupChatMessage, GroupChatReset, GroupChatStart, GroupChatTermination
|
from ._events import GroupChatMessage, GroupChatReset, GroupChatStart, GroupChatTermination
|
||||||
|
@ -190,6 +190,9 @@ class BaseGroupChat(Team, ABC, ComponentBase[BaseModel]):
|
||||||
and it may not reset the termination condition.
|
and it may not reset the termination condition.
|
||||||
To gracefully stop the team, use :class:`~autogen_agentchat.conditions.ExternalTermination` instead.
|
To gracefully stop the team, use :class:`~autogen_agentchat.conditions.ExternalTermination` instead.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
result: The result of the task as :class:`~autogen_agentchat.base.TaskResult`. The result contains the messages produced by the team and the stop reason.
|
||||||
|
|
||||||
Example using the :class:`~autogen_agentchat.teams.RoundRobinGroupChat` team:
|
Example using the :class:`~autogen_agentchat.teams.RoundRobinGroupChat` team:
|
||||||
|
|
||||||
|
|
||||||
|
@ -279,9 +282,15 @@ class BaseGroupChat(Team, ABC, ComponentBase[BaseModel]):
|
||||||
cancellation_token: CancellationToken | None = None,
|
cancellation_token: CancellationToken | None = None,
|
||||||
) -> AsyncGenerator[AgentEvent | ChatMessage | TaskResult, None]:
|
) -> AsyncGenerator[AgentEvent | ChatMessage | TaskResult, None]:
|
||||||
"""Run the team and produces a stream of messages and the final result
|
"""Run the team and produces a stream of messages and the final result
|
||||||
of the type :class:`TaskResult` as the last item in the stream. Once the
|
of the type :class:`~autogen_agentchat.base.TaskResult` as the last item in the stream. Once the
|
||||||
team is stopped, the termination condition is reset.
|
team is stopped, the termination condition is reset.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
|
||||||
|
If an agent produces :class:`~autogen_agentchat.messages.ModelClientStreamingChunkEvent`,
|
||||||
|
the message will be yielded in the stream but it will not be included in the
|
||||||
|
:attr:`~autogen_agentchat.base.TaskResult.messages`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
task (str | ChatMessage | Sequence[ChatMessage] | None): The task to run the team with. Can be a string, a single :class:`ChatMessage` , or a list of :class:`ChatMessage`.
|
task (str | ChatMessage | Sequence[ChatMessage] | None): The task to run the team with. Can be a string, a single :class:`ChatMessage` , or a list of :class:`ChatMessage`.
|
||||||
cancellation_token (CancellationToken | None): The cancellation token to kill the task immediately.
|
cancellation_token (CancellationToken | None): The cancellation token to kill the task immediately.
|
||||||
|
@ -289,6 +298,9 @@ class BaseGroupChat(Team, ABC, ComponentBase[BaseModel]):
|
||||||
and it may not reset the termination condition.
|
and it may not reset the termination condition.
|
||||||
To gracefully stop the team, use :class:`~autogen_agentchat.conditions.ExternalTermination` instead.
|
To gracefully stop the team, use :class:`~autogen_agentchat.conditions.ExternalTermination` instead.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
stream: an :class:`~collections.abc.AsyncGenerator` that yields :class:`~autogen_agentchat.messages.AgentEvent`, :class:`~autogen_agentchat.messages.ChatMessage`, and the final result :class:`~autogen_agentchat.base.TaskResult` as the last item in the stream.
|
||||||
|
|
||||||
Example using the :class:`~autogen_agentchat.teams.RoundRobinGroupChat` team:
|
Example using the :class:`~autogen_agentchat.teams.RoundRobinGroupChat` team:
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
@ -422,6 +434,9 @@ class BaseGroupChat(Team, ABC, ComponentBase[BaseModel]):
|
||||||
if message is None:
|
if message is None:
|
||||||
break
|
break
|
||||||
yield message
|
yield message
|
||||||
|
if isinstance(message, ModelClientStreamingChunkEvent):
|
||||||
|
# Skip the model client streaming chunk events.
|
||||||
|
continue
|
||||||
output_messages.append(message)
|
output_messages.append(message)
|
||||||
|
|
||||||
# Yield the final result.
|
# Yield the final result.
|
||||||
|
|
|
@ -10,7 +10,13 @@ from autogen_core.models import RequestUsage
|
||||||
|
|
||||||
from autogen_agentchat.agents import UserProxyAgent
|
from autogen_agentchat.agents import UserProxyAgent
|
||||||
from autogen_agentchat.base import Response, TaskResult
|
from autogen_agentchat.base import Response, TaskResult
|
||||||
from autogen_agentchat.messages import AgentEvent, ChatMessage, MultiModalMessage, UserInputRequestedEvent
|
from autogen_agentchat.messages import (
|
||||||
|
AgentEvent,
|
||||||
|
ChatMessage,
|
||||||
|
ModelClientStreamingChunkEvent,
|
||||||
|
MultiModalMessage,
|
||||||
|
UserInputRequestedEvent,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _is_running_in_iterm() -> bool:
|
def _is_running_in_iterm() -> bool:
|
||||||
|
@ -106,6 +112,8 @@ async def Console(
|
||||||
|
|
||||||
last_processed: Optional[T] = None
|
last_processed: Optional[T] = None
|
||||||
|
|
||||||
|
streaming_chunks: List[str] = []
|
||||||
|
|
||||||
async for message in stream:
|
async for message in stream:
|
||||||
if isinstance(message, TaskResult):
|
if isinstance(message, TaskResult):
|
||||||
duration = time.time() - start_time
|
duration = time.time() - start_time
|
||||||
|
@ -159,13 +167,28 @@ async def Console(
|
||||||
else:
|
else:
|
||||||
# Cast required for mypy to be happy
|
# Cast required for mypy to be happy
|
||||||
message = cast(AgentEvent | ChatMessage, message) # type: ignore
|
message = cast(AgentEvent | ChatMessage, message) # type: ignore
|
||||||
output = f"{'-' * 10} {message.source} {'-' * 10}\n{_message_to_str(message, render_image_iterm=render_image_iterm)}\n"
|
if not streaming_chunks:
|
||||||
if message.models_usage:
|
# Print message sender.
|
||||||
if output_stats:
|
await aprint(f"{'-' * 10} {message.source} {'-' * 10}", end="\n")
|
||||||
output += f"[Prompt tokens: {message.models_usage.prompt_tokens}, Completion tokens: {message.models_usage.completion_tokens}]\n"
|
if isinstance(message, ModelClientStreamingChunkEvent):
|
||||||
total_usage.completion_tokens += message.models_usage.completion_tokens
|
await aprint(message.content, end="")
|
||||||
total_usage.prompt_tokens += message.models_usage.prompt_tokens
|
streaming_chunks.append(message.content)
|
||||||
await aprint(output, end="")
|
else:
|
||||||
|
if streaming_chunks:
|
||||||
|
streaming_chunks.clear()
|
||||||
|
# Chunked messages are already printed, so we just print a newline.
|
||||||
|
await aprint("", end="\n")
|
||||||
|
else:
|
||||||
|
# Print message content.
|
||||||
|
await aprint(_message_to_str(message, render_image_iterm=render_image_iterm), end="\n")
|
||||||
|
if message.models_usage:
|
||||||
|
if output_stats:
|
||||||
|
await aprint(
|
||||||
|
f"[Prompt tokens: {message.models_usage.prompt_tokens}, Completion tokens: {message.models_usage.completion_tokens}]",
|
||||||
|
end="\n",
|
||||||
|
)
|
||||||
|
total_usage.completion_tokens += message.models_usage.completion_tokens
|
||||||
|
total_usage.prompt_tokens += message.models_usage.prompt_tokens
|
||||||
|
|
||||||
if last_processed is None:
|
if last_processed is None:
|
||||||
raise ValueError("No TaskResult or Response was processed.")
|
raise ValueError("No TaskResult or Response was processed.")
|
||||||
|
|
|
@ -11,6 +11,7 @@ from autogen_agentchat.messages import (
|
||||||
ChatMessage,
|
ChatMessage,
|
||||||
HandoffMessage,
|
HandoffMessage,
|
||||||
MemoryQueryEvent,
|
MemoryQueryEvent,
|
||||||
|
ModelClientStreamingChunkEvent,
|
||||||
MultiModalMessage,
|
MultiModalMessage,
|
||||||
TextMessage,
|
TextMessage,
|
||||||
ToolCallExecutionEvent,
|
ToolCallExecutionEvent,
|
||||||
|
@ -20,10 +21,11 @@ from autogen_agentchat.messages import (
|
||||||
from autogen_core import FunctionCall, Image
|
from autogen_core import FunctionCall, Image
|
||||||
from autogen_core.memory import ListMemory, Memory, MemoryContent, MemoryMimeType, MemoryQueryResult
|
from autogen_core.memory import ListMemory, Memory, MemoryContent, MemoryMimeType, MemoryQueryResult
|
||||||
from autogen_core.model_context import BufferedChatCompletionContext
|
from autogen_core.model_context import BufferedChatCompletionContext
|
||||||
from autogen_core.models import FunctionExecutionResult, LLMMessage
|
from autogen_core.models import CreateResult, FunctionExecutionResult, LLMMessage, RequestUsage
|
||||||
from autogen_core.models._model_client import ModelFamily
|
from autogen_core.models._model_client import ModelFamily
|
||||||
from autogen_core.tools import FunctionTool
|
from autogen_core.tools import FunctionTool
|
||||||
from autogen_ext.models.openai import OpenAIChatCompletionClient
|
from autogen_ext.models.openai import OpenAIChatCompletionClient
|
||||||
|
from autogen_ext.models.replay import ReplayChatCompletionClient
|
||||||
from openai.resources.chat.completions import AsyncCompletions
|
from openai.resources.chat.completions import AsyncCompletions
|
||||||
from openai.types.chat.chat_completion import ChatCompletion, Choice
|
from openai.types.chat.chat_completion import ChatCompletion, Choice
|
||||||
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
|
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
|
||||||
|
@ -776,3 +778,65 @@ async def test_assistant_agent_declarative(monkeypatch: pytest.MonkeyPatch) -> N
|
||||||
)
|
)
|
||||||
agent3_config = agent3.dump_component()
|
agent3_config = agent3.dump_component()
|
||||||
assert agent3_config.provider == "autogen_agentchat.agents.AssistantAgent"
|
assert agent3_config.provider == "autogen_agentchat.agents.AssistantAgent"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_model_client_stream() -> None:
|
||||||
|
mock_client = ReplayChatCompletionClient(
|
||||||
|
[
|
||||||
|
"Response to message 3",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
agent = AssistantAgent(
|
||||||
|
"test_agent",
|
||||||
|
model_client=mock_client,
|
||||||
|
model_client_stream=True,
|
||||||
|
)
|
||||||
|
chunks: List[str] = []
|
||||||
|
async for message in agent.run_stream(task="task"):
|
||||||
|
if isinstance(message, TaskResult):
|
||||||
|
assert message.messages[-1].content == "Response to message 3"
|
||||||
|
elif isinstance(message, ModelClientStreamingChunkEvent):
|
||||||
|
chunks.append(message.content)
|
||||||
|
assert "".join(chunks) == "Response to message 3"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_model_client_stream_with_tool_calls() -> None:
|
||||||
|
mock_client = ReplayChatCompletionClient(
|
||||||
|
[
|
||||||
|
CreateResult(
|
||||||
|
content=[
|
||||||
|
FunctionCall(id="1", name="_pass_function", arguments=r'{"input": "task"}'),
|
||||||
|
FunctionCall(id="3", name="_echo_function", arguments=r'{"input": "task"}'),
|
||||||
|
],
|
||||||
|
finish_reason="function_calls",
|
||||||
|
usage=RequestUsage(prompt_tokens=10, completion_tokens=5),
|
||||||
|
cached=False,
|
||||||
|
),
|
||||||
|
"Example response 2 to task",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
mock_client._model_info["function_calling"] = True # pyright: ignore
|
||||||
|
agent = AssistantAgent(
|
||||||
|
"test_agent",
|
||||||
|
model_client=mock_client,
|
||||||
|
model_client_stream=True,
|
||||||
|
reflect_on_tool_use=True,
|
||||||
|
tools=[_pass_function, _echo_function],
|
||||||
|
)
|
||||||
|
chunks: List[str] = []
|
||||||
|
async for message in agent.run_stream(task="task"):
|
||||||
|
if isinstance(message, TaskResult):
|
||||||
|
assert message.messages[-1].content == "Example response 2 to task"
|
||||||
|
assert message.messages[1].content == [
|
||||||
|
FunctionCall(id="1", name="_pass_function", arguments=r'{"input": "task"}'),
|
||||||
|
FunctionCall(id="3", name="_echo_function", arguments=r'{"input": "task"}'),
|
||||||
|
]
|
||||||
|
assert message.messages[2].content == [
|
||||||
|
FunctionExecutionResult(call_id="1", content="pass"),
|
||||||
|
FunctionExecutionResult(call_id="3", content="task"),
|
||||||
|
]
|
||||||
|
elif isinstance(message, ModelClientStreamingChunkEvent):
|
||||||
|
chunks.append(message.content)
|
||||||
|
assert "".join(chunks) == "Example response 2 to task"
|
||||||
|
|
|
@ -403,6 +403,117 @@
|
||||||
"await Console(agent.run_stream(task=\"I am happy.\"))"
|
"await Console(agent.run_stream(task=\"I am happy.\"))"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Streaming Tokens\n",
|
||||||
|
"\n",
|
||||||
|
"You can stream the tokens generated by the model client by setting `model_client_stream=True`.\n",
|
||||||
|
"This will cause the agent to yield {py:class}`~autogen_agentchat.messages.ModelClientStreamingChunkEvent` messages\n",
|
||||||
|
"in {py:meth}`~autogen_agentchat.agents.BaseChatAgent.on_messages_stream` and {py:meth}`~autogen_agentchat.agents.BaseChatAgent.run_stream`.\n",
|
||||||
|
"\n",
|
||||||
|
"The underlying model API must support streaming tokens for this to work.\n",
|
||||||
|
"Please check with your model provider to see if this is supported."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"source='assistant' models_usage=None content='Two' type='ModelClientStreamingChunkEvent'\n",
|
||||||
|
"source='assistant' models_usage=None content=' cities' type='ModelClientStreamingChunkEvent'\n",
|
||||||
|
"source='assistant' models_usage=None content=' in' type='ModelClientStreamingChunkEvent'\n",
|
||||||
|
"source='assistant' models_usage=None content=' South' type='ModelClientStreamingChunkEvent'\n",
|
||||||
|
"source='assistant' models_usage=None content=' America' type='ModelClientStreamingChunkEvent'\n",
|
||||||
|
"source='assistant' models_usage=None content=' are' type='ModelClientStreamingChunkEvent'\n",
|
||||||
|
"source='assistant' models_usage=None content=' Buenos' type='ModelClientStreamingChunkEvent'\n",
|
||||||
|
"source='assistant' models_usage=None content=' Aires' type='ModelClientStreamingChunkEvent'\n",
|
||||||
|
"source='assistant' models_usage=None content=' in' type='ModelClientStreamingChunkEvent'\n",
|
||||||
|
"source='assistant' models_usage=None content=' Argentina' type='ModelClientStreamingChunkEvent'\n",
|
||||||
|
"source='assistant' models_usage=None content=' and' type='ModelClientStreamingChunkEvent'\n",
|
||||||
|
"source='assistant' models_usage=None content=' São' type='ModelClientStreamingChunkEvent'\n",
|
||||||
|
"source='assistant' models_usage=None content=' Paulo' type='ModelClientStreamingChunkEvent'\n",
|
||||||
|
"source='assistant' models_usage=None content=' in' type='ModelClientStreamingChunkEvent'\n",
|
||||||
|
"source='assistant' models_usage=None content=' Brazil' type='ModelClientStreamingChunkEvent'\n",
|
||||||
|
"source='assistant' models_usage=None content='.' type='ModelClientStreamingChunkEvent'\n",
|
||||||
|
"Response(chat_message=TextMessage(source='assistant', models_usage=RequestUsage(prompt_tokens=0, completion_tokens=0), content='Two cities in South America are Buenos Aires in Argentina and São Paulo in Brazil.', type='TextMessage'), inner_messages=[])\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"model_client = OpenAIChatCompletionClient(model=\"gpt-4o\")\n",
|
||||||
|
"\n",
|
||||||
|
"streaming_assistant = AssistantAgent(\n",
|
||||||
|
" name=\"assistant\",\n",
|
||||||
|
" model_client=model_client,\n",
|
||||||
|
" system_message=\"You are a helpful assistant.\",\n",
|
||||||
|
" model_client_stream=True, # Enable streaming tokens.\n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"# Use an async function and asyncio.run() in a script.\n",
|
||||||
|
"async for message in streaming_assistant.on_messages_stream( # type: ignore\n",
|
||||||
|
" [TextMessage(content=\"Name two cities in South America\", source=\"user\")],\n",
|
||||||
|
" cancellation_token=CancellationToken(),\n",
|
||||||
|
"):\n",
|
||||||
|
" print(message)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"You can see the streaming chunks in the output above.\n",
|
||||||
|
"The chunks are generated by the model client and are yielded by the agent as they are received.\n",
|
||||||
|
"The final response, the concatenation of all the chunks, is yielded right after the last chunk.\n",
|
||||||
|
"\n",
|
||||||
|
"Similarly, {py:meth}`~autogen_agentchat.agents.BaseChatAgent.run_stream` will also yield the same streaming chunks,\n",
|
||||||
|
"followed by a full text message right after the last chunk."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"source='user' models_usage=None content='Name two cities in North America.' type='TextMessage'\n",
|
||||||
|
"source='assistant' models_usage=None content='Two' type='ModelClientStreamingChunkEvent'\n",
|
||||||
|
"source='assistant' models_usage=None content=' cities' type='ModelClientStreamingChunkEvent'\n",
|
||||||
|
"source='assistant' models_usage=None content=' in' type='ModelClientStreamingChunkEvent'\n",
|
||||||
|
"source='assistant' models_usage=None content=' North' type='ModelClientStreamingChunkEvent'\n",
|
||||||
|
"source='assistant' models_usage=None content=' America' type='ModelClientStreamingChunkEvent'\n",
|
||||||
|
"source='assistant' models_usage=None content=' are' type='ModelClientStreamingChunkEvent'\n",
|
||||||
|
"source='assistant' models_usage=None content=' New' type='ModelClientStreamingChunkEvent'\n",
|
||||||
|
"source='assistant' models_usage=None content=' York' type='ModelClientStreamingChunkEvent'\n",
|
||||||
|
"source='assistant' models_usage=None content=' City' type='ModelClientStreamingChunkEvent'\n",
|
||||||
|
"source='assistant' models_usage=None content=' in' type='ModelClientStreamingChunkEvent'\n",
|
||||||
|
"source='assistant' models_usage=None content=' the' type='ModelClientStreamingChunkEvent'\n",
|
||||||
|
"source='assistant' models_usage=None content=' United' type='ModelClientStreamingChunkEvent'\n",
|
||||||
|
"source='assistant' models_usage=None content=' States' type='ModelClientStreamingChunkEvent'\n",
|
||||||
|
"source='assistant' models_usage=None content=' and' type='ModelClientStreamingChunkEvent'\n",
|
||||||
|
"source='assistant' models_usage=None content=' Toronto' type='ModelClientStreamingChunkEvent'\n",
|
||||||
|
"source='assistant' models_usage=None content=' in' type='ModelClientStreamingChunkEvent'\n",
|
||||||
|
"source='assistant' models_usage=None content=' Canada' type='ModelClientStreamingChunkEvent'\n",
|
||||||
|
"source='assistant' models_usage=None content='.' type='ModelClientStreamingChunkEvent'\n",
|
||||||
|
"source='assistant' models_usage=RequestUsage(prompt_tokens=0, completion_tokens=0) content='Two cities in North America are New York City in the United States and Toronto in Canada.' type='TextMessage'\n",
|
||||||
|
"TaskResult(messages=[TextMessage(source='user', models_usage=None, content='Name two cities in North America.', type='TextMessage'), TextMessage(source='assistant', models_usage=RequestUsage(prompt_tokens=0, completion_tokens=0), content='Two cities in North America are New York City in the United States and Toronto in Canada.', type='TextMessage')], stop_reason=None)\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"async for message in streaming_assistant.run_stream(task=\"Name two cities in North America.\"): # type: ignore\n",
|
||||||
|
" print(message)"
|
||||||
|
]
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
|
|
|
@ -185,6 +185,9 @@ class ReplayChatCompletionClient(ChatCompletionClient):
|
||||||
yield token + " "
|
yield token + " "
|
||||||
else:
|
else:
|
||||||
yield token
|
yield token
|
||||||
|
yield CreateResult(
|
||||||
|
finish_reason="stop", content=response, usage=self._cur_usage, cached=self._cached_bool_value
|
||||||
|
)
|
||||||
self._update_total_usage()
|
self._update_total_usage()
|
||||||
else:
|
else:
|
||||||
self._cur_usage = RequestUsage(
|
self._cur_usage = RequestUsage(
|
||||||
|
@ -226,7 +229,7 @@ class ReplayChatCompletionClient(ChatCompletionClient):
|
||||||
total_tokens += len(tokens)
|
total_tokens += len(tokens)
|
||||||
all_tokens.extend(tokens)
|
all_tokens.extend(tokens)
|
||||||
else:
|
else:
|
||||||
logger.warning("Token count has been done only on string content", RuntimeWarning)
|
logger.warning("Token count has been done only on string content")
|
||||||
elif isinstance(messages, Sequence):
|
elif isinstance(messages, Sequence):
|
||||||
for message in messages:
|
for message in messages:
|
||||||
if isinstance(message.content, str): # type: ignore [reportAttributeAccessIssue, union-attr]
|
if isinstance(message.content, str): # type: ignore [reportAttributeAccessIssue, union-attr]
|
||||||
|
@ -234,7 +237,7 @@ class ReplayChatCompletionClient(ChatCompletionClient):
|
||||||
total_tokens += len(tokens)
|
total_tokens += len(tokens)
|
||||||
all_tokens.extend(tokens)
|
all_tokens.extend(tokens)
|
||||||
else:
|
else:
|
||||||
logger.warning("Token count has been done only on string content", RuntimeWarning)
|
logger.warning("Token count has been done only on string content")
|
||||||
return all_tokens, total_tokens
|
return all_tokens, total_tokens
|
||||||
|
|
||||||
def _update_total_usage(self) -> None:
|
def _update_total_usage(self) -> None:
|
||||||
|
|
|
@ -16,6 +16,7 @@ from autogen_agentchat.base import Response, TaskResult
|
||||||
from autogen_agentchat.messages import (
|
from autogen_agentchat.messages import (
|
||||||
AgentEvent,
|
AgentEvent,
|
||||||
ChatMessage,
|
ChatMessage,
|
||||||
|
ModelClientStreamingChunkEvent,
|
||||||
MultiModalMessage,
|
MultiModalMessage,
|
||||||
UserInputRequestedEvent,
|
UserInputRequestedEvent,
|
||||||
)
|
)
|
||||||
|
@ -185,6 +186,9 @@ async def RichConsole(
|
||||||
elif isinstance(message, UserInputRequestedEvent):
|
elif isinstance(message, UserInputRequestedEvent):
|
||||||
if user_input_manager is not None:
|
if user_input_manager is not None:
|
||||||
user_input_manager.notify_event_received(message.request_id)
|
user_input_manager.notify_event_received(message.request_id)
|
||||||
|
elif isinstance(message, ModelClientStreamingChunkEvent):
|
||||||
|
# TODO: Handle model client streaming chunk events.
|
||||||
|
pass
|
||||||
else:
|
else:
|
||||||
# Cast required for mypy to be happy
|
# Cast required for mypy to be happy
|
||||||
message = cast(AgentEvent | ChatMessage, message) # type: ignore
|
message = cast(AgentEvent | ChatMessage, message) # type: ignore
|
||||||
|
|
|
@ -107,14 +107,14 @@ async def test_cache_create_stream() -> None:
|
||||||
async for completion in cached_client.create_stream(
|
async for completion in cached_client.create_stream(
|
||||||
[system_prompt, UserMessage(content=prompts[0], source="user")]
|
[system_prompt, UserMessage(content=prompts[0], source="user")]
|
||||||
):
|
):
|
||||||
original_streamed_results.append(completion)
|
original_streamed_results.append(copy.copy(completion))
|
||||||
total_usage0 = copy.copy(cached_client.total_usage())
|
total_usage0 = copy.copy(cached_client.total_usage())
|
||||||
|
|
||||||
cached_completion_results: List[Union[str, CreateResult]] = []
|
cached_completion_results: List[Union[str, CreateResult]] = []
|
||||||
async for completion in cached_client.create_stream(
|
async for completion in cached_client.create_stream(
|
||||||
[system_prompt, UserMessage(content=prompts[0], source="user")]
|
[system_prompt, UserMessage(content=prompts[0], source="user")]
|
||||||
):
|
):
|
||||||
cached_completion_results.append(completion)
|
cached_completion_results.append(copy.copy(completion))
|
||||||
total_usage1 = copy.copy(cached_client.total_usage())
|
total_usage1 = copy.copy(cached_client.total_usage())
|
||||||
|
|
||||||
assert total_usage1.prompt_tokens == total_usage0.prompt_tokens
|
assert total_usage1.prompt_tokens == total_usage0.prompt_tokens
|
||||||
|
|
|
@ -67,12 +67,16 @@ async def test_replay_chat_completion_client_create_stream() -> None:
|
||||||
reply_model_client = ReplayChatCompletionClient(messages)
|
reply_model_client = ReplayChatCompletionClient(messages)
|
||||||
|
|
||||||
for i in range(num_messages):
|
for i in range(num_messages):
|
||||||
result: List[str] = []
|
chunks: List[str] = []
|
||||||
|
result: CreateResult | None = None
|
||||||
async for completion in reply_model_client.create_stream([UserMessage(content="dummy", source="_")]):
|
async for completion in reply_model_client.create_stream([UserMessage(content="dummy", source="_")]):
|
||||||
text = completion.content if isinstance(completion, CreateResult) else completion
|
if isinstance(completion, CreateResult):
|
||||||
assert isinstance(text, str)
|
result = completion
|
||||||
result.append(text)
|
else:
|
||||||
assert "".join(result) == messages[i]
|
assert isinstance(completion, str)
|
||||||
|
chunks.append(completion)
|
||||||
|
assert result is not None
|
||||||
|
assert "".join(chunks) == messages[i] == result.content
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="No more mock responses available"):
|
with pytest.raises(ValueError, match="No more mock responses available"):
|
||||||
await reply_model_client.create([UserMessage(content="dummy", source="_")])
|
await reply_model_client.create([UserMessage(content="dummy", source="_")])
|
||||||
|
|
Loading…
Reference in New Issue