Return message history in agentchat (#661)

* update TeamRunResult

* fix line ending in test

* lint

* update team result to list[chatmessage]

---------

Co-authored-by: Leonardo Pinheiro <lpinheiro@microsoft.com>
Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com>
This commit is contained in:
Leonardo Pinheiro 2024-10-01 10:03:20 +10:00 committed by Jack Gerrits
parent e7342d558c
commit 7fade2d5e7
3 changed files with 28 additions and 22 deletions

View File

@ -1,10 +1,12 @@
from dataclasses import dataclass
from typing import Protocol
from typing import List, Protocol
from autogen_agentchat.agents._base_chat_agent import ChatMessage
@dataclass
class TeamRunResult:
result: str
messages: List[ChatMessage]
class BaseTeam(Protocol):

View File

@ -1,7 +1,7 @@
import asyncio
import uuid
from typing import Callable, List
from autogen_agentchat.agents._base_chat_agent import ChatMessage
from autogen_core.application import SingleThreadedAgentRuntime
from autogen_core.base import AgentId, AgentInstantiationContext, AgentRuntime, AgentType, MessageContext, TopicId
from autogen_core.components import ClosureAgent, TypeSubscription
@ -132,19 +132,20 @@ class RoundRobinGroupChat(BaseTeam):
TypeSubscription(topic_type=team_topic_type, agent_type=group_chat_manager_agent_type.type)
)
# Create a closure agent to recieve the final result.
team_messages = asyncio.Queue[ContentPublishEvent]()
group_chat_messages: List[ChatMessage] = []
async def output_result(
async def collect_group_chat_messages(
_runtime: AgentRuntime, id: AgentId, message: ContentPublishEvent, ctx: MessageContext
) -> None:
await team_messages.put(message)
group_chat_messages.append(message.agent_message)
await ClosureAgent.register(
runtime,
type="output_result",
closure=output_result,
subscriptions=lambda: [TypeSubscription(topic_type=team_topic_type, agent_type="output_result")],
type="collect_group_chat_messages",
closure=collect_group_chat_messages,
subscriptions=lambda: [
TypeSubscription(topic_type=group_topic_type, agent_type="collect_group_chat_messages")
],
)
# Start the runtime.
@ -162,14 +163,4 @@ class RoundRobinGroupChat(BaseTeam):
# Wait for the runtime to stop.
await runtime.stop_when_idle()
# Get the last message from the team.
last_message = None
while not team_messages.empty():
last_message = await team_messages.get()
assert (
last_message is not None
and isinstance(last_message.agent_message, TextMessage)
and isinstance(last_message.agent_message.content, str)
)
return TeamRunResult(last_message.agent_message.content)
return TeamRunResult(messages=group_chat_messages)

View File

@ -95,7 +95,20 @@ async def test_round_robin_group_chat(monkeypatch: pytest.MonkeyPatch) -> None:
)
team = RoundRobinGroupChat(participants=[coding_assistant_agent, code_executor_agent])
result = await team.run("Write a program that prints 'Hello, world!'")
assert result.result == "TERMINATE"
expected_messages = [
"Write a program that prints 'Hello, world!'",
'Here is the program\n ```python\nprint("Hello, world!")\n```',
"Hello, world!",
"TERMINATE",
]
# Normalize the messages to remove \r\n and any leading/trailing whitespace.
normalized_messages = [
msg.content.replace("\r\n", "\n").rstrip("\n") if isinstance(msg.content, str) else msg.content
for msg in result.messages
]
# Assert that all expected messages are in the collected messages
assert normalized_messages == expected_messages
@pytest.mark.asyncio