From 965d9a461b85aff849dd09e77479527529bdcf30 Mon Sep 17 00:00:00 2001 From: "chiyoung.song" Date: Sat, 12 Apr 2025 14:43:39 +0900 Subject: [PATCH] FEAT: select group chat could using stream --- .../teams/_group_chat/_selector_group_chat.py | 31 +++++++++++++++++-- 1 file changed, 28 insertions(+), 3 deletions(-) diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py index 1aa5aa337..cc03b9e6b 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py @@ -5,7 +5,15 @@ from inspect import iscoroutinefunction from typing import Any, Awaitable, Callable, Dict, List, Mapping, Optional, Sequence, Union, cast from autogen_core import AgentRuntime, Component, ComponentModel -from autogen_core.models import AssistantMessage, ChatCompletionClient, ModelFamily, SystemMessage, UserMessage +from autogen_core.logging import LLMStreamEndEvent +from autogen_core.models import ( + AssistantMessage, + ChatCompletionClient, + CreateResult, + ModelFamily, + SystemMessage, + UserMessage, +) from pydantic import BaseModel from typing_extensions import Self @@ -55,6 +63,7 @@ class SelectorGroupChatManager(BaseGroupChatManager): selector_func: Optional[SelectorFuncType], max_selector_attempts: int, candidate_func: Optional[CandidateFuncType], + streaming: bool = False, ) -> None: super().__init__( name, @@ -77,6 +86,7 @@ class SelectorGroupChatManager(BaseGroupChatManager): self._max_selector_attempts = max_selector_attempts self._candidate_func = candidate_func self._is_candidate_func_async = iscoroutinefunction(self._candidate_func) + self._streaming = streaming async def validate_group_state(self, messages: List[BaseChatMessage] | None) -> None: pass @@ -192,7 +202,18 @@ class SelectorGroupChatManager(BaseGroupChatManager): num_attempts = 0 while num_attempts < max_attempts: num_attempts += 1 - response = await self._model_client.create(messages=select_speaker_messages) + if self._streaming: + message: CreateResult | str = "" + async for _message in self._model_client.create_stream(messages=select_speaker_messages): + if isinstance(_message, LLMStreamEndEvent): + break + message = _message + if isinstance(message, CreateResult): + response = message + else: + raise ValueError("Model failed to select a speaker.") + else: + response = await self._model_client.create(messages=select_speaker_messages) assert isinstance(response.content, str) select_speaker_messages.append(AssistantMessage(content=response.content, source="selector")) # NOTE: we use all participant names to check for mentions, even if the previous speaker is not allowed. @@ -278,6 +299,7 @@ class SelectorGroupChatConfig(BaseModel): allow_repeated_speaker: bool # selector_func: ComponentModel | None max_selector_attempts: int = 3 + streaming: bool = False class SelectorGroupChat(BaseGroupChat, Component[SelectorGroupChatConfig]): @@ -307,7 +329,7 @@ class SelectorGroupChat(BaseGroupChat, Component[SelectorGroupChatConfig]): A custom function that takes the conversation history and returns a filtered list of candidates for the next speaker selection using model. If the function returns an empty list or `None`, `SelectorGroupChat` will raise a `ValueError`. This function is only used if `selector_func` is not set. The `allow_repeated_speaker` will be ignored if set. - + streaming (bool, optional): Whether to use streaming for the model.(Only use for specify case e.g. QwQ) Defaults to False. Raises: ValueError: If the number of participants is less than two or if the selector prompt is invalid. @@ -449,6 +471,7 @@ Read the above conversation. Then select the next role from {participants} to pl selector_func: Optional[SelectorFuncType] = None, candidate_func: Optional[CandidateFuncType] = None, custom_message_types: List[type[BaseAgentEvent | BaseChatMessage]] | None = None, + streaming: bool = False, ): super().__init__( participants, @@ -468,6 +491,7 @@ Read the above conversation. Then select the next role from {participants} to pl self._selector_func = selector_func self._max_selector_attempts = max_selector_attempts self._candidate_func = candidate_func + self._streaming = streaming def _create_group_chat_manager_factory( self, @@ -499,6 +523,7 @@ Read the above conversation. Then select the next role from {participants} to pl self._selector_func, self._max_selector_attempts, self._candidate_func, + self._streaming, ) def _to_config(self) -> SelectorGroupChatConfig: