autogen/python/samples/core_distributed-group-chat/_agents.py

215 lines
8.2 KiB
Python

import asyncio
import random
from typing import Awaitable, Callable, List
from uuid import uuid4
from _types import GroupChatMessage, MessageChunk, RequestToSpeak, UIAgentConfig
from autogen_core import DefaultTopicId, MessageContext, RoutedAgent, message_handler
from autogen_core.models import (
AssistantMessage,
ChatCompletionClient,
LLMMessage,
SystemMessage,
UserMessage,
)
from autogen_ext.runtimes.grpc import GrpcWorkerAgentRuntime
from rich.console import Console
from rich.markdown import Markdown
class BaseGroupChatAgent(RoutedAgent):
"""A group chat participant using an LLM."""
def __init__(
self,
description: str,
group_chat_topic_type: str,
model_client: ChatCompletionClient,
system_message: str,
ui_config: UIAgentConfig,
) -> None:
super().__init__(description=description)
self._group_chat_topic_type = group_chat_topic_type
self._model_client = model_client
self._system_message = SystemMessage(content=system_message)
self._chat_history: List[LLMMessage] = []
self._ui_config = ui_config
self.console = Console()
@message_handler
async def handle_message(self, message: GroupChatMessage, ctx: MessageContext) -> None:
self._chat_history.extend(
[
UserMessage(content=f"Transferred to {message.body.source}", source="system"), # type: ignore[union-attr]
message.body,
]
)
@message_handler
async def handle_request_to_speak(self, message: RequestToSpeak, ctx: MessageContext) -> None:
self._chat_history.append(
UserMessage(content=f"Transferred to {self.id.type}, adopt the persona immediately.", source="system")
)
completion = await self._model_client.create([self._system_message] + self._chat_history)
assert isinstance(completion.content, str)
self._chat_history.append(AssistantMessage(content=completion.content, source=self.id.type))
console_message = f"\n{'-'*80}\n**{self.id.type}**: {completion.content}"
self.console.print(Markdown(console_message))
await publish_message_to_ui_and_backend(
runtime=self,
source=self.id.type,
user_message=completion.content,
ui_config=self._ui_config,
group_chat_topic_type=self._group_chat_topic_type,
)
class GroupChatManager(RoutedAgent):
def __init__(
self,
model_client: ChatCompletionClient,
participant_topic_types: List[str],
participant_descriptions: List[str],
ui_config: UIAgentConfig,
max_rounds: int = 3,
) -> None:
super().__init__("Group chat manager")
self._model_client = model_client
self._num_rounds = 0
self._participant_topic_types = participant_topic_types
self._chat_history: List[GroupChatMessage] = []
self._max_rounds = max_rounds
self.console = Console()
self._participant_descriptions = participant_descriptions
self._previous_participant_topic_type: str | None = None
self._ui_config = ui_config
@message_handler
async def handle_message(self, message: GroupChatMessage, ctx: MessageContext) -> None:
assert isinstance(message.body, UserMessage)
self._chat_history.append(message.body) # type: ignore[reportargumenttype,arg-type]
# Format message history.
messages: List[str] = []
for msg in self._chat_history:
if isinstance(msg.content, str): # type: ignore[attr-defined]
messages.append(f"{msg.source}: {msg.content}") # type: ignore[attr-defined]
elif isinstance(msg.content, list): # type: ignore[attr-defined]
messages.append(f"{msg.source}: {', '.join(msg.content)}") # type: ignore[attr-defined,reportUnknownArgumentType]
history = "\n".join(messages)
# Format roles.
roles = "\n".join(
[
f"{topic_type}: {description}".strip()
for topic_type, description in zip(
self._participant_topic_types, self._participant_descriptions, strict=True
)
if topic_type != self._previous_participant_topic_type
]
)
participants = str(
[
topic_type
for topic_type in self._participant_topic_types
if topic_type != self._previous_participant_topic_type
]
)
selector_prompt = f"""You are in a role play game. The following roles are available:
{roles}.
Read the following conversation. Then select the next role from {participants} to play. Only return the role.
{history}
Read the above conversation. Then select the next role from {participants} to play. if you think it's enough talking (for example they have talked for {self._max_rounds} rounds), return 'FINISH'.
"""
system_message = SystemMessage(content=selector_prompt)
completion = await self._model_client.create([system_message], cancellation_token=ctx.cancellation_token)
assert isinstance(
completion.content, str
), f"Completion content must be a string, but is: {type(completion.content)}"
if completion.content.upper() == "FINISH":
finish_msg = "I think it's enough iterations on the story! Thanks for collaborating!"
manager_message = f"\n{'-'*80}\n Manager ({id(self)}): {finish_msg}"
await publish_message_to_ui(
runtime=self, source=self.id.type, user_message=finish_msg, ui_config=self._ui_config
)
self.console.print(Markdown(manager_message))
return
selected_topic_type: str
for topic_type in self._participant_topic_types:
if topic_type.lower() in completion.content.lower():
selected_topic_type = topic_type
self._previous_participant_topic_type = selected_topic_type
self.console.print(
Markdown(f"\n{'-'*80}\n Manager ({id(self)}): Asking `{selected_topic_type}` to speak")
)
await self.publish_message(RequestToSpeak(), DefaultTopicId(type=selected_topic_type))
return
raise ValueError(f"Invalid role selected: {completion.content}")
class UIAgent(RoutedAgent):
"""Handles UI-related tasks and message processing for the distributed group chat system."""
def __init__(self, on_message_chunk_func: Callable[[MessageChunk], Awaitable[None]]) -> None:
super().__init__("UI Agent")
self._on_message_chunk_func = on_message_chunk_func
@message_handler
async def handle_message_chunk(self, message: MessageChunk, ctx: MessageContext) -> None:
await self._on_message_chunk_func(message)
async def publish_message_to_ui(
runtime: RoutedAgent | GrpcWorkerAgentRuntime,
source: str,
user_message: str,
ui_config: UIAgentConfig,
) -> None:
message_id = str(uuid4())
# Stream the message to UI
message_chunks = (
MessageChunk(message_id=message_id, text=token + " ", author=source, finished=False)
for token in user_message.split()
)
for chunk in message_chunks:
await runtime.publish_message(
chunk,
DefaultTopicId(type=ui_config.topic_type),
)
await asyncio.sleep(random.uniform(ui_config.min_delay, ui_config.max_delay))
await runtime.publish_message(
MessageChunk(message_id=message_id, text=" ", author=source, finished=True),
DefaultTopicId(type=ui_config.topic_type),
)
async def publish_message_to_ui_and_backend(
runtime: RoutedAgent | GrpcWorkerAgentRuntime,
source: str,
user_message: str,
ui_config: UIAgentConfig,
group_chat_topic_type: str,
) -> None:
# Publish messages for ui
await publish_message_to_ui(
runtime=runtime,
source=source,
user_message=user_message,
ui_config=ui_config,
)
# Publish message to backend
await runtime.publish_message(
GroupChatMessage(body=UserMessage(content=user_message, source=source)),
topic_id=DefaultTopicId(type=group_chat_topic_type),
)