mirror of https://github.com/microsoft/autogen.git
Use class hierarchy to organize AgentChat message types and introduce StructuredMessage type (#5998)
This PR refactored `AgentEvent` and `ChatMessage` union types to abstract base classes. This allows for user-defined message types that subclass one of the base classes to be used in AgentChat. To support a unified interface for working with the messages, the base classes added abstract methods for: - Convert content to string - Convert content to a `UserMessage` for model client - Convert content for rendering in console. - Dump into a dictionary - Load and create a new instance from a dictionary This way, all agents such as `AssistantAgent` and `SocietyOfMindAgent` can utilize the unified interface to work with any built-in and user-defined message type. This PR also introduces a new message type, `StructuredMessage` for AgentChat (Resolves #5131), which is a generic type that requires a user-specified content type. You can create a `StructuredMessage` as follow: ```python class MessageType(BaseModel): data: str references: List[str] message = StructuredMessage[MessageType](content=MessageType(data="data", references=["a", "b"]), source="user") # message.content is of type `MessageType`. ``` This PR addresses the receving side of this message type. To produce this message type from `AssistantAgent`, the work continue in #5934. Added unit tests to verify this message type works with agents and teams.
This commit is contained in:
parent
8a5ee3de6a
commit
025490a1bd
|
@ -31,7 +31,6 @@ from autogen_core.models import (
|
|||
LLMMessage,
|
||||
ModelFamily,
|
||||
SystemMessage,
|
||||
UserMessage,
|
||||
)
|
||||
from autogen_core.tools import BaseTool, FunctionTool
|
||||
from pydantic import BaseModel
|
||||
|
@ -814,14 +813,13 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
|
|||
messages: Sequence[ChatMessage],
|
||||
) -> None:
|
||||
"""
|
||||
Add incoming user (and possibly handoff) messages to the model context.
|
||||
Add incoming messages to the model context.
|
||||
"""
|
||||
for msg in messages:
|
||||
if isinstance(msg, HandoffMessage):
|
||||
# Add handoff context to the model context.
|
||||
for context_msg in msg.context:
|
||||
await model_context.add_message(context_msg)
|
||||
await model_context.add_message(UserMessage(content=msg.content, source=msg.source))
|
||||
for llm_msg in msg.context:
|
||||
await model_context.add_message(llm_msg)
|
||||
await model_context.add_message(msg.to_model_message())
|
||||
|
||||
@staticmethod
|
||||
async def _update_model_context_with_memory(
|
||||
|
|
|
@ -7,7 +7,6 @@ from pydantic import BaseModel
|
|||
from ..base import ChatAgent, Response, TaskResult
|
||||
from ..messages import (
|
||||
AgentEvent,
|
||||
BaseChatMessage,
|
||||
ChatMessage,
|
||||
ModelClientStreamingChunkEvent,
|
||||
TextMessage,
|
||||
|
@ -121,7 +120,7 @@ class BaseChatAgent(ChatAgent, ABC, ComponentBase[BaseModel]):
|
|||
text_msg = TextMessage(content=task, source="user")
|
||||
input_messages.append(text_msg)
|
||||
output_messages.append(text_msg)
|
||||
elif isinstance(task, BaseChatMessage):
|
||||
elif isinstance(task, ChatMessage):
|
||||
input_messages.append(task)
|
||||
output_messages.append(task)
|
||||
else:
|
||||
|
@ -129,7 +128,7 @@ class BaseChatAgent(ChatAgent, ABC, ComponentBase[BaseModel]):
|
|||
raise ValueError("Task list cannot be empty.")
|
||||
# Task is a sequence of messages.
|
||||
for msg in task:
|
||||
if isinstance(msg, BaseChatMessage):
|
||||
if isinstance(msg, ChatMessage):
|
||||
input_messages.append(msg)
|
||||
output_messages.append(msg)
|
||||
else:
|
||||
|
@ -159,7 +158,7 @@ class BaseChatAgent(ChatAgent, ABC, ComponentBase[BaseModel]):
|
|||
input_messages.append(text_msg)
|
||||
output_messages.append(text_msg)
|
||||
yield text_msg
|
||||
elif isinstance(task, BaseChatMessage):
|
||||
elif isinstance(task, ChatMessage):
|
||||
input_messages.append(task)
|
||||
output_messages.append(task)
|
||||
yield task
|
||||
|
@ -167,7 +166,7 @@ class BaseChatAgent(ChatAgent, ABC, ComponentBase[BaseModel]):
|
|||
if not task:
|
||||
raise ValueError("Task list cannot be empty.")
|
||||
for msg in task:
|
||||
if isinstance(msg, BaseChatMessage):
|
||||
if isinstance(msg, ChatMessage):
|
||||
input_messages.append(msg)
|
||||
output_messages.append(msg)
|
||||
yield msg
|
||||
|
|
|
@ -21,7 +21,9 @@ class CodeExecutorAgentConfig(BaseModel):
|
|||
|
||||
|
||||
class CodeExecutorAgent(BaseChatAgent, Component[CodeExecutorAgentConfig]):
|
||||
"""An agent that extracts and executes code snippets found in received messages and returns the output.
|
||||
"""An agent that extracts and executes code snippets found in received
|
||||
:class:`~autogen_agentchat.messages.TextMessage` messages and returns the output
|
||||
of the code execution.
|
||||
|
||||
It is typically used within a team with another agent that generates code snippets to be executed.
|
||||
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
from typing import Any, AsyncGenerator, List, Mapping, Sequence
|
||||
|
||||
from autogen_core import CancellationToken, Component, ComponentModel
|
||||
from autogen_core.models import ChatCompletionClient, LLMMessage, SystemMessage, UserMessage
|
||||
from autogen_core.models import ChatCompletionClient, LLMMessage, SystemMessage
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import Self
|
||||
|
||||
|
@ -11,7 +11,6 @@ from autogen_agentchat.state import SocietyOfMindAgentState
|
|||
from ..base import TaskResult, Team
|
||||
from ..messages import (
|
||||
AgentEvent,
|
||||
BaseChatMessage,
|
||||
ChatMessage,
|
||||
ModelClientStreamingChunkEvent,
|
||||
TextMessage,
|
||||
|
@ -167,13 +166,9 @@ class SocietyOfMindAgent(BaseChatAgent, Component[SocietyOfMindAgentConfig]):
|
|||
else:
|
||||
# Generate a response using the model client.
|
||||
llm_messages: List[LLMMessage] = [SystemMessage(content=self._instruction)]
|
||||
llm_messages.extend(
|
||||
[
|
||||
UserMessage(content=message.content, source=message.source)
|
||||
for message in inner_messages
|
||||
if isinstance(message, BaseChatMessage)
|
||||
]
|
||||
)
|
||||
for message in messages:
|
||||
if isinstance(message, ChatMessage):
|
||||
llm_messages.append(message.to_model_message())
|
||||
llm_messages.append(SystemMessage(content=self._response_prompt))
|
||||
completion = await self._model_client.create(messages=llm_messages, cancellation_token=cancellation_token)
|
||||
assert isinstance(completion.content, str)
|
||||
|
|
|
@ -82,6 +82,7 @@ class UserProxyAgent(BaseChatAgent, Component[UserProxyAgentConfig]):
|
|||
cancellation_token=CancellationToken(),
|
||||
)
|
||||
)
|
||||
assert isinstance(response.chat_message, TextMessage)
|
||||
print(f"Your name is {response.chat_message.content}")
|
||||
|
||||
Example:
|
||||
|
@ -117,6 +118,7 @@ class UserProxyAgent(BaseChatAgent, Component[UserProxyAgentConfig]):
|
|||
)
|
||||
)
|
||||
response = await agent_task
|
||||
assert isinstance(response.chat_message, TextMessage)
|
||||
print(f"Your name is {response.chat_message.content}")
|
||||
except Exception as e:
|
||||
print(f"Exception: {e}")
|
||||
|
|
|
@ -11,7 +11,6 @@ from ..messages import (
|
|||
BaseChatMessage,
|
||||
ChatMessage,
|
||||
HandoffMessage,
|
||||
MultiModalMessage,
|
||||
StopMessage,
|
||||
TextMessage,
|
||||
ToolCallExecutionEvent,
|
||||
|
@ -137,18 +136,12 @@ class TextMentionTermination(TerminationCondition, Component[TextMentionTerminat
|
|||
if self._sources is not None and message.source not in self._sources:
|
||||
continue
|
||||
|
||||
if isinstance(message.content, str) and self._termination_text in message.content:
|
||||
content = message.to_text()
|
||||
if self._termination_text in content:
|
||||
self._terminated = True
|
||||
return StopMessage(
|
||||
content=f"Text '{self._termination_text}' mentioned", source="TextMentionTermination"
|
||||
)
|
||||
elif isinstance(message, MultiModalMessage):
|
||||
for item in message.content:
|
||||
if isinstance(item, str) and self._termination_text in item:
|
||||
self._terminated = True
|
||||
return StopMessage(
|
||||
content=f"Text '{self._termination_text}' mentioned", source="TextMentionTermination"
|
||||
)
|
||||
return None
|
||||
|
||||
async def reset(self) -> None:
|
||||
|
|
|
@ -1,21 +1,69 @@
|
|||
"""
|
||||
This module defines various message types used for agent-to-agent communication.
|
||||
Each message type inherits either from the BaseChatMessage class or BaseAgentEvent
|
||||
Each message type inherits either from the ChatMessage class or BaseAgentEvent
|
||||
class and includes specific fields relevant to the type of message being sent.
|
||||
"""
|
||||
|
||||
from abc import ABC
|
||||
from typing import Dict, List, Literal
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, Generic, List, Literal, Mapping, TypeVar
|
||||
|
||||
from autogen_core import FunctionCall, Image
|
||||
from autogen_core.memory import MemoryContent
|
||||
from autogen_core.models import FunctionExecutionResult, LLMMessage, RequestUsage
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from typing_extensions import Annotated
|
||||
from autogen_core.models import FunctionExecutionResult, LLMMessage, RequestUsage, UserMessage
|
||||
from pydantic import BaseModel, ConfigDict, computed_field
|
||||
from typing_extensions import Self
|
||||
|
||||
|
||||
class BaseMessage(BaseModel, ABC):
|
||||
"""Base class for all message types."""
|
||||
"""Base class for all message types in AgentChat. This is an abstract class
|
||||
with default implementations for serialization and deserialization.
|
||||
|
||||
.. warning::
|
||||
|
||||
If you want to create a new message type, do not inherit from this class.
|
||||
Instead, inherit from :class:`ChatMessage` or :class:`AgentEvent`
|
||||
to clarify the purpose of the message type.
|
||||
|
||||
"""
|
||||
|
||||
@computed_field
|
||||
def type(self) -> str:
|
||||
"""The class name of this message."""
|
||||
return self.__class__.__name__
|
||||
|
||||
def dump(self) -> Mapping[str, Any]:
|
||||
"""Convert the message to a JSON-serializable dictionary.
|
||||
|
||||
The default implementation uses the Pydantic model's `model_dump` method.
|
||||
|
||||
If you want to customize the serialization, override this method.
|
||||
"""
|
||||
return self.model_dump()
|
||||
|
||||
@classmethod
|
||||
def load(cls, data: Mapping[str, Any]) -> Self:
|
||||
"""Create a message from a dictionary of JSON-serializable data.
|
||||
|
||||
The default implementation uses the Pydantic model's `model_validate` method.
|
||||
If you want to customize the deserialization, override this method.
|
||||
"""
|
||||
return cls.model_validate(data)
|
||||
|
||||
|
||||
class ChatMessage(BaseMessage, ABC):
|
||||
"""Base class for chat messages.
|
||||
|
||||
.. note::
|
||||
|
||||
If you want to create a new message type that is used for agent-to-agent
|
||||
communication, inherit from this class, or simply use
|
||||
:class:`StructuredMessage` if your content type is a subclass of
|
||||
Pydantic BaseModel.
|
||||
|
||||
This class is used for messages that are sent between agents in a chat
|
||||
conversation. Agents are expected to process the content of the
|
||||
message using models and return a response as another :class:`ChatMessage`.
|
||||
"""
|
||||
|
||||
source: str
|
||||
"""The name of the agent that sent this message."""
|
||||
|
@ -28,89 +76,231 @@ class BaseMessage(BaseModel, ABC):
|
|||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
@abstractmethod
|
||||
def to_text(self) -> str:
|
||||
"""Convert the content of the message to a string-only representation
|
||||
that can be rendered in the console and inspected by the user or conditions.
|
||||
|
||||
class BaseChatMessage(BaseMessage, ABC):
|
||||
"""Base class for chat messages."""
|
||||
This is not used for creating text-only content for models.
|
||||
For :class:`ChatMessage` types, use :meth:`to_model_text` instead."""
|
||||
...
|
||||
|
||||
pass
|
||||
@abstractmethod
|
||||
def to_model_text(self) -> str:
|
||||
"""Convert the content of the message to text-only representation.
|
||||
This is used for creating text-only content for models.
|
||||
|
||||
This is not used for rendering the message in console. For that, use
|
||||
:meth:`~BaseMessage.to_text`.
|
||||
|
||||
The difference between this and :meth:`to_model_message` is that this
|
||||
is used to construct parts of the a message for the model client,
|
||||
while :meth:`to_model_message` is used to create a complete message
|
||||
for the model client.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def to_model_message(self) -> UserMessage:
|
||||
"""Convert the message content to a :class:`~autogen_core.models.UserMessage`
|
||||
for use with model client, e.g., :class:`~autogen_core.models.ChatCompletionClient`."""
|
||||
...
|
||||
|
||||
|
||||
class BaseAgentEvent(BaseMessage, ABC):
|
||||
"""Base class for agent events."""
|
||||
class TextChatMessage(ChatMessage, ABC):
|
||||
"""Base class for all text-only :class:`ChatMessage` types.
|
||||
It has implementations for :meth:`to_text`, :meth:`to_model_text`,
|
||||
and :meth:`to_model_message` methods.
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class TextMessage(BaseChatMessage):
|
||||
"""A text message."""
|
||||
Inherit from this class if your message content type is a string.
|
||||
"""
|
||||
|
||||
content: str
|
||||
"""The content of the message."""
|
||||
|
||||
type: Literal["TextMessage"] = "TextMessage"
|
||||
def to_text(self) -> str:
|
||||
return self.content
|
||||
|
||||
def to_model_text(self) -> str:
|
||||
return self.content
|
||||
|
||||
def to_model_message(self) -> UserMessage:
|
||||
return UserMessage(content=self.content, source=self.source)
|
||||
|
||||
|
||||
class MultiModalMessage(BaseChatMessage):
|
||||
class AgentEvent(BaseMessage, ABC):
|
||||
"""Base class for agent events.
|
||||
|
||||
.. note::
|
||||
|
||||
If you want to create a new message type for signaling observable events
|
||||
to user and application, inherit from this class.
|
||||
|
||||
Agent events are used to signal actions and thoughts produced by agents
|
||||
and teams to user and applications. They are not used for agent-to-agent
|
||||
communication and are not expected to be processed by other agents.
|
||||
|
||||
You should override the :meth:`to_text` method if you want to provide
|
||||
a custom rendering of the content.
|
||||
"""
|
||||
|
||||
source: str
|
||||
"""The name of the agent that sent this message."""
|
||||
|
||||
models_usage: RequestUsage | None = None
|
||||
"""The model client usage incurred when producing this message."""
|
||||
|
||||
metadata: Dict[str, str] = {}
|
||||
"""Additional metadata about the message."""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
@abstractmethod
|
||||
def to_text(self) -> str:
|
||||
"""Convert the content of the message to a string-only representation
|
||||
that can be rendered in the console and inspected by the user.
|
||||
|
||||
This is not used for creating text-only content for models.
|
||||
For :class:`ChatMessage` types, use :meth:`to_model_text` instead."""
|
||||
...
|
||||
|
||||
|
||||
StructuredContentType = TypeVar("StructuredContentType", bound=BaseModel, covariant=True)
|
||||
"""Type variable for structured content types."""
|
||||
|
||||
|
||||
class StructuredMessage(ChatMessage, Generic[StructuredContentType]):
|
||||
"""A :class:`ChatMessage` type with an unspecified content type.
|
||||
|
||||
To create a new structured message type, specify the content type
|
||||
as a subclass of `Pydantic BaseModel <https://docs.pydantic.dev/latest/concepts/models/>`_.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from pydantic import BaseModel
|
||||
from autogen_agentchat.messages import StructuredMessage
|
||||
|
||||
|
||||
class MyMessageContent(BaseModel):
|
||||
text: str
|
||||
number: int
|
||||
|
||||
|
||||
message = StructuredMessage[MyMessageContent](
|
||||
content=MyMessageContent(text="Hello", number=42),
|
||||
source="agent1",
|
||||
)
|
||||
|
||||
print(message.to_text()) # {"text": "Hello", "number": 42}
|
||||
|
||||
"""
|
||||
|
||||
content: StructuredContentType
|
||||
"""The content of the message. Must be a subclass of
|
||||
`Pydantic BaseModel <https://docs.pydantic.dev/latest/concepts/models/>`_."""
|
||||
|
||||
def to_text(self) -> str:
|
||||
return self.content.model_dump_json(indent=2)
|
||||
|
||||
def to_model_text(self) -> str:
|
||||
return self.content.model_dump_json()
|
||||
|
||||
def to_model_message(self) -> UserMessage:
|
||||
return UserMessage(
|
||||
content=self.content.model_dump_json(),
|
||||
source=self.source,
|
||||
)
|
||||
|
||||
|
||||
class TextMessage(TextChatMessage):
|
||||
"""A text message with string-only content."""
|
||||
|
||||
...
|
||||
|
||||
|
||||
class MultiModalMessage(ChatMessage):
|
||||
"""A multimodal message."""
|
||||
|
||||
content: List[str | Image]
|
||||
"""The content of the message."""
|
||||
|
||||
type: Literal["MultiModalMessage"] = "MultiModalMessage"
|
||||
def to_model_text(self, image_placeholder: str | None = "[image]") -> str:
|
||||
"""Convert the content of the message to a string-only representation.
|
||||
If an image is present, it will be replaced with the image placeholder
|
||||
by default, otherwise it will be a base64 string when set to None.
|
||||
"""
|
||||
text = ""
|
||||
for c in self.content:
|
||||
if isinstance(c, str):
|
||||
text += c
|
||||
elif isinstance(c, Image):
|
||||
if image_placeholder is not None:
|
||||
text += f" {image_placeholder}"
|
||||
else:
|
||||
text += f" {c.to_base64()}"
|
||||
return text
|
||||
|
||||
def to_text(self, iterm: bool = False) -> str:
|
||||
result: List[str] = []
|
||||
for c in self.content:
|
||||
if isinstance(c, str):
|
||||
result.append(c)
|
||||
else:
|
||||
if iterm:
|
||||
# iTerm2 image rendering protocol: https://iterm2.com/documentation-images.html
|
||||
image_data = c.to_base64()
|
||||
result.append(f"\033]1337;File=inline=1:{image_data}\a\n")
|
||||
else:
|
||||
result.append("<image>")
|
||||
return "\n".join(result)
|
||||
|
||||
def to_model_message(self) -> UserMessage:
|
||||
return UserMessage(content=self.content, source=self.source)
|
||||
|
||||
|
||||
class StopMessage(BaseChatMessage):
|
||||
class StopMessage(TextChatMessage):
|
||||
"""A message requesting stop of a conversation."""
|
||||
|
||||
content: str
|
||||
"""The content for the stop message."""
|
||||
|
||||
type: Literal["StopMessage"] = "StopMessage"
|
||||
...
|
||||
|
||||
|
||||
class HandoffMessage(BaseChatMessage):
|
||||
class HandoffMessage(TextChatMessage):
|
||||
"""A message requesting handoff of a conversation to another agent."""
|
||||
|
||||
target: str
|
||||
"""The name of the target agent to handoff to."""
|
||||
|
||||
content: str
|
||||
"""The handoff message to the target agent."""
|
||||
|
||||
context: List[LLMMessage] = []
|
||||
"""The model context to be passed to the target agent."""
|
||||
|
||||
type: Literal["HandoffMessage"] = "HandoffMessage"
|
||||
|
||||
class ToolCallSummaryMessage(TextChatMessage):
|
||||
"""A message signaling the summary of tool call results."""
|
||||
|
||||
...
|
||||
|
||||
|
||||
class ToolCallRequestEvent(BaseAgentEvent):
|
||||
class ToolCallRequestEvent(AgentEvent):
|
||||
"""An event signaling a request to use tools."""
|
||||
|
||||
content: List[FunctionCall]
|
||||
"""The tool calls."""
|
||||
|
||||
type: Literal["ToolCallRequestEvent"] = "ToolCallRequestEvent"
|
||||
def to_text(self) -> str:
|
||||
return str(self.content)
|
||||
|
||||
|
||||
class ToolCallExecutionEvent(BaseAgentEvent):
|
||||
class ToolCallExecutionEvent(AgentEvent):
|
||||
"""An event signaling the execution of tool calls."""
|
||||
|
||||
content: List[FunctionExecutionResult]
|
||||
"""The tool call results."""
|
||||
|
||||
type: Literal["ToolCallExecutionEvent"] = "ToolCallExecutionEvent"
|
||||
def to_text(self) -> str:
|
||||
return str(self.content)
|
||||
|
||||
|
||||
class ToolCallSummaryMessage(BaseChatMessage):
|
||||
"""A message signaling the summary of tool call results."""
|
||||
|
||||
content: str
|
||||
"""Summary of the the tool call results."""
|
||||
|
||||
type: Literal["ToolCallSummaryMessage"] = "ToolCallSummaryMessage"
|
||||
|
||||
|
||||
class UserInputRequestedEvent(BaseAgentEvent):
|
||||
class UserInputRequestedEvent(AgentEvent):
|
||||
"""An event signaling a that the user proxy has requested user input. Published prior to invoking the input callback."""
|
||||
|
||||
request_id: str
|
||||
|
@ -119,60 +309,117 @@ class UserInputRequestedEvent(BaseAgentEvent):
|
|||
content: Literal[""] = ""
|
||||
"""Empty content for compat with consumers expecting a content field."""
|
||||
|
||||
type: Literal["UserInputRequestedEvent"] = "UserInputRequestedEvent"
|
||||
def to_text(self) -> str:
|
||||
return str(self.content)
|
||||
|
||||
|
||||
class MemoryQueryEvent(BaseAgentEvent):
|
||||
class MemoryQueryEvent(AgentEvent):
|
||||
"""An event signaling the results of memory queries."""
|
||||
|
||||
content: List[MemoryContent]
|
||||
"""The memory query results."""
|
||||
|
||||
type: Literal["MemoryQueryEvent"] = "MemoryQueryEvent"
|
||||
def to_text(self) -> str:
|
||||
return str(self.content)
|
||||
|
||||
|
||||
class ModelClientStreamingChunkEvent(BaseAgentEvent):
|
||||
class ModelClientStreamingChunkEvent(AgentEvent):
|
||||
"""An event signaling a text output chunk from a model client in streaming mode."""
|
||||
|
||||
content: str
|
||||
"""The partial text chunk."""
|
||||
"""A string chunk from the model client."""
|
||||
|
||||
type: Literal["ModelClientStreamingChunkEvent"] = "ModelClientStreamingChunkEvent"
|
||||
def to_text(self) -> str:
|
||||
return self.content
|
||||
|
||||
|
||||
class ThoughtEvent(BaseAgentEvent):
|
||||
"""An event signaling the thought process of an agent.
|
||||
class ThoughtEvent(AgentEvent):
|
||||
"""An event signaling the thought process of a model.
|
||||
It is used to communicate the reasoning tokens generated by a reasoning model,
|
||||
or the extra text content generated by a function call."""
|
||||
|
||||
content: str
|
||||
"""The thought process."""
|
||||
"""The thought process of the model."""
|
||||
|
||||
type: Literal["ThoughtEvent"] = "ThoughtEvent"
|
||||
def to_text(self) -> str:
|
||||
return self.content
|
||||
|
||||
|
||||
ChatMessage = Annotated[
|
||||
TextMessage | MultiModalMessage | StopMessage | ToolCallSummaryMessage | HandoffMessage, Field(discriminator="type")
|
||||
]
|
||||
"""Messages for agent-to-agent communication only."""
|
||||
class MessageFactory:
|
||||
""":meta private:
|
||||
|
||||
A factory for creating messages from JSON-serializable dictionaries.
|
||||
|
||||
This is useful for deserializing messages from JSON data.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._message_types: Dict[str, type[AgentEvent | ChatMessage]] = {}
|
||||
# Register all message types.
|
||||
self._message_types[TextMessage.__name__] = TextMessage
|
||||
self._message_types[MultiModalMessage.__name__] = MultiModalMessage
|
||||
self._message_types[StopMessage.__name__] = StopMessage
|
||||
self._message_types[ToolCallSummaryMessage.__name__] = ToolCallSummaryMessage
|
||||
self._message_types[HandoffMessage.__name__] = HandoffMessage
|
||||
self._message_types[ToolCallRequestEvent.__name__] = ToolCallRequestEvent
|
||||
self._message_types[ToolCallExecutionEvent.__name__] = ToolCallExecutionEvent
|
||||
self._message_types[MemoryQueryEvent.__name__] = MemoryQueryEvent
|
||||
self._message_types[UserInputRequestedEvent.__name__] = UserInputRequestedEvent
|
||||
self._message_types[ModelClientStreamingChunkEvent.__name__] = ModelClientStreamingChunkEvent
|
||||
self._message_types[ThoughtEvent.__name__] = ThoughtEvent
|
||||
|
||||
def is_registered(self, message_type: type[AgentEvent | ChatMessage]) -> bool:
|
||||
"""Check if a message type is registered with the factory."""
|
||||
# Get the class name of the message type.
|
||||
class_name = message_type.__name__
|
||||
# Check if the class name is already registered.
|
||||
return class_name in self._message_types
|
||||
|
||||
def register(self, message_type: type[AgentEvent | ChatMessage]) -> None:
|
||||
"""Register a new message type with the factory."""
|
||||
if self.is_registered(message_type):
|
||||
raise ValueError(f"Message type {message_type} is already registered.")
|
||||
if not issubclass(message_type, ChatMessage) and not issubclass(message_type, AgentEvent):
|
||||
raise ValueError(f"Message type {message_type} must be a subclass of ChatMessage or AgentEvent.")
|
||||
# Get the class name of the
|
||||
class_name = message_type.__name__
|
||||
# Check if the class name is already registered.
|
||||
# Register the message type.
|
||||
self._message_types[class_name] = message_type
|
||||
|
||||
def create(self, data: Mapping[str, Any]) -> AgentEvent | ChatMessage:
|
||||
"""Create a message from a dictionary of JSON-serializable data."""
|
||||
# Get the type of the message from the dictionary.
|
||||
message_type = data.get("type")
|
||||
if message_type not in self._message_types:
|
||||
raise ValueError(f"Unknown message type: {message_type}")
|
||||
if not isinstance(message_type, str):
|
||||
raise ValueError(f"Message type must be a string, got {type(message_type)}")
|
||||
|
||||
# Get the class for the message type.
|
||||
message_class = self._message_types[message_type]
|
||||
|
||||
# Create an instance of the message class.
|
||||
assert issubclass(message_class, ChatMessage) or issubclass(message_class, AgentEvent)
|
||||
return message_class.load(data)
|
||||
|
||||
|
||||
AgentEvent = Annotated[
|
||||
ToolCallRequestEvent
|
||||
| ToolCallExecutionEvent
|
||||
| MemoryQueryEvent
|
||||
| UserInputRequestedEvent
|
||||
| ModelClientStreamingChunkEvent
|
||||
| ThoughtEvent,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
"""Events emitted by agents and teams when they work, not used for agent-to-agent communication."""
|
||||
# For backward compatibility
|
||||
BaseAgentEvent = AgentEvent
|
||||
BaseChatMessage = ChatMessage
|
||||
|
||||
|
||||
__all__ = [
|
||||
"AgentEvent",
|
||||
"BaseMessage",
|
||||
"ChatMessage",
|
||||
"BaseChatMessage",
|
||||
"BaseAgentEvent",
|
||||
"AgentEvent",
|
||||
"TextChatMessage",
|
||||
"ChatMessage",
|
||||
"StructuredContentType",
|
||||
"StructuredMessage",
|
||||
"HandoffMessage",
|
||||
"MultiModalMessage",
|
||||
"StopMessage",
|
||||
|
@ -184,4 +431,5 @@ __all__ = [
|
|||
"UserInputRequestedEvent",
|
||||
"ModelClientStreamingChunkEvent",
|
||||
"ThoughtEvent",
|
||||
"MessageFactory",
|
||||
]
|
||||
|
|
|
@ -1,15 +1,7 @@
|
|||
from typing import Annotated, Any, List, Mapping, Optional
|
||||
from typing import Any, List, Mapping, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from ..messages import (
|
||||
AgentEvent,
|
||||
ChatMessage,
|
||||
)
|
||||
|
||||
# Ensures pydantic can distinguish between types of events & messages.
|
||||
_AgentMessage = Annotated[AgentEvent | ChatMessage, Field(discriminator="type")]
|
||||
|
||||
|
||||
class BaseState(BaseModel):
|
||||
"""Base class for all saveable state"""
|
||||
|
@ -35,7 +27,7 @@ class TeamState(BaseState):
|
|||
class BaseGroupChatManagerState(BaseState):
|
||||
"""Base state for all group chat managers."""
|
||||
|
||||
message_thread: List[_AgentMessage] = Field(default_factory=list)
|
||||
message_thread: List[Mapping[str, Any]] = Field(default_factory=list)
|
||||
current_turn: int = Field(default=0)
|
||||
type: str = Field(default="BaseGroupChatManagerState")
|
||||
|
||||
|
@ -44,7 +36,7 @@ class ChatAgentContainerState(BaseState):
|
|||
"""State for a container of chat agents."""
|
||||
|
||||
agent_state: Mapping[str, Any] = Field(default_factory=dict)
|
||||
message_buffer: List[ChatMessage] = Field(default_factory=list)
|
||||
message_buffer: List[Mapping[str, Any]] = Field(default_factory=list)
|
||||
type: str = Field(default="ChatAgentContainerState")
|
||||
|
||||
|
||||
|
|
|
@ -19,8 +19,8 @@ from ... import EVENT_LOGGER_NAME
|
|||
from ...base import ChatAgent, TaskResult, Team, TerminationCondition
|
||||
from ...messages import (
|
||||
AgentEvent,
|
||||
BaseChatMessage,
|
||||
ChatMessage,
|
||||
MessageFactory,
|
||||
ModelClientStreamingChunkEvent,
|
||||
StopMessage,
|
||||
TextMessage,
|
||||
|
@ -50,6 +50,7 @@ class BaseGroupChat(Team, ABC, ComponentBase[BaseModel]):
|
|||
termination_condition: TerminationCondition | None = None,
|
||||
max_turns: int | None = None,
|
||||
runtime: AgentRuntime | None = None,
|
||||
custom_message_types: List[type[AgentEvent | ChatMessage]] | None = None,
|
||||
):
|
||||
if len(participants) == 0:
|
||||
raise ValueError("At least one participant is required.")
|
||||
|
@ -59,6 +60,10 @@ class BaseGroupChat(Team, ABC, ComponentBase[BaseModel]):
|
|||
self._base_group_chat_manager_class = group_chat_manager_class
|
||||
self._termination_condition = termination_condition
|
||||
self._max_turns = max_turns
|
||||
self._message_factory = MessageFactory()
|
||||
if custom_message_types is not None:
|
||||
for message_type in custom_message_types:
|
||||
self._message_factory.register(message_type)
|
||||
|
||||
# The team ID is a UUID that is used to identify the team and its participants
|
||||
# in the agent runtime. It is used to create unique topic types for each participant.
|
||||
|
@ -115,6 +120,7 @@ class BaseGroupChat(Team, ABC, ComponentBase[BaseModel]):
|
|||
output_message_queue: asyncio.Queue[AgentEvent | ChatMessage | GroupChatTermination],
|
||||
termination_condition: TerminationCondition | None,
|
||||
max_turns: int | None,
|
||||
message_factory: MessageFactory,
|
||||
) -> Callable[[], SequentialRoutedAgent]: ...
|
||||
|
||||
def _create_participant_factory(
|
||||
|
@ -122,9 +128,10 @@ class BaseGroupChat(Team, ABC, ComponentBase[BaseModel]):
|
|||
parent_topic_type: str,
|
||||
output_topic_type: str,
|
||||
agent: ChatAgent,
|
||||
message_factory: MessageFactory,
|
||||
) -> Callable[[], ChatAgentContainer]:
|
||||
def _factory() -> ChatAgentContainer:
|
||||
container = ChatAgentContainer(parent_topic_type, output_topic_type, agent)
|
||||
container = ChatAgentContainer(parent_topic_type, output_topic_type, agent, message_factory)
|
||||
return container
|
||||
|
||||
return _factory
|
||||
|
@ -140,7 +147,9 @@ class BaseGroupChat(Team, ABC, ComponentBase[BaseModel]):
|
|||
await ChatAgentContainer.register(
|
||||
runtime,
|
||||
type=agent_type,
|
||||
factory=self._create_participant_factory(self._group_topic_type, self._output_topic_type, participant),
|
||||
factory=self._create_participant_factory(
|
||||
self._group_topic_type, self._output_topic_type, participant, self._message_factory
|
||||
),
|
||||
)
|
||||
# Add subscriptions for the participant.
|
||||
# The participant should be able to receive messages from its own topic.
|
||||
|
@ -162,6 +171,7 @@ class BaseGroupChat(Team, ABC, ComponentBase[BaseModel]):
|
|||
output_message_queue=self._output_message_queue,
|
||||
termination_condition=self._termination_condition,
|
||||
max_turns=self._max_turns,
|
||||
message_factory=self._message_factory,
|
||||
),
|
||||
)
|
||||
# Add subscriptions for the group chat manager.
|
||||
|
@ -393,16 +403,27 @@ class BaseGroupChat(Team, ABC, ComponentBase[BaseModel]):
|
|||
pass
|
||||
elif isinstance(task, str):
|
||||
messages = [TextMessage(content=task, source="user")]
|
||||
elif isinstance(task, BaseChatMessage):
|
||||
elif isinstance(task, ChatMessage):
|
||||
messages = [task]
|
||||
else:
|
||||
elif isinstance(task, list):
|
||||
if not task:
|
||||
raise ValueError("Task list cannot be empty.")
|
||||
messages = []
|
||||
for msg in task:
|
||||
if not isinstance(msg, BaseChatMessage):
|
||||
if not isinstance(msg, ChatMessage):
|
||||
raise ValueError("All messages in task list must be valid ChatMessage types")
|
||||
messages.append(msg)
|
||||
else:
|
||||
raise ValueError("Task must be a string, a ChatMessage, or a list of ChatMessage.")
|
||||
# Check if the messages types are registered with the message factory.
|
||||
if messages is not None:
|
||||
for msg in messages:
|
||||
if not self._message_factory.is_registered(msg.__class__):
|
||||
raise ValueError(
|
||||
f"Message type {msg.__class__} is not registered with the message factory. "
|
||||
"Please register it with the message factory by adding it to the "
|
||||
"custom_message_types list when creating the team."
|
||||
)
|
||||
|
||||
if self._is_running:
|
||||
raise ValueError("The team is already running, it cannot run again until it is stopped.")
|
||||
|
|
|
@ -5,7 +5,7 @@ from typing import Any, List
|
|||
from autogen_core import DefaultTopicId, MessageContext, event, rpc
|
||||
|
||||
from ...base import TerminationCondition
|
||||
from ...messages import AgentEvent, ChatMessage, StopMessage
|
||||
from ...messages import AgentEvent, ChatMessage, MessageFactory, StopMessage
|
||||
from ._events import (
|
||||
GroupChatAgentResponse,
|
||||
GroupChatMessage,
|
||||
|
@ -40,8 +40,9 @@ class BaseGroupChatManager(SequentialRoutedAgent, ABC):
|
|||
participant_names: List[str],
|
||||
participant_descriptions: List[str],
|
||||
output_message_queue: asyncio.Queue[AgentEvent | ChatMessage | GroupChatTermination],
|
||||
termination_condition: TerminationCondition | None = None,
|
||||
max_turns: int | None = None,
|
||||
termination_condition: TerminationCondition | None,
|
||||
max_turns: int | None,
|
||||
message_factory: MessageFactory,
|
||||
):
|
||||
super().__init__(
|
||||
description="Group chat manager",
|
||||
|
@ -73,6 +74,7 @@ class BaseGroupChatManager(SequentialRoutedAgent, ABC):
|
|||
raise ValueError("The maximum number of turns must be greater than 0.")
|
||||
self._max_turns = max_turns
|
||||
self._current_turn = 0
|
||||
self._message_factory = message_factory
|
||||
|
||||
@rpc
|
||||
async def handle_start(self, message: GroupChatStart, ctx: MessageContext) -> None:
|
||||
|
|
|
@ -2,8 +2,9 @@ from typing import Any, List, Mapping
|
|||
|
||||
from autogen_core import DefaultTopicId, MessageContext, event, rpc
|
||||
|
||||
from autogen_agentchat.messages import AgentEvent, ChatMessage, MessageFactory
|
||||
|
||||
from ...base import ChatAgent, Response
|
||||
from ...messages import ChatMessage
|
||||
from ...state import ChatAgentContainerState
|
||||
from ._events import (
|
||||
GroupChatAgentResponse,
|
||||
|
@ -26,9 +27,13 @@ class ChatAgentContainer(SequentialRoutedAgent):
|
|||
parent_topic_type (str): The topic type of the parent orchestrator.
|
||||
output_topic_type (str): The topic type for the output.
|
||||
agent (ChatAgent): The agent to delegate message handling to.
|
||||
message_factory (MessageFactory): The message factory to use for
|
||||
creating messages from JSON data.
|
||||
"""
|
||||
|
||||
def __init__(self, parent_topic_type: str, output_topic_type: str, agent: ChatAgent) -> None:
|
||||
def __init__(
|
||||
self, parent_topic_type: str, output_topic_type: str, agent: ChatAgent, message_factory: MessageFactory
|
||||
) -> None:
|
||||
super().__init__(
|
||||
description=agent.description,
|
||||
sequential_message_types=[
|
||||
|
@ -42,17 +47,19 @@ class ChatAgentContainer(SequentialRoutedAgent):
|
|||
self._output_topic_type = output_topic_type
|
||||
self._agent = agent
|
||||
self._message_buffer: List[ChatMessage] = []
|
||||
self._message_factory = message_factory
|
||||
|
||||
@event
|
||||
async def handle_start(self, message: GroupChatStart, ctx: MessageContext) -> None:
|
||||
"""Handle a start event by appending the content to the buffer."""
|
||||
if message.messages is not None:
|
||||
self._message_buffer.extend(message.messages)
|
||||
for msg in message.messages:
|
||||
self._buffer_message(msg)
|
||||
|
||||
@event
|
||||
async def handle_agent_response(self, message: GroupChatAgentResponse, ctx: MessageContext) -> None:
|
||||
"""Handle an agent response event by appending the content to the buffer."""
|
||||
self._message_buffer.append(message.agent_response.chat_message)
|
||||
self._buffer_message(message.agent_response.chat_message)
|
||||
|
||||
@rpc
|
||||
async def handle_reset(self, message: GroupChatReset, ctx: MessageContext) -> None:
|
||||
|
@ -68,17 +75,10 @@ class ChatAgentContainer(SequentialRoutedAgent):
|
|||
response: Response | None = None
|
||||
async for msg in self._agent.on_messages_stream(self._message_buffer, ctx.cancellation_token):
|
||||
if isinstance(msg, Response):
|
||||
# Log the response.
|
||||
await self.publish_message(
|
||||
GroupChatMessage(message=msg.chat_message),
|
||||
topic_id=DefaultTopicId(type=self._output_topic_type),
|
||||
)
|
||||
await self._log_message(msg.chat_message)
|
||||
response = msg
|
||||
else:
|
||||
# Log the message.
|
||||
await self.publish_message(
|
||||
GroupChatMessage(message=msg), topic_id=DefaultTopicId(type=self._output_topic_type)
|
||||
)
|
||||
await self._log_message(msg)
|
||||
if response is None:
|
||||
raise ValueError("The agent did not produce a final response. Check the agent's on_messages_stream method.")
|
||||
|
||||
|
@ -90,6 +90,21 @@ class ChatAgentContainer(SequentialRoutedAgent):
|
|||
cancellation_token=ctx.cancellation_token,
|
||||
)
|
||||
|
||||
def _buffer_message(self, message: ChatMessage) -> None:
|
||||
if not self._message_factory.is_registered(message.__class__):
|
||||
raise ValueError(f"Message type {message.__class__} is not registered.")
|
||||
# Buffer the message.
|
||||
self._message_buffer.append(message)
|
||||
|
||||
async def _log_message(self, message: AgentEvent | ChatMessage) -> None:
|
||||
if not self._message_factory.is_registered(message.__class__):
|
||||
raise ValueError(f"Message type {message.__class__} is not registered.")
|
||||
# Log the message.
|
||||
await self.publish_message(
|
||||
GroupChatMessage(message=message),
|
||||
topic_id=DefaultTopicId(type=self._output_topic_type),
|
||||
)
|
||||
|
||||
@rpc
|
||||
async def handle_pause(self, message: GroupChatPause, ctx: MessageContext) -> None:
|
||||
"""Handle a pause event by pausing the agent."""
|
||||
|
@ -105,10 +120,18 @@ class ChatAgentContainer(SequentialRoutedAgent):
|
|||
|
||||
async def save_state(self) -> Mapping[str, Any]:
|
||||
agent_state = await self._agent.save_state()
|
||||
state = ChatAgentContainerState(agent_state=agent_state, message_buffer=list(self._message_buffer))
|
||||
state = ChatAgentContainerState(
|
||||
agent_state=agent_state, message_buffer=[message.dump() for message in self._message_buffer]
|
||||
)
|
||||
return state.model_dump()
|
||||
|
||||
async def load_state(self, state: Mapping[str, Any]) -> None:
|
||||
container_state = ChatAgentContainerState.model_validate(state)
|
||||
self._message_buffer = list(container_state.message_buffer)
|
||||
self._message_buffer = []
|
||||
for message_data in container_state.message_buffer:
|
||||
message = self._message_factory.create(message_data)
|
||||
if isinstance(message, ChatMessage):
|
||||
self._message_buffer.append(message)
|
||||
else:
|
||||
raise ValueError(f"Invalid message type in message buffer: {type(message)}")
|
||||
await self._agent.load_state(container_state.agent_state)
|
||||
|
|
|
@ -9,7 +9,7 @@ from typing_extensions import Self
|
|||
|
||||
from .... import EVENT_LOGGER_NAME, TRACE_LOGGER_NAME
|
||||
from ....base import ChatAgent, TerminationCondition
|
||||
from ....messages import AgentEvent, ChatMessage
|
||||
from ....messages import AgentEvent, ChatMessage, MessageFactory
|
||||
from .._base_group_chat import BaseGroupChat
|
||||
from .._events import GroupChatTermination
|
||||
from ._magentic_one_orchestrator import MagenticOneOrchestrator
|
||||
|
@ -131,6 +131,7 @@ class MagenticOneGroupChat(BaseGroupChat, Component[MagenticOneGroupChatConfig])
|
|||
output_message_queue: asyncio.Queue[AgentEvent | ChatMessage | GroupChatTermination],
|
||||
termination_condition: TerminationCondition | None,
|
||||
max_turns: int | None,
|
||||
message_factory: MessageFactory,
|
||||
) -> Callable[[], MagenticOneOrchestrator]:
|
||||
return lambda: MagenticOneOrchestrator(
|
||||
name,
|
||||
|
@ -140,6 +141,7 @@ class MagenticOneGroupChat(BaseGroupChat, Component[MagenticOneGroupChatConfig])
|
|||
participant_names,
|
||||
participant_descriptions,
|
||||
max_turns,
|
||||
message_factory,
|
||||
self._model_client,
|
||||
self._max_stalls,
|
||||
self._final_answer_prompt,
|
||||
|
|
|
@ -18,6 +18,7 @@ from ....messages import (
|
|||
AgentEvent,
|
||||
ChatMessage,
|
||||
HandoffMessage,
|
||||
MessageFactory,
|
||||
MultiModalMessage,
|
||||
StopMessage,
|
||||
TextMessage,
|
||||
|
@ -26,7 +27,7 @@ from ....messages import (
|
|||
ToolCallSummaryMessage,
|
||||
)
|
||||
from ....state import MagenticOneOrchestratorState
|
||||
from ....utils import content_to_str, remove_images
|
||||
from ....utils import remove_images
|
||||
from .._base_group_chat_manager import BaseGroupChatManager
|
||||
from .._events import (
|
||||
GroupChatAgentResponse,
|
||||
|
@ -61,6 +62,7 @@ class MagenticOneOrchestrator(BaseGroupChatManager):
|
|||
participant_names: List[str],
|
||||
participant_descriptions: List[str],
|
||||
max_turns: int | None,
|
||||
message_factory: MessageFactory,
|
||||
model_client: ChatCompletionClient,
|
||||
max_stalls: int,
|
||||
final_answer_prompt: str,
|
||||
|
@ -77,6 +79,7 @@ class MagenticOneOrchestrator(BaseGroupChatManager):
|
|||
output_message_queue,
|
||||
termination_condition,
|
||||
max_turns,
|
||||
message_factory,
|
||||
)
|
||||
self._model_client = model_client
|
||||
self._max_stalls = max_stalls
|
||||
|
@ -147,7 +150,7 @@ class MagenticOneOrchestrator(BaseGroupChatManager):
|
|||
# Create the initial task ledger
|
||||
#################################
|
||||
# Combine all message contents for task
|
||||
self._task = " ".join([content_to_str(msg.content) for msg in message.messages])
|
||||
self._task = " ".join([msg.to_model_text() for msg in message.messages])
|
||||
planning_conversation: List[LLMMessage] = []
|
||||
|
||||
# 1. GATHER FACTS
|
||||
|
@ -203,7 +206,7 @@ class MagenticOneOrchestrator(BaseGroupChatManager):
|
|||
|
||||
async def save_state(self) -> Mapping[str, Any]:
|
||||
state = MagenticOneOrchestratorState(
|
||||
message_thread=list(self._message_thread),
|
||||
message_thread=[msg.dump() for msg in self._message_thread],
|
||||
current_turn=self._current_turn,
|
||||
task=self._task,
|
||||
facts=self._facts,
|
||||
|
@ -215,7 +218,7 @@ class MagenticOneOrchestrator(BaseGroupChatManager):
|
|||
|
||||
async def load_state(self, state: Mapping[str, Any]) -> None:
|
||||
orchestrator_state = MagenticOneOrchestratorState.model_validate(state)
|
||||
self._message_thread = orchestrator_state.message_thread
|
||||
self._message_thread = [self._message_factory.create(message) for message in orchestrator_state.message_thread]
|
||||
self._current_turn = orchestrator_state.current_turn
|
||||
self._task = orchestrator_state.task
|
||||
self._facts = orchestrator_state.facts
|
||||
|
|
|
@ -6,7 +6,7 @@ from pydantic import BaseModel
|
|||
from typing_extensions import Self
|
||||
|
||||
from ...base import ChatAgent, TerminationCondition
|
||||
from ...messages import AgentEvent, ChatMessage
|
||||
from ...messages import AgentEvent, ChatMessage, MessageFactory
|
||||
from ...state import RoundRobinManagerState
|
||||
from ._base_group_chat import BaseGroupChat
|
||||
from ._base_group_chat_manager import BaseGroupChatManager
|
||||
|
@ -26,7 +26,8 @@ class RoundRobinGroupChatManager(BaseGroupChatManager):
|
|||
participant_descriptions: List[str],
|
||||
output_message_queue: asyncio.Queue[AgentEvent | ChatMessage | GroupChatTermination],
|
||||
termination_condition: TerminationCondition | None,
|
||||
max_turns: int | None = None,
|
||||
max_turns: int | None,
|
||||
message_factory: MessageFactory,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
name,
|
||||
|
@ -38,6 +39,7 @@ class RoundRobinGroupChatManager(BaseGroupChatManager):
|
|||
output_message_queue,
|
||||
termination_condition,
|
||||
max_turns,
|
||||
message_factory,
|
||||
)
|
||||
self._next_speaker_index = 0
|
||||
|
||||
|
@ -53,7 +55,7 @@ class RoundRobinGroupChatManager(BaseGroupChatManager):
|
|||
|
||||
async def save_state(self) -> Mapping[str, Any]:
|
||||
state = RoundRobinManagerState(
|
||||
message_thread=list(self._message_thread),
|
||||
message_thread=[message.dump() for message in self._message_thread],
|
||||
current_turn=self._current_turn,
|
||||
next_speaker_index=self._next_speaker_index,
|
||||
)
|
||||
|
@ -61,7 +63,7 @@ class RoundRobinGroupChatManager(BaseGroupChatManager):
|
|||
|
||||
async def load_state(self, state: Mapping[str, Any]) -> None:
|
||||
round_robin_state = RoundRobinManagerState.model_validate(state)
|
||||
self._message_thread = list(round_robin_state.message_thread)
|
||||
self._message_thread = [self._message_factory.create(message) for message in round_robin_state.message_thread]
|
||||
self._current_turn = round_robin_state.current_turn
|
||||
self._next_speaker_index = round_robin_state.next_speaker_index
|
||||
|
||||
|
@ -164,6 +166,7 @@ class RoundRobinGroupChat(BaseGroupChat, Component[RoundRobinGroupChatConfig]):
|
|||
termination_condition: TerminationCondition | None = None,
|
||||
max_turns: int | None = None,
|
||||
runtime: AgentRuntime | None = None,
|
||||
custom_message_types: List[type[AgentEvent | ChatMessage]] | None = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
participants,
|
||||
|
@ -172,6 +175,7 @@ class RoundRobinGroupChat(BaseGroupChat, Component[RoundRobinGroupChatConfig]):
|
|||
termination_condition=termination_condition,
|
||||
max_turns=max_turns,
|
||||
runtime=runtime,
|
||||
custom_message_types=custom_message_types,
|
||||
)
|
||||
|
||||
def _create_group_chat_manager_factory(
|
||||
|
@ -185,6 +189,7 @@ class RoundRobinGroupChat(BaseGroupChat, Component[RoundRobinGroupChatConfig]):
|
|||
output_message_queue: asyncio.Queue[AgentEvent | ChatMessage | GroupChatTermination],
|
||||
termination_condition: TerminationCondition | None,
|
||||
max_turns: int | None,
|
||||
message_factory: MessageFactory,
|
||||
) -> Callable[[], RoundRobinGroupChatManager]:
|
||||
def _factory() -> RoundRobinGroupChatManager:
|
||||
return RoundRobinGroupChatManager(
|
||||
|
@ -197,6 +202,7 @@ class RoundRobinGroupChat(BaseGroupChat, Component[RoundRobinGroupChatConfig]):
|
|||
output_message_queue,
|
||||
termination_condition,
|
||||
max_turns,
|
||||
message_factory,
|
||||
)
|
||||
|
||||
return _factory
|
||||
|
|
|
@ -14,9 +14,8 @@ from ...agents import BaseChatAgent
|
|||
from ...base import ChatAgent, TerminationCondition
|
||||
from ...messages import (
|
||||
AgentEvent,
|
||||
BaseAgentEvent,
|
||||
ChatMessage,
|
||||
MultiModalMessage,
|
||||
MessageFactory,
|
||||
)
|
||||
from ...state import SelectorManagerState
|
||||
from ._base_group_chat import BaseGroupChat
|
||||
|
@ -49,6 +48,7 @@ class SelectorGroupChatManager(BaseGroupChatManager):
|
|||
output_message_queue: asyncio.Queue[AgentEvent | ChatMessage | GroupChatTermination],
|
||||
termination_condition: TerminationCondition | None,
|
||||
max_turns: int | None,
|
||||
message_factory: MessageFactory,
|
||||
model_client: ChatCompletionClient,
|
||||
selector_prompt: str,
|
||||
allow_repeated_speaker: bool,
|
||||
|
@ -66,6 +66,7 @@ class SelectorGroupChatManager(BaseGroupChatManager):
|
|||
output_message_queue,
|
||||
termination_condition,
|
||||
max_turns,
|
||||
message_factory,
|
||||
)
|
||||
self._model_client = model_client
|
||||
self._selector_prompt = selector_prompt
|
||||
|
@ -89,7 +90,7 @@ class SelectorGroupChatManager(BaseGroupChatManager):
|
|||
|
||||
async def save_state(self) -> Mapping[str, Any]:
|
||||
state = SelectorManagerState(
|
||||
message_thread=list(self._message_thread),
|
||||
message_thread=[msg.dump() for msg in self._message_thread],
|
||||
current_turn=self._current_turn,
|
||||
previous_speaker=self._previous_speaker,
|
||||
)
|
||||
|
@ -97,7 +98,7 @@ class SelectorGroupChatManager(BaseGroupChatManager):
|
|||
|
||||
async def load_state(self, state: Mapping[str, Any]) -> None:
|
||||
selector_state = SelectorManagerState.model_validate(state)
|
||||
self._message_thread = list(selector_state.message_thread)
|
||||
self._message_thread = [self._message_factory.create(msg) for msg in selector_state.message_thread]
|
||||
self._current_turn = selector_state.current_turn
|
||||
self._previous_speaker = selector_state.previous_speaker
|
||||
|
||||
|
@ -152,20 +153,10 @@ class SelectorGroupChatManager(BaseGroupChatManager):
|
|||
# Construct the history of the conversation.
|
||||
history_messages: List[str] = []
|
||||
for msg in thread:
|
||||
if isinstance(msg, BaseAgentEvent):
|
||||
# Ignore agent events.
|
||||
if not isinstance(msg, ChatMessage):
|
||||
# Only process chat messages.
|
||||
continue
|
||||
message = f"{msg.source}:"
|
||||
if isinstance(msg.content, str):
|
||||
message += f" {msg.content}"
|
||||
elif isinstance(msg, MultiModalMessage):
|
||||
for item in msg.content:
|
||||
if isinstance(item, str):
|
||||
message += f" {item}"
|
||||
else:
|
||||
message += " [Image]"
|
||||
else:
|
||||
raise ValueError(f"Unexpected message type in selector: {type(msg)}")
|
||||
message = f"{msg.source}: {msg.to_model_text()}"
|
||||
history_messages.append(
|
||||
message.rstrip() + "\n\n"
|
||||
) # Create some consistency for how messages are separated in the transcript
|
||||
|
@ -414,7 +405,7 @@ class SelectorGroupChat(BaseGroupChat, Component[SelectorGroupChatConfig]):
|
|||
)
|
||||
|
||||
def selector_func(messages: Sequence[AgentEvent | ChatMessage]) -> str | None:
|
||||
if len(messages) == 1 or messages[-1].content == "Incorrect!":
|
||||
if len(messages) == 1 or messages[-1].to_text() == "Incorrect!":
|
||||
return "Agent1"
|
||||
if messages[-1].source == "Agent1":
|
||||
return "Agent2"
|
||||
|
@ -457,6 +448,7 @@ Read the above conversation. Then select the next role from {participants} to pl
|
|||
max_selector_attempts: int = 3,
|
||||
selector_func: Optional[SelectorFuncType] = None,
|
||||
candidate_func: Optional[CandidateFuncType] = None,
|
||||
custom_message_types: List[type[AgentEvent | ChatMessage]] | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
participants,
|
||||
|
@ -465,6 +457,7 @@ Read the above conversation. Then select the next role from {participants} to pl
|
|||
termination_condition=termination_condition,
|
||||
max_turns=max_turns,
|
||||
runtime=runtime,
|
||||
custom_message_types=custom_message_types,
|
||||
)
|
||||
# Validate the participants.
|
||||
if len(participants) < 2:
|
||||
|
@ -487,6 +480,7 @@ Read the above conversation. Then select the next role from {participants} to pl
|
|||
output_message_queue: asyncio.Queue[AgentEvent | ChatMessage | GroupChatTermination],
|
||||
termination_condition: TerminationCondition | None,
|
||||
max_turns: int | None,
|
||||
message_factory: MessageFactory,
|
||||
) -> Callable[[], BaseGroupChatManager]:
|
||||
return lambda: SelectorGroupChatManager(
|
||||
name,
|
||||
|
@ -498,6 +492,7 @@ Read the above conversation. Then select the next role from {participants} to pl
|
|||
output_message_queue,
|
||||
termination_condition,
|
||||
max_turns,
|
||||
message_factory,
|
||||
self._model_client,
|
||||
self._selector_prompt,
|
||||
self._allow_repeated_speaker,
|
||||
|
|
|
@ -5,7 +5,7 @@ from autogen_core import AgentRuntime, Component, ComponentModel
|
|||
from pydantic import BaseModel
|
||||
|
||||
from ...base import ChatAgent, TerminationCondition
|
||||
from ...messages import AgentEvent, ChatMessage, HandoffMessage
|
||||
from ...messages import AgentEvent, ChatMessage, HandoffMessage, MessageFactory
|
||||
from ...state import SwarmManagerState
|
||||
from ._base_group_chat import BaseGroupChat
|
||||
from ._base_group_chat_manager import BaseGroupChatManager
|
||||
|
@ -26,6 +26,7 @@ class SwarmGroupChatManager(BaseGroupChatManager):
|
|||
output_message_queue: asyncio.Queue[AgentEvent | ChatMessage | GroupChatTermination],
|
||||
termination_condition: TerminationCondition | None,
|
||||
max_turns: int | None,
|
||||
message_factory: MessageFactory,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
name,
|
||||
|
@ -37,6 +38,7 @@ class SwarmGroupChatManager(BaseGroupChatManager):
|
|||
output_message_queue,
|
||||
termination_condition,
|
||||
max_turns,
|
||||
message_factory,
|
||||
)
|
||||
self._current_speaker = self._participant_names[0]
|
||||
|
||||
|
@ -90,7 +92,7 @@ class SwarmGroupChatManager(BaseGroupChatManager):
|
|||
|
||||
async def save_state(self) -> Mapping[str, Any]:
|
||||
state = SwarmManagerState(
|
||||
message_thread=list(self._message_thread),
|
||||
message_thread=[msg.dump() for msg in self._message_thread],
|
||||
current_turn=self._current_turn,
|
||||
current_speaker=self._current_speaker,
|
||||
)
|
||||
|
@ -98,7 +100,7 @@ class SwarmGroupChatManager(BaseGroupChatManager):
|
|||
|
||||
async def load_state(self, state: Mapping[str, Any]) -> None:
|
||||
swarm_state = SwarmManagerState.model_validate(state)
|
||||
self._message_thread = list(swarm_state.message_thread)
|
||||
self._message_thread = [self._message_factory.create(message) for message in swarm_state.message_thread]
|
||||
self._current_turn = swarm_state.current_turn
|
||||
self._current_speaker = swarm_state.current_speaker
|
||||
|
||||
|
@ -210,6 +212,7 @@ class Swarm(BaseGroupChat, Component[SwarmConfig]):
|
|||
termination_condition: TerminationCondition | None = None,
|
||||
max_turns: int | None = None,
|
||||
runtime: AgentRuntime | None = None,
|
||||
custom_message_types: List[type[AgentEvent | ChatMessage]] | None = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
participants,
|
||||
|
@ -218,6 +221,7 @@ class Swarm(BaseGroupChat, Component[SwarmConfig]):
|
|||
termination_condition=termination_condition,
|
||||
max_turns=max_turns,
|
||||
runtime=runtime,
|
||||
custom_message_types=custom_message_types,
|
||||
)
|
||||
# The first participant must be able to produce handoff messages.
|
||||
first_participant = self._participants[0]
|
||||
|
@ -235,6 +239,7 @@ class Swarm(BaseGroupChat, Component[SwarmConfig]):
|
|||
output_message_queue: asyncio.Queue[AgentEvent | ChatMessage | GroupChatTermination],
|
||||
termination_condition: TerminationCondition | None,
|
||||
max_turns: int | None,
|
||||
message_factory: MessageFactory,
|
||||
) -> Callable[[], SwarmGroupChatManager]:
|
||||
def _factory() -> SwarmGroupChatManager:
|
||||
return SwarmGroupChatManager(
|
||||
|
@ -247,6 +252,7 @@ class Swarm(BaseGroupChat, Component[SwarmConfig]):
|
|||
output_message_queue,
|
||||
termination_condition,
|
||||
max_turns,
|
||||
message_factory,
|
||||
)
|
||||
|
||||
return _factory
|
||||
|
|
|
@ -5,7 +5,7 @@ import time
|
|||
from inspect import iscoroutinefunction
|
||||
from typing import AsyncGenerator, Awaitable, Callable, Dict, List, Optional, TypeVar, Union, cast
|
||||
|
||||
from autogen_core import CancellationToken, Image
|
||||
from autogen_core import CancellationToken
|
||||
from autogen_core.models import RequestUsage
|
||||
|
||||
from autogen_agentchat.agents import UserProxyAgent
|
||||
|
@ -135,7 +135,11 @@ async def Console(
|
|||
duration = time.time() - start_time
|
||||
|
||||
# Print final response.
|
||||
output = f"{'-' * 10} {message.chat_message.source} {'-' * 10}\n{_message_to_str(message.chat_message, render_image_iterm=render_image_iterm)}\n"
|
||||
if isinstance(message.chat_message, MultiModalMessage):
|
||||
final_content = message.chat_message.to_text(iterm=render_image_iterm)
|
||||
else:
|
||||
final_content = message.chat_message.to_text()
|
||||
output = f"{'-' * 10} {message.chat_message.source} {'-' * 10}\n{final_content}\n"
|
||||
if message.chat_message.models_usage:
|
||||
if output_stats:
|
||||
output += f"[Prompt tokens: {message.chat_message.models_usage.prompt_tokens}, Completion tokens: {message.chat_message.models_usage.completion_tokens}]\n"
|
||||
|
@ -171,16 +175,17 @@ async def Console(
|
|||
# Print message sender.
|
||||
await aprint(f"{'-' * 10} {message.source} {'-' * 10}", end="\n", flush=True)
|
||||
if isinstance(message, ModelClientStreamingChunkEvent):
|
||||
await aprint(message.content, end="")
|
||||
await aprint(message.to_text(), end="")
|
||||
streaming_chunks.append(message.content)
|
||||
else:
|
||||
if streaming_chunks:
|
||||
streaming_chunks.clear()
|
||||
# Chunked messages are already printed, so we just print a newline.
|
||||
await aprint("", end="\n", flush=True)
|
||||
elif isinstance(message, MultiModalMessage):
|
||||
await aprint(message.to_text(iterm=render_image_iterm), end="\n", flush=True)
|
||||
else:
|
||||
# Print message content.
|
||||
await aprint(_message_to_str(message, render_image_iterm=render_image_iterm), end="\n", flush=True)
|
||||
await aprint(message.to_text(), end="\n", flush=True)
|
||||
if message.models_usage:
|
||||
if output_stats:
|
||||
await aprint(
|
||||
|
@ -195,25 +200,3 @@ async def Console(
|
|||
raise ValueError("No TaskResult or Response was processed.")
|
||||
|
||||
return last_processed
|
||||
|
||||
|
||||
# iTerm2 image rendering protocol: https://iterm2.com/documentation-images.html
|
||||
def _image_to_iterm(image: Image) -> str:
|
||||
image_data = image.to_base64()
|
||||
return f"\033]1337;File=inline=1:{image_data}\a\n"
|
||||
|
||||
|
||||
def _message_to_str(message: AgentEvent | ChatMessage, *, render_image_iterm: bool = False) -> str:
|
||||
if isinstance(message, MultiModalMessage):
|
||||
result: List[str] = []
|
||||
for c in message.content:
|
||||
if isinstance(c, str):
|
||||
result.append(c)
|
||||
else:
|
||||
if render_image_iterm:
|
||||
result.append(_image_to_iterm(c))
|
||||
else:
|
||||
result.append("<image>")
|
||||
return "\n".join(result)
|
||||
else:
|
||||
return f"{message.content}"
|
||||
|
|
|
@ -2,18 +2,24 @@ from typing import List, Union
|
|||
|
||||
from autogen_core import FunctionCall, Image
|
||||
from autogen_core.models import FunctionExecutionResult, LLMMessage, UserMessage
|
||||
from pydantic import BaseModel
|
||||
|
||||
# Type aliases for convenience
|
||||
_StructuredContent = BaseModel
|
||||
_UserContent = Union[str, List[Union[str, Image]]]
|
||||
_AssistantContent = Union[str, List[FunctionCall]]
|
||||
_FunctionExecutionContent = List[FunctionExecutionResult]
|
||||
_SystemContent = str
|
||||
|
||||
|
||||
def content_to_str(content: _UserContent | _AssistantContent | _FunctionExecutionContent | _SystemContent) -> str:
|
||||
def content_to_str(
|
||||
content: _UserContent | _AssistantContent | _FunctionExecutionContent | _SystemContent | _StructuredContent,
|
||||
) -> str:
|
||||
"""Convert the content of an LLMMessage to a string."""
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
elif isinstance(content, BaseModel):
|
||||
return content.model_dump_json()
|
||||
else:
|
||||
result: List[str] = []
|
||||
for c in content:
|
||||
|
|
|
@ -12,6 +12,7 @@ from autogen_agentchat.messages import (
|
|||
MemoryQueryEvent,
|
||||
ModelClientStreamingChunkEvent,
|
||||
MultiModalMessage,
|
||||
StructuredMessage,
|
||||
TextMessage,
|
||||
ThoughtEvent,
|
||||
ToolCallExecutionEvent,
|
||||
|
@ -624,6 +625,23 @@ async def test_multi_modal_task(monkeypatch: pytest.MonkeyPatch) -> None:
|
|||
assert len(result.messages) == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_with_structured_task() -> None:
|
||||
class InputTask(BaseModel):
|
||||
input: str
|
||||
data: List[str]
|
||||
|
||||
model_client = ReplayChatCompletionClient(["Hello"])
|
||||
agent = AssistantAgent(
|
||||
name="assistant",
|
||||
model_client=model_client,
|
||||
)
|
||||
|
||||
task = StructuredMessage[InputTask](content=InputTask(input="Test", data=["Test1", "Test2"]), source="user")
|
||||
result = await agent.run(task=task)
|
||||
assert len(result.messages) == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_model_capabilities() -> None:
|
||||
model = "random-model"
|
||||
|
@ -896,6 +914,7 @@ async def test_model_client_stream() -> None:
|
|||
chunks: List[str] = []
|
||||
async for message in agent.run_stream(task="task"):
|
||||
if isinstance(message, TaskResult):
|
||||
assert isinstance(message.messages[-1], TextMessage)
|
||||
assert message.messages[-1].content == "Response to message 3"
|
||||
elif isinstance(message, ModelClientStreamingChunkEvent):
|
||||
chunks.append(message.content)
|
||||
|
@ -929,11 +948,14 @@ async def test_model_client_stream_with_tool_calls() -> None:
|
|||
chunks: List[str] = []
|
||||
async for message in agent.run_stream(task="task"):
|
||||
if isinstance(message, TaskResult):
|
||||
assert isinstance(message.messages[-1], TextMessage)
|
||||
assert isinstance(message.messages[1], ToolCallRequestEvent)
|
||||
assert message.messages[-1].content == "Example response 2 to task"
|
||||
assert message.messages[1].content == [
|
||||
FunctionCall(id="1", name="_pass_function", arguments=r'{"input": "task"}'),
|
||||
FunctionCall(id="3", name="_echo_function", arguments=r'{"input": "task"}'),
|
||||
]
|
||||
assert isinstance(message.messages[2], ToolCallExecutionEvent)
|
||||
assert message.messages[2].content == [
|
||||
FunctionExecutionResult(call_id="1", content="pass", is_error=False, name="_pass_function"),
|
||||
FunctionExecutionResult(call_id="3", content="task", is_error=False, name="_echo_function"),
|
||||
|
|
|
@ -20,6 +20,7 @@ from autogen_agentchat.messages import (
|
|||
HandoffMessage,
|
||||
MultiModalMessage,
|
||||
StopMessage,
|
||||
StructuredMessage,
|
||||
TextMessage,
|
||||
ToolCallExecutionEvent,
|
||||
ToolCallRequestEvent,
|
||||
|
@ -44,6 +45,7 @@ from autogen_core.tools import FunctionTool
|
|||
from autogen_ext.code_executors.local import LocalCommandLineCodeExecutor
|
||||
from autogen_ext.models.openai import OpenAIChatCompletionClient
|
||||
from autogen_ext.models.replay import ReplayChatCompletionClient
|
||||
from pydantic import BaseModel
|
||||
from utils import FileLogHandler
|
||||
|
||||
logger = logging.getLogger(EVENT_LOGGER_NAME)
|
||||
|
@ -101,6 +103,34 @@ class _FlakyAgent(BaseChatAgent):
|
|||
self._last_message = None
|
||||
|
||||
|
||||
class _UnknownMessageType(ChatMessage):
|
||||
content: str
|
||||
|
||||
def to_model_message(self) -> UserMessage:
|
||||
raise NotImplementedError("This message type is not supported.")
|
||||
|
||||
def to_model_text(self) -> str:
|
||||
raise NotImplementedError("This message type is not supported.")
|
||||
|
||||
def to_text(self) -> str:
|
||||
raise NotImplementedError("This message type is not supported.")
|
||||
|
||||
|
||||
class _UnknownMessageTypeAgent(BaseChatAgent):
|
||||
def __init__(self, name: str, description: str) -> None:
|
||||
super().__init__(name, description)
|
||||
|
||||
@property
|
||||
def produced_message_types(self) -> Sequence[type[ChatMessage]]:
|
||||
return (_UnknownMessageType,)
|
||||
|
||||
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
|
||||
return Response(chat_message=_UnknownMessageType(content="Unknown message type", source=self.name))
|
||||
|
||||
async def on_reset(self, cancellation_token: CancellationToken) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class _StopAgent(_EchoAgent):
|
||||
def __init__(self, name: str, description: str, *, stop_at: int = 1) -> None:
|
||||
super().__init__(name, description)
|
||||
|
@ -122,6 +152,19 @@ def _pass_function(input: str) -> str:
|
|||
return "pass"
|
||||
|
||||
|
||||
class _InputTask1(BaseModel):
|
||||
task: str
|
||||
data: List[str]
|
||||
|
||||
|
||||
class _InputTask2(BaseModel):
|
||||
task: str
|
||||
data: str
|
||||
|
||||
|
||||
TaskType = str | List[ChatMessage] | ChatMessage
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(params=["single_threaded", "embedded"]) # type: ignore
|
||||
async def runtime(request: pytest.FixtureRequest) -> AsyncGenerator[AgentRuntime | None, None]:
|
||||
if request.param == "single_threaded":
|
||||
|
@ -164,14 +207,11 @@ async def test_round_robin_group_chat(runtime: AgentRuntime | None) -> None:
|
|||
"Hello, world!",
|
||||
"TERMINATE",
|
||||
]
|
||||
# Normalize the messages to remove \r\n and any leading/trailing whitespace.
|
||||
normalized_messages = [
|
||||
msg.content.replace("\r\n", "\n").rstrip("\n") if isinstance(msg.content, str) else msg.content
|
||||
for msg in result.messages
|
||||
]
|
||||
|
||||
# Assert that all expected messages are in the collected messages
|
||||
assert normalized_messages == expected_messages
|
||||
for i in range(len(expected_messages)):
|
||||
produced_message = result.messages[i]
|
||||
assert isinstance(produced_message, TextMessage)
|
||||
content = produced_message.content.replace("\r\n", "\n").rstrip("\n")
|
||||
assert content == expected_messages[i]
|
||||
|
||||
assert result.stop_reason is not None and result.stop_reason == "Text 'TERMINATE' mentioned"
|
||||
|
||||
|
@ -202,28 +242,89 @@ async def test_round_robin_group_chat(runtime: AgentRuntime | None) -> None:
|
|||
model_client.reset()
|
||||
index = 0
|
||||
await team.reset()
|
||||
result_2 = await team.run(
|
||||
task=MultiModalMessage(content=["Write a program that prints 'Hello, world!'"], source="user")
|
||||
)
|
||||
assert result.messages[0].content == result_2.messages[0].content[0]
|
||||
task = MultiModalMessage(content=["Write a program that prints 'Hello, world!'"], source="user")
|
||||
result_2 = await team.run(task=task)
|
||||
assert isinstance(result.messages[0], TextMessage)
|
||||
assert isinstance(result_2.messages[0], MultiModalMessage)
|
||||
assert result.messages[0].content == task.content[0]
|
||||
assert result.messages[1:] == result_2.messages[1:]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_round_robin_group_chat_state(runtime: AgentRuntime | None) -> None:
|
||||
async def test_round_robin_group_chat_unknown_task_message_type(runtime: AgentRuntime | None) -> None:
|
||||
model_client = ReplayChatCompletionClient([])
|
||||
agent1 = AssistantAgent("agent1", model_client=model_client)
|
||||
agent2 = AssistantAgent("agent2", model_client=model_client)
|
||||
termination = TextMentionTermination("TERMINATE")
|
||||
team1 = RoundRobinGroupChat(
|
||||
participants=[agent1, agent2],
|
||||
termination_condition=termination,
|
||||
runtime=runtime,
|
||||
custom_message_types=[StructuredMessage[_InputTask2]],
|
||||
)
|
||||
with pytest.raises(ValueError, match=r"Message type .*StructuredMessage\[_InputTask1\].* is not registered"):
|
||||
await team1.run(
|
||||
task=StructuredMessage[_InputTask1](
|
||||
content=_InputTask1(task="Write a program that prints 'Hello, world!'", data=["a", "b", "c"]),
|
||||
source="user",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_round_robin_group_chat_unknown_agent_message_type() -> None:
|
||||
model_client = ReplayChatCompletionClient(["Hello"])
|
||||
agent1 = AssistantAgent("agent1", model_client=model_client)
|
||||
agent2 = _UnknownMessageTypeAgent("agent2", "I am an unknown message type agent")
|
||||
termination = TextMentionTermination("TERMINATE")
|
||||
team1 = RoundRobinGroupChat(participants=[agent1, agent2], termination_condition=termination)
|
||||
with pytest.raises(ValueError, match="Message type .*UnknownMessageType.* not registered"):
|
||||
await team1.run(task=TextMessage(content="Write a program that prints 'Hello, world!'", source="user"))
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"task",
|
||||
[
|
||||
"Write a program that prints 'Hello, world!'",
|
||||
[TextMessage(content="Write a program that prints 'Hello, world!'", source="user")],
|
||||
[MultiModalMessage(content=["Write a program that prints 'Hello, world!'"], source="user")],
|
||||
[
|
||||
StructuredMessage[_InputTask1](
|
||||
content=_InputTask1(task="Write a program that prints 'Hello, world!'", data=["a", "b", "c"]),
|
||||
source="user",
|
||||
),
|
||||
StructuredMessage[_InputTask2](
|
||||
content=_InputTask2(task="Write a program that prints 'Hello, world!'", data="a"), source="user"
|
||||
),
|
||||
],
|
||||
],
|
||||
ids=["text", "text_message", "multi_modal_message", "structured_message"],
|
||||
)
|
||||
async def test_round_robin_group_chat_state(task: TaskType, runtime: AgentRuntime | None) -> None:
|
||||
model_client = ReplayChatCompletionClient(
|
||||
["No facts", "No plan", "print('Hello, world!')", "TERMINATE"],
|
||||
)
|
||||
agent1 = AssistantAgent("agent1", model_client=model_client)
|
||||
agent2 = AssistantAgent("agent2", model_client=model_client)
|
||||
termination = TextMentionTermination("TERMINATE")
|
||||
team1 = RoundRobinGroupChat(participants=[agent1, agent2], termination_condition=termination, runtime=runtime)
|
||||
await team1.run(task="Write a program that prints 'Hello, world!'")
|
||||
team1 = RoundRobinGroupChat(
|
||||
participants=[agent1, agent2],
|
||||
termination_condition=termination,
|
||||
runtime=runtime,
|
||||
custom_message_types=[StructuredMessage[_InputTask1], StructuredMessage[_InputTask2]],
|
||||
)
|
||||
await team1.run(task=task)
|
||||
state = await team1.save_state()
|
||||
|
||||
agent3 = AssistantAgent("agent1", model_client=model_client)
|
||||
agent4 = AssistantAgent("agent2", model_client=model_client)
|
||||
team2 = RoundRobinGroupChat(participants=[agent3, agent4], termination_condition=termination, runtime=runtime)
|
||||
team2 = RoundRobinGroupChat(
|
||||
participants=[agent3, agent4],
|
||||
termination_condition=termination,
|
||||
runtime=runtime,
|
||||
custom_message_types=[StructuredMessage[_InputTask1], StructuredMessage[_InputTask2]],
|
||||
)
|
||||
await team2.load_state(state)
|
||||
state2 = await team2.save_state()
|
||||
assert state == state2
|
||||
|
@ -453,6 +554,7 @@ async def test_selector_group_chat(runtime: AgentRuntime | None) -> None:
|
|||
task="Write a program that prints 'Hello, world!'",
|
||||
)
|
||||
assert len(result.messages) == 6
|
||||
assert isinstance(result.messages[0], TextMessage)
|
||||
assert result.messages[0].content == "Write a program that prints 'Hello, world!'"
|
||||
assert result.messages[1].source == "agent3"
|
||||
assert result.messages[2].source == "agent2"
|
||||
|
@ -485,7 +587,25 @@ async def test_selector_group_chat(runtime: AgentRuntime | None) -> None:
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_selector_group_chat_state(runtime: AgentRuntime | None) -> None:
|
||||
@pytest.mark.parametrize(
|
||||
"task",
|
||||
[
|
||||
"Write a program that prints 'Hello, world!'",
|
||||
[TextMessage(content="Write a program that prints 'Hello, world!'", source="user")],
|
||||
[MultiModalMessage(content=["Write a program that prints 'Hello, world!'"], source="user")],
|
||||
[
|
||||
StructuredMessage[_InputTask1](
|
||||
content=_InputTask1(task="Write a program that prints 'Hello, world!'", data=["a", "b", "c"]),
|
||||
source="user",
|
||||
),
|
||||
StructuredMessage[_InputTask2](
|
||||
content=_InputTask2(task="Write a program that prints 'Hello, world!'", data="a"), source="user"
|
||||
),
|
||||
],
|
||||
],
|
||||
ids=["text", "text_message", "multi_modal_message", "structured_message"],
|
||||
)
|
||||
async def test_selector_group_chat_state(task: TaskType, runtime: AgentRuntime | None) -> None:
|
||||
model_client = ReplayChatCompletionClient(
|
||||
["agent1", "No facts", "agent2", "No plan", "agent1", "print('Hello, world!')", "agent2", "TERMINATE"],
|
||||
)
|
||||
|
@ -497,14 +617,18 @@ async def test_selector_group_chat_state(runtime: AgentRuntime | None) -> None:
|
|||
termination_condition=termination,
|
||||
model_client=model_client,
|
||||
runtime=runtime,
|
||||
custom_message_types=[StructuredMessage[_InputTask1], StructuredMessage[_InputTask2]],
|
||||
)
|
||||
await team1.run(task="Write a program that prints 'Hello, world!'")
|
||||
await team1.run(task=task)
|
||||
state = await team1.save_state()
|
||||
|
||||
agent3 = AssistantAgent("agent1", model_client=model_client)
|
||||
agent4 = AssistantAgent("agent2", model_client=model_client)
|
||||
team2 = SelectorGroupChat(
|
||||
participants=[agent3, agent4], termination_condition=termination, model_client=model_client
|
||||
participants=[agent3, agent4],
|
||||
termination_condition=termination,
|
||||
model_client=model_client,
|
||||
custom_message_types=[StructuredMessage[_InputTask1], StructuredMessage[_InputTask2]],
|
||||
)
|
||||
await team2.load_state(state)
|
||||
state2 = await team2.save_state()
|
||||
|
@ -545,6 +669,7 @@ async def test_selector_group_chat_two_speakers(runtime: AgentRuntime | None) ->
|
|||
task="Write a program that prints 'Hello, world!'",
|
||||
)
|
||||
assert len(result.messages) == 5
|
||||
assert isinstance(result.messages[0], TextMessage)
|
||||
assert result.messages[0].content == "Write a program that prints 'Hello, world!'"
|
||||
assert result.messages[1].source == "agent2"
|
||||
assert result.messages[2].source == "agent1"
|
||||
|
@ -594,6 +719,7 @@ async def test_selector_group_chat_two_speakers_allow_repeated(runtime: AgentRun
|
|||
)
|
||||
result = await team.run(task="Write a program that prints 'Hello, world!'")
|
||||
assert len(result.messages) == 4
|
||||
assert isinstance(result.messages[0], TextMessage)
|
||||
assert result.messages[0].content == "Write a program that prints 'Hello, world!'"
|
||||
assert result.messages[1].source == "agent2"
|
||||
assert result.messages[2].source == "agent2"
|
||||
|
@ -635,6 +761,7 @@ async def test_selector_group_chat_succcess_after_2_attempts(runtime: AgentRunti
|
|||
)
|
||||
result = await team.run(task="Write a program that prints 'Hello, world!'")
|
||||
assert len(result.messages) == 2
|
||||
assert isinstance(result.messages[0], TextMessage)
|
||||
assert result.messages[0].content == "Write a program that prints 'Hello, world!'"
|
||||
assert result.messages[1].source == "agent2"
|
||||
|
||||
|
@ -659,6 +786,7 @@ async def test_selector_group_chat_fall_back_to_first_after_3_attempts(runtime:
|
|||
)
|
||||
result = await team.run(task="Write a program that prints 'Hello, world!'")
|
||||
assert len(result.messages) == 2
|
||||
assert isinstance(result.messages[0], TextMessage)
|
||||
assert result.messages[0].content == "Write a program that prints 'Hello, world!'"
|
||||
assert result.messages[1].source == "agent1"
|
||||
|
||||
|
@ -679,6 +807,7 @@ async def test_selector_group_chat_fall_back_to_previous_after_3_attempts(runtim
|
|||
)
|
||||
result = await team.run(task="Write a program that prints 'Hello, world!'")
|
||||
assert len(result.messages) == 3
|
||||
assert isinstance(result.messages[0], TextMessage)
|
||||
assert result.messages[0].content == "Write a program that prints 'Hello, world!'"
|
||||
assert result.messages[1].source == "agent2"
|
||||
assert result.messages[2].source == "agent2"
|
||||
|
@ -796,6 +925,12 @@ async def test_swarm_handoff(runtime: AgentRuntime | None) -> None:
|
|||
team = Swarm([second_agent, first_agent, third_agent], termination_condition=termination, runtime=runtime)
|
||||
result = await team.run(task="task")
|
||||
assert len(result.messages) == 6
|
||||
assert isinstance(result.messages[0], TextMessage)
|
||||
assert isinstance(result.messages[1], HandoffMessage)
|
||||
assert isinstance(result.messages[2], HandoffMessage)
|
||||
assert isinstance(result.messages[3], HandoffMessage)
|
||||
assert isinstance(result.messages[4], HandoffMessage)
|
||||
assert isinstance(result.messages[5], HandoffMessage)
|
||||
assert result.messages[0].content == "task"
|
||||
assert result.messages[1].content == "Transferred to third_agent."
|
||||
assert result.messages[2].content == "Transferred to first_agent."
|
||||
|
@ -839,6 +974,65 @@ async def test_swarm_handoff(runtime: AgentRuntime | None) -> None:
|
|||
assert manager_1._current_speaker == manager_2._current_speaker # pyright: ignore
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"task",
|
||||
[
|
||||
"Write a program that prints 'Hello, world!'",
|
||||
[TextMessage(content="Write a program that prints 'Hello, world!'", source="user")],
|
||||
[MultiModalMessage(content=["Write a program that prints 'Hello, world!'"], source="user")],
|
||||
[
|
||||
StructuredMessage[_InputTask1](
|
||||
content=_InputTask1(task="Write a program that prints 'Hello, world!'", data=["a", "b", "c"]),
|
||||
source="user",
|
||||
),
|
||||
StructuredMessage[_InputTask2](
|
||||
content=_InputTask2(task="Write a program that prints 'Hello, world!'", data="a"), source="user"
|
||||
),
|
||||
],
|
||||
],
|
||||
ids=["text", "text_message", "multi_modal_message", "structured_message"],
|
||||
)
|
||||
async def test_swarm_handoff_state(task: TaskType, runtime: AgentRuntime | None) -> None:
|
||||
first_agent = _HandOffAgent("first_agent", description="first agent", next_agent="second_agent")
|
||||
second_agent = _HandOffAgent("second_agent", description="second agent", next_agent="third_agent")
|
||||
third_agent = _HandOffAgent("third_agent", description="third agent", next_agent="first_agent")
|
||||
|
||||
termination = MaxMessageTermination(6)
|
||||
team1 = Swarm(
|
||||
[second_agent, first_agent, third_agent],
|
||||
termination_condition=termination,
|
||||
runtime=runtime,
|
||||
custom_message_types=[StructuredMessage[_InputTask1], StructuredMessage[_InputTask2]],
|
||||
)
|
||||
await team1.run(task=task)
|
||||
state = await team1.save_state()
|
||||
|
||||
first_agent2 = _HandOffAgent("first_agent", description="first agent", next_agent="second_agent")
|
||||
second_agent2 = _HandOffAgent("second_agent", description="second agent", next_agent="third_agent")
|
||||
third_agent2 = _HandOffAgent("third_agent", description="third agent", next_agent="first_agent")
|
||||
team2 = Swarm(
|
||||
[second_agent2, first_agent2, third_agent2],
|
||||
termination_condition=termination,
|
||||
runtime=runtime,
|
||||
custom_message_types=[StructuredMessage[_InputTask1], StructuredMessage[_InputTask2]],
|
||||
)
|
||||
await team2.load_state(state)
|
||||
state2 = await team2.save_state()
|
||||
assert state == state2
|
||||
|
||||
manager_1 = await team1._runtime.try_get_underlying_agent_instance( # pyright: ignore
|
||||
AgentId(f"{team1._group_chat_manager_name}_{team1._team_id}", team1._team_id), # pyright: ignore
|
||||
SwarmGroupChatManager, # pyright: ignore
|
||||
)
|
||||
manager_2 = await team2._runtime.try_get_underlying_agent_instance( # pyright: ignore
|
||||
AgentId(f"{team2._group_chat_manager_name}_{team2._team_id}", team2._team_id), # pyright: ignore
|
||||
SwarmGroupChatManager, # pyright: ignore
|
||||
)
|
||||
assert manager_1._message_thread == manager_2._message_thread # pyright: ignore
|
||||
assert manager_1._current_speaker == manager_2._current_speaker # pyright: ignore
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_swarm_handoff_using_tool_calls(runtime: AgentRuntime | None) -> None:
|
||||
model_client = ReplayChatCompletionClient(
|
||||
|
@ -870,9 +1064,14 @@ async def test_swarm_handoff_using_tool_calls(runtime: AgentRuntime | None) -> N
|
|||
team = Swarm([agent1, agent2], termination_condition=termination, runtime=runtime)
|
||||
result = await team.run(task="task")
|
||||
assert len(result.messages) == 7
|
||||
assert isinstance(result.messages[0], TextMessage)
|
||||
assert result.messages[0].content == "task"
|
||||
assert isinstance(result.messages[1], ToolCallRequestEvent)
|
||||
assert isinstance(result.messages[2], ToolCallExecutionEvent)
|
||||
assert isinstance(result.messages[3], HandoffMessage)
|
||||
assert isinstance(result.messages[4], HandoffMessage)
|
||||
assert isinstance(result.messages[5], TextMessage)
|
||||
assert isinstance(result.messages[6], TextMessage)
|
||||
assert result.messages[3].content == "handoff to agent2"
|
||||
assert result.messages[4].content == "Transferred to agent1."
|
||||
assert result.messages[5].content == "Hello"
|
||||
|
@ -910,18 +1109,23 @@ async def test_swarm_pause_and_resume(runtime: AgentRuntime | None) -> None:
|
|||
team = Swarm([second_agent, first_agent, third_agent], max_turns=1, runtime=runtime)
|
||||
result = await team.run(task="task")
|
||||
assert len(result.messages) == 2
|
||||
assert isinstance(result.messages[0], TextMessage)
|
||||
assert isinstance(result.messages[1], HandoffMessage)
|
||||
assert result.messages[0].content == "task"
|
||||
assert result.messages[1].content == "Transferred to third_agent."
|
||||
|
||||
# Resume with a new task.
|
||||
result = await team.run(task="new task")
|
||||
assert len(result.messages) == 2
|
||||
assert isinstance(result.messages[0], TextMessage)
|
||||
assert isinstance(result.messages[1], HandoffMessage)
|
||||
assert result.messages[0].content == "new task"
|
||||
assert result.messages[1].content == "Transferred to first_agent."
|
||||
|
||||
# Resume with the same task.
|
||||
result = await team.run()
|
||||
assert len(result.messages) == 1
|
||||
assert isinstance(result.messages[0], HandoffMessage)
|
||||
assert result.messages[0].content == "Transferred to second_agent."
|
||||
|
||||
|
||||
|
@ -996,8 +1200,10 @@ async def test_swarm_with_parallel_tool_calls(runtime: AgentRuntime | None) -> N
|
|||
source="agent1",
|
||||
context=expected_handoff_context,
|
||||
)
|
||||
assert isinstance(result.messages[4], TextMessage)
|
||||
assert result.messages[4].content == "Hello"
|
||||
assert result.messages[4].source == "agent2"
|
||||
assert isinstance(result.messages[5], TextMessage)
|
||||
assert result.messages[5].content == "TERMINATE"
|
||||
assert result.messages[5].source == "agent2"
|
||||
|
||||
|
@ -1020,17 +1226,26 @@ async def test_swarm_with_handoff_termination(runtime: AgentRuntime | None) -> N
|
|||
# Start
|
||||
result = await team.run(task="task")
|
||||
assert len(result.messages) == 2
|
||||
assert isinstance(result.messages[0], TextMessage)
|
||||
assert isinstance(result.messages[1], HandoffMessage)
|
||||
assert result.messages[0].content == "task"
|
||||
assert result.messages[1].content == "Transferred to third_agent."
|
||||
# Resume existing.
|
||||
result = await team.run()
|
||||
assert len(result.messages) == 3
|
||||
assert isinstance(result.messages[0], HandoffMessage)
|
||||
assert isinstance(result.messages[1], HandoffMessage)
|
||||
assert isinstance(result.messages[2], HandoffMessage)
|
||||
assert result.messages[0].content == "Transferred to first_agent."
|
||||
assert result.messages[1].content == "Transferred to second_agent."
|
||||
assert result.messages[2].content == "Transferred to third_agent."
|
||||
# Resume new task.
|
||||
result = await team.run(task="new task")
|
||||
assert len(result.messages) == 4
|
||||
assert isinstance(result.messages[0], TextMessage)
|
||||
assert isinstance(result.messages[1], HandoffMessage)
|
||||
assert isinstance(result.messages[2], HandoffMessage)
|
||||
assert isinstance(result.messages[3], HandoffMessage)
|
||||
assert result.messages[0].content == "new task"
|
||||
assert result.messages[1].content == "Transferred to first_agent."
|
||||
assert result.messages[2].content == "Transferred to second_agent."
|
||||
|
@ -1043,6 +1258,9 @@ async def test_swarm_with_handoff_termination(runtime: AgentRuntime | None) -> N
|
|||
# Start
|
||||
result = await team.run(task="task")
|
||||
assert len(result.messages) == 3
|
||||
assert isinstance(result.messages[0], TextMessage)
|
||||
assert isinstance(result.messages[1], HandoffMessage)
|
||||
assert isinstance(result.messages[2], HandoffMessage)
|
||||
assert result.messages[0].content == "task"
|
||||
assert result.messages[1].content == "Transferred to third_agent."
|
||||
assert result.messages[2].content == "Transferred to non_existing_agent."
|
||||
|
@ -1055,6 +1273,10 @@ async def test_swarm_with_handoff_termination(runtime: AgentRuntime | None) -> N
|
|||
# Resume with a HandoffMessage
|
||||
result = await team.run(task=HandoffMessage(content="Handoff to first_agent.", target="first_agent", source="user"))
|
||||
assert len(result.messages) == 4
|
||||
assert isinstance(result.messages[0], HandoffMessage)
|
||||
assert isinstance(result.messages[1], HandoffMessage)
|
||||
assert isinstance(result.messages[2], HandoffMessage)
|
||||
assert isinstance(result.messages[3], HandoffMessage)
|
||||
assert result.messages[0].content == "Handoff to first_agent."
|
||||
assert result.messages[1].content == "Transferred to second_agent."
|
||||
assert result.messages[2].content == "Transferred to third_agent."
|
||||
|
@ -1081,6 +1303,10 @@ async def test_round_robin_group_chat_with_message_list(runtime: AgentRuntime |
|
|||
|
||||
# Verify the messages were processed in order
|
||||
assert len(result.messages) == 4 # Initial messages + echo until termination
|
||||
assert isinstance(result.messages[0], TextMessage)
|
||||
assert isinstance(result.messages[1], TextMessage)
|
||||
assert isinstance(result.messages[2], TextMessage)
|
||||
assert isinstance(result.messages[3], TextMessage)
|
||||
assert result.messages[0].content == "Message 1" # First message
|
||||
assert result.messages[1].content == "Message 2" # Second message
|
||||
assert result.messages[2].content == "Message 3" # Third message
|
||||
|
|
|
@ -4,10 +4,7 @@ from typing import List, Sequence
|
|||
import pytest
|
||||
from autogen_agentchat.agents import AssistantAgent
|
||||
from autogen_agentchat.base import TaskResult
|
||||
from autogen_agentchat.messages import (
|
||||
AgentEvent,
|
||||
ChatMessage,
|
||||
)
|
||||
from autogen_agentchat.messages import AgentEvent, ChatMessage
|
||||
from autogen_agentchat.teams import SelectorGroupChat
|
||||
from autogen_agentchat.ui import Console
|
||||
from autogen_core.models import ChatCompletionClient
|
||||
|
|
|
@ -134,8 +134,8 @@ async def test_magentic_one_group_chat_basic(runtime: AgentRuntime | None) -> No
|
|||
)
|
||||
result = await team.run(task="Write a program that prints 'Hello, world!'")
|
||||
assert len(result.messages) == 5
|
||||
assert result.messages[2].content == "Continue task"
|
||||
assert result.messages[4].content == "print('Hello, world!')"
|
||||
assert result.messages[2].to_text() == "Continue task"
|
||||
assert result.messages[4].to_text() == "print('Hello, world!')"
|
||||
assert result.stop_reason is not None and result.stop_reason == "Because"
|
||||
|
||||
# Test save and load.
|
||||
|
@ -214,8 +214,8 @@ async def test_magentic_one_group_chat_with_stalls(runtime: AgentRuntime | None)
|
|||
)
|
||||
result = await team.run(task="Write a program that prints 'Hello, world!'")
|
||||
assert len(result.messages) == 6
|
||||
assert isinstance(result.messages[1].content, str)
|
||||
assert isinstance(result.messages[1], TextMessage)
|
||||
assert result.messages[1].content.startswith("\nWe are working to address the following user request:")
|
||||
assert isinstance(result.messages[4].content, str)
|
||||
assert isinstance(result.messages[4], TextMessage)
|
||||
assert result.messages[4].content.startswith("\nWe are working to address the following user request:")
|
||||
assert result.stop_reason is not None and result.stop_reason == "test"
|
||||
|
|
|
@ -0,0 +1,93 @@
|
|||
import pytest
|
||||
from autogen_agentchat.messages import HandoffMessage, MessageFactory, StructuredMessage, TextMessage
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class TestContent(BaseModel):
|
||||
"""Test content model."""
|
||||
|
||||
field1: str
|
||||
field2: int
|
||||
|
||||
|
||||
def test_structured_message() -> None:
|
||||
# Create a structured message with the test content
|
||||
message = StructuredMessage[TestContent](
|
||||
source="test_agent",
|
||||
content=TestContent(field1="test", field2=42),
|
||||
)
|
||||
|
||||
# Check that the message type is correct
|
||||
assert message.type == "StructuredMessage[TestContent]" # type: ignore
|
||||
|
||||
# Check that the content is of the correct type
|
||||
assert isinstance(message.content, TestContent)
|
||||
|
||||
# Check that the content fields are set correctly
|
||||
assert message.content.field1 == "test"
|
||||
assert message.content.field2 == 42
|
||||
|
||||
# Check that model_dump works correctly
|
||||
dumped_message = message.model_dump()
|
||||
assert dumped_message["source"] == "test_agent"
|
||||
assert dumped_message["content"]["field1"] == "test"
|
||||
assert dumped_message["content"]["field2"] == 42
|
||||
assert dumped_message["type"] == "StructuredMessage[TestContent]"
|
||||
|
||||
|
||||
def test_message_factory() -> None:
|
||||
factory = MessageFactory()
|
||||
|
||||
# Text message data
|
||||
text_data = {
|
||||
"type": "TextMessage",
|
||||
"source": "test_agent",
|
||||
"content": "Hello, world!",
|
||||
}
|
||||
|
||||
# Create a TextMessage instance
|
||||
text_message = factory.create(text_data)
|
||||
assert isinstance(text_message, TextMessage)
|
||||
assert text_message.source == "test_agent"
|
||||
assert text_message.content == "Hello, world!"
|
||||
assert text_message.type == "TextMessage" # type: ignore
|
||||
|
||||
# Handoff message data
|
||||
handoff_data = {
|
||||
"type": "HandoffMessage",
|
||||
"source": "test_agent",
|
||||
"content": "handoff to another agent",
|
||||
"target": "target_agent",
|
||||
}
|
||||
|
||||
# Create a HandoffMessage instance
|
||||
handoff_message = factory.create(handoff_data)
|
||||
assert isinstance(handoff_message, HandoffMessage)
|
||||
assert handoff_message.source == "test_agent"
|
||||
assert handoff_message.content == "handoff to another agent"
|
||||
assert handoff_message.target == "target_agent"
|
||||
assert handoff_message.type == "HandoffMessage" # type: ignore
|
||||
|
||||
# Structured message data
|
||||
structured_data = {
|
||||
"type": "StructuredMessage[TestContent]",
|
||||
"source": "test_agent",
|
||||
"content": {
|
||||
"field1": "test",
|
||||
"field2": 42,
|
||||
},
|
||||
}
|
||||
# Create a StructuredMessage instance -- this will fail because the type
|
||||
# is not registered in the factory.
|
||||
with pytest.raises(ValueError):
|
||||
structured_message = factory.create(structured_data)
|
||||
# Register the StructuredMessage type in the factory
|
||||
factory.register(StructuredMessage[TestContent])
|
||||
# Create a StructuredMessage instance
|
||||
structured_message = factory.create(structured_data)
|
||||
assert isinstance(structured_message, StructuredMessage)
|
||||
assert isinstance(structured_message.content, TestContent) # type: ignore
|
||||
assert structured_message.source == "test_agent"
|
||||
assert structured_message.content.field1 == "test"
|
||||
assert structured_message.content.field2 == 42
|
||||
assert structured_message.type == "StructuredMessage[TestContent]" # type: ignore
|
File diff suppressed because it is too large
Load Diff
|
@ -691,7 +691,7 @@ async def main() -> None:
|
|||
if user_input == "exit":
|
||||
break
|
||||
response = await assistant.on_messages([TextMessage(content=user_input, source="user")], CancellationToken())
|
||||
print("Assistant:", response.chat_message.content)
|
||||
print("Assistant:", response.chat_message.to_text())
|
||||
await model_client.close()
|
||||
|
||||
asyncio.run(main())
|
||||
|
@ -1331,7 +1331,7 @@ async def main() -> None:
|
|||
if user_input == "exit":
|
||||
break
|
||||
response = await assistant.on_messages([TextMessage(content=user_input, source="user")], CancellationToken())
|
||||
print("Assistant:", response.chat_message.content)
|
||||
print("Assistant:", response.chat_message.to_text())
|
||||
|
||||
await model_client.close()
|
||||
|
||||
|
|
File diff suppressed because one or more lines are too long
|
@ -1,403 +1,402 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Tracing and Observability\n",
|
||||
"\n",
|
||||
"AutoGen has [built-in support for tracing](https://microsoft.github.io/autogen/dev/user-guide/core-user-guide/framework/telemetry.html) and observability for collecting comprehensive records on the execution of your application. This feature is useful for debugging, performance analysis, and understanding the flow of your application.\n",
|
||||
"\n",
|
||||
"This capability is powered by the [OpenTelemetry](https://opentelemetry.io/) library, which means you can use any OpenTelemetry-compatible backend to collect and analyze traces.\n",
|
||||
"\n",
|
||||
"## Setup\n",
|
||||
"\n",
|
||||
"To begin, you need to install the OpenTelemetry Python package. You can do this using pip:\n",
|
||||
"\n",
|
||||
"```bash\n",
|
||||
"pip install opentelemetry-sdk\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"Once you have the SDK installed, the simplest way to set up tracing in AutoGen is to:\n",
|
||||
"\n",
|
||||
"1. Configure an OpenTelemetry tracer provider\n",
|
||||
"2. Set up an exporter to send traces to your backend\n",
|
||||
"3. Connect the tracer provider to the AutoGen runtime\n",
|
||||
"\n",
|
||||
"## Telemetry Backend\n",
|
||||
"\n",
|
||||
"To collect and view traces, you need to set up a telemetry backend. Several open-source options are available, including Jaeger, Zipkin. For this example, we will use Jaeger as our telemetry backend.\n",
|
||||
"\n",
|
||||
"For a quick start, you can run Jaeger locally using Docker:\n",
|
||||
"\n",
|
||||
"```bash\n",
|
||||
"docker run -d --name jaeger \\\n",
|
||||
" -e COLLECTOR_OTLP_ENABLED=true \\\n",
|
||||
" -p 16686:16686 \\\n",
|
||||
" -p 4317:4317 \\\n",
|
||||
" -p 4318:4318 \\\n",
|
||||
" jaegertracing/all-in-one:latest\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"This command starts a Jaeger instance that listens on port 16686 for the Jaeger UI and port 4317 for the OpenTelemetry collector. You can access the Jaeger UI at `http://localhost:16686`.\n",
|
||||
"\n",
|
||||
"## Instrumenting an AgentChat Team\n",
|
||||
"\n",
|
||||
"In the following section, we will review how to enable tracing with an AutoGen GroupChat team. The AutoGen runtime already supports open telemetry (automatically logging message metadata). To begin, we will create a tracing service that will be used to instrument the AutoGen runtime. "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from opentelemetry import trace\n",
|
||||
"from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter\n",
|
||||
"from opentelemetry.sdk.resources import Resource\n",
|
||||
"from opentelemetry.sdk.trace import TracerProvider\n",
|
||||
"from opentelemetry.sdk.trace.export import BatchSpanProcessor\n",
|
||||
"\n",
|
||||
"otel_exporter = OTLPSpanExporter(endpoint=\"http://localhost:4317\", insecure=True)\n",
|
||||
"tracer_provider = TracerProvider(resource=Resource({\"service.name\": \"autogen-test-agentchat\"}))\n",
|
||||
"span_processor = BatchSpanProcessor(otel_exporter)\n",
|
||||
"tracer_provider.add_span_processor(span_processor)\n",
|
||||
"trace.set_tracer_provider(tracer_provider)\n",
|
||||
"\n",
|
||||
"# we will get reference this tracer later using its service name\n",
|
||||
"# tracer = trace.get_tracer(\"autogen-test-agentchat\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"\n",
|
||||
"\n",
|
||||
"All of the code to create a [team](./tutorial/teams.ipynb) should already be familiar to you. An important note here is that all AgentChat agents and teams are run using the AutoGen core API runtime. In turn, the runtime is already instrumented to log [runtime messaging events (metadata)] (https://github.com/microsoft/autogen/blob/main/python/packages/autogen-core/src/autogen_core/_telemetry/_tracing_config.py) including:\n",
|
||||
"\n",
|
||||
"- **create**: When a message is created\n",
|
||||
"- **send**: When a message is sent\n",
|
||||
"- **publish**: When a message is published\n",
|
||||
"- **receive**: When a message is received\n",
|
||||
"- **intercept**: When a message is intercepted\n",
|
||||
"- **process**: When a message is processed\n",
|
||||
"- **ack**: When a message is acknowledged \n",
|
||||
" "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from autogen_agentchat.agents import AssistantAgent\n",
|
||||
"from autogen_agentchat.conditions import MaxMessageTermination, TextMentionTermination\n",
|
||||
"from autogen_agentchat.teams import SelectorGroupChat\n",
|
||||
"from autogen_agentchat.ui import Console\n",
|
||||
"from autogen_core import SingleThreadedAgentRuntime\n",
|
||||
"from autogen_ext.models.openai import OpenAIChatCompletionClient\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def search_web_tool(query: str) -> str:\n",
|
||||
" if \"2006-2007\" in query:\n",
|
||||
" return \"\"\"Here are the total points scored by Miami Heat players in the 2006-2007 season:\n",
|
||||
" Udonis Haslem: 844 points\n",
|
||||
" Dwayne Wade: 1397 points\n",
|
||||
" James Posey: 550 points\n",
|
||||
" ...\n",
|
||||
" \"\"\"\n",
|
||||
" elif \"2007-2008\" in query:\n",
|
||||
" return \"The number of total rebounds for Dwayne Wade in the Miami Heat season 2007-2008 is 214.\"\n",
|
||||
" elif \"2008-2009\" in query:\n",
|
||||
" return \"The number of total rebounds for Dwayne Wade in the Miami Heat season 2008-2009 is 398.\"\n",
|
||||
" return \"No data found.\"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def percentage_change_tool(start: float, end: float) -> float:\n",
|
||||
" return ((end - start) / start) * 100\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"async def main() -> None:\n",
|
||||
" model_client = OpenAIChatCompletionClient(model=\"gpt-4o\")\n",
|
||||
"\n",
|
||||
" planning_agent = AssistantAgent(\n",
|
||||
" \"PlanningAgent\",\n",
|
||||
" description=\"An agent for planning tasks, this agent should be the first to engage when given a new task.\",\n",
|
||||
" model_client=model_client,\n",
|
||||
" system_message=\"\"\"\n",
|
||||
" You are a planning agent.\n",
|
||||
" Your job is to break down complex tasks into smaller, manageable subtasks.\n",
|
||||
" Your team members are:\n",
|
||||
" WebSearchAgent: Searches for information\n",
|
||||
" DataAnalystAgent: Performs calculations\n",
|
||||
"\n",
|
||||
" You only plan and delegate tasks - you do not execute them yourself.\n",
|
||||
"\n",
|
||||
" When assigning tasks, use this format:\n",
|
||||
" 1. <agent> : <task>\n",
|
||||
"\n",
|
||||
" After all tasks are complete, summarize the findings and end with \"TERMINATE\".\n",
|
||||
" \"\"\",\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" web_search_agent = AssistantAgent(\n",
|
||||
" \"WebSearchAgent\",\n",
|
||||
" description=\"An agent for searching information on the web.\",\n",
|
||||
" tools=[search_web_tool],\n",
|
||||
" model_client=model_client,\n",
|
||||
" system_message=\"\"\"\n",
|
||||
" You are a web search agent.\n",
|
||||
" Your only tool is search_tool - use it to find information.\n",
|
||||
" You make only one search call at a time.\n",
|
||||
" Once you have the results, you never do calculations based on them.\n",
|
||||
" \"\"\",\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" data_analyst_agent = AssistantAgent(\n",
|
||||
" \"DataAnalystAgent\",\n",
|
||||
" description=\"An agent for performing calculations.\",\n",
|
||||
" model_client=model_client,\n",
|
||||
" tools=[percentage_change_tool],\n",
|
||||
" system_message=\"\"\"\n",
|
||||
" You are a data analyst.\n",
|
||||
" Given the tasks you have been assigned, you should analyze the data and provide results using the tools provided.\n",
|
||||
" If you have not seen the data, ask for it.\n",
|
||||
" \"\"\",\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" text_mention_termination = TextMentionTermination(\"TERMINATE\")\n",
|
||||
" max_messages_termination = MaxMessageTermination(max_messages=25)\n",
|
||||
" termination = text_mention_termination | max_messages_termination\n",
|
||||
"\n",
|
||||
" selector_prompt = \"\"\"Select an agent to perform task.\n",
|
||||
"\n",
|
||||
" {roles}\n",
|
||||
"\n",
|
||||
" Current conversation context:\n",
|
||||
" {history}\n",
|
||||
"\n",
|
||||
" Read the above conversation, then select an agent from {participants} to perform the next task.\n",
|
||||
" Make sure the planner agent has assigned tasks before other agents start working.\n",
|
||||
" Only select one agent.\n",
|
||||
" \"\"\"\n",
|
||||
"\n",
|
||||
" task = \"Who was the Miami Heat player with the highest points in the 2006-2007 season, and what was the percentage change in his total rebounds between the 2007-2008 and 2008-2009 seasons?\"\n",
|
||||
"\n",
|
||||
" tracer = trace.get_tracer(\"autogen-test-agentchat\")\n",
|
||||
" with tracer.start_as_current_span(\"runtime\"):\n",
|
||||
" team = SelectorGroupChat(\n",
|
||||
" [planning_agent, web_search_agent, data_analyst_agent],\n",
|
||||
" model_client=model_client,\n",
|
||||
" termination_condition=termination,\n",
|
||||
" selector_prompt=selector_prompt,\n",
|
||||
" allow_repeated_speaker=True,\n",
|
||||
" )\n",
|
||||
" await Console(team.run_stream(task=task))\n",
|
||||
"\n",
|
||||
" await model_client.close()\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# asyncio.run(main())"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"---------- user ----------\n",
|
||||
"Who was the Miami Heat player with the highest points in the 2006-2007 season, and what was the percentage change in his total rebounds between the 2007-2008 and 2008-2009 seasons?\n",
|
||||
"---------- PlanningAgent ----------\n",
|
||||
"To accomplish this, we can break down the tasks as follows:\n",
|
||||
"\n",
|
||||
"1. WebSearchAgent: Search for the Miami Heat player with the highest points during the 2006-2007 NBA season.\n",
|
||||
"2. WebSearchAgent: Find the total rebounds for the identified player in both the 2007-2008 and 2008-2009 NBA seasons.\n",
|
||||
"3. DataAnalystAgent: Calculate the percentage change in total rebounds for the player between the 2007-2008 and 2008-2009 seasons.\n",
|
||||
"\n",
|
||||
"Once these tasks are complete, I will summarize the findings.\n",
|
||||
"---------- WebSearchAgent ----------\n",
|
||||
"[FunctionCall(id='call_PUhxZyR0CTlWCY4uwd5Zh3WO', arguments='{\"query\":\"Miami Heat highest points scorer 2006-2007 season\"}', name='search_web_tool')]\n",
|
||||
"---------- WebSearchAgent ----------\n",
|
||||
"[FunctionExecutionResult(content='Here are the total points scored by Miami Heat players in the 2006-2007 season:\\n Udonis Haslem: 844 points\\n Dwayne Wade: 1397 points\\n James Posey: 550 points\\n ...\\n ', name='search_web_tool', call_id='call_PUhxZyR0CTlWCY4uwd5Zh3WO', is_error=False)]\n",
|
||||
"---------- WebSearchAgent ----------\n",
|
||||
"Here are the total points scored by Miami Heat players in the 2006-2007 season:\n",
|
||||
" Udonis Haslem: 844 points\n",
|
||||
" Dwayne Wade: 1397 points\n",
|
||||
" James Posey: 550 points\n",
|
||||
" ...\n",
|
||||
" \n",
|
||||
"---------- WebSearchAgent ----------\n",
|
||||
"Dwyane Wade was the Miami Heat player with the highest points in the 2006-2007 season, scoring 1,397 points. Now, let's find his total rebounds for the 2007-2008 and 2008-2009 NBA seasons.\n",
|
||||
"---------- WebSearchAgent ----------\n",
|
||||
"[FunctionCall(id='call_GL7KkWKj9ejIM8FfpgXe2dPk', arguments='{\"query\": \"Dwyane Wade total rebounds 2007-2008 season\"}', name='search_web_tool'), FunctionCall(id='call_X81huZoiA30zIjSAIDgb8ebe', arguments='{\"query\": \"Dwyane Wade total rebounds 2008-2009 season\"}', name='search_web_tool')]\n",
|
||||
"---------- WebSearchAgent ----------\n",
|
||||
"[FunctionExecutionResult(content='The number of total rebounds for Dwayne Wade in the Miami Heat season 2007-2008 is 214.', name='search_web_tool', call_id='call_GL7KkWKj9ejIM8FfpgXe2dPk', is_error=False), FunctionExecutionResult(content='The number of total rebounds for Dwayne Wade in the Miami Heat season 2008-2009 is 398.', name='search_web_tool', call_id='call_X81huZoiA30zIjSAIDgb8ebe', is_error=False)]\n",
|
||||
"---------- WebSearchAgent ----------\n",
|
||||
"The number of total rebounds for Dwayne Wade in the Miami Heat season 2007-2008 is 214.\n",
|
||||
"The number of total rebounds for Dwayne Wade in the Miami Heat season 2008-2009 is 398.\n",
|
||||
"---------- DataAnalystAgent ----------\n",
|
||||
"[FunctionCall(id='call_kB50RkFVqHptA7FOf0lL2RS8', arguments='{\"start\":214,\"end\":398}', name='percentage_change_tool')]\n",
|
||||
"---------- DataAnalystAgent ----------\n",
|
||||
"[FunctionExecutionResult(content='85.98130841121495', name='percentage_change_tool', call_id='call_kB50RkFVqHptA7FOf0lL2RS8', is_error=False)]\n",
|
||||
"---------- DataAnalystAgent ----------\n",
|
||||
"85.98130841121495\n",
|
||||
"---------- PlanningAgent ----------\n",
|
||||
"The Miami Heat player with the highest points during the 2006-2007 NBA season was Dwayne Wade, who scored 1,397 points. The percentage increase in his total rebounds from the 2007-2008 season (214 rebounds) to the 2008-2009 season (398 rebounds) was approximately 86%.\n",
|
||||
"\n",
|
||||
"TERMINATE\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"await main()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"You can then use the Jaeger UI to view the traces collected from the application run above. \n",
|
||||
"\n",
|
||||
""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Custom Traces \n",
|
||||
"\n",
|
||||
"So far, we are logging only the default events that are generated by the AutoGen runtime (message created, publish etc). However, you can also create custom spans to log specific events in your application. \n",
|
||||
"\n",
|
||||
"In the example below, we will show how to log messages from the `RoundRobinGroupChat` team as they are generated by adding custom spans around the team to log runtime events and spans to log messages generated by the team.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"-- primary_agent -- : Leaves cascade like gold, \n",
|
||||
"Whispering winds cool the earth.\n",
|
||||
"primary_agent: Leaves cascade like gold, \n",
|
||||
"Whispering winds cool the earth.\n",
|
||||
"\n",
|
||||
"-- critic_agent -- : Your haiku beautifully captures the essence of the fall season with vivid imagery. However, it appears to have six syllables in the second line, which should traditionally be five. Here's a revised version keeping the 5-7-5 syllable structure:\n",
|
||||
"\n",
|
||||
"Leaves cascade like gold, \n",
|
||||
"Whispering winds cool the air. \n",
|
||||
"\n",
|
||||
"Please adjust the second line to reflect a five-syllable count. Thank you!\n",
|
||||
"critic_agent: Your haiku beautifully captures the essence of the fall season with vivid imagery. However, it appears to have six syllables in the second line, which should traditionally be five. Here's a revised version keeping the 5-7-5 syllable structure:\n",
|
||||
"\n",
|
||||
"Leaves cascade like gold, \n",
|
||||
"Whispering winds cool the air. \n",
|
||||
"\n",
|
||||
"Please adjust the second line to reflect a five-syllable count. Thank you!\n",
|
||||
"\n",
|
||||
"-- primary_agent -- : Leaves cascade like gold, \n",
|
||||
"Whispering winds cool the air.\n",
|
||||
"primary_agent: Leaves cascade like gold, \n",
|
||||
"Whispering winds cool the air.\n",
|
||||
"\n",
|
||||
"-- critic_agent -- : APPROVE\n",
|
||||
"critic_agent: APPROVE\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from autogen_agentchat.base import TaskResult\n",
|
||||
"from autogen_agentchat.conditions import ExternalTermination\n",
|
||||
"from autogen_agentchat.teams import RoundRobinGroupChat\n",
|
||||
"from autogen_core import CancellationToken\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"async def run_agents() -> None:\n",
|
||||
" # Create an OpenAI model client.\n",
|
||||
" model_client = OpenAIChatCompletionClient(model=\"gpt-4o-2024-08-06\")\n",
|
||||
"\n",
|
||||
" # Create the primary agent.\n",
|
||||
" primary_agent = AssistantAgent(\n",
|
||||
" \"primary_agent\",\n",
|
||||
" model_client=model_client,\n",
|
||||
" system_message=\"You are a helpful AI assistant.\",\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" # Create the critic agent.\n",
|
||||
" critic_agent = AssistantAgent(\n",
|
||||
" \"critic_agent\",\n",
|
||||
" model_client=model_client,\n",
|
||||
" system_message=\"Provide constructive feedback. Respond with 'APPROVE' to when your feedbacks are addressed.\",\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" # Define a termination condition that stops the task if the critic approves.\n",
|
||||
" text_termination = TextMentionTermination(\"APPROVE\")\n",
|
||||
"\n",
|
||||
" tracer = trace.get_tracer(\"autogen-test-agentchat\")\n",
|
||||
" with tracer.start_as_current_span(\"runtime_round_robin_events\"):\n",
|
||||
" team = RoundRobinGroupChat([primary_agent, critic_agent], termination_condition=text_termination)\n",
|
||||
"\n",
|
||||
" response_stream = team.run_stream(task=\"Write a 2 line haiku about the fall season\")\n",
|
||||
" async for response in response_stream:\n",
|
||||
" async for response in response_stream:\n",
|
||||
" if not isinstance(response, TaskResult):\n",
|
||||
" print(f\"\\n-- {response.source} -- : {response.content}\")\n",
|
||||
" with tracer.start_as_current_span(f\"agent_message.{response.source}\") as message_span:\n",
|
||||
" content = response.content if isinstance(response.content, str) else str(response.content)\n",
|
||||
" message_span.set_attribute(\"agent.name\", response.source)\n",
|
||||
" message_span.set_attribute(\"message.content\", content)\n",
|
||||
" print(f\"{response.source}: {response.content}\")\n",
|
||||
"\n",
|
||||
" await model_client.close()\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"await run_agents()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"\n",
|
||||
"In the code above, we create a new span for each message sent by the agent. We set attributes on the span to include the agent's name and the message content. This allows us to trace the flow of messages through our application and understand how they are processed."
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": ".venv",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.9"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Tracing and Observability\n",
|
||||
"\n",
|
||||
"AutoGen has [built-in support for tracing](https://microsoft.github.io/autogen/dev/user-guide/core-user-guide/framework/telemetry.html) and observability for collecting comprehensive records on the execution of your application. This feature is useful for debugging, performance analysis, and understanding the flow of your application.\n",
|
||||
"\n",
|
||||
"This capability is powered by the [OpenTelemetry](https://opentelemetry.io/) library, which means you can use any OpenTelemetry-compatible backend to collect and analyze traces.\n",
|
||||
"\n",
|
||||
"## Setup\n",
|
||||
"\n",
|
||||
"To begin, you need to install the OpenTelemetry Python package. You can do this using pip:\n",
|
||||
"\n",
|
||||
"```bash\n",
|
||||
"pip install opentelemetry-sdk\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"Once you have the SDK installed, the simplest way to set up tracing in AutoGen is to:\n",
|
||||
"\n",
|
||||
"1. Configure an OpenTelemetry tracer provider\n",
|
||||
"2. Set up an exporter to send traces to your backend\n",
|
||||
"3. Connect the tracer provider to the AutoGen runtime\n",
|
||||
"\n",
|
||||
"## Telemetry Backend\n",
|
||||
"\n",
|
||||
"To collect and view traces, you need to set up a telemetry backend. Several open-source options are available, including Jaeger, Zipkin. For this example, we will use Jaeger as our telemetry backend.\n",
|
||||
"\n",
|
||||
"For a quick start, you can run Jaeger locally using Docker:\n",
|
||||
"\n",
|
||||
"```bash\n",
|
||||
"docker run -d --name jaeger \\\n",
|
||||
" -e COLLECTOR_OTLP_ENABLED=true \\\n",
|
||||
" -p 16686:16686 \\\n",
|
||||
" -p 4317:4317 \\\n",
|
||||
" -p 4318:4318 \\\n",
|
||||
" jaegertracing/all-in-one:latest\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"This command starts a Jaeger instance that listens on port 16686 for the Jaeger UI and port 4317 for the OpenTelemetry collector. You can access the Jaeger UI at `http://localhost:16686`.\n",
|
||||
"\n",
|
||||
"## Instrumenting an AgentChat Team\n",
|
||||
"\n",
|
||||
"In the following section, we will review how to enable tracing with an AutoGen GroupChat team. The AutoGen runtime already supports open telemetry (automatically logging message metadata). To begin, we will create a tracing service that will be used to instrument the AutoGen runtime. "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from opentelemetry import trace\n",
|
||||
"from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter\n",
|
||||
"from opentelemetry.sdk.resources import Resource\n",
|
||||
"from opentelemetry.sdk.trace import TracerProvider\n",
|
||||
"from opentelemetry.sdk.trace.export import BatchSpanProcessor\n",
|
||||
"\n",
|
||||
"otel_exporter = OTLPSpanExporter(endpoint=\"http://localhost:4317\", insecure=True)\n",
|
||||
"tracer_provider = TracerProvider(resource=Resource({\"service.name\": \"autogen-test-agentchat\"}))\n",
|
||||
"span_processor = BatchSpanProcessor(otel_exporter)\n",
|
||||
"tracer_provider.add_span_processor(span_processor)\n",
|
||||
"trace.set_tracer_provider(tracer_provider)\n",
|
||||
"\n",
|
||||
"# we will get reference this tracer later using its service name\n",
|
||||
"# tracer = trace.get_tracer(\"autogen-test-agentchat\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"\n",
|
||||
"\n",
|
||||
"All of the code to create a [team](./tutorial/teams.ipynb) should already be familiar to you. An important note here is that all AgentChat agents and teams are run using the AutoGen core API runtime. In turn, the runtime is already instrumented to log [runtime messaging events (metadata)] (https://github.com/microsoft/autogen/blob/main/python/packages/autogen-core/src/autogen_core/_telemetry/_tracing_config.py) including:\n",
|
||||
"\n",
|
||||
"- **create**: When a message is created\n",
|
||||
"- **send**: When a message is sent\n",
|
||||
"- **publish**: When a message is published\n",
|
||||
"- **receive**: When a message is received\n",
|
||||
"- **intercept**: When a message is intercepted\n",
|
||||
"- **process**: When a message is processed\n",
|
||||
"- **ack**: When a message is acknowledged \n",
|
||||
" "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from autogen_agentchat.agents import AssistantAgent\n",
|
||||
"from autogen_agentchat.conditions import MaxMessageTermination, TextMentionTermination\n",
|
||||
"from autogen_agentchat.teams import SelectorGroupChat\n",
|
||||
"from autogen_agentchat.ui import Console\n",
|
||||
"from autogen_core import SingleThreadedAgentRuntime\n",
|
||||
"from autogen_ext.models.openai import OpenAIChatCompletionClient\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def search_web_tool(query: str) -> str:\n",
|
||||
" if \"2006-2007\" in query:\n",
|
||||
" return \"\"\"Here are the total points scored by Miami Heat players in the 2006-2007 season:\n",
|
||||
" Udonis Haslem: 844 points\n",
|
||||
" Dwayne Wade: 1397 points\n",
|
||||
" James Posey: 550 points\n",
|
||||
" ...\n",
|
||||
" \"\"\"\n",
|
||||
" elif \"2007-2008\" in query:\n",
|
||||
" return \"The number of total rebounds for Dwayne Wade in the Miami Heat season 2007-2008 is 214.\"\n",
|
||||
" elif \"2008-2009\" in query:\n",
|
||||
" return \"The number of total rebounds for Dwayne Wade in the Miami Heat season 2008-2009 is 398.\"\n",
|
||||
" return \"No data found.\"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def percentage_change_tool(start: float, end: float) -> float:\n",
|
||||
" return ((end - start) / start) * 100\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"async def main() -> None:\n",
|
||||
" model_client = OpenAIChatCompletionClient(model=\"gpt-4o\")\n",
|
||||
"\n",
|
||||
" planning_agent = AssistantAgent(\n",
|
||||
" \"PlanningAgent\",\n",
|
||||
" description=\"An agent for planning tasks, this agent should be the first to engage when given a new task.\",\n",
|
||||
" model_client=model_client,\n",
|
||||
" system_message=\"\"\"\n",
|
||||
" You are a planning agent.\n",
|
||||
" Your job is to break down complex tasks into smaller, manageable subtasks.\n",
|
||||
" Your team members are:\n",
|
||||
" WebSearchAgent: Searches for information\n",
|
||||
" DataAnalystAgent: Performs calculations\n",
|
||||
"\n",
|
||||
" You only plan and delegate tasks - you do not execute them yourself.\n",
|
||||
"\n",
|
||||
" When assigning tasks, use this format:\n",
|
||||
" 1. <agent> : <task>\n",
|
||||
"\n",
|
||||
" After all tasks are complete, summarize the findings and end with \"TERMINATE\".\n",
|
||||
" \"\"\",\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" web_search_agent = AssistantAgent(\n",
|
||||
" \"WebSearchAgent\",\n",
|
||||
" description=\"An agent for searching information on the web.\",\n",
|
||||
" tools=[search_web_tool],\n",
|
||||
" model_client=model_client,\n",
|
||||
" system_message=\"\"\"\n",
|
||||
" You are a web search agent.\n",
|
||||
" Your only tool is search_tool - use it to find information.\n",
|
||||
" You make only one search call at a time.\n",
|
||||
" Once you have the results, you never do calculations based on them.\n",
|
||||
" \"\"\",\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" data_analyst_agent = AssistantAgent(\n",
|
||||
" \"DataAnalystAgent\",\n",
|
||||
" description=\"An agent for performing calculations.\",\n",
|
||||
" model_client=model_client,\n",
|
||||
" tools=[percentage_change_tool],\n",
|
||||
" system_message=\"\"\"\n",
|
||||
" You are a data analyst.\n",
|
||||
" Given the tasks you have been assigned, you should analyze the data and provide results using the tools provided.\n",
|
||||
" If you have not seen the data, ask for it.\n",
|
||||
" \"\"\",\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" text_mention_termination = TextMentionTermination(\"TERMINATE\")\n",
|
||||
" max_messages_termination = MaxMessageTermination(max_messages=25)\n",
|
||||
" termination = text_mention_termination | max_messages_termination\n",
|
||||
"\n",
|
||||
" selector_prompt = \"\"\"Select an agent to perform task.\n",
|
||||
"\n",
|
||||
" {roles}\n",
|
||||
"\n",
|
||||
" Current conversation context:\n",
|
||||
" {history}\n",
|
||||
"\n",
|
||||
" Read the above conversation, then select an agent from {participants} to perform the next task.\n",
|
||||
" Make sure the planner agent has assigned tasks before other agents start working.\n",
|
||||
" Only select one agent.\n",
|
||||
" \"\"\"\n",
|
||||
"\n",
|
||||
" task = \"Who was the Miami Heat player with the highest points in the 2006-2007 season, and what was the percentage change in his total rebounds between the 2007-2008 and 2008-2009 seasons?\"\n",
|
||||
"\n",
|
||||
" tracer = trace.get_tracer(\"autogen-test-agentchat\")\n",
|
||||
" with tracer.start_as_current_span(\"runtime\"):\n",
|
||||
" team = SelectorGroupChat(\n",
|
||||
" [planning_agent, web_search_agent, data_analyst_agent],\n",
|
||||
" model_client=model_client,\n",
|
||||
" termination_condition=termination,\n",
|
||||
" selector_prompt=selector_prompt,\n",
|
||||
" allow_repeated_speaker=True,\n",
|
||||
" )\n",
|
||||
" await Console(team.run_stream(task=task))\n",
|
||||
"\n",
|
||||
" await model_client.close()\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# asyncio.run(main())"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"---------- user ----------\n",
|
||||
"Who was the Miami Heat player with the highest points in the 2006-2007 season, and what was the percentage change in his total rebounds between the 2007-2008 and 2008-2009 seasons?\n",
|
||||
"---------- PlanningAgent ----------\n",
|
||||
"To accomplish this, we can break down the tasks as follows:\n",
|
||||
"\n",
|
||||
"1. WebSearchAgent: Search for the Miami Heat player with the highest points during the 2006-2007 NBA season.\n",
|
||||
"2. WebSearchAgent: Find the total rebounds for the identified player in both the 2007-2008 and 2008-2009 NBA seasons.\n",
|
||||
"3. DataAnalystAgent: Calculate the percentage change in total rebounds for the player between the 2007-2008 and 2008-2009 seasons.\n",
|
||||
"\n",
|
||||
"Once these tasks are complete, I will summarize the findings.\n",
|
||||
"---------- WebSearchAgent ----------\n",
|
||||
"[FunctionCall(id='call_PUhxZyR0CTlWCY4uwd5Zh3WO', arguments='{\"query\":\"Miami Heat highest points scorer 2006-2007 season\"}', name='search_web_tool')]\n",
|
||||
"---------- WebSearchAgent ----------\n",
|
||||
"[FunctionExecutionResult(content='Here are the total points scored by Miami Heat players in the 2006-2007 season:\\n Udonis Haslem: 844 points\\n Dwayne Wade: 1397 points\\n James Posey: 550 points\\n ...\\n ', name='search_web_tool', call_id='call_PUhxZyR0CTlWCY4uwd5Zh3WO', is_error=False)]\n",
|
||||
"---------- WebSearchAgent ----------\n",
|
||||
"Here are the total points scored by Miami Heat players in the 2006-2007 season:\n",
|
||||
" Udonis Haslem: 844 points\n",
|
||||
" Dwayne Wade: 1397 points\n",
|
||||
" James Posey: 550 points\n",
|
||||
" ...\n",
|
||||
" \n",
|
||||
"---------- WebSearchAgent ----------\n",
|
||||
"Dwyane Wade was the Miami Heat player with the highest points in the 2006-2007 season, scoring 1,397 points. Now, let's find his total rebounds for the 2007-2008 and 2008-2009 NBA seasons.\n",
|
||||
"---------- WebSearchAgent ----------\n",
|
||||
"[FunctionCall(id='call_GL7KkWKj9ejIM8FfpgXe2dPk', arguments='{\"query\": \"Dwyane Wade total rebounds 2007-2008 season\"}', name='search_web_tool'), FunctionCall(id='call_X81huZoiA30zIjSAIDgb8ebe', arguments='{\"query\": \"Dwyane Wade total rebounds 2008-2009 season\"}', name='search_web_tool')]\n",
|
||||
"---------- WebSearchAgent ----------\n",
|
||||
"[FunctionExecutionResult(content='The number of total rebounds for Dwayne Wade in the Miami Heat season 2007-2008 is 214.', name='search_web_tool', call_id='call_GL7KkWKj9ejIM8FfpgXe2dPk', is_error=False), FunctionExecutionResult(content='The number of total rebounds for Dwayne Wade in the Miami Heat season 2008-2009 is 398.', name='search_web_tool', call_id='call_X81huZoiA30zIjSAIDgb8ebe', is_error=False)]\n",
|
||||
"---------- WebSearchAgent ----------\n",
|
||||
"The number of total rebounds for Dwayne Wade in the Miami Heat season 2007-2008 is 214.\n",
|
||||
"The number of total rebounds for Dwayne Wade in the Miami Heat season 2008-2009 is 398.\n",
|
||||
"---------- DataAnalystAgent ----------\n",
|
||||
"[FunctionCall(id='call_kB50RkFVqHptA7FOf0lL2RS8', arguments='{\"start\":214,\"end\":398}', name='percentage_change_tool')]\n",
|
||||
"---------- DataAnalystAgent ----------\n",
|
||||
"[FunctionExecutionResult(content='85.98130841121495', name='percentage_change_tool', call_id='call_kB50RkFVqHptA7FOf0lL2RS8', is_error=False)]\n",
|
||||
"---------- DataAnalystAgent ----------\n",
|
||||
"85.98130841121495\n",
|
||||
"---------- PlanningAgent ----------\n",
|
||||
"The Miami Heat player with the highest points during the 2006-2007 NBA season was Dwayne Wade, who scored 1,397 points. The percentage increase in his total rebounds from the 2007-2008 season (214 rebounds) to the 2008-2009 season (398 rebounds) was approximately 86%.\n",
|
||||
"\n",
|
||||
"TERMINATE\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"await main()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"You can then use the Jaeger UI to view the traces collected from the application run above. \n",
|
||||
"\n",
|
||||
""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Custom Traces \n",
|
||||
"\n",
|
||||
"So far, we are logging only the default events that are generated by the AutoGen runtime (message created, publish etc). However, you can also create custom spans to log specific events in your application. \n",
|
||||
"\n",
|
||||
"In the example below, we will show how to log messages from the `RoundRobinGroupChat` team as they are generated by adding custom spans around the team to log runtime events and spans to log messages generated by the team.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"-- primary_agent -- : Leaves cascade like gold, \n",
|
||||
"Whispering winds cool the earth.\n",
|
||||
"primary_agent: Leaves cascade like gold, \n",
|
||||
"Whispering winds cool the earth.\n",
|
||||
"\n",
|
||||
"-- critic_agent -- : Your haiku beautifully captures the essence of the fall season with vivid imagery. However, it appears to have six syllables in the second line, which should traditionally be five. Here's a revised version keeping the 5-7-5 syllable structure:\n",
|
||||
"\n",
|
||||
"Leaves cascade like gold, \n",
|
||||
"Whispering winds cool the air. \n",
|
||||
"\n",
|
||||
"Please adjust the second line to reflect a five-syllable count. Thank you!\n",
|
||||
"critic_agent: Your haiku beautifully captures the essence of the fall season with vivid imagery. However, it appears to have six syllables in the second line, which should traditionally be five. Here's a revised version keeping the 5-7-5 syllable structure:\n",
|
||||
"\n",
|
||||
"Leaves cascade like gold, \n",
|
||||
"Whispering winds cool the air. \n",
|
||||
"\n",
|
||||
"Please adjust the second line to reflect a five-syllable count. Thank you!\n",
|
||||
"\n",
|
||||
"-- primary_agent -- : Leaves cascade like gold, \n",
|
||||
"Whispering winds cool the air.\n",
|
||||
"primary_agent: Leaves cascade like gold, \n",
|
||||
"Whispering winds cool the air.\n",
|
||||
"\n",
|
||||
"-- critic_agent -- : APPROVE\n",
|
||||
"critic_agent: APPROVE\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from autogen_agentchat.base import TaskResult\n",
|
||||
"from autogen_agentchat.conditions import ExternalTermination\n",
|
||||
"from autogen_agentchat.teams import RoundRobinGroupChat\n",
|
||||
"from autogen_core import CancellationToken\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"async def run_agents() -> None:\n",
|
||||
" # Create an OpenAI model client.\n",
|
||||
" model_client = OpenAIChatCompletionClient(model=\"gpt-4o-2024-08-06\")\n",
|
||||
"\n",
|
||||
" # Create the primary agent.\n",
|
||||
" primary_agent = AssistantAgent(\n",
|
||||
" \"primary_agent\",\n",
|
||||
" model_client=model_client,\n",
|
||||
" system_message=\"You are a helpful AI assistant.\",\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" # Create the critic agent.\n",
|
||||
" critic_agent = AssistantAgent(\n",
|
||||
" \"critic_agent\",\n",
|
||||
" model_client=model_client,\n",
|
||||
" system_message=\"Provide constructive feedback. Respond with 'APPROVE' to when your feedbacks are addressed.\",\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" # Define a termination condition that stops the task if the critic approves.\n",
|
||||
" text_termination = TextMentionTermination(\"APPROVE\")\n",
|
||||
"\n",
|
||||
" tracer = trace.get_tracer(\"autogen-test-agentchat\")\n",
|
||||
" with tracer.start_as_current_span(\"runtime_round_robin_events\"):\n",
|
||||
" team = RoundRobinGroupChat([primary_agent, critic_agent], termination_condition=text_termination)\n",
|
||||
"\n",
|
||||
" response_stream = team.run_stream(task=\"Write a 2 line haiku about the fall season\")\n",
|
||||
" async for response in response_stream:\n",
|
||||
" async for response in response_stream:\n",
|
||||
" if not isinstance(response, TaskResult):\n",
|
||||
" print(f\"\\n-- {response.source} -- : {response.to_text()}\")\n",
|
||||
" with tracer.start_as_current_span(f\"agent_message.{response.source}\") as message_span:\n",
|
||||
" message_span.set_attribute(\"agent.name\", response.source)\n",
|
||||
" message_span.set_attribute(\"message.content\", response.to_text())\n",
|
||||
" print(f\"{response.source}: {response.to_text()}\")\n",
|
||||
"\n",
|
||||
" await model_client.close()\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"await run_agents()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"\n",
|
||||
"In the code above, we create a new span for each message sent by the agent. We set attributes on the span to include the agent's name and the message content. This allows us to trace the flow of messages through our application and understand how they are processed."
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": ".venv",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.3"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
|
|
File diff suppressed because one or more lines are too long
|
@ -1,359 +1,359 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Managing State \n",
|
||||
"\n",
|
||||
"So far, we have discussed how to build components in a multi-agent application - agents, teams, termination conditions. In many cases, it is useful to save the state of these components to disk and load them back later. This is particularly useful in a web application where stateless endpoints respond to requests and need to load the state of the application from persistent storage.\n",
|
||||
"\n",
|
||||
"In this notebook, we will discuss how to save and load the state of agents, teams, and termination conditions. \n",
|
||||
" \n",
|
||||
"\n",
|
||||
"## Saving and Loading Agents\n",
|
||||
"\n",
|
||||
"We can get the state of an agent by calling {py:meth}`~autogen_agentchat.agents.AssistantAgent.save_state` method on \n",
|
||||
"an {py:class}`~autogen_agentchat.agents.AssistantAgent`. "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"In Tanganyika's embrace so wide and deep, \n",
|
||||
"Ancient waters cradle secrets they keep, \n",
|
||||
"Echoes of time where horizons sleep. \n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from autogen_agentchat.agents import AssistantAgent\n",
|
||||
"from autogen_agentchat.conditions import MaxMessageTermination\n",
|
||||
"from autogen_agentchat.messages import TextMessage\n",
|
||||
"from autogen_agentchat.teams import RoundRobinGroupChat\n",
|
||||
"from autogen_agentchat.ui import Console\n",
|
||||
"from autogen_core import CancellationToken\n",
|
||||
"from autogen_ext.models.openai import OpenAIChatCompletionClient\n",
|
||||
"\n",
|
||||
"model_client = OpenAIChatCompletionClient(model=\"gpt-4o-2024-08-06\")\n",
|
||||
"\n",
|
||||
"assistant_agent = AssistantAgent(\n",
|
||||
" name=\"assistant_agent\",\n",
|
||||
" system_message=\"You are a helpful assistant\",\n",
|
||||
" model_client=model_client,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# Use asyncio.run(...) when running in a script.\n",
|
||||
"response = await assistant_agent.on_messages(\n",
|
||||
" [TextMessage(content=\"Write a 3 line poem on lake tangayika\", source=\"user\")], CancellationToken()\n",
|
||||
")\n",
|
||||
"print(response.chat_message.content)\n",
|
||||
"await model_client.close()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"{'type': 'AssistantAgentState', 'version': '1.0.0', 'llm_messages': [{'content': 'Write a 3 line poem on lake tangayika', 'source': 'user', 'type': 'UserMessage'}, {'content': \"In Tanganyika's embrace so wide and deep, \\nAncient waters cradle secrets they keep, \\nEchoes of time where horizons sleep. \", 'source': 'assistant_agent', 'type': 'AssistantMessage'}]}\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"agent_state = await assistant_agent.save_state()\n",
|
||||
"print(agent_state)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"The last line of the poem was: \"Echoes of time where horizons sleep.\"\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"model_client = OpenAIChatCompletionClient(model=\"gpt-4o-2024-08-06\")\n",
|
||||
"\n",
|
||||
"new_assistant_agent = AssistantAgent(\n",
|
||||
" name=\"assistant_agent\",\n",
|
||||
" system_message=\"You are a helpful assistant\",\n",
|
||||
" model_client=model_client,\n",
|
||||
")\n",
|
||||
"await new_assistant_agent.load_state(agent_state)\n",
|
||||
"\n",
|
||||
"# Use asyncio.run(...) when running in a script.\n",
|
||||
"response = await new_assistant_agent.on_messages(\n",
|
||||
" [TextMessage(content=\"What was the last line of the previous poem you wrote\", source=\"user\")], CancellationToken()\n",
|
||||
")\n",
|
||||
"print(response.chat_message.content)\n",
|
||||
"await model_client.close()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"```{note}\n",
|
||||
"For {py:class}`~autogen_agentchat.agents.AssistantAgent`, its state consists of the model_context.\n",
|
||||
"If your write your own custom agent, consider overriding the {py:meth}`~autogen_agentchat.agents.BaseChatAgent.save_state` and {py:meth}`~autogen_agentchat.agents.BaseChatAgent.load_state` methods to customize the behavior. The default implementations save and load an empty state.\n",
|
||||
"```"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Saving and Loading Teams \n",
|
||||
"\n",
|
||||
"We can get the state of a team by calling `save_state` method on the team and load it back by calling `load_state` method on the team. \n",
|
||||
"\n",
|
||||
"When we call `save_state` on a team, it saves the state of all the agents in the team.\n",
|
||||
"\n",
|
||||
"We will begin by creating a simple {py:class}`~autogen_agentchat.teams.RoundRobinGroupChat` team with a single agent and ask it to write a poem. "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"---------- user ----------\n",
|
||||
"Write a beautiful poem 3-line about lake tangayika\n",
|
||||
"---------- assistant_agent ----------\n",
|
||||
"In Tanganyika's gleam, beneath the azure skies, \n",
|
||||
"Whispers of ancient waters, in tranquil guise, \n",
|
||||
"Nature's mirror, where dreams and serenity lie.\n",
|
||||
"[Prompt tokens: 29, Completion tokens: 34]\n",
|
||||
"---------- Summary ----------\n",
|
||||
"Number of messages: 2\n",
|
||||
"Finish reason: Maximum number of messages 2 reached, current message count: 2\n",
|
||||
"Total prompt tokens: 29\n",
|
||||
"Total completion tokens: 34\n",
|
||||
"Duration: 0.71 seconds\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"model_client = OpenAIChatCompletionClient(model=\"gpt-4o-2024-08-06\")\n",
|
||||
"\n",
|
||||
"# Define a team.\n",
|
||||
"assistant_agent = AssistantAgent(\n",
|
||||
" name=\"assistant_agent\",\n",
|
||||
" system_message=\"You are a helpful assistant\",\n",
|
||||
" model_client=model_client,\n",
|
||||
")\n",
|
||||
"agent_team = RoundRobinGroupChat([assistant_agent], termination_condition=MaxMessageTermination(max_messages=2))\n",
|
||||
"\n",
|
||||
"# Run the team and stream messages to the console.\n",
|
||||
"stream = agent_team.run_stream(task=\"Write a beautiful poem 3-line about lake tangayika\")\n",
|
||||
"\n",
|
||||
"# Use asyncio.run(...) when running in a script.\n",
|
||||
"await Console(stream)\n",
|
||||
"\n",
|
||||
"# Save the state of the agent team.\n",
|
||||
"team_state = await agent_team.save_state()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"If we reset the team (simulating instantiation of the team), and ask the question `What was the last line of the poem you wrote?`, we see that the team is unable to accomplish this as there is no reference to the previous run."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"---------- user ----------\n",
|
||||
"What was the last line of the poem you wrote?\n",
|
||||
"---------- assistant_agent ----------\n",
|
||||
"I'm sorry, but I am unable to recall or access previous interactions, including any specific poem I may have composed in our past conversations. If you like, I can write a new poem for you.\n",
|
||||
"[Prompt tokens: 28, Completion tokens: 40]\n",
|
||||
"---------- Summary ----------\n",
|
||||
"Number of messages: 2\n",
|
||||
"Finish reason: Maximum number of messages 2 reached, current message count: 2\n",
|
||||
"Total prompt tokens: 28\n",
|
||||
"Total completion tokens: 40\n",
|
||||
"Duration: 0.70 seconds\n"
|
||||
]
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Managing State \n",
|
||||
"\n",
|
||||
"So far, we have discussed how to build components in a multi-agent application - agents, teams, termination conditions. In many cases, it is useful to save the state of these components to disk and load them back later. This is particularly useful in a web application where stateless endpoints respond to requests and need to load the state of the application from persistent storage.\n",
|
||||
"\n",
|
||||
"In this notebook, we will discuss how to save and load the state of agents, teams, and termination conditions. \n",
|
||||
" \n",
|
||||
"\n",
|
||||
"## Saving and Loading Agents\n",
|
||||
"\n",
|
||||
"We can get the state of an agent by calling {py:meth}`~autogen_agentchat.agents.AssistantAgent.save_state` method on \n",
|
||||
"an {py:class}`~autogen_agentchat.agents.AssistantAgent`. "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"In Tanganyika's embrace so wide and deep, \n",
|
||||
"Ancient waters cradle secrets they keep, \n",
|
||||
"Echoes of time where horizons sleep. \n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from autogen_agentchat.agents import AssistantAgent\n",
|
||||
"from autogen_agentchat.conditions import MaxMessageTermination\n",
|
||||
"from autogen_agentchat.messages import TextMessage\n",
|
||||
"from autogen_agentchat.teams import RoundRobinGroupChat\n",
|
||||
"from autogen_agentchat.ui import Console\n",
|
||||
"from autogen_core import CancellationToken\n",
|
||||
"from autogen_ext.models.openai import OpenAIChatCompletionClient\n",
|
||||
"\n",
|
||||
"model_client = OpenAIChatCompletionClient(model=\"gpt-4o-2024-08-06\")\n",
|
||||
"\n",
|
||||
"assistant_agent = AssistantAgent(\n",
|
||||
" name=\"assistant_agent\",\n",
|
||||
" system_message=\"You are a helpful assistant\",\n",
|
||||
" model_client=model_client,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# Use asyncio.run(...) when running in a script.\n",
|
||||
"response = await assistant_agent.on_messages(\n",
|
||||
" [TextMessage(content=\"Write a 3 line poem on lake tangayika\", source=\"user\")], CancellationToken()\n",
|
||||
")\n",
|
||||
"print(response.chat_message)\n",
|
||||
"await model_client.close()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"{'type': 'AssistantAgentState', 'version': '1.0.0', 'llm_messages': [{'content': 'Write a 3 line poem on lake tangayika', 'source': 'user', 'type': 'UserMessage'}, {'content': \"In Tanganyika's embrace so wide and deep, \\nAncient waters cradle secrets they keep, \\nEchoes of time where horizons sleep. \", 'source': 'assistant_agent', 'type': 'AssistantMessage'}]}\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"agent_state = await assistant_agent.save_state()\n",
|
||||
"print(agent_state)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"The last line of the poem was: \"Echoes of time where horizons sleep.\"\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"model_client = OpenAIChatCompletionClient(model=\"gpt-4o-2024-08-06\")\n",
|
||||
"\n",
|
||||
"new_assistant_agent = AssistantAgent(\n",
|
||||
" name=\"assistant_agent\",\n",
|
||||
" system_message=\"You are a helpful assistant\",\n",
|
||||
" model_client=model_client,\n",
|
||||
")\n",
|
||||
"await new_assistant_agent.load_state(agent_state)\n",
|
||||
"\n",
|
||||
"# Use asyncio.run(...) when running in a script.\n",
|
||||
"response = await new_assistant_agent.on_messages(\n",
|
||||
" [TextMessage(content=\"What was the last line of the previous poem you wrote\", source=\"user\")], CancellationToken()\n",
|
||||
")\n",
|
||||
"print(response.chat_message)\n",
|
||||
"await model_client.close()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"```{note}\n",
|
||||
"For {py:class}`~autogen_agentchat.agents.AssistantAgent`, its state consists of the model_context.\n",
|
||||
"If your write your own custom agent, consider overriding the {py:meth}`~autogen_agentchat.agents.BaseChatAgent.save_state` and {py:meth}`~autogen_agentchat.agents.BaseChatAgent.load_state` methods to customize the behavior. The default implementations save and load an empty state.\n",
|
||||
"```"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Saving and Loading Teams \n",
|
||||
"\n",
|
||||
"We can get the state of a team by calling `save_state` method on the team and load it back by calling `load_state` method on the team. \n",
|
||||
"\n",
|
||||
"When we call `save_state` on a team, it saves the state of all the agents in the team.\n",
|
||||
"\n",
|
||||
"We will begin by creating a simple {py:class}`~autogen_agentchat.teams.RoundRobinGroupChat` team with a single agent and ask it to write a poem. "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"---------- user ----------\n",
|
||||
"Write a beautiful poem 3-line about lake tangayika\n",
|
||||
"---------- assistant_agent ----------\n",
|
||||
"In Tanganyika's gleam, beneath the azure skies, \n",
|
||||
"Whispers of ancient waters, in tranquil guise, \n",
|
||||
"Nature's mirror, where dreams and serenity lie.\n",
|
||||
"[Prompt tokens: 29, Completion tokens: 34]\n",
|
||||
"---------- Summary ----------\n",
|
||||
"Number of messages: 2\n",
|
||||
"Finish reason: Maximum number of messages 2 reached, current message count: 2\n",
|
||||
"Total prompt tokens: 29\n",
|
||||
"Total completion tokens: 34\n",
|
||||
"Duration: 0.71 seconds\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"model_client = OpenAIChatCompletionClient(model=\"gpt-4o-2024-08-06\")\n",
|
||||
"\n",
|
||||
"# Define a team.\n",
|
||||
"assistant_agent = AssistantAgent(\n",
|
||||
" name=\"assistant_agent\",\n",
|
||||
" system_message=\"You are a helpful assistant\",\n",
|
||||
" model_client=model_client,\n",
|
||||
")\n",
|
||||
"agent_team = RoundRobinGroupChat([assistant_agent], termination_condition=MaxMessageTermination(max_messages=2))\n",
|
||||
"\n",
|
||||
"# Run the team and stream messages to the console.\n",
|
||||
"stream = agent_team.run_stream(task=\"Write a beautiful poem 3-line about lake tangayika\")\n",
|
||||
"\n",
|
||||
"# Use asyncio.run(...) when running in a script.\n",
|
||||
"await Console(stream)\n",
|
||||
"\n",
|
||||
"# Save the state of the agent team.\n",
|
||||
"team_state = await agent_team.save_state()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"If we reset the team (simulating instantiation of the team), and ask the question `What was the last line of the poem you wrote?`, we see that the team is unable to accomplish this as there is no reference to the previous run."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"---------- user ----------\n",
|
||||
"What was the last line of the poem you wrote?\n",
|
||||
"---------- assistant_agent ----------\n",
|
||||
"I'm sorry, but I am unable to recall or access previous interactions, including any specific poem I may have composed in our past conversations. If you like, I can write a new poem for you.\n",
|
||||
"[Prompt tokens: 28, Completion tokens: 40]\n",
|
||||
"---------- Summary ----------\n",
|
||||
"Number of messages: 2\n",
|
||||
"Finish reason: Maximum number of messages 2 reached, current message count: 2\n",
|
||||
"Total prompt tokens: 28\n",
|
||||
"Total completion tokens: 40\n",
|
||||
"Duration: 0.70 seconds\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"TaskResult(messages=[TextMessage(source='user', models_usage=None, content='What was the last line of the poem you wrote?', type='TextMessage'), TextMessage(source='assistant_agent', models_usage=RequestUsage(prompt_tokens=28, completion_tokens=40), content=\"I'm sorry, but I am unable to recall or access previous interactions, including any specific poem I may have composed in our past conversations. If you like, I can write a new poem for you.\", type='TextMessage')], stop_reason='Maximum number of messages 2 reached, current message count: 2')"
|
||||
]
|
||||
},
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"await agent_team.reset()\n",
|
||||
"stream = agent_team.run_stream(task=\"What was the last line of the poem you wrote?\")\n",
|
||||
"await Console(stream)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Next, we load the state of the team and ask the same question. We see that the team is able to accurately return the last line of the poem it wrote."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"{'type': 'TeamState', 'version': '1.0.0', 'agent_states': {'group_chat_manager/a55364ad-86fd-46ab-9449-dcb5260b1e06': {'type': 'RoundRobinManagerState', 'version': '1.0.0', 'message_thread': [{'source': 'user', 'models_usage': None, 'content': 'Write a beautiful poem 3-line about lake tangayika', 'type': 'TextMessage'}, {'source': 'assistant_agent', 'models_usage': {'prompt_tokens': 29, 'completion_tokens': 34}, 'content': \"In Tanganyika's gleam, beneath the azure skies, \\nWhispers of ancient waters, in tranquil guise, \\nNature's mirror, where dreams and serenity lie.\", 'type': 'TextMessage'}], 'current_turn': 0, 'next_speaker_index': 0}, 'collect_output_messages/a55364ad-86fd-46ab-9449-dcb5260b1e06': {}, 'assistant_agent/a55364ad-86fd-46ab-9449-dcb5260b1e06': {'type': 'ChatAgentContainerState', 'version': '1.0.0', 'agent_state': {'type': 'AssistantAgentState', 'version': '1.0.0', 'llm_messages': [{'content': 'Write a beautiful poem 3-line about lake tangayika', 'source': 'user', 'type': 'UserMessage'}, {'content': \"In Tanganyika's gleam, beneath the azure skies, \\nWhispers of ancient waters, in tranquil guise, \\nNature's mirror, where dreams and serenity lie.\", 'source': 'assistant_agent', 'type': 'AssistantMessage'}]}, 'message_buffer': []}}, 'team_id': 'a55364ad-86fd-46ab-9449-dcb5260b1e06'}\n",
|
||||
"---------- user ----------\n",
|
||||
"What was the last line of the poem you wrote?\n",
|
||||
"---------- assistant_agent ----------\n",
|
||||
"The last line of the poem I wrote is: \n",
|
||||
"\"Nature's mirror, where dreams and serenity lie.\"\n",
|
||||
"[Prompt tokens: 86, Completion tokens: 22]\n",
|
||||
"---------- Summary ----------\n",
|
||||
"Number of messages: 2\n",
|
||||
"Finish reason: Maximum number of messages 2 reached, current message count: 2\n",
|
||||
"Total prompt tokens: 86\n",
|
||||
"Total completion tokens: 22\n",
|
||||
"Duration: 0.96 seconds\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"TaskResult(messages=[TextMessage(source='user', models_usage=None, content='What was the last line of the poem you wrote?', type='TextMessage'), TextMessage(source='assistant_agent', models_usage=RequestUsage(prompt_tokens=86, completion_tokens=22), content='The last line of the poem I wrote is: \\n\"Nature\\'s mirror, where dreams and serenity lie.\"', type='TextMessage')], stop_reason='Maximum number of messages 2 reached, current message count: 2')"
|
||||
]
|
||||
},
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print(team_state)\n",
|
||||
"\n",
|
||||
"# Load team state.\n",
|
||||
"await agent_team.load_state(team_state)\n",
|
||||
"stream = agent_team.run_stream(task=\"What was the last line of the poem you wrote?\")\n",
|
||||
"await Console(stream)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Persisting State (File or Database)\n",
|
||||
"\n",
|
||||
"In many cases, we may want to persist the state of the team to disk (or a database) and load it back later. State is a dictionary that can be serialized to a file or written to a database."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"---------- user ----------\n",
|
||||
"What was the last line of the poem you wrote?\n",
|
||||
"---------- assistant_agent ----------\n",
|
||||
"The last line of the poem I wrote is: \n",
|
||||
"\"Nature's mirror, where dreams and serenity lie.\"\n",
|
||||
"[Prompt tokens: 86, Completion tokens: 22]\n",
|
||||
"---------- Summary ----------\n",
|
||||
"Number of messages: 2\n",
|
||||
"Finish reason: Maximum number of messages 2 reached, current message count: 2\n",
|
||||
"Total prompt tokens: 86\n",
|
||||
"Total completion tokens: 22\n",
|
||||
"Duration: 0.72 seconds\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"TaskResult(messages=[TextMessage(source='user', models_usage=None, content='What was the last line of the poem you wrote?', type='TextMessage'), TextMessage(source='assistant_agent', models_usage=RequestUsage(prompt_tokens=86, completion_tokens=22), content='The last line of the poem I wrote is: \\n\"Nature\\'s mirror, where dreams and serenity lie.\"', type='TextMessage')], stop_reason='Maximum number of messages 2 reached, current message count: 2')"
|
||||
]
|
||||
},
|
||||
"execution_count": 13,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import json\n",
|
||||
"\n",
|
||||
"## save state to disk\n",
|
||||
"\n",
|
||||
"with open(\"coding/team_state.json\", \"w\") as f:\n",
|
||||
" json.dump(team_state, f)\n",
|
||||
"\n",
|
||||
"## load state from disk\n",
|
||||
"with open(\"coding/team_state.json\", \"r\") as f:\n",
|
||||
" team_state = json.load(f)\n",
|
||||
"\n",
|
||||
"new_agent_team = RoundRobinGroupChat([assistant_agent], termination_condition=MaxMessageTermination(max_messages=2))\n",
|
||||
"await new_agent_team.load_state(team_state)\n",
|
||||
"stream = new_agent_team.run_stream(task=\"What was the last line of the poem you wrote?\")\n",
|
||||
"await Console(stream)\n",
|
||||
"await model_client.close()"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "agnext",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.9"
|
||||
}
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"TaskResult(messages=[TextMessage(source='user', models_usage=None, content='What was the last line of the poem you wrote?', type='TextMessage'), TextMessage(source='assistant_agent', models_usage=RequestUsage(prompt_tokens=28, completion_tokens=40), content=\"I'm sorry, but I am unable to recall or access previous interactions, including any specific poem I may have composed in our past conversations. If you like, I can write a new poem for you.\", type='TextMessage')], stop_reason='Maximum number of messages 2 reached, current message count: 2')"
|
||||
]
|
||||
},
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"await agent_team.reset()\n",
|
||||
"stream = agent_team.run_stream(task=\"What was the last line of the poem you wrote?\")\n",
|
||||
"await Console(stream)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Next, we load the state of the team and ask the same question. We see that the team is able to accurately return the last line of the poem it wrote."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"{'type': 'TeamState', 'version': '1.0.0', 'agent_states': {'group_chat_manager/a55364ad-86fd-46ab-9449-dcb5260b1e06': {'type': 'RoundRobinManagerState', 'version': '1.0.0', 'message_thread': [{'source': 'user', 'models_usage': None, 'content': 'Write a beautiful poem 3-line about lake tangayika', 'type': 'TextMessage'}, {'source': 'assistant_agent', 'models_usage': {'prompt_tokens': 29, 'completion_tokens': 34}, 'content': \"In Tanganyika's gleam, beneath the azure skies, \\nWhispers of ancient waters, in tranquil guise, \\nNature's mirror, where dreams and serenity lie.\", 'type': 'TextMessage'}], 'current_turn': 0, 'next_speaker_index': 0}, 'collect_output_messages/a55364ad-86fd-46ab-9449-dcb5260b1e06': {}, 'assistant_agent/a55364ad-86fd-46ab-9449-dcb5260b1e06': {'type': 'ChatAgentContainerState', 'version': '1.0.0', 'agent_state': {'type': 'AssistantAgentState', 'version': '1.0.0', 'llm_messages': [{'content': 'Write a beautiful poem 3-line about lake tangayika', 'source': 'user', 'type': 'UserMessage'}, {'content': \"In Tanganyika's gleam, beneath the azure skies, \\nWhispers of ancient waters, in tranquil guise, \\nNature's mirror, where dreams and serenity lie.\", 'source': 'assistant_agent', 'type': 'AssistantMessage'}]}, 'message_buffer': []}}, 'team_id': 'a55364ad-86fd-46ab-9449-dcb5260b1e06'}\n",
|
||||
"---------- user ----------\n",
|
||||
"What was the last line of the poem you wrote?\n",
|
||||
"---------- assistant_agent ----------\n",
|
||||
"The last line of the poem I wrote is: \n",
|
||||
"\"Nature's mirror, where dreams and serenity lie.\"\n",
|
||||
"[Prompt tokens: 86, Completion tokens: 22]\n",
|
||||
"---------- Summary ----------\n",
|
||||
"Number of messages: 2\n",
|
||||
"Finish reason: Maximum number of messages 2 reached, current message count: 2\n",
|
||||
"Total prompt tokens: 86\n",
|
||||
"Total completion tokens: 22\n",
|
||||
"Duration: 0.96 seconds\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"TaskResult(messages=[TextMessage(source='user', models_usage=None, content='What was the last line of the poem you wrote?', type='TextMessage'), TextMessage(source='assistant_agent', models_usage=RequestUsage(prompt_tokens=86, completion_tokens=22), content='The last line of the poem I wrote is: \\n\"Nature\\'s mirror, where dreams and serenity lie.\"', type='TextMessage')], stop_reason='Maximum number of messages 2 reached, current message count: 2')"
|
||||
]
|
||||
},
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print(team_state)\n",
|
||||
"\n",
|
||||
"# Load team state.\n",
|
||||
"await agent_team.load_state(team_state)\n",
|
||||
"stream = agent_team.run_stream(task=\"What was the last line of the poem you wrote?\")\n",
|
||||
"await Console(stream)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Persisting State (File or Database)\n",
|
||||
"\n",
|
||||
"In many cases, we may want to persist the state of the team to disk (or a database) and load it back later. State is a dictionary that can be serialized to a file or written to a database."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"---------- user ----------\n",
|
||||
"What was the last line of the poem you wrote?\n",
|
||||
"---------- assistant_agent ----------\n",
|
||||
"The last line of the poem I wrote is: \n",
|
||||
"\"Nature's mirror, where dreams and serenity lie.\"\n",
|
||||
"[Prompt tokens: 86, Completion tokens: 22]\n",
|
||||
"---------- Summary ----------\n",
|
||||
"Number of messages: 2\n",
|
||||
"Finish reason: Maximum number of messages 2 reached, current message count: 2\n",
|
||||
"Total prompt tokens: 86\n",
|
||||
"Total completion tokens: 22\n",
|
||||
"Duration: 0.72 seconds\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"TaskResult(messages=[TextMessage(source='user', models_usage=None, content='What was the last line of the poem you wrote?', type='TextMessage'), TextMessage(source='assistant_agent', models_usage=RequestUsage(prompt_tokens=86, completion_tokens=22), content='The last line of the poem I wrote is: \\n\"Nature\\'s mirror, where dreams and serenity lie.\"', type='TextMessage')], stop_reason='Maximum number of messages 2 reached, current message count: 2')"
|
||||
]
|
||||
},
|
||||
"execution_count": 13,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import json\n",
|
||||
"\n",
|
||||
"## save state to disk\n",
|
||||
"\n",
|
||||
"with open(\"coding/team_state.json\", \"w\") as f:\n",
|
||||
" json.dump(team_state, f)\n",
|
||||
"\n",
|
||||
"## load state from disk\n",
|
||||
"with open(\"coding/team_state.json\", \"r\") as f:\n",
|
||||
" team_state = json.load(f)\n",
|
||||
"\n",
|
||||
"new_agent_team = RoundRobinGroupChat([assistant_agent], termination_condition=MaxMessageTermination(max_messages=2))\n",
|
||||
"await new_agent_team.load_state(team_state)\n",
|
||||
"stream = new_agent_team.run_stream(task=\"What was the last line of the poem you wrote?\")\n",
|
||||
"await Console(stream)\n",
|
||||
"await model_client.close()"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "agnext",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.9"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
|
|
|
@ -133,7 +133,7 @@
|
|||
" response = await self._delegate.on_messages(\n",
|
||||
" [TextMessage(content=message.content, source=\"user\")], ctx.cancellation_token\n",
|
||||
" )\n",
|
||||
" print(f\"{self.id.type} responded: {response.chat_message.content}\")"
|
||||
" print(f\"{self.id.type} responded: {response.chat_message}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
|
|
@ -7,7 +7,6 @@ from autogen_agentchat.agents import BaseChatAgent
|
|||
from autogen_agentchat.base import Response
|
||||
from autogen_agentchat.messages import (
|
||||
ChatMessage,
|
||||
MultiModalMessage,
|
||||
TextMessage,
|
||||
)
|
||||
from autogen_agentchat.utils import remove_images
|
||||
|
@ -90,11 +89,7 @@ class FileSurfer(BaseChatAgent, Component[FileSurferConfig]):
|
|||
|
||||
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
|
||||
for chat_message in messages:
|
||||
if isinstance(chat_message, TextMessage | MultiModalMessage):
|
||||
self._chat_history.append(UserMessage(content=chat_message.content, source=chat_message.source))
|
||||
else:
|
||||
raise ValueError(f"Unexpected message in FileSurfer: {chat_message}")
|
||||
|
||||
self._chat_history.append(chat_message.to_model_message())
|
||||
try:
|
||||
_, content = await self._generate_reply(cancellation_token=cancellation_token)
|
||||
self._chat_history.append(AssistantMessage(content=content, source=self.name))
|
||||
|
|
|
@ -26,16 +26,12 @@ from autogen_agentchat.base import Response
|
|||
from autogen_agentchat.messages import (
|
||||
AgentEvent,
|
||||
ChatMessage,
|
||||
HandoffMessage,
|
||||
MultiModalMessage,
|
||||
StopMessage,
|
||||
TextMessage,
|
||||
ToolCallExecutionEvent,
|
||||
ToolCallRequestEvent,
|
||||
)
|
||||
from autogen_core import CancellationToken, FunctionCall
|
||||
from autogen_core.models._model_client import ChatCompletionClient
|
||||
from autogen_core.models._types import FunctionExecutionResult
|
||||
from autogen_core import CancellationToken, FunctionCall, Image
|
||||
from autogen_core.models import ChatCompletionClient, FunctionExecutionResult
|
||||
from autogen_core.tools import FunctionTool, Tool
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
@ -52,6 +48,12 @@ from openai.types.beta.file_search_tool_param import FileSearchToolParam
|
|||
from openai.types.beta.function_tool_param import FunctionToolParam
|
||||
from openai.types.beta.thread import Thread, ToolResources, ToolResourcesCodeInterpreter
|
||||
from openai.types.beta.threads import Message, MessageDeleted, Run
|
||||
from openai.types.beta.threads.image_url_content_block_param import ImageURLContentBlockParam
|
||||
from openai.types.beta.threads.image_url_param import ImageURLParam
|
||||
from openai.types.beta.threads.message_content_part_param import (
|
||||
MessageContentPartParam,
|
||||
)
|
||||
from openai.types.beta.threads.text_content_block_param import TextContentBlockParam
|
||||
from openai.types.shared_params.function_definition import FunctionDefinition
|
||||
from openai.types.vector_store import VectorStore
|
||||
|
||||
|
@ -406,10 +408,7 @@ class OpenAIAssistantAgent(BaseChatAgent):
|
|||
|
||||
# Process all messages in sequence
|
||||
for message in messages:
|
||||
if isinstance(message, (TextMessage, MultiModalMessage)):
|
||||
await self.handle_text_message(str(message.content), cancellation_token)
|
||||
elif isinstance(message, (StopMessage, HandoffMessage)):
|
||||
await self.handle_text_message(message.content, cancellation_token)
|
||||
await self.handle_incoming_message(message, cancellation_token)
|
||||
|
||||
# Inner messages for tool calls
|
||||
inner_messages: List[AgentEvent | ChatMessage] = []
|
||||
|
@ -519,8 +518,21 @@ class OpenAIAssistantAgent(BaseChatAgent):
|
|||
chat_message = TextMessage(source=self.name, content=text_content[0].text.value)
|
||||
yield Response(chat_message=chat_message, inner_messages=inner_messages)
|
||||
|
||||
async def handle_text_message(self, content: str, cancellation_token: CancellationToken) -> None:
|
||||
async def handle_incoming_message(self, message: ChatMessage, cancellation_token: CancellationToken) -> None:
|
||||
"""Handle regular text messages by adding them to the thread."""
|
||||
content: str | List[MessageContentPartParam] | None = None
|
||||
llm_message = message.to_model_message()
|
||||
if isinstance(llm_message.content, str):
|
||||
content = llm_message.content
|
||||
else:
|
||||
content = []
|
||||
for c in llm_message.content:
|
||||
if isinstance(c, str):
|
||||
content.append(TextContentBlockParam(text=c, type="text"))
|
||||
elif isinstance(c, Image):
|
||||
content.append(ImageURLContentBlockParam(image_url=ImageURLParam(url=c.data_uri), type="image_url"))
|
||||
else:
|
||||
raise ValueError(f"Unsupported content type: {type(c)} in {message}")
|
||||
await cancellation_token.link_future(
|
||||
asyncio.ensure_future(
|
||||
self._client.beta.threads.messages.create(
|
||||
|
|
|
@ -432,10 +432,8 @@ class MultimodalWebSurfer(BaseChatAgent, Component[MultimodalWebSurferConfig]):
|
|||
self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken
|
||||
) -> AsyncGenerator[AgentEvent | ChatMessage | Response, None]:
|
||||
for chat_message in messages:
|
||||
if isinstance(chat_message, TextMessage | MultiModalMessage):
|
||||
self._chat_history.append(UserMessage(content=chat_message.content, source=chat_message.source))
|
||||
else:
|
||||
raise ValueError(f"Unexpected message in MultiModalWebSurfer: {chat_message}")
|
||||
self._chat_history.append(chat_message.to_model_message())
|
||||
|
||||
self.inner_messages: List[AgentEvent | ChatMessage] = []
|
||||
self.model_usage: List[RequestUsage] = []
|
||||
try:
|
||||
|
|
|
@ -192,7 +192,7 @@ In responding to every user message, you follow the same multi-step process give
|
|||
task_result: TaskResult = await assistant_agent.run(task=TextMessage(content=task, source="User"))
|
||||
messages: Sequence[AgentEvent | ChatMessage] = task_result.messages
|
||||
message: AgentEvent | ChatMessage = messages[-1]
|
||||
response_str = message.content
|
||||
response_str = message.to_text()
|
||||
|
||||
# Log the model call
|
||||
self.logger.log_model_task(
|
||||
|
@ -245,12 +245,7 @@ In responding to every user message, you follow the same multi-step process give
|
|||
|
||||
response_str_list: List[str] = []
|
||||
for message in messages:
|
||||
content = message.content
|
||||
if isinstance(content, str):
|
||||
content_str = content
|
||||
else:
|
||||
content_str = "Not a string."
|
||||
response_str_list.append(content_str)
|
||||
response_str_list.append(message.to_text())
|
||||
response_str = "\n".join(response_str_list)
|
||||
|
||||
self.logger.info("\n----- RESPONSE -----\n\n{}\n".format(response_str))
|
||||
|
|
|
@ -345,7 +345,7 @@ class PageLogger:
|
|||
|
||||
messages: Sequence[AgentEvent | ChatMessage] = task_result.messages
|
||||
message = messages[-1]
|
||||
response_str = message.content
|
||||
response_str = message.to_text()
|
||||
if not isinstance(response_str, str):
|
||||
response_str = "??"
|
||||
|
||||
|
|
|
@ -126,7 +126,7 @@ class HttpTool(BaseTool[BaseModel, Any], Component[HttpToolConfig]):
|
|||
[TextMessage(content="Can you base64 decode the value 'YWJjZGU=', please?", source="user")],
|
||||
CancellationToken(),
|
||||
)
|
||||
print(response.chat_message.content)
|
||||
print(response.chat_message)
|
||||
|
||||
|
||||
asyncio.run(main())
|
||||
|
|
|
@ -105,7 +105,7 @@ async def mcp_server_tools(
|
|||
|
||||
# Let the agent fetch the content of a URL and summarize it.
|
||||
result = await agent.run(task="Summarize the content of https://en.wikipedia.org/wiki/Seattle")
|
||||
print(result.messages[-1].content)
|
||||
print(result.messages[-1])
|
||||
|
||||
|
||||
asyncio.run(main())
|
||||
|
|
|
@ -61,7 +61,7 @@ def _extract_message_content(message: AgentEvent | ChatMessage) -> Tuple[List[st
|
|||
text_parts = [item for item in message.content if isinstance(item, str)]
|
||||
image_parts = [item for item in message.content if isinstance(item, Image)]
|
||||
else:
|
||||
text_parts = [str(message.content)]
|
||||
text_parts = [message.to_text()]
|
||||
image_parts = []
|
||||
return text_parts, image_parts
|
||||
|
||||
|
|
|
@ -8,6 +8,7 @@ from typing import Any, AsyncGenerator, List
|
|||
import aiofiles
|
||||
import pytest
|
||||
from autogen_agentchat import EVENT_LOGGER_NAME
|
||||
from autogen_agentchat.messages import TextMessage
|
||||
from autogen_ext.agents.file_surfer import FileSurfer
|
||||
from autogen_ext.models.openai import OpenAIChatCompletionClient
|
||||
from openai.resources.chat.completions import AsyncCompletions
|
||||
|
@ -140,9 +141,11 @@ async def test_run_filesurfer(monkeypatch: pytest.MonkeyPatch) -> None:
|
|||
# Get the FileSurfer to read the file, and the directory
|
||||
assert agent._name == "FileSurfer" # pyright: ignore[reportPrivateUsage]
|
||||
result = await agent.run(task="Please read the test file")
|
||||
assert isinstance(result.messages[1], TextMessage)
|
||||
assert "# FileSurfer test H1" in result.messages[1].content
|
||||
|
||||
result = await agent.run(task="Please read the test directory")
|
||||
assert isinstance(result.messages[1], TextMessage)
|
||||
assert "# Index of " in result.messages[1].content
|
||||
assert "test_filesurfer_agent.html" in result.messages[1].content
|
||||
|
||||
|
|
|
@ -8,7 +8,7 @@ from unittest.mock import AsyncMock, MagicMock
|
|||
|
||||
import aiofiles
|
||||
import pytest
|
||||
from autogen_agentchat.messages import ChatMessage, TextMessage
|
||||
from autogen_agentchat.messages import ChatMessage, TextMessage, ToolCallRequestEvent
|
||||
from autogen_core import CancellationToken
|
||||
from autogen_core.tools._base import BaseTool, Tool
|
||||
from autogen_ext.agents.openai import OpenAIAssistantAgent
|
||||
|
@ -250,8 +250,7 @@ async def test_file_retrieval(
|
|||
message = TextMessage(source="user", content="What is the first sentence of the jungle scout book?")
|
||||
response = await agent.on_messages([message], cancellation_token)
|
||||
|
||||
assert response.chat_message.content is not None
|
||||
assert isinstance(response.chat_message.content, str)
|
||||
assert isinstance(response.chat_message, TextMessage)
|
||||
assert len(response.chat_message.content) > 0
|
||||
|
||||
await agent.delete_uploaded_files(cancellation_token)
|
||||
|
@ -271,8 +270,7 @@ async def test_code_interpreter(
|
|||
message = TextMessage(source="user", content="I need to solve the equation `3x + 11 = 14`. Can you help me?")
|
||||
response = await agent.on_messages([message], cancellation_token)
|
||||
|
||||
assert response.chat_message.content is not None
|
||||
assert isinstance(response.chat_message.content, str)
|
||||
assert isinstance(response.chat_message, TextMessage)
|
||||
assert len(response.chat_message.content) > 0
|
||||
assert "x = 1" in response.chat_message.content.lower()
|
||||
|
||||
|
@ -326,12 +324,11 @@ async def test_quiz_creation(
|
|||
response = await agent.on_messages([message], cancellation_token)
|
||||
|
||||
# Check that the final response has non-empty inner messages (i.e. tool call events).
|
||||
assert response.chat_message.content is not None
|
||||
assert isinstance(response.chat_message.content, str)
|
||||
assert isinstance(response.chat_message, TextMessage)
|
||||
assert len(response.chat_message.content) > 0
|
||||
assert isinstance(response.inner_messages, list)
|
||||
# Ensure that at least one inner message has non-empty content.
|
||||
assert any(hasattr(tool_msg, "content") and tool_msg.content for tool_msg in response.inner_messages)
|
||||
assert any(isinstance(msg, ToolCallRequestEvent) for msg in response.inner_messages)
|
||||
|
||||
await agent.delete_assistant(cancellation_token)
|
||||
|
||||
|
@ -357,14 +354,14 @@ async def test_on_reset_behavior(client: AsyncOpenAI, cancellation_token: Cancel
|
|||
|
||||
message1 = TextMessage(source="user", content="What is my name?")
|
||||
response1 = await agent.on_messages([message1], cancellation_token)
|
||||
assert isinstance(response1.chat_message.content, str)
|
||||
assert isinstance(response1.chat_message, TextMessage)
|
||||
assert "john" in response1.chat_message.content.lower()
|
||||
|
||||
await agent.on_reset(cancellation_token)
|
||||
|
||||
message2 = TextMessage(source="user", content="What is my name?")
|
||||
response2 = await agent.on_messages([message2], cancellation_token)
|
||||
assert isinstance(response2.chat_message.content, str)
|
||||
assert isinstance(response2.chat_message, TextMessage)
|
||||
assert "john" in response2.chat_message.content.lower()
|
||||
|
||||
await agent.delete_assistant(cancellation_token)
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
import argparse
|
||||
import asyncio
|
||||
from autogen_agentchat.messages import TextMessage
|
||||
import yaml
|
||||
import random
|
||||
|
||||
|
@ -78,11 +79,10 @@ async def get_ai_move(board: chess.Board, player: AssistantAgent, max_tries: int
|
|||
while count < max_tries:
|
||||
result = await Console(player.run_stream(task=task))
|
||||
count += 1
|
||||
response = result.messages[-1].content
|
||||
assert isinstance(response, str)
|
||||
assert isinstance(result.messages[-1], TextMessage)
|
||||
# Check if the response is a valid UC move.
|
||||
try:
|
||||
move = chess.Move.from_uci(extract_move(response))
|
||||
move = chess.Move.from_uci(extract_move(result.messages[-1].content))
|
||||
except (ValueError, IndexError):
|
||||
task = "Invalid format. Please read instruction.\n" + get_ai_prompt(board)
|
||||
continue
|
||||
|
|
|
@ -22,5 +22,5 @@ class Agent:
|
|||
[TextMessage(content=prompt, source="user")],
|
||||
CancellationToken(),
|
||||
)
|
||||
assert isinstance(response.chat_message.content, str)
|
||||
assert isinstance(response.chat_message, TextMessage)
|
||||
return response.chat_message.content
|
||||
|
|
Loading…
Reference in New Issue