Merge branch 'feat/structured-output' into deploy/dev

This commit is contained in:
Novice 2025-04-18 09:29:04 +08:00
commit 7936770810
10 changed files with 39 additions and 10050 deletions

View File

@ -327,7 +327,6 @@ UPLOAD_AUDIO_FILE_SIZE_LIMIT=50
MULTIMODAL_SEND_FORMAT=base64
PROMPT_GENERATION_MAX_TOKENS=512
CODE_GENERATION_MAX_TOKENS=1024
STRUCTURED_OUTPUT_MAX_TOKENS=1024
PLUGIN_BASED_TOKEN_COUNTING_ENABLED=false
# Mail configuration, support: resend, smtp

View File

@ -96,13 +96,11 @@ class RuleStructuredOutputGenerateApi(Resource):
args = parser.parse_args()
account = current_user
structured_output_max_tokens = int(os.getenv("STRUCTURED_OUTPUT_MAX_TOKENS", "1024"))
try:
structured_output = LLMGenerator.generate_structured_output(
tenant_id=account.current_tenant_id,
instruction=args["instruction"],
model_config=args["model_config"],
max_tokens=structured_output_max_tokens,
)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)

View File

@ -10,7 +10,7 @@ from core.llm_generator.prompts import (
GENERATOR_QA_PROMPT,
JAVASCRIPT_CODE_GENERATOR_PROMPT_TEMPLATE,
PYTHON_CODE_GENERATOR_PROMPT_TEMPLATE,
STRUCTURED_OUTPUT_GENERATE_TEMPLATE,
SYSTEM_STRUCTURED_OUTPUT_GENERATE,
WORKFLOW_RULE_CONFIG_PROMPT_GENERATE_TEMPLATE,
)
from core.model_manager import ModelManager
@ -343,16 +343,7 @@ class LLMGenerator:
return answer.strip()
@classmethod
def generate_structured_output(cls, tenant_id: str, instruction: str, model_config: dict, max_tokens: int):
prompt_template = PromptTemplateParser(STRUCTURED_OUTPUT_GENERATE_TEMPLATE)
prompt = prompt_template.format(
inputs={
"INSTRUCTION": instruction,
},
remove_template_variables=False,
)
def generate_structured_output(cls, tenant_id: str, instruction: str, model_config: dict):
model_manager = ModelManager()
model_instance = model_manager.get_model_instance(
tenant_id=tenant_id,
@ -361,7 +352,10 @@ class LLMGenerator:
model=model_config.get("name", ""),
)
prompt_messages = [UserPromptMessage(content=prompt)]
prompt_messages = [
SystemPromptMessage(content=SYSTEM_STRUCTURED_OUTPUT_GENERATE),
UserPromptMessage(content=instruction),
]
model_parameters = model_config.get("model_parameters", {})
try:

View File

@ -221,7 +221,7 @@ Here is the task description: {{INPUT_TEXT}}
You just need to generate the output
""" # noqa: E501
STRUCTURED_OUTPUT_GENERATE_TEMPLATE = """
SYSTEM_STRUCTURED_OUTPUT_GENERATE = """
Your task is to convert simple user descriptions into properly formatted JSON Schema definitions. When a user describes data fields they need, generate a complete, valid JSON Schema that accurately represents those fields with appropriate types and requirements.
## Instructions:
@ -325,7 +325,5 @@ Your task is to convert simple user descriptions into properly formatted JSON Sc
]
}
Now, generate a JSON Schema based on my description:
**User Input:** {{INSTRUCTION}}
**JSON Schema Output:**
Now, generate a JSON Schema based on my description
""" # noqa: E501

View File

@ -202,12 +202,13 @@ class AIModelEntity(ProviderModel):
def validate_model(self):
supported_schema_keys = ["json_schema"]
schema_key = next((rule.name for rule in self.parameter_rules if rule.name in supported_schema_keys), None)
if schema_key:
if not schema_key:
return self
if self.features is None:
self.features = [ModelFeature.STRUCTURED_OUTPUT]
else:
if ModelFeature.STRUCTURED_OUTPUT not in self.features:
self.features = [*self.features, ModelFeature.STRUCTURED_OUTPUT]
self.features.append(ModelFeature.STRUCTURED_OUTPUT)
return self

View File

@ -102,6 +102,12 @@ class LLMNode(BaseNode[LLMNodeData]):
_node_type = NodeType.LLM
def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]:
def process_structured_output(text: str) -> Optional[dict[str, Any] | list[Any]]:
"""Process structured output if enabled"""
if not self.node_data.structured_output_enabled or not self.node_data.structured_output:
return None
return self._parse_structured_output(text)
node_inputs: Optional[dict[str, Any]] = None
process_data = None
result_text = ""
@ -201,19 +207,8 @@ class LLMNode(BaseNode[LLMNodeData]):
self.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage)
break
outputs = {"text": result_text, "usage": jsonable_encoder(usage), "finish_reason": finish_reason}
if self.node_data.structured_output_enabled and self.node_data.structured_output:
structured_output: dict[str, Any] | list[Any] = {}
try:
parsed = json.loads(result_text)
if not isinstance(parsed, (dict | list)):
raise LLMNodeError(f"Failed to parse structured output: {result_text}")
structured_output = parsed
except json.JSONDecodeError as e:
# if the result_text is not a valid json, try to repair it
parsed = json_repair.loads(result_text)
if not isinstance(parsed, (dict | list)):
raise LLMNodeError(f"Failed to parse structured output: {result_text}")
structured_output = parsed
structured_output = process_structured_output(result_text)
if structured_output:
outputs["structured_output"] = structured_output
yield RunCompletedEvent(
run_result=NodeRunResult(
@ -759,6 +754,21 @@ class LLMNode(BaseNode[LLMNodeData]):
stop = model_config.stop
return filtered_prompt_messages, stop
def _parse_structured_output(self, result_text: str) -> dict[str, Any] | list[Any]:
structured_output: dict[str, Any] | list[Any] = {}
try:
parsed = json.loads(result_text)
if not isinstance(parsed, (dict | list)):
raise LLMNodeError(f"Failed to parse structured output: {result_text}")
structured_output = parsed
except json.JSONDecodeError as e:
# if the result_text is not a valid json, try to repair it
parsed = json_repair.loads(result_text)
if not isinstance(parsed, (dict | list)):
raise LLMNodeError(f"Failed to parse structured output: {result_text}")
structured_output = parsed
return structured_output
@classmethod
def deduct_llm_quota(cls, tenant_id: str, model_instance: ModelInstance, usage: LLMUsage) -> None:
provider_model_bundle = model_instance.provider_model_bundle

9850
api/poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -90,158 +90,6 @@ dependencies = [
[tool.uv]
default-groups = ["storage", "tools", "vdb"]
############################################################
# [ Main ] Dependency group
############################################################
[tool.poetry.dependencies]
authlib = "1.3.1"
azure-identity = "1.16.1"
beautifulsoup4 = "4.12.2"
boto3 = "1.35.99"
bs4 = "~0.0.1"
cachetools = "~5.3.0"
celery = "~5.4.0"
chardet = "~5.1.0"
flask = "~3.1.0"
flask-compress = "~1.17"
flask-cors = "~4.0.0"
flask-login = "~0.6.3"
flask-migrate = "~4.0.7"
flask-restful = "~0.3.10"
flask-sqlalchemy = "~3.1.1"
gevent = "~24.11.1"
gmpy2 = "~2.2.1"
google-api-core = "2.18.0"
google-api-python-client = "2.90.0"
google-auth = "2.29.0"
google-auth-httplib2 = "0.2.0"
google-cloud-aiplatform = "1.49.0"
googleapis-common-protos = "1.63.0"
gunicorn = "~23.0.0"
httpx = { version = "~0.27.0", extras = ["socks"] }
jieba = "0.42.1"
json-repair = "~0.40.0"
langfuse = "~2.51.3"
langsmith = "~0.1.77"
mailchimp-transactional = "~1.0.50"
markdown = "~3.5.1"
numpy = "~1.26.4"
oci = "~2.135.1"
openai = "~1.61.0"
openpyxl = "~3.1.5"
opentelemetry-api = "1.27.0"
opentelemetry-distro = "0.48b0"
opentelemetry-exporter-otlp = "1.27.0"
opentelemetry-exporter-otlp-proto-common = "1.27.0"
opentelemetry-exporter-otlp-proto-grpc = "1.27.0"
opentelemetry-exporter-otlp-proto-http = "1.27.0"
opentelemetry-instrumentation = "0.48b0"
opentelemetry-instrumentation-flask = "0.48b0"
opentelemetry-instrumentation-sqlalchemy = "0.48b0"
opentelemetry-propagator-b3 = "1.27.0"
opentelemetry-proto = "1.27.0" # 1.28.0 depends on protobuf (>=5.0,<6.0), conflict with googleapis-common-protos (1.63.0)
opentelemetry-sdk = "1.27.0"
opentelemetry-semantic-conventions = "0.48b0"
opentelemetry-util-http = "0.48b0"
opik = "~1.3.4"
pandas = { version = "~2.2.2", extras = [
"performance",
"excel",
"output-formatting",
] }
pandas-stubs = "~2.2.3.241009"
pandoc = "~2.4"
psycogreen = "~1.0.2"
psycopg2-binary = "~2.9.6"
pycryptodome = "3.19.1"
pydantic = "~2.9.2"
pydantic-settings = "~2.6.0"
pydantic_extra_types = "~2.9.0"
pyjwt = "~2.8.0"
pypdfium2 = "~4.30.0"
python = ">=3.11,<3.13"
python-docx = "~1.1.0"
python-dotenv = "1.0.1"
pyyaml = "~6.0.1"
readabilipy = "0.2.0"
redis = { version = "~5.0.3", extras = ["hiredis"] }
resend = "~0.7.0"
sentry-sdk = { version = "~1.44.1", extras = ["flask"] }
sqlalchemy = "~2.0.29"
starlette = "0.41.0"
tiktoken = "~0.8.0"
tokenizers = "~0.15.0"
transformers = "~4.35.0"
unstructured = { version = "~0.16.1", extras = [
"docx",
"epub",
"md",
"ppt",
"pptx",
] }
validators = "0.21.0"
yarl = "~1.18.3"
# Before adding new dependency, consider place it in alphabet order (a-z) and suitable group.
############################################################
# [ Indirect ] dependency group
# Related transparent dependencies with pinned version
# required by main implementations
############################################################
[tool.poetry.group.indirect.dependencies]
kaleido = "0.2.1"
rank-bm25 = "~0.2.2"
safetensors = "~0.4.3"
############################################################
# [ Tools ] dependency group
############################################################
[tool.poetry.group.tools.dependencies]
cloudscraper = "1.2.71"
nltk = "3.9.1"
############################################################
# [ Storage ] dependency group
# Required for storage clients
############################################################
[tool.poetry.group.storage.dependencies]
azure-storage-blob = "12.13.0"
bce-python-sdk = "~0.9.23"
cos-python-sdk-v5 = "1.9.30"
esdk-obs-python = "3.24.6.1"
google-cloud-storage = "2.16.0"
opendal = "~0.45.16"
oss2 = "2.18.5"
supabase = "~2.8.1"
tos = "~2.7.1"
############################################################
# [ VDB ] dependency group
# Required by vector store clients
############################################################
[tool.poetry.group.vdb.dependencies]
alibabacloud_gpdb20160503 = "~3.8.0"
alibabacloud_tea_openapi = "~0.3.9"
chromadb = "0.5.20"
clickhouse-connect = "~0.7.16"
couchbase = "~4.3.0"
elasticsearch = "8.14.0"
opensearch-py = "2.4.0"
oracledb = "~2.2.1"
pgvecto-rs = { version = "~0.2.1", extras = ['sqlalchemy'] }
pgvector = "0.2.5"
pymilvus = "~2.5.0"
pymochow = "1.3.1"
pyobvector = "~0.1.6"
qdrant-client = "1.7.3"
tablestore = "6.1.0"
tcvectordb = "~1.6.4"
tidb-vector = "0.0.9"
upstash-vector = "0.6.0"
volcengine-compat = "~1.0.156"
weaviate-client = "~3.21.0"
xinference-client = "~1.2.2"
[dependency-groups]
############################################################
@ -316,10 +164,7 @@ storage = [
############################################################
# [ Tools ] dependency group
############################################################
tools = [
"cloudscraper~=1.2.71",
"nltk~=3.9.1",
]
tools = ["cloudscraper~=1.2.71", "nltk~=3.9.1"]
############################################################
# [ VDB ] dependency group

View File

@ -623,11 +623,6 @@ PROMPT_GENERATION_MAX_TOKENS=512
# Default: 1024 tokens.
CODE_GENERATION_MAX_TOKENS=1024
# The maximum number of tokens allowed for structured output.
# This setting controls the upper limit of tokens that can be used by the LLM
# when generating structured output in the structured output tool.
# Default: 1024 tokens.
STRUCTURED_OUTPUT_MAX_TOKENS=1024
# Enable or disable plugin based token counting. If disabled, token counting will return 0.
# This can improve performance by skipping token counting operations.
# Default: false (disabled).

View File

@ -280,7 +280,6 @@ x-shared-env: &shared-api-worker-env
SCARF_NO_ANALYTICS: ${SCARF_NO_ANALYTICS:-true}
PROMPT_GENERATION_MAX_TOKENS: ${PROMPT_GENERATION_MAX_TOKENS:-512}
CODE_GENERATION_MAX_TOKENS: ${CODE_GENERATION_MAX_TOKENS:-1024}
STRUCTURED_OUTPUT_MAX_TOKENS: ${STRUCTURED_OUTPUT_MAX_TOKENS:-1024}
PLUGIN_BASED_TOKEN_COUNTING_ENABLED: ${PLUGIN_BASED_TOKEN_COUNTING_ENABLED:-false}
MULTIMODAL_SEND_FORMAT: ${MULTIMODAL_SEND_FORMAT:-base64}
UPLOAD_IMAGE_FILE_SIZE_LIMIT: ${UPLOAD_IMAGE_FILE_SIZE_LIMIT:-10}