Rename to use BaseChatMessage and BaseAgentEvent. Bring back union types. (#6144)

Rename the `ChatMessage` and `AgentEvent` base classes to `BaseChatMessage` and `BaseAgentEvent`. 

Bring back the `ChatMessage` and `AgentEvent` as union of built-in concrete types to avoid breaking existing applications that depends on Pydantic serialization. 

Why?

Many existing code uses containers like this:

```python
class AppMessage(BaseModel):
   name: str
   message: ChatMessage 

# Serialization is this:
m = AppMessage(...)
m.model_dump_json()

# Fields like HandoffMessage.target will be lost because it is now treated as a base class without content or target fields.
```

The assumption on `ChatMessage` or `AgentEvent` to be a union of concrete types could be in many existing code bases. So this PR brings back the union types, while keep method type hints such as those on `on_messages` to use the `BaseChatMessage` and `BaseAgentEvent` base classes for flexibility.
This commit is contained in:
Eric Zhu 2025-03-30 09:34:40 -07:00 committed by GitHub
parent e686342f53
commit 7615c7b83b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
49 changed files with 1532 additions and 1442 deletions

View File

@ -16,7 +16,7 @@ from autogen_core.models import ChatCompletionClient
from autogen_ext.agents.web_surfer import MultimodalWebSurfer
from autogen_ext.agents.file_surfer import FileSurfer
from autogen_agentchat.agents import CodeExecutorAgent
from autogen_agentchat.messages import TextMessage, AgentEvent, ChatMessage, HandoffMessage, MultiModalMessage, StopMessage
from autogen_agentchat.messages import TextMessage, BaseAgentEvent, BaseChatMessage, HandoffMessage, MultiModalMessage, StopMessage
from autogen_core.models import LLMMessage, UserMessage, AssistantMessage
# Suppress warnings about the requests.Session() not being closed
@ -141,7 +141,7 @@ class LLMTermination(TerminationCondition):
def terminated(self) -> bool:
return self._terminated
async def __call__(self, messages: Sequence[AgentEvent | ChatMessage]) -> StopMessage | None:
async def __call__(self, messages: Sequence[BaseAgentEvent | BaseChatMessage]) -> StopMessage | None:
if self._terminated:
raise TerminatedException("Termination condition has already been reached")

View File

@ -1,4 +1,4 @@
# __init__.py
from ._base import Code, Document, CodedDocument, BaseQualitativeCoder
from ._base import BaseQualitativeCoder, Code, CodedDocument, Document
__all__ = ["Code", "Document", "CodedDocument", "BaseQualitativeCoder"]

View File

@ -1,7 +1,8 @@
import json
import hashlib
import json
import re
from typing import Protocol, List, Set, Optional
from typing import List, Optional, Protocol, Set
from pydantic import BaseModel, Field

View File

@ -1,8 +1,10 @@
import os
import argparse
from typing import List, Sequence, Optional
import os
from typing import List, Optional, Sequence
from openai import OpenAI
from ._base import Document, CodedDocument
from ._base import CodedDocument, Document
from .coders.oai_coder import OAIQualitativeCoder

View File

@ -1,13 +1,11 @@
import os
import re
from typing import List, Set, Optional
from pydantic import BaseModel
from typing import List, Optional, Set
from openai import OpenAI
from pydantic import BaseModel
from .._base import CodedDocument, Document, Code
from .._base import BaseQualitativeCoder
from .._base import BaseQualitativeCoder, Code, CodedDocument, Document
class CodeList(BaseModel):

View File

@ -40,8 +40,8 @@ from .. import EVENT_LOGGER_NAME
from ..base import Handoff as HandoffBase
from ..base import Response
from ..messages import (
AgentEvent,
ChatMessage,
BaseAgentEvent,
BaseChatMessage,
HandoffMessage,
MemoryQueryEvent,
ModelClientStreamingChunkEvent,
@ -697,8 +697,8 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
self._is_running = False
@property
def produced_message_types(self) -> Sequence[type[ChatMessage]]:
message_types: List[type[ChatMessage]] = [TextMessage]
def produced_message_types(self) -> Sequence[type[BaseChatMessage]]:
message_types: List[type[BaseChatMessage]] = [TextMessage]
if self._handoffs:
message_types.append(HandoffMessage)
if self._tools:
@ -712,15 +712,15 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
"""
return self._model_context
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
async def on_messages(self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken) -> Response:
async for message in self.on_messages_stream(messages, cancellation_token):
if isinstance(message, Response):
return message
raise AssertionError("The stream should have returned the final result.")
async def on_messages_stream(
self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken
) -> AsyncGenerator[AgentEvent | ChatMessage | Response, None]:
self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken
) -> AsyncGenerator[BaseAgentEvent | BaseChatMessage | Response, None]:
"""
Process the incoming messages with the assistant agent and yield events/responses as they happen.
"""
@ -745,7 +745,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
)
# STEP 2: Update model context with any relevant memory
inner_messages: List[AgentEvent | ChatMessage] = []
inner_messages: List[BaseAgentEvent | BaseChatMessage] = []
for event_msg in await self._update_model_context_with_memory(
memory=memory,
model_context=model_context,
@ -810,7 +810,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
@staticmethod
async def _add_messages_to_context(
model_context: ChatCompletionContext,
messages: Sequence[ChatMessage],
messages: Sequence[BaseChatMessage],
) -> None:
"""
Add incoming messages to the model context.
@ -886,7 +886,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
async def _process_model_result(
cls,
model_result: CreateResult,
inner_messages: List[AgentEvent | ChatMessage],
inner_messages: List[BaseAgentEvent | BaseChatMessage],
cancellation_token: CancellationToken,
agent_name: str,
system_messages: List[SystemMessage],
@ -898,7 +898,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
model_client_stream: bool,
reflect_on_tool_use: bool,
tool_call_summary_format: str,
) -> AsyncGenerator[AgentEvent | ChatMessage | Response, None]:
) -> AsyncGenerator[BaseAgentEvent | BaseChatMessage | Response, None]:
"""
Handle final or partial responses from model_result, including tool calls, handoffs,
and reflection if needed.
@ -992,7 +992,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
def _check_and_handle_handoff(
model_result: CreateResult,
executed_calls_and_results: List[Tuple[FunctionCall, FunctionExecutionResult]],
inner_messages: List[AgentEvent | ChatMessage],
inner_messages: List[BaseAgentEvent | BaseChatMessage],
handoffs: Dict[str, HandoffBase],
agent_name: str,
) -> Optional[Response]:
@ -1061,7 +1061,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
model_client_stream: bool,
model_context: ChatCompletionContext,
agent_name: str,
inner_messages: List[AgentEvent | ChatMessage],
inner_messages: List[BaseAgentEvent | BaseChatMessage],
) -> AsyncGenerator[Response | ModelClientStreamingChunkEvent | ThoughtEvent, None]:
"""
If reflect_on_tool_use=True, we do another inference based on tool results
@ -1113,7 +1113,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
@staticmethod
def _summarize_tool_use(
executed_calls_and_results: List[Tuple[FunctionCall, FunctionExecutionResult]],
inner_messages: List[AgentEvent | ChatMessage],
inner_messages: List[BaseAgentEvent | BaseChatMessage],
handoffs: Dict[str, HandoffBase],
tool_call_summary_format: str,
agent_name: str,

View File

@ -6,8 +6,8 @@ from pydantic import BaseModel
from ..base import ChatAgent, Response, TaskResult
from ..messages import (
AgentEvent,
ChatMessage,
BaseAgentEvent,
BaseChatMessage,
ModelClientStreamingChunkEvent,
TextMessage,
)
@ -59,13 +59,13 @@ class BaseChatAgent(ChatAgent, ABC, ComponentBase[BaseModel]):
@property
@abstractmethod
def produced_message_types(self) -> Sequence[type[ChatMessage]]:
def produced_message_types(self) -> Sequence[type[BaseChatMessage]]:
"""The types of messages that the agent produces in the
:attr:`Response.chat_message` field. They must be :class:`ChatMessage` types."""
:attr:`Response.chat_message` field. They must be :class:`BaseChatMessage` types."""
...
@abstractmethod
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
async def on_messages(self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken) -> Response:
"""Handles incoming messages and returns a response.
.. note::
@ -81,8 +81,8 @@ class BaseChatAgent(ChatAgent, ABC, ComponentBase[BaseModel]):
...
async def on_messages_stream(
self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken
) -> AsyncGenerator[AgentEvent | ChatMessage | Response, None]:
self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken
) -> AsyncGenerator[BaseAgentEvent | BaseChatMessage | Response, None]:
"""Handles incoming messages and returns a stream of messages and
and the final item is the response. The base implementation in
:class:`BaseChatAgent` simply calls :meth:`on_messages` and yields
@ -106,21 +106,21 @@ class BaseChatAgent(ChatAgent, ABC, ComponentBase[BaseModel]):
async def run(
self,
*,
task: str | ChatMessage | Sequence[ChatMessage] | None = None,
task: str | BaseChatMessage | Sequence[BaseChatMessage] | None = None,
cancellation_token: CancellationToken | None = None,
) -> TaskResult:
"""Run the agent with the given task and return the result."""
if cancellation_token is None:
cancellation_token = CancellationToken()
input_messages: List[ChatMessage] = []
output_messages: List[AgentEvent | ChatMessage] = []
input_messages: List[BaseChatMessage] = []
output_messages: List[BaseAgentEvent | BaseChatMessage] = []
if task is None:
pass
elif isinstance(task, str):
text_msg = TextMessage(content=task, source="user")
input_messages.append(text_msg)
output_messages.append(text_msg)
elif isinstance(task, ChatMessage):
elif isinstance(task, BaseChatMessage):
input_messages.append(task)
output_messages.append(task)
else:
@ -128,7 +128,7 @@ class BaseChatAgent(ChatAgent, ABC, ComponentBase[BaseModel]):
raise ValueError("Task list cannot be empty.")
# Task is a sequence of messages.
for msg in task:
if isinstance(msg, ChatMessage):
if isinstance(msg, BaseChatMessage):
input_messages.append(msg)
output_messages.append(msg)
else:
@ -142,15 +142,15 @@ class BaseChatAgent(ChatAgent, ABC, ComponentBase[BaseModel]):
async def run_stream(
self,
*,
task: str | ChatMessage | Sequence[ChatMessage] | None = None,
task: str | BaseChatMessage | Sequence[BaseChatMessage] | None = None,
cancellation_token: CancellationToken | None = None,
) -> AsyncGenerator[AgentEvent | ChatMessage | TaskResult, None]:
) -> AsyncGenerator[BaseAgentEvent | BaseChatMessage | TaskResult, None]:
"""Run the agent with the given task and return a stream of messages
and the final task result as the last item in the stream."""
if cancellation_token is None:
cancellation_token = CancellationToken()
input_messages: List[ChatMessage] = []
output_messages: List[AgentEvent | ChatMessage] = []
input_messages: List[BaseChatMessage] = []
output_messages: List[BaseAgentEvent | BaseChatMessage] = []
if task is None:
pass
elif isinstance(task, str):
@ -158,7 +158,7 @@ class BaseChatAgent(ChatAgent, ABC, ComponentBase[BaseModel]):
input_messages.append(text_msg)
output_messages.append(text_msg)
yield text_msg
elif isinstance(task, ChatMessage):
elif isinstance(task, BaseChatMessage):
input_messages.append(task)
output_messages.append(task)
yield task
@ -166,7 +166,7 @@ class BaseChatAgent(ChatAgent, ABC, ComponentBase[BaseModel]):
if not task:
raise ValueError("Task list cannot be empty.")
for msg in task:
if isinstance(msg, ChatMessage):
if isinstance(msg, BaseChatMessage):
input_messages.append(msg)
output_messages.append(msg)
yield msg

View File

@ -7,7 +7,7 @@ from pydantic import BaseModel
from typing_extensions import Self
from ..base import Response
from ..messages import ChatMessage, TextMessage
from ..messages import BaseChatMessage, TextMessage
from ._base_chat_agent import BaseChatAgent
@ -119,11 +119,11 @@ class CodeExecutorAgent(BaseChatAgent, Component[CodeExecutorAgentConfig]):
self._sources = sources
@property
def produced_message_types(self) -> Sequence[type[ChatMessage]]:
def produced_message_types(self) -> Sequence[type[BaseChatMessage]]:
"""The types of messages that the code executor agent produces."""
return (TextMessage,)
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
async def on_messages(self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken) -> Response:
# Extract code blocks from the messages.
code_blocks: List[CodeBlock] = []
for msg in messages:

View File

@ -10,8 +10,8 @@ from autogen_agentchat.state import SocietyOfMindAgentState
from ..base import TaskResult, Team
from ..messages import (
AgentEvent,
ChatMessage,
BaseAgentEvent,
BaseChatMessage,
ModelClientStreamingChunkEvent,
TextMessage,
)
@ -122,10 +122,10 @@ class SocietyOfMindAgent(BaseChatAgent, Component[SocietyOfMindAgentConfig]):
self._response_prompt = response_prompt
@property
def produced_message_types(self) -> Sequence[type[ChatMessage]]:
def produced_message_types(self) -> Sequence[type[BaseChatMessage]]:
return (TextMessage,)
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
async def on_messages(self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken) -> Response:
# Call the stream method and collect the messages.
response: Response | None = None
async for msg in self.on_messages_stream(messages, cancellation_token):
@ -135,14 +135,14 @@ class SocietyOfMindAgent(BaseChatAgent, Component[SocietyOfMindAgentConfig]):
return response
async def on_messages_stream(
self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken
) -> AsyncGenerator[AgentEvent | ChatMessage | Response, None]:
self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken
) -> AsyncGenerator[BaseAgentEvent | BaseChatMessage | Response, None]:
# Prepare the task for the team of agents.
task = list(messages)
# Run the team of agents.
result: TaskResult | None = None
inner_messages: List[AgentEvent | ChatMessage] = []
inner_messages: List[BaseAgentEvent | BaseChatMessage] = []
count = 0
async for inner_msg in self._team.run_stream(task=task, cancellation_token=cancellation_token):
if isinstance(inner_msg, TaskResult):
@ -167,7 +167,7 @@ class SocietyOfMindAgent(BaseChatAgent, Component[SocietyOfMindAgentConfig]):
# Generate a response using the model client.
llm_messages: List[LLMMessage] = [SystemMessage(content=self._instruction)]
for message in messages:
if isinstance(message, ChatMessage):
if isinstance(message, BaseChatMessage):
llm_messages.append(message.to_model_message())
llm_messages.append(SystemMessage(content=self._response_prompt))
completion = await self._model_client.create(messages=llm_messages, cancellation_token=cancellation_token)

View File

@ -10,7 +10,7 @@ from pydantic import BaseModel
from typing_extensions import Self
from ..base import Response
from ..messages import AgentEvent, ChatMessage, HandoffMessage, TextMessage, UserInputRequestedEvent
from ..messages import BaseAgentEvent, BaseChatMessage, HandoffMessage, TextMessage, UserInputRequestedEvent
from ._base_chat_agent import BaseChatAgent
SyncInputFunc = Callable[[str], str]
@ -170,11 +170,11 @@ class UserProxyAgent(BaseChatAgent, Component[UserProxyAgentConfig]):
self._is_async = iscoroutinefunction(self.input_func)
@property
def produced_message_types(self) -> Sequence[type[ChatMessage]]:
def produced_message_types(self) -> Sequence[type[BaseChatMessage]]:
"""Message types this agent can produce."""
return (TextMessage, HandoffMessage)
def _get_latest_handoff(self, messages: Sequence[ChatMessage]) -> Optional[HandoffMessage]:
def _get_latest_handoff(self, messages: Sequence[BaseChatMessage]) -> Optional[HandoffMessage]:
"""Find the HandoffMessage in the message sequence that addresses this agent."""
if len(messages) > 0 and isinstance(messages[-1], HandoffMessage):
if messages[-1].target == self.name:
@ -201,15 +201,15 @@ class UserProxyAgent(BaseChatAgent, Component[UserProxyAgentConfig]):
except Exception as e:
raise RuntimeError(f"Failed to get user input: {str(e)}") from e
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
async def on_messages(self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken) -> Response:
async for message in self.on_messages_stream(messages, cancellation_token):
if isinstance(message, Response):
return message
raise AssertionError("The stream should have returned the final result.")
async def on_messages_stream(
self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken
) -> AsyncGenerator[AgentEvent | ChatMessage | Response, None]:
self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken
) -> AsyncGenerator[BaseAgentEvent | BaseChatMessage | Response, None]:
"""Handle incoming messages by requesting user input."""
try:
# Check for handoff first

View File

@ -5,7 +5,7 @@ from typing import Any, AsyncGenerator, Mapping, Sequence
from autogen_core import CancellationToken, ComponentBase
from pydantic import BaseModel
from ..messages import AgentEvent, ChatMessage
from ..messages import BaseAgentEvent, BaseChatMessage
from ._task import TaskRunner
@ -13,12 +13,12 @@ from ._task import TaskRunner
class Response:
"""A response from calling :meth:`ChatAgent.on_messages`."""
chat_message: ChatMessage
chat_message: BaseChatMessage
"""A chat message produced by the agent as the response."""
inner_messages: Sequence[AgentEvent | ChatMessage] | None = None
"""Inner messages produced by the agent, they can be :class:`AgentEvent`
or :class:`ChatMessage`."""
inner_messages: Sequence[BaseAgentEvent | BaseChatMessage] | None = None
"""Inner messages produced by the agent, they can be :class:`BaseAgentEvent`
or :class:`BaseChatMessage`."""
class ChatAgent(ABC, TaskRunner, ComponentBase[BaseModel]):
@ -43,20 +43,20 @@ class ChatAgent(ABC, TaskRunner, ComponentBase[BaseModel]):
@property
@abstractmethod
def produced_message_types(self) -> Sequence[type[ChatMessage]]:
def produced_message_types(self) -> Sequence[type[BaseChatMessage]]:
"""The types of messages that the agent produces in the
:attr:`Response.chat_message` field. They must be :class:`ChatMessage` types."""
:attr:`Response.chat_message` field. They must be :class:`BaseChatMessage` types."""
...
@abstractmethod
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
async def on_messages(self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken) -> Response:
"""Handles incoming messages and returns a response."""
...
@abstractmethod
def on_messages_stream(
self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken
) -> AsyncGenerator[AgentEvent | ChatMessage | Response, None]:
self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken
) -> AsyncGenerator[BaseAgentEvent | BaseChatMessage | Response, None]:
"""Handles incoming messages and returns a stream of inner messages and
and the final item is the response."""
...

View File

@ -3,14 +3,14 @@ from typing import AsyncGenerator, Protocol, Sequence
from autogen_core import CancellationToken
from ..messages import AgentEvent, ChatMessage
from ..messages import BaseAgentEvent, BaseChatMessage
@dataclass
class TaskResult:
"""Result of running a task."""
messages: Sequence[AgentEvent | ChatMessage]
messages: Sequence[BaseAgentEvent | BaseChatMessage]
"""Messages produced by the task."""
stop_reason: str | None = None
@ -23,7 +23,7 @@ class TaskRunner(Protocol):
async def run(
self,
*,
task: str | ChatMessage | Sequence[ChatMessage] | None = None,
task: str | BaseChatMessage | Sequence[BaseChatMessage] | None = None,
cancellation_token: CancellationToken | None = None,
) -> TaskResult:
"""Run the task and return the result.
@ -38,9 +38,9 @@ class TaskRunner(Protocol):
def run_stream(
self,
*,
task: str | ChatMessage | Sequence[ChatMessage] | None = None,
task: str | BaseChatMessage | Sequence[BaseChatMessage] | None = None,
cancellation_token: CancellationToken | None = None,
) -> AsyncGenerator[AgentEvent | ChatMessage | TaskResult, None]:
) -> AsyncGenerator[BaseAgentEvent | BaseChatMessage | TaskResult, None]:
"""Run the task and produces a stream of messages and the final result
:class:`TaskResult` as the last item in the stream.

View File

@ -6,7 +6,7 @@ from autogen_core import Component, ComponentBase, ComponentModel
from pydantic import BaseModel
from typing_extensions import Self
from ..messages import AgentEvent, ChatMessage, StopMessage
from ..messages import BaseAgentEvent, BaseChatMessage, StopMessage
class TerminatedException(BaseException): ...
@ -15,7 +15,7 @@ class TerminatedException(BaseException): ...
class TerminationCondition(ABC, ComponentBase[BaseModel]):
"""A stateful condition that determines when a conversation should be terminated.
A termination condition is a callable that takes a sequence of ChatMessage objects
A termination condition is a callable that takes a sequence of BaseChatMessage objects
since the last time the condition was called, and returns a StopMessage if the
conversation should be terminated, or None otherwise.
Once a termination condition has been reached, it must be reset before it can be used again.
@ -56,7 +56,7 @@ class TerminationCondition(ABC, ComponentBase[BaseModel]):
...
@abstractmethod
async def __call__(self, messages: Sequence[AgentEvent | ChatMessage]) -> StopMessage | None:
async def __call__(self, messages: Sequence[BaseAgentEvent | BaseChatMessage]) -> StopMessage | None:
"""Check if the conversation should be terminated based on the messages received
since the last time the condition was called.
Return a StopMessage if the conversation should be terminated, or None otherwise.
@ -102,7 +102,7 @@ class AndTerminationCondition(TerminationCondition, Component[AndTerminationCond
def terminated(self) -> bool:
return all(condition.terminated for condition in self._conditions)
async def __call__(self, messages: Sequence[AgentEvent | ChatMessage]) -> StopMessage | None:
async def __call__(self, messages: Sequence[BaseAgentEvent | BaseChatMessage]) -> StopMessage | None:
if self.terminated:
raise TerminatedException("Termination condition has already been reached.")
# Check all remaining conditions.
@ -153,7 +153,7 @@ class OrTerminationCondition(TerminationCondition, Component[OrTerminationCondit
def terminated(self) -> bool:
return any(condition.terminated for condition in self._conditions)
async def __call__(self, messages: Sequence[AgentEvent | ChatMessage]) -> StopMessage | None:
async def __call__(self, messages: Sequence[BaseAgentEvent | BaseChatMessage]) -> StopMessage | None:
if self.terminated:
raise RuntimeError("Termination condition has already been reached")
stop_messages = await asyncio.gather(*[condition(messages) for condition in self._conditions])

View File

@ -7,9 +7,8 @@ from typing_extensions import Self
from ..base import TerminatedException, TerminationCondition
from ..messages import (
AgentEvent,
BaseAgentEvent,
BaseChatMessage,
ChatMessage,
HandoffMessage,
StopMessage,
TextMessage,
@ -34,7 +33,7 @@ class StopMessageTermination(TerminationCondition, Component[StopMessageTerminat
def terminated(self) -> bool:
return self._terminated
async def __call__(self, messages: Sequence[AgentEvent | ChatMessage]) -> StopMessage | None:
async def __call__(self, messages: Sequence[BaseAgentEvent | BaseChatMessage]) -> StopMessage | None:
if self._terminated:
raise TerminatedException("Termination condition has already been reached")
for message in messages:
@ -64,8 +63,8 @@ class MaxMessageTermination(TerminationCondition, Component[MaxMessageTerminatio
Args:
max_messages: The maximum number of messages allowed in the conversation.
include_agent_event: If True, include :class:`~autogen_agentchat.messages.AgentEvent` in the message count.
Otherwise, only include :class:`~autogen_agentchat.messages.ChatMessage`. Defaults to False.
include_agent_event: If True, include :class:`~autogen_agentchat.messages.BaseAgentEvent` in the message count.
Otherwise, only include :class:`~autogen_agentchat.messages.BaseChatMessage`. Defaults to False.
"""
component_config_schema = MaxMessageTerminationConfig
@ -80,7 +79,7 @@ class MaxMessageTermination(TerminationCondition, Component[MaxMessageTerminatio
def terminated(self) -> bool:
return self._message_count >= self._max_messages
async def __call__(self, messages: Sequence[AgentEvent | ChatMessage]) -> StopMessage | None:
async def __call__(self, messages: Sequence[BaseAgentEvent | BaseChatMessage]) -> StopMessage | None:
if self.terminated:
raise TerminatedException("Termination condition has already been reached")
self._message_count += len([m for m in messages if self._include_agent_event or isinstance(m, BaseChatMessage)])
@ -129,7 +128,7 @@ class TextMentionTermination(TerminationCondition, Component[TextMentionTerminat
def terminated(self) -> bool:
return self._terminated
async def __call__(self, messages: Sequence[AgentEvent | ChatMessage]) -> StopMessage | None:
async def __call__(self, messages: Sequence[BaseAgentEvent | BaseChatMessage]) -> StopMessage | None:
if self._terminated:
raise TerminatedException("Termination condition has already been reached")
for message in messages:
@ -201,7 +200,7 @@ class TokenUsageTermination(TerminationCondition, Component[TokenUsageTerminatio
or (self._max_completion_token is not None and self._completion_token_count >= self._max_completion_token)
)
async def __call__(self, messages: Sequence[AgentEvent | ChatMessage]) -> StopMessage | None:
async def __call__(self, messages: Sequence[BaseAgentEvent | BaseChatMessage]) -> StopMessage | None:
if self.terminated:
raise TerminatedException("Termination condition has already been reached")
for message in messages:
@ -258,7 +257,7 @@ class HandoffTermination(TerminationCondition, Component[HandoffTerminationConfi
def terminated(self) -> bool:
return self._terminated
async def __call__(self, messages: Sequence[AgentEvent | ChatMessage]) -> StopMessage | None:
async def __call__(self, messages: Sequence[BaseAgentEvent | BaseChatMessage]) -> StopMessage | None:
if self._terminated:
raise TerminatedException("Termination condition has already been reached")
for message in messages:
@ -303,7 +302,7 @@ class TimeoutTermination(TerminationCondition, Component[TimeoutTerminationConfi
def terminated(self) -> bool:
return self._terminated
async def __call__(self, messages: Sequence[AgentEvent | ChatMessage]) -> StopMessage | None:
async def __call__(self, messages: Sequence[BaseAgentEvent | BaseChatMessage]) -> StopMessage | None:
if self._terminated:
raise TerminatedException("Termination condition has already been reached")
@ -365,7 +364,7 @@ class ExternalTermination(TerminationCondition, Component[ExternalTerminationCon
"""Set the termination condition to terminated."""
self._setted = True
async def __call__(self, messages: Sequence[AgentEvent | ChatMessage]) -> StopMessage | None:
async def __call__(self, messages: Sequence[BaseAgentEvent | BaseChatMessage]) -> StopMessage | None:
if self._terminated:
raise TerminatedException("Termination condition has already been reached")
if self._setted:
@ -410,7 +409,7 @@ class SourceMatchTermination(TerminationCondition, Component[SourceMatchTerminat
def terminated(self) -> bool:
return self._terminated
async def __call__(self, messages: Sequence[AgentEvent | ChatMessage]) -> StopMessage | None:
async def __call__(self, messages: Sequence[BaseAgentEvent | BaseChatMessage]) -> StopMessage | None:
if self._terminated:
raise TerminatedException("Termination condition has already been reached")
if not messages:
@ -463,7 +462,7 @@ class TextMessageTermination(TerminationCondition, Component[TextMessageTerminat
def terminated(self) -> bool:
return self._terminated
async def __call__(self, messages: Sequence[AgentEvent | ChatMessage]) -> StopMessage | None:
async def __call__(self, messages: Sequence[BaseAgentEvent | BaseChatMessage]) -> StopMessage | None:
if self._terminated:
raise TerminatedException("Termination condition has already been reached")
for message in messages:
@ -513,7 +512,7 @@ class FunctionCallTermination(TerminationCondition, Component[FunctionCallTermin
def terminated(self) -> bool:
return self._terminated
async def __call__(self, messages: Sequence[AgentEvent | ChatMessage]) -> StopMessage | None:
async def __call__(self, messages: Sequence[BaseAgentEvent | BaseChatMessage]) -> StopMessage | None:
if self._terminated:
raise TerminatedException("Termination condition has already been reached")
for message in messages:

View File

@ -1,6 +1,6 @@
"""
This module defines various message types used for agent-to-agent communication.
Each message type inherits either from the ChatMessage class or BaseAgentEvent
Each message type inherits either from the BaseChatMessage class or BaseAgentEvent
class and includes specific fields relevant to the type of message being sent.
"""
@ -10,33 +10,36 @@ from typing import Any, Dict, Generic, List, Literal, Mapping, TypeVar
from autogen_core import FunctionCall, Image
from autogen_core.memory import MemoryContent
from autogen_core.models import FunctionExecutionResult, LLMMessage, RequestUsage, UserMessage
from pydantic import BaseModel, ConfigDict, computed_field
from typing_extensions import Self
from pydantic import BaseModel, Field, computed_field
from typing_extensions import Annotated, Self
class BaseMessage(BaseModel, ABC):
"""Base class for all message types in AgentChat. This is an abstract class
with default implementations for serialization and deserialization.
"""Abstract base class for all message types in AgentChat.
.. warning::
If you want to create a new message type, do not inherit from this class.
Instead, inherit from :class:`ChatMessage` or :class:`AgentEvent`
Instead, inherit from :class:`BaseChatMessage` or :class:`BaseAgentEvent`
to clarify the purpose of the message type.
"""
@computed_field
def type(self) -> str:
"""The class name of this message."""
return self.__class__.__name__
@abstractmethod
def to_text(self) -> str:
"""Convert the message content to a string-only representation
that can be rendered in the console and inspected by the user or conditions.
This is not used for creating text-only content for models.
For :class:`BaseChatMessage` types, use :meth:`to_model_text` instead."""
...
def dump(self) -> Mapping[str, Any]:
"""Convert the message to a JSON-serializable dictionary.
The default implementation uses the Pydantic model's `model_dump` method.
If you want to customize the serialization, override this method.
The default implementation uses the Pydantic model's
:meth:`model_dump` method to convert the message to a dictionary.
Override this method if you want to customize the serialization
process or add additional fields to the output.
"""
return self.model_dump()
@ -44,14 +47,15 @@ class BaseMessage(BaseModel, ABC):
def load(cls, data: Mapping[str, Any]) -> Self:
"""Create a message from a dictionary of JSON-serializable data.
The default implementation uses the Pydantic model's `model_validate` method.
If you want to customize the deserialization, override this method.
"""
The default implementation uses the Pydantic model's
:meth:`model_validate` method to create the message from the data.
Override this method if you want to customize the deserialization
process or add additional fields to the input data."""
return cls.model_validate(data)
class ChatMessage(BaseMessage, ABC):
"""Base class for chat messages.
class BaseChatMessage(BaseMessage, ABC):
"""Abstract base class for chat messages.
.. note::
@ -62,7 +66,7 @@ class ChatMessage(BaseMessage, ABC):
This class is used for messages that are sent between agents in a chat
conversation. Agents are expected to process the content of the
message using models and return a response as another :class:`ChatMessage`.
message using models and return a response as another :class:`BaseChatMessage`.
"""
source: str
@ -74,17 +78,6 @@ class ChatMessage(BaseMessage, ABC):
metadata: Dict[str, str] = {}
"""Additional metadata about the message."""
model_config = ConfigDict(arbitrary_types_allowed=True)
@abstractmethod
def to_text(self) -> str:
"""Convert the content of the message to a string-only representation
that can be rendered in the console and inspected by the user or conditions.
This is not used for creating text-only content for models.
For :class:`ChatMessage` types, use :meth:`to_model_text` instead."""
...
@abstractmethod
def to_model_text(self) -> str:
"""Convert the content of the message to text-only representation.
@ -107,8 +100,8 @@ class ChatMessage(BaseMessage, ABC):
...
class TextChatMessage(ChatMessage, ABC):
"""Base class for all text-only :class:`ChatMessage` types.
class BaseTextChatMessage(BaseChatMessage, ABC):
"""Base class for all text-only :class:`BaseChatMessage` types.
It has implementations for :meth:`to_text`, :meth:`to_model_text`,
and :meth:`to_model_message` methods.
@ -128,7 +121,7 @@ class TextChatMessage(ChatMessage, ABC):
return UserMessage(content=self.content, source=self.source)
class AgentEvent(BaseMessage, ABC):
class BaseAgentEvent(BaseMessage, ABC):
"""Base class for agent events.
.. note::
@ -153,24 +146,13 @@ class AgentEvent(BaseMessage, ABC):
metadata: Dict[str, str] = {}
"""Additional metadata about the message."""
model_config = ConfigDict(arbitrary_types_allowed=True)
@abstractmethod
def to_text(self) -> str:
"""Convert the content of the message to a string-only representation
that can be rendered in the console and inspected by the user.
This is not used for creating text-only content for models.
For :class:`ChatMessage` types, use :meth:`to_model_text` instead."""
...
StructuredContentType = TypeVar("StructuredContentType", bound=BaseModel, covariant=True)
"""Type variable for structured content types."""
class StructuredMessage(ChatMessage, Generic[StructuredContentType]):
"""A :class:`ChatMessage` type with an unspecified content type.
class StructuredMessage(BaseChatMessage, Generic[StructuredContentType]):
"""A :class:`BaseChatMessage` type with an unspecified content type.
To create a new structured message type, specify the content type
as a subclass of `Pydantic BaseModel <https://docs.pydantic.dev/latest/concepts/models/>`_.
@ -199,6 +181,10 @@ class StructuredMessage(ChatMessage, Generic[StructuredContentType]):
"""The content of the message. Must be a subclass of
`Pydantic BaseModel <https://docs.pydantic.dev/latest/concepts/models/>`_."""
@computed_field
def type(self) -> str:
return self.__class__.__name__
def to_text(self) -> str:
return self.content.model_dump_json(indent=2)
@ -212,18 +198,20 @@ class StructuredMessage(ChatMessage, Generic[StructuredContentType]):
)
class TextMessage(TextChatMessage):
class TextMessage(BaseTextChatMessage):
"""A text message with string-only content."""
...
type: Literal["TextMessage"] = "TextMessage"
class MultiModalMessage(ChatMessage):
class MultiModalMessage(BaseChatMessage):
"""A multimodal message."""
content: List[str | Image]
"""The content of the message."""
type: Literal["MultiModalMessage"] = "MultiModalMessage"
def to_model_text(self, image_placeholder: str | None = "[image]") -> str:
"""Convert the content of the message to a string-only representation.
If an image is present, it will be replaced with the image placeholder
@ -258,13 +246,13 @@ class MultiModalMessage(ChatMessage):
return UserMessage(content=self.content, source=self.source)
class StopMessage(TextChatMessage):
class StopMessage(BaseTextChatMessage):
"""A message requesting stop of a conversation."""
...
type: Literal["StopMessage"] = "StopMessage"
class HandoffMessage(TextChatMessage):
class HandoffMessage(BaseTextChatMessage):
"""A message requesting handoff of a conversation to another agent."""
target: str
@ -273,34 +261,40 @@ class HandoffMessage(TextChatMessage):
context: List[LLMMessage] = []
"""The model context to be passed to the target agent."""
type: Literal["HandoffMessage"] = "HandoffMessage"
class ToolCallSummaryMessage(TextChatMessage):
class ToolCallSummaryMessage(BaseTextChatMessage):
"""A message signaling the summary of tool call results."""
...
type: Literal["ToolCallSummaryMessage"] = "ToolCallSummaryMessage"
class ToolCallRequestEvent(AgentEvent):
class ToolCallRequestEvent(BaseAgentEvent):
"""An event signaling a request to use tools."""
content: List[FunctionCall]
"""The tool calls."""
type: Literal["ToolCallRequestEvent"] = "ToolCallRequestEvent"
def to_text(self) -> str:
return str(self.content)
class ToolCallExecutionEvent(AgentEvent):
class ToolCallExecutionEvent(BaseAgentEvent):
"""An event signaling the execution of tool calls."""
content: List[FunctionExecutionResult]
"""The tool call results."""
type: Literal["ToolCallExecutionEvent"] = "ToolCallExecutionEvent"
def to_text(self) -> str:
return str(self.content)
class UserInputRequestedEvent(AgentEvent):
class UserInputRequestedEvent(BaseAgentEvent):
"""An event signaling a that the user proxy has requested user input. Published prior to invoking the input callback."""
request_id: str
@ -309,31 +303,37 @@ class UserInputRequestedEvent(AgentEvent):
content: Literal[""] = ""
"""Empty content for compat with consumers expecting a content field."""
type: Literal["UserInputRequestedEvent"] = "UserInputRequestedEvent"
def to_text(self) -> str:
return str(self.content)
class MemoryQueryEvent(AgentEvent):
class MemoryQueryEvent(BaseAgentEvent):
"""An event signaling the results of memory queries."""
content: List[MemoryContent]
"""The memory query results."""
type: Literal["MemoryQueryEvent"] = "MemoryQueryEvent"
def to_text(self) -> str:
return str(self.content)
class ModelClientStreamingChunkEvent(AgentEvent):
class ModelClientStreamingChunkEvent(BaseAgentEvent):
"""An event signaling a text output chunk from a model client in streaming mode."""
content: str
"""A string chunk from the model client."""
type: Literal["ModelClientStreamingChunkEvent"] = "ModelClientStreamingChunkEvent"
def to_text(self) -> str:
return self.content
class ThoughtEvent(AgentEvent):
class ThoughtEvent(BaseAgentEvent):
"""An event signaling the thought process of a model.
It is used to communicate the reasoning tokens generated by a reasoning model,
or the extra text content generated by a function call."""
@ -341,6 +341,8 @@ class ThoughtEvent(AgentEvent):
content: str
"""The thought process of the model."""
type: Literal["ThoughtEvent"] = "ThoughtEvent"
def to_text(self) -> str:
return self.content
@ -354,7 +356,7 @@ class MessageFactory:
"""
def __init__(self) -> None:
self._message_types: Dict[str, type[AgentEvent | ChatMessage]] = {}
self._message_types: Dict[str, type[BaseAgentEvent | BaseChatMessage]] = {}
# Register all message types.
self._message_types[TextMessage.__name__] = TextMessage
self._message_types[MultiModalMessage.__name__] = MultiModalMessage
@ -368,29 +370,31 @@ class MessageFactory:
self._message_types[ModelClientStreamingChunkEvent.__name__] = ModelClientStreamingChunkEvent
self._message_types[ThoughtEvent.__name__] = ThoughtEvent
def is_registered(self, message_type: type[AgentEvent | ChatMessage]) -> bool:
def is_registered(self, message_type: type[BaseAgentEvent | BaseChatMessage]) -> bool:
"""Check if a message type is registered with the factory."""
# Get the class name of the message type.
class_name = message_type.__name__
# Check if the class name is already registered.
return class_name in self._message_types
def register(self, message_type: type[AgentEvent | ChatMessage]) -> None:
def register(self, message_type: type[BaseAgentEvent | BaseChatMessage]) -> None:
"""Register a new message type with the factory."""
if self.is_registered(message_type):
raise ValueError(f"Message type {message_type} is already registered.")
if not issubclass(message_type, ChatMessage) and not issubclass(message_type, AgentEvent):
raise ValueError(f"Message type {message_type} must be a subclass of ChatMessage or AgentEvent.")
if not issubclass(message_type, BaseChatMessage) and not issubclass(message_type, BaseAgentEvent):
raise ValueError(f"Message type {message_type} must be a subclass of BaseChatMessage or BaseAgentEvent.")
# Get the class name of the
class_name = message_type.__name__
# Check if the class name is already registered.
# Register the message type.
self._message_types[class_name] = message_type
def create(self, data: Mapping[str, Any]) -> AgentEvent | ChatMessage:
def create(self, data: Mapping[str, Any]) -> BaseAgentEvent | BaseChatMessage:
"""Create a message from a dictionary of JSON-serializable data."""
# Get the type of the message from the dictionary.
message_type = data.get("type")
if message_type is None:
raise ValueError("Field 'type' is required in the message data to recover the message type.")
if message_type not in self._message_types:
raise ValueError(f"Unknown message type: {message_type}")
if not isinstance(message_type, str):
@ -400,14 +404,26 @@ class MessageFactory:
message_class = self._message_types[message_type]
# Create an instance of the message class.
assert issubclass(message_class, ChatMessage) or issubclass(message_class, AgentEvent)
assert issubclass(message_class, BaseChatMessage) or issubclass(message_class, BaseAgentEvent)
return message_class.load(data)
# For backward compatibility
BaseAgentEvent = AgentEvent
BaseChatMessage = ChatMessage
ChatMessage = Annotated[
TextMessage | MultiModalMessage | StopMessage | ToolCallSummaryMessage | HandoffMessage, Field(discriminator="type")
]
"""The union type of all built-in concrete subclasses of :class:`BaseChatMessage`.
It does not include :class:`StructuredMessage` types."""
AgentEvent = Annotated[
ToolCallRequestEvent
| ToolCallExecutionEvent
| MemoryQueryEvent
| UserInputRequestedEvent
| ModelClientStreamingChunkEvent
| ThoughtEvent,
Field(discriminator="type"),
]
"""The union type of all built-in concrete subclasses of :class:`BaseAgentEvent`."""
__all__ = [
"AgentEvent",
@ -415,9 +431,8 @@ __all__ = [
"ChatMessage",
"BaseChatMessage",
"BaseAgentEvent",
"AgentEvent",
"TextChatMessage",
"ChatMessage",
"BaseTextChatMessage",
"BaseChatMessage",
"StructuredContentType",
"StructuredMessage",
"HandoffMessage",

View File

@ -18,8 +18,8 @@ from pydantic import BaseModel, ValidationError
from ... import EVENT_LOGGER_NAME
from ...base import ChatAgent, TaskResult, Team, TerminationCondition
from ...messages import (
AgentEvent,
ChatMessage,
BaseAgentEvent,
BaseChatMessage,
MessageFactory,
ModelClientStreamingChunkEvent,
StopMessage,
@ -50,7 +50,7 @@ class BaseGroupChat(Team, ABC, ComponentBase[BaseModel]):
termination_condition: TerminationCondition | None = None,
max_turns: int | None = None,
runtime: AgentRuntime | None = None,
custom_message_types: List[type[AgentEvent | ChatMessage]] | None = None,
custom_message_types: List[type[BaseAgentEvent | BaseChatMessage]] | None = None,
):
if len(participants) == 0:
raise ValueError("At least one participant is required.")
@ -90,7 +90,9 @@ class BaseGroupChat(Team, ABC, ComponentBase[BaseModel]):
self._output_topic_type = f"output_topic_{self._team_id}"
# The queue for collecting the output messages.
self._output_message_queue: asyncio.Queue[AgentEvent | ChatMessage | GroupChatTermination] = asyncio.Queue()
self._output_message_queue: asyncio.Queue[BaseAgentEvent | BaseChatMessage | GroupChatTermination] = (
asyncio.Queue()
)
# Create a runtime for the team.
if runtime is not None:
@ -117,7 +119,7 @@ class BaseGroupChat(Team, ABC, ComponentBase[BaseModel]):
participant_topic_types: List[str],
participant_names: List[str],
participant_descriptions: List[str],
output_message_queue: asyncio.Queue[AgentEvent | ChatMessage | GroupChatTermination],
output_message_queue: asyncio.Queue[BaseAgentEvent | BaseChatMessage | GroupChatTermination],
termination_condition: TerminationCondition | None,
max_turns: int | None,
message_factory: MessageFactory,
@ -195,7 +197,7 @@ class BaseGroupChat(Team, ABC, ComponentBase[BaseModel]):
async def run(
self,
*,
task: str | ChatMessage | Sequence[ChatMessage] | None = None,
task: str | BaseChatMessage | Sequence[BaseChatMessage] | None = None,
cancellation_token: CancellationToken | None = None,
) -> TaskResult:
"""Run the team and return the result. The base implementation uses
@ -203,7 +205,7 @@ class BaseGroupChat(Team, ABC, ComponentBase[BaseModel]):
Once the team is stopped, the termination condition is reset.
Args:
task (str | ChatMessage | Sequence[ChatMessage] | None): The task to run the team with. Can be a string, a single :class:`ChatMessage` , or a list of :class:`ChatMessage`.
task (str | BaseChatMessage | Sequence[BaseChatMessage] | None): The task to run the team with. Can be a string, a single :class:`BaseChatMessage` , or a list of :class:`BaseChatMessage`.
cancellation_token (CancellationToken | None): The cancellation token to kill the task immediately.
Setting the cancellation token potentially put the team in an inconsistent state,
and it may not reset the termination condition.
@ -297,9 +299,9 @@ class BaseGroupChat(Team, ABC, ComponentBase[BaseModel]):
async def run_stream(
self,
*,
task: str | ChatMessage | Sequence[ChatMessage] | None = None,
task: str | BaseChatMessage | Sequence[BaseChatMessage] | None = None,
cancellation_token: CancellationToken | None = None,
) -> AsyncGenerator[AgentEvent | ChatMessage | TaskResult, None]:
) -> AsyncGenerator[BaseAgentEvent | BaseChatMessage | TaskResult, None]:
"""Run the team and produces a stream of messages and the final result
of the type :class:`~autogen_agentchat.base.TaskResult` as the last item in the stream. Once the
team is stopped, the termination condition is reset.
@ -311,14 +313,14 @@ class BaseGroupChat(Team, ABC, ComponentBase[BaseModel]):
:attr:`~autogen_agentchat.base.TaskResult.messages`.
Args:
task (str | ChatMessage | Sequence[ChatMessage] | None): The task to run the team with. Can be a string, a single :class:`ChatMessage` , or a list of :class:`ChatMessage`.
task (str | BaseChatMessage | Sequence[BaseChatMessage] | None): The task to run the team with. Can be a string, a single :class:`BaseChatMessage` , or a list of :class:`BaseChatMessage`.
cancellation_token (CancellationToken | None): The cancellation token to kill the task immediately.
Setting the cancellation token potentially put the team in an inconsistent state,
and it may not reset the termination condition.
To gracefully stop the team, use :class:`~autogen_agentchat.conditions.ExternalTermination` instead.
Returns:
stream: an :class:`~collections.abc.AsyncGenerator` that yields :class:`~autogen_agentchat.messages.AgentEvent`, :class:`~autogen_agentchat.messages.ChatMessage`, and the final result :class:`~autogen_agentchat.base.TaskResult` as the last item in the stream.
stream: an :class:`~collections.abc.AsyncGenerator` that yields :class:`~autogen_agentchat.messages.BaseAgentEvent`, :class:`~autogen_agentchat.messages.BaseChatMessage`, and the final result :class:`~autogen_agentchat.base.TaskResult` as the last item in the stream.
Example using the :class:`~autogen_agentchat.teams.RoundRobinGroupChat` team:
@ -398,23 +400,23 @@ class BaseGroupChat(Team, ABC, ComponentBase[BaseModel]):
"""
# Create the messages list if the task is a string or a chat message.
messages: List[ChatMessage] | None = None
messages: List[BaseChatMessage] | None = None
if task is None:
pass
elif isinstance(task, str):
messages = [TextMessage(content=task, source="user")]
elif isinstance(task, ChatMessage):
elif isinstance(task, BaseChatMessage):
messages = [task]
elif isinstance(task, list):
if not task:
raise ValueError("Task list cannot be empty.")
messages = []
for msg in task:
if not isinstance(msg, ChatMessage):
raise ValueError("All messages in task list must be valid ChatMessage types")
if not isinstance(msg, BaseChatMessage):
raise ValueError("All messages in task list must be valid BaseChatMessage types")
messages.append(msg)
else:
raise ValueError("Task must be a string, a ChatMessage, or a list of ChatMessage.")
raise ValueError("Task must be a string, a BaseChatMessage, or a list of BaseChatMessage.")
# Check if the messages types are registered with the message factory.
if messages is not None:
for msg in messages:
@ -469,7 +471,7 @@ class BaseGroupChat(Team, ABC, ComponentBase[BaseModel]):
cancellation_token=cancellation_token,
)
# Collect the output messages in order.
output_messages: List[AgentEvent | ChatMessage] = []
output_messages: List[BaseAgentEvent | BaseChatMessage] = []
stop_reason: str | None = None
# Yield the messsages until the queue is empty.
while True:

View File

@ -5,7 +5,7 @@ from typing import Any, List
from autogen_core import DefaultTopicId, MessageContext, event, rpc
from ...base import TerminationCondition
from ...messages import AgentEvent, ChatMessage, MessageFactory, StopMessage
from ...messages import BaseAgentEvent, BaseChatMessage, MessageFactory, StopMessage
from ._events import (
GroupChatAgentResponse,
GroupChatMessage,
@ -39,7 +39,7 @@ class BaseGroupChatManager(SequentialRoutedAgent, ABC):
participant_topic_types: List[str],
participant_names: List[str],
participant_descriptions: List[str],
output_message_queue: asyncio.Queue[AgentEvent | ChatMessage | GroupChatTermination],
output_message_queue: asyncio.Queue[BaseAgentEvent | BaseChatMessage | GroupChatTermination],
termination_condition: TerminationCondition | None,
max_turns: int | None,
message_factory: MessageFactory,
@ -67,7 +67,7 @@ class BaseGroupChatManager(SequentialRoutedAgent, ABC):
name: topic_type for name, topic_type in zip(participant_names, participant_topic_types, strict=True)
}
self._participant_descriptions = participant_descriptions
self._message_thread: List[AgentEvent | ChatMessage] = []
self._message_thread: List[BaseAgentEvent | BaseChatMessage] = []
self._output_message_queue = output_message_queue
self._termination_condition = termination_condition
if max_turns is not None and max_turns <= 0:
@ -141,7 +141,7 @@ class BaseGroupChatManager(SequentialRoutedAgent, ABC):
@event
async def handle_agent_response(self, message: GroupChatAgentResponse, ctx: MessageContext) -> None:
# Append the message to the message thread and construct the delta.
delta: List[AgentEvent | ChatMessage] = []
delta: List[BaseAgentEvent | BaseChatMessage] = []
if message.agent_response.inner_messages is not None:
for inner_message in message.agent_response.inner_messages:
self._message_thread.append(inner_message)
@ -225,7 +225,7 @@ class BaseGroupChatManager(SequentialRoutedAgent, ABC):
pass
@abstractmethod
async def validate_group_state(self, messages: List[ChatMessage] | None) -> None:
async def validate_group_state(self, messages: List[BaseChatMessage] | None) -> None:
"""Validate the state of the group chat given the start messages.
This is executed when the group chat manager receives a GroupChatStart event.
@ -235,7 +235,7 @@ class BaseGroupChatManager(SequentialRoutedAgent, ABC):
...
@abstractmethod
async def select_speaker(self, thread: List[AgentEvent | ChatMessage]) -> str:
async def select_speaker(self, thread: List[BaseAgentEvent | BaseChatMessage]) -> str:
"""Select a speaker from the participants and return the
topic type of the selected speaker."""
...

View File

@ -2,7 +2,7 @@ from typing import Any, List, Mapping
from autogen_core import DefaultTopicId, MessageContext, event, rpc
from autogen_agentchat.messages import AgentEvent, ChatMessage, MessageFactory
from autogen_agentchat.messages import BaseAgentEvent, BaseChatMessage, MessageFactory
from ...base import ChatAgent, Response
from ...state import ChatAgentContainerState
@ -46,7 +46,7 @@ class ChatAgentContainer(SequentialRoutedAgent):
self._parent_topic_type = parent_topic_type
self._output_topic_type = output_topic_type
self._agent = agent
self._message_buffer: List[ChatMessage] = []
self._message_buffer: List[BaseChatMessage] = []
self._message_factory = message_factory
@event
@ -90,13 +90,13 @@ class ChatAgentContainer(SequentialRoutedAgent):
cancellation_token=ctx.cancellation_token,
)
def _buffer_message(self, message: ChatMessage) -> None:
def _buffer_message(self, message: BaseChatMessage) -> None:
if not self._message_factory.is_registered(message.__class__):
raise ValueError(f"Message type {message.__class__} is not registered.")
# Buffer the message.
self._message_buffer.append(message)
async def _log_message(self, message: AgentEvent | ChatMessage) -> None:
async def _log_message(self, message: BaseAgentEvent | BaseChatMessage) -> None:
if not self._message_factory.is_registered(message.__class__):
raise ValueError(f"Message type {message.__class__} is not registered.")
# Log the message.
@ -130,7 +130,7 @@ class ChatAgentContainer(SequentialRoutedAgent):
self._message_buffer = []
for message_data in container_state.message_buffer:
message = self._message_factory.create(message_data)
if isinstance(message, ChatMessage):
if isinstance(message, BaseChatMessage):
self._message_buffer.append(message)
else:
raise ValueError(f"Invalid message type in message buffer: {type(message)}")

View File

@ -3,13 +3,13 @@ from typing import List
from pydantic import BaseModel
from ...base import Response
from ...messages import AgentEvent, ChatMessage, StopMessage
from ...messages import BaseAgentEvent, BaseChatMessage, StopMessage
class GroupChatStart(BaseModel):
"""A request to start a group chat."""
messages: List[ChatMessage] | None = None
messages: List[BaseChatMessage] | None = None
"""An optional list of messages to start the group chat."""
@ -29,7 +29,7 @@ class GroupChatRequestPublish(BaseModel):
class GroupChatMessage(BaseModel):
"""A message from a group chat."""
message: AgentEvent | ChatMessage
message: BaseAgentEvent | BaseChatMessage
"""The message that was published."""

View File

@ -9,7 +9,7 @@ from typing_extensions import Self
from .... import EVENT_LOGGER_NAME, TRACE_LOGGER_NAME
from ....base import ChatAgent, TerminationCondition
from ....messages import AgentEvent, ChatMessage, MessageFactory
from ....messages import BaseAgentEvent, BaseChatMessage, MessageFactory
from .._base_group_chat import BaseGroupChat
from .._events import GroupChatTermination
from ._magentic_one_orchestrator import MagenticOneOrchestrator
@ -128,7 +128,7 @@ class MagenticOneGroupChat(BaseGroupChat, Component[MagenticOneGroupChatConfig])
participant_topic_types: List[str],
participant_names: List[str],
participant_descriptions: List[str],
output_message_queue: asyncio.Queue[AgentEvent | ChatMessage | GroupChatTermination],
output_message_queue: asyncio.Queue[BaseAgentEvent | BaseChatMessage | GroupChatTermination],
termination_condition: TerminationCondition | None,
max_turns: int | None,
message_factory: MessageFactory,

View File

@ -15,8 +15,8 @@ from autogen_core.models import (
from .... import TRACE_LOGGER_NAME
from ....base import Response, TerminationCondition
from ....messages import (
AgentEvent,
ChatMessage,
BaseAgentEvent,
BaseChatMessage,
HandoffMessage,
MessageFactory,
MultiModalMessage,
@ -66,7 +66,7 @@ class MagenticOneOrchestrator(BaseGroupChatManager):
model_client: ChatCompletionClient,
max_stalls: int,
final_answer_prompt: str,
output_message_queue: asyncio.Queue[AgentEvent | ChatMessage | GroupChatTermination],
output_message_queue: asyncio.Queue[BaseAgentEvent | BaseChatMessage | GroupChatTermination],
termination_condition: TerminationCondition | None,
):
super().__init__(
@ -184,7 +184,7 @@ class MagenticOneOrchestrator(BaseGroupChatManager):
@event
async def handle_agent_response(self, message: GroupChatAgentResponse, ctx: MessageContext) -> None: # type: ignore
delta: List[AgentEvent | ChatMessage] = []
delta: List[BaseAgentEvent | BaseChatMessage] = []
if message.agent_response.inner_messages is not None:
for inner_message in message.agent_response.inner_messages:
delta.append(inner_message)
@ -201,7 +201,7 @@ class MagenticOneOrchestrator(BaseGroupChatManager):
return
await self._orchestrate_step(ctx.cancellation_token)
async def validate_group_state(self, messages: List[ChatMessage] | None) -> None:
async def validate_group_state(self, messages: List[BaseChatMessage] | None) -> None:
pass
async def save_state(self) -> Mapping[str, Any]:
@ -226,7 +226,7 @@ class MagenticOneOrchestrator(BaseGroupChatManager):
self._n_rounds = orchestrator_state.n_rounds
self._n_stalls = orchestrator_state.n_stalls
async def select_speaker(self, thread: List[AgentEvent | ChatMessage]) -> str:
async def select_speaker(self, thread: List[BaseAgentEvent | BaseChatMessage]) -> str:
"""Not used in this orchestrator, we select next speaker in _orchestrate_step."""
return ""

View File

@ -6,7 +6,7 @@ from pydantic import BaseModel
from typing_extensions import Self
from ...base import ChatAgent, TerminationCondition
from ...messages import AgentEvent, ChatMessage, MessageFactory
from ...messages import BaseAgentEvent, BaseChatMessage, MessageFactory
from ...state import RoundRobinManagerState
from ._base_group_chat import BaseGroupChat
from ._base_group_chat_manager import BaseGroupChatManager
@ -24,7 +24,7 @@ class RoundRobinGroupChatManager(BaseGroupChatManager):
participant_topic_types: List[str],
participant_names: List[str],
participant_descriptions: List[str],
output_message_queue: asyncio.Queue[AgentEvent | ChatMessage | GroupChatTermination],
output_message_queue: asyncio.Queue[BaseAgentEvent | BaseChatMessage | GroupChatTermination],
termination_condition: TerminationCondition | None,
max_turns: int | None,
message_factory: MessageFactory,
@ -43,7 +43,7 @@ class RoundRobinGroupChatManager(BaseGroupChatManager):
)
self._next_speaker_index = 0
async def validate_group_state(self, messages: List[ChatMessage] | None) -> None:
async def validate_group_state(self, messages: List[BaseChatMessage] | None) -> None:
pass
async def reset(self) -> None:
@ -67,7 +67,7 @@ class RoundRobinGroupChatManager(BaseGroupChatManager):
self._current_turn = round_robin_state.current_turn
self._next_speaker_index = round_robin_state.next_speaker_index
async def select_speaker(self, thread: List[AgentEvent | ChatMessage]) -> str:
async def select_speaker(self, thread: List[BaseAgentEvent | BaseChatMessage]) -> str:
"""Select a speaker from the participants in a round-robin fashion."""
current_speaker_index = self._next_speaker_index
self._next_speaker_index = (current_speaker_index + 1) % len(self._participant_names)
@ -166,7 +166,7 @@ class RoundRobinGroupChat(BaseGroupChat, Component[RoundRobinGroupChatConfig]):
termination_condition: TerminationCondition | None = None,
max_turns: int | None = None,
runtime: AgentRuntime | None = None,
custom_message_types: List[type[AgentEvent | ChatMessage]] | None = None,
custom_message_types: List[type[BaseAgentEvent | BaseChatMessage]] | None = None,
) -> None:
super().__init__(
participants,
@ -186,7 +186,7 @@ class RoundRobinGroupChat(BaseGroupChat, Component[RoundRobinGroupChatConfig]):
participant_topic_types: List[str],
participant_names: List[str],
participant_descriptions: List[str],
output_message_queue: asyncio.Queue[AgentEvent | ChatMessage | GroupChatTermination],
output_message_queue: asyncio.Queue[BaseAgentEvent | BaseChatMessage | GroupChatTermination],
termination_condition: TerminationCondition | None,
max_turns: int | None,
message_factory: MessageFactory,

View File

@ -13,8 +13,8 @@ from ... import TRACE_LOGGER_NAME
from ...agents import BaseChatAgent
from ...base import ChatAgent, TerminationCondition
from ...messages import (
AgentEvent,
ChatMessage,
BaseAgentEvent,
BaseChatMessage,
MessageFactory,
)
from ...state import SelectorManagerState
@ -24,12 +24,12 @@ from ._events import GroupChatTermination
trace_logger = logging.getLogger(TRACE_LOGGER_NAME)
SyncSelectorFunc = Callable[[Sequence[AgentEvent | ChatMessage]], str | None]
AsyncSelectorFunc = Callable[[Sequence[AgentEvent | ChatMessage]], Awaitable[str | None]]
SyncSelectorFunc = Callable[[Sequence[BaseAgentEvent | BaseChatMessage]], str | None]
AsyncSelectorFunc = Callable[[Sequence[BaseAgentEvent | BaseChatMessage]], Awaitable[str | None]]
SelectorFuncType = Union[SyncSelectorFunc | AsyncSelectorFunc]
SyncCandidateFunc = Callable[[Sequence[AgentEvent | ChatMessage]], List[str]]
AsyncCandidateFunc = Callable[[Sequence[AgentEvent | ChatMessage]], Awaitable[List[str]]]
SyncCandidateFunc = Callable[[Sequence[BaseAgentEvent | BaseChatMessage]], List[str]]
AsyncCandidateFunc = Callable[[Sequence[BaseAgentEvent | BaseChatMessage]], Awaitable[List[str]]]
CandidateFuncType = Union[SyncCandidateFunc | AsyncCandidateFunc]
@ -45,7 +45,7 @@ class SelectorGroupChatManager(BaseGroupChatManager):
participant_topic_types: List[str],
participant_names: List[str],
participant_descriptions: List[str],
output_message_queue: asyncio.Queue[AgentEvent | ChatMessage | GroupChatTermination],
output_message_queue: asyncio.Queue[BaseAgentEvent | BaseChatMessage | GroupChatTermination],
termination_condition: TerminationCondition | None,
max_turns: int | None,
message_factory: MessageFactory,
@ -78,7 +78,7 @@ class SelectorGroupChatManager(BaseGroupChatManager):
self._candidate_func = candidate_func
self._is_candidate_func_async = iscoroutinefunction(self._candidate_func)
async def validate_group_state(self, messages: List[ChatMessage] | None) -> None:
async def validate_group_state(self, messages: List[BaseChatMessage] | None) -> None:
pass
async def reset(self) -> None:
@ -102,7 +102,7 @@ class SelectorGroupChatManager(BaseGroupChatManager):
self._current_turn = selector_state.current_turn
self._previous_speaker = selector_state.previous_speaker
async def select_speaker(self, thread: List[AgentEvent | ChatMessage]) -> str:
async def select_speaker(self, thread: List[BaseAgentEvent | BaseChatMessage]) -> str:
"""Selects the next speaker in a group chat using a ChatCompletion client,
with the selector function as override if it returns a speaker name.
@ -153,7 +153,7 @@ class SelectorGroupChatManager(BaseGroupChatManager):
# Construct the history of the conversation.
history_messages: List[str] = []
for msg in thread:
if not isinstance(msg, ChatMessage):
if not isinstance(msg, BaseChatMessage):
# Only process chat messages.
continue
message = f"{msg.source}: {msg.to_model_text()}"
@ -299,11 +299,11 @@ class SelectorGroupChat(BaseGroupChat, Component[SelectorGroupChatConfig]):
max_selector_attempts (int, optional): The maximum number of attempts to select a speaker using the model. Defaults to 3.
If the model fails to select a speaker after the maximum number of attempts, the previous speaker will be used if available,
otherwise the first participant will be used.
selector_func (Callable[[Sequence[AgentEvent | ChatMessage]], str | None], Callable[[Sequence[AgentEvent | ChatMessage]], Awaitable[str | None]], optional): A custom selector
selector_func (Callable[[Sequence[BaseAgentEvent | BaseChatMessage]], str | None], Callable[[Sequence[BaseAgentEvent | BaseChatMessage]], Awaitable[str | None]], optional): A custom selector
function that takes the conversation history and returns the name of the next speaker.
If provided, this function will be used to override the model to select the next speaker.
If the function returns None, the model will be used to select the next speaker.
candidate_func (Callable[[Sequence[AgentEvent | ChatMessage]], List[str]], Callable[[Sequence[AgentEvent | ChatMessage]], Awaitable[List[str]]], optional):
candidate_func (Callable[[Sequence[BaseAgentEvent | BaseChatMessage]], List[str]], Callable[[Sequence[BaseAgentEvent | BaseChatMessage]], Awaitable[List[str]]], optional):
A custom function that takes the conversation history and returns a filtered list of candidates for the next speaker
selection using model. If the function returns an empty list or `None`, `SelectorGroupChat` will raise a `ValueError`.
This function is only used if `selector_func` is not set. The `allow_repeated_speaker` will be ignored if set.
@ -378,7 +378,7 @@ class SelectorGroupChat(BaseGroupChat, Component[SelectorGroupChatConfig]):
from autogen_agentchat.teams import SelectorGroupChat
from autogen_agentchat.conditions import TextMentionTermination
from autogen_agentchat.ui import Console
from autogen_agentchat.messages import AgentEvent, ChatMessage
from autogen_agentchat.messages import BaseAgentEvent, BaseChatMessage
async def main() -> None:
@ -404,7 +404,7 @@ class SelectorGroupChat(BaseGroupChat, Component[SelectorGroupChatConfig]):
system_message="Check the answer and respond with 'Correct!' or 'Incorrect!'",
)
def selector_func(messages: Sequence[AgentEvent | ChatMessage]) -> str | None:
def selector_func(messages: Sequence[BaseAgentEvent | BaseChatMessage]) -> str | None:
if len(messages) == 1 or messages[-1].to_text() == "Incorrect!":
return "Agent1"
if messages[-1].source == "Agent1":
@ -448,7 +448,7 @@ Read the above conversation. Then select the next role from {participants} to pl
max_selector_attempts: int = 3,
selector_func: Optional[SelectorFuncType] = None,
candidate_func: Optional[CandidateFuncType] = None,
custom_message_types: List[type[AgentEvent | ChatMessage]] | None = None,
custom_message_types: List[type[BaseAgentEvent | BaseChatMessage]] | None = None,
):
super().__init__(
participants,
@ -477,7 +477,7 @@ Read the above conversation. Then select the next role from {participants} to pl
participant_topic_types: List[str],
participant_names: List[str],
participant_descriptions: List[str],
output_message_queue: asyncio.Queue[AgentEvent | ChatMessage | GroupChatTermination],
output_message_queue: asyncio.Queue[BaseAgentEvent | BaseChatMessage | GroupChatTermination],
termination_condition: TerminationCondition | None,
max_turns: int | None,
message_factory: MessageFactory,
@ -525,7 +525,7 @@ Read the above conversation. Then select the next role from {participants} to pl
selector_prompt=config.selector_prompt,
allow_repeated_speaker=config.allow_repeated_speaker,
max_selector_attempts=config.max_selector_attempts,
# selector_func=ComponentLoader.load_component(config.selector_func, Callable[[Sequence[AgentEvent | ChatMessage]], str | None])
# selector_func=ComponentLoader.load_component(config.selector_func, Callable[[Sequence[BaseAgentEvent | BaseChatMessage]], str | None])
# if config.selector_func
# else None,
)

View File

@ -5,7 +5,7 @@ from autogen_core import AgentRuntime, Component, ComponentModel
from pydantic import BaseModel
from ...base import ChatAgent, TerminationCondition
from ...messages import AgentEvent, ChatMessage, HandoffMessage, MessageFactory
from ...messages import BaseAgentEvent, BaseChatMessage, HandoffMessage, MessageFactory
from ...state import SwarmManagerState
from ._base_group_chat import BaseGroupChat
from ._base_group_chat_manager import BaseGroupChatManager
@ -23,7 +23,7 @@ class SwarmGroupChatManager(BaseGroupChatManager):
participant_topic_types: List[str],
participant_names: List[str],
participant_descriptions: List[str],
output_message_queue: asyncio.Queue[AgentEvent | ChatMessage | GroupChatTermination],
output_message_queue: asyncio.Queue[BaseAgentEvent | BaseChatMessage | GroupChatTermination],
termination_condition: TerminationCondition | None,
max_turns: int | None,
message_factory: MessageFactory,
@ -42,7 +42,7 @@ class SwarmGroupChatManager(BaseGroupChatManager):
)
self._current_speaker = self._participant_names[0]
async def validate_group_state(self, messages: List[ChatMessage] | None) -> None:
async def validate_group_state(self, messages: List[BaseChatMessage] | None) -> None:
"""Validate the start messages for the group chat."""
# Check if any of the start messages is a handoff message.
if messages:
@ -77,7 +77,7 @@ class SwarmGroupChatManager(BaseGroupChatManager):
await self._termination_condition.reset()
self._current_speaker = self._participant_names[0]
async def select_speaker(self, thread: List[AgentEvent | ChatMessage]) -> str:
async def select_speaker(self, thread: List[BaseAgentEvent | BaseChatMessage]) -> str:
"""Select a speaker from the participants based on handoff message.
Looks for the last handoff message in the thread to determine the next speaker."""
if len(thread) == 0:
@ -212,7 +212,7 @@ class Swarm(BaseGroupChat, Component[SwarmConfig]):
termination_condition: TerminationCondition | None = None,
max_turns: int | None = None,
runtime: AgentRuntime | None = None,
custom_message_types: List[type[AgentEvent | ChatMessage]] | None = None,
custom_message_types: List[type[BaseAgentEvent | BaseChatMessage]] | None = None,
) -> None:
super().__init__(
participants,
@ -236,7 +236,7 @@ class Swarm(BaseGroupChat, Component[SwarmConfig]):
participant_topic_types: List[str],
participant_names: List[str],
participant_descriptions: List[str],
output_message_queue: asyncio.Queue[AgentEvent | ChatMessage | GroupChatTermination],
output_message_queue: asyncio.Queue[BaseAgentEvent | BaseChatMessage | GroupChatTermination],
termination_condition: TerminationCondition | None,
max_turns: int | None,
message_factory: MessageFactory,

View File

@ -11,8 +11,8 @@ from autogen_core.models import RequestUsage
from autogen_agentchat.agents import UserProxyAgent
from autogen_agentchat.base import Response, TaskResult
from autogen_agentchat.messages import (
AgentEvent,
ChatMessage,
BaseAgentEvent,
BaseChatMessage,
ModelClientStreamingChunkEvent,
MultiModalMessage,
UserInputRequestedEvent,
@ -80,7 +80,7 @@ def aprint(output: str, end: str = "\n", flush: bool = False) -> Awaitable[None]
async def Console(
stream: AsyncGenerator[AgentEvent | ChatMessage | T, None],
stream: AsyncGenerator[BaseAgentEvent | BaseChatMessage | T, None],
*,
no_inline_images: bool = False,
output_stats: bool = False,
@ -97,7 +97,7 @@ async def Console(
It will be improved in future releases.
Args:
stream (AsyncGenerator[AgentEvent | ChatMessage | TaskResult, None] | AsyncGenerator[AgentEvent | ChatMessage | Response, None]): Message stream to render.
stream (AsyncGenerator[BaseAgentEvent | BaseChatMessage | TaskResult, None] | AsyncGenerator[BaseAgentEvent | BaseChatMessage | Response, None]): Message stream to render.
This can be from :meth:`~autogen_agentchat.base.TaskRunner.run_stream` or :meth:`~autogen_agentchat.base.ChatAgent.on_messages_stream`.
no_inline_images (bool, optional): If terminal is iTerm2 will render images inline. Use this to disable this behavior. Defaults to False.
output_stats (bool, optional): (Experimental) If True, will output a summary of the messages and inline token usage info. Defaults to False.
@ -170,7 +170,7 @@ async def Console(
user_input_manager.notify_event_received(message.request_id)
else:
# Cast required for mypy to be happy
message = cast(AgentEvent | ChatMessage, message) # type: ignore
message = cast(BaseAgentEvent | BaseChatMessage, message) # type: ignore
if not streaming_chunks:
# Print message sender.
await aprint(f"{'-' * 10} {message.source} {'-' * 10}", end="\n", flush=True)

View File

@ -7,7 +7,7 @@ from autogen_agentchat import EVENT_LOGGER_NAME
from autogen_agentchat.agents import AssistantAgent
from autogen_agentchat.base import Handoff, TaskResult
from autogen_agentchat.messages import (
ChatMessage,
BaseChatMessage,
HandoffMessage,
MemoryQueryEvent,
ModelClientStreamingChunkEvent,
@ -737,7 +737,7 @@ async def test_list_chat_messages(monkeypatch: pytest.MonkeyPatch) -> None:
)
# Create a list of chat messages
messages: List[ChatMessage] = [
messages: List[BaseChatMessage] = [
TextMessage(content="Message 1", source="user"),
TextMessage(content="Message 2", source="user"),
]

View File

@ -2,7 +2,7 @@ import asyncio
import json
import logging
import tempfile
from typing import AsyncGenerator, List, Sequence
from typing import Any, AsyncGenerator, List, Mapping, Sequence
import pytest
import pytest_asyncio
@ -15,8 +15,8 @@ from autogen_agentchat.agents import (
from autogen_agentchat.base import Handoff, Response, TaskResult
from autogen_agentchat.conditions import HandoffTermination, MaxMessageTermination, TextMentionTermination
from autogen_agentchat.messages import (
AgentEvent,
ChatMessage,
BaseAgentEvent,
BaseChatMessage,
HandoffMessage,
MultiModalMessage,
StopMessage,
@ -60,14 +60,14 @@ class _EchoAgent(BaseChatAgent):
self._total_messages = 0
@property
def produced_message_types(self) -> Sequence[type[ChatMessage]]:
def produced_message_types(self) -> Sequence[type[BaseChatMessage]]:
return (TextMessage,)
@property
def total_messages(self) -> int:
return self._total_messages
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
async def on_messages(self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken) -> Response:
if len(messages) > 0:
assert isinstance(messages[0], TextMessage)
self._last_message = messages[0].content
@ -89,21 +89,21 @@ class _FlakyAgent(BaseChatAgent):
self._total_messages = 0
@property
def produced_message_types(self) -> Sequence[type[ChatMessage]]:
def produced_message_types(self) -> Sequence[type[BaseChatMessage]]:
return (TextMessage,)
@property
def total_messages(self) -> int:
return self._total_messages
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
async def on_messages(self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken) -> Response:
raise ValueError("I am a flaky agent...")
async def on_reset(self, cancellation_token: CancellationToken) -> None:
self._last_message = None
class _UnknownMessageType(ChatMessage):
class _UnknownMessageType(BaseChatMessage):
content: str
def to_model_message(self) -> UserMessage:
@ -115,16 +115,23 @@ class _UnknownMessageType(ChatMessage):
def to_text(self) -> str:
raise NotImplementedError("This message type is not supported.")
def dump(self) -> Mapping[str, Any]:
return {}
@classmethod
def load(cls, data: Mapping[str, Any]) -> "_UnknownMessageType":
return cls(**data)
class _UnknownMessageTypeAgent(BaseChatAgent):
def __init__(self, name: str, description: str) -> None:
super().__init__(name, description)
@property
def produced_message_types(self) -> Sequence[type[ChatMessage]]:
def produced_message_types(self) -> Sequence[type[BaseChatMessage]]:
return (_UnknownMessageType,)
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
async def on_messages(self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken) -> Response:
return Response(chat_message=_UnknownMessageType(content="Unknown message type", source=self.name))
async def on_reset(self, cancellation_token: CancellationToken) -> None:
@ -138,10 +145,10 @@ class _StopAgent(_EchoAgent):
self._stop_at = stop_at
@property
def produced_message_types(self) -> Sequence[type[ChatMessage]]:
def produced_message_types(self) -> Sequence[type[BaseChatMessage]]:
return (TextMessage, StopMessage)
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
async def on_messages(self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken) -> Response:
self._count += 1
if self._count < self._stop_at:
return await super().on_messages(messages, cancellation_token)
@ -162,7 +169,7 @@ class _InputTask2(BaseModel):
data: str
TaskType = str | List[ChatMessage] | ChatMessage
TaskType = str | List[BaseChatMessage] | BaseChatMessage
@pytest_asyncio.fixture(params=["single_threaded", "embedded"]) # type: ignore
@ -821,7 +828,7 @@ async def test_selector_group_chat_custom_selector(runtime: AgentRuntime | None)
agent3 = _EchoAgent("agent3", description="echo agent 3")
agent4 = _EchoAgent("agent4", description="echo agent 4")
def _select_agent(messages: Sequence[AgentEvent | ChatMessage]) -> str | None:
def _select_agent(messages: Sequence[BaseAgentEvent | BaseChatMessage]) -> str | None:
if len(messages) == 0:
return "agent1"
elif messages[-1].source == "agent1":
@ -862,7 +869,7 @@ async def test_selector_group_chat_custom_candidate_func(runtime: AgentRuntime |
agent3 = _EchoAgent("agent3", description="echo agent 3")
agent4 = _EchoAgent("agent4", description="echo agent 4")
def _candidate_func(messages: Sequence[AgentEvent | ChatMessage]) -> List[str]:
def _candidate_func(messages: Sequence[BaseAgentEvent | BaseChatMessage]) -> List[str]:
if len(messages) == 0:
return ["agent1"]
elif messages[-1].source == "agent1":
@ -901,10 +908,10 @@ class _HandOffAgent(BaseChatAgent):
self._next_agent = next_agent
@property
def produced_message_types(self) -> Sequence[type[ChatMessage]]:
def produced_message_types(self) -> Sequence[type[BaseChatMessage]]:
return (HandoffMessage,)
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
async def on_messages(self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken) -> Response:
return Response(
chat_message=HandoffMessage(
content=f"Transferred to {self._next_agent}.", target=self._next_agent, source=self.name
@ -1292,7 +1299,7 @@ async def test_round_robin_group_chat_with_message_list(runtime: AgentRuntime |
team = RoundRobinGroupChat([agent1, agent2], termination_condition=termination, runtime=runtime)
# Create a list of messages
messages: List[ChatMessage] = [
messages: List[BaseChatMessage] = [
TextMessage(content="Message 1", source="user"),
TextMessage(content="Message 2", source="user"),
TextMessage(content="Message 3", source="user"),
@ -1324,7 +1331,7 @@ async def test_round_robin_group_chat_with_message_list(runtime: AgentRuntime |
index += 1
# Test with invalid message list
with pytest.raises(ValueError, match="All messages in task list must be valid ChatMessage types"):
with pytest.raises(ValueError, match="All messages in task list must be valid BaseChatMessage types"):
await team.run(task=["not a message"]) # type: ignore[list-item, arg-type] # intentionally testing invalid input
# Test with empty message list

View File

@ -4,7 +4,7 @@ from typing import List, Sequence
import pytest
from autogen_agentchat.agents import AssistantAgent
from autogen_agentchat.base import TaskResult
from autogen_agentchat.messages import AgentEvent, ChatMessage
from autogen_agentchat.messages import BaseAgentEvent, BaseChatMessage
from autogen_agentchat.teams import SelectorGroupChat
from autogen_agentchat.ui import Console
from autogen_core.models import ChatCompletionClient
@ -33,7 +33,7 @@ async def _test_selector_group_chat(model_client: ChatCompletionClient) -> None:
async def _test_selector_group_chat_with_candidate_func(model_client: ChatCompletionClient) -> None:
filtered_participants = ["developer", "tester"]
def dummy_candidate_func(thread: Sequence[AgentEvent | ChatMessage]) -> List[str]:
def dummy_candidate_func(thread: Sequence[BaseAgentEvent | BaseChatMessage]) -> List[str]:
# Dummy candidate function that will return
# only return developer and reviewer
return filtered_participants

View File

@ -5,7 +5,7 @@ import pytest
import pytest_asyncio
from autogen_agentchat.agents import BaseChatAgent
from autogen_agentchat.base import Response
from autogen_agentchat.messages import ChatMessage, TextMessage
from autogen_agentchat.messages import BaseChatMessage, TextMessage
from autogen_agentchat.teams import RoundRobinGroupChat
from autogen_core import AgentRuntime, CancellationToken, SingleThreadedAgentRuntime
@ -20,10 +20,10 @@ class TestAgent(BaseChatAgent):
self.counter = 0
@property
def produced_message_types(self) -> Sequence[type[ChatMessage]]:
def produced_message_types(self) -> Sequence[type[BaseChatMessage]]:
return [TextMessage]
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
async def on_messages(self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken) -> Response:
assert not self._is_paused, "Agent is paused"
async def _process() -> None:

View File

@ -11,7 +11,7 @@ from autogen_agentchat.agents import (
)
from autogen_agentchat.base import Response
from autogen_agentchat.messages import (
ChatMessage,
BaseChatMessage,
TextMessage,
)
from autogen_agentchat.teams import (
@ -34,14 +34,14 @@ class _EchoAgent(BaseChatAgent):
self._total_messages = 0
@property
def produced_message_types(self) -> Sequence[type[ChatMessage]]:
def produced_message_types(self) -> Sequence[type[BaseChatMessage]]:
return (TextMessage,)
@property
def total_messages(self) -> int:
return self._total_messages
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
async def on_messages(self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken) -> Response:
if len(messages) > 0:
assert isinstance(messages[0], TextMessage)
self._last_message = messages[0].content

View File

@ -1,5 +1,22 @@
import json
from typing import List
import pytest
from autogen_agentchat.messages import HandoffMessage, MessageFactory, StructuredMessage, TextMessage
from autogen_agentchat.messages import (
AgentEvent,
ChatMessage,
HandoffMessage,
MessageFactory,
ModelClientStreamingChunkEvent,
MultiModalMessage,
StopMessage,
StructuredMessage,
TextMessage,
ToolCallExecutionEvent,
ToolCallRequestEvent,
)
from autogen_core import FunctionCall
from autogen_core.models import FunctionExecutionResult
from pydantic import BaseModel
@ -18,7 +35,7 @@ def test_structured_message() -> None:
)
# Check that the message type is correct
assert message.type == "StructuredMessage[TestContent]" # type: ignore
assert message.type == "StructuredMessage[TestContent]" # type: ignore[comparison-overlap]
# Check that the content is of the correct type
assert isinstance(message.content, TestContent)
@ -50,7 +67,7 @@ def test_message_factory() -> None:
assert isinstance(text_message, TextMessage)
assert text_message.source == "test_agent"
assert text_message.content == "Hello, world!"
assert text_message.type == "TextMessage" # type: ignore
assert text_message.type == "TextMessage" # type: ignore[comparison-overlap]
# Handoff message data
handoff_data = {
@ -66,7 +83,7 @@ def test_message_factory() -> None:
assert handoff_message.source == "test_agent"
assert handoff_message.content == "handoff to another agent"
assert handoff_message.target == "target_agent"
assert handoff_message.type == "HandoffMessage" # type: ignore
assert handoff_message.type == "HandoffMessage" # type: ignore[comparison-overlap]
# Structured message data
structured_data = {
@ -86,8 +103,48 @@ def test_message_factory() -> None:
# Create a StructuredMessage instance
structured_message = factory.create(structured_data)
assert isinstance(structured_message, StructuredMessage)
assert isinstance(structured_message.content, TestContent) # type: ignore
assert isinstance(structured_message.content, TestContent) # type: ignore[reportUnkownMemberType]
assert structured_message.source == "test_agent"
assert structured_message.content.field1 == "test"
assert structured_message.content.field2 == 42
assert structured_message.type == "StructuredMessage[TestContent]" # type: ignore
assert structured_message.type == "StructuredMessage[TestContent]" # type: ignore[comparison-overlap]
class TestContainer(BaseModel):
chat_messages: List[ChatMessage]
agent_events: List[AgentEvent]
def test_union_types() -> None:
# Create a few messages.
chat_messages: List[ChatMessage] = [
TextMessage(source="user", content="Hello!"),
MultiModalMessage(source="user", content=["Hello!", "World!"]),
HandoffMessage(source="user", content="handoff to another agent", target="target_agent"),
StopMessage(source="user", content="stop"),
]
# Create a few agent events.
agent_events: List[AgentEvent] = [
ModelClientStreamingChunkEvent(source="user", content="Hello!"),
ToolCallRequestEvent(
content=[
FunctionCall(id="1", name="test_function", arguments=json.dumps({"arg1": "value1", "arg2": "value2"}))
],
source="user",
),
ToolCallExecutionEvent(
content=[FunctionExecutionResult(call_id="1", content="result", name="test")], source="user"
),
]
# Create a container with the messages.
container = TestContainer(chat_messages=chat_messages, agent_events=agent_events)
# Dump the container to JSON.
data = container.model_dump()
# Load the container from JSON.
loaded_container = TestContainer.model_validate(data)
assert loaded_container.chat_messages == chat_messages
assert loaded_container.agent_events == agent_events

View File

@ -4,7 +4,7 @@ from typing import Optional, Sequence
import pytest
from autogen_agentchat.agents import UserProxyAgent
from autogen_agentchat.base import Response
from autogen_agentchat.messages import ChatMessage, HandoffMessage, TextMessage
from autogen_agentchat.messages import BaseChatMessage, HandoffMessage, TextMessage
from autogen_core import CancellationToken
@ -53,7 +53,7 @@ async def test_handoff_handling() -> None:
agent = UserProxyAgent(name="test_user", input_func=custom_input)
messages: Sequence[ChatMessage] = [
messages: Sequence[BaseChatMessage] = [
TextMessage(content="Initial message", source="assistant"),
HandoffMessage(content="Handing off to user for confirmation", source="assistant", target="test_user"),
]

View File

@ -14,7 +14,7 @@
"\n",
"- {py:meth}`~autogen_agentchat.agents.BaseChatAgent.on_messages`: The abstract method that defines the behavior of the agent in response to messages. This method is called when the agent is asked to provide a response in {py:meth}`~autogen_agentchat.agents.BaseChatAgent.run`. It returns a {py:class}`~autogen_agentchat.base.Response` object.\n",
"- {py:meth}`~autogen_agentchat.agents.BaseChatAgent.on_reset`: The abstract method that resets the agent to its initial state. This method is called when the agent is asked to reset itself.\n",
"- {py:attr}`~autogen_agentchat.agents.BaseChatAgent.produced_message_types`: The list of possible {py:class}`~autogen_agentchat.messages.ChatMessage` message types the agent can produce in its response.\n",
"- {py:attr}`~autogen_agentchat.agents.BaseChatAgent.produced_message_types`: The list of possible {py:class}`~autogen_agentchat.messages.BaseChatMessage` message types the agent can produce in its response.\n",
"\n",
"Optionally, you can implement the the {py:meth}`~autogen_agentchat.agents.BaseChatAgent.on_messages_stream` method to stream messages as they are generated by the agent. If this method is not implemented, the agent\n",
"uses the default implementation of {py:meth}`~autogen_agentchat.agents.BaseChatAgent.on_messages_stream`\n",
@ -53,7 +53,7 @@
"\n",
"from autogen_agentchat.agents import BaseChatAgent\n",
"from autogen_agentchat.base import Response\n",
"from autogen_agentchat.messages import AgentEvent, ChatMessage, TextMessage\n",
"from autogen_agentchat.messages import BaseAgentEvent, BaseChatMessage, TextMessage\n",
"from autogen_core import CancellationToken\n",
"\n",
"\n",
@ -63,10 +63,10 @@
" self._count = count\n",
"\n",
" @property\n",
" def produced_message_types(self) -> Sequence[type[ChatMessage]]:\n",
" def produced_message_types(self) -> Sequence[type[BaseChatMessage]]:\n",
" return (TextMessage,)\n",
"\n",
" async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:\n",
" async def on_messages(self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken) -> Response:\n",
" # Calls the on_messages_stream.\n",
" response: Response | None = None\n",
" async for message in self.on_messages_stream(messages, cancellation_token):\n",
@ -76,9 +76,9 @@
" return response\n",
"\n",
" async def on_messages_stream(\n",
" self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken\n",
" ) -> AsyncGenerator[AgentEvent | ChatMessage | Response, None]:\n",
" inner_messages: List[AgentEvent | ChatMessage] = []\n",
" self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken\n",
" ) -> AsyncGenerator[BaseAgentEvent | BaseChatMessage | Response, None]:\n",
" inner_messages: List[BaseAgentEvent | BaseChatMessage] = []\n",
" for i in range(self._count, 0, -1):\n",
" msg = TextMessage(content=f\"{i}...\", source=self.name)\n",
" inner_messages.append(msg)\n",
@ -135,7 +135,7 @@
"from autogen_agentchat.agents import BaseChatAgent\n",
"from autogen_agentchat.base import Response\n",
"from autogen_agentchat.conditions import MaxMessageTermination\n",
"from autogen_agentchat.messages import ChatMessage\n",
"from autogen_agentchat.messages import BaseChatMessage\n",
"from autogen_agentchat.teams import SelectorGroupChat\n",
"from autogen_agentchat.ui import Console\n",
"from autogen_core import CancellationToken\n",
@ -146,13 +146,13 @@
" def __init__(self, name: str, description: str, operator_func: Callable[[int], int]) -> None:\n",
" super().__init__(name, description=description)\n",
" self._operator_func = operator_func\n",
" self._message_history: List[ChatMessage] = []\n",
" self._message_history: List[BaseChatMessage] = []\n",
"\n",
" @property\n",
" def produced_message_types(self) -> Sequence[type[ChatMessage]]:\n",
" def produced_message_types(self) -> Sequence[type[BaseChatMessage]]:\n",
" return (TextMessage,)\n",
"\n",
" async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:\n",
" async def on_messages(self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken) -> Response:\n",
" # Update the message history.\n",
" # NOTE: it is possible the messages is an empty list, which means the agent was selected previously.\n",
" self._message_history.extend(messages)\n",
@ -268,7 +268,7 @@
" )\n",
"\n",
" # Run the selector group chat with a given task and stream the response.\n",
" task: List[ChatMessage] = [\n",
" task: List[BaseChatMessage] = [\n",
" TextMessage(content=\"Apply the operations to turn the given number into 25.\", source=\"user\"),\n",
" TextMessage(content=\"10\", source=\"user\"),\n",
" ]\n",
@ -319,7 +319,7 @@
"\n",
"from autogen_agentchat.agents import BaseChatAgent\n",
"from autogen_agentchat.base import Response\n",
"from autogen_agentchat.messages import AgentEvent, ChatMessage\n",
"from autogen_agentchat.messages import BaseAgentEvent, BaseChatMessage\n",
"from autogen_core import CancellationToken\n",
"from autogen_core.model_context import UnboundedChatCompletionContext\n",
"from autogen_core.models import AssistantMessage, RequestUsage, UserMessage\n",
@ -344,10 +344,10 @@
" self._model = model\n",
"\n",
" @property\n",
" def produced_message_types(self) -> Sequence[type[ChatMessage]]:\n",
" def produced_message_types(self) -> Sequence[type[BaseChatMessage]]:\n",
" return (TextMessage,)\n",
"\n",
" async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:\n",
" async def on_messages(self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken) -> Response:\n",
" final_response = None\n",
" async for message in self.on_messages_stream(messages, cancellation_token):\n",
" if isinstance(message, Response):\n",
@ -359,8 +359,8 @@
" return final_response\n",
"\n",
" async def on_messages_stream(\n",
" self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken\n",
" ) -> AsyncGenerator[AgentEvent | ChatMessage | Response, None]:\n",
" self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken\n",
" ) -> AsyncGenerator[BaseAgentEvent | BaseChatMessage | Response, None]:\n",
" # Add messages to the model context\n",
" for msg in messages:\n",
" await self._model_context.add_message(msg.to_model_message())\n",
@ -550,7 +550,7 @@
"\n",
"from autogen_agentchat.agents import BaseChatAgent\n",
"from autogen_agentchat.base import Response\n",
"from autogen_agentchat.messages import AgentEvent, ChatMessage\n",
"from autogen_agentchat.messages import BaseAgentEvent, BaseChatMessage\n",
"from autogen_core import CancellationToken, Component\n",
"from pydantic import BaseModel\n",
"from typing_extensions import Self\n",
@ -583,10 +583,10 @@
" self._model = model\n",
"\n",
" @property\n",
" def produced_message_types(self) -> Sequence[type[ChatMessage]]:\n",
" def produced_message_types(self) -> Sequence[type[BaseChatMessage]]:\n",
" return (TextMessage,)\n",
"\n",
" async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:\n",
" async def on_messages(self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken) -> Response:\n",
" final_response = None\n",
" async for message in self.on_messages_stream(messages, cancellation_token):\n",
" if isinstance(message, Response):\n",
@ -598,8 +598,8 @@
" return final_response\n",
"\n",
" async def on_messages_stream(\n",
" self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken\n",
" ) -> AsyncGenerator[AgentEvent | ChatMessage | Response, None]:\n",
" self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken\n",
" ) -> AsyncGenerator[BaseAgentEvent | BaseChatMessage | Response, None]:\n",
" # Add messages to the model context\n",
" for msg in messages:\n",
" await self._model_context.add_message(msg.to_model_message())\n",

File diff suppressed because one or more lines are too long

View File

@ -462,18 +462,18 @@ and implement the `on_messages`, `on_reset`, and `produced_message_types` method
from typing import Sequence
from autogen_core import CancellationToken
from autogen_agentchat.agents import BaseChatAgent
from autogen_agentchat.messages import TextMessage, ChatMessage
from autogen_agentchat.messages import TextMessage, BaseChatMessage
from autogen_agentchat.base import Response
class CustomAgent(BaseChatAgent):
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
async def on_messages(self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken) -> Response:
return Response(chat_message=TextMessage(content="Custom reply", source=self.name))
async def on_reset(self, cancellation_token: CancellationToken) -> None:
pass
@property
def produced_message_types(self) -> Sequence[type[ChatMessage]]:
def produced_message_types(self) -> Sequence[type[BaseChatMessage]]:
return (TextMessage,)
```
@ -742,8 +742,8 @@ You can use the following conversion functions to convert between a v0.4 message
from typing import Any, Dict, List, Literal
from autogen_agentchat.messages import (
AgentEvent,
ChatMessage,
BaseAgentEvent,
BaseChatMessage,
HandoffMessage,
MultiModalMessage,
StopMessage,
@ -757,14 +757,14 @@ from autogen_core.models import FunctionExecutionResult
def convert_to_v02_message(
message: AgentEvent | ChatMessage,
message: BaseAgentEvent | BaseChatMessage,
role: Literal["assistant", "user", "tool"],
image_detail: Literal["auto", "high", "low"] = "auto",
) -> Dict[str, Any]:
"""Convert a v0.4 AgentChat message to a v0.2 message.
Args:
message (AgentEvent | ChatMessage): The message to convert.
message (BaseAgentEvent | BaseChatMessage): The message to convert.
role (Literal["assistant", "user", "tool"]): The role of the message.
image_detail (Literal["auto", "high", "low"], optional): The detail level of image content in multi-modal message. Defaults to "auto".
@ -810,7 +810,7 @@ def convert_to_v02_message(
return v02_message
def convert_to_v04_message(message: Dict[str, Any]) -> AgentEvent | ChatMessage:
def convert_to_v04_message(message: Dict[str, Any]) -> BaseAgentEvent | BaseChatMessage:
"""Convert a v0.2 message to a v0.4 AgentChat message."""
if "tool_calls" in message:
tool_calls: List[FunctionCall] = []
@ -1065,7 +1065,7 @@ import asyncio
from typing import Sequence
from autogen_agentchat.agents import AssistantAgent
from autogen_agentchat.conditions import MaxMessageTermination, TextMentionTermination
from autogen_agentchat.messages import AgentEvent, ChatMessage
from autogen_agentchat.messages import BaseAgentEvent, BaseChatMessage
from autogen_agentchat.teams import SelectorGroupChat
from autogen_agentchat.ui import Console
from autogen_ext.models.openai import OpenAIChatCompletionClient
@ -1141,7 +1141,7 @@ def create_team(model_client : OpenAIChatCompletionClient) -> SelectorGroupChat:
# The selector function is a function that takes the current message thread of the group chat
# and returns the next speaker's name. If None is returned, the LLM-based selection method will be used.
def selector_func(messages: Sequence[AgentEvent | ChatMessage]) -> str | None:
def selector_func(messages: Sequence[BaseAgentEvent | BaseChatMessage]) -> str | None:
if messages[-1].source != planning_agent.name:
return planning_agent.name # Always return to the planning agent after the other agents have spoken.
return None
@ -1190,12 +1190,12 @@ from typing import Sequence
from autogen_core import CancellationToken
from autogen_agentchat.agents import BaseChatAgent
from autogen_agentchat.teams import RoundRobinGroupChat
from autogen_agentchat.messages import TextMessage, ChatMessage
from autogen_agentchat.messages import TextMessage, BaseChatMessage
from autogen_agentchat.base import Response
class CountingAgent(BaseChatAgent):
"""An agent that returns a new number by adding 1 to the last number in the input messages."""
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
async def on_messages(self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken) -> Response:
if len(messages) == 0:
last_number = 0 # Start from 0 if no messages are given.
else:
@ -1207,7 +1207,7 @@ class CountingAgent(BaseChatAgent):
pass
@property
def produced_message_types(self) -> Sequence[type[ChatMessage]]:
def produced_message_types(self) -> Sequence[type[BaseChatMessage]]:
return (TextMessage,)
class NestedCountingAgent(BaseChatAgent):
@ -1217,7 +1217,7 @@ class NestedCountingAgent(BaseChatAgent):
super().__init__(name, description="An agent that counts numbers.")
self._counting_team = counting_team
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
async def on_messages(self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken) -> Response:
# Run the inner team with the given messages and returns the last message produced by the team.
result = await self._counting_team.run(task=messages, cancellation_token=cancellation_token)
# To stream the inner messages, implement `on_messages_stream` and use that to implement `on_messages`.
@ -1229,7 +1229,7 @@ class NestedCountingAgent(BaseChatAgent):
await self._counting_team.reset()
@property
def produced_message_types(self) -> Sequence[type[ChatMessage]]:
def produced_message_types(self) -> Sequence[type[BaseChatMessage]]:
return (TextMessage,)
async def main() -> None:

View File

@ -61,7 +61,7 @@
"\n",
"from autogen_agentchat.agents import AssistantAgent, UserProxyAgent\n",
"from autogen_agentchat.conditions import MaxMessageTermination, TextMentionTermination\n",
"from autogen_agentchat.messages import AgentEvent, ChatMessage\n",
"from autogen_agentchat.messages import BaseAgentEvent, BaseChatMessage\n",
"from autogen_agentchat.teams import SelectorGroupChat\n",
"from autogen_agentchat.ui import Console\n",
"from autogen_ext.models.openai import OpenAIChatCompletionClient"
@ -511,7 +511,7 @@
}
],
"source": [
"def selector_func(messages: Sequence[AgentEvent | ChatMessage]) -> str | None:\n",
"def selector_func(messages: Sequence[BaseAgentEvent | BaseChatMessage]) -> str | None:\n",
" if messages[-1].source != planning_agent.name:\n",
" return planning_agent.name\n",
" return None\n",
@ -655,7 +655,7 @@
}
],
"source": [
"def candidate_func(messages: Sequence[AgentEvent | ChatMessage]) -> List[str]:\n",
"def candidate_func(messages: Sequence[BaseAgentEvent | BaseChatMessage]) -> List[str]:\n",
" # keep planning_agent first one to plan out the tasks\n",
" if messages[-1].source == \"user\":\n",
" return [planning_agent.name]\n",
@ -813,7 +813,7 @@
"user_proxy_agent = UserProxyAgent(\"UserProxyAgent\", description=\"A proxy for the user to approve or disapprove tasks.\")\n",
"\n",
"\n",
"def selector_func_with_user_proxy(messages: Sequence[AgentEvent | ChatMessage]) -> str | None:\n",
"def selector_func_with_user_proxy(messages: Sequence[BaseAgentEvent | BaseChatMessage]) -> str | None:\n",
" if messages[-1].source != planning_agent.name and messages[-1].source != user_proxy_agent.name:\n",
" # Planning agent should be the first to engage when given a new task, or check progress.\n",
" return planning_agent.name\n",
@ -1018,7 +1018,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.3"
"version": "3.12.9"
}
},
"nbformat": 4,

View File

@ -11,8 +11,8 @@
"\n",
"- {py:attr}`~autogen_agentchat.agents.BaseChatAgent.name`: The unique name of the agent.\n",
"- {py:attr}`~autogen_agentchat.agents.BaseChatAgent.description`: The description of the agent in text.\n",
"- {py:meth}`~autogen_agentchat.agents.BaseChatAgent.on_messages`: Send the agent a sequence of {py:class}`~autogen_agentchat.messages.ChatMessage` and get a {py:class}`~autogen_agentchat.base.Response`. **It is important to note that agents are expected to be stateful and this method is expected to be called with new messages, not the complete history**.\n",
"- {py:meth}`~autogen_agentchat.agents.BaseChatAgent.on_messages_stream`: Same as {py:meth}`~autogen_agentchat.agents.BaseChatAgent.on_messages` but returns an iterator of {py:class}`~autogen_agentchat.messages.AgentEvent` or {py:class}`~autogen_agentchat.messages.ChatMessage` followed by a {py:class}`~autogen_agentchat.base.Response` as the last item.\n",
"- {py:meth}`~autogen_agentchat.agents.BaseChatAgent.on_messages`: Send the agent a sequence of {py:class}`~autogen_agentchat.messages.BaseChatMessage` and get a {py:class}`~autogen_agentchat.base.Response`. **It is important to note that agents are expected to be stateful and this method is expected to be called with new messages, not the complete history**.\n",
"- {py:meth}`~autogen_agentchat.agents.BaseChatAgent.on_messages_stream`: Same as {py:meth}`~autogen_agentchat.agents.BaseChatAgent.on_messages` but returns an iterator of {py:class}`~autogen_agentchat.messages.BaseAgentEvent` or {py:class}`~autogen_agentchat.messages.BaseChatMessage` followed by a {py:class}`~autogen_agentchat.base.Response` as the last item.\n",
"- {py:meth}`~autogen_agentchat.agents.BaseChatAgent.on_reset`: Reset the agent to its initial state.\n",
"- {py:meth}`~autogen_agentchat.agents.BaseChatAgent.run` and {py:meth}`~autogen_agentchat.agents.BaseChatAgent.run_stream`: convenience methods that call {py:meth}`~autogen_agentchat.agents.BaseChatAgent.on_messages` and {py:meth}`~autogen_agentchat.agents.BaseChatAgent.on_messages_stream` respectively but offer the same interface as [Teams](./teams.ipynb).\n",
"\n",
@ -840,7 +840,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.3"
"version": "3.12.9"
}
},
"nbformat": 4,

View File

@ -3,12 +3,11 @@ import logging
from abc import ABC, abstractmethod
from collections.abc import Sequence
from typing import Any, Dict, Generic, Mapping, Protocol, Type, TypeVar, cast, runtime_checkable
from typing_extensions import TypedDict
import jsonref
from opentelemetry.trace import get_tracer
from pydantic import BaseModel
from typing_extensions import NotRequired
from typing_extensions import NotRequired, TypedDict
from .. import EVENT_LOGGER_NAME, CancellationToken
from .._component_config import ComponentBase

View File

@ -6,7 +6,7 @@ from typing import List, Sequence, Tuple
from autogen_agentchat.agents import BaseChatAgent
from autogen_agentchat.base import Response
from autogen_agentchat.messages import (
ChatMessage,
BaseChatMessage,
TextMessage,
)
from autogen_agentchat.utils import remove_images
@ -84,10 +84,10 @@ class FileSurfer(BaseChatAgent, Component[FileSurferConfig]):
self._browser = MarkdownFileBrowser(viewport_size=1024 * 5, base_path=base_path)
@property
def produced_message_types(self) -> Sequence[type[ChatMessage]]:
def produced_message_types(self) -> Sequence[type[BaseChatMessage]]:
return (TextMessage,)
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
async def on_messages(self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken) -> Response:
for chat_message in messages:
self._chat_history.append(chat_message.to_model_message())
try:

View File

@ -24,8 +24,8 @@ from autogen_agentchat import EVENT_LOGGER_NAME
from autogen_agentchat.agents import BaseChatAgent
from autogen_agentchat.base import Response
from autogen_agentchat.messages import (
AgentEvent,
ChatMessage,
BaseAgentEvent,
BaseChatMessage,
TextMessage,
ToolCallExecutionEvent,
ToolCallRequestEvent,
@ -353,7 +353,7 @@ class OpenAIAssistantAgent(BaseChatAgent):
self._initial_message_ids = initial_message_ids
@property
def produced_message_types(self) -> Sequence[type[ChatMessage]]:
def produced_message_types(self) -> Sequence[type[BaseChatMessage]]:
"""The types of messages that the assistant agent produces."""
return (TextMessage,)
@ -392,7 +392,7 @@ class OpenAIAssistantAgent(BaseChatAgent):
result = await tool.run_json(arguments, cancellation_token)
return tool.return_value_as_string(result)
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
async def on_messages(self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken) -> Response:
"""Handle incoming messages and return a response."""
async for message in self.on_messages_stream(messages, cancellation_token):
@ -401,8 +401,8 @@ class OpenAIAssistantAgent(BaseChatAgent):
raise AssertionError("The stream should have returned the final result.")
async def on_messages_stream(
self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken
) -> AsyncGenerator[AgentEvent | ChatMessage | Response, None]:
self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken
) -> AsyncGenerator[BaseAgentEvent | BaseChatMessage | Response, None]:
"""Handle incoming messages and return a response."""
await self._ensure_initialized()
@ -411,7 +411,7 @@ class OpenAIAssistantAgent(BaseChatAgent):
await self.handle_incoming_message(message, cancellation_token)
# Inner messages for tool calls
inner_messages: List[AgentEvent | ChatMessage] = []
inner_messages: List[BaseAgentEvent | BaseChatMessage] = []
# Create and start a run
run: Run = await cancellation_token.link_future(
@ -518,7 +518,7 @@ class OpenAIAssistantAgent(BaseChatAgent):
chat_message = TextMessage(source=self.name, content=text_content[0].text.value)
yield Response(chat_message=chat_message, inner_messages=inner_messages)
async def handle_incoming_message(self, message: ChatMessage, cancellation_token: CancellationToken) -> None:
async def handle_incoming_message(self, message: BaseChatMessage, cancellation_token: CancellationToken) -> None:
"""Handle regular text messages by adding them to the thread."""
content: str | List[MessageContentPartParam] | None = None
llm_message = message.to_model_message()

View File

@ -24,7 +24,7 @@ import aiofiles
import PIL.Image
from autogen_agentchat.agents import BaseChatAgent
from autogen_agentchat.base import Response
from autogen_agentchat.messages import AgentEvent, ChatMessage, MultiModalMessage, TextMessage
from autogen_agentchat.messages import BaseAgentEvent, BaseChatMessage, MultiModalMessage, TextMessage
from autogen_agentchat.utils import content_to_str, remove_images
from autogen_core import EVENT_LOGGER_NAME, CancellationToken, Component, ComponentModel, FunctionCall
from autogen_core import Image as AGImage
@ -385,7 +385,7 @@ class MultimodalWebSurfer(BaseChatAgent, Component[MultimodalWebSurferConfig]):
)
@property
def produced_message_types(self) -> Sequence[type[ChatMessage]]:
def produced_message_types(self) -> Sequence[type[BaseChatMessage]]:
return (MultiModalMessage,)
async def on_reset(self, cancellation_token: CancellationToken) -> None:
@ -422,19 +422,19 @@ class MultimodalWebSurfer(BaseChatAgent, Component[MultimodalWebSurferConfig]):
)
)
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
async def on_messages(self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken) -> Response:
async for message in self.on_messages_stream(messages, cancellation_token):
if isinstance(message, Response):
return message
raise AssertionError("The stream should have returned the final result.")
async def on_messages_stream(
self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken
) -> AsyncGenerator[AgentEvent | ChatMessage | Response, None]:
self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken
) -> AsyncGenerator[BaseAgentEvent | BaseChatMessage | Response, None]:
for chat_message in messages:
self._chat_history.append(chat_message.to_model_message())
self.inner_messages: List[AgentEvent | ChatMessage] = []
self.inner_messages: List[BaseAgentEvent | BaseChatMessage] = []
self.model_usage: List[RequestUsage] = []
try:
content = await self._generate_reply(cancellation_token=cancellation_token)

View File

@ -4,7 +4,7 @@ from typing import TYPE_CHECKING, Any, List, Sequence, Tuple, TypedDict
from autogen_agentchat.agents import AssistantAgent
from autogen_agentchat.base import TaskResult
from autogen_agentchat.messages import AgentEvent, ChatMessage, TextMessage
from autogen_agentchat.messages import BaseAgentEvent, BaseChatMessage, TextMessage
from autogen_core.models import (
ChatCompletionClient,
LLMMessage,
@ -190,8 +190,8 @@ In responding to every user message, you follow the same multi-step process give
# Get the agent's response to the task.
task_result: TaskResult = await assistant_agent.run(task=TextMessage(content=task, source="User"))
messages: Sequence[AgentEvent | ChatMessage] = task_result.messages
message: AgentEvent | ChatMessage = messages[-1]
messages: Sequence[BaseAgentEvent | BaseChatMessage] = task_result.messages
message: BaseAgentEvent | BaseChatMessage = messages[-1]
response_str = message.to_text()
# Log the model call

View File

@ -5,7 +5,7 @@ import shutil
from typing import Any, Dict, List, Mapping, Optional, Sequence, TypedDict
from autogen_agentchat.base import TaskResult
from autogen_agentchat.messages import AgentEvent, ChatMessage
from autogen_agentchat.messages import BaseAgentEvent, BaseChatMessage
from autogen_core import Image
from autogen_core.models import (
AssistantMessage,
@ -343,7 +343,7 @@ class PageLogger:
if self.level > self.levels["INFO"]:
return None
messages: Sequence[AgentEvent | ChatMessage] = task_result.messages
messages: Sequence[BaseAgentEvent | BaseChatMessage] = task_result.messages
message = messages[-1]
response_str = message.to_text()
if not isinstance(response_str, str):

View File

@ -14,8 +14,8 @@ from typing import (
from autogen_agentchat.base import Response, TaskResult
from autogen_agentchat.messages import (
AgentEvent,
ChatMessage,
BaseAgentEvent,
BaseChatMessage,
ModelClientStreamingChunkEvent,
MultiModalMessage,
UserInputRequestedEvent,
@ -56,7 +56,7 @@ def aprint(output: str, end: str = "\n") -> Awaitable[None]:
return asyncio.to_thread(print, output, end=end)
def _extract_message_content(message: AgentEvent | ChatMessage) -> Tuple[List[str], List[Image]]:
def _extract_message_content(message: BaseAgentEvent | BaseChatMessage) -> Tuple[List[str], List[Image]]:
if isinstance(message, MultiModalMessage):
text_parts = [item for item in message.content if isinstance(item, str)]
image_parts = [item for item in message.content if isinstance(item, Image)]
@ -100,7 +100,7 @@ async def _aprint_message_content(
async def RichConsole(
stream: AsyncGenerator[AgentEvent | ChatMessage | T, None],
stream: AsyncGenerator[BaseAgentEvent | BaseChatMessage | T, None],
*,
no_inline_images: bool = False,
output_stats: bool = False,
@ -117,7 +117,7 @@ async def RichConsole(
It will be improved in future releases.
Args:
stream (AsyncGenerator[AgentEvent | ChatMessage | TaskResult, None] | AsyncGenerator[AgentEvent | ChatMessage | Response, None]): Message stream to render.
stream (AsyncGenerator[BaseAgentEvent | BaseChatMessage | TaskResult, None] | AsyncGenerator[BaseAgentEvent | BaseChatMessage | Response, None]): Message stream to render.
This can be from :meth:`~autogen_agentchat.base.TaskRunner.run_stream` or :meth:`~autogen_agentchat.base.ChatAgent.on_messages_stream`.
no_inline_images (bool, optional): If terminal is iTerm2 will render images inline. Use this to disable this behavior. Defaults to False.
output_stats (bool, optional): (Experimental) If True, will output a summary of the messages and inline token usage info. Defaults to False.
@ -191,7 +191,7 @@ async def RichConsole(
pass
else:
# Cast required for mypy to be happy
message = cast(AgentEvent | ChatMessage, message) # type: ignore
message = cast(BaseAgentEvent | BaseChatMessage, message) # type: ignore
text_parts, image_parts = _extract_message_content(message)
# Add usage stats if needed

View File

@ -8,7 +8,7 @@ from unittest.mock import AsyncMock, MagicMock
import aiofiles
import pytest
from autogen_agentchat.messages import ChatMessage, TextMessage, ToolCallRequestEvent
from autogen_agentchat.messages import BaseChatMessage, TextMessage, ToolCallRequestEvent
from autogen_core import CancellationToken
from autogen_core.tools._base import BaseTool, Tool
from autogen_ext.agents.openai import OpenAIAssistantAgent
@ -81,7 +81,7 @@ class FakeMessage:
class FakeCursorPage:
def __init__(self, data: List[ChatMessage | FakeMessage]) -> None:
def __init__(self, data: List[BaseChatMessage | FakeMessage]) -> None:
self.data = data
def has_next_page(self) -> bool:

View File

@ -10,7 +10,7 @@ import aiofiles
import yaml
from autogen_agentchat.agents import UserProxyAgent
from autogen_agentchat.base import TaskResult, Team
from autogen_agentchat.messages import AgentEvent, ChatMessage
from autogen_agentchat.messages import BaseAgentEvent, BaseChatMessage
from autogen_agentchat.teams import BaseGroupChat
from autogen_core import EVENT_LOGGER_NAME, CancellationToken, Component, ComponentModel
from autogen_core.logging import LLMCallEvent
@ -102,7 +102,7 @@ class TeamManager:
input_func: Optional[Callable] = None,
cancellation_token: Optional[CancellationToken] = None,
env_vars: Optional[List[EnvironmentVariable]] = None,
) -> AsyncGenerator[Union[AgentEvent | ChatMessage | LLMCallEvent, ChatMessage, TeamResult], None]:
) -> AsyncGenerator[Union[BaseAgentEvent | BaseChatMessage | LLMCallEvent, BaseChatMessage, TeamResult], None]:
"""Stream team execution results"""
start_time = time.time()
team = None

View File

@ -6,8 +6,8 @@ from typing import Any, Callable, Dict, Optional, Union
from autogen_agentchat.base._task import TaskResult
from autogen_agentchat.messages import (
AgentEvent,
ChatMessage,
BaseAgentEvent,
BaseChatMessage,
HandoffMessage,
ModelClientStreamingChunkEvent,
MultiModalMessage,
@ -160,7 +160,9 @@ class WebSocketManager:
finally:
self._cancellation_tokens.pop(run_id, None)
async def _save_message(self, run_id: int, message: Union[AgentEvent | ChatMessage, ChatMessage]) -> None:
async def _save_message(
self, run_id: int, message: Union[BaseAgentEvent | BaseChatMessage, BaseChatMessage]
) -> None:
"""Save a message to the database"""
run = await self._get_run(run_id)