mirror of https://github.com/microsoft/autogen.git
[BugFix][Refactor] Modular Transformer Pipeline and Fix Gemini/Anthropic Empty Content Handling (#6063)
## Why are these changes needed? This change addresses a compatibility issue when using Google Gemini models with AutoGen. Specifically, Gemini returns a 400 INVALID_ARGUMENT error when receiving a response with an empty "text" parameter. The root cause is that Gemini does not accept empty string values (e.g., "") as valid inputs in the history of the conversation. To fix this, if the content field is falsy (e.g., None, "", etc.), it is explicitly replaced with a single whitespace (" "), which prevents the Gemini model from rejecting the request. - **Gemini API compatibility:** Gemini models reject empty assistant messages (e.g., `""`), causing runtime errors. This PR ensures such messages are safely replaced with whitespace where appropriate. - **Avoiding regressions:** Applying the empty content workaround **only to Gemini**, and **only to valid message types**, avoids breaking OpenAI or other models. - **Reducing duplication:** Previously, message transformation logic was scattered and repeated across different message types and models. Modularizing this pipeline removes that redundancy. - **Improved maintainability:** With future model variants likely to introduce more constraints, this modular structure makes it easier to adapt transformations without writing ad-hoc code each time. - **Testing for correctness:** The new structure is verified with tests, ensuring the bug fix is effective and non-intrusive. ## Summary This PR introduces a **modular transformer pipeline** for message conversion and **fixes a Gemini-specific bug** related to empty assistant message content. ### Key Changes - **[Refactor]** Extracted message transformation logic into a unified pipeline to: - Reduce code duplication - Improve maintainability - Simplify debugging and extension for future model-specific logic - **[BugFix]** Gemini models do not accept empty assistant message content. - Introduced `_set_empty_to_whitespace` transformer to replace empty strings with `" "` only where needed - Applied it **only** to `"text"` and `"thought"` message types, not to `"tools"` to avoid serialization errors - **Improved structure for model-specific handling** - Transformer functions are now grouped and conditionally applied based on message type and model family - This design makes it easier to support future models or combinations (e.g., Gemini + R1) - **Test coverage added** - Added dedicated tests to verify that empty assistant content causes errors for Gemini - Ensured the fix resolves the issue without affecting OpenAI models --- ## Motivation Originally, Gemini-compatible endpoints would fail when receiving assistant messages with empty content (`""`). This issue required special handling without introducing brittle, ad-hoc patches. In addressing this, I also saw an opportunity to **modularize** the message transformation logic across models. This improves clarity, avoids duplication, and simplifies future adaptations (e.g., different constraints across model families). --- ## 📘 AutoGen Modular Message Transformer: Design & Usage Guide This document introduces the **new modular transformer system** used in AutoGen for converting `LLMMessage` instances to SDK-specific message formats (e.g., OpenAI-style `ChatCompletionMessageParam`). The design improves **reusability, extensibility**, and **maintainability** across different model families. --- ### 🚀 Overview Instead of scattering model-specific message conversion logic across the codebase, the new design introduces: - Modular transformer **functions** for each message type - Per-model **transformer maps** (e.g., for OpenAI-compatible models) - Optional **conditional transformers** for multimodal/text hybrid models - Clear separation between **message adaptation logic** and **SDK-specific builder** (e.g., `ChatCompletionUserMessageParam`) --- ### 🧱 1. Define Transform Functions Each transformer function takes: - `LLMMessage`: a structured AutoGen message - `context: dict`: metadata passed through the builder pipeline And returns: - A dictionary of keyword arguments for the target message constructor (e.g., `{"content": ..., "name": ..., "role": ...}`) ```python def _set_thought_as_content_gemini(message: LLMMessage, context: Dict[str, Any]) -> Dict[str, str | None]: assert isinstance(message, AssistantMessage) return {"content": message.thought or " "} ``` --- ### 🪢 2. Compose Transformer Pipelines Multiple transformer functions are composed into a pipeline using `build_transformer_func()`: ```python base_user_transformer_funcs: List[Callable[[LLMMessage, Dict[str, Any]], Dict[str, Any]]] = [ _assert_valid_name, _set_name, _set_role("user"), ] user_transformer = build_transformer_func( funcs=base_user_transformer_funcs, message_param_func=ChatCompletionUserMessageParam ) ``` - The `message_param_func` is the actual constructor for the target message class (usually from the SDK). - The pipeline is **ordered** — each function adds or overrides keys in the builder kwargs. --- ### 🗂️ 3. Register Transformer Map Each model family maintains a `TransformerMap`, which maps `LLMMessage` types to transformers: ```python __BASE_TRANSFORMER_MAP: TransformerMap = { SystemMessage: system_transformer, UserMessage: user_transformer, AssistantMessage: assistant_transformer, } register_transformer("openai", model_name_or_family, __BASE_TRANSFORMER_MAP) ``` - `"openai"` is currently required (as only OpenAI-compatible format is supported now). - Registration ensures AutoGen knows how to transform each message type for that model. --- ### 🔁 4. Conditional Transformers (Optional) When message construction depends on runtime conditions (e.g., `"text"` vs. `"multimodal"`), use: ```python conditional_transformer = build_conditional_transformer_func( funcs_map=user_transformer_funcs_claude, message_param_func_map=user_transformer_constructors, condition_func=user_condition, ) ``` Where: - `funcs_map`: maps condition label → list of transformer functions ```python user_transformer_funcs_claude = { "text": text_transformers + [_set_empty_to_whitespace], "multimodal": multimodal_transformers + [_set_empty_to_whitespace], } ``` - `message_param_func_map`: maps condition label → message builder ```python user_transformer_constructors = { "text": ChatCompletionUserMessageParam, "multimodal": ChatCompletionUserMessageParam, } ``` - `condition_func`: determines which transformer to apply at runtime ```python def user_condition(message: LLMMessage, context: Dict[str, Any]) -> str: if isinstance(message.content, str): return "text" return "multimodal" ``` --- ### 🧪 Example Flow ```python llm_message = AssistantMessage(name="a", thought="let’s go") model_family = "openai" model_name = "claude-3-opus" transformer = get_transformer(model_family, model_name, type(llm_message)) sdk_message = transformer(llm_message, context={}) ``` --- ### 🎯 Design Benefits | Feature | Benefit | |--------|---------| | 🧱 Function-based modular design | Easy to compose and test | | 🧩 Per-model registry | Clean separation across model families | | ⚖️ Conditional support | Allows multimodal / dynamic adaptation | | 🔄 Reuse-friendly | Shared logic (e.g., `_set_name`) is DRY | | 📦 SDK-specific | Keeps message adaptation aligned to builder interface | --- ### 🔮 Future Direction - Support more SDKs and formats by introducing new message_param_func - Global registry integration (currently `"openai"`-scoped) - Class-based transformer variant if complexity grows --- ## Related issue number Closes #5762 ## Checks - [ ] I've included any doc changes needed for <https://microsoft.github.io/autogen/>. See <https://github.com/microsoft/autogen/blob/main/CONTRIBUTING.md> to build and test documentation locally. - [x] I've added tests (if relevant) corresponding to the changes introduced in this PR. - [ v ] I've made sure all auto checks have passed. --------- Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com>
This commit is contained in:
parent
7615c7b83b
commit
fbdd89b46b
|
@ -31,9 +31,9 @@ class ModelFamily:
|
|||
CLAUDE_3_HAIKU = "claude-3-haiku"
|
||||
CLAUDE_3_SONNET = "claude-3-sonnet"
|
||||
CLAUDE_3_OPUS = "claude-3-opus"
|
||||
CLAUDE_3_5_HAIKU = "claude-3.5-haiku"
|
||||
CLAUDE_3_5_SONNET = "claude-3.5-sonnet"
|
||||
CLAUDE_3_7_SONNET = "claude-3.7-sonnet"
|
||||
CLAUDE_3_5_HAIKU = "claude-3-5-haiku"
|
||||
CLAUDE_3_5_SONNET = "claude-3-5-sonnet"
|
||||
CLAUDE_3_7_SONNET = "claude-3-7-sonnet"
|
||||
UNKNOWN = "unknown"
|
||||
|
||||
ANY: TypeAlias = Literal[
|
||||
|
@ -50,8 +50,9 @@ class ModelFamily:
|
|||
"claude-3-haiku",
|
||||
"claude-3-sonnet",
|
||||
"claude-3-opus",
|
||||
"claude-3.5-haiku",
|
||||
"claude-3.5-sonnet",
|
||||
"claude-3-5-haiku",
|
||||
"claude-3-5-sonnet",
|
||||
"claude-3-7-sonnet",
|
||||
"unknown",
|
||||
]
|
||||
|
||||
|
@ -66,6 +67,7 @@ class ModelFamily:
|
|||
ModelFamily.CLAUDE_3_OPUS,
|
||||
ModelFamily.CLAUDE_3_5_HAIKU,
|
||||
ModelFamily.CLAUDE_3_5_SONNET,
|
||||
ModelFamily.CLAUDE_3_7_SONNET,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
|
|
@ -12,6 +12,7 @@ from typing import (
|
|||
AsyncGenerator,
|
||||
Coroutine,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
Literal,
|
||||
Mapping,
|
||||
|
@ -20,6 +21,7 @@ from typing import (
|
|||
Set,
|
||||
Union,
|
||||
cast,
|
||||
overload,
|
||||
)
|
||||
|
||||
import tiktoken
|
||||
|
@ -142,20 +144,41 @@ def get_mime_type_from_image(image: Image) -> Literal["image/jpeg", "image/png",
|
|||
return "image/jpeg"
|
||||
|
||||
|
||||
@overload
|
||||
def __empty_content_to_whitespace(content: str) -> str: ...
|
||||
|
||||
|
||||
@overload
|
||||
def __empty_content_to_whitespace(content: List[Any]) -> Iterable[Any]: ...
|
||||
|
||||
|
||||
def __empty_content_to_whitespace(
|
||||
content: Union[str, List[Union[str, Image]]],
|
||||
) -> Union[str, Iterable[Any]]:
|
||||
if isinstance(content, str) and not content.strip():
|
||||
return " "
|
||||
elif isinstance(content, list) and not any(isinstance(x, str) and not x.strip() for x in content):
|
||||
for idx, message in enumerate(content):
|
||||
if isinstance(message, str) and not message.strip():
|
||||
content[idx] = " "
|
||||
|
||||
return content
|
||||
|
||||
|
||||
def user_message_to_anthropic(message: UserMessage) -> MessageParam:
|
||||
assert_valid_name(message.source)
|
||||
|
||||
if isinstance(message.content, str):
|
||||
return {
|
||||
"role": "user",
|
||||
"content": message.content,
|
||||
"content": __empty_content_to_whitespace(message.content),
|
||||
}
|
||||
else:
|
||||
blocks: List[Union[TextBlockParam, ImageBlockParam]] = []
|
||||
|
||||
for part in message.content:
|
||||
if isinstance(part, str):
|
||||
blocks.append(TextBlockParam(type="text", text=part))
|
||||
blocks.append(TextBlockParam(type="text", text=__empty_content_to_whitespace(part)))
|
||||
elif isinstance(part, Image):
|
||||
blocks.append(
|
||||
ImageBlockParam(
|
||||
|
@ -177,7 +200,7 @@ def user_message_to_anthropic(message: UserMessage) -> MessageParam:
|
|||
|
||||
|
||||
def system_message_to_anthropic(message: SystemMessage) -> str:
|
||||
return message.content
|
||||
return __empty_content_to_whitespace(message.content)
|
||||
|
||||
|
||||
def assistant_message_to_anthropic(message: AssistantMessage) -> MessageParam:
|
||||
|
@ -190,6 +213,7 @@ def assistant_message_to_anthropic(message: AssistantMessage) -> MessageParam:
|
|||
for func_call in message.content:
|
||||
# Parse the arguments and convert to dict if it's a JSON string
|
||||
args = func_call.arguments
|
||||
args = __empty_content_to_whitespace(args)
|
||||
if isinstance(args, str):
|
||||
try:
|
||||
args_dict = json.loads(args)
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
from . import _message_transform
|
||||
from ._openai_client import (
|
||||
AZURE_OPENAI_USER_AGENT,
|
||||
AzureOpenAIChatCompletionClient,
|
||||
|
@ -20,4 +21,5 @@ __all__ = [
|
|||
"BaseOpenAIClientConfigurationConfigModel",
|
||||
"CreateArgumentsConfigModel",
|
||||
"AZURE_OPENAI_USER_AGENT",
|
||||
"_message_transform",
|
||||
]
|
||||
|
|
|
@ -0,0 +1,357 @@
|
|||
from typing import Any, Callable, Dict, List, cast, get_args
|
||||
|
||||
from autogen_core import (
|
||||
FunctionCall,
|
||||
Image,
|
||||
)
|
||||
from autogen_core.models import (
|
||||
AssistantMessage,
|
||||
FunctionExecutionResultMessage,
|
||||
LLMMessage,
|
||||
ModelFamily,
|
||||
SystemMessage,
|
||||
UserMessage,
|
||||
)
|
||||
from openai.types.chat import (
|
||||
ChatCompletionAssistantMessageParam,
|
||||
ChatCompletionContentPartImageParam,
|
||||
ChatCompletionContentPartParam,
|
||||
ChatCompletionContentPartTextParam,
|
||||
ChatCompletionMessageToolCallParam,
|
||||
ChatCompletionSystemMessageParam,
|
||||
ChatCompletionToolMessageParam,
|
||||
ChatCompletionUserMessageParam,
|
||||
)
|
||||
|
||||
from ._transformation import (
|
||||
LLMMessageContent,
|
||||
TransformerMap,
|
||||
TrasformerReturnType,
|
||||
build_conditional_transformer_func,
|
||||
build_transformer_func,
|
||||
register_transformer,
|
||||
)
|
||||
from ._utils import assert_valid_name
|
||||
|
||||
EMPTY: Dict[str, Any] = {}
|
||||
|
||||
|
||||
def func_call_to_oai(message: FunctionCall) -> ChatCompletionMessageToolCallParam:
|
||||
return ChatCompletionMessageToolCallParam(
|
||||
id=message.id,
|
||||
function={
|
||||
"arguments": message.arguments,
|
||||
"name": message.name,
|
||||
},
|
||||
type="function",
|
||||
)
|
||||
|
||||
|
||||
# ===Mini Transformers===
|
||||
def _assert_valid_name(message: LLMMessage, context: Dict[str, Any]) -> Dict[str, None]:
|
||||
assert isinstance(message, (UserMessage, AssistantMessage))
|
||||
assert_valid_name(message.source)
|
||||
return EMPTY
|
||||
|
||||
|
||||
def _set_role(role: str) -> Callable[[LLMMessage, Dict[str, Any]], Dict[str, str]]:
|
||||
def inner(message: LLMMessage, context: Dict[str, Any]) -> Dict[str, str]:
|
||||
return {"role": role}
|
||||
|
||||
return inner
|
||||
|
||||
|
||||
def _set_name(message: LLMMessage, context: Dict[str, Any]) -> Dict[str, str]:
|
||||
assert isinstance(message, (UserMessage, AssistantMessage))
|
||||
assert_valid_name(message.source)
|
||||
return {"name": message.source}
|
||||
|
||||
|
||||
def _set_content_direct(message: LLMMessage, context: Dict[str, Any]) -> Dict[str, LLMMessageContent]:
|
||||
return {"content": message.content}
|
||||
|
||||
|
||||
def _set_prepend_text_content(message: LLMMessage, context: Dict[str, Any]) -> Dict[str, str]:
|
||||
assert isinstance(message, (UserMessage, AssistantMessage))
|
||||
assert isinstance(message.content, str)
|
||||
prepend = context.get("prepend_name", False)
|
||||
prefix = f"{message.source} said:\n" if prepend else ""
|
||||
return {"content": prefix + message.content}
|
||||
|
||||
|
||||
def _set_multimodal_content(
|
||||
message: LLMMessage, context: Dict[str, Any]
|
||||
) -> Dict[str, List[ChatCompletionContentPartParam]]:
|
||||
assert isinstance(message, (UserMessage, AssistantMessage))
|
||||
prepend = context.get("prepend_name", False)
|
||||
parts: List[ChatCompletionContentPartParam] = []
|
||||
|
||||
for idx, part in enumerate(message.content):
|
||||
if isinstance(part, str):
|
||||
# If prepend, Append the name to the first text part
|
||||
text = f"{message.source} said:\n" + part if prepend and idx == 0 else part
|
||||
parts.append(ChatCompletionContentPartTextParam(type="text", text=text))
|
||||
elif isinstance(part, Image):
|
||||
# TODO: support url based images
|
||||
# TODO: support specifying details
|
||||
parts.append(cast(ChatCompletionContentPartImageParam, part.to_openai_format()))
|
||||
else:
|
||||
raise ValueError(f"Unknown content part: {part}")
|
||||
|
||||
return {"content": parts}
|
||||
|
||||
|
||||
def _set_tool_calls(
|
||||
message: LLMMessage, context: Dict[str, Any]
|
||||
) -> Dict[str, List[ChatCompletionMessageToolCallParam]]:
|
||||
assert isinstance(message.content, list)
|
||||
assert isinstance(message, AssistantMessage)
|
||||
return {
|
||||
"tool_calls": [func_call_to_oai(x) for x in message.content],
|
||||
}
|
||||
|
||||
|
||||
def _set_thought_as_content(message: LLMMessage, context: Dict[str, Any]) -> Dict[str, str | None]:
|
||||
assert isinstance(message, AssistantMessage)
|
||||
return {"content": message.thought}
|
||||
|
||||
|
||||
def _set_thought_as_content_gemini(message: LLMMessage, context: Dict[str, Any]) -> Dict[str, str | None]:
|
||||
assert isinstance(message, AssistantMessage)
|
||||
return {"content": message.thought or " "}
|
||||
|
||||
|
||||
def _set_empty_to_whitespace(message: LLMMessage, context: Dict[str, Any]) -> Dict[str, LLMMessageContent]:
|
||||
return {"content": message.content or " "}
|
||||
|
||||
|
||||
def _set_pass_message_when_whitespace(message: LLMMessage, context: Dict[str, Any]) -> Dict[str, bool]:
|
||||
if isinstance(message.content, str) and (message.content.isspace() or not message.content):
|
||||
return {"pass_message": True}
|
||||
return {}
|
||||
|
||||
|
||||
# === Base Transformers list ===
|
||||
base_system_message_transformers: List[Callable[[LLMMessage, Dict[str, Any]], Dict[str, Any]]] = [
|
||||
_set_content_direct,
|
||||
_set_role("system"),
|
||||
]
|
||||
|
||||
base_user_transformer_funcs: List[Callable[[LLMMessage, Dict[str, Any]], Dict[str, Any]]] = [
|
||||
_assert_valid_name,
|
||||
_set_name,
|
||||
_set_role("user"),
|
||||
]
|
||||
|
||||
base_assistant_transformer_funcs: List[Callable[[LLMMessage, Dict[str, Any]], Dict[str, Any]]] = [
|
||||
_assert_valid_name,
|
||||
_set_name,
|
||||
_set_role("assistant"),
|
||||
]
|
||||
|
||||
|
||||
# === Transformers list ===
|
||||
system_message_transformers: List[Callable[[LLMMessage, Dict[str, Any]], Dict[str, Any]]] = (
|
||||
base_system_message_transformers
|
||||
)
|
||||
|
||||
single_user_transformer_funcs: List[Callable[[LLMMessage, Dict[str, Any]], Dict[str, Any]]] = (
|
||||
base_user_transformer_funcs
|
||||
+ [
|
||||
_set_prepend_text_content,
|
||||
]
|
||||
)
|
||||
|
||||
multimodal_user_transformer_funcs: List[Callable[[LLMMessage, Dict[str, Any]], Dict[str, Any]]] = (
|
||||
base_user_transformer_funcs
|
||||
+ [
|
||||
_set_multimodal_content,
|
||||
]
|
||||
)
|
||||
|
||||
single_assistant_transformer_funcs: List[Callable[[LLMMessage, Dict[str, Any]], Dict[str, Any]]] = (
|
||||
base_assistant_transformer_funcs
|
||||
+ [
|
||||
_set_content_direct,
|
||||
]
|
||||
)
|
||||
|
||||
tools_assistant_transformer_funcs: List[Callable[[LLMMessage, Dict[str, Any]], Dict[str, Any]]] = (
|
||||
base_assistant_transformer_funcs
|
||||
+ [
|
||||
_set_tool_calls,
|
||||
]
|
||||
)
|
||||
|
||||
thought_assistant_transformer_funcs: List[Callable[[LLMMessage, Dict[str, Any]], Dict[str, Any]]] = (
|
||||
tools_assistant_transformer_funcs
|
||||
+ [
|
||||
_set_thought_as_content,
|
||||
]
|
||||
)
|
||||
|
||||
thought_assistant_transformer_funcs_gemini: List[Callable[[LLMMessage, Dict[str, Any]], Dict[str, Any]]] = (
|
||||
tools_assistant_transformer_funcs
|
||||
+ [
|
||||
_set_thought_as_content_gemini,
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
# === Specific message param functions ===
|
||||
|
||||
|
||||
# === Transformer maps ===
|
||||
user_transformer_funcs: Dict[str, List[Callable[[LLMMessage, Dict[str, Any]], Dict[str, Any]]]] = {
|
||||
"text": single_user_transformer_funcs,
|
||||
"multimodal": multimodal_user_transformer_funcs,
|
||||
}
|
||||
user_transformer_constructors: Dict[str, Callable[..., Any]] = {
|
||||
"text": ChatCompletionUserMessageParam,
|
||||
"multimodal": ChatCompletionUserMessageParam,
|
||||
}
|
||||
|
||||
|
||||
def user_condition(message: LLMMessage, context: Dict[str, Any]) -> str:
|
||||
if isinstance(message.content, str):
|
||||
return "text"
|
||||
else:
|
||||
return "multimodal"
|
||||
|
||||
|
||||
assistant_transformer_funcs: Dict[str, List[Callable[[LLMMessage, Dict[str, Any]], Dict[str, Any]]]] = {
|
||||
"text": single_assistant_transformer_funcs,
|
||||
"tools": tools_assistant_transformer_funcs,
|
||||
"thought": thought_assistant_transformer_funcs,
|
||||
}
|
||||
assistant_transformer_constructors: Dict[str, Callable[..., Any]] = {
|
||||
"text": ChatCompletionAssistantMessageParam,
|
||||
"tools": ChatCompletionAssistantMessageParam,
|
||||
"thought": ChatCompletionAssistantMessageParam,
|
||||
}
|
||||
|
||||
|
||||
def assistant_condition(message: LLMMessage, context: Dict[str, Any]) -> str:
|
||||
assert isinstance(message, AssistantMessage)
|
||||
if isinstance(message.content, list):
|
||||
if message.thought is not None:
|
||||
return "thought"
|
||||
else:
|
||||
return "tools"
|
||||
else:
|
||||
return "text"
|
||||
|
||||
|
||||
user_transformer_funcs_gemini: Dict[str, List[Callable[[LLMMessage, Dict[str, Any]], Dict[str, Any]]]] = {
|
||||
"text": single_user_transformer_funcs + [_set_empty_to_whitespace],
|
||||
"multimodal": multimodal_user_transformer_funcs + [_set_empty_to_whitespace],
|
||||
}
|
||||
|
||||
|
||||
assistant_transformer_funcs_gemini: Dict[str, List[Callable[[LLMMessage, Dict[str, Any]], Dict[str, Any]]]] = {
|
||||
"text": single_assistant_transformer_funcs + [_set_empty_to_whitespace],
|
||||
"tools": tools_assistant_transformer_funcs, # that case, message.content is a list of FunctionCall
|
||||
"thought": thought_assistant_transformer_funcs_gemini, # that case, message.content is a list of FunctionCall
|
||||
}
|
||||
|
||||
|
||||
user_transformer_funcs_claude: Dict[str, List[Callable[[LLMMessage, Dict[str, Any]], Dict[str, Any]]]] = {
|
||||
"text": single_user_transformer_funcs + [_set_pass_message_when_whitespace],
|
||||
"multimodal": multimodal_user_transformer_funcs + [_set_pass_message_when_whitespace],
|
||||
}
|
||||
|
||||
|
||||
assistant_transformer_funcs_claude: Dict[str, List[Callable[[LLMMessage, Dict[str, Any]], Dict[str, Any]]]] = {
|
||||
"text": single_assistant_transformer_funcs + [_set_pass_message_when_whitespace],
|
||||
"tools": tools_assistant_transformer_funcs, # that case, message.content is a list of FunctionCall
|
||||
"thought": thought_assistant_transformer_funcs_gemini, # that case, message.content is a list of FunctionCall
|
||||
}
|
||||
|
||||
|
||||
def function_execution_result_message(message: LLMMessage, context: Dict[str, Any]) -> TrasformerReturnType:
|
||||
assert isinstance(message, FunctionExecutionResultMessage)
|
||||
return [
|
||||
ChatCompletionToolMessageParam(content=x.content, role="tool", tool_call_id=x.call_id) for x in message.content
|
||||
]
|
||||
|
||||
|
||||
# === Transformers ===
|
||||
|
||||
__BASE_TRANSFORMER_MAP: TransformerMap = {
|
||||
SystemMessage: build_transformer_func(
|
||||
funcs=system_message_transformers,
|
||||
message_param_func=ChatCompletionSystemMessageParam,
|
||||
),
|
||||
UserMessage: build_conditional_transformer_func(
|
||||
funcs_map=user_transformer_funcs,
|
||||
message_param_func_map=user_transformer_constructors,
|
||||
condition_func=user_condition,
|
||||
),
|
||||
AssistantMessage: build_conditional_transformer_func(
|
||||
funcs_map=assistant_transformer_funcs,
|
||||
message_param_func_map=assistant_transformer_constructors,
|
||||
condition_func=assistant_condition,
|
||||
),
|
||||
FunctionExecutionResultMessage: function_execution_result_message,
|
||||
}
|
||||
|
||||
__GEMINI_TRANSFORMER_MAP: TransformerMap = {
|
||||
SystemMessage: build_transformer_func(
|
||||
funcs=system_message_transformers + [_set_empty_to_whitespace],
|
||||
message_param_func=ChatCompletionSystemMessageParam,
|
||||
),
|
||||
UserMessage: build_conditional_transformer_func(
|
||||
funcs_map=user_transformer_funcs_gemini,
|
||||
message_param_func_map=user_transformer_constructors,
|
||||
condition_func=user_condition,
|
||||
),
|
||||
AssistantMessage: build_conditional_transformer_func(
|
||||
funcs_map=assistant_transformer_funcs_gemini,
|
||||
message_param_func_map=assistant_transformer_constructors,
|
||||
condition_func=assistant_condition,
|
||||
),
|
||||
FunctionExecutionResultMessage: function_execution_result_message,
|
||||
}
|
||||
|
||||
__CLAUDE_TRANSFORMER_MAP: TransformerMap = {
|
||||
SystemMessage: build_transformer_func(
|
||||
funcs=system_message_transformers + [_set_empty_to_whitespace],
|
||||
message_param_func=ChatCompletionSystemMessageParam,
|
||||
),
|
||||
UserMessage: build_conditional_transformer_func(
|
||||
funcs_map=user_transformer_funcs_claude,
|
||||
message_param_func_map=user_transformer_constructors,
|
||||
condition_func=user_condition,
|
||||
),
|
||||
AssistantMessage: build_conditional_transformer_func(
|
||||
funcs_map=assistant_transformer_funcs_claude,
|
||||
message_param_func_map=assistant_transformer_constructors,
|
||||
condition_func=assistant_condition,
|
||||
),
|
||||
FunctionExecutionResultMessage: function_execution_result_message,
|
||||
}
|
||||
|
||||
|
||||
# set openai models to use the transformer map
|
||||
total_models = get_args(ModelFamily.ANY)
|
||||
__openai_models = [model for model in total_models if ModelFamily.is_openai(model)]
|
||||
|
||||
__claude_models = [model for model in total_models if ModelFamily.is_claude(model)]
|
||||
|
||||
__gemini_models = [model for model in total_models if ModelFamily.is_gemini(model)]
|
||||
|
||||
__unknown_models = list(set(total_models) - set(__openai_models) - set(__claude_models) - set(__gemini_models))
|
||||
|
||||
for model in __openai_models:
|
||||
register_transformer("openai", model, __BASE_TRANSFORMER_MAP)
|
||||
|
||||
for model in __claude_models:
|
||||
register_transformer("openai", model, __CLAUDE_TRANSFORMER_MAP)
|
||||
|
||||
for model in __gemini_models:
|
||||
register_transformer("openai", model, __GEMINI_TRANSFORMER_MAP)
|
||||
|
||||
for model in __unknown_models:
|
||||
register_transformer("openai", model, __BASE_TRANSFORMER_MAP)
|
||||
|
||||
register_transformer("openai", "default", __BASE_TRANSFORMER_MAP)
|
|
@ -21,9 +21,9 @@ _MODEL_POINTERS = {
|
|||
"claude-3-haiku": "claude-3-haiku-20240307",
|
||||
"claude-3-sonnet": "claude-3-sonnet-20240229",
|
||||
"claude-3-opus": "claude-3-opus-20240229",
|
||||
"claude-3.5-haiku": "claude-3-5-haiku-20241022",
|
||||
"claude-3.5-sonnet": "claude-3-5-sonnet-20241022",
|
||||
"claude-3.7-sonnet": "claude-3-7-sonnet-20250219",
|
||||
"claude-3-5-haiku": "claude-3-5-haiku-20241022",
|
||||
"claude-3-5-sonnet": "claude-3-5-sonnet-20241022",
|
||||
"claude-3-7-sonnet": "claude-3-7-sonnet-20250219",
|
||||
}
|
||||
|
||||
_MODEL_INFO: Dict[str, ModelInfo] = {
|
||||
|
|
|
@ -12,6 +12,7 @@ from importlib.metadata import PackageNotFoundError, version
|
|||
from typing import (
|
||||
Any,
|
||||
AsyncGenerator,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Mapping,
|
||||
|
@ -38,7 +39,6 @@ from autogen_core.models import (
|
|||
ChatCompletionClient,
|
||||
ChatCompletionTokenLogprob,
|
||||
CreateResult,
|
||||
FunctionExecutionResultMessage,
|
||||
LLMMessage,
|
||||
ModelCapabilities, # type: ignore
|
||||
ModelFamily,
|
||||
|
@ -53,18 +53,11 @@ from autogen_core.tools import Tool, ToolSchema
|
|||
from openai import NOT_GIVEN, AsyncAzureOpenAI, AsyncOpenAI
|
||||
from openai.types.chat import (
|
||||
ChatCompletion,
|
||||
ChatCompletionAssistantMessageParam,
|
||||
ChatCompletionChunk,
|
||||
ChatCompletionContentPartImageParam,
|
||||
ChatCompletionContentPartParam,
|
||||
ChatCompletionContentPartTextParam,
|
||||
ChatCompletionMessageParam,
|
||||
ChatCompletionMessageToolCallParam,
|
||||
ChatCompletionRole,
|
||||
ChatCompletionSystemMessageParam,
|
||||
ChatCompletionToolMessageParam,
|
||||
ChatCompletionToolParam,
|
||||
ChatCompletionUserMessageParam,
|
||||
ParsedChatCompletion,
|
||||
ParsedChoice,
|
||||
completion_create_params,
|
||||
|
@ -82,6 +75,10 @@ from typing_extensions import Self, Unpack
|
|||
from .._utils.normalize_stop_reason import normalize_stop_reason
|
||||
from .._utils.parse_r1_content import parse_r1_content
|
||||
from . import _model_info
|
||||
from ._transformation import (
|
||||
get_transformer,
|
||||
)
|
||||
from ._utils import assert_valid_name
|
||||
from .config import (
|
||||
AzureOpenAIClientConfiguration,
|
||||
AzureOpenAIClientConfigurationConfigModel,
|
||||
|
@ -164,105 +161,22 @@ def type_to_role(message: LLMMessage) -> ChatCompletionRole:
|
|||
return "tool"
|
||||
|
||||
|
||||
def user_message_to_oai(message: UserMessage, prepend_name: bool = False) -> ChatCompletionUserMessageParam:
|
||||
assert_valid_name(message.source)
|
||||
if isinstance(message.content, str):
|
||||
return ChatCompletionUserMessageParam(
|
||||
content=(f"{message.source} said:\n" if prepend_name else "") + message.content,
|
||||
role="user",
|
||||
name=message.source,
|
||||
)
|
||||
else:
|
||||
parts: List[ChatCompletionContentPartParam] = []
|
||||
for part in message.content:
|
||||
if isinstance(part, str):
|
||||
if prepend_name:
|
||||
# Append the name to the first text part
|
||||
oai_part = ChatCompletionContentPartTextParam(
|
||||
text=f"{message.source} said:\n" + part,
|
||||
type="text",
|
||||
)
|
||||
prepend_name = False
|
||||
else:
|
||||
oai_part = ChatCompletionContentPartTextParam(
|
||||
text=part,
|
||||
type="text",
|
||||
)
|
||||
parts.append(oai_part)
|
||||
elif isinstance(part, Image):
|
||||
# TODO: support url based images
|
||||
# TODO: support specifying details
|
||||
parts.append(cast(ChatCompletionContentPartImageParam, part.to_openai_format()))
|
||||
else:
|
||||
raise ValueError(f"Unknown content type: {part}")
|
||||
return ChatCompletionUserMessageParam(
|
||||
content=parts,
|
||||
role="user",
|
||||
name=message.source,
|
||||
)
|
||||
def to_oai_type(
|
||||
message: LLMMessage, prepend_name: bool = False, model_family: str = "gpt-4o"
|
||||
) -> Sequence[ChatCompletionMessageParam]:
|
||||
context = {
|
||||
"prepend_name": prepend_name,
|
||||
}
|
||||
transformers = get_transformer("openai", model_family)
|
||||
|
||||
def raise_value_error(message: LLMMessage, context: Dict[str, Any]) -> Sequence[ChatCompletionMessageParam]:
|
||||
raise ValueError(f"Unknown message type: {type(message)}")
|
||||
|
||||
def system_message_to_oai(message: SystemMessage) -> ChatCompletionSystemMessageParam:
|
||||
return ChatCompletionSystemMessageParam(
|
||||
content=message.content,
|
||||
role="system",
|
||||
transformer: Callable[[LLMMessage, Dict[str, Any]], Sequence[ChatCompletionMessageParam]] = transformers.get(
|
||||
type(message), raise_value_error
|
||||
)
|
||||
|
||||
|
||||
def func_call_to_oai(message: FunctionCall) -> ChatCompletionMessageToolCallParam:
|
||||
return ChatCompletionMessageToolCallParam(
|
||||
id=message.id,
|
||||
function={
|
||||
"arguments": message.arguments,
|
||||
"name": message.name,
|
||||
},
|
||||
type="function",
|
||||
)
|
||||
|
||||
|
||||
def tool_message_to_oai(
|
||||
message: FunctionExecutionResultMessage,
|
||||
) -> Sequence[ChatCompletionToolMessageParam]:
|
||||
return [
|
||||
ChatCompletionToolMessageParam(content=x.content, role="tool", tool_call_id=x.call_id) for x in message.content
|
||||
]
|
||||
|
||||
|
||||
def assistant_message_to_oai(
|
||||
message: AssistantMessage,
|
||||
) -> ChatCompletionAssistantMessageParam:
|
||||
assert_valid_name(message.source)
|
||||
if isinstance(message.content, list):
|
||||
if message.thought is not None:
|
||||
return ChatCompletionAssistantMessageParam(
|
||||
content=message.thought,
|
||||
tool_calls=[func_call_to_oai(x) for x in message.content],
|
||||
role="assistant",
|
||||
name=message.source,
|
||||
)
|
||||
else:
|
||||
return ChatCompletionAssistantMessageParam(
|
||||
tool_calls=[func_call_to_oai(x) for x in message.content],
|
||||
role="assistant",
|
||||
name=message.source,
|
||||
)
|
||||
else:
|
||||
return ChatCompletionAssistantMessageParam(
|
||||
content=message.content,
|
||||
role="assistant",
|
||||
name=message.source,
|
||||
)
|
||||
|
||||
|
||||
def to_oai_type(message: LLMMessage, prepend_name: bool = False) -> Sequence[ChatCompletionMessageParam]:
|
||||
if isinstance(message, SystemMessage):
|
||||
return [system_message_to_oai(message)]
|
||||
elif isinstance(message, UserMessage):
|
||||
return [user_message_to_oai(message, prepend_name)]
|
||||
elif isinstance(message, AssistantMessage):
|
||||
return [assistant_message_to_oai(message)]
|
||||
else:
|
||||
return tool_message_to_oai(message)
|
||||
result = transformer(message, context)
|
||||
return result
|
||||
|
||||
|
||||
def calculate_vision_tokens(image: Image, detail: str = "auto") -> int:
|
||||
|
@ -360,19 +274,6 @@ def normalize_name(name: str) -> str:
|
|||
return re.sub(r"[^a-zA-Z0-9_-]", "_", name)[:64]
|
||||
|
||||
|
||||
def assert_valid_name(name: str) -> str:
|
||||
"""
|
||||
Ensure that configured names are valid, raises ValueError if not.
|
||||
|
||||
For munging LLM responses use _normalize_name to ensure LLM specified names don't break the API.
|
||||
"""
|
||||
if not re.match(r"^[a-zA-Z0-9_-]+$", name):
|
||||
raise ValueError(f"Invalid name: {name}. Only letters, numbers, '_' and '-' are allowed.")
|
||||
if len(name) > 64:
|
||||
raise ValueError(f"Invalid name: {name}. Name must be less than 64 characters.")
|
||||
return name
|
||||
|
||||
|
||||
def count_tokens_openai(
|
||||
messages: Sequence[LLMMessage],
|
||||
model: str,
|
||||
|
@ -392,7 +293,7 @@ def count_tokens_openai(
|
|||
# Message tokens.
|
||||
for message in messages:
|
||||
num_tokens += tokens_per_message
|
||||
oai_message = to_oai_type(message, prepend_name=add_name_prefixes)
|
||||
oai_message = to_oai_type(message, prepend_name=add_name_prefixes, model_family=model)
|
||||
for oai_message_part in oai_message:
|
||||
for key, value in oai_message_part.items():
|
||||
if value is None:
|
||||
|
@ -638,7 +539,11 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
|
|||
_messages.insert(_first_system_message_idx, system_message)
|
||||
messages = _messages
|
||||
|
||||
oai_messages_nested = [to_oai_type(m, prepend_name=self._add_name_prefixes) for m in messages]
|
||||
oai_messages_nested = [
|
||||
to_oai_type(m, prepend_name=self._add_name_prefixes, model_family=create_args.get("model", "unknown"))
|
||||
for m in messages
|
||||
]
|
||||
|
||||
oai_messages = [item for sublist in oai_messages_nested for item in sublist]
|
||||
|
||||
if self.model_info["function_calling"] is False and len(tools) > 0:
|
||||
|
|
|
@ -0,0 +1,27 @@
|
|||
from .registry import (
|
||||
MESSAGE_TRANSFORMERS,
|
||||
build_conditional_transformer_func,
|
||||
build_transformer_func,
|
||||
get_transformer,
|
||||
register_transformer,
|
||||
)
|
||||
from .types import (
|
||||
LLMMessageContent,
|
||||
MessageParam,
|
||||
TransformerFunc,
|
||||
TransformerMap,
|
||||
TrasformerReturnType,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"register_transformer",
|
||||
"get_transformer",
|
||||
"build_transformer_func",
|
||||
"build_conditional_transformer_func",
|
||||
"MESSAGE_TRANSFORMERS",
|
||||
"TransformerMap",
|
||||
"TransformerFunc",
|
||||
"MessageParam",
|
||||
"LLMMessageContent",
|
||||
"TrasformerReturnType",
|
||||
]
|
|
@ -0,0 +1,117 @@
|
|||
from collections import defaultdict
|
||||
from typing import Any, Callable, Dict, List
|
||||
|
||||
from autogen_core.models import LLMMessage
|
||||
|
||||
from .types import (
|
||||
TransformerFunc,
|
||||
TransformerMap,
|
||||
)
|
||||
|
||||
# Global registry of model family → message transformer map
|
||||
# Each model family (e.g. "gpt-4o", "gemini-1.5-flash") maps to a dict of LLMMessage type → transformer function
|
||||
MESSAGE_TRANSFORMERS: Dict[str, Dict[str, TransformerMap]] = defaultdict(dict)
|
||||
|
||||
|
||||
def build_transformer_func(
|
||||
funcs: List[Callable[[LLMMessage, Dict[str, Any]], Dict[str, Any]]], message_param_func: Callable[..., Any]
|
||||
) -> TransformerFunc:
|
||||
"""
|
||||
Combines multiple transformer functions into a single transformer.
|
||||
|
||||
Each `func` must accept a message and a context dict, and return a partial dict
|
||||
of keyword arguments. These are merged and passed to `message_param_func`.
|
||||
|
||||
This structure allows flexible transformation pipelines and future extensibility
|
||||
(e.g., prepend name, insert metadata, etc).
|
||||
|
||||
message_param_func: A model-specific constructor (e.g. ChatCompletionMessageParam).
|
||||
Signature is intentionally open: Callable[..., Any].
|
||||
"""
|
||||
|
||||
def transformer_func(message: LLMMessage, context: Any) -> Any:
|
||||
kwargs: Dict[str, Any] = {}
|
||||
for func in funcs:
|
||||
kwargs.update(func(message, context))
|
||||
return [message_param_func(**kwargs)]
|
||||
|
||||
return transformer_func
|
||||
|
||||
|
||||
def build_conditional_transformer_func(
|
||||
funcs_map: Dict[str, List[Callable[[LLMMessage, Dict[str, Any]], Dict[str, Any]]]],
|
||||
message_param_func_map: Dict[str, Callable[..., Any]],
|
||||
condition_func: Callable[[LLMMessage, Dict[str, Any]], str],
|
||||
) -> TransformerFunc:
|
||||
"""
|
||||
Combines multiple transformer functions into a single transformer, with a conditional constructor.
|
||||
|
||||
Each `func` must accept a message and a context dict, and return a partial dict
|
||||
of keyword arguments. These are merged and passed to the constructor selected by `condition_func`.
|
||||
|
||||
This structure allows flexible transformation pipelines and future extensibility
|
||||
(e.g., prepend name, insert metadata, etc).
|
||||
|
||||
message_param_func_map: A mapping of condition → constructor function.
|
||||
condition_func: A function that returns the condition for selecting the constructor.
|
||||
"""
|
||||
|
||||
def transformer(message: LLMMessage, context: Dict[str, Any]) -> Any:
|
||||
condition = condition_func(message, context)
|
||||
message_param_func = message_param_func_map[condition]
|
||||
kwargs: Dict[str, Any] = {}
|
||||
for func in funcs_map[condition]:
|
||||
kwargs.update(func(message, context))
|
||||
if kwargs.get("pass_message", False):
|
||||
return []
|
||||
return [message_param_func(**kwargs)]
|
||||
|
||||
return transformer
|
||||
|
||||
|
||||
def register_transformer(api: str, model_family: str, transformer_map: TransformerMap) -> None:
|
||||
"""
|
||||
Registers a transformer map for a given model family.
|
||||
|
||||
Example:
|
||||
register_transformer("gpt-4o", {
|
||||
UserMessage: user_message_to_oai,
|
||||
SystemMessage: system_message_to_oai,
|
||||
})
|
||||
"""
|
||||
MESSAGE_TRANSFORMERS[api][model_family] = transformer_map
|
||||
|
||||
|
||||
def _find_model_family(api: str, model: str) -> str:
|
||||
"""
|
||||
Finds the best matching model family for the given model.
|
||||
Search via prefix matching (e.g. "gpt-4o" → "gpt-4o-1.0").
|
||||
"""
|
||||
for family in MESSAGE_TRANSFORMERS[api].keys():
|
||||
if model.startswith(family):
|
||||
return family
|
||||
return "default"
|
||||
|
||||
|
||||
def get_transformer(api: str, model_family: str) -> TransformerMap:
|
||||
"""
|
||||
Returns the registered transformer map for the given model family.
|
||||
|
||||
This is a thin wrapper around `MESSAGE_TRANSFORMERS.get(...)`, but serves as
|
||||
an abstraction layer to allow future enhancements such as:
|
||||
|
||||
- Providing fallback transformers for unknown model families
|
||||
- Injecting mock transformers during testing
|
||||
- Adding logging, metrics, or versioning later
|
||||
|
||||
Keeping this as a function (instead of direct dict access) improves long-term flexibility.
|
||||
"""
|
||||
|
||||
model = _find_model_family(api, model_family)
|
||||
|
||||
transformer = MESSAGE_TRANSFORMERS.get(api, {}).get(model, {})
|
||||
|
||||
if not transformer:
|
||||
raise ValueError(f"No transformer found for model family '{model_family}'")
|
||||
|
||||
return transformer
|
|
@ -0,0 +1,22 @@
|
|||
from typing import Any, Callable, Dict, List, Sequence, Type, Union
|
||||
|
||||
from autogen_core import FunctionCall, Image
|
||||
from autogen_core.models import LLMMessage
|
||||
from autogen_core.models._types import FunctionExecutionResult
|
||||
from openai.types.chat import ChatCompletionMessageParam
|
||||
|
||||
MessageParam = Union[ChatCompletionMessageParam] # If that transformation move to global, add other message params here
|
||||
TrasformerReturnType = Sequence[MessageParam]
|
||||
TransformerFunc = Callable[[LLMMessage, Dict[str, Any]], TrasformerReturnType]
|
||||
TransformerMap = Dict[Type[LLMMessage], TransformerFunc]
|
||||
|
||||
LLMMessageContent = Union[
|
||||
# SystemMessage.content
|
||||
str,
|
||||
# UserMessage.content
|
||||
List[Union[str, Image]],
|
||||
# AssistantMessage.content
|
||||
List[FunctionCall],
|
||||
# FunctionExecutionResultMessage.content
|
||||
List[FunctionExecutionResult],
|
||||
]
|
|
@ -0,0 +1,14 @@
|
|||
import re
|
||||
|
||||
|
||||
def assert_valid_name(name: str) -> str:
|
||||
"""
|
||||
Ensure that configured names are valid, raises ValueError if not.
|
||||
|
||||
For munging LLM responses use _normalize_name to ensure LLM specified names don't break the API.
|
||||
"""
|
||||
if not re.match(r"^[a-zA-Z0-9_-]+$", name):
|
||||
raise ValueError(f"Invalid name: {name}. Only letters, numbers, '_' and '-' are allowed.")
|
||||
if len(name) > 64:
|
||||
raise ValueError(f"Invalid name: {name}. Name must be less than 64 characters.")
|
||||
return name
|
|
@ -339,7 +339,6 @@ async def test_anthropic_serialization() -> None:
|
|||
@pytest.mark.asyncio
|
||||
async def test_anthropic_muliple_system_message() -> None:
|
||||
"""Test multiple system messages in a single request."""
|
||||
|
||||
api_key = os.getenv("ANTHROPIC_API_KEY")
|
||||
if not api_key:
|
||||
pytest.skip("ANTHROPIC_API_KEY not found in environment variables")
|
||||
|
@ -560,3 +559,28 @@ def test_merge_system_messages_no_duplicates() -> None:
|
|||
assert isinstance(merged_messages[0], SystemMessage)
|
||||
# 중복된 내용도 그대로 병합됨
|
||||
assert merged_messages[0].content == "Same instruction\nSame instruction"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_assistant_content_string_with_anthropic() -> None:
|
||||
"""Test that an empty assistant content string is handled correctly."""
|
||||
api_key = os.getenv("ANTHROPIC_API_KEY")
|
||||
if not api_key:
|
||||
pytest.skip("ANTHROPIC_API_KEY not found in environment variables")
|
||||
|
||||
client = AnthropicChatCompletionClient(
|
||||
model="claude-3-haiku-20240307",
|
||||
api_key=api_key,
|
||||
)
|
||||
|
||||
# Test empty assistant content string
|
||||
result = await client.create(
|
||||
messages=[
|
||||
UserMessage(content="Say something", source="user"),
|
||||
AssistantMessage(content="", source="assistant"),
|
||||
]
|
||||
)
|
||||
|
||||
# Verify we got a response
|
||||
assert isinstance(result.content, str)
|
||||
assert len(result.content) > 0
|
||||
|
|
|
@ -1602,6 +1602,10 @@ def openai_client(request: pytest.FixtureRequest) -> OpenAIChatCompletionClient:
|
|||
api_key = os.getenv("GEMINI_API_KEY")
|
||||
if not api_key:
|
||||
pytest.skip("GEMINI_API_KEY not found in environment variables")
|
||||
elif model.startswith("claude"):
|
||||
api_key = os.getenv("ANTHROPIC_API_KEY")
|
||||
if not api_key:
|
||||
pytest.skip("ANTHROPIC_API_KEY not found in environment variables")
|
||||
else:
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
if not api_key:
|
||||
|
@ -1616,7 +1620,7 @@ def openai_client(request: pytest.FixtureRequest) -> OpenAIChatCompletionClient:
|
|||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"model",
|
||||
["gpt-4o-mini", "gemini-1.5-flash"],
|
||||
["gpt-4o-mini", "gemini-1.5-flash", "claude-3-5-haiku-20241022"],
|
||||
)
|
||||
async def test_model_client_basic_completion(model: str, openai_client: OpenAIChatCompletionClient) -> None:
|
||||
# Test basic completion
|
||||
|
@ -1633,7 +1637,7 @@ async def test_model_client_basic_completion(model: str, openai_client: OpenAICh
|
|||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"model",
|
||||
["gpt-4o-mini", "gemini-1.5-flash"],
|
||||
["gpt-4o-mini", "gemini-1.5-flash", "claude-3-5-haiku-20241022"],
|
||||
)
|
||||
async def test_model_client_with_function_calling(model: str, openai_client: OpenAIChatCompletionClient) -> None:
|
||||
# Test tool calling
|
||||
|
@ -2065,7 +2069,7 @@ async def test_add_name_prefixes(monkeypatch: pytest.MonkeyPatch) -> None:
|
|||
[
|
||||
"gpt-4o-mini",
|
||||
"gemini-1.5-flash",
|
||||
# TODO: Add anthropic models when available.
|
||||
"claude-3-5-haiku-20241022",
|
||||
],
|
||||
)
|
||||
async def test_muliple_system_message(model: str, openai_client: OpenAIChatCompletionClient) -> None:
|
||||
|
@ -2303,4 +2307,64 @@ async def test_single_system_message_for_gemini_model() -> None:
|
|||
assert system_messages[0]["content"] == "I am the only system message"
|
||||
|
||||
|
||||
def noop(input: str) -> str:
|
||||
return "done"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model", ["gemini-1.5-flash"])
|
||||
async def test_empty_assistant_content_with_gemini(model: str, openai_client: OpenAIChatCompletionClient) -> None:
|
||||
# Test tool calling
|
||||
tool = FunctionTool(noop, name="noop", description="No-op tool")
|
||||
messages: List[LLMMessage] = [UserMessage(content="Call noop", source="user")]
|
||||
result = await openai_client.create(messages=messages, tools=[tool])
|
||||
assert isinstance(result.content, list)
|
||||
tool_call = result.content[0]
|
||||
assert isinstance(tool_call, FunctionCall)
|
||||
|
||||
# reply with empty string as thought (== content)
|
||||
messages.append(AssistantMessage(content=result.content, thought="", source="assistant"))
|
||||
messages.append(
|
||||
FunctionExecutionResultMessage(
|
||||
content=[
|
||||
FunctionExecutionResult(
|
||||
content="done",
|
||||
call_id=tool_call.id,
|
||||
is_error=False,
|
||||
name=tool_call.name,
|
||||
)
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
# This will crash if _set_empty_to_whitespace is not applied to "thought"
|
||||
result = await openai_client.create(messages=messages)
|
||||
assert isinstance(result.content, str)
|
||||
assert result.content.strip() != "" or result.content == " "
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"model",
|
||||
[
|
||||
"gpt-4o-mini",
|
||||
"gemini-1.5-flash",
|
||||
"claude-3-5-haiku-20241022",
|
||||
],
|
||||
)
|
||||
async def test_empty_assistant_content_string_with_some_model(
|
||||
model: str, openai_client: OpenAIChatCompletionClient
|
||||
) -> None:
|
||||
# message: assistant is response empty content
|
||||
messages: list[LLMMessage] = [
|
||||
UserMessage(content="Say something", source="user"),
|
||||
AssistantMessage(content="test", source="assistant"),
|
||||
UserMessage(content="", source="user"),
|
||||
]
|
||||
|
||||
# This will crash if _set_empty_to_whitespace is not applied to "content"
|
||||
result = await openai_client.create(messages=messages)
|
||||
assert isinstance(result.content, str)
|
||||
|
||||
|
||||
# TODO: add integration tests for Azure OpenAI using AAD token.
|
||||
|
|
Loading…
Reference in New Issue