This commit is contained in:
EeS 2025-04-12 14:49:56 +09:00 committed by GitHub
commit a63f1dd7bc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 28 additions and 3 deletions

View File

@ -5,7 +5,15 @@ from inspect import iscoroutinefunction
from typing import Any, Awaitable, Callable, Dict, List, Mapping, Optional, Sequence, Union, cast
from autogen_core import AgentRuntime, Component, ComponentModel
from autogen_core.models import AssistantMessage, ChatCompletionClient, ModelFamily, SystemMessage, UserMessage
from autogen_core.logging import LLMStreamEndEvent
from autogen_core.models import (
AssistantMessage,
ChatCompletionClient,
CreateResult,
ModelFamily,
SystemMessage,
UserMessage,
)
from pydantic import BaseModel
from typing_extensions import Self
@ -55,6 +63,7 @@ class SelectorGroupChatManager(BaseGroupChatManager):
selector_func: Optional[SelectorFuncType],
max_selector_attempts: int,
candidate_func: Optional[CandidateFuncType],
streaming: bool = False,
) -> None:
super().__init__(
name,
@ -77,6 +86,7 @@ class SelectorGroupChatManager(BaseGroupChatManager):
self._max_selector_attempts = max_selector_attempts
self._candidate_func = candidate_func
self._is_candidate_func_async = iscoroutinefunction(self._candidate_func)
self._streaming = streaming
async def validate_group_state(self, messages: List[BaseChatMessage] | None) -> None:
pass
@ -192,7 +202,18 @@ class SelectorGroupChatManager(BaseGroupChatManager):
num_attempts = 0
while num_attempts < max_attempts:
num_attempts += 1
response = await self._model_client.create(messages=select_speaker_messages)
if self._streaming:
message: CreateResult | str = ""
async for _message in self._model_client.create_stream(messages=select_speaker_messages):
if isinstance(_message, LLMStreamEndEvent):
break
message = _message
if isinstance(message, CreateResult):
response = message
else:
raise ValueError("Model failed to select a speaker.")
else:
response = await self._model_client.create(messages=select_speaker_messages)
assert isinstance(response.content, str)
select_speaker_messages.append(AssistantMessage(content=response.content, source="selector"))
# NOTE: we use all participant names to check for mentions, even if the previous speaker is not allowed.
@ -278,6 +299,7 @@ class SelectorGroupChatConfig(BaseModel):
allow_repeated_speaker: bool
# selector_func: ComponentModel | None
max_selector_attempts: int = 3
streaming: bool = False
class SelectorGroupChat(BaseGroupChat, Component[SelectorGroupChatConfig]):
@ -307,7 +329,7 @@ class SelectorGroupChat(BaseGroupChat, Component[SelectorGroupChatConfig]):
A custom function that takes the conversation history and returns a filtered list of candidates for the next speaker
selection using model. If the function returns an empty list or `None`, `SelectorGroupChat` will raise a `ValueError`.
This function is only used if `selector_func` is not set. The `allow_repeated_speaker` will be ignored if set.
streaming (bool, optional): Whether to use streaming for the model.(Only use for specify case e.g. QwQ) Defaults to False.
Raises:
ValueError: If the number of participants is less than two or if the selector prompt is invalid.
@ -449,6 +471,7 @@ Read the above conversation. Then select the next role from {participants} to pl
selector_func: Optional[SelectorFuncType] = None,
candidate_func: Optional[CandidateFuncType] = None,
custom_message_types: List[type[BaseAgentEvent | BaseChatMessage]] | None = None,
streaming: bool = False,
):
super().__init__(
participants,
@ -468,6 +491,7 @@ Read the above conversation. Then select the next role from {participants} to pl
self._selector_func = selector_func
self._max_selector_attempts = max_selector_attempts
self._candidate_func = candidate_func
self._streaming = streaming
def _create_group_chat_manager_factory(
self,
@ -499,6 +523,7 @@ Read the above conversation. Then select the next role from {participants} to pl
self._selector_func,
self._max_selector_attempts,
self._candidate_func,
self._streaming,
)
def _to_config(self) -> SelectorGroupChatConfig: