mirror of https://github.com/langgenius/dify.git
merge main
This commit is contained in:
commit
358d5978cb
|
@ -7,7 +7,7 @@ pipx install poetry
|
|||
echo 'alias start-api="cd /workspaces/dify/api && poetry run python -m flask run --host 0.0.0.0 --port=5001 --debug"' >> ~/.bashrc
|
||||
echo 'alias start-worker="cd /workspaces/dify/api && poetry run python -m celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion"' >> ~/.bashrc
|
||||
echo 'alias start-web="cd /workspaces/dify/web && pnpm dev"' >> ~/.bashrc
|
||||
echo 'alias start-containers="cd /workspaces/dify/docker && docker-compose -f docker-compose.middleware.yaml -p dify up -d"' >> ~/.bashrc
|
||||
echo 'alias stop-containers="cd /workspaces/dify/docker && docker-compose -f docker-compose.middleware.yaml -p dify down"' >> ~/.bashrc
|
||||
echo 'alias start-containers="cd /workspaces/dify/docker && docker-compose -f docker-compose.middleware.yaml -p dify --env-file middleware.env up -d"' >> ~/.bashrc
|
||||
echo 'alias stop-containers="cd /workspaces/dify/docker && docker-compose -f docker-compose.middleware.yaml -p dify --env-file middleware.env down"' >> ~/.bashrc
|
||||
|
||||
source /home/vscode/.bashrc
|
||||
|
|
|
@ -6,6 +6,7 @@ on:
|
|||
- "main"
|
||||
- "deploy/dev"
|
||||
- "deploy/enterprise"
|
||||
- release/1.1.3-fix1
|
||||
tags:
|
||||
- "*"
|
||||
|
||||
|
|
|
@ -254,8 +254,6 @@ docker compose up -d
|
|||
- [Discord](https://discord.gg/FngNHpbcY7)。👉:分享您的应用程序并与社区交流。
|
||||
- [X(Twitter)](https://twitter.com/dify_ai)。👉:分享您的应用程序并与社区交流。
|
||||
- [商业许可](mailto:business@dify.ai?subject=[GitHub]Business%20License%20Inquiry)。👉:有关商业用途许可 Dify.AI 的商业咨询。
|
||||
- [微信]() 👉:扫描下方二维码,添加微信好友,备注 Dify,我们将邀请您加入 Dify 社区。
|
||||
<img src="./images/wechat.png" alt="wechat" width="100"/>
|
||||
|
||||
## 安全问题
|
||||
|
||||
|
|
|
@ -189,6 +189,7 @@ TENCENT_VECTOR_DB_USERNAME=dify
|
|||
TENCENT_VECTOR_DB_DATABASE=dify
|
||||
TENCENT_VECTOR_DB_SHARD=1
|
||||
TENCENT_VECTOR_DB_REPLICAS=2
|
||||
TENCENT_VECTOR_DB_ENABLE_HYBRID_SEARCH=false
|
||||
|
||||
# ElasticSearch configuration
|
||||
ELASTICSEARCH_HOST=127.0.0.1
|
||||
|
|
|
@ -848,6 +848,11 @@ class AccountConfig(BaseSettings):
|
|||
default=5,
|
||||
)
|
||||
|
||||
EDUCATION_ENABLED: bool = Field(
|
||||
description="whether to enable education identity",
|
||||
default=False,
|
||||
)
|
||||
|
||||
|
||||
class FeatureConfig(
|
||||
# place the configs in alphabet order
|
||||
|
|
|
@ -48,3 +48,8 @@ class TencentVectorDBConfig(BaseSettings):
|
|||
description="Name of the specific Tencent Vector Database to connect to",
|
||||
default=None,
|
||||
)
|
||||
|
||||
TENCENT_VECTOR_DB_ENABLE_HYBRID_SEARCH: bool = Field(
|
||||
description="Enable hybrid search features",
|
||||
default=False,
|
||||
)
|
||||
|
|
|
@ -8,6 +8,7 @@ from werkzeug.exceptions import Forbidden
|
|||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
cloud_edition_billing_resource_check,
|
||||
setup_required,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
|
@ -23,6 +24,7 @@ class AppImportApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(app_import_fields)
|
||||
@cloud_edition_billing_resource_check("apps")
|
||||
def post(self):
|
||||
# Check user role first
|
||||
if not current_user.is_editor:
|
||||
|
|
|
@ -99,53 +99,64 @@ class ForgotPasswordResetApi(Resource):
|
|||
parser.add_argument("password_confirm", type=valid_password, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
new_password = args["new_password"]
|
||||
password_confirm = args["password_confirm"]
|
||||
|
||||
if str(new_password).strip() != str(password_confirm).strip():
|
||||
# Validate passwords match
|
||||
if args["new_password"] != args["password_confirm"]:
|
||||
raise PasswordMismatchError()
|
||||
|
||||
token = args["token"]
|
||||
reset_data = AccountService.get_reset_password_data(token)
|
||||
|
||||
if reset_data is None:
|
||||
# Validate token and get reset data
|
||||
reset_data = AccountService.get_reset_password_data(args["token"])
|
||||
if not reset_data:
|
||||
raise InvalidTokenError()
|
||||
|
||||
AccountService.revoke_reset_password_token(token)
|
||||
# Revoke token to prevent reuse
|
||||
AccountService.revoke_reset_password_token(args["token"])
|
||||
|
||||
# Generate secure salt and hash password
|
||||
salt = secrets.token_bytes(16)
|
||||
base64_salt = base64.b64encode(salt).decode()
|
||||
password_hashed = hash_password(args["new_password"], salt)
|
||||
|
||||
password_hashed = hash_password(new_password, salt)
|
||||
base64_password_hashed = base64.b64encode(password_hashed).decode()
|
||||
email = reset_data.get("email", "")
|
||||
|
||||
with Session(db.engine) as session:
|
||||
account = session.execute(select(Account).filter_by(email=reset_data.get("email"))).scalar_one_or_none()
|
||||
if account:
|
||||
account.password = base64_password_hashed
|
||||
account.password_salt = base64_salt
|
||||
db.session.commit()
|
||||
tenant = TenantService.get_join_tenants(account)
|
||||
if not tenant and not FeatureService.get_system_features().is_allow_create_workspace:
|
||||
tenant = TenantService.create_tenant(f"{account.name}'s Workspace")
|
||||
TenantService.create_tenant_member(tenant, account, role="owner")
|
||||
account.current_tenant = tenant
|
||||
tenant_was_created.send(tenant)
|
||||
else:
|
||||
try:
|
||||
account = AccountService.create_account_and_tenant(
|
||||
email=reset_data.get("email", ""),
|
||||
name=reset_data.get("email", ""),
|
||||
password=password_confirm,
|
||||
interface_language=languages[0],
|
||||
)
|
||||
except WorkSpaceNotAllowedCreateError:
|
||||
pass
|
||||
except AccountRegisterError:
|
||||
raise AccountInFreezeError()
|
||||
account = session.execute(select(Account).filter_by(email=email)).scalar_one_or_none()
|
||||
|
||||
if account:
|
||||
self._update_existing_account(account, password_hashed, salt, session)
|
||||
else:
|
||||
self._create_new_account(email, args["password_confirm"])
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
def _update_existing_account(self, account, password_hashed, salt, session):
|
||||
# Update existing account credentials
|
||||
account.password = base64.b64encode(password_hashed).decode()
|
||||
account.password_salt = base64.b64encode(salt).decode()
|
||||
session.commit()
|
||||
|
||||
# Create workspace if needed
|
||||
if (
|
||||
not TenantService.get_join_tenants(account)
|
||||
and FeatureService.get_system_features().is_allow_create_workspace
|
||||
):
|
||||
tenant = TenantService.create_tenant(f"{account.name}'s Workspace")
|
||||
TenantService.create_tenant_member(tenant, account, role="owner")
|
||||
account.current_tenant = tenant
|
||||
tenant_was_created.send(tenant)
|
||||
|
||||
def _create_new_account(self, email, password):
|
||||
# Create new account if allowed
|
||||
try:
|
||||
AccountService.create_account_and_tenant(
|
||||
email=email,
|
||||
name=email,
|
||||
password=password,
|
||||
interface_language=languages[0],
|
||||
)
|
||||
except WorkSpaceNotAllowedCreateError:
|
||||
pass
|
||||
except AccountRegisterError:
|
||||
raise AccountInFreezeError()
|
||||
|
||||
|
||||
api.add_resource(ForgotPasswordSendEmailApi, "/forgot-password")
|
||||
api.add_resource(ForgotPasswordCheckApi, "/forgot-password/validity")
|
||||
|
|
|
@ -641,7 +641,6 @@ class DatasetRetrievalSettingApi(Resource):
|
|||
VectorType.RELYT
|
||||
| VectorType.TIDB_VECTOR
|
||||
| VectorType.CHROMA
|
||||
| VectorType.TENCENT
|
||||
| VectorType.PGVECTO_RS
|
||||
| VectorType.BAIDU
|
||||
| VectorType.VIKINGDB
|
||||
|
@ -665,6 +664,7 @@ class DatasetRetrievalSettingApi(Resource):
|
|||
| VectorType.OPENGAUSS
|
||||
| VectorType.OCEANBASE
|
||||
| VectorType.TABLESTORE
|
||||
| VectorType.TENCENT
|
||||
):
|
||||
return {
|
||||
"retrieval_method": [
|
||||
|
@ -688,7 +688,6 @@ class DatasetRetrievalSettingMockApi(Resource):
|
|||
| VectorType.RELYT
|
||||
| VectorType.TIDB_VECTOR
|
||||
| VectorType.CHROMA
|
||||
| VectorType.TENCENT
|
||||
| VectorType.PGVECTO_RS
|
||||
| VectorType.BAIDU
|
||||
| VectorType.VIKINGDB
|
||||
|
@ -710,6 +709,7 @@ class DatasetRetrievalSettingMockApi(Resource):
|
|||
| VectorType.OPENGAUSS
|
||||
| VectorType.OCEANBASE
|
||||
| VectorType.TABLESTORE
|
||||
| VectorType.TENCENT
|
||||
):
|
||||
return {
|
||||
"retrieval_method": [
|
||||
|
|
|
@ -14,7 +14,12 @@ class WebsiteCrawlApi(Resource):
|
|||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument(
|
||||
"provider", type=str, choices=["firecrawl", "jinareader"], required=True, nullable=True, location="json"
|
||||
"provider",
|
||||
type=str,
|
||||
choices=["firecrawl", "watercrawl", "jinareader"],
|
||||
required=True,
|
||||
nullable=True,
|
||||
location="json",
|
||||
)
|
||||
parser.add_argument("url", type=str, required=True, nullable=True, location="json")
|
||||
parser.add_argument("options", type=dict, required=True, nullable=True, location="json")
|
||||
|
@ -34,7 +39,9 @@ class WebsiteCrawlStatusApi(Resource):
|
|||
@account_initialization_required
|
||||
def get(self, job_id: str):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("provider", type=str, choices=["firecrawl", "jinareader"], required=True, location="args")
|
||||
parser.add_argument(
|
||||
"provider", type=str, choices=["firecrawl", "watercrawl", "jinareader"], required=True, location="args"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
# get crawl status
|
||||
try:
|
||||
|
|
|
@ -103,6 +103,18 @@ class AccountInFreezeError(BaseHTTPException):
|
|||
)
|
||||
|
||||
|
||||
class EducationVerifyLimitError(BaseHTTPException):
|
||||
error_code = "education_verify_limit"
|
||||
description = "Rate limit exceeded"
|
||||
code = 429
|
||||
|
||||
|
||||
class EducationActivateLimitError(BaseHTTPException):
|
||||
error_code = "education_activate_limit"
|
||||
description = "Rate limit exceeded"
|
||||
code = 429
|
||||
|
||||
|
||||
class CompilanceRateLimitError(BaseHTTPException):
|
||||
error_code = "compilance_rate_limit"
|
||||
description = "Rate limit exceeded for downloading compliance report."
|
||||
|
|
|
@ -15,7 +15,13 @@ from controllers.console.workspace.error import (
|
|||
InvalidInvitationCodeError,
|
||||
RepeatPasswordNotMatchError,
|
||||
)
|
||||
from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
cloud_edition_billing_enabled,
|
||||
enterprise_license_required,
|
||||
only_edition_cloud,
|
||||
setup_required,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from fields.member_fields import account_fields
|
||||
from libs.helper import TimestampField, timezone
|
||||
|
@ -292,6 +298,79 @@ class AccountDeleteUpdateFeedbackApi(Resource):
|
|||
return {"result": "success"}
|
||||
|
||||
|
||||
class EducationVerifyApi(Resource):
|
||||
verify_fields = {
|
||||
"token": fields.String,
|
||||
}
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@only_edition_cloud
|
||||
@cloud_edition_billing_enabled
|
||||
@marshal_with(verify_fields)
|
||||
def get(self):
|
||||
account = current_user
|
||||
|
||||
return BillingService.EducationIdentity.verify(account.id, account.email)
|
||||
|
||||
|
||||
class EducationApi(Resource):
|
||||
status_fields = {
|
||||
"result": fields.Boolean,
|
||||
}
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@only_edition_cloud
|
||||
@cloud_edition_billing_enabled
|
||||
def post(self):
|
||||
account = current_user
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("token", type=str, required=True, location="json")
|
||||
parser.add_argument("institution", type=str, required=True, location="json")
|
||||
parser.add_argument("role", type=str, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
return BillingService.EducationIdentity.activate(account, args["token"], args["institution"], args["role"])
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@only_edition_cloud
|
||||
@cloud_edition_billing_enabled
|
||||
@marshal_with(status_fields)
|
||||
def get(self):
|
||||
account = current_user
|
||||
|
||||
return BillingService.EducationIdentity.is_active(account.id)
|
||||
|
||||
|
||||
class EducationAutoCompleteApi(Resource):
|
||||
data_fields = {
|
||||
"data": fields.List(fields.String),
|
||||
"curr_page": fields.Integer,
|
||||
"has_next": fields.Boolean,
|
||||
}
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@only_edition_cloud
|
||||
@cloud_edition_billing_enabled
|
||||
@marshal_with(data_fields)
|
||||
def get(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("keywords", type=str, required=True, location="args")
|
||||
parser.add_argument("page", type=int, required=False, location="args", default=0)
|
||||
parser.add_argument("limit", type=int, required=False, location="args", default=20)
|
||||
args = parser.parse_args()
|
||||
|
||||
return BillingService.EducationIdentity.autocomplete(args["keywords"], args["page"], args["limit"])
|
||||
|
||||
|
||||
# Register API resources
|
||||
api.add_resource(AccountInitApi, "/account/init")
|
||||
api.add_resource(AccountProfileApi, "/account/profile")
|
||||
|
@ -305,5 +384,8 @@ api.add_resource(AccountIntegrateApi, "/account/integrates")
|
|||
api.add_resource(AccountDeleteVerifyApi, "/account/delete/verify")
|
||||
api.add_resource(AccountDeleteApi, "/account/delete")
|
||||
api.add_resource(AccountDeleteUpdateFeedbackApi, "/account/delete/feedback")
|
||||
api.add_resource(EducationVerifyApi, "/account/education/verify")
|
||||
api.add_resource(EducationApi, "/account/education")
|
||||
api.add_resource(EducationAutoCompleteApi, "/account/education/autocomplete")
|
||||
# api.add_resource(AccountEmailApi, '/account/email')
|
||||
# api.add_resource(AccountEmailVerifyApi, '/account/email-verify')
|
||||
|
|
|
@ -236,7 +236,7 @@ class PluginFetchManifestApi(Resource):
|
|||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@plugin_permission_required(debug_required=True)
|
||||
@plugin_permission_required(install_required=True)
|
||||
def get(self):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
|
@ -260,7 +260,7 @@ class PluginFetchInstallTasksApi(Resource):
|
|||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@plugin_permission_required(debug_required=True)
|
||||
@plugin_permission_required(install_required=True)
|
||||
def get(self):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
|
@ -281,7 +281,7 @@ class PluginFetchInstallTaskApi(Resource):
|
|||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@plugin_permission_required(debug_required=True)
|
||||
@plugin_permission_required(install_required=True)
|
||||
def get(self, task_id: str):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
|
@ -295,7 +295,7 @@ class PluginDeleteInstallTaskApi(Resource):
|
|||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@plugin_permission_required(debug_required=True)
|
||||
@plugin_permission_required(install_required=True)
|
||||
def post(self, task_id: str):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
|
@ -309,7 +309,7 @@ class PluginDeleteAllInstallTaskItemsApi(Resource):
|
|||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@plugin_permission_required(debug_required=True)
|
||||
@plugin_permission_required(install_required=True)
|
||||
def post(self):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
|
@ -323,7 +323,7 @@ class PluginDeleteInstallTaskItemApi(Resource):
|
|||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@plugin_permission_required(debug_required=True)
|
||||
@plugin_permission_required(install_required=True)
|
||||
def post(self, task_id: str, identifier: str):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
|
@ -337,7 +337,7 @@ class PluginUpgradeFromMarketplaceApi(Resource):
|
|||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@plugin_permission_required(debug_required=True)
|
||||
@plugin_permission_required(install_required=True)
|
||||
def post(self):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
|
@ -360,7 +360,7 @@ class PluginUpgradeFromGithubApi(Resource):
|
|||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@plugin_permission_required(debug_required=True)
|
||||
@plugin_permission_required(install_required=True)
|
||||
def post(self):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
|
@ -391,7 +391,7 @@ class PluginUninstallApi(Resource):
|
|||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@plugin_permission_required(debug_required=True)
|
||||
@plugin_permission_required(install_required=True)
|
||||
def post(self):
|
||||
req = reqparse.RequestParser()
|
||||
req.add_argument("plugin_installation_id", type=str, required=True, location="json")
|
||||
|
|
|
@ -216,6 +216,23 @@ class WebappLogoWorkspaceApi(Resource):
|
|||
return {"id": upload_file.id}, 201
|
||||
|
||||
|
||||
class WorkspaceInfoApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
# Change workspace name
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("name", type=str, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
tenant = Tenant.query.filter(Tenant.id == current_user.current_tenant_id).one_or_404()
|
||||
tenant.name = args["name"]
|
||||
db.session.commit()
|
||||
|
||||
return {"result": "success", "tenant": marshal(WorkspaceService.get_tenant_info(tenant), tenant_fields)}
|
||||
|
||||
|
||||
api.add_resource(TenantListApi, "/workspaces") # GET for getting all tenants
|
||||
api.add_resource(WorkspaceListApi, "/all-workspaces") # GET for getting all tenants
|
||||
api.add_resource(TenantApi, "/workspaces/current", endpoint="workspaces_current") # GET for getting current tenant info
|
||||
|
@ -223,3 +240,4 @@ api.add_resource(TenantApi, "/info", endpoint="info") # Deprecated
|
|||
api.add_resource(SwitchWorkspaceApi, "/workspaces/switch") # POST for switching tenant
|
||||
api.add_resource(CustomConfigWorkspaceApi, "/workspaces/custom-config")
|
||||
api.add_resource(WebappLogoWorkspaceApi, "/workspaces/custom-config/webapp-logo/upload")
|
||||
api.add_resource(WorkspaceInfoApi, "/workspaces/info") # POST for changing workspace info
|
||||
|
|
|
@ -54,6 +54,17 @@ def only_edition_self_hosted(view):
|
|||
return decorated
|
||||
|
||||
|
||||
def cloud_edition_billing_enabled(view):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
features = FeatureService.get_features(current_user.current_tenant_id)
|
||||
if not features.billing.enabled:
|
||||
abort(403, "Billing feature is not enabled.")
|
||||
return view(*args, **kwargs)
|
||||
|
||||
return decorated
|
||||
|
||||
|
||||
def cloud_edition_billing_resource_check(resource: str):
|
||||
def interceptor(view):
|
||||
@wraps(view)
|
||||
|
|
|
@ -6,5 +6,6 @@ bp = Blueprint("service_api", __name__, url_prefix="/v1")
|
|||
api = ExternalApi(bp)
|
||||
|
||||
from . import index
|
||||
from .app import app, audio, completion, conversation, file, message, workflow
|
||||
from .app import annotation, app, audio, completion, conversation, file, message, workflow
|
||||
from .dataset import dataset, document, hit_testing, metadata, segment, upload_file
|
||||
from .workspace import models
|
||||
|
|
|
@ -0,0 +1,107 @@
|
|||
from flask import request
|
||||
from flask_restful import Resource, marshal, marshal_with, reqparse # type: ignore
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.service_api import api
|
||||
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
|
||||
from extensions.ext_redis import redis_client
|
||||
from fields.annotation_fields import (
|
||||
annotation_fields,
|
||||
)
|
||||
from libs.login import current_user
|
||||
from models.model import App, EndUser
|
||||
from services.annotation_service import AppAnnotationService
|
||||
|
||||
|
||||
class AnnotationReplyActionApi(Resource):
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
|
||||
def post(self, app_model: App, end_user: EndUser, action):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("score_threshold", required=True, type=float, location="json")
|
||||
parser.add_argument("embedding_provider_name", required=True, type=str, location="json")
|
||||
parser.add_argument("embedding_model_name", required=True, type=str, location="json")
|
||||
args = parser.parse_args()
|
||||
if action == "enable":
|
||||
result = AppAnnotationService.enable_app_annotation(args, app_model.id)
|
||||
elif action == "disable":
|
||||
result = AppAnnotationService.disable_app_annotation(app_model.id)
|
||||
else:
|
||||
raise ValueError("Unsupported annotation reply action")
|
||||
return result, 200
|
||||
|
||||
|
||||
class AnnotationReplyActionStatusApi(Resource):
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY))
|
||||
def get(self, app_model: App, end_user: EndUser, job_id, action):
|
||||
job_id = str(job_id)
|
||||
app_annotation_job_key = "{}_app_annotation_job_{}".format(action, str(job_id))
|
||||
cache_result = redis_client.get(app_annotation_job_key)
|
||||
if cache_result is None:
|
||||
raise ValueError("The job does not exist.")
|
||||
|
||||
job_status = cache_result.decode()
|
||||
error_msg = ""
|
||||
if job_status == "error":
|
||||
app_annotation_error_key = "{}_app_annotation_error_{}".format(action, str(job_id))
|
||||
error_msg = redis_client.get(app_annotation_error_key).decode()
|
||||
|
||||
return {"job_id": job_id, "job_status": job_status, "error_msg": error_msg}, 200
|
||||
|
||||
|
||||
class AnnotationListApi(Resource):
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY))
|
||||
def get(self, app_model: App, end_user: EndUser):
|
||||
page = request.args.get("page", default=1, type=int)
|
||||
limit = request.args.get("limit", default=20, type=int)
|
||||
keyword = request.args.get("keyword", default="", type=str)
|
||||
|
||||
annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id(app_model.id, page, limit, keyword)
|
||||
response = {
|
||||
"data": marshal(annotation_list, annotation_fields),
|
||||
"has_more": len(annotation_list) == limit,
|
||||
"limit": limit,
|
||||
"total": total,
|
||||
"page": page,
|
||||
}
|
||||
return response, 200
|
||||
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
|
||||
@marshal_with(annotation_fields)
|
||||
def post(self, app_model: App, end_user: EndUser):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("question", required=True, type=str, location="json")
|
||||
parser.add_argument("answer", required=True, type=str, location="json")
|
||||
args = parser.parse_args()
|
||||
annotation = AppAnnotationService.insert_app_annotation_directly(args, app_model.id)
|
||||
return annotation
|
||||
|
||||
|
||||
class AnnotationUpdateDeleteApi(Resource):
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
|
||||
@marshal_with(annotation_fields)
|
||||
def post(self, app_model: App, end_user: EndUser, annotation_id):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
annotation_id = str(annotation_id)
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("question", required=True, type=str, location="json")
|
||||
parser.add_argument("answer", required=True, type=str, location="json")
|
||||
args = parser.parse_args()
|
||||
annotation = AppAnnotationService.update_app_annotation_directly(args, app_model.id, annotation_id)
|
||||
return annotation
|
||||
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY))
|
||||
def delete(self, app_model: App, end_user: EndUser, annotation_id):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
annotation_id = str(annotation_id)
|
||||
AppAnnotationService.delete_app_annotation(app_model.id, annotation_id)
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
api.add_resource(AnnotationReplyActionApi, "/apps/annotation-reply/<string:action>")
|
||||
api.add_resource(AnnotationReplyActionStatusApi, "/apps/annotation-reply/<string:action>/status/<uuid:job_id>")
|
||||
api.add_resource(AnnotationListApi, "/apps/annotations")
|
||||
api.add_resource(AnnotationUpdateDeleteApi, "/apps/annotations/<uuid:annotation_id>")
|
|
@ -1,3 +1,4 @@
|
|||
import json
|
||||
import logging
|
||||
|
||||
from flask_restful import Resource, fields, marshal_with, reqparse # type: ignore
|
||||
|
@ -10,7 +11,7 @@ from controllers.service_api.app.error import NotChatAppError
|
|||
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from fields.conversation_fields import message_file_fields
|
||||
from fields.message_fields import agent_thought_fields, feedback_fields, retriever_resource_fields
|
||||
from fields.message_fields import agent_thought_fields, feedback_fields
|
||||
from fields.raws import FilesContainedField
|
||||
from libs.helper import TimestampField, uuid_value
|
||||
from models.model import App, AppMode, EndUser
|
||||
|
@ -19,6 +20,14 @@ from services.message_service import MessageService
|
|||
|
||||
|
||||
class MessageListApi(Resource):
|
||||
def get_retriever_resources(self):
|
||||
try:
|
||||
if self.message_metadata:
|
||||
return json.loads(self.message_metadata).get("retriever_resources", [])
|
||||
return []
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
return []
|
||||
|
||||
message_fields = {
|
||||
"id": fields.String,
|
||||
"conversation_id": fields.String,
|
||||
|
@ -28,7 +37,7 @@ class MessageListApi(Resource):
|
|||
"answer": fields.String(attribute="re_sign_file_url_answer"),
|
||||
"message_files": fields.List(fields.Nested(message_file_fields)),
|
||||
"feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True),
|
||||
"retriever_resources": fields.List(fields.Nested(retriever_resource_fields)),
|
||||
"retriever_resources": get_retriever_resources,
|
||||
"created_at": TimestampField,
|
||||
"agent_thoughts": fields.List(fields.Nested(agent_thought_fields)),
|
||||
"status": fields.String,
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
from flask import request
|
||||
from flask_restful import marshal, reqparse # type: ignore
|
||||
from werkzeug.exceptions import NotFound
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
import services.dataset_service
|
||||
from controllers.service_api import api
|
||||
|
@ -12,7 +12,7 @@ from core.provider_manager import ProviderManager
|
|||
from fields.dataset_fields import dataset_detail_fields
|
||||
from libs.login import current_user
|
||||
from models.dataset import Dataset, DatasetPermissionEnum
|
||||
from services.dataset_service import DatasetService
|
||||
from services.dataset_service import DatasetPermissionService, DatasetService
|
||||
|
||||
|
||||
def _validate_name(name):
|
||||
|
@ -21,6 +21,12 @@ def _validate_name(name):
|
|||
return name
|
||||
|
||||
|
||||
def _validate_description_length(description):
|
||||
if len(description) > 400:
|
||||
raise ValueError("Description cannot exceed 400 characters.")
|
||||
return description
|
||||
|
||||
|
||||
class DatasetListApi(DatasetApiResource):
|
||||
"""Resource for datasets."""
|
||||
|
||||
|
@ -137,6 +143,145 @@ class DatasetListApi(DatasetApiResource):
|
|||
class DatasetApi(DatasetApiResource):
|
||||
"""Resource for dataset."""
|
||||
|
||||
def get(self, _, dataset_id):
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
try:
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
data = marshal(dataset, dataset_detail_fields)
|
||||
if data.get("permission") == "partial_members":
|
||||
part_users_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
|
||||
data.update({"partial_member_list": part_users_list})
|
||||
|
||||
# check embedding setting
|
||||
provider_manager = ProviderManager()
|
||||
configurations = provider_manager.get_configurations(tenant_id=current_user.current_tenant_id)
|
||||
|
||||
embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True)
|
||||
|
||||
model_names = []
|
||||
for embedding_model in embedding_models:
|
||||
model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}")
|
||||
|
||||
if data["indexing_technique"] == "high_quality":
|
||||
item_model = f"{data['embedding_model']}:{data['embedding_model_provider']}"
|
||||
if item_model in model_names:
|
||||
data["embedding_available"] = True
|
||||
else:
|
||||
data["embedding_available"] = False
|
||||
else:
|
||||
data["embedding_available"] = True
|
||||
|
||||
if data.get("permission") == "partial_members":
|
||||
part_users_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
|
||||
data.update({"partial_member_list": part_users_list})
|
||||
|
||||
return data, 200
|
||||
|
||||
def patch(self, _, dataset_id):
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument(
|
||||
"name",
|
||||
nullable=False,
|
||||
help="type is required. Name must be between 1 to 40 characters.",
|
||||
type=_validate_name,
|
||||
)
|
||||
parser.add_argument("description", location="json", store_missing=False, type=_validate_description_length)
|
||||
parser.add_argument(
|
||||
"indexing_technique",
|
||||
type=str,
|
||||
location="json",
|
||||
choices=Dataset.INDEXING_TECHNIQUE_LIST,
|
||||
nullable=True,
|
||||
help="Invalid indexing technique.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"permission",
|
||||
type=str,
|
||||
location="json",
|
||||
choices=(DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM),
|
||||
help="Invalid permission.",
|
||||
)
|
||||
parser.add_argument("embedding_model", type=str, location="json", help="Invalid embedding model.")
|
||||
parser.add_argument(
|
||||
"embedding_model_provider", type=str, location="json", help="Invalid embedding model provider."
|
||||
)
|
||||
parser.add_argument("retrieval_model", type=dict, location="json", help="Invalid retrieval model.")
|
||||
parser.add_argument("partial_member_list", type=list, location="json", help="Invalid parent user list.")
|
||||
|
||||
parser.add_argument(
|
||||
"external_retrieval_model",
|
||||
type=dict,
|
||||
required=False,
|
||||
nullable=True,
|
||||
location="json",
|
||||
help="Invalid external retrieval model.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"external_knowledge_id",
|
||||
type=str,
|
||||
required=False,
|
||||
nullable=True,
|
||||
location="json",
|
||||
help="Invalid external knowledge id.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"external_knowledge_api_id",
|
||||
type=str,
|
||||
required=False,
|
||||
nullable=True,
|
||||
location="json",
|
||||
help="Invalid external knowledge api id.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
data = request.get_json()
|
||||
|
||||
# check embedding model setting
|
||||
if data.get("indexing_technique") == "high_quality":
|
||||
DatasetService.check_embedding_model_setting(
|
||||
dataset.tenant_id, data.get("embedding_model_provider"), data.get("embedding_model")
|
||||
)
|
||||
|
||||
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
||||
DatasetPermissionService.check_permission(
|
||||
current_user, dataset, data.get("permission"), data.get("partial_member_list")
|
||||
)
|
||||
|
||||
dataset = DatasetService.update_dataset(dataset_id_str, args, current_user)
|
||||
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
|
||||
result_data = marshal(dataset, dataset_detail_fields)
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
if data.get("partial_member_list") and data.get("permission") == "partial_members":
|
||||
DatasetPermissionService.update_partial_member_list(
|
||||
tenant_id, dataset_id_str, data.get("partial_member_list")
|
||||
)
|
||||
# clear partial member list when permission is only_me or all_team_members
|
||||
elif (
|
||||
data.get("permission") == DatasetPermissionEnum.ONLY_ME
|
||||
or data.get("permission") == DatasetPermissionEnum.ALL_TEAM
|
||||
):
|
||||
DatasetPermissionService.clear_partial_member_list(dataset_id_str)
|
||||
|
||||
partial_member_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
|
||||
result_data.update({"partial_member_list": partial_member_list})
|
||||
|
||||
return result_data, 200
|
||||
|
||||
def delete(self, _, dataset_id):
|
||||
"""
|
||||
Deletes a dataset given its ID.
|
||||
|
@ -158,6 +303,7 @@ class DatasetApi(DatasetApiResource):
|
|||
|
||||
try:
|
||||
if DatasetService.delete_dataset(dataset_id_str, current_user):
|
||||
DatasetPermissionService.clear_partial_member_list(dataset_id_str)
|
||||
return {"result": "success"}, 204
|
||||
else:
|
||||
raise NotFound("Dataset not found.")
|
||||
|
|
|
@ -341,7 +341,7 @@ class DocumentListApi(DatasetApiResource):
|
|||
search = f"%{search}%"
|
||||
query = query.filter(Document.name.like(search))
|
||||
|
||||
query = query.order_by(desc(Document.created_at))
|
||||
query = query.order_by(desc(Document.created_at), desc(Document.position))
|
||||
|
||||
paginated_documents = query.paginate(page=page, per_page=limit, max_per_page=100, error_out=False)
|
||||
documents = paginated_documents.items
|
||||
|
|
|
@ -0,0 +1,21 @@
|
|||
from flask_login import current_user # type: ignore
|
||||
from flask_restful import Resource # type: ignore
|
||||
|
||||
from controllers.service_api import api
|
||||
from controllers.service_api.wraps import validate_dataset_token
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from services.model_provider_service import ModelProviderService
|
||||
|
||||
|
||||
class ModelProviderAvailableModelApi(Resource):
|
||||
@validate_dataset_token
|
||||
def get(self, _, model_type):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
models = model_provider_service.get_models_by_model_type(tenant_id=tenant_id, model_type=model_type)
|
||||
|
||||
return jsonable_encoder({"data": models})
|
||||
|
||||
|
||||
api.add_resource(ModelProviderAvailableModelApi, "/workspaces/current/models/model-types/<string:model_type>")
|
|
@ -59,6 +59,27 @@ def validate_app_token(view: Optional[Callable] = None, *, fetch_user_arg: Optio
|
|||
if tenant.status == TenantStatus.ARCHIVE:
|
||||
raise Forbidden("The workspace's status is archived.")
|
||||
|
||||
tenant_account_join = (
|
||||
db.session.query(Tenant, TenantAccountJoin)
|
||||
.filter(Tenant.id == api_token.tenant_id)
|
||||
.filter(TenantAccountJoin.tenant_id == Tenant.id)
|
||||
.filter(TenantAccountJoin.role.in_(["owner"]))
|
||||
.filter(Tenant.status == TenantStatus.NORMAL)
|
||||
.one_or_none()
|
||||
) # TODO: only owner information is required, so only one is returned.
|
||||
if tenant_account_join:
|
||||
tenant, ta = tenant_account_join
|
||||
account = Account.query.filter_by(id=ta.account_id).first()
|
||||
# Login admin
|
||||
if account:
|
||||
account.current_tenant = tenant
|
||||
current_app.login_manager._update_request_context_with_user(account) # type: ignore
|
||||
user_logged_in.send(current_app._get_current_object(), user=_get_user()) # type: ignore
|
||||
else:
|
||||
raise Unauthorized("Tenant owner account does not exist.")
|
||||
else:
|
||||
raise Unauthorized("Tenant does not exist.")
|
||||
|
||||
kwargs["app_model"] = app_model
|
||||
|
||||
if fetch_user_arg:
|
||||
|
|
|
@ -19,6 +19,8 @@ class PassportResource(Resource):
|
|||
def get(self):
|
||||
system_features = FeatureService.get_system_features()
|
||||
app_code = request.headers.get("X-App-Code")
|
||||
user_id = request.args.get("user_id")
|
||||
|
||||
if app_code is None:
|
||||
raise Unauthorized("X-App-Code header is missing.")
|
||||
|
||||
|
@ -36,16 +38,33 @@ class PassportResource(Resource):
|
|||
if not app_model or app_model.status != "normal" or not app_model.enable_site:
|
||||
raise NotFound()
|
||||
|
||||
end_user = EndUser(
|
||||
tenant_id=app_model.tenant_id,
|
||||
app_id=app_model.id,
|
||||
type="browser",
|
||||
is_anonymous=True,
|
||||
session_id=generate_session_id(),
|
||||
)
|
||||
if user_id:
|
||||
end_user = (
|
||||
db.session.query(EndUser).filter(EndUser.app_id == app_model.id, EndUser.session_id == user_id).first()
|
||||
)
|
||||
|
||||
db.session.add(end_user)
|
||||
db.session.commit()
|
||||
if end_user:
|
||||
pass
|
||||
else:
|
||||
end_user = EndUser(
|
||||
tenant_id=app_model.tenant_id,
|
||||
app_id=app_model.id,
|
||||
type="browser",
|
||||
is_anonymous=True,
|
||||
session_id=user_id,
|
||||
)
|
||||
db.session.add(end_user)
|
||||
db.session.commit()
|
||||
else:
|
||||
end_user = EndUser(
|
||||
tenant_id=app_model.tenant_id,
|
||||
app_id=app_model.id,
|
||||
type="browser",
|
||||
is_anonymous=True,
|
||||
session_id=generate_session_id(),
|
||||
)
|
||||
db.session.add(end_user)
|
||||
db.session.commit()
|
||||
|
||||
payload = {
|
||||
"iss": site.app_id,
|
||||
|
|
|
@ -12,39 +12,45 @@ class CotAgentOutputParser:
|
|||
def handle_react_stream_output(
|
||||
cls, llm_response: Generator[LLMResultChunk, None, None], usage_dict: dict
|
||||
) -> Generator[Union[str, AgentScratchpadUnit.Action], None, None]:
|
||||
def parse_action(json_str):
|
||||
try:
|
||||
action = json.loads(json_str, strict=False)
|
||||
action_name = None
|
||||
action_input = None
|
||||
def parse_action(action) -> Union[str, AgentScratchpadUnit.Action]:
|
||||
action_name = None
|
||||
action_input = None
|
||||
if isinstance(action, str):
|
||||
try:
|
||||
action = json.loads(action, strict=False)
|
||||
except json.JSONDecodeError:
|
||||
return action or ""
|
||||
|
||||
# cohere always returns a list
|
||||
if isinstance(action, list) and len(action) == 1:
|
||||
action = action[0]
|
||||
# cohere always returns a list
|
||||
if isinstance(action, list) and len(action) == 1:
|
||||
action = action[0]
|
||||
|
||||
for key, value in action.items():
|
||||
if "input" in key.lower():
|
||||
action_input = value
|
||||
else:
|
||||
action_name = value
|
||||
|
||||
if action_name is not None and action_input is not None:
|
||||
return AgentScratchpadUnit.Action(
|
||||
action_name=action_name,
|
||||
action_input=action_input,
|
||||
)
|
||||
for key, value in action.items():
|
||||
if "input" in key.lower():
|
||||
action_input = value
|
||||
else:
|
||||
return json_str or ""
|
||||
except:
|
||||
return json_str or ""
|
||||
action_name = value
|
||||
|
||||
def extra_json_from_code_block(code_block) -> Generator[Union[str, AgentScratchpadUnit.Action], None, None]:
|
||||
code_blocks = re.findall(r"```(.*?)```", code_block, re.DOTALL)
|
||||
if not code_blocks:
|
||||
return
|
||||
for block in code_blocks:
|
||||
json_text = re.sub(r"^[a-zA-Z]+\n", "", block.strip(), flags=re.MULTILINE)
|
||||
yield parse_action(json_text)
|
||||
if action_name is not None and action_input is not None:
|
||||
return AgentScratchpadUnit.Action(
|
||||
action_name=action_name,
|
||||
action_input=action_input,
|
||||
)
|
||||
else:
|
||||
return json.dumps(action)
|
||||
|
||||
def extra_json_from_code_block(code_block) -> list[Union[list, dict]]:
|
||||
blocks = re.findall(r"```[json]*\s*([\[{].*[]}])\s*```", code_block, re.DOTALL | re.IGNORECASE)
|
||||
if not blocks:
|
||||
return []
|
||||
try:
|
||||
json_blocks = []
|
||||
for block in blocks:
|
||||
json_text = re.sub(r"^[a-zA-Z]+\n", "", block.strip(), flags=re.MULTILINE)
|
||||
json_blocks.append(json.loads(json_text, strict=False))
|
||||
return json_blocks
|
||||
except:
|
||||
return []
|
||||
|
||||
code_block_cache = ""
|
||||
code_block_delimiter_count = 0
|
||||
|
@ -78,7 +84,7 @@ class CotAgentOutputParser:
|
|||
delta = response_content[index : index + steps]
|
||||
yield_delta = False
|
||||
|
||||
if delta == "`":
|
||||
if not in_json and delta == "`":
|
||||
last_character = delta
|
||||
code_block_cache += delta
|
||||
code_block_delimiter_count += 1
|
||||
|
@ -159,8 +165,14 @@ class CotAgentOutputParser:
|
|||
if code_block_delimiter_count == 3:
|
||||
if in_code_block:
|
||||
last_character = delta
|
||||
yield from extra_json_from_code_block(code_block_cache)
|
||||
code_block_cache = ""
|
||||
action_json_list = extra_json_from_code_block(code_block_cache)
|
||||
if action_json_list:
|
||||
for action_json in action_json_list:
|
||||
yield parse_action(action_json)
|
||||
code_block_cache = ""
|
||||
else:
|
||||
index += steps
|
||||
continue
|
||||
|
||||
in_code_block = not in_code_block
|
||||
code_block_delimiter_count = 0
|
||||
|
|
|
@ -70,11 +70,20 @@ class AgentStrategyIdentity(ToolIdentity):
|
|||
pass
|
||||
|
||||
|
||||
class AgentFeature(enum.StrEnum):
|
||||
"""
|
||||
Agent Feature, used to describe the features of the agent strategy.
|
||||
"""
|
||||
|
||||
HISTORY_MESSAGES = "history-messages"
|
||||
|
||||
|
||||
class AgentStrategyEntity(BaseModel):
|
||||
identity: AgentStrategyIdentity
|
||||
parameters: list[AgentStrategyParameter] = Field(default_factory=list)
|
||||
description: I18nObject = Field(..., description="The description of the agent strategy")
|
||||
output_schema: Optional[dict] = None
|
||||
features: Optional[list[AgentFeature]] = None
|
||||
|
||||
# pydantic configs
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
|
|
@ -146,6 +146,7 @@ class BasicProviderConfig(BaseModel):
|
|||
BOOLEAN = CommonParameterType.BOOLEAN.value
|
||||
APP_SELECTOR = CommonParameterType.APP_SELECTOR.value
|
||||
MODEL_SELECTOR = CommonParameterType.MODEL_SELECTOR.value
|
||||
TOOLS_SELECTOR = CommonParameterType.TOOLS_SELECTOR.value
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> "ProviderConfig.Type":
|
||||
|
|
|
@ -213,9 +213,24 @@ class LangFuseDataTrace(BaseTraceInstance):
|
|||
|
||||
if process_data and process_data.get("model_mode") == "chat":
|
||||
total_token = metadata.get("total_tokens", 0)
|
||||
prompt_tokens = 0
|
||||
completion_tokens = 0
|
||||
try:
|
||||
if outputs.get("usage"):
|
||||
prompt_tokens = outputs.get("usage", {}).get("prompt_tokens", 0)
|
||||
completion_tokens = outputs.get("usage", {}).get("completion_tokens", 0)
|
||||
else:
|
||||
prompt_tokens = process_data.get("usage", {}).get("prompt_tokens", 0)
|
||||
completion_tokens = process_data.get("usage", {}).get("completion_tokens", 0)
|
||||
except Exception:
|
||||
logger.error("Failed to extract usage", exc_info=True)
|
||||
|
||||
# add generation
|
||||
generation_usage = GenerationUsage(
|
||||
input=prompt_tokens,
|
||||
output=completion_tokens,
|
||||
total=total_token,
|
||||
unit=UnitEnum.TOKENS,
|
||||
)
|
||||
|
||||
node_generation_data = LangfuseGeneration(
|
||||
|
|
|
@ -199,6 +199,7 @@ class LangSmithDataTrace(BaseTraceInstance):
|
|||
)
|
||||
|
||||
process_data = json.loads(node_execution.process_data) if node_execution.process_data else {}
|
||||
|
||||
if process_data and process_data.get("model_mode") == "chat":
|
||||
run_type = LangSmithRunType.llm
|
||||
metadata.update(
|
||||
|
@ -212,9 +213,23 @@ class LangSmithDataTrace(BaseTraceInstance):
|
|||
else:
|
||||
run_type = LangSmithRunType.tool
|
||||
|
||||
prompt_tokens = 0
|
||||
completion_tokens = 0
|
||||
try:
|
||||
if outputs.get("usage"):
|
||||
prompt_tokens = outputs.get("usage", {}).get("prompt_tokens", 0)
|
||||
completion_tokens = outputs.get("usage", {}).get("completion_tokens", 0)
|
||||
else:
|
||||
prompt_tokens = process_data.get("usage", {}).get("prompt_tokens", 0)
|
||||
completion_tokens = process_data.get("usage", {}).get("completion_tokens", 0)
|
||||
except Exception:
|
||||
logger.error("Failed to extract usage", exc_info=True)
|
||||
|
||||
node_dotted_order = generate_dotted_order(node_execution_id, created_at, workflow_dotted_order)
|
||||
langsmith_run = LangSmithRunModel(
|
||||
total_tokens=node_total_tokens,
|
||||
input_tokens=prompt_tokens,
|
||||
output_tokens=completion_tokens,
|
||||
name=node_type,
|
||||
inputs=inputs,
|
||||
run_type=run_type,
|
||||
|
|
|
@ -27,9 +27,26 @@ class CleanProcessor:
|
|||
pattern = r"([a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+)"
|
||||
text = re.sub(pattern, "", text)
|
||||
|
||||
# Remove URL
|
||||
pattern = r"https?://[^\s]+"
|
||||
text = re.sub(pattern, "", text)
|
||||
# Remove URL but keep Markdown image URLs
|
||||
# First, temporarily replace Markdown image URLs with a placeholder
|
||||
markdown_image_pattern = r"!\[.*?\]\((https?://[^\s)]+)\)"
|
||||
placeholders: list[str] = []
|
||||
|
||||
def replace_with_placeholder(match, placeholders=placeholders):
|
||||
url = match.group(1)
|
||||
placeholder = f"__MARKDOWN_IMAGE_URL_{len(placeholders)}__"
|
||||
placeholders.append(url)
|
||||
return f""
|
||||
|
||||
text = re.sub(markdown_image_pattern, replace_with_placeholder, text)
|
||||
|
||||
# Now remove all remaining URLs
|
||||
url_pattern = r"https?://[^\s)]+"
|
||||
text = re.sub(url_pattern, "", text)
|
||||
|
||||
# Finally, restore the Markdown image URLs
|
||||
for i, url in enumerate(placeholders):
|
||||
text = text.replace(f"__MARKDOWN_IMAGE_URL_{i}__", url)
|
||||
return text
|
||||
|
||||
def filter_string(self, text):
|
||||
|
|
|
@ -1,10 +1,13 @@
|
|||
import copy
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, Optional
|
||||
|
||||
from opensearchpy import OpenSearch
|
||||
from opensearchpy import OpenSearch, helpers
|
||||
from opensearchpy.helpers import BulkIndexError
|
||||
from pydantic import BaseModel, model_validator
|
||||
from tenacity import retry, stop_after_attempt, wait_exponential
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.datasource.vdb.field import Field
|
||||
|
@ -77,31 +80,74 @@ class LindormVectorStore(BaseVector):
|
|||
def refresh(self):
|
||||
self._client.indices.refresh(index=self._collection_name)
|
||||
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
actions = []
|
||||
def add_texts(
|
||||
self,
|
||||
documents: list[Document],
|
||||
embeddings: list[list[float]],
|
||||
batch_size: int = 64,
|
||||
timeout: int = 60,
|
||||
**kwargs,
|
||||
):
|
||||
logger.info(f"Total documents to add: {len(documents)}")
|
||||
uuids = self._get_uuids(documents)
|
||||
for i in range(len(documents)):
|
||||
action_header = {
|
||||
"index": {
|
||||
"_index": self.collection_name.lower(),
|
||||
"_id": uuids[i],
|
||||
|
||||
total_docs = len(documents)
|
||||
num_batches = (total_docs + batch_size - 1) // batch_size
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||
)
|
||||
def _bulk_with_retry(actions):
|
||||
try:
|
||||
response = self._client.bulk(actions, timeout=timeout)
|
||||
if response["errors"]:
|
||||
error_items = [item for item in response["items"] if "error" in item["index"]]
|
||||
error_msg = f"Bulk indexing had {len(error_items)} errors"
|
||||
logger.exception(error_msg)
|
||||
raise Exception(error_msg)
|
||||
return response
|
||||
except Exception:
|
||||
logger.exception("Bulk indexing error")
|
||||
raise
|
||||
|
||||
for batch_num in range(num_batches):
|
||||
start_idx = batch_num * batch_size
|
||||
end_idx = min((batch_num + 1) * batch_size, total_docs)
|
||||
|
||||
actions = []
|
||||
for i in range(start_idx, end_idx):
|
||||
action_header = {
|
||||
"index": {
|
||||
"_index": self.collection_name.lower(),
|
||||
"_id": uuids[i],
|
||||
}
|
||||
}
|
||||
}
|
||||
action_values: dict[str, Any] = {
|
||||
Field.CONTENT_KEY.value: documents[i].page_content,
|
||||
Field.VECTOR.value: embeddings[i], # Make sure you pass an array here
|
||||
Field.METADATA_KEY.value: documents[i].metadata,
|
||||
}
|
||||
if self._using_ugc:
|
||||
action_header["index"]["routing"] = self._routing
|
||||
if self._routing_field is not None:
|
||||
action_values[self._routing_field] = self._routing
|
||||
actions.append(action_header)
|
||||
actions.append(action_values)
|
||||
response = self._client.bulk(actions)
|
||||
if response["errors"]:
|
||||
for item in response["items"]:
|
||||
print(f"{item['index']['status']}: {item['index']['error']['type']}")
|
||||
action_values: dict[str, Any] = {
|
||||
Field.CONTENT_KEY.value: documents[i].page_content,
|
||||
Field.VECTOR.value: embeddings[i],
|
||||
Field.METADATA_KEY.value: documents[i].metadata,
|
||||
}
|
||||
if self._using_ugc:
|
||||
action_header["index"]["routing"] = self._routing
|
||||
if self._routing_field is not None:
|
||||
action_values[self._routing_field] = self._routing
|
||||
|
||||
actions.append(action_header)
|
||||
actions.append(action_values)
|
||||
|
||||
# logger.info(f"Processing batch {batch_num + 1}/{num_batches} (documents {start_idx + 1} to {end_idx})")
|
||||
|
||||
try:
|
||||
_bulk_with_retry(actions)
|
||||
# logger.info(f"Successfully processed batch {batch_num + 1}")
|
||||
# simple latency to avoid too many requests in a short time
|
||||
if batch_num < num_batches - 1:
|
||||
time.sleep(0.5)
|
||||
|
||||
except Exception:
|
||||
logger.exception(f"Failed to process batch {batch_num + 1}")
|
||||
raise
|
||||
|
||||
def get_ids_by_metadata_field(self, key: str, value: str):
|
||||
query: dict[str, Any] = {
|
||||
|
@ -121,19 +167,51 @@ class LindormVectorStore(BaseVector):
|
|||
self.delete_by_ids(ids)
|
||||
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
params = {}
|
||||
if self._using_ugc:
|
||||
params["routing"] = self._routing
|
||||
"""Delete documents by their IDs in batch.
|
||||
|
||||
Args:
|
||||
ids: List of document IDs to delete
|
||||
"""
|
||||
if not ids:
|
||||
return
|
||||
|
||||
params = {"routing": self._routing} if self._using_ugc else {}
|
||||
|
||||
# 1. First check if collection exists
|
||||
if not self._client.indices.exists(index=self._collection_name):
|
||||
logger.warning(f"Collection {self._collection_name} does not exist")
|
||||
return
|
||||
|
||||
# 2. Batch process deletions
|
||||
actions = []
|
||||
for id in ids:
|
||||
if self._client.exists(index=self._collection_name, id=id, params=params):
|
||||
params = {}
|
||||
if self._using_ugc:
|
||||
params["routing"] = self._routing
|
||||
self._client.delete(index=self._collection_name, id=id, params=params)
|
||||
self.refresh()
|
||||
actions.append(
|
||||
{
|
||||
"_op_type": "delete",
|
||||
"_index": self._collection_name,
|
||||
"_id": id,
|
||||
**params, # Include routing if using UGC
|
||||
}
|
||||
)
|
||||
else:
|
||||
logger.warning(f"DELETE BY ID: ID {id} does not exist in the index.")
|
||||
|
||||
# 3. Perform bulk deletion if there are valid documents to delete
|
||||
if actions:
|
||||
try:
|
||||
helpers.bulk(self._client, actions)
|
||||
except BulkIndexError as e:
|
||||
for error in e.errors:
|
||||
delete_error = error.get("delete", {})
|
||||
status = delete_error.get("status")
|
||||
doc_id = delete_error.get("_id")
|
||||
|
||||
if status == 404:
|
||||
logger.warning(f"Document not found for deletion: {doc_id}")
|
||||
else:
|
||||
logger.exception(f"Error deleting document: {error}")
|
||||
|
||||
def delete(self) -> None:
|
||||
if self._using_ugc:
|
||||
routing_filter_query = {
|
||||
|
@ -169,7 +247,7 @@ class LindormVectorStore(BaseVector):
|
|||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
filters = []
|
||||
if document_ids_filter:
|
||||
filters.append({"terms": {"metadata.document_id": document_ids_filter}})
|
||||
filters.append({"terms": {"metadata.document_id.keyword": document_ids_filter}})
|
||||
query = default_vector_search_query(query_vector=query_vector, k=top_k, filters=filters, **kwargs)
|
||||
|
||||
try:
|
||||
|
@ -212,7 +290,7 @@ class LindormVectorStore(BaseVector):
|
|||
filters = kwargs.get("filter", [])
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
if document_ids_filter:
|
||||
filters.append({"terms": {"metadata.document_id": document_ids_filter}})
|
||||
filters.append({"terms": {"metadata.document_id.keyword": document_ids_filter}})
|
||||
routing = self._routing
|
||||
full_text_query = default_text_search_query(
|
||||
query_text=query,
|
||||
|
@ -226,6 +304,7 @@ class LindormVectorStore(BaseVector):
|
|||
routing=routing,
|
||||
routing_field=self._routing_field,
|
||||
)
|
||||
|
||||
response = self._client.search(index=self._collection_name, body=full_text_query)
|
||||
docs = []
|
||||
for hit in response["hits"]["hits"]:
|
||||
|
@ -435,7 +514,7 @@ def default_vector_search_query(
|
|||
**kwargs,
|
||||
) -> dict:
|
||||
if filters is not None:
|
||||
filter_type = "post_filter" if filter_type is None else filter_type
|
||||
filter_type = "pre_filter" if filter_type is None else filter_type
|
||||
if not isinstance(filters, list):
|
||||
raise RuntimeError(f"unexpected filter with {type(filters)}")
|
||||
final_ext: dict[str, Any] = {"lvector": {}}
|
||||
|
|
|
@ -1,12 +1,14 @@
|
|||
import json
|
||||
import logging
|
||||
import math
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
from tcvdb_text.encoder import BM25Encoder # type: ignore
|
||||
from tcvectordb import RPCVectorDBClient, VectorDBException # type: ignore
|
||||
from tcvectordb.model import document, enum # type: ignore
|
||||
from tcvectordb.model import index as vdb_index # type: ignore
|
||||
from tcvectordb.model.document import Filter # type: ignore
|
||||
from tcvectordb.model.document import AnnSearch, Filter, KeywordSearch, WeightedRerank # type: ignore
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
||||
|
@ -17,6 +19,8 @@ from core.rag.models.document import Document
|
|||
from extensions.ext_redis import redis_client
|
||||
from models.dataset import Dataset
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TencentConfig(BaseModel):
|
||||
url: str
|
||||
|
@ -25,10 +29,11 @@ class TencentConfig(BaseModel):
|
|||
username: Optional[str]
|
||||
database: Optional[str]
|
||||
index_type: str = "HNSW"
|
||||
metric_type: str = "L2"
|
||||
metric_type: str = "IP"
|
||||
shard: int = 1
|
||||
replicas: int = 2
|
||||
max_upsert_batch_size: int = 128
|
||||
enable_hybrid_search: bool = False # Flag to enable hybrid search
|
||||
|
||||
def to_tencent_params(self):
|
||||
return {"url": self.url, "username": self.username, "key": self.api_key, "timeout": self.timeout}
|
||||
|
@ -44,6 +49,29 @@ class TencentVector(BaseVector):
|
|||
super().__init__(collection_name)
|
||||
self._client_config = config
|
||||
self._client = RPCVectorDBClient(**self._client_config.to_tencent_params())
|
||||
self._enable_hybrid_search = False
|
||||
self._dimension = 1024
|
||||
self._load_collection()
|
||||
self._bm25 = BM25Encoder.default("zh")
|
||||
|
||||
def _load_collection(self):
|
||||
"""
|
||||
Check if the collection supports hybrid search.
|
||||
"""
|
||||
if self._client_config.enable_hybrid_search:
|
||||
self._enable_hybrid_search = True
|
||||
if self._has_collection():
|
||||
coll = self._client.describe_collection(
|
||||
database_name=self._client_config.database, collection_name=self.collection_name
|
||||
)
|
||||
has_hybrid_search = False
|
||||
for idx in coll.indexes:
|
||||
if idx.name == "sparse_vector":
|
||||
has_hybrid_search = True
|
||||
elif idx.name == "vector":
|
||||
self._dimension = idx.dimension
|
||||
if not has_hybrid_search:
|
||||
self._enable_hybrid_search = False
|
||||
|
||||
def _init_database(self):
|
||||
return self._client.create_database_if_not_exists(database_name=self._client_config.database)
|
||||
|
@ -62,6 +90,7 @@ class TencentVector(BaseVector):
|
|||
)
|
||||
|
||||
def _create_collection(self, dimension: int) -> None:
|
||||
self._dimension = dimension
|
||||
lock_name = "vector_indexing_lock_{}".format(self._collection_name)
|
||||
with redis_client.lock(lock_name, timeout=20):
|
||||
collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name)
|
||||
|
@ -84,18 +113,25 @@ class TencentVector(BaseVector):
|
|||
if metric_type is None:
|
||||
raise ValueError("unsupported metric_type")
|
||||
params = vdb_index.HNSWParams(m=16, efconstruction=200)
|
||||
index = vdb_index.Index(
|
||||
vdb_index.FilterIndex(self.field_id, enum.FieldType.String, enum.IndexType.PRIMARY_KEY),
|
||||
vdb_index.VectorIndex(
|
||||
self.field_vector,
|
||||
dimension,
|
||||
index_type,
|
||||
metric_type,
|
||||
params,
|
||||
),
|
||||
vdb_index.FilterIndex(self.field_text, enum.FieldType.String, enum.IndexType.FILTER),
|
||||
vdb_index.FilterIndex(self.field_metadata, enum.FieldType.Json, enum.IndexType.FILTER),
|
||||
index_id = vdb_index.FilterIndex(self.field_id, enum.FieldType.String, enum.IndexType.PRIMARY_KEY)
|
||||
index_vector = vdb_index.VectorIndex(
|
||||
self.field_vector,
|
||||
dimension,
|
||||
index_type,
|
||||
metric_type,
|
||||
params,
|
||||
)
|
||||
index_text = vdb_index.FilterIndex(self.field_text, enum.FieldType.String, enum.IndexType.FILTER)
|
||||
index_metadate = vdb_index.FilterIndex(self.field_metadata, enum.FieldType.Json, enum.IndexType.FILTER)
|
||||
index_sparse_vector = vdb_index.SparseIndex(
|
||||
name="sparse_vector",
|
||||
field_type=enum.FieldType.SparseVector,
|
||||
index_type=enum.IndexType.SPARSE_INVERTED,
|
||||
metric_type=enum.MetricType.IP,
|
||||
)
|
||||
indexes = [index_id, index_vector, index_text, index_metadate]
|
||||
if self._enable_hybrid_search:
|
||||
indexes.append(index_sparse_vector)
|
||||
try:
|
||||
self._client.create_collection(
|
||||
database_name=self._client_config.database,
|
||||
|
@ -103,31 +139,25 @@ class TencentVector(BaseVector):
|
|||
shard=self._client_config.shard,
|
||||
replicas=self._client_config.replicas,
|
||||
description="Collection for Dify",
|
||||
index=index,
|
||||
indexes=indexes,
|
||||
)
|
||||
except VectorDBException as e:
|
||||
if "fieldType:json" not in e.message:
|
||||
raise e
|
||||
# vdb version not support json, use string
|
||||
index = vdb_index.Index(
|
||||
vdb_index.FilterIndex(self.field_id, enum.FieldType.String, enum.IndexType.PRIMARY_KEY),
|
||||
vdb_index.VectorIndex(
|
||||
self.field_vector,
|
||||
dimension,
|
||||
index_type,
|
||||
metric_type,
|
||||
params,
|
||||
),
|
||||
vdb_index.FilterIndex(self.field_text, enum.FieldType.String, enum.IndexType.FILTER),
|
||||
vdb_index.FilterIndex(self.field_metadata, enum.FieldType.String, enum.IndexType.FILTER),
|
||||
index_metadate = vdb_index.FilterIndex(
|
||||
self.field_metadata, enum.FieldType.String, enum.IndexType.FILTER
|
||||
)
|
||||
indexes = [index_id, index_vector, index_text, index_metadate]
|
||||
if self._enable_hybrid_search:
|
||||
indexes.append(index_sparse_vector)
|
||||
self._client.create_collection(
|
||||
database_name=self._client_config.database,
|
||||
collection_name=self._collection_name,
|
||||
shard=self._client_config.shard,
|
||||
replicas=self._client_config.replicas,
|
||||
description="Collection for Dify",
|
||||
index=index,
|
||||
indexes=indexes,
|
||||
)
|
||||
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||
|
||||
|
@ -155,6 +185,8 @@ class TencentVector(BaseVector):
|
|||
text=texts[i],
|
||||
metadata=metadata,
|
||||
)
|
||||
if self._enable_hybrid_search:
|
||||
doc.__dict__["sparse_vector"] = self._bm25.encode_texts(texts[i])
|
||||
docs.append(doc)
|
||||
self._client.upsert(
|
||||
database_name=self._client_config.database,
|
||||
|
@ -204,7 +236,32 @@ class TencentVector(BaseVector):
|
|||
return self._get_search_res(res, score_threshold)
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
return []
|
||||
if not self._enable_hybrid_search:
|
||||
return []
|
||||
res = self._client.hybrid_search(
|
||||
database_name=self._client_config.database,
|
||||
collection_name=self.collection_name,
|
||||
ann=[
|
||||
AnnSearch(
|
||||
field_name="vector",
|
||||
data=[0.0] * self._dimension,
|
||||
)
|
||||
],
|
||||
match=[
|
||||
KeywordSearch(
|
||||
field_name="sparse_vector",
|
||||
data=self._bm25.encode_queries(query),
|
||||
),
|
||||
],
|
||||
rerank=WeightedRerank(
|
||||
field_list=["vector", "sparse_vector"],
|
||||
weight=[0, 1],
|
||||
),
|
||||
retrieve_vector=False,
|
||||
limit=kwargs.get("top_k", 4),
|
||||
)
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
return self._get_search_res(res, score_threshold)
|
||||
|
||||
def _get_search_res(self, res: list | None, score_threshold: float) -> list[Document]:
|
||||
docs: list[Document] = []
|
||||
|
@ -213,7 +270,7 @@ class TencentVector(BaseVector):
|
|||
|
||||
for result in res[0]:
|
||||
meta = result.get(self.field_metadata)
|
||||
score = 1 - result.get("score", 0.0)
|
||||
score = result.get("score", 0.0)
|
||||
if score > score_threshold:
|
||||
meta["score"] = score
|
||||
doc = Document(page_content=result.get(self.field_text), metadata=meta)
|
||||
|
@ -245,5 +302,6 @@ class TencentVectorFactory(AbstractVectorFactory):
|
|||
database=dify_config.TENCENT_VECTOR_DB_DATABASE,
|
||||
shard=dify_config.TENCENT_VECTOR_DB_SHARD,
|
||||
replicas=dify_config.TENCENT_VECTOR_DB_REPLICAS,
|
||||
enable_hybrid_search=dify_config.TENCENT_VECTOR_DB_ENABLE_HYBRID_SEARCH or False,
|
||||
),
|
||||
)
|
||||
|
|
|
@ -18,6 +18,7 @@ from core.rag.extractor.markdown_extractor import MarkdownExtractor
|
|||
from core.rag.extractor.notion_extractor import NotionExtractor
|
||||
from core.rag.extractor.pdf_extractor import PdfExtractor
|
||||
from core.rag.extractor.text_extractor import TextExtractor
|
||||
from core.rag.extractor.unstructured.unstructured_doc_extractor import UnstructuredWordExtractor
|
||||
from core.rag.extractor.unstructured.unstructured_eml_extractor import UnstructuredEmailExtractor
|
||||
from core.rag.extractor.unstructured.unstructured_epub_extractor import UnstructuredEpubExtractor
|
||||
from core.rag.extractor.unstructured.unstructured_markdown_extractor import UnstructuredMarkdownExtractor
|
||||
|
@ -25,6 +26,7 @@ from core.rag.extractor.unstructured.unstructured_msg_extractor import Unstructu
|
|||
from core.rag.extractor.unstructured.unstructured_ppt_extractor import UnstructuredPPTExtractor
|
||||
from core.rag.extractor.unstructured.unstructured_pptx_extractor import UnstructuredPPTXExtractor
|
||||
from core.rag.extractor.unstructured.unstructured_xml_extractor import UnstructuredXmlExtractor
|
||||
from core.rag.extractor.watercrawl.extractor import WaterCrawlWebExtractor
|
||||
from core.rag.extractor.word_extractor import WordExtractor
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_storage import storage
|
||||
|
@ -104,7 +106,7 @@ class ExtractProcessor:
|
|||
etl_type = dify_config.ETL_TYPE
|
||||
extractor: Optional[BaseExtractor] = None
|
||||
if etl_type == "Unstructured":
|
||||
unstructured_api_url = dify_config.UNSTRUCTURED_API_URL
|
||||
unstructured_api_url = dify_config.UNSTRUCTURED_API_URL or ""
|
||||
unstructured_api_key = dify_config.UNSTRUCTURED_API_KEY or ""
|
||||
|
||||
if file_extension in {".xlsx", ".xls"}:
|
||||
|
@ -121,6 +123,8 @@ class ExtractProcessor:
|
|||
extractor = HtmlExtractor(file_path)
|
||||
elif file_extension == ".docx":
|
||||
extractor = WordExtractor(file_path, upload_file.tenant_id, upload_file.created_by)
|
||||
elif file_extension == ".doc":
|
||||
extractor = UnstructuredWordExtractor(file_path, unstructured_api_url, unstructured_api_key)
|
||||
elif file_extension == ".csv":
|
||||
extractor = CSVExtractor(file_path, autodetect_encoding=True)
|
||||
elif file_extension == ".msg":
|
||||
|
@ -180,6 +184,15 @@ class ExtractProcessor:
|
|||
only_main_content=extract_setting.website_info.only_main_content,
|
||||
)
|
||||
return extractor.extract()
|
||||
elif extract_setting.website_info.provider == "watercrawl":
|
||||
extractor = WaterCrawlWebExtractor(
|
||||
url=extract_setting.website_info.url,
|
||||
job_id=extract_setting.website_info.job_id,
|
||||
tenant_id=extract_setting.website_info.tenant_id,
|
||||
mode=extract_setting.website_info.mode,
|
||||
only_main_content=extract_setting.website_info.only_main_content,
|
||||
)
|
||||
return extractor.extract()
|
||||
elif extract_setting.website_info.provider == "jinareader":
|
||||
extractor = JinaReaderWebExtractor(
|
||||
url=extract_setting.website_info.url,
|
||||
|
|
|
@ -10,14 +10,11 @@ logger = logging.getLogger(__name__)
|
|||
class UnstructuredWordExtractor(BaseExtractor):
|
||||
"""Loader that uses unstructured to load word documents."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
file_path: str,
|
||||
api_url: str,
|
||||
):
|
||||
def __init__(self, file_path: str, api_url: str, api_key: str = ""):
|
||||
"""Initialize with file path."""
|
||||
self._file_path = file_path
|
||||
self._api_url = api_url
|
||||
self._api_key = api_key
|
||||
|
||||
def extract(self) -> list[Document]:
|
||||
from unstructured.__version__ import __version__ as __unstructured_version__
|
||||
|
@ -41,9 +38,10 @@ class UnstructuredWordExtractor(BaseExtractor):
|
|||
)
|
||||
|
||||
if is_doc:
|
||||
from unstructured.partition.doc import partition_doc
|
||||
from unstructured.partition.api import partition_via_api
|
||||
|
||||
elements = partition_via_api(filename=self._file_path, api_url=self._api_url, api_key=self._api_key)
|
||||
|
||||
elements = partition_doc(filename=self._file_path)
|
||||
else:
|
||||
from unstructured.partition.docx import partition_docx
|
||||
|
||||
|
|
|
@ -0,0 +1,161 @@
|
|||
import json
|
||||
from collections.abc import Generator
|
||||
from typing import Union
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import requests
|
||||
from requests import Response
|
||||
|
||||
|
||||
class BaseAPIClient:
|
||||
def __init__(self, api_key, base_url):
|
||||
self.api_key = api_key
|
||||
self.base_url = base_url
|
||||
self.session = self.init_session()
|
||||
|
||||
def init_session(self):
|
||||
session = requests.Session()
|
||||
session.headers.update({"X-API-Key": self.api_key})
|
||||
session.headers.update({"Content-Type": "application/json"})
|
||||
session.headers.update({"Accept": "application/json"})
|
||||
session.headers.update({"User-Agent": "WaterCrawl-Plugin"})
|
||||
session.headers.update({"Accept-Language": "en-US"})
|
||||
return session
|
||||
|
||||
def _get(self, endpoint: str, query_params: dict | None = None, **kwargs):
|
||||
return self.session.get(urljoin(self.base_url, endpoint), params=query_params, **kwargs)
|
||||
|
||||
def _post(self, endpoint: str, query_params: dict | None = None, data: dict | None = None, **kwargs):
|
||||
return self.session.post(urljoin(self.base_url, endpoint), params=query_params, json=data, **kwargs)
|
||||
|
||||
def _put(self, endpoint: str, query_params: dict | None = None, data: dict | None = None, **kwargs):
|
||||
return self.session.put(urljoin(self.base_url, endpoint), params=query_params, json=data, **kwargs)
|
||||
|
||||
def _delete(self, endpoint: str, query_params: dict | None = None, **kwargs):
|
||||
return self.session.delete(urljoin(self.base_url, endpoint), params=query_params, **kwargs)
|
||||
|
||||
def _patch(self, endpoint: str, query_params: dict | None = None, data: dict | None = None, **kwargs):
|
||||
return self.session.patch(urljoin(self.base_url, endpoint), params=query_params, json=data, **kwargs)
|
||||
|
||||
|
||||
class WaterCrawlAPIClient(BaseAPIClient):
|
||||
def __init__(self, api_key, base_url: str | None = "https://app.watercrawl.dev/"):
|
||||
super().__init__(api_key, base_url)
|
||||
|
||||
def process_eventstream(self, response: Response, download: bool = False) -> Generator:
|
||||
for line in response.iter_lines():
|
||||
line = line.decode("utf-8")
|
||||
if line.startswith("data:"):
|
||||
line = line[5:].strip()
|
||||
data = json.loads(line)
|
||||
if data["type"] == "result" and download:
|
||||
data["data"] = self.download_result(data["data"])
|
||||
yield data
|
||||
|
||||
def process_response(self, response: Response) -> dict | bytes | list | None | Generator:
|
||||
response.raise_for_status()
|
||||
if response.status_code == 204:
|
||||
return None
|
||||
if response.headers.get("Content-Type") == "application/json":
|
||||
return response.json() or {}
|
||||
|
||||
if response.headers.get("Content-Type") == "application/octet-stream":
|
||||
return response.content
|
||||
|
||||
if response.headers.get("Content-Type") == "text/event-stream":
|
||||
return self.process_eventstream(response)
|
||||
|
||||
raise Exception(f"Unknown response type: {response.headers.get('Content-Type')}")
|
||||
|
||||
def get_crawl_requests_list(self, page: int | None = None, page_size: int | None = None):
|
||||
query_params = {"page": page or 1, "page_size": page_size or 10}
|
||||
return self.process_response(
|
||||
self._get(
|
||||
"/api/v1/core/crawl-requests/",
|
||||
query_params=query_params,
|
||||
)
|
||||
)
|
||||
|
||||
def get_crawl_request(self, item_id: str):
|
||||
return self.process_response(
|
||||
self._get(
|
||||
f"/api/v1/core/crawl-requests/{item_id}/",
|
||||
)
|
||||
)
|
||||
|
||||
def create_crawl_request(
|
||||
self,
|
||||
url: Union[list, str] | None = None,
|
||||
spider_options: dict | None = None,
|
||||
page_options: dict | None = None,
|
||||
plugin_options: dict | None = None,
|
||||
):
|
||||
data = {
|
||||
# 'urls': url if isinstance(url, list) else [url],
|
||||
"url": url,
|
||||
"options": {
|
||||
"spider_options": spider_options or {},
|
||||
"page_options": page_options or {},
|
||||
"plugin_options": plugin_options or {},
|
||||
},
|
||||
}
|
||||
return self.process_response(
|
||||
self._post(
|
||||
"/api/v1/core/crawl-requests/",
|
||||
data=data,
|
||||
)
|
||||
)
|
||||
|
||||
def stop_crawl_request(self, item_id: str):
|
||||
return self.process_response(
|
||||
self._delete(
|
||||
f"/api/v1/core/crawl-requests/{item_id}/",
|
||||
)
|
||||
)
|
||||
|
||||
def download_crawl_request(self, item_id: str):
|
||||
return self.process_response(
|
||||
self._get(
|
||||
f"/api/v1/core/crawl-requests/{item_id}/download/",
|
||||
)
|
||||
)
|
||||
|
||||
def monitor_crawl_request(self, item_id: str, prefetched=False) -> Generator:
|
||||
query_params = {"prefetched": str(prefetched).lower()}
|
||||
generator = self.process_response(
|
||||
self._get(f"/api/v1/core/crawl-requests/{item_id}/status/", stream=True, query_params=query_params),
|
||||
)
|
||||
if not isinstance(generator, Generator):
|
||||
raise ValueError("Generator expected")
|
||||
yield from generator
|
||||
|
||||
def get_crawl_request_results(
|
||||
self, item_id: str, page: int = 1, page_size: int = 25, query_params: dict | None = None
|
||||
):
|
||||
query_params = query_params or {}
|
||||
query_params.update({"page": page or 1, "page_size": page_size or 25})
|
||||
return self.process_response(
|
||||
self._get(f"/api/v1/core/crawl-requests/{item_id}/results/", query_params=query_params)
|
||||
)
|
||||
|
||||
def scrape_url(
|
||||
self,
|
||||
url: str,
|
||||
page_options: dict | None = None,
|
||||
plugin_options: dict | None = None,
|
||||
sync: bool = True,
|
||||
prefetched: bool = True,
|
||||
):
|
||||
response_result = self.create_crawl_request(url=url, page_options=page_options, plugin_options=plugin_options)
|
||||
if not sync:
|
||||
return response_result
|
||||
|
||||
for event_data in self.monitor_crawl_request(response_result["uuid"], prefetched):
|
||||
if event_data["type"] == "result":
|
||||
return event_data["data"]
|
||||
|
||||
def download_result(self, result_object: dict):
|
||||
response = requests.get(result_object["result"])
|
||||
response.raise_for_status()
|
||||
result_object["result"] = response.json()
|
||||
return result_object
|
|
@ -0,0 +1,57 @@
|
|||
from core.rag.extractor.extractor_base import BaseExtractor
|
||||
from core.rag.models.document import Document
|
||||
from services.website_service import WebsiteService
|
||||
|
||||
|
||||
class WaterCrawlWebExtractor(BaseExtractor):
|
||||
"""
|
||||
Crawl and scrape websites and return content in clean llm-ready markdown.
|
||||
|
||||
|
||||
Args:
|
||||
url: The URL to scrape.
|
||||
api_key: The API key for WaterCrawl.
|
||||
base_url: The base URL for the Firecrawl API. Defaults to 'https://app.firecrawl.dev'.
|
||||
mode: The mode of operation. Defaults to 'scrape'. Options are 'crawl', 'scrape' and 'crawl_return_urls'.
|
||||
only_main_content: Only return the main content of the page excluding headers, navs, footers, etc.
|
||||
"""
|
||||
|
||||
def __init__(self, url: str, job_id: str, tenant_id: str, mode: str = "crawl", only_main_content: bool = True):
|
||||
"""Initialize with url, api_key, base_url and mode."""
|
||||
self._url = url
|
||||
self.job_id = job_id
|
||||
self.tenant_id = tenant_id
|
||||
self.mode = mode
|
||||
self.only_main_content = only_main_content
|
||||
|
||||
def extract(self) -> list[Document]:
|
||||
"""Extract content from the URL."""
|
||||
documents = []
|
||||
if self.mode == "crawl":
|
||||
crawl_data = WebsiteService.get_crawl_url_data(self.job_id, "watercrawl", self._url, self.tenant_id)
|
||||
if crawl_data is None:
|
||||
return []
|
||||
document = Document(
|
||||
page_content=crawl_data.get("markdown", ""),
|
||||
metadata={
|
||||
"source_url": crawl_data.get("source_url"),
|
||||
"description": crawl_data.get("description"),
|
||||
"title": crawl_data.get("title"),
|
||||
},
|
||||
)
|
||||
documents.append(document)
|
||||
elif self.mode == "scrape":
|
||||
scrape_data = WebsiteService.get_scrape_url_data(
|
||||
"watercrawl", self._url, self.tenant_id, self.only_main_content
|
||||
)
|
||||
|
||||
document = Document(
|
||||
page_content=scrape_data.get("markdown", ""),
|
||||
metadata={
|
||||
"source_url": scrape_data.get("source_url"),
|
||||
"description": scrape_data.get("description"),
|
||||
"title": scrape_data.get("title"),
|
||||
},
|
||||
)
|
||||
documents.append(document)
|
||||
return documents
|
|
@ -0,0 +1,117 @@
|
|||
from collections.abc import Generator
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from core.rag.extractor.watercrawl.client import WaterCrawlAPIClient
|
||||
|
||||
|
||||
class WaterCrawlProvider:
|
||||
def __init__(self, api_key, base_url: str | None = None):
|
||||
self.client = WaterCrawlAPIClient(api_key, base_url)
|
||||
|
||||
def crawl_url(self, url, options: dict | Any = None) -> dict:
|
||||
options = options or {}
|
||||
spider_options = {
|
||||
"max_depth": 1,
|
||||
"page_limit": 1,
|
||||
"allowed_domains": [],
|
||||
"exclude_paths": [],
|
||||
"include_paths": [],
|
||||
}
|
||||
if options.get("crawl_sub_pages", True):
|
||||
spider_options["page_limit"] = options.get("limit", 1)
|
||||
spider_options["max_depth"] = options.get("depth", 1)
|
||||
spider_options["include_paths"] = options.get("includes", "").split(",") if options.get("includes") else []
|
||||
spider_options["exclude_paths"] = options.get("excludes", "").split(",") if options.get("excludes") else []
|
||||
|
||||
wait_time = options.get("wait_time", 1000)
|
||||
page_options = {
|
||||
"exclude_tags": options.get("exclude_tags", "").split(",") if options.get("exclude_tags") else [],
|
||||
"include_tags": options.get("include_tags", "").split(",") if options.get("include_tags") else [],
|
||||
"wait_time": max(1000, wait_time), # minimum wait time is 1 second
|
||||
"include_html": False,
|
||||
"only_main_content": options.get("only_main_content", True),
|
||||
"include_links": False,
|
||||
"timeout": 15000,
|
||||
"accept_cookies_selector": "#cookies-accept",
|
||||
"locale": "en-US",
|
||||
"actions": [],
|
||||
}
|
||||
result = self.client.create_crawl_request(url=url, spider_options=spider_options, page_options=page_options)
|
||||
|
||||
return {"status": "active", "job_id": result.get("uuid")}
|
||||
|
||||
def get_crawl_status(self, crawl_request_id) -> dict:
|
||||
response = self.client.get_crawl_request(crawl_request_id)
|
||||
data = []
|
||||
if response["status"] in ["new", "running"]:
|
||||
status = "active"
|
||||
else:
|
||||
status = "completed"
|
||||
data = list(self._get_results(crawl_request_id))
|
||||
|
||||
time_str = response.get("duration")
|
||||
time_consuming: float = 0
|
||||
if time_str:
|
||||
time_obj = datetime.strptime(time_str, "%H:%M:%S.%f")
|
||||
time_consuming = (
|
||||
time_obj.hour * 3600 + time_obj.minute * 60 + time_obj.second + time_obj.microsecond / 1_000_000
|
||||
)
|
||||
|
||||
return {
|
||||
"status": status,
|
||||
"job_id": response.get("uuid"),
|
||||
"total": response.get("options", {}).get("spider_options", {}).get("page_limit", 1),
|
||||
"current": response.get("number_of_documents", 0),
|
||||
"data": data,
|
||||
"time_consuming": time_consuming,
|
||||
}
|
||||
|
||||
def get_crawl_url_data(self, job_id, url) -> dict | None:
|
||||
if not job_id:
|
||||
return self.scrape_url(url)
|
||||
|
||||
for result in self._get_results(
|
||||
job_id,
|
||||
{
|
||||
# filter by url
|
||||
"url": url
|
||||
},
|
||||
):
|
||||
return result
|
||||
|
||||
return None
|
||||
|
||||
def scrape_url(self, url: str) -> dict:
|
||||
response = self.client.scrape_url(url=url, sync=True, prefetched=True)
|
||||
return self._structure_data(response)
|
||||
|
||||
def _structure_data(self, result_object: dict) -> dict:
|
||||
if isinstance(result_object.get("result", {}), str):
|
||||
raise ValueError("Invalid result object. Expected a dictionary.")
|
||||
|
||||
metadata = result_object.get("result", {}).get("metadata", {})
|
||||
return {
|
||||
"title": metadata.get("og:title") or metadata.get("title"),
|
||||
"description": metadata.get("description"),
|
||||
"source_url": result_object.get("url"),
|
||||
"markdown": result_object.get("result", {}).get("markdown"),
|
||||
}
|
||||
|
||||
def _get_results(self, crawl_request_id: str, query_params: dict | None = None) -> Generator[dict, None, None]:
|
||||
page = 0
|
||||
page_size = 100
|
||||
|
||||
query_params = query_params or {}
|
||||
query_params.update({"prefetched": "true"})
|
||||
while True:
|
||||
page += 1
|
||||
response = self.client.get_crawl_request_results(crawl_request_id, page, page_size, query_params)
|
||||
if not response["results"]:
|
||||
break
|
||||
|
||||
for result in response["results"]:
|
||||
yield self._structure_data(result)
|
||||
|
||||
if response["next"] is None:
|
||||
break
|
|
@ -85,7 +85,7 @@ class WordExtractor(BaseExtractor):
|
|||
if "image" in rel.target_ref:
|
||||
image_count += 1
|
||||
if rel.is_external:
|
||||
url = rel.reltype
|
||||
url = rel.target_ref
|
||||
response = ssrf_proxy.get(url)
|
||||
if response.status_code == 200:
|
||||
image_ext = mimetypes.guess_extension(response.headers["Content-Type"])
|
||||
|
|
|
@ -30,6 +30,7 @@ class NodeRunMetadataKey(StrEnum):
|
|||
ITERATION_DURATION_MAP = "iteration_duration_map" # single iteration duration if iteration node runs
|
||||
LOOP_DURATION_MAP = "loop_duration_map" # single loop duration if loop node runs
|
||||
ERROR_STRATEGY = "error_strategy" # node in continue on error mode return the field
|
||||
LOOP_VARIABLE_MAP = "loop_variable_map" # single loop variable output
|
||||
|
||||
|
||||
class NodeRunResult(BaseModel):
|
||||
|
|
|
@ -1,15 +1,18 @@
|
|||
import json
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import Any, cast
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from core.agent.entities import AgentToolEntity
|
||||
from core.agent.plugin_entities import AgentStrategyParameter
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
|
||||
from core.plugin.manager.exc import PluginDaemonClientSideError
|
||||
from core.plugin.manager.plugin import PluginInstallationManager
|
||||
from core.provider_manager import ProviderManager
|
||||
from core.tools.entities.tool_entities import ToolParameter, ToolProviderType
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.variables.segments import StringSegment
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
|
@ -19,7 +22,9 @@ from core.workflow.nodes.enums import NodeType
|
|||
from core.workflow.nodes.event.event import RunCompletedEvent
|
||||
from core.workflow.nodes.tool.tool_node import ToolNode
|
||||
from core.workflow.utils.variable_template_parser import VariableTemplateParser
|
||||
from extensions.ext_database import db
|
||||
from factories.agent_factory import get_plugin_agent_strategy
|
||||
from models.model import Conversation
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
|
@ -233,17 +238,20 @@ class AgentNode(ToolNode):
|
|||
value = tool_value
|
||||
if parameter.type == "model-selector":
|
||||
value = cast(dict[str, Any], value)
|
||||
model_instance = ModelManager().get_model_instance(
|
||||
tenant_id=self.tenant_id,
|
||||
provider=value.get("provider", ""),
|
||||
model_type=ModelType(value.get("model_type", "")),
|
||||
model=value.get("model", ""),
|
||||
)
|
||||
models = model_instance.model_type_instance.plugin_model_provider.declaration.models
|
||||
finded_model = next((model for model in models if model.model == value.get("model", "")), None)
|
||||
|
||||
value["entity"] = finded_model.model_dump(mode="json") if finded_model else None
|
||||
|
||||
model_instance, model_schema = self._fetch_model(value)
|
||||
# memory config
|
||||
history_prompt_messages = []
|
||||
if node_data.memory:
|
||||
memory = self._fetch_memory(model_instance)
|
||||
if memory:
|
||||
prompt_messages = memory.get_history_prompt_messages(
|
||||
message_limit=node_data.memory.window.size if node_data.memory.window.size else None
|
||||
)
|
||||
history_prompt_messages = [
|
||||
prompt_message.model_dump(mode="json") for prompt_message in prompt_messages
|
||||
]
|
||||
value["history_prompt_messages"] = history_prompt_messages
|
||||
value["entity"] = model_schema.model_dump(mode="json") if model_schema else None
|
||||
result[parameter_name] = value
|
||||
|
||||
return result
|
||||
|
@ -297,3 +305,46 @@ class AgentNode(ToolNode):
|
|||
except StopIteration:
|
||||
icon = None
|
||||
return icon
|
||||
|
||||
def _fetch_memory(self, model_instance: ModelInstance) -> Optional[TokenBufferMemory]:
|
||||
# get conversation id
|
||||
conversation_id_variable = self.graph_runtime_state.variable_pool.get(
|
||||
["sys", SystemVariableKey.CONVERSATION_ID.value]
|
||||
)
|
||||
if not isinstance(conversation_id_variable, StringSegment):
|
||||
return None
|
||||
conversation_id = conversation_id_variable.value
|
||||
|
||||
# get conversation
|
||||
conversation = (
|
||||
db.session.query(Conversation)
|
||||
.filter(Conversation.app_id == self.app_id, Conversation.id == conversation_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not conversation:
|
||||
return None
|
||||
|
||||
memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance)
|
||||
|
||||
return memory
|
||||
|
||||
def _fetch_model(self, value: dict[str, Any]) -> tuple[ModelInstance, AIModelEntity | None]:
|
||||
provider_manager = ProviderManager()
|
||||
provider_model_bundle = provider_manager.get_provider_model_bundle(
|
||||
tenant_id=self.tenant_id, provider=value.get("provider", ""), model_type=ModelType.LLM
|
||||
)
|
||||
model_name = value.get("model", "")
|
||||
model_credentials = provider_model_bundle.configuration.get_current_credentials(
|
||||
model_type=ModelType.LLM, model=model_name
|
||||
)
|
||||
provider_name = provider_model_bundle.configuration.provider.provider
|
||||
model_type_instance = provider_model_bundle.model_type_instance
|
||||
model_instance = ModelManager().get_model_instance(
|
||||
tenant_id=self.tenant_id,
|
||||
provider=provider_name,
|
||||
model_type=ModelType(value.get("model_type", "")),
|
||||
model=model_name,
|
||||
)
|
||||
model_schema = model_type_instance.get_model_schema(model_name, model_credentials)
|
||||
return model_instance, model_schema
|
||||
|
|
|
@ -3,6 +3,7 @@ from typing import Any, Literal, Union
|
|||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
|
||||
from core.tools.entities.tool_entities import ToolSelector
|
||||
from core.workflow.nodes.base.entities import BaseNodeData
|
||||
|
||||
|
@ -11,6 +12,7 @@ class AgentNodeData(BaseNodeData):
|
|||
agent_strategy_provider_name: str # redundancy
|
||||
agent_strategy_name: str
|
||||
agent_strategy_label: str # redundancy
|
||||
memory: MemoryConfig | None = None
|
||||
|
||||
class AgentInput(BaseModel):
|
||||
value: Union[list[str], list[ToolSelector], Any]
|
||||
|
|
|
@ -6,6 +6,7 @@ from core.helper.code_executor.code_executor import CodeExecutionError, CodeExec
|
|||
from core.helper.code_executor.code_node_provider import CodeNodeProvider
|
||||
from core.helper.code_executor.javascript.javascript_code_provider import JavascriptCodeProvider
|
||||
from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider
|
||||
from core.variables.segments import ArrayFileSegment
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.code.entities import CodeNodeData
|
||||
|
@ -49,7 +50,10 @@ class CodeNode(BaseNode[CodeNodeData]):
|
|||
for variable_selector in self.node_data.variables:
|
||||
variable_name = variable_selector.variable
|
||||
variable = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector)
|
||||
variables[variable_name] = variable.to_object() if variable else None
|
||||
if isinstance(variable, ArrayFileSegment):
|
||||
variables[variable_name] = [v.to_dict() for v in variable.value] if variable.value else None
|
||||
else:
|
||||
variables[variable_name] = variable.to_object() if variable else None
|
||||
# Run code
|
||||
try:
|
||||
result = CodeExecutor.execute_workflow_code_template(
|
||||
|
|
|
@ -17,6 +17,7 @@ class NodeType(StrEnum):
|
|||
LEGACY_VARIABLE_AGGREGATOR = "variable-assigner" # TODO: Merge this into VARIABLE_AGGREGATOR in the database.
|
||||
LOOP = "loop"
|
||||
LOOP_START = "loop-start"
|
||||
LOOP_END = "loop-end"
|
||||
ITERATION = "iteration"
|
||||
ITERATION_START = "iteration-start" # Fake start node for iteration.
|
||||
PARAMETER_EXTRACTOR = "parameter-extractor"
|
||||
|
|
|
@ -8,7 +8,7 @@ from core.workflow.utils.condition.entities import Condition
|
|||
|
||||
class IfElseNodeData(BaseNodeData):
|
||||
"""
|
||||
Answer Node Data.
|
||||
If Else Node Data.
|
||||
"""
|
||||
|
||||
class Case(BaseModel):
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
from .entities import LoopNodeData
|
||||
from .loop_end_node import LoopEndNode
|
||||
from .loop_node import LoopNode
|
||||
from .loop_start_node import LoopStartNode
|
||||
|
||||
__all__ = ["LoopNode", "LoopNodeData", "LoopStartNode"]
|
||||
__all__ = ["LoopEndNode", "LoopNode", "LoopNodeData", "LoopStartNode"]
|
||||
|
|
|
@ -1,11 +1,23 @@
|
|||
from collections.abc import Mapping
|
||||
from typing import Any, Literal, Optional
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.workflow.nodes.base import BaseLoopNodeData, BaseLoopState, BaseNodeData
|
||||
from core.workflow.utils.condition.entities import Condition
|
||||
|
||||
|
||||
class LoopVariableData(BaseModel):
|
||||
"""
|
||||
Loop Variable Data.
|
||||
"""
|
||||
|
||||
label: str
|
||||
var_type: Literal["string", "number", "object", "array[string]", "array[number]", "array[object]"]
|
||||
value_type: Literal["variable", "constant"]
|
||||
value: Optional[Any | list[str]] = None
|
||||
|
||||
|
||||
class LoopNodeData(BaseLoopNodeData):
|
||||
"""
|
||||
Loop Node Data.
|
||||
|
@ -14,6 +26,8 @@ class LoopNodeData(BaseLoopNodeData):
|
|||
loop_count: int # Maximum number of loops
|
||||
break_conditions: list[Condition] # Conditions to break the loop
|
||||
logical_operator: Literal["and", "or"]
|
||||
loop_variables: Optional[list[LoopVariableData]] = Field(default_factory=list)
|
||||
outputs: Optional[Mapping[str, Any]] = None
|
||||
|
||||
|
||||
class LoopStartNodeData(BaseNodeData):
|
||||
|
@ -24,6 +38,14 @@ class LoopStartNodeData(BaseNodeData):
|
|||
pass
|
||||
|
||||
|
||||
class LoopEndNodeData(BaseNodeData):
|
||||
"""
|
||||
Loop End Node Data.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class LoopState(BaseLoopState):
|
||||
"""
|
||||
Loop State.
|
||||
|
|
|
@ -0,0 +1,20 @@
|
|||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.nodes.loop.entities import LoopEndNodeData
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
class LoopEndNode(BaseNode[LoopEndNodeData]):
|
||||
"""
|
||||
Loop End Node.
|
||||
"""
|
||||
|
||||
_node_data_cls = LoopEndNodeData
|
||||
_node_type = NodeType.LOOP_END
|
||||
|
||||
def _run(self) -> NodeRunResult:
|
||||
"""
|
||||
Run the node.
|
||||
"""
|
||||
return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED)
|
|
@ -1,10 +1,20 @@
|
|||
import json
|
||||
import logging
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any, cast
|
||||
from typing import TYPE_CHECKING, Any, Literal, cast
|
||||
|
||||
from configs import dify_config
|
||||
from core.variables import IntegerSegment
|
||||
from core.variables import (
|
||||
ArrayNumberSegment,
|
||||
ArrayObjectSegment,
|
||||
ArrayStringSegment,
|
||||
IntegerSegment,
|
||||
ObjectSegment,
|
||||
Segment,
|
||||
SegmentType,
|
||||
StringSegment,
|
||||
)
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult
|
||||
from core.workflow.graph_engine.entities.event import (
|
||||
BaseGraphEvent,
|
||||
|
@ -29,6 +39,10 @@ from core.workflow.nodes.loop.entities import LoopNodeData
|
|||
from core.workflow.utils.condition.processor import ConditionProcessor
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.graph_engine.graph_engine import GraphEngine
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
@ -61,6 +75,28 @@ class LoopNode(BaseNode[LoopNodeData]):
|
|||
variable_pool = self.graph_runtime_state.variable_pool
|
||||
variable_pool.add([self.node_id, "index"], 0)
|
||||
|
||||
# Initialize loop variables
|
||||
loop_variable_selectors = {}
|
||||
if self.node_data.loop_variables:
|
||||
for loop_variable in self.node_data.loop_variables:
|
||||
value_processor = {
|
||||
"constant": lambda var=loop_variable: self._get_segment_for_constant(var.var_type, var.value),
|
||||
"variable": lambda var=loop_variable: variable_pool.get(var.value),
|
||||
}
|
||||
|
||||
if loop_variable.value_type not in value_processor:
|
||||
raise ValueError(
|
||||
f"Invalid value type '{loop_variable.value_type}' for loop variable {loop_variable.label}"
|
||||
)
|
||||
|
||||
processed_segment = value_processor[loop_variable.value_type]()
|
||||
if not processed_segment:
|
||||
raise ValueError(f"Invalid value for loop variable {loop_variable.label}")
|
||||
variable_selector = [self.node_id, loop_variable.label]
|
||||
variable_pool.add(variable_selector, processed_segment.value)
|
||||
loop_variable_selectors[loop_variable.label] = variable_selector
|
||||
inputs[loop_variable.label] = processed_segment.value
|
||||
|
||||
from core.workflow.graph_engine.graph_engine import GraphEngine
|
||||
|
||||
graph_engine = GraphEngine(
|
||||
|
@ -95,135 +131,51 @@ class LoopNode(BaseNode[LoopNodeData]):
|
|||
predecessor_node_id=self.previous_node_id,
|
||||
)
|
||||
|
||||
yield LoopRunNextEvent(
|
||||
loop_id=self.id,
|
||||
loop_node_id=self.node_id,
|
||||
loop_node_type=self.node_type,
|
||||
loop_node_data=self.node_data,
|
||||
index=0,
|
||||
pre_loop_output=None,
|
||||
)
|
||||
|
||||
# yield LoopRunNextEvent(
|
||||
# loop_id=self.id,
|
||||
# loop_node_id=self.node_id,
|
||||
# loop_node_type=self.node_type,
|
||||
# loop_node_data=self.node_data,
|
||||
# index=0,
|
||||
# pre_loop_output=None,
|
||||
# )
|
||||
loop_duration_map = {}
|
||||
single_loop_variable_map = {} # single loop variable output
|
||||
try:
|
||||
check_break_result = False
|
||||
for i in range(loop_count):
|
||||
# Run workflow
|
||||
rst = graph_engine.run()
|
||||
current_index_variable = variable_pool.get([self.node_id, "index"])
|
||||
if not isinstance(current_index_variable, IntegerSegment):
|
||||
raise ValueError(f"loop {self.node_id} current index not found")
|
||||
current_index = current_index_variable.value
|
||||
loop_start_time = datetime.now(UTC).replace(tzinfo=None)
|
||||
# run single loop
|
||||
loop_result = yield from self._run_single_loop(
|
||||
graph_engine=graph_engine,
|
||||
loop_graph=loop_graph,
|
||||
variable_pool=variable_pool,
|
||||
loop_variable_selectors=loop_variable_selectors,
|
||||
break_conditions=break_conditions,
|
||||
logical_operator=logical_operator,
|
||||
condition_processor=condition_processor,
|
||||
current_index=i,
|
||||
start_at=start_at,
|
||||
inputs=inputs,
|
||||
)
|
||||
loop_end_time = datetime.now(UTC).replace(tzinfo=None)
|
||||
|
||||
check_break_result = False
|
||||
|
||||
for event in rst:
|
||||
if isinstance(event, (BaseNodeEvent | BaseParallelBranchEvent)) and not event.in_loop_id:
|
||||
event.in_loop_id = self.node_id
|
||||
|
||||
if (
|
||||
isinstance(event, BaseNodeEvent)
|
||||
and event.node_type == NodeType.LOOP_START
|
||||
and not isinstance(event, NodeRunStreamChunkEvent)
|
||||
):
|
||||
continue
|
||||
|
||||
if isinstance(event, NodeRunSucceededEvent):
|
||||
yield self._handle_event_metadata(event=event, iter_run_index=current_index)
|
||||
|
||||
# Check if all variables in break conditions exist
|
||||
exists_variable = False
|
||||
for condition in break_conditions:
|
||||
if not self.graph_runtime_state.variable_pool.get(condition.variable_selector):
|
||||
exists_variable = False
|
||||
break
|
||||
else:
|
||||
exists_variable = True
|
||||
if exists_variable:
|
||||
input_conditions, group_result, check_break_result = condition_processor.process_conditions(
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
conditions=break_conditions,
|
||||
operator=logical_operator,
|
||||
)
|
||||
if check_break_result:
|
||||
break
|
||||
|
||||
elif isinstance(event, BaseGraphEvent):
|
||||
if isinstance(event, GraphRunFailedEvent):
|
||||
# Loop run failed
|
||||
yield LoopRunFailedEvent(
|
||||
loop_id=self.id,
|
||||
loop_node_id=self.node_id,
|
||||
loop_node_type=self.node_type,
|
||||
loop_node_data=self.node_data,
|
||||
start_at=start_at,
|
||||
inputs=inputs,
|
||||
steps=i,
|
||||
metadata={
|
||||
NodeRunMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens,
|
||||
"completed_reason": "error",
|
||||
},
|
||||
error=event.error,
|
||||
)
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=event.error,
|
||||
metadata={
|
||||
NodeRunMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens
|
||||
},
|
||||
)
|
||||
)
|
||||
return
|
||||
elif isinstance(event, NodeRunFailedEvent):
|
||||
# Loop run failed
|
||||
yield event
|
||||
yield LoopRunFailedEvent(
|
||||
loop_id=self.id,
|
||||
loop_node_id=self.node_id,
|
||||
loop_node_type=self.node_type,
|
||||
loop_node_data=self.node_data,
|
||||
start_at=start_at,
|
||||
inputs=inputs,
|
||||
steps=i,
|
||||
metadata={
|
||||
NodeRunMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens,
|
||||
"completed_reason": "error",
|
||||
},
|
||||
error=event.error,
|
||||
)
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=event.error,
|
||||
metadata={
|
||||
NodeRunMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens
|
||||
},
|
||||
)
|
||||
)
|
||||
return
|
||||
single_loop_variable = {}
|
||||
for key, selector in loop_variable_selectors.items():
|
||||
item = variable_pool.get(selector)
|
||||
if item:
|
||||
single_loop_variable[key] = item.value
|
||||
else:
|
||||
yield self._handle_event_metadata(event=cast(InNodeEvent, event), iter_run_index=current_index)
|
||||
single_loop_variable[key] = None
|
||||
|
||||
# Remove all nodes outputs from variable pool
|
||||
for node_id in loop_graph.node_ids:
|
||||
variable_pool.remove([node_id])
|
||||
loop_duration_map[str(i)] = (loop_end_time - loop_start_time).total_seconds()
|
||||
single_loop_variable_map[str(i)] = single_loop_variable
|
||||
|
||||
check_break_result = loop_result.get("check_break_result", False)
|
||||
|
||||
if check_break_result:
|
||||
break
|
||||
|
||||
# Move to next loop
|
||||
next_index = current_index + 1
|
||||
variable_pool.add([self.node_id, "index"], next_index)
|
||||
|
||||
yield LoopRunNextEvent(
|
||||
loop_id=self.id,
|
||||
loop_node_id=self.node_id,
|
||||
loop_node_type=self.node_type,
|
||||
loop_node_data=self.node_data,
|
||||
index=next_index,
|
||||
pre_loop_output=None,
|
||||
)
|
||||
|
||||
# Loop completed successfully
|
||||
yield LoopRunSucceededEvent(
|
||||
loop_id=self.id,
|
||||
|
@ -232,17 +184,26 @@ class LoopNode(BaseNode[LoopNodeData]):
|
|||
loop_node_data=self.node_data,
|
||||
start_at=start_at,
|
||||
inputs=inputs,
|
||||
outputs=self.node_data.outputs,
|
||||
steps=loop_count,
|
||||
metadata={
|
||||
NodeRunMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens,
|
||||
"completed_reason": "loop_break" if check_break_result else "loop_completed",
|
||||
NodeRunMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
|
||||
NodeRunMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
|
||||
},
|
||||
)
|
||||
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
metadata={NodeRunMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens},
|
||||
metadata={
|
||||
NodeRunMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens,
|
||||
NodeRunMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
|
||||
NodeRunMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
|
||||
},
|
||||
outputs=self.node_data.outputs,
|
||||
inputs=inputs,
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -260,6 +221,8 @@ class LoopNode(BaseNode[LoopNodeData]):
|
|||
metadata={
|
||||
"total_tokens": graph_engine.graph_runtime_state.total_tokens,
|
||||
"completed_reason": "error",
|
||||
NodeRunMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
|
||||
NodeRunMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
|
||||
},
|
||||
error=str(e),
|
||||
)
|
||||
|
@ -268,7 +231,11 @@ class LoopNode(BaseNode[LoopNodeData]):
|
|||
run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=str(e),
|
||||
metadata={NodeRunMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens},
|
||||
metadata={
|
||||
NodeRunMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens,
|
||||
NodeRunMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
|
||||
NodeRunMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -276,6 +243,159 @@ class LoopNode(BaseNode[LoopNodeData]):
|
|||
# Clean up
|
||||
variable_pool.remove([self.node_id, "index"])
|
||||
|
||||
def _run_single_loop(
|
||||
self,
|
||||
*,
|
||||
graph_engine: "GraphEngine",
|
||||
loop_graph: Graph,
|
||||
variable_pool: "VariablePool",
|
||||
loop_variable_selectors: dict,
|
||||
break_conditions: list,
|
||||
logical_operator: Literal["and", "or"],
|
||||
condition_processor: ConditionProcessor,
|
||||
current_index: int,
|
||||
start_at: datetime,
|
||||
inputs: dict,
|
||||
) -> Generator[NodeEvent | InNodeEvent, None, dict]:
|
||||
"""Run a single loop iteration.
|
||||
Returns:
|
||||
dict: {'check_break_result': bool}
|
||||
"""
|
||||
# Run workflow
|
||||
rst = graph_engine.run()
|
||||
current_index_variable = variable_pool.get([self.node_id, "index"])
|
||||
if not isinstance(current_index_variable, IntegerSegment):
|
||||
raise ValueError(f"loop {self.node_id} current index not found")
|
||||
current_index = current_index_variable.value
|
||||
|
||||
check_break_result = False
|
||||
|
||||
for event in rst:
|
||||
if isinstance(event, (BaseNodeEvent | BaseParallelBranchEvent)) and not event.in_loop_id:
|
||||
event.in_loop_id = self.node_id
|
||||
|
||||
if (
|
||||
isinstance(event, BaseNodeEvent)
|
||||
and event.node_type == NodeType.LOOP_START
|
||||
and not isinstance(event, NodeRunStreamChunkEvent)
|
||||
):
|
||||
continue
|
||||
|
||||
if (
|
||||
isinstance(event, NodeRunSucceededEvent)
|
||||
and event.node_type == NodeType.LOOP_END
|
||||
and not isinstance(event, NodeRunStreamChunkEvent)
|
||||
):
|
||||
check_break_result = True
|
||||
yield self._handle_event_metadata(event=event, iter_run_index=current_index)
|
||||
break
|
||||
|
||||
if isinstance(event, NodeRunSucceededEvent):
|
||||
yield self._handle_event_metadata(event=event, iter_run_index=current_index)
|
||||
|
||||
# Check if all variables in break conditions exist
|
||||
exists_variable = False
|
||||
for condition in break_conditions:
|
||||
if not self.graph_runtime_state.variable_pool.get(condition.variable_selector):
|
||||
exists_variable = False
|
||||
break
|
||||
else:
|
||||
exists_variable = True
|
||||
if exists_variable:
|
||||
input_conditions, group_result, check_break_result = condition_processor.process_conditions(
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
conditions=break_conditions,
|
||||
operator=logical_operator,
|
||||
)
|
||||
if check_break_result:
|
||||
break
|
||||
|
||||
elif isinstance(event, BaseGraphEvent):
|
||||
if isinstance(event, GraphRunFailedEvent):
|
||||
# Loop run failed
|
||||
yield LoopRunFailedEvent(
|
||||
loop_id=self.id,
|
||||
loop_node_id=self.node_id,
|
||||
loop_node_type=self.node_type,
|
||||
loop_node_data=self.node_data,
|
||||
start_at=start_at,
|
||||
inputs=inputs,
|
||||
steps=current_index,
|
||||
metadata={
|
||||
NodeRunMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens,
|
||||
"completed_reason": "error",
|
||||
},
|
||||
error=event.error,
|
||||
)
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=event.error,
|
||||
metadata={NodeRunMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens},
|
||||
)
|
||||
)
|
||||
return {"check_break_result": True}
|
||||
elif isinstance(event, NodeRunFailedEvent):
|
||||
# Loop run failed
|
||||
yield event
|
||||
yield LoopRunFailedEvent(
|
||||
loop_id=self.id,
|
||||
loop_node_id=self.node_id,
|
||||
loop_node_type=self.node_type,
|
||||
loop_node_data=self.node_data,
|
||||
start_at=start_at,
|
||||
inputs=inputs,
|
||||
steps=current_index,
|
||||
metadata={
|
||||
NodeRunMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens,
|
||||
"completed_reason": "error",
|
||||
},
|
||||
error=event.error,
|
||||
)
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=event.error,
|
||||
metadata={NodeRunMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens},
|
||||
)
|
||||
)
|
||||
return {"check_break_result": True}
|
||||
else:
|
||||
yield self._handle_event_metadata(event=cast(InNodeEvent, event), iter_run_index=current_index)
|
||||
|
||||
# Remove all nodes outputs from variable pool
|
||||
for node_id in loop_graph.node_ids:
|
||||
variable_pool.remove([node_id])
|
||||
|
||||
_outputs = {}
|
||||
for loop_variable_key, loop_variable_selector in loop_variable_selectors.items():
|
||||
_loop_variable_segment = variable_pool.get(loop_variable_selector)
|
||||
if _loop_variable_segment:
|
||||
_outputs[loop_variable_key] = _loop_variable_segment.value
|
||||
else:
|
||||
_outputs[loop_variable_key] = None
|
||||
|
||||
_outputs["loop_round"] = current_index + 1
|
||||
self.node_data.outputs = _outputs
|
||||
|
||||
if check_break_result:
|
||||
return {"check_break_result": True}
|
||||
|
||||
# Move to next loop
|
||||
next_index = current_index + 1
|
||||
variable_pool.add([self.node_id, "index"], next_index)
|
||||
|
||||
yield LoopRunNextEvent(
|
||||
loop_id=self.id,
|
||||
loop_node_id=self.node_id,
|
||||
loop_node_type=self.node_type,
|
||||
loop_node_data=self.node_data,
|
||||
index=next_index,
|
||||
pre_loop_output=self.node_data.outputs,
|
||||
)
|
||||
|
||||
return {"check_break_result": False}
|
||||
|
||||
def _handle_event_metadata(
|
||||
self,
|
||||
*,
|
||||
|
@ -360,3 +480,25 @@ class LoopNode(BaseNode[LoopNodeData]):
|
|||
}
|
||||
|
||||
return variable_mapping
|
||||
|
||||
@staticmethod
|
||||
def _get_segment_for_constant(var_type: str, value: Any) -> Segment:
|
||||
"""Get the appropriate segment type for a constant value."""
|
||||
segment_mapping: dict[str, tuple[type[Segment], SegmentType]] = {
|
||||
"string": (StringSegment, SegmentType.STRING),
|
||||
"number": (IntegerSegment, SegmentType.NUMBER),
|
||||
"object": (ObjectSegment, SegmentType.OBJECT),
|
||||
"array[string]": (ArrayStringSegment, SegmentType.ARRAY_STRING),
|
||||
"array[number]": (ArrayNumberSegment, SegmentType.ARRAY_NUMBER),
|
||||
"array[object]": (ArrayObjectSegment, SegmentType.ARRAY_OBJECT),
|
||||
}
|
||||
if var_type in ["array[string]", "array[number]", "array[object]"]:
|
||||
if value:
|
||||
value = json.loads(value)
|
||||
else:
|
||||
value = []
|
||||
segment_info = segment_mapping.get(var_type)
|
||||
if not segment_info:
|
||||
raise ValueError(f"Invalid variable type: {var_type}")
|
||||
segment_class, value_type = segment_info
|
||||
return segment_class(value=value, value_type=value_type)
|
||||
|
|
|
@ -13,7 +13,7 @@ from core.workflow.nodes.iteration import IterationNode, IterationStartNode
|
|||
from core.workflow.nodes.knowledge_retrieval import KnowledgeRetrievalNode
|
||||
from core.workflow.nodes.list_operator import ListOperatorNode
|
||||
from core.workflow.nodes.llm import LLMNode
|
||||
from core.workflow.nodes.loop import LoopNode, LoopStartNode
|
||||
from core.workflow.nodes.loop import LoopEndNode, LoopNode, LoopStartNode
|
||||
from core.workflow.nodes.parameter_extractor import ParameterExtractorNode
|
||||
from core.workflow.nodes.question_classifier import QuestionClassifierNode
|
||||
from core.workflow.nodes.start import StartNode
|
||||
|
@ -94,6 +94,10 @@ NODE_TYPE_CLASSES_MAPPING: Mapping[NodeType, Mapping[str, type[BaseNode]]] = {
|
|||
LATEST_VERSION: LoopStartNode,
|
||||
"1": LoopStartNode,
|
||||
},
|
||||
NodeType.LOOP_END: {
|
||||
LATEST_VERSION: LoopEndNode,
|
||||
"1": LoopEndNode,
|
||||
},
|
||||
NodeType.PARAMETER_EXTRACTOR: {
|
||||
LATEST_VERSION: ParameterExtractorNode,
|
||||
"1": ParameterExtractorNode,
|
||||
|
|
|
@ -4,7 +4,7 @@ from core.workflow.nodes.base import BaseNodeData
|
|||
|
||||
class TemplateTransformNodeData(BaseNodeData):
|
||||
"""
|
||||
Code Node Data.
|
||||
Template Transform Node Data.
|
||||
"""
|
||||
|
||||
variables: list[VariableSelector]
|
||||
|
|
|
@ -26,7 +26,7 @@ class AdvancedSettings(BaseModel):
|
|||
|
||||
class VariableAssignerNodeData(BaseNodeData):
|
||||
"""
|
||||
Knowledge retrieval Node Data.
|
||||
Variable Assigner Node Data.
|
||||
"""
|
||||
|
||||
type: str = "variable-assigner"
|
||||
|
|
|
@ -2,6 +2,7 @@ import json
|
|||
from collections.abc import Sequence
|
||||
from typing import Any, cast
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.variables import SegmentType, Variable
|
||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
|
@ -123,13 +124,14 @@ class VariableAssignerNode(BaseNode[VariableAssignerNodeData]):
|
|||
if variable.selector[0] == CONVERSATION_VARIABLE_NODE_ID:
|
||||
conversation_id = self.graph_runtime_state.variable_pool.get(["sys", "conversation_id"])
|
||||
if not conversation_id:
|
||||
raise ConversationIDNotFoundError
|
||||
if self.invoke_from != InvokeFrom.DEBUGGER:
|
||||
raise ConversationIDNotFoundError
|
||||
else:
|
||||
conversation_id = conversation_id.value
|
||||
common_helpers.update_conversation_variable(
|
||||
conversation_id=cast(str, conversation_id),
|
||||
variable=variable,
|
||||
)
|
||||
common_helpers.update_conversation_variable(
|
||||
conversation_id=cast(str, conversation_id),
|
||||
variable=variable,
|
||||
)
|
||||
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
|
|
|
@ -134,8 +134,9 @@ def _build_from_local_file(
|
|||
if row is None:
|
||||
raise ValueError("Invalid upload file")
|
||||
|
||||
file_type = FileType(mapping.get("type", "custom"))
|
||||
file_type = _standardize_file_type(file_type, extension="." + row.extension, mime_type=row.mime_type)
|
||||
file_type = _standardize_file_type(extension="." + row.extension, mime_type=row.mime_type)
|
||||
if file_type.value != mapping.get("type", "custom"):
|
||||
raise ValueError("Detected file type does not match the specified type. Please verify the file.")
|
||||
|
||||
return File(
|
||||
id=mapping.get("id"),
|
||||
|
@ -173,10 +174,9 @@ def _build_from_remote_url(
|
|||
if upload_file is None:
|
||||
raise ValueError("Invalid upload file")
|
||||
|
||||
file_type = FileType(mapping.get("type", "custom"))
|
||||
file_type = _standardize_file_type(
|
||||
file_type, extension="." + upload_file.extension, mime_type=upload_file.mime_type
|
||||
)
|
||||
file_type = _standardize_file_type(extension="." + upload_file.extension, mime_type=upload_file.mime_type)
|
||||
if file_type.value != mapping.get("type", "custom"):
|
||||
raise ValueError("Detected file type does not match the specified type. Please verify the file.")
|
||||
|
||||
return File(
|
||||
id=mapping.get("id"),
|
||||
|
@ -198,8 +198,9 @@ def _build_from_remote_url(
|
|||
mime_type, filename, file_size = _get_remote_file_info(url)
|
||||
extension = mimetypes.guess_extension(mime_type) or ("." + filename.split(".")[-1] if "." in filename else ".bin")
|
||||
|
||||
file_type = FileType(mapping.get("type", "custom"))
|
||||
file_type = _standardize_file_type(file_type, extension=extension, mime_type=mime_type)
|
||||
file_type = _standardize_file_type(extension=extension, mime_type=mime_type)
|
||||
if file_type.value != mapping.get("type", "custom"):
|
||||
raise ValueError("Detected file type does not match the specified type. Please verify the file.")
|
||||
|
||||
return File(
|
||||
id=mapping.get("id"),
|
||||
|
@ -250,8 +251,8 @@ def _build_from_tool_file(
|
|||
raise ValueError(f"ToolFile {mapping.get('tool_file_id')} not found")
|
||||
|
||||
extension = "." + tool_file.file_key.split(".")[-1] if "." in tool_file.file_key else ".bin"
|
||||
file_type = FileType(mapping.get("type", "custom"))
|
||||
file_type = _standardize_file_type(file_type, extension=extension, mime_type=tool_file.mimetype)
|
||||
|
||||
file_type = _standardize_file_type(extension=extension, mime_type=tool_file.mimetype)
|
||||
|
||||
return File(
|
||||
id=mapping.get("id"),
|
||||
|
@ -302,12 +303,10 @@ def _is_file_valid_with_config(
|
|||
return True
|
||||
|
||||
|
||||
def _standardize_file_type(file_type: FileType, /, *, extension: str = "", mime_type: str = "") -> FileType:
|
||||
def _standardize_file_type(*, extension: str = "", mime_type: str = "") -> FileType:
|
||||
"""
|
||||
If custom type, try to guess the file type by extension and mime_type.
|
||||
Infer the possible actual type of the file based on the extension and mime_type
|
||||
"""
|
||||
if file_type != FileType.CUSTOM:
|
||||
return FileType(file_type)
|
||||
guessed_type = None
|
||||
if extension:
|
||||
guessed_type = _get_file_type_by_extension(extension)
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# This file is automatically @generated by Poetry 2.1.1 and should not be changed by hand.
|
||||
# This file is automatically @generated by Poetry 2.1.2 and should not be changed by hand.
|
||||
|
||||
[[package]]
|
||||
name = "aiofiles"
|
||||
|
@ -501,7 +501,7 @@ description = "Timeout context manager for asyncio programs"
|
|||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
groups = ["main"]
|
||||
markers = "python_full_version < \"3.11.3\""
|
||||
markers = "python_version == \"3.11\" and python_full_version < \"3.11.3\""
|
||||
files = [
|
||||
{file = "async_timeout-5.0.1-py3-none-any.whl", hash = "sha256:39e3809566ff85354557ec2398b55e096c8364bacac9405a7a1fa429e77fe76c"},
|
||||
{file = "async_timeout-5.0.1.tar.gz", hash = "sha256:d9321a7a3d5a6a5e187e824d2fa0793ce379a202935782d555d6e9d2735677d3"},
|
||||
|
@ -1199,7 +1199,7 @@ files = [
|
|||
{file = "cffi-1.17.1-cp39-cp39-win_amd64.whl", hash = "sha256:d016c76bdd850f3c626af19b0542c9677ba156e4ee4fccfdd7848803533ef662"},
|
||||
{file = "cffi-1.17.1.tar.gz", hash = "sha256:1c39c6016c32bc48dd54561950ebd6836e1670f2ae46128f67cf49e789c52824"},
|
||||
]
|
||||
markers = {storage = "platform_python_implementation != \"PyPy\"", vdb = "python_version < \"3.12\" or platform_python_implementation != \"PyPy\""}
|
||||
markers = {storage = "platform_python_implementation != \"PyPy\"", vdb = "platform_python_implementation != \"PyPy\""}
|
||||
|
||||
[package.dependencies]
|
||||
pycparser = "*"
|
||||
|
@ -2947,6 +2947,7 @@ files = [
|
|||
{file = "google_crc32c-1.7.0-cp313-cp313-macosx_12_0_x86_64.whl", hash = "sha256:6a40522958040051c755a173eb98c05ad4d64a6dd898888c3e5ccca2d1cbdcdc"},
|
||||
{file = "google_crc32c-1.7.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f714fe5cdf5007d7064c57cf7471a99e0cbafda24ddfa829117fc3baafa424f7"},
|
||||
{file = "google_crc32c-1.7.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f04e58dbe1bf0c9398e603a9be5aaa09e0ba7eb022a3293195d8749459a01069"},
|
||||
{file = "google_crc32c-1.7.0-cp313-cp313-win_amd64.whl", hash = "sha256:e545b51ddf97f604d30114f7c23eecaf4c06cd6c023ff1ae0b80dcd99af32833"},
|
||||
{file = "google_crc32c-1.7.0-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:364067b063664dd8d1fec75a3fe85edf05c46f688365269beccaf42ef5dfe889"},
|
||||
{file = "google_crc32c-1.7.0-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e1b0d6044799f6ac51d1cc2decb997280a83c448b3bef517a54b57a3b71921c0"},
|
||||
{file = "google_crc32c-1.7.0-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:02bc3295d26cd7666521fd6d5b7b93923ae1eb4417ddd3bc57185a5881ad7b96"},
|
||||
|
@ -6367,7 +6368,6 @@ files = [
|
|||
{file = "psycopg2_binary-2.9.10-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:bb89f0a835bcfc1d42ccd5f41f04870c1b936d8507c6df12b7737febc40f0909"},
|
||||
{file = "psycopg2_binary-2.9.10-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:f0c2d907a1e102526dd2986df638343388b94c33860ff3bbe1384130828714b1"},
|
||||
{file = "psycopg2_binary-2.9.10-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:f8157bed2f51db683f31306aa497311b560f2265998122abe1dce6428bd86567"},
|
||||
{file = "psycopg2_binary-2.9.10-cp313-cp313-win_amd64.whl", hash = "sha256:27422aa5f11fbcd9b18da48373eb67081243662f9b46e6fd07c3eb46e4535142"},
|
||||
{file = "psycopg2_binary-2.9.10-cp38-cp38-macosx_12_0_x86_64.whl", hash = "sha256:eb09aa7f9cecb45027683bb55aebaaf45a0df8bf6de68801a6afdc7947bb09d4"},
|
||||
{file = "psycopg2_binary-2.9.10-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b73d6d7f0ccdad7bc43e6d34273f70d587ef62f824d7261c4ae9b8b1b6af90e8"},
|
||||
{file = "psycopg2_binary-2.9.10-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ce5ab4bf46a211a8e924d307c1b1fcda82368586a19d0a24f8ae166f5c784864"},
|
||||
|
@ -6452,7 +6452,7 @@ files = [
|
|||
{file = "pycparser-2.22-py3-none-any.whl", hash = "sha256:c3702b6d3dd8c7abc1afa565d7e63d53a1d0bd86cdc24edd75470f4de499cfcc"},
|
||||
{file = "pycparser-2.22.tar.gz", hash = "sha256:491c8be9c040f5390f5bf44a5b07752bd07f56edf992381b05c701439eec10f6"},
|
||||
]
|
||||
markers = {storage = "platform_python_implementation != \"PyPy\"", vdb = "python_version < \"3.12\" or platform_python_implementation != \"PyPy\""}
|
||||
markers = {storage = "platform_python_implementation != \"PyPy\"", vdb = "platform_python_implementation != \"PyPy\""}
|
||||
|
||||
[[package]]
|
||||
name = "pycryptodome"
|
||||
|
@ -7240,7 +7240,7 @@ files = [
|
|||
{file = "pywin32-310-cp39-cp39-win32.whl", hash = "sha256:851c8d927af0d879221e616ae1f66145253537bbdd321a77e8ef701b443a9a1a"},
|
||||
{file = "pywin32-310-cp39-cp39-win_amd64.whl", hash = "sha256:96867217335559ac619f00ad70e513c0fcf84b8a3af9fc2bba3b59b97da70475"},
|
||||
]
|
||||
markers = {main = "platform_system == \"Windows\" and platform_python_implementation != \"PyPy\"", vdb = "platform_system == \"Windows\""}
|
||||
markers = {main = "platform_python_implementation != \"PyPy\" and platform_system == \"Windows\"", vdb = "platform_system == \"Windows\""}
|
||||
|
||||
[[package]]
|
||||
name = "pyxlsb"
|
||||
|
@ -9357,7 +9357,7 @@ description = "Fast implementation of asyncio event loop on top of libuv"
|
|||
optional = false
|
||||
python-versions = ">=3.8.0"
|
||||
groups = ["vdb"]
|
||||
markers = "sys_platform != \"win32\" and sys_platform != \"cygwin\" and platform_python_implementation != \"PyPy\""
|
||||
markers = "platform_python_implementation != \"PyPy\" and sys_platform != \"win32\" and sys_platform != \"cygwin\""
|
||||
files = [
|
||||
{file = "uvloop-0.21.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:ec7e6b09a6fdded42403182ab6b832b71f4edaf7f37a9a0e371a01db5f0cb45f"},
|
||||
{file = "uvloop-0.21.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:196274f2adb9689a289ad7d65700d37df0c0930fd8e4e743fa4834e850d7719d"},
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
import base64
|
||||
import hashlib
|
||||
import logging
|
||||
import uuid
|
||||
from collections.abc import Mapping
|
||||
|
@ -7,6 +9,8 @@ from urllib.parse import urlparse
|
|||
from uuid import uuid4
|
||||
|
||||
import yaml # type: ignore
|
||||
from Crypto.Cipher import AES
|
||||
from Crypto.Util.Padding import pad, unpad
|
||||
from packaging import version
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select
|
||||
|
@ -478,6 +482,15 @@ class AppDslService:
|
|||
unique_hash = current_draft_workflow.unique_hash
|
||||
else:
|
||||
unique_hash = None
|
||||
graph = workflow_data.get("graph", {})
|
||||
for node in graph.get("nodes", []):
|
||||
if node.get("data", {}).get("type", "") == NodeType.KNOWLEDGE_RETRIEVAL.value:
|
||||
dataset_ids = node["data"].get("dataset_ids", [])
|
||||
node["data"]["dataset_ids"] = [
|
||||
decrypted_id
|
||||
for dataset_id in dataset_ids
|
||||
if (decrypted_id := self.decrypt_dataset_id(encrypted_data=dataset_id, tenant_id=app.tenant_id))
|
||||
]
|
||||
workflow_service.sync_draft_workflow(
|
||||
app_model=app,
|
||||
graph=workflow_data.get("graph", {}),
|
||||
|
@ -552,7 +565,15 @@ class AppDslService:
|
|||
if not workflow:
|
||||
raise ValueError("Missing draft workflow configuration, please check.")
|
||||
|
||||
export_data["workflow"] = workflow.to_dict(include_secret=include_secret)
|
||||
workflow_dict = workflow.to_dict(include_secret=include_secret)
|
||||
for node in workflow_dict.get("graph", {}).get("nodes", []):
|
||||
if node.get("data", {}).get("type", "") == NodeType.KNOWLEDGE_RETRIEVAL.value:
|
||||
dataset_ids = node["data"].get("dataset_ids", [])
|
||||
node["data"]["dataset_ids"] = [
|
||||
cls.encrypt_dataset_id(dataset_id=dataset_id, tenant_id=app_model.tenant_id)
|
||||
for dataset_id in dataset_ids
|
||||
]
|
||||
export_data["workflow"] = workflow_dict
|
||||
dependencies = cls._extract_dependencies_from_workflow(workflow)
|
||||
export_data["dependencies"] = [
|
||||
jsonable_encoder(d.model_dump())
|
||||
|
@ -724,3 +745,29 @@ class AppDslService:
|
|||
return []
|
||||
|
||||
return DependenciesAnalysisService.get_leaked_dependencies(tenant_id=tenant_id, dependencies=dependencies)
|
||||
|
||||
@staticmethod
|
||||
def _generate_aes_key(tenant_id: str) -> bytes:
|
||||
"""Generate AES key based on tenant_id"""
|
||||
return hashlib.sha256(tenant_id.encode()).digest()
|
||||
|
||||
@classmethod
|
||||
def encrypt_dataset_id(cls, dataset_id: str, tenant_id: str) -> str:
|
||||
"""Encrypt dataset_id using AES-CBC mode"""
|
||||
key = cls._generate_aes_key(tenant_id)
|
||||
iv = key[:16]
|
||||
cipher = AES.new(key, AES.MODE_CBC, iv)
|
||||
ct_bytes = cipher.encrypt(pad(dataset_id.encode(), AES.block_size))
|
||||
return base64.b64encode(ct_bytes).decode()
|
||||
|
||||
@classmethod
|
||||
def decrypt_dataset_id(cls, encrypted_data: str, tenant_id: str) -> str | None:
|
||||
"""AES decryption"""
|
||||
try:
|
||||
key = cls._generate_aes_key(tenant_id)
|
||||
iv = key[:16]
|
||||
cipher = AES.new(key, AES.MODE_CBC, iv)
|
||||
pt = unpad(cipher.decrypt(base64.b64decode(encrypted_data)), AES.block_size)
|
||||
return pt.decode()
|
||||
except Exception:
|
||||
return None
|
||||
|
|
|
@ -17,6 +17,10 @@ class ApiKeyAuthFactory:
|
|||
from services.auth.firecrawl.firecrawl import FirecrawlAuth
|
||||
|
||||
return FirecrawlAuth
|
||||
case AuthType.WATERCRAWL:
|
||||
from services.auth.watercrawl.watercrawl import WatercrawlAuth
|
||||
|
||||
return WatercrawlAuth
|
||||
case AuthType.JINA:
|
||||
from services.auth.jina.jina import JinaAuth
|
||||
|
||||
|
|
|
@ -3,4 +3,5 @@ from enum import StrEnum
|
|||
|
||||
class AuthType(StrEnum):
|
||||
FIRECRAWL = "firecrawl"
|
||||
WATERCRAWL = "watercrawl"
|
||||
JINA = "jinareader"
|
||||
|
|
|
@ -0,0 +1,44 @@
|
|||
import json
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import requests
|
||||
|
||||
from services.auth.api_key_auth_base import ApiKeyAuthBase
|
||||
|
||||
|
||||
class WatercrawlAuth(ApiKeyAuthBase):
|
||||
def __init__(self, credentials: dict):
|
||||
super().__init__(credentials)
|
||||
auth_type = credentials.get("auth_type")
|
||||
if auth_type != "x-api-key":
|
||||
raise ValueError("Invalid auth type, WaterCrawl auth type must be x-api-key")
|
||||
self.api_key = credentials.get("config", {}).get("api_key", None)
|
||||
self.base_url = credentials.get("config", {}).get("base_url", "https://app.watercrawl.dev")
|
||||
|
||||
if not self.api_key:
|
||||
raise ValueError("No API key provided")
|
||||
|
||||
def validate_credentials(self):
|
||||
headers = self._prepare_headers()
|
||||
url = urljoin(self.base_url, "/api/v1/core/crawl-requests/")
|
||||
response = self._get_request(url, headers)
|
||||
if response.status_code == 200:
|
||||
return True
|
||||
else:
|
||||
self._handle_error(response)
|
||||
|
||||
def _prepare_headers(self):
|
||||
return {"Content-Type": "application/json", "X-API-KEY": self.api_key}
|
||||
|
||||
def _get_request(self, url, headers):
|
||||
return requests.get(url, headers=headers)
|
||||
|
||||
def _handle_error(self, response):
|
||||
if response.status_code in {402, 409, 500}:
|
||||
error_message = response.json().get("error", "Unknown error occurred")
|
||||
raise Exception(f"Failed to authorize. Status code: {response.status_code}. Error: {error_message}")
|
||||
else:
|
||||
if response.text:
|
||||
error_message = json.loads(response.text).get("error", "Unknown error occurred")
|
||||
raise Exception(f"Failed to authorize. Status code: {response.status_code}. Error: {error_message}")
|
||||
raise Exception(f"Unexpected error occurred while trying to authorize. Status code: {response.status_code}")
|
|
@ -6,7 +6,7 @@ from tenacity import retry, retry_if_exception_type, stop_before_delay, wait_fix
|
|||
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import RateLimiter
|
||||
from models.account import TenantAccountJoin, TenantAccountRole
|
||||
from models.account import Account, TenantAccountJoin, TenantAccountRole
|
||||
|
||||
|
||||
class BillingService:
|
||||
|
@ -106,6 +106,48 @@ class BillingService:
|
|||
json = {"email": email, "feedback": feedback}
|
||||
return cls._send_request("POST", "/account/delete-feedback", json=json)
|
||||
|
||||
class EducationIdentity:
|
||||
verification_rate_limit = RateLimiter(prefix="edu_verification_rate_limit", max_attempts=10, time_window=60)
|
||||
activation_rate_limit = RateLimiter(prefix="edu_activation_rate_limit", max_attempts=10, time_window=60)
|
||||
|
||||
@classmethod
|
||||
def verify(cls, account_id: str, account_email: str):
|
||||
if cls.verification_rate_limit.is_rate_limited(account_email):
|
||||
from controllers.console.error import EducationVerifyLimitError
|
||||
|
||||
raise EducationVerifyLimitError()
|
||||
|
||||
cls.verification_rate_limit.increment_rate_limit(account_email)
|
||||
|
||||
params = {"account_id": account_id}
|
||||
return BillingService._send_request("GET", "/education/verify", params=params)
|
||||
|
||||
@classmethod
|
||||
def is_active(cls, account_id: str):
|
||||
params = {"account_id": account_id}
|
||||
return BillingService._send_request("GET", "/education/status", params=params)
|
||||
|
||||
@classmethod
|
||||
def activate(cls, account: Account, token: str, institution: str, role: str):
|
||||
if cls.activation_rate_limit.is_rate_limited(account.email):
|
||||
from controllers.console.error import EducationActivateLimitError
|
||||
|
||||
raise EducationActivateLimitError()
|
||||
|
||||
cls.activation_rate_limit.increment_rate_limit(account.email)
|
||||
params = {"account_id": account.id, "curr_tenant_id": account.current_tenant_id}
|
||||
json = {
|
||||
"institution": institution,
|
||||
"token": token,
|
||||
"role": role,
|
||||
}
|
||||
return BillingService._send_request("POST", "/education/", json=json, params=params)
|
||||
|
||||
@classmethod
|
||||
def autocomplete(cls, keywords: str, page: int = 0, limit: int = 20):
|
||||
params = {"keywords": keywords, "page": page, "limit": limit}
|
||||
return BillingService._send_request("GET", "/education/autocomplete", params=params)
|
||||
|
||||
@classmethod
|
||||
def get_compliance_download_link(
|
||||
cls,
|
||||
|
|
|
@ -880,6 +880,9 @@ class DocumentService:
|
|||
website_info = knowledge_config.data_source.info_list.website_info_list
|
||||
count = len(website_info.urls) # type: ignore
|
||||
batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT)
|
||||
|
||||
if features.billing.subscription.plan == "sandbox" and count > 1:
|
||||
raise ValueError("Your current plan does not support batch upload, please upgrade your plan.")
|
||||
if count > batch_upload_limit:
|
||||
raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")
|
||||
|
||||
|
@ -1328,6 +1331,8 @@ class DocumentService:
|
|||
website_info = knowledge_config.data_source.info_list.website_info_list # type: ignore
|
||||
if website_info:
|
||||
count = len(website_info.urls)
|
||||
if features.billing.subscription.plan == "sandbox" and count > 1:
|
||||
raise ValueError("Your current plan does not support batch upload, please upgrade your plan.")
|
||||
batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT)
|
||||
if count > batch_upload_limit:
|
||||
raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")
|
||||
|
@ -1663,6 +1668,7 @@ class SegmentService:
|
|||
content=content,
|
||||
word_count=len(content),
|
||||
tokens=tokens,
|
||||
keywords=segment_item.get("keywords", []),
|
||||
status="completed",
|
||||
indexing_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
|
||||
completed_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
|
||||
|
@ -1780,12 +1786,8 @@ class SegmentService:
|
|||
)
|
||||
elif document.doc_form in (IndexType.PARAGRAPH_INDEX, IndexType.QA_INDEX):
|
||||
if args.enabled or keyword_changed:
|
||||
VectorService.create_segments_vector(
|
||||
[args.keywords] if args.keywords else None,
|
||||
[segment],
|
||||
dataset,
|
||||
document.doc_form,
|
||||
)
|
||||
# update segment vector index
|
||||
VectorService.update_segment_vector(args.keywords, segment, dataset)
|
||||
else:
|
||||
segment_hash = helper.generate_text_hash(content)
|
||||
tokens = 0
|
||||
|
|
|
@ -17,6 +17,11 @@ class BillingModel(BaseModel):
|
|||
subscription: SubscriptionModel = SubscriptionModel()
|
||||
|
||||
|
||||
class EducationModel(BaseModel):
|
||||
enabled: bool = False
|
||||
activated: bool = False
|
||||
|
||||
|
||||
class LimitationModel(BaseModel):
|
||||
size: int = 0
|
||||
limit: int = 0
|
||||
|
@ -38,6 +43,7 @@ class LicenseModel(BaseModel):
|
|||
|
||||
class FeatureModel(BaseModel):
|
||||
billing: BillingModel = BillingModel()
|
||||
education: EducationModel = EducationModel()
|
||||
members: LimitationModel = LimitationModel(size=0, limit=1)
|
||||
apps: LimitationModel = LimitationModel(size=0, limit=10)
|
||||
vector_space: LimitationModel = LimitationModel(size=0, limit=5)
|
||||
|
@ -128,6 +134,7 @@ class FeatureService:
|
|||
features.can_replace_logo = dify_config.CAN_REPLACE_LOGO
|
||||
features.model_load_balancing_enabled = dify_config.MODEL_LB_ENABLED
|
||||
features.dataset_operator_enabled = dify_config.DATASET_OPERATOR_ENABLED
|
||||
features.education.enabled = dify_config.EDUCATION_ENABLED
|
||||
|
||||
@classmethod
|
||||
def _fulfill_params_from_billing_api(cls, features: FeatureModel, tenant_id: str):
|
||||
|
@ -136,6 +143,7 @@ class FeatureService:
|
|||
features.billing.enabled = billing_info["enabled"]
|
||||
features.billing.subscription.plan = billing_info["subscription"]["plan"]
|
||||
features.billing.subscription.interval = billing_info["subscription"]["interval"]
|
||||
features.education.activated = billing_info["subscription"].get("education", False)
|
||||
|
||||
if "members" in billing_info:
|
||||
features.members.size = billing_info["members"]["size"]
|
||||
|
|
|
@ -20,7 +20,7 @@ class TagService:
|
|||
)
|
||||
if keyword:
|
||||
query = query.filter(db.and_(Tag.name.ilike(f"%{keyword}%")))
|
||||
query = query.group_by(Tag.id, Tag.type, Tag.name)
|
||||
query = query.group_by(Tag.id, Tag.type, Tag.name, Tag.created_at)
|
||||
results: list = query.order_by(Tag.created_at.desc()).all()
|
||||
return results
|
||||
|
||||
|
|
|
@ -7,6 +7,7 @@ from flask_login import current_user # type: ignore
|
|||
|
||||
from core.helper import encrypter
|
||||
from core.rag.extractor.firecrawl.firecrawl_app import FirecrawlApp
|
||||
from core.rag.extractor.watercrawl.provider import WaterCrawlProvider
|
||||
from extensions.ext_redis import redis_client
|
||||
from extensions.ext_storage import storage
|
||||
from services.auth.api_key_auth_service import ApiKeyAuthService
|
||||
|
@ -59,6 +60,13 @@ class WebsiteService:
|
|||
time = str(datetime.datetime.now().timestamp())
|
||||
redis_client.setex(website_crawl_time_cache_key, 3600, time)
|
||||
return {"status": "active", "job_id": job_id}
|
||||
elif provider == "watercrawl":
|
||||
# decrypt api_key
|
||||
api_key = encrypter.decrypt_token(
|
||||
tenant_id=current_user.current_tenant_id, token=credentials.get("config").get("api_key")
|
||||
)
|
||||
return WaterCrawlProvider(api_key, credentials.get("config").get("base_url", None)).crawl_url(url, options)
|
||||
|
||||
elif provider == "jinareader":
|
||||
api_key = encrypter.decrypt_token(
|
||||
tenant_id=current_user.current_tenant_id, token=credentials.get("config").get("api_key")
|
||||
|
@ -116,6 +124,14 @@ class WebsiteService:
|
|||
time_consuming = abs(end_time - float(start_time))
|
||||
crawl_status_data["time_consuming"] = f"{time_consuming:.2f}"
|
||||
redis_client.delete(website_crawl_time_cache_key)
|
||||
elif provider == "watercrawl":
|
||||
# decrypt api_key
|
||||
api_key = encrypter.decrypt_token(
|
||||
tenant_id=current_user.current_tenant_id, token=credentials.get("config").get("api_key")
|
||||
)
|
||||
crawl_status_data = WaterCrawlProvider(
|
||||
api_key, credentials.get("config").get("base_url", None)
|
||||
).get_crawl_status(job_id)
|
||||
elif provider == "jinareader":
|
||||
api_key = encrypter.decrypt_token(
|
||||
tenant_id=current_user.current_tenant_id, token=credentials.get("config").get("api_key")
|
||||
|
@ -180,6 +196,11 @@ class WebsiteService:
|
|||
if item.get("source_url") == url:
|
||||
return dict(item)
|
||||
return None
|
||||
elif provider == "watercrawl":
|
||||
api_key = encrypter.decrypt_token(tenant_id=tenant_id, token=credentials.get("config").get("api_key"))
|
||||
return WaterCrawlProvider(api_key, credentials.get("config").get("base_url", None)).get_crawl_url_data(
|
||||
job_id, url
|
||||
)
|
||||
elif provider == "jinareader":
|
||||
if not job_id:
|
||||
response = requests.get(
|
||||
|
@ -223,5 +244,8 @@ class WebsiteService:
|
|||
params = {"onlyMainContent": only_main_content}
|
||||
result = firecrawl_app.scrape_url(url, params)
|
||||
return result
|
||||
elif provider == "watercrawl":
|
||||
api_key = encrypter.decrypt_token(tenant_id=tenant_id, token=credentials.get("config").get("api_key"))
|
||||
return WaterCrawlProvider(api_key, credentials.get("config").get("base_url", None)).scrape_url(url)
|
||||
else:
|
||||
raise ValueError("Invalid provider")
|
||||
|
|
|
@ -4,7 +4,6 @@ import time
|
|||
|
||||
import click
|
||||
from celery import shared_task # type: ignore
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from core.rag.index_processor.constant.index_type import IndexType
|
||||
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||
|
@ -28,7 +27,9 @@ def add_document_to_index_task(dataset_document_id: str):
|
|||
|
||||
dataset_document = db.session.query(DatasetDocument).filter(DatasetDocument.id == dataset_document_id).first()
|
||||
if not dataset_document:
|
||||
raise NotFound("Document not found")
|
||||
logging.info(click.style("Document not found: {}".format(dataset_document_id), fg="red"))
|
||||
db.session.close()
|
||||
return
|
||||
|
||||
if dataset_document.indexing_status != "completed":
|
||||
return
|
||||
|
|
|
@ -6,6 +6,7 @@ from celery import shared_task # type: ignore
|
|||
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset
|
||||
from services.dataset_service import DatasetCollectionBindingService
|
||||
|
||||
|
@ -55,3 +56,5 @@ def add_annotation_to_index_task(
|
|||
)
|
||||
except Exception:
|
||||
logging.exception("Build index for annotation failed")
|
||||
finally:
|
||||
db.session.close()
|
||||
|
|
|
@ -88,3 +88,5 @@ def batch_import_annotations_task(job_id: str, content_list: list[dict], app_id:
|
|||
indexing_error_msg_key = "app_annotation_batch_import_error_msg_{}".format(str(job_id))
|
||||
redis_client.setex(indexing_error_msg_key, 600, str(e))
|
||||
logging.exception("Build index for batch import annotations failed")
|
||||
finally:
|
||||
db.session.close()
|
||||
|
|
|
@ -5,6 +5,7 @@ import click
|
|||
from celery import shared_task # type: ignore
|
||||
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset
|
||||
from services.dataset_service import DatasetCollectionBindingService
|
||||
|
||||
|
@ -39,3 +40,5 @@ def delete_annotation_index_task(annotation_id: str, app_id: str, tenant_id: str
|
|||
)
|
||||
except Exception as e:
|
||||
logging.exception("Annotation deleted index failed")
|
||||
finally:
|
||||
db.session.close()
|
||||
|
|
|
@ -3,7 +3,6 @@ import time
|
|||
|
||||
import click
|
||||
from celery import shared_task # type: ignore
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from extensions.ext_database import db
|
||||
|
@ -23,14 +22,18 @@ def disable_annotation_reply_task(job_id: str, app_id: str, tenant_id: str):
|
|||
app = db.session.query(App).filter(App.id == app_id, App.tenant_id == tenant_id, App.status == "normal").first()
|
||||
annotations_count = db.session.query(MessageAnnotation).filter(MessageAnnotation.app_id == app_id).count()
|
||||
if not app:
|
||||
raise NotFound("App not found")
|
||||
logging.info(click.style("App not found: {}".format(app_id), fg="red"))
|
||||
db.session.close()
|
||||
return
|
||||
|
||||
app_annotation_setting = (
|
||||
db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app_id).first()
|
||||
)
|
||||
|
||||
if not app_annotation_setting:
|
||||
raise NotFound("App annotation setting not found")
|
||||
logging.info(click.style("App annotation setting not found: {}".format(app_id), fg="red"))
|
||||
db.session.close()
|
||||
return
|
||||
|
||||
disable_app_annotation_key = "disable_app_annotation_{}".format(str(app_id))
|
||||
disable_app_annotation_job_key = "disable_app_annotation_job_{}".format(str(job_id))
|
||||
|
@ -46,7 +49,7 @@ def disable_annotation_reply_task(job_id: str, app_id: str, tenant_id: str):
|
|||
try:
|
||||
if annotations_count > 0:
|
||||
vector = Vector(dataset, attributes=["doc_id", "annotation_id", "app_id"])
|
||||
vector.delete_by_metadata_field("app_id", app_id)
|
||||
vector.delete()
|
||||
except Exception:
|
||||
logging.exception("Delete annotation index failed when annotation deleted.")
|
||||
redis_client.setex(disable_app_annotation_job_key, 600, "completed")
|
||||
|
@ -66,3 +69,4 @@ def disable_annotation_reply_task(job_id: str, app_id: str, tenant_id: str):
|
|||
redis_client.setex(disable_app_annotation_error_key, 600, str(e))
|
||||
finally:
|
||||
redis_client.delete(disable_app_annotation_key)
|
||||
db.session.close()
|
||||
|
|
|
@ -4,7 +4,6 @@ import time
|
|||
|
||||
import click
|
||||
from celery import shared_task # type: ignore
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from core.rag.models.document import Document
|
||||
|
@ -34,7 +33,9 @@ def enable_annotation_reply_task(
|
|||
app = db.session.query(App).filter(App.id == app_id, App.tenant_id == tenant_id, App.status == "normal").first()
|
||||
|
||||
if not app:
|
||||
raise NotFound("App not found")
|
||||
logging.info(click.style("App not found: {}".format(app_id), fg="red"))
|
||||
db.session.close()
|
||||
return
|
||||
|
||||
annotations = db.session.query(MessageAnnotation).filter(MessageAnnotation.app_id == app_id).all()
|
||||
enable_app_annotation_key = "enable_app_annotation_{}".format(str(app_id))
|
||||
|
@ -49,6 +50,27 @@ def enable_annotation_reply_task(
|
|||
db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app_id).first()
|
||||
)
|
||||
if annotation_setting:
|
||||
if dataset_collection_binding.id != annotation_setting.collection_binding_id:
|
||||
old_dataset_collection_binding = (
|
||||
DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(
|
||||
annotation_setting.collection_binding_id, "annotation"
|
||||
)
|
||||
)
|
||||
if old_dataset_collection_binding and annotations:
|
||||
old_dataset = Dataset(
|
||||
id=app_id,
|
||||
tenant_id=tenant_id,
|
||||
indexing_technique="high_quality",
|
||||
embedding_model_provider=old_dataset_collection_binding.provider_name,
|
||||
embedding_model=old_dataset_collection_binding.model_name,
|
||||
collection_binding_id=old_dataset_collection_binding.id,
|
||||
)
|
||||
|
||||
old_vector = Vector(old_dataset, attributes=["doc_id", "annotation_id", "app_id"])
|
||||
try:
|
||||
old_vector.delete()
|
||||
except Exception as e:
|
||||
logging.info(click.style("Delete annotation index error: {}".format(str(e)), fg="red"))
|
||||
annotation_setting.score_threshold = score_threshold
|
||||
annotation_setting.collection_binding_id = dataset_collection_binding.id
|
||||
annotation_setting.updated_user_id = user_id
|
||||
|
@ -100,3 +122,4 @@ def enable_annotation_reply_task(
|
|||
db.session.rollback()
|
||||
finally:
|
||||
redis_client.delete(enable_app_annotation_key)
|
||||
db.session.close()
|
||||
|
|
|
@ -6,6 +6,7 @@ from celery import shared_task # type: ignore
|
|||
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset
|
||||
from services.dataset_service import DatasetCollectionBindingService
|
||||
|
||||
|
@ -56,3 +57,5 @@ def update_annotation_to_index_task(
|
|||
)
|
||||
except Exception:
|
||||
logging.exception("Build index for annotation failed")
|
||||
finally:
|
||||
db.session.close()
|
||||
|
|
|
@ -74,3 +74,5 @@ def batch_clean_document_task(document_ids: list[str], dataset_id: str, doc_form
|
|||
)
|
||||
except Exception:
|
||||
logging.exception("Cleaned documents when documents deleted failed")
|
||||
finally:
|
||||
db.session.close()
|
||||
|
|
|
@ -127,3 +127,5 @@ def batch_create_segment_to_index_task(
|
|||
except Exception:
|
||||
logging.exception("Segments batch created index failed")
|
||||
redis_client.setex(indexing_cache_key, 600, "error")
|
||||
finally:
|
||||
db.session.close()
|
||||
|
|
|
@ -11,6 +11,8 @@ from extensions.ext_storage import storage
|
|||
from models.dataset import (
|
||||
AppDatasetJoin,
|
||||
Dataset,
|
||||
DatasetMetadata,
|
||||
DatasetMetadataBinding,
|
||||
DatasetProcessRule,
|
||||
DatasetQuery,
|
||||
Document,
|
||||
|
@ -86,7 +88,9 @@ def clean_dataset_task(
|
|||
db.session.query(DatasetProcessRule).filter(DatasetProcessRule.dataset_id == dataset_id).delete()
|
||||
db.session.query(DatasetQuery).filter(DatasetQuery.dataset_id == dataset_id).delete()
|
||||
db.session.query(AppDatasetJoin).filter(AppDatasetJoin.dataset_id == dataset_id).delete()
|
||||
|
||||
# delete dataset metadata
|
||||
db.session.query(DatasetMetadata).filter(DatasetMetadata.dataset_id == dataset_id).delete()
|
||||
db.session.query(DatasetMetadataBinding).filter(DatasetMetadataBinding.dataset_id == dataset_id).delete()
|
||||
# delete files
|
||||
if documents:
|
||||
for document in documents:
|
||||
|
@ -117,3 +121,5 @@ def clean_dataset_task(
|
|||
)
|
||||
except Exception:
|
||||
logging.exception("Cleaned dataset when dataset deleted failed")
|
||||
finally:
|
||||
db.session.close()
|
||||
|
|
|
@ -9,7 +9,7 @@ from core.rag.index_processor.index_processor_factory import IndexProcessorFacto
|
|||
from core.tools.utils.rag_web_reader import get_image_upload_file_ids
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
from models.dataset import Dataset, DocumentSegment
|
||||
from models.dataset import Dataset, DatasetMetadataBinding, DocumentSegment
|
||||
from models.model import UploadFile
|
||||
|
||||
|
||||
|
@ -67,6 +67,12 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i
|
|||
db.session.delete(file)
|
||||
db.session.commit()
|
||||
|
||||
# delete dataset metadata binding
|
||||
db.session.query(DatasetMetadataBinding).filter(
|
||||
DatasetMetadataBinding.dataset_id == dataset_id,
|
||||
DatasetMetadataBinding.document_id == document_id,
|
||||
).delete()
|
||||
|
||||
end_at = time.perf_counter()
|
||||
logging.info(
|
||||
click.style(
|
||||
|
@ -76,3 +82,5 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i
|
|||
)
|
||||
except Exception:
|
||||
logging.exception("Cleaned document when document deleted failed")
|
||||
finally:
|
||||
db.session.close()
|
||||
|
|
|
@ -53,3 +53,5 @@ def clean_notion_document_task(document_ids: list[str], dataset_id: str):
|
|||
)
|
||||
except Exception:
|
||||
logging.exception("Cleaned document when import form notion document deleted failed")
|
||||
finally:
|
||||
db.session.close()
|
||||
|
|
|
@ -5,7 +5,6 @@ from typing import Optional
|
|||
|
||||
import click
|
||||
from celery import shared_task # type: ignore
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||
from core.rag.models.document import Document
|
||||
|
@ -27,7 +26,9 @@ def create_segment_to_index_task(segment_id: str, keywords: Optional[list[str]]
|
|||
|
||||
segment = db.session.query(DocumentSegment).filter(DocumentSegment.id == segment_id).first()
|
||||
if not segment:
|
||||
raise NotFound("Segment not found")
|
||||
logging.info(click.style("Segment not found: {}".format(segment_id), fg="red"))
|
||||
db.session.close()
|
||||
return
|
||||
|
||||
if segment.status != "waiting":
|
||||
return
|
||||
|
@ -93,3 +94,4 @@ def create_segment_to_index_task(segment_id: str, keywords: Optional[list[str]]
|
|||
db.session.commit()
|
||||
finally:
|
||||
redis_client.delete(indexing_cache_key)
|
||||
db.session.close()
|
||||
|
|
|
@ -167,3 +167,5 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
|
|||
)
|
||||
except Exception:
|
||||
logging.exception("Deal dataset vector index failed")
|
||||
finally:
|
||||
db.session.close()
|
||||
|
|
|
@ -41,3 +41,5 @@ def delete_segment_from_index_task(index_node_ids: list, dataset_id: str, docume
|
|||
logging.info(click.style("Segment deleted from index latency: {}".format(end_at - start_at), fg="green"))
|
||||
except Exception:
|
||||
logging.exception("delete segment from index failed")
|
||||
finally:
|
||||
db.session.close()
|
||||
|
|
|
@ -3,7 +3,6 @@ import time
|
|||
|
||||
import click
|
||||
from celery import shared_task # type: ignore
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||
from extensions.ext_database import db
|
||||
|
@ -24,10 +23,14 @@ def disable_segment_from_index_task(segment_id: str):
|
|||
|
||||
segment = db.session.query(DocumentSegment).filter(DocumentSegment.id == segment_id).first()
|
||||
if not segment:
|
||||
raise NotFound("Segment not found")
|
||||
logging.info(click.style("Segment not found: {}".format(segment_id), fg="red"))
|
||||
db.session.close()
|
||||
return
|
||||
|
||||
if segment.status != "completed":
|
||||
raise NotFound("Segment is not completed , disable action is not allowed.")
|
||||
logging.info(click.style("Segment is not completed, disable is not allowed: {}".format(segment_id), fg="red"))
|
||||
db.session.close()
|
||||
return
|
||||
|
||||
indexing_cache_key = "segment_{}_indexing".format(segment.id)
|
||||
|
||||
|
@ -62,3 +65,4 @@ def disable_segment_from_index_task(segment_id: str):
|
|||
db.session.commit()
|
||||
finally:
|
||||
redis_client.delete(indexing_cache_key)
|
||||
db.session.close()
|
||||
|
|
|
@ -26,15 +26,18 @@ def disable_segments_from_index_task(segment_ids: list, dataset_id: str, documen
|
|||
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
|
||||
if not dataset:
|
||||
logging.info(click.style("Dataset {} not found, pass.".format(dataset_id), fg="cyan"))
|
||||
db.session.close()
|
||||
return
|
||||
|
||||
dataset_document = db.session.query(DatasetDocument).filter(DatasetDocument.id == document_id).first()
|
||||
|
||||
if not dataset_document:
|
||||
logging.info(click.style("Document {} not found, pass.".format(document_id), fg="cyan"))
|
||||
db.session.close()
|
||||
return
|
||||
if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed":
|
||||
logging.info(click.style("Document {} status is invalid, pass.".format(document_id), fg="cyan"))
|
||||
db.session.close()
|
||||
return
|
||||
# sync index processor
|
||||
index_processor = IndexProcessorFactory(dataset_document.doc_form).init_index_processor()
|
||||
|
@ -50,6 +53,7 @@ def disable_segments_from_index_task(segment_ids: list, dataset_id: str, documen
|
|||
)
|
||||
|
||||
if not segments:
|
||||
db.session.close()
|
||||
return
|
||||
|
||||
try:
|
||||
|
@ -76,3 +80,4 @@ def disable_segments_from_index_task(segment_ids: list, dataset_id: str, documen
|
|||
for segment in segments:
|
||||
indexing_cache_key = "segment_{}_indexing".format(segment.id)
|
||||
redis_client.delete(indexing_cache_key)
|
||||
db.session.close()
|
||||
|
|
|
@ -4,7 +4,6 @@ import time
|
|||
|
||||
import click
|
||||
from celery import shared_task # type: ignore
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
|
||||
from core.rag.extractor.notion_extractor import NotionExtractor
|
||||
|
@ -29,7 +28,9 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
|
|||
document = db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first()
|
||||
|
||||
if not document:
|
||||
raise NotFound("Document not found")
|
||||
logging.info(click.style("Document not found: {}".format(document_id), fg="red"))
|
||||
db.session.close()
|
||||
return
|
||||
|
||||
data_source_info = document.data_source_info_dict
|
||||
if document.data_source_type == "notion_import":
|
||||
|
|
|
@ -27,6 +27,7 @@ def document_indexing_task(dataset_id: str, document_ids: list):
|
|||
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
|
||||
if not dataset:
|
||||
logging.info(click.style("Dataset is not found: {}".format(dataset_id), fg="yellow"))
|
||||
db.session.close()
|
||||
return
|
||||
# check document limit
|
||||
features = FeatureService.get_features(dataset.tenant_id)
|
||||
|
@ -35,6 +36,8 @@ def document_indexing_task(dataset_id: str, document_ids: list):
|
|||
vector_space = features.vector_space
|
||||
count = len(document_ids)
|
||||
batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT)
|
||||
if features.billing.subscription.plan == "sandbox" and count > 1:
|
||||
raise ValueError("Your current plan does not support batch upload, please upgrade your plan.")
|
||||
if count > batch_upload_limit:
|
||||
raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")
|
||||
if 0 < vector_space.limit <= vector_space.size:
|
||||
|
@ -53,6 +56,7 @@ def document_indexing_task(dataset_id: str, document_ids: list):
|
|||
document.stopped_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
||||
db.session.add(document)
|
||||
db.session.commit()
|
||||
db.session.close()
|
||||
return
|
||||
|
||||
for document_id in document_ids:
|
||||
|
@ -78,3 +82,5 @@ def document_indexing_task(dataset_id: str, document_ids: list):
|
|||
logging.info(click.style(str(ex), fg="yellow"))
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
db.session.close()
|
||||
|
|
|
@ -4,7 +4,6 @@ import time
|
|||
|
||||
import click
|
||||
from celery import shared_task # type: ignore
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
|
||||
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||
|
@ -27,7 +26,9 @@ def document_indexing_update_task(dataset_id: str, document_id: str):
|
|||
document = db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first()
|
||||
|
||||
if not document:
|
||||
raise NotFound("Document not found")
|
||||
logging.info(click.style("Document not found: {}".format(document_id), fg="red"))
|
||||
db.session.close()
|
||||
return
|
||||
|
||||
document.indexing_status = "parsing"
|
||||
document.processing_started_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
||||
|
@ -73,3 +74,5 @@ def document_indexing_update_task(dataset_id: str, document_id: str):
|
|||
logging.info(click.style(str(ex), fg="yellow"))
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
db.session.close()
|
||||
|
|
|
@ -27,7 +27,9 @@ def duplicate_document_indexing_task(dataset_id: str, document_ids: list):
|
|||
|
||||
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
|
||||
if dataset is None:
|
||||
raise ValueError("Dataset not found")
|
||||
logging.info(click.style("Dataset not found: {}".format(dataset_id), fg="red"))
|
||||
db.session.close()
|
||||
return
|
||||
|
||||
# check document limit
|
||||
features = FeatureService.get_features(dataset.tenant_id)
|
||||
|
@ -35,6 +37,8 @@ def duplicate_document_indexing_task(dataset_id: str, document_ids: list):
|
|||
if features.billing.enabled:
|
||||
vector_space = features.vector_space
|
||||
count = len(document_ids)
|
||||
if features.billing.subscription.plan == "sandbox" and count > 1:
|
||||
raise ValueError("Your current plan does not support batch upload, please upgrade your plan.")
|
||||
batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT)
|
||||
if count > batch_upload_limit:
|
||||
raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")
|
||||
|
@ -55,6 +59,8 @@ def duplicate_document_indexing_task(dataset_id: str, document_ids: list):
|
|||
db.session.add(document)
|
||||
db.session.commit()
|
||||
return
|
||||
finally:
|
||||
db.session.close()
|
||||
|
||||
for document_id in document_ids:
|
||||
logging.info(click.style("Start process document: {}".format(document_id), fg="green"))
|
||||
|
@ -94,3 +100,5 @@ def duplicate_document_indexing_task(dataset_id: str, document_ids: list):
|
|||
logging.info(click.style(str(ex), fg="yellow"))
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
db.session.close()
|
||||
|
|
|
@ -4,7 +4,6 @@ import time
|
|||
|
||||
import click
|
||||
from celery import shared_task # type: ignore
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from core.rag.index_processor.constant.index_type import IndexType
|
||||
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||
|
@ -27,10 +26,14 @@ def enable_segment_to_index_task(segment_id: str):
|
|||
|
||||
segment = db.session.query(DocumentSegment).filter(DocumentSegment.id == segment_id).first()
|
||||
if not segment:
|
||||
raise NotFound("Segment not found")
|
||||
logging.info(click.style("Segment not found: {}".format(segment_id), fg="red"))
|
||||
db.session.close()
|
||||
return
|
||||
|
||||
if segment.status != "completed":
|
||||
raise NotFound("Segment is not completed, enable action is not allowed.")
|
||||
logging.info(click.style("Segment is not completed, enable is not allowed: {}".format(segment_id), fg="red"))
|
||||
db.session.close()
|
||||
return
|
||||
|
||||
indexing_cache_key = "segment_{}_indexing".format(segment.id)
|
||||
|
||||
|
@ -94,3 +97,4 @@ def enable_segment_to_index_task(segment_id: str):
|
|||
db.session.commit()
|
||||
finally:
|
||||
redis_client.delete(indexing_cache_key)
|
||||
db.session.close()
|
||||
|
|
|
@ -34,9 +34,11 @@ def enable_segments_to_index_task(segment_ids: list, dataset_id: str, document_i
|
|||
|
||||
if not dataset_document:
|
||||
logging.info(click.style("Document {} not found, pass.".format(document_id), fg="cyan"))
|
||||
db.session.close()
|
||||
return
|
||||
if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed":
|
||||
logging.info(click.style("Document {} status is invalid, pass.".format(document_id), fg="cyan"))
|
||||
db.session.close()
|
||||
return
|
||||
# sync index processor
|
||||
index_processor = IndexProcessorFactory(dataset_document.doc_form).init_index_processor()
|
||||
|
@ -51,6 +53,8 @@ def enable_segments_to_index_task(segment_ids: list, dataset_id: str, document_i
|
|||
.all()
|
||||
)
|
||||
if not segments:
|
||||
logging.info(click.style("Segments not found: {}".format(segment_ids), fg="cyan"))
|
||||
db.session.close()
|
||||
return
|
||||
|
||||
try:
|
||||
|
@ -108,3 +112,4 @@ def enable_segments_to_index_task(segment_ids: list, dataset_id: str, document_i
|
|||
for segment in segments:
|
||||
indexing_cache_key = "segment_{}_indexing".format(segment.id)
|
||||
redis_client.delete(indexing_cache_key)
|
||||
db.session.close()
|
||||
|
|
|
@ -3,7 +3,6 @@ import time
|
|||
|
||||
import click
|
||||
from celery import shared_task # type: ignore
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
|
||||
from extensions.ext_database import db
|
||||
|
@ -25,7 +24,9 @@ def recover_document_indexing_task(dataset_id: str, document_id: str):
|
|||
document = db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first()
|
||||
|
||||
if not document:
|
||||
raise NotFound("Document not found")
|
||||
logging.info(click.style("Document not found: {}".format(document_id), fg="red"))
|
||||
db.session.close()
|
||||
return
|
||||
|
||||
try:
|
||||
indexing_runner = IndexingRunner()
|
||||
|
@ -43,3 +44,5 @@ def recover_document_indexing_task(dataset_id: str, document_id: str):
|
|||
logging.info(click.style(str(ex), fg="yellow"))
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
db.session.close()
|
||||
|
|
|
@ -4,7 +4,6 @@ import time
|
|||
|
||||
import click
|
||||
from celery import shared_task # type: ignore
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||
from extensions.ext_database import db
|
||||
|
@ -25,9 +24,13 @@ def remove_document_from_index_task(document_id: str):
|
|||
|
||||
document = db.session.query(Document).filter(Document.id == document_id).first()
|
||||
if not document:
|
||||
raise NotFound("Document not found")
|
||||
logging.info(click.style("Document not found: {}".format(document_id), fg="red"))
|
||||
db.session.close()
|
||||
return
|
||||
|
||||
if document.indexing_status != "completed":
|
||||
logging.info(click.style("Document is not completed, remove is not allowed: {}".format(document_id), fg="red"))
|
||||
db.session.close()
|
||||
return
|
||||
|
||||
indexing_cache_key = "document_{}_indexing".format(document.id)
|
||||
|
@ -71,3 +74,4 @@ def remove_document_from_index_task(document_id: str):
|
|||
db.session.commit()
|
||||
finally:
|
||||
redis_client.delete(indexing_cache_key)
|
||||
db.session.close()
|
||||
|
|
|
@ -27,7 +27,9 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str]):
|
|||
|
||||
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
|
||||
if not dataset:
|
||||
raise ValueError("Dataset not found")
|
||||
logging.info(click.style("Dataset not found: {}".format(dataset_id), fg="red"))
|
||||
db.session.close()
|
||||
return
|
||||
|
||||
for document_id in document_ids:
|
||||
retry_indexing_cache_key = "document_{}_is_retried".format(document_id)
|
||||
|
@ -52,6 +54,7 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str]):
|
|||
db.session.add(document)
|
||||
db.session.commit()
|
||||
redis_client.delete(retry_indexing_cache_key)
|
||||
db.session.close()
|
||||
return
|
||||
|
||||
logging.info(click.style("Start retry document: {}".format(document_id), fg="green"))
|
||||
|
@ -60,6 +63,7 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str]):
|
|||
)
|
||||
if not document:
|
||||
logging.info(click.style("Document not found: {}".format(document_id), fg="yellow"))
|
||||
db.session.close()
|
||||
return
|
||||
try:
|
||||
# clean old data
|
||||
|
@ -92,5 +96,7 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str]):
|
|||
logging.info(click.style(str(ex), fg="yellow"))
|
||||
redis_client.delete(retry_indexing_cache_key)
|
||||
pass
|
||||
finally:
|
||||
db.session.close()
|
||||
end_at = time.perf_counter()
|
||||
logging.info(click.style("Retry dataset: {} latency: {}".format(dataset_id, end_at - start_at), fg="green"))
|
||||
|
|
|
@ -5,10 +5,11 @@ import pytest
|
|||
from _pytest.monkeypatch import MonkeyPatch
|
||||
from requests.adapters import HTTPAdapter
|
||||
from tcvectordb import RPCVectorDBClient # type: ignore
|
||||
from tcvectordb.model import enum
|
||||
from tcvectordb.model.collection import FilterIndexConfig
|
||||
from tcvectordb.model.document import Document, Filter # type: ignore
|
||||
from tcvectordb.model.document import AnnSearch, Document, Filter, KeywordSearch, Rerank # type: ignore
|
||||
from tcvectordb.model.enum import ReadConsistency # type: ignore
|
||||
from tcvectordb.model.index import Index, IndexField # type: ignore
|
||||
from tcvectordb.model.index import FilterIndex, HNSWParams, Index, IndexField, VectorIndex # type: ignore
|
||||
from tcvectordb.rpc.model.collection import RPCCollection
|
||||
from tcvectordb.rpc.model.database import RPCDatabase
|
||||
from xinference_client.types import Embedding # type: ignore
|
||||
|
@ -40,6 +41,30 @@ class MockTcvectordbClass:
|
|||
def exists_collection(self, database_name: str, collection_name: str) -> bool:
|
||||
return True
|
||||
|
||||
def describe_collection(
|
||||
self, database_name: str, collection_name: str, timeout: Optional[float] = None
|
||||
) -> RPCCollection:
|
||||
index = Index(
|
||||
FilterIndex("id", enum.FieldType.String, enum.IndexType.PRIMARY_KEY),
|
||||
VectorIndex(
|
||||
"vector",
|
||||
128,
|
||||
enum.IndexType.HNSW,
|
||||
enum.MetricType.IP,
|
||||
HNSWParams(m=16, efconstruction=200),
|
||||
),
|
||||
FilterIndex("text", enum.FieldType.String, enum.IndexType.FILTER),
|
||||
FilterIndex("metadata", enum.FieldType.String, enum.IndexType.FILTER),
|
||||
)
|
||||
return RPCCollection(
|
||||
RPCDatabase(
|
||||
name=database_name,
|
||||
read_consistency=self._read_consistency,
|
||||
),
|
||||
collection_name,
|
||||
index=index,
|
||||
)
|
||||
|
||||
def create_collection(
|
||||
self,
|
||||
database_name: str,
|
||||
|
@ -97,6 +122,23 @@ class MockTcvectordbClass:
|
|||
) -> list[list[dict]]:
|
||||
return [[{"metadata": {"doc_id": "foo1"}, "text": "text", "doc_id": "foo1", "score": 0.1}]]
|
||||
|
||||
def collection_hybrid_search(
|
||||
self,
|
||||
database_name: str,
|
||||
collection_name: str,
|
||||
ann: Optional[Union[list[AnnSearch], AnnSearch]] = None,
|
||||
match: Optional[Union[list[KeywordSearch], KeywordSearch]] = None,
|
||||
filter: Union[Filter, str] = None,
|
||||
rerank: Optional[Rerank] = None,
|
||||
retrieve_vector: Optional[bool] = None,
|
||||
output_fields: Optional[list[str]] = None,
|
||||
limit: Optional[int] = None,
|
||||
timeout: Optional[float] = None,
|
||||
return_pd_object=False,
|
||||
**kwargs,
|
||||
) -> list[list[dict]]:
|
||||
return [[{"metadata": {"doc_id": "foo1"}, "text": "text", "doc_id": "foo1", "score": 0.1}]]
|
||||
|
||||
def collection_query(
|
||||
self,
|
||||
database_name: str,
|
||||
|
@ -137,8 +179,10 @@ def setup_tcvectordb_mock(request, monkeypatch: MonkeyPatch):
|
|||
)
|
||||
monkeypatch.setattr(RPCVectorDBClient, "exists_collection", MockTcvectordbClass.exists_collection)
|
||||
monkeypatch.setattr(RPCVectorDBClient, "create_collection", MockTcvectordbClass.create_collection)
|
||||
monkeypatch.setattr(RPCVectorDBClient, "describe_collection", MockTcvectordbClass.describe_collection)
|
||||
monkeypatch.setattr(RPCVectorDBClient, "upsert", MockTcvectordbClass.collection_upsert)
|
||||
monkeypatch.setattr(RPCVectorDBClient, "search", MockTcvectordbClass.collection_search)
|
||||
monkeypatch.setattr(RPCVectorDBClient, "hybrid_search", MockTcvectordbClass.collection_hybrid_search)
|
||||
monkeypatch.setattr(RPCVectorDBClient, "query", MockTcvectordbClass.collection_query)
|
||||
monkeypatch.setattr(RPCVectorDBClient, "delete", MockTcvectordbClass.collection_delete)
|
||||
monkeypatch.setattr(RPCVectorDBClient, "drop_collection", MockTcvectordbClass.drop_collection)
|
||||
|
|
|
@ -21,6 +21,7 @@ class TencentVectorTest(AbstractVectorTest):
|
|||
database="dify",
|
||||
shard=1,
|
||||
replicas=2,
|
||||
enable_hybrid_search=True,
|
||||
),
|
||||
)
|
||||
|
||||
|
@ -30,7 +31,7 @@ class TencentVectorTest(AbstractVectorTest):
|
|||
|
||||
def search_by_full_text(self):
|
||||
hits_by_full_text = self.vector.search_by_full_text(query=get_example_text())
|
||||
assert len(hits_by_full_text) == 0
|
||||
assert len(hits_by_full_text) >= 0
|
||||
|
||||
|
||||
def test_tencent_vector(setup_mock_redis, setup_tcvectordb_mock):
|
||||
|
|
|
@ -0,0 +1,70 @@
|
|||
import json
|
||||
from collections.abc import Generator
|
||||
|
||||
from core.agent.entities import AgentScratchpadUnit
|
||||
from core.agent.output_parser.cot_output_parser import CotAgentOutputParser
|
||||
from core.model_runtime.entities.llm_entities import AssistantPromptMessage, LLMResultChunk, LLMResultChunkDelta
|
||||
|
||||
|
||||
def mock_llm_response(text) -> Generator[LLMResultChunk, None, None]:
|
||||
for i in range(len(text)):
|
||||
yield LLMResultChunk(
|
||||
model="model",
|
||||
prompt_messages=[],
|
||||
delta=LLMResultChunkDelta(index=0, message=AssistantPromptMessage(content=text[i], tool_calls=[])),
|
||||
)
|
||||
|
||||
|
||||
def test_cot_output_parser():
|
||||
test_cases = [
|
||||
{
|
||||
"input": 'Through: abc\nAction: ```{"action": "Final Answer", "action_input": "```echarts\n {}\n```"}```',
|
||||
"action": {"action": "Final Answer", "action_input": "```echarts\n {}\n```"},
|
||||
"output": 'Through: abc\n {"action": "Final Answer", "action_input": "```echarts\\n {}\\n```"}',
|
||||
},
|
||||
# code block with json
|
||||
{
|
||||
"input": 'Through: abc\nAction: ```json\n{"action": "Final Answer", "action_input": "```echarts\n {'
|
||||
'}\n```"}```',
|
||||
"action": {"action": "Final Answer", "action_input": "```echarts\n {}\n```"},
|
||||
"output": 'Through: abc\n {"action": "Final Answer", "action_input": "```echarts\\n {}\\n```"}',
|
||||
},
|
||||
# code block with JSON
|
||||
{
|
||||
"input": 'Through: abc\nAction: ```JSON\n{"action": "Final Answer", "action_input": "```echarts\n {'
|
||||
'}\n```"}```',
|
||||
"action": {"action": "Final Answer", "action_input": "```echarts\n {}\n```"},
|
||||
"output": 'Through: abc\n {"action": "Final Answer", "action_input": "```echarts\\n {}\\n```"}',
|
||||
},
|
||||
# list
|
||||
{
|
||||
"input": 'Through: abc\nAction: ```[{"action": "Final Answer", "action_input": "```echarts\n {}\n```"}]```',
|
||||
"action": {"action": "Final Answer", "action_input": "```echarts\n {}\n```"},
|
||||
"output": 'Through: abc\n {"action": "Final Answer", "action_input": "```echarts\\n {}\\n```"}',
|
||||
},
|
||||
# no code block
|
||||
{
|
||||
"input": 'Through: abc\nAction: {"action": "Final Answer", "action_input": "```echarts\n {}\n```"}',
|
||||
"action": {"action": "Final Answer", "action_input": "```echarts\n {}\n```"},
|
||||
"output": 'Through: abc\n {"action": "Final Answer", "action_input": "```echarts\\n {}\\n```"}',
|
||||
},
|
||||
# no code block and json
|
||||
{"input": "Through: abc\nAction: efg", "action": {}, "output": "Through: abc\n efg"},
|
||||
]
|
||||
|
||||
parser = CotAgentOutputParser()
|
||||
usage_dict = {}
|
||||
for test_case in test_cases:
|
||||
# mock llm_response as a generator by text
|
||||
llm_response: Generator[LLMResultChunk, None, None] = mock_llm_response(test_case["input"])
|
||||
results = parser.handle_react_stream_output(llm_response, usage_dict)
|
||||
output = ""
|
||||
for result in results:
|
||||
if isinstance(result, str):
|
||||
output += result
|
||||
elif isinstance(result, AgentScratchpadUnit.Action):
|
||||
if test_case["action"]:
|
||||
assert result.to_dict() == test_case["action"]
|
||||
output += json.dumps(result.to_dict())
|
||||
if test_case["output"]:
|
||||
assert output == test_case["output"]
|
|
@ -515,6 +515,7 @@ TENCENT_VECTOR_DB_USERNAME=dify
|
|||
TENCENT_VECTOR_DB_DATABASE=dify
|
||||
TENCENT_VECTOR_DB_SHARD=1
|
||||
TENCENT_VECTOR_DB_REPLICAS=2
|
||||
TENCENT_VECTOR_DB_ENABLE_HYBRID_SEARCH=false
|
||||
|
||||
# ElasticSearch configuration, only available when VECTOR_STORE is `elasticsearch`
|
||||
ELASTICSEARCH_HOST=0.0.0.0
|
||||
|
|
|
@ -36,7 +36,8 @@ Welcome to the new `docker` directory for deploying Dify using Docker Compose. T
|
|||
- Navigate to the `docker` directory.
|
||||
- Ensure the `middleware.env` file is created by running `cp middleware.env.example middleware.env` (refer to the `middleware.env.example` file).
|
||||
2. **Running Middleware Services**:
|
||||
- Execute `docker-compose -f docker-compose.middleware.yaml up --env-file middleware.env -d` to start the middleware services.
|
||||
- Navigate to the `docker` directory.
|
||||
- Execute `docker compose -f docker-compose.middleware.yaml --profile weaviate -p dify up -d` to start the middleware services. (Change the profile to other vector database if you are not using weaviate)
|
||||
|
||||
### Migration for Existing Users
|
||||
|
||||
|
|
|
@ -223,6 +223,7 @@ x-shared-env: &shared-api-worker-env
|
|||
TENCENT_VECTOR_DB_DATABASE: ${TENCENT_VECTOR_DB_DATABASE:-dify}
|
||||
TENCENT_VECTOR_DB_SHARD: ${TENCENT_VECTOR_DB_SHARD:-1}
|
||||
TENCENT_VECTOR_DB_REPLICAS: ${TENCENT_VECTOR_DB_REPLICAS:-2}
|
||||
TENCENT_VECTOR_DB_ENABLE_HYBRID_SEARCH: ${TENCENT_VECTOR_DB_ENABLE_HYBRID_SEARCH:-false}
|
||||
ELASTICSEARCH_HOST: ${ELASTICSEARCH_HOST:-0.0.0.0}
|
||||
ELASTICSEARCH_PORT: ${ELASTICSEARCH_PORT:-9200}
|
||||
ELASTICSEARCH_USERNAME: ${ELASTICSEARCH_USERNAME:-elastic}
|
||||
|
|
Binary file not shown.
Before Width: | Height: | Size: 257 KiB |
|
@ -42,8 +42,8 @@ ENV EDITION=SELF_HOSTED
|
|||
ENV DEPLOY_ENV=PRODUCTION
|
||||
ENV CONSOLE_API_URL=http://127.0.0.1:5001
|
||||
ENV APP_API_URL=http://127.0.0.1:5001
|
||||
ENV MARKETPLACE_API_URL=http://127.0.0.1:5001
|
||||
ENV MARKETPLACE_URL=http://127.0.0.1:5001
|
||||
ENV MARKETPLACE_API_URL=https://marketplace.dify.ai
|
||||
ENV MARKETPLACE_URL=https://marketplace.dify.ai
|
||||
ENV PORT=3000
|
||||
ENV NEXT_TELEMETRY_DISABLED=1
|
||||
ENV PM2_INSTANCES=2
|
||||
|
|
|
@ -1,7 +1,9 @@
|
|||
'use client'
|
||||
|
||||
import { useCallback, useEffect, useRef, useState } from 'react'
|
||||
import { useRouter } from 'next/navigation'
|
||||
import {
|
||||
useRouter,
|
||||
} from 'next/navigation'
|
||||
import useSWRInfinite from 'swr/infinite'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { useDebounceFn } from 'ahooks'
|
||||
|
|
|
@ -7,9 +7,12 @@ import style from '../list.module.css'
|
|||
import Apps from './Apps'
|
||||
import AppContext from '@/context/app-context'
|
||||
import { LicenseStatus } from '@/types/feature'
|
||||
import { useEducationInit } from '@/app/education-apply/hooks'
|
||||
|
||||
const AppList = () => {
|
||||
const { t } = useTranslation()
|
||||
useEducationInit()
|
||||
|
||||
const systemFeatures = useContextSelector(AppContext, v => v.systemFeatures)
|
||||
|
||||
return (
|
||||
|
|
|
@ -38,6 +38,8 @@ const Container = () => {
|
|||
const { showExternalApiPanel, setShowExternalApiPanel } = useExternalApiPanel()
|
||||
const [includeAll, { toggle: toggleIncludeAll }] = useBoolean(false)
|
||||
|
||||
document.title = `${t('dataset.knowledge')} - Dify`
|
||||
|
||||
const options = useMemo(() => {
|
||||
return [
|
||||
{ value: 'dataset', text: t('dataset.datasets') },
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue