mirror of https://github.com/langgenius/dify.git
chore: fix the code review issues
This commit is contained in:
parent
56ef622447
commit
e6a7a1884c
|
@ -326,7 +326,6 @@ UPLOAD_AUDIO_FILE_SIZE_LIMIT=50
|
|||
MULTIMODAL_SEND_FORMAT=base64
|
||||
PROMPT_GENERATION_MAX_TOKENS=512
|
||||
CODE_GENERATION_MAX_TOKENS=1024
|
||||
STRUCTURED_OUTPUT_MAX_TOKENS=1024
|
||||
PLUGIN_BASED_TOKEN_COUNTING_ENABLED=false
|
||||
|
||||
# Mail configuration, support: resend, smtp
|
||||
|
|
|
@ -96,13 +96,11 @@ class RuleStructuredOutputGenerateApi(Resource):
|
|||
args = parser.parse_args()
|
||||
|
||||
account = current_user
|
||||
structured_output_max_tokens = int(os.getenv("STRUCTURED_OUTPUT_MAX_TOKENS", "1024"))
|
||||
try:
|
||||
structured_output = LLMGenerator.generate_structured_output(
|
||||
tenant_id=account.current_tenant_id,
|
||||
instruction=args["instruction"],
|
||||
model_config=args["model_config"],
|
||||
max_tokens=structured_output_max_tokens,
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
|
|
|
@ -10,7 +10,7 @@ from core.llm_generator.prompts import (
|
|||
GENERATOR_QA_PROMPT,
|
||||
JAVASCRIPT_CODE_GENERATOR_PROMPT_TEMPLATE,
|
||||
PYTHON_CODE_GENERATOR_PROMPT_TEMPLATE,
|
||||
STRUCTURED_OUTPUT_GENERATE_TEMPLATE,
|
||||
SYSTEM_STRUCTURED_OUTPUT_GENERATE,
|
||||
WORKFLOW_RULE_CONFIG_PROMPT_GENERATE_TEMPLATE,
|
||||
)
|
||||
from core.model_manager import ModelManager
|
||||
|
@ -343,16 +343,7 @@ class LLMGenerator:
|
|||
return answer.strip()
|
||||
|
||||
@classmethod
|
||||
def generate_structured_output(cls, tenant_id: str, instruction: str, model_config: dict, max_tokens: int):
|
||||
prompt_template = PromptTemplateParser(STRUCTURED_OUTPUT_GENERATE_TEMPLATE)
|
||||
|
||||
prompt = prompt_template.format(
|
||||
inputs={
|
||||
"INSTRUCTION": instruction,
|
||||
},
|
||||
remove_template_variables=False,
|
||||
)
|
||||
|
||||
def generate_structured_output(cls, tenant_id: str, instruction: str, model_config: dict):
|
||||
model_manager = ModelManager()
|
||||
model_instance = model_manager.get_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
|
@ -361,7 +352,10 @@ class LLMGenerator:
|
|||
model=model_config.get("name", ""),
|
||||
)
|
||||
|
||||
prompt_messages = [UserPromptMessage(content=prompt)]
|
||||
prompt_messages = [
|
||||
SystemPromptMessage(content=SYSTEM_STRUCTURED_OUTPUT_GENERATE),
|
||||
UserPromptMessage(content=instruction),
|
||||
]
|
||||
model_parameters = model_config.get("model_parameters", {})
|
||||
|
||||
try:
|
||||
|
|
|
@ -325,7 +325,5 @@ Your task is to convert simple user descriptions into properly formatted JSON Sc
|
|||
]
|
||||
}
|
||||
|
||||
Now, generate a JSON Schema based on my description:
|
||||
**User Input:** {{INSTRUCTION}}
|
||||
**JSON Schema Output:**
|
||||
Now, generate a JSON Schema based on my description
|
||||
""" # noqa: E501
|
||||
|
|
|
@ -202,12 +202,13 @@ class AIModelEntity(ProviderModel):
|
|||
def validate_model(self):
|
||||
supported_schema_keys = ["json_schema"]
|
||||
schema_key = next((rule.name for rule in self.parameter_rules if rule.name in supported_schema_keys), None)
|
||||
if schema_key:
|
||||
if self.features is None:
|
||||
self.features = [ModelFeature.STRUCTURED_OUTPUT]
|
||||
else:
|
||||
if ModelFeature.STRUCTURED_OUTPUT not in self.features:
|
||||
self.features = [*self.features, ModelFeature.STRUCTURED_OUTPUT]
|
||||
if not schema_key:
|
||||
return self
|
||||
if self.features is None:
|
||||
self.features = [ModelFeature.STRUCTURED_OUTPUT]
|
||||
else:
|
||||
if ModelFeature.STRUCTURED_OUTPUT not in self.features:
|
||||
self.features.append(ModelFeature.STRUCTURED_OUTPUT)
|
||||
return self
|
||||
|
||||
|
||||
|
|
|
@ -102,6 +102,12 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||
_node_type = NodeType.LLM
|
||||
|
||||
def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]:
|
||||
def process_structured_output(text: str) -> Optional[dict[str, Any] | list[Any]]:
|
||||
"""Process structured output if enabled"""
|
||||
if not self.node_data.structured_output_enabled or not self.node_data.structured_output:
|
||||
return None
|
||||
return self._parse_structured_output(text)
|
||||
|
||||
node_inputs: Optional[dict[str, Any]] = None
|
||||
process_data = None
|
||||
result_text = ""
|
||||
|
@ -201,19 +207,8 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||
self.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage)
|
||||
break
|
||||
outputs = {"text": result_text, "usage": jsonable_encoder(usage), "finish_reason": finish_reason}
|
||||
if self.node_data.structured_output_enabled and self.node_data.structured_output:
|
||||
structured_output: dict[str, Any] | list[Any] = {}
|
||||
try:
|
||||
parsed = json.loads(result_text)
|
||||
if not isinstance(parsed, (dict | list)):
|
||||
raise LLMNodeError(f"Failed to parse structured output: {result_text}")
|
||||
structured_output = parsed
|
||||
except json.JSONDecodeError as e:
|
||||
# if the result_text is not a valid json, try to repair it
|
||||
parsed = json_repair.loads(result_text)
|
||||
if not isinstance(parsed, (dict | list)):
|
||||
raise LLMNodeError(f"Failed to parse structured output: {result_text}")
|
||||
structured_output = parsed
|
||||
structured_output = process_structured_output(result_text)
|
||||
if structured_output:
|
||||
outputs["structured_output"] = structured_output
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
|
@ -759,6 +754,21 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||
stop = model_config.stop
|
||||
return filtered_prompt_messages, stop
|
||||
|
||||
def _parse_structured_output(self, result_text: str) -> dict[str, Any] | list[Any]:
|
||||
structured_output: dict[str, Any] | list[Any] = {}
|
||||
try:
|
||||
parsed = json.loads(result_text)
|
||||
if not isinstance(parsed, (dict | list)):
|
||||
raise LLMNodeError(f"Failed to parse structured output: {result_text}")
|
||||
structured_output = parsed
|
||||
except json.JSONDecodeError as e:
|
||||
# if the result_text is not a valid json, try to repair it
|
||||
parsed = json_repair.loads(result_text)
|
||||
if not isinstance(parsed, (dict | list)):
|
||||
raise LLMNodeError(f"Failed to parse structured output: {result_text}")
|
||||
structured_output = parsed
|
||||
return structured_output
|
||||
|
||||
@classmethod
|
||||
def deduct_llm_quota(cls, tenant_id: str, model_instance: ModelInstance, usage: LLMUsage) -> None:
|
||||
provider_model_bundle = model_instance.provider_model_bundle
|
||||
|
|
|
@ -622,11 +622,6 @@ PROMPT_GENERATION_MAX_TOKENS=512
|
|||
# Default: 1024 tokens.
|
||||
CODE_GENERATION_MAX_TOKENS=1024
|
||||
|
||||
# The maximum number of tokens allowed for structured output.
|
||||
# This setting controls the upper limit of tokens that can be used by the LLM
|
||||
# when generating structured output in the structured output tool.
|
||||
# Default: 1024 tokens.
|
||||
STRUCTURED_OUTPUT_MAX_TOKENS=1024
|
||||
# Enable or disable plugin based token counting. If disabled, token counting will return 0.
|
||||
# This can improve performance by skipping token counting operations.
|
||||
# Default: false (disabled).
|
||||
|
|
|
@ -279,7 +279,6 @@ x-shared-env: &shared-api-worker-env
|
|||
SCARF_NO_ANALYTICS: ${SCARF_NO_ANALYTICS:-true}
|
||||
PROMPT_GENERATION_MAX_TOKENS: ${PROMPT_GENERATION_MAX_TOKENS:-512}
|
||||
CODE_GENERATION_MAX_TOKENS: ${CODE_GENERATION_MAX_TOKENS:-1024}
|
||||
STRUCTURED_OUTPUT_MAX_TOKENS: ${STRUCTURED_OUTPUT_MAX_TOKENS:-1024}
|
||||
PLUGIN_BASED_TOKEN_COUNTING_ENABLED: ${PLUGIN_BASED_TOKEN_COUNTING_ENABLED:-false}
|
||||
MULTIMODAL_SEND_FORMAT: ${MULTIMODAL_SEND_FORMAT:-base64}
|
||||
UPLOAD_IMAGE_FILE_SIZE_LIMIT: ${UPLOAD_IMAGE_FILE_SIZE_LIMIT:-10}
|
||||
|
|
Loading…
Reference in New Issue