FIX: Anthropic and Gemini could take multiple system message (#6118)

Anthropic SDK could not takes multiple system messages.
However some autogen Agent(e.g. SocietyOfMindAgent) makes multiple
system messages.

And... Gemini with OpenaiSDK do not take error. However is not working
mulitple system messages.
(Just last one is working)

So, I simple change of, "merge multiple system message" at these cases.

## Related issue number
Closes #6116
Closes #6117


---------

Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com>
This commit is contained in:
EeS 2025-03-29 01:05:54 +09:00 committed by GitHub
parent c24eba6ae1
commit 0cd3ff46fa
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 532 additions and 0 deletions

View File

@ -408,6 +408,35 @@ class BaseAnthropicChatCompletionClient(ChatCompletionClient):
self._total_usage = RequestUsage(prompt_tokens=0, completion_tokens=0)
self._actual_usage = RequestUsage(prompt_tokens=0, completion_tokens=0)
def _merge_system_messages(self, messages: Sequence[LLMMessage]) -> Sequence[LLMMessage]:
"""
Merge continuous system messages into a single message.
"""
_messages: List[LLMMessage] = []
system_message_content = ""
_first_system_message_idx = -1
_last_system_message_idx = -1
# Index of the first system message for adding the merged system message at the correct position
for idx, message in enumerate(messages):
if isinstance(message, SystemMessage):
if _first_system_message_idx == -1:
_first_system_message_idx = idx
elif _last_system_message_idx + 1 != idx:
# That case, system message is not continuous
# Merge system messages only contiues system messages
raise ValueError("Multiple and Not continuous system messages are not supported")
system_message_content += message.content + "\n"
_last_system_message_idx = idx
else:
_messages.append(message)
system_message_content = system_message_content.rstrip()
if system_message_content != "":
system_message = SystemMessage(content=system_message_content)
_messages.insert(_first_system_message_idx, system_message)
messages = _messages
return messages
async def create(
self,
messages: Sequence[LLMMessage],
@ -442,9 +471,12 @@ class BaseAnthropicChatCompletionClient(ChatCompletionClient):
system_message = None
anthropic_messages: List[MessageParam] = []
# Merge continuous system messages into a single message
messages = self._merge_system_messages(messages)
for message in messages:
if isinstance(message, SystemMessage):
if system_message is not None:
# if that case, system message is must only one
raise ValueError("Multiple system messages are not supported")
system_message = to_anthropic_type(message)
else:
@ -604,9 +636,12 @@ class BaseAnthropicChatCompletionClient(ChatCompletionClient):
system_message = None
anthropic_messages: List[MessageParam] = []
# Merge continuous system messages into a single message
messages = self._merge_system_messages(messages)
for message in messages:
if isinstance(message, SystemMessage):
if system_message is not None:
# if that case, system message is must only one
raise ValueError("Multiple system messages are not supported")
system_message = to_anthropic_type(message)
else:

View File

@ -612,6 +612,32 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
if self.model_info["json_output"] is False and json_output is True:
raise ValueError("Model does not support JSON output.")
if create_args.get("model", "unknown").startswith("gemini-"):
# Gemini models accept only one system message(else, it will read only the last one)
# So, merge system messages into one
system_message_content = ""
_messages: List[LLMMessage] = []
_first_system_message_idx = -1
_last_system_message_idx = -1
# Index of the first system message for adding the merged system message at the correct position
for idx, message in enumerate(messages):
if isinstance(message, SystemMessage):
if _first_system_message_idx == -1:
_first_system_message_idx = idx
elif _last_system_message_idx + 1 != idx:
# That case, system message is not continuous
# Merge system messages only contiues system messages
raise ValueError("Multiple and Not continuous system messages are not supported")
system_message_content += message.content + "\n"
_last_system_message_idx = idx
else:
_messages.append(message)
system_message_content = system_message_content.rstrip()
if system_message_content != "":
system_message = SystemMessage(content=system_message_content)
_messages.insert(_first_system_message_idx, system_message)
messages = _messages
oai_messages_nested = [to_oai_type(m, prepend_name=self._add_name_prefixes) for m in messages]
oai_messages = [item for sublist in oai_messages_nested for item in sublist]

View File

@ -334,3 +334,229 @@ async def test_anthropic_serialization() -> None:
loaded_model_client = AnthropicChatCompletionClient.load_component(model_client_config)
assert loaded_model_client is not None
assert isinstance(loaded_model_client, AnthropicChatCompletionClient)
@pytest.mark.asyncio
async def test_anthropic_muliple_system_message() -> None:
"""Test multiple system messages in a single request."""
api_key = os.getenv("ANTHROPIC_API_KEY")
if not api_key:
pytest.skip("ANTHROPIC_API_KEY not found in environment variables")
client = AnthropicChatCompletionClient(
model="claude-3-haiku-20240307",
api_key=api_key,
)
# Test multiple system messages
messages: List[LLMMessage] = [
SystemMessage(content="When you say anything Start with 'FOO'"),
SystemMessage(content="When you say anything End with 'BAR'"),
UserMessage(content="Just say '.'", source="user"),
]
result = await client.create(messages=messages)
result_content = result.content
assert isinstance(result_content, str)
result_content = result_content.strip()
assert result_content[:3] == "FOO"
assert result_content[-3:] == "BAR"
def test_merge_continuous_system_messages() -> None:
"""Tests merging of continuous system messages."""
client = AnthropicChatCompletionClient(model="claude-3-haiku-20240307", api_key="fake-api-key")
messages: List[LLMMessage] = [
SystemMessage(content="System instruction 1"),
SystemMessage(content="System instruction 2"),
UserMessage(content="User question", source="user"),
]
merged_messages = client._merge_system_messages(messages) # pyright: ignore[reportPrivateUsage]
# The method is protected, but we need to test it
# 병합 후 2개 메시지만 남아야 함 (시스템 1개, 사용자 1개)
assert len(merged_messages) == 2
# 첫 번째 메시지는 병합된 시스템 메시지여야 함
assert isinstance(merged_messages[0], SystemMessage)
assert merged_messages[0].content == "System instruction 1\nSystem instruction 2"
# 두 번째 메시지는 사용자 메시지여야 함
assert isinstance(merged_messages[1], UserMessage)
assert merged_messages[1].content == "User question"
def test_merge_single_system_message() -> None:
"""Tests that a single system message remains unchanged."""
client = AnthropicChatCompletionClient(model="claude-3-haiku-20240307", api_key="fake-api-key")
messages: List[LLMMessage] = [
SystemMessage(content="Single system instruction"),
UserMessage(content="User question", source="user"),
]
merged_messages = client._merge_system_messages(messages) # pyright: ignore[reportPrivateUsage]
# The method is protected, but we need to test it
# 메시지 개수는 변하지 않아야 함
assert len(merged_messages) == 2
# 시스템 메시지 내용은 변하지 않아야 함
assert isinstance(merged_messages[0], SystemMessage)
assert merged_messages[0].content == "Single system instruction"
def test_merge_no_system_messages() -> None:
"""Tests behavior when there are no system messages."""
client = AnthropicChatCompletionClient(model="claude-3-haiku-20240307", api_key="fake-api-key")
messages: List[LLMMessage] = [
UserMessage(content="User question without system", source="user"),
]
merged_messages = client._merge_system_messages(messages) # pyright: ignore[reportPrivateUsage]
# The method is protected, but we need to test it
# 메시지 개수는 변하지 않아야 함
assert len(merged_messages) == 1
# 유일한 메시지는 사용자 메시지여야 함
assert isinstance(merged_messages[0], UserMessage)
assert merged_messages[0].content == "User question without system"
def test_merge_non_continuous_system_messages() -> None:
"""Tests that an error is raised for non-continuous system messages."""
client = AnthropicChatCompletionClient(model="claude-3-haiku-20240307", api_key="fake-api-key")
messages: List[LLMMessage] = [
SystemMessage(content="First group 1"),
SystemMessage(content="First group 2"),
UserMessage(content="Middle user message", source="user"),
SystemMessage(content="Second group 1"),
SystemMessage(content="Second group 2"),
]
# 연속적이지 않은 시스템 메시지는 에러를 발생시켜야 함
with pytest.raises(ValueError, match="Multiple and Not continuous system messages are not supported"):
client._merge_system_messages(messages) # pyright: ignore[reportPrivateUsage]
# The method is protected, but we need to test it
def test_merge_system_messages_empty() -> None:
"""Tests that empty message list is handled properly."""
client = AnthropicChatCompletionClient(model="claude-3-haiku-20240307", api_key="fake-api-key")
merged_messages = client._merge_system_messages([]) # pyright: ignore[reportPrivateUsage]
# The method is protected, but we need to test it
assert len(merged_messages) == 0
def test_merge_system_messages_with_special_characters() -> None:
"""Tests system message merging with special characters and formatting."""
client = AnthropicChatCompletionClient(model="claude-3-haiku-20240307", api_key="fake-api-key")
messages: List[LLMMessage] = [
SystemMessage(content="Line 1\nWith newline"),
SystemMessage(content="Line 2 with *formatting*"),
SystemMessage(content="Line 3 with `code`"),
UserMessage(content="Question", source="user"),
]
merged_messages = client._merge_system_messages(messages) # pyright: ignore[reportPrivateUsage]
# The method is protected, but we need to test it
assert len(merged_messages) == 2
system_message = merged_messages[0]
assert isinstance(system_message, SystemMessage)
assert system_message.content == "Line 1\nWith newline\nLine 2 with *formatting*\nLine 3 with `code`"
def test_merge_system_messages_with_whitespace() -> None:
"""Tests system message merging with extra whitespace."""
client = AnthropicChatCompletionClient(model="claude-3-haiku-20240307", api_key="fake-api-key")
messages: List[LLMMessage] = [
SystemMessage(content=" Message with leading spaces "),
SystemMessage(content="\nMessage with leading newline\n"),
UserMessage(content="Question", source="user"),
]
merged_messages = client._merge_system_messages(messages) # pyright: ignore[reportPrivateUsage]
# The method is protected, but we need to test it
assert len(merged_messages) == 2
system_message = merged_messages[0]
assert isinstance(system_message, SystemMessage)
# strip()은 내부에서 발생하지 않지만 최종 결과에서는 줄바꿈이 유지됨
assert system_message.content == " Message with leading spaces \n\nMessage with leading newline"
def test_merge_system_messages_message_order() -> None:
"""Tests that message order is preserved after merging."""
client = AnthropicChatCompletionClient(model="claude-3-haiku-20240307", api_key="fake-api-key")
messages: List[LLMMessage] = [
UserMessage(content="Question 1", source="user"),
SystemMessage(content="Instruction 1"),
SystemMessage(content="Instruction 2"),
UserMessage(content="Question 2", source="user"),
AssistantMessage(content="Answer", source="assistant"),
]
merged_messages = client._merge_system_messages(messages) # pyright: ignore[reportPrivateUsage]
# The method is protected, but we need to test it
assert len(merged_messages) == 4
# 첫 번째 메시지는 UserMessage여야 함
assert isinstance(merged_messages[0], UserMessage)
assert merged_messages[0].content == "Question 1"
# 두 번째 메시지는 병합된 SystemMessage여야 함
assert isinstance(merged_messages[1], SystemMessage)
assert merged_messages[1].content == "Instruction 1\nInstruction 2"
# 나머지 메시지는 순서대로 유지되어야 함
assert isinstance(merged_messages[2], UserMessage)
assert merged_messages[2].content == "Question 2"
assert isinstance(merged_messages[3], AssistantMessage)
assert merged_messages[3].content == "Answer"
def test_merge_system_messages_multiple_groups() -> None:
"""Tests that multiple separate groups of system messages raise an error."""
client = AnthropicChatCompletionClient(model="claude-3-haiku-20240307", api_key="fake-api-key")
# 연속되지 않은 시스템 메시지: 사용자 메시지로 분리된 두 그룹
messages: List[LLMMessage] = [
SystemMessage(content="Group 1 - message 1"),
UserMessage(content="Interrupting user message", source="user"),
SystemMessage(content="Group 2 - message 1"),
]
with pytest.raises(ValueError, match="Multiple and Not continuous system messages are not supported"):
client._merge_system_messages(messages) # pyright: ignore[reportPrivateUsage]
# The method is protected, but we need to test it
def test_merge_system_messages_no_duplicates() -> None:
"""Tests that identical system messages are still merged properly."""
client = AnthropicChatCompletionClient(model="claude-3-haiku-20240307", api_key="fake-api-key")
messages: List[LLMMessage] = [
SystemMessage(content="Same instruction"),
SystemMessage(content="Same instruction"), # 중복된 내용
UserMessage(content="Question", source="user"),
]
merged_messages = client._merge_system_messages(messages) # pyright: ignore[reportPrivateUsage]
# The method is protected, but we need to test it
assert len(merged_messages) == 2
# 첫 번째 메시지는 병합된 시스템 메시지여야 함
assert isinstance(merged_messages[0], SystemMessage)
# 중복된 내용도 그대로 병합됨
assert merged_messages[0].content == "Same instruction\nSame instruction"

View File

@ -24,6 +24,7 @@ from autogen_core.tools import BaseTool, FunctionTool
from autogen_ext.models.openai import AzureOpenAIChatCompletionClient, OpenAIChatCompletionClient
from autogen_ext.models.openai._model_info import resolve_model
from autogen_ext.models.openai._openai_client import (
BaseOpenAIChatCompletionClient,
calculate_vision_tokens,
convert_tools,
to_oai_type,
@ -2058,4 +2059,248 @@ async def test_add_name_prefixes(monkeypatch: pytest.MonkeyPatch) -> None:
assert str(converted_mm["content"][0]["text"]) == "Adam said:\n" + str(oai_mm["content"][0]["text"])
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model",
[
"gpt-4o-mini",
"gemini-1.5-flash",
# TODO: Add anthropic models when available.
],
)
async def test_muliple_system_message(model: str, openai_client: OpenAIChatCompletionClient) -> None:
"""Test multiple system messages in a single request."""
# Test multiple system messages
messages: List[LLMMessage] = [
SystemMessage(content="When you say anything Start with 'FOO'"),
SystemMessage(content="When you say anything End with 'BAR'"),
UserMessage(content="Just say '.'", source="user"),
]
result = await openai_client.create(messages=messages)
result_content = result.content
assert isinstance(result_content, str)
result_content = result_content.strip()
assert result_content[:3] == "FOO"
assert result_content[-3:] == "BAR"
@pytest.mark.asyncio
async def test_system_message_merge_for_gemini_models() -> None:
"""Tests that system messages are merged correctly for Gemini models."""
# Create a mock client
mock_client = MagicMock()
client = BaseOpenAIChatCompletionClient(
client=mock_client,
create_args={"model": "gemini-1.5-flash"},
model_info={
"vision": False,
"function_calling": False,
"json_output": False,
"family": "unknown",
"structured_output": False,
},
)
# Create two system messages
messages: List[LLMMessage] = [
SystemMessage(content="I am system message 1"),
SystemMessage(content="I am system message 2"),
UserMessage(content="Hello", source="user"),
]
# Process the messages
# pylint: disable=protected-access
# The method is protected, but we need to test it
create_params = client._process_create_args( # pyright: ignore[reportPrivateUsage]
messages=messages,
tools=[],
json_output=None,
extra_create_args={},
)
# Extract the actual messages from the result
oai_messages = create_params.messages
# Check that there is only one system message and it contains the merged content
system_messages = [msg for msg in oai_messages if msg["role"] == "system"]
assert len(system_messages) == 1
assert system_messages[0]["content"] == "I am system message 1\nI am system message 2"
# Check that the user message is preserved
user_messages = [msg for msg in oai_messages if msg["role"] == "user"]
assert len(user_messages) == 1
assert user_messages[0]["content"] == "Hello"
@pytest.mark.asyncio
async def test_system_message_merge_with_non_continuous_messages() -> None:
"""Tests that an error is raised when non-continuous system messages are provided."""
# Create a mock client
mock_client = MagicMock()
client = BaseOpenAIChatCompletionClient(
client=mock_client,
create_args={"model": "gemini-1.5-flash"},
model_info={
"vision": False,
"function_calling": False,
"json_output": False,
"family": "unknown",
"structured_output": False,
},
)
# Create non-continuous system messages
messages: List[LLMMessage] = [
SystemMessage(content="I am system message 1"),
UserMessage(content="Hello", source="user"),
SystemMessage(content="I am system message 2"),
]
# Process should raise ValueError
with pytest.raises(ValueError, match="Multiple and Not continuous system messages are not supported"):
# pylint: disable=protected-access
# The method is protected, but we need to test it
client._process_create_args( # pyright: ignore[reportPrivateUsage]
messages=messages,
tools=[],
json_output=None,
extra_create_args={},
)
@pytest.mark.asyncio
async def test_system_message_not_merged_for_non_gemini_models() -> None:
"""Tests that system messages aren't modified for non-Gemini models."""
# Create a mock client
mock_client = MagicMock()
client = BaseOpenAIChatCompletionClient(
client=mock_client,
create_args={"model": "gpt-4o"},
model_info={
"vision": False,
"function_calling": False,
"json_output": False,
"family": "unknown",
"structured_output": False,
},
)
# Create two system messages
messages: List[LLMMessage] = [
SystemMessage(content="I am system message 1"),
SystemMessage(content="I am system message 2"),
UserMessage(content="Hello", source="user"),
]
# Process the messages
# pylint: disable=protected-access
# The method is protected, but we need to test it
create_params = client._process_create_args( # pyright: ignore[reportPrivateUsage]
messages=messages,
tools=[],
json_output=None,
extra_create_args={},
)
# Extract the actual messages from the result
oai_messages = create_params.messages
# Check that there are two system messages preserved
system_messages = [msg for msg in oai_messages if msg["role"] == "system"]
assert len(system_messages) == 2
assert system_messages[0]["content"] == "I am system message 1"
assert system_messages[1]["content"] == "I am system message 2"
@pytest.mark.asyncio
async def test_no_system_messages_for_gemini_model() -> None:
"""Tests behavior when no system messages are provided to a Gemini model."""
# Create a mock client
mock_client = MagicMock()
client = BaseOpenAIChatCompletionClient(
client=mock_client,
create_args={"model": "gemini-1.5-flash"},
model_info={
"vision": False,
"function_calling": False,
"json_output": False,
"family": "unknown",
"structured_output": False,
},
)
# Create messages with no system message
messages: List[LLMMessage] = [
UserMessage(content="Hello", source="user"),
AssistantMessage(content="Hi there", source="assistant"),
]
# Process the messages
# pylint: disable=protected-access
# The method is protected, but we need to test it
create_params = client._process_create_args( # pyright: ignore[reportPrivateUsage]
messages=messages,
tools=[],
json_output=None,
extra_create_args={},
)
# Extract the actual messages from the result
oai_messages = create_params.messages
# Check that there are no system messages
system_messages = [msg for msg in oai_messages if msg["role"] == "system"]
assert len(system_messages) == 0
# Check that other messages are preserved
user_messages = [msg for msg in oai_messages if msg["role"] == "user"]
assistant_messages = [msg for msg in oai_messages if msg["role"] == "assistant"]
assert len(user_messages) == 1
assert len(assistant_messages) == 1
@pytest.mark.asyncio
async def test_single_system_message_for_gemini_model() -> None:
"""Tests that a single system message is preserved for Gemini models."""
# Create a mock client
mock_client = MagicMock()
client = BaseOpenAIChatCompletionClient(
client=mock_client,
create_args={"model": "gemini-1.5-flash"},
model_info={
"vision": False,
"function_calling": False,
"json_output": False,
"family": "unknown",
"structured_output": False,
},
)
# Create messages with a single system message
messages: List[LLMMessage] = [
SystemMessage(content="I am the only system message"),
UserMessage(content="Hello", source="user"),
]
# Process the messages
# pylint: disable=protected-access
# The method is protected, but we need to test it
create_params = client._process_create_args( # pyright: ignore[reportPrivateUsage]
messages=messages,
tools=[],
json_output=None,
extra_create_args={},
)
# Extract the actual messages from the result
oai_messages = create_params.messages
# Check that there is exactly one system message with the correct content
system_messages = [msg for msg in oai_messages if msg["role"] == "system"]
assert len(system_messages) == 1
assert system_messages[0]["content"] == "I am the only system message"
# TODO: add integration tests for Azure OpenAI using AAD token.