From e686342f53a2ccc10c40b3e88c0e5da8bc66c825 Mon Sep 17 00:00:00 2001 From: Eric Zhu Date: Fri, 28 Mar 2025 10:24:41 -0700 Subject: [PATCH] Fix token limited model context (#6137) Token limited model context is currently broken because it is importing from extensions. This fix removed the imports and updated the model context implementation to use model client directly. In the future, the model client's token counting should cache results from model API to provide accurate counting. --- .../tests/test_declarative_components.py | 11 +- .../autogen_core/model_context/__init__.py | 2 +- .../_buffered_chat_completion_context.py | 4 +- .../model_context/_chat_completion_context.py | 5 +- .../_head_and_tail_chat_completion_context.py | 2 +- .../_token_limited_chat_completion_context.py | 107 ++++++++++-------- .../_unbounded_chat_completion_context.py | 6 +- .../src/autogen_core/tools/_base.py | 3 +- .../tests/test_component_config.py | 1 - .../autogen-core/tests/test_model_context.py | 95 ++++++++++------ 10 files changed, 139 insertions(+), 97 deletions(-) diff --git a/python/packages/autogen-agentchat/tests/test_declarative_components.py b/python/packages/autogen-agentchat/tests/test_declarative_components.py index 09054cdaf..69e5c4ae9 100644 --- a/python/packages/autogen-agentchat/tests/test_declarative_components.py +++ b/python/packages/autogen-agentchat/tests/test_declarative_components.py @@ -14,9 +14,10 @@ from autogen_core import ComponentLoader, ComponentModel from autogen_core.model_context import ( BufferedChatCompletionContext, HeadAndTailChatCompletionContext, - UnboundedChatCompletionContext, TokenLimitedChatCompletionContext, + UnboundedChatCompletionContext, ) +from autogen_ext.models.openai import OpenAIChatCompletionClient @pytest.mark.asyncio @@ -105,7 +106,8 @@ async def test_chat_completion_context_declarative() -> None: unbounded_context = UnboundedChatCompletionContext() buffered_context = BufferedChatCompletionContext(buffer_size=5) head_tail_context = HeadAndTailChatCompletionContext(head_size=3, tail_size=2) - token_limited_context = TokenLimitedChatCompletionContext(token_limit=5, model="gpt-4o") + model_client = OpenAIChatCompletionClient(model="gpt-4o", api_key="test_key") + token_limited_context = TokenLimitedChatCompletionContext(model_client=model_client, token_limit=5) # Test serialization unbounded_config = unbounded_context.dump_component() @@ -123,7 +125,10 @@ async def test_chat_completion_context_declarative() -> None: token_limited_config = token_limited_context.dump_component() assert token_limited_config.provider == "autogen_core.model_context.TokenLimitedChatCompletionContext" assert token_limited_config.config["token_limit"] == 5 - assert token_limited_config.config["model"] == "gpt-4o" + assert ( + token_limited_config.config["model_client"]["provider"] + == "autogen_ext.models.openai.OpenAIChatCompletionClient" + ) # Test deserialization loaded_unbounded = ComponentLoader.load_component(unbounded_config, UnboundedChatCompletionContext) diff --git a/python/packages/autogen-core/src/autogen_core/model_context/__init__.py b/python/packages/autogen-core/src/autogen_core/model_context/__init__.py index 513613e6e..b6898614e 100644 --- a/python/packages/autogen-core/src/autogen_core/model_context/__init__.py +++ b/python/packages/autogen-core/src/autogen_core/model_context/__init__.py @@ -1,7 +1,7 @@ from ._buffered_chat_completion_context import BufferedChatCompletionContext -from ._token_limited_chat_completion_context import TokenLimitedChatCompletionContext from ._chat_completion_context import ChatCompletionContext, ChatCompletionContextState from ._head_and_tail_chat_completion_context import HeadAndTailChatCompletionContext +from ._token_limited_chat_completion_context import TokenLimitedChatCompletionContext from ._unbounded_chat_completion_context import ( UnboundedChatCompletionContext, ) diff --git a/python/packages/autogen-core/src/autogen_core/model_context/_buffered_chat_completion_context.py b/python/packages/autogen-core/src/autogen_core/model_context/_buffered_chat_completion_context.py index dcece60b1..5d23f818a 100644 --- a/python/packages/autogen-core/src/autogen_core/model_context/_buffered_chat_completion_context.py +++ b/python/packages/autogen-core/src/autogen_core/model_context/_buffered_chat_completion_context.py @@ -41,7 +41,9 @@ class BufferedChatCompletionContext(ChatCompletionContext, Component[BufferedCha return messages def _to_config(self) -> BufferedChatCompletionContextConfig: - return BufferedChatCompletionContextConfig(buffer_size=self._buffer_size, initial_messages=self._messages) + return BufferedChatCompletionContextConfig( + buffer_size=self._buffer_size, initial_messages=self._initial_messages + ) @classmethod def _from_config(cls, config: BufferedChatCompletionContextConfig) -> Self: diff --git a/python/packages/autogen-core/src/autogen_core/model_context/_chat_completion_context.py b/python/packages/autogen-core/src/autogen_core/model_context/_chat_completion_context.py index 1b14b8e85..84871f154 100644 --- a/python/packages/autogen-core/src/autogen_core/model_context/_chat_completion_context.py +++ b/python/packages/autogen-core/src/autogen_core/model_context/_chat_completion_context.py @@ -47,7 +47,10 @@ class ChatCompletionContext(ABC, ComponentBase[BaseModel]): component_type = "chat_completion_context" def __init__(self, initial_messages: List[LLMMessage] | None = None) -> None: - self._messages: List[LLMMessage] = initial_messages or [] + self._messages: List[LLMMessage] = [] + if initial_messages is not None: + self._messages.extend(initial_messages) + self._initial_messages = initial_messages async def add_message(self, message: LLMMessage) -> None: """Add a message to the context.""" diff --git a/python/packages/autogen-core/src/autogen_core/model_context/_head_and_tail_chat_completion_context.py b/python/packages/autogen-core/src/autogen_core/model_context/_head_and_tail_chat_completion_context.py index a37d5927b..75493618e 100644 --- a/python/packages/autogen-core/src/autogen_core/model_context/_head_and_tail_chat_completion_context.py +++ b/python/packages/autogen-core/src/autogen_core/model_context/_head_and_tail_chat_completion_context.py @@ -68,7 +68,7 @@ class HeadAndTailChatCompletionContext(ChatCompletionContext, Component[HeadAndT def _to_config(self) -> HeadAndTailChatCompletionContextConfig: return HeadAndTailChatCompletionContextConfig( - head_size=self._head_size, tail_size=self._tail_size, initial_messages=self._messages + head_size=self._head_size, tail_size=self._tail_size, initial_messages=self._initial_messages ) @classmethod diff --git a/python/packages/autogen-core/src/autogen_core/model_context/_token_limited_chat_completion_context.py b/python/packages/autogen-core/src/autogen_core/model_context/_token_limited_chat_completion_context.py index 12816755a..b8a0258a4 100644 --- a/python/packages/autogen-core/src/autogen_core/model_context/_token_limited_chat_completion_context.py +++ b/python/packages/autogen-core/src/autogen_core/model_context/_token_limited_chat_completion_context.py @@ -1,83 +1,94 @@ -from typing import List, Sequence -from autogen_core.tools import Tool, ToolSchema +from typing import List from pydantic import BaseModel from typing_extensions import Self -import tiktoken -from .._component_config import Component -from ..models import FunctionExecutionResultMessage, LLMMessage +from .._component_config import Component, ComponentModel +from ..models import ChatCompletionClient, FunctionExecutionResultMessage, LLMMessage +from ..tools import ToolSchema from ._chat_completion_context import ChatCompletionContext -from autogen_ext.models.ollama._ollama_client import count_tokens_ollama -from autogen_ext.models.openai._openai_client import count_tokens_openai - class TokenLimitedChatCompletionContextConfig(BaseModel): - token_limit: int - model: str + model_client: ComponentModel + token_limit: int | None = None + tool_schema: List[ToolSchema] | None = None initial_messages: List[LLMMessage] | None = None class TokenLimitedChatCompletionContext(ChatCompletionContext, Component[TokenLimitedChatCompletionContextConfig]): - """A token based chat completion context maintains a view of the context up to a token limit, - where n is the token limit. The token limit is set at initialization. + """(Experimental) A token based chat completion context maintains a view of the context up to a token limit. + + .. note:: + + Added in v0.4.10. This is an experimental component and may change in the future. Args: - token_limit (int): Max tokens for context. - initial_messages (List[LLMMessage] | None): The initial messages. + model_client (ChatCompletionClient): The model client to use for token counting. + The model client must implement the :meth:`~autogen_core.models.ChatCompletionClient.count_tokens` + and :meth:`~autogen_core.models.ChatCompletionClient.remaining_tokens` methods. + token_limit (int | None): The maximum number of tokens to keep in the context + using the :meth:`~autogen_core.models.ChatCompletionClient.count_tokens` method. + If None, the context will be limited by the model client using the + :meth:`~autogen_core.models.ChatCompletionClient.remaining_tokens` method. + tools (List[ToolSchema] | None): A list of tool schema to use in the context. + initial_messages (List[LLMMessage] | None): A list of initial messages to include in the context. + """ component_config_schema = TokenLimitedChatCompletionContextConfig component_provider_override = "autogen_core.model_context.TokenLimitedChatCompletionContext" - def __init__(self, token_limit: int, model: str, initial_messages: List[LLMMessage] | None = None) -> None: + def __init__( + self, + model_client: ChatCompletionClient, + *, + token_limit: int | None = None, + tool_schema: List[ToolSchema] | None = None, + initial_messages: List[LLMMessage] | None = None, + ) -> None: super().__init__(initial_messages) - if token_limit <= 0: + if token_limit is not None and token_limit <= 0: raise ValueError("token_limit must be greater than 0.") self._token_limit = token_limit - self._model = model + self._model_client = model_client + self._tool_schema = tool_schema or [] async def get_messages(self) -> List[LLMMessage]: - """Get at most `token_limit` tokens in recent messages.""" - token_count = count_chat_tokens(self._messages, self._model) - while token_count > self._token_limit: - middle_index = len(self._messages) // 2 - self._messages.pop(middle_index) - token_count = count_chat_tokens(self._messages, self._model) - messages = self._messages - # Handle the first message is a function call result message. + """Get at most `token_limit` tokens in recent messages. If the token limit is not + provided, then return as many messages as the remaining token allowed by the model client.""" + messages = list(self._messages) + if self._token_limit is None: + remaining_tokens = self._model_client.remaining_tokens(messages, tools=self._tool_schema) + while remaining_tokens < 0 and len(messages) > 0: + middle_index = len(messages) // 2 + messages.pop(middle_index) + remaining_tokens = self._model_client.remaining_tokens(messages, tools=self._tool_schema) + else: + token_count = self._model_client.count_tokens(messages, tools=self._tool_schema) + while token_count > self._token_limit and len(messages) > 0: + middle_index = len(messages) // 2 + messages.pop(middle_index) + token_count = self._model_client.count_tokens(messages, tools=self._tool_schema) if messages and isinstance(messages[0], FunctionExecutionResultMessage): + # Handle the first message is a function call result message. # Remove the first message from the list. messages = messages[1:] return messages def _to_config(self) -> TokenLimitedChatCompletionContextConfig: return TokenLimitedChatCompletionContextConfig( - token_limit=self._token_limit, model=self._model, initial_messages=self._messages + model_client=self._model_client.dump_component(), + token_limit=self._token_limit, + tool_schema=self._tool_schema, + initial_messages=self._initial_messages, ) @classmethod def _from_config(cls, config: TokenLimitedChatCompletionContextConfig) -> Self: - return cls(**config.model_dump()) - - -def count_chat_tokens( - messages: Sequence[LLMMessage], model: str = "gpt-4o", *, tools: Sequence[Tool | ToolSchema] = [] -) -> int: - """Count tokens for a list of messages using the appropriate client based on the model.""" - # Check if the model is an OpenAI model - if "openai" in model.lower(): - return count_tokens_openai(messages, model) - - # Check if the model is an Ollama model - elif "llama" in model.lower(): - return count_tokens_ollama(messages, model) - - # Fallback to cl100k_base encoding if the model is unrecognized - else: - encoding = tiktoken.get_encoding("cl100k_base") - total_tokens = 0 - for message in messages: - total_tokens += len(encoding.encode(str(message.content))) - return total_tokens + return cls( + model_client=ChatCompletionClient.load_component(config.model_client), + token_limit=config.token_limit, + tool_schema=config.tool_schema, + initial_messages=config.initial_messages, + ) diff --git a/python/packages/autogen-core/src/autogen_core/model_context/_unbounded_chat_completion_context.py b/python/packages/autogen-core/src/autogen_core/model_context/_unbounded_chat_completion_context.py index 4bc26db46..a2f409719 100644 --- a/python/packages/autogen-core/src/autogen_core/model_context/_unbounded_chat_completion_context.py +++ b/python/packages/autogen-core/src/autogen_core/model_context/_unbounded_chat_completion_context.py @@ -9,7 +9,7 @@ from ._chat_completion_context import ChatCompletionContext class UnboundedChatCompletionContextConfig(BaseModel): - pass + initial_messages: List[LLMMessage] | None = None class UnboundedChatCompletionContext(ChatCompletionContext, Component[UnboundedChatCompletionContextConfig]): @@ -23,8 +23,8 @@ class UnboundedChatCompletionContext(ChatCompletionContext, Component[UnboundedC return self._messages def _to_config(self) -> UnboundedChatCompletionContextConfig: - return UnboundedChatCompletionContextConfig() + return UnboundedChatCompletionContextConfig(initial_messages=self._initial_messages) @classmethod def _from_config(cls, config: UnboundedChatCompletionContextConfig) -> Self: - return cls() + return cls(initial_messages=config.initial_messages) diff --git a/python/packages/autogen-core/src/autogen_core/tools/_base.py b/python/packages/autogen-core/src/autogen_core/tools/_base.py index 1843f246f..813065939 100644 --- a/python/packages/autogen-core/src/autogen_core/tools/_base.py +++ b/python/packages/autogen-core/src/autogen_core/tools/_base.py @@ -2,7 +2,8 @@ import json import logging from abc import ABC, abstractmethod from collections.abc import Sequence -from typing import Any, Dict, Generic, Mapping, Protocol, Type, TypedDict, TypeVar, cast, runtime_checkable +from typing import Any, Dict, Generic, Mapping, Protocol, Type, TypeVar, cast, runtime_checkable +from typing_extensions import TypedDict import jsonref from opentelemetry.trace import get_tracer diff --git a/python/packages/autogen-core/tests/test_component_config.py b/python/packages/autogen-core/tests/test_component_config.py index 36125d128..9527c1951 100644 --- a/python/packages/autogen-core/tests/test_component_config.py +++ b/python/packages/autogen-core/tests/test_component_config.py @@ -361,7 +361,6 @@ async def test_function_tool() -> None: await loaded_async.run_json({"x": 1.0, "y": 2.0}, cancelled_token) -@pytest.mark.asyncio def test_component_descriptions() -> None: """Test different ways of setting component descriptions.""" assert MyComponent("test").dump_component().description is None diff --git a/python/packages/autogen-core/tests/test_model_context.py b/python/packages/autogen-core/tests/test_model_context.py index 2901eef7d..fbb984090 100644 --- a/python/packages/autogen-core/tests/test_model_context.py +++ b/python/packages/autogen-core/tests/test_model_context.py @@ -4,10 +4,18 @@ import pytest from autogen_core.model_context import ( BufferedChatCompletionContext, HeadAndTailChatCompletionContext, - UnboundedChatCompletionContext, TokenLimitedChatCompletionContext, + UnboundedChatCompletionContext, ) -from autogen_core.models import AssistantMessage, LLMMessage, UserMessage, FunctionExecutionResultMessage +from autogen_core.models import ( + AssistantMessage, + ChatCompletionClient, + FunctionExecutionResultMessage, + LLMMessage, + UserMessage, +) +from autogen_ext.models.ollama import OllamaChatCompletionClient +from autogen_ext.models.openai import OpenAIChatCompletionClient @pytest.mark.asyncio @@ -108,8 +116,18 @@ async def test_unbounded_model_context() -> None: @pytest.mark.asyncio -async def test_token_limited_model_context_openai() -> None: - model_context = TokenLimitedChatCompletionContext(token_limit=20, model="gpt-4o") +@pytest.mark.parametrize( + "model_client,token_limit", + [ + (OpenAIChatCompletionClient(model="gpt-4o", temperature=0.0, api_key="test"), 30), + (OllamaChatCompletionClient(model="llama3.3"), 20), + ], + ids=["openai", "ollama"], +) +async def test_token_limited_model_context_with_token_limit( + model_client: ChatCompletionClient, token_limit: int +) -> None: + model_context = TokenLimitedChatCompletionContext(model_client=model_client, token_limit=token_limit) messages: List[LLMMessage] = [ UserMessage(content="Hello!", source="user"), AssistantMessage(content="What can I do for you?", source="assistant"), @@ -119,37 +137,7 @@ async def test_token_limited_model_context_openai() -> None: await model_context.add_message(msg) retrieved = await model_context.get_messages() - assert len(retrieved) == 2 # Token limit set very low, will remove 1 of the messages - assert retrieved != messages # Will not be equal to the original messages - - await model_context.clear() - retrieved = await model_context.get_messages() - assert len(retrieved) == 0 - - # Test saving and loading state. - for msg in messages: - await model_context.add_message(msg) - state = await model_context.save_state() - await model_context.clear() - await model_context.load_state(state) - retrieved = await model_context.get_messages() - assert len(retrieved) == 2 - assert retrieved != messages - - -@pytest.mark.asyncio -async def test_token_limited_model_context_llama() -> None: - model_context = TokenLimitedChatCompletionContext(token_limit=20, model="llama2-7b") - messages: List[LLMMessage] = [ - UserMessage(content="Hello!", source="user"), - AssistantMessage(content="What can I do for you?", source="assistant"), - UserMessage(content="Tell what are some fun things to do in seattle.", source="user"), - ] - for msg in messages: - await model_context.add_message(msg) - - retrieved = await model_context.get_messages() - assert len(retrieved) == 1 # Token limit set very low, will remove two of the messages + assert len(retrieved) == 1 # Token limit set very low, will remove 2 of the messages assert retrieved != messages # Will not be equal to the original messages await model_context.clear() @@ -168,8 +156,41 @@ async def test_token_limited_model_context_llama() -> None: @pytest.mark.asyncio -async def test_token_limited_model_context_openai_with_function_result() -> None: - model_context = TokenLimitedChatCompletionContext(token_limit=1000, model="gpt-4o") +@pytest.mark.parametrize( + "model_client", + [ + OpenAIChatCompletionClient(model="gpt-4o", temperature=0.0, api_key="test_key"), + OllamaChatCompletionClient(model="llama3.3"), + ], + ids=["openai", "ollama"], +) +async def test_token_limited_model_context_without_token_limit(model_client: ChatCompletionClient) -> None: + model_context = TokenLimitedChatCompletionContext(model_client=model_client) + messages: List[LLMMessage] = [ + UserMessage(content="Hello!", source="user"), + AssistantMessage(content="What can I do for you?", source="assistant"), + UserMessage(content="Tell what are some fun things to do in seattle.", source="user"), + ] + for msg in messages: + await model_context.add_message(msg) + + retrieved = await model_context.get_messages() + assert len(retrieved) == 3 + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model_client,token_limit", + [ + (OpenAIChatCompletionClient(model="gpt-4o", temperature=0.0, api_key="test"), 60), + (OllamaChatCompletionClient(model="llama3.3"), 50), + ], + ids=["openai", "ollama"], +) +async def test_token_limited_model_context_openai_with_function_result( + model_client: ChatCompletionClient, token_limit: int +) -> None: + model_context = TokenLimitedChatCompletionContext(model_client=model_client, token_limit=token_limit) messages: List[LLMMessage] = [ FunctionExecutionResultMessage(content=[]), UserMessage(content="Hello!", source="user"),