AutoGPT/autogpt_platform/backend/backend/executor/manager.py

969 lines
35 KiB
Python

import atexit
import logging
import multiprocessing
import os
import signal
import sys
import threading
from concurrent.futures import Future, ProcessPoolExecutor
from contextlib import contextmanager
from multiprocessing.pool import AsyncResult, Pool
from typing import TYPE_CHECKING, Any, Generator, TypeVar, cast
from redis.lock import Lock as RedisLock
if TYPE_CHECKING:
from backend.executor import DatabaseManager
from autogpt_libs.utils.cache import thread_cached
from backend.blocks.agent import AgentExecutorBlock
from backend.data import redis
from backend.data.block import (
Block,
BlockData,
BlockInput,
BlockSchema,
BlockType,
get_block,
)
from backend.data.execution import (
ExecutionQueue,
ExecutionResult,
ExecutionStatus,
GraphExecutionEntry,
NodeExecutionEntry,
merge_execution_input,
parse_execution_output,
)
from backend.data.graph import GraphModel, Link, Node
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.util import json
from backend.util.decorator import error_logged, time_measured
from backend.util.logging import configure_logging
from backend.util.process import set_service_name
from backend.util.service import (
AppService,
close_service_client,
expose,
get_service_client,
)
from backend.util.settings import Settings
from backend.util.type import convert
logger = logging.getLogger(__name__)
settings = Settings()
class LogMetadata:
def __init__(
self,
user_id: str,
graph_eid: str,
graph_id: str,
node_eid: str,
node_id: str,
block_name: str,
):
self.metadata = {
"component": "ExecutionManager",
"user_id": user_id,
"graph_eid": graph_eid,
"graph_id": graph_id,
"node_eid": node_eid,
"node_id": node_id,
"block_name": block_name,
}
self.prefix = f"[ExecutionManager|uid:{user_id}|gid:{graph_id}|nid:{node_id}]|geid:{graph_eid}|nid:{node_eid}|{block_name}]"
def info(self, msg: str, **extra):
msg = self._wrap(msg, **extra)
logger.info(msg, extra={"json_fields": {**self.metadata, **extra}})
def warning(self, msg: str, **extra):
msg = self._wrap(msg, **extra)
logger.warning(msg, extra={"json_fields": {**self.metadata, **extra}})
def error(self, msg: str, **extra):
msg = self._wrap(msg, **extra)
logger.error(msg, extra={"json_fields": {**self.metadata, **extra}})
def debug(self, msg: str, **extra):
msg = self._wrap(msg, **extra)
logger.debug(msg, extra={"json_fields": {**self.metadata, **extra}})
def exception(self, msg: str, **extra):
msg = self._wrap(msg, **extra)
logger.exception(msg, extra={"json_fields": {**self.metadata, **extra}})
def _wrap(self, msg: str, **extra):
return f"{self.prefix} {msg} {extra}"
T = TypeVar("T")
ExecutionStream = Generator[NodeExecutionEntry, None, None]
def execute_node(
db_client: "DatabaseManager",
creds_manager: IntegrationCredentialsManager,
data: NodeExecutionEntry,
execution_stats: dict[str, Any] | None = None,
) -> ExecutionStream:
"""
Execute a node in the graph. This will trigger a block execution on a node,
persist the execution result, and return the subsequent node to be executed.
Args:
db_client: The client to send execution updates to the server.
creds_manager: The manager to acquire and release credentials.
data: The execution data for executing the current node.
execution_stats: The execution statistics to be updated.
Returns:
The subsequent node to be enqueued, or None if there is no subsequent node.
"""
user_id = data.user_id
graph_exec_id = data.graph_exec_id
graph_id = data.graph_id
node_exec_id = data.node_exec_id
node_id = data.node_id
def update_execution(status: ExecutionStatus) -> ExecutionResult:
exec_update = db_client.update_execution_status(node_exec_id, status)
db_client.send_execution_update(exec_update)
return exec_update
node = db_client.get_node(node_id)
node_block = get_block(node.block_id)
if not node_block:
logger.error(f"Block {node.block_id} not found.")
return
log_metadata = LogMetadata(
user_id=user_id,
graph_eid=graph_exec_id,
graph_id=graph_id,
node_eid=node_exec_id,
node_id=node_id,
block_name=node_block.name,
)
# Sanity check: validate the execution input.
input_data, error = validate_exec(node, data.data, resolve_input=False)
if input_data is None:
log_metadata.error(f"Skip execution, input validation error: {error}")
db_client.upsert_execution_output(node_exec_id, "error", error)
update_execution(ExecutionStatus.FAILED)
return
# Re-shape the input data for agent block.
# AgentExecutorBlock specially separate the node input_data & its input_default.
if isinstance(node_block, AgentExecutorBlock):
input_data = {**node.input_default, "data": input_data}
# Execute the node
input_data_str = json.dumps(input_data)
input_size = len(input_data_str)
log_metadata.info("Executed node with input", input=input_data_str)
update_execution(ExecutionStatus.RUNNING)
extra_exec_kwargs = {}
# Last-minute fetch credentials + acquire a system-wide read-write lock to prevent
# changes during execution. ⚠️ This means a set of credentials can only be used by
# one (running) block at a time; simultaneous execution of blocks using same
# credentials is not supported.
creds_lock = None
input_model = cast(type[BlockSchema], node_block.input_schema)
for field_name, input_type in input_model.get_credentials_fields().items():
credentials_meta = input_type(**input_data[field_name])
credentials, creds_lock = creds_manager.acquire(user_id, credentials_meta.id)
extra_exec_kwargs[field_name] = credentials
output_size = 0
end_status = ExecutionStatus.COMPLETED
credit = db_client.get_or_refill_credit(user_id)
if credit < 0:
raise ValueError(f"Insufficient credit: {credit}")
try:
for output_name, output_data in node_block.execute(
input_data, **extra_exec_kwargs
):
output_size += len(json.dumps(output_data))
log_metadata.info("Node produced output", **{output_name: output_data})
db_client.upsert_execution_output(node_exec_id, output_name, output_data)
for execution in _enqueue_next_nodes(
db_client=db_client,
node=node,
output=(output_name, output_data),
user_id=user_id,
graph_exec_id=graph_exec_id,
graph_id=graph_id,
log_metadata=log_metadata,
):
yield execution
except Exception as e:
end_status = ExecutionStatus.FAILED
error_msg = str(e)
log_metadata.exception(f"Node execution failed with error {error_msg}")
db_client.upsert_execution_output(node_exec_id, "error", error_msg)
for execution in _enqueue_next_nodes(
db_client=db_client,
node=node,
output=("error", error_msg),
user_id=user_id,
graph_exec_id=graph_exec_id,
graph_id=graph_id,
log_metadata=log_metadata,
):
yield execution
raise e
finally:
# Ensure credentials are released even if execution fails
if creds_lock:
try:
creds_lock.release()
except Exception as e:
log_metadata.error(f"Failed to release credentials lock: {e}")
# Update execution status and spend credits
res = update_execution(end_status)
if end_status == ExecutionStatus.COMPLETED:
s = input_size + output_size
t = (
(res.end_time - res.start_time).total_seconds()
if res.end_time and res.start_time
else 0
)
db_client.spend_credits(user_id, credit, node_block.id, input_data, s, t)
# Update execution stats
if execution_stats is not None:
execution_stats.update(node_block.execution_stats)
execution_stats["input_size"] = input_size
execution_stats["output_size"] = output_size
def _enqueue_next_nodes(
db_client: "DatabaseManager",
node: Node,
output: BlockData,
user_id: str,
graph_exec_id: str,
graph_id: str,
log_metadata: LogMetadata,
) -> list[NodeExecutionEntry]:
def add_enqueued_execution(
node_exec_id: str, node_id: str, data: BlockInput
) -> NodeExecutionEntry:
exec_update = db_client.update_execution_status(
node_exec_id, ExecutionStatus.QUEUED, data
)
db_client.send_execution_update(exec_update)
return NodeExecutionEntry(
user_id=user_id,
graph_exec_id=graph_exec_id,
graph_id=graph_id,
node_exec_id=node_exec_id,
node_id=node_id,
data=data,
)
def register_next_executions(node_link: Link) -> list[NodeExecutionEntry]:
enqueued_executions = []
next_output_name = node_link.source_name
next_input_name = node_link.sink_name
next_node_id = node_link.sink_id
next_data = parse_execution_output(output, next_output_name)
if next_data is None:
return enqueued_executions
next_node = db_client.get_node(next_node_id)
# Multiple node can register the same next node, we need this to be atomic
# To avoid same execution to be enqueued multiple times,
# Or the same input to be consumed multiple times.
with synchronized(f"upsert_input-{next_node_id}-{graph_exec_id}"):
# Add output data to the earliest incomplete execution, or create a new one.
next_node_exec_id, next_node_input = db_client.upsert_execution_input(
node_id=next_node_id,
graph_exec_id=graph_exec_id,
input_name=next_input_name,
input_data=next_data,
)
# Complete missing static input pins data using the last execution input.
static_link_names = {
link.sink_name
for link in next_node.input_links
if link.is_static and link.sink_name not in next_node_input
}
if static_link_names and (
latest_execution := db_client.get_latest_execution(
next_node_id, graph_exec_id
)
):
for name in static_link_names:
next_node_input[name] = latest_execution.input_data.get(name)
# Validate the input data for the next node.
next_node_input, validation_msg = validate_exec(next_node, next_node_input)
suffix = f"{next_output_name}>{next_input_name}~{next_node_exec_id}:{validation_msg}"
# Incomplete input data, skip queueing the execution.
if not next_node_input:
log_metadata.warning(f"Skipped queueing {suffix}")
return enqueued_executions
# Input is complete, enqueue the execution.
log_metadata.info(f"Enqueued {suffix}")
enqueued_executions.append(
add_enqueued_execution(next_node_exec_id, next_node_id, next_node_input)
)
# Next execution stops here if the link is not static.
if not node_link.is_static:
return enqueued_executions
# If link is static, there could be some incomplete executions waiting for it.
# Load and complete the input missing input data, and try to re-enqueue them.
for iexec in db_client.get_incomplete_executions(
next_node_id, graph_exec_id
):
idata = iexec.input_data
ineid = iexec.node_exec_id
static_link_names = {
link.sink_name
for link in next_node.input_links
if link.is_static and link.sink_name not in idata
}
for input_name in static_link_names:
idata[input_name] = next_node_input[input_name]
idata, msg = validate_exec(next_node, idata)
suffix = f"{next_output_name}>{next_input_name}~{ineid}:{msg}"
if not idata:
log_metadata.info(f"Enqueueing static-link skipped: {suffix}")
continue
log_metadata.info(f"Enqueueing static-link execution {suffix}")
enqueued_executions.append(
add_enqueued_execution(iexec.node_exec_id, next_node_id, idata)
)
return enqueued_executions
return [
execution
for link in node.output_links
for execution in register_next_executions(link)
]
def validate_exec(
node: Node,
data: BlockInput,
resolve_input: bool = True,
) -> tuple[BlockInput | None, str]:
"""
Validate the input data for a node execution.
Args:
node: The node to execute.
data: The input data for the node execution.
resolve_input: Whether to resolve dynamic pins into dict/list/object.
Returns:
A tuple of the validated data and the block name.
If the data is invalid, the first element will be None, and the second element
will be an error message.
If the data is valid, the first element will be the resolved input data, and
the second element will be the block name.
"""
node_block: Block | None = get_block(node.block_id)
if not node_block:
return None, f"Block for {node.block_id} not found."
if isinstance(node_block, AgentExecutorBlock):
# Validate the execution metadata for the agent executor block.
try:
exec_data = AgentExecutorBlock.Input(**node.input_default)
except Exception as e:
return None, f"Input data doesn't match {node_block.name}: {str(e)}"
# Validation input
input_schema = exec_data.input_schema
required_fields = set(input_schema["required"])
input_default = exec_data.data
else:
# Convert non-matching data types to the expected input schema.
for name, data_type in node_block.input_schema.__annotations__.items():
if (value := data.get(name)) and (type(value) is not data_type):
data[name] = convert(value, data_type)
# Validation input
input_schema = node_block.input_schema.jsonschema()
required_fields = node_block.input_schema.get_required_fields()
input_default = node.input_default
# Input data (without default values) should contain all required fields.
error_prefix = f"Input data missing or mismatch for `{node_block.name}`:"
input_fields_from_nodes = {link.sink_name for link in node.input_links}
if not input_fields_from_nodes.issubset(data):
return None, f"{error_prefix} {input_fields_from_nodes - set(data)}"
# Merge input data with default values and resolve dynamic dict/list/object pins.
data = {**input_default, **data}
if resolve_input:
data = merge_execution_input(data)
# Input data post-merge should contain all required fields from the schema.
if not required_fields.issubset(data):
return None, f"{error_prefix} {required_fields - set(data)}"
# Last validation: Validate the input values against the schema.
if error := json.validate_with_jsonschema(schema=input_schema, data=data):
error_message = f"{error_prefix} {error}"
logger.error(error_message)
return None, error_message
return data, node_block.name
class Executor:
"""
This class contains event handlers for the process pool executor events.
The main events are:
on_node_executor_start: Initialize the process that executes the node.
on_node_execution: Execution logic for a node.
on_graph_executor_start: Initialize the process that executes the graph.
on_graph_execution: Execution logic for a graph.
The execution flow:
1. Graph execution request is added to the queue.
2. Graph executor loop picks the request from the queue.
3. Graph executor loop submits the graph execution request to the executor pool.
[on_graph_execution]
4. Graph executor initialize the node execution queue.
5. Graph executor adds the starting nodes to the node execution queue.
6. Graph executor waits for all nodes to be executed.
[on_node_execution]
7. Node executor picks the node execution request from the queue.
8. Node executor executes the node.
9. Node executor enqueues the next executed nodes to the node execution queue.
"""
@classmethod
def on_node_executor_start(cls):
configure_logging()
set_service_name("NodeExecutor")
redis.connect()
cls.pid = os.getpid()
cls.db_client = get_db_client()
cls.creds_manager = IntegrationCredentialsManager()
# Set up shutdown handlers
cls.shutdown_lock = threading.Lock()
atexit.register(cls.on_node_executor_stop) # handle regular shutdown
signal.signal( # handle termination
signal.SIGTERM, lambda _, __: cls.on_node_executor_sigterm()
)
@classmethod
def on_node_executor_stop(cls):
if not cls.shutdown_lock.acquire(blocking=False):
return # already shutting down
logger.info(f"[on_node_executor_stop {cls.pid}] ⏳ Releasing locks...")
cls.creds_manager.release_all_locks()
logger.info(f"[on_node_executor_stop {cls.pid}] ⏳ Disconnecting Redis...")
redis.disconnect()
logger.info(f"[on_node_executor_stop {cls.pid}] ⏳ Disconnecting DB manager...")
close_service_client(cls.db_client)
logger.info(f"[on_node_executor_stop {cls.pid}] ✅ Finished cleanup")
@classmethod
def on_node_executor_sigterm(cls):
llprint(f"[on_node_executor_sigterm {cls.pid}] ⚠️ SIGTERM received")
if not cls.shutdown_lock.acquire(blocking=False):
return # already shutting down
llprint(f"[on_node_executor_stop {cls.pid}] ⏳ Releasing locks...")
cls.creds_manager.release_all_locks()
llprint(f"[on_node_executor_stop {cls.pid}] ⏳ Disconnecting Redis...")
redis.disconnect()
llprint(f"[on_node_executor_stop {cls.pid}] ✅ Finished cleanup")
sys.exit(0)
@classmethod
@error_logged
def on_node_execution(
cls,
q: ExecutionQueue[NodeExecutionEntry],
node_exec: NodeExecutionEntry,
) -> dict[str, Any]:
log_metadata = LogMetadata(
user_id=node_exec.user_id,
graph_eid=node_exec.graph_exec_id,
graph_id=node_exec.graph_id,
node_eid=node_exec.node_exec_id,
node_id=node_exec.node_id,
block_name="-",
)
execution_stats = {}
timing_info, _ = cls._on_node_execution(
q, node_exec, log_metadata, execution_stats
)
execution_stats["walltime"] = timing_info.wall_time
execution_stats["cputime"] = timing_info.cpu_time
cls.db_client.update_node_execution_stats(
node_exec.node_exec_id, execution_stats
)
return execution_stats
@classmethod
@time_measured
def _on_node_execution(
cls,
q: ExecutionQueue[NodeExecutionEntry],
node_exec: NodeExecutionEntry,
log_metadata: LogMetadata,
stats: dict[str, Any] | None = None,
):
try:
log_metadata.info(f"Start node execution {node_exec.node_exec_id}")
for execution in execute_node(
cls.db_client, cls.creds_manager, node_exec, stats
):
q.add(execution)
log_metadata.info(f"Finished node execution {node_exec.node_exec_id}")
except Exception as e:
log_metadata.exception(
f"Failed node execution {node_exec.node_exec_id}: {e}"
)
@classmethod
def on_graph_executor_start(cls):
configure_logging()
set_service_name("GraphExecutor")
cls.db_client = get_db_client()
cls.pool_size = settings.config.num_node_workers
cls.pid = os.getpid()
cls._init_node_executor_pool()
logger.info(
f"Graph executor {cls.pid} started with {cls.pool_size} node workers"
)
# Set up shutdown handler
atexit.register(cls.on_graph_executor_stop)
@classmethod
def on_graph_executor_stop(cls):
prefix = f"[on_graph_executor_stop {cls.pid}]"
logger.info(f"{prefix} ⏳ Terminating node executor pool...")
cls.executor.terminate()
logger.info(f"{prefix} ⏳ Disconnecting DB manager...")
close_service_client(cls.db_client)
logger.info(f"{prefix} ✅ Finished cleanup")
@classmethod
def _init_node_executor_pool(cls):
cls.executor = Pool(
processes=cls.pool_size,
initializer=cls.on_node_executor_start,
)
@classmethod
@error_logged
def on_graph_execution(
cls, graph_exec: GraphExecutionEntry, cancel: threading.Event
):
log_metadata = LogMetadata(
user_id=graph_exec.user_id,
graph_eid=graph_exec.graph_exec_id,
graph_id=graph_exec.graph_id,
node_id="*",
node_eid="*",
block_name="-",
)
timing_info, (exec_stats, status, error) = cls._on_graph_execution(
graph_exec, cancel, log_metadata
)
exec_stats["walltime"] = timing_info.wall_time
exec_stats["cputime"] = timing_info.cpu_time
exec_stats["error"] = str(error) if error else None
result = cls.db_client.update_graph_execution_stats(
graph_exec_id=graph_exec.graph_exec_id,
status=status,
stats=exec_stats,
)
cls.db_client.send_execution_update(result)
@classmethod
@time_measured
def _on_graph_execution(
cls,
graph_exec: GraphExecutionEntry,
cancel: threading.Event,
log_metadata: LogMetadata,
) -> tuple[dict[str, Any], ExecutionStatus, Exception | None]:
"""
Returns:
dict: The execution statistics of the graph execution.
ExecutionStatus: The final status of the graph execution.
Exception | None: The error that occurred during the execution, if any.
"""
log_metadata.info(f"Start graph execution {graph_exec.graph_exec_id}")
exec_stats = {
"nodes_walltime": 0,
"nodes_cputime": 0,
"node_count": 0,
}
error = None
finished = False
def cancel_handler():
while not cancel.is_set():
cancel.wait(1)
if finished:
return
cls.executor.terminate()
log_metadata.info(f"Terminated graph execution {graph_exec.graph_exec_id}")
cls._init_node_executor_pool()
cancel_thread = threading.Thread(target=cancel_handler)
cancel_thread.start()
try:
queue = ExecutionQueue[NodeExecutionEntry]()
for node_exec in graph_exec.start_node_execs:
queue.add(node_exec)
running_executions: dict[str, AsyncResult] = {}
def make_exec_callback(exec_data: NodeExecutionEntry):
node_id = exec_data.node_id
def callback(result: object):
running_executions.pop(node_id)
nonlocal exec_stats
if isinstance(result, dict):
exec_stats["node_count"] += 1
exec_stats["nodes_cputime"] += result.get("cputime", 0)
exec_stats["nodes_walltime"] += result.get("walltime", 0)
return callback
while not queue.empty():
if cancel.is_set():
return exec_stats, ExecutionStatus.TERMINATED, error
exec_data = queue.get()
# Avoid parallel execution of the same node.
execution = running_executions.get(exec_data.node_id)
if execution and not execution.ready():
# TODO (performance improvement):
# Wait for the completion of the same node execution is blocking.
# To improve this we need a separate queue for each node.
# Re-enqueueing the data back to the queue will disrupt the order.
execution.wait()
log_metadata.debug(
f"Dispatching node execution {exec_data.node_exec_id} "
f"for node {exec_data.node_id}",
)
running_executions[exec_data.node_id] = cls.executor.apply_async(
cls.on_node_execution,
(queue, exec_data),
callback=make_exec_callback(exec_data),
)
# Avoid terminating graph execution when some nodes are still running.
while queue.empty() and running_executions:
log_metadata.debug(
f"Queue empty; running nodes: {list(running_executions.keys())}"
)
for node_id, execution in list(running_executions.items()):
if cancel.is_set():
return exec_stats, ExecutionStatus.TERMINATED, error
if not queue.empty():
break # yield to parent loop to execute new queue items
log_metadata.debug(f"Waiting on execution of node {node_id}")
execution.wait(3)
log_metadata.info(f"Finished graph execution {graph_exec.graph_exec_id}")
except Exception as e:
log_metadata.exception(
f"Failed graph execution {graph_exec.graph_exec_id}: {e}"
)
error = e
finally:
if not cancel.is_set():
finished = True
cancel.set()
cancel_thread.join()
return (
exec_stats,
ExecutionStatus.FAILED if error else ExecutionStatus.COMPLETED,
error,
)
class ExecutionManager(AppService):
def __init__(self):
super().__init__()
self.use_redis = True
self.use_supabase = True
self.pool_size = settings.config.num_graph_workers
self.queue = ExecutionQueue[GraphExecutionEntry]()
self.active_graph_runs: dict[str, tuple[Future, threading.Event]] = {}
@classmethod
def get_port(cls) -> int:
return settings.config.execution_manager_port
def run_service(self):
from backend.integrations.credentials_store import IntegrationCredentialsStore
self.credentials_store = IntegrationCredentialsStore()
self.executor = ProcessPoolExecutor(
max_workers=self.pool_size,
initializer=Executor.on_graph_executor_start,
)
sync_manager = multiprocessing.Manager()
logger.info(
f"[{self.service_name}] Started with max-{self.pool_size} graph workers"
)
while True:
graph_exec_data = self.queue.get()
graph_exec_id = graph_exec_data.graph_exec_id
logger.debug(
f"[ExecutionManager] Dispatching graph execution {graph_exec_id}"
)
cancel_event = sync_manager.Event()
future = self.executor.submit(
Executor.on_graph_execution, graph_exec_data, cancel_event
)
self.active_graph_runs[graph_exec_id] = (future, cancel_event)
future.add_done_callback(
lambda _: self.active_graph_runs.pop(graph_exec_id, None)
)
def cleanup(self):
logger.info(f"[{__class__.__name__}] ⏳ Shutting down graph executor pool...")
self.executor.shutdown(cancel_futures=True)
super().cleanup()
@property
def db_client(self) -> "DatabaseManager":
return get_db_client()
@expose
def add_execution(
self,
graph_id: str,
data: BlockInput,
user_id: str,
graph_version: int,
preset_id: str | None = None,
) -> GraphExecutionEntry:
graph: GraphModel | None = self.db_client.get_graph(
graph_id=graph_id, user_id=user_id, version=graph_version
)
if not graph:
raise ValueError(f"Graph #{graph_id} not found.")
graph.validate_graph(for_run=True)
self._validate_node_input_credentials(graph, user_id)
nodes_input = []
for node in graph.starting_nodes:
input_data = {}
block = get_block(node.block_id)
# Invalid block & Note block should never be executed.
if not block or block.block_type == BlockType.NOTE:
continue
# Extract request input data, and assign it to the input pin.
if block.block_type == BlockType.INPUT:
name = node.input_default.get("name")
if name and name in data:
input_data = {"value": data[name]}
# Extract webhook payload, and assign it to the input pin
webhook_payload_key = f"webhook_{node.webhook_id}_payload"
if (
block.block_type in (BlockType.WEBHOOK, BlockType.WEBHOOK_MANUAL)
and node.webhook_id
):
if webhook_payload_key not in data:
raise ValueError(
f"Node {block.name} #{node.id} webhook payload is missing"
)
input_data = {"payload": data[webhook_payload_key]}
input_data, error = validate_exec(node, input_data)
if input_data is None:
raise ValueError(error)
else:
nodes_input.append((node.id, input_data))
graph_exec_id, node_execs = self.db_client.create_graph_execution(
graph_id=graph_id,
graph_version=graph.version,
nodes_input=nodes_input,
user_id=user_id,
preset_id=preset_id,
)
starting_node_execs = []
for node_exec in node_execs:
starting_node_execs.append(
NodeExecutionEntry(
user_id=user_id,
graph_exec_id=node_exec.graph_exec_id,
graph_id=node_exec.graph_id,
node_exec_id=node_exec.node_exec_id,
node_id=node_exec.node_id,
data=node_exec.input_data,
)
)
exec_update = self.db_client.update_execution_status(
node_exec.node_exec_id, ExecutionStatus.QUEUED, node_exec.input_data
)
self.db_client.send_execution_update(exec_update)
graph_exec = GraphExecutionEntry(
user_id=user_id,
graph_id=graph_id,
graph_exec_id=graph_exec_id,
start_node_execs=starting_node_execs,
)
self.queue.add(graph_exec)
return graph_exec
@expose
def cancel_execution(self, graph_exec_id: str) -> None:
"""
Mechanism:
1. Set the cancel event
2. Graph executor's cancel handler thread detects the event, terminates workers,
reinitializes worker pool, and returns.
3. Update execution statuses in DB and set `error` outputs to `"TERMINATED"`.
"""
if graph_exec_id not in self.active_graph_runs:
raise Exception(
f"Graph execution #{graph_exec_id} not active/running: "
"possibly already completed/cancelled."
)
future, cancel_event = self.active_graph_runs[graph_exec_id]
if cancel_event.is_set():
return
cancel_event.set()
future.result()
# Update the status of the unfinished node executions
node_execs = self.db_client.get_execution_results(graph_exec_id)
for node_exec in node_execs:
if node_exec.status not in (
ExecutionStatus.COMPLETED,
ExecutionStatus.FAILED,
):
exec_update = self.db_client.update_execution_status(
node_exec.node_exec_id, ExecutionStatus.TERMINATED
)
self.db_client.send_execution_update(exec_update)
def _validate_node_input_credentials(self, graph: GraphModel, user_id: str):
"""Checks all credentials for all nodes of the graph"""
for node in graph.nodes:
block = get_block(node.block_id)
if not block:
raise ValueError(f"Unknown block {node.block_id} for node #{node.id}")
# Find any fields of type CredentialsMetaInput
credentials_fields = cast(
type[BlockSchema], block.input_schema
).get_credentials_fields()
if not credentials_fields:
continue
for field_name, credentials_meta_type in credentials_fields.items():
credentials_meta = credentials_meta_type.model_validate(
node.input_default[field_name]
)
# Fetch the corresponding Credentials and perform sanity checks
credentials = self.credentials_store.get_creds_by_id(
user_id, credentials_meta.id
)
if not credentials:
raise ValueError(
f"Unknown credentials #{credentials_meta.id} "
f"for node #{node.id} input '{field_name}'"
)
if (
credentials.provider != credentials_meta.provider
or credentials.type != credentials_meta.type
):
logger.warning(
f"Invalid credentials #{credentials.id} for node #{node.id}: "
"type/provider mismatch: "
f"{credentials_meta.type}<>{credentials.type};"
f"{credentials_meta.provider}<>{credentials.provider}"
)
raise ValueError(
f"Invalid credentials #{credentials.id} for node #{node.id}: "
"type/provider mismatch"
)
# ------- UTILITIES ------- #
@thread_cached
def get_db_client() -> "DatabaseManager":
from backend.executor import DatabaseManager
return get_service_client(DatabaseManager)
@contextmanager
def synchronized(key: str, timeout: int = 60):
lock: RedisLock = redis.get_redis().lock(f"lock:{key}", timeout=timeout)
try:
lock.acquire()
yield
finally:
if lock.locked():
lock.release()
def llprint(message: str):
"""
Low-level print/log helper function for use in signal handlers.
Regular log/print statements are not allowed in signal handlers.
"""
if logger.getEffectiveLevel() == logging.DEBUG:
os.write(sys.stdout.fileno(), (message + "\n").encode())