This commit is contained in:
EeS 2025-04-16 00:29:59 +01:00 committed by GitHub
commit 0b2a764cab
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 94 additions and 4 deletions

View File

@ -376,6 +376,18 @@ class ThoughtEvent(BaseAgentEvent):
return self.content
class SelectSpeakerEvent(BaseAgentEvent):
"""An event signaling the selection of a speaker for a conversation."""
content: str
"""The name of the selected speaker."""
type: Literal["SelectSpeakerEvent"] = "SelectSpeakerEvent"
def to_text(self) -> str:
return self.content
class MessageFactory:
""":meta private:
@ -398,6 +410,7 @@ class MessageFactory:
self._message_types[UserInputRequestedEvent.__name__] = UserInputRequestedEvent
self._message_types[ModelClientStreamingChunkEvent.__name__] = ModelClientStreamingChunkEvent
self._message_types[ThoughtEvent.__name__] = ThoughtEvent
self._message_types[SelectSpeakerEvent.__name__] = SelectSpeakerEvent
self._message_types[CodeGenerationEvent.__name__] = CodeGenerationEvent
self._message_types[CodeExecutionEvent.__name__] = CodeExecutionEvent
@ -453,6 +466,7 @@ AgentEvent = Annotated[
| UserInputRequestedEvent
| ModelClientStreamingChunkEvent
| ThoughtEvent
| SelectSpeakerEvent
| CodeGenerationEvent
| CodeExecutionEvent,
Field(discriminator="type"),
@ -479,6 +493,7 @@ __all__ = [
"UserInputRequestedEvent",
"ModelClientStreamingChunkEvent",
"ThoughtEvent",
"SelectSpeakerEvent",
"MessageFactory",
"CodeGenerationEvent",
"CodeExecutionEvent",

View File

@ -54,6 +54,7 @@ class BaseGroupChat(Team, ABC, ComponentBase[BaseModel]):
max_turns: int | None = None,
runtime: AgentRuntime | None = None,
custom_message_types: List[type[BaseAgentEvent | BaseChatMessage]] | None = None,
emit_team_events: bool = False,
):
if len(participants) == 0:
raise ValueError("At least one participant is required.")
@ -113,6 +114,9 @@ class BaseGroupChat(Team, ABC, ComponentBase[BaseModel]):
# Flag to track if the group chat is running.
self._is_running = False
# Flag to track if the team events should be emitted.
self._emit_team_events = emit_team_events
@abstractmethod
def _create_group_chat_manager_factory(
self,

View File

@ -5,7 +5,7 @@ from typing import Any, List
from autogen_core import DefaultTopicId, MessageContext, event, rpc
from ...base import TerminationCondition
from ...messages import BaseAgentEvent, BaseChatMessage, MessageFactory, StopMessage
from ...messages import BaseAgentEvent, BaseChatMessage, MessageFactory, SelectSpeakerEvent, StopMessage
from ._events import (
GroupChatAgentResponse,
GroupChatError,
@ -45,6 +45,7 @@ class BaseGroupChatManager(SequentialRoutedAgent, ABC):
termination_condition: TerminationCondition | None,
max_turns: int | None,
message_factory: MessageFactory,
emit_team_events: bool = False,
):
super().__init__(
description="Group chat manager",
@ -77,6 +78,7 @@ class BaseGroupChatManager(SequentialRoutedAgent, ABC):
self._max_turns = max_turns
self._current_turn = 0
self._message_factory = message_factory
self._emit_team_events = emit_team_events
@rpc
async def handle_start(self, message: GroupChatStart, ctx: MessageContext) -> None:
@ -139,6 +141,14 @@ class BaseGroupChatManager(SequentialRoutedAgent, ABC):
topic_id=DefaultTopicId(type=speaker_topic_type),
cancellation_token=ctx.cancellation_token,
)
# Send the message to the next speaker
if self._emit_team_events:
select_msg = SelectSpeakerEvent(content=speaker_name, source=self._name)
await self.publish_message(
GroupChatMessage(message=select_msg),
topic_id=DefaultTopicId(type=self._output_topic_type),
)
await self._output_message_queue.put(select_msg)
@event
async def handle_agent_response(self, message: GroupChatAgentResponse, ctx: MessageContext) -> None:
@ -195,6 +205,14 @@ class BaseGroupChatManager(SequentialRoutedAgent, ABC):
topic_id=DefaultTopicId(type=speaker_topic_type),
cancellation_token=ctx.cancellation_token,
)
# Send the message to the next speakers
if self._emit_team_events:
select_msg = SelectSpeakerEvent(content=speaker_name, source=self._name)
await self.publish_message(
GroupChatMessage(message=select_msg),
topic_id=DefaultTopicId(type=self._output_topic_type),
)
await self._output_message_queue.put(select_msg)
except Exception as e:
# Handle the exception and signal termination with an error.
error = SerializableException.from_exception(e)

View File

@ -28,6 +28,7 @@ class MagenticOneGroupChatConfig(BaseModel):
max_turns: int | None = None
max_stalls: int
final_answer_prompt: str
emit_team_events: bool = False
class MagenticOneGroupChat(BaseGroupChat, Component[MagenticOneGroupChatConfig]):
@ -46,6 +47,7 @@ class MagenticOneGroupChat(BaseGroupChat, Component[MagenticOneGroupChatConfig])
max_turns (int, optional): The maximum number of turns in the group chat before stopping. Defaults to 20.
max_stalls (int, optional): The maximum number of stalls allowed before re-planning. Defaults to 3.
final_answer_prompt (str, optional): The LLM prompt used to generate the final answer or response from the team's transcript. A default (sensible for GPT-4o class models) is provided.
emit_team_events (bool, optional): Whether to emit team events. Defaults to False.
Raises:
ValueError: In orchestration logic if progress ledger does not have required keys or if next speaker is not valid.
@ -103,6 +105,7 @@ class MagenticOneGroupChat(BaseGroupChat, Component[MagenticOneGroupChatConfig])
runtime: AgentRuntime | None = None,
max_stalls: int = 3,
final_answer_prompt: str = ORCHESTRATOR_FINAL_ANSWER_PROMPT,
emit_team_events: bool = False,
):
super().__init__(
participants,
@ -111,6 +114,7 @@ class MagenticOneGroupChat(BaseGroupChat, Component[MagenticOneGroupChatConfig])
termination_condition=termination_condition,
max_turns=max_turns,
runtime=runtime,
emit_team_events=emit_team_events,
)
# Validate the participants.
@ -147,6 +151,7 @@ class MagenticOneGroupChat(BaseGroupChat, Component[MagenticOneGroupChatConfig])
self._final_answer_prompt,
output_message_queue,
termination_condition,
self._emit_team_events,
)
def _to_config(self) -> MagenticOneGroupChatConfig:
@ -159,6 +164,7 @@ class MagenticOneGroupChat(BaseGroupChat, Component[MagenticOneGroupChatConfig])
max_turns=self._max_turns,
max_stalls=self._max_stalls,
final_answer_prompt=self._final_answer_prompt,
emit_team_events=self._emit_team_events,
)
@classmethod
@ -175,4 +181,5 @@ class MagenticOneGroupChat(BaseGroupChat, Component[MagenticOneGroupChatConfig])
max_turns=config.max_turns,
max_stalls=config.max_stalls,
final_answer_prompt=config.final_answer_prompt,
emit_team_events=config.emit_team_events,
)

View File

@ -20,6 +20,7 @@ from ....messages import (
HandoffMessage,
MessageFactory,
MultiModalMessage,
SelectSpeakerEvent,
StopMessage,
TextMessage,
ToolCallExecutionEvent,
@ -68,6 +69,7 @@ class MagenticOneOrchestrator(BaseGroupChatManager):
final_answer_prompt: str,
output_message_queue: asyncio.Queue[BaseAgentEvent | BaseChatMessage | GroupChatTermination],
termination_condition: TerminationCondition | None,
emit_team_events: bool = False,
):
super().__init__(
name,
@ -80,6 +82,7 @@ class MagenticOneOrchestrator(BaseGroupChatManager):
termination_condition,
max_turns,
message_factory,
emit_team_events=emit_team_events,
)
self._model_client = model_client
self._max_stalls = max_stalls
@ -405,6 +408,15 @@ class MagenticOneOrchestrator(BaseGroupChatManager):
cancellation_token=cancellation_token,
)
# Send the message to the next speaker
if self._emit_team_events:
select_msg = SelectSpeakerEvent(content=next_speaker, source=self._name)
await self.publish_message(
GroupChatMessage(message=select_msg),
topic_id=DefaultTopicId(type=self._output_topic_type),
)
await self._output_message_queue.put(select_msg)
async def _update_task_ledger(self, cancellation_token: CancellationToken) -> None:
"""Update the task ledger (outer loop) with the latest facts and plan."""
context = self._thread_to_context()

View File

@ -28,6 +28,7 @@ class RoundRobinGroupChatManager(BaseGroupChatManager):
termination_condition: TerminationCondition | None,
max_turns: int | None,
message_factory: MessageFactory,
emit_team_events: bool = True,
) -> None:
super().__init__(
name,
@ -40,6 +41,7 @@ class RoundRobinGroupChatManager(BaseGroupChatManager):
termination_condition,
max_turns,
message_factory,
emit_team_events,
)
self._next_speaker_index = 0
@ -81,6 +83,7 @@ class RoundRobinGroupChatConfig(BaseModel):
participants: List[ComponentModel]
termination_condition: ComponentModel | None = None
max_turns: int | None = None
emit_team_events: bool = False
class RoundRobinGroupChat(BaseGroupChat, Component[RoundRobinGroupChatConfig]):
@ -94,6 +97,7 @@ class RoundRobinGroupChat(BaseGroupChat, Component[RoundRobinGroupChatConfig]):
termination_condition (TerminationCondition, optional): The termination condition for the group chat. Defaults to None.
Without a termination condition, the group chat will run indefinitely.
max_turns (int, optional): The maximum number of turns in the group chat before stopping. Defaults to None, meaning no limit.
emit_team_events (bool, optinal): Whether to emit team events. Defaults to False.
Raises:
ValueError: If no participants are provided or if participant names are not unique.
@ -167,6 +171,7 @@ class RoundRobinGroupChat(BaseGroupChat, Component[RoundRobinGroupChatConfig]):
max_turns: int | None = None,
runtime: AgentRuntime | None = None,
custom_message_types: List[type[BaseAgentEvent | BaseChatMessage]] | None = None,
emit_team_events: bool = False,
) -> None:
super().__init__(
participants,
@ -176,6 +181,7 @@ class RoundRobinGroupChat(BaseGroupChat, Component[RoundRobinGroupChatConfig]):
max_turns=max_turns,
runtime=runtime,
custom_message_types=custom_message_types,
emit_team_events=emit_team_events,
)
def _create_group_chat_manager_factory(
@ -203,6 +209,7 @@ class RoundRobinGroupChat(BaseGroupChat, Component[RoundRobinGroupChatConfig]):
termination_condition,
max_turns,
message_factory,
self._emit_team_events,
)
return _factory
@ -214,6 +221,7 @@ class RoundRobinGroupChat(BaseGroupChat, Component[RoundRobinGroupChatConfig]):
participants=participants,
termination_condition=termination_condition,
max_turns=self._max_turns,
emit_team_events=self._emit_team_events,
)
@classmethod
@ -222,4 +230,9 @@ class RoundRobinGroupChat(BaseGroupChat, Component[RoundRobinGroupChatConfig]):
termination_condition = (
TerminationCondition.load_component(config.termination_condition) if config.termination_condition else None
)
return cls(participants, termination_condition=termination_condition, max_turns=config.max_turns)
return cls(
participants,
termination_condition=termination_condition,
max_turns=config.max_turns,
emit_team_events=config.emit_team_events,
)

View File

@ -55,6 +55,7 @@ class SelectorGroupChatManager(BaseGroupChatManager):
selector_func: Optional[SelectorFuncType],
max_selector_attempts: int,
candidate_func: Optional[CandidateFuncType],
emit_team_events: bool = True,
) -> None:
super().__init__(
name,
@ -67,6 +68,7 @@ class SelectorGroupChatManager(BaseGroupChatManager):
termination_condition,
max_turns,
message_factory,
emit_team_events,
)
self._model_client = model_client
self._selector_prompt = selector_prompt
@ -278,6 +280,7 @@ class SelectorGroupChatConfig(BaseModel):
allow_repeated_speaker: bool
# selector_func: ComponentModel | None
max_selector_attempts: int = 3
emit_team_events: bool = False
class SelectorGroupChat(BaseGroupChat, Component[SelectorGroupChatConfig]):
@ -307,7 +310,7 @@ class SelectorGroupChat(BaseGroupChat, Component[SelectorGroupChatConfig]):
A custom function that takes the conversation history and returns a filtered list of candidates for the next speaker
selection using model. If the function returns an empty list or `None`, `SelectorGroupChat` will raise a `ValueError`.
This function is only used if `selector_func` is not set. The `allow_repeated_speaker` will be ignored if set.
emit_team_events (bool, optional): Whether to emit team events. Defaults to False.
Raises:
ValueError: If the number of participants is less than two or if the selector prompt is invalid.
@ -449,6 +452,7 @@ Read the above conversation. Then select the next role from {participants} to pl
selector_func: Optional[SelectorFuncType] = None,
candidate_func: Optional[CandidateFuncType] = None,
custom_message_types: List[type[BaseAgentEvent | BaseChatMessage]] | None = None,
emit_team_events: bool = False,
):
super().__init__(
participants,
@ -458,6 +462,7 @@ Read the above conversation. Then select the next role from {participants} to pl
max_turns=max_turns,
runtime=runtime,
custom_message_types=custom_message_types,
emit_team_events=emit_team_events,
)
# Validate the participants.
if len(participants) < 2:
@ -499,6 +504,7 @@ Read the above conversation. Then select the next role from {participants} to pl
self._selector_func,
self._max_selector_attempts,
self._candidate_func,
self._emit_team_events,
)
def _to_config(self) -> SelectorGroupChatConfig:
@ -511,6 +517,7 @@ Read the above conversation. Then select the next role from {participants} to pl
allow_repeated_speaker=self._allow_repeated_speaker,
max_selector_attempts=self._max_selector_attempts,
# selector_func=self._selector_func.dump_component() if self._selector_func else None,
emit_team_events=self._emit_team_events,
)
@classmethod
@ -528,4 +535,5 @@ Read the above conversation. Then select the next role from {participants} to pl
# selector_func=ComponentLoader.load_component(config.selector_func, Callable[[Sequence[BaseAgentEvent | BaseChatMessage]], str | None])
# if config.selector_func
# else None,
emit_team_events=config.emit_team_events,
)

View File

@ -27,6 +27,7 @@ class SwarmGroupChatManager(BaseGroupChatManager):
termination_condition: TerminationCondition | None,
max_turns: int | None,
message_factory: MessageFactory,
emit_team_events: bool,
) -> None:
super().__init__(
name,
@ -39,6 +40,7 @@ class SwarmGroupChatManager(BaseGroupChatManager):
termination_condition,
max_turns,
message_factory,
emit_team_events,
)
self._current_speaker = self._participant_names[0]
@ -111,6 +113,7 @@ class SwarmConfig(BaseModel):
participants: List[ComponentModel]
termination_condition: ComponentModel | None = None
max_turns: int | None = None
emit_team_events: bool = False
class Swarm(BaseGroupChat, Component[SwarmConfig]):
@ -126,6 +129,7 @@ class Swarm(BaseGroupChat, Component[SwarmConfig]):
termination_condition (TerminationCondition, optional): The termination condition for the group chat. Defaults to None.
Without a termination condition, the group chat will run indefinitely.
max_turns (int, optional): The maximum number of turns in the group chat before stopping. Defaults to None, meaning no limit.
emit_team_events (bool, optional): Whether to emit team events. Defaults to False.
Basic example:
@ -213,6 +217,7 @@ class Swarm(BaseGroupChat, Component[SwarmConfig]):
max_turns: int | None = None,
runtime: AgentRuntime | None = None,
custom_message_types: List[type[BaseAgentEvent | BaseChatMessage]] | None = None,
emit_team_events: bool = False,
) -> None:
super().__init__(
participants,
@ -222,6 +227,7 @@ class Swarm(BaseGroupChat, Component[SwarmConfig]):
max_turns=max_turns,
runtime=runtime,
custom_message_types=custom_message_types,
emit_team_events=emit_team_events,
)
# The first participant must be able to produce handoff messages.
first_participant = self._participants[0]
@ -253,6 +259,7 @@ class Swarm(BaseGroupChat, Component[SwarmConfig]):
termination_condition,
max_turns,
message_factory,
self._emit_team_events,
)
return _factory
@ -264,6 +271,7 @@ class Swarm(BaseGroupChat, Component[SwarmConfig]):
participants=participants,
termination_condition=termination_condition,
max_turns=self._max_turns,
emit_team_events=self._emit_team_events,
)
@classmethod
@ -272,4 +280,9 @@ class Swarm(BaseGroupChat, Component[SwarmConfig]):
termination_condition = (
TerminationCondition.load_component(config.termination_condition) if config.termination_condition else None
)
return cls(participants, termination_condition=termination_condition, max_turns=config.max_turns)
return cls(
participants,
termination_condition=termination_condition,
max_turns=config.max_turns,
emit_team_events=config.emit_team_events,
)