781 lines
26 KiB
Python
781 lines
26 KiB
Python
import asyncio
|
|
import logging
|
|
import uuid
|
|
from collections import defaultdict
|
|
from datetime import datetime, timezone
|
|
from typing import Any, Literal, Optional, Type
|
|
|
|
import prisma
|
|
from prisma.models import (
|
|
AgentGraph,
|
|
AgentGraphExecution,
|
|
AgentNode,
|
|
AgentNodeLink,
|
|
StoreListingVersion,
|
|
)
|
|
from prisma.types import AgentGraphWhereInput
|
|
from pydantic.fields import computed_field
|
|
|
|
from backend.blocks.agent import AgentExecutorBlock
|
|
from backend.blocks.basic import AgentInputBlock, AgentOutputBlock
|
|
from backend.util import json
|
|
|
|
from .block import BlockInput, BlockType, get_block, get_blocks
|
|
from .db import BaseDbModel, transaction
|
|
from .execution import ExecutionStatus
|
|
from .includes import AGENT_GRAPH_INCLUDE, AGENT_NODE_INCLUDE
|
|
from .integrations import Webhook
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class Link(BaseDbModel):
|
|
source_id: str
|
|
sink_id: str
|
|
source_name: str
|
|
sink_name: str
|
|
is_static: bool = False
|
|
|
|
@staticmethod
|
|
def from_db(link: AgentNodeLink):
|
|
return Link(
|
|
id=link.id,
|
|
source_name=link.sourceName,
|
|
source_id=link.agentNodeSourceId,
|
|
sink_name=link.sinkName,
|
|
sink_id=link.agentNodeSinkId,
|
|
is_static=link.isStatic,
|
|
)
|
|
|
|
def __hash__(self):
|
|
return hash((self.source_id, self.sink_id, self.source_name, self.sink_name))
|
|
|
|
|
|
class Node(BaseDbModel):
|
|
block_id: str
|
|
input_default: BlockInput = {} # dict[input_name, default_value]
|
|
metadata: dict[str, Any] = {}
|
|
input_links: list[Link] = []
|
|
output_links: list[Link] = []
|
|
|
|
webhook_id: Optional[str] = None
|
|
|
|
|
|
class NodeModel(Node):
|
|
graph_id: str
|
|
graph_version: int
|
|
|
|
webhook: Optional[Webhook] = None
|
|
|
|
@staticmethod
|
|
def from_db(node: AgentNode):
|
|
if not node.AgentBlock:
|
|
raise ValueError(f"Invalid node {node.id}, invalid AgentBlock.")
|
|
obj = NodeModel(
|
|
id=node.id,
|
|
block_id=node.AgentBlock.id,
|
|
input_default=json.loads(node.constantInput, target_type=dict[str, Any]),
|
|
metadata=json.loads(node.metadata, target_type=dict[str, Any]),
|
|
graph_id=node.agentGraphId,
|
|
graph_version=node.agentGraphVersion,
|
|
webhook_id=node.webhookId,
|
|
webhook=Webhook.from_db(node.Webhook) if node.Webhook else None,
|
|
)
|
|
obj.input_links = [Link.from_db(link) for link in node.Input or []]
|
|
obj.output_links = [Link.from_db(link) for link in node.Output or []]
|
|
return obj
|
|
|
|
def is_triggered_by_event_type(self, event_type: str) -> bool:
|
|
if not (block := get_block(self.block_id)):
|
|
raise ValueError(f"Block #{self.block_id} not found for node #{self.id}")
|
|
if not block.webhook_config:
|
|
raise TypeError("This method can't be used on non-webhook blocks")
|
|
if not block.webhook_config.event_filter_input:
|
|
return True
|
|
event_filter = self.input_default.get(block.webhook_config.event_filter_input)
|
|
if not event_filter:
|
|
raise ValueError(f"Event filter is not configured on node #{self.id}")
|
|
return event_type in [
|
|
block.webhook_config.event_format.format(event=k)
|
|
for k in event_filter
|
|
if event_filter[k] is True
|
|
]
|
|
|
|
|
|
# Fix 2-way reference Node <-> Webhook
|
|
Webhook.model_rebuild()
|
|
|
|
|
|
class GraphExecution(BaseDbModel):
|
|
execution_id: str
|
|
started_at: datetime
|
|
ended_at: datetime
|
|
duration: float
|
|
total_run_time: float
|
|
status: ExecutionStatus
|
|
graph_id: str
|
|
graph_version: int
|
|
|
|
@staticmethod
|
|
def from_db(execution: AgentGraphExecution):
|
|
now = datetime.now(timezone.utc)
|
|
start_time = execution.startedAt or execution.createdAt
|
|
end_time = execution.updatedAt or now
|
|
duration = (end_time - start_time).total_seconds()
|
|
total_run_time = duration
|
|
|
|
try:
|
|
stats = json.loads(execution.stats or "{}", target_type=dict[str, Any])
|
|
except ValueError:
|
|
stats = {}
|
|
|
|
duration = stats.get("walltime", duration)
|
|
total_run_time = stats.get("nodes_walltime", total_run_time)
|
|
|
|
return GraphExecution(
|
|
id=execution.id,
|
|
execution_id=execution.id,
|
|
started_at=start_time,
|
|
ended_at=end_time,
|
|
duration=duration,
|
|
total_run_time=total_run_time,
|
|
status=ExecutionStatus(execution.executionStatus),
|
|
graph_id=execution.agentGraphId,
|
|
graph_version=execution.agentGraphVersion,
|
|
)
|
|
|
|
|
|
class Graph(BaseDbModel):
|
|
version: int = 1
|
|
is_active: bool = True
|
|
is_template: bool = False
|
|
name: str
|
|
description: str
|
|
nodes: list[Node] = []
|
|
links: list[Link] = []
|
|
|
|
@computed_field
|
|
@property
|
|
def input_schema(self) -> dict[str, Any]:
|
|
return self._generate_schema(
|
|
AgentInputBlock.Input,
|
|
[
|
|
node.input_default
|
|
for node in self.nodes
|
|
if (b := get_block(node.block_id))
|
|
and b.block_type == BlockType.INPUT
|
|
and "name" in node.input_default
|
|
],
|
|
)
|
|
|
|
@computed_field
|
|
@property
|
|
def output_schema(self) -> dict[str, Any]:
|
|
return self._generate_schema(
|
|
AgentOutputBlock.Input,
|
|
[
|
|
node.input_default
|
|
for node in self.nodes
|
|
if (b := get_block(node.block_id))
|
|
and b.block_type == BlockType.OUTPUT
|
|
and "name" in node.input_default
|
|
],
|
|
)
|
|
|
|
@staticmethod
|
|
def _generate_schema(
|
|
type_class: Type[AgentInputBlock.Input] | Type[AgentOutputBlock.Input],
|
|
data: list[dict],
|
|
) -> dict[str, Any]:
|
|
props = []
|
|
for p in data:
|
|
try:
|
|
props.append(type_class(**p))
|
|
except Exception as e:
|
|
logger.warning(f"Invalid {type_class}: {p}, {e}")
|
|
|
|
return {
|
|
"type": "object",
|
|
"properties": {
|
|
p.name: {
|
|
"secret": p.secret,
|
|
# Default value has to be set for advanced fields.
|
|
"advanced": p.advanced and p.value is not None,
|
|
"title": p.title or p.name,
|
|
**({"description": p.description} if p.description else {}),
|
|
**({"default": p.value} if p.value is not None else {}),
|
|
}
|
|
for p in props
|
|
},
|
|
"required": [p.name for p in props if p.value is None],
|
|
}
|
|
|
|
|
|
class GraphModel(Graph):
|
|
user_id: str
|
|
nodes: list[NodeModel] = [] # type: ignore
|
|
|
|
@property
|
|
def starting_nodes(self) -> list[Node]:
|
|
outbound_nodes = {link.sink_id for link in self.links}
|
|
input_nodes = {
|
|
v.id
|
|
for v in self.nodes
|
|
if (b := get_block(v.block_id)) and b.block_type == BlockType.INPUT
|
|
}
|
|
return [
|
|
node
|
|
for node in self.nodes
|
|
if node.id not in outbound_nodes or node.id in input_nodes
|
|
]
|
|
|
|
def reassign_ids(self, user_id: str, reassign_graph_id: bool = False):
|
|
"""
|
|
Reassigns all IDs in the graph to new UUIDs.
|
|
This method can be used before storing a new graph to the database.
|
|
"""
|
|
|
|
# Reassign Graph ID
|
|
id_map = {node.id: str(uuid.uuid4()) for node in self.nodes}
|
|
if reassign_graph_id:
|
|
self.id = str(uuid.uuid4())
|
|
|
|
# Reassign Node IDs
|
|
for node in self.nodes:
|
|
node.id = id_map[node.id]
|
|
|
|
# Reassign Link IDs
|
|
for link in self.links:
|
|
link.source_id = id_map[link.source_id]
|
|
link.sink_id = id_map[link.sink_id]
|
|
|
|
# Reassign User IDs for agent blocks
|
|
for node in self.nodes:
|
|
if node.block_id != AgentExecutorBlock().id:
|
|
continue
|
|
node.input_default["user_id"] = user_id
|
|
node.input_default.setdefault("data", {})
|
|
|
|
self.validate_graph()
|
|
|
|
def validate_graph(self, for_run: bool = False):
|
|
def sanitize(name):
|
|
return name.split("_#_")[0].split("_@_")[0].split("_$_")[0]
|
|
|
|
input_links = defaultdict(list)
|
|
for link in self.links:
|
|
input_links[link.sink_id].append(link)
|
|
|
|
# Nodes: required fields are filled or connected and dependencies are satisfied
|
|
for node in self.nodes:
|
|
block = get_block(node.block_id)
|
|
if block is None:
|
|
raise ValueError(f"Invalid block {node.block_id} for node #{node.id}")
|
|
|
|
provided_inputs = set(
|
|
[sanitize(name) for name in node.input_default]
|
|
+ [sanitize(link.sink_name) for link in input_links.get(node.id, [])]
|
|
)
|
|
for name in block.input_schema.get_required_fields():
|
|
if (
|
|
name not in provided_inputs
|
|
and not (
|
|
name == "payload"
|
|
and block.block_type
|
|
in (BlockType.WEBHOOK, BlockType.WEBHOOK_MANUAL)
|
|
)
|
|
and (
|
|
for_run # Skip input completion validation, unless when executing.
|
|
or block.block_type == BlockType.INPUT
|
|
or block.block_type == BlockType.OUTPUT
|
|
or block.block_type == BlockType.AGENT
|
|
)
|
|
):
|
|
raise ValueError(
|
|
f"Node {block.name} #{node.id} required input missing: `{name}`"
|
|
)
|
|
|
|
# Get input schema properties and check dependencies
|
|
input_schema = block.input_schema.model_fields
|
|
required_fields = block.input_schema.get_required_fields()
|
|
|
|
def has_value(name):
|
|
return (
|
|
node is not None
|
|
and name in node.input_default
|
|
and node.input_default[name] is not None
|
|
and str(node.input_default[name]).strip() != ""
|
|
) or (name in input_schema and input_schema[name].default is not None)
|
|
|
|
# Validate dependencies between fields
|
|
for field_name, field_info in input_schema.items():
|
|
# Apply input dependency validation only on run & field with depends_on
|
|
json_schema_extra = field_info.json_schema_extra or {}
|
|
dependencies = json_schema_extra.get("depends_on", [])
|
|
if not for_run or not dependencies:
|
|
continue
|
|
|
|
# Check if dependent field has value in input_default
|
|
field_has_value = has_value(field_name)
|
|
field_is_required = field_name in required_fields
|
|
|
|
# Check for missing dependencies when dependent field is present
|
|
missing_deps = [dep for dep in dependencies if not has_value(dep)]
|
|
if missing_deps and (field_has_value or field_is_required):
|
|
raise ValueError(
|
|
f"Node {block.name} #{node.id}: Field `{field_name}` requires [{', '.join(missing_deps)}] to be set"
|
|
)
|
|
|
|
node_map = {v.id: v for v in self.nodes}
|
|
|
|
def is_static_output_block(nid: str) -> bool:
|
|
bid = node_map[nid].block_id
|
|
b = get_block(bid)
|
|
return b.static_output if b else False
|
|
|
|
# Links: links are connected and the connected pin data type are compatible.
|
|
for link in self.links:
|
|
source = (link.source_id, link.source_name)
|
|
sink = (link.sink_id, link.sink_name)
|
|
suffix = f"Link {source} <-> {sink}"
|
|
|
|
for i, (node_id, name) in enumerate([source, sink]):
|
|
node = node_map.get(node_id)
|
|
if not node:
|
|
raise ValueError(
|
|
f"{suffix}, {node_id} is invalid node id, available nodes: {node_map.keys()}"
|
|
)
|
|
|
|
block = get_block(node.block_id)
|
|
if not block:
|
|
blocks = {v().id: v().name for v in get_blocks().values()}
|
|
raise ValueError(
|
|
f"{suffix}, {node.block_id} is invalid block id, available blocks: {blocks}"
|
|
)
|
|
|
|
sanitized_name = sanitize(name)
|
|
vals = node.input_default
|
|
if i == 0:
|
|
fields = (
|
|
block.output_schema.get_fields()
|
|
if block.block_type != BlockType.AGENT
|
|
else vals.get("output_schema", {}).get("properties", {}).keys()
|
|
)
|
|
else:
|
|
fields = (
|
|
block.input_schema.get_fields()
|
|
if block.block_type != BlockType.AGENT
|
|
else vals.get("input_schema", {}).get("properties", {}).keys()
|
|
)
|
|
if sanitized_name not in fields:
|
|
fields_msg = f"Allowed fields: {fields}"
|
|
raise ValueError(f"{suffix}, `{name}` invalid, {fields_msg}")
|
|
|
|
if is_static_output_block(link.source_id):
|
|
link.is_static = True # Each value block output should be static.
|
|
|
|
@staticmethod
|
|
def from_db(graph: AgentGraph, for_export: bool = False):
|
|
return GraphModel(
|
|
id=graph.id,
|
|
user_id=graph.userId,
|
|
version=graph.version,
|
|
is_active=graph.isActive,
|
|
is_template=graph.isTemplate,
|
|
name=graph.name or "",
|
|
description=graph.description or "",
|
|
nodes=[
|
|
NodeModel.from_db(GraphModel._process_node(node, for_export))
|
|
for node in graph.AgentNodes or []
|
|
],
|
|
links=list(
|
|
{
|
|
Link.from_db(link)
|
|
for node in graph.AgentNodes or []
|
|
for link in (node.Input or []) + (node.Output or [])
|
|
}
|
|
),
|
|
)
|
|
|
|
@staticmethod
|
|
def _process_node(node: AgentNode, for_export: bool) -> AgentNode:
|
|
if for_export:
|
|
# Remove credentials from node input
|
|
if node.constantInput:
|
|
constant_input = json.loads(
|
|
node.constantInput, target_type=dict[str, Any]
|
|
)
|
|
constant_input = GraphModel._hide_node_input_credentials(constant_input)
|
|
node.constantInput = json.dumps(constant_input)
|
|
|
|
# Remove webhook info
|
|
node.webhookId = None
|
|
node.Webhook = None
|
|
|
|
return node
|
|
|
|
@staticmethod
|
|
def _hide_node_input_credentials(input_data: dict[str, Any]) -> dict[str, Any]:
|
|
sensitive_keys = ["credentials", "api_key", "password", "token", "secret"]
|
|
result = {}
|
|
for key, value in input_data.items():
|
|
if isinstance(value, dict):
|
|
result[key] = GraphModel._hide_node_input_credentials(value)
|
|
elif isinstance(value, str) and any(
|
|
sensitive_key in key.lower() for sensitive_key in sensitive_keys
|
|
):
|
|
# Skip this key-value pair in the result
|
|
continue
|
|
else:
|
|
result[key] = value
|
|
return result
|
|
|
|
def clean_graph(self):
|
|
blocks = [block() for block in get_blocks().values()]
|
|
|
|
input_blocks = [
|
|
node
|
|
for node in self.nodes
|
|
if next(
|
|
(
|
|
b
|
|
for b in blocks
|
|
if b.id == node.block_id and b.block_type == BlockType.INPUT
|
|
),
|
|
None,
|
|
)
|
|
]
|
|
|
|
for node in self.nodes:
|
|
if any(input_block.id == node.id for input_block in input_blocks):
|
|
node.input_default["value"] = ""
|
|
|
|
|
|
# --------------------- CRUD functions --------------------- #
|
|
|
|
|
|
async def get_node(node_id: str) -> NodeModel:
|
|
node = await AgentNode.prisma().find_unique_or_raise(
|
|
where={"id": node_id},
|
|
include=AGENT_NODE_INCLUDE,
|
|
)
|
|
return NodeModel.from_db(node)
|
|
|
|
|
|
async def set_node_webhook(node_id: str, webhook_id: str | None) -> NodeModel:
|
|
node = await AgentNode.prisma().update(
|
|
where={"id": node_id},
|
|
data=(
|
|
{"Webhook": {"connect": {"id": webhook_id}}}
|
|
if webhook_id
|
|
else {"Webhook": {"disconnect": True}}
|
|
),
|
|
include=AGENT_NODE_INCLUDE,
|
|
)
|
|
if not node:
|
|
raise ValueError(f"Node #{node_id} not found")
|
|
return NodeModel.from_db(node)
|
|
|
|
|
|
async def get_graphs(
|
|
user_id: str,
|
|
filter_by: Literal["active", "template"] | None = "active",
|
|
) -> list[GraphModel]:
|
|
"""
|
|
Retrieves graph metadata objects.
|
|
Default behaviour is to get all currently active graphs.
|
|
|
|
Args:
|
|
filter_by: An optional filter to either select templates or active graphs.
|
|
user_id: The ID of the user that owns the graph.
|
|
|
|
Returns:
|
|
list[GraphModel]: A list of objects representing the retrieved graphs.
|
|
"""
|
|
where_clause: AgentGraphWhereInput = {"userId": user_id}
|
|
|
|
if filter_by == "active":
|
|
where_clause["isActive"] = True
|
|
elif filter_by == "template":
|
|
where_clause["isTemplate"] = True
|
|
|
|
graphs = await AgentGraph.prisma().find_many(
|
|
where=where_clause,
|
|
distinct=["id"],
|
|
order={"version": "desc"},
|
|
include=AGENT_GRAPH_INCLUDE,
|
|
)
|
|
|
|
graph_models = []
|
|
for graph in graphs:
|
|
try:
|
|
graph_models.append(GraphModel.from_db(graph))
|
|
except Exception as e:
|
|
logger.error(f"Error processing graph {graph.id}: {e}")
|
|
continue
|
|
|
|
return graph_models
|
|
|
|
|
|
async def get_executions(user_id: str) -> list[GraphExecution]:
|
|
executions = await AgentGraphExecution.prisma().find_many(
|
|
where={"userId": user_id},
|
|
order={"createdAt": "desc"},
|
|
)
|
|
return [GraphExecution.from_db(execution) for execution in executions]
|
|
|
|
|
|
async def get_execution(user_id: str, execution_id: str) -> GraphExecution | None:
|
|
execution = await AgentGraphExecution.prisma().find_first(
|
|
where={"id": execution_id, "userId": user_id}
|
|
)
|
|
return GraphExecution.from_db(execution) if execution else None
|
|
|
|
|
|
async def get_graph(
|
|
graph_id: str,
|
|
version: int | None = None,
|
|
user_id: str | None = None,
|
|
for_export: bool = False,
|
|
) -> GraphModel | None:
|
|
"""
|
|
Retrieves a graph from the DB.
|
|
Defaults to the version with `is_active` if `version` is not passed,
|
|
or the latest version with `is_template` if `template=True`.
|
|
|
|
Returns `None` if the record is not found.
|
|
"""
|
|
where_clause: AgentGraphWhereInput = {
|
|
"id": graph_id,
|
|
}
|
|
|
|
if version is not None:
|
|
where_clause["version"] = version
|
|
else:
|
|
where_clause["isActive"] = True
|
|
|
|
graph = await AgentGraph.prisma().find_first(
|
|
where=where_clause,
|
|
include=AGENT_GRAPH_INCLUDE,
|
|
order={"version": "desc"},
|
|
)
|
|
|
|
# The Graph has to be owned by the user or a store listing.
|
|
if (
|
|
graph is None
|
|
or graph.userId != user_id
|
|
and not (
|
|
await StoreListingVersion.prisma().find_first(
|
|
where=prisma.types.StoreListingVersionWhereInput(
|
|
agentId=graph_id,
|
|
agentVersion=version or graph.version,
|
|
isDeleted=False,
|
|
StoreListing={"is": {"isApproved": True}},
|
|
)
|
|
)
|
|
)
|
|
):
|
|
return None
|
|
|
|
return GraphModel.from_db(graph, for_export)
|
|
|
|
|
|
async def set_graph_active_version(graph_id: str, version: int, user_id: str) -> None:
|
|
# Activate the requested version if it exists and is owned by the user.
|
|
updated_count = await AgentGraph.prisma().update_many(
|
|
data={"isActive": True},
|
|
where={
|
|
"id": graph_id,
|
|
"version": version,
|
|
"userId": user_id,
|
|
},
|
|
)
|
|
if updated_count == 0:
|
|
raise Exception(f"Graph #{graph_id} v{version} not found or not owned by user")
|
|
|
|
# Deactivate all other versions.
|
|
await AgentGraph.prisma().update_many(
|
|
data={"isActive": False},
|
|
where={
|
|
"id": graph_id,
|
|
"version": {"not": version},
|
|
"userId": user_id,
|
|
"isActive": True,
|
|
},
|
|
)
|
|
|
|
|
|
async def get_graph_all_versions(graph_id: str, user_id: str) -> list[GraphModel]:
|
|
graph_versions = await AgentGraph.prisma().find_many(
|
|
where={"id": graph_id, "userId": user_id},
|
|
order={"version": "desc"},
|
|
include=AGENT_GRAPH_INCLUDE,
|
|
)
|
|
|
|
if not graph_versions:
|
|
return []
|
|
|
|
return [GraphModel.from_db(graph) for graph in graph_versions]
|
|
|
|
|
|
async def delete_graph(graph_id: str, user_id: str) -> int:
|
|
entries_count = await AgentGraph.prisma().delete_many(
|
|
where={"id": graph_id, "userId": user_id}
|
|
)
|
|
if entries_count:
|
|
logger.info(f"Deleted {entries_count} graph entries for Graph #{graph_id}")
|
|
return entries_count
|
|
|
|
|
|
async def create_graph(graph: Graph, user_id: str) -> GraphModel:
|
|
async with transaction() as tx:
|
|
await __create_graph(tx, graph, user_id)
|
|
|
|
if created_graph := await get_graph(graph.id, graph.version, user_id=user_id):
|
|
return created_graph
|
|
|
|
raise ValueError(f"Created graph {graph.id} v{graph.version} is not in DB")
|
|
|
|
|
|
async def __create_graph(tx, graph: Graph, user_id: str):
|
|
await AgentGraph.prisma(tx).create(
|
|
data={
|
|
"id": graph.id,
|
|
"version": graph.version,
|
|
"name": graph.name,
|
|
"description": graph.description,
|
|
"isTemplate": graph.is_template,
|
|
"isActive": graph.is_active,
|
|
"userId": user_id,
|
|
"AgentNodes": {
|
|
"create": [
|
|
{
|
|
"id": node.id,
|
|
"agentBlockId": node.block_id,
|
|
"constantInput": json.dumps(node.input_default),
|
|
"metadata": json.dumps(node.metadata),
|
|
}
|
|
for node in graph.nodes
|
|
]
|
|
},
|
|
}
|
|
)
|
|
|
|
await asyncio.gather(
|
|
*[
|
|
AgentNodeLink.prisma(tx).create(
|
|
{
|
|
"id": str(uuid.uuid4()),
|
|
"sourceName": link.source_name,
|
|
"sinkName": link.sink_name,
|
|
"agentNodeSourceId": link.source_id,
|
|
"agentNodeSinkId": link.sink_id,
|
|
"isStatic": link.is_static,
|
|
}
|
|
)
|
|
for link in graph.links
|
|
]
|
|
)
|
|
|
|
|
|
# ------------------------ UTILITIES ------------------------ #
|
|
|
|
|
|
def make_graph_model(creatable_graph: Graph, user_id: str) -> GraphModel:
|
|
"""
|
|
Convert a Graph to a GraphModel, setting graph_id and graph_version on all nodes.
|
|
|
|
Args:
|
|
creatable_graph (Graph): The creatable graph to convert.
|
|
user_id (str): The ID of the user creating the graph.
|
|
|
|
Returns:
|
|
GraphModel: The converted Graph object.
|
|
"""
|
|
# Create a new Graph object, inheriting properties from CreatableGraph
|
|
return GraphModel(
|
|
**creatable_graph.model_dump(exclude={"nodes"}),
|
|
user_id=user_id,
|
|
nodes=[
|
|
NodeModel(
|
|
**creatable_node.model_dump(),
|
|
graph_id=creatable_graph.id,
|
|
graph_version=creatable_graph.version,
|
|
)
|
|
for creatable_node in creatable_graph.nodes
|
|
],
|
|
)
|
|
|
|
|
|
async def fix_llm_provider_credentials():
|
|
"""Fix node credentials with provider `llm`"""
|
|
from backend.integrations.credentials_store import IntegrationCredentialsStore
|
|
|
|
from .user import get_user_integrations
|
|
|
|
store = IntegrationCredentialsStore()
|
|
|
|
broken_nodes = await prisma.get_client().query_raw(
|
|
"""
|
|
SELECT graph."userId" user_id,
|
|
node.id node_id,
|
|
node."constantInput" node_preset_input
|
|
FROM platform."AgentNode" node
|
|
LEFT JOIN platform."AgentGraph" graph
|
|
ON node."agentGraphId" = graph.id
|
|
WHERE node."constantInput"::jsonb->'credentials'->>'provider' = 'llm'
|
|
ORDER BY graph."userId";
|
|
"""
|
|
)
|
|
logger.info(f"Fixing LLM credential inputs on {len(broken_nodes)} nodes")
|
|
|
|
user_id: str = ""
|
|
user_integrations = None
|
|
for node in broken_nodes:
|
|
if node["user_id"] != user_id:
|
|
# Save queries by only fetching once per user
|
|
user_id = node["user_id"]
|
|
user_integrations = await get_user_integrations(user_id)
|
|
elif not user_integrations:
|
|
raise RuntimeError(f"Impossible state while processing node {node}")
|
|
|
|
node_id: str = node["node_id"]
|
|
node_preset_input: dict = json.loads(node["node_preset_input"])
|
|
credentials_meta: dict = node_preset_input["credentials"]
|
|
|
|
credentials = next(
|
|
(
|
|
c
|
|
for c in user_integrations.credentials
|
|
if c.id == credentials_meta["id"]
|
|
),
|
|
None,
|
|
)
|
|
if not credentials:
|
|
continue
|
|
if credentials.type != "api_key":
|
|
logger.warning(
|
|
f"User {user_id} credentials {credentials.id} with provider 'llm' "
|
|
f"has invalid type '{credentials.type}'"
|
|
)
|
|
continue
|
|
|
|
api_key = credentials.api_key.get_secret_value()
|
|
if api_key.startswith("sk-ant-api03-"):
|
|
credentials.provider = credentials_meta["provider"] = "anthropic"
|
|
elif api_key.startswith("sk-"):
|
|
credentials.provider = credentials_meta["provider"] = "openai"
|
|
elif api_key.startswith("gsk_"):
|
|
credentials.provider = credentials_meta["provider"] = "groq"
|
|
else:
|
|
logger.warning(
|
|
f"Could not identify provider from key prefix {api_key[:13]}*****"
|
|
)
|
|
continue
|
|
|
|
store.update_creds(user_id, credentials)
|
|
await AgentNode.prisma().update(
|
|
where={"id": node_id},
|
|
data={"constantInput": json.dumps(node_preset_input)},
|
|
)
|