Make grpc an optional dependency (#4315)

* Make grpc an optional dependency

* add note to the runtime docs

* update version

---------

Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com>
This commit is contained in:
Jack Gerrits 2024-11-24 01:36:30 -05:00 committed by GitHub
parent 02ef110e10
commit 01dc56b244
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 46 additions and 15 deletions

View File

@ -21,6 +21,13 @@
"It also advertises the agents which they support to the host service,\n", "It also advertises the agents which they support to the host service,\n",
"so the host service can deliver messages to the correct worker.\n", "so the host service can deliver messages to the correct worker.\n",
"\n", "\n",
"````{note}\n",
"The distributed agent runtime requires extra dependencies, install them using:\n",
"```bash\n",
"pip install autogen-core[grpc]==0.4.0.dev6\n",
"```\n",
"````\n",
"\n",
"We can start a host service using {py:class}`~autogen_core.application.WorkerAgentRuntimeHost`." "We can start a host service using {py:class}`~autogen_core.application.WorkerAgentRuntimeHost`."
] ]
}, },

View File

@ -20,7 +20,6 @@ dependencies = [
"aiohttp", "aiohttp",
"typing-extensions", "typing-extensions",
"pydantic<3.0.0,>=2.0.0", "pydantic<3.0.0,>=2.0.0",
"grpcio~=1.62.0",
"protobuf~=4.25.1", "protobuf~=4.25.1",
"tiktoken", "tiktoken",
"opentelemetry-api~=1.27.0", "opentelemetry-api~=1.27.0",
@ -28,6 +27,11 @@ dependencies = [
"jsonref~=1.1.0", "jsonref~=1.1.0",
] ]
[project.optional-dependencies]
grpc = [
"grpcio~=1.62.0",
]
[tool.uv] [tool.uv]
dev-dependencies = [ dev-dependencies = [
"aiofiles", "aiofiles",

View File

@ -0,0 +1,3 @@
GRPC_IMPORT_ERROR_STR = (
"Distributed runtime features require additional dependencies. Install them with: pip install autogen-core[grpc]"
)

View File

@ -27,16 +27,11 @@ from typing import (
cast, cast,
) )
import grpc
from grpc.aio import StreamStreamCall
from opentelemetry.trace import TracerProvider from opentelemetry.trace import TracerProvider
from typing_extensions import Self, deprecated from typing_extensions import Self, deprecated
from autogen_core.base import JSON_DATA_CONTENT_TYPE
from autogen_core.base._serialization import MessageSerializer, SerializationRegistry
from autogen_core.base._type_helpers import ChannelArgumentType
from ..base import ( from ..base import (
JSON_DATA_CONTENT_TYPE,
Agent, Agent,
AgentId, AgentId,
AgentInstantiationContext, AgentInstantiationContext,
@ -50,11 +45,19 @@ from ..base import (
SubscriptionInstantiationContext, SubscriptionInstantiationContext,
TopicId, TopicId,
) )
from ..base._serialization import MessageSerializer, SerializationRegistry
from ..base._type_helpers import ChannelArgumentType
from ..components import TypeSubscription from ..components import TypeSubscription
from ._helpers import SubscriptionManager, get_impl from ._helpers import SubscriptionManager, get_impl
from ._utils import GRPC_IMPORT_ERROR_STR
from .protos import agent_worker_pb2, agent_worker_pb2_grpc from .protos import agent_worker_pb2, agent_worker_pb2_grpc
from .telemetry import MessageRuntimeTracingConfig, TraceHelper, get_telemetry_grpc_metadata from .telemetry import MessageRuntimeTracingConfig, TraceHelper, get_telemetry_grpc_metadata
try:
import grpc.aio
except ImportError as e:
raise ImportError(GRPC_IMPORT_ERROR_STR) from e
if TYPE_CHECKING: if TYPE_CHECKING:
from .protos.agent_worker_pb2_grpc import AgentRpcAsyncStub from .protos.agent_worker_pb2_grpc import AgentRpcAsyncStub
@ -140,6 +143,8 @@ class HostConnection:
) -> None: ) -> None:
stub: AgentRpcAsyncStub = agent_worker_pb2_grpc.AgentRpcStub(channel) # type: ignore stub: AgentRpcAsyncStub = agent_worker_pb2_grpc.AgentRpcStub(channel) # type: ignore
from grpc.aio import StreamStreamCall
# TODO: where do exceptions from reading the iterable go? How do we recover from those? # TODO: where do exceptions from reading the iterable go? How do we recover from those?
recv_stream: StreamStreamCall[agent_worker_pb2.Message, agent_worker_pb2.Message] = stub.OpenChannel( # type: ignore recv_stream: StreamStreamCall[agent_worker_pb2.Message, agent_worker_pb2.Message] = stub.OpenChannel( # type: ignore
QueueAsyncIterable(send_queue) QueueAsyncIterable(send_queue)

View File

@ -3,11 +3,14 @@ import logging
import signal import signal
from typing import Optional, Sequence from typing import Optional, Sequence
import grpc from ..base._type_helpers import ChannelArgumentType
from ._utils import GRPC_IMPORT_ERROR_STR
from autogen_core.base._type_helpers import ChannelArgumentType
from ._worker_runtime_host_servicer import WorkerAgentRuntimeHostServicer from ._worker_runtime_host_servicer import WorkerAgentRuntimeHostServicer
try:
import grpc
except ImportError as e:
raise ImportError(GRPC_IMPORT_ERROR_STR) from e
from .protos import agent_worker_pb2_grpc from .protos import agent_worker_pb2_grpc
logger = logging.getLogger("autogen_core") logger = logging.getLogger("autogen_core")

View File

@ -4,11 +4,16 @@ from _collections_abc import AsyncIterator, Iterator
from asyncio import Future, Task from asyncio import Future, Task
from typing import Any, Dict, Set from typing import Any, Dict, Set
import grpc
from ..base import TopicId from ..base import TopicId
from ..components import TypeSubscription from ..components import TypeSubscription
from ._helpers import SubscriptionManager from ._helpers import SubscriptionManager
from ._utils import GRPC_IMPORT_ERROR_STR
try:
import grpc
except ImportError as e:
raise ImportError(GRPC_IMPORT_ERROR_STR) from e
from .protos import agent_worker_pb2, agent_worker_pb2_grpc from .protos import agent_worker_pb2, agent_worker_pb2_grpc
logger = logging.getLogger("autogen_core") logger = logging.getLogger("autogen_core")

View File

@ -334,7 +334,6 @@ source = { editable = "packages/autogen-core" }
dependencies = [ dependencies = [
{ name = "aiohttp" }, { name = "aiohttp" },
{ name = "asyncio-atexit" }, { name = "asyncio-atexit" },
{ name = "grpcio" },
{ name = "jsonref" }, { name = "jsonref" },
{ name = "openai" }, { name = "openai" },
{ name = "opentelemetry-api" }, { name = "opentelemetry-api" },
@ -345,6 +344,11 @@ dependencies = [
{ name = "typing-extensions" }, { name = "typing-extensions" },
] ]
[package.optional-dependencies]
grpc = [
{ name = "grpcio" },
]
[package.dev-dependencies] [package.dev-dependencies]
dev = [ dev = [
{ name = "aiofiles" }, { name = "aiofiles" },
@ -390,7 +394,7 @@ dev = [
requires-dist = [ requires-dist = [
{ name = "aiohttp" }, { name = "aiohttp" },
{ name = "asyncio-atexit" }, { name = "asyncio-atexit" },
{ name = "grpcio", specifier = "~=1.62.0" }, { name = "grpcio", marker = "extra == 'grpc'", specifier = "~=1.62.0" },
{ name = "jsonref", specifier = "~=1.1.0" }, { name = "jsonref", specifier = "~=1.1.0" },
{ name = "openai", specifier = ">=1.3" }, { name = "openai", specifier = ">=1.3" },
{ name = "opentelemetry-api", specifier = "~=1.27.0" }, { name = "opentelemetry-api", specifier = "~=1.27.0" },