mirror of https://github.com/microsoft/autogen.git
Return message history in agentchat (#661)
* update TeamRunResult * fix line ending in test * lint * update team result to list[chatmessage] --------- Co-authored-by: Leonardo Pinheiro <lpinheiro@microsoft.com> Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com>
This commit is contained in:
parent
e7342d558c
commit
7fade2d5e7
|
@ -1,10 +1,12 @@
|
|||
from dataclasses import dataclass
|
||||
from typing import Protocol
|
||||
from typing import List, Protocol
|
||||
|
||||
from autogen_agentchat.agents._base_chat_agent import ChatMessage
|
||||
|
||||
|
||||
@dataclass
|
||||
class TeamRunResult:
|
||||
result: str
|
||||
messages: List[ChatMessage]
|
||||
|
||||
|
||||
class BaseTeam(Protocol):
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import asyncio
|
||||
import uuid
|
||||
from typing import Callable, List
|
||||
|
||||
from autogen_agentchat.agents._base_chat_agent import ChatMessage
|
||||
from autogen_core.application import SingleThreadedAgentRuntime
|
||||
from autogen_core.base import AgentId, AgentInstantiationContext, AgentRuntime, AgentType, MessageContext, TopicId
|
||||
from autogen_core.components import ClosureAgent, TypeSubscription
|
||||
|
@ -132,19 +132,20 @@ class RoundRobinGroupChat(BaseTeam):
|
|||
TypeSubscription(topic_type=team_topic_type, agent_type=group_chat_manager_agent_type.type)
|
||||
)
|
||||
|
||||
# Create a closure agent to recieve the final result.
|
||||
team_messages = asyncio.Queue[ContentPublishEvent]()
|
||||
group_chat_messages: List[ChatMessage] = []
|
||||
|
||||
async def output_result(
|
||||
async def collect_group_chat_messages(
|
||||
_runtime: AgentRuntime, id: AgentId, message: ContentPublishEvent, ctx: MessageContext
|
||||
) -> None:
|
||||
await team_messages.put(message)
|
||||
group_chat_messages.append(message.agent_message)
|
||||
|
||||
await ClosureAgent.register(
|
||||
runtime,
|
||||
type="output_result",
|
||||
closure=output_result,
|
||||
subscriptions=lambda: [TypeSubscription(topic_type=team_topic_type, agent_type="output_result")],
|
||||
type="collect_group_chat_messages",
|
||||
closure=collect_group_chat_messages,
|
||||
subscriptions=lambda: [
|
||||
TypeSubscription(topic_type=group_topic_type, agent_type="collect_group_chat_messages")
|
||||
],
|
||||
)
|
||||
|
||||
# Start the runtime.
|
||||
|
@ -162,14 +163,4 @@ class RoundRobinGroupChat(BaseTeam):
|
|||
# Wait for the runtime to stop.
|
||||
await runtime.stop_when_idle()
|
||||
|
||||
# Get the last message from the team.
|
||||
last_message = None
|
||||
while not team_messages.empty():
|
||||
last_message = await team_messages.get()
|
||||
|
||||
assert (
|
||||
last_message is not None
|
||||
and isinstance(last_message.agent_message, TextMessage)
|
||||
and isinstance(last_message.agent_message.content, str)
|
||||
)
|
||||
return TeamRunResult(last_message.agent_message.content)
|
||||
return TeamRunResult(messages=group_chat_messages)
|
||||
|
|
|
@ -95,7 +95,20 @@ async def test_round_robin_group_chat(monkeypatch: pytest.MonkeyPatch) -> None:
|
|||
)
|
||||
team = RoundRobinGroupChat(participants=[coding_assistant_agent, code_executor_agent])
|
||||
result = await team.run("Write a program that prints 'Hello, world!'")
|
||||
assert result.result == "TERMINATE"
|
||||
expected_messages = [
|
||||
"Write a program that prints 'Hello, world!'",
|
||||
'Here is the program\n ```python\nprint("Hello, world!")\n```',
|
||||
"Hello, world!",
|
||||
"TERMINATE",
|
||||
]
|
||||
# Normalize the messages to remove \r\n and any leading/trailing whitespace.
|
||||
normalized_messages = [
|
||||
msg.content.replace("\r\n", "\n").rstrip("\n") if isinstance(msg.content, str) else msg.content
|
||||
for msg in result.messages
|
||||
]
|
||||
|
||||
# Assert that all expected messages are in the collected messages
|
||||
assert normalized_messages == expected_messages
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
Loading…
Reference in New Issue