mirror of https://github.com/microsoft/autogen.git
Fix token limited model context (#6137)
Token limited model context is currently broken because it is importing from extensions. This fix removed the imports and updated the model context implementation to use model client directly. In the future, the model client's token counting should cache results from model API to provide accurate counting.
This commit is contained in:
parent
0cd3ff46fa
commit
e686342f53
|
@ -14,9 +14,10 @@ from autogen_core import ComponentLoader, ComponentModel
|
|||
from autogen_core.model_context import (
|
||||
BufferedChatCompletionContext,
|
||||
HeadAndTailChatCompletionContext,
|
||||
UnboundedChatCompletionContext,
|
||||
TokenLimitedChatCompletionContext,
|
||||
UnboundedChatCompletionContext,
|
||||
)
|
||||
from autogen_ext.models.openai import OpenAIChatCompletionClient
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
@ -105,7 +106,8 @@ async def test_chat_completion_context_declarative() -> None:
|
|||
unbounded_context = UnboundedChatCompletionContext()
|
||||
buffered_context = BufferedChatCompletionContext(buffer_size=5)
|
||||
head_tail_context = HeadAndTailChatCompletionContext(head_size=3, tail_size=2)
|
||||
token_limited_context = TokenLimitedChatCompletionContext(token_limit=5, model="gpt-4o")
|
||||
model_client = OpenAIChatCompletionClient(model="gpt-4o", api_key="test_key")
|
||||
token_limited_context = TokenLimitedChatCompletionContext(model_client=model_client, token_limit=5)
|
||||
|
||||
# Test serialization
|
||||
unbounded_config = unbounded_context.dump_component()
|
||||
|
@ -123,7 +125,10 @@ async def test_chat_completion_context_declarative() -> None:
|
|||
token_limited_config = token_limited_context.dump_component()
|
||||
assert token_limited_config.provider == "autogen_core.model_context.TokenLimitedChatCompletionContext"
|
||||
assert token_limited_config.config["token_limit"] == 5
|
||||
assert token_limited_config.config["model"] == "gpt-4o"
|
||||
assert (
|
||||
token_limited_config.config["model_client"]["provider"]
|
||||
== "autogen_ext.models.openai.OpenAIChatCompletionClient"
|
||||
)
|
||||
|
||||
# Test deserialization
|
||||
loaded_unbounded = ComponentLoader.load_component(unbounded_config, UnboundedChatCompletionContext)
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
from ._buffered_chat_completion_context import BufferedChatCompletionContext
|
||||
from ._token_limited_chat_completion_context import TokenLimitedChatCompletionContext
|
||||
from ._chat_completion_context import ChatCompletionContext, ChatCompletionContextState
|
||||
from ._head_and_tail_chat_completion_context import HeadAndTailChatCompletionContext
|
||||
from ._token_limited_chat_completion_context import TokenLimitedChatCompletionContext
|
||||
from ._unbounded_chat_completion_context import (
|
||||
UnboundedChatCompletionContext,
|
||||
)
|
||||
|
|
|
@ -41,7 +41,9 @@ class BufferedChatCompletionContext(ChatCompletionContext, Component[BufferedCha
|
|||
return messages
|
||||
|
||||
def _to_config(self) -> BufferedChatCompletionContextConfig:
|
||||
return BufferedChatCompletionContextConfig(buffer_size=self._buffer_size, initial_messages=self._messages)
|
||||
return BufferedChatCompletionContextConfig(
|
||||
buffer_size=self._buffer_size, initial_messages=self._initial_messages
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _from_config(cls, config: BufferedChatCompletionContextConfig) -> Self:
|
||||
|
|
|
@ -47,7 +47,10 @@ class ChatCompletionContext(ABC, ComponentBase[BaseModel]):
|
|||
component_type = "chat_completion_context"
|
||||
|
||||
def __init__(self, initial_messages: List[LLMMessage] | None = None) -> None:
|
||||
self._messages: List[LLMMessage] = initial_messages or []
|
||||
self._messages: List[LLMMessage] = []
|
||||
if initial_messages is not None:
|
||||
self._messages.extend(initial_messages)
|
||||
self._initial_messages = initial_messages
|
||||
|
||||
async def add_message(self, message: LLMMessage) -> None:
|
||||
"""Add a message to the context."""
|
||||
|
|
|
@ -68,7 +68,7 @@ class HeadAndTailChatCompletionContext(ChatCompletionContext, Component[HeadAndT
|
|||
|
||||
def _to_config(self) -> HeadAndTailChatCompletionContextConfig:
|
||||
return HeadAndTailChatCompletionContextConfig(
|
||||
head_size=self._head_size, tail_size=self._tail_size, initial_messages=self._messages
|
||||
head_size=self._head_size, tail_size=self._tail_size, initial_messages=self._initial_messages
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
|
|
@ -1,83 +1,94 @@
|
|||
from typing import List, Sequence
|
||||
from autogen_core.tools import Tool, ToolSchema
|
||||
from typing import List
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import Self
|
||||
import tiktoken
|
||||
|
||||
from .._component_config import Component
|
||||
from ..models import FunctionExecutionResultMessage, LLMMessage
|
||||
from .._component_config import Component, ComponentModel
|
||||
from ..models import ChatCompletionClient, FunctionExecutionResultMessage, LLMMessage
|
||||
from ..tools import ToolSchema
|
||||
from ._chat_completion_context import ChatCompletionContext
|
||||
|
||||
from autogen_ext.models.ollama._ollama_client import count_tokens_ollama
|
||||
from autogen_ext.models.openai._openai_client import count_tokens_openai
|
||||
|
||||
|
||||
class TokenLimitedChatCompletionContextConfig(BaseModel):
|
||||
token_limit: int
|
||||
model: str
|
||||
model_client: ComponentModel
|
||||
token_limit: int | None = None
|
||||
tool_schema: List[ToolSchema] | None = None
|
||||
initial_messages: List[LLMMessage] | None = None
|
||||
|
||||
|
||||
class TokenLimitedChatCompletionContext(ChatCompletionContext, Component[TokenLimitedChatCompletionContextConfig]):
|
||||
"""A token based chat completion context maintains a view of the context up to a token limit,
|
||||
where n is the token limit. The token limit is set at initialization.
|
||||
"""(Experimental) A token based chat completion context maintains a view of the context up to a token limit.
|
||||
|
||||
.. note::
|
||||
|
||||
Added in v0.4.10. This is an experimental component and may change in the future.
|
||||
|
||||
Args:
|
||||
token_limit (int): Max tokens for context.
|
||||
initial_messages (List[LLMMessage] | None): The initial messages.
|
||||
model_client (ChatCompletionClient): The model client to use for token counting.
|
||||
The model client must implement the :meth:`~autogen_core.models.ChatCompletionClient.count_tokens`
|
||||
and :meth:`~autogen_core.models.ChatCompletionClient.remaining_tokens` methods.
|
||||
token_limit (int | None): The maximum number of tokens to keep in the context
|
||||
using the :meth:`~autogen_core.models.ChatCompletionClient.count_tokens` method.
|
||||
If None, the context will be limited by the model client using the
|
||||
:meth:`~autogen_core.models.ChatCompletionClient.remaining_tokens` method.
|
||||
tools (List[ToolSchema] | None): A list of tool schema to use in the context.
|
||||
initial_messages (List[LLMMessage] | None): A list of initial messages to include in the context.
|
||||
|
||||
"""
|
||||
|
||||
component_config_schema = TokenLimitedChatCompletionContextConfig
|
||||
component_provider_override = "autogen_core.model_context.TokenLimitedChatCompletionContext"
|
||||
|
||||
def __init__(self, token_limit: int, model: str, initial_messages: List[LLMMessage] | None = None) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
model_client: ChatCompletionClient,
|
||||
*,
|
||||
token_limit: int | None = None,
|
||||
tool_schema: List[ToolSchema] | None = None,
|
||||
initial_messages: List[LLMMessage] | None = None,
|
||||
) -> None:
|
||||
super().__init__(initial_messages)
|
||||
if token_limit <= 0:
|
||||
if token_limit is not None and token_limit <= 0:
|
||||
raise ValueError("token_limit must be greater than 0.")
|
||||
self._token_limit = token_limit
|
||||
self._model = model
|
||||
self._model_client = model_client
|
||||
self._tool_schema = tool_schema or []
|
||||
|
||||
async def get_messages(self) -> List[LLMMessage]:
|
||||
"""Get at most `token_limit` tokens in recent messages."""
|
||||
token_count = count_chat_tokens(self._messages, self._model)
|
||||
while token_count > self._token_limit:
|
||||
middle_index = len(self._messages) // 2
|
||||
self._messages.pop(middle_index)
|
||||
token_count = count_chat_tokens(self._messages, self._model)
|
||||
messages = self._messages
|
||||
# Handle the first message is a function call result message.
|
||||
"""Get at most `token_limit` tokens in recent messages. If the token limit is not
|
||||
provided, then return as many messages as the remaining token allowed by the model client."""
|
||||
messages = list(self._messages)
|
||||
if self._token_limit is None:
|
||||
remaining_tokens = self._model_client.remaining_tokens(messages, tools=self._tool_schema)
|
||||
while remaining_tokens < 0 and len(messages) > 0:
|
||||
middle_index = len(messages) // 2
|
||||
messages.pop(middle_index)
|
||||
remaining_tokens = self._model_client.remaining_tokens(messages, tools=self._tool_schema)
|
||||
else:
|
||||
token_count = self._model_client.count_tokens(messages, tools=self._tool_schema)
|
||||
while token_count > self._token_limit and len(messages) > 0:
|
||||
middle_index = len(messages) // 2
|
||||
messages.pop(middle_index)
|
||||
token_count = self._model_client.count_tokens(messages, tools=self._tool_schema)
|
||||
if messages and isinstance(messages[0], FunctionExecutionResultMessage):
|
||||
# Handle the first message is a function call result message.
|
||||
# Remove the first message from the list.
|
||||
messages = messages[1:]
|
||||
return messages
|
||||
|
||||
def _to_config(self) -> TokenLimitedChatCompletionContextConfig:
|
||||
return TokenLimitedChatCompletionContextConfig(
|
||||
token_limit=self._token_limit, model=self._model, initial_messages=self._messages
|
||||
model_client=self._model_client.dump_component(),
|
||||
token_limit=self._token_limit,
|
||||
tool_schema=self._tool_schema,
|
||||
initial_messages=self._initial_messages,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _from_config(cls, config: TokenLimitedChatCompletionContextConfig) -> Self:
|
||||
return cls(**config.model_dump())
|
||||
|
||||
|
||||
def count_chat_tokens(
|
||||
messages: Sequence[LLMMessage], model: str = "gpt-4o", *, tools: Sequence[Tool | ToolSchema] = []
|
||||
) -> int:
|
||||
"""Count tokens for a list of messages using the appropriate client based on the model."""
|
||||
# Check if the model is an OpenAI model
|
||||
if "openai" in model.lower():
|
||||
return count_tokens_openai(messages, model)
|
||||
|
||||
# Check if the model is an Ollama model
|
||||
elif "llama" in model.lower():
|
||||
return count_tokens_ollama(messages, model)
|
||||
|
||||
# Fallback to cl100k_base encoding if the model is unrecognized
|
||||
else:
|
||||
encoding = tiktoken.get_encoding("cl100k_base")
|
||||
total_tokens = 0
|
||||
for message in messages:
|
||||
total_tokens += len(encoding.encode(str(message.content)))
|
||||
return total_tokens
|
||||
return cls(
|
||||
model_client=ChatCompletionClient.load_component(config.model_client),
|
||||
token_limit=config.token_limit,
|
||||
tool_schema=config.tool_schema,
|
||||
initial_messages=config.initial_messages,
|
||||
)
|
||||
|
|
|
@ -9,7 +9,7 @@ from ._chat_completion_context import ChatCompletionContext
|
|||
|
||||
|
||||
class UnboundedChatCompletionContextConfig(BaseModel):
|
||||
pass
|
||||
initial_messages: List[LLMMessage] | None = None
|
||||
|
||||
|
||||
class UnboundedChatCompletionContext(ChatCompletionContext, Component[UnboundedChatCompletionContextConfig]):
|
||||
|
@ -23,8 +23,8 @@ class UnboundedChatCompletionContext(ChatCompletionContext, Component[UnboundedC
|
|||
return self._messages
|
||||
|
||||
def _to_config(self) -> UnboundedChatCompletionContextConfig:
|
||||
return UnboundedChatCompletionContextConfig()
|
||||
return UnboundedChatCompletionContextConfig(initial_messages=self._initial_messages)
|
||||
|
||||
@classmethod
|
||||
def _from_config(cls, config: UnboundedChatCompletionContextConfig) -> Self:
|
||||
return cls()
|
||||
return cls(initial_messages=config.initial_messages)
|
||||
|
|
|
@ -2,7 +2,8 @@ import json
|
|||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Dict, Generic, Mapping, Protocol, Type, TypedDict, TypeVar, cast, runtime_checkable
|
||||
from typing import Any, Dict, Generic, Mapping, Protocol, Type, TypeVar, cast, runtime_checkable
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
import jsonref
|
||||
from opentelemetry.trace import get_tracer
|
||||
|
|
|
@ -361,7 +361,6 @@ async def test_function_tool() -> None:
|
|||
await loaded_async.run_json({"x": 1.0, "y": 2.0}, cancelled_token)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
def test_component_descriptions() -> None:
|
||||
"""Test different ways of setting component descriptions."""
|
||||
assert MyComponent("test").dump_component().description is None
|
||||
|
|
|
@ -4,10 +4,18 @@ import pytest
|
|||
from autogen_core.model_context import (
|
||||
BufferedChatCompletionContext,
|
||||
HeadAndTailChatCompletionContext,
|
||||
UnboundedChatCompletionContext,
|
||||
TokenLimitedChatCompletionContext,
|
||||
UnboundedChatCompletionContext,
|
||||
)
|
||||
from autogen_core.models import AssistantMessage, LLMMessage, UserMessage, FunctionExecutionResultMessage
|
||||
from autogen_core.models import (
|
||||
AssistantMessage,
|
||||
ChatCompletionClient,
|
||||
FunctionExecutionResultMessage,
|
||||
LLMMessage,
|
||||
UserMessage,
|
||||
)
|
||||
from autogen_ext.models.ollama import OllamaChatCompletionClient
|
||||
from autogen_ext.models.openai import OpenAIChatCompletionClient
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
@ -108,8 +116,18 @@ async def test_unbounded_model_context() -> None:
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_token_limited_model_context_openai() -> None:
|
||||
model_context = TokenLimitedChatCompletionContext(token_limit=20, model="gpt-4o")
|
||||
@pytest.mark.parametrize(
|
||||
"model_client,token_limit",
|
||||
[
|
||||
(OpenAIChatCompletionClient(model="gpt-4o", temperature=0.0, api_key="test"), 30),
|
||||
(OllamaChatCompletionClient(model="llama3.3"), 20),
|
||||
],
|
||||
ids=["openai", "ollama"],
|
||||
)
|
||||
async def test_token_limited_model_context_with_token_limit(
|
||||
model_client: ChatCompletionClient, token_limit: int
|
||||
) -> None:
|
||||
model_context = TokenLimitedChatCompletionContext(model_client=model_client, token_limit=token_limit)
|
||||
messages: List[LLMMessage] = [
|
||||
UserMessage(content="Hello!", source="user"),
|
||||
AssistantMessage(content="What can I do for you?", source="assistant"),
|
||||
|
@ -119,37 +137,7 @@ async def test_token_limited_model_context_openai() -> None:
|
|||
await model_context.add_message(msg)
|
||||
|
||||
retrieved = await model_context.get_messages()
|
||||
assert len(retrieved) == 2 # Token limit set very low, will remove 1 of the messages
|
||||
assert retrieved != messages # Will not be equal to the original messages
|
||||
|
||||
await model_context.clear()
|
||||
retrieved = await model_context.get_messages()
|
||||
assert len(retrieved) == 0
|
||||
|
||||
# Test saving and loading state.
|
||||
for msg in messages:
|
||||
await model_context.add_message(msg)
|
||||
state = await model_context.save_state()
|
||||
await model_context.clear()
|
||||
await model_context.load_state(state)
|
||||
retrieved = await model_context.get_messages()
|
||||
assert len(retrieved) == 2
|
||||
assert retrieved != messages
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_token_limited_model_context_llama() -> None:
|
||||
model_context = TokenLimitedChatCompletionContext(token_limit=20, model="llama2-7b")
|
||||
messages: List[LLMMessage] = [
|
||||
UserMessage(content="Hello!", source="user"),
|
||||
AssistantMessage(content="What can I do for you?", source="assistant"),
|
||||
UserMessage(content="Tell what are some fun things to do in seattle.", source="user"),
|
||||
]
|
||||
for msg in messages:
|
||||
await model_context.add_message(msg)
|
||||
|
||||
retrieved = await model_context.get_messages()
|
||||
assert len(retrieved) == 1 # Token limit set very low, will remove two of the messages
|
||||
assert len(retrieved) == 1 # Token limit set very low, will remove 2 of the messages
|
||||
assert retrieved != messages # Will not be equal to the original messages
|
||||
|
||||
await model_context.clear()
|
||||
|
@ -168,8 +156,41 @@ async def test_token_limited_model_context_llama() -> None:
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_token_limited_model_context_openai_with_function_result() -> None:
|
||||
model_context = TokenLimitedChatCompletionContext(token_limit=1000, model="gpt-4o")
|
||||
@pytest.mark.parametrize(
|
||||
"model_client",
|
||||
[
|
||||
OpenAIChatCompletionClient(model="gpt-4o", temperature=0.0, api_key="test_key"),
|
||||
OllamaChatCompletionClient(model="llama3.3"),
|
||||
],
|
||||
ids=["openai", "ollama"],
|
||||
)
|
||||
async def test_token_limited_model_context_without_token_limit(model_client: ChatCompletionClient) -> None:
|
||||
model_context = TokenLimitedChatCompletionContext(model_client=model_client)
|
||||
messages: List[LLMMessage] = [
|
||||
UserMessage(content="Hello!", source="user"),
|
||||
AssistantMessage(content="What can I do for you?", source="assistant"),
|
||||
UserMessage(content="Tell what are some fun things to do in seattle.", source="user"),
|
||||
]
|
||||
for msg in messages:
|
||||
await model_context.add_message(msg)
|
||||
|
||||
retrieved = await model_context.get_messages()
|
||||
assert len(retrieved) == 3
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"model_client,token_limit",
|
||||
[
|
||||
(OpenAIChatCompletionClient(model="gpt-4o", temperature=0.0, api_key="test"), 60),
|
||||
(OllamaChatCompletionClient(model="llama3.3"), 50),
|
||||
],
|
||||
ids=["openai", "ollama"],
|
||||
)
|
||||
async def test_token_limited_model_context_openai_with_function_result(
|
||||
model_client: ChatCompletionClient, token_limit: int
|
||||
) -> None:
|
||||
model_context = TokenLimitedChatCompletionContext(model_client=model_client, token_limit=token_limit)
|
||||
messages: List[LLMMessage] = [
|
||||
FunctionExecutionResultMessage(content=[]),
|
||||
UserMessage(content="Hello!", source="user"),
|
||||
|
|
Loading…
Reference in New Issue