mirror of https://github.com/langgenius/dify.git
merge main
This commit is contained in:
commit
628dd8fcd8
|
@ -2,10 +2,10 @@
|
||||||
|
|
||||||
npm add -g pnpm@10.8.0
|
npm add -g pnpm@10.8.0
|
||||||
cd web && pnpm install
|
cd web && pnpm install
|
||||||
pipx install poetry
|
pipx install uv
|
||||||
|
|
||||||
echo 'alias start-api="cd /workspaces/dify/api && poetry run python -m flask run --host 0.0.0.0 --port=5001 --debug"' >> ~/.bashrc
|
echo 'alias start-api="cd /workspaces/dify/api && uv run python -m flask run --host 0.0.0.0 --port=5001 --debug"' >> ~/.bashrc
|
||||||
echo 'alias start-worker="cd /workspaces/dify/api && poetry run python -m celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion"' >> ~/.bashrc
|
echo 'alias start-worker="cd /workspaces/dify/api && uv run python -m celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion"' >> ~/.bashrc
|
||||||
echo 'alias start-web="cd /workspaces/dify/web && pnpm dev"' >> ~/.bashrc
|
echo 'alias start-web="cd /workspaces/dify/web && pnpm dev"' >> ~/.bashrc
|
||||||
echo 'alias start-containers="cd /workspaces/dify/docker && docker-compose -f docker-compose.middleware.yaml -p dify --env-file middleware.env up -d"' >> ~/.bashrc
|
echo 'alias start-containers="cd /workspaces/dify/docker && docker-compose -f docker-compose.middleware.yaml -p dify --env-file middleware.env up -d"' >> ~/.bashrc
|
||||||
echo 'alias stop-containers="cd /workspaces/dify/docker && docker-compose -f docker-compose.middleware.yaml -p dify --env-file middleware.env down"' >> ~/.bashrc
|
echo 'alias stop-containers="cd /workspaces/dify/docker && docker-compose -f docker-compose.middleware.yaml -p dify --env-file middleware.env down"' >> ~/.bashrc
|
||||||
|
|
|
@ -1,3 +1,3 @@
|
||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
|
|
||||||
cd api && poetry install
|
cd api && uv sync
|
||||||
|
|
|
@ -1,36 +0,0 @@
|
||||||
name: Setup Poetry and Python
|
|
||||||
|
|
||||||
inputs:
|
|
||||||
python-version:
|
|
||||||
description: Python version to use and the Poetry installed with
|
|
||||||
required: true
|
|
||||||
default: '3.11'
|
|
||||||
poetry-version:
|
|
||||||
description: Poetry version to set up
|
|
||||||
required: true
|
|
||||||
default: '2.0.1'
|
|
||||||
poetry-lockfile:
|
|
||||||
description: Path to the Poetry lockfile to restore cache from
|
|
||||||
required: true
|
|
||||||
default: ''
|
|
||||||
|
|
||||||
runs:
|
|
||||||
using: composite
|
|
||||||
steps:
|
|
||||||
- name: Set up Python ${{ inputs.python-version }}
|
|
||||||
uses: actions/setup-python@v5
|
|
||||||
with:
|
|
||||||
python-version: ${{ inputs.python-version }}
|
|
||||||
cache: pip
|
|
||||||
|
|
||||||
- name: Install Poetry
|
|
||||||
shell: bash
|
|
||||||
run: pip install poetry==${{ inputs.poetry-version }}
|
|
||||||
|
|
||||||
- name: Restore Poetry cache
|
|
||||||
if: ${{ inputs.poetry-lockfile != '' }}
|
|
||||||
uses: actions/setup-python@v5
|
|
||||||
with:
|
|
||||||
python-version: ${{ inputs.python-version }}
|
|
||||||
cache: poetry
|
|
||||||
cache-dependency-path: ${{ inputs.poetry-lockfile }}
|
|
|
@ -0,0 +1,34 @@
|
||||||
|
name: Setup UV and Python
|
||||||
|
|
||||||
|
inputs:
|
||||||
|
python-version:
|
||||||
|
description: Python version to use and the UV installed with
|
||||||
|
required: true
|
||||||
|
default: '3.12'
|
||||||
|
uv-version:
|
||||||
|
description: UV version to set up
|
||||||
|
required: true
|
||||||
|
default: '0.6.14'
|
||||||
|
uv-lockfile:
|
||||||
|
description: Path to the UV lockfile to restore cache from
|
||||||
|
required: true
|
||||||
|
default: ''
|
||||||
|
enable-cache:
|
||||||
|
required: true
|
||||||
|
default: true
|
||||||
|
|
||||||
|
runs:
|
||||||
|
using: composite
|
||||||
|
steps:
|
||||||
|
- name: Set up Python ${{ inputs.python-version }}
|
||||||
|
uses: actions/setup-python@v5
|
||||||
|
with:
|
||||||
|
python-version: ${{ inputs.python-version }}
|
||||||
|
|
||||||
|
- name: Install uv
|
||||||
|
uses: astral-sh/setup-uv@v5
|
||||||
|
with:
|
||||||
|
version: ${{ inputs.uv-version }}
|
||||||
|
python-version: ${{ inputs.python-version }}
|
||||||
|
enable-cache: ${{ inputs.enable-cache }}
|
||||||
|
cache-dependency-glob: ${{ inputs.uv-lockfile }}
|
|
@ -17,6 +17,9 @@ jobs:
|
||||||
test:
|
test:
|
||||||
name: API Tests
|
name: API Tests
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
|
defaults:
|
||||||
|
run:
|
||||||
|
shell: bash
|
||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
python-version:
|
python-version:
|
||||||
|
@ -27,40 +30,44 @@ jobs:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
with:
|
with:
|
||||||
fetch-depth: 0
|
|
||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
|
|
||||||
- name: Setup Poetry and Python ${{ matrix.python-version }}
|
- name: Setup UV and Python
|
||||||
uses: ./.github/actions/setup-poetry
|
uses: ./.github/actions/setup-uv
|
||||||
with:
|
with:
|
||||||
python-version: ${{ matrix.python-version }}
|
python-version: ${{ matrix.python-version }}
|
||||||
poetry-lockfile: api/poetry.lock
|
uv-lockfile: api/uv.lock
|
||||||
|
|
||||||
- name: Check Poetry lockfile
|
- name: Check UV lockfile
|
||||||
run: |
|
run: uv lock --project api --check
|
||||||
poetry check -C api --lock
|
|
||||||
poetry show -C api
|
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: poetry install -C api --with dev
|
run: uv sync --project api --dev
|
||||||
|
|
||||||
- name: Check dependencies in pyproject.toml
|
|
||||||
run: poetry run -P api bash dev/pytest/pytest_artifacts.sh
|
|
||||||
|
|
||||||
- name: Run Unit tests
|
- name: Run Unit tests
|
||||||
run: poetry run -P api bash dev/pytest/pytest_unit_tests.sh
|
run: |
|
||||||
|
uv run --project api bash dev/pytest/pytest_unit_tests.sh
|
||||||
|
# Extract coverage percentage and create a summary
|
||||||
|
TOTAL_COVERAGE=$(python -c 'import json; print(json.load(open("coverage.json"))["totals"]["percent_covered_display"])')
|
||||||
|
|
||||||
|
# Create a detailed coverage summary
|
||||||
|
echo "### Test Coverage Summary :test_tube:" >> $GITHUB_STEP_SUMMARY
|
||||||
|
echo "Total Coverage: ${TOTAL_COVERAGE}%" >> $GITHUB_STEP_SUMMARY
|
||||||
|
echo "\`\`\`" >> $GITHUB_STEP_SUMMARY
|
||||||
|
uv run --project api coverage report >> $GITHUB_STEP_SUMMARY
|
||||||
|
echo "\`\`\`" >> $GITHUB_STEP_SUMMARY
|
||||||
|
|
||||||
- name: Run dify config tests
|
- name: Run dify config tests
|
||||||
run: poetry run -P api python dev/pytest/pytest_config_tests.py
|
run: uv run --project api dev/pytest/pytest_config_tests.py
|
||||||
|
|
||||||
- name: Cache MyPy
|
- name: MyPy Cache
|
||||||
uses: actions/cache@v4
|
uses: actions/cache@v4
|
||||||
with:
|
with:
|
||||||
path: api/.mypy_cache
|
path: api/.mypy_cache
|
||||||
key: mypy-${{ matrix.python-version }}-${{ runner.os }}-${{ hashFiles('api/poetry.lock') }}
|
key: mypy-${{ matrix.python-version }}-${{ runner.os }}-${{ hashFiles('api/uv.lock') }}
|
||||||
|
|
||||||
- name: Run mypy
|
- name: Run MyPy Checks
|
||||||
run: dev/run-mypy
|
run: dev/mypy-check
|
||||||
|
|
||||||
- name: Set up dotenvs
|
- name: Set up dotenvs
|
||||||
run: |
|
run: |
|
||||||
|
@ -80,4 +87,4 @@ jobs:
|
||||||
ssrf_proxy
|
ssrf_proxy
|
||||||
|
|
||||||
- name: Run Workflow
|
- name: Run Workflow
|
||||||
run: poetry run -P api bash dev/pytest/pytest_workflow.sh
|
run: uv run --project api bash dev/pytest/pytest_workflow.sh
|
||||||
|
|
|
@ -24,13 +24,13 @@ jobs:
|
||||||
fetch-depth: 0
|
fetch-depth: 0
|
||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
|
|
||||||
- name: Setup Poetry and Python
|
- name: Setup UV and Python
|
||||||
uses: ./.github/actions/setup-poetry
|
uses: ./.github/actions/setup-uv
|
||||||
with:
|
with:
|
||||||
poetry-lockfile: api/poetry.lock
|
uv-lockfile: api/uv.lock
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: poetry install -C api
|
run: uv sync --project api
|
||||||
|
|
||||||
- name: Prepare middleware env
|
- name: Prepare middleware env
|
||||||
run: |
|
run: |
|
||||||
|
@ -54,6 +54,4 @@ jobs:
|
||||||
- name: Run DB Migration
|
- name: Run DB Migration
|
||||||
env:
|
env:
|
||||||
DEBUG: true
|
DEBUG: true
|
||||||
run: |
|
run: uv run --directory api flask upgrade-db
|
||||||
cd api
|
|
||||||
poetry run python -m flask upgrade-db
|
|
||||||
|
|
|
@ -42,6 +42,7 @@ jobs:
|
||||||
with:
|
with:
|
||||||
push: false
|
push: false
|
||||||
context: "{{defaultContext}}:${{ matrix.context }}"
|
context: "{{defaultContext}}:${{ matrix.context }}"
|
||||||
|
file: "${{ matrix.file }}"
|
||||||
platforms: ${{ matrix.platform }}
|
platforms: ${{ matrix.platform }}
|
||||||
cache-from: type=gha
|
cache-from: type=gha
|
||||||
cache-to: type=gha,mode=max
|
cache-to: type=gha,mode=max
|
||||||
|
|
|
@ -18,7 +18,6 @@ jobs:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
with:
|
with:
|
||||||
fetch-depth: 0
|
|
||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
|
|
||||||
- name: Check changed files
|
- name: Check changed files
|
||||||
|
@ -29,24 +28,27 @@ jobs:
|
||||||
api/**
|
api/**
|
||||||
.github/workflows/style.yml
|
.github/workflows/style.yml
|
||||||
|
|
||||||
- name: Setup Poetry and Python
|
- name: Setup UV and Python
|
||||||
if: steps.changed-files.outputs.any_changed == 'true'
|
if: steps.changed-files.outputs.any_changed == 'true'
|
||||||
uses: ./.github/actions/setup-poetry
|
uses: ./.github/actions/setup-uv
|
||||||
|
with:
|
||||||
|
uv-lockfile: api/uv.lock
|
||||||
|
enable-cache: false
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
if: steps.changed-files.outputs.any_changed == 'true'
|
if: steps.changed-files.outputs.any_changed == 'true'
|
||||||
run: poetry install -C api --only lint
|
run: uv sync --project api --dev
|
||||||
|
|
||||||
- name: Ruff check
|
- name: Ruff check
|
||||||
if: steps.changed-files.outputs.any_changed == 'true'
|
if: steps.changed-files.outputs.any_changed == 'true'
|
||||||
run: |
|
run: |
|
||||||
poetry run -C api ruff --version
|
uv run --directory api ruff --version
|
||||||
poetry run -C api ruff check ./
|
uv run --directory api ruff check ./
|
||||||
poetry run -C api ruff format --check ./
|
uv run --directory api ruff format --check ./
|
||||||
|
|
||||||
- name: Dotenv check
|
- name: Dotenv check
|
||||||
if: steps.changed-files.outputs.any_changed == 'true'
|
if: steps.changed-files.outputs.any_changed == 'true'
|
||||||
run: poetry run -P api dotenv-linter ./api/.env.example ./web/.env.example
|
run: uv run --project api dotenv-linter ./api/.env.example ./web/.env.example
|
||||||
|
|
||||||
- name: Lint hints
|
- name: Lint hints
|
||||||
if: failure()
|
if: failure()
|
||||||
|
@ -63,7 +65,6 @@ jobs:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
with:
|
with:
|
||||||
fetch-depth: 0
|
|
||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
|
|
||||||
- name: Check changed files
|
- name: Check changed files
|
||||||
|
@ -102,7 +103,6 @@ jobs:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
with:
|
with:
|
||||||
fetch-depth: 0
|
|
||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
|
|
||||||
- name: Check changed files
|
- name: Check changed files
|
||||||
|
@ -133,7 +133,6 @@ jobs:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
with:
|
with:
|
||||||
fetch-depth: 0
|
|
||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
|
|
||||||
- name: Check changed files
|
- name: Check changed files
|
||||||
|
|
|
@ -27,7 +27,6 @@ jobs:
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
with:
|
with:
|
||||||
fetch-depth: 0
|
|
||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
|
|
||||||
- name: Use Node.js ${{ matrix.node-version }}
|
- name: Use Node.js ${{ matrix.node-version }}
|
||||||
|
|
|
@ -8,7 +8,7 @@ on:
|
||||||
- api/core/rag/datasource/**
|
- api/core/rag/datasource/**
|
||||||
- docker/**
|
- docker/**
|
||||||
- .github/workflows/vdb-tests.yml
|
- .github/workflows/vdb-tests.yml
|
||||||
- api/poetry.lock
|
- api/uv.lock
|
||||||
- api/pyproject.toml
|
- api/pyproject.toml
|
||||||
|
|
||||||
concurrency:
|
concurrency:
|
||||||
|
@ -29,22 +29,19 @@ jobs:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
with:
|
with:
|
||||||
fetch-depth: 0
|
|
||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
|
|
||||||
- name: Setup Poetry and Python ${{ matrix.python-version }}
|
- name: Setup UV and Python
|
||||||
uses: ./.github/actions/setup-poetry
|
uses: ./.github/actions/setup-uv
|
||||||
with:
|
with:
|
||||||
python-version: ${{ matrix.python-version }}
|
python-version: ${{ matrix.python-version }}
|
||||||
poetry-lockfile: api/poetry.lock
|
uv-lockfile: api/uv.lock
|
||||||
|
|
||||||
- name: Check Poetry lockfile
|
- name: Check UV lockfile
|
||||||
run: |
|
run: uv lock --project api --check
|
||||||
poetry check -C api --lock
|
|
||||||
poetry show -C api
|
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: poetry install -C api --with dev
|
run: uv sync --project api --dev
|
||||||
|
|
||||||
- name: Set up dotenvs
|
- name: Set up dotenvs
|
||||||
run: |
|
run: |
|
||||||
|
@ -80,7 +77,7 @@ jobs:
|
||||||
elasticsearch
|
elasticsearch
|
||||||
|
|
||||||
- name: Check TiDB Ready
|
- name: Check TiDB Ready
|
||||||
run: poetry run -P api python api/tests/integration_tests/vdb/tidb_vector/check_tiflash_ready.py
|
run: uv run --project api python api/tests/integration_tests/vdb/tidb_vector/check_tiflash_ready.py
|
||||||
|
|
||||||
- name: Test Vector Stores
|
- name: Test Vector Stores
|
||||||
run: poetry run -P api bash dev/pytest/pytest_vdb.sh
|
run: uv run --project api bash dev/pytest/pytest_vdb.sh
|
||||||
|
|
|
@ -23,7 +23,6 @@ jobs:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
with:
|
with:
|
||||||
fetch-depth: 0
|
|
||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
|
|
||||||
- name: Check changed files
|
- name: Check changed files
|
||||||
|
|
|
@ -46,6 +46,7 @@ htmlcov/
|
||||||
.cache
|
.cache
|
||||||
nosetests.xml
|
nosetests.xml
|
||||||
coverage.xml
|
coverage.xml
|
||||||
|
coverage.json
|
||||||
*.cover
|
*.cover
|
||||||
*.py,cover
|
*.py,cover
|
||||||
.hypothesis/
|
.hypothesis/
|
||||||
|
|
|
@ -165,6 +165,7 @@ MILVUS_URI=http://127.0.0.1:19530
|
||||||
MILVUS_TOKEN=
|
MILVUS_TOKEN=
|
||||||
MILVUS_USER=root
|
MILVUS_USER=root
|
||||||
MILVUS_PASSWORD=Milvus
|
MILVUS_PASSWORD=Milvus
|
||||||
|
MILVUS_ANALYZER_PARAMS=
|
||||||
|
|
||||||
# MyScale configuration
|
# MyScale configuration
|
||||||
MYSCALE_HOST=127.0.0.1
|
MYSCALE_HOST=127.0.0.1
|
||||||
|
@ -423,6 +424,12 @@ WORKFLOW_CALL_MAX_DEPTH=5
|
||||||
WORKFLOW_PARALLEL_DEPTH_LIMIT=3
|
WORKFLOW_PARALLEL_DEPTH_LIMIT=3
|
||||||
MAX_VARIABLE_SIZE=204800
|
MAX_VARIABLE_SIZE=204800
|
||||||
|
|
||||||
|
# Workflow storage configuration
|
||||||
|
# Options: rdbms, hybrid
|
||||||
|
# rdbms: Use only the relational database (default)
|
||||||
|
# hybrid: Save new data to object storage, read from both object storage and RDBMS
|
||||||
|
WORKFLOW_NODE_EXECUTION_STORAGE=rdbms
|
||||||
|
|
||||||
# App configuration
|
# App configuration
|
||||||
APP_MAX_EXECUTION_TIME=1200
|
APP_MAX_EXECUTION_TIME=1200
|
||||||
APP_MAX_ACTIVE_REQUESTS=0
|
APP_MAX_ACTIVE_REQUESTS=0
|
||||||
|
@ -463,3 +470,16 @@ CREATE_TIDB_SERVICE_JOB_ENABLED=false
|
||||||
MAX_SUBMIT_COUNT=100
|
MAX_SUBMIT_COUNT=100
|
||||||
# Lockout duration in seconds
|
# Lockout duration in seconds
|
||||||
LOGIN_LOCKOUT_DURATION=86400
|
LOGIN_LOCKOUT_DURATION=86400
|
||||||
|
|
||||||
|
# Enable OpenTelemetry
|
||||||
|
ENABLE_OTEL=false
|
||||||
|
OTLP_BASE_ENDPOINT=http://localhost:4318
|
||||||
|
OTLP_API_KEY=
|
||||||
|
OTEL_EXPORTER_TYPE=otlp
|
||||||
|
OTEL_SAMPLING_RATE=0.1
|
||||||
|
OTEL_BATCH_EXPORT_SCHEDULE_DELAY=5000
|
||||||
|
OTEL_MAX_QUEUE_SIZE=2048
|
||||||
|
OTEL_MAX_EXPORT_BATCH_SIZE=512
|
||||||
|
OTEL_METRIC_EXPORT_INTERVAL=60000
|
||||||
|
OTEL_BATCH_EXPORT_TIMEOUT=10000
|
||||||
|
OTEL_METRIC_EXPORT_TIMEOUT=30000
|
|
@ -3,20 +3,11 @@ FROM python:3.12-slim-bookworm AS base
|
||||||
|
|
||||||
WORKDIR /app/api
|
WORKDIR /app/api
|
||||||
|
|
||||||
# Install Poetry
|
# Install uv
|
||||||
ENV POETRY_VERSION=2.0.1
|
ENV UV_VERSION=0.6.14
|
||||||
|
|
||||||
# if you located in China, you can use aliyun mirror to speed up
|
RUN pip install --no-cache-dir uv==${UV_VERSION}
|
||||||
# RUN pip install --no-cache-dir poetry==${POETRY_VERSION} -i https://mirrors.aliyun.com/pypi/simple/
|
|
||||||
|
|
||||||
RUN pip install --no-cache-dir poetry==${POETRY_VERSION}
|
|
||||||
|
|
||||||
# Configure Poetry
|
|
||||||
ENV POETRY_CACHE_DIR=/tmp/poetry_cache
|
|
||||||
ENV POETRY_NO_INTERACTION=1
|
|
||||||
ENV POETRY_VIRTUALENVS_IN_PROJECT=true
|
|
||||||
ENV POETRY_VIRTUALENVS_CREATE=true
|
|
||||||
ENV POETRY_REQUESTS_TIMEOUT=15
|
|
||||||
|
|
||||||
FROM base AS packages
|
FROM base AS packages
|
||||||
|
|
||||||
|
@ -27,8 +18,8 @@ RUN apt-get update \
|
||||||
&& apt-get install -y --no-install-recommends gcc g++ libc-dev libffi-dev libgmp-dev libmpfr-dev libmpc-dev
|
&& apt-get install -y --no-install-recommends gcc g++ libc-dev libffi-dev libgmp-dev libmpfr-dev libmpc-dev
|
||||||
|
|
||||||
# Install Python dependencies
|
# Install Python dependencies
|
||||||
COPY pyproject.toml poetry.lock ./
|
COPY pyproject.toml uv.lock ./
|
||||||
RUN poetry install --sync --no-cache --no-root
|
RUN uv sync --locked
|
||||||
|
|
||||||
# production stage
|
# production stage
|
||||||
FROM base AS production
|
FROM base AS production
|
||||||
|
|
|
@ -3,7 +3,10 @@
|
||||||
## Usage
|
## Usage
|
||||||
|
|
||||||
> [!IMPORTANT]
|
> [!IMPORTANT]
|
||||||
> In the v0.6.12 release, we deprecated `pip` as the package management tool for Dify API Backend service and replaced it with `poetry`.
|
>
|
||||||
|
> In the v1.3.0 release, `poetry` has been replaced with
|
||||||
|
> [`uv`](https://docs.astral.sh/uv/) as the package manager
|
||||||
|
> for Dify API backend service.
|
||||||
|
|
||||||
1. Start the docker-compose stack
|
1. Start the docker-compose stack
|
||||||
|
|
||||||
|
@ -37,19 +40,19 @@
|
||||||
|
|
||||||
4. Create environment.
|
4. Create environment.
|
||||||
|
|
||||||
Dify API service uses [Poetry](https://python-poetry.org/docs/) to manage dependencies. First, you need to add the poetry shell plugin, if you don't have it already, in order to run in a virtual environment. [Note: Poetry shell is no longer a native command so you need to install the poetry plugin beforehand]
|
Dify API service uses [UV](https://docs.astral.sh/uv/) to manage dependencies.
|
||||||
|
First, you need to add the uv package manager, if you don't have it already.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
poetry self add poetry-plugin-shell
|
pip install uv
|
||||||
|
# Or on macOS
|
||||||
|
brew install uv
|
||||||
```
|
```
|
||||||
|
|
||||||
Then, You can execute `poetry shell` to activate the environment.
|
|
||||||
|
|
||||||
5. Install dependencies
|
5. Install dependencies
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
poetry env use 3.12
|
uv sync --dev
|
||||||
poetry install
|
|
||||||
```
|
```
|
||||||
|
|
||||||
6. Run migrate
|
6. Run migrate
|
||||||
|
@ -57,21 +60,21 @@
|
||||||
Before the first launch, migrate the database to the latest version.
|
Before the first launch, migrate the database to the latest version.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
poetry run python -m flask db upgrade
|
uv run flask db upgrade
|
||||||
```
|
```
|
||||||
|
|
||||||
7. Start backend
|
7. Start backend
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
poetry run python -m flask run --host 0.0.0.0 --port=5001 --debug
|
uv run flask run --host 0.0.0.0 --port=5001 --debug
|
||||||
```
|
```
|
||||||
|
|
||||||
8. Start Dify [web](../web) service.
|
8. Start Dify [web](../web) service.
|
||||||
9. Setup your application by visiting `http://localhost:3000`...
|
9. Setup your application by visiting `http://localhost:3000`.
|
||||||
10. If you need to handle and debug the async tasks (e.g. dataset importing and documents indexing), please start the worker service.
|
10. If you need to handle and debug the async tasks (e.g. dataset importing and documents indexing), please start the worker service.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
poetry run python -m celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion
|
uv run celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion
|
||||||
```
|
```
|
||||||
|
|
||||||
## Testing
|
## Testing
|
||||||
|
@ -79,11 +82,11 @@
|
||||||
1. Install dependencies for both the backend and the test environment
|
1. Install dependencies for both the backend and the test environment
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
poetry install -C api --with dev
|
uv sync --dev
|
||||||
```
|
```
|
||||||
|
|
||||||
2. Run the tests locally with mocked system environment variables in `tool.pytest_env` section in `pyproject.toml`
|
2. Run the tests locally with mocked system environment variables in `tool.pytest_env` section in `pyproject.toml`
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
poetry run -P api bash dev/pytest/pytest_all_tests.sh
|
uv run -P api bash dev/pytest/pytest_all_tests.sh
|
||||||
```
|
```
|
||||||
|
|
|
@ -51,8 +51,10 @@ def initialize_extensions(app: DifyApp):
|
||||||
ext_login,
|
ext_login,
|
||||||
ext_mail,
|
ext_mail,
|
||||||
ext_migrate,
|
ext_migrate,
|
||||||
|
ext_otel,
|
||||||
ext_proxy_fix,
|
ext_proxy_fix,
|
||||||
ext_redis,
|
ext_redis,
|
||||||
|
ext_repositories,
|
||||||
ext_sentry,
|
ext_sentry,
|
||||||
ext_set_secretkey,
|
ext_set_secretkey,
|
||||||
ext_storage,
|
ext_storage,
|
||||||
|
@ -73,6 +75,7 @@ def initialize_extensions(app: DifyApp):
|
||||||
ext_migrate,
|
ext_migrate,
|
||||||
ext_redis,
|
ext_redis,
|
||||||
ext_storage,
|
ext_storage,
|
||||||
|
ext_repositories,
|
||||||
ext_celery,
|
ext_celery,
|
||||||
ext_login,
|
ext_login,
|
||||||
ext_mail,
|
ext_mail,
|
||||||
|
@ -81,6 +84,7 @@ def initialize_extensions(app: DifyApp):
|
||||||
ext_proxy_fix,
|
ext_proxy_fix,
|
||||||
ext_blueprints,
|
ext_blueprints,
|
||||||
ext_commands,
|
ext_commands,
|
||||||
|
ext_otel,
|
||||||
]
|
]
|
||||||
for ext in extensions:
|
for ext in extensions:
|
||||||
short_name = ext.__name__.split(".")[-1]
|
short_name = ext.__name__.split(".")[-1]
|
||||||
|
|
|
@ -9,6 +9,7 @@ from .enterprise import EnterpriseFeatureConfig
|
||||||
from .extra import ExtraServiceConfig
|
from .extra import ExtraServiceConfig
|
||||||
from .feature import FeatureConfig
|
from .feature import FeatureConfig
|
||||||
from .middleware import MiddlewareConfig
|
from .middleware import MiddlewareConfig
|
||||||
|
from .observability import ObservabilityConfig
|
||||||
from .packaging import PackagingInfo
|
from .packaging import PackagingInfo
|
||||||
from .remote_settings_sources import RemoteSettingsSource, RemoteSettingsSourceConfig, RemoteSettingsSourceName
|
from .remote_settings_sources import RemoteSettingsSource, RemoteSettingsSourceConfig, RemoteSettingsSourceName
|
||||||
from .remote_settings_sources.apollo import ApolloSettingsSource
|
from .remote_settings_sources.apollo import ApolloSettingsSource
|
||||||
|
@ -59,6 +60,8 @@ class DifyConfig(
|
||||||
MiddlewareConfig,
|
MiddlewareConfig,
|
||||||
# Extra service configs
|
# Extra service configs
|
||||||
ExtraServiceConfig,
|
ExtraServiceConfig,
|
||||||
|
# Observability configs
|
||||||
|
ObservabilityConfig,
|
||||||
# Remote source configs
|
# Remote source configs
|
||||||
RemoteSettingsSourceConfig,
|
RemoteSettingsSourceConfig,
|
||||||
# Enterprise feature configs
|
# Enterprise feature configs
|
||||||
|
|
|
@ -12,7 +12,7 @@ from pydantic import (
|
||||||
)
|
)
|
||||||
from pydantic_settings import BaseSettings
|
from pydantic_settings import BaseSettings
|
||||||
|
|
||||||
from configs.feature.hosted_service import HostedServiceConfig
|
from .hosted_service import HostedServiceConfig
|
||||||
|
|
||||||
|
|
||||||
class SecurityConfig(BaseSettings):
|
class SecurityConfig(BaseSettings):
|
||||||
|
@ -519,6 +519,11 @@ class WorkflowNodeExecutionConfig(BaseSettings):
|
||||||
default=100,
|
default=100,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
WORKFLOW_NODE_EXECUTION_STORAGE: str = Field(
|
||||||
|
default="rdbms",
|
||||||
|
description="Storage backend for WorkflowNodeExecution. Options: 'rdbms', 'hybrid'",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class AuthConfig(BaseSettings):
|
class AuthConfig(BaseSettings):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -39,3 +39,8 @@ class MilvusConfig(BaseSettings):
|
||||||
"older versions",
|
"older versions",
|
||||||
default=True,
|
default=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
MILVUS_ANALYZER_PARAMS: Optional[str] = Field(
|
||||||
|
description='Milvus text analyzer parameters, e.g., {"type": "chinese"} for Chinese segmentation support.',
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
|
|
|
@ -0,0 +1,9 @@
|
||||||
|
from configs.observability.otel.otel_config import OTelConfig
|
||||||
|
|
||||||
|
|
||||||
|
class ObservabilityConfig(OTelConfig):
|
||||||
|
"""
|
||||||
|
Observability configuration settings
|
||||||
|
"""
|
||||||
|
|
||||||
|
pass
|
|
@ -0,0 +1,44 @@
|
||||||
|
from pydantic import Field
|
||||||
|
from pydantic_settings import BaseSettings
|
||||||
|
|
||||||
|
|
||||||
|
class OTelConfig(BaseSettings):
|
||||||
|
"""
|
||||||
|
OpenTelemetry configuration settings
|
||||||
|
"""
|
||||||
|
|
||||||
|
ENABLE_OTEL: bool = Field(
|
||||||
|
description="Whether to enable OpenTelemetry",
|
||||||
|
default=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
OTLP_BASE_ENDPOINT: str = Field(
|
||||||
|
description="OTLP base endpoint",
|
||||||
|
default="http://localhost:4318",
|
||||||
|
)
|
||||||
|
|
||||||
|
OTLP_API_KEY: str = Field(
|
||||||
|
description="OTLP API key",
|
||||||
|
default="",
|
||||||
|
)
|
||||||
|
|
||||||
|
OTEL_EXPORTER_TYPE: str = Field(
|
||||||
|
description="OTEL exporter type",
|
||||||
|
default="otlp",
|
||||||
|
)
|
||||||
|
|
||||||
|
OTEL_SAMPLING_RATE: float = Field(default=0.1, description="Sampling rate for traces (0.0 to 1.0)")
|
||||||
|
|
||||||
|
OTEL_BATCH_EXPORT_SCHEDULE_DELAY: int = Field(
|
||||||
|
default=5000, description="Batch export schedule delay in milliseconds"
|
||||||
|
)
|
||||||
|
|
||||||
|
OTEL_MAX_QUEUE_SIZE: int = Field(default=2048, description="Maximum queue size for the batch span processor")
|
||||||
|
|
||||||
|
OTEL_MAX_EXPORT_BATCH_SIZE: int = Field(default=512, description="Maximum export batch size")
|
||||||
|
|
||||||
|
OTEL_METRIC_EXPORT_INTERVAL: int = Field(default=60000, description="Metric export interval in milliseconds")
|
||||||
|
|
||||||
|
OTEL_BATCH_EXPORT_TIMEOUT: int = Field(default=10000, description="Batch export timeout in milliseconds")
|
||||||
|
|
||||||
|
OTEL_METRIC_EXPORT_TIMEOUT: int = Field(default=30000, description="Metric export timeout in milliseconds")
|
|
@ -270,7 +270,7 @@ class ApolloClient:
|
||||||
while not self._stopping:
|
while not self._stopping:
|
||||||
for namespace in self._notification_map:
|
for namespace in self._notification_map:
|
||||||
self._do_heart_beat(namespace)
|
self._do_heart_beat(namespace)
|
||||||
time.sleep(60 * 10) # 10分钟
|
time.sleep(60 * 10) # 10 minutes
|
||||||
|
|
||||||
def _do_heart_beat(self, namespace):
|
def _do_heart_beat(self, namespace):
|
||||||
url = "{}/configs/{}/{}/{}?ip={}".format(self.config_url, self.app_id, self.cluster, namespace, self.ip)
|
url = "{}/configs/{}/{}/{}?ip={}".format(self.config_url, self.app_id, self.cluster, namespace, self.ip)
|
||||||
|
|
|
@ -3,6 +3,8 @@ from configs import dify_config
|
||||||
HIDDEN_VALUE = "[__HIDDEN__]"
|
HIDDEN_VALUE = "[__HIDDEN__]"
|
||||||
UUID_NIL = "00000000-0000-0000-0000-000000000000"
|
UUID_NIL = "00000000-0000-0000-0000-000000000000"
|
||||||
|
|
||||||
|
DEFAULT_FILE_NUMBER_LIMITS = 3
|
||||||
|
|
||||||
IMAGE_EXTENSIONS = ["jpg", "jpeg", "png", "webp", "gif", "svg"]
|
IMAGE_EXTENSIONS = ["jpg", "jpeg", "png", "webp", "gif", "svg"]
|
||||||
IMAGE_EXTENSIONS.extend([ext.upper() for ext in IMAGE_EXTENSIONS])
|
IMAGE_EXTENSIONS.extend([ext.upper() for ext in IMAGE_EXTENSIONS])
|
||||||
|
|
||||||
|
|
|
@ -4,8 +4,6 @@ import platform
|
||||||
import re
|
import re
|
||||||
import urllib.parse
|
import urllib.parse
|
||||||
import warnings
|
import warnings
|
||||||
from collections.abc import Mapping
|
|
||||||
from typing import Any
|
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
@ -29,8 +27,6 @@ except ImportError:
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from configs import dify_config
|
|
||||||
|
|
||||||
|
|
||||||
class FileInfo(BaseModel):
|
class FileInfo(BaseModel):
|
||||||
filename: str
|
filename: str
|
||||||
|
@ -87,38 +83,3 @@ def guess_file_info_from_response(response: httpx.Response):
|
||||||
mimetype=mimetype,
|
mimetype=mimetype,
|
||||||
size=int(response.headers.get("Content-Length", -1)),
|
size=int(response.headers.get("Content-Length", -1)),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_parameters_from_feature_dict(*, features_dict: Mapping[str, Any], user_input_form: list[dict[str, Any]]):
|
|
||||||
return {
|
|
||||||
"opening_statement": features_dict.get("opening_statement"),
|
|
||||||
"suggested_questions": features_dict.get("suggested_questions", []),
|
|
||||||
"suggested_questions_after_answer": features_dict.get("suggested_questions_after_answer", {"enabled": False}),
|
|
||||||
"speech_to_text": features_dict.get("speech_to_text", {"enabled": False}),
|
|
||||||
"text_to_speech": features_dict.get("text_to_speech", {"enabled": False}),
|
|
||||||
"retriever_resource": features_dict.get("retriever_resource", {"enabled": False}),
|
|
||||||
"annotation_reply": features_dict.get("annotation_reply", {"enabled": False}),
|
|
||||||
"more_like_this": features_dict.get("more_like_this", {"enabled": False}),
|
|
||||||
"user_input_form": user_input_form,
|
|
||||||
"sensitive_word_avoidance": features_dict.get(
|
|
||||||
"sensitive_word_avoidance", {"enabled": False, "type": "", "configs": []}
|
|
||||||
),
|
|
||||||
"file_upload": features_dict.get(
|
|
||||||
"file_upload",
|
|
||||||
{
|
|
||||||
"image": {
|
|
||||||
"enabled": False,
|
|
||||||
"number_limits": 3,
|
|
||||||
"detail": "high",
|
|
||||||
"transfer_methods": ["remote_url", "local_file"],
|
|
||||||
}
|
|
||||||
},
|
|
||||||
),
|
|
||||||
"system_parameters": {
|
|
||||||
"image_file_size_limit": dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT,
|
|
||||||
"video_file_size_limit": dify_config.UPLOAD_VIDEO_FILE_SIZE_LIMIT,
|
|
||||||
"audio_file_size_limit": dify_config.UPLOAD_AUDIO_FILE_SIZE_LIMIT,
|
|
||||||
"file_size_limit": dify_config.UPLOAD_FILE_SIZE_LIMIT,
|
|
||||||
"workflow_file_upload_limit": dify_config.WORKFLOW_FILE_UPLOAD_LIMIT,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
|
@ -89,7 +89,7 @@ class AnnotationReplyActionStatusApi(Resource):
|
||||||
app_annotation_job_key = "{}_app_annotation_job_{}".format(action, str(job_id))
|
app_annotation_job_key = "{}_app_annotation_job_{}".format(action, str(job_id))
|
||||||
cache_result = redis_client.get(app_annotation_job_key)
|
cache_result = redis_client.get(app_annotation_job_key)
|
||||||
if cache_result is None:
|
if cache_result is None:
|
||||||
raise ValueError("The job is not exist.")
|
raise ValueError("The job does not exist.")
|
||||||
|
|
||||||
job_status = cache_result.decode()
|
job_status = cache_result.decode()
|
||||||
error_msg = ""
|
error_msg = ""
|
||||||
|
@ -226,7 +226,7 @@ class AnnotationBatchImportStatusApi(Resource):
|
||||||
indexing_cache_key = "app_annotation_batch_import_{}".format(str(job_id))
|
indexing_cache_key = "app_annotation_batch_import_{}".format(str(job_id))
|
||||||
cache_result = redis_client.get(indexing_cache_key)
|
cache_result = redis_client.get(indexing_cache_key)
|
||||||
if cache_result is None:
|
if cache_result is None:
|
||||||
raise ValueError("The job is not exist.")
|
raise ValueError("The job does not exist.")
|
||||||
job_status = cache_result.decode()
|
job_status = cache_result.decode()
|
||||||
error_msg = ""
|
error_msg = ""
|
||||||
if job_status == "error":
|
if job_status == "error":
|
||||||
|
|
|
@ -1,5 +1,4 @@
|
||||||
from datetime import datetime
|
from dateutil.parser import isoparse
|
||||||
|
|
||||||
from flask_restful import Resource, marshal_with, reqparse # type: ignore
|
from flask_restful import Resource, marshal_with, reqparse # type: ignore
|
||||||
from flask_restful.inputs import int_range # type: ignore
|
from flask_restful.inputs import int_range # type: ignore
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
@ -41,10 +40,10 @@ class WorkflowAppLogApi(Resource):
|
||||||
|
|
||||||
args.status = WorkflowRunStatus(args.status) if args.status else None
|
args.status = WorkflowRunStatus(args.status) if args.status else None
|
||||||
if args.created_at__before:
|
if args.created_at__before:
|
||||||
args.created_at__before = datetime.fromisoformat(args.created_at__before.replace("Z", "+00:00"))
|
args.created_at__before = isoparse(args.created_at__before)
|
||||||
|
|
||||||
if args.created_at__after:
|
if args.created_at__after:
|
||||||
args.created_at__after = datetime.fromisoformat(args.created_at__after.replace("Z", "+00:00"))
|
args.created_at__after = isoparse(args.created_at__after)
|
||||||
|
|
||||||
# get paginate workflow app logs
|
# get paginate workflow app logs
|
||||||
workflow_app_service = WorkflowAppService()
|
workflow_app_service = WorkflowAppService()
|
||||||
|
|
|
@ -74,7 +74,9 @@ class OAuthDataSourceBinding(Resource):
|
||||||
if not oauth_provider:
|
if not oauth_provider:
|
||||||
return {"error": "Invalid provider"}, 400
|
return {"error": "Invalid provider"}, 400
|
||||||
if "code" in request.args:
|
if "code" in request.args:
|
||||||
code = request.args.get("code")
|
code = request.args.get("code", "")
|
||||||
|
if not code:
|
||||||
|
return {"error": "Invalid code"}, 400
|
||||||
try:
|
try:
|
||||||
oauth_provider.get_access_token(code)
|
oauth_provider.get_access_token(code)
|
||||||
except requests.exceptions.HTTPError as e:
|
except requests.exceptions.HTTPError as e:
|
||||||
|
|
|
@ -398,7 +398,7 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
|
||||||
indexing_cache_key = "segment_batch_import_{}".format(job_id)
|
indexing_cache_key = "segment_batch_import_{}".format(job_id)
|
||||||
cache_result = redis_client.get(indexing_cache_key)
|
cache_result = redis_client.get(indexing_cache_key)
|
||||||
if cache_result is None:
|
if cache_result is None:
|
||||||
raise ValueError("The job is not exist.")
|
raise ValueError("The job does not exist.")
|
||||||
|
|
||||||
return {"job_id": job_id, "job_status": cache_result.decode()}, 200
|
return {"job_id": job_id, "job_status": cache_result.decode()}, 200
|
||||||
|
|
||||||
|
|
|
@ -21,12 +21,6 @@ def _validate_name(name):
|
||||||
return name
|
return name
|
||||||
|
|
||||||
|
|
||||||
def _validate_description_length(description):
|
|
||||||
if description and len(description) > 400:
|
|
||||||
raise ValueError("Description cannot exceed 400 characters.")
|
|
||||||
return description
|
|
||||||
|
|
||||||
|
|
||||||
class ExternalApiTemplateListApi(Resource):
|
class ExternalApiTemplateListApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
|
|
|
@ -14,18 +14,6 @@ from services.entities.knowledge_entities.knowledge_entities import (
|
||||||
from services.metadata_service import MetadataService
|
from services.metadata_service import MetadataService
|
||||||
|
|
||||||
|
|
||||||
def _validate_name(name):
|
|
||||||
if not name or len(name) < 1 or len(name) > 40:
|
|
||||||
raise ValueError("Name must be between 1 to 40 characters.")
|
|
||||||
return name
|
|
||||||
|
|
||||||
|
|
||||||
def _validate_description_length(description):
|
|
||||||
if len(description) > 400:
|
|
||||||
raise ValueError("Description cannot exceed 400 characters.")
|
|
||||||
return description
|
|
||||||
|
|
||||||
|
|
||||||
class DatasetMetadataCreateApi(Resource):
|
class DatasetMetadataCreateApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
|
|
|
@ -1,10 +1,10 @@
|
||||||
from flask_restful import marshal_with # type: ignore
|
from flask_restful import marshal_with # type: ignore
|
||||||
|
|
||||||
from controllers.common import fields
|
from controllers.common import fields
|
||||||
from controllers.common import helpers as controller_helpers
|
|
||||||
from controllers.console import api
|
from controllers.console import api
|
||||||
from controllers.console.app.error import AppUnavailableError
|
from controllers.console.app.error import AppUnavailableError
|
||||||
from controllers.console.explore.wraps import InstalledAppResource
|
from controllers.console.explore.wraps import InstalledAppResource
|
||||||
|
from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict
|
||||||
from models.model import AppMode, InstalledApp
|
from models.model import AppMode, InstalledApp
|
||||||
from services.app_service import AppService
|
from services.app_service import AppService
|
||||||
|
|
||||||
|
@ -36,9 +36,7 @@ class AppParameterApi(InstalledAppResource):
|
||||||
|
|
||||||
user_input_form = features_dict.get("user_input_form", [])
|
user_input_form = features_dict.get("user_input_form", [])
|
||||||
|
|
||||||
return controller_helpers.get_parameters_from_feature_dict(
|
return get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form)
|
||||||
features_dict=features_dict, user_input_form=user_input_form
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ExploreAppMetaApi(InstalledAppResource):
|
class ExploreAppMetaApi(InstalledAppResource):
|
||||||
|
|
|
@ -286,8 +286,6 @@ class AccountDeleteApi(Resource):
|
||||||
class AccountDeleteUpdateFeedbackApi(Resource):
|
class AccountDeleteUpdateFeedbackApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
def post(self):
|
def post(self):
|
||||||
account = current_user
|
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument("email", type=str, required=True, location="json")
|
parser.add_argument("email", type=str, required=True, location="json")
|
||||||
parser.add_argument("feedback", type=str, required=True, location="json")
|
parser.add_argument("feedback", type=str, required=True, location="json")
|
||||||
|
|
|
@ -249,6 +249,31 @@ class PluginInstallFromMarketplaceApi(Resource):
|
||||||
return jsonable_encoder(response)
|
return jsonable_encoder(response)
|
||||||
|
|
||||||
|
|
||||||
|
class PluginFetchMarketplacePkgApi(Resource):
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
@plugin_permission_required(install_required=True)
|
||||||
|
def get(self):
|
||||||
|
tenant_id = current_user.current_tenant_id
|
||||||
|
|
||||||
|
parser = reqparse.RequestParser()
|
||||||
|
parser.add_argument("plugin_unique_identifier", type=str, required=True, location="args")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
try:
|
||||||
|
return jsonable_encoder(
|
||||||
|
{
|
||||||
|
"manifest": PluginService.fetch_marketplace_pkg(
|
||||||
|
tenant_id,
|
||||||
|
args["plugin_unique_identifier"],
|
||||||
|
)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
except PluginDaemonClientSideError as e:
|
||||||
|
raise ValueError(e)
|
||||||
|
|
||||||
|
|
||||||
class PluginFetchManifestApi(Resource):
|
class PluginFetchManifestApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
|
@ -488,6 +513,7 @@ api.add_resource(PluginDeleteInstallTaskApi, "/workspaces/current/plugin/tasks/<
|
||||||
api.add_resource(PluginDeleteAllInstallTaskItemsApi, "/workspaces/current/plugin/tasks/delete_all")
|
api.add_resource(PluginDeleteAllInstallTaskItemsApi, "/workspaces/current/plugin/tasks/delete_all")
|
||||||
api.add_resource(PluginDeleteInstallTaskItemApi, "/workspaces/current/plugin/tasks/<task_id>/delete/<path:identifier>")
|
api.add_resource(PluginDeleteInstallTaskItemApi, "/workspaces/current/plugin/tasks/<task_id>/delete/<path:identifier>")
|
||||||
api.add_resource(PluginUninstallApi, "/workspaces/current/plugin/uninstall")
|
api.add_resource(PluginUninstallApi, "/workspaces/current/plugin/uninstall")
|
||||||
|
api.add_resource(PluginFetchMarketplacePkgApi, "/workspaces/current/plugin/marketplace/pkg")
|
||||||
|
|
||||||
api.add_resource(PluginChangePermissionApi, "/workspaces/current/plugin/permission/change")
|
api.add_resource(PluginChangePermissionApi, "/workspaces/current/plugin/permission/change")
|
||||||
api.add_resource(PluginFetchPermissionApi, "/workspaces/current/plugin/permission/fetch")
|
api.add_resource(PluginFetchPermissionApi, "/workspaces/current/plugin/permission/fetch")
|
||||||
|
|
|
@ -13,6 +13,7 @@ from core.plugin.backwards_invocation.model import PluginModelBackwardsInvocatio
|
||||||
from core.plugin.backwards_invocation.node import PluginNodeBackwardsInvocation
|
from core.plugin.backwards_invocation.node import PluginNodeBackwardsInvocation
|
||||||
from core.plugin.backwards_invocation.tool import PluginToolBackwardsInvocation
|
from core.plugin.backwards_invocation.tool import PluginToolBackwardsInvocation
|
||||||
from core.plugin.entities.request import (
|
from core.plugin.entities.request import (
|
||||||
|
RequestFetchAppInfo,
|
||||||
RequestInvokeApp,
|
RequestInvokeApp,
|
||||||
RequestInvokeEncrypt,
|
RequestInvokeEncrypt,
|
||||||
RequestInvokeLLM,
|
RequestInvokeLLM,
|
||||||
|
@ -278,6 +279,17 @@ class PluginUploadFileRequestApi(Resource):
|
||||||
return BaseBackwardsInvocationResponse(data={"url": url}).model_dump()
|
return BaseBackwardsInvocationResponse(data={"url": url}).model_dump()
|
||||||
|
|
||||||
|
|
||||||
|
class PluginFetchAppInfoApi(Resource):
|
||||||
|
@setup_required
|
||||||
|
@plugin_inner_api_only
|
||||||
|
@get_user_tenant
|
||||||
|
@plugin_data(payload_type=RequestFetchAppInfo)
|
||||||
|
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestFetchAppInfo):
|
||||||
|
return BaseBackwardsInvocationResponse(
|
||||||
|
data=PluginAppBackwardsInvocation.fetch_app_info(payload.app_id, tenant_model.id)
|
||||||
|
).model_dump()
|
||||||
|
|
||||||
|
|
||||||
api.add_resource(PluginInvokeLLMApi, "/invoke/llm")
|
api.add_resource(PluginInvokeLLMApi, "/invoke/llm")
|
||||||
api.add_resource(PluginInvokeTextEmbeddingApi, "/invoke/text-embedding")
|
api.add_resource(PluginInvokeTextEmbeddingApi, "/invoke/text-embedding")
|
||||||
api.add_resource(PluginInvokeRerankApi, "/invoke/rerank")
|
api.add_resource(PluginInvokeRerankApi, "/invoke/rerank")
|
||||||
|
@ -291,3 +303,4 @@ api.add_resource(PluginInvokeAppApi, "/invoke/app")
|
||||||
api.add_resource(PluginInvokeEncryptApi, "/invoke/encrypt")
|
api.add_resource(PluginInvokeEncryptApi, "/invoke/encrypt")
|
||||||
api.add_resource(PluginInvokeSummaryApi, "/invoke/summary")
|
api.add_resource(PluginInvokeSummaryApi, "/invoke/summary")
|
||||||
api.add_resource(PluginUploadFileRequestApi, "/upload/file/request")
|
api.add_resource(PluginUploadFileRequestApi, "/upload/file/request")
|
||||||
|
api.add_resource(PluginFetchAppInfoApi, "/fetch/app/info")
|
||||||
|
|
|
@ -1,10 +1,10 @@
|
||||||
from flask_restful import Resource, marshal_with # type: ignore
|
from flask_restful import Resource, marshal_with # type: ignore
|
||||||
|
|
||||||
from controllers.common import fields
|
from controllers.common import fields
|
||||||
from controllers.common import helpers as controller_helpers
|
|
||||||
from controllers.service_api import api
|
from controllers.service_api import api
|
||||||
from controllers.service_api.app.error import AppUnavailableError
|
from controllers.service_api.app.error import AppUnavailableError
|
||||||
from controllers.service_api.wraps import validate_app_token
|
from controllers.service_api.wraps import validate_app_token
|
||||||
|
from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict
|
||||||
from models.model import App, AppMode
|
from models.model import App, AppMode
|
||||||
from services.app_service import AppService
|
from services.app_service import AppService
|
||||||
|
|
||||||
|
@ -32,9 +32,7 @@ class AppParameterApi(Resource):
|
||||||
|
|
||||||
user_input_form = features_dict.get("user_input_form", [])
|
user_input_form = features_dict.get("user_input_form", [])
|
||||||
|
|
||||||
return controller_helpers.get_parameters_from_feature_dict(
|
return get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form)
|
||||||
features_dict=features_dict, user_input_form=user_input_form
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class AppMetaApi(Resource):
|
class AppMetaApi(Resource):
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
import logging
|
import logging
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
|
from dateutil.parser import isoparse
|
||||||
from flask_restful import Resource, fields, marshal_with, reqparse # type: ignore
|
from flask_restful import Resource, fields, marshal_with, reqparse # type: ignore
|
||||||
from flask_restful.inputs import int_range # type: ignore
|
from flask_restful.inputs import int_range # type: ignore
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
@ -140,10 +140,10 @@ class WorkflowAppLogApi(Resource):
|
||||||
|
|
||||||
args.status = WorkflowRunStatus(args.status) if args.status else None
|
args.status = WorkflowRunStatus(args.status) if args.status else None
|
||||||
if args.created_at__before:
|
if args.created_at__before:
|
||||||
args.created_at__before = datetime.fromisoformat(args.created_at__before.replace("Z", "+00:00"))
|
args.created_at__before = isoparse(args.created_at__before)
|
||||||
|
|
||||||
if args.created_at__after:
|
if args.created_at__after:
|
||||||
args.created_at__after = datetime.fromisoformat(args.created_at__after.replace("Z", "+00:00"))
|
args.created_at__after = isoparse(args.created_at__after)
|
||||||
|
|
||||||
# get paginate workflow app logs
|
# get paginate workflow app logs
|
||||||
workflow_app_service = WorkflowAppService()
|
workflow_app_service = WorkflowAppService()
|
||||||
|
|
|
@ -13,6 +13,7 @@ from fields.dataset_fields import dataset_detail_fields
|
||||||
from libs.login import current_user
|
from libs.login import current_user
|
||||||
from models.dataset import Dataset, DatasetPermissionEnum
|
from models.dataset import Dataset, DatasetPermissionEnum
|
||||||
from services.dataset_service import DatasetPermissionService, DatasetService
|
from services.dataset_service import DatasetPermissionService, DatasetService
|
||||||
|
from services.entities.knowledge_entities.knowledge_entities import RetrievalModel
|
||||||
|
|
||||||
|
|
||||||
def _validate_name(name):
|
def _validate_name(name):
|
||||||
|
@ -120,8 +121,11 @@ class DatasetListApi(DatasetApiResource):
|
||||||
nullable=True,
|
nullable=True,
|
||||||
required=False,
|
required=False,
|
||||||
)
|
)
|
||||||
args = parser.parse_args()
|
parser.add_argument("retrieval_model", type=dict, required=False, nullable=True, location="json")
|
||||||
|
parser.add_argument("embedding_model", type=str, required=False, nullable=True, location="json")
|
||||||
|
parser.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
try:
|
try:
|
||||||
dataset = DatasetService.create_empty_dataset(
|
dataset = DatasetService.create_empty_dataset(
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
|
@ -133,6 +137,11 @@ class DatasetListApi(DatasetApiResource):
|
||||||
provider=args["provider"],
|
provider=args["provider"],
|
||||||
external_knowledge_api_id=args["external_knowledge_api_id"],
|
external_knowledge_api_id=args["external_knowledge_api_id"],
|
||||||
external_knowledge_id=args["external_knowledge_id"],
|
external_knowledge_id=args["external_knowledge_id"],
|
||||||
|
embedding_model_provider=args["embedding_model_provider"],
|
||||||
|
embedding_model_name=args["embedding_model"],
|
||||||
|
retrieval_model=RetrievalModel(**args["retrieval_model"])
|
||||||
|
if args["retrieval_model"] is not None
|
||||||
|
else None,
|
||||||
)
|
)
|
||||||
except services.errors.dataset.DatasetNameDuplicateError:
|
except services.errors.dataset.DatasetNameDuplicateError:
|
||||||
raise DatasetNameDuplicateError()
|
raise DatasetNameDuplicateError()
|
||||||
|
|
|
@ -49,7 +49,9 @@ class DocumentAddByTextApi(DatasetApiResource):
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"indexing_technique", type=str, choices=Dataset.INDEXING_TECHNIQUE_LIST, nullable=False, location="json"
|
"indexing_technique", type=str, choices=Dataset.INDEXING_TECHNIQUE_LIST, nullable=False, location="json"
|
||||||
)
|
)
|
||||||
parser.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json")
|
parser.add_argument("retrieval_model", type=dict, required=False, nullable=True, location="json")
|
||||||
|
parser.add_argument("embedding_model", type=str, required=False, nullable=True, location="json")
|
||||||
|
parser.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
dataset_id = str(dataset_id)
|
dataset_id = str(dataset_id)
|
||||||
|
@ -57,7 +59,7 @@ class DocumentAddByTextApi(DatasetApiResource):
|
||||||
dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
|
dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
|
||||||
|
|
||||||
if not dataset:
|
if not dataset:
|
||||||
raise ValueError("Dataset is not exist.")
|
raise ValueError("Dataset does not exist.")
|
||||||
|
|
||||||
if not dataset.indexing_technique and not args["indexing_technique"]:
|
if not dataset.indexing_technique and not args["indexing_technique"]:
|
||||||
raise ValueError("indexing_technique is required.")
|
raise ValueError("indexing_technique is required.")
|
||||||
|
@ -114,7 +116,7 @@ class DocumentUpdateByTextApi(DatasetApiResource):
|
||||||
dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
|
dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
|
||||||
|
|
||||||
if not dataset:
|
if not dataset:
|
||||||
raise ValueError("Dataset is not exist.")
|
raise ValueError("Dataset does not exist.")
|
||||||
|
|
||||||
# indexing_technique is already set in dataset since this is an update
|
# indexing_technique is already set in dataset since this is an update
|
||||||
args["indexing_technique"] = dataset.indexing_technique
|
args["indexing_technique"] = dataset.indexing_technique
|
||||||
|
@ -172,7 +174,7 @@ class DocumentAddByFileApi(DatasetApiResource):
|
||||||
dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
|
dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
|
||||||
|
|
||||||
if not dataset:
|
if not dataset:
|
||||||
raise ValueError("Dataset is not exist.")
|
raise ValueError("Dataset does not exist.")
|
||||||
if not dataset.indexing_technique and not args.get("indexing_technique"):
|
if not dataset.indexing_technique and not args.get("indexing_technique"):
|
||||||
raise ValueError("indexing_technique is required.")
|
raise ValueError("indexing_technique is required.")
|
||||||
|
|
||||||
|
@ -239,7 +241,7 @@ class DocumentUpdateByFileApi(DatasetApiResource):
|
||||||
dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
|
dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
|
||||||
|
|
||||||
if not dataset:
|
if not dataset:
|
||||||
raise ValueError("Dataset is not exist.")
|
raise ValueError("Dataset does not exist.")
|
||||||
|
|
||||||
# indexing_technique is already set in dataset since this is an update
|
# indexing_technique is already set in dataset since this is an update
|
||||||
args["indexing_technique"] = dataset.indexing_technique
|
args["indexing_technique"] = dataset.indexing_technique
|
||||||
|
@ -303,7 +305,7 @@ class DocumentDeleteApi(DatasetApiResource):
|
||||||
dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
|
dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
|
||||||
|
|
||||||
if not dataset:
|
if not dataset:
|
||||||
raise ValueError("Dataset is not exist.")
|
raise ValueError("Dataset does not exist.")
|
||||||
|
|
||||||
document = DocumentService.get_document(dataset.id, document_id)
|
document = DocumentService.get_document(dataset.id, document_id)
|
||||||
|
|
||||||
|
|
|
@ -13,18 +13,6 @@ from services.entities.knowledge_entities.knowledge_entities import (
|
||||||
from services.metadata_service import MetadataService
|
from services.metadata_service import MetadataService
|
||||||
|
|
||||||
|
|
||||||
def _validate_name(name):
|
|
||||||
if not name or len(name) < 1 or len(name) > 40:
|
|
||||||
raise ValueError("Name must be between 1 to 40 characters.")
|
|
||||||
return name
|
|
||||||
|
|
||||||
|
|
||||||
def _validate_description_length(description):
|
|
||||||
if len(description) > 400:
|
|
||||||
raise ValueError("Description cannot exceed 400 characters.")
|
|
||||||
return description
|
|
||||||
|
|
||||||
|
|
||||||
class DatasetMetadataCreateServiceApi(DatasetApiResource):
|
class DatasetMetadataCreateServiceApi(DatasetApiResource):
|
||||||
def post(self, tenant_id, dataset_id):
|
def post(self, tenant_id, dataset_id):
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
|
|
|
@ -117,14 +117,13 @@ class SegmentApi(DatasetApiResource):
|
||||||
parser.add_argument("keyword", type=str, default=None, location="args")
|
parser.add_argument("keyword", type=str, default=None, location="args")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
status_list = args["status"]
|
|
||||||
keyword = args["keyword"]
|
|
||||||
|
|
||||||
segments, total = SegmentService.get_segments(
|
segments, total = SegmentService.get_segments(
|
||||||
document_id=document_id,
|
document_id=document_id,
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_user.current_tenant_id,
|
||||||
status_list=args["status"],
|
status_list=args["status"],
|
||||||
keyword=args["keyword"],
|
keyword=args["keyword"],
|
||||||
|
page=page,
|
||||||
|
limit=limit,
|
||||||
)
|
)
|
||||||
|
|
||||||
response = {
|
response = {
|
||||||
|
|
|
@ -1,10 +1,10 @@
|
||||||
from flask_restful import marshal_with # type: ignore
|
from flask_restful import marshal_with # type: ignore
|
||||||
|
|
||||||
from controllers.common import fields
|
from controllers.common import fields
|
||||||
from controllers.common import helpers as controller_helpers
|
|
||||||
from controllers.web import api
|
from controllers.web import api
|
||||||
from controllers.web.error import AppUnavailableError
|
from controllers.web.error import AppUnavailableError
|
||||||
from controllers.web.wraps import WebApiResource
|
from controllers.web.wraps import WebApiResource
|
||||||
|
from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict
|
||||||
from models.model import App, AppMode
|
from models.model import App, AppMode
|
||||||
from services.app_service import AppService
|
from services.app_service import AppService
|
||||||
|
|
||||||
|
@ -31,9 +31,7 @@ class AppParameterApi(WebApiResource):
|
||||||
|
|
||||||
user_input_form = features_dict.get("user_input_form", [])
|
user_input_form = features_dict.get("user_input_form", [])
|
||||||
|
|
||||||
return controller_helpers.get_parameters_from_feature_dict(
|
return get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form)
|
||||||
features_dict=features_dict, user_input_form=user_input_form
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class AppMeta(WebApiResource):
|
class AppMeta(WebApiResource):
|
||||||
|
|
|
@ -46,6 +46,7 @@ class MessageListApi(WebApiResource):
|
||||||
"retriever_resources": fields.List(fields.Nested(retriever_resource_fields)),
|
"retriever_resources": fields.List(fields.Nested(retriever_resource_fields)),
|
||||||
"created_at": TimestampField,
|
"created_at": TimestampField,
|
||||||
"agent_thoughts": fields.List(fields.Nested(agent_thought_fields)),
|
"agent_thoughts": fields.List(fields.Nested(agent_thought_fields)),
|
||||||
|
"metadata": fields.Raw(attribute="message_metadata_dict"),
|
||||||
"status": fields.String,
|
"status": fields.String,
|
||||||
"error": fields.String,
|
"error": fields.String,
|
||||||
}
|
}
|
||||||
|
|
|
@ -191,7 +191,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||||
# action is final answer, return final answer directly
|
# action is final answer, return final answer directly
|
||||||
try:
|
try:
|
||||||
if isinstance(scratchpad.action.action_input, dict):
|
if isinstance(scratchpad.action.action_input, dict):
|
||||||
final_answer = json.dumps(scratchpad.action.action_input)
|
final_answer = json.dumps(scratchpad.action.action_input, ensure_ascii=False)
|
||||||
elif isinstance(scratchpad.action.action_input, str):
|
elif isinstance(scratchpad.action.action_input, str):
|
||||||
final_answer = scratchpad.action.action_input
|
final_answer = scratchpad.action.action_input
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -52,6 +52,7 @@ class AgentStrategyParameter(PluginParameter):
|
||||||
return cast_parameter_value(self, value)
|
return cast_parameter_value(self, value)
|
||||||
|
|
||||||
type: AgentStrategyParameterType = Field(..., description="The type of the parameter")
|
type: AgentStrategyParameterType = Field(..., description="The type of the parameter")
|
||||||
|
help: Optional[I18nObject] = None
|
||||||
|
|
||||||
def init_frontend_parameter(self, value: Any):
|
def init_frontend_parameter(self, value: Any):
|
||||||
return init_frontend_parameter(self, self.type, value)
|
return init_frontend_parameter(self, self.type, value)
|
||||||
|
|
|
@ -0,0 +1,45 @@
|
||||||
|
from collections.abc import Mapping
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from configs import dify_config
|
||||||
|
from constants import DEFAULT_FILE_NUMBER_LIMITS
|
||||||
|
|
||||||
|
|
||||||
|
def get_parameters_from_feature_dict(
|
||||||
|
*, features_dict: Mapping[str, Any], user_input_form: list[dict[str, Any]]
|
||||||
|
) -> Mapping[str, Any]:
|
||||||
|
"""
|
||||||
|
Mapping from feature dict to webapp parameters
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
"opening_statement": features_dict.get("opening_statement"),
|
||||||
|
"suggested_questions": features_dict.get("suggested_questions", []),
|
||||||
|
"suggested_questions_after_answer": features_dict.get("suggested_questions_after_answer", {"enabled": False}),
|
||||||
|
"speech_to_text": features_dict.get("speech_to_text", {"enabled": False}),
|
||||||
|
"text_to_speech": features_dict.get("text_to_speech", {"enabled": False}),
|
||||||
|
"retriever_resource": features_dict.get("retriever_resource", {"enabled": False}),
|
||||||
|
"annotation_reply": features_dict.get("annotation_reply", {"enabled": False}),
|
||||||
|
"more_like_this": features_dict.get("more_like_this", {"enabled": False}),
|
||||||
|
"user_input_form": user_input_form,
|
||||||
|
"sensitive_word_avoidance": features_dict.get(
|
||||||
|
"sensitive_word_avoidance", {"enabled": False, "type": "", "configs": []}
|
||||||
|
),
|
||||||
|
"file_upload": features_dict.get(
|
||||||
|
"file_upload",
|
||||||
|
{
|
||||||
|
"image": {
|
||||||
|
"enabled": False,
|
||||||
|
"number_limits": DEFAULT_FILE_NUMBER_LIMITS,
|
||||||
|
"detail": "high",
|
||||||
|
"transfer_methods": ["remote_url", "local_file"],
|
||||||
|
}
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"system_parameters": {
|
||||||
|
"image_file_size_limit": dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT,
|
||||||
|
"video_file_size_limit": dify_config.UPLOAD_VIDEO_FILE_SIZE_LIMIT,
|
||||||
|
"audio_file_size_limit": dify_config.UPLOAD_AUDIO_FILE_SIZE_LIMIT,
|
||||||
|
"file_size_limit": dify_config.UPLOAD_FILE_SIZE_LIMIT,
|
||||||
|
"workflow_file_upload_limit": dify_config.WORKFLOW_FILE_UPLOAD_LIMIT,
|
||||||
|
},
|
||||||
|
}
|
|
@ -1,6 +1,7 @@
|
||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
from constants import DEFAULT_FILE_NUMBER_LIMITS
|
||||||
from core.file import FileUploadConfig
|
from core.file import FileUploadConfig
|
||||||
|
|
||||||
|
|
||||||
|
@ -18,7 +19,7 @@ class FileUploadConfigManager:
|
||||||
if file_upload_dict.get("enabled"):
|
if file_upload_dict.get("enabled"):
|
||||||
transform_methods = file_upload_dict.get("allowed_file_upload_methods", [])
|
transform_methods = file_upload_dict.get("allowed_file_upload_methods", [])
|
||||||
file_upload_dict["image_config"] = {
|
file_upload_dict["image_config"] = {
|
||||||
"number_limits": file_upload_dict.get("number_limits", 1),
|
"number_limits": file_upload_dict.get("number_limits", DEFAULT_FILE_NUMBER_LIMITS),
|
||||||
"transfer_methods": transform_methods,
|
"transfer_methods": transform_methods,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -320,10 +320,9 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||||
session=session, workflow_run_id=self._workflow_run_id
|
session=session, workflow_run_id=self._workflow_run_id
|
||||||
)
|
)
|
||||||
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_retried(
|
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_retried(
|
||||||
session=session, workflow_run=workflow_run, event=event
|
workflow_run=workflow_run, event=event
|
||||||
)
|
)
|
||||||
node_retry_resp = self._workflow_cycle_manager._workflow_node_retry_to_stream_response(
|
node_retry_resp = self._workflow_cycle_manager._workflow_node_retry_to_stream_response(
|
||||||
session=session,
|
|
||||||
event=event,
|
event=event,
|
||||||
task_id=self._application_generate_entity.task_id,
|
task_id=self._application_generate_entity.task_id,
|
||||||
workflow_node_execution=workflow_node_execution,
|
workflow_node_execution=workflow_node_execution,
|
||||||
|
@ -341,11 +340,10 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||||
session=session, workflow_run_id=self._workflow_run_id
|
session=session, workflow_run_id=self._workflow_run_id
|
||||||
)
|
)
|
||||||
workflow_node_execution = self._workflow_cycle_manager._handle_node_execution_start(
|
workflow_node_execution = self._workflow_cycle_manager._handle_node_execution_start(
|
||||||
session=session, workflow_run=workflow_run, event=event
|
workflow_run=workflow_run, event=event
|
||||||
)
|
)
|
||||||
|
|
||||||
node_start_resp = self._workflow_cycle_manager._workflow_node_start_to_stream_response(
|
node_start_resp = self._workflow_cycle_manager._workflow_node_start_to_stream_response(
|
||||||
session=session,
|
|
||||||
event=event,
|
event=event,
|
||||||
task_id=self._application_generate_entity.task_id,
|
task_id=self._application_generate_entity.task_id,
|
||||||
workflow_node_execution=workflow_node_execution,
|
workflow_node_execution=workflow_node_execution,
|
||||||
|
@ -363,11 +361,10 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||||
|
|
||||||
with Session(db.engine, expire_on_commit=False) as session:
|
with Session(db.engine, expire_on_commit=False) as session:
|
||||||
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_success(
|
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_success(
|
||||||
session=session, event=event
|
event=event
|
||||||
)
|
)
|
||||||
|
|
||||||
node_finish_resp = self._workflow_cycle_manager._workflow_node_finish_to_stream_response(
|
node_finish_resp = self._workflow_cycle_manager._workflow_node_finish_to_stream_response(
|
||||||
session=session,
|
|
||||||
event=event,
|
event=event,
|
||||||
task_id=self._application_generate_entity.task_id,
|
task_id=self._application_generate_entity.task_id,
|
||||||
workflow_node_execution=workflow_node_execution,
|
workflow_node_execution=workflow_node_execution,
|
||||||
|
@ -383,18 +380,15 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||||
| QueueNodeInLoopFailedEvent
|
| QueueNodeInLoopFailedEvent
|
||||||
| QueueNodeExceptionEvent,
|
| QueueNodeExceptionEvent,
|
||||||
):
|
):
|
||||||
with Session(db.engine, expire_on_commit=False) as session:
|
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_failed(
|
||||||
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_failed(
|
event=event
|
||||||
session=session, event=event
|
)
|
||||||
)
|
|
||||||
|
|
||||||
node_finish_resp = self._workflow_cycle_manager._workflow_node_finish_to_stream_response(
|
node_finish_resp = self._workflow_cycle_manager._workflow_node_finish_to_stream_response(
|
||||||
session=session,
|
event=event,
|
||||||
event=event,
|
task_id=self._application_generate_entity.task_id,
|
||||||
task_id=self._application_generate_entity.task_id,
|
workflow_node_execution=workflow_node_execution,
|
||||||
workflow_node_execution=workflow_node_execution,
|
)
|
||||||
)
|
|
||||||
session.commit()
|
|
||||||
|
|
||||||
if node_finish_resp:
|
if node_finish_resp:
|
||||||
yield node_finish_resp
|
yield node_finish_resp
|
||||||
|
|
|
@ -17,6 +17,7 @@ class BaseAppGenerator:
|
||||||
user_inputs: Optional[Mapping[str, Any]],
|
user_inputs: Optional[Mapping[str, Any]],
|
||||||
variables: Sequence["VariableEntity"],
|
variables: Sequence["VariableEntity"],
|
||||||
tenant_id: str,
|
tenant_id: str,
|
||||||
|
strict_type_validation: bool = False,
|
||||||
) -> Mapping[str, Any]:
|
) -> Mapping[str, Any]:
|
||||||
user_inputs = user_inputs or {}
|
user_inputs = user_inputs or {}
|
||||||
# Filter input variables from form configuration, handle required fields, default values, and option values
|
# Filter input variables from form configuration, handle required fields, default values, and option values
|
||||||
|
@ -37,6 +38,7 @@ class BaseAppGenerator:
|
||||||
allowed_file_extensions=entity_dictionary[k].allowed_file_extensions,
|
allowed_file_extensions=entity_dictionary[k].allowed_file_extensions,
|
||||||
allowed_file_upload_methods=entity_dictionary[k].allowed_file_upload_methods,
|
allowed_file_upload_methods=entity_dictionary[k].allowed_file_upload_methods,
|
||||||
),
|
),
|
||||||
|
strict_type_validation=strict_type_validation,
|
||||||
)
|
)
|
||||||
for k, v in user_inputs.items()
|
for k, v in user_inputs.items()
|
||||||
if isinstance(v, dict) and entity_dictionary[k].type == VariableEntityType.FILE
|
if isinstance(v, dict) and entity_dictionary[k].type == VariableEntityType.FILE
|
||||||
|
|
|
@ -153,6 +153,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
||||||
query = application_generate_entity.query or "New conversation"
|
query = application_generate_entity.query or "New conversation"
|
||||||
else:
|
else:
|
||||||
query = next(iter(application_generate_entity.inputs.values()), "New conversation")
|
query = next(iter(application_generate_entity.inputs.values()), "New conversation")
|
||||||
|
query = query or "New conversation"
|
||||||
conversation_name = (query[:20] + "…") if len(query) > 20 else query
|
conversation_name = (query[:20] + "…") if len(query) > 20 else query
|
||||||
|
|
||||||
if not conversation:
|
if not conversation:
|
||||||
|
|
|
@ -92,6 +92,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||||
mappings=files,
|
mappings=files,
|
||||||
tenant_id=app_model.tenant_id,
|
tenant_id=app_model.tenant_id,
|
||||||
config=file_extra_config,
|
config=file_extra_config,
|
||||||
|
strict_type_validation=True if invoke_from == InvokeFrom.SERVICE_API else False,
|
||||||
)
|
)
|
||||||
|
|
||||||
# convert to app config
|
# convert to app config
|
||||||
|
@ -114,7 +115,10 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||||
app_config=app_config,
|
app_config=app_config,
|
||||||
file_upload_config=file_extra_config,
|
file_upload_config=file_extra_config,
|
||||||
inputs=self._prepare_user_inputs(
|
inputs=self._prepare_user_inputs(
|
||||||
user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id
|
user_inputs=inputs,
|
||||||
|
variables=app_config.variables,
|
||||||
|
tenant_id=app_model.tenant_id,
|
||||||
|
strict_type_validation=True if invoke_from == InvokeFrom.SERVICE_API else False,
|
||||||
),
|
),
|
||||||
files=list(system_files),
|
files=list(system_files),
|
||||||
user_id=user.id,
|
user_id=user.id,
|
||||||
|
|
|
@ -279,10 +279,9 @@ class WorkflowAppGenerateTaskPipeline:
|
||||||
session=session, workflow_run_id=self._workflow_run_id
|
session=session, workflow_run_id=self._workflow_run_id
|
||||||
)
|
)
|
||||||
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_retried(
|
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_retried(
|
||||||
session=session, workflow_run=workflow_run, event=event
|
workflow_run=workflow_run, event=event
|
||||||
)
|
)
|
||||||
response = self._workflow_cycle_manager._workflow_node_retry_to_stream_response(
|
response = self._workflow_cycle_manager._workflow_node_retry_to_stream_response(
|
||||||
session=session,
|
|
||||||
event=event,
|
event=event,
|
||||||
task_id=self._application_generate_entity.task_id,
|
task_id=self._application_generate_entity.task_id,
|
||||||
workflow_node_execution=workflow_node_execution,
|
workflow_node_execution=workflow_node_execution,
|
||||||
|
@ -300,10 +299,9 @@ class WorkflowAppGenerateTaskPipeline:
|
||||||
session=session, workflow_run_id=self._workflow_run_id
|
session=session, workflow_run_id=self._workflow_run_id
|
||||||
)
|
)
|
||||||
workflow_node_execution = self._workflow_cycle_manager._handle_node_execution_start(
|
workflow_node_execution = self._workflow_cycle_manager._handle_node_execution_start(
|
||||||
session=session, workflow_run=workflow_run, event=event
|
workflow_run=workflow_run, event=event
|
||||||
)
|
)
|
||||||
node_start_response = self._workflow_cycle_manager._workflow_node_start_to_stream_response(
|
node_start_response = self._workflow_cycle_manager._workflow_node_start_to_stream_response(
|
||||||
session=session,
|
|
||||||
event=event,
|
event=event,
|
||||||
task_id=self._application_generate_entity.task_id,
|
task_id=self._application_generate_entity.task_id,
|
||||||
workflow_node_execution=workflow_node_execution,
|
workflow_node_execution=workflow_node_execution,
|
||||||
|
@ -313,17 +311,14 @@ class WorkflowAppGenerateTaskPipeline:
|
||||||
if node_start_response:
|
if node_start_response:
|
||||||
yield node_start_response
|
yield node_start_response
|
||||||
elif isinstance(event, QueueNodeSucceededEvent):
|
elif isinstance(event, QueueNodeSucceededEvent):
|
||||||
with Session(db.engine, expire_on_commit=False) as session:
|
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_success(
|
||||||
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_success(
|
event=event
|
||||||
session=session, event=event
|
)
|
||||||
)
|
node_success_response = self._workflow_cycle_manager._workflow_node_finish_to_stream_response(
|
||||||
node_success_response = self._workflow_cycle_manager._workflow_node_finish_to_stream_response(
|
event=event,
|
||||||
session=session,
|
task_id=self._application_generate_entity.task_id,
|
||||||
event=event,
|
workflow_node_execution=workflow_node_execution,
|
||||||
task_id=self._application_generate_entity.task_id,
|
)
|
||||||
workflow_node_execution=workflow_node_execution,
|
|
||||||
)
|
|
||||||
session.commit()
|
|
||||||
|
|
||||||
if node_success_response:
|
if node_success_response:
|
||||||
yield node_success_response
|
yield node_success_response
|
||||||
|
@ -334,18 +329,14 @@ class WorkflowAppGenerateTaskPipeline:
|
||||||
| QueueNodeInLoopFailedEvent
|
| QueueNodeInLoopFailedEvent
|
||||||
| QueueNodeExceptionEvent,
|
| QueueNodeExceptionEvent,
|
||||||
):
|
):
|
||||||
with Session(db.engine, expire_on_commit=False) as session:
|
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_failed(
|
||||||
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_failed(
|
event=event,
|
||||||
session=session,
|
)
|
||||||
event=event,
|
node_failed_response = self._workflow_cycle_manager._workflow_node_finish_to_stream_response(
|
||||||
)
|
event=event,
|
||||||
node_failed_response = self._workflow_cycle_manager._workflow_node_finish_to_stream_response(
|
task_id=self._application_generate_entity.task_id,
|
||||||
session=session,
|
workflow_node_execution=workflow_node_execution,
|
||||||
event=event,
|
)
|
||||||
task_id=self._application_generate_entity.task_id,
|
|
||||||
workflow_node_execution=workflow_node_execution,
|
|
||||||
)
|
|
||||||
session.commit()
|
|
||||||
|
|
||||||
if node_failed_response:
|
if node_failed_response:
|
||||||
yield node_failed_response
|
yield node_failed_response
|
||||||
|
@ -627,6 +618,7 @@ class WorkflowAppGenerateTaskPipeline:
|
||||||
workflow_app_log.created_by = self._user_id
|
workflow_app_log.created_by = self._user_id
|
||||||
|
|
||||||
session.add(workflow_app_log)
|
session.add(workflow_app_log)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
def _text_chunk_to_stream_response(
|
def _text_chunk_to_stream_response(
|
||||||
self, text: str, from_variable_selector: Optional[list[str]] = None
|
self, text: str, from_variable_selector: Optional[list[str]] = None
|
||||||
|
|
|
@ -6,7 +6,7 @@ from typing import Any, Optional, Union, cast
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
from sqlalchemy import func, select
|
from sqlalchemy import func, select
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session, sessionmaker
|
||||||
|
|
||||||
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity
|
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity
|
||||||
from core.app.entities.queue_entities import (
|
from core.app.entities.queue_entities import (
|
||||||
|
@ -49,12 +49,14 @@ from core.file import FILE_MODEL_IDENTITY, File
|
||||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||||
from core.ops.entities.trace_entity import TraceTaskName
|
from core.ops.entities.trace_entity import TraceTaskName
|
||||||
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
|
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
|
||||||
|
from core.repository import RepositoryFactory
|
||||||
from core.tools.tool_manager import ToolManager
|
from core.tools.tool_manager import ToolManager
|
||||||
from core.workflow.entities.node_entities import NodeRunMetadataKey
|
from core.workflow.entities.node_entities import NodeRunMetadataKey
|
||||||
from core.workflow.enums import SystemVariableKey
|
from core.workflow.enums import SystemVariableKey
|
||||||
from core.workflow.nodes import NodeType
|
from core.workflow.nodes import NodeType
|
||||||
from core.workflow.nodes.tool.entities import ToolNodeData
|
from core.workflow.nodes.tool.entities import ToolNodeData
|
||||||
from core.workflow.workflow_entry import WorkflowEntry
|
from core.workflow.workflow_entry import WorkflowEntry
|
||||||
|
from extensions.ext_database import db
|
||||||
from models.account import Account
|
from models.account import Account
|
||||||
from models.enums import CreatedByRole, WorkflowRunTriggeredFrom
|
from models.enums import CreatedByRole, WorkflowRunTriggeredFrom
|
||||||
from models.model import EndUser
|
from models.model import EndUser
|
||||||
|
@ -80,6 +82,21 @@ class WorkflowCycleManage:
|
||||||
self._application_generate_entity = application_generate_entity
|
self._application_generate_entity = application_generate_entity
|
||||||
self._workflow_system_variables = workflow_system_variables
|
self._workflow_system_variables = workflow_system_variables
|
||||||
|
|
||||||
|
# Initialize the session factory and repository
|
||||||
|
# We use the global db engine instead of the session passed to methods
|
||||||
|
# Disable expire_on_commit to avoid the need for merging objects
|
||||||
|
self._session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||||
|
self._workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository(
|
||||||
|
params={
|
||||||
|
"tenant_id": self._application_generate_entity.app_config.tenant_id,
|
||||||
|
"app_id": self._application_generate_entity.app_config.app_id,
|
||||||
|
"session_factory": self._session_factory,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# We'll still keep the cache for backward compatibility and performance
|
||||||
|
# but use the repository for database operations
|
||||||
|
|
||||||
def _handle_workflow_run_start(
|
def _handle_workflow_run_start(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
|
@ -254,19 +271,15 @@ class WorkflowCycleManage:
|
||||||
workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None)
|
workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None)
|
||||||
workflow_run.exceptions_count = exceptions_count
|
workflow_run.exceptions_count = exceptions_count
|
||||||
|
|
||||||
stmt = select(WorkflowNodeExecution.node_execution_id).where(
|
# Use the instance repository to find running executions for a workflow run
|
||||||
WorkflowNodeExecution.tenant_id == workflow_run.tenant_id,
|
running_workflow_node_executions = self._workflow_node_execution_repository.get_running_executions(
|
||||||
WorkflowNodeExecution.app_id == workflow_run.app_id,
|
workflow_run_id=workflow_run.id
|
||||||
WorkflowNodeExecution.workflow_id == workflow_run.workflow_id,
|
|
||||||
WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
|
|
||||||
WorkflowNodeExecution.workflow_run_id == workflow_run.id,
|
|
||||||
WorkflowNodeExecution.status == WorkflowNodeExecutionStatus.RUNNING.value,
|
|
||||||
)
|
)
|
||||||
ids = session.scalars(stmt).all()
|
|
||||||
# Use self._get_workflow_node_execution here to make sure the cache is updated
|
# Update the cache with the retrieved executions
|
||||||
running_workflow_node_executions = [
|
for execution in running_workflow_node_executions:
|
||||||
self._get_workflow_node_execution(session=session, node_execution_id=id) for id in ids if id
|
if execution.node_execution_id:
|
||||||
]
|
self._workflow_node_executions[execution.node_execution_id] = execution
|
||||||
|
|
||||||
for workflow_node_execution in running_workflow_node_executions:
|
for workflow_node_execution in running_workflow_node_executions:
|
||||||
now = datetime.now(UTC).replace(tzinfo=None)
|
now = datetime.now(UTC).replace(tzinfo=None)
|
||||||
|
@ -288,7 +301,7 @@ class WorkflowCycleManage:
|
||||||
return workflow_run
|
return workflow_run
|
||||||
|
|
||||||
def _handle_node_execution_start(
|
def _handle_node_execution_start(
|
||||||
self, *, session: Session, workflow_run: WorkflowRun, event: QueueNodeStartedEvent
|
self, *, workflow_run: WorkflowRun, event: QueueNodeStartedEvent
|
||||||
) -> WorkflowNodeExecution:
|
) -> WorkflowNodeExecution:
|
||||||
workflow_node_execution = WorkflowNodeExecution()
|
workflow_node_execution = WorkflowNodeExecution()
|
||||||
workflow_node_execution.id = str(uuid4())
|
workflow_node_execution.id = str(uuid4())
|
||||||
|
@ -315,17 +328,14 @@ class WorkflowCycleManage:
|
||||||
)
|
)
|
||||||
workflow_node_execution.created_at = datetime.now(UTC).replace(tzinfo=None)
|
workflow_node_execution.created_at = datetime.now(UTC).replace(tzinfo=None)
|
||||||
|
|
||||||
session.add(workflow_node_execution)
|
# Use the instance repository to save the workflow node execution
|
||||||
|
self._workflow_node_execution_repository.save(workflow_node_execution)
|
||||||
|
|
||||||
self._workflow_node_executions[event.node_execution_id] = workflow_node_execution
|
self._workflow_node_executions[event.node_execution_id] = workflow_node_execution
|
||||||
return workflow_node_execution
|
return workflow_node_execution
|
||||||
|
|
||||||
def _handle_workflow_node_execution_success(
|
def _handle_workflow_node_execution_success(self, *, event: QueueNodeSucceededEvent) -> WorkflowNodeExecution:
|
||||||
self, *, session: Session, event: QueueNodeSucceededEvent
|
workflow_node_execution = self._get_workflow_node_execution(node_execution_id=event.node_execution_id)
|
||||||
) -> WorkflowNodeExecution:
|
|
||||||
workflow_node_execution = self._get_workflow_node_execution(
|
|
||||||
session=session, node_execution_id=event.node_execution_id
|
|
||||||
)
|
|
||||||
inputs = WorkflowEntry.handle_special_values(event.inputs)
|
inputs = WorkflowEntry.handle_special_values(event.inputs)
|
||||||
process_data = WorkflowEntry.handle_special_values(event.process_data)
|
process_data = WorkflowEntry.handle_special_values(event.process_data)
|
||||||
outputs = WorkflowEntry.handle_special_values(event.outputs)
|
outputs = WorkflowEntry.handle_special_values(event.outputs)
|
||||||
|
@ -344,13 +354,13 @@ class WorkflowCycleManage:
|
||||||
workflow_node_execution.finished_at = finished_at
|
workflow_node_execution.finished_at = finished_at
|
||||||
workflow_node_execution.elapsed_time = elapsed_time
|
workflow_node_execution.elapsed_time = elapsed_time
|
||||||
|
|
||||||
workflow_node_execution = session.merge(workflow_node_execution)
|
# Use the instance repository to update the workflow node execution
|
||||||
|
self._workflow_node_execution_repository.update(workflow_node_execution)
|
||||||
return workflow_node_execution
|
return workflow_node_execution
|
||||||
|
|
||||||
def _handle_workflow_node_execution_failed(
|
def _handle_workflow_node_execution_failed(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
session: Session,
|
|
||||||
event: QueueNodeFailedEvent
|
event: QueueNodeFailedEvent
|
||||||
| QueueNodeInIterationFailedEvent
|
| QueueNodeInIterationFailedEvent
|
||||||
| QueueNodeInLoopFailedEvent
|
| QueueNodeInLoopFailedEvent
|
||||||
|
@ -361,9 +371,7 @@ class WorkflowCycleManage:
|
||||||
:param event: queue node failed event
|
:param event: queue node failed event
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
workflow_node_execution = self._get_workflow_node_execution(
|
workflow_node_execution = self._get_workflow_node_execution(node_execution_id=event.node_execution_id)
|
||||||
session=session, node_execution_id=event.node_execution_id
|
|
||||||
)
|
|
||||||
|
|
||||||
inputs = WorkflowEntry.handle_special_values(event.inputs)
|
inputs = WorkflowEntry.handle_special_values(event.inputs)
|
||||||
process_data = WorkflowEntry.handle_special_values(event.process_data)
|
process_data = WorkflowEntry.handle_special_values(event.process_data)
|
||||||
|
@ -387,14 +395,14 @@ class WorkflowCycleManage:
|
||||||
workflow_node_execution.elapsed_time = elapsed_time
|
workflow_node_execution.elapsed_time = elapsed_time
|
||||||
workflow_node_execution.execution_metadata = execution_metadata
|
workflow_node_execution.execution_metadata = execution_metadata
|
||||||
|
|
||||||
workflow_node_execution = session.merge(workflow_node_execution)
|
|
||||||
return workflow_node_execution
|
return workflow_node_execution
|
||||||
|
|
||||||
def _handle_workflow_node_execution_retried(
|
def _handle_workflow_node_execution_retried(
|
||||||
self, *, session: Session, workflow_run: WorkflowRun, event: QueueNodeRetryEvent
|
self, *, workflow_run: WorkflowRun, event: QueueNodeRetryEvent
|
||||||
) -> WorkflowNodeExecution:
|
) -> WorkflowNodeExecution:
|
||||||
"""
|
"""
|
||||||
Workflow node execution failed
|
Workflow node execution failed
|
||||||
|
:param workflow_run: workflow run
|
||||||
:param event: queue node failed event
|
:param event: queue node failed event
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
|
@ -439,15 +447,12 @@ class WorkflowCycleManage:
|
||||||
workflow_node_execution.execution_metadata = execution_metadata
|
workflow_node_execution.execution_metadata = execution_metadata
|
||||||
workflow_node_execution.index = event.node_run_index
|
workflow_node_execution.index = event.node_run_index
|
||||||
|
|
||||||
session.add(workflow_node_execution)
|
# Use the instance repository to save the workflow node execution
|
||||||
|
self._workflow_node_execution_repository.save(workflow_node_execution)
|
||||||
|
|
||||||
self._workflow_node_executions[event.node_execution_id] = workflow_node_execution
|
self._workflow_node_executions[event.node_execution_id] = workflow_node_execution
|
||||||
return workflow_node_execution
|
return workflow_node_execution
|
||||||
|
|
||||||
#################################################
|
|
||||||
# to stream responses #
|
|
||||||
#################################################
|
|
||||||
|
|
||||||
def _workflow_start_to_stream_response(
|
def _workflow_start_to_stream_response(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
|
@ -455,7 +460,6 @@ class WorkflowCycleManage:
|
||||||
task_id: str,
|
task_id: str,
|
||||||
workflow_run: WorkflowRun,
|
workflow_run: WorkflowRun,
|
||||||
) -> WorkflowStartStreamResponse:
|
) -> WorkflowStartStreamResponse:
|
||||||
# receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this
|
|
||||||
_ = session
|
_ = session
|
||||||
return WorkflowStartStreamResponse(
|
return WorkflowStartStreamResponse(
|
||||||
task_id=task_id,
|
task_id=task_id,
|
||||||
|
@ -521,14 +525,10 @@ class WorkflowCycleManage:
|
||||||
def _workflow_node_start_to_stream_response(
|
def _workflow_node_start_to_stream_response(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
session: Session,
|
|
||||||
event: QueueNodeStartedEvent,
|
event: QueueNodeStartedEvent,
|
||||||
task_id: str,
|
task_id: str,
|
||||||
workflow_node_execution: WorkflowNodeExecution,
|
workflow_node_execution: WorkflowNodeExecution,
|
||||||
) -> Optional[NodeStartStreamResponse]:
|
) -> Optional[NodeStartStreamResponse]:
|
||||||
# receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this
|
|
||||||
_ = session
|
|
||||||
|
|
||||||
if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}:
|
if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}:
|
||||||
return None
|
return None
|
||||||
if not workflow_node_execution.workflow_run_id:
|
if not workflow_node_execution.workflow_run_id:
|
||||||
|
@ -571,7 +571,6 @@ class WorkflowCycleManage:
|
||||||
def _workflow_node_finish_to_stream_response(
|
def _workflow_node_finish_to_stream_response(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
session: Session,
|
|
||||||
event: QueueNodeSucceededEvent
|
event: QueueNodeSucceededEvent
|
||||||
| QueueNodeFailedEvent
|
| QueueNodeFailedEvent
|
||||||
| QueueNodeInIterationFailedEvent
|
| QueueNodeInIterationFailedEvent
|
||||||
|
@ -580,8 +579,6 @@ class WorkflowCycleManage:
|
||||||
task_id: str,
|
task_id: str,
|
||||||
workflow_node_execution: WorkflowNodeExecution,
|
workflow_node_execution: WorkflowNodeExecution,
|
||||||
) -> Optional[NodeFinishStreamResponse]:
|
) -> Optional[NodeFinishStreamResponse]:
|
||||||
# receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this
|
|
||||||
_ = session
|
|
||||||
if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}:
|
if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}:
|
||||||
return None
|
return None
|
||||||
if not workflow_node_execution.workflow_run_id:
|
if not workflow_node_execution.workflow_run_id:
|
||||||
|
@ -621,13 +618,10 @@ class WorkflowCycleManage:
|
||||||
def _workflow_node_retry_to_stream_response(
|
def _workflow_node_retry_to_stream_response(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
session: Session,
|
|
||||||
event: QueueNodeRetryEvent,
|
event: QueueNodeRetryEvent,
|
||||||
task_id: str,
|
task_id: str,
|
||||||
workflow_node_execution: WorkflowNodeExecution,
|
workflow_node_execution: WorkflowNodeExecution,
|
||||||
) -> Optional[Union[NodeRetryStreamResponse, NodeFinishStreamResponse]]:
|
) -> Optional[Union[NodeRetryStreamResponse, NodeFinishStreamResponse]]:
|
||||||
# receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this
|
|
||||||
_ = session
|
|
||||||
if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}:
|
if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}:
|
||||||
return None
|
return None
|
||||||
if not workflow_node_execution.workflow_run_id:
|
if not workflow_node_execution.workflow_run_id:
|
||||||
|
@ -668,7 +662,6 @@ class WorkflowCycleManage:
|
||||||
def _workflow_parallel_branch_start_to_stream_response(
|
def _workflow_parallel_branch_start_to_stream_response(
|
||||||
self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueParallelBranchRunStartedEvent
|
self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueParallelBranchRunStartedEvent
|
||||||
) -> ParallelBranchStartStreamResponse:
|
) -> ParallelBranchStartStreamResponse:
|
||||||
# receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this
|
|
||||||
_ = session
|
_ = session
|
||||||
return ParallelBranchStartStreamResponse(
|
return ParallelBranchStartStreamResponse(
|
||||||
task_id=task_id,
|
task_id=task_id,
|
||||||
|
@ -692,7 +685,6 @@ class WorkflowCycleManage:
|
||||||
workflow_run: WorkflowRun,
|
workflow_run: WorkflowRun,
|
||||||
event: QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent,
|
event: QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent,
|
||||||
) -> ParallelBranchFinishedStreamResponse:
|
) -> ParallelBranchFinishedStreamResponse:
|
||||||
# receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this
|
|
||||||
_ = session
|
_ = session
|
||||||
return ParallelBranchFinishedStreamResponse(
|
return ParallelBranchFinishedStreamResponse(
|
||||||
task_id=task_id,
|
task_id=task_id,
|
||||||
|
@ -713,7 +705,6 @@ class WorkflowCycleManage:
|
||||||
def _workflow_iteration_start_to_stream_response(
|
def _workflow_iteration_start_to_stream_response(
|
||||||
self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueIterationStartEvent
|
self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueIterationStartEvent
|
||||||
) -> IterationNodeStartStreamResponse:
|
) -> IterationNodeStartStreamResponse:
|
||||||
# receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this
|
|
||||||
_ = session
|
_ = session
|
||||||
return IterationNodeStartStreamResponse(
|
return IterationNodeStartStreamResponse(
|
||||||
task_id=task_id,
|
task_id=task_id,
|
||||||
|
@ -735,7 +726,6 @@ class WorkflowCycleManage:
|
||||||
def _workflow_iteration_next_to_stream_response(
|
def _workflow_iteration_next_to_stream_response(
|
||||||
self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueIterationNextEvent
|
self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueIterationNextEvent
|
||||||
) -> IterationNodeNextStreamResponse:
|
) -> IterationNodeNextStreamResponse:
|
||||||
# receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this
|
|
||||||
_ = session
|
_ = session
|
||||||
return IterationNodeNextStreamResponse(
|
return IterationNodeNextStreamResponse(
|
||||||
task_id=task_id,
|
task_id=task_id,
|
||||||
|
@ -759,7 +749,6 @@ class WorkflowCycleManage:
|
||||||
def _workflow_iteration_completed_to_stream_response(
|
def _workflow_iteration_completed_to_stream_response(
|
||||||
self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueIterationCompletedEvent
|
self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueIterationCompletedEvent
|
||||||
) -> IterationNodeCompletedStreamResponse:
|
) -> IterationNodeCompletedStreamResponse:
|
||||||
# receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this
|
|
||||||
_ = session
|
_ = session
|
||||||
return IterationNodeCompletedStreamResponse(
|
return IterationNodeCompletedStreamResponse(
|
||||||
task_id=task_id,
|
task_id=task_id,
|
||||||
|
@ -790,7 +779,6 @@ class WorkflowCycleManage:
|
||||||
def _workflow_loop_start_to_stream_response(
|
def _workflow_loop_start_to_stream_response(
|
||||||
self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueLoopStartEvent
|
self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueLoopStartEvent
|
||||||
) -> LoopNodeStartStreamResponse:
|
) -> LoopNodeStartStreamResponse:
|
||||||
# receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this
|
|
||||||
_ = session
|
_ = session
|
||||||
return LoopNodeStartStreamResponse(
|
return LoopNodeStartStreamResponse(
|
||||||
task_id=task_id,
|
task_id=task_id,
|
||||||
|
@ -812,7 +800,6 @@ class WorkflowCycleManage:
|
||||||
def _workflow_loop_next_to_stream_response(
|
def _workflow_loop_next_to_stream_response(
|
||||||
self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueLoopNextEvent
|
self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueLoopNextEvent
|
||||||
) -> LoopNodeNextStreamResponse:
|
) -> LoopNodeNextStreamResponse:
|
||||||
# receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this
|
|
||||||
_ = session
|
_ = session
|
||||||
return LoopNodeNextStreamResponse(
|
return LoopNodeNextStreamResponse(
|
||||||
task_id=task_id,
|
task_id=task_id,
|
||||||
|
@ -836,7 +823,6 @@ class WorkflowCycleManage:
|
||||||
def _workflow_loop_completed_to_stream_response(
|
def _workflow_loop_completed_to_stream_response(
|
||||||
self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueLoopCompletedEvent
|
self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueLoopCompletedEvent
|
||||||
) -> LoopNodeCompletedStreamResponse:
|
) -> LoopNodeCompletedStreamResponse:
|
||||||
# receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this
|
|
||||||
_ = session
|
_ = session
|
||||||
return LoopNodeCompletedStreamResponse(
|
return LoopNodeCompletedStreamResponse(
|
||||||
task_id=task_id,
|
task_id=task_id,
|
||||||
|
@ -934,11 +920,22 @@ class WorkflowCycleManage:
|
||||||
|
|
||||||
return workflow_run
|
return workflow_run
|
||||||
|
|
||||||
def _get_workflow_node_execution(self, session: Session, node_execution_id: str) -> WorkflowNodeExecution:
|
def _get_workflow_node_execution(self, node_execution_id: str) -> WorkflowNodeExecution:
|
||||||
if node_execution_id not in self._workflow_node_executions:
|
# First check the cache for performance
|
||||||
|
if node_execution_id in self._workflow_node_executions:
|
||||||
|
cached_execution = self._workflow_node_executions[node_execution_id]
|
||||||
|
# No need to merge with session since expire_on_commit=False
|
||||||
|
return cached_execution
|
||||||
|
|
||||||
|
# If not in cache, use the instance repository to get by node_execution_id
|
||||||
|
execution = self._workflow_node_execution_repository.get_by_node_execution_id(node_execution_id)
|
||||||
|
|
||||||
|
if not execution:
|
||||||
raise ValueError(f"Workflow node execution not found: {node_execution_id}")
|
raise ValueError(f"Workflow node execution not found: {node_execution_id}")
|
||||||
cached_workflow_node_execution = self._workflow_node_executions[node_execution_id]
|
|
||||||
return session.merge(cached_workflow_node_execution)
|
# Update cache
|
||||||
|
self._workflow_node_executions[node_execution_id] = execution
|
||||||
|
return execution
|
||||||
|
|
||||||
def _handle_agent_log(self, task_id: str, event: QueueAgentLogEvent) -> AgentLogStreamResponse:
|
def _handle_agent_log(self, task_id: str, event: QueueAgentLogEvent) -> AgentLogStreamResponse:
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -6,7 +6,6 @@ from core.rag.models.document import Document
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models.dataset import ChildChunk, DatasetQuery, DocumentSegment
|
from models.dataset import ChildChunk, DatasetQuery, DocumentSegment
|
||||||
from models.dataset import Document as DatasetDocument
|
from models.dataset import Document as DatasetDocument
|
||||||
from models.model import DatasetRetrieverResource
|
|
||||||
|
|
||||||
|
|
||||||
class DatasetIndexToolCallbackHandler:
|
class DatasetIndexToolCallbackHandler:
|
||||||
|
@ -71,29 +70,6 @@ class DatasetIndexToolCallbackHandler:
|
||||||
|
|
||||||
def return_retriever_resource_info(self, resource: list):
|
def return_retriever_resource_info(self, resource: list):
|
||||||
"""Handle return_retriever_resource_info."""
|
"""Handle return_retriever_resource_info."""
|
||||||
if resource and len(resource) > 0:
|
|
||||||
for item in resource:
|
|
||||||
dataset_retriever_resource = DatasetRetrieverResource(
|
|
||||||
message_id=self._message_id,
|
|
||||||
position=item.get("position") or 0,
|
|
||||||
dataset_id=item.get("dataset_id"),
|
|
||||||
dataset_name=item.get("dataset_name"),
|
|
||||||
document_id=item.get("document_id"),
|
|
||||||
document_name=item.get("document_name"),
|
|
||||||
data_source_type=item.get("data_source_type"),
|
|
||||||
segment_id=item.get("segment_id"),
|
|
||||||
score=item.get("score") if "score" in item else None,
|
|
||||||
hit_count=item.get("hit_count") if "hit_count" in item else None,
|
|
||||||
word_count=item.get("word_count") if "word_count" in item else None,
|
|
||||||
segment_position=item.get("segment_position") if "segment_position" in item else None,
|
|
||||||
index_node_hash=item.get("index_node_hash") if "index_node_hash" in item else None,
|
|
||||||
content=item.get("content"),
|
|
||||||
retriever_from=item.get("retriever_from"),
|
|
||||||
created_by=self._user_id,
|
|
||||||
)
|
|
||||||
db.session.add(dataset_retriever_resource)
|
|
||||||
db.session.commit()
|
|
||||||
|
|
||||||
self._queue_manager.publish(
|
self._queue_manager.publish(
|
||||||
QueueRetrieverResourcesEvent(retriever_resources=resource), PublishFrom.APPLICATION_MANAGER
|
QueueRetrieverResourcesEvent(retriever_resources=resource), PublishFrom.APPLICATION_MANAGER
|
||||||
)
|
)
|
||||||
|
|
|
@ -48,25 +48,26 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
|
||||||
write=dify_config.SSRF_DEFAULT_WRITE_TIME_OUT,
|
write=dify_config.SSRF_DEFAULT_WRITE_TIME_OUT,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if "ssl_verify" not in kwargs:
|
||||||
|
kwargs["ssl_verify"] = HTTP_REQUEST_NODE_SSL_VERIFY
|
||||||
|
|
||||||
|
ssl_verify = kwargs.pop("ssl_verify")
|
||||||
|
|
||||||
retries = 0
|
retries = 0
|
||||||
while retries <= max_retries:
|
while retries <= max_retries:
|
||||||
try:
|
try:
|
||||||
if dify_config.SSRF_PROXY_ALL_URL:
|
if dify_config.SSRF_PROXY_ALL_URL:
|
||||||
with httpx.Client(proxy=dify_config.SSRF_PROXY_ALL_URL, verify=HTTP_REQUEST_NODE_SSL_VERIFY) as client:
|
with httpx.Client(proxy=dify_config.SSRF_PROXY_ALL_URL, verify=ssl_verify) as client:
|
||||||
response = client.request(method=method, url=url, **kwargs)
|
response = client.request(method=method, url=url, **kwargs)
|
||||||
elif dify_config.SSRF_PROXY_HTTP_URL and dify_config.SSRF_PROXY_HTTPS_URL:
|
elif dify_config.SSRF_PROXY_HTTP_URL and dify_config.SSRF_PROXY_HTTPS_URL:
|
||||||
proxy_mounts = {
|
proxy_mounts = {
|
||||||
"http://": httpx.HTTPTransport(
|
"http://": httpx.HTTPTransport(proxy=dify_config.SSRF_PROXY_HTTP_URL, verify=ssl_verify),
|
||||||
proxy=dify_config.SSRF_PROXY_HTTP_URL, verify=HTTP_REQUEST_NODE_SSL_VERIFY
|
"https://": httpx.HTTPTransport(proxy=dify_config.SSRF_PROXY_HTTPS_URL, verify=ssl_verify),
|
||||||
),
|
|
||||||
"https://": httpx.HTTPTransport(
|
|
||||||
proxy=dify_config.SSRF_PROXY_HTTPS_URL, verify=HTTP_REQUEST_NODE_SSL_VERIFY
|
|
||||||
),
|
|
||||||
}
|
}
|
||||||
with httpx.Client(mounts=proxy_mounts, verify=HTTP_REQUEST_NODE_SSL_VERIFY) as client:
|
with httpx.Client(mounts=proxy_mounts, verify=ssl_verify) as client:
|
||||||
response = client.request(method=method, url=url, **kwargs)
|
response = client.request(method=method, url=url, **kwargs)
|
||||||
else:
|
else:
|
||||||
with httpx.Client(verify=HTTP_REQUEST_NODE_SSL_VERIFY) as client:
|
with httpx.Client(verify=ssl_verify) as client:
|
||||||
response = client.request(method=method, url=url, **kwargs)
|
response = client.request(method=method, url=url, **kwargs)
|
||||||
|
|
||||||
if response.status_code not in STATUS_FORCELIST:
|
if response.status_code not in STATUS_FORCELIST:
|
||||||
|
|
|
@ -44,6 +44,7 @@ class TokenBufferMemory:
|
||||||
Message.created_at,
|
Message.created_at,
|
||||||
Message.workflow_run_id,
|
Message.workflow_run_id,
|
||||||
Message.parent_message_id,
|
Message.parent_message_id,
|
||||||
|
Message.answer_tokens,
|
||||||
)
|
)
|
||||||
.filter(
|
.filter(
|
||||||
Message.conversation_id == self.conversation.id,
|
Message.conversation_id == self.conversation.id,
|
||||||
|
@ -63,7 +64,7 @@ class TokenBufferMemory:
|
||||||
thread_messages = extract_thread_messages(messages)
|
thread_messages = extract_thread_messages(messages)
|
||||||
|
|
||||||
# for newly created message, its answer is temporarily empty, we don't need to add it to memory
|
# for newly created message, its answer is temporarily empty, we don't need to add it to memory
|
||||||
if thread_messages and not thread_messages[0].answer:
|
if thread_messages and not thread_messages[0].answer and thread_messages[0].answer_tokens == 0:
|
||||||
thread_messages.pop(0)
|
thread_messages.pop(0)
|
||||||
|
|
||||||
messages = list(reversed(thread_messages))
|
messages = list(reversed(thread_messages))
|
||||||
|
|
|
@ -177,7 +177,7 @@ class ModelInstance:
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_llm_num_tokens(
|
def get_llm_num_tokens(
|
||||||
self, prompt_messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None
|
self, prompt_messages: Sequence[PromptMessage], tools: Optional[Sequence[PromptMessageTool]] = None
|
||||||
) -> int:
|
) -> int:
|
||||||
"""
|
"""
|
||||||
Get number of tokens for llm
|
Get number of tokens for llm
|
||||||
|
|
|
@ -10,7 +10,7 @@
|
||||||
- 支持 5 种模型类型的能力调用
|
- 支持 5 种模型类型的能力调用
|
||||||
|
|
||||||
- `LLM` - LLM 文本补全、对话,预计算 tokens 能力
|
- `LLM` - LLM 文本补全、对话,预计算 tokens 能力
|
||||||
- `Text Embedding Model` - 文本 Embedding ,预计算 tokens 能力
|
- `Text Embedding Model` - 文本 Embedding,预计算 tokens 能力
|
||||||
- `Rerank Model` - 分段 Rerank 能力
|
- `Rerank Model` - 分段 Rerank 能力
|
||||||
- `Speech-to-text Model` - 语音转文本能力
|
- `Speech-to-text Model` - 语音转文本能力
|
||||||
- `Text-to-speech Model` - 文本转语音能力
|
- `Text-to-speech Model` - 文本转语音能力
|
||||||
|
@ -57,11 +57,11 @@ Model Runtime 分三层:
|
||||||
提供获取当前供应商模型列表、获取模型实例、供应商凭据鉴权、供应商配置规则信息,**可横向扩展**以支持不同的供应商。
|
提供获取当前供应商模型列表、获取模型实例、供应商凭据鉴权、供应商配置规则信息,**可横向扩展**以支持不同的供应商。
|
||||||
|
|
||||||
对于供应商/模型凭据,有两种情况
|
对于供应商/模型凭据,有两种情况
|
||||||
- 如OpenAI这类中心化供应商,需要定义如**api_key**这类的鉴权凭据
|
- 如 OpenAI 这类中心化供应商,需要定义如**api_key**这类的鉴权凭据
|
||||||
- 如[**Xinference**](https://github.com/xorbitsai/inference)这类本地部署的供应商,需要定义如**server_url**这类的地址凭据,有时候还需要定义**model_uid**之类的模型类型凭据,就像下面这样,当在供应商层定义了这些凭据后,就可以在前端页面上直接展示,无需修改前端逻辑。
|
- 如[**Xinference**](https://github.com/xorbitsai/inference)这类本地部署的供应商,需要定义如**server_url**这类的地址凭据,有时候还需要定义**model_uid**之类的模型类型凭据,就像下面这样,当在供应商层定义了这些凭据后,就可以在前端页面上直接展示,无需修改前端逻辑。
|
||||||

|

|
||||||
|
|
||||||
当配置好凭据后,就可以通过DifyRuntime的外部接口直接获取到对应供应商所需要的**Schema**(凭据表单规则),从而在可以在不修改前端逻辑的情况下,提供新的供应商/模型的支持。
|
当配置好凭据后,就可以通过 DifyRuntime 的外部接口直接获取到对应供应商所需要的**Schema**(凭据表单规则),从而在可以在不修改前端逻辑的情况下,提供新的供应商/模型的支持。
|
||||||
|
|
||||||
- 最底层为模型层
|
- 最底层为模型层
|
||||||
|
|
||||||
|
@ -69,9 +69,9 @@ Model Runtime 分三层:
|
||||||
|
|
||||||
在这里我们需要先区分模型参数与模型凭据。
|
在这里我们需要先区分模型参数与模型凭据。
|
||||||
|
|
||||||
- 模型参数(**在本层定义**):这是一类经常需要变动,随时调整的参数,如 LLM 的 **max_tokens**、**temperature** 等,这些参数是由用户在前端页面上进行调整的,因此需要在后端定义参数的规则,以便前端页面进行展示和调整。在DifyRuntime中,他们的参数名一般为**model_parameters: dict[str, any]**。
|
- 模型参数 (**在本层定义**):这是一类经常需要变动,随时调整的参数,如 LLM 的 **max_tokens**、**temperature** 等,这些参数是由用户在前端页面上进行调整的,因此需要在后端定义参数的规则,以便前端页面进行展示和调整。在 DifyRuntime 中,他们的参数名一般为**model_parameters: dict[str, any]**。
|
||||||
|
|
||||||
- 模型凭据(**在供应商层定义**):这是一类不经常变动,一般在配置好后就不会再变动的参数,如 **api_key**、**server_url** 等。在DifyRuntime中,他们的参数名一般为**credentials: dict[str, any]**,Provider层的credentials会直接被传递到这一层,不需要再单独定义。
|
- 模型凭据 (**在供应商层定义**):这是一类不经常变动,一般在配置好后就不会再变动的参数,如 **api_key**、**server_url** 等。在 DifyRuntime 中,他们的参数名一般为**credentials: dict[str, any]**,Provider 层的 credentials 会直接被传递到这一层,不需要再单独定义。
|
||||||
|
|
||||||
## 下一步
|
## 下一步
|
||||||
|
|
||||||
|
@ -81,7 +81,7 @@ Model Runtime 分三层:
|
||||||

|

|
||||||
|
|
||||||
### [为已存在的供应商新增模型 👈🏻](./docs/zh_Hans/provider_scale_out.md#增加模型)
|
### [为已存在的供应商新增模型 👈🏻](./docs/zh_Hans/provider_scale_out.md#增加模型)
|
||||||
当添加后,对应供应商的模型列表中将会出现一个新的预定义模型供用户选择,如GPT-3.5 GPT-4 ChatGLM3-6b等,而对于支持自定义模型的供应商,则不需要新增模型。
|
当添加后,对应供应商的模型列表中将会出现一个新的预定义模型供用户选择,如 GPT-3.5 GPT-4 ChatGLM3-6b 等,而对于支持自定义模型的供应商,则不需要新增模型。
|
||||||
|
|
||||||

|

|
||||||
|
|
||||||
|
|
|
@ -58,7 +58,7 @@ class Callback(ABC):
|
||||||
chunk: LLMResultChunk,
|
chunk: LLMResultChunk,
|
||||||
model: str,
|
model: str,
|
||||||
credentials: dict,
|
credentials: dict,
|
||||||
prompt_messages: list[PromptMessage],
|
prompt_messages: Sequence[PromptMessage],
|
||||||
model_parameters: dict,
|
model_parameters: dict,
|
||||||
tools: Optional[list[PromptMessageTool]] = None,
|
tools: Optional[list[PromptMessageTool]] = None,
|
||||||
stop: Optional[Sequence[str]] = None,
|
stop: Optional[Sequence[str]] = None,
|
||||||
|
@ -88,7 +88,7 @@ class Callback(ABC):
|
||||||
result: LLMResult,
|
result: LLMResult,
|
||||||
model: str,
|
model: str,
|
||||||
credentials: dict,
|
credentials: dict,
|
||||||
prompt_messages: list[PromptMessage],
|
prompt_messages: Sequence[PromptMessage],
|
||||||
model_parameters: dict,
|
model_parameters: dict,
|
||||||
tools: Optional[list[PromptMessageTool]] = None,
|
tools: Optional[list[PromptMessageTool]] = None,
|
||||||
stop: Optional[Sequence[str]] = None,
|
stop: Optional[Sequence[str]] = None,
|
||||||
|
|
|
@ -74,7 +74,7 @@ class LoggingCallback(Callback):
|
||||||
chunk: LLMResultChunk,
|
chunk: LLMResultChunk,
|
||||||
model: str,
|
model: str,
|
||||||
credentials: dict,
|
credentials: dict,
|
||||||
prompt_messages: list[PromptMessage],
|
prompt_messages: Sequence[PromptMessage],
|
||||||
model_parameters: dict,
|
model_parameters: dict,
|
||||||
tools: Optional[list[PromptMessageTool]] = None,
|
tools: Optional[list[PromptMessageTool]] = None,
|
||||||
stop: Optional[Sequence[str]] = None,
|
stop: Optional[Sequence[str]] = None,
|
||||||
|
@ -104,7 +104,7 @@ class LoggingCallback(Callback):
|
||||||
result: LLMResult,
|
result: LLMResult,
|
||||||
model: str,
|
model: str,
|
||||||
credentials: dict,
|
credentials: dict,
|
||||||
prompt_messages: list[PromptMessage],
|
prompt_messages: Sequence[PromptMessage],
|
||||||
model_parameters: dict,
|
model_parameters: dict,
|
||||||
tools: Optional[list[PromptMessageTool]] = None,
|
tools: Optional[list[PromptMessageTool]] = None,
|
||||||
stop: Optional[Sequence[str]] = None,
|
stop: Optional[Sequence[str]] = None,
|
||||||
|
|
|
@ -102,12 +102,12 @@ provider_credential_schema:
|
||||||
```yaml
|
```yaml
|
||||||
- variable: server_url
|
- variable: server_url
|
||||||
label:
|
label:
|
||||||
zh_Hans: 服务器URL
|
zh_Hans: 服务器 URL
|
||||||
en_US: Server url
|
en_US: Server url
|
||||||
type: text-input
|
type: text-input
|
||||||
required: true
|
required: true
|
||||||
placeholder:
|
placeholder:
|
||||||
zh_Hans: 在此输入Xinference的服务器地址,如 https://example.com/xxx
|
zh_Hans: 在此输入 Xinference 的服务器地址,如 https://example.com/xxx
|
||||||
en_US: Enter the url of your Xinference, for example https://example.com/xxx
|
en_US: Enter the url of your Xinference, for example https://example.com/xxx
|
||||||
```
|
```
|
||||||
|
|
||||||
|
@ -116,12 +116,12 @@ provider_credential_schema:
|
||||||
```yaml
|
```yaml
|
||||||
- variable: model_uid
|
- variable: model_uid
|
||||||
label:
|
label:
|
||||||
zh_Hans: 模型UID
|
zh_Hans: 模型 UID
|
||||||
en_US: Model uid
|
en_US: Model uid
|
||||||
type: text-input
|
type: text-input
|
||||||
required: true
|
required: true
|
||||||
placeholder:
|
placeholder:
|
||||||
zh_Hans: 在此输入您的Model UID
|
zh_Hans: 在此输入您的 Model UID
|
||||||
en_US: Enter the model uid
|
en_US: Enter the model uid
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
|
@ -367,7 +367,7 @@ Inherit the `__base.text2speech_model.Text2SpeechModel` base class and implement
|
||||||
|
|
||||||
- Returns:
|
- Returns:
|
||||||
|
|
||||||
Text converted speech stream。
|
Text converted speech stream.
|
||||||
|
|
||||||
### Moderation
|
### Moderation
|
||||||
|
|
||||||
|
|
|
@ -6,14 +6,14 @@
|
||||||
|
|
||||||
需要注意的是,对于自定义模型,每一个模型的接入都需要填写一个完整的供应商凭据。
|
需要注意的是,对于自定义模型,每一个模型的接入都需要填写一个完整的供应商凭据。
|
||||||
|
|
||||||
而不同于预定义模型,自定义供应商接入时永远会拥有如下两个参数,不需要在供应商yaml中定义。
|
而不同于预定义模型,自定义供应商接入时永远会拥有如下两个参数,不需要在供应商 yaml 中定义。
|
||||||
|
|
||||||

|

|
||||||
|
|
||||||
|
|
||||||
在前文中,我们已经知道了供应商无需实现`validate_provider_credential`,Runtime会自行根据用户在此选择的模型类型和模型名称调用对应的模型层的`validate_credentials`来进行验证。
|
在前文中,我们已经知道了供应商无需实现`validate_provider_credential`,Runtime 会自行根据用户在此选择的模型类型和模型名称调用对应的模型层的`validate_credentials`来进行验证。
|
||||||
|
|
||||||
### 编写供应商yaml
|
### 编写供应商 yaml
|
||||||
|
|
||||||
我们首先要确定,接入的这个供应商支持哪些类型的模型。
|
我们首先要确定,接入的这个供应商支持哪些类型的模型。
|
||||||
|
|
||||||
|
@ -26,7 +26,7 @@
|
||||||
- `tts` 文字转语音
|
- `tts` 文字转语音
|
||||||
- `moderation` 审查
|
- `moderation` 审查
|
||||||
|
|
||||||
`Xinference`支持`LLM`和`Text Embedding`和Rerank,那么我们开始编写`xinference.yaml`。
|
`Xinference`支持`LLM`和`Text Embedding`和 Rerank,那么我们开始编写`xinference.yaml`。
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
provider: xinference #确定供应商标识
|
provider: xinference #确定供应商标识
|
||||||
|
@ -42,17 +42,17 @@ help: # 帮助
|
||||||
zh_Hans: 如何部署 Xinference
|
zh_Hans: 如何部署 Xinference
|
||||||
url:
|
url:
|
||||||
en_US: https://github.com/xorbitsai/inference
|
en_US: https://github.com/xorbitsai/inference
|
||||||
supported_model_types: # 支持的模型类型,Xinference同时支持LLM/Text Embedding/Rerank
|
supported_model_types: # 支持的模型类型,Xinference 同时支持 LLM/Text Embedding/Rerank
|
||||||
- llm
|
- llm
|
||||||
- text-embedding
|
- text-embedding
|
||||||
- rerank
|
- rerank
|
||||||
configurate_methods: # 因为Xinference为本地部署的供应商,并且没有预定义模型,需要用什么模型需要根据Xinference的文档自己部署,所以这里只支持自定义模型
|
configurate_methods: # 因为 Xinference 为本地部署的供应商,并且没有预定义模型,需要用什么模型需要根据 Xinference 的文档自己部署,所以这里只支持自定义模型
|
||||||
- customizable-model
|
- customizable-model
|
||||||
provider_credential_schema:
|
provider_credential_schema:
|
||||||
credential_form_schemas:
|
credential_form_schemas:
|
||||||
```
|
```
|
||||||
|
|
||||||
随后,我们需要思考在Xinference中定义一个模型需要哪些凭据
|
随后,我们需要思考在 Xinference 中定义一个模型需要哪些凭据
|
||||||
|
|
||||||
- 它支持三种不同的模型,因此,我们需要有`model_type`来指定这个模型的类型,它有三种类型,所以我们这么编写
|
- 它支持三种不同的模型,因此,我们需要有`model_type`来指定这个模型的类型,它有三种类型,所以我们这么编写
|
||||||
```yaml
|
```yaml
|
||||||
|
@ -88,28 +88,28 @@ provider_credential_schema:
|
||||||
zh_Hans: 填写模型名称
|
zh_Hans: 填写模型名称
|
||||||
en_US: Input model name
|
en_US: Input model name
|
||||||
```
|
```
|
||||||
- 填写Xinference本地部署的地址
|
- 填写 Xinference 本地部署的地址
|
||||||
```yaml
|
```yaml
|
||||||
- variable: server_url
|
- variable: server_url
|
||||||
label:
|
label:
|
||||||
zh_Hans: 服务器URL
|
zh_Hans: 服务器 URL
|
||||||
en_US: Server url
|
en_US: Server url
|
||||||
type: text-input
|
type: text-input
|
||||||
required: true
|
required: true
|
||||||
placeholder:
|
placeholder:
|
||||||
zh_Hans: 在此输入Xinference的服务器地址,如 https://example.com/xxx
|
zh_Hans: 在此输入 Xinference 的服务器地址,如 https://example.com/xxx
|
||||||
en_US: Enter the url of your Xinference, for example https://example.com/xxx
|
en_US: Enter the url of your Xinference, for example https://example.com/xxx
|
||||||
```
|
```
|
||||||
- 每个模型都有唯一的model_uid,因此需要在这里定义
|
- 每个模型都有唯一的 model_uid,因此需要在这里定义
|
||||||
```yaml
|
```yaml
|
||||||
- variable: model_uid
|
- variable: model_uid
|
||||||
label:
|
label:
|
||||||
zh_Hans: 模型UID
|
zh_Hans: 模型 UID
|
||||||
en_US: Model uid
|
en_US: Model uid
|
||||||
type: text-input
|
type: text-input
|
||||||
required: true
|
required: true
|
||||||
placeholder:
|
placeholder:
|
||||||
zh_Hans: 在此输入您的Model UID
|
zh_Hans: 在此输入您的 Model UID
|
||||||
en_US: Enter the model uid
|
en_US: Enter the model uid
|
||||||
```
|
```
|
||||||
现在,我们就完成了供应商的基础定义。
|
现在,我们就完成了供应商的基础定义。
|
||||||
|
@ -145,7 +145,7 @@ provider_credential_schema:
|
||||||
"""
|
"""
|
||||||
```
|
```
|
||||||
|
|
||||||
在实现时,需要注意使用两个函数来返回数据,分别用于处理同步返回和流式返回,因为Python会将函数中包含 `yield` 关键字的函数识别为生成器函数,返回的数据类型固定为 `Generator`,因此同步和流式返回需要分别实现,就像下面这样(注意下面例子使用了简化参数,实际实现时需要按照上面的参数列表进行实现):
|
在实现时,需要注意使用两个函数来返回数据,分别用于处理同步返回和流式返回,因为 Python 会将函数中包含 `yield` 关键字的函数识别为生成器函数,返回的数据类型固定为 `Generator`,因此同步和流式返回需要分别实现,就像下面这样(注意下面例子使用了简化参数,实际实现时需要按照上面的参数列表进行实现):
|
||||||
|
|
||||||
```python
|
```python
|
||||||
def _invoke(self, stream: bool, **kwargs) \
|
def _invoke(self, stream: bool, **kwargs) \
|
||||||
|
@ -179,7 +179,7 @@ provider_credential_schema:
|
||||||
"""
|
"""
|
||||||
```
|
```
|
||||||
|
|
||||||
有时候,也许你不需要直接返回0,所以你可以使用`self._get_num_tokens_by_gpt2(text: str)`来获取预计算的tokens,并确保环境变量`PLUGIN_BASED_TOKEN_COUNTING_ENABLED`设置为`true`,这个方法位于`AIModel`基类中,它会使用GPT2的Tokenizer进行计算,但是只能作为替代方法,并不完全准确。
|
有时候,也许你不需要直接返回 0,所以你可以使用`self._get_num_tokens_by_gpt2(text: str)`来获取预计算的 tokens,并确保环境变量`PLUGIN_BASED_TOKEN_COUNTING_ENABLED`设置为`true`,这个方法位于`AIModel`基类中,它会使用 GPT2 的 Tokenizer 进行计算,但是只能作为替代方法,并不完全准确。
|
||||||
|
|
||||||
- 模型凭据校验
|
- 模型凭据校验
|
||||||
|
|
||||||
|
@ -196,13 +196,13 @@ provider_credential_schema:
|
||||||
"""
|
"""
|
||||||
```
|
```
|
||||||
|
|
||||||
- 模型参数Schema
|
- 模型参数 Schema
|
||||||
|
|
||||||
与自定义类型不同,由于没有在yaml文件中定义一个模型支持哪些参数,因此,我们需要动态时间模型参数的Schema。
|
与自定义类型不同,由于没有在 yaml 文件中定义一个模型支持哪些参数,因此,我们需要动态时间模型参数的 Schema。
|
||||||
|
|
||||||
如Xinference支持`max_tokens` `temperature` `top_p` 这三个模型参数。
|
如 Xinference 支持`max_tokens` `temperature` `top_p` 这三个模型参数。
|
||||||
|
|
||||||
但是有的供应商根据不同的模型支持不同的参数,如供应商`OpenLLM`支持`top_k`,但是并不是这个供应商提供的所有模型都支持`top_k`,我们这里举例A模型支持`top_k`,B模型不支持`top_k`,那么我们需要在这里动态生成模型参数的Schema,如下所示:
|
但是有的供应商根据不同的模型支持不同的参数,如供应商`OpenLLM`支持`top_k`,但是并不是这个供应商提供的所有模型都支持`top_k`,我们这里举例 A 模型支持`top_k`,B 模型不支持`top_k`,那么我们需要在这里动态生成模型参数的 Schema,如下所示:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
|
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
|
||||||
|
|
|
@ -687,7 +687,7 @@ class LLMUsage(ModelUsage):
|
||||||
total_tokens: int # 总使用 token 数
|
total_tokens: int # 总使用 token 数
|
||||||
total_price: Decimal # 总费用
|
total_price: Decimal # 总费用
|
||||||
currency: str # 货币单位
|
currency: str # 货币单位
|
||||||
latency: float # 请求耗时(s)
|
latency: float # 请求耗时 (s)
|
||||||
```
|
```
|
||||||
|
|
||||||
---
|
---
|
||||||
|
@ -717,7 +717,7 @@ class EmbeddingUsage(ModelUsage):
|
||||||
price_unit: Decimal # 价格单位,即单价基于多少 tokens
|
price_unit: Decimal # 价格单位,即单价基于多少 tokens
|
||||||
total_price: Decimal # 总费用
|
total_price: Decimal # 总费用
|
||||||
currency: str # 货币单位
|
currency: str # 货币单位
|
||||||
latency: float # 请求耗时(s)
|
latency: float # 请求耗时 (s)
|
||||||
```
|
```
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
|
@ -95,7 +95,7 @@ pricing: # 价格信息
|
||||||
"""
|
"""
|
||||||
```
|
```
|
||||||
|
|
||||||
在实现时,需要注意使用两个函数来返回数据,分别用于处理同步返回和流式返回,因为Python会将函数中包含 `yield` 关键字的函数识别为生成器函数,返回的数据类型固定为 `Generator`,因此同步和流式返回需要分别实现,就像下面这样(注意下面例子使用了简化参数,实际实现时需要按照上面的参数列表进行实现):
|
在实现时,需要注意使用两个函数来返回数据,分别用于处理同步返回和流式返回,因为 Python 会将函数中包含 `yield` 关键字的函数识别为生成器函数,返回的数据类型固定为 `Generator`,因此同步和流式返回需要分别实现,就像下面这样(注意下面例子使用了简化参数,实际实现时需要按照上面的参数列表进行实现):
|
||||||
|
|
||||||
```python
|
```python
|
||||||
def _invoke(self, stream: bool, **kwargs) \
|
def _invoke(self, stream: bool, **kwargs) \
|
||||||
|
|
|
@ -8,13 +8,13 @@
|
||||||
|
|
||||||
- `customizable-model` 自定义模型
|
- `customizable-model` 自定义模型
|
||||||
|
|
||||||
用户需要新增每个模型的凭据配置,如Xinference,它同时支持 LLM 和 Text Embedding,但是每个模型都有唯一的**model_uid**,如果想要将两者同时接入,就需要为每个模型配置一个**model_uid**。
|
用户需要新增每个模型的凭据配置,如 Xinference,它同时支持 LLM 和 Text Embedding,但是每个模型都有唯一的**model_uid**,如果想要将两者同时接入,就需要为每个模型配置一个**model_uid**。
|
||||||
|
|
||||||
- `fetch-from-remote` 从远程获取
|
- `fetch-from-remote` 从远程获取
|
||||||
|
|
||||||
与 `predefined-model` 配置方式一致,只需要配置统一的供应商凭据即可,模型通过凭据信息从供应商获取。
|
与 `predefined-model` 配置方式一致,只需要配置统一的供应商凭据即可,模型通过凭据信息从供应商获取。
|
||||||
|
|
||||||
如OpenAI,我们可以基于gpt-turbo-3.5来Fine Tune多个模型,而他们都位于同一个**api_key**下,当配置为 `fetch-from-remote` 时,开发者只需要配置统一的**api_key**即可让DifyRuntime获取到开发者所有的微调模型并接入Dify。
|
如 OpenAI,我们可以基于 gpt-turbo-3.5 来 Fine Tune 多个模型,而他们都位于同一个**api_key**下,当配置为 `fetch-from-remote` 时,开发者只需要配置统一的**api_key**即可让 DifyRuntime 获取到开发者所有的微调模型并接入 Dify。
|
||||||
|
|
||||||
这三种配置方式**支持共存**,即存在供应商支持 `predefined-model` + `customizable-model` 或 `predefined-model` + `fetch-from-remote` 等,也就是配置了供应商统一凭据可以使用预定义模型和从远程获取的模型,若新增了模型,则可以在此基础上额外使用自定义的模型。
|
这三种配置方式**支持共存**,即存在供应商支持 `predefined-model` + `customizable-model` 或 `predefined-model` + `fetch-from-remote` 等,也就是配置了供应商统一凭据可以使用预定义模型和从远程获取的模型,若新增了模型,则可以在此基础上额外使用自定义的模型。
|
||||||
|
|
||||||
|
@ -23,16 +23,16 @@
|
||||||
### 介绍
|
### 介绍
|
||||||
|
|
||||||
#### 名词解释
|
#### 名词解释
|
||||||
- `module`: 一个`module`即为一个Python Package,或者通俗一点,称为一个文件夹,里面包含了一个`__init__.py`文件,以及其他的`.py`文件。
|
- `module`: 一个`module`即为一个 Python Package,或者通俗一点,称为一个文件夹,里面包含了一个`__init__.py`文件,以及其他的`.py`文件。
|
||||||
|
|
||||||
#### 步骤
|
#### 步骤
|
||||||
新增一个供应商主要分为几步,这里简单列出,帮助大家有一个大概的认识,具体的步骤会在下面详细介绍。
|
新增一个供应商主要分为几步,这里简单列出,帮助大家有一个大概的认识,具体的步骤会在下面详细介绍。
|
||||||
|
|
||||||
- 创建供应商yaml文件,根据[ProviderSchema](./schema.md#provider)编写
|
- 创建供应商 yaml 文件,根据[ProviderSchema](./schema.md#provider)编写
|
||||||
- 创建供应商代码,实现一个`class`。
|
- 创建供应商代码,实现一个`class`。
|
||||||
- 根据模型类型,在供应商`module`下创建对应的模型类型 `module`,如`llm`或`text_embedding`。
|
- 根据模型类型,在供应商`module`下创建对应的模型类型 `module`,如`llm`或`text_embedding`。
|
||||||
- 根据模型类型,在对应的模型`module`下创建同名的代码文件,如`llm.py`,并实现一个`class`。
|
- 根据模型类型,在对应的模型`module`下创建同名的代码文件,如`llm.py`,并实现一个`class`。
|
||||||
- 如果有预定义模型,根据模型名称创建同名的yaml文件在模型`module`下,如`claude-2.1.yaml`,根据[AIModelEntity](./schema.md#aimodelentity)编写。
|
- 如果有预定义模型,根据模型名称创建同名的 yaml 文件在模型`module`下,如`claude-2.1.yaml`,根据[AIModelEntity](./schema.md#aimodelentity)编写。
|
||||||
- 编写测试代码,确保功能可用。
|
- 编写测试代码,确保功能可用。
|
||||||
|
|
||||||
### 开始吧
|
### 开始吧
|
||||||
|
@ -121,11 +121,11 @@ model_credential_schema:
|
||||||
|
|
||||||
#### 实现供应商代码
|
#### 实现供应商代码
|
||||||
|
|
||||||
我们需要在`model_providers`下创建一个同名的python文件,如`anthropic.py`,并实现一个`class`,继承`__base.provider.Provider`基类,如`AnthropicProvider`。
|
我们需要在`model_providers`下创建一个同名的 python 文件,如`anthropic.py`,并实现一个`class`,继承`__base.provider.Provider`基类,如`AnthropicProvider`。
|
||||||
|
|
||||||
##### 自定义模型供应商
|
##### 自定义模型供应商
|
||||||
|
|
||||||
当供应商为Xinference等自定义模型供应商时,可跳过该步骤,仅创建一个空的`XinferenceProvider`类即可,并实现一个空的`validate_provider_credentials`方法,该方法并不会被实际使用,仅用作避免抽象类无法实例化。
|
当供应商为 Xinference 等自定义模型供应商时,可跳过该步骤,仅创建一个空的`XinferenceProvider`类即可,并实现一个空的`validate_provider_credentials`方法,该方法并不会被实际使用,仅用作避免抽象类无法实例化。
|
||||||
|
|
||||||
```python
|
```python
|
||||||
class XinferenceProvider(Provider):
|
class XinferenceProvider(Provider):
|
||||||
|
@ -155,7 +155,7 @@ def validate_provider_credentials(self, credentials: dict) -> None:
|
||||||
#### 增加模型
|
#### 增加模型
|
||||||
|
|
||||||
#### [增加预定义模型 👈🏻](./predefined_model_scale_out.md)
|
#### [增加预定义模型 👈🏻](./predefined_model_scale_out.md)
|
||||||
对于预定义模型,我们可以通过简单定义一个yaml,并通过实现调用代码来接入。
|
对于预定义模型,我们可以通过简单定义一个 yaml,并通过实现调用代码来接入。
|
||||||
|
|
||||||
#### [增加自定义模型 👈🏻](./customizable_model_scale_out.md)
|
#### [增加自定义模型 👈🏻](./customizable_model_scale_out.md)
|
||||||
对于自定义模型,我们只需要实现调用代码即可接入,但是它需要处理的参数可能会更加复杂。
|
对于自定义模型,我们只需要实现调用代码即可接入,但是它需要处理的参数可能会更加复杂。
|
||||||
|
|
|
@ -29,7 +29,7 @@ PARAMETER_RULE_TEMPLATE: dict[DefaultParameterName, dict] = {
|
||||||
"help": {
|
"help": {
|
||||||
"en_US": "Controls diversity via nucleus sampling: 0.5 means half of all likelihood-weighted options"
|
"en_US": "Controls diversity via nucleus sampling: 0.5 means half of all likelihood-weighted options"
|
||||||
" are considered.",
|
" are considered.",
|
||||||
"zh_Hans": "通过核心采样控制多样性:0.5表示考虑了一半的所有可能性加权选项。",
|
"zh_Hans": "通过核心采样控制多样性:0.5 表示考虑了一半的所有可能性加权选项。",
|
||||||
},
|
},
|
||||||
"required": False,
|
"required": False,
|
||||||
"default": 1.0,
|
"default": 1.0,
|
||||||
|
@ -111,7 +111,7 @@ PARAMETER_RULE_TEMPLATE: dict[DefaultParameterName, dict] = {
|
||||||
"help": {
|
"help": {
|
||||||
"en_US": "Set a response format, ensure the output from llm is a valid code block as possible,"
|
"en_US": "Set a response format, ensure the output from llm is a valid code block as possible,"
|
||||||
" such as JSON, XML, etc.",
|
" such as JSON, XML, etc.",
|
||||||
"zh_Hans": "设置一个返回格式,确保llm的输出尽可能是有效的代码块,如JSON、XML等",
|
"zh_Hans": "设置一个返回格式,确保 llm 的输出尽可能是有效的代码块,如 JSON、XML 等",
|
||||||
},
|
},
|
||||||
"required": False,
|
"required": False,
|
||||||
"options": ["JSON", "XML"],
|
"options": ["JSON", "XML"],
|
||||||
|
@ -123,7 +123,7 @@ PARAMETER_RULE_TEMPLATE: dict[DefaultParameterName, dict] = {
|
||||||
"type": "text",
|
"type": "text",
|
||||||
"help": {
|
"help": {
|
||||||
"en_US": "Set a response json schema will ensure LLM to adhere it.",
|
"en_US": "Set a response json schema will ensure LLM to adhere it.",
|
||||||
"zh_Hans": "设置返回的json schema,llm将按照它返回",
|
"zh_Hans": "设置返回的 json schema,llm 将按照它返回",
|
||||||
},
|
},
|
||||||
"required": False,
|
"required": False,
|
||||||
},
|
},
|
||||||
|
|
|
@ -1,8 +1,9 @@
|
||||||
|
from collections.abc import Sequence
|
||||||
from decimal import Decimal
|
from decimal import Decimal
|
||||||
from enum import StrEnum
|
from enum import StrEnum
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from core.model_runtime.entities.message_entities import AssistantPromptMessage, PromptMessage
|
from core.model_runtime.entities.message_entities import AssistantPromptMessage, PromptMessage
|
||||||
from core.model_runtime.entities.model_entities import ModelUsage, PriceInfo
|
from core.model_runtime.entities.model_entities import ModelUsage, PriceInfo
|
||||||
|
@ -107,7 +108,7 @@ class LLMResult(BaseModel):
|
||||||
|
|
||||||
id: Optional[str] = None
|
id: Optional[str] = None
|
||||||
model: str
|
model: str
|
||||||
prompt_messages: list[PromptMessage]
|
prompt_messages: Sequence[PromptMessage] = Field(default_factory=list)
|
||||||
message: AssistantPromptMessage
|
message: AssistantPromptMessage
|
||||||
usage: LLMUsage
|
usage: LLMUsage
|
||||||
system_fingerprint: Optional[str] = None
|
system_fingerprint: Optional[str] = None
|
||||||
|
@ -130,7 +131,7 @@ class LLMResultChunk(BaseModel):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
model: str
|
model: str
|
||||||
prompt_messages: list[PromptMessage]
|
prompt_messages: Sequence[PromptMessage] = Field(default_factory=list)
|
||||||
system_fingerprint: Optional[str] = None
|
system_fingerprint: Optional[str] = None
|
||||||
delta: LLMResultChunkDelta
|
delta: LLMResultChunkDelta
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
|
import uuid
|
||||||
from collections.abc import Generator, Sequence
|
from collections.abc import Generator, Sequence
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
@ -24,6 +25,58 @@ from core.plugin.manager.model import PluginModelManager
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _gen_tool_call_id() -> str:
|
||||||
|
return f"chatcmpl-tool-{str(uuid.uuid4().hex)}"
|
||||||
|
|
||||||
|
|
||||||
|
def _increase_tool_call(
|
||||||
|
new_tool_calls: list[AssistantPromptMessage.ToolCall], existing_tools_calls: list[AssistantPromptMessage.ToolCall]
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Merge incremental tool call updates into existing tool calls.
|
||||||
|
|
||||||
|
:param new_tool_calls: List of new tool call deltas to be merged.
|
||||||
|
:param existing_tools_calls: List of existing tool calls to be modified IN-PLACE.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def get_tool_call(tool_call_id: str):
|
||||||
|
"""
|
||||||
|
Get or create a tool call by ID
|
||||||
|
|
||||||
|
:param tool_call_id: tool call ID
|
||||||
|
:return: existing or new tool call
|
||||||
|
"""
|
||||||
|
if not tool_call_id:
|
||||||
|
return existing_tools_calls[-1]
|
||||||
|
|
||||||
|
_tool_call = next((_tool_call for _tool_call in existing_tools_calls if _tool_call.id == tool_call_id), None)
|
||||||
|
if _tool_call is None:
|
||||||
|
_tool_call = AssistantPromptMessage.ToolCall(
|
||||||
|
id=tool_call_id,
|
||||||
|
type="function",
|
||||||
|
function=AssistantPromptMessage.ToolCall.ToolCallFunction(name="", arguments=""),
|
||||||
|
)
|
||||||
|
existing_tools_calls.append(_tool_call)
|
||||||
|
|
||||||
|
return _tool_call
|
||||||
|
|
||||||
|
for new_tool_call in new_tool_calls:
|
||||||
|
# generate ID for tool calls with function name but no ID to track them
|
||||||
|
if new_tool_call.function.name and not new_tool_call.id:
|
||||||
|
new_tool_call.id = _gen_tool_call_id()
|
||||||
|
# get tool call
|
||||||
|
tool_call = get_tool_call(new_tool_call.id)
|
||||||
|
# update tool call
|
||||||
|
if new_tool_call.id:
|
||||||
|
tool_call.id = new_tool_call.id
|
||||||
|
if new_tool_call.type:
|
||||||
|
tool_call.type = new_tool_call.type
|
||||||
|
if new_tool_call.function.name:
|
||||||
|
tool_call.function.name = new_tool_call.function.name
|
||||||
|
if new_tool_call.function.arguments:
|
||||||
|
tool_call.function.arguments += new_tool_call.function.arguments
|
||||||
|
|
||||||
|
|
||||||
class LargeLanguageModel(AIModel):
|
class LargeLanguageModel(AIModel):
|
||||||
"""
|
"""
|
||||||
Model class for large language model.
|
Model class for large language model.
|
||||||
|
@ -45,7 +98,7 @@ class LargeLanguageModel(AIModel):
|
||||||
stream: bool = True,
|
stream: bool = True,
|
||||||
user: Optional[str] = None,
|
user: Optional[str] = None,
|
||||||
callbacks: Optional[list[Callback]] = None,
|
callbacks: Optional[list[Callback]] = None,
|
||||||
) -> Union[LLMResult, Generator]:
|
) -> Union[LLMResult, Generator[LLMResultChunk, None, None]]:
|
||||||
"""
|
"""
|
||||||
Invoke large language model
|
Invoke large language model
|
||||||
|
|
||||||
|
@ -109,44 +162,13 @@ class LargeLanguageModel(AIModel):
|
||||||
system_fingerprint = None
|
system_fingerprint = None
|
||||||
tools_calls: list[AssistantPromptMessage.ToolCall] = []
|
tools_calls: list[AssistantPromptMessage.ToolCall] = []
|
||||||
|
|
||||||
def increase_tool_call(new_tool_calls: list[AssistantPromptMessage.ToolCall]):
|
|
||||||
def get_tool_call(tool_name: str):
|
|
||||||
if not tool_name:
|
|
||||||
return tools_calls[-1]
|
|
||||||
|
|
||||||
tool_call = next(
|
|
||||||
(tool_call for tool_call in tools_calls if tool_call.function.name == tool_name), None
|
|
||||||
)
|
|
||||||
if tool_call is None:
|
|
||||||
tool_call = AssistantPromptMessage.ToolCall(
|
|
||||||
id="",
|
|
||||||
type="",
|
|
||||||
function=AssistantPromptMessage.ToolCall.ToolCallFunction(name=tool_name, arguments=""),
|
|
||||||
)
|
|
||||||
tools_calls.append(tool_call)
|
|
||||||
|
|
||||||
return tool_call
|
|
||||||
|
|
||||||
for new_tool_call in new_tool_calls:
|
|
||||||
# get tool call
|
|
||||||
tool_call = get_tool_call(new_tool_call.function.name)
|
|
||||||
# update tool call
|
|
||||||
if new_tool_call.id:
|
|
||||||
tool_call.id = new_tool_call.id
|
|
||||||
if new_tool_call.type:
|
|
||||||
tool_call.type = new_tool_call.type
|
|
||||||
if new_tool_call.function.name:
|
|
||||||
tool_call.function.name = new_tool_call.function.name
|
|
||||||
if new_tool_call.function.arguments:
|
|
||||||
tool_call.function.arguments += new_tool_call.function.arguments
|
|
||||||
|
|
||||||
for chunk in result:
|
for chunk in result:
|
||||||
if isinstance(chunk.delta.message.content, str):
|
if isinstance(chunk.delta.message.content, str):
|
||||||
content += chunk.delta.message.content
|
content += chunk.delta.message.content
|
||||||
elif isinstance(chunk.delta.message.content, list):
|
elif isinstance(chunk.delta.message.content, list):
|
||||||
content_list.extend(chunk.delta.message.content)
|
content_list.extend(chunk.delta.message.content)
|
||||||
if chunk.delta.message.tool_calls:
|
if chunk.delta.message.tool_calls:
|
||||||
increase_tool_call(chunk.delta.message.tool_calls)
|
_increase_tool_call(chunk.delta.message.tool_calls, tools_calls)
|
||||||
|
|
||||||
usage = chunk.delta.usage or LLMUsage.empty_usage()
|
usage = chunk.delta.usage or LLMUsage.empty_usage()
|
||||||
system_fingerprint = chunk.system_fingerprint
|
system_fingerprint = chunk.system_fingerprint
|
||||||
|
@ -205,22 +227,26 @@ class LargeLanguageModel(AIModel):
|
||||||
user=user,
|
user=user,
|
||||||
callbacks=callbacks,
|
callbacks=callbacks,
|
||||||
)
|
)
|
||||||
|
# Following https://github.com/langgenius/dify/issues/17799,
|
||||||
return result
|
# we removed the prompt_messages from the chunk on the plugin daemon side.
|
||||||
|
# To ensure compatibility, we add the prompt_messages back here.
|
||||||
|
result.prompt_messages = prompt_messages
|
||||||
|
return result
|
||||||
|
raise NotImplementedError("unsupported invoke result type", type(result))
|
||||||
|
|
||||||
def _invoke_result_generator(
|
def _invoke_result_generator(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
result: Generator,
|
result: Generator,
|
||||||
credentials: dict,
|
credentials: dict,
|
||||||
prompt_messages: list[PromptMessage],
|
prompt_messages: Sequence[PromptMessage],
|
||||||
model_parameters: dict,
|
model_parameters: dict,
|
||||||
tools: Optional[list[PromptMessageTool]] = None,
|
tools: Optional[list[PromptMessageTool]] = None,
|
||||||
stop: Optional[Sequence[str]] = None,
|
stop: Optional[Sequence[str]] = None,
|
||||||
stream: bool = True,
|
stream: bool = True,
|
||||||
user: Optional[str] = None,
|
user: Optional[str] = None,
|
||||||
callbacks: Optional[list[Callback]] = None,
|
callbacks: Optional[list[Callback]] = None,
|
||||||
) -> Generator:
|
) -> Generator[LLMResultChunk, None, None]:
|
||||||
"""
|
"""
|
||||||
Invoke result generator
|
Invoke result generator
|
||||||
|
|
||||||
|
@ -235,6 +261,10 @@ class LargeLanguageModel(AIModel):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
for chunk in result:
|
for chunk in result:
|
||||||
|
# Following https://github.com/langgenius/dify/issues/17799,
|
||||||
|
# we removed the prompt_messages from the chunk on the plugin daemon side.
|
||||||
|
# To ensure compatibility, we add the prompt_messages back here.
|
||||||
|
chunk.prompt_messages = prompt_messages
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
self._trigger_new_chunk_callbacks(
|
self._trigger_new_chunk_callbacks(
|
||||||
|
@ -403,7 +433,7 @@ class LargeLanguageModel(AIModel):
|
||||||
chunk: LLMResultChunk,
|
chunk: LLMResultChunk,
|
||||||
model: str,
|
model: str,
|
||||||
credentials: dict,
|
credentials: dict,
|
||||||
prompt_messages: list[PromptMessage],
|
prompt_messages: Sequence[PromptMessage],
|
||||||
model_parameters: dict,
|
model_parameters: dict,
|
||||||
tools: Optional[list[PromptMessageTool]] = None,
|
tools: Optional[list[PromptMessageTool]] = None,
|
||||||
stop: Optional[Sequence[str]] = None,
|
stop: Optional[Sequence[str]] = None,
|
||||||
|
@ -450,7 +480,7 @@ class LargeLanguageModel(AIModel):
|
||||||
model: str,
|
model: str,
|
||||||
result: LLMResult,
|
result: LLMResult,
|
||||||
credentials: dict,
|
credentials: dict,
|
||||||
prompt_messages: list[PromptMessage],
|
prompt_messages: Sequence[PromptMessage],
|
||||||
model_parameters: dict,
|
model_parameters: dict,
|
||||||
tools: Optional[list[PromptMessageTool]] = None,
|
tools: Optional[list[PromptMessageTool]] = None,
|
||||||
stop: Optional[Sequence[str]] = None,
|
stop: Optional[Sequence[str]] = None,
|
||||||
|
|
|
@ -5,6 +5,7 @@ from datetime import datetime, timedelta
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from langfuse import Langfuse # type: ignore
|
from langfuse import Langfuse # type: ignore
|
||||||
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
|
||||||
from core.ops.base_trace_instance import BaseTraceInstance
|
from core.ops.base_trace_instance import BaseTraceInstance
|
||||||
from core.ops.entities.config_entity import LangfuseConfig
|
from core.ops.entities.config_entity import LangfuseConfig
|
||||||
|
@ -28,9 +29,9 @@ from core.ops.langfuse_trace.entities.langfuse_trace_entity import (
|
||||||
UnitEnum,
|
UnitEnum,
|
||||||
)
|
)
|
||||||
from core.ops.utils import filter_none_values
|
from core.ops.utils import filter_none_values
|
||||||
|
from core.repository.repository_factory import RepositoryFactory
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models.model import EndUser
|
from models.model import EndUser
|
||||||
from models.workflow import WorkflowNodeExecution
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -110,36 +111,18 @@ class LangFuseDataTrace(BaseTraceInstance):
|
||||||
)
|
)
|
||||||
self.add_trace(langfuse_trace_data=trace_data)
|
self.add_trace(langfuse_trace_data=trace_data)
|
||||||
|
|
||||||
# through workflow_run_id get all_nodes_execution
|
# through workflow_run_id get all_nodes_execution using repository
|
||||||
workflow_nodes_execution_id_records = (
|
session_factory = sessionmaker(bind=db.engine)
|
||||||
db.session.query(WorkflowNodeExecution.id)
|
workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository(
|
||||||
.filter(WorkflowNodeExecution.workflow_run_id == trace_info.workflow_run_id)
|
params={"tenant_id": trace_info.tenant_id, "session_factory": session_factory},
|
||||||
.all()
|
|
||||||
)
|
)
|
||||||
|
|
||||||
for node_execution_id_record in workflow_nodes_execution_id_records:
|
# Get all executions for this workflow run
|
||||||
node_execution = (
|
workflow_node_executions = workflow_node_execution_repository.get_by_workflow_run(
|
||||||
db.session.query(
|
workflow_run_id=trace_info.workflow_run_id
|
||||||
WorkflowNodeExecution.id,
|
)
|
||||||
WorkflowNodeExecution.tenant_id,
|
|
||||||
WorkflowNodeExecution.app_id,
|
|
||||||
WorkflowNodeExecution.title,
|
|
||||||
WorkflowNodeExecution.node_type,
|
|
||||||
WorkflowNodeExecution.status,
|
|
||||||
WorkflowNodeExecution.inputs,
|
|
||||||
WorkflowNodeExecution.outputs,
|
|
||||||
WorkflowNodeExecution.created_at,
|
|
||||||
WorkflowNodeExecution.elapsed_time,
|
|
||||||
WorkflowNodeExecution.process_data,
|
|
||||||
WorkflowNodeExecution.execution_metadata,
|
|
||||||
)
|
|
||||||
.filter(WorkflowNodeExecution.id == node_execution_id_record.id)
|
|
||||||
.first()
|
|
||||||
)
|
|
||||||
|
|
||||||
if not node_execution:
|
|
||||||
continue
|
|
||||||
|
|
||||||
|
for node_execution in workflow_node_executions:
|
||||||
node_execution_id = node_execution.id
|
node_execution_id = node_execution.id
|
||||||
tenant_id = node_execution.tenant_id
|
tenant_id = node_execution.tenant_id
|
||||||
app_id = node_execution.app_id
|
app_id = node_execution.app_id
|
||||||
|
|
|
@ -7,6 +7,7 @@ from typing import Optional, cast
|
||||||
|
|
||||||
from langsmith import Client
|
from langsmith import Client
|
||||||
from langsmith.schemas import RunBase
|
from langsmith.schemas import RunBase
|
||||||
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
|
||||||
from core.ops.base_trace_instance import BaseTraceInstance
|
from core.ops.base_trace_instance import BaseTraceInstance
|
||||||
from core.ops.entities.config_entity import LangSmithConfig
|
from core.ops.entities.config_entity import LangSmithConfig
|
||||||
|
@ -27,9 +28,9 @@ from core.ops.langsmith_trace.entities.langsmith_trace_entity import (
|
||||||
LangSmithRunUpdateModel,
|
LangSmithRunUpdateModel,
|
||||||
)
|
)
|
||||||
from core.ops.utils import filter_none_values, generate_dotted_order
|
from core.ops.utils import filter_none_values, generate_dotted_order
|
||||||
|
from core.repository.repository_factory import RepositoryFactory
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models.model import EndUser, MessageFile
|
from models.model import EndUser, MessageFile
|
||||||
from models.workflow import WorkflowNodeExecution
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -134,36 +135,22 @@ class LangSmithDataTrace(BaseTraceInstance):
|
||||||
|
|
||||||
self.add_run(langsmith_run)
|
self.add_run(langsmith_run)
|
||||||
|
|
||||||
# through workflow_run_id get all_nodes_execution
|
# through workflow_run_id get all_nodes_execution using repository
|
||||||
workflow_nodes_execution_id_records = (
|
session_factory = sessionmaker(bind=db.engine)
|
||||||
db.session.query(WorkflowNodeExecution.id)
|
workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository(
|
||||||
.filter(WorkflowNodeExecution.workflow_run_id == trace_info.workflow_run_id)
|
params={
|
||||||
.all()
|
"tenant_id": trace_info.tenant_id,
|
||||||
|
"app_id": trace_info.metadata.get("app_id"),
|
||||||
|
"session_factory": session_factory,
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
for node_execution_id_record in workflow_nodes_execution_id_records:
|
# Get all executions for this workflow run
|
||||||
node_execution = (
|
workflow_node_executions = workflow_node_execution_repository.get_by_workflow_run(
|
||||||
db.session.query(
|
workflow_run_id=trace_info.workflow_run_id
|
||||||
WorkflowNodeExecution.id,
|
)
|
||||||
WorkflowNodeExecution.tenant_id,
|
|
||||||
WorkflowNodeExecution.app_id,
|
|
||||||
WorkflowNodeExecution.title,
|
|
||||||
WorkflowNodeExecution.node_type,
|
|
||||||
WorkflowNodeExecution.status,
|
|
||||||
WorkflowNodeExecution.inputs,
|
|
||||||
WorkflowNodeExecution.outputs,
|
|
||||||
WorkflowNodeExecution.created_at,
|
|
||||||
WorkflowNodeExecution.elapsed_time,
|
|
||||||
WorkflowNodeExecution.process_data,
|
|
||||||
WorkflowNodeExecution.execution_metadata,
|
|
||||||
)
|
|
||||||
.filter(WorkflowNodeExecution.id == node_execution_id_record.id)
|
|
||||||
.first()
|
|
||||||
)
|
|
||||||
|
|
||||||
if not node_execution:
|
|
||||||
continue
|
|
||||||
|
|
||||||
|
for node_execution in workflow_node_executions:
|
||||||
node_execution_id = node_execution.id
|
node_execution_id = node_execution.id
|
||||||
tenant_id = node_execution.tenant_id
|
tenant_id = node_execution.tenant_id
|
||||||
app_id = node_execution.app_id
|
app_id = node_execution.app_id
|
||||||
|
|
|
@ -7,6 +7,7 @@ from typing import Optional, cast
|
||||||
|
|
||||||
from opik import Opik, Trace
|
from opik import Opik, Trace
|
||||||
from opik.id_helpers import uuid4_to_uuid7
|
from opik.id_helpers import uuid4_to_uuid7
|
||||||
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
|
||||||
from core.ops.base_trace_instance import BaseTraceInstance
|
from core.ops.base_trace_instance import BaseTraceInstance
|
||||||
from core.ops.entities.config_entity import OpikConfig
|
from core.ops.entities.config_entity import OpikConfig
|
||||||
|
@ -21,9 +22,9 @@ from core.ops.entities.trace_entity import (
|
||||||
TraceTaskName,
|
TraceTaskName,
|
||||||
WorkflowTraceInfo,
|
WorkflowTraceInfo,
|
||||||
)
|
)
|
||||||
|
from core.repository.repository_factory import RepositoryFactory
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models.model import EndUser, MessageFile
|
from models.model import EndUser, MessageFile
|
||||||
from models.workflow import WorkflowNodeExecution
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -147,36 +148,22 @@ class OpikDataTrace(BaseTraceInstance):
|
||||||
}
|
}
|
||||||
self.add_trace(trace_data)
|
self.add_trace(trace_data)
|
||||||
|
|
||||||
# through workflow_run_id get all_nodes_execution
|
# through workflow_run_id get all_nodes_execution using repository
|
||||||
workflow_nodes_execution_id_records = (
|
session_factory = sessionmaker(bind=db.engine)
|
||||||
db.session.query(WorkflowNodeExecution.id)
|
workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository(
|
||||||
.filter(WorkflowNodeExecution.workflow_run_id == trace_info.workflow_run_id)
|
params={
|
||||||
.all()
|
"tenant_id": trace_info.tenant_id,
|
||||||
|
"app_id": trace_info.metadata.get("app_id"),
|
||||||
|
"session_factory": session_factory,
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
for node_execution_id_record in workflow_nodes_execution_id_records:
|
# Get all executions for this workflow run
|
||||||
node_execution = (
|
workflow_node_executions = workflow_node_execution_repository.get_by_workflow_run(
|
||||||
db.session.query(
|
workflow_run_id=trace_info.workflow_run_id
|
||||||
WorkflowNodeExecution.id,
|
)
|
||||||
WorkflowNodeExecution.tenant_id,
|
|
||||||
WorkflowNodeExecution.app_id,
|
|
||||||
WorkflowNodeExecution.title,
|
|
||||||
WorkflowNodeExecution.node_type,
|
|
||||||
WorkflowNodeExecution.status,
|
|
||||||
WorkflowNodeExecution.inputs,
|
|
||||||
WorkflowNodeExecution.outputs,
|
|
||||||
WorkflowNodeExecution.created_at,
|
|
||||||
WorkflowNodeExecution.elapsed_time,
|
|
||||||
WorkflowNodeExecution.process_data,
|
|
||||||
WorkflowNodeExecution.execution_metadata,
|
|
||||||
)
|
|
||||||
.filter(WorkflowNodeExecution.id == node_execution_id_record.id)
|
|
||||||
.first()
|
|
||||||
)
|
|
||||||
|
|
||||||
if not node_execution:
|
|
||||||
continue
|
|
||||||
|
|
||||||
|
for node_execution in workflow_node_executions:
|
||||||
node_execution_id = node_execution.id
|
node_execution_id = node_execution.id
|
||||||
tenant_id = node_execution.tenant_id
|
tenant_id = node_execution.tenant_id
|
||||||
app_id = node_execution.app_id
|
app_id = node_execution.app_id
|
||||||
|
|
|
@ -453,7 +453,7 @@ class TraceTask:
|
||||||
"version": workflow_run_version,
|
"version": workflow_run_version,
|
||||||
"total_tokens": total_tokens,
|
"total_tokens": total_tokens,
|
||||||
"file_list": file_list,
|
"file_list": file_list,
|
||||||
"triggered_form": workflow_run.triggered_from,
|
"triggered_from": workflow_run.triggered_from,
|
||||||
"user_id": user_id,
|
"user_id": user_id,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -2,6 +2,7 @@ from collections.abc import Generator, Mapping
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
from controllers.service_api.wraps import create_or_update_end_user_for_user_id
|
from controllers.service_api.wraps import create_or_update_end_user_for_user_id
|
||||||
|
from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict
|
||||||
from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator
|
from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator
|
||||||
from core.app.apps.agent_chat.app_generator import AgentChatAppGenerator
|
from core.app.apps.agent_chat.app_generator import AgentChatAppGenerator
|
||||||
from core.app.apps.chat.app_generator import ChatAppGenerator
|
from core.app.apps.chat.app_generator import ChatAppGenerator
|
||||||
|
@ -15,6 +16,34 @@ from models.model import App, AppMode, EndUser
|
||||||
|
|
||||||
|
|
||||||
class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
|
class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
|
||||||
|
@classmethod
|
||||||
|
def fetch_app_info(cls, app_id: str, tenant_id: str) -> Mapping:
|
||||||
|
"""
|
||||||
|
Fetch app info
|
||||||
|
"""
|
||||||
|
app = cls._get_app(app_id, tenant_id)
|
||||||
|
|
||||||
|
"""Retrieve app parameters."""
|
||||||
|
if app.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}:
|
||||||
|
workflow = app.workflow
|
||||||
|
if workflow is None:
|
||||||
|
raise ValueError("unexpected app type")
|
||||||
|
|
||||||
|
features_dict = workflow.features_dict
|
||||||
|
user_input_form = workflow.user_input_form(to_old_structure=True)
|
||||||
|
else:
|
||||||
|
app_model_config = app.app_model_config
|
||||||
|
if app_model_config is None:
|
||||||
|
raise ValueError("unexpected app type")
|
||||||
|
|
||||||
|
features_dict = app_model_config.to_dict()
|
||||||
|
|
||||||
|
user_input_form = features_dict.get("user_input_form", [])
|
||||||
|
|
||||||
|
return {
|
||||||
|
"data": get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form),
|
||||||
|
}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def invoke_app(
|
def invoke_app(
|
||||||
cls,
|
cls,
|
||||||
|
|
|
@ -131,7 +131,7 @@ def cast_parameter_value(typ: enum.StrEnum, value: Any, /):
|
||||||
raise ValueError("The selector must be a dictionary.")
|
raise ValueError("The selector must be a dictionary.")
|
||||||
return value
|
return value
|
||||||
case PluginParameterType.TOOLS_SELECTOR:
|
case PluginParameterType.TOOLS_SELECTOR:
|
||||||
if not isinstance(value, list):
|
if value and not isinstance(value, list):
|
||||||
raise ValueError("The tools selector must be a list.")
|
raise ValueError("The tools selector must be a list.")
|
||||||
return value
|
return value
|
||||||
case _:
|
case _:
|
||||||
|
@ -147,7 +147,7 @@ def init_frontend_parameter(rule: PluginParameter, type: enum.StrEnum, value: An
|
||||||
init frontend parameter by rule
|
init frontend parameter by rule
|
||||||
"""
|
"""
|
||||||
parameter_value = value
|
parameter_value = value
|
||||||
if not parameter_value and parameter_value != 0 and type != PluginParameterType.TOOLS_SELECTOR:
|
if not parameter_value and parameter_value != 0:
|
||||||
# get default value
|
# get default value
|
||||||
parameter_value = rule.default
|
parameter_value = rule.default
|
||||||
if not parameter_value and rule.required:
|
if not parameter_value and rule.required:
|
||||||
|
|
|
@ -70,6 +70,9 @@ class PluginDeclaration(BaseModel):
|
||||||
models: Optional[list[str]] = Field(default_factory=list)
|
models: Optional[list[str]] = Field(default_factory=list)
|
||||||
endpoints: Optional[list[str]] = Field(default_factory=list)
|
endpoints: Optional[list[str]] = Field(default_factory=list)
|
||||||
|
|
||||||
|
class Meta(BaseModel):
|
||||||
|
minimum_dify_version: Optional[str] = Field(default=None, pattern=r"^\d{1,4}(\.\d{1,4}){1,3}(-\w{1,16})?$")
|
||||||
|
|
||||||
version: str = Field(..., pattern=r"^\d{1,4}(\.\d{1,4}){1,3}(-\w{1,16})?$")
|
version: str = Field(..., pattern=r"^\d{1,4}(\.\d{1,4}){1,3}(-\w{1,16})?$")
|
||||||
author: Optional[str] = Field(..., pattern=r"^[a-zA-Z0-9_-]{1,64}$")
|
author: Optional[str] = Field(..., pattern=r"^[a-zA-Z0-9_-]{1,64}$")
|
||||||
name: str = Field(..., pattern=r"^[a-z0-9_-]{1,128}$")
|
name: str = Field(..., pattern=r"^[a-z0-9_-]{1,128}$")
|
||||||
|
@ -86,6 +89,7 @@ class PluginDeclaration(BaseModel):
|
||||||
model: Optional[ProviderEntity] = None
|
model: Optional[ProviderEntity] = None
|
||||||
endpoint: Optional[EndpointProviderDeclaration] = None
|
endpoint: Optional[EndpointProviderDeclaration] = None
|
||||||
agent_strategy: Optional[AgentStrategyProviderEntity] = None
|
agent_strategy: Optional[AgentStrategyProviderEntity] = None
|
||||||
|
meta: Meta
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
|
@ -204,3 +204,11 @@ class RequestRequestUploadFile(BaseModel):
|
||||||
|
|
||||||
filename: str
|
filename: str
|
||||||
mimetype: str
|
mimetype: str
|
||||||
|
|
||||||
|
|
||||||
|
class RequestFetchAppInfo(BaseModel):
|
||||||
|
"""
|
||||||
|
Request to fetch app info
|
||||||
|
"""
|
||||||
|
|
||||||
|
app_id: str
|
||||||
|
|
|
@ -82,7 +82,7 @@ class BasePluginManager:
|
||||||
Make a stream request to the plugin daemon inner API
|
Make a stream request to the plugin daemon inner API
|
||||||
"""
|
"""
|
||||||
response = self._request(method, path, headers, data, params, files, stream=True)
|
response = self._request(method, path, headers, data, params, files, stream=True)
|
||||||
for line in response.iter_lines():
|
for line in response.iter_lines(chunk_size=1024 * 8):
|
||||||
line = line.decode("utf-8").strip()
|
line = line.decode("utf-8").strip()
|
||||||
if line.startswith("data:"):
|
if line.startswith("data:"):
|
||||||
line = line[5:].strip()
|
line = line[5:].strip()
|
||||||
|
@ -168,16 +168,18 @@ class BasePluginManager:
|
||||||
Make a stream request to the plugin daemon inner API and yield the response as a model.
|
Make a stream request to the plugin daemon inner API and yield the response as a model.
|
||||||
"""
|
"""
|
||||||
for line in self._stream_request(method, path, params, headers, data, files):
|
for line in self._stream_request(method, path, params, headers, data, files):
|
||||||
line_data = None
|
|
||||||
try:
|
try:
|
||||||
line_data = json.loads(line)
|
rep = PluginDaemonBasicResponse[type].model_validate_json(line) # type: ignore
|
||||||
rep = PluginDaemonBasicResponse[type](**line_data) # type: ignore
|
except (ValueError, TypeError):
|
||||||
except Exception:
|
|
||||||
# TODO modify this when line_data has code and message
|
# TODO modify this when line_data has code and message
|
||||||
if line_data and "error" in line_data:
|
try:
|
||||||
raise ValueError(line_data["error"])
|
line_data = json.loads(line)
|
||||||
else:
|
except (ValueError, TypeError):
|
||||||
raise ValueError(line)
|
raise ValueError(line)
|
||||||
|
# If the dictionary contains the `error` key, use its value as the argument
|
||||||
|
# for `ValueError`.
|
||||||
|
# Otherwise, use the `line` to provide better contextual information about the error.
|
||||||
|
raise ValueError(line_data.get("error", line))
|
||||||
|
|
||||||
if rep.code != 0:
|
if rep.code != 0:
|
||||||
if rep.code == -500:
|
if rep.code == -500:
|
||||||
|
|
|
@ -110,7 +110,62 @@ class PluginToolManager(BasePluginManager):
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
return response
|
|
||||||
|
class FileChunk:
|
||||||
|
"""
|
||||||
|
Only used for internal processing.
|
||||||
|
"""
|
||||||
|
|
||||||
|
bytes_written: int
|
||||||
|
total_length: int
|
||||||
|
data: bytearray
|
||||||
|
|
||||||
|
def __init__(self, total_length: int):
|
||||||
|
self.bytes_written = 0
|
||||||
|
self.total_length = total_length
|
||||||
|
self.data = bytearray(total_length)
|
||||||
|
|
||||||
|
files: dict[str, FileChunk] = {}
|
||||||
|
for resp in response:
|
||||||
|
if resp.type == ToolInvokeMessage.MessageType.BLOB_CHUNK:
|
||||||
|
assert isinstance(resp.message, ToolInvokeMessage.BlobChunkMessage)
|
||||||
|
# Get blob chunk information
|
||||||
|
chunk_id = resp.message.id
|
||||||
|
total_length = resp.message.total_length
|
||||||
|
blob_data = resp.message.blob
|
||||||
|
is_end = resp.message.end
|
||||||
|
|
||||||
|
# Initialize buffer for this file if it doesn't exist
|
||||||
|
if chunk_id not in files:
|
||||||
|
files[chunk_id] = FileChunk(total_length)
|
||||||
|
|
||||||
|
# If this is the final chunk, yield a complete blob message
|
||||||
|
if is_end:
|
||||||
|
yield ToolInvokeMessage(
|
||||||
|
type=ToolInvokeMessage.MessageType.BLOB,
|
||||||
|
message=ToolInvokeMessage.BlobMessage(blob=files[chunk_id].data),
|
||||||
|
meta=resp.meta,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Check if file is too large (30MB limit)
|
||||||
|
if files[chunk_id].bytes_written + len(blob_data) > 30 * 1024 * 1024:
|
||||||
|
# Delete the file if it's too large
|
||||||
|
del files[chunk_id]
|
||||||
|
# Skip yielding this message
|
||||||
|
raise ValueError("File is too large which reached the limit of 30MB")
|
||||||
|
|
||||||
|
# Check if single chunk is too large (8KB limit)
|
||||||
|
if len(blob_data) > 8192:
|
||||||
|
# Skip yielding this message
|
||||||
|
raise ValueError("File chunk is too large which reached the limit of 8KB")
|
||||||
|
|
||||||
|
# Append the blob data to the buffer
|
||||||
|
files[chunk_id].data[
|
||||||
|
files[chunk_id].bytes_written : files[chunk_id].bytes_written + len(blob_data)
|
||||||
|
] = blob_data
|
||||||
|
files[chunk_id].bytes_written += len(blob_data)
|
||||||
|
else:
|
||||||
|
yield resp
|
||||||
|
|
||||||
def validate_provider_credentials(
|
def validate_provider_credentials(
|
||||||
self, tenant_id: str, user_id: str, provider: str, credentials: dict[str, Any]
|
self, tenant_id: str, user_id: str, provider: str, credentials: dict[str, Any]
|
||||||
|
|
|
@ -28,7 +28,7 @@ BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG = {
|
||||||
},
|
},
|
||||||
"conversation_histories_role": {"user_prefix": "用户", "assistant_prefix": "助手"},
|
"conversation_histories_role": {"user_prefix": "用户", "assistant_prefix": "助手"},
|
||||||
},
|
},
|
||||||
"stop": ["用户:"],
|
"stop": ["用户:"],
|
||||||
}
|
}
|
||||||
|
|
||||||
BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG = {
|
BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG = {
|
||||||
|
@ -41,5 +41,5 @@ BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG = {
|
||||||
|
|
||||||
BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG = {
|
BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG = {
|
||||||
"completion_prompt_config": {"prompt": {"text": "{{#pre_prompt#}}"}},
|
"completion_prompt_config": {"prompt": {"text": "{{#pre_prompt#}}"}},
|
||||||
"stop": ["用户:"],
|
"stop": ["用户:"],
|
||||||
}
|
}
|
||||||
|
|
|
@ -124,6 +124,15 @@ class ProviderManager:
|
||||||
|
|
||||||
# Get All preferred provider types of the workspace
|
# Get All preferred provider types of the workspace
|
||||||
provider_name_to_preferred_model_provider_records_dict = self._get_all_preferred_model_providers(tenant_id)
|
provider_name_to_preferred_model_provider_records_dict = self._get_all_preferred_model_providers(tenant_id)
|
||||||
|
# Ensure that both the original provider name and its ModelProviderID string representation
|
||||||
|
# are present in the dictionary to handle cases where either form might be used
|
||||||
|
for provider_name in list(provider_name_to_preferred_model_provider_records_dict.keys()):
|
||||||
|
provider_id = ModelProviderID(provider_name)
|
||||||
|
if str(provider_id) not in provider_name_to_preferred_model_provider_records_dict:
|
||||||
|
# Add the ModelProviderID string representation if it's not already present
|
||||||
|
provider_name_to_preferred_model_provider_records_dict[str(provider_id)] = (
|
||||||
|
provider_name_to_preferred_model_provider_records_dict[provider_name]
|
||||||
|
)
|
||||||
|
|
||||||
# Get All provider model settings
|
# Get All provider model settings
|
||||||
provider_name_to_provider_model_settings_dict = self._get_all_provider_model_settings(tenant_id)
|
provider_name_to_provider_model_settings_dict = self._get_all_provider_model_settings(tenant_id)
|
||||||
|
@ -497,8 +506,8 @@ class ProviderManager:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _init_trial_provider_records(
|
def _init_trial_provider_records(
|
||||||
tenant_id: str, provider_name_to_provider_records_dict: dict[str, list]
|
tenant_id: str, provider_name_to_provider_records_dict: dict[str, list[Provider]]
|
||||||
) -> dict[str, list]:
|
) -> dict[str, list[Provider]]:
|
||||||
"""
|
"""
|
||||||
Initialize trial provider records if not exists.
|
Initialize trial provider records if not exists.
|
||||||
|
|
||||||
|
@ -532,7 +541,7 @@ class ProviderManager:
|
||||||
if ProviderQuotaType.TRIAL not in provider_quota_to_provider_record_dict:
|
if ProviderQuotaType.TRIAL not in provider_quota_to_provider_record_dict:
|
||||||
try:
|
try:
|
||||||
# FIXME ignore the type errork, onyl TrialHostingQuota has limit need to change the logic
|
# FIXME ignore the type errork, onyl TrialHostingQuota has limit need to change the logic
|
||||||
provider_record = Provider(
|
new_provider_record = Provider(
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
# TODO: Use provider name with prefix after the data migration.
|
# TODO: Use provider name with prefix after the data migration.
|
||||||
provider_name=ModelProviderID(provider_name).provider_name,
|
provider_name=ModelProviderID(provider_name).provider_name,
|
||||||
|
@ -542,11 +551,12 @@ class ProviderManager:
|
||||||
quota_used=0,
|
quota_used=0,
|
||||||
is_valid=True,
|
is_valid=True,
|
||||||
)
|
)
|
||||||
db.session.add(provider_record)
|
db.session.add(new_provider_record)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
provider_name_to_provider_records_dict[provider_name].append(new_provider_record)
|
||||||
except IntegrityError:
|
except IntegrityError:
|
||||||
db.session.rollback()
|
db.session.rollback()
|
||||||
provider_record = (
|
existed_provider_record = (
|
||||||
db.session.query(Provider)
|
db.session.query(Provider)
|
||||||
.filter(
|
.filter(
|
||||||
Provider.tenant_id == tenant_id,
|
Provider.tenant_id == tenant_id,
|
||||||
|
@ -556,11 +566,14 @@ class ProviderManager:
|
||||||
)
|
)
|
||||||
.first()
|
.first()
|
||||||
)
|
)
|
||||||
if provider_record and not provider_record.is_valid:
|
if not existed_provider_record:
|
||||||
provider_record.is_valid = True
|
continue
|
||||||
|
|
||||||
|
if not existed_provider_record.is_valid:
|
||||||
|
existed_provider_record.is_valid = True
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
||||||
provider_name_to_provider_records_dict[provider_name].append(provider_record)
|
provider_name_to_provider_records_dict[provider_name].append(existed_provider_record)
|
||||||
|
|
||||||
return provider_name_to_provider_records_dict
|
return provider_name_to_provider_records_dict
|
||||||
|
|
||||||
|
|
|
@ -139,13 +139,17 @@ class AnalyticdbVectorBySql:
|
||||||
)
|
)
|
||||||
if embedding_dimension is not None:
|
if embedding_dimension is not None:
|
||||||
index_name = f"{self._collection_name}_embedding_idx"
|
index_name = f"{self._collection_name}_embedding_idx"
|
||||||
cur.execute(f"ALTER TABLE {self.table_name} ALTER COLUMN vector SET STORAGE PLAIN")
|
try:
|
||||||
cur.execute(
|
cur.execute(f"ALTER TABLE {self.table_name} ALTER COLUMN vector SET STORAGE PLAIN")
|
||||||
f"CREATE INDEX {index_name} ON {self.table_name} USING ann(vector) "
|
cur.execute(
|
||||||
f"WITH(dim='{embedding_dimension}', distancemeasure='{self.config.metrics}', "
|
f"CREATE INDEX {index_name} ON {self.table_name} USING ann(vector) "
|
||||||
f"pq_enable=0, external_storage=0)"
|
f"WITH(dim='{embedding_dimension}', distancemeasure='{self.config.metrics}', "
|
||||||
)
|
f"pq_enable=0, external_storage=0)"
|
||||||
cur.execute(f"CREATE INDEX ON {self.table_name} USING gin(to_tsvector)")
|
)
|
||||||
|
cur.execute(f"CREATE INDEX ON {self.table_name} USING gin(to_tsvector)")
|
||||||
|
except Exception as e:
|
||||||
|
if "already exists" not in str(e):
|
||||||
|
raise e
|
||||||
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||||
|
|
||||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||||
|
@ -177,9 +181,11 @@ class AnalyticdbVectorBySql:
|
||||||
return cur.fetchone() is not None
|
return cur.fetchone() is not None
|
||||||
|
|
||||||
def delete_by_ids(self, ids: list[str]) -> None:
|
def delete_by_ids(self, ids: list[str]) -> None:
|
||||||
|
if not ids:
|
||||||
|
return
|
||||||
with self._get_cursor() as cur:
|
with self._get_cursor() as cur:
|
||||||
try:
|
try:
|
||||||
cur.execute(f"DELETE FROM {self.table_name} WHERE ref_doc_id IN %s", (tuple(ids),))
|
cur.execute(f"DELETE FROM {self.table_name} WHERE ref_doc_id = ANY(%s)", (ids,))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if "does not exist" not in str(e):
|
if "does not exist" not in str(e):
|
||||||
raise e
|
raise e
|
||||||
|
@ -240,7 +246,7 @@ class AnalyticdbVectorBySql:
|
||||||
ts_rank(to_tsvector, to_tsquery_from_text(%s, 'zh_cn'), 32) AS score
|
ts_rank(to_tsvector, to_tsquery_from_text(%s, 'zh_cn'), 32) AS score
|
||||||
FROM {self.table_name}
|
FROM {self.table_name}
|
||||||
WHERE to_tsvector@@to_tsquery_from_text(%s, 'zh_cn') {where_clause}
|
WHERE to_tsvector@@to_tsquery_from_text(%s, 'zh_cn') {where_clause}
|
||||||
ORDER BY score DESC
|
ORDER BY score DESC, id DESC
|
||||||
LIMIT {top_k}""",
|
LIMIT {top_k}""",
|
||||||
(f"'{query}'", f"'{query}'"),
|
(f"'{query}'", f"'{query}'"),
|
||||||
)
|
)
|
||||||
|
|
|
@ -32,6 +32,7 @@ class MilvusConfig(BaseModel):
|
||||||
batch_size: int = 100 # Batch size for operations
|
batch_size: int = 100 # Batch size for operations
|
||||||
database: str = "default" # Database name
|
database: str = "default" # Database name
|
||||||
enable_hybrid_search: bool = False # Flag to enable hybrid search
|
enable_hybrid_search: bool = False # Flag to enable hybrid search
|
||||||
|
analyzer_params: Optional[str] = None # Analyzer params
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -58,6 +59,7 @@ class MilvusConfig(BaseModel):
|
||||||
"user": self.user,
|
"user": self.user,
|
||||||
"password": self.password,
|
"password": self.password,
|
||||||
"db_name": self.database,
|
"db_name": self.database,
|
||||||
|
"analyzer_params": self.analyzer_params,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -300,14 +302,19 @@ class MilvusVector(BaseVector):
|
||||||
|
|
||||||
# Create the text field, enable_analyzer will be set True to support milvus automatically
|
# Create the text field, enable_analyzer will be set True to support milvus automatically
|
||||||
# transfer text to sparse_vector, reference: https://milvus.io/docs/full-text-search.md
|
# transfer text to sparse_vector, reference: https://milvus.io/docs/full-text-search.md
|
||||||
fields.append(
|
content_field_kwargs: dict[str, Any] = {
|
||||||
FieldSchema(
|
"max_length": 65_535,
|
||||||
Field.CONTENT_KEY.value,
|
"enable_analyzer": self._hybrid_search_enabled,
|
||||||
DataType.VARCHAR,
|
}
|
||||||
max_length=65_535,
|
if (
|
||||||
enable_analyzer=self._hybrid_search_enabled,
|
self._hybrid_search_enabled
|
||||||
)
|
and self._client_config.analyzer_params is not None
|
||||||
)
|
and self._client_config.analyzer_params.strip()
|
||||||
|
):
|
||||||
|
content_field_kwargs["analyzer_params"] = self._client_config.analyzer_params
|
||||||
|
|
||||||
|
fields.append(FieldSchema(Field.CONTENT_KEY.value, DataType.VARCHAR, **content_field_kwargs))
|
||||||
|
|
||||||
# Create the primary key field
|
# Create the primary key field
|
||||||
fields.append(FieldSchema(Field.PRIMARY_KEY.value, DataType.INT64, is_primary=True, auto_id=True))
|
fields.append(FieldSchema(Field.PRIMARY_KEY.value, DataType.INT64, is_primary=True, auto_id=True))
|
||||||
# Create the vector field, supports binary or float vectors
|
# Create the vector field, supports binary or float vectors
|
||||||
|
@ -383,5 +390,6 @@ class MilvusVectorFactory(AbstractVectorFactory):
|
||||||
password=dify_config.MILVUS_PASSWORD or "",
|
password=dify_config.MILVUS_PASSWORD or "",
|
||||||
database=dify_config.MILVUS_DATABASE or "",
|
database=dify_config.MILVUS_DATABASE or "",
|
||||||
enable_hybrid_search=dify_config.MILVUS_ENABLE_HYBRID_SEARCH or False,
|
enable_hybrid_search=dify_config.MILVUS_ENABLE_HYBRID_SEARCH or False,
|
||||||
|
analyzer_params=dify_config.MILVUS_ANALYZER_PARAMS or "",
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
|
@ -228,7 +228,7 @@ class OracleVector(BaseVector):
|
||||||
|
|
||||||
top_k = kwargs.get("top_k", 5)
|
top_k = kwargs.get("top_k", 5)
|
||||||
# just not implement fetch by score_threshold now, may be later
|
# just not implement fetch by score_threshold now, may be later
|
||||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
# score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||||
if len(query) > 0:
|
if len(query) > 0:
|
||||||
# Check which language the query is in
|
# Check which language the query is in
|
||||||
zh_pattern = re.compile("[\u4e00-\u9fa5]+")
|
zh_pattern = re.compile("[\u4e00-\u9fa5]+")
|
||||||
|
@ -239,7 +239,7 @@ class OracleVector(BaseVector):
|
||||||
words = pseg.cut(query)
|
words = pseg.cut(query)
|
||||||
current_entity = ""
|
current_entity = ""
|
||||||
for word, pos in words:
|
for word, pos in words:
|
||||||
if pos in {"nr", "Ng", "eng", "nz", "n", "ORG", "v"}: # nr: 人名, ns: 地名, nt: 机构名
|
if pos in {"nr", "Ng", "eng", "nz", "n", "ORG", "v"}: # nr: 人名,ns: 地名,nt: 机构名
|
||||||
current_entity += word
|
current_entity += word
|
||||||
else:
|
else:
|
||||||
if current_entity:
|
if current_entity:
|
||||||
|
|
|
@ -444,7 +444,7 @@ class QdrantVectorFactory(AbstractVectorFactory):
|
||||||
if dataset_collection_binding:
|
if dataset_collection_binding:
|
||||||
collection_name = dataset_collection_binding.collection_name
|
collection_name = dataset_collection_binding.collection_name
|
||||||
else:
|
else:
|
||||||
raise ValueError("Dataset Collection Bindings is not exist!")
|
raise ValueError("Dataset Collection Bindings does not exist!")
|
||||||
else:
|
else:
|
||||||
if dataset.index_struct_dict:
|
if dataset.index_struct_dict:
|
||||||
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
|
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
|
||||||
|
|
|
@ -65,8 +65,6 @@ class RelytVector(BaseVector):
|
||||||
return VectorType.RELYT
|
return VectorType.RELYT
|
||||||
|
|
||||||
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs) -> None:
|
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs) -> None:
|
||||||
index_params: dict[str, Any] = {}
|
|
||||||
metadatas = [d.metadata for d in texts]
|
|
||||||
self.create_collection(len(embeddings[0]))
|
self.create_collection(len(embeddings[0]))
|
||||||
self.embedding_dimension = len(embeddings[0])
|
self.embedding_dimension = len(embeddings[0])
|
||||||
self.add_texts(texts, embeddings)
|
self.add_texts(texts, embeddings)
|
||||||
|
|
|
@ -187,7 +187,6 @@ class TiDBVector(BaseVector):
|
||||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||||
top_k = kwargs.get("top_k", 4)
|
top_k = kwargs.get("top_k", 4)
|
||||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||||
filter = kwargs.get("filter")
|
|
||||||
distance = 1 - score_threshold
|
distance = 1 - score_threshold
|
||||||
|
|
||||||
query_vector_str = ", ".join(format(x) for x in query_vector)
|
query_vector_str = ", ".join(format(x) for x in query_vector)
|
||||||
|
|
|
@ -206,6 +206,7 @@ class DatasetRetrieval:
|
||||||
source = {
|
source = {
|
||||||
"dataset_id": item.metadata.get("dataset_id"),
|
"dataset_id": item.metadata.get("dataset_id"),
|
||||||
"dataset_name": item.metadata.get("dataset_name"),
|
"dataset_name": item.metadata.get("dataset_name"),
|
||||||
|
"document_id": item.metadata.get("document_id") or item.metadata.get("title"),
|
||||||
"document_name": item.metadata.get("title"),
|
"document_name": item.metadata.get("title"),
|
||||||
"data_source_type": "external",
|
"data_source_type": "external",
|
||||||
"retriever_from": invoke_from.to_source(),
|
"retriever_from": invoke_from.to_source(),
|
||||||
|
|
|
@ -39,6 +39,12 @@ class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter):
|
||||||
else:
|
else:
|
||||||
return [GPT2Tokenizer.get_num_tokens(text) for text in texts]
|
return [GPT2Tokenizer.get_num_tokens(text) for text in texts]
|
||||||
|
|
||||||
|
def _character_encoder(texts: list[str]) -> list[int]:
|
||||||
|
if not texts:
|
||||||
|
return []
|
||||||
|
|
||||||
|
return [len(text) for text in texts]
|
||||||
|
|
||||||
if issubclass(cls, TokenTextSplitter):
|
if issubclass(cls, TokenTextSplitter):
|
||||||
extra_kwargs = {
|
extra_kwargs = {
|
||||||
"model_name": embedding_model_instance.model if embedding_model_instance else "gpt2",
|
"model_name": embedding_model_instance.model if embedding_model_instance else "gpt2",
|
||||||
|
@ -47,7 +53,7 @@ class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter):
|
||||||
}
|
}
|
||||||
kwargs = {**kwargs, **extra_kwargs}
|
kwargs = {**kwargs, **extra_kwargs}
|
||||||
|
|
||||||
return cls(length_function=_token_encoder, **kwargs)
|
return cls(length_function=_character_encoder, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter):
|
class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter):
|
||||||
|
@ -103,7 +109,7 @@ class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter)
|
||||||
_good_splits_lengths = [] # cache the lengths of the splits
|
_good_splits_lengths = [] # cache the lengths of the splits
|
||||||
_separator = "" if self._keep_separator else separator
|
_separator = "" if self._keep_separator else separator
|
||||||
s_lens = self._length_function(splits)
|
s_lens = self._length_function(splits)
|
||||||
if _separator != "":
|
if separator != "":
|
||||||
for s, s_len in zip(splits, s_lens):
|
for s, s_len in zip(splits, s_lens):
|
||||||
if s_len < self._chunk_size:
|
if s_len < self._chunk_size:
|
||||||
_good_splits.append(s)
|
_good_splits.append(s)
|
||||||
|
|
|
@ -0,0 +1,15 @@
|
||||||
|
"""
|
||||||
|
Repository interfaces for data access.
|
||||||
|
|
||||||
|
This package contains repository interfaces that define the contract
|
||||||
|
for accessing and manipulating data, regardless of the underlying
|
||||||
|
storage mechanism.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from core.repository.repository_factory import RepositoryFactory
|
||||||
|
from core.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"RepositoryFactory",
|
||||||
|
"WorkflowNodeExecutionRepository",
|
||||||
|
]
|
|
@ -0,0 +1,97 @@
|
||||||
|
"""
|
||||||
|
Repository factory for creating repository instances.
|
||||||
|
|
||||||
|
This module provides a simple factory interface for creating repository instances.
|
||||||
|
It does not contain any implementation details or dependencies on specific repositories.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from collections.abc import Callable, Mapping
|
||||||
|
from typing import Any, Literal, Optional, cast
|
||||||
|
|
||||||
|
from core.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||||
|
|
||||||
|
# Type for factory functions - takes a dict of parameters and returns any repository type
|
||||||
|
RepositoryFactoryFunc = Callable[[Mapping[str, Any]], Any]
|
||||||
|
|
||||||
|
# Type for workflow node execution factory function
|
||||||
|
WorkflowNodeExecutionFactoryFunc = Callable[[Mapping[str, Any]], WorkflowNodeExecutionRepository]
|
||||||
|
|
||||||
|
# Repository type literals
|
||||||
|
_RepositoryType = Literal["workflow_node_execution"]
|
||||||
|
|
||||||
|
|
||||||
|
class RepositoryFactory:
|
||||||
|
"""
|
||||||
|
Factory class for creating repository instances.
|
||||||
|
|
||||||
|
This factory delegates the actual repository creation to implementation-specific
|
||||||
|
factory functions that are registered with the factory at runtime.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Dictionary to store factory functions
|
||||||
|
_factory_functions: dict[str, RepositoryFactoryFunc] = {}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _register_factory(cls, repository_type: _RepositoryType, factory_func: RepositoryFactoryFunc) -> None:
|
||||||
|
"""
|
||||||
|
Register a factory function for a specific repository type.
|
||||||
|
This is a private method and should not be called directly.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
repository_type: The type of repository (e.g., 'workflow_node_execution')
|
||||||
|
factory_func: A function that takes parameters and returns a repository instance
|
||||||
|
"""
|
||||||
|
cls._factory_functions[repository_type] = factory_func
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _create_repository(cls, repository_type: _RepositoryType, params: Optional[Mapping[str, Any]] = None) -> Any:
|
||||||
|
"""
|
||||||
|
Create a new repository instance with the provided parameters.
|
||||||
|
This is a private method and should not be called directly.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
repository_type: The type of repository to create
|
||||||
|
params: A dictionary of parameters to pass to the factory function
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A new instance of the requested repository
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If no factory function is registered for the repository type
|
||||||
|
"""
|
||||||
|
if repository_type not in cls._factory_functions:
|
||||||
|
raise ValueError(f"No factory function registered for repository type '{repository_type}'")
|
||||||
|
|
||||||
|
# Use empty dict if params is None
|
||||||
|
params = params or {}
|
||||||
|
|
||||||
|
return cls._factory_functions[repository_type](params)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def register_workflow_node_execution_factory(cls, factory_func: WorkflowNodeExecutionFactoryFunc) -> None:
|
||||||
|
"""
|
||||||
|
Register a factory function for the workflow node execution repository.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
factory_func: A function that takes parameters and returns a WorkflowNodeExecutionRepository instance
|
||||||
|
"""
|
||||||
|
cls._register_factory("workflow_node_execution", factory_func)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create_workflow_node_execution_repository(
|
||||||
|
cls, params: Optional[Mapping[str, Any]] = None
|
||||||
|
) -> WorkflowNodeExecutionRepository:
|
||||||
|
"""
|
||||||
|
Create a new WorkflowNodeExecutionRepository instance with the provided parameters.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
params: A dictionary of parameters to pass to the factory function
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A new instance of the WorkflowNodeExecutionRepository
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If no factory function is registered for the workflow_node_execution repository type
|
||||||
|
"""
|
||||||
|
# We can safely cast here because we've registered a WorkflowNodeExecutionFactoryFunc
|
||||||
|
return cast(WorkflowNodeExecutionRepository, cls._create_repository("workflow_node_execution", params))
|
|
@ -0,0 +1,88 @@
|
||||||
|
from collections.abc import Sequence
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Literal, Optional, Protocol
|
||||||
|
|
||||||
|
from models.workflow import WorkflowNodeExecution
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class OrderConfig:
|
||||||
|
"""Configuration for ordering WorkflowNodeExecution instances."""
|
||||||
|
|
||||||
|
order_by: list[str]
|
||||||
|
order_direction: Optional[Literal["asc", "desc"]] = None
|
||||||
|
|
||||||
|
|
||||||
|
class WorkflowNodeExecutionRepository(Protocol):
|
||||||
|
"""
|
||||||
|
Repository interface for WorkflowNodeExecution.
|
||||||
|
|
||||||
|
This interface defines the contract for accessing and manipulating
|
||||||
|
WorkflowNodeExecution data, regardless of the underlying storage mechanism.
|
||||||
|
|
||||||
|
Note: Domain-specific concepts like multi-tenancy (tenant_id), application context (app_id),
|
||||||
|
and trigger sources (triggered_from) should be handled at the implementation level, not in
|
||||||
|
the core interface. This keeps the core domain model clean and independent of specific
|
||||||
|
application domains or deployment scenarios.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def save(self, execution: WorkflowNodeExecution) -> None:
|
||||||
|
"""
|
||||||
|
Save a WorkflowNodeExecution instance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
execution: The WorkflowNodeExecution instance to save
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
def get_by_node_execution_id(self, node_execution_id: str) -> Optional[WorkflowNodeExecution]:
|
||||||
|
"""
|
||||||
|
Retrieve a WorkflowNodeExecution by its node_execution_id.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node_execution_id: The node execution ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The WorkflowNodeExecution instance if found, None otherwise
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
def get_by_workflow_run(
|
||||||
|
self,
|
||||||
|
workflow_run_id: str,
|
||||||
|
order_config: Optional[OrderConfig] = None,
|
||||||
|
) -> Sequence[WorkflowNodeExecution]:
|
||||||
|
"""
|
||||||
|
Retrieve all WorkflowNodeExecution instances for a specific workflow run.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
workflow_run_id: The workflow run ID
|
||||||
|
order_config: Optional configuration for ordering results
|
||||||
|
order_config.order_by: List of fields to order by (e.g., ["index", "created_at"])
|
||||||
|
order_config.order_direction: Direction to order ("asc" or "desc")
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of WorkflowNodeExecution instances
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
def get_running_executions(self, workflow_run_id: str) -> Sequence[WorkflowNodeExecution]:
|
||||||
|
"""
|
||||||
|
Retrieve all running WorkflowNodeExecution instances for a specific workflow run.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
workflow_run_id: The workflow run ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of running WorkflowNodeExecution instances
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
def update(self, execution: WorkflowNodeExecution) -> None:
|
||||||
|
"""
|
||||||
|
Update an existing WorkflowNodeExecution instance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
execution: The WorkflowNodeExecution instance to update
|
||||||
|
"""
|
||||||
|
...
|
|
@ -8,7 +8,7 @@ identity:
|
||||||
description:
|
description:
|
||||||
human:
|
human:
|
||||||
en_US: Run code and get the result back. When you're using a lower quality model, please make sure there are some tips help LLM to understand how to write the code.
|
en_US: Run code and get the result back. When you're using a lower quality model, please make sure there are some tips help LLM to understand how to write the code.
|
||||||
zh_Hans: 运行一段代码并返回结果。当您使用较低质量的模型时,请确保有一些提示帮助LLM理解如何编写代码。
|
zh_Hans: 运行一段代码并返回结果。当您使用较低质量的模型时,请确保有一些提示帮助 LLM 理解如何编写代码。
|
||||||
pt_BR: Execute um trecho de código e obtenha o resultado de volta. quando você estiver usando um modelo de qualidade inferior, certifique-se de que existam algumas dicas para ajudar o LLM a entender como escrever o código.
|
pt_BR: Execute um trecho de código e obtenha o resultado de volta. quando você estiver usando um modelo de qualidade inferior, certifique-se de que existam algumas dicas para ajudar o LLM a entender como escrever o código.
|
||||||
llm: A tool for running code and getting the result back. Only native packages are allowed, network/IO operations are disabled. and you must use print() or console.log() to output the result or result will be empty.
|
llm: A tool for running code and getting the result back. Only native packages are allowed, network/IO operations are disabled. and you must use print() or console.log() to output the result or result will be empty.
|
||||||
parameters:
|
parameters:
|
||||||
|
|
|
@ -19,7 +19,7 @@ parameters:
|
||||||
zh_Hans: 本地时间
|
zh_Hans: 本地时间
|
||||||
human_description:
|
human_description:
|
||||||
en_US: localtime, such as 2024-1-1 0:0:0
|
en_US: localtime, such as 2024-1-1 0:0:0
|
||||||
zh_Hans: 本地时间, 比如2024-1-1 0:0:0
|
zh_Hans: 本地时间,比如 2024-1-1 0:0:0
|
||||||
- name: timezone
|
- name: timezone
|
||||||
type: string
|
type: string
|
||||||
required: false
|
required: false
|
||||||
|
@ -29,5 +29,5 @@ parameters:
|
||||||
zh_Hans: 时区
|
zh_Hans: 时区
|
||||||
human_description:
|
human_description:
|
||||||
en_US: Timezone, such as Asia/Shanghai
|
en_US: Timezone, such as Asia/Shanghai
|
||||||
zh_Hans: 时区, 比如Asia/Shanghai
|
zh_Hans: 时区,比如 Asia/Shanghai
|
||||||
default: Asia/Shanghai
|
default: Asia/Shanghai
|
||||||
|
|
|
@ -29,5 +29,5 @@ parameters:
|
||||||
zh_Hans: 时区
|
zh_Hans: 时区
|
||||||
human_description:
|
human_description:
|
||||||
en_US: Timezone, such as Asia/Shanghai
|
en_US: Timezone, such as Asia/Shanghai
|
||||||
zh_Hans: 时区, 比如Asia/Shanghai
|
zh_Hans: 时区,比如 Asia/Shanghai
|
||||||
default: Asia/Shanghai
|
default: Asia/Shanghai
|
||||||
|
|
|
@ -19,7 +19,7 @@ parameters:
|
||||||
zh_Hans: 当前时间
|
zh_Hans: 当前时间
|
||||||
human_description:
|
human_description:
|
||||||
en_US: current time, such as 2024-1-1 0:0:0
|
en_US: current time, such as 2024-1-1 0:0:0
|
||||||
zh_Hans: 当前时间, 比如2024-1-1 0:0:0
|
zh_Hans: 当前时间,比如 2024-1-1 0:0:0
|
||||||
- name: current_timezone
|
- name: current_timezone
|
||||||
type: string
|
type: string
|
||||||
required: true
|
required: true
|
||||||
|
@ -29,7 +29,7 @@ parameters:
|
||||||
zh_Hans: 当前时区
|
zh_Hans: 当前时区
|
||||||
human_description:
|
human_description:
|
||||||
en_US: Current Timezone, such as Asia/Shanghai
|
en_US: Current Timezone, such as Asia/Shanghai
|
||||||
zh_Hans: 当前时区, 比如Asia/Shanghai
|
zh_Hans: 当前时区,比如 Asia/Shanghai
|
||||||
default: Asia/Shanghai
|
default: Asia/Shanghai
|
||||||
- name: target_timezone
|
- name: target_timezone
|
||||||
type: string
|
type: string
|
||||||
|
@ -40,5 +40,5 @@ parameters:
|
||||||
zh_Hans: 目标时区
|
zh_Hans: 目标时区
|
||||||
human_description:
|
human_description:
|
||||||
en_US: Target Timezone, such as Asia/Tokyo
|
en_US: Target Timezone, such as Asia/Tokyo
|
||||||
zh_Hans: 目标时区, 比如Asia/Tokyo
|
zh_Hans: 目标时区,比如 Asia/Tokyo
|
||||||
default: Asia/Tokyo
|
default: Asia/Tokyo
|
||||||
|
|
|
@ -59,7 +59,7 @@ class ApiToolProviderController(ToolProviderController):
|
||||||
name="api_key_value",
|
name="api_key_value",
|
||||||
required=True,
|
required=True,
|
||||||
type=ProviderConfig.Type.SECRET_INPUT,
|
type=ProviderConfig.Type.SECRET_INPUT,
|
||||||
help=I18nObject(en_US="The api key", zh_Hans="api key的值"),
|
help=I18nObject(en_US="The api key", zh_Hans="api key 的值"),
|
||||||
),
|
),
|
||||||
ProviderConfig(
|
ProviderConfig(
|
||||||
name="api_key_header_prefix",
|
name="api_key_header_prefix",
|
||||||
|
|
|
@ -120,6 +120,13 @@ class ToolInvokeMessage(BaseModel):
|
||||||
class BlobMessage(BaseModel):
|
class BlobMessage(BaseModel):
|
||||||
blob: bytes
|
blob: bytes
|
||||||
|
|
||||||
|
class BlobChunkMessage(BaseModel):
|
||||||
|
id: str = Field(..., description="The id of the blob")
|
||||||
|
sequence: int = Field(..., description="The sequence of the chunk")
|
||||||
|
total_length: int = Field(..., description="The total length of the blob")
|
||||||
|
blob: bytes = Field(..., description="The blob data of the chunk")
|
||||||
|
end: bool = Field(..., description="Whether the chunk is the last chunk")
|
||||||
|
|
||||||
class FileMessage(BaseModel):
|
class FileMessage(BaseModel):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -180,12 +187,15 @@ class ToolInvokeMessage(BaseModel):
|
||||||
VARIABLE = "variable"
|
VARIABLE = "variable"
|
||||||
FILE = "file"
|
FILE = "file"
|
||||||
LOG = "log"
|
LOG = "log"
|
||||||
|
BLOB_CHUNK = "blob_chunk"
|
||||||
|
|
||||||
type: MessageType = MessageType.TEXT
|
type: MessageType = MessageType.TEXT
|
||||||
"""
|
"""
|
||||||
plain text, image url or link url
|
plain text, image url or link url
|
||||||
"""
|
"""
|
||||||
message: JsonMessage | TextMessage | BlobMessage | LogMessage | FileMessage | None | VariableMessage
|
message: (
|
||||||
|
JsonMessage | TextMessage | BlobChunkMessage | BlobMessage | LogMessage | FileMessage | None | VariableMessage
|
||||||
|
)
|
||||||
meta: dict[str, Any] | None = None
|
meta: dict[str, Any] | None = None
|
||||||
|
|
||||||
@field_validator("message", mode="before")
|
@field_validator("message", mode="before")
|
||||||
|
|
|
@ -86,6 +86,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
|
||||||
"position": position,
|
"position": position,
|
||||||
"dataset_id": item.metadata.get("dataset_id"),
|
"dataset_id": item.metadata.get("dataset_id"),
|
||||||
"dataset_name": item.metadata.get("dataset_name"),
|
"dataset_name": item.metadata.get("dataset_name"),
|
||||||
|
"document_id": item.metadata.get("document_id") or item.metadata.get("title"),
|
||||||
"document_name": item.metadata.get("title"),
|
"document_name": item.metadata.get("title"),
|
||||||
"data_source_type": "external",
|
"data_source_type": "external",
|
||||||
"retriever_from": self.retriever_from,
|
"retriever_from": self.retriever_from,
|
||||||
|
|
|
@ -1,13 +0,0 @@
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from core.workflow.graph_engine.entities.graph import GraphParallel
|
|
||||||
|
|
||||||
|
|
||||||
class NextGraphNode(BaseModel):
|
|
||||||
node_id: str
|
|
||||||
"""next node id"""
|
|
||||||
|
|
||||||
parallel: Optional[GraphParallel] = None
|
|
||||||
"""parallel"""
|
|
|
@ -155,9 +155,28 @@ class AnswerStreamProcessor(StreamProcessor):
|
||||||
for answer_node_id, route_position in self.route_position.items():
|
for answer_node_id, route_position in self.route_position.items():
|
||||||
if answer_node_id not in self.rest_node_ids:
|
if answer_node_id not in self.rest_node_ids:
|
||||||
continue
|
continue
|
||||||
# exclude current node id
|
# Remove current node id from answer dependencies to support stream output if it is a success branch
|
||||||
answer_dependencies = self.generate_routes.answer_dependencies
|
answer_dependencies = self.generate_routes.answer_dependencies
|
||||||
if event.node_id in answer_dependencies[answer_node_id]:
|
edge_mapping = self.graph.edge_mapping.get(event.node_id)
|
||||||
|
success_edge = (
|
||||||
|
next(
|
||||||
|
(
|
||||||
|
edge
|
||||||
|
for edge in edge_mapping
|
||||||
|
if edge.run_condition
|
||||||
|
and edge.run_condition.type == "branch_identify"
|
||||||
|
and edge.run_condition.branch_identify == "success-branch"
|
||||||
|
),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
if edge_mapping
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
if (
|
||||||
|
event.node_id in answer_dependencies[answer_node_id]
|
||||||
|
and success_edge
|
||||||
|
and success_edge.target_node_id == answer_node_id
|
||||||
|
):
|
||||||
answer_dependencies[answer_node_id].remove(event.node_id)
|
answer_dependencies[answer_node_id].remove(event.node_id)
|
||||||
answer_dependencies_ids = answer_dependencies.get(answer_node_id, [])
|
answer_dependencies_ids = answer_dependencies.get(answer_node_id, [])
|
||||||
# all depends on answer node id not in rest node ids
|
# all depends on answer node id not in rest node ids
|
||||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue