autogen/python/packages/autogen-core/tests/test_serialization.py

203 lines
6.6 KiB
Python

from dataclasses import dataclass
from typing import Union
import pytest
from autogen_core import Image
from autogen_core._serialization import (
JSON_DATA_CONTENT_TYPE,
PROTOBUF_DATA_CONTENT_TYPE,
DataclassJsonMessageSerializer,
MessageSerializer,
PydanticJsonMessageSerializer,
SerializationRegistry,
try_get_known_serializers_for_type,
)
from PIL import Image as PILImage
from protos.serialization_test_pb2 import NestingProtoMessage, ProtoMessage
from pydantic import BaseModel
class PydanticMessage(BaseModel):
message: str
class NestingPydanticMessage(BaseModel):
message: str
nested: PydanticMessage
@dataclass
class DataclassMessage:
message: str
@dataclass
class NestingDataclassMessage:
message: str
nested: DataclassMessage
@dataclass
class NestingPydanticDataclassMessage:
message: str
nested: PydanticMessage
def test_pydantic() -> None:
serde = SerializationRegistry()
serde.add_serializer(try_get_known_serializers_for_type(PydanticMessage))
message = PydanticMessage(message="hello")
name = serde.type_name(message)
json = serde.serialize(message, type_name=name, data_content_type=JSON_DATA_CONTENT_TYPE)
assert name == "PydanticMessage"
assert json == b'{"message":"hello"}'
deserialized = serde.deserialize(json, type_name=name, data_content_type=JSON_DATA_CONTENT_TYPE)
assert deserialized == message
def test_nested_pydantic() -> None:
serde = SerializationRegistry()
serde.add_serializer(try_get_known_serializers_for_type(NestingPydanticMessage))
message = NestingPydanticMessage(message="hello", nested=PydanticMessage(message="world"))
name = serde.type_name(message)
json = serde.serialize(message, type_name=name, data_content_type=JSON_DATA_CONTENT_TYPE)
assert json == b'{"message":"hello","nested":{"message":"world"}}'
deserialized = serde.deserialize(json, type_name=name, data_content_type=JSON_DATA_CONTENT_TYPE)
assert deserialized == message
def test_dataclass() -> None:
serde = SerializationRegistry()
serde.add_serializer(try_get_known_serializers_for_type(DataclassMessage))
message = DataclassMessage(message="hello")
name = serde.type_name(message)
json = serde.serialize(message, type_name=name, data_content_type=JSON_DATA_CONTENT_TYPE)
assert json == b'{"message": "hello"}'
deserialized = serde.deserialize(json, type_name=name, data_content_type=JSON_DATA_CONTENT_TYPE)
assert deserialized == message
def test_nesting_dataclass_dataclass() -> None:
serde = SerializationRegistry()
with pytest.raises(ValueError):
serde.add_serializer(try_get_known_serializers_for_type(NestingDataclassMessage))
def test_proto() -> None:
serde = SerializationRegistry()
serde.add_serializer(try_get_known_serializers_for_type(ProtoMessage))
message = ProtoMessage(message="hello")
name = serde.type_name(message)
data = serde.serialize(message, type_name=name, data_content_type=PROTOBUF_DATA_CONTENT_TYPE)
assert name == "agents.ProtoMessage"
deserialized = serde.deserialize(data, type_name=name, data_content_type=PROTOBUF_DATA_CONTENT_TYPE)
assert deserialized.message == message.message
def test_nested_proto() -> None:
serde = SerializationRegistry()
serde.add_serializer(try_get_known_serializers_for_type(NestingProtoMessage))
message = NestingProtoMessage(message="hello", nested=ProtoMessage(message="world"))
name = serde.type_name(message)
data = serde.serialize(message, type_name=name, data_content_type=PROTOBUF_DATA_CONTENT_TYPE)
deserialized = serde.deserialize(data, type_name=name, data_content_type=PROTOBUF_DATA_CONTENT_TYPE)
assert deserialized.message == message.message
assert deserialized.nested.message == message.nested.message
@dataclass
class DataclassNestedUnionSyntaxOldMessage:
message: Union[str, int]
@dataclass
class DataclassNestedUnionSyntaxNewMessage:
message: str | int
@pytest.mark.parametrize("cls", [DataclassNestedUnionSyntaxOldMessage, DataclassNestedUnionSyntaxNewMessage])
def test_nesting_union_old_syntax_dataclass(
cls: type[DataclassNestedUnionSyntaxOldMessage | DataclassNestedUnionSyntaxNewMessage],
) -> None:
with pytest.raises(ValueError):
_serializer = DataclassJsonMessageSerializer(cls)
def test_nesting_dataclass_pydantic() -> None:
serde = SerializationRegistry()
with pytest.raises(ValueError):
serde.add_serializer(try_get_known_serializers_for_type(NestingPydanticDataclassMessage))
def test_invalid_type() -> None:
serde = SerializationRegistry()
try:
serde.add_serializer(try_get_known_serializers_for_type(str))
except ValueError as e:
assert str(e) == "Unsupported type <class 'str'>"
def test_custom_type() -> None:
serde = SerializationRegistry()
class CustomStringTypeSerializer(MessageSerializer[str]):
@property
def data_content_type(self) -> str:
return "str"
@property
def type_name(self) -> str:
return "custom_str"
def deserialize(self, payload: bytes) -> str:
message = payload.decode("utf-8")
return message[1:-1]
def serialize(self, message: str) -> bytes:
return f'"{message}"'.encode("utf-8")
serde.add_serializer(CustomStringTypeSerializer())
message = "hello"
json = serde.serialize(message, type_name="custom_str", data_content_type="str")
assert json == b'"hello"'
deserialized = serde.deserialize(json, type_name="custom_str", data_content_type="str")
assert deserialized == message
def test_image_type() -> None:
pil_image = PILImage.new("RGB", (100, 100))
image = Image(pil_image)
class PydanticImageMessage(BaseModel):
image: Image
serializer = PydanticJsonMessageSerializer(PydanticImageMessage)
json = serializer.serialize(PydanticImageMessage(image=image))
deserialized = serializer.deserialize(json)
assert deserialized.image.image.size == (100, 100)
assert deserialized.image.image.mode == "RGB"
assert deserialized.image.image == image.image
def test_type_name_for_protos() -> None:
type_name = SerializationRegistry().type_name(ProtoMessage())
assert type_name == "agents.ProtoMessage"
type_name = SerializationRegistry().type_name(ProtoMessage)
assert type_name == "agents.ProtoMessage"
type_name = SerializationRegistry().type_name(NestingProtoMessage())
assert type_name == "agents.NestingProtoMessage"
type_name = SerializationRegistry().type_name(NestingProtoMessage)
assert type_name == "agents.NestingProtoMessage"