Feature/azure ai inference client (#5153)

* Rebase to latest main branch

* Moved _azure module to azure

* Validate extra_create_args in and json response

* Added Support for Github Models

* Added normalize_name and assert_valid name

* Added Tests for AzureAIChatCompletionClient

* WIP: Azure AI Client

* Added: object-level usage data
* Added: doc string
* Added: check existing response_format value
* Added: _validate_config and _create_client

* lint

* merge dependencies

* add tests for img and function calling

* support actual tests through env vars

* address mypy errors

* doc example fix

* fmt

* fix doc fmt

* Update python/packages/autogen-ext/src/autogen_ext/models/azure/_azure_ai_client.py

---------

Co-authored-by: Rohan Thacker <thackerrohan4@gmail.com>
Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com>
Co-authored-by: Leonardo Pinheiro <lpinheiro@microsoft.com>
This commit is contained in:
Leonardo Pinheiro 2025-01-25 08:26:48 +10:00 committed by GitHub
parent 1982f1b0ec
commit db2410c705
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 885 additions and 6 deletions

View File

@ -51,6 +51,7 @@ python/autogen_ext.teams.magentic_one
python/autogen_ext.models.cache
python/autogen_ext.models.openai
python/autogen_ext.models.replay
python/autogen_ext.models.azure
python/autogen_ext.models.semantic_kernel
python/autogen_ext.tools.langchain
python/autogen_ext.tools.graphrag

View File

@ -0,0 +1,8 @@
autogen\_ext.models.azure
==========================
.. automodule:: autogen_ext.models.azure
:members:
:undoc-members:
:show-inheritance:

View File

@ -20,7 +20,11 @@ dependencies = [
[project.optional-dependencies]
langchain = ["langchain_core~= 0.3.3"]
azure = ["azure-core", "azure-identity"]
azure = [
"azure-ai-inference>=1.0.0b7",
"azure-core",
"azure-identity",
]
docker = ["docker~=7.0"]
openai = ["openai>=1.52.2", "tiktoken>=0.8.0", "aiofiles"]
file-surfer = [
@ -56,6 +60,7 @@ redis = [
grpc = [
"grpcio~=1.62.0", # TODO: update this once we have a stable version.
]
jupyter-executor = [
"ipykernel>=6.29.5",
"nbclient>=0.10.2",

View File

@ -0,0 +1,4 @@
from ._azure_ai_client import AzureAIChatCompletionClient
from .config import AzureAIChatCompletionClientConfig
__all__ = ["AzureAIChatCompletionClient", "AzureAIChatCompletionClientConfig"]

View File

@ -0,0 +1,503 @@
import asyncio
import re
from asyncio import Task
from inspect import getfullargspec
from typing import Any, Dict, List, Mapping, Optional, Sequence, cast
from autogen_core import CancellationToken, FunctionCall, Image
from autogen_core.models import (
AssistantMessage,
ChatCompletionClient,
CreateResult,
FinishReasons,
FunctionExecutionResultMessage,
LLMMessage,
ModelInfo,
RequestUsage,
SystemMessage,
UserMessage,
)
from autogen_core.tools import Tool, ToolSchema
from azure.ai.inference.aio import ChatCompletionsClient
from azure.ai.inference.models import (
AssistantMessage as AzureAssistantMessage,
)
from azure.ai.inference.models import (
ChatCompletions,
ChatCompletionsToolCall,
ChatCompletionsToolDefinition,
CompletionsFinishReason,
ContentItem,
FunctionDefinition,
ImageContentItem,
ImageDetailLevel,
ImageUrl,
StreamingChatChoiceUpdate,
StreamingChatCompletionsUpdate,
TextContentItem,
)
from azure.ai.inference.models import (
FunctionCall as AzureFunctionCall,
)
from azure.ai.inference.models import (
SystemMessage as AzureSystemMessage,
)
from azure.ai.inference.models import (
ToolMessage as AzureToolMessage,
)
from azure.ai.inference.models import (
UserMessage as AzureUserMessage,
)
from typing_extensions import AsyncGenerator, Union, Unpack
from autogen_ext.models.azure.config import (
GITHUB_MODELS_ENDPOINT,
AzureAIChatCompletionClientConfig,
)
create_kwargs = set(getfullargspec(ChatCompletionsClient.complete).kwonlyargs)
AzureMessage = Union[AzureSystemMessage, AzureUserMessage, AzureAssistantMessage, AzureToolMessage]
def _is_github_model(endpoint: str) -> bool:
return endpoint == GITHUB_MODELS_ENDPOINT
def convert_tools(tools: Sequence[Tool | ToolSchema]) -> List[ChatCompletionsToolDefinition]:
result: List[ChatCompletionsToolDefinition] = []
for tool in tools:
if isinstance(tool, Tool):
tool_schema = tool.schema.copy()
else:
assert isinstance(tool, dict)
tool_schema = tool.copy()
if "parameters" in tool_schema:
for value in tool_schema["parameters"]["properties"].values():
if "title" in value.keys():
del value["title"]
function_def: Dict[str, Any] = dict(name=tool_schema["name"])
if "description" in tool_schema:
function_def["description"] = tool_schema["description"]
if "parameters" in tool_schema:
function_def["parameters"] = tool_schema["parameters"]
result.append(
ChatCompletionsToolDefinition(
function=FunctionDefinition(**function_def),
),
)
return result
def _func_call_to_azure(message: FunctionCall) -> ChatCompletionsToolCall:
return ChatCompletionsToolCall(
id=message.id,
function=AzureFunctionCall(arguments=message.arguments, name=message.name),
)
def _system_message_to_azure(message: SystemMessage) -> AzureSystemMessage:
return AzureSystemMessage(content=message.content)
def _user_message_to_azure(message: UserMessage) -> AzureUserMessage:
assert_valid_name(message.source)
if isinstance(message.content, str):
return AzureUserMessage(content=message.content)
else:
parts: List[ContentItem] = []
for part in message.content:
if isinstance(part, str):
parts.append(TextContentItem(text=part))
elif isinstance(part, Image):
# TODO: support url based images
# TODO: support specifying details
parts.append(ImageContentItem(image_url=ImageUrl(url=part.data_uri, detail=ImageDetailLevel.AUTO)))
else:
raise ValueError(f"Unknown content type: {message.content}")
return AzureUserMessage(content=parts)
def _assistant_message_to_azure(message: AssistantMessage) -> AzureAssistantMessage:
assert_valid_name(message.source)
if isinstance(message.content, list):
return AzureAssistantMessage(
tool_calls=[_func_call_to_azure(x) for x in message.content],
)
else:
return AzureAssistantMessage(content=message.content)
def _tool_message_to_azure(message: FunctionExecutionResultMessage) -> Sequence[AzureToolMessage]:
return [AzureToolMessage(content=x.content, tool_call_id=x.call_id) for x in message.content]
def to_azure_message(message: LLMMessage) -> Sequence[AzureMessage]:
if isinstance(message, SystemMessage):
return [_system_message_to_azure(message)]
elif isinstance(message, UserMessage):
return [_user_message_to_azure(message)]
elif isinstance(message, AssistantMessage):
return [_assistant_message_to_azure(message)]
else:
return _tool_message_to_azure(message)
def normalize_name(name: str) -> str:
"""
LLMs sometimes ask functions while ignoring their own format requirements, this function should be used to replace invalid characters with "_".
Prefer _assert_valid_name for validating user configuration or input
"""
return re.sub(r"[^a-zA-Z0-9_-]", "_", name)[:64]
def assert_valid_name(name: str) -> str:
"""
Ensure that configured names are valid, raises ValueError if not.
For munging LLM responses use _normalize_name to ensure LLM specified names don't break the API.
"""
if not re.match(r"^[a-zA-Z0-9_-]+$", name):
raise ValueError(f"Invalid name: {name}. Only letters, numbers, '_' and '-' are allowed.")
if len(name) > 64:
raise ValueError(f"Invalid name: {name}. Name must be less than 64 characters.")
return name
class AzureAIChatCompletionClient(ChatCompletionClient):
"""
Chat completion client for models hosted on Azure AI Foundry or GitHub Models.
See `here <https://learn.microsoft.com/en-us/azure/ai-studio/reference/reference-model-inference-chat-completions>`_ for more info.
Args:
endpoint (str): The endpoint to use. **Required.**
credentials (union, AzureKeyCredential, AsyncTokenCredential): The credentials to use. **Required**
model_info (ModelInfo): The model family and capabilities of the model. **Required.**
model (str): The name of the model. **Required if model is hosted on GitHub Models.**
frequency_penalty: (optional,float)
presence_penalty: (optional,float)
temperature: (optional,float)
top_p: (optional,float)
max_tokens: (optional,int)
response_format: (optional,ChatCompletionsResponseFormat)
stop: (optional,List[str])
tools: (optional,List[ChatCompletionsToolDefinition])
tool_choice: (optional,Union[str, ChatCompletionsToolChoicePreset, ChatCompletionsNamedToolChoice]])
seed: (optional,int)
model_extras: (optional,Dict[str, Any])
To use this client, you must install the `azure-ai-inference` extension:
.. code-block:: bash
pip install "autogen-ext[azure]"
The following code snippet shows how to use the client:
.. code-block:: python
import asyncio
from azure.core.credentials import AzureKeyCredential
from autogen_ext.models.azure import AzureAIChatCompletionClient
from autogen_core.models import UserMessage
async def main():
client = AzureAIChatCompletionClient(
endpoint="endpoint",
credential=AzureKeyCredential("api_key"),
model_info={
"json_output": False,
"function_calling": False,
"vision": False,
"family": "unknown",
},
)
result = await client.create([UserMessage(content="What is the capital of France?", source="user")])
print(result)
if __name__ == "__main__":
asyncio.run(main())
"""
def __init__(self, **kwargs: Unpack[AzureAIChatCompletionClientConfig]):
config = self._validate_config(kwargs) # type: ignore
self._model_info = config["model_info"] # type: ignore
self._client = self._create_client(config)
self._create_args = self._prepare_create_args(config)
self._actual_usage = RequestUsage(prompt_tokens=0, completion_tokens=0)
self._total_usage = RequestUsage(prompt_tokens=0, completion_tokens=0)
@staticmethod
def _validate_config(config: Dict[str, Any]) -> AzureAIChatCompletionClientConfig:
if "endpoint" not in config:
raise ValueError("endpoint is required for AzureAIChatCompletionClient")
if "credential" not in config:
raise ValueError("credential is required for AzureAIChatCompletionClient")
if "model_info" not in config:
raise ValueError("model_info is required for AzureAIChatCompletionClient")
if _is_github_model(config["endpoint"]) and "model" not in config:
raise ValueError("model is required for when using a Github model with AzureAIChatCompletionClient")
return cast(AzureAIChatCompletionClientConfig, config)
@staticmethod
def _create_client(config: AzureAIChatCompletionClientConfig) -> ChatCompletionsClient:
return ChatCompletionsClient(**config)
@staticmethod
def _prepare_create_args(config: Mapping[str, Any]) -> Dict[str, Any]:
create_args = {k: v for k, v in config.items() if k in create_kwargs}
return create_args
def add_usage(self, usage: RequestUsage) -> None:
self._total_usage = RequestUsage(
self._total_usage.prompt_tokens + usage.prompt_tokens,
self._total_usage.completion_tokens + usage.completion_tokens,
)
def _validate_model_info(
self,
messages: Sequence[LLMMessage],
tools: Sequence[Tool | ToolSchema],
json_output: Optional[bool],
create_args: Dict[str, Any],
) -> None:
if self.model_info["vision"] is False:
for message in messages:
if isinstance(message, UserMessage):
if isinstance(message.content, list) and any(isinstance(x, Image) for x in message.content):
raise ValueError("Model does not support vision and image was provided")
if json_output is not None:
if self.model_info["json_output"] is False and json_output is True:
raise ValueError("Model does not support JSON output")
if json_output is True and "response_format" not in create_args:
create_args["response_format"] = "json_object"
if self.model_info["json_output"] is False and json_output is True:
raise ValueError("Model does not support JSON output")
if self.model_info["function_calling"] is False and len(tools) > 0:
raise ValueError("Model does not support function calling")
async def create(
self,
messages: Sequence[LLMMessage],
*,
tools: Sequence[Tool | ToolSchema] = [],
json_output: Optional[bool] = None,
extra_create_args: Mapping[str, Any] = {},
cancellation_token: Optional[CancellationToken] = None,
) -> CreateResult:
extra_create_args_keys = set(extra_create_args.keys())
if not create_kwargs.issuperset(extra_create_args_keys):
raise ValueError(f"Extra create args are invalid: {extra_create_args_keys - create_kwargs}")
# Copy the create args and overwrite anything in extra_create_args
create_args = self._create_args.copy()
create_args.update(extra_create_args)
self._validate_model_info(messages, tools, json_output, create_args)
azure_messages_nested = [to_azure_message(msg) for msg in messages]
azure_messages = [item for sublist in azure_messages_nested for item in sublist]
task: Task[ChatCompletions]
if len(tools) > 0:
converted_tools = convert_tools(tools)
task = asyncio.create_task( # type: ignore
self._client.complete(messages=azure_messages, tools=converted_tools, **create_args) # type: ignore
)
else:
task = asyncio.create_task( # type: ignore
self._client.complete( # type: ignore
messages=azure_messages,
**create_args,
)
)
if cancellation_token is not None:
cancellation_token.link_future(task)
result: ChatCompletions = await task
usage = RequestUsage(
prompt_tokens=result.usage.prompt_tokens if result.usage else 0,
completion_tokens=result.usage.completion_tokens if result.usage else 0,
)
choice = result.choices[0]
if choice.finish_reason == CompletionsFinishReason.TOOL_CALLS:
assert choice.message.tool_calls is not None
content: Union[str, List[FunctionCall]] = [
FunctionCall(
id=x.id,
arguments=x.function.arguments,
name=normalize_name(x.function.name),
)
for x in choice.message.tool_calls
]
finish_reason = "function_calls"
else:
if isinstance(choice.finish_reason, CompletionsFinishReason):
finish_reason = choice.finish_reason.value
else:
finish_reason = choice.finish_reason # type: ignore
content = choice.message.content or ""
response = CreateResult(
finish_reason=finish_reason, # type: ignore
content=content,
usage=usage,
cached=False,
)
self.add_usage(usage)
return response
async def create_stream(
self,
messages: Sequence[LLMMessage],
*,
tools: Sequence[Tool | ToolSchema] = [],
json_output: Optional[bool] = None,
extra_create_args: Mapping[str, Any] = {},
cancellation_token: Optional[CancellationToken] = None,
) -> AsyncGenerator[Union[str, CreateResult], None]:
extra_create_args_keys = set(extra_create_args.keys())
if not create_kwargs.issuperset(extra_create_args_keys):
raise ValueError(f"Extra create args are invalid: {extra_create_args_keys - create_kwargs}")
create_args: Dict[str, Any] = self._create_args.copy()
create_args.update(extra_create_args)
self._validate_model_info(messages, tools, json_output, create_args)
# azure_messages = [to_azure_message(m) for m in messages]
azure_messages_nested = [to_azure_message(msg) for msg in messages]
azure_messages = [item for sublist in azure_messages_nested for item in sublist]
if len(tools) > 0:
converted_tools = convert_tools(tools)
task = asyncio.create_task(
self._client.complete(messages=azure_messages, tools=converted_tools, stream=True, **create_args)
)
else:
task = asyncio.create_task(
self._client.complete(messages=azure_messages, max_tokens=20, stream=True, **create_args)
)
if cancellation_token is not None:
cancellation_token.link_future(task)
# result: ChatCompletions = await task
finish_reason: Optional[FinishReasons] = None
content_deltas: List[str] = []
full_tool_calls: Dict[str, FunctionCall] = {}
prompt_tokens = 0
completion_tokens = 0
chunk: Optional[StreamingChatCompletionsUpdate] = None
choice: Optional[StreamingChatChoiceUpdate] = None
async for chunk in await task: # type: ignore
assert isinstance(chunk, StreamingChatCompletionsUpdate)
choice = chunk.choices[0] if len(chunk.choices) > 0 else None
if choice and choice.finish_reason is not None:
if isinstance(choice.finish_reason, CompletionsFinishReason):
finish_reason = cast(FinishReasons, choice.finish_reason.value)
else:
if choice.finish_reason in ["stop", "length", "function_calls", "content_filter", "unknown"]:
finish_reason = choice.finish_reason # type: ignore
else:
raise ValueError(f"Unexpected finish reason: {choice.finish_reason}")
# We first try to load the content
if choice and choice.delta.content is not None:
content_deltas.append(choice.delta.content)
yield choice.delta.content
# Otherwise, we try to load the tool calls
if choice and choice.delta.tool_calls is not None:
for tool_call_chunk in choice.delta.tool_calls:
# print(tool_call_chunk)
if "index" in tool_call_chunk:
idx = tool_call_chunk["index"]
else:
idx = tool_call_chunk.id
if idx not in full_tool_calls:
full_tool_calls[idx] = FunctionCall(id="", arguments="", name="")
full_tool_calls[idx].id += tool_call_chunk.id
full_tool_calls[idx].name += tool_call_chunk.function.name
full_tool_calls[idx].arguments += tool_call_chunk.function.arguments
if chunk and chunk.usage:
prompt_tokens = chunk.usage.prompt_tokens
if finish_reason is None:
raise ValueError("No stop reason found")
if choice and choice.finish_reason is CompletionsFinishReason.TOOL_CALLS:
finish_reason = "function_calls"
content: Union[str, List[FunctionCall]]
if len(content_deltas) > 1:
content = "".join(content_deltas)
if chunk and chunk.usage:
completion_tokens = chunk.usage.completion_tokens
else:
completion_tokens = 0
else:
content = list(full_tool_calls.values())
usage = RequestUsage(
completion_tokens=completion_tokens,
prompt_tokens=prompt_tokens,
)
result = CreateResult(
finish_reason=finish_reason,
content=content,
usage=usage,
cached=False,
)
self.add_usage(usage)
yield result
def actual_usage(self) -> RequestUsage:
return self._actual_usage
def total_usage(self) -> RequestUsage:
return self._total_usage
def count_tokens(self, messages: Sequence[LLMMessage], *, tools: Sequence[Tool | ToolSchema] = []) -> int:
return 0
def remaining_tokens(self, messages: Sequence[LLMMessage], *, tools: Sequence[Tool | ToolSchema] = []) -> int:
return 0
@property
def model_info(self) -> ModelInfo:
return self._model_info
@property
def capabilities(self) -> ModelInfo:
return self.model_info
def __del__(self) -> None:
# TODO: This is a hack to close the open client
try:
asyncio.get_running_loop().create_task(self._client.close())
except RuntimeError:
asyncio.run(self._client.close())

View File

@ -0,0 +1,46 @@
from typing import Any, Dict, List, Literal, Optional, TypedDict, Union
from autogen_core.models import ModelInfo
from azure.ai.inference.models import (
ChatCompletionsNamedToolChoice,
ChatCompletionsToolChoicePreset,
ChatCompletionsToolDefinition,
)
from azure.core.credentials import AzureKeyCredential
from azure.core.credentials_async import AsyncTokenCredential
GITHUB_MODELS_ENDPOINT = "https://models.inference.ai.azure.com"
class JsonSchemaFormat(TypedDict, total=False):
"""Represents the same fields as azure.ai.inference.models.JsonSchemaFormat."""
name: str
schema: Dict[str, Any]
description: Optional[str]
strict: Optional[bool]
class AzureAIClientArguments(TypedDict, total=False):
endpoint: str
credential: Union[AzureKeyCredential, AsyncTokenCredential]
model_info: ModelInfo
class AzureAICreateArguments(TypedDict, total=False):
frequency_penalty: Optional[float]
presence_penalty: Optional[float]
temperature: Optional[float]
top_p: Optional[float]
max_tokens: Optional[int]
response_format: Optional[Literal["text", "json_object"]]
stop: Optional[List[str]]
tools: Optional[List[ChatCompletionsToolDefinition]]
tool_choice: Optional[Union[str, ChatCompletionsToolChoicePreset, ChatCompletionsNamedToolChoice]]
seed: Optional[int]
model: Optional[str]
model_extras: Optional[Dict[str, Any]]
class AzureAIChatCompletionClientConfig(AzureAIClientArguments, AzureAICreateArguments):
pass

View File

@ -0,0 +1,297 @@
import asyncio
import os
from datetime import datetime
from typing import Any, AsyncGenerator, List
import pytest
from autogen_core import CancellationToken, FunctionCall, Image
from autogen_core.models import CreateResult, UserMessage
from autogen_ext.models.azure import AzureAIChatCompletionClient
from azure.ai.inference.aio import (
ChatCompletionsClient,
)
from azure.ai.inference.models import (
ChatChoice,
ChatCompletions,
ChatCompletionsToolCall,
ChatResponseMessage,
CompletionsFinishReason,
CompletionsUsage,
StreamingChatChoiceUpdate,
StreamingChatCompletionsUpdate,
StreamingChatResponseMessageUpdate,
)
from azure.ai.inference.models import (
FunctionCall as AzureFunctionCall,
)
from azure.core.credentials import AzureKeyCredential
async def _mock_create_stream(*args: Any, **kwargs: Any) -> AsyncGenerator[StreamingChatCompletionsUpdate, None]:
mock_chunks_content = ["Hello", " Another Hello", " Yet Another Hello"]
mock_chunks = [
StreamingChatChoiceUpdate(
index=0,
finish_reason="stop",
delta=StreamingChatResponseMessageUpdate(role="assistant", content=chunk_content),
)
for chunk_content in mock_chunks_content
]
for mock_chunk in mock_chunks:
await asyncio.sleep(0.1)
yield StreamingChatCompletionsUpdate(
id="id",
choices=[mock_chunk],
created=datetime.now(),
model="model",
usage=CompletionsUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0),
)
async def _mock_create(
*args: Any, **kwargs: Any
) -> ChatCompletions | AsyncGenerator[StreamingChatCompletionsUpdate, None]:
stream = kwargs.get("stream", False)
if not stream:
await asyncio.sleep(0.1)
return ChatCompletions(
id="id",
created=datetime.now(),
model="model",
choices=[
ChatChoice(
index=0, finish_reason="stop", message=ChatResponseMessage(content="Hello", role="assistant")
)
],
usage=CompletionsUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0),
)
else:
return _mock_create_stream(*args, **kwargs)
@pytest.fixture
def azure_client(monkeypatch: pytest.MonkeyPatch) -> AzureAIChatCompletionClient:
endpoint = os.getenv("AZURE_AI_INFERENCE_ENDPOINT")
api_key = os.getenv("AZURE_AI_INFERENCE_API_KEY")
if endpoint and api_key:
return AzureAIChatCompletionClient(
endpoint=endpoint,
credential=AzureKeyCredential(api_key),
model_info={
"json_output": False,
"function_calling": False,
"vision": False,
"family": "unknown",
},
model="model",
)
monkeypatch.setattr(ChatCompletionsClient, "complete", _mock_create)
return AzureAIChatCompletionClient(
endpoint="endpoint",
credential=AzureKeyCredential("api_key"),
model_info={
"json_output": False,
"function_calling": False,
"vision": False,
"family": "unknown",
},
model="model",
)
@pytest.mark.asyncio
async def test_azure_ai_chat_completion_client(azure_client: AzureAIChatCompletionClient) -> None:
assert azure_client
@pytest.mark.asyncio
async def test_azure_ai_chat_completion_client_create(azure_client: AzureAIChatCompletionClient) -> None:
result = await azure_client.create(messages=[UserMessage(content="Hello", source="user")])
assert result.content == "Hello"
@pytest.mark.asyncio
async def test_azure_ai_chat_completion_client_create_stream(azure_client: AzureAIChatCompletionClient) -> None:
chunks: List[str | CreateResult] = []
async for chunk in azure_client.create_stream(messages=[UserMessage(content="Hello", source="user")]):
chunks.append(chunk)
assert chunks[0] == "Hello"
assert chunks[1] == " Another Hello"
assert chunks[2] == " Yet Another Hello"
@pytest.mark.asyncio
async def test_azure_ai_chat_completion_client_create_cancel(azure_client: AzureAIChatCompletionClient) -> None:
cancellation_token = CancellationToken()
task = asyncio.create_task(
azure_client.create(
messages=[UserMessage(content="Hello", source="user")], cancellation_token=cancellation_token
)
)
cancellation_token.cancel()
with pytest.raises(asyncio.CancelledError):
await task
@pytest.mark.asyncio
async def test_azure_ai_chat_completion_client_create_stream_cancel(azure_client: AzureAIChatCompletionClient) -> None:
cancellation_token = CancellationToken()
stream = azure_client.create_stream(
messages=[UserMessage(content="Hello", source="user")], cancellation_token=cancellation_token
)
cancellation_token.cancel()
with pytest.raises(asyncio.CancelledError):
async for _ in stream:
pass
@pytest.fixture
def function_calling_client(monkeypatch: pytest.MonkeyPatch) -> AzureAIChatCompletionClient:
"""
Returns a client that supports function calling.
"""
async def _mock_function_call_create(*args: Any, **kwargs: Any) -> ChatCompletions:
await asyncio.sleep(0.01)
return ChatCompletions(
id="id",
created=datetime.now(),
model="model",
choices=[
ChatChoice(
index=0,
finish_reason=CompletionsFinishReason.TOOL_CALLS,
message=ChatResponseMessage(
role="assistant",
content="",
tool_calls=[
ChatCompletionsToolCall(
id="tool_call_id",
function=AzureFunctionCall(name="some_function", arguments='{"foo": "bar"}'),
)
],
),
)
],
usage=CompletionsUsage(prompt_tokens=5, completion_tokens=2, total_tokens=7),
)
monkeypatch.setattr(ChatCompletionsClient, "complete", _mock_function_call_create)
return AzureAIChatCompletionClient(
endpoint="endpoint",
credential=AzureKeyCredential("api_key"),
model_info={
"json_output": False,
"function_calling": True,
"vision": False,
"family": "function_calling_model",
},
model="model",
)
@pytest.mark.asyncio
async def test_function_calling_not_supported(azure_client: AzureAIChatCompletionClient) -> None:
"""
Ensures error is raised if we pass tools but the model_info doesn't support function calling.
"""
with pytest.raises(ValueError) as exc:
await azure_client.create(
messages=[UserMessage(content="Hello", source="user")],
tools=[{"name": "dummy_tool"}],
)
assert "Model does not support function calling" in str(exc.value)
@pytest.mark.asyncio
async def test_function_calling_success(function_calling_client: AzureAIChatCompletionClient) -> None:
"""
Ensures function calling works and returns FunctionCall content.
"""
result = await function_calling_client.create(
messages=[UserMessage(content="Please call a function", source="user")],
tools=[{"name": "test_tool"}],
)
assert result.finish_reason == "function_calls"
assert isinstance(result.content, list)
assert isinstance(result.content[0], FunctionCall)
assert result.content[0].name == "some_function"
assert result.content[0].arguments == '{"foo": "bar"}'
@pytest.mark.asyncio
async def test_multimodal_unsupported_raises_error(azure_client: AzureAIChatCompletionClient) -> None:
"""
If model does not support vision, providing an image should raise ValueError.
"""
with pytest.raises(ValueError) as exc:
await azure_client.create(
messages=[
UserMessage(
content=[ # type: ignore
Image.from_base64(
"iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR4nGNgAAIAAAUAAen6L8YAAAAASUVORK5CYII="
)
],
source="user",
)
]
)
assert "does not support vision and image was provided" in str(exc.value)
@pytest.mark.asyncio
async def test_multimodal_supported(monkeypatch: pytest.MonkeyPatch) -> None:
"""
If model supports vision, providing an image should not raise.
"""
async def _mock_create_noop(*args: Any, **kwargs: Any) -> ChatCompletions:
await asyncio.sleep(0.01)
return ChatCompletions(
id="id",
created=datetime.now(),
model="model",
choices=[
ChatChoice(
index=0,
finish_reason="stop",
message=ChatResponseMessage(content="Handled image", role="assistant"),
)
],
usage=CompletionsUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0),
)
monkeypatch.setattr(ChatCompletionsClient, "complete", _mock_create_noop)
client = AzureAIChatCompletionClient(
endpoint="endpoint",
credential=AzureKeyCredential("api_key"),
model_info={
"json_output": False,
"function_calling": False,
"vision": True,
"family": "vision_model",
},
model="model",
)
result = await client.create(
messages=[
UserMessage(
content=[
Image.from_base64(
"iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR4nGNgAAIAAAUAAen6L8YAAAAASUVORK5CYII="
)
],
source="user",
)
]
)
assert result.content == "Handled image"

File diff suppressed because one or more lines are too long

View File

@ -561,6 +561,7 @@ dependencies = [
[package.optional-dependencies]
azure = [
{ name = "azure-ai-inference" },
{ name = "azure-core" },
{ name = "azure-identity" },
]
@ -665,6 +666,7 @@ requires-dist = [
{ name = "autogen-agentchat", marker = "extra == 'video-surfer'", editable = "packages/autogen-agentchat" },
{ name = "autogen-agentchat", marker = "extra == 'web-surfer'", editable = "packages/autogen-agentchat" },
{ name = "autogen-core", editable = "packages/autogen-core" },
{ name = "azure-ai-inference", marker = "extra == 'azure'", specifier = ">=1.0.0b7" },
{ name = "azure-core", marker = "extra == 'azure'" },
{ name = "azure-identity", marker = "extra == 'azure'" },
{ name = "diskcache", marker = "extra == 'diskcache'", specifier = ">=5.6.3" },
@ -878,6 +880,20 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/9e/43/53afb8ba17218f19b77c7834128566c5bbb100a0ad9ba2e8e89d089d7079/autopep8-2.3.2-py2.py3-none-any.whl", hash = "sha256:ce8ad498672c845a0c3de2629c15b635ec2b05ef8177a6e7c91c74f3e9b51128", size = 45807 },
]
[[package]]
name = "azure-ai-inference"
version = "1.0.0b7"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "azure-core" },
{ name = "isodate" },
{ name = "typing-extensions" },
]
sdist = { url = "https://files.pythonhosted.org/packages/af/37/233eee0bebbf631d2f911a9f1ebbc3784b100d9bfb84efc275e71c1ea636/azure_ai_inference-1.0.0b7.tar.gz", hash = "sha256:bd912f71f7f855036ca46c9a21439f290eed5e61da418fd26bbb32e3c68bcce3", size = 175883 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/cd/b6/5ba830eddc59f820c654694d476c14a3dd9c1f828ff9b48eb8b21dfd5f01/azure_ai_inference-1.0.0b7-py3-none-any.whl", hash = "sha256:59bb6a9ee62bd7654a69ca2bf12fe9335d7045df95b491cb8b5f9e3791c86175", size = 123030 },
]
[[package]]
name = "azure-common"
version = "1.1.28"