Fix chess sample (#4932)

---------

Co-authored-by: Jack Gerrits <jackgerrits@users.noreply.github.com>
This commit is contained in:
Eric Zhu 2025-01-07 16:06:14 -08:00 committed by GitHub
parent f113c9a959
commit 52c2a70e95
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 30 additions and 121 deletions

View File

@ -1,6 +1,7 @@
"""This is an example of simulating a chess game with two agents
that play against each other, using tools to reason about the game state
and make moves, and using a group chat manager to orchestrate the conversation."""
and make moves. The agents subscribe to the default topic and publish their
moves to the default topic."""
import argparse
import asyncio
@ -19,7 +20,12 @@ from autogen_core import (
message_handler,
)
from autogen_core.model_context import BufferedChatCompletionContext, ChatCompletionContext
from autogen_core.models import AssistantMessage, ChatCompletionClient, LLMMessage, SystemMessage, UserMessage
from autogen_core.models import (
ChatCompletionClient,
LLMMessage,
SystemMessage,
UserMessage,
)
from autogen_core.tool_agent import ToolAgent, tool_agent_caller_loop
from autogen_core.tools import FunctionTool, Tool, ToolSchema
from chess import BLACK, SQUARE_NAMES, WHITE, Board, Move
@ -33,7 +39,7 @@ class TextMessage(BaseModel):
@default_subscription
class ToolUseAgent(RoutedAgent):
class PlayerAgent(RoutedAgent):
def __init__(
self,
description: str,
@ -59,14 +65,15 @@ class ToolUseAgent(RoutedAgent):
self,
tool_agent_id=self._tool_agent_id,
model_client=self._model_client,
input_messages=(await self._model_context.get_messages()),
input_messages=self._system_messages + (await self._model_context.get_messages()),
tool_schema=self._tool_schema,
cancellation_token=ctx.cancellation_token,
)
assert isinstance(messages[-1].content, str)
# Add the assistant message to the model context.
await self._model_context.add_message(AssistantMessage(content=messages[-1].content, source=self.id.type))
for msg in messages:
await self._model_context.add_message(msg)
# Publish the final response.
assert isinstance(messages[-1].content, str)
await self.publish_message(TextMessage(content=messages[-1].content, source=self.id.type), DefaultTopicId())
@ -203,39 +210,39 @@ async def chess_game(runtime: AgentRuntime, model_config: Dict[str, Any]) -> Non
# Register the agents.
await ToolAgent.register(
runtime,
"ToolAgent",
lambda: ToolAgent(description="Tool agent for chess game.", tools=black_tools + white_tools),
"PlayerBlackToolAgent",
lambda: ToolAgent(description="Tool agent for chess game.", tools=black_tools),
)
await ToolUseAgent.register(
await ToolAgent.register(
runtime,
"PlayerWhiteToolAgent",
lambda: ToolAgent(description="Tool agent for chess game.", tools=white_tools),
)
await PlayerAgent.register(
runtime,
"PlayerBlack",
lambda: ToolUseAgent(
lambda: PlayerAgent(
description="Player playing black.",
instructions="You are a chess player and you play as black. "
"Use get_legal_moves() to get list of legal moves. "
"Use get_board() to get the current board state. "
"Think about your strategy and call make_move(thinking, move) to make a move.",
instructions="You are a chess player and you play as black. Use the tool 'get_board' and 'get_legal_moves' to get the legal moves and 'make_move' to make a move.",
model_client=model_client,
model_context=BufferedChatCompletionContext(buffer_size=10),
tool_schema=[tool.schema for tool in black_tools],
tool_agent_type="ToolAgent",
tool_agent_type="PlayerBlackToolAgent",
),
)
await ToolUseAgent.register(
await PlayerAgent.register(
runtime,
"PlayerWhite",
lambda: ToolUseAgent(
lambda: PlayerAgent(
description="Player playing white.",
instructions="You are a chess player and you play as white. "
"Use get_legal_moves() to get list of legal moves. "
"Use get_board() to get the current board state. "
"Think about your strategy and call make_move(thinking, move) to make a move.",
instructions="You are a chess player and you play as white. Use the tool 'get_board' and 'get_legal_moves' to get the legal moves and 'make_move' to make a move.",
model_client=model_client,
model_context=BufferedChatCompletionContext(buffer_size=10),
tool_schema=[tool.schema for tool in white_tools],
tool_agent_type="ToolAgent",
tool_agent_type="PlayerWhiteToolAgent",
),
)
@ -249,7 +256,7 @@ async def main(model_config: Dict[str, Any]) -> None:
# orchestration.
# Send an initial message to player white to start the game.
await runtime.send_message(
TextMessage(content="Game started.", source="System"),
TextMessage(content="Game started, white player your move.", source="System"),
AgentId("PlayerWhite", "default"),
)
await runtime.stop_when_idle()

View File

@ -1,98 +0,0 @@
from typing import List, Optional, Union
from autogen_core.models import (
AssistantMessage,
FunctionExecutionResult,
FunctionExecutionResultMessage,
LLMMessage,
UserMessage,
)
from typing_extensions import Literal
from .messages import (
FunctionCallMessage,
Message,
MultiModalMessage,
TextMessage,
)
def convert_content_message_to_assistant_message(
message: Union[TextMessage, MultiModalMessage, FunctionCallMessage],
handle_unrepresentable: Literal["error", "ignore", "try_slice"] = "error",
) -> Optional[AssistantMessage]:
match message:
case TextMessage() | FunctionCallMessage():
return AssistantMessage(content=message.content, source=message.source)
case MultiModalMessage():
if handle_unrepresentable == "error":
raise ValueError("Cannot represent multimodal message as AssistantMessage")
elif handle_unrepresentable == "ignore":
return None
elif handle_unrepresentable == "try_slice":
return AssistantMessage(
content="".join([x for x in message.content if isinstance(x, str)]),
source=message.source,
)
def convert_content_message_to_user_message(
message: Union[TextMessage, MultiModalMessage, FunctionCallMessage],
handle_unrepresentable: Literal["error", "ignore", "try_slice"] = "error",
) -> Optional[UserMessage]:
match message:
case TextMessage() | MultiModalMessage():
return UserMessage(content=message.content, source=message.source)
case FunctionCallMessage():
if handle_unrepresentable == "error":
raise ValueError("Cannot represent multimodal message as UserMessage")
elif handle_unrepresentable == "ignore":
return None
elif handle_unrepresentable == "try_slice":
# TODO: what is a sliced function call?
raise NotImplementedError("Sliced function calls not yet implemented")
def convert_tool_call_response_message(
message: FunctionExecutionResultMessage,
handle_unrepresentable: Literal["error", "ignore", "try_slice"] = "error",
) -> Optional[FunctionExecutionResultMessage]:
match message:
case FunctionExecutionResultMessage():
return FunctionExecutionResultMessage(
content=[FunctionExecutionResult(content=x.content, call_id=x.call_id) for x in message.content]
)
def convert_messages_to_llm_messages(
messages: List[Message],
self_name: str,
handle_unrepresentable: Literal["error", "ignore", "try_slice"] = "error",
) -> List[LLMMessage]:
result: List[LLMMessage] = []
for message in messages:
match message:
case (
TextMessage(content=_, source=source)
| MultiModalMessage(content=_, source=source)
| FunctionCallMessage(content=_, source=source)
) if source == self_name:
converted_message_1 = convert_content_message_to_assistant_message(message, handle_unrepresentable)
if converted_message_1 is not None:
result.append(converted_message_1)
case (
TextMessage(content=_, source=source)
| MultiModalMessage(content=_, source=source)
| FunctionCallMessage(content=_, source=source)
) if source != self_name:
converted_message_2 = convert_content_message_to_user_message(message, handle_unrepresentable)
if converted_message_2 is not None:
result.append(converted_message_2)
case FunctionExecutionResultMessage(content=_):
converted_message_3 = convert_tool_call_response_message(message, handle_unrepresentable)
if converted_message_3 is not None:
result.append(converted_message_3)
case _:
raise AssertionError("unreachable")
return result