From 52c2a70e95df2006e0094e96ad192243148ec4bb Mon Sep 17 00:00:00 2001 From: Eric Zhu Date: Tue, 7 Jan 2025 16:06:14 -0800 Subject: [PATCH] Fix chess sample (#4932) --------- Co-authored-by: Jack Gerrits --- python/samples/core_chess_game/main.py | 53 +++++++------ python/samples/core_chess_game/utils.py | 98 ------------------------- 2 files changed, 30 insertions(+), 121 deletions(-) delete mode 100644 python/samples/core_chess_game/utils.py diff --git a/python/samples/core_chess_game/main.py b/python/samples/core_chess_game/main.py index 78b478775..ccc77deba 100644 --- a/python/samples/core_chess_game/main.py +++ b/python/samples/core_chess_game/main.py @@ -1,6 +1,7 @@ """This is an example of simulating a chess game with two agents that play against each other, using tools to reason about the game state -and make moves, and using a group chat manager to orchestrate the conversation.""" +and make moves. The agents subscribe to the default topic and publish their +moves to the default topic.""" import argparse import asyncio @@ -19,7 +20,12 @@ from autogen_core import ( message_handler, ) from autogen_core.model_context import BufferedChatCompletionContext, ChatCompletionContext -from autogen_core.models import AssistantMessage, ChatCompletionClient, LLMMessage, SystemMessage, UserMessage +from autogen_core.models import ( + ChatCompletionClient, + LLMMessage, + SystemMessage, + UserMessage, +) from autogen_core.tool_agent import ToolAgent, tool_agent_caller_loop from autogen_core.tools import FunctionTool, Tool, ToolSchema from chess import BLACK, SQUARE_NAMES, WHITE, Board, Move @@ -33,7 +39,7 @@ class TextMessage(BaseModel): @default_subscription -class ToolUseAgent(RoutedAgent): +class PlayerAgent(RoutedAgent): def __init__( self, description: str, @@ -59,14 +65,15 @@ class ToolUseAgent(RoutedAgent): self, tool_agent_id=self._tool_agent_id, model_client=self._model_client, - input_messages=(await self._model_context.get_messages()), + input_messages=self._system_messages + (await self._model_context.get_messages()), tool_schema=self._tool_schema, cancellation_token=ctx.cancellation_token, ) - assert isinstance(messages[-1].content, str) # Add the assistant message to the model context. - await self._model_context.add_message(AssistantMessage(content=messages[-1].content, source=self.id.type)) + for msg in messages: + await self._model_context.add_message(msg) # Publish the final response. + assert isinstance(messages[-1].content, str) await self.publish_message(TextMessage(content=messages[-1].content, source=self.id.type), DefaultTopicId()) @@ -203,39 +210,39 @@ async def chess_game(runtime: AgentRuntime, model_config: Dict[str, Any]) -> Non # Register the agents. await ToolAgent.register( runtime, - "ToolAgent", - lambda: ToolAgent(description="Tool agent for chess game.", tools=black_tools + white_tools), + "PlayerBlackToolAgent", + lambda: ToolAgent(description="Tool agent for chess game.", tools=black_tools), ) - await ToolUseAgent.register( + await ToolAgent.register( + runtime, + "PlayerWhiteToolAgent", + lambda: ToolAgent(description="Tool agent for chess game.", tools=white_tools), + ) + + await PlayerAgent.register( runtime, "PlayerBlack", - lambda: ToolUseAgent( + lambda: PlayerAgent( description="Player playing black.", - instructions="You are a chess player and you play as black. " - "Use get_legal_moves() to get list of legal moves. " - "Use get_board() to get the current board state. " - "Think about your strategy and call make_move(thinking, move) to make a move.", + instructions="You are a chess player and you play as black. Use the tool 'get_board' and 'get_legal_moves' to get the legal moves and 'make_move' to make a move.", model_client=model_client, model_context=BufferedChatCompletionContext(buffer_size=10), tool_schema=[tool.schema for tool in black_tools], - tool_agent_type="ToolAgent", + tool_agent_type="PlayerBlackToolAgent", ), ) - await ToolUseAgent.register( + await PlayerAgent.register( runtime, "PlayerWhite", - lambda: ToolUseAgent( + lambda: PlayerAgent( description="Player playing white.", - instructions="You are a chess player and you play as white. " - "Use get_legal_moves() to get list of legal moves. " - "Use get_board() to get the current board state. " - "Think about your strategy and call make_move(thinking, move) to make a move.", + instructions="You are a chess player and you play as white. Use the tool 'get_board' and 'get_legal_moves' to get the legal moves and 'make_move' to make a move.", model_client=model_client, model_context=BufferedChatCompletionContext(buffer_size=10), tool_schema=[tool.schema for tool in white_tools], - tool_agent_type="ToolAgent", + tool_agent_type="PlayerWhiteToolAgent", ), ) @@ -249,7 +256,7 @@ async def main(model_config: Dict[str, Any]) -> None: # orchestration. # Send an initial message to player white to start the game. await runtime.send_message( - TextMessage(content="Game started.", source="System"), + TextMessage(content="Game started, white player your move.", source="System"), AgentId("PlayerWhite", "default"), ) await runtime.stop_when_idle() diff --git a/python/samples/core_chess_game/utils.py b/python/samples/core_chess_game/utils.py deleted file mode 100644 index 5fc21dc20..000000000 --- a/python/samples/core_chess_game/utils.py +++ /dev/null @@ -1,98 +0,0 @@ -from typing import List, Optional, Union - -from autogen_core.models import ( - AssistantMessage, - FunctionExecutionResult, - FunctionExecutionResultMessage, - LLMMessage, - UserMessage, -) -from typing_extensions import Literal - -from .messages import ( - FunctionCallMessage, - Message, - MultiModalMessage, - TextMessage, -) - - -def convert_content_message_to_assistant_message( - message: Union[TextMessage, MultiModalMessage, FunctionCallMessage], - handle_unrepresentable: Literal["error", "ignore", "try_slice"] = "error", -) -> Optional[AssistantMessage]: - match message: - case TextMessage() | FunctionCallMessage(): - return AssistantMessage(content=message.content, source=message.source) - case MultiModalMessage(): - if handle_unrepresentable == "error": - raise ValueError("Cannot represent multimodal message as AssistantMessage") - elif handle_unrepresentable == "ignore": - return None - elif handle_unrepresentable == "try_slice": - return AssistantMessage( - content="".join([x for x in message.content if isinstance(x, str)]), - source=message.source, - ) - - -def convert_content_message_to_user_message( - message: Union[TextMessage, MultiModalMessage, FunctionCallMessage], - handle_unrepresentable: Literal["error", "ignore", "try_slice"] = "error", -) -> Optional[UserMessage]: - match message: - case TextMessage() | MultiModalMessage(): - return UserMessage(content=message.content, source=message.source) - case FunctionCallMessage(): - if handle_unrepresentable == "error": - raise ValueError("Cannot represent multimodal message as UserMessage") - elif handle_unrepresentable == "ignore": - return None - elif handle_unrepresentable == "try_slice": - # TODO: what is a sliced function call? - raise NotImplementedError("Sliced function calls not yet implemented") - - -def convert_tool_call_response_message( - message: FunctionExecutionResultMessage, - handle_unrepresentable: Literal["error", "ignore", "try_slice"] = "error", -) -> Optional[FunctionExecutionResultMessage]: - match message: - case FunctionExecutionResultMessage(): - return FunctionExecutionResultMessage( - content=[FunctionExecutionResult(content=x.content, call_id=x.call_id) for x in message.content] - ) - - -def convert_messages_to_llm_messages( - messages: List[Message], - self_name: str, - handle_unrepresentable: Literal["error", "ignore", "try_slice"] = "error", -) -> List[LLMMessage]: - result: List[LLMMessage] = [] - for message in messages: - match message: - case ( - TextMessage(content=_, source=source) - | MultiModalMessage(content=_, source=source) - | FunctionCallMessage(content=_, source=source) - ) if source == self_name: - converted_message_1 = convert_content_message_to_assistant_message(message, handle_unrepresentable) - if converted_message_1 is not None: - result.append(converted_message_1) - case ( - TextMessage(content=_, source=source) - | MultiModalMessage(content=_, source=source) - | FunctionCallMessage(content=_, source=source) - ) if source != self_name: - converted_message_2 = convert_content_message_to_user_message(message, handle_unrepresentable) - if converted_message_2 is not None: - result.append(converted_message_2) - case FunctionExecutionResultMessage(content=_): - converted_message_3 = convert_tool_call_response_message(message, handle_unrepresentable) - if converted_message_3 is not None: - result.append(converted_message_3) - case _: - raise AssertionError("unreachable") - - return result