mirror of https://github.com/microsoft/autogen.git
FIX: Anthropic and Gemini could take multiple system message (#6118)
Anthropic SDK could not takes multiple system messages. However some autogen Agent(e.g. SocietyOfMindAgent) makes multiple system messages. And... Gemini with OpenaiSDK do not take error. However is not working mulitple system messages. (Just last one is working) So, I simple change of, "merge multiple system message" at these cases. ## Related issue number Closes #6116 Closes #6117 --------- Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com>
This commit is contained in:
parent
c24eba6ae1
commit
0cd3ff46fa
|
@ -408,6 +408,35 @@ class BaseAnthropicChatCompletionClient(ChatCompletionClient):
|
|||
self._total_usage = RequestUsage(prompt_tokens=0, completion_tokens=0)
|
||||
self._actual_usage = RequestUsage(prompt_tokens=0, completion_tokens=0)
|
||||
|
||||
def _merge_system_messages(self, messages: Sequence[LLMMessage]) -> Sequence[LLMMessage]:
|
||||
"""
|
||||
Merge continuous system messages into a single message.
|
||||
"""
|
||||
_messages: List[LLMMessage] = []
|
||||
system_message_content = ""
|
||||
_first_system_message_idx = -1
|
||||
_last_system_message_idx = -1
|
||||
# Index of the first system message for adding the merged system message at the correct position
|
||||
for idx, message in enumerate(messages):
|
||||
if isinstance(message, SystemMessage):
|
||||
if _first_system_message_idx == -1:
|
||||
_first_system_message_idx = idx
|
||||
elif _last_system_message_idx + 1 != idx:
|
||||
# That case, system message is not continuous
|
||||
# Merge system messages only contiues system messages
|
||||
raise ValueError("Multiple and Not continuous system messages are not supported")
|
||||
system_message_content += message.content + "\n"
|
||||
_last_system_message_idx = idx
|
||||
else:
|
||||
_messages.append(message)
|
||||
system_message_content = system_message_content.rstrip()
|
||||
if system_message_content != "":
|
||||
system_message = SystemMessage(content=system_message_content)
|
||||
_messages.insert(_first_system_message_idx, system_message)
|
||||
messages = _messages
|
||||
|
||||
return messages
|
||||
|
||||
async def create(
|
||||
self,
|
||||
messages: Sequence[LLMMessage],
|
||||
|
@ -442,9 +471,12 @@ class BaseAnthropicChatCompletionClient(ChatCompletionClient):
|
|||
system_message = None
|
||||
anthropic_messages: List[MessageParam] = []
|
||||
|
||||
# Merge continuous system messages into a single message
|
||||
messages = self._merge_system_messages(messages)
|
||||
for message in messages:
|
||||
if isinstance(message, SystemMessage):
|
||||
if system_message is not None:
|
||||
# if that case, system message is must only one
|
||||
raise ValueError("Multiple system messages are not supported")
|
||||
system_message = to_anthropic_type(message)
|
||||
else:
|
||||
|
@ -604,9 +636,12 @@ class BaseAnthropicChatCompletionClient(ChatCompletionClient):
|
|||
system_message = None
|
||||
anthropic_messages: List[MessageParam] = []
|
||||
|
||||
# Merge continuous system messages into a single message
|
||||
messages = self._merge_system_messages(messages)
|
||||
for message in messages:
|
||||
if isinstance(message, SystemMessage):
|
||||
if system_message is not None:
|
||||
# if that case, system message is must only one
|
||||
raise ValueError("Multiple system messages are not supported")
|
||||
system_message = to_anthropic_type(message)
|
||||
else:
|
||||
|
|
|
@ -612,6 +612,32 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
|
|||
if self.model_info["json_output"] is False and json_output is True:
|
||||
raise ValueError("Model does not support JSON output.")
|
||||
|
||||
if create_args.get("model", "unknown").startswith("gemini-"):
|
||||
# Gemini models accept only one system message(else, it will read only the last one)
|
||||
# So, merge system messages into one
|
||||
system_message_content = ""
|
||||
_messages: List[LLMMessage] = []
|
||||
_first_system_message_idx = -1
|
||||
_last_system_message_idx = -1
|
||||
# Index of the first system message for adding the merged system message at the correct position
|
||||
for idx, message in enumerate(messages):
|
||||
if isinstance(message, SystemMessage):
|
||||
if _first_system_message_idx == -1:
|
||||
_first_system_message_idx = idx
|
||||
elif _last_system_message_idx + 1 != idx:
|
||||
# That case, system message is not continuous
|
||||
# Merge system messages only contiues system messages
|
||||
raise ValueError("Multiple and Not continuous system messages are not supported")
|
||||
system_message_content += message.content + "\n"
|
||||
_last_system_message_idx = idx
|
||||
else:
|
||||
_messages.append(message)
|
||||
system_message_content = system_message_content.rstrip()
|
||||
if system_message_content != "":
|
||||
system_message = SystemMessage(content=system_message_content)
|
||||
_messages.insert(_first_system_message_idx, system_message)
|
||||
messages = _messages
|
||||
|
||||
oai_messages_nested = [to_oai_type(m, prepend_name=self._add_name_prefixes) for m in messages]
|
||||
oai_messages = [item for sublist in oai_messages_nested for item in sublist]
|
||||
|
||||
|
|
|
@ -334,3 +334,229 @@ async def test_anthropic_serialization() -> None:
|
|||
loaded_model_client = AnthropicChatCompletionClient.load_component(model_client_config)
|
||||
assert loaded_model_client is not None
|
||||
assert isinstance(loaded_model_client, AnthropicChatCompletionClient)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_anthropic_muliple_system_message() -> None:
|
||||
"""Test multiple system messages in a single request."""
|
||||
|
||||
api_key = os.getenv("ANTHROPIC_API_KEY")
|
||||
if not api_key:
|
||||
pytest.skip("ANTHROPIC_API_KEY not found in environment variables")
|
||||
|
||||
client = AnthropicChatCompletionClient(
|
||||
model="claude-3-haiku-20240307",
|
||||
api_key=api_key,
|
||||
)
|
||||
|
||||
# Test multiple system messages
|
||||
messages: List[LLMMessage] = [
|
||||
SystemMessage(content="When you say anything Start with 'FOO'"),
|
||||
SystemMessage(content="When you say anything End with 'BAR'"),
|
||||
UserMessage(content="Just say '.'", source="user"),
|
||||
]
|
||||
|
||||
result = await client.create(messages=messages)
|
||||
result_content = result.content
|
||||
assert isinstance(result_content, str)
|
||||
result_content = result_content.strip()
|
||||
assert result_content[:3] == "FOO"
|
||||
assert result_content[-3:] == "BAR"
|
||||
|
||||
|
||||
def test_merge_continuous_system_messages() -> None:
|
||||
"""Tests merging of continuous system messages."""
|
||||
client = AnthropicChatCompletionClient(model="claude-3-haiku-20240307", api_key="fake-api-key")
|
||||
|
||||
messages: List[LLMMessage] = [
|
||||
SystemMessage(content="System instruction 1"),
|
||||
SystemMessage(content="System instruction 2"),
|
||||
UserMessage(content="User question", source="user"),
|
||||
]
|
||||
|
||||
merged_messages = client._merge_system_messages(messages) # pyright: ignore[reportPrivateUsage]
|
||||
# The method is protected, but we need to test it
|
||||
|
||||
# 병합 후 2개 메시지만 남아야 함 (시스템 1개, 사용자 1개)
|
||||
assert len(merged_messages) == 2
|
||||
|
||||
# 첫 번째 메시지는 병합된 시스템 메시지여야 함
|
||||
assert isinstance(merged_messages[0], SystemMessage)
|
||||
assert merged_messages[0].content == "System instruction 1\nSystem instruction 2"
|
||||
|
||||
# 두 번째 메시지는 사용자 메시지여야 함
|
||||
assert isinstance(merged_messages[1], UserMessage)
|
||||
assert merged_messages[1].content == "User question"
|
||||
|
||||
|
||||
def test_merge_single_system_message() -> None:
|
||||
"""Tests that a single system message remains unchanged."""
|
||||
client = AnthropicChatCompletionClient(model="claude-3-haiku-20240307", api_key="fake-api-key")
|
||||
|
||||
messages: List[LLMMessage] = [
|
||||
SystemMessage(content="Single system instruction"),
|
||||
UserMessage(content="User question", source="user"),
|
||||
]
|
||||
|
||||
merged_messages = client._merge_system_messages(messages) # pyright: ignore[reportPrivateUsage]
|
||||
# The method is protected, but we need to test it
|
||||
|
||||
# 메시지 개수는 변하지 않아야 함
|
||||
assert len(merged_messages) == 2
|
||||
|
||||
# 시스템 메시지 내용은 변하지 않아야 함
|
||||
assert isinstance(merged_messages[0], SystemMessage)
|
||||
assert merged_messages[0].content == "Single system instruction"
|
||||
|
||||
|
||||
def test_merge_no_system_messages() -> None:
|
||||
"""Tests behavior when there are no system messages."""
|
||||
client = AnthropicChatCompletionClient(model="claude-3-haiku-20240307", api_key="fake-api-key")
|
||||
|
||||
messages: List[LLMMessage] = [
|
||||
UserMessage(content="User question without system", source="user"),
|
||||
]
|
||||
|
||||
merged_messages = client._merge_system_messages(messages) # pyright: ignore[reportPrivateUsage]
|
||||
# The method is protected, but we need to test it
|
||||
|
||||
# 메시지 개수는 변하지 않아야 함
|
||||
assert len(merged_messages) == 1
|
||||
|
||||
# 유일한 메시지는 사용자 메시지여야 함
|
||||
assert isinstance(merged_messages[0], UserMessage)
|
||||
assert merged_messages[0].content == "User question without system"
|
||||
|
||||
|
||||
def test_merge_non_continuous_system_messages() -> None:
|
||||
"""Tests that an error is raised for non-continuous system messages."""
|
||||
client = AnthropicChatCompletionClient(model="claude-3-haiku-20240307", api_key="fake-api-key")
|
||||
|
||||
messages: List[LLMMessage] = [
|
||||
SystemMessage(content="First group 1"),
|
||||
SystemMessage(content="First group 2"),
|
||||
UserMessage(content="Middle user message", source="user"),
|
||||
SystemMessage(content="Second group 1"),
|
||||
SystemMessage(content="Second group 2"),
|
||||
]
|
||||
|
||||
# 연속적이지 않은 시스템 메시지는 에러를 발생시켜야 함
|
||||
with pytest.raises(ValueError, match="Multiple and Not continuous system messages are not supported"):
|
||||
client._merge_system_messages(messages) # pyright: ignore[reportPrivateUsage]
|
||||
# The method is protected, but we need to test it
|
||||
|
||||
|
||||
def test_merge_system_messages_empty() -> None:
|
||||
"""Tests that empty message list is handled properly."""
|
||||
client = AnthropicChatCompletionClient(model="claude-3-haiku-20240307", api_key="fake-api-key")
|
||||
|
||||
merged_messages = client._merge_system_messages([]) # pyright: ignore[reportPrivateUsage]
|
||||
# The method is protected, but we need to test it
|
||||
assert len(merged_messages) == 0
|
||||
|
||||
|
||||
def test_merge_system_messages_with_special_characters() -> None:
|
||||
"""Tests system message merging with special characters and formatting."""
|
||||
client = AnthropicChatCompletionClient(model="claude-3-haiku-20240307", api_key="fake-api-key")
|
||||
|
||||
messages: List[LLMMessage] = [
|
||||
SystemMessage(content="Line 1\nWith newline"),
|
||||
SystemMessage(content="Line 2 with *formatting*"),
|
||||
SystemMessage(content="Line 3 with `code`"),
|
||||
UserMessage(content="Question", source="user"),
|
||||
]
|
||||
|
||||
merged_messages = client._merge_system_messages(messages) # pyright: ignore[reportPrivateUsage]
|
||||
# The method is protected, but we need to test it
|
||||
assert len(merged_messages) == 2
|
||||
|
||||
system_message = merged_messages[0]
|
||||
assert isinstance(system_message, SystemMessage)
|
||||
assert system_message.content == "Line 1\nWith newline\nLine 2 with *formatting*\nLine 3 with `code`"
|
||||
|
||||
|
||||
def test_merge_system_messages_with_whitespace() -> None:
|
||||
"""Tests system message merging with extra whitespace."""
|
||||
client = AnthropicChatCompletionClient(model="claude-3-haiku-20240307", api_key="fake-api-key")
|
||||
|
||||
messages: List[LLMMessage] = [
|
||||
SystemMessage(content=" Message with leading spaces "),
|
||||
SystemMessage(content="\nMessage with leading newline\n"),
|
||||
UserMessage(content="Question", source="user"),
|
||||
]
|
||||
|
||||
merged_messages = client._merge_system_messages(messages) # pyright: ignore[reportPrivateUsage]
|
||||
# The method is protected, but we need to test it
|
||||
assert len(merged_messages) == 2
|
||||
|
||||
system_message = merged_messages[0]
|
||||
assert isinstance(system_message, SystemMessage)
|
||||
# strip()은 내부에서 발생하지 않지만 최종 결과에서는 줄바꿈이 유지됨
|
||||
assert system_message.content == " Message with leading spaces \n\nMessage with leading newline"
|
||||
|
||||
|
||||
def test_merge_system_messages_message_order() -> None:
|
||||
"""Tests that message order is preserved after merging."""
|
||||
client = AnthropicChatCompletionClient(model="claude-3-haiku-20240307", api_key="fake-api-key")
|
||||
|
||||
messages: List[LLMMessage] = [
|
||||
UserMessage(content="Question 1", source="user"),
|
||||
SystemMessage(content="Instruction 1"),
|
||||
SystemMessage(content="Instruction 2"),
|
||||
UserMessage(content="Question 2", source="user"),
|
||||
AssistantMessage(content="Answer", source="assistant"),
|
||||
]
|
||||
|
||||
merged_messages = client._merge_system_messages(messages) # pyright: ignore[reportPrivateUsage]
|
||||
# The method is protected, but we need to test it
|
||||
assert len(merged_messages) == 4
|
||||
|
||||
# 첫 번째 메시지는 UserMessage여야 함
|
||||
assert isinstance(merged_messages[0], UserMessage)
|
||||
assert merged_messages[0].content == "Question 1"
|
||||
|
||||
# 두 번째 메시지는 병합된 SystemMessage여야 함
|
||||
assert isinstance(merged_messages[1], SystemMessage)
|
||||
assert merged_messages[1].content == "Instruction 1\nInstruction 2"
|
||||
|
||||
# 나머지 메시지는 순서대로 유지되어야 함
|
||||
assert isinstance(merged_messages[2], UserMessage)
|
||||
assert merged_messages[2].content == "Question 2"
|
||||
assert isinstance(merged_messages[3], AssistantMessage)
|
||||
assert merged_messages[3].content == "Answer"
|
||||
|
||||
|
||||
def test_merge_system_messages_multiple_groups() -> None:
|
||||
"""Tests that multiple separate groups of system messages raise an error."""
|
||||
client = AnthropicChatCompletionClient(model="claude-3-haiku-20240307", api_key="fake-api-key")
|
||||
|
||||
# 연속되지 않은 시스템 메시지: 사용자 메시지로 분리된 두 그룹
|
||||
messages: List[LLMMessage] = [
|
||||
SystemMessage(content="Group 1 - message 1"),
|
||||
UserMessage(content="Interrupting user message", source="user"),
|
||||
SystemMessage(content="Group 2 - message 1"),
|
||||
]
|
||||
|
||||
with pytest.raises(ValueError, match="Multiple and Not continuous system messages are not supported"):
|
||||
client._merge_system_messages(messages) # pyright: ignore[reportPrivateUsage]
|
||||
# The method is protected, but we need to test it
|
||||
|
||||
|
||||
def test_merge_system_messages_no_duplicates() -> None:
|
||||
"""Tests that identical system messages are still merged properly."""
|
||||
client = AnthropicChatCompletionClient(model="claude-3-haiku-20240307", api_key="fake-api-key")
|
||||
|
||||
messages: List[LLMMessage] = [
|
||||
SystemMessage(content="Same instruction"),
|
||||
SystemMessage(content="Same instruction"), # 중복된 내용
|
||||
UserMessage(content="Question", source="user"),
|
||||
]
|
||||
|
||||
merged_messages = client._merge_system_messages(messages) # pyright: ignore[reportPrivateUsage]
|
||||
# The method is protected, but we need to test it
|
||||
assert len(merged_messages) == 2
|
||||
|
||||
# 첫 번째 메시지는 병합된 시스템 메시지여야 함
|
||||
assert isinstance(merged_messages[0], SystemMessage)
|
||||
# 중복된 내용도 그대로 병합됨
|
||||
assert merged_messages[0].content == "Same instruction\nSame instruction"
|
||||
|
|
|
@ -24,6 +24,7 @@ from autogen_core.tools import BaseTool, FunctionTool
|
|||
from autogen_ext.models.openai import AzureOpenAIChatCompletionClient, OpenAIChatCompletionClient
|
||||
from autogen_ext.models.openai._model_info import resolve_model
|
||||
from autogen_ext.models.openai._openai_client import (
|
||||
BaseOpenAIChatCompletionClient,
|
||||
calculate_vision_tokens,
|
||||
convert_tools,
|
||||
to_oai_type,
|
||||
|
@ -2058,4 +2059,248 @@ async def test_add_name_prefixes(monkeypatch: pytest.MonkeyPatch) -> None:
|
|||
assert str(converted_mm["content"][0]["text"]) == "Adam said:\n" + str(oai_mm["content"][0]["text"])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"model",
|
||||
[
|
||||
"gpt-4o-mini",
|
||||
"gemini-1.5-flash",
|
||||
# TODO: Add anthropic models when available.
|
||||
],
|
||||
)
|
||||
async def test_muliple_system_message(model: str, openai_client: OpenAIChatCompletionClient) -> None:
|
||||
"""Test multiple system messages in a single request."""
|
||||
|
||||
# Test multiple system messages
|
||||
messages: List[LLMMessage] = [
|
||||
SystemMessage(content="When you say anything Start with 'FOO'"),
|
||||
SystemMessage(content="When you say anything End with 'BAR'"),
|
||||
UserMessage(content="Just say '.'", source="user"),
|
||||
]
|
||||
|
||||
result = await openai_client.create(messages=messages)
|
||||
result_content = result.content
|
||||
assert isinstance(result_content, str)
|
||||
result_content = result_content.strip()
|
||||
assert result_content[:3] == "FOO"
|
||||
assert result_content[-3:] == "BAR"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_system_message_merge_for_gemini_models() -> None:
|
||||
"""Tests that system messages are merged correctly for Gemini models."""
|
||||
# Create a mock client
|
||||
mock_client = MagicMock()
|
||||
client = BaseOpenAIChatCompletionClient(
|
||||
client=mock_client,
|
||||
create_args={"model": "gemini-1.5-flash"},
|
||||
model_info={
|
||||
"vision": False,
|
||||
"function_calling": False,
|
||||
"json_output": False,
|
||||
"family": "unknown",
|
||||
"structured_output": False,
|
||||
},
|
||||
)
|
||||
|
||||
# Create two system messages
|
||||
messages: List[LLMMessage] = [
|
||||
SystemMessage(content="I am system message 1"),
|
||||
SystemMessage(content="I am system message 2"),
|
||||
UserMessage(content="Hello", source="user"),
|
||||
]
|
||||
|
||||
# Process the messages
|
||||
# pylint: disable=protected-access
|
||||
# The method is protected, but we need to test it
|
||||
create_params = client._process_create_args( # pyright: ignore[reportPrivateUsage]
|
||||
messages=messages,
|
||||
tools=[],
|
||||
json_output=None,
|
||||
extra_create_args={},
|
||||
)
|
||||
|
||||
# Extract the actual messages from the result
|
||||
oai_messages = create_params.messages
|
||||
|
||||
# Check that there is only one system message and it contains the merged content
|
||||
system_messages = [msg for msg in oai_messages if msg["role"] == "system"]
|
||||
assert len(system_messages) == 1
|
||||
assert system_messages[0]["content"] == "I am system message 1\nI am system message 2"
|
||||
|
||||
# Check that the user message is preserved
|
||||
user_messages = [msg for msg in oai_messages if msg["role"] == "user"]
|
||||
assert len(user_messages) == 1
|
||||
assert user_messages[0]["content"] == "Hello"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_system_message_merge_with_non_continuous_messages() -> None:
|
||||
"""Tests that an error is raised when non-continuous system messages are provided."""
|
||||
# Create a mock client
|
||||
mock_client = MagicMock()
|
||||
client = BaseOpenAIChatCompletionClient(
|
||||
client=mock_client,
|
||||
create_args={"model": "gemini-1.5-flash"},
|
||||
model_info={
|
||||
"vision": False,
|
||||
"function_calling": False,
|
||||
"json_output": False,
|
||||
"family": "unknown",
|
||||
"structured_output": False,
|
||||
},
|
||||
)
|
||||
|
||||
# Create non-continuous system messages
|
||||
messages: List[LLMMessage] = [
|
||||
SystemMessage(content="I am system message 1"),
|
||||
UserMessage(content="Hello", source="user"),
|
||||
SystemMessage(content="I am system message 2"),
|
||||
]
|
||||
|
||||
# Process should raise ValueError
|
||||
with pytest.raises(ValueError, match="Multiple and Not continuous system messages are not supported"):
|
||||
# pylint: disable=protected-access
|
||||
# The method is protected, but we need to test it
|
||||
client._process_create_args( # pyright: ignore[reportPrivateUsage]
|
||||
messages=messages,
|
||||
tools=[],
|
||||
json_output=None,
|
||||
extra_create_args={},
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_system_message_not_merged_for_non_gemini_models() -> None:
|
||||
"""Tests that system messages aren't modified for non-Gemini models."""
|
||||
# Create a mock client
|
||||
mock_client = MagicMock()
|
||||
client = BaseOpenAIChatCompletionClient(
|
||||
client=mock_client,
|
||||
create_args={"model": "gpt-4o"},
|
||||
model_info={
|
||||
"vision": False,
|
||||
"function_calling": False,
|
||||
"json_output": False,
|
||||
"family": "unknown",
|
||||
"structured_output": False,
|
||||
},
|
||||
)
|
||||
|
||||
# Create two system messages
|
||||
messages: List[LLMMessage] = [
|
||||
SystemMessage(content="I am system message 1"),
|
||||
SystemMessage(content="I am system message 2"),
|
||||
UserMessage(content="Hello", source="user"),
|
||||
]
|
||||
|
||||
# Process the messages
|
||||
# pylint: disable=protected-access
|
||||
# The method is protected, but we need to test it
|
||||
create_params = client._process_create_args( # pyright: ignore[reportPrivateUsage]
|
||||
messages=messages,
|
||||
tools=[],
|
||||
json_output=None,
|
||||
extra_create_args={},
|
||||
)
|
||||
|
||||
# Extract the actual messages from the result
|
||||
oai_messages = create_params.messages
|
||||
|
||||
# Check that there are two system messages preserved
|
||||
system_messages = [msg for msg in oai_messages if msg["role"] == "system"]
|
||||
assert len(system_messages) == 2
|
||||
assert system_messages[0]["content"] == "I am system message 1"
|
||||
assert system_messages[1]["content"] == "I am system message 2"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_system_messages_for_gemini_model() -> None:
|
||||
"""Tests behavior when no system messages are provided to a Gemini model."""
|
||||
# Create a mock client
|
||||
mock_client = MagicMock()
|
||||
client = BaseOpenAIChatCompletionClient(
|
||||
client=mock_client,
|
||||
create_args={"model": "gemini-1.5-flash"},
|
||||
model_info={
|
||||
"vision": False,
|
||||
"function_calling": False,
|
||||
"json_output": False,
|
||||
"family": "unknown",
|
||||
"structured_output": False,
|
||||
},
|
||||
)
|
||||
|
||||
# Create messages with no system message
|
||||
messages: List[LLMMessage] = [
|
||||
UserMessage(content="Hello", source="user"),
|
||||
AssistantMessage(content="Hi there", source="assistant"),
|
||||
]
|
||||
|
||||
# Process the messages
|
||||
# pylint: disable=protected-access
|
||||
# The method is protected, but we need to test it
|
||||
create_params = client._process_create_args( # pyright: ignore[reportPrivateUsage]
|
||||
messages=messages,
|
||||
tools=[],
|
||||
json_output=None,
|
||||
extra_create_args={},
|
||||
)
|
||||
|
||||
# Extract the actual messages from the result
|
||||
oai_messages = create_params.messages
|
||||
|
||||
# Check that there are no system messages
|
||||
system_messages = [msg for msg in oai_messages if msg["role"] == "system"]
|
||||
assert len(system_messages) == 0
|
||||
|
||||
# Check that other messages are preserved
|
||||
user_messages = [msg for msg in oai_messages if msg["role"] == "user"]
|
||||
assistant_messages = [msg for msg in oai_messages if msg["role"] == "assistant"]
|
||||
assert len(user_messages) == 1
|
||||
assert len(assistant_messages) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_single_system_message_for_gemini_model() -> None:
|
||||
"""Tests that a single system message is preserved for Gemini models."""
|
||||
# Create a mock client
|
||||
mock_client = MagicMock()
|
||||
client = BaseOpenAIChatCompletionClient(
|
||||
client=mock_client,
|
||||
create_args={"model": "gemini-1.5-flash"},
|
||||
model_info={
|
||||
"vision": False,
|
||||
"function_calling": False,
|
||||
"json_output": False,
|
||||
"family": "unknown",
|
||||
"structured_output": False,
|
||||
},
|
||||
)
|
||||
|
||||
# Create messages with a single system message
|
||||
messages: List[LLMMessage] = [
|
||||
SystemMessage(content="I am the only system message"),
|
||||
UserMessage(content="Hello", source="user"),
|
||||
]
|
||||
|
||||
# Process the messages
|
||||
# pylint: disable=protected-access
|
||||
# The method is protected, but we need to test it
|
||||
create_params = client._process_create_args( # pyright: ignore[reportPrivateUsage]
|
||||
messages=messages,
|
||||
tools=[],
|
||||
json_output=None,
|
||||
extra_create_args={},
|
||||
)
|
||||
|
||||
# Extract the actual messages from the result
|
||||
oai_messages = create_params.messages
|
||||
|
||||
# Check that there is exactly one system message with the correct content
|
||||
system_messages = [msg for msg in oai_messages if msg["role"] == "system"]
|
||||
assert len(system_messages) == 1
|
||||
assert system_messages[0]["content"] == "I am the only system message"
|
||||
|
||||
|
||||
# TODO: add integration tests for Azure OpenAI using AAD token.
|
||||
|
|
Loading…
Reference in New Issue