mirror of https://github.com/microsoft/autogen.git
Merge fabb85a543
into 71b7429a42
This commit is contained in:
commit
0b2a764cab
|
@ -376,6 +376,18 @@ class ThoughtEvent(BaseAgentEvent):
|
|||
return self.content
|
||||
|
||||
|
||||
class SelectSpeakerEvent(BaseAgentEvent):
|
||||
"""An event signaling the selection of a speaker for a conversation."""
|
||||
|
||||
content: str
|
||||
"""The name of the selected speaker."""
|
||||
|
||||
type: Literal["SelectSpeakerEvent"] = "SelectSpeakerEvent"
|
||||
|
||||
def to_text(self) -> str:
|
||||
return self.content
|
||||
|
||||
|
||||
class MessageFactory:
|
||||
""":meta private:
|
||||
|
||||
|
@ -398,6 +410,7 @@ class MessageFactory:
|
|||
self._message_types[UserInputRequestedEvent.__name__] = UserInputRequestedEvent
|
||||
self._message_types[ModelClientStreamingChunkEvent.__name__] = ModelClientStreamingChunkEvent
|
||||
self._message_types[ThoughtEvent.__name__] = ThoughtEvent
|
||||
self._message_types[SelectSpeakerEvent.__name__] = SelectSpeakerEvent
|
||||
self._message_types[CodeGenerationEvent.__name__] = CodeGenerationEvent
|
||||
self._message_types[CodeExecutionEvent.__name__] = CodeExecutionEvent
|
||||
|
||||
|
@ -453,6 +466,7 @@ AgentEvent = Annotated[
|
|||
| UserInputRequestedEvent
|
||||
| ModelClientStreamingChunkEvent
|
||||
| ThoughtEvent
|
||||
| SelectSpeakerEvent
|
||||
| CodeGenerationEvent
|
||||
| CodeExecutionEvent,
|
||||
Field(discriminator="type"),
|
||||
|
@ -479,6 +493,7 @@ __all__ = [
|
|||
"UserInputRequestedEvent",
|
||||
"ModelClientStreamingChunkEvent",
|
||||
"ThoughtEvent",
|
||||
"SelectSpeakerEvent",
|
||||
"MessageFactory",
|
||||
"CodeGenerationEvent",
|
||||
"CodeExecutionEvent",
|
||||
|
|
|
@ -54,6 +54,7 @@ class BaseGroupChat(Team, ABC, ComponentBase[BaseModel]):
|
|||
max_turns: int | None = None,
|
||||
runtime: AgentRuntime | None = None,
|
||||
custom_message_types: List[type[BaseAgentEvent | BaseChatMessage]] | None = None,
|
||||
emit_team_events: bool = False,
|
||||
):
|
||||
if len(participants) == 0:
|
||||
raise ValueError("At least one participant is required.")
|
||||
|
@ -113,6 +114,9 @@ class BaseGroupChat(Team, ABC, ComponentBase[BaseModel]):
|
|||
# Flag to track if the group chat is running.
|
||||
self._is_running = False
|
||||
|
||||
# Flag to track if the team events should be emitted.
|
||||
self._emit_team_events = emit_team_events
|
||||
|
||||
@abstractmethod
|
||||
def _create_group_chat_manager_factory(
|
||||
self,
|
||||
|
|
|
@ -5,7 +5,7 @@ from typing import Any, List
|
|||
from autogen_core import DefaultTopicId, MessageContext, event, rpc
|
||||
|
||||
from ...base import TerminationCondition
|
||||
from ...messages import BaseAgentEvent, BaseChatMessage, MessageFactory, StopMessage
|
||||
from ...messages import BaseAgentEvent, BaseChatMessage, MessageFactory, SelectSpeakerEvent, StopMessage
|
||||
from ._events import (
|
||||
GroupChatAgentResponse,
|
||||
GroupChatError,
|
||||
|
@ -45,6 +45,7 @@ class BaseGroupChatManager(SequentialRoutedAgent, ABC):
|
|||
termination_condition: TerminationCondition | None,
|
||||
max_turns: int | None,
|
||||
message_factory: MessageFactory,
|
||||
emit_team_events: bool = False,
|
||||
):
|
||||
super().__init__(
|
||||
description="Group chat manager",
|
||||
|
@ -77,6 +78,7 @@ class BaseGroupChatManager(SequentialRoutedAgent, ABC):
|
|||
self._max_turns = max_turns
|
||||
self._current_turn = 0
|
||||
self._message_factory = message_factory
|
||||
self._emit_team_events = emit_team_events
|
||||
|
||||
@rpc
|
||||
async def handle_start(self, message: GroupChatStart, ctx: MessageContext) -> None:
|
||||
|
@ -139,6 +141,14 @@ class BaseGroupChatManager(SequentialRoutedAgent, ABC):
|
|||
topic_id=DefaultTopicId(type=speaker_topic_type),
|
||||
cancellation_token=ctx.cancellation_token,
|
||||
)
|
||||
# Send the message to the next speaker
|
||||
if self._emit_team_events:
|
||||
select_msg = SelectSpeakerEvent(content=speaker_name, source=self._name)
|
||||
await self.publish_message(
|
||||
GroupChatMessage(message=select_msg),
|
||||
topic_id=DefaultTopicId(type=self._output_topic_type),
|
||||
)
|
||||
await self._output_message_queue.put(select_msg)
|
||||
|
||||
@event
|
||||
async def handle_agent_response(self, message: GroupChatAgentResponse, ctx: MessageContext) -> None:
|
||||
|
@ -195,6 +205,14 @@ class BaseGroupChatManager(SequentialRoutedAgent, ABC):
|
|||
topic_id=DefaultTopicId(type=speaker_topic_type),
|
||||
cancellation_token=ctx.cancellation_token,
|
||||
)
|
||||
# Send the message to the next speakers
|
||||
if self._emit_team_events:
|
||||
select_msg = SelectSpeakerEvent(content=speaker_name, source=self._name)
|
||||
await self.publish_message(
|
||||
GroupChatMessage(message=select_msg),
|
||||
topic_id=DefaultTopicId(type=self._output_topic_type),
|
||||
)
|
||||
await self._output_message_queue.put(select_msg)
|
||||
except Exception as e:
|
||||
# Handle the exception and signal termination with an error.
|
||||
error = SerializableException.from_exception(e)
|
||||
|
|
|
@ -28,6 +28,7 @@ class MagenticOneGroupChatConfig(BaseModel):
|
|||
max_turns: int | None = None
|
||||
max_stalls: int
|
||||
final_answer_prompt: str
|
||||
emit_team_events: bool = False
|
||||
|
||||
|
||||
class MagenticOneGroupChat(BaseGroupChat, Component[MagenticOneGroupChatConfig]):
|
||||
|
@ -46,6 +47,7 @@ class MagenticOneGroupChat(BaseGroupChat, Component[MagenticOneGroupChatConfig])
|
|||
max_turns (int, optional): The maximum number of turns in the group chat before stopping. Defaults to 20.
|
||||
max_stalls (int, optional): The maximum number of stalls allowed before re-planning. Defaults to 3.
|
||||
final_answer_prompt (str, optional): The LLM prompt used to generate the final answer or response from the team's transcript. A default (sensible for GPT-4o class models) is provided.
|
||||
emit_team_events (bool, optional): Whether to emit team events. Defaults to False.
|
||||
|
||||
Raises:
|
||||
ValueError: In orchestration logic if progress ledger does not have required keys or if next speaker is not valid.
|
||||
|
@ -103,6 +105,7 @@ class MagenticOneGroupChat(BaseGroupChat, Component[MagenticOneGroupChatConfig])
|
|||
runtime: AgentRuntime | None = None,
|
||||
max_stalls: int = 3,
|
||||
final_answer_prompt: str = ORCHESTRATOR_FINAL_ANSWER_PROMPT,
|
||||
emit_team_events: bool = False,
|
||||
):
|
||||
super().__init__(
|
||||
participants,
|
||||
|
@ -111,6 +114,7 @@ class MagenticOneGroupChat(BaseGroupChat, Component[MagenticOneGroupChatConfig])
|
|||
termination_condition=termination_condition,
|
||||
max_turns=max_turns,
|
||||
runtime=runtime,
|
||||
emit_team_events=emit_team_events,
|
||||
)
|
||||
|
||||
# Validate the participants.
|
||||
|
@ -147,6 +151,7 @@ class MagenticOneGroupChat(BaseGroupChat, Component[MagenticOneGroupChatConfig])
|
|||
self._final_answer_prompt,
|
||||
output_message_queue,
|
||||
termination_condition,
|
||||
self._emit_team_events,
|
||||
)
|
||||
|
||||
def _to_config(self) -> MagenticOneGroupChatConfig:
|
||||
|
@ -159,6 +164,7 @@ class MagenticOneGroupChat(BaseGroupChat, Component[MagenticOneGroupChatConfig])
|
|||
max_turns=self._max_turns,
|
||||
max_stalls=self._max_stalls,
|
||||
final_answer_prompt=self._final_answer_prompt,
|
||||
emit_team_events=self._emit_team_events,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
@ -175,4 +181,5 @@ class MagenticOneGroupChat(BaseGroupChat, Component[MagenticOneGroupChatConfig])
|
|||
max_turns=config.max_turns,
|
||||
max_stalls=config.max_stalls,
|
||||
final_answer_prompt=config.final_answer_prompt,
|
||||
emit_team_events=config.emit_team_events,
|
||||
)
|
||||
|
|
|
@ -20,6 +20,7 @@ from ....messages import (
|
|||
HandoffMessage,
|
||||
MessageFactory,
|
||||
MultiModalMessage,
|
||||
SelectSpeakerEvent,
|
||||
StopMessage,
|
||||
TextMessage,
|
||||
ToolCallExecutionEvent,
|
||||
|
@ -68,6 +69,7 @@ class MagenticOneOrchestrator(BaseGroupChatManager):
|
|||
final_answer_prompt: str,
|
||||
output_message_queue: asyncio.Queue[BaseAgentEvent | BaseChatMessage | GroupChatTermination],
|
||||
termination_condition: TerminationCondition | None,
|
||||
emit_team_events: bool = False,
|
||||
):
|
||||
super().__init__(
|
||||
name,
|
||||
|
@ -80,6 +82,7 @@ class MagenticOneOrchestrator(BaseGroupChatManager):
|
|||
termination_condition,
|
||||
max_turns,
|
||||
message_factory,
|
||||
emit_team_events=emit_team_events,
|
||||
)
|
||||
self._model_client = model_client
|
||||
self._max_stalls = max_stalls
|
||||
|
@ -405,6 +408,15 @@ class MagenticOneOrchestrator(BaseGroupChatManager):
|
|||
cancellation_token=cancellation_token,
|
||||
)
|
||||
|
||||
# Send the message to the next speaker
|
||||
if self._emit_team_events:
|
||||
select_msg = SelectSpeakerEvent(content=next_speaker, source=self._name)
|
||||
await self.publish_message(
|
||||
GroupChatMessage(message=select_msg),
|
||||
topic_id=DefaultTopicId(type=self._output_topic_type),
|
||||
)
|
||||
await self._output_message_queue.put(select_msg)
|
||||
|
||||
async def _update_task_ledger(self, cancellation_token: CancellationToken) -> None:
|
||||
"""Update the task ledger (outer loop) with the latest facts and plan."""
|
||||
context = self._thread_to_context()
|
||||
|
|
|
@ -28,6 +28,7 @@ class RoundRobinGroupChatManager(BaseGroupChatManager):
|
|||
termination_condition: TerminationCondition | None,
|
||||
max_turns: int | None,
|
||||
message_factory: MessageFactory,
|
||||
emit_team_events: bool = True,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
name,
|
||||
|
@ -40,6 +41,7 @@ class RoundRobinGroupChatManager(BaseGroupChatManager):
|
|||
termination_condition,
|
||||
max_turns,
|
||||
message_factory,
|
||||
emit_team_events,
|
||||
)
|
||||
self._next_speaker_index = 0
|
||||
|
||||
|
@ -81,6 +83,7 @@ class RoundRobinGroupChatConfig(BaseModel):
|
|||
participants: List[ComponentModel]
|
||||
termination_condition: ComponentModel | None = None
|
||||
max_turns: int | None = None
|
||||
emit_team_events: bool = False
|
||||
|
||||
|
||||
class RoundRobinGroupChat(BaseGroupChat, Component[RoundRobinGroupChatConfig]):
|
||||
|
@ -94,6 +97,7 @@ class RoundRobinGroupChat(BaseGroupChat, Component[RoundRobinGroupChatConfig]):
|
|||
termination_condition (TerminationCondition, optional): The termination condition for the group chat. Defaults to None.
|
||||
Without a termination condition, the group chat will run indefinitely.
|
||||
max_turns (int, optional): The maximum number of turns in the group chat before stopping. Defaults to None, meaning no limit.
|
||||
emit_team_events (bool, optinal): Whether to emit team events. Defaults to False.
|
||||
|
||||
Raises:
|
||||
ValueError: If no participants are provided or if participant names are not unique.
|
||||
|
@ -167,6 +171,7 @@ class RoundRobinGroupChat(BaseGroupChat, Component[RoundRobinGroupChatConfig]):
|
|||
max_turns: int | None = None,
|
||||
runtime: AgentRuntime | None = None,
|
||||
custom_message_types: List[type[BaseAgentEvent | BaseChatMessage]] | None = None,
|
||||
emit_team_events: bool = False,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
participants,
|
||||
|
@ -176,6 +181,7 @@ class RoundRobinGroupChat(BaseGroupChat, Component[RoundRobinGroupChatConfig]):
|
|||
max_turns=max_turns,
|
||||
runtime=runtime,
|
||||
custom_message_types=custom_message_types,
|
||||
emit_team_events=emit_team_events,
|
||||
)
|
||||
|
||||
def _create_group_chat_manager_factory(
|
||||
|
@ -203,6 +209,7 @@ class RoundRobinGroupChat(BaseGroupChat, Component[RoundRobinGroupChatConfig]):
|
|||
termination_condition,
|
||||
max_turns,
|
||||
message_factory,
|
||||
self._emit_team_events,
|
||||
)
|
||||
|
||||
return _factory
|
||||
|
@ -214,6 +221,7 @@ class RoundRobinGroupChat(BaseGroupChat, Component[RoundRobinGroupChatConfig]):
|
|||
participants=participants,
|
||||
termination_condition=termination_condition,
|
||||
max_turns=self._max_turns,
|
||||
emit_team_events=self._emit_team_events,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
@ -222,4 +230,9 @@ class RoundRobinGroupChat(BaseGroupChat, Component[RoundRobinGroupChatConfig]):
|
|||
termination_condition = (
|
||||
TerminationCondition.load_component(config.termination_condition) if config.termination_condition else None
|
||||
)
|
||||
return cls(participants, termination_condition=termination_condition, max_turns=config.max_turns)
|
||||
return cls(
|
||||
participants,
|
||||
termination_condition=termination_condition,
|
||||
max_turns=config.max_turns,
|
||||
emit_team_events=config.emit_team_events,
|
||||
)
|
||||
|
|
|
@ -55,6 +55,7 @@ class SelectorGroupChatManager(BaseGroupChatManager):
|
|||
selector_func: Optional[SelectorFuncType],
|
||||
max_selector_attempts: int,
|
||||
candidate_func: Optional[CandidateFuncType],
|
||||
emit_team_events: bool = True,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
name,
|
||||
|
@ -67,6 +68,7 @@ class SelectorGroupChatManager(BaseGroupChatManager):
|
|||
termination_condition,
|
||||
max_turns,
|
||||
message_factory,
|
||||
emit_team_events,
|
||||
)
|
||||
self._model_client = model_client
|
||||
self._selector_prompt = selector_prompt
|
||||
|
@ -278,6 +280,7 @@ class SelectorGroupChatConfig(BaseModel):
|
|||
allow_repeated_speaker: bool
|
||||
# selector_func: ComponentModel | None
|
||||
max_selector_attempts: int = 3
|
||||
emit_team_events: bool = False
|
||||
|
||||
|
||||
class SelectorGroupChat(BaseGroupChat, Component[SelectorGroupChatConfig]):
|
||||
|
@ -307,7 +310,7 @@ class SelectorGroupChat(BaseGroupChat, Component[SelectorGroupChatConfig]):
|
|||
A custom function that takes the conversation history and returns a filtered list of candidates for the next speaker
|
||||
selection using model. If the function returns an empty list or `None`, `SelectorGroupChat` will raise a `ValueError`.
|
||||
This function is only used if `selector_func` is not set. The `allow_repeated_speaker` will be ignored if set.
|
||||
|
||||
emit_team_events (bool, optional): Whether to emit team events. Defaults to False.
|
||||
|
||||
Raises:
|
||||
ValueError: If the number of participants is less than two or if the selector prompt is invalid.
|
||||
|
@ -449,6 +452,7 @@ Read the above conversation. Then select the next role from {participants} to pl
|
|||
selector_func: Optional[SelectorFuncType] = None,
|
||||
candidate_func: Optional[CandidateFuncType] = None,
|
||||
custom_message_types: List[type[BaseAgentEvent | BaseChatMessage]] | None = None,
|
||||
emit_team_events: bool = False,
|
||||
):
|
||||
super().__init__(
|
||||
participants,
|
||||
|
@ -458,6 +462,7 @@ Read the above conversation. Then select the next role from {participants} to pl
|
|||
max_turns=max_turns,
|
||||
runtime=runtime,
|
||||
custom_message_types=custom_message_types,
|
||||
emit_team_events=emit_team_events,
|
||||
)
|
||||
# Validate the participants.
|
||||
if len(participants) < 2:
|
||||
|
@ -499,6 +504,7 @@ Read the above conversation. Then select the next role from {participants} to pl
|
|||
self._selector_func,
|
||||
self._max_selector_attempts,
|
||||
self._candidate_func,
|
||||
self._emit_team_events,
|
||||
)
|
||||
|
||||
def _to_config(self) -> SelectorGroupChatConfig:
|
||||
|
@ -511,6 +517,7 @@ Read the above conversation. Then select the next role from {participants} to pl
|
|||
allow_repeated_speaker=self._allow_repeated_speaker,
|
||||
max_selector_attempts=self._max_selector_attempts,
|
||||
# selector_func=self._selector_func.dump_component() if self._selector_func else None,
|
||||
emit_team_events=self._emit_team_events,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
@ -528,4 +535,5 @@ Read the above conversation. Then select the next role from {participants} to pl
|
|||
# selector_func=ComponentLoader.load_component(config.selector_func, Callable[[Sequence[BaseAgentEvent | BaseChatMessage]], str | None])
|
||||
# if config.selector_func
|
||||
# else None,
|
||||
emit_team_events=config.emit_team_events,
|
||||
)
|
||||
|
|
|
@ -27,6 +27,7 @@ class SwarmGroupChatManager(BaseGroupChatManager):
|
|||
termination_condition: TerminationCondition | None,
|
||||
max_turns: int | None,
|
||||
message_factory: MessageFactory,
|
||||
emit_team_events: bool,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
name,
|
||||
|
@ -39,6 +40,7 @@ class SwarmGroupChatManager(BaseGroupChatManager):
|
|||
termination_condition,
|
||||
max_turns,
|
||||
message_factory,
|
||||
emit_team_events,
|
||||
)
|
||||
self._current_speaker = self._participant_names[0]
|
||||
|
||||
|
@ -111,6 +113,7 @@ class SwarmConfig(BaseModel):
|
|||
participants: List[ComponentModel]
|
||||
termination_condition: ComponentModel | None = None
|
||||
max_turns: int | None = None
|
||||
emit_team_events: bool = False
|
||||
|
||||
|
||||
class Swarm(BaseGroupChat, Component[SwarmConfig]):
|
||||
|
@ -126,6 +129,7 @@ class Swarm(BaseGroupChat, Component[SwarmConfig]):
|
|||
termination_condition (TerminationCondition, optional): The termination condition for the group chat. Defaults to None.
|
||||
Without a termination condition, the group chat will run indefinitely.
|
||||
max_turns (int, optional): The maximum number of turns in the group chat before stopping. Defaults to None, meaning no limit.
|
||||
emit_team_events (bool, optional): Whether to emit team events. Defaults to False.
|
||||
|
||||
Basic example:
|
||||
|
||||
|
@ -213,6 +217,7 @@ class Swarm(BaseGroupChat, Component[SwarmConfig]):
|
|||
max_turns: int | None = None,
|
||||
runtime: AgentRuntime | None = None,
|
||||
custom_message_types: List[type[BaseAgentEvent | BaseChatMessage]] | None = None,
|
||||
emit_team_events: bool = False,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
participants,
|
||||
|
@ -222,6 +227,7 @@ class Swarm(BaseGroupChat, Component[SwarmConfig]):
|
|||
max_turns=max_turns,
|
||||
runtime=runtime,
|
||||
custom_message_types=custom_message_types,
|
||||
emit_team_events=emit_team_events,
|
||||
)
|
||||
# The first participant must be able to produce handoff messages.
|
||||
first_participant = self._participants[0]
|
||||
|
@ -253,6 +259,7 @@ class Swarm(BaseGroupChat, Component[SwarmConfig]):
|
|||
termination_condition,
|
||||
max_turns,
|
||||
message_factory,
|
||||
self._emit_team_events,
|
||||
)
|
||||
|
||||
return _factory
|
||||
|
@ -264,6 +271,7 @@ class Swarm(BaseGroupChat, Component[SwarmConfig]):
|
|||
participants=participants,
|
||||
termination_condition=termination_condition,
|
||||
max_turns=self._max_turns,
|
||||
emit_team_events=self._emit_team_events,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
@ -272,4 +280,9 @@ class Swarm(BaseGroupChat, Component[SwarmConfig]):
|
|||
termination_condition = (
|
||||
TerminationCondition.load_component(config.termination_condition) if config.termination_condition else None
|
||||
)
|
||||
return cls(participants, termination_condition=termination_condition, max_turns=config.max_turns)
|
||||
return cls(
|
||||
participants,
|
||||
termination_condition=termination_condition,
|
||||
max_turns=config.max_turns,
|
||||
emit_team_events=config.emit_team_events,
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue