mirror of https://github.com/microsoft/autogen.git
Merge 965d9a461b
into eca80ff663
This commit is contained in:
commit
a63f1dd7bc
|
@ -5,7 +5,15 @@ from inspect import iscoroutinefunction
|
||||||
from typing import Any, Awaitable, Callable, Dict, List, Mapping, Optional, Sequence, Union, cast
|
from typing import Any, Awaitable, Callable, Dict, List, Mapping, Optional, Sequence, Union, cast
|
||||||
|
|
||||||
from autogen_core import AgentRuntime, Component, ComponentModel
|
from autogen_core import AgentRuntime, Component, ComponentModel
|
||||||
from autogen_core.models import AssistantMessage, ChatCompletionClient, ModelFamily, SystemMessage, UserMessage
|
from autogen_core.logging import LLMStreamEndEvent
|
||||||
|
from autogen_core.models import (
|
||||||
|
AssistantMessage,
|
||||||
|
ChatCompletionClient,
|
||||||
|
CreateResult,
|
||||||
|
ModelFamily,
|
||||||
|
SystemMessage,
|
||||||
|
UserMessage,
|
||||||
|
)
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from typing_extensions import Self
|
from typing_extensions import Self
|
||||||
|
|
||||||
|
@ -55,6 +63,7 @@ class SelectorGroupChatManager(BaseGroupChatManager):
|
||||||
selector_func: Optional[SelectorFuncType],
|
selector_func: Optional[SelectorFuncType],
|
||||||
max_selector_attempts: int,
|
max_selector_attempts: int,
|
||||||
candidate_func: Optional[CandidateFuncType],
|
candidate_func: Optional[CandidateFuncType],
|
||||||
|
streaming: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(
|
super().__init__(
|
||||||
name,
|
name,
|
||||||
|
@ -77,6 +86,7 @@ class SelectorGroupChatManager(BaseGroupChatManager):
|
||||||
self._max_selector_attempts = max_selector_attempts
|
self._max_selector_attempts = max_selector_attempts
|
||||||
self._candidate_func = candidate_func
|
self._candidate_func = candidate_func
|
||||||
self._is_candidate_func_async = iscoroutinefunction(self._candidate_func)
|
self._is_candidate_func_async = iscoroutinefunction(self._candidate_func)
|
||||||
|
self._streaming = streaming
|
||||||
|
|
||||||
async def validate_group_state(self, messages: List[BaseChatMessage] | None) -> None:
|
async def validate_group_state(self, messages: List[BaseChatMessage] | None) -> None:
|
||||||
pass
|
pass
|
||||||
|
@ -192,7 +202,18 @@ class SelectorGroupChatManager(BaseGroupChatManager):
|
||||||
num_attempts = 0
|
num_attempts = 0
|
||||||
while num_attempts < max_attempts:
|
while num_attempts < max_attempts:
|
||||||
num_attempts += 1
|
num_attempts += 1
|
||||||
response = await self._model_client.create(messages=select_speaker_messages)
|
if self._streaming:
|
||||||
|
message: CreateResult | str = ""
|
||||||
|
async for _message in self._model_client.create_stream(messages=select_speaker_messages):
|
||||||
|
if isinstance(_message, LLMStreamEndEvent):
|
||||||
|
break
|
||||||
|
message = _message
|
||||||
|
if isinstance(message, CreateResult):
|
||||||
|
response = message
|
||||||
|
else:
|
||||||
|
raise ValueError("Model failed to select a speaker.")
|
||||||
|
else:
|
||||||
|
response = await self._model_client.create(messages=select_speaker_messages)
|
||||||
assert isinstance(response.content, str)
|
assert isinstance(response.content, str)
|
||||||
select_speaker_messages.append(AssistantMessage(content=response.content, source="selector"))
|
select_speaker_messages.append(AssistantMessage(content=response.content, source="selector"))
|
||||||
# NOTE: we use all participant names to check for mentions, even if the previous speaker is not allowed.
|
# NOTE: we use all participant names to check for mentions, even if the previous speaker is not allowed.
|
||||||
|
@ -278,6 +299,7 @@ class SelectorGroupChatConfig(BaseModel):
|
||||||
allow_repeated_speaker: bool
|
allow_repeated_speaker: bool
|
||||||
# selector_func: ComponentModel | None
|
# selector_func: ComponentModel | None
|
||||||
max_selector_attempts: int = 3
|
max_selector_attempts: int = 3
|
||||||
|
streaming: bool = False
|
||||||
|
|
||||||
|
|
||||||
class SelectorGroupChat(BaseGroupChat, Component[SelectorGroupChatConfig]):
|
class SelectorGroupChat(BaseGroupChat, Component[SelectorGroupChatConfig]):
|
||||||
|
@ -307,7 +329,7 @@ class SelectorGroupChat(BaseGroupChat, Component[SelectorGroupChatConfig]):
|
||||||
A custom function that takes the conversation history and returns a filtered list of candidates for the next speaker
|
A custom function that takes the conversation history and returns a filtered list of candidates for the next speaker
|
||||||
selection using model. If the function returns an empty list or `None`, `SelectorGroupChat` will raise a `ValueError`.
|
selection using model. If the function returns an empty list or `None`, `SelectorGroupChat` will raise a `ValueError`.
|
||||||
This function is only used if `selector_func` is not set. The `allow_repeated_speaker` will be ignored if set.
|
This function is only used if `selector_func` is not set. The `allow_repeated_speaker` will be ignored if set.
|
||||||
|
streaming (bool, optional): Whether to use streaming for the model.(Only use for specify case e.g. QwQ) Defaults to False.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If the number of participants is less than two or if the selector prompt is invalid.
|
ValueError: If the number of participants is less than two or if the selector prompt is invalid.
|
||||||
|
@ -449,6 +471,7 @@ Read the above conversation. Then select the next role from {participants} to pl
|
||||||
selector_func: Optional[SelectorFuncType] = None,
|
selector_func: Optional[SelectorFuncType] = None,
|
||||||
candidate_func: Optional[CandidateFuncType] = None,
|
candidate_func: Optional[CandidateFuncType] = None,
|
||||||
custom_message_types: List[type[BaseAgentEvent | BaseChatMessage]] | None = None,
|
custom_message_types: List[type[BaseAgentEvent | BaseChatMessage]] | None = None,
|
||||||
|
streaming: bool = False,
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
participants,
|
participants,
|
||||||
|
@ -468,6 +491,7 @@ Read the above conversation. Then select the next role from {participants} to pl
|
||||||
self._selector_func = selector_func
|
self._selector_func = selector_func
|
||||||
self._max_selector_attempts = max_selector_attempts
|
self._max_selector_attempts = max_selector_attempts
|
||||||
self._candidate_func = candidate_func
|
self._candidate_func = candidate_func
|
||||||
|
self._streaming = streaming
|
||||||
|
|
||||||
def _create_group_chat_manager_factory(
|
def _create_group_chat_manager_factory(
|
||||||
self,
|
self,
|
||||||
|
@ -499,6 +523,7 @@ Read the above conversation. Then select the next role from {participants} to pl
|
||||||
self._selector_func,
|
self._selector_func,
|
||||||
self._max_selector_attempts,
|
self._max_selector_attempts,
|
||||||
self._candidate_func,
|
self._candidate_func,
|
||||||
|
self._streaming,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _to_config(self) -> SelectorGroupChatConfig:
|
def _to_config(self) -> SelectorGroupChatConfig:
|
||||||
|
|
Loading…
Reference in New Issue