mirror of https://github.com/langgenius/dify.git
This commit is contained in:
parent
9f8e05d9f0
commit
5c4bf2a9e4
|
@ -1,5 +1,6 @@
|
|||
from flask import Blueprint
|
||||
|
||||
from .datasets.rag_pipeline import data_source
|
||||
from libs.external_api import ExternalApi
|
||||
|
||||
from .app.app_import import AppImportApi, AppImportCheckDependenciesApi, AppImportConfirmApi
|
||||
|
@ -75,7 +76,6 @@ from .billing import billing, compliance
|
|||
|
||||
# Import datasets controllers
|
||||
from .datasets import (
|
||||
data_source,
|
||||
datasets,
|
||||
datasets_document,
|
||||
datasets_segments,
|
||||
|
|
|
@ -0,0 +1,136 @@
|
|||
import json
|
||||
import logging
|
||||
from typing import cast
|
||||
|
||||
from flask import abort, request
|
||||
from flask_restful import Resource, inputs, marshal_with, reqparse # type: ignore # type: ignore
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
from configs import dify_config
|
||||
from controllers.console import api
|
||||
from controllers.console.app.error import (
|
||||
ConversationCompletedError,
|
||||
DraftWorkflowNotExist,
|
||||
DraftWorkflowNotSync,
|
||||
)
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.datasets.wraps import get_rag_pipeline
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
enterprise_license_required,
|
||||
setup_required,
|
||||
)
|
||||
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from extensions.ext_database import db
|
||||
from factories import variable_factory
|
||||
from fields.workflow_fields import workflow_fields, workflow_pagination_fields
|
||||
from fields.workflow_run_fields import workflow_run_node_execution_fields
|
||||
from libs import helper
|
||||
from libs.helper import TimestampField
|
||||
from libs.login import current_user, login_required
|
||||
from models import App
|
||||
from models.account import Account
|
||||
from models.dataset import Pipeline
|
||||
from models.model import AppMode
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.entities.knowledge_entities.rag_pipeline_entities import PipelineTemplateInfoEntity
|
||||
from services.errors.app import WorkflowHashNotEqualError
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
from services.rag_pipeline.rag_pipeline import RagPipelineService
|
||||
from services.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError, WorkflowService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _validate_name(name):
|
||||
if not name or len(name) < 1 or len(name) > 40:
|
||||
raise ValueError("Name must be between 1 to 40 characters.")
|
||||
return name
|
||||
|
||||
|
||||
def _validate_description_length(description):
|
||||
if len(description) > 400:
|
||||
raise ValueError("Description cannot exceed 400 characters.")
|
||||
return description
|
||||
|
||||
|
||||
class PipelineTemplateListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
def get(self):
|
||||
type = request.args.get("type", default="built-in", type=str, choices=["built-in", "customized"])
|
||||
language = request.args.get("language", default="en-US", type=str)
|
||||
# get pipeline templates
|
||||
pipeline_templates = RagPipelineService.get_pipeline_templates(type, language)
|
||||
return pipeline_templates, 200
|
||||
|
||||
|
||||
class PipelineTemplateDetailApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
def get(self, pipeline_id: str):
|
||||
pipeline_template = RagPipelineService.get_pipeline_template_detail(pipeline_id)
|
||||
return pipeline_template, 200
|
||||
|
||||
|
||||
class CustomizedPipelineTemplateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
def patch(self, template_id: str):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument(
|
||||
"name",
|
||||
nullable=False,
|
||||
required=True,
|
||||
help="Name must be between 1 to 40 characters.",
|
||||
type=_validate_name,
|
||||
)
|
||||
parser.add_argument(
|
||||
"description",
|
||||
type=str,
|
||||
nullable=True,
|
||||
required=False,
|
||||
default="",
|
||||
)
|
||||
parser.add_argument(
|
||||
"icon_info",
|
||||
type=dict,
|
||||
location="json",
|
||||
nullable=True,
|
||||
)
|
||||
args = parser.parse_args()
|
||||
pipeline_template_info = PipelineTemplateInfoEntity(**args)
|
||||
pipeline_template = RagPipelineService.update_customized_pipeline_template(template_id, pipeline_template_info)
|
||||
return pipeline_template, 200
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
def delete(self, template_id: str):
|
||||
RagPipelineService.delete_customized_pipeline_template(template_id)
|
||||
return 200
|
||||
|
||||
|
||||
api.add_resource(
|
||||
PipelineTemplateListApi,
|
||||
"/rag/pipeline/templates",
|
||||
)
|
||||
api.add_resource(
|
||||
PipelineTemplateDetailApi,
|
||||
"/rag/pipeline/templates/<string:pipeline_id>",
|
||||
)
|
||||
api.add_resource(
|
||||
CustomizedPipelineTemplateApi,
|
||||
"/rag/pipeline/templates/<string:template_id>",
|
||||
)
|
|
@ -15,11 +15,9 @@ from controllers.console.app.error import (
|
|||
DraftWorkflowNotExist,
|
||||
DraftWorkflowNotSync,
|
||||
)
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.datasets.wraps import get_rag_pipeline
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
enterprise_license_required,
|
||||
setup_required,
|
||||
)
|
||||
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
|
||||
|
@ -32,96 +30,17 @@ from fields.workflow_run_fields import workflow_run_node_execution_fields
|
|||
from libs import helper
|
||||
from libs.helper import TimestampField
|
||||
from libs.login import current_user, login_required
|
||||
from models import App
|
||||
from models.account import Account
|
||||
from models.dataset import Pipeline
|
||||
from models.model import AppMode
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.entities.knowledge_entities.rag_pipeline_entities import PipelineTemplateInfoEntity
|
||||
from services.errors.app import WorkflowHashNotEqualError
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
from services.rag_pipeline.rag_pipeline import RagPipelineService
|
||||
from services.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError, WorkflowService
|
||||
from services.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _validate_name(name):
|
||||
if not name or len(name) < 1 or len(name) > 40:
|
||||
raise ValueError("Name must be between 1 to 40 characters.")
|
||||
return name
|
||||
|
||||
|
||||
def _validate_description_length(description):
|
||||
if len(description) > 400:
|
||||
raise ValueError("Description cannot exceed 400 characters.")
|
||||
return description
|
||||
|
||||
|
||||
class PipelineTemplateListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
def get(self):
|
||||
type = request.args.get("type", default="built-in", type=str, choices=["built-in", "customized"])
|
||||
language = request.args.get("language", default="en-US", type=str)
|
||||
# get pipeline templates
|
||||
pipeline_templates = RagPipelineService.get_pipeline_templates(type, language)
|
||||
return pipeline_templates, 200
|
||||
|
||||
|
||||
class PipelineTemplateDetailApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
def get(self, pipeline_id: str):
|
||||
pipeline_template = RagPipelineService.get_pipeline_template_detail(pipeline_id)
|
||||
return pipeline_template, 200
|
||||
|
||||
|
||||
class CustomizedPipelineTemplateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
def patch(self, template_id: str):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument(
|
||||
"name",
|
||||
nullable=False,
|
||||
required=True,
|
||||
help="Name must be between 1 to 40 characters.",
|
||||
type=_validate_name,
|
||||
)
|
||||
parser.add_argument(
|
||||
"description",
|
||||
type=str,
|
||||
nullable=True,
|
||||
required=False,
|
||||
default="",
|
||||
)
|
||||
parser.add_argument(
|
||||
"icon_info",
|
||||
type=dict,
|
||||
location="json",
|
||||
nullable=True,
|
||||
)
|
||||
args = parser.parse_args()
|
||||
pipeline_template_info = PipelineTemplateInfoEntity(**args)
|
||||
pipeline_template = RagPipelineService.update_customized_pipeline_template(template_id, pipeline_template_info)
|
||||
return pipeline_template, 200
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
def delete(self, template_id: str):
|
||||
RagPipelineService.delete_customized_pipeline_template(template_id)
|
||||
return 200
|
||||
|
||||
|
||||
class DraftRagPipelineApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
@ -130,7 +49,7 @@ class DraftRagPipelineApi(Resource):
|
|||
@marshal_with(workflow_fields)
|
||||
def get(self, pipeline: Pipeline):
|
||||
"""
|
||||
Get draft workflow
|
||||
Get draft rag pipeline's workflow
|
||||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
|
@ -167,6 +86,7 @@ class DraftRagPipelineApi(Resource):
|
|||
parser.add_argument("hash", type=str, required=False, location="json")
|
||||
parser.add_argument("environment_variables", type=list, required=False, location="json")
|
||||
parser.add_argument("conversation_variables", type=list, required=False, location="json")
|
||||
parser.add_argument("pipeline_variables", type=dict, required=False, location="json")
|
||||
args = parser.parse_args()
|
||||
elif "text/plain" in content_type:
|
||||
try:
|
||||
|
@ -183,6 +103,7 @@ class DraftRagPipelineApi(Resource):
|
|||
"hash": data.get("hash"),
|
||||
"environment_variables": data.get("environment_variables"),
|
||||
"conversation_variables": data.get("conversation_variables"),
|
||||
"pipeline_variables": data.get("pipeline_variables"),
|
||||
}
|
||||
except json.JSONDecodeError:
|
||||
return {"message": "Invalid JSON data"}, 400
|
||||
|
@ -192,8 +113,6 @@ class DraftRagPipelineApi(Resource):
|
|||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
|
||||
try:
|
||||
environment_variables_list = args.get("environment_variables") or []
|
||||
environment_variables = [
|
||||
|
@ -203,6 +122,11 @@ class DraftRagPipelineApi(Resource):
|
|||
conversation_variables = [
|
||||
variable_factory.build_conversation_variable_from_mapping(obj) for obj in conversation_variables_list
|
||||
]
|
||||
pipeline_variables_list = args.get("pipeline_variables") or {}
|
||||
pipeline_variables = {
|
||||
k: [variable_factory.build_pipeline_variable_from_mapping(obj) for obj in v]
|
||||
for k, v in pipeline_variables_list.items()
|
||||
}
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
workflow = rag_pipeline_service.sync_draft_workflow(
|
||||
pipeline=pipeline,
|
||||
|
@ -212,6 +136,7 @@ class DraftRagPipelineApi(Resource):
|
|||
account=current_user,
|
||||
environment_variables=environment_variables,
|
||||
conversation_variables=conversation_variables,
|
||||
pipeline_variables=pipeline_variables,
|
||||
)
|
||||
except WorkflowHashNotEqualError:
|
||||
raise DraftWorkflowNotSync()
|
||||
|
@ -263,8 +188,8 @@ class RagPipelineDraftRunLoopNodeApi(Resource):
|
|||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.WORKFLOW])
|
||||
def post(self, app_model: App, node_id: str):
|
||||
@get_rag_pipeline
|
||||
def post(self, pipeline: Pipeline, node_id: str):
|
||||
"""
|
||||
Run draft workflow loop node
|
||||
"""
|
||||
|
@ -281,7 +206,7 @@ class RagPipelineDraftRunLoopNodeApi(Resource):
|
|||
|
||||
try:
|
||||
response = AppGenerateService.generate_single_loop(
|
||||
app_model=app_model, user=current_user, node_id=node_id, args=args, streaming=True
|
||||
pipeline=pipeline, user=current_user, node_id=node_id, args=args, streaming=True
|
||||
)
|
||||
|
||||
return helper.compact_generate_response(response)
|
||||
|
@ -300,8 +225,8 @@ class DraftRagPipelineRunApi(Resource):
|
|||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.WORKFLOW])
|
||||
def post(self, app_model: App):
|
||||
@get_rag_pipeline
|
||||
def post(self, pipeline: Pipeline):
|
||||
"""
|
||||
Run draft workflow
|
||||
"""
|
||||
|
@ -319,7 +244,7 @@ class DraftRagPipelineRunApi(Resource):
|
|||
|
||||
try:
|
||||
response = AppGenerateService.generate(
|
||||
app_model=app_model,
|
||||
pipeline=pipeline,
|
||||
user=current_user,
|
||||
args=args,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
|
@ -330,32 +255,45 @@ class DraftRagPipelineRunApi(Resource):
|
|||
except InvokeRateLimitError as ex:
|
||||
raise InvokeRateLimitHttpError(ex.description)
|
||||
|
||||
|
||||
class RagPipelineTaskStopApi(Resource):
|
||||
class RagPipelineDatasourceNodeRunApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def post(self, app_model: App, task_id: str):
|
||||
@get_rag_pipeline
|
||||
def post(self, pipeline: Pipeline, node_id: str):
|
||||
"""
|
||||
Stop workflow task
|
||||
Run rag pipeline datasource
|
||||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, current_user.id)
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
|
||||
return {"result": "success"}
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
inputs = args.get("inputs")
|
||||
if inputs == None:
|
||||
raise ValueError("missing inputs")
|
||||
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
workflow_node_execution = rag_pipeline_service.run_datasource_workflow_node(
|
||||
pipeline=pipeline, node_id=node_id, user_inputs=inputs, account=current_user
|
||||
)
|
||||
|
||||
return workflow_node_execution
|
||||
|
||||
|
||||
class RagPipelineNodeRunApi(Resource):
|
||||
class RagPipelineDraftNodeRunApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@get_rag_pipeline
|
||||
@marshal_with(workflow_run_node_execution_fields)
|
||||
def post(self, app_model: App, node_id: str):
|
||||
def post(self, pipeline: Pipeline, node_id: str):
|
||||
"""
|
||||
Run draft workflow node
|
||||
"""
|
||||
|
@ -374,13 +312,29 @@ class RagPipelineNodeRunApi(Resource):
|
|||
if inputs == None:
|
||||
raise ValueError("missing inputs")
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
workflow_node_execution = workflow_service.run_draft_workflow_node(
|
||||
app_model=app_model, node_id=node_id, user_inputs=inputs, account=current_user
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
workflow_node_execution = rag_pipeline_service.run_draft_workflow_node(
|
||||
pipeline=pipeline, node_id=node_id, user_inputs=inputs, account=current_user
|
||||
)
|
||||
|
||||
return workflow_node_execution
|
||||
|
||||
class RagPipelineTaskStopApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_rag_pipeline
|
||||
def post(self, pipeline: Pipeline, task_id: str):
|
||||
"""
|
||||
Stop workflow task
|
||||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, current_user.id)
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
class PublishedRagPipelineApi(Resource):
|
||||
@setup_required
|
||||
|
@ -695,6 +649,25 @@ class RagPipelineByIdApi(Resource):
|
|||
|
||||
return None, 204
|
||||
|
||||
class RagPipelineSecondStepApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_rag_pipeline
|
||||
def get(self, pipeline: Pipeline):
|
||||
"""
|
||||
Get second step parameters of rag pipeline
|
||||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
datasource_provider = request.args.get("datasource_provider", required=True, type=str)
|
||||
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
return rag_pipeline_service.get_second_step_parameters(pipeline=pipeline,
|
||||
datasource_provider=datasource_provider
|
||||
)
|
||||
|
||||
|
||||
api.add_resource(
|
||||
DraftRagPipelineApi,
|
||||
|
@ -713,9 +686,13 @@ api.add_resource(
|
|||
"/rag/pipelines/<uuid:pipeline_id>/workflow-runs/tasks/<string:task_id>/stop",
|
||||
)
|
||||
api.add_resource(
|
||||
RagPipelineNodeRunApi,
|
||||
RagPipelineDraftNodeRunApi,
|
||||
"/rag/pipelines/<uuid:pipeline_id>/workflows/draft/nodes/<string:node_id>/run",
|
||||
)
|
||||
api.add_resource(
|
||||
RagPipelinePublishedNodeRunApi,
|
||||
"/rag/pipelines/<uuid:pipeline_id>/workflows/published/nodes/<string:node_id>/run",
|
||||
)
|
||||
|
||||
api.add_resource(
|
||||
RagPipelineDraftRunIterationNodeApi,
|
||||
|
@ -751,15 +728,3 @@ api.add_resource(
|
|||
"/rag/pipelines/<uuid:pipeline_id>/workflows/<string:workflow_id>",
|
||||
)
|
||||
|
||||
api.add_resource(
|
||||
PipelineTemplateListApi,
|
||||
"/rag/pipeline/templates",
|
||||
)
|
||||
api.add_resource(
|
||||
PipelineTemplateDetailApi,
|
||||
"/rag/pipeline/templates/<string:pipeline_id>",
|
||||
)
|
||||
api.add_resource(
|
||||
CustomizedPipelineTemplateApi,
|
||||
"/rag/pipeline/templates/<string:template_id>",
|
||||
)
|
|
@ -0,0 +1,222 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Generator
|
||||
from copy import deepcopy
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from models.model import File
|
||||
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
from core.tools.entities.tool_entities import (
|
||||
ToolEntity,
|
||||
ToolInvokeMessage,
|
||||
ToolParameter,
|
||||
ToolProviderType,
|
||||
)
|
||||
|
||||
|
||||
class Tool(ABC):
|
||||
"""
|
||||
The base class of a tool
|
||||
"""
|
||||
|
||||
entity: ToolEntity
|
||||
runtime: ToolRuntime
|
||||
|
||||
def __init__(self, entity: ToolEntity, runtime: ToolRuntime) -> None:
|
||||
self.entity = entity
|
||||
self.runtime = runtime
|
||||
|
||||
def fork_tool_runtime(self, runtime: ToolRuntime) -> "Tool":
|
||||
"""
|
||||
fork a new tool with metadata
|
||||
:return: the new tool
|
||||
"""
|
||||
return self.__class__(
|
||||
entity=self.entity.model_copy(),
|
||||
runtime=runtime,
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def tool_provider_type(self) -> ToolProviderType:
|
||||
"""
|
||||
get the tool provider type
|
||||
|
||||
:return: the tool provider type
|
||||
"""
|
||||
|
||||
def invoke(
|
||||
self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
conversation_id: Optional[str] = None,
|
||||
app_id: Optional[str] = None,
|
||||
message_id: Optional[str] = None,
|
||||
) -> Generator[ToolInvokeMessage]:
|
||||
if self.runtime and self.runtime.runtime_parameters:
|
||||
tool_parameters.update(self.runtime.runtime_parameters)
|
||||
|
||||
# try parse tool parameters into the correct type
|
||||
tool_parameters = self._transform_tool_parameters_type(tool_parameters)
|
||||
|
||||
result = self._invoke(
|
||||
user_id=user_id,
|
||||
tool_parameters=tool_parameters,
|
||||
conversation_id=conversation_id,
|
||||
app_id=app_id,
|
||||
message_id=message_id,
|
||||
)
|
||||
|
||||
if isinstance(result, ToolInvokeMessage):
|
||||
|
||||
def single_generator() -> Generator[ToolInvokeMessage, None, None]:
|
||||
yield result
|
||||
|
||||
return single_generator()
|
||||
elif isinstance(result, list):
|
||||
|
||||
def generator() -> Generator[ToolInvokeMessage, None, None]:
|
||||
yield from result
|
||||
|
||||
return generator()
|
||||
else:
|
||||
return result
|
||||
|
||||
def _transform_tool_parameters_type(self, tool_parameters: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Transform tool parameters type
|
||||
"""
|
||||
# Temp fix for the issue that the tool parameters will be converted to empty while validating the credentials
|
||||
result = deepcopy(tool_parameters)
|
||||
for parameter in self.entity.parameters or []:
|
||||
if parameter.name in tool_parameters:
|
||||
result[parameter.name] = parameter.type.cast_value(tool_parameters[parameter.name])
|
||||
|
||||
return result
|
||||
|
||||
@abstractmethod
|
||||
def _invoke(
|
||||
self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
conversation_id: Optional[str] = None,
|
||||
app_id: Optional[str] = None,
|
||||
message_id: Optional[str] = None,
|
||||
) -> ToolInvokeMessage | list[ToolInvokeMessage] | Generator[ToolInvokeMessage, None, None]:
|
||||
pass
|
||||
|
||||
def get_runtime_parameters(
|
||||
self,
|
||||
conversation_id: Optional[str] = None,
|
||||
app_id: Optional[str] = None,
|
||||
message_id: Optional[str] = None,
|
||||
) -> list[ToolParameter]:
|
||||
"""
|
||||
get the runtime parameters
|
||||
|
||||
interface for developer to dynamic change the parameters of a tool depends on the variables pool
|
||||
|
||||
:return: the runtime parameters
|
||||
"""
|
||||
return self.entity.parameters
|
||||
|
||||
def get_merged_runtime_parameters(
|
||||
self,
|
||||
conversation_id: Optional[str] = None,
|
||||
app_id: Optional[str] = None,
|
||||
message_id: Optional[str] = None,
|
||||
) -> list[ToolParameter]:
|
||||
"""
|
||||
get merged runtime parameters
|
||||
|
||||
:return: merged runtime parameters
|
||||
"""
|
||||
parameters = self.entity.parameters
|
||||
parameters = parameters.copy()
|
||||
user_parameters = self.get_runtime_parameters() or []
|
||||
user_parameters = user_parameters.copy()
|
||||
|
||||
# override parameters
|
||||
for parameter in user_parameters:
|
||||
# check if parameter in tool parameters
|
||||
for tool_parameter in parameters:
|
||||
if tool_parameter.name == parameter.name:
|
||||
# override parameter
|
||||
tool_parameter.type = parameter.type
|
||||
tool_parameter.form = parameter.form
|
||||
tool_parameter.required = parameter.required
|
||||
tool_parameter.default = parameter.default
|
||||
tool_parameter.options = parameter.options
|
||||
tool_parameter.llm_description = parameter.llm_description
|
||||
break
|
||||
else:
|
||||
# add new parameter
|
||||
parameters.append(parameter)
|
||||
|
||||
return parameters
|
||||
|
||||
def create_image_message(
|
||||
self,
|
||||
image: str,
|
||||
) -> ToolInvokeMessage:
|
||||
"""
|
||||
create an image message
|
||||
|
||||
:param image: the url of the image
|
||||
:return: the image message
|
||||
"""
|
||||
return ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.IMAGE, message=ToolInvokeMessage.TextMessage(text=image)
|
||||
)
|
||||
|
||||
def create_file_message(self, file: "File") -> ToolInvokeMessage:
|
||||
return ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.FILE,
|
||||
message=ToolInvokeMessage.FileMessage(),
|
||||
meta={"file": file},
|
||||
)
|
||||
|
||||
def create_link_message(self, link: str) -> ToolInvokeMessage:
|
||||
"""
|
||||
create a link message
|
||||
|
||||
:param link: the url of the link
|
||||
:return: the link message
|
||||
"""
|
||||
return ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.LINK, message=ToolInvokeMessage.TextMessage(text=link)
|
||||
)
|
||||
|
||||
def create_text_message(self, text: str) -> ToolInvokeMessage:
|
||||
"""
|
||||
create a text message
|
||||
|
||||
:param text: the text
|
||||
:return: the text message
|
||||
"""
|
||||
return ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.TEXT,
|
||||
message=ToolInvokeMessage.TextMessage(text=text),
|
||||
)
|
||||
|
||||
def create_blob_message(self, blob: bytes, meta: Optional[dict] = None) -> ToolInvokeMessage:
|
||||
"""
|
||||
create a blob message
|
||||
|
||||
:param blob: the blob
|
||||
:param meta: the meta info of blob object
|
||||
:return: the blob message
|
||||
"""
|
||||
return ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.BLOB,
|
||||
message=ToolInvokeMessage.BlobMessage(blob=blob),
|
||||
meta=meta,
|
||||
)
|
||||
|
||||
def create_json_message(self, object: dict) -> ToolInvokeMessage:
|
||||
"""
|
||||
create a json message
|
||||
"""
|
||||
return ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.JSON, message=ToolInvokeMessage.JsonMessage(json_object=object)
|
||||
)
|
|
@ -0,0 +1,109 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from copy import deepcopy
|
||||
from typing import Any
|
||||
|
||||
from core.entities.provider_entities import ProviderConfig
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.entities.tool_entities import (
|
||||
ToolProviderEntity,
|
||||
ToolProviderType,
|
||||
)
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
|
||||
|
||||
class ToolProviderController(ABC):
|
||||
entity: ToolProviderEntity
|
||||
|
||||
def __init__(self, entity: ToolProviderEntity) -> None:
|
||||
self.entity = entity
|
||||
|
||||
def get_credentials_schema(self) -> list[ProviderConfig]:
|
||||
"""
|
||||
returns the credentials schema of the provider
|
||||
|
||||
:return: the credentials schema
|
||||
"""
|
||||
return deepcopy(self.entity.credentials_schema)
|
||||
|
||||
@abstractmethod
|
||||
def get_tool(self, tool_name: str) -> Tool:
|
||||
"""
|
||||
returns a tool that the provider can provide
|
||||
|
||||
:return: tool
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
def provider_type(self) -> ToolProviderType:
|
||||
"""
|
||||
returns the type of the provider
|
||||
|
||||
:return: type of the provider
|
||||
"""
|
||||
return ToolProviderType.BUILT_IN
|
||||
|
||||
def validate_credentials_format(self, credentials: dict[str, Any]) -> None:
|
||||
"""
|
||||
validate the format of the credentials of the provider and set the default value if needed
|
||||
|
||||
:param credentials: the credentials of the tool
|
||||
"""
|
||||
credentials_schema = dict[str, ProviderConfig]()
|
||||
if credentials_schema is None:
|
||||
return
|
||||
|
||||
for credential in self.entity.credentials_schema:
|
||||
credentials_schema[credential.name] = credential
|
||||
|
||||
credentials_need_to_validate: dict[str, ProviderConfig] = {}
|
||||
for credential_name in credentials_schema:
|
||||
credentials_need_to_validate[credential_name] = credentials_schema[credential_name]
|
||||
|
||||
for credential_name in credentials:
|
||||
if credential_name not in credentials_need_to_validate:
|
||||
raise ToolProviderCredentialValidationError(
|
||||
f"credential {credential_name} not found in provider {self.entity.identity.name}"
|
||||
)
|
||||
|
||||
# check type
|
||||
credential_schema = credentials_need_to_validate[credential_name]
|
||||
if not credential_schema.required and credentials[credential_name] is None:
|
||||
continue
|
||||
|
||||
if credential_schema.type in {ProviderConfig.Type.SECRET_INPUT, ProviderConfig.Type.TEXT_INPUT}:
|
||||
if not isinstance(credentials[credential_name], str):
|
||||
raise ToolProviderCredentialValidationError(f"credential {credential_name} should be string")
|
||||
|
||||
elif credential_schema.type == ProviderConfig.Type.SELECT:
|
||||
if not isinstance(credentials[credential_name], str):
|
||||
raise ToolProviderCredentialValidationError(f"credential {credential_name} should be string")
|
||||
|
||||
options = credential_schema.options
|
||||
if not isinstance(options, list):
|
||||
raise ToolProviderCredentialValidationError(f"credential {credential_name} options should be list")
|
||||
|
||||
if credentials[credential_name] not in [x.value for x in options]:
|
||||
raise ToolProviderCredentialValidationError(
|
||||
f"credential {credential_name} should be one of {options}"
|
||||
)
|
||||
|
||||
credentials_need_to_validate.pop(credential_name)
|
||||
|
||||
for credential_name in credentials_need_to_validate:
|
||||
credential_schema = credentials_need_to_validate[credential_name]
|
||||
if credential_schema.required:
|
||||
raise ToolProviderCredentialValidationError(f"credential {credential_name} is required")
|
||||
|
||||
# the credential is not set currently, set the default value if needed
|
||||
if credential_schema.default is not None:
|
||||
default_value = credential_schema.default
|
||||
# parse default value into the correct type
|
||||
if credential_schema.type in {
|
||||
ProviderConfig.Type.SECRET_INPUT,
|
||||
ProviderConfig.Type.TEXT_INPUT,
|
||||
ProviderConfig.Type.SELECT,
|
||||
}:
|
||||
default_value = str(default_value)
|
||||
|
||||
credentials[credential_name] = default_value
|
|
@ -0,0 +1,36 @@
|
|||
from typing import Any, Optional
|
||||
|
||||
from openai import BaseModel
|
||||
from pydantic import Field
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.tools.entities.tool_entities import ToolInvokeFrom
|
||||
|
||||
|
||||
class ToolRuntime(BaseModel):
|
||||
"""
|
||||
Meta data of a tool call processing
|
||||
"""
|
||||
|
||||
tenant_id: str
|
||||
tool_id: Optional[str] = None
|
||||
invoke_from: Optional[InvokeFrom] = None
|
||||
tool_invoke_from: Optional[ToolInvokeFrom] = None
|
||||
credentials: dict[str, Any] = Field(default_factory=dict)
|
||||
runtime_parameters: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class FakeToolRuntime(ToolRuntime):
|
||||
"""
|
||||
Fake tool runtime for testing
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
tenant_id="fake_tenant_id",
|
||||
tool_id="fake_tool_id",
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
tool_invoke_from=ToolInvokeFrom.AGENT,
|
||||
credentials={},
|
||||
runtime_parameters={},
|
||||
)
|
|
@ -0,0 +1,72 @@
|
|||
from typing import Literal, Optional
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.tools.__base.tool import ToolParameter
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
|
||||
|
||||
class ToolApiEntity(BaseModel):
|
||||
author: str
|
||||
name: str # identifier
|
||||
label: I18nObject # label
|
||||
description: I18nObject
|
||||
parameters: Optional[list[ToolParameter]] = None
|
||||
labels: list[str] = Field(default_factory=list)
|
||||
output_schema: Optional[dict] = None
|
||||
|
||||
|
||||
ToolProviderTypeApiLiteral = Optional[Literal["builtin", "api", "workflow"]]
|
||||
|
||||
|
||||
class ToolProviderApiEntity(BaseModel):
|
||||
id: str
|
||||
author: str
|
||||
name: str # identifier
|
||||
description: I18nObject
|
||||
icon: str | dict
|
||||
label: I18nObject # label
|
||||
type: ToolProviderType
|
||||
masked_credentials: Optional[dict] = None
|
||||
original_credentials: Optional[dict] = None
|
||||
is_team_authorization: bool = False
|
||||
allow_delete: bool = True
|
||||
plugin_id: Optional[str] = Field(default="", description="The plugin id of the tool")
|
||||
plugin_unique_identifier: Optional[str] = Field(default="", description="The unique identifier of the tool")
|
||||
tools: list[ToolApiEntity] = Field(default_factory=list)
|
||||
labels: list[str] = Field(default_factory=list)
|
||||
|
||||
@field_validator("tools", mode="before")
|
||||
@classmethod
|
||||
def convert_none_to_empty_list(cls, v):
|
||||
return v if v is not None else []
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
# -------------
|
||||
# overwrite tool parameter types for temp fix
|
||||
tools = jsonable_encoder(self.tools)
|
||||
for tool in tools:
|
||||
if tool.get("parameters"):
|
||||
for parameter in tool.get("parameters"):
|
||||
if parameter.get("type") == ToolParameter.ToolParameterType.SYSTEM_FILES.value:
|
||||
parameter["type"] = "files"
|
||||
# -------------
|
||||
|
||||
return {
|
||||
"id": self.id,
|
||||
"author": self.author,
|
||||
"name": self.name,
|
||||
"plugin_id": self.plugin_id,
|
||||
"plugin_unique_identifier": self.plugin_unique_identifier,
|
||||
"description": self.description.to_dict(),
|
||||
"icon": self.icon,
|
||||
"label": self.label.to_dict(),
|
||||
"type": self.type.value,
|
||||
"team_credentials": self.masked_credentials,
|
||||
"is_team_authorization": self.is_team_authorization,
|
||||
"allow_delete": self.allow_delete,
|
||||
"tools": tools,
|
||||
"labels": self.labels,
|
||||
}
|
|
@ -0,0 +1,23 @@
|
|||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class I18nObject(BaseModel):
|
||||
"""
|
||||
Model class for i18n object.
|
||||
"""
|
||||
|
||||
en_US: str
|
||||
zh_Hans: Optional[str] = Field(default=None)
|
||||
pt_BR: Optional[str] = Field(default=None)
|
||||
ja_JP: Optional[str] = Field(default=None)
|
||||
|
||||
def __init__(self, **data):
|
||||
super().__init__(**data)
|
||||
self.zh_Hans = self.zh_Hans or self.en_US
|
||||
self.pt_BR = self.pt_BR or self.en_US
|
||||
self.ja_JP = self.ja_JP or self.en_US
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {"zh_Hans": self.zh_Hans, "en_US": self.en_US, "pt_BR": self.pt_BR, "ja_JP": self.ja_JP}
|
|
@ -0,0 +1 @@
|
|||
TOOL_SELECTOR_MODEL_IDENTITY = "__dify__tool_selector__"
|
|
@ -0,0 +1 @@
|
|||
|
|
@ -0,0 +1,29 @@
|
|||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.tools.entities.tool_entities import ToolParameter
|
||||
|
||||
|
||||
class ApiToolBundle(BaseModel):
|
||||
"""
|
||||
This class is used to store the schema information of an api based tool.
|
||||
such as the url, the method, the parameters, etc.
|
||||
"""
|
||||
|
||||
# server_url
|
||||
server_url: str
|
||||
# method
|
||||
method: str
|
||||
# summary
|
||||
summary: Optional[str] = None
|
||||
# operation_id
|
||||
operation_id: Optional[str] = None
|
||||
# parameters
|
||||
parameters: Optional[list[ToolParameter]] = None
|
||||
# author
|
||||
author: str
|
||||
# icon
|
||||
icon: Optional[str] = None
|
||||
# openapi operation
|
||||
openapi: dict
|
|
@ -0,0 +1,427 @@
|
|||
import base64
|
||||
import enum
|
||||
from collections.abc import Mapping
|
||||
from enum import Enum
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_serializer, field_validator, model_validator
|
||||
|
||||
from core.entities.provider_entities import ProviderConfig
|
||||
from core.plugin.entities.parameters import (
|
||||
PluginParameter,
|
||||
PluginParameterOption,
|
||||
PluginParameterType,
|
||||
as_normal_type,
|
||||
cast_parameter_value,
|
||||
init_frontend_parameter,
|
||||
)
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.constants import TOOL_SELECTOR_MODEL_IDENTITY
|
||||
|
||||
|
||||
class ToolLabelEnum(Enum):
|
||||
SEARCH = "search"
|
||||
IMAGE = "image"
|
||||
VIDEOS = "videos"
|
||||
WEATHER = "weather"
|
||||
FINANCE = "finance"
|
||||
DESIGN = "design"
|
||||
TRAVEL = "travel"
|
||||
SOCIAL = "social"
|
||||
NEWS = "news"
|
||||
MEDICAL = "medical"
|
||||
PRODUCTIVITY = "productivity"
|
||||
EDUCATION = "education"
|
||||
BUSINESS = "business"
|
||||
ENTERTAINMENT = "entertainment"
|
||||
UTILITIES = "utilities"
|
||||
OTHER = "other"
|
||||
|
||||
|
||||
class ToolProviderType(enum.StrEnum):
|
||||
"""
|
||||
Enum class for tool provider
|
||||
"""
|
||||
|
||||
PLUGIN = "plugin"
|
||||
BUILT_IN = "builtin"
|
||||
WORKFLOW = "workflow"
|
||||
API = "api"
|
||||
APP = "app"
|
||||
DATASET_RETRIEVAL = "dataset-retrieval"
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> "ToolProviderType":
|
||||
"""
|
||||
Get value of given mode.
|
||||
|
||||
:param value: mode value
|
||||
:return: mode
|
||||
"""
|
||||
for mode in cls:
|
||||
if mode.value == value:
|
||||
return mode
|
||||
raise ValueError(f"invalid mode value {value}")
|
||||
|
||||
|
||||
class ApiProviderSchemaType(Enum):
|
||||
"""
|
||||
Enum class for api provider schema type.
|
||||
"""
|
||||
|
||||
OPENAPI = "openapi"
|
||||
SWAGGER = "swagger"
|
||||
OPENAI_PLUGIN = "openai_plugin"
|
||||
OPENAI_ACTIONS = "openai_actions"
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> "ApiProviderSchemaType":
|
||||
"""
|
||||
Get value of given mode.
|
||||
|
||||
:param value: mode value
|
||||
:return: mode
|
||||
"""
|
||||
for mode in cls:
|
||||
if mode.value == value:
|
||||
return mode
|
||||
raise ValueError(f"invalid mode value {value}")
|
||||
|
||||
|
||||
class ApiProviderAuthType(Enum):
|
||||
"""
|
||||
Enum class for api provider auth type.
|
||||
"""
|
||||
|
||||
NONE = "none"
|
||||
API_KEY = "api_key"
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> "ApiProviderAuthType":
|
||||
"""
|
||||
Get value of given mode.
|
||||
|
||||
:param value: mode value
|
||||
:return: mode
|
||||
"""
|
||||
for mode in cls:
|
||||
if mode.value == value:
|
||||
return mode
|
||||
raise ValueError(f"invalid mode value {value}")
|
||||
|
||||
|
||||
class ToolInvokeMessage(BaseModel):
|
||||
class TextMessage(BaseModel):
|
||||
text: str
|
||||
|
||||
class JsonMessage(BaseModel):
|
||||
json_object: dict
|
||||
|
||||
class BlobMessage(BaseModel):
|
||||
blob: bytes
|
||||
|
||||
class FileMessage(BaseModel):
|
||||
pass
|
||||
|
||||
class VariableMessage(BaseModel):
|
||||
variable_name: str = Field(..., description="The name of the variable")
|
||||
variable_value: Any = Field(..., description="The value of the variable")
|
||||
stream: bool = Field(default=False, description="Whether the variable is streamed")
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def transform_variable_value(cls, values) -> Any:
|
||||
"""
|
||||
Only basic types and lists are allowed.
|
||||
"""
|
||||
value = values.get("variable_value")
|
||||
if not isinstance(value, dict | list | str | int | float | bool):
|
||||
raise ValueError("Only basic types and lists are allowed.")
|
||||
|
||||
# if stream is true, the value must be a string
|
||||
if values.get("stream"):
|
||||
if not isinstance(value, str):
|
||||
raise ValueError("When 'stream' is True, 'variable_value' must be a string.")
|
||||
|
||||
return values
|
||||
|
||||
@field_validator("variable_name", mode="before")
|
||||
@classmethod
|
||||
def transform_variable_name(cls, value: str) -> str:
|
||||
"""
|
||||
The variable name must be a string.
|
||||
"""
|
||||
if value in {"json", "text", "files"}:
|
||||
raise ValueError(f"The variable name '{value}' is reserved.")
|
||||
return value
|
||||
|
||||
class LogMessage(BaseModel):
|
||||
class LogStatus(Enum):
|
||||
START = "start"
|
||||
ERROR = "error"
|
||||
SUCCESS = "success"
|
||||
|
||||
id: str
|
||||
label: str = Field(..., description="The label of the log")
|
||||
parent_id: Optional[str] = Field(default=None, description="Leave empty for root log")
|
||||
error: Optional[str] = Field(default=None, description="The error message")
|
||||
status: LogStatus = Field(..., description="The status of the log")
|
||||
data: Mapping[str, Any] = Field(..., description="Detailed log data")
|
||||
metadata: Optional[Mapping[str, Any]] = Field(default=None, description="The metadata of the log")
|
||||
|
||||
class MessageType(Enum):
|
||||
TEXT = "text"
|
||||
IMAGE = "image"
|
||||
LINK = "link"
|
||||
BLOB = "blob"
|
||||
JSON = "json"
|
||||
IMAGE_LINK = "image_link"
|
||||
BINARY_LINK = "binary_link"
|
||||
VARIABLE = "variable"
|
||||
FILE = "file"
|
||||
LOG = "log"
|
||||
|
||||
type: MessageType = MessageType.TEXT
|
||||
"""
|
||||
plain text, image url or link url
|
||||
"""
|
||||
message: JsonMessage | TextMessage | BlobMessage | LogMessage | FileMessage | None | VariableMessage
|
||||
meta: dict[str, Any] | None = None
|
||||
|
||||
@field_validator("message", mode="before")
|
||||
@classmethod
|
||||
def decode_blob_message(cls, v):
|
||||
if isinstance(v, dict) and "blob" in v:
|
||||
try:
|
||||
v["blob"] = base64.b64decode(v["blob"])
|
||||
except Exception:
|
||||
pass
|
||||
return v
|
||||
|
||||
@field_serializer("message")
|
||||
def serialize_message(self, v):
|
||||
if isinstance(v, self.BlobMessage):
|
||||
return {"blob": base64.b64encode(v.blob).decode("utf-8")}
|
||||
return v
|
||||
|
||||
|
||||
class ToolInvokeMessageBinary(BaseModel):
|
||||
mimetype: str = Field(..., description="The mimetype of the binary")
|
||||
url: str = Field(..., description="The url of the binary")
|
||||
file_var: Optional[dict[str, Any]] = None
|
||||
|
||||
|
||||
class ToolParameter(PluginParameter):
|
||||
"""
|
||||
Overrides type
|
||||
"""
|
||||
|
||||
class ToolParameterType(enum.StrEnum):
|
||||
"""
|
||||
removes TOOLS_SELECTOR from PluginParameterType
|
||||
"""
|
||||
|
||||
STRING = PluginParameterType.STRING.value
|
||||
NUMBER = PluginParameterType.NUMBER.value
|
||||
BOOLEAN = PluginParameterType.BOOLEAN.value
|
||||
SELECT = PluginParameterType.SELECT.value
|
||||
SECRET_INPUT = PluginParameterType.SECRET_INPUT.value
|
||||
FILE = PluginParameterType.FILE.value
|
||||
FILES = PluginParameterType.FILES.value
|
||||
APP_SELECTOR = PluginParameterType.APP_SELECTOR.value
|
||||
MODEL_SELECTOR = PluginParameterType.MODEL_SELECTOR.value
|
||||
|
||||
# deprecated, should not use.
|
||||
SYSTEM_FILES = PluginParameterType.SYSTEM_FILES.value
|
||||
|
||||
def as_normal_type(self):
|
||||
return as_normal_type(self)
|
||||
|
||||
def cast_value(self, value: Any):
|
||||
return cast_parameter_value(self, value)
|
||||
|
||||
class ToolParameterForm(Enum):
|
||||
SCHEMA = "schema" # should be set while adding tool
|
||||
FORM = "form" # should be set before invoking tool
|
||||
LLM = "llm" # will be set by LLM
|
||||
|
||||
type: ToolParameterType = Field(..., description="The type of the parameter")
|
||||
human_description: Optional[I18nObject] = Field(default=None, description="The description presented to the user")
|
||||
form: ToolParameterForm = Field(..., description="The form of the parameter, schema/form/llm")
|
||||
llm_description: Optional[str] = None
|
||||
|
||||
@classmethod
|
||||
def get_simple_instance(
|
||||
cls,
|
||||
name: str,
|
||||
llm_description: str,
|
||||
typ: ToolParameterType,
|
||||
required: bool,
|
||||
options: Optional[list[str]] = None,
|
||||
) -> "ToolParameter":
|
||||
"""
|
||||
get a simple tool parameter
|
||||
|
||||
:param name: the name of the parameter
|
||||
:param llm_description: the description presented to the LLM
|
||||
:param typ: the type of the parameter
|
||||
:param required: if the parameter is required
|
||||
:param options: the options of the parameter
|
||||
"""
|
||||
# convert options to ToolParameterOption
|
||||
# FIXME fix the type error
|
||||
if options:
|
||||
option_objs = [
|
||||
PluginParameterOption(value=option, label=I18nObject(en_US=option, zh_Hans=option))
|
||||
for option in options
|
||||
]
|
||||
else:
|
||||
option_objs = []
|
||||
|
||||
return cls(
|
||||
name=name,
|
||||
label=I18nObject(en_US="", zh_Hans=""),
|
||||
placeholder=None,
|
||||
human_description=I18nObject(en_US="", zh_Hans=""),
|
||||
type=typ,
|
||||
form=cls.ToolParameterForm.LLM,
|
||||
llm_description=llm_description,
|
||||
required=required,
|
||||
options=option_objs,
|
||||
)
|
||||
|
||||
def init_frontend_parameter(self, value: Any):
|
||||
return init_frontend_parameter(self, self.type, value)
|
||||
|
||||
|
||||
class ToolProviderIdentity(BaseModel):
|
||||
author: str = Field(..., description="The author of the tool")
|
||||
name: str = Field(..., description="The name of the tool")
|
||||
description: I18nObject = Field(..., description="The description of the tool")
|
||||
icon: str = Field(..., description="The icon of the tool")
|
||||
label: I18nObject = Field(..., description="The label of the tool")
|
||||
tags: Optional[list[ToolLabelEnum]] = Field(
|
||||
default=[],
|
||||
description="The tags of the tool",
|
||||
)
|
||||
|
||||
|
||||
class ToolIdentity(BaseModel):
|
||||
author: str = Field(..., description="The author of the tool")
|
||||
name: str = Field(..., description="The name of the tool")
|
||||
label: I18nObject = Field(..., description="The label of the tool")
|
||||
provider: str = Field(..., description="The provider of the tool")
|
||||
icon: Optional[str] = None
|
||||
|
||||
|
||||
class ToolDescription(BaseModel):
|
||||
human: I18nObject = Field(..., description="The description presented to the user")
|
||||
llm: str = Field(..., description="The description presented to the LLM")
|
||||
|
||||
|
||||
class ToolEntity(BaseModel):
|
||||
identity: ToolIdentity
|
||||
parameters: list[ToolParameter] = Field(default_factory=list)
|
||||
description: Optional[ToolDescription] = None
|
||||
output_schema: Optional[dict] = None
|
||||
has_runtime_parameters: bool = Field(default=False, description="Whether the tool has runtime parameters")
|
||||
|
||||
# pydantic configs
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
@field_validator("parameters", mode="before")
|
||||
@classmethod
|
||||
def set_parameters(cls, v, validation_info: ValidationInfo) -> list[ToolParameter]:
|
||||
return v or []
|
||||
|
||||
|
||||
class ToolProviderEntity(BaseModel):
|
||||
identity: ToolProviderIdentity
|
||||
plugin_id: Optional[str] = None
|
||||
credentials_schema: list[ProviderConfig] = Field(default_factory=list)
|
||||
|
||||
|
||||
class ToolProviderEntityWithPlugin(ToolProviderEntity):
|
||||
tools: list[ToolEntity] = Field(default_factory=list)
|
||||
|
||||
|
||||
class WorkflowToolParameterConfiguration(BaseModel):
|
||||
"""
|
||||
Workflow tool configuration
|
||||
"""
|
||||
|
||||
name: str = Field(..., description="The name of the parameter")
|
||||
description: str = Field(..., description="The description of the parameter")
|
||||
form: ToolParameter.ToolParameterForm = Field(..., description="The form of the parameter")
|
||||
|
||||
|
||||
class ToolInvokeMeta(BaseModel):
|
||||
"""
|
||||
Tool invoke meta
|
||||
"""
|
||||
|
||||
time_cost: float = Field(..., description="The time cost of the tool invoke")
|
||||
error: Optional[str] = None
|
||||
tool_config: Optional[dict] = None
|
||||
|
||||
@classmethod
|
||||
def empty(cls) -> "ToolInvokeMeta":
|
||||
"""
|
||||
Get an empty instance of ToolInvokeMeta
|
||||
"""
|
||||
return cls(time_cost=0.0, error=None, tool_config={})
|
||||
|
||||
@classmethod
|
||||
def error_instance(cls, error: str) -> "ToolInvokeMeta":
|
||||
"""
|
||||
Get an instance of ToolInvokeMeta with error
|
||||
"""
|
||||
return cls(time_cost=0.0, error=error, tool_config={})
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"time_cost": self.time_cost,
|
||||
"error": self.error,
|
||||
"tool_config": self.tool_config,
|
||||
}
|
||||
|
||||
|
||||
class ToolLabel(BaseModel):
|
||||
"""
|
||||
Tool label
|
||||
"""
|
||||
|
||||
name: str = Field(..., description="The name of the tool")
|
||||
label: I18nObject = Field(..., description="The label of the tool")
|
||||
icon: str = Field(..., description="The icon of the tool")
|
||||
|
||||
|
||||
class ToolInvokeFrom(Enum):
|
||||
"""
|
||||
Enum class for tool invoke
|
||||
"""
|
||||
|
||||
WORKFLOW = "workflow"
|
||||
AGENT = "agent"
|
||||
PLUGIN = "plugin"
|
||||
|
||||
|
||||
class ToolSelector(BaseModel):
|
||||
dify_model_identity: str = TOOL_SELECTOR_MODEL_IDENTITY
|
||||
|
||||
class Parameter(BaseModel):
|
||||
name: str = Field(..., description="The name of the parameter")
|
||||
type: ToolParameter.ToolParameterType = Field(..., description="The type of the parameter")
|
||||
required: bool = Field(..., description="Whether the parameter is required")
|
||||
description: str = Field(..., description="The description of the parameter")
|
||||
default: Optional[Union[int, float, str]] = None
|
||||
options: Optional[list[PluginParameterOption]] = None
|
||||
|
||||
provider_id: str = Field(..., description="The id of the provider")
|
||||
tool_name: str = Field(..., description="The name of the tool")
|
||||
tool_description: str = Field(..., description="The description of the tool")
|
||||
tool_configuration: Mapping[str, Any] = Field(..., description="Configuration, type form")
|
||||
tool_parameters: Mapping[str, Parameter] = Field(..., description="Parameters, type llm")
|
||||
|
||||
def to_plugin_parameter(self) -> dict[str, Any]:
|
||||
return self.model_dump()
|
|
@ -0,0 +1,111 @@
|
|||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ToolLabel, ToolLabelEnum
|
||||
|
||||
ICONS = {
|
||||
ToolLabelEnum.SEARCH: """<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
|
||||
<path d="M7.33398 1.3335C10.646 1.3335 13.334 4.0215 13.334 7.3335C13.334 10.6455 10.646 13.3335 7.33398 13.3335C4.02198 13.3335 1.33398 10.6455 1.33398 7.3335C1.33398 4.0215 4.02198 1.3335 7.33398 1.3335ZM7.33398 12.0002C9.91232 12.0002 12.0007 9.91183 12.0007 7.3335C12.0007 4.75516 9.91232 2.66683 7.33398 2.66683C4.75565 2.66683 2.66732 4.75516 2.66732 7.3335C2.66732 9.91183 4.75565 12.0002 7.33398 12.0002ZM12.9909 12.0476L14.8764 13.9332L13.9337 14.876L12.0481 12.9904L12.9909 12.0476Z" fill="#344054"/>
|
||||
</svg>""", # noqa: E501
|
||||
ToolLabelEnum.IMAGE: """<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
|
||||
<path d="M13.0514 9.71752L10.4718 7.13792C10.2115 6.87752 9.78932 6.87752 9.52898 7.13792L4.57721 12.0897C3.4097 11.1113 2.66732 9.64232 2.66732 7.99992C2.66732 5.0544 5.05513 2.66659 8.00065 2.66659C10.9462 2.66659 13.334 5.0544 13.334 7.99992C13.334 8.60085 13.2346 9.17852 13.0514 9.71752ZM5.72683 12.8257L10.0004 8.55212L12.4259 10.9777C11.4668 12.4001 9.84152 13.3331 8.00038 13.3331C7.18632 13.3331 6.41628 13.1511 5.72683 12.8257ZM8.00065 14.6666C11.6825 14.6666 14.6673 11.6818 14.6673 7.99992C14.6673 4.31802 11.6825 1.33325 8.00065 1.33325C4.31875 1.33325 1.33398 4.31802 1.33398 7.99992C1.33398 11.6818 4.31875 14.6666 8.00065 14.6666ZM7.33398 6.66658C7.33398 7.40299 6.73705 7.99992 6.00065 7.99992C5.26427 7.99992 4.66732 7.40299 4.66732 6.66658C4.66732 5.9302 5.26427 5.33325 6.00065 5.33325C6.73705 5.33325 7.33398 5.9302 7.33398 6.66658Z" fill="#344054"/>
|
||||
</svg>""", # noqa: E501
|
||||
ToolLabelEnum.VIDEOS: """<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
|
||||
<path d="M8.00065 13.3333H13.334V14.6666H8.00065C4.31875 14.6666 1.33398 11.6818 1.33398 7.99992C1.33398 4.31802 4.31875 1.33325 8.00065 1.33325C11.6825 1.33325 14.6673 4.31802 14.6673 7.99992C14.6673 9.50072 14.1714 10.8857 13.3345 11.9999H11.5284C12.6356 11.0227 13.334 9.59285 13.334 7.99992C13.334 5.0544 10.9462 2.66659 8.00065 2.66659C5.05513 2.66659 2.66732 5.0544 2.66732 7.99992C2.66732 10.9455 5.05513 13.3333 8.00065 13.3333ZM8.00065 6.66658C7.26425 6.66658 6.66732 6.06963 6.66732 5.33325C6.66732 4.59687 7.26425 3.99992 8.00065 3.99992C8.73705 3.99992 9.33398 4.59687 9.33398 5.33325C9.33398 6.06963 8.73705 6.66658 8.00065 6.66658ZM5.33398 9.33325C4.5976 9.33325 4.00065 8.73632 4.00065 7.99992C4.00065 7.26352 4.5976 6.66658 5.33398 6.66658C6.07036 6.66658 6.66732 7.26352 6.66732 7.99992C6.66732 8.73632 6.07036 9.33325 5.33398 9.33325ZM10.6673 9.33325C9.93092 9.33325 9.33398 8.73632 9.33398 7.99992C9.33398 7.26352 9.93092 6.66658 10.6673 6.66658C11.4037 6.66658 12.0007 7.26352 12.0007 7.99992C12.0007 8.73632 11.4037 9.33325 10.6673 9.33325ZM8.00065 11.9999C7.26425 11.9999 6.66732 11.403 6.66732 10.6666C6.66732 9.93018 7.26425 9.33325 8.00065 9.33325C8.73705 9.33325 9.33398 9.93018 9.33398 10.6666C9.33398 11.403 8.73705 11.9999 8.00065 11.9999Z" fill="#344054"/>
|
||||
</svg>""", # noqa: E501
|
||||
ToolLabelEnum.WEATHER: """<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
|
||||
<path d="M6.6553 3.37344C7.42088 2.1484 8.78162 1.3335 10.3327 1.3335C12.7259 1.3335 14.666 3.2736 14.666 5.66683C14.666 6.38704 14.4903 7.06623 14.1794 7.66383C14.8894 8.3325 15.3327 9.28123 15.3327 10.3335C15.3327 12.3586 13.6911 14.0002 11.666 14.0002H5.99935C3.05383 14.0002 0.666016 11.6124 0.666016 8.66683C0.666016 5.72131 3.05383 3.3335 5.99935 3.3335C6.22143 3.3335 6.44034 3.34707 6.6553 3.37344ZM8.03628 3.73629C9.37768 4.29108 10.4435 5.37735 10.9711 6.73256C11.1961 6.68943 11.4284 6.66683 11.666 6.66683C12.1561 6.66683 12.6237 6.76296 13.0511 6.93743C13.2317 6.55162 13.3327 6.12102 13.3327 5.66683C13.3327 4.00998 11.9895 2.66683 10.3327 2.66683C9.41115 2.66683 8.58662 3.08236 8.03628 3.73629ZM11.666 12.6668C12.9547 12.6668 13.9993 11.6222 13.9993 10.3335C13.9993 9.04483 12.9547 8.00016 11.666 8.00016C11.013 8.00016 10.4227 8.26836 9.99922 8.70063C9.99928 8.68936 9.99935 8.6781 9.99935 8.66683C9.99935 6.45769 8.20848 4.66683 5.99935 4.66683C3.79021 4.66683 1.99935 6.45769 1.99935 8.66683C1.99935 10.876 3.79021 12.6668 5.99935 12.6668H11.666Z" fill="#344054"/>
|
||||
</svg>""", # noqa: E501
|
||||
ToolLabelEnum.FINANCE: """<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
|
||||
<path d="M8.00262 14.6685C4.32071 14.6685 1.33594 11.6838 1.33594 8.00184C1.33594 4.31997 4.32071 1.33521 8.00262 1.33521C11.6845 1.33521 14.6693 4.31997 14.6693 8.00184C14.6693 11.6838 11.6845 14.6685 8.00262 14.6685ZM8.00262 13.3352C10.9482 13.3352 13.336 10.9474 13.336 8.00184C13.336 5.05635 10.9482 2.66854 8.00262 2.66854C5.05708 2.66854 2.66927 5.05635 2.66927 8.00184C2.66927 10.9474 5.05708 13.3352 8.00262 13.3352ZM5.66927 9.33517H9.33595C9.52002 9.33517 9.66928 9.18597 9.66928 9.00184C9.66928 8.81777 9.52002 8.66851 9.33595 8.66851H6.66928C5.7488 8.66851 5.0026 7.92237 5.0026 7.00184C5.0026 6.08139 5.7488 5.33521 6.66928 5.33521H7.33595V4.00187H8.66928V5.33521H10.336V6.66851H6.66928C6.48518 6.66851 6.33594 6.81777 6.33594 7.00184C6.33594 7.18597 6.48518 7.33517 6.66928 7.33517H9.33595C10.2564 7.33517 11.0026 8.08137 11.0026 9.00184C11.0026 9.92237 10.2564 10.6685 9.33595 10.6685H8.66928V12.0018H7.33595V10.6685H5.66927V9.33517Z" fill="#344054"/>
|
||||
</svg>""", # noqa: E501
|
||||
ToolLabelEnum.DESIGN: """<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
|
||||
<path d="M4.70152 9.41416L3.2873 10.8284L5.17292 12.714L12.7154 5.17154L10.8298 3.28592L9.41557 4.70013L10.3584 5.64295L9.41557 6.58575L8.47277 5.64295L7.52997 6.58575L8.47277 7.52856L7.52997 8.47136L6.58713 7.52856L5.64433 8.47136L6.58713 9.41416L5.64433 10.357L4.70152 9.41416ZM11.3012 1.87171L14.1296 4.70013C14.39 4.96049 14.39 5.38259 14.1296 5.64295L5.64433 14.1282C5.38397 14.3886 4.96187 14.3886 4.70152 14.1282L1.87309 11.2998C1.61274 11.0394 1.61274 10.6174 1.87309 10.357L10.3584 1.87171C10.6187 1.61136 11.0408 1.61136 11.3012 1.87171ZM9.41557 12.2423L10.3584 11.2995L11.8534 12.7945H12.7962V11.8517L11.3012 10.3567L12.244 9.41383L14.0011 11.171V13.9999H11.1732L9.41557 12.2423ZM3.75861 6.58533L1.87299 4.69971C1.61265 4.43937 1.61265 4.01725 1.87299 3.75691L3.75861 1.87129C4.01896 1.61094 4.44107 1.61094 4.70142 1.87129L6.58704 3.75691L5.64423 4.69971L4.23002 3.2855L3.28721 4.22831L4.70142 5.64253L3.75861 6.58533Z" fill="#344054"/>
|
||||
</svg>""", # noqa: E501
|
||||
ToolLabelEnum.TRAVEL: """<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
|
||||
<path d="M9.44839 2C9.80198 2 10.1411 2.14047 10.3912 2.39053L13.6101 5.60947C13.8602 5.85953 14.0007 6.19866 14.0007 6.55229V11.3333H15.334V12.6667L9.91652 12.6672C9.62032 13.8171 8.57638 14.6667 7.33398 14.6667C6.0916 14.6667 5.04766 13.8171 4.75146 12.6672L2.00065 12.6667C1.63246 12.6667 1.33398 12.3682 1.33398 12V3.33333C1.33398 2.59695 1.93094 2 2.66732 2H9.44839ZM7.33398 10.6667C6.5976 10.6667 6.00065 11.2636 6.00065 12C6.00065 12.7364 6.5976 13.3333 7.33398 13.3333C8.07038 13.3333 8.66732 12.7364 8.66732 12C8.66732 11.2636 8.07038 10.6667 7.33398 10.6667ZM9.44839 3.33333H2.66732V11.3333L4.75128 11.3335C5.04726 10.1833 6.09136 9.33333 7.33398 9.33333C8.57658 9.33333 9.62072 10.1833 9.91665 11.3335L12.6673 11.3333V6.55229L9.44839 3.33333ZM9.33398 4.66667V8.66667H4.00065V4.66667H9.33398ZM8.00065 6H5.33398V7.33333H8.00065V6Z" fill="#344054"/>
|
||||
</svg>""", # noqa: E501
|
||||
ToolLabelEnum.SOCIAL: """<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
|
||||
<path d="M13.334 7.99992C13.334 5.0544 10.9462 2.66659 8.00065 2.66659C5.05513 2.66659 2.66732 5.0544 2.66732 7.99992C2.66732 10.9455 5.05513 13.3333 8.00065 13.3333C9.09518 13.3333 10.1127 13.0035 10.9594 12.438L11.699 13.5475C10.6408 14.2545 9.36885 14.6666 8.00065 14.6666C4.31875 14.6666 1.33398 11.6818 1.33398 7.99992C1.33398 4.31802 4.31875 1.33325 8.00065 1.33325C11.6825 1.33325 14.6673 4.31802 14.6673 7.99992V8.99992C14.6673 10.2886 13.6227 11.3333 12.334 11.3333C11.5312 11.3333 10.8231 10.9278 10.4032 10.3105C9.79678 10.9409 8.94452 11.3333 8.00065 11.3333C6.1597 11.3333 4.66732 9.84085 4.66732 7.99992C4.66732 6.15897 6.1597 4.66658 8.00065 4.66658C8.75118 4.66658 9.44378 4.91464 10.001 5.33325H11.334V8.99992C11.334 9.55219 11.7817 9.99992 12.334 9.99992C12.8863 9.99992 13.334 9.55219 13.334 8.99992V7.99992ZM8.00065 5.99992C6.89605 5.99992 6.00065 6.89532 6.00065 7.99992C6.00065 9.10452 6.89605 9.99992 8.00065 9.99992C9.10525 9.99992 10.0007 9.10452 10.0007 7.99992C10.0007 6.89532 9.10525 5.99992 8.00065 5.99992Z" fill="#344054"/>
|
||||
</svg>""", # noqa: E501
|
||||
ToolLabelEnum.NEWS: """<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
|
||||
<path d="M10.6673 13.3335V2.66683H2.66732V12.6668C2.66732 13.035 2.9658 13.3335 3.33398 13.3335H10.6673ZM12.6673 14.6668H3.33398C2.22942 14.6668 1.33398 13.7714 1.33398 12.6668V2.00016C1.33398 1.63198 1.63246 1.3335 2.00065 1.3335H11.334C11.7022 1.3335 12.0007 1.63198 12.0007 2.00016V6.66683H14.6673V12.6668C14.6673 13.7714 13.7719 14.6668 12.6673 14.6668ZM12.0007 8.00016V12.6668C12.0007 13.035 12.2991 13.3335 12.6673 13.3335C13.0355 13.3335 13.334 13.035 13.334 12.6668V8.00016H12.0007ZM4.00065 4.00016H8.00065V8.00016H4.00065V4.00016ZM5.33398 5.3335V6.66683H6.66732V5.3335H5.33398ZM4.00065 8.66683H9.33398V10.0002H4.00065V8.66683ZM4.00065 10.6668H9.33398V12.0002H4.00065V10.6668Z" fill="#344054"/>
|
||||
</svg>""", # noqa: E501
|
||||
ToolLabelEnum.MEDICAL: """<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
|
||||
<path d="M8.79747 1.51186L10.9641 5.26464C11.1482 5.5835 11.0389 5.99122 10.7201 6.17532L9.85373 6.67474L10.5207 7.83001L9.366 8.49668L8.699 7.34141L7.83333 7.84201C7.51447 8.02608 7.10673 7.91681 6.92267 7.59794L5.69747 5.47632C4.32922 5.89145 3.33333 7.16268 3.33333 8.66654C3.33333 9.08348 3.40987 9.48248 3.54965 9.85034C4.06613 9.52254 4.67762 9.33321 5.33333 9.33321C6.45605 9.33321 7.44913 9.88828 8.05313 10.7389L13.1787 7.78014L13.8454 8.93488L8.5932 11.9672C8.64133 12.1927 8.66667 12.4267 8.66667 12.6665C8.66667 12.895 8.64367 13.1181 8.59993 13.3337L14 13.3332V14.6665L2.66703 14.6673C2.2482 14.1101 2 13.4173 2 12.6665C2 11.9951 2.19855 11.3699 2.54014 10.8467C2.19517 10.1964 2 9.45428 2 8.66654C2 6.66968 3.25421 4.96575 5.01785 4.29953L4.75598 3.84519C4.38779 3.20747 4.60629 2.39202 5.24402 2.02382L6.97607 1.02382C7.6138 0.655637 8.42927 0.874138 8.79747 1.51186ZM5.33333 10.6665C4.22877 10.6665 3.33333 11.562 3.33333 12.6665C3.33333 12.9003 3.37343 13.1247 3.44711 13.3331H7.21953C7.29327 13.1247 7.33333 12.9003 7.33333 12.6665C7.33333 11.562 6.4379 10.6665 5.33333 10.6665ZM7.64273 2.17852L5.91068 3.17852L7.744 6.35395L9.47607 5.35395L7.64273 2.17852Z" fill="#344054"/>
|
||||
</svg>""", # noqa: E501
|
||||
ToolLabelEnum.PRODUCTIVITY: """<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
|
||||
<path d="M6.64807 11.9999H9.35062C9.43862 11.1989 9.84742 10.5376 10.5111 9.81499C10.5858 9.73365 11.0652 9.23752 11.1221 9.16665C11.6872 8.46199 11.9993 7.58992 11.9993 6.66659C11.9993 4.45745 10.2085 2.66659 7.99935 2.66659C5.79021 2.66659 3.99935 4.45745 3.99935 6.66659C3.99935 7.58945 4.31118 8.46105 4.87576 9.16552C4.93271 9.23659 5.41322 9.73405 5.48704 9.81445C6.15112 10.5375 6.56004 11.1989 6.64807 11.9999ZM9.33268 13.3333H6.66602V13.9999H9.33268V13.3333ZM3.83532 9.99939C3.10365 9.08639 2.66602 7.92759 2.66602 6.66659C2.66602 3.72107 5.05383 1.33325 7.99935 1.33325C10.9449 1.33325 13.3327 3.72107 13.3327 6.66659C13.3327 7.92825 12.8945 9.08759 12.1622 10.0009C11.7487 10.5165 10.666 11.3333 10.666 12.3333V13.9999C10.666 14.7363 10.0691 15.3333 9.33268 15.3333H6.66602C5.92964 15.3333 5.33268 14.7363 5.33268 13.9999V12.3333C5.33268 11.3333 4.24907 10.5157 3.83532 9.99939ZM8.66602 6.66979H10.3327L7.33268 10.6698V8.00312H5.66602L8.66602 3.99992V6.66979Z" fill="#344054"/>
|
||||
</svg>""", # noqa: E501
|
||||
ToolLabelEnum.EDUCATION: """<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
|
||||
<path d="M14 2.66683H4.66667C3.93029 2.66683 3.33333 3.26378 3.33333 4.00016C3.33333 4.73654 3.93029 5.3335 4.66667 5.3335H14V14.0002C14 14.3684 13.7015 14.6668 13.3333 14.6668H4.66667C3.19391 14.6668 2 13.4729 2 12.0002V4.00016C2 2.5274 3.19391 1.3335 4.66667 1.3335H13.3333C13.7015 1.3335 14 1.63198 14 2.00016V2.66683ZM3.33333 12.0002C3.33333 12.7366 3.93029 13.3335 4.66667 13.3335H12.6667V6.66683H4.66667C4.18095 6.66683 3.72557 6.53697 3.33333 6.31008V12.0002ZM13.3333 4.66683H4.66667C4.29848 4.66683 4 4.36835 4 4.00016C4 3.63198 4.29848 3.3335 4.66667 3.3335H13.3333V4.66683Z" fill="#344054"/>
|
||||
</svg>""", # noqa: E501
|
||||
ToolLabelEnum.BUSINESS: """<svg xmlns="http://www.w3.org/2000/svg" width="14" height="14" viewBox="0 0 14 14" fill="none">
|
||||
<path d="M3.66732 3.33341V1.33341C3.66732 0.965228 3.9658 0.666748 4.33398 0.666748H9.66732C10.0355 0.666748 10.334 0.965228 10.334 1.33341V3.33341H13.0007C13.3689 3.33341 13.6673 3.63189 13.6673 4.00008V13.3334C13.6673 13.7016 13.3689 14.0001 13.0007 14.0001H1.00065C0.632464 14.0001 0.333984 13.7016 0.333984 13.3334V4.00008C0.333984 3.63189 0.632464 3.33341 1.00065 3.33341H3.66732ZM12.334 8.66675H1.66732V12.6667H12.334V8.66675ZM12.334 4.66675H1.66732V7.33341H3.66732V6.00008H5.00065V7.33341H9.00065V6.00008H10.334V7.33341H12.334V4.66675ZM5.00065 2.00008V3.33341H9.00065V2.00008H5.00065Z" fill="#344054"/>
|
||||
</svg>""", # noqa: E501
|
||||
ToolLabelEnum.ENTERTAINMENT: """<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
|
||||
<path d="M11.3327 2.66675C13.5418 2.66675 15.3327 4.45761 15.3327 6.66675V9.33342C15.3327 11.5425 13.5418 13.3334 11.3327 13.3334H4.66602C2.45688 13.3334 0.666016 11.5425 0.666016 9.33342V6.66675C0.666016 4.45761 2.45688 2.66675 4.66602 2.66675H11.3327ZM11.3327 4.00008H4.66602C3.23788 4.00008 2.07196 5.12273 2.00262 6.53365L1.99935 6.66675V9.33342C1.99935 10.7615 3.122 11.9275 4.53292 11.9968L4.66602 12.0001H11.3327C12.7608 12.0001 13.9267 10.8774 13.9961 9.46648L13.9993 9.33342V6.66675C13.9993 5.23861 12.8767 4.07269 11.4657 4.00335L11.3327 4.00008ZM6.66602 6.00008V7.33342H7.99935V8.66675H6.66535L6.66602 10.0001H5.33268L5.33202 8.66675H3.99935V7.33342H5.33268V6.00008H6.66602ZM11.9993 8.66675V10.0001H10.666V8.66675H11.9993ZM10.666 6.00008V7.33342H9.33268V6.00008H10.666Z" fill="#344054"/>
|
||||
</svg>""", # noqa: E501
|
||||
ToolLabelEnum.UTILITIES: """<svg xmlns="http://www.w3.org/2000/svg" width="13" height="15" viewBox="0 0 13 15" fill="none">
|
||||
<path d="M12.3346 0.333252C12.7028 0.333252 13.0013 0.631732 13.0013 0.999919V4.33325C13.0013 4.70144 12.7028 4.99992 12.3346 4.99992H9.0013V13.6666C9.0013 14.0348 8.70284 14.3333 8.33463 14.3333H5.66797C5.29978 14.3333 5.0013 14.0348 5.0013 13.6666V4.99992H1.33464C0.966449 4.99992 0.667969 4.70144 0.667969 4.33325V2.74527C0.667969 2.49276 0.810635 2.26192 1.0365 2.14899L4.66797 0.333252H12.3346ZM9.0013 1.66659H4.98273L2.0013 3.1573V3.66659H6.33464V12.9999H7.66797V3.66659H9.0013V1.66659ZM11.668 1.66659H10.3346V3.66659H11.668V1.66659Z" fill="#344054"/>
|
||||
</svg>""", # noqa: E501
|
||||
ToolLabelEnum.OTHER: """<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
|
||||
<path d="M8.00052 0.666748L4.00065 7.33342H12.0007L8.00052 0.666748ZM8.00052 3.25828L9.64572 6.00008H6.35553L8.00052 3.25828ZM4.50065 13.3334C3.48813 13.3334 2.66732 12.5126 2.66732 11.5001C2.66732 10.4875 3.48813 9.66675 4.50065 9.66675C5.51317 9.66675 6.33398 10.4875 6.33398 11.5001C6.33398 12.5126 5.51317 13.3334 4.50065 13.3334ZM4.50065 14.6667C6.24955 14.6667 7.66732 13.249 7.66732 11.5001C7.66732 9.75115 6.24955 8.33342 4.50065 8.33342C2.75175 8.33342 1.33398 9.75115 1.33398 11.5001C1.33398 13.249 2.75175 14.6667 4.50065 14.6667ZM10.0007 10.3334V13.0001H12.6673V10.3334H10.0007ZM8.66732 14.3334V9.00008H14.0007V14.3334H8.66732Z" fill="#344054"/>
|
||||
</svg>""", # noqa: E501
|
||||
}
|
||||
|
||||
default_tool_label_dict = {
|
||||
ToolLabelEnum.SEARCH: ToolLabel(
|
||||
name="search", label=I18nObject(en_US="Search", zh_Hans="搜索"), icon=ICONS[ToolLabelEnum.SEARCH]
|
||||
),
|
||||
ToolLabelEnum.IMAGE: ToolLabel(
|
||||
name="image", label=I18nObject(en_US="Image", zh_Hans="图片"), icon=ICONS[ToolLabelEnum.IMAGE]
|
||||
),
|
||||
ToolLabelEnum.VIDEOS: ToolLabel(
|
||||
name="videos", label=I18nObject(en_US="Videos", zh_Hans="视频"), icon=ICONS[ToolLabelEnum.VIDEOS]
|
||||
),
|
||||
ToolLabelEnum.WEATHER: ToolLabel(
|
||||
name="weather", label=I18nObject(en_US="Weather", zh_Hans="天气"), icon=ICONS[ToolLabelEnum.WEATHER]
|
||||
),
|
||||
ToolLabelEnum.FINANCE: ToolLabel(
|
||||
name="finance", label=I18nObject(en_US="Finance", zh_Hans="金融"), icon=ICONS[ToolLabelEnum.FINANCE]
|
||||
),
|
||||
ToolLabelEnum.DESIGN: ToolLabel(
|
||||
name="design", label=I18nObject(en_US="Design", zh_Hans="设计"), icon=ICONS[ToolLabelEnum.DESIGN]
|
||||
),
|
||||
ToolLabelEnum.TRAVEL: ToolLabel(
|
||||
name="travel", label=I18nObject(en_US="Travel", zh_Hans="旅行"), icon=ICONS[ToolLabelEnum.TRAVEL]
|
||||
),
|
||||
ToolLabelEnum.SOCIAL: ToolLabel(
|
||||
name="social", label=I18nObject(en_US="Social", zh_Hans="社交"), icon=ICONS[ToolLabelEnum.SOCIAL]
|
||||
),
|
||||
ToolLabelEnum.NEWS: ToolLabel(
|
||||
name="news", label=I18nObject(en_US="News", zh_Hans="新闻"), icon=ICONS[ToolLabelEnum.NEWS]
|
||||
),
|
||||
ToolLabelEnum.MEDICAL: ToolLabel(
|
||||
name="medical", label=I18nObject(en_US="Medical", zh_Hans="医疗"), icon=ICONS[ToolLabelEnum.MEDICAL]
|
||||
),
|
||||
ToolLabelEnum.PRODUCTIVITY: ToolLabel(
|
||||
name="productivity",
|
||||
label=I18nObject(en_US="Productivity", zh_Hans="生产力"),
|
||||
icon=ICONS[ToolLabelEnum.PRODUCTIVITY],
|
||||
),
|
||||
ToolLabelEnum.EDUCATION: ToolLabel(
|
||||
name="education", label=I18nObject(en_US="Education", zh_Hans="教育"), icon=ICONS[ToolLabelEnum.EDUCATION]
|
||||
),
|
||||
ToolLabelEnum.BUSINESS: ToolLabel(
|
||||
name="business", label=I18nObject(en_US="Business", zh_Hans="商业"), icon=ICONS[ToolLabelEnum.BUSINESS]
|
||||
),
|
||||
ToolLabelEnum.ENTERTAINMENT: ToolLabel(
|
||||
name="entertainment",
|
||||
label=I18nObject(en_US="Entertainment", zh_Hans="娱乐"),
|
||||
icon=ICONS[ToolLabelEnum.ENTERTAINMENT],
|
||||
),
|
||||
ToolLabelEnum.UTILITIES: ToolLabel(
|
||||
name="utilities", label=I18nObject(en_US="Utilities", zh_Hans="工具"), icon=ICONS[ToolLabelEnum.UTILITIES]
|
||||
),
|
||||
ToolLabelEnum.OTHER: ToolLabel(
|
||||
name="other", label=I18nObject(en_US="Other", zh_Hans="其他"), icon=ICONS[ToolLabelEnum.OTHER]
|
||||
),
|
||||
}
|
||||
|
||||
default_tool_labels = [v for k, v in default_tool_label_dict.items()]
|
||||
default_tool_label_name_list = [label.name for label in default_tool_labels]
|
|
@ -0,0 +1,37 @@
|
|||
from core.tools.entities.tool_entities import ToolInvokeMeta
|
||||
|
||||
|
||||
class ToolProviderNotFoundError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
class ToolNotFoundError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
class ToolParameterValidationError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
class ToolProviderCredentialValidationError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
class ToolNotSupportedError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
class ToolInvokeError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
class ToolApiSchemaError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
class ToolEngineInvokeError(Exception):
|
||||
meta: ToolInvokeMeta
|
||||
|
||||
def __init__(self, meta, **kwargs):
|
||||
self.meta = meta
|
||||
super().__init__(**kwargs)
|
|
@ -0,0 +1,79 @@
|
|||
from typing import Any
|
||||
|
||||
from core.plugin.manager.tool import PluginToolManager
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
from core.tools.builtin_tool.provider import BuiltinToolProviderController
|
||||
from core.tools.entities.tool_entities import ToolProviderEntityWithPlugin, ToolProviderType
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.plugin_tool.tool import PluginTool
|
||||
|
||||
|
||||
class PluginToolProviderController(BuiltinToolProviderController):
|
||||
entity: ToolProviderEntityWithPlugin
|
||||
tenant_id: str
|
||||
plugin_id: str
|
||||
plugin_unique_identifier: str
|
||||
|
||||
def __init__(
|
||||
self, entity: ToolProviderEntityWithPlugin, plugin_id: str, plugin_unique_identifier: str, tenant_id: str
|
||||
) -> None:
|
||||
self.entity = entity
|
||||
self.tenant_id = tenant_id
|
||||
self.plugin_id = plugin_id
|
||||
self.plugin_unique_identifier = plugin_unique_identifier
|
||||
|
||||
@property
|
||||
def provider_type(self) -> ToolProviderType:
|
||||
"""
|
||||
returns the type of the provider
|
||||
|
||||
:return: type of the provider
|
||||
"""
|
||||
return ToolProviderType.PLUGIN
|
||||
|
||||
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None:
|
||||
"""
|
||||
validate the credentials of the provider
|
||||
"""
|
||||
manager = PluginToolManager()
|
||||
if not manager.validate_provider_credentials(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=user_id,
|
||||
provider=self.entity.identity.name,
|
||||
credentials=credentials,
|
||||
):
|
||||
raise ToolProviderCredentialValidationError("Invalid credentials")
|
||||
|
||||
def get_tool(self, tool_name: str) -> PluginTool: # type: ignore
|
||||
"""
|
||||
return tool with given name
|
||||
"""
|
||||
tool_entity = next(
|
||||
(tool_entity for tool_entity in self.entity.tools if tool_entity.identity.name == tool_name), None
|
||||
)
|
||||
|
||||
if not tool_entity:
|
||||
raise ValueError(f"Tool with name {tool_name} not found")
|
||||
|
||||
return PluginTool(
|
||||
entity=tool_entity,
|
||||
runtime=ToolRuntime(tenant_id=self.tenant_id),
|
||||
tenant_id=self.tenant_id,
|
||||
icon=self.entity.identity.icon,
|
||||
plugin_unique_identifier=self.plugin_unique_identifier,
|
||||
)
|
||||
|
||||
def get_tools(self) -> list[PluginTool]: # type: ignore
|
||||
"""
|
||||
get all tools
|
||||
"""
|
||||
return [
|
||||
PluginTool(
|
||||
entity=tool_entity,
|
||||
runtime=ToolRuntime(tenant_id=self.tenant_id),
|
||||
tenant_id=self.tenant_id,
|
||||
icon=self.entity.identity.icon,
|
||||
plugin_unique_identifier=self.plugin_unique_identifier,
|
||||
)
|
||||
for tool_entity in self.entity.tools
|
||||
]
|
|
@ -0,0 +1,89 @@
|
|||
from collections.abc import Generator
|
||||
from typing import Any, Optional
|
||||
|
||||
from core.plugin.manager.tool import PluginToolManager
|
||||
from core.plugin.utils.converter import convert_parameters_to_plugin_format
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolParameter, ToolProviderType
|
||||
|
||||
|
||||
class PluginTool(Tool):
|
||||
tenant_id: str
|
||||
icon: str
|
||||
plugin_unique_identifier: str
|
||||
runtime_parameters: Optional[list[ToolParameter]]
|
||||
|
||||
def __init__(
|
||||
self, entity: ToolEntity, runtime: ToolRuntime, tenant_id: str, icon: str, plugin_unique_identifier: str
|
||||
) -> None:
|
||||
super().__init__(entity, runtime)
|
||||
self.tenant_id = tenant_id
|
||||
self.icon = icon
|
||||
self.plugin_unique_identifier = plugin_unique_identifier
|
||||
self.runtime_parameters = None
|
||||
|
||||
def tool_provider_type(self) -> ToolProviderType:
|
||||
return ToolProviderType.PLUGIN
|
||||
|
||||
def _invoke(
|
||||
self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
conversation_id: Optional[str] = None,
|
||||
app_id: Optional[str] = None,
|
||||
message_id: Optional[str] = None,
|
||||
) -> Generator[ToolInvokeMessage, None, None]:
|
||||
manager = PluginToolManager()
|
||||
|
||||
tool_parameters = convert_parameters_to_plugin_format(tool_parameters)
|
||||
|
||||
yield from manager.invoke(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=user_id,
|
||||
tool_provider=self.entity.identity.provider,
|
||||
tool_name=self.entity.identity.name,
|
||||
credentials=self.runtime.credentials,
|
||||
tool_parameters=tool_parameters,
|
||||
conversation_id=conversation_id,
|
||||
app_id=app_id,
|
||||
message_id=message_id,
|
||||
)
|
||||
|
||||
def fork_tool_runtime(self, runtime: ToolRuntime) -> "PluginTool":
|
||||
return PluginTool(
|
||||
entity=self.entity,
|
||||
runtime=runtime,
|
||||
tenant_id=self.tenant_id,
|
||||
icon=self.icon,
|
||||
plugin_unique_identifier=self.plugin_unique_identifier,
|
||||
)
|
||||
|
||||
def get_runtime_parameters(
|
||||
self,
|
||||
conversation_id: Optional[str] = None,
|
||||
app_id: Optional[str] = None,
|
||||
message_id: Optional[str] = None,
|
||||
) -> list[ToolParameter]:
|
||||
"""
|
||||
get the runtime parameters
|
||||
"""
|
||||
if not self.entity.has_runtime_parameters:
|
||||
return self.entity.parameters
|
||||
|
||||
if self.runtime_parameters is not None:
|
||||
return self.runtime_parameters
|
||||
|
||||
manager = PluginToolManager()
|
||||
self.runtime_parameters = manager.get_runtime_parameters(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id="",
|
||||
provider=self.entity.identity.provider,
|
||||
tool=self.entity.identity.name,
|
||||
credentials=self.runtime.credentials,
|
||||
conversation_id=conversation_id,
|
||||
app_id=app_id,
|
||||
message_id=message_id,
|
||||
)
|
||||
|
||||
return self.runtime_parameters
|
|
@ -0,0 +1,357 @@
|
|||
import json
|
||||
from collections.abc import Generator, Iterable
|
||||
from copy import deepcopy
|
||||
from datetime import UTC, datetime
|
||||
from mimetypes import guess_type
|
||||
from typing import Any, Optional, Union, cast
|
||||
|
||||
from yarl import URL
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler
|
||||
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
|
||||
from core.file import FileType
|
||||
from core.file.models import FileTransferMethod
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.entities.tool_entities import (
|
||||
ToolInvokeMessage,
|
||||
ToolInvokeMessageBinary,
|
||||
ToolInvokeMeta,
|
||||
ToolParameter,
|
||||
)
|
||||
from core.tools.errors import (
|
||||
ToolEngineInvokeError,
|
||||
ToolInvokeError,
|
||||
ToolNotFoundError,
|
||||
ToolNotSupportedError,
|
||||
ToolParameterValidationError,
|
||||
ToolProviderCredentialValidationError,
|
||||
ToolProviderNotFoundError,
|
||||
)
|
||||
from core.tools.utils.message_transformer import ToolFileMessageTransformer
|
||||
from core.tools.workflow_as_tool.tool import WorkflowTool
|
||||
from extensions.ext_database import db
|
||||
from models.enums import CreatedByRole
|
||||
from models.model import Message, MessageFile
|
||||
|
||||
|
||||
class ToolEngine:
|
||||
"""
|
||||
Tool runtime engine take care of the tool executions.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def agent_invoke(
|
||||
tool: Tool,
|
||||
tool_parameters: Union[str, dict],
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
message: Message,
|
||||
invoke_from: InvokeFrom,
|
||||
agent_tool_callback: DifyAgentCallbackHandler,
|
||||
trace_manager: Optional[TraceQueueManager] = None,
|
||||
conversation_id: Optional[str] = None,
|
||||
app_id: Optional[str] = None,
|
||||
message_id: Optional[str] = None,
|
||||
) -> tuple[str, list[str], ToolInvokeMeta]:
|
||||
"""
|
||||
Agent invokes the tool with the given arguments.
|
||||
"""
|
||||
# check if arguments is a string
|
||||
if isinstance(tool_parameters, str):
|
||||
# check if this tool has only one parameter
|
||||
parameters = [
|
||||
parameter
|
||||
for parameter in tool.get_runtime_parameters()
|
||||
if parameter.form == ToolParameter.ToolParameterForm.LLM
|
||||
]
|
||||
if parameters and len(parameters) == 1:
|
||||
tool_parameters = {parameters[0].name: tool_parameters}
|
||||
else:
|
||||
try:
|
||||
tool_parameters = json.loads(tool_parameters)
|
||||
except Exception:
|
||||
pass
|
||||
if not isinstance(tool_parameters, dict):
|
||||
raise ValueError(f"tool_parameters should be a dict, but got a string: {tool_parameters}")
|
||||
|
||||
try:
|
||||
# hit the callback handler
|
||||
agent_tool_callback.on_tool_start(tool_name=tool.entity.identity.name, tool_inputs=tool_parameters)
|
||||
|
||||
messages = ToolEngine._invoke(tool, tool_parameters, user_id, conversation_id, app_id, message_id)
|
||||
invocation_meta_dict: dict[str, ToolInvokeMeta] = {}
|
||||
|
||||
def message_callback(
|
||||
invocation_meta_dict: dict, messages: Generator[ToolInvokeMessage | ToolInvokeMeta, None, None]
|
||||
):
|
||||
for message in messages:
|
||||
if isinstance(message, ToolInvokeMeta):
|
||||
invocation_meta_dict["meta"] = message
|
||||
else:
|
||||
yield message
|
||||
|
||||
messages = ToolFileMessageTransformer.transform_tool_invoke_messages(
|
||||
messages=message_callback(invocation_meta_dict, messages),
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
conversation_id=message.conversation_id,
|
||||
)
|
||||
|
||||
message_list = list(messages)
|
||||
|
||||
# extract binary data from tool invoke message
|
||||
binary_files = ToolEngine._extract_tool_response_binary_and_text(message_list)
|
||||
# create message file
|
||||
message_files = ToolEngine._create_message_files(
|
||||
tool_messages=binary_files, agent_message=message, invoke_from=invoke_from, user_id=user_id
|
||||
)
|
||||
|
||||
plain_text = ToolEngine._convert_tool_response_to_str(message_list)
|
||||
|
||||
meta = invocation_meta_dict["meta"]
|
||||
|
||||
# hit the callback handler
|
||||
agent_tool_callback.on_tool_end(
|
||||
tool_name=tool.entity.identity.name,
|
||||
tool_inputs=tool_parameters,
|
||||
tool_outputs=plain_text,
|
||||
message_id=message.id,
|
||||
trace_manager=trace_manager,
|
||||
)
|
||||
|
||||
# transform tool invoke message to get LLM friendly message
|
||||
return plain_text, message_files, meta
|
||||
except ToolProviderCredentialValidationError as e:
|
||||
error_response = "Please check your tool provider credentials"
|
||||
agent_tool_callback.on_tool_error(e)
|
||||
except (ToolNotFoundError, ToolNotSupportedError, ToolProviderNotFoundError) as e:
|
||||
error_response = f"there is not a tool named {tool.entity.identity.name}"
|
||||
agent_tool_callback.on_tool_error(e)
|
||||
except ToolParameterValidationError as e:
|
||||
error_response = f"tool parameters validation error: {e}, please check your tool parameters"
|
||||
agent_tool_callback.on_tool_error(e)
|
||||
except ToolInvokeError as e:
|
||||
error_response = f"tool invoke error: {e}"
|
||||
agent_tool_callback.on_tool_error(e)
|
||||
except ToolEngineInvokeError as e:
|
||||
meta = e.meta
|
||||
error_response = f"tool invoke error: {meta.error}"
|
||||
agent_tool_callback.on_tool_error(e)
|
||||
return error_response, [], meta
|
||||
except Exception as e:
|
||||
error_response = f"unknown error: {e}"
|
||||
agent_tool_callback.on_tool_error(e)
|
||||
|
||||
return error_response, [], ToolInvokeMeta.error_instance(error_response)
|
||||
|
||||
@staticmethod
|
||||
def x(
|
||||
tool: Tool,
|
||||
tool_parameters: dict[str, Any],
|
||||
user_id: str,
|
||||
workflow_tool_callback: DifyWorkflowCallbackHandler,
|
||||
workflow_call_depth: int,
|
||||
thread_pool_id: Optional[str] = None,
|
||||
conversation_id: Optional[str] = None,
|
||||
app_id: Optional[str] = None,
|
||||
message_id: Optional[str] = None,
|
||||
) -> Generator[ToolInvokeMessage, None, None]:
|
||||
"""
|
||||
Workflow invokes the tool with the given arguments.
|
||||
"""
|
||||
try:
|
||||
# hit the callback handler
|
||||
workflow_tool_callback.on_tool_start(tool_name=tool.entity.identity.name, tool_inputs=tool_parameters)
|
||||
|
||||
if isinstance(tool, WorkflowTool):
|
||||
tool.workflow_call_depth = workflow_call_depth + 1
|
||||
tool.thread_pool_id = thread_pool_id
|
||||
|
||||
if tool.runtime and tool.runtime.runtime_parameters:
|
||||
tool_parameters = {**tool.runtime.runtime_parameters, **tool_parameters}
|
||||
|
||||
response = tool.invoke(
|
||||
user_id=user_id,
|
||||
tool_parameters=tool_parameters,
|
||||
conversation_id=conversation_id,
|
||||
app_id=app_id,
|
||||
message_id=message_id,
|
||||
)
|
||||
|
||||
# hit the callback handler
|
||||
response = workflow_tool_callback.on_tool_execution(
|
||||
tool_name=tool.entity.identity.name,
|
||||
tool_inputs=tool_parameters,
|
||||
tool_outputs=response,
|
||||
)
|
||||
|
||||
return response
|
||||
except Exception as e:
|
||||
workflow_tool_callback.on_tool_error(e)
|
||||
raise e
|
||||
|
||||
@staticmethod
|
||||
def _invoke(
|
||||
tool: Tool,
|
||||
tool_parameters: dict,
|
||||
user_id: str,
|
||||
conversation_id: Optional[str] = None,
|
||||
app_id: Optional[str] = None,
|
||||
message_id: Optional[str] = None,
|
||||
) -> Generator[ToolInvokeMessage | ToolInvokeMeta, None, None]:
|
||||
"""
|
||||
Invoke the tool with the given arguments.
|
||||
"""
|
||||
started_at = datetime.now(UTC)
|
||||
meta = ToolInvokeMeta(
|
||||
time_cost=0.0,
|
||||
error=None,
|
||||
tool_config={
|
||||
"tool_name": tool.entity.identity.name,
|
||||
"tool_provider": tool.entity.identity.provider,
|
||||
"tool_provider_type": tool.tool_provider_type().value,
|
||||
"tool_parameters": deepcopy(tool.runtime.runtime_parameters),
|
||||
"tool_icon": tool.entity.identity.icon,
|
||||
},
|
||||
)
|
||||
try:
|
||||
yield from tool.invoke(user_id, tool_parameters, conversation_id, app_id, message_id)
|
||||
except Exception as e:
|
||||
meta.error = str(e)
|
||||
raise ToolEngineInvokeError(meta)
|
||||
finally:
|
||||
ended_at = datetime.now(UTC)
|
||||
meta.time_cost = (ended_at - started_at).total_seconds()
|
||||
yield meta
|
||||
|
||||
@staticmethod
|
||||
def _convert_tool_response_to_str(tool_response: list[ToolInvokeMessage]) -> str:
|
||||
"""
|
||||
Handle tool response
|
||||
"""
|
||||
result = ""
|
||||
for response in tool_response:
|
||||
if response.type == ToolInvokeMessage.MessageType.TEXT:
|
||||
result += cast(ToolInvokeMessage.TextMessage, response.message).text
|
||||
elif response.type == ToolInvokeMessage.MessageType.LINK:
|
||||
result += (
|
||||
f"result link: {cast(ToolInvokeMessage.TextMessage, response.message).text}."
|
||||
+ " please tell user to check it."
|
||||
)
|
||||
elif response.type in {ToolInvokeMessage.MessageType.IMAGE_LINK, ToolInvokeMessage.MessageType.IMAGE}:
|
||||
result += (
|
||||
"image has been created and sent to user already, "
|
||||
+ "you do not need to create it, just tell the user to check it now."
|
||||
)
|
||||
elif response.type == ToolInvokeMessage.MessageType.JSON:
|
||||
result = json.dumps(
|
||||
cast(ToolInvokeMessage.JsonMessage, response.message).json_object, ensure_ascii=False
|
||||
)
|
||||
else:
|
||||
result += str(response.message)
|
||||
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _extract_tool_response_binary_and_text(
|
||||
tool_response: list[ToolInvokeMessage],
|
||||
) -> Generator[ToolInvokeMessageBinary, None, None]:
|
||||
"""
|
||||
Extract tool response binary
|
||||
"""
|
||||
for response in tool_response:
|
||||
if response.type in {ToolInvokeMessage.MessageType.IMAGE_LINK, ToolInvokeMessage.MessageType.IMAGE}:
|
||||
mimetype = None
|
||||
if not response.meta:
|
||||
raise ValueError("missing meta data")
|
||||
if response.meta.get("mime_type"):
|
||||
mimetype = response.meta.get("mime_type")
|
||||
else:
|
||||
try:
|
||||
url = URL(cast(ToolInvokeMessage.TextMessage, response.message).text)
|
||||
extension = url.suffix
|
||||
guess_type_result, _ = guess_type(f"a{extension}")
|
||||
if guess_type_result:
|
||||
mimetype = guess_type_result
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if not mimetype:
|
||||
mimetype = "image/jpeg"
|
||||
|
||||
yield ToolInvokeMessageBinary(
|
||||
mimetype=response.meta.get("mime_type", "image/jpeg"),
|
||||
url=cast(ToolInvokeMessage.TextMessage, response.message).text,
|
||||
)
|
||||
elif response.type == ToolInvokeMessage.MessageType.BLOB:
|
||||
if not response.meta:
|
||||
raise ValueError("missing meta data")
|
||||
|
||||
yield ToolInvokeMessageBinary(
|
||||
mimetype=response.meta.get("mime_type", "application/octet-stream"),
|
||||
url=cast(ToolInvokeMessage.TextMessage, response.message).text,
|
||||
)
|
||||
elif response.type == ToolInvokeMessage.MessageType.LINK:
|
||||
# check if there is a mime type in meta
|
||||
if response.meta and "mime_type" in response.meta:
|
||||
yield ToolInvokeMessageBinary(
|
||||
mimetype=response.meta.get("mime_type", "application/octet-stream")
|
||||
if response.meta
|
||||
else "application/octet-stream",
|
||||
url=cast(ToolInvokeMessage.TextMessage, response.message).text,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _create_message_files(
|
||||
tool_messages: Iterable[ToolInvokeMessageBinary],
|
||||
agent_message: Message,
|
||||
invoke_from: InvokeFrom,
|
||||
user_id: str,
|
||||
) -> list[str]:
|
||||
"""
|
||||
Create message file
|
||||
|
||||
:return: message file ids
|
||||
"""
|
||||
result = []
|
||||
|
||||
for message in tool_messages:
|
||||
if "image" in message.mimetype:
|
||||
file_type = FileType.IMAGE
|
||||
elif "video" in message.mimetype:
|
||||
file_type = FileType.VIDEO
|
||||
elif "audio" in message.mimetype:
|
||||
file_type = FileType.AUDIO
|
||||
elif "text" in message.mimetype or "pdf" in message.mimetype:
|
||||
file_type = FileType.DOCUMENT
|
||||
else:
|
||||
file_type = FileType.CUSTOM
|
||||
|
||||
# extract tool file id from url
|
||||
tool_file_id = message.url.split("/")[-1].split(".")[0]
|
||||
message_file = MessageFile(
|
||||
message_id=agent_message.id,
|
||||
type=file_type,
|
||||
transfer_method=FileTransferMethod.TOOL_FILE,
|
||||
belongs_to="assistant",
|
||||
url=message.url,
|
||||
upload_file_id=tool_file_id,
|
||||
created_by_role=(
|
||||
CreatedByRole.ACCOUNT
|
||||
if invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER}
|
||||
else CreatedByRole.END_USER
|
||||
),
|
||||
created_by=user_id,
|
||||
)
|
||||
|
||||
db.session.add(message_file)
|
||||
db.session.commit()
|
||||
db.session.refresh(message_file)
|
||||
|
||||
result.append(message_file.id)
|
||||
|
||||
db.session.close()
|
||||
|
||||
return result
|
|
@ -0,0 +1,234 @@
|
|||
import base64
|
||||
import hashlib
|
||||
import hmac
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from mimetypes import guess_extension, guess_type
|
||||
from typing import Optional, Union
|
||||
from uuid import uuid4
|
||||
|
||||
import httpx
|
||||
|
||||
from configs import dify_config
|
||||
from core.helper import ssrf_proxy
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
from models.model import MessageFile
|
||||
from models.tools import ToolFile
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ToolFileManager:
|
||||
@staticmethod
|
||||
def sign_file(tool_file_id: str, extension: str) -> str:
|
||||
"""
|
||||
sign file to get a temporary url
|
||||
"""
|
||||
base_url = dify_config.FILES_URL
|
||||
file_preview_url = f"{base_url}/files/tools/{tool_file_id}{extension}"
|
||||
|
||||
timestamp = str(int(time.time()))
|
||||
nonce = os.urandom(16).hex()
|
||||
data_to_sign = f"file-preview|{tool_file_id}|{timestamp}|{nonce}"
|
||||
secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b""
|
||||
sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
|
||||
encoded_sign = base64.urlsafe_b64encode(sign).decode()
|
||||
|
||||
return f"{file_preview_url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}"
|
||||
|
||||
@staticmethod
|
||||
def verify_file(file_id: str, timestamp: str, nonce: str, sign: str) -> bool:
|
||||
"""
|
||||
verify signature
|
||||
"""
|
||||
data_to_sign = f"file-preview|{file_id}|{timestamp}|{nonce}"
|
||||
secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b""
|
||||
recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
|
||||
recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode()
|
||||
|
||||
# verify signature
|
||||
if sign != recalculated_encoded_sign:
|
||||
return False
|
||||
|
||||
current_time = int(time.time())
|
||||
return current_time - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT
|
||||
|
||||
@staticmethod
|
||||
def create_file_by_raw(
|
||||
*,
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
conversation_id: Optional[str],
|
||||
file_binary: bytes,
|
||||
mimetype: str,
|
||||
filename: Optional[str] = None,
|
||||
) -> ToolFile:
|
||||
extension = guess_extension(mimetype) or ".bin"
|
||||
unique_name = uuid4().hex
|
||||
unique_filename = f"{unique_name}{extension}"
|
||||
# default just as before
|
||||
present_filename = unique_filename
|
||||
if filename is not None:
|
||||
has_extension = len(filename.split(".")) > 1
|
||||
# Add extension flexibly
|
||||
present_filename = filename if has_extension else f"{filename}{extension}"
|
||||
filepath = f"tools/{tenant_id}/{unique_filename}"
|
||||
storage.save(filepath, file_binary)
|
||||
|
||||
tool_file = ToolFile(
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
conversation_id=conversation_id,
|
||||
file_key=filepath,
|
||||
mimetype=mimetype,
|
||||
name=present_filename,
|
||||
size=len(file_binary),
|
||||
)
|
||||
|
||||
db.session.add(tool_file)
|
||||
db.session.commit()
|
||||
db.session.refresh(tool_file)
|
||||
|
||||
return tool_file
|
||||
|
||||
@staticmethod
|
||||
def create_file_by_url(
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
file_url: str,
|
||||
conversation_id: Optional[str] = None,
|
||||
) -> ToolFile:
|
||||
# try to download image
|
||||
try:
|
||||
response = ssrf_proxy.get(file_url)
|
||||
response.raise_for_status()
|
||||
blob = response.content
|
||||
except httpx.TimeoutException:
|
||||
raise ValueError(f"timeout when downloading file from {file_url}")
|
||||
|
||||
mimetype = (
|
||||
guess_type(file_url)[0]
|
||||
or response.headers.get("Content-Type", "").split(";")[0].strip()
|
||||
or "application/octet-stream"
|
||||
)
|
||||
extension = guess_extension(mimetype) or ".bin"
|
||||
unique_name = uuid4().hex
|
||||
filename = f"{unique_name}{extension}"
|
||||
filepath = f"tools/{tenant_id}/{filename}"
|
||||
storage.save(filepath, blob)
|
||||
|
||||
tool_file = ToolFile(
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
conversation_id=conversation_id,
|
||||
file_key=filepath,
|
||||
mimetype=mimetype,
|
||||
original_url=file_url,
|
||||
name=filename,
|
||||
size=len(blob),
|
||||
)
|
||||
|
||||
db.session.add(tool_file)
|
||||
db.session.commit()
|
||||
|
||||
return tool_file
|
||||
|
||||
@staticmethod
|
||||
def get_file_binary(id: str) -> Union[tuple[bytes, str], None]:
|
||||
"""
|
||||
get file binary
|
||||
|
||||
:param id: the id of the file
|
||||
|
||||
:return: the binary of the file, mime type
|
||||
"""
|
||||
tool_file: ToolFile | None = (
|
||||
db.session.query(ToolFile)
|
||||
.filter(
|
||||
ToolFile.id == id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not tool_file:
|
||||
return None
|
||||
|
||||
blob = storage.load_once(tool_file.file_key)
|
||||
|
||||
return blob, tool_file.mimetype
|
||||
|
||||
@staticmethod
|
||||
def get_file_binary_by_message_file_id(id: str) -> Union[tuple[bytes, str], None]:
|
||||
"""
|
||||
get file binary
|
||||
|
||||
:param id: the id of the file
|
||||
|
||||
:return: the binary of the file, mime type
|
||||
"""
|
||||
message_file: MessageFile | None = (
|
||||
db.session.query(MessageFile)
|
||||
.filter(
|
||||
MessageFile.id == id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
# Check if message_file is not None
|
||||
if message_file is not None:
|
||||
# get tool file id
|
||||
if message_file.url is not None:
|
||||
tool_file_id = message_file.url.split("/")[-1]
|
||||
# trim extension
|
||||
tool_file_id = tool_file_id.split(".")[0]
|
||||
else:
|
||||
tool_file_id = None
|
||||
else:
|
||||
tool_file_id = None
|
||||
|
||||
tool_file: ToolFile | None = (
|
||||
db.session.query(ToolFile)
|
||||
.filter(
|
||||
ToolFile.id == tool_file_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not tool_file:
|
||||
return None
|
||||
|
||||
blob = storage.load_once(tool_file.file_key)
|
||||
|
||||
return blob, tool_file.mimetype
|
||||
|
||||
@staticmethod
|
||||
def get_file_generator_by_tool_file_id(tool_file_id: str):
|
||||
"""
|
||||
get file binary
|
||||
|
||||
:param tool_file_id: the id of the tool file
|
||||
|
||||
:return: the binary of the file, mime type
|
||||
"""
|
||||
tool_file: ToolFile | None = (
|
||||
db.session.query(ToolFile)
|
||||
.filter(
|
||||
ToolFile.id == tool_file_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not tool_file:
|
||||
return None, None
|
||||
|
||||
stream = storage.load_stream(tool_file.file_key)
|
||||
|
||||
return stream, tool_file
|
||||
|
||||
|
||||
# init tool_file_parser
|
||||
from core.file.tool_file_parser import tool_file_manager
|
||||
|
||||
tool_file_manager["manager"] = ToolFileManager
|
|
@ -0,0 +1,101 @@
|
|||
from core.tools.__base.tool_provider import ToolProviderController
|
||||
from core.tools.builtin_tool.provider import BuiltinToolProviderController
|
||||
from core.tools.custom_tool.provider import ApiToolProviderController
|
||||
from core.tools.entities.values import default_tool_label_name_list
|
||||
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
|
||||
from extensions.ext_database import db
|
||||
from models.tools import ToolLabelBinding
|
||||
|
||||
|
||||
class ToolLabelManager:
|
||||
@classmethod
|
||||
def filter_tool_labels(cls, tool_labels: list[str]) -> list[str]:
|
||||
"""
|
||||
Filter tool labels
|
||||
"""
|
||||
tool_labels = [label for label in tool_labels if label in default_tool_label_name_list]
|
||||
return list(set(tool_labels))
|
||||
|
||||
@classmethod
|
||||
def update_tool_labels(cls, controller: ToolProviderController, labels: list[str]):
|
||||
"""
|
||||
Update tool labels
|
||||
"""
|
||||
labels = cls.filter_tool_labels(labels)
|
||||
|
||||
if isinstance(controller, ApiToolProviderController | WorkflowToolProviderController):
|
||||
provider_id = controller.provider_id
|
||||
else:
|
||||
raise ValueError("Unsupported tool type")
|
||||
|
||||
# delete old labels
|
||||
db.session.query(ToolLabelBinding).filter(ToolLabelBinding.tool_id == provider_id).delete()
|
||||
|
||||
# insert new labels
|
||||
for label in labels:
|
||||
db.session.add(
|
||||
ToolLabelBinding(
|
||||
tool_id=provider_id,
|
||||
tool_type=controller.provider_type.value,
|
||||
label_name=label,
|
||||
)
|
||||
)
|
||||
|
||||
db.session.commit()
|
||||
|
||||
@classmethod
|
||||
def get_tool_labels(cls, controller: ToolProviderController) -> list[str]:
|
||||
"""
|
||||
Get tool labels
|
||||
"""
|
||||
if isinstance(controller, ApiToolProviderController | WorkflowToolProviderController):
|
||||
provider_id = controller.provider_id
|
||||
elif isinstance(controller, BuiltinToolProviderController):
|
||||
return controller.tool_labels
|
||||
else:
|
||||
raise ValueError("Unsupported tool type")
|
||||
|
||||
labels = (
|
||||
db.session.query(ToolLabelBinding.label_name)
|
||||
.filter(
|
||||
ToolLabelBinding.tool_id == provider_id,
|
||||
ToolLabelBinding.tool_type == controller.provider_type.value,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
return [label.label_name for label in labels]
|
||||
|
||||
@classmethod
|
||||
def get_tools_labels(cls, tool_providers: list[ToolProviderController]) -> dict[str, list[str]]:
|
||||
"""
|
||||
Get tools labels
|
||||
|
||||
:param tool_providers: list of tool providers
|
||||
|
||||
:return: dict of tool labels
|
||||
:key: tool id
|
||||
:value: list of tool labels
|
||||
"""
|
||||
if not tool_providers:
|
||||
return {}
|
||||
|
||||
for controller in tool_providers:
|
||||
if not isinstance(controller, ApiToolProviderController | WorkflowToolProviderController):
|
||||
raise ValueError("Unsupported tool type")
|
||||
|
||||
provider_ids = []
|
||||
for controller in tool_providers:
|
||||
assert isinstance(controller, ApiToolProviderController | WorkflowToolProviderController)
|
||||
provider_ids.append(controller.provider_id)
|
||||
|
||||
labels: list[ToolLabelBinding] = (
|
||||
db.session.query(ToolLabelBinding).filter(ToolLabelBinding.tool_id.in_(provider_ids)).all()
|
||||
)
|
||||
|
||||
tool_labels: dict[str, list[str]] = {label.tool_id: [] for label in labels}
|
||||
|
||||
for label in labels:
|
||||
tool_labels[label.tool_id].append(label.label_name)
|
||||
|
||||
return tool_labels
|
|
@ -0,0 +1,870 @@
|
|||
import json
|
||||
import logging
|
||||
import mimetypes
|
||||
from collections.abc import Generator
|
||||
from os import listdir, path
|
||||
from threading import Lock
|
||||
from typing import TYPE_CHECKING, Any, Union, cast
|
||||
|
||||
from yarl import URL
|
||||
|
||||
import contexts
|
||||
from core.plugin.entities.plugin import ToolProviderID
|
||||
from core.plugin.manager.tool import PluginToolManager
|
||||
from core.tools.__base.tool_provider import ToolProviderController
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
from core.tools.plugin_tool.provider import PluginToolProviderController
|
||||
from core.tools.plugin_tool.tool import PluginTool
|
||||
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.workflow.nodes.tool.entities import ToolEntity
|
||||
|
||||
|
||||
from configs import dify_config
|
||||
from core.agent.entities import AgentToolEntity
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.helper.module_import_helper import load_single_subclass_from_source
|
||||
from core.helper.position_helper import is_filtered
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.builtin_tool.provider import BuiltinToolProviderController
|
||||
from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort
|
||||
from core.tools.builtin_tool.tool import BuiltinTool
|
||||
from core.tools.custom_tool.provider import ApiToolProviderController
|
||||
from core.tools.custom_tool.tool import ApiTool
|
||||
from core.tools.entities.api_entities import ToolProviderApiEntity, ToolProviderTypeApiLiteral
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import (
|
||||
ApiProviderAuthType,
|
||||
ToolInvokeFrom,
|
||||
ToolParameter,
|
||||
ToolProviderType,
|
||||
)
|
||||
from core.tools.errors import ToolNotFoundError, ToolProviderNotFoundError
|
||||
from core.tools.tool_label_manager import ToolLabelManager
|
||||
from core.tools.utils.configuration import (
|
||||
ProviderConfigEncrypter,
|
||||
ToolParameterConfigurationManager,
|
||||
)
|
||||
from core.tools.workflow_as_tool.tool import WorkflowTool
|
||||
from extensions.ext_database import db
|
||||
from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider
|
||||
from services.tools.tools_transform_service import ToolTransformService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ToolManager:
|
||||
_builtin_provider_lock = Lock()
|
||||
_hardcoded_providers: dict[str, BuiltinToolProviderController] = {}
|
||||
_builtin_providers_loaded = False
|
||||
_builtin_tools_labels: dict[str, Union[I18nObject, None]] = {}
|
||||
|
||||
@classmethod
|
||||
def get_hardcoded_provider(cls, provider: str) -> BuiltinToolProviderController:
|
||||
"""
|
||||
get the hardcoded provider
|
||||
"""
|
||||
if len(cls._hardcoded_providers) == 0:
|
||||
# init the builtin providers
|
||||
cls.load_hardcoded_providers_cache()
|
||||
|
||||
return cls._hardcoded_providers[provider]
|
||||
|
||||
@classmethod
|
||||
def get_builtin_provider(
|
||||
cls, provider: str, tenant_id: str
|
||||
) -> BuiltinToolProviderController | PluginToolProviderController:
|
||||
"""
|
||||
get the builtin provider
|
||||
|
||||
:param provider: the name of the provider
|
||||
:param tenant_id: the id of the tenant
|
||||
:return: the provider
|
||||
"""
|
||||
# split provider to
|
||||
|
||||
if len(cls._hardcoded_providers) == 0:
|
||||
# init the builtin providers
|
||||
cls.load_hardcoded_providers_cache()
|
||||
|
||||
if provider not in cls._hardcoded_providers:
|
||||
# get plugin provider
|
||||
plugin_provider = cls.get_plugin_provider(provider, tenant_id)
|
||||
if plugin_provider:
|
||||
return plugin_provider
|
||||
|
||||
return cls._hardcoded_providers[provider]
|
||||
|
||||
@classmethod
|
||||
def get_plugin_provider(cls, provider: str, tenant_id: str) -> PluginToolProviderController:
|
||||
"""
|
||||
get the plugin provider
|
||||
"""
|
||||
# check if context is set
|
||||
try:
|
||||
contexts.plugin_tool_providers.get()
|
||||
except LookupError:
|
||||
contexts.plugin_tool_providers.set({})
|
||||
contexts.plugin_tool_providers_lock.set(Lock())
|
||||
|
||||
with contexts.plugin_tool_providers_lock.get():
|
||||
plugin_tool_providers = contexts.plugin_tool_providers.get()
|
||||
if provider in plugin_tool_providers:
|
||||
return plugin_tool_providers[provider]
|
||||
|
||||
manager = PluginToolManager()
|
||||
provider_entity = manager.fetch_tool_provider(tenant_id, provider)
|
||||
if not provider_entity:
|
||||
raise ToolProviderNotFoundError(f"plugin provider {provider} not found")
|
||||
|
||||
controller = PluginToolProviderController(
|
||||
entity=provider_entity.declaration,
|
||||
plugin_id=provider_entity.plugin_id,
|
||||
plugin_unique_identifier=provider_entity.plugin_unique_identifier,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
plugin_tool_providers[provider] = controller
|
||||
|
||||
return controller
|
||||
|
||||
@classmethod
|
||||
def get_builtin_tool(cls, provider: str, tool_name: str, tenant_id: str) -> BuiltinTool | PluginTool | None:
|
||||
"""
|
||||
get the builtin tool
|
||||
|
||||
:param provider: the name of the provider
|
||||
:param tool_name: the name of the tool
|
||||
:param tenant_id: the id of the tenant
|
||||
:return: the provider, the tool
|
||||
"""
|
||||
provider_controller = cls.get_builtin_provider(provider, tenant_id)
|
||||
tool = provider_controller.get_tool(tool_name)
|
||||
if tool is None:
|
||||
raise ToolNotFoundError(f"tool {tool_name} not found")
|
||||
|
||||
return tool
|
||||
|
||||
@classmethod
|
||||
def get_tool_runtime(
|
||||
cls,
|
||||
provider_type: ToolProviderType,
|
||||
provider_id: str,
|
||||
tool_name: str,
|
||||
tenant_id: str,
|
||||
invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
|
||||
tool_invoke_from: ToolInvokeFrom = ToolInvokeFrom.AGENT,
|
||||
) -> Union[BuiltinTool, PluginTool, ApiTool, WorkflowTool]:
|
||||
"""
|
||||
get the tool runtime
|
||||
|
||||
:param provider_type: the type of the provider
|
||||
:param provider_id: the id of the provider
|
||||
:param tool_name: the name of the tool
|
||||
:param tenant_id: the tenant id
|
||||
:param invoke_from: invoke from
|
||||
:param tool_invoke_from: the tool invoke from
|
||||
|
||||
:return: the tool
|
||||
"""
|
||||
if provider_type == ToolProviderType.BUILT_IN:
|
||||
# check if the builtin tool need credentials
|
||||
provider_controller = cls.get_builtin_provider(provider_id, tenant_id)
|
||||
|
||||
builtin_tool = provider_controller.get_tool(tool_name)
|
||||
if not builtin_tool:
|
||||
raise ToolProviderNotFoundError(f"builtin tool {tool_name} not found")
|
||||
|
||||
if not provider_controller.need_credentials:
|
||||
return cast(
|
||||
BuiltinTool,
|
||||
builtin_tool.fork_tool_runtime(
|
||||
runtime=ToolRuntime(
|
||||
tenant_id=tenant_id,
|
||||
credentials={},
|
||||
invoke_from=invoke_from,
|
||||
tool_invoke_from=tool_invoke_from,
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
if isinstance(provider_controller, PluginToolProviderController):
|
||||
provider_id_entity = ToolProviderID(provider_id)
|
||||
# get credentials
|
||||
builtin_provider: BuiltinToolProvider | None = (
|
||||
db.session.query(BuiltinToolProvider)
|
||||
.filter(
|
||||
BuiltinToolProvider.tenant_id == tenant_id,
|
||||
(BuiltinToolProvider.provider == str(provider_id_entity))
|
||||
| (BuiltinToolProvider.provider == provider_id_entity.provider_name),
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if builtin_provider is None:
|
||||
raise ToolProviderNotFoundError(f"builtin provider {provider_id} not found")
|
||||
else:
|
||||
builtin_provider = (
|
||||
db.session.query(BuiltinToolProvider)
|
||||
.filter(BuiltinToolProvider.tenant_id == tenant_id, (BuiltinToolProvider.provider == provider_id))
|
||||
.first()
|
||||
)
|
||||
|
||||
if builtin_provider is None:
|
||||
raise ToolProviderNotFoundError(f"builtin provider {provider_id} not found")
|
||||
|
||||
# decrypt the credentials
|
||||
credentials = builtin_provider.credentials
|
||||
tool_configuration = ProviderConfigEncrypter(
|
||||
tenant_id=tenant_id,
|
||||
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
|
||||
provider_type=provider_controller.provider_type.value,
|
||||
provider_identity=provider_controller.entity.identity.name,
|
||||
)
|
||||
|
||||
decrypted_credentials = tool_configuration.decrypt(credentials)
|
||||
|
||||
return cast(
|
||||
BuiltinTool,
|
||||
builtin_tool.fork_tool_runtime(
|
||||
runtime=ToolRuntime(
|
||||
tenant_id=tenant_id,
|
||||
credentials=decrypted_credentials,
|
||||
runtime_parameters={},
|
||||
invoke_from=invoke_from,
|
||||
tool_invoke_from=tool_invoke_from,
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
elif provider_type == ToolProviderType.API:
|
||||
api_provider, credentials = cls.get_api_provider_controller(tenant_id, provider_id)
|
||||
|
||||
# decrypt the credentials
|
||||
tool_configuration = ProviderConfigEncrypter(
|
||||
tenant_id=tenant_id,
|
||||
config=[x.to_basic_provider_config() for x in api_provider.get_credentials_schema()],
|
||||
provider_type=api_provider.provider_type.value,
|
||||
provider_identity=api_provider.entity.identity.name,
|
||||
)
|
||||
decrypted_credentials = tool_configuration.decrypt(credentials)
|
||||
|
||||
return cast(
|
||||
ApiTool,
|
||||
api_provider.get_tool(tool_name).fork_tool_runtime(
|
||||
runtime=ToolRuntime(
|
||||
tenant_id=tenant_id,
|
||||
credentials=decrypted_credentials,
|
||||
invoke_from=invoke_from,
|
||||
tool_invoke_from=tool_invoke_from,
|
||||
)
|
||||
),
|
||||
)
|
||||
elif provider_type == ToolProviderType.WORKFLOW:
|
||||
workflow_provider = (
|
||||
db.session.query(WorkflowToolProvider)
|
||||
.filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if workflow_provider is None:
|
||||
raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found")
|
||||
|
||||
controller = ToolTransformService.workflow_provider_to_controller(db_provider=workflow_provider)
|
||||
controller_tools: list[WorkflowTool] = controller.get_tools(tenant_id=workflow_provider.tenant_id)
|
||||
if controller_tools is None or len(controller_tools) == 0:
|
||||
raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found")
|
||||
|
||||
return cast(
|
||||
WorkflowTool,
|
||||
controller.get_tools(tenant_id=workflow_provider.tenant_id)[0].fork_tool_runtime(
|
||||
runtime=ToolRuntime(
|
||||
tenant_id=tenant_id,
|
||||
credentials={},
|
||||
invoke_from=invoke_from,
|
||||
tool_invoke_from=tool_invoke_from,
|
||||
)
|
||||
),
|
||||
)
|
||||
elif provider_type == ToolProviderType.APP:
|
||||
raise NotImplementedError("app provider not implemented")
|
||||
elif provider_type == ToolProviderType.PLUGIN:
|
||||
return cls.get_plugin_provider(provider_id, tenant_id).get_tool(tool_name)
|
||||
else:
|
||||
raise ToolProviderNotFoundError(f"provider type {provider_type.value} not found")
|
||||
|
||||
@classmethod
|
||||
def get_agent_tool_runtime(
|
||||
cls,
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
agent_tool: AgentToolEntity,
|
||||
invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
|
||||
) -> Tool:
|
||||
"""
|
||||
get the agent tool runtime
|
||||
"""
|
||||
tool_entity = cls.get_tool_runtime(
|
||||
provider_type=agent_tool.provider_type,
|
||||
provider_id=agent_tool.provider_id,
|
||||
tool_name=agent_tool.tool_name,
|
||||
tenant_id=tenant_id,
|
||||
invoke_from=invoke_from,
|
||||
tool_invoke_from=ToolInvokeFrom.AGENT,
|
||||
)
|
||||
runtime_parameters = {}
|
||||
parameters = tool_entity.get_merged_runtime_parameters()
|
||||
for parameter in parameters:
|
||||
# check file types
|
||||
if (
|
||||
parameter.type
|
||||
in {
|
||||
ToolParameter.ToolParameterType.SYSTEM_FILES,
|
||||
ToolParameter.ToolParameterType.FILE,
|
||||
ToolParameter.ToolParameterType.FILES,
|
||||
}
|
||||
and parameter.required
|
||||
):
|
||||
raise ValueError(f"file type parameter {parameter.name} not supported in agent")
|
||||
|
||||
if parameter.form == ToolParameter.ToolParameterForm.FORM:
|
||||
# save tool parameter to tool entity memory
|
||||
value = parameter.init_frontend_parameter(agent_tool.tool_parameters.get(parameter.name))
|
||||
runtime_parameters[parameter.name] = value
|
||||
|
||||
# decrypt runtime parameters
|
||||
encryption_manager = ToolParameterConfigurationManager(
|
||||
tenant_id=tenant_id,
|
||||
tool_runtime=tool_entity,
|
||||
provider_name=agent_tool.provider_id,
|
||||
provider_type=agent_tool.provider_type,
|
||||
identity_id=f"AGENT.{app_id}",
|
||||
)
|
||||
runtime_parameters = encryption_manager.decrypt_tool_parameters(runtime_parameters)
|
||||
if tool_entity.runtime is None or tool_entity.runtime.runtime_parameters is None:
|
||||
raise ValueError("runtime not found or runtime parameters not found")
|
||||
|
||||
tool_entity.runtime.runtime_parameters.update(runtime_parameters)
|
||||
return tool_entity
|
||||
|
||||
@classmethod
|
||||
def get_workflow_tool_runtime(
|
||||
cls,
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
node_id: str,
|
||||
workflow_tool: "ToolEntity",
|
||||
invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
|
||||
) -> Tool:
|
||||
"""
|
||||
get the workflow tool runtime
|
||||
"""
|
||||
tool_runtime = cls.get_tool_runtime(
|
||||
provider_type=workflow_tool.provider_type,
|
||||
provider_id=workflow_tool.provider_id,
|
||||
tool_name=workflow_tool.tool_name,
|
||||
tenant_id=tenant_id,
|
||||
invoke_from=invoke_from,
|
||||
tool_invoke_from=ToolInvokeFrom.WORKFLOW,
|
||||
)
|
||||
runtime_parameters = {}
|
||||
parameters = tool_runtime.get_merged_runtime_parameters()
|
||||
|
||||
for parameter in parameters:
|
||||
# save tool parameter to tool entity memory
|
||||
if parameter.form == ToolParameter.ToolParameterForm.FORM:
|
||||
value = parameter.init_frontend_parameter(workflow_tool.tool_configurations.get(parameter.name))
|
||||
runtime_parameters[parameter.name] = value
|
||||
|
||||
# decrypt runtime parameters
|
||||
encryption_manager = ToolParameterConfigurationManager(
|
||||
tenant_id=tenant_id,
|
||||
tool_runtime=tool_runtime,
|
||||
provider_name=workflow_tool.provider_id,
|
||||
provider_type=workflow_tool.provider_type,
|
||||
identity_id=f"WORKFLOW.{app_id}.{node_id}",
|
||||
)
|
||||
|
||||
if runtime_parameters:
|
||||
runtime_parameters = encryption_manager.decrypt_tool_parameters(runtime_parameters)
|
||||
|
||||
tool_runtime.runtime.runtime_parameters.update(runtime_parameters)
|
||||
return tool_runtime
|
||||
|
||||
@classmethod
|
||||
def get_tool_runtime_from_plugin(
|
||||
cls,
|
||||
tool_type: ToolProviderType,
|
||||
tenant_id: str,
|
||||
provider: str,
|
||||
tool_name: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
) -> Tool:
|
||||
"""
|
||||
get tool runtime from plugin
|
||||
"""
|
||||
tool_entity = cls.get_tool_runtime(
|
||||
provider_type=tool_type,
|
||||
provider_id=provider,
|
||||
tool_name=tool_name,
|
||||
tenant_id=tenant_id,
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
tool_invoke_from=ToolInvokeFrom.PLUGIN,
|
||||
)
|
||||
runtime_parameters = {}
|
||||
parameters = tool_entity.get_merged_runtime_parameters()
|
||||
for parameter in parameters:
|
||||
if parameter.form == ToolParameter.ToolParameterForm.FORM:
|
||||
# save tool parameter to tool entity memory
|
||||
value = parameter.init_frontend_parameter(tool_parameters.get(parameter.name))
|
||||
runtime_parameters[parameter.name] = value
|
||||
|
||||
tool_entity.runtime.runtime_parameters.update(runtime_parameters)
|
||||
return tool_entity
|
||||
|
||||
@classmethod
|
||||
def get_hardcoded_provider_icon(cls, provider: str) -> tuple[str, str]:
|
||||
"""
|
||||
get the absolute path of the icon of the hardcoded provider
|
||||
|
||||
:param provider: the name of the provider
|
||||
:return: the absolute path of the icon, the mime type of the icon
|
||||
"""
|
||||
# get provider
|
||||
provider_controller = cls.get_hardcoded_provider(provider)
|
||||
|
||||
absolute_path = path.join(
|
||||
path.dirname(path.realpath(__file__)),
|
||||
"builtin_tool",
|
||||
"providers",
|
||||
provider,
|
||||
"_assets",
|
||||
provider_controller.entity.identity.icon,
|
||||
)
|
||||
# check if the icon exists
|
||||
if not path.exists(absolute_path):
|
||||
raise ToolProviderNotFoundError(f"builtin provider {provider} icon not found")
|
||||
|
||||
# get the mime type
|
||||
mime_type, _ = mimetypes.guess_type(absolute_path)
|
||||
mime_type = mime_type or "application/octet-stream"
|
||||
|
||||
return absolute_path, mime_type
|
||||
|
||||
@classmethod
|
||||
def list_hardcoded_providers(cls):
|
||||
# use cache first
|
||||
if cls._builtin_providers_loaded:
|
||||
yield from list(cls._hardcoded_providers.values())
|
||||
return
|
||||
|
||||
with cls._builtin_provider_lock:
|
||||
if cls._builtin_providers_loaded:
|
||||
yield from list(cls._hardcoded_providers.values())
|
||||
return
|
||||
|
||||
yield from cls._list_hardcoded_providers()
|
||||
|
||||
@classmethod
|
||||
def list_plugin_providers(cls, tenant_id: str) -> list[PluginToolProviderController]:
|
||||
"""
|
||||
list all the plugin providers
|
||||
"""
|
||||
manager = PluginToolManager()
|
||||
provider_entities = manager.fetch_tool_providers(tenant_id)
|
||||
return [
|
||||
PluginToolProviderController(
|
||||
entity=provider.declaration,
|
||||
plugin_id=provider.plugin_id,
|
||||
plugin_unique_identifier=provider.plugin_unique_identifier,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
for provider in provider_entities
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def list_builtin_providers(
|
||||
cls, tenant_id: str
|
||||
) -> Generator[BuiltinToolProviderController | PluginToolProviderController, None, None]:
|
||||
"""
|
||||
list all the builtin providers
|
||||
"""
|
||||
yield from cls.list_hardcoded_providers()
|
||||
# get plugin providers
|
||||
yield from cls.list_plugin_providers(tenant_id)
|
||||
|
||||
@classmethod
|
||||
def _list_hardcoded_providers(cls) -> Generator[BuiltinToolProviderController, None, None]:
|
||||
"""
|
||||
list all the builtin providers
|
||||
"""
|
||||
for provider_path in listdir(path.join(path.dirname(path.realpath(__file__)), "builtin_tool", "providers")):
|
||||
if provider_path.startswith("__"):
|
||||
continue
|
||||
|
||||
if path.isdir(path.join(path.dirname(path.realpath(__file__)), "builtin_tool", "providers", provider_path)):
|
||||
if provider_path.startswith("__"):
|
||||
continue
|
||||
|
||||
# init provider
|
||||
try:
|
||||
provider_class = load_single_subclass_from_source(
|
||||
module_name=f"core.tools.builtin_tool.providers.{provider_path}.{provider_path}",
|
||||
script_path=path.join(
|
||||
path.dirname(path.realpath(__file__)),
|
||||
"builtin_tool",
|
||||
"providers",
|
||||
provider_path,
|
||||
f"{provider_path}.py",
|
||||
),
|
||||
parent_type=BuiltinToolProviderController,
|
||||
)
|
||||
provider: BuiltinToolProviderController = provider_class()
|
||||
cls._hardcoded_providers[provider.entity.identity.name] = provider
|
||||
for tool in provider.get_tools():
|
||||
cls._builtin_tools_labels[tool.entity.identity.name] = tool.entity.identity.label
|
||||
yield provider
|
||||
|
||||
except Exception:
|
||||
logger.exception(f"load builtin provider {provider}")
|
||||
continue
|
||||
# set builtin providers loaded
|
||||
cls._builtin_providers_loaded = True
|
||||
|
||||
@classmethod
|
||||
def load_hardcoded_providers_cache(cls):
|
||||
for _ in cls.list_hardcoded_providers():
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def clear_hardcoded_providers_cache(cls):
|
||||
cls._hardcoded_providers = {}
|
||||
cls._builtin_providers_loaded = False
|
||||
|
||||
@classmethod
|
||||
def get_tool_label(cls, tool_name: str) -> Union[I18nObject, None]:
|
||||
"""
|
||||
get the tool label
|
||||
|
||||
:param tool_name: the name of the tool
|
||||
|
||||
:return: the label of the tool
|
||||
"""
|
||||
if len(cls._builtin_tools_labels) == 0:
|
||||
# init the builtin providers
|
||||
cls.load_hardcoded_providers_cache()
|
||||
|
||||
if tool_name not in cls._builtin_tools_labels:
|
||||
return None
|
||||
|
||||
return cls._builtin_tools_labels[tool_name]
|
||||
|
||||
@classmethod
|
||||
def list_providers_from_api(
|
||||
cls, user_id: str, tenant_id: str, typ: ToolProviderTypeApiLiteral
|
||||
) -> list[ToolProviderApiEntity]:
|
||||
result_providers: dict[str, ToolProviderApiEntity] = {}
|
||||
|
||||
filters = []
|
||||
if not typ:
|
||||
filters.extend(["builtin", "api", "workflow"])
|
||||
else:
|
||||
filters.append(typ)
|
||||
|
||||
with db.session.no_autoflush:
|
||||
if "builtin" in filters:
|
||||
# get builtin providers
|
||||
builtin_providers = cls.list_builtin_providers(tenant_id)
|
||||
|
||||
# get db builtin providers
|
||||
db_builtin_providers: list[BuiltinToolProvider] = (
|
||||
db.session.query(BuiltinToolProvider).filter(BuiltinToolProvider.tenant_id == tenant_id).all()
|
||||
)
|
||||
|
||||
# rewrite db_builtin_providers
|
||||
for db_provider in db_builtin_providers:
|
||||
tool_provider_id = str(ToolProviderID(db_provider.provider))
|
||||
db_provider.provider = tool_provider_id
|
||||
|
||||
def find_db_builtin_provider(provider):
|
||||
return next((x for x in db_builtin_providers if x.provider == provider), None)
|
||||
|
||||
# append builtin providers
|
||||
for provider in builtin_providers:
|
||||
# handle include, exclude
|
||||
if is_filtered(
|
||||
include_set=cast(set[str], dify_config.POSITION_TOOL_INCLUDES_SET),
|
||||
exclude_set=cast(set[str], dify_config.POSITION_TOOL_EXCLUDES_SET),
|
||||
data=provider,
|
||||
name_func=lambda x: x.identity.name,
|
||||
):
|
||||
continue
|
||||
|
||||
user_provider = ToolTransformService.builtin_provider_to_user_provider(
|
||||
provider_controller=provider,
|
||||
db_provider=find_db_builtin_provider(provider.entity.identity.name),
|
||||
decrypt_credentials=False,
|
||||
)
|
||||
|
||||
if isinstance(provider, PluginToolProviderController):
|
||||
result_providers[f"plugin_provider.{user_provider.name}"] = user_provider
|
||||
else:
|
||||
result_providers[f"builtin_provider.{user_provider.name}"] = user_provider
|
||||
|
||||
# get db api providers
|
||||
|
||||
if "api" in filters:
|
||||
db_api_providers: list[ApiToolProvider] = (
|
||||
db.session.query(ApiToolProvider).filter(ApiToolProvider.tenant_id == tenant_id).all()
|
||||
)
|
||||
|
||||
api_provider_controllers: list[dict[str, Any]] = [
|
||||
{"provider": provider, "controller": ToolTransformService.api_provider_to_controller(provider)}
|
||||
for provider in db_api_providers
|
||||
]
|
||||
|
||||
# get labels
|
||||
labels = ToolLabelManager.get_tools_labels([x["controller"] for x in api_provider_controllers])
|
||||
|
||||
for api_provider_controller in api_provider_controllers:
|
||||
user_provider = ToolTransformService.api_provider_to_user_provider(
|
||||
provider_controller=api_provider_controller["controller"],
|
||||
db_provider=api_provider_controller["provider"],
|
||||
decrypt_credentials=False,
|
||||
labels=labels.get(api_provider_controller["controller"].provider_id, []),
|
||||
)
|
||||
result_providers[f"api_provider.{user_provider.name}"] = user_provider
|
||||
|
||||
if "workflow" in filters:
|
||||
# get workflow providers
|
||||
workflow_providers: list[WorkflowToolProvider] = (
|
||||
db.session.query(WorkflowToolProvider).filter(WorkflowToolProvider.tenant_id == tenant_id).all()
|
||||
)
|
||||
|
||||
workflow_provider_controllers: list[WorkflowToolProviderController] = []
|
||||
for provider in workflow_providers:
|
||||
try:
|
||||
workflow_provider_controllers.append(
|
||||
ToolTransformService.workflow_provider_to_controller(db_provider=provider)
|
||||
)
|
||||
except Exception:
|
||||
# app has been deleted
|
||||
pass
|
||||
|
||||
labels = ToolLabelManager.get_tools_labels(
|
||||
[cast(ToolProviderController, controller) for controller in workflow_provider_controllers]
|
||||
)
|
||||
|
||||
for provider_controller in workflow_provider_controllers:
|
||||
user_provider = ToolTransformService.workflow_provider_to_user_provider(
|
||||
provider_controller=provider_controller,
|
||||
labels=labels.get(provider_controller.provider_id, []),
|
||||
)
|
||||
result_providers[f"workflow_provider.{user_provider.name}"] = user_provider
|
||||
|
||||
return BuiltinToolProviderSort.sort(list(result_providers.values()))
|
||||
|
||||
@classmethod
|
||||
def get_api_provider_controller(
|
||||
cls, tenant_id: str, provider_id: str
|
||||
) -> tuple[ApiToolProviderController, dict[str, Any]]:
|
||||
"""
|
||||
get the api provider
|
||||
|
||||
:param tenant_id: the id of the tenant
|
||||
:param provider_id: the id of the provider
|
||||
|
||||
:return: the provider controller, the credentials
|
||||
"""
|
||||
provider: ApiToolProvider | None = (
|
||||
db.session.query(ApiToolProvider)
|
||||
.filter(
|
||||
ApiToolProvider.id == provider_id,
|
||||
ApiToolProvider.tenant_id == tenant_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if provider is None:
|
||||
raise ToolProviderNotFoundError(f"api provider {provider_id} not found")
|
||||
|
||||
controller = ApiToolProviderController.from_db(
|
||||
provider,
|
||||
ApiProviderAuthType.API_KEY if provider.credentials["auth_type"] == "api_key" else ApiProviderAuthType.NONE,
|
||||
)
|
||||
controller.load_bundled_tools(provider.tools)
|
||||
|
||||
return controller, provider.credentials
|
||||
|
||||
@classmethod
|
||||
def user_get_api_provider(cls, provider: str, tenant_id: str) -> dict:
|
||||
"""
|
||||
get api provider
|
||||
"""
|
||||
"""
|
||||
get tool provider
|
||||
"""
|
||||
provider_name = provider
|
||||
provider_obj: ApiToolProvider | None = (
|
||||
db.session.query(ApiToolProvider)
|
||||
.filter(
|
||||
ApiToolProvider.tenant_id == tenant_id,
|
||||
ApiToolProvider.name == provider,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if provider_obj is None:
|
||||
raise ValueError(f"you have not added provider {provider_name}")
|
||||
|
||||
try:
|
||||
credentials = json.loads(provider_obj.credentials_str) or {}
|
||||
except Exception:
|
||||
credentials = {}
|
||||
|
||||
# package tool provider controller
|
||||
controller = ApiToolProviderController.from_db(
|
||||
provider_obj,
|
||||
ApiProviderAuthType.API_KEY if credentials["auth_type"] == "api_key" else ApiProviderAuthType.NONE,
|
||||
)
|
||||
# init tool configuration
|
||||
tool_configuration = ProviderConfigEncrypter(
|
||||
tenant_id=tenant_id,
|
||||
config=[x.to_basic_provider_config() for x in controller.get_credentials_schema()],
|
||||
provider_type=controller.provider_type.value,
|
||||
provider_identity=controller.entity.identity.name,
|
||||
)
|
||||
|
||||
decrypted_credentials = tool_configuration.decrypt(credentials)
|
||||
masked_credentials = tool_configuration.mask_tool_credentials(decrypted_credentials)
|
||||
|
||||
try:
|
||||
icon = json.loads(provider_obj.icon)
|
||||
except Exception:
|
||||
icon = {"background": "#252525", "content": "\ud83d\ude01"}
|
||||
|
||||
# add tool labels
|
||||
labels = ToolLabelManager.get_tool_labels(controller)
|
||||
|
||||
return cast(
|
||||
dict,
|
||||
jsonable_encoder(
|
||||
{
|
||||
"schema_type": provider_obj.schema_type,
|
||||
"schema": provider_obj.schema,
|
||||
"tools": provider_obj.tools,
|
||||
"icon": icon,
|
||||
"description": provider_obj.description,
|
||||
"credentials": masked_credentials,
|
||||
"privacy_policy": provider_obj.privacy_policy,
|
||||
"custom_disclaimer": provider_obj.custom_disclaimer,
|
||||
"labels": labels,
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def generate_builtin_tool_icon_url(cls, provider_id: str) -> str:
|
||||
return str(
|
||||
URL(dify_config.CONSOLE_API_URL or "/")
|
||||
/ "console"
|
||||
/ "api"
|
||||
/ "workspaces"
|
||||
/ "current"
|
||||
/ "tool-provider"
|
||||
/ "builtin"
|
||||
/ provider_id
|
||||
/ "icon"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def generate_plugin_tool_icon_url(cls, tenant_id: str, filename: str) -> str:
|
||||
return str(
|
||||
URL(dify_config.CONSOLE_API_URL or "/")
|
||||
/ "console"
|
||||
/ "api"
|
||||
/ "workspaces"
|
||||
/ "current"
|
||||
/ "plugin"
|
||||
/ "icon"
|
||||
% {"tenant_id": tenant_id, "filename": filename}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def generate_workflow_tool_icon_url(cls, tenant_id: str, provider_id: str) -> dict:
|
||||
try:
|
||||
workflow_provider: WorkflowToolProvider | None = (
|
||||
db.session.query(WorkflowToolProvider)
|
||||
.filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if workflow_provider is None:
|
||||
raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found")
|
||||
|
||||
icon: dict = json.loads(workflow_provider.icon)
|
||||
return icon
|
||||
except Exception:
|
||||
return {"background": "#252525", "content": "\ud83d\ude01"}
|
||||
|
||||
@classmethod
|
||||
def generate_api_tool_icon_url(cls, tenant_id: str, provider_id: str) -> dict:
|
||||
try:
|
||||
api_provider: ApiToolProvider | None = (
|
||||
db.session.query(ApiToolProvider)
|
||||
.filter(ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.id == provider_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if api_provider is None:
|
||||
raise ToolProviderNotFoundError(f"api provider {provider_id} not found")
|
||||
|
||||
icon: dict = json.loads(api_provider.icon)
|
||||
return icon
|
||||
except Exception:
|
||||
return {"background": "#252525", "content": "\ud83d\ude01"}
|
||||
|
||||
@classmethod
|
||||
def get_tool_icon(
|
||||
cls,
|
||||
tenant_id: str,
|
||||
provider_type: ToolProviderType,
|
||||
provider_id: str,
|
||||
) -> Union[str, dict]:
|
||||
"""
|
||||
get the tool icon
|
||||
|
||||
:param tenant_id: the id of the tenant
|
||||
:param provider_type: the type of the provider
|
||||
:param provider_id: the id of the provider
|
||||
:return:
|
||||
"""
|
||||
provider_type = provider_type
|
||||
provider_id = provider_id
|
||||
if provider_type == ToolProviderType.BUILT_IN:
|
||||
provider = ToolManager.get_builtin_provider(provider_id, tenant_id)
|
||||
if isinstance(provider, PluginToolProviderController):
|
||||
try:
|
||||
return cls.generate_plugin_tool_icon_url(tenant_id, provider.entity.identity.icon)
|
||||
except Exception:
|
||||
return {"background": "#252525", "content": "\ud83d\ude01"}
|
||||
return cls.generate_builtin_tool_icon_url(provider_id)
|
||||
elif provider_type == ToolProviderType.API:
|
||||
return cls.generate_api_tool_icon_url(tenant_id, provider_id)
|
||||
elif provider_type == ToolProviderType.WORKFLOW:
|
||||
return cls.generate_workflow_tool_icon_url(tenant_id, provider_id)
|
||||
elif provider_type == ToolProviderType.PLUGIN:
|
||||
provider = ToolManager.get_builtin_provider(provider_id, tenant_id)
|
||||
if isinstance(provider, PluginToolProviderController):
|
||||
try:
|
||||
return cls.generate_plugin_tool_icon_url(tenant_id, provider.entity.identity.icon)
|
||||
except Exception:
|
||||
return {"background": "#252525", "content": "\ud83d\ude01"}
|
||||
raise ValueError(f"plugin provider {provider_id} not found")
|
||||
else:
|
||||
raise ValueError(f"provider type {provider_type} not found")
|
||||
|
||||
|
||||
ToolManager.load_hardcoded_providers_cache()
|
|
@ -0,0 +1,265 @@
|
|||
from copy import deepcopy
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.entities.provider_entities import BasicProviderConfig
|
||||
from core.helper import encrypter
|
||||
from core.helper.tool_parameter_cache import ToolParameterCache, ToolParameterCacheType
|
||||
from core.helper.tool_provider_cache import ToolProviderCredentialsCache, ToolProviderCredentialsCacheType
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.entities.tool_entities import (
|
||||
ToolParameter,
|
||||
ToolProviderType,
|
||||
)
|
||||
|
||||
|
||||
class ProviderConfigEncrypter(BaseModel):
|
||||
tenant_id: str
|
||||
config: list[BasicProviderConfig]
|
||||
provider_type: str
|
||||
provider_identity: str
|
||||
|
||||
def _deep_copy(self, data: dict[str, str]) -> dict[str, str]:
|
||||
"""
|
||||
deep copy data
|
||||
"""
|
||||
return deepcopy(data)
|
||||
|
||||
def encrypt(self, data: dict[str, str]) -> dict[str, str]:
|
||||
"""
|
||||
encrypt tool credentials with tenant id
|
||||
|
||||
return a deep copy of credentials with encrypted values
|
||||
"""
|
||||
data = self._deep_copy(data)
|
||||
|
||||
# get fields need to be decrypted
|
||||
fields = dict[str, BasicProviderConfig]()
|
||||
for credential in self.config:
|
||||
fields[credential.name] = credential
|
||||
|
||||
for field_name, field in fields.items():
|
||||
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
|
||||
if field_name in data:
|
||||
encrypted = encrypter.encrypt_token(self.tenant_id, data[field_name] or "")
|
||||
data[field_name] = encrypted
|
||||
|
||||
return data
|
||||
|
||||
def mask_tool_credentials(self, data: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
mask tool credentials
|
||||
|
||||
return a deep copy of credentials with masked values
|
||||
"""
|
||||
data = self._deep_copy(data)
|
||||
|
||||
# get fields need to be decrypted
|
||||
fields = dict[str, BasicProviderConfig]()
|
||||
for credential in self.config:
|
||||
fields[credential.name] = credential
|
||||
|
||||
for field_name, field in fields.items():
|
||||
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
|
||||
if field_name in data:
|
||||
if len(data[field_name]) > 6:
|
||||
data[field_name] = (
|
||||
data[field_name][:2] + "*" * (len(data[field_name]) - 4) + data[field_name][-2:]
|
||||
)
|
||||
else:
|
||||
data[field_name] = "*" * len(data[field_name])
|
||||
|
||||
return data
|
||||
|
||||
def decrypt(self, data: dict[str, str]) -> dict[str, str]:
|
||||
"""
|
||||
decrypt tool credentials with tenant id
|
||||
|
||||
return a deep copy of credentials with decrypted values
|
||||
"""
|
||||
cache = ToolProviderCredentialsCache(
|
||||
tenant_id=self.tenant_id,
|
||||
identity_id=f"{self.provider_type}.{self.provider_identity}",
|
||||
cache_type=ToolProviderCredentialsCacheType.PROVIDER,
|
||||
)
|
||||
cached_credentials = cache.get()
|
||||
if cached_credentials:
|
||||
return cached_credentials
|
||||
data = self._deep_copy(data)
|
||||
# get fields need to be decrypted
|
||||
fields = dict[str, BasicProviderConfig]()
|
||||
for credential in self.config:
|
||||
fields[credential.name] = credential
|
||||
|
||||
for field_name, field in fields.items():
|
||||
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
|
||||
if field_name in data:
|
||||
try:
|
||||
# if the value is None or empty string, skip decrypt
|
||||
if not data[field_name]:
|
||||
continue
|
||||
|
||||
data[field_name] = encrypter.decrypt_token(self.tenant_id, data[field_name])
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
cache.set(data)
|
||||
return data
|
||||
|
||||
def delete_tool_credentials_cache(self):
|
||||
cache = ToolProviderCredentialsCache(
|
||||
tenant_id=self.tenant_id,
|
||||
identity_id=f"{self.provider_type}.{self.provider_identity}",
|
||||
cache_type=ToolProviderCredentialsCacheType.PROVIDER,
|
||||
)
|
||||
cache.delete()
|
||||
|
||||
|
||||
class ToolParameterConfigurationManager:
|
||||
"""
|
||||
Tool parameter configuration manager
|
||||
"""
|
||||
|
||||
tenant_id: str
|
||||
tool_runtime: Tool
|
||||
provider_name: str
|
||||
provider_type: ToolProviderType
|
||||
identity_id: str
|
||||
|
||||
def __init__(
|
||||
self, tenant_id: str, tool_runtime: Tool, provider_name: str, provider_type: ToolProviderType, identity_id: str
|
||||
) -> None:
|
||||
self.tenant_id = tenant_id
|
||||
self.tool_runtime = tool_runtime
|
||||
self.provider_name = provider_name
|
||||
self.provider_type = provider_type
|
||||
self.identity_id = identity_id
|
||||
|
||||
def _deep_copy(self, parameters: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
deep copy parameters
|
||||
"""
|
||||
return deepcopy(parameters)
|
||||
|
||||
def _merge_parameters(self) -> list[ToolParameter]:
|
||||
"""
|
||||
merge parameters
|
||||
"""
|
||||
# get tool parameters
|
||||
tool_parameters = self.tool_runtime.entity.parameters or []
|
||||
# get tool runtime parameters
|
||||
runtime_parameters = self.tool_runtime.get_runtime_parameters()
|
||||
# override parameters
|
||||
current_parameters = tool_parameters.copy()
|
||||
for runtime_parameter in runtime_parameters:
|
||||
found = False
|
||||
for index, parameter in enumerate(current_parameters):
|
||||
if parameter.name == runtime_parameter.name and parameter.form == runtime_parameter.form:
|
||||
current_parameters[index] = runtime_parameter
|
||||
found = True
|
||||
break
|
||||
|
||||
if not found and runtime_parameter.form == ToolParameter.ToolParameterForm.FORM:
|
||||
current_parameters.append(runtime_parameter)
|
||||
|
||||
return current_parameters
|
||||
|
||||
def mask_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
mask tool parameters
|
||||
|
||||
return a deep copy of parameters with masked values
|
||||
"""
|
||||
parameters = self._deep_copy(parameters)
|
||||
|
||||
# override parameters
|
||||
current_parameters = self._merge_parameters()
|
||||
|
||||
for parameter in current_parameters:
|
||||
if (
|
||||
parameter.form == ToolParameter.ToolParameterForm.FORM
|
||||
and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT
|
||||
):
|
||||
if parameter.name in parameters:
|
||||
if len(parameters[parameter.name]) > 6:
|
||||
parameters[parameter.name] = (
|
||||
parameters[parameter.name][:2]
|
||||
+ "*" * (len(parameters[parameter.name]) - 4)
|
||||
+ parameters[parameter.name][-2:]
|
||||
)
|
||||
else:
|
||||
parameters[parameter.name] = "*" * len(parameters[parameter.name])
|
||||
|
||||
return parameters
|
||||
|
||||
def encrypt_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
encrypt tool parameters with tenant id
|
||||
|
||||
return a deep copy of parameters with encrypted values
|
||||
"""
|
||||
# override parameters
|
||||
current_parameters = self._merge_parameters()
|
||||
|
||||
parameters = self._deep_copy(parameters)
|
||||
|
||||
for parameter in current_parameters:
|
||||
if (
|
||||
parameter.form == ToolParameter.ToolParameterForm.FORM
|
||||
and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT
|
||||
):
|
||||
if parameter.name in parameters:
|
||||
encrypted = encrypter.encrypt_token(self.tenant_id, parameters[parameter.name])
|
||||
parameters[parameter.name] = encrypted
|
||||
|
||||
return parameters
|
||||
|
||||
def decrypt_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
decrypt tool parameters with tenant id
|
||||
|
||||
return a deep copy of parameters with decrypted values
|
||||
"""
|
||||
|
||||
cache = ToolParameterCache(
|
||||
tenant_id=self.tenant_id,
|
||||
provider=f"{self.provider_type.value}.{self.provider_name}",
|
||||
tool_name=self.tool_runtime.entity.identity.name,
|
||||
cache_type=ToolParameterCacheType.PARAMETER,
|
||||
identity_id=self.identity_id,
|
||||
)
|
||||
cached_parameters = cache.get()
|
||||
if cached_parameters:
|
||||
return cached_parameters
|
||||
|
||||
# override parameters
|
||||
current_parameters = self._merge_parameters()
|
||||
has_secret_input = False
|
||||
|
||||
for parameter in current_parameters:
|
||||
if (
|
||||
parameter.form == ToolParameter.ToolParameterForm.FORM
|
||||
and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT
|
||||
):
|
||||
if parameter.name in parameters:
|
||||
try:
|
||||
has_secret_input = True
|
||||
parameters[parameter.name] = encrypter.decrypt_token(self.tenant_id, parameters[parameter.name])
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if has_secret_input:
|
||||
cache.set(parameters)
|
||||
|
||||
return parameters
|
||||
|
||||
def delete_tool_parameters_cache(self):
|
||||
cache = ToolParameterCache(
|
||||
tenant_id=self.tenant_id,
|
||||
provider=f"{self.provider_type.value}.{self.provider_name}",
|
||||
tool_name=self.tool_runtime.entity.identity.name,
|
||||
cache_type=ToolParameterCacheType.PARAMETER,
|
||||
identity_id=self.identity_id,
|
||||
)
|
||||
cache.delete()
|
|
@ -0,0 +1,199 @@
|
|||
import threading
|
||||
from typing import Any
|
||||
|
||||
from flask import Flask, current_app
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
from core.rag.models.document import Document as RagDocument
|
||||
from core.rag.rerank.rerank_model import RerankModelRunner
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
|
||||
default_retrieval_model: dict[str, Any] = {
|
||||
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
|
||||
"reranking_enable": False,
|
||||
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
|
||||
"top_k": 2,
|
||||
"score_threshold_enabled": False,
|
||||
}
|
||||
|
||||
|
||||
class DatasetMultiRetrieverToolInput(BaseModel):
|
||||
query: str = Field(..., description="dataset multi retriever and rerank")
|
||||
|
||||
|
||||
class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
|
||||
"""Tool for querying multi dataset."""
|
||||
|
||||
name: str = "dataset_"
|
||||
args_schema: type[BaseModel] = DatasetMultiRetrieverToolInput
|
||||
description: str = "dataset multi retriever and rerank. "
|
||||
dataset_ids: list[str]
|
||||
reranking_provider_name: str
|
||||
reranking_model_name: str
|
||||
|
||||
@classmethod
|
||||
def from_dataset(cls, dataset_ids: list[str], tenant_id: str, **kwargs):
|
||||
return cls(
|
||||
name=f"dataset_{tenant_id.replace('-', '_')}", tenant_id=tenant_id, dataset_ids=dataset_ids, **kwargs
|
||||
)
|
||||
|
||||
def _run(self, query: str) -> str:
|
||||
threads = []
|
||||
all_documents: list[RagDocument] = []
|
||||
for dataset_id in self.dataset_ids:
|
||||
retrieval_thread = threading.Thread(
|
||||
target=self._retriever,
|
||||
kwargs={
|
||||
"flask_app": current_app._get_current_object(), # type: ignore
|
||||
"dataset_id": dataset_id,
|
||||
"query": query,
|
||||
"all_documents": all_documents,
|
||||
"hit_callbacks": self.hit_callbacks,
|
||||
},
|
||||
)
|
||||
threads.append(retrieval_thread)
|
||||
retrieval_thread.start()
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
# do rerank for searched documents
|
||||
model_manager = ModelManager()
|
||||
rerank_model_instance = model_manager.get_model_instance(
|
||||
tenant_id=self.tenant_id,
|
||||
provider=self.reranking_provider_name,
|
||||
model_type=ModelType.RERANK,
|
||||
model=self.reranking_model_name,
|
||||
)
|
||||
|
||||
rerank_runner = RerankModelRunner(rerank_model_instance)
|
||||
all_documents = rerank_runner.run(query, all_documents, self.score_threshold, self.top_k)
|
||||
|
||||
for hit_callback in self.hit_callbacks:
|
||||
hit_callback.on_tool_end(all_documents)
|
||||
|
||||
document_score_list = {}
|
||||
for item in all_documents:
|
||||
if item.metadata and item.metadata.get("score"):
|
||||
document_score_list[item.metadata["doc_id"]] = item.metadata["score"]
|
||||
|
||||
document_context_list = []
|
||||
index_node_ids = [document.metadata["doc_id"] for document in all_documents if document.metadata]
|
||||
segments = DocumentSegment.query.filter(
|
||||
DocumentSegment.dataset_id.in_(self.dataset_ids),
|
||||
DocumentSegment.completed_at.isnot(None),
|
||||
DocumentSegment.status == "completed",
|
||||
DocumentSegment.enabled == True,
|
||||
DocumentSegment.index_node_id.in_(index_node_ids),
|
||||
).all()
|
||||
|
||||
if segments:
|
||||
index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)}
|
||||
sorted_segments = sorted(
|
||||
segments, key=lambda segment: index_node_id_to_position.get(segment.index_node_id, float("inf"))
|
||||
)
|
||||
for segment in sorted_segments:
|
||||
if segment.answer:
|
||||
document_context_list.append(f"question:{segment.get_sign_content()} answer:{segment.answer}")
|
||||
else:
|
||||
document_context_list.append(segment.get_sign_content())
|
||||
if self.return_resource:
|
||||
context_list = []
|
||||
resource_number = 1
|
||||
for segment in sorted_segments:
|
||||
dataset = Dataset.query.filter_by(id=segment.dataset_id).first()
|
||||
document = Document.query.filter(
|
||||
Document.id == segment.document_id,
|
||||
Document.enabled == True,
|
||||
Document.archived == False,
|
||||
).first()
|
||||
if dataset and document:
|
||||
source = {
|
||||
"position": resource_number,
|
||||
"dataset_id": dataset.id,
|
||||
"dataset_name": dataset.name,
|
||||
"document_id": document.id,
|
||||
"document_name": document.name,
|
||||
"data_source_type": document.data_source_type,
|
||||
"segment_id": segment.id,
|
||||
"retriever_from": self.retriever_from,
|
||||
"score": document_score_list.get(segment.index_node_id, None),
|
||||
"doc_metadata": document.doc_metadata,
|
||||
}
|
||||
|
||||
if self.retriever_from == "dev":
|
||||
source["hit_count"] = segment.hit_count
|
||||
source["word_count"] = segment.word_count
|
||||
source["segment_position"] = segment.position
|
||||
source["index_node_hash"] = segment.index_node_hash
|
||||
if segment.answer:
|
||||
source["content"] = f"question:{segment.content} \nanswer:{segment.answer}"
|
||||
else:
|
||||
source["content"] = segment.content
|
||||
context_list.append(source)
|
||||
resource_number += 1
|
||||
|
||||
for hit_callback in self.hit_callbacks:
|
||||
hit_callback.return_retriever_resource_info(context_list)
|
||||
|
||||
return str("\n".join(document_context_list))
|
||||
return ""
|
||||
|
||||
raise RuntimeError("not segments found")
|
||||
|
||||
def _retriever(
|
||||
self,
|
||||
flask_app: Flask,
|
||||
dataset_id: str,
|
||||
query: str,
|
||||
all_documents: list,
|
||||
hit_callbacks: list[DatasetIndexToolCallbackHandler],
|
||||
):
|
||||
with flask_app.app_context():
|
||||
dataset = (
|
||||
db.session.query(Dataset).filter(Dataset.tenant_id == self.tenant_id, Dataset.id == dataset_id).first()
|
||||
)
|
||||
|
||||
if not dataset:
|
||||
return []
|
||||
|
||||
for hit_callback in hit_callbacks:
|
||||
hit_callback.on_query(query, dataset.id)
|
||||
|
||||
# get retrieval model , if the model is not setting , using default
|
||||
retrieval_model = dataset.retrieval_model or default_retrieval_model
|
||||
|
||||
if dataset.indexing_technique == "economy":
|
||||
# use keyword table query
|
||||
documents = RetrievalService.retrieve(
|
||||
retrieval_method="keyword_search",
|
||||
dataset_id=dataset.id,
|
||||
query=query,
|
||||
top_k=retrieval_model.get("top_k") or 2,
|
||||
)
|
||||
if documents:
|
||||
all_documents.extend(documents)
|
||||
else:
|
||||
if self.top_k > 0:
|
||||
# retrieval source
|
||||
documents = RetrievalService.retrieve(
|
||||
retrieval_method=retrieval_model["search_method"],
|
||||
dataset_id=dataset.id,
|
||||
query=query,
|
||||
top_k=retrieval_model.get("top_k") or 2,
|
||||
score_threshold=retrieval_model.get("score_threshold", 0.0)
|
||||
if retrieval_model["score_threshold_enabled"]
|
||||
else 0.0,
|
||||
reranking_model=retrieval_model.get("reranking_model", None)
|
||||
if retrieval_model["reranking_enable"]
|
||||
else None,
|
||||
reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model",
|
||||
weights=retrieval_model.get("weights", None),
|
||||
)
|
||||
|
||||
all_documents.extend(documents)
|
|
@ -0,0 +1,33 @@
|
|||
from abc import abstractmethod
|
||||
from typing import Any, Optional
|
||||
|
||||
from msal_extensions.persistence import ABC # type: ignore
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
||||
|
||||
|
||||
class DatasetRetrieverBaseTool(BaseModel, ABC):
|
||||
"""Tool for querying a Dataset."""
|
||||
|
||||
name: str = "dataset"
|
||||
description: str = "use this to retrieve a dataset. "
|
||||
tenant_id: str
|
||||
top_k: int = 2
|
||||
score_threshold: Optional[float] = None
|
||||
hit_callbacks: list[DatasetIndexToolCallbackHandler] = []
|
||||
return_resource: bool
|
||||
retriever_from: str
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
@abstractmethod
|
||||
def _run(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Use the tool.
|
||||
|
||||
Add run_manager: Optional[CallbackManagerForToolRun] = None
|
||||
to child implementations to enable tracing,
|
||||
"""
|
|
@ -0,0 +1,202 @@
|
|||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
from core.rag.entities.context_entities import DocumentContext
|
||||
from core.rag.models.document import Document as RetrievalDocument
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset
|
||||
from models.dataset import Document as DatasetDocument
|
||||
from services.external_knowledge_service import ExternalDatasetService
|
||||
|
||||
default_retrieval_model = {
|
||||
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
|
||||
"reranking_enable": False,
|
||||
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
|
||||
"reranking_mode": "reranking_model",
|
||||
"top_k": 2,
|
||||
"score_threshold_enabled": False,
|
||||
}
|
||||
|
||||
|
||||
class DatasetRetrieverToolInput(BaseModel):
|
||||
query: str = Field(..., description="Query for the dataset to be used to retrieve the dataset.")
|
||||
|
||||
|
||||
class DatasetRetrieverTool(DatasetRetrieverBaseTool):
|
||||
"""Tool for querying a Dataset."""
|
||||
|
||||
name: str = "dataset"
|
||||
args_schema: type[BaseModel] = DatasetRetrieverToolInput
|
||||
description: str = "use this to retrieve a dataset. "
|
||||
dataset_id: str
|
||||
|
||||
@classmethod
|
||||
def from_dataset(cls, dataset: Dataset, **kwargs):
|
||||
description = dataset.description
|
||||
if not description:
|
||||
description = "useful for when you want to answer queries about the " + dataset.name
|
||||
|
||||
description = description.replace("\n", "").replace("\r", "")
|
||||
return cls(
|
||||
name=f"dataset_{dataset.id.replace('-', '_')}",
|
||||
tenant_id=dataset.tenant_id,
|
||||
dataset_id=dataset.id,
|
||||
description=description,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _run(self, query: str) -> str:
|
||||
dataset = (
|
||||
db.session.query(Dataset).filter(Dataset.tenant_id == self.tenant_id, Dataset.id == self.dataset_id).first()
|
||||
)
|
||||
|
||||
if not dataset:
|
||||
return ""
|
||||
for hit_callback in self.hit_callbacks:
|
||||
hit_callback.on_query(query, dataset.id)
|
||||
if dataset.provider == "external":
|
||||
results = []
|
||||
external_documents = ExternalDatasetService.fetch_external_knowledge_retrieval(
|
||||
tenant_id=dataset.tenant_id,
|
||||
dataset_id=dataset.id,
|
||||
query=query,
|
||||
external_retrieval_parameters=dataset.retrieval_model,
|
||||
)
|
||||
for external_document in external_documents:
|
||||
document = RetrievalDocument(
|
||||
page_content=external_document.get("content"),
|
||||
metadata=external_document.get("metadata"),
|
||||
provider="external",
|
||||
)
|
||||
if document.metadata is not None:
|
||||
document.metadata["score"] = external_document.get("score")
|
||||
document.metadata["title"] = external_document.get("title")
|
||||
document.metadata["dataset_id"] = dataset.id
|
||||
document.metadata["dataset_name"] = dataset.name
|
||||
results.append(document)
|
||||
# deal with external documents
|
||||
context_list = []
|
||||
for position, item in enumerate(results, start=1):
|
||||
if item.metadata is not None:
|
||||
source = {
|
||||
"position": position,
|
||||
"dataset_id": item.metadata.get("dataset_id"),
|
||||
"dataset_name": item.metadata.get("dataset_name"),
|
||||
"document_name": item.metadata.get("title"),
|
||||
"data_source_type": "external",
|
||||
"retriever_from": self.retriever_from,
|
||||
"score": item.metadata.get("score"),
|
||||
"title": item.metadata.get("title"),
|
||||
"content": item.page_content,
|
||||
}
|
||||
context_list.append(source)
|
||||
for hit_callback in self.hit_callbacks:
|
||||
hit_callback.return_retriever_resource_info(context_list)
|
||||
|
||||
return str("\n".join([item.page_content for item in results]))
|
||||
else:
|
||||
# get retrieval model , if the model is not setting , using default
|
||||
retrieval_model: dict[str, Any] = dataset.retrieval_model or default_retrieval_model
|
||||
if dataset.indexing_technique == "economy":
|
||||
# use keyword table query
|
||||
documents = RetrievalService.retrieve(
|
||||
retrieval_method="keyword_search", dataset_id=dataset.id, query=query, top_k=self.top_k
|
||||
)
|
||||
return str("\n".join([document.page_content for document in documents]))
|
||||
else:
|
||||
if self.top_k > 0:
|
||||
# retrieval source
|
||||
documents = RetrievalService.retrieve(
|
||||
retrieval_method=retrieval_model.get("search_method", "semantic_search"),
|
||||
dataset_id=dataset.id,
|
||||
query=query,
|
||||
top_k=self.top_k,
|
||||
score_threshold=retrieval_model.get("score_threshold", 0.0)
|
||||
if retrieval_model["score_threshold_enabled"]
|
||||
else 0.0,
|
||||
reranking_model=retrieval_model.get("reranking_model")
|
||||
if retrieval_model["reranking_enable"]
|
||||
else None,
|
||||
reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model",
|
||||
weights=retrieval_model.get("weights"),
|
||||
)
|
||||
else:
|
||||
documents = []
|
||||
for hit_callback in self.hit_callbacks:
|
||||
hit_callback.on_tool_end(documents)
|
||||
document_score_list = {}
|
||||
if dataset.indexing_technique != "economy":
|
||||
for item in documents:
|
||||
if item.metadata is not None and item.metadata.get("score"):
|
||||
document_score_list[item.metadata["doc_id"]] = item.metadata["score"]
|
||||
document_context_list = []
|
||||
records = RetrievalService.format_retrieval_documents(documents)
|
||||
if records:
|
||||
for record in records:
|
||||
segment = record.segment
|
||||
if segment.answer:
|
||||
document_context_list.append(
|
||||
DocumentContext(
|
||||
content=f"question:{segment.get_sign_content()} answer:{segment.answer}",
|
||||
score=record.score,
|
||||
)
|
||||
)
|
||||
else:
|
||||
document_context_list.append(
|
||||
DocumentContext(
|
||||
content=segment.get_sign_content(),
|
||||
score=record.score,
|
||||
)
|
||||
)
|
||||
retrieval_resource_list = []
|
||||
if self.return_resource:
|
||||
for record in records:
|
||||
segment = record.segment
|
||||
dataset = Dataset.query.filter_by(id=segment.dataset_id).first()
|
||||
document = DatasetDocument.query.filter(
|
||||
DatasetDocument.id == segment.document_id,
|
||||
DatasetDocument.enabled == True,
|
||||
DatasetDocument.archived == False,
|
||||
).first()
|
||||
if dataset and document:
|
||||
source = {
|
||||
"dataset_id": dataset.id,
|
||||
"dataset_name": dataset.name,
|
||||
"document_id": document.id, # type: ignore
|
||||
"document_name": document.name, # type: ignore
|
||||
"data_source_type": document.data_source_type, # type: ignore
|
||||
"segment_id": segment.id,
|
||||
"retriever_from": self.retriever_from,
|
||||
"score": record.score or 0.0,
|
||||
"doc_metadata": document.doc_metadata, # type: ignore
|
||||
}
|
||||
|
||||
if self.retriever_from == "dev":
|
||||
source["hit_count"] = segment.hit_count
|
||||
source["word_count"] = segment.word_count
|
||||
source["segment_position"] = segment.position
|
||||
source["index_node_hash"] = segment.index_node_hash
|
||||
if segment.answer:
|
||||
source["content"] = f"question:{segment.content} \nanswer:{segment.answer}"
|
||||
else:
|
||||
source["content"] = segment.content
|
||||
retrieval_resource_list.append(source)
|
||||
|
||||
if self.return_resource and retrieval_resource_list:
|
||||
retrieval_resource_list = sorted(
|
||||
retrieval_resource_list,
|
||||
key=lambda x: x.get("score") or 0.0,
|
||||
reverse=True,
|
||||
)
|
||||
for position, item in enumerate(retrieval_resource_list, start=1): # type: ignore
|
||||
item["position"] = position # type: ignore
|
||||
for hit_callback in self.hit_callbacks:
|
||||
hit_callback.return_retriever_resource_info(retrieval_resource_list)
|
||||
if document_context_list:
|
||||
document_context_list = sorted(document_context_list, key=lambda x: x.score or 0.0, reverse=True)
|
||||
return str("\n".join([document_context.content for document_context in document_context_list]))
|
||||
return ""
|
|
@ -0,0 +1,134 @@
|
|||
from collections.abc import Generator
|
||||
from typing import Any, Optional
|
||||
|
||||
from core.app.app_config.entities import DatasetRetrieveConfigEntity
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
||||
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import (
|
||||
ToolDescription,
|
||||
ToolEntity,
|
||||
ToolIdentity,
|
||||
ToolInvokeMessage,
|
||||
ToolParameter,
|
||||
ToolProviderType,
|
||||
)
|
||||
from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
|
||||
|
||||
|
||||
class DatasetRetrieverTool(Tool):
|
||||
retrieval_tool: DatasetRetrieverBaseTool
|
||||
|
||||
def __init__(self, entity: ToolEntity, runtime: ToolRuntime, retrieval_tool: DatasetRetrieverBaseTool) -> None:
|
||||
super().__init__(entity, runtime)
|
||||
self.retrieval_tool = retrieval_tool
|
||||
|
||||
@staticmethod
|
||||
def get_dataset_tools(
|
||||
tenant_id: str,
|
||||
dataset_ids: list[str],
|
||||
retrieve_config: DatasetRetrieveConfigEntity | None,
|
||||
return_resource: bool,
|
||||
invoke_from: InvokeFrom,
|
||||
hit_callback: DatasetIndexToolCallbackHandler,
|
||||
) -> list["DatasetRetrieverTool"]:
|
||||
"""
|
||||
get dataset tool
|
||||
"""
|
||||
# check if retrieve_config is valid
|
||||
if dataset_ids is None or len(dataset_ids) == 0:
|
||||
return []
|
||||
if retrieve_config is None:
|
||||
return []
|
||||
|
||||
feature = DatasetRetrieval()
|
||||
|
||||
# save original retrieve strategy, and set retrieve strategy to SINGLE
|
||||
# Agent only support SINGLE mode
|
||||
original_retriever_mode = retrieve_config.retrieve_strategy
|
||||
retrieve_config.retrieve_strategy = DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE
|
||||
retrieval_tools = feature.to_dataset_retriever_tool(
|
||||
tenant_id=tenant_id,
|
||||
dataset_ids=dataset_ids,
|
||||
retrieve_config=retrieve_config,
|
||||
return_resource=return_resource,
|
||||
invoke_from=invoke_from,
|
||||
hit_callback=hit_callback,
|
||||
)
|
||||
if retrieval_tools is None or len(retrieval_tools) == 0:
|
||||
return []
|
||||
|
||||
# restore retrieve strategy
|
||||
retrieve_config.retrieve_strategy = original_retriever_mode
|
||||
|
||||
# convert retrieval tools to Tools
|
||||
tools = []
|
||||
for retrieval_tool in retrieval_tools:
|
||||
tool = DatasetRetrieverTool(
|
||||
retrieval_tool=retrieval_tool,
|
||||
entity=ToolEntity(
|
||||
identity=ToolIdentity(
|
||||
provider="", author="", name=retrieval_tool.name, label=I18nObject(en_US="", zh_Hans="")
|
||||
),
|
||||
parameters=[],
|
||||
description=ToolDescription(human=I18nObject(en_US="", zh_Hans=""), llm=retrieval_tool.description),
|
||||
),
|
||||
runtime=ToolRuntime(tenant_id=tenant_id),
|
||||
)
|
||||
|
||||
tools.append(tool)
|
||||
|
||||
return tools
|
||||
|
||||
def get_runtime_parameters(
|
||||
self,
|
||||
conversation_id: Optional[str] = None,
|
||||
app_id: Optional[str] = None,
|
||||
message_id: Optional[str] = None,
|
||||
) -> list[ToolParameter]:
|
||||
return [
|
||||
ToolParameter(
|
||||
name="query",
|
||||
label=I18nObject(en_US="", zh_Hans=""),
|
||||
human_description=I18nObject(en_US="", zh_Hans=""),
|
||||
type=ToolParameter.ToolParameterType.STRING,
|
||||
form=ToolParameter.ToolParameterForm.LLM,
|
||||
llm_description="Query for the dataset to be used to retrieve the dataset.",
|
||||
required=True,
|
||||
default="",
|
||||
placeholder=I18nObject(en_US="", zh_Hans=""),
|
||||
),
|
||||
]
|
||||
|
||||
def tool_provider_type(self) -> ToolProviderType:
|
||||
return ToolProviderType.DATASET_RETRIEVAL
|
||||
|
||||
def _invoke(
|
||||
self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
conversation_id: Optional[str] = None,
|
||||
app_id: Optional[str] = None,
|
||||
message_id: Optional[str] = None,
|
||||
) -> Generator[ToolInvokeMessage, None, None]:
|
||||
"""
|
||||
invoke dataset retriever tool
|
||||
"""
|
||||
query = tool_parameters.get("query")
|
||||
if not query:
|
||||
yield self.create_text_message(text="please input query")
|
||||
else:
|
||||
# invoke dataset retriever tool
|
||||
result = self.retrieval_tool._run(query=query)
|
||||
yield self.create_text_message(text=result)
|
||||
|
||||
def validate_credentials(
|
||||
self, credentials: dict[str, Any], parameters: dict[str, Any], format_only: bool = False
|
||||
) -> str | None:
|
||||
"""
|
||||
validate the credentials for dataset retriever tool
|
||||
"""
|
||||
pass
|
|
@ -0,0 +1,121 @@
|
|||
import logging
|
||||
from collections.abc import Generator
|
||||
from mimetypes import guess_extension
|
||||
from typing import Optional
|
||||
|
||||
from core.file import File, FileTransferMethod, FileType
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ToolFileMessageTransformer:
|
||||
@classmethod
|
||||
def transform_tool_invoke_messages(
|
||||
cls,
|
||||
messages: Generator[ToolInvokeMessage, None, None],
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
conversation_id: Optional[str] = None,
|
||||
) -> Generator[ToolInvokeMessage, None, None]:
|
||||
"""
|
||||
Transform tool message and handle file download
|
||||
"""
|
||||
for message in messages:
|
||||
if message.type in {ToolInvokeMessage.MessageType.TEXT, ToolInvokeMessage.MessageType.LINK}:
|
||||
yield message
|
||||
elif message.type == ToolInvokeMessage.MessageType.IMAGE and isinstance(
|
||||
message.message, ToolInvokeMessage.TextMessage
|
||||
):
|
||||
# try to download image
|
||||
try:
|
||||
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
|
||||
|
||||
file = ToolFileManager.create_file_by_url(
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
file_url=message.message.text,
|
||||
conversation_id=conversation_id,
|
||||
)
|
||||
|
||||
url = f"/files/tools/{file.id}{guess_extension(file.mimetype) or '.png'}"
|
||||
|
||||
yield ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.IMAGE_LINK,
|
||||
message=ToolInvokeMessage.TextMessage(text=url),
|
||||
meta=message.meta.copy() if message.meta is not None else {},
|
||||
)
|
||||
except Exception as e:
|
||||
yield ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.TEXT,
|
||||
message=ToolInvokeMessage.TextMessage(
|
||||
text=f"Failed to download image: {message.message.text}: {e}"
|
||||
),
|
||||
meta=message.meta.copy() if message.meta is not None else {},
|
||||
)
|
||||
elif message.type == ToolInvokeMessage.MessageType.BLOB:
|
||||
# get mime type and save blob to storage
|
||||
meta = message.meta or {}
|
||||
|
||||
mimetype = meta.get("mime_type", "application/octet-stream")
|
||||
# get filename from meta
|
||||
filename = meta.get("file_name", None)
|
||||
# if message is str, encode it to bytes
|
||||
|
||||
if not isinstance(message.message, ToolInvokeMessage.BlobMessage):
|
||||
raise ValueError("unexpected message type")
|
||||
|
||||
# FIXME: should do a type check here.
|
||||
assert isinstance(message.message.blob, bytes)
|
||||
file = ToolFileManager.create_file_by_raw(
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
conversation_id=conversation_id,
|
||||
file_binary=message.message.blob,
|
||||
mimetype=mimetype,
|
||||
filename=filename,
|
||||
)
|
||||
|
||||
url = cls.get_tool_file_url(tool_file_id=file.id, extension=guess_extension(file.mimetype))
|
||||
|
||||
# check if file is image
|
||||
if "image" in mimetype:
|
||||
yield ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.IMAGE_LINK,
|
||||
message=ToolInvokeMessage.TextMessage(text=url),
|
||||
meta=meta.copy() if meta is not None else {},
|
||||
)
|
||||
else:
|
||||
yield ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.BINARY_LINK,
|
||||
message=ToolInvokeMessage.TextMessage(text=url),
|
||||
meta=meta.copy() if meta is not None else {},
|
||||
)
|
||||
elif message.type == ToolInvokeMessage.MessageType.FILE:
|
||||
meta = message.meta or {}
|
||||
file = meta.get("file", None)
|
||||
if isinstance(file, File):
|
||||
if file.transfer_method == FileTransferMethod.TOOL_FILE:
|
||||
assert file.related_id is not None
|
||||
url = cls.get_tool_file_url(tool_file_id=file.related_id, extension=file.extension)
|
||||
if file.type == FileType.IMAGE:
|
||||
yield ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.IMAGE_LINK,
|
||||
message=ToolInvokeMessage.TextMessage(text=url),
|
||||
meta=meta.copy() if meta is not None else {},
|
||||
)
|
||||
else:
|
||||
yield ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.LINK,
|
||||
message=ToolInvokeMessage.TextMessage(text=url),
|
||||
meta=meta.copy() if meta is not None else {},
|
||||
)
|
||||
else:
|
||||
yield message
|
||||
else:
|
||||
yield message
|
||||
|
||||
@classmethod
|
||||
def get_tool_file_url(cls, tool_file_id: str, extension: Optional[str]) -> str:
|
||||
return f"/files/tools/{tool_file_id}{extension or '.bin'}"
|
|
@ -0,0 +1,169 @@
|
|||
"""
|
||||
For some reason, model will be used in tools like WebScraperTool, WikipediaSearchTool etc.
|
||||
|
||||
Therefore, a model manager is needed to list/invoke/validate models.
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import Optional, cast
|
||||
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.llm_entities import LLMResult
|
||||
from core.model_runtime.entities.message_entities import PromptMessage
|
||||
from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
|
||||
from core.model_runtime.errors.invoke import (
|
||||
InvokeAuthorizationError,
|
||||
InvokeBadRequestError,
|
||||
InvokeConnectionError,
|
||||
InvokeRateLimitError,
|
||||
InvokeServerUnavailableError,
|
||||
)
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from extensions.ext_database import db
|
||||
from models.tools import ToolModelInvoke
|
||||
|
||||
|
||||
class InvokeModelError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class ModelInvocationUtils:
|
||||
@staticmethod
|
||||
def get_max_llm_context_tokens(
|
||||
tenant_id: str,
|
||||
) -> int:
|
||||
"""
|
||||
get max llm context tokens of the model
|
||||
"""
|
||||
model_manager = ModelManager()
|
||||
model_instance = model_manager.get_default_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
model_type=ModelType.LLM,
|
||||
)
|
||||
|
||||
if not model_instance:
|
||||
raise InvokeModelError("Model not found")
|
||||
|
||||
llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
|
||||
schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials)
|
||||
|
||||
if not schema:
|
||||
raise InvokeModelError("No model schema found")
|
||||
|
||||
max_tokens: Optional[int] = schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE, None)
|
||||
if max_tokens is None:
|
||||
return 2048
|
||||
|
||||
return max_tokens
|
||||
|
||||
@staticmethod
|
||||
def calculate_tokens(tenant_id: str, prompt_messages: list[PromptMessage]) -> int:
|
||||
"""
|
||||
calculate tokens from prompt messages and model parameters
|
||||
"""
|
||||
|
||||
# get model instance
|
||||
model_manager = ModelManager()
|
||||
model_instance = model_manager.get_default_model_instance(tenant_id=tenant_id, model_type=ModelType.LLM)
|
||||
|
||||
if not model_instance:
|
||||
raise InvokeModelError("Model not found")
|
||||
|
||||
# get tokens
|
||||
tokens = model_instance.get_llm_num_tokens(prompt_messages)
|
||||
|
||||
return tokens
|
||||
|
||||
@staticmethod
|
||||
def invoke(
|
||||
user_id: str, tenant_id: str, tool_type: str, tool_name: str, prompt_messages: list[PromptMessage]
|
||||
) -> LLMResult:
|
||||
"""
|
||||
invoke model with parameters in user's own context
|
||||
|
||||
:param user_id: user id
|
||||
:param tenant_id: tenant id, the tenant id of the creator of the tool
|
||||
:param tool_type: tool type
|
||||
:param tool_name: tool name
|
||||
:param prompt_messages: prompt messages
|
||||
:return: AssistantPromptMessage
|
||||
"""
|
||||
|
||||
# get model manager
|
||||
model_manager = ModelManager()
|
||||
# get model instance
|
||||
model_instance = model_manager.get_default_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
model_type=ModelType.LLM,
|
||||
)
|
||||
|
||||
# get prompt tokens
|
||||
prompt_tokens = model_instance.get_llm_num_tokens(prompt_messages)
|
||||
|
||||
model_parameters = {
|
||||
"temperature": 0.8,
|
||||
"top_p": 0.8,
|
||||
}
|
||||
|
||||
# create tool model invoke
|
||||
tool_model_invoke = ToolModelInvoke(
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
provider=model_instance.provider,
|
||||
tool_type=tool_type,
|
||||
tool_name=tool_name,
|
||||
model_parameters=json.dumps(model_parameters),
|
||||
prompt_messages=json.dumps(jsonable_encoder(prompt_messages)),
|
||||
model_response="",
|
||||
prompt_tokens=prompt_tokens,
|
||||
answer_tokens=0,
|
||||
answer_unit_price=0,
|
||||
answer_price_unit=0,
|
||||
provider_response_latency=0,
|
||||
total_price=0,
|
||||
currency="USD",
|
||||
)
|
||||
|
||||
db.session.add(tool_model_invoke)
|
||||
db.session.commit()
|
||||
|
||||
try:
|
||||
response: LLMResult = cast(
|
||||
LLMResult,
|
||||
model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=model_parameters,
|
||||
tools=[],
|
||||
stop=[],
|
||||
stream=False,
|
||||
user=user_id,
|
||||
callbacks=[],
|
||||
),
|
||||
)
|
||||
except InvokeRateLimitError as e:
|
||||
raise InvokeModelError(f"Invoke rate limit error: {e}")
|
||||
except InvokeBadRequestError as e:
|
||||
raise InvokeModelError(f"Invoke bad request error: {e}")
|
||||
except InvokeConnectionError as e:
|
||||
raise InvokeModelError(f"Invoke connection error: {e}")
|
||||
except InvokeAuthorizationError as e:
|
||||
raise InvokeModelError("Invoke authorization error")
|
||||
except InvokeServerUnavailableError as e:
|
||||
raise InvokeModelError(f"Invoke server unavailable error: {e}")
|
||||
except Exception as e:
|
||||
raise InvokeModelError(f"Invoke error: {e}")
|
||||
|
||||
# update tool model invoke
|
||||
tool_model_invoke.model_response = response.message.content
|
||||
if response.usage:
|
||||
tool_model_invoke.answer_tokens = response.usage.completion_tokens
|
||||
tool_model_invoke.answer_unit_price = response.usage.completion_unit_price
|
||||
tool_model_invoke.answer_price_unit = response.usage.completion_price_unit
|
||||
tool_model_invoke.provider_response_latency = response.usage.latency
|
||||
tool_model_invoke.total_price = response.usage.total_price
|
||||
tool_model_invoke.currency = response.usage.currency
|
||||
|
||||
db.session.commit()
|
||||
|
||||
return response
|
|
@ -0,0 +1,389 @@
|
|||
import re
|
||||
import uuid
|
||||
from json import dumps as json_dumps
|
||||
from json import loads as json_loads
|
||||
from json.decoder import JSONDecodeError
|
||||
from typing import Optional
|
||||
|
||||
from flask import request
|
||||
from requests import get
|
||||
from yaml import YAMLError, safe_load # type: ignore
|
||||
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_bundle import ApiToolBundle
|
||||
from core.tools.entities.tool_entities import ApiProviderSchemaType, ToolParameter
|
||||
from core.tools.errors import ToolApiSchemaError, ToolNotSupportedError, ToolProviderNotFoundError
|
||||
|
||||
|
||||
class ApiBasedToolSchemaParser:
|
||||
@staticmethod
|
||||
def parse_openapi_to_tool_bundle(
|
||||
openapi: dict, extra_info: dict | None = None, warning: dict | None = None
|
||||
) -> list[ApiToolBundle]:
|
||||
warning = warning if warning is not None else {}
|
||||
extra_info = extra_info if extra_info is not None else {}
|
||||
|
||||
# set description to extra_info
|
||||
extra_info["description"] = openapi["info"].get("description", "")
|
||||
|
||||
if len(openapi["servers"]) == 0:
|
||||
raise ToolProviderNotFoundError("No server found in the openapi yaml.")
|
||||
|
||||
server_url = openapi["servers"][0]["url"]
|
||||
request_env = request.headers.get("X-Request-Env")
|
||||
if request_env:
|
||||
matched_servers = [server["url"] for server in openapi["servers"] if server["env"] == request_env]
|
||||
server_url = matched_servers[0] if matched_servers else server_url
|
||||
|
||||
# list all interfaces
|
||||
interfaces = []
|
||||
for path, path_item in openapi["paths"].items():
|
||||
methods = ["get", "post", "put", "delete", "patch", "head", "options", "trace"]
|
||||
for method in methods:
|
||||
if method in path_item:
|
||||
interfaces.append(
|
||||
{
|
||||
"path": path,
|
||||
"method": method,
|
||||
"operation": path_item[method],
|
||||
}
|
||||
)
|
||||
|
||||
# get all parameters
|
||||
bundles = []
|
||||
for interface in interfaces:
|
||||
# convert parameters
|
||||
parameters = []
|
||||
if "parameters" in interface["operation"]:
|
||||
for parameter in interface["operation"]["parameters"]:
|
||||
tool_parameter = ToolParameter(
|
||||
name=parameter["name"],
|
||||
label=I18nObject(en_US=parameter["name"], zh_Hans=parameter["name"]),
|
||||
human_description=I18nObject(
|
||||
en_US=parameter.get("description", ""), zh_Hans=parameter.get("description", "")
|
||||
),
|
||||
type=ToolParameter.ToolParameterType.STRING,
|
||||
required=parameter.get("required", False),
|
||||
form=ToolParameter.ToolParameterForm.LLM,
|
||||
llm_description=parameter.get("description"),
|
||||
default=parameter["schema"]["default"]
|
||||
if "schema" in parameter and "default" in parameter["schema"]
|
||||
else None,
|
||||
placeholder=I18nObject(
|
||||
en_US=parameter.get("description", ""), zh_Hans=parameter.get("description", "")
|
||||
),
|
||||
)
|
||||
|
||||
# check if there is a type
|
||||
typ = ApiBasedToolSchemaParser._get_tool_parameter_type(parameter)
|
||||
if typ:
|
||||
tool_parameter.type = typ
|
||||
|
||||
parameters.append(tool_parameter)
|
||||
# create tool bundle
|
||||
# check if there is a request body
|
||||
if "requestBody" in interface["operation"]:
|
||||
request_body = interface["operation"]["requestBody"]
|
||||
if "content" in request_body:
|
||||
for content_type, content in request_body["content"].items():
|
||||
# if there is a reference, get the reference and overwrite the content
|
||||
if "schema" not in content:
|
||||
continue
|
||||
|
||||
if "$ref" in content["schema"]:
|
||||
# get the reference
|
||||
root = openapi
|
||||
reference = content["schema"]["$ref"].split("/")[1:]
|
||||
for ref in reference:
|
||||
root = root[ref]
|
||||
# overwrite the content
|
||||
interface["operation"]["requestBody"]["content"][content_type]["schema"] = root
|
||||
|
||||
# parse body parameters
|
||||
if "schema" in interface["operation"]["requestBody"]["content"][content_type]:
|
||||
body_schema = interface["operation"]["requestBody"]["content"][content_type]["schema"]
|
||||
required = body_schema.get("required", [])
|
||||
properties = body_schema.get("properties", {})
|
||||
for name, property in properties.items():
|
||||
tool = ToolParameter(
|
||||
name=name,
|
||||
label=I18nObject(en_US=name, zh_Hans=name),
|
||||
human_description=I18nObject(
|
||||
en_US=property.get("description", ""), zh_Hans=property.get("description", "")
|
||||
),
|
||||
type=ToolParameter.ToolParameterType.STRING,
|
||||
required=name in required,
|
||||
form=ToolParameter.ToolParameterForm.LLM,
|
||||
llm_description=property.get("description", ""),
|
||||
default=property.get("default", None),
|
||||
placeholder=I18nObject(
|
||||
en_US=property.get("description", ""), zh_Hans=property.get("description", "")
|
||||
),
|
||||
)
|
||||
|
||||
# check if there is a type
|
||||
typ = ApiBasedToolSchemaParser._get_tool_parameter_type(property)
|
||||
if typ:
|
||||
tool.type = typ
|
||||
|
||||
parameters.append(tool)
|
||||
|
||||
# check if parameters is duplicated
|
||||
parameters_count = {}
|
||||
for parameter in parameters:
|
||||
if parameter.name not in parameters_count:
|
||||
parameters_count[parameter.name] = 0
|
||||
parameters_count[parameter.name] += 1
|
||||
for name, count in parameters_count.items():
|
||||
if count > 1:
|
||||
warning["duplicated_parameter"] = f"Parameter {name} is duplicated."
|
||||
|
||||
# check if there is a operation id, use $path_$method as operation id if not
|
||||
if "operationId" not in interface["operation"]:
|
||||
# remove special characters like / to ensure the operation id is valid ^[a-zA-Z0-9_-]{1,64}$
|
||||
path = interface["path"]
|
||||
if interface["path"].startswith("/"):
|
||||
path = interface["path"][1:]
|
||||
# remove special characters like / to ensure the operation id is valid ^[a-zA-Z0-9_-]{1,64}$
|
||||
path = re.sub(r"[^a-zA-Z0-9_-]", "", path)
|
||||
if not path:
|
||||
path = str(uuid.uuid4())
|
||||
|
||||
interface["operation"]["operationId"] = f"{path}_{interface['method']}"
|
||||
|
||||
bundles.append(
|
||||
ApiToolBundle(
|
||||
server_url=server_url + interface["path"],
|
||||
method=interface["method"],
|
||||
summary=interface["operation"]["description"]
|
||||
if "description" in interface["operation"]
|
||||
else interface["operation"].get("summary", None),
|
||||
operation_id=interface["operation"]["operationId"],
|
||||
parameters=parameters,
|
||||
author="",
|
||||
icon=None,
|
||||
openapi=interface["operation"],
|
||||
)
|
||||
)
|
||||
|
||||
return bundles
|
||||
|
||||
@staticmethod
|
||||
def _get_tool_parameter_type(parameter: dict) -> Optional[ToolParameter.ToolParameterType]:
|
||||
parameter = parameter or {}
|
||||
typ: Optional[str] = None
|
||||
if parameter.get("format") == "binary":
|
||||
return ToolParameter.ToolParameterType.FILE
|
||||
|
||||
if "type" in parameter:
|
||||
typ = parameter["type"]
|
||||
elif "schema" in parameter and "type" in parameter["schema"]:
|
||||
typ = parameter["schema"]["type"]
|
||||
|
||||
if typ in {"integer", "number"}:
|
||||
return ToolParameter.ToolParameterType.NUMBER
|
||||
elif typ == "boolean":
|
||||
return ToolParameter.ToolParameterType.BOOLEAN
|
||||
elif typ == "string":
|
||||
return ToolParameter.ToolParameterType.STRING
|
||||
elif typ == "array":
|
||||
items = parameter.get("items") or parameter.get("schema", {}).get("items")
|
||||
return ToolParameter.ToolParameterType.FILES if items and items.get("format") == "binary" else None
|
||||
else:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def parse_openapi_yaml_to_tool_bundle(
|
||||
yaml: str, extra_info: dict | None = None, warning: dict | None = None
|
||||
) -> list[ApiToolBundle]:
|
||||
"""
|
||||
parse openapi yaml to tool bundle
|
||||
|
||||
:param yaml: the yaml string
|
||||
:param extra_info: the extra info
|
||||
:param warning: the warning message
|
||||
:return: the tool bundle
|
||||
"""
|
||||
warning = warning if warning is not None else {}
|
||||
extra_info = extra_info if extra_info is not None else {}
|
||||
|
||||
openapi: dict = safe_load(yaml)
|
||||
if openapi is None:
|
||||
raise ToolApiSchemaError("Invalid openapi yaml.")
|
||||
return ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(openapi, extra_info=extra_info, warning=warning)
|
||||
|
||||
@staticmethod
|
||||
def parse_swagger_to_openapi(swagger: dict, extra_info: dict | None = None, warning: dict | None = None) -> dict:
|
||||
warning = warning or {}
|
||||
"""
|
||||
parse swagger to openapi
|
||||
|
||||
:param swagger: the swagger dict
|
||||
:return: the openapi dict
|
||||
"""
|
||||
# convert swagger to openapi
|
||||
info = swagger.get("info", {"title": "Swagger", "description": "Swagger", "version": "1.0.0"})
|
||||
|
||||
servers = swagger.get("servers", [])
|
||||
|
||||
if len(servers) == 0:
|
||||
raise ToolApiSchemaError("No server found in the swagger yaml.")
|
||||
|
||||
openapi = {
|
||||
"openapi": "3.0.0",
|
||||
"info": {
|
||||
"title": info.get("title", "Swagger"),
|
||||
"description": info.get("description", "Swagger"),
|
||||
"version": info.get("version", "1.0.0"),
|
||||
},
|
||||
"servers": swagger["servers"],
|
||||
"paths": {},
|
||||
"components": {"schemas": {}},
|
||||
}
|
||||
|
||||
# check paths
|
||||
if "paths" not in swagger or len(swagger["paths"]) == 0:
|
||||
raise ToolApiSchemaError("No paths found in the swagger yaml.")
|
||||
|
||||
# convert paths
|
||||
for path, path_item in swagger["paths"].items():
|
||||
openapi["paths"][path] = {}
|
||||
for method, operation in path_item.items():
|
||||
if "operationId" not in operation:
|
||||
raise ToolApiSchemaError(f"No operationId found in operation {method} {path}.")
|
||||
|
||||
if ("summary" not in operation or len(operation["summary"]) == 0) and (
|
||||
"description" not in operation or len(operation["description"]) == 0
|
||||
):
|
||||
if warning is not None:
|
||||
warning["missing_summary"] = f"No summary or description found in operation {method} {path}."
|
||||
|
||||
openapi["paths"][path][method] = {
|
||||
"operationId": operation["operationId"],
|
||||
"summary": operation.get("summary", ""),
|
||||
"description": operation.get("description", ""),
|
||||
"parameters": operation.get("parameters", []),
|
||||
"responses": operation.get("responses", {}),
|
||||
}
|
||||
|
||||
if "requestBody" in operation:
|
||||
openapi["paths"][path][method]["requestBody"] = operation["requestBody"]
|
||||
|
||||
# convert definitions
|
||||
for name, definition in swagger["definitions"].items():
|
||||
openapi["components"]["schemas"][name] = definition
|
||||
|
||||
return openapi
|
||||
|
||||
@staticmethod
|
||||
def parse_openai_plugin_json_to_tool_bundle(
|
||||
json: str, extra_info: dict | None = None, warning: dict | None = None
|
||||
) -> list[ApiToolBundle]:
|
||||
"""
|
||||
parse openapi plugin yaml to tool bundle
|
||||
|
||||
:param json: the json string
|
||||
:param extra_info: the extra info
|
||||
:param warning: the warning message
|
||||
:return: the tool bundle
|
||||
"""
|
||||
warning = warning if warning is not None else {}
|
||||
extra_info = extra_info if extra_info is not None else {}
|
||||
|
||||
try:
|
||||
openai_plugin = json_loads(json)
|
||||
api = openai_plugin["api"]
|
||||
api_url = api["url"]
|
||||
api_type = api["type"]
|
||||
except JSONDecodeError:
|
||||
raise ToolProviderNotFoundError("Invalid openai plugin json.")
|
||||
|
||||
if api_type != "openapi":
|
||||
raise ToolNotSupportedError("Only openapi is supported now.")
|
||||
|
||||
# get openapi yaml
|
||||
response = get(api_url, headers={"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) "}, timeout=5)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise ToolProviderNotFoundError("cannot get openapi yaml from url.")
|
||||
|
||||
return ApiBasedToolSchemaParser.parse_openapi_yaml_to_tool_bundle(
|
||||
response.text, extra_info=extra_info, warning=warning
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def auto_parse_to_tool_bundle(
|
||||
content: str, extra_info: dict | None = None, warning: dict | None = None
|
||||
) -> tuple[list[ApiToolBundle], str]:
|
||||
"""
|
||||
auto parse to tool bundle
|
||||
|
||||
:param content: the content
|
||||
:param extra_info: the extra info
|
||||
:param warning: the warning message
|
||||
:return: tools bundle, schema_type
|
||||
"""
|
||||
warning = warning if warning is not None else {}
|
||||
extra_info = extra_info if extra_info is not None else {}
|
||||
|
||||
content = content.strip()
|
||||
loaded_content = None
|
||||
json_error = None
|
||||
yaml_error = None
|
||||
|
||||
try:
|
||||
loaded_content = json_loads(content)
|
||||
except JSONDecodeError as e:
|
||||
json_error = e
|
||||
|
||||
if loaded_content is None:
|
||||
try:
|
||||
loaded_content = safe_load(content)
|
||||
except YAMLError as e:
|
||||
yaml_error = e
|
||||
if loaded_content is None:
|
||||
raise ToolApiSchemaError(
|
||||
f"Invalid api schema, schema is neither json nor yaml. json error: {str(json_error)},"
|
||||
f" yaml error: {str(yaml_error)}"
|
||||
)
|
||||
|
||||
swagger_error = None
|
||||
openapi_error = None
|
||||
openapi_plugin_error = None
|
||||
schema_type = None
|
||||
|
||||
try:
|
||||
openapi = ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(
|
||||
loaded_content, extra_info=extra_info, warning=warning
|
||||
)
|
||||
schema_type = ApiProviderSchemaType.OPENAPI.value
|
||||
return openapi, schema_type
|
||||
except ToolApiSchemaError as e:
|
||||
openapi_error = e
|
||||
|
||||
# openai parse error, fallback to swagger
|
||||
try:
|
||||
converted_swagger = ApiBasedToolSchemaParser.parse_swagger_to_openapi(
|
||||
loaded_content, extra_info=extra_info, warning=warning
|
||||
)
|
||||
schema_type = ApiProviderSchemaType.SWAGGER.value
|
||||
return ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(
|
||||
converted_swagger, extra_info=extra_info, warning=warning
|
||||
), schema_type
|
||||
except ToolApiSchemaError as e:
|
||||
swagger_error = e
|
||||
|
||||
# swagger parse error, fallback to openai plugin
|
||||
try:
|
||||
openapi_plugin = ApiBasedToolSchemaParser.parse_openai_plugin_json_to_tool_bundle(
|
||||
json_dumps(loaded_content), extra_info=extra_info, warning=warning
|
||||
)
|
||||
return openapi_plugin, ApiProviderSchemaType.OPENAI_PLUGIN.value
|
||||
except ToolNotSupportedError as e:
|
||||
# maybe it's not plugin at all
|
||||
openapi_plugin_error = e
|
||||
|
||||
raise ToolApiSchemaError(
|
||||
f"Invalid api schema, openapi error: {str(openapi_error)}, swagger error: {str(swagger_error)},"
|
||||
f" openapi plugin error: {str(openapi_plugin_error)}"
|
||||
)
|
|
@ -0,0 +1,17 @@
|
|||
import re
|
||||
|
||||
|
||||
def get_image_upload_file_ids(content):
|
||||
pattern = r"!\[image\]\((http?://.*?(file-preview|image-preview))\)"
|
||||
matches = re.findall(pattern, content)
|
||||
image_upload_file_ids = []
|
||||
for match in matches:
|
||||
if match[1] == "file-preview":
|
||||
content_pattern = r"files/([^/]+)/file-preview"
|
||||
else:
|
||||
content_pattern = r"files/([^/]+)/image-preview"
|
||||
content_match = re.search(content_pattern, match[0])
|
||||
if content_match:
|
||||
image_upload_file_id = content_match.group(1)
|
||||
image_upload_file_ids.append(image_upload_file_id)
|
||||
return image_upload_file_ids
|
|
@ -0,0 +1,17 @@
|
|||
import re
|
||||
|
||||
|
||||
def remove_leading_symbols(text: str) -> str:
|
||||
"""
|
||||
Remove leading punctuation or symbols from the given text.
|
||||
|
||||
Args:
|
||||
text (str): The input text to process.
|
||||
|
||||
Returns:
|
||||
str: The text with leading punctuation or symbols removed.
|
||||
"""
|
||||
# Match Unicode ranges for punctuation and symbols
|
||||
# FIXME this pattern is confused quick fix for #11868 maybe refactor it later
|
||||
pattern = r"^[\u2000-\u206F\u2E00-\u2E7F\u3000-\u303F!\"#$%&'()*+,./:;<=>?@^_`~]+"
|
||||
return re.sub(pattern, "", text)
|
|
@ -0,0 +1,9 @@
|
|||
import uuid
|
||||
|
||||
|
||||
def is_valid_uuid(uuid_str: str) -> bool:
|
||||
try:
|
||||
uuid.UUID(uuid_str)
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
|
@ -0,0 +1,375 @@
|
|||
import hashlib
|
||||
import json
|
||||
import mimetypes
|
||||
import os
|
||||
import re
|
||||
import site
|
||||
import subprocess
|
||||
import tempfile
|
||||
import unicodedata
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from typing import Any, Literal, Optional, cast
|
||||
from urllib.parse import unquote
|
||||
|
||||
import chardet
|
||||
import cloudscraper # type: ignore
|
||||
from bs4 import BeautifulSoup, CData, Comment, NavigableString # type: ignore
|
||||
from regex import regex # type: ignore
|
||||
|
||||
from core.helper import ssrf_proxy
|
||||
from core.rag.extractor import extract_processor
|
||||
from core.rag.extractor.extract_processor import ExtractProcessor
|
||||
|
||||
FULL_TEMPLATE = """
|
||||
TITLE: {title}
|
||||
AUTHORS: {authors}
|
||||
PUBLISH DATE: {publish_date}
|
||||
TOP_IMAGE_URL: {top_image}
|
||||
TEXT:
|
||||
|
||||
{text}
|
||||
"""
|
||||
|
||||
|
||||
def page_result(text: str, cursor: int, max_length: int) -> str:
|
||||
"""Page through `text` and return a substring of `max_length` characters starting from `cursor`."""
|
||||
return text[cursor : cursor + max_length]
|
||||
|
||||
|
||||
def get_url(url: str, user_agent: Optional[str] = None) -> str:
|
||||
"""Fetch URL and return the contents as a string."""
|
||||
headers = {
|
||||
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko)"
|
||||
" Chrome/91.0.4472.124 Safari/537.36"
|
||||
}
|
||||
if user_agent:
|
||||
headers["User-Agent"] = user_agent
|
||||
|
||||
main_content_type = None
|
||||
supported_content_types = extract_processor.SUPPORT_URL_CONTENT_TYPES + ["text/html"]
|
||||
response = ssrf_proxy.head(url, headers=headers, follow_redirects=True, timeout=(5, 10))
|
||||
|
||||
if response.status_code == 200:
|
||||
# check content-type
|
||||
content_type = response.headers.get("Content-Type")
|
||||
if content_type:
|
||||
main_content_type = response.headers.get("Content-Type").split(";")[0].strip()
|
||||
else:
|
||||
content_disposition = response.headers.get("Content-Disposition", "")
|
||||
filename_match = re.search(r'filename="([^"]+)"', content_disposition)
|
||||
if filename_match:
|
||||
filename = unquote(filename_match.group(1))
|
||||
extension = re.search(r"\.(\w+)$", filename)
|
||||
if extension:
|
||||
main_content_type = mimetypes.guess_type(filename)[0]
|
||||
|
||||
if main_content_type not in supported_content_types:
|
||||
return "Unsupported content-type [{}] of URL.".format(main_content_type)
|
||||
|
||||
if main_content_type in extract_processor.SUPPORT_URL_CONTENT_TYPES:
|
||||
return cast(str, ExtractProcessor.load_from_url(url, return_text=True))
|
||||
|
||||
response = ssrf_proxy.get(url, headers=headers, follow_redirects=True, timeout=(120, 300))
|
||||
elif response.status_code == 403:
|
||||
scraper = cloudscraper.create_scraper()
|
||||
scraper.perform_request = ssrf_proxy.make_request
|
||||
response = scraper.get(url, headers=headers, follow_redirects=True, timeout=(120, 300))
|
||||
|
||||
if response.status_code != 200:
|
||||
return "URL returned status code {}.".format(response.status_code)
|
||||
|
||||
# Detect encoding using chardet
|
||||
detected_encoding = chardet.detect(response.content)
|
||||
encoding = detected_encoding["encoding"]
|
||||
if encoding:
|
||||
try:
|
||||
content = response.content.decode(encoding)
|
||||
except (UnicodeDecodeError, TypeError):
|
||||
content = response.text
|
||||
else:
|
||||
content = response.text
|
||||
|
||||
a = extract_using_readabilipy(content)
|
||||
|
||||
if not a["plain_text"] or not a["plain_text"].strip():
|
||||
return ""
|
||||
|
||||
res = FULL_TEMPLATE.format(
|
||||
title=a["title"],
|
||||
authors=a["byline"],
|
||||
publish_date=a["date"],
|
||||
top_image="",
|
||||
text=a["plain_text"] or "",
|
||||
)
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def extract_using_readabilipy(html):
|
||||
with tempfile.NamedTemporaryFile(delete=False, mode="w+") as f_html:
|
||||
f_html.write(html)
|
||||
f_html.close()
|
||||
html_path = f_html.name
|
||||
|
||||
# Call Mozilla's Readability.js Readability.parse() function via node, writing output to a temporary file
|
||||
article_json_path = html_path + ".json"
|
||||
jsdir = os.path.join(find_module_path("readabilipy"), "javascript")
|
||||
with chdir(jsdir):
|
||||
subprocess.check_call(["node", "ExtractArticle.js", "-i", html_path, "-o", article_json_path])
|
||||
|
||||
# Read output of call to Readability.parse() from JSON file and return as Python dictionary
|
||||
input_json = json.loads(Path(article_json_path).read_text(encoding="utf-8"))
|
||||
|
||||
# Deleting files after processing
|
||||
os.unlink(article_json_path)
|
||||
os.unlink(html_path)
|
||||
|
||||
article_json: dict[str, Any] = {
|
||||
"title": None,
|
||||
"byline": None,
|
||||
"date": None,
|
||||
"content": None,
|
||||
"plain_content": None,
|
||||
"plain_text": None,
|
||||
}
|
||||
# Populate article fields from readability fields where present
|
||||
if input_json:
|
||||
if input_json.get("title"):
|
||||
article_json["title"] = input_json["title"]
|
||||
if input_json.get("byline"):
|
||||
article_json["byline"] = input_json["byline"]
|
||||
if input_json.get("date"):
|
||||
article_json["date"] = input_json["date"]
|
||||
if input_json.get("content"):
|
||||
article_json["content"] = input_json["content"]
|
||||
article_json["plain_content"] = plain_content(article_json["content"], False, False)
|
||||
article_json["plain_text"] = extract_text_blocks_as_plain_text(article_json["plain_content"])
|
||||
if input_json.get("textContent"):
|
||||
article_json["plain_text"] = input_json["textContent"]
|
||||
article_json["plain_text"] = re.sub(r"\n\s*\n", "\n", article_json["plain_text"])
|
||||
|
||||
return article_json
|
||||
|
||||
|
||||
def find_module_path(module_name):
|
||||
for package_path in site.getsitepackages():
|
||||
potential_path = os.path.join(package_path, module_name)
|
||||
if os.path.exists(potential_path):
|
||||
return potential_path
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@contextmanager
|
||||
def chdir(path):
|
||||
"""Change directory in context and return to original on exit"""
|
||||
# From https://stackoverflow.com/a/37996581, couldn't find a built-in
|
||||
original_path = os.getcwd()
|
||||
os.chdir(path)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
os.chdir(original_path)
|
||||
|
||||
|
||||
def extract_text_blocks_as_plain_text(paragraph_html):
|
||||
# Load article as DOM
|
||||
soup = BeautifulSoup(paragraph_html, "html.parser")
|
||||
# Select all lists
|
||||
list_elements = soup.find_all(["ul", "ol"])
|
||||
# Prefix text in all list items with "* " and make lists paragraphs
|
||||
for list_element in list_elements:
|
||||
plain_items = "".join(
|
||||
list(filter(None, [plain_text_leaf_node(li)["text"] for li in list_element.find_all("li")]))
|
||||
)
|
||||
list_element.string = plain_items
|
||||
list_element.name = "p"
|
||||
# Select all text blocks
|
||||
text_blocks = [s.parent for s in soup.find_all(string=True)]
|
||||
text_blocks = [plain_text_leaf_node(block) for block in text_blocks]
|
||||
# Drop empty paragraphs
|
||||
text_blocks = list(filter(lambda p: p["text"] is not None, text_blocks))
|
||||
return text_blocks
|
||||
|
||||
|
||||
def plain_text_leaf_node(element):
|
||||
# Extract all text, stripped of any child HTML elements and normalize it
|
||||
plain_text = normalize_text(element.get_text())
|
||||
if plain_text != "" and element.name == "li":
|
||||
plain_text = "* {}, ".format(plain_text)
|
||||
if plain_text == "":
|
||||
plain_text = None
|
||||
if "data-node-index" in element.attrs:
|
||||
plain = {"node_index": element["data-node-index"], "text": plain_text}
|
||||
else:
|
||||
plain = {"text": plain_text}
|
||||
return plain
|
||||
|
||||
|
||||
def plain_content(readability_content, content_digests, node_indexes):
|
||||
# Load article as DOM
|
||||
soup = BeautifulSoup(readability_content, "html.parser")
|
||||
# Make all elements plain
|
||||
elements = plain_elements(soup.contents, content_digests, node_indexes)
|
||||
if node_indexes:
|
||||
# Add node index attributes to nodes
|
||||
elements = [add_node_indexes(element) for element in elements]
|
||||
# Replace article contents with plain elements
|
||||
soup.contents = elements
|
||||
return str(soup)
|
||||
|
||||
|
||||
def plain_elements(elements, content_digests, node_indexes):
|
||||
# Get plain content versions of all elements
|
||||
elements = [plain_element(element, content_digests, node_indexes) for element in elements]
|
||||
if content_digests:
|
||||
# Add content digest attribute to nodes
|
||||
elements = [add_content_digest(element) for element in elements]
|
||||
return elements
|
||||
|
||||
|
||||
def plain_element(element, content_digests, node_indexes):
|
||||
# For lists, we make each item plain text
|
||||
if is_leaf(element):
|
||||
# For leaf node elements, extract the text content, discarding any HTML tags
|
||||
# 1. Get element contents as text
|
||||
plain_text = element.get_text()
|
||||
# 2. Normalize the extracted text string to a canonical representation
|
||||
plain_text = normalize_text(plain_text)
|
||||
# 3. Update element content to be plain text
|
||||
element.string = plain_text
|
||||
elif is_text(element):
|
||||
if is_non_printing(element):
|
||||
# The simplified HTML may have come from Readability.js so might
|
||||
# have non-printing text (e.g. Comment or CData). In this case, we
|
||||
# keep the structure, but ensure that the string is empty.
|
||||
element = type(element)("")
|
||||
else:
|
||||
plain_text = element.string
|
||||
plain_text = normalize_text(plain_text)
|
||||
element = type(element)(plain_text)
|
||||
else:
|
||||
# If not a leaf node or leaf type call recursively on child nodes, replacing
|
||||
element.contents = plain_elements(element.contents, content_digests, node_indexes)
|
||||
return element
|
||||
|
||||
|
||||
def add_node_indexes(element, node_index="0"):
|
||||
# Can't add attributes to string types
|
||||
if is_text(element):
|
||||
return element
|
||||
# Add index to current element
|
||||
element["data-node-index"] = node_index
|
||||
# Add index to child elements
|
||||
for local_idx, child in enumerate([c for c in element.contents if not is_text(c)], start=1):
|
||||
# Can't add attributes to leaf string types
|
||||
child_index = "{stem}.{local}".format(stem=node_index, local=local_idx)
|
||||
add_node_indexes(child, node_index=child_index)
|
||||
return element
|
||||
|
||||
|
||||
def normalize_text(text):
|
||||
"""Normalize unicode and whitespace."""
|
||||
# Normalize unicode first to try and standardize whitespace characters as much as possible before normalizing them
|
||||
text = strip_control_characters(text)
|
||||
text = normalize_unicode(text)
|
||||
text = normalize_whitespace(text)
|
||||
return text
|
||||
|
||||
|
||||
def strip_control_characters(text):
|
||||
"""Strip out unicode control characters which might break the parsing."""
|
||||
# Unicode control characters
|
||||
# [Cc]: Other, Control [includes new lines]
|
||||
# [Cf]: Other, Format
|
||||
# [Cn]: Other, Not Assigned
|
||||
# [Co]: Other, Private Use
|
||||
# [Cs]: Other, Surrogate
|
||||
control_chars = {"Cc", "Cf", "Cn", "Co", "Cs"}
|
||||
retained_chars = ["\t", "\n", "\r", "\f"]
|
||||
|
||||
# Remove non-printing control characters
|
||||
return "".join(
|
||||
[
|
||||
"" if (unicodedata.category(char) in control_chars) and (char not in retained_chars) else char
|
||||
for char in text
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def normalize_unicode(text):
|
||||
"""Normalize unicode such that things that are visually equivalent map to the same unicode string where possible."""
|
||||
normal_form: Literal["NFC", "NFD", "NFKC", "NFKD"] = "NFKC"
|
||||
text = unicodedata.normalize(normal_form, text)
|
||||
return text
|
||||
|
||||
|
||||
def normalize_whitespace(text):
|
||||
"""Replace runs of whitespace characters with a single space as this is what happens when HTML text is displayed."""
|
||||
text = regex.sub(r"\s+", " ", text)
|
||||
# Remove leading and trailing whitespace
|
||||
text = text.strip()
|
||||
return text
|
||||
|
||||
|
||||
def is_leaf(element):
|
||||
return element.name in {"p", "li"}
|
||||
|
||||
|
||||
def is_text(element):
|
||||
return isinstance(element, NavigableString)
|
||||
|
||||
|
||||
def is_non_printing(element):
|
||||
return any(isinstance(element, _e) for _e in [Comment, CData])
|
||||
|
||||
|
||||
def add_content_digest(element):
|
||||
if not is_text(element):
|
||||
element["data-content-digest"] = content_digest(element)
|
||||
return element
|
||||
|
||||
|
||||
def content_digest(element):
|
||||
digest: Any
|
||||
if is_text(element):
|
||||
# Hash
|
||||
trimmed_string = element.string.strip()
|
||||
if trimmed_string == "":
|
||||
digest = ""
|
||||
else:
|
||||
digest = hashlib.sha256(trimmed_string.encode("utf-8")).hexdigest()
|
||||
else:
|
||||
contents = element.contents
|
||||
num_contents = len(contents)
|
||||
if num_contents == 0:
|
||||
# No hash when no child elements exist
|
||||
digest = ""
|
||||
elif num_contents == 1:
|
||||
# If single child, use digest of child
|
||||
digest = content_digest(contents[0])
|
||||
else:
|
||||
# Build content digest from the "non-empty" digests of child nodes
|
||||
digest = hashlib.sha256()
|
||||
child_digests = list(filter(lambda x: x != "", [content_digest(content) for content in contents]))
|
||||
for child in child_digests:
|
||||
digest.update(child.encode("utf-8"))
|
||||
digest = digest.hexdigest()
|
||||
return digest
|
||||
|
||||
|
||||
def get_image_upload_file_ids(content):
|
||||
pattern = r"!\[image\]\((http?://.*?(file-preview|image-preview))\)"
|
||||
matches = re.findall(pattern, content)
|
||||
image_upload_file_ids = []
|
||||
for match in matches:
|
||||
if match[1] == "file-preview":
|
||||
content_pattern = r"files/([^/]+)/file-preview"
|
||||
else:
|
||||
content_pattern = r"files/([^/]+)/image-preview"
|
||||
content_match = re.search(content_pattern, match[0])
|
||||
if content_match:
|
||||
image_upload_file_id = content_match.group(1)
|
||||
image_upload_file_ids.append(image_upload_file_id)
|
||||
return image_upload_file_ids
|
|
@ -0,0 +1,43 @@
|
|||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any
|
||||
|
||||
from core.app.app_config.entities import VariableEntity
|
||||
from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration
|
||||
|
||||
|
||||
class WorkflowToolConfigurationUtils:
|
||||
@classmethod
|
||||
def check_parameter_configurations(cls, configurations: list[Mapping[str, Any]]):
|
||||
for configuration in configurations:
|
||||
WorkflowToolParameterConfiguration.model_validate(configuration)
|
||||
|
||||
@classmethod
|
||||
def get_workflow_graph_variables(cls, graph: Mapping[str, Any]) -> Sequence[VariableEntity]:
|
||||
"""
|
||||
get workflow graph variables
|
||||
"""
|
||||
nodes = graph.get("nodes", [])
|
||||
start_node = next(filter(lambda x: x.get("data", {}).get("type") == "start", nodes), None)
|
||||
|
||||
if not start_node:
|
||||
return []
|
||||
|
||||
return [VariableEntity.model_validate(variable) for variable in start_node.get("data", {}).get("variables", [])]
|
||||
|
||||
@classmethod
|
||||
def check_is_synced(
|
||||
cls, variables: list[VariableEntity], tool_configurations: list[WorkflowToolParameterConfiguration]
|
||||
):
|
||||
"""
|
||||
check is synced
|
||||
|
||||
raise ValueError if not synced
|
||||
"""
|
||||
variable_names = [variable.variable for variable in variables]
|
||||
|
||||
if len(tool_configurations) != len(variables):
|
||||
raise ValueError("parameter configuration mismatch, please republish the tool to update")
|
||||
|
||||
for parameter in tool_configurations:
|
||||
if parameter.name not in variable_names:
|
||||
raise ValueError("parameter configuration mismatch, please republish the tool to update")
|
|
@ -0,0 +1,35 @@
|
|||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import yaml # type: ignore
|
||||
from yaml import YAMLError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def load_yaml_file(file_path: str, ignore_error: bool = True, default_value: Any = {}) -> Any:
|
||||
"""
|
||||
Safe loading a YAML file
|
||||
:param file_path: the path of the YAML file
|
||||
:param ignore_error:
|
||||
if True, return default_value if error occurs and the error will be logged in debug level
|
||||
if False, raise error if error occurs
|
||||
:param default_value: the value returned when errors ignored
|
||||
:return: an object of the YAML content
|
||||
"""
|
||||
if not file_path or not Path(file_path).exists():
|
||||
if ignore_error:
|
||||
return default_value
|
||||
else:
|
||||
raise FileNotFoundError(f"File not found: {file_path}")
|
||||
|
||||
with open(file_path, encoding="utf-8") as yaml_file:
|
||||
try:
|
||||
yaml_content = yaml.safe_load(yaml_file)
|
||||
return yaml_content or default_value
|
||||
except Exception as e:
|
||||
if ignore_error:
|
||||
return default_value
|
||||
else:
|
||||
raise YAMLError(f"Failed to load YAML file {file_path}: {e}") from e
|
|
@ -1,3 +1,4 @@
|
|||
SYSTEM_VARIABLE_NODE_ID = "sys"
|
||||
ENVIRONMENT_VARIABLE_NODE_ID = "env"
|
||||
CONVERSATION_VARIABLE_NODE_ID = "conversation"
|
||||
PIPELINE_VARIABLE_NODE_ID = "pipeline"
|
|
@ -17,6 +17,7 @@ class NodeRunMetadataKey(StrEnum):
|
|||
TOTAL_PRICE = "total_price"
|
||||
CURRENCY = "currency"
|
||||
TOOL_INFO = "tool_info"
|
||||
DATASOURCE_INFO = "datasource_info"
|
||||
AGENT_LOG = "agent_log"
|
||||
ITERATION_ID = "iteration_id"
|
||||
ITERATION_INDEX = "iteration_index"
|
||||
|
|
|
@ -0,0 +1,3 @@
|
|||
from .tool_node import ToolNode
|
||||
|
||||
__all__ = ["DatasourceNode"]
|
|
@ -0,0 +1,406 @@
|
|||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import Any, cast
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
|
||||
from core.file import File, FileTransferMethod
|
||||
from core.plugin.manager.exc import PluginDaemonClientSideError
|
||||
from core.plugin.manager.plugin import PluginInstallationManager
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
|
||||
from core.tools.errors import ToolInvokeError
|
||||
from core.tools.tool_engine import ToolEngine
|
||||
from core.tools.utils.message_transformer import ToolFileMessageTransformer
|
||||
from core.variables.segments import ArrayAnySegment
|
||||
from core.variables.variables import ArrayAnyVariable
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.graph_engine.entities.event import AgentLogEvent
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.nodes.event import RunCompletedEvent, RunStreamChunkEvent
|
||||
from core.workflow.utils.variable_template_parser import VariableTemplateParser
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
from models import ToolFile
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
|
||||
|
||||
from .entities import DatasourceNodeData
|
||||
from .exc import (
|
||||
ToolFileError,
|
||||
ToolNodeError,
|
||||
ToolParameterError,
|
||||
)
|
||||
|
||||
|
||||
class DatasourceNode(BaseNode[DatasourceNodeData]):
|
||||
"""
|
||||
Datasource Node
|
||||
"""
|
||||
|
||||
_node_data_cls = DatasourceNodeData
|
||||
_node_type = NodeType.DATASOURCE
|
||||
|
||||
def _run(self) -> Generator:
|
||||
"""
|
||||
Run the datasource node
|
||||
"""
|
||||
|
||||
node_data = cast(DatasourceNodeData, self.node_data)
|
||||
|
||||
# fetch datasource icon
|
||||
datasource_info = {
|
||||
"provider_type": node_data.provider_type.value,
|
||||
"provider_id": node_data.provider_id,
|
||||
"plugin_unique_identifier": node_data.plugin_unique_identifier,
|
||||
}
|
||||
|
||||
# get datasource runtime
|
||||
try:
|
||||
from core.tools.tool_manager import ToolManager
|
||||
|
||||
tool_runtime = ToolManager.get_workflow_tool_runtime(
|
||||
self.tenant_id, self.app_id, self.node_id, self.node_data, self.invoke_from
|
||||
)
|
||||
except ToolNodeError as e:
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs={},
|
||||
metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info},
|
||||
error=f"Failed to get datasource runtime: {str(e)}",
|
||||
error_type=type(e).__name__,
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
# get parameters
|
||||
tool_parameters = tool_runtime.get_merged_runtime_parameters() or []
|
||||
parameters = self._generate_parameters(
|
||||
tool_parameters=tool_parameters,
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
node_data=self.node_data,
|
||||
)
|
||||
parameters_for_log = self._generate_parameters(
|
||||
tool_parameters=tool_parameters,
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
node_data=self.node_data,
|
||||
for_log=True,
|
||||
)
|
||||
|
||||
# get conversation id
|
||||
conversation_id = self.graph_runtime_state.variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID])
|
||||
|
||||
try:
|
||||
message_stream = ToolEngine.generic_invoke(
|
||||
tool=tool_runtime,
|
||||
tool_parameters=parameters,
|
||||
user_id=self.user_id,
|
||||
workflow_tool_callback=DifyWorkflowCallbackHandler(),
|
||||
workflow_call_depth=self.workflow_call_depth,
|
||||
thread_pool_id=self.thread_pool_id,
|
||||
app_id=self.app_id,
|
||||
conversation_id=conversation_id.text if conversation_id else None,
|
||||
)
|
||||
except ToolNodeError as e:
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=parameters_for_log,
|
||||
metadata={NodeRunMetadataKey.TOOL_INFO: tool_info},
|
||||
error=f"Failed to invoke tool: {str(e)}",
|
||||
error_type=type(e).__name__,
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
# convert tool messages
|
||||
yield from self._transform_message(message_stream, tool_info, parameters_for_log)
|
||||
except (PluginDaemonClientSideError, ToolInvokeError) as e:
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=parameters_for_log,
|
||||
metadata={NodeRunMetadataKey.TOOL_INFO: tool_info},
|
||||
error=f"Failed to transform tool message: {str(e)}",
|
||||
error_type=type(e).__name__,
|
||||
)
|
||||
)
|
||||
|
||||
def _generate_parameters(
|
||||
self,
|
||||
*,
|
||||
tool_parameters: Sequence[ToolParameter],
|
||||
variable_pool: VariablePool,
|
||||
node_data: ToolNodeData,
|
||||
for_log: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Generate parameters based on the given tool parameters, variable pool, and node data.
|
||||
|
||||
Args:
|
||||
tool_parameters (Sequence[ToolParameter]): The list of tool parameters.
|
||||
variable_pool (VariablePool): The variable pool containing the variables.
|
||||
node_data (ToolNodeData): The data associated with the tool node.
|
||||
|
||||
Returns:
|
||||
Mapping[str, Any]: A dictionary containing the generated parameters.
|
||||
|
||||
"""
|
||||
tool_parameters_dictionary = {parameter.name: parameter for parameter in tool_parameters}
|
||||
|
||||
result: dict[str, Any] = {}
|
||||
for parameter_name in node_data.tool_parameters:
|
||||
parameter = tool_parameters_dictionary.get(parameter_name)
|
||||
if not parameter:
|
||||
result[parameter_name] = None
|
||||
continue
|
||||
tool_input = node_data.tool_parameters[parameter_name]
|
||||
if tool_input.type == "variable":
|
||||
variable = variable_pool.get(tool_input.value)
|
||||
if variable is None:
|
||||
raise ToolParameterError(f"Variable {tool_input.value} does not exist")
|
||||
parameter_value = variable.value
|
||||
elif tool_input.type in {"mixed", "constant"}:
|
||||
segment_group = variable_pool.convert_template(str(tool_input.value))
|
||||
parameter_value = segment_group.log if for_log else segment_group.text
|
||||
else:
|
||||
raise ToolParameterError(f"Unknown tool input type '{tool_input.type}'")
|
||||
result[parameter_name] = parameter_value
|
||||
|
||||
return result
|
||||
|
||||
def _fetch_files(self, variable_pool: VariablePool) -> list[File]:
|
||||
variable = variable_pool.get(["sys", SystemVariableKey.FILES.value])
|
||||
assert isinstance(variable, ArrayAnyVariable | ArrayAnySegment)
|
||||
return list(variable.value) if variable else []
|
||||
|
||||
def _transform_message(
|
||||
self,
|
||||
messages: Generator[ToolInvokeMessage, None, None],
|
||||
tool_info: Mapping[str, Any],
|
||||
parameters_for_log: dict[str, Any],
|
||||
) -> Generator:
|
||||
"""
|
||||
Convert ToolInvokeMessages into tuple[plain_text, files]
|
||||
"""
|
||||
# transform message and handle file storage
|
||||
message_stream = ToolFileMessageTransformer.transform_tool_invoke_messages(
|
||||
messages=messages,
|
||||
user_id=self.user_id,
|
||||
tenant_id=self.tenant_id,
|
||||
conversation_id=None,
|
||||
)
|
||||
|
||||
text = ""
|
||||
files: list[File] = []
|
||||
json: list[dict] = []
|
||||
|
||||
agent_logs: list[AgentLogEvent] = []
|
||||
agent_execution_metadata: Mapping[NodeRunMetadataKey, Any] = {}
|
||||
|
||||
variables: dict[str, Any] = {}
|
||||
|
||||
for message in message_stream:
|
||||
if message.type in {
|
||||
ToolInvokeMessage.MessageType.IMAGE_LINK,
|
||||
ToolInvokeMessage.MessageType.BINARY_LINK,
|
||||
ToolInvokeMessage.MessageType.IMAGE,
|
||||
}:
|
||||
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
|
||||
|
||||
url = message.message.text
|
||||
if message.meta:
|
||||
transfer_method = message.meta.get("transfer_method", FileTransferMethod.TOOL_FILE)
|
||||
else:
|
||||
transfer_method = FileTransferMethod.TOOL_FILE
|
||||
|
||||
tool_file_id = str(url).split("/")[-1].split(".")[0]
|
||||
|
||||
with Session(db.engine) as session:
|
||||
stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
|
||||
tool_file = session.scalar(stmt)
|
||||
if tool_file is None:
|
||||
raise ToolFileError(f"Tool file {tool_file_id} does not exist")
|
||||
|
||||
mapping = {
|
||||
"tool_file_id": tool_file_id,
|
||||
"type": file_factory.get_file_type_by_mime_type(tool_file.mimetype),
|
||||
"transfer_method": transfer_method,
|
||||
"url": url,
|
||||
}
|
||||
file = file_factory.build_from_mapping(
|
||||
mapping=mapping,
|
||||
tenant_id=self.tenant_id,
|
||||
)
|
||||
files.append(file)
|
||||
elif message.type == ToolInvokeMessage.MessageType.BLOB:
|
||||
# get tool file id
|
||||
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
|
||||
assert message.meta
|
||||
|
||||
tool_file_id = message.message.text.split("/")[-1].split(".")[0]
|
||||
with Session(db.engine) as session:
|
||||
stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
|
||||
tool_file = session.scalar(stmt)
|
||||
if tool_file is None:
|
||||
raise ToolFileError(f"tool file {tool_file_id} not exists")
|
||||
|
||||
mapping = {
|
||||
"tool_file_id": tool_file_id,
|
||||
"transfer_method": FileTransferMethod.TOOL_FILE,
|
||||
}
|
||||
|
||||
files.append(
|
||||
file_factory.build_from_mapping(
|
||||
mapping=mapping,
|
||||
tenant_id=self.tenant_id,
|
||||
)
|
||||
)
|
||||
elif message.type == ToolInvokeMessage.MessageType.TEXT:
|
||||
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
|
||||
text += message.message.text
|
||||
yield RunStreamChunkEvent(
|
||||
chunk_content=message.message.text, from_variable_selector=[self.node_id, "text"]
|
||||
)
|
||||
elif message.type == ToolInvokeMessage.MessageType.JSON:
|
||||
assert isinstance(message.message, ToolInvokeMessage.JsonMessage)
|
||||
if self.node_type == NodeType.AGENT:
|
||||
msg_metadata = message.message.json_object.pop("execution_metadata", {})
|
||||
agent_execution_metadata = {
|
||||
key: value
|
||||
for key, value in msg_metadata.items()
|
||||
if key in NodeRunMetadataKey.__members__.values()
|
||||
}
|
||||
json.append(message.message.json_object)
|
||||
elif message.type == ToolInvokeMessage.MessageType.LINK:
|
||||
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
|
||||
stream_text = f"Link: {message.message.text}\n"
|
||||
text += stream_text
|
||||
yield RunStreamChunkEvent(chunk_content=stream_text, from_variable_selector=[self.node_id, "text"])
|
||||
elif message.type == ToolInvokeMessage.MessageType.VARIABLE:
|
||||
assert isinstance(message.message, ToolInvokeMessage.VariableMessage)
|
||||
variable_name = message.message.variable_name
|
||||
variable_value = message.message.variable_value
|
||||
if message.message.stream:
|
||||
if not isinstance(variable_value, str):
|
||||
raise ValueError("When 'stream' is True, 'variable_value' must be a string.")
|
||||
if variable_name not in variables:
|
||||
variables[variable_name] = ""
|
||||
variables[variable_name] += variable_value
|
||||
|
||||
yield RunStreamChunkEvent(
|
||||
chunk_content=variable_value, from_variable_selector=[self.node_id, variable_name]
|
||||
)
|
||||
else:
|
||||
variables[variable_name] = variable_value
|
||||
elif message.type == ToolInvokeMessage.MessageType.FILE:
|
||||
assert message.meta is not None
|
||||
files.append(message.meta["file"])
|
||||
elif message.type == ToolInvokeMessage.MessageType.LOG:
|
||||
assert isinstance(message.message, ToolInvokeMessage.LogMessage)
|
||||
if message.message.metadata:
|
||||
icon = tool_info.get("icon", "")
|
||||
dict_metadata = dict(message.message.metadata)
|
||||
if dict_metadata.get("provider"):
|
||||
manager = PluginInstallationManager()
|
||||
plugins = manager.list_plugins(self.tenant_id)
|
||||
try:
|
||||
current_plugin = next(
|
||||
plugin
|
||||
for plugin in plugins
|
||||
if f"{plugin.plugin_id}/{plugin.name}" == dict_metadata["provider"]
|
||||
)
|
||||
icon = current_plugin.declaration.icon
|
||||
except StopIteration:
|
||||
pass
|
||||
try:
|
||||
builtin_tool = next(
|
||||
provider
|
||||
for provider in BuiltinToolManageService.list_builtin_tools(
|
||||
self.user_id,
|
||||
self.tenant_id,
|
||||
)
|
||||
if provider.name == dict_metadata["provider"]
|
||||
)
|
||||
icon = builtin_tool.icon
|
||||
except StopIteration:
|
||||
pass
|
||||
|
||||
dict_metadata["icon"] = icon
|
||||
message.message.metadata = dict_metadata
|
||||
agent_log = AgentLogEvent(
|
||||
id=message.message.id,
|
||||
node_execution_id=self.id,
|
||||
parent_id=message.message.parent_id,
|
||||
error=message.message.error,
|
||||
status=message.message.status.value,
|
||||
data=message.message.data,
|
||||
label=message.message.label,
|
||||
metadata=message.message.metadata,
|
||||
node_id=self.node_id,
|
||||
)
|
||||
|
||||
# check if the agent log is already in the list
|
||||
for log in agent_logs:
|
||||
if log.id == agent_log.id:
|
||||
# update the log
|
||||
log.data = agent_log.data
|
||||
log.status = agent_log.status
|
||||
log.error = agent_log.error
|
||||
log.label = agent_log.label
|
||||
log.metadata = agent_log.metadata
|
||||
break
|
||||
else:
|
||||
agent_logs.append(agent_log)
|
||||
|
||||
yield agent_log
|
||||
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
outputs={"text": text, "files": files, "json": json, **variables},
|
||||
metadata={
|
||||
**agent_execution_metadata,
|
||||
NodeRunMetadataKey.TOOL_INFO: tool_info,
|
||||
NodeRunMetadataKey.AGENT_LOG: agent_logs,
|
||||
},
|
||||
inputs=parameters_for_log,
|
||||
)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: ToolNodeData,
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
:param graph_config: graph config
|
||||
:param node_id: node id
|
||||
:param node_data: node data
|
||||
:return:
|
||||
"""
|
||||
result = {}
|
||||
for parameter_name in node_data.tool_parameters:
|
||||
input = node_data.tool_parameters[parameter_name]
|
||||
if input.type == "mixed":
|
||||
assert isinstance(input.value, str)
|
||||
selectors = VariableTemplateParser(input.value).extract_variable_selectors()
|
||||
for selector in selectors:
|
||||
result[selector.variable] = selector.value_selector
|
||||
elif input.type == "variable":
|
||||
result[parameter_name] = input.value
|
||||
elif input.type == "constant":
|
||||
pass
|
||||
|
||||
result = {node_id + "." + key: value for key, value in result.items()}
|
||||
|
||||
return result
|
|
@ -0,0 +1,56 @@
|
|||
from typing import Any, Literal, Union
|
||||
|
||||
from pydantic import BaseModel, field_validator
|
||||
from pydantic_core.core_schema import ValidationInfo
|
||||
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
from core.workflow.nodes.base.entities import BaseNodeData
|
||||
|
||||
|
||||
class DatasourceEntity(BaseModel):
|
||||
provider_id: str
|
||||
provider_type: ToolProviderType
|
||||
provider_name: str # redundancy
|
||||
tool_name: str
|
||||
tool_label: str # redundancy
|
||||
tool_configurations: dict[str, Any]
|
||||
plugin_unique_identifier: str | None = None # redundancy
|
||||
|
||||
@field_validator("tool_configurations", mode="before")
|
||||
@classmethod
|
||||
def validate_tool_configurations(cls, value, values: ValidationInfo):
|
||||
if not isinstance(value, dict):
|
||||
raise ValueError("tool_configurations must be a dictionary")
|
||||
|
||||
for key in values.data.get("tool_configurations", {}):
|
||||
value = values.data.get("tool_configurations", {}).get(key)
|
||||
if not isinstance(value, str | int | float | bool):
|
||||
raise ValueError(f"{key} must be a string")
|
||||
|
||||
return value
|
||||
|
||||
|
||||
class DatasourceNodeData(BaseNodeData, DatasourceEntity):
|
||||
class DatasourceInput(BaseModel):
|
||||
# TODO: check this type
|
||||
value: Union[Any, list[str]]
|
||||
type: Literal["mixed", "variable", "constant"]
|
||||
|
||||
@field_validator("type", mode="before")
|
||||
@classmethod
|
||||
def check_type(cls, value, validation_info: ValidationInfo):
|
||||
typ = value
|
||||
value = validation_info.data.get("value")
|
||||
if typ == "mixed" and not isinstance(value, str):
|
||||
raise ValueError("value must be a string")
|
||||
elif typ == "variable":
|
||||
if not isinstance(value, list):
|
||||
raise ValueError("value must be a list")
|
||||
for val in value:
|
||||
if not isinstance(val, str):
|
||||
raise ValueError("value must be a list of strings")
|
||||
elif typ == "constant" and not isinstance(value, str | int | float | bool):
|
||||
raise ValueError("value must be a string, int, float, or bool")
|
||||
return typ
|
||||
|
||||
datasource_parameters: dict[str, DatasourceInput]
|
|
@ -0,0 +1,16 @@
|
|||
class ToolNodeError(ValueError):
|
||||
"""Base exception for tool node errors."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ToolParameterError(ToolNodeError):
|
||||
"""Exception raised for errors in tool parameters."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ToolFileError(ToolNodeError):
|
||||
"""Exception raised for errors related to tool files."""
|
||||
|
||||
pass
|
|
@ -13,6 +13,7 @@ class NodeType(StrEnum):
|
|||
QUESTION_CLASSIFIER = "question-classifier"
|
||||
HTTP_REQUEST = "http-request"
|
||||
TOOL = "tool"
|
||||
DATASOURCE = "datasource"
|
||||
VARIABLE_AGGREGATOR = "variable-aggregator"
|
||||
LEGACY_VARIABLE_AGGREGATOR = "variable-assigner" # TODO: Merge this into VARIABLE_AGGREGATOR in the database.
|
||||
LOOP = "loop"
|
||||
|
|
|
@ -36,7 +36,11 @@ from core.variables.variables import (
|
|||
StringVariable,
|
||||
Variable,
|
||||
)
|
||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID
|
||||
from core.workflow.constants import (
|
||||
CONVERSATION_VARIABLE_NODE_ID,
|
||||
ENVIRONMENT_VARIABLE_NODE_ID,
|
||||
PIPELINE_VARIABLE_NODE_ID,
|
||||
)
|
||||
|
||||
|
||||
class InvalidSelectorError(ValueError):
|
||||
|
@ -74,6 +78,10 @@ def build_environment_variable_from_mapping(mapping: Mapping[str, Any], /) -> Va
|
|||
raise VariableError("missing name")
|
||||
return _build_variable_from_mapping(mapping=mapping, selector=[ENVIRONMENT_VARIABLE_NODE_ID, mapping["name"]])
|
||||
|
||||
def build_pipeline_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable:
|
||||
if not mapping.get("name"):
|
||||
raise VariableError("missing name")
|
||||
return _build_variable_from_mapping(mapping=mapping, selector=[PIPELINE_VARIABLE_NODE_ID, mapping["name"]])
|
||||
|
||||
def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequence[str]) -> Variable:
|
||||
"""
|
||||
|
|
|
@ -40,6 +40,13 @@ conversation_variable_fields = {
|
|||
"description": fields.String,
|
||||
}
|
||||
|
||||
pipeline_variable_fields = {
|
||||
"id": fields.String,
|
||||
"name": fields.String,
|
||||
"value_type": fields.String(attribute="value_type.value"),
|
||||
"value": fields.Raw,
|
||||
}
|
||||
|
||||
workflow_fields = {
|
||||
"id": fields.String,
|
||||
"graph": fields.Raw(attribute="graph_dict"),
|
||||
|
@ -55,6 +62,10 @@ workflow_fields = {
|
|||
"tool_published": fields.Boolean,
|
||||
"environment_variables": fields.List(EnvironmentVariableField()),
|
||||
"conversation_variables": fields.List(fields.Nested(conversation_variable_fields)),
|
||||
"pipeline_variables": fields.Dict(
|
||||
keys=fields.String,
|
||||
values=fields.List(fields.Nested(pipeline_variable_fields)),
|
||||
),
|
||||
}
|
||||
|
||||
workflow_partial_fields = {
|
||||
|
|
|
@ -130,6 +130,9 @@ class Workflow(Base):
|
|||
_conversation_variables: Mapped[str] = mapped_column(
|
||||
"conversation_variables", db.Text, nullable=False, server_default="{}"
|
||||
)
|
||||
_pipeline_variables: Mapped[str] = mapped_column(
|
||||
"conversation_variables", db.Text, nullable=False, server_default="{}"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def new(
|
||||
|
@ -343,6 +346,24 @@ class Workflow(Base):
|
|||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
@property
|
||||
def pipeline_variables(self) -> dict[str, Sequence[Variable]]:
|
||||
# TODO: find some way to init `self._conversation_variables` when instance created.
|
||||
if self._pipeline_variables is None:
|
||||
self._pipeline_variables = "{}"
|
||||
|
||||
variables_dict: dict[str, Any] = json.loads(self._pipeline_variables)
|
||||
results = {}
|
||||
for k, v in variables_dict.items():
|
||||
results[k] = [variable_factory.build_pipeline_variable_from_mapping(item) for item in v.values()]
|
||||
return results
|
||||
|
||||
@pipeline_variables.setter
|
||||
def pipeline_variables(self, values: dict[str, Sequence[Variable]]) -> None:
|
||||
self._pipeline_variables = json.dumps(
|
||||
{k: {item.name: item.model_dump() for item in v} for k, v in values.items()},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
class WorkflowRunStatus(StrEnum):
|
||||
"""
|
||||
|
|
|
@ -16,12 +16,13 @@ from core.workflow.nodes.enums import NodeType
|
|||
from core.workflow.nodes.event.types import NodeEvent
|
||||
from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from extensions.db import db
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account
|
||||
from models.dataset import Pipeline, PipelineBuiltInTemplate, PipelineCustomizedTemplate # type: ignore
|
||||
from models.workflow import Workflow, WorkflowNodeExecution, WorkflowType
|
||||
from services.entities.knowledge_entities.rag_pipeline_entities import PipelineTemplateInfoEntity
|
||||
from services.errors.app import WorkflowHashNotEqualError
|
||||
from services.errors.workflow_service import DraftWorkflowDeletionError
|
||||
from services.rag_pipeline.pipeline_template.pipeline_template_factory import PipelineTemplateRetrievalFactory
|
||||
|
||||
|
||||
|
@ -186,6 +187,7 @@ class RagPipelineService:
|
|||
account: Account,
|
||||
environment_variables: Sequence[Variable],
|
||||
conversation_variables: Sequence[Variable],
|
||||
pipeline_variables: dict[str, Sequence[Variable]],
|
||||
) -> Workflow:
|
||||
"""
|
||||
Sync draft workflow
|
||||
|
@ -212,6 +214,7 @@ class RagPipelineService:
|
|||
created_by=account.id,
|
||||
environment_variables=environment_variables,
|
||||
conversation_variables=conversation_variables,
|
||||
pipeline_variables=pipeline_variables,
|
||||
)
|
||||
db.session.add(workflow)
|
||||
# update draft workflow if found
|
||||
|
@ -222,7 +225,7 @@ class RagPipelineService:
|
|||
workflow.updated_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
workflow.environment_variables = environment_variables
|
||||
workflow.conversation_variables = conversation_variables
|
||||
|
||||
workflow.pipeline_variables = pipeline_variables
|
||||
# commit db session changes
|
||||
db.session.commit()
|
||||
|
||||
|
@ -343,6 +346,41 @@ class RagPipelineService:
|
|||
|
||||
return workflow_node_execution
|
||||
|
||||
def run_datasource_workflow_node(
|
||||
self, pipeline: Pipeline, node_id: str, user_inputs: dict, account: Account
|
||||
) -> WorkflowNodeExecution:
|
||||
"""
|
||||
Run published workflow datasource
|
||||
"""
|
||||
# fetch published workflow by app_model
|
||||
published_workflow = self.get_published_workflow(pipeline=pipeline)
|
||||
if not published_workflow:
|
||||
raise ValueError("Workflow not initialized")
|
||||
|
||||
# run draft workflow node
|
||||
start_at = time.perf_counter()
|
||||
|
||||
workflow_node_execution = self._handle_node_run_result(
|
||||
getter=lambda: WorkflowEntry.single_step_run(
|
||||
workflow=published_workflow,
|
||||
node_id=node_id,
|
||||
user_inputs=user_inputs,
|
||||
user_id=account.id,
|
||||
),
|
||||
start_at=start_at,
|
||||
tenant_id=pipeline.tenant_id,
|
||||
node_id=node_id,
|
||||
)
|
||||
|
||||
workflow_node_execution.app_id = pipeline.id
|
||||
workflow_node_execution.created_by = account.id
|
||||
workflow_node_execution.workflow_id = published_workflow.id
|
||||
|
||||
db.session.add(workflow_node_execution)
|
||||
db.session.commit()
|
||||
|
||||
return workflow_node_execution
|
||||
|
||||
def run_free_workflow_node(
|
||||
self, node_data: dict, tenant_id: str, user_id: str, node_id: str, user_inputs: dict[str, Any]
|
||||
) -> WorkflowNodeExecution:
|
||||
|
@ -573,3 +611,21 @@ class RagPipelineService:
|
|||
|
||||
session.delete(workflow)
|
||||
return True
|
||||
|
||||
def get_second_step_parameters(self, pipeline: Pipeline, datasource_provider: str) -> dict:
|
||||
"""
|
||||
Get second step parameters of rag pipeline
|
||||
"""
|
||||
|
||||
workflow = self.get_published_workflow(pipeline=pipeline)
|
||||
if not workflow:
|
||||
raise ValueError("Workflow not initialized")
|
||||
|
||||
# get second step node
|
||||
pipeline_variables = workflow.pipeline_variables
|
||||
if not pipeline_variables:
|
||||
return {}
|
||||
# get datasource provider
|
||||
datasource_provider_variables = pipeline_variables.get(datasource_provider, [])
|
||||
shared_variables = pipeline_variables.get("shared", [])
|
||||
return datasource_provider_variables + shared_variables
|
Loading…
Reference in New Issue