mirror of https://github.com/microsoft/autogen.git
Fix deprecated usages (#4374)
This commit is contained in:
parent
fe96f7de24
commit
45f16f534b
|
@ -29,6 +29,7 @@ include = ["src/**", "tests/*.py"]
|
|||
[tool.pyright]
|
||||
extends = "../../pyproject.toml"
|
||||
include = ["src", "tests"]
|
||||
reportDeprecated = true
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
minversion = "6.0"
|
||||
|
|
|
@ -92,7 +92,7 @@ include = ["src/**", "samples/*.py", "docs/**/*.ipynb", "tests/**"]
|
|||
extends = "../../pyproject.toml"
|
||||
include = ["src", "tests", "samples"]
|
||||
exclude = ["src/autogen_core/application/protos", "tests/protos"]
|
||||
reportDeprecated = false
|
||||
reportDeprecated = true
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
minversion = "6.0"
|
||||
|
|
|
@ -33,7 +33,13 @@ from typing import Any, Mapping, Optional
|
|||
from autogen_core.application import SingleThreadedAgentRuntime
|
||||
from autogen_core.base import AgentId, CancellationToken, MessageContext
|
||||
from autogen_core.base.intervention import DefaultInterventionHandler
|
||||
from autogen_core.components import DefaultSubscription, DefaultTopicId, FunctionCall, RoutedAgent, message_handler
|
||||
from autogen_core.components import (
|
||||
DefaultTopicId,
|
||||
FunctionCall,
|
||||
RoutedAgent,
|
||||
message_handler,
|
||||
type_subscription,
|
||||
)
|
||||
from autogen_core.components.model_context import BufferedChatCompletionContext
|
||||
from autogen_core.components.models import (
|
||||
AssistantMessage,
|
||||
|
@ -81,6 +87,7 @@ class MockPersistence:
|
|||
state_persister = MockPersistence()
|
||||
|
||||
|
||||
@type_subscription("scheduling_assistant_conversation")
|
||||
class SlowUserProxyAgent(RoutedAgent):
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -132,6 +139,7 @@ class ScheduleMeetingTool(BaseTool[ScheduleMeetingInput, ScheduleMeetingOutput])
|
|||
return ScheduleMeetingOutput()
|
||||
|
||||
|
||||
@type_subscription("scheduling_assistant_conversation")
|
||||
class SchedulingAssistantAgent(RoutedAgent):
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -256,16 +264,13 @@ async def main(latest_user_input: Optional[str] = None) -> None | str:
|
|||
needs_user_input_handler = NeedsUserInputHandler()
|
||||
runtime = SingleThreadedAgentRuntime(intervention_handlers=[needs_user_input_handler, termination_handler])
|
||||
|
||||
await runtime.register(
|
||||
"User",
|
||||
lambda: SlowUserProxyAgent("User", "I am a user"),
|
||||
subscriptions=lambda: [DefaultSubscription("scheduling_assistant_conversation")],
|
||||
)
|
||||
await SlowUserProxyAgent.register(runtime, "User", lambda: SlowUserProxyAgent("User", "I am a user"))
|
||||
|
||||
initial_schedule_assistant_message = AssistantTextMessage(
|
||||
content="Hi! How can I help you? I can help schedule meetings", source="User"
|
||||
)
|
||||
await runtime.register(
|
||||
await SchedulingAssistantAgent.register(
|
||||
runtime,
|
||||
"SchedulingAssistant",
|
||||
lambda: SchedulingAssistantAgent(
|
||||
"SchedulingAssistant",
|
||||
|
@ -273,7 +278,6 @@ async def main(latest_user_input: Optional[str] = None) -> None | str:
|
|||
model_client=get_chat_completion_client_from_envs(model="gpt-4o-mini"),
|
||||
initial_message=initial_schedule_assistant_message,
|
||||
),
|
||||
subscriptions=lambda: [DefaultSubscription("scheduling_assistant_conversation")],
|
||||
)
|
||||
|
||||
if latest_user_input is not None:
|
||||
|
|
|
@ -59,7 +59,7 @@ class NestingLongRunningAgent(RoutedAgent):
|
|||
async def test_cancellation_with_token() -> None:
|
||||
runtime = SingleThreadedAgentRuntime()
|
||||
|
||||
await runtime.register("long_running", LongRunningAgent)
|
||||
await LongRunningAgent.register(runtime, "long_running", LongRunningAgent)
|
||||
agent_id = AgentId("long_running", key="default")
|
||||
token = CancellationToken()
|
||||
response = asyncio.create_task(runtime.send_message(MessageType(), recipient=agent_id, cancellation_token=token))
|
||||
|
@ -85,8 +85,9 @@ async def test_cancellation_with_token() -> None:
|
|||
async def test_nested_cancellation_only_outer_called() -> None:
|
||||
runtime = SingleThreadedAgentRuntime()
|
||||
|
||||
await runtime.register("long_running", LongRunningAgent)
|
||||
await runtime.register(
|
||||
await LongRunningAgent.register(runtime, "long_running", LongRunningAgent)
|
||||
await NestingLongRunningAgent.register(
|
||||
runtime,
|
||||
"nested",
|
||||
lambda: NestingLongRunningAgent(AgentId("long_running", key=AgentInstantiationContext.current_agent_id().key)),
|
||||
)
|
||||
|
@ -119,8 +120,9 @@ async def test_nested_cancellation_only_outer_called() -> None:
|
|||
async def test_nested_cancellation_inner_called() -> None:
|
||||
runtime = SingleThreadedAgentRuntime()
|
||||
|
||||
await runtime.register("long_running", LongRunningAgent)
|
||||
await runtime.register(
|
||||
await LongRunningAgent.register(runtime, "long_running", LongRunningAgent)
|
||||
await NestingLongRunningAgent.register(
|
||||
runtime,
|
||||
"nested",
|
||||
lambda: NestingLongRunningAgent(AgentId("long_running", key=AgentInstantiationContext.current_agent_id().key)),
|
||||
)
|
||||
|
|
|
@ -18,7 +18,7 @@ async def test_intervention_count_messages() -> None:
|
|||
|
||||
handler = DebugInterventionHandler()
|
||||
runtime = SingleThreadedAgentRuntime(intervention_handlers=[handler])
|
||||
await runtime.register("name", LoopbackAgent)
|
||||
await LoopbackAgent.register(runtime, "name", LoopbackAgent)
|
||||
loopback = AgentId("name", key="default")
|
||||
runtime.start()
|
||||
|
||||
|
@ -42,7 +42,7 @@ async def test_intervention_drop_send() -> None:
|
|||
handler = DropSendInterventionHandler()
|
||||
runtime = SingleThreadedAgentRuntime(intervention_handlers=[handler])
|
||||
|
||||
await runtime.register("name", LoopbackAgent)
|
||||
await LoopbackAgent.register(runtime, "name", LoopbackAgent)
|
||||
loopback = AgentId("name", key="default")
|
||||
runtime.start()
|
||||
|
||||
|
@ -66,7 +66,7 @@ async def test_intervention_drop_response() -> None:
|
|||
handler = DropResponseInterventionHandler()
|
||||
runtime = SingleThreadedAgentRuntime(intervention_handlers=[handler])
|
||||
|
||||
await runtime.register("name", LoopbackAgent)
|
||||
await LoopbackAgent.register(runtime, "name", LoopbackAgent)
|
||||
loopback = AgentId("name", key="default")
|
||||
runtime.start()
|
||||
|
||||
|
@ -90,7 +90,7 @@ async def test_intervention_raise_exception_on_send() -> None:
|
|||
handler = ExceptionInterventionHandler()
|
||||
runtime = SingleThreadedAgentRuntime(intervention_handlers=[handler])
|
||||
|
||||
await runtime.register("name", LoopbackAgent)
|
||||
await LoopbackAgent.register(runtime, "name", LoopbackAgent)
|
||||
loopback = AgentId("name", key="default")
|
||||
runtime.start()
|
||||
|
||||
|
@ -117,7 +117,7 @@ async def test_intervention_raise_exception_on_respond() -> None:
|
|||
handler = ExceptionInterventionHandler()
|
||||
runtime = SingleThreadedAgentRuntime(intervention_handlers=[handler])
|
||||
|
||||
await runtime.register("name", LoopbackAgent)
|
||||
await LoopbackAgent.register(runtime, "name", LoopbackAgent)
|
||||
loopback = AgentId("name", key="default")
|
||||
runtime.start()
|
||||
with pytest.raises(InterventionException):
|
||||
|
|
|
@ -37,7 +37,8 @@ class CounterAgent(RoutedAgent):
|
|||
async def test_routed_agent(caplog: pytest.LogCaptureFixture) -> None:
|
||||
runtime = SingleThreadedAgentRuntime()
|
||||
with caplog.at_level(logging.INFO):
|
||||
await runtime.register("loopback", LoopbackAgent, lambda: [TypeSubscription("default", "loopback")])
|
||||
await LoopbackAgent.register(runtime, "loopback", LoopbackAgent)
|
||||
await runtime.add_subscription(TypeSubscription("default", "loopback"))
|
||||
runtime.start()
|
||||
await runtime.publish_message(UnhandledMessageType(), topic_id=TopicId("default", "default"))
|
||||
await runtime.stop_when_idle()
|
||||
|
@ -47,7 +48,8 @@ async def test_routed_agent(caplog: pytest.LogCaptureFixture) -> None:
|
|||
@pytest.mark.asyncio
|
||||
async def test_message_handler_router() -> None:
|
||||
runtime = SingleThreadedAgentRuntime()
|
||||
await runtime.register("counter", CounterAgent, lambda: [TypeSubscription("default", "counter")])
|
||||
await CounterAgent.register(runtime, "counter", CounterAgent)
|
||||
await runtime.add_subscription(TypeSubscription("default", "counter"))
|
||||
agent_id = AgentId(type="counter", key="default")
|
||||
|
||||
# Send a broadcast message.
|
||||
|
@ -94,7 +96,7 @@ class RoutedAgentMessageCustomMatch(RoutedAgent):
|
|||
@pytest.mark.asyncio
|
||||
async def test_routed_agent_message_matching() -> None:
|
||||
runtime = SingleThreadedAgentRuntime()
|
||||
await runtime.register("message_match", RoutedAgentMessageCustomMatch)
|
||||
await RoutedAgentMessageCustomMatch.register(runtime, "message_match", RoutedAgentMessageCustomMatch)
|
||||
agent_id = AgentId(type="message_match", key="default")
|
||||
|
||||
agent = await runtime.try_get_underlying_agent_instance(agent_id, type=RoutedAgentMessageCustomMatch)
|
||||
|
@ -134,7 +136,8 @@ class EventAgent(RoutedAgent):
|
|||
@pytest.mark.asyncio
|
||||
async def test_event() -> None:
|
||||
runtime = SingleThreadedAgentRuntime()
|
||||
await runtime.register("counter", EventAgent, lambda: [TypeSubscription("default", "counter")])
|
||||
await EventAgent.register(runtime, "counter", EventAgent)
|
||||
await runtime.add_subscription(TypeSubscription("default", "counter"))
|
||||
agent_id = AgentId(type="counter", key="default")
|
||||
|
||||
# Send a broadcast message.
|
||||
|
@ -181,7 +184,8 @@ class RPCAgent(RoutedAgent):
|
|||
@pytest.mark.asyncio
|
||||
async def test_rpc() -> None:
|
||||
runtime = SingleThreadedAgentRuntime()
|
||||
await runtime.register("counter", RPCAgent, lambda: [TypeSubscription("default", "counter")])
|
||||
await RPCAgent.register(runtime, "counter", RPCAgent)
|
||||
await runtime.add_subscription(TypeSubscription("default", "counter"))
|
||||
agent_id = AgentId(type="counter", key="default")
|
||||
|
||||
# Send an RPC message.
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
import asyncio
|
||||
import logging
|
||||
|
||||
import pytest
|
||||
|
@ -7,16 +6,10 @@ from autogen_core.base import (
|
|||
AgentId,
|
||||
AgentInstantiationContext,
|
||||
AgentType,
|
||||
Subscription,
|
||||
SubscriptionInstantiationContext,
|
||||
TopicId,
|
||||
try_get_known_serializers_for_type,
|
||||
)
|
||||
from autogen_core.components import (
|
||||
DefaultTopicId,
|
||||
TypeSubscription,
|
||||
type_subscription,
|
||||
)
|
||||
from autogen_core.components import DefaultTopicId, TypeSubscription, type_subscription
|
||||
from opentelemetry.sdk.trace import TracerProvider
|
||||
from test_utils import (
|
||||
CascadingAgent,
|
||||
|
@ -146,82 +139,9 @@ async def test_register_receives_publish_cascade() -> None:
|
|||
async def test_register_factory_explicit_name() -> None:
|
||||
runtime = SingleThreadedAgentRuntime()
|
||||
|
||||
await runtime.register("name", LoopbackAgent, lambda: [TypeSubscription("default", "name")])
|
||||
runtime.start()
|
||||
agent_id = AgentId("name", key="default")
|
||||
topic_id = TopicId("default", "default")
|
||||
await runtime.publish_message(MessageType(), topic_id=topic_id)
|
||||
await LoopbackAgent.register(runtime, "name", LoopbackAgent)
|
||||
await runtime.add_subscription(TypeSubscription("default", "name"))
|
||||
|
||||
await runtime.stop_when_idle()
|
||||
|
||||
# Agent in default namespace should have received the message
|
||||
long_running_agent = await runtime.try_get_underlying_agent_instance(agent_id, type=LoopbackAgent)
|
||||
assert long_running_agent.num_calls == 1
|
||||
|
||||
# Agent in other namespace should not have received the message
|
||||
other_long_running_agent: LoopbackAgent = await runtime.try_get_underlying_agent_instance(
|
||||
AgentId("name", key="other"), type=LoopbackAgent
|
||||
)
|
||||
assert other_long_running_agent.num_calls == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_factory_context_var_name() -> None:
|
||||
runtime = SingleThreadedAgentRuntime()
|
||||
|
||||
await runtime.register(
|
||||
"name", LoopbackAgent, lambda: [TypeSubscription("default", SubscriptionInstantiationContext.agent_type().type)]
|
||||
)
|
||||
runtime.start()
|
||||
agent_id = AgentId("name", key="default")
|
||||
topic_id = TopicId("default", "default")
|
||||
await runtime.publish_message(MessageType(), topic_id=topic_id)
|
||||
|
||||
await runtime.stop_when_idle()
|
||||
|
||||
# Agent in default namespace should have received the message
|
||||
long_running_agent = await runtime.try_get_underlying_agent_instance(agent_id, type=LoopbackAgent)
|
||||
assert long_running_agent.num_calls == 1
|
||||
|
||||
# Agent in other namespace should not have received the message
|
||||
other_long_running_agent: LoopbackAgent = await runtime.try_get_underlying_agent_instance(
|
||||
AgentId("name", key="other"), type=LoopbackAgent
|
||||
)
|
||||
assert other_long_running_agent.num_calls == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_factory_async() -> None:
|
||||
runtime = SingleThreadedAgentRuntime()
|
||||
|
||||
async def sub_factory() -> list[Subscription]:
|
||||
await asyncio.sleep(0.1)
|
||||
return [TypeSubscription("default", SubscriptionInstantiationContext.agent_type().type)]
|
||||
|
||||
await runtime.register("name", LoopbackAgent, sub_factory)
|
||||
runtime.start()
|
||||
agent_id = AgentId("name", key="default")
|
||||
topic_id = TopicId("default", "default")
|
||||
await runtime.publish_message(MessageType(), topic_id=topic_id)
|
||||
|
||||
await runtime.stop_when_idle()
|
||||
|
||||
# Agent in default namespace should have received the message
|
||||
long_running_agent = await runtime.try_get_underlying_agent_instance(agent_id, type=LoopbackAgent)
|
||||
assert long_running_agent.num_calls == 1
|
||||
|
||||
# Agent in other namespace should not have received the message
|
||||
other_long_running_agent: LoopbackAgent = await runtime.try_get_underlying_agent_instance(
|
||||
AgentId("name", key="other"), type=LoopbackAgent
|
||||
)
|
||||
assert other_long_running_agent.num_calls == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_factory_direct_list() -> None:
|
||||
runtime = SingleThreadedAgentRuntime()
|
||||
|
||||
await runtime.register("name", LoopbackAgent, [TypeSubscription("default", "name")])
|
||||
runtime.start()
|
||||
agent_id = AgentId("name", key="default")
|
||||
topic_id = TopicId("default", "default")
|
||||
|
|
|
@ -24,7 +24,7 @@ class StatefulAgent(BaseAgent):
|
|||
async def test_agent_can_save_state() -> None:
|
||||
runtime = SingleThreadedAgentRuntime()
|
||||
|
||||
await runtime.register("name1", StatefulAgent)
|
||||
await StatefulAgent.register(runtime, "name1", StatefulAgent)
|
||||
agent1_id = AgentId("name1", key="default")
|
||||
agent1: StatefulAgent = await runtime.try_get_underlying_agent_instance(agent1_id, type=StatefulAgent)
|
||||
assert agent1.state == 0
|
||||
|
@ -44,7 +44,7 @@ async def test_agent_can_save_state() -> None:
|
|||
async def test_runtime_can_save_state() -> None:
|
||||
runtime = SingleThreadedAgentRuntime()
|
||||
|
||||
await runtime.register("name1", StatefulAgent)
|
||||
await StatefulAgent.register(runtime, "name1", StatefulAgent)
|
||||
agent1_id = AgentId("name1", key="default")
|
||||
agent1: StatefulAgent = await runtime.try_get_underlying_agent_instance(agent1_id, type=StatefulAgent)
|
||||
assert agent1.state == 0
|
||||
|
@ -54,7 +54,7 @@ async def test_runtime_can_save_state() -> None:
|
|||
runtime_state = await runtime.save_state()
|
||||
|
||||
runtime2 = SingleThreadedAgentRuntime()
|
||||
await runtime2.register("name1", StatefulAgent)
|
||||
await StatefulAgent.register(runtime2, "name1", StatefulAgent)
|
||||
agent2_id = AgentId("name1", key="default")
|
||||
agent2: StatefulAgent = await runtime2.try_get_underlying_agent_instance(agent2_id, type=StatefulAgent)
|
||||
|
||||
|
|
|
@ -27,7 +27,7 @@ def test_type_subscription_map() -> None:
|
|||
async def test_non_default_default_subscription() -> None:
|
||||
runtime = SingleThreadedAgentRuntime()
|
||||
|
||||
await runtime.register("MyAgent", LoopbackAgent)
|
||||
await LoopbackAgent.register(runtime, "MyAgent", LoopbackAgent, skip_class_subscriptions=True)
|
||||
runtime.start()
|
||||
await runtime.publish_message(MessageType(), topic_id=DefaultTopicId())
|
||||
await runtime.stop_when_idle()
|
||||
|
|
|
@ -43,7 +43,8 @@ async def _async_sleep_function(input: str) -> str:
|
|||
@pytest.mark.asyncio
|
||||
async def test_tool_agent() -> None:
|
||||
runtime = SingleThreadedAgentRuntime()
|
||||
await runtime.register(
|
||||
await ToolAgent.register(
|
||||
runtime,
|
||||
"tool_agent",
|
||||
lambda: ToolAgent(
|
||||
description="Tool agent",
|
||||
|
@ -143,7 +144,8 @@ async def test_caller_loop() -> None:
|
|||
client = MockChatCompletionClient()
|
||||
tools: List[Tool] = [FunctionTool(_pass_function, name="pass", description="Pass function")]
|
||||
runtime = SingleThreadedAgentRuntime()
|
||||
await runtime.register(
|
||||
await ToolAgent.register(
|
||||
runtime,
|
||||
"tool_agent",
|
||||
lambda: ToolAgent(
|
||||
description="Tool agent",
|
||||
|
|
Loading…
Reference in New Issue