Fix token limited model context (#6137)

Token limited model context is currently broken because it is importing
from extensions.

This fix removed the imports and updated the model context
implementation to use model client directly.

In the future, the model client's token counting should cache results
from model API to provide accurate counting.
This commit is contained in:
Eric Zhu 2025-03-28 10:24:41 -07:00 committed by GitHub
parent 0cd3ff46fa
commit e686342f53
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 139 additions and 97 deletions

View File

@ -14,9 +14,10 @@ from autogen_core import ComponentLoader, ComponentModel
from autogen_core.model_context import (
BufferedChatCompletionContext,
HeadAndTailChatCompletionContext,
UnboundedChatCompletionContext,
TokenLimitedChatCompletionContext,
UnboundedChatCompletionContext,
)
from autogen_ext.models.openai import OpenAIChatCompletionClient
@pytest.mark.asyncio
@ -105,7 +106,8 @@ async def test_chat_completion_context_declarative() -> None:
unbounded_context = UnboundedChatCompletionContext()
buffered_context = BufferedChatCompletionContext(buffer_size=5)
head_tail_context = HeadAndTailChatCompletionContext(head_size=3, tail_size=2)
token_limited_context = TokenLimitedChatCompletionContext(token_limit=5, model="gpt-4o")
model_client = OpenAIChatCompletionClient(model="gpt-4o", api_key="test_key")
token_limited_context = TokenLimitedChatCompletionContext(model_client=model_client, token_limit=5)
# Test serialization
unbounded_config = unbounded_context.dump_component()
@ -123,7 +125,10 @@ async def test_chat_completion_context_declarative() -> None:
token_limited_config = token_limited_context.dump_component()
assert token_limited_config.provider == "autogen_core.model_context.TokenLimitedChatCompletionContext"
assert token_limited_config.config["token_limit"] == 5
assert token_limited_config.config["model"] == "gpt-4o"
assert (
token_limited_config.config["model_client"]["provider"]
== "autogen_ext.models.openai.OpenAIChatCompletionClient"
)
# Test deserialization
loaded_unbounded = ComponentLoader.load_component(unbounded_config, UnboundedChatCompletionContext)

View File

@ -1,7 +1,7 @@
from ._buffered_chat_completion_context import BufferedChatCompletionContext
from ._token_limited_chat_completion_context import TokenLimitedChatCompletionContext
from ._chat_completion_context import ChatCompletionContext, ChatCompletionContextState
from ._head_and_tail_chat_completion_context import HeadAndTailChatCompletionContext
from ._token_limited_chat_completion_context import TokenLimitedChatCompletionContext
from ._unbounded_chat_completion_context import (
UnboundedChatCompletionContext,
)

View File

@ -41,7 +41,9 @@ class BufferedChatCompletionContext(ChatCompletionContext, Component[BufferedCha
return messages
def _to_config(self) -> BufferedChatCompletionContextConfig:
return BufferedChatCompletionContextConfig(buffer_size=self._buffer_size, initial_messages=self._messages)
return BufferedChatCompletionContextConfig(
buffer_size=self._buffer_size, initial_messages=self._initial_messages
)
@classmethod
def _from_config(cls, config: BufferedChatCompletionContextConfig) -> Self:

View File

@ -47,7 +47,10 @@ class ChatCompletionContext(ABC, ComponentBase[BaseModel]):
component_type = "chat_completion_context"
def __init__(self, initial_messages: List[LLMMessage] | None = None) -> None:
self._messages: List[LLMMessage] = initial_messages or []
self._messages: List[LLMMessage] = []
if initial_messages is not None:
self._messages.extend(initial_messages)
self._initial_messages = initial_messages
async def add_message(self, message: LLMMessage) -> None:
"""Add a message to the context."""

View File

@ -68,7 +68,7 @@ class HeadAndTailChatCompletionContext(ChatCompletionContext, Component[HeadAndT
def _to_config(self) -> HeadAndTailChatCompletionContextConfig:
return HeadAndTailChatCompletionContextConfig(
head_size=self._head_size, tail_size=self._tail_size, initial_messages=self._messages
head_size=self._head_size, tail_size=self._tail_size, initial_messages=self._initial_messages
)
@classmethod

View File

@ -1,83 +1,94 @@
from typing import List, Sequence
from autogen_core.tools import Tool, ToolSchema
from typing import List
from pydantic import BaseModel
from typing_extensions import Self
import tiktoken
from .._component_config import Component
from ..models import FunctionExecutionResultMessage, LLMMessage
from .._component_config import Component, ComponentModel
from ..models import ChatCompletionClient, FunctionExecutionResultMessage, LLMMessage
from ..tools import ToolSchema
from ._chat_completion_context import ChatCompletionContext
from autogen_ext.models.ollama._ollama_client import count_tokens_ollama
from autogen_ext.models.openai._openai_client import count_tokens_openai
class TokenLimitedChatCompletionContextConfig(BaseModel):
token_limit: int
model: str
model_client: ComponentModel
token_limit: int | None = None
tool_schema: List[ToolSchema] | None = None
initial_messages: List[LLMMessage] | None = None
class TokenLimitedChatCompletionContext(ChatCompletionContext, Component[TokenLimitedChatCompletionContextConfig]):
"""A token based chat completion context maintains a view of the context up to a token limit,
where n is the token limit. The token limit is set at initialization.
"""(Experimental) A token based chat completion context maintains a view of the context up to a token limit.
.. note::
Added in v0.4.10. This is an experimental component and may change in the future.
Args:
token_limit (int): Max tokens for context.
initial_messages (List[LLMMessage] | None): The initial messages.
model_client (ChatCompletionClient): The model client to use for token counting.
The model client must implement the :meth:`~autogen_core.models.ChatCompletionClient.count_tokens`
and :meth:`~autogen_core.models.ChatCompletionClient.remaining_tokens` methods.
token_limit (int | None): The maximum number of tokens to keep in the context
using the :meth:`~autogen_core.models.ChatCompletionClient.count_tokens` method.
If None, the context will be limited by the model client using the
:meth:`~autogen_core.models.ChatCompletionClient.remaining_tokens` method.
tools (List[ToolSchema] | None): A list of tool schema to use in the context.
initial_messages (List[LLMMessage] | None): A list of initial messages to include in the context.
"""
component_config_schema = TokenLimitedChatCompletionContextConfig
component_provider_override = "autogen_core.model_context.TokenLimitedChatCompletionContext"
def __init__(self, token_limit: int, model: str, initial_messages: List[LLMMessage] | None = None) -> None:
def __init__(
self,
model_client: ChatCompletionClient,
*,
token_limit: int | None = None,
tool_schema: List[ToolSchema] | None = None,
initial_messages: List[LLMMessage] | None = None,
) -> None:
super().__init__(initial_messages)
if token_limit <= 0:
if token_limit is not None and token_limit <= 0:
raise ValueError("token_limit must be greater than 0.")
self._token_limit = token_limit
self._model = model
self._model_client = model_client
self._tool_schema = tool_schema or []
async def get_messages(self) -> List[LLMMessage]:
"""Get at most `token_limit` tokens in recent messages."""
token_count = count_chat_tokens(self._messages, self._model)
while token_count > self._token_limit:
middle_index = len(self._messages) // 2
self._messages.pop(middle_index)
token_count = count_chat_tokens(self._messages, self._model)
messages = self._messages
# Handle the first message is a function call result message.
"""Get at most `token_limit` tokens in recent messages. If the token limit is not
provided, then return as many messages as the remaining token allowed by the model client."""
messages = list(self._messages)
if self._token_limit is None:
remaining_tokens = self._model_client.remaining_tokens(messages, tools=self._tool_schema)
while remaining_tokens < 0 and len(messages) > 0:
middle_index = len(messages) // 2
messages.pop(middle_index)
remaining_tokens = self._model_client.remaining_tokens(messages, tools=self._tool_schema)
else:
token_count = self._model_client.count_tokens(messages, tools=self._tool_schema)
while token_count > self._token_limit and len(messages) > 0:
middle_index = len(messages) // 2
messages.pop(middle_index)
token_count = self._model_client.count_tokens(messages, tools=self._tool_schema)
if messages and isinstance(messages[0], FunctionExecutionResultMessage):
# Handle the first message is a function call result message.
# Remove the first message from the list.
messages = messages[1:]
return messages
def _to_config(self) -> TokenLimitedChatCompletionContextConfig:
return TokenLimitedChatCompletionContextConfig(
token_limit=self._token_limit, model=self._model, initial_messages=self._messages
model_client=self._model_client.dump_component(),
token_limit=self._token_limit,
tool_schema=self._tool_schema,
initial_messages=self._initial_messages,
)
@classmethod
def _from_config(cls, config: TokenLimitedChatCompletionContextConfig) -> Self:
return cls(**config.model_dump())
def count_chat_tokens(
messages: Sequence[LLMMessage], model: str = "gpt-4o", *, tools: Sequence[Tool | ToolSchema] = []
) -> int:
"""Count tokens for a list of messages using the appropriate client based on the model."""
# Check if the model is an OpenAI model
if "openai" in model.lower():
return count_tokens_openai(messages, model)
# Check if the model is an Ollama model
elif "llama" in model.lower():
return count_tokens_ollama(messages, model)
# Fallback to cl100k_base encoding if the model is unrecognized
else:
encoding = tiktoken.get_encoding("cl100k_base")
total_tokens = 0
for message in messages:
total_tokens += len(encoding.encode(str(message.content)))
return total_tokens
return cls(
model_client=ChatCompletionClient.load_component(config.model_client),
token_limit=config.token_limit,
tool_schema=config.tool_schema,
initial_messages=config.initial_messages,
)

View File

@ -9,7 +9,7 @@ from ._chat_completion_context import ChatCompletionContext
class UnboundedChatCompletionContextConfig(BaseModel):
pass
initial_messages: List[LLMMessage] | None = None
class UnboundedChatCompletionContext(ChatCompletionContext, Component[UnboundedChatCompletionContextConfig]):
@ -23,8 +23,8 @@ class UnboundedChatCompletionContext(ChatCompletionContext, Component[UnboundedC
return self._messages
def _to_config(self) -> UnboundedChatCompletionContextConfig:
return UnboundedChatCompletionContextConfig()
return UnboundedChatCompletionContextConfig(initial_messages=self._initial_messages)
@classmethod
def _from_config(cls, config: UnboundedChatCompletionContextConfig) -> Self:
return cls()
return cls(initial_messages=config.initial_messages)

View File

@ -2,7 +2,8 @@ import json
import logging
from abc import ABC, abstractmethod
from collections.abc import Sequence
from typing import Any, Dict, Generic, Mapping, Protocol, Type, TypedDict, TypeVar, cast, runtime_checkable
from typing import Any, Dict, Generic, Mapping, Protocol, Type, TypeVar, cast, runtime_checkable
from typing_extensions import TypedDict
import jsonref
from opentelemetry.trace import get_tracer

View File

@ -361,7 +361,6 @@ async def test_function_tool() -> None:
await loaded_async.run_json({"x": 1.0, "y": 2.0}, cancelled_token)
@pytest.mark.asyncio
def test_component_descriptions() -> None:
"""Test different ways of setting component descriptions."""
assert MyComponent("test").dump_component().description is None

View File

@ -4,10 +4,18 @@ import pytest
from autogen_core.model_context import (
BufferedChatCompletionContext,
HeadAndTailChatCompletionContext,
UnboundedChatCompletionContext,
TokenLimitedChatCompletionContext,
UnboundedChatCompletionContext,
)
from autogen_core.models import AssistantMessage, LLMMessage, UserMessage, FunctionExecutionResultMessage
from autogen_core.models import (
AssistantMessage,
ChatCompletionClient,
FunctionExecutionResultMessage,
LLMMessage,
UserMessage,
)
from autogen_ext.models.ollama import OllamaChatCompletionClient
from autogen_ext.models.openai import OpenAIChatCompletionClient
@pytest.mark.asyncio
@ -108,8 +116,18 @@ async def test_unbounded_model_context() -> None:
@pytest.mark.asyncio
async def test_token_limited_model_context_openai() -> None:
model_context = TokenLimitedChatCompletionContext(token_limit=20, model="gpt-4o")
@pytest.mark.parametrize(
"model_client,token_limit",
[
(OpenAIChatCompletionClient(model="gpt-4o", temperature=0.0, api_key="test"), 30),
(OllamaChatCompletionClient(model="llama3.3"), 20),
],
ids=["openai", "ollama"],
)
async def test_token_limited_model_context_with_token_limit(
model_client: ChatCompletionClient, token_limit: int
) -> None:
model_context = TokenLimitedChatCompletionContext(model_client=model_client, token_limit=token_limit)
messages: List[LLMMessage] = [
UserMessage(content="Hello!", source="user"),
AssistantMessage(content="What can I do for you?", source="assistant"),
@ -119,37 +137,7 @@ async def test_token_limited_model_context_openai() -> None:
await model_context.add_message(msg)
retrieved = await model_context.get_messages()
assert len(retrieved) == 2 # Token limit set very low, will remove 1 of the messages
assert retrieved != messages # Will not be equal to the original messages
await model_context.clear()
retrieved = await model_context.get_messages()
assert len(retrieved) == 0
# Test saving and loading state.
for msg in messages:
await model_context.add_message(msg)
state = await model_context.save_state()
await model_context.clear()
await model_context.load_state(state)
retrieved = await model_context.get_messages()
assert len(retrieved) == 2
assert retrieved != messages
@pytest.mark.asyncio
async def test_token_limited_model_context_llama() -> None:
model_context = TokenLimitedChatCompletionContext(token_limit=20, model="llama2-7b")
messages: List[LLMMessage] = [
UserMessage(content="Hello!", source="user"),
AssistantMessage(content="What can I do for you?", source="assistant"),
UserMessage(content="Tell what are some fun things to do in seattle.", source="user"),
]
for msg in messages:
await model_context.add_message(msg)
retrieved = await model_context.get_messages()
assert len(retrieved) == 1 # Token limit set very low, will remove two of the messages
assert len(retrieved) == 1 # Token limit set very low, will remove 2 of the messages
assert retrieved != messages # Will not be equal to the original messages
await model_context.clear()
@ -168,8 +156,41 @@ async def test_token_limited_model_context_llama() -> None:
@pytest.mark.asyncio
async def test_token_limited_model_context_openai_with_function_result() -> None:
model_context = TokenLimitedChatCompletionContext(token_limit=1000, model="gpt-4o")
@pytest.mark.parametrize(
"model_client",
[
OpenAIChatCompletionClient(model="gpt-4o", temperature=0.0, api_key="test_key"),
OllamaChatCompletionClient(model="llama3.3"),
],
ids=["openai", "ollama"],
)
async def test_token_limited_model_context_without_token_limit(model_client: ChatCompletionClient) -> None:
model_context = TokenLimitedChatCompletionContext(model_client=model_client)
messages: List[LLMMessage] = [
UserMessage(content="Hello!", source="user"),
AssistantMessage(content="What can I do for you?", source="assistant"),
UserMessage(content="Tell what are some fun things to do in seattle.", source="user"),
]
for msg in messages:
await model_context.add_message(msg)
retrieved = await model_context.get_messages()
assert len(retrieved) == 3
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_client,token_limit",
[
(OpenAIChatCompletionClient(model="gpt-4o", temperature=0.0, api_key="test"), 60),
(OllamaChatCompletionClient(model="llama3.3"), 50),
],
ids=["openai", "ollama"],
)
async def test_token_limited_model_context_openai_with_function_result(
model_client: ChatCompletionClient, token_limit: int
) -> None:
model_context = TokenLimitedChatCompletionContext(model_client=model_client, token_limit=token_limit)
messages: List[LLMMessage] = [
FunctionExecutionResultMessage(content=[]),
UserMessage(content="Hello!", source="user"),