From 5c4bf2a9e451643a8d72d70f3443cc6c2d4240e2 Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Thu, 17 Apr 2025 15:07:23 +0800 Subject: [PATCH] r2 --- api/controllers/console/__init__.py | 2 +- .../datasets/rag_pipeline/rag_pipeline.py | 136 +++ .../rag_pipeline_workflow.py} | 199 ++-- api/core/datasource/__base/tool.py | 222 +++++ api/core/datasource/__base/tool_provider.py | 109 +++ api/core/datasource/__base/tool_runtime.py | 36 + api/core/datasource/__init__.py | 0 .../datasource/entities/agent_entities.py | 0 api/core/datasource/entities/api_entities.py | 72 ++ .../datasource/entities/common_entities.py | 23 + api/core/datasource/entities/constants.py | 1 + api/core/datasource/entities/file_entities.py | 1 + api/core/datasource/entities/tool_bundle.py | 29 + api/core/datasource/entities/tool_entities.py | 427 +++++++++ api/core/datasource/entities/values.py | 111 +++ api/core/datasource/errors.py | 37 + api/core/datasource/plugin_tool/provider.py | 79 ++ api/core/datasource/plugin_tool/tool.py | 89 ++ api/core/datasource/tool_engine.py | 357 +++++++ api/core/datasource/tool_file_manager.py | 234 +++++ api/core/datasource/tool_label_manager.py | 101 ++ api/core/datasource/tool_manager.py | 870 ++++++++++++++++++ api/core/datasource/utils/__init__.py | 0 api/core/datasource/utils/configuration.py | 265 ++++++ .../dataset_multi_retriever_tool.py | 199 ++++ .../dataset_retriever_base_tool.py | 33 + .../dataset_retriever_tool.py | 202 ++++ .../utils/dataset_retriever_tool.py | 134 +++ .../datasource/utils/message_transformer.py | 121 +++ .../utils/model_invocation_utils.py | 169 ++++ api/core/datasource/utils/parser.py | 389 ++++++++ api/core/datasource/utils/rag_web_reader.py | 17 + .../datasource/utils/text_processing_utils.py | 17 + api/core/datasource/utils/uuid_utils.py | 9 + api/core/datasource/utils/web_reader_tool.py | 375 ++++++++ .../utils/workflow_configuration_sync.py | 43 + api/core/datasource/utils/yaml_utils.py | 35 + api/core/workflow/constants.py | 1 + api/core/workflow/entities/node_entities.py | 1 + .../workflow/nodes/datasource/__init__.py | 3 + .../nodes/datasource/datasource_node.py | 406 ++++++++ .../workflow/nodes/datasource/entities.py | 56 ++ api/core/workflow/nodes/datasource/exc.py | 16 + api/core/workflow/nodes/enums.py | 1 + api/core/workflow/nodes/tool/tool_node.py | 2 +- api/factories/variable_factory.py | 10 +- api/fields/workflow_fields.py | 11 + api/models/workflow.py | 21 + api/services/rag_pipeline/rag_pipeline.py | 60 +- 49 files changed, 5609 insertions(+), 122 deletions(-) create mode 100644 api/controllers/console/datasets/rag_pipeline/rag_pipeline.py rename api/controllers/console/datasets/{pipeline.py => rag_pipeline/rag_pipeline_workflow.py} (86%) create mode 100644 api/core/datasource/__base/tool.py create mode 100644 api/core/datasource/__base/tool_provider.py create mode 100644 api/core/datasource/__base/tool_runtime.py create mode 100644 api/core/datasource/__init__.py create mode 100644 api/core/datasource/entities/agent_entities.py create mode 100644 api/core/datasource/entities/api_entities.py create mode 100644 api/core/datasource/entities/common_entities.py create mode 100644 api/core/datasource/entities/constants.py create mode 100644 api/core/datasource/entities/file_entities.py create mode 100644 api/core/datasource/entities/tool_bundle.py create mode 100644 api/core/datasource/entities/tool_entities.py create mode 100644 api/core/datasource/entities/values.py create mode 100644 api/core/datasource/errors.py create mode 100644 api/core/datasource/plugin_tool/provider.py create mode 100644 api/core/datasource/plugin_tool/tool.py create mode 100644 api/core/datasource/tool_engine.py create mode 100644 api/core/datasource/tool_file_manager.py create mode 100644 api/core/datasource/tool_label_manager.py create mode 100644 api/core/datasource/tool_manager.py create mode 100644 api/core/datasource/utils/__init__.py create mode 100644 api/core/datasource/utils/configuration.py create mode 100644 api/core/datasource/utils/dataset_retriever/dataset_multi_retriever_tool.py create mode 100644 api/core/datasource/utils/dataset_retriever/dataset_retriever_base_tool.py create mode 100644 api/core/datasource/utils/dataset_retriever/dataset_retriever_tool.py create mode 100644 api/core/datasource/utils/dataset_retriever_tool.py create mode 100644 api/core/datasource/utils/message_transformer.py create mode 100644 api/core/datasource/utils/model_invocation_utils.py create mode 100644 api/core/datasource/utils/parser.py create mode 100644 api/core/datasource/utils/rag_web_reader.py create mode 100644 api/core/datasource/utils/text_processing_utils.py create mode 100644 api/core/datasource/utils/uuid_utils.py create mode 100644 api/core/datasource/utils/web_reader_tool.py create mode 100644 api/core/datasource/utils/workflow_configuration_sync.py create mode 100644 api/core/datasource/utils/yaml_utils.py create mode 100644 api/core/workflow/nodes/datasource/__init__.py create mode 100644 api/core/workflow/nodes/datasource/datasource_node.py create mode 100644 api/core/workflow/nodes/datasource/entities.py create mode 100644 api/core/workflow/nodes/datasource/exc.py diff --git a/api/controllers/console/__init__.py b/api/controllers/console/__init__.py index a974c63e35..74e5da9435 100644 --- a/api/controllers/console/__init__.py +++ b/api/controllers/console/__init__.py @@ -1,5 +1,6 @@ from flask import Blueprint +from .datasets.rag_pipeline import data_source from libs.external_api import ExternalApi from .app.app_import import AppImportApi, AppImportCheckDependenciesApi, AppImportConfirmApi @@ -75,7 +76,6 @@ from .billing import billing, compliance # Import datasets controllers from .datasets import ( - data_source, datasets, datasets_document, datasets_segments, diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py new file mode 100644 index 0000000000..4ff2f07bb6 --- /dev/null +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py @@ -0,0 +1,136 @@ +import json +import logging +from typing import cast + +from flask import abort, request +from flask_restful import Resource, inputs, marshal_with, reqparse # type: ignore # type: ignore +from sqlalchemy.orm import Session +from werkzeug.exceptions import Forbidden, InternalServerError, NotFound + +import services +from configs import dify_config +from controllers.console import api +from controllers.console.app.error import ( + ConversationCompletedError, + DraftWorkflowNotExist, + DraftWorkflowNotSync, +) +from controllers.console.app.wraps import get_app_model +from controllers.console.datasets.wraps import get_rag_pipeline +from controllers.console.wraps import ( + account_initialization_required, + enterprise_license_required, + setup_required, +) +from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError +from core.app.apps.base_app_queue_manager import AppQueueManager +from core.app.entities.app_invoke_entities import InvokeFrom +from extensions.ext_database import db +from factories import variable_factory +from fields.workflow_fields import workflow_fields, workflow_pagination_fields +from fields.workflow_run_fields import workflow_run_node_execution_fields +from libs import helper +from libs.helper import TimestampField +from libs.login import current_user, login_required +from models import App +from models.account import Account +from models.dataset import Pipeline +from models.model import AppMode +from services.app_generate_service import AppGenerateService +from services.entities.knowledge_entities.rag_pipeline_entities import PipelineTemplateInfoEntity +from services.errors.app import WorkflowHashNotEqualError +from services.errors.llm import InvokeRateLimitError +from services.rag_pipeline.rag_pipeline import RagPipelineService +from services.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError, WorkflowService + +logger = logging.getLogger(__name__) + + +def _validate_name(name): + if not name or len(name) < 1 or len(name) > 40: + raise ValueError("Name must be between 1 to 40 characters.") + return name + + +def _validate_description_length(description): + if len(description) > 400: + raise ValueError("Description cannot exceed 400 characters.") + return description + + +class PipelineTemplateListApi(Resource): + @setup_required + @login_required + @account_initialization_required + @enterprise_license_required + def get(self): + type = request.args.get("type", default="built-in", type=str, choices=["built-in", "customized"]) + language = request.args.get("language", default="en-US", type=str) + # get pipeline templates + pipeline_templates = RagPipelineService.get_pipeline_templates(type, language) + return pipeline_templates, 200 + + +class PipelineTemplateDetailApi(Resource): + @setup_required + @login_required + @account_initialization_required + @enterprise_license_required + def get(self, pipeline_id: str): + pipeline_template = RagPipelineService.get_pipeline_template_detail(pipeline_id) + return pipeline_template, 200 + + +class CustomizedPipelineTemplateApi(Resource): + @setup_required + @login_required + @account_initialization_required + @enterprise_license_required + def patch(self, template_id: str): + parser = reqparse.RequestParser() + parser.add_argument( + "name", + nullable=False, + required=True, + help="Name must be between 1 to 40 characters.", + type=_validate_name, + ) + parser.add_argument( + "description", + type=str, + nullable=True, + required=False, + default="", + ) + parser.add_argument( + "icon_info", + type=dict, + location="json", + nullable=True, + ) + args = parser.parse_args() + pipeline_template_info = PipelineTemplateInfoEntity(**args) + pipeline_template = RagPipelineService.update_customized_pipeline_template(template_id, pipeline_template_info) + return pipeline_template, 200 + + @setup_required + @login_required + @account_initialization_required + @enterprise_license_required + def delete(self, template_id: str): + RagPipelineService.delete_customized_pipeline_template(template_id) + return 200 + + +api.add_resource( + PipelineTemplateListApi, + "/rag/pipeline/templates", +) +api.add_resource( + PipelineTemplateDetailApi, + "/rag/pipeline/templates/", +) +api.add_resource( + CustomizedPipelineTemplateApi, + "/rag/pipeline/templates/", +) diff --git a/api/controllers/console/datasets/pipeline.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py similarity index 86% rename from api/controllers/console/datasets/pipeline.py rename to api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py index 72e819fa12..d33531b447 100644 --- a/api/controllers/console/datasets/pipeline.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py @@ -15,11 +15,9 @@ from controllers.console.app.error import ( DraftWorkflowNotExist, DraftWorkflowNotSync, ) -from controllers.console.app.wraps import get_app_model from controllers.console.datasets.wraps import get_rag_pipeline from controllers.console.wraps import ( account_initialization_required, - enterprise_license_required, setup_required, ) from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError @@ -32,96 +30,17 @@ from fields.workflow_run_fields import workflow_run_node_execution_fields from libs import helper from libs.helper import TimestampField from libs.login import current_user, login_required -from models import App from models.account import Account from models.dataset import Pipeline -from models.model import AppMode from services.app_generate_service import AppGenerateService -from services.entities.knowledge_entities.rag_pipeline_entities import PipelineTemplateInfoEntity from services.errors.app import WorkflowHashNotEqualError from services.errors.llm import InvokeRateLimitError from services.rag_pipeline.rag_pipeline import RagPipelineService -from services.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError, WorkflowService +from services.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError logger = logging.getLogger(__name__) -def _validate_name(name): - if not name or len(name) < 1 or len(name) > 40: - raise ValueError("Name must be between 1 to 40 characters.") - return name - - -def _validate_description_length(description): - if len(description) > 400: - raise ValueError("Description cannot exceed 400 characters.") - return description - - -class PipelineTemplateListApi(Resource): - @setup_required - @login_required - @account_initialization_required - @enterprise_license_required - def get(self): - type = request.args.get("type", default="built-in", type=str, choices=["built-in", "customized"]) - language = request.args.get("language", default="en-US", type=str) - # get pipeline templates - pipeline_templates = RagPipelineService.get_pipeline_templates(type, language) - return pipeline_templates, 200 - - -class PipelineTemplateDetailApi(Resource): - @setup_required - @login_required - @account_initialization_required - @enterprise_license_required - def get(self, pipeline_id: str): - pipeline_template = RagPipelineService.get_pipeline_template_detail(pipeline_id) - return pipeline_template, 200 - - -class CustomizedPipelineTemplateApi(Resource): - @setup_required - @login_required - @account_initialization_required - @enterprise_license_required - def patch(self, template_id: str): - parser = reqparse.RequestParser() - parser.add_argument( - "name", - nullable=False, - required=True, - help="Name must be between 1 to 40 characters.", - type=_validate_name, - ) - parser.add_argument( - "description", - type=str, - nullable=True, - required=False, - default="", - ) - parser.add_argument( - "icon_info", - type=dict, - location="json", - nullable=True, - ) - args = parser.parse_args() - pipeline_template_info = PipelineTemplateInfoEntity(**args) - pipeline_template = RagPipelineService.update_customized_pipeline_template(template_id, pipeline_template_info) - return pipeline_template, 200 - - @setup_required - @login_required - @account_initialization_required - @enterprise_license_required - def delete(self, template_id: str): - RagPipelineService.delete_customized_pipeline_template(template_id) - return 200 - - class DraftRagPipelineApi(Resource): @setup_required @login_required @@ -130,7 +49,7 @@ class DraftRagPipelineApi(Resource): @marshal_with(workflow_fields) def get(self, pipeline: Pipeline): """ - Get draft workflow + Get draft rag pipeline's workflow """ # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: @@ -167,6 +86,7 @@ class DraftRagPipelineApi(Resource): parser.add_argument("hash", type=str, required=False, location="json") parser.add_argument("environment_variables", type=list, required=False, location="json") parser.add_argument("conversation_variables", type=list, required=False, location="json") + parser.add_argument("pipeline_variables", type=dict, required=False, location="json") args = parser.parse_args() elif "text/plain" in content_type: try: @@ -183,6 +103,7 @@ class DraftRagPipelineApi(Resource): "hash": data.get("hash"), "environment_variables": data.get("environment_variables"), "conversation_variables": data.get("conversation_variables"), + "pipeline_variables": data.get("pipeline_variables"), } except json.JSONDecodeError: return {"message": "Invalid JSON data"}, 400 @@ -192,8 +113,6 @@ class DraftRagPipelineApi(Resource): if not isinstance(current_user, Account): raise Forbidden() - workflow_service = WorkflowService() - try: environment_variables_list = args.get("environment_variables") or [] environment_variables = [ @@ -203,6 +122,11 @@ class DraftRagPipelineApi(Resource): conversation_variables = [ variable_factory.build_conversation_variable_from_mapping(obj) for obj in conversation_variables_list ] + pipeline_variables_list = args.get("pipeline_variables") or {} + pipeline_variables = { + k: [variable_factory.build_pipeline_variable_from_mapping(obj) for obj in v] + for k, v in pipeline_variables_list.items() + } rag_pipeline_service = RagPipelineService() workflow = rag_pipeline_service.sync_draft_workflow( pipeline=pipeline, @@ -212,6 +136,7 @@ class DraftRagPipelineApi(Resource): account=current_user, environment_variables=environment_variables, conversation_variables=conversation_variables, + pipeline_variables=pipeline_variables, ) except WorkflowHashNotEqualError: raise DraftWorkflowNotSync() @@ -263,8 +188,8 @@ class RagPipelineDraftRunLoopNodeApi(Resource): @setup_required @login_required @account_initialization_required - @get_app_model(mode=[AppMode.WORKFLOW]) - def post(self, app_model: App, node_id: str): + @get_rag_pipeline + def post(self, pipeline: Pipeline, node_id: str): """ Run draft workflow loop node """ @@ -281,7 +206,7 @@ class RagPipelineDraftRunLoopNodeApi(Resource): try: response = AppGenerateService.generate_single_loop( - app_model=app_model, user=current_user, node_id=node_id, args=args, streaming=True + pipeline=pipeline, user=current_user, node_id=node_id, args=args, streaming=True ) return helper.compact_generate_response(response) @@ -300,8 +225,8 @@ class DraftRagPipelineRunApi(Resource): @setup_required @login_required @account_initialization_required - @get_app_model(mode=[AppMode.WORKFLOW]) - def post(self, app_model: App): + @get_rag_pipeline + def post(self, pipeline: Pipeline): """ Run draft workflow """ @@ -319,7 +244,7 @@ class DraftRagPipelineRunApi(Resource): try: response = AppGenerateService.generate( - app_model=app_model, + pipeline=pipeline, user=current_user, args=args, invoke_from=InvokeFrom.DEBUGGER, @@ -330,32 +255,45 @@ class DraftRagPipelineRunApi(Resource): except InvokeRateLimitError as ex: raise InvokeRateLimitHttpError(ex.description) - -class RagPipelineTaskStopApi(Resource): +class RagPipelineDatasourceNodeRunApi(Resource): @setup_required @login_required @account_initialization_required - @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) - def post(self, app_model: App, task_id: str): + @get_rag_pipeline + def post(self, pipeline: Pipeline, node_id: str): """ - Stop workflow task + Run rag pipeline datasource """ # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() - AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, current_user.id) + if not isinstance(current_user, Account): + raise Forbidden() - return {"result": "success"} + parser = reqparse.RequestParser() + parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") + args = parser.parse_args() + + inputs = args.get("inputs") + if inputs == None: + raise ValueError("missing inputs") + + rag_pipeline_service = RagPipelineService() + workflow_node_execution = rag_pipeline_service.run_datasource_workflow_node( + pipeline=pipeline, node_id=node_id, user_inputs=inputs, account=current_user + ) + + return workflow_node_execution -class RagPipelineNodeRunApi(Resource): +class RagPipelineDraftNodeRunApi(Resource): @setup_required @login_required @account_initialization_required - @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) + @get_rag_pipeline @marshal_with(workflow_run_node_execution_fields) - def post(self, app_model: App, node_id: str): + def post(self, pipeline: Pipeline, node_id: str): """ Run draft workflow node """ @@ -374,13 +312,29 @@ class RagPipelineNodeRunApi(Resource): if inputs == None: raise ValueError("missing inputs") - workflow_service = WorkflowService() - workflow_node_execution = workflow_service.run_draft_workflow_node( - app_model=app_model, node_id=node_id, user_inputs=inputs, account=current_user + rag_pipeline_service = RagPipelineService() + workflow_node_execution = rag_pipeline_service.run_draft_workflow_node( + pipeline=pipeline, node_id=node_id, user_inputs=inputs, account=current_user ) return workflow_node_execution +class RagPipelineTaskStopApi(Resource): + @setup_required + @login_required + @account_initialization_required + @get_rag_pipeline + def post(self, pipeline: Pipeline, task_id: str): + """ + Stop workflow task + """ + # The role of the current user in the ta table must be admin, owner, or editor + if not current_user.is_editor: + raise Forbidden() + + AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, current_user.id) + + return {"result": "success"} class PublishedRagPipelineApi(Resource): @setup_required @@ -695,6 +649,25 @@ class RagPipelineByIdApi(Resource): return None, 204 +class RagPipelineSecondStepApi(Resource): + @setup_required + @login_required + @account_initialization_required + @get_rag_pipeline + def get(self, pipeline: Pipeline): + """ + Get second step parameters of rag pipeline + """ + # The role of the current user in the ta table must be admin, owner, or editor + if not current_user.is_editor: + raise Forbidden() + datasource_provider = request.args.get("datasource_provider", required=True, type=str) + + rag_pipeline_service = RagPipelineService() + return rag_pipeline_service.get_second_step_parameters(pipeline=pipeline, + datasource_provider=datasource_provider + ) + api.add_resource( DraftRagPipelineApi, @@ -713,9 +686,13 @@ api.add_resource( "/rag/pipelines//workflow-runs/tasks//stop", ) api.add_resource( - RagPipelineNodeRunApi, + RagPipelineDraftNodeRunApi, "/rag/pipelines//workflows/draft/nodes//run", ) +api.add_resource( + RagPipelinePublishedNodeRunApi, + "/rag/pipelines//workflows/published/nodes//run", +) api.add_resource( RagPipelineDraftRunIterationNodeApi, @@ -751,15 +728,3 @@ api.add_resource( "/rag/pipelines//workflows/", ) -api.add_resource( - PipelineTemplateListApi, - "/rag/pipeline/templates", -) -api.add_resource( - PipelineTemplateDetailApi, - "/rag/pipeline/templates/", -) -api.add_resource( - CustomizedPipelineTemplateApi, - "/rag/pipeline/templates/", -) diff --git a/api/core/datasource/__base/tool.py b/api/core/datasource/__base/tool.py new file mode 100644 index 0000000000..35e16b5c8f --- /dev/null +++ b/api/core/datasource/__base/tool.py @@ -0,0 +1,222 @@ +from abc import ABC, abstractmethod +from collections.abc import Generator +from copy import deepcopy +from typing import TYPE_CHECKING, Any, Optional + +if TYPE_CHECKING: + from models.model import File + +from core.tools.__base.tool_runtime import ToolRuntime +from core.tools.entities.tool_entities import ( + ToolEntity, + ToolInvokeMessage, + ToolParameter, + ToolProviderType, +) + + +class Tool(ABC): + """ + The base class of a tool + """ + + entity: ToolEntity + runtime: ToolRuntime + + def __init__(self, entity: ToolEntity, runtime: ToolRuntime) -> None: + self.entity = entity + self.runtime = runtime + + def fork_tool_runtime(self, runtime: ToolRuntime) -> "Tool": + """ + fork a new tool with metadata + :return: the new tool + """ + return self.__class__( + entity=self.entity.model_copy(), + runtime=runtime, + ) + + @abstractmethod + def tool_provider_type(self) -> ToolProviderType: + """ + get the tool provider type + + :return: the tool provider type + """ + + def invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + conversation_id: Optional[str] = None, + app_id: Optional[str] = None, + message_id: Optional[str] = None, + ) -> Generator[ToolInvokeMessage]: + if self.runtime and self.runtime.runtime_parameters: + tool_parameters.update(self.runtime.runtime_parameters) + + # try parse tool parameters into the correct type + tool_parameters = self._transform_tool_parameters_type(tool_parameters) + + result = self._invoke( + user_id=user_id, + tool_parameters=tool_parameters, + conversation_id=conversation_id, + app_id=app_id, + message_id=message_id, + ) + + if isinstance(result, ToolInvokeMessage): + + def single_generator() -> Generator[ToolInvokeMessage, None, None]: + yield result + + return single_generator() + elif isinstance(result, list): + + def generator() -> Generator[ToolInvokeMessage, None, None]: + yield from result + + return generator() + else: + return result + + def _transform_tool_parameters_type(self, tool_parameters: dict[str, Any]) -> dict[str, Any]: + """ + Transform tool parameters type + """ + # Temp fix for the issue that the tool parameters will be converted to empty while validating the credentials + result = deepcopy(tool_parameters) + for parameter in self.entity.parameters or []: + if parameter.name in tool_parameters: + result[parameter.name] = parameter.type.cast_value(tool_parameters[parameter.name]) + + return result + + @abstractmethod + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + conversation_id: Optional[str] = None, + app_id: Optional[str] = None, + message_id: Optional[str] = None, + ) -> ToolInvokeMessage | list[ToolInvokeMessage] | Generator[ToolInvokeMessage, None, None]: + pass + + def get_runtime_parameters( + self, + conversation_id: Optional[str] = None, + app_id: Optional[str] = None, + message_id: Optional[str] = None, + ) -> list[ToolParameter]: + """ + get the runtime parameters + + interface for developer to dynamic change the parameters of a tool depends on the variables pool + + :return: the runtime parameters + """ + return self.entity.parameters + + def get_merged_runtime_parameters( + self, + conversation_id: Optional[str] = None, + app_id: Optional[str] = None, + message_id: Optional[str] = None, + ) -> list[ToolParameter]: + """ + get merged runtime parameters + + :return: merged runtime parameters + """ + parameters = self.entity.parameters + parameters = parameters.copy() + user_parameters = self.get_runtime_parameters() or [] + user_parameters = user_parameters.copy() + + # override parameters + for parameter in user_parameters: + # check if parameter in tool parameters + for tool_parameter in parameters: + if tool_parameter.name == parameter.name: + # override parameter + tool_parameter.type = parameter.type + tool_parameter.form = parameter.form + tool_parameter.required = parameter.required + tool_parameter.default = parameter.default + tool_parameter.options = parameter.options + tool_parameter.llm_description = parameter.llm_description + break + else: + # add new parameter + parameters.append(parameter) + + return parameters + + def create_image_message( + self, + image: str, + ) -> ToolInvokeMessage: + """ + create an image message + + :param image: the url of the image + :return: the image message + """ + return ToolInvokeMessage( + type=ToolInvokeMessage.MessageType.IMAGE, message=ToolInvokeMessage.TextMessage(text=image) + ) + + def create_file_message(self, file: "File") -> ToolInvokeMessage: + return ToolInvokeMessage( + type=ToolInvokeMessage.MessageType.FILE, + message=ToolInvokeMessage.FileMessage(), + meta={"file": file}, + ) + + def create_link_message(self, link: str) -> ToolInvokeMessage: + """ + create a link message + + :param link: the url of the link + :return: the link message + """ + return ToolInvokeMessage( + type=ToolInvokeMessage.MessageType.LINK, message=ToolInvokeMessage.TextMessage(text=link) + ) + + def create_text_message(self, text: str) -> ToolInvokeMessage: + """ + create a text message + + :param text: the text + :return: the text message + """ + return ToolInvokeMessage( + type=ToolInvokeMessage.MessageType.TEXT, + message=ToolInvokeMessage.TextMessage(text=text), + ) + + def create_blob_message(self, blob: bytes, meta: Optional[dict] = None) -> ToolInvokeMessage: + """ + create a blob message + + :param blob: the blob + :param meta: the meta info of blob object + :return: the blob message + """ + return ToolInvokeMessage( + type=ToolInvokeMessage.MessageType.BLOB, + message=ToolInvokeMessage.BlobMessage(blob=blob), + meta=meta, + ) + + def create_json_message(self, object: dict) -> ToolInvokeMessage: + """ + create a json message + """ + return ToolInvokeMessage( + type=ToolInvokeMessage.MessageType.JSON, message=ToolInvokeMessage.JsonMessage(json_object=object) + ) diff --git a/api/core/datasource/__base/tool_provider.py b/api/core/datasource/__base/tool_provider.py new file mode 100644 index 0000000000..d096fc7df7 --- /dev/null +++ b/api/core/datasource/__base/tool_provider.py @@ -0,0 +1,109 @@ +from abc import ABC, abstractmethod +from copy import deepcopy +from typing import Any + +from core.entities.provider_entities import ProviderConfig +from core.tools.__base.tool import Tool +from core.tools.entities.tool_entities import ( + ToolProviderEntity, + ToolProviderType, +) +from core.tools.errors import ToolProviderCredentialValidationError + + +class ToolProviderController(ABC): + entity: ToolProviderEntity + + def __init__(self, entity: ToolProviderEntity) -> None: + self.entity = entity + + def get_credentials_schema(self) -> list[ProviderConfig]: + """ + returns the credentials schema of the provider + + :return: the credentials schema + """ + return deepcopy(self.entity.credentials_schema) + + @abstractmethod + def get_tool(self, tool_name: str) -> Tool: + """ + returns a tool that the provider can provide + + :return: tool + """ + pass + + @property + def provider_type(self) -> ToolProviderType: + """ + returns the type of the provider + + :return: type of the provider + """ + return ToolProviderType.BUILT_IN + + def validate_credentials_format(self, credentials: dict[str, Any]) -> None: + """ + validate the format of the credentials of the provider and set the default value if needed + + :param credentials: the credentials of the tool + """ + credentials_schema = dict[str, ProviderConfig]() + if credentials_schema is None: + return + + for credential in self.entity.credentials_schema: + credentials_schema[credential.name] = credential + + credentials_need_to_validate: dict[str, ProviderConfig] = {} + for credential_name in credentials_schema: + credentials_need_to_validate[credential_name] = credentials_schema[credential_name] + + for credential_name in credentials: + if credential_name not in credentials_need_to_validate: + raise ToolProviderCredentialValidationError( + f"credential {credential_name} not found in provider {self.entity.identity.name}" + ) + + # check type + credential_schema = credentials_need_to_validate[credential_name] + if not credential_schema.required and credentials[credential_name] is None: + continue + + if credential_schema.type in {ProviderConfig.Type.SECRET_INPUT, ProviderConfig.Type.TEXT_INPUT}: + if not isinstance(credentials[credential_name], str): + raise ToolProviderCredentialValidationError(f"credential {credential_name} should be string") + + elif credential_schema.type == ProviderConfig.Type.SELECT: + if not isinstance(credentials[credential_name], str): + raise ToolProviderCredentialValidationError(f"credential {credential_name} should be string") + + options = credential_schema.options + if not isinstance(options, list): + raise ToolProviderCredentialValidationError(f"credential {credential_name} options should be list") + + if credentials[credential_name] not in [x.value for x in options]: + raise ToolProviderCredentialValidationError( + f"credential {credential_name} should be one of {options}" + ) + + credentials_need_to_validate.pop(credential_name) + + for credential_name in credentials_need_to_validate: + credential_schema = credentials_need_to_validate[credential_name] + if credential_schema.required: + raise ToolProviderCredentialValidationError(f"credential {credential_name} is required") + + # the credential is not set currently, set the default value if needed + if credential_schema.default is not None: + default_value = credential_schema.default + # parse default value into the correct type + if credential_schema.type in { + ProviderConfig.Type.SECRET_INPUT, + ProviderConfig.Type.TEXT_INPUT, + ProviderConfig.Type.SELECT, + }: + default_value = str(default_value) + + credentials[credential_name] = default_value diff --git a/api/core/datasource/__base/tool_runtime.py b/api/core/datasource/__base/tool_runtime.py new file mode 100644 index 0000000000..c9e157cb77 --- /dev/null +++ b/api/core/datasource/__base/tool_runtime.py @@ -0,0 +1,36 @@ +from typing import Any, Optional + +from openai import BaseModel +from pydantic import Field + +from core.app.entities.app_invoke_entities import InvokeFrom +from core.tools.entities.tool_entities import ToolInvokeFrom + + +class ToolRuntime(BaseModel): + """ + Meta data of a tool call processing + """ + + tenant_id: str + tool_id: Optional[str] = None + invoke_from: Optional[InvokeFrom] = None + tool_invoke_from: Optional[ToolInvokeFrom] = None + credentials: dict[str, Any] = Field(default_factory=dict) + runtime_parameters: dict[str, Any] = Field(default_factory=dict) + + +class FakeToolRuntime(ToolRuntime): + """ + Fake tool runtime for testing + """ + + def __init__(self): + super().__init__( + tenant_id="fake_tenant_id", + tool_id="fake_tool_id", + invoke_from=InvokeFrom.DEBUGGER, + tool_invoke_from=ToolInvokeFrom.AGENT, + credentials={}, + runtime_parameters={}, + ) diff --git a/api/core/datasource/__init__.py b/api/core/datasource/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/datasource/entities/agent_entities.py b/api/core/datasource/entities/agent_entities.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/datasource/entities/api_entities.py b/api/core/datasource/entities/api_entities.py new file mode 100644 index 0000000000..b96c994cff --- /dev/null +++ b/api/core/datasource/entities/api_entities.py @@ -0,0 +1,72 @@ +from typing import Literal, Optional + +from pydantic import BaseModel, Field, field_validator + +from core.model_runtime.utils.encoders import jsonable_encoder +from core.tools.__base.tool import ToolParameter +from core.tools.entities.common_entities import I18nObject +from core.tools.entities.tool_entities import ToolProviderType + + +class ToolApiEntity(BaseModel): + author: str + name: str # identifier + label: I18nObject # label + description: I18nObject + parameters: Optional[list[ToolParameter]] = None + labels: list[str] = Field(default_factory=list) + output_schema: Optional[dict] = None + + +ToolProviderTypeApiLiteral = Optional[Literal["builtin", "api", "workflow"]] + + +class ToolProviderApiEntity(BaseModel): + id: str + author: str + name: str # identifier + description: I18nObject + icon: str | dict + label: I18nObject # label + type: ToolProviderType + masked_credentials: Optional[dict] = None + original_credentials: Optional[dict] = None + is_team_authorization: bool = False + allow_delete: bool = True + plugin_id: Optional[str] = Field(default="", description="The plugin id of the tool") + plugin_unique_identifier: Optional[str] = Field(default="", description="The unique identifier of the tool") + tools: list[ToolApiEntity] = Field(default_factory=list) + labels: list[str] = Field(default_factory=list) + + @field_validator("tools", mode="before") + @classmethod + def convert_none_to_empty_list(cls, v): + return v if v is not None else [] + + def to_dict(self) -> dict: + # ------------- + # overwrite tool parameter types for temp fix + tools = jsonable_encoder(self.tools) + for tool in tools: + if tool.get("parameters"): + for parameter in tool.get("parameters"): + if parameter.get("type") == ToolParameter.ToolParameterType.SYSTEM_FILES.value: + parameter["type"] = "files" + # ------------- + + return { + "id": self.id, + "author": self.author, + "name": self.name, + "plugin_id": self.plugin_id, + "plugin_unique_identifier": self.plugin_unique_identifier, + "description": self.description.to_dict(), + "icon": self.icon, + "label": self.label.to_dict(), + "type": self.type.value, + "team_credentials": self.masked_credentials, + "is_team_authorization": self.is_team_authorization, + "allow_delete": self.allow_delete, + "tools": tools, + "labels": self.labels, + } diff --git a/api/core/datasource/entities/common_entities.py b/api/core/datasource/entities/common_entities.py new file mode 100644 index 0000000000..924e6fc0cf --- /dev/null +++ b/api/core/datasource/entities/common_entities.py @@ -0,0 +1,23 @@ +from typing import Optional + +from pydantic import BaseModel, Field + + +class I18nObject(BaseModel): + """ + Model class for i18n object. + """ + + en_US: str + zh_Hans: Optional[str] = Field(default=None) + pt_BR: Optional[str] = Field(default=None) + ja_JP: Optional[str] = Field(default=None) + + def __init__(self, **data): + super().__init__(**data) + self.zh_Hans = self.zh_Hans or self.en_US + self.pt_BR = self.pt_BR or self.en_US + self.ja_JP = self.ja_JP or self.en_US + + def to_dict(self) -> dict: + return {"zh_Hans": self.zh_Hans, "en_US": self.en_US, "pt_BR": self.pt_BR, "ja_JP": self.ja_JP} diff --git a/api/core/datasource/entities/constants.py b/api/core/datasource/entities/constants.py new file mode 100644 index 0000000000..199c9f0d53 --- /dev/null +++ b/api/core/datasource/entities/constants.py @@ -0,0 +1 @@ +TOOL_SELECTOR_MODEL_IDENTITY = "__dify__tool_selector__" diff --git a/api/core/datasource/entities/file_entities.py b/api/core/datasource/entities/file_entities.py new file mode 100644 index 0000000000..8b13789179 --- /dev/null +++ b/api/core/datasource/entities/file_entities.py @@ -0,0 +1 @@ + diff --git a/api/core/datasource/entities/tool_bundle.py b/api/core/datasource/entities/tool_bundle.py new file mode 100644 index 0000000000..ffeeabbc1c --- /dev/null +++ b/api/core/datasource/entities/tool_bundle.py @@ -0,0 +1,29 @@ +from typing import Optional + +from pydantic import BaseModel + +from core.tools.entities.tool_entities import ToolParameter + + +class ApiToolBundle(BaseModel): + """ + This class is used to store the schema information of an api based tool. + such as the url, the method, the parameters, etc. + """ + + # server_url + server_url: str + # method + method: str + # summary + summary: Optional[str] = None + # operation_id + operation_id: Optional[str] = None + # parameters + parameters: Optional[list[ToolParameter]] = None + # author + author: str + # icon + icon: Optional[str] = None + # openapi operation + openapi: dict diff --git a/api/core/datasource/entities/tool_entities.py b/api/core/datasource/entities/tool_entities.py new file mode 100644 index 0000000000..d756763137 --- /dev/null +++ b/api/core/datasource/entities/tool_entities.py @@ -0,0 +1,427 @@ +import base64 +import enum +from collections.abc import Mapping +from enum import Enum +from typing import Any, Optional, Union + +from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_serializer, field_validator, model_validator + +from core.entities.provider_entities import ProviderConfig +from core.plugin.entities.parameters import ( + PluginParameter, + PluginParameterOption, + PluginParameterType, + as_normal_type, + cast_parameter_value, + init_frontend_parameter, +) +from core.tools.entities.common_entities import I18nObject +from core.tools.entities.constants import TOOL_SELECTOR_MODEL_IDENTITY + + +class ToolLabelEnum(Enum): + SEARCH = "search" + IMAGE = "image" + VIDEOS = "videos" + WEATHER = "weather" + FINANCE = "finance" + DESIGN = "design" + TRAVEL = "travel" + SOCIAL = "social" + NEWS = "news" + MEDICAL = "medical" + PRODUCTIVITY = "productivity" + EDUCATION = "education" + BUSINESS = "business" + ENTERTAINMENT = "entertainment" + UTILITIES = "utilities" + OTHER = "other" + + +class ToolProviderType(enum.StrEnum): + """ + Enum class for tool provider + """ + + PLUGIN = "plugin" + BUILT_IN = "builtin" + WORKFLOW = "workflow" + API = "api" + APP = "app" + DATASET_RETRIEVAL = "dataset-retrieval" + + @classmethod + def value_of(cls, value: str) -> "ToolProviderType": + """ + Get value of given mode. + + :param value: mode value + :return: mode + """ + for mode in cls: + if mode.value == value: + return mode + raise ValueError(f"invalid mode value {value}") + + +class ApiProviderSchemaType(Enum): + """ + Enum class for api provider schema type. + """ + + OPENAPI = "openapi" + SWAGGER = "swagger" + OPENAI_PLUGIN = "openai_plugin" + OPENAI_ACTIONS = "openai_actions" + + @classmethod + def value_of(cls, value: str) -> "ApiProviderSchemaType": + """ + Get value of given mode. + + :param value: mode value + :return: mode + """ + for mode in cls: + if mode.value == value: + return mode + raise ValueError(f"invalid mode value {value}") + + +class ApiProviderAuthType(Enum): + """ + Enum class for api provider auth type. + """ + + NONE = "none" + API_KEY = "api_key" + + @classmethod + def value_of(cls, value: str) -> "ApiProviderAuthType": + """ + Get value of given mode. + + :param value: mode value + :return: mode + """ + for mode in cls: + if mode.value == value: + return mode + raise ValueError(f"invalid mode value {value}") + + +class ToolInvokeMessage(BaseModel): + class TextMessage(BaseModel): + text: str + + class JsonMessage(BaseModel): + json_object: dict + + class BlobMessage(BaseModel): + blob: bytes + + class FileMessage(BaseModel): + pass + + class VariableMessage(BaseModel): + variable_name: str = Field(..., description="The name of the variable") + variable_value: Any = Field(..., description="The value of the variable") + stream: bool = Field(default=False, description="Whether the variable is streamed") + + @model_validator(mode="before") + @classmethod + def transform_variable_value(cls, values) -> Any: + """ + Only basic types and lists are allowed. + """ + value = values.get("variable_value") + if not isinstance(value, dict | list | str | int | float | bool): + raise ValueError("Only basic types and lists are allowed.") + + # if stream is true, the value must be a string + if values.get("stream"): + if not isinstance(value, str): + raise ValueError("When 'stream' is True, 'variable_value' must be a string.") + + return values + + @field_validator("variable_name", mode="before") + @classmethod + def transform_variable_name(cls, value: str) -> str: + """ + The variable name must be a string. + """ + if value in {"json", "text", "files"}: + raise ValueError(f"The variable name '{value}' is reserved.") + return value + + class LogMessage(BaseModel): + class LogStatus(Enum): + START = "start" + ERROR = "error" + SUCCESS = "success" + + id: str + label: str = Field(..., description="The label of the log") + parent_id: Optional[str] = Field(default=None, description="Leave empty for root log") + error: Optional[str] = Field(default=None, description="The error message") + status: LogStatus = Field(..., description="The status of the log") + data: Mapping[str, Any] = Field(..., description="Detailed log data") + metadata: Optional[Mapping[str, Any]] = Field(default=None, description="The metadata of the log") + + class MessageType(Enum): + TEXT = "text" + IMAGE = "image" + LINK = "link" + BLOB = "blob" + JSON = "json" + IMAGE_LINK = "image_link" + BINARY_LINK = "binary_link" + VARIABLE = "variable" + FILE = "file" + LOG = "log" + + type: MessageType = MessageType.TEXT + """ + plain text, image url or link url + """ + message: JsonMessage | TextMessage | BlobMessage | LogMessage | FileMessage | None | VariableMessage + meta: dict[str, Any] | None = None + + @field_validator("message", mode="before") + @classmethod + def decode_blob_message(cls, v): + if isinstance(v, dict) and "blob" in v: + try: + v["blob"] = base64.b64decode(v["blob"]) + except Exception: + pass + return v + + @field_serializer("message") + def serialize_message(self, v): + if isinstance(v, self.BlobMessage): + return {"blob": base64.b64encode(v.blob).decode("utf-8")} + return v + + +class ToolInvokeMessageBinary(BaseModel): + mimetype: str = Field(..., description="The mimetype of the binary") + url: str = Field(..., description="The url of the binary") + file_var: Optional[dict[str, Any]] = None + + +class ToolParameter(PluginParameter): + """ + Overrides type + """ + + class ToolParameterType(enum.StrEnum): + """ + removes TOOLS_SELECTOR from PluginParameterType + """ + + STRING = PluginParameterType.STRING.value + NUMBER = PluginParameterType.NUMBER.value + BOOLEAN = PluginParameterType.BOOLEAN.value + SELECT = PluginParameterType.SELECT.value + SECRET_INPUT = PluginParameterType.SECRET_INPUT.value + FILE = PluginParameterType.FILE.value + FILES = PluginParameterType.FILES.value + APP_SELECTOR = PluginParameterType.APP_SELECTOR.value + MODEL_SELECTOR = PluginParameterType.MODEL_SELECTOR.value + + # deprecated, should not use. + SYSTEM_FILES = PluginParameterType.SYSTEM_FILES.value + + def as_normal_type(self): + return as_normal_type(self) + + def cast_value(self, value: Any): + return cast_parameter_value(self, value) + + class ToolParameterForm(Enum): + SCHEMA = "schema" # should be set while adding tool + FORM = "form" # should be set before invoking tool + LLM = "llm" # will be set by LLM + + type: ToolParameterType = Field(..., description="The type of the parameter") + human_description: Optional[I18nObject] = Field(default=None, description="The description presented to the user") + form: ToolParameterForm = Field(..., description="The form of the parameter, schema/form/llm") + llm_description: Optional[str] = None + + @classmethod + def get_simple_instance( + cls, + name: str, + llm_description: str, + typ: ToolParameterType, + required: bool, + options: Optional[list[str]] = None, + ) -> "ToolParameter": + """ + get a simple tool parameter + + :param name: the name of the parameter + :param llm_description: the description presented to the LLM + :param typ: the type of the parameter + :param required: if the parameter is required + :param options: the options of the parameter + """ + # convert options to ToolParameterOption + # FIXME fix the type error + if options: + option_objs = [ + PluginParameterOption(value=option, label=I18nObject(en_US=option, zh_Hans=option)) + for option in options + ] + else: + option_objs = [] + + return cls( + name=name, + label=I18nObject(en_US="", zh_Hans=""), + placeholder=None, + human_description=I18nObject(en_US="", zh_Hans=""), + type=typ, + form=cls.ToolParameterForm.LLM, + llm_description=llm_description, + required=required, + options=option_objs, + ) + + def init_frontend_parameter(self, value: Any): + return init_frontend_parameter(self, self.type, value) + + +class ToolProviderIdentity(BaseModel): + author: str = Field(..., description="The author of the tool") + name: str = Field(..., description="The name of the tool") + description: I18nObject = Field(..., description="The description of the tool") + icon: str = Field(..., description="The icon of the tool") + label: I18nObject = Field(..., description="The label of the tool") + tags: Optional[list[ToolLabelEnum]] = Field( + default=[], + description="The tags of the tool", + ) + + +class ToolIdentity(BaseModel): + author: str = Field(..., description="The author of the tool") + name: str = Field(..., description="The name of the tool") + label: I18nObject = Field(..., description="The label of the tool") + provider: str = Field(..., description="The provider of the tool") + icon: Optional[str] = None + + +class ToolDescription(BaseModel): + human: I18nObject = Field(..., description="The description presented to the user") + llm: str = Field(..., description="The description presented to the LLM") + + +class ToolEntity(BaseModel): + identity: ToolIdentity + parameters: list[ToolParameter] = Field(default_factory=list) + description: Optional[ToolDescription] = None + output_schema: Optional[dict] = None + has_runtime_parameters: bool = Field(default=False, description="Whether the tool has runtime parameters") + + # pydantic configs + model_config = ConfigDict(protected_namespaces=()) + + @field_validator("parameters", mode="before") + @classmethod + def set_parameters(cls, v, validation_info: ValidationInfo) -> list[ToolParameter]: + return v or [] + + +class ToolProviderEntity(BaseModel): + identity: ToolProviderIdentity + plugin_id: Optional[str] = None + credentials_schema: list[ProviderConfig] = Field(default_factory=list) + + +class ToolProviderEntityWithPlugin(ToolProviderEntity): + tools: list[ToolEntity] = Field(default_factory=list) + + +class WorkflowToolParameterConfiguration(BaseModel): + """ + Workflow tool configuration + """ + + name: str = Field(..., description="The name of the parameter") + description: str = Field(..., description="The description of the parameter") + form: ToolParameter.ToolParameterForm = Field(..., description="The form of the parameter") + + +class ToolInvokeMeta(BaseModel): + """ + Tool invoke meta + """ + + time_cost: float = Field(..., description="The time cost of the tool invoke") + error: Optional[str] = None + tool_config: Optional[dict] = None + + @classmethod + def empty(cls) -> "ToolInvokeMeta": + """ + Get an empty instance of ToolInvokeMeta + """ + return cls(time_cost=0.0, error=None, tool_config={}) + + @classmethod + def error_instance(cls, error: str) -> "ToolInvokeMeta": + """ + Get an instance of ToolInvokeMeta with error + """ + return cls(time_cost=0.0, error=error, tool_config={}) + + def to_dict(self) -> dict: + return { + "time_cost": self.time_cost, + "error": self.error, + "tool_config": self.tool_config, + } + + +class ToolLabel(BaseModel): + """ + Tool label + """ + + name: str = Field(..., description="The name of the tool") + label: I18nObject = Field(..., description="The label of the tool") + icon: str = Field(..., description="The icon of the tool") + + +class ToolInvokeFrom(Enum): + """ + Enum class for tool invoke + """ + + WORKFLOW = "workflow" + AGENT = "agent" + PLUGIN = "plugin" + + +class ToolSelector(BaseModel): + dify_model_identity: str = TOOL_SELECTOR_MODEL_IDENTITY + + class Parameter(BaseModel): + name: str = Field(..., description="The name of the parameter") + type: ToolParameter.ToolParameterType = Field(..., description="The type of the parameter") + required: bool = Field(..., description="Whether the parameter is required") + description: str = Field(..., description="The description of the parameter") + default: Optional[Union[int, float, str]] = None + options: Optional[list[PluginParameterOption]] = None + + provider_id: str = Field(..., description="The id of the provider") + tool_name: str = Field(..., description="The name of the tool") + tool_description: str = Field(..., description="The description of the tool") + tool_configuration: Mapping[str, Any] = Field(..., description="Configuration, type form") + tool_parameters: Mapping[str, Parameter] = Field(..., description="Parameters, type llm") + + def to_plugin_parameter(self) -> dict[str, Any]: + return self.model_dump() diff --git a/api/core/datasource/entities/values.py b/api/core/datasource/entities/values.py new file mode 100644 index 0000000000..f460df7e25 --- /dev/null +++ b/api/core/datasource/entities/values.py @@ -0,0 +1,111 @@ +from core.tools.entities.common_entities import I18nObject +from core.tools.entities.tool_entities import ToolLabel, ToolLabelEnum + +ICONS = { + ToolLabelEnum.SEARCH: """ + +""", # noqa: E501 + ToolLabelEnum.IMAGE: """ + +""", # noqa: E501 + ToolLabelEnum.VIDEOS: """ + +""", # noqa: E501 + ToolLabelEnum.WEATHER: """ + +""", # noqa: E501 + ToolLabelEnum.FINANCE: """ + +""", # noqa: E501 + ToolLabelEnum.DESIGN: """ + +""", # noqa: E501 + ToolLabelEnum.TRAVEL: """ + +""", # noqa: E501 + ToolLabelEnum.SOCIAL: """ + +""", # noqa: E501 + ToolLabelEnum.NEWS: """ + +""", # noqa: E501 + ToolLabelEnum.MEDICAL: """ + +""", # noqa: E501 + ToolLabelEnum.PRODUCTIVITY: """ + +""", # noqa: E501 + ToolLabelEnum.EDUCATION: """ + +""", # noqa: E501 + ToolLabelEnum.BUSINESS: """ + +""", # noqa: E501 + ToolLabelEnum.ENTERTAINMENT: """ + +""", # noqa: E501 + ToolLabelEnum.UTILITIES: """ + +""", # noqa: E501 + ToolLabelEnum.OTHER: """ + +""", # noqa: E501 +} + +default_tool_label_dict = { + ToolLabelEnum.SEARCH: ToolLabel( + name="search", label=I18nObject(en_US="Search", zh_Hans="搜索"), icon=ICONS[ToolLabelEnum.SEARCH] + ), + ToolLabelEnum.IMAGE: ToolLabel( + name="image", label=I18nObject(en_US="Image", zh_Hans="图片"), icon=ICONS[ToolLabelEnum.IMAGE] + ), + ToolLabelEnum.VIDEOS: ToolLabel( + name="videos", label=I18nObject(en_US="Videos", zh_Hans="视频"), icon=ICONS[ToolLabelEnum.VIDEOS] + ), + ToolLabelEnum.WEATHER: ToolLabel( + name="weather", label=I18nObject(en_US="Weather", zh_Hans="天气"), icon=ICONS[ToolLabelEnum.WEATHER] + ), + ToolLabelEnum.FINANCE: ToolLabel( + name="finance", label=I18nObject(en_US="Finance", zh_Hans="金融"), icon=ICONS[ToolLabelEnum.FINANCE] + ), + ToolLabelEnum.DESIGN: ToolLabel( + name="design", label=I18nObject(en_US="Design", zh_Hans="设计"), icon=ICONS[ToolLabelEnum.DESIGN] + ), + ToolLabelEnum.TRAVEL: ToolLabel( + name="travel", label=I18nObject(en_US="Travel", zh_Hans="旅行"), icon=ICONS[ToolLabelEnum.TRAVEL] + ), + ToolLabelEnum.SOCIAL: ToolLabel( + name="social", label=I18nObject(en_US="Social", zh_Hans="社交"), icon=ICONS[ToolLabelEnum.SOCIAL] + ), + ToolLabelEnum.NEWS: ToolLabel( + name="news", label=I18nObject(en_US="News", zh_Hans="新闻"), icon=ICONS[ToolLabelEnum.NEWS] + ), + ToolLabelEnum.MEDICAL: ToolLabel( + name="medical", label=I18nObject(en_US="Medical", zh_Hans="医疗"), icon=ICONS[ToolLabelEnum.MEDICAL] + ), + ToolLabelEnum.PRODUCTIVITY: ToolLabel( + name="productivity", + label=I18nObject(en_US="Productivity", zh_Hans="生产力"), + icon=ICONS[ToolLabelEnum.PRODUCTIVITY], + ), + ToolLabelEnum.EDUCATION: ToolLabel( + name="education", label=I18nObject(en_US="Education", zh_Hans="教育"), icon=ICONS[ToolLabelEnum.EDUCATION] + ), + ToolLabelEnum.BUSINESS: ToolLabel( + name="business", label=I18nObject(en_US="Business", zh_Hans="商业"), icon=ICONS[ToolLabelEnum.BUSINESS] + ), + ToolLabelEnum.ENTERTAINMENT: ToolLabel( + name="entertainment", + label=I18nObject(en_US="Entertainment", zh_Hans="娱乐"), + icon=ICONS[ToolLabelEnum.ENTERTAINMENT], + ), + ToolLabelEnum.UTILITIES: ToolLabel( + name="utilities", label=I18nObject(en_US="Utilities", zh_Hans="工具"), icon=ICONS[ToolLabelEnum.UTILITIES] + ), + ToolLabelEnum.OTHER: ToolLabel( + name="other", label=I18nObject(en_US="Other", zh_Hans="其他"), icon=ICONS[ToolLabelEnum.OTHER] + ), +} + +default_tool_labels = [v for k, v in default_tool_label_dict.items()] +default_tool_label_name_list = [label.name for label in default_tool_labels] diff --git a/api/core/datasource/errors.py b/api/core/datasource/errors.py new file mode 100644 index 0000000000..c5f9ca4774 --- /dev/null +++ b/api/core/datasource/errors.py @@ -0,0 +1,37 @@ +from core.tools.entities.tool_entities import ToolInvokeMeta + + +class ToolProviderNotFoundError(ValueError): + pass + + +class ToolNotFoundError(ValueError): + pass + + +class ToolParameterValidationError(ValueError): + pass + + +class ToolProviderCredentialValidationError(ValueError): + pass + + +class ToolNotSupportedError(ValueError): + pass + + +class ToolInvokeError(ValueError): + pass + + +class ToolApiSchemaError(ValueError): + pass + + +class ToolEngineInvokeError(Exception): + meta: ToolInvokeMeta + + def __init__(self, meta, **kwargs): + self.meta = meta + super().__init__(**kwargs) diff --git a/api/core/datasource/plugin_tool/provider.py b/api/core/datasource/plugin_tool/provider.py new file mode 100644 index 0000000000..3616e426b9 --- /dev/null +++ b/api/core/datasource/plugin_tool/provider.py @@ -0,0 +1,79 @@ +from typing import Any + +from core.plugin.manager.tool import PluginToolManager +from core.tools.__base.tool_runtime import ToolRuntime +from core.tools.builtin_tool.provider import BuiltinToolProviderController +from core.tools.entities.tool_entities import ToolProviderEntityWithPlugin, ToolProviderType +from core.tools.errors import ToolProviderCredentialValidationError +from core.tools.plugin_tool.tool import PluginTool + + +class PluginToolProviderController(BuiltinToolProviderController): + entity: ToolProviderEntityWithPlugin + tenant_id: str + plugin_id: str + plugin_unique_identifier: str + + def __init__( + self, entity: ToolProviderEntityWithPlugin, plugin_id: str, plugin_unique_identifier: str, tenant_id: str + ) -> None: + self.entity = entity + self.tenant_id = tenant_id + self.plugin_id = plugin_id + self.plugin_unique_identifier = plugin_unique_identifier + + @property + def provider_type(self) -> ToolProviderType: + """ + returns the type of the provider + + :return: type of the provider + """ + return ToolProviderType.PLUGIN + + def _validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None: + """ + validate the credentials of the provider + """ + manager = PluginToolManager() + if not manager.validate_provider_credentials( + tenant_id=self.tenant_id, + user_id=user_id, + provider=self.entity.identity.name, + credentials=credentials, + ): + raise ToolProviderCredentialValidationError("Invalid credentials") + + def get_tool(self, tool_name: str) -> PluginTool: # type: ignore + """ + return tool with given name + """ + tool_entity = next( + (tool_entity for tool_entity in self.entity.tools if tool_entity.identity.name == tool_name), None + ) + + if not tool_entity: + raise ValueError(f"Tool with name {tool_name} not found") + + return PluginTool( + entity=tool_entity, + runtime=ToolRuntime(tenant_id=self.tenant_id), + tenant_id=self.tenant_id, + icon=self.entity.identity.icon, + plugin_unique_identifier=self.plugin_unique_identifier, + ) + + def get_tools(self) -> list[PluginTool]: # type: ignore + """ + get all tools + """ + return [ + PluginTool( + entity=tool_entity, + runtime=ToolRuntime(tenant_id=self.tenant_id), + tenant_id=self.tenant_id, + icon=self.entity.identity.icon, + plugin_unique_identifier=self.plugin_unique_identifier, + ) + for tool_entity in self.entity.tools + ] diff --git a/api/core/datasource/plugin_tool/tool.py b/api/core/datasource/plugin_tool/tool.py new file mode 100644 index 0000000000..f31a9a0d3e --- /dev/null +++ b/api/core/datasource/plugin_tool/tool.py @@ -0,0 +1,89 @@ +from collections.abc import Generator +from typing import Any, Optional + +from core.plugin.manager.tool import PluginToolManager +from core.plugin.utils.converter import convert_parameters_to_plugin_format +from core.tools.__base.tool import Tool +from core.tools.__base.tool_runtime import ToolRuntime +from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolParameter, ToolProviderType + + +class PluginTool(Tool): + tenant_id: str + icon: str + plugin_unique_identifier: str + runtime_parameters: Optional[list[ToolParameter]] + + def __init__( + self, entity: ToolEntity, runtime: ToolRuntime, tenant_id: str, icon: str, plugin_unique_identifier: str + ) -> None: + super().__init__(entity, runtime) + self.tenant_id = tenant_id + self.icon = icon + self.plugin_unique_identifier = plugin_unique_identifier + self.runtime_parameters = None + + def tool_provider_type(self) -> ToolProviderType: + return ToolProviderType.PLUGIN + + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + conversation_id: Optional[str] = None, + app_id: Optional[str] = None, + message_id: Optional[str] = None, + ) -> Generator[ToolInvokeMessage, None, None]: + manager = PluginToolManager() + + tool_parameters = convert_parameters_to_plugin_format(tool_parameters) + + yield from manager.invoke( + tenant_id=self.tenant_id, + user_id=user_id, + tool_provider=self.entity.identity.provider, + tool_name=self.entity.identity.name, + credentials=self.runtime.credentials, + tool_parameters=tool_parameters, + conversation_id=conversation_id, + app_id=app_id, + message_id=message_id, + ) + + def fork_tool_runtime(self, runtime: ToolRuntime) -> "PluginTool": + return PluginTool( + entity=self.entity, + runtime=runtime, + tenant_id=self.tenant_id, + icon=self.icon, + plugin_unique_identifier=self.plugin_unique_identifier, + ) + + def get_runtime_parameters( + self, + conversation_id: Optional[str] = None, + app_id: Optional[str] = None, + message_id: Optional[str] = None, + ) -> list[ToolParameter]: + """ + get the runtime parameters + """ + if not self.entity.has_runtime_parameters: + return self.entity.parameters + + if self.runtime_parameters is not None: + return self.runtime_parameters + + manager = PluginToolManager() + self.runtime_parameters = manager.get_runtime_parameters( + tenant_id=self.tenant_id, + user_id="", + provider=self.entity.identity.provider, + tool=self.entity.identity.name, + credentials=self.runtime.credentials, + conversation_id=conversation_id, + app_id=app_id, + message_id=message_id, + ) + + return self.runtime_parameters diff --git a/api/core/datasource/tool_engine.py b/api/core/datasource/tool_engine.py new file mode 100644 index 0000000000..ad0c62537c --- /dev/null +++ b/api/core/datasource/tool_engine.py @@ -0,0 +1,357 @@ +import json +from collections.abc import Generator, Iterable +from copy import deepcopy +from datetime import UTC, datetime +from mimetypes import guess_type +from typing import Any, Optional, Union, cast + +from yarl import URL + +from core.app.entities.app_invoke_entities import InvokeFrom +from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler +from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler +from core.file import FileType +from core.file.models import FileTransferMethod +from core.ops.ops_trace_manager import TraceQueueManager +from core.tools.__base.tool import Tool +from core.tools.entities.tool_entities import ( + ToolInvokeMessage, + ToolInvokeMessageBinary, + ToolInvokeMeta, + ToolParameter, +) +from core.tools.errors import ( + ToolEngineInvokeError, + ToolInvokeError, + ToolNotFoundError, + ToolNotSupportedError, + ToolParameterValidationError, + ToolProviderCredentialValidationError, + ToolProviderNotFoundError, +) +from core.tools.utils.message_transformer import ToolFileMessageTransformer +from core.tools.workflow_as_tool.tool import WorkflowTool +from extensions.ext_database import db +from models.enums import CreatedByRole +from models.model import Message, MessageFile + + +class ToolEngine: + """ + Tool runtime engine take care of the tool executions. + """ + + @staticmethod + def agent_invoke( + tool: Tool, + tool_parameters: Union[str, dict], + user_id: str, + tenant_id: str, + message: Message, + invoke_from: InvokeFrom, + agent_tool_callback: DifyAgentCallbackHandler, + trace_manager: Optional[TraceQueueManager] = None, + conversation_id: Optional[str] = None, + app_id: Optional[str] = None, + message_id: Optional[str] = None, + ) -> tuple[str, list[str], ToolInvokeMeta]: + """ + Agent invokes the tool with the given arguments. + """ + # check if arguments is a string + if isinstance(tool_parameters, str): + # check if this tool has only one parameter + parameters = [ + parameter + for parameter in tool.get_runtime_parameters() + if parameter.form == ToolParameter.ToolParameterForm.LLM + ] + if parameters and len(parameters) == 1: + tool_parameters = {parameters[0].name: tool_parameters} + else: + try: + tool_parameters = json.loads(tool_parameters) + except Exception: + pass + if not isinstance(tool_parameters, dict): + raise ValueError(f"tool_parameters should be a dict, but got a string: {tool_parameters}") + + try: + # hit the callback handler + agent_tool_callback.on_tool_start(tool_name=tool.entity.identity.name, tool_inputs=tool_parameters) + + messages = ToolEngine._invoke(tool, tool_parameters, user_id, conversation_id, app_id, message_id) + invocation_meta_dict: dict[str, ToolInvokeMeta] = {} + + def message_callback( + invocation_meta_dict: dict, messages: Generator[ToolInvokeMessage | ToolInvokeMeta, None, None] + ): + for message in messages: + if isinstance(message, ToolInvokeMeta): + invocation_meta_dict["meta"] = message + else: + yield message + + messages = ToolFileMessageTransformer.transform_tool_invoke_messages( + messages=message_callback(invocation_meta_dict, messages), + user_id=user_id, + tenant_id=tenant_id, + conversation_id=message.conversation_id, + ) + + message_list = list(messages) + + # extract binary data from tool invoke message + binary_files = ToolEngine._extract_tool_response_binary_and_text(message_list) + # create message file + message_files = ToolEngine._create_message_files( + tool_messages=binary_files, agent_message=message, invoke_from=invoke_from, user_id=user_id + ) + + plain_text = ToolEngine._convert_tool_response_to_str(message_list) + + meta = invocation_meta_dict["meta"] + + # hit the callback handler + agent_tool_callback.on_tool_end( + tool_name=tool.entity.identity.name, + tool_inputs=tool_parameters, + tool_outputs=plain_text, + message_id=message.id, + trace_manager=trace_manager, + ) + + # transform tool invoke message to get LLM friendly message + return plain_text, message_files, meta + except ToolProviderCredentialValidationError as e: + error_response = "Please check your tool provider credentials" + agent_tool_callback.on_tool_error(e) + except (ToolNotFoundError, ToolNotSupportedError, ToolProviderNotFoundError) as e: + error_response = f"there is not a tool named {tool.entity.identity.name}" + agent_tool_callback.on_tool_error(e) + except ToolParameterValidationError as e: + error_response = f"tool parameters validation error: {e}, please check your tool parameters" + agent_tool_callback.on_tool_error(e) + except ToolInvokeError as e: + error_response = f"tool invoke error: {e}" + agent_tool_callback.on_tool_error(e) + except ToolEngineInvokeError as e: + meta = e.meta + error_response = f"tool invoke error: {meta.error}" + agent_tool_callback.on_tool_error(e) + return error_response, [], meta + except Exception as e: + error_response = f"unknown error: {e}" + agent_tool_callback.on_tool_error(e) + + return error_response, [], ToolInvokeMeta.error_instance(error_response) + + @staticmethod + def x( + tool: Tool, + tool_parameters: dict[str, Any], + user_id: str, + workflow_tool_callback: DifyWorkflowCallbackHandler, + workflow_call_depth: int, + thread_pool_id: Optional[str] = None, + conversation_id: Optional[str] = None, + app_id: Optional[str] = None, + message_id: Optional[str] = None, + ) -> Generator[ToolInvokeMessage, None, None]: + """ + Workflow invokes the tool with the given arguments. + """ + try: + # hit the callback handler + workflow_tool_callback.on_tool_start(tool_name=tool.entity.identity.name, tool_inputs=tool_parameters) + + if isinstance(tool, WorkflowTool): + tool.workflow_call_depth = workflow_call_depth + 1 + tool.thread_pool_id = thread_pool_id + + if tool.runtime and tool.runtime.runtime_parameters: + tool_parameters = {**tool.runtime.runtime_parameters, **tool_parameters} + + response = tool.invoke( + user_id=user_id, + tool_parameters=tool_parameters, + conversation_id=conversation_id, + app_id=app_id, + message_id=message_id, + ) + + # hit the callback handler + response = workflow_tool_callback.on_tool_execution( + tool_name=tool.entity.identity.name, + tool_inputs=tool_parameters, + tool_outputs=response, + ) + + return response + except Exception as e: + workflow_tool_callback.on_tool_error(e) + raise e + + @staticmethod + def _invoke( + tool: Tool, + tool_parameters: dict, + user_id: str, + conversation_id: Optional[str] = None, + app_id: Optional[str] = None, + message_id: Optional[str] = None, + ) -> Generator[ToolInvokeMessage | ToolInvokeMeta, None, None]: + """ + Invoke the tool with the given arguments. + """ + started_at = datetime.now(UTC) + meta = ToolInvokeMeta( + time_cost=0.0, + error=None, + tool_config={ + "tool_name": tool.entity.identity.name, + "tool_provider": tool.entity.identity.provider, + "tool_provider_type": tool.tool_provider_type().value, + "tool_parameters": deepcopy(tool.runtime.runtime_parameters), + "tool_icon": tool.entity.identity.icon, + }, + ) + try: + yield from tool.invoke(user_id, tool_parameters, conversation_id, app_id, message_id) + except Exception as e: + meta.error = str(e) + raise ToolEngineInvokeError(meta) + finally: + ended_at = datetime.now(UTC) + meta.time_cost = (ended_at - started_at).total_seconds() + yield meta + + @staticmethod + def _convert_tool_response_to_str(tool_response: list[ToolInvokeMessage]) -> str: + """ + Handle tool response + """ + result = "" + for response in tool_response: + if response.type == ToolInvokeMessage.MessageType.TEXT: + result += cast(ToolInvokeMessage.TextMessage, response.message).text + elif response.type == ToolInvokeMessage.MessageType.LINK: + result += ( + f"result link: {cast(ToolInvokeMessage.TextMessage, response.message).text}." + + " please tell user to check it." + ) + elif response.type in {ToolInvokeMessage.MessageType.IMAGE_LINK, ToolInvokeMessage.MessageType.IMAGE}: + result += ( + "image has been created and sent to user already, " + + "you do not need to create it, just tell the user to check it now." + ) + elif response.type == ToolInvokeMessage.MessageType.JSON: + result = json.dumps( + cast(ToolInvokeMessage.JsonMessage, response.message).json_object, ensure_ascii=False + ) + else: + result += str(response.message) + + return result + + @staticmethod + def _extract_tool_response_binary_and_text( + tool_response: list[ToolInvokeMessage], + ) -> Generator[ToolInvokeMessageBinary, None, None]: + """ + Extract tool response binary + """ + for response in tool_response: + if response.type in {ToolInvokeMessage.MessageType.IMAGE_LINK, ToolInvokeMessage.MessageType.IMAGE}: + mimetype = None + if not response.meta: + raise ValueError("missing meta data") + if response.meta.get("mime_type"): + mimetype = response.meta.get("mime_type") + else: + try: + url = URL(cast(ToolInvokeMessage.TextMessage, response.message).text) + extension = url.suffix + guess_type_result, _ = guess_type(f"a{extension}") + if guess_type_result: + mimetype = guess_type_result + except Exception: + pass + + if not mimetype: + mimetype = "image/jpeg" + + yield ToolInvokeMessageBinary( + mimetype=response.meta.get("mime_type", "image/jpeg"), + url=cast(ToolInvokeMessage.TextMessage, response.message).text, + ) + elif response.type == ToolInvokeMessage.MessageType.BLOB: + if not response.meta: + raise ValueError("missing meta data") + + yield ToolInvokeMessageBinary( + mimetype=response.meta.get("mime_type", "application/octet-stream"), + url=cast(ToolInvokeMessage.TextMessage, response.message).text, + ) + elif response.type == ToolInvokeMessage.MessageType.LINK: + # check if there is a mime type in meta + if response.meta and "mime_type" in response.meta: + yield ToolInvokeMessageBinary( + mimetype=response.meta.get("mime_type", "application/octet-stream") + if response.meta + else "application/octet-stream", + url=cast(ToolInvokeMessage.TextMessage, response.message).text, + ) + + @staticmethod + def _create_message_files( + tool_messages: Iterable[ToolInvokeMessageBinary], + agent_message: Message, + invoke_from: InvokeFrom, + user_id: str, + ) -> list[str]: + """ + Create message file + + :return: message file ids + """ + result = [] + + for message in tool_messages: + if "image" in message.mimetype: + file_type = FileType.IMAGE + elif "video" in message.mimetype: + file_type = FileType.VIDEO + elif "audio" in message.mimetype: + file_type = FileType.AUDIO + elif "text" in message.mimetype or "pdf" in message.mimetype: + file_type = FileType.DOCUMENT + else: + file_type = FileType.CUSTOM + + # extract tool file id from url + tool_file_id = message.url.split("/")[-1].split(".")[0] + message_file = MessageFile( + message_id=agent_message.id, + type=file_type, + transfer_method=FileTransferMethod.TOOL_FILE, + belongs_to="assistant", + url=message.url, + upload_file_id=tool_file_id, + created_by_role=( + CreatedByRole.ACCOUNT + if invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} + else CreatedByRole.END_USER + ), + created_by=user_id, + ) + + db.session.add(message_file) + db.session.commit() + db.session.refresh(message_file) + + result.append(message_file.id) + + db.session.close() + + return result diff --git a/api/core/datasource/tool_file_manager.py b/api/core/datasource/tool_file_manager.py new file mode 100644 index 0000000000..7e8d4280d4 --- /dev/null +++ b/api/core/datasource/tool_file_manager.py @@ -0,0 +1,234 @@ +import base64 +import hashlib +import hmac +import logging +import os +import time +from mimetypes import guess_extension, guess_type +from typing import Optional, Union +from uuid import uuid4 + +import httpx + +from configs import dify_config +from core.helper import ssrf_proxy +from extensions.ext_database import db +from extensions.ext_storage import storage +from models.model import MessageFile +from models.tools import ToolFile + +logger = logging.getLogger(__name__) + + +class ToolFileManager: + @staticmethod + def sign_file(tool_file_id: str, extension: str) -> str: + """ + sign file to get a temporary url + """ + base_url = dify_config.FILES_URL + file_preview_url = f"{base_url}/files/tools/{tool_file_id}{extension}" + + timestamp = str(int(time.time())) + nonce = os.urandom(16).hex() + data_to_sign = f"file-preview|{tool_file_id}|{timestamp}|{nonce}" + secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b"" + sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() + encoded_sign = base64.urlsafe_b64encode(sign).decode() + + return f"{file_preview_url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}" + + @staticmethod + def verify_file(file_id: str, timestamp: str, nonce: str, sign: str) -> bool: + """ + verify signature + """ + data_to_sign = f"file-preview|{file_id}|{timestamp}|{nonce}" + secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b"" + recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() + recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode() + + # verify signature + if sign != recalculated_encoded_sign: + return False + + current_time = int(time.time()) + return current_time - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT + + @staticmethod + def create_file_by_raw( + *, + user_id: str, + tenant_id: str, + conversation_id: Optional[str], + file_binary: bytes, + mimetype: str, + filename: Optional[str] = None, + ) -> ToolFile: + extension = guess_extension(mimetype) or ".bin" + unique_name = uuid4().hex + unique_filename = f"{unique_name}{extension}" + # default just as before + present_filename = unique_filename + if filename is not None: + has_extension = len(filename.split(".")) > 1 + # Add extension flexibly + present_filename = filename if has_extension else f"{filename}{extension}" + filepath = f"tools/{tenant_id}/{unique_filename}" + storage.save(filepath, file_binary) + + tool_file = ToolFile( + user_id=user_id, + tenant_id=tenant_id, + conversation_id=conversation_id, + file_key=filepath, + mimetype=mimetype, + name=present_filename, + size=len(file_binary), + ) + + db.session.add(tool_file) + db.session.commit() + db.session.refresh(tool_file) + + return tool_file + + @staticmethod + def create_file_by_url( + user_id: str, + tenant_id: str, + file_url: str, + conversation_id: Optional[str] = None, + ) -> ToolFile: + # try to download image + try: + response = ssrf_proxy.get(file_url) + response.raise_for_status() + blob = response.content + except httpx.TimeoutException: + raise ValueError(f"timeout when downloading file from {file_url}") + + mimetype = ( + guess_type(file_url)[0] + or response.headers.get("Content-Type", "").split(";")[0].strip() + or "application/octet-stream" + ) + extension = guess_extension(mimetype) or ".bin" + unique_name = uuid4().hex + filename = f"{unique_name}{extension}" + filepath = f"tools/{tenant_id}/{filename}" + storage.save(filepath, blob) + + tool_file = ToolFile( + user_id=user_id, + tenant_id=tenant_id, + conversation_id=conversation_id, + file_key=filepath, + mimetype=mimetype, + original_url=file_url, + name=filename, + size=len(blob), + ) + + db.session.add(tool_file) + db.session.commit() + + return tool_file + + @staticmethod + def get_file_binary(id: str) -> Union[tuple[bytes, str], None]: + """ + get file binary + + :param id: the id of the file + + :return: the binary of the file, mime type + """ + tool_file: ToolFile | None = ( + db.session.query(ToolFile) + .filter( + ToolFile.id == id, + ) + .first() + ) + + if not tool_file: + return None + + blob = storage.load_once(tool_file.file_key) + + return blob, tool_file.mimetype + + @staticmethod + def get_file_binary_by_message_file_id(id: str) -> Union[tuple[bytes, str], None]: + """ + get file binary + + :param id: the id of the file + + :return: the binary of the file, mime type + """ + message_file: MessageFile | None = ( + db.session.query(MessageFile) + .filter( + MessageFile.id == id, + ) + .first() + ) + + # Check if message_file is not None + if message_file is not None: + # get tool file id + if message_file.url is not None: + tool_file_id = message_file.url.split("/")[-1] + # trim extension + tool_file_id = tool_file_id.split(".")[0] + else: + tool_file_id = None + else: + tool_file_id = None + + tool_file: ToolFile | None = ( + db.session.query(ToolFile) + .filter( + ToolFile.id == tool_file_id, + ) + .first() + ) + + if not tool_file: + return None + + blob = storage.load_once(tool_file.file_key) + + return blob, tool_file.mimetype + + @staticmethod + def get_file_generator_by_tool_file_id(tool_file_id: str): + """ + get file binary + + :param tool_file_id: the id of the tool file + + :return: the binary of the file, mime type + """ + tool_file: ToolFile | None = ( + db.session.query(ToolFile) + .filter( + ToolFile.id == tool_file_id, + ) + .first() + ) + + if not tool_file: + return None, None + + stream = storage.load_stream(tool_file.file_key) + + return stream, tool_file + + +# init tool_file_parser +from core.file.tool_file_parser import tool_file_manager + +tool_file_manager["manager"] = ToolFileManager diff --git a/api/core/datasource/tool_label_manager.py b/api/core/datasource/tool_label_manager.py new file mode 100644 index 0000000000..4787d7d79c --- /dev/null +++ b/api/core/datasource/tool_label_manager.py @@ -0,0 +1,101 @@ +from core.tools.__base.tool_provider import ToolProviderController +from core.tools.builtin_tool.provider import BuiltinToolProviderController +from core.tools.custom_tool.provider import ApiToolProviderController +from core.tools.entities.values import default_tool_label_name_list +from core.tools.workflow_as_tool.provider import WorkflowToolProviderController +from extensions.ext_database import db +from models.tools import ToolLabelBinding + + +class ToolLabelManager: + @classmethod + def filter_tool_labels(cls, tool_labels: list[str]) -> list[str]: + """ + Filter tool labels + """ + tool_labels = [label for label in tool_labels if label in default_tool_label_name_list] + return list(set(tool_labels)) + + @classmethod + def update_tool_labels(cls, controller: ToolProviderController, labels: list[str]): + """ + Update tool labels + """ + labels = cls.filter_tool_labels(labels) + + if isinstance(controller, ApiToolProviderController | WorkflowToolProviderController): + provider_id = controller.provider_id + else: + raise ValueError("Unsupported tool type") + + # delete old labels + db.session.query(ToolLabelBinding).filter(ToolLabelBinding.tool_id == provider_id).delete() + + # insert new labels + for label in labels: + db.session.add( + ToolLabelBinding( + tool_id=provider_id, + tool_type=controller.provider_type.value, + label_name=label, + ) + ) + + db.session.commit() + + @classmethod + def get_tool_labels(cls, controller: ToolProviderController) -> list[str]: + """ + Get tool labels + """ + if isinstance(controller, ApiToolProviderController | WorkflowToolProviderController): + provider_id = controller.provider_id + elif isinstance(controller, BuiltinToolProviderController): + return controller.tool_labels + else: + raise ValueError("Unsupported tool type") + + labels = ( + db.session.query(ToolLabelBinding.label_name) + .filter( + ToolLabelBinding.tool_id == provider_id, + ToolLabelBinding.tool_type == controller.provider_type.value, + ) + .all() + ) + + return [label.label_name for label in labels] + + @classmethod + def get_tools_labels(cls, tool_providers: list[ToolProviderController]) -> dict[str, list[str]]: + """ + Get tools labels + + :param tool_providers: list of tool providers + + :return: dict of tool labels + :key: tool id + :value: list of tool labels + """ + if not tool_providers: + return {} + + for controller in tool_providers: + if not isinstance(controller, ApiToolProviderController | WorkflowToolProviderController): + raise ValueError("Unsupported tool type") + + provider_ids = [] + for controller in tool_providers: + assert isinstance(controller, ApiToolProviderController | WorkflowToolProviderController) + provider_ids.append(controller.provider_id) + + labels: list[ToolLabelBinding] = ( + db.session.query(ToolLabelBinding).filter(ToolLabelBinding.tool_id.in_(provider_ids)).all() + ) + + tool_labels: dict[str, list[str]] = {label.tool_id: [] for label in labels} + + for label in labels: + tool_labels[label.tool_id].append(label.label_name) + + return tool_labels diff --git a/api/core/datasource/tool_manager.py b/api/core/datasource/tool_manager.py new file mode 100644 index 0000000000..f2d0b74f7c --- /dev/null +++ b/api/core/datasource/tool_manager.py @@ -0,0 +1,870 @@ +import json +import logging +import mimetypes +from collections.abc import Generator +from os import listdir, path +from threading import Lock +from typing import TYPE_CHECKING, Any, Union, cast + +from yarl import URL + +import contexts +from core.plugin.entities.plugin import ToolProviderID +from core.plugin.manager.tool import PluginToolManager +from core.tools.__base.tool_provider import ToolProviderController +from core.tools.__base.tool_runtime import ToolRuntime +from core.tools.plugin_tool.provider import PluginToolProviderController +from core.tools.plugin_tool.tool import PluginTool +from core.tools.workflow_as_tool.provider import WorkflowToolProviderController + +if TYPE_CHECKING: + from core.workflow.nodes.tool.entities import ToolEntity + + +from configs import dify_config +from core.agent.entities import AgentToolEntity +from core.app.entities.app_invoke_entities import InvokeFrom +from core.helper.module_import_helper import load_single_subclass_from_source +from core.helper.position_helper import is_filtered +from core.model_runtime.utils.encoders import jsonable_encoder +from core.tools.__base.tool import Tool +from core.tools.builtin_tool.provider import BuiltinToolProviderController +from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort +from core.tools.builtin_tool.tool import BuiltinTool +from core.tools.custom_tool.provider import ApiToolProviderController +from core.tools.custom_tool.tool import ApiTool +from core.tools.entities.api_entities import ToolProviderApiEntity, ToolProviderTypeApiLiteral +from core.tools.entities.common_entities import I18nObject +from core.tools.entities.tool_entities import ( + ApiProviderAuthType, + ToolInvokeFrom, + ToolParameter, + ToolProviderType, +) +from core.tools.errors import ToolNotFoundError, ToolProviderNotFoundError +from core.tools.tool_label_manager import ToolLabelManager +from core.tools.utils.configuration import ( + ProviderConfigEncrypter, + ToolParameterConfigurationManager, +) +from core.tools.workflow_as_tool.tool import WorkflowTool +from extensions.ext_database import db +from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider +from services.tools.tools_transform_service import ToolTransformService + +logger = logging.getLogger(__name__) + + +class ToolManager: + _builtin_provider_lock = Lock() + _hardcoded_providers: dict[str, BuiltinToolProviderController] = {} + _builtin_providers_loaded = False + _builtin_tools_labels: dict[str, Union[I18nObject, None]] = {} + + @classmethod + def get_hardcoded_provider(cls, provider: str) -> BuiltinToolProviderController: + """ + get the hardcoded provider + """ + if len(cls._hardcoded_providers) == 0: + # init the builtin providers + cls.load_hardcoded_providers_cache() + + return cls._hardcoded_providers[provider] + + @classmethod + def get_builtin_provider( + cls, provider: str, tenant_id: str + ) -> BuiltinToolProviderController | PluginToolProviderController: + """ + get the builtin provider + + :param provider: the name of the provider + :param tenant_id: the id of the tenant + :return: the provider + """ + # split provider to + + if len(cls._hardcoded_providers) == 0: + # init the builtin providers + cls.load_hardcoded_providers_cache() + + if provider not in cls._hardcoded_providers: + # get plugin provider + plugin_provider = cls.get_plugin_provider(provider, tenant_id) + if plugin_provider: + return plugin_provider + + return cls._hardcoded_providers[provider] + + @classmethod + def get_plugin_provider(cls, provider: str, tenant_id: str) -> PluginToolProviderController: + """ + get the plugin provider + """ + # check if context is set + try: + contexts.plugin_tool_providers.get() + except LookupError: + contexts.plugin_tool_providers.set({}) + contexts.plugin_tool_providers_lock.set(Lock()) + + with contexts.plugin_tool_providers_lock.get(): + plugin_tool_providers = contexts.plugin_tool_providers.get() + if provider in plugin_tool_providers: + return plugin_tool_providers[provider] + + manager = PluginToolManager() + provider_entity = manager.fetch_tool_provider(tenant_id, provider) + if not provider_entity: + raise ToolProviderNotFoundError(f"plugin provider {provider} not found") + + controller = PluginToolProviderController( + entity=provider_entity.declaration, + plugin_id=provider_entity.plugin_id, + plugin_unique_identifier=provider_entity.plugin_unique_identifier, + tenant_id=tenant_id, + ) + + plugin_tool_providers[provider] = controller + + return controller + + @classmethod + def get_builtin_tool(cls, provider: str, tool_name: str, tenant_id: str) -> BuiltinTool | PluginTool | None: + """ + get the builtin tool + + :param provider: the name of the provider + :param tool_name: the name of the tool + :param tenant_id: the id of the tenant + :return: the provider, the tool + """ + provider_controller = cls.get_builtin_provider(provider, tenant_id) + tool = provider_controller.get_tool(tool_name) + if tool is None: + raise ToolNotFoundError(f"tool {tool_name} not found") + + return tool + + @classmethod + def get_tool_runtime( + cls, + provider_type: ToolProviderType, + provider_id: str, + tool_name: str, + tenant_id: str, + invoke_from: InvokeFrom = InvokeFrom.DEBUGGER, + tool_invoke_from: ToolInvokeFrom = ToolInvokeFrom.AGENT, + ) -> Union[BuiltinTool, PluginTool, ApiTool, WorkflowTool]: + """ + get the tool runtime + + :param provider_type: the type of the provider + :param provider_id: the id of the provider + :param tool_name: the name of the tool + :param tenant_id: the tenant id + :param invoke_from: invoke from + :param tool_invoke_from: the tool invoke from + + :return: the tool + """ + if provider_type == ToolProviderType.BUILT_IN: + # check if the builtin tool need credentials + provider_controller = cls.get_builtin_provider(provider_id, tenant_id) + + builtin_tool = provider_controller.get_tool(tool_name) + if not builtin_tool: + raise ToolProviderNotFoundError(f"builtin tool {tool_name} not found") + + if not provider_controller.need_credentials: + return cast( + BuiltinTool, + builtin_tool.fork_tool_runtime( + runtime=ToolRuntime( + tenant_id=tenant_id, + credentials={}, + invoke_from=invoke_from, + tool_invoke_from=tool_invoke_from, + ) + ), + ) + + if isinstance(provider_controller, PluginToolProviderController): + provider_id_entity = ToolProviderID(provider_id) + # get credentials + builtin_provider: BuiltinToolProvider | None = ( + db.session.query(BuiltinToolProvider) + .filter( + BuiltinToolProvider.tenant_id == tenant_id, + (BuiltinToolProvider.provider == str(provider_id_entity)) + | (BuiltinToolProvider.provider == provider_id_entity.provider_name), + ) + .first() + ) + + if builtin_provider is None: + raise ToolProviderNotFoundError(f"builtin provider {provider_id} not found") + else: + builtin_provider = ( + db.session.query(BuiltinToolProvider) + .filter(BuiltinToolProvider.tenant_id == tenant_id, (BuiltinToolProvider.provider == provider_id)) + .first() + ) + + if builtin_provider is None: + raise ToolProviderNotFoundError(f"builtin provider {provider_id} not found") + + # decrypt the credentials + credentials = builtin_provider.credentials + tool_configuration = ProviderConfigEncrypter( + tenant_id=tenant_id, + config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()], + provider_type=provider_controller.provider_type.value, + provider_identity=provider_controller.entity.identity.name, + ) + + decrypted_credentials = tool_configuration.decrypt(credentials) + + return cast( + BuiltinTool, + builtin_tool.fork_tool_runtime( + runtime=ToolRuntime( + tenant_id=tenant_id, + credentials=decrypted_credentials, + runtime_parameters={}, + invoke_from=invoke_from, + tool_invoke_from=tool_invoke_from, + ) + ), + ) + + elif provider_type == ToolProviderType.API: + api_provider, credentials = cls.get_api_provider_controller(tenant_id, provider_id) + + # decrypt the credentials + tool_configuration = ProviderConfigEncrypter( + tenant_id=tenant_id, + config=[x.to_basic_provider_config() for x in api_provider.get_credentials_schema()], + provider_type=api_provider.provider_type.value, + provider_identity=api_provider.entity.identity.name, + ) + decrypted_credentials = tool_configuration.decrypt(credentials) + + return cast( + ApiTool, + api_provider.get_tool(tool_name).fork_tool_runtime( + runtime=ToolRuntime( + tenant_id=tenant_id, + credentials=decrypted_credentials, + invoke_from=invoke_from, + tool_invoke_from=tool_invoke_from, + ) + ), + ) + elif provider_type == ToolProviderType.WORKFLOW: + workflow_provider = ( + db.session.query(WorkflowToolProvider) + .filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id) + .first() + ) + + if workflow_provider is None: + raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found") + + controller = ToolTransformService.workflow_provider_to_controller(db_provider=workflow_provider) + controller_tools: list[WorkflowTool] = controller.get_tools(tenant_id=workflow_provider.tenant_id) + if controller_tools is None or len(controller_tools) == 0: + raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found") + + return cast( + WorkflowTool, + controller.get_tools(tenant_id=workflow_provider.tenant_id)[0].fork_tool_runtime( + runtime=ToolRuntime( + tenant_id=tenant_id, + credentials={}, + invoke_from=invoke_from, + tool_invoke_from=tool_invoke_from, + ) + ), + ) + elif provider_type == ToolProviderType.APP: + raise NotImplementedError("app provider not implemented") + elif provider_type == ToolProviderType.PLUGIN: + return cls.get_plugin_provider(provider_id, tenant_id).get_tool(tool_name) + else: + raise ToolProviderNotFoundError(f"provider type {provider_type.value} not found") + + @classmethod + def get_agent_tool_runtime( + cls, + tenant_id: str, + app_id: str, + agent_tool: AgentToolEntity, + invoke_from: InvokeFrom = InvokeFrom.DEBUGGER, + ) -> Tool: + """ + get the agent tool runtime + """ + tool_entity = cls.get_tool_runtime( + provider_type=agent_tool.provider_type, + provider_id=agent_tool.provider_id, + tool_name=agent_tool.tool_name, + tenant_id=tenant_id, + invoke_from=invoke_from, + tool_invoke_from=ToolInvokeFrom.AGENT, + ) + runtime_parameters = {} + parameters = tool_entity.get_merged_runtime_parameters() + for parameter in parameters: + # check file types + if ( + parameter.type + in { + ToolParameter.ToolParameterType.SYSTEM_FILES, + ToolParameter.ToolParameterType.FILE, + ToolParameter.ToolParameterType.FILES, + } + and parameter.required + ): + raise ValueError(f"file type parameter {parameter.name} not supported in agent") + + if parameter.form == ToolParameter.ToolParameterForm.FORM: + # save tool parameter to tool entity memory + value = parameter.init_frontend_parameter(agent_tool.tool_parameters.get(parameter.name)) + runtime_parameters[parameter.name] = value + + # decrypt runtime parameters + encryption_manager = ToolParameterConfigurationManager( + tenant_id=tenant_id, + tool_runtime=tool_entity, + provider_name=agent_tool.provider_id, + provider_type=agent_tool.provider_type, + identity_id=f"AGENT.{app_id}", + ) + runtime_parameters = encryption_manager.decrypt_tool_parameters(runtime_parameters) + if tool_entity.runtime is None or tool_entity.runtime.runtime_parameters is None: + raise ValueError("runtime not found or runtime parameters not found") + + tool_entity.runtime.runtime_parameters.update(runtime_parameters) + return tool_entity + + @classmethod + def get_workflow_tool_runtime( + cls, + tenant_id: str, + app_id: str, + node_id: str, + workflow_tool: "ToolEntity", + invoke_from: InvokeFrom = InvokeFrom.DEBUGGER, + ) -> Tool: + """ + get the workflow tool runtime + """ + tool_runtime = cls.get_tool_runtime( + provider_type=workflow_tool.provider_type, + provider_id=workflow_tool.provider_id, + tool_name=workflow_tool.tool_name, + tenant_id=tenant_id, + invoke_from=invoke_from, + tool_invoke_from=ToolInvokeFrom.WORKFLOW, + ) + runtime_parameters = {} + parameters = tool_runtime.get_merged_runtime_parameters() + + for parameter in parameters: + # save tool parameter to tool entity memory + if parameter.form == ToolParameter.ToolParameterForm.FORM: + value = parameter.init_frontend_parameter(workflow_tool.tool_configurations.get(parameter.name)) + runtime_parameters[parameter.name] = value + + # decrypt runtime parameters + encryption_manager = ToolParameterConfigurationManager( + tenant_id=tenant_id, + tool_runtime=tool_runtime, + provider_name=workflow_tool.provider_id, + provider_type=workflow_tool.provider_type, + identity_id=f"WORKFLOW.{app_id}.{node_id}", + ) + + if runtime_parameters: + runtime_parameters = encryption_manager.decrypt_tool_parameters(runtime_parameters) + + tool_runtime.runtime.runtime_parameters.update(runtime_parameters) + return tool_runtime + + @classmethod + def get_tool_runtime_from_plugin( + cls, + tool_type: ToolProviderType, + tenant_id: str, + provider: str, + tool_name: str, + tool_parameters: dict[str, Any], + ) -> Tool: + """ + get tool runtime from plugin + """ + tool_entity = cls.get_tool_runtime( + provider_type=tool_type, + provider_id=provider, + tool_name=tool_name, + tenant_id=tenant_id, + invoke_from=InvokeFrom.SERVICE_API, + tool_invoke_from=ToolInvokeFrom.PLUGIN, + ) + runtime_parameters = {} + parameters = tool_entity.get_merged_runtime_parameters() + for parameter in parameters: + if parameter.form == ToolParameter.ToolParameterForm.FORM: + # save tool parameter to tool entity memory + value = parameter.init_frontend_parameter(tool_parameters.get(parameter.name)) + runtime_parameters[parameter.name] = value + + tool_entity.runtime.runtime_parameters.update(runtime_parameters) + return tool_entity + + @classmethod + def get_hardcoded_provider_icon(cls, provider: str) -> tuple[str, str]: + """ + get the absolute path of the icon of the hardcoded provider + + :param provider: the name of the provider + :return: the absolute path of the icon, the mime type of the icon + """ + # get provider + provider_controller = cls.get_hardcoded_provider(provider) + + absolute_path = path.join( + path.dirname(path.realpath(__file__)), + "builtin_tool", + "providers", + provider, + "_assets", + provider_controller.entity.identity.icon, + ) + # check if the icon exists + if not path.exists(absolute_path): + raise ToolProviderNotFoundError(f"builtin provider {provider} icon not found") + + # get the mime type + mime_type, _ = mimetypes.guess_type(absolute_path) + mime_type = mime_type or "application/octet-stream" + + return absolute_path, mime_type + + @classmethod + def list_hardcoded_providers(cls): + # use cache first + if cls._builtin_providers_loaded: + yield from list(cls._hardcoded_providers.values()) + return + + with cls._builtin_provider_lock: + if cls._builtin_providers_loaded: + yield from list(cls._hardcoded_providers.values()) + return + + yield from cls._list_hardcoded_providers() + + @classmethod + def list_plugin_providers(cls, tenant_id: str) -> list[PluginToolProviderController]: + """ + list all the plugin providers + """ + manager = PluginToolManager() + provider_entities = manager.fetch_tool_providers(tenant_id) + return [ + PluginToolProviderController( + entity=provider.declaration, + plugin_id=provider.plugin_id, + plugin_unique_identifier=provider.plugin_unique_identifier, + tenant_id=tenant_id, + ) + for provider in provider_entities + ] + + @classmethod + def list_builtin_providers( + cls, tenant_id: str + ) -> Generator[BuiltinToolProviderController | PluginToolProviderController, None, None]: + """ + list all the builtin providers + """ + yield from cls.list_hardcoded_providers() + # get plugin providers + yield from cls.list_plugin_providers(tenant_id) + + @classmethod + def _list_hardcoded_providers(cls) -> Generator[BuiltinToolProviderController, None, None]: + """ + list all the builtin providers + """ + for provider_path in listdir(path.join(path.dirname(path.realpath(__file__)), "builtin_tool", "providers")): + if provider_path.startswith("__"): + continue + + if path.isdir(path.join(path.dirname(path.realpath(__file__)), "builtin_tool", "providers", provider_path)): + if provider_path.startswith("__"): + continue + + # init provider + try: + provider_class = load_single_subclass_from_source( + module_name=f"core.tools.builtin_tool.providers.{provider_path}.{provider_path}", + script_path=path.join( + path.dirname(path.realpath(__file__)), + "builtin_tool", + "providers", + provider_path, + f"{provider_path}.py", + ), + parent_type=BuiltinToolProviderController, + ) + provider: BuiltinToolProviderController = provider_class() + cls._hardcoded_providers[provider.entity.identity.name] = provider + for tool in provider.get_tools(): + cls._builtin_tools_labels[tool.entity.identity.name] = tool.entity.identity.label + yield provider + + except Exception: + logger.exception(f"load builtin provider {provider}") + continue + # set builtin providers loaded + cls._builtin_providers_loaded = True + + @classmethod + def load_hardcoded_providers_cache(cls): + for _ in cls.list_hardcoded_providers(): + pass + + @classmethod + def clear_hardcoded_providers_cache(cls): + cls._hardcoded_providers = {} + cls._builtin_providers_loaded = False + + @classmethod + def get_tool_label(cls, tool_name: str) -> Union[I18nObject, None]: + """ + get the tool label + + :param tool_name: the name of the tool + + :return: the label of the tool + """ + if len(cls._builtin_tools_labels) == 0: + # init the builtin providers + cls.load_hardcoded_providers_cache() + + if tool_name not in cls._builtin_tools_labels: + return None + + return cls._builtin_tools_labels[tool_name] + + @classmethod + def list_providers_from_api( + cls, user_id: str, tenant_id: str, typ: ToolProviderTypeApiLiteral + ) -> list[ToolProviderApiEntity]: + result_providers: dict[str, ToolProviderApiEntity] = {} + + filters = [] + if not typ: + filters.extend(["builtin", "api", "workflow"]) + else: + filters.append(typ) + + with db.session.no_autoflush: + if "builtin" in filters: + # get builtin providers + builtin_providers = cls.list_builtin_providers(tenant_id) + + # get db builtin providers + db_builtin_providers: list[BuiltinToolProvider] = ( + db.session.query(BuiltinToolProvider).filter(BuiltinToolProvider.tenant_id == tenant_id).all() + ) + + # rewrite db_builtin_providers + for db_provider in db_builtin_providers: + tool_provider_id = str(ToolProviderID(db_provider.provider)) + db_provider.provider = tool_provider_id + + def find_db_builtin_provider(provider): + return next((x for x in db_builtin_providers if x.provider == provider), None) + + # append builtin providers + for provider in builtin_providers: + # handle include, exclude + if is_filtered( + include_set=cast(set[str], dify_config.POSITION_TOOL_INCLUDES_SET), + exclude_set=cast(set[str], dify_config.POSITION_TOOL_EXCLUDES_SET), + data=provider, + name_func=lambda x: x.identity.name, + ): + continue + + user_provider = ToolTransformService.builtin_provider_to_user_provider( + provider_controller=provider, + db_provider=find_db_builtin_provider(provider.entity.identity.name), + decrypt_credentials=False, + ) + + if isinstance(provider, PluginToolProviderController): + result_providers[f"plugin_provider.{user_provider.name}"] = user_provider + else: + result_providers[f"builtin_provider.{user_provider.name}"] = user_provider + + # get db api providers + + if "api" in filters: + db_api_providers: list[ApiToolProvider] = ( + db.session.query(ApiToolProvider).filter(ApiToolProvider.tenant_id == tenant_id).all() + ) + + api_provider_controllers: list[dict[str, Any]] = [ + {"provider": provider, "controller": ToolTransformService.api_provider_to_controller(provider)} + for provider in db_api_providers + ] + + # get labels + labels = ToolLabelManager.get_tools_labels([x["controller"] for x in api_provider_controllers]) + + for api_provider_controller in api_provider_controllers: + user_provider = ToolTransformService.api_provider_to_user_provider( + provider_controller=api_provider_controller["controller"], + db_provider=api_provider_controller["provider"], + decrypt_credentials=False, + labels=labels.get(api_provider_controller["controller"].provider_id, []), + ) + result_providers[f"api_provider.{user_provider.name}"] = user_provider + + if "workflow" in filters: + # get workflow providers + workflow_providers: list[WorkflowToolProvider] = ( + db.session.query(WorkflowToolProvider).filter(WorkflowToolProvider.tenant_id == tenant_id).all() + ) + + workflow_provider_controllers: list[WorkflowToolProviderController] = [] + for provider in workflow_providers: + try: + workflow_provider_controllers.append( + ToolTransformService.workflow_provider_to_controller(db_provider=provider) + ) + except Exception: + # app has been deleted + pass + + labels = ToolLabelManager.get_tools_labels( + [cast(ToolProviderController, controller) for controller in workflow_provider_controllers] + ) + + for provider_controller in workflow_provider_controllers: + user_provider = ToolTransformService.workflow_provider_to_user_provider( + provider_controller=provider_controller, + labels=labels.get(provider_controller.provider_id, []), + ) + result_providers[f"workflow_provider.{user_provider.name}"] = user_provider + + return BuiltinToolProviderSort.sort(list(result_providers.values())) + + @classmethod + def get_api_provider_controller( + cls, tenant_id: str, provider_id: str + ) -> tuple[ApiToolProviderController, dict[str, Any]]: + """ + get the api provider + + :param tenant_id: the id of the tenant + :param provider_id: the id of the provider + + :return: the provider controller, the credentials + """ + provider: ApiToolProvider | None = ( + db.session.query(ApiToolProvider) + .filter( + ApiToolProvider.id == provider_id, + ApiToolProvider.tenant_id == tenant_id, + ) + .first() + ) + + if provider is None: + raise ToolProviderNotFoundError(f"api provider {provider_id} not found") + + controller = ApiToolProviderController.from_db( + provider, + ApiProviderAuthType.API_KEY if provider.credentials["auth_type"] == "api_key" else ApiProviderAuthType.NONE, + ) + controller.load_bundled_tools(provider.tools) + + return controller, provider.credentials + + @classmethod + def user_get_api_provider(cls, provider: str, tenant_id: str) -> dict: + """ + get api provider + """ + """ + get tool provider + """ + provider_name = provider + provider_obj: ApiToolProvider | None = ( + db.session.query(ApiToolProvider) + .filter( + ApiToolProvider.tenant_id == tenant_id, + ApiToolProvider.name == provider, + ) + .first() + ) + + if provider_obj is None: + raise ValueError(f"you have not added provider {provider_name}") + + try: + credentials = json.loads(provider_obj.credentials_str) or {} + except Exception: + credentials = {} + + # package tool provider controller + controller = ApiToolProviderController.from_db( + provider_obj, + ApiProviderAuthType.API_KEY if credentials["auth_type"] == "api_key" else ApiProviderAuthType.NONE, + ) + # init tool configuration + tool_configuration = ProviderConfigEncrypter( + tenant_id=tenant_id, + config=[x.to_basic_provider_config() for x in controller.get_credentials_schema()], + provider_type=controller.provider_type.value, + provider_identity=controller.entity.identity.name, + ) + + decrypted_credentials = tool_configuration.decrypt(credentials) + masked_credentials = tool_configuration.mask_tool_credentials(decrypted_credentials) + + try: + icon = json.loads(provider_obj.icon) + except Exception: + icon = {"background": "#252525", "content": "\ud83d\ude01"} + + # add tool labels + labels = ToolLabelManager.get_tool_labels(controller) + + return cast( + dict, + jsonable_encoder( + { + "schema_type": provider_obj.schema_type, + "schema": provider_obj.schema, + "tools": provider_obj.tools, + "icon": icon, + "description": provider_obj.description, + "credentials": masked_credentials, + "privacy_policy": provider_obj.privacy_policy, + "custom_disclaimer": provider_obj.custom_disclaimer, + "labels": labels, + } + ), + ) + + @classmethod + def generate_builtin_tool_icon_url(cls, provider_id: str) -> str: + return str( + URL(dify_config.CONSOLE_API_URL or "/") + / "console" + / "api" + / "workspaces" + / "current" + / "tool-provider" + / "builtin" + / provider_id + / "icon" + ) + + @classmethod + def generate_plugin_tool_icon_url(cls, tenant_id: str, filename: str) -> str: + return str( + URL(dify_config.CONSOLE_API_URL or "/") + / "console" + / "api" + / "workspaces" + / "current" + / "plugin" + / "icon" + % {"tenant_id": tenant_id, "filename": filename} + ) + + @classmethod + def generate_workflow_tool_icon_url(cls, tenant_id: str, provider_id: str) -> dict: + try: + workflow_provider: WorkflowToolProvider | None = ( + db.session.query(WorkflowToolProvider) + .filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id) + .first() + ) + + if workflow_provider is None: + raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found") + + icon: dict = json.loads(workflow_provider.icon) + return icon + except Exception: + return {"background": "#252525", "content": "\ud83d\ude01"} + + @classmethod + def generate_api_tool_icon_url(cls, tenant_id: str, provider_id: str) -> dict: + try: + api_provider: ApiToolProvider | None = ( + db.session.query(ApiToolProvider) + .filter(ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.id == provider_id) + .first() + ) + + if api_provider is None: + raise ToolProviderNotFoundError(f"api provider {provider_id} not found") + + icon: dict = json.loads(api_provider.icon) + return icon + except Exception: + return {"background": "#252525", "content": "\ud83d\ude01"} + + @classmethod + def get_tool_icon( + cls, + tenant_id: str, + provider_type: ToolProviderType, + provider_id: str, + ) -> Union[str, dict]: + """ + get the tool icon + + :param tenant_id: the id of the tenant + :param provider_type: the type of the provider + :param provider_id: the id of the provider + :return: + """ + provider_type = provider_type + provider_id = provider_id + if provider_type == ToolProviderType.BUILT_IN: + provider = ToolManager.get_builtin_provider(provider_id, tenant_id) + if isinstance(provider, PluginToolProviderController): + try: + return cls.generate_plugin_tool_icon_url(tenant_id, provider.entity.identity.icon) + except Exception: + return {"background": "#252525", "content": "\ud83d\ude01"} + return cls.generate_builtin_tool_icon_url(provider_id) + elif provider_type == ToolProviderType.API: + return cls.generate_api_tool_icon_url(tenant_id, provider_id) + elif provider_type == ToolProviderType.WORKFLOW: + return cls.generate_workflow_tool_icon_url(tenant_id, provider_id) + elif provider_type == ToolProviderType.PLUGIN: + provider = ToolManager.get_builtin_provider(provider_id, tenant_id) + if isinstance(provider, PluginToolProviderController): + try: + return cls.generate_plugin_tool_icon_url(tenant_id, provider.entity.identity.icon) + except Exception: + return {"background": "#252525", "content": "\ud83d\ude01"} + raise ValueError(f"plugin provider {provider_id} not found") + else: + raise ValueError(f"provider type {provider_type} not found") + + +ToolManager.load_hardcoded_providers_cache() diff --git a/api/core/datasource/utils/__init__.py b/api/core/datasource/utils/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/datasource/utils/configuration.py b/api/core/datasource/utils/configuration.py new file mode 100644 index 0000000000..6a5fba65bd --- /dev/null +++ b/api/core/datasource/utils/configuration.py @@ -0,0 +1,265 @@ +from copy import deepcopy +from typing import Any + +from pydantic import BaseModel + +from core.entities.provider_entities import BasicProviderConfig +from core.helper import encrypter +from core.helper.tool_parameter_cache import ToolParameterCache, ToolParameterCacheType +from core.helper.tool_provider_cache import ToolProviderCredentialsCache, ToolProviderCredentialsCacheType +from core.tools.__base.tool import Tool +from core.tools.entities.tool_entities import ( + ToolParameter, + ToolProviderType, +) + + +class ProviderConfigEncrypter(BaseModel): + tenant_id: str + config: list[BasicProviderConfig] + provider_type: str + provider_identity: str + + def _deep_copy(self, data: dict[str, str]) -> dict[str, str]: + """ + deep copy data + """ + return deepcopy(data) + + def encrypt(self, data: dict[str, str]) -> dict[str, str]: + """ + encrypt tool credentials with tenant id + + return a deep copy of credentials with encrypted values + """ + data = self._deep_copy(data) + + # get fields need to be decrypted + fields = dict[str, BasicProviderConfig]() + for credential in self.config: + fields[credential.name] = credential + + for field_name, field in fields.items(): + if field.type == BasicProviderConfig.Type.SECRET_INPUT: + if field_name in data: + encrypted = encrypter.encrypt_token(self.tenant_id, data[field_name] or "") + data[field_name] = encrypted + + return data + + def mask_tool_credentials(self, data: dict[str, Any]) -> dict[str, Any]: + """ + mask tool credentials + + return a deep copy of credentials with masked values + """ + data = self._deep_copy(data) + + # get fields need to be decrypted + fields = dict[str, BasicProviderConfig]() + for credential in self.config: + fields[credential.name] = credential + + for field_name, field in fields.items(): + if field.type == BasicProviderConfig.Type.SECRET_INPUT: + if field_name in data: + if len(data[field_name]) > 6: + data[field_name] = ( + data[field_name][:2] + "*" * (len(data[field_name]) - 4) + data[field_name][-2:] + ) + else: + data[field_name] = "*" * len(data[field_name]) + + return data + + def decrypt(self, data: dict[str, str]) -> dict[str, str]: + """ + decrypt tool credentials with tenant id + + return a deep copy of credentials with decrypted values + """ + cache = ToolProviderCredentialsCache( + tenant_id=self.tenant_id, + identity_id=f"{self.provider_type}.{self.provider_identity}", + cache_type=ToolProviderCredentialsCacheType.PROVIDER, + ) + cached_credentials = cache.get() + if cached_credentials: + return cached_credentials + data = self._deep_copy(data) + # get fields need to be decrypted + fields = dict[str, BasicProviderConfig]() + for credential in self.config: + fields[credential.name] = credential + + for field_name, field in fields.items(): + if field.type == BasicProviderConfig.Type.SECRET_INPUT: + if field_name in data: + try: + # if the value is None or empty string, skip decrypt + if not data[field_name]: + continue + + data[field_name] = encrypter.decrypt_token(self.tenant_id, data[field_name]) + except Exception: + pass + + cache.set(data) + return data + + def delete_tool_credentials_cache(self): + cache = ToolProviderCredentialsCache( + tenant_id=self.tenant_id, + identity_id=f"{self.provider_type}.{self.provider_identity}", + cache_type=ToolProviderCredentialsCacheType.PROVIDER, + ) + cache.delete() + + +class ToolParameterConfigurationManager: + """ + Tool parameter configuration manager + """ + + tenant_id: str + tool_runtime: Tool + provider_name: str + provider_type: ToolProviderType + identity_id: str + + def __init__( + self, tenant_id: str, tool_runtime: Tool, provider_name: str, provider_type: ToolProviderType, identity_id: str + ) -> None: + self.tenant_id = tenant_id + self.tool_runtime = tool_runtime + self.provider_name = provider_name + self.provider_type = provider_type + self.identity_id = identity_id + + def _deep_copy(self, parameters: dict[str, Any]) -> dict[str, Any]: + """ + deep copy parameters + """ + return deepcopy(parameters) + + def _merge_parameters(self) -> list[ToolParameter]: + """ + merge parameters + """ + # get tool parameters + tool_parameters = self.tool_runtime.entity.parameters or [] + # get tool runtime parameters + runtime_parameters = self.tool_runtime.get_runtime_parameters() + # override parameters + current_parameters = tool_parameters.copy() + for runtime_parameter in runtime_parameters: + found = False + for index, parameter in enumerate(current_parameters): + if parameter.name == runtime_parameter.name and parameter.form == runtime_parameter.form: + current_parameters[index] = runtime_parameter + found = True + break + + if not found and runtime_parameter.form == ToolParameter.ToolParameterForm.FORM: + current_parameters.append(runtime_parameter) + + return current_parameters + + def mask_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]: + """ + mask tool parameters + + return a deep copy of parameters with masked values + """ + parameters = self._deep_copy(parameters) + + # override parameters + current_parameters = self._merge_parameters() + + for parameter in current_parameters: + if ( + parameter.form == ToolParameter.ToolParameterForm.FORM + and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT + ): + if parameter.name in parameters: + if len(parameters[parameter.name]) > 6: + parameters[parameter.name] = ( + parameters[parameter.name][:2] + + "*" * (len(parameters[parameter.name]) - 4) + + parameters[parameter.name][-2:] + ) + else: + parameters[parameter.name] = "*" * len(parameters[parameter.name]) + + return parameters + + def encrypt_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]: + """ + encrypt tool parameters with tenant id + + return a deep copy of parameters with encrypted values + """ + # override parameters + current_parameters = self._merge_parameters() + + parameters = self._deep_copy(parameters) + + for parameter in current_parameters: + if ( + parameter.form == ToolParameter.ToolParameterForm.FORM + and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT + ): + if parameter.name in parameters: + encrypted = encrypter.encrypt_token(self.tenant_id, parameters[parameter.name]) + parameters[parameter.name] = encrypted + + return parameters + + def decrypt_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]: + """ + decrypt tool parameters with tenant id + + return a deep copy of parameters with decrypted values + """ + + cache = ToolParameterCache( + tenant_id=self.tenant_id, + provider=f"{self.provider_type.value}.{self.provider_name}", + tool_name=self.tool_runtime.entity.identity.name, + cache_type=ToolParameterCacheType.PARAMETER, + identity_id=self.identity_id, + ) + cached_parameters = cache.get() + if cached_parameters: + return cached_parameters + + # override parameters + current_parameters = self._merge_parameters() + has_secret_input = False + + for parameter in current_parameters: + if ( + parameter.form == ToolParameter.ToolParameterForm.FORM + and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT + ): + if parameter.name in parameters: + try: + has_secret_input = True + parameters[parameter.name] = encrypter.decrypt_token(self.tenant_id, parameters[parameter.name]) + except Exception: + pass + + if has_secret_input: + cache.set(parameters) + + return parameters + + def delete_tool_parameters_cache(self): + cache = ToolParameterCache( + tenant_id=self.tenant_id, + provider=f"{self.provider_type.value}.{self.provider_name}", + tool_name=self.tool_runtime.entity.identity.name, + cache_type=ToolParameterCacheType.PARAMETER, + identity_id=self.identity_id, + ) + cache.delete() diff --git a/api/core/datasource/utils/dataset_retriever/dataset_multi_retriever_tool.py b/api/core/datasource/utils/dataset_retriever/dataset_multi_retriever_tool.py new file mode 100644 index 0000000000..032274b87e --- /dev/null +++ b/api/core/datasource/utils/dataset_retriever/dataset_multi_retriever_tool.py @@ -0,0 +1,199 @@ +import threading +from typing import Any + +from flask import Flask, current_app +from pydantic import BaseModel, Field + +from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler +from core.model_manager import ModelManager +from core.model_runtime.entities.model_entities import ModelType +from core.rag.datasource.retrieval_service import RetrievalService +from core.rag.models.document import Document as RagDocument +from core.rag.rerank.rerank_model import RerankModelRunner +from core.rag.retrieval.retrieval_methods import RetrievalMethod +from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool +from extensions.ext_database import db +from models.dataset import Dataset, Document, DocumentSegment + +default_retrieval_model: dict[str, Any] = { + "search_method": RetrievalMethod.SEMANTIC_SEARCH.value, + "reranking_enable": False, + "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, + "top_k": 2, + "score_threshold_enabled": False, +} + + +class DatasetMultiRetrieverToolInput(BaseModel): + query: str = Field(..., description="dataset multi retriever and rerank") + + +class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool): + """Tool for querying multi dataset.""" + + name: str = "dataset_" + args_schema: type[BaseModel] = DatasetMultiRetrieverToolInput + description: str = "dataset multi retriever and rerank. " + dataset_ids: list[str] + reranking_provider_name: str + reranking_model_name: str + + @classmethod + def from_dataset(cls, dataset_ids: list[str], tenant_id: str, **kwargs): + return cls( + name=f"dataset_{tenant_id.replace('-', '_')}", tenant_id=tenant_id, dataset_ids=dataset_ids, **kwargs + ) + + def _run(self, query: str) -> str: + threads = [] + all_documents: list[RagDocument] = [] + for dataset_id in self.dataset_ids: + retrieval_thread = threading.Thread( + target=self._retriever, + kwargs={ + "flask_app": current_app._get_current_object(), # type: ignore + "dataset_id": dataset_id, + "query": query, + "all_documents": all_documents, + "hit_callbacks": self.hit_callbacks, + }, + ) + threads.append(retrieval_thread) + retrieval_thread.start() + for thread in threads: + thread.join() + # do rerank for searched documents + model_manager = ModelManager() + rerank_model_instance = model_manager.get_model_instance( + tenant_id=self.tenant_id, + provider=self.reranking_provider_name, + model_type=ModelType.RERANK, + model=self.reranking_model_name, + ) + + rerank_runner = RerankModelRunner(rerank_model_instance) + all_documents = rerank_runner.run(query, all_documents, self.score_threshold, self.top_k) + + for hit_callback in self.hit_callbacks: + hit_callback.on_tool_end(all_documents) + + document_score_list = {} + for item in all_documents: + if item.metadata and item.metadata.get("score"): + document_score_list[item.metadata["doc_id"]] = item.metadata["score"] + + document_context_list = [] + index_node_ids = [document.metadata["doc_id"] for document in all_documents if document.metadata] + segments = DocumentSegment.query.filter( + DocumentSegment.dataset_id.in_(self.dataset_ids), + DocumentSegment.completed_at.isnot(None), + DocumentSegment.status == "completed", + DocumentSegment.enabled == True, + DocumentSegment.index_node_id.in_(index_node_ids), + ).all() + + if segments: + index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)} + sorted_segments = sorted( + segments, key=lambda segment: index_node_id_to_position.get(segment.index_node_id, float("inf")) + ) + for segment in sorted_segments: + if segment.answer: + document_context_list.append(f"question:{segment.get_sign_content()} answer:{segment.answer}") + else: + document_context_list.append(segment.get_sign_content()) + if self.return_resource: + context_list = [] + resource_number = 1 + for segment in sorted_segments: + dataset = Dataset.query.filter_by(id=segment.dataset_id).first() + document = Document.query.filter( + Document.id == segment.document_id, + Document.enabled == True, + Document.archived == False, + ).first() + if dataset and document: + source = { + "position": resource_number, + "dataset_id": dataset.id, + "dataset_name": dataset.name, + "document_id": document.id, + "document_name": document.name, + "data_source_type": document.data_source_type, + "segment_id": segment.id, + "retriever_from": self.retriever_from, + "score": document_score_list.get(segment.index_node_id, None), + "doc_metadata": document.doc_metadata, + } + + if self.retriever_from == "dev": + source["hit_count"] = segment.hit_count + source["word_count"] = segment.word_count + source["segment_position"] = segment.position + source["index_node_hash"] = segment.index_node_hash + if segment.answer: + source["content"] = f"question:{segment.content} \nanswer:{segment.answer}" + else: + source["content"] = segment.content + context_list.append(source) + resource_number += 1 + + for hit_callback in self.hit_callbacks: + hit_callback.return_retriever_resource_info(context_list) + + return str("\n".join(document_context_list)) + return "" + + raise RuntimeError("not segments found") + + def _retriever( + self, + flask_app: Flask, + dataset_id: str, + query: str, + all_documents: list, + hit_callbacks: list[DatasetIndexToolCallbackHandler], + ): + with flask_app.app_context(): + dataset = ( + db.session.query(Dataset).filter(Dataset.tenant_id == self.tenant_id, Dataset.id == dataset_id).first() + ) + + if not dataset: + return [] + + for hit_callback in hit_callbacks: + hit_callback.on_query(query, dataset.id) + + # get retrieval model , if the model is not setting , using default + retrieval_model = dataset.retrieval_model or default_retrieval_model + + if dataset.indexing_technique == "economy": + # use keyword table query + documents = RetrievalService.retrieve( + retrieval_method="keyword_search", + dataset_id=dataset.id, + query=query, + top_k=retrieval_model.get("top_k") or 2, + ) + if documents: + all_documents.extend(documents) + else: + if self.top_k > 0: + # retrieval source + documents = RetrievalService.retrieve( + retrieval_method=retrieval_model["search_method"], + dataset_id=dataset.id, + query=query, + top_k=retrieval_model.get("top_k") or 2, + score_threshold=retrieval_model.get("score_threshold", 0.0) + if retrieval_model["score_threshold_enabled"] + else 0.0, + reranking_model=retrieval_model.get("reranking_model", None) + if retrieval_model["reranking_enable"] + else None, + reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model", + weights=retrieval_model.get("weights", None), + ) + + all_documents.extend(documents) diff --git a/api/core/datasource/utils/dataset_retriever/dataset_retriever_base_tool.py b/api/core/datasource/utils/dataset_retriever/dataset_retriever_base_tool.py new file mode 100644 index 0000000000..a4d2de3b1c --- /dev/null +++ b/api/core/datasource/utils/dataset_retriever/dataset_retriever_base_tool.py @@ -0,0 +1,33 @@ +from abc import abstractmethod +from typing import Any, Optional + +from msal_extensions.persistence import ABC # type: ignore +from pydantic import BaseModel, ConfigDict + +from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler + + +class DatasetRetrieverBaseTool(BaseModel, ABC): + """Tool for querying a Dataset.""" + + name: str = "dataset" + description: str = "use this to retrieve a dataset. " + tenant_id: str + top_k: int = 2 + score_threshold: Optional[float] = None + hit_callbacks: list[DatasetIndexToolCallbackHandler] = [] + return_resource: bool + retriever_from: str + model_config = ConfigDict(arbitrary_types_allowed=True) + + @abstractmethod + def _run( + self, + *args: Any, + **kwargs: Any, + ) -> Any: + """Use the tool. + + Add run_manager: Optional[CallbackManagerForToolRun] = None + to child implementations to enable tracing, + """ diff --git a/api/core/datasource/utils/dataset_retriever/dataset_retriever_tool.py b/api/core/datasource/utils/dataset_retriever/dataset_retriever_tool.py new file mode 100644 index 0000000000..63260cfac3 --- /dev/null +++ b/api/core/datasource/utils/dataset_retriever/dataset_retriever_tool.py @@ -0,0 +1,202 @@ +from typing import Any + +from pydantic import BaseModel, Field + +from core.rag.datasource.retrieval_service import RetrievalService +from core.rag.entities.context_entities import DocumentContext +from core.rag.models.document import Document as RetrievalDocument +from core.rag.retrieval.retrieval_methods import RetrievalMethod +from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool +from extensions.ext_database import db +from models.dataset import Dataset +from models.dataset import Document as DatasetDocument +from services.external_knowledge_service import ExternalDatasetService + +default_retrieval_model = { + "search_method": RetrievalMethod.SEMANTIC_SEARCH.value, + "reranking_enable": False, + "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, + "reranking_mode": "reranking_model", + "top_k": 2, + "score_threshold_enabled": False, +} + + +class DatasetRetrieverToolInput(BaseModel): + query: str = Field(..., description="Query for the dataset to be used to retrieve the dataset.") + + +class DatasetRetrieverTool(DatasetRetrieverBaseTool): + """Tool for querying a Dataset.""" + + name: str = "dataset" + args_schema: type[BaseModel] = DatasetRetrieverToolInput + description: str = "use this to retrieve a dataset. " + dataset_id: str + + @classmethod + def from_dataset(cls, dataset: Dataset, **kwargs): + description = dataset.description + if not description: + description = "useful for when you want to answer queries about the " + dataset.name + + description = description.replace("\n", "").replace("\r", "") + return cls( + name=f"dataset_{dataset.id.replace('-', '_')}", + tenant_id=dataset.tenant_id, + dataset_id=dataset.id, + description=description, + **kwargs, + ) + + def _run(self, query: str) -> str: + dataset = ( + db.session.query(Dataset).filter(Dataset.tenant_id == self.tenant_id, Dataset.id == self.dataset_id).first() + ) + + if not dataset: + return "" + for hit_callback in self.hit_callbacks: + hit_callback.on_query(query, dataset.id) + if dataset.provider == "external": + results = [] + external_documents = ExternalDatasetService.fetch_external_knowledge_retrieval( + tenant_id=dataset.tenant_id, + dataset_id=dataset.id, + query=query, + external_retrieval_parameters=dataset.retrieval_model, + ) + for external_document in external_documents: + document = RetrievalDocument( + page_content=external_document.get("content"), + metadata=external_document.get("metadata"), + provider="external", + ) + if document.metadata is not None: + document.metadata["score"] = external_document.get("score") + document.metadata["title"] = external_document.get("title") + document.metadata["dataset_id"] = dataset.id + document.metadata["dataset_name"] = dataset.name + results.append(document) + # deal with external documents + context_list = [] + for position, item in enumerate(results, start=1): + if item.metadata is not None: + source = { + "position": position, + "dataset_id": item.metadata.get("dataset_id"), + "dataset_name": item.metadata.get("dataset_name"), + "document_name": item.metadata.get("title"), + "data_source_type": "external", + "retriever_from": self.retriever_from, + "score": item.metadata.get("score"), + "title": item.metadata.get("title"), + "content": item.page_content, + } + context_list.append(source) + for hit_callback in self.hit_callbacks: + hit_callback.return_retriever_resource_info(context_list) + + return str("\n".join([item.page_content for item in results])) + else: + # get retrieval model , if the model is not setting , using default + retrieval_model: dict[str, Any] = dataset.retrieval_model or default_retrieval_model + if dataset.indexing_technique == "economy": + # use keyword table query + documents = RetrievalService.retrieve( + retrieval_method="keyword_search", dataset_id=dataset.id, query=query, top_k=self.top_k + ) + return str("\n".join([document.page_content for document in documents])) + else: + if self.top_k > 0: + # retrieval source + documents = RetrievalService.retrieve( + retrieval_method=retrieval_model.get("search_method", "semantic_search"), + dataset_id=dataset.id, + query=query, + top_k=self.top_k, + score_threshold=retrieval_model.get("score_threshold", 0.0) + if retrieval_model["score_threshold_enabled"] + else 0.0, + reranking_model=retrieval_model.get("reranking_model") + if retrieval_model["reranking_enable"] + else None, + reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model", + weights=retrieval_model.get("weights"), + ) + else: + documents = [] + for hit_callback in self.hit_callbacks: + hit_callback.on_tool_end(documents) + document_score_list = {} + if dataset.indexing_technique != "economy": + for item in documents: + if item.metadata is not None and item.metadata.get("score"): + document_score_list[item.metadata["doc_id"]] = item.metadata["score"] + document_context_list = [] + records = RetrievalService.format_retrieval_documents(documents) + if records: + for record in records: + segment = record.segment + if segment.answer: + document_context_list.append( + DocumentContext( + content=f"question:{segment.get_sign_content()} answer:{segment.answer}", + score=record.score, + ) + ) + else: + document_context_list.append( + DocumentContext( + content=segment.get_sign_content(), + score=record.score, + ) + ) + retrieval_resource_list = [] + if self.return_resource: + for record in records: + segment = record.segment + dataset = Dataset.query.filter_by(id=segment.dataset_id).first() + document = DatasetDocument.query.filter( + DatasetDocument.id == segment.document_id, + DatasetDocument.enabled == True, + DatasetDocument.archived == False, + ).first() + if dataset and document: + source = { + "dataset_id": dataset.id, + "dataset_name": dataset.name, + "document_id": document.id, # type: ignore + "document_name": document.name, # type: ignore + "data_source_type": document.data_source_type, # type: ignore + "segment_id": segment.id, + "retriever_from": self.retriever_from, + "score": record.score or 0.0, + "doc_metadata": document.doc_metadata, # type: ignore + } + + if self.retriever_from == "dev": + source["hit_count"] = segment.hit_count + source["word_count"] = segment.word_count + source["segment_position"] = segment.position + source["index_node_hash"] = segment.index_node_hash + if segment.answer: + source["content"] = f"question:{segment.content} \nanswer:{segment.answer}" + else: + source["content"] = segment.content + retrieval_resource_list.append(source) + + if self.return_resource and retrieval_resource_list: + retrieval_resource_list = sorted( + retrieval_resource_list, + key=lambda x: x.get("score") or 0.0, + reverse=True, + ) + for position, item in enumerate(retrieval_resource_list, start=1): # type: ignore + item["position"] = position # type: ignore + for hit_callback in self.hit_callbacks: + hit_callback.return_retriever_resource_info(retrieval_resource_list) + if document_context_list: + document_context_list = sorted(document_context_list, key=lambda x: x.score or 0.0, reverse=True) + return str("\n".join([document_context.content for document_context in document_context_list])) + return "" diff --git a/api/core/datasource/utils/dataset_retriever_tool.py b/api/core/datasource/utils/dataset_retriever_tool.py new file mode 100644 index 0000000000..b73dec4ebc --- /dev/null +++ b/api/core/datasource/utils/dataset_retriever_tool.py @@ -0,0 +1,134 @@ +from collections.abc import Generator +from typing import Any, Optional + +from core.app.app_config.entities import DatasetRetrieveConfigEntity +from core.app.entities.app_invoke_entities import InvokeFrom +from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler +from core.rag.retrieval.dataset_retrieval import DatasetRetrieval +from core.tools.__base.tool import Tool +from core.tools.__base.tool_runtime import ToolRuntime +from core.tools.entities.common_entities import I18nObject +from core.tools.entities.tool_entities import ( + ToolDescription, + ToolEntity, + ToolIdentity, + ToolInvokeMessage, + ToolParameter, + ToolProviderType, +) +from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool + + +class DatasetRetrieverTool(Tool): + retrieval_tool: DatasetRetrieverBaseTool + + def __init__(self, entity: ToolEntity, runtime: ToolRuntime, retrieval_tool: DatasetRetrieverBaseTool) -> None: + super().__init__(entity, runtime) + self.retrieval_tool = retrieval_tool + + @staticmethod + def get_dataset_tools( + tenant_id: str, + dataset_ids: list[str], + retrieve_config: DatasetRetrieveConfigEntity | None, + return_resource: bool, + invoke_from: InvokeFrom, + hit_callback: DatasetIndexToolCallbackHandler, + ) -> list["DatasetRetrieverTool"]: + """ + get dataset tool + """ + # check if retrieve_config is valid + if dataset_ids is None or len(dataset_ids) == 0: + return [] + if retrieve_config is None: + return [] + + feature = DatasetRetrieval() + + # save original retrieve strategy, and set retrieve strategy to SINGLE + # Agent only support SINGLE mode + original_retriever_mode = retrieve_config.retrieve_strategy + retrieve_config.retrieve_strategy = DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE + retrieval_tools = feature.to_dataset_retriever_tool( + tenant_id=tenant_id, + dataset_ids=dataset_ids, + retrieve_config=retrieve_config, + return_resource=return_resource, + invoke_from=invoke_from, + hit_callback=hit_callback, + ) + if retrieval_tools is None or len(retrieval_tools) == 0: + return [] + + # restore retrieve strategy + retrieve_config.retrieve_strategy = original_retriever_mode + + # convert retrieval tools to Tools + tools = [] + for retrieval_tool in retrieval_tools: + tool = DatasetRetrieverTool( + retrieval_tool=retrieval_tool, + entity=ToolEntity( + identity=ToolIdentity( + provider="", author="", name=retrieval_tool.name, label=I18nObject(en_US="", zh_Hans="") + ), + parameters=[], + description=ToolDescription(human=I18nObject(en_US="", zh_Hans=""), llm=retrieval_tool.description), + ), + runtime=ToolRuntime(tenant_id=tenant_id), + ) + + tools.append(tool) + + return tools + + def get_runtime_parameters( + self, + conversation_id: Optional[str] = None, + app_id: Optional[str] = None, + message_id: Optional[str] = None, + ) -> list[ToolParameter]: + return [ + ToolParameter( + name="query", + label=I18nObject(en_US="", zh_Hans=""), + human_description=I18nObject(en_US="", zh_Hans=""), + type=ToolParameter.ToolParameterType.STRING, + form=ToolParameter.ToolParameterForm.LLM, + llm_description="Query for the dataset to be used to retrieve the dataset.", + required=True, + default="", + placeholder=I18nObject(en_US="", zh_Hans=""), + ), + ] + + def tool_provider_type(self) -> ToolProviderType: + return ToolProviderType.DATASET_RETRIEVAL + + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + conversation_id: Optional[str] = None, + app_id: Optional[str] = None, + message_id: Optional[str] = None, + ) -> Generator[ToolInvokeMessage, None, None]: + """ + invoke dataset retriever tool + """ + query = tool_parameters.get("query") + if not query: + yield self.create_text_message(text="please input query") + else: + # invoke dataset retriever tool + result = self.retrieval_tool._run(query=query) + yield self.create_text_message(text=result) + + def validate_credentials( + self, credentials: dict[str, Any], parameters: dict[str, Any], format_only: bool = False + ) -> str | None: + """ + validate the credentials for dataset retriever tool + """ + pass diff --git a/api/core/datasource/utils/message_transformer.py b/api/core/datasource/utils/message_transformer.py new file mode 100644 index 0000000000..6fd0c201e3 --- /dev/null +++ b/api/core/datasource/utils/message_transformer.py @@ -0,0 +1,121 @@ +import logging +from collections.abc import Generator +from mimetypes import guess_extension +from typing import Optional + +from core.file import File, FileTransferMethod, FileType +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool_file_manager import ToolFileManager + +logger = logging.getLogger(__name__) + + +class ToolFileMessageTransformer: + @classmethod + def transform_tool_invoke_messages( + cls, + messages: Generator[ToolInvokeMessage, None, None], + user_id: str, + tenant_id: str, + conversation_id: Optional[str] = None, + ) -> Generator[ToolInvokeMessage, None, None]: + """ + Transform tool message and handle file download + """ + for message in messages: + if message.type in {ToolInvokeMessage.MessageType.TEXT, ToolInvokeMessage.MessageType.LINK}: + yield message + elif message.type == ToolInvokeMessage.MessageType.IMAGE and isinstance( + message.message, ToolInvokeMessage.TextMessage + ): + # try to download image + try: + assert isinstance(message.message, ToolInvokeMessage.TextMessage) + + file = ToolFileManager.create_file_by_url( + user_id=user_id, + tenant_id=tenant_id, + file_url=message.message.text, + conversation_id=conversation_id, + ) + + url = f"/files/tools/{file.id}{guess_extension(file.mimetype) or '.png'}" + + yield ToolInvokeMessage( + type=ToolInvokeMessage.MessageType.IMAGE_LINK, + message=ToolInvokeMessage.TextMessage(text=url), + meta=message.meta.copy() if message.meta is not None else {}, + ) + except Exception as e: + yield ToolInvokeMessage( + type=ToolInvokeMessage.MessageType.TEXT, + message=ToolInvokeMessage.TextMessage( + text=f"Failed to download image: {message.message.text}: {e}" + ), + meta=message.meta.copy() if message.meta is not None else {}, + ) + elif message.type == ToolInvokeMessage.MessageType.BLOB: + # get mime type and save blob to storage + meta = message.meta or {} + + mimetype = meta.get("mime_type", "application/octet-stream") + # get filename from meta + filename = meta.get("file_name", None) + # if message is str, encode it to bytes + + if not isinstance(message.message, ToolInvokeMessage.BlobMessage): + raise ValueError("unexpected message type") + + # FIXME: should do a type check here. + assert isinstance(message.message.blob, bytes) + file = ToolFileManager.create_file_by_raw( + user_id=user_id, + tenant_id=tenant_id, + conversation_id=conversation_id, + file_binary=message.message.blob, + mimetype=mimetype, + filename=filename, + ) + + url = cls.get_tool_file_url(tool_file_id=file.id, extension=guess_extension(file.mimetype)) + + # check if file is image + if "image" in mimetype: + yield ToolInvokeMessage( + type=ToolInvokeMessage.MessageType.IMAGE_LINK, + message=ToolInvokeMessage.TextMessage(text=url), + meta=meta.copy() if meta is not None else {}, + ) + else: + yield ToolInvokeMessage( + type=ToolInvokeMessage.MessageType.BINARY_LINK, + message=ToolInvokeMessage.TextMessage(text=url), + meta=meta.copy() if meta is not None else {}, + ) + elif message.type == ToolInvokeMessage.MessageType.FILE: + meta = message.meta or {} + file = meta.get("file", None) + if isinstance(file, File): + if file.transfer_method == FileTransferMethod.TOOL_FILE: + assert file.related_id is not None + url = cls.get_tool_file_url(tool_file_id=file.related_id, extension=file.extension) + if file.type == FileType.IMAGE: + yield ToolInvokeMessage( + type=ToolInvokeMessage.MessageType.IMAGE_LINK, + message=ToolInvokeMessage.TextMessage(text=url), + meta=meta.copy() if meta is not None else {}, + ) + else: + yield ToolInvokeMessage( + type=ToolInvokeMessage.MessageType.LINK, + message=ToolInvokeMessage.TextMessage(text=url), + meta=meta.copy() if meta is not None else {}, + ) + else: + yield message + else: + yield message + + @classmethod + def get_tool_file_url(cls, tool_file_id: str, extension: Optional[str]) -> str: + return f"/files/tools/{tool_file_id}{extension or '.bin'}" diff --git a/api/core/datasource/utils/model_invocation_utils.py b/api/core/datasource/utils/model_invocation_utils.py new file mode 100644 index 0000000000..3f59b3f472 --- /dev/null +++ b/api/core/datasource/utils/model_invocation_utils.py @@ -0,0 +1,169 @@ +""" +For some reason, model will be used in tools like WebScraperTool, WikipediaSearchTool etc. + +Therefore, a model manager is needed to list/invoke/validate models. +""" + +import json +from typing import Optional, cast + +from core.model_manager import ModelManager +from core.model_runtime.entities.llm_entities import LLMResult +from core.model_runtime.entities.message_entities import PromptMessage +from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) +from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from core.model_runtime.utils.encoders import jsonable_encoder +from extensions.ext_database import db +from models.tools import ToolModelInvoke + + +class InvokeModelError(Exception): + pass + + +class ModelInvocationUtils: + @staticmethod + def get_max_llm_context_tokens( + tenant_id: str, + ) -> int: + """ + get max llm context tokens of the model + """ + model_manager = ModelManager() + model_instance = model_manager.get_default_model_instance( + tenant_id=tenant_id, + model_type=ModelType.LLM, + ) + + if not model_instance: + raise InvokeModelError("Model not found") + + llm_model = cast(LargeLanguageModel, model_instance.model_type_instance) + schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials) + + if not schema: + raise InvokeModelError("No model schema found") + + max_tokens: Optional[int] = schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE, None) + if max_tokens is None: + return 2048 + + return max_tokens + + @staticmethod + def calculate_tokens(tenant_id: str, prompt_messages: list[PromptMessage]) -> int: + """ + calculate tokens from prompt messages and model parameters + """ + + # get model instance + model_manager = ModelManager() + model_instance = model_manager.get_default_model_instance(tenant_id=tenant_id, model_type=ModelType.LLM) + + if not model_instance: + raise InvokeModelError("Model not found") + + # get tokens + tokens = model_instance.get_llm_num_tokens(prompt_messages) + + return tokens + + @staticmethod + def invoke( + user_id: str, tenant_id: str, tool_type: str, tool_name: str, prompt_messages: list[PromptMessage] + ) -> LLMResult: + """ + invoke model with parameters in user's own context + + :param user_id: user id + :param tenant_id: tenant id, the tenant id of the creator of the tool + :param tool_type: tool type + :param tool_name: tool name + :param prompt_messages: prompt messages + :return: AssistantPromptMessage + """ + + # get model manager + model_manager = ModelManager() + # get model instance + model_instance = model_manager.get_default_model_instance( + tenant_id=tenant_id, + model_type=ModelType.LLM, + ) + + # get prompt tokens + prompt_tokens = model_instance.get_llm_num_tokens(prompt_messages) + + model_parameters = { + "temperature": 0.8, + "top_p": 0.8, + } + + # create tool model invoke + tool_model_invoke = ToolModelInvoke( + user_id=user_id, + tenant_id=tenant_id, + provider=model_instance.provider, + tool_type=tool_type, + tool_name=tool_name, + model_parameters=json.dumps(model_parameters), + prompt_messages=json.dumps(jsonable_encoder(prompt_messages)), + model_response="", + prompt_tokens=prompt_tokens, + answer_tokens=0, + answer_unit_price=0, + answer_price_unit=0, + provider_response_latency=0, + total_price=0, + currency="USD", + ) + + db.session.add(tool_model_invoke) + db.session.commit() + + try: + response: LLMResult = cast( + LLMResult, + model_instance.invoke_llm( + prompt_messages=prompt_messages, + model_parameters=model_parameters, + tools=[], + stop=[], + stream=False, + user=user_id, + callbacks=[], + ), + ) + except InvokeRateLimitError as e: + raise InvokeModelError(f"Invoke rate limit error: {e}") + except InvokeBadRequestError as e: + raise InvokeModelError(f"Invoke bad request error: {e}") + except InvokeConnectionError as e: + raise InvokeModelError(f"Invoke connection error: {e}") + except InvokeAuthorizationError as e: + raise InvokeModelError("Invoke authorization error") + except InvokeServerUnavailableError as e: + raise InvokeModelError(f"Invoke server unavailable error: {e}") + except Exception as e: + raise InvokeModelError(f"Invoke error: {e}") + + # update tool model invoke + tool_model_invoke.model_response = response.message.content + if response.usage: + tool_model_invoke.answer_tokens = response.usage.completion_tokens + tool_model_invoke.answer_unit_price = response.usage.completion_unit_price + tool_model_invoke.answer_price_unit = response.usage.completion_price_unit + tool_model_invoke.provider_response_latency = response.usage.latency + tool_model_invoke.total_price = response.usage.total_price + tool_model_invoke.currency = response.usage.currency + + db.session.commit() + + return response diff --git a/api/core/datasource/utils/parser.py b/api/core/datasource/utils/parser.py new file mode 100644 index 0000000000..f72291783a --- /dev/null +++ b/api/core/datasource/utils/parser.py @@ -0,0 +1,389 @@ +import re +import uuid +from json import dumps as json_dumps +from json import loads as json_loads +from json.decoder import JSONDecodeError +from typing import Optional + +from flask import request +from requests import get +from yaml import YAMLError, safe_load # type: ignore + +from core.tools.entities.common_entities import I18nObject +from core.tools.entities.tool_bundle import ApiToolBundle +from core.tools.entities.tool_entities import ApiProviderSchemaType, ToolParameter +from core.tools.errors import ToolApiSchemaError, ToolNotSupportedError, ToolProviderNotFoundError + + +class ApiBasedToolSchemaParser: + @staticmethod + def parse_openapi_to_tool_bundle( + openapi: dict, extra_info: dict | None = None, warning: dict | None = None + ) -> list[ApiToolBundle]: + warning = warning if warning is not None else {} + extra_info = extra_info if extra_info is not None else {} + + # set description to extra_info + extra_info["description"] = openapi["info"].get("description", "") + + if len(openapi["servers"]) == 0: + raise ToolProviderNotFoundError("No server found in the openapi yaml.") + + server_url = openapi["servers"][0]["url"] + request_env = request.headers.get("X-Request-Env") + if request_env: + matched_servers = [server["url"] for server in openapi["servers"] if server["env"] == request_env] + server_url = matched_servers[0] if matched_servers else server_url + + # list all interfaces + interfaces = [] + for path, path_item in openapi["paths"].items(): + methods = ["get", "post", "put", "delete", "patch", "head", "options", "trace"] + for method in methods: + if method in path_item: + interfaces.append( + { + "path": path, + "method": method, + "operation": path_item[method], + } + ) + + # get all parameters + bundles = [] + for interface in interfaces: + # convert parameters + parameters = [] + if "parameters" in interface["operation"]: + for parameter in interface["operation"]["parameters"]: + tool_parameter = ToolParameter( + name=parameter["name"], + label=I18nObject(en_US=parameter["name"], zh_Hans=parameter["name"]), + human_description=I18nObject( + en_US=parameter.get("description", ""), zh_Hans=parameter.get("description", "") + ), + type=ToolParameter.ToolParameterType.STRING, + required=parameter.get("required", False), + form=ToolParameter.ToolParameterForm.LLM, + llm_description=parameter.get("description"), + default=parameter["schema"]["default"] + if "schema" in parameter and "default" in parameter["schema"] + else None, + placeholder=I18nObject( + en_US=parameter.get("description", ""), zh_Hans=parameter.get("description", "") + ), + ) + + # check if there is a type + typ = ApiBasedToolSchemaParser._get_tool_parameter_type(parameter) + if typ: + tool_parameter.type = typ + + parameters.append(tool_parameter) + # create tool bundle + # check if there is a request body + if "requestBody" in interface["operation"]: + request_body = interface["operation"]["requestBody"] + if "content" in request_body: + for content_type, content in request_body["content"].items(): + # if there is a reference, get the reference and overwrite the content + if "schema" not in content: + continue + + if "$ref" in content["schema"]: + # get the reference + root = openapi + reference = content["schema"]["$ref"].split("/")[1:] + for ref in reference: + root = root[ref] + # overwrite the content + interface["operation"]["requestBody"]["content"][content_type]["schema"] = root + + # parse body parameters + if "schema" in interface["operation"]["requestBody"]["content"][content_type]: + body_schema = interface["operation"]["requestBody"]["content"][content_type]["schema"] + required = body_schema.get("required", []) + properties = body_schema.get("properties", {}) + for name, property in properties.items(): + tool = ToolParameter( + name=name, + label=I18nObject(en_US=name, zh_Hans=name), + human_description=I18nObject( + en_US=property.get("description", ""), zh_Hans=property.get("description", "") + ), + type=ToolParameter.ToolParameterType.STRING, + required=name in required, + form=ToolParameter.ToolParameterForm.LLM, + llm_description=property.get("description", ""), + default=property.get("default", None), + placeholder=I18nObject( + en_US=property.get("description", ""), zh_Hans=property.get("description", "") + ), + ) + + # check if there is a type + typ = ApiBasedToolSchemaParser._get_tool_parameter_type(property) + if typ: + tool.type = typ + + parameters.append(tool) + + # check if parameters is duplicated + parameters_count = {} + for parameter in parameters: + if parameter.name not in parameters_count: + parameters_count[parameter.name] = 0 + parameters_count[parameter.name] += 1 + for name, count in parameters_count.items(): + if count > 1: + warning["duplicated_parameter"] = f"Parameter {name} is duplicated." + + # check if there is a operation id, use $path_$method as operation id if not + if "operationId" not in interface["operation"]: + # remove special characters like / to ensure the operation id is valid ^[a-zA-Z0-9_-]{1,64}$ + path = interface["path"] + if interface["path"].startswith("/"): + path = interface["path"][1:] + # remove special characters like / to ensure the operation id is valid ^[a-zA-Z0-9_-]{1,64}$ + path = re.sub(r"[^a-zA-Z0-9_-]", "", path) + if not path: + path = str(uuid.uuid4()) + + interface["operation"]["operationId"] = f"{path}_{interface['method']}" + + bundles.append( + ApiToolBundle( + server_url=server_url + interface["path"], + method=interface["method"], + summary=interface["operation"]["description"] + if "description" in interface["operation"] + else interface["operation"].get("summary", None), + operation_id=interface["operation"]["operationId"], + parameters=parameters, + author="", + icon=None, + openapi=interface["operation"], + ) + ) + + return bundles + + @staticmethod + def _get_tool_parameter_type(parameter: dict) -> Optional[ToolParameter.ToolParameterType]: + parameter = parameter or {} + typ: Optional[str] = None + if parameter.get("format") == "binary": + return ToolParameter.ToolParameterType.FILE + + if "type" in parameter: + typ = parameter["type"] + elif "schema" in parameter and "type" in parameter["schema"]: + typ = parameter["schema"]["type"] + + if typ in {"integer", "number"}: + return ToolParameter.ToolParameterType.NUMBER + elif typ == "boolean": + return ToolParameter.ToolParameterType.BOOLEAN + elif typ == "string": + return ToolParameter.ToolParameterType.STRING + elif typ == "array": + items = parameter.get("items") or parameter.get("schema", {}).get("items") + return ToolParameter.ToolParameterType.FILES if items and items.get("format") == "binary" else None + else: + return None + + @staticmethod + def parse_openapi_yaml_to_tool_bundle( + yaml: str, extra_info: dict | None = None, warning: dict | None = None + ) -> list[ApiToolBundle]: + """ + parse openapi yaml to tool bundle + + :param yaml: the yaml string + :param extra_info: the extra info + :param warning: the warning message + :return: the tool bundle + """ + warning = warning if warning is not None else {} + extra_info = extra_info if extra_info is not None else {} + + openapi: dict = safe_load(yaml) + if openapi is None: + raise ToolApiSchemaError("Invalid openapi yaml.") + return ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(openapi, extra_info=extra_info, warning=warning) + + @staticmethod + def parse_swagger_to_openapi(swagger: dict, extra_info: dict | None = None, warning: dict | None = None) -> dict: + warning = warning or {} + """ + parse swagger to openapi + + :param swagger: the swagger dict + :return: the openapi dict + """ + # convert swagger to openapi + info = swagger.get("info", {"title": "Swagger", "description": "Swagger", "version": "1.0.0"}) + + servers = swagger.get("servers", []) + + if len(servers) == 0: + raise ToolApiSchemaError("No server found in the swagger yaml.") + + openapi = { + "openapi": "3.0.0", + "info": { + "title": info.get("title", "Swagger"), + "description": info.get("description", "Swagger"), + "version": info.get("version", "1.0.0"), + }, + "servers": swagger["servers"], + "paths": {}, + "components": {"schemas": {}}, + } + + # check paths + if "paths" not in swagger or len(swagger["paths"]) == 0: + raise ToolApiSchemaError("No paths found in the swagger yaml.") + + # convert paths + for path, path_item in swagger["paths"].items(): + openapi["paths"][path] = {} + for method, operation in path_item.items(): + if "operationId" not in operation: + raise ToolApiSchemaError(f"No operationId found in operation {method} {path}.") + + if ("summary" not in operation or len(operation["summary"]) == 0) and ( + "description" not in operation or len(operation["description"]) == 0 + ): + if warning is not None: + warning["missing_summary"] = f"No summary or description found in operation {method} {path}." + + openapi["paths"][path][method] = { + "operationId": operation["operationId"], + "summary": operation.get("summary", ""), + "description": operation.get("description", ""), + "parameters": operation.get("parameters", []), + "responses": operation.get("responses", {}), + } + + if "requestBody" in operation: + openapi["paths"][path][method]["requestBody"] = operation["requestBody"] + + # convert definitions + for name, definition in swagger["definitions"].items(): + openapi["components"]["schemas"][name] = definition + + return openapi + + @staticmethod + def parse_openai_plugin_json_to_tool_bundle( + json: str, extra_info: dict | None = None, warning: dict | None = None + ) -> list[ApiToolBundle]: + """ + parse openapi plugin yaml to tool bundle + + :param json: the json string + :param extra_info: the extra info + :param warning: the warning message + :return: the tool bundle + """ + warning = warning if warning is not None else {} + extra_info = extra_info if extra_info is not None else {} + + try: + openai_plugin = json_loads(json) + api = openai_plugin["api"] + api_url = api["url"] + api_type = api["type"] + except JSONDecodeError: + raise ToolProviderNotFoundError("Invalid openai plugin json.") + + if api_type != "openapi": + raise ToolNotSupportedError("Only openapi is supported now.") + + # get openapi yaml + response = get(api_url, headers={"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) "}, timeout=5) + + if response.status_code != 200: + raise ToolProviderNotFoundError("cannot get openapi yaml from url.") + + return ApiBasedToolSchemaParser.parse_openapi_yaml_to_tool_bundle( + response.text, extra_info=extra_info, warning=warning + ) + + @staticmethod + def auto_parse_to_tool_bundle( + content: str, extra_info: dict | None = None, warning: dict | None = None + ) -> tuple[list[ApiToolBundle], str]: + """ + auto parse to tool bundle + + :param content: the content + :param extra_info: the extra info + :param warning: the warning message + :return: tools bundle, schema_type + """ + warning = warning if warning is not None else {} + extra_info = extra_info if extra_info is not None else {} + + content = content.strip() + loaded_content = None + json_error = None + yaml_error = None + + try: + loaded_content = json_loads(content) + except JSONDecodeError as e: + json_error = e + + if loaded_content is None: + try: + loaded_content = safe_load(content) + except YAMLError as e: + yaml_error = e + if loaded_content is None: + raise ToolApiSchemaError( + f"Invalid api schema, schema is neither json nor yaml. json error: {str(json_error)}," + f" yaml error: {str(yaml_error)}" + ) + + swagger_error = None + openapi_error = None + openapi_plugin_error = None + schema_type = None + + try: + openapi = ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle( + loaded_content, extra_info=extra_info, warning=warning + ) + schema_type = ApiProviderSchemaType.OPENAPI.value + return openapi, schema_type + except ToolApiSchemaError as e: + openapi_error = e + + # openai parse error, fallback to swagger + try: + converted_swagger = ApiBasedToolSchemaParser.parse_swagger_to_openapi( + loaded_content, extra_info=extra_info, warning=warning + ) + schema_type = ApiProviderSchemaType.SWAGGER.value + return ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle( + converted_swagger, extra_info=extra_info, warning=warning + ), schema_type + except ToolApiSchemaError as e: + swagger_error = e + + # swagger parse error, fallback to openai plugin + try: + openapi_plugin = ApiBasedToolSchemaParser.parse_openai_plugin_json_to_tool_bundle( + json_dumps(loaded_content), extra_info=extra_info, warning=warning + ) + return openapi_plugin, ApiProviderSchemaType.OPENAI_PLUGIN.value + except ToolNotSupportedError as e: + # maybe it's not plugin at all + openapi_plugin_error = e + + raise ToolApiSchemaError( + f"Invalid api schema, openapi error: {str(openapi_error)}, swagger error: {str(swagger_error)}," + f" openapi plugin error: {str(openapi_plugin_error)}" + ) diff --git a/api/core/datasource/utils/rag_web_reader.py b/api/core/datasource/utils/rag_web_reader.py new file mode 100644 index 0000000000..22c47fa814 --- /dev/null +++ b/api/core/datasource/utils/rag_web_reader.py @@ -0,0 +1,17 @@ +import re + + +def get_image_upload_file_ids(content): + pattern = r"!\[image\]\((http?://.*?(file-preview|image-preview))\)" + matches = re.findall(pattern, content) + image_upload_file_ids = [] + for match in matches: + if match[1] == "file-preview": + content_pattern = r"files/([^/]+)/file-preview" + else: + content_pattern = r"files/([^/]+)/image-preview" + content_match = re.search(content_pattern, match[0]) + if content_match: + image_upload_file_id = content_match.group(1) + image_upload_file_ids.append(image_upload_file_id) + return image_upload_file_ids diff --git a/api/core/datasource/utils/text_processing_utils.py b/api/core/datasource/utils/text_processing_utils.py new file mode 100644 index 0000000000..105823f896 --- /dev/null +++ b/api/core/datasource/utils/text_processing_utils.py @@ -0,0 +1,17 @@ +import re + + +def remove_leading_symbols(text: str) -> str: + """ + Remove leading punctuation or symbols from the given text. + + Args: + text (str): The input text to process. + + Returns: + str: The text with leading punctuation or symbols removed. + """ + # Match Unicode ranges for punctuation and symbols + # FIXME this pattern is confused quick fix for #11868 maybe refactor it later + pattern = r"^[\u2000-\u206F\u2E00-\u2E7F\u3000-\u303F!\"#$%&'()*+,./:;<=>?@^_`~]+" + return re.sub(pattern, "", text) diff --git a/api/core/datasource/utils/uuid_utils.py b/api/core/datasource/utils/uuid_utils.py new file mode 100644 index 0000000000..3046c08c89 --- /dev/null +++ b/api/core/datasource/utils/uuid_utils.py @@ -0,0 +1,9 @@ +import uuid + + +def is_valid_uuid(uuid_str: str) -> bool: + try: + uuid.UUID(uuid_str) + return True + except Exception: + return False diff --git a/api/core/datasource/utils/web_reader_tool.py b/api/core/datasource/utils/web_reader_tool.py new file mode 100644 index 0000000000..d42fd99fce --- /dev/null +++ b/api/core/datasource/utils/web_reader_tool.py @@ -0,0 +1,375 @@ +import hashlib +import json +import mimetypes +import os +import re +import site +import subprocess +import tempfile +import unicodedata +from contextlib import contextmanager +from pathlib import Path +from typing import Any, Literal, Optional, cast +from urllib.parse import unquote + +import chardet +import cloudscraper # type: ignore +from bs4 import BeautifulSoup, CData, Comment, NavigableString # type: ignore +from regex import regex # type: ignore + +from core.helper import ssrf_proxy +from core.rag.extractor import extract_processor +from core.rag.extractor.extract_processor import ExtractProcessor + +FULL_TEMPLATE = """ +TITLE: {title} +AUTHORS: {authors} +PUBLISH DATE: {publish_date} +TOP_IMAGE_URL: {top_image} +TEXT: + +{text} +""" + + +def page_result(text: str, cursor: int, max_length: int) -> str: + """Page through `text` and return a substring of `max_length` characters starting from `cursor`.""" + return text[cursor : cursor + max_length] + + +def get_url(url: str, user_agent: Optional[str] = None) -> str: + """Fetch URL and return the contents as a string.""" + headers = { + "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko)" + " Chrome/91.0.4472.124 Safari/537.36" + } + if user_agent: + headers["User-Agent"] = user_agent + + main_content_type = None + supported_content_types = extract_processor.SUPPORT_URL_CONTENT_TYPES + ["text/html"] + response = ssrf_proxy.head(url, headers=headers, follow_redirects=True, timeout=(5, 10)) + + if response.status_code == 200: + # check content-type + content_type = response.headers.get("Content-Type") + if content_type: + main_content_type = response.headers.get("Content-Type").split(";")[0].strip() + else: + content_disposition = response.headers.get("Content-Disposition", "") + filename_match = re.search(r'filename="([^"]+)"', content_disposition) + if filename_match: + filename = unquote(filename_match.group(1)) + extension = re.search(r"\.(\w+)$", filename) + if extension: + main_content_type = mimetypes.guess_type(filename)[0] + + if main_content_type not in supported_content_types: + return "Unsupported content-type [{}] of URL.".format(main_content_type) + + if main_content_type in extract_processor.SUPPORT_URL_CONTENT_TYPES: + return cast(str, ExtractProcessor.load_from_url(url, return_text=True)) + + response = ssrf_proxy.get(url, headers=headers, follow_redirects=True, timeout=(120, 300)) + elif response.status_code == 403: + scraper = cloudscraper.create_scraper() + scraper.perform_request = ssrf_proxy.make_request + response = scraper.get(url, headers=headers, follow_redirects=True, timeout=(120, 300)) + + if response.status_code != 200: + return "URL returned status code {}.".format(response.status_code) + + # Detect encoding using chardet + detected_encoding = chardet.detect(response.content) + encoding = detected_encoding["encoding"] + if encoding: + try: + content = response.content.decode(encoding) + except (UnicodeDecodeError, TypeError): + content = response.text + else: + content = response.text + + a = extract_using_readabilipy(content) + + if not a["plain_text"] or not a["plain_text"].strip(): + return "" + + res = FULL_TEMPLATE.format( + title=a["title"], + authors=a["byline"], + publish_date=a["date"], + top_image="", + text=a["plain_text"] or "", + ) + + return res + + +def extract_using_readabilipy(html): + with tempfile.NamedTemporaryFile(delete=False, mode="w+") as f_html: + f_html.write(html) + f_html.close() + html_path = f_html.name + + # Call Mozilla's Readability.js Readability.parse() function via node, writing output to a temporary file + article_json_path = html_path + ".json" + jsdir = os.path.join(find_module_path("readabilipy"), "javascript") + with chdir(jsdir): + subprocess.check_call(["node", "ExtractArticle.js", "-i", html_path, "-o", article_json_path]) + + # Read output of call to Readability.parse() from JSON file and return as Python dictionary + input_json = json.loads(Path(article_json_path).read_text(encoding="utf-8")) + + # Deleting files after processing + os.unlink(article_json_path) + os.unlink(html_path) + + article_json: dict[str, Any] = { + "title": None, + "byline": None, + "date": None, + "content": None, + "plain_content": None, + "plain_text": None, + } + # Populate article fields from readability fields where present + if input_json: + if input_json.get("title"): + article_json["title"] = input_json["title"] + if input_json.get("byline"): + article_json["byline"] = input_json["byline"] + if input_json.get("date"): + article_json["date"] = input_json["date"] + if input_json.get("content"): + article_json["content"] = input_json["content"] + article_json["plain_content"] = plain_content(article_json["content"], False, False) + article_json["plain_text"] = extract_text_blocks_as_plain_text(article_json["plain_content"]) + if input_json.get("textContent"): + article_json["plain_text"] = input_json["textContent"] + article_json["plain_text"] = re.sub(r"\n\s*\n", "\n", article_json["plain_text"]) + + return article_json + + +def find_module_path(module_name): + for package_path in site.getsitepackages(): + potential_path = os.path.join(package_path, module_name) + if os.path.exists(potential_path): + return potential_path + + return None + + +@contextmanager +def chdir(path): + """Change directory in context and return to original on exit""" + # From https://stackoverflow.com/a/37996581, couldn't find a built-in + original_path = os.getcwd() + os.chdir(path) + try: + yield + finally: + os.chdir(original_path) + + +def extract_text_blocks_as_plain_text(paragraph_html): + # Load article as DOM + soup = BeautifulSoup(paragraph_html, "html.parser") + # Select all lists + list_elements = soup.find_all(["ul", "ol"]) + # Prefix text in all list items with "* " and make lists paragraphs + for list_element in list_elements: + plain_items = "".join( + list(filter(None, [plain_text_leaf_node(li)["text"] for li in list_element.find_all("li")])) + ) + list_element.string = plain_items + list_element.name = "p" + # Select all text blocks + text_blocks = [s.parent for s in soup.find_all(string=True)] + text_blocks = [plain_text_leaf_node(block) for block in text_blocks] + # Drop empty paragraphs + text_blocks = list(filter(lambda p: p["text"] is not None, text_blocks)) + return text_blocks + + +def plain_text_leaf_node(element): + # Extract all text, stripped of any child HTML elements and normalize it + plain_text = normalize_text(element.get_text()) + if plain_text != "" and element.name == "li": + plain_text = "* {}, ".format(plain_text) + if plain_text == "": + plain_text = None + if "data-node-index" in element.attrs: + plain = {"node_index": element["data-node-index"], "text": plain_text} + else: + plain = {"text": plain_text} + return plain + + +def plain_content(readability_content, content_digests, node_indexes): + # Load article as DOM + soup = BeautifulSoup(readability_content, "html.parser") + # Make all elements plain + elements = plain_elements(soup.contents, content_digests, node_indexes) + if node_indexes: + # Add node index attributes to nodes + elements = [add_node_indexes(element) for element in elements] + # Replace article contents with plain elements + soup.contents = elements + return str(soup) + + +def plain_elements(elements, content_digests, node_indexes): + # Get plain content versions of all elements + elements = [plain_element(element, content_digests, node_indexes) for element in elements] + if content_digests: + # Add content digest attribute to nodes + elements = [add_content_digest(element) for element in elements] + return elements + + +def plain_element(element, content_digests, node_indexes): + # For lists, we make each item plain text + if is_leaf(element): + # For leaf node elements, extract the text content, discarding any HTML tags + # 1. Get element contents as text + plain_text = element.get_text() + # 2. Normalize the extracted text string to a canonical representation + plain_text = normalize_text(plain_text) + # 3. Update element content to be plain text + element.string = plain_text + elif is_text(element): + if is_non_printing(element): + # The simplified HTML may have come from Readability.js so might + # have non-printing text (e.g. Comment or CData). In this case, we + # keep the structure, but ensure that the string is empty. + element = type(element)("") + else: + plain_text = element.string + plain_text = normalize_text(plain_text) + element = type(element)(plain_text) + else: + # If not a leaf node or leaf type call recursively on child nodes, replacing + element.contents = plain_elements(element.contents, content_digests, node_indexes) + return element + + +def add_node_indexes(element, node_index="0"): + # Can't add attributes to string types + if is_text(element): + return element + # Add index to current element + element["data-node-index"] = node_index + # Add index to child elements + for local_idx, child in enumerate([c for c in element.contents if not is_text(c)], start=1): + # Can't add attributes to leaf string types + child_index = "{stem}.{local}".format(stem=node_index, local=local_idx) + add_node_indexes(child, node_index=child_index) + return element + + +def normalize_text(text): + """Normalize unicode and whitespace.""" + # Normalize unicode first to try and standardize whitespace characters as much as possible before normalizing them + text = strip_control_characters(text) + text = normalize_unicode(text) + text = normalize_whitespace(text) + return text + + +def strip_control_characters(text): + """Strip out unicode control characters which might break the parsing.""" + # Unicode control characters + # [Cc]: Other, Control [includes new lines] + # [Cf]: Other, Format + # [Cn]: Other, Not Assigned + # [Co]: Other, Private Use + # [Cs]: Other, Surrogate + control_chars = {"Cc", "Cf", "Cn", "Co", "Cs"} + retained_chars = ["\t", "\n", "\r", "\f"] + + # Remove non-printing control characters + return "".join( + [ + "" if (unicodedata.category(char) in control_chars) and (char not in retained_chars) else char + for char in text + ] + ) + + +def normalize_unicode(text): + """Normalize unicode such that things that are visually equivalent map to the same unicode string where possible.""" + normal_form: Literal["NFC", "NFD", "NFKC", "NFKD"] = "NFKC" + text = unicodedata.normalize(normal_form, text) + return text + + +def normalize_whitespace(text): + """Replace runs of whitespace characters with a single space as this is what happens when HTML text is displayed.""" + text = regex.sub(r"\s+", " ", text) + # Remove leading and trailing whitespace + text = text.strip() + return text + + +def is_leaf(element): + return element.name in {"p", "li"} + + +def is_text(element): + return isinstance(element, NavigableString) + + +def is_non_printing(element): + return any(isinstance(element, _e) for _e in [Comment, CData]) + + +def add_content_digest(element): + if not is_text(element): + element["data-content-digest"] = content_digest(element) + return element + + +def content_digest(element): + digest: Any + if is_text(element): + # Hash + trimmed_string = element.string.strip() + if trimmed_string == "": + digest = "" + else: + digest = hashlib.sha256(trimmed_string.encode("utf-8")).hexdigest() + else: + contents = element.contents + num_contents = len(contents) + if num_contents == 0: + # No hash when no child elements exist + digest = "" + elif num_contents == 1: + # If single child, use digest of child + digest = content_digest(contents[0]) + else: + # Build content digest from the "non-empty" digests of child nodes + digest = hashlib.sha256() + child_digests = list(filter(lambda x: x != "", [content_digest(content) for content in contents])) + for child in child_digests: + digest.update(child.encode("utf-8")) + digest = digest.hexdigest() + return digest + + +def get_image_upload_file_ids(content): + pattern = r"!\[image\]\((http?://.*?(file-preview|image-preview))\)" + matches = re.findall(pattern, content) + image_upload_file_ids = [] + for match in matches: + if match[1] == "file-preview": + content_pattern = r"files/([^/]+)/file-preview" + else: + content_pattern = r"files/([^/]+)/image-preview" + content_match = re.search(content_pattern, match[0]) + if content_match: + image_upload_file_id = content_match.group(1) + image_upload_file_ids.append(image_upload_file_id) + return image_upload_file_ids diff --git a/api/core/datasource/utils/workflow_configuration_sync.py b/api/core/datasource/utils/workflow_configuration_sync.py new file mode 100644 index 0000000000..d16d6fc576 --- /dev/null +++ b/api/core/datasource/utils/workflow_configuration_sync.py @@ -0,0 +1,43 @@ +from collections.abc import Mapping, Sequence +from typing import Any + +from core.app.app_config.entities import VariableEntity +from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration + + +class WorkflowToolConfigurationUtils: + @classmethod + def check_parameter_configurations(cls, configurations: list[Mapping[str, Any]]): + for configuration in configurations: + WorkflowToolParameterConfiguration.model_validate(configuration) + + @classmethod + def get_workflow_graph_variables(cls, graph: Mapping[str, Any]) -> Sequence[VariableEntity]: + """ + get workflow graph variables + """ + nodes = graph.get("nodes", []) + start_node = next(filter(lambda x: x.get("data", {}).get("type") == "start", nodes), None) + + if not start_node: + return [] + + return [VariableEntity.model_validate(variable) for variable in start_node.get("data", {}).get("variables", [])] + + @classmethod + def check_is_synced( + cls, variables: list[VariableEntity], tool_configurations: list[WorkflowToolParameterConfiguration] + ): + """ + check is synced + + raise ValueError if not synced + """ + variable_names = [variable.variable for variable in variables] + + if len(tool_configurations) != len(variables): + raise ValueError("parameter configuration mismatch, please republish the tool to update") + + for parameter in tool_configurations: + if parameter.name not in variable_names: + raise ValueError("parameter configuration mismatch, please republish the tool to update") diff --git a/api/core/datasource/utils/yaml_utils.py b/api/core/datasource/utils/yaml_utils.py new file mode 100644 index 0000000000..ee7ca11e05 --- /dev/null +++ b/api/core/datasource/utils/yaml_utils.py @@ -0,0 +1,35 @@ +import logging +from pathlib import Path +from typing import Any + +import yaml # type: ignore +from yaml import YAMLError + +logger = logging.getLogger(__name__) + + +def load_yaml_file(file_path: str, ignore_error: bool = True, default_value: Any = {}) -> Any: + """ + Safe loading a YAML file + :param file_path: the path of the YAML file + :param ignore_error: + if True, return default_value if error occurs and the error will be logged in debug level + if False, raise error if error occurs + :param default_value: the value returned when errors ignored + :return: an object of the YAML content + """ + if not file_path or not Path(file_path).exists(): + if ignore_error: + return default_value + else: + raise FileNotFoundError(f"File not found: {file_path}") + + with open(file_path, encoding="utf-8") as yaml_file: + try: + yaml_content = yaml.safe_load(yaml_file) + return yaml_content or default_value + except Exception as e: + if ignore_error: + return default_value + else: + raise YAMLError(f"Failed to load YAML file {file_path}: {e}") from e diff --git a/api/core/workflow/constants.py b/api/core/workflow/constants.py index e3fe17c284..e5deafc32f 100644 --- a/api/core/workflow/constants.py +++ b/api/core/workflow/constants.py @@ -1,3 +1,4 @@ SYSTEM_VARIABLE_NODE_ID = "sys" ENVIRONMENT_VARIABLE_NODE_ID = "env" CONVERSATION_VARIABLE_NODE_ID = "conversation" +PIPELINE_VARIABLE_NODE_ID = "pipeline" \ No newline at end of file diff --git a/api/core/workflow/entities/node_entities.py b/api/core/workflow/entities/node_entities.py index 82fd6cdc30..ecd2cfeabc 100644 --- a/api/core/workflow/entities/node_entities.py +++ b/api/core/workflow/entities/node_entities.py @@ -17,6 +17,7 @@ class NodeRunMetadataKey(StrEnum): TOTAL_PRICE = "total_price" CURRENCY = "currency" TOOL_INFO = "tool_info" + DATASOURCE_INFO = "datasource_info" AGENT_LOG = "agent_log" ITERATION_ID = "iteration_id" ITERATION_INDEX = "iteration_index" diff --git a/api/core/workflow/nodes/datasource/__init__.py b/api/core/workflow/nodes/datasource/__init__.py new file mode 100644 index 0000000000..cee9e5a895 --- /dev/null +++ b/api/core/workflow/nodes/datasource/__init__.py @@ -0,0 +1,3 @@ +from .tool_node import ToolNode + +__all__ = ["DatasourceNode"] diff --git a/api/core/workflow/nodes/datasource/datasource_node.py b/api/core/workflow/nodes/datasource/datasource_node.py new file mode 100644 index 0000000000..1752ba36fa --- /dev/null +++ b/api/core/workflow/nodes/datasource/datasource_node.py @@ -0,0 +1,406 @@ +from collections.abc import Generator, Mapping, Sequence +from typing import Any, cast + +from sqlalchemy import select +from sqlalchemy.orm import Session + +from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler +from core.file import File, FileTransferMethod +from core.plugin.manager.exc import PluginDaemonClientSideError +from core.plugin.manager.plugin import PluginInstallationManager +from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter +from core.tools.errors import ToolInvokeError +from core.tools.tool_engine import ToolEngine +from core.tools.utils.message_transformer import ToolFileMessageTransformer +from core.variables.segments import ArrayAnySegment +from core.variables.variables import ArrayAnyVariable +from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.enums import SystemVariableKey +from core.workflow.graph_engine.entities.event import AgentLogEvent +from core.workflow.nodes.base import BaseNode +from core.workflow.nodes.enums import NodeType +from core.workflow.nodes.event import RunCompletedEvent, RunStreamChunkEvent +from core.workflow.utils.variable_template_parser import VariableTemplateParser +from extensions.ext_database import db +from factories import file_factory +from models import ToolFile +from models.workflow import WorkflowNodeExecutionStatus +from services.tools.builtin_tools_manage_service import BuiltinToolManageService + +from .entities import DatasourceNodeData +from .exc import ( + ToolFileError, + ToolNodeError, + ToolParameterError, +) + + +class DatasourceNode(BaseNode[DatasourceNodeData]): + """ + Datasource Node + """ + + _node_data_cls = DatasourceNodeData + _node_type = NodeType.DATASOURCE + + def _run(self) -> Generator: + """ + Run the datasource node + """ + + node_data = cast(DatasourceNodeData, self.node_data) + + # fetch datasource icon + datasource_info = { + "provider_type": node_data.provider_type.value, + "provider_id": node_data.provider_id, + "plugin_unique_identifier": node_data.plugin_unique_identifier, + } + + # get datasource runtime + try: + from core.tools.tool_manager import ToolManager + + tool_runtime = ToolManager.get_workflow_tool_runtime( + self.tenant_id, self.app_id, self.node_id, self.node_data, self.invoke_from + ) + except ToolNodeError as e: + yield RunCompletedEvent( + run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + inputs={}, + metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info}, + error=f"Failed to get datasource runtime: {str(e)}", + error_type=type(e).__name__, + ) + ) + return + + # get parameters + tool_parameters = tool_runtime.get_merged_runtime_parameters() or [] + parameters = self._generate_parameters( + tool_parameters=tool_parameters, + variable_pool=self.graph_runtime_state.variable_pool, + node_data=self.node_data, + ) + parameters_for_log = self._generate_parameters( + tool_parameters=tool_parameters, + variable_pool=self.graph_runtime_state.variable_pool, + node_data=self.node_data, + for_log=True, + ) + + # get conversation id + conversation_id = self.graph_runtime_state.variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID]) + + try: + message_stream = ToolEngine.generic_invoke( + tool=tool_runtime, + tool_parameters=parameters, + user_id=self.user_id, + workflow_tool_callback=DifyWorkflowCallbackHandler(), + workflow_call_depth=self.workflow_call_depth, + thread_pool_id=self.thread_pool_id, + app_id=self.app_id, + conversation_id=conversation_id.text if conversation_id else None, + ) + except ToolNodeError as e: + yield RunCompletedEvent( + run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + inputs=parameters_for_log, + metadata={NodeRunMetadataKey.TOOL_INFO: tool_info}, + error=f"Failed to invoke tool: {str(e)}", + error_type=type(e).__name__, + ) + ) + return + + try: + # convert tool messages + yield from self._transform_message(message_stream, tool_info, parameters_for_log) + except (PluginDaemonClientSideError, ToolInvokeError) as e: + yield RunCompletedEvent( + run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + inputs=parameters_for_log, + metadata={NodeRunMetadataKey.TOOL_INFO: tool_info}, + error=f"Failed to transform tool message: {str(e)}", + error_type=type(e).__name__, + ) + ) + + def _generate_parameters( + self, + *, + tool_parameters: Sequence[ToolParameter], + variable_pool: VariablePool, + node_data: ToolNodeData, + for_log: bool = False, + ) -> dict[str, Any]: + """ + Generate parameters based on the given tool parameters, variable pool, and node data. + + Args: + tool_parameters (Sequence[ToolParameter]): The list of tool parameters. + variable_pool (VariablePool): The variable pool containing the variables. + node_data (ToolNodeData): The data associated with the tool node. + + Returns: + Mapping[str, Any]: A dictionary containing the generated parameters. + + """ + tool_parameters_dictionary = {parameter.name: parameter for parameter in tool_parameters} + + result: dict[str, Any] = {} + for parameter_name in node_data.tool_parameters: + parameter = tool_parameters_dictionary.get(parameter_name) + if not parameter: + result[parameter_name] = None + continue + tool_input = node_data.tool_parameters[parameter_name] + if tool_input.type == "variable": + variable = variable_pool.get(tool_input.value) + if variable is None: + raise ToolParameterError(f"Variable {tool_input.value} does not exist") + parameter_value = variable.value + elif tool_input.type in {"mixed", "constant"}: + segment_group = variable_pool.convert_template(str(tool_input.value)) + parameter_value = segment_group.log if for_log else segment_group.text + else: + raise ToolParameterError(f"Unknown tool input type '{tool_input.type}'") + result[parameter_name] = parameter_value + + return result + + def _fetch_files(self, variable_pool: VariablePool) -> list[File]: + variable = variable_pool.get(["sys", SystemVariableKey.FILES.value]) + assert isinstance(variable, ArrayAnyVariable | ArrayAnySegment) + return list(variable.value) if variable else [] + + def _transform_message( + self, + messages: Generator[ToolInvokeMessage, None, None], + tool_info: Mapping[str, Any], + parameters_for_log: dict[str, Any], + ) -> Generator: + """ + Convert ToolInvokeMessages into tuple[plain_text, files] + """ + # transform message and handle file storage + message_stream = ToolFileMessageTransformer.transform_tool_invoke_messages( + messages=messages, + user_id=self.user_id, + tenant_id=self.tenant_id, + conversation_id=None, + ) + + text = "" + files: list[File] = [] + json: list[dict] = [] + + agent_logs: list[AgentLogEvent] = [] + agent_execution_metadata: Mapping[NodeRunMetadataKey, Any] = {} + + variables: dict[str, Any] = {} + + for message in message_stream: + if message.type in { + ToolInvokeMessage.MessageType.IMAGE_LINK, + ToolInvokeMessage.MessageType.BINARY_LINK, + ToolInvokeMessage.MessageType.IMAGE, + }: + assert isinstance(message.message, ToolInvokeMessage.TextMessage) + + url = message.message.text + if message.meta: + transfer_method = message.meta.get("transfer_method", FileTransferMethod.TOOL_FILE) + else: + transfer_method = FileTransferMethod.TOOL_FILE + + tool_file_id = str(url).split("/")[-1].split(".")[0] + + with Session(db.engine) as session: + stmt = select(ToolFile).where(ToolFile.id == tool_file_id) + tool_file = session.scalar(stmt) + if tool_file is None: + raise ToolFileError(f"Tool file {tool_file_id} does not exist") + + mapping = { + "tool_file_id": tool_file_id, + "type": file_factory.get_file_type_by_mime_type(tool_file.mimetype), + "transfer_method": transfer_method, + "url": url, + } + file = file_factory.build_from_mapping( + mapping=mapping, + tenant_id=self.tenant_id, + ) + files.append(file) + elif message.type == ToolInvokeMessage.MessageType.BLOB: + # get tool file id + assert isinstance(message.message, ToolInvokeMessage.TextMessage) + assert message.meta + + tool_file_id = message.message.text.split("/")[-1].split(".")[0] + with Session(db.engine) as session: + stmt = select(ToolFile).where(ToolFile.id == tool_file_id) + tool_file = session.scalar(stmt) + if tool_file is None: + raise ToolFileError(f"tool file {tool_file_id} not exists") + + mapping = { + "tool_file_id": tool_file_id, + "transfer_method": FileTransferMethod.TOOL_FILE, + } + + files.append( + file_factory.build_from_mapping( + mapping=mapping, + tenant_id=self.tenant_id, + ) + ) + elif message.type == ToolInvokeMessage.MessageType.TEXT: + assert isinstance(message.message, ToolInvokeMessage.TextMessage) + text += message.message.text + yield RunStreamChunkEvent( + chunk_content=message.message.text, from_variable_selector=[self.node_id, "text"] + ) + elif message.type == ToolInvokeMessage.MessageType.JSON: + assert isinstance(message.message, ToolInvokeMessage.JsonMessage) + if self.node_type == NodeType.AGENT: + msg_metadata = message.message.json_object.pop("execution_metadata", {}) + agent_execution_metadata = { + key: value + for key, value in msg_metadata.items() + if key in NodeRunMetadataKey.__members__.values() + } + json.append(message.message.json_object) + elif message.type == ToolInvokeMessage.MessageType.LINK: + assert isinstance(message.message, ToolInvokeMessage.TextMessage) + stream_text = f"Link: {message.message.text}\n" + text += stream_text + yield RunStreamChunkEvent(chunk_content=stream_text, from_variable_selector=[self.node_id, "text"]) + elif message.type == ToolInvokeMessage.MessageType.VARIABLE: + assert isinstance(message.message, ToolInvokeMessage.VariableMessage) + variable_name = message.message.variable_name + variable_value = message.message.variable_value + if message.message.stream: + if not isinstance(variable_value, str): + raise ValueError("When 'stream' is True, 'variable_value' must be a string.") + if variable_name not in variables: + variables[variable_name] = "" + variables[variable_name] += variable_value + + yield RunStreamChunkEvent( + chunk_content=variable_value, from_variable_selector=[self.node_id, variable_name] + ) + else: + variables[variable_name] = variable_value + elif message.type == ToolInvokeMessage.MessageType.FILE: + assert message.meta is not None + files.append(message.meta["file"]) + elif message.type == ToolInvokeMessage.MessageType.LOG: + assert isinstance(message.message, ToolInvokeMessage.LogMessage) + if message.message.metadata: + icon = tool_info.get("icon", "") + dict_metadata = dict(message.message.metadata) + if dict_metadata.get("provider"): + manager = PluginInstallationManager() + plugins = manager.list_plugins(self.tenant_id) + try: + current_plugin = next( + plugin + for plugin in plugins + if f"{plugin.plugin_id}/{plugin.name}" == dict_metadata["provider"] + ) + icon = current_plugin.declaration.icon + except StopIteration: + pass + try: + builtin_tool = next( + provider + for provider in BuiltinToolManageService.list_builtin_tools( + self.user_id, + self.tenant_id, + ) + if provider.name == dict_metadata["provider"] + ) + icon = builtin_tool.icon + except StopIteration: + pass + + dict_metadata["icon"] = icon + message.message.metadata = dict_metadata + agent_log = AgentLogEvent( + id=message.message.id, + node_execution_id=self.id, + parent_id=message.message.parent_id, + error=message.message.error, + status=message.message.status.value, + data=message.message.data, + label=message.message.label, + metadata=message.message.metadata, + node_id=self.node_id, + ) + + # check if the agent log is already in the list + for log in agent_logs: + if log.id == agent_log.id: + # update the log + log.data = agent_log.data + log.status = agent_log.status + log.error = agent_log.error + log.label = agent_log.label + log.metadata = agent_log.metadata + break + else: + agent_logs.append(agent_log) + + yield agent_log + + yield RunCompletedEvent( + run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + outputs={"text": text, "files": files, "json": json, **variables}, + metadata={ + **agent_execution_metadata, + NodeRunMetadataKey.TOOL_INFO: tool_info, + NodeRunMetadataKey.AGENT_LOG: agent_logs, + }, + inputs=parameters_for_log, + ) + ) + + @classmethod + def _extract_variable_selector_to_variable_mapping( + cls, + *, + graph_config: Mapping[str, Any], + node_id: str, + node_data: ToolNodeData, + ) -> Mapping[str, Sequence[str]]: + """ + Extract variable selector to variable mapping + :param graph_config: graph config + :param node_id: node id + :param node_data: node data + :return: + """ + result = {} + for parameter_name in node_data.tool_parameters: + input = node_data.tool_parameters[parameter_name] + if input.type == "mixed": + assert isinstance(input.value, str) + selectors = VariableTemplateParser(input.value).extract_variable_selectors() + for selector in selectors: + result[selector.variable] = selector.value_selector + elif input.type == "variable": + result[parameter_name] = input.value + elif input.type == "constant": + pass + + result = {node_id + "." + key: value for key, value in result.items()} + + return result diff --git a/api/core/workflow/nodes/datasource/entities.py b/api/core/workflow/nodes/datasource/entities.py new file mode 100644 index 0000000000..66e8adc431 --- /dev/null +++ b/api/core/workflow/nodes/datasource/entities.py @@ -0,0 +1,56 @@ +from typing import Any, Literal, Union + +from pydantic import BaseModel, field_validator +from pydantic_core.core_schema import ValidationInfo + +from core.tools.entities.tool_entities import ToolProviderType +from core.workflow.nodes.base.entities import BaseNodeData + + +class DatasourceEntity(BaseModel): + provider_id: str + provider_type: ToolProviderType + provider_name: str # redundancy + tool_name: str + tool_label: str # redundancy + tool_configurations: dict[str, Any] + plugin_unique_identifier: str | None = None # redundancy + + @field_validator("tool_configurations", mode="before") + @classmethod + def validate_tool_configurations(cls, value, values: ValidationInfo): + if not isinstance(value, dict): + raise ValueError("tool_configurations must be a dictionary") + + for key in values.data.get("tool_configurations", {}): + value = values.data.get("tool_configurations", {}).get(key) + if not isinstance(value, str | int | float | bool): + raise ValueError(f"{key} must be a string") + + return value + + +class DatasourceNodeData(BaseNodeData, DatasourceEntity): + class DatasourceInput(BaseModel): + # TODO: check this type + value: Union[Any, list[str]] + type: Literal["mixed", "variable", "constant"] + + @field_validator("type", mode="before") + @classmethod + def check_type(cls, value, validation_info: ValidationInfo): + typ = value + value = validation_info.data.get("value") + if typ == "mixed" and not isinstance(value, str): + raise ValueError("value must be a string") + elif typ == "variable": + if not isinstance(value, list): + raise ValueError("value must be a list") + for val in value: + if not isinstance(val, str): + raise ValueError("value must be a list of strings") + elif typ == "constant" and not isinstance(value, str | int | float | bool): + raise ValueError("value must be a string, int, float, or bool") + return typ + + datasource_parameters: dict[str, DatasourceInput] diff --git a/api/core/workflow/nodes/datasource/exc.py b/api/core/workflow/nodes/datasource/exc.py new file mode 100644 index 0000000000..7212e8bfc0 --- /dev/null +++ b/api/core/workflow/nodes/datasource/exc.py @@ -0,0 +1,16 @@ +class ToolNodeError(ValueError): + """Base exception for tool node errors.""" + + pass + + +class ToolParameterError(ToolNodeError): + """Exception raised for errors in tool parameters.""" + + pass + + +class ToolFileError(ToolNodeError): + """Exception raised for errors related to tool files.""" + + pass diff --git a/api/core/workflow/nodes/enums.py b/api/core/workflow/nodes/enums.py index 73b43eeaf7..673d0ba049 100644 --- a/api/core/workflow/nodes/enums.py +++ b/api/core/workflow/nodes/enums.py @@ -13,6 +13,7 @@ class NodeType(StrEnum): QUESTION_CLASSIFIER = "question-classifier" HTTP_REQUEST = "http-request" TOOL = "tool" + DATASOURCE = "datasource" VARIABLE_AGGREGATOR = "variable-aggregator" LEGACY_VARIABLE_AGGREGATOR = "variable-assigner" # TODO: Merge this into VARIABLE_AGGREGATOR in the database. LOOP = "loop" diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index 6f0cc3f6d2..08af6b0014 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -73,7 +73,7 @@ class ToolNode(BaseNode[ToolNodeData]): metadata={NodeRunMetadataKey.TOOL_INFO: tool_info}, error=f"Failed to get tool runtime: {str(e)}", error_type=type(e).__name__, - ) + ) ) return diff --git a/api/factories/variable_factory.py b/api/factories/variable_factory.py index bbca8448ec..bb6b366a81 100644 --- a/api/factories/variable_factory.py +++ b/api/factories/variable_factory.py @@ -36,7 +36,11 @@ from core.variables.variables import ( StringVariable, Variable, ) -from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID +from core.workflow.constants import ( + CONVERSATION_VARIABLE_NODE_ID, + ENVIRONMENT_VARIABLE_NODE_ID, + PIPELINE_VARIABLE_NODE_ID, +) class InvalidSelectorError(ValueError): @@ -74,6 +78,10 @@ def build_environment_variable_from_mapping(mapping: Mapping[str, Any], /) -> Va raise VariableError("missing name") return _build_variable_from_mapping(mapping=mapping, selector=[ENVIRONMENT_VARIABLE_NODE_ID, mapping["name"]]) +def build_pipeline_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable: + if not mapping.get("name"): + raise VariableError("missing name") + return _build_variable_from_mapping(mapping=mapping, selector=[PIPELINE_VARIABLE_NODE_ID, mapping["name"]]) def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequence[str]) -> Variable: """ diff --git a/api/fields/workflow_fields.py b/api/fields/workflow_fields.py index 971e99c259..1bf70da9d9 100644 --- a/api/fields/workflow_fields.py +++ b/api/fields/workflow_fields.py @@ -40,6 +40,13 @@ conversation_variable_fields = { "description": fields.String, } +pipeline_variable_fields = { + "id": fields.String, + "name": fields.String, + "value_type": fields.String(attribute="value_type.value"), + "value": fields.Raw, +} + workflow_fields = { "id": fields.String, "graph": fields.Raw(attribute="graph_dict"), @@ -55,6 +62,10 @@ workflow_fields = { "tool_published": fields.Boolean, "environment_variables": fields.List(EnvironmentVariableField()), "conversation_variables": fields.List(fields.Nested(conversation_variable_fields)), + "pipeline_variables": fields.Dict( + keys=fields.String, + values=fields.List(fields.Nested(pipeline_variable_fields)), + ), } workflow_partial_fields = { diff --git a/api/models/workflow.py b/api/models/workflow.py index c85f335f37..2e9f6f0315 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -130,6 +130,9 @@ class Workflow(Base): _conversation_variables: Mapped[str] = mapped_column( "conversation_variables", db.Text, nullable=False, server_default="{}" ) + _pipeline_variables: Mapped[str] = mapped_column( + "conversation_variables", db.Text, nullable=False, server_default="{}" + ) @classmethod def new( @@ -343,6 +346,24 @@ class Workflow(Base): ensure_ascii=False, ) + @property + def pipeline_variables(self) -> dict[str, Sequence[Variable]]: + # TODO: find some way to init `self._conversation_variables` when instance created. + if self._pipeline_variables is None: + self._pipeline_variables = "{}" + + variables_dict: dict[str, Any] = json.loads(self._pipeline_variables) + results = {} + for k, v in variables_dict.items(): + results[k] = [variable_factory.build_pipeline_variable_from_mapping(item) for item in v.values()] + return results + + @pipeline_variables.setter + def pipeline_variables(self, values: dict[str, Sequence[Variable]]) -> None: + self._pipeline_variables = json.dumps( + {k: {item.name: item.model_dump() for item in v} for k, v in values.items()}, + ensure_ascii=False, + ) class WorkflowRunStatus(StrEnum): """ diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index c2c9c56e9d..422f24d521 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -16,12 +16,13 @@ from core.workflow.nodes.enums import NodeType from core.workflow.nodes.event.types import NodeEvent from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING from core.workflow.workflow_entry import WorkflowEntry -from extensions.db import db +from extensions.ext_database import db from models.account import Account from models.dataset import Pipeline, PipelineBuiltInTemplate, PipelineCustomizedTemplate # type: ignore from models.workflow import Workflow, WorkflowNodeExecution, WorkflowType from services.entities.knowledge_entities.rag_pipeline_entities import PipelineTemplateInfoEntity from services.errors.app import WorkflowHashNotEqualError +from services.errors.workflow_service import DraftWorkflowDeletionError from services.rag_pipeline.pipeline_template.pipeline_template_factory import PipelineTemplateRetrievalFactory @@ -186,6 +187,7 @@ class RagPipelineService: account: Account, environment_variables: Sequence[Variable], conversation_variables: Sequence[Variable], + pipeline_variables: dict[str, Sequence[Variable]], ) -> Workflow: """ Sync draft workflow @@ -212,6 +214,7 @@ class RagPipelineService: created_by=account.id, environment_variables=environment_variables, conversation_variables=conversation_variables, + pipeline_variables=pipeline_variables, ) db.session.add(workflow) # update draft workflow if found @@ -222,7 +225,7 @@ class RagPipelineService: workflow.updated_at = datetime.now(UTC).replace(tzinfo=None) workflow.environment_variables = environment_variables workflow.conversation_variables = conversation_variables - + workflow.pipeline_variables = pipeline_variables # commit db session changes db.session.commit() @@ -342,6 +345,41 @@ class RagPipelineService: db.session.commit() return workflow_node_execution + + def run_datasource_workflow_node( + self, pipeline: Pipeline, node_id: str, user_inputs: dict, account: Account + ) -> WorkflowNodeExecution: + """ + Run published workflow datasource + """ + # fetch published workflow by app_model + published_workflow = self.get_published_workflow(pipeline=pipeline) + if not published_workflow: + raise ValueError("Workflow not initialized") + + # run draft workflow node + start_at = time.perf_counter() + + workflow_node_execution = self._handle_node_run_result( + getter=lambda: WorkflowEntry.single_step_run( + workflow=published_workflow, + node_id=node_id, + user_inputs=user_inputs, + user_id=account.id, + ), + start_at=start_at, + tenant_id=pipeline.tenant_id, + node_id=node_id, + ) + + workflow_node_execution.app_id = pipeline.id + workflow_node_execution.created_by = account.id + workflow_node_execution.workflow_id = published_workflow.id + + db.session.add(workflow_node_execution) + db.session.commit() + + return workflow_node_execution def run_free_workflow_node( self, node_data: dict, tenant_id: str, user_id: str, node_id: str, user_inputs: dict[str, Any] @@ -573,3 +611,21 @@ class RagPipelineService: session.delete(workflow) return True + + def get_second_step_parameters(self, pipeline: Pipeline, datasource_provider: str) -> dict: + """ + Get second step parameters of rag pipeline + """ + + workflow = self.get_published_workflow(pipeline=pipeline) + if not workflow: + raise ValueError("Workflow not initialized") + + # get second step node + pipeline_variables = workflow.pipeline_variables + if not pipeline_variables: + return {} + # get datasource provider + datasource_provider_variables = pipeline_variables.get(datasource_provider, []) + shared_variables = pipeline_variables.get("shared", []) + return datasource_provider_variables + shared_variables \ No newline at end of file