This commit is contained in:
EeS 2025-04-12 14:49:56 +09:00 committed by GitHub
commit a63f1dd7bc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 28 additions and 3 deletions

View File

@ -5,7 +5,15 @@ from inspect import iscoroutinefunction
from typing import Any, Awaitable, Callable, Dict, List, Mapping, Optional, Sequence, Union, cast from typing import Any, Awaitable, Callable, Dict, List, Mapping, Optional, Sequence, Union, cast
from autogen_core import AgentRuntime, Component, ComponentModel from autogen_core import AgentRuntime, Component, ComponentModel
from autogen_core.models import AssistantMessage, ChatCompletionClient, ModelFamily, SystemMessage, UserMessage from autogen_core.logging import LLMStreamEndEvent
from autogen_core.models import (
AssistantMessage,
ChatCompletionClient,
CreateResult,
ModelFamily,
SystemMessage,
UserMessage,
)
from pydantic import BaseModel from pydantic import BaseModel
from typing_extensions import Self from typing_extensions import Self
@ -55,6 +63,7 @@ class SelectorGroupChatManager(BaseGroupChatManager):
selector_func: Optional[SelectorFuncType], selector_func: Optional[SelectorFuncType],
max_selector_attempts: int, max_selector_attempts: int,
candidate_func: Optional[CandidateFuncType], candidate_func: Optional[CandidateFuncType],
streaming: bool = False,
) -> None: ) -> None:
super().__init__( super().__init__(
name, name,
@ -77,6 +86,7 @@ class SelectorGroupChatManager(BaseGroupChatManager):
self._max_selector_attempts = max_selector_attempts self._max_selector_attempts = max_selector_attempts
self._candidate_func = candidate_func self._candidate_func = candidate_func
self._is_candidate_func_async = iscoroutinefunction(self._candidate_func) self._is_candidate_func_async = iscoroutinefunction(self._candidate_func)
self._streaming = streaming
async def validate_group_state(self, messages: List[BaseChatMessage] | None) -> None: async def validate_group_state(self, messages: List[BaseChatMessage] | None) -> None:
pass pass
@ -192,7 +202,18 @@ class SelectorGroupChatManager(BaseGroupChatManager):
num_attempts = 0 num_attempts = 0
while num_attempts < max_attempts: while num_attempts < max_attempts:
num_attempts += 1 num_attempts += 1
response = await self._model_client.create(messages=select_speaker_messages) if self._streaming:
message: CreateResult | str = ""
async for _message in self._model_client.create_stream(messages=select_speaker_messages):
if isinstance(_message, LLMStreamEndEvent):
break
message = _message
if isinstance(message, CreateResult):
response = message
else:
raise ValueError("Model failed to select a speaker.")
else:
response = await self._model_client.create(messages=select_speaker_messages)
assert isinstance(response.content, str) assert isinstance(response.content, str)
select_speaker_messages.append(AssistantMessage(content=response.content, source="selector")) select_speaker_messages.append(AssistantMessage(content=response.content, source="selector"))
# NOTE: we use all participant names to check for mentions, even if the previous speaker is not allowed. # NOTE: we use all participant names to check for mentions, even if the previous speaker is not allowed.
@ -278,6 +299,7 @@ class SelectorGroupChatConfig(BaseModel):
allow_repeated_speaker: bool allow_repeated_speaker: bool
# selector_func: ComponentModel | None # selector_func: ComponentModel | None
max_selector_attempts: int = 3 max_selector_attempts: int = 3
streaming: bool = False
class SelectorGroupChat(BaseGroupChat, Component[SelectorGroupChatConfig]): class SelectorGroupChat(BaseGroupChat, Component[SelectorGroupChatConfig]):
@ -307,7 +329,7 @@ class SelectorGroupChat(BaseGroupChat, Component[SelectorGroupChatConfig]):
A custom function that takes the conversation history and returns a filtered list of candidates for the next speaker A custom function that takes the conversation history and returns a filtered list of candidates for the next speaker
selection using model. If the function returns an empty list or `None`, `SelectorGroupChat` will raise a `ValueError`. selection using model. If the function returns an empty list or `None`, `SelectorGroupChat` will raise a `ValueError`.
This function is only used if `selector_func` is not set. The `allow_repeated_speaker` will be ignored if set. This function is only used if `selector_func` is not set. The `allow_repeated_speaker` will be ignored if set.
streaming (bool, optional): Whether to use streaming for the model.(Only use for specify case e.g. QwQ) Defaults to False.
Raises: Raises:
ValueError: If the number of participants is less than two or if the selector prompt is invalid. ValueError: If the number of participants is less than two or if the selector prompt is invalid.
@ -449,6 +471,7 @@ Read the above conversation. Then select the next role from {participants} to pl
selector_func: Optional[SelectorFuncType] = None, selector_func: Optional[SelectorFuncType] = None,
candidate_func: Optional[CandidateFuncType] = None, candidate_func: Optional[CandidateFuncType] = None,
custom_message_types: List[type[BaseAgentEvent | BaseChatMessage]] | None = None, custom_message_types: List[type[BaseAgentEvent | BaseChatMessage]] | None = None,
streaming: bool = False,
): ):
super().__init__( super().__init__(
participants, participants,
@ -468,6 +491,7 @@ Read the above conversation. Then select the next role from {participants} to pl
self._selector_func = selector_func self._selector_func = selector_func
self._max_selector_attempts = max_selector_attempts self._max_selector_attempts = max_selector_attempts
self._candidate_func = candidate_func self._candidate_func = candidate_func
self._streaming = streaming
def _create_group_chat_manager_factory( def _create_group_chat_manager_factory(
self, self,
@ -499,6 +523,7 @@ Read the above conversation. Then select the next role from {participants} to pl
self._selector_func, self._selector_func,
self._max_selector_attempts, self._max_selector_attempts,
self._candidate_func, self._candidate_func,
self._streaming,
) )
def _to_config(self) -> SelectorGroupChatConfig: def _to_config(self) -> SelectorGroupChatConfig: