Stop run when an error occured in a group chat (#6141)

Resolves #5851

* Added GroupChatError event type and terminate a run when an error
occurs in either a participant or the group chat manager
* Raise a RuntimeError from the error message within the group chat run
This commit is contained in:
Eric Zhu 2025-04-01 13:17:50 -07:00 committed by GitHub
parent 86237c9fdf
commit aec04e76ec
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 220 additions and 88 deletions

View File

@ -1,5 +1,4 @@
import asyncio
import logging
import uuid
from abc import ABC, abstractmethod
from typing import Any, AsyncGenerator, Callable, Dict, List, Mapping, Sequence
@ -15,7 +14,6 @@ from autogen_core import (
)
from pydantic import BaseModel, ValidationError
from ... import EVENT_LOGGER_NAME
from ...base import ChatAgent, TaskResult, Team, TerminationCondition
from ...messages import (
BaseAgentEvent,
@ -27,11 +25,16 @@ from ...messages import (
)
from ...state import TeamState
from ._chat_agent_container import ChatAgentContainer
from ._events import GroupChatPause, GroupChatReset, GroupChatResume, GroupChatStart, GroupChatTermination
from ._events import (
GroupChatPause,
GroupChatReset,
GroupChatResume,
GroupChatStart,
GroupChatTermination,
SerializableException,
)
from ._sequential_routed_agent import SequentialRoutedAgent
event_logger = logging.getLogger(EVENT_LOGGER_NAME)
class BaseGroupChat(Team, ABC, ComponentBase[BaseModel]):
"""The base class for group chat teams.
@ -447,13 +450,26 @@ class BaseGroupChat(Team, ABC, ComponentBase[BaseModel]):
try:
# This will propagate any exceptions raised.
await self._runtime.stop_when_idle()
finally:
# Stop the consumption of messages and end the stream.
# NOTE: we also need to put a GroupChatTermination event here because when the group chat
# has an exception, the group chat manager may not be able to put a GroupChatTermination event in the queue.
# Put a termination message in the queue to indicate that the group chat is stopped for whatever reason
# but not due to an exception.
await self._output_message_queue.put(
GroupChatTermination(
message=StopMessage(content="Exception occurred.", source=self._group_chat_manager_name)
message=StopMessage(
content="The group chat is stopped.", source=self._group_chat_manager_name
)
)
)
except Exception as e:
# Stop the consumption of messages and end the stream.
# NOTE: we also need to put a GroupChatTermination event here because when the runtime
# has an exception, the group chat manager may not be able to put a GroupChatTermination event in the queue.
# This may not be necessary if the group chat manager is able to handle the exception and put the event in the queue.
await self._output_message_queue.put(
GroupChatTermination(
message=StopMessage(
content="An exception occurred in the runtime.", source=self._group_chat_manager_name
),
error=SerializableException.from_exception(e),
)
)
@ -481,11 +497,10 @@ class BaseGroupChat(Team, ABC, ComponentBase[BaseModel]):
# Wait for the next message, this will raise an exception if the task is cancelled.
message = await message_future
if isinstance(message, GroupChatTermination):
# If the message is None, it means the group chat has terminated.
# TODO: how do we handle termination when the runtime is not embedded
# and there is an exception in the group chat?
# The group chat manager may not be able to put a GroupChatTermination event in the queue,
# and this loop will never end.
# If the message contains an error, we need to raise it here.
# This will stop the team and propagate the error.
if message.error is not None:
raise RuntimeError(str(message.error))
stop_reason = message.message.content
break
yield message

View File

@ -8,6 +8,7 @@ from ...base import TerminationCondition
from ...messages import BaseAgentEvent, BaseChatMessage, MessageFactory, StopMessage
from ._events import (
GroupChatAgentResponse,
GroupChatError,
GroupChatMessage,
GroupChatPause,
GroupChatRequestPublish,
@ -15,6 +16,7 @@ from ._events import (
GroupChatResume,
GroupChatStart,
GroupChatTermination,
SerializableException,
)
from ._sequential_routed_agent import SequentialRoutedAgent
@ -140,58 +142,65 @@ class BaseGroupChatManager(SequentialRoutedAgent, ABC):
@event
async def handle_agent_response(self, message: GroupChatAgentResponse, ctx: MessageContext) -> None:
# Append the message to the message thread and construct the delta.
delta: List[BaseAgentEvent | BaseChatMessage] = []
if message.agent_response.inner_messages is not None:
for inner_message in message.agent_response.inner_messages:
self._message_thread.append(inner_message)
delta.append(inner_message)
self._message_thread.append(message.agent_response.chat_message)
delta.append(message.agent_response.chat_message)
try:
# Append the message to the message thread and construct the delta.
delta: List[BaseAgentEvent | BaseChatMessage] = []
if message.agent_response.inner_messages is not None:
for inner_message in message.agent_response.inner_messages:
self._message_thread.append(inner_message)
delta.append(inner_message)
self._message_thread.append(message.agent_response.chat_message)
delta.append(message.agent_response.chat_message)
# Check if the conversation should be terminated.
if self._termination_condition is not None:
stop_message = await self._termination_condition(delta)
if stop_message is not None:
# Reset the termination conditions and turn count.
await self._termination_condition.reset()
self._current_turn = 0
# Signal termination to the caller of the team.
await self._signal_termination(stop_message)
# Stop the group chat.
return
# Increment the turn count.
self._current_turn += 1
# Check if the maximum number of turns has been reached.
if self._max_turns is not None:
if self._current_turn >= self._max_turns:
stop_message = StopMessage(
content=f"Maximum number of turns {self._max_turns} reached.",
source=self._name,
)
# Reset the termination conditions and turn count.
if self._termination_condition is not None:
# Check if the conversation should be terminated.
if self._termination_condition is not None:
stop_message = await self._termination_condition(delta)
if stop_message is not None:
# Reset the termination conditions and turn count.
await self._termination_condition.reset()
self._current_turn = 0
# Signal termination to the caller of the team.
await self._signal_termination(stop_message)
# Stop the group chat.
return
self._current_turn = 0
# Signal termination to the caller of the team.
await self._signal_termination(stop_message)
# Stop the group chat.
return
# Select a speaker to continue the conversation.
speaker_name_future = asyncio.ensure_future(self.select_speaker(self._message_thread))
# Link the select speaker future to the cancellation token.
ctx.cancellation_token.link_future(speaker_name_future)
speaker_name = await speaker_name_future
if speaker_name not in self._participant_name_to_topic_type:
raise RuntimeError(f"Speaker {speaker_name} not found in participant names.")
speaker_topic_type = self._participant_name_to_topic_type[speaker_name]
await self.publish_message(
GroupChatRequestPublish(),
topic_id=DefaultTopicId(type=speaker_topic_type),
cancellation_token=ctx.cancellation_token,
)
# Increment the turn count.
self._current_turn += 1
# Check if the maximum number of turns has been reached.
if self._max_turns is not None:
if self._current_turn >= self._max_turns:
stop_message = StopMessage(
content=f"Maximum number of turns {self._max_turns} reached.",
source=self._name,
)
# Reset the termination conditions and turn count.
if self._termination_condition is not None:
await self._termination_condition.reset()
self._current_turn = 0
# Signal termination to the caller of the team.
await self._signal_termination(stop_message)
# Stop the group chat.
return
# Select a speaker to continue the conversation.
speaker_name_future = asyncio.ensure_future(self.select_speaker(self._message_thread))
# Link the select speaker future to the cancellation token.
ctx.cancellation_token.link_future(speaker_name_future)
speaker_name = await speaker_name_future
if speaker_name not in self._participant_name_to_topic_type:
raise RuntimeError(f"Speaker {speaker_name} not found in participant names.")
speaker_topic_type = self._participant_name_to_topic_type[speaker_name]
await self.publish_message(
GroupChatRequestPublish(),
topic_id=DefaultTopicId(type=speaker_topic_type),
cancellation_token=ctx.cancellation_token,
)
except Exception as e:
# Handle the exception and signal termination with an error.
error = SerializableException.from_exception(e)
await self._signal_termination_with_error(error)
# Raise the exception to the runtime.
raise
async def _signal_termination(self, message: StopMessage) -> None:
termination_event = GroupChatTermination(message=message)
@ -203,11 +212,28 @@ class BaseGroupChatManager(SequentialRoutedAgent, ABC):
# Put the termination event in the output message queue.
await self._output_message_queue.put(termination_event)
async def _signal_termination_with_error(self, error: SerializableException) -> None:
termination_event = GroupChatTermination(
message=StopMessage(content="An error occurred in the group chat.", source=self._name), error=error
)
# Log the termination event.
await self.publish_message(
termination_event,
topic_id=DefaultTopicId(type=self._output_topic_type),
)
# Put the termination event in the output message queue.
await self._output_message_queue.put(termination_event)
@event
async def handle_group_chat_message(self, message: GroupChatMessage, ctx: MessageContext) -> None:
"""Handle a group chat message by appending the content to its output message queue."""
await self._output_message_queue.put(message.message)
@event
async def handle_group_chat_error(self, message: GroupChatError, ctx: MessageContext) -> None:
"""Handle a group chat error by logging the error and signaling termination."""
await self._signal_termination_with_error(message.error)
@rpc
async def handle_reset(self, message: GroupChatReset, ctx: MessageContext) -> None:
"""Reset the group chat manager. Calling :meth:`reset` to reset the group chat manager

View File

@ -8,12 +8,14 @@ from ...base import ChatAgent, Response
from ...state import ChatAgentContainerState
from ._events import (
GroupChatAgentResponse,
GroupChatError,
GroupChatMessage,
GroupChatPause,
GroupChatRequestPublish,
GroupChatReset,
GroupChatResume,
GroupChatStart,
SerializableException,
)
from ._sequential_routed_agent import SequentialRoutedAgent
@ -71,24 +73,36 @@ class ChatAgentContainer(SequentialRoutedAgent):
async def handle_request(self, message: GroupChatRequestPublish, ctx: MessageContext) -> None:
"""Handle a content request event by passing the messages in the buffer
to the delegate agent and publish the response."""
# Pass the messages in the buffer to the delegate agent.
response: Response | None = None
async for msg in self._agent.on_messages_stream(self._message_buffer, ctx.cancellation_token):
if isinstance(msg, Response):
await self._log_message(msg.chat_message)
response = msg
else:
await self._log_message(msg)
if response is None:
raise ValueError("The agent did not produce a final response. Check the agent's on_messages_stream method.")
# Publish the response to the group chat.
self._message_buffer.clear()
await self.publish_message(
GroupChatAgentResponse(agent_response=response),
topic_id=DefaultTopicId(type=self._parent_topic_type),
cancellation_token=ctx.cancellation_token,
)
try:
# Pass the messages in the buffer to the delegate agent.
response: Response | None = None
async for msg in self._agent.on_messages_stream(self._message_buffer, ctx.cancellation_token):
if isinstance(msg, Response):
await self._log_message(msg.chat_message)
response = msg
else:
await self._log_message(msg)
if response is None:
raise ValueError(
"The agent did not produce a final response. Check the agent's on_messages_stream method."
)
# Publish the response to the group chat.
self._message_buffer.clear()
await self.publish_message(
GroupChatAgentResponse(agent_response=response),
topic_id=DefaultTopicId(type=self._parent_topic_type),
cancellation_token=ctx.cancellation_token,
)
except Exception as e:
# Publish the error to the group chat.
error_message = SerializableException.from_exception(e)
await self.publish_message(
GroupChatError(error=error_message),
topic_id=DefaultTopicId(type=self._parent_topic_type),
cancellation_token=ctx.cancellation_token,
)
# Raise the error to the runtime.
raise
def _buffer_message(self, message: BaseChatMessage) -> None:
if not self._message_factory.is_registered(message.__class__):

View File

@ -1,3 +1,4 @@
import traceback
from typing import List
from pydantic import BaseModel
@ -6,6 +7,34 @@ from ...base import Response
from ...messages import BaseAgentEvent, BaseChatMessage, StopMessage
class SerializableException(BaseModel):
"""A serializable exception."""
error_type: str
"""The type of error that occurred."""
error_message: str
"""The error message that describes the error."""
traceback: str | None = None
"""The traceback of the error, if available."""
@classmethod
def from_exception(cls, exc: Exception) -> "SerializableException":
"""Create a GroupChatError from an exception."""
return cls(
error_type=type(exc).__name__,
error_message=str(exc),
traceback="\n".join(traceback.format_exception(type(exc), exc, exc.__traceback__)),
)
def __str__(self) -> str:
"""Return a string representation of the error, including the traceback if available."""
if self.traceback:
return f"{self.error_type}: {self.error_message}\nTraceback:\n{self.traceback}"
return f"{self.error_type}: {self.error_message}"
class GroupChatStart(BaseModel):
"""A request to start a group chat."""
@ -39,6 +68,9 @@ class GroupChatTermination(BaseModel):
message: StopMessage
"""The stop message that indicates the reason of termination."""
error: SerializableException | None = None
"""The error that occurred, if any."""
class GroupChatReset(BaseModel):
"""A request to reset the agents in the group chat."""
@ -56,3 +88,10 @@ class GroupChatResume(BaseModel):
"""A request to resume the group chat."""
...
class GroupChatError(BaseModel):
"""A message indicating that an error occurred in the group chat."""
error: SerializableException
"""The error that occurred."""

View File

@ -12,7 +12,7 @@ from autogen_agentchat.agents import (
BaseChatAgent,
CodeExecutorAgent,
)
from autogen_agentchat.base import Handoff, Response, TaskResult
from autogen_agentchat.base import Handoff, Response, TaskResult, TerminationCondition
from autogen_agentchat.conditions import HandoffTermination, MaxMessageTermination, TextMentionTermination
from autogen_agentchat.messages import (
BaseAgentEvent,
@ -103,6 +103,26 @@ class _FlakyAgent(BaseChatAgent):
self._last_message = None
class _FlakyTermination(TerminationCondition):
def __init__(self, raise_on_count: int) -> None:
self._raise_on_count = raise_on_count
self._count = 0
@property
def terminated(self) -> bool:
"""Check if the termination condition has been reached"""
return False
async def __call__(self, messages: Sequence[BaseAgentEvent | BaseChatMessage]) -> StopMessage | None:
self._count += 1
if self._count == self._raise_on_count:
raise ValueError("I am a flaky termination...")
return None
async def reset(self) -> None:
pass
class _UnknownMessageType(BaseChatMessage):
content: str
@ -285,7 +305,7 @@ async def test_round_robin_group_chat_unknown_agent_message_type() -> None:
agent2 = _UnknownMessageTypeAgent("agent2", "I am an unknown message type agent")
termination = TextMentionTermination("TERMINATE")
team1 = RoundRobinGroupChat(participants=[agent1, agent2], termination_condition=termination)
with pytest.raises(ValueError, match="Message type .*UnknownMessageType.* not registered"):
with pytest.raises(RuntimeError, match=".* Message type .*UnknownMessageType.* not registered"):
await team1.run(task=TextMessage(content="Write a program that prints 'Hello, world!'", source="user"))
@ -457,10 +477,8 @@ async def test_round_robin_group_chat_with_resume_and_reset(runtime: AgentRuntim
assert result.stop_reason is not None
# TODO: add runtime fixture for testing with custom runtime once the issue regarding
# hanging on exception is resolved.
@pytest.mark.asyncio
async def test_round_robin_group_chat_with_exception_raised() -> None:
async def test_round_robin_group_chat_with_exception_raised_from_agent(runtime: AgentRuntime | None) -> None:
agent_1 = _EchoAgent("agent_1", description="echo agent 1")
agent_2 = _FlakyAgent("agent_2", description="echo agent 2")
agent_3 = _EchoAgent("agent_3", description="echo agent 3")
@ -468,9 +486,29 @@ async def test_round_robin_group_chat_with_exception_raised() -> None:
team = RoundRobinGroupChat(
participants=[agent_1, agent_2, agent_3],
termination_condition=termination,
runtime=runtime,
)
with pytest.raises(ValueError, match="I am a flaky agent..."):
with pytest.raises(RuntimeError, match="I am a flaky agent..."):
await team.run(
task="Write a program that prints 'Hello, world!'",
)
@pytest.mark.asyncio
async def test_round_robin_group_chat_with_exception_raised_from_termination_condition(
runtime: AgentRuntime | None,
) -> None:
agent_1 = _EchoAgent("agent_1", description="echo agent 1")
agent_2 = _FlakyAgent("agent_2", description="echo agent 2")
agent_3 = _EchoAgent("agent_3", description="echo agent 3")
team = RoundRobinGroupChat(
participants=[agent_1, agent_2, agent_3],
termination_condition=_FlakyTermination(raise_on_count=1),
runtime=runtime,
)
with pytest.raises(Exception, match="I am a flaky termination..."):
await team.run(
task="Write a program that prints 'Hello, world!'",
)