mirror of https://github.com/langgenius/dify.git
This commit is contained in:
parent
3340775052
commit
9f8e05d9f0
|
@ -222,11 +222,28 @@ class HostedFetchAppTemplateConfig(BaseSettings):
|
|||
)
|
||||
|
||||
|
||||
class HostedFetchPipelineTemplateConfig(BaseSettings):
|
||||
"""
|
||||
Configuration for fetching pipeline templates
|
||||
"""
|
||||
|
||||
HOSTED_FETCH_PIPELINE_TEMPLATES_MODE: str = Field(
|
||||
description="Mode for fetching pipeline templates: remote, db, or builtin default to remote,",
|
||||
default="remote",
|
||||
)
|
||||
|
||||
HOSTED_FETCH_PIPELINE_TEMPLATES_REMOTE_DOMAIN: str = Field(
|
||||
description="Domain for fetching remote pipeline templates",
|
||||
default="https://tmpl.dify.ai",
|
||||
)
|
||||
|
||||
|
||||
class HostedServiceConfig(
|
||||
# place the configs in alphabet order
|
||||
HostedAnthropicConfig,
|
||||
HostedAzureOpenAiConfig,
|
||||
HostedFetchAppTemplateConfig,
|
||||
HostedFetchPipelineTemplateConfig,
|
||||
HostedMinmaxConfig,
|
||||
HostedOpenAiConfig,
|
||||
HostedSparkConfig,
|
||||
|
|
|
@ -101,3 +101,9 @@ class ChildChunkDeleteIndexError(BaseHTTPException):
|
|||
error_code = "child_chunk_delete_index_error"
|
||||
description = "Delete child chunk index failed: {message}"
|
||||
code = 500
|
||||
|
||||
|
||||
class PipelineNotFoundError(BaseHTTPException):
|
||||
error_code = "pipeline_not_found"
|
||||
description = "Pipeline not found."
|
||||
code = 404
|
|
@ -1,19 +1,49 @@
|
|||
from flask import request
|
||||
from flask_login import current_user # type: ignore # type: ignore
|
||||
from flask_restful import Resource, marshal # type: ignore
|
||||
import json
|
||||
import logging
|
||||
from typing import cast
|
||||
|
||||
from flask import abort, request
|
||||
from flask_restful import Resource, inputs, marshal_with, reqparse # type: ignore # type: ignore
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
from configs import dify_config
|
||||
from controllers.console import api
|
||||
from controllers.console.app.error import (
|
||||
ConversationCompletedError,
|
||||
DraftWorkflowNotExist,
|
||||
DraftWorkflowNotSync,
|
||||
)
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.datasets.wraps import get_rag_pipeline
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
enterprise_license_required,
|
||||
setup_required,
|
||||
)
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.plugin.entities.plugin import ModelProviderID
|
||||
from core.provider_manager import ProviderManager
|
||||
from fields.dataset_fields import dataset_detail_fields
|
||||
from libs.login import login_required
|
||||
from services.dataset_service import DatasetPermissionService, DatasetService
|
||||
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from extensions.ext_database import db
|
||||
from factories import variable_factory
|
||||
from fields.workflow_fields import workflow_fields, workflow_pagination_fields
|
||||
from fields.workflow_run_fields import workflow_run_node_execution_fields
|
||||
from libs import helper
|
||||
from libs.helper import TimestampField
|
||||
from libs.login import current_user, login_required
|
||||
from models import App
|
||||
from models.account import Account
|
||||
from models.dataset import Pipeline
|
||||
from models.model import AppMode
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.entities.knowledge_entities.rag_pipeline_entities import PipelineTemplateInfoEntity
|
||||
from services.errors.app import WorkflowHashNotEqualError
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
from services.rag_pipeline.rag_pipeline import RagPipelineService
|
||||
from services.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError, WorkflowService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _validate_name(name):
|
||||
|
@ -37,7 +67,699 @@ class PipelineTemplateListApi(Resource):
|
|||
type = request.args.get("type", default="built-in", type=str, choices=["built-in", "customized"])
|
||||
language = request.args.get("language", default="en-US", type=str)
|
||||
# get pipeline templates
|
||||
return response, 200
|
||||
pipeline_templates = RagPipelineService.get_pipeline_templates(type, language)
|
||||
return pipeline_templates, 200
|
||||
|
||||
|
||||
api.add_resource(PipelineTemplateListApi, "/rag/pipeline/templates")
|
||||
class PipelineTemplateDetailApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
def get(self, pipeline_id: str):
|
||||
pipeline_template = RagPipelineService.get_pipeline_template_detail(pipeline_id)
|
||||
return pipeline_template, 200
|
||||
|
||||
|
||||
class CustomizedPipelineTemplateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
def patch(self, template_id: str):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument(
|
||||
"name",
|
||||
nullable=False,
|
||||
required=True,
|
||||
help="Name must be between 1 to 40 characters.",
|
||||
type=_validate_name,
|
||||
)
|
||||
parser.add_argument(
|
||||
"description",
|
||||
type=str,
|
||||
nullable=True,
|
||||
required=False,
|
||||
default="",
|
||||
)
|
||||
parser.add_argument(
|
||||
"icon_info",
|
||||
type=dict,
|
||||
location="json",
|
||||
nullable=True,
|
||||
)
|
||||
args = parser.parse_args()
|
||||
pipeline_template_info = PipelineTemplateInfoEntity(**args)
|
||||
pipeline_template = RagPipelineService.update_customized_pipeline_template(template_id, pipeline_template_info)
|
||||
return pipeline_template, 200
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
def delete(self, template_id: str):
|
||||
RagPipelineService.delete_customized_pipeline_template(template_id)
|
||||
return 200
|
||||
|
||||
|
||||
class DraftRagPipelineApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_rag_pipeline
|
||||
@marshal_with(workflow_fields)
|
||||
def get(self, pipeline: Pipeline):
|
||||
"""
|
||||
Get draft workflow
|
||||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
# fetch draft workflow by app_model
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
workflow = rag_pipeline_service.get_draft_workflow(pipeline=pipeline)
|
||||
|
||||
if not workflow:
|
||||
raise DraftWorkflowNotExist()
|
||||
|
||||
# return workflow, if not found, return None (initiate graph by frontend)
|
||||
return workflow
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_rag_pipeline
|
||||
def post(self, pipeline: Pipeline):
|
||||
"""
|
||||
Sync draft workflow
|
||||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
content_type = request.headers.get("Content-Type", "")
|
||||
|
||||
if "application/json" in content_type:
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("graph", type=dict, required=True, nullable=False, location="json")
|
||||
parser.add_argument("features", type=dict, required=True, nullable=False, location="json")
|
||||
parser.add_argument("hash", type=str, required=False, location="json")
|
||||
parser.add_argument("environment_variables", type=list, required=False, location="json")
|
||||
parser.add_argument("conversation_variables", type=list, required=False, location="json")
|
||||
args = parser.parse_args()
|
||||
elif "text/plain" in content_type:
|
||||
try:
|
||||
data = json.loads(request.data.decode("utf-8"))
|
||||
if "graph" not in data or "features" not in data:
|
||||
raise ValueError("graph or features not found in data")
|
||||
|
||||
if not isinstance(data.get("graph"), dict) or not isinstance(data.get("features"), dict):
|
||||
raise ValueError("graph or features is not a dict")
|
||||
|
||||
args = {
|
||||
"graph": data.get("graph"),
|
||||
"features": data.get("features"),
|
||||
"hash": data.get("hash"),
|
||||
"environment_variables": data.get("environment_variables"),
|
||||
"conversation_variables": data.get("conversation_variables"),
|
||||
}
|
||||
except json.JSONDecodeError:
|
||||
return {"message": "Invalid JSON data"}, 400
|
||||
else:
|
||||
abort(415)
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
|
||||
try:
|
||||
environment_variables_list = args.get("environment_variables") or []
|
||||
environment_variables = [
|
||||
variable_factory.build_environment_variable_from_mapping(obj) for obj in environment_variables_list
|
||||
]
|
||||
conversation_variables_list = args.get("conversation_variables") or []
|
||||
conversation_variables = [
|
||||
variable_factory.build_conversation_variable_from_mapping(obj) for obj in conversation_variables_list
|
||||
]
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
workflow = rag_pipeline_service.sync_draft_workflow(
|
||||
pipeline=pipeline,
|
||||
graph=args["graph"],
|
||||
features=args["features"],
|
||||
unique_hash=args.get("hash"),
|
||||
account=current_user,
|
||||
environment_variables=environment_variables,
|
||||
conversation_variables=conversation_variables,
|
||||
)
|
||||
except WorkflowHashNotEqualError:
|
||||
raise DraftWorkflowNotSync()
|
||||
|
||||
return {
|
||||
"result": "success",
|
||||
"hash": workflow.unique_hash,
|
||||
"updated_at": TimestampField().format(workflow.updated_at or workflow.created_at),
|
||||
}
|
||||
|
||||
class RagPipelineDraftRunIterationNodeApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_rag_pipeline
|
||||
def post(self, pipeline: Pipeline, node_id: str):
|
||||
"""
|
||||
Run draft workflow iteration node
|
||||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
response = AppGenerateService.generate_single_iteration(
|
||||
pipeline=pipeline, user=current_user, node_id=node_id, args=args, streaming=True
|
||||
)
|
||||
|
||||
return helper.compact_generate_response(response)
|
||||
except services.errors.conversation.ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
except services.errors.conversation.ConversationCompletedError:
|
||||
raise ConversationCompletedError()
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception:
|
||||
logging.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
class RagPipelineDraftRunLoopNodeApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.WORKFLOW])
|
||||
def post(self, app_model: App, node_id: str):
|
||||
"""
|
||||
Run draft workflow loop node
|
||||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
response = AppGenerateService.generate_single_loop(
|
||||
app_model=app_model, user=current_user, node_id=node_id, args=args, streaming=True
|
||||
)
|
||||
|
||||
return helper.compact_generate_response(response)
|
||||
except services.errors.conversation.ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
except services.errors.conversation.ConversationCompletedError:
|
||||
raise ConversationCompletedError()
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception:
|
||||
logging.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
class DraftRagPipelineRunApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.WORKFLOW])
|
||||
def post(self, app_model: App):
|
||||
"""
|
||||
Run draft workflow
|
||||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
||||
parser.add_argument("files", type=list, required=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
response = AppGenerateService.generate(
|
||||
app_model=app_model,
|
||||
user=current_user,
|
||||
args=args,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
streaming=True,
|
||||
)
|
||||
|
||||
return helper.compact_generate_response(response)
|
||||
except InvokeRateLimitError as ex:
|
||||
raise InvokeRateLimitHttpError(ex.description)
|
||||
|
||||
|
||||
class RagPipelineTaskStopApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def post(self, app_model: App, task_id: str):
|
||||
"""
|
||||
Stop workflow task
|
||||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, current_user.id)
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
class RagPipelineNodeRunApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@marshal_with(workflow_run_node_execution_fields)
|
||||
def post(self, app_model: App, node_id: str):
|
||||
"""
|
||||
Run draft workflow node
|
||||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
inputs = args.get("inputs")
|
||||
if inputs == None:
|
||||
raise ValueError("missing inputs")
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
workflow_node_execution = workflow_service.run_draft_workflow_node(
|
||||
app_model=app_model, node_id=node_id, user_inputs=inputs, account=current_user
|
||||
)
|
||||
|
||||
return workflow_node_execution
|
||||
|
||||
|
||||
class PublishedRagPipelineApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_rag_pipeline
|
||||
@marshal_with(workflow_fields)
|
||||
def get(self, pipeline: Pipeline):
|
||||
"""
|
||||
Get published pipeline
|
||||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
# fetch published workflow by pipeline
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
workflow = rag_pipeline_service.get_published_workflow(pipeline=pipeline)
|
||||
|
||||
# return workflow, if not found, return None
|
||||
return workflow
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_rag_pipeline
|
||||
def post(self, pipeline: Pipeline):
|
||||
"""
|
||||
Publish workflow
|
||||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("marked_name", type=str, required=False, default="", location="json")
|
||||
parser.add_argument("marked_comment", type=str, required=False, default="", location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Validate name and comment length
|
||||
if args.marked_name and len(args.marked_name) > 20:
|
||||
raise ValueError("Marked name cannot exceed 20 characters")
|
||||
if args.marked_comment and len(args.marked_comment) > 100:
|
||||
raise ValueError("Marked comment cannot exceed 100 characters")
|
||||
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
with Session(db.engine) as session:
|
||||
workflow = rag_pipeline_service.publish_workflow(
|
||||
session=session,
|
||||
pipeline=pipeline,
|
||||
account=current_user,
|
||||
marked_name=args.marked_name or "",
|
||||
marked_comment=args.marked_comment or "",
|
||||
)
|
||||
|
||||
pipeline.workflow_id = workflow.id
|
||||
db.session.commit()
|
||||
|
||||
workflow_created_at = TimestampField().format(workflow.created_at)
|
||||
|
||||
session.commit()
|
||||
|
||||
return {
|
||||
"result": "success",
|
||||
"created_at": workflow_created_at,
|
||||
}
|
||||
|
||||
|
||||
class DefaultRagPipelineBlockConfigsApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_rag_pipeline
|
||||
def get(self, pipeline: Pipeline):
|
||||
"""
|
||||
Get default block config
|
||||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
# Get default block configs
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
return rag_pipeline_service.get_default_block_configs()
|
||||
|
||||
|
||||
class DefaultRagPipelineBlockConfigApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_rag_pipeline
|
||||
def get(self, pipeline: Pipeline, block_type: str):
|
||||
"""
|
||||
Get default block config
|
||||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("q", type=str, location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
q = args.get("q")
|
||||
|
||||
filters = None
|
||||
if q:
|
||||
try:
|
||||
filters = json.loads(args.get("q", ""))
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError("Invalid filters")
|
||||
|
||||
# Get default block configs
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
return rag_pipeline_service.get_default_block_config(node_type=block_type, filters=filters)
|
||||
|
||||
|
||||
class ConvertToRagPipelineApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_rag_pipeline
|
||||
def post(self, pipeline: Pipeline):
|
||||
"""
|
||||
Convert basic mode of chatbot app to workflow mode
|
||||
Convert expert mode of chatbot app to workflow mode
|
||||
Convert Completion App to Workflow App
|
||||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
|
||||
if request.data:
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("name", type=str, required=False, nullable=True, location="json")
|
||||
parser.add_argument("icon_type", type=str, required=False, nullable=True, location="json")
|
||||
parser.add_argument("icon", type=str, required=False, nullable=True, location="json")
|
||||
parser.add_argument("icon_background", type=str, required=False, nullable=True, location="json")
|
||||
args = parser.parse_args()
|
||||
else:
|
||||
args = {}
|
||||
|
||||
# convert to workflow mode
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
new_app_model = rag_pipeline_service.convert_to_workflow(pipeline=pipeline, account=current_user, args=args)
|
||||
|
||||
# return app id
|
||||
return {
|
||||
"new_app_id": new_app_model.id,
|
||||
}
|
||||
|
||||
|
||||
class RagPipelineConfigApi(Resource):
|
||||
"""Resource for rag pipeline configuration."""
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
return {
|
||||
"parallel_depth_limit": dify_config.WORKFLOW_PARALLEL_DEPTH_LIMIT,
|
||||
}
|
||||
|
||||
|
||||
class PublishedAllRagPipelineApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_rag_pipeline
|
||||
@marshal_with(workflow_pagination_fields)
|
||||
def get(self, pipeline: Pipeline):
|
||||
"""
|
||||
Get published workflows
|
||||
"""
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args")
|
||||
parser.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
|
||||
parser.add_argument("user_id", type=str, required=False, location="args")
|
||||
parser.add_argument("named_only", type=inputs.boolean, required=False, default=False, location="args")
|
||||
args = parser.parse_args()
|
||||
page = int(args.get("page", 1))
|
||||
limit = int(args.get("limit", 10))
|
||||
user_id = args.get("user_id")
|
||||
named_only = args.get("named_only", False)
|
||||
|
||||
if user_id:
|
||||
if user_id != current_user.id:
|
||||
raise Forbidden()
|
||||
user_id = cast(str, user_id)
|
||||
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
with Session(db.engine) as session:
|
||||
workflows, has_more = rag_pipeline_service.get_all_published_workflow(
|
||||
session=session,
|
||||
pipeline=pipeline,
|
||||
page=page,
|
||||
limit=limit,
|
||||
user_id=user_id,
|
||||
named_only=named_only,
|
||||
)
|
||||
|
||||
return {
|
||||
"items": workflows,
|
||||
"page": page,
|
||||
"limit": limit,
|
||||
"has_more": has_more,
|
||||
}
|
||||
|
||||
|
||||
class RagPipelineByIdApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_rag_pipeline
|
||||
@marshal_with(workflow_fields)
|
||||
def patch(self, pipeline: Pipeline, workflow_id: str):
|
||||
"""
|
||||
Update workflow attributes
|
||||
"""
|
||||
# Check permission
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("marked_name", type=str, required=False, location="json")
|
||||
parser.add_argument("marked_comment", type=str, required=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Validate name and comment length
|
||||
if args.marked_name and len(args.marked_name) > 20:
|
||||
raise ValueError("Marked name cannot exceed 20 characters")
|
||||
if args.marked_comment and len(args.marked_comment) > 100:
|
||||
raise ValueError("Marked comment cannot exceed 100 characters")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Prepare update data
|
||||
update_data = {}
|
||||
if args.get("marked_name") is not None:
|
||||
update_data["marked_name"] = args["marked_name"]
|
||||
if args.get("marked_comment") is not None:
|
||||
update_data["marked_comment"] = args["marked_comment"]
|
||||
|
||||
if not update_data:
|
||||
return {"message": "No valid fields to update"}, 400
|
||||
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
|
||||
# Create a session and manage the transaction
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow = rag_pipeline_service.update_workflow(
|
||||
session=session,
|
||||
workflow_id=workflow_id,
|
||||
tenant_id=pipeline.tenant_id,
|
||||
account_id=current_user.id,
|
||||
data=update_data,
|
||||
)
|
||||
|
||||
if not workflow:
|
||||
raise NotFound("Workflow not found")
|
||||
|
||||
# Commit the transaction in the controller
|
||||
session.commit()
|
||||
|
||||
return workflow
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_rag_pipeline
|
||||
def delete(self, pipeline: Pipeline, workflow_id: str):
|
||||
"""
|
||||
Delete workflow
|
||||
"""
|
||||
# Check permission
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
|
||||
# Create a session and manage the transaction
|
||||
with Session(db.engine) as session:
|
||||
try:
|
||||
rag_pipeline_service.delete_workflow(
|
||||
session=session, workflow_id=workflow_id, tenant_id=pipeline.tenant_id
|
||||
)
|
||||
# Commit the transaction in the controller
|
||||
session.commit()
|
||||
except WorkflowInUseError as e:
|
||||
abort(400, description=str(e))
|
||||
except DraftWorkflowDeletionError as e:
|
||||
abort(400, description=str(e))
|
||||
except ValueError as e:
|
||||
raise NotFound(str(e))
|
||||
|
||||
return None, 204
|
||||
|
||||
|
||||
api.add_resource(
|
||||
DraftRagPipelineApi,
|
||||
"/rag/pipelines/<uuid:pipeline_id>/workflows/draft",
|
||||
)
|
||||
api.add_resource(
|
||||
RagPipelineConfigApi,
|
||||
"/rag/pipelines/<uuid:pipeline_id>/workflows/draft/config",
|
||||
)
|
||||
api.add_resource(
|
||||
DraftRagPipelineRunApi,
|
||||
"/rag/pipelines/<uuid:pipeline_id>/workflows/draft/run",
|
||||
)
|
||||
api.add_resource(
|
||||
RagPipelineTaskStopApi,
|
||||
"/rag/pipelines/<uuid:pipeline_id>/workflow-runs/tasks/<string:task_id>/stop",
|
||||
)
|
||||
api.add_resource(
|
||||
RagPipelineNodeRunApi,
|
||||
"/rag/pipelines/<uuid:pipeline_id>/workflows/draft/nodes/<string:node_id>/run",
|
||||
)
|
||||
|
||||
api.add_resource(
|
||||
RagPipelineDraftRunIterationNodeApi,
|
||||
"/rag/pipelines/<uuid:pipeline_id>/workflows/draft/iteration/nodes/<string:node_id>/run",
|
||||
)
|
||||
api.add_resource(
|
||||
RagPipelineDraftRunLoopNodeApi,
|
||||
"/rag/pipelines/<uuid:pipeline_id>/workflows/draft/loop/nodes/<string:node_id>/run",
|
||||
)
|
||||
|
||||
api.add_resource(
|
||||
PublishedRagPipelineApi,
|
||||
"/rag/pipelines/<uuid:pipeline_id>/workflows/publish",
|
||||
)
|
||||
api.add_resource(
|
||||
PublishedAllRagPipelineApi,
|
||||
"/rag/pipelines/<uuid:pipeline_id>/workflows",
|
||||
)
|
||||
api.add_resource(
|
||||
DefaultRagPipelineBlockConfigsApi,
|
||||
"/rag/pipelines/<uuid:pipeline_id>/workflows/default-workflow-block-configs",
|
||||
)
|
||||
api.add_resource(
|
||||
DefaultRagPipelineBlockConfigApi,
|
||||
"/rag/pipelines/<uuid:pipeline_id>/workflows/default-workflow-block-configs/<string:block_type>",
|
||||
)
|
||||
api.add_resource(
|
||||
ConvertToRagPipelineApi,
|
||||
"/rag/pipelines/<uuid:pipeline_id>/convert-to-workflow",
|
||||
)
|
||||
api.add_resource(
|
||||
RagPipelineByIdApi,
|
||||
"/rag/pipelines/<uuid:pipeline_id>/workflows/<string:workflow_id>",
|
||||
)
|
||||
|
||||
api.add_resource(
|
||||
PipelineTemplateListApi,
|
||||
"/rag/pipeline/templates",
|
||||
)
|
||||
api.add_resource(
|
||||
PipelineTemplateDetailApi,
|
||||
"/rag/pipeline/templates/<string:pipeline_id>",
|
||||
)
|
||||
api.add_resource(
|
||||
CustomizedPipelineTemplateApi,
|
||||
"/rag/pipeline/templates/<string:template_id>",
|
||||
)
|
||||
|
|
|
@ -0,0 +1,41 @@
|
|||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import Optional
|
||||
|
||||
from controllers.console.datasets.error import PipelineNotFoundError
|
||||
from extensions.ext_database import db
|
||||
from libs.login import current_user
|
||||
from models.dataset import Pipeline
|
||||
|
||||
|
||||
def get_rag_pipeline(view: Optional[Callable] = None,):
|
||||
def decorator(view_func):
|
||||
@wraps(view_func)
|
||||
def decorated_view(*args, **kwargs):
|
||||
if not kwargs.get("pipeline_id"):
|
||||
raise ValueError("missing pipeline_id in path parameters")
|
||||
|
||||
pipeline_id = kwargs.get("pipeline_id")
|
||||
pipeline_id = str(pipeline_id)
|
||||
|
||||
del kwargs["pipeline_id"]
|
||||
|
||||
pipeline = (
|
||||
db.session.query(Pipeline)
|
||||
.filter(Pipeline.id == pipeline_id, Pipeline.tenant_id == current_user.current_tenant_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not pipeline:
|
||||
raise PipelineNotFoundError()
|
||||
|
||||
kwargs["pipeline"] = pipeline
|
||||
|
||||
return view_func(*args, **kwargs)
|
||||
|
||||
return decorated_view
|
||||
|
||||
if view is None:
|
||||
return decorator
|
||||
else:
|
||||
return decorator(view)
|
|
@ -1139,11 +1139,10 @@ class DatasetMetadataBinding(db.Model): # type: ignore[name-defined]
|
|||
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
created_by = db.Column(StringUUID, nullable=False)
|
||||
|
||||
|
||||
class PipelineBuiltInTemplate(db.Model): # type: ignore[name-defined]
|
||||
__tablename__ = "pipeline_built_in_templates"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="pipeline_built_in_template_pkey"),
|
||||
)
|
||||
__table_args__ = (db.PrimaryKeyConstraint("id", name="pipeline_built_in_template_pkey"),)
|
||||
|
||||
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
pipeline_id = db.Column(StringUUID, nullable=False)
|
||||
|
@ -1181,9 +1180,7 @@ class PipelineCustomizedTemplate(db.Model): # type: ignore[name-defined]
|
|||
|
||||
class Pipeline(db.Model): # type: ignore[name-defined]
|
||||
__tablename__ = "pipelines"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="pipeline_pkey"),
|
||||
)
|
||||
__table_args__ = (db.PrimaryKeyConstraint("id", name="pipeline_pkey"),)
|
||||
|
||||
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
tenant_id: Mapped[str] = db.Column(StringUUID, nullable=False)
|
||||
|
|
|
@ -38,7 +38,8 @@ class WorkflowType(Enum):
|
|||
|
||||
WORKFLOW = "workflow"
|
||||
CHAT = "chat"
|
||||
|
||||
RAG_PIPELINE = "rag_pipeline"
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> "WorkflowType":
|
||||
"""
|
||||
|
|
|
@ -0,0 +1,16 @@
|
|||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class IconInfo(BaseModel):
|
||||
icon: str
|
||||
icon_background: Optional[str] = None
|
||||
icon_type: Optional[str] = None
|
||||
icon_url: Optional[str] = None
|
||||
|
||||
|
||||
class PipelineTemplateInfoEntity(BaseModel):
|
||||
name: str
|
||||
description: str
|
||||
icon_info: IconInfo
|
|
@ -0,0 +1,64 @@
|
|||
import json
|
||||
from os import path
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from flask import current_app
|
||||
|
||||
from services.rag_pipeline.pipeline_template.pipeline_template_base import PipelineTemplateRetrievalBase
|
||||
from services.rag_pipeline.pipeline_template.pipeline_template_type import PipelineTemplateType
|
||||
|
||||
|
||||
class BuiltInPipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
|
||||
"""
|
||||
Retrieval pipeline template from built-in, the location is constants/pipeline_templates.json
|
||||
"""
|
||||
|
||||
builtin_data: Optional[dict] = None
|
||||
|
||||
def get_type(self) -> str:
|
||||
return PipelineTemplateType.BUILTIN
|
||||
|
||||
def get_pipeline_templates(self, language: str) -> dict:
|
||||
result = self.fetch_pipeline_templates_from_builtin(language)
|
||||
return result
|
||||
|
||||
def get_pipeline_template_detail(self, pipeline_id: str):
|
||||
result = self.fetch_pipeline_template_detail_from_builtin(pipeline_id)
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def _get_builtin_data(cls) -> dict:
|
||||
"""
|
||||
Get builtin data.
|
||||
:return:
|
||||
"""
|
||||
if cls.builtin_data:
|
||||
return cls.builtin_data
|
||||
|
||||
root_path = current_app.root_path
|
||||
cls.builtin_data = json.loads(
|
||||
Path(path.join(root_path, "constants", "pipeline_templates.json")).read_text(encoding="utf-8")
|
||||
)
|
||||
|
||||
return cls.builtin_data or {}
|
||||
|
||||
@classmethod
|
||||
def fetch_pipeline_templates_from_builtin(cls, language: str) -> dict:
|
||||
"""
|
||||
Fetch pipeline templates from builtin.
|
||||
:param language: language
|
||||
:return:
|
||||
"""
|
||||
builtin_data: dict[str, dict[str, dict]] = cls._get_builtin_data()
|
||||
return builtin_data.get("pipeline_templates", {}).get(language, {})
|
||||
|
||||
@classmethod
|
||||
def fetch_pipeline_template_detail_from_builtin(cls, pipeline_id: str) -> Optional[dict]:
|
||||
"""
|
||||
Fetch pipeline template detail from builtin.
|
||||
:param pipeline_id: Pipeline ID
|
||||
:return:
|
||||
"""
|
||||
builtin_data: dict[str, dict[str, dict]] = cls._get_builtin_data()
|
||||
return builtin_data.get("pipeline_templates", {}).get(pipeline_id)
|
|
@ -1,8 +1,9 @@
|
|||
from typing import Optional
|
||||
|
||||
from constants.languages import languages
|
||||
from flask_login import current_user
|
||||
|
||||
from extensions.ext_database import db
|
||||
from models.model import App, RecommendedApp
|
||||
from models.dataset import Pipeline, PipelineCustomizedTemplate
|
||||
from services.app_dsl_service import AppDslService
|
||||
from services.rag_pipeline.pipeline_template.pipeline_template_base import PipelineTemplateRetrievalBase
|
||||
from services.rag_pipeline.pipeline_template.pipeline_template_type import PipelineTemplateType
|
||||
|
@ -14,92 +15,57 @@ class CustomizedPipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
|
|||
"""
|
||||
|
||||
def get_pipeline_templates(self, language: str) -> dict:
|
||||
result = self.fetch_pipeline_templates_from_db(language)
|
||||
result = self.fetch_pipeline_templates_from_customized(
|
||||
tenant_id=current_user.current_tenant_id, language=language
|
||||
)
|
||||
return result
|
||||
|
||||
def get_pipeline_template_detail(self, pipeline_id: str):
|
||||
result = self.fetch_pipeline_template_detail_from_db(pipeline_id)
|
||||
def get_pipeline_template_detail(self, template_id: str):
|
||||
result = self.fetch_pipeline_template_detail_from_db(template_id)
|
||||
return result
|
||||
|
||||
def get_type(self) -> str:
|
||||
return PipelineTemplateType.CUSTOMIZED
|
||||
|
||||
@classmethod
|
||||
def fetch_recommended_apps_from_db(cls, language: str) -> dict:
|
||||
def fetch_pipeline_templates_from_customized(cls, tenant_id: str, language: str) -> dict:
|
||||
"""
|
||||
Fetch recommended apps from db.
|
||||
Fetch pipeline templates from db.
|
||||
:param tenant_id: tenant id
|
||||
:param language: language
|
||||
:return:
|
||||
"""
|
||||
recommended_apps = (
|
||||
db.session.query(RecommendedApp)
|
||||
.filter(RecommendedApp.is_listed == True, RecommendedApp.language == language)
|
||||
pipeline_templates = (
|
||||
db.session.query(PipelineCustomizedTemplate)
|
||||
.filter(PipelineCustomizedTemplate.tenant_id == tenant_id, PipelineCustomizedTemplate.language == language)
|
||||
.all()
|
||||
)
|
||||
|
||||
if len(recommended_apps) == 0:
|
||||
recommended_apps = (
|
||||
db.session.query(RecommendedApp)
|
||||
.filter(RecommendedApp.is_listed == True, RecommendedApp.language == languages[0])
|
||||
.all()
|
||||
)
|
||||
|
||||
categories = set()
|
||||
recommended_apps_result = []
|
||||
for recommended_app in recommended_apps:
|
||||
app = recommended_app.app
|
||||
if not app or not app.is_public:
|
||||
continue
|
||||
|
||||
site = app.site
|
||||
if not site:
|
||||
continue
|
||||
|
||||
recommended_app_result = {
|
||||
"id": recommended_app.id,
|
||||
"app": recommended_app.app,
|
||||
"app_id": recommended_app.app_id,
|
||||
"description": site.description,
|
||||
"copyright": site.copyright,
|
||||
"privacy_policy": site.privacy_policy,
|
||||
"custom_disclaimer": site.custom_disclaimer,
|
||||
"category": recommended_app.category,
|
||||
"position": recommended_app.position,
|
||||
"is_listed": recommended_app.is_listed,
|
||||
}
|
||||
recommended_apps_result.append(recommended_app_result)
|
||||
|
||||
categories.add(recommended_app.category)
|
||||
|
||||
return {"recommended_apps": recommended_apps_result, "categories": sorted(categories)}
|
||||
return {"pipeline_templates": pipeline_templates}
|
||||
|
||||
@classmethod
|
||||
def fetch_recommended_app_detail_from_db(cls, app_id: str) -> Optional[dict]:
|
||||
def fetch_pipeline_template_detail_from_db(cls, template_id: str) -> Optional[dict]:
|
||||
"""
|
||||
Fetch recommended app detail from db.
|
||||
:param app_id: App ID
|
||||
Fetch pipeline template detail from db.
|
||||
:param template_id: Template ID
|
||||
:return:
|
||||
"""
|
||||
# is in public recommended list
|
||||
recommended_app = (
|
||||
db.session.query(RecommendedApp)
|
||||
.filter(RecommendedApp.is_listed == True, RecommendedApp.app_id == app_id)
|
||||
.first()
|
||||
pipeline_template = (
|
||||
db.session.query(PipelineCustomizedTemplate).filter(PipelineCustomizedTemplate.id == template_id).first()
|
||||
)
|
||||
|
||||
if not recommended_app:
|
||||
if not pipeline_template:
|
||||
return None
|
||||
|
||||
# get app detail
|
||||
app_model = db.session.query(App).filter(App.id == app_id).first()
|
||||
if not app_model or not app_model.is_public:
|
||||
# get pipeline detail
|
||||
pipeline = db.session.query(Pipeline).filter(Pipeline.id == pipeline_template.pipeline_id).first()
|
||||
if not pipeline or not pipeline.is_public:
|
||||
return None
|
||||
|
||||
return {
|
||||
"id": app_model.id,
|
||||
"name": app_model.name,
|
||||
"icon": app_model.icon,
|
||||
"icon_background": app_model.icon_background,
|
||||
"mode": app_model.mode,
|
||||
"export_data": AppDslService.export_dsl(app_model=app_model),
|
||||
"id": pipeline.id,
|
||||
"name": pipeline.name,
|
||||
"icon": pipeline.icon,
|
||||
"mode": pipeline.mode,
|
||||
"export_data": AppDslService.export_dsl(app_model=pipeline),
|
||||
}
|
||||
|
|
|
@ -1,8 +1,7 @@
|
|||
from typing import Optional
|
||||
|
||||
from constants.languages import languages
|
||||
from extensions.ext_database import db
|
||||
from models.model import App, RecommendedApp
|
||||
from models.dataset import Pipeline, PipelineBuiltInTemplate
|
||||
from services.app_dsl_service import AppDslService
|
||||
from services.recommend_app.recommend_app_base import RecommendAppRetrievalBase
|
||||
from services.recommend_app.recommend_app_type import RecommendAppType
|
||||
|
@ -10,96 +9,57 @@ from services.recommend_app.recommend_app_type import RecommendAppType
|
|||
|
||||
class DatabasePipelineTemplateRetrieval(RecommendAppRetrievalBase):
|
||||
"""
|
||||
Retrieval recommended app from database
|
||||
Retrieval pipeline template from database
|
||||
"""
|
||||
|
||||
def get_recommended_apps_and_categories(self, language: str) -> dict:
|
||||
result = self.fetch_recommended_apps_from_db(language)
|
||||
def get_pipeline_templates(self, language: str) -> dict:
|
||||
result = self.fetch_pipeline_templates_from_db(language)
|
||||
return result
|
||||
|
||||
def get_recommend_app_detail(self, app_id: str):
|
||||
result = self.fetch_recommended_app_detail_from_db(app_id)
|
||||
def get_pipeline_template_detail(self, pipeline_id: str):
|
||||
result = self.fetch_pipeline_template_detail_from_db(pipeline_id)
|
||||
return result
|
||||
|
||||
def get_type(self) -> str:
|
||||
return RecommendAppType.DATABASE
|
||||
|
||||
@classmethod
|
||||
def fetch_recommended_apps_from_db(cls, language: str) -> dict:
|
||||
def fetch_pipeline_templates_from_db(cls, language: str) -> dict:
|
||||
"""
|
||||
Fetch recommended apps from db.
|
||||
Fetch pipeline templates from db.
|
||||
:param language: language
|
||||
:return:
|
||||
"""
|
||||
recommended_apps = (
|
||||
db.session.query(RecommendedApp)
|
||||
.filter(RecommendedApp.is_listed == True, RecommendedApp.language == language)
|
||||
.all()
|
||||
pipeline_templates = (
|
||||
db.session.query(PipelineBuiltInTemplate).filter(PipelineBuiltInTemplate.language == language).all()
|
||||
)
|
||||
|
||||
if len(recommended_apps) == 0:
|
||||
recommended_apps = (
|
||||
db.session.query(RecommendedApp)
|
||||
.filter(RecommendedApp.is_listed == True, RecommendedApp.language == languages[0])
|
||||
.all()
|
||||
)
|
||||
|
||||
categories = set()
|
||||
recommended_apps_result = []
|
||||
for recommended_app in recommended_apps:
|
||||
app = recommended_app.app
|
||||
if not app or not app.is_public:
|
||||
continue
|
||||
|
||||
site = app.site
|
||||
if not site:
|
||||
continue
|
||||
|
||||
recommended_app_result = {
|
||||
"id": recommended_app.id,
|
||||
"app": recommended_app.app,
|
||||
"app_id": recommended_app.app_id,
|
||||
"description": site.description,
|
||||
"copyright": site.copyright,
|
||||
"privacy_policy": site.privacy_policy,
|
||||
"custom_disclaimer": site.custom_disclaimer,
|
||||
"category": recommended_app.category,
|
||||
"position": recommended_app.position,
|
||||
"is_listed": recommended_app.is_listed,
|
||||
}
|
||||
recommended_apps_result.append(recommended_app_result)
|
||||
|
||||
categories.add(recommended_app.category)
|
||||
|
||||
return {"recommended_apps": recommended_apps_result, "categories": sorted(categories)}
|
||||
return {"pipeline_templates": pipeline_templates}
|
||||
|
||||
@classmethod
|
||||
def fetch_recommended_app_detail_from_db(cls, app_id: str) -> Optional[dict]:
|
||||
def fetch_pipeline_template_detail_from_db(cls, pipeline_id: str) -> Optional[dict]:
|
||||
"""
|
||||
Fetch recommended app detail from db.
|
||||
:param app_id: App ID
|
||||
Fetch pipeline template detail from db.
|
||||
:param pipeline_id: Pipeline ID
|
||||
:return:
|
||||
"""
|
||||
# is in public recommended list
|
||||
recommended_app = (
|
||||
db.session.query(RecommendedApp)
|
||||
.filter(RecommendedApp.is_listed == True, RecommendedApp.app_id == app_id)
|
||||
.first()
|
||||
pipeline_template = (
|
||||
db.session.query(PipelineBuiltInTemplate).filter(PipelineBuiltInTemplate.id == pipeline_id).first()
|
||||
)
|
||||
|
||||
if not recommended_app:
|
||||
if not pipeline_template:
|
||||
return None
|
||||
|
||||
# get app detail
|
||||
app_model = db.session.query(App).filter(App.id == app_id).first()
|
||||
if not app_model or not app_model.is_public:
|
||||
pipeline = db.session.query(Pipeline).filter(Pipeline.id == pipeline_template.pipeline_id).first()
|
||||
if not pipeline or not pipeline.is_public:
|
||||
return None
|
||||
|
||||
return {
|
||||
"id": app_model.id,
|
||||
"name": app_model.name,
|
||||
"icon": app_model.icon,
|
||||
"icon_background": app_model.icon_background,
|
||||
"mode": app_model.mode,
|
||||
"export_data": AppDslService.export_dsl(app_model=app_model),
|
||||
"id": pipeline.id,
|
||||
"name": pipeline.name,
|
||||
"icon": pipeline.icon,
|
||||
"mode": pipeline.mode,
|
||||
"export_data": AppDslService.export_dsl(app_model=pipeline),
|
||||
}
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class PipelineTemplateRetrievalBase(ABC):
|
||||
|
@ -9,7 +10,7 @@ class PipelineTemplateRetrievalBase(ABC):
|
|||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_pipeline_template_detail(self, pipeline_id: str):
|
||||
def get_pipeline_template_detail(self, template_id: str) -> Optional[dict]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
|
|
|
@ -1,9 +1,11 @@
|
|||
from services.rag_pipeline.pipeline_template.built_in.built_in_retrieval import BuiltInPipelineTemplateRetrieval
|
||||
from services.rag_pipeline.pipeline_template.database.database_retrieval import DatabasePipelineTemplateRetrieval
|
||||
from services.rag_pipeline.pipeline_template.pipeline_template_base import PipelineTemplateRetrievalBase
|
||||
from services.rag_pipeline.pipeline_template.pipeline_template_type import PipelineTemplateType
|
||||
from services.rag_pipeline.pipeline_template.remote.remote_retrieval import RemotePipelineTemplateRetrieval
|
||||
|
||||
|
||||
|
||||
class RecommendAppRetrievalFactory:
|
||||
class PipelineTemplateRetrievalFactory:
|
||||
@staticmethod
|
||||
def get_pipeline_template_factory(mode: str) -> type[PipelineTemplateRetrievalBase]:
|
||||
match mode:
|
||||
|
@ -11,11 +13,13 @@ class RecommendAppRetrievalFactory:
|
|||
return RemotePipelineTemplateRetrieval
|
||||
case PipelineTemplateType.CUSTOMIZED:
|
||||
return DatabasePipelineTemplateRetrieval
|
||||
case PipelineTemplateType.BUILTIN:
|
||||
return BuildInPipelineTemplateRetrieval
|
||||
case PipelineTemplateType.DATABASE:
|
||||
return DatabasePipelineTemplateRetrieval
|
||||
case PipelineTemplateType.BUILT_IN:
|
||||
return BuiltInPipelineTemplateRetrieval
|
||||
case _:
|
||||
raise ValueError(f"invalid fetch recommended apps mode: {mode}")
|
||||
|
||||
@staticmethod
|
||||
def get_buildin_recommend_app_retrieval():
|
||||
return BuildInRecommendAppRetrieval
|
||||
def get_built_in_pipeline_template_retrieval():
|
||||
return BuiltInPipelineTemplateRetrieval
|
||||
|
|
|
@ -3,5 +3,6 @@ from enum import StrEnum
|
|||
|
||||
class PipelineTemplateType(StrEnum):
|
||||
REMOTE = "remote"
|
||||
BUILTIN = "builtin"
|
||||
DATABASE = "database"
|
||||
CUSTOMIZED = "customized"
|
||||
BUILTIN = "builtin"
|
||||
|
|
|
@ -4,9 +4,10 @@ from typing import Optional
|
|||
import requests
|
||||
|
||||
from configs import dify_config
|
||||
from services.recommend_app.buildin.buildin_retrieval import BuildInRecommendAppRetrieval
|
||||
from services.recommend_app.recommend_app_type import RecommendAppType
|
||||
from services.rag_pipeline.pipeline_template.pipeline_template_base import PipelineTemplateRetrievalBase
|
||||
from services.rag_pipeline.pipeline_template.pipeline_template_type import PipelineTemplateType
|
||||
from services.recommend_app.buildin.buildin_retrieval import BuildInRecommendAppRetrieval
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
@ -23,26 +24,26 @@ class RemotePipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
|
|||
result = BuildInRecommendAppRetrieval.fetch_recommended_app_detail_from_builtin(pipeline_id)
|
||||
return result
|
||||
|
||||
def get_recommended_apps_and_categories(self, language: str) -> dict:
|
||||
def get_pipeline_templates(self, language: str) -> dict:
|
||||
try:
|
||||
result = self.fetch_recommended_apps_from_dify_official(language)
|
||||
result = self.fetch_pipeline_templates_from_dify_official(language)
|
||||
except Exception as e:
|
||||
logger.warning(f"fetch recommended apps from dify official failed: {e}, switch to built-in.")
|
||||
logger.warning(f"fetch pipeline templates from dify official failed: {e}, switch to built-in.")
|
||||
result = BuildInRecommendAppRetrieval.fetch_recommended_apps_from_builtin(language)
|
||||
return result
|
||||
|
||||
def get_type(self) -> str:
|
||||
return RecommendAppType.REMOTE
|
||||
return PipelineTemplateType.REMOTE
|
||||
|
||||
@classmethod
|
||||
def fetch_recommended_app_detail_from_dify_official(cls, app_id: str) -> Optional[dict]:
|
||||
def fetch_pipeline_template_detail_from_dify_official(cls, pipeline_id: str) -> Optional[dict]:
|
||||
"""
|
||||
Fetch recommended app detail from dify official.
|
||||
:param app_id: App ID
|
||||
Fetch pipeline template detail from dify official.
|
||||
:param pipeline_id: Pipeline ID
|
||||
:return:
|
||||
"""
|
||||
domain = dify_config.HOSTED_FETCH_APP_TEMPLATES_REMOTE_DOMAIN
|
||||
url = f"{domain}/apps/{app_id}"
|
||||
domain = dify_config.HOSTED_FETCH_PIPELINE_TEMPLATES_REMOTE_DOMAIN
|
||||
url = f"{domain}/pipelines/{pipeline_id}"
|
||||
response = requests.get(url, timeout=(3, 10))
|
||||
if response.status_code != 200:
|
||||
return None
|
||||
|
@ -50,21 +51,18 @@ class RemotePipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
|
|||
return data
|
||||
|
||||
@classmethod
|
||||
def fetch_recommended_apps_from_dify_official(cls, language: str) -> dict:
|
||||
def fetch_pipeline_templates_from_dify_official(cls, language: str) -> dict:
|
||||
"""
|
||||
Fetch recommended apps from dify official.
|
||||
Fetch pipeline templates from dify official.
|
||||
:param language: language
|
||||
:return:
|
||||
"""
|
||||
domain = dify_config.HOSTED_FETCH_APP_TEMPLATES_REMOTE_DOMAIN
|
||||
url = f"{domain}/apps?language={language}"
|
||||
domain = dify_config.HOSTED_FETCH_PIPELINE_TEMPLATES_REMOTE_DOMAIN
|
||||
url = f"{domain}/pipelines?language={language}"
|
||||
response = requests.get(url, timeout=(3, 10))
|
||||
if response.status_code != 200:
|
||||
raise ValueError(f"fetch recommended apps failed, status code: {response.status_code}")
|
||||
raise ValueError(f"fetch pipeline templates failed, status code: {response.status_code}")
|
||||
|
||||
result: dict = response.json()
|
||||
|
||||
if "categories" in result:
|
||||
result["categories"] = sorted(result["categories"])
|
||||
|
||||
return result
|
||||
|
|
|
@ -1,52 +1,575 @@
|
|||
import datetime
|
||||
import hashlib
|
||||
import os
|
||||
import uuid
|
||||
from typing import Any, List, Literal, Union
|
||||
import json
|
||||
import time
|
||||
from collections.abc import Callable, Generator, Sequence
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any, Literal, Optional
|
||||
|
||||
from flask_login import current_user
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from models.dataset import PipelineBuiltInTemplate, PipelineCustomizedTemplate # type: ignore
|
||||
from configs import dify_config
|
||||
from core.variables.variables import Variable
|
||||
from core.workflow.graph_engine.entities.event import InNodeEvent
|
||||
from core.workflow.nodes.base.node import BaseNode
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.nodes.event.types import NodeEvent
|
||||
from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from extensions.db import db
|
||||
from models.account import Account
|
||||
from models.dataset import Pipeline, PipelineBuiltInTemplate, PipelineCustomizedTemplate # type: ignore
|
||||
from models.workflow import Workflow, WorkflowNodeExecution, WorkflowType
|
||||
from services.entities.knowledge_entities.rag_pipeline_entities import PipelineTemplateInfoEntity
|
||||
from services.errors.app import WorkflowHashNotEqualError
|
||||
from services.rag_pipeline.pipeline_template.pipeline_template_factory import PipelineTemplateRetrievalFactory
|
||||
|
||||
|
||||
class RagPipelineService:
|
||||
@staticmethod
|
||||
def get_pipeline_templates(
|
||||
type: Literal["built-in", "customized"] = "built-in",
|
||||
type: Literal["built-in", "customized"] = "built-in", language: str = "en-US"
|
||||
) -> list[PipelineBuiltInTemplate | PipelineCustomizedTemplate]:
|
||||
if type == "built-in":
|
||||
return PipelineBuiltInTemplate.query.all()
|
||||
else:
|
||||
return PipelineCustomizedTemplate.query.all()
|
||||
|
||||
@staticmethod
|
||||
def get_pipeline_templates(cls, type: Literal["built-in", "customized"] = "built-in", language: str) -> dict:
|
||||
"""
|
||||
Get pipeline templates.
|
||||
:param type: type
|
||||
:param language: language
|
||||
:return:
|
||||
"""
|
||||
mode = dify_config.HOSTED_FETCH_APP_TEMPLATES_MODE
|
||||
retrieval_instance = RecommendAppRetrievalFactory.get_recommend_app_factory(mode)()
|
||||
result = retrieval_instance.get_recommended_apps_and_categories(language)
|
||||
if not result.get("recommended_apps") and language != "en-US":
|
||||
result = (
|
||||
RecommendAppRetrievalFactory.get_buildin_recommend_app_retrieval().fetch_recommended_apps_from_builtin(
|
||||
mode = dify_config.HOSTED_FETCH_PIPELINE_TEMPLATES_MODE
|
||||
retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)
|
||||
result = retrieval_instance.get_pipeline_templates(language)
|
||||
if not result.get("pipeline_templates") and language != "en-US":
|
||||
result = PipelineTemplateRetrievalFactory.get_built_in_pipeline_template_retrieval().fetch_pipeline_templates_from_builtin(
|
||||
"en-US"
|
||||
)
|
||||
)
|
||||
return result.get("pipeline_templates")
|
||||
else:
|
||||
mode = "customized"
|
||||
retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)
|
||||
result = retrieval_instance.get_pipeline_templates(language)
|
||||
return result.get("pipeline_templates")
|
||||
|
||||
@classmethod
|
||||
def get_pipeline_template_detail(cls, pipeline_id: str) -> Optional[dict]:
|
||||
"""
|
||||
Get pipeline template detail.
|
||||
:param pipeline_id: pipeline id
|
||||
:return:
|
||||
"""
|
||||
mode = dify_config.HOSTED_FETCH_PIPELINE_TEMPLATES_MODE
|
||||
retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)
|
||||
result: Optional[dict] = retrieval_instance.get_pipeline_template_detail(pipeline_id)
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def get_recommend_app_detail(cls, app_id: str) -> Optional[dict]:
|
||||
def update_customized_pipeline_template(cls, template_id: str, template_info: PipelineTemplateInfoEntity):
|
||||
"""
|
||||
Get recommend app detail.
|
||||
:param app_id: app id
|
||||
Update pipeline template.
|
||||
:param template_id: template id
|
||||
:param template_info: template info
|
||||
"""
|
||||
customized_template: PipelineCustomizedTemplate | None = (
|
||||
db.query(PipelineCustomizedTemplate)
|
||||
.filter(
|
||||
PipelineCustomizedTemplate.id == template_id,
|
||||
PipelineCustomizedTemplate.tenant_id == current_user.current_tenant_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if not customized_template:
|
||||
raise ValueError("Customized pipeline template not found.")
|
||||
customized_template.name = template_info.name
|
||||
customized_template.description = template_info.description
|
||||
customized_template.icon = template_info.icon_info.model_dump()
|
||||
db.commit()
|
||||
return customized_template
|
||||
|
||||
@classmethod
|
||||
def delete_customized_pipeline_template(cls, template_id: str):
|
||||
"""
|
||||
Delete customized pipeline template.
|
||||
"""
|
||||
customized_template: PipelineCustomizedTemplate | None = (
|
||||
db.query(PipelineCustomizedTemplate)
|
||||
.filter(
|
||||
PipelineCustomizedTemplate.id == template_id,
|
||||
PipelineCustomizedTemplate.tenant_id == current_user.current_tenant_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if not customized_template:
|
||||
raise ValueError("Customized pipeline template not found.")
|
||||
db.delete(customized_template)
|
||||
db.commit()
|
||||
|
||||
|
||||
def get_draft_workflow(self, pipeline: Pipeline) -> Optional[Workflow]:
|
||||
"""
|
||||
Get draft workflow
|
||||
"""
|
||||
# fetch draft workflow by rag pipeline
|
||||
workflow = (
|
||||
db.session.query(Workflow)
|
||||
.filter(
|
||||
Workflow.tenant_id == pipeline.tenant_id,
|
||||
Workflow.app_id == pipeline.id,
|
||||
Workflow.version == "draft",
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
# return draft workflow
|
||||
return workflow
|
||||
|
||||
def get_published_workflow(self, pipeline: Pipeline) -> Optional[Workflow]:
|
||||
"""
|
||||
Get published workflow
|
||||
"""
|
||||
|
||||
if not pipeline.workflow_id:
|
||||
return None
|
||||
|
||||
# fetch published workflow by workflow_id
|
||||
workflow = (
|
||||
db.session.query(Workflow)
|
||||
.filter(
|
||||
Workflow.tenant_id == pipeline.tenant_id,
|
||||
Workflow.app_id == pipeline.id,
|
||||
Workflow.id == pipeline.workflow_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
return workflow
|
||||
|
||||
def get_all_published_workflow(
|
||||
self,
|
||||
*,
|
||||
session: Session,
|
||||
pipeline: Pipeline,
|
||||
page: int,
|
||||
limit: int,
|
||||
user_id: str | None,
|
||||
named_only: bool = False,
|
||||
) -> tuple[Sequence[Workflow], bool]:
|
||||
"""
|
||||
Get published workflow with pagination
|
||||
"""
|
||||
if not pipeline.workflow_id:
|
||||
return [], False
|
||||
|
||||
stmt = (
|
||||
select(Workflow)
|
||||
.where(Workflow.app_id == pipeline.id)
|
||||
.order_by(Workflow.version.desc())
|
||||
.limit(limit + 1)
|
||||
.offset((page - 1) * limit)
|
||||
)
|
||||
|
||||
if user_id:
|
||||
stmt = stmt.where(Workflow.created_by == user_id)
|
||||
|
||||
if named_only:
|
||||
stmt = stmt.where(Workflow.marked_name != "")
|
||||
|
||||
workflows = session.scalars(stmt).all()
|
||||
|
||||
has_more = len(workflows) > limit
|
||||
if has_more:
|
||||
workflows = workflows[:-1]
|
||||
|
||||
return workflows, has_more
|
||||
|
||||
def sync_draft_workflow(
|
||||
self,
|
||||
*,
|
||||
pipeline: Pipeline,
|
||||
graph: dict,
|
||||
features: dict,
|
||||
unique_hash: Optional[str],
|
||||
account: Account,
|
||||
environment_variables: Sequence[Variable],
|
||||
conversation_variables: Sequence[Variable],
|
||||
) -> Workflow:
|
||||
"""
|
||||
Sync draft workflow
|
||||
:raises WorkflowHashNotEqualError
|
||||
"""
|
||||
# fetch draft workflow by app_model
|
||||
workflow = self.get_draft_workflow(pipeline=pipeline)
|
||||
|
||||
if workflow and workflow.unique_hash != unique_hash:
|
||||
raise WorkflowHashNotEqualError()
|
||||
|
||||
# validate features structure
|
||||
self.validate_features_structure(pipeline=pipeline, features=features)
|
||||
|
||||
# create draft workflow if not found
|
||||
if not workflow:
|
||||
workflow = Workflow(
|
||||
tenant_id=pipeline.tenant_id,
|
||||
app_id=pipeline.id,
|
||||
type=WorkflowType.RAG_PIPELINE.value,
|
||||
version="draft",
|
||||
graph=json.dumps(graph),
|
||||
features=json.dumps(features),
|
||||
created_by=account.id,
|
||||
environment_variables=environment_variables,
|
||||
conversation_variables=conversation_variables,
|
||||
)
|
||||
db.session.add(workflow)
|
||||
# update draft workflow if found
|
||||
else:
|
||||
workflow.graph = json.dumps(graph)
|
||||
workflow.features = json.dumps(features)
|
||||
workflow.updated_by = account.id
|
||||
workflow.updated_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
workflow.environment_variables = environment_variables
|
||||
workflow.conversation_variables = conversation_variables
|
||||
|
||||
# commit db session changes
|
||||
db.session.commit()
|
||||
|
||||
# trigger app workflow events
|
||||
app_draft_workflow_was_synced.send(pipeline, synced_draft_workflow=workflow)
|
||||
|
||||
# return draft workflow
|
||||
return workflow
|
||||
|
||||
def publish_workflow(
|
||||
self,
|
||||
*,
|
||||
session: Session,
|
||||
pipeline: Pipeline,
|
||||
account: Account,
|
||||
marked_name: str = "",
|
||||
marked_comment: str = "",
|
||||
) -> Workflow:
|
||||
draft_workflow_stmt = select(Workflow).where(
|
||||
Workflow.tenant_id == pipeline.tenant_id,
|
||||
Workflow.app_id == pipeline.id,
|
||||
Workflow.version == "draft",
|
||||
)
|
||||
draft_workflow = session.scalar(draft_workflow_stmt)
|
||||
if not draft_workflow:
|
||||
raise ValueError("No valid workflow found.")
|
||||
|
||||
# create new workflow
|
||||
workflow = Workflow.new(
|
||||
tenant_id=pipeline.tenant_id,
|
||||
app_id=pipeline.id,
|
||||
type=draft_workflow.type,
|
||||
version=str(datetime.now(UTC).replace(tzinfo=None)),
|
||||
graph=draft_workflow.graph,
|
||||
features=draft_workflow.features,
|
||||
created_by=account.id,
|
||||
environment_variables=draft_workflow.environment_variables,
|
||||
conversation_variables=draft_workflow.conversation_variables,
|
||||
marked_name=marked_name,
|
||||
marked_comment=marked_comment,
|
||||
)
|
||||
|
||||
# commit db session changes
|
||||
session.add(workflow)
|
||||
|
||||
# trigger app workflow events
|
||||
app_published_workflow_was_updated.send(pipeline, published_workflow=workflow)
|
||||
|
||||
# return new workflow
|
||||
return workflow
|
||||
|
||||
def get_default_block_configs(self) -> list[dict]:
|
||||
"""
|
||||
Get default block configs
|
||||
"""
|
||||
# return default block config
|
||||
default_block_configs = []
|
||||
for node_class_mapping in NODE_TYPE_CLASSES_MAPPING.values():
|
||||
node_class = node_class_mapping[LATEST_VERSION]
|
||||
default_config = node_class.get_default_config()
|
||||
if default_config:
|
||||
default_block_configs.append(default_config)
|
||||
|
||||
return default_block_configs
|
||||
|
||||
def get_default_block_config(self, node_type: str, filters: Optional[dict] = None) -> Optional[dict]:
|
||||
"""
|
||||
Get default config of node.
|
||||
:param node_type: node type
|
||||
:param filters: filter by node config parameters.
|
||||
:return:
|
||||
"""
|
||||
mode = dify_config.HOSTED_FETCH_APP_TEMPLATES_MODE
|
||||
retrieval_instance = RecommendAppRetrievalFactory.get_recommend_app_factory(mode)()
|
||||
result: dict = retrieval_instance.get_recommend_app_detail(app_id)
|
||||
return result
|
||||
node_type_enum = NodeType(node_type)
|
||||
|
||||
# return default block config
|
||||
if node_type_enum not in NODE_TYPE_CLASSES_MAPPING:
|
||||
return None
|
||||
|
||||
node_class = NODE_TYPE_CLASSES_MAPPING[node_type_enum][LATEST_VERSION]
|
||||
default_config = node_class.get_default_config(filters=filters)
|
||||
if not default_config:
|
||||
return None
|
||||
|
||||
return default_config
|
||||
|
||||
def run_draft_workflow_node(
|
||||
self, pipeline: Pipeline, node_id: str, user_inputs: dict, account: Account
|
||||
) -> WorkflowNodeExecution:
|
||||
"""
|
||||
Run draft workflow node
|
||||
"""
|
||||
# fetch draft workflow by app_model
|
||||
draft_workflow = self.get_draft_workflow(pipeline=pipeline)
|
||||
if not draft_workflow:
|
||||
raise ValueError("Workflow not initialized")
|
||||
|
||||
# run draft workflow node
|
||||
start_at = time.perf_counter()
|
||||
|
||||
workflow_node_execution = self._handle_node_run_result(
|
||||
getter=lambda: WorkflowEntry.single_step_run(
|
||||
workflow=draft_workflow,
|
||||
node_id=node_id,
|
||||
user_inputs=user_inputs,
|
||||
user_id=account.id,
|
||||
),
|
||||
start_at=start_at,
|
||||
tenant_id=pipeline.tenant_id,
|
||||
node_id=node_id,
|
||||
)
|
||||
|
||||
workflow_node_execution.app_id = pipeline.id
|
||||
workflow_node_execution.created_by = account.id
|
||||
workflow_node_execution.workflow_id = draft_workflow.id
|
||||
|
||||
db.session.add(workflow_node_execution)
|
||||
db.session.commit()
|
||||
|
||||
return workflow_node_execution
|
||||
|
||||
def run_free_workflow_node(
|
||||
self, node_data: dict, tenant_id: str, user_id: str, node_id: str, user_inputs: dict[str, Any]
|
||||
) -> WorkflowNodeExecution:
|
||||
"""
|
||||
Run draft workflow node
|
||||
"""
|
||||
# run draft workflow node
|
||||
start_at = time.perf_counter()
|
||||
|
||||
workflow_node_execution = self._handle_node_run_result(
|
||||
getter=lambda: WorkflowEntry.run_free_node(
|
||||
node_id=node_id,
|
||||
node_data=node_data,
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
user_inputs=user_inputs,
|
||||
),
|
||||
start_at=start_at,
|
||||
tenant_id=tenant_id,
|
||||
node_id=node_id,
|
||||
)
|
||||
|
||||
return workflow_node_execution
|
||||
|
||||
def _handle_node_run_result(
|
||||
self,
|
||||
getter: Callable[[], tuple[BaseNode, Generator[NodeEvent | InNodeEvent, None, None]]],
|
||||
start_at: float,
|
||||
tenant_id: str,
|
||||
node_id: str,
|
||||
) -> WorkflowNodeExecution:
|
||||
"""
|
||||
Handle node run result
|
||||
|
||||
:param getter: Callable[[], tuple[BaseNode, Generator[RunEvent | InNodeEvent, None, None]]]
|
||||
:param start_at: float
|
||||
:param tenant_id: str
|
||||
:param node_id: str
|
||||
"""
|
||||
try:
|
||||
node_instance, generator = getter()
|
||||
|
||||
node_run_result: NodeRunResult | None = None
|
||||
for event in generator:
|
||||
if isinstance(event, RunCompletedEvent):
|
||||
node_run_result = event.run_result
|
||||
|
||||
# sign output files
|
||||
node_run_result.outputs = WorkflowEntry.handle_special_values(node_run_result.outputs)
|
||||
break
|
||||
|
||||
if not node_run_result:
|
||||
raise ValueError("Node run failed with no run result")
|
||||
# single step debug mode error handling return
|
||||
if node_run_result.status == WorkflowNodeExecutionStatus.FAILED and node_instance.should_continue_on_error:
|
||||
node_error_args: dict[str, Any] = {
|
||||
"status": WorkflowNodeExecutionStatus.EXCEPTION,
|
||||
"error": node_run_result.error,
|
||||
"inputs": node_run_result.inputs,
|
||||
"metadata": {"error_strategy": node_instance.node_data.error_strategy},
|
||||
}
|
||||
if node_instance.node_data.error_strategy is ErrorStrategy.DEFAULT_VALUE:
|
||||
node_run_result = NodeRunResult(
|
||||
**node_error_args,
|
||||
outputs={
|
||||
**node_instance.node_data.default_value_dict,
|
||||
"error_message": node_run_result.error,
|
||||
"error_type": node_run_result.error_type,
|
||||
},
|
||||
)
|
||||
else:
|
||||
node_run_result = NodeRunResult(
|
||||
**node_error_args,
|
||||
outputs={
|
||||
"error_message": node_run_result.error,
|
||||
"error_type": node_run_result.error_type,
|
||||
},
|
||||
)
|
||||
run_succeeded = node_run_result.status in (
|
||||
WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
WorkflowNodeExecutionStatus.EXCEPTION,
|
||||
)
|
||||
error = node_run_result.error if not run_succeeded else None
|
||||
except WorkflowNodeRunFailedError as e:
|
||||
node_instance = e.node_instance
|
||||
run_succeeded = False
|
||||
node_run_result = None
|
||||
error = e.error
|
||||
|
||||
workflow_node_execution = WorkflowNodeExecution()
|
||||
workflow_node_execution.id = str(uuid4())
|
||||
workflow_node_execution.tenant_id = tenant_id
|
||||
workflow_node_execution.triggered_from = WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP.value
|
||||
workflow_node_execution.index = 1
|
||||
workflow_node_execution.node_id = node_id
|
||||
workflow_node_execution.node_type = node_instance.node_type
|
||||
workflow_node_execution.title = node_instance.node_data.title
|
||||
workflow_node_execution.elapsed_time = time.perf_counter() - start_at
|
||||
workflow_node_execution.created_by_role = CreatedByRole.ACCOUNT.value
|
||||
workflow_node_execution.created_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
workflow_node_execution.finished_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
if run_succeeded and node_run_result:
|
||||
# create workflow node execution
|
||||
inputs = WorkflowEntry.handle_special_values(node_run_result.inputs) if node_run_result.inputs else None
|
||||
process_data = (
|
||||
WorkflowEntry.handle_special_values(node_run_result.process_data)
|
||||
if node_run_result.process_data
|
||||
else None
|
||||
)
|
||||
outputs = WorkflowEntry.handle_special_values(node_run_result.outputs) if node_run_result.outputs else None
|
||||
|
||||
workflow_node_execution.inputs = json.dumps(inputs)
|
||||
workflow_node_execution.process_data = json.dumps(process_data)
|
||||
workflow_node_execution.outputs = json.dumps(outputs)
|
||||
workflow_node_execution.execution_metadata = (
|
||||
json.dumps(jsonable_encoder(node_run_result.metadata)) if node_run_result.metadata else None
|
||||
)
|
||||
if node_run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
|
||||
workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value
|
||||
elif node_run_result.status == WorkflowNodeExecutionStatus.EXCEPTION:
|
||||
workflow_node_execution.status = WorkflowNodeExecutionStatus.EXCEPTION.value
|
||||
workflow_node_execution.error = node_run_result.error
|
||||
else:
|
||||
# create workflow node execution
|
||||
workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value
|
||||
workflow_node_execution.error = error
|
||||
|
||||
return workflow_node_execution
|
||||
|
||||
def convert_to_workflow(self, app_model: App, account: Account, args: dict) -> App:
|
||||
"""
|
||||
Basic mode of chatbot app(expert mode) to workflow
|
||||
Completion App to Workflow App
|
||||
|
||||
:param app_model: App instance
|
||||
:param account: Account instance
|
||||
:param args: dict
|
||||
:return:
|
||||
"""
|
||||
# chatbot convert to workflow mode
|
||||
workflow_converter = WorkflowConverter()
|
||||
|
||||
if app_model.mode not in {AppMode.CHAT.value, AppMode.COMPLETION.value}:
|
||||
raise ValueError(f"Current App mode: {app_model.mode} is not supported convert to workflow.")
|
||||
|
||||
# convert to workflow
|
||||
new_app: App = workflow_converter.convert_to_workflow(
|
||||
app_model=app_model,
|
||||
account=account,
|
||||
name=args.get("name", "Default Name"),
|
||||
icon_type=args.get("icon_type", "emoji"),
|
||||
icon=args.get("icon", "🤖"),
|
||||
icon_background=args.get("icon_background", "#FFEAD5"),
|
||||
)
|
||||
|
||||
return new_app
|
||||
|
||||
def validate_features_structure(self, app_model: App, features: dict) -> dict:
|
||||
if app_model.mode == AppMode.ADVANCED_CHAT.value:
|
||||
return AdvancedChatAppConfigManager.config_validate(
|
||||
tenant_id=app_model.tenant_id, config=features, only_structure_validate=True
|
||||
)
|
||||
elif app_model.mode == AppMode.WORKFLOW.value:
|
||||
return WorkflowAppConfigManager.config_validate(
|
||||
tenant_id=app_model.tenant_id, config=features, only_structure_validate=True
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid app mode: {app_model.mode}")
|
||||
|
||||
def update_workflow(
|
||||
self, *, session: Session, workflow_id: str, tenant_id: str, account_id: str, data: dict
|
||||
) -> Optional[Workflow]:
|
||||
"""
|
||||
Update workflow attributes
|
||||
|
||||
:param session: SQLAlchemy database session
|
||||
:param workflow_id: Workflow ID
|
||||
:param tenant_id: Tenant ID
|
||||
:param account_id: Account ID (for permission check)
|
||||
:param data: Dictionary containing fields to update
|
||||
:return: Updated workflow or None if not found
|
||||
"""
|
||||
stmt = select(Workflow).where(Workflow.id == workflow_id, Workflow.tenant_id == tenant_id)
|
||||
workflow = session.scalar(stmt)
|
||||
|
||||
if not workflow:
|
||||
return None
|
||||
|
||||
allowed_fields = ["marked_name", "marked_comment"]
|
||||
|
||||
for field, value in data.items():
|
||||
if field in allowed_fields:
|
||||
setattr(workflow, field, value)
|
||||
|
||||
workflow.updated_by = account_id
|
||||
workflow.updated_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
|
||||
return workflow
|
||||
|
||||
def delete_workflow(self, *, session: Session, workflow_id: str, tenant_id: str) -> bool:
|
||||
"""
|
||||
Delete a workflow
|
||||
|
||||
:param session: SQLAlchemy database session
|
||||
:param workflow_id: Workflow ID
|
||||
:param tenant_id: Tenant ID
|
||||
:return: True if successful
|
||||
:raises: ValueError if workflow not found
|
||||
:raises: WorkflowInUseError if workflow is in use
|
||||
:raises: DraftWorkflowDeletionError if workflow is a draft version
|
||||
"""
|
||||
stmt = select(Workflow).where(Workflow.id == workflow_id, Workflow.tenant_id == tenant_id)
|
||||
workflow = session.scalar(stmt)
|
||||
|
||||
if not workflow:
|
||||
raise ValueError(f"Workflow with ID {workflow_id} not found")
|
||||
|
||||
# Check if workflow is a draft version
|
||||
if workflow.version == "draft":
|
||||
raise DraftWorkflowDeletionError("Cannot delete draft workflow versions")
|
||||
|
||||
# Check if this workflow is currently referenced by an app
|
||||
stmt = select(App).where(App.workflow_id == workflow_id)
|
||||
app = session.scalar(stmt)
|
||||
if app:
|
||||
# Cannot delete a workflow that's currently in use by an app
|
||||
raise WorkflowInUseError(f"Cannot delete workflow that is currently in use by app '{app.name}'")
|
||||
|
||||
session.delete(workflow)
|
||||
return True
|
||||
|
|
Loading…
Reference in New Issue