autogen/python/packages/autogen-agentchat/tests/test_assistant_agent.py

1107 lines
40 KiB
Python

import json
import logging
from typing import Dict, List
import pytest
from autogen_agentchat import EVENT_LOGGER_NAME
from autogen_agentchat.agents import AssistantAgent
from autogen_agentchat.base import Handoff, TaskResult
from autogen_agentchat.messages import (
BaseChatMessage,
HandoffMessage,
MemoryQueryEvent,
ModelClientStreamingChunkEvent,
MultiModalMessage,
StructuredMessage,
TextMessage,
ThoughtEvent,
ToolCallExecutionEvent,
ToolCallRequestEvent,
ToolCallSummaryMessage,
)
from autogen_core import ComponentModel, FunctionCall, Image
from autogen_core.memory import ListMemory, Memory, MemoryContent, MemoryMimeType, MemoryQueryResult
from autogen_core.model_context import BufferedChatCompletionContext
from autogen_core.models import (
AssistantMessage,
CreateResult,
FunctionExecutionResult,
FunctionExecutionResultMessage,
LLMMessage,
RequestUsage,
SystemMessage,
UserMessage,
)
from autogen_core.models._model_client import ModelFamily
from autogen_core.tools import BaseTool, FunctionTool
from autogen_ext.models.openai import OpenAIChatCompletionClient
from autogen_ext.models.replay import ReplayChatCompletionClient
from pydantic import BaseModel
from utils import FileLogHandler
logger = logging.getLogger(EVENT_LOGGER_NAME)
logger.setLevel(logging.DEBUG)
logger.addHandler(FileLogHandler("test_assistant_agent.log"))
def _pass_function(input: str) -> str:
return "pass"
async def _fail_function(input: str) -> str:
return "fail"
async def _echo_function(input: str) -> str:
return input
@pytest.mark.asyncio
async def test_run_with_tools(monkeypatch: pytest.MonkeyPatch) -> None:
model_client = ReplayChatCompletionClient(
[
CreateResult(
finish_reason="function_calls",
content=[FunctionCall(id="1", arguments=json.dumps({"input": "task"}), name="_pass_function")],
usage=RequestUsage(prompt_tokens=10, completion_tokens=5),
thought="Calling pass function",
cached=False,
),
"pass",
"TERMINATE",
],
model_info={
"function_calling": True,
"vision": True,
"json_output": True,
"family": ModelFamily.GPT_4O,
"structured_output": True,
},
)
agent = AssistantAgent(
"tool_use_agent",
model_client=model_client,
tools=[
_pass_function,
_fail_function,
FunctionTool(_echo_function, description="Echo"),
],
)
result = await agent.run(task="task")
# Make sure the create call was made with the correct parameters.
assert len(model_client.create_calls) == 1
llm_messages = model_client.create_calls[0]["messages"]
assert len(llm_messages) == 2
assert isinstance(llm_messages[0], SystemMessage)
assert llm_messages[0].content == agent._system_messages[0].content # type: ignore
assert isinstance(llm_messages[1], UserMessage)
assert llm_messages[1].content == "task"
assert len(result.messages) == 5
assert isinstance(result.messages[0], TextMessage)
assert result.messages[0].models_usage is None
assert isinstance(result.messages[1], ThoughtEvent)
assert result.messages[1].content == "Calling pass function"
assert isinstance(result.messages[2], ToolCallRequestEvent)
assert result.messages[2].models_usage is not None
assert result.messages[2].models_usage.completion_tokens == 5
assert result.messages[2].models_usage.prompt_tokens == 10
assert isinstance(result.messages[3], ToolCallExecutionEvent)
assert result.messages[3].models_usage is None
assert isinstance(result.messages[4], ToolCallSummaryMessage)
assert result.messages[4].content == "pass"
assert result.messages[4].models_usage is None
# Test streaming.
model_client.reset()
index = 0
async for message in agent.run_stream(task="task"):
if isinstance(message, TaskResult):
assert message == result
else:
assert message == result.messages[index]
index += 1
# Test state saving and loading.
state = await agent.save_state()
agent2 = AssistantAgent(
"tool_use_agent",
model_client=model_client,
tools=[_pass_function, _fail_function, FunctionTool(_echo_function, description="Echo")],
)
await agent2.load_state(state)
state2 = await agent2.save_state()
assert state == state2
@pytest.mark.asyncio
async def test_run_with_tools_and_reflection() -> None:
model_client = ReplayChatCompletionClient(
[
CreateResult(
finish_reason="function_calls",
content=[FunctionCall(id="1", arguments=json.dumps({"input": "task"}), name="_pass_function")],
usage=RequestUsage(prompt_tokens=10, completion_tokens=5),
cached=False,
),
CreateResult(
finish_reason="stop",
content="Hello",
usage=RequestUsage(prompt_tokens=10, completion_tokens=5),
cached=False,
),
CreateResult(
finish_reason="stop",
content="TERMINATE",
usage=RequestUsage(prompt_tokens=10, completion_tokens=5),
cached=False,
),
],
model_info={
"function_calling": True,
"vision": True,
"json_output": True,
"family": ModelFamily.GPT_4O,
"structured_output": True,
},
)
agent = AssistantAgent(
"tool_use_agent",
model_client=model_client,
tools=[_pass_function, _fail_function, FunctionTool(_echo_function, description="Echo")],
reflect_on_tool_use=True,
)
result = await agent.run(task="task")
# Make sure the create call was made with the correct parameters.
assert len(model_client.create_calls) == 2
llm_messages = model_client.create_calls[0]["messages"]
assert len(llm_messages) == 2
assert isinstance(llm_messages[0], SystemMessage)
assert llm_messages[0].content == agent._system_messages[0].content # type: ignore
assert isinstance(llm_messages[1], UserMessage)
assert llm_messages[1].content == "task"
llm_messages = model_client.create_calls[1]["messages"]
assert len(llm_messages) == 4
assert isinstance(llm_messages[0], SystemMessage)
assert llm_messages[0].content == agent._system_messages[0].content # type: ignore
assert isinstance(llm_messages[1], UserMessage)
assert llm_messages[1].content == "task"
assert isinstance(llm_messages[2], AssistantMessage)
assert isinstance(llm_messages[3], FunctionExecutionResultMessage)
assert len(result.messages) == 4
assert isinstance(result.messages[0], TextMessage)
assert result.messages[0].models_usage is None
assert isinstance(result.messages[1], ToolCallRequestEvent)
assert result.messages[1].models_usage is not None
assert result.messages[1].models_usage.completion_tokens == 5
assert result.messages[1].models_usage.prompt_tokens == 10
assert isinstance(result.messages[2], ToolCallExecutionEvent)
assert result.messages[2].models_usage is None
assert isinstance(result.messages[3], TextMessage)
assert result.messages[3].content == "Hello"
assert result.messages[3].models_usage is not None
assert result.messages[3].models_usage.completion_tokens == 5
assert result.messages[3].models_usage.prompt_tokens == 10
# Test streaming.
model_client.reset()
index = 0
async for message in agent.run_stream(task="task"):
if isinstance(message, TaskResult):
assert message == result
else:
assert message == result.messages[index]
index += 1
# Test state saving and loading.
state = await agent.save_state()
agent2 = AssistantAgent(
"tool_use_agent",
model_client=model_client,
tools=[
_pass_function,
_fail_function,
FunctionTool(_echo_function, description="Echo"),
],
)
await agent2.load_state(state)
state2 = await agent2.save_state()
assert state == state2
@pytest.mark.asyncio
async def test_run_with_parallel_tools() -> None:
model_client = ReplayChatCompletionClient(
[
CreateResult(
finish_reason="function_calls",
content=[
FunctionCall(id="1", arguments=json.dumps({"input": "task1"}), name="_pass_function"),
FunctionCall(id="2", arguments=json.dumps({"input": "task2"}), name="_pass_function"),
FunctionCall(id="3", arguments=json.dumps({"input": "task3"}), name="_echo_function"),
],
usage=RequestUsage(prompt_tokens=10, completion_tokens=5),
thought="Calling pass and echo functions",
cached=False,
),
"pass",
"TERMINATE",
],
model_info={
"function_calling": True,
"vision": True,
"json_output": True,
"family": ModelFamily.GPT_4O,
"structured_output": True,
},
)
agent = AssistantAgent(
"tool_use_agent",
model_client=model_client,
tools=[
_pass_function,
_fail_function,
FunctionTool(_echo_function, description="Echo"),
],
)
result = await agent.run(task="task")
assert len(result.messages) == 5
assert isinstance(result.messages[0], TextMessage)
assert result.messages[0].models_usage is None
assert isinstance(result.messages[1], ThoughtEvent)
assert result.messages[1].content == "Calling pass and echo functions"
assert isinstance(result.messages[2], ToolCallRequestEvent)
assert result.messages[2].content == [
FunctionCall(id="1", arguments=r'{"input": "task1"}', name="_pass_function"),
FunctionCall(id="2", arguments=r'{"input": "task2"}', name="_pass_function"),
FunctionCall(id="3", arguments=r'{"input": "task3"}', name="_echo_function"),
]
assert result.messages[2].models_usage is not None
assert result.messages[2].models_usage.completion_tokens == 5
assert result.messages[2].models_usage.prompt_tokens == 10
assert isinstance(result.messages[3], ToolCallExecutionEvent)
expected_content = [
FunctionExecutionResult(call_id="1", content="pass", is_error=False, name="_pass_function"),
FunctionExecutionResult(call_id="2", content="pass", is_error=False, name="_pass_function"),
FunctionExecutionResult(call_id="3", content="task3", is_error=False, name="_echo_function"),
]
for expected in expected_content:
assert expected in result.messages[3].content
assert result.messages[3].models_usage is None
assert isinstance(result.messages[4], ToolCallSummaryMessage)
assert result.messages[4].content == "pass\npass\ntask3"
assert result.messages[4].models_usage is None
# Test streaming.
model_client.reset()
index = 0
async for message in agent.run_stream(task="task"):
if isinstance(message, TaskResult):
assert message == result
else:
assert message == result.messages[index]
index += 1
# Test state saving and loading.
state = await agent.save_state()
agent2 = AssistantAgent(
"tool_use_agent",
model_client=model_client,
tools=[_pass_function, _fail_function, FunctionTool(_echo_function, description="Echo")],
)
await agent2.load_state(state)
state2 = await agent2.save_state()
assert state == state2
@pytest.mark.asyncio
async def test_run_with_parallel_tools_with_empty_call_ids() -> None:
model_client = ReplayChatCompletionClient(
[
CreateResult(
finish_reason="function_calls",
content=[
FunctionCall(id="", arguments=json.dumps({"input": "task1"}), name="_pass_function"),
FunctionCall(id="", arguments=json.dumps({"input": "task2"}), name="_pass_function"),
FunctionCall(id="", arguments=json.dumps({"input": "task3"}), name="_echo_function"),
],
usage=RequestUsage(prompt_tokens=10, completion_tokens=5),
cached=False,
),
"pass",
"TERMINATE",
],
model_info={
"function_calling": True,
"vision": True,
"json_output": True,
"family": ModelFamily.GPT_4O,
"structured_output": True,
},
)
agent = AssistantAgent(
"tool_use_agent",
model_client=model_client,
tools=[
_pass_function,
_fail_function,
FunctionTool(_echo_function, description="Echo"),
],
)
result = await agent.run(task="task")
assert len(result.messages) == 4
assert isinstance(result.messages[0], TextMessage)
assert result.messages[0].models_usage is None
assert isinstance(result.messages[1], ToolCallRequestEvent)
assert result.messages[1].content == [
FunctionCall(id="", arguments=r'{"input": "task1"}', name="_pass_function"),
FunctionCall(id="", arguments=r'{"input": "task2"}', name="_pass_function"),
FunctionCall(id="", arguments=r'{"input": "task3"}', name="_echo_function"),
]
assert result.messages[1].models_usage is not None
assert result.messages[1].models_usage.completion_tokens == 5
assert result.messages[1].models_usage.prompt_tokens == 10
assert isinstance(result.messages[2], ToolCallExecutionEvent)
expected_content = [
FunctionExecutionResult(call_id="", content="pass", is_error=False, name="_pass_function"),
FunctionExecutionResult(call_id="", content="pass", is_error=False, name="_pass_function"),
FunctionExecutionResult(call_id="", content="task3", is_error=False, name="_echo_function"),
]
for expected in expected_content:
assert expected in result.messages[2].content
assert result.messages[2].models_usage is None
assert isinstance(result.messages[3], ToolCallSummaryMessage)
assert result.messages[3].content == "pass\npass\ntask3"
assert result.messages[3].models_usage is None
# Test streaming.
model_client.reset()
index = 0
async for message in agent.run_stream(task="task"):
if isinstance(message, TaskResult):
assert message == result
else:
assert message == result.messages[index]
index += 1
# Test state saving and loading.
state = await agent.save_state()
agent2 = AssistantAgent(
"tool_use_agent",
model_client=model_client,
tools=[_pass_function, _fail_function, FunctionTool(_echo_function, description="Echo")],
)
await agent2.load_state(state)
state2 = await agent2.save_state()
assert state == state2
@pytest.mark.asyncio
async def test_output_format() -> None:
class AgentResponse(BaseModel):
response: str
status: str
model_client = ReplayChatCompletionClient(
[
CreateResult(
finish_reason="stop",
content=AgentResponse(response="Hello", status="success").model_dump_json(),
usage=RequestUsage(prompt_tokens=10, completion_tokens=5),
cached=False,
),
]
)
agent = AssistantAgent(
"test_agent",
model_client=model_client,
output_content_type=AgentResponse,
)
assert StructuredMessage[AgentResponse] in agent.produced_message_types
assert TextMessage not in agent.produced_message_types
result = await agent.run()
assert len(result.messages) == 1
assert isinstance(result.messages[0], StructuredMessage)
assert isinstance(result.messages[0].content, AgentResponse) # type: ignore[reportUnknownMemberType]
assert result.messages[0].content.response == "Hello"
assert result.messages[0].content.status == "success"
# Test streaming.
agent = AssistantAgent(
"test_agent",
model_client=model_client,
model_client_stream=True,
output_content_type=AgentResponse,
)
model_client.reset()
stream = agent.run_stream()
stream_result: TaskResult | None = None
async for message in stream:
if isinstance(message, TaskResult):
stream_result = message
assert stream_result is not None
assert len(stream_result.messages) == 1
assert isinstance(stream_result.messages[0], StructuredMessage)
assert isinstance(stream_result.messages[0].content, AgentResponse) # type: ignore[reportUnknownMemberType]
assert stream_result.messages[0].content.response == "Hello"
assert stream_result.messages[0].content.status == "success"
@pytest.mark.asyncio
async def test_reflection_output_format() -> None:
class AgentResponse(BaseModel):
response: str
status: str
model_client = ReplayChatCompletionClient(
[
CreateResult(
finish_reason="function_calls",
content=[FunctionCall(id="1", arguments=json.dumps({"input": "task"}), name="_pass_function")],
usage=RequestUsage(prompt_tokens=10, completion_tokens=5),
cached=False,
),
AgentResponse(response="Hello", status="success").model_dump_json(),
],
model_info={
"function_calling": True,
"vision": True,
"json_output": True,
"family": ModelFamily.GPT_4O,
"structured_output": True,
},
)
agent = AssistantAgent(
"test_agent",
model_client=model_client,
output_content_type=AgentResponse,
# reflect_on_tool_use=True,
tools=[
_pass_function,
_fail_function,
],
)
result = await agent.run()
assert len(result.messages) == 3
assert isinstance(result.messages[0], ToolCallRequestEvent)
assert isinstance(result.messages[1], ToolCallExecutionEvent)
assert isinstance(result.messages[2], StructuredMessage)
assert isinstance(result.messages[2].content, AgentResponse) # type: ignore[reportUnknownMemberType]
assert result.messages[2].content.response == "Hello"
assert result.messages[2].content.status == "success"
# Test streaming.
agent = AssistantAgent(
"test_agent",
model_client=model_client,
model_client_stream=True,
output_content_type=AgentResponse,
# reflect_on_tool_use=True,
tools=[
_pass_function,
_fail_function,
],
)
model_client.reset()
stream = agent.run_stream()
stream_result: TaskResult | None = None
async for message in stream:
if isinstance(message, TaskResult):
stream_result = message
assert stream_result is not None
assert len(stream_result.messages) == 3
assert isinstance(stream_result.messages[0], ToolCallRequestEvent)
assert isinstance(stream_result.messages[1], ToolCallExecutionEvent)
assert isinstance(stream_result.messages[2], StructuredMessage)
assert isinstance(stream_result.messages[2].content, AgentResponse) # type: ignore[reportUnknownMemberType]
assert stream_result.messages[2].content.response == "Hello"
assert stream_result.messages[2].content.status == "success"
# Test when reflect_on_tool_use is False
model_client.reset()
agent = AssistantAgent(
"test_agent",
model_client=model_client,
output_content_type=AgentResponse,
reflect_on_tool_use=False,
tools=[
_pass_function,
_fail_function,
],
)
result = await agent.run()
assert len(result.messages) == 3
assert isinstance(result.messages[0], ToolCallRequestEvent)
assert isinstance(result.messages[1], ToolCallExecutionEvent)
assert isinstance(result.messages[2], ToolCallSummaryMessage)
@pytest.mark.asyncio
async def test_handoffs() -> None:
handoff = Handoff(target="agent2")
model_client = ReplayChatCompletionClient(
[
CreateResult(
finish_reason="function_calls",
content=[
FunctionCall(id="1", arguments=json.dumps({}), name=handoff.name),
],
usage=RequestUsage(prompt_tokens=42, completion_tokens=43),
cached=False,
)
],
model_info={
"function_calling": True,
"vision": True,
"json_output": True,
"family": ModelFamily.GPT_4O,
"structured_output": True,
},
)
tool_use_agent = AssistantAgent(
"tool_use_agent",
model_client=model_client,
tools=[
_pass_function,
_fail_function,
FunctionTool(_echo_function, description="Echo"),
],
handoffs=[handoff],
)
assert HandoffMessage in tool_use_agent.produced_message_types
result = await tool_use_agent.run(task="task")
assert len(result.messages) == 4
assert isinstance(result.messages[0], TextMessage)
assert result.messages[0].models_usage is None
assert isinstance(result.messages[1], ToolCallRequestEvent)
assert result.messages[1].models_usage is not None
assert result.messages[1].models_usage.completion_tokens == 43
assert result.messages[1].models_usage.prompt_tokens == 42
assert isinstance(result.messages[2], ToolCallExecutionEvent)
assert result.messages[2].models_usage is None
assert isinstance(result.messages[3], HandoffMessage)
assert result.messages[3].content == handoff.message
assert result.messages[3].target == handoff.target
assert result.messages[3].models_usage is None
# Test streaming.
model_client.reset()
index = 0
async for message in tool_use_agent.run_stream(task="task"):
if isinstance(message, TaskResult):
assert message == result
else:
assert message == result.messages[index]
index += 1
@pytest.mark.asyncio
async def test_custom_handoffs() -> None:
name = "transfer_to_agent2"
description = "Handoff to agent2."
next_action = "next_action"
class TextCommandHandOff(Handoff):
@property
def handoff_tool(self) -> BaseTool[BaseModel, BaseModel]:
"""Create a handoff tool from this handoff configuration."""
def _next_action(action: str) -> str:
"""Returns the action you want the user to perform"""
return action
return FunctionTool(_next_action, name=self.name, description=self.description, strict=True)
handoff = TextCommandHandOff(name=name, description=description, target="agent2")
model_client = ReplayChatCompletionClient(
[
CreateResult(
finish_reason="function_calls",
content=[
FunctionCall(id="1", arguments=json.dumps({"action": next_action}), name=handoff.name),
],
usage=RequestUsage(prompt_tokens=42, completion_tokens=43),
cached=False,
)
],
model_info={
"function_calling": True,
"vision": True,
"json_output": True,
"family": ModelFamily.GPT_4O,
"structured_output": True,
},
)
tool_use_agent = AssistantAgent(
"tool_use_agent",
model_client=model_client,
tools=[
_pass_function,
_fail_function,
FunctionTool(_echo_function, description="Echo"),
],
handoffs=[handoff],
)
assert HandoffMessage in tool_use_agent.produced_message_types
result = await tool_use_agent.run(task="task")
assert len(result.messages) == 4
assert isinstance(result.messages[0], TextMessage)
assert result.messages[0].models_usage is None
assert isinstance(result.messages[1], ToolCallRequestEvent)
assert result.messages[1].models_usage is not None
assert result.messages[1].models_usage.completion_tokens == 43
assert result.messages[1].models_usage.prompt_tokens == 42
assert isinstance(result.messages[2], ToolCallExecutionEvent)
assert result.messages[2].models_usage is None
assert isinstance(result.messages[3], HandoffMessage)
assert result.messages[3].content == next_action
assert result.messages[3].target == handoff.target
assert result.messages[3].models_usage is None
# Test streaming.
model_client.reset()
index = 0
async for message in tool_use_agent.run_stream(task="task"):
if isinstance(message, TaskResult):
assert message == result
else:
assert message == result.messages[index]
index += 1
@pytest.mark.asyncio
async def test_custom_object_handoffs() -> None:
"""test handoff tool return a object"""
name = "transfer_to_agent2"
description = "Handoff to agent2."
next_action = {"action": "next_action"} # using a map, not a str
class DictCommandHandOff(Handoff):
@property
def handoff_tool(self) -> BaseTool[BaseModel, BaseModel]:
"""Create a handoff tool from this handoff configuration."""
def _next_action(action: str) -> Dict[str, str]:
"""Returns the action you want the user to perform"""
return {"action": action}
return FunctionTool(_next_action, name=self.name, description=self.description, strict=True)
handoff = DictCommandHandOff(name=name, description=description, target="agent2")
model_client = ReplayChatCompletionClient(
[
CreateResult(
finish_reason="function_calls",
content=[
FunctionCall(id="1", arguments=json.dumps({"action": "next_action"}), name=handoff.name),
],
usage=RequestUsage(prompt_tokens=42, completion_tokens=43),
cached=False,
)
],
model_info={
"function_calling": True,
"vision": True,
"json_output": True,
"family": ModelFamily.GPT_4O,
"structured_output": True,
},
)
tool_use_agent = AssistantAgent(
"tool_use_agent",
model_client=model_client,
tools=[
_pass_function,
_fail_function,
FunctionTool(_echo_function, description="Echo"),
],
handoffs=[handoff],
)
assert HandoffMessage in tool_use_agent.produced_message_types
result = await tool_use_agent.run(task="task")
assert len(result.messages) == 4
assert isinstance(result.messages[0], TextMessage)
assert result.messages[0].models_usage is None
assert isinstance(result.messages[1], ToolCallRequestEvent)
assert result.messages[1].models_usage is not None
assert result.messages[1].models_usage.completion_tokens == 43
assert result.messages[1].models_usage.prompt_tokens == 42
assert isinstance(result.messages[2], ToolCallExecutionEvent)
assert result.messages[2].models_usage is None
assert isinstance(result.messages[3], HandoffMessage)
# the content will return as a string, because the function call will convert to string
assert result.messages[3].content == str(next_action)
assert result.messages[3].target == handoff.target
assert result.messages[3].models_usage is None
# Test streaming.
model_client.reset()
index = 0
async for message in tool_use_agent.run_stream(task="task"):
if isinstance(message, TaskResult):
assert message == result
else:
assert message == result.messages[index]
index += 1
@pytest.mark.asyncio
async def test_multi_modal_task(monkeypatch: pytest.MonkeyPatch) -> None:
model_client = ReplayChatCompletionClient(["Hello"])
agent = AssistantAgent(
name="assistant",
model_client=model_client,
)
# Generate a random base64 image.
img_base64 = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGP4//8/AAX+Av4N70a4AAAAAElFTkSuQmCC"
result = await agent.run(task=MultiModalMessage(source="user", content=["Test", Image.from_base64(img_base64)]))
assert len(result.messages) == 2
@pytest.mark.asyncio
async def test_run_with_structured_task() -> None:
class InputTask(BaseModel):
input: str
data: List[str]
model_client = ReplayChatCompletionClient(["Hello"])
agent = AssistantAgent(
name="assistant",
model_client=model_client,
)
task = StructuredMessage[InputTask](content=InputTask(input="Test", data=["Test1", "Test2"]), source="user")
result = await agent.run(task=task)
assert len(result.messages) == 2
@pytest.mark.asyncio
async def test_invalid_model_capabilities() -> None:
model = "random-model"
model_client = OpenAIChatCompletionClient(
model=model,
api_key="",
model_info={
"vision": False,
"function_calling": False,
"json_output": False,
"family": ModelFamily.UNKNOWN,
"structured_output": False,
},
)
with pytest.raises(ValueError):
agent = AssistantAgent(
name="assistant",
model_client=model_client,
tools=[
_pass_function,
_fail_function,
FunctionTool(_echo_function, description="Echo"),
],
)
await agent.run(task=TextMessage(source="user", content="Test"))
with pytest.raises(ValueError):
agent = AssistantAgent(name="assistant", model_client=model_client, handoffs=["agent2"])
await agent.run(task=TextMessage(source="user", content="Test"))
@pytest.mark.asyncio
async def test_remove_images() -> None:
model = "random-model"
model_client_1 = OpenAIChatCompletionClient(
model=model,
api_key="",
model_info={
"vision": False,
"function_calling": False,
"json_output": False,
"family": ModelFamily.UNKNOWN,
"structured_output": False,
},
)
model_client_2 = OpenAIChatCompletionClient(
model=model,
api_key="",
model_info={
"vision": True,
"function_calling": False,
"json_output": False,
"family": ModelFamily.UNKNOWN,
"structured_output": False,
},
)
img_base64 = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGP4//8/AAX+Av4N70a4AAAAAElFTkSuQmCC"
messages: List[LLMMessage] = [
SystemMessage(content="System.1"),
UserMessage(content=["User.1", Image.from_base64(img_base64)], source="user.1"),
AssistantMessage(content="Assistant.1", source="assistant.1"),
UserMessage(content="User.2", source="assistant.2"),
]
agent_1 = AssistantAgent(name="assistant_1", model_client=model_client_1)
result = agent_1._get_compatible_context(model_client_1, messages) # type: ignore
assert len(result) == 4
assert isinstance(result[1].content, str)
agent_2 = AssistantAgent(name="assistant_2", model_client=model_client_2)
result = agent_2._get_compatible_context(model_client_2, messages) # type: ignore
assert len(result) == 4
assert isinstance(result[1].content, list)
@pytest.mark.asyncio
async def test_list_chat_messages(monkeypatch: pytest.MonkeyPatch) -> None:
model_client = ReplayChatCompletionClient(
[
CreateResult(
finish_reason="stop",
content="Response to message 1",
usage=RequestUsage(prompt_tokens=10, completion_tokens=5),
cached=False,
)
]
)
agent = AssistantAgent(
"test_agent",
model_client=model_client,
)
# Create a list of chat messages
messages: List[BaseChatMessage] = [
TextMessage(content="Message 1", source="user"),
TextMessage(content="Message 2", source="user"),
]
# Test run method with list of messages
result = await agent.run(task=messages)
assert len(result.messages) == 3 # 2 input messages + 1 response message
assert isinstance(result.messages[0], TextMessage)
assert result.messages[0].content == "Message 1"
assert result.messages[0].source == "user"
assert isinstance(result.messages[1], TextMessage)
assert result.messages[1].content == "Message 2"
assert result.messages[1].source == "user"
assert isinstance(result.messages[2], TextMessage)
assert result.messages[2].content == "Response to message 1"
assert result.messages[2].source == "test_agent"
assert result.messages[2].models_usage is not None
assert result.messages[2].models_usage.completion_tokens == 5
assert result.messages[2].models_usage.prompt_tokens == 10
# Test run_stream method with list of messages
model_client.reset() # Reset the mock client
index = 0
async for message in agent.run_stream(task=messages):
if isinstance(message, TaskResult):
assert message == result
else:
assert message == result.messages[index]
index += 1
@pytest.mark.asyncio
async def test_model_context(monkeypatch: pytest.MonkeyPatch) -> None:
model_client = ReplayChatCompletionClient(["Response to message 3"])
model_context = BufferedChatCompletionContext(buffer_size=2)
agent = AssistantAgent(
"test_agent",
model_client=model_client,
model_context=model_context,
)
messages = [
TextMessage(content="Message 1", source="user"),
TextMessage(content="Message 2", source="user"),
TextMessage(content="Message 3", source="user"),
]
await agent.run(task=messages)
# Check that the model_context property returns the correct internal context
assert agent.model_context == model_context
# Check if the mock client is called with only the last two messages.
assert len(model_client.create_calls) == 1
# 2 message from the context + 1 system message
assert len(model_client.create_calls[0]["messages"]) == 3
@pytest.mark.asyncio
async def test_run_with_memory(monkeypatch: pytest.MonkeyPatch) -> None:
model_client = ReplayChatCompletionClient(["Hello"])
b64_image_str = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGP4//8/AAX+Av4N70a4AAAAAElFTkSuQmCC"
# Test basic memory properties and empty context
memory = ListMemory(name="test_memory")
assert memory.name == "test_memory"
empty_context = BufferedChatCompletionContext(buffer_size=2)
empty_results = await memory.update_context(empty_context)
assert len(empty_results.memories.results) == 0
# Test various content types
memory = ListMemory()
await memory.add(MemoryContent(content="text content", mime_type=MemoryMimeType.TEXT))
await memory.add(MemoryContent(content={"key": "value"}, mime_type=MemoryMimeType.JSON))
await memory.add(MemoryContent(content=Image.from_base64(b64_image_str), mime_type=MemoryMimeType.IMAGE))
# Test query functionality
query_result = await memory.query(MemoryContent(content="", mime_type=MemoryMimeType.TEXT))
assert isinstance(query_result, MemoryQueryResult)
# Should have all three memories we added
assert len(query_result.results) == 3
# Test clear and cleanup
await memory.clear()
empty_query = await memory.query(MemoryContent(content="", mime_type=MemoryMimeType.TEXT))
assert len(empty_query.results) == 0
await memory.close() # Should not raise
# Test invalid memory type
with pytest.raises(TypeError):
AssistantAgent(
"test_agent",
model_client=model_client,
memory="invalid", # type: ignore
)
# Test with agent
memory2 = ListMemory()
await memory2.add(MemoryContent(content="test instruction", mime_type=MemoryMimeType.TEXT))
agent = AssistantAgent("test_agent", model_client=model_client, memory=[memory2])
# Test dump and load component with memory
agent_config: ComponentModel = agent.dump_component()
assert agent_config.provider == "autogen_agentchat.agents.AssistantAgent"
agent2 = AssistantAgent.load_component(agent_config)
result = await agent2.run(task="test task")
assert len(result.messages) > 0
memory_event = next((msg for msg in result.messages if isinstance(msg, MemoryQueryEvent)), None)
assert memory_event is not None
assert len(memory_event.content) > 0
assert isinstance(memory_event.content[0], MemoryContent)
# Test memory protocol
class BadMemory:
pass
assert not isinstance(BadMemory(), Memory)
assert isinstance(ListMemory(), Memory)
@pytest.mark.asyncio
async def test_assistant_agent_declarative() -> None:
model_client = ReplayChatCompletionClient(
["Response to message 3"],
model_info={
"function_calling": True,
"vision": True,
"json_output": True,
"family": ModelFamily.GPT_4O,
"structured_output": True,
},
)
model_context = BufferedChatCompletionContext(buffer_size=2)
agent = AssistantAgent(
"test_agent",
model_client=model_client,
model_context=model_context,
memory=[ListMemory(name="test_memory")],
)
agent_config: ComponentModel = agent.dump_component()
assert agent_config.provider == "autogen_agentchat.agents.AssistantAgent"
agent2 = AssistantAgent.load_component(agent_config)
assert agent2.name == agent.name
agent3 = AssistantAgent(
"test_agent",
model_client=model_client,
model_context=model_context,
tools=[
_pass_function,
_fail_function,
FunctionTool(_echo_function, description="Echo"),
],
)
agent3_config = agent3.dump_component()
assert agent3_config.provider == "autogen_agentchat.agents.AssistantAgent"
@pytest.mark.asyncio
async def test_model_client_stream() -> None:
mock_client = ReplayChatCompletionClient(
[
"Response to message 3",
]
)
agent = AssistantAgent(
"test_agent",
model_client=mock_client,
model_client_stream=True,
)
chunks: List[str] = []
async for message in agent.run_stream(task="task"):
if isinstance(message, TaskResult):
assert isinstance(message.messages[-1], TextMessage)
assert message.messages[-1].content == "Response to message 3"
elif isinstance(message, ModelClientStreamingChunkEvent):
chunks.append(message.content)
assert "".join(chunks) == "Response to message 3"
@pytest.mark.asyncio
async def test_model_client_stream_with_tool_calls() -> None:
mock_client = ReplayChatCompletionClient(
[
CreateResult(
content=[
FunctionCall(id="1", name="_pass_function", arguments=r'{"input": "task"}'),
FunctionCall(id="3", name="_echo_function", arguments=r'{"input": "task"}'),
],
finish_reason="function_calls",
usage=RequestUsage(prompt_tokens=10, completion_tokens=5),
cached=False,
),
"Example response 2 to task",
]
)
mock_client._model_info["function_calling"] = True # pyright: ignore
agent = AssistantAgent(
"test_agent",
model_client=mock_client,
model_client_stream=True,
reflect_on_tool_use=True,
tools=[_pass_function, _echo_function],
)
chunks: List[str] = []
async for message in agent.run_stream(task="task"):
if isinstance(message, TaskResult):
assert isinstance(message.messages[-1], TextMessage)
assert isinstance(message.messages[1], ToolCallRequestEvent)
assert message.messages[-1].content == "Example response 2 to task"
assert message.messages[1].content == [
FunctionCall(id="1", name="_pass_function", arguments=r'{"input": "task"}'),
FunctionCall(id="3", name="_echo_function", arguments=r'{"input": "task"}'),
]
assert isinstance(message.messages[2], ToolCallExecutionEvent)
assert message.messages[2].content == [
FunctionExecutionResult(call_id="1", content="pass", is_error=False, name="_pass_function"),
FunctionExecutionResult(call_id="3", content="task", is_error=False, name="_echo_function"),
]
elif isinstance(message, ModelClientStreamingChunkEvent):
chunks.append(message.content)
assert "".join(chunks) == "Example response 2 to task"