Fix token limited model context (#6137)

Token limited model context is currently broken because it is importing
from extensions.

This fix removed the imports and updated the model context
implementation to use model client directly.

In the future, the model client's token counting should cache results
from model API to provide accurate counting.
This commit is contained in:
Eric Zhu 2025-03-28 10:24:41 -07:00 committed by GitHub
parent 0cd3ff46fa
commit e686342f53
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 139 additions and 97 deletions

View File

@ -14,9 +14,10 @@ from autogen_core import ComponentLoader, ComponentModel
from autogen_core.model_context import ( from autogen_core.model_context import (
BufferedChatCompletionContext, BufferedChatCompletionContext,
HeadAndTailChatCompletionContext, HeadAndTailChatCompletionContext,
UnboundedChatCompletionContext,
TokenLimitedChatCompletionContext, TokenLimitedChatCompletionContext,
UnboundedChatCompletionContext,
) )
from autogen_ext.models.openai import OpenAIChatCompletionClient
@pytest.mark.asyncio @pytest.mark.asyncio
@ -105,7 +106,8 @@ async def test_chat_completion_context_declarative() -> None:
unbounded_context = UnboundedChatCompletionContext() unbounded_context = UnboundedChatCompletionContext()
buffered_context = BufferedChatCompletionContext(buffer_size=5) buffered_context = BufferedChatCompletionContext(buffer_size=5)
head_tail_context = HeadAndTailChatCompletionContext(head_size=3, tail_size=2) head_tail_context = HeadAndTailChatCompletionContext(head_size=3, tail_size=2)
token_limited_context = TokenLimitedChatCompletionContext(token_limit=5, model="gpt-4o") model_client = OpenAIChatCompletionClient(model="gpt-4o", api_key="test_key")
token_limited_context = TokenLimitedChatCompletionContext(model_client=model_client, token_limit=5)
# Test serialization # Test serialization
unbounded_config = unbounded_context.dump_component() unbounded_config = unbounded_context.dump_component()
@ -123,7 +125,10 @@ async def test_chat_completion_context_declarative() -> None:
token_limited_config = token_limited_context.dump_component() token_limited_config = token_limited_context.dump_component()
assert token_limited_config.provider == "autogen_core.model_context.TokenLimitedChatCompletionContext" assert token_limited_config.provider == "autogen_core.model_context.TokenLimitedChatCompletionContext"
assert token_limited_config.config["token_limit"] == 5 assert token_limited_config.config["token_limit"] == 5
assert token_limited_config.config["model"] == "gpt-4o" assert (
token_limited_config.config["model_client"]["provider"]
== "autogen_ext.models.openai.OpenAIChatCompletionClient"
)
# Test deserialization # Test deserialization
loaded_unbounded = ComponentLoader.load_component(unbounded_config, UnboundedChatCompletionContext) loaded_unbounded = ComponentLoader.load_component(unbounded_config, UnboundedChatCompletionContext)

View File

@ -1,7 +1,7 @@
from ._buffered_chat_completion_context import BufferedChatCompletionContext from ._buffered_chat_completion_context import BufferedChatCompletionContext
from ._token_limited_chat_completion_context import TokenLimitedChatCompletionContext
from ._chat_completion_context import ChatCompletionContext, ChatCompletionContextState from ._chat_completion_context import ChatCompletionContext, ChatCompletionContextState
from ._head_and_tail_chat_completion_context import HeadAndTailChatCompletionContext from ._head_and_tail_chat_completion_context import HeadAndTailChatCompletionContext
from ._token_limited_chat_completion_context import TokenLimitedChatCompletionContext
from ._unbounded_chat_completion_context import ( from ._unbounded_chat_completion_context import (
UnboundedChatCompletionContext, UnboundedChatCompletionContext,
) )

View File

@ -41,7 +41,9 @@ class BufferedChatCompletionContext(ChatCompletionContext, Component[BufferedCha
return messages return messages
def _to_config(self) -> BufferedChatCompletionContextConfig: def _to_config(self) -> BufferedChatCompletionContextConfig:
return BufferedChatCompletionContextConfig(buffer_size=self._buffer_size, initial_messages=self._messages) return BufferedChatCompletionContextConfig(
buffer_size=self._buffer_size, initial_messages=self._initial_messages
)
@classmethod @classmethod
def _from_config(cls, config: BufferedChatCompletionContextConfig) -> Self: def _from_config(cls, config: BufferedChatCompletionContextConfig) -> Self:

View File

@ -47,7 +47,10 @@ class ChatCompletionContext(ABC, ComponentBase[BaseModel]):
component_type = "chat_completion_context" component_type = "chat_completion_context"
def __init__(self, initial_messages: List[LLMMessage] | None = None) -> None: def __init__(self, initial_messages: List[LLMMessage] | None = None) -> None:
self._messages: List[LLMMessage] = initial_messages or [] self._messages: List[LLMMessage] = []
if initial_messages is not None:
self._messages.extend(initial_messages)
self._initial_messages = initial_messages
async def add_message(self, message: LLMMessage) -> None: async def add_message(self, message: LLMMessage) -> None:
"""Add a message to the context.""" """Add a message to the context."""

View File

@ -68,7 +68,7 @@ class HeadAndTailChatCompletionContext(ChatCompletionContext, Component[HeadAndT
def _to_config(self) -> HeadAndTailChatCompletionContextConfig: def _to_config(self) -> HeadAndTailChatCompletionContextConfig:
return HeadAndTailChatCompletionContextConfig( return HeadAndTailChatCompletionContextConfig(
head_size=self._head_size, tail_size=self._tail_size, initial_messages=self._messages head_size=self._head_size, tail_size=self._tail_size, initial_messages=self._initial_messages
) )
@classmethod @classmethod

View File

@ -1,83 +1,94 @@
from typing import List, Sequence from typing import List
from autogen_core.tools import Tool, ToolSchema
from pydantic import BaseModel from pydantic import BaseModel
from typing_extensions import Self from typing_extensions import Self
import tiktoken
from .._component_config import Component from .._component_config import Component, ComponentModel
from ..models import FunctionExecutionResultMessage, LLMMessage from ..models import ChatCompletionClient, FunctionExecutionResultMessage, LLMMessage
from ..tools import ToolSchema
from ._chat_completion_context import ChatCompletionContext from ._chat_completion_context import ChatCompletionContext
from autogen_ext.models.ollama._ollama_client import count_tokens_ollama
from autogen_ext.models.openai._openai_client import count_tokens_openai
class TokenLimitedChatCompletionContextConfig(BaseModel): class TokenLimitedChatCompletionContextConfig(BaseModel):
token_limit: int model_client: ComponentModel
model: str token_limit: int | None = None
tool_schema: List[ToolSchema] | None = None
initial_messages: List[LLMMessage] | None = None initial_messages: List[LLMMessage] | None = None
class TokenLimitedChatCompletionContext(ChatCompletionContext, Component[TokenLimitedChatCompletionContextConfig]): class TokenLimitedChatCompletionContext(ChatCompletionContext, Component[TokenLimitedChatCompletionContextConfig]):
"""A token based chat completion context maintains a view of the context up to a token limit, """(Experimental) A token based chat completion context maintains a view of the context up to a token limit.
where n is the token limit. The token limit is set at initialization.
.. note::
Added in v0.4.10. This is an experimental component and may change in the future.
Args: Args:
token_limit (int): Max tokens for context. model_client (ChatCompletionClient): The model client to use for token counting.
initial_messages (List[LLMMessage] | None): The initial messages. The model client must implement the :meth:`~autogen_core.models.ChatCompletionClient.count_tokens`
and :meth:`~autogen_core.models.ChatCompletionClient.remaining_tokens` methods.
token_limit (int | None): The maximum number of tokens to keep in the context
using the :meth:`~autogen_core.models.ChatCompletionClient.count_tokens` method.
If None, the context will be limited by the model client using the
:meth:`~autogen_core.models.ChatCompletionClient.remaining_tokens` method.
tools (List[ToolSchema] | None): A list of tool schema to use in the context.
initial_messages (List[LLMMessage] | None): A list of initial messages to include in the context.
""" """
component_config_schema = TokenLimitedChatCompletionContextConfig component_config_schema = TokenLimitedChatCompletionContextConfig
component_provider_override = "autogen_core.model_context.TokenLimitedChatCompletionContext" component_provider_override = "autogen_core.model_context.TokenLimitedChatCompletionContext"
def __init__(self, token_limit: int, model: str, initial_messages: List[LLMMessage] | None = None) -> None: def __init__(
self,
model_client: ChatCompletionClient,
*,
token_limit: int | None = None,
tool_schema: List[ToolSchema] | None = None,
initial_messages: List[LLMMessage] | None = None,
) -> None:
super().__init__(initial_messages) super().__init__(initial_messages)
if token_limit <= 0: if token_limit is not None and token_limit <= 0:
raise ValueError("token_limit must be greater than 0.") raise ValueError("token_limit must be greater than 0.")
self._token_limit = token_limit self._token_limit = token_limit
self._model = model self._model_client = model_client
self._tool_schema = tool_schema or []
async def get_messages(self) -> List[LLMMessage]: async def get_messages(self) -> List[LLMMessage]:
"""Get at most `token_limit` tokens in recent messages.""" """Get at most `token_limit` tokens in recent messages. If the token limit is not
token_count = count_chat_tokens(self._messages, self._model) provided, then return as many messages as the remaining token allowed by the model client."""
while token_count > self._token_limit: messages = list(self._messages)
middle_index = len(self._messages) // 2 if self._token_limit is None:
self._messages.pop(middle_index) remaining_tokens = self._model_client.remaining_tokens(messages, tools=self._tool_schema)
token_count = count_chat_tokens(self._messages, self._model) while remaining_tokens < 0 and len(messages) > 0:
messages = self._messages middle_index = len(messages) // 2
# Handle the first message is a function call result message. messages.pop(middle_index)
remaining_tokens = self._model_client.remaining_tokens(messages, tools=self._tool_schema)
else:
token_count = self._model_client.count_tokens(messages, tools=self._tool_schema)
while token_count > self._token_limit and len(messages) > 0:
middle_index = len(messages) // 2
messages.pop(middle_index)
token_count = self._model_client.count_tokens(messages, tools=self._tool_schema)
if messages and isinstance(messages[0], FunctionExecutionResultMessage): if messages and isinstance(messages[0], FunctionExecutionResultMessage):
# Handle the first message is a function call result message.
# Remove the first message from the list. # Remove the first message from the list.
messages = messages[1:] messages = messages[1:]
return messages return messages
def _to_config(self) -> TokenLimitedChatCompletionContextConfig: def _to_config(self) -> TokenLimitedChatCompletionContextConfig:
return TokenLimitedChatCompletionContextConfig( return TokenLimitedChatCompletionContextConfig(
token_limit=self._token_limit, model=self._model, initial_messages=self._messages model_client=self._model_client.dump_component(),
token_limit=self._token_limit,
tool_schema=self._tool_schema,
initial_messages=self._initial_messages,
) )
@classmethod @classmethod
def _from_config(cls, config: TokenLimitedChatCompletionContextConfig) -> Self: def _from_config(cls, config: TokenLimitedChatCompletionContextConfig) -> Self:
return cls(**config.model_dump()) return cls(
model_client=ChatCompletionClient.load_component(config.model_client),
token_limit=config.token_limit,
def count_chat_tokens( tool_schema=config.tool_schema,
messages: Sequence[LLMMessage], model: str = "gpt-4o", *, tools: Sequence[Tool | ToolSchema] = [] initial_messages=config.initial_messages,
) -> int: )
"""Count tokens for a list of messages using the appropriate client based on the model."""
# Check if the model is an OpenAI model
if "openai" in model.lower():
return count_tokens_openai(messages, model)
# Check if the model is an Ollama model
elif "llama" in model.lower():
return count_tokens_ollama(messages, model)
# Fallback to cl100k_base encoding if the model is unrecognized
else:
encoding = tiktoken.get_encoding("cl100k_base")
total_tokens = 0
for message in messages:
total_tokens += len(encoding.encode(str(message.content)))
return total_tokens

View File

@ -9,7 +9,7 @@ from ._chat_completion_context import ChatCompletionContext
class UnboundedChatCompletionContextConfig(BaseModel): class UnboundedChatCompletionContextConfig(BaseModel):
pass initial_messages: List[LLMMessage] | None = None
class UnboundedChatCompletionContext(ChatCompletionContext, Component[UnboundedChatCompletionContextConfig]): class UnboundedChatCompletionContext(ChatCompletionContext, Component[UnboundedChatCompletionContextConfig]):
@ -23,8 +23,8 @@ class UnboundedChatCompletionContext(ChatCompletionContext, Component[UnboundedC
return self._messages return self._messages
def _to_config(self) -> UnboundedChatCompletionContextConfig: def _to_config(self) -> UnboundedChatCompletionContextConfig:
return UnboundedChatCompletionContextConfig() return UnboundedChatCompletionContextConfig(initial_messages=self._initial_messages)
@classmethod @classmethod
def _from_config(cls, config: UnboundedChatCompletionContextConfig) -> Self: def _from_config(cls, config: UnboundedChatCompletionContextConfig) -> Self:
return cls() return cls(initial_messages=config.initial_messages)

View File

@ -2,7 +2,8 @@ import json
import logging import logging
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Sequence from collections.abc import Sequence
from typing import Any, Dict, Generic, Mapping, Protocol, Type, TypedDict, TypeVar, cast, runtime_checkable from typing import Any, Dict, Generic, Mapping, Protocol, Type, TypeVar, cast, runtime_checkable
from typing_extensions import TypedDict
import jsonref import jsonref
from opentelemetry.trace import get_tracer from opentelemetry.trace import get_tracer

View File

@ -361,7 +361,6 @@ async def test_function_tool() -> None:
await loaded_async.run_json({"x": 1.0, "y": 2.0}, cancelled_token) await loaded_async.run_json({"x": 1.0, "y": 2.0}, cancelled_token)
@pytest.mark.asyncio
def test_component_descriptions() -> None: def test_component_descriptions() -> None:
"""Test different ways of setting component descriptions.""" """Test different ways of setting component descriptions."""
assert MyComponent("test").dump_component().description is None assert MyComponent("test").dump_component().description is None

View File

@ -4,10 +4,18 @@ import pytest
from autogen_core.model_context import ( from autogen_core.model_context import (
BufferedChatCompletionContext, BufferedChatCompletionContext,
HeadAndTailChatCompletionContext, HeadAndTailChatCompletionContext,
UnboundedChatCompletionContext,
TokenLimitedChatCompletionContext, TokenLimitedChatCompletionContext,
UnboundedChatCompletionContext,
) )
from autogen_core.models import AssistantMessage, LLMMessage, UserMessage, FunctionExecutionResultMessage from autogen_core.models import (
AssistantMessage,
ChatCompletionClient,
FunctionExecutionResultMessage,
LLMMessage,
UserMessage,
)
from autogen_ext.models.ollama import OllamaChatCompletionClient
from autogen_ext.models.openai import OpenAIChatCompletionClient
@pytest.mark.asyncio @pytest.mark.asyncio
@ -108,8 +116,18 @@ async def test_unbounded_model_context() -> None:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_token_limited_model_context_openai() -> None: @pytest.mark.parametrize(
model_context = TokenLimitedChatCompletionContext(token_limit=20, model="gpt-4o") "model_client,token_limit",
[
(OpenAIChatCompletionClient(model="gpt-4o", temperature=0.0, api_key="test"), 30),
(OllamaChatCompletionClient(model="llama3.3"), 20),
],
ids=["openai", "ollama"],
)
async def test_token_limited_model_context_with_token_limit(
model_client: ChatCompletionClient, token_limit: int
) -> None:
model_context = TokenLimitedChatCompletionContext(model_client=model_client, token_limit=token_limit)
messages: List[LLMMessage] = [ messages: List[LLMMessage] = [
UserMessage(content="Hello!", source="user"), UserMessage(content="Hello!", source="user"),
AssistantMessage(content="What can I do for you?", source="assistant"), AssistantMessage(content="What can I do for you?", source="assistant"),
@ -119,37 +137,7 @@ async def test_token_limited_model_context_openai() -> None:
await model_context.add_message(msg) await model_context.add_message(msg)
retrieved = await model_context.get_messages() retrieved = await model_context.get_messages()
assert len(retrieved) == 2 # Token limit set very low, will remove 1 of the messages assert len(retrieved) == 1 # Token limit set very low, will remove 2 of the messages
assert retrieved != messages # Will not be equal to the original messages
await model_context.clear()
retrieved = await model_context.get_messages()
assert len(retrieved) == 0
# Test saving and loading state.
for msg in messages:
await model_context.add_message(msg)
state = await model_context.save_state()
await model_context.clear()
await model_context.load_state(state)
retrieved = await model_context.get_messages()
assert len(retrieved) == 2
assert retrieved != messages
@pytest.mark.asyncio
async def test_token_limited_model_context_llama() -> None:
model_context = TokenLimitedChatCompletionContext(token_limit=20, model="llama2-7b")
messages: List[LLMMessage] = [
UserMessage(content="Hello!", source="user"),
AssistantMessage(content="What can I do for you?", source="assistant"),
UserMessage(content="Tell what are some fun things to do in seattle.", source="user"),
]
for msg in messages:
await model_context.add_message(msg)
retrieved = await model_context.get_messages()
assert len(retrieved) == 1 # Token limit set very low, will remove two of the messages
assert retrieved != messages # Will not be equal to the original messages assert retrieved != messages # Will not be equal to the original messages
await model_context.clear() await model_context.clear()
@ -168,8 +156,41 @@ async def test_token_limited_model_context_llama() -> None:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_token_limited_model_context_openai_with_function_result() -> None: @pytest.mark.parametrize(
model_context = TokenLimitedChatCompletionContext(token_limit=1000, model="gpt-4o") "model_client",
[
OpenAIChatCompletionClient(model="gpt-4o", temperature=0.0, api_key="test_key"),
OllamaChatCompletionClient(model="llama3.3"),
],
ids=["openai", "ollama"],
)
async def test_token_limited_model_context_without_token_limit(model_client: ChatCompletionClient) -> None:
model_context = TokenLimitedChatCompletionContext(model_client=model_client)
messages: List[LLMMessage] = [
UserMessage(content="Hello!", source="user"),
AssistantMessage(content="What can I do for you?", source="assistant"),
UserMessage(content="Tell what are some fun things to do in seattle.", source="user"),
]
for msg in messages:
await model_context.add_message(msg)
retrieved = await model_context.get_messages()
assert len(retrieved) == 3
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_client,token_limit",
[
(OpenAIChatCompletionClient(model="gpt-4o", temperature=0.0, api_key="test"), 60),
(OllamaChatCompletionClient(model="llama3.3"), 50),
],
ids=["openai", "ollama"],
)
async def test_token_limited_model_context_openai_with_function_result(
model_client: ChatCompletionClient, token_limit: int
) -> None:
model_context = TokenLimitedChatCompletionContext(model_client=model_client, token_limit=token_limit)
messages: List[LLMMessage] = [ messages: List[LLMMessage] = [
FunctionExecutionResultMessage(content=[]), FunctionExecutionResultMessage(content=[]),
UserMessage(content="Hello!", source="user"), UserMessage(content="Hello!", source="user"),