mirror of https://github.com/microsoft/autogen.git
Rename to use BaseChatMessage and BaseAgentEvent. Bring back union types. (#6144)
Rename the `ChatMessage` and `AgentEvent` base classes to `BaseChatMessage` and `BaseAgentEvent`. Bring back the `ChatMessage` and `AgentEvent` as union of built-in concrete types to avoid breaking existing applications that depends on Pydantic serialization. Why? Many existing code uses containers like this: ```python class AppMessage(BaseModel): name: str message: ChatMessage # Serialization is this: m = AppMessage(...) m.model_dump_json() # Fields like HandoffMessage.target will be lost because it is now treated as a base class without content or target fields. ``` The assumption on `ChatMessage` or `AgentEvent` to be a union of concrete types could be in many existing code bases. So this PR brings back the union types, while keep method type hints such as those on `on_messages` to use the `BaseChatMessage` and `BaseAgentEvent` base classes for flexibility.
This commit is contained in:
parent
e686342f53
commit
7615c7b83b
|
@ -16,7 +16,7 @@ from autogen_core.models import ChatCompletionClient
|
|||
from autogen_ext.agents.web_surfer import MultimodalWebSurfer
|
||||
from autogen_ext.agents.file_surfer import FileSurfer
|
||||
from autogen_agentchat.agents import CodeExecutorAgent
|
||||
from autogen_agentchat.messages import TextMessage, AgentEvent, ChatMessage, HandoffMessage, MultiModalMessage, StopMessage
|
||||
from autogen_agentchat.messages import TextMessage, BaseAgentEvent, BaseChatMessage, HandoffMessage, MultiModalMessage, StopMessage
|
||||
from autogen_core.models import LLMMessage, UserMessage, AssistantMessage
|
||||
|
||||
# Suppress warnings about the requests.Session() not being closed
|
||||
|
@ -141,7 +141,7 @@ class LLMTermination(TerminationCondition):
|
|||
def terminated(self) -> bool:
|
||||
return self._terminated
|
||||
|
||||
async def __call__(self, messages: Sequence[AgentEvent | ChatMessage]) -> StopMessage | None:
|
||||
async def __call__(self, messages: Sequence[BaseAgentEvent | BaseChatMessage]) -> StopMessage | None:
|
||||
if self._terminated:
|
||||
raise TerminatedException("Termination condition has already been reached")
|
||||
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# __init__.py
|
||||
from ._base import Code, Document, CodedDocument, BaseQualitativeCoder
|
||||
from ._base import BaseQualitativeCoder, Code, CodedDocument, Document
|
||||
|
||||
__all__ = ["Code", "Document", "CodedDocument", "BaseQualitativeCoder"]
|
||||
|
|
|
@ -1,7 +1,8 @@
|
|||
import json
|
||||
import hashlib
|
||||
import json
|
||||
import re
|
||||
from typing import Protocol, List, Set, Optional
|
||||
from typing import List, Optional, Protocol, Set
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
|
|
|
@ -1,8 +1,10 @@
|
|||
import os
|
||||
import argparse
|
||||
from typing import List, Sequence, Optional
|
||||
import os
|
||||
from typing import List, Optional, Sequence
|
||||
|
||||
from openai import OpenAI
|
||||
from ._base import Document, CodedDocument
|
||||
|
||||
from ._base import CodedDocument, Document
|
||||
from .coders.oai_coder import OAIQualitativeCoder
|
||||
|
||||
|
||||
|
|
|
@ -1,13 +1,11 @@
|
|||
import os
|
||||
import re
|
||||
|
||||
from typing import List, Set, Optional
|
||||
from pydantic import BaseModel
|
||||
from typing import List, Optional, Set
|
||||
|
||||
from openai import OpenAI
|
||||
from pydantic import BaseModel
|
||||
|
||||
from .._base import CodedDocument, Document, Code
|
||||
from .._base import BaseQualitativeCoder
|
||||
from .._base import BaseQualitativeCoder, Code, CodedDocument, Document
|
||||
|
||||
|
||||
class CodeList(BaseModel):
|
||||
|
|
|
@ -40,8 +40,8 @@ from .. import EVENT_LOGGER_NAME
|
|||
from ..base import Handoff as HandoffBase
|
||||
from ..base import Response
|
||||
from ..messages import (
|
||||
AgentEvent,
|
||||
ChatMessage,
|
||||
BaseAgentEvent,
|
||||
BaseChatMessage,
|
||||
HandoffMessage,
|
||||
MemoryQueryEvent,
|
||||
ModelClientStreamingChunkEvent,
|
||||
|
@ -697,8 +697,8 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
|
|||
self._is_running = False
|
||||
|
||||
@property
|
||||
def produced_message_types(self) -> Sequence[type[ChatMessage]]:
|
||||
message_types: List[type[ChatMessage]] = [TextMessage]
|
||||
def produced_message_types(self) -> Sequence[type[BaseChatMessage]]:
|
||||
message_types: List[type[BaseChatMessage]] = [TextMessage]
|
||||
if self._handoffs:
|
||||
message_types.append(HandoffMessage)
|
||||
if self._tools:
|
||||
|
@ -712,15 +712,15 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
|
|||
"""
|
||||
return self._model_context
|
||||
|
||||
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
|
||||
async def on_messages(self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken) -> Response:
|
||||
async for message in self.on_messages_stream(messages, cancellation_token):
|
||||
if isinstance(message, Response):
|
||||
return message
|
||||
raise AssertionError("The stream should have returned the final result.")
|
||||
|
||||
async def on_messages_stream(
|
||||
self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken
|
||||
) -> AsyncGenerator[AgentEvent | ChatMessage | Response, None]:
|
||||
self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken
|
||||
) -> AsyncGenerator[BaseAgentEvent | BaseChatMessage | Response, None]:
|
||||
"""
|
||||
Process the incoming messages with the assistant agent and yield events/responses as they happen.
|
||||
"""
|
||||
|
@ -745,7 +745,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
|
|||
)
|
||||
|
||||
# STEP 2: Update model context with any relevant memory
|
||||
inner_messages: List[AgentEvent | ChatMessage] = []
|
||||
inner_messages: List[BaseAgentEvent | BaseChatMessage] = []
|
||||
for event_msg in await self._update_model_context_with_memory(
|
||||
memory=memory,
|
||||
model_context=model_context,
|
||||
|
@ -810,7 +810,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
|
|||
@staticmethod
|
||||
async def _add_messages_to_context(
|
||||
model_context: ChatCompletionContext,
|
||||
messages: Sequence[ChatMessage],
|
||||
messages: Sequence[BaseChatMessage],
|
||||
) -> None:
|
||||
"""
|
||||
Add incoming messages to the model context.
|
||||
|
@ -886,7 +886,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
|
|||
async def _process_model_result(
|
||||
cls,
|
||||
model_result: CreateResult,
|
||||
inner_messages: List[AgentEvent | ChatMessage],
|
||||
inner_messages: List[BaseAgentEvent | BaseChatMessage],
|
||||
cancellation_token: CancellationToken,
|
||||
agent_name: str,
|
||||
system_messages: List[SystemMessage],
|
||||
|
@ -898,7 +898,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
|
|||
model_client_stream: bool,
|
||||
reflect_on_tool_use: bool,
|
||||
tool_call_summary_format: str,
|
||||
) -> AsyncGenerator[AgentEvent | ChatMessage | Response, None]:
|
||||
) -> AsyncGenerator[BaseAgentEvent | BaseChatMessage | Response, None]:
|
||||
"""
|
||||
Handle final or partial responses from model_result, including tool calls, handoffs,
|
||||
and reflection if needed.
|
||||
|
@ -992,7 +992,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
|
|||
def _check_and_handle_handoff(
|
||||
model_result: CreateResult,
|
||||
executed_calls_and_results: List[Tuple[FunctionCall, FunctionExecutionResult]],
|
||||
inner_messages: List[AgentEvent | ChatMessage],
|
||||
inner_messages: List[BaseAgentEvent | BaseChatMessage],
|
||||
handoffs: Dict[str, HandoffBase],
|
||||
agent_name: str,
|
||||
) -> Optional[Response]:
|
||||
|
@ -1061,7 +1061,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
|
|||
model_client_stream: bool,
|
||||
model_context: ChatCompletionContext,
|
||||
agent_name: str,
|
||||
inner_messages: List[AgentEvent | ChatMessage],
|
||||
inner_messages: List[BaseAgentEvent | BaseChatMessage],
|
||||
) -> AsyncGenerator[Response | ModelClientStreamingChunkEvent | ThoughtEvent, None]:
|
||||
"""
|
||||
If reflect_on_tool_use=True, we do another inference based on tool results
|
||||
|
@ -1113,7 +1113,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
|
|||
@staticmethod
|
||||
def _summarize_tool_use(
|
||||
executed_calls_and_results: List[Tuple[FunctionCall, FunctionExecutionResult]],
|
||||
inner_messages: List[AgentEvent | ChatMessage],
|
||||
inner_messages: List[BaseAgentEvent | BaseChatMessage],
|
||||
handoffs: Dict[str, HandoffBase],
|
||||
tool_call_summary_format: str,
|
||||
agent_name: str,
|
||||
|
|
|
@ -6,8 +6,8 @@ from pydantic import BaseModel
|
|||
|
||||
from ..base import ChatAgent, Response, TaskResult
|
||||
from ..messages import (
|
||||
AgentEvent,
|
||||
ChatMessage,
|
||||
BaseAgentEvent,
|
||||
BaseChatMessage,
|
||||
ModelClientStreamingChunkEvent,
|
||||
TextMessage,
|
||||
)
|
||||
|
@ -59,13 +59,13 @@ class BaseChatAgent(ChatAgent, ABC, ComponentBase[BaseModel]):
|
|||
|
||||
@property
|
||||
@abstractmethod
|
||||
def produced_message_types(self) -> Sequence[type[ChatMessage]]:
|
||||
def produced_message_types(self) -> Sequence[type[BaseChatMessage]]:
|
||||
"""The types of messages that the agent produces in the
|
||||
:attr:`Response.chat_message` field. They must be :class:`ChatMessage` types."""
|
||||
:attr:`Response.chat_message` field. They must be :class:`BaseChatMessage` types."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
|
||||
async def on_messages(self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken) -> Response:
|
||||
"""Handles incoming messages and returns a response.
|
||||
|
||||
.. note::
|
||||
|
@ -81,8 +81,8 @@ class BaseChatAgent(ChatAgent, ABC, ComponentBase[BaseModel]):
|
|||
...
|
||||
|
||||
async def on_messages_stream(
|
||||
self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken
|
||||
) -> AsyncGenerator[AgentEvent | ChatMessage | Response, None]:
|
||||
self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken
|
||||
) -> AsyncGenerator[BaseAgentEvent | BaseChatMessage | Response, None]:
|
||||
"""Handles incoming messages and returns a stream of messages and
|
||||
and the final item is the response. The base implementation in
|
||||
:class:`BaseChatAgent` simply calls :meth:`on_messages` and yields
|
||||
|
@ -106,21 +106,21 @@ class BaseChatAgent(ChatAgent, ABC, ComponentBase[BaseModel]):
|
|||
async def run(
|
||||
self,
|
||||
*,
|
||||
task: str | ChatMessage | Sequence[ChatMessage] | None = None,
|
||||
task: str | BaseChatMessage | Sequence[BaseChatMessage] | None = None,
|
||||
cancellation_token: CancellationToken | None = None,
|
||||
) -> TaskResult:
|
||||
"""Run the agent with the given task and return the result."""
|
||||
if cancellation_token is None:
|
||||
cancellation_token = CancellationToken()
|
||||
input_messages: List[ChatMessage] = []
|
||||
output_messages: List[AgentEvent | ChatMessage] = []
|
||||
input_messages: List[BaseChatMessage] = []
|
||||
output_messages: List[BaseAgentEvent | BaseChatMessage] = []
|
||||
if task is None:
|
||||
pass
|
||||
elif isinstance(task, str):
|
||||
text_msg = TextMessage(content=task, source="user")
|
||||
input_messages.append(text_msg)
|
||||
output_messages.append(text_msg)
|
||||
elif isinstance(task, ChatMessage):
|
||||
elif isinstance(task, BaseChatMessage):
|
||||
input_messages.append(task)
|
||||
output_messages.append(task)
|
||||
else:
|
||||
|
@ -128,7 +128,7 @@ class BaseChatAgent(ChatAgent, ABC, ComponentBase[BaseModel]):
|
|||
raise ValueError("Task list cannot be empty.")
|
||||
# Task is a sequence of messages.
|
||||
for msg in task:
|
||||
if isinstance(msg, ChatMessage):
|
||||
if isinstance(msg, BaseChatMessage):
|
||||
input_messages.append(msg)
|
||||
output_messages.append(msg)
|
||||
else:
|
||||
|
@ -142,15 +142,15 @@ class BaseChatAgent(ChatAgent, ABC, ComponentBase[BaseModel]):
|
|||
async def run_stream(
|
||||
self,
|
||||
*,
|
||||
task: str | ChatMessage | Sequence[ChatMessage] | None = None,
|
||||
task: str | BaseChatMessage | Sequence[BaseChatMessage] | None = None,
|
||||
cancellation_token: CancellationToken | None = None,
|
||||
) -> AsyncGenerator[AgentEvent | ChatMessage | TaskResult, None]:
|
||||
) -> AsyncGenerator[BaseAgentEvent | BaseChatMessage | TaskResult, None]:
|
||||
"""Run the agent with the given task and return a stream of messages
|
||||
and the final task result as the last item in the stream."""
|
||||
if cancellation_token is None:
|
||||
cancellation_token = CancellationToken()
|
||||
input_messages: List[ChatMessage] = []
|
||||
output_messages: List[AgentEvent | ChatMessage] = []
|
||||
input_messages: List[BaseChatMessage] = []
|
||||
output_messages: List[BaseAgentEvent | BaseChatMessage] = []
|
||||
if task is None:
|
||||
pass
|
||||
elif isinstance(task, str):
|
||||
|
@ -158,7 +158,7 @@ class BaseChatAgent(ChatAgent, ABC, ComponentBase[BaseModel]):
|
|||
input_messages.append(text_msg)
|
||||
output_messages.append(text_msg)
|
||||
yield text_msg
|
||||
elif isinstance(task, ChatMessage):
|
||||
elif isinstance(task, BaseChatMessage):
|
||||
input_messages.append(task)
|
||||
output_messages.append(task)
|
||||
yield task
|
||||
|
@ -166,7 +166,7 @@ class BaseChatAgent(ChatAgent, ABC, ComponentBase[BaseModel]):
|
|||
if not task:
|
||||
raise ValueError("Task list cannot be empty.")
|
||||
for msg in task:
|
||||
if isinstance(msg, ChatMessage):
|
||||
if isinstance(msg, BaseChatMessage):
|
||||
input_messages.append(msg)
|
||||
output_messages.append(msg)
|
||||
yield msg
|
||||
|
|
|
@ -7,7 +7,7 @@ from pydantic import BaseModel
|
|||
from typing_extensions import Self
|
||||
|
||||
from ..base import Response
|
||||
from ..messages import ChatMessage, TextMessage
|
||||
from ..messages import BaseChatMessage, TextMessage
|
||||
from ._base_chat_agent import BaseChatAgent
|
||||
|
||||
|
||||
|
@ -119,11 +119,11 @@ class CodeExecutorAgent(BaseChatAgent, Component[CodeExecutorAgentConfig]):
|
|||
self._sources = sources
|
||||
|
||||
@property
|
||||
def produced_message_types(self) -> Sequence[type[ChatMessage]]:
|
||||
def produced_message_types(self) -> Sequence[type[BaseChatMessage]]:
|
||||
"""The types of messages that the code executor agent produces."""
|
||||
return (TextMessage,)
|
||||
|
||||
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
|
||||
async def on_messages(self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken) -> Response:
|
||||
# Extract code blocks from the messages.
|
||||
code_blocks: List[CodeBlock] = []
|
||||
for msg in messages:
|
||||
|
|
|
@ -10,8 +10,8 @@ from autogen_agentchat.state import SocietyOfMindAgentState
|
|||
|
||||
from ..base import TaskResult, Team
|
||||
from ..messages import (
|
||||
AgentEvent,
|
||||
ChatMessage,
|
||||
BaseAgentEvent,
|
||||
BaseChatMessage,
|
||||
ModelClientStreamingChunkEvent,
|
||||
TextMessage,
|
||||
)
|
||||
|
@ -122,10 +122,10 @@ class SocietyOfMindAgent(BaseChatAgent, Component[SocietyOfMindAgentConfig]):
|
|||
self._response_prompt = response_prompt
|
||||
|
||||
@property
|
||||
def produced_message_types(self) -> Sequence[type[ChatMessage]]:
|
||||
def produced_message_types(self) -> Sequence[type[BaseChatMessage]]:
|
||||
return (TextMessage,)
|
||||
|
||||
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
|
||||
async def on_messages(self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken) -> Response:
|
||||
# Call the stream method and collect the messages.
|
||||
response: Response | None = None
|
||||
async for msg in self.on_messages_stream(messages, cancellation_token):
|
||||
|
@ -135,14 +135,14 @@ class SocietyOfMindAgent(BaseChatAgent, Component[SocietyOfMindAgentConfig]):
|
|||
return response
|
||||
|
||||
async def on_messages_stream(
|
||||
self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken
|
||||
) -> AsyncGenerator[AgentEvent | ChatMessage | Response, None]:
|
||||
self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken
|
||||
) -> AsyncGenerator[BaseAgentEvent | BaseChatMessage | Response, None]:
|
||||
# Prepare the task for the team of agents.
|
||||
task = list(messages)
|
||||
|
||||
# Run the team of agents.
|
||||
result: TaskResult | None = None
|
||||
inner_messages: List[AgentEvent | ChatMessage] = []
|
||||
inner_messages: List[BaseAgentEvent | BaseChatMessage] = []
|
||||
count = 0
|
||||
async for inner_msg in self._team.run_stream(task=task, cancellation_token=cancellation_token):
|
||||
if isinstance(inner_msg, TaskResult):
|
||||
|
@ -167,7 +167,7 @@ class SocietyOfMindAgent(BaseChatAgent, Component[SocietyOfMindAgentConfig]):
|
|||
# Generate a response using the model client.
|
||||
llm_messages: List[LLMMessage] = [SystemMessage(content=self._instruction)]
|
||||
for message in messages:
|
||||
if isinstance(message, ChatMessage):
|
||||
if isinstance(message, BaseChatMessage):
|
||||
llm_messages.append(message.to_model_message())
|
||||
llm_messages.append(SystemMessage(content=self._response_prompt))
|
||||
completion = await self._model_client.create(messages=llm_messages, cancellation_token=cancellation_token)
|
||||
|
|
|
@ -10,7 +10,7 @@ from pydantic import BaseModel
|
|||
from typing_extensions import Self
|
||||
|
||||
from ..base import Response
|
||||
from ..messages import AgentEvent, ChatMessage, HandoffMessage, TextMessage, UserInputRequestedEvent
|
||||
from ..messages import BaseAgentEvent, BaseChatMessage, HandoffMessage, TextMessage, UserInputRequestedEvent
|
||||
from ._base_chat_agent import BaseChatAgent
|
||||
|
||||
SyncInputFunc = Callable[[str], str]
|
||||
|
@ -170,11 +170,11 @@ class UserProxyAgent(BaseChatAgent, Component[UserProxyAgentConfig]):
|
|||
self._is_async = iscoroutinefunction(self.input_func)
|
||||
|
||||
@property
|
||||
def produced_message_types(self) -> Sequence[type[ChatMessage]]:
|
||||
def produced_message_types(self) -> Sequence[type[BaseChatMessage]]:
|
||||
"""Message types this agent can produce."""
|
||||
return (TextMessage, HandoffMessage)
|
||||
|
||||
def _get_latest_handoff(self, messages: Sequence[ChatMessage]) -> Optional[HandoffMessage]:
|
||||
def _get_latest_handoff(self, messages: Sequence[BaseChatMessage]) -> Optional[HandoffMessage]:
|
||||
"""Find the HandoffMessage in the message sequence that addresses this agent."""
|
||||
if len(messages) > 0 and isinstance(messages[-1], HandoffMessage):
|
||||
if messages[-1].target == self.name:
|
||||
|
@ -201,15 +201,15 @@ class UserProxyAgent(BaseChatAgent, Component[UserProxyAgentConfig]):
|
|||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to get user input: {str(e)}") from e
|
||||
|
||||
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
|
||||
async def on_messages(self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken) -> Response:
|
||||
async for message in self.on_messages_stream(messages, cancellation_token):
|
||||
if isinstance(message, Response):
|
||||
return message
|
||||
raise AssertionError("The stream should have returned the final result.")
|
||||
|
||||
async def on_messages_stream(
|
||||
self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken
|
||||
) -> AsyncGenerator[AgentEvent | ChatMessage | Response, None]:
|
||||
self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken
|
||||
) -> AsyncGenerator[BaseAgentEvent | BaseChatMessage | Response, None]:
|
||||
"""Handle incoming messages by requesting user input."""
|
||||
try:
|
||||
# Check for handoff first
|
||||
|
|
|
@ -5,7 +5,7 @@ from typing import Any, AsyncGenerator, Mapping, Sequence
|
|||
from autogen_core import CancellationToken, ComponentBase
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ..messages import AgentEvent, ChatMessage
|
||||
from ..messages import BaseAgentEvent, BaseChatMessage
|
||||
from ._task import TaskRunner
|
||||
|
||||
|
||||
|
@ -13,12 +13,12 @@ from ._task import TaskRunner
|
|||
class Response:
|
||||
"""A response from calling :meth:`ChatAgent.on_messages`."""
|
||||
|
||||
chat_message: ChatMessage
|
||||
chat_message: BaseChatMessage
|
||||
"""A chat message produced by the agent as the response."""
|
||||
|
||||
inner_messages: Sequence[AgentEvent | ChatMessage] | None = None
|
||||
"""Inner messages produced by the agent, they can be :class:`AgentEvent`
|
||||
or :class:`ChatMessage`."""
|
||||
inner_messages: Sequence[BaseAgentEvent | BaseChatMessage] | None = None
|
||||
"""Inner messages produced by the agent, they can be :class:`BaseAgentEvent`
|
||||
or :class:`BaseChatMessage`."""
|
||||
|
||||
|
||||
class ChatAgent(ABC, TaskRunner, ComponentBase[BaseModel]):
|
||||
|
@ -43,20 +43,20 @@ class ChatAgent(ABC, TaskRunner, ComponentBase[BaseModel]):
|
|||
|
||||
@property
|
||||
@abstractmethod
|
||||
def produced_message_types(self) -> Sequence[type[ChatMessage]]:
|
||||
def produced_message_types(self) -> Sequence[type[BaseChatMessage]]:
|
||||
"""The types of messages that the agent produces in the
|
||||
:attr:`Response.chat_message` field. They must be :class:`ChatMessage` types."""
|
||||
:attr:`Response.chat_message` field. They must be :class:`BaseChatMessage` types."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
|
||||
async def on_messages(self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken) -> Response:
|
||||
"""Handles incoming messages and returns a response."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def on_messages_stream(
|
||||
self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken
|
||||
) -> AsyncGenerator[AgentEvent | ChatMessage | Response, None]:
|
||||
self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken
|
||||
) -> AsyncGenerator[BaseAgentEvent | BaseChatMessage | Response, None]:
|
||||
"""Handles incoming messages and returns a stream of inner messages and
|
||||
and the final item is the response."""
|
||||
...
|
||||
|
|
|
@ -3,14 +3,14 @@ from typing import AsyncGenerator, Protocol, Sequence
|
|||
|
||||
from autogen_core import CancellationToken
|
||||
|
||||
from ..messages import AgentEvent, ChatMessage
|
||||
from ..messages import BaseAgentEvent, BaseChatMessage
|
||||
|
||||
|
||||
@dataclass
|
||||
class TaskResult:
|
||||
"""Result of running a task."""
|
||||
|
||||
messages: Sequence[AgentEvent | ChatMessage]
|
||||
messages: Sequence[BaseAgentEvent | BaseChatMessage]
|
||||
"""Messages produced by the task."""
|
||||
|
||||
stop_reason: str | None = None
|
||||
|
@ -23,7 +23,7 @@ class TaskRunner(Protocol):
|
|||
async def run(
|
||||
self,
|
||||
*,
|
||||
task: str | ChatMessage | Sequence[ChatMessage] | None = None,
|
||||
task: str | BaseChatMessage | Sequence[BaseChatMessage] | None = None,
|
||||
cancellation_token: CancellationToken | None = None,
|
||||
) -> TaskResult:
|
||||
"""Run the task and return the result.
|
||||
|
@ -38,9 +38,9 @@ class TaskRunner(Protocol):
|
|||
def run_stream(
|
||||
self,
|
||||
*,
|
||||
task: str | ChatMessage | Sequence[ChatMessage] | None = None,
|
||||
task: str | BaseChatMessage | Sequence[BaseChatMessage] | None = None,
|
||||
cancellation_token: CancellationToken | None = None,
|
||||
) -> AsyncGenerator[AgentEvent | ChatMessage | TaskResult, None]:
|
||||
) -> AsyncGenerator[BaseAgentEvent | BaseChatMessage | TaskResult, None]:
|
||||
"""Run the task and produces a stream of messages and the final result
|
||||
:class:`TaskResult` as the last item in the stream.
|
||||
|
||||
|
|
|
@ -6,7 +6,7 @@ from autogen_core import Component, ComponentBase, ComponentModel
|
|||
from pydantic import BaseModel
|
||||
from typing_extensions import Self
|
||||
|
||||
from ..messages import AgentEvent, ChatMessage, StopMessage
|
||||
from ..messages import BaseAgentEvent, BaseChatMessage, StopMessage
|
||||
|
||||
|
||||
class TerminatedException(BaseException): ...
|
||||
|
@ -15,7 +15,7 @@ class TerminatedException(BaseException): ...
|
|||
class TerminationCondition(ABC, ComponentBase[BaseModel]):
|
||||
"""A stateful condition that determines when a conversation should be terminated.
|
||||
|
||||
A termination condition is a callable that takes a sequence of ChatMessage objects
|
||||
A termination condition is a callable that takes a sequence of BaseChatMessage objects
|
||||
since the last time the condition was called, and returns a StopMessage if the
|
||||
conversation should be terminated, or None otherwise.
|
||||
Once a termination condition has been reached, it must be reset before it can be used again.
|
||||
|
@ -56,7 +56,7 @@ class TerminationCondition(ABC, ComponentBase[BaseModel]):
|
|||
...
|
||||
|
||||
@abstractmethod
|
||||
async def __call__(self, messages: Sequence[AgentEvent | ChatMessage]) -> StopMessage | None:
|
||||
async def __call__(self, messages: Sequence[BaseAgentEvent | BaseChatMessage]) -> StopMessage | None:
|
||||
"""Check if the conversation should be terminated based on the messages received
|
||||
since the last time the condition was called.
|
||||
Return a StopMessage if the conversation should be terminated, or None otherwise.
|
||||
|
@ -102,7 +102,7 @@ class AndTerminationCondition(TerminationCondition, Component[AndTerminationCond
|
|||
def terminated(self) -> bool:
|
||||
return all(condition.terminated for condition in self._conditions)
|
||||
|
||||
async def __call__(self, messages: Sequence[AgentEvent | ChatMessage]) -> StopMessage | None:
|
||||
async def __call__(self, messages: Sequence[BaseAgentEvent | BaseChatMessage]) -> StopMessage | None:
|
||||
if self.terminated:
|
||||
raise TerminatedException("Termination condition has already been reached.")
|
||||
# Check all remaining conditions.
|
||||
|
@ -153,7 +153,7 @@ class OrTerminationCondition(TerminationCondition, Component[OrTerminationCondit
|
|||
def terminated(self) -> bool:
|
||||
return any(condition.terminated for condition in self._conditions)
|
||||
|
||||
async def __call__(self, messages: Sequence[AgentEvent | ChatMessage]) -> StopMessage | None:
|
||||
async def __call__(self, messages: Sequence[BaseAgentEvent | BaseChatMessage]) -> StopMessage | None:
|
||||
if self.terminated:
|
||||
raise RuntimeError("Termination condition has already been reached")
|
||||
stop_messages = await asyncio.gather(*[condition(messages) for condition in self._conditions])
|
||||
|
|
|
@ -7,9 +7,8 @@ from typing_extensions import Self
|
|||
|
||||
from ..base import TerminatedException, TerminationCondition
|
||||
from ..messages import (
|
||||
AgentEvent,
|
||||
BaseAgentEvent,
|
||||
BaseChatMessage,
|
||||
ChatMessage,
|
||||
HandoffMessage,
|
||||
StopMessage,
|
||||
TextMessage,
|
||||
|
@ -34,7 +33,7 @@ class StopMessageTermination(TerminationCondition, Component[StopMessageTerminat
|
|||
def terminated(self) -> bool:
|
||||
return self._terminated
|
||||
|
||||
async def __call__(self, messages: Sequence[AgentEvent | ChatMessage]) -> StopMessage | None:
|
||||
async def __call__(self, messages: Sequence[BaseAgentEvent | BaseChatMessage]) -> StopMessage | None:
|
||||
if self._terminated:
|
||||
raise TerminatedException("Termination condition has already been reached")
|
||||
for message in messages:
|
||||
|
@ -64,8 +63,8 @@ class MaxMessageTermination(TerminationCondition, Component[MaxMessageTerminatio
|
|||
|
||||
Args:
|
||||
max_messages: The maximum number of messages allowed in the conversation.
|
||||
include_agent_event: If True, include :class:`~autogen_agentchat.messages.AgentEvent` in the message count.
|
||||
Otherwise, only include :class:`~autogen_agentchat.messages.ChatMessage`. Defaults to False.
|
||||
include_agent_event: If True, include :class:`~autogen_agentchat.messages.BaseAgentEvent` in the message count.
|
||||
Otherwise, only include :class:`~autogen_agentchat.messages.BaseChatMessage`. Defaults to False.
|
||||
"""
|
||||
|
||||
component_config_schema = MaxMessageTerminationConfig
|
||||
|
@ -80,7 +79,7 @@ class MaxMessageTermination(TerminationCondition, Component[MaxMessageTerminatio
|
|||
def terminated(self) -> bool:
|
||||
return self._message_count >= self._max_messages
|
||||
|
||||
async def __call__(self, messages: Sequence[AgentEvent | ChatMessage]) -> StopMessage | None:
|
||||
async def __call__(self, messages: Sequence[BaseAgentEvent | BaseChatMessage]) -> StopMessage | None:
|
||||
if self.terminated:
|
||||
raise TerminatedException("Termination condition has already been reached")
|
||||
self._message_count += len([m for m in messages if self._include_agent_event or isinstance(m, BaseChatMessage)])
|
||||
|
@ -129,7 +128,7 @@ class TextMentionTermination(TerminationCondition, Component[TextMentionTerminat
|
|||
def terminated(self) -> bool:
|
||||
return self._terminated
|
||||
|
||||
async def __call__(self, messages: Sequence[AgentEvent | ChatMessage]) -> StopMessage | None:
|
||||
async def __call__(self, messages: Sequence[BaseAgentEvent | BaseChatMessage]) -> StopMessage | None:
|
||||
if self._terminated:
|
||||
raise TerminatedException("Termination condition has already been reached")
|
||||
for message in messages:
|
||||
|
@ -201,7 +200,7 @@ class TokenUsageTermination(TerminationCondition, Component[TokenUsageTerminatio
|
|||
or (self._max_completion_token is not None and self._completion_token_count >= self._max_completion_token)
|
||||
)
|
||||
|
||||
async def __call__(self, messages: Sequence[AgentEvent | ChatMessage]) -> StopMessage | None:
|
||||
async def __call__(self, messages: Sequence[BaseAgentEvent | BaseChatMessage]) -> StopMessage | None:
|
||||
if self.terminated:
|
||||
raise TerminatedException("Termination condition has already been reached")
|
||||
for message in messages:
|
||||
|
@ -258,7 +257,7 @@ class HandoffTermination(TerminationCondition, Component[HandoffTerminationConfi
|
|||
def terminated(self) -> bool:
|
||||
return self._terminated
|
||||
|
||||
async def __call__(self, messages: Sequence[AgentEvent | ChatMessage]) -> StopMessage | None:
|
||||
async def __call__(self, messages: Sequence[BaseAgentEvent | BaseChatMessage]) -> StopMessage | None:
|
||||
if self._terminated:
|
||||
raise TerminatedException("Termination condition has already been reached")
|
||||
for message in messages:
|
||||
|
@ -303,7 +302,7 @@ class TimeoutTermination(TerminationCondition, Component[TimeoutTerminationConfi
|
|||
def terminated(self) -> bool:
|
||||
return self._terminated
|
||||
|
||||
async def __call__(self, messages: Sequence[AgentEvent | ChatMessage]) -> StopMessage | None:
|
||||
async def __call__(self, messages: Sequence[BaseAgentEvent | BaseChatMessage]) -> StopMessage | None:
|
||||
if self._terminated:
|
||||
raise TerminatedException("Termination condition has already been reached")
|
||||
|
||||
|
@ -365,7 +364,7 @@ class ExternalTermination(TerminationCondition, Component[ExternalTerminationCon
|
|||
"""Set the termination condition to terminated."""
|
||||
self._setted = True
|
||||
|
||||
async def __call__(self, messages: Sequence[AgentEvent | ChatMessage]) -> StopMessage | None:
|
||||
async def __call__(self, messages: Sequence[BaseAgentEvent | BaseChatMessage]) -> StopMessage | None:
|
||||
if self._terminated:
|
||||
raise TerminatedException("Termination condition has already been reached")
|
||||
if self._setted:
|
||||
|
@ -410,7 +409,7 @@ class SourceMatchTermination(TerminationCondition, Component[SourceMatchTerminat
|
|||
def terminated(self) -> bool:
|
||||
return self._terminated
|
||||
|
||||
async def __call__(self, messages: Sequence[AgentEvent | ChatMessage]) -> StopMessage | None:
|
||||
async def __call__(self, messages: Sequence[BaseAgentEvent | BaseChatMessage]) -> StopMessage | None:
|
||||
if self._terminated:
|
||||
raise TerminatedException("Termination condition has already been reached")
|
||||
if not messages:
|
||||
|
@ -463,7 +462,7 @@ class TextMessageTermination(TerminationCondition, Component[TextMessageTerminat
|
|||
def terminated(self) -> bool:
|
||||
return self._terminated
|
||||
|
||||
async def __call__(self, messages: Sequence[AgentEvent | ChatMessage]) -> StopMessage | None:
|
||||
async def __call__(self, messages: Sequence[BaseAgentEvent | BaseChatMessage]) -> StopMessage | None:
|
||||
if self._terminated:
|
||||
raise TerminatedException("Termination condition has already been reached")
|
||||
for message in messages:
|
||||
|
@ -513,7 +512,7 @@ class FunctionCallTermination(TerminationCondition, Component[FunctionCallTermin
|
|||
def terminated(self) -> bool:
|
||||
return self._terminated
|
||||
|
||||
async def __call__(self, messages: Sequence[AgentEvent | ChatMessage]) -> StopMessage | None:
|
||||
async def __call__(self, messages: Sequence[BaseAgentEvent | BaseChatMessage]) -> StopMessage | None:
|
||||
if self._terminated:
|
||||
raise TerminatedException("Termination condition has already been reached")
|
||||
for message in messages:
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
"""
|
||||
This module defines various message types used for agent-to-agent communication.
|
||||
Each message type inherits either from the ChatMessage class or BaseAgentEvent
|
||||
Each message type inherits either from the BaseChatMessage class or BaseAgentEvent
|
||||
class and includes specific fields relevant to the type of message being sent.
|
||||
"""
|
||||
|
||||
|
@ -10,33 +10,36 @@ from typing import Any, Dict, Generic, List, Literal, Mapping, TypeVar
|
|||
from autogen_core import FunctionCall, Image
|
||||
from autogen_core.memory import MemoryContent
|
||||
from autogen_core.models import FunctionExecutionResult, LLMMessage, RequestUsage, UserMessage
|
||||
from pydantic import BaseModel, ConfigDict, computed_field
|
||||
from typing_extensions import Self
|
||||
from pydantic import BaseModel, Field, computed_field
|
||||
from typing_extensions import Annotated, Self
|
||||
|
||||
|
||||
class BaseMessage(BaseModel, ABC):
|
||||
"""Base class for all message types in AgentChat. This is an abstract class
|
||||
with default implementations for serialization and deserialization.
|
||||
"""Abstract base class for all message types in AgentChat.
|
||||
|
||||
.. warning::
|
||||
|
||||
If you want to create a new message type, do not inherit from this class.
|
||||
Instead, inherit from :class:`ChatMessage` or :class:`AgentEvent`
|
||||
Instead, inherit from :class:`BaseChatMessage` or :class:`BaseAgentEvent`
|
||||
to clarify the purpose of the message type.
|
||||
|
||||
"""
|
||||
|
||||
@computed_field
|
||||
def type(self) -> str:
|
||||
"""The class name of this message."""
|
||||
return self.__class__.__name__
|
||||
@abstractmethod
|
||||
def to_text(self) -> str:
|
||||
"""Convert the message content to a string-only representation
|
||||
that can be rendered in the console and inspected by the user or conditions.
|
||||
This is not used for creating text-only content for models.
|
||||
For :class:`BaseChatMessage` types, use :meth:`to_model_text` instead."""
|
||||
...
|
||||
|
||||
def dump(self) -> Mapping[str, Any]:
|
||||
"""Convert the message to a JSON-serializable dictionary.
|
||||
|
||||
The default implementation uses the Pydantic model's `model_dump` method.
|
||||
|
||||
If you want to customize the serialization, override this method.
|
||||
The default implementation uses the Pydantic model's
|
||||
:meth:`model_dump` method to convert the message to a dictionary.
|
||||
Override this method if you want to customize the serialization
|
||||
process or add additional fields to the output.
|
||||
"""
|
||||
return self.model_dump()
|
||||
|
||||
|
@ -44,14 +47,15 @@ class BaseMessage(BaseModel, ABC):
|
|||
def load(cls, data: Mapping[str, Any]) -> Self:
|
||||
"""Create a message from a dictionary of JSON-serializable data.
|
||||
|
||||
The default implementation uses the Pydantic model's `model_validate` method.
|
||||
If you want to customize the deserialization, override this method.
|
||||
"""
|
||||
The default implementation uses the Pydantic model's
|
||||
:meth:`model_validate` method to create the message from the data.
|
||||
Override this method if you want to customize the deserialization
|
||||
process or add additional fields to the input data."""
|
||||
return cls.model_validate(data)
|
||||
|
||||
|
||||
class ChatMessage(BaseMessage, ABC):
|
||||
"""Base class for chat messages.
|
||||
class BaseChatMessage(BaseMessage, ABC):
|
||||
"""Abstract base class for chat messages.
|
||||
|
||||
.. note::
|
||||
|
||||
|
@ -62,7 +66,7 @@ class ChatMessage(BaseMessage, ABC):
|
|||
|
||||
This class is used for messages that are sent between agents in a chat
|
||||
conversation. Agents are expected to process the content of the
|
||||
message using models and return a response as another :class:`ChatMessage`.
|
||||
message using models and return a response as another :class:`BaseChatMessage`.
|
||||
"""
|
||||
|
||||
source: str
|
||||
|
@ -74,17 +78,6 @@ class ChatMessage(BaseMessage, ABC):
|
|||
metadata: Dict[str, str] = {}
|
||||
"""Additional metadata about the message."""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
@abstractmethod
|
||||
def to_text(self) -> str:
|
||||
"""Convert the content of the message to a string-only representation
|
||||
that can be rendered in the console and inspected by the user or conditions.
|
||||
|
||||
This is not used for creating text-only content for models.
|
||||
For :class:`ChatMessage` types, use :meth:`to_model_text` instead."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def to_model_text(self) -> str:
|
||||
"""Convert the content of the message to text-only representation.
|
||||
|
@ -107,8 +100,8 @@ class ChatMessage(BaseMessage, ABC):
|
|||
...
|
||||
|
||||
|
||||
class TextChatMessage(ChatMessage, ABC):
|
||||
"""Base class for all text-only :class:`ChatMessage` types.
|
||||
class BaseTextChatMessage(BaseChatMessage, ABC):
|
||||
"""Base class for all text-only :class:`BaseChatMessage` types.
|
||||
It has implementations for :meth:`to_text`, :meth:`to_model_text`,
|
||||
and :meth:`to_model_message` methods.
|
||||
|
||||
|
@ -128,7 +121,7 @@ class TextChatMessage(ChatMessage, ABC):
|
|||
return UserMessage(content=self.content, source=self.source)
|
||||
|
||||
|
||||
class AgentEvent(BaseMessage, ABC):
|
||||
class BaseAgentEvent(BaseMessage, ABC):
|
||||
"""Base class for agent events.
|
||||
|
||||
.. note::
|
||||
|
@ -153,24 +146,13 @@ class AgentEvent(BaseMessage, ABC):
|
|||
metadata: Dict[str, str] = {}
|
||||
"""Additional metadata about the message."""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
@abstractmethod
|
||||
def to_text(self) -> str:
|
||||
"""Convert the content of the message to a string-only representation
|
||||
that can be rendered in the console and inspected by the user.
|
||||
|
||||
This is not used for creating text-only content for models.
|
||||
For :class:`ChatMessage` types, use :meth:`to_model_text` instead."""
|
||||
...
|
||||
|
||||
|
||||
StructuredContentType = TypeVar("StructuredContentType", bound=BaseModel, covariant=True)
|
||||
"""Type variable for structured content types."""
|
||||
|
||||
|
||||
class StructuredMessage(ChatMessage, Generic[StructuredContentType]):
|
||||
"""A :class:`ChatMessage` type with an unspecified content type.
|
||||
class StructuredMessage(BaseChatMessage, Generic[StructuredContentType]):
|
||||
"""A :class:`BaseChatMessage` type with an unspecified content type.
|
||||
|
||||
To create a new structured message type, specify the content type
|
||||
as a subclass of `Pydantic BaseModel <https://docs.pydantic.dev/latest/concepts/models/>`_.
|
||||
|
@ -199,6 +181,10 @@ class StructuredMessage(ChatMessage, Generic[StructuredContentType]):
|
|||
"""The content of the message. Must be a subclass of
|
||||
`Pydantic BaseModel <https://docs.pydantic.dev/latest/concepts/models/>`_."""
|
||||
|
||||
@computed_field
|
||||
def type(self) -> str:
|
||||
return self.__class__.__name__
|
||||
|
||||
def to_text(self) -> str:
|
||||
return self.content.model_dump_json(indent=2)
|
||||
|
||||
|
@ -212,18 +198,20 @@ class StructuredMessage(ChatMessage, Generic[StructuredContentType]):
|
|||
)
|
||||
|
||||
|
||||
class TextMessage(TextChatMessage):
|
||||
class TextMessage(BaseTextChatMessage):
|
||||
"""A text message with string-only content."""
|
||||
|
||||
...
|
||||
type: Literal["TextMessage"] = "TextMessage"
|
||||
|
||||
|
||||
class MultiModalMessage(ChatMessage):
|
||||
class MultiModalMessage(BaseChatMessage):
|
||||
"""A multimodal message."""
|
||||
|
||||
content: List[str | Image]
|
||||
"""The content of the message."""
|
||||
|
||||
type: Literal["MultiModalMessage"] = "MultiModalMessage"
|
||||
|
||||
def to_model_text(self, image_placeholder: str | None = "[image]") -> str:
|
||||
"""Convert the content of the message to a string-only representation.
|
||||
If an image is present, it will be replaced with the image placeholder
|
||||
|
@ -258,13 +246,13 @@ class MultiModalMessage(ChatMessage):
|
|||
return UserMessage(content=self.content, source=self.source)
|
||||
|
||||
|
||||
class StopMessage(TextChatMessage):
|
||||
class StopMessage(BaseTextChatMessage):
|
||||
"""A message requesting stop of a conversation."""
|
||||
|
||||
...
|
||||
type: Literal["StopMessage"] = "StopMessage"
|
||||
|
||||
|
||||
class HandoffMessage(TextChatMessage):
|
||||
class HandoffMessage(BaseTextChatMessage):
|
||||
"""A message requesting handoff of a conversation to another agent."""
|
||||
|
||||
target: str
|
||||
|
@ -273,34 +261,40 @@ class HandoffMessage(TextChatMessage):
|
|||
context: List[LLMMessage] = []
|
||||
"""The model context to be passed to the target agent."""
|
||||
|
||||
type: Literal["HandoffMessage"] = "HandoffMessage"
|
||||
|
||||
class ToolCallSummaryMessage(TextChatMessage):
|
||||
|
||||
class ToolCallSummaryMessage(BaseTextChatMessage):
|
||||
"""A message signaling the summary of tool call results."""
|
||||
|
||||
...
|
||||
type: Literal["ToolCallSummaryMessage"] = "ToolCallSummaryMessage"
|
||||
|
||||
|
||||
class ToolCallRequestEvent(AgentEvent):
|
||||
class ToolCallRequestEvent(BaseAgentEvent):
|
||||
"""An event signaling a request to use tools."""
|
||||
|
||||
content: List[FunctionCall]
|
||||
"""The tool calls."""
|
||||
|
||||
type: Literal["ToolCallRequestEvent"] = "ToolCallRequestEvent"
|
||||
|
||||
def to_text(self) -> str:
|
||||
return str(self.content)
|
||||
|
||||
|
||||
class ToolCallExecutionEvent(AgentEvent):
|
||||
class ToolCallExecutionEvent(BaseAgentEvent):
|
||||
"""An event signaling the execution of tool calls."""
|
||||
|
||||
content: List[FunctionExecutionResult]
|
||||
"""The tool call results."""
|
||||
|
||||
type: Literal["ToolCallExecutionEvent"] = "ToolCallExecutionEvent"
|
||||
|
||||
def to_text(self) -> str:
|
||||
return str(self.content)
|
||||
|
||||
|
||||
class UserInputRequestedEvent(AgentEvent):
|
||||
class UserInputRequestedEvent(BaseAgentEvent):
|
||||
"""An event signaling a that the user proxy has requested user input. Published prior to invoking the input callback."""
|
||||
|
||||
request_id: str
|
||||
|
@ -309,31 +303,37 @@ class UserInputRequestedEvent(AgentEvent):
|
|||
content: Literal[""] = ""
|
||||
"""Empty content for compat with consumers expecting a content field."""
|
||||
|
||||
type: Literal["UserInputRequestedEvent"] = "UserInputRequestedEvent"
|
||||
|
||||
def to_text(self) -> str:
|
||||
return str(self.content)
|
||||
|
||||
|
||||
class MemoryQueryEvent(AgentEvent):
|
||||
class MemoryQueryEvent(BaseAgentEvent):
|
||||
"""An event signaling the results of memory queries."""
|
||||
|
||||
content: List[MemoryContent]
|
||||
"""The memory query results."""
|
||||
|
||||
type: Literal["MemoryQueryEvent"] = "MemoryQueryEvent"
|
||||
|
||||
def to_text(self) -> str:
|
||||
return str(self.content)
|
||||
|
||||
|
||||
class ModelClientStreamingChunkEvent(AgentEvent):
|
||||
class ModelClientStreamingChunkEvent(BaseAgentEvent):
|
||||
"""An event signaling a text output chunk from a model client in streaming mode."""
|
||||
|
||||
content: str
|
||||
"""A string chunk from the model client."""
|
||||
|
||||
type: Literal["ModelClientStreamingChunkEvent"] = "ModelClientStreamingChunkEvent"
|
||||
|
||||
def to_text(self) -> str:
|
||||
return self.content
|
||||
|
||||
|
||||
class ThoughtEvent(AgentEvent):
|
||||
class ThoughtEvent(BaseAgentEvent):
|
||||
"""An event signaling the thought process of a model.
|
||||
It is used to communicate the reasoning tokens generated by a reasoning model,
|
||||
or the extra text content generated by a function call."""
|
||||
|
@ -341,6 +341,8 @@ class ThoughtEvent(AgentEvent):
|
|||
content: str
|
||||
"""The thought process of the model."""
|
||||
|
||||
type: Literal["ThoughtEvent"] = "ThoughtEvent"
|
||||
|
||||
def to_text(self) -> str:
|
||||
return self.content
|
||||
|
||||
|
@ -354,7 +356,7 @@ class MessageFactory:
|
|||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._message_types: Dict[str, type[AgentEvent | ChatMessage]] = {}
|
||||
self._message_types: Dict[str, type[BaseAgentEvent | BaseChatMessage]] = {}
|
||||
# Register all message types.
|
||||
self._message_types[TextMessage.__name__] = TextMessage
|
||||
self._message_types[MultiModalMessage.__name__] = MultiModalMessage
|
||||
|
@ -368,29 +370,31 @@ class MessageFactory:
|
|||
self._message_types[ModelClientStreamingChunkEvent.__name__] = ModelClientStreamingChunkEvent
|
||||
self._message_types[ThoughtEvent.__name__] = ThoughtEvent
|
||||
|
||||
def is_registered(self, message_type: type[AgentEvent | ChatMessage]) -> bool:
|
||||
def is_registered(self, message_type: type[BaseAgentEvent | BaseChatMessage]) -> bool:
|
||||
"""Check if a message type is registered with the factory."""
|
||||
# Get the class name of the message type.
|
||||
class_name = message_type.__name__
|
||||
# Check if the class name is already registered.
|
||||
return class_name in self._message_types
|
||||
|
||||
def register(self, message_type: type[AgentEvent | ChatMessage]) -> None:
|
||||
def register(self, message_type: type[BaseAgentEvent | BaseChatMessage]) -> None:
|
||||
"""Register a new message type with the factory."""
|
||||
if self.is_registered(message_type):
|
||||
raise ValueError(f"Message type {message_type} is already registered.")
|
||||
if not issubclass(message_type, ChatMessage) and not issubclass(message_type, AgentEvent):
|
||||
raise ValueError(f"Message type {message_type} must be a subclass of ChatMessage or AgentEvent.")
|
||||
if not issubclass(message_type, BaseChatMessage) and not issubclass(message_type, BaseAgentEvent):
|
||||
raise ValueError(f"Message type {message_type} must be a subclass of BaseChatMessage or BaseAgentEvent.")
|
||||
# Get the class name of the
|
||||
class_name = message_type.__name__
|
||||
# Check if the class name is already registered.
|
||||
# Register the message type.
|
||||
self._message_types[class_name] = message_type
|
||||
|
||||
def create(self, data: Mapping[str, Any]) -> AgentEvent | ChatMessage:
|
||||
def create(self, data: Mapping[str, Any]) -> BaseAgentEvent | BaseChatMessage:
|
||||
"""Create a message from a dictionary of JSON-serializable data."""
|
||||
# Get the type of the message from the dictionary.
|
||||
message_type = data.get("type")
|
||||
if message_type is None:
|
||||
raise ValueError("Field 'type' is required in the message data to recover the message type.")
|
||||
if message_type not in self._message_types:
|
||||
raise ValueError(f"Unknown message type: {message_type}")
|
||||
if not isinstance(message_type, str):
|
||||
|
@ -400,14 +404,26 @@ class MessageFactory:
|
|||
message_class = self._message_types[message_type]
|
||||
|
||||
# Create an instance of the message class.
|
||||
assert issubclass(message_class, ChatMessage) or issubclass(message_class, AgentEvent)
|
||||
assert issubclass(message_class, BaseChatMessage) or issubclass(message_class, BaseAgentEvent)
|
||||
return message_class.load(data)
|
||||
|
||||
|
||||
# For backward compatibility
|
||||
BaseAgentEvent = AgentEvent
|
||||
BaseChatMessage = ChatMessage
|
||||
ChatMessage = Annotated[
|
||||
TextMessage | MultiModalMessage | StopMessage | ToolCallSummaryMessage | HandoffMessage, Field(discriminator="type")
|
||||
]
|
||||
"""The union type of all built-in concrete subclasses of :class:`BaseChatMessage`.
|
||||
It does not include :class:`StructuredMessage` types."""
|
||||
|
||||
AgentEvent = Annotated[
|
||||
ToolCallRequestEvent
|
||||
| ToolCallExecutionEvent
|
||||
| MemoryQueryEvent
|
||||
| UserInputRequestedEvent
|
||||
| ModelClientStreamingChunkEvent
|
||||
| ThoughtEvent,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
"""The union type of all built-in concrete subclasses of :class:`BaseAgentEvent`."""
|
||||
|
||||
__all__ = [
|
||||
"AgentEvent",
|
||||
|
@ -415,9 +431,8 @@ __all__ = [
|
|||
"ChatMessage",
|
||||
"BaseChatMessage",
|
||||
"BaseAgentEvent",
|
||||
"AgentEvent",
|
||||
"TextChatMessage",
|
||||
"ChatMessage",
|
||||
"BaseTextChatMessage",
|
||||
"BaseChatMessage",
|
||||
"StructuredContentType",
|
||||
"StructuredMessage",
|
||||
"HandoffMessage",
|
||||
|
|
|
@ -18,8 +18,8 @@ from pydantic import BaseModel, ValidationError
|
|||
from ... import EVENT_LOGGER_NAME
|
||||
from ...base import ChatAgent, TaskResult, Team, TerminationCondition
|
||||
from ...messages import (
|
||||
AgentEvent,
|
||||
ChatMessage,
|
||||
BaseAgentEvent,
|
||||
BaseChatMessage,
|
||||
MessageFactory,
|
||||
ModelClientStreamingChunkEvent,
|
||||
StopMessage,
|
||||
|
@ -50,7 +50,7 @@ class BaseGroupChat(Team, ABC, ComponentBase[BaseModel]):
|
|||
termination_condition: TerminationCondition | None = None,
|
||||
max_turns: int | None = None,
|
||||
runtime: AgentRuntime | None = None,
|
||||
custom_message_types: List[type[AgentEvent | ChatMessage]] | None = None,
|
||||
custom_message_types: List[type[BaseAgentEvent | BaseChatMessage]] | None = None,
|
||||
):
|
||||
if len(participants) == 0:
|
||||
raise ValueError("At least one participant is required.")
|
||||
|
@ -90,7 +90,9 @@ class BaseGroupChat(Team, ABC, ComponentBase[BaseModel]):
|
|||
self._output_topic_type = f"output_topic_{self._team_id}"
|
||||
|
||||
# The queue for collecting the output messages.
|
||||
self._output_message_queue: asyncio.Queue[AgentEvent | ChatMessage | GroupChatTermination] = asyncio.Queue()
|
||||
self._output_message_queue: asyncio.Queue[BaseAgentEvent | BaseChatMessage | GroupChatTermination] = (
|
||||
asyncio.Queue()
|
||||
)
|
||||
|
||||
# Create a runtime for the team.
|
||||
if runtime is not None:
|
||||
|
@ -117,7 +119,7 @@ class BaseGroupChat(Team, ABC, ComponentBase[BaseModel]):
|
|||
participant_topic_types: List[str],
|
||||
participant_names: List[str],
|
||||
participant_descriptions: List[str],
|
||||
output_message_queue: asyncio.Queue[AgentEvent | ChatMessage | GroupChatTermination],
|
||||
output_message_queue: asyncio.Queue[BaseAgentEvent | BaseChatMessage | GroupChatTermination],
|
||||
termination_condition: TerminationCondition | None,
|
||||
max_turns: int | None,
|
||||
message_factory: MessageFactory,
|
||||
|
@ -195,7 +197,7 @@ class BaseGroupChat(Team, ABC, ComponentBase[BaseModel]):
|
|||
async def run(
|
||||
self,
|
||||
*,
|
||||
task: str | ChatMessage | Sequence[ChatMessage] | None = None,
|
||||
task: str | BaseChatMessage | Sequence[BaseChatMessage] | None = None,
|
||||
cancellation_token: CancellationToken | None = None,
|
||||
) -> TaskResult:
|
||||
"""Run the team and return the result. The base implementation uses
|
||||
|
@ -203,7 +205,7 @@ class BaseGroupChat(Team, ABC, ComponentBase[BaseModel]):
|
|||
Once the team is stopped, the termination condition is reset.
|
||||
|
||||
Args:
|
||||
task (str | ChatMessage | Sequence[ChatMessage] | None): The task to run the team with. Can be a string, a single :class:`ChatMessage` , or a list of :class:`ChatMessage`.
|
||||
task (str | BaseChatMessage | Sequence[BaseChatMessage] | None): The task to run the team with. Can be a string, a single :class:`BaseChatMessage` , or a list of :class:`BaseChatMessage`.
|
||||
cancellation_token (CancellationToken | None): The cancellation token to kill the task immediately.
|
||||
Setting the cancellation token potentially put the team in an inconsistent state,
|
||||
and it may not reset the termination condition.
|
||||
|
@ -297,9 +299,9 @@ class BaseGroupChat(Team, ABC, ComponentBase[BaseModel]):
|
|||
async def run_stream(
|
||||
self,
|
||||
*,
|
||||
task: str | ChatMessage | Sequence[ChatMessage] | None = None,
|
||||
task: str | BaseChatMessage | Sequence[BaseChatMessage] | None = None,
|
||||
cancellation_token: CancellationToken | None = None,
|
||||
) -> AsyncGenerator[AgentEvent | ChatMessage | TaskResult, None]:
|
||||
) -> AsyncGenerator[BaseAgentEvent | BaseChatMessage | TaskResult, None]:
|
||||
"""Run the team and produces a stream of messages and the final result
|
||||
of the type :class:`~autogen_agentchat.base.TaskResult` as the last item in the stream. Once the
|
||||
team is stopped, the termination condition is reset.
|
||||
|
@ -311,14 +313,14 @@ class BaseGroupChat(Team, ABC, ComponentBase[BaseModel]):
|
|||
:attr:`~autogen_agentchat.base.TaskResult.messages`.
|
||||
|
||||
Args:
|
||||
task (str | ChatMessage | Sequence[ChatMessage] | None): The task to run the team with. Can be a string, a single :class:`ChatMessage` , or a list of :class:`ChatMessage`.
|
||||
task (str | BaseChatMessage | Sequence[BaseChatMessage] | None): The task to run the team with. Can be a string, a single :class:`BaseChatMessage` , or a list of :class:`BaseChatMessage`.
|
||||
cancellation_token (CancellationToken | None): The cancellation token to kill the task immediately.
|
||||
Setting the cancellation token potentially put the team in an inconsistent state,
|
||||
and it may not reset the termination condition.
|
||||
To gracefully stop the team, use :class:`~autogen_agentchat.conditions.ExternalTermination` instead.
|
||||
|
||||
Returns:
|
||||
stream: an :class:`~collections.abc.AsyncGenerator` that yields :class:`~autogen_agentchat.messages.AgentEvent`, :class:`~autogen_agentchat.messages.ChatMessage`, and the final result :class:`~autogen_agentchat.base.TaskResult` as the last item in the stream.
|
||||
stream: an :class:`~collections.abc.AsyncGenerator` that yields :class:`~autogen_agentchat.messages.BaseAgentEvent`, :class:`~autogen_agentchat.messages.BaseChatMessage`, and the final result :class:`~autogen_agentchat.base.TaskResult` as the last item in the stream.
|
||||
|
||||
Example using the :class:`~autogen_agentchat.teams.RoundRobinGroupChat` team:
|
||||
|
||||
|
@ -398,23 +400,23 @@ class BaseGroupChat(Team, ABC, ComponentBase[BaseModel]):
|
|||
"""
|
||||
|
||||
# Create the messages list if the task is a string or a chat message.
|
||||
messages: List[ChatMessage] | None = None
|
||||
messages: List[BaseChatMessage] | None = None
|
||||
if task is None:
|
||||
pass
|
||||
elif isinstance(task, str):
|
||||
messages = [TextMessage(content=task, source="user")]
|
||||
elif isinstance(task, ChatMessage):
|
||||
elif isinstance(task, BaseChatMessage):
|
||||
messages = [task]
|
||||
elif isinstance(task, list):
|
||||
if not task:
|
||||
raise ValueError("Task list cannot be empty.")
|
||||
messages = []
|
||||
for msg in task:
|
||||
if not isinstance(msg, ChatMessage):
|
||||
raise ValueError("All messages in task list must be valid ChatMessage types")
|
||||
if not isinstance(msg, BaseChatMessage):
|
||||
raise ValueError("All messages in task list must be valid BaseChatMessage types")
|
||||
messages.append(msg)
|
||||
else:
|
||||
raise ValueError("Task must be a string, a ChatMessage, or a list of ChatMessage.")
|
||||
raise ValueError("Task must be a string, a BaseChatMessage, or a list of BaseChatMessage.")
|
||||
# Check if the messages types are registered with the message factory.
|
||||
if messages is not None:
|
||||
for msg in messages:
|
||||
|
@ -469,7 +471,7 @@ class BaseGroupChat(Team, ABC, ComponentBase[BaseModel]):
|
|||
cancellation_token=cancellation_token,
|
||||
)
|
||||
# Collect the output messages in order.
|
||||
output_messages: List[AgentEvent | ChatMessage] = []
|
||||
output_messages: List[BaseAgentEvent | BaseChatMessage] = []
|
||||
stop_reason: str | None = None
|
||||
# Yield the messsages until the queue is empty.
|
||||
while True:
|
||||
|
|
|
@ -5,7 +5,7 @@ from typing import Any, List
|
|||
from autogen_core import DefaultTopicId, MessageContext, event, rpc
|
||||
|
||||
from ...base import TerminationCondition
|
||||
from ...messages import AgentEvent, ChatMessage, MessageFactory, StopMessage
|
||||
from ...messages import BaseAgentEvent, BaseChatMessage, MessageFactory, StopMessage
|
||||
from ._events import (
|
||||
GroupChatAgentResponse,
|
||||
GroupChatMessage,
|
||||
|
@ -39,7 +39,7 @@ class BaseGroupChatManager(SequentialRoutedAgent, ABC):
|
|||
participant_topic_types: List[str],
|
||||
participant_names: List[str],
|
||||
participant_descriptions: List[str],
|
||||
output_message_queue: asyncio.Queue[AgentEvent | ChatMessage | GroupChatTermination],
|
||||
output_message_queue: asyncio.Queue[BaseAgentEvent | BaseChatMessage | GroupChatTermination],
|
||||
termination_condition: TerminationCondition | None,
|
||||
max_turns: int | None,
|
||||
message_factory: MessageFactory,
|
||||
|
@ -67,7 +67,7 @@ class BaseGroupChatManager(SequentialRoutedAgent, ABC):
|
|||
name: topic_type for name, topic_type in zip(participant_names, participant_topic_types, strict=True)
|
||||
}
|
||||
self._participant_descriptions = participant_descriptions
|
||||
self._message_thread: List[AgentEvent | ChatMessage] = []
|
||||
self._message_thread: List[BaseAgentEvent | BaseChatMessage] = []
|
||||
self._output_message_queue = output_message_queue
|
||||
self._termination_condition = termination_condition
|
||||
if max_turns is not None and max_turns <= 0:
|
||||
|
@ -141,7 +141,7 @@ class BaseGroupChatManager(SequentialRoutedAgent, ABC):
|
|||
@event
|
||||
async def handle_agent_response(self, message: GroupChatAgentResponse, ctx: MessageContext) -> None:
|
||||
# Append the message to the message thread and construct the delta.
|
||||
delta: List[AgentEvent | ChatMessage] = []
|
||||
delta: List[BaseAgentEvent | BaseChatMessage] = []
|
||||
if message.agent_response.inner_messages is not None:
|
||||
for inner_message in message.agent_response.inner_messages:
|
||||
self._message_thread.append(inner_message)
|
||||
|
@ -225,7 +225,7 @@ class BaseGroupChatManager(SequentialRoutedAgent, ABC):
|
|||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def validate_group_state(self, messages: List[ChatMessage] | None) -> None:
|
||||
async def validate_group_state(self, messages: List[BaseChatMessage] | None) -> None:
|
||||
"""Validate the state of the group chat given the start messages.
|
||||
This is executed when the group chat manager receives a GroupChatStart event.
|
||||
|
||||
|
@ -235,7 +235,7 @@ class BaseGroupChatManager(SequentialRoutedAgent, ABC):
|
|||
...
|
||||
|
||||
@abstractmethod
|
||||
async def select_speaker(self, thread: List[AgentEvent | ChatMessage]) -> str:
|
||||
async def select_speaker(self, thread: List[BaseAgentEvent | BaseChatMessage]) -> str:
|
||||
"""Select a speaker from the participants and return the
|
||||
topic type of the selected speaker."""
|
||||
...
|
||||
|
|
|
@ -2,7 +2,7 @@ from typing import Any, List, Mapping
|
|||
|
||||
from autogen_core import DefaultTopicId, MessageContext, event, rpc
|
||||
|
||||
from autogen_agentchat.messages import AgentEvent, ChatMessage, MessageFactory
|
||||
from autogen_agentchat.messages import BaseAgentEvent, BaseChatMessage, MessageFactory
|
||||
|
||||
from ...base import ChatAgent, Response
|
||||
from ...state import ChatAgentContainerState
|
||||
|
@ -46,7 +46,7 @@ class ChatAgentContainer(SequentialRoutedAgent):
|
|||
self._parent_topic_type = parent_topic_type
|
||||
self._output_topic_type = output_topic_type
|
||||
self._agent = agent
|
||||
self._message_buffer: List[ChatMessage] = []
|
||||
self._message_buffer: List[BaseChatMessage] = []
|
||||
self._message_factory = message_factory
|
||||
|
||||
@event
|
||||
|
@ -90,13 +90,13 @@ class ChatAgentContainer(SequentialRoutedAgent):
|
|||
cancellation_token=ctx.cancellation_token,
|
||||
)
|
||||
|
||||
def _buffer_message(self, message: ChatMessage) -> None:
|
||||
def _buffer_message(self, message: BaseChatMessage) -> None:
|
||||
if not self._message_factory.is_registered(message.__class__):
|
||||
raise ValueError(f"Message type {message.__class__} is not registered.")
|
||||
# Buffer the message.
|
||||
self._message_buffer.append(message)
|
||||
|
||||
async def _log_message(self, message: AgentEvent | ChatMessage) -> None:
|
||||
async def _log_message(self, message: BaseAgentEvent | BaseChatMessage) -> None:
|
||||
if not self._message_factory.is_registered(message.__class__):
|
||||
raise ValueError(f"Message type {message.__class__} is not registered.")
|
||||
# Log the message.
|
||||
|
@ -130,7 +130,7 @@ class ChatAgentContainer(SequentialRoutedAgent):
|
|||
self._message_buffer = []
|
||||
for message_data in container_state.message_buffer:
|
||||
message = self._message_factory.create(message_data)
|
||||
if isinstance(message, ChatMessage):
|
||||
if isinstance(message, BaseChatMessage):
|
||||
self._message_buffer.append(message)
|
||||
else:
|
||||
raise ValueError(f"Invalid message type in message buffer: {type(message)}")
|
||||
|
|
|
@ -3,13 +3,13 @@ from typing import List
|
|||
from pydantic import BaseModel
|
||||
|
||||
from ...base import Response
|
||||
from ...messages import AgentEvent, ChatMessage, StopMessage
|
||||
from ...messages import BaseAgentEvent, BaseChatMessage, StopMessage
|
||||
|
||||
|
||||
class GroupChatStart(BaseModel):
|
||||
"""A request to start a group chat."""
|
||||
|
||||
messages: List[ChatMessage] | None = None
|
||||
messages: List[BaseChatMessage] | None = None
|
||||
"""An optional list of messages to start the group chat."""
|
||||
|
||||
|
||||
|
@ -29,7 +29,7 @@ class GroupChatRequestPublish(BaseModel):
|
|||
class GroupChatMessage(BaseModel):
|
||||
"""A message from a group chat."""
|
||||
|
||||
message: AgentEvent | ChatMessage
|
||||
message: BaseAgentEvent | BaseChatMessage
|
||||
"""The message that was published."""
|
||||
|
||||
|
||||
|
|
|
@ -9,7 +9,7 @@ from typing_extensions import Self
|
|||
|
||||
from .... import EVENT_LOGGER_NAME, TRACE_LOGGER_NAME
|
||||
from ....base import ChatAgent, TerminationCondition
|
||||
from ....messages import AgentEvent, ChatMessage, MessageFactory
|
||||
from ....messages import BaseAgentEvent, BaseChatMessage, MessageFactory
|
||||
from .._base_group_chat import BaseGroupChat
|
||||
from .._events import GroupChatTermination
|
||||
from ._magentic_one_orchestrator import MagenticOneOrchestrator
|
||||
|
@ -128,7 +128,7 @@ class MagenticOneGroupChat(BaseGroupChat, Component[MagenticOneGroupChatConfig])
|
|||
participant_topic_types: List[str],
|
||||
participant_names: List[str],
|
||||
participant_descriptions: List[str],
|
||||
output_message_queue: asyncio.Queue[AgentEvent | ChatMessage | GroupChatTermination],
|
||||
output_message_queue: asyncio.Queue[BaseAgentEvent | BaseChatMessage | GroupChatTermination],
|
||||
termination_condition: TerminationCondition | None,
|
||||
max_turns: int | None,
|
||||
message_factory: MessageFactory,
|
||||
|
|
|
@ -15,8 +15,8 @@ from autogen_core.models import (
|
|||
from .... import TRACE_LOGGER_NAME
|
||||
from ....base import Response, TerminationCondition
|
||||
from ....messages import (
|
||||
AgentEvent,
|
||||
ChatMessage,
|
||||
BaseAgentEvent,
|
||||
BaseChatMessage,
|
||||
HandoffMessage,
|
||||
MessageFactory,
|
||||
MultiModalMessage,
|
||||
|
@ -66,7 +66,7 @@ class MagenticOneOrchestrator(BaseGroupChatManager):
|
|||
model_client: ChatCompletionClient,
|
||||
max_stalls: int,
|
||||
final_answer_prompt: str,
|
||||
output_message_queue: asyncio.Queue[AgentEvent | ChatMessage | GroupChatTermination],
|
||||
output_message_queue: asyncio.Queue[BaseAgentEvent | BaseChatMessage | GroupChatTermination],
|
||||
termination_condition: TerminationCondition | None,
|
||||
):
|
||||
super().__init__(
|
||||
|
@ -184,7 +184,7 @@ class MagenticOneOrchestrator(BaseGroupChatManager):
|
|||
|
||||
@event
|
||||
async def handle_agent_response(self, message: GroupChatAgentResponse, ctx: MessageContext) -> None: # type: ignore
|
||||
delta: List[AgentEvent | ChatMessage] = []
|
||||
delta: List[BaseAgentEvent | BaseChatMessage] = []
|
||||
if message.agent_response.inner_messages is not None:
|
||||
for inner_message in message.agent_response.inner_messages:
|
||||
delta.append(inner_message)
|
||||
|
@ -201,7 +201,7 @@ class MagenticOneOrchestrator(BaseGroupChatManager):
|
|||
return
|
||||
await self._orchestrate_step(ctx.cancellation_token)
|
||||
|
||||
async def validate_group_state(self, messages: List[ChatMessage] | None) -> None:
|
||||
async def validate_group_state(self, messages: List[BaseChatMessage] | None) -> None:
|
||||
pass
|
||||
|
||||
async def save_state(self) -> Mapping[str, Any]:
|
||||
|
@ -226,7 +226,7 @@ class MagenticOneOrchestrator(BaseGroupChatManager):
|
|||
self._n_rounds = orchestrator_state.n_rounds
|
||||
self._n_stalls = orchestrator_state.n_stalls
|
||||
|
||||
async def select_speaker(self, thread: List[AgentEvent | ChatMessage]) -> str:
|
||||
async def select_speaker(self, thread: List[BaseAgentEvent | BaseChatMessage]) -> str:
|
||||
"""Not used in this orchestrator, we select next speaker in _orchestrate_step."""
|
||||
return ""
|
||||
|
||||
|
|
|
@ -6,7 +6,7 @@ from pydantic import BaseModel
|
|||
from typing_extensions import Self
|
||||
|
||||
from ...base import ChatAgent, TerminationCondition
|
||||
from ...messages import AgentEvent, ChatMessage, MessageFactory
|
||||
from ...messages import BaseAgentEvent, BaseChatMessage, MessageFactory
|
||||
from ...state import RoundRobinManagerState
|
||||
from ._base_group_chat import BaseGroupChat
|
||||
from ._base_group_chat_manager import BaseGroupChatManager
|
||||
|
@ -24,7 +24,7 @@ class RoundRobinGroupChatManager(BaseGroupChatManager):
|
|||
participant_topic_types: List[str],
|
||||
participant_names: List[str],
|
||||
participant_descriptions: List[str],
|
||||
output_message_queue: asyncio.Queue[AgentEvent | ChatMessage | GroupChatTermination],
|
||||
output_message_queue: asyncio.Queue[BaseAgentEvent | BaseChatMessage | GroupChatTermination],
|
||||
termination_condition: TerminationCondition | None,
|
||||
max_turns: int | None,
|
||||
message_factory: MessageFactory,
|
||||
|
@ -43,7 +43,7 @@ class RoundRobinGroupChatManager(BaseGroupChatManager):
|
|||
)
|
||||
self._next_speaker_index = 0
|
||||
|
||||
async def validate_group_state(self, messages: List[ChatMessage] | None) -> None:
|
||||
async def validate_group_state(self, messages: List[BaseChatMessage] | None) -> None:
|
||||
pass
|
||||
|
||||
async def reset(self) -> None:
|
||||
|
@ -67,7 +67,7 @@ class RoundRobinGroupChatManager(BaseGroupChatManager):
|
|||
self._current_turn = round_robin_state.current_turn
|
||||
self._next_speaker_index = round_robin_state.next_speaker_index
|
||||
|
||||
async def select_speaker(self, thread: List[AgentEvent | ChatMessage]) -> str:
|
||||
async def select_speaker(self, thread: List[BaseAgentEvent | BaseChatMessage]) -> str:
|
||||
"""Select a speaker from the participants in a round-robin fashion."""
|
||||
current_speaker_index = self._next_speaker_index
|
||||
self._next_speaker_index = (current_speaker_index + 1) % len(self._participant_names)
|
||||
|
@ -166,7 +166,7 @@ class RoundRobinGroupChat(BaseGroupChat, Component[RoundRobinGroupChatConfig]):
|
|||
termination_condition: TerminationCondition | None = None,
|
||||
max_turns: int | None = None,
|
||||
runtime: AgentRuntime | None = None,
|
||||
custom_message_types: List[type[AgentEvent | ChatMessage]] | None = None,
|
||||
custom_message_types: List[type[BaseAgentEvent | BaseChatMessage]] | None = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
participants,
|
||||
|
@ -186,7 +186,7 @@ class RoundRobinGroupChat(BaseGroupChat, Component[RoundRobinGroupChatConfig]):
|
|||
participant_topic_types: List[str],
|
||||
participant_names: List[str],
|
||||
participant_descriptions: List[str],
|
||||
output_message_queue: asyncio.Queue[AgentEvent | ChatMessage | GroupChatTermination],
|
||||
output_message_queue: asyncio.Queue[BaseAgentEvent | BaseChatMessage | GroupChatTermination],
|
||||
termination_condition: TerminationCondition | None,
|
||||
max_turns: int | None,
|
||||
message_factory: MessageFactory,
|
||||
|
|
|
@ -13,8 +13,8 @@ from ... import TRACE_LOGGER_NAME
|
|||
from ...agents import BaseChatAgent
|
||||
from ...base import ChatAgent, TerminationCondition
|
||||
from ...messages import (
|
||||
AgentEvent,
|
||||
ChatMessage,
|
||||
BaseAgentEvent,
|
||||
BaseChatMessage,
|
||||
MessageFactory,
|
||||
)
|
||||
from ...state import SelectorManagerState
|
||||
|
@ -24,12 +24,12 @@ from ._events import GroupChatTermination
|
|||
|
||||
trace_logger = logging.getLogger(TRACE_LOGGER_NAME)
|
||||
|
||||
SyncSelectorFunc = Callable[[Sequence[AgentEvent | ChatMessage]], str | None]
|
||||
AsyncSelectorFunc = Callable[[Sequence[AgentEvent | ChatMessage]], Awaitable[str | None]]
|
||||
SyncSelectorFunc = Callable[[Sequence[BaseAgentEvent | BaseChatMessage]], str | None]
|
||||
AsyncSelectorFunc = Callable[[Sequence[BaseAgentEvent | BaseChatMessage]], Awaitable[str | None]]
|
||||
SelectorFuncType = Union[SyncSelectorFunc | AsyncSelectorFunc]
|
||||
|
||||
SyncCandidateFunc = Callable[[Sequence[AgentEvent | ChatMessage]], List[str]]
|
||||
AsyncCandidateFunc = Callable[[Sequence[AgentEvent | ChatMessage]], Awaitable[List[str]]]
|
||||
SyncCandidateFunc = Callable[[Sequence[BaseAgentEvent | BaseChatMessage]], List[str]]
|
||||
AsyncCandidateFunc = Callable[[Sequence[BaseAgentEvent | BaseChatMessage]], Awaitable[List[str]]]
|
||||
CandidateFuncType = Union[SyncCandidateFunc | AsyncCandidateFunc]
|
||||
|
||||
|
||||
|
@ -45,7 +45,7 @@ class SelectorGroupChatManager(BaseGroupChatManager):
|
|||
participant_topic_types: List[str],
|
||||
participant_names: List[str],
|
||||
participant_descriptions: List[str],
|
||||
output_message_queue: asyncio.Queue[AgentEvent | ChatMessage | GroupChatTermination],
|
||||
output_message_queue: asyncio.Queue[BaseAgentEvent | BaseChatMessage | GroupChatTermination],
|
||||
termination_condition: TerminationCondition | None,
|
||||
max_turns: int | None,
|
||||
message_factory: MessageFactory,
|
||||
|
@ -78,7 +78,7 @@ class SelectorGroupChatManager(BaseGroupChatManager):
|
|||
self._candidate_func = candidate_func
|
||||
self._is_candidate_func_async = iscoroutinefunction(self._candidate_func)
|
||||
|
||||
async def validate_group_state(self, messages: List[ChatMessage] | None) -> None:
|
||||
async def validate_group_state(self, messages: List[BaseChatMessage] | None) -> None:
|
||||
pass
|
||||
|
||||
async def reset(self) -> None:
|
||||
|
@ -102,7 +102,7 @@ class SelectorGroupChatManager(BaseGroupChatManager):
|
|||
self._current_turn = selector_state.current_turn
|
||||
self._previous_speaker = selector_state.previous_speaker
|
||||
|
||||
async def select_speaker(self, thread: List[AgentEvent | ChatMessage]) -> str:
|
||||
async def select_speaker(self, thread: List[BaseAgentEvent | BaseChatMessage]) -> str:
|
||||
"""Selects the next speaker in a group chat using a ChatCompletion client,
|
||||
with the selector function as override if it returns a speaker name.
|
||||
|
||||
|
@ -153,7 +153,7 @@ class SelectorGroupChatManager(BaseGroupChatManager):
|
|||
# Construct the history of the conversation.
|
||||
history_messages: List[str] = []
|
||||
for msg in thread:
|
||||
if not isinstance(msg, ChatMessage):
|
||||
if not isinstance(msg, BaseChatMessage):
|
||||
# Only process chat messages.
|
||||
continue
|
||||
message = f"{msg.source}: {msg.to_model_text()}"
|
||||
|
@ -299,11 +299,11 @@ class SelectorGroupChat(BaseGroupChat, Component[SelectorGroupChatConfig]):
|
|||
max_selector_attempts (int, optional): The maximum number of attempts to select a speaker using the model. Defaults to 3.
|
||||
If the model fails to select a speaker after the maximum number of attempts, the previous speaker will be used if available,
|
||||
otherwise the first participant will be used.
|
||||
selector_func (Callable[[Sequence[AgentEvent | ChatMessage]], str | None], Callable[[Sequence[AgentEvent | ChatMessage]], Awaitable[str | None]], optional): A custom selector
|
||||
selector_func (Callable[[Sequence[BaseAgentEvent | BaseChatMessage]], str | None], Callable[[Sequence[BaseAgentEvent | BaseChatMessage]], Awaitable[str | None]], optional): A custom selector
|
||||
function that takes the conversation history and returns the name of the next speaker.
|
||||
If provided, this function will be used to override the model to select the next speaker.
|
||||
If the function returns None, the model will be used to select the next speaker.
|
||||
candidate_func (Callable[[Sequence[AgentEvent | ChatMessage]], List[str]], Callable[[Sequence[AgentEvent | ChatMessage]], Awaitable[List[str]]], optional):
|
||||
candidate_func (Callable[[Sequence[BaseAgentEvent | BaseChatMessage]], List[str]], Callable[[Sequence[BaseAgentEvent | BaseChatMessage]], Awaitable[List[str]]], optional):
|
||||
A custom function that takes the conversation history and returns a filtered list of candidates for the next speaker
|
||||
selection using model. If the function returns an empty list or `None`, `SelectorGroupChat` will raise a `ValueError`.
|
||||
This function is only used if `selector_func` is not set. The `allow_repeated_speaker` will be ignored if set.
|
||||
|
@ -378,7 +378,7 @@ class SelectorGroupChat(BaseGroupChat, Component[SelectorGroupChatConfig]):
|
|||
from autogen_agentchat.teams import SelectorGroupChat
|
||||
from autogen_agentchat.conditions import TextMentionTermination
|
||||
from autogen_agentchat.ui import Console
|
||||
from autogen_agentchat.messages import AgentEvent, ChatMessage
|
||||
from autogen_agentchat.messages import BaseAgentEvent, BaseChatMessage
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
|
@ -404,7 +404,7 @@ class SelectorGroupChat(BaseGroupChat, Component[SelectorGroupChatConfig]):
|
|||
system_message="Check the answer and respond with 'Correct!' or 'Incorrect!'",
|
||||
)
|
||||
|
||||
def selector_func(messages: Sequence[AgentEvent | ChatMessage]) -> str | None:
|
||||
def selector_func(messages: Sequence[BaseAgentEvent | BaseChatMessage]) -> str | None:
|
||||
if len(messages) == 1 or messages[-1].to_text() == "Incorrect!":
|
||||
return "Agent1"
|
||||
if messages[-1].source == "Agent1":
|
||||
|
@ -448,7 +448,7 @@ Read the above conversation. Then select the next role from {participants} to pl
|
|||
max_selector_attempts: int = 3,
|
||||
selector_func: Optional[SelectorFuncType] = None,
|
||||
candidate_func: Optional[CandidateFuncType] = None,
|
||||
custom_message_types: List[type[AgentEvent | ChatMessage]] | None = None,
|
||||
custom_message_types: List[type[BaseAgentEvent | BaseChatMessage]] | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
participants,
|
||||
|
@ -477,7 +477,7 @@ Read the above conversation. Then select the next role from {participants} to pl
|
|||
participant_topic_types: List[str],
|
||||
participant_names: List[str],
|
||||
participant_descriptions: List[str],
|
||||
output_message_queue: asyncio.Queue[AgentEvent | ChatMessage | GroupChatTermination],
|
||||
output_message_queue: asyncio.Queue[BaseAgentEvent | BaseChatMessage | GroupChatTermination],
|
||||
termination_condition: TerminationCondition | None,
|
||||
max_turns: int | None,
|
||||
message_factory: MessageFactory,
|
||||
|
@ -525,7 +525,7 @@ Read the above conversation. Then select the next role from {participants} to pl
|
|||
selector_prompt=config.selector_prompt,
|
||||
allow_repeated_speaker=config.allow_repeated_speaker,
|
||||
max_selector_attempts=config.max_selector_attempts,
|
||||
# selector_func=ComponentLoader.load_component(config.selector_func, Callable[[Sequence[AgentEvent | ChatMessage]], str | None])
|
||||
# selector_func=ComponentLoader.load_component(config.selector_func, Callable[[Sequence[BaseAgentEvent | BaseChatMessage]], str | None])
|
||||
# if config.selector_func
|
||||
# else None,
|
||||
)
|
||||
|
|
|
@ -5,7 +5,7 @@ from autogen_core import AgentRuntime, Component, ComponentModel
|
|||
from pydantic import BaseModel
|
||||
|
||||
from ...base import ChatAgent, TerminationCondition
|
||||
from ...messages import AgentEvent, ChatMessage, HandoffMessage, MessageFactory
|
||||
from ...messages import BaseAgentEvent, BaseChatMessage, HandoffMessage, MessageFactory
|
||||
from ...state import SwarmManagerState
|
||||
from ._base_group_chat import BaseGroupChat
|
||||
from ._base_group_chat_manager import BaseGroupChatManager
|
||||
|
@ -23,7 +23,7 @@ class SwarmGroupChatManager(BaseGroupChatManager):
|
|||
participant_topic_types: List[str],
|
||||
participant_names: List[str],
|
||||
participant_descriptions: List[str],
|
||||
output_message_queue: asyncio.Queue[AgentEvent | ChatMessage | GroupChatTermination],
|
||||
output_message_queue: asyncio.Queue[BaseAgentEvent | BaseChatMessage | GroupChatTermination],
|
||||
termination_condition: TerminationCondition | None,
|
||||
max_turns: int | None,
|
||||
message_factory: MessageFactory,
|
||||
|
@ -42,7 +42,7 @@ class SwarmGroupChatManager(BaseGroupChatManager):
|
|||
)
|
||||
self._current_speaker = self._participant_names[0]
|
||||
|
||||
async def validate_group_state(self, messages: List[ChatMessage] | None) -> None:
|
||||
async def validate_group_state(self, messages: List[BaseChatMessage] | None) -> None:
|
||||
"""Validate the start messages for the group chat."""
|
||||
# Check if any of the start messages is a handoff message.
|
||||
if messages:
|
||||
|
@ -77,7 +77,7 @@ class SwarmGroupChatManager(BaseGroupChatManager):
|
|||
await self._termination_condition.reset()
|
||||
self._current_speaker = self._participant_names[0]
|
||||
|
||||
async def select_speaker(self, thread: List[AgentEvent | ChatMessage]) -> str:
|
||||
async def select_speaker(self, thread: List[BaseAgentEvent | BaseChatMessage]) -> str:
|
||||
"""Select a speaker from the participants based on handoff message.
|
||||
Looks for the last handoff message in the thread to determine the next speaker."""
|
||||
if len(thread) == 0:
|
||||
|
@ -212,7 +212,7 @@ class Swarm(BaseGroupChat, Component[SwarmConfig]):
|
|||
termination_condition: TerminationCondition | None = None,
|
||||
max_turns: int | None = None,
|
||||
runtime: AgentRuntime | None = None,
|
||||
custom_message_types: List[type[AgentEvent | ChatMessage]] | None = None,
|
||||
custom_message_types: List[type[BaseAgentEvent | BaseChatMessage]] | None = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
participants,
|
||||
|
@ -236,7 +236,7 @@ class Swarm(BaseGroupChat, Component[SwarmConfig]):
|
|||
participant_topic_types: List[str],
|
||||
participant_names: List[str],
|
||||
participant_descriptions: List[str],
|
||||
output_message_queue: asyncio.Queue[AgentEvent | ChatMessage | GroupChatTermination],
|
||||
output_message_queue: asyncio.Queue[BaseAgentEvent | BaseChatMessage | GroupChatTermination],
|
||||
termination_condition: TerminationCondition | None,
|
||||
max_turns: int | None,
|
||||
message_factory: MessageFactory,
|
||||
|
|
|
@ -11,8 +11,8 @@ from autogen_core.models import RequestUsage
|
|||
from autogen_agentchat.agents import UserProxyAgent
|
||||
from autogen_agentchat.base import Response, TaskResult
|
||||
from autogen_agentchat.messages import (
|
||||
AgentEvent,
|
||||
ChatMessage,
|
||||
BaseAgentEvent,
|
||||
BaseChatMessage,
|
||||
ModelClientStreamingChunkEvent,
|
||||
MultiModalMessage,
|
||||
UserInputRequestedEvent,
|
||||
|
@ -80,7 +80,7 @@ def aprint(output: str, end: str = "\n", flush: bool = False) -> Awaitable[None]
|
|||
|
||||
|
||||
async def Console(
|
||||
stream: AsyncGenerator[AgentEvent | ChatMessage | T, None],
|
||||
stream: AsyncGenerator[BaseAgentEvent | BaseChatMessage | T, None],
|
||||
*,
|
||||
no_inline_images: bool = False,
|
||||
output_stats: bool = False,
|
||||
|
@ -97,7 +97,7 @@ async def Console(
|
|||
It will be improved in future releases.
|
||||
|
||||
Args:
|
||||
stream (AsyncGenerator[AgentEvent | ChatMessage | TaskResult, None] | AsyncGenerator[AgentEvent | ChatMessage | Response, None]): Message stream to render.
|
||||
stream (AsyncGenerator[BaseAgentEvent | BaseChatMessage | TaskResult, None] | AsyncGenerator[BaseAgentEvent | BaseChatMessage | Response, None]): Message stream to render.
|
||||
This can be from :meth:`~autogen_agentchat.base.TaskRunner.run_stream` or :meth:`~autogen_agentchat.base.ChatAgent.on_messages_stream`.
|
||||
no_inline_images (bool, optional): If terminal is iTerm2 will render images inline. Use this to disable this behavior. Defaults to False.
|
||||
output_stats (bool, optional): (Experimental) If True, will output a summary of the messages and inline token usage info. Defaults to False.
|
||||
|
@ -170,7 +170,7 @@ async def Console(
|
|||
user_input_manager.notify_event_received(message.request_id)
|
||||
else:
|
||||
# Cast required for mypy to be happy
|
||||
message = cast(AgentEvent | ChatMessage, message) # type: ignore
|
||||
message = cast(BaseAgentEvent | BaseChatMessage, message) # type: ignore
|
||||
if not streaming_chunks:
|
||||
# Print message sender.
|
||||
await aprint(f"{'-' * 10} {message.source} {'-' * 10}", end="\n", flush=True)
|
||||
|
|
|
@ -7,7 +7,7 @@ from autogen_agentchat import EVENT_LOGGER_NAME
|
|||
from autogen_agentchat.agents import AssistantAgent
|
||||
from autogen_agentchat.base import Handoff, TaskResult
|
||||
from autogen_agentchat.messages import (
|
||||
ChatMessage,
|
||||
BaseChatMessage,
|
||||
HandoffMessage,
|
||||
MemoryQueryEvent,
|
||||
ModelClientStreamingChunkEvent,
|
||||
|
@ -737,7 +737,7 @@ async def test_list_chat_messages(monkeypatch: pytest.MonkeyPatch) -> None:
|
|||
)
|
||||
|
||||
# Create a list of chat messages
|
||||
messages: List[ChatMessage] = [
|
||||
messages: List[BaseChatMessage] = [
|
||||
TextMessage(content="Message 1", source="user"),
|
||||
TextMessage(content="Message 2", source="user"),
|
||||
]
|
||||
|
|
|
@ -2,7 +2,7 @@ import asyncio
|
|||
import json
|
||||
import logging
|
||||
import tempfile
|
||||
from typing import AsyncGenerator, List, Sequence
|
||||
from typing import Any, AsyncGenerator, List, Mapping, Sequence
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
@ -15,8 +15,8 @@ from autogen_agentchat.agents import (
|
|||
from autogen_agentchat.base import Handoff, Response, TaskResult
|
||||
from autogen_agentchat.conditions import HandoffTermination, MaxMessageTermination, TextMentionTermination
|
||||
from autogen_agentchat.messages import (
|
||||
AgentEvent,
|
||||
ChatMessage,
|
||||
BaseAgentEvent,
|
||||
BaseChatMessage,
|
||||
HandoffMessage,
|
||||
MultiModalMessage,
|
||||
StopMessage,
|
||||
|
@ -60,14 +60,14 @@ class _EchoAgent(BaseChatAgent):
|
|||
self._total_messages = 0
|
||||
|
||||
@property
|
||||
def produced_message_types(self) -> Sequence[type[ChatMessage]]:
|
||||
def produced_message_types(self) -> Sequence[type[BaseChatMessage]]:
|
||||
return (TextMessage,)
|
||||
|
||||
@property
|
||||
def total_messages(self) -> int:
|
||||
return self._total_messages
|
||||
|
||||
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
|
||||
async def on_messages(self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken) -> Response:
|
||||
if len(messages) > 0:
|
||||
assert isinstance(messages[0], TextMessage)
|
||||
self._last_message = messages[0].content
|
||||
|
@ -89,21 +89,21 @@ class _FlakyAgent(BaseChatAgent):
|
|||
self._total_messages = 0
|
||||
|
||||
@property
|
||||
def produced_message_types(self) -> Sequence[type[ChatMessage]]:
|
||||
def produced_message_types(self) -> Sequence[type[BaseChatMessage]]:
|
||||
return (TextMessage,)
|
||||
|
||||
@property
|
||||
def total_messages(self) -> int:
|
||||
return self._total_messages
|
||||
|
||||
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
|
||||
async def on_messages(self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken) -> Response:
|
||||
raise ValueError("I am a flaky agent...")
|
||||
|
||||
async def on_reset(self, cancellation_token: CancellationToken) -> None:
|
||||
self._last_message = None
|
||||
|
||||
|
||||
class _UnknownMessageType(ChatMessage):
|
||||
class _UnknownMessageType(BaseChatMessage):
|
||||
content: str
|
||||
|
||||
def to_model_message(self) -> UserMessage:
|
||||
|
@ -115,16 +115,23 @@ class _UnknownMessageType(ChatMessage):
|
|||
def to_text(self) -> str:
|
||||
raise NotImplementedError("This message type is not supported.")
|
||||
|
||||
def dump(self) -> Mapping[str, Any]:
|
||||
return {}
|
||||
|
||||
@classmethod
|
||||
def load(cls, data: Mapping[str, Any]) -> "_UnknownMessageType":
|
||||
return cls(**data)
|
||||
|
||||
|
||||
class _UnknownMessageTypeAgent(BaseChatAgent):
|
||||
def __init__(self, name: str, description: str) -> None:
|
||||
super().__init__(name, description)
|
||||
|
||||
@property
|
||||
def produced_message_types(self) -> Sequence[type[ChatMessage]]:
|
||||
def produced_message_types(self) -> Sequence[type[BaseChatMessage]]:
|
||||
return (_UnknownMessageType,)
|
||||
|
||||
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
|
||||
async def on_messages(self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken) -> Response:
|
||||
return Response(chat_message=_UnknownMessageType(content="Unknown message type", source=self.name))
|
||||
|
||||
async def on_reset(self, cancellation_token: CancellationToken) -> None:
|
||||
|
@ -138,10 +145,10 @@ class _StopAgent(_EchoAgent):
|
|||
self._stop_at = stop_at
|
||||
|
||||
@property
|
||||
def produced_message_types(self) -> Sequence[type[ChatMessage]]:
|
||||
def produced_message_types(self) -> Sequence[type[BaseChatMessage]]:
|
||||
return (TextMessage, StopMessage)
|
||||
|
||||
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
|
||||
async def on_messages(self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken) -> Response:
|
||||
self._count += 1
|
||||
if self._count < self._stop_at:
|
||||
return await super().on_messages(messages, cancellation_token)
|
||||
|
@ -162,7 +169,7 @@ class _InputTask2(BaseModel):
|
|||
data: str
|
||||
|
||||
|
||||
TaskType = str | List[ChatMessage] | ChatMessage
|
||||
TaskType = str | List[BaseChatMessage] | BaseChatMessage
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(params=["single_threaded", "embedded"]) # type: ignore
|
||||
|
@ -821,7 +828,7 @@ async def test_selector_group_chat_custom_selector(runtime: AgentRuntime | None)
|
|||
agent3 = _EchoAgent("agent3", description="echo agent 3")
|
||||
agent4 = _EchoAgent("agent4", description="echo agent 4")
|
||||
|
||||
def _select_agent(messages: Sequence[AgentEvent | ChatMessage]) -> str | None:
|
||||
def _select_agent(messages: Sequence[BaseAgentEvent | BaseChatMessage]) -> str | None:
|
||||
if len(messages) == 0:
|
||||
return "agent1"
|
||||
elif messages[-1].source == "agent1":
|
||||
|
@ -862,7 +869,7 @@ async def test_selector_group_chat_custom_candidate_func(runtime: AgentRuntime |
|
|||
agent3 = _EchoAgent("agent3", description="echo agent 3")
|
||||
agent4 = _EchoAgent("agent4", description="echo agent 4")
|
||||
|
||||
def _candidate_func(messages: Sequence[AgentEvent | ChatMessage]) -> List[str]:
|
||||
def _candidate_func(messages: Sequence[BaseAgentEvent | BaseChatMessage]) -> List[str]:
|
||||
if len(messages) == 0:
|
||||
return ["agent1"]
|
||||
elif messages[-1].source == "agent1":
|
||||
|
@ -901,10 +908,10 @@ class _HandOffAgent(BaseChatAgent):
|
|||
self._next_agent = next_agent
|
||||
|
||||
@property
|
||||
def produced_message_types(self) -> Sequence[type[ChatMessage]]:
|
||||
def produced_message_types(self) -> Sequence[type[BaseChatMessage]]:
|
||||
return (HandoffMessage,)
|
||||
|
||||
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
|
||||
async def on_messages(self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken) -> Response:
|
||||
return Response(
|
||||
chat_message=HandoffMessage(
|
||||
content=f"Transferred to {self._next_agent}.", target=self._next_agent, source=self.name
|
||||
|
@ -1292,7 +1299,7 @@ async def test_round_robin_group_chat_with_message_list(runtime: AgentRuntime |
|
|||
team = RoundRobinGroupChat([agent1, agent2], termination_condition=termination, runtime=runtime)
|
||||
|
||||
# Create a list of messages
|
||||
messages: List[ChatMessage] = [
|
||||
messages: List[BaseChatMessage] = [
|
||||
TextMessage(content="Message 1", source="user"),
|
||||
TextMessage(content="Message 2", source="user"),
|
||||
TextMessage(content="Message 3", source="user"),
|
||||
|
@ -1324,7 +1331,7 @@ async def test_round_robin_group_chat_with_message_list(runtime: AgentRuntime |
|
|||
index += 1
|
||||
|
||||
# Test with invalid message list
|
||||
with pytest.raises(ValueError, match="All messages in task list must be valid ChatMessage types"):
|
||||
with pytest.raises(ValueError, match="All messages in task list must be valid BaseChatMessage types"):
|
||||
await team.run(task=["not a message"]) # type: ignore[list-item, arg-type] # intentionally testing invalid input
|
||||
|
||||
# Test with empty message list
|
||||
|
|
|
@ -4,7 +4,7 @@ from typing import List, Sequence
|
|||
import pytest
|
||||
from autogen_agentchat.agents import AssistantAgent
|
||||
from autogen_agentchat.base import TaskResult
|
||||
from autogen_agentchat.messages import AgentEvent, ChatMessage
|
||||
from autogen_agentchat.messages import BaseAgentEvent, BaseChatMessage
|
||||
from autogen_agentchat.teams import SelectorGroupChat
|
||||
from autogen_agentchat.ui import Console
|
||||
from autogen_core.models import ChatCompletionClient
|
||||
|
@ -33,7 +33,7 @@ async def _test_selector_group_chat(model_client: ChatCompletionClient) -> None:
|
|||
async def _test_selector_group_chat_with_candidate_func(model_client: ChatCompletionClient) -> None:
|
||||
filtered_participants = ["developer", "tester"]
|
||||
|
||||
def dummy_candidate_func(thread: Sequence[AgentEvent | ChatMessage]) -> List[str]:
|
||||
def dummy_candidate_func(thread: Sequence[BaseAgentEvent | BaseChatMessage]) -> List[str]:
|
||||
# Dummy candidate function that will return
|
||||
# only return developer and reviewer
|
||||
return filtered_participants
|
||||
|
|
|
@ -5,7 +5,7 @@ import pytest
|
|||
import pytest_asyncio
|
||||
from autogen_agentchat.agents import BaseChatAgent
|
||||
from autogen_agentchat.base import Response
|
||||
from autogen_agentchat.messages import ChatMessage, TextMessage
|
||||
from autogen_agentchat.messages import BaseChatMessage, TextMessage
|
||||
from autogen_agentchat.teams import RoundRobinGroupChat
|
||||
from autogen_core import AgentRuntime, CancellationToken, SingleThreadedAgentRuntime
|
||||
|
||||
|
@ -20,10 +20,10 @@ class TestAgent(BaseChatAgent):
|
|||
self.counter = 0
|
||||
|
||||
@property
|
||||
def produced_message_types(self) -> Sequence[type[ChatMessage]]:
|
||||
def produced_message_types(self) -> Sequence[type[BaseChatMessage]]:
|
||||
return [TextMessage]
|
||||
|
||||
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
|
||||
async def on_messages(self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken) -> Response:
|
||||
assert not self._is_paused, "Agent is paused"
|
||||
|
||||
async def _process() -> None:
|
||||
|
|
|
@ -11,7 +11,7 @@ from autogen_agentchat.agents import (
|
|||
)
|
||||
from autogen_agentchat.base import Response
|
||||
from autogen_agentchat.messages import (
|
||||
ChatMessage,
|
||||
BaseChatMessage,
|
||||
TextMessage,
|
||||
)
|
||||
from autogen_agentchat.teams import (
|
||||
|
@ -34,14 +34,14 @@ class _EchoAgent(BaseChatAgent):
|
|||
self._total_messages = 0
|
||||
|
||||
@property
|
||||
def produced_message_types(self) -> Sequence[type[ChatMessage]]:
|
||||
def produced_message_types(self) -> Sequence[type[BaseChatMessage]]:
|
||||
return (TextMessage,)
|
||||
|
||||
@property
|
||||
def total_messages(self) -> int:
|
||||
return self._total_messages
|
||||
|
||||
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
|
||||
async def on_messages(self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken) -> Response:
|
||||
if len(messages) > 0:
|
||||
assert isinstance(messages[0], TextMessage)
|
||||
self._last_message = messages[0].content
|
||||
|
|
|
@ -1,5 +1,22 @@
|
|||
import json
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
from autogen_agentchat.messages import HandoffMessage, MessageFactory, StructuredMessage, TextMessage
|
||||
from autogen_agentchat.messages import (
|
||||
AgentEvent,
|
||||
ChatMessage,
|
||||
HandoffMessage,
|
||||
MessageFactory,
|
||||
ModelClientStreamingChunkEvent,
|
||||
MultiModalMessage,
|
||||
StopMessage,
|
||||
StructuredMessage,
|
||||
TextMessage,
|
||||
ToolCallExecutionEvent,
|
||||
ToolCallRequestEvent,
|
||||
)
|
||||
from autogen_core import FunctionCall
|
||||
from autogen_core.models import FunctionExecutionResult
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
|
@ -18,7 +35,7 @@ def test_structured_message() -> None:
|
|||
)
|
||||
|
||||
# Check that the message type is correct
|
||||
assert message.type == "StructuredMessage[TestContent]" # type: ignore
|
||||
assert message.type == "StructuredMessage[TestContent]" # type: ignore[comparison-overlap]
|
||||
|
||||
# Check that the content is of the correct type
|
||||
assert isinstance(message.content, TestContent)
|
||||
|
@ -50,7 +67,7 @@ def test_message_factory() -> None:
|
|||
assert isinstance(text_message, TextMessage)
|
||||
assert text_message.source == "test_agent"
|
||||
assert text_message.content == "Hello, world!"
|
||||
assert text_message.type == "TextMessage" # type: ignore
|
||||
assert text_message.type == "TextMessage" # type: ignore[comparison-overlap]
|
||||
|
||||
# Handoff message data
|
||||
handoff_data = {
|
||||
|
@ -66,7 +83,7 @@ def test_message_factory() -> None:
|
|||
assert handoff_message.source == "test_agent"
|
||||
assert handoff_message.content == "handoff to another agent"
|
||||
assert handoff_message.target == "target_agent"
|
||||
assert handoff_message.type == "HandoffMessage" # type: ignore
|
||||
assert handoff_message.type == "HandoffMessage" # type: ignore[comparison-overlap]
|
||||
|
||||
# Structured message data
|
||||
structured_data = {
|
||||
|
@ -86,8 +103,48 @@ def test_message_factory() -> None:
|
|||
# Create a StructuredMessage instance
|
||||
structured_message = factory.create(structured_data)
|
||||
assert isinstance(structured_message, StructuredMessage)
|
||||
assert isinstance(structured_message.content, TestContent) # type: ignore
|
||||
assert isinstance(structured_message.content, TestContent) # type: ignore[reportUnkownMemberType]
|
||||
assert structured_message.source == "test_agent"
|
||||
assert structured_message.content.field1 == "test"
|
||||
assert structured_message.content.field2 == 42
|
||||
assert structured_message.type == "StructuredMessage[TestContent]" # type: ignore
|
||||
assert structured_message.type == "StructuredMessage[TestContent]" # type: ignore[comparison-overlap]
|
||||
|
||||
|
||||
class TestContainer(BaseModel):
|
||||
chat_messages: List[ChatMessage]
|
||||
agent_events: List[AgentEvent]
|
||||
|
||||
|
||||
def test_union_types() -> None:
|
||||
# Create a few messages.
|
||||
chat_messages: List[ChatMessage] = [
|
||||
TextMessage(source="user", content="Hello!"),
|
||||
MultiModalMessage(source="user", content=["Hello!", "World!"]),
|
||||
HandoffMessage(source="user", content="handoff to another agent", target="target_agent"),
|
||||
StopMessage(source="user", content="stop"),
|
||||
]
|
||||
|
||||
# Create a few agent events.
|
||||
agent_events: List[AgentEvent] = [
|
||||
ModelClientStreamingChunkEvent(source="user", content="Hello!"),
|
||||
ToolCallRequestEvent(
|
||||
content=[
|
||||
FunctionCall(id="1", name="test_function", arguments=json.dumps({"arg1": "value1", "arg2": "value2"}))
|
||||
],
|
||||
source="user",
|
||||
),
|
||||
ToolCallExecutionEvent(
|
||||
content=[FunctionExecutionResult(call_id="1", content="result", name="test")], source="user"
|
||||
),
|
||||
]
|
||||
|
||||
# Create a container with the messages.
|
||||
container = TestContainer(chat_messages=chat_messages, agent_events=agent_events)
|
||||
|
||||
# Dump the container to JSON.
|
||||
data = container.model_dump()
|
||||
|
||||
# Load the container from JSON.
|
||||
loaded_container = TestContainer.model_validate(data)
|
||||
assert loaded_container.chat_messages == chat_messages
|
||||
assert loaded_container.agent_events == agent_events
|
||||
|
|
|
@ -4,7 +4,7 @@ from typing import Optional, Sequence
|
|||
import pytest
|
||||
from autogen_agentchat.agents import UserProxyAgent
|
||||
from autogen_agentchat.base import Response
|
||||
from autogen_agentchat.messages import ChatMessage, HandoffMessage, TextMessage
|
||||
from autogen_agentchat.messages import BaseChatMessage, HandoffMessage, TextMessage
|
||||
from autogen_core import CancellationToken
|
||||
|
||||
|
||||
|
@ -53,7 +53,7 @@ async def test_handoff_handling() -> None:
|
|||
|
||||
agent = UserProxyAgent(name="test_user", input_func=custom_input)
|
||||
|
||||
messages: Sequence[ChatMessage] = [
|
||||
messages: Sequence[BaseChatMessage] = [
|
||||
TextMessage(content="Initial message", source="assistant"),
|
||||
HandoffMessage(content="Handing off to user for confirmation", source="assistant", target="test_user"),
|
||||
]
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
"\n",
|
||||
"- {py:meth}`~autogen_agentchat.agents.BaseChatAgent.on_messages`: The abstract method that defines the behavior of the agent in response to messages. This method is called when the agent is asked to provide a response in {py:meth}`~autogen_agentchat.agents.BaseChatAgent.run`. It returns a {py:class}`~autogen_agentchat.base.Response` object.\n",
|
||||
"- {py:meth}`~autogen_agentchat.agents.BaseChatAgent.on_reset`: The abstract method that resets the agent to its initial state. This method is called when the agent is asked to reset itself.\n",
|
||||
"- {py:attr}`~autogen_agentchat.agents.BaseChatAgent.produced_message_types`: The list of possible {py:class}`~autogen_agentchat.messages.ChatMessage` message types the agent can produce in its response.\n",
|
||||
"- {py:attr}`~autogen_agentchat.agents.BaseChatAgent.produced_message_types`: The list of possible {py:class}`~autogen_agentchat.messages.BaseChatMessage` message types the agent can produce in its response.\n",
|
||||
"\n",
|
||||
"Optionally, you can implement the the {py:meth}`~autogen_agentchat.agents.BaseChatAgent.on_messages_stream` method to stream messages as they are generated by the agent. If this method is not implemented, the agent\n",
|
||||
"uses the default implementation of {py:meth}`~autogen_agentchat.agents.BaseChatAgent.on_messages_stream`\n",
|
||||
|
@ -53,7 +53,7 @@
|
|||
"\n",
|
||||
"from autogen_agentchat.agents import BaseChatAgent\n",
|
||||
"from autogen_agentchat.base import Response\n",
|
||||
"from autogen_agentchat.messages import AgentEvent, ChatMessage, TextMessage\n",
|
||||
"from autogen_agentchat.messages import BaseAgentEvent, BaseChatMessage, TextMessage\n",
|
||||
"from autogen_core import CancellationToken\n",
|
||||
"\n",
|
||||
"\n",
|
||||
|
@ -63,10 +63,10 @@
|
|||
" self._count = count\n",
|
||||
"\n",
|
||||
" @property\n",
|
||||
" def produced_message_types(self) -> Sequence[type[ChatMessage]]:\n",
|
||||
" def produced_message_types(self) -> Sequence[type[BaseChatMessage]]:\n",
|
||||
" return (TextMessage,)\n",
|
||||
"\n",
|
||||
" async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:\n",
|
||||
" async def on_messages(self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken) -> Response:\n",
|
||||
" # Calls the on_messages_stream.\n",
|
||||
" response: Response | None = None\n",
|
||||
" async for message in self.on_messages_stream(messages, cancellation_token):\n",
|
||||
|
@ -76,9 +76,9 @@
|
|||
" return response\n",
|
||||
"\n",
|
||||
" async def on_messages_stream(\n",
|
||||
" self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken\n",
|
||||
" ) -> AsyncGenerator[AgentEvent | ChatMessage | Response, None]:\n",
|
||||
" inner_messages: List[AgentEvent | ChatMessage] = []\n",
|
||||
" self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken\n",
|
||||
" ) -> AsyncGenerator[BaseAgentEvent | BaseChatMessage | Response, None]:\n",
|
||||
" inner_messages: List[BaseAgentEvent | BaseChatMessage] = []\n",
|
||||
" for i in range(self._count, 0, -1):\n",
|
||||
" msg = TextMessage(content=f\"{i}...\", source=self.name)\n",
|
||||
" inner_messages.append(msg)\n",
|
||||
|
@ -135,7 +135,7 @@
|
|||
"from autogen_agentchat.agents import BaseChatAgent\n",
|
||||
"from autogen_agentchat.base import Response\n",
|
||||
"from autogen_agentchat.conditions import MaxMessageTermination\n",
|
||||
"from autogen_agentchat.messages import ChatMessage\n",
|
||||
"from autogen_agentchat.messages import BaseChatMessage\n",
|
||||
"from autogen_agentchat.teams import SelectorGroupChat\n",
|
||||
"from autogen_agentchat.ui import Console\n",
|
||||
"from autogen_core import CancellationToken\n",
|
||||
|
@ -146,13 +146,13 @@
|
|||
" def __init__(self, name: str, description: str, operator_func: Callable[[int], int]) -> None:\n",
|
||||
" super().__init__(name, description=description)\n",
|
||||
" self._operator_func = operator_func\n",
|
||||
" self._message_history: List[ChatMessage] = []\n",
|
||||
" self._message_history: List[BaseChatMessage] = []\n",
|
||||
"\n",
|
||||
" @property\n",
|
||||
" def produced_message_types(self) -> Sequence[type[ChatMessage]]:\n",
|
||||
" def produced_message_types(self) -> Sequence[type[BaseChatMessage]]:\n",
|
||||
" return (TextMessage,)\n",
|
||||
"\n",
|
||||
" async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:\n",
|
||||
" async def on_messages(self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken) -> Response:\n",
|
||||
" # Update the message history.\n",
|
||||
" # NOTE: it is possible the messages is an empty list, which means the agent was selected previously.\n",
|
||||
" self._message_history.extend(messages)\n",
|
||||
|
@ -268,7 +268,7 @@
|
|||
" )\n",
|
||||
"\n",
|
||||
" # Run the selector group chat with a given task and stream the response.\n",
|
||||
" task: List[ChatMessage] = [\n",
|
||||
" task: List[BaseChatMessage] = [\n",
|
||||
" TextMessage(content=\"Apply the operations to turn the given number into 25.\", source=\"user\"),\n",
|
||||
" TextMessage(content=\"10\", source=\"user\"),\n",
|
||||
" ]\n",
|
||||
|
@ -319,7 +319,7 @@
|
|||
"\n",
|
||||
"from autogen_agentchat.agents import BaseChatAgent\n",
|
||||
"from autogen_agentchat.base import Response\n",
|
||||
"from autogen_agentchat.messages import AgentEvent, ChatMessage\n",
|
||||
"from autogen_agentchat.messages import BaseAgentEvent, BaseChatMessage\n",
|
||||
"from autogen_core import CancellationToken\n",
|
||||
"from autogen_core.model_context import UnboundedChatCompletionContext\n",
|
||||
"from autogen_core.models import AssistantMessage, RequestUsage, UserMessage\n",
|
||||
|
@ -344,10 +344,10 @@
|
|||
" self._model = model\n",
|
||||
"\n",
|
||||
" @property\n",
|
||||
" def produced_message_types(self) -> Sequence[type[ChatMessage]]:\n",
|
||||
" def produced_message_types(self) -> Sequence[type[BaseChatMessage]]:\n",
|
||||
" return (TextMessage,)\n",
|
||||
"\n",
|
||||
" async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:\n",
|
||||
" async def on_messages(self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken) -> Response:\n",
|
||||
" final_response = None\n",
|
||||
" async for message in self.on_messages_stream(messages, cancellation_token):\n",
|
||||
" if isinstance(message, Response):\n",
|
||||
|
@ -359,8 +359,8 @@
|
|||
" return final_response\n",
|
||||
"\n",
|
||||
" async def on_messages_stream(\n",
|
||||
" self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken\n",
|
||||
" ) -> AsyncGenerator[AgentEvent | ChatMessage | Response, None]:\n",
|
||||
" self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken\n",
|
||||
" ) -> AsyncGenerator[BaseAgentEvent | BaseChatMessage | Response, None]:\n",
|
||||
" # Add messages to the model context\n",
|
||||
" for msg in messages:\n",
|
||||
" await self._model_context.add_message(msg.to_model_message())\n",
|
||||
|
@ -550,7 +550,7 @@
|
|||
"\n",
|
||||
"from autogen_agentchat.agents import BaseChatAgent\n",
|
||||
"from autogen_agentchat.base import Response\n",
|
||||
"from autogen_agentchat.messages import AgentEvent, ChatMessage\n",
|
||||
"from autogen_agentchat.messages import BaseAgentEvent, BaseChatMessage\n",
|
||||
"from autogen_core import CancellationToken, Component\n",
|
||||
"from pydantic import BaseModel\n",
|
||||
"from typing_extensions import Self\n",
|
||||
|
@ -583,10 +583,10 @@
|
|||
" self._model = model\n",
|
||||
"\n",
|
||||
" @property\n",
|
||||
" def produced_message_types(self) -> Sequence[type[ChatMessage]]:\n",
|
||||
" def produced_message_types(self) -> Sequence[type[BaseChatMessage]]:\n",
|
||||
" return (TextMessage,)\n",
|
||||
"\n",
|
||||
" async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:\n",
|
||||
" async def on_messages(self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken) -> Response:\n",
|
||||
" final_response = None\n",
|
||||
" async for message in self.on_messages_stream(messages, cancellation_token):\n",
|
||||
" if isinstance(message, Response):\n",
|
||||
|
@ -598,8 +598,8 @@
|
|||
" return final_response\n",
|
||||
"\n",
|
||||
" async def on_messages_stream(\n",
|
||||
" self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken\n",
|
||||
" ) -> AsyncGenerator[AgentEvent | ChatMessage | Response, None]:\n",
|
||||
" self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken\n",
|
||||
" ) -> AsyncGenerator[BaseAgentEvent | BaseChatMessage | Response, None]:\n",
|
||||
" # Add messages to the model context\n",
|
||||
" for msg in messages:\n",
|
||||
" await self._model_context.add_message(msg.to_model_message())\n",
|
||||
|
|
File diff suppressed because one or more lines are too long
|
@ -462,18 +462,18 @@ and implement the `on_messages`, `on_reset`, and `produced_message_types` method
|
|||
from typing import Sequence
|
||||
from autogen_core import CancellationToken
|
||||
from autogen_agentchat.agents import BaseChatAgent
|
||||
from autogen_agentchat.messages import TextMessage, ChatMessage
|
||||
from autogen_agentchat.messages import TextMessage, BaseChatMessage
|
||||
from autogen_agentchat.base import Response
|
||||
|
||||
class CustomAgent(BaseChatAgent):
|
||||
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
|
||||
async def on_messages(self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken) -> Response:
|
||||
return Response(chat_message=TextMessage(content="Custom reply", source=self.name))
|
||||
|
||||
async def on_reset(self, cancellation_token: CancellationToken) -> None:
|
||||
pass
|
||||
|
||||
@property
|
||||
def produced_message_types(self) -> Sequence[type[ChatMessage]]:
|
||||
def produced_message_types(self) -> Sequence[type[BaseChatMessage]]:
|
||||
return (TextMessage,)
|
||||
```
|
||||
|
||||
|
@ -742,8 +742,8 @@ You can use the following conversion functions to convert between a v0.4 message
|
|||
from typing import Any, Dict, List, Literal
|
||||
|
||||
from autogen_agentchat.messages import (
|
||||
AgentEvent,
|
||||
ChatMessage,
|
||||
BaseAgentEvent,
|
||||
BaseChatMessage,
|
||||
HandoffMessage,
|
||||
MultiModalMessage,
|
||||
StopMessage,
|
||||
|
@ -757,14 +757,14 @@ from autogen_core.models import FunctionExecutionResult
|
|||
|
||||
|
||||
def convert_to_v02_message(
|
||||
message: AgentEvent | ChatMessage,
|
||||
message: BaseAgentEvent | BaseChatMessage,
|
||||
role: Literal["assistant", "user", "tool"],
|
||||
image_detail: Literal["auto", "high", "low"] = "auto",
|
||||
) -> Dict[str, Any]:
|
||||
"""Convert a v0.4 AgentChat message to a v0.2 message.
|
||||
|
||||
Args:
|
||||
message (AgentEvent | ChatMessage): The message to convert.
|
||||
message (BaseAgentEvent | BaseChatMessage): The message to convert.
|
||||
role (Literal["assistant", "user", "tool"]): The role of the message.
|
||||
image_detail (Literal["auto", "high", "low"], optional): The detail level of image content in multi-modal message. Defaults to "auto".
|
||||
|
||||
|
@ -810,7 +810,7 @@ def convert_to_v02_message(
|
|||
return v02_message
|
||||
|
||||
|
||||
def convert_to_v04_message(message: Dict[str, Any]) -> AgentEvent | ChatMessage:
|
||||
def convert_to_v04_message(message: Dict[str, Any]) -> BaseAgentEvent | BaseChatMessage:
|
||||
"""Convert a v0.2 message to a v0.4 AgentChat message."""
|
||||
if "tool_calls" in message:
|
||||
tool_calls: List[FunctionCall] = []
|
||||
|
@ -1065,7 +1065,7 @@ import asyncio
|
|||
from typing import Sequence
|
||||
from autogen_agentchat.agents import AssistantAgent
|
||||
from autogen_agentchat.conditions import MaxMessageTermination, TextMentionTermination
|
||||
from autogen_agentchat.messages import AgentEvent, ChatMessage
|
||||
from autogen_agentchat.messages import BaseAgentEvent, BaseChatMessage
|
||||
from autogen_agentchat.teams import SelectorGroupChat
|
||||
from autogen_agentchat.ui import Console
|
||||
from autogen_ext.models.openai import OpenAIChatCompletionClient
|
||||
|
@ -1141,7 +1141,7 @@ def create_team(model_client : OpenAIChatCompletionClient) -> SelectorGroupChat:
|
|||
|
||||
# The selector function is a function that takes the current message thread of the group chat
|
||||
# and returns the next speaker's name. If None is returned, the LLM-based selection method will be used.
|
||||
def selector_func(messages: Sequence[AgentEvent | ChatMessage]) -> str | None:
|
||||
def selector_func(messages: Sequence[BaseAgentEvent | BaseChatMessage]) -> str | None:
|
||||
if messages[-1].source != planning_agent.name:
|
||||
return planning_agent.name # Always return to the planning agent after the other agents have spoken.
|
||||
return None
|
||||
|
@ -1190,12 +1190,12 @@ from typing import Sequence
|
|||
from autogen_core import CancellationToken
|
||||
from autogen_agentchat.agents import BaseChatAgent
|
||||
from autogen_agentchat.teams import RoundRobinGroupChat
|
||||
from autogen_agentchat.messages import TextMessage, ChatMessage
|
||||
from autogen_agentchat.messages import TextMessage, BaseChatMessage
|
||||
from autogen_agentchat.base import Response
|
||||
|
||||
class CountingAgent(BaseChatAgent):
|
||||
"""An agent that returns a new number by adding 1 to the last number in the input messages."""
|
||||
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
|
||||
async def on_messages(self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken) -> Response:
|
||||
if len(messages) == 0:
|
||||
last_number = 0 # Start from 0 if no messages are given.
|
||||
else:
|
||||
|
@ -1207,7 +1207,7 @@ class CountingAgent(BaseChatAgent):
|
|||
pass
|
||||
|
||||
@property
|
||||
def produced_message_types(self) -> Sequence[type[ChatMessage]]:
|
||||
def produced_message_types(self) -> Sequence[type[BaseChatMessage]]:
|
||||
return (TextMessage,)
|
||||
|
||||
class NestedCountingAgent(BaseChatAgent):
|
||||
|
@ -1217,7 +1217,7 @@ class NestedCountingAgent(BaseChatAgent):
|
|||
super().__init__(name, description="An agent that counts numbers.")
|
||||
self._counting_team = counting_team
|
||||
|
||||
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
|
||||
async def on_messages(self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken) -> Response:
|
||||
# Run the inner team with the given messages and returns the last message produced by the team.
|
||||
result = await self._counting_team.run(task=messages, cancellation_token=cancellation_token)
|
||||
# To stream the inner messages, implement `on_messages_stream` and use that to implement `on_messages`.
|
||||
|
@ -1229,7 +1229,7 @@ class NestedCountingAgent(BaseChatAgent):
|
|||
await self._counting_team.reset()
|
||||
|
||||
@property
|
||||
def produced_message_types(self) -> Sequence[type[ChatMessage]]:
|
||||
def produced_message_types(self) -> Sequence[type[BaseChatMessage]]:
|
||||
return (TextMessage,)
|
||||
|
||||
async def main() -> None:
|
||||
|
|
|
@ -61,7 +61,7 @@
|
|||
"\n",
|
||||
"from autogen_agentchat.agents import AssistantAgent, UserProxyAgent\n",
|
||||
"from autogen_agentchat.conditions import MaxMessageTermination, TextMentionTermination\n",
|
||||
"from autogen_agentchat.messages import AgentEvent, ChatMessage\n",
|
||||
"from autogen_agentchat.messages import BaseAgentEvent, BaseChatMessage\n",
|
||||
"from autogen_agentchat.teams import SelectorGroupChat\n",
|
||||
"from autogen_agentchat.ui import Console\n",
|
||||
"from autogen_ext.models.openai import OpenAIChatCompletionClient"
|
||||
|
@ -511,7 +511,7 @@
|
|||
}
|
||||
],
|
||||
"source": [
|
||||
"def selector_func(messages: Sequence[AgentEvent | ChatMessage]) -> str | None:\n",
|
||||
"def selector_func(messages: Sequence[BaseAgentEvent | BaseChatMessage]) -> str | None:\n",
|
||||
" if messages[-1].source != planning_agent.name:\n",
|
||||
" return planning_agent.name\n",
|
||||
" return None\n",
|
||||
|
@ -655,7 +655,7 @@
|
|||
}
|
||||
],
|
||||
"source": [
|
||||
"def candidate_func(messages: Sequence[AgentEvent | ChatMessage]) -> List[str]:\n",
|
||||
"def candidate_func(messages: Sequence[BaseAgentEvent | BaseChatMessage]) -> List[str]:\n",
|
||||
" # keep planning_agent first one to plan out the tasks\n",
|
||||
" if messages[-1].source == \"user\":\n",
|
||||
" return [planning_agent.name]\n",
|
||||
|
@ -813,7 +813,7 @@
|
|||
"user_proxy_agent = UserProxyAgent(\"UserProxyAgent\", description=\"A proxy for the user to approve or disapprove tasks.\")\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def selector_func_with_user_proxy(messages: Sequence[AgentEvent | ChatMessage]) -> str | None:\n",
|
||||
"def selector_func_with_user_proxy(messages: Sequence[BaseAgentEvent | BaseChatMessage]) -> str | None:\n",
|
||||
" if messages[-1].source != planning_agent.name and messages[-1].source != user_proxy_agent.name:\n",
|
||||
" # Planning agent should be the first to engage when given a new task, or check progress.\n",
|
||||
" return planning_agent.name\n",
|
||||
|
@ -1018,7 +1018,7 @@
|
|||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.3"
|
||||
"version": "3.12.9"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
|
|
@ -11,8 +11,8 @@
|
|||
"\n",
|
||||
"- {py:attr}`~autogen_agentchat.agents.BaseChatAgent.name`: The unique name of the agent.\n",
|
||||
"- {py:attr}`~autogen_agentchat.agents.BaseChatAgent.description`: The description of the agent in text.\n",
|
||||
"- {py:meth}`~autogen_agentchat.agents.BaseChatAgent.on_messages`: Send the agent a sequence of {py:class}`~autogen_agentchat.messages.ChatMessage` and get a {py:class}`~autogen_agentchat.base.Response`. **It is important to note that agents are expected to be stateful and this method is expected to be called with new messages, not the complete history**.\n",
|
||||
"- {py:meth}`~autogen_agentchat.agents.BaseChatAgent.on_messages_stream`: Same as {py:meth}`~autogen_agentchat.agents.BaseChatAgent.on_messages` but returns an iterator of {py:class}`~autogen_agentchat.messages.AgentEvent` or {py:class}`~autogen_agentchat.messages.ChatMessage` followed by a {py:class}`~autogen_agentchat.base.Response` as the last item.\n",
|
||||
"- {py:meth}`~autogen_agentchat.agents.BaseChatAgent.on_messages`: Send the agent a sequence of {py:class}`~autogen_agentchat.messages.BaseChatMessage` and get a {py:class}`~autogen_agentchat.base.Response`. **It is important to note that agents are expected to be stateful and this method is expected to be called with new messages, not the complete history**.\n",
|
||||
"- {py:meth}`~autogen_agentchat.agents.BaseChatAgent.on_messages_stream`: Same as {py:meth}`~autogen_agentchat.agents.BaseChatAgent.on_messages` but returns an iterator of {py:class}`~autogen_agentchat.messages.BaseAgentEvent` or {py:class}`~autogen_agentchat.messages.BaseChatMessage` followed by a {py:class}`~autogen_agentchat.base.Response` as the last item.\n",
|
||||
"- {py:meth}`~autogen_agentchat.agents.BaseChatAgent.on_reset`: Reset the agent to its initial state.\n",
|
||||
"- {py:meth}`~autogen_agentchat.agents.BaseChatAgent.run` and {py:meth}`~autogen_agentchat.agents.BaseChatAgent.run_stream`: convenience methods that call {py:meth}`~autogen_agentchat.agents.BaseChatAgent.on_messages` and {py:meth}`~autogen_agentchat.agents.BaseChatAgent.on_messages_stream` respectively but offer the same interface as [Teams](./teams.ipynb).\n",
|
||||
"\n",
|
||||
|
@ -840,7 +840,7 @@
|
|||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.3"
|
||||
"version": "3.12.9"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
|
File diff suppressed because one or more lines are too long
File diff suppressed because it is too large
Load Diff
|
@ -3,12 +3,11 @@ import logging
|
|||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Dict, Generic, Mapping, Protocol, Type, TypeVar, cast, runtime_checkable
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
import jsonref
|
||||
from opentelemetry.trace import get_tracer
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import NotRequired
|
||||
from typing_extensions import NotRequired, TypedDict
|
||||
|
||||
from .. import EVENT_LOGGER_NAME, CancellationToken
|
||||
from .._component_config import ComponentBase
|
||||
|
|
|
@ -6,7 +6,7 @@ from typing import List, Sequence, Tuple
|
|||
from autogen_agentchat.agents import BaseChatAgent
|
||||
from autogen_agentchat.base import Response
|
||||
from autogen_agentchat.messages import (
|
||||
ChatMessage,
|
||||
BaseChatMessage,
|
||||
TextMessage,
|
||||
)
|
||||
from autogen_agentchat.utils import remove_images
|
||||
|
@ -84,10 +84,10 @@ class FileSurfer(BaseChatAgent, Component[FileSurferConfig]):
|
|||
self._browser = MarkdownFileBrowser(viewport_size=1024 * 5, base_path=base_path)
|
||||
|
||||
@property
|
||||
def produced_message_types(self) -> Sequence[type[ChatMessage]]:
|
||||
def produced_message_types(self) -> Sequence[type[BaseChatMessage]]:
|
||||
return (TextMessage,)
|
||||
|
||||
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
|
||||
async def on_messages(self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken) -> Response:
|
||||
for chat_message in messages:
|
||||
self._chat_history.append(chat_message.to_model_message())
|
||||
try:
|
||||
|
|
|
@ -24,8 +24,8 @@ from autogen_agentchat import EVENT_LOGGER_NAME
|
|||
from autogen_agentchat.agents import BaseChatAgent
|
||||
from autogen_agentchat.base import Response
|
||||
from autogen_agentchat.messages import (
|
||||
AgentEvent,
|
||||
ChatMessage,
|
||||
BaseAgentEvent,
|
||||
BaseChatMessage,
|
||||
TextMessage,
|
||||
ToolCallExecutionEvent,
|
||||
ToolCallRequestEvent,
|
||||
|
@ -353,7 +353,7 @@ class OpenAIAssistantAgent(BaseChatAgent):
|
|||
self._initial_message_ids = initial_message_ids
|
||||
|
||||
@property
|
||||
def produced_message_types(self) -> Sequence[type[ChatMessage]]:
|
||||
def produced_message_types(self) -> Sequence[type[BaseChatMessage]]:
|
||||
"""The types of messages that the assistant agent produces."""
|
||||
return (TextMessage,)
|
||||
|
||||
|
@ -392,7 +392,7 @@ class OpenAIAssistantAgent(BaseChatAgent):
|
|||
result = await tool.run_json(arguments, cancellation_token)
|
||||
return tool.return_value_as_string(result)
|
||||
|
||||
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
|
||||
async def on_messages(self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken) -> Response:
|
||||
"""Handle incoming messages and return a response."""
|
||||
|
||||
async for message in self.on_messages_stream(messages, cancellation_token):
|
||||
|
@ -401,8 +401,8 @@ class OpenAIAssistantAgent(BaseChatAgent):
|
|||
raise AssertionError("The stream should have returned the final result.")
|
||||
|
||||
async def on_messages_stream(
|
||||
self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken
|
||||
) -> AsyncGenerator[AgentEvent | ChatMessage | Response, None]:
|
||||
self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken
|
||||
) -> AsyncGenerator[BaseAgentEvent | BaseChatMessage | Response, None]:
|
||||
"""Handle incoming messages and return a response."""
|
||||
await self._ensure_initialized()
|
||||
|
||||
|
@ -411,7 +411,7 @@ class OpenAIAssistantAgent(BaseChatAgent):
|
|||
await self.handle_incoming_message(message, cancellation_token)
|
||||
|
||||
# Inner messages for tool calls
|
||||
inner_messages: List[AgentEvent | ChatMessage] = []
|
||||
inner_messages: List[BaseAgentEvent | BaseChatMessage] = []
|
||||
|
||||
# Create and start a run
|
||||
run: Run = await cancellation_token.link_future(
|
||||
|
@ -518,7 +518,7 @@ class OpenAIAssistantAgent(BaseChatAgent):
|
|||
chat_message = TextMessage(source=self.name, content=text_content[0].text.value)
|
||||
yield Response(chat_message=chat_message, inner_messages=inner_messages)
|
||||
|
||||
async def handle_incoming_message(self, message: ChatMessage, cancellation_token: CancellationToken) -> None:
|
||||
async def handle_incoming_message(self, message: BaseChatMessage, cancellation_token: CancellationToken) -> None:
|
||||
"""Handle regular text messages by adding them to the thread."""
|
||||
content: str | List[MessageContentPartParam] | None = None
|
||||
llm_message = message.to_model_message()
|
||||
|
|
|
@ -24,7 +24,7 @@ import aiofiles
|
|||
import PIL.Image
|
||||
from autogen_agentchat.agents import BaseChatAgent
|
||||
from autogen_agentchat.base import Response
|
||||
from autogen_agentchat.messages import AgentEvent, ChatMessage, MultiModalMessage, TextMessage
|
||||
from autogen_agentchat.messages import BaseAgentEvent, BaseChatMessage, MultiModalMessage, TextMessage
|
||||
from autogen_agentchat.utils import content_to_str, remove_images
|
||||
from autogen_core import EVENT_LOGGER_NAME, CancellationToken, Component, ComponentModel, FunctionCall
|
||||
from autogen_core import Image as AGImage
|
||||
|
@ -385,7 +385,7 @@ class MultimodalWebSurfer(BaseChatAgent, Component[MultimodalWebSurferConfig]):
|
|||
)
|
||||
|
||||
@property
|
||||
def produced_message_types(self) -> Sequence[type[ChatMessage]]:
|
||||
def produced_message_types(self) -> Sequence[type[BaseChatMessage]]:
|
||||
return (MultiModalMessage,)
|
||||
|
||||
async def on_reset(self, cancellation_token: CancellationToken) -> None:
|
||||
|
@ -422,19 +422,19 @@ class MultimodalWebSurfer(BaseChatAgent, Component[MultimodalWebSurferConfig]):
|
|||
)
|
||||
)
|
||||
|
||||
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
|
||||
async def on_messages(self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken) -> Response:
|
||||
async for message in self.on_messages_stream(messages, cancellation_token):
|
||||
if isinstance(message, Response):
|
||||
return message
|
||||
raise AssertionError("The stream should have returned the final result.")
|
||||
|
||||
async def on_messages_stream(
|
||||
self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken
|
||||
) -> AsyncGenerator[AgentEvent | ChatMessage | Response, None]:
|
||||
self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken
|
||||
) -> AsyncGenerator[BaseAgentEvent | BaseChatMessage | Response, None]:
|
||||
for chat_message in messages:
|
||||
self._chat_history.append(chat_message.to_model_message())
|
||||
|
||||
self.inner_messages: List[AgentEvent | ChatMessage] = []
|
||||
self.inner_messages: List[BaseAgentEvent | BaseChatMessage] = []
|
||||
self.model_usage: List[RequestUsage] = []
|
||||
try:
|
||||
content = await self._generate_reply(cancellation_token=cancellation_token)
|
||||
|
|
|
@ -4,7 +4,7 @@ from typing import TYPE_CHECKING, Any, List, Sequence, Tuple, TypedDict
|
|||
|
||||
from autogen_agentchat.agents import AssistantAgent
|
||||
from autogen_agentchat.base import TaskResult
|
||||
from autogen_agentchat.messages import AgentEvent, ChatMessage, TextMessage
|
||||
from autogen_agentchat.messages import BaseAgentEvent, BaseChatMessage, TextMessage
|
||||
from autogen_core.models import (
|
||||
ChatCompletionClient,
|
||||
LLMMessage,
|
||||
|
@ -190,8 +190,8 @@ In responding to every user message, you follow the same multi-step process give
|
|||
|
||||
# Get the agent's response to the task.
|
||||
task_result: TaskResult = await assistant_agent.run(task=TextMessage(content=task, source="User"))
|
||||
messages: Sequence[AgentEvent | ChatMessage] = task_result.messages
|
||||
message: AgentEvent | ChatMessage = messages[-1]
|
||||
messages: Sequence[BaseAgentEvent | BaseChatMessage] = task_result.messages
|
||||
message: BaseAgentEvent | BaseChatMessage = messages[-1]
|
||||
response_str = message.to_text()
|
||||
|
||||
# Log the model call
|
||||
|
|
|
@ -5,7 +5,7 @@ import shutil
|
|||
from typing import Any, Dict, List, Mapping, Optional, Sequence, TypedDict
|
||||
|
||||
from autogen_agentchat.base import TaskResult
|
||||
from autogen_agentchat.messages import AgentEvent, ChatMessage
|
||||
from autogen_agentchat.messages import BaseAgentEvent, BaseChatMessage
|
||||
from autogen_core import Image
|
||||
from autogen_core.models import (
|
||||
AssistantMessage,
|
||||
|
@ -343,7 +343,7 @@ class PageLogger:
|
|||
if self.level > self.levels["INFO"]:
|
||||
return None
|
||||
|
||||
messages: Sequence[AgentEvent | ChatMessage] = task_result.messages
|
||||
messages: Sequence[BaseAgentEvent | BaseChatMessage] = task_result.messages
|
||||
message = messages[-1]
|
||||
response_str = message.to_text()
|
||||
if not isinstance(response_str, str):
|
||||
|
|
|
@ -14,8 +14,8 @@ from typing import (
|
|||
|
||||
from autogen_agentchat.base import Response, TaskResult
|
||||
from autogen_agentchat.messages import (
|
||||
AgentEvent,
|
||||
ChatMessage,
|
||||
BaseAgentEvent,
|
||||
BaseChatMessage,
|
||||
ModelClientStreamingChunkEvent,
|
||||
MultiModalMessage,
|
||||
UserInputRequestedEvent,
|
||||
|
@ -56,7 +56,7 @@ def aprint(output: str, end: str = "\n") -> Awaitable[None]:
|
|||
return asyncio.to_thread(print, output, end=end)
|
||||
|
||||
|
||||
def _extract_message_content(message: AgentEvent | ChatMessage) -> Tuple[List[str], List[Image]]:
|
||||
def _extract_message_content(message: BaseAgentEvent | BaseChatMessage) -> Tuple[List[str], List[Image]]:
|
||||
if isinstance(message, MultiModalMessage):
|
||||
text_parts = [item for item in message.content if isinstance(item, str)]
|
||||
image_parts = [item for item in message.content if isinstance(item, Image)]
|
||||
|
@ -100,7 +100,7 @@ async def _aprint_message_content(
|
|||
|
||||
|
||||
async def RichConsole(
|
||||
stream: AsyncGenerator[AgentEvent | ChatMessage | T, None],
|
||||
stream: AsyncGenerator[BaseAgentEvent | BaseChatMessage | T, None],
|
||||
*,
|
||||
no_inline_images: bool = False,
|
||||
output_stats: bool = False,
|
||||
|
@ -117,7 +117,7 @@ async def RichConsole(
|
|||
It will be improved in future releases.
|
||||
|
||||
Args:
|
||||
stream (AsyncGenerator[AgentEvent | ChatMessage | TaskResult, None] | AsyncGenerator[AgentEvent | ChatMessage | Response, None]): Message stream to render.
|
||||
stream (AsyncGenerator[BaseAgentEvent | BaseChatMessage | TaskResult, None] | AsyncGenerator[BaseAgentEvent | BaseChatMessage | Response, None]): Message stream to render.
|
||||
This can be from :meth:`~autogen_agentchat.base.TaskRunner.run_stream` or :meth:`~autogen_agentchat.base.ChatAgent.on_messages_stream`.
|
||||
no_inline_images (bool, optional): If terminal is iTerm2 will render images inline. Use this to disable this behavior. Defaults to False.
|
||||
output_stats (bool, optional): (Experimental) If True, will output a summary of the messages and inline token usage info. Defaults to False.
|
||||
|
@ -191,7 +191,7 @@ async def RichConsole(
|
|||
pass
|
||||
else:
|
||||
# Cast required for mypy to be happy
|
||||
message = cast(AgentEvent | ChatMessage, message) # type: ignore
|
||||
message = cast(BaseAgentEvent | BaseChatMessage, message) # type: ignore
|
||||
|
||||
text_parts, image_parts = _extract_message_content(message)
|
||||
# Add usage stats if needed
|
||||
|
|
|
@ -8,7 +8,7 @@ from unittest.mock import AsyncMock, MagicMock
|
|||
|
||||
import aiofiles
|
||||
import pytest
|
||||
from autogen_agentchat.messages import ChatMessage, TextMessage, ToolCallRequestEvent
|
||||
from autogen_agentchat.messages import BaseChatMessage, TextMessage, ToolCallRequestEvent
|
||||
from autogen_core import CancellationToken
|
||||
from autogen_core.tools._base import BaseTool, Tool
|
||||
from autogen_ext.agents.openai import OpenAIAssistantAgent
|
||||
|
@ -81,7 +81,7 @@ class FakeMessage:
|
|||
|
||||
|
||||
class FakeCursorPage:
|
||||
def __init__(self, data: List[ChatMessage | FakeMessage]) -> None:
|
||||
def __init__(self, data: List[BaseChatMessage | FakeMessage]) -> None:
|
||||
self.data = data
|
||||
|
||||
def has_next_page(self) -> bool:
|
||||
|
|
|
@ -10,7 +10,7 @@ import aiofiles
|
|||
import yaml
|
||||
from autogen_agentchat.agents import UserProxyAgent
|
||||
from autogen_agentchat.base import TaskResult, Team
|
||||
from autogen_agentchat.messages import AgentEvent, ChatMessage
|
||||
from autogen_agentchat.messages import BaseAgentEvent, BaseChatMessage
|
||||
from autogen_agentchat.teams import BaseGroupChat
|
||||
from autogen_core import EVENT_LOGGER_NAME, CancellationToken, Component, ComponentModel
|
||||
from autogen_core.logging import LLMCallEvent
|
||||
|
@ -102,7 +102,7 @@ class TeamManager:
|
|||
input_func: Optional[Callable] = None,
|
||||
cancellation_token: Optional[CancellationToken] = None,
|
||||
env_vars: Optional[List[EnvironmentVariable]] = None,
|
||||
) -> AsyncGenerator[Union[AgentEvent | ChatMessage | LLMCallEvent, ChatMessage, TeamResult], None]:
|
||||
) -> AsyncGenerator[Union[BaseAgentEvent | BaseChatMessage | LLMCallEvent, BaseChatMessage, TeamResult], None]:
|
||||
"""Stream team execution results"""
|
||||
start_time = time.time()
|
||||
team = None
|
||||
|
|
|
@ -6,8 +6,8 @@ from typing import Any, Callable, Dict, Optional, Union
|
|||
|
||||
from autogen_agentchat.base._task import TaskResult
|
||||
from autogen_agentchat.messages import (
|
||||
AgentEvent,
|
||||
ChatMessage,
|
||||
BaseAgentEvent,
|
||||
BaseChatMessage,
|
||||
HandoffMessage,
|
||||
ModelClientStreamingChunkEvent,
|
||||
MultiModalMessage,
|
||||
|
@ -160,7 +160,9 @@ class WebSocketManager:
|
|||
finally:
|
||||
self._cancellation_tokens.pop(run_id, None)
|
||||
|
||||
async def _save_message(self, run_id: int, message: Union[AgentEvent | ChatMessage, ChatMessage]) -> None:
|
||||
async def _save_message(
|
||||
self, run_id: int, message: Union[BaseAgentEvent | BaseChatMessage, BaseChatMessage]
|
||||
) -> None:
|
||||
"""Save a message to the database"""
|
||||
|
||||
run = await self._get_run(run_id)
|
||||
|
|
Loading…
Reference in New Issue