Add ChatCompletionCache along with AbstractStore for caching completions (#4924)

* Add ChatCompletionCache along with AbstractStore for caching completions

* Addressing comments

* Improve interface for cachestore

* Improve documentation & revert protocol

* Make cache store typed, and improve docs

* remove unnecessary casts
This commit is contained in:
Sachin Joglekar 2025-01-16 15:47:38 -08:00 committed by GitHub
parent 2e1a9c737a
commit 8bd65c672f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
22 changed files with 802 additions and 18 deletions

View File

@ -48,6 +48,7 @@ python/autogen_ext.agents.video_surfer
python/autogen_ext.agents.video_surfer.tools
python/autogen_ext.auth.azure
python/autogen_ext.teams.magentic_one
python/autogen_ext.models.cache
python/autogen_ext.models.openai
python/autogen_ext.models.replay
python/autogen_ext.tools.langchain
@ -56,5 +57,7 @@ python/autogen_ext.tools.code_execution
python/autogen_ext.code_executors.local
python/autogen_ext.code_executors.docker
python/autogen_ext.code_executors.azure
python/autogen_ext.cache_store.diskcache
python/autogen_ext.cache_store.redis
python/autogen_ext.runtimes.grpc
```

View File

@ -0,0 +1,8 @@
autogen\_ext.cache_store.diskcache
==================================
.. automodule:: autogen_ext.cache_store.diskcache
:members:
:undoc-members:
:show-inheritance:

View File

@ -0,0 +1,8 @@
autogen\_ext.cache_store.redis
==============================
.. automodule:: autogen_ext.cache_store.redis
:members:
:undoc-members:
:show-inheritance:

View File

@ -0,0 +1,8 @@
autogen\_ext.models.cache
=========================
.. automodule:: autogen_ext.models.cache
:members:
:undoc-members:
:show-inheritance:

View File

@ -1,8 +1,8 @@
autogen\_ext.models.replay
==========================
.. automodule:: autogen_ext.models.replay
:members:
:undoc-members:
:show-inheritance:
autogen\_ext.models.replay
==========================
.. automodule:: autogen_ext.models.replay
:members:
:undoc-members:
:show-inheritance:

View File

@ -6,7 +6,11 @@
"source": [
"# Models\n",
"\n",
"In many cases, agents need access to LLM model services such as OpenAI, Azure OpenAI, or local models. Since there are many different providers with different APIs, `autogen-core` implements a protocol for [model clients](../../core-user-guide/framework/model-clients.ipynb) and `autogen-ext` implements a set of model clients for popular model services. AgentChat can use these model clients to interact with model services. "
"In many cases, agents need access to LLM model services such as OpenAI, Azure OpenAI, or local models. Since there are many different providers with different APIs, `autogen-core` implements a protocol for [model clients](../../core-user-guide/framework/model-clients.ipynb) and `autogen-ext` implements a set of model clients for popular model services. AgentChat can use these model clients to interact with model services. \n",
"\n",
"```{note}\n",
"See {py:class}`~autogen_ext.models.cache.ChatCompletionCache` for a caching wrapper to use with the following clients.\n",
"```"
]
},
{

View File

@ -96,7 +96,13 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Default [Model Capabilities](../faqs.md#what-are-model-capabilities-and-how-do-i-specify-them) may be overridden should the need arise.\n",
"Default [Model Capabilities](../faqs.md#what-are-model-capabilities-and-how-do-i-specify-them) may be overridden should the need arise.\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"\n",
"\n",
"### Streaming Response\n",
@ -315,6 +321,84 @@
"```"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Caching Wrapper\n",
"\n",
"`autogen_ext` implements {py:class}`~autogen_ext.models.cache.ChatCompletionCache` that can wrap any {py:class}`~autogen_core.models.ChatCompletionClient`. Using this wrapper avoids incurring token usage when querying the underlying client with the same prompt multiple times.\n",
"\n",
"{py:class}`~autogen_core.models.ChatCompletionCache` uses a {py:class}`~autogen_core.CacheStore` protocol. We have implemented some useful variants of {py:class}`~autogen_core.CacheStore` including {py:class}`~autogen_ext.cache_store.diskcache.DiskCacheStore` and {py:class}`~autogen_ext.cache_store.redis.RedisStore`.\n",
"\n",
"Here's an example of using `diskcache` for local caching:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# pip install -U \"autogen-ext[openai, diskcache]\""
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"True\n"
]
}
],
"source": [
"import asyncio\n",
"import tempfile\n",
"\n",
"from autogen_core.models import UserMessage\n",
"from autogen_ext.cache_store.diskcache import DiskCacheStore\n",
"from autogen_ext.models.cache import CHAT_CACHE_VALUE_TYPE, ChatCompletionCache\n",
"from autogen_ext.models.openai import OpenAIChatCompletionClient\n",
"from diskcache import Cache\n",
"\n",
"\n",
"async def main() -> None:\n",
" with tempfile.TemporaryDirectory() as tmpdirname:\n",
" # Initialize the original client\n",
" openai_model_client = OpenAIChatCompletionClient(model=\"gpt-4o\")\n",
"\n",
" # Then initialize the CacheStore, in this case with diskcache.Cache.\n",
" # You can also use redis like:\n",
" # from autogen_ext.cache_store.redis import RedisStore\n",
" # import redis\n",
" # redis_instance = redis.Redis()\n",
" # cache_store = RedisCacheStore[CHAT_CACHE_VALUE_TYPE](redis_instance)\n",
" cache_store = DiskCacheStore[CHAT_CACHE_VALUE_TYPE](Cache(tmpdirname))\n",
" cache_client = ChatCompletionCache(openai_model_client, cache_store)\n",
"\n",
" response = await cache_client.create([UserMessage(content=\"Hello, how are you?\", source=\"user\")])\n",
" print(response) # Should print response from OpenAI\n",
" response = await cache_client.create([UserMessage(content=\"Hello, how are you?\", source=\"user\")])\n",
" print(response) # Should print cached response\n",
"\n",
"\n",
"asyncio.run(main())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Inspecting `cached_client.total_usage()` (or `model_client.total_usage()`) before and after a cached response should yield idential counts.\n",
"\n",
"Note that the caching is sensitive to the exact arguments provided to `cached_client.create` or `cached_client.create_stream`, so changing `tools` or `json_output` arguments might lead to a cache miss."
]
},
{
"cell_type": "markdown",
"metadata": {},
@ -615,7 +699,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.7"
"version": "3.12.1"
}
},
"nbformat": 4,

View File

@ -72,6 +72,8 @@ dev = [
"autogen_ext==0.4.3",
# Documentation tooling
"diskcache",
"redis",
"sphinx-autobuild",
]

View File

@ -10,6 +10,7 @@ from ._agent_proxy import AgentProxy
from ._agent_runtime import AgentRuntime
from ._agent_type import AgentType
from ._base_agent import BaseAgent
from ._cache_store import CacheStore, InMemoryStore
from ._cancellation_token import CancellationToken
from ._closure_agent import ClosureAgent, ClosureContext
from ._component_config import (
@ -85,6 +86,8 @@ __all__ = [
"AgentMetadata",
"AgentRuntime",
"BaseAgent",
"CacheStore",
"InMemoryStore",
"CancellationToken",
"AgentInstantiationContext",
"TopicId",

View File

@ -0,0 +1,46 @@
from typing import Dict, Generic, Optional, Protocol, TypeVar
T = TypeVar("T")
class CacheStore(Protocol, Generic[T]):
"""
This protocol defines the basic interface for store/cache operations.
Sub-classes should handle the lifecycle of underlying storage.
"""
def get(self, key: str, default: Optional[T] = None) -> Optional[T]:
"""
Retrieve an item from the store.
Args:
key: The key identifying the item in the store.
default (optional): The default value to return if the key is not found.
Defaults to None.
Returns:
The value associated with the key if found, else the default value.
"""
...
def set(self, key: str, value: T) -> None:
"""
Set an item in the store.
Args:
key: The key under which the item is to be stored.
value: The value to be stored in the store.
"""
...
class InMemoryStore(CacheStore[T]):
def __init__(self) -> None:
self.store: Dict[str, T] = {}
def get(self, key: str, default: Optional[T] = None) -> Optional[T]:
return self.store.get(key, default)
def set(self, key: str, value: T) -> None:
self.store[key] = value

View File

@ -0,0 +1,48 @@
from unittest.mock import Mock
from autogen_core import CacheStore, InMemoryStore
def test_set_and_get_object_key_value() -> None:
mock_store = Mock(spec=CacheStore)
test_key = "test_key"
test_value = object()
mock_store.set(test_key, test_value)
mock_store.get.return_value = test_value
mock_store.set.assert_called_with(test_key, test_value)
assert mock_store.get(test_key) == test_value
def test_get_non_existent_key() -> None:
mock_store = Mock(spec=CacheStore)
key = "non_existent_key"
mock_store.get.return_value = None
assert mock_store.get(key) is None
def test_set_overwrite_existing_key() -> None:
mock_store = Mock(spec=CacheStore)
key = "test_key"
initial_value = "initial_value"
new_value = "new_value"
mock_store.set(key, initial_value)
mock_store.set(key, new_value)
mock_store.get.return_value = new_value
mock_store.set.assert_called_with(key, new_value)
assert mock_store.get(key) == new_value
def test_inmemory_store() -> None:
store = InMemoryStore[int]()
test_key = "test_key"
test_value = 42
store.set(test_key, test_value)
assert store.get(test_key) == test_value
new_value = 2
store.set(test_key, new_value)
assert store.get(test_key) == new_value
key = "non_existent_key"
default_value = 99
assert store.get(key, default_value) == default_value

View File

@ -46,6 +46,12 @@ video-surfer = [
"ffmpeg-python",
"openai-whisper",
]
diskcache = [
"diskcache>=5.6.3"
]
redis = [
"redis>=5.2.1"
]
grpc = [
"grpcio~=1.62.0", # TODO: update this once we have a stable version.

View File

@ -0,0 +1,26 @@
from typing import Any, Optional, TypeVar, cast
import diskcache
from autogen_core import CacheStore
T = TypeVar("T")
class DiskCacheStore(CacheStore[T]):
"""
A typed CacheStore implementation that uses diskcache as the underlying storage.
See :class:`~autogen_ext.models.cache.ChatCompletionCache` for an example of usage.
Args:
cache_instance: An instance of diskcache.Cache.
The user is responsible for managing the DiskCache instance's lifetime.
"""
def __init__(self, cache_instance: diskcache.Cache): # type: ignore[no-any-unimported]
self.cache = cache_instance
def get(self, key: str, default: Optional[T] = None) -> Optional[T]:
return cast(Optional[T], self.cache.get(key, default)) # type: ignore[reportUnknownMemberType]
def set(self, key: str, value: T) -> None:
self.cache.set(key, cast(Any, value)) # type: ignore[reportUnknownMemberType]

View File

@ -0,0 +1,29 @@
from typing import Any, Optional, TypeVar, cast
import redis
from autogen_core import CacheStore
T = TypeVar("T")
class RedisStore(CacheStore[T]):
"""
A typed CacheStore implementation that uses redis as the underlying storage.
See :class:`~autogen_ext.models.cache.ChatCompletionCache` for an example of usage.
Args:
cache_instance: An instance of `redis.Redis`.
The user is responsible for managing the Redis instance's lifetime.
"""
def __init__(self, redis_instance: redis.Redis):
self.cache = redis_instance
def get(self, key: str, default: Optional[T] = None) -> Optional[T]:
value = cast(Optional[T], self.cache.get(key))
if value is None:
return default
return value
def set(self, key: str, value: T) -> None:
self.cache.set(key, cast(Any, value))

View File

@ -0,0 +1,6 @@
from ._chat_completion_cache import CHAT_CACHE_VALUE_TYPE, ChatCompletionCache
__all__ = [
"CHAT_CACHE_VALUE_TYPE",
"ChatCompletionCache",
]

View File

@ -0,0 +1,210 @@
import hashlib
import json
import warnings
from typing import Any, AsyncGenerator, List, Mapping, Optional, Sequence, Union, cast
from autogen_core import CacheStore, CancellationToken
from autogen_core.models import (
ChatCompletionClient,
CreateResult,
LLMMessage,
ModelCapabilities, # type: ignore
ModelInfo,
RequestUsage,
)
from autogen_core.tools import Tool, ToolSchema
CHAT_CACHE_VALUE_TYPE = Union[CreateResult, List[Union[str, CreateResult]]]
class ChatCompletionCache(ChatCompletionClient):
"""
A wrapper around a :class:`~autogen_ext.models.cache.ChatCompletionClient` that caches
creation results from an underlying client.
Cache hits do not contribute to token usage of the original client.
Typical Usage:
Lets use caching on disk with `openai` client as an example.
First install `autogen-ext` with the required packages:
.. code-block:: bash
pip install -U "autogen-ext[openai, diskcache]"
And use it as:
.. code-block:: python
import asyncio
import tempfile
from autogen_core.models import UserMessage
from autogen_ext.models.openai import OpenAIChatCompletionClient
from autogen_ext.models.cache import ChatCompletionCache, CHAT_CACHE_VALUE_TYPE
from autogen_ext.cache_store.diskcache import DiskCacheStore
from diskcache import Cache
async def main():
with tempfile.TemporaryDirectory() as tmpdirname:
# Initialize the original client
openai_model_client = OpenAIChatCompletionClient(model="gpt-4o")
# Then initialize the CacheStore, in this case with diskcache.Cache.
# You can also use redis like:
# from autogen_ext.cache_store.redis import RedisStore
# import redis
# redis_instance = redis.Redis()
# cache_store = RedisCacheStore[CHAT_CACHE_VALUE_TYPE](redis_instance)
cache_store = DiskCacheStore[CHAT_CACHE_VALUE_TYPE](Cache(tmpdirname))
cache_client = ChatCompletionCache(openai_model_client, cache_store)
response = await cache_client.create([UserMessage(content="Hello, how are you?", source="user")])
print(response) # Should print response from OpenAI
response = await cache_client.create([UserMessage(content="Hello, how are you?", source="user")])
print(response) # Should print cached response
asyncio.run(main())
You can now use the `cached_client` as you would the original client, but with caching enabled.
Args:
client (ChatCompletionClient): The original ChatCompletionClient to wrap.
store (CacheStore): A store object that implements get and set methods.
The user is responsible for managing the store's lifecycle & clearing it (if needed).
"""
def __init__(self, client: ChatCompletionClient, store: CacheStore[CHAT_CACHE_VALUE_TYPE]):
self.client = client
self.store = store
def _check_cache(
self,
messages: Sequence[LLMMessage],
tools: Sequence[Tool | ToolSchema],
json_output: Optional[bool],
extra_create_args: Mapping[str, Any],
) -> tuple[Optional[Union[CreateResult, List[Union[str, CreateResult]]]], str]:
"""
Helper function to check the cache for a result.
Returns a tuple of (cached_result, cache_key).
"""
data = {
"messages": [message.model_dump() for message in messages],
"tools": [(tool.schema if isinstance(tool, Tool) else tool) for tool in tools],
"json_output": json_output,
"extra_create_args": extra_create_args,
}
serialized_data = json.dumps(data, sort_keys=True)
cache_key = hashlib.sha256(serialized_data.encode()).hexdigest()
cached_result = cast(Optional[CreateResult], self.store.get(cache_key))
if cached_result is not None:
return cached_result, cache_key
return None, cache_key
async def create(
self,
messages: Sequence[LLMMessage],
*,
tools: Sequence[Tool | ToolSchema] = [],
json_output: Optional[bool] = None,
extra_create_args: Mapping[str, Any] = {},
cancellation_token: Optional[CancellationToken] = None,
) -> CreateResult:
"""
Cached version of ChatCompletionClient.create.
If the result of a call to create has been cached, it will be returned immediately
without invoking the underlying client.
NOTE: cancellation_token is ignored for cached results.
"""
cached_result, cache_key = self._check_cache(messages, tools, json_output, extra_create_args)
if cached_result:
assert isinstance(cached_result, CreateResult)
cached_result.cached = True
return cached_result
result = await self.client.create(
messages,
tools=tools,
json_output=json_output,
extra_create_args=extra_create_args,
cancellation_token=cancellation_token,
)
self.store.set(cache_key, result)
return result
def create_stream(
self,
messages: Sequence[LLMMessage],
*,
tools: Sequence[Tool | ToolSchema] = [],
json_output: Optional[bool] = None,
extra_create_args: Mapping[str, Any] = {},
cancellation_token: Optional[CancellationToken] = None,
) -> AsyncGenerator[Union[str, CreateResult], None]:
"""
Cached version of ChatCompletionClient.create_stream.
If the result of a call to create_stream has been cached, it will be returned
without streaming from the underlying client.
NOTE: cancellation_token is ignored for cached results.
"""
async def _generator() -> AsyncGenerator[Union[str, CreateResult], None]:
cached_result, cache_key = self._check_cache(
messages,
tools,
json_output,
extra_create_args,
)
if cached_result:
assert isinstance(cached_result, list)
for result in cached_result:
if isinstance(result, CreateResult):
result.cached = True
yield result
return
result_stream = self.client.create_stream(
messages,
tools=tools,
json_output=json_output,
extra_create_args=extra_create_args,
cancellation_token=cancellation_token,
)
output_results: List[Union[str, CreateResult]] = []
self.store.set(cache_key, output_results)
async for result in result_stream:
output_results.append(result)
yield result
return _generator()
def actual_usage(self) -> RequestUsage:
return self.client.actual_usage()
def count_tokens(self, messages: Sequence[LLMMessage], *, tools: Sequence[Tool | ToolSchema] = []) -> int:
return self.client.count_tokens(messages, tools=tools)
@property
def capabilities(self) -> ModelCapabilities: # type: ignore
warnings.warn("capabilities is deprecated, use model_info instead", DeprecationWarning, stacklevel=2)
return self.client.capabilities
@property
def model_info(self) -> ModelInfo:
return self.client.model_info
def remaining_tokens(self, messages: Sequence[LLMMessage], *, tools: Sequence[Tool | ToolSchema] = []) -> int:
return self.client.remaining_tokens(messages, tools=tools)
def total_usage(self) -> RequestUsage:
return self.client.total_usage()

View File

@ -40,8 +40,8 @@ class ReplayChatCompletionClient(ChatCompletionClient):
.. code-block:: python
from autogen_ext.models.replay import ReplayChatCompletionClient
from autogen_core.models import UserMessage
from autogen_ext.models.replay import ReplayChatCompletionClient
async def example():
@ -60,8 +60,8 @@ class ReplayChatCompletionClient(ChatCompletionClient):
.. code-block:: python
import asyncio
from autogen_ext.models.replay import ReplayChatCompletionClient
from autogen_core.models import UserMessage
from autogen_ext.models.replay import ReplayChatCompletionClient
async def example():
@ -86,8 +86,8 @@ class ReplayChatCompletionClient(ChatCompletionClient):
.. code-block:: python
import asyncio
from autogen_ext.models.replay import ReplayChatCompletionClient
from autogen_core.models import UserMessage
from autogen_ext.models.replay import ReplayChatCompletionClient
async def example():
@ -129,6 +129,7 @@ class ReplayChatCompletionClient(ChatCompletionClient):
self._cur_usage = RequestUsage(prompt_tokens=0, completion_tokens=0)
self._total_usage = RequestUsage(prompt_tokens=0, completion_tokens=0)
self._current_index = 0
self._cached_bool_value = True
async def create(
self,
@ -148,7 +149,9 @@ class ReplayChatCompletionClient(ChatCompletionClient):
if isinstance(response, str):
_, output_token_count = self._tokenize(response)
self._cur_usage = RequestUsage(prompt_tokens=prompt_token_count, completion_tokens=output_token_count)
response = CreateResult(finish_reason="stop", content=response, usage=self._cur_usage, cached=True)
response = CreateResult(
finish_reason="stop", content=response, usage=self._cur_usage, cached=self._cached_bool_value
)
else:
self._cur_usage = RequestUsage(
prompt_tokens=prompt_token_count, completion_tokens=response.usage.completion_tokens
@ -207,6 +210,9 @@ class ReplayChatCompletionClient(ChatCompletionClient):
0, self._total_available_tokens - self._total_usage.prompt_tokens - self._total_usage.completion_tokens
)
def set_cached_bool_value(self, value: bool) -> None:
self._cached_bool_value = value
def _tokenize(self, messages: Union[str, LLMMessage, Sequence[LLMMessage]]) -> tuple[list[str], int]:
total_tokens = 0
all_tokens: List[str] = []

View File

@ -0,0 +1,48 @@
import tempfile
import pytest
diskcache = pytest.importorskip("diskcache")
def test_diskcache_store_basic() -> None:
from autogen_ext.cache_store.diskcache import DiskCacheStore
from diskcache import Cache
with tempfile.TemporaryDirectory() as temp_dir:
cache = Cache(temp_dir)
store = DiskCacheStore[int](cache)
test_key = "test_key"
test_value = 42
store.set(test_key, test_value)
assert store.get(test_key) == test_value
new_value = 2
store.set(test_key, new_value)
assert store.get(test_key) == new_value
key = "non_existent_key"
default_value = 99
assert store.get(key, default_value) == default_value
def test_diskcache_with_different_instances() -> None:
from autogen_ext.cache_store.diskcache import DiskCacheStore
from diskcache import Cache
with tempfile.TemporaryDirectory() as temp_dir_1, tempfile.TemporaryDirectory() as temp_dir_2:
cache_1 = Cache(temp_dir_1)
cache_2 = Cache(temp_dir_2)
store_1 = DiskCacheStore[int](cache_1)
store_2 = DiskCacheStore[int](cache_2)
test_key = "test_key"
test_value_1 = 5
test_value_2 = 6
store_1.set(test_key, test_value_1)
assert store_1.get(test_key) == test_value_1
store_2.set(test_key, test_value_2)
assert store_2.get(test_key) == test_value_2

View File

@ -0,0 +1,53 @@
from unittest.mock import MagicMock
import pytest
redis = pytest.importorskip("redis")
def test_redis_store_basic() -> None:
from autogen_ext.cache_store.redis import RedisStore
redis_instance = MagicMock()
store = RedisStore[int](redis_instance)
test_key = "test_key"
test_value = 42
store.set(test_key, test_value)
redis_instance.set.assert_called_with(test_key, test_value)
redis_instance.get.return_value = test_value
assert store.get(test_key) == test_value
new_value = 2
store.set(test_key, new_value)
redis_instance.set.assert_called_with(test_key, new_value)
redis_instance.get.return_value = new_value
assert store.get(test_key) == new_value
key = "non_existent_key"
default_value = 99
redis_instance.get.return_value = None
assert store.get(key, default_value) == default_value
def test_redis_with_different_instances() -> None:
from autogen_ext.cache_store.redis import RedisStore
redis_instance_1 = MagicMock()
redis_instance_2 = MagicMock()
store_1 = RedisStore[int](redis_instance_1)
store_2 = RedisStore[int](redis_instance_2)
test_key = "test_key"
test_value_1 = 5
test_value_2 = 6
store_1.set(test_key, test_value_1)
redis_instance_1.set.assert_called_with(test_key, test_value_1)
redis_instance_1.get.return_value = test_value_1
assert store_1.get(test_key) == test_value_1
store_2.set(test_key, test_value_2)
redis_instance_2.set.assert_called_with(test_key, test_value_2)
redis_instance_2.get.return_value = test_value_2
assert store_2.get(test_key) == test_value_2

View File

@ -0,0 +1,133 @@
import copy
from typing import List, Tuple, Union
import pytest
from autogen_core import InMemoryStore
from autogen_core.models import (
ChatCompletionClient,
CreateResult,
LLMMessage,
SystemMessage,
UserMessage,
)
from autogen_ext.models.cache import CHAT_CACHE_VALUE_TYPE, ChatCompletionCache
from autogen_ext.models.replay import ReplayChatCompletionClient
def get_test_data() -> Tuple[list[str], list[str], SystemMessage, ChatCompletionClient, ChatCompletionCache]:
num_messages = 3
responses = [f"This is dummy message number {i}" for i in range(num_messages)]
prompts = [f"This is dummy prompt number {i}" for i in range(num_messages)]
system_prompt = SystemMessage(content="This is a system prompt")
replay_client = ReplayChatCompletionClient(responses)
replay_client.set_cached_bool_value(False)
store = InMemoryStore[CHAT_CACHE_VALUE_TYPE]()
cached_client = ChatCompletionCache(replay_client, store)
return responses, prompts, system_prompt, replay_client, cached_client
@pytest.mark.asyncio
async def test_cache_basic_with_args() -> None:
responses, prompts, system_prompt, _, cached_client = get_test_data()
response0 = await cached_client.create([system_prompt, UserMessage(content=prompts[0], source="user")])
assert isinstance(response0, CreateResult)
assert not response0.cached
assert response0.content == responses[0]
response1 = await cached_client.create([system_prompt, UserMessage(content=prompts[1], source="user")])
assert not response1.cached
assert response1.content == responses[1]
# Cached output.
response0_cached = await cached_client.create([system_prompt, UserMessage(content=prompts[0], source="user")])
assert isinstance(response0, CreateResult)
assert response0_cached.cached
assert response0_cached.content == responses[0]
# Cache miss if args change.
response2 = await cached_client.create(
[system_prompt, UserMessage(content=prompts[0], source="user")], json_output=True
)
assert isinstance(response2, CreateResult)
assert not response2.cached
assert response2.content == responses[2]
@pytest.mark.asyncio
async def test_cache_model_and_count_api() -> None:
_, prompts, system_prompt, replay_client, cached_client = get_test_data()
assert replay_client.model_info == cached_client.model_info
assert replay_client.capabilities == cached_client.capabilities
messages: List[LLMMessage] = [system_prompt, UserMessage(content=prompts[0], source="user")]
assert replay_client.count_tokens(messages) == cached_client.count_tokens(messages)
assert replay_client.remaining_tokens(messages) == cached_client.remaining_tokens(messages)
@pytest.mark.asyncio
async def test_cache_token_usage() -> None:
responses, prompts, system_prompt, replay_client, cached_client = get_test_data()
response0 = await cached_client.create([system_prompt, UserMessage(content=prompts[0], source="user")])
assert isinstance(response0, CreateResult)
assert not response0.cached
assert response0.content == responses[0]
actual_usage0 = copy.copy(cached_client.actual_usage())
total_usage0 = copy.copy(cached_client.total_usage())
response1 = await cached_client.create([system_prompt, UserMessage(content=prompts[1], source="user")])
assert not response1.cached
assert response1.content == responses[1]
actual_usage1 = copy.copy(cached_client.actual_usage())
total_usage1 = copy.copy(cached_client.total_usage())
assert total_usage1.prompt_tokens > total_usage0.prompt_tokens
assert total_usage1.completion_tokens > total_usage0.completion_tokens
assert actual_usage1.prompt_tokens == actual_usage0.prompt_tokens
assert actual_usage1.completion_tokens == actual_usage0.completion_tokens
# Cached output.
response0_cached = await cached_client.create([system_prompt, UserMessage(content=prompts[0], source="user")])
assert isinstance(response0, CreateResult)
assert response0_cached.cached
assert response0_cached.content == responses[0]
total_usage2 = copy.copy(cached_client.total_usage())
assert total_usage2.prompt_tokens == total_usage1.prompt_tokens
assert total_usage2.completion_tokens == total_usage1.completion_tokens
assert cached_client.actual_usage() == replay_client.actual_usage()
assert cached_client.total_usage() == replay_client.total_usage()
@pytest.mark.asyncio
async def test_cache_create_stream() -> None:
_, prompts, system_prompt, _, cached_client = get_test_data()
original_streamed_results: List[Union[str, CreateResult]] = []
async for completion in cached_client.create_stream(
[system_prompt, UserMessage(content=prompts[0], source="user")]
):
original_streamed_results.append(completion)
total_usage0 = copy.copy(cached_client.total_usage())
cached_completion_results: List[Union[str, CreateResult]] = []
async for completion in cached_client.create_stream(
[system_prompt, UserMessage(content=prompts[0], source="user")]
):
cached_completion_results.append(completion)
total_usage1 = copy.copy(cached_client.total_usage())
assert total_usage1.prompt_tokens == total_usage0.prompt_tokens
assert total_usage1.completion_tokens == total_usage0.completion_tokens
for original, cached in zip(original_streamed_results, cached_completion_results, strict=False):
if isinstance(original, str):
assert original == cached
elif isinstance(original, CreateResult) and isinstance(cached, CreateResult):
assert original.content == cached.content
assert cached.cached
assert not original.cached
else:
raise ValueError(f"Unexpected types : {type(original)} and {type(cached)}")

View File

@ -130,7 +130,7 @@ source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "aiohappyeyeballs" },
{ name = "aiosignal" },
{ name = "async-timeout", marker = "python_full_version < '3.11'" },
{ name = "async-timeout", version = "4.0.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" },
{ name = "attrs" },
{ name = "frozenlist" },
{ name = "multidict" },
@ -317,11 +317,30 @@ wheels = [
name = "async-timeout"
version = "4.0.3"
source = { registry = "https://pypi.org/simple" }
resolution-markers = [
"python_full_version < '3.11' and sys_platform == 'darwin'",
"python_full_version < '3.11' and platform_machine == 'aarch64' and sys_platform == 'linux'",
"(python_full_version < '3.11' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version < '3.11' and sys_platform != 'darwin' and sys_platform != 'linux')",
]
sdist = { url = "https://files.pythonhosted.org/packages/87/d6/21b30a550dafea84b1b8eee21b5e23fa16d010ae006011221f33dcd8d7f8/async-timeout-4.0.3.tar.gz", hash = "sha256:4640d96be84d82d02ed59ea2b7105a0f7b33abe8703703cd0ab0bf87c427522f", size = 8345 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/a7/fa/e01228c2938de91d47b307831c62ab9e4001e747789d0b05baf779a6488c/async_timeout-4.0.3-py3-none-any.whl", hash = "sha256:7405140ff1230c310e51dc27b3145b9092d659ce68ff733fb0cefe3ee42be028", size = 5721 },
]
[[package]]
name = "async-timeout"
version = "5.0.1"
source = { registry = "https://pypi.org/simple" }
resolution-markers = [
"python_full_version == '3.11.*' and sys_platform == 'darwin'",
"python_full_version == '3.11.*' and platform_machine == 'aarch64' and sys_platform == 'linux'",
"(python_full_version == '3.11.*' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version == '3.11.*' and sys_platform != 'darwin' and sys_platform != 'linux')",
]
sdist = { url = "https://files.pythonhosted.org/packages/a5/ae/136395dfbfe00dfc94da3f3e136d0b13f394cba8f4841120e34226265780/async_timeout-5.0.1.tar.gz", hash = "sha256:d9321a7a3d5a6a5e187e824d2fa0793ce379a202935782d555d6e9d2735677d3", size = 9274 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/fe/ba/e2081de779ca30d473f21f5b30e0e737c438205440784c7dfc81efc2b029/async_timeout-5.0.1-py3-none-any.whl", hash = "sha256:39e3809566ff85354557ec2398b55e096c8364bacac9405a7a1fa429e77fe76c", size = 6233 },
]
[[package]]
name = "asyncer"
version = "0.0.7"
@ -399,6 +418,7 @@ dev = [
{ name = "azure-identity" },
{ name = "chess" },
{ name = "colorama" },
{ name = "diskcache" },
{ name = "langchain-openai" },
{ name = "langgraph" },
{ name = "llama-index" },
@ -416,6 +436,7 @@ dev = [
{ name = "pydata-sphinx-theme" },
{ name = "pygments" },
{ name = "python-dotenv" },
{ name = "redis" },
{ name = "requests" },
{ name = "sphinx" },
{ name = "sphinx-autobuild" },
@ -455,6 +476,7 @@ dev = [
{ name = "azure-identity" },
{ name = "chess" },
{ name = "colorama" },
{ name = "diskcache" },
{ name = "langchain-openai" },
{ name = "langgraph" },
{ name = "llama-index" },
@ -472,6 +494,7 @@ dev = [
{ name = "pydata-sphinx-theme", specifier = "==0.15.4" },
{ name = "pygments" },
{ name = "python-dotenv" },
{ name = "redis" },
{ name = "requests" },
{ name = "sphinx" },
{ name = "sphinx-autobuild" },
@ -504,6 +527,9 @@ azure = [
{ name = "azure-core" },
{ name = "azure-identity" },
]
diskcache = [
{ name = "diskcache" },
]
docker = [
{ name = "docker" },
]
@ -531,6 +557,9 @@ openai = [
{ name = "openai" },
{ name = "tiktoken" },
]
redis = [
{ name = "redis" },
]
video-surfer = [
{ name = "autogen-agentchat" },
{ name = "ffmpeg-python" },
@ -561,6 +590,7 @@ requires-dist = [
{ name = "autogen-core", editable = "packages/autogen-core" },
{ name = "azure-core", marker = "extra == 'azure'" },
{ name = "azure-identity", marker = "extra == 'azure'" },
{ name = "diskcache", marker = "extra == 'diskcache'", specifier = ">=5.6.3" },
{ name = "docker", marker = "extra == 'docker'", specifier = "~=7.0" },
{ name = "ffmpeg-python", marker = "extra == 'video-surfer'" },
{ name = "graphrag", marker = "extra == 'graphrag'", specifier = ">=1.0.1" },
@ -576,6 +606,7 @@ requires-dist = [
{ name = "pillow", marker = "extra == 'web-surfer'", specifier = ">=11.0.0" },
{ name = "playwright", marker = "extra == 'magentic-one'", specifier = ">=1.48.0" },
{ name = "playwright", marker = "extra == 'web-surfer'", specifier = ">=1.48.0" },
{ name = "redis", marker = "extra == 'redis'", specifier = ">=5.2.1" },
{ name = "tiktoken", marker = "extra == 'openai'", specifier = ">=0.8.0" },
]
@ -1379,6 +1410,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/68/69/1bcf70f81de1b4a9f21b3a62ec0c83bdff991c88d6cc2267d02408457e88/dirtyjson-1.0.8-py3-none-any.whl", hash = "sha256:125e27248435a58acace26d5c2c4c11a1c0de0a9c5124c5a94ba78e517d74f53", size = 25197 },
]
[[package]]
name = "diskcache"
version = "5.6.3"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/3f/21/1c1ffc1a039ddcc459db43cc108658f32c57d271d7289a2794e401d0fdb6/diskcache-5.6.3.tar.gz", hash = "sha256:2c3a3fa2743d8535d832ec61c2054a1641f41775aa7c556758a109941e33e4fc", size = 67916 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/3f/27/4570e78fc0bf5ea0ca45eb1de3818a23787af9b390c0b0a0033a1b8236f9/diskcache-5.6.3-py3-none-any.whl", hash = "sha256:5e31b2d5fbad117cc363ebaf6b689474db18a1f6438bc82358b024abd4c2ca19", size = 45550 },
]
[[package]]
name = "distro"
version = "1.9.0"
@ -2345,7 +2385,7 @@ version = "0.3.14"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "aiohttp" },
{ name = "async-timeout", marker = "python_full_version < '3.11'" },
{ name = "async-timeout", version = "4.0.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" },
{ name = "langchain-core" },
{ name = "langchain-text-splitters" },
{ name = "langsmith" },
@ -4895,6 +4935,19 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/ec/d2/3b2ab40f455a256cb6672186bea95cd97b459ce4594050132d71e76f0d6f/pyzmq-26.2.0-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:90412f2db8c02a3864cbfc67db0e3dcdbda336acf1c469526d3e869394fe001c", size = 550762 },
]
[[package]]
name = "redis"
version = "5.2.1"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "async-timeout", version = "4.0.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" },
{ name = "async-timeout", version = "5.0.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11' and python_full_version < '3.11.3'" },
]
sdist = { url = "https://files.pythonhosted.org/packages/47/da/d283a37303a995cd36f8b92db85135153dc4f7a8e4441aa827721b442cfb/redis-5.2.1.tar.gz", hash = "sha256:16f2e22dff21d5125e8481515e386711a34cbec50f0e44413dd7d9c060a54e0f", size = 4608355 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/3c/5f/fa26b9b2672cbe30e07d9a5bdf39cf16e3b80b42916757c5f92bca88e4ba/redis-5.2.1-py3-none-any.whl", hash = "sha256:ee7e1056b9aea0f04c6c2ed59452947f34c4940ee025f5dd83e6a6418b6989e4", size = 261502 },
]
[[package]]
name = "referencing"
version = "0.35.1"
@ -5917,7 +5970,7 @@ name = "triton"
version = "3.1.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "filelock", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
{ name = "filelock" },
]
wheels = [
{ url = "https://files.pythonhosted.org/packages/98/29/69aa56dc0b2eb2602b553881e34243475ea2afd9699be042316842788ff5/triton-3.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6b0dd10a925263abbe9fa37dcde67a5e9b2383fc269fdf59f5657cac38c5d1d8", size = 209460013 },