This commit is contained in:
jyong 2025-04-17 15:07:23 +08:00
parent 9f8e05d9f0
commit 5c4bf2a9e4
49 changed files with 5609 additions and 122 deletions

View File

@ -1,5 +1,6 @@
from flask import Blueprint
from .datasets.rag_pipeline import data_source
from libs.external_api import ExternalApi
from .app.app_import import AppImportApi, AppImportCheckDependenciesApi, AppImportConfirmApi
@ -75,7 +76,6 @@ from .billing import billing, compliance
# Import datasets controllers
from .datasets import (
data_source,
datasets,
datasets_document,
datasets_segments,

View File

@ -0,0 +1,136 @@
import json
import logging
from typing import cast
from flask import abort, request
from flask_restful import Resource, inputs, marshal_with, reqparse # type: ignore # type: ignore
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
import services
from configs import dify_config
from controllers.console import api
from controllers.console.app.error import (
ConversationCompletedError,
DraftWorkflowNotExist,
DraftWorkflowNotSync,
)
from controllers.console.app.wraps import get_app_model
from controllers.console.datasets.wraps import get_rag_pipeline
from controllers.console.wraps import (
account_initialization_required,
enterprise_license_required,
setup_required,
)
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import InvokeFrom
from extensions.ext_database import db
from factories import variable_factory
from fields.workflow_fields import workflow_fields, workflow_pagination_fields
from fields.workflow_run_fields import workflow_run_node_execution_fields
from libs import helper
from libs.helper import TimestampField
from libs.login import current_user, login_required
from models import App
from models.account import Account
from models.dataset import Pipeline
from models.model import AppMode
from services.app_generate_service import AppGenerateService
from services.entities.knowledge_entities.rag_pipeline_entities import PipelineTemplateInfoEntity
from services.errors.app import WorkflowHashNotEqualError
from services.errors.llm import InvokeRateLimitError
from services.rag_pipeline.rag_pipeline import RagPipelineService
from services.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError, WorkflowService
logger = logging.getLogger(__name__)
def _validate_name(name):
if not name or len(name) < 1 or len(name) > 40:
raise ValueError("Name must be between 1 to 40 characters.")
return name
def _validate_description_length(description):
if len(description) > 400:
raise ValueError("Description cannot exceed 400 characters.")
return description
class PipelineTemplateListApi(Resource):
@setup_required
@login_required
@account_initialization_required
@enterprise_license_required
def get(self):
type = request.args.get("type", default="built-in", type=str, choices=["built-in", "customized"])
language = request.args.get("language", default="en-US", type=str)
# get pipeline templates
pipeline_templates = RagPipelineService.get_pipeline_templates(type, language)
return pipeline_templates, 200
class PipelineTemplateDetailApi(Resource):
@setup_required
@login_required
@account_initialization_required
@enterprise_license_required
def get(self, pipeline_id: str):
pipeline_template = RagPipelineService.get_pipeline_template_detail(pipeline_id)
return pipeline_template, 200
class CustomizedPipelineTemplateApi(Resource):
@setup_required
@login_required
@account_initialization_required
@enterprise_license_required
def patch(self, template_id: str):
parser = reqparse.RequestParser()
parser.add_argument(
"name",
nullable=False,
required=True,
help="Name must be between 1 to 40 characters.",
type=_validate_name,
)
parser.add_argument(
"description",
type=str,
nullable=True,
required=False,
default="",
)
parser.add_argument(
"icon_info",
type=dict,
location="json",
nullable=True,
)
args = parser.parse_args()
pipeline_template_info = PipelineTemplateInfoEntity(**args)
pipeline_template = RagPipelineService.update_customized_pipeline_template(template_id, pipeline_template_info)
return pipeline_template, 200
@setup_required
@login_required
@account_initialization_required
@enterprise_license_required
def delete(self, template_id: str):
RagPipelineService.delete_customized_pipeline_template(template_id)
return 200
api.add_resource(
PipelineTemplateListApi,
"/rag/pipeline/templates",
)
api.add_resource(
PipelineTemplateDetailApi,
"/rag/pipeline/templates/<string:pipeline_id>",
)
api.add_resource(
CustomizedPipelineTemplateApi,
"/rag/pipeline/templates/<string:template_id>",
)

View File

@ -15,11 +15,9 @@ from controllers.console.app.error import (
DraftWorkflowNotExist,
DraftWorkflowNotSync,
)
from controllers.console.app.wraps import get_app_model
from controllers.console.datasets.wraps import get_rag_pipeline
from controllers.console.wraps import (
account_initialization_required,
enterprise_license_required,
setup_required,
)
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
@ -32,96 +30,17 @@ from fields.workflow_run_fields import workflow_run_node_execution_fields
from libs import helper
from libs.helper import TimestampField
from libs.login import current_user, login_required
from models import App
from models.account import Account
from models.dataset import Pipeline
from models.model import AppMode
from services.app_generate_service import AppGenerateService
from services.entities.knowledge_entities.rag_pipeline_entities import PipelineTemplateInfoEntity
from services.errors.app import WorkflowHashNotEqualError
from services.errors.llm import InvokeRateLimitError
from services.rag_pipeline.rag_pipeline import RagPipelineService
from services.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError, WorkflowService
from services.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError
logger = logging.getLogger(__name__)
def _validate_name(name):
if not name or len(name) < 1 or len(name) > 40:
raise ValueError("Name must be between 1 to 40 characters.")
return name
def _validate_description_length(description):
if len(description) > 400:
raise ValueError("Description cannot exceed 400 characters.")
return description
class PipelineTemplateListApi(Resource):
@setup_required
@login_required
@account_initialization_required
@enterprise_license_required
def get(self):
type = request.args.get("type", default="built-in", type=str, choices=["built-in", "customized"])
language = request.args.get("language", default="en-US", type=str)
# get pipeline templates
pipeline_templates = RagPipelineService.get_pipeline_templates(type, language)
return pipeline_templates, 200
class PipelineTemplateDetailApi(Resource):
@setup_required
@login_required
@account_initialization_required
@enterprise_license_required
def get(self, pipeline_id: str):
pipeline_template = RagPipelineService.get_pipeline_template_detail(pipeline_id)
return pipeline_template, 200
class CustomizedPipelineTemplateApi(Resource):
@setup_required
@login_required
@account_initialization_required
@enterprise_license_required
def patch(self, template_id: str):
parser = reqparse.RequestParser()
parser.add_argument(
"name",
nullable=False,
required=True,
help="Name must be between 1 to 40 characters.",
type=_validate_name,
)
parser.add_argument(
"description",
type=str,
nullable=True,
required=False,
default="",
)
parser.add_argument(
"icon_info",
type=dict,
location="json",
nullable=True,
)
args = parser.parse_args()
pipeline_template_info = PipelineTemplateInfoEntity(**args)
pipeline_template = RagPipelineService.update_customized_pipeline_template(template_id, pipeline_template_info)
return pipeline_template, 200
@setup_required
@login_required
@account_initialization_required
@enterprise_license_required
def delete(self, template_id: str):
RagPipelineService.delete_customized_pipeline_template(template_id)
return 200
class DraftRagPipelineApi(Resource):
@setup_required
@login_required
@ -130,7 +49,7 @@ class DraftRagPipelineApi(Resource):
@marshal_with(workflow_fields)
def get(self, pipeline: Pipeline):
"""
Get draft workflow
Get draft rag pipeline's workflow
"""
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
@ -167,6 +86,7 @@ class DraftRagPipelineApi(Resource):
parser.add_argument("hash", type=str, required=False, location="json")
parser.add_argument("environment_variables", type=list, required=False, location="json")
parser.add_argument("conversation_variables", type=list, required=False, location="json")
parser.add_argument("pipeline_variables", type=dict, required=False, location="json")
args = parser.parse_args()
elif "text/plain" in content_type:
try:
@ -183,6 +103,7 @@ class DraftRagPipelineApi(Resource):
"hash": data.get("hash"),
"environment_variables": data.get("environment_variables"),
"conversation_variables": data.get("conversation_variables"),
"pipeline_variables": data.get("pipeline_variables"),
}
except json.JSONDecodeError:
return {"message": "Invalid JSON data"}, 400
@ -192,8 +113,6 @@ class DraftRagPipelineApi(Resource):
if not isinstance(current_user, Account):
raise Forbidden()
workflow_service = WorkflowService()
try:
environment_variables_list = args.get("environment_variables") or []
environment_variables = [
@ -203,6 +122,11 @@ class DraftRagPipelineApi(Resource):
conversation_variables = [
variable_factory.build_conversation_variable_from_mapping(obj) for obj in conversation_variables_list
]
pipeline_variables_list = args.get("pipeline_variables") or {}
pipeline_variables = {
k: [variable_factory.build_pipeline_variable_from_mapping(obj) for obj in v]
for k, v in pipeline_variables_list.items()
}
rag_pipeline_service = RagPipelineService()
workflow = rag_pipeline_service.sync_draft_workflow(
pipeline=pipeline,
@ -212,6 +136,7 @@ class DraftRagPipelineApi(Resource):
account=current_user,
environment_variables=environment_variables,
conversation_variables=conversation_variables,
pipeline_variables=pipeline_variables,
)
except WorkflowHashNotEqualError:
raise DraftWorkflowNotSync()
@ -263,8 +188,8 @@ class RagPipelineDraftRunLoopNodeApi(Resource):
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.WORKFLOW])
def post(self, app_model: App, node_id: str):
@get_rag_pipeline
def post(self, pipeline: Pipeline, node_id: str):
"""
Run draft workflow loop node
"""
@ -281,7 +206,7 @@ class RagPipelineDraftRunLoopNodeApi(Resource):
try:
response = AppGenerateService.generate_single_loop(
app_model=app_model, user=current_user, node_id=node_id, args=args, streaming=True
pipeline=pipeline, user=current_user, node_id=node_id, args=args, streaming=True
)
return helper.compact_generate_response(response)
@ -300,8 +225,8 @@ class DraftRagPipelineRunApi(Resource):
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.WORKFLOW])
def post(self, app_model: App):
@get_rag_pipeline
def post(self, pipeline: Pipeline):
"""
Run draft workflow
"""
@ -319,7 +244,7 @@ class DraftRagPipelineRunApi(Resource):
try:
response = AppGenerateService.generate(
app_model=app_model,
pipeline=pipeline,
user=current_user,
args=args,
invoke_from=InvokeFrom.DEBUGGER,
@ -330,32 +255,45 @@ class DraftRagPipelineRunApi(Resource):
except InvokeRateLimitError as ex:
raise InvokeRateLimitHttpError(ex.description)
class RagPipelineTaskStopApi(Resource):
class RagPipelineDatasourceNodeRunApi(Resource):
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
def post(self, app_model: App, task_id: str):
@get_rag_pipeline
def post(self, pipeline: Pipeline, node_id: str):
"""
Stop workflow task
Run rag pipeline datasource
"""
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, current_user.id)
if not isinstance(current_user, Account):
raise Forbidden()
return {"result": "success"}
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
args = parser.parse_args()
inputs = args.get("inputs")
if inputs == None:
raise ValueError("missing inputs")
rag_pipeline_service = RagPipelineService()
workflow_node_execution = rag_pipeline_service.run_datasource_workflow_node(
pipeline=pipeline, node_id=node_id, user_inputs=inputs, account=current_user
)
return workflow_node_execution
class RagPipelineNodeRunApi(Resource):
class RagPipelineDraftNodeRunApi(Resource):
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@get_rag_pipeline
@marshal_with(workflow_run_node_execution_fields)
def post(self, app_model: App, node_id: str):
def post(self, pipeline: Pipeline, node_id: str):
"""
Run draft workflow node
"""
@ -374,13 +312,29 @@ class RagPipelineNodeRunApi(Resource):
if inputs == None:
raise ValueError("missing inputs")
workflow_service = WorkflowService()
workflow_node_execution = workflow_service.run_draft_workflow_node(
app_model=app_model, node_id=node_id, user_inputs=inputs, account=current_user
rag_pipeline_service = RagPipelineService()
workflow_node_execution = rag_pipeline_service.run_draft_workflow_node(
pipeline=pipeline, node_id=node_id, user_inputs=inputs, account=current_user
)
return workflow_node_execution
class RagPipelineTaskStopApi(Resource):
@setup_required
@login_required
@account_initialization_required
@get_rag_pipeline
def post(self, pipeline: Pipeline, task_id: str):
"""
Stop workflow task
"""
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, current_user.id)
return {"result": "success"}
class PublishedRagPipelineApi(Resource):
@setup_required
@ -695,6 +649,25 @@ class RagPipelineByIdApi(Resource):
return None, 204
class RagPipelineSecondStepApi(Resource):
@setup_required
@login_required
@account_initialization_required
@get_rag_pipeline
def get(self, pipeline: Pipeline):
"""
Get second step parameters of rag pipeline
"""
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
datasource_provider = request.args.get("datasource_provider", required=True, type=str)
rag_pipeline_service = RagPipelineService()
return rag_pipeline_service.get_second_step_parameters(pipeline=pipeline,
datasource_provider=datasource_provider
)
api.add_resource(
DraftRagPipelineApi,
@ -713,9 +686,13 @@ api.add_resource(
"/rag/pipelines/<uuid:pipeline_id>/workflow-runs/tasks/<string:task_id>/stop",
)
api.add_resource(
RagPipelineNodeRunApi,
RagPipelineDraftNodeRunApi,
"/rag/pipelines/<uuid:pipeline_id>/workflows/draft/nodes/<string:node_id>/run",
)
api.add_resource(
RagPipelinePublishedNodeRunApi,
"/rag/pipelines/<uuid:pipeline_id>/workflows/published/nodes/<string:node_id>/run",
)
api.add_resource(
RagPipelineDraftRunIterationNodeApi,
@ -751,15 +728,3 @@ api.add_resource(
"/rag/pipelines/<uuid:pipeline_id>/workflows/<string:workflow_id>",
)
api.add_resource(
PipelineTemplateListApi,
"/rag/pipeline/templates",
)
api.add_resource(
PipelineTemplateDetailApi,
"/rag/pipeline/templates/<string:pipeline_id>",
)
api.add_resource(
CustomizedPipelineTemplateApi,
"/rag/pipeline/templates/<string:template_id>",
)

View File

@ -0,0 +1,222 @@
from abc import ABC, abstractmethod
from collections.abc import Generator
from copy import deepcopy
from typing import TYPE_CHECKING, Any, Optional
if TYPE_CHECKING:
from models.model import File
from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.entities.tool_entities import (
ToolEntity,
ToolInvokeMessage,
ToolParameter,
ToolProviderType,
)
class Tool(ABC):
"""
The base class of a tool
"""
entity: ToolEntity
runtime: ToolRuntime
def __init__(self, entity: ToolEntity, runtime: ToolRuntime) -> None:
self.entity = entity
self.runtime = runtime
def fork_tool_runtime(self, runtime: ToolRuntime) -> "Tool":
"""
fork a new tool with metadata
:return: the new tool
"""
return self.__class__(
entity=self.entity.model_copy(),
runtime=runtime,
)
@abstractmethod
def tool_provider_type(self) -> ToolProviderType:
"""
get the tool provider type
:return: the tool provider type
"""
def invoke(
self,
user_id: str,
tool_parameters: dict[str, Any],
conversation_id: Optional[str] = None,
app_id: Optional[str] = None,
message_id: Optional[str] = None,
) -> Generator[ToolInvokeMessage]:
if self.runtime and self.runtime.runtime_parameters:
tool_parameters.update(self.runtime.runtime_parameters)
# try parse tool parameters into the correct type
tool_parameters = self._transform_tool_parameters_type(tool_parameters)
result = self._invoke(
user_id=user_id,
tool_parameters=tool_parameters,
conversation_id=conversation_id,
app_id=app_id,
message_id=message_id,
)
if isinstance(result, ToolInvokeMessage):
def single_generator() -> Generator[ToolInvokeMessage, None, None]:
yield result
return single_generator()
elif isinstance(result, list):
def generator() -> Generator[ToolInvokeMessage, None, None]:
yield from result
return generator()
else:
return result
def _transform_tool_parameters_type(self, tool_parameters: dict[str, Any]) -> dict[str, Any]:
"""
Transform tool parameters type
"""
# Temp fix for the issue that the tool parameters will be converted to empty while validating the credentials
result = deepcopy(tool_parameters)
for parameter in self.entity.parameters or []:
if parameter.name in tool_parameters:
result[parameter.name] = parameter.type.cast_value(tool_parameters[parameter.name])
return result
@abstractmethod
def _invoke(
self,
user_id: str,
tool_parameters: dict[str, Any],
conversation_id: Optional[str] = None,
app_id: Optional[str] = None,
message_id: Optional[str] = None,
) -> ToolInvokeMessage | list[ToolInvokeMessage] | Generator[ToolInvokeMessage, None, None]:
pass
def get_runtime_parameters(
self,
conversation_id: Optional[str] = None,
app_id: Optional[str] = None,
message_id: Optional[str] = None,
) -> list[ToolParameter]:
"""
get the runtime parameters
interface for developer to dynamic change the parameters of a tool depends on the variables pool
:return: the runtime parameters
"""
return self.entity.parameters
def get_merged_runtime_parameters(
self,
conversation_id: Optional[str] = None,
app_id: Optional[str] = None,
message_id: Optional[str] = None,
) -> list[ToolParameter]:
"""
get merged runtime parameters
:return: merged runtime parameters
"""
parameters = self.entity.parameters
parameters = parameters.copy()
user_parameters = self.get_runtime_parameters() or []
user_parameters = user_parameters.copy()
# override parameters
for parameter in user_parameters:
# check if parameter in tool parameters
for tool_parameter in parameters:
if tool_parameter.name == parameter.name:
# override parameter
tool_parameter.type = parameter.type
tool_parameter.form = parameter.form
tool_parameter.required = parameter.required
tool_parameter.default = parameter.default
tool_parameter.options = parameter.options
tool_parameter.llm_description = parameter.llm_description
break
else:
# add new parameter
parameters.append(parameter)
return parameters
def create_image_message(
self,
image: str,
) -> ToolInvokeMessage:
"""
create an image message
:param image: the url of the image
:return: the image message
"""
return ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.IMAGE, message=ToolInvokeMessage.TextMessage(text=image)
)
def create_file_message(self, file: "File") -> ToolInvokeMessage:
return ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.FILE,
message=ToolInvokeMessage.FileMessage(),
meta={"file": file},
)
def create_link_message(self, link: str) -> ToolInvokeMessage:
"""
create a link message
:param link: the url of the link
:return: the link message
"""
return ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.LINK, message=ToolInvokeMessage.TextMessage(text=link)
)
def create_text_message(self, text: str) -> ToolInvokeMessage:
"""
create a text message
:param text: the text
:return: the text message
"""
return ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.TEXT,
message=ToolInvokeMessage.TextMessage(text=text),
)
def create_blob_message(self, blob: bytes, meta: Optional[dict] = None) -> ToolInvokeMessage:
"""
create a blob message
:param blob: the blob
:param meta: the meta info of blob object
:return: the blob message
"""
return ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.BLOB,
message=ToolInvokeMessage.BlobMessage(blob=blob),
meta=meta,
)
def create_json_message(self, object: dict) -> ToolInvokeMessage:
"""
create a json message
"""
return ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.JSON, message=ToolInvokeMessage.JsonMessage(json_object=object)
)

View File

@ -0,0 +1,109 @@
from abc import ABC, abstractmethod
from copy import deepcopy
from typing import Any
from core.entities.provider_entities import ProviderConfig
from core.tools.__base.tool import Tool
from core.tools.entities.tool_entities import (
ToolProviderEntity,
ToolProviderType,
)
from core.tools.errors import ToolProviderCredentialValidationError
class ToolProviderController(ABC):
entity: ToolProviderEntity
def __init__(self, entity: ToolProviderEntity) -> None:
self.entity = entity
def get_credentials_schema(self) -> list[ProviderConfig]:
"""
returns the credentials schema of the provider
:return: the credentials schema
"""
return deepcopy(self.entity.credentials_schema)
@abstractmethod
def get_tool(self, tool_name: str) -> Tool:
"""
returns a tool that the provider can provide
:return: tool
"""
pass
@property
def provider_type(self) -> ToolProviderType:
"""
returns the type of the provider
:return: type of the provider
"""
return ToolProviderType.BUILT_IN
def validate_credentials_format(self, credentials: dict[str, Any]) -> None:
"""
validate the format of the credentials of the provider and set the default value if needed
:param credentials: the credentials of the tool
"""
credentials_schema = dict[str, ProviderConfig]()
if credentials_schema is None:
return
for credential in self.entity.credentials_schema:
credentials_schema[credential.name] = credential
credentials_need_to_validate: dict[str, ProviderConfig] = {}
for credential_name in credentials_schema:
credentials_need_to_validate[credential_name] = credentials_schema[credential_name]
for credential_name in credentials:
if credential_name not in credentials_need_to_validate:
raise ToolProviderCredentialValidationError(
f"credential {credential_name} not found in provider {self.entity.identity.name}"
)
# check type
credential_schema = credentials_need_to_validate[credential_name]
if not credential_schema.required and credentials[credential_name] is None:
continue
if credential_schema.type in {ProviderConfig.Type.SECRET_INPUT, ProviderConfig.Type.TEXT_INPUT}:
if not isinstance(credentials[credential_name], str):
raise ToolProviderCredentialValidationError(f"credential {credential_name} should be string")
elif credential_schema.type == ProviderConfig.Type.SELECT:
if not isinstance(credentials[credential_name], str):
raise ToolProviderCredentialValidationError(f"credential {credential_name} should be string")
options = credential_schema.options
if not isinstance(options, list):
raise ToolProviderCredentialValidationError(f"credential {credential_name} options should be list")
if credentials[credential_name] not in [x.value for x in options]:
raise ToolProviderCredentialValidationError(
f"credential {credential_name} should be one of {options}"
)
credentials_need_to_validate.pop(credential_name)
for credential_name in credentials_need_to_validate:
credential_schema = credentials_need_to_validate[credential_name]
if credential_schema.required:
raise ToolProviderCredentialValidationError(f"credential {credential_name} is required")
# the credential is not set currently, set the default value if needed
if credential_schema.default is not None:
default_value = credential_schema.default
# parse default value into the correct type
if credential_schema.type in {
ProviderConfig.Type.SECRET_INPUT,
ProviderConfig.Type.TEXT_INPUT,
ProviderConfig.Type.SELECT,
}:
default_value = str(default_value)
credentials[credential_name] = default_value

View File

@ -0,0 +1,36 @@
from typing import Any, Optional
from openai import BaseModel
from pydantic import Field
from core.app.entities.app_invoke_entities import InvokeFrom
from core.tools.entities.tool_entities import ToolInvokeFrom
class ToolRuntime(BaseModel):
"""
Meta data of a tool call processing
"""
tenant_id: str
tool_id: Optional[str] = None
invoke_from: Optional[InvokeFrom] = None
tool_invoke_from: Optional[ToolInvokeFrom] = None
credentials: dict[str, Any] = Field(default_factory=dict)
runtime_parameters: dict[str, Any] = Field(default_factory=dict)
class FakeToolRuntime(ToolRuntime):
"""
Fake tool runtime for testing
"""
def __init__(self):
super().__init__(
tenant_id="fake_tenant_id",
tool_id="fake_tool_id",
invoke_from=InvokeFrom.DEBUGGER,
tool_invoke_from=ToolInvokeFrom.AGENT,
credentials={},
runtime_parameters={},
)

View File

View File

@ -0,0 +1,72 @@
from typing import Literal, Optional
from pydantic import BaseModel, Field, field_validator
from core.model_runtime.utils.encoders import jsonable_encoder
from core.tools.__base.tool import ToolParameter
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ToolProviderType
class ToolApiEntity(BaseModel):
author: str
name: str # identifier
label: I18nObject # label
description: I18nObject
parameters: Optional[list[ToolParameter]] = None
labels: list[str] = Field(default_factory=list)
output_schema: Optional[dict] = None
ToolProviderTypeApiLiteral = Optional[Literal["builtin", "api", "workflow"]]
class ToolProviderApiEntity(BaseModel):
id: str
author: str
name: str # identifier
description: I18nObject
icon: str | dict
label: I18nObject # label
type: ToolProviderType
masked_credentials: Optional[dict] = None
original_credentials: Optional[dict] = None
is_team_authorization: bool = False
allow_delete: bool = True
plugin_id: Optional[str] = Field(default="", description="The plugin id of the tool")
plugin_unique_identifier: Optional[str] = Field(default="", description="The unique identifier of the tool")
tools: list[ToolApiEntity] = Field(default_factory=list)
labels: list[str] = Field(default_factory=list)
@field_validator("tools", mode="before")
@classmethod
def convert_none_to_empty_list(cls, v):
return v if v is not None else []
def to_dict(self) -> dict:
# -------------
# overwrite tool parameter types for temp fix
tools = jsonable_encoder(self.tools)
for tool in tools:
if tool.get("parameters"):
for parameter in tool.get("parameters"):
if parameter.get("type") == ToolParameter.ToolParameterType.SYSTEM_FILES.value:
parameter["type"] = "files"
# -------------
return {
"id": self.id,
"author": self.author,
"name": self.name,
"plugin_id": self.plugin_id,
"plugin_unique_identifier": self.plugin_unique_identifier,
"description": self.description.to_dict(),
"icon": self.icon,
"label": self.label.to_dict(),
"type": self.type.value,
"team_credentials": self.masked_credentials,
"is_team_authorization": self.is_team_authorization,
"allow_delete": self.allow_delete,
"tools": tools,
"labels": self.labels,
}

View File

@ -0,0 +1,23 @@
from typing import Optional
from pydantic import BaseModel, Field
class I18nObject(BaseModel):
"""
Model class for i18n object.
"""
en_US: str
zh_Hans: Optional[str] = Field(default=None)
pt_BR: Optional[str] = Field(default=None)
ja_JP: Optional[str] = Field(default=None)
def __init__(self, **data):
super().__init__(**data)
self.zh_Hans = self.zh_Hans or self.en_US
self.pt_BR = self.pt_BR or self.en_US
self.ja_JP = self.ja_JP or self.en_US
def to_dict(self) -> dict:
return {"zh_Hans": self.zh_Hans, "en_US": self.en_US, "pt_BR": self.pt_BR, "ja_JP": self.ja_JP}

View File

@ -0,0 +1 @@
TOOL_SELECTOR_MODEL_IDENTITY = "__dify__tool_selector__"

View File

@ -0,0 +1 @@

View File

@ -0,0 +1,29 @@
from typing import Optional
from pydantic import BaseModel
from core.tools.entities.tool_entities import ToolParameter
class ApiToolBundle(BaseModel):
"""
This class is used to store the schema information of an api based tool.
such as the url, the method, the parameters, etc.
"""
# server_url
server_url: str
# method
method: str
# summary
summary: Optional[str] = None
# operation_id
operation_id: Optional[str] = None
# parameters
parameters: Optional[list[ToolParameter]] = None
# author
author: str
# icon
icon: Optional[str] = None
# openapi operation
openapi: dict

View File

@ -0,0 +1,427 @@
import base64
import enum
from collections.abc import Mapping
from enum import Enum
from typing import Any, Optional, Union
from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_serializer, field_validator, model_validator
from core.entities.provider_entities import ProviderConfig
from core.plugin.entities.parameters import (
PluginParameter,
PluginParameterOption,
PluginParameterType,
as_normal_type,
cast_parameter_value,
init_frontend_parameter,
)
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.constants import TOOL_SELECTOR_MODEL_IDENTITY
class ToolLabelEnum(Enum):
SEARCH = "search"
IMAGE = "image"
VIDEOS = "videos"
WEATHER = "weather"
FINANCE = "finance"
DESIGN = "design"
TRAVEL = "travel"
SOCIAL = "social"
NEWS = "news"
MEDICAL = "medical"
PRODUCTIVITY = "productivity"
EDUCATION = "education"
BUSINESS = "business"
ENTERTAINMENT = "entertainment"
UTILITIES = "utilities"
OTHER = "other"
class ToolProviderType(enum.StrEnum):
"""
Enum class for tool provider
"""
PLUGIN = "plugin"
BUILT_IN = "builtin"
WORKFLOW = "workflow"
API = "api"
APP = "app"
DATASET_RETRIEVAL = "dataset-retrieval"
@classmethod
def value_of(cls, value: str) -> "ToolProviderType":
"""
Get value of given mode.
:param value: mode value
:return: mode
"""
for mode in cls:
if mode.value == value:
return mode
raise ValueError(f"invalid mode value {value}")
class ApiProviderSchemaType(Enum):
"""
Enum class for api provider schema type.
"""
OPENAPI = "openapi"
SWAGGER = "swagger"
OPENAI_PLUGIN = "openai_plugin"
OPENAI_ACTIONS = "openai_actions"
@classmethod
def value_of(cls, value: str) -> "ApiProviderSchemaType":
"""
Get value of given mode.
:param value: mode value
:return: mode
"""
for mode in cls:
if mode.value == value:
return mode
raise ValueError(f"invalid mode value {value}")
class ApiProviderAuthType(Enum):
"""
Enum class for api provider auth type.
"""
NONE = "none"
API_KEY = "api_key"
@classmethod
def value_of(cls, value: str) -> "ApiProviderAuthType":
"""
Get value of given mode.
:param value: mode value
:return: mode
"""
for mode in cls:
if mode.value == value:
return mode
raise ValueError(f"invalid mode value {value}")
class ToolInvokeMessage(BaseModel):
class TextMessage(BaseModel):
text: str
class JsonMessage(BaseModel):
json_object: dict
class BlobMessage(BaseModel):
blob: bytes
class FileMessage(BaseModel):
pass
class VariableMessage(BaseModel):
variable_name: str = Field(..., description="The name of the variable")
variable_value: Any = Field(..., description="The value of the variable")
stream: bool = Field(default=False, description="Whether the variable is streamed")
@model_validator(mode="before")
@classmethod
def transform_variable_value(cls, values) -> Any:
"""
Only basic types and lists are allowed.
"""
value = values.get("variable_value")
if not isinstance(value, dict | list | str | int | float | bool):
raise ValueError("Only basic types and lists are allowed.")
# if stream is true, the value must be a string
if values.get("stream"):
if not isinstance(value, str):
raise ValueError("When 'stream' is True, 'variable_value' must be a string.")
return values
@field_validator("variable_name", mode="before")
@classmethod
def transform_variable_name(cls, value: str) -> str:
"""
The variable name must be a string.
"""
if value in {"json", "text", "files"}:
raise ValueError(f"The variable name '{value}' is reserved.")
return value
class LogMessage(BaseModel):
class LogStatus(Enum):
START = "start"
ERROR = "error"
SUCCESS = "success"
id: str
label: str = Field(..., description="The label of the log")
parent_id: Optional[str] = Field(default=None, description="Leave empty for root log")
error: Optional[str] = Field(default=None, description="The error message")
status: LogStatus = Field(..., description="The status of the log")
data: Mapping[str, Any] = Field(..., description="Detailed log data")
metadata: Optional[Mapping[str, Any]] = Field(default=None, description="The metadata of the log")
class MessageType(Enum):
TEXT = "text"
IMAGE = "image"
LINK = "link"
BLOB = "blob"
JSON = "json"
IMAGE_LINK = "image_link"
BINARY_LINK = "binary_link"
VARIABLE = "variable"
FILE = "file"
LOG = "log"
type: MessageType = MessageType.TEXT
"""
plain text, image url or link url
"""
message: JsonMessage | TextMessage | BlobMessage | LogMessage | FileMessage | None | VariableMessage
meta: dict[str, Any] | None = None
@field_validator("message", mode="before")
@classmethod
def decode_blob_message(cls, v):
if isinstance(v, dict) and "blob" in v:
try:
v["blob"] = base64.b64decode(v["blob"])
except Exception:
pass
return v
@field_serializer("message")
def serialize_message(self, v):
if isinstance(v, self.BlobMessage):
return {"blob": base64.b64encode(v.blob).decode("utf-8")}
return v
class ToolInvokeMessageBinary(BaseModel):
mimetype: str = Field(..., description="The mimetype of the binary")
url: str = Field(..., description="The url of the binary")
file_var: Optional[dict[str, Any]] = None
class ToolParameter(PluginParameter):
"""
Overrides type
"""
class ToolParameterType(enum.StrEnum):
"""
removes TOOLS_SELECTOR from PluginParameterType
"""
STRING = PluginParameterType.STRING.value
NUMBER = PluginParameterType.NUMBER.value
BOOLEAN = PluginParameterType.BOOLEAN.value
SELECT = PluginParameterType.SELECT.value
SECRET_INPUT = PluginParameterType.SECRET_INPUT.value
FILE = PluginParameterType.FILE.value
FILES = PluginParameterType.FILES.value
APP_SELECTOR = PluginParameterType.APP_SELECTOR.value
MODEL_SELECTOR = PluginParameterType.MODEL_SELECTOR.value
# deprecated, should not use.
SYSTEM_FILES = PluginParameterType.SYSTEM_FILES.value
def as_normal_type(self):
return as_normal_type(self)
def cast_value(self, value: Any):
return cast_parameter_value(self, value)
class ToolParameterForm(Enum):
SCHEMA = "schema" # should be set while adding tool
FORM = "form" # should be set before invoking tool
LLM = "llm" # will be set by LLM
type: ToolParameterType = Field(..., description="The type of the parameter")
human_description: Optional[I18nObject] = Field(default=None, description="The description presented to the user")
form: ToolParameterForm = Field(..., description="The form of the parameter, schema/form/llm")
llm_description: Optional[str] = None
@classmethod
def get_simple_instance(
cls,
name: str,
llm_description: str,
typ: ToolParameterType,
required: bool,
options: Optional[list[str]] = None,
) -> "ToolParameter":
"""
get a simple tool parameter
:param name: the name of the parameter
:param llm_description: the description presented to the LLM
:param typ: the type of the parameter
:param required: if the parameter is required
:param options: the options of the parameter
"""
# convert options to ToolParameterOption
# FIXME fix the type error
if options:
option_objs = [
PluginParameterOption(value=option, label=I18nObject(en_US=option, zh_Hans=option))
for option in options
]
else:
option_objs = []
return cls(
name=name,
label=I18nObject(en_US="", zh_Hans=""),
placeholder=None,
human_description=I18nObject(en_US="", zh_Hans=""),
type=typ,
form=cls.ToolParameterForm.LLM,
llm_description=llm_description,
required=required,
options=option_objs,
)
def init_frontend_parameter(self, value: Any):
return init_frontend_parameter(self, self.type, value)
class ToolProviderIdentity(BaseModel):
author: str = Field(..., description="The author of the tool")
name: str = Field(..., description="The name of the tool")
description: I18nObject = Field(..., description="The description of the tool")
icon: str = Field(..., description="The icon of the tool")
label: I18nObject = Field(..., description="The label of the tool")
tags: Optional[list[ToolLabelEnum]] = Field(
default=[],
description="The tags of the tool",
)
class ToolIdentity(BaseModel):
author: str = Field(..., description="The author of the tool")
name: str = Field(..., description="The name of the tool")
label: I18nObject = Field(..., description="The label of the tool")
provider: str = Field(..., description="The provider of the tool")
icon: Optional[str] = None
class ToolDescription(BaseModel):
human: I18nObject = Field(..., description="The description presented to the user")
llm: str = Field(..., description="The description presented to the LLM")
class ToolEntity(BaseModel):
identity: ToolIdentity
parameters: list[ToolParameter] = Field(default_factory=list)
description: Optional[ToolDescription] = None
output_schema: Optional[dict] = None
has_runtime_parameters: bool = Field(default=False, description="Whether the tool has runtime parameters")
# pydantic configs
model_config = ConfigDict(protected_namespaces=())
@field_validator("parameters", mode="before")
@classmethod
def set_parameters(cls, v, validation_info: ValidationInfo) -> list[ToolParameter]:
return v or []
class ToolProviderEntity(BaseModel):
identity: ToolProviderIdentity
plugin_id: Optional[str] = None
credentials_schema: list[ProviderConfig] = Field(default_factory=list)
class ToolProviderEntityWithPlugin(ToolProviderEntity):
tools: list[ToolEntity] = Field(default_factory=list)
class WorkflowToolParameterConfiguration(BaseModel):
"""
Workflow tool configuration
"""
name: str = Field(..., description="The name of the parameter")
description: str = Field(..., description="The description of the parameter")
form: ToolParameter.ToolParameterForm = Field(..., description="The form of the parameter")
class ToolInvokeMeta(BaseModel):
"""
Tool invoke meta
"""
time_cost: float = Field(..., description="The time cost of the tool invoke")
error: Optional[str] = None
tool_config: Optional[dict] = None
@classmethod
def empty(cls) -> "ToolInvokeMeta":
"""
Get an empty instance of ToolInvokeMeta
"""
return cls(time_cost=0.0, error=None, tool_config={})
@classmethod
def error_instance(cls, error: str) -> "ToolInvokeMeta":
"""
Get an instance of ToolInvokeMeta with error
"""
return cls(time_cost=0.0, error=error, tool_config={})
def to_dict(self) -> dict:
return {
"time_cost": self.time_cost,
"error": self.error,
"tool_config": self.tool_config,
}
class ToolLabel(BaseModel):
"""
Tool label
"""
name: str = Field(..., description="The name of the tool")
label: I18nObject = Field(..., description="The label of the tool")
icon: str = Field(..., description="The icon of the tool")
class ToolInvokeFrom(Enum):
"""
Enum class for tool invoke
"""
WORKFLOW = "workflow"
AGENT = "agent"
PLUGIN = "plugin"
class ToolSelector(BaseModel):
dify_model_identity: str = TOOL_SELECTOR_MODEL_IDENTITY
class Parameter(BaseModel):
name: str = Field(..., description="The name of the parameter")
type: ToolParameter.ToolParameterType = Field(..., description="The type of the parameter")
required: bool = Field(..., description="Whether the parameter is required")
description: str = Field(..., description="The description of the parameter")
default: Optional[Union[int, float, str]] = None
options: Optional[list[PluginParameterOption]] = None
provider_id: str = Field(..., description="The id of the provider")
tool_name: str = Field(..., description="The name of the tool")
tool_description: str = Field(..., description="The description of the tool")
tool_configuration: Mapping[str, Any] = Field(..., description="Configuration, type form")
tool_parameters: Mapping[str, Parameter] = Field(..., description="Parameters, type llm")
def to_plugin_parameter(self) -> dict[str, Any]:
return self.model_dump()

View File

@ -0,0 +1,111 @@
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ToolLabel, ToolLabelEnum
ICONS = {
ToolLabelEnum.SEARCH: """<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
<path d="M7.33398 1.3335C10.646 1.3335 13.334 4.0215 13.334 7.3335C13.334 10.6455 10.646 13.3335 7.33398 13.3335C4.02198 13.3335 1.33398 10.6455 1.33398 7.3335C1.33398 4.0215 4.02198 1.3335 7.33398 1.3335ZM7.33398 12.0002C9.91232 12.0002 12.0007 9.91183 12.0007 7.3335C12.0007 4.75516 9.91232 2.66683 7.33398 2.66683C4.75565 2.66683 2.66732 4.75516 2.66732 7.3335C2.66732 9.91183 4.75565 12.0002 7.33398 12.0002ZM12.9909 12.0476L14.8764 13.9332L13.9337 14.876L12.0481 12.9904L12.9909 12.0476Z" fill="#344054"/>
</svg>""", # noqa: E501
ToolLabelEnum.IMAGE: """<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
<path d="M13.0514 9.71752L10.4718 7.13792C10.2115 6.87752 9.78932 6.87752 9.52898 7.13792L4.57721 12.0897C3.4097 11.1113 2.66732 9.64232 2.66732 7.99992C2.66732 5.0544 5.05513 2.66659 8.00065 2.66659C10.9462 2.66659 13.334 5.0544 13.334 7.99992C13.334 8.60085 13.2346 9.17852 13.0514 9.71752ZM5.72683 12.8257L10.0004 8.55212L12.4259 10.9777C11.4668 12.4001 9.84152 13.3331 8.00038 13.3331C7.18632 13.3331 6.41628 13.1511 5.72683 12.8257ZM8.00065 14.6666C11.6825 14.6666 14.6673 11.6818 14.6673 7.99992C14.6673 4.31802 11.6825 1.33325 8.00065 1.33325C4.31875 1.33325 1.33398 4.31802 1.33398 7.99992C1.33398 11.6818 4.31875 14.6666 8.00065 14.6666ZM7.33398 6.66658C7.33398 7.40299 6.73705 7.99992 6.00065 7.99992C5.26427 7.99992 4.66732 7.40299 4.66732 6.66658C4.66732 5.9302 5.26427 5.33325 6.00065 5.33325C6.73705 5.33325 7.33398 5.9302 7.33398 6.66658Z" fill="#344054"/>
</svg>""", # noqa: E501
ToolLabelEnum.VIDEOS: """<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
<path d="M8.00065 13.3333H13.334V14.6666H8.00065C4.31875 14.6666 1.33398 11.6818 1.33398 7.99992C1.33398 4.31802 4.31875 1.33325 8.00065 1.33325C11.6825 1.33325 14.6673 4.31802 14.6673 7.99992C14.6673 9.50072 14.1714 10.8857 13.3345 11.9999H11.5284C12.6356 11.0227 13.334 9.59285 13.334 7.99992C13.334 5.0544 10.9462 2.66659 8.00065 2.66659C5.05513 2.66659 2.66732 5.0544 2.66732 7.99992C2.66732 10.9455 5.05513 13.3333 8.00065 13.3333ZM8.00065 6.66658C7.26425 6.66658 6.66732 6.06963 6.66732 5.33325C6.66732 4.59687 7.26425 3.99992 8.00065 3.99992C8.73705 3.99992 9.33398 4.59687 9.33398 5.33325C9.33398 6.06963 8.73705 6.66658 8.00065 6.66658ZM5.33398 9.33325C4.5976 9.33325 4.00065 8.73632 4.00065 7.99992C4.00065 7.26352 4.5976 6.66658 5.33398 6.66658C6.07036 6.66658 6.66732 7.26352 6.66732 7.99992C6.66732 8.73632 6.07036 9.33325 5.33398 9.33325ZM10.6673 9.33325C9.93092 9.33325 9.33398 8.73632 9.33398 7.99992C9.33398 7.26352 9.93092 6.66658 10.6673 6.66658C11.4037 6.66658 12.0007 7.26352 12.0007 7.99992C12.0007 8.73632 11.4037 9.33325 10.6673 9.33325ZM8.00065 11.9999C7.26425 11.9999 6.66732 11.403 6.66732 10.6666C6.66732 9.93018 7.26425 9.33325 8.00065 9.33325C8.73705 9.33325 9.33398 9.93018 9.33398 10.6666C9.33398 11.403 8.73705 11.9999 8.00065 11.9999Z" fill="#344054"/>
</svg>""", # noqa: E501
ToolLabelEnum.WEATHER: """<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
<path d="M6.6553 3.37344C7.42088 2.1484 8.78162 1.3335 10.3327 1.3335C12.7259 1.3335 14.666 3.2736 14.666 5.66683C14.666 6.38704 14.4903 7.06623 14.1794 7.66383C14.8894 8.3325 15.3327 9.28123 15.3327 10.3335C15.3327 12.3586 13.6911 14.0002 11.666 14.0002H5.99935C3.05383 14.0002 0.666016 11.6124 0.666016 8.66683C0.666016 5.72131 3.05383 3.3335 5.99935 3.3335C6.22143 3.3335 6.44034 3.34707 6.6553 3.37344ZM8.03628 3.73629C9.37768 4.29108 10.4435 5.37735 10.9711 6.73256C11.1961 6.68943 11.4284 6.66683 11.666 6.66683C12.1561 6.66683 12.6237 6.76296 13.0511 6.93743C13.2317 6.55162 13.3327 6.12102 13.3327 5.66683C13.3327 4.00998 11.9895 2.66683 10.3327 2.66683C9.41115 2.66683 8.58662 3.08236 8.03628 3.73629ZM11.666 12.6668C12.9547 12.6668 13.9993 11.6222 13.9993 10.3335C13.9993 9.04483 12.9547 8.00016 11.666 8.00016C11.013 8.00016 10.4227 8.26836 9.99922 8.70063C9.99928 8.68936 9.99935 8.6781 9.99935 8.66683C9.99935 6.45769 8.20848 4.66683 5.99935 4.66683C3.79021 4.66683 1.99935 6.45769 1.99935 8.66683C1.99935 10.876 3.79021 12.6668 5.99935 12.6668H11.666Z" fill="#344054"/>
</svg>""", # noqa: E501
ToolLabelEnum.FINANCE: """<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
<path d="M8.00262 14.6685C4.32071 14.6685 1.33594 11.6838 1.33594 8.00184C1.33594 4.31997 4.32071 1.33521 8.00262 1.33521C11.6845 1.33521 14.6693 4.31997 14.6693 8.00184C14.6693 11.6838 11.6845 14.6685 8.00262 14.6685ZM8.00262 13.3352C10.9482 13.3352 13.336 10.9474 13.336 8.00184C13.336 5.05635 10.9482 2.66854 8.00262 2.66854C5.05708 2.66854 2.66927 5.05635 2.66927 8.00184C2.66927 10.9474 5.05708 13.3352 8.00262 13.3352ZM5.66927 9.33517H9.33595C9.52002 9.33517 9.66928 9.18597 9.66928 9.00184C9.66928 8.81777 9.52002 8.66851 9.33595 8.66851H6.66928C5.7488 8.66851 5.0026 7.92237 5.0026 7.00184C5.0026 6.08139 5.7488 5.33521 6.66928 5.33521H7.33595V4.00187H8.66928V5.33521H10.336V6.66851H6.66928C6.48518 6.66851 6.33594 6.81777 6.33594 7.00184C6.33594 7.18597 6.48518 7.33517 6.66928 7.33517H9.33595C10.2564 7.33517 11.0026 8.08137 11.0026 9.00184C11.0026 9.92237 10.2564 10.6685 9.33595 10.6685H8.66928V12.0018H7.33595V10.6685H5.66927V9.33517Z" fill="#344054"/>
</svg>""", # noqa: E501
ToolLabelEnum.DESIGN: """<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
<path d="M4.70152 9.41416L3.2873 10.8284L5.17292 12.714L12.7154 5.17154L10.8298 3.28592L9.41557 4.70013L10.3584 5.64295L9.41557 6.58575L8.47277 5.64295L7.52997 6.58575L8.47277 7.52856L7.52997 8.47136L6.58713 7.52856L5.64433 8.47136L6.58713 9.41416L5.64433 10.357L4.70152 9.41416ZM11.3012 1.87171L14.1296 4.70013C14.39 4.96049 14.39 5.38259 14.1296 5.64295L5.64433 14.1282C5.38397 14.3886 4.96187 14.3886 4.70152 14.1282L1.87309 11.2998C1.61274 11.0394 1.61274 10.6174 1.87309 10.357L10.3584 1.87171C10.6187 1.61136 11.0408 1.61136 11.3012 1.87171ZM9.41557 12.2423L10.3584 11.2995L11.8534 12.7945H12.7962V11.8517L11.3012 10.3567L12.244 9.41383L14.0011 11.171V13.9999H11.1732L9.41557 12.2423ZM3.75861 6.58533L1.87299 4.69971C1.61265 4.43937 1.61265 4.01725 1.87299 3.75691L3.75861 1.87129C4.01896 1.61094 4.44107 1.61094 4.70142 1.87129L6.58704 3.75691L5.64423 4.69971L4.23002 3.2855L3.28721 4.22831L4.70142 5.64253L3.75861 6.58533Z" fill="#344054"/>
</svg>""", # noqa: E501
ToolLabelEnum.TRAVEL: """<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
<path d="M9.44839 2C9.80198 2 10.1411 2.14047 10.3912 2.39053L13.6101 5.60947C13.8602 5.85953 14.0007 6.19866 14.0007 6.55229V11.3333H15.334V12.6667L9.91652 12.6672C9.62032 13.8171 8.57638 14.6667 7.33398 14.6667C6.0916 14.6667 5.04766 13.8171 4.75146 12.6672L2.00065 12.6667C1.63246 12.6667 1.33398 12.3682 1.33398 12V3.33333C1.33398 2.59695 1.93094 2 2.66732 2H9.44839ZM7.33398 10.6667C6.5976 10.6667 6.00065 11.2636 6.00065 12C6.00065 12.7364 6.5976 13.3333 7.33398 13.3333C8.07038 13.3333 8.66732 12.7364 8.66732 12C8.66732 11.2636 8.07038 10.6667 7.33398 10.6667ZM9.44839 3.33333H2.66732V11.3333L4.75128 11.3335C5.04726 10.1833 6.09136 9.33333 7.33398 9.33333C8.57658 9.33333 9.62072 10.1833 9.91665 11.3335L12.6673 11.3333V6.55229L9.44839 3.33333ZM9.33398 4.66667V8.66667H4.00065V4.66667H9.33398ZM8.00065 6H5.33398V7.33333H8.00065V6Z" fill="#344054"/>
</svg>""", # noqa: E501
ToolLabelEnum.SOCIAL: """<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
<path d="M13.334 7.99992C13.334 5.0544 10.9462 2.66659 8.00065 2.66659C5.05513 2.66659 2.66732 5.0544 2.66732 7.99992C2.66732 10.9455 5.05513 13.3333 8.00065 13.3333C9.09518 13.3333 10.1127 13.0035 10.9594 12.438L11.699 13.5475C10.6408 14.2545 9.36885 14.6666 8.00065 14.6666C4.31875 14.6666 1.33398 11.6818 1.33398 7.99992C1.33398 4.31802 4.31875 1.33325 8.00065 1.33325C11.6825 1.33325 14.6673 4.31802 14.6673 7.99992V8.99992C14.6673 10.2886 13.6227 11.3333 12.334 11.3333C11.5312 11.3333 10.8231 10.9278 10.4032 10.3105C9.79678 10.9409 8.94452 11.3333 8.00065 11.3333C6.1597 11.3333 4.66732 9.84085 4.66732 7.99992C4.66732 6.15897 6.1597 4.66658 8.00065 4.66658C8.75118 4.66658 9.44378 4.91464 10.001 5.33325H11.334V8.99992C11.334 9.55219 11.7817 9.99992 12.334 9.99992C12.8863 9.99992 13.334 9.55219 13.334 8.99992V7.99992ZM8.00065 5.99992C6.89605 5.99992 6.00065 6.89532 6.00065 7.99992C6.00065 9.10452 6.89605 9.99992 8.00065 9.99992C9.10525 9.99992 10.0007 9.10452 10.0007 7.99992C10.0007 6.89532 9.10525 5.99992 8.00065 5.99992Z" fill="#344054"/>
</svg>""", # noqa: E501
ToolLabelEnum.NEWS: """<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
<path d="M10.6673 13.3335V2.66683H2.66732V12.6668C2.66732 13.035 2.9658 13.3335 3.33398 13.3335H10.6673ZM12.6673 14.6668H3.33398C2.22942 14.6668 1.33398 13.7714 1.33398 12.6668V2.00016C1.33398 1.63198 1.63246 1.3335 2.00065 1.3335H11.334C11.7022 1.3335 12.0007 1.63198 12.0007 2.00016V6.66683H14.6673V12.6668C14.6673 13.7714 13.7719 14.6668 12.6673 14.6668ZM12.0007 8.00016V12.6668C12.0007 13.035 12.2991 13.3335 12.6673 13.3335C13.0355 13.3335 13.334 13.035 13.334 12.6668V8.00016H12.0007ZM4.00065 4.00016H8.00065V8.00016H4.00065V4.00016ZM5.33398 5.3335V6.66683H6.66732V5.3335H5.33398ZM4.00065 8.66683H9.33398V10.0002H4.00065V8.66683ZM4.00065 10.6668H9.33398V12.0002H4.00065V10.6668Z" fill="#344054"/>
</svg>""", # noqa: E501
ToolLabelEnum.MEDICAL: """<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
<path d="M8.79747 1.51186L10.9641 5.26464C11.1482 5.5835 11.0389 5.99122 10.7201 6.17532L9.85373 6.67474L10.5207 7.83001L9.366 8.49668L8.699 7.34141L7.83333 7.84201C7.51447 8.02608 7.10673 7.91681 6.92267 7.59794L5.69747 5.47632C4.32922 5.89145 3.33333 7.16268 3.33333 8.66654C3.33333 9.08348 3.40987 9.48248 3.54965 9.85034C4.06613 9.52254 4.67762 9.33321 5.33333 9.33321C6.45605 9.33321 7.44913 9.88828 8.05313 10.7389L13.1787 7.78014L13.8454 8.93488L8.5932 11.9672C8.64133 12.1927 8.66667 12.4267 8.66667 12.6665C8.66667 12.895 8.64367 13.1181 8.59993 13.3337L14 13.3332V14.6665L2.66703 14.6673C2.2482 14.1101 2 13.4173 2 12.6665C2 11.9951 2.19855 11.3699 2.54014 10.8467C2.19517 10.1964 2 9.45428 2 8.66654C2 6.66968 3.25421 4.96575 5.01785 4.29953L4.75598 3.84519C4.38779 3.20747 4.60629 2.39202 5.24402 2.02382L6.97607 1.02382C7.6138 0.655637 8.42927 0.874138 8.79747 1.51186ZM5.33333 10.6665C4.22877 10.6665 3.33333 11.562 3.33333 12.6665C3.33333 12.9003 3.37343 13.1247 3.44711 13.3331H7.21953C7.29327 13.1247 7.33333 12.9003 7.33333 12.6665C7.33333 11.562 6.4379 10.6665 5.33333 10.6665ZM7.64273 2.17852L5.91068 3.17852L7.744 6.35395L9.47607 5.35395L7.64273 2.17852Z" fill="#344054"/>
</svg>""", # noqa: E501
ToolLabelEnum.PRODUCTIVITY: """<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
<path d="M6.64807 11.9999H9.35062C9.43862 11.1989 9.84742 10.5376 10.5111 9.81499C10.5858 9.73365 11.0652 9.23752 11.1221 9.16665C11.6872 8.46199 11.9993 7.58992 11.9993 6.66659C11.9993 4.45745 10.2085 2.66659 7.99935 2.66659C5.79021 2.66659 3.99935 4.45745 3.99935 6.66659C3.99935 7.58945 4.31118 8.46105 4.87576 9.16552C4.93271 9.23659 5.41322 9.73405 5.48704 9.81445C6.15112 10.5375 6.56004 11.1989 6.64807 11.9999ZM9.33268 13.3333H6.66602V13.9999H9.33268V13.3333ZM3.83532 9.99939C3.10365 9.08639 2.66602 7.92759 2.66602 6.66659C2.66602 3.72107 5.05383 1.33325 7.99935 1.33325C10.9449 1.33325 13.3327 3.72107 13.3327 6.66659C13.3327 7.92825 12.8945 9.08759 12.1622 10.0009C11.7487 10.5165 10.666 11.3333 10.666 12.3333V13.9999C10.666 14.7363 10.0691 15.3333 9.33268 15.3333H6.66602C5.92964 15.3333 5.33268 14.7363 5.33268 13.9999V12.3333C5.33268 11.3333 4.24907 10.5157 3.83532 9.99939ZM8.66602 6.66979H10.3327L7.33268 10.6698V8.00312H5.66602L8.66602 3.99992V6.66979Z" fill="#344054"/>
</svg>""", # noqa: E501
ToolLabelEnum.EDUCATION: """<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
<path d="M14 2.66683H4.66667C3.93029 2.66683 3.33333 3.26378 3.33333 4.00016C3.33333 4.73654 3.93029 5.3335 4.66667 5.3335H14V14.0002C14 14.3684 13.7015 14.6668 13.3333 14.6668H4.66667C3.19391 14.6668 2 13.4729 2 12.0002V4.00016C2 2.5274 3.19391 1.3335 4.66667 1.3335H13.3333C13.7015 1.3335 14 1.63198 14 2.00016V2.66683ZM3.33333 12.0002C3.33333 12.7366 3.93029 13.3335 4.66667 13.3335H12.6667V6.66683H4.66667C4.18095 6.66683 3.72557 6.53697 3.33333 6.31008V12.0002ZM13.3333 4.66683H4.66667C4.29848 4.66683 4 4.36835 4 4.00016C4 3.63198 4.29848 3.3335 4.66667 3.3335H13.3333V4.66683Z" fill="#344054"/>
</svg>""", # noqa: E501
ToolLabelEnum.BUSINESS: """<svg xmlns="http://www.w3.org/2000/svg" width="14" height="14" viewBox="0 0 14 14" fill="none">
<path d="M3.66732 3.33341V1.33341C3.66732 0.965228 3.9658 0.666748 4.33398 0.666748H9.66732C10.0355 0.666748 10.334 0.965228 10.334 1.33341V3.33341H13.0007C13.3689 3.33341 13.6673 3.63189 13.6673 4.00008V13.3334C13.6673 13.7016 13.3689 14.0001 13.0007 14.0001H1.00065C0.632464 14.0001 0.333984 13.7016 0.333984 13.3334V4.00008C0.333984 3.63189 0.632464 3.33341 1.00065 3.33341H3.66732ZM12.334 8.66675H1.66732V12.6667H12.334V8.66675ZM12.334 4.66675H1.66732V7.33341H3.66732V6.00008H5.00065V7.33341H9.00065V6.00008H10.334V7.33341H12.334V4.66675ZM5.00065 2.00008V3.33341H9.00065V2.00008H5.00065Z" fill="#344054"/>
</svg>""", # noqa: E501
ToolLabelEnum.ENTERTAINMENT: """<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
<path d="M11.3327 2.66675C13.5418 2.66675 15.3327 4.45761 15.3327 6.66675V9.33342C15.3327 11.5425 13.5418 13.3334 11.3327 13.3334H4.66602C2.45688 13.3334 0.666016 11.5425 0.666016 9.33342V6.66675C0.666016 4.45761 2.45688 2.66675 4.66602 2.66675H11.3327ZM11.3327 4.00008H4.66602C3.23788 4.00008 2.07196 5.12273 2.00262 6.53365L1.99935 6.66675V9.33342C1.99935 10.7615 3.122 11.9275 4.53292 11.9968L4.66602 12.0001H11.3327C12.7608 12.0001 13.9267 10.8774 13.9961 9.46648L13.9993 9.33342V6.66675C13.9993 5.23861 12.8767 4.07269 11.4657 4.00335L11.3327 4.00008ZM6.66602 6.00008V7.33342H7.99935V8.66675H6.66535L6.66602 10.0001H5.33268L5.33202 8.66675H3.99935V7.33342H5.33268V6.00008H6.66602ZM11.9993 8.66675V10.0001H10.666V8.66675H11.9993ZM10.666 6.00008V7.33342H9.33268V6.00008H10.666Z" fill="#344054"/>
</svg>""", # noqa: E501
ToolLabelEnum.UTILITIES: """<svg xmlns="http://www.w3.org/2000/svg" width="13" height="15" viewBox="0 0 13 15" fill="none">
<path d="M12.3346 0.333252C12.7028 0.333252 13.0013 0.631732 13.0013 0.999919V4.33325C13.0013 4.70144 12.7028 4.99992 12.3346 4.99992H9.0013V13.6666C9.0013 14.0348 8.70284 14.3333 8.33463 14.3333H5.66797C5.29978 14.3333 5.0013 14.0348 5.0013 13.6666V4.99992H1.33464C0.966449 4.99992 0.667969 4.70144 0.667969 4.33325V2.74527C0.667969 2.49276 0.810635 2.26192 1.0365 2.14899L4.66797 0.333252H12.3346ZM9.0013 1.66659H4.98273L2.0013 3.1573V3.66659H6.33464V12.9999H7.66797V3.66659H9.0013V1.66659ZM11.668 1.66659H10.3346V3.66659H11.668V1.66659Z" fill="#344054"/>
</svg>""", # noqa: E501
ToolLabelEnum.OTHER: """<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
<path d="M8.00052 0.666748L4.00065 7.33342H12.0007L8.00052 0.666748ZM8.00052 3.25828L9.64572 6.00008H6.35553L8.00052 3.25828ZM4.50065 13.3334C3.48813 13.3334 2.66732 12.5126 2.66732 11.5001C2.66732 10.4875 3.48813 9.66675 4.50065 9.66675C5.51317 9.66675 6.33398 10.4875 6.33398 11.5001C6.33398 12.5126 5.51317 13.3334 4.50065 13.3334ZM4.50065 14.6667C6.24955 14.6667 7.66732 13.249 7.66732 11.5001C7.66732 9.75115 6.24955 8.33342 4.50065 8.33342C2.75175 8.33342 1.33398 9.75115 1.33398 11.5001C1.33398 13.249 2.75175 14.6667 4.50065 14.6667ZM10.0007 10.3334V13.0001H12.6673V10.3334H10.0007ZM8.66732 14.3334V9.00008H14.0007V14.3334H8.66732Z" fill="#344054"/>
</svg>""", # noqa: E501
}
default_tool_label_dict = {
ToolLabelEnum.SEARCH: ToolLabel(
name="search", label=I18nObject(en_US="Search", zh_Hans="搜索"), icon=ICONS[ToolLabelEnum.SEARCH]
),
ToolLabelEnum.IMAGE: ToolLabel(
name="image", label=I18nObject(en_US="Image", zh_Hans="图片"), icon=ICONS[ToolLabelEnum.IMAGE]
),
ToolLabelEnum.VIDEOS: ToolLabel(
name="videos", label=I18nObject(en_US="Videos", zh_Hans="视频"), icon=ICONS[ToolLabelEnum.VIDEOS]
),
ToolLabelEnum.WEATHER: ToolLabel(
name="weather", label=I18nObject(en_US="Weather", zh_Hans="天气"), icon=ICONS[ToolLabelEnum.WEATHER]
),
ToolLabelEnum.FINANCE: ToolLabel(
name="finance", label=I18nObject(en_US="Finance", zh_Hans="金融"), icon=ICONS[ToolLabelEnum.FINANCE]
),
ToolLabelEnum.DESIGN: ToolLabel(
name="design", label=I18nObject(en_US="Design", zh_Hans="设计"), icon=ICONS[ToolLabelEnum.DESIGN]
),
ToolLabelEnum.TRAVEL: ToolLabel(
name="travel", label=I18nObject(en_US="Travel", zh_Hans="旅行"), icon=ICONS[ToolLabelEnum.TRAVEL]
),
ToolLabelEnum.SOCIAL: ToolLabel(
name="social", label=I18nObject(en_US="Social", zh_Hans="社交"), icon=ICONS[ToolLabelEnum.SOCIAL]
),
ToolLabelEnum.NEWS: ToolLabel(
name="news", label=I18nObject(en_US="News", zh_Hans="新闻"), icon=ICONS[ToolLabelEnum.NEWS]
),
ToolLabelEnum.MEDICAL: ToolLabel(
name="medical", label=I18nObject(en_US="Medical", zh_Hans="医疗"), icon=ICONS[ToolLabelEnum.MEDICAL]
),
ToolLabelEnum.PRODUCTIVITY: ToolLabel(
name="productivity",
label=I18nObject(en_US="Productivity", zh_Hans="生产力"),
icon=ICONS[ToolLabelEnum.PRODUCTIVITY],
),
ToolLabelEnum.EDUCATION: ToolLabel(
name="education", label=I18nObject(en_US="Education", zh_Hans="教育"), icon=ICONS[ToolLabelEnum.EDUCATION]
),
ToolLabelEnum.BUSINESS: ToolLabel(
name="business", label=I18nObject(en_US="Business", zh_Hans="商业"), icon=ICONS[ToolLabelEnum.BUSINESS]
),
ToolLabelEnum.ENTERTAINMENT: ToolLabel(
name="entertainment",
label=I18nObject(en_US="Entertainment", zh_Hans="娱乐"),
icon=ICONS[ToolLabelEnum.ENTERTAINMENT],
),
ToolLabelEnum.UTILITIES: ToolLabel(
name="utilities", label=I18nObject(en_US="Utilities", zh_Hans="工具"), icon=ICONS[ToolLabelEnum.UTILITIES]
),
ToolLabelEnum.OTHER: ToolLabel(
name="other", label=I18nObject(en_US="Other", zh_Hans="其他"), icon=ICONS[ToolLabelEnum.OTHER]
),
}
default_tool_labels = [v for k, v in default_tool_label_dict.items()]
default_tool_label_name_list = [label.name for label in default_tool_labels]

View File

@ -0,0 +1,37 @@
from core.tools.entities.tool_entities import ToolInvokeMeta
class ToolProviderNotFoundError(ValueError):
pass
class ToolNotFoundError(ValueError):
pass
class ToolParameterValidationError(ValueError):
pass
class ToolProviderCredentialValidationError(ValueError):
pass
class ToolNotSupportedError(ValueError):
pass
class ToolInvokeError(ValueError):
pass
class ToolApiSchemaError(ValueError):
pass
class ToolEngineInvokeError(Exception):
meta: ToolInvokeMeta
def __init__(self, meta, **kwargs):
self.meta = meta
super().__init__(**kwargs)

View File

@ -0,0 +1,79 @@
from typing import Any
from core.plugin.manager.tool import PluginToolManager
from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.builtin_tool.provider import BuiltinToolProviderController
from core.tools.entities.tool_entities import ToolProviderEntityWithPlugin, ToolProviderType
from core.tools.errors import ToolProviderCredentialValidationError
from core.tools.plugin_tool.tool import PluginTool
class PluginToolProviderController(BuiltinToolProviderController):
entity: ToolProviderEntityWithPlugin
tenant_id: str
plugin_id: str
plugin_unique_identifier: str
def __init__(
self, entity: ToolProviderEntityWithPlugin, plugin_id: str, plugin_unique_identifier: str, tenant_id: str
) -> None:
self.entity = entity
self.tenant_id = tenant_id
self.plugin_id = plugin_id
self.plugin_unique_identifier = plugin_unique_identifier
@property
def provider_type(self) -> ToolProviderType:
"""
returns the type of the provider
:return: type of the provider
"""
return ToolProviderType.PLUGIN
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None:
"""
validate the credentials of the provider
"""
manager = PluginToolManager()
if not manager.validate_provider_credentials(
tenant_id=self.tenant_id,
user_id=user_id,
provider=self.entity.identity.name,
credentials=credentials,
):
raise ToolProviderCredentialValidationError("Invalid credentials")
def get_tool(self, tool_name: str) -> PluginTool: # type: ignore
"""
return tool with given name
"""
tool_entity = next(
(tool_entity for tool_entity in self.entity.tools if tool_entity.identity.name == tool_name), None
)
if not tool_entity:
raise ValueError(f"Tool with name {tool_name} not found")
return PluginTool(
entity=tool_entity,
runtime=ToolRuntime(tenant_id=self.tenant_id),
tenant_id=self.tenant_id,
icon=self.entity.identity.icon,
plugin_unique_identifier=self.plugin_unique_identifier,
)
def get_tools(self) -> list[PluginTool]: # type: ignore
"""
get all tools
"""
return [
PluginTool(
entity=tool_entity,
runtime=ToolRuntime(tenant_id=self.tenant_id),
tenant_id=self.tenant_id,
icon=self.entity.identity.icon,
plugin_unique_identifier=self.plugin_unique_identifier,
)
for tool_entity in self.entity.tools
]

View File

@ -0,0 +1,89 @@
from collections.abc import Generator
from typing import Any, Optional
from core.plugin.manager.tool import PluginToolManager
from core.plugin.utils.converter import convert_parameters_to_plugin_format
from core.tools.__base.tool import Tool
from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolParameter, ToolProviderType
class PluginTool(Tool):
tenant_id: str
icon: str
plugin_unique_identifier: str
runtime_parameters: Optional[list[ToolParameter]]
def __init__(
self, entity: ToolEntity, runtime: ToolRuntime, tenant_id: str, icon: str, plugin_unique_identifier: str
) -> None:
super().__init__(entity, runtime)
self.tenant_id = tenant_id
self.icon = icon
self.plugin_unique_identifier = plugin_unique_identifier
self.runtime_parameters = None
def tool_provider_type(self) -> ToolProviderType:
return ToolProviderType.PLUGIN
def _invoke(
self,
user_id: str,
tool_parameters: dict[str, Any],
conversation_id: Optional[str] = None,
app_id: Optional[str] = None,
message_id: Optional[str] = None,
) -> Generator[ToolInvokeMessage, None, None]:
manager = PluginToolManager()
tool_parameters = convert_parameters_to_plugin_format(tool_parameters)
yield from manager.invoke(
tenant_id=self.tenant_id,
user_id=user_id,
tool_provider=self.entity.identity.provider,
tool_name=self.entity.identity.name,
credentials=self.runtime.credentials,
tool_parameters=tool_parameters,
conversation_id=conversation_id,
app_id=app_id,
message_id=message_id,
)
def fork_tool_runtime(self, runtime: ToolRuntime) -> "PluginTool":
return PluginTool(
entity=self.entity,
runtime=runtime,
tenant_id=self.tenant_id,
icon=self.icon,
plugin_unique_identifier=self.plugin_unique_identifier,
)
def get_runtime_parameters(
self,
conversation_id: Optional[str] = None,
app_id: Optional[str] = None,
message_id: Optional[str] = None,
) -> list[ToolParameter]:
"""
get the runtime parameters
"""
if not self.entity.has_runtime_parameters:
return self.entity.parameters
if self.runtime_parameters is not None:
return self.runtime_parameters
manager = PluginToolManager()
self.runtime_parameters = manager.get_runtime_parameters(
tenant_id=self.tenant_id,
user_id="",
provider=self.entity.identity.provider,
tool=self.entity.identity.name,
credentials=self.runtime.credentials,
conversation_id=conversation_id,
app_id=app_id,
message_id=message_id,
)
return self.runtime_parameters

View File

@ -0,0 +1,357 @@
import json
from collections.abc import Generator, Iterable
from copy import deepcopy
from datetime import UTC, datetime
from mimetypes import guess_type
from typing import Any, Optional, Union, cast
from yarl import URL
from core.app.entities.app_invoke_entities import InvokeFrom
from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
from core.file import FileType
from core.file.models import FileTransferMethod
from core.ops.ops_trace_manager import TraceQueueManager
from core.tools.__base.tool import Tool
from core.tools.entities.tool_entities import (
ToolInvokeMessage,
ToolInvokeMessageBinary,
ToolInvokeMeta,
ToolParameter,
)
from core.tools.errors import (
ToolEngineInvokeError,
ToolInvokeError,
ToolNotFoundError,
ToolNotSupportedError,
ToolParameterValidationError,
ToolProviderCredentialValidationError,
ToolProviderNotFoundError,
)
from core.tools.utils.message_transformer import ToolFileMessageTransformer
from core.tools.workflow_as_tool.tool import WorkflowTool
from extensions.ext_database import db
from models.enums import CreatedByRole
from models.model import Message, MessageFile
class ToolEngine:
"""
Tool runtime engine take care of the tool executions.
"""
@staticmethod
def agent_invoke(
tool: Tool,
tool_parameters: Union[str, dict],
user_id: str,
tenant_id: str,
message: Message,
invoke_from: InvokeFrom,
agent_tool_callback: DifyAgentCallbackHandler,
trace_manager: Optional[TraceQueueManager] = None,
conversation_id: Optional[str] = None,
app_id: Optional[str] = None,
message_id: Optional[str] = None,
) -> tuple[str, list[str], ToolInvokeMeta]:
"""
Agent invokes the tool with the given arguments.
"""
# check if arguments is a string
if isinstance(tool_parameters, str):
# check if this tool has only one parameter
parameters = [
parameter
for parameter in tool.get_runtime_parameters()
if parameter.form == ToolParameter.ToolParameterForm.LLM
]
if parameters and len(parameters) == 1:
tool_parameters = {parameters[0].name: tool_parameters}
else:
try:
tool_parameters = json.loads(tool_parameters)
except Exception:
pass
if not isinstance(tool_parameters, dict):
raise ValueError(f"tool_parameters should be a dict, but got a string: {tool_parameters}")
try:
# hit the callback handler
agent_tool_callback.on_tool_start(tool_name=tool.entity.identity.name, tool_inputs=tool_parameters)
messages = ToolEngine._invoke(tool, tool_parameters, user_id, conversation_id, app_id, message_id)
invocation_meta_dict: dict[str, ToolInvokeMeta] = {}
def message_callback(
invocation_meta_dict: dict, messages: Generator[ToolInvokeMessage | ToolInvokeMeta, None, None]
):
for message in messages:
if isinstance(message, ToolInvokeMeta):
invocation_meta_dict["meta"] = message
else:
yield message
messages = ToolFileMessageTransformer.transform_tool_invoke_messages(
messages=message_callback(invocation_meta_dict, messages),
user_id=user_id,
tenant_id=tenant_id,
conversation_id=message.conversation_id,
)
message_list = list(messages)
# extract binary data from tool invoke message
binary_files = ToolEngine._extract_tool_response_binary_and_text(message_list)
# create message file
message_files = ToolEngine._create_message_files(
tool_messages=binary_files, agent_message=message, invoke_from=invoke_from, user_id=user_id
)
plain_text = ToolEngine._convert_tool_response_to_str(message_list)
meta = invocation_meta_dict["meta"]
# hit the callback handler
agent_tool_callback.on_tool_end(
tool_name=tool.entity.identity.name,
tool_inputs=tool_parameters,
tool_outputs=plain_text,
message_id=message.id,
trace_manager=trace_manager,
)
# transform tool invoke message to get LLM friendly message
return plain_text, message_files, meta
except ToolProviderCredentialValidationError as e:
error_response = "Please check your tool provider credentials"
agent_tool_callback.on_tool_error(e)
except (ToolNotFoundError, ToolNotSupportedError, ToolProviderNotFoundError) as e:
error_response = f"there is not a tool named {tool.entity.identity.name}"
agent_tool_callback.on_tool_error(e)
except ToolParameterValidationError as e:
error_response = f"tool parameters validation error: {e}, please check your tool parameters"
agent_tool_callback.on_tool_error(e)
except ToolInvokeError as e:
error_response = f"tool invoke error: {e}"
agent_tool_callback.on_tool_error(e)
except ToolEngineInvokeError as e:
meta = e.meta
error_response = f"tool invoke error: {meta.error}"
agent_tool_callback.on_tool_error(e)
return error_response, [], meta
except Exception as e:
error_response = f"unknown error: {e}"
agent_tool_callback.on_tool_error(e)
return error_response, [], ToolInvokeMeta.error_instance(error_response)
@staticmethod
def x(
tool: Tool,
tool_parameters: dict[str, Any],
user_id: str,
workflow_tool_callback: DifyWorkflowCallbackHandler,
workflow_call_depth: int,
thread_pool_id: Optional[str] = None,
conversation_id: Optional[str] = None,
app_id: Optional[str] = None,
message_id: Optional[str] = None,
) -> Generator[ToolInvokeMessage, None, None]:
"""
Workflow invokes the tool with the given arguments.
"""
try:
# hit the callback handler
workflow_tool_callback.on_tool_start(tool_name=tool.entity.identity.name, tool_inputs=tool_parameters)
if isinstance(tool, WorkflowTool):
tool.workflow_call_depth = workflow_call_depth + 1
tool.thread_pool_id = thread_pool_id
if tool.runtime and tool.runtime.runtime_parameters:
tool_parameters = {**tool.runtime.runtime_parameters, **tool_parameters}
response = tool.invoke(
user_id=user_id,
tool_parameters=tool_parameters,
conversation_id=conversation_id,
app_id=app_id,
message_id=message_id,
)
# hit the callback handler
response = workflow_tool_callback.on_tool_execution(
tool_name=tool.entity.identity.name,
tool_inputs=tool_parameters,
tool_outputs=response,
)
return response
except Exception as e:
workflow_tool_callback.on_tool_error(e)
raise e
@staticmethod
def _invoke(
tool: Tool,
tool_parameters: dict,
user_id: str,
conversation_id: Optional[str] = None,
app_id: Optional[str] = None,
message_id: Optional[str] = None,
) -> Generator[ToolInvokeMessage | ToolInvokeMeta, None, None]:
"""
Invoke the tool with the given arguments.
"""
started_at = datetime.now(UTC)
meta = ToolInvokeMeta(
time_cost=0.0,
error=None,
tool_config={
"tool_name": tool.entity.identity.name,
"tool_provider": tool.entity.identity.provider,
"tool_provider_type": tool.tool_provider_type().value,
"tool_parameters": deepcopy(tool.runtime.runtime_parameters),
"tool_icon": tool.entity.identity.icon,
},
)
try:
yield from tool.invoke(user_id, tool_parameters, conversation_id, app_id, message_id)
except Exception as e:
meta.error = str(e)
raise ToolEngineInvokeError(meta)
finally:
ended_at = datetime.now(UTC)
meta.time_cost = (ended_at - started_at).total_seconds()
yield meta
@staticmethod
def _convert_tool_response_to_str(tool_response: list[ToolInvokeMessage]) -> str:
"""
Handle tool response
"""
result = ""
for response in tool_response:
if response.type == ToolInvokeMessage.MessageType.TEXT:
result += cast(ToolInvokeMessage.TextMessage, response.message).text
elif response.type == ToolInvokeMessage.MessageType.LINK:
result += (
f"result link: {cast(ToolInvokeMessage.TextMessage, response.message).text}."
+ " please tell user to check it."
)
elif response.type in {ToolInvokeMessage.MessageType.IMAGE_LINK, ToolInvokeMessage.MessageType.IMAGE}:
result += (
"image has been created and sent to user already, "
+ "you do not need to create it, just tell the user to check it now."
)
elif response.type == ToolInvokeMessage.MessageType.JSON:
result = json.dumps(
cast(ToolInvokeMessage.JsonMessage, response.message).json_object, ensure_ascii=False
)
else:
result += str(response.message)
return result
@staticmethod
def _extract_tool_response_binary_and_text(
tool_response: list[ToolInvokeMessage],
) -> Generator[ToolInvokeMessageBinary, None, None]:
"""
Extract tool response binary
"""
for response in tool_response:
if response.type in {ToolInvokeMessage.MessageType.IMAGE_LINK, ToolInvokeMessage.MessageType.IMAGE}:
mimetype = None
if not response.meta:
raise ValueError("missing meta data")
if response.meta.get("mime_type"):
mimetype = response.meta.get("mime_type")
else:
try:
url = URL(cast(ToolInvokeMessage.TextMessage, response.message).text)
extension = url.suffix
guess_type_result, _ = guess_type(f"a{extension}")
if guess_type_result:
mimetype = guess_type_result
except Exception:
pass
if not mimetype:
mimetype = "image/jpeg"
yield ToolInvokeMessageBinary(
mimetype=response.meta.get("mime_type", "image/jpeg"),
url=cast(ToolInvokeMessage.TextMessage, response.message).text,
)
elif response.type == ToolInvokeMessage.MessageType.BLOB:
if not response.meta:
raise ValueError("missing meta data")
yield ToolInvokeMessageBinary(
mimetype=response.meta.get("mime_type", "application/octet-stream"),
url=cast(ToolInvokeMessage.TextMessage, response.message).text,
)
elif response.type == ToolInvokeMessage.MessageType.LINK:
# check if there is a mime type in meta
if response.meta and "mime_type" in response.meta:
yield ToolInvokeMessageBinary(
mimetype=response.meta.get("mime_type", "application/octet-stream")
if response.meta
else "application/octet-stream",
url=cast(ToolInvokeMessage.TextMessage, response.message).text,
)
@staticmethod
def _create_message_files(
tool_messages: Iterable[ToolInvokeMessageBinary],
agent_message: Message,
invoke_from: InvokeFrom,
user_id: str,
) -> list[str]:
"""
Create message file
:return: message file ids
"""
result = []
for message in tool_messages:
if "image" in message.mimetype:
file_type = FileType.IMAGE
elif "video" in message.mimetype:
file_type = FileType.VIDEO
elif "audio" in message.mimetype:
file_type = FileType.AUDIO
elif "text" in message.mimetype or "pdf" in message.mimetype:
file_type = FileType.DOCUMENT
else:
file_type = FileType.CUSTOM
# extract tool file id from url
tool_file_id = message.url.split("/")[-1].split(".")[0]
message_file = MessageFile(
message_id=agent_message.id,
type=file_type,
transfer_method=FileTransferMethod.TOOL_FILE,
belongs_to="assistant",
url=message.url,
upload_file_id=tool_file_id,
created_by_role=(
CreatedByRole.ACCOUNT
if invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER}
else CreatedByRole.END_USER
),
created_by=user_id,
)
db.session.add(message_file)
db.session.commit()
db.session.refresh(message_file)
result.append(message_file.id)
db.session.close()
return result

View File

@ -0,0 +1,234 @@
import base64
import hashlib
import hmac
import logging
import os
import time
from mimetypes import guess_extension, guess_type
from typing import Optional, Union
from uuid import uuid4
import httpx
from configs import dify_config
from core.helper import ssrf_proxy
from extensions.ext_database import db
from extensions.ext_storage import storage
from models.model import MessageFile
from models.tools import ToolFile
logger = logging.getLogger(__name__)
class ToolFileManager:
@staticmethod
def sign_file(tool_file_id: str, extension: str) -> str:
"""
sign file to get a temporary url
"""
base_url = dify_config.FILES_URL
file_preview_url = f"{base_url}/files/tools/{tool_file_id}{extension}"
timestamp = str(int(time.time()))
nonce = os.urandom(16).hex()
data_to_sign = f"file-preview|{tool_file_id}|{timestamp}|{nonce}"
secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b""
sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
encoded_sign = base64.urlsafe_b64encode(sign).decode()
return f"{file_preview_url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}"
@staticmethod
def verify_file(file_id: str, timestamp: str, nonce: str, sign: str) -> bool:
"""
verify signature
"""
data_to_sign = f"file-preview|{file_id}|{timestamp}|{nonce}"
secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b""
recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode()
# verify signature
if sign != recalculated_encoded_sign:
return False
current_time = int(time.time())
return current_time - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT
@staticmethod
def create_file_by_raw(
*,
user_id: str,
tenant_id: str,
conversation_id: Optional[str],
file_binary: bytes,
mimetype: str,
filename: Optional[str] = None,
) -> ToolFile:
extension = guess_extension(mimetype) or ".bin"
unique_name = uuid4().hex
unique_filename = f"{unique_name}{extension}"
# default just as before
present_filename = unique_filename
if filename is not None:
has_extension = len(filename.split(".")) > 1
# Add extension flexibly
present_filename = filename if has_extension else f"{filename}{extension}"
filepath = f"tools/{tenant_id}/{unique_filename}"
storage.save(filepath, file_binary)
tool_file = ToolFile(
user_id=user_id,
tenant_id=tenant_id,
conversation_id=conversation_id,
file_key=filepath,
mimetype=mimetype,
name=present_filename,
size=len(file_binary),
)
db.session.add(tool_file)
db.session.commit()
db.session.refresh(tool_file)
return tool_file
@staticmethod
def create_file_by_url(
user_id: str,
tenant_id: str,
file_url: str,
conversation_id: Optional[str] = None,
) -> ToolFile:
# try to download image
try:
response = ssrf_proxy.get(file_url)
response.raise_for_status()
blob = response.content
except httpx.TimeoutException:
raise ValueError(f"timeout when downloading file from {file_url}")
mimetype = (
guess_type(file_url)[0]
or response.headers.get("Content-Type", "").split(";")[0].strip()
or "application/octet-stream"
)
extension = guess_extension(mimetype) or ".bin"
unique_name = uuid4().hex
filename = f"{unique_name}{extension}"
filepath = f"tools/{tenant_id}/{filename}"
storage.save(filepath, blob)
tool_file = ToolFile(
user_id=user_id,
tenant_id=tenant_id,
conversation_id=conversation_id,
file_key=filepath,
mimetype=mimetype,
original_url=file_url,
name=filename,
size=len(blob),
)
db.session.add(tool_file)
db.session.commit()
return tool_file
@staticmethod
def get_file_binary(id: str) -> Union[tuple[bytes, str], None]:
"""
get file binary
:param id: the id of the file
:return: the binary of the file, mime type
"""
tool_file: ToolFile | None = (
db.session.query(ToolFile)
.filter(
ToolFile.id == id,
)
.first()
)
if not tool_file:
return None
blob = storage.load_once(tool_file.file_key)
return blob, tool_file.mimetype
@staticmethod
def get_file_binary_by_message_file_id(id: str) -> Union[tuple[bytes, str], None]:
"""
get file binary
:param id: the id of the file
:return: the binary of the file, mime type
"""
message_file: MessageFile | None = (
db.session.query(MessageFile)
.filter(
MessageFile.id == id,
)
.first()
)
# Check if message_file is not None
if message_file is not None:
# get tool file id
if message_file.url is not None:
tool_file_id = message_file.url.split("/")[-1]
# trim extension
tool_file_id = tool_file_id.split(".")[0]
else:
tool_file_id = None
else:
tool_file_id = None
tool_file: ToolFile | None = (
db.session.query(ToolFile)
.filter(
ToolFile.id == tool_file_id,
)
.first()
)
if not tool_file:
return None
blob = storage.load_once(tool_file.file_key)
return blob, tool_file.mimetype
@staticmethod
def get_file_generator_by_tool_file_id(tool_file_id: str):
"""
get file binary
:param tool_file_id: the id of the tool file
:return: the binary of the file, mime type
"""
tool_file: ToolFile | None = (
db.session.query(ToolFile)
.filter(
ToolFile.id == tool_file_id,
)
.first()
)
if not tool_file:
return None, None
stream = storage.load_stream(tool_file.file_key)
return stream, tool_file
# init tool_file_parser
from core.file.tool_file_parser import tool_file_manager
tool_file_manager["manager"] = ToolFileManager

View File

@ -0,0 +1,101 @@
from core.tools.__base.tool_provider import ToolProviderController
from core.tools.builtin_tool.provider import BuiltinToolProviderController
from core.tools.custom_tool.provider import ApiToolProviderController
from core.tools.entities.values import default_tool_label_name_list
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
from extensions.ext_database import db
from models.tools import ToolLabelBinding
class ToolLabelManager:
@classmethod
def filter_tool_labels(cls, tool_labels: list[str]) -> list[str]:
"""
Filter tool labels
"""
tool_labels = [label for label in tool_labels if label in default_tool_label_name_list]
return list(set(tool_labels))
@classmethod
def update_tool_labels(cls, controller: ToolProviderController, labels: list[str]):
"""
Update tool labels
"""
labels = cls.filter_tool_labels(labels)
if isinstance(controller, ApiToolProviderController | WorkflowToolProviderController):
provider_id = controller.provider_id
else:
raise ValueError("Unsupported tool type")
# delete old labels
db.session.query(ToolLabelBinding).filter(ToolLabelBinding.tool_id == provider_id).delete()
# insert new labels
for label in labels:
db.session.add(
ToolLabelBinding(
tool_id=provider_id,
tool_type=controller.provider_type.value,
label_name=label,
)
)
db.session.commit()
@classmethod
def get_tool_labels(cls, controller: ToolProviderController) -> list[str]:
"""
Get tool labels
"""
if isinstance(controller, ApiToolProviderController | WorkflowToolProviderController):
provider_id = controller.provider_id
elif isinstance(controller, BuiltinToolProviderController):
return controller.tool_labels
else:
raise ValueError("Unsupported tool type")
labels = (
db.session.query(ToolLabelBinding.label_name)
.filter(
ToolLabelBinding.tool_id == provider_id,
ToolLabelBinding.tool_type == controller.provider_type.value,
)
.all()
)
return [label.label_name for label in labels]
@classmethod
def get_tools_labels(cls, tool_providers: list[ToolProviderController]) -> dict[str, list[str]]:
"""
Get tools labels
:param tool_providers: list of tool providers
:return: dict of tool labels
:key: tool id
:value: list of tool labels
"""
if not tool_providers:
return {}
for controller in tool_providers:
if not isinstance(controller, ApiToolProviderController | WorkflowToolProviderController):
raise ValueError("Unsupported tool type")
provider_ids = []
for controller in tool_providers:
assert isinstance(controller, ApiToolProviderController | WorkflowToolProviderController)
provider_ids.append(controller.provider_id)
labels: list[ToolLabelBinding] = (
db.session.query(ToolLabelBinding).filter(ToolLabelBinding.tool_id.in_(provider_ids)).all()
)
tool_labels: dict[str, list[str]] = {label.tool_id: [] for label in labels}
for label in labels:
tool_labels[label.tool_id].append(label.label_name)
return tool_labels

View File

@ -0,0 +1,870 @@
import json
import logging
import mimetypes
from collections.abc import Generator
from os import listdir, path
from threading import Lock
from typing import TYPE_CHECKING, Any, Union, cast
from yarl import URL
import contexts
from core.plugin.entities.plugin import ToolProviderID
from core.plugin.manager.tool import PluginToolManager
from core.tools.__base.tool_provider import ToolProviderController
from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.plugin_tool.provider import PluginToolProviderController
from core.tools.plugin_tool.tool import PluginTool
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
if TYPE_CHECKING:
from core.workflow.nodes.tool.entities import ToolEntity
from configs import dify_config
from core.agent.entities import AgentToolEntity
from core.app.entities.app_invoke_entities import InvokeFrom
from core.helper.module_import_helper import load_single_subclass_from_source
from core.helper.position_helper import is_filtered
from core.model_runtime.utils.encoders import jsonable_encoder
from core.tools.__base.tool import Tool
from core.tools.builtin_tool.provider import BuiltinToolProviderController
from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort
from core.tools.builtin_tool.tool import BuiltinTool
from core.tools.custom_tool.provider import ApiToolProviderController
from core.tools.custom_tool.tool import ApiTool
from core.tools.entities.api_entities import ToolProviderApiEntity, ToolProviderTypeApiLiteral
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import (
ApiProviderAuthType,
ToolInvokeFrom,
ToolParameter,
ToolProviderType,
)
from core.tools.errors import ToolNotFoundError, ToolProviderNotFoundError
from core.tools.tool_label_manager import ToolLabelManager
from core.tools.utils.configuration import (
ProviderConfigEncrypter,
ToolParameterConfigurationManager,
)
from core.tools.workflow_as_tool.tool import WorkflowTool
from extensions.ext_database import db
from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider
from services.tools.tools_transform_service import ToolTransformService
logger = logging.getLogger(__name__)
class ToolManager:
_builtin_provider_lock = Lock()
_hardcoded_providers: dict[str, BuiltinToolProviderController] = {}
_builtin_providers_loaded = False
_builtin_tools_labels: dict[str, Union[I18nObject, None]] = {}
@classmethod
def get_hardcoded_provider(cls, provider: str) -> BuiltinToolProviderController:
"""
get the hardcoded provider
"""
if len(cls._hardcoded_providers) == 0:
# init the builtin providers
cls.load_hardcoded_providers_cache()
return cls._hardcoded_providers[provider]
@classmethod
def get_builtin_provider(
cls, provider: str, tenant_id: str
) -> BuiltinToolProviderController | PluginToolProviderController:
"""
get the builtin provider
:param provider: the name of the provider
:param tenant_id: the id of the tenant
:return: the provider
"""
# split provider to
if len(cls._hardcoded_providers) == 0:
# init the builtin providers
cls.load_hardcoded_providers_cache()
if provider not in cls._hardcoded_providers:
# get plugin provider
plugin_provider = cls.get_plugin_provider(provider, tenant_id)
if plugin_provider:
return plugin_provider
return cls._hardcoded_providers[provider]
@classmethod
def get_plugin_provider(cls, provider: str, tenant_id: str) -> PluginToolProviderController:
"""
get the plugin provider
"""
# check if context is set
try:
contexts.plugin_tool_providers.get()
except LookupError:
contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(Lock())
with contexts.plugin_tool_providers_lock.get():
plugin_tool_providers = contexts.plugin_tool_providers.get()
if provider in plugin_tool_providers:
return plugin_tool_providers[provider]
manager = PluginToolManager()
provider_entity = manager.fetch_tool_provider(tenant_id, provider)
if not provider_entity:
raise ToolProviderNotFoundError(f"plugin provider {provider} not found")
controller = PluginToolProviderController(
entity=provider_entity.declaration,
plugin_id=provider_entity.plugin_id,
plugin_unique_identifier=provider_entity.plugin_unique_identifier,
tenant_id=tenant_id,
)
plugin_tool_providers[provider] = controller
return controller
@classmethod
def get_builtin_tool(cls, provider: str, tool_name: str, tenant_id: str) -> BuiltinTool | PluginTool | None:
"""
get the builtin tool
:param provider: the name of the provider
:param tool_name: the name of the tool
:param tenant_id: the id of the tenant
:return: the provider, the tool
"""
provider_controller = cls.get_builtin_provider(provider, tenant_id)
tool = provider_controller.get_tool(tool_name)
if tool is None:
raise ToolNotFoundError(f"tool {tool_name} not found")
return tool
@classmethod
def get_tool_runtime(
cls,
provider_type: ToolProviderType,
provider_id: str,
tool_name: str,
tenant_id: str,
invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
tool_invoke_from: ToolInvokeFrom = ToolInvokeFrom.AGENT,
) -> Union[BuiltinTool, PluginTool, ApiTool, WorkflowTool]:
"""
get the tool runtime
:param provider_type: the type of the provider
:param provider_id: the id of the provider
:param tool_name: the name of the tool
:param tenant_id: the tenant id
:param invoke_from: invoke from
:param tool_invoke_from: the tool invoke from
:return: the tool
"""
if provider_type == ToolProviderType.BUILT_IN:
# check if the builtin tool need credentials
provider_controller = cls.get_builtin_provider(provider_id, tenant_id)
builtin_tool = provider_controller.get_tool(tool_name)
if not builtin_tool:
raise ToolProviderNotFoundError(f"builtin tool {tool_name} not found")
if not provider_controller.need_credentials:
return cast(
BuiltinTool,
builtin_tool.fork_tool_runtime(
runtime=ToolRuntime(
tenant_id=tenant_id,
credentials={},
invoke_from=invoke_from,
tool_invoke_from=tool_invoke_from,
)
),
)
if isinstance(provider_controller, PluginToolProviderController):
provider_id_entity = ToolProviderID(provider_id)
# get credentials
builtin_provider: BuiltinToolProvider | None = (
db.session.query(BuiltinToolProvider)
.filter(
BuiltinToolProvider.tenant_id == tenant_id,
(BuiltinToolProvider.provider == str(provider_id_entity))
| (BuiltinToolProvider.provider == provider_id_entity.provider_name),
)
.first()
)
if builtin_provider is None:
raise ToolProviderNotFoundError(f"builtin provider {provider_id} not found")
else:
builtin_provider = (
db.session.query(BuiltinToolProvider)
.filter(BuiltinToolProvider.tenant_id == tenant_id, (BuiltinToolProvider.provider == provider_id))
.first()
)
if builtin_provider is None:
raise ToolProviderNotFoundError(f"builtin provider {provider_id} not found")
# decrypt the credentials
credentials = builtin_provider.credentials
tool_configuration = ProviderConfigEncrypter(
tenant_id=tenant_id,
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.entity.identity.name,
)
decrypted_credentials = tool_configuration.decrypt(credentials)
return cast(
BuiltinTool,
builtin_tool.fork_tool_runtime(
runtime=ToolRuntime(
tenant_id=tenant_id,
credentials=decrypted_credentials,
runtime_parameters={},
invoke_from=invoke_from,
tool_invoke_from=tool_invoke_from,
)
),
)
elif provider_type == ToolProviderType.API:
api_provider, credentials = cls.get_api_provider_controller(tenant_id, provider_id)
# decrypt the credentials
tool_configuration = ProviderConfigEncrypter(
tenant_id=tenant_id,
config=[x.to_basic_provider_config() for x in api_provider.get_credentials_schema()],
provider_type=api_provider.provider_type.value,
provider_identity=api_provider.entity.identity.name,
)
decrypted_credentials = tool_configuration.decrypt(credentials)
return cast(
ApiTool,
api_provider.get_tool(tool_name).fork_tool_runtime(
runtime=ToolRuntime(
tenant_id=tenant_id,
credentials=decrypted_credentials,
invoke_from=invoke_from,
tool_invoke_from=tool_invoke_from,
)
),
)
elif provider_type == ToolProviderType.WORKFLOW:
workflow_provider = (
db.session.query(WorkflowToolProvider)
.filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id)
.first()
)
if workflow_provider is None:
raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found")
controller = ToolTransformService.workflow_provider_to_controller(db_provider=workflow_provider)
controller_tools: list[WorkflowTool] = controller.get_tools(tenant_id=workflow_provider.tenant_id)
if controller_tools is None or len(controller_tools) == 0:
raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found")
return cast(
WorkflowTool,
controller.get_tools(tenant_id=workflow_provider.tenant_id)[0].fork_tool_runtime(
runtime=ToolRuntime(
tenant_id=tenant_id,
credentials={},
invoke_from=invoke_from,
tool_invoke_from=tool_invoke_from,
)
),
)
elif provider_type == ToolProviderType.APP:
raise NotImplementedError("app provider not implemented")
elif provider_type == ToolProviderType.PLUGIN:
return cls.get_plugin_provider(provider_id, tenant_id).get_tool(tool_name)
else:
raise ToolProviderNotFoundError(f"provider type {provider_type.value} not found")
@classmethod
def get_agent_tool_runtime(
cls,
tenant_id: str,
app_id: str,
agent_tool: AgentToolEntity,
invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
) -> Tool:
"""
get the agent tool runtime
"""
tool_entity = cls.get_tool_runtime(
provider_type=agent_tool.provider_type,
provider_id=agent_tool.provider_id,
tool_name=agent_tool.tool_name,
tenant_id=tenant_id,
invoke_from=invoke_from,
tool_invoke_from=ToolInvokeFrom.AGENT,
)
runtime_parameters = {}
parameters = tool_entity.get_merged_runtime_parameters()
for parameter in parameters:
# check file types
if (
parameter.type
in {
ToolParameter.ToolParameterType.SYSTEM_FILES,
ToolParameter.ToolParameterType.FILE,
ToolParameter.ToolParameterType.FILES,
}
and parameter.required
):
raise ValueError(f"file type parameter {parameter.name} not supported in agent")
if parameter.form == ToolParameter.ToolParameterForm.FORM:
# save tool parameter to tool entity memory
value = parameter.init_frontend_parameter(agent_tool.tool_parameters.get(parameter.name))
runtime_parameters[parameter.name] = value
# decrypt runtime parameters
encryption_manager = ToolParameterConfigurationManager(
tenant_id=tenant_id,
tool_runtime=tool_entity,
provider_name=agent_tool.provider_id,
provider_type=agent_tool.provider_type,
identity_id=f"AGENT.{app_id}",
)
runtime_parameters = encryption_manager.decrypt_tool_parameters(runtime_parameters)
if tool_entity.runtime is None or tool_entity.runtime.runtime_parameters is None:
raise ValueError("runtime not found or runtime parameters not found")
tool_entity.runtime.runtime_parameters.update(runtime_parameters)
return tool_entity
@classmethod
def get_workflow_tool_runtime(
cls,
tenant_id: str,
app_id: str,
node_id: str,
workflow_tool: "ToolEntity",
invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
) -> Tool:
"""
get the workflow tool runtime
"""
tool_runtime = cls.get_tool_runtime(
provider_type=workflow_tool.provider_type,
provider_id=workflow_tool.provider_id,
tool_name=workflow_tool.tool_name,
tenant_id=tenant_id,
invoke_from=invoke_from,
tool_invoke_from=ToolInvokeFrom.WORKFLOW,
)
runtime_parameters = {}
parameters = tool_runtime.get_merged_runtime_parameters()
for parameter in parameters:
# save tool parameter to tool entity memory
if parameter.form == ToolParameter.ToolParameterForm.FORM:
value = parameter.init_frontend_parameter(workflow_tool.tool_configurations.get(parameter.name))
runtime_parameters[parameter.name] = value
# decrypt runtime parameters
encryption_manager = ToolParameterConfigurationManager(
tenant_id=tenant_id,
tool_runtime=tool_runtime,
provider_name=workflow_tool.provider_id,
provider_type=workflow_tool.provider_type,
identity_id=f"WORKFLOW.{app_id}.{node_id}",
)
if runtime_parameters:
runtime_parameters = encryption_manager.decrypt_tool_parameters(runtime_parameters)
tool_runtime.runtime.runtime_parameters.update(runtime_parameters)
return tool_runtime
@classmethod
def get_tool_runtime_from_plugin(
cls,
tool_type: ToolProviderType,
tenant_id: str,
provider: str,
tool_name: str,
tool_parameters: dict[str, Any],
) -> Tool:
"""
get tool runtime from plugin
"""
tool_entity = cls.get_tool_runtime(
provider_type=tool_type,
provider_id=provider,
tool_name=tool_name,
tenant_id=tenant_id,
invoke_from=InvokeFrom.SERVICE_API,
tool_invoke_from=ToolInvokeFrom.PLUGIN,
)
runtime_parameters = {}
parameters = tool_entity.get_merged_runtime_parameters()
for parameter in parameters:
if parameter.form == ToolParameter.ToolParameterForm.FORM:
# save tool parameter to tool entity memory
value = parameter.init_frontend_parameter(tool_parameters.get(parameter.name))
runtime_parameters[parameter.name] = value
tool_entity.runtime.runtime_parameters.update(runtime_parameters)
return tool_entity
@classmethod
def get_hardcoded_provider_icon(cls, provider: str) -> tuple[str, str]:
"""
get the absolute path of the icon of the hardcoded provider
:param provider: the name of the provider
:return: the absolute path of the icon, the mime type of the icon
"""
# get provider
provider_controller = cls.get_hardcoded_provider(provider)
absolute_path = path.join(
path.dirname(path.realpath(__file__)),
"builtin_tool",
"providers",
provider,
"_assets",
provider_controller.entity.identity.icon,
)
# check if the icon exists
if not path.exists(absolute_path):
raise ToolProviderNotFoundError(f"builtin provider {provider} icon not found")
# get the mime type
mime_type, _ = mimetypes.guess_type(absolute_path)
mime_type = mime_type or "application/octet-stream"
return absolute_path, mime_type
@classmethod
def list_hardcoded_providers(cls):
# use cache first
if cls._builtin_providers_loaded:
yield from list(cls._hardcoded_providers.values())
return
with cls._builtin_provider_lock:
if cls._builtin_providers_loaded:
yield from list(cls._hardcoded_providers.values())
return
yield from cls._list_hardcoded_providers()
@classmethod
def list_plugin_providers(cls, tenant_id: str) -> list[PluginToolProviderController]:
"""
list all the plugin providers
"""
manager = PluginToolManager()
provider_entities = manager.fetch_tool_providers(tenant_id)
return [
PluginToolProviderController(
entity=provider.declaration,
plugin_id=provider.plugin_id,
plugin_unique_identifier=provider.plugin_unique_identifier,
tenant_id=tenant_id,
)
for provider in provider_entities
]
@classmethod
def list_builtin_providers(
cls, tenant_id: str
) -> Generator[BuiltinToolProviderController | PluginToolProviderController, None, None]:
"""
list all the builtin providers
"""
yield from cls.list_hardcoded_providers()
# get plugin providers
yield from cls.list_plugin_providers(tenant_id)
@classmethod
def _list_hardcoded_providers(cls) -> Generator[BuiltinToolProviderController, None, None]:
"""
list all the builtin providers
"""
for provider_path in listdir(path.join(path.dirname(path.realpath(__file__)), "builtin_tool", "providers")):
if provider_path.startswith("__"):
continue
if path.isdir(path.join(path.dirname(path.realpath(__file__)), "builtin_tool", "providers", provider_path)):
if provider_path.startswith("__"):
continue
# init provider
try:
provider_class = load_single_subclass_from_source(
module_name=f"core.tools.builtin_tool.providers.{provider_path}.{provider_path}",
script_path=path.join(
path.dirname(path.realpath(__file__)),
"builtin_tool",
"providers",
provider_path,
f"{provider_path}.py",
),
parent_type=BuiltinToolProviderController,
)
provider: BuiltinToolProviderController = provider_class()
cls._hardcoded_providers[provider.entity.identity.name] = provider
for tool in provider.get_tools():
cls._builtin_tools_labels[tool.entity.identity.name] = tool.entity.identity.label
yield provider
except Exception:
logger.exception(f"load builtin provider {provider}")
continue
# set builtin providers loaded
cls._builtin_providers_loaded = True
@classmethod
def load_hardcoded_providers_cache(cls):
for _ in cls.list_hardcoded_providers():
pass
@classmethod
def clear_hardcoded_providers_cache(cls):
cls._hardcoded_providers = {}
cls._builtin_providers_loaded = False
@classmethod
def get_tool_label(cls, tool_name: str) -> Union[I18nObject, None]:
"""
get the tool label
:param tool_name: the name of the tool
:return: the label of the tool
"""
if len(cls._builtin_tools_labels) == 0:
# init the builtin providers
cls.load_hardcoded_providers_cache()
if tool_name not in cls._builtin_tools_labels:
return None
return cls._builtin_tools_labels[tool_name]
@classmethod
def list_providers_from_api(
cls, user_id: str, tenant_id: str, typ: ToolProviderTypeApiLiteral
) -> list[ToolProviderApiEntity]:
result_providers: dict[str, ToolProviderApiEntity] = {}
filters = []
if not typ:
filters.extend(["builtin", "api", "workflow"])
else:
filters.append(typ)
with db.session.no_autoflush:
if "builtin" in filters:
# get builtin providers
builtin_providers = cls.list_builtin_providers(tenant_id)
# get db builtin providers
db_builtin_providers: list[BuiltinToolProvider] = (
db.session.query(BuiltinToolProvider).filter(BuiltinToolProvider.tenant_id == tenant_id).all()
)
# rewrite db_builtin_providers
for db_provider in db_builtin_providers:
tool_provider_id = str(ToolProviderID(db_provider.provider))
db_provider.provider = tool_provider_id
def find_db_builtin_provider(provider):
return next((x for x in db_builtin_providers if x.provider == provider), None)
# append builtin providers
for provider in builtin_providers:
# handle include, exclude
if is_filtered(
include_set=cast(set[str], dify_config.POSITION_TOOL_INCLUDES_SET),
exclude_set=cast(set[str], dify_config.POSITION_TOOL_EXCLUDES_SET),
data=provider,
name_func=lambda x: x.identity.name,
):
continue
user_provider = ToolTransformService.builtin_provider_to_user_provider(
provider_controller=provider,
db_provider=find_db_builtin_provider(provider.entity.identity.name),
decrypt_credentials=False,
)
if isinstance(provider, PluginToolProviderController):
result_providers[f"plugin_provider.{user_provider.name}"] = user_provider
else:
result_providers[f"builtin_provider.{user_provider.name}"] = user_provider
# get db api providers
if "api" in filters:
db_api_providers: list[ApiToolProvider] = (
db.session.query(ApiToolProvider).filter(ApiToolProvider.tenant_id == tenant_id).all()
)
api_provider_controllers: list[dict[str, Any]] = [
{"provider": provider, "controller": ToolTransformService.api_provider_to_controller(provider)}
for provider in db_api_providers
]
# get labels
labels = ToolLabelManager.get_tools_labels([x["controller"] for x in api_provider_controllers])
for api_provider_controller in api_provider_controllers:
user_provider = ToolTransformService.api_provider_to_user_provider(
provider_controller=api_provider_controller["controller"],
db_provider=api_provider_controller["provider"],
decrypt_credentials=False,
labels=labels.get(api_provider_controller["controller"].provider_id, []),
)
result_providers[f"api_provider.{user_provider.name}"] = user_provider
if "workflow" in filters:
# get workflow providers
workflow_providers: list[WorkflowToolProvider] = (
db.session.query(WorkflowToolProvider).filter(WorkflowToolProvider.tenant_id == tenant_id).all()
)
workflow_provider_controllers: list[WorkflowToolProviderController] = []
for provider in workflow_providers:
try:
workflow_provider_controllers.append(
ToolTransformService.workflow_provider_to_controller(db_provider=provider)
)
except Exception:
# app has been deleted
pass
labels = ToolLabelManager.get_tools_labels(
[cast(ToolProviderController, controller) for controller in workflow_provider_controllers]
)
for provider_controller in workflow_provider_controllers:
user_provider = ToolTransformService.workflow_provider_to_user_provider(
provider_controller=provider_controller,
labels=labels.get(provider_controller.provider_id, []),
)
result_providers[f"workflow_provider.{user_provider.name}"] = user_provider
return BuiltinToolProviderSort.sort(list(result_providers.values()))
@classmethod
def get_api_provider_controller(
cls, tenant_id: str, provider_id: str
) -> tuple[ApiToolProviderController, dict[str, Any]]:
"""
get the api provider
:param tenant_id: the id of the tenant
:param provider_id: the id of the provider
:return: the provider controller, the credentials
"""
provider: ApiToolProvider | None = (
db.session.query(ApiToolProvider)
.filter(
ApiToolProvider.id == provider_id,
ApiToolProvider.tenant_id == tenant_id,
)
.first()
)
if provider is None:
raise ToolProviderNotFoundError(f"api provider {provider_id} not found")
controller = ApiToolProviderController.from_db(
provider,
ApiProviderAuthType.API_KEY if provider.credentials["auth_type"] == "api_key" else ApiProviderAuthType.NONE,
)
controller.load_bundled_tools(provider.tools)
return controller, provider.credentials
@classmethod
def user_get_api_provider(cls, provider: str, tenant_id: str) -> dict:
"""
get api provider
"""
"""
get tool provider
"""
provider_name = provider
provider_obj: ApiToolProvider | None = (
db.session.query(ApiToolProvider)
.filter(
ApiToolProvider.tenant_id == tenant_id,
ApiToolProvider.name == provider,
)
.first()
)
if provider_obj is None:
raise ValueError(f"you have not added provider {provider_name}")
try:
credentials = json.loads(provider_obj.credentials_str) or {}
except Exception:
credentials = {}
# package tool provider controller
controller = ApiToolProviderController.from_db(
provider_obj,
ApiProviderAuthType.API_KEY if credentials["auth_type"] == "api_key" else ApiProviderAuthType.NONE,
)
# init tool configuration
tool_configuration = ProviderConfigEncrypter(
tenant_id=tenant_id,
config=[x.to_basic_provider_config() for x in controller.get_credentials_schema()],
provider_type=controller.provider_type.value,
provider_identity=controller.entity.identity.name,
)
decrypted_credentials = tool_configuration.decrypt(credentials)
masked_credentials = tool_configuration.mask_tool_credentials(decrypted_credentials)
try:
icon = json.loads(provider_obj.icon)
except Exception:
icon = {"background": "#252525", "content": "\ud83d\ude01"}
# add tool labels
labels = ToolLabelManager.get_tool_labels(controller)
return cast(
dict,
jsonable_encoder(
{
"schema_type": provider_obj.schema_type,
"schema": provider_obj.schema,
"tools": provider_obj.tools,
"icon": icon,
"description": provider_obj.description,
"credentials": masked_credentials,
"privacy_policy": provider_obj.privacy_policy,
"custom_disclaimer": provider_obj.custom_disclaimer,
"labels": labels,
}
),
)
@classmethod
def generate_builtin_tool_icon_url(cls, provider_id: str) -> str:
return str(
URL(dify_config.CONSOLE_API_URL or "/")
/ "console"
/ "api"
/ "workspaces"
/ "current"
/ "tool-provider"
/ "builtin"
/ provider_id
/ "icon"
)
@classmethod
def generate_plugin_tool_icon_url(cls, tenant_id: str, filename: str) -> str:
return str(
URL(dify_config.CONSOLE_API_URL or "/")
/ "console"
/ "api"
/ "workspaces"
/ "current"
/ "plugin"
/ "icon"
% {"tenant_id": tenant_id, "filename": filename}
)
@classmethod
def generate_workflow_tool_icon_url(cls, tenant_id: str, provider_id: str) -> dict:
try:
workflow_provider: WorkflowToolProvider | None = (
db.session.query(WorkflowToolProvider)
.filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id)
.first()
)
if workflow_provider is None:
raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found")
icon: dict = json.loads(workflow_provider.icon)
return icon
except Exception:
return {"background": "#252525", "content": "\ud83d\ude01"}
@classmethod
def generate_api_tool_icon_url(cls, tenant_id: str, provider_id: str) -> dict:
try:
api_provider: ApiToolProvider | None = (
db.session.query(ApiToolProvider)
.filter(ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.id == provider_id)
.first()
)
if api_provider is None:
raise ToolProviderNotFoundError(f"api provider {provider_id} not found")
icon: dict = json.loads(api_provider.icon)
return icon
except Exception:
return {"background": "#252525", "content": "\ud83d\ude01"}
@classmethod
def get_tool_icon(
cls,
tenant_id: str,
provider_type: ToolProviderType,
provider_id: str,
) -> Union[str, dict]:
"""
get the tool icon
:param tenant_id: the id of the tenant
:param provider_type: the type of the provider
:param provider_id: the id of the provider
:return:
"""
provider_type = provider_type
provider_id = provider_id
if provider_type == ToolProviderType.BUILT_IN:
provider = ToolManager.get_builtin_provider(provider_id, tenant_id)
if isinstance(provider, PluginToolProviderController):
try:
return cls.generate_plugin_tool_icon_url(tenant_id, provider.entity.identity.icon)
except Exception:
return {"background": "#252525", "content": "\ud83d\ude01"}
return cls.generate_builtin_tool_icon_url(provider_id)
elif provider_type == ToolProviderType.API:
return cls.generate_api_tool_icon_url(tenant_id, provider_id)
elif provider_type == ToolProviderType.WORKFLOW:
return cls.generate_workflow_tool_icon_url(tenant_id, provider_id)
elif provider_type == ToolProviderType.PLUGIN:
provider = ToolManager.get_builtin_provider(provider_id, tenant_id)
if isinstance(provider, PluginToolProviderController):
try:
return cls.generate_plugin_tool_icon_url(tenant_id, provider.entity.identity.icon)
except Exception:
return {"background": "#252525", "content": "\ud83d\ude01"}
raise ValueError(f"plugin provider {provider_id} not found")
else:
raise ValueError(f"provider type {provider_type} not found")
ToolManager.load_hardcoded_providers_cache()

View File

View File

@ -0,0 +1,265 @@
from copy import deepcopy
from typing import Any
from pydantic import BaseModel
from core.entities.provider_entities import BasicProviderConfig
from core.helper import encrypter
from core.helper.tool_parameter_cache import ToolParameterCache, ToolParameterCacheType
from core.helper.tool_provider_cache import ToolProviderCredentialsCache, ToolProviderCredentialsCacheType
from core.tools.__base.tool import Tool
from core.tools.entities.tool_entities import (
ToolParameter,
ToolProviderType,
)
class ProviderConfigEncrypter(BaseModel):
tenant_id: str
config: list[BasicProviderConfig]
provider_type: str
provider_identity: str
def _deep_copy(self, data: dict[str, str]) -> dict[str, str]:
"""
deep copy data
"""
return deepcopy(data)
def encrypt(self, data: dict[str, str]) -> dict[str, str]:
"""
encrypt tool credentials with tenant id
return a deep copy of credentials with encrypted values
"""
data = self._deep_copy(data)
# get fields need to be decrypted
fields = dict[str, BasicProviderConfig]()
for credential in self.config:
fields[credential.name] = credential
for field_name, field in fields.items():
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
if field_name in data:
encrypted = encrypter.encrypt_token(self.tenant_id, data[field_name] or "")
data[field_name] = encrypted
return data
def mask_tool_credentials(self, data: dict[str, Any]) -> dict[str, Any]:
"""
mask tool credentials
return a deep copy of credentials with masked values
"""
data = self._deep_copy(data)
# get fields need to be decrypted
fields = dict[str, BasicProviderConfig]()
for credential in self.config:
fields[credential.name] = credential
for field_name, field in fields.items():
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
if field_name in data:
if len(data[field_name]) > 6:
data[field_name] = (
data[field_name][:2] + "*" * (len(data[field_name]) - 4) + data[field_name][-2:]
)
else:
data[field_name] = "*" * len(data[field_name])
return data
def decrypt(self, data: dict[str, str]) -> dict[str, str]:
"""
decrypt tool credentials with tenant id
return a deep copy of credentials with decrypted values
"""
cache = ToolProviderCredentialsCache(
tenant_id=self.tenant_id,
identity_id=f"{self.provider_type}.{self.provider_identity}",
cache_type=ToolProviderCredentialsCacheType.PROVIDER,
)
cached_credentials = cache.get()
if cached_credentials:
return cached_credentials
data = self._deep_copy(data)
# get fields need to be decrypted
fields = dict[str, BasicProviderConfig]()
for credential in self.config:
fields[credential.name] = credential
for field_name, field in fields.items():
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
if field_name in data:
try:
# if the value is None or empty string, skip decrypt
if not data[field_name]:
continue
data[field_name] = encrypter.decrypt_token(self.tenant_id, data[field_name])
except Exception:
pass
cache.set(data)
return data
def delete_tool_credentials_cache(self):
cache = ToolProviderCredentialsCache(
tenant_id=self.tenant_id,
identity_id=f"{self.provider_type}.{self.provider_identity}",
cache_type=ToolProviderCredentialsCacheType.PROVIDER,
)
cache.delete()
class ToolParameterConfigurationManager:
"""
Tool parameter configuration manager
"""
tenant_id: str
tool_runtime: Tool
provider_name: str
provider_type: ToolProviderType
identity_id: str
def __init__(
self, tenant_id: str, tool_runtime: Tool, provider_name: str, provider_type: ToolProviderType, identity_id: str
) -> None:
self.tenant_id = tenant_id
self.tool_runtime = tool_runtime
self.provider_name = provider_name
self.provider_type = provider_type
self.identity_id = identity_id
def _deep_copy(self, parameters: dict[str, Any]) -> dict[str, Any]:
"""
deep copy parameters
"""
return deepcopy(parameters)
def _merge_parameters(self) -> list[ToolParameter]:
"""
merge parameters
"""
# get tool parameters
tool_parameters = self.tool_runtime.entity.parameters or []
# get tool runtime parameters
runtime_parameters = self.tool_runtime.get_runtime_parameters()
# override parameters
current_parameters = tool_parameters.copy()
for runtime_parameter in runtime_parameters:
found = False
for index, parameter in enumerate(current_parameters):
if parameter.name == runtime_parameter.name and parameter.form == runtime_parameter.form:
current_parameters[index] = runtime_parameter
found = True
break
if not found and runtime_parameter.form == ToolParameter.ToolParameterForm.FORM:
current_parameters.append(runtime_parameter)
return current_parameters
def mask_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]:
"""
mask tool parameters
return a deep copy of parameters with masked values
"""
parameters = self._deep_copy(parameters)
# override parameters
current_parameters = self._merge_parameters()
for parameter in current_parameters:
if (
parameter.form == ToolParameter.ToolParameterForm.FORM
and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT
):
if parameter.name in parameters:
if len(parameters[parameter.name]) > 6:
parameters[parameter.name] = (
parameters[parameter.name][:2]
+ "*" * (len(parameters[parameter.name]) - 4)
+ parameters[parameter.name][-2:]
)
else:
parameters[parameter.name] = "*" * len(parameters[parameter.name])
return parameters
def encrypt_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]:
"""
encrypt tool parameters with tenant id
return a deep copy of parameters with encrypted values
"""
# override parameters
current_parameters = self._merge_parameters()
parameters = self._deep_copy(parameters)
for parameter in current_parameters:
if (
parameter.form == ToolParameter.ToolParameterForm.FORM
and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT
):
if parameter.name in parameters:
encrypted = encrypter.encrypt_token(self.tenant_id, parameters[parameter.name])
parameters[parameter.name] = encrypted
return parameters
def decrypt_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]:
"""
decrypt tool parameters with tenant id
return a deep copy of parameters with decrypted values
"""
cache = ToolParameterCache(
tenant_id=self.tenant_id,
provider=f"{self.provider_type.value}.{self.provider_name}",
tool_name=self.tool_runtime.entity.identity.name,
cache_type=ToolParameterCacheType.PARAMETER,
identity_id=self.identity_id,
)
cached_parameters = cache.get()
if cached_parameters:
return cached_parameters
# override parameters
current_parameters = self._merge_parameters()
has_secret_input = False
for parameter in current_parameters:
if (
parameter.form == ToolParameter.ToolParameterForm.FORM
and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT
):
if parameter.name in parameters:
try:
has_secret_input = True
parameters[parameter.name] = encrypter.decrypt_token(self.tenant_id, parameters[parameter.name])
except Exception:
pass
if has_secret_input:
cache.set(parameters)
return parameters
def delete_tool_parameters_cache(self):
cache = ToolParameterCache(
tenant_id=self.tenant_id,
provider=f"{self.provider_type.value}.{self.provider_name}",
tool_name=self.tool_runtime.entity.identity.name,
cache_type=ToolParameterCacheType.PARAMETER,
identity_id=self.identity_id,
)
cache.delete()

View File

@ -0,0 +1,199 @@
import threading
from typing import Any
from flask import Flask, current_app
from pydantic import BaseModel, Field
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.models.document import Document as RagDocument
from core.rag.rerank.rerank_model import RerankModelRunner
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
from extensions.ext_database import db
from models.dataset import Dataset, Document, DocumentSegment
default_retrieval_model: dict[str, Any] = {
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
"reranking_enable": False,
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
"top_k": 2,
"score_threshold_enabled": False,
}
class DatasetMultiRetrieverToolInput(BaseModel):
query: str = Field(..., description="dataset multi retriever and rerank")
class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
"""Tool for querying multi dataset."""
name: str = "dataset_"
args_schema: type[BaseModel] = DatasetMultiRetrieverToolInput
description: str = "dataset multi retriever and rerank. "
dataset_ids: list[str]
reranking_provider_name: str
reranking_model_name: str
@classmethod
def from_dataset(cls, dataset_ids: list[str], tenant_id: str, **kwargs):
return cls(
name=f"dataset_{tenant_id.replace('-', '_')}", tenant_id=tenant_id, dataset_ids=dataset_ids, **kwargs
)
def _run(self, query: str) -> str:
threads = []
all_documents: list[RagDocument] = []
for dataset_id in self.dataset_ids:
retrieval_thread = threading.Thread(
target=self._retriever,
kwargs={
"flask_app": current_app._get_current_object(), # type: ignore
"dataset_id": dataset_id,
"query": query,
"all_documents": all_documents,
"hit_callbacks": self.hit_callbacks,
},
)
threads.append(retrieval_thread)
retrieval_thread.start()
for thread in threads:
thread.join()
# do rerank for searched documents
model_manager = ModelManager()
rerank_model_instance = model_manager.get_model_instance(
tenant_id=self.tenant_id,
provider=self.reranking_provider_name,
model_type=ModelType.RERANK,
model=self.reranking_model_name,
)
rerank_runner = RerankModelRunner(rerank_model_instance)
all_documents = rerank_runner.run(query, all_documents, self.score_threshold, self.top_k)
for hit_callback in self.hit_callbacks:
hit_callback.on_tool_end(all_documents)
document_score_list = {}
for item in all_documents:
if item.metadata and item.metadata.get("score"):
document_score_list[item.metadata["doc_id"]] = item.metadata["score"]
document_context_list = []
index_node_ids = [document.metadata["doc_id"] for document in all_documents if document.metadata]
segments = DocumentSegment.query.filter(
DocumentSegment.dataset_id.in_(self.dataset_ids),
DocumentSegment.completed_at.isnot(None),
DocumentSegment.status == "completed",
DocumentSegment.enabled == True,
DocumentSegment.index_node_id.in_(index_node_ids),
).all()
if segments:
index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)}
sorted_segments = sorted(
segments, key=lambda segment: index_node_id_to_position.get(segment.index_node_id, float("inf"))
)
for segment in sorted_segments:
if segment.answer:
document_context_list.append(f"question:{segment.get_sign_content()} answer:{segment.answer}")
else:
document_context_list.append(segment.get_sign_content())
if self.return_resource:
context_list = []
resource_number = 1
for segment in sorted_segments:
dataset = Dataset.query.filter_by(id=segment.dataset_id).first()
document = Document.query.filter(
Document.id == segment.document_id,
Document.enabled == True,
Document.archived == False,
).first()
if dataset and document:
source = {
"position": resource_number,
"dataset_id": dataset.id,
"dataset_name": dataset.name,
"document_id": document.id,
"document_name": document.name,
"data_source_type": document.data_source_type,
"segment_id": segment.id,
"retriever_from": self.retriever_from,
"score": document_score_list.get(segment.index_node_id, None),
"doc_metadata": document.doc_metadata,
}
if self.retriever_from == "dev":
source["hit_count"] = segment.hit_count
source["word_count"] = segment.word_count
source["segment_position"] = segment.position
source["index_node_hash"] = segment.index_node_hash
if segment.answer:
source["content"] = f"question:{segment.content} \nanswer:{segment.answer}"
else:
source["content"] = segment.content
context_list.append(source)
resource_number += 1
for hit_callback in self.hit_callbacks:
hit_callback.return_retriever_resource_info(context_list)
return str("\n".join(document_context_list))
return ""
raise RuntimeError("not segments found")
def _retriever(
self,
flask_app: Flask,
dataset_id: str,
query: str,
all_documents: list,
hit_callbacks: list[DatasetIndexToolCallbackHandler],
):
with flask_app.app_context():
dataset = (
db.session.query(Dataset).filter(Dataset.tenant_id == self.tenant_id, Dataset.id == dataset_id).first()
)
if not dataset:
return []
for hit_callback in hit_callbacks:
hit_callback.on_query(query, dataset.id)
# get retrieval model , if the model is not setting , using default
retrieval_model = dataset.retrieval_model or default_retrieval_model
if dataset.indexing_technique == "economy":
# use keyword table query
documents = RetrievalService.retrieve(
retrieval_method="keyword_search",
dataset_id=dataset.id,
query=query,
top_k=retrieval_model.get("top_k") or 2,
)
if documents:
all_documents.extend(documents)
else:
if self.top_k > 0:
# retrieval source
documents = RetrievalService.retrieve(
retrieval_method=retrieval_model["search_method"],
dataset_id=dataset.id,
query=query,
top_k=retrieval_model.get("top_k") or 2,
score_threshold=retrieval_model.get("score_threshold", 0.0)
if retrieval_model["score_threshold_enabled"]
else 0.0,
reranking_model=retrieval_model.get("reranking_model", None)
if retrieval_model["reranking_enable"]
else None,
reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model",
weights=retrieval_model.get("weights", None),
)
all_documents.extend(documents)

View File

@ -0,0 +1,33 @@
from abc import abstractmethod
from typing import Any, Optional
from msal_extensions.persistence import ABC # type: ignore
from pydantic import BaseModel, ConfigDict
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
class DatasetRetrieverBaseTool(BaseModel, ABC):
"""Tool for querying a Dataset."""
name: str = "dataset"
description: str = "use this to retrieve a dataset. "
tenant_id: str
top_k: int = 2
score_threshold: Optional[float] = None
hit_callbacks: list[DatasetIndexToolCallbackHandler] = []
return_resource: bool
retriever_from: str
model_config = ConfigDict(arbitrary_types_allowed=True)
@abstractmethod
def _run(
self,
*args: Any,
**kwargs: Any,
) -> Any:
"""Use the tool.
Add run_manager: Optional[CallbackManagerForToolRun] = None
to child implementations to enable tracing,
"""

View File

@ -0,0 +1,202 @@
from typing import Any
from pydantic import BaseModel, Field
from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.entities.context_entities import DocumentContext
from core.rag.models.document import Document as RetrievalDocument
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
from extensions.ext_database import db
from models.dataset import Dataset
from models.dataset import Document as DatasetDocument
from services.external_knowledge_service import ExternalDatasetService
default_retrieval_model = {
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
"reranking_enable": False,
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
"reranking_mode": "reranking_model",
"top_k": 2,
"score_threshold_enabled": False,
}
class DatasetRetrieverToolInput(BaseModel):
query: str = Field(..., description="Query for the dataset to be used to retrieve the dataset.")
class DatasetRetrieverTool(DatasetRetrieverBaseTool):
"""Tool for querying a Dataset."""
name: str = "dataset"
args_schema: type[BaseModel] = DatasetRetrieverToolInput
description: str = "use this to retrieve a dataset. "
dataset_id: str
@classmethod
def from_dataset(cls, dataset: Dataset, **kwargs):
description = dataset.description
if not description:
description = "useful for when you want to answer queries about the " + dataset.name
description = description.replace("\n", "").replace("\r", "")
return cls(
name=f"dataset_{dataset.id.replace('-', '_')}",
tenant_id=dataset.tenant_id,
dataset_id=dataset.id,
description=description,
**kwargs,
)
def _run(self, query: str) -> str:
dataset = (
db.session.query(Dataset).filter(Dataset.tenant_id == self.tenant_id, Dataset.id == self.dataset_id).first()
)
if not dataset:
return ""
for hit_callback in self.hit_callbacks:
hit_callback.on_query(query, dataset.id)
if dataset.provider == "external":
results = []
external_documents = ExternalDatasetService.fetch_external_knowledge_retrieval(
tenant_id=dataset.tenant_id,
dataset_id=dataset.id,
query=query,
external_retrieval_parameters=dataset.retrieval_model,
)
for external_document in external_documents:
document = RetrievalDocument(
page_content=external_document.get("content"),
metadata=external_document.get("metadata"),
provider="external",
)
if document.metadata is not None:
document.metadata["score"] = external_document.get("score")
document.metadata["title"] = external_document.get("title")
document.metadata["dataset_id"] = dataset.id
document.metadata["dataset_name"] = dataset.name
results.append(document)
# deal with external documents
context_list = []
for position, item in enumerate(results, start=1):
if item.metadata is not None:
source = {
"position": position,
"dataset_id": item.metadata.get("dataset_id"),
"dataset_name": item.metadata.get("dataset_name"),
"document_name": item.metadata.get("title"),
"data_source_type": "external",
"retriever_from": self.retriever_from,
"score": item.metadata.get("score"),
"title": item.metadata.get("title"),
"content": item.page_content,
}
context_list.append(source)
for hit_callback in self.hit_callbacks:
hit_callback.return_retriever_resource_info(context_list)
return str("\n".join([item.page_content for item in results]))
else:
# get retrieval model , if the model is not setting , using default
retrieval_model: dict[str, Any] = dataset.retrieval_model or default_retrieval_model
if dataset.indexing_technique == "economy":
# use keyword table query
documents = RetrievalService.retrieve(
retrieval_method="keyword_search", dataset_id=dataset.id, query=query, top_k=self.top_k
)
return str("\n".join([document.page_content for document in documents]))
else:
if self.top_k > 0:
# retrieval source
documents = RetrievalService.retrieve(
retrieval_method=retrieval_model.get("search_method", "semantic_search"),
dataset_id=dataset.id,
query=query,
top_k=self.top_k,
score_threshold=retrieval_model.get("score_threshold", 0.0)
if retrieval_model["score_threshold_enabled"]
else 0.0,
reranking_model=retrieval_model.get("reranking_model")
if retrieval_model["reranking_enable"]
else None,
reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model",
weights=retrieval_model.get("weights"),
)
else:
documents = []
for hit_callback in self.hit_callbacks:
hit_callback.on_tool_end(documents)
document_score_list = {}
if dataset.indexing_technique != "economy":
for item in documents:
if item.metadata is not None and item.metadata.get("score"):
document_score_list[item.metadata["doc_id"]] = item.metadata["score"]
document_context_list = []
records = RetrievalService.format_retrieval_documents(documents)
if records:
for record in records:
segment = record.segment
if segment.answer:
document_context_list.append(
DocumentContext(
content=f"question:{segment.get_sign_content()} answer:{segment.answer}",
score=record.score,
)
)
else:
document_context_list.append(
DocumentContext(
content=segment.get_sign_content(),
score=record.score,
)
)
retrieval_resource_list = []
if self.return_resource:
for record in records:
segment = record.segment
dataset = Dataset.query.filter_by(id=segment.dataset_id).first()
document = DatasetDocument.query.filter(
DatasetDocument.id == segment.document_id,
DatasetDocument.enabled == True,
DatasetDocument.archived == False,
).first()
if dataset and document:
source = {
"dataset_id": dataset.id,
"dataset_name": dataset.name,
"document_id": document.id, # type: ignore
"document_name": document.name, # type: ignore
"data_source_type": document.data_source_type, # type: ignore
"segment_id": segment.id,
"retriever_from": self.retriever_from,
"score": record.score or 0.0,
"doc_metadata": document.doc_metadata, # type: ignore
}
if self.retriever_from == "dev":
source["hit_count"] = segment.hit_count
source["word_count"] = segment.word_count
source["segment_position"] = segment.position
source["index_node_hash"] = segment.index_node_hash
if segment.answer:
source["content"] = f"question:{segment.content} \nanswer:{segment.answer}"
else:
source["content"] = segment.content
retrieval_resource_list.append(source)
if self.return_resource and retrieval_resource_list:
retrieval_resource_list = sorted(
retrieval_resource_list,
key=lambda x: x.get("score") or 0.0,
reverse=True,
)
for position, item in enumerate(retrieval_resource_list, start=1): # type: ignore
item["position"] = position # type: ignore
for hit_callback in self.hit_callbacks:
hit_callback.return_retriever_resource_info(retrieval_resource_list)
if document_context_list:
document_context_list = sorted(document_context_list, key=lambda x: x.score or 0.0, reverse=True)
return str("\n".join([document_context.content for document_context in document_context_list]))
return ""

View File

@ -0,0 +1,134 @@
from collections.abc import Generator
from typing import Any, Optional
from core.app.app_config.entities import DatasetRetrieveConfigEntity
from core.app.entities.app_invoke_entities import InvokeFrom
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
from core.tools.__base.tool import Tool
from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import (
ToolDescription,
ToolEntity,
ToolIdentity,
ToolInvokeMessage,
ToolParameter,
ToolProviderType,
)
from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
class DatasetRetrieverTool(Tool):
retrieval_tool: DatasetRetrieverBaseTool
def __init__(self, entity: ToolEntity, runtime: ToolRuntime, retrieval_tool: DatasetRetrieverBaseTool) -> None:
super().__init__(entity, runtime)
self.retrieval_tool = retrieval_tool
@staticmethod
def get_dataset_tools(
tenant_id: str,
dataset_ids: list[str],
retrieve_config: DatasetRetrieveConfigEntity | None,
return_resource: bool,
invoke_from: InvokeFrom,
hit_callback: DatasetIndexToolCallbackHandler,
) -> list["DatasetRetrieverTool"]:
"""
get dataset tool
"""
# check if retrieve_config is valid
if dataset_ids is None or len(dataset_ids) == 0:
return []
if retrieve_config is None:
return []
feature = DatasetRetrieval()
# save original retrieve strategy, and set retrieve strategy to SINGLE
# Agent only support SINGLE mode
original_retriever_mode = retrieve_config.retrieve_strategy
retrieve_config.retrieve_strategy = DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE
retrieval_tools = feature.to_dataset_retriever_tool(
tenant_id=tenant_id,
dataset_ids=dataset_ids,
retrieve_config=retrieve_config,
return_resource=return_resource,
invoke_from=invoke_from,
hit_callback=hit_callback,
)
if retrieval_tools is None or len(retrieval_tools) == 0:
return []
# restore retrieve strategy
retrieve_config.retrieve_strategy = original_retriever_mode
# convert retrieval tools to Tools
tools = []
for retrieval_tool in retrieval_tools:
tool = DatasetRetrieverTool(
retrieval_tool=retrieval_tool,
entity=ToolEntity(
identity=ToolIdentity(
provider="", author="", name=retrieval_tool.name, label=I18nObject(en_US="", zh_Hans="")
),
parameters=[],
description=ToolDescription(human=I18nObject(en_US="", zh_Hans=""), llm=retrieval_tool.description),
),
runtime=ToolRuntime(tenant_id=tenant_id),
)
tools.append(tool)
return tools
def get_runtime_parameters(
self,
conversation_id: Optional[str] = None,
app_id: Optional[str] = None,
message_id: Optional[str] = None,
) -> list[ToolParameter]:
return [
ToolParameter(
name="query",
label=I18nObject(en_US="", zh_Hans=""),
human_description=I18nObject(en_US="", zh_Hans=""),
type=ToolParameter.ToolParameterType.STRING,
form=ToolParameter.ToolParameterForm.LLM,
llm_description="Query for the dataset to be used to retrieve the dataset.",
required=True,
default="",
placeholder=I18nObject(en_US="", zh_Hans=""),
),
]
def tool_provider_type(self) -> ToolProviderType:
return ToolProviderType.DATASET_RETRIEVAL
def _invoke(
self,
user_id: str,
tool_parameters: dict[str, Any],
conversation_id: Optional[str] = None,
app_id: Optional[str] = None,
message_id: Optional[str] = None,
) -> Generator[ToolInvokeMessage, None, None]:
"""
invoke dataset retriever tool
"""
query = tool_parameters.get("query")
if not query:
yield self.create_text_message(text="please input query")
else:
# invoke dataset retriever tool
result = self.retrieval_tool._run(query=query)
yield self.create_text_message(text=result)
def validate_credentials(
self, credentials: dict[str, Any], parameters: dict[str, Any], format_only: bool = False
) -> str | None:
"""
validate the credentials for dataset retriever tool
"""
pass

View File

@ -0,0 +1,121 @@
import logging
from collections.abc import Generator
from mimetypes import guess_extension
from typing import Optional
from core.file import File, FileTransferMethod, FileType
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool_file_manager import ToolFileManager
logger = logging.getLogger(__name__)
class ToolFileMessageTransformer:
@classmethod
def transform_tool_invoke_messages(
cls,
messages: Generator[ToolInvokeMessage, None, None],
user_id: str,
tenant_id: str,
conversation_id: Optional[str] = None,
) -> Generator[ToolInvokeMessage, None, None]:
"""
Transform tool message and handle file download
"""
for message in messages:
if message.type in {ToolInvokeMessage.MessageType.TEXT, ToolInvokeMessage.MessageType.LINK}:
yield message
elif message.type == ToolInvokeMessage.MessageType.IMAGE and isinstance(
message.message, ToolInvokeMessage.TextMessage
):
# try to download image
try:
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
file = ToolFileManager.create_file_by_url(
user_id=user_id,
tenant_id=tenant_id,
file_url=message.message.text,
conversation_id=conversation_id,
)
url = f"/files/tools/{file.id}{guess_extension(file.mimetype) or '.png'}"
yield ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.IMAGE_LINK,
message=ToolInvokeMessage.TextMessage(text=url),
meta=message.meta.copy() if message.meta is not None else {},
)
except Exception as e:
yield ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.TEXT,
message=ToolInvokeMessage.TextMessage(
text=f"Failed to download image: {message.message.text}: {e}"
),
meta=message.meta.copy() if message.meta is not None else {},
)
elif message.type == ToolInvokeMessage.MessageType.BLOB:
# get mime type and save blob to storage
meta = message.meta or {}
mimetype = meta.get("mime_type", "application/octet-stream")
# get filename from meta
filename = meta.get("file_name", None)
# if message is str, encode it to bytes
if not isinstance(message.message, ToolInvokeMessage.BlobMessage):
raise ValueError("unexpected message type")
# FIXME: should do a type check here.
assert isinstance(message.message.blob, bytes)
file = ToolFileManager.create_file_by_raw(
user_id=user_id,
tenant_id=tenant_id,
conversation_id=conversation_id,
file_binary=message.message.blob,
mimetype=mimetype,
filename=filename,
)
url = cls.get_tool_file_url(tool_file_id=file.id, extension=guess_extension(file.mimetype))
# check if file is image
if "image" in mimetype:
yield ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.IMAGE_LINK,
message=ToolInvokeMessage.TextMessage(text=url),
meta=meta.copy() if meta is not None else {},
)
else:
yield ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.BINARY_LINK,
message=ToolInvokeMessage.TextMessage(text=url),
meta=meta.copy() if meta is not None else {},
)
elif message.type == ToolInvokeMessage.MessageType.FILE:
meta = message.meta or {}
file = meta.get("file", None)
if isinstance(file, File):
if file.transfer_method == FileTransferMethod.TOOL_FILE:
assert file.related_id is not None
url = cls.get_tool_file_url(tool_file_id=file.related_id, extension=file.extension)
if file.type == FileType.IMAGE:
yield ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.IMAGE_LINK,
message=ToolInvokeMessage.TextMessage(text=url),
meta=meta.copy() if meta is not None else {},
)
else:
yield ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.LINK,
message=ToolInvokeMessage.TextMessage(text=url),
meta=meta.copy() if meta is not None else {},
)
else:
yield message
else:
yield message
@classmethod
def get_tool_file_url(cls, tool_file_id: str, extension: Optional[str]) -> str:
return f"/files/tools/{tool_file_id}{extension or '.bin'}"

View File

@ -0,0 +1,169 @@
"""
For some reason, model will be used in tools like WebScraperTool, WikipediaSearchTool etc.
Therefore, a model manager is needed to list/invoke/validate models.
"""
import json
from typing import Optional, cast
from core.model_manager import ModelManager
from core.model_runtime.entities.llm_entities import LLMResult
from core.model_runtime.entities.message_entities import PromptMessage
from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
from core.model_runtime.errors.invoke import (
InvokeAuthorizationError,
InvokeBadRequestError,
InvokeConnectionError,
InvokeRateLimitError,
InvokeServerUnavailableError,
)
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.utils.encoders import jsonable_encoder
from extensions.ext_database import db
from models.tools import ToolModelInvoke
class InvokeModelError(Exception):
pass
class ModelInvocationUtils:
@staticmethod
def get_max_llm_context_tokens(
tenant_id: str,
) -> int:
"""
get max llm context tokens of the model
"""
model_manager = ModelManager()
model_instance = model_manager.get_default_model_instance(
tenant_id=tenant_id,
model_type=ModelType.LLM,
)
if not model_instance:
raise InvokeModelError("Model not found")
llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials)
if not schema:
raise InvokeModelError("No model schema found")
max_tokens: Optional[int] = schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE, None)
if max_tokens is None:
return 2048
return max_tokens
@staticmethod
def calculate_tokens(tenant_id: str, prompt_messages: list[PromptMessage]) -> int:
"""
calculate tokens from prompt messages and model parameters
"""
# get model instance
model_manager = ModelManager()
model_instance = model_manager.get_default_model_instance(tenant_id=tenant_id, model_type=ModelType.LLM)
if not model_instance:
raise InvokeModelError("Model not found")
# get tokens
tokens = model_instance.get_llm_num_tokens(prompt_messages)
return tokens
@staticmethod
def invoke(
user_id: str, tenant_id: str, tool_type: str, tool_name: str, prompt_messages: list[PromptMessage]
) -> LLMResult:
"""
invoke model with parameters in user's own context
:param user_id: user id
:param tenant_id: tenant id, the tenant id of the creator of the tool
:param tool_type: tool type
:param tool_name: tool name
:param prompt_messages: prompt messages
:return: AssistantPromptMessage
"""
# get model manager
model_manager = ModelManager()
# get model instance
model_instance = model_manager.get_default_model_instance(
tenant_id=tenant_id,
model_type=ModelType.LLM,
)
# get prompt tokens
prompt_tokens = model_instance.get_llm_num_tokens(prompt_messages)
model_parameters = {
"temperature": 0.8,
"top_p": 0.8,
}
# create tool model invoke
tool_model_invoke = ToolModelInvoke(
user_id=user_id,
tenant_id=tenant_id,
provider=model_instance.provider,
tool_type=tool_type,
tool_name=tool_name,
model_parameters=json.dumps(model_parameters),
prompt_messages=json.dumps(jsonable_encoder(prompt_messages)),
model_response="",
prompt_tokens=prompt_tokens,
answer_tokens=0,
answer_unit_price=0,
answer_price_unit=0,
provider_response_latency=0,
total_price=0,
currency="USD",
)
db.session.add(tool_model_invoke)
db.session.commit()
try:
response: LLMResult = cast(
LLMResult,
model_instance.invoke_llm(
prompt_messages=prompt_messages,
model_parameters=model_parameters,
tools=[],
stop=[],
stream=False,
user=user_id,
callbacks=[],
),
)
except InvokeRateLimitError as e:
raise InvokeModelError(f"Invoke rate limit error: {e}")
except InvokeBadRequestError as e:
raise InvokeModelError(f"Invoke bad request error: {e}")
except InvokeConnectionError as e:
raise InvokeModelError(f"Invoke connection error: {e}")
except InvokeAuthorizationError as e:
raise InvokeModelError("Invoke authorization error")
except InvokeServerUnavailableError as e:
raise InvokeModelError(f"Invoke server unavailable error: {e}")
except Exception as e:
raise InvokeModelError(f"Invoke error: {e}")
# update tool model invoke
tool_model_invoke.model_response = response.message.content
if response.usage:
tool_model_invoke.answer_tokens = response.usage.completion_tokens
tool_model_invoke.answer_unit_price = response.usage.completion_unit_price
tool_model_invoke.answer_price_unit = response.usage.completion_price_unit
tool_model_invoke.provider_response_latency = response.usage.latency
tool_model_invoke.total_price = response.usage.total_price
tool_model_invoke.currency = response.usage.currency
db.session.commit()
return response

View File

@ -0,0 +1,389 @@
import re
import uuid
from json import dumps as json_dumps
from json import loads as json_loads
from json.decoder import JSONDecodeError
from typing import Optional
from flask import request
from requests import get
from yaml import YAMLError, safe_load # type: ignore
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_bundle import ApiToolBundle
from core.tools.entities.tool_entities import ApiProviderSchemaType, ToolParameter
from core.tools.errors import ToolApiSchemaError, ToolNotSupportedError, ToolProviderNotFoundError
class ApiBasedToolSchemaParser:
@staticmethod
def parse_openapi_to_tool_bundle(
openapi: dict, extra_info: dict | None = None, warning: dict | None = None
) -> list[ApiToolBundle]:
warning = warning if warning is not None else {}
extra_info = extra_info if extra_info is not None else {}
# set description to extra_info
extra_info["description"] = openapi["info"].get("description", "")
if len(openapi["servers"]) == 0:
raise ToolProviderNotFoundError("No server found in the openapi yaml.")
server_url = openapi["servers"][0]["url"]
request_env = request.headers.get("X-Request-Env")
if request_env:
matched_servers = [server["url"] for server in openapi["servers"] if server["env"] == request_env]
server_url = matched_servers[0] if matched_servers else server_url
# list all interfaces
interfaces = []
for path, path_item in openapi["paths"].items():
methods = ["get", "post", "put", "delete", "patch", "head", "options", "trace"]
for method in methods:
if method in path_item:
interfaces.append(
{
"path": path,
"method": method,
"operation": path_item[method],
}
)
# get all parameters
bundles = []
for interface in interfaces:
# convert parameters
parameters = []
if "parameters" in interface["operation"]:
for parameter in interface["operation"]["parameters"]:
tool_parameter = ToolParameter(
name=parameter["name"],
label=I18nObject(en_US=parameter["name"], zh_Hans=parameter["name"]),
human_description=I18nObject(
en_US=parameter.get("description", ""), zh_Hans=parameter.get("description", "")
),
type=ToolParameter.ToolParameterType.STRING,
required=parameter.get("required", False),
form=ToolParameter.ToolParameterForm.LLM,
llm_description=parameter.get("description"),
default=parameter["schema"]["default"]
if "schema" in parameter and "default" in parameter["schema"]
else None,
placeholder=I18nObject(
en_US=parameter.get("description", ""), zh_Hans=parameter.get("description", "")
),
)
# check if there is a type
typ = ApiBasedToolSchemaParser._get_tool_parameter_type(parameter)
if typ:
tool_parameter.type = typ
parameters.append(tool_parameter)
# create tool bundle
# check if there is a request body
if "requestBody" in interface["operation"]:
request_body = interface["operation"]["requestBody"]
if "content" in request_body:
for content_type, content in request_body["content"].items():
# if there is a reference, get the reference and overwrite the content
if "schema" not in content:
continue
if "$ref" in content["schema"]:
# get the reference
root = openapi
reference = content["schema"]["$ref"].split("/")[1:]
for ref in reference:
root = root[ref]
# overwrite the content
interface["operation"]["requestBody"]["content"][content_type]["schema"] = root
# parse body parameters
if "schema" in interface["operation"]["requestBody"]["content"][content_type]:
body_schema = interface["operation"]["requestBody"]["content"][content_type]["schema"]
required = body_schema.get("required", [])
properties = body_schema.get("properties", {})
for name, property in properties.items():
tool = ToolParameter(
name=name,
label=I18nObject(en_US=name, zh_Hans=name),
human_description=I18nObject(
en_US=property.get("description", ""), zh_Hans=property.get("description", "")
),
type=ToolParameter.ToolParameterType.STRING,
required=name in required,
form=ToolParameter.ToolParameterForm.LLM,
llm_description=property.get("description", ""),
default=property.get("default", None),
placeholder=I18nObject(
en_US=property.get("description", ""), zh_Hans=property.get("description", "")
),
)
# check if there is a type
typ = ApiBasedToolSchemaParser._get_tool_parameter_type(property)
if typ:
tool.type = typ
parameters.append(tool)
# check if parameters is duplicated
parameters_count = {}
for parameter in parameters:
if parameter.name not in parameters_count:
parameters_count[parameter.name] = 0
parameters_count[parameter.name] += 1
for name, count in parameters_count.items():
if count > 1:
warning["duplicated_parameter"] = f"Parameter {name} is duplicated."
# check if there is a operation id, use $path_$method as operation id if not
if "operationId" not in interface["operation"]:
# remove special characters like / to ensure the operation id is valid ^[a-zA-Z0-9_-]{1,64}$
path = interface["path"]
if interface["path"].startswith("/"):
path = interface["path"][1:]
# remove special characters like / to ensure the operation id is valid ^[a-zA-Z0-9_-]{1,64}$
path = re.sub(r"[^a-zA-Z0-9_-]", "", path)
if not path:
path = str(uuid.uuid4())
interface["operation"]["operationId"] = f"{path}_{interface['method']}"
bundles.append(
ApiToolBundle(
server_url=server_url + interface["path"],
method=interface["method"],
summary=interface["operation"]["description"]
if "description" in interface["operation"]
else interface["operation"].get("summary", None),
operation_id=interface["operation"]["operationId"],
parameters=parameters,
author="",
icon=None,
openapi=interface["operation"],
)
)
return bundles
@staticmethod
def _get_tool_parameter_type(parameter: dict) -> Optional[ToolParameter.ToolParameterType]:
parameter = parameter or {}
typ: Optional[str] = None
if parameter.get("format") == "binary":
return ToolParameter.ToolParameterType.FILE
if "type" in parameter:
typ = parameter["type"]
elif "schema" in parameter and "type" in parameter["schema"]:
typ = parameter["schema"]["type"]
if typ in {"integer", "number"}:
return ToolParameter.ToolParameterType.NUMBER
elif typ == "boolean":
return ToolParameter.ToolParameterType.BOOLEAN
elif typ == "string":
return ToolParameter.ToolParameterType.STRING
elif typ == "array":
items = parameter.get("items") or parameter.get("schema", {}).get("items")
return ToolParameter.ToolParameterType.FILES if items and items.get("format") == "binary" else None
else:
return None
@staticmethod
def parse_openapi_yaml_to_tool_bundle(
yaml: str, extra_info: dict | None = None, warning: dict | None = None
) -> list[ApiToolBundle]:
"""
parse openapi yaml to tool bundle
:param yaml: the yaml string
:param extra_info: the extra info
:param warning: the warning message
:return: the tool bundle
"""
warning = warning if warning is not None else {}
extra_info = extra_info if extra_info is not None else {}
openapi: dict = safe_load(yaml)
if openapi is None:
raise ToolApiSchemaError("Invalid openapi yaml.")
return ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(openapi, extra_info=extra_info, warning=warning)
@staticmethod
def parse_swagger_to_openapi(swagger: dict, extra_info: dict | None = None, warning: dict | None = None) -> dict:
warning = warning or {}
"""
parse swagger to openapi
:param swagger: the swagger dict
:return: the openapi dict
"""
# convert swagger to openapi
info = swagger.get("info", {"title": "Swagger", "description": "Swagger", "version": "1.0.0"})
servers = swagger.get("servers", [])
if len(servers) == 0:
raise ToolApiSchemaError("No server found in the swagger yaml.")
openapi = {
"openapi": "3.0.0",
"info": {
"title": info.get("title", "Swagger"),
"description": info.get("description", "Swagger"),
"version": info.get("version", "1.0.0"),
},
"servers": swagger["servers"],
"paths": {},
"components": {"schemas": {}},
}
# check paths
if "paths" not in swagger or len(swagger["paths"]) == 0:
raise ToolApiSchemaError("No paths found in the swagger yaml.")
# convert paths
for path, path_item in swagger["paths"].items():
openapi["paths"][path] = {}
for method, operation in path_item.items():
if "operationId" not in operation:
raise ToolApiSchemaError(f"No operationId found in operation {method} {path}.")
if ("summary" not in operation or len(operation["summary"]) == 0) and (
"description" not in operation or len(operation["description"]) == 0
):
if warning is not None:
warning["missing_summary"] = f"No summary or description found in operation {method} {path}."
openapi["paths"][path][method] = {
"operationId": operation["operationId"],
"summary": operation.get("summary", ""),
"description": operation.get("description", ""),
"parameters": operation.get("parameters", []),
"responses": operation.get("responses", {}),
}
if "requestBody" in operation:
openapi["paths"][path][method]["requestBody"] = operation["requestBody"]
# convert definitions
for name, definition in swagger["definitions"].items():
openapi["components"]["schemas"][name] = definition
return openapi
@staticmethod
def parse_openai_plugin_json_to_tool_bundle(
json: str, extra_info: dict | None = None, warning: dict | None = None
) -> list[ApiToolBundle]:
"""
parse openapi plugin yaml to tool bundle
:param json: the json string
:param extra_info: the extra info
:param warning: the warning message
:return: the tool bundle
"""
warning = warning if warning is not None else {}
extra_info = extra_info if extra_info is not None else {}
try:
openai_plugin = json_loads(json)
api = openai_plugin["api"]
api_url = api["url"]
api_type = api["type"]
except JSONDecodeError:
raise ToolProviderNotFoundError("Invalid openai plugin json.")
if api_type != "openapi":
raise ToolNotSupportedError("Only openapi is supported now.")
# get openapi yaml
response = get(api_url, headers={"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) "}, timeout=5)
if response.status_code != 200:
raise ToolProviderNotFoundError("cannot get openapi yaml from url.")
return ApiBasedToolSchemaParser.parse_openapi_yaml_to_tool_bundle(
response.text, extra_info=extra_info, warning=warning
)
@staticmethod
def auto_parse_to_tool_bundle(
content: str, extra_info: dict | None = None, warning: dict | None = None
) -> tuple[list[ApiToolBundle], str]:
"""
auto parse to tool bundle
:param content: the content
:param extra_info: the extra info
:param warning: the warning message
:return: tools bundle, schema_type
"""
warning = warning if warning is not None else {}
extra_info = extra_info if extra_info is not None else {}
content = content.strip()
loaded_content = None
json_error = None
yaml_error = None
try:
loaded_content = json_loads(content)
except JSONDecodeError as e:
json_error = e
if loaded_content is None:
try:
loaded_content = safe_load(content)
except YAMLError as e:
yaml_error = e
if loaded_content is None:
raise ToolApiSchemaError(
f"Invalid api schema, schema is neither json nor yaml. json error: {str(json_error)},"
f" yaml error: {str(yaml_error)}"
)
swagger_error = None
openapi_error = None
openapi_plugin_error = None
schema_type = None
try:
openapi = ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(
loaded_content, extra_info=extra_info, warning=warning
)
schema_type = ApiProviderSchemaType.OPENAPI.value
return openapi, schema_type
except ToolApiSchemaError as e:
openapi_error = e
# openai parse error, fallback to swagger
try:
converted_swagger = ApiBasedToolSchemaParser.parse_swagger_to_openapi(
loaded_content, extra_info=extra_info, warning=warning
)
schema_type = ApiProviderSchemaType.SWAGGER.value
return ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(
converted_swagger, extra_info=extra_info, warning=warning
), schema_type
except ToolApiSchemaError as e:
swagger_error = e
# swagger parse error, fallback to openai plugin
try:
openapi_plugin = ApiBasedToolSchemaParser.parse_openai_plugin_json_to_tool_bundle(
json_dumps(loaded_content), extra_info=extra_info, warning=warning
)
return openapi_plugin, ApiProviderSchemaType.OPENAI_PLUGIN.value
except ToolNotSupportedError as e:
# maybe it's not plugin at all
openapi_plugin_error = e
raise ToolApiSchemaError(
f"Invalid api schema, openapi error: {str(openapi_error)}, swagger error: {str(swagger_error)},"
f" openapi plugin error: {str(openapi_plugin_error)}"
)

View File

@ -0,0 +1,17 @@
import re
def get_image_upload_file_ids(content):
pattern = r"!\[image\]\((http?://.*?(file-preview|image-preview))\)"
matches = re.findall(pattern, content)
image_upload_file_ids = []
for match in matches:
if match[1] == "file-preview":
content_pattern = r"files/([^/]+)/file-preview"
else:
content_pattern = r"files/([^/]+)/image-preview"
content_match = re.search(content_pattern, match[0])
if content_match:
image_upload_file_id = content_match.group(1)
image_upload_file_ids.append(image_upload_file_id)
return image_upload_file_ids

View File

@ -0,0 +1,17 @@
import re
def remove_leading_symbols(text: str) -> str:
"""
Remove leading punctuation or symbols from the given text.
Args:
text (str): The input text to process.
Returns:
str: The text with leading punctuation or symbols removed.
"""
# Match Unicode ranges for punctuation and symbols
# FIXME this pattern is confused quick fix for #11868 maybe refactor it later
pattern = r"^[\u2000-\u206F\u2E00-\u2E7F\u3000-\u303F!\"#$%&'()*+,./:;<=>?@^_`~]+"
return re.sub(pattern, "", text)

View File

@ -0,0 +1,9 @@
import uuid
def is_valid_uuid(uuid_str: str) -> bool:
try:
uuid.UUID(uuid_str)
return True
except Exception:
return False

View File

@ -0,0 +1,375 @@
import hashlib
import json
import mimetypes
import os
import re
import site
import subprocess
import tempfile
import unicodedata
from contextlib import contextmanager
from pathlib import Path
from typing import Any, Literal, Optional, cast
from urllib.parse import unquote
import chardet
import cloudscraper # type: ignore
from bs4 import BeautifulSoup, CData, Comment, NavigableString # type: ignore
from regex import regex # type: ignore
from core.helper import ssrf_proxy
from core.rag.extractor import extract_processor
from core.rag.extractor.extract_processor import ExtractProcessor
FULL_TEMPLATE = """
TITLE: {title}
AUTHORS: {authors}
PUBLISH DATE: {publish_date}
TOP_IMAGE_URL: {top_image}
TEXT:
{text}
"""
def page_result(text: str, cursor: int, max_length: int) -> str:
"""Page through `text` and return a substring of `max_length` characters starting from `cursor`."""
return text[cursor : cursor + max_length]
def get_url(url: str, user_agent: Optional[str] = None) -> str:
"""Fetch URL and return the contents as a string."""
headers = {
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko)"
" Chrome/91.0.4472.124 Safari/537.36"
}
if user_agent:
headers["User-Agent"] = user_agent
main_content_type = None
supported_content_types = extract_processor.SUPPORT_URL_CONTENT_TYPES + ["text/html"]
response = ssrf_proxy.head(url, headers=headers, follow_redirects=True, timeout=(5, 10))
if response.status_code == 200:
# check content-type
content_type = response.headers.get("Content-Type")
if content_type:
main_content_type = response.headers.get("Content-Type").split(";")[0].strip()
else:
content_disposition = response.headers.get("Content-Disposition", "")
filename_match = re.search(r'filename="([^"]+)"', content_disposition)
if filename_match:
filename = unquote(filename_match.group(1))
extension = re.search(r"\.(\w+)$", filename)
if extension:
main_content_type = mimetypes.guess_type(filename)[0]
if main_content_type not in supported_content_types:
return "Unsupported content-type [{}] of URL.".format(main_content_type)
if main_content_type in extract_processor.SUPPORT_URL_CONTENT_TYPES:
return cast(str, ExtractProcessor.load_from_url(url, return_text=True))
response = ssrf_proxy.get(url, headers=headers, follow_redirects=True, timeout=(120, 300))
elif response.status_code == 403:
scraper = cloudscraper.create_scraper()
scraper.perform_request = ssrf_proxy.make_request
response = scraper.get(url, headers=headers, follow_redirects=True, timeout=(120, 300))
if response.status_code != 200:
return "URL returned status code {}.".format(response.status_code)
# Detect encoding using chardet
detected_encoding = chardet.detect(response.content)
encoding = detected_encoding["encoding"]
if encoding:
try:
content = response.content.decode(encoding)
except (UnicodeDecodeError, TypeError):
content = response.text
else:
content = response.text
a = extract_using_readabilipy(content)
if not a["plain_text"] or not a["plain_text"].strip():
return ""
res = FULL_TEMPLATE.format(
title=a["title"],
authors=a["byline"],
publish_date=a["date"],
top_image="",
text=a["plain_text"] or "",
)
return res
def extract_using_readabilipy(html):
with tempfile.NamedTemporaryFile(delete=False, mode="w+") as f_html:
f_html.write(html)
f_html.close()
html_path = f_html.name
# Call Mozilla's Readability.js Readability.parse() function via node, writing output to a temporary file
article_json_path = html_path + ".json"
jsdir = os.path.join(find_module_path("readabilipy"), "javascript")
with chdir(jsdir):
subprocess.check_call(["node", "ExtractArticle.js", "-i", html_path, "-o", article_json_path])
# Read output of call to Readability.parse() from JSON file and return as Python dictionary
input_json = json.loads(Path(article_json_path).read_text(encoding="utf-8"))
# Deleting files after processing
os.unlink(article_json_path)
os.unlink(html_path)
article_json: dict[str, Any] = {
"title": None,
"byline": None,
"date": None,
"content": None,
"plain_content": None,
"plain_text": None,
}
# Populate article fields from readability fields where present
if input_json:
if input_json.get("title"):
article_json["title"] = input_json["title"]
if input_json.get("byline"):
article_json["byline"] = input_json["byline"]
if input_json.get("date"):
article_json["date"] = input_json["date"]
if input_json.get("content"):
article_json["content"] = input_json["content"]
article_json["plain_content"] = plain_content(article_json["content"], False, False)
article_json["plain_text"] = extract_text_blocks_as_plain_text(article_json["plain_content"])
if input_json.get("textContent"):
article_json["plain_text"] = input_json["textContent"]
article_json["plain_text"] = re.sub(r"\n\s*\n", "\n", article_json["plain_text"])
return article_json
def find_module_path(module_name):
for package_path in site.getsitepackages():
potential_path = os.path.join(package_path, module_name)
if os.path.exists(potential_path):
return potential_path
return None
@contextmanager
def chdir(path):
"""Change directory in context and return to original on exit"""
# From https://stackoverflow.com/a/37996581, couldn't find a built-in
original_path = os.getcwd()
os.chdir(path)
try:
yield
finally:
os.chdir(original_path)
def extract_text_blocks_as_plain_text(paragraph_html):
# Load article as DOM
soup = BeautifulSoup(paragraph_html, "html.parser")
# Select all lists
list_elements = soup.find_all(["ul", "ol"])
# Prefix text in all list items with "* " and make lists paragraphs
for list_element in list_elements:
plain_items = "".join(
list(filter(None, [plain_text_leaf_node(li)["text"] for li in list_element.find_all("li")]))
)
list_element.string = plain_items
list_element.name = "p"
# Select all text blocks
text_blocks = [s.parent for s in soup.find_all(string=True)]
text_blocks = [plain_text_leaf_node(block) for block in text_blocks]
# Drop empty paragraphs
text_blocks = list(filter(lambda p: p["text"] is not None, text_blocks))
return text_blocks
def plain_text_leaf_node(element):
# Extract all text, stripped of any child HTML elements and normalize it
plain_text = normalize_text(element.get_text())
if plain_text != "" and element.name == "li":
plain_text = "* {}, ".format(plain_text)
if plain_text == "":
plain_text = None
if "data-node-index" in element.attrs:
plain = {"node_index": element["data-node-index"], "text": plain_text}
else:
plain = {"text": plain_text}
return plain
def plain_content(readability_content, content_digests, node_indexes):
# Load article as DOM
soup = BeautifulSoup(readability_content, "html.parser")
# Make all elements plain
elements = plain_elements(soup.contents, content_digests, node_indexes)
if node_indexes:
# Add node index attributes to nodes
elements = [add_node_indexes(element) for element in elements]
# Replace article contents with plain elements
soup.contents = elements
return str(soup)
def plain_elements(elements, content_digests, node_indexes):
# Get plain content versions of all elements
elements = [plain_element(element, content_digests, node_indexes) for element in elements]
if content_digests:
# Add content digest attribute to nodes
elements = [add_content_digest(element) for element in elements]
return elements
def plain_element(element, content_digests, node_indexes):
# For lists, we make each item plain text
if is_leaf(element):
# For leaf node elements, extract the text content, discarding any HTML tags
# 1. Get element contents as text
plain_text = element.get_text()
# 2. Normalize the extracted text string to a canonical representation
plain_text = normalize_text(plain_text)
# 3. Update element content to be plain text
element.string = plain_text
elif is_text(element):
if is_non_printing(element):
# The simplified HTML may have come from Readability.js so might
# have non-printing text (e.g. Comment or CData). In this case, we
# keep the structure, but ensure that the string is empty.
element = type(element)("")
else:
plain_text = element.string
plain_text = normalize_text(plain_text)
element = type(element)(plain_text)
else:
# If not a leaf node or leaf type call recursively on child nodes, replacing
element.contents = plain_elements(element.contents, content_digests, node_indexes)
return element
def add_node_indexes(element, node_index="0"):
# Can't add attributes to string types
if is_text(element):
return element
# Add index to current element
element["data-node-index"] = node_index
# Add index to child elements
for local_idx, child in enumerate([c for c in element.contents if not is_text(c)], start=1):
# Can't add attributes to leaf string types
child_index = "{stem}.{local}".format(stem=node_index, local=local_idx)
add_node_indexes(child, node_index=child_index)
return element
def normalize_text(text):
"""Normalize unicode and whitespace."""
# Normalize unicode first to try and standardize whitespace characters as much as possible before normalizing them
text = strip_control_characters(text)
text = normalize_unicode(text)
text = normalize_whitespace(text)
return text
def strip_control_characters(text):
"""Strip out unicode control characters which might break the parsing."""
# Unicode control characters
# [Cc]: Other, Control [includes new lines]
# [Cf]: Other, Format
# [Cn]: Other, Not Assigned
# [Co]: Other, Private Use
# [Cs]: Other, Surrogate
control_chars = {"Cc", "Cf", "Cn", "Co", "Cs"}
retained_chars = ["\t", "\n", "\r", "\f"]
# Remove non-printing control characters
return "".join(
[
"" if (unicodedata.category(char) in control_chars) and (char not in retained_chars) else char
for char in text
]
)
def normalize_unicode(text):
"""Normalize unicode such that things that are visually equivalent map to the same unicode string where possible."""
normal_form: Literal["NFC", "NFD", "NFKC", "NFKD"] = "NFKC"
text = unicodedata.normalize(normal_form, text)
return text
def normalize_whitespace(text):
"""Replace runs of whitespace characters with a single space as this is what happens when HTML text is displayed."""
text = regex.sub(r"\s+", " ", text)
# Remove leading and trailing whitespace
text = text.strip()
return text
def is_leaf(element):
return element.name in {"p", "li"}
def is_text(element):
return isinstance(element, NavigableString)
def is_non_printing(element):
return any(isinstance(element, _e) for _e in [Comment, CData])
def add_content_digest(element):
if not is_text(element):
element["data-content-digest"] = content_digest(element)
return element
def content_digest(element):
digest: Any
if is_text(element):
# Hash
trimmed_string = element.string.strip()
if trimmed_string == "":
digest = ""
else:
digest = hashlib.sha256(trimmed_string.encode("utf-8")).hexdigest()
else:
contents = element.contents
num_contents = len(contents)
if num_contents == 0:
# No hash when no child elements exist
digest = ""
elif num_contents == 1:
# If single child, use digest of child
digest = content_digest(contents[0])
else:
# Build content digest from the "non-empty" digests of child nodes
digest = hashlib.sha256()
child_digests = list(filter(lambda x: x != "", [content_digest(content) for content in contents]))
for child in child_digests:
digest.update(child.encode("utf-8"))
digest = digest.hexdigest()
return digest
def get_image_upload_file_ids(content):
pattern = r"!\[image\]\((http?://.*?(file-preview|image-preview))\)"
matches = re.findall(pattern, content)
image_upload_file_ids = []
for match in matches:
if match[1] == "file-preview":
content_pattern = r"files/([^/]+)/file-preview"
else:
content_pattern = r"files/([^/]+)/image-preview"
content_match = re.search(content_pattern, match[0])
if content_match:
image_upload_file_id = content_match.group(1)
image_upload_file_ids.append(image_upload_file_id)
return image_upload_file_ids

View File

@ -0,0 +1,43 @@
from collections.abc import Mapping, Sequence
from typing import Any
from core.app.app_config.entities import VariableEntity
from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration
class WorkflowToolConfigurationUtils:
@classmethod
def check_parameter_configurations(cls, configurations: list[Mapping[str, Any]]):
for configuration in configurations:
WorkflowToolParameterConfiguration.model_validate(configuration)
@classmethod
def get_workflow_graph_variables(cls, graph: Mapping[str, Any]) -> Sequence[VariableEntity]:
"""
get workflow graph variables
"""
nodes = graph.get("nodes", [])
start_node = next(filter(lambda x: x.get("data", {}).get("type") == "start", nodes), None)
if not start_node:
return []
return [VariableEntity.model_validate(variable) for variable in start_node.get("data", {}).get("variables", [])]
@classmethod
def check_is_synced(
cls, variables: list[VariableEntity], tool_configurations: list[WorkflowToolParameterConfiguration]
):
"""
check is synced
raise ValueError if not synced
"""
variable_names = [variable.variable for variable in variables]
if len(tool_configurations) != len(variables):
raise ValueError("parameter configuration mismatch, please republish the tool to update")
for parameter in tool_configurations:
if parameter.name not in variable_names:
raise ValueError("parameter configuration mismatch, please republish the tool to update")

View File

@ -0,0 +1,35 @@
import logging
from pathlib import Path
from typing import Any
import yaml # type: ignore
from yaml import YAMLError
logger = logging.getLogger(__name__)
def load_yaml_file(file_path: str, ignore_error: bool = True, default_value: Any = {}) -> Any:
"""
Safe loading a YAML file
:param file_path: the path of the YAML file
:param ignore_error:
if True, return default_value if error occurs and the error will be logged in debug level
if False, raise error if error occurs
:param default_value: the value returned when errors ignored
:return: an object of the YAML content
"""
if not file_path or not Path(file_path).exists():
if ignore_error:
return default_value
else:
raise FileNotFoundError(f"File not found: {file_path}")
with open(file_path, encoding="utf-8") as yaml_file:
try:
yaml_content = yaml.safe_load(yaml_file)
return yaml_content or default_value
except Exception as e:
if ignore_error:
return default_value
else:
raise YAMLError(f"Failed to load YAML file {file_path}: {e}") from e

View File

@ -1,3 +1,4 @@
SYSTEM_VARIABLE_NODE_ID = "sys"
ENVIRONMENT_VARIABLE_NODE_ID = "env"
CONVERSATION_VARIABLE_NODE_ID = "conversation"
PIPELINE_VARIABLE_NODE_ID = "pipeline"

View File

@ -17,6 +17,7 @@ class NodeRunMetadataKey(StrEnum):
TOTAL_PRICE = "total_price"
CURRENCY = "currency"
TOOL_INFO = "tool_info"
DATASOURCE_INFO = "datasource_info"
AGENT_LOG = "agent_log"
ITERATION_ID = "iteration_id"
ITERATION_INDEX = "iteration_index"

View File

@ -0,0 +1,3 @@
from .tool_node import ToolNode
__all__ = ["DatasourceNode"]

View File

@ -0,0 +1,406 @@
from collections.abc import Generator, Mapping, Sequence
from typing import Any, cast
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
from core.file import File, FileTransferMethod
from core.plugin.manager.exc import PluginDaemonClientSideError
from core.plugin.manager.plugin import PluginInstallationManager
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
from core.tools.errors import ToolInvokeError
from core.tools.tool_engine import ToolEngine
from core.tools.utils.message_transformer import ToolFileMessageTransformer
from core.variables.segments import ArrayAnySegment
from core.variables.variables import ArrayAnyVariable
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.event import AgentLogEvent
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.event import RunCompletedEvent, RunStreamChunkEvent
from core.workflow.utils.variable_template_parser import VariableTemplateParser
from extensions.ext_database import db
from factories import file_factory
from models import ToolFile
from models.workflow import WorkflowNodeExecutionStatus
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
from .entities import DatasourceNodeData
from .exc import (
ToolFileError,
ToolNodeError,
ToolParameterError,
)
class DatasourceNode(BaseNode[DatasourceNodeData]):
"""
Datasource Node
"""
_node_data_cls = DatasourceNodeData
_node_type = NodeType.DATASOURCE
def _run(self) -> Generator:
"""
Run the datasource node
"""
node_data = cast(DatasourceNodeData, self.node_data)
# fetch datasource icon
datasource_info = {
"provider_type": node_data.provider_type.value,
"provider_id": node_data.provider_id,
"plugin_unique_identifier": node_data.plugin_unique_identifier,
}
# get datasource runtime
try:
from core.tools.tool_manager import ToolManager
tool_runtime = ToolManager.get_workflow_tool_runtime(
self.tenant_id, self.app_id, self.node_id, self.node_data, self.invoke_from
)
except ToolNodeError as e:
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs={},
metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info},
error=f"Failed to get datasource runtime: {str(e)}",
error_type=type(e).__name__,
)
)
return
# get parameters
tool_parameters = tool_runtime.get_merged_runtime_parameters() or []
parameters = self._generate_parameters(
tool_parameters=tool_parameters,
variable_pool=self.graph_runtime_state.variable_pool,
node_data=self.node_data,
)
parameters_for_log = self._generate_parameters(
tool_parameters=tool_parameters,
variable_pool=self.graph_runtime_state.variable_pool,
node_data=self.node_data,
for_log=True,
)
# get conversation id
conversation_id = self.graph_runtime_state.variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID])
try:
message_stream = ToolEngine.generic_invoke(
tool=tool_runtime,
tool_parameters=parameters,
user_id=self.user_id,
workflow_tool_callback=DifyWorkflowCallbackHandler(),
workflow_call_depth=self.workflow_call_depth,
thread_pool_id=self.thread_pool_id,
app_id=self.app_id,
conversation_id=conversation_id.text if conversation_id else None,
)
except ToolNodeError as e:
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=parameters_for_log,
metadata={NodeRunMetadataKey.TOOL_INFO: tool_info},
error=f"Failed to invoke tool: {str(e)}",
error_type=type(e).__name__,
)
)
return
try:
# convert tool messages
yield from self._transform_message(message_stream, tool_info, parameters_for_log)
except (PluginDaemonClientSideError, ToolInvokeError) as e:
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=parameters_for_log,
metadata={NodeRunMetadataKey.TOOL_INFO: tool_info},
error=f"Failed to transform tool message: {str(e)}",
error_type=type(e).__name__,
)
)
def _generate_parameters(
self,
*,
tool_parameters: Sequence[ToolParameter],
variable_pool: VariablePool,
node_data: ToolNodeData,
for_log: bool = False,
) -> dict[str, Any]:
"""
Generate parameters based on the given tool parameters, variable pool, and node data.
Args:
tool_parameters (Sequence[ToolParameter]): The list of tool parameters.
variable_pool (VariablePool): The variable pool containing the variables.
node_data (ToolNodeData): The data associated with the tool node.
Returns:
Mapping[str, Any]: A dictionary containing the generated parameters.
"""
tool_parameters_dictionary = {parameter.name: parameter for parameter in tool_parameters}
result: dict[str, Any] = {}
for parameter_name in node_data.tool_parameters:
parameter = tool_parameters_dictionary.get(parameter_name)
if not parameter:
result[parameter_name] = None
continue
tool_input = node_data.tool_parameters[parameter_name]
if tool_input.type == "variable":
variable = variable_pool.get(tool_input.value)
if variable is None:
raise ToolParameterError(f"Variable {tool_input.value} does not exist")
parameter_value = variable.value
elif tool_input.type in {"mixed", "constant"}:
segment_group = variable_pool.convert_template(str(tool_input.value))
parameter_value = segment_group.log if for_log else segment_group.text
else:
raise ToolParameterError(f"Unknown tool input type '{tool_input.type}'")
result[parameter_name] = parameter_value
return result
def _fetch_files(self, variable_pool: VariablePool) -> list[File]:
variable = variable_pool.get(["sys", SystemVariableKey.FILES.value])
assert isinstance(variable, ArrayAnyVariable | ArrayAnySegment)
return list(variable.value) if variable else []
def _transform_message(
self,
messages: Generator[ToolInvokeMessage, None, None],
tool_info: Mapping[str, Any],
parameters_for_log: dict[str, Any],
) -> Generator:
"""
Convert ToolInvokeMessages into tuple[plain_text, files]
"""
# transform message and handle file storage
message_stream = ToolFileMessageTransformer.transform_tool_invoke_messages(
messages=messages,
user_id=self.user_id,
tenant_id=self.tenant_id,
conversation_id=None,
)
text = ""
files: list[File] = []
json: list[dict] = []
agent_logs: list[AgentLogEvent] = []
agent_execution_metadata: Mapping[NodeRunMetadataKey, Any] = {}
variables: dict[str, Any] = {}
for message in message_stream:
if message.type in {
ToolInvokeMessage.MessageType.IMAGE_LINK,
ToolInvokeMessage.MessageType.BINARY_LINK,
ToolInvokeMessage.MessageType.IMAGE,
}:
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
url = message.message.text
if message.meta:
transfer_method = message.meta.get("transfer_method", FileTransferMethod.TOOL_FILE)
else:
transfer_method = FileTransferMethod.TOOL_FILE
tool_file_id = str(url).split("/")[-1].split(".")[0]
with Session(db.engine) as session:
stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
tool_file = session.scalar(stmt)
if tool_file is None:
raise ToolFileError(f"Tool file {tool_file_id} does not exist")
mapping = {
"tool_file_id": tool_file_id,
"type": file_factory.get_file_type_by_mime_type(tool_file.mimetype),
"transfer_method": transfer_method,
"url": url,
}
file = file_factory.build_from_mapping(
mapping=mapping,
tenant_id=self.tenant_id,
)
files.append(file)
elif message.type == ToolInvokeMessage.MessageType.BLOB:
# get tool file id
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
assert message.meta
tool_file_id = message.message.text.split("/")[-1].split(".")[0]
with Session(db.engine) as session:
stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
tool_file = session.scalar(stmt)
if tool_file is None:
raise ToolFileError(f"tool file {tool_file_id} not exists")
mapping = {
"tool_file_id": tool_file_id,
"transfer_method": FileTransferMethod.TOOL_FILE,
}
files.append(
file_factory.build_from_mapping(
mapping=mapping,
tenant_id=self.tenant_id,
)
)
elif message.type == ToolInvokeMessage.MessageType.TEXT:
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
text += message.message.text
yield RunStreamChunkEvent(
chunk_content=message.message.text, from_variable_selector=[self.node_id, "text"]
)
elif message.type == ToolInvokeMessage.MessageType.JSON:
assert isinstance(message.message, ToolInvokeMessage.JsonMessage)
if self.node_type == NodeType.AGENT:
msg_metadata = message.message.json_object.pop("execution_metadata", {})
agent_execution_metadata = {
key: value
for key, value in msg_metadata.items()
if key in NodeRunMetadataKey.__members__.values()
}
json.append(message.message.json_object)
elif message.type == ToolInvokeMessage.MessageType.LINK:
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
stream_text = f"Link: {message.message.text}\n"
text += stream_text
yield RunStreamChunkEvent(chunk_content=stream_text, from_variable_selector=[self.node_id, "text"])
elif message.type == ToolInvokeMessage.MessageType.VARIABLE:
assert isinstance(message.message, ToolInvokeMessage.VariableMessage)
variable_name = message.message.variable_name
variable_value = message.message.variable_value
if message.message.stream:
if not isinstance(variable_value, str):
raise ValueError("When 'stream' is True, 'variable_value' must be a string.")
if variable_name not in variables:
variables[variable_name] = ""
variables[variable_name] += variable_value
yield RunStreamChunkEvent(
chunk_content=variable_value, from_variable_selector=[self.node_id, variable_name]
)
else:
variables[variable_name] = variable_value
elif message.type == ToolInvokeMessage.MessageType.FILE:
assert message.meta is not None
files.append(message.meta["file"])
elif message.type == ToolInvokeMessage.MessageType.LOG:
assert isinstance(message.message, ToolInvokeMessage.LogMessage)
if message.message.metadata:
icon = tool_info.get("icon", "")
dict_metadata = dict(message.message.metadata)
if dict_metadata.get("provider"):
manager = PluginInstallationManager()
plugins = manager.list_plugins(self.tenant_id)
try:
current_plugin = next(
plugin
for plugin in plugins
if f"{plugin.plugin_id}/{plugin.name}" == dict_metadata["provider"]
)
icon = current_plugin.declaration.icon
except StopIteration:
pass
try:
builtin_tool = next(
provider
for provider in BuiltinToolManageService.list_builtin_tools(
self.user_id,
self.tenant_id,
)
if provider.name == dict_metadata["provider"]
)
icon = builtin_tool.icon
except StopIteration:
pass
dict_metadata["icon"] = icon
message.message.metadata = dict_metadata
agent_log = AgentLogEvent(
id=message.message.id,
node_execution_id=self.id,
parent_id=message.message.parent_id,
error=message.message.error,
status=message.message.status.value,
data=message.message.data,
label=message.message.label,
metadata=message.message.metadata,
node_id=self.node_id,
)
# check if the agent log is already in the list
for log in agent_logs:
if log.id == agent_log.id:
# update the log
log.data = agent_log.data
log.status = agent_log.status
log.error = agent_log.error
log.label = agent_log.label
log.metadata = agent_log.metadata
break
else:
agent_logs.append(agent_log)
yield agent_log
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={"text": text, "files": files, "json": json, **variables},
metadata={
**agent_execution_metadata,
NodeRunMetadataKey.TOOL_INFO: tool_info,
NodeRunMetadataKey.AGENT_LOG: agent_logs,
},
inputs=parameters_for_log,
)
)
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls,
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: ToolNodeData,
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
:param graph_config: graph config
:param node_id: node id
:param node_data: node data
:return:
"""
result = {}
for parameter_name in node_data.tool_parameters:
input = node_data.tool_parameters[parameter_name]
if input.type == "mixed":
assert isinstance(input.value, str)
selectors = VariableTemplateParser(input.value).extract_variable_selectors()
for selector in selectors:
result[selector.variable] = selector.value_selector
elif input.type == "variable":
result[parameter_name] = input.value
elif input.type == "constant":
pass
result = {node_id + "." + key: value for key, value in result.items()}
return result

View File

@ -0,0 +1,56 @@
from typing import Any, Literal, Union
from pydantic import BaseModel, field_validator
from pydantic_core.core_schema import ValidationInfo
from core.tools.entities.tool_entities import ToolProviderType
from core.workflow.nodes.base.entities import BaseNodeData
class DatasourceEntity(BaseModel):
provider_id: str
provider_type: ToolProviderType
provider_name: str # redundancy
tool_name: str
tool_label: str # redundancy
tool_configurations: dict[str, Any]
plugin_unique_identifier: str | None = None # redundancy
@field_validator("tool_configurations", mode="before")
@classmethod
def validate_tool_configurations(cls, value, values: ValidationInfo):
if not isinstance(value, dict):
raise ValueError("tool_configurations must be a dictionary")
for key in values.data.get("tool_configurations", {}):
value = values.data.get("tool_configurations", {}).get(key)
if not isinstance(value, str | int | float | bool):
raise ValueError(f"{key} must be a string")
return value
class DatasourceNodeData(BaseNodeData, DatasourceEntity):
class DatasourceInput(BaseModel):
# TODO: check this type
value: Union[Any, list[str]]
type: Literal["mixed", "variable", "constant"]
@field_validator("type", mode="before")
@classmethod
def check_type(cls, value, validation_info: ValidationInfo):
typ = value
value = validation_info.data.get("value")
if typ == "mixed" and not isinstance(value, str):
raise ValueError("value must be a string")
elif typ == "variable":
if not isinstance(value, list):
raise ValueError("value must be a list")
for val in value:
if not isinstance(val, str):
raise ValueError("value must be a list of strings")
elif typ == "constant" and not isinstance(value, str | int | float | bool):
raise ValueError("value must be a string, int, float, or bool")
return typ
datasource_parameters: dict[str, DatasourceInput]

View File

@ -0,0 +1,16 @@
class ToolNodeError(ValueError):
"""Base exception for tool node errors."""
pass
class ToolParameterError(ToolNodeError):
"""Exception raised for errors in tool parameters."""
pass
class ToolFileError(ToolNodeError):
"""Exception raised for errors related to tool files."""
pass

View File

@ -13,6 +13,7 @@ class NodeType(StrEnum):
QUESTION_CLASSIFIER = "question-classifier"
HTTP_REQUEST = "http-request"
TOOL = "tool"
DATASOURCE = "datasource"
VARIABLE_AGGREGATOR = "variable-aggregator"
LEGACY_VARIABLE_AGGREGATOR = "variable-assigner" # TODO: Merge this into VARIABLE_AGGREGATOR in the database.
LOOP = "loop"

View File

@ -36,7 +36,11 @@ from core.variables.variables import (
StringVariable,
Variable,
)
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID
from core.workflow.constants import (
CONVERSATION_VARIABLE_NODE_ID,
ENVIRONMENT_VARIABLE_NODE_ID,
PIPELINE_VARIABLE_NODE_ID,
)
class InvalidSelectorError(ValueError):
@ -74,6 +78,10 @@ def build_environment_variable_from_mapping(mapping: Mapping[str, Any], /) -> Va
raise VariableError("missing name")
return _build_variable_from_mapping(mapping=mapping, selector=[ENVIRONMENT_VARIABLE_NODE_ID, mapping["name"]])
def build_pipeline_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable:
if not mapping.get("name"):
raise VariableError("missing name")
return _build_variable_from_mapping(mapping=mapping, selector=[PIPELINE_VARIABLE_NODE_ID, mapping["name"]])
def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequence[str]) -> Variable:
"""

View File

@ -40,6 +40,13 @@ conversation_variable_fields = {
"description": fields.String,
}
pipeline_variable_fields = {
"id": fields.String,
"name": fields.String,
"value_type": fields.String(attribute="value_type.value"),
"value": fields.Raw,
}
workflow_fields = {
"id": fields.String,
"graph": fields.Raw(attribute="graph_dict"),
@ -55,6 +62,10 @@ workflow_fields = {
"tool_published": fields.Boolean,
"environment_variables": fields.List(EnvironmentVariableField()),
"conversation_variables": fields.List(fields.Nested(conversation_variable_fields)),
"pipeline_variables": fields.Dict(
keys=fields.String,
values=fields.List(fields.Nested(pipeline_variable_fields)),
),
}
workflow_partial_fields = {

View File

@ -130,6 +130,9 @@ class Workflow(Base):
_conversation_variables: Mapped[str] = mapped_column(
"conversation_variables", db.Text, nullable=False, server_default="{}"
)
_pipeline_variables: Mapped[str] = mapped_column(
"conversation_variables", db.Text, nullable=False, server_default="{}"
)
@classmethod
def new(
@ -343,6 +346,24 @@ class Workflow(Base):
ensure_ascii=False,
)
@property
def pipeline_variables(self) -> dict[str, Sequence[Variable]]:
# TODO: find some way to init `self._conversation_variables` when instance created.
if self._pipeline_variables is None:
self._pipeline_variables = "{}"
variables_dict: dict[str, Any] = json.loads(self._pipeline_variables)
results = {}
for k, v in variables_dict.items():
results[k] = [variable_factory.build_pipeline_variable_from_mapping(item) for item in v.values()]
return results
@pipeline_variables.setter
def pipeline_variables(self, values: dict[str, Sequence[Variable]]) -> None:
self._pipeline_variables = json.dumps(
{k: {item.name: item.model_dump() for item in v} for k, v in values.items()},
ensure_ascii=False,
)
class WorkflowRunStatus(StrEnum):
"""

View File

@ -16,12 +16,13 @@ from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.event.types import NodeEvent
from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
from core.workflow.workflow_entry import WorkflowEntry
from extensions.db import db
from extensions.ext_database import db
from models.account import Account
from models.dataset import Pipeline, PipelineBuiltInTemplate, PipelineCustomizedTemplate # type: ignore
from models.workflow import Workflow, WorkflowNodeExecution, WorkflowType
from services.entities.knowledge_entities.rag_pipeline_entities import PipelineTemplateInfoEntity
from services.errors.app import WorkflowHashNotEqualError
from services.errors.workflow_service import DraftWorkflowDeletionError
from services.rag_pipeline.pipeline_template.pipeline_template_factory import PipelineTemplateRetrievalFactory
@ -186,6 +187,7 @@ class RagPipelineService:
account: Account,
environment_variables: Sequence[Variable],
conversation_variables: Sequence[Variable],
pipeline_variables: dict[str, Sequence[Variable]],
) -> Workflow:
"""
Sync draft workflow
@ -212,6 +214,7 @@ class RagPipelineService:
created_by=account.id,
environment_variables=environment_variables,
conversation_variables=conversation_variables,
pipeline_variables=pipeline_variables,
)
db.session.add(workflow)
# update draft workflow if found
@ -222,7 +225,7 @@ class RagPipelineService:
workflow.updated_at = datetime.now(UTC).replace(tzinfo=None)
workflow.environment_variables = environment_variables
workflow.conversation_variables = conversation_variables
workflow.pipeline_variables = pipeline_variables
# commit db session changes
db.session.commit()
@ -343,6 +346,41 @@ class RagPipelineService:
return workflow_node_execution
def run_datasource_workflow_node(
self, pipeline: Pipeline, node_id: str, user_inputs: dict, account: Account
) -> WorkflowNodeExecution:
"""
Run published workflow datasource
"""
# fetch published workflow by app_model
published_workflow = self.get_published_workflow(pipeline=pipeline)
if not published_workflow:
raise ValueError("Workflow not initialized")
# run draft workflow node
start_at = time.perf_counter()
workflow_node_execution = self._handle_node_run_result(
getter=lambda: WorkflowEntry.single_step_run(
workflow=published_workflow,
node_id=node_id,
user_inputs=user_inputs,
user_id=account.id,
),
start_at=start_at,
tenant_id=pipeline.tenant_id,
node_id=node_id,
)
workflow_node_execution.app_id = pipeline.id
workflow_node_execution.created_by = account.id
workflow_node_execution.workflow_id = published_workflow.id
db.session.add(workflow_node_execution)
db.session.commit()
return workflow_node_execution
def run_free_workflow_node(
self, node_data: dict, tenant_id: str, user_id: str, node_id: str, user_inputs: dict[str, Any]
) -> WorkflowNodeExecution:
@ -573,3 +611,21 @@ class RagPipelineService:
session.delete(workflow)
return True
def get_second_step_parameters(self, pipeline: Pipeline, datasource_provider: str) -> dict:
"""
Get second step parameters of rag pipeline
"""
workflow = self.get_published_workflow(pipeline=pipeline)
if not workflow:
raise ValueError("Workflow not initialized")
# get second step node
pipeline_variables = workflow.pipeline_variables
if not pipeline_variables:
return {}
# get datasource provider
datasource_provider_variables = pipeline_variables.get(datasource_provider, [])
shared_variables = pipeline_variables.get("shared", [])
return datasource_provider_variables + shared_variables