mirror of https://github.com/microsoft/autogen.git
[Refactor] model family resolution to support non-prefixed names like Mistral (#6158)
This PR improves how model_family is resolved when selecting a transformer from the registry. Previously, model families were inferred using a simple prefix-based match like: ``` if model.startswith(family): ... ``` This works for cleanly prefixed models (e.g., `gpt-4o`, `claude-3`) but fails for models like `mistral-large-latest`, `codestral-latest`, etc., where prefix-based matching is ambiguous or misleading. To address this: • model_family can now be passed explicitly (e.g., via ModelInfo) • _find_model_family() is only used as a fallback when the value is "unknown" • Transformer lookup is now more robust and predictable • Example integration in to_oai_type() demonstrates this pattern using self._model_info["family"] This change is required for safe support of models like Mistral and other future models that do not follow standard naming conventions. Linked to discussion in [#6151](https://github.com/microsoft/autogen/issues/6151) Related : #6011 --------- Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com>
This commit is contained in:
parent
9143e58ef1
commit
27da37efc0
|
@ -1,10 +1,16 @@
|
|||
import logging
|
||||
from typing import Dict
|
||||
|
||||
from autogen_core import EVENT_LOGGER_NAME, TRACE_LOGGER_NAME
|
||||
from autogen_core.models import ModelFamily, ModelInfo
|
||||
|
||||
logger = logging.getLogger(EVENT_LOGGER_NAME)
|
||||
trace_logger = logging.getLogger(TRACE_LOGGER_NAME)
|
||||
|
||||
# Based on: https://platform.openai.com/docs/models/continuous-model-upgrades
|
||||
# This is a moving target, so correctness is checked by the model value returned by openai against expected values at runtime``
|
||||
_MODEL_POINTERS = {
|
||||
# OpenAI models
|
||||
"o3-mini": "o3-mini-2025-01-31",
|
||||
"o1": "o1-2024-12-17",
|
||||
"o1-preview": "o1-preview-2024-09-12",
|
||||
|
@ -18,6 +24,7 @@ _MODEL_POINTERS = {
|
|||
"gpt-4-32k": "gpt-4-32k-0613",
|
||||
"gpt-3.5-turbo": "gpt-3.5-turbo-0125",
|
||||
"gpt-3.5-turbo-16k": "gpt-3.5-turbo-16k-0613",
|
||||
# Anthropic models
|
||||
"claude-3-haiku": "claude-3-haiku-20240307",
|
||||
"claude-3-sonnet": "claude-3-sonnet-20240229",
|
||||
"claude-3-opus": "claude-3-opus-20240229",
|
||||
|
@ -291,8 +298,24 @@ def resolve_model(model: str) -> str:
|
|||
|
||||
|
||||
def get_info(model: str) -> ModelInfo:
|
||||
# If call it, that mean is that the config does not have cumstom model_info
|
||||
resolved_model = resolve_model(model)
|
||||
return _MODEL_INFO[resolved_model]
|
||||
model_info: ModelInfo = _MODEL_INFO.get(
|
||||
resolved_model,
|
||||
{
|
||||
"vision": False,
|
||||
"function_calling": False,
|
||||
"json_output": False,
|
||||
"family": "FAILED",
|
||||
"structured_output": False,
|
||||
},
|
||||
)
|
||||
if model_info.get("family") == "FAILED":
|
||||
raise ValueError("model_info is required when model name is not a valid OpenAI model")
|
||||
if model_info.get("family") == ModelFamily.UNKNOWN:
|
||||
trace_logger.warning(f"Model info not found for model: {model}")
|
||||
|
||||
return model_info
|
||||
|
||||
|
||||
def get_token_limit(model: str) -> int:
|
||||
|
|
|
@ -162,12 +162,12 @@ def type_to_role(message: LLMMessage) -> ChatCompletionRole:
|
|||
|
||||
|
||||
def to_oai_type(
|
||||
message: LLMMessage, prepend_name: bool = False, model_family: str = "gpt-4o"
|
||||
message: LLMMessage, prepend_name: bool = False, model: str = "unknown", model_family: str = ModelFamily.UNKNOWN
|
||||
) -> Sequence[ChatCompletionMessageParam]:
|
||||
context = {
|
||||
"prepend_name": prepend_name,
|
||||
}
|
||||
transformers = get_transformer("openai", model_family)
|
||||
transformers = get_transformer("openai", model, model_family)
|
||||
|
||||
def raise_value_error(message: LLMMessage, context: Dict[str, Any]) -> Sequence[ChatCompletionMessageParam]:
|
||||
raise ValueError(f"Unknown message type: {type(message)}")
|
||||
|
@ -280,6 +280,7 @@ def count_tokens_openai(
|
|||
*,
|
||||
add_name_prefixes: bool = False,
|
||||
tools: Sequence[Tool | ToolSchema] = [],
|
||||
model_family: str = ModelFamily.UNKNOWN,
|
||||
) -> int:
|
||||
try:
|
||||
encoding = tiktoken.encoding_for_model(model)
|
||||
|
@ -293,7 +294,7 @@ def count_tokens_openai(
|
|||
# Message tokens.
|
||||
for message in messages:
|
||||
num_tokens += tokens_per_message
|
||||
oai_message = to_oai_type(message, prepend_name=add_name_prefixes, model_family=model)
|
||||
oai_message = to_oai_type(message, prepend_name=add_name_prefixes, model=model, model_family=model_family)
|
||||
for oai_message_part in oai_message:
|
||||
for key, value in oai_message_part.items():
|
||||
if value is None:
|
||||
|
@ -556,7 +557,12 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
|
|||
messages = self._rstrip_last_assistant_message(messages)
|
||||
|
||||
oai_messages_nested = [
|
||||
to_oai_type(m, prepend_name=self._add_name_prefixes, model_family=create_args.get("model", "unknown"))
|
||||
to_oai_type(
|
||||
m,
|
||||
prepend_name=self._add_name_prefixes,
|
||||
model=create_args.get("model", "unknown"),
|
||||
model_family=self._model_info["family"],
|
||||
)
|
||||
for m in messages
|
||||
]
|
||||
|
||||
|
@ -1049,6 +1055,7 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
|
|||
self._create_args["model"],
|
||||
add_name_prefixes=self._add_name_prefixes,
|
||||
tools=tools,
|
||||
model_family=self._model_info["family"],
|
||||
)
|
||||
|
||||
def remaining_tokens(self, messages: Sequence[LLMMessage], *, tools: Sequence[Tool | ToolSchema] = []) -> int:
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
from collections import defaultdict
|
||||
from typing import Any, Callable, Dict, List
|
||||
|
||||
from autogen_core.models import LLMMessage
|
||||
from autogen_core.models import LLMMessage, ModelFamily
|
||||
|
||||
from .types import (
|
||||
TransformerFunc,
|
||||
|
@ -87,13 +87,14 @@ def _find_model_family(api: str, model: str) -> str:
|
|||
Finds the best matching model family for the given model.
|
||||
Search via prefix matching (e.g. "gpt-4o" → "gpt-4o-1.0").
|
||||
"""
|
||||
for family in MESSAGE_TRANSFORMERS[api].keys():
|
||||
if model.startswith(family):
|
||||
return family
|
||||
return "default"
|
||||
family = ModelFamily.UNKNOWN
|
||||
for _family in MESSAGE_TRANSFORMERS[api].keys():
|
||||
if model.startswith(_family):
|
||||
family = _family
|
||||
return family
|
||||
|
||||
|
||||
def get_transformer(api: str, model_family: str) -> TransformerMap:
|
||||
def get_transformer(api: str, model: str, model_family: str) -> TransformerMap:
|
||||
"""
|
||||
Returns the registered transformer map for the given model family.
|
||||
|
||||
|
@ -107,9 +108,11 @@ def get_transformer(api: str, model_family: str) -> TransformerMap:
|
|||
Keeping this as a function (instead of direct dict access) improves long-term flexibility.
|
||||
"""
|
||||
|
||||
model = _find_model_family(api, model_family)
|
||||
if model_family == ModelFamily.UNKNOWN:
|
||||
# fallback to finding the best matching model family
|
||||
model_family = _find_model_family(api, model)
|
||||
|
||||
transformer = MESSAGE_TRANSFORMERS.get(api, {}).get(model, {})
|
||||
transformer = MESSAGE_TRANSFORMERS.get(api, {}).get(model_family, {})
|
||||
|
||||
if not transformer:
|
||||
raise ValueError(f"No transformer found for model family '{model_family}'")
|
||||
|
|
|
@ -29,6 +29,7 @@ from autogen_ext.models.openai._openai_client import (
|
|||
convert_tools,
|
||||
to_oai_type,
|
||||
)
|
||||
from autogen_ext.models.openai._transformation import TransformerMap, get_transformer
|
||||
from openai.resources.beta.chat.completions import ( # type: ignore
|
||||
AsyncChatCompletionStreamManager as BetaAsyncChatCompletionStreamManager, # type: ignore
|
||||
)
|
||||
|
@ -2367,6 +2368,51 @@ async def test_empty_assistant_content_string_with_some_model(
|
|||
assert isinstance(result.content, str)
|
||||
|
||||
|
||||
def test_openai_model_registry_find_well() -> None:
|
||||
model = "gpt-4o"
|
||||
client1 = OpenAIChatCompletionClient(model=model, api_key="test")
|
||||
client2 = OpenAIChatCompletionClient(
|
||||
model=model,
|
||||
model_info={
|
||||
"vision": False,
|
||||
"function_calling": False,
|
||||
"json_output": False,
|
||||
"structured_output": False,
|
||||
"family": ModelFamily.UNKNOWN,
|
||||
},
|
||||
api_key="test",
|
||||
)
|
||||
|
||||
def get_regitered_transformer(client: OpenAIChatCompletionClient) -> TransformerMap:
|
||||
model_name = client._create_args["model"] # pyright: ignore[reportPrivateUsage]
|
||||
model_family = client.model_info["family"]
|
||||
return get_transformer("openai", model_name, model_family)
|
||||
|
||||
assert get_regitered_transformer(client1) == get_regitered_transformer(client2)
|
||||
|
||||
|
||||
def test_openai_model_registry_find_wrong() -> None:
|
||||
with pytest.raises(ValueError, match="No transformer found for model family"):
|
||||
get_transformer("openai", "gpt-7", "foobar")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"model",
|
||||
[
|
||||
"gpt-4o-mini",
|
||||
],
|
||||
)
|
||||
async def test_openai_model_unknown_message_type(model: str, openai_client: OpenAIChatCompletionClient) -> None:
|
||||
class WrongMessage:
|
||||
content = "foo"
|
||||
source = "bar"
|
||||
|
||||
messages: List[WrongMessage] = [WrongMessage()]
|
||||
with pytest.raises(ValueError, match="Unknown message type"):
|
||||
await openai_client.create(messages=messages) # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"model",
|
||||
|
|
Loading…
Reference in New Issue