Fix deprecated usages (#4374)

This commit is contained in:
Jack Gerrits 2024-11-26 19:31:23 -05:00 committed by GitHub
parent fe96f7de24
commit 45f16f534b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 46 additions and 113 deletions

View File

@ -29,6 +29,7 @@ include = ["src/**", "tests/*.py"]
[tool.pyright]
extends = "../../pyproject.toml"
include = ["src", "tests"]
reportDeprecated = true
[tool.pytest.ini_options]
minversion = "6.0"

View File

@ -92,7 +92,7 @@ include = ["src/**", "samples/*.py", "docs/**/*.ipynb", "tests/**"]
extends = "../../pyproject.toml"
include = ["src", "tests", "samples"]
exclude = ["src/autogen_core/application/protos", "tests/protos"]
reportDeprecated = false
reportDeprecated = true
[tool.pytest.ini_options]
minversion = "6.0"

View File

@ -33,7 +33,13 @@ from typing import Any, Mapping, Optional
from autogen_core.application import SingleThreadedAgentRuntime
from autogen_core.base import AgentId, CancellationToken, MessageContext
from autogen_core.base.intervention import DefaultInterventionHandler
from autogen_core.components import DefaultSubscription, DefaultTopicId, FunctionCall, RoutedAgent, message_handler
from autogen_core.components import (
DefaultTopicId,
FunctionCall,
RoutedAgent,
message_handler,
type_subscription,
)
from autogen_core.components.model_context import BufferedChatCompletionContext
from autogen_core.components.models import (
AssistantMessage,
@ -81,6 +87,7 @@ class MockPersistence:
state_persister = MockPersistence()
@type_subscription("scheduling_assistant_conversation")
class SlowUserProxyAgent(RoutedAgent):
def __init__(
self,
@ -132,6 +139,7 @@ class ScheduleMeetingTool(BaseTool[ScheduleMeetingInput, ScheduleMeetingOutput])
return ScheduleMeetingOutput()
@type_subscription("scheduling_assistant_conversation")
class SchedulingAssistantAgent(RoutedAgent):
def __init__(
self,
@ -256,16 +264,13 @@ async def main(latest_user_input: Optional[str] = None) -> None | str:
needs_user_input_handler = NeedsUserInputHandler()
runtime = SingleThreadedAgentRuntime(intervention_handlers=[needs_user_input_handler, termination_handler])
await runtime.register(
"User",
lambda: SlowUserProxyAgent("User", "I am a user"),
subscriptions=lambda: [DefaultSubscription("scheduling_assistant_conversation")],
)
await SlowUserProxyAgent.register(runtime, "User", lambda: SlowUserProxyAgent("User", "I am a user"))
initial_schedule_assistant_message = AssistantTextMessage(
content="Hi! How can I help you? I can help schedule meetings", source="User"
)
await runtime.register(
await SchedulingAssistantAgent.register(
runtime,
"SchedulingAssistant",
lambda: SchedulingAssistantAgent(
"SchedulingAssistant",
@ -273,7 +278,6 @@ async def main(latest_user_input: Optional[str] = None) -> None | str:
model_client=get_chat_completion_client_from_envs(model="gpt-4o-mini"),
initial_message=initial_schedule_assistant_message,
),
subscriptions=lambda: [DefaultSubscription("scheduling_assistant_conversation")],
)
if latest_user_input is not None:

View File

@ -59,7 +59,7 @@ class NestingLongRunningAgent(RoutedAgent):
async def test_cancellation_with_token() -> None:
runtime = SingleThreadedAgentRuntime()
await runtime.register("long_running", LongRunningAgent)
await LongRunningAgent.register(runtime, "long_running", LongRunningAgent)
agent_id = AgentId("long_running", key="default")
token = CancellationToken()
response = asyncio.create_task(runtime.send_message(MessageType(), recipient=agent_id, cancellation_token=token))
@ -85,8 +85,9 @@ async def test_cancellation_with_token() -> None:
async def test_nested_cancellation_only_outer_called() -> None:
runtime = SingleThreadedAgentRuntime()
await runtime.register("long_running", LongRunningAgent)
await runtime.register(
await LongRunningAgent.register(runtime, "long_running", LongRunningAgent)
await NestingLongRunningAgent.register(
runtime,
"nested",
lambda: NestingLongRunningAgent(AgentId("long_running", key=AgentInstantiationContext.current_agent_id().key)),
)
@ -119,8 +120,9 @@ async def test_nested_cancellation_only_outer_called() -> None:
async def test_nested_cancellation_inner_called() -> None:
runtime = SingleThreadedAgentRuntime()
await runtime.register("long_running", LongRunningAgent)
await runtime.register(
await LongRunningAgent.register(runtime, "long_running", LongRunningAgent)
await NestingLongRunningAgent.register(
runtime,
"nested",
lambda: NestingLongRunningAgent(AgentId("long_running", key=AgentInstantiationContext.current_agent_id().key)),
)

View File

@ -18,7 +18,7 @@ async def test_intervention_count_messages() -> None:
handler = DebugInterventionHandler()
runtime = SingleThreadedAgentRuntime(intervention_handlers=[handler])
await runtime.register("name", LoopbackAgent)
await LoopbackAgent.register(runtime, "name", LoopbackAgent)
loopback = AgentId("name", key="default")
runtime.start()
@ -42,7 +42,7 @@ async def test_intervention_drop_send() -> None:
handler = DropSendInterventionHandler()
runtime = SingleThreadedAgentRuntime(intervention_handlers=[handler])
await runtime.register("name", LoopbackAgent)
await LoopbackAgent.register(runtime, "name", LoopbackAgent)
loopback = AgentId("name", key="default")
runtime.start()
@ -66,7 +66,7 @@ async def test_intervention_drop_response() -> None:
handler = DropResponseInterventionHandler()
runtime = SingleThreadedAgentRuntime(intervention_handlers=[handler])
await runtime.register("name", LoopbackAgent)
await LoopbackAgent.register(runtime, "name", LoopbackAgent)
loopback = AgentId("name", key="default")
runtime.start()
@ -90,7 +90,7 @@ async def test_intervention_raise_exception_on_send() -> None:
handler = ExceptionInterventionHandler()
runtime = SingleThreadedAgentRuntime(intervention_handlers=[handler])
await runtime.register("name", LoopbackAgent)
await LoopbackAgent.register(runtime, "name", LoopbackAgent)
loopback = AgentId("name", key="default")
runtime.start()
@ -117,7 +117,7 @@ async def test_intervention_raise_exception_on_respond() -> None:
handler = ExceptionInterventionHandler()
runtime = SingleThreadedAgentRuntime(intervention_handlers=[handler])
await runtime.register("name", LoopbackAgent)
await LoopbackAgent.register(runtime, "name", LoopbackAgent)
loopback = AgentId("name", key="default")
runtime.start()
with pytest.raises(InterventionException):

View File

@ -37,7 +37,8 @@ class CounterAgent(RoutedAgent):
async def test_routed_agent(caplog: pytest.LogCaptureFixture) -> None:
runtime = SingleThreadedAgentRuntime()
with caplog.at_level(logging.INFO):
await runtime.register("loopback", LoopbackAgent, lambda: [TypeSubscription("default", "loopback")])
await LoopbackAgent.register(runtime, "loopback", LoopbackAgent)
await runtime.add_subscription(TypeSubscription("default", "loopback"))
runtime.start()
await runtime.publish_message(UnhandledMessageType(), topic_id=TopicId("default", "default"))
await runtime.stop_when_idle()
@ -47,7 +48,8 @@ async def test_routed_agent(caplog: pytest.LogCaptureFixture) -> None:
@pytest.mark.asyncio
async def test_message_handler_router() -> None:
runtime = SingleThreadedAgentRuntime()
await runtime.register("counter", CounterAgent, lambda: [TypeSubscription("default", "counter")])
await CounterAgent.register(runtime, "counter", CounterAgent)
await runtime.add_subscription(TypeSubscription("default", "counter"))
agent_id = AgentId(type="counter", key="default")
# Send a broadcast message.
@ -94,7 +96,7 @@ class RoutedAgentMessageCustomMatch(RoutedAgent):
@pytest.mark.asyncio
async def test_routed_agent_message_matching() -> None:
runtime = SingleThreadedAgentRuntime()
await runtime.register("message_match", RoutedAgentMessageCustomMatch)
await RoutedAgentMessageCustomMatch.register(runtime, "message_match", RoutedAgentMessageCustomMatch)
agent_id = AgentId(type="message_match", key="default")
agent = await runtime.try_get_underlying_agent_instance(agent_id, type=RoutedAgentMessageCustomMatch)
@ -134,7 +136,8 @@ class EventAgent(RoutedAgent):
@pytest.mark.asyncio
async def test_event() -> None:
runtime = SingleThreadedAgentRuntime()
await runtime.register("counter", EventAgent, lambda: [TypeSubscription("default", "counter")])
await EventAgent.register(runtime, "counter", EventAgent)
await runtime.add_subscription(TypeSubscription("default", "counter"))
agent_id = AgentId(type="counter", key="default")
# Send a broadcast message.
@ -181,7 +184,8 @@ class RPCAgent(RoutedAgent):
@pytest.mark.asyncio
async def test_rpc() -> None:
runtime = SingleThreadedAgentRuntime()
await runtime.register("counter", RPCAgent, lambda: [TypeSubscription("default", "counter")])
await RPCAgent.register(runtime, "counter", RPCAgent)
await runtime.add_subscription(TypeSubscription("default", "counter"))
agent_id = AgentId(type="counter", key="default")
# Send an RPC message.

View File

@ -1,4 +1,3 @@
import asyncio
import logging
import pytest
@ -7,16 +6,10 @@ from autogen_core.base import (
AgentId,
AgentInstantiationContext,
AgentType,
Subscription,
SubscriptionInstantiationContext,
TopicId,
try_get_known_serializers_for_type,
)
from autogen_core.components import (
DefaultTopicId,
TypeSubscription,
type_subscription,
)
from autogen_core.components import DefaultTopicId, TypeSubscription, type_subscription
from opentelemetry.sdk.trace import TracerProvider
from test_utils import (
CascadingAgent,
@ -146,82 +139,9 @@ async def test_register_receives_publish_cascade() -> None:
async def test_register_factory_explicit_name() -> None:
runtime = SingleThreadedAgentRuntime()
await runtime.register("name", LoopbackAgent, lambda: [TypeSubscription("default", "name")])
runtime.start()
agent_id = AgentId("name", key="default")
topic_id = TopicId("default", "default")
await runtime.publish_message(MessageType(), topic_id=topic_id)
await LoopbackAgent.register(runtime, "name", LoopbackAgent)
await runtime.add_subscription(TypeSubscription("default", "name"))
await runtime.stop_when_idle()
# Agent in default namespace should have received the message
long_running_agent = await runtime.try_get_underlying_agent_instance(agent_id, type=LoopbackAgent)
assert long_running_agent.num_calls == 1
# Agent in other namespace should not have received the message
other_long_running_agent: LoopbackAgent = await runtime.try_get_underlying_agent_instance(
AgentId("name", key="other"), type=LoopbackAgent
)
assert other_long_running_agent.num_calls == 0
@pytest.mark.asyncio
async def test_register_factory_context_var_name() -> None:
runtime = SingleThreadedAgentRuntime()
await runtime.register(
"name", LoopbackAgent, lambda: [TypeSubscription("default", SubscriptionInstantiationContext.agent_type().type)]
)
runtime.start()
agent_id = AgentId("name", key="default")
topic_id = TopicId("default", "default")
await runtime.publish_message(MessageType(), topic_id=topic_id)
await runtime.stop_when_idle()
# Agent in default namespace should have received the message
long_running_agent = await runtime.try_get_underlying_agent_instance(agent_id, type=LoopbackAgent)
assert long_running_agent.num_calls == 1
# Agent in other namespace should not have received the message
other_long_running_agent: LoopbackAgent = await runtime.try_get_underlying_agent_instance(
AgentId("name", key="other"), type=LoopbackAgent
)
assert other_long_running_agent.num_calls == 0
@pytest.mark.asyncio
async def test_register_factory_async() -> None:
runtime = SingleThreadedAgentRuntime()
async def sub_factory() -> list[Subscription]:
await asyncio.sleep(0.1)
return [TypeSubscription("default", SubscriptionInstantiationContext.agent_type().type)]
await runtime.register("name", LoopbackAgent, sub_factory)
runtime.start()
agent_id = AgentId("name", key="default")
topic_id = TopicId("default", "default")
await runtime.publish_message(MessageType(), topic_id=topic_id)
await runtime.stop_when_idle()
# Agent in default namespace should have received the message
long_running_agent = await runtime.try_get_underlying_agent_instance(agent_id, type=LoopbackAgent)
assert long_running_agent.num_calls == 1
# Agent in other namespace should not have received the message
other_long_running_agent: LoopbackAgent = await runtime.try_get_underlying_agent_instance(
AgentId("name", key="other"), type=LoopbackAgent
)
assert other_long_running_agent.num_calls == 0
@pytest.mark.asyncio
async def test_register_factory_direct_list() -> None:
runtime = SingleThreadedAgentRuntime()
await runtime.register("name", LoopbackAgent, [TypeSubscription("default", "name")])
runtime.start()
agent_id = AgentId("name", key="default")
topic_id = TopicId("default", "default")

View File

@ -24,7 +24,7 @@ class StatefulAgent(BaseAgent):
async def test_agent_can_save_state() -> None:
runtime = SingleThreadedAgentRuntime()
await runtime.register("name1", StatefulAgent)
await StatefulAgent.register(runtime, "name1", StatefulAgent)
agent1_id = AgentId("name1", key="default")
agent1: StatefulAgent = await runtime.try_get_underlying_agent_instance(agent1_id, type=StatefulAgent)
assert agent1.state == 0
@ -44,7 +44,7 @@ async def test_agent_can_save_state() -> None:
async def test_runtime_can_save_state() -> None:
runtime = SingleThreadedAgentRuntime()
await runtime.register("name1", StatefulAgent)
await StatefulAgent.register(runtime, "name1", StatefulAgent)
agent1_id = AgentId("name1", key="default")
agent1: StatefulAgent = await runtime.try_get_underlying_agent_instance(agent1_id, type=StatefulAgent)
assert agent1.state == 0
@ -54,7 +54,7 @@ async def test_runtime_can_save_state() -> None:
runtime_state = await runtime.save_state()
runtime2 = SingleThreadedAgentRuntime()
await runtime2.register("name1", StatefulAgent)
await StatefulAgent.register(runtime2, "name1", StatefulAgent)
agent2_id = AgentId("name1", key="default")
agent2: StatefulAgent = await runtime2.try_get_underlying_agent_instance(agent2_id, type=StatefulAgent)

View File

@ -27,7 +27,7 @@ def test_type_subscription_map() -> None:
async def test_non_default_default_subscription() -> None:
runtime = SingleThreadedAgentRuntime()
await runtime.register("MyAgent", LoopbackAgent)
await LoopbackAgent.register(runtime, "MyAgent", LoopbackAgent, skip_class_subscriptions=True)
runtime.start()
await runtime.publish_message(MessageType(), topic_id=DefaultTopicId())
await runtime.stop_when_idle()

View File

@ -43,7 +43,8 @@ async def _async_sleep_function(input: str) -> str:
@pytest.mark.asyncio
async def test_tool_agent() -> None:
runtime = SingleThreadedAgentRuntime()
await runtime.register(
await ToolAgent.register(
runtime,
"tool_agent",
lambda: ToolAgent(
description="Tool agent",
@ -143,7 +144,8 @@ async def test_caller_loop() -> None:
client = MockChatCompletionClient()
tools: List[Tool] = [FunctionTool(_pass_function, name="pass", description="Pass function")]
runtime = SingleThreadedAgentRuntime()
await runtime.register(
await ToolAgent.register(
runtime,
"tool_agent",
lambda: ToolAgent(
description="Tool agent",