[Refactor] model family resolution to support non-prefixed names like Mistral (#6158)

This PR improves how model_family is resolved when selecting a
transformer from the registry.
Previously, model families were inferred using a simple prefix-based
match like:
```
if model.startswith(family): ...
```
This works for cleanly prefixed models (e.g., `gpt-4o`, `claude-3`) but
fails for models like `mistral-large-latest`, `codestral-latest`, etc.,
where prefix-based matching is ambiguous or misleading.

To address this:
	•	model_family can now be passed explicitly (e.g., via ModelInfo)
• _find_model_family() is only used as a fallback when the value is
"unknown"
	•	Transformer lookup is now more robust and predictable
• Example integration in to_oai_type() demonstrates this pattern using
self._model_info["family"]

This change is required for safe support of models like Mistral and
other future models that do not follow standard naming conventions.

Linked to discussion in
[#6151](https://github.com/microsoft/autogen/issues/6151)
Related : #6011

---------

Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com>
This commit is contained in:
EeS 2025-04-03 07:08:17 +09:00 committed by GitHub
parent 9143e58ef1
commit 27da37efc0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 92 additions and 13 deletions

View File

@ -1,10 +1,16 @@
import logging
from typing import Dict
from autogen_core import EVENT_LOGGER_NAME, TRACE_LOGGER_NAME
from autogen_core.models import ModelFamily, ModelInfo
logger = logging.getLogger(EVENT_LOGGER_NAME)
trace_logger = logging.getLogger(TRACE_LOGGER_NAME)
# Based on: https://platform.openai.com/docs/models/continuous-model-upgrades
# This is a moving target, so correctness is checked by the model value returned by openai against expected values at runtime``
_MODEL_POINTERS = {
# OpenAI models
"o3-mini": "o3-mini-2025-01-31",
"o1": "o1-2024-12-17",
"o1-preview": "o1-preview-2024-09-12",
@ -18,6 +24,7 @@ _MODEL_POINTERS = {
"gpt-4-32k": "gpt-4-32k-0613",
"gpt-3.5-turbo": "gpt-3.5-turbo-0125",
"gpt-3.5-turbo-16k": "gpt-3.5-turbo-16k-0613",
# Anthropic models
"claude-3-haiku": "claude-3-haiku-20240307",
"claude-3-sonnet": "claude-3-sonnet-20240229",
"claude-3-opus": "claude-3-opus-20240229",
@ -291,8 +298,24 @@ def resolve_model(model: str) -> str:
def get_info(model: str) -> ModelInfo:
# If call it, that mean is that the config does not have cumstom model_info
resolved_model = resolve_model(model)
return _MODEL_INFO[resolved_model]
model_info: ModelInfo = _MODEL_INFO.get(
resolved_model,
{
"vision": False,
"function_calling": False,
"json_output": False,
"family": "FAILED",
"structured_output": False,
},
)
if model_info.get("family") == "FAILED":
raise ValueError("model_info is required when model name is not a valid OpenAI model")
if model_info.get("family") == ModelFamily.UNKNOWN:
trace_logger.warning(f"Model info not found for model: {model}")
return model_info
def get_token_limit(model: str) -> int:

View File

@ -162,12 +162,12 @@ def type_to_role(message: LLMMessage) -> ChatCompletionRole:
def to_oai_type(
message: LLMMessage, prepend_name: bool = False, model_family: str = "gpt-4o"
message: LLMMessage, prepend_name: bool = False, model: str = "unknown", model_family: str = ModelFamily.UNKNOWN
) -> Sequence[ChatCompletionMessageParam]:
context = {
"prepend_name": prepend_name,
}
transformers = get_transformer("openai", model_family)
transformers = get_transformer("openai", model, model_family)
def raise_value_error(message: LLMMessage, context: Dict[str, Any]) -> Sequence[ChatCompletionMessageParam]:
raise ValueError(f"Unknown message type: {type(message)}")
@ -280,6 +280,7 @@ def count_tokens_openai(
*,
add_name_prefixes: bool = False,
tools: Sequence[Tool | ToolSchema] = [],
model_family: str = ModelFamily.UNKNOWN,
) -> int:
try:
encoding = tiktoken.encoding_for_model(model)
@ -293,7 +294,7 @@ def count_tokens_openai(
# Message tokens.
for message in messages:
num_tokens += tokens_per_message
oai_message = to_oai_type(message, prepend_name=add_name_prefixes, model_family=model)
oai_message = to_oai_type(message, prepend_name=add_name_prefixes, model=model, model_family=model_family)
for oai_message_part in oai_message:
for key, value in oai_message_part.items():
if value is None:
@ -556,7 +557,12 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
messages = self._rstrip_last_assistant_message(messages)
oai_messages_nested = [
to_oai_type(m, prepend_name=self._add_name_prefixes, model_family=create_args.get("model", "unknown"))
to_oai_type(
m,
prepend_name=self._add_name_prefixes,
model=create_args.get("model", "unknown"),
model_family=self._model_info["family"],
)
for m in messages
]
@ -1049,6 +1055,7 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
self._create_args["model"],
add_name_prefixes=self._add_name_prefixes,
tools=tools,
model_family=self._model_info["family"],
)
def remaining_tokens(self, messages: Sequence[LLMMessage], *, tools: Sequence[Tool | ToolSchema] = []) -> int:

View File

@ -1,7 +1,7 @@
from collections import defaultdict
from typing import Any, Callable, Dict, List
from autogen_core.models import LLMMessage
from autogen_core.models import LLMMessage, ModelFamily
from .types import (
TransformerFunc,
@ -87,13 +87,14 @@ def _find_model_family(api: str, model: str) -> str:
Finds the best matching model family for the given model.
Search via prefix matching (e.g. "gpt-4o" "gpt-4o-1.0").
"""
for family in MESSAGE_TRANSFORMERS[api].keys():
if model.startswith(family):
return family
return "default"
family = ModelFamily.UNKNOWN
for _family in MESSAGE_TRANSFORMERS[api].keys():
if model.startswith(_family):
family = _family
return family
def get_transformer(api: str, model_family: str) -> TransformerMap:
def get_transformer(api: str, model: str, model_family: str) -> TransformerMap:
"""
Returns the registered transformer map for the given model family.
@ -107,9 +108,11 @@ def get_transformer(api: str, model_family: str) -> TransformerMap:
Keeping this as a function (instead of direct dict access) improves long-term flexibility.
"""
model = _find_model_family(api, model_family)
if model_family == ModelFamily.UNKNOWN:
# fallback to finding the best matching model family
model_family = _find_model_family(api, model)
transformer = MESSAGE_TRANSFORMERS.get(api, {}).get(model, {})
transformer = MESSAGE_TRANSFORMERS.get(api, {}).get(model_family, {})
if not transformer:
raise ValueError(f"No transformer found for model family '{model_family}'")

View File

@ -29,6 +29,7 @@ from autogen_ext.models.openai._openai_client import (
convert_tools,
to_oai_type,
)
from autogen_ext.models.openai._transformation import TransformerMap, get_transformer
from openai.resources.beta.chat.completions import ( # type: ignore
AsyncChatCompletionStreamManager as BetaAsyncChatCompletionStreamManager, # type: ignore
)
@ -2367,6 +2368,51 @@ async def test_empty_assistant_content_string_with_some_model(
assert isinstance(result.content, str)
def test_openai_model_registry_find_well() -> None:
model = "gpt-4o"
client1 = OpenAIChatCompletionClient(model=model, api_key="test")
client2 = OpenAIChatCompletionClient(
model=model,
model_info={
"vision": False,
"function_calling": False,
"json_output": False,
"structured_output": False,
"family": ModelFamily.UNKNOWN,
},
api_key="test",
)
def get_regitered_transformer(client: OpenAIChatCompletionClient) -> TransformerMap:
model_name = client._create_args["model"] # pyright: ignore[reportPrivateUsage]
model_family = client.model_info["family"]
return get_transformer("openai", model_name, model_family)
assert get_regitered_transformer(client1) == get_regitered_transformer(client2)
def test_openai_model_registry_find_wrong() -> None:
with pytest.raises(ValueError, match="No transformer found for model family"):
get_transformer("openai", "gpt-7", "foobar")
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model",
[
"gpt-4o-mini",
],
)
async def test_openai_model_unknown_message_type(model: str, openai_client: OpenAIChatCompletionClient) -> None:
class WrongMessage:
content = "foo"
source = "bar"
messages: List[WrongMessage] = [WrongMessage()]
with pytest.raises(ValueError, match="Unknown message type"):
await openai_client.create(messages=messages) # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model",