Declarative BaseChat Agents (#5055)

* v1, make assistant agent declarative

* make head tail context declarative

* update and formatting

* update assistant, format updates

* make websurfer declarative

* update formatting

* move declarative docs to advanced section

* remove tools until implemented

* minor updates to termination conditions

* update docs
This commit is contained in:
Victor Dibia 2025-01-16 22:29:40 -08:00 committed by GitHub
parent 1f22a7b7a1
commit c2a43e84a2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 524 additions and 144 deletions

View File

@ -13,7 +13,7 @@ from typing import (
Sequence,
)
from autogen_core import CancellationToken, FunctionCall
from autogen_core import CancellationToken, Component, ComponentModel, FunctionCall
from autogen_core.memory import Memory
from autogen_core.model_context import (
ChatCompletionContext,
@ -28,6 +28,8 @@ from autogen_core.models import (
UserMessage,
)
from autogen_core.tools import FunctionTool, Tool
from pydantic import BaseModel
from typing_extensions import Self
from .. import EVENT_LOGGER_NAME
from ..base import Handoff as HandoffBase
@ -49,7 +51,21 @@ from ._base_chat_agent import BaseChatAgent
event_logger = logging.getLogger(EVENT_LOGGER_NAME)
class AssistantAgent(BaseChatAgent):
class AssistantAgentConfig(BaseModel):
"""The declarative configuration for the assistant agent."""
name: str
model_client: ComponentModel
# tools: List[Any] | None = None # TBD
handoffs: List[HandoffBase | str] | None = None
model_context: ComponentModel | None = None
description: str
system_message: str | None = None
reflect_on_tool_use: bool
tool_call_summary_format: str
class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
"""An agent that provides assistance with tool use.
The :meth:`on_messages` returns a :class:`~autogen_agentchat.base.Response`
@ -229,6 +245,9 @@ class AssistantAgent(BaseChatAgent):
See `o1 beta limitations <https://platform.openai.com/docs/guides/reasoning#beta-limitations>`_ for more details.
"""
component_config_schema = AssistantAgentConfig
component_provider_override = "autogen_agentchat.agents.AssistantAgent"
def __init__(
self,
name: str,
@ -462,3 +481,40 @@ class AssistantAgent(BaseChatAgent):
assistant_agent_state = AssistantAgentState.model_validate(state)
# Load the model context state.
await self._model_context.load_state(assistant_agent_state.llm_context)
def _to_config(self) -> AssistantAgentConfig:
"""Convert the assistant agent to a declarative config."""
# raise an error if tools is not empty until it is implemented
# TBD : Implement serializing tools and remove this check.
if self._tools and len(self._tools) > 0:
raise NotImplementedError("Serializing tools is not implemented yet.")
return AssistantAgentConfig(
name=self.name,
model_client=self._model_client.dump_component(),
# tools=[], # TBD
handoffs=list(self._handoffs.values()),
model_context=self._model_context.dump_component(),
description=self.description,
system_message=self._system_messages[0].content
if self._system_messages and isinstance(self._system_messages[0].content, str)
else None,
reflect_on_tool_use=self._reflect_on_tool_use,
tool_call_summary_format=self._tool_call_summary_format,
)
@classmethod
def _from_config(cls, config: AssistantAgentConfig) -> Self:
"""Create an assistant agent from a declarative config."""
return cls(
name=config.name,
model_client=ChatCompletionClient.load_component(config.model_client),
# tools=[], # TBD
handoffs=config.handoffs,
model_context=None,
description=config.description,
system_message=config.system_message,
reflect_on_tool_use=config.reflect_on_tool_use,
tool_call_summary_format=config.tool_call_summary_format,
)

View File

@ -1,7 +1,8 @@
from abc import ABC, abstractmethod
from typing import Any, AsyncGenerator, List, Mapping, Sequence
from autogen_core import CancellationToken
from autogen_core import CancellationToken, ComponentBase
from pydantic import BaseModel
from ..base import ChatAgent, Response, TaskResult
from ..messages import (
@ -13,7 +14,7 @@ from ..messages import (
from ..state import BaseState
class BaseChatAgent(ChatAgent, ABC):
class BaseChatAgent(ChatAgent, ABC, ComponentBase[BaseModel]):
"""Base class for a chat agent.
This abstract class provides a base implementation for a :class:`ChatAgent`.
@ -35,6 +36,8 @@ class BaseChatAgent(ChatAgent, ABC):
This design principle must be followed when creating a new agent.
"""
component_type = "agent"
def __init__(self, name: str, description: str) -> None:
self._name = name
if self._name.isidentifier() is False:

View File

@ -5,7 +5,9 @@ from contextvars import ContextVar
from inspect import iscoroutinefunction
from typing import Any, AsyncGenerator, Awaitable, Callable, ClassVar, Generator, Optional, Sequence, Union, cast
from autogen_core import CancellationToken
from autogen_core import CancellationToken, Component
from pydantic import BaseModel
from typing_extensions import Self
from ..base import Response
from ..messages import AgentEvent, ChatMessage, HandoffMessage, TextMessage, UserInputRequestedEvent
@ -24,7 +26,15 @@ async def cancellable_input(prompt: str, cancellation_token: Optional[Cancellati
return await task
class UserProxyAgent(BaseChatAgent):
class UserProxyAgentConfig(BaseModel):
"""Declarative configuration for the UserProxyAgent."""
name: str
description: str = "A human user"
input_func: str | None = None
class UserProxyAgent(BaseChatAgent, Component[UserProxyAgentConfig]):
"""An agent that can represent a human user through an input function.
This agent can be used to represent a human user in a chat system by providing a custom input function.
@ -109,6 +119,10 @@ class UserProxyAgent(BaseChatAgent):
print(f"BaseException: {e}")
"""
component_type = "agent"
component_provider_override = "autogen_agentchat.agents.UserProxyAgent"
component_config_schema = UserProxyAgentConfig
class InputRequestContext:
def __init__(self) -> None:
raise RuntimeError(
@ -218,3 +232,11 @@ class UserProxyAgent(BaseChatAgent):
async def on_reset(self, cancellation_token: Optional[CancellationToken] = None) -> None:
"""Reset agent state."""
pass
def _to_config(self) -> UserProxyAgentConfig:
# TODO: Add ability to serialie input_func
return UserProxyAgentConfig(name=self.name, description=self.description, input_func=None)
@classmethod
def _from_config(cls, config: UserProxyAgentConfig) -> Self:
return cls(name=config.name, description=config.description, input_func=None)

View File

@ -48,7 +48,6 @@ class TerminationCondition(ABC, ComponentBase[BaseModel]):
"""
component_type = "termination"
# component_config_schema = BaseModel # type: ignore
@property
@abstractmethod

View File

@ -16,7 +16,6 @@ class StopMessageTerminationConfig(BaseModel):
class StopMessageTermination(TerminationCondition, Component[StopMessageTerminationConfig]):
"""Terminate the conversation if a StopMessage is received."""
component_type = "termination"
component_config_schema = StopMessageTerminationConfig
component_provider_override = "autogen_agentchat.conditions.StopMessageTermination"
@ -58,7 +57,6 @@ class MaxMessageTermination(TerminationCondition, Component[MaxMessageTerminatio
max_messages: The maximum number of messages allowed in the conversation.
"""
component_type = "termination"
component_config_schema = MaxMessageTerminationConfig
component_provider_override = "autogen_agentchat.conditions.MaxMessageTermination"
@ -104,7 +102,6 @@ class TextMentionTermination(TerminationCondition, Component[TextMentionTerminat
text: The text to look for in the messages.
"""
component_type = "termination"
component_config_schema = TextMentionTerminationConfig
component_provider_override = "autogen_agentchat.conditions.TextMentionTermination"
@ -159,7 +156,6 @@ class TokenUsageTermination(TerminationCondition, Component[TokenUsageTerminatio
ValueError: If none of max_total_token, max_prompt_token, or max_completion_token is provided.
"""
component_type = "termination"
component_config_schema = TokenUsageTerminationConfig
component_provider_override = "autogen_agentchat.conditions.TokenUsageTermination"
@ -234,7 +230,6 @@ class HandoffTermination(TerminationCondition, Component[HandoffTerminationConfi
target (str): The target of the handoff message.
"""
component_type = "termination"
component_config_schema = HandoffTerminationConfig
component_provider_override = "autogen_agentchat.conditions.HandoffTermination"
@ -279,7 +274,6 @@ class TimeoutTermination(TerminationCondition, Component[TimeoutTerminationConfi
timeout_seconds: The maximum duration in seconds before terminating the conversation.
"""
component_type = "termination"
component_config_schema = TimeoutTerminationConfig
component_provider_override = "autogen_agentchat.conditions.TimeoutTermination"
@ -339,7 +333,6 @@ class ExternalTermination(TerminationCondition, Component[ExternalTerminationCon
"""
component_type = "termination"
component_config_schema = ExternalTerminationConfig
component_provider_override = "autogen_agentchat.conditions.ExternalTermination"
@ -389,7 +382,6 @@ class SourceMatchTermination(TerminationCondition, Component[SourceMatchTerminat
TerminatedException: If the termination condition has already been reached.
"""
component_type = "termination"
component_config_schema = SourceMatchTerminationConfig
component_provider_override = "autogen_agentchat.conditions.SourceMatchTermination"

View File

@ -592,3 +592,51 @@ async def test_run_with_memory(monkeypatch: pytest.MonkeyPatch) -> None:
assert not isinstance(BadMemory(), Memory)
assert isinstance(ListMemory(), Memory)
@pytest.mark.asyncio
async def test_assistant_agent_declarative(monkeypatch: pytest.MonkeyPatch) -> None:
model = "gpt-4o-2024-05-13"
chat_completions = [
ChatCompletion(
id="id1",
choices=[
Choice(
finish_reason="stop",
index=0,
message=ChatCompletionMessage(content="Response to message 3", role="assistant"),
)
],
created=0,
model=model,
object="chat.completion",
usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=15),
),
]
mock = _MockChatCompletion(chat_completions)
monkeypatch.setattr(AsyncCompletions, "create", mock.mock_create)
model_context = BufferedChatCompletionContext(buffer_size=2)
agent = AssistantAgent(
"test_agent",
model_client=OpenAIChatCompletionClient(model=model, api_key=""),
model_context=model_context,
)
agent_config = agent.dump_component()
assert agent_config.provider == "autogen_agentchat.agents.AssistantAgent"
agent2 = AssistantAgent.load_component(agent_config)
assert agent2.name == agent.name
agent3 = AssistantAgent(
"test_agent",
model_client=OpenAIChatCompletionClient(model=model, api_key=""),
model_context=model_context,
tools=[
_pass_function,
_fail_function,
FunctionTool(_echo_function, description="Echo"),
],
)
with pytest.raises(NotImplementedError):
agent3.dump_component()

View File

@ -11,6 +11,11 @@ from autogen_agentchat.conditions import (
TokenUsageTermination,
)
from autogen_core import ComponentLoader, ComponentModel
from autogen_core.model_context import (
BufferedChatCompletionContext,
HeadAndTailChatCompletionContext,
UnboundedChatCompletionContext,
)
@pytest.mark.asyncio
@ -92,3 +97,35 @@ async def test_termination_declarative() -> None:
# Test loading complex composition
loaded_composite = ComponentLoader.load_component(composite_config)
assert isinstance(loaded_composite, AndTerminationCondition)
@pytest.mark.asyncio
async def test_chat_completion_context_declarative() -> None:
unbounded_context = UnboundedChatCompletionContext()
buffered_context = BufferedChatCompletionContext(buffer_size=5)
head_tail_context = HeadAndTailChatCompletionContext(head_size=3, tail_size=2)
# Test serialization
unbounded_config = unbounded_context.dump_component()
assert unbounded_config.provider == "autogen_core.model_context.UnboundedChatCompletionContext"
buffered_config = buffered_context.dump_component()
assert buffered_config.provider == "autogen_core.model_context.BufferedChatCompletionContext"
assert buffered_config.config["buffer_size"] == 5
head_tail_config = head_tail_context.dump_component()
assert head_tail_config.provider == "autogen_core.model_context.HeadAndTailChatCompletionContext"
assert head_tail_config.config["head_size"] == 3
assert head_tail_config.config["tail_size"] == 2
# Test deserialization
loaded_unbounded = ComponentLoader.load_component(unbounded_config, UnboundedChatCompletionContext)
assert isinstance(loaded_unbounded, UnboundedChatCompletionContext)
loaded_buffered = ComponentLoader.load_component(buffered_config, BufferedChatCompletionContext)
assert isinstance(loaded_buffered, BufferedChatCompletionContext)
loaded_head_tail = ComponentLoader.load_component(head_tail_config, HeadAndTailChatCompletionContext)
assert isinstance(loaded_head_tail, HeadAndTailChatCompletionContext)

View File

@ -66,6 +66,18 @@ Sample code and use cases
How to migrate from AutoGen 0.2.x to 0.4.x.
:::
:::{grid-item-card} {fas}`save;pst-color-primary` Serialize Components
:link: ./serialize-components.html
Serialize and deserialize components
:::
:::{grid-item-card} {fas}`brain;pst-color-primary` Memory
:link: ./memory.html
Add memory capabilities to your agents
:::
::::
```{toctree}
@ -91,8 +103,7 @@ tutorial/human-in-the-loop
tutorial/termination
tutorial/custom-agents
tutorial/state
tutorial/declarative
tutorial/memory
```
```{toctree}
@ -103,6 +114,8 @@ tutorial/memory
selector-group-chat
swarm
magentic-one
memory
serialize-components
```
```{toctree}

View File

@ -0,0 +1,171 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Serializing Components \n",
"\n",
"AutoGen provides a {py:class}`~autogen_core.Component` configuration class that defines behaviours for to serialize/deserialize component into declarative specifications. This is useful for debugging, visualizing, and even for sharing your work with others. In this notebook, we will demonstrate how to serialize multiple components to a declarative specification like a JSON file. \n",
"\n",
"\n",
"```{note}\n",
"This is work in progress\n",
"``` \n",
"\n",
"We will be implementing declarative support for the following components:\n",
"\n",
"- Termination conditions ✔️\n",
"- Tools \n",
"- Agents \n",
"- Teams \n",
"\n",
"\n",
"### Termination Condition Example \n",
"\n",
"In the example below, we will define termination conditions (a part of an agent team) in python, export this to a dictionary/json and also demonstrate how the termination condition object can be loaded from the dictionary/json. \n",
" "
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Config: {\"provider\":\"autogen_agentchat.base.OrTerminationCondition\",\"component_type\":\"termination\",\"version\":1,\"component_version\":1,\"description\":null,\"config\":{\"conditions\":[{\"provider\":\"autogen_agentchat.conditions.MaxMessageTermination\",\"component_type\":\"termination\",\"version\":1,\"component_version\":1,\"config\":{\"max_messages\":5}},{\"provider\":\"autogen_agentchat.conditions.StopMessageTermination\",\"component_type\":\"termination\",\"version\":1,\"component_version\":1,\"config\":{}}]}}\n"
]
}
],
"source": [
"from autogen_agentchat.conditions import MaxMessageTermination, StopMessageTermination\n",
"\n",
"max_termination = MaxMessageTermination(5)\n",
"stop_termination = StopMessageTermination()\n",
"\n",
"or_termination = max_termination | stop_termination\n",
"\n",
"or_term_config = or_termination.dump_component()\n",
"print(\"Config: \", or_term_config.model_dump_json())\n",
"\n",
"new_or_termination = or_termination.load_component(or_term_config)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Agent Example \n",
"\n",
"In the example below, we will define an agent in python, export this to a dictionary/json and also demonstrate how the agent object can be loaded from the dictionary/json."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"from autogen_agentchat.agents import AssistantAgent, UserProxyAgent\n",
"from autogen_ext.models.openai import OpenAIChatCompletionClient\n",
"\n",
"# Create an agent that uses the OpenAI GPT-4o model.\n",
"model_client = OpenAIChatCompletionClient(\n",
" model=\"gpt-4o\",\n",
" # api_key=\"YOUR_API_KEY\",\n",
")\n",
"agent = AssistantAgent(\n",
" name=\"assistant\",\n",
" model_client=model_client,\n",
" handoffs=[\"flights_refunder\", \"user\"],\n",
" # tools=[], # serializing tools is not yet supported\n",
" system_message=\"Use tools to solve tasks.\",\n",
")\n",
"user_proxy = UserProxyAgent(name=\"user\")"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{\"provider\":\"autogen_agentchat.agents.UserProxyAgent\",\"component_type\":\"agent\",\"version\":1,\"component_version\":1,\"description\":null,\"config\":{\"name\":\"user\",\"description\":\"A human user\"}}\n"
]
}
],
"source": [
"user_proxy_config = user_proxy.dump_component() # dump component\n",
"print(user_proxy_config.model_dump_json())\n",
"up_new = user_proxy.load_component(user_proxy_config) # load component"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{\"provider\":\"autogen_agentchat.agents.AssistantAgent\",\"component_type\":\"agent\",\"version\":1,\"component_version\":1,\"description\":null,\"config\":{\"name\":\"assistant\",\"model_client\":{\"provider\":\"autogen_ext.models.openai.OpenAIChatCompletionClient\",\"component_type\":\"model\",\"version\":1,\"component_version\":1,\"config\":{\"model\":\"gpt-4o\"}},\"handoffs\":[{\"target\":\"flights_refunder\",\"description\":\"Handoff to flights_refunder.\",\"name\":\"transfer_to_flights_refunder\",\"message\":\"Transferred to flights_refunder, adopting the role of flights_refunder immediately.\"},{\"target\":\"user\",\"description\":\"Handoff to user.\",\"name\":\"transfer_to_user\",\"message\":\"Transferred to user, adopting the role of user immediately.\"}],\"model_context\":{\"provider\":\"autogen_core.model_context.UnboundedChatCompletionContext\",\"component_type\":\"chat_completion_context\",\"version\":1,\"component_version\":1,\"config\":{}},\"description\":\"An agent that provides assistance with ability to use tools.\",\"system_message\":\"Use tools to solve tasks.\",\"reflect_on_tool_use\":false,\"tool_call_summary_format\":\"{result}\"}}\n"
]
}
],
"source": [
"agent_config = agent.dump_component() # dump component\n",
"print(agent_config.model_dump_json())\n",
"agent_new = agent.load_component(agent_config) # load component"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"A similar approach can be used to serialize the `MultiModalWebSurfer` agent.\n",
"\n",
"```python\n",
"from autogen_ext.agents.web_surfer import MultimodalWebSurfer\n",
"\n",
"agent = MultimodalWebSurfer(\n",
" name=\"web_surfer\",\n",
" model_client=model_client,\n",
" headless=False,\n",
")\n",
"\n",
"web_surfer_config = agent.dump_component() # dump component\n",
"print(web_surfer_config.model_dump_json())\n",
"\n",
"```"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.9"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@ -1,119 +0,0 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Declarative Components \n",
"\n",
"AutoGen provides a declarative {py:class}`~autogen_core.Component` configuration class that defines behaviours for declarative import/export. This is useful for debugging, visualizing, and even for sharing your work with others. In this notebook, we will demonstrate how to export a declarative representation of a multiagent team in the form of a JSON file. \n",
"\n",
"\n",
"```{note}\n",
"This is work in progress\n",
"``` \n",
"\n",
"We will be implementing declarative support for the following components:\n",
"\n",
"- Termination conditions ✔️\n",
"- Tools \n",
"- Agents \n",
"- Teams \n",
"\n",
"\n",
"### Termination Condition Example \n",
"\n",
"In the example below, we will define termination conditions (a part of an agent team) in python, export this to a dictionary/json and also demonstrate how the termination condition object can be loaded from the dictionary/json. \n",
" "
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"from autogen_agentchat.conditions import MaxMessageTermination, StopMessageTermination\n",
"\n",
"max_termination = MaxMessageTermination(5)\n",
"stop_termination = StopMessageTermination()"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"provider='autogen_agentchat.conditions.MaxMessageTermination' component_type='termination' version=1 component_version=1 description=None config={'max_messages': 5}\n"
]
}
],
"source": [
"print(max_termination.dump_component())"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'provider': 'autogen_agentchat.conditions.MaxMessageTermination', 'component_type': 'termination', 'version': 1, 'component_version': 1, 'description': None, 'config': {'max_messages': 5}}\n"
]
}
],
"source": [
"print(max_termination.dump_component().model_dump())"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"ComponentModel(provider='autogen_agentchat.base.OrTerminationCondition', component_type='termination', version=1, component_version=1, description=None, config={'conditions': [{'provider': 'autogen_agentchat.conditions.MaxMessageTermination', 'component_type': 'termination', 'version': 1, 'component_version': 1, 'config': {'max_messages': 5}}, {'provider': 'autogen_agentchat.conditions.StopMessageTermination', 'component_type': 'termination', 'version': 1, 'component_version': 1, 'config': {}}]})"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"or_termination = max_termination | stop_termination\n",
"or_termination.dump_component()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.9"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@ -1,10 +1,19 @@
from typing import List
from pydantic import BaseModel
from typing_extensions import Self
from .._component_config import Component
from ..models import FunctionExecutionResultMessage, LLMMessage
from ._chat_completion_context import ChatCompletionContext
class BufferedChatCompletionContext(ChatCompletionContext):
class BufferedChatCompletionContextConfig(BaseModel):
buffer_size: int
initial_messages: List[LLMMessage] | None = None
class BufferedChatCompletionContext(ChatCompletionContext, Component[BufferedChatCompletionContextConfig]):
"""A buffered chat completion context that keeps a view of the last n messages,
where n is the buffer size. The buffer size is set at initialization.
@ -13,6 +22,9 @@ class BufferedChatCompletionContext(ChatCompletionContext):
initial_messages (List[LLMMessage] | None): The initial messages.
"""
component_config_schema = BufferedChatCompletionContextConfig
component_provider_override = "autogen_core.model_context.BufferedChatCompletionContext"
def __init__(self, buffer_size: int, initial_messages: List[LLMMessage] | None = None) -> None:
super().__init__(initial_messages)
if buffer_size <= 0:
@ -27,3 +39,10 @@ class BufferedChatCompletionContext(ChatCompletionContext):
# Remove the first message from the list.
messages = messages[1:]
return messages
def _to_config(self) -> BufferedChatCompletionContextConfig:
return BufferedChatCompletionContextConfig(buffer_size=self._buffer_size, initial_messages=self._messages)
@classmethod
def _from_config(cls, config: BufferedChatCompletionContextConfig) -> Self:
return cls(**config.model_dump())

View File

@ -3,10 +3,11 @@ from typing import Any, List, Mapping
from pydantic import BaseModel, Field
from .._component_config import ComponentBase
from ..models import LLMMessage
class ChatCompletionContext(ABC):
class ChatCompletionContext(ABC, ComponentBase[BaseModel]):
"""An abstract base class for defining the interface of a chat completion context.
A chat completion context lets agents store and retrieve LLM messages.
It can be implemented with different recall strategies.
@ -15,6 +16,8 @@ class ChatCompletionContext(ABC):
initial_messages (List[LLMMessage] | None): The initial messages.
"""
component_type = "chat_completion_context"
def __init__(self, initial_messages: List[LLMMessage] | None = None) -> None:
self._messages: List[LLMMessage] = initial_messages or []

View File

@ -1,11 +1,21 @@
from typing import List
from pydantic import BaseModel
from typing_extensions import Self
from .._component_config import Component
from .._types import FunctionCall
from ..models import AssistantMessage, FunctionExecutionResultMessage, LLMMessage, UserMessage
from ._chat_completion_context import ChatCompletionContext
class HeadAndTailChatCompletionContext(ChatCompletionContext):
class HeadAndTailChatCompletionContextConfig(BaseModel):
head_size: int
tail_size: int
initial_messages: List[LLMMessage] | None = None
class HeadAndTailChatCompletionContext(ChatCompletionContext, Component[HeadAndTailChatCompletionContextConfig]):
"""A chat completion context that keeps a view of the first n and last m messages,
where n is the head size and m is the tail size. The head and tail sizes
are set at initialization.
@ -16,6 +26,9 @@ class HeadAndTailChatCompletionContext(ChatCompletionContext):
initial_messages (List[LLMMessage] | None): The initial messages.
"""
component_config_schema = HeadAndTailChatCompletionContextConfig
component_provider_override = "autogen_core.model_context.HeadAndTailChatCompletionContext"
def __init__(self, head_size: int, tail_size: int, initial_messages: List[LLMMessage] | None = None) -> None:
super().__init__(initial_messages)
if head_size <= 0:
@ -52,3 +65,12 @@ class HeadAndTailChatCompletionContext(ChatCompletionContext):
placeholder_messages = [UserMessage(content=f"Skipped {num_skipped} messages.", source="System")]
return head_messages + placeholder_messages + tail_messages
def _to_config(self) -> HeadAndTailChatCompletionContextConfig:
return HeadAndTailChatCompletionContextConfig(
head_size=self._head_size, tail_size=self._tail_size, initial_messages=self._messages
)
@classmethod
def _from_config(cls, config: HeadAndTailChatCompletionContextConfig) -> Self:
return cls(head_size=config.head_size, tail_size=config.tail_size, initial_messages=config.initial_messages)

View File

@ -1,12 +1,30 @@
from typing import List
from pydantic import BaseModel
from typing_extensions import Self
from .._component_config import Component
from ..models import LLMMessage
from ._chat_completion_context import ChatCompletionContext
class UnboundedChatCompletionContext(ChatCompletionContext):
class UnboundedChatCompletionContextConfig(BaseModel):
pass
class UnboundedChatCompletionContext(ChatCompletionContext, Component[UnboundedChatCompletionContextConfig]):
"""An unbounded chat completion context that keeps a view of the all the messages."""
component_config_schema = UnboundedChatCompletionContextConfig
component_provider_override = "autogen_core.model_context.UnboundedChatCompletionContext"
async def get_messages(self) -> List[LLMMessage]:
"""Get at most `buffer_size` recent messages."""
return self._messages
def _to_config(self) -> UnboundedChatCompletionContextConfig:
return UnboundedChatCompletionContextConfig()
@classmethod
def _from_config(cls, config: UnboundedChatCompletionContextConfig) -> Self:
return cls()

View File

@ -24,7 +24,7 @@ import PIL.Image
from autogen_agentchat.agents import BaseChatAgent
from autogen_agentchat.base import Response
from autogen_agentchat.messages import AgentEvent, ChatMessage, MultiModalMessage, TextMessage
from autogen_core import EVENT_LOGGER_NAME, CancellationToken, FunctionCall
from autogen_core import EVENT_LOGGER_NAME, CancellationToken, Component, ComponentModel, FunctionCall
from autogen_core import Image as AGImage
from autogen_core.models import (
AssistantMessage,
@ -36,6 +36,8 @@ from autogen_core.models import (
)
from PIL import Image
from playwright.async_api import BrowserContext, Download, Page, Playwright, async_playwright
from pydantic import BaseModel
from typing_extensions import Self
from ._events import WebSurferEvent
from ._prompts import WEB_SURFER_OCR_PROMPT, WEB_SURFER_QA_PROMPT, WEB_SURFER_QA_SYSTEM_MESSAGE, WEB_SURFER_TOOL_PROMPT
@ -58,7 +60,23 @@ from ._utils import message_content_to_str
from .playwright_controller import PlaywrightController
class MultimodalWebSurfer(BaseChatAgent):
class MultimodalWebSurferConfig(BaseModel):
name: str
model_client: ComponentModel
downloads_folder: str | None = None
description: str | None = None
debug_dir: str | None = None
headless: bool = True
start_page: str | None = "https://www.bing.com/"
animate_actions: bool = False
to_save_screenshots: bool = False
use_ocr: bool = False
browser_channel: str | None = None
browser_data_dir: str | None = None
to_resize_viewport: bool = True
class MultimodalWebSurfer(BaseChatAgent, Component[MultimodalWebSurferConfig]):
"""
MultimodalWebSurfer is a multimodal agent that acts as a web surfer that can search the web and visit web pages.
@ -144,6 +162,10 @@ class MultimodalWebSurfer(BaseChatAgent):
asyncio.run(main())
"""
component_type = "agent"
component_config_schema = MultimodalWebSurferConfig
component_provider_override = "autogen_ext.agents.web_surfer.MultimodalWebSurfer"
DEFAULT_DESCRIPTION = """
A helpful assistant with access to a web browser.
Ask them to perform web searches, open pages, and interact with content (e.g., clicking links, scrolling the viewport, etc., filling in form fields, etc.).
@ -242,7 +264,8 @@ class MultimodalWebSurfer(BaseChatAgent):
TOOL_SLEEP,
TOOL_HOVER,
]
self.n_lines_page_text = 50 # Number of lines of text to extract from the page in the absence of OCR
# Number of lines of text to extract from the page in the absence of OCR
self.n_lines_page_text = 50
self.did_lazy_init = False # flag to check if we have initialized the browser
async def _lazy_init(
@ -317,7 +340,8 @@ class MultimodalWebSurfer(BaseChatAgent):
if self.to_save_screenshots:
current_timestamp = "_" + int(time.time()).__str__()
screenshot_png_name = "screenshot" + current_timestamp + ".png"
await self._page.screenshot(path=os.path.join(self.debug_dir, screenshot_png_name))
await self._page.screenshot(path=os.path.join(self.debug_dir, screenshot_png_name)) # type: ignore
self.logger.info(
WebSurferEvent(
source=self.name,
@ -346,6 +370,7 @@ class MultimodalWebSurfer(BaseChatAgent):
if self.to_save_screenshots:
current_timestamp = "_" + int(time.time()).__str__()
screenshot_png_name = "screenshot" + current_timestamp + ".png"
await self._page.screenshot(path=os.path.join(self.debug_dir, screenshot_png_name)) # type: ignore
self.logger.info(
WebSurferEvent(
@ -704,6 +729,7 @@ class MultimodalWebSurfer(BaseChatAgent):
if self.to_save_screenshots:
current_timestamp = "_" + int(time.time()).__str__()
screenshot_png_name = "screenshot" + current_timestamp + ".png"
async with aiofiles.open(os.path.join(self.debug_dir, screenshot_png_name), "wb") as file: # type: ignore
await file.write(new_screenshot) # type: ignore
self.logger.info(
@ -861,3 +887,38 @@ class MultimodalWebSurfer(BaseChatAgent):
scaled_screenshot.close()
assert isinstance(response.content, str)
return response.content
def _to_config(self) -> MultimodalWebSurferConfig:
return MultimodalWebSurferConfig(
name=self.name,
model_client=self._model_client.dump_component(),
downloads_folder=self.downloads_folder,
description=self.description,
debug_dir=self.debug_dir,
headless=self.headless,
start_page=self.start_page,
animate_actions=self.animate_actions,
to_save_screenshots=self.to_save_screenshots,
use_ocr=self.use_ocr,
browser_channel=self.browser_channel,
browser_data_dir=self.browser_data_dir,
to_resize_viewport=self.to_resize_viewport,
)
@classmethod
def _from_config(cls, config: MultimodalWebSurferConfig) -> Self:
return cls(
name=config.name,
model_client=ChatCompletionClient.load_component(config.model_client),
downloads_folder=config.downloads_folder,
description=config.description or cls.DEFAULT_DESCRIPTION,
debug_dir=config.debug_dir,
headless=config.headless,
start_page=config.start_page or cls.DEFAULT_START_PAGE,
animate_actions=config.animate_actions,
to_save_screenshots=config.to_save_screenshots,
use_ocr=config.use_ocr,
browser_channel=config.browser_channel,
browser_data_dir=config.browser_data_dir,
to_resize_viewport=config.to_resize_viewport,
)

View File

@ -145,3 +145,38 @@ async def test_run_websurfer(monkeypatch: pytest.MonkeyPatch) -> None:
) # type: ignore
url_after_sleep = agent._page.url # type: ignore
assert url_after_no_tool == url_after_sleep
@pytest.mark.asyncio
async def test_run_websurfer_declarative(monkeypatch: pytest.MonkeyPatch) -> None:
model = "gpt-4o-2024-05-13"
chat_completions = [
ChatCompletion(
id="id1",
choices=[
Choice(
finish_reason="stop",
index=0,
message=ChatCompletionMessage(content="Response to message 3", role="assistant"),
)
],
created=0,
model=model,
object="chat.completion",
usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=15),
),
]
mock = _MockChatCompletion(chat_completions)
monkeypatch.setattr(AsyncCompletions, "create", mock.mock_create)
agent = MultimodalWebSurfer(
"WebSurfer", model_client=OpenAIChatCompletionClient(model=model, api_key=""), use_ocr=False
)
agent_config = agent.dump_component()
assert agent_config.provider == "autogen_ext.agents.web_surfer.MultimodalWebSurfer"
assert agent_config.config["name"] == "WebSurfer"
loaded_agent = MultimodalWebSurfer.load_component(agent_config)
assert isinstance(loaded_agent, MultimodalWebSurfer)
assert loaded_agent.name == "WebSurfer"