This commit is contained in:
abhinav-aegis 2025-04-12 22:44:38 +10:00 committed by GitHub
commit 723abfcf54
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 1306 additions and 4 deletions

View File

@ -5,14 +5,15 @@ class and includes specific fields relevant to the type of message being sent.
"""
from abc import ABC, abstractmethod
from typing import Any, Dict, Generic, List, Literal, Mapping, TypeVar
from typing import Any, Dict, Generic, List, Literal, Mapping, TypeVar, Optional, Type
from autogen_core import FunctionCall, Image
from autogen_core.memory import MemoryContent
from autogen_core.models import FunctionExecutionResult, LLMMessage, RequestUsage, UserMessage
from pydantic import BaseModel, Field, computed_field
from typing_extensions import Annotated, Self
from autogen_core import Component, ComponentBase
from autogen_core.utils import schema_to_pydantic_model
class BaseMessage(BaseModel, ABC):
"""Abstract base class for all message types in AgentChat.
@ -175,21 +176,45 @@ class StructuredMessage(BaseChatMessage, Generic[StructuredContentType]):
print(message.to_text()) # {"text": "Hello", "number": 42}
.. code-block:: python
from pydantic import BaseModel
from autogen_agentchat.messages import StructuredMessage
class MyMessageContent(BaseModel):
text: str
number: int
message = StructuredMessage[MyMessageContent](
content=MyMessageContent(text="Hello", number=42),
source="agent",
format_string="Hello, {text} {number}!",
)
print(message.to_text()) # Hello, agent 42!
"""
content: StructuredContentType
"""The content of the message. Must be a subclass of
`Pydantic BaseModel <https://docs.pydantic.dev/latest/concepts/models/>`_."""
format_string: Optional[str] = None
@computed_field
def type(self) -> str:
return self.__class__.__name__
def to_text(self) -> str:
return self.content.model_dump_json(indent=2)
if self.format_string is not None:
return self.format_string.format(**self.content.model_dump())
else:
return self.content.model_dump_json()
def to_model_text(self) -> str:
return self.content.model_dump_json()
if self.format_string is not None:
return self.format_string.format(**self.content.model_dump())
else:
return self.content.model_dump_json()
def to_model_message(self) -> UserMessage:
return UserMessage(
@ -197,6 +222,97 @@ class StructuredMessage(BaseChatMessage, Generic[StructuredContentType]):
source=self.source,
)
class StructureMessageConfig(BaseModel):
"""The declarative configuration for the structured input."""
json_schema: dict
format_string: Optional[str] = None
content_model_name: str
class StructuredMessageComponent(ComponentBase[StructureMessageConfig], Component[StructureMessageConfig]):
"""
A component that creates structured chat messages from Pydantic models or JSON schemas.
This component helps you generate strongly-typed chat messages with content defined using a Pydantic model.
It can be used in declarative workflows where message structure must be validated, formatted, and serialized.
You can initialize the component directly using a `BaseModel` subclass, or dynamically from a configuration
object (e.g., loaded from disk or a database).
### Example 1: Create from a Pydantic Model
.. code-block:: python
from pydantic import BaseModel
from autogen_agentchat.messages import StructuredMessageComponent
class TestContent(BaseModel):
field1: str
field2: int
format_string = "This is a string {field1} and this is an int {field2}"
sm_component = StructuredMessageComponent(input_model=TestContent, format_string=format_string)
message = sm_component.StructuredMessage(
source="test_agent",
content=TestContent(field1="Hello", field2=42),
format_string=format_string
)
print(message.to_model_text()) # Output: This is a string Hello and this is an int 42
config = sm_component._to_config()
s_m_dyn = StructuredMessageComponent._from_config(config)
message = s_m_dyn.StructuredMessage(source="test_agent", content=s_m_dyn.ContentModel(field1="dyn agent", field2=43), format_string=s_m_dyn.format_string)
print(type(message)) # StructuredMessage[GeneratedModel]
print(message.to_model_text()) # Output: This is a string dyn agent and this is an int 43
Attributes:
component_config_schema (StructureMessageConfig): Defines the configuration structure for this component.
component_provider_override (str): Path used to reference this component in external tooling.
component_type (str): Identifier used for categorization (e.g., "structured_message").
Raises:
ValueError: If neither `json_schema` nor `input_model` is provided.
Args:
json_schema (Optional[str]): JSON schema to dynamically create a Pydantic model.
input_model (Optional[Type[BaseModel]]): A subclass of `BaseModel` that defines the expected message structure.
format_string (Optional[str]): Optional string to render content into a human-readable format.
content_model_name (Optional[str]): Optional name for the generated Pydantic model.
"""
component_config_schema = StructureMessageConfig
component_provider_override = "autogen_agentchat.messages.StructuredMessageComponent"
component_type = "structured_message"
def __init__(self, json_schema: Optional[str]=None, input_model: Optional[Type[BaseModel]] = None, format_string: Optional[str] = None, content_model_name: Optional[str] = None) -> None:
self.format_string = format_string
if not json_schema and not input_model:
raise ValueError("Either `input_json_schema` or `input_model` must be provided.")
if input_model:
self.ContentModel = input_model
else:
self.ContentModel = schema_to_pydantic_model(json_schema, model_name=content_model_name or "GeneratedContentModel")
self.StructuredMessage = StructuredMessage[self.ContentModel]
def _to_config(self) -> StructureMessageConfig:
return StructureMessageConfig(
json_schema=self.ContentModel.model_json_schema(),
format_string=self.format_string,
content_model_name=self.ContentModel.__name__
)
@classmethod
def _from_config(cls, config: StructureMessageConfig) -> "StructuredMessageComponent":
return cls(
json_schema=config.json_schema,
format_string=config.format_string,
content_model_name=config.content_model_name
)
class TextMessage(BaseTextChatMessage):
"""A text message with string-only content."""

View File

@ -11,6 +11,7 @@ from autogen_agentchat.messages import (
MultiModalMessage,
StopMessage,
StructuredMessage,
StructuredMessageComponent,
TextMessage,
ToolCallExecutionEvent,
ToolCallRequestEvent,
@ -51,6 +52,24 @@ def test_structured_message() -> None:
assert dumped_message["content"]["field2"] == 42
assert dumped_message["type"] == "StructuredMessage[TestContent]"
def test_structured_message_component() -> None:
# Create a structured message with the test contentformat_string="this is a string {field1} and this is an int {field2}"
format_string="this is a string {field1} and this is an int {field2}"
s_m = StructuredMessageComponent(input_model=TestContent, format_string=format_string)
config = s_m._to_config()
s_m_dyn = StructuredMessageComponent._from_config(config)
message = s_m_dyn.StructuredMessage(source="test_agent", content=s_m_dyn.ContentModel(field1="test", field2=42), format_string=s_m_dyn.format_string)
assert isinstance(message.content, s_m_dyn.ContentModel)
assert not isinstance(message.content, TestContent)
assert message.content.field1 == "test"
assert message.content.field2 == 42
dumped_message = message.model_dump()
assert dumped_message["source"] == "test_agent"
assert dumped_message["content"]["field1"] == "test"
assert dumped_message["content"]["field2"] == 42
assert message.to_model_text() == format_string.format(field1="test", field2=42)
def test_message_factory() -> None:
factory = MessageFactory()

View File

@ -0,0 +1,3 @@
from ._json_to_pydantic import schema_to_pydantic_model
__all__ = ["schema_to_pydantic_model"]

View File

@ -0,0 +1,508 @@
import datetime
from typing import Annotated, Any, Dict, ForwardRef, List, Literal, Optional, Type, Union
from pydantic import (
UUID1,
UUID3,
UUID4,
UUID5,
AnyUrl,
BaseModel,
EmailStr,
Field,
IPvAnyAddress,
conbytes,
confloat,
conint,
conlist,
constr,
create_model,
)
class SchemaConversionError(Exception):
"""Base class for schema conversion exceptions."""
pass
class ReferenceNotFoundError(SchemaConversionError):
"""Raised when a $ref cannot be resolved."""
pass
class FormatNotSupportedError(SchemaConversionError):
"""Raised when a format is not supported."""
pass
class UnsupportedKeywordError(SchemaConversionError):
"""Raised when an unsupported JSON Schema keyword is encountered."""
pass
TYPE_MAPPING: Dict[str, Any] = {
"string": str,
"integer": int,
"boolean": bool,
"number": float,
"array": List,
"object": dict,
"null": None,
}
FORMAT_MAPPING: Dict[str, Any] = {
"uuid": UUID4,
"uuid1": UUID1,
"uuid2": UUID4,
"uuid3": UUID3,
"uuid4": UUID4,
"uuid5": UUID5,
"email": EmailStr,
"uri": AnyUrl,
"hostname": constr(strict=True),
"ipv4": IPvAnyAddress,
"ipv6": IPvAnyAddress,
"ipv4-network": IPvAnyAddress,
"ipv6-network": IPvAnyAddress,
"date-time": datetime.datetime,
"date": datetime.date,
"time": datetime.time,
"duration": datetime.timedelta,
"int32": conint(strict=True, ge=-(2**31), le=2**31 - 1),
"int64": conint(strict=True, ge=-(2**63), le=2**63 - 1),
"float": confloat(strict=True),
"double": float,
"decimal": float,
"byte": conbytes(strict=True),
"binary": conbytes(strict=True),
"password": str,
"path": str,
}
class _JSONSchemaToPydantic:
def __init__(self):
self._model_cache = {}
def _resolve_ref(self, ref: str, schema: Dict[str, Any]) -> Dict[str, Any]:
ref_key = ref.split("/")[-1]
definitions = schema.get("$defs", {})
if ref_key not in definitions:
raise ReferenceNotFoundError(
f"Reference `{ref}` not found in `$defs`. Available keys: {list(definitions.keys())}"
)
return definitions[ref_key]
def get_ref(self, ref_name: str) -> Any:
if ref_name not in self._model_cache:
raise ReferenceNotFoundError(
f"Reference `{ref_name}` not found in cache. Available: {list(self._model_cache.keys())}"
)
if self._model_cache[ref_name] is None:
return ForwardRef(ref_name)
return self._model_cache[ref_name]
def _process_definitions(self, root_schema: Dict[str, Any]):
if "$defs" in root_schema:
for model_name in root_schema["$defs"]:
if model_name not in self._model_cache:
self._model_cache[model_name] = None
for model_name, model_schema in root_schema["$defs"].items():
if self._model_cache[model_name] is None:
self._model_cache[model_name] = self.json_schema_to_pydantic(model_schema, model_name, root_schema)
def json_schema_to_pydantic(
self, schema: Dict[str, Any], model_name: str = "GeneratedModel", root_schema: Optional[Dict[str, Any]] = None
) -> Type[BaseModel]:
if root_schema is None:
root_schema = schema
self._process_definitions(root_schema)
if "$ref" in schema:
resolved = self._resolve_ref(schema["$ref"], root_schema)
schema = {**resolved, **{k: v for k, v in schema.items() if k != "$ref"}}
if "allOf" in schema:
merged = {"type": "object", "properties": {}, "required": []}
for s in schema["allOf"]:
part = self._resolve_ref(s["$ref"], root_schema) if "$ref" in s else s
merged["properties"].update(part.get("properties", {}))
merged["required"].extend(part.get("required", []))
for k, v in schema.items():
if k not in {"allOf", "properties", "required"}:
merged[k] = v
merged["required"] = list(set(merged["required"]))
schema = merged
return self._json_schema_to_model(schema, model_name, root_schema)
def _resolve_union_types(self, schemas: List[Dict[str, Any]]) -> List[Any]:
types = []
for s in schemas:
if "$ref" in s:
types.append(self.get_ref(s["$ref"].split("/")[-1]))
elif "enum" in s:
types.append(Literal[tuple(s["enum"])] if len(s["enum"]) > 0 else Any)
else:
json_type = s.get("type")
if json_type not in TYPE_MAPPING:
raise UnsupportedKeywordError(f"Unsupported or missing type `{json_type}` in union")
types.append(TYPE_MAPPING[json_type])
return types
def _extract_field_type(self, key: str, value: Dict[str, Any], model_name: str, root_schema: Dict[str, Any]) -> Any:
json_type = value.get("type")
if json_type not in TYPE_MAPPING:
raise UnsupportedKeywordError(
f"Unsupported or missing type `{json_type}` for field `{key}` in `{model_name}`"
)
base_type = TYPE_MAPPING[json_type]
constraints = {}
if json_type == "string":
if "minLength" in value:
constraints["min_length"] = value["minLength"]
if "maxLength" in value:
constraints["max_length"] = value["maxLength"]
if "pattern" in value:
constraints["pattern"] = value["pattern"]
if constraints:
base_type = constr(**constraints)
elif json_type == "integer":
if "minimum" in value:
constraints["ge"] = value["minimum"]
if "maximum" in value:
constraints["le"] = value["maximum"]
if "exclusiveMinimum" in value:
constraints["gt"] = value["exclusiveMinimum"]
if "exclusiveMaximum" in value:
constraints["lt"] = value["exclusiveMaximum"]
if constraints:
base_type = conint(**constraints)
elif json_type == "number":
if "minimum" in value:
constraints["ge"] = value["minimum"]
if "maximum" in value:
constraints["le"] = value["maximum"]
if "exclusiveMinimum" in value:
constraints["gt"] = value["exclusiveMinimum"]
if "exclusiveMaximum" in value:
constraints["lt"] = value["exclusiveMaximum"]
if constraints:
base_type = confloat(**constraints)
elif json_type == "array":
if "minItems" in value:
constraints["min_length"] = value["minItems"]
if "maxItems" in value:
constraints["max_length"] = value["maxItems"]
item_schema = value.get("items", {"type": "string"})
if "$ref" in item_schema:
item_type = self.get_ref(item_schema["$ref"].split("/")[-1])
else:
item_type_name = item_schema.get("type")
if item_type_name not in TYPE_MAPPING:
raise UnsupportedKeywordError(
f"Unsupported or missing item type `{item_type_name}` for array field `{key}` in `{model_name}`"
)
item_type = TYPE_MAPPING[item_type_name]
base_type = conlist(item_type, **constraints) if constraints else List[item_type]
if "format" in value:
format_type = FORMAT_MAPPING.get(value["format"])
if format_type is None:
raise FormatNotSupportedError(f"Unknown format `{value['format']}` for `{key}` in `{model_name}`")
if not isinstance(format_type, type):
return format_type
if not issubclass(format_type, str):
return format_type
return format_type
return base_type
def _json_schema_to_model(
self, schema: Dict[str, Any], model_name: str, root_schema: Dict[str, Any]
) -> Type[BaseModel]:
if "allOf" in schema:
merged = {"type": "object", "properties": {}, "required": []}
for s in schema["allOf"]:
part = self._resolve_ref(s["$ref"], root_schema) if "$ref" in s else s
merged["properties"].update(part.get("properties", {}))
merged["required"].extend(part.get("required", []))
for k, v in schema.items():
if k not in {"allOf", "properties", "required"}:
merged[k] = v
merged["required"] = list(set(merged["required"]))
schema = merged
fields = {}
required_fields = set(schema.get("required", []))
for key, value in schema.get("properties", {}).items():
if "$ref" in value:
ref_name = value["$ref"].split("/")[-1]
field_type = self.get_ref(ref_name)
elif "anyOf" in value:
sub_models = self._resolve_union_types(value["anyOf"])
field_type = Union[tuple(sub_models)]
elif "oneOf" in value:
sub_models = self._resolve_union_types(value["oneOf"])
field_type = Union[tuple(sub_models)]
if "discriminator" in value:
discriminator = value["discriminator"]["propertyName"]
field_type = Annotated[field_type, Field(discriminator=discriminator)]
elif "enum" in value:
field_type = Literal[tuple(value["enum"])]
elif "allOf" in value:
merged = {"type": "object", "properties": {}, "required": []}
for s in value["allOf"]:
part = self._resolve_ref(s["$ref"], root_schema) if "$ref" in s else s
merged["properties"].update(part.get("properties", {}))
merged["required"].extend(part.get("required", []))
for k, v in value.items():
if k not in {"allOf", "properties", "required"}:
merged[k] = v
merged["required"] = list(set(merged["required"]))
field_type = self._json_schema_to_model(merged, f"{model_name}_{key}", root_schema)
elif value.get("type") == "object" and "properties" in value:
field_type = self._json_schema_to_model(value, f"{model_name}_{key}", root_schema)
else:
field_type = self._extract_field_type(key, value, model_name, root_schema)
if field_type is None:
raise UnsupportedKeywordError(f"Unsupported or missing type for field `{key}` in `{model_name}`")
default_value = value.get("default")
is_required = key in required_fields
if not is_required and default_value is None:
field_type = Optional[field_type]
field_args = {
"default": default_value if not is_required else ...,
}
if "title" in value:
field_args["title"] = value["title"]
if "description" in value:
field_args["description"] = value["description"]
fields[key] = (field_type, Field(**field_args))
model = create_model(model_name, **fields)
model.model_rebuild()
return model
def schema_to_pydantic_model(schema: Dict[str, Any], model_name: str = "GeneratedModel") -> Type[BaseModel]:
"""
Convert a JSON Schema dictionary to a fully-typed Pydantic model.
This function handles schema translation and validation logic to produce
a Pydantic model.
**Supported JSON Schema Features**
- **Primitive types**: `string`, `integer`, `number`, `boolean`, `object`, `array`, `null`
- **String formats**:
- `email`, `uri`, `uuid`, `uuid1`, `uuid3`, `uuid4`, `uuid5`
- `hostname`, `ipv4`, `ipv6`, `ipv4-network`, `ipv6-network`
- `date`, `time`, `date-time`, `duration`
- `byte`, `binary`, `password`, `path`
- **String constraints**:
- `minLength`, `maxLength`, `pattern`
- **Numeric constraints**:
- `minimum`, `maximum`, `exclusiveMinimum`, `exclusiveMaximum`
- **Array constraints**:
- `minItems`, `maxItems`, `items`
- **Object schema support**:
- `properties`, `required`, `title`, `description`, `default`
- **Enums**:
- Converted to Python `Literal` type
- **Union types**:
- `anyOf`, `oneOf` supported with optional `discriminator`
- **Inheritance and composition**:
- `allOf` merges multiple schemas into one model
- **$ref and $defs resolution**:
- Supports references to sibling definitions and self-referencing schemas
.. code-block:: python
from json_schema_to_pydantic import schema_to_pydantic_model
# Example 1: Simple user model
schema = {
"title": "User",
"type": "object",
"properties": {
"name": {"type": "string"},
"email": {"type": "string", "format": "email"},
"age": {"type": "integer", "minimum": 0},
},
"required": ["name", "email"],
}
UserModel = schema_to_pydantic_model(schema)
user = UserModel(name="Alice", email="alice@example.com", age=30)
.. code-block:: python
# Example 2: Nested model
schema = {
"title": "BlogPost",
"type": "object",
"properties": {
"title": {"type": "string"},
"tags": {"type": "array", "items": {"type": "string"}},
"author": {
"type": "object",
"properties": {"name": {"type": "string"}, "email": {"type": "string", "format": "email"}},
"required": ["name"],
},
},
"required": ["title", "author"],
}
BlogPost = schema_to_pydantic_model(schema)
.. code-block:: python
# Example 3: allOf merging with $refs
schema = {
"title": "EmployeeWithDepartment",
"allOf": [{"$ref": "#/$defs/Employee"}, {"$ref": "#/$defs/Department"}],
"$defs": {
"Employee": {
"type": "object",
"properties": {"id": {"type": "string"}, "name": {"type": "string"}},
"required": ["id", "name"],
},
"Department": {
"type": "object",
"properties": {"department": {"type": "string"}},
"required": ["department"],
},
},
}
Model = schema_to_pydantic_model(schema)
.. code-block:: python
# Example 4: Self-referencing (recursive) model
schema = {
"title": "Category",
"type": "object",
"properties": {
"name": {"type": "string"},
"subcategories": {"type": "array", "items": {"$ref": "#/$defs/Category"}},
},
"required": ["name"],
"$defs": {
"Category": {
"type": "object",
"properties": {
"name": {"type": "string"},
"subcategories": {"type": "array", "items": {"$ref": "#/$defs/Category"}},
},
"required": ["name"],
}
},
}
Category = schema_to_pydantic_model(schema)
.. code-block:: python
# Example 5: Serializing and deserializing with Pydantic
from uuid import uuid4
from pydantic import BaseModel, EmailStr, Field
from typing import Optional, List, Dict, Any
from autogen_core.utils import schema_to_pydantic_model
class Address(BaseModel):
street: str
city: str
zipcode: str
class User(BaseModel):
id: str
name: str
email: EmailStr
age: int = Field(..., ge=18)
address: Address
class Employee(BaseModel):
id: str
name: str
manager: Optional["Employee"] = None
class Department(BaseModel):
name: str
employees: List[Employee]
class ComplexModel(BaseModel):
user: User
extra_info: Optional[Dict[str, Any]] = None
sub_items: List[Employee]
# Convert ComplexModel to JSON schema
complex_schema = ComplexModel.model_json_schema()
# Rebuild a new Pydantic model from JSON schema
ReconstructedModel = schema_to_pydantic_model(complex_schema, "ComplexModel")
# Instantiate reconstructed model
reconstructed = ReconstructedModel(
user={
"id": str(uuid4()),
"name": "Alice",
"email": "alice@example.com",
"age": 30,
"address": {"street": "123 Main St", "city": "Wonderland", "zipcode": "12345"},
},
sub_items=[{"id": str(uuid4()), "name": "Bob", "manager": {"id": str(uuid4()), "name": "Eve"}}],
)
print(reconstructed.model_dump())
Args:
schema (Dict[str, Any]): A valid JSON Schema dictionary.
model_name (str, optional): The name of the root model. Defaults to "GeneratedModel".
Returns:
Type[BaseModel]: A dynamically generated Pydantic model class.
Raises:
ReferenceNotFoundError: If a `$ref` key references a missing entry.
FormatNotSupportedError: If a `format` keyword is unknown or unsupported.
UnsupportedKeywordError: If the schema contains an unsupported `type`.
See Also:
- :class:`pydantic.BaseModel`
- :func:`pydantic.create_model`
- https://json-schema.org/
"""
...
return _JSONSchemaToPydantic().json_schema_to_pydantic(schema, model_name)

View File

@ -0,0 +1,656 @@
from typing import Any, Dict, List, Literal, Optional
from uuid import UUID, uuid4
import pytest
from autogen_core.utils._json_to_pydantic import (
FormatNotSupportedError,
ReferenceNotFoundError,
UnsupportedKeywordError,
_JSONSchemaToPydantic,
)
from pydantic import BaseModel, EmailStr, Field, ValidationError
# ✅ Define Pydantic models for testing
class Address(BaseModel):
street: str
city: str
zipcode: str
class User(BaseModel):
id: UUID
name: str
email: EmailStr
age: int = Field(..., ge=18) # Minimum age = 18
address: Address
class Employee(BaseModel):
id: UUID
name: str
manager: Optional["Employee"] = None # Recursive self-reference
class Department(BaseModel):
name: str
employees: List[Employee] # Array of objects
class ComplexModel(BaseModel):
user: User
extra_info: Optional[Dict[str, Any]] = None # Optional dictionary
sub_items: List[Employee] # List of Employees
@pytest.fixture
def converter():
"""Fixture to create a fresh instance of JSONSchemaToPydantic for every test."""
return _JSONSchemaToPydantic()
@pytest.fixture
def sample_json_schema():
"""Fixture that returns a JSON schema dynamically using model_json_schema()."""
return User.model_json_schema()
@pytest.fixture
def sample_json_schema_recursive():
"""Fixture that returns a self-referencing JSON schema."""
return Employee.model_json_schema()
@pytest.fixture
def sample_json_schema_nested():
"""Fixture that returns a nested schema with arrays of objects."""
return Department.model_json_schema()
@pytest.fixture
def sample_json_schema_complex():
"""Fixture that returns a complex schema with multiple structures."""
return ComplexModel.model_json_schema()
@pytest.mark.parametrize(
"schema_fixture, model_name, expected_fields",
[
(sample_json_schema, "User", ["id", "name", "email", "age", "address"]),
(sample_json_schema_recursive, "Employee", ["id", "name", "manager"]),
(sample_json_schema_nested, "Department", ["name", "employees"]),
(sample_json_schema_complex, "ComplexModel", ["user", "extra_info", "sub_items"]),
],
)
def test_json_schema_to_pydantic(converter, schema_fixture, model_name, expected_fields, request):
"""Test conversion of JSON Schema to Pydantic model using the class instance."""
schema = request.getfixturevalue(schema_fixture.__name__)
Model = converter.json_schema_to_pydantic(schema, model_name)
for field in expected_fields:
assert field in Model.__annotations__, f"Expected '{field}' missing in {model_name}Model"
# ✅ **Valid Data Tests**
@pytest.mark.parametrize(
"schema_fixture, model_name, valid_data",
[
(
sample_json_schema,
"User",
{
"id": str(uuid4()),
"name": "Alice",
"email": "alice@example.com",
"age": 25,
"address": {"street": "123 Main St", "city": "Metropolis", "zipcode": "12345"},
},
),
(
sample_json_schema_recursive,
"Employee",
{
"id": str(uuid4()),
"name": "Alice",
"manager": {
"id": str(uuid4()),
"name": "Bob",
},
},
),
(
sample_json_schema_nested,
"Department",
{
"name": "Engineering",
"employees": [
{
"id": str(uuid4()),
"name": "Alice",
"manager": {
"id": str(uuid4()),
"name": "Bob",
},
}
],
},
),
(
sample_json_schema_complex,
"ComplexModel",
{
"user": {
"id": str(uuid4()),
"name": "Charlie",
"email": "charlie@example.com",
"age": 30,
"address": {"street": "456 Side St", "city": "Gotham", "zipcode": "67890"},
},
"extra_info": {"hobby": "Chess", "level": "Advanced"},
"sub_items": [
{"id": str(uuid4()), "name": "Eve"},
{"id": str(uuid4()), "name": "David", "manager": {"id": str(uuid4()), "name": "Frank"}},
],
},
),
],
)
def test_valid_data_model(converter, schema_fixture, model_name, valid_data, request):
"""Test that valid data is accepted by the generated model."""
schema = request.getfixturevalue(schema_fixture.__name__)
Model = converter.json_schema_to_pydantic(schema, model_name)
instance = Model(**valid_data)
assert instance
dumped = instance.model_dump(mode="json", exclude_none=True)
assert dumped == valid_data, f"Model output mismatch.\nExpected: {valid_data}\nGot: {dumped}"
# ✅ **Invalid Data Tests**
@pytest.mark.parametrize(
"schema_fixture, model_name, invalid_data",
[
(
sample_json_schema,
"User",
{
"id": "not-a-uuid", # Invalid UUID
"name": "Alice",
"email": "not-an-email", # Invalid email
"age": 17, # Below minimum
"address": {"street": "123 Main St", "city": "Metropolis"},
},
),
(
sample_json_schema_recursive,
"Employee",
{
"id": str(uuid4()),
"name": "Alice",
"manager": {
"id": "not-a-uuid", # Invalid UUID
"name": "Bob",
},
},
),
(
sample_json_schema_nested,
"Department",
{
"name": "Engineering",
"employees": [
{
"id": "not-a-uuid", # Invalid UUID
"name": "Alice",
"manager": {
"id": str(uuid4()),
"name": "Bob",
},
}
],
},
),
(
sample_json_schema_complex,
"ComplexModel",
{
"user": {
"id": str(uuid4()),
"name": "Charlie",
"email": "charlie@example.com",
"age": "thirty", # Invalid: Should be an int
"address": {"street": "456 Side St", "city": "Gotham", "zipcode": "67890"},
},
"extra_info": "should-be-dictionary", # Invalid type
"sub_items": [
{"id": "invalid-uuid", "name": "Eve"}, # Invalid UUID
{"id": str(uuid4()), "name": 123}, # Invalid name type
],
},
),
],
)
def test_invalid_data_model(converter, schema_fixture, model_name, invalid_data, request):
"""Test that invalid data raises ValidationError."""
schema = request.getfixturevalue(schema_fixture.__name__)
Model = converter.json_schema_to_pydantic(schema, model_name)
with pytest.raises(ValidationError):
Model(**invalid_data)
class ListDictModel(BaseModel):
"""Example for `List[Dict[str, Any]]`"""
data: List[Dict[str, Any]]
class DictListModel(BaseModel):
"""Example for `Dict[str, List[Any]]`"""
mapping: Dict[str, List[Any]]
class NestedListModel(BaseModel):
"""Example for `List[List[str]]`"""
matrix: List[List[str]]
@pytest.fixture
def sample_json_schema_list_dict():
"""Fixture for `List[Dict[str, Any]]`"""
return ListDictModel.model_json_schema()
@pytest.fixture
def sample_json_schema_dict_list():
"""Fixture for `Dict[str, List[Any]]`"""
return DictListModel.model_json_schema()
@pytest.fixture
def sample_json_schema_nested_list():
"""Fixture for `List[List[str]]`"""
return NestedListModel.model_json_schema()
@pytest.mark.parametrize(
"schema_fixture, model_name, expected_fields",
[
(sample_json_schema_list_dict, "ListDictModel", ["data"]),
(sample_json_schema_dict_list, "DictListModel", ["mapping"]),
(sample_json_schema_nested_list, "NestedListModel", ["matrix"]),
],
)
def test_json_schema_to_pydantic_nested(converter, schema_fixture, model_name, expected_fields, request):
"""Test conversion of JSON Schema to Pydantic model using the class instance."""
schema = request.getfixturevalue(schema_fixture.__name__)
Model = converter.json_schema_to_pydantic(schema, model_name)
for field in expected_fields:
assert field in Model.__annotations__, f"Expected '{field}' missing in {model_name}Model"
# ✅ **Valid Data Tests**
@pytest.mark.parametrize(
"schema_fixture, model_name, valid_data",
[
(
sample_json_schema_list_dict,
"ListDictModel",
{
"data": [
{"key1": "value1", "key2": 10},
{"another_key": False, "nested": {"subkey": "data"}},
]
},
),
(
sample_json_schema_dict_list,
"DictListModel",
{
"mapping": {
"first": ["a", "b", "c"],
"second": [1, 2, 3, 4],
"third": [True, False, True],
}
},
),
(
sample_json_schema_nested_list,
"NestedListModel",
{"matrix": [["A", "B"], ["C", "D"], ["E", "F"]]},
),
],
)
def test_valid_data_model_nested(converter, schema_fixture, model_name, valid_data, request):
"""Test that valid data is accepted by the generated model."""
schema = request.getfixturevalue(schema_fixture.__name__)
Model = converter.json_schema_to_pydantic(schema, model_name)
instance = Model(**valid_data)
assert instance
for field, value in valid_data.items():
assert (
getattr(instance, field) == value
), f"Mismatch in field `{field}`: expected `{value}`, got `{getattr(instance, field)}`"
# ✅ **Invalid Data Tests**
@pytest.mark.parametrize(
"schema_fixture, model_name, invalid_data",
[
(
sample_json_schema_list_dict,
"ListDictModel",
{
"data": "should-be-a-list", # ❌ Should be a list of dicts
},
),
(
sample_json_schema_dict_list,
"DictListModel",
{
"mapping": [
"should-be-a-dictionary", # ❌ Should be a dict of lists
]
},
),
(
sample_json_schema_nested_list,
"NestedListModel",
{"matrix": [["A", "B"], "C", ["D", "E"]]}, # ❌ "C" is not a list
),
],
)
def test_invalid_data_model_nested(converter, schema_fixture, model_name, invalid_data, request):
"""Test that invalid data raises ValidationError."""
schema = request.getfixturevalue(schema_fixture.__name__)
Model = converter.json_schema_to_pydantic(schema, model_name)
with pytest.raises(ValidationError):
Model(**invalid_data)
def test_reference_not_found(converter):
schema = {"type": "object", "properties": {"manager": {"$ref": "#/$defs/MissingRef"}}}
with pytest.raises(ReferenceNotFoundError):
converter.json_schema_to_pydantic(schema, "MissingRefModel")
def test_format_not_supported(converter):
schema = {"type": "object", "properties": {"custom_field": {"type": "string", "format": "unsupported-format"}}}
with pytest.raises(FormatNotSupportedError):
converter.json_schema_to_pydantic(schema, "UnsupportedFormatModel")
def test_unsupported_keyword(converter):
schema = {"type": "object", "properties": {"broken_field": {"title": "Missing type"}}}
with pytest.raises(UnsupportedKeywordError):
converter.json_schema_to_pydantic(schema, "MissingTypeModel")
def test_enum_field_schema():
schema = {
"type": "object",
"properties": {
"status": {"type": "string", "enum": ["pending", "approved", "rejected"]},
"priority": {"type": "integer", "enum": [1, 2, 3]},
},
"required": ["status"],
}
converter = _JSONSchemaToPydantic()
Model = converter.json_schema_to_pydantic(schema, "Task")
assert Model.model_fields["status"].annotation == Literal["pending", "approved", "rejected"]
assert Model.model_fields["priority"].annotation == Optional[Literal[1, 2, 3]]
instance = Model(status="approved", priority=2)
assert instance.status == "approved"
assert instance.priority == 2
def test_metadata_title_description(converter):
schema = {
"title": "CustomerProfile",
"description": "A profile containing personal and contact info",
"type": "object",
"properties": {
"first_name": {"type": "string", "title": "First Name", "description": "Given name of the user"},
"age": {"type": "integer", "title": "Age", "description": "Age in years"},
"contact": {
"type": "object",
"title": "Contact Information",
"description": "How to reach the user",
"properties": {
"email": {
"type": "string",
"format": "email",
"title": "Email Address",
"description": "Primary email",
}
},
},
},
"required": ["first_name"],
}
Model = converter.json_schema_to_pydantic(schema, "CustomerProfile")
generated_schema = Model.model_json_schema()
assert generated_schema["title"] == "CustomerProfile"
props = generated_schema["properties"]
assert props["first_name"]["title"] == "First Name"
assert props["first_name"]["description"] == "Given name of the user"
assert props["age"]["title"] == "Age"
assert props["age"]["description"] == "Age in years"
contact = props["contact"]
assert contact["title"] == "Contact Information"
assert contact["description"] == "How to reach the user"
# Follow the $ref
ref_key = contact["anyOf"][0]["$ref"].split("/")[-1]
contact_def = generated_schema["$defs"][ref_key]
email = contact_def["properties"]["email"]
assert email["title"] == "Email Address"
assert email["description"] == "Primary email"
def test_oneof_with_discriminator(converter):
schema = {
"title": "PetWrapper",
"type": "object",
"properties": {
"pet": {
"oneOf": [{"$ref": "#/$defs/Cat"}, {"$ref": "#/$defs/Dog"}],
"discriminator": {"propertyName": "pet_type"},
}
},
"required": ["pet"],
"$defs": {
"Cat": {
"type": "object",
"properties": {"pet_type": {"type": "string", "enum": ["cat"]}, "hunting_skill": {"type": "string"}},
"required": ["pet_type", "hunting_skill"],
"title": "Cat",
},
"Dog": {
"type": "object",
"properties": {"pet_type": {"type": "string", "enum": ["dog"]}, "pack_size": {"type": "integer"}},
"required": ["pet_type", "pack_size"],
"title": "Dog",
},
},
}
Model = converter.json_schema_to_pydantic(schema, "PetWrapper")
# Instantiate with a Cat
cat = Model(pet={"pet_type": "cat", "hunting_skill": "expert"})
assert cat.pet.pet_type == "cat"
# Instantiate with a Dog
dog = Model(pet={"pet_type": "dog", "pack_size": 4})
assert dog.pet.pet_type == "dog"
# Check round-trip schema includes discriminator
model_schema = Model.model_json_schema()
assert "discriminator" in model_schema["properties"]["pet"]
assert model_schema["properties"]["pet"]["discriminator"]["propertyName"] == "pet_type"
def test_allof_merging_with_refs(converter):
schema = {
"title": "EmployeeWithDepartment",
"allOf": [{"$ref": "#/$defs/Employee"}, {"$ref": "#/$defs/Department"}],
"$defs": {
"Employee": {
"type": "object",
"properties": {"id": {"type": "string"}, "name": {"type": "string"}},
"required": ["id", "name"],
"title": "Employee",
},
"Department": {
"type": "object",
"properties": {"department": {"type": "string"}},
"required": ["department"],
"title": "Department",
},
},
}
Model = converter.json_schema_to_pydantic(schema, "EmployeeWithDepartment")
instance = Model(id="123", name="Alice", department="Engineering")
assert instance.id == "123"
assert instance.name == "Alice"
assert instance.department == "Engineering"
dumped = instance.model_dump()
assert dumped == {"id": "123", "name": "Alice", "department": "Engineering"}
def test_nested_allof_merging(converter):
schema = {
"title": "ContainerModel",
"type": "object",
"properties": {
"nested": {
"type": "object",
"properties": {
"data": {
"allOf": [
{"$ref": "#/$defs/Base"},
{"type": "object", "properties": {"extra": {"type": "string"}}, "required": ["extra"]},
]
}
},
"required": ["data"],
}
},
"required": ["nested"],
"$defs": {
"Base": {
"type": "object",
"properties": {"base_field": {"type": "string"}},
"required": ["base_field"],
"title": "Base",
}
},
}
Model = converter.json_schema_to_pydantic(schema, "ContainerModel")
instance = Model(nested={"data": {"base_field": "abc", "extra": "xyz"}})
assert instance.nested.data.base_field == "abc"
assert instance.nested.data.extra == "xyz"
@pytest.mark.parametrize(
"schema, field_name, valid_values, invalid_values",
[
# String constraints
(
{
"type": "object",
"properties": {
"username": {"type": "string", "minLength": 3, "maxLength": 10, "pattern": "^[a-zA-Z0-9_]+$"}
},
"required": ["username"],
},
"username",
["user_123", "abc", "Name2023"],
["", "ab", "toolongusername123", "invalid!char"],
),
# Integer constraints
(
{
"type": "object",
"properties": {"age": {"type": "integer", "minimum": 18, "maximum": 99}},
"required": ["age"],
},
"age",
[18, 25, 99],
[17, 100, -1],
),
# Float constraints
(
{
"type": "object",
"properties": {"score": {"type": "number", "minimum": 0.0, "exclusiveMaximum": 1.0}},
"required": ["score"],
},
"score",
[0.0, 0.5, 0.999],
[-0.1, 1.0, 2.5],
),
# Array constraints
(
{
"type": "object",
"properties": {"tags": {"type": "array", "items": {"type": "string"}, "minItems": 1, "maxItems": 3}},
"required": ["tags"],
},
"tags",
[["a"], ["a", "b"], ["x", "y", "z"]],
[[], ["one", "two", "three", "four"]],
),
],
)
def test_field_constraints(schema, field_name, valid_values, invalid_values):
converter = _JSONSchemaToPydantic()
Model = converter.json_schema_to_pydantic(schema, "ConstraintModel")
import json
for value in valid_values:
instance = Model(**{field_name: value})
assert getattr(instance, field_name) == value
for value in invalid_values:
with pytest.raises(ValidationError):
Model(**{field_name: value})
@pytest.mark.parametrize(
"schema",
[
# Top-level field
{"type": "object", "properties": {"weird": {"type": "abc"}}, "required": ["weird"]},
# Inside array items
{"type": "object", "properties": {"items": {"type": "array", "items": {"type": "abc"}}}, "required": ["items"]},
# Inside anyOf
{
"type": "object",
"properties": {"choice": {"anyOf": [{"type": "string"}, {"type": "abc"}]}},
"required": ["choice"],
},
],
)
def test_unknown_type_raises(schema):
converter = _JSONSchemaToPydantic()
with pytest.raises(UnsupportedKeywordError):
converter.json_schema_to_pydantic(schema, "UnknownTypeModel")