mirror of https://github.com/microsoft/autogen.git
Feature/azure ai inference client (#5153)
* Rebase to latest main branch * Moved _azure module to azure * Validate extra_create_args in and json response * Added Support for Github Models * Added normalize_name and assert_valid name * Added Tests for AzureAIChatCompletionClient * WIP: Azure AI Client * Added: object-level usage data * Added: doc string * Added: check existing response_format value * Added: _validate_config and _create_client * lint * merge dependencies * add tests for img and function calling * support actual tests through env vars * address mypy errors * doc example fix * fmt * fix doc fmt * Update python/packages/autogen-ext/src/autogen_ext/models/azure/_azure_ai_client.py --------- Co-authored-by: Rohan Thacker <thackerrohan4@gmail.com> Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com> Co-authored-by: Leonardo Pinheiro <lpinheiro@microsoft.com>
This commit is contained in:
parent
1982f1b0ec
commit
db2410c705
|
@ -51,6 +51,7 @@ python/autogen_ext.teams.magentic_one
|
|||
python/autogen_ext.models.cache
|
||||
python/autogen_ext.models.openai
|
||||
python/autogen_ext.models.replay
|
||||
python/autogen_ext.models.azure
|
||||
python/autogen_ext.models.semantic_kernel
|
||||
python/autogen_ext.tools.langchain
|
||||
python/autogen_ext.tools.graphrag
|
||||
|
|
|
@ -0,0 +1,8 @@
|
|||
autogen\_ext.models.azure
|
||||
==========================
|
||||
|
||||
|
||||
.. automodule:: autogen_ext.models.azure
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
|
@ -20,7 +20,11 @@ dependencies = [
|
|||
|
||||
[project.optional-dependencies]
|
||||
langchain = ["langchain_core~= 0.3.3"]
|
||||
azure = ["azure-core", "azure-identity"]
|
||||
azure = [
|
||||
"azure-ai-inference>=1.0.0b7",
|
||||
"azure-core",
|
||||
"azure-identity",
|
||||
]
|
||||
docker = ["docker~=7.0"]
|
||||
openai = ["openai>=1.52.2", "tiktoken>=0.8.0", "aiofiles"]
|
||||
file-surfer = [
|
||||
|
@ -56,6 +60,7 @@ redis = [
|
|||
grpc = [
|
||||
"grpcio~=1.62.0", # TODO: update this once we have a stable version.
|
||||
]
|
||||
|
||||
jupyter-executor = [
|
||||
"ipykernel>=6.29.5",
|
||||
"nbclient>=0.10.2",
|
||||
|
|
|
@ -0,0 +1,4 @@
|
|||
from ._azure_ai_client import AzureAIChatCompletionClient
|
||||
from .config import AzureAIChatCompletionClientConfig
|
||||
|
||||
__all__ = ["AzureAIChatCompletionClient", "AzureAIChatCompletionClientConfig"]
|
|
@ -0,0 +1,503 @@
|
|||
import asyncio
|
||||
import re
|
||||
from asyncio import Task
|
||||
from inspect import getfullargspec
|
||||
from typing import Any, Dict, List, Mapping, Optional, Sequence, cast
|
||||
|
||||
from autogen_core import CancellationToken, FunctionCall, Image
|
||||
from autogen_core.models import (
|
||||
AssistantMessage,
|
||||
ChatCompletionClient,
|
||||
CreateResult,
|
||||
FinishReasons,
|
||||
FunctionExecutionResultMessage,
|
||||
LLMMessage,
|
||||
ModelInfo,
|
||||
RequestUsage,
|
||||
SystemMessage,
|
||||
UserMessage,
|
||||
)
|
||||
from autogen_core.tools import Tool, ToolSchema
|
||||
from azure.ai.inference.aio import ChatCompletionsClient
|
||||
from azure.ai.inference.models import (
|
||||
AssistantMessage as AzureAssistantMessage,
|
||||
)
|
||||
from azure.ai.inference.models import (
|
||||
ChatCompletions,
|
||||
ChatCompletionsToolCall,
|
||||
ChatCompletionsToolDefinition,
|
||||
CompletionsFinishReason,
|
||||
ContentItem,
|
||||
FunctionDefinition,
|
||||
ImageContentItem,
|
||||
ImageDetailLevel,
|
||||
ImageUrl,
|
||||
StreamingChatChoiceUpdate,
|
||||
StreamingChatCompletionsUpdate,
|
||||
TextContentItem,
|
||||
)
|
||||
from azure.ai.inference.models import (
|
||||
FunctionCall as AzureFunctionCall,
|
||||
)
|
||||
from azure.ai.inference.models import (
|
||||
SystemMessage as AzureSystemMessage,
|
||||
)
|
||||
from azure.ai.inference.models import (
|
||||
ToolMessage as AzureToolMessage,
|
||||
)
|
||||
from azure.ai.inference.models import (
|
||||
UserMessage as AzureUserMessage,
|
||||
)
|
||||
from typing_extensions import AsyncGenerator, Union, Unpack
|
||||
|
||||
from autogen_ext.models.azure.config import (
|
||||
GITHUB_MODELS_ENDPOINT,
|
||||
AzureAIChatCompletionClientConfig,
|
||||
)
|
||||
|
||||
create_kwargs = set(getfullargspec(ChatCompletionsClient.complete).kwonlyargs)
|
||||
AzureMessage = Union[AzureSystemMessage, AzureUserMessage, AzureAssistantMessage, AzureToolMessage]
|
||||
|
||||
|
||||
def _is_github_model(endpoint: str) -> bool:
|
||||
return endpoint == GITHUB_MODELS_ENDPOINT
|
||||
|
||||
|
||||
def convert_tools(tools: Sequence[Tool | ToolSchema]) -> List[ChatCompletionsToolDefinition]:
|
||||
result: List[ChatCompletionsToolDefinition] = []
|
||||
for tool in tools:
|
||||
if isinstance(tool, Tool):
|
||||
tool_schema = tool.schema.copy()
|
||||
else:
|
||||
assert isinstance(tool, dict)
|
||||
tool_schema = tool.copy()
|
||||
|
||||
if "parameters" in tool_schema:
|
||||
for value in tool_schema["parameters"]["properties"].values():
|
||||
if "title" in value.keys():
|
||||
del value["title"]
|
||||
|
||||
function_def: Dict[str, Any] = dict(name=tool_schema["name"])
|
||||
if "description" in tool_schema:
|
||||
function_def["description"] = tool_schema["description"]
|
||||
if "parameters" in tool_schema:
|
||||
function_def["parameters"] = tool_schema["parameters"]
|
||||
|
||||
result.append(
|
||||
ChatCompletionsToolDefinition(
|
||||
function=FunctionDefinition(**function_def),
|
||||
),
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
def _func_call_to_azure(message: FunctionCall) -> ChatCompletionsToolCall:
|
||||
return ChatCompletionsToolCall(
|
||||
id=message.id,
|
||||
function=AzureFunctionCall(arguments=message.arguments, name=message.name),
|
||||
)
|
||||
|
||||
|
||||
def _system_message_to_azure(message: SystemMessage) -> AzureSystemMessage:
|
||||
return AzureSystemMessage(content=message.content)
|
||||
|
||||
|
||||
def _user_message_to_azure(message: UserMessage) -> AzureUserMessage:
|
||||
assert_valid_name(message.source)
|
||||
if isinstance(message.content, str):
|
||||
return AzureUserMessage(content=message.content)
|
||||
else:
|
||||
parts: List[ContentItem] = []
|
||||
for part in message.content:
|
||||
if isinstance(part, str):
|
||||
parts.append(TextContentItem(text=part))
|
||||
elif isinstance(part, Image):
|
||||
# TODO: support url based images
|
||||
# TODO: support specifying details
|
||||
parts.append(ImageContentItem(image_url=ImageUrl(url=part.data_uri, detail=ImageDetailLevel.AUTO)))
|
||||
else:
|
||||
raise ValueError(f"Unknown content type: {message.content}")
|
||||
return AzureUserMessage(content=parts)
|
||||
|
||||
|
||||
def _assistant_message_to_azure(message: AssistantMessage) -> AzureAssistantMessage:
|
||||
assert_valid_name(message.source)
|
||||
if isinstance(message.content, list):
|
||||
return AzureAssistantMessage(
|
||||
tool_calls=[_func_call_to_azure(x) for x in message.content],
|
||||
)
|
||||
else:
|
||||
return AzureAssistantMessage(content=message.content)
|
||||
|
||||
|
||||
def _tool_message_to_azure(message: FunctionExecutionResultMessage) -> Sequence[AzureToolMessage]:
|
||||
return [AzureToolMessage(content=x.content, tool_call_id=x.call_id) for x in message.content]
|
||||
|
||||
|
||||
def to_azure_message(message: LLMMessage) -> Sequence[AzureMessage]:
|
||||
if isinstance(message, SystemMessage):
|
||||
return [_system_message_to_azure(message)]
|
||||
elif isinstance(message, UserMessage):
|
||||
return [_user_message_to_azure(message)]
|
||||
elif isinstance(message, AssistantMessage):
|
||||
return [_assistant_message_to_azure(message)]
|
||||
else:
|
||||
return _tool_message_to_azure(message)
|
||||
|
||||
|
||||
def normalize_name(name: str) -> str:
|
||||
"""
|
||||
LLMs sometimes ask functions while ignoring their own format requirements, this function should be used to replace invalid characters with "_".
|
||||
|
||||
Prefer _assert_valid_name for validating user configuration or input
|
||||
"""
|
||||
return re.sub(r"[^a-zA-Z0-9_-]", "_", name)[:64]
|
||||
|
||||
|
||||
def assert_valid_name(name: str) -> str:
|
||||
"""
|
||||
Ensure that configured names are valid, raises ValueError if not.
|
||||
|
||||
For munging LLM responses use _normalize_name to ensure LLM specified names don't break the API.
|
||||
"""
|
||||
if not re.match(r"^[a-zA-Z0-9_-]+$", name):
|
||||
raise ValueError(f"Invalid name: {name}. Only letters, numbers, '_' and '-' are allowed.")
|
||||
if len(name) > 64:
|
||||
raise ValueError(f"Invalid name: {name}. Name must be less than 64 characters.")
|
||||
return name
|
||||
|
||||
|
||||
class AzureAIChatCompletionClient(ChatCompletionClient):
|
||||
"""
|
||||
Chat completion client for models hosted on Azure AI Foundry or GitHub Models.
|
||||
See `here <https://learn.microsoft.com/en-us/azure/ai-studio/reference/reference-model-inference-chat-completions>`_ for more info.
|
||||
|
||||
Args:
|
||||
endpoint (str): The endpoint to use. **Required.**
|
||||
credentials (union, AzureKeyCredential, AsyncTokenCredential): The credentials to use. **Required**
|
||||
model_info (ModelInfo): The model family and capabilities of the model. **Required.**
|
||||
model (str): The name of the model. **Required if model is hosted on GitHub Models.**
|
||||
frequency_penalty: (optional,float)
|
||||
presence_penalty: (optional,float)
|
||||
temperature: (optional,float)
|
||||
top_p: (optional,float)
|
||||
max_tokens: (optional,int)
|
||||
response_format: (optional,ChatCompletionsResponseFormat)
|
||||
stop: (optional,List[str])
|
||||
tools: (optional,List[ChatCompletionsToolDefinition])
|
||||
tool_choice: (optional,Union[str, ChatCompletionsToolChoicePreset, ChatCompletionsNamedToolChoice]])
|
||||
seed: (optional,int)
|
||||
model_extras: (optional,Dict[str, Any])
|
||||
|
||||
To use this client, you must install the `azure-ai-inference` extension:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
pip install "autogen-ext[azure]"
|
||||
|
||||
The following code snippet shows how to use the client:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import asyncio
|
||||
from azure.core.credentials import AzureKeyCredential
|
||||
from autogen_ext.models.azure import AzureAIChatCompletionClient
|
||||
from autogen_core.models import UserMessage
|
||||
|
||||
|
||||
async def main():
|
||||
client = AzureAIChatCompletionClient(
|
||||
endpoint="endpoint",
|
||||
credential=AzureKeyCredential("api_key"),
|
||||
model_info={
|
||||
"json_output": False,
|
||||
"function_calling": False,
|
||||
"vision": False,
|
||||
"family": "unknown",
|
||||
},
|
||||
)
|
||||
|
||||
result = await client.create([UserMessage(content="What is the capital of France?", source="user")])
|
||||
print(result)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs: Unpack[AzureAIChatCompletionClientConfig]):
|
||||
config = self._validate_config(kwargs) # type: ignore
|
||||
self._model_info = config["model_info"] # type: ignore
|
||||
self._client = self._create_client(config)
|
||||
self._create_args = self._prepare_create_args(config)
|
||||
|
||||
self._actual_usage = RequestUsage(prompt_tokens=0, completion_tokens=0)
|
||||
self._total_usage = RequestUsage(prompt_tokens=0, completion_tokens=0)
|
||||
|
||||
@staticmethod
|
||||
def _validate_config(config: Dict[str, Any]) -> AzureAIChatCompletionClientConfig:
|
||||
if "endpoint" not in config:
|
||||
raise ValueError("endpoint is required for AzureAIChatCompletionClient")
|
||||
if "credential" not in config:
|
||||
raise ValueError("credential is required for AzureAIChatCompletionClient")
|
||||
if "model_info" not in config:
|
||||
raise ValueError("model_info is required for AzureAIChatCompletionClient")
|
||||
if _is_github_model(config["endpoint"]) and "model" not in config:
|
||||
raise ValueError("model is required for when using a Github model with AzureAIChatCompletionClient")
|
||||
return cast(AzureAIChatCompletionClientConfig, config)
|
||||
|
||||
@staticmethod
|
||||
def _create_client(config: AzureAIChatCompletionClientConfig) -> ChatCompletionsClient:
|
||||
return ChatCompletionsClient(**config)
|
||||
|
||||
@staticmethod
|
||||
def _prepare_create_args(config: Mapping[str, Any]) -> Dict[str, Any]:
|
||||
create_args = {k: v for k, v in config.items() if k in create_kwargs}
|
||||
return create_args
|
||||
|
||||
def add_usage(self, usage: RequestUsage) -> None:
|
||||
self._total_usage = RequestUsage(
|
||||
self._total_usage.prompt_tokens + usage.prompt_tokens,
|
||||
self._total_usage.completion_tokens + usage.completion_tokens,
|
||||
)
|
||||
|
||||
def _validate_model_info(
|
||||
self,
|
||||
messages: Sequence[LLMMessage],
|
||||
tools: Sequence[Tool | ToolSchema],
|
||||
json_output: Optional[bool],
|
||||
create_args: Dict[str, Any],
|
||||
) -> None:
|
||||
if self.model_info["vision"] is False:
|
||||
for message in messages:
|
||||
if isinstance(message, UserMessage):
|
||||
if isinstance(message.content, list) and any(isinstance(x, Image) for x in message.content):
|
||||
raise ValueError("Model does not support vision and image was provided")
|
||||
|
||||
if json_output is not None:
|
||||
if self.model_info["json_output"] is False and json_output is True:
|
||||
raise ValueError("Model does not support JSON output")
|
||||
|
||||
if json_output is True and "response_format" not in create_args:
|
||||
create_args["response_format"] = "json_object"
|
||||
|
||||
if self.model_info["json_output"] is False and json_output is True:
|
||||
raise ValueError("Model does not support JSON output")
|
||||
if self.model_info["function_calling"] is False and len(tools) > 0:
|
||||
raise ValueError("Model does not support function calling")
|
||||
|
||||
async def create(
|
||||
self,
|
||||
messages: Sequence[LLMMessage],
|
||||
*,
|
||||
tools: Sequence[Tool | ToolSchema] = [],
|
||||
json_output: Optional[bool] = None,
|
||||
extra_create_args: Mapping[str, Any] = {},
|
||||
cancellation_token: Optional[CancellationToken] = None,
|
||||
) -> CreateResult:
|
||||
extra_create_args_keys = set(extra_create_args.keys())
|
||||
if not create_kwargs.issuperset(extra_create_args_keys):
|
||||
raise ValueError(f"Extra create args are invalid: {extra_create_args_keys - create_kwargs}")
|
||||
|
||||
# Copy the create args and overwrite anything in extra_create_args
|
||||
create_args = self._create_args.copy()
|
||||
create_args.update(extra_create_args)
|
||||
|
||||
self._validate_model_info(messages, tools, json_output, create_args)
|
||||
|
||||
azure_messages_nested = [to_azure_message(msg) for msg in messages]
|
||||
azure_messages = [item for sublist in azure_messages_nested for item in sublist]
|
||||
|
||||
task: Task[ChatCompletions]
|
||||
|
||||
if len(tools) > 0:
|
||||
converted_tools = convert_tools(tools)
|
||||
task = asyncio.create_task( # type: ignore
|
||||
self._client.complete(messages=azure_messages, tools=converted_tools, **create_args) # type: ignore
|
||||
)
|
||||
else:
|
||||
task = asyncio.create_task( # type: ignore
|
||||
self._client.complete( # type: ignore
|
||||
messages=azure_messages,
|
||||
**create_args,
|
||||
)
|
||||
)
|
||||
|
||||
if cancellation_token is not None:
|
||||
cancellation_token.link_future(task)
|
||||
|
||||
result: ChatCompletions = await task
|
||||
|
||||
usage = RequestUsage(
|
||||
prompt_tokens=result.usage.prompt_tokens if result.usage else 0,
|
||||
completion_tokens=result.usage.completion_tokens if result.usage else 0,
|
||||
)
|
||||
|
||||
choice = result.choices[0]
|
||||
if choice.finish_reason == CompletionsFinishReason.TOOL_CALLS:
|
||||
assert choice.message.tool_calls is not None
|
||||
|
||||
content: Union[str, List[FunctionCall]] = [
|
||||
FunctionCall(
|
||||
id=x.id,
|
||||
arguments=x.function.arguments,
|
||||
name=normalize_name(x.function.name),
|
||||
)
|
||||
for x in choice.message.tool_calls
|
||||
]
|
||||
finish_reason = "function_calls"
|
||||
else:
|
||||
if isinstance(choice.finish_reason, CompletionsFinishReason):
|
||||
finish_reason = choice.finish_reason.value
|
||||
else:
|
||||
finish_reason = choice.finish_reason # type: ignore
|
||||
content = choice.message.content or ""
|
||||
|
||||
response = CreateResult(
|
||||
finish_reason=finish_reason, # type: ignore
|
||||
content=content,
|
||||
usage=usage,
|
||||
cached=False,
|
||||
)
|
||||
|
||||
self.add_usage(usage)
|
||||
|
||||
return response
|
||||
|
||||
async def create_stream(
|
||||
self,
|
||||
messages: Sequence[LLMMessage],
|
||||
*,
|
||||
tools: Sequence[Tool | ToolSchema] = [],
|
||||
json_output: Optional[bool] = None,
|
||||
extra_create_args: Mapping[str, Any] = {},
|
||||
cancellation_token: Optional[CancellationToken] = None,
|
||||
) -> AsyncGenerator[Union[str, CreateResult], None]:
|
||||
extra_create_args_keys = set(extra_create_args.keys())
|
||||
if not create_kwargs.issuperset(extra_create_args_keys):
|
||||
raise ValueError(f"Extra create args are invalid: {extra_create_args_keys - create_kwargs}")
|
||||
|
||||
create_args: Dict[str, Any] = self._create_args.copy()
|
||||
create_args.update(extra_create_args)
|
||||
|
||||
self._validate_model_info(messages, tools, json_output, create_args)
|
||||
|
||||
# azure_messages = [to_azure_message(m) for m in messages]
|
||||
azure_messages_nested = [to_azure_message(msg) for msg in messages]
|
||||
azure_messages = [item for sublist in azure_messages_nested for item in sublist]
|
||||
|
||||
if len(tools) > 0:
|
||||
converted_tools = convert_tools(tools)
|
||||
task = asyncio.create_task(
|
||||
self._client.complete(messages=azure_messages, tools=converted_tools, stream=True, **create_args)
|
||||
)
|
||||
else:
|
||||
task = asyncio.create_task(
|
||||
self._client.complete(messages=azure_messages, max_tokens=20, stream=True, **create_args)
|
||||
)
|
||||
|
||||
if cancellation_token is not None:
|
||||
cancellation_token.link_future(task)
|
||||
|
||||
# result: ChatCompletions = await task
|
||||
finish_reason: Optional[FinishReasons] = None
|
||||
content_deltas: List[str] = []
|
||||
full_tool_calls: Dict[str, FunctionCall] = {}
|
||||
prompt_tokens = 0
|
||||
completion_tokens = 0
|
||||
chunk: Optional[StreamingChatCompletionsUpdate] = None
|
||||
choice: Optional[StreamingChatChoiceUpdate] = None
|
||||
async for chunk in await task: # type: ignore
|
||||
assert isinstance(chunk, StreamingChatCompletionsUpdate)
|
||||
choice = chunk.choices[0] if len(chunk.choices) > 0 else None
|
||||
if choice and choice.finish_reason is not None:
|
||||
if isinstance(choice.finish_reason, CompletionsFinishReason):
|
||||
finish_reason = cast(FinishReasons, choice.finish_reason.value)
|
||||
else:
|
||||
if choice.finish_reason in ["stop", "length", "function_calls", "content_filter", "unknown"]:
|
||||
finish_reason = choice.finish_reason # type: ignore
|
||||
else:
|
||||
raise ValueError(f"Unexpected finish reason: {choice.finish_reason}")
|
||||
|
||||
# We first try to load the content
|
||||
if choice and choice.delta.content is not None:
|
||||
content_deltas.append(choice.delta.content)
|
||||
yield choice.delta.content
|
||||
# Otherwise, we try to load the tool calls
|
||||
if choice and choice.delta.tool_calls is not None:
|
||||
for tool_call_chunk in choice.delta.tool_calls:
|
||||
# print(tool_call_chunk)
|
||||
if "index" in tool_call_chunk:
|
||||
idx = tool_call_chunk["index"]
|
||||
else:
|
||||
idx = tool_call_chunk.id
|
||||
if idx not in full_tool_calls:
|
||||
full_tool_calls[idx] = FunctionCall(id="", arguments="", name="")
|
||||
|
||||
full_tool_calls[idx].id += tool_call_chunk.id
|
||||
full_tool_calls[idx].name += tool_call_chunk.function.name
|
||||
full_tool_calls[idx].arguments += tool_call_chunk.function.arguments
|
||||
|
||||
if chunk and chunk.usage:
|
||||
prompt_tokens = chunk.usage.prompt_tokens
|
||||
|
||||
if finish_reason is None:
|
||||
raise ValueError("No stop reason found")
|
||||
|
||||
if choice and choice.finish_reason is CompletionsFinishReason.TOOL_CALLS:
|
||||
finish_reason = "function_calls"
|
||||
|
||||
content: Union[str, List[FunctionCall]]
|
||||
|
||||
if len(content_deltas) > 1:
|
||||
content = "".join(content_deltas)
|
||||
if chunk and chunk.usage:
|
||||
completion_tokens = chunk.usage.completion_tokens
|
||||
else:
|
||||
completion_tokens = 0
|
||||
else:
|
||||
content = list(full_tool_calls.values())
|
||||
|
||||
usage = RequestUsage(
|
||||
completion_tokens=completion_tokens,
|
||||
prompt_tokens=prompt_tokens,
|
||||
)
|
||||
|
||||
result = CreateResult(
|
||||
finish_reason=finish_reason,
|
||||
content=content,
|
||||
usage=usage,
|
||||
cached=False,
|
||||
)
|
||||
|
||||
self.add_usage(usage)
|
||||
|
||||
yield result
|
||||
|
||||
def actual_usage(self) -> RequestUsage:
|
||||
return self._actual_usage
|
||||
|
||||
def total_usage(self) -> RequestUsage:
|
||||
return self._total_usage
|
||||
|
||||
def count_tokens(self, messages: Sequence[LLMMessage], *, tools: Sequence[Tool | ToolSchema] = []) -> int:
|
||||
return 0
|
||||
|
||||
def remaining_tokens(self, messages: Sequence[LLMMessage], *, tools: Sequence[Tool | ToolSchema] = []) -> int:
|
||||
return 0
|
||||
|
||||
@property
|
||||
def model_info(self) -> ModelInfo:
|
||||
return self._model_info
|
||||
|
||||
@property
|
||||
def capabilities(self) -> ModelInfo:
|
||||
return self.model_info
|
||||
|
||||
def __del__(self) -> None:
|
||||
# TODO: This is a hack to close the open client
|
||||
try:
|
||||
asyncio.get_running_loop().create_task(self._client.close())
|
||||
except RuntimeError:
|
||||
asyncio.run(self._client.close())
|
|
@ -0,0 +1,46 @@
|
|||
from typing import Any, Dict, List, Literal, Optional, TypedDict, Union
|
||||
|
||||
from autogen_core.models import ModelInfo
|
||||
from azure.ai.inference.models import (
|
||||
ChatCompletionsNamedToolChoice,
|
||||
ChatCompletionsToolChoicePreset,
|
||||
ChatCompletionsToolDefinition,
|
||||
)
|
||||
from azure.core.credentials import AzureKeyCredential
|
||||
from azure.core.credentials_async import AsyncTokenCredential
|
||||
|
||||
GITHUB_MODELS_ENDPOINT = "https://models.inference.ai.azure.com"
|
||||
|
||||
|
||||
class JsonSchemaFormat(TypedDict, total=False):
|
||||
"""Represents the same fields as azure.ai.inference.models.JsonSchemaFormat."""
|
||||
|
||||
name: str
|
||||
schema: Dict[str, Any]
|
||||
description: Optional[str]
|
||||
strict: Optional[bool]
|
||||
|
||||
|
||||
class AzureAIClientArguments(TypedDict, total=False):
|
||||
endpoint: str
|
||||
credential: Union[AzureKeyCredential, AsyncTokenCredential]
|
||||
model_info: ModelInfo
|
||||
|
||||
|
||||
class AzureAICreateArguments(TypedDict, total=False):
|
||||
frequency_penalty: Optional[float]
|
||||
presence_penalty: Optional[float]
|
||||
temperature: Optional[float]
|
||||
top_p: Optional[float]
|
||||
max_tokens: Optional[int]
|
||||
response_format: Optional[Literal["text", "json_object"]]
|
||||
stop: Optional[List[str]]
|
||||
tools: Optional[List[ChatCompletionsToolDefinition]]
|
||||
tool_choice: Optional[Union[str, ChatCompletionsToolChoicePreset, ChatCompletionsNamedToolChoice]]
|
||||
seed: Optional[int]
|
||||
model: Optional[str]
|
||||
model_extras: Optional[Dict[str, Any]]
|
||||
|
||||
|
||||
class AzureAIChatCompletionClientConfig(AzureAIClientArguments, AzureAICreateArguments):
|
||||
pass
|
|
@ -0,0 +1,297 @@
|
|||
import asyncio
|
||||
import os
|
||||
from datetime import datetime
|
||||
from typing import Any, AsyncGenerator, List
|
||||
|
||||
import pytest
|
||||
from autogen_core import CancellationToken, FunctionCall, Image
|
||||
from autogen_core.models import CreateResult, UserMessage
|
||||
from autogen_ext.models.azure import AzureAIChatCompletionClient
|
||||
from azure.ai.inference.aio import (
|
||||
ChatCompletionsClient,
|
||||
)
|
||||
from azure.ai.inference.models import (
|
||||
ChatChoice,
|
||||
ChatCompletions,
|
||||
ChatCompletionsToolCall,
|
||||
ChatResponseMessage,
|
||||
CompletionsFinishReason,
|
||||
CompletionsUsage,
|
||||
StreamingChatChoiceUpdate,
|
||||
StreamingChatCompletionsUpdate,
|
||||
StreamingChatResponseMessageUpdate,
|
||||
)
|
||||
from azure.ai.inference.models import (
|
||||
FunctionCall as AzureFunctionCall,
|
||||
)
|
||||
from azure.core.credentials import AzureKeyCredential
|
||||
|
||||
|
||||
async def _mock_create_stream(*args: Any, **kwargs: Any) -> AsyncGenerator[StreamingChatCompletionsUpdate, None]:
|
||||
mock_chunks_content = ["Hello", " Another Hello", " Yet Another Hello"]
|
||||
|
||||
mock_chunks = [
|
||||
StreamingChatChoiceUpdate(
|
||||
index=0,
|
||||
finish_reason="stop",
|
||||
delta=StreamingChatResponseMessageUpdate(role="assistant", content=chunk_content),
|
||||
)
|
||||
for chunk_content in mock_chunks_content
|
||||
]
|
||||
|
||||
for mock_chunk in mock_chunks:
|
||||
await asyncio.sleep(0.1)
|
||||
yield StreamingChatCompletionsUpdate(
|
||||
id="id",
|
||||
choices=[mock_chunk],
|
||||
created=datetime.now(),
|
||||
model="model",
|
||||
usage=CompletionsUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0),
|
||||
)
|
||||
|
||||
|
||||
async def _mock_create(
|
||||
*args: Any, **kwargs: Any
|
||||
) -> ChatCompletions | AsyncGenerator[StreamingChatCompletionsUpdate, None]:
|
||||
stream = kwargs.get("stream", False)
|
||||
|
||||
if not stream:
|
||||
await asyncio.sleep(0.1)
|
||||
return ChatCompletions(
|
||||
id="id",
|
||||
created=datetime.now(),
|
||||
model="model",
|
||||
choices=[
|
||||
ChatChoice(
|
||||
index=0, finish_reason="stop", message=ChatResponseMessage(content="Hello", role="assistant")
|
||||
)
|
||||
],
|
||||
usage=CompletionsUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0),
|
||||
)
|
||||
else:
|
||||
return _mock_create_stream(*args, **kwargs)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def azure_client(monkeypatch: pytest.MonkeyPatch) -> AzureAIChatCompletionClient:
|
||||
endpoint = os.getenv("AZURE_AI_INFERENCE_ENDPOINT")
|
||||
api_key = os.getenv("AZURE_AI_INFERENCE_API_KEY")
|
||||
|
||||
if endpoint and api_key:
|
||||
return AzureAIChatCompletionClient(
|
||||
endpoint=endpoint,
|
||||
credential=AzureKeyCredential(api_key),
|
||||
model_info={
|
||||
"json_output": False,
|
||||
"function_calling": False,
|
||||
"vision": False,
|
||||
"family": "unknown",
|
||||
},
|
||||
model="model",
|
||||
)
|
||||
|
||||
monkeypatch.setattr(ChatCompletionsClient, "complete", _mock_create)
|
||||
return AzureAIChatCompletionClient(
|
||||
endpoint="endpoint",
|
||||
credential=AzureKeyCredential("api_key"),
|
||||
model_info={
|
||||
"json_output": False,
|
||||
"function_calling": False,
|
||||
"vision": False,
|
||||
"family": "unknown",
|
||||
},
|
||||
model="model",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_azure_ai_chat_completion_client(azure_client: AzureAIChatCompletionClient) -> None:
|
||||
assert azure_client
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_azure_ai_chat_completion_client_create(azure_client: AzureAIChatCompletionClient) -> None:
|
||||
result = await azure_client.create(messages=[UserMessage(content="Hello", source="user")])
|
||||
assert result.content == "Hello"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_azure_ai_chat_completion_client_create_stream(azure_client: AzureAIChatCompletionClient) -> None:
|
||||
chunks: List[str | CreateResult] = []
|
||||
async for chunk in azure_client.create_stream(messages=[UserMessage(content="Hello", source="user")]):
|
||||
chunks.append(chunk)
|
||||
|
||||
assert chunks[0] == "Hello"
|
||||
assert chunks[1] == " Another Hello"
|
||||
assert chunks[2] == " Yet Another Hello"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_azure_ai_chat_completion_client_create_cancel(azure_client: AzureAIChatCompletionClient) -> None:
|
||||
cancellation_token = CancellationToken()
|
||||
task = asyncio.create_task(
|
||||
azure_client.create(
|
||||
messages=[UserMessage(content="Hello", source="user")], cancellation_token=cancellation_token
|
||||
)
|
||||
)
|
||||
cancellation_token.cancel()
|
||||
with pytest.raises(asyncio.CancelledError):
|
||||
await task
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_azure_ai_chat_completion_client_create_stream_cancel(azure_client: AzureAIChatCompletionClient) -> None:
|
||||
cancellation_token = CancellationToken()
|
||||
stream = azure_client.create_stream(
|
||||
messages=[UserMessage(content="Hello", source="user")], cancellation_token=cancellation_token
|
||||
)
|
||||
cancellation_token.cancel()
|
||||
with pytest.raises(asyncio.CancelledError):
|
||||
async for _ in stream:
|
||||
pass
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def function_calling_client(monkeypatch: pytest.MonkeyPatch) -> AzureAIChatCompletionClient:
|
||||
"""
|
||||
Returns a client that supports function calling.
|
||||
"""
|
||||
|
||||
async def _mock_function_call_create(*args: Any, **kwargs: Any) -> ChatCompletions:
|
||||
await asyncio.sleep(0.01)
|
||||
return ChatCompletions(
|
||||
id="id",
|
||||
created=datetime.now(),
|
||||
model="model",
|
||||
choices=[
|
||||
ChatChoice(
|
||||
index=0,
|
||||
finish_reason=CompletionsFinishReason.TOOL_CALLS,
|
||||
message=ChatResponseMessage(
|
||||
role="assistant",
|
||||
content="",
|
||||
tool_calls=[
|
||||
ChatCompletionsToolCall(
|
||||
id="tool_call_id",
|
||||
function=AzureFunctionCall(name="some_function", arguments='{"foo": "bar"}'),
|
||||
)
|
||||
],
|
||||
),
|
||||
)
|
||||
],
|
||||
usage=CompletionsUsage(prompt_tokens=5, completion_tokens=2, total_tokens=7),
|
||||
)
|
||||
|
||||
monkeypatch.setattr(ChatCompletionsClient, "complete", _mock_function_call_create)
|
||||
return AzureAIChatCompletionClient(
|
||||
endpoint="endpoint",
|
||||
credential=AzureKeyCredential("api_key"),
|
||||
model_info={
|
||||
"json_output": False,
|
||||
"function_calling": True,
|
||||
"vision": False,
|
||||
"family": "function_calling_model",
|
||||
},
|
||||
model="model",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_function_calling_not_supported(azure_client: AzureAIChatCompletionClient) -> None:
|
||||
"""
|
||||
Ensures error is raised if we pass tools but the model_info doesn't support function calling.
|
||||
"""
|
||||
with pytest.raises(ValueError) as exc:
|
||||
await azure_client.create(
|
||||
messages=[UserMessage(content="Hello", source="user")],
|
||||
tools=[{"name": "dummy_tool"}],
|
||||
)
|
||||
assert "Model does not support function calling" in str(exc.value)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_function_calling_success(function_calling_client: AzureAIChatCompletionClient) -> None:
|
||||
"""
|
||||
Ensures function calling works and returns FunctionCall content.
|
||||
"""
|
||||
result = await function_calling_client.create(
|
||||
messages=[UserMessage(content="Please call a function", source="user")],
|
||||
tools=[{"name": "test_tool"}],
|
||||
)
|
||||
assert result.finish_reason == "function_calls"
|
||||
assert isinstance(result.content, list)
|
||||
assert isinstance(result.content[0], FunctionCall)
|
||||
assert result.content[0].name == "some_function"
|
||||
assert result.content[0].arguments == '{"foo": "bar"}'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multimodal_unsupported_raises_error(azure_client: AzureAIChatCompletionClient) -> None:
|
||||
"""
|
||||
If model does not support vision, providing an image should raise ValueError.
|
||||
"""
|
||||
with pytest.raises(ValueError) as exc:
|
||||
await azure_client.create(
|
||||
messages=[
|
||||
UserMessage(
|
||||
content=[ # type: ignore
|
||||
Image.from_base64(
|
||||
"iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR4nGNgAAIAAAUAAen6L8YAAAAASUVORK5CYII="
|
||||
)
|
||||
],
|
||||
source="user",
|
||||
)
|
||||
]
|
||||
)
|
||||
assert "does not support vision and image was provided" in str(exc.value)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multimodal_supported(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""
|
||||
If model supports vision, providing an image should not raise.
|
||||
"""
|
||||
|
||||
async def _mock_create_noop(*args: Any, **kwargs: Any) -> ChatCompletions:
|
||||
await asyncio.sleep(0.01)
|
||||
return ChatCompletions(
|
||||
id="id",
|
||||
created=datetime.now(),
|
||||
model="model",
|
||||
choices=[
|
||||
ChatChoice(
|
||||
index=0,
|
||||
finish_reason="stop",
|
||||
message=ChatResponseMessage(content="Handled image", role="assistant"),
|
||||
)
|
||||
],
|
||||
usage=CompletionsUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0),
|
||||
)
|
||||
|
||||
monkeypatch.setattr(ChatCompletionsClient, "complete", _mock_create_noop)
|
||||
|
||||
client = AzureAIChatCompletionClient(
|
||||
endpoint="endpoint",
|
||||
credential=AzureKeyCredential("api_key"),
|
||||
model_info={
|
||||
"json_output": False,
|
||||
"function_calling": False,
|
||||
"vision": True,
|
||||
"family": "vision_model",
|
||||
},
|
||||
model="model",
|
||||
)
|
||||
|
||||
result = await client.create(
|
||||
messages=[
|
||||
UserMessage(
|
||||
content=[
|
||||
Image.from_base64(
|
||||
"iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR4nGNgAAIAAAUAAen6L8YAAAAASUVORK5CYII="
|
||||
)
|
||||
],
|
||||
source="user",
|
||||
)
|
||||
]
|
||||
)
|
||||
assert result.content == "Handled image"
|
File diff suppressed because one or more lines are too long
|
@ -561,6 +561,7 @@ dependencies = [
|
|||
|
||||
[package.optional-dependencies]
|
||||
azure = [
|
||||
{ name = "azure-ai-inference" },
|
||||
{ name = "azure-core" },
|
||||
{ name = "azure-identity" },
|
||||
]
|
||||
|
@ -665,6 +666,7 @@ requires-dist = [
|
|||
{ name = "autogen-agentchat", marker = "extra == 'video-surfer'", editable = "packages/autogen-agentchat" },
|
||||
{ name = "autogen-agentchat", marker = "extra == 'web-surfer'", editable = "packages/autogen-agentchat" },
|
||||
{ name = "autogen-core", editable = "packages/autogen-core" },
|
||||
{ name = "azure-ai-inference", marker = "extra == 'azure'", specifier = ">=1.0.0b7" },
|
||||
{ name = "azure-core", marker = "extra == 'azure'" },
|
||||
{ name = "azure-identity", marker = "extra == 'azure'" },
|
||||
{ name = "diskcache", marker = "extra == 'diskcache'", specifier = ">=5.6.3" },
|
||||
|
@ -878,6 +880,20 @@ wheels = [
|
|||
{ url = "https://files.pythonhosted.org/packages/9e/43/53afb8ba17218f19b77c7834128566c5bbb100a0ad9ba2e8e89d089d7079/autopep8-2.3.2-py2.py3-none-any.whl", hash = "sha256:ce8ad498672c845a0c3de2629c15b635ec2b05ef8177a6e7c91c74f3e9b51128", size = 45807 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "azure-ai-inference"
|
||||
version = "1.0.0b7"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "azure-core" },
|
||||
{ name = "isodate" },
|
||||
{ name = "typing-extensions" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/af/37/233eee0bebbf631d2f911a9f1ebbc3784b100d9bfb84efc275e71c1ea636/azure_ai_inference-1.0.0b7.tar.gz", hash = "sha256:bd912f71f7f855036ca46c9a21439f290eed5e61da418fd26bbb32e3c68bcce3", size = 175883 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/cd/b6/5ba830eddc59f820c654694d476c14a3dd9c1f828ff9b48eb8b21dfd5f01/azure_ai_inference-1.0.0b7-py3-none-any.whl", hash = "sha256:59bb6a9ee62bd7654a69ca2bf12fe9335d7045df95b491cb8b5f9e3791c86175", size = 123030 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "azure-common"
|
||||
version = "1.1.28"
|
||||
|
|
Loading…
Reference in New Issue