5663 ollama client host (#5674)

@ekzhu should likely be assigned as reviewer

## Why are these changes needed?

These changes address the bug reported in #5663. Prevents TypeError from
being thrown at inference time by ollama AsyncClient when `host` (and
other) kwargs are passed to autogen OllamaChatCompletionClient
constructor.

It also adds ollama as a named optional extra so that the ollama
requirements can be installed alongside autogen-ext (e.g. `pip install
autogen-ext[ollama]`

@ekzhu, I will need some help or guidance to ensure that the associated
test (which requires ollama and tiktoken as dependencies of the
OllamaChatCompletionClient) can run successfully in autogen's test
execution environment.

I have also left the "I've made sure all auto checks have passed" check
below unchecked as this PR is coming from my fork. (UPDATE: auto checks
appear to have passed after opening PR, so I have checked box below)

## Related issue number

Intended to close #5663 

## Checks

- [x] I've included any doc changes needed for
<https://microsoft.github.io/autogen/>. See
<https://github.com/microsoft/autogen/blob/main/CONTRIBUTING.md> to
build and test documentation locally.
- [x] I've added tests (if relevant) corresponding to the changes
introduced in this PR.
- [x] I've made sure all auto checks have passed.

---------

Co-authored-by: Ryan Stewart <ryanstewart@Ryans-MacBook-Pro.local>
Co-authored-by: Jack Gerrits <jackgerrits@users.noreply.github.com>
Co-authored-by: peterychang <49209570+peterychang@users.noreply.github.com>
This commit is contained in:
rylativity 2025-02-26 11:02:48 -05:00 committed by GitHub
parent 05fc763b8a
commit 5615f40a30
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 74 additions and 1 deletions

View File

@ -26,6 +26,7 @@ azure = [
"azure-identity",
]
docker = ["docker~=7.0", "asyncio_atexit>=1.0.1"]
ollama = ["ollama>=0.4.7", "tiktoken>=0.8.0"]
openai = ["openai>=1.52.2", "tiktoken>=0.8.0", "aiofiles"]
file-surfer = [
"autogen-agentchat==0.4.8",

View File

@ -1,4 +1,5 @@
import asyncio
import inspect
import json
import logging
import math
@ -46,6 +47,7 @@ from autogen_core.tools import Tool, ToolSchema
from ollama import AsyncClient, ChatResponse, Message
from ollama import Image as OllamaImage
from ollama import Tool as OllamaTool
from ollama._types import ChatRequest
from pydantic import BaseModel
from typing_extensions import Self, Unpack
@ -75,8 +77,20 @@ def _ollama_client_from_config(config: Mapping[str, Any]) -> AsyncClient:
return AsyncClient(**ollama_config)
ollama_chat_request_fields: dict[str, Any] = [m for m in inspect.getmembers(ChatRequest) if m[0] == "model_fields"][0][
1
]
OLLAMA_VALID_CREATE_KWARGS_KEYS = set(ollama_chat_request_fields.keys()) | set(
("model", "messages", "tools", "stream", "format", "options", "keep_alive")
)
def _create_args_from_config(config: Mapping[str, Any]) -> Dict[str, Any]:
return dict(config).copy()
create_args = {k.lower(): v for k, v in config.items() if k.lower() in OLLAMA_VALID_CREATE_KWARGS_KEYS}
dropped_keys = [k for k in config.keys() if k.lower() not in OLLAMA_VALID_CREATE_KWARGS_KEYS]
logger.info(f"Dropped the following unrecognized keys from create_args: {dropped_keys}")
return create_args
# create_args = {k: v for k, v in config.items() if k in create_kwargs}
# create_args_keys = set(create_args.keys())
# if not required_create_args.issubset(create_args_keys):
@ -374,6 +388,9 @@ class BaseOllamaChatCompletionClient(ChatCompletionClient):
def create_from_config(cls, config: Dict[str, Any]) -> ChatCompletionClient:
return OllamaChatCompletionClient(**config)
def get_create_args(self) -> Mapping[str, Any]:
return self._create_args
async def create(
self,
messages: Sequence[LLMMessage],

View File

@ -0,0 +1,49 @@
from typing import Any, Mapping
import pytest
from autogen_core.models._types import UserMessage
from autogen_ext.models.ollama import OllamaChatCompletionClient
from autogen_ext.models.ollama._ollama_client import OLLAMA_VALID_CREATE_KWARGS_KEYS
from httpx import Response
from ollama import AsyncClient
def _mock_request(*args: Any, **kwargs: Any) -> Response:
return Response(status_code=200, content="{'response': 'Hello world!'}")
@pytest.mark.asyncio
async def test_ollama_chat_completion_client_doesnt_error_with_host_kwarg(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(AsyncClient, "_request", _mock_request)
client = OllamaChatCompletionClient(model="llama3.1", host="http://testyhostname:11434")
## Call to client.create will throw a ConnectionError,
# but that will only occur if the call to the AsyncChat's .chat() method does not receive unexpected kwargs
# and does not throw a TypeError with unrecognized kwargs
# (i.e. the extra unrecognized kwargs have been successfully removed)
try:
await client.create([UserMessage(content="hi", source="user")])
except TypeError as e:
assert "AsyncClient.chat() got an unexpected keyword argument" not in e.args[0]
def test_create_args_from_config_drops_unexpected_kwargs() -> None:
test_config: Mapping[str, Any] = {
"model": "llama3.1",
"messages": [],
"tools": [],
"stream": False,
"format": "json",
"options": {},
"keep_alive": 100,
"extra_unexpected_kwarg": "value",
"another_extra_unexpected_kwarg": "another_value",
}
client = OllamaChatCompletionClient(**test_config)
final_create_args = client.get_create_args()
for arg in final_create_args.keys():
assert arg in OLLAMA_VALID_CREATE_KWARGS_KEYS

View File

@ -623,6 +623,10 @@ mcp = [
{ name = "json-schema-to-pydantic" },
{ name = "mcp" },
]
ollama = [
{ name = "ollama" },
{ name = "tiktoken" },
]
openai = [
{ name = "aiofiles" },
{ name = "openai" },
@ -716,6 +720,7 @@ requires-dist = [
{ name = "markitdown", marker = "extra == 'web-surfer'", specifier = ">=0.0.1a2" },
{ name = "mcp", marker = "extra == 'mcp'", specifier = ">=1.1.3" },
{ name = "nbclient", marker = "extra == 'jupyter-executor'", specifier = ">=0.10.2" },
{ name = "ollama", marker = "extra == 'ollama'", specifier = ">=0.4.7" },
{ name = "openai", marker = "extra == 'openai'", specifier = ">=1.52.2" },
{ name = "openai-whisper", marker = "extra == 'video-surfer'" },
{ name = "opencv-python", marker = "extra == 'video-surfer'", specifier = ">=4.5" },
@ -736,6 +741,7 @@ requires-dist = [
{ name = "semantic-kernel", extras = ["ollama"], marker = "extra == 'semantic-kernel-ollama'", specifier = ">=1.17.1" },
{ name = "semantic-kernel", extras = ["onnx"], marker = "extra == 'semantic-kernel-onnx'", specifier = ">=1.17.1" },
{ name = "semantic-kernel", extras = ["pandas"], marker = "extra == 'semantic-kernel-pandas'", specifier = ">=1.17.1" },
{ name = "tiktoken", marker = "extra == 'ollama'", specifier = ">=0.8.0" },
{ name = "tiktoken", marker = "extra == 'openai'", specifier = ">=0.8.0" },
]