diff --git a/python/packages/autogen-agentchat/tests/test_declarative_components.py b/python/packages/autogen-agentchat/tests/test_declarative_components.py index 09054cdaf..69e5c4ae9 100644 --- a/python/packages/autogen-agentchat/tests/test_declarative_components.py +++ b/python/packages/autogen-agentchat/tests/test_declarative_components.py @@ -14,9 +14,10 @@ from autogen_core import ComponentLoader, ComponentModel from autogen_core.model_context import ( BufferedChatCompletionContext, HeadAndTailChatCompletionContext, - UnboundedChatCompletionContext, TokenLimitedChatCompletionContext, + UnboundedChatCompletionContext, ) +from autogen_ext.models.openai import OpenAIChatCompletionClient @pytest.mark.asyncio @@ -105,7 +106,8 @@ async def test_chat_completion_context_declarative() -> None: unbounded_context = UnboundedChatCompletionContext() buffered_context = BufferedChatCompletionContext(buffer_size=5) head_tail_context = HeadAndTailChatCompletionContext(head_size=3, tail_size=2) - token_limited_context = TokenLimitedChatCompletionContext(token_limit=5, model="gpt-4o") + model_client = OpenAIChatCompletionClient(model="gpt-4o", api_key="test_key") + token_limited_context = TokenLimitedChatCompletionContext(model_client=model_client, token_limit=5) # Test serialization unbounded_config = unbounded_context.dump_component() @@ -123,7 +125,10 @@ async def test_chat_completion_context_declarative() -> None: token_limited_config = token_limited_context.dump_component() assert token_limited_config.provider == "autogen_core.model_context.TokenLimitedChatCompletionContext" assert token_limited_config.config["token_limit"] == 5 - assert token_limited_config.config["model"] == "gpt-4o" + assert ( + token_limited_config.config["model_client"]["provider"] + == "autogen_ext.models.openai.OpenAIChatCompletionClient" + ) # Test deserialization loaded_unbounded = ComponentLoader.load_component(unbounded_config, UnboundedChatCompletionContext) diff --git a/python/packages/autogen-core/src/autogen_core/model_context/__init__.py b/python/packages/autogen-core/src/autogen_core/model_context/__init__.py index 513613e6e..b6898614e 100644 --- a/python/packages/autogen-core/src/autogen_core/model_context/__init__.py +++ b/python/packages/autogen-core/src/autogen_core/model_context/__init__.py @@ -1,7 +1,7 @@ from ._buffered_chat_completion_context import BufferedChatCompletionContext -from ._token_limited_chat_completion_context import TokenLimitedChatCompletionContext from ._chat_completion_context import ChatCompletionContext, ChatCompletionContextState from ._head_and_tail_chat_completion_context import HeadAndTailChatCompletionContext +from ._token_limited_chat_completion_context import TokenLimitedChatCompletionContext from ._unbounded_chat_completion_context import ( UnboundedChatCompletionContext, ) diff --git a/python/packages/autogen-core/src/autogen_core/model_context/_buffered_chat_completion_context.py b/python/packages/autogen-core/src/autogen_core/model_context/_buffered_chat_completion_context.py index dcece60b1..5d23f818a 100644 --- a/python/packages/autogen-core/src/autogen_core/model_context/_buffered_chat_completion_context.py +++ b/python/packages/autogen-core/src/autogen_core/model_context/_buffered_chat_completion_context.py @@ -41,7 +41,9 @@ class BufferedChatCompletionContext(ChatCompletionContext, Component[BufferedCha return messages def _to_config(self) -> BufferedChatCompletionContextConfig: - return BufferedChatCompletionContextConfig(buffer_size=self._buffer_size, initial_messages=self._messages) + return BufferedChatCompletionContextConfig( + buffer_size=self._buffer_size, initial_messages=self._initial_messages + ) @classmethod def _from_config(cls, config: BufferedChatCompletionContextConfig) -> Self: diff --git a/python/packages/autogen-core/src/autogen_core/model_context/_chat_completion_context.py b/python/packages/autogen-core/src/autogen_core/model_context/_chat_completion_context.py index 1b14b8e85..84871f154 100644 --- a/python/packages/autogen-core/src/autogen_core/model_context/_chat_completion_context.py +++ b/python/packages/autogen-core/src/autogen_core/model_context/_chat_completion_context.py @@ -47,7 +47,10 @@ class ChatCompletionContext(ABC, ComponentBase[BaseModel]): component_type = "chat_completion_context" def __init__(self, initial_messages: List[LLMMessage] | None = None) -> None: - self._messages: List[LLMMessage] = initial_messages or [] + self._messages: List[LLMMessage] = [] + if initial_messages is not None: + self._messages.extend(initial_messages) + self._initial_messages = initial_messages async def add_message(self, message: LLMMessage) -> None: """Add a message to the context.""" diff --git a/python/packages/autogen-core/src/autogen_core/model_context/_head_and_tail_chat_completion_context.py b/python/packages/autogen-core/src/autogen_core/model_context/_head_and_tail_chat_completion_context.py index a37d5927b..75493618e 100644 --- a/python/packages/autogen-core/src/autogen_core/model_context/_head_and_tail_chat_completion_context.py +++ b/python/packages/autogen-core/src/autogen_core/model_context/_head_and_tail_chat_completion_context.py @@ -68,7 +68,7 @@ class HeadAndTailChatCompletionContext(ChatCompletionContext, Component[HeadAndT def _to_config(self) -> HeadAndTailChatCompletionContextConfig: return HeadAndTailChatCompletionContextConfig( - head_size=self._head_size, tail_size=self._tail_size, initial_messages=self._messages + head_size=self._head_size, tail_size=self._tail_size, initial_messages=self._initial_messages ) @classmethod diff --git a/python/packages/autogen-core/src/autogen_core/model_context/_token_limited_chat_completion_context.py b/python/packages/autogen-core/src/autogen_core/model_context/_token_limited_chat_completion_context.py index 12816755a..b8a0258a4 100644 --- a/python/packages/autogen-core/src/autogen_core/model_context/_token_limited_chat_completion_context.py +++ b/python/packages/autogen-core/src/autogen_core/model_context/_token_limited_chat_completion_context.py @@ -1,83 +1,94 @@ -from typing import List, Sequence -from autogen_core.tools import Tool, ToolSchema +from typing import List from pydantic import BaseModel from typing_extensions import Self -import tiktoken -from .._component_config import Component -from ..models import FunctionExecutionResultMessage, LLMMessage +from .._component_config import Component, ComponentModel +from ..models import ChatCompletionClient, FunctionExecutionResultMessage, LLMMessage +from ..tools import ToolSchema from ._chat_completion_context import ChatCompletionContext -from autogen_ext.models.ollama._ollama_client import count_tokens_ollama -from autogen_ext.models.openai._openai_client import count_tokens_openai - class TokenLimitedChatCompletionContextConfig(BaseModel): - token_limit: int - model: str + model_client: ComponentModel + token_limit: int | None = None + tool_schema: List[ToolSchema] | None = None initial_messages: List[LLMMessage] | None = None class TokenLimitedChatCompletionContext(ChatCompletionContext, Component[TokenLimitedChatCompletionContextConfig]): - """A token based chat completion context maintains a view of the context up to a token limit, - where n is the token limit. The token limit is set at initialization. + """(Experimental) A token based chat completion context maintains a view of the context up to a token limit. + + .. note:: + + Added in v0.4.10. This is an experimental component and may change in the future. Args: - token_limit (int): Max tokens for context. - initial_messages (List[LLMMessage] | None): The initial messages. + model_client (ChatCompletionClient): The model client to use for token counting. + The model client must implement the :meth:`~autogen_core.models.ChatCompletionClient.count_tokens` + and :meth:`~autogen_core.models.ChatCompletionClient.remaining_tokens` methods. + token_limit (int | None): The maximum number of tokens to keep in the context + using the :meth:`~autogen_core.models.ChatCompletionClient.count_tokens` method. + If None, the context will be limited by the model client using the + :meth:`~autogen_core.models.ChatCompletionClient.remaining_tokens` method. + tools (List[ToolSchema] | None): A list of tool schema to use in the context. + initial_messages (List[LLMMessage] | None): A list of initial messages to include in the context. + """ component_config_schema = TokenLimitedChatCompletionContextConfig component_provider_override = "autogen_core.model_context.TokenLimitedChatCompletionContext" - def __init__(self, token_limit: int, model: str, initial_messages: List[LLMMessage] | None = None) -> None: + def __init__( + self, + model_client: ChatCompletionClient, + *, + token_limit: int | None = None, + tool_schema: List[ToolSchema] | None = None, + initial_messages: List[LLMMessage] | None = None, + ) -> None: super().__init__(initial_messages) - if token_limit <= 0: + if token_limit is not None and token_limit <= 0: raise ValueError("token_limit must be greater than 0.") self._token_limit = token_limit - self._model = model + self._model_client = model_client + self._tool_schema = tool_schema or [] async def get_messages(self) -> List[LLMMessage]: - """Get at most `token_limit` tokens in recent messages.""" - token_count = count_chat_tokens(self._messages, self._model) - while token_count > self._token_limit: - middle_index = len(self._messages) // 2 - self._messages.pop(middle_index) - token_count = count_chat_tokens(self._messages, self._model) - messages = self._messages - # Handle the first message is a function call result message. + """Get at most `token_limit` tokens in recent messages. If the token limit is not + provided, then return as many messages as the remaining token allowed by the model client.""" + messages = list(self._messages) + if self._token_limit is None: + remaining_tokens = self._model_client.remaining_tokens(messages, tools=self._tool_schema) + while remaining_tokens < 0 and len(messages) > 0: + middle_index = len(messages) // 2 + messages.pop(middle_index) + remaining_tokens = self._model_client.remaining_tokens(messages, tools=self._tool_schema) + else: + token_count = self._model_client.count_tokens(messages, tools=self._tool_schema) + while token_count > self._token_limit and len(messages) > 0: + middle_index = len(messages) // 2 + messages.pop(middle_index) + token_count = self._model_client.count_tokens(messages, tools=self._tool_schema) if messages and isinstance(messages[0], FunctionExecutionResultMessage): + # Handle the first message is a function call result message. # Remove the first message from the list. messages = messages[1:] return messages def _to_config(self) -> TokenLimitedChatCompletionContextConfig: return TokenLimitedChatCompletionContextConfig( - token_limit=self._token_limit, model=self._model, initial_messages=self._messages + model_client=self._model_client.dump_component(), + token_limit=self._token_limit, + tool_schema=self._tool_schema, + initial_messages=self._initial_messages, ) @classmethod def _from_config(cls, config: TokenLimitedChatCompletionContextConfig) -> Self: - return cls(**config.model_dump()) - - -def count_chat_tokens( - messages: Sequence[LLMMessage], model: str = "gpt-4o", *, tools: Sequence[Tool | ToolSchema] = [] -) -> int: - """Count tokens for a list of messages using the appropriate client based on the model.""" - # Check if the model is an OpenAI model - if "openai" in model.lower(): - return count_tokens_openai(messages, model) - - # Check if the model is an Ollama model - elif "llama" in model.lower(): - return count_tokens_ollama(messages, model) - - # Fallback to cl100k_base encoding if the model is unrecognized - else: - encoding = tiktoken.get_encoding("cl100k_base") - total_tokens = 0 - for message in messages: - total_tokens += len(encoding.encode(str(message.content))) - return total_tokens + return cls( + model_client=ChatCompletionClient.load_component(config.model_client), + token_limit=config.token_limit, + tool_schema=config.tool_schema, + initial_messages=config.initial_messages, + ) diff --git a/python/packages/autogen-core/src/autogen_core/model_context/_unbounded_chat_completion_context.py b/python/packages/autogen-core/src/autogen_core/model_context/_unbounded_chat_completion_context.py index 4bc26db46..a2f409719 100644 --- a/python/packages/autogen-core/src/autogen_core/model_context/_unbounded_chat_completion_context.py +++ b/python/packages/autogen-core/src/autogen_core/model_context/_unbounded_chat_completion_context.py @@ -9,7 +9,7 @@ from ._chat_completion_context import ChatCompletionContext class UnboundedChatCompletionContextConfig(BaseModel): - pass + initial_messages: List[LLMMessage] | None = None class UnboundedChatCompletionContext(ChatCompletionContext, Component[UnboundedChatCompletionContextConfig]): @@ -23,8 +23,8 @@ class UnboundedChatCompletionContext(ChatCompletionContext, Component[UnboundedC return self._messages def _to_config(self) -> UnboundedChatCompletionContextConfig: - return UnboundedChatCompletionContextConfig() + return UnboundedChatCompletionContextConfig(initial_messages=self._initial_messages) @classmethod def _from_config(cls, config: UnboundedChatCompletionContextConfig) -> Self: - return cls() + return cls(initial_messages=config.initial_messages) diff --git a/python/packages/autogen-core/src/autogen_core/tools/_base.py b/python/packages/autogen-core/src/autogen_core/tools/_base.py index 1843f246f..813065939 100644 --- a/python/packages/autogen-core/src/autogen_core/tools/_base.py +++ b/python/packages/autogen-core/src/autogen_core/tools/_base.py @@ -2,7 +2,8 @@ import json import logging from abc import ABC, abstractmethod from collections.abc import Sequence -from typing import Any, Dict, Generic, Mapping, Protocol, Type, TypedDict, TypeVar, cast, runtime_checkable +from typing import Any, Dict, Generic, Mapping, Protocol, Type, TypeVar, cast, runtime_checkable +from typing_extensions import TypedDict import jsonref from opentelemetry.trace import get_tracer diff --git a/python/packages/autogen-core/tests/test_component_config.py b/python/packages/autogen-core/tests/test_component_config.py index 36125d128..9527c1951 100644 --- a/python/packages/autogen-core/tests/test_component_config.py +++ b/python/packages/autogen-core/tests/test_component_config.py @@ -361,7 +361,6 @@ async def test_function_tool() -> None: await loaded_async.run_json({"x": 1.0, "y": 2.0}, cancelled_token) -@pytest.mark.asyncio def test_component_descriptions() -> None: """Test different ways of setting component descriptions.""" assert MyComponent("test").dump_component().description is None diff --git a/python/packages/autogen-core/tests/test_model_context.py b/python/packages/autogen-core/tests/test_model_context.py index 2901eef7d..fbb984090 100644 --- a/python/packages/autogen-core/tests/test_model_context.py +++ b/python/packages/autogen-core/tests/test_model_context.py @@ -4,10 +4,18 @@ import pytest from autogen_core.model_context import ( BufferedChatCompletionContext, HeadAndTailChatCompletionContext, - UnboundedChatCompletionContext, TokenLimitedChatCompletionContext, + UnboundedChatCompletionContext, ) -from autogen_core.models import AssistantMessage, LLMMessage, UserMessage, FunctionExecutionResultMessage +from autogen_core.models import ( + AssistantMessage, + ChatCompletionClient, + FunctionExecutionResultMessage, + LLMMessage, + UserMessage, +) +from autogen_ext.models.ollama import OllamaChatCompletionClient +from autogen_ext.models.openai import OpenAIChatCompletionClient @pytest.mark.asyncio @@ -108,8 +116,18 @@ async def test_unbounded_model_context() -> None: @pytest.mark.asyncio -async def test_token_limited_model_context_openai() -> None: - model_context = TokenLimitedChatCompletionContext(token_limit=20, model="gpt-4o") +@pytest.mark.parametrize( + "model_client,token_limit", + [ + (OpenAIChatCompletionClient(model="gpt-4o", temperature=0.0, api_key="test"), 30), + (OllamaChatCompletionClient(model="llama3.3"), 20), + ], + ids=["openai", "ollama"], +) +async def test_token_limited_model_context_with_token_limit( + model_client: ChatCompletionClient, token_limit: int +) -> None: + model_context = TokenLimitedChatCompletionContext(model_client=model_client, token_limit=token_limit) messages: List[LLMMessage] = [ UserMessage(content="Hello!", source="user"), AssistantMessage(content="What can I do for you?", source="assistant"), @@ -119,37 +137,7 @@ async def test_token_limited_model_context_openai() -> None: await model_context.add_message(msg) retrieved = await model_context.get_messages() - assert len(retrieved) == 2 # Token limit set very low, will remove 1 of the messages - assert retrieved != messages # Will not be equal to the original messages - - await model_context.clear() - retrieved = await model_context.get_messages() - assert len(retrieved) == 0 - - # Test saving and loading state. - for msg in messages: - await model_context.add_message(msg) - state = await model_context.save_state() - await model_context.clear() - await model_context.load_state(state) - retrieved = await model_context.get_messages() - assert len(retrieved) == 2 - assert retrieved != messages - - -@pytest.mark.asyncio -async def test_token_limited_model_context_llama() -> None: - model_context = TokenLimitedChatCompletionContext(token_limit=20, model="llama2-7b") - messages: List[LLMMessage] = [ - UserMessage(content="Hello!", source="user"), - AssistantMessage(content="What can I do for you?", source="assistant"), - UserMessage(content="Tell what are some fun things to do in seattle.", source="user"), - ] - for msg in messages: - await model_context.add_message(msg) - - retrieved = await model_context.get_messages() - assert len(retrieved) == 1 # Token limit set very low, will remove two of the messages + assert len(retrieved) == 1 # Token limit set very low, will remove 2 of the messages assert retrieved != messages # Will not be equal to the original messages await model_context.clear() @@ -168,8 +156,41 @@ async def test_token_limited_model_context_llama() -> None: @pytest.mark.asyncio -async def test_token_limited_model_context_openai_with_function_result() -> None: - model_context = TokenLimitedChatCompletionContext(token_limit=1000, model="gpt-4o") +@pytest.mark.parametrize( + "model_client", + [ + OpenAIChatCompletionClient(model="gpt-4o", temperature=0.0, api_key="test_key"), + OllamaChatCompletionClient(model="llama3.3"), + ], + ids=["openai", "ollama"], +) +async def test_token_limited_model_context_without_token_limit(model_client: ChatCompletionClient) -> None: + model_context = TokenLimitedChatCompletionContext(model_client=model_client) + messages: List[LLMMessage] = [ + UserMessage(content="Hello!", source="user"), + AssistantMessage(content="What can I do for you?", source="assistant"), + UserMessage(content="Tell what are some fun things to do in seattle.", source="user"), + ] + for msg in messages: + await model_context.add_message(msg) + + retrieved = await model_context.get_messages() + assert len(retrieved) == 3 + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model_client,token_limit", + [ + (OpenAIChatCompletionClient(model="gpt-4o", temperature=0.0, api_key="test"), 60), + (OllamaChatCompletionClient(model="llama3.3"), 50), + ], + ids=["openai", "ollama"], +) +async def test_token_limited_model_context_openai_with_function_result( + model_client: ChatCompletionClient, token_limit: int +) -> None: + model_context = TokenLimitedChatCompletionContext(model_client=model_client, token_limit=token_limit) messages: List[LLMMessage] = [ FunctionExecutionResultMessage(content=[]), UserMessage(content="Hello!", source="user"),