mirror of https://github.com/microsoft/autogen.git
Stop run when an error occured in a group chat (#6141)
Resolves #5851 * Added GroupChatError event type and terminate a run when an error occurs in either a participant or the group chat manager * Raise a RuntimeError from the error message within the group chat run
This commit is contained in:
parent
86237c9fdf
commit
aec04e76ec
|
@ -1,5 +1,4 @@
|
|||
import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, AsyncGenerator, Callable, Dict, List, Mapping, Sequence
|
||||
|
@ -15,7 +14,6 @@ from autogen_core import (
|
|||
)
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
from ... import EVENT_LOGGER_NAME
|
||||
from ...base import ChatAgent, TaskResult, Team, TerminationCondition
|
||||
from ...messages import (
|
||||
BaseAgentEvent,
|
||||
|
@ -27,11 +25,16 @@ from ...messages import (
|
|||
)
|
||||
from ...state import TeamState
|
||||
from ._chat_agent_container import ChatAgentContainer
|
||||
from ._events import GroupChatPause, GroupChatReset, GroupChatResume, GroupChatStart, GroupChatTermination
|
||||
from ._events import (
|
||||
GroupChatPause,
|
||||
GroupChatReset,
|
||||
GroupChatResume,
|
||||
GroupChatStart,
|
||||
GroupChatTermination,
|
||||
SerializableException,
|
||||
)
|
||||
from ._sequential_routed_agent import SequentialRoutedAgent
|
||||
|
||||
event_logger = logging.getLogger(EVENT_LOGGER_NAME)
|
||||
|
||||
|
||||
class BaseGroupChat(Team, ABC, ComponentBase[BaseModel]):
|
||||
"""The base class for group chat teams.
|
||||
|
@ -447,13 +450,26 @@ class BaseGroupChat(Team, ABC, ComponentBase[BaseModel]):
|
|||
try:
|
||||
# This will propagate any exceptions raised.
|
||||
await self._runtime.stop_when_idle()
|
||||
finally:
|
||||
# Stop the consumption of messages and end the stream.
|
||||
# NOTE: we also need to put a GroupChatTermination event here because when the group chat
|
||||
# has an exception, the group chat manager may not be able to put a GroupChatTermination event in the queue.
|
||||
# Put a termination message in the queue to indicate that the group chat is stopped for whatever reason
|
||||
# but not due to an exception.
|
||||
await self._output_message_queue.put(
|
||||
GroupChatTermination(
|
||||
message=StopMessage(content="Exception occurred.", source=self._group_chat_manager_name)
|
||||
message=StopMessage(
|
||||
content="The group chat is stopped.", source=self._group_chat_manager_name
|
||||
)
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
# Stop the consumption of messages and end the stream.
|
||||
# NOTE: we also need to put a GroupChatTermination event here because when the runtime
|
||||
# has an exception, the group chat manager may not be able to put a GroupChatTermination event in the queue.
|
||||
# This may not be necessary if the group chat manager is able to handle the exception and put the event in the queue.
|
||||
await self._output_message_queue.put(
|
||||
GroupChatTermination(
|
||||
message=StopMessage(
|
||||
content="An exception occurred in the runtime.", source=self._group_chat_manager_name
|
||||
),
|
||||
error=SerializableException.from_exception(e),
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -481,11 +497,10 @@ class BaseGroupChat(Team, ABC, ComponentBase[BaseModel]):
|
|||
# Wait for the next message, this will raise an exception if the task is cancelled.
|
||||
message = await message_future
|
||||
if isinstance(message, GroupChatTermination):
|
||||
# If the message is None, it means the group chat has terminated.
|
||||
# TODO: how do we handle termination when the runtime is not embedded
|
||||
# and there is an exception in the group chat?
|
||||
# The group chat manager may not be able to put a GroupChatTermination event in the queue,
|
||||
# and this loop will never end.
|
||||
# If the message contains an error, we need to raise it here.
|
||||
# This will stop the team and propagate the error.
|
||||
if message.error is not None:
|
||||
raise RuntimeError(str(message.error))
|
||||
stop_reason = message.message.content
|
||||
break
|
||||
yield message
|
||||
|
|
|
@ -8,6 +8,7 @@ from ...base import TerminationCondition
|
|||
from ...messages import BaseAgentEvent, BaseChatMessage, MessageFactory, StopMessage
|
||||
from ._events import (
|
||||
GroupChatAgentResponse,
|
||||
GroupChatError,
|
||||
GroupChatMessage,
|
||||
GroupChatPause,
|
||||
GroupChatRequestPublish,
|
||||
|
@ -15,6 +16,7 @@ from ._events import (
|
|||
GroupChatResume,
|
||||
GroupChatStart,
|
||||
GroupChatTermination,
|
||||
SerializableException,
|
||||
)
|
||||
from ._sequential_routed_agent import SequentialRoutedAgent
|
||||
|
||||
|
@ -140,58 +142,65 @@ class BaseGroupChatManager(SequentialRoutedAgent, ABC):
|
|||
|
||||
@event
|
||||
async def handle_agent_response(self, message: GroupChatAgentResponse, ctx: MessageContext) -> None:
|
||||
# Append the message to the message thread and construct the delta.
|
||||
delta: List[BaseAgentEvent | BaseChatMessage] = []
|
||||
if message.agent_response.inner_messages is not None:
|
||||
for inner_message in message.agent_response.inner_messages:
|
||||
self._message_thread.append(inner_message)
|
||||
delta.append(inner_message)
|
||||
self._message_thread.append(message.agent_response.chat_message)
|
||||
delta.append(message.agent_response.chat_message)
|
||||
try:
|
||||
# Append the message to the message thread and construct the delta.
|
||||
delta: List[BaseAgentEvent | BaseChatMessage] = []
|
||||
if message.agent_response.inner_messages is not None:
|
||||
for inner_message in message.agent_response.inner_messages:
|
||||
self._message_thread.append(inner_message)
|
||||
delta.append(inner_message)
|
||||
self._message_thread.append(message.agent_response.chat_message)
|
||||
delta.append(message.agent_response.chat_message)
|
||||
|
||||
# Check if the conversation should be terminated.
|
||||
if self._termination_condition is not None:
|
||||
stop_message = await self._termination_condition(delta)
|
||||
if stop_message is not None:
|
||||
# Reset the termination conditions and turn count.
|
||||
await self._termination_condition.reset()
|
||||
self._current_turn = 0
|
||||
# Signal termination to the caller of the team.
|
||||
await self._signal_termination(stop_message)
|
||||
# Stop the group chat.
|
||||
return
|
||||
|
||||
# Increment the turn count.
|
||||
self._current_turn += 1
|
||||
# Check if the maximum number of turns has been reached.
|
||||
if self._max_turns is not None:
|
||||
if self._current_turn >= self._max_turns:
|
||||
stop_message = StopMessage(
|
||||
content=f"Maximum number of turns {self._max_turns} reached.",
|
||||
source=self._name,
|
||||
)
|
||||
# Reset the termination conditions and turn count.
|
||||
if self._termination_condition is not None:
|
||||
# Check if the conversation should be terminated.
|
||||
if self._termination_condition is not None:
|
||||
stop_message = await self._termination_condition(delta)
|
||||
if stop_message is not None:
|
||||
# Reset the termination conditions and turn count.
|
||||
await self._termination_condition.reset()
|
||||
self._current_turn = 0
|
||||
# Signal termination to the caller of the team.
|
||||
await self._signal_termination(stop_message)
|
||||
# Stop the group chat.
|
||||
return
|
||||
self._current_turn = 0
|
||||
# Signal termination to the caller of the team.
|
||||
await self._signal_termination(stop_message)
|
||||
# Stop the group chat.
|
||||
return
|
||||
|
||||
# Select a speaker to continue the conversation.
|
||||
speaker_name_future = asyncio.ensure_future(self.select_speaker(self._message_thread))
|
||||
# Link the select speaker future to the cancellation token.
|
||||
ctx.cancellation_token.link_future(speaker_name_future)
|
||||
speaker_name = await speaker_name_future
|
||||
if speaker_name not in self._participant_name_to_topic_type:
|
||||
raise RuntimeError(f"Speaker {speaker_name} not found in participant names.")
|
||||
speaker_topic_type = self._participant_name_to_topic_type[speaker_name]
|
||||
await self.publish_message(
|
||||
GroupChatRequestPublish(),
|
||||
topic_id=DefaultTopicId(type=speaker_topic_type),
|
||||
cancellation_token=ctx.cancellation_token,
|
||||
)
|
||||
# Increment the turn count.
|
||||
self._current_turn += 1
|
||||
# Check if the maximum number of turns has been reached.
|
||||
if self._max_turns is not None:
|
||||
if self._current_turn >= self._max_turns:
|
||||
stop_message = StopMessage(
|
||||
content=f"Maximum number of turns {self._max_turns} reached.",
|
||||
source=self._name,
|
||||
)
|
||||
# Reset the termination conditions and turn count.
|
||||
if self._termination_condition is not None:
|
||||
await self._termination_condition.reset()
|
||||
self._current_turn = 0
|
||||
# Signal termination to the caller of the team.
|
||||
await self._signal_termination(stop_message)
|
||||
# Stop the group chat.
|
||||
return
|
||||
|
||||
# Select a speaker to continue the conversation.
|
||||
speaker_name_future = asyncio.ensure_future(self.select_speaker(self._message_thread))
|
||||
# Link the select speaker future to the cancellation token.
|
||||
ctx.cancellation_token.link_future(speaker_name_future)
|
||||
speaker_name = await speaker_name_future
|
||||
if speaker_name not in self._participant_name_to_topic_type:
|
||||
raise RuntimeError(f"Speaker {speaker_name} not found in participant names.")
|
||||
speaker_topic_type = self._participant_name_to_topic_type[speaker_name]
|
||||
await self.publish_message(
|
||||
GroupChatRequestPublish(),
|
||||
topic_id=DefaultTopicId(type=speaker_topic_type),
|
||||
cancellation_token=ctx.cancellation_token,
|
||||
)
|
||||
except Exception as e:
|
||||
# Handle the exception and signal termination with an error.
|
||||
error = SerializableException.from_exception(e)
|
||||
await self._signal_termination_with_error(error)
|
||||
# Raise the exception to the runtime.
|
||||
raise
|
||||
|
||||
async def _signal_termination(self, message: StopMessage) -> None:
|
||||
termination_event = GroupChatTermination(message=message)
|
||||
|
@ -203,11 +212,28 @@ class BaseGroupChatManager(SequentialRoutedAgent, ABC):
|
|||
# Put the termination event in the output message queue.
|
||||
await self._output_message_queue.put(termination_event)
|
||||
|
||||
async def _signal_termination_with_error(self, error: SerializableException) -> None:
|
||||
termination_event = GroupChatTermination(
|
||||
message=StopMessage(content="An error occurred in the group chat.", source=self._name), error=error
|
||||
)
|
||||
# Log the termination event.
|
||||
await self.publish_message(
|
||||
termination_event,
|
||||
topic_id=DefaultTopicId(type=self._output_topic_type),
|
||||
)
|
||||
# Put the termination event in the output message queue.
|
||||
await self._output_message_queue.put(termination_event)
|
||||
|
||||
@event
|
||||
async def handle_group_chat_message(self, message: GroupChatMessage, ctx: MessageContext) -> None:
|
||||
"""Handle a group chat message by appending the content to its output message queue."""
|
||||
await self._output_message_queue.put(message.message)
|
||||
|
||||
@event
|
||||
async def handle_group_chat_error(self, message: GroupChatError, ctx: MessageContext) -> None:
|
||||
"""Handle a group chat error by logging the error and signaling termination."""
|
||||
await self._signal_termination_with_error(message.error)
|
||||
|
||||
@rpc
|
||||
async def handle_reset(self, message: GroupChatReset, ctx: MessageContext) -> None:
|
||||
"""Reset the group chat manager. Calling :meth:`reset` to reset the group chat manager
|
||||
|
|
|
@ -8,12 +8,14 @@ from ...base import ChatAgent, Response
|
|||
from ...state import ChatAgentContainerState
|
||||
from ._events import (
|
||||
GroupChatAgentResponse,
|
||||
GroupChatError,
|
||||
GroupChatMessage,
|
||||
GroupChatPause,
|
||||
GroupChatRequestPublish,
|
||||
GroupChatReset,
|
||||
GroupChatResume,
|
||||
GroupChatStart,
|
||||
SerializableException,
|
||||
)
|
||||
from ._sequential_routed_agent import SequentialRoutedAgent
|
||||
|
||||
|
@ -71,24 +73,36 @@ class ChatAgentContainer(SequentialRoutedAgent):
|
|||
async def handle_request(self, message: GroupChatRequestPublish, ctx: MessageContext) -> None:
|
||||
"""Handle a content request event by passing the messages in the buffer
|
||||
to the delegate agent and publish the response."""
|
||||
# Pass the messages in the buffer to the delegate agent.
|
||||
response: Response | None = None
|
||||
async for msg in self._agent.on_messages_stream(self._message_buffer, ctx.cancellation_token):
|
||||
if isinstance(msg, Response):
|
||||
await self._log_message(msg.chat_message)
|
||||
response = msg
|
||||
else:
|
||||
await self._log_message(msg)
|
||||
if response is None:
|
||||
raise ValueError("The agent did not produce a final response. Check the agent's on_messages_stream method.")
|
||||
|
||||
# Publish the response to the group chat.
|
||||
self._message_buffer.clear()
|
||||
await self.publish_message(
|
||||
GroupChatAgentResponse(agent_response=response),
|
||||
topic_id=DefaultTopicId(type=self._parent_topic_type),
|
||||
cancellation_token=ctx.cancellation_token,
|
||||
)
|
||||
try:
|
||||
# Pass the messages in the buffer to the delegate agent.
|
||||
response: Response | None = None
|
||||
async for msg in self._agent.on_messages_stream(self._message_buffer, ctx.cancellation_token):
|
||||
if isinstance(msg, Response):
|
||||
await self._log_message(msg.chat_message)
|
||||
response = msg
|
||||
else:
|
||||
await self._log_message(msg)
|
||||
if response is None:
|
||||
raise ValueError(
|
||||
"The agent did not produce a final response. Check the agent's on_messages_stream method."
|
||||
)
|
||||
# Publish the response to the group chat.
|
||||
self._message_buffer.clear()
|
||||
await self.publish_message(
|
||||
GroupChatAgentResponse(agent_response=response),
|
||||
topic_id=DefaultTopicId(type=self._parent_topic_type),
|
||||
cancellation_token=ctx.cancellation_token,
|
||||
)
|
||||
except Exception as e:
|
||||
# Publish the error to the group chat.
|
||||
error_message = SerializableException.from_exception(e)
|
||||
await self.publish_message(
|
||||
GroupChatError(error=error_message),
|
||||
topic_id=DefaultTopicId(type=self._parent_topic_type),
|
||||
cancellation_token=ctx.cancellation_token,
|
||||
)
|
||||
# Raise the error to the runtime.
|
||||
raise
|
||||
|
||||
def _buffer_message(self, message: BaseChatMessage) -> None:
|
||||
if not self._message_factory.is_registered(message.__class__):
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import traceback
|
||||
from typing import List
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
@ -6,6 +7,34 @@ from ...base import Response
|
|||
from ...messages import BaseAgentEvent, BaseChatMessage, StopMessage
|
||||
|
||||
|
||||
class SerializableException(BaseModel):
|
||||
"""A serializable exception."""
|
||||
|
||||
error_type: str
|
||||
"""The type of error that occurred."""
|
||||
|
||||
error_message: str
|
||||
"""The error message that describes the error."""
|
||||
|
||||
traceback: str | None = None
|
||||
"""The traceback of the error, if available."""
|
||||
|
||||
@classmethod
|
||||
def from_exception(cls, exc: Exception) -> "SerializableException":
|
||||
"""Create a GroupChatError from an exception."""
|
||||
return cls(
|
||||
error_type=type(exc).__name__,
|
||||
error_message=str(exc),
|
||||
traceback="\n".join(traceback.format_exception(type(exc), exc, exc.__traceback__)),
|
||||
)
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""Return a string representation of the error, including the traceback if available."""
|
||||
if self.traceback:
|
||||
return f"{self.error_type}: {self.error_message}\nTraceback:\n{self.traceback}"
|
||||
return f"{self.error_type}: {self.error_message}"
|
||||
|
||||
|
||||
class GroupChatStart(BaseModel):
|
||||
"""A request to start a group chat."""
|
||||
|
||||
|
@ -39,6 +68,9 @@ class GroupChatTermination(BaseModel):
|
|||
message: StopMessage
|
||||
"""The stop message that indicates the reason of termination."""
|
||||
|
||||
error: SerializableException | None = None
|
||||
"""The error that occurred, if any."""
|
||||
|
||||
|
||||
class GroupChatReset(BaseModel):
|
||||
"""A request to reset the agents in the group chat."""
|
||||
|
@ -56,3 +88,10 @@ class GroupChatResume(BaseModel):
|
|||
"""A request to resume the group chat."""
|
||||
|
||||
...
|
||||
|
||||
|
||||
class GroupChatError(BaseModel):
|
||||
"""A message indicating that an error occurred in the group chat."""
|
||||
|
||||
error: SerializableException
|
||||
"""The error that occurred."""
|
||||
|
|
|
@ -12,7 +12,7 @@ from autogen_agentchat.agents import (
|
|||
BaseChatAgent,
|
||||
CodeExecutorAgent,
|
||||
)
|
||||
from autogen_agentchat.base import Handoff, Response, TaskResult
|
||||
from autogen_agentchat.base import Handoff, Response, TaskResult, TerminationCondition
|
||||
from autogen_agentchat.conditions import HandoffTermination, MaxMessageTermination, TextMentionTermination
|
||||
from autogen_agentchat.messages import (
|
||||
BaseAgentEvent,
|
||||
|
@ -103,6 +103,26 @@ class _FlakyAgent(BaseChatAgent):
|
|||
self._last_message = None
|
||||
|
||||
|
||||
class _FlakyTermination(TerminationCondition):
|
||||
def __init__(self, raise_on_count: int) -> None:
|
||||
self._raise_on_count = raise_on_count
|
||||
self._count = 0
|
||||
|
||||
@property
|
||||
def terminated(self) -> bool:
|
||||
"""Check if the termination condition has been reached"""
|
||||
return False
|
||||
|
||||
async def __call__(self, messages: Sequence[BaseAgentEvent | BaseChatMessage]) -> StopMessage | None:
|
||||
self._count += 1
|
||||
if self._count == self._raise_on_count:
|
||||
raise ValueError("I am a flaky termination...")
|
||||
return None
|
||||
|
||||
async def reset(self) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class _UnknownMessageType(BaseChatMessage):
|
||||
content: str
|
||||
|
||||
|
@ -285,7 +305,7 @@ async def test_round_robin_group_chat_unknown_agent_message_type() -> None:
|
|||
agent2 = _UnknownMessageTypeAgent("agent2", "I am an unknown message type agent")
|
||||
termination = TextMentionTermination("TERMINATE")
|
||||
team1 = RoundRobinGroupChat(participants=[agent1, agent2], termination_condition=termination)
|
||||
with pytest.raises(ValueError, match="Message type .*UnknownMessageType.* not registered"):
|
||||
with pytest.raises(RuntimeError, match=".* Message type .*UnknownMessageType.* not registered"):
|
||||
await team1.run(task=TextMessage(content="Write a program that prints 'Hello, world!'", source="user"))
|
||||
|
||||
|
||||
|
@ -457,10 +477,8 @@ async def test_round_robin_group_chat_with_resume_and_reset(runtime: AgentRuntim
|
|||
assert result.stop_reason is not None
|
||||
|
||||
|
||||
# TODO: add runtime fixture for testing with custom runtime once the issue regarding
|
||||
# hanging on exception is resolved.
|
||||
@pytest.mark.asyncio
|
||||
async def test_round_robin_group_chat_with_exception_raised() -> None:
|
||||
async def test_round_robin_group_chat_with_exception_raised_from_agent(runtime: AgentRuntime | None) -> None:
|
||||
agent_1 = _EchoAgent("agent_1", description="echo agent 1")
|
||||
agent_2 = _FlakyAgent("agent_2", description="echo agent 2")
|
||||
agent_3 = _EchoAgent("agent_3", description="echo agent 3")
|
||||
|
@ -468,9 +486,29 @@ async def test_round_robin_group_chat_with_exception_raised() -> None:
|
|||
team = RoundRobinGroupChat(
|
||||
participants=[agent_1, agent_2, agent_3],
|
||||
termination_condition=termination,
|
||||
runtime=runtime,
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="I am a flaky agent..."):
|
||||
with pytest.raises(RuntimeError, match="I am a flaky agent..."):
|
||||
await team.run(
|
||||
task="Write a program that prints 'Hello, world!'",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_round_robin_group_chat_with_exception_raised_from_termination_condition(
|
||||
runtime: AgentRuntime | None,
|
||||
) -> None:
|
||||
agent_1 = _EchoAgent("agent_1", description="echo agent 1")
|
||||
agent_2 = _FlakyAgent("agent_2", description="echo agent 2")
|
||||
agent_3 = _EchoAgent("agent_3", description="echo agent 3")
|
||||
team = RoundRobinGroupChat(
|
||||
participants=[agent_1, agent_2, agent_3],
|
||||
termination_condition=_FlakyTermination(raise_on_count=1),
|
||||
runtime=runtime,
|
||||
)
|
||||
|
||||
with pytest.raises(Exception, match="I am a flaky termination..."):
|
||||
await team.run(
|
||||
task="Write a program that prints 'Hello, world!'",
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue