merge main

This commit is contained in:
Joel 2025-04-08 14:17:07 +08:00
commit 358d5978cb
387 changed files with 9775 additions and 1432 deletions

View File

@ -7,7 +7,7 @@ pipx install poetry
echo 'alias start-api="cd /workspaces/dify/api && poetry run python -m flask run --host 0.0.0.0 --port=5001 --debug"' >> ~/.bashrc
echo 'alias start-worker="cd /workspaces/dify/api && poetry run python -m celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion"' >> ~/.bashrc
echo 'alias start-web="cd /workspaces/dify/web && pnpm dev"' >> ~/.bashrc
echo 'alias start-containers="cd /workspaces/dify/docker && docker-compose -f docker-compose.middleware.yaml -p dify up -d"' >> ~/.bashrc
echo 'alias stop-containers="cd /workspaces/dify/docker && docker-compose -f docker-compose.middleware.yaml -p dify down"' >> ~/.bashrc
echo 'alias start-containers="cd /workspaces/dify/docker && docker-compose -f docker-compose.middleware.yaml -p dify --env-file middleware.env up -d"' >> ~/.bashrc
echo 'alias stop-containers="cd /workspaces/dify/docker && docker-compose -f docker-compose.middleware.yaml -p dify --env-file middleware.env down"' >> ~/.bashrc
source /home/vscode/.bashrc

View File

@ -6,6 +6,7 @@ on:
- "main"
- "deploy/dev"
- "deploy/enterprise"
- release/1.1.3-fix1
tags:
- "*"

View File

@ -254,8 +254,6 @@ docker compose up -d
- [Discord](https://discord.gg/FngNHpbcY7)。👉:分享您的应用程序并与社区交流。
- [X(Twitter)](https://twitter.com/dify_ai)。👉:分享您的应用程序并与社区交流。
- [商业许可](mailto:business@dify.ai?subject=[GitHub]Business%20License%20Inquiry)。👉:有关商业用途许可 Dify.AI 的商业咨询。
- [微信]() 👉:扫描下方二维码,添加微信好友,备注 Dify我们将邀请您加入 Dify 社区。
<img src="./images/wechat.png" alt="wechat" width="100"/>
## 安全问题

View File

@ -189,6 +189,7 @@ TENCENT_VECTOR_DB_USERNAME=dify
TENCENT_VECTOR_DB_DATABASE=dify
TENCENT_VECTOR_DB_SHARD=1
TENCENT_VECTOR_DB_REPLICAS=2
TENCENT_VECTOR_DB_ENABLE_HYBRID_SEARCH=false
# ElasticSearch configuration
ELASTICSEARCH_HOST=127.0.0.1

View File

@ -848,6 +848,11 @@ class AccountConfig(BaseSettings):
default=5,
)
EDUCATION_ENABLED: bool = Field(
description="whether to enable education identity",
default=False,
)
class FeatureConfig(
# place the configs in alphabet order

View File

@ -48,3 +48,8 @@ class TencentVectorDBConfig(BaseSettings):
description="Name of the specific Tencent Vector Database to connect to",
default=None,
)
TENCENT_VECTOR_DB_ENABLE_HYBRID_SEARCH: bool = Field(
description="Enable hybrid search features",
default=False,
)

View File

@ -8,6 +8,7 @@ from werkzeug.exceptions import Forbidden
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import (
account_initialization_required,
cloud_edition_billing_resource_check,
setup_required,
)
from extensions.ext_database import db
@ -23,6 +24,7 @@ class AppImportApi(Resource):
@login_required
@account_initialization_required
@marshal_with(app_import_fields)
@cloud_edition_billing_resource_check("apps")
def post(self):
# Check user role first
if not current_user.is_editor:

View File

@ -99,53 +99,64 @@ class ForgotPasswordResetApi(Resource):
parser.add_argument("password_confirm", type=valid_password, required=True, nullable=False, location="json")
args = parser.parse_args()
new_password = args["new_password"]
password_confirm = args["password_confirm"]
if str(new_password).strip() != str(password_confirm).strip():
# Validate passwords match
if args["new_password"] != args["password_confirm"]:
raise PasswordMismatchError()
token = args["token"]
reset_data = AccountService.get_reset_password_data(token)
if reset_data is None:
# Validate token and get reset data
reset_data = AccountService.get_reset_password_data(args["token"])
if not reset_data:
raise InvalidTokenError()
AccountService.revoke_reset_password_token(token)
# Revoke token to prevent reuse
AccountService.revoke_reset_password_token(args["token"])
# Generate secure salt and hash password
salt = secrets.token_bytes(16)
base64_salt = base64.b64encode(salt).decode()
password_hashed = hash_password(args["new_password"], salt)
password_hashed = hash_password(new_password, salt)
base64_password_hashed = base64.b64encode(password_hashed).decode()
email = reset_data.get("email", "")
with Session(db.engine) as session:
account = session.execute(select(Account).filter_by(email=reset_data.get("email"))).scalar_one_or_none()
if account:
account.password = base64_password_hashed
account.password_salt = base64_salt
db.session.commit()
tenant = TenantService.get_join_tenants(account)
if not tenant and not FeatureService.get_system_features().is_allow_create_workspace:
tenant = TenantService.create_tenant(f"{account.name}'s Workspace")
TenantService.create_tenant_member(tenant, account, role="owner")
account.current_tenant = tenant
tenant_was_created.send(tenant)
else:
try:
account = AccountService.create_account_and_tenant(
email=reset_data.get("email", ""),
name=reset_data.get("email", ""),
password=password_confirm,
interface_language=languages[0],
)
except WorkSpaceNotAllowedCreateError:
pass
except AccountRegisterError:
raise AccountInFreezeError()
account = session.execute(select(Account).filter_by(email=email)).scalar_one_or_none()
if account:
self._update_existing_account(account, password_hashed, salt, session)
else:
self._create_new_account(email, args["password_confirm"])
return {"result": "success"}
def _update_existing_account(self, account, password_hashed, salt, session):
# Update existing account credentials
account.password = base64.b64encode(password_hashed).decode()
account.password_salt = base64.b64encode(salt).decode()
session.commit()
# Create workspace if needed
if (
not TenantService.get_join_tenants(account)
and FeatureService.get_system_features().is_allow_create_workspace
):
tenant = TenantService.create_tenant(f"{account.name}'s Workspace")
TenantService.create_tenant_member(tenant, account, role="owner")
account.current_tenant = tenant
tenant_was_created.send(tenant)
def _create_new_account(self, email, password):
# Create new account if allowed
try:
AccountService.create_account_and_tenant(
email=email,
name=email,
password=password,
interface_language=languages[0],
)
except WorkSpaceNotAllowedCreateError:
pass
except AccountRegisterError:
raise AccountInFreezeError()
api.add_resource(ForgotPasswordSendEmailApi, "/forgot-password")
api.add_resource(ForgotPasswordCheckApi, "/forgot-password/validity")

View File

@ -641,7 +641,6 @@ class DatasetRetrievalSettingApi(Resource):
VectorType.RELYT
| VectorType.TIDB_VECTOR
| VectorType.CHROMA
| VectorType.TENCENT
| VectorType.PGVECTO_RS
| VectorType.BAIDU
| VectorType.VIKINGDB
@ -665,6 +664,7 @@ class DatasetRetrievalSettingApi(Resource):
| VectorType.OPENGAUSS
| VectorType.OCEANBASE
| VectorType.TABLESTORE
| VectorType.TENCENT
):
return {
"retrieval_method": [
@ -688,7 +688,6 @@ class DatasetRetrievalSettingMockApi(Resource):
| VectorType.RELYT
| VectorType.TIDB_VECTOR
| VectorType.CHROMA
| VectorType.TENCENT
| VectorType.PGVECTO_RS
| VectorType.BAIDU
| VectorType.VIKINGDB
@ -710,6 +709,7 @@ class DatasetRetrievalSettingMockApi(Resource):
| VectorType.OPENGAUSS
| VectorType.OCEANBASE
| VectorType.TABLESTORE
| VectorType.TENCENT
):
return {
"retrieval_method": [

View File

@ -14,7 +14,12 @@ class WebsiteCrawlApi(Resource):
def post(self):
parser = reqparse.RequestParser()
parser.add_argument(
"provider", type=str, choices=["firecrawl", "jinareader"], required=True, nullable=True, location="json"
"provider",
type=str,
choices=["firecrawl", "watercrawl", "jinareader"],
required=True,
nullable=True,
location="json",
)
parser.add_argument("url", type=str, required=True, nullable=True, location="json")
parser.add_argument("options", type=dict, required=True, nullable=True, location="json")
@ -34,7 +39,9 @@ class WebsiteCrawlStatusApi(Resource):
@account_initialization_required
def get(self, job_id: str):
parser = reqparse.RequestParser()
parser.add_argument("provider", type=str, choices=["firecrawl", "jinareader"], required=True, location="args")
parser.add_argument(
"provider", type=str, choices=["firecrawl", "watercrawl", "jinareader"], required=True, location="args"
)
args = parser.parse_args()
# get crawl status
try:

View File

@ -103,6 +103,18 @@ class AccountInFreezeError(BaseHTTPException):
)
class EducationVerifyLimitError(BaseHTTPException):
error_code = "education_verify_limit"
description = "Rate limit exceeded"
code = 429
class EducationActivateLimitError(BaseHTTPException):
error_code = "education_activate_limit"
description = "Rate limit exceeded"
code = 429
class CompilanceRateLimitError(BaseHTTPException):
error_code = "compilance_rate_limit"
description = "Rate limit exceeded for downloading compliance report."

View File

@ -15,7 +15,13 @@ from controllers.console.workspace.error import (
InvalidInvitationCodeError,
RepeatPasswordNotMatchError,
)
from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required
from controllers.console.wraps import (
account_initialization_required,
cloud_edition_billing_enabled,
enterprise_license_required,
only_edition_cloud,
setup_required,
)
from extensions.ext_database import db
from fields.member_fields import account_fields
from libs.helper import TimestampField, timezone
@ -292,6 +298,79 @@ class AccountDeleteUpdateFeedbackApi(Resource):
return {"result": "success"}
class EducationVerifyApi(Resource):
verify_fields = {
"token": fields.String,
}
@setup_required
@login_required
@account_initialization_required
@only_edition_cloud
@cloud_edition_billing_enabled
@marshal_with(verify_fields)
def get(self):
account = current_user
return BillingService.EducationIdentity.verify(account.id, account.email)
class EducationApi(Resource):
status_fields = {
"result": fields.Boolean,
}
@setup_required
@login_required
@account_initialization_required
@only_edition_cloud
@cloud_edition_billing_enabled
def post(self):
account = current_user
parser = reqparse.RequestParser()
parser.add_argument("token", type=str, required=True, location="json")
parser.add_argument("institution", type=str, required=True, location="json")
parser.add_argument("role", type=str, required=True, location="json")
args = parser.parse_args()
return BillingService.EducationIdentity.activate(account, args["token"], args["institution"], args["role"])
@setup_required
@login_required
@account_initialization_required
@only_edition_cloud
@cloud_edition_billing_enabled
@marshal_with(status_fields)
def get(self):
account = current_user
return BillingService.EducationIdentity.is_active(account.id)
class EducationAutoCompleteApi(Resource):
data_fields = {
"data": fields.List(fields.String),
"curr_page": fields.Integer,
"has_next": fields.Boolean,
}
@setup_required
@login_required
@account_initialization_required
@only_edition_cloud
@cloud_edition_billing_enabled
@marshal_with(data_fields)
def get(self):
parser = reqparse.RequestParser()
parser.add_argument("keywords", type=str, required=True, location="args")
parser.add_argument("page", type=int, required=False, location="args", default=0)
parser.add_argument("limit", type=int, required=False, location="args", default=20)
args = parser.parse_args()
return BillingService.EducationIdentity.autocomplete(args["keywords"], args["page"], args["limit"])
# Register API resources
api.add_resource(AccountInitApi, "/account/init")
api.add_resource(AccountProfileApi, "/account/profile")
@ -305,5 +384,8 @@ api.add_resource(AccountIntegrateApi, "/account/integrates")
api.add_resource(AccountDeleteVerifyApi, "/account/delete/verify")
api.add_resource(AccountDeleteApi, "/account/delete")
api.add_resource(AccountDeleteUpdateFeedbackApi, "/account/delete/feedback")
api.add_resource(EducationVerifyApi, "/account/education/verify")
api.add_resource(EducationApi, "/account/education")
api.add_resource(EducationAutoCompleteApi, "/account/education/autocomplete")
# api.add_resource(AccountEmailApi, '/account/email')
# api.add_resource(AccountEmailVerifyApi, '/account/email-verify')

View File

@ -236,7 +236,7 @@ class PluginFetchManifestApi(Resource):
@setup_required
@login_required
@account_initialization_required
@plugin_permission_required(debug_required=True)
@plugin_permission_required(install_required=True)
def get(self):
tenant_id = current_user.current_tenant_id
@ -260,7 +260,7 @@ class PluginFetchInstallTasksApi(Resource):
@setup_required
@login_required
@account_initialization_required
@plugin_permission_required(debug_required=True)
@plugin_permission_required(install_required=True)
def get(self):
tenant_id = current_user.current_tenant_id
@ -281,7 +281,7 @@ class PluginFetchInstallTaskApi(Resource):
@setup_required
@login_required
@account_initialization_required
@plugin_permission_required(debug_required=True)
@plugin_permission_required(install_required=True)
def get(self, task_id: str):
tenant_id = current_user.current_tenant_id
@ -295,7 +295,7 @@ class PluginDeleteInstallTaskApi(Resource):
@setup_required
@login_required
@account_initialization_required
@plugin_permission_required(debug_required=True)
@plugin_permission_required(install_required=True)
def post(self, task_id: str):
tenant_id = current_user.current_tenant_id
@ -309,7 +309,7 @@ class PluginDeleteAllInstallTaskItemsApi(Resource):
@setup_required
@login_required
@account_initialization_required
@plugin_permission_required(debug_required=True)
@plugin_permission_required(install_required=True)
def post(self):
tenant_id = current_user.current_tenant_id
@ -323,7 +323,7 @@ class PluginDeleteInstallTaskItemApi(Resource):
@setup_required
@login_required
@account_initialization_required
@plugin_permission_required(debug_required=True)
@plugin_permission_required(install_required=True)
def post(self, task_id: str, identifier: str):
tenant_id = current_user.current_tenant_id
@ -337,7 +337,7 @@ class PluginUpgradeFromMarketplaceApi(Resource):
@setup_required
@login_required
@account_initialization_required
@plugin_permission_required(debug_required=True)
@plugin_permission_required(install_required=True)
def post(self):
tenant_id = current_user.current_tenant_id
@ -360,7 +360,7 @@ class PluginUpgradeFromGithubApi(Resource):
@setup_required
@login_required
@account_initialization_required
@plugin_permission_required(debug_required=True)
@plugin_permission_required(install_required=True)
def post(self):
tenant_id = current_user.current_tenant_id
@ -391,7 +391,7 @@ class PluginUninstallApi(Resource):
@setup_required
@login_required
@account_initialization_required
@plugin_permission_required(debug_required=True)
@plugin_permission_required(install_required=True)
def post(self):
req = reqparse.RequestParser()
req.add_argument("plugin_installation_id", type=str, required=True, location="json")

View File

@ -216,6 +216,23 @@ class WebappLogoWorkspaceApi(Resource):
return {"id": upload_file.id}, 201
class WorkspaceInfoApi(Resource):
@setup_required
@login_required
@account_initialization_required
# Change workspace name
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("name", type=str, required=True, location="json")
args = parser.parse_args()
tenant = Tenant.query.filter(Tenant.id == current_user.current_tenant_id).one_or_404()
tenant.name = args["name"]
db.session.commit()
return {"result": "success", "tenant": marshal(WorkspaceService.get_tenant_info(tenant), tenant_fields)}
api.add_resource(TenantListApi, "/workspaces") # GET for getting all tenants
api.add_resource(WorkspaceListApi, "/all-workspaces") # GET for getting all tenants
api.add_resource(TenantApi, "/workspaces/current", endpoint="workspaces_current") # GET for getting current tenant info
@ -223,3 +240,4 @@ api.add_resource(TenantApi, "/info", endpoint="info") # Deprecated
api.add_resource(SwitchWorkspaceApi, "/workspaces/switch") # POST for switching tenant
api.add_resource(CustomConfigWorkspaceApi, "/workspaces/custom-config")
api.add_resource(WebappLogoWorkspaceApi, "/workspaces/custom-config/webapp-logo/upload")
api.add_resource(WorkspaceInfoApi, "/workspaces/info") # POST for changing workspace info

View File

@ -54,6 +54,17 @@ def only_edition_self_hosted(view):
return decorated
def cloud_edition_billing_enabled(view):
@wraps(view)
def decorated(*args, **kwargs):
features = FeatureService.get_features(current_user.current_tenant_id)
if not features.billing.enabled:
abort(403, "Billing feature is not enabled.")
return view(*args, **kwargs)
return decorated
def cloud_edition_billing_resource_check(resource: str):
def interceptor(view):
@wraps(view)

View File

@ -6,5 +6,6 @@ bp = Blueprint("service_api", __name__, url_prefix="/v1")
api = ExternalApi(bp)
from . import index
from .app import app, audio, completion, conversation, file, message, workflow
from .app import annotation, app, audio, completion, conversation, file, message, workflow
from .dataset import dataset, document, hit_testing, metadata, segment, upload_file
from .workspace import models

View File

@ -0,0 +1,107 @@
from flask import request
from flask_restful import Resource, marshal, marshal_with, reqparse # type: ignore
from werkzeug.exceptions import Forbidden
from controllers.service_api import api
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from extensions.ext_redis import redis_client
from fields.annotation_fields import (
annotation_fields,
)
from libs.login import current_user
from models.model import App, EndUser
from services.annotation_service import AppAnnotationService
class AnnotationReplyActionApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
def post(self, app_model: App, end_user: EndUser, action):
parser = reqparse.RequestParser()
parser.add_argument("score_threshold", required=True, type=float, location="json")
parser.add_argument("embedding_provider_name", required=True, type=str, location="json")
parser.add_argument("embedding_model_name", required=True, type=str, location="json")
args = parser.parse_args()
if action == "enable":
result = AppAnnotationService.enable_app_annotation(args, app_model.id)
elif action == "disable":
result = AppAnnotationService.disable_app_annotation(app_model.id)
else:
raise ValueError("Unsupported annotation reply action")
return result, 200
class AnnotationReplyActionStatusApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY))
def get(self, app_model: App, end_user: EndUser, job_id, action):
job_id = str(job_id)
app_annotation_job_key = "{}_app_annotation_job_{}".format(action, str(job_id))
cache_result = redis_client.get(app_annotation_job_key)
if cache_result is None:
raise ValueError("The job does not exist.")
job_status = cache_result.decode()
error_msg = ""
if job_status == "error":
app_annotation_error_key = "{}_app_annotation_error_{}".format(action, str(job_id))
error_msg = redis_client.get(app_annotation_error_key).decode()
return {"job_id": job_id, "job_status": job_status, "error_msg": error_msg}, 200
class AnnotationListApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY))
def get(self, app_model: App, end_user: EndUser):
page = request.args.get("page", default=1, type=int)
limit = request.args.get("limit", default=20, type=int)
keyword = request.args.get("keyword", default="", type=str)
annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id(app_model.id, page, limit, keyword)
response = {
"data": marshal(annotation_list, annotation_fields),
"has_more": len(annotation_list) == limit,
"limit": limit,
"total": total,
"page": page,
}
return response, 200
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
@marshal_with(annotation_fields)
def post(self, app_model: App, end_user: EndUser):
parser = reqparse.RequestParser()
parser.add_argument("question", required=True, type=str, location="json")
parser.add_argument("answer", required=True, type=str, location="json")
args = parser.parse_args()
annotation = AppAnnotationService.insert_app_annotation_directly(args, app_model.id)
return annotation
class AnnotationUpdateDeleteApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
@marshal_with(annotation_fields)
def post(self, app_model: App, end_user: EndUser, annotation_id):
if not current_user.is_editor:
raise Forbidden()
annotation_id = str(annotation_id)
parser = reqparse.RequestParser()
parser.add_argument("question", required=True, type=str, location="json")
parser.add_argument("answer", required=True, type=str, location="json")
args = parser.parse_args()
annotation = AppAnnotationService.update_app_annotation_directly(args, app_model.id, annotation_id)
return annotation
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY))
def delete(self, app_model: App, end_user: EndUser, annotation_id):
if not current_user.is_editor:
raise Forbidden()
annotation_id = str(annotation_id)
AppAnnotationService.delete_app_annotation(app_model.id, annotation_id)
return {"result": "success"}, 200
api.add_resource(AnnotationReplyActionApi, "/apps/annotation-reply/<string:action>")
api.add_resource(AnnotationReplyActionStatusApi, "/apps/annotation-reply/<string:action>/status/<uuid:job_id>")
api.add_resource(AnnotationListApi, "/apps/annotations")
api.add_resource(AnnotationUpdateDeleteApi, "/apps/annotations/<uuid:annotation_id>")

View File

@ -1,3 +1,4 @@
import json
import logging
from flask_restful import Resource, fields, marshal_with, reqparse # type: ignore
@ -10,7 +11,7 @@ from controllers.service_api.app.error import NotChatAppError
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from core.app.entities.app_invoke_entities import InvokeFrom
from fields.conversation_fields import message_file_fields
from fields.message_fields import agent_thought_fields, feedback_fields, retriever_resource_fields
from fields.message_fields import agent_thought_fields, feedback_fields
from fields.raws import FilesContainedField
from libs.helper import TimestampField, uuid_value
from models.model import App, AppMode, EndUser
@ -19,6 +20,14 @@ from services.message_service import MessageService
class MessageListApi(Resource):
def get_retriever_resources(self):
try:
if self.message_metadata:
return json.loads(self.message_metadata).get("retriever_resources", [])
return []
except (json.JSONDecodeError, TypeError):
return []
message_fields = {
"id": fields.String,
"conversation_id": fields.String,
@ -28,7 +37,7 @@ class MessageListApi(Resource):
"answer": fields.String(attribute="re_sign_file_url_answer"),
"message_files": fields.List(fields.Nested(message_file_fields)),
"feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True),
"retriever_resources": fields.List(fields.Nested(retriever_resource_fields)),
"retriever_resources": get_retriever_resources,
"created_at": TimestampField,
"agent_thoughts": fields.List(fields.Nested(agent_thought_fields)),
"status": fields.String,

View File

@ -1,6 +1,6 @@
from flask import request
from flask_restful import marshal, reqparse # type: ignore
from werkzeug.exceptions import NotFound
from werkzeug.exceptions import Forbidden, NotFound
import services.dataset_service
from controllers.service_api import api
@ -12,7 +12,7 @@ from core.provider_manager import ProviderManager
from fields.dataset_fields import dataset_detail_fields
from libs.login import current_user
from models.dataset import Dataset, DatasetPermissionEnum
from services.dataset_service import DatasetService
from services.dataset_service import DatasetPermissionService, DatasetService
def _validate_name(name):
@ -21,6 +21,12 @@ def _validate_name(name):
return name
def _validate_description_length(description):
if len(description) > 400:
raise ValueError("Description cannot exceed 400 characters.")
return description
class DatasetListApi(DatasetApiResource):
"""Resource for datasets."""
@ -137,6 +143,145 @@ class DatasetListApi(DatasetApiResource):
class DatasetApi(DatasetApiResource):
"""Resource for dataset."""
def get(self, _, dataset_id):
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
raise NotFound("Dataset not found.")
try:
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
data = marshal(dataset, dataset_detail_fields)
if data.get("permission") == "partial_members":
part_users_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
data.update({"partial_member_list": part_users_list})
# check embedding setting
provider_manager = ProviderManager()
configurations = provider_manager.get_configurations(tenant_id=current_user.current_tenant_id)
embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True)
model_names = []
for embedding_model in embedding_models:
model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}")
if data["indexing_technique"] == "high_quality":
item_model = f"{data['embedding_model']}:{data['embedding_model_provider']}"
if item_model in model_names:
data["embedding_available"] = True
else:
data["embedding_available"] = False
else:
data["embedding_available"] = True
if data.get("permission") == "partial_members":
part_users_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
data.update({"partial_member_list": part_users_list})
return data, 200
def patch(self, _, dataset_id):
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
raise NotFound("Dataset not found.")
parser = reqparse.RequestParser()
parser.add_argument(
"name",
nullable=False,
help="type is required. Name must be between 1 to 40 characters.",
type=_validate_name,
)
parser.add_argument("description", location="json", store_missing=False, type=_validate_description_length)
parser.add_argument(
"indexing_technique",
type=str,
location="json",
choices=Dataset.INDEXING_TECHNIQUE_LIST,
nullable=True,
help="Invalid indexing technique.",
)
parser.add_argument(
"permission",
type=str,
location="json",
choices=(DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM),
help="Invalid permission.",
)
parser.add_argument("embedding_model", type=str, location="json", help="Invalid embedding model.")
parser.add_argument(
"embedding_model_provider", type=str, location="json", help="Invalid embedding model provider."
)
parser.add_argument("retrieval_model", type=dict, location="json", help="Invalid retrieval model.")
parser.add_argument("partial_member_list", type=list, location="json", help="Invalid parent user list.")
parser.add_argument(
"external_retrieval_model",
type=dict,
required=False,
nullable=True,
location="json",
help="Invalid external retrieval model.",
)
parser.add_argument(
"external_knowledge_id",
type=str,
required=False,
nullable=True,
location="json",
help="Invalid external knowledge id.",
)
parser.add_argument(
"external_knowledge_api_id",
type=str,
required=False,
nullable=True,
location="json",
help="Invalid external knowledge api id.",
)
args = parser.parse_args()
data = request.get_json()
# check embedding model setting
if data.get("indexing_technique") == "high_quality":
DatasetService.check_embedding_model_setting(
dataset.tenant_id, data.get("embedding_model_provider"), data.get("embedding_model")
)
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
DatasetPermissionService.check_permission(
current_user, dataset, data.get("permission"), data.get("partial_member_list")
)
dataset = DatasetService.update_dataset(dataset_id_str, args, current_user)
if dataset is None:
raise NotFound("Dataset not found.")
result_data = marshal(dataset, dataset_detail_fields)
tenant_id = current_user.current_tenant_id
if data.get("partial_member_list") and data.get("permission") == "partial_members":
DatasetPermissionService.update_partial_member_list(
tenant_id, dataset_id_str, data.get("partial_member_list")
)
# clear partial member list when permission is only_me or all_team_members
elif (
data.get("permission") == DatasetPermissionEnum.ONLY_ME
or data.get("permission") == DatasetPermissionEnum.ALL_TEAM
):
DatasetPermissionService.clear_partial_member_list(dataset_id_str)
partial_member_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
result_data.update({"partial_member_list": partial_member_list})
return result_data, 200
def delete(self, _, dataset_id):
"""
Deletes a dataset given its ID.
@ -158,6 +303,7 @@ class DatasetApi(DatasetApiResource):
try:
if DatasetService.delete_dataset(dataset_id_str, current_user):
DatasetPermissionService.clear_partial_member_list(dataset_id_str)
return {"result": "success"}, 204
else:
raise NotFound("Dataset not found.")

View File

@ -341,7 +341,7 @@ class DocumentListApi(DatasetApiResource):
search = f"%{search}%"
query = query.filter(Document.name.like(search))
query = query.order_by(desc(Document.created_at))
query = query.order_by(desc(Document.created_at), desc(Document.position))
paginated_documents = query.paginate(page=page, per_page=limit, max_per_page=100, error_out=False)
documents = paginated_documents.items

View File

@ -0,0 +1,21 @@
from flask_login import current_user # type: ignore
from flask_restful import Resource # type: ignore
from controllers.service_api import api
from controllers.service_api.wraps import validate_dataset_token
from core.model_runtime.utils.encoders import jsonable_encoder
from services.model_provider_service import ModelProviderService
class ModelProviderAvailableModelApi(Resource):
@validate_dataset_token
def get(self, _, model_type):
tenant_id = current_user.current_tenant_id
model_provider_service = ModelProviderService()
models = model_provider_service.get_models_by_model_type(tenant_id=tenant_id, model_type=model_type)
return jsonable_encoder({"data": models})
api.add_resource(ModelProviderAvailableModelApi, "/workspaces/current/models/model-types/<string:model_type>")

View File

@ -59,6 +59,27 @@ def validate_app_token(view: Optional[Callable] = None, *, fetch_user_arg: Optio
if tenant.status == TenantStatus.ARCHIVE:
raise Forbidden("The workspace's status is archived.")
tenant_account_join = (
db.session.query(Tenant, TenantAccountJoin)
.filter(Tenant.id == api_token.tenant_id)
.filter(TenantAccountJoin.tenant_id == Tenant.id)
.filter(TenantAccountJoin.role.in_(["owner"]))
.filter(Tenant.status == TenantStatus.NORMAL)
.one_or_none()
) # TODO: only owner information is required, so only one is returned.
if tenant_account_join:
tenant, ta = tenant_account_join
account = Account.query.filter_by(id=ta.account_id).first()
# Login admin
if account:
account.current_tenant = tenant
current_app.login_manager._update_request_context_with_user(account) # type: ignore
user_logged_in.send(current_app._get_current_object(), user=_get_user()) # type: ignore
else:
raise Unauthorized("Tenant owner account does not exist.")
else:
raise Unauthorized("Tenant does not exist.")
kwargs["app_model"] = app_model
if fetch_user_arg:

View File

@ -19,6 +19,8 @@ class PassportResource(Resource):
def get(self):
system_features = FeatureService.get_system_features()
app_code = request.headers.get("X-App-Code")
user_id = request.args.get("user_id")
if app_code is None:
raise Unauthorized("X-App-Code header is missing.")
@ -36,16 +38,33 @@ class PassportResource(Resource):
if not app_model or app_model.status != "normal" or not app_model.enable_site:
raise NotFound()
end_user = EndUser(
tenant_id=app_model.tenant_id,
app_id=app_model.id,
type="browser",
is_anonymous=True,
session_id=generate_session_id(),
)
if user_id:
end_user = (
db.session.query(EndUser).filter(EndUser.app_id == app_model.id, EndUser.session_id == user_id).first()
)
db.session.add(end_user)
db.session.commit()
if end_user:
pass
else:
end_user = EndUser(
tenant_id=app_model.tenant_id,
app_id=app_model.id,
type="browser",
is_anonymous=True,
session_id=user_id,
)
db.session.add(end_user)
db.session.commit()
else:
end_user = EndUser(
tenant_id=app_model.tenant_id,
app_id=app_model.id,
type="browser",
is_anonymous=True,
session_id=generate_session_id(),
)
db.session.add(end_user)
db.session.commit()
payload = {
"iss": site.app_id,

View File

@ -12,39 +12,45 @@ class CotAgentOutputParser:
def handle_react_stream_output(
cls, llm_response: Generator[LLMResultChunk, None, None], usage_dict: dict
) -> Generator[Union[str, AgentScratchpadUnit.Action], None, None]:
def parse_action(json_str):
try:
action = json.loads(json_str, strict=False)
action_name = None
action_input = None
def parse_action(action) -> Union[str, AgentScratchpadUnit.Action]:
action_name = None
action_input = None
if isinstance(action, str):
try:
action = json.loads(action, strict=False)
except json.JSONDecodeError:
return action or ""
# cohere always returns a list
if isinstance(action, list) and len(action) == 1:
action = action[0]
# cohere always returns a list
if isinstance(action, list) and len(action) == 1:
action = action[0]
for key, value in action.items():
if "input" in key.lower():
action_input = value
else:
action_name = value
if action_name is not None and action_input is not None:
return AgentScratchpadUnit.Action(
action_name=action_name,
action_input=action_input,
)
for key, value in action.items():
if "input" in key.lower():
action_input = value
else:
return json_str or ""
except:
return json_str or ""
action_name = value
def extra_json_from_code_block(code_block) -> Generator[Union[str, AgentScratchpadUnit.Action], None, None]:
code_blocks = re.findall(r"```(.*?)```", code_block, re.DOTALL)
if not code_blocks:
return
for block in code_blocks:
json_text = re.sub(r"^[a-zA-Z]+\n", "", block.strip(), flags=re.MULTILINE)
yield parse_action(json_text)
if action_name is not None and action_input is not None:
return AgentScratchpadUnit.Action(
action_name=action_name,
action_input=action_input,
)
else:
return json.dumps(action)
def extra_json_from_code_block(code_block) -> list[Union[list, dict]]:
blocks = re.findall(r"```[json]*\s*([\[{].*[]}])\s*```", code_block, re.DOTALL | re.IGNORECASE)
if not blocks:
return []
try:
json_blocks = []
for block in blocks:
json_text = re.sub(r"^[a-zA-Z]+\n", "", block.strip(), flags=re.MULTILINE)
json_blocks.append(json.loads(json_text, strict=False))
return json_blocks
except:
return []
code_block_cache = ""
code_block_delimiter_count = 0
@ -78,7 +84,7 @@ class CotAgentOutputParser:
delta = response_content[index : index + steps]
yield_delta = False
if delta == "`":
if not in_json and delta == "`":
last_character = delta
code_block_cache += delta
code_block_delimiter_count += 1
@ -159,8 +165,14 @@ class CotAgentOutputParser:
if code_block_delimiter_count == 3:
if in_code_block:
last_character = delta
yield from extra_json_from_code_block(code_block_cache)
code_block_cache = ""
action_json_list = extra_json_from_code_block(code_block_cache)
if action_json_list:
for action_json in action_json_list:
yield parse_action(action_json)
code_block_cache = ""
else:
index += steps
continue
in_code_block = not in_code_block
code_block_delimiter_count = 0

View File

@ -70,11 +70,20 @@ class AgentStrategyIdentity(ToolIdentity):
pass
class AgentFeature(enum.StrEnum):
"""
Agent Feature, used to describe the features of the agent strategy.
"""
HISTORY_MESSAGES = "history-messages"
class AgentStrategyEntity(BaseModel):
identity: AgentStrategyIdentity
parameters: list[AgentStrategyParameter] = Field(default_factory=list)
description: I18nObject = Field(..., description="The description of the agent strategy")
output_schema: Optional[dict] = None
features: Optional[list[AgentFeature]] = None
# pydantic configs
model_config = ConfigDict(protected_namespaces=())

View File

@ -146,6 +146,7 @@ class BasicProviderConfig(BaseModel):
BOOLEAN = CommonParameterType.BOOLEAN.value
APP_SELECTOR = CommonParameterType.APP_SELECTOR.value
MODEL_SELECTOR = CommonParameterType.MODEL_SELECTOR.value
TOOLS_SELECTOR = CommonParameterType.TOOLS_SELECTOR.value
@classmethod
def value_of(cls, value: str) -> "ProviderConfig.Type":

View File

@ -213,9 +213,24 @@ class LangFuseDataTrace(BaseTraceInstance):
if process_data and process_data.get("model_mode") == "chat":
total_token = metadata.get("total_tokens", 0)
prompt_tokens = 0
completion_tokens = 0
try:
if outputs.get("usage"):
prompt_tokens = outputs.get("usage", {}).get("prompt_tokens", 0)
completion_tokens = outputs.get("usage", {}).get("completion_tokens", 0)
else:
prompt_tokens = process_data.get("usage", {}).get("prompt_tokens", 0)
completion_tokens = process_data.get("usage", {}).get("completion_tokens", 0)
except Exception:
logger.error("Failed to extract usage", exc_info=True)
# add generation
generation_usage = GenerationUsage(
input=prompt_tokens,
output=completion_tokens,
total=total_token,
unit=UnitEnum.TOKENS,
)
node_generation_data = LangfuseGeneration(

View File

@ -199,6 +199,7 @@ class LangSmithDataTrace(BaseTraceInstance):
)
process_data = json.loads(node_execution.process_data) if node_execution.process_data else {}
if process_data and process_data.get("model_mode") == "chat":
run_type = LangSmithRunType.llm
metadata.update(
@ -212,9 +213,23 @@ class LangSmithDataTrace(BaseTraceInstance):
else:
run_type = LangSmithRunType.tool
prompt_tokens = 0
completion_tokens = 0
try:
if outputs.get("usage"):
prompt_tokens = outputs.get("usage", {}).get("prompt_tokens", 0)
completion_tokens = outputs.get("usage", {}).get("completion_tokens", 0)
else:
prompt_tokens = process_data.get("usage", {}).get("prompt_tokens", 0)
completion_tokens = process_data.get("usage", {}).get("completion_tokens", 0)
except Exception:
logger.error("Failed to extract usage", exc_info=True)
node_dotted_order = generate_dotted_order(node_execution_id, created_at, workflow_dotted_order)
langsmith_run = LangSmithRunModel(
total_tokens=node_total_tokens,
input_tokens=prompt_tokens,
output_tokens=completion_tokens,
name=node_type,
inputs=inputs,
run_type=run_type,

View File

@ -27,9 +27,26 @@ class CleanProcessor:
pattern = r"([a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+)"
text = re.sub(pattern, "", text)
# Remove URL
pattern = r"https?://[^\s]+"
text = re.sub(pattern, "", text)
# Remove URL but keep Markdown image URLs
# First, temporarily replace Markdown image URLs with a placeholder
markdown_image_pattern = r"!\[.*?\]\((https?://[^\s)]+)\)"
placeholders: list[str] = []
def replace_with_placeholder(match, placeholders=placeholders):
url = match.group(1)
placeholder = f"__MARKDOWN_IMAGE_URL_{len(placeholders)}__"
placeholders.append(url)
return f"![image]({placeholder})"
text = re.sub(markdown_image_pattern, replace_with_placeholder, text)
# Now remove all remaining URLs
url_pattern = r"https?://[^\s)]+"
text = re.sub(url_pattern, "", text)
# Finally, restore the Markdown image URLs
for i, url in enumerate(placeholders):
text = text.replace(f"__MARKDOWN_IMAGE_URL_{i}__", url)
return text
def filter_string(self, text):

View File

@ -1,10 +1,13 @@
import copy
import json
import logging
import time
from typing import Any, Optional
from opensearchpy import OpenSearch
from opensearchpy import OpenSearch, helpers
from opensearchpy.helpers import BulkIndexError
from pydantic import BaseModel, model_validator
from tenacity import retry, stop_after_attempt, wait_exponential
from configs import dify_config
from core.rag.datasource.vdb.field import Field
@ -77,31 +80,74 @@ class LindormVectorStore(BaseVector):
def refresh(self):
self._client.indices.refresh(index=self._collection_name)
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
actions = []
def add_texts(
self,
documents: list[Document],
embeddings: list[list[float]],
batch_size: int = 64,
timeout: int = 60,
**kwargs,
):
logger.info(f"Total documents to add: {len(documents)}")
uuids = self._get_uuids(documents)
for i in range(len(documents)):
action_header = {
"index": {
"_index": self.collection_name.lower(),
"_id": uuids[i],
total_docs = len(documents)
num_batches = (total_docs + batch_size - 1) // batch_size
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
)
def _bulk_with_retry(actions):
try:
response = self._client.bulk(actions, timeout=timeout)
if response["errors"]:
error_items = [item for item in response["items"] if "error" in item["index"]]
error_msg = f"Bulk indexing had {len(error_items)} errors"
logger.exception(error_msg)
raise Exception(error_msg)
return response
except Exception:
logger.exception("Bulk indexing error")
raise
for batch_num in range(num_batches):
start_idx = batch_num * batch_size
end_idx = min((batch_num + 1) * batch_size, total_docs)
actions = []
for i in range(start_idx, end_idx):
action_header = {
"index": {
"_index": self.collection_name.lower(),
"_id": uuids[i],
}
}
}
action_values: dict[str, Any] = {
Field.CONTENT_KEY.value: documents[i].page_content,
Field.VECTOR.value: embeddings[i], # Make sure you pass an array here
Field.METADATA_KEY.value: documents[i].metadata,
}
if self._using_ugc:
action_header["index"]["routing"] = self._routing
if self._routing_field is not None:
action_values[self._routing_field] = self._routing
actions.append(action_header)
actions.append(action_values)
response = self._client.bulk(actions)
if response["errors"]:
for item in response["items"]:
print(f"{item['index']['status']}: {item['index']['error']['type']}")
action_values: dict[str, Any] = {
Field.CONTENT_KEY.value: documents[i].page_content,
Field.VECTOR.value: embeddings[i],
Field.METADATA_KEY.value: documents[i].metadata,
}
if self._using_ugc:
action_header["index"]["routing"] = self._routing
if self._routing_field is not None:
action_values[self._routing_field] = self._routing
actions.append(action_header)
actions.append(action_values)
# logger.info(f"Processing batch {batch_num + 1}/{num_batches} (documents {start_idx + 1} to {end_idx})")
try:
_bulk_with_retry(actions)
# logger.info(f"Successfully processed batch {batch_num + 1}")
# simple latency to avoid too many requests in a short time
if batch_num < num_batches - 1:
time.sleep(0.5)
except Exception:
logger.exception(f"Failed to process batch {batch_num + 1}")
raise
def get_ids_by_metadata_field(self, key: str, value: str):
query: dict[str, Any] = {
@ -121,19 +167,51 @@ class LindormVectorStore(BaseVector):
self.delete_by_ids(ids)
def delete_by_ids(self, ids: list[str]) -> None:
params = {}
if self._using_ugc:
params["routing"] = self._routing
"""Delete documents by their IDs in batch.
Args:
ids: List of document IDs to delete
"""
if not ids:
return
params = {"routing": self._routing} if self._using_ugc else {}
# 1. First check if collection exists
if not self._client.indices.exists(index=self._collection_name):
logger.warning(f"Collection {self._collection_name} does not exist")
return
# 2. Batch process deletions
actions = []
for id in ids:
if self._client.exists(index=self._collection_name, id=id, params=params):
params = {}
if self._using_ugc:
params["routing"] = self._routing
self._client.delete(index=self._collection_name, id=id, params=params)
self.refresh()
actions.append(
{
"_op_type": "delete",
"_index": self._collection_name,
"_id": id,
**params, # Include routing if using UGC
}
)
else:
logger.warning(f"DELETE BY ID: ID {id} does not exist in the index.")
# 3. Perform bulk deletion if there are valid documents to delete
if actions:
try:
helpers.bulk(self._client, actions)
except BulkIndexError as e:
for error in e.errors:
delete_error = error.get("delete", {})
status = delete_error.get("status")
doc_id = delete_error.get("_id")
if status == 404:
logger.warning(f"Document not found for deletion: {doc_id}")
else:
logger.exception(f"Error deleting document: {error}")
def delete(self) -> None:
if self._using_ugc:
routing_filter_query = {
@ -169,7 +247,7 @@ class LindormVectorStore(BaseVector):
document_ids_filter = kwargs.get("document_ids_filter")
filters = []
if document_ids_filter:
filters.append({"terms": {"metadata.document_id": document_ids_filter}})
filters.append({"terms": {"metadata.document_id.keyword": document_ids_filter}})
query = default_vector_search_query(query_vector=query_vector, k=top_k, filters=filters, **kwargs)
try:
@ -212,7 +290,7 @@ class LindormVectorStore(BaseVector):
filters = kwargs.get("filter", [])
document_ids_filter = kwargs.get("document_ids_filter")
if document_ids_filter:
filters.append({"terms": {"metadata.document_id": document_ids_filter}})
filters.append({"terms": {"metadata.document_id.keyword": document_ids_filter}})
routing = self._routing
full_text_query = default_text_search_query(
query_text=query,
@ -226,6 +304,7 @@ class LindormVectorStore(BaseVector):
routing=routing,
routing_field=self._routing_field,
)
response = self._client.search(index=self._collection_name, body=full_text_query)
docs = []
for hit in response["hits"]["hits"]:
@ -435,7 +514,7 @@ def default_vector_search_query(
**kwargs,
) -> dict:
if filters is not None:
filter_type = "post_filter" if filter_type is None else filter_type
filter_type = "pre_filter" if filter_type is None else filter_type
if not isinstance(filters, list):
raise RuntimeError(f"unexpected filter with {type(filters)}")
final_ext: dict[str, Any] = {"lvector": {}}

View File

@ -1,12 +1,14 @@
import json
import logging
import math
from typing import Any, Optional
from pydantic import BaseModel
from tcvdb_text.encoder import BM25Encoder # type: ignore
from tcvectordb import RPCVectorDBClient, VectorDBException # type: ignore
from tcvectordb.model import document, enum # type: ignore
from tcvectordb.model import index as vdb_index # type: ignore
from tcvectordb.model.document import Filter # type: ignore
from tcvectordb.model.document import AnnSearch, Filter, KeywordSearch, WeightedRerank # type: ignore
from configs import dify_config
from core.rag.datasource.vdb.vector_base import BaseVector
@ -17,6 +19,8 @@ from core.rag.models.document import Document
from extensions.ext_redis import redis_client
from models.dataset import Dataset
logger = logging.getLogger(__name__)
class TencentConfig(BaseModel):
url: str
@ -25,10 +29,11 @@ class TencentConfig(BaseModel):
username: Optional[str]
database: Optional[str]
index_type: str = "HNSW"
metric_type: str = "L2"
metric_type: str = "IP"
shard: int = 1
replicas: int = 2
max_upsert_batch_size: int = 128
enable_hybrid_search: bool = False # Flag to enable hybrid search
def to_tencent_params(self):
return {"url": self.url, "username": self.username, "key": self.api_key, "timeout": self.timeout}
@ -44,6 +49,29 @@ class TencentVector(BaseVector):
super().__init__(collection_name)
self._client_config = config
self._client = RPCVectorDBClient(**self._client_config.to_tencent_params())
self._enable_hybrid_search = False
self._dimension = 1024
self._load_collection()
self._bm25 = BM25Encoder.default("zh")
def _load_collection(self):
"""
Check if the collection supports hybrid search.
"""
if self._client_config.enable_hybrid_search:
self._enable_hybrid_search = True
if self._has_collection():
coll = self._client.describe_collection(
database_name=self._client_config.database, collection_name=self.collection_name
)
has_hybrid_search = False
for idx in coll.indexes:
if idx.name == "sparse_vector":
has_hybrid_search = True
elif idx.name == "vector":
self._dimension = idx.dimension
if not has_hybrid_search:
self._enable_hybrid_search = False
def _init_database(self):
return self._client.create_database_if_not_exists(database_name=self._client_config.database)
@ -62,6 +90,7 @@ class TencentVector(BaseVector):
)
def _create_collection(self, dimension: int) -> None:
self._dimension = dimension
lock_name = "vector_indexing_lock_{}".format(self._collection_name)
with redis_client.lock(lock_name, timeout=20):
collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name)
@ -84,18 +113,25 @@ class TencentVector(BaseVector):
if metric_type is None:
raise ValueError("unsupported metric_type")
params = vdb_index.HNSWParams(m=16, efconstruction=200)
index = vdb_index.Index(
vdb_index.FilterIndex(self.field_id, enum.FieldType.String, enum.IndexType.PRIMARY_KEY),
vdb_index.VectorIndex(
self.field_vector,
dimension,
index_type,
metric_type,
params,
),
vdb_index.FilterIndex(self.field_text, enum.FieldType.String, enum.IndexType.FILTER),
vdb_index.FilterIndex(self.field_metadata, enum.FieldType.Json, enum.IndexType.FILTER),
index_id = vdb_index.FilterIndex(self.field_id, enum.FieldType.String, enum.IndexType.PRIMARY_KEY)
index_vector = vdb_index.VectorIndex(
self.field_vector,
dimension,
index_type,
metric_type,
params,
)
index_text = vdb_index.FilterIndex(self.field_text, enum.FieldType.String, enum.IndexType.FILTER)
index_metadate = vdb_index.FilterIndex(self.field_metadata, enum.FieldType.Json, enum.IndexType.FILTER)
index_sparse_vector = vdb_index.SparseIndex(
name="sparse_vector",
field_type=enum.FieldType.SparseVector,
index_type=enum.IndexType.SPARSE_INVERTED,
metric_type=enum.MetricType.IP,
)
indexes = [index_id, index_vector, index_text, index_metadate]
if self._enable_hybrid_search:
indexes.append(index_sparse_vector)
try:
self._client.create_collection(
database_name=self._client_config.database,
@ -103,31 +139,25 @@ class TencentVector(BaseVector):
shard=self._client_config.shard,
replicas=self._client_config.replicas,
description="Collection for Dify",
index=index,
indexes=indexes,
)
except VectorDBException as e:
if "fieldType:json" not in e.message:
raise e
# vdb version not support json, use string
index = vdb_index.Index(
vdb_index.FilterIndex(self.field_id, enum.FieldType.String, enum.IndexType.PRIMARY_KEY),
vdb_index.VectorIndex(
self.field_vector,
dimension,
index_type,
metric_type,
params,
),
vdb_index.FilterIndex(self.field_text, enum.FieldType.String, enum.IndexType.FILTER),
vdb_index.FilterIndex(self.field_metadata, enum.FieldType.String, enum.IndexType.FILTER),
index_metadate = vdb_index.FilterIndex(
self.field_metadata, enum.FieldType.String, enum.IndexType.FILTER
)
indexes = [index_id, index_vector, index_text, index_metadate]
if self._enable_hybrid_search:
indexes.append(index_sparse_vector)
self._client.create_collection(
database_name=self._client_config.database,
collection_name=self._collection_name,
shard=self._client_config.shard,
replicas=self._client_config.replicas,
description="Collection for Dify",
index=index,
indexes=indexes,
)
redis_client.set(collection_exist_cache_key, 1, ex=3600)
@ -155,6 +185,8 @@ class TencentVector(BaseVector):
text=texts[i],
metadata=metadata,
)
if self._enable_hybrid_search:
doc.__dict__["sparse_vector"] = self._bm25.encode_texts(texts[i])
docs.append(doc)
self._client.upsert(
database_name=self._client_config.database,
@ -204,7 +236,32 @@ class TencentVector(BaseVector):
return self._get_search_res(res, score_threshold)
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
return []
if not self._enable_hybrid_search:
return []
res = self._client.hybrid_search(
database_name=self._client_config.database,
collection_name=self.collection_name,
ann=[
AnnSearch(
field_name="vector",
data=[0.0] * self._dimension,
)
],
match=[
KeywordSearch(
field_name="sparse_vector",
data=self._bm25.encode_queries(query),
),
],
rerank=WeightedRerank(
field_list=["vector", "sparse_vector"],
weight=[0, 1],
),
retrieve_vector=False,
limit=kwargs.get("top_k", 4),
)
score_threshold = float(kwargs.get("score_threshold") or 0.0)
return self._get_search_res(res, score_threshold)
def _get_search_res(self, res: list | None, score_threshold: float) -> list[Document]:
docs: list[Document] = []
@ -213,7 +270,7 @@ class TencentVector(BaseVector):
for result in res[0]:
meta = result.get(self.field_metadata)
score = 1 - result.get("score", 0.0)
score = result.get("score", 0.0)
if score > score_threshold:
meta["score"] = score
doc = Document(page_content=result.get(self.field_text), metadata=meta)
@ -245,5 +302,6 @@ class TencentVectorFactory(AbstractVectorFactory):
database=dify_config.TENCENT_VECTOR_DB_DATABASE,
shard=dify_config.TENCENT_VECTOR_DB_SHARD,
replicas=dify_config.TENCENT_VECTOR_DB_REPLICAS,
enable_hybrid_search=dify_config.TENCENT_VECTOR_DB_ENABLE_HYBRID_SEARCH or False,
),
)

View File

@ -18,6 +18,7 @@ from core.rag.extractor.markdown_extractor import MarkdownExtractor
from core.rag.extractor.notion_extractor import NotionExtractor
from core.rag.extractor.pdf_extractor import PdfExtractor
from core.rag.extractor.text_extractor import TextExtractor
from core.rag.extractor.unstructured.unstructured_doc_extractor import UnstructuredWordExtractor
from core.rag.extractor.unstructured.unstructured_eml_extractor import UnstructuredEmailExtractor
from core.rag.extractor.unstructured.unstructured_epub_extractor import UnstructuredEpubExtractor
from core.rag.extractor.unstructured.unstructured_markdown_extractor import UnstructuredMarkdownExtractor
@ -25,6 +26,7 @@ from core.rag.extractor.unstructured.unstructured_msg_extractor import Unstructu
from core.rag.extractor.unstructured.unstructured_ppt_extractor import UnstructuredPPTExtractor
from core.rag.extractor.unstructured.unstructured_pptx_extractor import UnstructuredPPTXExtractor
from core.rag.extractor.unstructured.unstructured_xml_extractor import UnstructuredXmlExtractor
from core.rag.extractor.watercrawl.extractor import WaterCrawlWebExtractor
from core.rag.extractor.word_extractor import WordExtractor
from core.rag.models.document import Document
from extensions.ext_storage import storage
@ -104,7 +106,7 @@ class ExtractProcessor:
etl_type = dify_config.ETL_TYPE
extractor: Optional[BaseExtractor] = None
if etl_type == "Unstructured":
unstructured_api_url = dify_config.UNSTRUCTURED_API_URL
unstructured_api_url = dify_config.UNSTRUCTURED_API_URL or ""
unstructured_api_key = dify_config.UNSTRUCTURED_API_KEY or ""
if file_extension in {".xlsx", ".xls"}:
@ -121,6 +123,8 @@ class ExtractProcessor:
extractor = HtmlExtractor(file_path)
elif file_extension == ".docx":
extractor = WordExtractor(file_path, upload_file.tenant_id, upload_file.created_by)
elif file_extension == ".doc":
extractor = UnstructuredWordExtractor(file_path, unstructured_api_url, unstructured_api_key)
elif file_extension == ".csv":
extractor = CSVExtractor(file_path, autodetect_encoding=True)
elif file_extension == ".msg":
@ -180,6 +184,15 @@ class ExtractProcessor:
only_main_content=extract_setting.website_info.only_main_content,
)
return extractor.extract()
elif extract_setting.website_info.provider == "watercrawl":
extractor = WaterCrawlWebExtractor(
url=extract_setting.website_info.url,
job_id=extract_setting.website_info.job_id,
tenant_id=extract_setting.website_info.tenant_id,
mode=extract_setting.website_info.mode,
only_main_content=extract_setting.website_info.only_main_content,
)
return extractor.extract()
elif extract_setting.website_info.provider == "jinareader":
extractor = JinaReaderWebExtractor(
url=extract_setting.website_info.url,

View File

@ -10,14 +10,11 @@ logger = logging.getLogger(__name__)
class UnstructuredWordExtractor(BaseExtractor):
"""Loader that uses unstructured to load word documents."""
def __init__(
self,
file_path: str,
api_url: str,
):
def __init__(self, file_path: str, api_url: str, api_key: str = ""):
"""Initialize with file path."""
self._file_path = file_path
self._api_url = api_url
self._api_key = api_key
def extract(self) -> list[Document]:
from unstructured.__version__ import __version__ as __unstructured_version__
@ -41,9 +38,10 @@ class UnstructuredWordExtractor(BaseExtractor):
)
if is_doc:
from unstructured.partition.doc import partition_doc
from unstructured.partition.api import partition_via_api
elements = partition_via_api(filename=self._file_path, api_url=self._api_url, api_key=self._api_key)
elements = partition_doc(filename=self._file_path)
else:
from unstructured.partition.docx import partition_docx

View File

@ -0,0 +1,161 @@
import json
from collections.abc import Generator
from typing import Union
from urllib.parse import urljoin
import requests
from requests import Response
class BaseAPIClient:
def __init__(self, api_key, base_url):
self.api_key = api_key
self.base_url = base_url
self.session = self.init_session()
def init_session(self):
session = requests.Session()
session.headers.update({"X-API-Key": self.api_key})
session.headers.update({"Content-Type": "application/json"})
session.headers.update({"Accept": "application/json"})
session.headers.update({"User-Agent": "WaterCrawl-Plugin"})
session.headers.update({"Accept-Language": "en-US"})
return session
def _get(self, endpoint: str, query_params: dict | None = None, **kwargs):
return self.session.get(urljoin(self.base_url, endpoint), params=query_params, **kwargs)
def _post(self, endpoint: str, query_params: dict | None = None, data: dict | None = None, **kwargs):
return self.session.post(urljoin(self.base_url, endpoint), params=query_params, json=data, **kwargs)
def _put(self, endpoint: str, query_params: dict | None = None, data: dict | None = None, **kwargs):
return self.session.put(urljoin(self.base_url, endpoint), params=query_params, json=data, **kwargs)
def _delete(self, endpoint: str, query_params: dict | None = None, **kwargs):
return self.session.delete(urljoin(self.base_url, endpoint), params=query_params, **kwargs)
def _patch(self, endpoint: str, query_params: dict | None = None, data: dict | None = None, **kwargs):
return self.session.patch(urljoin(self.base_url, endpoint), params=query_params, json=data, **kwargs)
class WaterCrawlAPIClient(BaseAPIClient):
def __init__(self, api_key, base_url: str | None = "https://app.watercrawl.dev/"):
super().__init__(api_key, base_url)
def process_eventstream(self, response: Response, download: bool = False) -> Generator:
for line in response.iter_lines():
line = line.decode("utf-8")
if line.startswith("data:"):
line = line[5:].strip()
data = json.loads(line)
if data["type"] == "result" and download:
data["data"] = self.download_result(data["data"])
yield data
def process_response(self, response: Response) -> dict | bytes | list | None | Generator:
response.raise_for_status()
if response.status_code == 204:
return None
if response.headers.get("Content-Type") == "application/json":
return response.json() or {}
if response.headers.get("Content-Type") == "application/octet-stream":
return response.content
if response.headers.get("Content-Type") == "text/event-stream":
return self.process_eventstream(response)
raise Exception(f"Unknown response type: {response.headers.get('Content-Type')}")
def get_crawl_requests_list(self, page: int | None = None, page_size: int | None = None):
query_params = {"page": page or 1, "page_size": page_size or 10}
return self.process_response(
self._get(
"/api/v1/core/crawl-requests/",
query_params=query_params,
)
)
def get_crawl_request(self, item_id: str):
return self.process_response(
self._get(
f"/api/v1/core/crawl-requests/{item_id}/",
)
)
def create_crawl_request(
self,
url: Union[list, str] | None = None,
spider_options: dict | None = None,
page_options: dict | None = None,
plugin_options: dict | None = None,
):
data = {
# 'urls': url if isinstance(url, list) else [url],
"url": url,
"options": {
"spider_options": spider_options or {},
"page_options": page_options or {},
"plugin_options": plugin_options or {},
},
}
return self.process_response(
self._post(
"/api/v1/core/crawl-requests/",
data=data,
)
)
def stop_crawl_request(self, item_id: str):
return self.process_response(
self._delete(
f"/api/v1/core/crawl-requests/{item_id}/",
)
)
def download_crawl_request(self, item_id: str):
return self.process_response(
self._get(
f"/api/v1/core/crawl-requests/{item_id}/download/",
)
)
def monitor_crawl_request(self, item_id: str, prefetched=False) -> Generator:
query_params = {"prefetched": str(prefetched).lower()}
generator = self.process_response(
self._get(f"/api/v1/core/crawl-requests/{item_id}/status/", stream=True, query_params=query_params),
)
if not isinstance(generator, Generator):
raise ValueError("Generator expected")
yield from generator
def get_crawl_request_results(
self, item_id: str, page: int = 1, page_size: int = 25, query_params: dict | None = None
):
query_params = query_params or {}
query_params.update({"page": page or 1, "page_size": page_size or 25})
return self.process_response(
self._get(f"/api/v1/core/crawl-requests/{item_id}/results/", query_params=query_params)
)
def scrape_url(
self,
url: str,
page_options: dict | None = None,
plugin_options: dict | None = None,
sync: bool = True,
prefetched: bool = True,
):
response_result = self.create_crawl_request(url=url, page_options=page_options, plugin_options=plugin_options)
if not sync:
return response_result
for event_data in self.monitor_crawl_request(response_result["uuid"], prefetched):
if event_data["type"] == "result":
return event_data["data"]
def download_result(self, result_object: dict):
response = requests.get(result_object["result"])
response.raise_for_status()
result_object["result"] = response.json()
return result_object

View File

@ -0,0 +1,57 @@
from core.rag.extractor.extractor_base import BaseExtractor
from core.rag.models.document import Document
from services.website_service import WebsiteService
class WaterCrawlWebExtractor(BaseExtractor):
"""
Crawl and scrape websites and return content in clean llm-ready markdown.
Args:
url: The URL to scrape.
api_key: The API key for WaterCrawl.
base_url: The base URL for the Firecrawl API. Defaults to 'https://app.firecrawl.dev'.
mode: The mode of operation. Defaults to 'scrape'. Options are 'crawl', 'scrape' and 'crawl_return_urls'.
only_main_content: Only return the main content of the page excluding headers, navs, footers, etc.
"""
def __init__(self, url: str, job_id: str, tenant_id: str, mode: str = "crawl", only_main_content: bool = True):
"""Initialize with url, api_key, base_url and mode."""
self._url = url
self.job_id = job_id
self.tenant_id = tenant_id
self.mode = mode
self.only_main_content = only_main_content
def extract(self) -> list[Document]:
"""Extract content from the URL."""
documents = []
if self.mode == "crawl":
crawl_data = WebsiteService.get_crawl_url_data(self.job_id, "watercrawl", self._url, self.tenant_id)
if crawl_data is None:
return []
document = Document(
page_content=crawl_data.get("markdown", ""),
metadata={
"source_url": crawl_data.get("source_url"),
"description": crawl_data.get("description"),
"title": crawl_data.get("title"),
},
)
documents.append(document)
elif self.mode == "scrape":
scrape_data = WebsiteService.get_scrape_url_data(
"watercrawl", self._url, self.tenant_id, self.only_main_content
)
document = Document(
page_content=scrape_data.get("markdown", ""),
metadata={
"source_url": scrape_data.get("source_url"),
"description": scrape_data.get("description"),
"title": scrape_data.get("title"),
},
)
documents.append(document)
return documents

View File

@ -0,0 +1,117 @@
from collections.abc import Generator
from datetime import datetime
from typing import Any
from core.rag.extractor.watercrawl.client import WaterCrawlAPIClient
class WaterCrawlProvider:
def __init__(self, api_key, base_url: str | None = None):
self.client = WaterCrawlAPIClient(api_key, base_url)
def crawl_url(self, url, options: dict | Any = None) -> dict:
options = options or {}
spider_options = {
"max_depth": 1,
"page_limit": 1,
"allowed_domains": [],
"exclude_paths": [],
"include_paths": [],
}
if options.get("crawl_sub_pages", True):
spider_options["page_limit"] = options.get("limit", 1)
spider_options["max_depth"] = options.get("depth", 1)
spider_options["include_paths"] = options.get("includes", "").split(",") if options.get("includes") else []
spider_options["exclude_paths"] = options.get("excludes", "").split(",") if options.get("excludes") else []
wait_time = options.get("wait_time", 1000)
page_options = {
"exclude_tags": options.get("exclude_tags", "").split(",") if options.get("exclude_tags") else [],
"include_tags": options.get("include_tags", "").split(",") if options.get("include_tags") else [],
"wait_time": max(1000, wait_time), # minimum wait time is 1 second
"include_html": False,
"only_main_content": options.get("only_main_content", True),
"include_links": False,
"timeout": 15000,
"accept_cookies_selector": "#cookies-accept",
"locale": "en-US",
"actions": [],
}
result = self.client.create_crawl_request(url=url, spider_options=spider_options, page_options=page_options)
return {"status": "active", "job_id": result.get("uuid")}
def get_crawl_status(self, crawl_request_id) -> dict:
response = self.client.get_crawl_request(crawl_request_id)
data = []
if response["status"] in ["new", "running"]:
status = "active"
else:
status = "completed"
data = list(self._get_results(crawl_request_id))
time_str = response.get("duration")
time_consuming: float = 0
if time_str:
time_obj = datetime.strptime(time_str, "%H:%M:%S.%f")
time_consuming = (
time_obj.hour * 3600 + time_obj.minute * 60 + time_obj.second + time_obj.microsecond / 1_000_000
)
return {
"status": status,
"job_id": response.get("uuid"),
"total": response.get("options", {}).get("spider_options", {}).get("page_limit", 1),
"current": response.get("number_of_documents", 0),
"data": data,
"time_consuming": time_consuming,
}
def get_crawl_url_data(self, job_id, url) -> dict | None:
if not job_id:
return self.scrape_url(url)
for result in self._get_results(
job_id,
{
# filter by url
"url": url
},
):
return result
return None
def scrape_url(self, url: str) -> dict:
response = self.client.scrape_url(url=url, sync=True, prefetched=True)
return self._structure_data(response)
def _structure_data(self, result_object: dict) -> dict:
if isinstance(result_object.get("result", {}), str):
raise ValueError("Invalid result object. Expected a dictionary.")
metadata = result_object.get("result", {}).get("metadata", {})
return {
"title": metadata.get("og:title") or metadata.get("title"),
"description": metadata.get("description"),
"source_url": result_object.get("url"),
"markdown": result_object.get("result", {}).get("markdown"),
}
def _get_results(self, crawl_request_id: str, query_params: dict | None = None) -> Generator[dict, None, None]:
page = 0
page_size = 100
query_params = query_params or {}
query_params.update({"prefetched": "true"})
while True:
page += 1
response = self.client.get_crawl_request_results(crawl_request_id, page, page_size, query_params)
if not response["results"]:
break
for result in response["results"]:
yield self._structure_data(result)
if response["next"] is None:
break

View File

@ -85,7 +85,7 @@ class WordExtractor(BaseExtractor):
if "image" in rel.target_ref:
image_count += 1
if rel.is_external:
url = rel.reltype
url = rel.target_ref
response = ssrf_proxy.get(url)
if response.status_code == 200:
image_ext = mimetypes.guess_extension(response.headers["Content-Type"])

View File

@ -30,6 +30,7 @@ class NodeRunMetadataKey(StrEnum):
ITERATION_DURATION_MAP = "iteration_duration_map" # single iteration duration if iteration node runs
LOOP_DURATION_MAP = "loop_duration_map" # single loop duration if loop node runs
ERROR_STRATEGY = "error_strategy" # node in continue on error mode return the field
LOOP_VARIABLE_MAP = "loop_variable_map" # single loop variable output
class NodeRunResult(BaseModel):

View File

@ -1,15 +1,18 @@
import json
from collections.abc import Generator, Mapping, Sequence
from typing import Any, cast
from typing import Any, Optional, cast
from core.agent.entities import AgentToolEntity
from core.agent.plugin_entities import AgentStrategyParameter
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
from core.plugin.manager.exc import PluginDaemonClientSideError
from core.plugin.manager.plugin import PluginInstallationManager
from core.provider_manager import ProviderManager
from core.tools.entities.tool_entities import ToolParameter, ToolProviderType
from core.tools.tool_manager import ToolManager
from core.variables.segments import StringSegment
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariableKey
@ -19,7 +22,9 @@ from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.event.event import RunCompletedEvent
from core.workflow.nodes.tool.tool_node import ToolNode
from core.workflow.utils.variable_template_parser import VariableTemplateParser
from extensions.ext_database import db
from factories.agent_factory import get_plugin_agent_strategy
from models.model import Conversation
from models.workflow import WorkflowNodeExecutionStatus
@ -233,17 +238,20 @@ class AgentNode(ToolNode):
value = tool_value
if parameter.type == "model-selector":
value = cast(dict[str, Any], value)
model_instance = ModelManager().get_model_instance(
tenant_id=self.tenant_id,
provider=value.get("provider", ""),
model_type=ModelType(value.get("model_type", "")),
model=value.get("model", ""),
)
models = model_instance.model_type_instance.plugin_model_provider.declaration.models
finded_model = next((model for model in models if model.model == value.get("model", "")), None)
value["entity"] = finded_model.model_dump(mode="json") if finded_model else None
model_instance, model_schema = self._fetch_model(value)
# memory config
history_prompt_messages = []
if node_data.memory:
memory = self._fetch_memory(model_instance)
if memory:
prompt_messages = memory.get_history_prompt_messages(
message_limit=node_data.memory.window.size if node_data.memory.window.size else None
)
history_prompt_messages = [
prompt_message.model_dump(mode="json") for prompt_message in prompt_messages
]
value["history_prompt_messages"] = history_prompt_messages
value["entity"] = model_schema.model_dump(mode="json") if model_schema else None
result[parameter_name] = value
return result
@ -297,3 +305,46 @@ class AgentNode(ToolNode):
except StopIteration:
icon = None
return icon
def _fetch_memory(self, model_instance: ModelInstance) -> Optional[TokenBufferMemory]:
# get conversation id
conversation_id_variable = self.graph_runtime_state.variable_pool.get(
["sys", SystemVariableKey.CONVERSATION_ID.value]
)
if not isinstance(conversation_id_variable, StringSegment):
return None
conversation_id = conversation_id_variable.value
# get conversation
conversation = (
db.session.query(Conversation)
.filter(Conversation.app_id == self.app_id, Conversation.id == conversation_id)
.first()
)
if not conversation:
return None
memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance)
return memory
def _fetch_model(self, value: dict[str, Any]) -> tuple[ModelInstance, AIModelEntity | None]:
provider_manager = ProviderManager()
provider_model_bundle = provider_manager.get_provider_model_bundle(
tenant_id=self.tenant_id, provider=value.get("provider", ""), model_type=ModelType.LLM
)
model_name = value.get("model", "")
model_credentials = provider_model_bundle.configuration.get_current_credentials(
model_type=ModelType.LLM, model=model_name
)
provider_name = provider_model_bundle.configuration.provider.provider
model_type_instance = provider_model_bundle.model_type_instance
model_instance = ModelManager().get_model_instance(
tenant_id=self.tenant_id,
provider=provider_name,
model_type=ModelType(value.get("model_type", "")),
model=model_name,
)
model_schema = model_type_instance.get_model_schema(model_name, model_credentials)
return model_instance, model_schema

View File

@ -3,6 +3,7 @@ from typing import Any, Literal, Union
from pydantic import BaseModel
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
from core.tools.entities.tool_entities import ToolSelector
from core.workflow.nodes.base.entities import BaseNodeData
@ -11,6 +12,7 @@ class AgentNodeData(BaseNodeData):
agent_strategy_provider_name: str # redundancy
agent_strategy_name: str
agent_strategy_label: str # redundancy
memory: MemoryConfig | None = None
class AgentInput(BaseModel):
value: Union[list[str], list[ToolSelector], Any]

View File

@ -6,6 +6,7 @@ from core.helper.code_executor.code_executor import CodeExecutionError, CodeExec
from core.helper.code_executor.code_node_provider import CodeNodeProvider
from core.helper.code_executor.javascript.javascript_code_provider import JavascriptCodeProvider
from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider
from core.variables.segments import ArrayFileSegment
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.code.entities import CodeNodeData
@ -49,7 +50,10 @@ class CodeNode(BaseNode[CodeNodeData]):
for variable_selector in self.node_data.variables:
variable_name = variable_selector.variable
variable = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector)
variables[variable_name] = variable.to_object() if variable else None
if isinstance(variable, ArrayFileSegment):
variables[variable_name] = [v.to_dict() for v in variable.value] if variable.value else None
else:
variables[variable_name] = variable.to_object() if variable else None
# Run code
try:
result = CodeExecutor.execute_workflow_code_template(

View File

@ -17,6 +17,7 @@ class NodeType(StrEnum):
LEGACY_VARIABLE_AGGREGATOR = "variable-assigner" # TODO: Merge this into VARIABLE_AGGREGATOR in the database.
LOOP = "loop"
LOOP_START = "loop-start"
LOOP_END = "loop-end"
ITERATION = "iteration"
ITERATION_START = "iteration-start" # Fake start node for iteration.
PARAMETER_EXTRACTOR = "parameter-extractor"

View File

@ -8,7 +8,7 @@ from core.workflow.utils.condition.entities import Condition
class IfElseNodeData(BaseNodeData):
"""
Answer Node Data.
If Else Node Data.
"""
class Case(BaseModel):

View File

@ -1,5 +1,6 @@
from .entities import LoopNodeData
from .loop_end_node import LoopEndNode
from .loop_node import LoopNode
from .loop_start_node import LoopStartNode
__all__ = ["LoopNode", "LoopNodeData", "LoopStartNode"]
__all__ = ["LoopEndNode", "LoopNode", "LoopNodeData", "LoopStartNode"]

View File

@ -1,11 +1,23 @@
from collections.abc import Mapping
from typing import Any, Literal, Optional
from pydantic import Field
from pydantic import BaseModel, Field
from core.workflow.nodes.base import BaseLoopNodeData, BaseLoopState, BaseNodeData
from core.workflow.utils.condition.entities import Condition
class LoopVariableData(BaseModel):
"""
Loop Variable Data.
"""
label: str
var_type: Literal["string", "number", "object", "array[string]", "array[number]", "array[object]"]
value_type: Literal["variable", "constant"]
value: Optional[Any | list[str]] = None
class LoopNodeData(BaseLoopNodeData):
"""
Loop Node Data.
@ -14,6 +26,8 @@ class LoopNodeData(BaseLoopNodeData):
loop_count: int # Maximum number of loops
break_conditions: list[Condition] # Conditions to break the loop
logical_operator: Literal["and", "or"]
loop_variables: Optional[list[LoopVariableData]] = Field(default_factory=list)
outputs: Optional[Mapping[str, Any]] = None
class LoopStartNodeData(BaseNodeData):
@ -24,6 +38,14 @@ class LoopStartNodeData(BaseNodeData):
pass
class LoopEndNodeData(BaseNodeData):
"""
Loop End Node Data.
"""
pass
class LoopState(BaseLoopState):
"""
Loop State.

View File

@ -0,0 +1,20 @@
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.loop.entities import LoopEndNodeData
from models.workflow import WorkflowNodeExecutionStatus
class LoopEndNode(BaseNode[LoopEndNodeData]):
"""
Loop End Node.
"""
_node_data_cls = LoopEndNodeData
_node_type = NodeType.LOOP_END
def _run(self) -> NodeRunResult:
"""
Run the node.
"""
return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED)

View File

@ -1,10 +1,20 @@
import json
import logging
from collections.abc import Generator, Mapping, Sequence
from datetime import UTC, datetime
from typing import Any, cast
from typing import TYPE_CHECKING, Any, Literal, cast
from configs import dify_config
from core.variables import IntegerSegment
from core.variables import (
ArrayNumberSegment,
ArrayObjectSegment,
ArrayStringSegment,
IntegerSegment,
ObjectSegment,
Segment,
SegmentType,
StringSegment,
)
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult
from core.workflow.graph_engine.entities.event import (
BaseGraphEvent,
@ -29,6 +39,10 @@ from core.workflow.nodes.loop.entities import LoopNodeData
from core.workflow.utils.condition.processor import ConditionProcessor
from models.workflow import WorkflowNodeExecutionStatus
if TYPE_CHECKING:
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine.graph_engine import GraphEngine
logger = logging.getLogger(__name__)
@ -61,6 +75,28 @@ class LoopNode(BaseNode[LoopNodeData]):
variable_pool = self.graph_runtime_state.variable_pool
variable_pool.add([self.node_id, "index"], 0)
# Initialize loop variables
loop_variable_selectors = {}
if self.node_data.loop_variables:
for loop_variable in self.node_data.loop_variables:
value_processor = {
"constant": lambda var=loop_variable: self._get_segment_for_constant(var.var_type, var.value),
"variable": lambda var=loop_variable: variable_pool.get(var.value),
}
if loop_variable.value_type not in value_processor:
raise ValueError(
f"Invalid value type '{loop_variable.value_type}' for loop variable {loop_variable.label}"
)
processed_segment = value_processor[loop_variable.value_type]()
if not processed_segment:
raise ValueError(f"Invalid value for loop variable {loop_variable.label}")
variable_selector = [self.node_id, loop_variable.label]
variable_pool.add(variable_selector, processed_segment.value)
loop_variable_selectors[loop_variable.label] = variable_selector
inputs[loop_variable.label] = processed_segment.value
from core.workflow.graph_engine.graph_engine import GraphEngine
graph_engine = GraphEngine(
@ -95,135 +131,51 @@ class LoopNode(BaseNode[LoopNodeData]):
predecessor_node_id=self.previous_node_id,
)
yield LoopRunNextEvent(
loop_id=self.id,
loop_node_id=self.node_id,
loop_node_type=self.node_type,
loop_node_data=self.node_data,
index=0,
pre_loop_output=None,
)
# yield LoopRunNextEvent(
# loop_id=self.id,
# loop_node_id=self.node_id,
# loop_node_type=self.node_type,
# loop_node_data=self.node_data,
# index=0,
# pre_loop_output=None,
# )
loop_duration_map = {}
single_loop_variable_map = {} # single loop variable output
try:
check_break_result = False
for i in range(loop_count):
# Run workflow
rst = graph_engine.run()
current_index_variable = variable_pool.get([self.node_id, "index"])
if not isinstance(current_index_variable, IntegerSegment):
raise ValueError(f"loop {self.node_id} current index not found")
current_index = current_index_variable.value
loop_start_time = datetime.now(UTC).replace(tzinfo=None)
# run single loop
loop_result = yield from self._run_single_loop(
graph_engine=graph_engine,
loop_graph=loop_graph,
variable_pool=variable_pool,
loop_variable_selectors=loop_variable_selectors,
break_conditions=break_conditions,
logical_operator=logical_operator,
condition_processor=condition_processor,
current_index=i,
start_at=start_at,
inputs=inputs,
)
loop_end_time = datetime.now(UTC).replace(tzinfo=None)
check_break_result = False
for event in rst:
if isinstance(event, (BaseNodeEvent | BaseParallelBranchEvent)) and not event.in_loop_id:
event.in_loop_id = self.node_id
if (
isinstance(event, BaseNodeEvent)
and event.node_type == NodeType.LOOP_START
and not isinstance(event, NodeRunStreamChunkEvent)
):
continue
if isinstance(event, NodeRunSucceededEvent):
yield self._handle_event_metadata(event=event, iter_run_index=current_index)
# Check if all variables in break conditions exist
exists_variable = False
for condition in break_conditions:
if not self.graph_runtime_state.variable_pool.get(condition.variable_selector):
exists_variable = False
break
else:
exists_variable = True
if exists_variable:
input_conditions, group_result, check_break_result = condition_processor.process_conditions(
variable_pool=self.graph_runtime_state.variable_pool,
conditions=break_conditions,
operator=logical_operator,
)
if check_break_result:
break
elif isinstance(event, BaseGraphEvent):
if isinstance(event, GraphRunFailedEvent):
# Loop run failed
yield LoopRunFailedEvent(
loop_id=self.id,
loop_node_id=self.node_id,
loop_node_type=self.node_type,
loop_node_data=self.node_data,
start_at=start_at,
inputs=inputs,
steps=i,
metadata={
NodeRunMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens,
"completed_reason": "error",
},
error=event.error,
)
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=event.error,
metadata={
NodeRunMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens
},
)
)
return
elif isinstance(event, NodeRunFailedEvent):
# Loop run failed
yield event
yield LoopRunFailedEvent(
loop_id=self.id,
loop_node_id=self.node_id,
loop_node_type=self.node_type,
loop_node_data=self.node_data,
start_at=start_at,
inputs=inputs,
steps=i,
metadata={
NodeRunMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens,
"completed_reason": "error",
},
error=event.error,
)
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=event.error,
metadata={
NodeRunMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens
},
)
)
return
single_loop_variable = {}
for key, selector in loop_variable_selectors.items():
item = variable_pool.get(selector)
if item:
single_loop_variable[key] = item.value
else:
yield self._handle_event_metadata(event=cast(InNodeEvent, event), iter_run_index=current_index)
single_loop_variable[key] = None
# Remove all nodes outputs from variable pool
for node_id in loop_graph.node_ids:
variable_pool.remove([node_id])
loop_duration_map[str(i)] = (loop_end_time - loop_start_time).total_seconds()
single_loop_variable_map[str(i)] = single_loop_variable
check_break_result = loop_result.get("check_break_result", False)
if check_break_result:
break
# Move to next loop
next_index = current_index + 1
variable_pool.add([self.node_id, "index"], next_index)
yield LoopRunNextEvent(
loop_id=self.id,
loop_node_id=self.node_id,
loop_node_type=self.node_type,
loop_node_data=self.node_data,
index=next_index,
pre_loop_output=None,
)
# Loop completed successfully
yield LoopRunSucceededEvent(
loop_id=self.id,
@ -232,17 +184,26 @@ class LoopNode(BaseNode[LoopNodeData]):
loop_node_data=self.node_data,
start_at=start_at,
inputs=inputs,
outputs=self.node_data.outputs,
steps=loop_count,
metadata={
NodeRunMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens,
"completed_reason": "loop_break" if check_break_result else "loop_completed",
NodeRunMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
NodeRunMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
},
)
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
metadata={NodeRunMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens},
metadata={
NodeRunMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens,
NodeRunMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
NodeRunMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
},
outputs=self.node_data.outputs,
inputs=inputs,
)
)
@ -260,6 +221,8 @@ class LoopNode(BaseNode[LoopNodeData]):
metadata={
"total_tokens": graph_engine.graph_runtime_state.total_tokens,
"completed_reason": "error",
NodeRunMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
NodeRunMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
},
error=str(e),
)
@ -268,7 +231,11 @@ class LoopNode(BaseNode[LoopNodeData]):
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=str(e),
metadata={NodeRunMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens},
metadata={
NodeRunMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens,
NodeRunMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
NodeRunMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
},
)
)
@ -276,6 +243,159 @@ class LoopNode(BaseNode[LoopNodeData]):
# Clean up
variable_pool.remove([self.node_id, "index"])
def _run_single_loop(
self,
*,
graph_engine: "GraphEngine",
loop_graph: Graph,
variable_pool: "VariablePool",
loop_variable_selectors: dict,
break_conditions: list,
logical_operator: Literal["and", "or"],
condition_processor: ConditionProcessor,
current_index: int,
start_at: datetime,
inputs: dict,
) -> Generator[NodeEvent | InNodeEvent, None, dict]:
"""Run a single loop iteration.
Returns:
dict: {'check_break_result': bool}
"""
# Run workflow
rst = graph_engine.run()
current_index_variable = variable_pool.get([self.node_id, "index"])
if not isinstance(current_index_variable, IntegerSegment):
raise ValueError(f"loop {self.node_id} current index not found")
current_index = current_index_variable.value
check_break_result = False
for event in rst:
if isinstance(event, (BaseNodeEvent | BaseParallelBranchEvent)) and not event.in_loop_id:
event.in_loop_id = self.node_id
if (
isinstance(event, BaseNodeEvent)
and event.node_type == NodeType.LOOP_START
and not isinstance(event, NodeRunStreamChunkEvent)
):
continue
if (
isinstance(event, NodeRunSucceededEvent)
and event.node_type == NodeType.LOOP_END
and not isinstance(event, NodeRunStreamChunkEvent)
):
check_break_result = True
yield self._handle_event_metadata(event=event, iter_run_index=current_index)
break
if isinstance(event, NodeRunSucceededEvent):
yield self._handle_event_metadata(event=event, iter_run_index=current_index)
# Check if all variables in break conditions exist
exists_variable = False
for condition in break_conditions:
if not self.graph_runtime_state.variable_pool.get(condition.variable_selector):
exists_variable = False
break
else:
exists_variable = True
if exists_variable:
input_conditions, group_result, check_break_result = condition_processor.process_conditions(
variable_pool=self.graph_runtime_state.variable_pool,
conditions=break_conditions,
operator=logical_operator,
)
if check_break_result:
break
elif isinstance(event, BaseGraphEvent):
if isinstance(event, GraphRunFailedEvent):
# Loop run failed
yield LoopRunFailedEvent(
loop_id=self.id,
loop_node_id=self.node_id,
loop_node_type=self.node_type,
loop_node_data=self.node_data,
start_at=start_at,
inputs=inputs,
steps=current_index,
metadata={
NodeRunMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens,
"completed_reason": "error",
},
error=event.error,
)
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=event.error,
metadata={NodeRunMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens},
)
)
return {"check_break_result": True}
elif isinstance(event, NodeRunFailedEvent):
# Loop run failed
yield event
yield LoopRunFailedEvent(
loop_id=self.id,
loop_node_id=self.node_id,
loop_node_type=self.node_type,
loop_node_data=self.node_data,
start_at=start_at,
inputs=inputs,
steps=current_index,
metadata={
NodeRunMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens,
"completed_reason": "error",
},
error=event.error,
)
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=event.error,
metadata={NodeRunMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens},
)
)
return {"check_break_result": True}
else:
yield self._handle_event_metadata(event=cast(InNodeEvent, event), iter_run_index=current_index)
# Remove all nodes outputs from variable pool
for node_id in loop_graph.node_ids:
variable_pool.remove([node_id])
_outputs = {}
for loop_variable_key, loop_variable_selector in loop_variable_selectors.items():
_loop_variable_segment = variable_pool.get(loop_variable_selector)
if _loop_variable_segment:
_outputs[loop_variable_key] = _loop_variable_segment.value
else:
_outputs[loop_variable_key] = None
_outputs["loop_round"] = current_index + 1
self.node_data.outputs = _outputs
if check_break_result:
return {"check_break_result": True}
# Move to next loop
next_index = current_index + 1
variable_pool.add([self.node_id, "index"], next_index)
yield LoopRunNextEvent(
loop_id=self.id,
loop_node_id=self.node_id,
loop_node_type=self.node_type,
loop_node_data=self.node_data,
index=next_index,
pre_loop_output=self.node_data.outputs,
)
return {"check_break_result": False}
def _handle_event_metadata(
self,
*,
@ -360,3 +480,25 @@ class LoopNode(BaseNode[LoopNodeData]):
}
return variable_mapping
@staticmethod
def _get_segment_for_constant(var_type: str, value: Any) -> Segment:
"""Get the appropriate segment type for a constant value."""
segment_mapping: dict[str, tuple[type[Segment], SegmentType]] = {
"string": (StringSegment, SegmentType.STRING),
"number": (IntegerSegment, SegmentType.NUMBER),
"object": (ObjectSegment, SegmentType.OBJECT),
"array[string]": (ArrayStringSegment, SegmentType.ARRAY_STRING),
"array[number]": (ArrayNumberSegment, SegmentType.ARRAY_NUMBER),
"array[object]": (ArrayObjectSegment, SegmentType.ARRAY_OBJECT),
}
if var_type in ["array[string]", "array[number]", "array[object]"]:
if value:
value = json.loads(value)
else:
value = []
segment_info = segment_mapping.get(var_type)
if not segment_info:
raise ValueError(f"Invalid variable type: {var_type}")
segment_class, value_type = segment_info
return segment_class(value=value, value_type=value_type)

View File

@ -13,7 +13,7 @@ from core.workflow.nodes.iteration import IterationNode, IterationStartNode
from core.workflow.nodes.knowledge_retrieval import KnowledgeRetrievalNode
from core.workflow.nodes.list_operator import ListOperatorNode
from core.workflow.nodes.llm import LLMNode
from core.workflow.nodes.loop import LoopNode, LoopStartNode
from core.workflow.nodes.loop import LoopEndNode, LoopNode, LoopStartNode
from core.workflow.nodes.parameter_extractor import ParameterExtractorNode
from core.workflow.nodes.question_classifier import QuestionClassifierNode
from core.workflow.nodes.start import StartNode
@ -94,6 +94,10 @@ NODE_TYPE_CLASSES_MAPPING: Mapping[NodeType, Mapping[str, type[BaseNode]]] = {
LATEST_VERSION: LoopStartNode,
"1": LoopStartNode,
},
NodeType.LOOP_END: {
LATEST_VERSION: LoopEndNode,
"1": LoopEndNode,
},
NodeType.PARAMETER_EXTRACTOR: {
LATEST_VERSION: ParameterExtractorNode,
"1": ParameterExtractorNode,

View File

@ -4,7 +4,7 @@ from core.workflow.nodes.base import BaseNodeData
class TemplateTransformNodeData(BaseNodeData):
"""
Code Node Data.
Template Transform Node Data.
"""
variables: list[VariableSelector]

View File

@ -26,7 +26,7 @@ class AdvancedSettings(BaseModel):
class VariableAssignerNodeData(BaseNodeData):
"""
Knowledge retrieval Node Data.
Variable Assigner Node Data.
"""
type: str = "variable-assigner"

View File

@ -2,6 +2,7 @@ import json
from collections.abc import Sequence
from typing import Any, cast
from core.app.entities.app_invoke_entities import InvokeFrom
from core.variables import SegmentType, Variable
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
from core.workflow.entities.node_entities import NodeRunResult
@ -123,13 +124,14 @@ class VariableAssignerNode(BaseNode[VariableAssignerNodeData]):
if variable.selector[0] == CONVERSATION_VARIABLE_NODE_ID:
conversation_id = self.graph_runtime_state.variable_pool.get(["sys", "conversation_id"])
if not conversation_id:
raise ConversationIDNotFoundError
if self.invoke_from != InvokeFrom.DEBUGGER:
raise ConversationIDNotFoundError
else:
conversation_id = conversation_id.value
common_helpers.update_conversation_variable(
conversation_id=cast(str, conversation_id),
variable=variable,
)
common_helpers.update_conversation_variable(
conversation_id=cast(str, conversation_id),
variable=variable,
)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,

View File

@ -134,8 +134,9 @@ def _build_from_local_file(
if row is None:
raise ValueError("Invalid upload file")
file_type = FileType(mapping.get("type", "custom"))
file_type = _standardize_file_type(file_type, extension="." + row.extension, mime_type=row.mime_type)
file_type = _standardize_file_type(extension="." + row.extension, mime_type=row.mime_type)
if file_type.value != mapping.get("type", "custom"):
raise ValueError("Detected file type does not match the specified type. Please verify the file.")
return File(
id=mapping.get("id"),
@ -173,10 +174,9 @@ def _build_from_remote_url(
if upload_file is None:
raise ValueError("Invalid upload file")
file_type = FileType(mapping.get("type", "custom"))
file_type = _standardize_file_type(
file_type, extension="." + upload_file.extension, mime_type=upload_file.mime_type
)
file_type = _standardize_file_type(extension="." + upload_file.extension, mime_type=upload_file.mime_type)
if file_type.value != mapping.get("type", "custom"):
raise ValueError("Detected file type does not match the specified type. Please verify the file.")
return File(
id=mapping.get("id"),
@ -198,8 +198,9 @@ def _build_from_remote_url(
mime_type, filename, file_size = _get_remote_file_info(url)
extension = mimetypes.guess_extension(mime_type) or ("." + filename.split(".")[-1] if "." in filename else ".bin")
file_type = FileType(mapping.get("type", "custom"))
file_type = _standardize_file_type(file_type, extension=extension, mime_type=mime_type)
file_type = _standardize_file_type(extension=extension, mime_type=mime_type)
if file_type.value != mapping.get("type", "custom"):
raise ValueError("Detected file type does not match the specified type. Please verify the file.")
return File(
id=mapping.get("id"),
@ -250,8 +251,8 @@ def _build_from_tool_file(
raise ValueError(f"ToolFile {mapping.get('tool_file_id')} not found")
extension = "." + tool_file.file_key.split(".")[-1] if "." in tool_file.file_key else ".bin"
file_type = FileType(mapping.get("type", "custom"))
file_type = _standardize_file_type(file_type, extension=extension, mime_type=tool_file.mimetype)
file_type = _standardize_file_type(extension=extension, mime_type=tool_file.mimetype)
return File(
id=mapping.get("id"),
@ -302,12 +303,10 @@ def _is_file_valid_with_config(
return True
def _standardize_file_type(file_type: FileType, /, *, extension: str = "", mime_type: str = "") -> FileType:
def _standardize_file_type(*, extension: str = "", mime_type: str = "") -> FileType:
"""
If custom type, try to guess the file type by extension and mime_type.
Infer the possible actual type of the file based on the extension and mime_type
"""
if file_type != FileType.CUSTOM:
return FileType(file_type)
guessed_type = None
if extension:
guessed_type = _get_file_type_by_extension(extension)

14
api/poetry.lock generated
View File

@ -1,4 +1,4 @@
# This file is automatically @generated by Poetry 2.1.1 and should not be changed by hand.
# This file is automatically @generated by Poetry 2.1.2 and should not be changed by hand.
[[package]]
name = "aiofiles"
@ -501,7 +501,7 @@ description = "Timeout context manager for asyncio programs"
optional = false
python-versions = ">=3.8"
groups = ["main"]
markers = "python_full_version < \"3.11.3\""
markers = "python_version == \"3.11\" and python_full_version < \"3.11.3\""
files = [
{file = "async_timeout-5.0.1-py3-none-any.whl", hash = "sha256:39e3809566ff85354557ec2398b55e096c8364bacac9405a7a1fa429e77fe76c"},
{file = "async_timeout-5.0.1.tar.gz", hash = "sha256:d9321a7a3d5a6a5e187e824d2fa0793ce379a202935782d555d6e9d2735677d3"},
@ -1199,7 +1199,7 @@ files = [
{file = "cffi-1.17.1-cp39-cp39-win_amd64.whl", hash = "sha256:d016c76bdd850f3c626af19b0542c9677ba156e4ee4fccfdd7848803533ef662"},
{file = "cffi-1.17.1.tar.gz", hash = "sha256:1c39c6016c32bc48dd54561950ebd6836e1670f2ae46128f67cf49e789c52824"},
]
markers = {storage = "platform_python_implementation != \"PyPy\"", vdb = "python_version < \"3.12\" or platform_python_implementation != \"PyPy\""}
markers = {storage = "platform_python_implementation != \"PyPy\"", vdb = "platform_python_implementation != \"PyPy\""}
[package.dependencies]
pycparser = "*"
@ -2947,6 +2947,7 @@ files = [
{file = "google_crc32c-1.7.0-cp313-cp313-macosx_12_0_x86_64.whl", hash = "sha256:6a40522958040051c755a173eb98c05ad4d64a6dd898888c3e5ccca2d1cbdcdc"},
{file = "google_crc32c-1.7.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f714fe5cdf5007d7064c57cf7471a99e0cbafda24ddfa829117fc3baafa424f7"},
{file = "google_crc32c-1.7.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f04e58dbe1bf0c9398e603a9be5aaa09e0ba7eb022a3293195d8749459a01069"},
{file = "google_crc32c-1.7.0-cp313-cp313-win_amd64.whl", hash = "sha256:e545b51ddf97f604d30114f7c23eecaf4c06cd6c023ff1ae0b80dcd99af32833"},
{file = "google_crc32c-1.7.0-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:364067b063664dd8d1fec75a3fe85edf05c46f688365269beccaf42ef5dfe889"},
{file = "google_crc32c-1.7.0-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e1b0d6044799f6ac51d1cc2decb997280a83c448b3bef517a54b57a3b71921c0"},
{file = "google_crc32c-1.7.0-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:02bc3295d26cd7666521fd6d5b7b93923ae1eb4417ddd3bc57185a5881ad7b96"},
@ -6367,7 +6368,6 @@ files = [
{file = "psycopg2_binary-2.9.10-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:bb89f0a835bcfc1d42ccd5f41f04870c1b936d8507c6df12b7737febc40f0909"},
{file = "psycopg2_binary-2.9.10-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:f0c2d907a1e102526dd2986df638343388b94c33860ff3bbe1384130828714b1"},
{file = "psycopg2_binary-2.9.10-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:f8157bed2f51db683f31306aa497311b560f2265998122abe1dce6428bd86567"},
{file = "psycopg2_binary-2.9.10-cp313-cp313-win_amd64.whl", hash = "sha256:27422aa5f11fbcd9b18da48373eb67081243662f9b46e6fd07c3eb46e4535142"},
{file = "psycopg2_binary-2.9.10-cp38-cp38-macosx_12_0_x86_64.whl", hash = "sha256:eb09aa7f9cecb45027683bb55aebaaf45a0df8bf6de68801a6afdc7947bb09d4"},
{file = "psycopg2_binary-2.9.10-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b73d6d7f0ccdad7bc43e6d34273f70d587ef62f824d7261c4ae9b8b1b6af90e8"},
{file = "psycopg2_binary-2.9.10-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ce5ab4bf46a211a8e924d307c1b1fcda82368586a19d0a24f8ae166f5c784864"},
@ -6452,7 +6452,7 @@ files = [
{file = "pycparser-2.22-py3-none-any.whl", hash = "sha256:c3702b6d3dd8c7abc1afa565d7e63d53a1d0bd86cdc24edd75470f4de499cfcc"},
{file = "pycparser-2.22.tar.gz", hash = "sha256:491c8be9c040f5390f5bf44a5b07752bd07f56edf992381b05c701439eec10f6"},
]
markers = {storage = "platform_python_implementation != \"PyPy\"", vdb = "python_version < \"3.12\" or platform_python_implementation != \"PyPy\""}
markers = {storage = "platform_python_implementation != \"PyPy\"", vdb = "platform_python_implementation != \"PyPy\""}
[[package]]
name = "pycryptodome"
@ -7240,7 +7240,7 @@ files = [
{file = "pywin32-310-cp39-cp39-win32.whl", hash = "sha256:851c8d927af0d879221e616ae1f66145253537bbdd321a77e8ef701b443a9a1a"},
{file = "pywin32-310-cp39-cp39-win_amd64.whl", hash = "sha256:96867217335559ac619f00ad70e513c0fcf84b8a3af9fc2bba3b59b97da70475"},
]
markers = {main = "platform_system == \"Windows\" and platform_python_implementation != \"PyPy\"", vdb = "platform_system == \"Windows\""}
markers = {main = "platform_python_implementation != \"PyPy\" and platform_system == \"Windows\"", vdb = "platform_system == \"Windows\""}
[[package]]
name = "pyxlsb"
@ -9357,7 +9357,7 @@ description = "Fast implementation of asyncio event loop on top of libuv"
optional = false
python-versions = ">=3.8.0"
groups = ["vdb"]
markers = "sys_platform != \"win32\" and sys_platform != \"cygwin\" and platform_python_implementation != \"PyPy\""
markers = "platform_python_implementation != \"PyPy\" and sys_platform != \"win32\" and sys_platform != \"cygwin\""
files = [
{file = "uvloop-0.21.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:ec7e6b09a6fdded42403182ab6b832b71f4edaf7f37a9a0e371a01db5f0cb45f"},
{file = "uvloop-0.21.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:196274f2adb9689a289ad7d65700d37df0c0930fd8e4e743fa4834e850d7719d"},

View File

@ -1,3 +1,5 @@
import base64
import hashlib
import logging
import uuid
from collections.abc import Mapping
@ -7,6 +9,8 @@ from urllib.parse import urlparse
from uuid import uuid4
import yaml # type: ignore
from Crypto.Cipher import AES
from Crypto.Util.Padding import pad, unpad
from packaging import version
from pydantic import BaseModel, Field
from sqlalchemy import select
@ -478,6 +482,15 @@ class AppDslService:
unique_hash = current_draft_workflow.unique_hash
else:
unique_hash = None
graph = workflow_data.get("graph", {})
for node in graph.get("nodes", []):
if node.get("data", {}).get("type", "") == NodeType.KNOWLEDGE_RETRIEVAL.value:
dataset_ids = node["data"].get("dataset_ids", [])
node["data"]["dataset_ids"] = [
decrypted_id
for dataset_id in dataset_ids
if (decrypted_id := self.decrypt_dataset_id(encrypted_data=dataset_id, tenant_id=app.tenant_id))
]
workflow_service.sync_draft_workflow(
app_model=app,
graph=workflow_data.get("graph", {}),
@ -552,7 +565,15 @@ class AppDslService:
if not workflow:
raise ValueError("Missing draft workflow configuration, please check.")
export_data["workflow"] = workflow.to_dict(include_secret=include_secret)
workflow_dict = workflow.to_dict(include_secret=include_secret)
for node in workflow_dict.get("graph", {}).get("nodes", []):
if node.get("data", {}).get("type", "") == NodeType.KNOWLEDGE_RETRIEVAL.value:
dataset_ids = node["data"].get("dataset_ids", [])
node["data"]["dataset_ids"] = [
cls.encrypt_dataset_id(dataset_id=dataset_id, tenant_id=app_model.tenant_id)
for dataset_id in dataset_ids
]
export_data["workflow"] = workflow_dict
dependencies = cls._extract_dependencies_from_workflow(workflow)
export_data["dependencies"] = [
jsonable_encoder(d.model_dump())
@ -724,3 +745,29 @@ class AppDslService:
return []
return DependenciesAnalysisService.get_leaked_dependencies(tenant_id=tenant_id, dependencies=dependencies)
@staticmethod
def _generate_aes_key(tenant_id: str) -> bytes:
"""Generate AES key based on tenant_id"""
return hashlib.sha256(tenant_id.encode()).digest()
@classmethod
def encrypt_dataset_id(cls, dataset_id: str, tenant_id: str) -> str:
"""Encrypt dataset_id using AES-CBC mode"""
key = cls._generate_aes_key(tenant_id)
iv = key[:16]
cipher = AES.new(key, AES.MODE_CBC, iv)
ct_bytes = cipher.encrypt(pad(dataset_id.encode(), AES.block_size))
return base64.b64encode(ct_bytes).decode()
@classmethod
def decrypt_dataset_id(cls, encrypted_data: str, tenant_id: str) -> str | None:
"""AES decryption"""
try:
key = cls._generate_aes_key(tenant_id)
iv = key[:16]
cipher = AES.new(key, AES.MODE_CBC, iv)
pt = unpad(cipher.decrypt(base64.b64decode(encrypted_data)), AES.block_size)
return pt.decode()
except Exception:
return None

View File

@ -17,6 +17,10 @@ class ApiKeyAuthFactory:
from services.auth.firecrawl.firecrawl import FirecrawlAuth
return FirecrawlAuth
case AuthType.WATERCRAWL:
from services.auth.watercrawl.watercrawl import WatercrawlAuth
return WatercrawlAuth
case AuthType.JINA:
from services.auth.jina.jina import JinaAuth

View File

@ -3,4 +3,5 @@ from enum import StrEnum
class AuthType(StrEnum):
FIRECRAWL = "firecrawl"
WATERCRAWL = "watercrawl"
JINA = "jinareader"

View File

View File

@ -0,0 +1,44 @@
import json
from urllib.parse import urljoin
import requests
from services.auth.api_key_auth_base import ApiKeyAuthBase
class WatercrawlAuth(ApiKeyAuthBase):
def __init__(self, credentials: dict):
super().__init__(credentials)
auth_type = credentials.get("auth_type")
if auth_type != "x-api-key":
raise ValueError("Invalid auth type, WaterCrawl auth type must be x-api-key")
self.api_key = credentials.get("config", {}).get("api_key", None)
self.base_url = credentials.get("config", {}).get("base_url", "https://app.watercrawl.dev")
if not self.api_key:
raise ValueError("No API key provided")
def validate_credentials(self):
headers = self._prepare_headers()
url = urljoin(self.base_url, "/api/v1/core/crawl-requests/")
response = self._get_request(url, headers)
if response.status_code == 200:
return True
else:
self._handle_error(response)
def _prepare_headers(self):
return {"Content-Type": "application/json", "X-API-KEY": self.api_key}
def _get_request(self, url, headers):
return requests.get(url, headers=headers)
def _handle_error(self, response):
if response.status_code in {402, 409, 500}:
error_message = response.json().get("error", "Unknown error occurred")
raise Exception(f"Failed to authorize. Status code: {response.status_code}. Error: {error_message}")
else:
if response.text:
error_message = json.loads(response.text).get("error", "Unknown error occurred")
raise Exception(f"Failed to authorize. Status code: {response.status_code}. Error: {error_message}")
raise Exception(f"Unexpected error occurred while trying to authorize. Status code: {response.status_code}")

View File

@ -6,7 +6,7 @@ from tenacity import retry, retry_if_exception_type, stop_before_delay, wait_fix
from extensions.ext_database import db
from libs.helper import RateLimiter
from models.account import TenantAccountJoin, TenantAccountRole
from models.account import Account, TenantAccountJoin, TenantAccountRole
class BillingService:
@ -106,6 +106,48 @@ class BillingService:
json = {"email": email, "feedback": feedback}
return cls._send_request("POST", "/account/delete-feedback", json=json)
class EducationIdentity:
verification_rate_limit = RateLimiter(prefix="edu_verification_rate_limit", max_attempts=10, time_window=60)
activation_rate_limit = RateLimiter(prefix="edu_activation_rate_limit", max_attempts=10, time_window=60)
@classmethod
def verify(cls, account_id: str, account_email: str):
if cls.verification_rate_limit.is_rate_limited(account_email):
from controllers.console.error import EducationVerifyLimitError
raise EducationVerifyLimitError()
cls.verification_rate_limit.increment_rate_limit(account_email)
params = {"account_id": account_id}
return BillingService._send_request("GET", "/education/verify", params=params)
@classmethod
def is_active(cls, account_id: str):
params = {"account_id": account_id}
return BillingService._send_request("GET", "/education/status", params=params)
@classmethod
def activate(cls, account: Account, token: str, institution: str, role: str):
if cls.activation_rate_limit.is_rate_limited(account.email):
from controllers.console.error import EducationActivateLimitError
raise EducationActivateLimitError()
cls.activation_rate_limit.increment_rate_limit(account.email)
params = {"account_id": account.id, "curr_tenant_id": account.current_tenant_id}
json = {
"institution": institution,
"token": token,
"role": role,
}
return BillingService._send_request("POST", "/education/", json=json, params=params)
@classmethod
def autocomplete(cls, keywords: str, page: int = 0, limit: int = 20):
params = {"keywords": keywords, "page": page, "limit": limit}
return BillingService._send_request("GET", "/education/autocomplete", params=params)
@classmethod
def get_compliance_download_link(
cls,

View File

@ -880,6 +880,9 @@ class DocumentService:
website_info = knowledge_config.data_source.info_list.website_info_list
count = len(website_info.urls) # type: ignore
batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT)
if features.billing.subscription.plan == "sandbox" and count > 1:
raise ValueError("Your current plan does not support batch upload, please upgrade your plan.")
if count > batch_upload_limit:
raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")
@ -1328,6 +1331,8 @@ class DocumentService:
website_info = knowledge_config.data_source.info_list.website_info_list # type: ignore
if website_info:
count = len(website_info.urls)
if features.billing.subscription.plan == "sandbox" and count > 1:
raise ValueError("Your current plan does not support batch upload, please upgrade your plan.")
batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT)
if count > batch_upload_limit:
raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")
@ -1663,6 +1668,7 @@ class SegmentService:
content=content,
word_count=len(content),
tokens=tokens,
keywords=segment_item.get("keywords", []),
status="completed",
indexing_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
completed_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
@ -1780,12 +1786,8 @@ class SegmentService:
)
elif document.doc_form in (IndexType.PARAGRAPH_INDEX, IndexType.QA_INDEX):
if args.enabled or keyword_changed:
VectorService.create_segments_vector(
[args.keywords] if args.keywords else None,
[segment],
dataset,
document.doc_form,
)
# update segment vector index
VectorService.update_segment_vector(args.keywords, segment, dataset)
else:
segment_hash = helper.generate_text_hash(content)
tokens = 0

View File

@ -17,6 +17,11 @@ class BillingModel(BaseModel):
subscription: SubscriptionModel = SubscriptionModel()
class EducationModel(BaseModel):
enabled: bool = False
activated: bool = False
class LimitationModel(BaseModel):
size: int = 0
limit: int = 0
@ -38,6 +43,7 @@ class LicenseModel(BaseModel):
class FeatureModel(BaseModel):
billing: BillingModel = BillingModel()
education: EducationModel = EducationModel()
members: LimitationModel = LimitationModel(size=0, limit=1)
apps: LimitationModel = LimitationModel(size=0, limit=10)
vector_space: LimitationModel = LimitationModel(size=0, limit=5)
@ -128,6 +134,7 @@ class FeatureService:
features.can_replace_logo = dify_config.CAN_REPLACE_LOGO
features.model_load_balancing_enabled = dify_config.MODEL_LB_ENABLED
features.dataset_operator_enabled = dify_config.DATASET_OPERATOR_ENABLED
features.education.enabled = dify_config.EDUCATION_ENABLED
@classmethod
def _fulfill_params_from_billing_api(cls, features: FeatureModel, tenant_id: str):
@ -136,6 +143,7 @@ class FeatureService:
features.billing.enabled = billing_info["enabled"]
features.billing.subscription.plan = billing_info["subscription"]["plan"]
features.billing.subscription.interval = billing_info["subscription"]["interval"]
features.education.activated = billing_info["subscription"].get("education", False)
if "members" in billing_info:
features.members.size = billing_info["members"]["size"]

View File

@ -20,7 +20,7 @@ class TagService:
)
if keyword:
query = query.filter(db.and_(Tag.name.ilike(f"%{keyword}%")))
query = query.group_by(Tag.id, Tag.type, Tag.name)
query = query.group_by(Tag.id, Tag.type, Tag.name, Tag.created_at)
results: list = query.order_by(Tag.created_at.desc()).all()
return results

View File

@ -7,6 +7,7 @@ from flask_login import current_user # type: ignore
from core.helper import encrypter
from core.rag.extractor.firecrawl.firecrawl_app import FirecrawlApp
from core.rag.extractor.watercrawl.provider import WaterCrawlProvider
from extensions.ext_redis import redis_client
from extensions.ext_storage import storage
from services.auth.api_key_auth_service import ApiKeyAuthService
@ -59,6 +60,13 @@ class WebsiteService:
time = str(datetime.datetime.now().timestamp())
redis_client.setex(website_crawl_time_cache_key, 3600, time)
return {"status": "active", "job_id": job_id}
elif provider == "watercrawl":
# decrypt api_key
api_key = encrypter.decrypt_token(
tenant_id=current_user.current_tenant_id, token=credentials.get("config").get("api_key")
)
return WaterCrawlProvider(api_key, credentials.get("config").get("base_url", None)).crawl_url(url, options)
elif provider == "jinareader":
api_key = encrypter.decrypt_token(
tenant_id=current_user.current_tenant_id, token=credentials.get("config").get("api_key")
@ -116,6 +124,14 @@ class WebsiteService:
time_consuming = abs(end_time - float(start_time))
crawl_status_data["time_consuming"] = f"{time_consuming:.2f}"
redis_client.delete(website_crawl_time_cache_key)
elif provider == "watercrawl":
# decrypt api_key
api_key = encrypter.decrypt_token(
tenant_id=current_user.current_tenant_id, token=credentials.get("config").get("api_key")
)
crawl_status_data = WaterCrawlProvider(
api_key, credentials.get("config").get("base_url", None)
).get_crawl_status(job_id)
elif provider == "jinareader":
api_key = encrypter.decrypt_token(
tenant_id=current_user.current_tenant_id, token=credentials.get("config").get("api_key")
@ -180,6 +196,11 @@ class WebsiteService:
if item.get("source_url") == url:
return dict(item)
return None
elif provider == "watercrawl":
api_key = encrypter.decrypt_token(tenant_id=tenant_id, token=credentials.get("config").get("api_key"))
return WaterCrawlProvider(api_key, credentials.get("config").get("base_url", None)).get_crawl_url_data(
job_id, url
)
elif provider == "jinareader":
if not job_id:
response = requests.get(
@ -223,5 +244,8 @@ class WebsiteService:
params = {"onlyMainContent": only_main_content}
result = firecrawl_app.scrape_url(url, params)
return result
elif provider == "watercrawl":
api_key = encrypter.decrypt_token(tenant_id=tenant_id, token=credentials.get("config").get("api_key"))
return WaterCrawlProvider(api_key, credentials.get("config").get("base_url", None)).scrape_url(url)
else:
raise ValueError("Invalid provider")

View File

@ -4,7 +4,6 @@ import time
import click
from celery import shared_task # type: ignore
from werkzeug.exceptions import NotFound
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
@ -28,7 +27,9 @@ def add_document_to_index_task(dataset_document_id: str):
dataset_document = db.session.query(DatasetDocument).filter(DatasetDocument.id == dataset_document_id).first()
if not dataset_document:
raise NotFound("Document not found")
logging.info(click.style("Document not found: {}".format(dataset_document_id), fg="red"))
db.session.close()
return
if dataset_document.indexing_status != "completed":
return

View File

@ -6,6 +6,7 @@ from celery import shared_task # type: ignore
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.models.document import Document
from extensions.ext_database import db
from models.dataset import Dataset
from services.dataset_service import DatasetCollectionBindingService
@ -55,3 +56,5 @@ def add_annotation_to_index_task(
)
except Exception:
logging.exception("Build index for annotation failed")
finally:
db.session.close()

View File

@ -88,3 +88,5 @@ def batch_import_annotations_task(job_id: str, content_list: list[dict], app_id:
indexing_error_msg_key = "app_annotation_batch_import_error_msg_{}".format(str(job_id))
redis_client.setex(indexing_error_msg_key, 600, str(e))
logging.exception("Build index for batch import annotations failed")
finally:
db.session.close()

View File

@ -5,6 +5,7 @@ import click
from celery import shared_task # type: ignore
from core.rag.datasource.vdb.vector_factory import Vector
from extensions.ext_database import db
from models.dataset import Dataset
from services.dataset_service import DatasetCollectionBindingService
@ -39,3 +40,5 @@ def delete_annotation_index_task(annotation_id: str, app_id: str, tenant_id: str
)
except Exception as e:
logging.exception("Annotation deleted index failed")
finally:
db.session.close()

View File

@ -3,7 +3,6 @@ import time
import click
from celery import shared_task # type: ignore
from werkzeug.exceptions import NotFound
from core.rag.datasource.vdb.vector_factory import Vector
from extensions.ext_database import db
@ -23,14 +22,18 @@ def disable_annotation_reply_task(job_id: str, app_id: str, tenant_id: str):
app = db.session.query(App).filter(App.id == app_id, App.tenant_id == tenant_id, App.status == "normal").first()
annotations_count = db.session.query(MessageAnnotation).filter(MessageAnnotation.app_id == app_id).count()
if not app:
raise NotFound("App not found")
logging.info(click.style("App not found: {}".format(app_id), fg="red"))
db.session.close()
return
app_annotation_setting = (
db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app_id).first()
)
if not app_annotation_setting:
raise NotFound("App annotation setting not found")
logging.info(click.style("App annotation setting not found: {}".format(app_id), fg="red"))
db.session.close()
return
disable_app_annotation_key = "disable_app_annotation_{}".format(str(app_id))
disable_app_annotation_job_key = "disable_app_annotation_job_{}".format(str(job_id))
@ -46,7 +49,7 @@ def disable_annotation_reply_task(job_id: str, app_id: str, tenant_id: str):
try:
if annotations_count > 0:
vector = Vector(dataset, attributes=["doc_id", "annotation_id", "app_id"])
vector.delete_by_metadata_field("app_id", app_id)
vector.delete()
except Exception:
logging.exception("Delete annotation index failed when annotation deleted.")
redis_client.setex(disable_app_annotation_job_key, 600, "completed")
@ -66,3 +69,4 @@ def disable_annotation_reply_task(job_id: str, app_id: str, tenant_id: str):
redis_client.setex(disable_app_annotation_error_key, 600, str(e))
finally:
redis_client.delete(disable_app_annotation_key)
db.session.close()

View File

@ -4,7 +4,6 @@ import time
import click
from celery import shared_task # type: ignore
from werkzeug.exceptions import NotFound
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.models.document import Document
@ -34,7 +33,9 @@ def enable_annotation_reply_task(
app = db.session.query(App).filter(App.id == app_id, App.tenant_id == tenant_id, App.status == "normal").first()
if not app:
raise NotFound("App not found")
logging.info(click.style("App not found: {}".format(app_id), fg="red"))
db.session.close()
return
annotations = db.session.query(MessageAnnotation).filter(MessageAnnotation.app_id == app_id).all()
enable_app_annotation_key = "enable_app_annotation_{}".format(str(app_id))
@ -49,6 +50,27 @@ def enable_annotation_reply_task(
db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app_id).first()
)
if annotation_setting:
if dataset_collection_binding.id != annotation_setting.collection_binding_id:
old_dataset_collection_binding = (
DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(
annotation_setting.collection_binding_id, "annotation"
)
)
if old_dataset_collection_binding and annotations:
old_dataset = Dataset(
id=app_id,
tenant_id=tenant_id,
indexing_technique="high_quality",
embedding_model_provider=old_dataset_collection_binding.provider_name,
embedding_model=old_dataset_collection_binding.model_name,
collection_binding_id=old_dataset_collection_binding.id,
)
old_vector = Vector(old_dataset, attributes=["doc_id", "annotation_id", "app_id"])
try:
old_vector.delete()
except Exception as e:
logging.info(click.style("Delete annotation index error: {}".format(str(e)), fg="red"))
annotation_setting.score_threshold = score_threshold
annotation_setting.collection_binding_id = dataset_collection_binding.id
annotation_setting.updated_user_id = user_id
@ -100,3 +122,4 @@ def enable_annotation_reply_task(
db.session.rollback()
finally:
redis_client.delete(enable_app_annotation_key)
db.session.close()

View File

@ -6,6 +6,7 @@ from celery import shared_task # type: ignore
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.models.document import Document
from extensions.ext_database import db
from models.dataset import Dataset
from services.dataset_service import DatasetCollectionBindingService
@ -56,3 +57,5 @@ def update_annotation_to_index_task(
)
except Exception:
logging.exception("Build index for annotation failed")
finally:
db.session.close()

View File

@ -74,3 +74,5 @@ def batch_clean_document_task(document_ids: list[str], dataset_id: str, doc_form
)
except Exception:
logging.exception("Cleaned documents when documents deleted failed")
finally:
db.session.close()

View File

@ -127,3 +127,5 @@ def batch_create_segment_to_index_task(
except Exception:
logging.exception("Segments batch created index failed")
redis_client.setex(indexing_cache_key, 600, "error")
finally:
db.session.close()

View File

@ -11,6 +11,8 @@ from extensions.ext_storage import storage
from models.dataset import (
AppDatasetJoin,
Dataset,
DatasetMetadata,
DatasetMetadataBinding,
DatasetProcessRule,
DatasetQuery,
Document,
@ -86,7 +88,9 @@ def clean_dataset_task(
db.session.query(DatasetProcessRule).filter(DatasetProcessRule.dataset_id == dataset_id).delete()
db.session.query(DatasetQuery).filter(DatasetQuery.dataset_id == dataset_id).delete()
db.session.query(AppDatasetJoin).filter(AppDatasetJoin.dataset_id == dataset_id).delete()
# delete dataset metadata
db.session.query(DatasetMetadata).filter(DatasetMetadata.dataset_id == dataset_id).delete()
db.session.query(DatasetMetadataBinding).filter(DatasetMetadataBinding.dataset_id == dataset_id).delete()
# delete files
if documents:
for document in documents:
@ -117,3 +121,5 @@ def clean_dataset_task(
)
except Exception:
logging.exception("Cleaned dataset when dataset deleted failed")
finally:
db.session.close()

View File

@ -9,7 +9,7 @@ from core.rag.index_processor.index_processor_factory import IndexProcessorFacto
from core.tools.utils.rag_web_reader import get_image_upload_file_ids
from extensions.ext_database import db
from extensions.ext_storage import storage
from models.dataset import Dataset, DocumentSegment
from models.dataset import Dataset, DatasetMetadataBinding, DocumentSegment
from models.model import UploadFile
@ -67,6 +67,12 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i
db.session.delete(file)
db.session.commit()
# delete dataset metadata binding
db.session.query(DatasetMetadataBinding).filter(
DatasetMetadataBinding.dataset_id == dataset_id,
DatasetMetadataBinding.document_id == document_id,
).delete()
end_at = time.perf_counter()
logging.info(
click.style(
@ -76,3 +82,5 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i
)
except Exception:
logging.exception("Cleaned document when document deleted failed")
finally:
db.session.close()

View File

@ -53,3 +53,5 @@ def clean_notion_document_task(document_ids: list[str], dataset_id: str):
)
except Exception:
logging.exception("Cleaned document when import form notion document deleted failed")
finally:
db.session.close()

View File

@ -5,7 +5,6 @@ from typing import Optional
import click
from celery import shared_task # type: ignore
from werkzeug.exceptions import NotFound
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.rag.models.document import Document
@ -27,7 +26,9 @@ def create_segment_to_index_task(segment_id: str, keywords: Optional[list[str]]
segment = db.session.query(DocumentSegment).filter(DocumentSegment.id == segment_id).first()
if not segment:
raise NotFound("Segment not found")
logging.info(click.style("Segment not found: {}".format(segment_id), fg="red"))
db.session.close()
return
if segment.status != "waiting":
return
@ -93,3 +94,4 @@ def create_segment_to_index_task(segment_id: str, keywords: Optional[list[str]]
db.session.commit()
finally:
redis_client.delete(indexing_cache_key)
db.session.close()

View File

@ -167,3 +167,5 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
)
except Exception:
logging.exception("Deal dataset vector index failed")
finally:
db.session.close()

View File

@ -41,3 +41,5 @@ def delete_segment_from_index_task(index_node_ids: list, dataset_id: str, docume
logging.info(click.style("Segment deleted from index latency: {}".format(end_at - start_at), fg="green"))
except Exception:
logging.exception("delete segment from index failed")
finally:
db.session.close()

View File

@ -3,7 +3,6 @@ import time
import click
from celery import shared_task # type: ignore
from werkzeug.exceptions import NotFound
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from extensions.ext_database import db
@ -24,10 +23,14 @@ def disable_segment_from_index_task(segment_id: str):
segment = db.session.query(DocumentSegment).filter(DocumentSegment.id == segment_id).first()
if not segment:
raise NotFound("Segment not found")
logging.info(click.style("Segment not found: {}".format(segment_id), fg="red"))
db.session.close()
return
if segment.status != "completed":
raise NotFound("Segment is not completed , disable action is not allowed.")
logging.info(click.style("Segment is not completed, disable is not allowed: {}".format(segment_id), fg="red"))
db.session.close()
return
indexing_cache_key = "segment_{}_indexing".format(segment.id)
@ -62,3 +65,4 @@ def disable_segment_from_index_task(segment_id: str):
db.session.commit()
finally:
redis_client.delete(indexing_cache_key)
db.session.close()

View File

@ -26,15 +26,18 @@ def disable_segments_from_index_task(segment_ids: list, dataset_id: str, documen
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
if not dataset:
logging.info(click.style("Dataset {} not found, pass.".format(dataset_id), fg="cyan"))
db.session.close()
return
dataset_document = db.session.query(DatasetDocument).filter(DatasetDocument.id == document_id).first()
if not dataset_document:
logging.info(click.style("Document {} not found, pass.".format(document_id), fg="cyan"))
db.session.close()
return
if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed":
logging.info(click.style("Document {} status is invalid, pass.".format(document_id), fg="cyan"))
db.session.close()
return
# sync index processor
index_processor = IndexProcessorFactory(dataset_document.doc_form).init_index_processor()
@ -50,6 +53,7 @@ def disable_segments_from_index_task(segment_ids: list, dataset_id: str, documen
)
if not segments:
db.session.close()
return
try:
@ -76,3 +80,4 @@ def disable_segments_from_index_task(segment_ids: list, dataset_id: str, documen
for segment in segments:
indexing_cache_key = "segment_{}_indexing".format(segment.id)
redis_client.delete(indexing_cache_key)
db.session.close()

View File

@ -4,7 +4,6 @@ import time
import click
from celery import shared_task # type: ignore
from werkzeug.exceptions import NotFound
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
from core.rag.extractor.notion_extractor import NotionExtractor
@ -29,7 +28,9 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
document = db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first()
if not document:
raise NotFound("Document not found")
logging.info(click.style("Document not found: {}".format(document_id), fg="red"))
db.session.close()
return
data_source_info = document.data_source_info_dict
if document.data_source_type == "notion_import":

View File

@ -27,6 +27,7 @@ def document_indexing_task(dataset_id: str, document_ids: list):
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
if not dataset:
logging.info(click.style("Dataset is not found: {}".format(dataset_id), fg="yellow"))
db.session.close()
return
# check document limit
features = FeatureService.get_features(dataset.tenant_id)
@ -35,6 +36,8 @@ def document_indexing_task(dataset_id: str, document_ids: list):
vector_space = features.vector_space
count = len(document_ids)
batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT)
if features.billing.subscription.plan == "sandbox" and count > 1:
raise ValueError("Your current plan does not support batch upload, please upgrade your plan.")
if count > batch_upload_limit:
raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")
if 0 < vector_space.limit <= vector_space.size:
@ -53,6 +56,7 @@ def document_indexing_task(dataset_id: str, document_ids: list):
document.stopped_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
db.session.add(document)
db.session.commit()
db.session.close()
return
for document_id in document_ids:
@ -78,3 +82,5 @@ def document_indexing_task(dataset_id: str, document_ids: list):
logging.info(click.style(str(ex), fg="yellow"))
except Exception:
pass
finally:
db.session.close()

View File

@ -4,7 +4,6 @@ import time
import click
from celery import shared_task # type: ignore
from werkzeug.exceptions import NotFound
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
@ -27,7 +26,9 @@ def document_indexing_update_task(dataset_id: str, document_id: str):
document = db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first()
if not document:
raise NotFound("Document not found")
logging.info(click.style("Document not found: {}".format(document_id), fg="red"))
db.session.close()
return
document.indexing_status = "parsing"
document.processing_started_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
@ -73,3 +74,5 @@ def document_indexing_update_task(dataset_id: str, document_id: str):
logging.info(click.style(str(ex), fg="yellow"))
except Exception:
pass
finally:
db.session.close()

View File

@ -27,7 +27,9 @@ def duplicate_document_indexing_task(dataset_id: str, document_ids: list):
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
if dataset is None:
raise ValueError("Dataset not found")
logging.info(click.style("Dataset not found: {}".format(dataset_id), fg="red"))
db.session.close()
return
# check document limit
features = FeatureService.get_features(dataset.tenant_id)
@ -35,6 +37,8 @@ def duplicate_document_indexing_task(dataset_id: str, document_ids: list):
if features.billing.enabled:
vector_space = features.vector_space
count = len(document_ids)
if features.billing.subscription.plan == "sandbox" and count > 1:
raise ValueError("Your current plan does not support batch upload, please upgrade your plan.")
batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT)
if count > batch_upload_limit:
raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")
@ -55,6 +59,8 @@ def duplicate_document_indexing_task(dataset_id: str, document_ids: list):
db.session.add(document)
db.session.commit()
return
finally:
db.session.close()
for document_id in document_ids:
logging.info(click.style("Start process document: {}".format(document_id), fg="green"))
@ -94,3 +100,5 @@ def duplicate_document_indexing_task(dataset_id: str, document_ids: list):
logging.info(click.style(str(ex), fg="yellow"))
except Exception:
pass
finally:
db.session.close()

View File

@ -4,7 +4,6 @@ import time
import click
from celery import shared_task # type: ignore
from werkzeug.exceptions import NotFound
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
@ -27,10 +26,14 @@ def enable_segment_to_index_task(segment_id: str):
segment = db.session.query(DocumentSegment).filter(DocumentSegment.id == segment_id).first()
if not segment:
raise NotFound("Segment not found")
logging.info(click.style("Segment not found: {}".format(segment_id), fg="red"))
db.session.close()
return
if segment.status != "completed":
raise NotFound("Segment is not completed, enable action is not allowed.")
logging.info(click.style("Segment is not completed, enable is not allowed: {}".format(segment_id), fg="red"))
db.session.close()
return
indexing_cache_key = "segment_{}_indexing".format(segment.id)
@ -94,3 +97,4 @@ def enable_segment_to_index_task(segment_id: str):
db.session.commit()
finally:
redis_client.delete(indexing_cache_key)
db.session.close()

View File

@ -34,9 +34,11 @@ def enable_segments_to_index_task(segment_ids: list, dataset_id: str, document_i
if not dataset_document:
logging.info(click.style("Document {} not found, pass.".format(document_id), fg="cyan"))
db.session.close()
return
if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed":
logging.info(click.style("Document {} status is invalid, pass.".format(document_id), fg="cyan"))
db.session.close()
return
# sync index processor
index_processor = IndexProcessorFactory(dataset_document.doc_form).init_index_processor()
@ -51,6 +53,8 @@ def enable_segments_to_index_task(segment_ids: list, dataset_id: str, document_i
.all()
)
if not segments:
logging.info(click.style("Segments not found: {}".format(segment_ids), fg="cyan"))
db.session.close()
return
try:
@ -108,3 +112,4 @@ def enable_segments_to_index_task(segment_ids: list, dataset_id: str, document_i
for segment in segments:
indexing_cache_key = "segment_{}_indexing".format(segment.id)
redis_client.delete(indexing_cache_key)
db.session.close()

View File

@ -3,7 +3,6 @@ import time
import click
from celery import shared_task # type: ignore
from werkzeug.exceptions import NotFound
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
from extensions.ext_database import db
@ -25,7 +24,9 @@ def recover_document_indexing_task(dataset_id: str, document_id: str):
document = db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first()
if not document:
raise NotFound("Document not found")
logging.info(click.style("Document not found: {}".format(document_id), fg="red"))
db.session.close()
return
try:
indexing_runner = IndexingRunner()
@ -43,3 +44,5 @@ def recover_document_indexing_task(dataset_id: str, document_id: str):
logging.info(click.style(str(ex), fg="yellow"))
except Exception:
pass
finally:
db.session.close()

View File

@ -4,7 +4,6 @@ import time
import click
from celery import shared_task # type: ignore
from werkzeug.exceptions import NotFound
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from extensions.ext_database import db
@ -25,9 +24,13 @@ def remove_document_from_index_task(document_id: str):
document = db.session.query(Document).filter(Document.id == document_id).first()
if not document:
raise NotFound("Document not found")
logging.info(click.style("Document not found: {}".format(document_id), fg="red"))
db.session.close()
return
if document.indexing_status != "completed":
logging.info(click.style("Document is not completed, remove is not allowed: {}".format(document_id), fg="red"))
db.session.close()
return
indexing_cache_key = "document_{}_indexing".format(document.id)
@ -71,3 +74,4 @@ def remove_document_from_index_task(document_id: str):
db.session.commit()
finally:
redis_client.delete(indexing_cache_key)
db.session.close()

View File

@ -27,7 +27,9 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str]):
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
if not dataset:
raise ValueError("Dataset not found")
logging.info(click.style("Dataset not found: {}".format(dataset_id), fg="red"))
db.session.close()
return
for document_id in document_ids:
retry_indexing_cache_key = "document_{}_is_retried".format(document_id)
@ -52,6 +54,7 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str]):
db.session.add(document)
db.session.commit()
redis_client.delete(retry_indexing_cache_key)
db.session.close()
return
logging.info(click.style("Start retry document: {}".format(document_id), fg="green"))
@ -60,6 +63,7 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str]):
)
if not document:
logging.info(click.style("Document not found: {}".format(document_id), fg="yellow"))
db.session.close()
return
try:
# clean old data
@ -92,5 +96,7 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str]):
logging.info(click.style(str(ex), fg="yellow"))
redis_client.delete(retry_indexing_cache_key)
pass
finally:
db.session.close()
end_at = time.perf_counter()
logging.info(click.style("Retry dataset: {} latency: {}".format(dataset_id, end_at - start_at), fg="green"))

View File

@ -5,10 +5,11 @@ import pytest
from _pytest.monkeypatch import MonkeyPatch
from requests.adapters import HTTPAdapter
from tcvectordb import RPCVectorDBClient # type: ignore
from tcvectordb.model import enum
from tcvectordb.model.collection import FilterIndexConfig
from tcvectordb.model.document import Document, Filter # type: ignore
from tcvectordb.model.document import AnnSearch, Document, Filter, KeywordSearch, Rerank # type: ignore
from tcvectordb.model.enum import ReadConsistency # type: ignore
from tcvectordb.model.index import Index, IndexField # type: ignore
from tcvectordb.model.index import FilterIndex, HNSWParams, Index, IndexField, VectorIndex # type: ignore
from tcvectordb.rpc.model.collection import RPCCollection
from tcvectordb.rpc.model.database import RPCDatabase
from xinference_client.types import Embedding # type: ignore
@ -40,6 +41,30 @@ class MockTcvectordbClass:
def exists_collection(self, database_name: str, collection_name: str) -> bool:
return True
def describe_collection(
self, database_name: str, collection_name: str, timeout: Optional[float] = None
) -> RPCCollection:
index = Index(
FilterIndex("id", enum.FieldType.String, enum.IndexType.PRIMARY_KEY),
VectorIndex(
"vector",
128,
enum.IndexType.HNSW,
enum.MetricType.IP,
HNSWParams(m=16, efconstruction=200),
),
FilterIndex("text", enum.FieldType.String, enum.IndexType.FILTER),
FilterIndex("metadata", enum.FieldType.String, enum.IndexType.FILTER),
)
return RPCCollection(
RPCDatabase(
name=database_name,
read_consistency=self._read_consistency,
),
collection_name,
index=index,
)
def create_collection(
self,
database_name: str,
@ -97,6 +122,23 @@ class MockTcvectordbClass:
) -> list[list[dict]]:
return [[{"metadata": {"doc_id": "foo1"}, "text": "text", "doc_id": "foo1", "score": 0.1}]]
def collection_hybrid_search(
self,
database_name: str,
collection_name: str,
ann: Optional[Union[list[AnnSearch], AnnSearch]] = None,
match: Optional[Union[list[KeywordSearch], KeywordSearch]] = None,
filter: Union[Filter, str] = None,
rerank: Optional[Rerank] = None,
retrieve_vector: Optional[bool] = None,
output_fields: Optional[list[str]] = None,
limit: Optional[int] = None,
timeout: Optional[float] = None,
return_pd_object=False,
**kwargs,
) -> list[list[dict]]:
return [[{"metadata": {"doc_id": "foo1"}, "text": "text", "doc_id": "foo1", "score": 0.1}]]
def collection_query(
self,
database_name: str,
@ -137,8 +179,10 @@ def setup_tcvectordb_mock(request, monkeypatch: MonkeyPatch):
)
monkeypatch.setattr(RPCVectorDBClient, "exists_collection", MockTcvectordbClass.exists_collection)
monkeypatch.setattr(RPCVectorDBClient, "create_collection", MockTcvectordbClass.create_collection)
monkeypatch.setattr(RPCVectorDBClient, "describe_collection", MockTcvectordbClass.describe_collection)
monkeypatch.setattr(RPCVectorDBClient, "upsert", MockTcvectordbClass.collection_upsert)
monkeypatch.setattr(RPCVectorDBClient, "search", MockTcvectordbClass.collection_search)
monkeypatch.setattr(RPCVectorDBClient, "hybrid_search", MockTcvectordbClass.collection_hybrid_search)
monkeypatch.setattr(RPCVectorDBClient, "query", MockTcvectordbClass.collection_query)
monkeypatch.setattr(RPCVectorDBClient, "delete", MockTcvectordbClass.collection_delete)
monkeypatch.setattr(RPCVectorDBClient, "drop_collection", MockTcvectordbClass.drop_collection)

View File

@ -21,6 +21,7 @@ class TencentVectorTest(AbstractVectorTest):
database="dify",
shard=1,
replicas=2,
enable_hybrid_search=True,
),
)
@ -30,7 +31,7 @@ class TencentVectorTest(AbstractVectorTest):
def search_by_full_text(self):
hits_by_full_text = self.vector.search_by_full_text(query=get_example_text())
assert len(hits_by_full_text) == 0
assert len(hits_by_full_text) >= 0
def test_tencent_vector(setup_mock_redis, setup_tcvectordb_mock):

View File

@ -0,0 +1,70 @@
import json
from collections.abc import Generator
from core.agent.entities import AgentScratchpadUnit
from core.agent.output_parser.cot_output_parser import CotAgentOutputParser
from core.model_runtime.entities.llm_entities import AssistantPromptMessage, LLMResultChunk, LLMResultChunkDelta
def mock_llm_response(text) -> Generator[LLMResultChunk, None, None]:
for i in range(len(text)):
yield LLMResultChunk(
model="model",
prompt_messages=[],
delta=LLMResultChunkDelta(index=0, message=AssistantPromptMessage(content=text[i], tool_calls=[])),
)
def test_cot_output_parser():
test_cases = [
{
"input": 'Through: abc\nAction: ```{"action": "Final Answer", "action_input": "```echarts\n {}\n```"}```',
"action": {"action": "Final Answer", "action_input": "```echarts\n {}\n```"},
"output": 'Through: abc\n {"action": "Final Answer", "action_input": "```echarts\\n {}\\n```"}',
},
# code block with json
{
"input": 'Through: abc\nAction: ```json\n{"action": "Final Answer", "action_input": "```echarts\n {'
'}\n```"}```',
"action": {"action": "Final Answer", "action_input": "```echarts\n {}\n```"},
"output": 'Through: abc\n {"action": "Final Answer", "action_input": "```echarts\\n {}\\n```"}',
},
# code block with JSON
{
"input": 'Through: abc\nAction: ```JSON\n{"action": "Final Answer", "action_input": "```echarts\n {'
'}\n```"}```',
"action": {"action": "Final Answer", "action_input": "```echarts\n {}\n```"},
"output": 'Through: abc\n {"action": "Final Answer", "action_input": "```echarts\\n {}\\n```"}',
},
# list
{
"input": 'Through: abc\nAction: ```[{"action": "Final Answer", "action_input": "```echarts\n {}\n```"}]```',
"action": {"action": "Final Answer", "action_input": "```echarts\n {}\n```"},
"output": 'Through: abc\n {"action": "Final Answer", "action_input": "```echarts\\n {}\\n```"}',
},
# no code block
{
"input": 'Through: abc\nAction: {"action": "Final Answer", "action_input": "```echarts\n {}\n```"}',
"action": {"action": "Final Answer", "action_input": "```echarts\n {}\n```"},
"output": 'Through: abc\n {"action": "Final Answer", "action_input": "```echarts\\n {}\\n```"}',
},
# no code block and json
{"input": "Through: abc\nAction: efg", "action": {}, "output": "Through: abc\n efg"},
]
parser = CotAgentOutputParser()
usage_dict = {}
for test_case in test_cases:
# mock llm_response as a generator by text
llm_response: Generator[LLMResultChunk, None, None] = mock_llm_response(test_case["input"])
results = parser.handle_react_stream_output(llm_response, usage_dict)
output = ""
for result in results:
if isinstance(result, str):
output += result
elif isinstance(result, AgentScratchpadUnit.Action):
if test_case["action"]:
assert result.to_dict() == test_case["action"]
output += json.dumps(result.to_dict())
if test_case["output"]:
assert output == test_case["output"]

View File

@ -515,6 +515,7 @@ TENCENT_VECTOR_DB_USERNAME=dify
TENCENT_VECTOR_DB_DATABASE=dify
TENCENT_VECTOR_DB_SHARD=1
TENCENT_VECTOR_DB_REPLICAS=2
TENCENT_VECTOR_DB_ENABLE_HYBRID_SEARCH=false
# ElasticSearch configuration, only available when VECTOR_STORE is `elasticsearch`
ELASTICSEARCH_HOST=0.0.0.0

View File

@ -36,7 +36,8 @@ Welcome to the new `docker` directory for deploying Dify using Docker Compose. T
- Navigate to the `docker` directory.
- Ensure the `middleware.env` file is created by running `cp middleware.env.example middleware.env` (refer to the `middleware.env.example` file).
2. **Running Middleware Services**:
- Execute `docker-compose -f docker-compose.middleware.yaml up --env-file middleware.env -d` to start the middleware services.
- Navigate to the `docker` directory.
- Execute `docker compose -f docker-compose.middleware.yaml --profile weaviate -p dify up -d` to start the middleware services. (Change the profile to other vector database if you are not using weaviate)
### Migration for Existing Users

View File

@ -223,6 +223,7 @@ x-shared-env: &shared-api-worker-env
TENCENT_VECTOR_DB_DATABASE: ${TENCENT_VECTOR_DB_DATABASE:-dify}
TENCENT_VECTOR_DB_SHARD: ${TENCENT_VECTOR_DB_SHARD:-1}
TENCENT_VECTOR_DB_REPLICAS: ${TENCENT_VECTOR_DB_REPLICAS:-2}
TENCENT_VECTOR_DB_ENABLE_HYBRID_SEARCH: ${TENCENT_VECTOR_DB_ENABLE_HYBRID_SEARCH:-false}
ELASTICSEARCH_HOST: ${ELASTICSEARCH_HOST:-0.0.0.0}
ELASTICSEARCH_PORT: ${ELASTICSEARCH_PORT:-9200}
ELASTICSEARCH_USERNAME: ${ELASTICSEARCH_USERNAME:-elastic}

Binary file not shown.

Before

Width:  |  Height:  |  Size: 257 KiB

View File

@ -42,8 +42,8 @@ ENV EDITION=SELF_HOSTED
ENV DEPLOY_ENV=PRODUCTION
ENV CONSOLE_API_URL=http://127.0.0.1:5001
ENV APP_API_URL=http://127.0.0.1:5001
ENV MARKETPLACE_API_URL=http://127.0.0.1:5001
ENV MARKETPLACE_URL=http://127.0.0.1:5001
ENV MARKETPLACE_API_URL=https://marketplace.dify.ai
ENV MARKETPLACE_URL=https://marketplace.dify.ai
ENV PORT=3000
ENV NEXT_TELEMETRY_DISABLED=1
ENV PM2_INSTANCES=2

View File

@ -1,7 +1,9 @@
'use client'
import { useCallback, useEffect, useRef, useState } from 'react'
import { useRouter } from 'next/navigation'
import {
useRouter,
} from 'next/navigation'
import useSWRInfinite from 'swr/infinite'
import { useTranslation } from 'react-i18next'
import { useDebounceFn } from 'ahooks'

View File

@ -7,9 +7,12 @@ import style from '../list.module.css'
import Apps from './Apps'
import AppContext from '@/context/app-context'
import { LicenseStatus } from '@/types/feature'
import { useEducationInit } from '@/app/education-apply/hooks'
const AppList = () => {
const { t } = useTranslation()
useEducationInit()
const systemFeatures = useContextSelector(AppContext, v => v.systemFeatures)
return (

View File

@ -38,6 +38,8 @@ const Container = () => {
const { showExternalApiPanel, setShowExternalApiPanel } = useExternalApiPanel()
const [includeAll, { toggle: toggleIncludeAll }] = useBoolean(false)
document.title = `${t('dataset.knowledge')} - Dify`
const options = useMemo(() => {
return [
{ value: 'dataset', text: t('dataset.datasets') },

Some files were not shown because too many files have changed in this diff Show More