mirror of https://github.com/langgenius/dify.git
Merge branch 'feat/structured-output' into deploy/dev
This commit is contained in:
commit
7936770810
|
@ -327,7 +327,6 @@ UPLOAD_AUDIO_FILE_SIZE_LIMIT=50
|
|||
MULTIMODAL_SEND_FORMAT=base64
|
||||
PROMPT_GENERATION_MAX_TOKENS=512
|
||||
CODE_GENERATION_MAX_TOKENS=1024
|
||||
STRUCTURED_OUTPUT_MAX_TOKENS=1024
|
||||
PLUGIN_BASED_TOKEN_COUNTING_ENABLED=false
|
||||
|
||||
# Mail configuration, support: resend, smtp
|
||||
|
|
|
@ -96,13 +96,11 @@ class RuleStructuredOutputGenerateApi(Resource):
|
|||
args = parser.parse_args()
|
||||
|
||||
account = current_user
|
||||
structured_output_max_tokens = int(os.getenv("STRUCTURED_OUTPUT_MAX_TOKENS", "1024"))
|
||||
try:
|
||||
structured_output = LLMGenerator.generate_structured_output(
|
||||
tenant_id=account.current_tenant_id,
|
||||
instruction=args["instruction"],
|
||||
model_config=args["model_config"],
|
||||
max_tokens=structured_output_max_tokens,
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
|
|
|
@ -10,7 +10,7 @@ from core.llm_generator.prompts import (
|
|||
GENERATOR_QA_PROMPT,
|
||||
JAVASCRIPT_CODE_GENERATOR_PROMPT_TEMPLATE,
|
||||
PYTHON_CODE_GENERATOR_PROMPT_TEMPLATE,
|
||||
STRUCTURED_OUTPUT_GENERATE_TEMPLATE,
|
||||
SYSTEM_STRUCTURED_OUTPUT_GENERATE,
|
||||
WORKFLOW_RULE_CONFIG_PROMPT_GENERATE_TEMPLATE,
|
||||
)
|
||||
from core.model_manager import ModelManager
|
||||
|
@ -343,16 +343,7 @@ class LLMGenerator:
|
|||
return answer.strip()
|
||||
|
||||
@classmethod
|
||||
def generate_structured_output(cls, tenant_id: str, instruction: str, model_config: dict, max_tokens: int):
|
||||
prompt_template = PromptTemplateParser(STRUCTURED_OUTPUT_GENERATE_TEMPLATE)
|
||||
|
||||
prompt = prompt_template.format(
|
||||
inputs={
|
||||
"INSTRUCTION": instruction,
|
||||
},
|
||||
remove_template_variables=False,
|
||||
)
|
||||
|
||||
def generate_structured_output(cls, tenant_id: str, instruction: str, model_config: dict):
|
||||
model_manager = ModelManager()
|
||||
model_instance = model_manager.get_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
|
@ -361,7 +352,10 @@ class LLMGenerator:
|
|||
model=model_config.get("name", ""),
|
||||
)
|
||||
|
||||
prompt_messages = [UserPromptMessage(content=prompt)]
|
||||
prompt_messages = [
|
||||
SystemPromptMessage(content=SYSTEM_STRUCTURED_OUTPUT_GENERATE),
|
||||
UserPromptMessage(content=instruction),
|
||||
]
|
||||
model_parameters = model_config.get("model_parameters", {})
|
||||
|
||||
try:
|
||||
|
|
|
@ -221,7 +221,7 @@ Here is the task description: {{INPUT_TEXT}}
|
|||
You just need to generate the output
|
||||
""" # noqa: E501
|
||||
|
||||
STRUCTURED_OUTPUT_GENERATE_TEMPLATE = """
|
||||
SYSTEM_STRUCTURED_OUTPUT_GENERATE = """
|
||||
Your task is to convert simple user descriptions into properly formatted JSON Schema definitions. When a user describes data fields they need, generate a complete, valid JSON Schema that accurately represents those fields with appropriate types and requirements.
|
||||
|
||||
## Instructions:
|
||||
|
@ -325,7 +325,5 @@ Your task is to convert simple user descriptions into properly formatted JSON Sc
|
|||
]
|
||||
}
|
||||
|
||||
Now, generate a JSON Schema based on my description:
|
||||
**User Input:** {{INSTRUCTION}}
|
||||
**JSON Schema Output:**
|
||||
Now, generate a JSON Schema based on my description
|
||||
""" # noqa: E501
|
||||
|
|
|
@ -202,12 +202,13 @@ class AIModelEntity(ProviderModel):
|
|||
def validate_model(self):
|
||||
supported_schema_keys = ["json_schema"]
|
||||
schema_key = next((rule.name for rule in self.parameter_rules if rule.name in supported_schema_keys), None)
|
||||
if schema_key:
|
||||
if self.features is None:
|
||||
self.features = [ModelFeature.STRUCTURED_OUTPUT]
|
||||
else:
|
||||
if ModelFeature.STRUCTURED_OUTPUT not in self.features:
|
||||
self.features = [*self.features, ModelFeature.STRUCTURED_OUTPUT]
|
||||
if not schema_key:
|
||||
return self
|
||||
if self.features is None:
|
||||
self.features = [ModelFeature.STRUCTURED_OUTPUT]
|
||||
else:
|
||||
if ModelFeature.STRUCTURED_OUTPUT not in self.features:
|
||||
self.features.append(ModelFeature.STRUCTURED_OUTPUT)
|
||||
return self
|
||||
|
||||
|
||||
|
|
|
@ -102,6 +102,12 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||
_node_type = NodeType.LLM
|
||||
|
||||
def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]:
|
||||
def process_structured_output(text: str) -> Optional[dict[str, Any] | list[Any]]:
|
||||
"""Process structured output if enabled"""
|
||||
if not self.node_data.structured_output_enabled or not self.node_data.structured_output:
|
||||
return None
|
||||
return self._parse_structured_output(text)
|
||||
|
||||
node_inputs: Optional[dict[str, Any]] = None
|
||||
process_data = None
|
||||
result_text = ""
|
||||
|
@ -201,19 +207,8 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||
self.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage)
|
||||
break
|
||||
outputs = {"text": result_text, "usage": jsonable_encoder(usage), "finish_reason": finish_reason}
|
||||
if self.node_data.structured_output_enabled and self.node_data.structured_output:
|
||||
structured_output: dict[str, Any] | list[Any] = {}
|
||||
try:
|
||||
parsed = json.loads(result_text)
|
||||
if not isinstance(parsed, (dict | list)):
|
||||
raise LLMNodeError(f"Failed to parse structured output: {result_text}")
|
||||
structured_output = parsed
|
||||
except json.JSONDecodeError as e:
|
||||
# if the result_text is not a valid json, try to repair it
|
||||
parsed = json_repair.loads(result_text)
|
||||
if not isinstance(parsed, (dict | list)):
|
||||
raise LLMNodeError(f"Failed to parse structured output: {result_text}")
|
||||
structured_output = parsed
|
||||
structured_output = process_structured_output(result_text)
|
||||
if structured_output:
|
||||
outputs["structured_output"] = structured_output
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
|
@ -759,6 +754,21 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||
stop = model_config.stop
|
||||
return filtered_prompt_messages, stop
|
||||
|
||||
def _parse_structured_output(self, result_text: str) -> dict[str, Any] | list[Any]:
|
||||
structured_output: dict[str, Any] | list[Any] = {}
|
||||
try:
|
||||
parsed = json.loads(result_text)
|
||||
if not isinstance(parsed, (dict | list)):
|
||||
raise LLMNodeError(f"Failed to parse structured output: {result_text}")
|
||||
structured_output = parsed
|
||||
except json.JSONDecodeError as e:
|
||||
# if the result_text is not a valid json, try to repair it
|
||||
parsed = json_repair.loads(result_text)
|
||||
if not isinstance(parsed, (dict | list)):
|
||||
raise LLMNodeError(f"Failed to parse structured output: {result_text}")
|
||||
structured_output = parsed
|
||||
return structured_output
|
||||
|
||||
@classmethod
|
||||
def deduct_llm_quota(cls, tenant_id: str, model_instance: ModelInstance, usage: LLMUsage) -> None:
|
||||
provider_model_bundle = model_instance.provider_model_bundle
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -90,158 +90,6 @@ dependencies = [
|
|||
[tool.uv]
|
||||
default-groups = ["storage", "tools", "vdb"]
|
||||
|
||||
############################################################
|
||||
# [ Main ] Dependency group
|
||||
############################################################
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
authlib = "1.3.1"
|
||||
azure-identity = "1.16.1"
|
||||
beautifulsoup4 = "4.12.2"
|
||||
boto3 = "1.35.99"
|
||||
bs4 = "~0.0.1"
|
||||
cachetools = "~5.3.0"
|
||||
celery = "~5.4.0"
|
||||
chardet = "~5.1.0"
|
||||
flask = "~3.1.0"
|
||||
flask-compress = "~1.17"
|
||||
flask-cors = "~4.0.0"
|
||||
flask-login = "~0.6.3"
|
||||
flask-migrate = "~4.0.7"
|
||||
flask-restful = "~0.3.10"
|
||||
flask-sqlalchemy = "~3.1.1"
|
||||
gevent = "~24.11.1"
|
||||
gmpy2 = "~2.2.1"
|
||||
google-api-core = "2.18.0"
|
||||
google-api-python-client = "2.90.0"
|
||||
google-auth = "2.29.0"
|
||||
google-auth-httplib2 = "0.2.0"
|
||||
google-cloud-aiplatform = "1.49.0"
|
||||
googleapis-common-protos = "1.63.0"
|
||||
gunicorn = "~23.0.0"
|
||||
httpx = { version = "~0.27.0", extras = ["socks"] }
|
||||
jieba = "0.42.1"
|
||||
json-repair = "~0.40.0"
|
||||
langfuse = "~2.51.3"
|
||||
langsmith = "~0.1.77"
|
||||
mailchimp-transactional = "~1.0.50"
|
||||
markdown = "~3.5.1"
|
||||
numpy = "~1.26.4"
|
||||
oci = "~2.135.1"
|
||||
openai = "~1.61.0"
|
||||
openpyxl = "~3.1.5"
|
||||
opentelemetry-api = "1.27.0"
|
||||
opentelemetry-distro = "0.48b0"
|
||||
opentelemetry-exporter-otlp = "1.27.0"
|
||||
opentelemetry-exporter-otlp-proto-common = "1.27.0"
|
||||
opentelemetry-exporter-otlp-proto-grpc = "1.27.0"
|
||||
opentelemetry-exporter-otlp-proto-http = "1.27.0"
|
||||
opentelemetry-instrumentation = "0.48b0"
|
||||
opentelemetry-instrumentation-flask = "0.48b0"
|
||||
opentelemetry-instrumentation-sqlalchemy = "0.48b0"
|
||||
opentelemetry-propagator-b3 = "1.27.0"
|
||||
opentelemetry-proto = "1.27.0" # 1.28.0 depends on protobuf (>=5.0,<6.0), conflict with googleapis-common-protos (1.63.0)
|
||||
opentelemetry-sdk = "1.27.0"
|
||||
opentelemetry-semantic-conventions = "0.48b0"
|
||||
opentelemetry-util-http = "0.48b0"
|
||||
opik = "~1.3.4"
|
||||
pandas = { version = "~2.2.2", extras = [
|
||||
"performance",
|
||||
"excel",
|
||||
"output-formatting",
|
||||
] }
|
||||
pandas-stubs = "~2.2.3.241009"
|
||||
pandoc = "~2.4"
|
||||
psycogreen = "~1.0.2"
|
||||
psycopg2-binary = "~2.9.6"
|
||||
pycryptodome = "3.19.1"
|
||||
pydantic = "~2.9.2"
|
||||
pydantic-settings = "~2.6.0"
|
||||
pydantic_extra_types = "~2.9.0"
|
||||
pyjwt = "~2.8.0"
|
||||
pypdfium2 = "~4.30.0"
|
||||
python = ">=3.11,<3.13"
|
||||
python-docx = "~1.1.0"
|
||||
python-dotenv = "1.0.1"
|
||||
pyyaml = "~6.0.1"
|
||||
readabilipy = "0.2.0"
|
||||
redis = { version = "~5.0.3", extras = ["hiredis"] }
|
||||
resend = "~0.7.0"
|
||||
sentry-sdk = { version = "~1.44.1", extras = ["flask"] }
|
||||
sqlalchemy = "~2.0.29"
|
||||
starlette = "0.41.0"
|
||||
tiktoken = "~0.8.0"
|
||||
tokenizers = "~0.15.0"
|
||||
transformers = "~4.35.0"
|
||||
unstructured = { version = "~0.16.1", extras = [
|
||||
"docx",
|
||||
"epub",
|
||||
"md",
|
||||
"ppt",
|
||||
"pptx",
|
||||
] }
|
||||
validators = "0.21.0"
|
||||
yarl = "~1.18.3"
|
||||
# Before adding new dependency, consider place it in alphabet order (a-z) and suitable group.
|
||||
|
||||
############################################################
|
||||
# [ Indirect ] dependency group
|
||||
# Related transparent dependencies with pinned version
|
||||
# required by main implementations
|
||||
############################################################
|
||||
[tool.poetry.group.indirect.dependencies]
|
||||
kaleido = "0.2.1"
|
||||
rank-bm25 = "~0.2.2"
|
||||
safetensors = "~0.4.3"
|
||||
|
||||
############################################################
|
||||
# [ Tools ] dependency group
|
||||
############################################################
|
||||
[tool.poetry.group.tools.dependencies]
|
||||
cloudscraper = "1.2.71"
|
||||
nltk = "3.9.1"
|
||||
|
||||
############################################################
|
||||
# [ Storage ] dependency group
|
||||
# Required for storage clients
|
||||
############################################################
|
||||
[tool.poetry.group.storage.dependencies]
|
||||
azure-storage-blob = "12.13.0"
|
||||
bce-python-sdk = "~0.9.23"
|
||||
cos-python-sdk-v5 = "1.9.30"
|
||||
esdk-obs-python = "3.24.6.1"
|
||||
google-cloud-storage = "2.16.0"
|
||||
opendal = "~0.45.16"
|
||||
oss2 = "2.18.5"
|
||||
supabase = "~2.8.1"
|
||||
tos = "~2.7.1"
|
||||
|
||||
############################################################
|
||||
# [ VDB ] dependency group
|
||||
# Required by vector store clients
|
||||
############################################################
|
||||
[tool.poetry.group.vdb.dependencies]
|
||||
alibabacloud_gpdb20160503 = "~3.8.0"
|
||||
alibabacloud_tea_openapi = "~0.3.9"
|
||||
chromadb = "0.5.20"
|
||||
clickhouse-connect = "~0.7.16"
|
||||
couchbase = "~4.3.0"
|
||||
elasticsearch = "8.14.0"
|
||||
opensearch-py = "2.4.0"
|
||||
oracledb = "~2.2.1"
|
||||
pgvecto-rs = { version = "~0.2.1", extras = ['sqlalchemy'] }
|
||||
pgvector = "0.2.5"
|
||||
pymilvus = "~2.5.0"
|
||||
pymochow = "1.3.1"
|
||||
pyobvector = "~0.1.6"
|
||||
qdrant-client = "1.7.3"
|
||||
tablestore = "6.1.0"
|
||||
tcvectordb = "~1.6.4"
|
||||
tidb-vector = "0.0.9"
|
||||
upstash-vector = "0.6.0"
|
||||
volcengine-compat = "~1.0.156"
|
||||
weaviate-client = "~3.21.0"
|
||||
xinference-client = "~1.2.2"
|
||||
[dependency-groups]
|
||||
|
||||
############################################################
|
||||
|
@ -316,10 +164,7 @@ storage = [
|
|||
############################################################
|
||||
# [ Tools ] dependency group
|
||||
############################################################
|
||||
tools = [
|
||||
"cloudscraper~=1.2.71",
|
||||
"nltk~=3.9.1",
|
||||
]
|
||||
tools = ["cloudscraper~=1.2.71", "nltk~=3.9.1"]
|
||||
|
||||
############################################################
|
||||
# [ VDB ] dependency group
|
||||
|
|
|
@ -623,11 +623,6 @@ PROMPT_GENERATION_MAX_TOKENS=512
|
|||
# Default: 1024 tokens.
|
||||
CODE_GENERATION_MAX_TOKENS=1024
|
||||
|
||||
# The maximum number of tokens allowed for structured output.
|
||||
# This setting controls the upper limit of tokens that can be used by the LLM
|
||||
# when generating structured output in the structured output tool.
|
||||
# Default: 1024 tokens.
|
||||
STRUCTURED_OUTPUT_MAX_TOKENS=1024
|
||||
# Enable or disable plugin based token counting. If disabled, token counting will return 0.
|
||||
# This can improve performance by skipping token counting operations.
|
||||
# Default: false (disabled).
|
||||
|
|
|
@ -280,7 +280,6 @@ x-shared-env: &shared-api-worker-env
|
|||
SCARF_NO_ANALYTICS: ${SCARF_NO_ANALYTICS:-true}
|
||||
PROMPT_GENERATION_MAX_TOKENS: ${PROMPT_GENERATION_MAX_TOKENS:-512}
|
||||
CODE_GENERATION_MAX_TOKENS: ${CODE_GENERATION_MAX_TOKENS:-1024}
|
||||
STRUCTURED_OUTPUT_MAX_TOKENS: ${STRUCTURED_OUTPUT_MAX_TOKENS:-1024}
|
||||
PLUGIN_BASED_TOKEN_COUNTING_ENABLED: ${PLUGIN_BASED_TOKEN_COUNTING_ENABLED:-false}
|
||||
MULTIMODAL_SEND_FORMAT: ${MULTIMODAL_SEND_FORMAT:-base64}
|
||||
UPLOAD_IMAGE_FILE_SIZE_LIMIT: ${UPLOAD_IMAGE_FILE_SIZE_LIMIT:-10}
|
||||
|
|
Loading…
Reference in New Issue