merge main

This commit is contained in:
zxhlyh 2025-04-18 10:30:28 +08:00
commit 628dd8fcd8
397 changed files with 16159 additions and 16507 deletions

View File

@ -2,10 +2,10 @@
npm add -g pnpm@10.8.0 npm add -g pnpm@10.8.0
cd web && pnpm install cd web && pnpm install
pipx install poetry pipx install uv
echo 'alias start-api="cd /workspaces/dify/api && poetry run python -m flask run --host 0.0.0.0 --port=5001 --debug"' >> ~/.bashrc echo 'alias start-api="cd /workspaces/dify/api && uv run python -m flask run --host 0.0.0.0 --port=5001 --debug"' >> ~/.bashrc
echo 'alias start-worker="cd /workspaces/dify/api && poetry run python -m celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion"' >> ~/.bashrc echo 'alias start-worker="cd /workspaces/dify/api && uv run python -m celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion"' >> ~/.bashrc
echo 'alias start-web="cd /workspaces/dify/web && pnpm dev"' >> ~/.bashrc echo 'alias start-web="cd /workspaces/dify/web && pnpm dev"' >> ~/.bashrc
echo 'alias start-containers="cd /workspaces/dify/docker && docker-compose -f docker-compose.middleware.yaml -p dify --env-file middleware.env up -d"' >> ~/.bashrc echo 'alias start-containers="cd /workspaces/dify/docker && docker-compose -f docker-compose.middleware.yaml -p dify --env-file middleware.env up -d"' >> ~/.bashrc
echo 'alias stop-containers="cd /workspaces/dify/docker && docker-compose -f docker-compose.middleware.yaml -p dify --env-file middleware.env down"' >> ~/.bashrc echo 'alias stop-containers="cd /workspaces/dify/docker && docker-compose -f docker-compose.middleware.yaml -p dify --env-file middleware.env down"' >> ~/.bashrc

View File

@ -1,3 +1,3 @@
#!/bin/bash #!/bin/bash
cd api && poetry install cd api && uv sync

View File

@ -1,36 +0,0 @@
name: Setup Poetry and Python
inputs:
python-version:
description: Python version to use and the Poetry installed with
required: true
default: '3.11'
poetry-version:
description: Poetry version to set up
required: true
default: '2.0.1'
poetry-lockfile:
description: Path to the Poetry lockfile to restore cache from
required: true
default: ''
runs:
using: composite
steps:
- name: Set up Python ${{ inputs.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ inputs.python-version }}
cache: pip
- name: Install Poetry
shell: bash
run: pip install poetry==${{ inputs.poetry-version }}
- name: Restore Poetry cache
if: ${{ inputs.poetry-lockfile != '' }}
uses: actions/setup-python@v5
with:
python-version: ${{ inputs.python-version }}
cache: poetry
cache-dependency-path: ${{ inputs.poetry-lockfile }}

34
.github/actions/setup-uv/action.yml vendored Normal file
View File

@ -0,0 +1,34 @@
name: Setup UV and Python
inputs:
python-version:
description: Python version to use and the UV installed with
required: true
default: '3.12'
uv-version:
description: UV version to set up
required: true
default: '0.6.14'
uv-lockfile:
description: Path to the UV lockfile to restore cache from
required: true
default: ''
enable-cache:
required: true
default: true
runs:
using: composite
steps:
- name: Set up Python ${{ inputs.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ inputs.python-version }}
- name: Install uv
uses: astral-sh/setup-uv@v5
with:
version: ${{ inputs.uv-version }}
python-version: ${{ inputs.python-version }}
enable-cache: ${{ inputs.enable-cache }}
cache-dependency-glob: ${{ inputs.uv-lockfile }}

View File

@ -17,6 +17,9 @@ jobs:
test: test:
name: API Tests name: API Tests
runs-on: ubuntu-latest runs-on: ubuntu-latest
defaults:
run:
shell: bash
strategy: strategy:
matrix: matrix:
python-version: python-version:
@ -27,40 +30,44 @@ jobs:
- name: Checkout code - name: Checkout code
uses: actions/checkout@v4 uses: actions/checkout@v4
with: with:
fetch-depth: 0
persist-credentials: false persist-credentials: false
- name: Setup Poetry and Python ${{ matrix.python-version }} - name: Setup UV and Python
uses: ./.github/actions/setup-poetry uses: ./.github/actions/setup-uv
with: with:
python-version: ${{ matrix.python-version }} python-version: ${{ matrix.python-version }}
poetry-lockfile: api/poetry.lock uv-lockfile: api/uv.lock
- name: Check Poetry lockfile - name: Check UV lockfile
run: | run: uv lock --project api --check
poetry check -C api --lock
poetry show -C api
- name: Install dependencies - name: Install dependencies
run: poetry install -C api --with dev run: uv sync --project api --dev
- name: Check dependencies in pyproject.toml
run: poetry run -P api bash dev/pytest/pytest_artifacts.sh
- name: Run Unit tests - name: Run Unit tests
run: poetry run -P api bash dev/pytest/pytest_unit_tests.sh run: |
uv run --project api bash dev/pytest/pytest_unit_tests.sh
# Extract coverage percentage and create a summary
TOTAL_COVERAGE=$(python -c 'import json; print(json.load(open("coverage.json"))["totals"]["percent_covered_display"])')
# Create a detailed coverage summary
echo "### Test Coverage Summary :test_tube:" >> $GITHUB_STEP_SUMMARY
echo "Total Coverage: ${TOTAL_COVERAGE}%" >> $GITHUB_STEP_SUMMARY
echo "\`\`\`" >> $GITHUB_STEP_SUMMARY
uv run --project api coverage report >> $GITHUB_STEP_SUMMARY
echo "\`\`\`" >> $GITHUB_STEP_SUMMARY
- name: Run dify config tests - name: Run dify config tests
run: poetry run -P api python dev/pytest/pytest_config_tests.py run: uv run --project api dev/pytest/pytest_config_tests.py
- name: Cache MyPy - name: MyPy Cache
uses: actions/cache@v4 uses: actions/cache@v4
with: with:
path: api/.mypy_cache path: api/.mypy_cache
key: mypy-${{ matrix.python-version }}-${{ runner.os }}-${{ hashFiles('api/poetry.lock') }} key: mypy-${{ matrix.python-version }}-${{ runner.os }}-${{ hashFiles('api/uv.lock') }}
- name: Run mypy - name: Run MyPy Checks
run: dev/run-mypy run: dev/mypy-check
- name: Set up dotenvs - name: Set up dotenvs
run: | run: |
@ -80,4 +87,4 @@ jobs:
ssrf_proxy ssrf_proxy
- name: Run Workflow - name: Run Workflow
run: poetry run -P api bash dev/pytest/pytest_workflow.sh run: uv run --project api bash dev/pytest/pytest_workflow.sh

View File

@ -24,13 +24,13 @@ jobs:
fetch-depth: 0 fetch-depth: 0
persist-credentials: false persist-credentials: false
- name: Setup Poetry and Python - name: Setup UV and Python
uses: ./.github/actions/setup-poetry uses: ./.github/actions/setup-uv
with: with:
poetry-lockfile: api/poetry.lock uv-lockfile: api/uv.lock
- name: Install dependencies - name: Install dependencies
run: poetry install -C api run: uv sync --project api
- name: Prepare middleware env - name: Prepare middleware env
run: | run: |
@ -54,6 +54,4 @@ jobs:
- name: Run DB Migration - name: Run DB Migration
env: env:
DEBUG: true DEBUG: true
run: | run: uv run --directory api flask upgrade-db
cd api
poetry run python -m flask upgrade-db

View File

@ -42,6 +42,7 @@ jobs:
with: with:
push: false push: false
context: "{{defaultContext}}:${{ matrix.context }}" context: "{{defaultContext}}:${{ matrix.context }}"
file: "${{ matrix.file }}"
platforms: ${{ matrix.platform }} platforms: ${{ matrix.platform }}
cache-from: type=gha cache-from: type=gha
cache-to: type=gha,mode=max cache-to: type=gha,mode=max

View File

@ -18,7 +18,6 @@ jobs:
- name: Checkout code - name: Checkout code
uses: actions/checkout@v4 uses: actions/checkout@v4
with: with:
fetch-depth: 0
persist-credentials: false persist-credentials: false
- name: Check changed files - name: Check changed files
@ -29,24 +28,27 @@ jobs:
api/** api/**
.github/workflows/style.yml .github/workflows/style.yml
- name: Setup Poetry and Python - name: Setup UV and Python
if: steps.changed-files.outputs.any_changed == 'true' if: steps.changed-files.outputs.any_changed == 'true'
uses: ./.github/actions/setup-poetry uses: ./.github/actions/setup-uv
with:
uv-lockfile: api/uv.lock
enable-cache: false
- name: Install dependencies - name: Install dependencies
if: steps.changed-files.outputs.any_changed == 'true' if: steps.changed-files.outputs.any_changed == 'true'
run: poetry install -C api --only lint run: uv sync --project api --dev
- name: Ruff check - name: Ruff check
if: steps.changed-files.outputs.any_changed == 'true' if: steps.changed-files.outputs.any_changed == 'true'
run: | run: |
poetry run -C api ruff --version uv run --directory api ruff --version
poetry run -C api ruff check ./ uv run --directory api ruff check ./
poetry run -C api ruff format --check ./ uv run --directory api ruff format --check ./
- name: Dotenv check - name: Dotenv check
if: steps.changed-files.outputs.any_changed == 'true' if: steps.changed-files.outputs.any_changed == 'true'
run: poetry run -P api dotenv-linter ./api/.env.example ./web/.env.example run: uv run --project api dotenv-linter ./api/.env.example ./web/.env.example
- name: Lint hints - name: Lint hints
if: failure() if: failure()
@ -63,7 +65,6 @@ jobs:
- name: Checkout code - name: Checkout code
uses: actions/checkout@v4 uses: actions/checkout@v4
with: with:
fetch-depth: 0
persist-credentials: false persist-credentials: false
- name: Check changed files - name: Check changed files
@ -102,7 +103,6 @@ jobs:
- name: Checkout code - name: Checkout code
uses: actions/checkout@v4 uses: actions/checkout@v4
with: with:
fetch-depth: 0
persist-credentials: false persist-credentials: false
- name: Check changed files - name: Check changed files
@ -133,7 +133,6 @@ jobs:
- name: Checkout code - name: Checkout code
uses: actions/checkout@v4 uses: actions/checkout@v4
with: with:
fetch-depth: 0
persist-credentials: false persist-credentials: false
- name: Check changed files - name: Check changed files

View File

@ -27,7 +27,6 @@ jobs:
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
with: with:
fetch-depth: 0
persist-credentials: false persist-credentials: false
- name: Use Node.js ${{ matrix.node-version }} - name: Use Node.js ${{ matrix.node-version }}

View File

@ -8,7 +8,7 @@ on:
- api/core/rag/datasource/** - api/core/rag/datasource/**
- docker/** - docker/**
- .github/workflows/vdb-tests.yml - .github/workflows/vdb-tests.yml
- api/poetry.lock - api/uv.lock
- api/pyproject.toml - api/pyproject.toml
concurrency: concurrency:
@ -29,22 +29,19 @@ jobs:
- name: Checkout code - name: Checkout code
uses: actions/checkout@v4 uses: actions/checkout@v4
with: with:
fetch-depth: 0
persist-credentials: false persist-credentials: false
- name: Setup Poetry and Python ${{ matrix.python-version }} - name: Setup UV and Python
uses: ./.github/actions/setup-poetry uses: ./.github/actions/setup-uv
with: with:
python-version: ${{ matrix.python-version }} python-version: ${{ matrix.python-version }}
poetry-lockfile: api/poetry.lock uv-lockfile: api/uv.lock
- name: Check Poetry lockfile - name: Check UV lockfile
run: | run: uv lock --project api --check
poetry check -C api --lock
poetry show -C api
- name: Install dependencies - name: Install dependencies
run: poetry install -C api --with dev run: uv sync --project api --dev
- name: Set up dotenvs - name: Set up dotenvs
run: | run: |
@ -80,7 +77,7 @@ jobs:
elasticsearch elasticsearch
- name: Check TiDB Ready - name: Check TiDB Ready
run: poetry run -P api python api/tests/integration_tests/vdb/tidb_vector/check_tiflash_ready.py run: uv run --project api python api/tests/integration_tests/vdb/tidb_vector/check_tiflash_ready.py
- name: Test Vector Stores - name: Test Vector Stores
run: poetry run -P api bash dev/pytest/pytest_vdb.sh run: uv run --project api bash dev/pytest/pytest_vdb.sh

View File

@ -23,7 +23,6 @@ jobs:
- name: Checkout code - name: Checkout code
uses: actions/checkout@v4 uses: actions/checkout@v4
with: with:
fetch-depth: 0
persist-credentials: false persist-credentials: false
- name: Check changed files - name: Check changed files

1
.gitignore vendored
View File

@ -46,6 +46,7 @@ htmlcov/
.cache .cache
nosetests.xml nosetests.xml
coverage.xml coverage.xml
coverage.json
*.cover *.cover
*.py,cover *.py,cover
.hypothesis/ .hypothesis/

View File

@ -165,6 +165,7 @@ MILVUS_URI=http://127.0.0.1:19530
MILVUS_TOKEN= MILVUS_TOKEN=
MILVUS_USER=root MILVUS_USER=root
MILVUS_PASSWORD=Milvus MILVUS_PASSWORD=Milvus
MILVUS_ANALYZER_PARAMS=
# MyScale configuration # MyScale configuration
MYSCALE_HOST=127.0.0.1 MYSCALE_HOST=127.0.0.1
@ -423,6 +424,12 @@ WORKFLOW_CALL_MAX_DEPTH=5
WORKFLOW_PARALLEL_DEPTH_LIMIT=3 WORKFLOW_PARALLEL_DEPTH_LIMIT=3
MAX_VARIABLE_SIZE=204800 MAX_VARIABLE_SIZE=204800
# Workflow storage configuration
# Options: rdbms, hybrid
# rdbms: Use only the relational database (default)
# hybrid: Save new data to object storage, read from both object storage and RDBMS
WORKFLOW_NODE_EXECUTION_STORAGE=rdbms
# App configuration # App configuration
APP_MAX_EXECUTION_TIME=1200 APP_MAX_EXECUTION_TIME=1200
APP_MAX_ACTIVE_REQUESTS=0 APP_MAX_ACTIVE_REQUESTS=0
@ -463,3 +470,16 @@ CREATE_TIDB_SERVICE_JOB_ENABLED=false
MAX_SUBMIT_COUNT=100 MAX_SUBMIT_COUNT=100
# Lockout duration in seconds # Lockout duration in seconds
LOGIN_LOCKOUT_DURATION=86400 LOGIN_LOCKOUT_DURATION=86400
# Enable OpenTelemetry
ENABLE_OTEL=false
OTLP_BASE_ENDPOINT=http://localhost:4318
OTLP_API_KEY=
OTEL_EXPORTER_TYPE=otlp
OTEL_SAMPLING_RATE=0.1
OTEL_BATCH_EXPORT_SCHEDULE_DELAY=5000
OTEL_MAX_QUEUE_SIZE=2048
OTEL_MAX_EXPORT_BATCH_SIZE=512
OTEL_METRIC_EXPORT_INTERVAL=60000
OTEL_BATCH_EXPORT_TIMEOUT=10000
OTEL_METRIC_EXPORT_TIMEOUT=30000

View File

@ -3,20 +3,11 @@ FROM python:3.12-slim-bookworm AS base
WORKDIR /app/api WORKDIR /app/api
# Install Poetry # Install uv
ENV POETRY_VERSION=2.0.1 ENV UV_VERSION=0.6.14
# if you located in China, you can use aliyun mirror to speed up RUN pip install --no-cache-dir uv==${UV_VERSION}
# RUN pip install --no-cache-dir poetry==${POETRY_VERSION} -i https://mirrors.aliyun.com/pypi/simple/
RUN pip install --no-cache-dir poetry==${POETRY_VERSION}
# Configure Poetry
ENV POETRY_CACHE_DIR=/tmp/poetry_cache
ENV POETRY_NO_INTERACTION=1
ENV POETRY_VIRTUALENVS_IN_PROJECT=true
ENV POETRY_VIRTUALENVS_CREATE=true
ENV POETRY_REQUESTS_TIMEOUT=15
FROM base AS packages FROM base AS packages
@ -27,8 +18,8 @@ RUN apt-get update \
&& apt-get install -y --no-install-recommends gcc g++ libc-dev libffi-dev libgmp-dev libmpfr-dev libmpc-dev && apt-get install -y --no-install-recommends gcc g++ libc-dev libffi-dev libgmp-dev libmpfr-dev libmpc-dev
# Install Python dependencies # Install Python dependencies
COPY pyproject.toml poetry.lock ./ COPY pyproject.toml uv.lock ./
RUN poetry install --sync --no-cache --no-root RUN uv sync --locked
# production stage # production stage
FROM base AS production FROM base AS production

View File

@ -3,7 +3,10 @@
## Usage ## Usage
> [!IMPORTANT] > [!IMPORTANT]
> In the v0.6.12 release, we deprecated `pip` as the package management tool for Dify API Backend service and replaced it with `poetry`. >
> In the v1.3.0 release, `poetry` has been replaced with
> [`uv`](https://docs.astral.sh/uv/) as the package manager
> for Dify API backend service.
1. Start the docker-compose stack 1. Start the docker-compose stack
@ -37,19 +40,19 @@
4. Create environment. 4. Create environment.
Dify API service uses [Poetry](https://python-poetry.org/docs/) to manage dependencies. First, you need to add the poetry shell plugin, if you don't have it already, in order to run in a virtual environment. [Note: Poetry shell is no longer a native command so you need to install the poetry plugin beforehand] Dify API service uses [UV](https://docs.astral.sh/uv/) to manage dependencies.
First, you need to add the uv package manager, if you don't have it already.
```bash ```bash
poetry self add poetry-plugin-shell pip install uv
# Or on macOS
brew install uv
``` ```
Then, You can execute `poetry shell` to activate the environment.
5. Install dependencies 5. Install dependencies
```bash ```bash
poetry env use 3.12 uv sync --dev
poetry install
``` ```
6. Run migrate 6. Run migrate
@ -57,21 +60,21 @@
Before the first launch, migrate the database to the latest version. Before the first launch, migrate the database to the latest version.
```bash ```bash
poetry run python -m flask db upgrade uv run flask db upgrade
``` ```
7. Start backend 7. Start backend
```bash ```bash
poetry run python -m flask run --host 0.0.0.0 --port=5001 --debug uv run flask run --host 0.0.0.0 --port=5001 --debug
``` ```
8. Start Dify [web](../web) service. 8. Start Dify [web](../web) service.
9. Setup your application by visiting `http://localhost:3000`... 9. Setup your application by visiting `http://localhost:3000`.
10. If you need to handle and debug the async tasks (e.g. dataset importing and documents indexing), please start the worker service. 10. If you need to handle and debug the async tasks (e.g. dataset importing and documents indexing), please start the worker service.
```bash ```bash
poetry run python -m celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion uv run celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion
``` ```
## Testing ## Testing
@ -79,11 +82,11 @@
1. Install dependencies for both the backend and the test environment 1. Install dependencies for both the backend and the test environment
```bash ```bash
poetry install -C api --with dev uv sync --dev
``` ```
2. Run the tests locally with mocked system environment variables in `tool.pytest_env` section in `pyproject.toml` 2. Run the tests locally with mocked system environment variables in `tool.pytest_env` section in `pyproject.toml`
```bash ```bash
poetry run -P api bash dev/pytest/pytest_all_tests.sh uv run -P api bash dev/pytest/pytest_all_tests.sh
``` ```

View File

@ -51,8 +51,10 @@ def initialize_extensions(app: DifyApp):
ext_login, ext_login,
ext_mail, ext_mail,
ext_migrate, ext_migrate,
ext_otel,
ext_proxy_fix, ext_proxy_fix,
ext_redis, ext_redis,
ext_repositories,
ext_sentry, ext_sentry,
ext_set_secretkey, ext_set_secretkey,
ext_storage, ext_storage,
@ -73,6 +75,7 @@ def initialize_extensions(app: DifyApp):
ext_migrate, ext_migrate,
ext_redis, ext_redis,
ext_storage, ext_storage,
ext_repositories,
ext_celery, ext_celery,
ext_login, ext_login,
ext_mail, ext_mail,
@ -81,6 +84,7 @@ def initialize_extensions(app: DifyApp):
ext_proxy_fix, ext_proxy_fix,
ext_blueprints, ext_blueprints,
ext_commands, ext_commands,
ext_otel,
] ]
for ext in extensions: for ext in extensions:
short_name = ext.__name__.split(".")[-1] short_name = ext.__name__.split(".")[-1]

View File

@ -9,6 +9,7 @@ from .enterprise import EnterpriseFeatureConfig
from .extra import ExtraServiceConfig from .extra import ExtraServiceConfig
from .feature import FeatureConfig from .feature import FeatureConfig
from .middleware import MiddlewareConfig from .middleware import MiddlewareConfig
from .observability import ObservabilityConfig
from .packaging import PackagingInfo from .packaging import PackagingInfo
from .remote_settings_sources import RemoteSettingsSource, RemoteSettingsSourceConfig, RemoteSettingsSourceName from .remote_settings_sources import RemoteSettingsSource, RemoteSettingsSourceConfig, RemoteSettingsSourceName
from .remote_settings_sources.apollo import ApolloSettingsSource from .remote_settings_sources.apollo import ApolloSettingsSource
@ -59,6 +60,8 @@ class DifyConfig(
MiddlewareConfig, MiddlewareConfig,
# Extra service configs # Extra service configs
ExtraServiceConfig, ExtraServiceConfig,
# Observability configs
ObservabilityConfig,
# Remote source configs # Remote source configs
RemoteSettingsSourceConfig, RemoteSettingsSourceConfig,
# Enterprise feature configs # Enterprise feature configs

View File

@ -12,7 +12,7 @@ from pydantic import (
) )
from pydantic_settings import BaseSettings from pydantic_settings import BaseSettings
from configs.feature.hosted_service import HostedServiceConfig from .hosted_service import HostedServiceConfig
class SecurityConfig(BaseSettings): class SecurityConfig(BaseSettings):
@ -519,6 +519,11 @@ class WorkflowNodeExecutionConfig(BaseSettings):
default=100, default=100,
) )
WORKFLOW_NODE_EXECUTION_STORAGE: str = Field(
default="rdbms",
description="Storage backend for WorkflowNodeExecution. Options: 'rdbms', 'hybrid'",
)
class AuthConfig(BaseSettings): class AuthConfig(BaseSettings):
""" """

View File

@ -39,3 +39,8 @@ class MilvusConfig(BaseSettings):
"older versions", "older versions",
default=True, default=True,
) )
MILVUS_ANALYZER_PARAMS: Optional[str] = Field(
description='Milvus text analyzer parameters, e.g., {"type": "chinese"} for Chinese segmentation support.',
default=None,
)

View File

@ -0,0 +1,9 @@
from configs.observability.otel.otel_config import OTelConfig
class ObservabilityConfig(OTelConfig):
"""
Observability configuration settings
"""
pass

View File

@ -0,0 +1,44 @@
from pydantic import Field
from pydantic_settings import BaseSettings
class OTelConfig(BaseSettings):
"""
OpenTelemetry configuration settings
"""
ENABLE_OTEL: bool = Field(
description="Whether to enable OpenTelemetry",
default=False,
)
OTLP_BASE_ENDPOINT: str = Field(
description="OTLP base endpoint",
default="http://localhost:4318",
)
OTLP_API_KEY: str = Field(
description="OTLP API key",
default="",
)
OTEL_EXPORTER_TYPE: str = Field(
description="OTEL exporter type",
default="otlp",
)
OTEL_SAMPLING_RATE: float = Field(default=0.1, description="Sampling rate for traces (0.0 to 1.0)")
OTEL_BATCH_EXPORT_SCHEDULE_DELAY: int = Field(
default=5000, description="Batch export schedule delay in milliseconds"
)
OTEL_MAX_QUEUE_SIZE: int = Field(default=2048, description="Maximum queue size for the batch span processor")
OTEL_MAX_EXPORT_BATCH_SIZE: int = Field(default=512, description="Maximum export batch size")
OTEL_METRIC_EXPORT_INTERVAL: int = Field(default=60000, description="Metric export interval in milliseconds")
OTEL_BATCH_EXPORT_TIMEOUT: int = Field(default=10000, description="Batch export timeout in milliseconds")
OTEL_METRIC_EXPORT_TIMEOUT: int = Field(default=30000, description="Metric export timeout in milliseconds")

View File

@ -270,7 +270,7 @@ class ApolloClient:
while not self._stopping: while not self._stopping:
for namespace in self._notification_map: for namespace in self._notification_map:
self._do_heart_beat(namespace) self._do_heart_beat(namespace)
time.sleep(60 * 10) # 10分钟 time.sleep(60 * 10) # 10 minutes
def _do_heart_beat(self, namespace): def _do_heart_beat(self, namespace):
url = "{}/configs/{}/{}/{}?ip={}".format(self.config_url, self.app_id, self.cluster, namespace, self.ip) url = "{}/configs/{}/{}/{}?ip={}".format(self.config_url, self.app_id, self.cluster, namespace, self.ip)

View File

@ -3,6 +3,8 @@ from configs import dify_config
HIDDEN_VALUE = "[__HIDDEN__]" HIDDEN_VALUE = "[__HIDDEN__]"
UUID_NIL = "00000000-0000-0000-0000-000000000000" UUID_NIL = "00000000-0000-0000-0000-000000000000"
DEFAULT_FILE_NUMBER_LIMITS = 3
IMAGE_EXTENSIONS = ["jpg", "jpeg", "png", "webp", "gif", "svg"] IMAGE_EXTENSIONS = ["jpg", "jpeg", "png", "webp", "gif", "svg"]
IMAGE_EXTENSIONS.extend([ext.upper() for ext in IMAGE_EXTENSIONS]) IMAGE_EXTENSIONS.extend([ext.upper() for ext in IMAGE_EXTENSIONS])

View File

@ -4,8 +4,6 @@ import platform
import re import re
import urllib.parse import urllib.parse
import warnings import warnings
from collections.abc import Mapping
from typing import Any
from uuid import uuid4 from uuid import uuid4
import httpx import httpx
@ -29,8 +27,6 @@ except ImportError:
from pydantic import BaseModel from pydantic import BaseModel
from configs import dify_config
class FileInfo(BaseModel): class FileInfo(BaseModel):
filename: str filename: str
@ -87,38 +83,3 @@ def guess_file_info_from_response(response: httpx.Response):
mimetype=mimetype, mimetype=mimetype,
size=int(response.headers.get("Content-Length", -1)), size=int(response.headers.get("Content-Length", -1)),
) )
def get_parameters_from_feature_dict(*, features_dict: Mapping[str, Any], user_input_form: list[dict[str, Any]]):
return {
"opening_statement": features_dict.get("opening_statement"),
"suggested_questions": features_dict.get("suggested_questions", []),
"suggested_questions_after_answer": features_dict.get("suggested_questions_after_answer", {"enabled": False}),
"speech_to_text": features_dict.get("speech_to_text", {"enabled": False}),
"text_to_speech": features_dict.get("text_to_speech", {"enabled": False}),
"retriever_resource": features_dict.get("retriever_resource", {"enabled": False}),
"annotation_reply": features_dict.get("annotation_reply", {"enabled": False}),
"more_like_this": features_dict.get("more_like_this", {"enabled": False}),
"user_input_form": user_input_form,
"sensitive_word_avoidance": features_dict.get(
"sensitive_word_avoidance", {"enabled": False, "type": "", "configs": []}
),
"file_upload": features_dict.get(
"file_upload",
{
"image": {
"enabled": False,
"number_limits": 3,
"detail": "high",
"transfer_methods": ["remote_url", "local_file"],
}
},
),
"system_parameters": {
"image_file_size_limit": dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT,
"video_file_size_limit": dify_config.UPLOAD_VIDEO_FILE_SIZE_LIMIT,
"audio_file_size_limit": dify_config.UPLOAD_AUDIO_FILE_SIZE_LIMIT,
"file_size_limit": dify_config.UPLOAD_FILE_SIZE_LIMIT,
"workflow_file_upload_limit": dify_config.WORKFLOW_FILE_UPLOAD_LIMIT,
},
}

View File

@ -89,7 +89,7 @@ class AnnotationReplyActionStatusApi(Resource):
app_annotation_job_key = "{}_app_annotation_job_{}".format(action, str(job_id)) app_annotation_job_key = "{}_app_annotation_job_{}".format(action, str(job_id))
cache_result = redis_client.get(app_annotation_job_key) cache_result = redis_client.get(app_annotation_job_key)
if cache_result is None: if cache_result is None:
raise ValueError("The job is not exist.") raise ValueError("The job does not exist.")
job_status = cache_result.decode() job_status = cache_result.decode()
error_msg = "" error_msg = ""
@ -226,7 +226,7 @@ class AnnotationBatchImportStatusApi(Resource):
indexing_cache_key = "app_annotation_batch_import_{}".format(str(job_id)) indexing_cache_key = "app_annotation_batch_import_{}".format(str(job_id))
cache_result = redis_client.get(indexing_cache_key) cache_result = redis_client.get(indexing_cache_key)
if cache_result is None: if cache_result is None:
raise ValueError("The job is not exist.") raise ValueError("The job does not exist.")
job_status = cache_result.decode() job_status = cache_result.decode()
error_msg = "" error_msg = ""
if job_status == "error": if job_status == "error":

View File

@ -1,5 +1,4 @@
from datetime import datetime from dateutil.parser import isoparse
from flask_restful import Resource, marshal_with, reqparse # type: ignore from flask_restful import Resource, marshal_with, reqparse # type: ignore
from flask_restful.inputs import int_range # type: ignore from flask_restful.inputs import int_range # type: ignore
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@ -41,10 +40,10 @@ class WorkflowAppLogApi(Resource):
args.status = WorkflowRunStatus(args.status) if args.status else None args.status = WorkflowRunStatus(args.status) if args.status else None
if args.created_at__before: if args.created_at__before:
args.created_at__before = datetime.fromisoformat(args.created_at__before.replace("Z", "+00:00")) args.created_at__before = isoparse(args.created_at__before)
if args.created_at__after: if args.created_at__after:
args.created_at__after = datetime.fromisoformat(args.created_at__after.replace("Z", "+00:00")) args.created_at__after = isoparse(args.created_at__after)
# get paginate workflow app logs # get paginate workflow app logs
workflow_app_service = WorkflowAppService() workflow_app_service = WorkflowAppService()

View File

@ -74,7 +74,9 @@ class OAuthDataSourceBinding(Resource):
if not oauth_provider: if not oauth_provider:
return {"error": "Invalid provider"}, 400 return {"error": "Invalid provider"}, 400
if "code" in request.args: if "code" in request.args:
code = request.args.get("code") code = request.args.get("code", "")
if not code:
return {"error": "Invalid code"}, 400
try: try:
oauth_provider.get_access_token(code) oauth_provider.get_access_token(code)
except requests.exceptions.HTTPError as e: except requests.exceptions.HTTPError as e:

View File

@ -398,7 +398,7 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
indexing_cache_key = "segment_batch_import_{}".format(job_id) indexing_cache_key = "segment_batch_import_{}".format(job_id)
cache_result = redis_client.get(indexing_cache_key) cache_result = redis_client.get(indexing_cache_key)
if cache_result is None: if cache_result is None:
raise ValueError("The job is not exist.") raise ValueError("The job does not exist.")
return {"job_id": job_id, "job_status": cache_result.decode()}, 200 return {"job_id": job_id, "job_status": cache_result.decode()}, 200

View File

@ -21,12 +21,6 @@ def _validate_name(name):
return name return name
def _validate_description_length(description):
if description and len(description) > 400:
raise ValueError("Description cannot exceed 400 characters.")
return description
class ExternalApiTemplateListApi(Resource): class ExternalApiTemplateListApi(Resource):
@setup_required @setup_required
@login_required @login_required

View File

@ -14,18 +14,6 @@ from services.entities.knowledge_entities.knowledge_entities import (
from services.metadata_service import MetadataService from services.metadata_service import MetadataService
def _validate_name(name):
if not name or len(name) < 1 or len(name) > 40:
raise ValueError("Name must be between 1 to 40 characters.")
return name
def _validate_description_length(description):
if len(description) > 400:
raise ValueError("Description cannot exceed 400 characters.")
return description
class DatasetMetadataCreateApi(Resource): class DatasetMetadataCreateApi(Resource):
@setup_required @setup_required
@login_required @login_required

View File

@ -1,10 +1,10 @@
from flask_restful import marshal_with # type: ignore from flask_restful import marshal_with # type: ignore
from controllers.common import fields from controllers.common import fields
from controllers.common import helpers as controller_helpers
from controllers.console import api from controllers.console import api
from controllers.console.app.error import AppUnavailableError from controllers.console.app.error import AppUnavailableError
from controllers.console.explore.wraps import InstalledAppResource from controllers.console.explore.wraps import InstalledAppResource
from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict
from models.model import AppMode, InstalledApp from models.model import AppMode, InstalledApp
from services.app_service import AppService from services.app_service import AppService
@ -36,9 +36,7 @@ class AppParameterApi(InstalledAppResource):
user_input_form = features_dict.get("user_input_form", []) user_input_form = features_dict.get("user_input_form", [])
return controller_helpers.get_parameters_from_feature_dict( return get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form)
features_dict=features_dict, user_input_form=user_input_form
)
class ExploreAppMetaApi(InstalledAppResource): class ExploreAppMetaApi(InstalledAppResource):

View File

@ -286,8 +286,6 @@ class AccountDeleteApi(Resource):
class AccountDeleteUpdateFeedbackApi(Resource): class AccountDeleteUpdateFeedbackApi(Resource):
@setup_required @setup_required
def post(self): def post(self):
account = current_user
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("email", type=str, required=True, location="json") parser.add_argument("email", type=str, required=True, location="json")
parser.add_argument("feedback", type=str, required=True, location="json") parser.add_argument("feedback", type=str, required=True, location="json")

View File

@ -249,6 +249,31 @@ class PluginInstallFromMarketplaceApi(Resource):
return jsonable_encoder(response) return jsonable_encoder(response)
class PluginFetchMarketplacePkgApi(Resource):
@setup_required
@login_required
@account_initialization_required
@plugin_permission_required(install_required=True)
def get(self):
tenant_id = current_user.current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument("plugin_unique_identifier", type=str, required=True, location="args")
args = parser.parse_args()
try:
return jsonable_encoder(
{
"manifest": PluginService.fetch_marketplace_pkg(
tenant_id,
args["plugin_unique_identifier"],
)
}
)
except PluginDaemonClientSideError as e:
raise ValueError(e)
class PluginFetchManifestApi(Resource): class PluginFetchManifestApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -488,6 +513,7 @@ api.add_resource(PluginDeleteInstallTaskApi, "/workspaces/current/plugin/tasks/<
api.add_resource(PluginDeleteAllInstallTaskItemsApi, "/workspaces/current/plugin/tasks/delete_all") api.add_resource(PluginDeleteAllInstallTaskItemsApi, "/workspaces/current/plugin/tasks/delete_all")
api.add_resource(PluginDeleteInstallTaskItemApi, "/workspaces/current/plugin/tasks/<task_id>/delete/<path:identifier>") api.add_resource(PluginDeleteInstallTaskItemApi, "/workspaces/current/plugin/tasks/<task_id>/delete/<path:identifier>")
api.add_resource(PluginUninstallApi, "/workspaces/current/plugin/uninstall") api.add_resource(PluginUninstallApi, "/workspaces/current/plugin/uninstall")
api.add_resource(PluginFetchMarketplacePkgApi, "/workspaces/current/plugin/marketplace/pkg")
api.add_resource(PluginChangePermissionApi, "/workspaces/current/plugin/permission/change") api.add_resource(PluginChangePermissionApi, "/workspaces/current/plugin/permission/change")
api.add_resource(PluginFetchPermissionApi, "/workspaces/current/plugin/permission/fetch") api.add_resource(PluginFetchPermissionApi, "/workspaces/current/plugin/permission/fetch")

View File

@ -13,6 +13,7 @@ from core.plugin.backwards_invocation.model import PluginModelBackwardsInvocatio
from core.plugin.backwards_invocation.node import PluginNodeBackwardsInvocation from core.plugin.backwards_invocation.node import PluginNodeBackwardsInvocation
from core.plugin.backwards_invocation.tool import PluginToolBackwardsInvocation from core.plugin.backwards_invocation.tool import PluginToolBackwardsInvocation
from core.plugin.entities.request import ( from core.plugin.entities.request import (
RequestFetchAppInfo,
RequestInvokeApp, RequestInvokeApp,
RequestInvokeEncrypt, RequestInvokeEncrypt,
RequestInvokeLLM, RequestInvokeLLM,
@ -278,6 +279,17 @@ class PluginUploadFileRequestApi(Resource):
return BaseBackwardsInvocationResponse(data={"url": url}).model_dump() return BaseBackwardsInvocationResponse(data={"url": url}).model_dump()
class PluginFetchAppInfoApi(Resource):
@setup_required
@plugin_inner_api_only
@get_user_tenant
@plugin_data(payload_type=RequestFetchAppInfo)
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestFetchAppInfo):
return BaseBackwardsInvocationResponse(
data=PluginAppBackwardsInvocation.fetch_app_info(payload.app_id, tenant_model.id)
).model_dump()
api.add_resource(PluginInvokeLLMApi, "/invoke/llm") api.add_resource(PluginInvokeLLMApi, "/invoke/llm")
api.add_resource(PluginInvokeTextEmbeddingApi, "/invoke/text-embedding") api.add_resource(PluginInvokeTextEmbeddingApi, "/invoke/text-embedding")
api.add_resource(PluginInvokeRerankApi, "/invoke/rerank") api.add_resource(PluginInvokeRerankApi, "/invoke/rerank")
@ -291,3 +303,4 @@ api.add_resource(PluginInvokeAppApi, "/invoke/app")
api.add_resource(PluginInvokeEncryptApi, "/invoke/encrypt") api.add_resource(PluginInvokeEncryptApi, "/invoke/encrypt")
api.add_resource(PluginInvokeSummaryApi, "/invoke/summary") api.add_resource(PluginInvokeSummaryApi, "/invoke/summary")
api.add_resource(PluginUploadFileRequestApi, "/upload/file/request") api.add_resource(PluginUploadFileRequestApi, "/upload/file/request")
api.add_resource(PluginFetchAppInfoApi, "/fetch/app/info")

View File

@ -1,10 +1,10 @@
from flask_restful import Resource, marshal_with # type: ignore from flask_restful import Resource, marshal_with # type: ignore
from controllers.common import fields from controllers.common import fields
from controllers.common import helpers as controller_helpers
from controllers.service_api import api from controllers.service_api import api
from controllers.service_api.app.error import AppUnavailableError from controllers.service_api.app.error import AppUnavailableError
from controllers.service_api.wraps import validate_app_token from controllers.service_api.wraps import validate_app_token
from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict
from models.model import App, AppMode from models.model import App, AppMode
from services.app_service import AppService from services.app_service import AppService
@ -32,9 +32,7 @@ class AppParameterApi(Resource):
user_input_form = features_dict.get("user_input_form", []) user_input_form = features_dict.get("user_input_form", [])
return controller_helpers.get_parameters_from_feature_dict( return get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form)
features_dict=features_dict, user_input_form=user_input_form
)
class AppMetaApi(Resource): class AppMetaApi(Resource):

View File

@ -1,6 +1,6 @@
import logging import logging
from datetime import datetime
from dateutil.parser import isoparse
from flask_restful import Resource, fields, marshal_with, reqparse # type: ignore from flask_restful import Resource, fields, marshal_with, reqparse # type: ignore
from flask_restful.inputs import int_range # type: ignore from flask_restful.inputs import int_range # type: ignore
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@ -140,10 +140,10 @@ class WorkflowAppLogApi(Resource):
args.status = WorkflowRunStatus(args.status) if args.status else None args.status = WorkflowRunStatus(args.status) if args.status else None
if args.created_at__before: if args.created_at__before:
args.created_at__before = datetime.fromisoformat(args.created_at__before.replace("Z", "+00:00")) args.created_at__before = isoparse(args.created_at__before)
if args.created_at__after: if args.created_at__after:
args.created_at__after = datetime.fromisoformat(args.created_at__after.replace("Z", "+00:00")) args.created_at__after = isoparse(args.created_at__after)
# get paginate workflow app logs # get paginate workflow app logs
workflow_app_service = WorkflowAppService() workflow_app_service = WorkflowAppService()

View File

@ -13,6 +13,7 @@ from fields.dataset_fields import dataset_detail_fields
from libs.login import current_user from libs.login import current_user
from models.dataset import Dataset, DatasetPermissionEnum from models.dataset import Dataset, DatasetPermissionEnum
from services.dataset_service import DatasetPermissionService, DatasetService from services.dataset_service import DatasetPermissionService, DatasetService
from services.entities.knowledge_entities.knowledge_entities import RetrievalModel
def _validate_name(name): def _validate_name(name):
@ -120,8 +121,11 @@ class DatasetListApi(DatasetApiResource):
nullable=True, nullable=True,
required=False, required=False,
) )
args = parser.parse_args() parser.add_argument("retrieval_model", type=dict, required=False, nullable=True, location="json")
parser.add_argument("embedding_model", type=str, required=False, nullable=True, location="json")
parser.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
args = parser.parse_args()
try: try:
dataset = DatasetService.create_empty_dataset( dataset = DatasetService.create_empty_dataset(
tenant_id=tenant_id, tenant_id=tenant_id,
@ -133,6 +137,11 @@ class DatasetListApi(DatasetApiResource):
provider=args["provider"], provider=args["provider"],
external_knowledge_api_id=args["external_knowledge_api_id"], external_knowledge_api_id=args["external_knowledge_api_id"],
external_knowledge_id=args["external_knowledge_id"], external_knowledge_id=args["external_knowledge_id"],
embedding_model_provider=args["embedding_model_provider"],
embedding_model_name=args["embedding_model"],
retrieval_model=RetrievalModel(**args["retrieval_model"])
if args["retrieval_model"] is not None
else None,
) )
except services.errors.dataset.DatasetNameDuplicateError: except services.errors.dataset.DatasetNameDuplicateError:
raise DatasetNameDuplicateError() raise DatasetNameDuplicateError()

View File

@ -49,7 +49,9 @@ class DocumentAddByTextApi(DatasetApiResource):
parser.add_argument( parser.add_argument(
"indexing_technique", type=str, choices=Dataset.INDEXING_TECHNIQUE_LIST, nullable=False, location="json" "indexing_technique", type=str, choices=Dataset.INDEXING_TECHNIQUE_LIST, nullable=False, location="json"
) )
parser.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json") parser.add_argument("retrieval_model", type=dict, required=False, nullable=True, location="json")
parser.add_argument("embedding_model", type=str, required=False, nullable=True, location="json")
parser.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
args = parser.parse_args() args = parser.parse_args()
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
@ -57,7 +59,7 @@ class DocumentAddByTextApi(DatasetApiResource):
dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset: if not dataset:
raise ValueError("Dataset is not exist.") raise ValueError("Dataset does not exist.")
if not dataset.indexing_technique and not args["indexing_technique"]: if not dataset.indexing_technique and not args["indexing_technique"]:
raise ValueError("indexing_technique is required.") raise ValueError("indexing_technique is required.")
@ -114,7 +116,7 @@ class DocumentUpdateByTextApi(DatasetApiResource):
dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset: if not dataset:
raise ValueError("Dataset is not exist.") raise ValueError("Dataset does not exist.")
# indexing_technique is already set in dataset since this is an update # indexing_technique is already set in dataset since this is an update
args["indexing_technique"] = dataset.indexing_technique args["indexing_technique"] = dataset.indexing_technique
@ -172,7 +174,7 @@ class DocumentAddByFileApi(DatasetApiResource):
dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset: if not dataset:
raise ValueError("Dataset is not exist.") raise ValueError("Dataset does not exist.")
if not dataset.indexing_technique and not args.get("indexing_technique"): if not dataset.indexing_technique and not args.get("indexing_technique"):
raise ValueError("indexing_technique is required.") raise ValueError("indexing_technique is required.")
@ -239,7 +241,7 @@ class DocumentUpdateByFileApi(DatasetApiResource):
dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset: if not dataset:
raise ValueError("Dataset is not exist.") raise ValueError("Dataset does not exist.")
# indexing_technique is already set in dataset since this is an update # indexing_technique is already set in dataset since this is an update
args["indexing_technique"] = dataset.indexing_technique args["indexing_technique"] = dataset.indexing_technique
@ -303,7 +305,7 @@ class DocumentDeleteApi(DatasetApiResource):
dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset: if not dataset:
raise ValueError("Dataset is not exist.") raise ValueError("Dataset does not exist.")
document = DocumentService.get_document(dataset.id, document_id) document = DocumentService.get_document(dataset.id, document_id)

View File

@ -13,18 +13,6 @@ from services.entities.knowledge_entities.knowledge_entities import (
from services.metadata_service import MetadataService from services.metadata_service import MetadataService
def _validate_name(name):
if not name or len(name) < 1 or len(name) > 40:
raise ValueError("Name must be between 1 to 40 characters.")
return name
def _validate_description_length(description):
if len(description) > 400:
raise ValueError("Description cannot exceed 400 characters.")
return description
class DatasetMetadataCreateServiceApi(DatasetApiResource): class DatasetMetadataCreateServiceApi(DatasetApiResource):
def post(self, tenant_id, dataset_id): def post(self, tenant_id, dataset_id):
parser = reqparse.RequestParser() parser = reqparse.RequestParser()

View File

@ -117,14 +117,13 @@ class SegmentApi(DatasetApiResource):
parser.add_argument("keyword", type=str, default=None, location="args") parser.add_argument("keyword", type=str, default=None, location="args")
args = parser.parse_args() args = parser.parse_args()
status_list = args["status"]
keyword = args["keyword"]
segments, total = SegmentService.get_segments( segments, total = SegmentService.get_segments(
document_id=document_id, document_id=document_id,
tenant_id=current_user.current_tenant_id, tenant_id=current_user.current_tenant_id,
status_list=args["status"], status_list=args["status"],
keyword=args["keyword"], keyword=args["keyword"],
page=page,
limit=limit,
) )
response = { response = {

View File

@ -1,10 +1,10 @@
from flask_restful import marshal_with # type: ignore from flask_restful import marshal_with # type: ignore
from controllers.common import fields from controllers.common import fields
from controllers.common import helpers as controller_helpers
from controllers.web import api from controllers.web import api
from controllers.web.error import AppUnavailableError from controllers.web.error import AppUnavailableError
from controllers.web.wraps import WebApiResource from controllers.web.wraps import WebApiResource
from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict
from models.model import App, AppMode from models.model import App, AppMode
from services.app_service import AppService from services.app_service import AppService
@ -31,9 +31,7 @@ class AppParameterApi(WebApiResource):
user_input_form = features_dict.get("user_input_form", []) user_input_form = features_dict.get("user_input_form", [])
return controller_helpers.get_parameters_from_feature_dict( return get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form)
features_dict=features_dict, user_input_form=user_input_form
)
class AppMeta(WebApiResource): class AppMeta(WebApiResource):

View File

@ -46,6 +46,7 @@ class MessageListApi(WebApiResource):
"retriever_resources": fields.List(fields.Nested(retriever_resource_fields)), "retriever_resources": fields.List(fields.Nested(retriever_resource_fields)),
"created_at": TimestampField, "created_at": TimestampField,
"agent_thoughts": fields.List(fields.Nested(agent_thought_fields)), "agent_thoughts": fields.List(fields.Nested(agent_thought_fields)),
"metadata": fields.Raw(attribute="message_metadata_dict"),
"status": fields.String, "status": fields.String,
"error": fields.String, "error": fields.String,
} }

View File

@ -191,7 +191,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
# action is final answer, return final answer directly # action is final answer, return final answer directly
try: try:
if isinstance(scratchpad.action.action_input, dict): if isinstance(scratchpad.action.action_input, dict):
final_answer = json.dumps(scratchpad.action.action_input) final_answer = json.dumps(scratchpad.action.action_input, ensure_ascii=False)
elif isinstance(scratchpad.action.action_input, str): elif isinstance(scratchpad.action.action_input, str):
final_answer = scratchpad.action.action_input final_answer = scratchpad.action.action_input
else: else:

View File

@ -52,6 +52,7 @@ class AgentStrategyParameter(PluginParameter):
return cast_parameter_value(self, value) return cast_parameter_value(self, value)
type: AgentStrategyParameterType = Field(..., description="The type of the parameter") type: AgentStrategyParameterType = Field(..., description="The type of the parameter")
help: Optional[I18nObject] = None
def init_frontend_parameter(self, value: Any): def init_frontend_parameter(self, value: Any):
return init_frontend_parameter(self, self.type, value) return init_frontend_parameter(self, self.type, value)

View File

@ -0,0 +1,45 @@
from collections.abc import Mapping
from typing import Any
from configs import dify_config
from constants import DEFAULT_FILE_NUMBER_LIMITS
def get_parameters_from_feature_dict(
*, features_dict: Mapping[str, Any], user_input_form: list[dict[str, Any]]
) -> Mapping[str, Any]:
"""
Mapping from feature dict to webapp parameters
"""
return {
"opening_statement": features_dict.get("opening_statement"),
"suggested_questions": features_dict.get("suggested_questions", []),
"suggested_questions_after_answer": features_dict.get("suggested_questions_after_answer", {"enabled": False}),
"speech_to_text": features_dict.get("speech_to_text", {"enabled": False}),
"text_to_speech": features_dict.get("text_to_speech", {"enabled": False}),
"retriever_resource": features_dict.get("retriever_resource", {"enabled": False}),
"annotation_reply": features_dict.get("annotation_reply", {"enabled": False}),
"more_like_this": features_dict.get("more_like_this", {"enabled": False}),
"user_input_form": user_input_form,
"sensitive_word_avoidance": features_dict.get(
"sensitive_word_avoidance", {"enabled": False, "type": "", "configs": []}
),
"file_upload": features_dict.get(
"file_upload",
{
"image": {
"enabled": False,
"number_limits": DEFAULT_FILE_NUMBER_LIMITS,
"detail": "high",
"transfer_methods": ["remote_url", "local_file"],
}
},
),
"system_parameters": {
"image_file_size_limit": dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT,
"video_file_size_limit": dify_config.UPLOAD_VIDEO_FILE_SIZE_LIMIT,
"audio_file_size_limit": dify_config.UPLOAD_AUDIO_FILE_SIZE_LIMIT,
"file_size_limit": dify_config.UPLOAD_FILE_SIZE_LIMIT,
"workflow_file_upload_limit": dify_config.WORKFLOW_FILE_UPLOAD_LIMIT,
},
}

View File

@ -1,6 +1,7 @@
from collections.abc import Mapping from collections.abc import Mapping
from typing import Any from typing import Any
from constants import DEFAULT_FILE_NUMBER_LIMITS
from core.file import FileUploadConfig from core.file import FileUploadConfig
@ -18,7 +19,7 @@ class FileUploadConfigManager:
if file_upload_dict.get("enabled"): if file_upload_dict.get("enabled"):
transform_methods = file_upload_dict.get("allowed_file_upload_methods", []) transform_methods = file_upload_dict.get("allowed_file_upload_methods", [])
file_upload_dict["image_config"] = { file_upload_dict["image_config"] = {
"number_limits": file_upload_dict.get("number_limits", 1), "number_limits": file_upload_dict.get("number_limits", DEFAULT_FILE_NUMBER_LIMITS),
"transfer_methods": transform_methods, "transfer_methods": transform_methods,
} }

View File

@ -320,10 +320,9 @@ class AdvancedChatAppGenerateTaskPipeline:
session=session, workflow_run_id=self._workflow_run_id session=session, workflow_run_id=self._workflow_run_id
) )
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_retried( workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_retried(
session=session, workflow_run=workflow_run, event=event workflow_run=workflow_run, event=event
) )
node_retry_resp = self._workflow_cycle_manager._workflow_node_retry_to_stream_response( node_retry_resp = self._workflow_cycle_manager._workflow_node_retry_to_stream_response(
session=session,
event=event, event=event,
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution, workflow_node_execution=workflow_node_execution,
@ -341,11 +340,10 @@ class AdvancedChatAppGenerateTaskPipeline:
session=session, workflow_run_id=self._workflow_run_id session=session, workflow_run_id=self._workflow_run_id
) )
workflow_node_execution = self._workflow_cycle_manager._handle_node_execution_start( workflow_node_execution = self._workflow_cycle_manager._handle_node_execution_start(
session=session, workflow_run=workflow_run, event=event workflow_run=workflow_run, event=event
) )
node_start_resp = self._workflow_cycle_manager._workflow_node_start_to_stream_response( node_start_resp = self._workflow_cycle_manager._workflow_node_start_to_stream_response(
session=session,
event=event, event=event,
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution, workflow_node_execution=workflow_node_execution,
@ -363,11 +361,10 @@ class AdvancedChatAppGenerateTaskPipeline:
with Session(db.engine, expire_on_commit=False) as session: with Session(db.engine, expire_on_commit=False) as session:
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_success( workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_success(
session=session, event=event event=event
) )
node_finish_resp = self._workflow_cycle_manager._workflow_node_finish_to_stream_response( node_finish_resp = self._workflow_cycle_manager._workflow_node_finish_to_stream_response(
session=session,
event=event, event=event,
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution, workflow_node_execution=workflow_node_execution,
@ -383,18 +380,15 @@ class AdvancedChatAppGenerateTaskPipeline:
| QueueNodeInLoopFailedEvent | QueueNodeInLoopFailedEvent
| QueueNodeExceptionEvent, | QueueNodeExceptionEvent,
): ):
with Session(db.engine, expire_on_commit=False) as session: workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_failed(
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_failed( event=event
session=session, event=event )
)
node_finish_resp = self._workflow_cycle_manager._workflow_node_finish_to_stream_response( node_finish_resp = self._workflow_cycle_manager._workflow_node_finish_to_stream_response(
session=session, event=event,
event=event, task_id=self._application_generate_entity.task_id,
task_id=self._application_generate_entity.task_id, workflow_node_execution=workflow_node_execution,
workflow_node_execution=workflow_node_execution, )
)
session.commit()
if node_finish_resp: if node_finish_resp:
yield node_finish_resp yield node_finish_resp

View File

@ -17,6 +17,7 @@ class BaseAppGenerator:
user_inputs: Optional[Mapping[str, Any]], user_inputs: Optional[Mapping[str, Any]],
variables: Sequence["VariableEntity"], variables: Sequence["VariableEntity"],
tenant_id: str, tenant_id: str,
strict_type_validation: bool = False,
) -> Mapping[str, Any]: ) -> Mapping[str, Any]:
user_inputs = user_inputs or {} user_inputs = user_inputs or {}
# Filter input variables from form configuration, handle required fields, default values, and option values # Filter input variables from form configuration, handle required fields, default values, and option values
@ -37,6 +38,7 @@ class BaseAppGenerator:
allowed_file_extensions=entity_dictionary[k].allowed_file_extensions, allowed_file_extensions=entity_dictionary[k].allowed_file_extensions,
allowed_file_upload_methods=entity_dictionary[k].allowed_file_upload_methods, allowed_file_upload_methods=entity_dictionary[k].allowed_file_upload_methods,
), ),
strict_type_validation=strict_type_validation,
) )
for k, v in user_inputs.items() for k, v in user_inputs.items()
if isinstance(v, dict) and entity_dictionary[k].type == VariableEntityType.FILE if isinstance(v, dict) and entity_dictionary[k].type == VariableEntityType.FILE

View File

@ -153,6 +153,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
query = application_generate_entity.query or "New conversation" query = application_generate_entity.query or "New conversation"
else: else:
query = next(iter(application_generate_entity.inputs.values()), "New conversation") query = next(iter(application_generate_entity.inputs.values()), "New conversation")
query = query or "New conversation"
conversation_name = (query[:20] + "") if len(query) > 20 else query conversation_name = (query[:20] + "") if len(query) > 20 else query
if not conversation: if not conversation:

View File

@ -92,6 +92,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
mappings=files, mappings=files,
tenant_id=app_model.tenant_id, tenant_id=app_model.tenant_id,
config=file_extra_config, config=file_extra_config,
strict_type_validation=True if invoke_from == InvokeFrom.SERVICE_API else False,
) )
# convert to app config # convert to app config
@ -114,7 +115,10 @@ class WorkflowAppGenerator(BaseAppGenerator):
app_config=app_config, app_config=app_config,
file_upload_config=file_extra_config, file_upload_config=file_extra_config,
inputs=self._prepare_user_inputs( inputs=self._prepare_user_inputs(
user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id user_inputs=inputs,
variables=app_config.variables,
tenant_id=app_model.tenant_id,
strict_type_validation=True if invoke_from == InvokeFrom.SERVICE_API else False,
), ),
files=list(system_files), files=list(system_files),
user_id=user.id, user_id=user.id,

View File

@ -279,10 +279,9 @@ class WorkflowAppGenerateTaskPipeline:
session=session, workflow_run_id=self._workflow_run_id session=session, workflow_run_id=self._workflow_run_id
) )
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_retried( workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_retried(
session=session, workflow_run=workflow_run, event=event workflow_run=workflow_run, event=event
) )
response = self._workflow_cycle_manager._workflow_node_retry_to_stream_response( response = self._workflow_cycle_manager._workflow_node_retry_to_stream_response(
session=session,
event=event, event=event,
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution, workflow_node_execution=workflow_node_execution,
@ -300,10 +299,9 @@ class WorkflowAppGenerateTaskPipeline:
session=session, workflow_run_id=self._workflow_run_id session=session, workflow_run_id=self._workflow_run_id
) )
workflow_node_execution = self._workflow_cycle_manager._handle_node_execution_start( workflow_node_execution = self._workflow_cycle_manager._handle_node_execution_start(
session=session, workflow_run=workflow_run, event=event workflow_run=workflow_run, event=event
) )
node_start_response = self._workflow_cycle_manager._workflow_node_start_to_stream_response( node_start_response = self._workflow_cycle_manager._workflow_node_start_to_stream_response(
session=session,
event=event, event=event,
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution, workflow_node_execution=workflow_node_execution,
@ -313,17 +311,14 @@ class WorkflowAppGenerateTaskPipeline:
if node_start_response: if node_start_response:
yield node_start_response yield node_start_response
elif isinstance(event, QueueNodeSucceededEvent): elif isinstance(event, QueueNodeSucceededEvent):
with Session(db.engine, expire_on_commit=False) as session: workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_success(
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_success( event=event
session=session, event=event )
) node_success_response = self._workflow_cycle_manager._workflow_node_finish_to_stream_response(
node_success_response = self._workflow_cycle_manager._workflow_node_finish_to_stream_response( event=event,
session=session, task_id=self._application_generate_entity.task_id,
event=event, workflow_node_execution=workflow_node_execution,
task_id=self._application_generate_entity.task_id, )
workflow_node_execution=workflow_node_execution,
)
session.commit()
if node_success_response: if node_success_response:
yield node_success_response yield node_success_response
@ -334,18 +329,14 @@ class WorkflowAppGenerateTaskPipeline:
| QueueNodeInLoopFailedEvent | QueueNodeInLoopFailedEvent
| QueueNodeExceptionEvent, | QueueNodeExceptionEvent,
): ):
with Session(db.engine, expire_on_commit=False) as session: workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_failed(
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_failed( event=event,
session=session, )
event=event, node_failed_response = self._workflow_cycle_manager._workflow_node_finish_to_stream_response(
) event=event,
node_failed_response = self._workflow_cycle_manager._workflow_node_finish_to_stream_response( task_id=self._application_generate_entity.task_id,
session=session, workflow_node_execution=workflow_node_execution,
event=event, )
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
)
session.commit()
if node_failed_response: if node_failed_response:
yield node_failed_response yield node_failed_response
@ -627,6 +618,7 @@ class WorkflowAppGenerateTaskPipeline:
workflow_app_log.created_by = self._user_id workflow_app_log.created_by = self._user_id
session.add(workflow_app_log) session.add(workflow_app_log)
session.commit()
def _text_chunk_to_stream_response( def _text_chunk_to_stream_response(
self, text: str, from_variable_selector: Optional[list[str]] = None self, text: str, from_variable_selector: Optional[list[str]] = None

View File

@ -6,7 +6,7 @@ from typing import Any, Optional, Union, cast
from uuid import uuid4 from uuid import uuid4
from sqlalchemy import func, select from sqlalchemy import func, select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session, sessionmaker
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity
from core.app.entities.queue_entities import ( from core.app.entities.queue_entities import (
@ -49,12 +49,14 @@ from core.file import FILE_MODEL_IDENTITY, File
from core.model_runtime.utils.encoders import jsonable_encoder from core.model_runtime.utils.encoders import jsonable_encoder
from core.ops.entities.trace_entity import TraceTaskName from core.ops.entities.trace_entity import TraceTaskName
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
from core.repository import RepositoryFactory
from core.tools.tool_manager import ToolManager from core.tools.tool_manager import ToolManager
from core.workflow.entities.node_entities import NodeRunMetadataKey from core.workflow.entities.node_entities import NodeRunMetadataKey
from core.workflow.enums import SystemVariableKey from core.workflow.enums import SystemVariableKey
from core.workflow.nodes import NodeType from core.workflow.nodes import NodeType
from core.workflow.nodes.tool.entities import ToolNodeData from core.workflow.nodes.tool.entities import ToolNodeData
from core.workflow.workflow_entry import WorkflowEntry from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_database import db
from models.account import Account from models.account import Account
from models.enums import CreatedByRole, WorkflowRunTriggeredFrom from models.enums import CreatedByRole, WorkflowRunTriggeredFrom
from models.model import EndUser from models.model import EndUser
@ -80,6 +82,21 @@ class WorkflowCycleManage:
self._application_generate_entity = application_generate_entity self._application_generate_entity = application_generate_entity
self._workflow_system_variables = workflow_system_variables self._workflow_system_variables = workflow_system_variables
# Initialize the session factory and repository
# We use the global db engine instead of the session passed to methods
# Disable expire_on_commit to avoid the need for merging objects
self._session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
self._workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository(
params={
"tenant_id": self._application_generate_entity.app_config.tenant_id,
"app_id": self._application_generate_entity.app_config.app_id,
"session_factory": self._session_factory,
}
)
# We'll still keep the cache for backward compatibility and performance
# but use the repository for database operations
def _handle_workflow_run_start( def _handle_workflow_run_start(
self, self,
*, *,
@ -254,19 +271,15 @@ class WorkflowCycleManage:
workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None) workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None)
workflow_run.exceptions_count = exceptions_count workflow_run.exceptions_count = exceptions_count
stmt = select(WorkflowNodeExecution.node_execution_id).where( # Use the instance repository to find running executions for a workflow run
WorkflowNodeExecution.tenant_id == workflow_run.tenant_id, running_workflow_node_executions = self._workflow_node_execution_repository.get_running_executions(
WorkflowNodeExecution.app_id == workflow_run.app_id, workflow_run_id=workflow_run.id
WorkflowNodeExecution.workflow_id == workflow_run.workflow_id,
WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
WorkflowNodeExecution.workflow_run_id == workflow_run.id,
WorkflowNodeExecution.status == WorkflowNodeExecutionStatus.RUNNING.value,
) )
ids = session.scalars(stmt).all()
# Use self._get_workflow_node_execution here to make sure the cache is updated # Update the cache with the retrieved executions
running_workflow_node_executions = [ for execution in running_workflow_node_executions:
self._get_workflow_node_execution(session=session, node_execution_id=id) for id in ids if id if execution.node_execution_id:
] self._workflow_node_executions[execution.node_execution_id] = execution
for workflow_node_execution in running_workflow_node_executions: for workflow_node_execution in running_workflow_node_executions:
now = datetime.now(UTC).replace(tzinfo=None) now = datetime.now(UTC).replace(tzinfo=None)
@ -288,7 +301,7 @@ class WorkflowCycleManage:
return workflow_run return workflow_run
def _handle_node_execution_start( def _handle_node_execution_start(
self, *, session: Session, workflow_run: WorkflowRun, event: QueueNodeStartedEvent self, *, workflow_run: WorkflowRun, event: QueueNodeStartedEvent
) -> WorkflowNodeExecution: ) -> WorkflowNodeExecution:
workflow_node_execution = WorkflowNodeExecution() workflow_node_execution = WorkflowNodeExecution()
workflow_node_execution.id = str(uuid4()) workflow_node_execution.id = str(uuid4())
@ -315,17 +328,14 @@ class WorkflowCycleManage:
) )
workflow_node_execution.created_at = datetime.now(UTC).replace(tzinfo=None) workflow_node_execution.created_at = datetime.now(UTC).replace(tzinfo=None)
session.add(workflow_node_execution) # Use the instance repository to save the workflow node execution
self._workflow_node_execution_repository.save(workflow_node_execution)
self._workflow_node_executions[event.node_execution_id] = workflow_node_execution self._workflow_node_executions[event.node_execution_id] = workflow_node_execution
return workflow_node_execution return workflow_node_execution
def _handle_workflow_node_execution_success( def _handle_workflow_node_execution_success(self, *, event: QueueNodeSucceededEvent) -> WorkflowNodeExecution:
self, *, session: Session, event: QueueNodeSucceededEvent workflow_node_execution = self._get_workflow_node_execution(node_execution_id=event.node_execution_id)
) -> WorkflowNodeExecution:
workflow_node_execution = self._get_workflow_node_execution(
session=session, node_execution_id=event.node_execution_id
)
inputs = WorkflowEntry.handle_special_values(event.inputs) inputs = WorkflowEntry.handle_special_values(event.inputs)
process_data = WorkflowEntry.handle_special_values(event.process_data) process_data = WorkflowEntry.handle_special_values(event.process_data)
outputs = WorkflowEntry.handle_special_values(event.outputs) outputs = WorkflowEntry.handle_special_values(event.outputs)
@ -344,13 +354,13 @@ class WorkflowCycleManage:
workflow_node_execution.finished_at = finished_at workflow_node_execution.finished_at = finished_at
workflow_node_execution.elapsed_time = elapsed_time workflow_node_execution.elapsed_time = elapsed_time
workflow_node_execution = session.merge(workflow_node_execution) # Use the instance repository to update the workflow node execution
self._workflow_node_execution_repository.update(workflow_node_execution)
return workflow_node_execution return workflow_node_execution
def _handle_workflow_node_execution_failed( def _handle_workflow_node_execution_failed(
self, self,
*, *,
session: Session,
event: QueueNodeFailedEvent event: QueueNodeFailedEvent
| QueueNodeInIterationFailedEvent | QueueNodeInIterationFailedEvent
| QueueNodeInLoopFailedEvent | QueueNodeInLoopFailedEvent
@ -361,9 +371,7 @@ class WorkflowCycleManage:
:param event: queue node failed event :param event: queue node failed event
:return: :return:
""" """
workflow_node_execution = self._get_workflow_node_execution( workflow_node_execution = self._get_workflow_node_execution(node_execution_id=event.node_execution_id)
session=session, node_execution_id=event.node_execution_id
)
inputs = WorkflowEntry.handle_special_values(event.inputs) inputs = WorkflowEntry.handle_special_values(event.inputs)
process_data = WorkflowEntry.handle_special_values(event.process_data) process_data = WorkflowEntry.handle_special_values(event.process_data)
@ -387,14 +395,14 @@ class WorkflowCycleManage:
workflow_node_execution.elapsed_time = elapsed_time workflow_node_execution.elapsed_time = elapsed_time
workflow_node_execution.execution_metadata = execution_metadata workflow_node_execution.execution_metadata = execution_metadata
workflow_node_execution = session.merge(workflow_node_execution)
return workflow_node_execution return workflow_node_execution
def _handle_workflow_node_execution_retried( def _handle_workflow_node_execution_retried(
self, *, session: Session, workflow_run: WorkflowRun, event: QueueNodeRetryEvent self, *, workflow_run: WorkflowRun, event: QueueNodeRetryEvent
) -> WorkflowNodeExecution: ) -> WorkflowNodeExecution:
""" """
Workflow node execution failed Workflow node execution failed
:param workflow_run: workflow run
:param event: queue node failed event :param event: queue node failed event
:return: :return:
""" """
@ -439,15 +447,12 @@ class WorkflowCycleManage:
workflow_node_execution.execution_metadata = execution_metadata workflow_node_execution.execution_metadata = execution_metadata
workflow_node_execution.index = event.node_run_index workflow_node_execution.index = event.node_run_index
session.add(workflow_node_execution) # Use the instance repository to save the workflow node execution
self._workflow_node_execution_repository.save(workflow_node_execution)
self._workflow_node_executions[event.node_execution_id] = workflow_node_execution self._workflow_node_executions[event.node_execution_id] = workflow_node_execution
return workflow_node_execution return workflow_node_execution
#################################################
# to stream responses #
#################################################
def _workflow_start_to_stream_response( def _workflow_start_to_stream_response(
self, self,
*, *,
@ -455,7 +460,6 @@ class WorkflowCycleManage:
task_id: str, task_id: str,
workflow_run: WorkflowRun, workflow_run: WorkflowRun,
) -> WorkflowStartStreamResponse: ) -> WorkflowStartStreamResponse:
# receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this
_ = session _ = session
return WorkflowStartStreamResponse( return WorkflowStartStreamResponse(
task_id=task_id, task_id=task_id,
@ -521,14 +525,10 @@ class WorkflowCycleManage:
def _workflow_node_start_to_stream_response( def _workflow_node_start_to_stream_response(
self, self,
*, *,
session: Session,
event: QueueNodeStartedEvent, event: QueueNodeStartedEvent,
task_id: str, task_id: str,
workflow_node_execution: WorkflowNodeExecution, workflow_node_execution: WorkflowNodeExecution,
) -> Optional[NodeStartStreamResponse]: ) -> Optional[NodeStartStreamResponse]:
# receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this
_ = session
if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}: if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}:
return None return None
if not workflow_node_execution.workflow_run_id: if not workflow_node_execution.workflow_run_id:
@ -571,7 +571,6 @@ class WorkflowCycleManage:
def _workflow_node_finish_to_stream_response( def _workflow_node_finish_to_stream_response(
self, self,
*, *,
session: Session,
event: QueueNodeSucceededEvent event: QueueNodeSucceededEvent
| QueueNodeFailedEvent | QueueNodeFailedEvent
| QueueNodeInIterationFailedEvent | QueueNodeInIterationFailedEvent
@ -580,8 +579,6 @@ class WorkflowCycleManage:
task_id: str, task_id: str,
workflow_node_execution: WorkflowNodeExecution, workflow_node_execution: WorkflowNodeExecution,
) -> Optional[NodeFinishStreamResponse]: ) -> Optional[NodeFinishStreamResponse]:
# receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this
_ = session
if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}: if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}:
return None return None
if not workflow_node_execution.workflow_run_id: if not workflow_node_execution.workflow_run_id:
@ -621,13 +618,10 @@ class WorkflowCycleManage:
def _workflow_node_retry_to_stream_response( def _workflow_node_retry_to_stream_response(
self, self,
*, *,
session: Session,
event: QueueNodeRetryEvent, event: QueueNodeRetryEvent,
task_id: str, task_id: str,
workflow_node_execution: WorkflowNodeExecution, workflow_node_execution: WorkflowNodeExecution,
) -> Optional[Union[NodeRetryStreamResponse, NodeFinishStreamResponse]]: ) -> Optional[Union[NodeRetryStreamResponse, NodeFinishStreamResponse]]:
# receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this
_ = session
if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}: if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}:
return None return None
if not workflow_node_execution.workflow_run_id: if not workflow_node_execution.workflow_run_id:
@ -668,7 +662,6 @@ class WorkflowCycleManage:
def _workflow_parallel_branch_start_to_stream_response( def _workflow_parallel_branch_start_to_stream_response(
self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueParallelBranchRunStartedEvent self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueParallelBranchRunStartedEvent
) -> ParallelBranchStartStreamResponse: ) -> ParallelBranchStartStreamResponse:
# receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this
_ = session _ = session
return ParallelBranchStartStreamResponse( return ParallelBranchStartStreamResponse(
task_id=task_id, task_id=task_id,
@ -692,7 +685,6 @@ class WorkflowCycleManage:
workflow_run: WorkflowRun, workflow_run: WorkflowRun,
event: QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent, event: QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent,
) -> ParallelBranchFinishedStreamResponse: ) -> ParallelBranchFinishedStreamResponse:
# receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this
_ = session _ = session
return ParallelBranchFinishedStreamResponse( return ParallelBranchFinishedStreamResponse(
task_id=task_id, task_id=task_id,
@ -713,7 +705,6 @@ class WorkflowCycleManage:
def _workflow_iteration_start_to_stream_response( def _workflow_iteration_start_to_stream_response(
self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueIterationStartEvent self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueIterationStartEvent
) -> IterationNodeStartStreamResponse: ) -> IterationNodeStartStreamResponse:
# receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this
_ = session _ = session
return IterationNodeStartStreamResponse( return IterationNodeStartStreamResponse(
task_id=task_id, task_id=task_id,
@ -735,7 +726,6 @@ class WorkflowCycleManage:
def _workflow_iteration_next_to_stream_response( def _workflow_iteration_next_to_stream_response(
self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueIterationNextEvent self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueIterationNextEvent
) -> IterationNodeNextStreamResponse: ) -> IterationNodeNextStreamResponse:
# receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this
_ = session _ = session
return IterationNodeNextStreamResponse( return IterationNodeNextStreamResponse(
task_id=task_id, task_id=task_id,
@ -759,7 +749,6 @@ class WorkflowCycleManage:
def _workflow_iteration_completed_to_stream_response( def _workflow_iteration_completed_to_stream_response(
self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueIterationCompletedEvent self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueIterationCompletedEvent
) -> IterationNodeCompletedStreamResponse: ) -> IterationNodeCompletedStreamResponse:
# receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this
_ = session _ = session
return IterationNodeCompletedStreamResponse( return IterationNodeCompletedStreamResponse(
task_id=task_id, task_id=task_id,
@ -790,7 +779,6 @@ class WorkflowCycleManage:
def _workflow_loop_start_to_stream_response( def _workflow_loop_start_to_stream_response(
self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueLoopStartEvent self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueLoopStartEvent
) -> LoopNodeStartStreamResponse: ) -> LoopNodeStartStreamResponse:
# receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this
_ = session _ = session
return LoopNodeStartStreamResponse( return LoopNodeStartStreamResponse(
task_id=task_id, task_id=task_id,
@ -812,7 +800,6 @@ class WorkflowCycleManage:
def _workflow_loop_next_to_stream_response( def _workflow_loop_next_to_stream_response(
self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueLoopNextEvent self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueLoopNextEvent
) -> LoopNodeNextStreamResponse: ) -> LoopNodeNextStreamResponse:
# receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this
_ = session _ = session
return LoopNodeNextStreamResponse( return LoopNodeNextStreamResponse(
task_id=task_id, task_id=task_id,
@ -836,7 +823,6 @@ class WorkflowCycleManage:
def _workflow_loop_completed_to_stream_response( def _workflow_loop_completed_to_stream_response(
self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueLoopCompletedEvent self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueLoopCompletedEvent
) -> LoopNodeCompletedStreamResponse: ) -> LoopNodeCompletedStreamResponse:
# receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this
_ = session _ = session
return LoopNodeCompletedStreamResponse( return LoopNodeCompletedStreamResponse(
task_id=task_id, task_id=task_id,
@ -934,11 +920,22 @@ class WorkflowCycleManage:
return workflow_run return workflow_run
def _get_workflow_node_execution(self, session: Session, node_execution_id: str) -> WorkflowNodeExecution: def _get_workflow_node_execution(self, node_execution_id: str) -> WorkflowNodeExecution:
if node_execution_id not in self._workflow_node_executions: # First check the cache for performance
if node_execution_id in self._workflow_node_executions:
cached_execution = self._workflow_node_executions[node_execution_id]
# No need to merge with session since expire_on_commit=False
return cached_execution
# If not in cache, use the instance repository to get by node_execution_id
execution = self._workflow_node_execution_repository.get_by_node_execution_id(node_execution_id)
if not execution:
raise ValueError(f"Workflow node execution not found: {node_execution_id}") raise ValueError(f"Workflow node execution not found: {node_execution_id}")
cached_workflow_node_execution = self._workflow_node_executions[node_execution_id]
return session.merge(cached_workflow_node_execution) # Update cache
self._workflow_node_executions[node_execution_id] = execution
return execution
def _handle_agent_log(self, task_id: str, event: QueueAgentLogEvent) -> AgentLogStreamResponse: def _handle_agent_log(self, task_id: str, event: QueueAgentLogEvent) -> AgentLogStreamResponse:
""" """

View File

@ -6,7 +6,6 @@ from core.rag.models.document import Document
from extensions.ext_database import db from extensions.ext_database import db
from models.dataset import ChildChunk, DatasetQuery, DocumentSegment from models.dataset import ChildChunk, DatasetQuery, DocumentSegment
from models.dataset import Document as DatasetDocument from models.dataset import Document as DatasetDocument
from models.model import DatasetRetrieverResource
class DatasetIndexToolCallbackHandler: class DatasetIndexToolCallbackHandler:
@ -71,29 +70,6 @@ class DatasetIndexToolCallbackHandler:
def return_retriever_resource_info(self, resource: list): def return_retriever_resource_info(self, resource: list):
"""Handle return_retriever_resource_info.""" """Handle return_retriever_resource_info."""
if resource and len(resource) > 0:
for item in resource:
dataset_retriever_resource = DatasetRetrieverResource(
message_id=self._message_id,
position=item.get("position") or 0,
dataset_id=item.get("dataset_id"),
dataset_name=item.get("dataset_name"),
document_id=item.get("document_id"),
document_name=item.get("document_name"),
data_source_type=item.get("data_source_type"),
segment_id=item.get("segment_id"),
score=item.get("score") if "score" in item else None,
hit_count=item.get("hit_count") if "hit_count" in item else None,
word_count=item.get("word_count") if "word_count" in item else None,
segment_position=item.get("segment_position") if "segment_position" in item else None,
index_node_hash=item.get("index_node_hash") if "index_node_hash" in item else None,
content=item.get("content"),
retriever_from=item.get("retriever_from"),
created_by=self._user_id,
)
db.session.add(dataset_retriever_resource)
db.session.commit()
self._queue_manager.publish( self._queue_manager.publish(
QueueRetrieverResourcesEvent(retriever_resources=resource), PublishFrom.APPLICATION_MANAGER QueueRetrieverResourcesEvent(retriever_resources=resource), PublishFrom.APPLICATION_MANAGER
) )

View File

@ -48,25 +48,26 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
write=dify_config.SSRF_DEFAULT_WRITE_TIME_OUT, write=dify_config.SSRF_DEFAULT_WRITE_TIME_OUT,
) )
if "ssl_verify" not in kwargs:
kwargs["ssl_verify"] = HTTP_REQUEST_NODE_SSL_VERIFY
ssl_verify = kwargs.pop("ssl_verify")
retries = 0 retries = 0
while retries <= max_retries: while retries <= max_retries:
try: try:
if dify_config.SSRF_PROXY_ALL_URL: if dify_config.SSRF_PROXY_ALL_URL:
with httpx.Client(proxy=dify_config.SSRF_PROXY_ALL_URL, verify=HTTP_REQUEST_NODE_SSL_VERIFY) as client: with httpx.Client(proxy=dify_config.SSRF_PROXY_ALL_URL, verify=ssl_verify) as client:
response = client.request(method=method, url=url, **kwargs) response = client.request(method=method, url=url, **kwargs)
elif dify_config.SSRF_PROXY_HTTP_URL and dify_config.SSRF_PROXY_HTTPS_URL: elif dify_config.SSRF_PROXY_HTTP_URL and dify_config.SSRF_PROXY_HTTPS_URL:
proxy_mounts = { proxy_mounts = {
"http://": httpx.HTTPTransport( "http://": httpx.HTTPTransport(proxy=dify_config.SSRF_PROXY_HTTP_URL, verify=ssl_verify),
proxy=dify_config.SSRF_PROXY_HTTP_URL, verify=HTTP_REQUEST_NODE_SSL_VERIFY "https://": httpx.HTTPTransport(proxy=dify_config.SSRF_PROXY_HTTPS_URL, verify=ssl_verify),
),
"https://": httpx.HTTPTransport(
proxy=dify_config.SSRF_PROXY_HTTPS_URL, verify=HTTP_REQUEST_NODE_SSL_VERIFY
),
} }
with httpx.Client(mounts=proxy_mounts, verify=HTTP_REQUEST_NODE_SSL_VERIFY) as client: with httpx.Client(mounts=proxy_mounts, verify=ssl_verify) as client:
response = client.request(method=method, url=url, **kwargs) response = client.request(method=method, url=url, **kwargs)
else: else:
with httpx.Client(verify=HTTP_REQUEST_NODE_SSL_VERIFY) as client: with httpx.Client(verify=ssl_verify) as client:
response = client.request(method=method, url=url, **kwargs) response = client.request(method=method, url=url, **kwargs)
if response.status_code not in STATUS_FORCELIST: if response.status_code not in STATUS_FORCELIST:

View File

@ -44,6 +44,7 @@ class TokenBufferMemory:
Message.created_at, Message.created_at,
Message.workflow_run_id, Message.workflow_run_id,
Message.parent_message_id, Message.parent_message_id,
Message.answer_tokens,
) )
.filter( .filter(
Message.conversation_id == self.conversation.id, Message.conversation_id == self.conversation.id,
@ -63,7 +64,7 @@ class TokenBufferMemory:
thread_messages = extract_thread_messages(messages) thread_messages = extract_thread_messages(messages)
# for newly created message, its answer is temporarily empty, we don't need to add it to memory # for newly created message, its answer is temporarily empty, we don't need to add it to memory
if thread_messages and not thread_messages[0].answer: if thread_messages and not thread_messages[0].answer and thread_messages[0].answer_tokens == 0:
thread_messages.pop(0) thread_messages.pop(0)
messages = list(reversed(thread_messages)) messages = list(reversed(thread_messages))

View File

@ -177,7 +177,7 @@ class ModelInstance:
) )
def get_llm_num_tokens( def get_llm_num_tokens(
self, prompt_messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None self, prompt_messages: Sequence[PromptMessage], tools: Optional[Sequence[PromptMessageTool]] = None
) -> int: ) -> int:
""" """
Get number of tokens for llm Get number of tokens for llm

View File

@ -10,7 +10,7 @@
- 支持 5 种模型类型的能力调用 - 支持 5 种模型类型的能力调用
- `LLM` - LLM 文本补全、对话,预计算 tokens 能力 - `LLM` - LLM 文本补全、对话,预计算 tokens 能力
- `Text Embedding Model` - 文本 Embedding ,预计算 tokens 能力 - `Text Embedding Model` - 文本 Embedding预计算 tokens 能力
- `Rerank Model` - 分段 Rerank 能力 - `Rerank Model` - 分段 Rerank 能力
- `Speech-to-text Model` - 语音转文本能力 - `Speech-to-text Model` - 语音转文本能力
- `Text-to-speech Model` - 文本转语音能力 - `Text-to-speech Model` - 文本转语音能力
@ -57,11 +57,11 @@ Model Runtime 分三层:
提供获取当前供应商模型列表、获取模型实例、供应商凭据鉴权、供应商配置规则信息,**可横向扩展**以支持不同的供应商。 提供获取当前供应商模型列表、获取模型实例、供应商凭据鉴权、供应商配置规则信息,**可横向扩展**以支持不同的供应商。
对于供应商/模型凭据,有两种情况 对于供应商/模型凭据,有两种情况
- 如OpenAI这类中心化供应商需要定义如**api_key**这类的鉴权凭据 - 如 OpenAI 这类中心化供应商,需要定义如**api_key**这类的鉴权凭据
- 如[**Xinference**](https://github.com/xorbitsai/inference)这类本地部署的供应商,需要定义如**server_url**这类的地址凭据,有时候还需要定义**model_uid**之类的模型类型凭据,就像下面这样,当在供应商层定义了这些凭据后,就可以在前端页面上直接展示,无需修改前端逻辑。 - 如[**Xinference**](https://github.com/xorbitsai/inference)这类本地部署的供应商,需要定义如**server_url**这类的地址凭据,有时候还需要定义**model_uid**之类的模型类型凭据,就像下面这样,当在供应商层定义了这些凭据后,就可以在前端页面上直接展示,无需修改前端逻辑。
![Alt text](docs/zh_Hans/images/index/image.png) ![Alt text](docs/zh_Hans/images/index/image.png)
当配置好凭据后就可以通过DifyRuntime的外部接口直接获取到对应供应商所需要的**Schema**(凭据表单规则),从而在可以在不修改前端逻辑的情况下,提供新的供应商/模型的支持。 当配置好凭据后,就可以通过 DifyRuntime 的外部接口直接获取到对应供应商所需要的**Schema**(凭据表单规则),从而在可以在不修改前端逻辑的情况下,提供新的供应商/模型的支持。
- 最底层为模型层 - 最底层为模型层
@ -69,9 +69,9 @@ Model Runtime 分三层:
在这里我们需要先区分模型参数与模型凭据。 在这里我们需要先区分模型参数与模型凭据。
- 模型参数(**在本层定义**):这是一类经常需要变动,随时调整的参数,如 LLM 的 **max_tokens**、**temperature** 等这些参数是由用户在前端页面上进行调整的因此需要在后端定义参数的规则以便前端页面进行展示和调整。在DifyRuntime中他们的参数名一般为**model_parameters: dict[str, any]**。 - 模型参数 (**在本层定义**):这是一类经常需要变动,随时调整的参数,如 LLM 的 **max_tokens**、**temperature** 等,这些参数是由用户在前端页面上进行调整的,因此需要在后端定义参数的规则,以便前端页面进行展示和调整。在 DifyRuntime 中,他们的参数名一般为**model_parameters: dict[str, any]**。
- 模型凭据(**在供应商层定义**):这是一类不经常变动,一般在配置好后就不会再变动的参数,如 **api_key**、**server_url** 等。在DifyRuntime中他们的参数名一般为**credentials: dict[str, any]**Provider层的credentials会直接被传递到这一层不需要再单独定义。 - 模型凭据 (**在供应商层定义**):这是一类不经常变动,一般在配置好后就不会再变动的参数,如 **api_key**、**server_url** 等。在 DifyRuntime 中,他们的参数名一般为**credentials: dict[str, any]**Provider 层的 credentials 会直接被传递到这一层,不需要再单独定义。
## 下一步 ## 下一步
@ -81,7 +81,7 @@ Model Runtime 分三层:
![Alt text](docs/zh_Hans/images/index/image-1.png) ![Alt text](docs/zh_Hans/images/index/image-1.png)
### [为已存在的供应商新增模型 👈🏻](./docs/zh_Hans/provider_scale_out.md#增加模型) ### [为已存在的供应商新增模型 👈🏻](./docs/zh_Hans/provider_scale_out.md#增加模型)
当添加后对应供应商的模型列表中将会出现一个新的预定义模型供用户选择如GPT-3.5 GPT-4 ChatGLM3-6b等而对于支持自定义模型的供应商则不需要新增模型。 当添加后,对应供应商的模型列表中将会出现一个新的预定义模型供用户选择,如 GPT-3.5 GPT-4 ChatGLM3-6b 等,而对于支持自定义模型的供应商,则不需要新增模型。
![Alt text](docs/zh_Hans/images/index/image-2.png) ![Alt text](docs/zh_Hans/images/index/image-2.png)

View File

@ -58,7 +58,7 @@ class Callback(ABC):
chunk: LLMResultChunk, chunk: LLMResultChunk,
model: str, model: str,
credentials: dict, credentials: dict,
prompt_messages: list[PromptMessage], prompt_messages: Sequence[PromptMessage],
model_parameters: dict, model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[Sequence[str]] = None, stop: Optional[Sequence[str]] = None,
@ -88,7 +88,7 @@ class Callback(ABC):
result: LLMResult, result: LLMResult,
model: str, model: str,
credentials: dict, credentials: dict,
prompt_messages: list[PromptMessage], prompt_messages: Sequence[PromptMessage],
model_parameters: dict, model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[Sequence[str]] = None, stop: Optional[Sequence[str]] = None,

View File

@ -74,7 +74,7 @@ class LoggingCallback(Callback):
chunk: LLMResultChunk, chunk: LLMResultChunk,
model: str, model: str,
credentials: dict, credentials: dict,
prompt_messages: list[PromptMessage], prompt_messages: Sequence[PromptMessage],
model_parameters: dict, model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[Sequence[str]] = None, stop: Optional[Sequence[str]] = None,
@ -104,7 +104,7 @@ class LoggingCallback(Callback):
result: LLMResult, result: LLMResult,
model: str, model: str,
credentials: dict, credentials: dict,
prompt_messages: list[PromptMessage], prompt_messages: Sequence[PromptMessage],
model_parameters: dict, model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[Sequence[str]] = None, stop: Optional[Sequence[str]] = None,

View File

@ -102,12 +102,12 @@ provider_credential_schema:
```yaml ```yaml
- variable: server_url - variable: server_url
label: label:
zh_Hans: 服务器URL zh_Hans: 服务器 URL
en_US: Server url en_US: Server url
type: text-input type: text-input
required: true required: true
placeholder: placeholder:
zh_Hans: 在此输入Xinference的服务器地址如 https://example.com/xxx zh_Hans: 在此输入 Xinference 的服务器地址,如 https://example.com/xxx
en_US: Enter the url of your Xinference, for example https://example.com/xxx en_US: Enter the url of your Xinference, for example https://example.com/xxx
``` ```
@ -116,12 +116,12 @@ provider_credential_schema:
```yaml ```yaml
- variable: model_uid - variable: model_uid
label: label:
zh_Hans: 模型UID zh_Hans: 模型 UID
en_US: Model uid en_US: Model uid
type: text-input type: text-input
required: true required: true
placeholder: placeholder:
zh_Hans: 在此输入您的Model UID zh_Hans: 在此输入您的 Model UID
en_US: Enter the model uid en_US: Enter the model uid
``` ```

View File

@ -367,7 +367,7 @@ Inherit the `__base.text2speech_model.Text2SpeechModel` base class and implement
- Returns - Returns
Text converted speech stream Text converted speech stream.
### Moderation ### Moderation

View File

@ -6,14 +6,14 @@
需要注意的是,对于自定义模型,每一个模型的接入都需要填写一个完整的供应商凭据。 需要注意的是,对于自定义模型,每一个模型的接入都需要填写一个完整的供应商凭据。
而不同于预定义模型自定义供应商接入时永远会拥有如下两个参数不需要在供应商yaml中定义。 而不同于预定义模型,自定义供应商接入时永远会拥有如下两个参数,不需要在供应商 yaml 中定义。
![Alt text](images/index/image-3.png) ![Alt text](images/index/image-3.png)
在前文中,我们已经知道了供应商无需实现`validate_provider_credential`Runtime会自行根据用户在此选择的模型类型和模型名称调用对应的模型层的`validate_credentials`来进行验证。 在前文中,我们已经知道了供应商无需实现`validate_provider_credential`Runtime 会自行根据用户在此选择的模型类型和模型名称调用对应的模型层的`validate_credentials`来进行验证。
### 编写供应商yaml ### 编写供应商 yaml
我们首先要确定,接入的这个供应商支持哪些类型的模型。 我们首先要确定,接入的这个供应商支持哪些类型的模型。
@ -26,7 +26,7 @@
- `tts` 文字转语音 - `tts` 文字转语音
- `moderation` 审查 - `moderation` 审查
`Xinference`支持`LLM`和`Text Embedding`和Rerank那么我们开始编写`xinference.yaml`。 `Xinference`支持`LLM`和`Text Embedding`和 Rerank那么我们开始编写`xinference.yaml`。
```yaml ```yaml
provider: xinference #确定供应商标识 provider: xinference #确定供应商标识
@ -42,17 +42,17 @@ help: # 帮助
zh_Hans: 如何部署 Xinference zh_Hans: 如何部署 Xinference
url: url:
en_US: https://github.com/xorbitsai/inference en_US: https://github.com/xorbitsai/inference
supported_model_types: # 支持的模型类型Xinference同时支持LLM/Text Embedding/Rerank supported_model_types: # 支持的模型类型Xinference 同时支持 LLM/Text Embedding/Rerank
- llm - llm
- text-embedding - text-embedding
- rerank - rerank
configurate_methods: # 因为Xinference为本地部署的供应商并且没有预定义模型需要用什么模型需要根据Xinference的文档自己部署所以这里只支持自定义模型 configurate_methods: # 因为 Xinference 为本地部署的供应商,并且没有预定义模型,需要用什么模型需要根据 Xinference 的文档自己部署,所以这里只支持自定义模型
- customizable-model - customizable-model
provider_credential_schema: provider_credential_schema:
credential_form_schemas: credential_form_schemas:
``` ```
随后我们需要思考在Xinference中定义一个模型需要哪些凭据 随后,我们需要思考在 Xinference 中定义一个模型需要哪些凭据
- 它支持三种不同的模型,因此,我们需要有`model_type`来指定这个模型的类型,它有三种类型,所以我们这么编写 - 它支持三种不同的模型,因此,我们需要有`model_type`来指定这个模型的类型,它有三种类型,所以我们这么编写
```yaml ```yaml
@ -88,28 +88,28 @@ provider_credential_schema:
zh_Hans: 填写模型名称 zh_Hans: 填写模型名称
en_US: Input model name en_US: Input model name
``` ```
- 填写Xinference本地部署的地址 - 填写 Xinference 本地部署的地址
```yaml ```yaml
- variable: server_url - variable: server_url
label: label:
zh_Hans: 服务器URL zh_Hans: 服务器 URL
en_US: Server url en_US: Server url
type: text-input type: text-input
required: true required: true
placeholder: placeholder:
zh_Hans: 在此输入Xinference的服务器地址如 https://example.com/xxx zh_Hans: 在此输入 Xinference 的服务器地址,如 https://example.com/xxx
en_US: Enter the url of your Xinference, for example https://example.com/xxx en_US: Enter the url of your Xinference, for example https://example.com/xxx
``` ```
- 每个模型都有唯一的model_uid因此需要在这里定义 - 每个模型都有唯一的 model_uid因此需要在这里定义
```yaml ```yaml
- variable: model_uid - variable: model_uid
label: label:
zh_Hans: 模型UID zh_Hans: 模型 UID
en_US: Model uid en_US: Model uid
type: text-input type: text-input
required: true required: true
placeholder: placeholder:
zh_Hans: 在此输入您的Model UID zh_Hans: 在此输入您的 Model UID
en_US: Enter the model uid en_US: Enter the model uid
``` ```
现在,我们就完成了供应商的基础定义。 现在,我们就完成了供应商的基础定义。
@ -145,7 +145,7 @@ provider_credential_schema:
""" """
``` ```
在实现时需要注意使用两个函数来返回数据分别用于处理同步返回和流式返回因为Python会将函数中包含 `yield` 关键字的函数识别为生成器函数,返回的数据类型固定为 `Generator`,因此同步和流式返回需要分别实现,就像下面这样(注意下面例子使用了简化参数,实际实现时需要按照上面的参数列表进行实现): 在实现时,需要注意使用两个函数来返回数据,分别用于处理同步返回和流式返回,因为 Python 会将函数中包含 `yield` 关键字的函数识别为生成器函数,返回的数据类型固定为 `Generator`,因此同步和流式返回需要分别实现,就像下面这样(注意下面例子使用了简化参数,实际实现时需要按照上面的参数列表进行实现):
```python ```python
def _invoke(self, stream: bool, **kwargs) \ def _invoke(self, stream: bool, **kwargs) \
@ -179,7 +179,7 @@ provider_credential_schema:
""" """
``` ```
有时候也许你不需要直接返回0所以你可以使用`self._get_num_tokens_by_gpt2(text: str)`来获取预计算的tokens并确保环境变量`PLUGIN_BASED_TOKEN_COUNTING_ENABLED`设置为`true`,这个方法位于`AIModel`基类中它会使用GPT2的Tokenizer进行计算但是只能作为替代方法并不完全准确。 有时候,也许你不需要直接返回 0所以你可以使用`self._get_num_tokens_by_gpt2(text: str)`来获取预计算的 tokens并确保环境变量`PLUGIN_BASED_TOKEN_COUNTING_ENABLED`设置为`true`,这个方法位于`AIModel`基类中,它会使用 GPT2 Tokenizer 进行计算,但是只能作为替代方法,并不完全准确。
- 模型凭据校验 - 模型凭据校验
@ -196,13 +196,13 @@ provider_credential_schema:
""" """
``` ```
- 模型参数Schema - 模型参数 Schema
与自定义类型不同由于没有在yaml文件中定义一个模型支持哪些参数因此我们需要动态时间模型参数的Schema。 与自定义类型不同,由于没有在 yaml 文件中定义一个模型支持哪些参数,因此,我们需要动态时间模型参数的 Schema。
如Xinference支持`max_tokens` `temperature` `top_p` 这三个模型参数。 Xinference 支持`max_tokens` `temperature` `top_p` 这三个模型参数。
但是有的供应商根据不同的模型支持不同的参数,如供应商`OpenLLM`支持`top_k`,但是并不是这个供应商提供的所有模型都支持`top_k`我们这里举例A模型支持`top_k`B模型不支持`top_k`那么我们需要在这里动态生成模型参数的Schema如下所示 但是有的供应商根据不同的模型支持不同的参数,如供应商`OpenLLM`支持`top_k`,但是并不是这个供应商提供的所有模型都支持`top_k`,我们这里举例 A 模型支持`top_k`B 模型不支持`top_k`,那么我们需要在这里动态生成模型参数的 Schema如下所示
```python ```python
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]: def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:

View File

@ -687,7 +687,7 @@ class LLMUsage(ModelUsage):
total_tokens: int # 总使用 token 数 total_tokens: int # 总使用 token 数
total_price: Decimal # 总费用 total_price: Decimal # 总费用
currency: str # 货币单位 currency: str # 货币单位
latency: float # 请求耗时(s) latency: float # 请求耗时 (s)
``` ```
--- ---
@ -717,7 +717,7 @@ class EmbeddingUsage(ModelUsage):
price_unit: Decimal # 价格单位,即单价基于多少 tokens price_unit: Decimal # 价格单位,即单价基于多少 tokens
total_price: Decimal # 总费用 total_price: Decimal # 总费用
currency: str # 货币单位 currency: str # 货币单位
latency: float # 请求耗时(s) latency: float # 请求耗时 (s)
``` ```
--- ---

View File

@ -95,7 +95,7 @@ pricing: # 价格信息
""" """
``` ```
在实现时需要注意使用两个函数来返回数据分别用于处理同步返回和流式返回因为Python会将函数中包含 `yield` 关键字的函数识别为生成器函数,返回的数据类型固定为 `Generator`,因此同步和流式返回需要分别实现,就像下面这样(注意下面例子使用了简化参数,实际实现时需要按照上面的参数列表进行实现): 在实现时,需要注意使用两个函数来返回数据,分别用于处理同步返回和流式返回,因为 Python 会将函数中包含 `yield` 关键字的函数识别为生成器函数,返回的数据类型固定为 `Generator`,因此同步和流式返回需要分别实现,就像下面这样(注意下面例子使用了简化参数,实际实现时需要按照上面的参数列表进行实现):
```python ```python
def _invoke(self, stream: bool, **kwargs) \ def _invoke(self, stream: bool, **kwargs) \

View File

@ -8,13 +8,13 @@
- `customizable-model` 自定义模型 - `customizable-model` 自定义模型
用户需要新增每个模型的凭据配置如Xinference它同时支持 LLM 和 Text Embedding但是每个模型都有唯一的**model_uid**,如果想要将两者同时接入,就需要为每个模型配置一个**model_uid**。 用户需要新增每个模型的凭据配置,如 Xinference它同时支持 LLM 和 Text Embedding但是每个模型都有唯一的**model_uid**,如果想要将两者同时接入,就需要为每个模型配置一个**model_uid**。
- `fetch-from-remote` 从远程获取 - `fetch-from-remote` 从远程获取
`predefined-model` 配置方式一致,只需要配置统一的供应商凭据即可,模型通过凭据信息从供应商获取。 `predefined-model` 配置方式一致,只需要配置统一的供应商凭据即可,模型通过凭据信息从供应商获取。
如OpenAI我们可以基于gpt-turbo-3.5来Fine Tune多个模型而他们都位于同一个**api_key**下,当配置为 `fetch-from-remote` 时,开发者只需要配置统一的**api_key**即可让DifyRuntime获取到开发者所有的微调模型并接入Dify。 OpenAI我们可以基于 gpt-turbo-3.5 Fine Tune 多个模型,而他们都位于同一个**api_key**下,当配置为 `fetch-from-remote` 时,开发者只需要配置统一的**api_key**即可让 DifyRuntime 获取到开发者所有的微调模型并接入 Dify。
这三种配置方式**支持共存**,即存在供应商支持 `predefined-model` + `customizable-model``predefined-model` + `fetch-from-remote` 等,也就是配置了供应商统一凭据可以使用预定义模型和从远程获取的模型,若新增了模型,则可以在此基础上额外使用自定义的模型。 这三种配置方式**支持共存**,即存在供应商支持 `predefined-model` + `customizable-model``predefined-model` + `fetch-from-remote` 等,也就是配置了供应商统一凭据可以使用预定义模型和从远程获取的模型,若新增了模型,则可以在此基础上额外使用自定义的模型。
@ -23,16 +23,16 @@
### 介绍 ### 介绍
#### 名词解释 #### 名词解释
- `module`: 一个`module`即为一个Python Package或者通俗一点称为一个文件夹里面包含了一个`__init__.py`文件,以及其他的`.py`文件。 - `module`: 一个`module`即为一个 Python Package或者通俗一点称为一个文件夹里面包含了一个`__init__.py`文件,以及其他的`.py`文件。
#### 步骤 #### 步骤
新增一个供应商主要分为几步,这里简单列出,帮助大家有一个大概的认识,具体的步骤会在下面详细介绍。 新增一个供应商主要分为几步,这里简单列出,帮助大家有一个大概的认识,具体的步骤会在下面详细介绍。
- 创建供应商yaml文件根据[ProviderSchema](./schema.md#provider)编写 - 创建供应商 yaml 文件,根据[ProviderSchema](./schema.md#provider)编写
- 创建供应商代码,实现一个`class`。 - 创建供应商代码,实现一个`class`。
- 根据模型类型,在供应商`module`下创建对应的模型类型 `module`,如`llm`或`text_embedding`。 - 根据模型类型,在供应商`module`下创建对应的模型类型 `module`,如`llm`或`text_embedding`。
- 根据模型类型,在对应的模型`module`下创建同名的代码文件,如`llm.py`,并实现一个`class`。 - 根据模型类型,在对应的模型`module`下创建同名的代码文件,如`llm.py`,并实现一个`class`。
- 如果有预定义模型根据模型名称创建同名的yaml文件在模型`module`下,如`claude-2.1.yaml`,根据[AIModelEntity](./schema.md#aimodelentity)编写。 - 如果有预定义模型,根据模型名称创建同名的 yaml 文件在模型`module`下,如`claude-2.1.yaml`,根据[AIModelEntity](./schema.md#aimodelentity)编写。
- 编写测试代码,确保功能可用。 - 编写测试代码,确保功能可用。
### 开始吧 ### 开始吧
@ -121,11 +121,11 @@ model_credential_schema:
#### 实现供应商代码 #### 实现供应商代码
我们需要在`model_providers`下创建一个同名的python文件如`anthropic.py`,并实现一个`class`,继承`__base.provider.Provider`基类,如`AnthropicProvider`。 我们需要在`model_providers`下创建一个同名的 python 文件,如`anthropic.py`,并实现一个`class`,继承`__base.provider.Provider`基类,如`AnthropicProvider`。
##### 自定义模型供应商 ##### 自定义模型供应商
当供应商为Xinference等自定义模型供应商时可跳过该步骤仅创建一个空的`XinferenceProvider`类即可,并实现一个空的`validate_provider_credentials`方法,该方法并不会被实际使用,仅用作避免抽象类无法实例化。 当供应商为 Xinference 等自定义模型供应商时,可跳过该步骤,仅创建一个空的`XinferenceProvider`类即可,并实现一个空的`validate_provider_credentials`方法,该方法并不会被实际使用,仅用作避免抽象类无法实例化。
```python ```python
class XinferenceProvider(Provider): class XinferenceProvider(Provider):
@ -155,7 +155,7 @@ def validate_provider_credentials(self, credentials: dict) -> None:
#### 增加模型 #### 增加模型
#### [增加预定义模型 👈🏻](./predefined_model_scale_out.md) #### [增加预定义模型 👈🏻](./predefined_model_scale_out.md)
对于预定义模型我们可以通过简单定义一个yaml并通过实现调用代码来接入。 对于预定义模型,我们可以通过简单定义一个 yaml并通过实现调用代码来接入。
#### [增加自定义模型 👈🏻](./customizable_model_scale_out.md) #### [增加自定义模型 👈🏻](./customizable_model_scale_out.md)
对于自定义模型,我们只需要实现调用代码即可接入,但是它需要处理的参数可能会更加复杂。 对于自定义模型,我们只需要实现调用代码即可接入,但是它需要处理的参数可能会更加复杂。

View File

@ -29,7 +29,7 @@ PARAMETER_RULE_TEMPLATE: dict[DefaultParameterName, dict] = {
"help": { "help": {
"en_US": "Controls diversity via nucleus sampling: 0.5 means half of all likelihood-weighted options" "en_US": "Controls diversity via nucleus sampling: 0.5 means half of all likelihood-weighted options"
" are considered.", " are considered.",
"zh_Hans": "通过核心采样控制多样性0.5表示考虑了一半的所有可能性加权选项。", "zh_Hans": "通过核心采样控制多样性0.5 表示考虑了一半的所有可能性加权选项。",
}, },
"required": False, "required": False,
"default": 1.0, "default": 1.0,
@ -111,7 +111,7 @@ PARAMETER_RULE_TEMPLATE: dict[DefaultParameterName, dict] = {
"help": { "help": {
"en_US": "Set a response format, ensure the output from llm is a valid code block as possible," "en_US": "Set a response format, ensure the output from llm is a valid code block as possible,"
" such as JSON, XML, etc.", " such as JSON, XML, etc.",
"zh_Hans": "设置一个返回格式,确保llm的输出尽可能是有效的代码块如JSON、XML等", "zh_Hans": "设置一个返回格式,确保 llm 的输出尽可能是有效的代码块,如 JSON、XML ",
}, },
"required": False, "required": False,
"options": ["JSON", "XML"], "options": ["JSON", "XML"],
@ -123,7 +123,7 @@ PARAMETER_RULE_TEMPLATE: dict[DefaultParameterName, dict] = {
"type": "text", "type": "text",
"help": { "help": {
"en_US": "Set a response json schema will ensure LLM to adhere it.", "en_US": "Set a response json schema will ensure LLM to adhere it.",
"zh_Hans": "设置返回的json schemallm将按照它返回", "zh_Hans": "设置返回的 json schemallm 将按照它返回",
}, },
"required": False, "required": False,
}, },

View File

@ -1,8 +1,9 @@
from collections.abc import Sequence
from decimal import Decimal from decimal import Decimal
from enum import StrEnum from enum import StrEnum
from typing import Optional from typing import Optional
from pydantic import BaseModel from pydantic import BaseModel, Field
from core.model_runtime.entities.message_entities import AssistantPromptMessage, PromptMessage from core.model_runtime.entities.message_entities import AssistantPromptMessage, PromptMessage
from core.model_runtime.entities.model_entities import ModelUsage, PriceInfo from core.model_runtime.entities.model_entities import ModelUsage, PriceInfo
@ -107,7 +108,7 @@ class LLMResult(BaseModel):
id: Optional[str] = None id: Optional[str] = None
model: str model: str
prompt_messages: list[PromptMessage] prompt_messages: Sequence[PromptMessage] = Field(default_factory=list)
message: AssistantPromptMessage message: AssistantPromptMessage
usage: LLMUsage usage: LLMUsage
system_fingerprint: Optional[str] = None system_fingerprint: Optional[str] = None
@ -130,7 +131,7 @@ class LLMResultChunk(BaseModel):
""" """
model: str model: str
prompt_messages: list[PromptMessage] prompt_messages: Sequence[PromptMessage] = Field(default_factory=list)
system_fingerprint: Optional[str] = None system_fingerprint: Optional[str] = None
delta: LLMResultChunkDelta delta: LLMResultChunkDelta

View File

@ -1,5 +1,6 @@
import logging import logging
import time import time
import uuid
from collections.abc import Generator, Sequence from collections.abc import Generator, Sequence
from typing import Optional, Union from typing import Optional, Union
@ -24,6 +25,58 @@ from core.plugin.manager.model import PluginModelManager
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def _gen_tool_call_id() -> str:
return f"chatcmpl-tool-{str(uuid.uuid4().hex)}"
def _increase_tool_call(
new_tool_calls: list[AssistantPromptMessage.ToolCall], existing_tools_calls: list[AssistantPromptMessage.ToolCall]
):
"""
Merge incremental tool call updates into existing tool calls.
:param new_tool_calls: List of new tool call deltas to be merged.
:param existing_tools_calls: List of existing tool calls to be modified IN-PLACE.
"""
def get_tool_call(tool_call_id: str):
"""
Get or create a tool call by ID
:param tool_call_id: tool call ID
:return: existing or new tool call
"""
if not tool_call_id:
return existing_tools_calls[-1]
_tool_call = next((_tool_call for _tool_call in existing_tools_calls if _tool_call.id == tool_call_id), None)
if _tool_call is None:
_tool_call = AssistantPromptMessage.ToolCall(
id=tool_call_id,
type="function",
function=AssistantPromptMessage.ToolCall.ToolCallFunction(name="", arguments=""),
)
existing_tools_calls.append(_tool_call)
return _tool_call
for new_tool_call in new_tool_calls:
# generate ID for tool calls with function name but no ID to track them
if new_tool_call.function.name and not new_tool_call.id:
new_tool_call.id = _gen_tool_call_id()
# get tool call
tool_call = get_tool_call(new_tool_call.id)
# update tool call
if new_tool_call.id:
tool_call.id = new_tool_call.id
if new_tool_call.type:
tool_call.type = new_tool_call.type
if new_tool_call.function.name:
tool_call.function.name = new_tool_call.function.name
if new_tool_call.function.arguments:
tool_call.function.arguments += new_tool_call.function.arguments
class LargeLanguageModel(AIModel): class LargeLanguageModel(AIModel):
""" """
Model class for large language model. Model class for large language model.
@ -45,7 +98,7 @@ class LargeLanguageModel(AIModel):
stream: bool = True, stream: bool = True,
user: Optional[str] = None, user: Optional[str] = None,
callbacks: Optional[list[Callback]] = None, callbacks: Optional[list[Callback]] = None,
) -> Union[LLMResult, Generator]: ) -> Union[LLMResult, Generator[LLMResultChunk, None, None]]:
""" """
Invoke large language model Invoke large language model
@ -109,44 +162,13 @@ class LargeLanguageModel(AIModel):
system_fingerprint = None system_fingerprint = None
tools_calls: list[AssistantPromptMessage.ToolCall] = [] tools_calls: list[AssistantPromptMessage.ToolCall] = []
def increase_tool_call(new_tool_calls: list[AssistantPromptMessage.ToolCall]):
def get_tool_call(tool_name: str):
if not tool_name:
return tools_calls[-1]
tool_call = next(
(tool_call for tool_call in tools_calls if tool_call.function.name == tool_name), None
)
if tool_call is None:
tool_call = AssistantPromptMessage.ToolCall(
id="",
type="",
function=AssistantPromptMessage.ToolCall.ToolCallFunction(name=tool_name, arguments=""),
)
tools_calls.append(tool_call)
return tool_call
for new_tool_call in new_tool_calls:
# get tool call
tool_call = get_tool_call(new_tool_call.function.name)
# update tool call
if new_tool_call.id:
tool_call.id = new_tool_call.id
if new_tool_call.type:
tool_call.type = new_tool_call.type
if new_tool_call.function.name:
tool_call.function.name = new_tool_call.function.name
if new_tool_call.function.arguments:
tool_call.function.arguments += new_tool_call.function.arguments
for chunk in result: for chunk in result:
if isinstance(chunk.delta.message.content, str): if isinstance(chunk.delta.message.content, str):
content += chunk.delta.message.content content += chunk.delta.message.content
elif isinstance(chunk.delta.message.content, list): elif isinstance(chunk.delta.message.content, list):
content_list.extend(chunk.delta.message.content) content_list.extend(chunk.delta.message.content)
if chunk.delta.message.tool_calls: if chunk.delta.message.tool_calls:
increase_tool_call(chunk.delta.message.tool_calls) _increase_tool_call(chunk.delta.message.tool_calls, tools_calls)
usage = chunk.delta.usage or LLMUsage.empty_usage() usage = chunk.delta.usage or LLMUsage.empty_usage()
system_fingerprint = chunk.system_fingerprint system_fingerprint = chunk.system_fingerprint
@ -205,22 +227,26 @@ class LargeLanguageModel(AIModel):
user=user, user=user,
callbacks=callbacks, callbacks=callbacks,
) )
# Following https://github.com/langgenius/dify/issues/17799,
return result # we removed the prompt_messages from the chunk on the plugin daemon side.
# To ensure compatibility, we add the prompt_messages back here.
result.prompt_messages = prompt_messages
return result
raise NotImplementedError("unsupported invoke result type", type(result))
def _invoke_result_generator( def _invoke_result_generator(
self, self,
model: str, model: str,
result: Generator, result: Generator,
credentials: dict, credentials: dict,
prompt_messages: list[PromptMessage], prompt_messages: Sequence[PromptMessage],
model_parameters: dict, model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[Sequence[str]] = None, stop: Optional[Sequence[str]] = None,
stream: bool = True, stream: bool = True,
user: Optional[str] = None, user: Optional[str] = None,
callbacks: Optional[list[Callback]] = None, callbacks: Optional[list[Callback]] = None,
) -> Generator: ) -> Generator[LLMResultChunk, None, None]:
""" """
Invoke result generator Invoke result generator
@ -235,6 +261,10 @@ class LargeLanguageModel(AIModel):
try: try:
for chunk in result: for chunk in result:
# Following https://github.com/langgenius/dify/issues/17799,
# we removed the prompt_messages from the chunk on the plugin daemon side.
# To ensure compatibility, we add the prompt_messages back here.
chunk.prompt_messages = prompt_messages
yield chunk yield chunk
self._trigger_new_chunk_callbacks( self._trigger_new_chunk_callbacks(
@ -403,7 +433,7 @@ class LargeLanguageModel(AIModel):
chunk: LLMResultChunk, chunk: LLMResultChunk,
model: str, model: str,
credentials: dict, credentials: dict,
prompt_messages: list[PromptMessage], prompt_messages: Sequence[PromptMessage],
model_parameters: dict, model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[Sequence[str]] = None, stop: Optional[Sequence[str]] = None,
@ -450,7 +480,7 @@ class LargeLanguageModel(AIModel):
model: str, model: str,
result: LLMResult, result: LLMResult,
credentials: dict, credentials: dict,
prompt_messages: list[PromptMessage], prompt_messages: Sequence[PromptMessage],
model_parameters: dict, model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[Sequence[str]] = None, stop: Optional[Sequence[str]] = None,

View File

@ -5,6 +5,7 @@ from datetime import datetime, timedelta
from typing import Optional from typing import Optional
from langfuse import Langfuse # type: ignore from langfuse import Langfuse # type: ignore
from sqlalchemy.orm import sessionmaker
from core.ops.base_trace_instance import BaseTraceInstance from core.ops.base_trace_instance import BaseTraceInstance
from core.ops.entities.config_entity import LangfuseConfig from core.ops.entities.config_entity import LangfuseConfig
@ -28,9 +29,9 @@ from core.ops.langfuse_trace.entities.langfuse_trace_entity import (
UnitEnum, UnitEnum,
) )
from core.ops.utils import filter_none_values from core.ops.utils import filter_none_values
from core.repository.repository_factory import RepositoryFactory
from extensions.ext_database import db from extensions.ext_database import db
from models.model import EndUser from models.model import EndUser
from models.workflow import WorkflowNodeExecution
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -110,36 +111,18 @@ class LangFuseDataTrace(BaseTraceInstance):
) )
self.add_trace(langfuse_trace_data=trace_data) self.add_trace(langfuse_trace_data=trace_data)
# through workflow_run_id get all_nodes_execution # through workflow_run_id get all_nodes_execution using repository
workflow_nodes_execution_id_records = ( session_factory = sessionmaker(bind=db.engine)
db.session.query(WorkflowNodeExecution.id) workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository(
.filter(WorkflowNodeExecution.workflow_run_id == trace_info.workflow_run_id) params={"tenant_id": trace_info.tenant_id, "session_factory": session_factory},
.all()
) )
for node_execution_id_record in workflow_nodes_execution_id_records: # Get all executions for this workflow run
node_execution = ( workflow_node_executions = workflow_node_execution_repository.get_by_workflow_run(
db.session.query( workflow_run_id=trace_info.workflow_run_id
WorkflowNodeExecution.id, )
WorkflowNodeExecution.tenant_id,
WorkflowNodeExecution.app_id,
WorkflowNodeExecution.title,
WorkflowNodeExecution.node_type,
WorkflowNodeExecution.status,
WorkflowNodeExecution.inputs,
WorkflowNodeExecution.outputs,
WorkflowNodeExecution.created_at,
WorkflowNodeExecution.elapsed_time,
WorkflowNodeExecution.process_data,
WorkflowNodeExecution.execution_metadata,
)
.filter(WorkflowNodeExecution.id == node_execution_id_record.id)
.first()
)
if not node_execution:
continue
for node_execution in workflow_node_executions:
node_execution_id = node_execution.id node_execution_id = node_execution.id
tenant_id = node_execution.tenant_id tenant_id = node_execution.tenant_id
app_id = node_execution.app_id app_id = node_execution.app_id

View File

@ -7,6 +7,7 @@ from typing import Optional, cast
from langsmith import Client from langsmith import Client
from langsmith.schemas import RunBase from langsmith.schemas import RunBase
from sqlalchemy.orm import sessionmaker
from core.ops.base_trace_instance import BaseTraceInstance from core.ops.base_trace_instance import BaseTraceInstance
from core.ops.entities.config_entity import LangSmithConfig from core.ops.entities.config_entity import LangSmithConfig
@ -27,9 +28,9 @@ from core.ops.langsmith_trace.entities.langsmith_trace_entity import (
LangSmithRunUpdateModel, LangSmithRunUpdateModel,
) )
from core.ops.utils import filter_none_values, generate_dotted_order from core.ops.utils import filter_none_values, generate_dotted_order
from core.repository.repository_factory import RepositoryFactory
from extensions.ext_database import db from extensions.ext_database import db
from models.model import EndUser, MessageFile from models.model import EndUser, MessageFile
from models.workflow import WorkflowNodeExecution
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -134,36 +135,22 @@ class LangSmithDataTrace(BaseTraceInstance):
self.add_run(langsmith_run) self.add_run(langsmith_run)
# through workflow_run_id get all_nodes_execution # through workflow_run_id get all_nodes_execution using repository
workflow_nodes_execution_id_records = ( session_factory = sessionmaker(bind=db.engine)
db.session.query(WorkflowNodeExecution.id) workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository(
.filter(WorkflowNodeExecution.workflow_run_id == trace_info.workflow_run_id) params={
.all() "tenant_id": trace_info.tenant_id,
"app_id": trace_info.metadata.get("app_id"),
"session_factory": session_factory,
},
) )
for node_execution_id_record in workflow_nodes_execution_id_records: # Get all executions for this workflow run
node_execution = ( workflow_node_executions = workflow_node_execution_repository.get_by_workflow_run(
db.session.query( workflow_run_id=trace_info.workflow_run_id
WorkflowNodeExecution.id, )
WorkflowNodeExecution.tenant_id,
WorkflowNodeExecution.app_id,
WorkflowNodeExecution.title,
WorkflowNodeExecution.node_type,
WorkflowNodeExecution.status,
WorkflowNodeExecution.inputs,
WorkflowNodeExecution.outputs,
WorkflowNodeExecution.created_at,
WorkflowNodeExecution.elapsed_time,
WorkflowNodeExecution.process_data,
WorkflowNodeExecution.execution_metadata,
)
.filter(WorkflowNodeExecution.id == node_execution_id_record.id)
.first()
)
if not node_execution:
continue
for node_execution in workflow_node_executions:
node_execution_id = node_execution.id node_execution_id = node_execution.id
tenant_id = node_execution.tenant_id tenant_id = node_execution.tenant_id
app_id = node_execution.app_id app_id = node_execution.app_id

View File

@ -7,6 +7,7 @@ from typing import Optional, cast
from opik import Opik, Trace from opik import Opik, Trace
from opik.id_helpers import uuid4_to_uuid7 from opik.id_helpers import uuid4_to_uuid7
from sqlalchemy.orm import sessionmaker
from core.ops.base_trace_instance import BaseTraceInstance from core.ops.base_trace_instance import BaseTraceInstance
from core.ops.entities.config_entity import OpikConfig from core.ops.entities.config_entity import OpikConfig
@ -21,9 +22,9 @@ from core.ops.entities.trace_entity import (
TraceTaskName, TraceTaskName,
WorkflowTraceInfo, WorkflowTraceInfo,
) )
from core.repository.repository_factory import RepositoryFactory
from extensions.ext_database import db from extensions.ext_database import db
from models.model import EndUser, MessageFile from models.model import EndUser, MessageFile
from models.workflow import WorkflowNodeExecution
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -147,36 +148,22 @@ class OpikDataTrace(BaseTraceInstance):
} }
self.add_trace(trace_data) self.add_trace(trace_data)
# through workflow_run_id get all_nodes_execution # through workflow_run_id get all_nodes_execution using repository
workflow_nodes_execution_id_records = ( session_factory = sessionmaker(bind=db.engine)
db.session.query(WorkflowNodeExecution.id) workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository(
.filter(WorkflowNodeExecution.workflow_run_id == trace_info.workflow_run_id) params={
.all() "tenant_id": trace_info.tenant_id,
"app_id": trace_info.metadata.get("app_id"),
"session_factory": session_factory,
},
) )
for node_execution_id_record in workflow_nodes_execution_id_records: # Get all executions for this workflow run
node_execution = ( workflow_node_executions = workflow_node_execution_repository.get_by_workflow_run(
db.session.query( workflow_run_id=trace_info.workflow_run_id
WorkflowNodeExecution.id, )
WorkflowNodeExecution.tenant_id,
WorkflowNodeExecution.app_id,
WorkflowNodeExecution.title,
WorkflowNodeExecution.node_type,
WorkflowNodeExecution.status,
WorkflowNodeExecution.inputs,
WorkflowNodeExecution.outputs,
WorkflowNodeExecution.created_at,
WorkflowNodeExecution.elapsed_time,
WorkflowNodeExecution.process_data,
WorkflowNodeExecution.execution_metadata,
)
.filter(WorkflowNodeExecution.id == node_execution_id_record.id)
.first()
)
if not node_execution:
continue
for node_execution in workflow_node_executions:
node_execution_id = node_execution.id node_execution_id = node_execution.id
tenant_id = node_execution.tenant_id tenant_id = node_execution.tenant_id
app_id = node_execution.app_id app_id = node_execution.app_id

View File

@ -453,7 +453,7 @@ class TraceTask:
"version": workflow_run_version, "version": workflow_run_version,
"total_tokens": total_tokens, "total_tokens": total_tokens,
"file_list": file_list, "file_list": file_list,
"triggered_form": workflow_run.triggered_from, "triggered_from": workflow_run.triggered_from,
"user_id": user_id, "user_id": user_id,
} }

View File

@ -2,6 +2,7 @@ from collections.abc import Generator, Mapping
from typing import Optional, Union from typing import Optional, Union
from controllers.service_api.wraps import create_or_update_end_user_for_user_id from controllers.service_api.wraps import create_or_update_end_user_for_user_id
from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict
from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator
from core.app.apps.agent_chat.app_generator import AgentChatAppGenerator from core.app.apps.agent_chat.app_generator import AgentChatAppGenerator
from core.app.apps.chat.app_generator import ChatAppGenerator from core.app.apps.chat.app_generator import ChatAppGenerator
@ -15,6 +16,34 @@ from models.model import App, AppMode, EndUser
class PluginAppBackwardsInvocation(BaseBackwardsInvocation): class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
@classmethod
def fetch_app_info(cls, app_id: str, tenant_id: str) -> Mapping:
"""
Fetch app info
"""
app = cls._get_app(app_id, tenant_id)
"""Retrieve app parameters."""
if app.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}:
workflow = app.workflow
if workflow is None:
raise ValueError("unexpected app type")
features_dict = workflow.features_dict
user_input_form = workflow.user_input_form(to_old_structure=True)
else:
app_model_config = app.app_model_config
if app_model_config is None:
raise ValueError("unexpected app type")
features_dict = app_model_config.to_dict()
user_input_form = features_dict.get("user_input_form", [])
return {
"data": get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form),
}
@classmethod @classmethod
def invoke_app( def invoke_app(
cls, cls,

View File

@ -131,7 +131,7 @@ def cast_parameter_value(typ: enum.StrEnum, value: Any, /):
raise ValueError("The selector must be a dictionary.") raise ValueError("The selector must be a dictionary.")
return value return value
case PluginParameterType.TOOLS_SELECTOR: case PluginParameterType.TOOLS_SELECTOR:
if not isinstance(value, list): if value and not isinstance(value, list):
raise ValueError("The tools selector must be a list.") raise ValueError("The tools selector must be a list.")
return value return value
case _: case _:
@ -147,7 +147,7 @@ def init_frontend_parameter(rule: PluginParameter, type: enum.StrEnum, value: An
init frontend parameter by rule init frontend parameter by rule
""" """
parameter_value = value parameter_value = value
if not parameter_value and parameter_value != 0 and type != PluginParameterType.TOOLS_SELECTOR: if not parameter_value and parameter_value != 0:
# get default value # get default value
parameter_value = rule.default parameter_value = rule.default
if not parameter_value and rule.required: if not parameter_value and rule.required:

View File

@ -70,6 +70,9 @@ class PluginDeclaration(BaseModel):
models: Optional[list[str]] = Field(default_factory=list) models: Optional[list[str]] = Field(default_factory=list)
endpoints: Optional[list[str]] = Field(default_factory=list) endpoints: Optional[list[str]] = Field(default_factory=list)
class Meta(BaseModel):
minimum_dify_version: Optional[str] = Field(default=None, pattern=r"^\d{1,4}(\.\d{1,4}){1,3}(-\w{1,16})?$")
version: str = Field(..., pattern=r"^\d{1,4}(\.\d{1,4}){1,3}(-\w{1,16})?$") version: str = Field(..., pattern=r"^\d{1,4}(\.\d{1,4}){1,3}(-\w{1,16})?$")
author: Optional[str] = Field(..., pattern=r"^[a-zA-Z0-9_-]{1,64}$") author: Optional[str] = Field(..., pattern=r"^[a-zA-Z0-9_-]{1,64}$")
name: str = Field(..., pattern=r"^[a-z0-9_-]{1,128}$") name: str = Field(..., pattern=r"^[a-z0-9_-]{1,128}$")
@ -86,6 +89,7 @@ class PluginDeclaration(BaseModel):
model: Optional[ProviderEntity] = None model: Optional[ProviderEntity] = None
endpoint: Optional[EndpointProviderDeclaration] = None endpoint: Optional[EndpointProviderDeclaration] = None
agent_strategy: Optional[AgentStrategyProviderEntity] = None agent_strategy: Optional[AgentStrategyProviderEntity] = None
meta: Meta
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod

View File

@ -204,3 +204,11 @@ class RequestRequestUploadFile(BaseModel):
filename: str filename: str
mimetype: str mimetype: str
class RequestFetchAppInfo(BaseModel):
"""
Request to fetch app info
"""
app_id: str

View File

@ -82,7 +82,7 @@ class BasePluginManager:
Make a stream request to the plugin daemon inner API Make a stream request to the plugin daemon inner API
""" """
response = self._request(method, path, headers, data, params, files, stream=True) response = self._request(method, path, headers, data, params, files, stream=True)
for line in response.iter_lines(): for line in response.iter_lines(chunk_size=1024 * 8):
line = line.decode("utf-8").strip() line = line.decode("utf-8").strip()
if line.startswith("data:"): if line.startswith("data:"):
line = line[5:].strip() line = line[5:].strip()
@ -168,16 +168,18 @@ class BasePluginManager:
Make a stream request to the plugin daemon inner API and yield the response as a model. Make a stream request to the plugin daemon inner API and yield the response as a model.
""" """
for line in self._stream_request(method, path, params, headers, data, files): for line in self._stream_request(method, path, params, headers, data, files):
line_data = None
try: try:
line_data = json.loads(line) rep = PluginDaemonBasicResponse[type].model_validate_json(line) # type: ignore
rep = PluginDaemonBasicResponse[type](**line_data) # type: ignore except (ValueError, TypeError):
except Exception:
# TODO modify this when line_data has code and message # TODO modify this when line_data has code and message
if line_data and "error" in line_data: try:
raise ValueError(line_data["error"]) line_data = json.loads(line)
else: except (ValueError, TypeError):
raise ValueError(line) raise ValueError(line)
# If the dictionary contains the `error` key, use its value as the argument
# for `ValueError`.
# Otherwise, use the `line` to provide better contextual information about the error.
raise ValueError(line_data.get("error", line))
if rep.code != 0: if rep.code != 0:
if rep.code == -500: if rep.code == -500:

View File

@ -110,7 +110,62 @@ class PluginToolManager(BasePluginManager):
"Content-Type": "application/json", "Content-Type": "application/json",
}, },
) )
return response
class FileChunk:
"""
Only used for internal processing.
"""
bytes_written: int
total_length: int
data: bytearray
def __init__(self, total_length: int):
self.bytes_written = 0
self.total_length = total_length
self.data = bytearray(total_length)
files: dict[str, FileChunk] = {}
for resp in response:
if resp.type == ToolInvokeMessage.MessageType.BLOB_CHUNK:
assert isinstance(resp.message, ToolInvokeMessage.BlobChunkMessage)
# Get blob chunk information
chunk_id = resp.message.id
total_length = resp.message.total_length
blob_data = resp.message.blob
is_end = resp.message.end
# Initialize buffer for this file if it doesn't exist
if chunk_id not in files:
files[chunk_id] = FileChunk(total_length)
# If this is the final chunk, yield a complete blob message
if is_end:
yield ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.BLOB,
message=ToolInvokeMessage.BlobMessage(blob=files[chunk_id].data),
meta=resp.meta,
)
else:
# Check if file is too large (30MB limit)
if files[chunk_id].bytes_written + len(blob_data) > 30 * 1024 * 1024:
# Delete the file if it's too large
del files[chunk_id]
# Skip yielding this message
raise ValueError("File is too large which reached the limit of 30MB")
# Check if single chunk is too large (8KB limit)
if len(blob_data) > 8192:
# Skip yielding this message
raise ValueError("File chunk is too large which reached the limit of 8KB")
# Append the blob data to the buffer
files[chunk_id].data[
files[chunk_id].bytes_written : files[chunk_id].bytes_written + len(blob_data)
] = blob_data
files[chunk_id].bytes_written += len(blob_data)
else:
yield resp
def validate_provider_credentials( def validate_provider_credentials(
self, tenant_id: str, user_id: str, provider: str, credentials: dict[str, Any] self, tenant_id: str, user_id: str, provider: str, credentials: dict[str, Any]

View File

@ -28,7 +28,7 @@ BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG = {
}, },
"conversation_histories_role": {"user_prefix": "用户", "assistant_prefix": "助手"}, "conversation_histories_role": {"user_prefix": "用户", "assistant_prefix": "助手"},
}, },
"stop": ["用户:"], "stop": ["用户"],
} }
BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG = { BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG = {
@ -41,5 +41,5 @@ BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG = {
BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG = { BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG = {
"completion_prompt_config": {"prompt": {"text": "{{#pre_prompt#}}"}}, "completion_prompt_config": {"prompt": {"text": "{{#pre_prompt#}}"}},
"stop": ["用户:"], "stop": ["用户"],
} }

View File

@ -124,6 +124,15 @@ class ProviderManager:
# Get All preferred provider types of the workspace # Get All preferred provider types of the workspace
provider_name_to_preferred_model_provider_records_dict = self._get_all_preferred_model_providers(tenant_id) provider_name_to_preferred_model_provider_records_dict = self._get_all_preferred_model_providers(tenant_id)
# Ensure that both the original provider name and its ModelProviderID string representation
# are present in the dictionary to handle cases where either form might be used
for provider_name in list(provider_name_to_preferred_model_provider_records_dict.keys()):
provider_id = ModelProviderID(provider_name)
if str(provider_id) not in provider_name_to_preferred_model_provider_records_dict:
# Add the ModelProviderID string representation if it's not already present
provider_name_to_preferred_model_provider_records_dict[str(provider_id)] = (
provider_name_to_preferred_model_provider_records_dict[provider_name]
)
# Get All provider model settings # Get All provider model settings
provider_name_to_provider_model_settings_dict = self._get_all_provider_model_settings(tenant_id) provider_name_to_provider_model_settings_dict = self._get_all_provider_model_settings(tenant_id)
@ -497,8 +506,8 @@ class ProviderManager:
@staticmethod @staticmethod
def _init_trial_provider_records( def _init_trial_provider_records(
tenant_id: str, provider_name_to_provider_records_dict: dict[str, list] tenant_id: str, provider_name_to_provider_records_dict: dict[str, list[Provider]]
) -> dict[str, list]: ) -> dict[str, list[Provider]]:
""" """
Initialize trial provider records if not exists. Initialize trial provider records if not exists.
@ -532,7 +541,7 @@ class ProviderManager:
if ProviderQuotaType.TRIAL not in provider_quota_to_provider_record_dict: if ProviderQuotaType.TRIAL not in provider_quota_to_provider_record_dict:
try: try:
# FIXME ignore the type errork, onyl TrialHostingQuota has limit need to change the logic # FIXME ignore the type errork, onyl TrialHostingQuota has limit need to change the logic
provider_record = Provider( new_provider_record = Provider(
tenant_id=tenant_id, tenant_id=tenant_id,
# TODO: Use provider name with prefix after the data migration. # TODO: Use provider name with prefix after the data migration.
provider_name=ModelProviderID(provider_name).provider_name, provider_name=ModelProviderID(provider_name).provider_name,
@ -542,11 +551,12 @@ class ProviderManager:
quota_used=0, quota_used=0,
is_valid=True, is_valid=True,
) )
db.session.add(provider_record) db.session.add(new_provider_record)
db.session.commit() db.session.commit()
provider_name_to_provider_records_dict[provider_name].append(new_provider_record)
except IntegrityError: except IntegrityError:
db.session.rollback() db.session.rollback()
provider_record = ( existed_provider_record = (
db.session.query(Provider) db.session.query(Provider)
.filter( .filter(
Provider.tenant_id == tenant_id, Provider.tenant_id == tenant_id,
@ -556,11 +566,14 @@ class ProviderManager:
) )
.first() .first()
) )
if provider_record and not provider_record.is_valid: if not existed_provider_record:
provider_record.is_valid = True continue
if not existed_provider_record.is_valid:
existed_provider_record.is_valid = True
db.session.commit() db.session.commit()
provider_name_to_provider_records_dict[provider_name].append(provider_record) provider_name_to_provider_records_dict[provider_name].append(existed_provider_record)
return provider_name_to_provider_records_dict return provider_name_to_provider_records_dict

View File

@ -139,13 +139,17 @@ class AnalyticdbVectorBySql:
) )
if embedding_dimension is not None: if embedding_dimension is not None:
index_name = f"{self._collection_name}_embedding_idx" index_name = f"{self._collection_name}_embedding_idx"
cur.execute(f"ALTER TABLE {self.table_name} ALTER COLUMN vector SET STORAGE PLAIN") try:
cur.execute( cur.execute(f"ALTER TABLE {self.table_name} ALTER COLUMN vector SET STORAGE PLAIN")
f"CREATE INDEX {index_name} ON {self.table_name} USING ann(vector) " cur.execute(
f"WITH(dim='{embedding_dimension}', distancemeasure='{self.config.metrics}', " f"CREATE INDEX {index_name} ON {self.table_name} USING ann(vector) "
f"pq_enable=0, external_storage=0)" f"WITH(dim='{embedding_dimension}', distancemeasure='{self.config.metrics}', "
) f"pq_enable=0, external_storage=0)"
cur.execute(f"CREATE INDEX ON {self.table_name} USING gin(to_tsvector)") )
cur.execute(f"CREATE INDEX ON {self.table_name} USING gin(to_tsvector)")
except Exception as e:
if "already exists" not in str(e):
raise e
redis_client.set(collection_exist_cache_key, 1, ex=3600) redis_client.set(collection_exist_cache_key, 1, ex=3600)
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
@ -177,9 +181,11 @@ class AnalyticdbVectorBySql:
return cur.fetchone() is not None return cur.fetchone() is not None
def delete_by_ids(self, ids: list[str]) -> None: def delete_by_ids(self, ids: list[str]) -> None:
if not ids:
return
with self._get_cursor() as cur: with self._get_cursor() as cur:
try: try:
cur.execute(f"DELETE FROM {self.table_name} WHERE ref_doc_id IN %s", (tuple(ids),)) cur.execute(f"DELETE FROM {self.table_name} WHERE ref_doc_id = ANY(%s)", (ids,))
except Exception as e: except Exception as e:
if "does not exist" not in str(e): if "does not exist" not in str(e):
raise e raise e
@ -240,7 +246,7 @@ class AnalyticdbVectorBySql:
ts_rank(to_tsvector, to_tsquery_from_text(%s, 'zh_cn'), 32) AS score ts_rank(to_tsvector, to_tsquery_from_text(%s, 'zh_cn'), 32) AS score
FROM {self.table_name} FROM {self.table_name}
WHERE to_tsvector@@to_tsquery_from_text(%s, 'zh_cn') {where_clause} WHERE to_tsvector@@to_tsquery_from_text(%s, 'zh_cn') {where_clause}
ORDER BY score DESC ORDER BY score DESC, id DESC
LIMIT {top_k}""", LIMIT {top_k}""",
(f"'{query}'", f"'{query}'"), (f"'{query}'", f"'{query}'"),
) )

View File

@ -32,6 +32,7 @@ class MilvusConfig(BaseModel):
batch_size: int = 100 # Batch size for operations batch_size: int = 100 # Batch size for operations
database: str = "default" # Database name database: str = "default" # Database name
enable_hybrid_search: bool = False # Flag to enable hybrid search enable_hybrid_search: bool = False # Flag to enable hybrid search
analyzer_params: Optional[str] = None # Analyzer params
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
@ -58,6 +59,7 @@ class MilvusConfig(BaseModel):
"user": self.user, "user": self.user,
"password": self.password, "password": self.password,
"db_name": self.database, "db_name": self.database,
"analyzer_params": self.analyzer_params,
} }
@ -300,14 +302,19 @@ class MilvusVector(BaseVector):
# Create the text field, enable_analyzer will be set True to support milvus automatically # Create the text field, enable_analyzer will be set True to support milvus automatically
# transfer text to sparse_vector, reference: https://milvus.io/docs/full-text-search.md # transfer text to sparse_vector, reference: https://milvus.io/docs/full-text-search.md
fields.append( content_field_kwargs: dict[str, Any] = {
FieldSchema( "max_length": 65_535,
Field.CONTENT_KEY.value, "enable_analyzer": self._hybrid_search_enabled,
DataType.VARCHAR, }
max_length=65_535, if (
enable_analyzer=self._hybrid_search_enabled, self._hybrid_search_enabled
) and self._client_config.analyzer_params is not None
) and self._client_config.analyzer_params.strip()
):
content_field_kwargs["analyzer_params"] = self._client_config.analyzer_params
fields.append(FieldSchema(Field.CONTENT_KEY.value, DataType.VARCHAR, **content_field_kwargs))
# Create the primary key field # Create the primary key field
fields.append(FieldSchema(Field.PRIMARY_KEY.value, DataType.INT64, is_primary=True, auto_id=True)) fields.append(FieldSchema(Field.PRIMARY_KEY.value, DataType.INT64, is_primary=True, auto_id=True))
# Create the vector field, supports binary or float vectors # Create the vector field, supports binary or float vectors
@ -383,5 +390,6 @@ class MilvusVectorFactory(AbstractVectorFactory):
password=dify_config.MILVUS_PASSWORD or "", password=dify_config.MILVUS_PASSWORD or "",
database=dify_config.MILVUS_DATABASE or "", database=dify_config.MILVUS_DATABASE or "",
enable_hybrid_search=dify_config.MILVUS_ENABLE_HYBRID_SEARCH or False, enable_hybrid_search=dify_config.MILVUS_ENABLE_HYBRID_SEARCH or False,
analyzer_params=dify_config.MILVUS_ANALYZER_PARAMS or "",
), ),
) )

View File

@ -228,7 +228,7 @@ class OracleVector(BaseVector):
top_k = kwargs.get("top_k", 5) top_k = kwargs.get("top_k", 5)
# just not implement fetch by score_threshold now, may be later # just not implement fetch by score_threshold now, may be later
score_threshold = float(kwargs.get("score_threshold") or 0.0) # score_threshold = float(kwargs.get("score_threshold") or 0.0)
if len(query) > 0: if len(query) > 0:
# Check which language the query is in # Check which language the query is in
zh_pattern = re.compile("[\u4e00-\u9fa5]+") zh_pattern = re.compile("[\u4e00-\u9fa5]+")
@ -239,7 +239,7 @@ class OracleVector(BaseVector):
words = pseg.cut(query) words = pseg.cut(query)
current_entity = "" current_entity = ""
for word, pos in words: for word, pos in words:
if pos in {"nr", "Ng", "eng", "nz", "n", "ORG", "v"}: # nr: 人名, ns: 地名, nt: 机构名 if pos in {"nr", "Ng", "eng", "nz", "n", "ORG", "v"}: # nr: 人名ns: 地名,nt: 机构名
current_entity += word current_entity += word
else: else:
if current_entity: if current_entity:

View File

@ -444,7 +444,7 @@ class QdrantVectorFactory(AbstractVectorFactory):
if dataset_collection_binding: if dataset_collection_binding:
collection_name = dataset_collection_binding.collection_name collection_name = dataset_collection_binding.collection_name
else: else:
raise ValueError("Dataset Collection Bindings is not exist!") raise ValueError("Dataset Collection Bindings does not exist!")
else: else:
if dataset.index_struct_dict: if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"] class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]

View File

@ -65,8 +65,6 @@ class RelytVector(BaseVector):
return VectorType.RELYT return VectorType.RELYT
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs) -> None: def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs) -> None:
index_params: dict[str, Any] = {}
metadatas = [d.metadata for d in texts]
self.create_collection(len(embeddings[0])) self.create_collection(len(embeddings[0]))
self.embedding_dimension = len(embeddings[0]) self.embedding_dimension = len(embeddings[0])
self.add_texts(texts, embeddings) self.add_texts(texts, embeddings)

View File

@ -187,7 +187,6 @@ class TiDBVector(BaseVector):
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
top_k = kwargs.get("top_k", 4) top_k = kwargs.get("top_k", 4)
score_threshold = float(kwargs.get("score_threshold") or 0.0) score_threshold = float(kwargs.get("score_threshold") or 0.0)
filter = kwargs.get("filter")
distance = 1 - score_threshold distance = 1 - score_threshold
query_vector_str = ", ".join(format(x) for x in query_vector) query_vector_str = ", ".join(format(x) for x in query_vector)

View File

@ -206,6 +206,7 @@ class DatasetRetrieval:
source = { source = {
"dataset_id": item.metadata.get("dataset_id"), "dataset_id": item.metadata.get("dataset_id"),
"dataset_name": item.metadata.get("dataset_name"), "dataset_name": item.metadata.get("dataset_name"),
"document_id": item.metadata.get("document_id") or item.metadata.get("title"),
"document_name": item.metadata.get("title"), "document_name": item.metadata.get("title"),
"data_source_type": "external", "data_source_type": "external",
"retriever_from": invoke_from.to_source(), "retriever_from": invoke_from.to_source(),

View File

@ -39,6 +39,12 @@ class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter):
else: else:
return [GPT2Tokenizer.get_num_tokens(text) for text in texts] return [GPT2Tokenizer.get_num_tokens(text) for text in texts]
def _character_encoder(texts: list[str]) -> list[int]:
if not texts:
return []
return [len(text) for text in texts]
if issubclass(cls, TokenTextSplitter): if issubclass(cls, TokenTextSplitter):
extra_kwargs = { extra_kwargs = {
"model_name": embedding_model_instance.model if embedding_model_instance else "gpt2", "model_name": embedding_model_instance.model if embedding_model_instance else "gpt2",
@ -47,7 +53,7 @@ class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter):
} }
kwargs = {**kwargs, **extra_kwargs} kwargs = {**kwargs, **extra_kwargs}
return cls(length_function=_token_encoder, **kwargs) return cls(length_function=_character_encoder, **kwargs)
class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter): class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter):
@ -103,7 +109,7 @@ class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter)
_good_splits_lengths = [] # cache the lengths of the splits _good_splits_lengths = [] # cache the lengths of the splits
_separator = "" if self._keep_separator else separator _separator = "" if self._keep_separator else separator
s_lens = self._length_function(splits) s_lens = self._length_function(splits)
if _separator != "": if separator != "":
for s, s_len in zip(splits, s_lens): for s, s_len in zip(splits, s_lens):
if s_len < self._chunk_size: if s_len < self._chunk_size:
_good_splits.append(s) _good_splits.append(s)

View File

@ -0,0 +1,15 @@
"""
Repository interfaces for data access.
This package contains repository interfaces that define the contract
for accessing and manipulating data, regardless of the underlying
storage mechanism.
"""
from core.repository.repository_factory import RepositoryFactory
from core.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository
__all__ = [
"RepositoryFactory",
"WorkflowNodeExecutionRepository",
]

View File

@ -0,0 +1,97 @@
"""
Repository factory for creating repository instances.
This module provides a simple factory interface for creating repository instances.
It does not contain any implementation details or dependencies on specific repositories.
"""
from collections.abc import Callable, Mapping
from typing import Any, Literal, Optional, cast
from core.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository
# Type for factory functions - takes a dict of parameters and returns any repository type
RepositoryFactoryFunc = Callable[[Mapping[str, Any]], Any]
# Type for workflow node execution factory function
WorkflowNodeExecutionFactoryFunc = Callable[[Mapping[str, Any]], WorkflowNodeExecutionRepository]
# Repository type literals
_RepositoryType = Literal["workflow_node_execution"]
class RepositoryFactory:
"""
Factory class for creating repository instances.
This factory delegates the actual repository creation to implementation-specific
factory functions that are registered with the factory at runtime.
"""
# Dictionary to store factory functions
_factory_functions: dict[str, RepositoryFactoryFunc] = {}
@classmethod
def _register_factory(cls, repository_type: _RepositoryType, factory_func: RepositoryFactoryFunc) -> None:
"""
Register a factory function for a specific repository type.
This is a private method and should not be called directly.
Args:
repository_type: The type of repository (e.g., 'workflow_node_execution')
factory_func: A function that takes parameters and returns a repository instance
"""
cls._factory_functions[repository_type] = factory_func
@classmethod
def _create_repository(cls, repository_type: _RepositoryType, params: Optional[Mapping[str, Any]] = None) -> Any:
"""
Create a new repository instance with the provided parameters.
This is a private method and should not be called directly.
Args:
repository_type: The type of repository to create
params: A dictionary of parameters to pass to the factory function
Returns:
A new instance of the requested repository
Raises:
ValueError: If no factory function is registered for the repository type
"""
if repository_type not in cls._factory_functions:
raise ValueError(f"No factory function registered for repository type '{repository_type}'")
# Use empty dict if params is None
params = params or {}
return cls._factory_functions[repository_type](params)
@classmethod
def register_workflow_node_execution_factory(cls, factory_func: WorkflowNodeExecutionFactoryFunc) -> None:
"""
Register a factory function for the workflow node execution repository.
Args:
factory_func: A function that takes parameters and returns a WorkflowNodeExecutionRepository instance
"""
cls._register_factory("workflow_node_execution", factory_func)
@classmethod
def create_workflow_node_execution_repository(
cls, params: Optional[Mapping[str, Any]] = None
) -> WorkflowNodeExecutionRepository:
"""
Create a new WorkflowNodeExecutionRepository instance with the provided parameters.
Args:
params: A dictionary of parameters to pass to the factory function
Returns:
A new instance of the WorkflowNodeExecutionRepository
Raises:
ValueError: If no factory function is registered for the workflow_node_execution repository type
"""
# We can safely cast here because we've registered a WorkflowNodeExecutionFactoryFunc
return cast(WorkflowNodeExecutionRepository, cls._create_repository("workflow_node_execution", params))

View File

@ -0,0 +1,88 @@
from collections.abc import Sequence
from dataclasses import dataclass
from typing import Literal, Optional, Protocol
from models.workflow import WorkflowNodeExecution
@dataclass
class OrderConfig:
"""Configuration for ordering WorkflowNodeExecution instances."""
order_by: list[str]
order_direction: Optional[Literal["asc", "desc"]] = None
class WorkflowNodeExecutionRepository(Protocol):
"""
Repository interface for WorkflowNodeExecution.
This interface defines the contract for accessing and manipulating
WorkflowNodeExecution data, regardless of the underlying storage mechanism.
Note: Domain-specific concepts like multi-tenancy (tenant_id), application context (app_id),
and trigger sources (triggered_from) should be handled at the implementation level, not in
the core interface. This keeps the core domain model clean and independent of specific
application domains or deployment scenarios.
"""
def save(self, execution: WorkflowNodeExecution) -> None:
"""
Save a WorkflowNodeExecution instance.
Args:
execution: The WorkflowNodeExecution instance to save
"""
...
def get_by_node_execution_id(self, node_execution_id: str) -> Optional[WorkflowNodeExecution]:
"""
Retrieve a WorkflowNodeExecution by its node_execution_id.
Args:
node_execution_id: The node execution ID
Returns:
The WorkflowNodeExecution instance if found, None otherwise
"""
...
def get_by_workflow_run(
self,
workflow_run_id: str,
order_config: Optional[OrderConfig] = None,
) -> Sequence[WorkflowNodeExecution]:
"""
Retrieve all WorkflowNodeExecution instances for a specific workflow run.
Args:
workflow_run_id: The workflow run ID
order_config: Optional configuration for ordering results
order_config.order_by: List of fields to order by (e.g., ["index", "created_at"])
order_config.order_direction: Direction to order ("asc" or "desc")
Returns:
A list of WorkflowNodeExecution instances
"""
...
def get_running_executions(self, workflow_run_id: str) -> Sequence[WorkflowNodeExecution]:
"""
Retrieve all running WorkflowNodeExecution instances for a specific workflow run.
Args:
workflow_run_id: The workflow run ID
Returns:
A list of running WorkflowNodeExecution instances
"""
...
def update(self, execution: WorkflowNodeExecution) -> None:
"""
Update an existing WorkflowNodeExecution instance.
Args:
execution: The WorkflowNodeExecution instance to update
"""
...

View File

@ -8,7 +8,7 @@ identity:
description: description:
human: human:
en_US: Run code and get the result back. When you're using a lower quality model, please make sure there are some tips help LLM to understand how to write the code. en_US: Run code and get the result back. When you're using a lower quality model, please make sure there are some tips help LLM to understand how to write the code.
zh_Hans: 运行一段代码并返回结果。当您使用较低质量的模型时,请确保有一些提示帮助LLM理解如何编写代码。 zh_Hans: 运行一段代码并返回结果。当您使用较低质量的模型时,请确保有一些提示帮助 LLM 理解如何编写代码。
pt_BR: Execute um trecho de código e obtenha o resultado de volta. quando você estiver usando um modelo de qualidade inferior, certifique-se de que existam algumas dicas para ajudar o LLM a entender como escrever o código. pt_BR: Execute um trecho de código e obtenha o resultado de volta. quando você estiver usando um modelo de qualidade inferior, certifique-se de que existam algumas dicas para ajudar o LLM a entender como escrever o código.
llm: A tool for running code and getting the result back. Only native packages are allowed, network/IO operations are disabled. and you must use print() or console.log() to output the result or result will be empty. llm: A tool for running code and getting the result back. Only native packages are allowed, network/IO operations are disabled. and you must use print() or console.log() to output the result or result will be empty.
parameters: parameters:

View File

@ -19,7 +19,7 @@ parameters:
zh_Hans: 本地时间 zh_Hans: 本地时间
human_description: human_description:
en_US: localtime, such as 2024-1-1 0:0:0 en_US: localtime, such as 2024-1-1 0:0:0
zh_Hans: 本地时间, 比如2024-1-1 0:0:0 zh_Hans: 本地时间,比如 2024-1-1 0:0:0
- name: timezone - name: timezone
type: string type: string
required: false required: false
@ -29,5 +29,5 @@ parameters:
zh_Hans: 时区 zh_Hans: 时区
human_description: human_description:
en_US: Timezone, such as Asia/Shanghai en_US: Timezone, such as Asia/Shanghai
zh_Hans: 时区, 比如Asia/Shanghai zh_Hans: 时区,比如 Asia/Shanghai
default: Asia/Shanghai default: Asia/Shanghai

View File

@ -29,5 +29,5 @@ parameters:
zh_Hans: 时区 zh_Hans: 时区
human_description: human_description:
en_US: Timezone, such as Asia/Shanghai en_US: Timezone, such as Asia/Shanghai
zh_Hans: 时区, 比如Asia/Shanghai zh_Hans: 时区,比如 Asia/Shanghai
default: Asia/Shanghai default: Asia/Shanghai

View File

@ -19,7 +19,7 @@ parameters:
zh_Hans: 当前时间 zh_Hans: 当前时间
human_description: human_description:
en_US: current time, such as 2024-1-1 0:0:0 en_US: current time, such as 2024-1-1 0:0:0
zh_Hans: 当前时间, 比如2024-1-1 0:0:0 zh_Hans: 当前时间,比如 2024-1-1 0:0:0
- name: current_timezone - name: current_timezone
type: string type: string
required: true required: true
@ -29,7 +29,7 @@ parameters:
zh_Hans: 当前时区 zh_Hans: 当前时区
human_description: human_description:
en_US: Current Timezone, such as Asia/Shanghai en_US: Current Timezone, such as Asia/Shanghai
zh_Hans: 当前时区, 比如Asia/Shanghai zh_Hans: 当前时区,比如 Asia/Shanghai
default: Asia/Shanghai default: Asia/Shanghai
- name: target_timezone - name: target_timezone
type: string type: string
@ -40,5 +40,5 @@ parameters:
zh_Hans: 目标时区 zh_Hans: 目标时区
human_description: human_description:
en_US: Target Timezone, such as Asia/Tokyo en_US: Target Timezone, such as Asia/Tokyo
zh_Hans: 目标时区, 比如Asia/Tokyo zh_Hans: 目标时区,比如 Asia/Tokyo
default: Asia/Tokyo default: Asia/Tokyo

View File

@ -59,7 +59,7 @@ class ApiToolProviderController(ToolProviderController):
name="api_key_value", name="api_key_value",
required=True, required=True,
type=ProviderConfig.Type.SECRET_INPUT, type=ProviderConfig.Type.SECRET_INPUT,
help=I18nObject(en_US="The api key", zh_Hans="api key的值"), help=I18nObject(en_US="The api key", zh_Hans="api key 的值"),
), ),
ProviderConfig( ProviderConfig(
name="api_key_header_prefix", name="api_key_header_prefix",

View File

@ -120,6 +120,13 @@ class ToolInvokeMessage(BaseModel):
class BlobMessage(BaseModel): class BlobMessage(BaseModel):
blob: bytes blob: bytes
class BlobChunkMessage(BaseModel):
id: str = Field(..., description="The id of the blob")
sequence: int = Field(..., description="The sequence of the chunk")
total_length: int = Field(..., description="The total length of the blob")
blob: bytes = Field(..., description="The blob data of the chunk")
end: bool = Field(..., description="Whether the chunk is the last chunk")
class FileMessage(BaseModel): class FileMessage(BaseModel):
pass pass
@ -180,12 +187,15 @@ class ToolInvokeMessage(BaseModel):
VARIABLE = "variable" VARIABLE = "variable"
FILE = "file" FILE = "file"
LOG = "log" LOG = "log"
BLOB_CHUNK = "blob_chunk"
type: MessageType = MessageType.TEXT type: MessageType = MessageType.TEXT
""" """
plain text, image url or link url plain text, image url or link url
""" """
message: JsonMessage | TextMessage | BlobMessage | LogMessage | FileMessage | None | VariableMessage message: (
JsonMessage | TextMessage | BlobChunkMessage | BlobMessage | LogMessage | FileMessage | None | VariableMessage
)
meta: dict[str, Any] | None = None meta: dict[str, Any] | None = None
@field_validator("message", mode="before") @field_validator("message", mode="before")

View File

@ -86,6 +86,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
"position": position, "position": position,
"dataset_id": item.metadata.get("dataset_id"), "dataset_id": item.metadata.get("dataset_id"),
"dataset_name": item.metadata.get("dataset_name"), "dataset_name": item.metadata.get("dataset_name"),
"document_id": item.metadata.get("document_id") or item.metadata.get("title"),
"document_name": item.metadata.get("title"), "document_name": item.metadata.get("title"),
"data_source_type": "external", "data_source_type": "external",
"retriever_from": self.retriever_from, "retriever_from": self.retriever_from,

View File

@ -1,13 +0,0 @@
from typing import Optional
from pydantic import BaseModel
from core.workflow.graph_engine.entities.graph import GraphParallel
class NextGraphNode(BaseModel):
node_id: str
"""next node id"""
parallel: Optional[GraphParallel] = None
"""parallel"""

View File

@ -155,9 +155,28 @@ class AnswerStreamProcessor(StreamProcessor):
for answer_node_id, route_position in self.route_position.items(): for answer_node_id, route_position in self.route_position.items():
if answer_node_id not in self.rest_node_ids: if answer_node_id not in self.rest_node_ids:
continue continue
# exclude current node id # Remove current node id from answer dependencies to support stream output if it is a success branch
answer_dependencies = self.generate_routes.answer_dependencies answer_dependencies = self.generate_routes.answer_dependencies
if event.node_id in answer_dependencies[answer_node_id]: edge_mapping = self.graph.edge_mapping.get(event.node_id)
success_edge = (
next(
(
edge
for edge in edge_mapping
if edge.run_condition
and edge.run_condition.type == "branch_identify"
and edge.run_condition.branch_identify == "success-branch"
),
None,
)
if edge_mapping
else None
)
if (
event.node_id in answer_dependencies[answer_node_id]
and success_edge
and success_edge.target_node_id == answer_node_id
):
answer_dependencies[answer_node_id].remove(event.node_id) answer_dependencies[answer_node_id].remove(event.node_id)
answer_dependencies_ids = answer_dependencies.get(answer_node_id, []) answer_dependencies_ids = answer_dependencies.get(answer_node_id, [])
# all depends on answer node id not in rest node ids # all depends on answer node id not in rest node ids

Some files were not shown because too many files have changed in this diff Show More