Fix chess sample (#4932)

---------

Co-authored-by: Jack Gerrits <jackgerrits@users.noreply.github.com>
This commit is contained in:
Eric Zhu 2025-01-07 16:06:14 -08:00 committed by GitHub
parent f113c9a959
commit 52c2a70e95
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 30 additions and 121 deletions

View File

@ -1,6 +1,7 @@
"""This is an example of simulating a chess game with two agents """This is an example of simulating a chess game with two agents
that play against each other, using tools to reason about the game state that play against each other, using tools to reason about the game state
and make moves, and using a group chat manager to orchestrate the conversation.""" and make moves. The agents subscribe to the default topic and publish their
moves to the default topic."""
import argparse import argparse
import asyncio import asyncio
@ -19,7 +20,12 @@ from autogen_core import (
message_handler, message_handler,
) )
from autogen_core.model_context import BufferedChatCompletionContext, ChatCompletionContext from autogen_core.model_context import BufferedChatCompletionContext, ChatCompletionContext
from autogen_core.models import AssistantMessage, ChatCompletionClient, LLMMessage, SystemMessage, UserMessage from autogen_core.models import (
ChatCompletionClient,
LLMMessage,
SystemMessage,
UserMessage,
)
from autogen_core.tool_agent import ToolAgent, tool_agent_caller_loop from autogen_core.tool_agent import ToolAgent, tool_agent_caller_loop
from autogen_core.tools import FunctionTool, Tool, ToolSchema from autogen_core.tools import FunctionTool, Tool, ToolSchema
from chess import BLACK, SQUARE_NAMES, WHITE, Board, Move from chess import BLACK, SQUARE_NAMES, WHITE, Board, Move
@ -33,7 +39,7 @@ class TextMessage(BaseModel):
@default_subscription @default_subscription
class ToolUseAgent(RoutedAgent): class PlayerAgent(RoutedAgent):
def __init__( def __init__(
self, self,
description: str, description: str,
@ -59,14 +65,15 @@ class ToolUseAgent(RoutedAgent):
self, self,
tool_agent_id=self._tool_agent_id, tool_agent_id=self._tool_agent_id,
model_client=self._model_client, model_client=self._model_client,
input_messages=(await self._model_context.get_messages()), input_messages=self._system_messages + (await self._model_context.get_messages()),
tool_schema=self._tool_schema, tool_schema=self._tool_schema,
cancellation_token=ctx.cancellation_token, cancellation_token=ctx.cancellation_token,
) )
assert isinstance(messages[-1].content, str)
# Add the assistant message to the model context. # Add the assistant message to the model context.
await self._model_context.add_message(AssistantMessage(content=messages[-1].content, source=self.id.type)) for msg in messages:
await self._model_context.add_message(msg)
# Publish the final response. # Publish the final response.
assert isinstance(messages[-1].content, str)
await self.publish_message(TextMessage(content=messages[-1].content, source=self.id.type), DefaultTopicId()) await self.publish_message(TextMessage(content=messages[-1].content, source=self.id.type), DefaultTopicId())
@ -203,39 +210,39 @@ async def chess_game(runtime: AgentRuntime, model_config: Dict[str, Any]) -> Non
# Register the agents. # Register the agents.
await ToolAgent.register( await ToolAgent.register(
runtime, runtime,
"ToolAgent", "PlayerBlackToolAgent",
lambda: ToolAgent(description="Tool agent for chess game.", tools=black_tools + white_tools), lambda: ToolAgent(description="Tool agent for chess game.", tools=black_tools),
) )
await ToolUseAgent.register( await ToolAgent.register(
runtime,
"PlayerWhiteToolAgent",
lambda: ToolAgent(description="Tool agent for chess game.", tools=white_tools),
)
await PlayerAgent.register(
runtime, runtime,
"PlayerBlack", "PlayerBlack",
lambda: ToolUseAgent( lambda: PlayerAgent(
description="Player playing black.", description="Player playing black.",
instructions="You are a chess player and you play as black. " instructions="You are a chess player and you play as black. Use the tool 'get_board' and 'get_legal_moves' to get the legal moves and 'make_move' to make a move.",
"Use get_legal_moves() to get list of legal moves. "
"Use get_board() to get the current board state. "
"Think about your strategy and call make_move(thinking, move) to make a move.",
model_client=model_client, model_client=model_client,
model_context=BufferedChatCompletionContext(buffer_size=10), model_context=BufferedChatCompletionContext(buffer_size=10),
tool_schema=[tool.schema for tool in black_tools], tool_schema=[tool.schema for tool in black_tools],
tool_agent_type="ToolAgent", tool_agent_type="PlayerBlackToolAgent",
), ),
) )
await ToolUseAgent.register( await PlayerAgent.register(
runtime, runtime,
"PlayerWhite", "PlayerWhite",
lambda: ToolUseAgent( lambda: PlayerAgent(
description="Player playing white.", description="Player playing white.",
instructions="You are a chess player and you play as white. " instructions="You are a chess player and you play as white. Use the tool 'get_board' and 'get_legal_moves' to get the legal moves and 'make_move' to make a move.",
"Use get_legal_moves() to get list of legal moves. "
"Use get_board() to get the current board state. "
"Think about your strategy and call make_move(thinking, move) to make a move.",
model_client=model_client, model_client=model_client,
model_context=BufferedChatCompletionContext(buffer_size=10), model_context=BufferedChatCompletionContext(buffer_size=10),
tool_schema=[tool.schema for tool in white_tools], tool_schema=[tool.schema for tool in white_tools],
tool_agent_type="ToolAgent", tool_agent_type="PlayerWhiteToolAgent",
), ),
) )
@ -249,7 +256,7 @@ async def main(model_config: Dict[str, Any]) -> None:
# orchestration. # orchestration.
# Send an initial message to player white to start the game. # Send an initial message to player white to start the game.
await runtime.send_message( await runtime.send_message(
TextMessage(content="Game started.", source="System"), TextMessage(content="Game started, white player your move.", source="System"),
AgentId("PlayerWhite", "default"), AgentId("PlayerWhite", "default"),
) )
await runtime.stop_when_idle() await runtime.stop_when_idle()

View File

@ -1,98 +0,0 @@
from typing import List, Optional, Union
from autogen_core.models import (
AssistantMessage,
FunctionExecutionResult,
FunctionExecutionResultMessage,
LLMMessage,
UserMessage,
)
from typing_extensions import Literal
from .messages import (
FunctionCallMessage,
Message,
MultiModalMessage,
TextMessage,
)
def convert_content_message_to_assistant_message(
message: Union[TextMessage, MultiModalMessage, FunctionCallMessage],
handle_unrepresentable: Literal["error", "ignore", "try_slice"] = "error",
) -> Optional[AssistantMessage]:
match message:
case TextMessage() | FunctionCallMessage():
return AssistantMessage(content=message.content, source=message.source)
case MultiModalMessage():
if handle_unrepresentable == "error":
raise ValueError("Cannot represent multimodal message as AssistantMessage")
elif handle_unrepresentable == "ignore":
return None
elif handle_unrepresentable == "try_slice":
return AssistantMessage(
content="".join([x for x in message.content if isinstance(x, str)]),
source=message.source,
)
def convert_content_message_to_user_message(
message: Union[TextMessage, MultiModalMessage, FunctionCallMessage],
handle_unrepresentable: Literal["error", "ignore", "try_slice"] = "error",
) -> Optional[UserMessage]:
match message:
case TextMessage() | MultiModalMessage():
return UserMessage(content=message.content, source=message.source)
case FunctionCallMessage():
if handle_unrepresentable == "error":
raise ValueError("Cannot represent multimodal message as UserMessage")
elif handle_unrepresentable == "ignore":
return None
elif handle_unrepresentable == "try_slice":
# TODO: what is a sliced function call?
raise NotImplementedError("Sliced function calls not yet implemented")
def convert_tool_call_response_message(
message: FunctionExecutionResultMessage,
handle_unrepresentable: Literal["error", "ignore", "try_slice"] = "error",
) -> Optional[FunctionExecutionResultMessage]:
match message:
case FunctionExecutionResultMessage():
return FunctionExecutionResultMessage(
content=[FunctionExecutionResult(content=x.content, call_id=x.call_id) for x in message.content]
)
def convert_messages_to_llm_messages(
messages: List[Message],
self_name: str,
handle_unrepresentable: Literal["error", "ignore", "try_slice"] = "error",
) -> List[LLMMessage]:
result: List[LLMMessage] = []
for message in messages:
match message:
case (
TextMessage(content=_, source=source)
| MultiModalMessage(content=_, source=source)
| FunctionCallMessage(content=_, source=source)
) if source == self_name:
converted_message_1 = convert_content_message_to_assistant_message(message, handle_unrepresentable)
if converted_message_1 is not None:
result.append(converted_message_1)
case (
TextMessage(content=_, source=source)
| MultiModalMessage(content=_, source=source)
| FunctionCallMessage(content=_, source=source)
) if source != self_name:
converted_message_2 = convert_content_message_to_user_message(message, handle_unrepresentable)
if converted_message_2 is not None:
result.append(converted_message_2)
case FunctionExecutionResultMessage(content=_):
converted_message_3 = convert_tool_call_response_message(message, handle_unrepresentable)
if converted_message_3 is not None:
result.append(converted_message_3)
case _:
raise AssertionError("unreachable")
return result