mirror of https://github.com/microsoft/autogen.git
542 lines
20 KiB
Python
542 lines
20 KiB
Python
import asyncio
|
|
import logging
|
|
import os
|
|
from typing import Any, List
|
|
|
|
import pytest
|
|
from autogen_core import (
|
|
PROTOBUF_DATA_CONTENT_TYPE,
|
|
AgentId,
|
|
AgentType,
|
|
DefaultTopicId,
|
|
MessageContext,
|
|
RoutedAgent,
|
|
Subscription,
|
|
TopicId,
|
|
TypeSubscription,
|
|
default_subscription,
|
|
event,
|
|
try_get_known_serializers_for_type,
|
|
type_subscription,
|
|
)
|
|
from autogen_ext.runtimes.grpc import GrpcWorkerAgentRuntime, GrpcWorkerAgentRuntimeHost
|
|
from autogen_test_utils import (
|
|
CascadingAgent,
|
|
CascadingMessageType,
|
|
ContentMessage,
|
|
LoopbackAgent,
|
|
LoopbackAgentWithDefaultSubscription,
|
|
MessageType,
|
|
NoopAgent,
|
|
)
|
|
|
|
from .protos.serialization_test_pb2 import ProtoMessage
|
|
|
|
|
|
@pytest.mark.grpc
|
|
@pytest.mark.asyncio
|
|
async def test_agent_types_must_be_unique_single_worker() -> None:
|
|
host_address = "localhost:50051"
|
|
host = GrpcWorkerAgentRuntimeHost(address=host_address)
|
|
host.start()
|
|
|
|
worker = GrpcWorkerAgentRuntime(host_address=host_address)
|
|
worker.start()
|
|
|
|
await worker.register_factory(type=AgentType("name1"), agent_factory=lambda: NoopAgent(), expected_class=NoopAgent)
|
|
|
|
with pytest.raises(ValueError):
|
|
await worker.register_factory(
|
|
type=AgentType("name1"), agent_factory=lambda: NoopAgent(), expected_class=NoopAgent
|
|
)
|
|
|
|
await worker.register_factory(type=AgentType("name4"), agent_factory=lambda: NoopAgent(), expected_class=NoopAgent)
|
|
|
|
await worker.stop()
|
|
await host.stop()
|
|
|
|
|
|
@pytest.mark.grpc
|
|
@pytest.mark.asyncio
|
|
async def test_agent_types_must_be_unique_multiple_workers() -> None:
|
|
host_address = "localhost:50052"
|
|
host = GrpcWorkerAgentRuntimeHost(address=host_address)
|
|
host.start()
|
|
|
|
worker1 = GrpcWorkerAgentRuntime(host_address=host_address)
|
|
worker1.start()
|
|
worker2 = GrpcWorkerAgentRuntime(host_address=host_address)
|
|
worker2.start()
|
|
|
|
await worker1.register_factory(type=AgentType("name1"), agent_factory=lambda: NoopAgent(), expected_class=NoopAgent)
|
|
|
|
with pytest.raises(Exception, match="Agent type name1 already registered"):
|
|
await worker2.register_factory(
|
|
type=AgentType("name1"), agent_factory=lambda: NoopAgent(), expected_class=NoopAgent
|
|
)
|
|
|
|
await worker2.register_factory(type=AgentType("name4"), agent_factory=lambda: NoopAgent(), expected_class=NoopAgent)
|
|
|
|
await worker1.stop()
|
|
await worker2.stop()
|
|
await host.stop()
|
|
|
|
|
|
@pytest.mark.grpc
|
|
@pytest.mark.asyncio
|
|
async def test_register_receives_publish() -> None:
|
|
host_address = "localhost:50053"
|
|
host = GrpcWorkerAgentRuntimeHost(address=host_address)
|
|
host.start()
|
|
|
|
worker1 = GrpcWorkerAgentRuntime(host_address=host_address)
|
|
worker1.start()
|
|
worker1.add_message_serializer(try_get_known_serializers_for_type(MessageType))
|
|
await worker1.register_factory(
|
|
type=AgentType("name1"), agent_factory=lambda: LoopbackAgent(), expected_class=LoopbackAgent
|
|
)
|
|
await worker1.add_subscription(TypeSubscription("default", "name1"))
|
|
|
|
worker2 = GrpcWorkerAgentRuntime(host_address=host_address)
|
|
worker2.start()
|
|
worker2.add_message_serializer(try_get_known_serializers_for_type(MessageType))
|
|
await worker2.register_factory(
|
|
type=AgentType("name2"), agent_factory=lambda: LoopbackAgent(), expected_class=LoopbackAgent
|
|
)
|
|
await worker2.add_subscription(TypeSubscription("default", "name2"))
|
|
|
|
# Publish message from worker1
|
|
await worker1.publish_message(MessageType(), topic_id=TopicId("default", "default"))
|
|
|
|
# Let the agent run for a bit.
|
|
await asyncio.sleep(2)
|
|
|
|
# Agents in default topic source should have received the message.
|
|
worker1_agent = await worker1.try_get_underlying_agent_instance(AgentId("name1", "default"), LoopbackAgent)
|
|
assert worker1_agent.num_calls == 1
|
|
worker2_agent = await worker2.try_get_underlying_agent_instance(AgentId("name2", "default"), LoopbackAgent)
|
|
assert worker2_agent.num_calls == 1
|
|
|
|
# Agents in other topic source should not have received the message.
|
|
worker1_agent = await worker1.try_get_underlying_agent_instance(AgentId("name1", "other"), LoopbackAgent)
|
|
assert worker1_agent.num_calls == 0
|
|
worker2_agent = await worker2.try_get_underlying_agent_instance(AgentId("name2", "other"), LoopbackAgent)
|
|
assert worker2_agent.num_calls == 0
|
|
|
|
await worker1.stop()
|
|
await worker2.stop()
|
|
await host.stop()
|
|
|
|
|
|
@pytest.mark.grpc
|
|
@pytest.mark.asyncio
|
|
async def test_register_receives_publish_cascade_single_worker() -> None:
|
|
host_address = "localhost:50054"
|
|
host = GrpcWorkerAgentRuntimeHost(address=host_address)
|
|
host.start()
|
|
runtime = GrpcWorkerAgentRuntime(host_address=host_address)
|
|
runtime.start()
|
|
|
|
num_agents = 5
|
|
num_initial_messages = 5
|
|
max_rounds = 5
|
|
total_num_calls_expected = 0
|
|
for i in range(0, max_rounds):
|
|
total_num_calls_expected += num_initial_messages * ((num_agents - 1) ** i)
|
|
|
|
# Register agents
|
|
for i in range(num_agents):
|
|
await CascadingAgent.register(runtime, f"name{i}", lambda: CascadingAgent(max_rounds))
|
|
|
|
# Publish messages
|
|
for _ in range(num_initial_messages):
|
|
await runtime.publish_message(CascadingMessageType(round=1), topic_id=DefaultTopicId())
|
|
|
|
# Wait for all agents to finish.
|
|
await asyncio.sleep(10)
|
|
|
|
# Check that each agent received the correct number of messages.
|
|
for i in range(num_agents):
|
|
agent = await runtime.try_get_underlying_agent_instance(AgentId(f"name{i}", "default"), CascadingAgent)
|
|
assert agent.num_calls == total_num_calls_expected
|
|
|
|
await runtime.stop()
|
|
await host.stop()
|
|
|
|
|
|
@pytest.mark.grpc
|
|
@pytest.mark.skip(reason="Fix flakiness")
|
|
@pytest.mark.asyncio
|
|
async def test_register_receives_publish_cascade_multiple_workers() -> None:
|
|
logging.basicConfig(level=logging.DEBUG)
|
|
host_address = "localhost:50055"
|
|
host = GrpcWorkerAgentRuntimeHost(address=host_address)
|
|
host.start()
|
|
|
|
# TODO: Increasing num_initial_messages or max_round to 2 causes the test to fail.
|
|
num_agents = 2
|
|
num_initial_messages = 1
|
|
max_rounds = 1
|
|
total_num_calls_expected = 0
|
|
for i in range(0, max_rounds):
|
|
total_num_calls_expected += num_initial_messages * ((num_agents - 1) ** i)
|
|
|
|
# Run multiple workers one for each agent.
|
|
workers: List[GrpcWorkerAgentRuntime] = []
|
|
# Register agents
|
|
for i in range(num_agents):
|
|
runtime = GrpcWorkerAgentRuntime(host_address=host_address)
|
|
runtime.start()
|
|
await CascadingAgent.register(runtime, f"name{i}", lambda: CascadingAgent(max_rounds))
|
|
workers.append(runtime)
|
|
|
|
# Publish messages
|
|
publisher = GrpcWorkerAgentRuntime(host_address=host_address)
|
|
publisher.add_message_serializer(try_get_known_serializers_for_type(CascadingMessageType))
|
|
publisher.start()
|
|
for _ in range(num_initial_messages):
|
|
await publisher.publish_message(CascadingMessageType(round=1), topic_id=DefaultTopicId())
|
|
|
|
await asyncio.sleep(20)
|
|
|
|
# Check that each agent received the correct number of messages.
|
|
for i in range(num_agents):
|
|
agent = await workers[i].try_get_underlying_agent_instance(AgentId(f"name{i}", "default"), CascadingAgent)
|
|
assert agent.num_calls == total_num_calls_expected
|
|
|
|
for worker in workers:
|
|
await worker.stop()
|
|
await publisher.stop()
|
|
await host.stop()
|
|
|
|
|
|
@pytest.mark.grpc
|
|
@pytest.mark.asyncio
|
|
async def test_default_subscription() -> None:
|
|
host_address = "localhost:50056"
|
|
host = GrpcWorkerAgentRuntimeHost(address=host_address)
|
|
host.start()
|
|
worker = GrpcWorkerAgentRuntime(host_address=host_address)
|
|
worker.start()
|
|
publisher = GrpcWorkerAgentRuntime(host_address=host_address)
|
|
publisher.add_message_serializer(try_get_known_serializers_for_type(MessageType))
|
|
publisher.start()
|
|
|
|
await LoopbackAgentWithDefaultSubscription.register(worker, "name", lambda: LoopbackAgentWithDefaultSubscription())
|
|
|
|
await publisher.publish_message(MessageType(), topic_id=DefaultTopicId())
|
|
|
|
await asyncio.sleep(2)
|
|
|
|
# Agent in default topic source should have received the message.
|
|
long_running_agent = await worker.try_get_underlying_agent_instance(
|
|
AgentId("name", "default"), type=LoopbackAgentWithDefaultSubscription
|
|
)
|
|
assert long_running_agent.num_calls == 1
|
|
|
|
# Agent in other topic source should not have received the message.
|
|
other_long_running_agent = await worker.try_get_underlying_agent_instance(
|
|
AgentId("name", key="other"), type=LoopbackAgentWithDefaultSubscription
|
|
)
|
|
assert other_long_running_agent.num_calls == 0
|
|
|
|
await worker.stop()
|
|
await publisher.stop()
|
|
await host.stop()
|
|
|
|
|
|
@pytest.mark.grpc
|
|
@pytest.mark.asyncio
|
|
async def test_default_subscription_other_source() -> None:
|
|
host_address = "localhost:50057"
|
|
host = GrpcWorkerAgentRuntimeHost(address=host_address)
|
|
host.start()
|
|
runtime = GrpcWorkerAgentRuntime(host_address=host_address)
|
|
runtime.start()
|
|
publisher = GrpcWorkerAgentRuntime(host_address=host_address)
|
|
publisher.add_message_serializer(try_get_known_serializers_for_type(MessageType))
|
|
publisher.start()
|
|
|
|
await LoopbackAgentWithDefaultSubscription.register(runtime, "name", lambda: LoopbackAgentWithDefaultSubscription())
|
|
|
|
await publisher.publish_message(MessageType(), topic_id=DefaultTopicId(source="other"))
|
|
|
|
await asyncio.sleep(2)
|
|
|
|
# Agent in default namespace should have received the message
|
|
long_running_agent = await runtime.try_get_underlying_agent_instance(
|
|
AgentId("name", "default"), type=LoopbackAgentWithDefaultSubscription
|
|
)
|
|
assert long_running_agent.num_calls == 0
|
|
|
|
# Agent in other namespace should not have received the message
|
|
other_long_running_agent = await runtime.try_get_underlying_agent_instance(
|
|
AgentId("name", key="other"), type=LoopbackAgentWithDefaultSubscription
|
|
)
|
|
assert other_long_running_agent.num_calls == 1
|
|
|
|
await runtime.stop()
|
|
await publisher.stop()
|
|
await host.stop()
|
|
|
|
|
|
@pytest.mark.grpc
|
|
@pytest.mark.asyncio
|
|
async def test_type_subscription() -> None:
|
|
host_address = "localhost:50058"
|
|
host = GrpcWorkerAgentRuntimeHost(address=host_address)
|
|
host.start()
|
|
worker = GrpcWorkerAgentRuntime(host_address=host_address)
|
|
worker.start()
|
|
publisher = GrpcWorkerAgentRuntime(host_address=host_address)
|
|
publisher.add_message_serializer(try_get_known_serializers_for_type(MessageType))
|
|
publisher.start()
|
|
|
|
@type_subscription("Other")
|
|
class LoopbackAgentWithSubscription(LoopbackAgent): ...
|
|
|
|
await LoopbackAgentWithSubscription.register(worker, "name", lambda: LoopbackAgentWithSubscription())
|
|
|
|
await publisher.publish_message(MessageType(), topic_id=TopicId(type="Other", source="default"))
|
|
|
|
await asyncio.sleep(2)
|
|
|
|
# Agent in default topic source should have received the message.
|
|
long_running_agent = await worker.try_get_underlying_agent_instance(
|
|
AgentId("name", "default"), type=LoopbackAgentWithSubscription
|
|
)
|
|
assert long_running_agent.num_calls == 1
|
|
|
|
# Agent in other topic source should not have received the message.
|
|
other_long_running_agent = await worker.try_get_underlying_agent_instance(
|
|
AgentId("name", key="other"), type=LoopbackAgentWithSubscription
|
|
)
|
|
assert other_long_running_agent.num_calls == 0
|
|
|
|
await worker.stop()
|
|
await publisher.stop()
|
|
await host.stop()
|
|
|
|
|
|
@pytest.mark.grpc
|
|
@pytest.mark.asyncio
|
|
async def test_duplicate_subscription() -> None:
|
|
host_address = "localhost:50059"
|
|
host = GrpcWorkerAgentRuntimeHost(address=host_address)
|
|
worker1 = GrpcWorkerAgentRuntime(host_address=host_address)
|
|
worker1_2 = GrpcWorkerAgentRuntime(host_address=host_address)
|
|
host.start()
|
|
try:
|
|
worker1.start()
|
|
await NoopAgent.register(worker1, "worker1", lambda: NoopAgent())
|
|
|
|
worker1_2.start()
|
|
|
|
# Note: This passes because worker1 is still running
|
|
with pytest.raises(Exception, match="Agent type worker1 already registered"):
|
|
await NoopAgent.register(worker1_2, "worker1", lambda: NoopAgent())
|
|
|
|
# This is somehow covered in test_disconnected_agent as well as a stop will also disconnect the agent.
|
|
# Will keep them both for now as we might replace the way we simulate a disconnect
|
|
await worker1.stop()
|
|
|
|
with pytest.raises(ValueError):
|
|
await NoopAgent.register(worker1_2, "worker1", lambda: NoopAgent())
|
|
|
|
except Exception as ex:
|
|
raise ex
|
|
finally:
|
|
await worker1_2.stop()
|
|
await host.stop()
|
|
|
|
|
|
@pytest.mark.grpc
|
|
@pytest.mark.asyncio
|
|
async def test_disconnected_agent() -> None:
|
|
host_address = "localhost:50060"
|
|
host = GrpcWorkerAgentRuntimeHost(address=host_address)
|
|
host.start()
|
|
worker1 = GrpcWorkerAgentRuntime(host_address=host_address)
|
|
worker1_2 = GrpcWorkerAgentRuntime(host_address=host_address)
|
|
|
|
# TODO: Implementing `get_current_subscriptions` and `get_subscribed_recipients` requires access
|
|
# to some private properties. This needs to be updated once they are available publicly
|
|
|
|
def get_current_subscriptions() -> List[Subscription]:
|
|
return host._servicer._subscription_manager._subscriptions # type: ignore[reportPrivateUsage]
|
|
|
|
async def get_subscribed_recipients() -> List[AgentId]:
|
|
return await host._servicer._subscription_manager.get_subscribed_recipients(DefaultTopicId()) # type: ignore[reportPrivateUsage]
|
|
|
|
try:
|
|
worker1.start()
|
|
await LoopbackAgentWithDefaultSubscription.register(
|
|
worker1, "worker1", lambda: LoopbackAgentWithDefaultSubscription()
|
|
)
|
|
|
|
subscriptions1 = get_current_subscriptions()
|
|
assert len(subscriptions1) == 2
|
|
recipients1 = await get_subscribed_recipients()
|
|
assert AgentId(type="worker1", key="default") in recipients1
|
|
|
|
first_subscription_id = subscriptions1[0].id
|
|
|
|
await worker1.publish_message(ContentMessage(content="Hello!"), DefaultTopicId())
|
|
# This is a simple simulation of worker disconnct
|
|
if worker1._host_connection is not None: # type: ignore[reportPrivateUsage]
|
|
try:
|
|
await worker1._host_connection.close() # type: ignore[reportPrivateUsage]
|
|
except asyncio.CancelledError:
|
|
pass
|
|
|
|
await asyncio.sleep(1)
|
|
|
|
subscriptions2 = get_current_subscriptions()
|
|
assert len(subscriptions2) == 0
|
|
recipients2 = await get_subscribed_recipients()
|
|
assert len(recipients2) == 0
|
|
await asyncio.sleep(1)
|
|
|
|
worker1_2.start()
|
|
await LoopbackAgentWithDefaultSubscription.register(
|
|
worker1_2, "worker1", lambda: LoopbackAgentWithDefaultSubscription()
|
|
)
|
|
|
|
subscriptions3 = get_current_subscriptions()
|
|
assert len(subscriptions3) == 2
|
|
assert first_subscription_id not in [x.id for x in subscriptions3]
|
|
|
|
recipients3 = await get_subscribed_recipients()
|
|
assert len(set(recipients2)) == len(recipients2) # Make sure there are no duplicates
|
|
assert AgentId(type="worker1", key="default") in recipients3
|
|
except Exception as ex:
|
|
raise ex
|
|
finally:
|
|
await worker1.stop()
|
|
await worker1_2.stop()
|
|
|
|
|
|
@default_subscription
|
|
class ProtoReceivingAgent(RoutedAgent):
|
|
def __init__(self) -> None:
|
|
super().__init__("A loop back agent.")
|
|
self.num_calls = 0
|
|
self.received_messages: list[Any] = []
|
|
|
|
@event
|
|
async def on_new_message(self, message: ProtoMessage, ctx: MessageContext) -> None: # type: ignore
|
|
self.num_calls += 1
|
|
self.received_messages.append(message)
|
|
|
|
|
|
@pytest.mark.grpc
|
|
@pytest.mark.asyncio
|
|
async def test_proto_payloads() -> None:
|
|
host_address = "localhost:50057"
|
|
host = GrpcWorkerAgentRuntimeHost(address=host_address)
|
|
host.start()
|
|
receiver_runtime = GrpcWorkerAgentRuntime(
|
|
host_address=host_address, payload_serialization_format=PROTOBUF_DATA_CONTENT_TYPE
|
|
)
|
|
receiver_runtime.start()
|
|
publisher_runtime = GrpcWorkerAgentRuntime(
|
|
host_address=host_address, payload_serialization_format=PROTOBUF_DATA_CONTENT_TYPE
|
|
)
|
|
publisher_runtime.add_message_serializer(try_get_known_serializers_for_type(ProtoMessage))
|
|
publisher_runtime.start()
|
|
|
|
await ProtoReceivingAgent.register(receiver_runtime, "name", ProtoReceivingAgent)
|
|
|
|
await publisher_runtime.publish_message(ProtoMessage(message="Hello!"), topic_id=DefaultTopicId())
|
|
|
|
await asyncio.sleep(2)
|
|
|
|
# Agent in default namespace should have received the message
|
|
long_running_agent = await receiver_runtime.try_get_underlying_agent_instance(
|
|
AgentId("name", "default"), type=ProtoReceivingAgent
|
|
)
|
|
assert long_running_agent.num_calls == 1
|
|
assert long_running_agent.received_messages[0].message == "Hello!"
|
|
|
|
# Agent in other namespace should not have received the message
|
|
other_long_running_agent = await receiver_runtime.try_get_underlying_agent_instance(
|
|
AgentId("name", key="other"), type=ProtoReceivingAgent
|
|
)
|
|
assert other_long_running_agent.num_calls == 0
|
|
assert len(other_long_running_agent.received_messages) == 0
|
|
|
|
await receiver_runtime.stop()
|
|
await publisher_runtime.stop()
|
|
await host.stop()
|
|
|
|
|
|
# TODO add tests for failure to deserialize
|
|
|
|
|
|
@pytest.mark.grpc
|
|
@pytest.mark.asyncio
|
|
async def test_grpc_max_message_size() -> None:
|
|
default_max_size = 2**22
|
|
new_max_size = default_max_size * 2
|
|
small_message = ContentMessage(content="small message")
|
|
big_message = ContentMessage(content="." * (default_max_size + 1))
|
|
|
|
extra_grpc_config = [
|
|
("grpc.max_send_message_length", new_max_size),
|
|
("grpc.max_receive_message_length", new_max_size),
|
|
]
|
|
host_address = "localhost:50061"
|
|
host = GrpcWorkerAgentRuntimeHost(address=host_address, extra_grpc_config=extra_grpc_config)
|
|
worker1 = GrpcWorkerAgentRuntime(host_address=host_address, extra_grpc_config=extra_grpc_config)
|
|
worker2 = GrpcWorkerAgentRuntime(host_address=host_address)
|
|
worker3 = GrpcWorkerAgentRuntime(host_address=host_address, extra_grpc_config=extra_grpc_config)
|
|
|
|
try:
|
|
host.start()
|
|
worker1.start()
|
|
worker2.start()
|
|
worker3.start()
|
|
await LoopbackAgentWithDefaultSubscription.register(
|
|
worker1, "worker1", lambda: LoopbackAgentWithDefaultSubscription()
|
|
)
|
|
await LoopbackAgentWithDefaultSubscription.register(
|
|
worker2, "worker2", lambda: LoopbackAgentWithDefaultSubscription()
|
|
)
|
|
await LoopbackAgentWithDefaultSubscription.register(
|
|
worker3, "worker3", lambda: LoopbackAgentWithDefaultSubscription()
|
|
)
|
|
|
|
# with pytest.raises(Exception):
|
|
await worker1.publish_message(small_message, DefaultTopicId())
|
|
# This is a simple simulation of worker disconnct
|
|
await asyncio.sleep(1)
|
|
agent_instance_2 = await worker2.try_get_underlying_agent_instance(
|
|
AgentId("worker2", key="default"), type=LoopbackAgent
|
|
)
|
|
agent_instance_3 = await worker3.try_get_underlying_agent_instance(
|
|
AgentId("worker3", key="default"), type=LoopbackAgent
|
|
)
|
|
assert agent_instance_2.num_calls == 1
|
|
assert agent_instance_3.num_calls == 1
|
|
|
|
await worker1.publish_message(big_message, DefaultTopicId())
|
|
await asyncio.sleep(2)
|
|
assert agent_instance_2.num_calls == 1 # Worker 2 won't receive the big message
|
|
assert agent_instance_3.num_calls == 2 # Worker 3 will receive the big message as has increased message length
|
|
except Exception as e:
|
|
raise e
|
|
finally:
|
|
await worker1.stop()
|
|
# await worker2.stop() # Worker 2 somehow breaks can can not be stopped.
|
|
await worker3.stop()
|
|
|
|
await host.stop()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
os.environ["GRPC_VERBOSITY"] = "DEBUG"
|
|
os.environ["GRPC_TRACE"] = "all"
|
|
|
|
asyncio.run(test_disconnected_agent())
|
|
asyncio.run(test_grpc_max_message_size())
|