chore(platform):Refactor GraphExecution naming clash and remove unused Graph Execution functions (#8939)

This is a follow-up of
https://github.com/Significant-Gravitas/AutoGPT/pull/8752

There are several APIs and functions related to graph execution that are
unused now.
There is also confusion about the name of `GraphExecution` that exists
in graph.py & execution.py.

### Changes 🏗️

* Renamed `GraphExecution` in `execution.py` to `GraphExecutionEntry`,
this is only used as a queue entry for execution.
* Removed unused `get_graph_execution` & `list_executions` in
`execution.py`.
* Removed `with_run` option on `get_graph` function in `graph.py`.
* Removed `GraphMetaWithRuns`
* Removed exposed functions only for testing.
* Removed `executions` fields in Graph model.

### Checklist 📋

#### For code changes:
- [ ] I have clearly listed my changes in the PR description
- [ ] I have made a test plan
- [ ] I have tested my changes according to the test plan:
  <!-- Put your test plan here: -->
  - [ ] ...

<details>
  <summary>Example test plan</summary>
  
  - [ ] Create from scratch and execute an agent with at least 3 blocks
- [ ] Import an agent from file upload, and confirm it executes
correctly
  - [ ] Upload agent to marketplace
- [ ] Import an agent from marketplace and confirm it executes correctly
  - [ ] Edit an agent from monitor, and confirm it executes correctly
</details>

#### For configuration changes:
- [ ] `.env.example` is updated or already compatible with my changes
- [ ] `docker-compose.yml` is updated or already compatible with my
changes
- [ ] I have included a list of my configuration changes in the PR
description (under **Changes**)

<details>
  <summary>Examples of configuration changes</summary>

  - Changing ports
  - Adding new services that need to communicate with each other
  - Secrets or environment variable changes
  - New or infrastructure changes such as databases
</details>

---------

Co-authored-by: Krzysztof Czerwinski <34861343+kcze@users.noreply.github.com>
This commit is contained in:
Zamil Majdy 2024-12-11 09:41:15 -06:00 committed by GitHub
parent 7a9115db18
commit 6490b4e188
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 49 additions and 164 deletions

View File

@ -9,7 +9,6 @@ from prisma.models import (
AgentNodeExecution,
AgentNodeExecutionInputOutput,
)
from prisma.types import AgentGraphExecutionWhereInput
from pydantic import BaseModel
from backend.data.block import BlockData, BlockInput, CompletedBlockOutput
@ -19,14 +18,14 @@ from backend.util import json, mock
from backend.util.settings import Config
class GraphExecution(BaseModel):
class GraphExecutionEntry(BaseModel):
user_id: str
graph_exec_id: str
graph_id: str
start_node_execs: list["NodeExecution"]
start_node_execs: list["NodeExecutionEntry"]
class NodeExecution(BaseModel):
class NodeExecutionEntry(BaseModel):
user_id: str
graph_exec_id: str
graph_id: str
@ -325,34 +324,6 @@ async def update_execution_status(
return ExecutionResult.from_db(res)
async def get_graph_execution(
graph_exec_id: str, user_id: str
) -> AgentGraphExecution | None:
"""
Retrieve a specific graph execution by its ID.
Args:
graph_exec_id (str): The ID of the graph execution to retrieve.
user_id (str): The ID of the user to whom the graph (execution) belongs.
Returns:
AgentGraphExecution | None: The graph execution if found, None otherwise.
"""
execution = await AgentGraphExecution.prisma().find_first(
where={"id": graph_exec_id, "userId": user_id},
include=GRAPH_EXECUTION_INCLUDE,
)
return execution
async def list_executions(graph_id: str, graph_version: int | None = None) -> list[str]:
where: AgentGraphExecutionWhereInput = {"agentGraphId": graph_id}
if graph_version is not None:
where["agentGraphVersion"] = graph_version
executions = await AgentGraphExecution.prisma().find_many(where=where)
return [execution.id for execution in executions]
async def get_execution_results(graph_exec_id: str) -> list[ExecutionResult]:
executions = await AgentNodeExecution.prisma().find_many(
where={"agentGraphExecutionId": graph_exec_id},

View File

@ -7,7 +7,7 @@ from typing import Any, Literal, Optional, Type
import prisma
from prisma.models import AgentGraph, AgentGraphExecution, AgentNode, AgentNodeLink
from prisma.types import AgentGraphExecutionWhereInput, AgentGraphWhereInput
from prisma.types import AgentGraphWhereInput
from pydantic.fields import computed_field
from backend.blocks.agent import AgentExecutorBlock
@ -143,7 +143,6 @@ class Graph(BaseDbModel):
is_template: bool = False
name: str
description: str
executions: list[GraphExecution] = []
nodes: list[Node] = []
links: list[Link] = []
@ -329,11 +328,6 @@ class GraphModel(Graph):
@staticmethod
def from_db(graph: AgentGraph, hide_credentials: bool = False):
executions = [
GraphExecution.from_db(execution)
for execution in graph.AgentGraphExecution or []
]
return GraphModel(
id=graph.id,
user_id=graph.userId,
@ -342,7 +336,6 @@ class GraphModel(Graph):
is_template=graph.isTemplate,
name=graph.name or "",
description=graph.description or "",
executions=executions,
nodes=[
GraphModel._process_node(node, hide_credentials)
for node in graph.AgentNodes or []
@ -412,7 +405,6 @@ async def set_node_webhook(node_id: str, webhook_id: str | None) -> NodeModel:
async def get_graphs(
user_id: str,
include_executions: bool = False,
filter_by: Literal["active", "template"] | None = "active",
) -> list[GraphModel]:
"""
@ -420,7 +412,6 @@ async def get_graphs(
Default behaviour is to get all currently active graphs.
Args:
include_executions: Whether to include executions in the graph metadata.
filter_by: An optional filter to either select templates or active graphs.
user_id: The ID of the user that owns the graph.
@ -434,30 +425,31 @@ async def get_graphs(
elif filter_by == "template":
where_clause["isTemplate"] = True
graph_include = AGENT_GRAPH_INCLUDE
graph_include["AgentGraphExecution"] = include_executions
graphs = await AgentGraph.prisma().find_many(
where=where_clause,
distinct=["id"],
order={"version": "desc"},
include=graph_include,
include=AGENT_GRAPH_INCLUDE,
)
return [GraphModel.from_db(graph) for graph in graphs]
async def get_executions(user_id: str) -> list[GraphExecution]:
where_clause: AgentGraphExecutionWhereInput = {"userId": user_id}
executions = await AgentGraphExecution.prisma().find_many(
where=where_clause,
where={"userId": user_id},
order={"createdAt": "desc"},
)
return [GraphExecution.from_db(execution) for execution in executions]
async def get_execution(user_id: str, execution_id: str) -> GraphExecution | None:
execution = await AgentGraphExecution.prisma().find_first(
where={"id": execution_id, "userId": user_id}
)
return GraphExecution.from_db(execution) if execution else None
async def get_graph(
graph_id: str,
version: int | None = None,

View File

@ -25,8 +25,8 @@ from backend.data.execution import (
ExecutionQueue,
ExecutionResult,
ExecutionStatus,
GraphExecution,
NodeExecution,
GraphExecutionEntry,
NodeExecutionEntry,
merge_execution_input,
parse_execution_output,
)
@ -96,13 +96,13 @@ class LogMetadata:
T = TypeVar("T")
ExecutionStream = Generator[NodeExecution, None, None]
ExecutionStream = Generator[NodeExecutionEntry, None, None]
def execute_node(
db_client: "DatabaseManager",
creds_manager: IntegrationCredentialsManager,
data: NodeExecution,
data: NodeExecutionEntry,
execution_stats: dict[str, Any] | None = None,
) -> ExecutionStream:
"""
@ -252,15 +252,15 @@ def _enqueue_next_nodes(
graph_exec_id: str,
graph_id: str,
log_metadata: LogMetadata,
) -> list[NodeExecution]:
) -> list[NodeExecutionEntry]:
def add_enqueued_execution(
node_exec_id: str, node_id: str, data: BlockInput
) -> NodeExecution:
) -> NodeExecutionEntry:
exec_update = db_client.update_execution_status(
node_exec_id, ExecutionStatus.QUEUED, data
)
db_client.send_execution_update(exec_update)
return NodeExecution(
return NodeExecutionEntry(
user_id=user_id,
graph_exec_id=graph_exec_id,
graph_id=graph_id,
@ -269,7 +269,7 @@ def _enqueue_next_nodes(
data=data,
)
def register_next_executions(node_link: Link) -> list[NodeExecution]:
def register_next_executions(node_link: Link) -> list[NodeExecutionEntry]:
enqueued_executions = []
next_output_name = node_link.source_name
next_input_name = node_link.sink_name
@ -501,8 +501,8 @@ class Executor:
@error_logged
def on_node_execution(
cls,
q: ExecutionQueue[NodeExecution],
node_exec: NodeExecution,
q: ExecutionQueue[NodeExecutionEntry],
node_exec: NodeExecutionEntry,
) -> dict[str, Any]:
log_metadata = LogMetadata(
user_id=node_exec.user_id,
@ -529,8 +529,8 @@ class Executor:
@time_measured
def _on_node_execution(
cls,
q: ExecutionQueue[NodeExecution],
node_exec: NodeExecution,
q: ExecutionQueue[NodeExecutionEntry],
node_exec: NodeExecutionEntry,
log_metadata: LogMetadata,
stats: dict[str, Any] | None = None,
):
@ -580,7 +580,9 @@ class Executor:
@classmethod
@error_logged
def on_graph_execution(cls, graph_exec: GraphExecution, cancel: threading.Event):
def on_graph_execution(
cls, graph_exec: GraphExecutionEntry, cancel: threading.Event
):
log_metadata = LogMetadata(
user_id=graph_exec.user_id,
graph_eid=graph_exec.graph_exec_id,
@ -605,7 +607,7 @@ class Executor:
@time_measured
def _on_graph_execution(
cls,
graph_exec: GraphExecution,
graph_exec: GraphExecutionEntry,
cancel: threading.Event,
log_metadata: LogMetadata,
) -> tuple[dict[str, Any], Exception | None]:
@ -636,13 +638,13 @@ class Executor:
cancel_thread.start()
try:
queue = ExecutionQueue[NodeExecution]()
queue = ExecutionQueue[NodeExecutionEntry]()
for node_exec in graph_exec.start_node_execs:
queue.add(node_exec)
running_executions: dict[str, AsyncResult] = {}
def make_exec_callback(exec_data: NodeExecution):
def make_exec_callback(exec_data: NodeExecutionEntry):
node_id = exec_data.node_id
def callback(result: object):
@ -717,7 +719,7 @@ class ExecutionManager(AppService):
self.use_redis = True
self.use_supabase = True
self.pool_size = settings.config.num_graph_workers
self.queue = ExecutionQueue[GraphExecution]()
self.queue = ExecutionQueue[GraphExecutionEntry]()
self.active_graph_runs: dict[str, tuple[Future, threading.Event]] = {}
@classmethod
@ -768,7 +770,7 @@ class ExecutionManager(AppService):
data: BlockInput,
user_id: str,
graph_version: int | None = None,
) -> GraphExecution:
) -> GraphExecutionEntry:
graph: GraphModel | None = self.db_client.get_graph(
graph_id=graph_id, user_id=user_id, version=graph_version
)
@ -818,7 +820,7 @@ class ExecutionManager(AppService):
starting_node_execs = []
for node_exec in node_execs:
starting_node_execs.append(
NodeExecution(
NodeExecutionEntry(
user_id=user_id,
graph_exec_id=node_exec.graph_exec_id,
graph_id=node_exec.graph_id,
@ -832,7 +834,7 @@ class ExecutionManager(AppService):
)
self.db_client.send_execution_update(exec_update)
graph_exec = GraphExecution(
graph_exec = GraphExecutionEntry(
user_id=user_id,
graph_id=graph_id,
graph_exec_id=graph_exec_id,

View File

@ -117,17 +117,17 @@ class AgentServer(backend.util.service.AppProcess):
async def test_create_graph(
create_graph: backend.server.routers.v1.CreateGraph,
user_id: str,
is_template=False,
):
return await backend.server.routers.v1.create_new_graph(create_graph, user_id)
@staticmethod
async def test_get_graph_run_status(
graph_id: str, graph_exec_id: str, user_id: str
):
return await backend.server.routers.v1.get_graph_run_status(
graph_id, graph_exec_id, user_id
async def test_get_graph_run_status(graph_exec_id: str, user_id: str):
execution = await backend.data.graph.get_execution(
user_id=user_id, execution_id=graph_exec_id
)
if not execution:
raise ValueError(f"Execution {graph_exec_id} not found")
return execution.status
@staticmethod
async def test_get_graph_run_node_execution_results(

View File

@ -149,12 +149,9 @@ class DeleteGraphResponse(TypedDict):
@v1_router.get(path="/graphs", tags=["graphs"], dependencies=[Depends(auth_middleware)])
async def get_graphs(
user_id: Annotated[str, Depends(get_user_id)],
with_runs: bool = False,
user_id: Annotated[str, Depends(get_user_id)]
) -> Sequence[graph_db.Graph]:
return await graph_db.get_graphs(
include_executions=with_runs, filter_by="active", user_id=user_id
)
return await graph_db.get_graphs(filter_by="active", user_id=user_id)
@v1_router.get(
@ -386,7 +383,7 @@ def execute_graph(
async def stop_graph_run(
graph_exec_id: str, user_id: Annotated[str, Depends(get_user_id)]
) -> Sequence[execution_db.ExecutionResult]:
if not await execution_db.get_graph_execution(graph_exec_id, user_id):
if not await graph_db.get_execution(user_id=user_id, execution_id=graph_exec_id):
raise HTTPException(404, detail=f"Agent execution #{graph_exec_id} not found")
await asyncio.to_thread(
@ -408,26 +405,6 @@ async def get_executions(
return await graph_db.get_executions(user_id=user_id)
@v1_router.get(
path="/graphs/{graph_id}/executions",
tags=["graphs"],
dependencies=[Depends(auth_middleware)],
)
async def list_graph_runs(
graph_id: str,
user_id: Annotated[str, Depends(get_user_id)],
graph_version: int | None = None,
) -> Sequence[str]:
graph = await graph_db.get_graph(graph_id, graph_version, user_id=user_id)
if not graph:
rev = "" if graph_version is None else f" v{graph_version}"
raise HTTPException(
status_code=404, detail=f"Agent #{graph_id}{rev} not found."
)
return await execution_db.list_executions(graph_id, graph_version)
@v1_router.get(
path="/graphs/{graph_id}/executions/{graph_exec_id}",
tags=["graphs"],
@ -445,25 +422,6 @@ async def get_graph_run_node_execution_results(
return await execution_db.get_execution_results(graph_exec_id)
# NOTE: This is used for testing
async def get_graph_run_status(
graph_id: str,
graph_exec_id: str,
user_id: Annotated[str, Depends(get_user_id)],
) -> execution_db.ExecutionStatus:
graph = await graph_db.get_graph(graph_id, user_id=user_id)
if not graph:
raise HTTPException(status_code=404, detail=f"Graph #{graph_id} not found.")
execution = await execution_db.get_graph_execution(graph_exec_id, user_id)
if not execution:
raise HTTPException(
status_code=404, detail=f"Execution #{graph_exec_id} not found."
)
return execution.executionStatus
########################################################
##################### Templates ########################
########################################################

View File

@ -60,9 +60,7 @@ async def wait_execution(
timeout: int = 20,
) -> Sequence[ExecutionResult]:
async def is_execution_completed():
status = await AgentServer().test_get_graph_run_status(
graph_id, graph_exec_id, user_id
)
status = await AgentServer().test_get_graph_run_status(graph_exec_id, user_id)
log.info(f"Execution status: {status}")
if status == ExecutionStatus.FAILED:
log.info("Execution failed")

View File

@ -14,7 +14,6 @@ async def test_agent_schedule(server: SpinTestServer):
test_user = await create_test_user()
test_graph = await server.agent_server.test_create_graph(
create_graph=CreateGraph(graph=create_test_graph()),
is_template=False,
user_id=test_user.id,
)

View File

@ -87,9 +87,7 @@ const Monitor = () => {
selectedFlow={selectedFlow}
onSelectFlow={(f) => {
setSelectedRun(null);
setSelectedFlow(
f.id == selectedFlow?.id ? null : (f as GraphMetaWithRuns),
);
setSelectedFlow(f.id == selectedFlow?.id ? null : (f as GraphMeta));
}}
/>
<FlowRunsList

View File

@ -18,13 +18,13 @@ export function withFeatureFlag<P extends object>(
if (flags && flagKey in flags) {
setHasFlagLoaded(true);
}
}, [flags, flagKey]);
}, [flags]);
useEffect(() => {
if (hasFlagLoaded && !flags[flagKey]) {
router.push("/404");
}
}, [hasFlagLoaded, flags, flagKey, router]);
}, [hasFlagLoaded, flags, router]);
// Show loading state until flags loaded
if (!hasFlagLoaded) {

View File

@ -12,7 +12,6 @@ import {
GraphCreatable,
GraphExecuteResponse,
GraphMeta,
GraphMetaWithRuns,
GraphUpdateable,
NodeExecutionResult,
OAuth2Credentials,
@ -69,11 +68,6 @@ export default class BaseAutoGPTServerAPI {
return this._get(`/graphs`);
}
async listGraphsWithRuns(): Promise<GraphMetaWithRuns[]> {
let graphs = await this._get(`/graphs?with_runs=true`);
return graphs.map(parseGraphMetaWithRuns);
}
getExecutions(): Promise<GraphExecution[]> {
return this._get(`/executions`);
}
@ -163,12 +157,6 @@ export default class BaseAutoGPTServerAPI {
return this._request("POST", `/graphs/${id}/execute`, inputData);
}
listGraphRunIDs(graphID: string, graphVersion?: number): Promise<string[]> {
const query =
graphVersion !== undefined ? `?graph_version=${graphVersion}` : "";
return this._get(`/graphs/${graphID}/executions` + query);
}
async getGraphExecutionInfo(
graphID: string,
runID: string,
@ -521,20 +509,3 @@ function parseNodeExecutionResultTimestamps(result: any): NodeExecutionResult {
end_time: result.end_time ? new Date(result.end_time) : undefined,
};
}
function parseGraphMetaWithRuns(result: any): GraphMetaWithRuns {
return {
...result,
executions: result.executions
? result.executions.map(parseExecutionMetaTimestamps)
: [],
};
}
function parseExecutionMetaTimestamps(result: any): GraphExecution {
return {
...result,
started_at: new Date(result.started_at).getTime(),
ended_at: result.ended_at ? new Date(result.ended_at).getTime() : undefined,
};
}

View File

@ -200,7 +200,6 @@ export type GraphExecution = {
graph_version: number;
};
/* backend/data/graph.py:Graph = GraphMeta & GraphMetaWithRuns & Graph */
export type GraphMeta = {
id: string;
version: number;
@ -212,10 +211,7 @@ export type GraphMeta = {
output_schema: BlockIOObjectSubSchema;
};
export type GraphMetaWithRuns = GraphMeta & {
executions: GraphExecution[];
};
/* Mirror of backend/data/graph.py:Graph */
export type Graph = GraphMeta & {
nodes: Array<Node>;
links: Array<Link>;