mirror of https://github.com/microsoft/autogen.git
5663 ollama client host (#5674)
@ekzhu should likely be assigned as reviewer ## Why are these changes needed? These changes address the bug reported in #5663. Prevents TypeError from being thrown at inference time by ollama AsyncClient when `host` (and other) kwargs are passed to autogen OllamaChatCompletionClient constructor. It also adds ollama as a named optional extra so that the ollama requirements can be installed alongside autogen-ext (e.g. `pip install autogen-ext[ollama]` @ekzhu, I will need some help or guidance to ensure that the associated test (which requires ollama and tiktoken as dependencies of the OllamaChatCompletionClient) can run successfully in autogen's test execution environment. I have also left the "I've made sure all auto checks have passed" check below unchecked as this PR is coming from my fork. (UPDATE: auto checks appear to have passed after opening PR, so I have checked box below) ## Related issue number Intended to close #5663 ## Checks - [x] I've included any doc changes needed for <https://microsoft.github.io/autogen/>. See <https://github.com/microsoft/autogen/blob/main/CONTRIBUTING.md> to build and test documentation locally. - [x] I've added tests (if relevant) corresponding to the changes introduced in this PR. - [x] I've made sure all auto checks have passed. --------- Co-authored-by: Ryan Stewart <ryanstewart@Ryans-MacBook-Pro.local> Co-authored-by: Jack Gerrits <jackgerrits@users.noreply.github.com> Co-authored-by: peterychang <49209570+peterychang@users.noreply.github.com>
This commit is contained in:
parent
05fc763b8a
commit
5615f40a30
|
@ -26,6 +26,7 @@ azure = [
|
|||
"azure-identity",
|
||||
]
|
||||
docker = ["docker~=7.0", "asyncio_atexit>=1.0.1"]
|
||||
ollama = ["ollama>=0.4.7", "tiktoken>=0.8.0"]
|
||||
openai = ["openai>=1.52.2", "tiktoken>=0.8.0", "aiofiles"]
|
||||
file-surfer = [
|
||||
"autogen-agentchat==0.4.8",
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import asyncio
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
|
@ -46,6 +47,7 @@ from autogen_core.tools import Tool, ToolSchema
|
|||
from ollama import AsyncClient, ChatResponse, Message
|
||||
from ollama import Image as OllamaImage
|
||||
from ollama import Tool as OllamaTool
|
||||
from ollama._types import ChatRequest
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import Self, Unpack
|
||||
|
||||
|
@ -75,8 +77,20 @@ def _ollama_client_from_config(config: Mapping[str, Any]) -> AsyncClient:
|
|||
return AsyncClient(**ollama_config)
|
||||
|
||||
|
||||
ollama_chat_request_fields: dict[str, Any] = [m for m in inspect.getmembers(ChatRequest) if m[0] == "model_fields"][0][
|
||||
1
|
||||
]
|
||||
OLLAMA_VALID_CREATE_KWARGS_KEYS = set(ollama_chat_request_fields.keys()) | set(
|
||||
("model", "messages", "tools", "stream", "format", "options", "keep_alive")
|
||||
)
|
||||
|
||||
|
||||
def _create_args_from_config(config: Mapping[str, Any]) -> Dict[str, Any]:
|
||||
return dict(config).copy()
|
||||
create_args = {k.lower(): v for k, v in config.items() if k.lower() in OLLAMA_VALID_CREATE_KWARGS_KEYS}
|
||||
dropped_keys = [k for k in config.keys() if k.lower() not in OLLAMA_VALID_CREATE_KWARGS_KEYS]
|
||||
logger.info(f"Dropped the following unrecognized keys from create_args: {dropped_keys}")
|
||||
|
||||
return create_args
|
||||
# create_args = {k: v for k, v in config.items() if k in create_kwargs}
|
||||
# create_args_keys = set(create_args.keys())
|
||||
# if not required_create_args.issubset(create_args_keys):
|
||||
|
@ -374,6 +388,9 @@ class BaseOllamaChatCompletionClient(ChatCompletionClient):
|
|||
def create_from_config(cls, config: Dict[str, Any]) -> ChatCompletionClient:
|
||||
return OllamaChatCompletionClient(**config)
|
||||
|
||||
def get_create_args(self) -> Mapping[str, Any]:
|
||||
return self._create_args
|
||||
|
||||
async def create(
|
||||
self,
|
||||
messages: Sequence[LLMMessage],
|
||||
|
|
|
@ -0,0 +1,49 @@
|
|||
from typing import Any, Mapping
|
||||
|
||||
import pytest
|
||||
from autogen_core.models._types import UserMessage
|
||||
from autogen_ext.models.ollama import OllamaChatCompletionClient
|
||||
from autogen_ext.models.ollama._ollama_client import OLLAMA_VALID_CREATE_KWARGS_KEYS
|
||||
from httpx import Response
|
||||
from ollama import AsyncClient
|
||||
|
||||
|
||||
def _mock_request(*args: Any, **kwargs: Any) -> Response:
|
||||
return Response(status_code=200, content="{'response': 'Hello world!'}")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ollama_chat_completion_client_doesnt_error_with_host_kwarg(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(AsyncClient, "_request", _mock_request)
|
||||
|
||||
client = OllamaChatCompletionClient(model="llama3.1", host="http://testyhostname:11434")
|
||||
|
||||
## Call to client.create will throw a ConnectionError,
|
||||
# but that will only occur if the call to the AsyncChat's .chat() method does not receive unexpected kwargs
|
||||
# and does not throw a TypeError with unrecognized kwargs
|
||||
# (i.e. the extra unrecognized kwargs have been successfully removed)
|
||||
try:
|
||||
await client.create([UserMessage(content="hi", source="user")])
|
||||
except TypeError as e:
|
||||
assert "AsyncClient.chat() got an unexpected keyword argument" not in e.args[0]
|
||||
|
||||
|
||||
def test_create_args_from_config_drops_unexpected_kwargs() -> None:
|
||||
test_config: Mapping[str, Any] = {
|
||||
"model": "llama3.1",
|
||||
"messages": [],
|
||||
"tools": [],
|
||||
"stream": False,
|
||||
"format": "json",
|
||||
"options": {},
|
||||
"keep_alive": 100,
|
||||
"extra_unexpected_kwarg": "value",
|
||||
"another_extra_unexpected_kwarg": "another_value",
|
||||
}
|
||||
|
||||
client = OllamaChatCompletionClient(**test_config)
|
||||
|
||||
final_create_args = client.get_create_args()
|
||||
|
||||
for arg in final_create_args.keys():
|
||||
assert arg in OLLAMA_VALID_CREATE_KWARGS_KEYS
|
|
@ -623,6 +623,10 @@ mcp = [
|
|||
{ name = "json-schema-to-pydantic" },
|
||||
{ name = "mcp" },
|
||||
]
|
||||
ollama = [
|
||||
{ name = "ollama" },
|
||||
{ name = "tiktoken" },
|
||||
]
|
||||
openai = [
|
||||
{ name = "aiofiles" },
|
||||
{ name = "openai" },
|
||||
|
@ -716,6 +720,7 @@ requires-dist = [
|
|||
{ name = "markitdown", marker = "extra == 'web-surfer'", specifier = ">=0.0.1a2" },
|
||||
{ name = "mcp", marker = "extra == 'mcp'", specifier = ">=1.1.3" },
|
||||
{ name = "nbclient", marker = "extra == 'jupyter-executor'", specifier = ">=0.10.2" },
|
||||
{ name = "ollama", marker = "extra == 'ollama'", specifier = ">=0.4.7" },
|
||||
{ name = "openai", marker = "extra == 'openai'", specifier = ">=1.52.2" },
|
||||
{ name = "openai-whisper", marker = "extra == 'video-surfer'" },
|
||||
{ name = "opencv-python", marker = "extra == 'video-surfer'", specifier = ">=4.5" },
|
||||
|
@ -736,6 +741,7 @@ requires-dist = [
|
|||
{ name = "semantic-kernel", extras = ["ollama"], marker = "extra == 'semantic-kernel-ollama'", specifier = ">=1.17.1" },
|
||||
{ name = "semantic-kernel", extras = ["onnx"], marker = "extra == 'semantic-kernel-onnx'", specifier = ">=1.17.1" },
|
||||
{ name = "semantic-kernel", extras = ["pandas"], marker = "extra == 'semantic-kernel-pandas'", specifier = ">=1.17.1" },
|
||||
{ name = "tiktoken", marker = "extra == 'ollama'", specifier = ">=0.8.0" },
|
||||
{ name = "tiktoken", marker = "extra == 'openai'", specifier = ">=0.8.0" },
|
||||
]
|
||||
|
||||
|
|
Loading…
Reference in New Issue