mirror of https://github.com/microsoft/autogen.git
2452 lines
92 KiB
Python
2452 lines
92 KiB
Python
import asyncio
|
|
import json
|
|
import logging
|
|
import os
|
|
from typing import Annotated, Any, AsyncGenerator, Dict, List, Literal, Tuple, TypeVar
|
|
from unittest.mock import MagicMock
|
|
|
|
import httpx
|
|
import pytest
|
|
from autogen_core import CancellationToken, FunctionCall, Image
|
|
from autogen_core.models import (
|
|
AssistantMessage,
|
|
CreateResult,
|
|
FunctionExecutionResult,
|
|
FunctionExecutionResultMessage,
|
|
LLMMessage,
|
|
ModelInfo,
|
|
RequestUsage,
|
|
SystemMessage,
|
|
UserMessage,
|
|
)
|
|
from autogen_core.models._model_client import ModelFamily
|
|
from autogen_core.tools import BaseTool, FunctionTool
|
|
from autogen_ext.models.openai import AzureOpenAIChatCompletionClient, OpenAIChatCompletionClient
|
|
from autogen_ext.models.openai._model_info import resolve_model
|
|
from autogen_ext.models.openai._openai_client import (
|
|
BaseOpenAIChatCompletionClient,
|
|
calculate_vision_tokens,
|
|
convert_tools,
|
|
to_oai_type,
|
|
)
|
|
from autogen_ext.models.openai._transformation import TransformerMap, get_transformer
|
|
from openai.resources.beta.chat.completions import ( # type: ignore
|
|
AsyncChatCompletionStreamManager as BetaAsyncChatCompletionStreamManager, # type: ignore
|
|
)
|
|
|
|
# type: ignore
|
|
from openai.resources.beta.chat.completions import (
|
|
AsyncCompletions as BetaAsyncCompletions,
|
|
)
|
|
from openai.resources.chat.completions import AsyncCompletions
|
|
from openai.types.chat.chat_completion import ChatCompletion, Choice
|
|
from openai.types.chat.chat_completion_chunk import (
|
|
ChatCompletionChunk,
|
|
ChoiceDelta,
|
|
ChoiceDeltaToolCall,
|
|
ChoiceDeltaToolCallFunction,
|
|
)
|
|
from openai.types.chat.chat_completion_chunk import (
|
|
Choice as ChunkChoice,
|
|
)
|
|
from openai.types.chat.chat_completion_message import ChatCompletionMessage
|
|
from openai.types.chat.chat_completion_message_tool_call import (
|
|
ChatCompletionMessageToolCall,
|
|
Function,
|
|
)
|
|
from openai.types.chat.parsed_chat_completion import ParsedChatCompletion, ParsedChatCompletionMessage, ParsedChoice
|
|
from openai.types.chat.parsed_function_tool_call import ParsedFunction, ParsedFunctionToolCall
|
|
from openai.types.completion_usage import CompletionUsage
|
|
from pydantic import BaseModel, Field
|
|
|
|
ResponseFormatT = TypeVar("ResponseFormatT", bound=BaseModel)
|
|
|
|
|
|
def _pass_function(input: str) -> str:
|
|
return "pass"
|
|
|
|
|
|
async def _fail_function(input: str) -> str:
|
|
return "fail"
|
|
|
|
|
|
async def _echo_function(input: str) -> str:
|
|
return input
|
|
|
|
|
|
class MyResult(BaseModel):
|
|
result: str = Field(description="The other description.")
|
|
|
|
|
|
class MyArgs(BaseModel):
|
|
query: str = Field(description="The description.")
|
|
|
|
|
|
class MockChunkDefinition(BaseModel):
|
|
# defining elements for diffentiating mocking chunks
|
|
chunk_choice: ChunkChoice
|
|
usage: CompletionUsage | None
|
|
|
|
|
|
class MockChunkEvent(BaseModel):
|
|
type: Literal["chunk"]
|
|
chunk: ChatCompletionChunk
|
|
|
|
|
|
async def _mock_create_stream(*args: Any, **kwargs: Any) -> AsyncGenerator[ChatCompletionChunk, None]:
|
|
model = resolve_model(kwargs.get("model", "gpt-4o"))
|
|
mock_chunks_content = ["Hello", " Another Hello", " Yet Another Hello"]
|
|
|
|
# The openai api implementations (OpenAI and Litellm) stream chunks of tokens
|
|
# with content as string, and then at the end a token with stop set and finally if
|
|
# usage requested with `"stream_options": {"include_usage": True}` a chunk with the usage data
|
|
mock_chunks = [
|
|
# generate the list of mock chunk content
|
|
MockChunkDefinition(
|
|
chunk_choice=ChunkChoice(
|
|
finish_reason=None,
|
|
index=0,
|
|
delta=ChoiceDelta(
|
|
content=mock_chunk_content,
|
|
role="assistant",
|
|
),
|
|
),
|
|
usage=None,
|
|
)
|
|
for mock_chunk_content in mock_chunks_content
|
|
] + [
|
|
# generate the stop chunk
|
|
MockChunkDefinition(
|
|
chunk_choice=ChunkChoice(
|
|
finish_reason="stop",
|
|
index=0,
|
|
delta=ChoiceDelta(
|
|
content=None,
|
|
role="assistant",
|
|
),
|
|
),
|
|
usage=None,
|
|
)
|
|
]
|
|
# generate the usage chunk if configured
|
|
if kwargs.get("stream_options", {}).get("include_usage") is True:
|
|
mock_chunks = mock_chunks + [
|
|
# ---- API differences
|
|
# OPENAI API does NOT create a choice
|
|
# LITELLM (proxy) DOES create a choice
|
|
# Not simulating all the API options, just implementing the LITELLM variant
|
|
MockChunkDefinition(
|
|
chunk_choice=ChunkChoice(
|
|
finish_reason=None,
|
|
index=0,
|
|
delta=ChoiceDelta(
|
|
content=None,
|
|
role="assistant",
|
|
),
|
|
),
|
|
usage=CompletionUsage(prompt_tokens=3, completion_tokens=3, total_tokens=6),
|
|
)
|
|
]
|
|
elif kwargs.get("stream_options", {}).get("include_usage") is False:
|
|
pass
|
|
else:
|
|
pass
|
|
|
|
for mock_chunk in mock_chunks:
|
|
await asyncio.sleep(0.1)
|
|
yield ChatCompletionChunk(
|
|
id="id",
|
|
choices=[mock_chunk.chunk_choice],
|
|
created=0,
|
|
model=model,
|
|
object="chat.completion.chunk",
|
|
usage=mock_chunk.usage,
|
|
)
|
|
|
|
|
|
async def _mock_create(*args: Any, **kwargs: Any) -> ChatCompletion | AsyncGenerator[ChatCompletionChunk, None]:
|
|
stream = kwargs.get("stream", False)
|
|
model = resolve_model(kwargs.get("model", "gpt-4o"))
|
|
if not stream:
|
|
await asyncio.sleep(0.1)
|
|
return ChatCompletion(
|
|
id="id",
|
|
choices=[
|
|
Choice(finish_reason="stop", index=0, message=ChatCompletionMessage(content="Hello", role="assistant"))
|
|
],
|
|
created=0,
|
|
model=model,
|
|
object="chat.completion",
|
|
usage=CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0),
|
|
)
|
|
else:
|
|
return _mock_create_stream(*args, **kwargs)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_openai_chat_completion_client() -> None:
|
|
client = OpenAIChatCompletionClient(model="gpt-4o", api_key="api_key")
|
|
assert client
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_openai_chat_completion_client_with_gemini_model() -> None:
|
|
client = OpenAIChatCompletionClient(model="gemini-1.5-flash", api_key="api_key")
|
|
assert client
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_openai_chat_completion_client_serialization() -> None:
|
|
client = OpenAIChatCompletionClient(model="gpt-4o", api_key="sk-password")
|
|
assert client
|
|
config = client.dump_component()
|
|
assert config
|
|
assert "sk-password" not in str(config)
|
|
serialized_config = config.model_dump_json()
|
|
assert serialized_config
|
|
assert "sk-password" not in serialized_config
|
|
client2 = OpenAIChatCompletionClient.load_component(config)
|
|
assert client2
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_openai_chat_completion_client_raise_on_unknown_model() -> None:
|
|
with pytest.raises(ValueError, match="model_info is required"):
|
|
_ = OpenAIChatCompletionClient(model="unknown", api_key="api_key")
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_custom_model_with_capabilities() -> None:
|
|
with pytest.raises(ValueError, match="model_info is required"):
|
|
client = OpenAIChatCompletionClient(model="dummy_model", base_url="https://api.dummy.com/v0", api_key="api_key")
|
|
|
|
client = OpenAIChatCompletionClient(
|
|
model="dummy_model",
|
|
base_url="https://api.dummy.com/v0",
|
|
api_key="api_key",
|
|
model_info={
|
|
"vision": False,
|
|
"function_calling": False,
|
|
"json_output": False,
|
|
"family": ModelFamily.UNKNOWN,
|
|
"structured_output": False,
|
|
},
|
|
)
|
|
assert client
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_azure_openai_chat_completion_client() -> None:
|
|
client = AzureOpenAIChatCompletionClient(
|
|
azure_deployment="gpt-4o-1",
|
|
model="gpt-4o",
|
|
api_key="api_key",
|
|
api_version="2020-08-04",
|
|
azure_endpoint="https://dummy.com",
|
|
model_info={
|
|
"vision": True,
|
|
"function_calling": True,
|
|
"json_output": True,
|
|
"family": ModelFamily.GPT_4O,
|
|
"structured_output": True,
|
|
},
|
|
)
|
|
assert client
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_openai_chat_completion_client_create(
|
|
monkeypatch: pytest.MonkeyPatch, caplog: pytest.LogCaptureFixture
|
|
) -> None:
|
|
monkeypatch.setattr(AsyncCompletions, "create", _mock_create)
|
|
with caplog.at_level(logging.INFO):
|
|
client = OpenAIChatCompletionClient(model="gpt-4o", api_key="api_key")
|
|
result = await client.create(messages=[UserMessage(content="Hello", source="user")])
|
|
assert result.content == "Hello"
|
|
assert "LLMCall" in caplog.text and "Hello" in caplog.text
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_openai_chat_completion_client_create_stream_with_usage(
|
|
monkeypatch: pytest.MonkeyPatch, caplog: pytest.LogCaptureFixture
|
|
) -> None:
|
|
monkeypatch.setattr(AsyncCompletions, "create", _mock_create)
|
|
client = OpenAIChatCompletionClient(model="gpt-4o", api_key="api_key")
|
|
chunks: List[str | CreateResult] = []
|
|
with caplog.at_level(logging.INFO):
|
|
async for chunk in client.create_stream(
|
|
messages=[UserMessage(content="Hello", source="user")],
|
|
# include_usage not the default of the OPENAI API and must be explicitly set
|
|
extra_create_args={"stream_options": {"include_usage": True}},
|
|
):
|
|
chunks.append(chunk)
|
|
|
|
assert "LLMStreamStart" in caplog.text
|
|
assert "LLMStreamEnd" in caplog.text
|
|
|
|
assert chunks[0] == "Hello"
|
|
assert chunks[1] == " Another Hello"
|
|
assert chunks[2] == " Yet Another Hello"
|
|
assert isinstance(chunks[-1], CreateResult)
|
|
assert isinstance(chunks[-1].content, str)
|
|
assert chunks[-1].content == "Hello Another Hello Yet Another Hello"
|
|
assert chunks[-1].content in caplog.text
|
|
assert chunks[-1].usage == RequestUsage(prompt_tokens=3, completion_tokens=3)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_openai_chat_completion_client_create_stream_no_usage_default(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
monkeypatch.setattr(AsyncCompletions, "create", _mock_create)
|
|
client = OpenAIChatCompletionClient(model="gpt-4o", api_key="api_key")
|
|
chunks: List[str | CreateResult] = []
|
|
async for chunk in client.create_stream(
|
|
messages=[UserMessage(content="Hello", source="user")],
|
|
# include_usage not the default of the OPENAI APIis ,
|
|
# it can be explicitly set
|
|
# or just not declared which is the default
|
|
# extra_create_args={"stream_options": {"include_usage": False}},
|
|
):
|
|
chunks.append(chunk)
|
|
assert chunks[0] == "Hello"
|
|
assert chunks[1] == " Another Hello"
|
|
assert chunks[2] == " Yet Another Hello"
|
|
assert isinstance(chunks[-1], CreateResult)
|
|
assert chunks[-1].content == "Hello Another Hello Yet Another Hello"
|
|
assert chunks[-1].usage == RequestUsage(prompt_tokens=0, completion_tokens=0)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_openai_chat_completion_client_create_stream_no_usage_explicit(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
monkeypatch.setattr(AsyncCompletions, "create", _mock_create)
|
|
client = OpenAIChatCompletionClient(model="gpt-4o", api_key="api_key")
|
|
chunks: List[str | CreateResult] = []
|
|
async for chunk in client.create_stream(
|
|
messages=[UserMessage(content="Hello", source="user")],
|
|
# include_usage is not the default of the OPENAI API ,
|
|
# it can be explicitly set
|
|
# or just not declared which is the default
|
|
extra_create_args={"stream_options": {"include_usage": False}},
|
|
):
|
|
chunks.append(chunk)
|
|
assert chunks[0] == "Hello"
|
|
assert chunks[1] == " Another Hello"
|
|
assert chunks[2] == " Yet Another Hello"
|
|
assert isinstance(chunks[-1], CreateResult)
|
|
assert chunks[-1].content == "Hello Another Hello Yet Another Hello"
|
|
assert chunks[-1].usage == RequestUsage(prompt_tokens=0, completion_tokens=0)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_openai_chat_completion_client_create_cancel(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
monkeypatch.setattr(AsyncCompletions, "create", _mock_create)
|
|
client = OpenAIChatCompletionClient(model="gpt-4o", api_key="api_key")
|
|
cancellation_token = CancellationToken()
|
|
task = asyncio.create_task(
|
|
client.create(messages=[UserMessage(content="Hello", source="user")], cancellation_token=cancellation_token)
|
|
)
|
|
cancellation_token.cancel()
|
|
with pytest.raises(asyncio.CancelledError):
|
|
await task
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_openai_chat_completion_client_create_stream_cancel(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
monkeypatch.setattr(AsyncCompletions, "create", _mock_create)
|
|
client = OpenAIChatCompletionClient(model="gpt-4o", api_key="api_key")
|
|
cancellation_token = CancellationToken()
|
|
stream = client.create_stream(
|
|
messages=[UserMessage(content="Hello", source="user")], cancellation_token=cancellation_token
|
|
)
|
|
assert await anext(stream)
|
|
cancellation_token.cancel()
|
|
with pytest.raises(asyncio.CancelledError):
|
|
async for _ in stream:
|
|
pass
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_openai_chat_completion_client_count_tokens(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
client = OpenAIChatCompletionClient(model="gpt-4o", api_key="api_key")
|
|
messages: List[LLMMessage] = [
|
|
SystemMessage(content="Hello"),
|
|
UserMessage(content="Hello", source="user"),
|
|
AssistantMessage(content="Hello", source="assistant"),
|
|
UserMessage(
|
|
content=[
|
|
"str1",
|
|
Image.from_base64(
|
|
"iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGP4z8AAAAMBAQDJ/pLvAAAAAElFTkSuQmCC"
|
|
),
|
|
],
|
|
source="user",
|
|
),
|
|
FunctionExecutionResultMessage(
|
|
content=[FunctionExecutionResult(content="Hello", call_id="1", is_error=False, name="tool1")]
|
|
),
|
|
]
|
|
|
|
def tool1(test: str, test2: str) -> str:
|
|
return test + test2
|
|
|
|
def tool2(test1: int, test2: List[int]) -> str:
|
|
return str(test1) + str(test2)
|
|
|
|
tools = [FunctionTool(tool1, description="example tool 1"), FunctionTool(tool2, description="example tool 2")]
|
|
|
|
mockcalculate_vision_tokens = MagicMock()
|
|
monkeypatch.setattr("autogen_ext.models.openai._openai_client.calculate_vision_tokens", mockcalculate_vision_tokens)
|
|
|
|
num_tokens = client.count_tokens(messages, tools=tools)
|
|
assert num_tokens
|
|
|
|
# Check that calculate_vision_tokens was called
|
|
mockcalculate_vision_tokens.assert_called_once()
|
|
|
|
remaining_tokens = client.remaining_tokens(messages, tools=tools)
|
|
assert remaining_tokens
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"mock_size, expected_num_tokens",
|
|
[
|
|
((1, 1), 255),
|
|
((512, 512), 255),
|
|
((2048, 512), 765),
|
|
((2048, 2048), 765),
|
|
((512, 1024), 425),
|
|
],
|
|
)
|
|
def test_openai_count_image_tokens(mock_size: Tuple[int, int], expected_num_tokens: int) -> None:
|
|
# Step 1: Mock the Image class with only the 'image' attribute
|
|
mock_image_attr = MagicMock()
|
|
mock_image_attr.size = mock_size
|
|
|
|
mock_image = MagicMock()
|
|
mock_image.image = mock_image_attr
|
|
|
|
# Directly call calculate_vision_tokens and check the result
|
|
calculated_tokens = calculate_vision_tokens(mock_image, detail="auto")
|
|
assert calculated_tokens == expected_num_tokens
|
|
|
|
|
|
def test_convert_tools_accepts_both_func_tool_and_schema() -> None:
|
|
def my_function(arg: str, other: Annotated[int, "int arg"], nonrequired: int = 5) -> MyResult:
|
|
return MyResult(result="test")
|
|
|
|
tool = FunctionTool(my_function, description="Function tool.")
|
|
schema = tool.schema
|
|
|
|
converted_tool_schema = convert_tools([tool, schema])
|
|
|
|
assert len(converted_tool_schema) == 2
|
|
assert converted_tool_schema[0] == converted_tool_schema[1]
|
|
|
|
|
|
def test_convert_tools_accepts_both_tool_and_schema() -> None:
|
|
class MyTool(BaseTool[MyArgs, MyResult]):
|
|
def __init__(self) -> None:
|
|
super().__init__(
|
|
args_type=MyArgs,
|
|
return_type=MyResult,
|
|
name="TestTool",
|
|
description="Description of test tool.",
|
|
)
|
|
|
|
async def run(self, args: MyArgs, cancellation_token: CancellationToken) -> MyResult:
|
|
return MyResult(result="value")
|
|
|
|
tool = MyTool()
|
|
schema = tool.schema
|
|
|
|
converted_tool_schema = convert_tools([tool, schema])
|
|
|
|
assert len(converted_tool_schema) == 2
|
|
assert converted_tool_schema[0] == converted_tool_schema[1]
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_json_mode(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
model = "gpt-4o-2024-11-20"
|
|
|
|
called_args = {}
|
|
|
|
async def _mock_create(*args: Any, **kwargs: Any) -> ChatCompletion:
|
|
# Capture the arguments passed to the function
|
|
called_args["kwargs"] = kwargs
|
|
return ChatCompletion(
|
|
id="id1",
|
|
choices=[
|
|
Choice(
|
|
finish_reason="stop",
|
|
index=0,
|
|
message=ChatCompletionMessage(
|
|
content=json.dumps({"thoughts": "happy", "response": "happy"}),
|
|
role="assistant",
|
|
),
|
|
)
|
|
],
|
|
created=0,
|
|
model=model,
|
|
object="chat.completion",
|
|
usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0),
|
|
)
|
|
|
|
monkeypatch.setattr(AsyncCompletions, "create", _mock_create)
|
|
model_client = OpenAIChatCompletionClient(model=model, api_key="")
|
|
|
|
# Test that the openai client was called with the correct response format.
|
|
create_result = await model_client.create(
|
|
messages=[UserMessage(content="I am happy.", source="user")], json_output=True
|
|
)
|
|
assert isinstance(create_result.content, str)
|
|
response = json.loads(create_result.content)
|
|
assert response["thoughts"] == "happy"
|
|
assert response["response"] == "happy"
|
|
assert called_args["kwargs"]["response_format"] == {"type": "json_object"}
|
|
|
|
# Make sure that the response format is set to json_object when json_output is True, regardless of the extra_create_args.
|
|
create_result = await model_client.create(
|
|
messages=[UserMessage(content="I am happy.", source="user")],
|
|
json_output=True,
|
|
extra_create_args={"response_format": "json_object"},
|
|
)
|
|
assert isinstance(create_result.content, str)
|
|
response = json.loads(create_result.content)
|
|
assert response["thoughts"] == "happy"
|
|
assert response["response"] == "happy"
|
|
assert called_args["kwargs"]["response_format"] == {"type": "json_object"}
|
|
|
|
create_result = await model_client.create(
|
|
messages=[UserMessage(content="I am happy.", source="user")],
|
|
json_output=True,
|
|
extra_create_args={"response_format": "text"},
|
|
)
|
|
assert isinstance(create_result.content, str)
|
|
response = json.loads(create_result.content)
|
|
assert response["thoughts"] == "happy"
|
|
assert response["response"] == "happy"
|
|
# Check that the openai client was called with the correct response format.
|
|
assert called_args["kwargs"]["response_format"] == {"type": "json_object"}
|
|
|
|
# Make sure when json_output is set to False, the response format is always set to text.
|
|
create_result = await model_client.create(
|
|
messages=[UserMessage(content="I am happy.", source="user")],
|
|
json_output=False,
|
|
extra_create_args={"response_format": "text"},
|
|
)
|
|
assert called_args["kwargs"]["response_format"] == {"type": "text"}
|
|
|
|
create_result = await model_client.create(
|
|
messages=[UserMessage(content="I am happy.", source="user")],
|
|
json_output=False,
|
|
extra_create_args={"response_format": "json_object"},
|
|
)
|
|
assert called_args["kwargs"]["response_format"] == {"type": "text"}
|
|
|
|
# Make sure when response_format is set it is used when json_output is not set.
|
|
create_result = await model_client.create(
|
|
messages=[UserMessage(content="I am happy.", source="user")],
|
|
extra_create_args={"response_format": {"type": "json_object"}},
|
|
)
|
|
assert isinstance(create_result.content, str)
|
|
response = json.loads(create_result.content)
|
|
assert response["thoughts"] == "happy"
|
|
assert response["response"] == "happy"
|
|
assert called_args["kwargs"]["response_format"] == {"type": "json_object"}
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_structured_output_using_response_format(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
class AgentResponse(BaseModel):
|
|
thoughts: str
|
|
response: Literal["happy", "sad", "neutral"]
|
|
|
|
model = "gpt-4o-2024-11-20"
|
|
|
|
called_args = {}
|
|
|
|
async def _mock_create(*args: Any, **kwargs: Any) -> ChatCompletion:
|
|
# Capture the arguments passed to the function
|
|
called_args["kwargs"] = kwargs
|
|
return ChatCompletion(
|
|
id="id1",
|
|
choices=[
|
|
Choice(
|
|
finish_reason="stop",
|
|
index=0,
|
|
message=ChatCompletionMessage(
|
|
content=json.dumps({"thoughts": "happy", "response": "happy"}),
|
|
role="assistant",
|
|
),
|
|
)
|
|
],
|
|
created=0,
|
|
model=model,
|
|
object="chat.completion",
|
|
usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0),
|
|
)
|
|
|
|
monkeypatch.setattr(AsyncCompletions, "create", _mock_create)
|
|
|
|
# Scenario 1: response_format is set to constructor.
|
|
model_client = OpenAIChatCompletionClient(
|
|
model=model,
|
|
api_key="",
|
|
response_format={
|
|
"type": "json_schema",
|
|
"json_schema": {
|
|
"name": "test",
|
|
"description": "test",
|
|
"schema": AgentResponse.model_json_schema(),
|
|
},
|
|
},
|
|
)
|
|
|
|
create_result = await model_client.create(
|
|
messages=[UserMessage(content="I am happy.", source="user")],
|
|
)
|
|
assert isinstance(create_result.content, str)
|
|
response = json.loads(create_result.content)
|
|
assert response["thoughts"] == "happy"
|
|
assert response["response"] == "happy"
|
|
assert called_args["kwargs"]["response_format"]["type"] == "json_schema"
|
|
|
|
# Test the response format can be serailized and deserialized.
|
|
config = model_client.dump_component()
|
|
assert config
|
|
loaded_client = OpenAIChatCompletionClient.load_component(config)
|
|
|
|
create_result = await loaded_client.create(
|
|
messages=[UserMessage(content="I am happy.", source="user")],
|
|
)
|
|
assert isinstance(create_result.content, str)
|
|
response = json.loads(create_result.content)
|
|
assert response["thoughts"] == "happy"
|
|
assert response["response"] == "happy"
|
|
assert called_args["kwargs"]["response_format"]["type"] == "json_schema"
|
|
|
|
# Scenario 2: response_format is set to a extra_create_args.
|
|
model_client = OpenAIChatCompletionClient(model=model, api_key="")
|
|
create_result = await model_client.create(
|
|
messages=[UserMessage(content="I am happy.", source="user")],
|
|
extra_create_args={
|
|
"response_format": {
|
|
"type": "json_schema",
|
|
"json_schema": {
|
|
"name": "test",
|
|
"description": "test",
|
|
"schema": AgentResponse.model_json_schema(),
|
|
},
|
|
}
|
|
},
|
|
)
|
|
assert isinstance(create_result.content, str)
|
|
response = json.loads(create_result.content)
|
|
assert response["thoughts"] == "happy"
|
|
assert response["response"] == "happy"
|
|
assert called_args["kwargs"]["response_format"]["type"] == "json_schema"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_structured_output(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
class AgentResponse(BaseModel):
|
|
thoughts: str
|
|
response: Literal["happy", "sad", "neutral"]
|
|
|
|
model = "gpt-4o-2024-11-20"
|
|
|
|
async def _mock_parse(*args: Any, **kwargs: Any) -> ParsedChatCompletion[AgentResponse]:
|
|
return ParsedChatCompletion(
|
|
id="id1",
|
|
choices=[
|
|
ParsedChoice(
|
|
finish_reason="stop",
|
|
index=0,
|
|
message=ParsedChatCompletionMessage(
|
|
content=json.dumps(
|
|
{
|
|
"thoughts": "The user explicitly states that they are happy without any indication of sadness or neutrality.",
|
|
"response": "happy",
|
|
}
|
|
),
|
|
role="assistant",
|
|
),
|
|
)
|
|
],
|
|
created=0,
|
|
model=model,
|
|
object="chat.completion",
|
|
usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0),
|
|
)
|
|
|
|
monkeypatch.setattr(BetaAsyncCompletions, "parse", _mock_parse)
|
|
|
|
model_client = OpenAIChatCompletionClient(
|
|
model=model,
|
|
api_key="",
|
|
)
|
|
|
|
# Test that the openai client was called with the correct response format.
|
|
create_result = await model_client.create(
|
|
messages=[UserMessage(content="I am happy.", source="user")], json_output=AgentResponse
|
|
)
|
|
assert isinstance(create_result.content, str)
|
|
response = AgentResponse.model_validate(json.loads(create_result.content))
|
|
assert (
|
|
response.thoughts
|
|
== "The user explicitly states that they are happy without any indication of sadness or neutrality."
|
|
)
|
|
assert response.response == "happy"
|
|
|
|
# Test that a warning will be raise if response_format is set to a dict.
|
|
with pytest.warns(
|
|
UserWarning,
|
|
match="response_format is found in extra_create_args while json_output is set to a Pydantic model class.",
|
|
):
|
|
create_result = await model_client.create(
|
|
messages=[UserMessage(content="I am happy.", source="user")],
|
|
json_output=AgentResponse,
|
|
extra_create_args={"response_format": {"type": "json_object"}},
|
|
)
|
|
|
|
# Test that a warning will be raised if response_format is set to a pydantic model.
|
|
with pytest.warns(
|
|
DeprecationWarning,
|
|
match="Using response_format to specify the BaseModel for structured output type will be deprecated.",
|
|
):
|
|
create_result = await model_client.create(
|
|
messages=[UserMessage(content="I am happy.", source="user")],
|
|
extra_create_args={"response_format": AgentResponse},
|
|
)
|
|
|
|
# Test that a ValueError will be raised if response_format and json_output are set to a pydantic model.
|
|
with pytest.raises(
|
|
ValueError, match="response_format and json_output cannot be set to a Pydantic model class at the same time."
|
|
):
|
|
create_result = await model_client.create(
|
|
messages=[UserMessage(content="I am happy.", source="user")],
|
|
json_output=AgentResponse,
|
|
extra_create_args={"response_format": AgentResponse},
|
|
)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_structured_output_with_tool_calls(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
class AgentResponse(BaseModel):
|
|
thoughts: str
|
|
response: Literal["happy", "sad", "neutral"]
|
|
|
|
model = "gpt-4o-2024-11-20"
|
|
|
|
async def _mock_parse(*args: Any, **kwargs: Any) -> ParsedChatCompletion[AgentResponse]:
|
|
return ParsedChatCompletion(
|
|
id="id1",
|
|
choices=[
|
|
ParsedChoice(
|
|
finish_reason="tool_calls",
|
|
index=0,
|
|
message=ParsedChatCompletionMessage(
|
|
content=json.dumps(
|
|
{
|
|
"thoughts": "The user explicitly states that they are happy without any indication of sadness or neutrality.",
|
|
"response": "happy",
|
|
}
|
|
),
|
|
role="assistant",
|
|
tool_calls=[
|
|
ParsedFunctionToolCall(
|
|
id="1",
|
|
type="function",
|
|
function=ParsedFunction(
|
|
name="_pass_function",
|
|
arguments=json.dumps({"input": "happy"}),
|
|
),
|
|
)
|
|
],
|
|
),
|
|
)
|
|
],
|
|
created=0,
|
|
model=model,
|
|
object="chat.completion",
|
|
usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0),
|
|
)
|
|
|
|
monkeypatch.setattr(BetaAsyncCompletions, "parse", _mock_parse)
|
|
|
|
model_client = OpenAIChatCompletionClient(
|
|
model=model,
|
|
api_key="",
|
|
)
|
|
|
|
# Test that the openai client was called with the correct response format.
|
|
create_result = await model_client.create(
|
|
messages=[UserMessage(content="I am happy.", source="user")], json_output=AgentResponse
|
|
)
|
|
assert isinstance(create_result.content, list)
|
|
assert len(create_result.content) == 1
|
|
assert create_result.content[0] == FunctionCall(
|
|
id="1", name="_pass_function", arguments=json.dumps({"input": "happy"})
|
|
)
|
|
assert isinstance(create_result.thought, str)
|
|
response = AgentResponse.model_validate(json.loads(create_result.thought))
|
|
assert (
|
|
response.thoughts
|
|
== "The user explicitly states that they are happy without any indication of sadness or neutrality."
|
|
)
|
|
assert response.response == "happy"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_structured_output_with_streaming(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
class AgentResponse(BaseModel):
|
|
thoughts: str
|
|
response: Literal["happy", "sad", "neutral"]
|
|
|
|
raw_content = json.dumps(
|
|
{
|
|
"thoughts": "The user explicitly states that they are happy without any indication of sadness or neutrality.",
|
|
"response": "happy",
|
|
}
|
|
)
|
|
chunked_content = [raw_content[i : i + 5] for i in range(0, len(raw_content), 5)]
|
|
assert "".join(chunked_content) == raw_content
|
|
|
|
model = "gpt-4o-2024-11-20"
|
|
mock_chunk_events = [
|
|
MockChunkEvent(
|
|
type="chunk",
|
|
chunk=ChatCompletionChunk(
|
|
id="id",
|
|
choices=[
|
|
ChunkChoice(
|
|
finish_reason=None,
|
|
index=0,
|
|
delta=ChoiceDelta(
|
|
content=mock_chunk_content,
|
|
role="assistant",
|
|
),
|
|
)
|
|
],
|
|
created=0,
|
|
model=model,
|
|
object="chat.completion.chunk",
|
|
usage=None,
|
|
),
|
|
)
|
|
for mock_chunk_content in chunked_content
|
|
]
|
|
|
|
async def _mock_create_stream(*args: Any) -> AsyncGenerator[MockChunkEvent, None]:
|
|
async def _stream() -> AsyncGenerator[MockChunkEvent, None]:
|
|
for mock_chunk_event in mock_chunk_events:
|
|
await asyncio.sleep(0.1)
|
|
yield mock_chunk_event
|
|
|
|
return _stream()
|
|
|
|
# Mock the context manager __aenter__ method which returns the stream.
|
|
monkeypatch.setattr(BetaAsyncChatCompletionStreamManager, "__aenter__", _mock_create_stream)
|
|
|
|
model_client = OpenAIChatCompletionClient(
|
|
model=model,
|
|
api_key="",
|
|
)
|
|
|
|
# Test that the openai client was called with the correct response format.
|
|
chunks: List[str | CreateResult] = []
|
|
async for chunk in model_client.create_stream(
|
|
messages=[UserMessage(content="I am happy.", source="user")], json_output=AgentResponse
|
|
):
|
|
chunks.append(chunk)
|
|
assert len(chunks) > 0
|
|
assert isinstance(chunks[-1], CreateResult)
|
|
assert isinstance(chunks[-1].content, str)
|
|
response = AgentResponse.model_validate(json.loads(chunks[-1].content))
|
|
assert (
|
|
response.thoughts
|
|
== "The user explicitly states that they are happy without any indication of sadness or neutrality."
|
|
)
|
|
assert response.response == "happy"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_structured_output_with_streaming_tool_calls(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
class AgentResponse(BaseModel):
|
|
thoughts: str
|
|
response: Literal["happy", "sad", "neutral"]
|
|
|
|
raw_content = json.dumps(
|
|
{
|
|
"thoughts": "The user explicitly states that they are happy without any indication of sadness or neutrality.",
|
|
"response": "happy",
|
|
}
|
|
)
|
|
chunked_content = [raw_content[i : i + 5] for i in range(0, len(raw_content), 5)]
|
|
assert "".join(chunked_content) == raw_content
|
|
|
|
model = "gpt-4o-2024-11-20"
|
|
|
|
# generate the list of mock chunk content
|
|
mock_chunk_events = [
|
|
MockChunkEvent(
|
|
type="chunk",
|
|
chunk=ChatCompletionChunk(
|
|
id="id",
|
|
choices=[
|
|
ChunkChoice(
|
|
finish_reason=None,
|
|
index=0,
|
|
delta=ChoiceDelta(
|
|
content=mock_chunk_content,
|
|
role="assistant",
|
|
),
|
|
)
|
|
],
|
|
created=0,
|
|
model=model,
|
|
object="chat.completion.chunk",
|
|
usage=None,
|
|
),
|
|
)
|
|
for mock_chunk_content in chunked_content
|
|
]
|
|
|
|
# add the tool call chunk.
|
|
mock_chunk_events += [
|
|
MockChunkEvent(
|
|
type="chunk",
|
|
chunk=ChatCompletionChunk(
|
|
id="id",
|
|
choices=[
|
|
ChunkChoice(
|
|
finish_reason="tool_calls",
|
|
index=0,
|
|
delta=ChoiceDelta(
|
|
content=None,
|
|
role="assistant",
|
|
tool_calls=[
|
|
ChoiceDeltaToolCall(
|
|
id="1",
|
|
index=0,
|
|
type="function",
|
|
function=ChoiceDeltaToolCallFunction(
|
|
name="_pass_function",
|
|
arguments=json.dumps({"input": "happy"}),
|
|
),
|
|
)
|
|
],
|
|
),
|
|
)
|
|
],
|
|
created=0,
|
|
model=model,
|
|
object="chat.completion.chunk",
|
|
usage=None,
|
|
),
|
|
)
|
|
]
|
|
|
|
async def _mock_create_stream(*args: Any) -> AsyncGenerator[MockChunkEvent, None]:
|
|
async def _stream() -> AsyncGenerator[MockChunkEvent, None]:
|
|
for mock_chunk_event in mock_chunk_events:
|
|
await asyncio.sleep(0.1)
|
|
yield mock_chunk_event
|
|
|
|
return _stream()
|
|
|
|
# Mock the context manager __aenter__ method which returns the stream.
|
|
monkeypatch.setattr(BetaAsyncChatCompletionStreamManager, "__aenter__", _mock_create_stream)
|
|
|
|
model_client = OpenAIChatCompletionClient(
|
|
model=model,
|
|
api_key="",
|
|
)
|
|
|
|
# Test that the openai client was called with the correct response format.
|
|
chunks: List[str | CreateResult] = []
|
|
async for chunk in model_client.create_stream(
|
|
messages=[UserMessage(content="I am happy.", source="user")], json_output=AgentResponse
|
|
):
|
|
chunks.append(chunk)
|
|
assert len(chunks) > 0
|
|
assert isinstance(chunks[-1], CreateResult)
|
|
assert isinstance(chunks[-1].content, list)
|
|
assert len(chunks[-1].content) == 1
|
|
assert chunks[-1].content[0] == FunctionCall(
|
|
id="1", name="_pass_function", arguments=json.dumps({"input": "happy"})
|
|
)
|
|
assert isinstance(chunks[-1].thought, str)
|
|
response = AgentResponse.model_validate(json.loads(chunks[-1].thought))
|
|
assert (
|
|
response.thoughts
|
|
== "The user explicitly states that they are happy without any indication of sadness or neutrality."
|
|
)
|
|
assert response.response == "happy"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_r1_reasoning_content(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
"""Test handling of reasoning_content in R1 model. Testing create without streaming."""
|
|
|
|
async def _mock_create(*args: Any, **kwargs: Any) -> ChatCompletion:
|
|
return ChatCompletion(
|
|
id="test_id",
|
|
model="r1",
|
|
object="chat.completion",
|
|
created=1234567890,
|
|
choices=[
|
|
Choice(
|
|
index=0,
|
|
message=ChatCompletionMessage(
|
|
role="assistant",
|
|
content="This is the main content",
|
|
# The reasoning content is included in model_extra for hosted R1 models.
|
|
reasoning_content="This is the reasoning content", # type: ignore
|
|
),
|
|
finish_reason="stop",
|
|
)
|
|
],
|
|
usage=CompletionUsage(
|
|
prompt_tokens=10,
|
|
completion_tokens=10,
|
|
total_tokens=20,
|
|
),
|
|
)
|
|
|
|
# Patch the client creation
|
|
|
|
monkeypatch.setattr(AsyncCompletions, "create", _mock_create)
|
|
|
|
# Create the client
|
|
model_client = OpenAIChatCompletionClient(
|
|
model="r1",
|
|
api_key="",
|
|
model_info={
|
|
"family": ModelFamily.R1,
|
|
"vision": False,
|
|
"function_calling": False,
|
|
"json_output": False,
|
|
"structured_output": False,
|
|
},
|
|
)
|
|
|
|
# Test the create method
|
|
result = await model_client.create([UserMessage(content="Test message", source="user")])
|
|
|
|
# Verify that the content and thought are as expected
|
|
assert result.content == "This is the main content"
|
|
assert result.thought == "This is the reasoning content"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_r1_reasoning_content_streaming(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
"""Test that reasoning_content in model_extra is correctly extracted and streamed."""
|
|
|
|
async def _mock_create_stream(*args: Any, **kwargs: Any) -> AsyncGenerator[ChatCompletionChunk, None]:
|
|
contentChunks = [None, None, "This is the main content"]
|
|
reasoningChunks = ["This is the reasoning content 1", "This is the reasoning content 2", None]
|
|
for i in range(len(contentChunks)):
|
|
await asyncio.sleep(0.1)
|
|
yield ChatCompletionChunk(
|
|
id="id",
|
|
choices=[
|
|
ChunkChoice(
|
|
finish_reason="stop" if i == len(contentChunks) - 1 else None,
|
|
index=0,
|
|
delta=ChoiceDelta(
|
|
content=contentChunks[i],
|
|
# The reasoning content is included in model_extra for hosted R1 models.
|
|
reasoning_content=reasoningChunks[i], # type: ignore
|
|
role="assistant",
|
|
),
|
|
),
|
|
],
|
|
created=0,
|
|
model="r1",
|
|
object="chat.completion.chunk",
|
|
usage=CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0),
|
|
)
|
|
|
|
async def _mock_create(*args: Any, **kwargs: Any) -> AsyncGenerator[ChatCompletionChunk, None]:
|
|
return _mock_create_stream(*args, **kwargs)
|
|
|
|
# Patch the client creation
|
|
monkeypatch.setattr(AsyncCompletions, "create", _mock_create)
|
|
# Create the client
|
|
model_client = OpenAIChatCompletionClient(
|
|
model="r1",
|
|
api_key="",
|
|
model_info={
|
|
"family": ModelFamily.R1,
|
|
"vision": False,
|
|
"function_calling": False,
|
|
"json_output": False,
|
|
"structured_output": False,
|
|
},
|
|
)
|
|
# Test the create_stream method
|
|
chunks: List[str | CreateResult] = []
|
|
async for chunk in model_client.create_stream(messages=[UserMessage(content="Hello", source="user")]):
|
|
chunks.append(chunk)
|
|
|
|
# Verify that the chunks first stream the reasoning content and then the main content
|
|
# Then verify that the final result has the correct content and thought
|
|
assert len(chunks) == 5
|
|
assert chunks[0] == "<think>This is the reasoning content 1"
|
|
assert chunks[1] == "This is the reasoning content 2"
|
|
assert chunks[2] == "</think>"
|
|
assert chunks[3] == "This is the main content"
|
|
assert isinstance(chunks[4], CreateResult)
|
|
assert chunks[4].content == "This is the main content"
|
|
assert chunks[4].thought == "This is the reasoning content 1This is the reasoning content 2"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_r1_think_field(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
async def _mock_create_stream(*args: Any, **kwargs: Any) -> AsyncGenerator[ChatCompletionChunk, None]:
|
|
chunks = ["<think> Hello</think>", " Another Hello", " Yet Another Hello"]
|
|
for i, chunk in enumerate(chunks):
|
|
await asyncio.sleep(0.1)
|
|
yield ChatCompletionChunk(
|
|
id="id",
|
|
choices=[
|
|
ChunkChoice(
|
|
finish_reason="stop" if i == len(chunks) - 1 else None,
|
|
index=0,
|
|
delta=ChoiceDelta(
|
|
content=chunk,
|
|
role="assistant",
|
|
),
|
|
),
|
|
],
|
|
created=0,
|
|
model="r1",
|
|
object="chat.completion.chunk",
|
|
usage=CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0),
|
|
)
|
|
|
|
async def _mock_create(*args: Any, **kwargs: Any) -> ChatCompletion | AsyncGenerator[ChatCompletionChunk, None]:
|
|
stream = kwargs.get("stream", False)
|
|
if not stream:
|
|
await asyncio.sleep(0.1)
|
|
return ChatCompletion(
|
|
id="id",
|
|
choices=[
|
|
Choice(
|
|
finish_reason="stop",
|
|
index=0,
|
|
message=ChatCompletionMessage(
|
|
content="<think> Hello</think> Another Hello Yet Another Hello", role="assistant"
|
|
),
|
|
)
|
|
],
|
|
created=0,
|
|
model="r1",
|
|
object="chat.completion",
|
|
usage=CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0),
|
|
)
|
|
else:
|
|
return _mock_create_stream(*args, **kwargs)
|
|
|
|
monkeypatch.setattr(AsyncCompletions, "create", _mock_create)
|
|
|
|
model_client = OpenAIChatCompletionClient(
|
|
model="r1",
|
|
api_key="",
|
|
model_info={
|
|
"family": ModelFamily.R1,
|
|
"vision": False,
|
|
"function_calling": False,
|
|
"json_output": False,
|
|
"structured_output": False,
|
|
},
|
|
)
|
|
|
|
# Successful completion with think field.
|
|
create_result = await model_client.create(messages=[UserMessage(content="I am happy.", source="user")])
|
|
assert create_result.content == "Another Hello Yet Another Hello"
|
|
assert create_result.finish_reason == "stop"
|
|
assert not create_result.cached
|
|
assert create_result.thought == "Hello"
|
|
|
|
# Stream completion with think field.
|
|
chunks: List[str | CreateResult] = []
|
|
async for chunk in model_client.create_stream(messages=[UserMessage(content="Hello", source="user")]):
|
|
chunks.append(chunk)
|
|
assert len(chunks) > 0
|
|
assert isinstance(chunks[-1], CreateResult)
|
|
assert chunks[-1].content == "Another Hello Yet Another Hello"
|
|
assert chunks[-1].thought == "Hello"
|
|
assert not chunks[-1].cached
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_r1_think_field_not_present(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
async def _mock_create_stream(*args: Any, **kwargs: Any) -> AsyncGenerator[ChatCompletionChunk, None]:
|
|
chunks = ["Hello", " Another Hello", " Yet Another Hello"]
|
|
for i, chunk in enumerate(chunks):
|
|
await asyncio.sleep(0.1)
|
|
yield ChatCompletionChunk(
|
|
id="id",
|
|
choices=[
|
|
ChunkChoice(
|
|
finish_reason="stop" if i == len(chunks) - 1 else None,
|
|
index=0,
|
|
delta=ChoiceDelta(
|
|
content=chunk,
|
|
role="assistant",
|
|
),
|
|
),
|
|
],
|
|
created=0,
|
|
model="r1",
|
|
object="chat.completion.chunk",
|
|
usage=CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0),
|
|
)
|
|
|
|
async def _mock_create(*args: Any, **kwargs: Any) -> ChatCompletion | AsyncGenerator[ChatCompletionChunk, None]:
|
|
stream = kwargs.get("stream", False)
|
|
if not stream:
|
|
await asyncio.sleep(0.1)
|
|
return ChatCompletion(
|
|
id="id",
|
|
choices=[
|
|
Choice(
|
|
finish_reason="stop",
|
|
index=0,
|
|
message=ChatCompletionMessage(
|
|
content="Hello Another Hello Yet Another Hello", role="assistant"
|
|
),
|
|
)
|
|
],
|
|
created=0,
|
|
model="r1",
|
|
object="chat.completion",
|
|
usage=CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0),
|
|
)
|
|
else:
|
|
return _mock_create_stream(*args, **kwargs)
|
|
|
|
monkeypatch.setattr(AsyncCompletions, "create", _mock_create)
|
|
|
|
model_client = OpenAIChatCompletionClient(
|
|
model="r1",
|
|
api_key="",
|
|
model_info={
|
|
"family": ModelFamily.R1,
|
|
"vision": False,
|
|
"function_calling": False,
|
|
"json_output": False,
|
|
"structured_output": False,
|
|
},
|
|
)
|
|
|
|
# Warning completion when think field is not present.
|
|
with pytest.warns(UserWarning, match="Could not find <think>..</think> field in model response content."):
|
|
create_result = await model_client.create(messages=[UserMessage(content="I am happy.", source="user")])
|
|
assert create_result.content == "Hello Another Hello Yet Another Hello"
|
|
assert create_result.finish_reason == "stop"
|
|
assert not create_result.cached
|
|
assert create_result.thought is None
|
|
|
|
# Stream completion with think field.
|
|
with pytest.warns(UserWarning, match="Could not find <think>..</think> field in model response content."):
|
|
chunks: List[str | CreateResult] = []
|
|
async for chunk in model_client.create_stream(messages=[UserMessage(content="Hello", source="user")]):
|
|
chunks.append(chunk)
|
|
assert len(chunks) > 0
|
|
assert isinstance(chunks[-1], CreateResult)
|
|
assert chunks[-1].content == "Hello Another Hello Yet Another Hello"
|
|
assert chunks[-1].thought is None
|
|
assert not chunks[-1].cached
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_tool_calling(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
model = "gpt-4o-2024-05-13"
|
|
chat_completions = [
|
|
# Successful completion, single tool call
|
|
ChatCompletion(
|
|
id="id1",
|
|
choices=[
|
|
Choice(
|
|
finish_reason="tool_calls",
|
|
index=0,
|
|
message=ChatCompletionMessage(
|
|
content=None,
|
|
tool_calls=[
|
|
ChatCompletionMessageToolCall(
|
|
id="1",
|
|
type="function",
|
|
function=Function(
|
|
name="_pass_function",
|
|
arguments=json.dumps({"input": "task"}),
|
|
),
|
|
)
|
|
],
|
|
role="assistant",
|
|
),
|
|
)
|
|
],
|
|
created=0,
|
|
model=model,
|
|
object="chat.completion",
|
|
usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0),
|
|
),
|
|
# Successful completion, parallel tool calls
|
|
ChatCompletion(
|
|
id="id2",
|
|
choices=[
|
|
Choice(
|
|
finish_reason="tool_calls",
|
|
index=0,
|
|
message=ChatCompletionMessage(
|
|
content=None,
|
|
tool_calls=[
|
|
ChatCompletionMessageToolCall(
|
|
id="1",
|
|
type="function",
|
|
function=Function(
|
|
name="_pass_function",
|
|
arguments=json.dumps({"input": "task"}),
|
|
),
|
|
),
|
|
ChatCompletionMessageToolCall(
|
|
id="2",
|
|
type="function",
|
|
function=Function(
|
|
name="_fail_function",
|
|
arguments=json.dumps({"input": "task"}),
|
|
),
|
|
),
|
|
ChatCompletionMessageToolCall(
|
|
id="3",
|
|
type="function",
|
|
function=Function(
|
|
name="_echo_function",
|
|
arguments=json.dumps({"input": "task"}),
|
|
),
|
|
),
|
|
],
|
|
role="assistant",
|
|
),
|
|
)
|
|
],
|
|
created=0,
|
|
model=model,
|
|
object="chat.completion",
|
|
usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0),
|
|
),
|
|
# Warning completion when finish reason is not tool_calls.
|
|
ChatCompletion(
|
|
id="id3",
|
|
choices=[
|
|
Choice(
|
|
finish_reason="stop",
|
|
index=0,
|
|
message=ChatCompletionMessage(
|
|
content=None,
|
|
tool_calls=[
|
|
ChatCompletionMessageToolCall(
|
|
id="1",
|
|
type="function",
|
|
function=Function(
|
|
name="_pass_function",
|
|
arguments=json.dumps({"input": "task"}),
|
|
),
|
|
)
|
|
],
|
|
role="assistant",
|
|
),
|
|
)
|
|
],
|
|
created=0,
|
|
model=model,
|
|
object="chat.completion",
|
|
usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0),
|
|
),
|
|
# Thought field is populated when content is not None.
|
|
ChatCompletion(
|
|
id="id4",
|
|
choices=[
|
|
Choice(
|
|
finish_reason="tool_calls",
|
|
index=0,
|
|
message=ChatCompletionMessage(
|
|
content="I should make a tool call.",
|
|
tool_calls=[
|
|
ChatCompletionMessageToolCall(
|
|
id="1",
|
|
type="function",
|
|
function=Function(
|
|
name="_pass_function",
|
|
arguments=json.dumps({"input": "task"}),
|
|
),
|
|
)
|
|
],
|
|
role="assistant",
|
|
),
|
|
)
|
|
],
|
|
created=0,
|
|
model=model,
|
|
object="chat.completion",
|
|
usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0),
|
|
),
|
|
# Should not be returning tool calls when the tool_calls are empty
|
|
ChatCompletion(
|
|
id="id5",
|
|
choices=[
|
|
Choice(
|
|
finish_reason="stop",
|
|
index=0,
|
|
message=ChatCompletionMessage(
|
|
content="I should make a tool call.",
|
|
tool_calls=[],
|
|
role="assistant",
|
|
),
|
|
)
|
|
],
|
|
created=0,
|
|
model=model,
|
|
object="chat.completion",
|
|
usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0),
|
|
),
|
|
# Should raise warning when function arguments is not a string.
|
|
ChatCompletion(
|
|
id="id6",
|
|
choices=[
|
|
Choice(
|
|
finish_reason="tool_calls",
|
|
index=0,
|
|
message=ChatCompletionMessage(
|
|
content=None,
|
|
tool_calls=[
|
|
ChatCompletionMessageToolCall(
|
|
id="1",
|
|
type="function",
|
|
function=Function.construct(name="_pass_function", arguments={"input": "task"}), # type: ignore
|
|
)
|
|
],
|
|
role="assistant",
|
|
),
|
|
)
|
|
],
|
|
created=0,
|
|
model=model,
|
|
object="chat.completion",
|
|
usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0),
|
|
),
|
|
]
|
|
|
|
class _MockChatCompletion:
|
|
def __init__(self, completions: List[ChatCompletion]):
|
|
self.completions = list(completions)
|
|
self.calls: List[Dict[str, Any]] = []
|
|
|
|
async def mock_create(
|
|
self, *args: Any, **kwargs: Any
|
|
) -> ChatCompletion | AsyncGenerator[ChatCompletionChunk, None]:
|
|
if kwargs.get("stream", False):
|
|
raise NotImplementedError("Streaming not supported in this test.")
|
|
self.calls.append(kwargs)
|
|
return self.completions.pop(0)
|
|
|
|
mock = _MockChatCompletion(chat_completions)
|
|
monkeypatch.setattr(AsyncCompletions, "create", mock.mock_create)
|
|
pass_tool = FunctionTool(_pass_function, description="pass tool.")
|
|
fail_tool = FunctionTool(_fail_function, description="fail tool.")
|
|
echo_tool = FunctionTool(_echo_function, description="echo tool.")
|
|
model_client = OpenAIChatCompletionClient(model=model, api_key="")
|
|
|
|
# Single tool call
|
|
create_result = await model_client.create(messages=[UserMessage(content="Hello", source="user")], tools=[pass_tool])
|
|
assert create_result.content == [FunctionCall(id="1", arguments=r'{"input": "task"}', name="_pass_function")]
|
|
# Verify that the tool schema was passed to the model client.
|
|
kwargs = mock.calls[0]
|
|
assert kwargs["tools"] == [{"function": pass_tool.schema, "type": "function"}]
|
|
# Verify finish reason
|
|
assert create_result.finish_reason == "function_calls"
|
|
|
|
# Parallel tool calls
|
|
create_result = await model_client.create(
|
|
messages=[UserMessage(content="Hello", source="user")], tools=[pass_tool, fail_tool, echo_tool]
|
|
)
|
|
assert create_result.content == [
|
|
FunctionCall(id="1", arguments=r'{"input": "task"}', name="_pass_function"),
|
|
FunctionCall(id="2", arguments=r'{"input": "task"}', name="_fail_function"),
|
|
FunctionCall(id="3", arguments=r'{"input": "task"}', name="_echo_function"),
|
|
]
|
|
# Verify that the tool schema was passed to the model client.
|
|
kwargs = mock.calls[1]
|
|
assert kwargs["tools"] == [
|
|
{"function": pass_tool.schema, "type": "function"},
|
|
{"function": fail_tool.schema, "type": "function"},
|
|
{"function": echo_tool.schema, "type": "function"},
|
|
]
|
|
# Verify finish reason
|
|
assert create_result.finish_reason == "function_calls"
|
|
|
|
# Warning completion when finish reason is not tool_calls.
|
|
with pytest.warns(UserWarning, match="Finish reason mismatch"):
|
|
create_result = await model_client.create(
|
|
messages=[UserMessage(content="Hello", source="user")], tools=[pass_tool]
|
|
)
|
|
assert create_result.content == [FunctionCall(id="1", arguments=r'{"input": "task"}', name="_pass_function")]
|
|
assert create_result.finish_reason == "function_calls"
|
|
|
|
# Thought field is populated when content is not None.
|
|
create_result = await model_client.create(messages=[UserMessage(content="Hello", source="user")], tools=[pass_tool])
|
|
assert create_result.content == [FunctionCall(id="1", arguments=r'{"input": "task"}', name="_pass_function")]
|
|
assert create_result.finish_reason == "function_calls"
|
|
assert create_result.thought == "I should make a tool call."
|
|
|
|
# Should not be returning tool calls when the tool_calls are empty
|
|
create_result = await model_client.create(messages=[UserMessage(content="Hello", source="user")], tools=[pass_tool])
|
|
assert create_result.content == "I should make a tool call."
|
|
assert create_result.finish_reason == "stop"
|
|
|
|
# Should raise warning when function arguments is not a string.
|
|
with pytest.warns(UserWarning, match="Tool call function arguments field is not a string"):
|
|
create_result = await model_client.create(
|
|
messages=[UserMessage(content="Hello", source="user")], tools=[pass_tool]
|
|
)
|
|
assert create_result.content == [FunctionCall(id="1", arguments=r'{"input": "task"}', name="_pass_function")]
|
|
assert create_result.finish_reason == "function_calls"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_tool_calling_with_stream(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
async def _mock_create_stream(*args: Any, **kwargs: Any) -> AsyncGenerator[ChatCompletionChunk, None]:
|
|
model = resolve_model(kwargs.get("model", "gpt-4o"))
|
|
mock_chunks_content = ["Hello", " Another Hello", " Yet Another Hello"]
|
|
mock_chunks = [
|
|
# generate the list of mock chunk content
|
|
MockChunkDefinition(
|
|
chunk_choice=ChunkChoice(
|
|
finish_reason=None,
|
|
index=0,
|
|
delta=ChoiceDelta(
|
|
content=mock_chunk_content,
|
|
role="assistant",
|
|
),
|
|
),
|
|
usage=None,
|
|
)
|
|
for mock_chunk_content in mock_chunks_content
|
|
] + [
|
|
# generate the function call chunk
|
|
MockChunkDefinition(
|
|
chunk_choice=ChunkChoice(
|
|
finish_reason="tool_calls",
|
|
index=0,
|
|
delta=ChoiceDelta(
|
|
content=None,
|
|
role="assistant",
|
|
tool_calls=[
|
|
ChoiceDeltaToolCall(
|
|
index=0,
|
|
id="1",
|
|
type="function",
|
|
function=ChoiceDeltaToolCallFunction(
|
|
name="_pass_function",
|
|
arguments=json.dumps({"input": "task"}),
|
|
),
|
|
)
|
|
],
|
|
),
|
|
),
|
|
usage=None,
|
|
)
|
|
]
|
|
for mock_chunk in mock_chunks:
|
|
await asyncio.sleep(0.1)
|
|
yield ChatCompletionChunk(
|
|
id="id",
|
|
choices=[mock_chunk.chunk_choice],
|
|
created=0,
|
|
model=model,
|
|
object="chat.completion.chunk",
|
|
usage=mock_chunk.usage,
|
|
)
|
|
|
|
async def _mock_create(*args: Any, **kwargs: Any) -> ChatCompletion | AsyncGenerator[ChatCompletionChunk, None]:
|
|
stream = kwargs.get("stream", False)
|
|
if not stream:
|
|
raise ValueError("Stream is not False")
|
|
else:
|
|
return _mock_create_stream(*args, **kwargs)
|
|
|
|
monkeypatch.setattr(AsyncCompletions, "create", _mock_create)
|
|
|
|
model_client = OpenAIChatCompletionClient(model="gpt-4o", api_key="")
|
|
pass_tool = FunctionTool(_pass_function, description="pass tool.")
|
|
stream = model_client.create_stream(messages=[UserMessage(content="Hello", source="user")], tools=[pass_tool])
|
|
chunks: List[str | CreateResult] = []
|
|
async for chunk in stream:
|
|
chunks.append(chunk)
|
|
assert chunks[0] == "Hello"
|
|
assert chunks[1] == " Another Hello"
|
|
assert chunks[2] == " Yet Another Hello"
|
|
assert isinstance(chunks[-1], CreateResult)
|
|
assert chunks[-1].content == [FunctionCall(id="1", arguments=r'{"input": "task"}', name="_pass_function")]
|
|
assert chunks[-1].finish_reason == "function_calls"
|
|
assert chunks[-1].thought == "Hello Another Hello Yet Another Hello"
|
|
|
|
|
|
@pytest.fixture()
|
|
def openai_client(request: pytest.FixtureRequest) -> OpenAIChatCompletionClient:
|
|
model = request.node.callspec.params["model"] # type: ignore
|
|
assert isinstance(model, str)
|
|
if model.startswith("gemini"):
|
|
api_key = os.getenv("GEMINI_API_KEY")
|
|
if not api_key:
|
|
pytest.skip("GEMINI_API_KEY not found in environment variables")
|
|
elif model.startswith("claude"):
|
|
api_key = os.getenv("ANTHROPIC_API_KEY")
|
|
if not api_key:
|
|
pytest.skip("ANTHROPIC_API_KEY not found in environment variables")
|
|
else:
|
|
api_key = os.getenv("OPENAI_API_KEY")
|
|
if not api_key:
|
|
pytest.skip("OPENAI_API_KEY not found in environment variables")
|
|
model_client = OpenAIChatCompletionClient(
|
|
model=model,
|
|
api_key=api_key,
|
|
)
|
|
return model_client
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
@pytest.mark.parametrize(
|
|
"model",
|
|
["gpt-4o-mini", "gemini-1.5-flash", "claude-3-5-haiku-20241022"],
|
|
)
|
|
async def test_model_client_basic_completion(model: str, openai_client: OpenAIChatCompletionClient) -> None:
|
|
# Test basic completion
|
|
create_result = await openai_client.create(
|
|
messages=[
|
|
SystemMessage(content="You are a helpful assistant."),
|
|
UserMessage(content="Explain to me how AI works.", source="user"),
|
|
]
|
|
)
|
|
assert isinstance(create_result.content, str)
|
|
assert len(create_result.content) > 0
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
@pytest.mark.parametrize(
|
|
"model",
|
|
["gpt-4o-mini", "gemini-1.5-flash", "claude-3-5-haiku-20241022"],
|
|
)
|
|
async def test_model_client_with_function_calling(model: str, openai_client: OpenAIChatCompletionClient) -> None:
|
|
# Test tool calling
|
|
pass_tool = FunctionTool(_pass_function, name="pass_tool", description="pass session.")
|
|
fail_tool = FunctionTool(_fail_function, name="fail_tool", description="fail session.")
|
|
messages: List[LLMMessage] = [UserMessage(content="Call the pass tool with input 'task'", source="user")]
|
|
create_result = await openai_client.create(messages=messages, tools=[pass_tool, fail_tool])
|
|
assert isinstance(create_result.content, list)
|
|
assert len(create_result.content) == 1
|
|
assert isinstance(create_result.content[0], FunctionCall)
|
|
assert create_result.content[0].name == "pass_tool"
|
|
assert json.loads(create_result.content[0].arguments) == {"input": "task"}
|
|
assert create_result.finish_reason == "function_calls"
|
|
assert create_result.usage is not None
|
|
|
|
# Test reflection on tool call response.
|
|
messages.append(AssistantMessage(content=create_result.content, source="assistant"))
|
|
messages.append(
|
|
FunctionExecutionResultMessage(
|
|
content=[
|
|
FunctionExecutionResult(
|
|
content="passed",
|
|
call_id=create_result.content[0].id,
|
|
is_error=False,
|
|
name=create_result.content[0].name,
|
|
)
|
|
]
|
|
)
|
|
)
|
|
create_result = await openai_client.create(messages=messages)
|
|
assert isinstance(create_result.content, str)
|
|
assert len(create_result.content) > 0
|
|
|
|
# Test parallel tool calling
|
|
messages = [
|
|
UserMessage(
|
|
content="Call both the pass tool with input 'task' and the fail tool also with input 'task'", source="user"
|
|
)
|
|
]
|
|
create_result = await openai_client.create(messages=messages, tools=[pass_tool, fail_tool])
|
|
assert isinstance(create_result.content, list)
|
|
assert len(create_result.content) == 2
|
|
assert isinstance(create_result.content[0], FunctionCall)
|
|
assert create_result.content[0].name == "pass_tool"
|
|
assert json.loads(create_result.content[0].arguments) == {"input": "task"}
|
|
assert isinstance(create_result.content[1], FunctionCall)
|
|
assert create_result.content[1].name == "fail_tool"
|
|
assert json.loads(create_result.content[1].arguments) == {"input": "task"}
|
|
assert create_result.finish_reason == "function_calls"
|
|
assert create_result.usage is not None
|
|
|
|
# Test reflection on parallel tool call response.
|
|
messages.append(AssistantMessage(content=create_result.content, source="assistant"))
|
|
messages.append(
|
|
FunctionExecutionResultMessage(
|
|
content=[
|
|
FunctionExecutionResult(
|
|
content="passed", call_id=create_result.content[0].id, is_error=False, name="pass_tool"
|
|
),
|
|
FunctionExecutionResult(
|
|
content="failed", call_id=create_result.content[1].id, is_error=True, name="fail_tool"
|
|
),
|
|
]
|
|
)
|
|
)
|
|
create_result = await openai_client.create(messages=messages)
|
|
assert isinstance(create_result.content, str)
|
|
assert len(create_result.content) > 0
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
@pytest.mark.parametrize(
|
|
"model",
|
|
["gpt-4o-mini", "gemini-1.5-flash"],
|
|
)
|
|
async def test_openai_structured_output_using_response_format(
|
|
model: str, openai_client: OpenAIChatCompletionClient
|
|
) -> None:
|
|
class AgentResponse(BaseModel):
|
|
thoughts: str
|
|
response: Literal["happy", "sad", "neutral"]
|
|
|
|
create_result = await openai_client.create(
|
|
messages=[UserMessage(content="I am happy.", source="user")],
|
|
extra_create_args={
|
|
"response_format": {
|
|
"type": "json_schema",
|
|
"json_schema": {
|
|
"name": "AgentResponse",
|
|
"description": "Agent response",
|
|
"schema": AgentResponse.model_json_schema(),
|
|
},
|
|
}
|
|
},
|
|
)
|
|
|
|
assert isinstance(create_result.content, str)
|
|
assert len(create_result.content) > 0
|
|
response = AgentResponse.model_validate(json.loads(create_result.content))
|
|
assert response.thoughts
|
|
assert response.response in ["happy", "sad", "neutral"]
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
@pytest.mark.parametrize(
|
|
"model",
|
|
["gpt-4o-mini", "gemini-1.5-flash"],
|
|
)
|
|
async def test_openai_structured_output(model: str, openai_client: OpenAIChatCompletionClient) -> None:
|
|
class AgentResponse(BaseModel):
|
|
thoughts: str
|
|
response: Literal["happy", "sad", "neutral"]
|
|
|
|
# Test that the openai client was called with the correct response format.
|
|
create_result = await openai_client.create(
|
|
messages=[UserMessage(content="I am happy.", source="user")], json_output=AgentResponse
|
|
)
|
|
assert isinstance(create_result.content, str)
|
|
response = AgentResponse.model_validate(json.loads(create_result.content))
|
|
assert response.thoughts
|
|
assert response.response in ["happy", "sad", "neutral"]
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
@pytest.mark.parametrize(
|
|
"model",
|
|
["gpt-4o-mini", "gemini-1.5-flash"],
|
|
)
|
|
async def test_openai_structured_output_with_streaming(model: str, openai_client: OpenAIChatCompletionClient) -> None:
|
|
class AgentResponse(BaseModel):
|
|
thoughts: str
|
|
response: Literal["happy", "sad", "neutral"]
|
|
|
|
# Test that the openai client was called with the correct response format.
|
|
stream = openai_client.create_stream(
|
|
messages=[UserMessage(content="I am happy.", source="user")], json_output=AgentResponse
|
|
)
|
|
chunks: List[str | CreateResult] = []
|
|
async for chunk in stream:
|
|
chunks.append(chunk)
|
|
assert len(chunks) > 0
|
|
assert isinstance(chunks[-1], CreateResult)
|
|
assert isinstance(chunks[-1].content, str)
|
|
response = AgentResponse.model_validate(json.loads(chunks[-1].content))
|
|
assert response.thoughts
|
|
assert response.response in ["happy", "sad", "neutral"]
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
@pytest.mark.parametrize(
|
|
"model",
|
|
[
|
|
"gpt-4o-mini",
|
|
# "gemini-1.5-flash", # Gemini models do not support structured output with tool calls from model client.
|
|
],
|
|
)
|
|
async def test_openai_structured_output_with_tool_calls(model: str, openai_client: OpenAIChatCompletionClient) -> None:
|
|
class AgentResponse(BaseModel):
|
|
thoughts: str
|
|
response: Literal["happy", "sad", "neutral"]
|
|
|
|
def sentiment_analysis(text: str) -> str:
|
|
"""Given a text, return the sentiment."""
|
|
return "happy" if "happy" in text else "sad" if "sad" in text else "neutral"
|
|
|
|
tool = FunctionTool(sentiment_analysis, description="Sentiment Analysis", strict=True)
|
|
|
|
extra_create_args = {"tool_choice": "required"}
|
|
|
|
response1 = await openai_client.create(
|
|
messages=[
|
|
SystemMessage(content="Analyze input text sentiment using the tool provided."),
|
|
UserMessage(content="I am happy.", source="user"),
|
|
],
|
|
tools=[tool],
|
|
extra_create_args=extra_create_args,
|
|
json_output=AgentResponse,
|
|
)
|
|
assert isinstance(response1.content, list)
|
|
assert len(response1.content) == 1
|
|
assert isinstance(response1.content[0], FunctionCall)
|
|
assert response1.content[0].name == "sentiment_analysis"
|
|
assert json.loads(response1.content[0].arguments) == {"text": "I am happy."}
|
|
assert response1.finish_reason == "function_calls"
|
|
|
|
response2 = await openai_client.create(
|
|
messages=[
|
|
SystemMessage(content="Analyze input text sentiment using the tool provided."),
|
|
UserMessage(content="I am happy.", source="user"),
|
|
AssistantMessage(content=response1.content, source="assistant"),
|
|
FunctionExecutionResultMessage(
|
|
content=[
|
|
FunctionExecutionResult(
|
|
content="happy", call_id=response1.content[0].id, is_error=False, name=tool.name
|
|
)
|
|
]
|
|
),
|
|
],
|
|
json_output=AgentResponse,
|
|
)
|
|
assert isinstance(response2.content, str)
|
|
parsed_response = AgentResponse.model_validate(json.loads(response2.content))
|
|
assert parsed_response.thoughts
|
|
assert parsed_response.response in ["happy", "sad", "neutral"]
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
@pytest.mark.parametrize(
|
|
"model",
|
|
[
|
|
"gpt-4o-mini",
|
|
# "gemini-1.5-flash", # Gemini models do not support structured output with tool calls from model client.
|
|
],
|
|
)
|
|
async def test_openai_structured_output_with_streaming_tool_calls(
|
|
model: str, openai_client: OpenAIChatCompletionClient
|
|
) -> None:
|
|
class AgentResponse(BaseModel):
|
|
thoughts: str
|
|
response: Literal["happy", "sad", "neutral"]
|
|
|
|
def sentiment_analysis(text: str) -> str:
|
|
"""Given a text, return the sentiment."""
|
|
return "happy" if "happy" in text else "sad" if "sad" in text else "neutral"
|
|
|
|
tool = FunctionTool(sentiment_analysis, description="Sentiment Analysis", strict=True)
|
|
|
|
extra_create_args = {"tool_choice": "required"}
|
|
|
|
chunks1: List[str | CreateResult] = []
|
|
stream1 = openai_client.create_stream(
|
|
messages=[
|
|
SystemMessage(content="Analyze input text sentiment using the tool provided."),
|
|
UserMessage(content="I am happy.", source="user"),
|
|
],
|
|
tools=[tool],
|
|
extra_create_args=extra_create_args,
|
|
json_output=AgentResponse,
|
|
)
|
|
async for chunk in stream1:
|
|
chunks1.append(chunk)
|
|
assert len(chunks1) > 0
|
|
create_result1 = chunks1[-1]
|
|
assert isinstance(create_result1, CreateResult)
|
|
assert isinstance(create_result1.content, list)
|
|
assert len(create_result1.content) == 1
|
|
assert isinstance(create_result1.content[0], FunctionCall)
|
|
assert create_result1.content[0].name == "sentiment_analysis"
|
|
assert json.loads(create_result1.content[0].arguments) == {"text": "I am happy."}
|
|
assert create_result1.finish_reason == "function_calls"
|
|
|
|
stream2 = openai_client.create_stream(
|
|
messages=[
|
|
SystemMessage(content="Analyze input text sentiment using the tool provided."),
|
|
UserMessage(content="I am happy.", source="user"),
|
|
AssistantMessage(content=create_result1.content, source="assistant"),
|
|
FunctionExecutionResultMessage(
|
|
content=[
|
|
FunctionExecutionResult(
|
|
content="happy", call_id=create_result1.content[0].id, is_error=False, name=tool.name
|
|
)
|
|
]
|
|
),
|
|
],
|
|
json_output=AgentResponse,
|
|
)
|
|
chunks2: List[str | CreateResult] = []
|
|
async for chunk in stream2:
|
|
chunks2.append(chunk)
|
|
assert len(chunks2) > 0
|
|
create_result2 = chunks2[-1]
|
|
assert isinstance(create_result2, CreateResult)
|
|
assert isinstance(create_result2.content, str)
|
|
parsed_response = AgentResponse.model_validate(json.loads(create_result2.content))
|
|
assert parsed_response.thoughts
|
|
assert parsed_response.response in ["happy", "sad", "neutral"]
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_hugging_face() -> None:
|
|
api_key = os.getenv("HF_TOKEN")
|
|
if not api_key:
|
|
pytest.skip("HF_TOKEN not found in environment variables")
|
|
|
|
model_client = OpenAIChatCompletionClient(
|
|
model="microsoft/Phi-3.5-mini-instruct",
|
|
api_key=api_key,
|
|
base_url="https://api-inference.huggingface.co/v1/",
|
|
model_info={
|
|
"function_calling": False,
|
|
"json_output": False,
|
|
"vision": False,
|
|
"family": ModelFamily.UNKNOWN,
|
|
"structured_output": False,
|
|
},
|
|
)
|
|
|
|
# Test basic completion
|
|
create_result = await model_client.create(
|
|
messages=[
|
|
SystemMessage(content="You are a helpful assistant."),
|
|
UserMessage(content="Explain to me how AI works.", source="user"),
|
|
]
|
|
)
|
|
assert isinstance(create_result.content, str)
|
|
assert len(create_result.content) > 0
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_ollama() -> None:
|
|
model = "deepseek-r1:1.5b"
|
|
model_info: ModelInfo = {
|
|
"function_calling": False,
|
|
"json_output": False,
|
|
"vision": False,
|
|
"family": ModelFamily.R1,
|
|
"structured_output": False,
|
|
}
|
|
# Check if the model is running locally.
|
|
try:
|
|
async with httpx.AsyncClient() as client:
|
|
response = await client.get(f"http://localhost:11434/v1/models/{model}")
|
|
response.raise_for_status()
|
|
except httpx.HTTPStatusError as e:
|
|
pytest.skip(f"{model} model is not running locally: {e}")
|
|
except httpx.ConnectError as e:
|
|
pytest.skip(f"Ollama is not running locally: {e}")
|
|
|
|
model_client = OpenAIChatCompletionClient(
|
|
model=model,
|
|
api_key="placeholder",
|
|
base_url="http://localhost:11434/v1",
|
|
model_info=model_info,
|
|
)
|
|
|
|
# Test basic completion with the Ollama deepseek-r1:1.5b model.
|
|
create_result = await model_client.create(
|
|
messages=[
|
|
UserMessage(
|
|
content="Taking two balls from a bag of 10 green balls and 20 red balls, "
|
|
"what is the probability of getting a green and a red balls?",
|
|
source="user",
|
|
),
|
|
]
|
|
)
|
|
assert isinstance(create_result.content, str)
|
|
assert len(create_result.content) > 0
|
|
assert create_result.finish_reason == "stop"
|
|
assert create_result.usage is not None
|
|
if model_info["family"] == ModelFamily.R1:
|
|
assert create_result.thought is not None
|
|
|
|
# Test streaming completion with the Ollama deepseek-r1:1.5b model.
|
|
chunks: List[str | CreateResult] = []
|
|
async for chunk in model_client.create_stream(
|
|
messages=[
|
|
UserMessage(
|
|
content="Taking two balls from a bag of 10 green balls and 20 red balls, "
|
|
"what is the probability of getting a green and a red balls?",
|
|
source="user",
|
|
),
|
|
]
|
|
):
|
|
chunks.append(chunk)
|
|
assert len(chunks) > 0
|
|
assert isinstance(chunks[-1], CreateResult)
|
|
assert chunks[-1].finish_reason == "stop"
|
|
assert len(chunks[-1].content) > 0
|
|
assert chunks[-1].usage is not None
|
|
if model_info["family"] == ModelFamily.R1:
|
|
assert chunks[-1].thought is not None
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_add_name_prefixes(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
sys_message = SystemMessage(content="You are a helpful AI agent, and you answer questions in a friendly way.")
|
|
assistant_message = AssistantMessage(content="Hello, how can I help you?", source="Assistant")
|
|
user_text_message = UserMessage(content="Hello, I am from Seattle.", source="Adam")
|
|
user_mm_message = UserMessage(
|
|
content=[
|
|
"Here is a postcard from Seattle:",
|
|
Image.from_base64(
|
|
"iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGP4z8AAAAMBAQDJ/pLvAAAAAElFTkSuQmCC"
|
|
),
|
|
],
|
|
source="Adam",
|
|
)
|
|
|
|
# Default conversion
|
|
oai_sys = to_oai_type(sys_message)[0]
|
|
oai_asst = to_oai_type(assistant_message)[0]
|
|
oai_text = to_oai_type(user_text_message)[0]
|
|
oai_mm = to_oai_type(user_mm_message)[0]
|
|
|
|
converted_sys = to_oai_type(sys_message, prepend_name=True)[0]
|
|
converted_asst = to_oai_type(assistant_message, prepend_name=True)[0]
|
|
converted_text = to_oai_type(user_text_message, prepend_name=True)[0]
|
|
converted_mm = to_oai_type(user_mm_message, prepend_name=True)[0]
|
|
|
|
# Invariants
|
|
assert "content" in oai_sys
|
|
assert "content" in oai_asst
|
|
assert "content" in oai_text
|
|
assert "content" in oai_mm
|
|
assert "content" in converted_sys
|
|
assert "content" in converted_asst
|
|
assert "content" in converted_text
|
|
assert "content" in converted_mm
|
|
assert oai_sys["role"] == converted_sys["role"]
|
|
assert oai_sys["content"] == converted_sys["content"]
|
|
assert oai_asst["role"] == converted_asst["role"]
|
|
assert oai_asst["content"] == converted_asst["content"]
|
|
assert oai_text["role"] == converted_text["role"]
|
|
assert oai_mm["role"] == converted_mm["role"]
|
|
assert isinstance(oai_mm["content"], list)
|
|
assert isinstance(converted_mm["content"], list)
|
|
assert len(oai_mm["content"]) == len(converted_mm["content"])
|
|
assert "text" in converted_mm["content"][0]
|
|
assert "text" in oai_mm["content"][0]
|
|
|
|
# Name prepended
|
|
assert str(converted_text["content"]) == "Adam said:\n" + str(oai_text["content"])
|
|
assert str(converted_mm["content"][0]["text"]) == "Adam said:\n" + str(oai_mm["content"][0]["text"])
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
@pytest.mark.parametrize(
|
|
"model",
|
|
[
|
|
"gpt-4o-mini",
|
|
"gemini-1.5-flash",
|
|
"claude-3-5-haiku-20241022",
|
|
],
|
|
)
|
|
async def test_muliple_system_message(model: str, openai_client: OpenAIChatCompletionClient) -> None:
|
|
"""Test multiple system messages in a single request."""
|
|
|
|
# Test multiple system messages
|
|
messages: List[LLMMessage] = [
|
|
SystemMessage(content="When you say anything Start with 'FOO'"),
|
|
SystemMessage(content="When you say anything End with 'BAR'"),
|
|
UserMessage(content="Just say '.'", source="user"),
|
|
]
|
|
|
|
result = await openai_client.create(messages=messages)
|
|
result_content = result.content
|
|
assert isinstance(result_content, str)
|
|
result_content = result_content.strip()
|
|
assert result_content[:3] == "FOO"
|
|
assert result_content[-3:] == "BAR"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_system_message_merge_for_gemini_models() -> None:
|
|
"""Tests that system messages are merged correctly for Gemini models."""
|
|
# Create a mock client
|
|
mock_client = MagicMock()
|
|
client = BaseOpenAIChatCompletionClient(
|
|
client=mock_client,
|
|
create_args={"model": "gemini-1.5-flash"},
|
|
model_info={
|
|
"vision": False,
|
|
"function_calling": False,
|
|
"json_output": False,
|
|
"family": "unknown",
|
|
"structured_output": False,
|
|
},
|
|
)
|
|
|
|
# Create two system messages
|
|
messages: List[LLMMessage] = [
|
|
SystemMessage(content="I am system message 1"),
|
|
SystemMessage(content="I am system message 2"),
|
|
UserMessage(content="Hello", source="user"),
|
|
]
|
|
|
|
# Process the messages
|
|
# pylint: disable=protected-access
|
|
# The method is protected, but we need to test it
|
|
create_params = client._process_create_args( # pyright: ignore[reportPrivateUsage]
|
|
messages=messages,
|
|
tools=[],
|
|
json_output=None,
|
|
extra_create_args={},
|
|
)
|
|
|
|
# Extract the actual messages from the result
|
|
oai_messages = create_params.messages
|
|
|
|
# Check that there is only one system message and it contains the merged content
|
|
system_messages = [msg for msg in oai_messages if msg["role"] == "system"]
|
|
assert len(system_messages) == 1
|
|
assert system_messages[0]["content"] == "I am system message 1\nI am system message 2"
|
|
|
|
# Check that the user message is preserved
|
|
user_messages = [msg for msg in oai_messages if msg["role"] == "user"]
|
|
assert len(user_messages) == 1
|
|
assert user_messages[0]["content"] == "Hello"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_system_message_merge_with_non_continuous_messages() -> None:
|
|
"""Tests that an error is raised when non-continuous system messages are provided."""
|
|
# Create a mock client
|
|
mock_client = MagicMock()
|
|
client = BaseOpenAIChatCompletionClient(
|
|
client=mock_client,
|
|
create_args={"model": "gemini-1.5-flash"},
|
|
model_info={
|
|
"vision": False,
|
|
"function_calling": False,
|
|
"json_output": False,
|
|
"family": "unknown",
|
|
"structured_output": False,
|
|
},
|
|
)
|
|
|
|
# Create non-continuous system messages
|
|
messages: List[LLMMessage] = [
|
|
SystemMessage(content="I am system message 1"),
|
|
UserMessage(content="Hello", source="user"),
|
|
SystemMessage(content="I am system message 2"),
|
|
]
|
|
|
|
# Process should raise ValueError
|
|
with pytest.raises(ValueError, match="Multiple and Not continuous system messages are not supported"):
|
|
# pylint: disable=protected-access
|
|
# The method is protected, but we need to test it
|
|
client._process_create_args( # pyright: ignore[reportPrivateUsage]
|
|
messages=messages,
|
|
tools=[],
|
|
json_output=None,
|
|
extra_create_args={},
|
|
)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_system_message_not_merged_for_non_gemini_models() -> None:
|
|
"""Tests that system messages aren't modified for non-Gemini models."""
|
|
# Create a mock client
|
|
mock_client = MagicMock()
|
|
client = BaseOpenAIChatCompletionClient(
|
|
client=mock_client,
|
|
create_args={"model": "gpt-4o"},
|
|
model_info={
|
|
"vision": False,
|
|
"function_calling": False,
|
|
"json_output": False,
|
|
"family": "unknown",
|
|
"structured_output": False,
|
|
},
|
|
)
|
|
|
|
# Create two system messages
|
|
messages: List[LLMMessage] = [
|
|
SystemMessage(content="I am system message 1"),
|
|
SystemMessage(content="I am system message 2"),
|
|
UserMessage(content="Hello", source="user"),
|
|
]
|
|
|
|
# Process the messages
|
|
# pylint: disable=protected-access
|
|
# The method is protected, but we need to test it
|
|
create_params = client._process_create_args( # pyright: ignore[reportPrivateUsage]
|
|
messages=messages,
|
|
tools=[],
|
|
json_output=None,
|
|
extra_create_args={},
|
|
)
|
|
|
|
# Extract the actual messages from the result
|
|
oai_messages = create_params.messages
|
|
|
|
# Check that there are two system messages preserved
|
|
system_messages = [msg for msg in oai_messages if msg["role"] == "system"]
|
|
assert len(system_messages) == 2
|
|
assert system_messages[0]["content"] == "I am system message 1"
|
|
assert system_messages[1]["content"] == "I am system message 2"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_no_system_messages_for_gemini_model() -> None:
|
|
"""Tests behavior when no system messages are provided to a Gemini model."""
|
|
# Create a mock client
|
|
mock_client = MagicMock()
|
|
client = BaseOpenAIChatCompletionClient(
|
|
client=mock_client,
|
|
create_args={"model": "gemini-1.5-flash"},
|
|
model_info={
|
|
"vision": False,
|
|
"function_calling": False,
|
|
"json_output": False,
|
|
"family": "unknown",
|
|
"structured_output": False,
|
|
},
|
|
)
|
|
|
|
# Create messages with no system message
|
|
messages: List[LLMMessage] = [
|
|
UserMessage(content="Hello", source="user"),
|
|
AssistantMessage(content="Hi there", source="assistant"),
|
|
]
|
|
|
|
# Process the messages
|
|
# pylint: disable=protected-access
|
|
# The method is protected, but we need to test it
|
|
create_params = client._process_create_args( # pyright: ignore[reportPrivateUsage]
|
|
messages=messages,
|
|
tools=[],
|
|
json_output=None,
|
|
extra_create_args={},
|
|
)
|
|
|
|
# Extract the actual messages from the result
|
|
oai_messages = create_params.messages
|
|
|
|
# Check that there are no system messages
|
|
system_messages = [msg for msg in oai_messages if msg["role"] == "system"]
|
|
assert len(system_messages) == 0
|
|
|
|
# Check that other messages are preserved
|
|
user_messages = [msg for msg in oai_messages if msg["role"] == "user"]
|
|
assistant_messages = [msg for msg in oai_messages if msg["role"] == "assistant"]
|
|
assert len(user_messages) == 1
|
|
assert len(assistant_messages) == 1
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_single_system_message_for_gemini_model() -> None:
|
|
"""Tests that a single system message is preserved for Gemini models."""
|
|
# Create a mock client
|
|
mock_client = MagicMock()
|
|
client = BaseOpenAIChatCompletionClient(
|
|
client=mock_client,
|
|
create_args={"model": "gemini-1.5-flash"},
|
|
model_info={
|
|
"vision": False,
|
|
"function_calling": False,
|
|
"json_output": False,
|
|
"family": "unknown",
|
|
"structured_output": False,
|
|
},
|
|
)
|
|
|
|
# Create messages with a single system message
|
|
messages: List[LLMMessage] = [
|
|
SystemMessage(content="I am the only system message"),
|
|
UserMessage(content="Hello", source="user"),
|
|
]
|
|
|
|
# Process the messages
|
|
# pylint: disable=protected-access
|
|
# The method is protected, but we need to test it
|
|
create_params = client._process_create_args( # pyright: ignore[reportPrivateUsage]
|
|
messages=messages,
|
|
tools=[],
|
|
json_output=None,
|
|
extra_create_args={},
|
|
)
|
|
|
|
# Extract the actual messages from the result
|
|
oai_messages = create_params.messages
|
|
|
|
# Check that there is exactly one system message with the correct content
|
|
system_messages = [msg for msg in oai_messages if msg["role"] == "system"]
|
|
assert len(system_messages) == 1
|
|
assert system_messages[0]["content"] == "I am the only system message"
|
|
|
|
|
|
def noop(input: str) -> str:
|
|
return "done"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
@pytest.mark.parametrize("model", ["gemini-1.5-flash"])
|
|
async def test_empty_assistant_content_with_gemini(model: str, openai_client: OpenAIChatCompletionClient) -> None:
|
|
# Test tool calling
|
|
tool = FunctionTool(noop, name="noop", description="No-op tool")
|
|
messages: List[LLMMessage] = [UserMessage(content="Call noop", source="user")]
|
|
result = await openai_client.create(messages=messages, tools=[tool])
|
|
assert isinstance(result.content, list)
|
|
tool_call = result.content[0]
|
|
assert isinstance(tool_call, FunctionCall)
|
|
|
|
# reply with empty string as thought (== content)
|
|
messages.append(AssistantMessage(content=result.content, thought="", source="assistant"))
|
|
messages.append(
|
|
FunctionExecutionResultMessage(
|
|
content=[
|
|
FunctionExecutionResult(
|
|
content="done",
|
|
call_id=tool_call.id,
|
|
is_error=False,
|
|
name=tool_call.name,
|
|
)
|
|
]
|
|
)
|
|
)
|
|
|
|
# This will crash if _set_empty_to_whitespace is not applied to "thought"
|
|
result = await openai_client.create(messages=messages)
|
|
assert isinstance(result.content, str)
|
|
assert result.content.strip() != "" or result.content == " "
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
@pytest.mark.parametrize(
|
|
"model",
|
|
[
|
|
"gpt-4o-mini",
|
|
"gemini-1.5-flash",
|
|
"claude-3-5-haiku-20241022",
|
|
],
|
|
)
|
|
async def test_empty_assistant_content_string_with_some_model(
|
|
model: str, openai_client: OpenAIChatCompletionClient
|
|
) -> None:
|
|
# message: assistant is response empty content
|
|
messages: list[LLMMessage] = [
|
|
UserMessage(content="Say something", source="user"),
|
|
AssistantMessage(content="test", source="assistant"),
|
|
UserMessage(content="", source="user"),
|
|
]
|
|
|
|
# This will crash if _set_empty_to_whitespace is not applied to "content"
|
|
result = await openai_client.create(messages=messages)
|
|
assert isinstance(result.content, str)
|
|
|
|
|
|
def test_openai_model_registry_find_well() -> None:
|
|
model = "gpt-4o"
|
|
client1 = OpenAIChatCompletionClient(model=model, api_key="test")
|
|
client2 = OpenAIChatCompletionClient(
|
|
model=model,
|
|
model_info={
|
|
"vision": False,
|
|
"function_calling": False,
|
|
"json_output": False,
|
|
"structured_output": False,
|
|
"family": ModelFamily.UNKNOWN,
|
|
},
|
|
api_key="test",
|
|
)
|
|
|
|
def get_regitered_transformer(client: OpenAIChatCompletionClient) -> TransformerMap:
|
|
model_name = client._create_args["model"] # pyright: ignore[reportPrivateUsage]
|
|
model_family = client.model_info["family"]
|
|
return get_transformer("openai", model_name, model_family)
|
|
|
|
assert get_regitered_transformer(client1) == get_regitered_transformer(client2)
|
|
|
|
|
|
def test_openai_model_registry_find_wrong() -> None:
|
|
with pytest.raises(ValueError, match="No transformer found for model family"):
|
|
get_transformer("openai", "gpt-7", "foobar")
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
@pytest.mark.parametrize(
|
|
"model",
|
|
[
|
|
"gpt-4o-mini",
|
|
],
|
|
)
|
|
async def test_openai_model_unknown_message_type(model: str, openai_client: OpenAIChatCompletionClient) -> None:
|
|
class WrongMessage:
|
|
content = "foo"
|
|
source = "bar"
|
|
|
|
messages: List[WrongMessage] = [WrongMessage()]
|
|
with pytest.raises(ValueError, match="Unknown message type"):
|
|
await openai_client.create(messages=messages) # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
@pytest.mark.parametrize(
|
|
"model",
|
|
[
|
|
"claude-3-5-haiku-20241022",
|
|
],
|
|
)
|
|
async def test_claude_trailing_whitespace_at_last_assistant_content(
|
|
model: str, openai_client: OpenAIChatCompletionClient
|
|
) -> None:
|
|
messages: list[LLMMessage] = [
|
|
UserMessage(content="foo", source="user"),
|
|
UserMessage(content="bar", source="user"),
|
|
AssistantMessage(content="foobar ", source="assistant"),
|
|
]
|
|
|
|
result = await openai_client.create(messages=messages)
|
|
assert isinstance(result.content, str)
|
|
|
|
|
|
def test_rstrip_railing_whitespace_at_last_assistant_content() -> None:
|
|
messages: list[LLMMessage] = [
|
|
UserMessage(content="foo", source="user"),
|
|
UserMessage(content="bar", source="user"),
|
|
AssistantMessage(content="foobar ", source="assistant"),
|
|
]
|
|
|
|
# This will crash if _rstrip_railing_whitespace_at_last_assistant_content is not applied to "content"
|
|
dummy_client = OpenAIChatCompletionClient(model="claude-3-5-haiku-20241022", api_key="dummy-key")
|
|
result = dummy_client._rstrip_last_assistant_message(messages) # pyright: ignore[reportPrivateUsage]
|
|
|
|
assert isinstance(result[-1].content, str)
|
|
assert result[-1].content == "foobar"
|
|
|
|
|
|
# TODO: add integration tests for Azure OpenAI using AAD token.
|