mirror of https://github.com/microsoft/autogen.git
Fix chess sample (#4932)
--------- Co-authored-by: Jack Gerrits <jackgerrits@users.noreply.github.com>
This commit is contained in:
parent
f113c9a959
commit
52c2a70e95
|
@ -1,6 +1,7 @@
|
||||||
"""This is an example of simulating a chess game with two agents
|
"""This is an example of simulating a chess game with two agents
|
||||||
that play against each other, using tools to reason about the game state
|
that play against each other, using tools to reason about the game state
|
||||||
and make moves, and using a group chat manager to orchestrate the conversation."""
|
and make moves. The agents subscribe to the default topic and publish their
|
||||||
|
moves to the default topic."""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import asyncio
|
import asyncio
|
||||||
|
@ -19,7 +20,12 @@ from autogen_core import (
|
||||||
message_handler,
|
message_handler,
|
||||||
)
|
)
|
||||||
from autogen_core.model_context import BufferedChatCompletionContext, ChatCompletionContext
|
from autogen_core.model_context import BufferedChatCompletionContext, ChatCompletionContext
|
||||||
from autogen_core.models import AssistantMessage, ChatCompletionClient, LLMMessage, SystemMessage, UserMessage
|
from autogen_core.models import (
|
||||||
|
ChatCompletionClient,
|
||||||
|
LLMMessage,
|
||||||
|
SystemMessage,
|
||||||
|
UserMessage,
|
||||||
|
)
|
||||||
from autogen_core.tool_agent import ToolAgent, tool_agent_caller_loop
|
from autogen_core.tool_agent import ToolAgent, tool_agent_caller_loop
|
||||||
from autogen_core.tools import FunctionTool, Tool, ToolSchema
|
from autogen_core.tools import FunctionTool, Tool, ToolSchema
|
||||||
from chess import BLACK, SQUARE_NAMES, WHITE, Board, Move
|
from chess import BLACK, SQUARE_NAMES, WHITE, Board, Move
|
||||||
|
@ -33,7 +39,7 @@ class TextMessage(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
@default_subscription
|
@default_subscription
|
||||||
class ToolUseAgent(RoutedAgent):
|
class PlayerAgent(RoutedAgent):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
description: str,
|
description: str,
|
||||||
|
@ -59,14 +65,15 @@ class ToolUseAgent(RoutedAgent):
|
||||||
self,
|
self,
|
||||||
tool_agent_id=self._tool_agent_id,
|
tool_agent_id=self._tool_agent_id,
|
||||||
model_client=self._model_client,
|
model_client=self._model_client,
|
||||||
input_messages=(await self._model_context.get_messages()),
|
input_messages=self._system_messages + (await self._model_context.get_messages()),
|
||||||
tool_schema=self._tool_schema,
|
tool_schema=self._tool_schema,
|
||||||
cancellation_token=ctx.cancellation_token,
|
cancellation_token=ctx.cancellation_token,
|
||||||
)
|
)
|
||||||
assert isinstance(messages[-1].content, str)
|
|
||||||
# Add the assistant message to the model context.
|
# Add the assistant message to the model context.
|
||||||
await self._model_context.add_message(AssistantMessage(content=messages[-1].content, source=self.id.type))
|
for msg in messages:
|
||||||
|
await self._model_context.add_message(msg)
|
||||||
# Publish the final response.
|
# Publish the final response.
|
||||||
|
assert isinstance(messages[-1].content, str)
|
||||||
await self.publish_message(TextMessage(content=messages[-1].content, source=self.id.type), DefaultTopicId())
|
await self.publish_message(TextMessage(content=messages[-1].content, source=self.id.type), DefaultTopicId())
|
||||||
|
|
||||||
|
|
||||||
|
@ -203,39 +210,39 @@ async def chess_game(runtime: AgentRuntime, model_config: Dict[str, Any]) -> Non
|
||||||
# Register the agents.
|
# Register the agents.
|
||||||
await ToolAgent.register(
|
await ToolAgent.register(
|
||||||
runtime,
|
runtime,
|
||||||
"ToolAgent",
|
"PlayerBlackToolAgent",
|
||||||
lambda: ToolAgent(description="Tool agent for chess game.", tools=black_tools + white_tools),
|
lambda: ToolAgent(description="Tool agent for chess game.", tools=black_tools),
|
||||||
)
|
)
|
||||||
|
|
||||||
await ToolUseAgent.register(
|
await ToolAgent.register(
|
||||||
|
runtime,
|
||||||
|
"PlayerWhiteToolAgent",
|
||||||
|
lambda: ToolAgent(description="Tool agent for chess game.", tools=white_tools),
|
||||||
|
)
|
||||||
|
|
||||||
|
await PlayerAgent.register(
|
||||||
runtime,
|
runtime,
|
||||||
"PlayerBlack",
|
"PlayerBlack",
|
||||||
lambda: ToolUseAgent(
|
lambda: PlayerAgent(
|
||||||
description="Player playing black.",
|
description="Player playing black.",
|
||||||
instructions="You are a chess player and you play as black. "
|
instructions="You are a chess player and you play as black. Use the tool 'get_board' and 'get_legal_moves' to get the legal moves and 'make_move' to make a move.",
|
||||||
"Use get_legal_moves() to get list of legal moves. "
|
|
||||||
"Use get_board() to get the current board state. "
|
|
||||||
"Think about your strategy and call make_move(thinking, move) to make a move.",
|
|
||||||
model_client=model_client,
|
model_client=model_client,
|
||||||
model_context=BufferedChatCompletionContext(buffer_size=10),
|
model_context=BufferedChatCompletionContext(buffer_size=10),
|
||||||
tool_schema=[tool.schema for tool in black_tools],
|
tool_schema=[tool.schema for tool in black_tools],
|
||||||
tool_agent_type="ToolAgent",
|
tool_agent_type="PlayerBlackToolAgent",
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
await ToolUseAgent.register(
|
await PlayerAgent.register(
|
||||||
runtime,
|
runtime,
|
||||||
"PlayerWhite",
|
"PlayerWhite",
|
||||||
lambda: ToolUseAgent(
|
lambda: PlayerAgent(
|
||||||
description="Player playing white.",
|
description="Player playing white.",
|
||||||
instructions="You are a chess player and you play as white. "
|
instructions="You are a chess player and you play as white. Use the tool 'get_board' and 'get_legal_moves' to get the legal moves and 'make_move' to make a move.",
|
||||||
"Use get_legal_moves() to get list of legal moves. "
|
|
||||||
"Use get_board() to get the current board state. "
|
|
||||||
"Think about your strategy and call make_move(thinking, move) to make a move.",
|
|
||||||
model_client=model_client,
|
model_client=model_client,
|
||||||
model_context=BufferedChatCompletionContext(buffer_size=10),
|
model_context=BufferedChatCompletionContext(buffer_size=10),
|
||||||
tool_schema=[tool.schema for tool in white_tools],
|
tool_schema=[tool.schema for tool in white_tools],
|
||||||
tool_agent_type="ToolAgent",
|
tool_agent_type="PlayerWhiteToolAgent",
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -249,7 +256,7 @@ async def main(model_config: Dict[str, Any]) -> None:
|
||||||
# orchestration.
|
# orchestration.
|
||||||
# Send an initial message to player white to start the game.
|
# Send an initial message to player white to start the game.
|
||||||
await runtime.send_message(
|
await runtime.send_message(
|
||||||
TextMessage(content="Game started.", source="System"),
|
TextMessage(content="Game started, white player your move.", source="System"),
|
||||||
AgentId("PlayerWhite", "default"),
|
AgentId("PlayerWhite", "default"),
|
||||||
)
|
)
|
||||||
await runtime.stop_when_idle()
|
await runtime.stop_when_idle()
|
||||||
|
|
|
@ -1,98 +0,0 @@
|
||||||
from typing import List, Optional, Union
|
|
||||||
|
|
||||||
from autogen_core.models import (
|
|
||||||
AssistantMessage,
|
|
||||||
FunctionExecutionResult,
|
|
||||||
FunctionExecutionResultMessage,
|
|
||||||
LLMMessage,
|
|
||||||
UserMessage,
|
|
||||||
)
|
|
||||||
from typing_extensions import Literal
|
|
||||||
|
|
||||||
from .messages import (
|
|
||||||
FunctionCallMessage,
|
|
||||||
Message,
|
|
||||||
MultiModalMessage,
|
|
||||||
TextMessage,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def convert_content_message_to_assistant_message(
|
|
||||||
message: Union[TextMessage, MultiModalMessage, FunctionCallMessage],
|
|
||||||
handle_unrepresentable: Literal["error", "ignore", "try_slice"] = "error",
|
|
||||||
) -> Optional[AssistantMessage]:
|
|
||||||
match message:
|
|
||||||
case TextMessage() | FunctionCallMessage():
|
|
||||||
return AssistantMessage(content=message.content, source=message.source)
|
|
||||||
case MultiModalMessage():
|
|
||||||
if handle_unrepresentable == "error":
|
|
||||||
raise ValueError("Cannot represent multimodal message as AssistantMessage")
|
|
||||||
elif handle_unrepresentable == "ignore":
|
|
||||||
return None
|
|
||||||
elif handle_unrepresentable == "try_slice":
|
|
||||||
return AssistantMessage(
|
|
||||||
content="".join([x for x in message.content if isinstance(x, str)]),
|
|
||||||
source=message.source,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def convert_content_message_to_user_message(
|
|
||||||
message: Union[TextMessage, MultiModalMessage, FunctionCallMessage],
|
|
||||||
handle_unrepresentable: Literal["error", "ignore", "try_slice"] = "error",
|
|
||||||
) -> Optional[UserMessage]:
|
|
||||||
match message:
|
|
||||||
case TextMessage() | MultiModalMessage():
|
|
||||||
return UserMessage(content=message.content, source=message.source)
|
|
||||||
case FunctionCallMessage():
|
|
||||||
if handle_unrepresentable == "error":
|
|
||||||
raise ValueError("Cannot represent multimodal message as UserMessage")
|
|
||||||
elif handle_unrepresentable == "ignore":
|
|
||||||
return None
|
|
||||||
elif handle_unrepresentable == "try_slice":
|
|
||||||
# TODO: what is a sliced function call?
|
|
||||||
raise NotImplementedError("Sliced function calls not yet implemented")
|
|
||||||
|
|
||||||
|
|
||||||
def convert_tool_call_response_message(
|
|
||||||
message: FunctionExecutionResultMessage,
|
|
||||||
handle_unrepresentable: Literal["error", "ignore", "try_slice"] = "error",
|
|
||||||
) -> Optional[FunctionExecutionResultMessage]:
|
|
||||||
match message:
|
|
||||||
case FunctionExecutionResultMessage():
|
|
||||||
return FunctionExecutionResultMessage(
|
|
||||||
content=[FunctionExecutionResult(content=x.content, call_id=x.call_id) for x in message.content]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def convert_messages_to_llm_messages(
|
|
||||||
messages: List[Message],
|
|
||||||
self_name: str,
|
|
||||||
handle_unrepresentable: Literal["error", "ignore", "try_slice"] = "error",
|
|
||||||
) -> List[LLMMessage]:
|
|
||||||
result: List[LLMMessage] = []
|
|
||||||
for message in messages:
|
|
||||||
match message:
|
|
||||||
case (
|
|
||||||
TextMessage(content=_, source=source)
|
|
||||||
| MultiModalMessage(content=_, source=source)
|
|
||||||
| FunctionCallMessage(content=_, source=source)
|
|
||||||
) if source == self_name:
|
|
||||||
converted_message_1 = convert_content_message_to_assistant_message(message, handle_unrepresentable)
|
|
||||||
if converted_message_1 is not None:
|
|
||||||
result.append(converted_message_1)
|
|
||||||
case (
|
|
||||||
TextMessage(content=_, source=source)
|
|
||||||
| MultiModalMessage(content=_, source=source)
|
|
||||||
| FunctionCallMessage(content=_, source=source)
|
|
||||||
) if source != self_name:
|
|
||||||
converted_message_2 = convert_content_message_to_user_message(message, handle_unrepresentable)
|
|
||||||
if converted_message_2 is not None:
|
|
||||||
result.append(converted_message_2)
|
|
||||||
case FunctionExecutionResultMessage(content=_):
|
|
||||||
converted_message_3 = convert_tool_call_response_message(message, handle_unrepresentable)
|
|
||||||
if converted_message_3 is not None:
|
|
||||||
result.append(converted_message_3)
|
|
||||||
case _:
|
|
||||||
raise AssertionError("unreachable")
|
|
||||||
|
|
||||||
return result
|
|
Loading…
Reference in New Issue