Register returns AgentType (#382)

This commit is contained in:
Jack Gerrits 2024-08-20 17:38:36 -04:00 committed by GitHub
parent e1a823fb6d
commit 29088d67a4
6 changed files with 23 additions and 6 deletions

View File

@ -12,7 +12,7 @@ from dataclasses import dataclass
from enum import Enum
from typing import Any, Awaitable, Callable, DefaultDict, Dict, List, Mapping, ParamSpec, Set, Type, TypeVar, cast
from agnext.core import Subscription, TopicId
from agnext.core import AgentType, Subscription, TopicId
from ..core import (
Agent,
@ -445,10 +445,11 @@ class SingleThreadedAgentRuntime(AgentRuntime):
self,
type: str,
agent_factory: Callable[[], T | Awaitable[T]] | Callable[[AgentRuntime, AgentId], T | Awaitable[T]],
) -> None:
) -> AgentType:
if type in self._agent_factories:
raise ValueError(f"Agent with type {type} already exists.")
self._agent_factories[type] = agent_factory
return AgentType(type)
async def _invoke_agent_factory(
self,

View File

@ -28,7 +28,7 @@ import grpc
from grpc.aio import StreamStreamCall
from typing_extensions import Self
from agnext.core import MESSAGE_TYPE_REGISTRY, MessageContext, Subscription, TopicId
from agnext.core import MESSAGE_TYPE_REGISTRY, AgentType, MessageContext, Subscription, TopicId
from ..core import Agent, AgentId, AgentInstantiationContext, AgentMetadata, AgentRuntime, CancellationToken
from .protos import AgentId as AgentIdProto
@ -352,7 +352,7 @@ class WorkerAgentRuntime(AgentRuntime):
self,
type: str,
agent_factory: Callable[[], T | Awaitable[T]],
) -> None:
) -> AgentType:
if type in self._agent_factories:
raise ValueError(f"Agent with type {type} already exists.")
self._agent_factories[type] = agent_factory
@ -361,6 +361,7 @@ class WorkerAgentRuntime(AgentRuntime):
message = Message(registerAgentType=RegisterAgentType(type=type))
await self._host_connection.send(message)
logger.info("Sent registerAgentType message for %s", type)
return AgentType(type)
async def _invoke_agent_factory(
self,

View File

@ -9,6 +9,7 @@ from ._agent_metadata import AgentMetadata
from ._agent_props import AgentChildren
from ._agent_proxy import AgentProxy
from ._agent_runtime import AgentRuntime
from ._agent_type import AgentType
from ._base_agent import BaseAgent
from ._cancellation_token import CancellationToken
from ._message_context import MessageContext
@ -33,4 +34,5 @@ __all__ = [
"Subscription",
"MessageContext",
"Serialization",
"AgentType",
]

View File

@ -1,8 +1,13 @@
from typing_extensions import Self
from ._agent_type import AgentType
class AgentId:
def __init__(self, type: str, key: str) -> None:
def __init__(self, type: str | AgentType, key: str) -> None:
if isinstance(type, AgentType):
type = type.type
if type.isidentifier() is False:
raise ValueError(f"Invalid type: {type}")

View File

@ -5,6 +5,7 @@ from typing import Any, Awaitable, Callable, Mapping, Protocol, Type, TypeVar, r
from ._agent import Agent
from ._agent_id import AgentId
from ._agent_metadata import AgentMetadata
from ._agent_type import AgentType
from ._cancellation_token import CancellationToken
from ._subscription import Subscription
from ._topic import TopicId
@ -70,7 +71,7 @@ class AgentRuntime(Protocol):
self,
type: str,
agent_factory: Callable[[], T | Awaitable[T]],
) -> None:
) -> AgentType:
"""Register an agent factory with the runtime associated with a specific type. The type must be unique.
Args:

View File

@ -0,0 +1,7 @@
from dataclasses import dataclass
@dataclass(eq=True, frozen=True)
class AgentType:
type: str
"""String representation of this agent type."""