chore(api/core): apply ruff reformatting (#7624)

This commit is contained in:
Bowen Liang 2024-09-10 17:00:20 +08:00 committed by GitHub
parent 178730266d
commit 2cf1187b32
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
724 changed files with 21180 additions and 21123 deletions

View File

@ -1 +1 @@
import core.moderation.base import core.moderation.base

View File

@ -25,17 +25,19 @@ from models.model import Message
class CotAgentRunner(BaseAgentRunner, ABC): class CotAgentRunner(BaseAgentRunner, ABC):
_is_first_iteration = True _is_first_iteration = True
_ignore_observation_providers = ['wenxin'] _ignore_observation_providers = ["wenxin"]
_historic_prompt_messages: list[PromptMessage] = None _historic_prompt_messages: list[PromptMessage] = None
_agent_scratchpad: list[AgentScratchpadUnit] = None _agent_scratchpad: list[AgentScratchpadUnit] = None
_instruction: str = None _instruction: str = None
_query: str = None _query: str = None
_prompt_messages_tools: list[PromptMessage] = None _prompt_messages_tools: list[PromptMessage] = None
def run(self, message: Message, def run(
query: str, self,
inputs: dict[str, str], message: Message,
) -> Union[Generator, LLMResult]: query: str,
inputs: dict[str, str],
) -> Union[Generator, LLMResult]:
""" """
Run Cot agent application Run Cot agent application
""" """
@ -46,17 +48,16 @@ class CotAgentRunner(BaseAgentRunner, ABC):
trace_manager = app_generate_entity.trace_manager trace_manager = app_generate_entity.trace_manager
# check model mode # check model mode
if 'Observation' not in app_generate_entity.model_conf.stop: if "Observation" not in app_generate_entity.model_conf.stop:
if app_generate_entity.model_conf.provider not in self._ignore_observation_providers: if app_generate_entity.model_conf.provider not in self._ignore_observation_providers:
app_generate_entity.model_conf.stop.append('Observation') app_generate_entity.model_conf.stop.append("Observation")
app_config = self.app_config app_config = self.app_config
# init instruction # init instruction
inputs = inputs or {} inputs = inputs or {}
instruction = app_config.prompt_template.simple_prompt_template instruction = app_config.prompt_template.simple_prompt_template
self._instruction = self._fill_in_inputs_from_external_data_tools( self._instruction = self._fill_in_inputs_from_external_data_tools(instruction, inputs)
instruction, inputs)
iteration_step = 1 iteration_step = 1
max_iteration_steps = min(app_config.agent.max_iteration, 5) + 1 max_iteration_steps = min(app_config.agent.max_iteration, 5) + 1
@ -65,16 +66,14 @@ class CotAgentRunner(BaseAgentRunner, ABC):
tool_instances, self._prompt_messages_tools = self._init_prompt_tools() tool_instances, self._prompt_messages_tools = self._init_prompt_tools()
function_call_state = True function_call_state = True
llm_usage = { llm_usage = {"usage": None}
'usage': None final_answer = ""
}
final_answer = ''
def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage):
if not final_llm_usage_dict['usage']: if not final_llm_usage_dict["usage"]:
final_llm_usage_dict['usage'] = usage final_llm_usage_dict["usage"] = usage
else: else:
llm_usage = final_llm_usage_dict['usage'] llm_usage = final_llm_usage_dict["usage"]
llm_usage.prompt_tokens += usage.prompt_tokens llm_usage.prompt_tokens += usage.prompt_tokens
llm_usage.completion_tokens += usage.completion_tokens llm_usage.completion_tokens += usage.completion_tokens
llm_usage.prompt_price += usage.prompt_price llm_usage.prompt_price += usage.prompt_price
@ -94,17 +93,13 @@ class CotAgentRunner(BaseAgentRunner, ABC):
message_file_ids = [] message_file_ids = []
agent_thought = self.create_agent_thought( agent_thought = self.create_agent_thought(
message_id=message.id, message_id=message.id, message="", tool_name="", tool_input="", messages_ids=message_file_ids
message='',
tool_name='',
tool_input='',
messages_ids=message_file_ids
) )
if iteration_step > 1: if iteration_step > 1:
self.queue_manager.publish(QueueAgentThoughtEvent( self.queue_manager.publish(
agent_thought_id=agent_thought.id QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
), PublishFrom.APPLICATION_MANAGER) )
# recalc llm max tokens # recalc llm max tokens
prompt_messages = self._organize_prompt_messages() prompt_messages = self._organize_prompt_messages()
@ -125,21 +120,20 @@ class CotAgentRunner(BaseAgentRunner, ABC):
raise ValueError("failed to invoke llm") raise ValueError("failed to invoke llm")
usage_dict = {} usage_dict = {}
react_chunks = CotAgentOutputParser.handle_react_stream_output( react_chunks = CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict)
chunks, usage_dict)
scratchpad = AgentScratchpadUnit( scratchpad = AgentScratchpadUnit(
agent_response='', agent_response="",
thought='', thought="",
action_str='', action_str="",
observation='', observation="",
action=None, action=None,
) )
# publish agent thought if it's first iteration # publish agent thought if it's first iteration
if iteration_step == 1: if iteration_step == 1:
self.queue_manager.publish(QueueAgentThoughtEvent( self.queue_manager.publish(
agent_thought_id=agent_thought.id QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
), PublishFrom.APPLICATION_MANAGER) )
for chunk in react_chunks: for chunk in react_chunks:
if isinstance(chunk, AgentScratchpadUnit.Action): if isinstance(chunk, AgentScratchpadUnit.Action):
@ -154,61 +148,51 @@ class CotAgentRunner(BaseAgentRunner, ABC):
yield LLMResultChunk( yield LLMResultChunk(
model=self.model_config.model, model=self.model_config.model,
prompt_messages=prompt_messages, prompt_messages=prompt_messages,
system_fingerprint='', system_fingerprint="",
delta=LLMResultChunkDelta( delta=LLMResultChunkDelta(index=0, message=AssistantPromptMessage(content=chunk), usage=None),
index=0,
message=AssistantPromptMessage(
content=chunk
),
usage=None
)
) )
scratchpad.thought = scratchpad.thought.strip( scratchpad.thought = scratchpad.thought.strip() or "I am thinking about how to help you"
) or 'I am thinking about how to help you'
self._agent_scratchpad.append(scratchpad) self._agent_scratchpad.append(scratchpad)
# get llm usage # get llm usage
if 'usage' in usage_dict: if "usage" in usage_dict:
increase_usage(llm_usage, usage_dict['usage']) increase_usage(llm_usage, usage_dict["usage"])
else: else:
usage_dict['usage'] = LLMUsage.empty_usage() usage_dict["usage"] = LLMUsage.empty_usage()
self.save_agent_thought( self.save_agent_thought(
agent_thought=agent_thought, agent_thought=agent_thought,
tool_name=scratchpad.action.action_name if scratchpad.action else '', tool_name=scratchpad.action.action_name if scratchpad.action else "",
tool_input={ tool_input={scratchpad.action.action_name: scratchpad.action.action_input} if scratchpad.action else {},
scratchpad.action.action_name: scratchpad.action.action_input
} if scratchpad.action else {},
tool_invoke_meta={}, tool_invoke_meta={},
thought=scratchpad.thought, thought=scratchpad.thought,
observation='', observation="",
answer=scratchpad.agent_response, answer=scratchpad.agent_response,
messages_ids=[], messages_ids=[],
llm_usage=usage_dict['usage'] llm_usage=usage_dict["usage"],
) )
if not scratchpad.is_final(): if not scratchpad.is_final():
self.queue_manager.publish(QueueAgentThoughtEvent( self.queue_manager.publish(
agent_thought_id=agent_thought.id QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
), PublishFrom.APPLICATION_MANAGER) )
if not scratchpad.action: if not scratchpad.action:
# failed to extract action, return final answer directly # failed to extract action, return final answer directly
final_answer = '' final_answer = ""
else: else:
if scratchpad.action.action_name.lower() == "final answer": if scratchpad.action.action_name.lower() == "final answer":
# action is final answer, return final answer directly # action is final answer, return final answer directly
try: try:
if isinstance(scratchpad.action.action_input, dict): if isinstance(scratchpad.action.action_input, dict):
final_answer = json.dumps( final_answer = json.dumps(scratchpad.action.action_input)
scratchpad.action.action_input)
elif isinstance(scratchpad.action.action_input, str): elif isinstance(scratchpad.action.action_input, str):
final_answer = scratchpad.action.action_input final_answer = scratchpad.action.action_input
else: else:
final_answer = f'{scratchpad.action.action_input}' final_answer = f"{scratchpad.action.action_input}"
except json.JSONDecodeError: except json.JSONDecodeError:
final_answer = f'{scratchpad.action.action_input}' final_answer = f"{scratchpad.action.action_input}"
else: else:
function_call_state = True function_call_state = True
# action is tool call, invoke tool # action is tool call, invoke tool
@ -224,21 +208,18 @@ class CotAgentRunner(BaseAgentRunner, ABC):
self.save_agent_thought( self.save_agent_thought(
agent_thought=agent_thought, agent_thought=agent_thought,
tool_name=scratchpad.action.action_name, tool_name=scratchpad.action.action_name,
tool_input={ tool_input={scratchpad.action.action_name: scratchpad.action.action_input},
scratchpad.action.action_name: scratchpad.action.action_input},
thought=scratchpad.thought, thought=scratchpad.thought,
observation={ observation={scratchpad.action.action_name: tool_invoke_response},
scratchpad.action.action_name: tool_invoke_response}, tool_invoke_meta={scratchpad.action.action_name: tool_invoke_meta.to_dict()},
tool_invoke_meta={
scratchpad.action.action_name: tool_invoke_meta.to_dict()},
answer=scratchpad.agent_response, answer=scratchpad.agent_response,
messages_ids=message_file_ids, messages_ids=message_file_ids,
llm_usage=usage_dict['usage'] llm_usage=usage_dict["usage"],
) )
self.queue_manager.publish(QueueAgentThoughtEvent( self.queue_manager.publish(
agent_thought_id=agent_thought.id QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
), PublishFrom.APPLICATION_MANAGER) )
# update prompt tool message # update prompt tool message
for prompt_tool in self._prompt_messages_tools: for prompt_tool in self._prompt_messages_tools:
@ -250,44 +231,45 @@ class CotAgentRunner(BaseAgentRunner, ABC):
model=model_instance.model, model=model_instance.model,
prompt_messages=prompt_messages, prompt_messages=prompt_messages,
delta=LLMResultChunkDelta( delta=LLMResultChunkDelta(
index=0, index=0, message=AssistantPromptMessage(content=final_answer), usage=llm_usage["usage"]
message=AssistantPromptMessage(
content=final_answer
),
usage=llm_usage['usage']
), ),
system_fingerprint='' system_fingerprint="",
) )
# save agent thought # save agent thought
self.save_agent_thought( self.save_agent_thought(
agent_thought=agent_thought, agent_thought=agent_thought,
tool_name='', tool_name="",
tool_input={}, tool_input={},
tool_invoke_meta={}, tool_invoke_meta={},
thought=final_answer, thought=final_answer,
observation={}, observation={},
answer=final_answer, answer=final_answer,
messages_ids=[] messages_ids=[],
) )
self.update_db_variables(self.variables_pool, self.db_variables_pool) self.update_db_variables(self.variables_pool, self.db_variables_pool)
# publish end event # publish end event
self.queue_manager.publish(QueueMessageEndEvent(llm_result=LLMResult( self.queue_manager.publish(
model=model_instance.model, QueueMessageEndEvent(
prompt_messages=prompt_messages, llm_result=LLMResult(
message=AssistantPromptMessage( model=model_instance.model,
content=final_answer prompt_messages=prompt_messages,
message=AssistantPromptMessage(content=final_answer),
usage=llm_usage["usage"] if llm_usage["usage"] else LLMUsage.empty_usage(),
system_fingerprint="",
)
), ),
usage=llm_usage['usage'] if llm_usage['usage'] else LLMUsage.empty_usage(), PublishFrom.APPLICATION_MANAGER,
system_fingerprint='' )
)), PublishFrom.APPLICATION_MANAGER)
def _handle_invoke_action(self, action: AgentScratchpadUnit.Action, def _handle_invoke_action(
tool_instances: dict[str, Tool], self,
message_file_ids: list[str], action: AgentScratchpadUnit.Action,
trace_manager: Optional[TraceQueueManager] = None tool_instances: dict[str, Tool],
) -> tuple[str, ToolInvokeMeta]: message_file_ids: list[str],
trace_manager: Optional[TraceQueueManager] = None,
) -> tuple[str, ToolInvokeMeta]:
""" """
handle invoke action handle invoke action
:param action: action :param action: action
@ -326,13 +308,12 @@ class CotAgentRunner(BaseAgentRunner, ABC):
# publish files # publish files
for message_file_id, save_as in message_files: for message_file_id, save_as in message_files:
if save_as: if save_as:
self.variables_pool.set_file( self.variables_pool.set_file(tool_name=tool_call_name, value=message_file_id, name=save_as)
tool_name=tool_call_name, value=message_file_id, name=save_as)
# publish message file # publish message file
self.queue_manager.publish(QueueMessageFileEvent( self.queue_manager.publish(
message_file_id=message_file_id QueueMessageFileEvent(message_file_id=message_file_id), PublishFrom.APPLICATION_MANAGER
), PublishFrom.APPLICATION_MANAGER) )
# add message file ids # add message file ids
message_file_ids.append(message_file_id) message_file_ids.append(message_file_id)
@ -342,10 +323,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
""" """
convert dict to action convert dict to action
""" """
return AgentScratchpadUnit.Action( return AgentScratchpadUnit.Action(action_name=action["action"], action_input=action["action_input"])
action_name=action['action'],
action_input=action['action_input']
)
def _fill_in_inputs_from_external_data_tools(self, instruction: str, inputs: dict) -> str: def _fill_in_inputs_from_external_data_tools(self, instruction: str, inputs: dict) -> str:
""" """
@ -353,7 +331,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
""" """
for key, value in inputs.items(): for key, value in inputs.items():
try: try:
instruction = instruction.replace(f'{{{{{key}}}}}', str(value)) instruction = instruction.replace(f"{{{{{key}}}}}", str(value))
except Exception as e: except Exception as e:
continue continue
@ -370,14 +348,14 @@ class CotAgentRunner(BaseAgentRunner, ABC):
@abstractmethod @abstractmethod
def _organize_prompt_messages(self) -> list[PromptMessage]: def _organize_prompt_messages(self) -> list[PromptMessage]:
""" """
organize prompt messages organize prompt messages
""" """
def _format_assistant_message(self, agent_scratchpad: list[AgentScratchpadUnit]) -> str: def _format_assistant_message(self, agent_scratchpad: list[AgentScratchpadUnit]) -> str:
""" """
format assistant message format assistant message
""" """
message = '' message = ""
for scratchpad in agent_scratchpad: for scratchpad in agent_scratchpad:
if scratchpad.is_final(): if scratchpad.is_final():
message += f"Final Answer: {scratchpad.agent_response}" message += f"Final Answer: {scratchpad.agent_response}"
@ -390,9 +368,11 @@ class CotAgentRunner(BaseAgentRunner, ABC):
return message return message
def _organize_historic_prompt_messages(self, current_session_messages: list[PromptMessage] = None) -> list[PromptMessage]: def _organize_historic_prompt_messages(
self, current_session_messages: list[PromptMessage] = None
) -> list[PromptMessage]:
""" """
organize historic prompt messages organize historic prompt messages
""" """
result: list[PromptMessage] = [] result: list[PromptMessage] = []
scratchpads: list[AgentScratchpadUnit] = [] scratchpads: list[AgentScratchpadUnit] = []
@ -403,8 +383,8 @@ class CotAgentRunner(BaseAgentRunner, ABC):
if not current_scratchpad: if not current_scratchpad:
current_scratchpad = AgentScratchpadUnit( current_scratchpad = AgentScratchpadUnit(
agent_response=message.content, agent_response=message.content,
thought=message.content or 'I am thinking about how to help you', thought=message.content or "I am thinking about how to help you",
action_str='', action_str="",
action=None, action=None,
observation=None, observation=None,
) )
@ -413,12 +393,9 @@ class CotAgentRunner(BaseAgentRunner, ABC):
try: try:
current_scratchpad.action = AgentScratchpadUnit.Action( current_scratchpad.action = AgentScratchpadUnit.Action(
action_name=message.tool_calls[0].function.name, action_name=message.tool_calls[0].function.name,
action_input=json.loads( action_input=json.loads(message.tool_calls[0].function.arguments),
message.tool_calls[0].function.arguments)
)
current_scratchpad.action_str = json.dumps(
current_scratchpad.action.to_dict()
) )
current_scratchpad.action_str = json.dumps(current_scratchpad.action.to_dict())
except: except:
pass pass
elif isinstance(message, ToolPromptMessage): elif isinstance(message, ToolPromptMessage):
@ -426,23 +403,19 @@ class CotAgentRunner(BaseAgentRunner, ABC):
current_scratchpad.observation = message.content current_scratchpad.observation = message.content
elif isinstance(message, UserPromptMessage): elif isinstance(message, UserPromptMessage):
if scratchpads: if scratchpads:
result.append(AssistantPromptMessage( result.append(AssistantPromptMessage(content=self._format_assistant_message(scratchpads)))
content=self._format_assistant_message(scratchpads)
))
scratchpads = [] scratchpads = []
current_scratchpad = None current_scratchpad = None
result.append(message) result.append(message)
if scratchpads: if scratchpads:
result.append(AssistantPromptMessage( result.append(AssistantPromptMessage(content=self._format_assistant_message(scratchpads)))
content=self._format_assistant_message(scratchpads)
))
historic_prompts = AgentHistoryPromptTransform( historic_prompts = AgentHistoryPromptTransform(
model_config=self.model_config, model_config=self.model_config,
prompt_messages=current_session_messages or [], prompt_messages=current_session_messages or [],
history_messages=result, history_messages=result,
memory=self.memory memory=self.memory,
).get_prompt() ).get_prompt()
return historic_prompts return historic_prompts

View File

@ -19,14 +19,15 @@ class CotChatAgentRunner(CotAgentRunner):
prompt_entity = self.app_config.agent.prompt prompt_entity = self.app_config.agent.prompt
first_prompt = prompt_entity.first_prompt first_prompt = prompt_entity.first_prompt
system_prompt = first_prompt \ system_prompt = (
.replace("{{instruction}}", self._instruction) \ first_prompt.replace("{{instruction}}", self._instruction)
.replace("{{tools}}", json.dumps(jsonable_encoder(self._prompt_messages_tools))) \ .replace("{{tools}}", json.dumps(jsonable_encoder(self._prompt_messages_tools)))
.replace("{{tool_names}}", ', '.join([tool.name for tool in self._prompt_messages_tools])) .replace("{{tool_names}}", ", ".join([tool.name for tool in self._prompt_messages_tools]))
)
return SystemPromptMessage(content=system_prompt) return SystemPromptMessage(content=system_prompt)
def _organize_user_query(self, query, prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]: def _organize_user_query(self, query, prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]:
""" """
Organize user query Organize user query
""" """
@ -43,7 +44,7 @@ class CotChatAgentRunner(CotAgentRunner):
def _organize_prompt_messages(self) -> list[PromptMessage]: def _organize_prompt_messages(self) -> list[PromptMessage]:
""" """
Organize Organize
""" """
# organize system prompt # organize system prompt
system_message = self._organize_system_prompt() system_message = self._organize_system_prompt()
@ -53,7 +54,7 @@ class CotChatAgentRunner(CotAgentRunner):
if not agent_scratchpad: if not agent_scratchpad:
assistant_messages = [] assistant_messages = []
else: else:
assistant_message = AssistantPromptMessage(content='') assistant_message = AssistantPromptMessage(content="")
for unit in agent_scratchpad: for unit in agent_scratchpad:
if unit.is_final(): if unit.is_final():
assistant_message.content += f"Final Answer: {unit.agent_response}" assistant_message.content += f"Final Answer: {unit.agent_response}"
@ -71,18 +72,15 @@ class CotChatAgentRunner(CotAgentRunner):
if assistant_messages: if assistant_messages:
# organize historic prompt messages # organize historic prompt messages
historic_messages = self._organize_historic_prompt_messages([ historic_messages = self._organize_historic_prompt_messages(
system_message, [system_message, *query_messages, *assistant_messages, UserPromptMessage(content="continue")]
*query_messages, )
*assistant_messages,
UserPromptMessage(content='continue')
])
messages = [ messages = [
system_message, system_message,
*historic_messages, *historic_messages,
*query_messages, *query_messages,
*assistant_messages, *assistant_messages,
UserPromptMessage(content='continue') UserPromptMessage(content="continue"),
] ]
else: else:
# organize historic prompt messages # organize historic prompt messages

View File

@ -13,10 +13,12 @@ class CotCompletionAgentRunner(CotAgentRunner):
prompt_entity = self.app_config.agent.prompt prompt_entity = self.app_config.agent.prompt
first_prompt = prompt_entity.first_prompt first_prompt = prompt_entity.first_prompt
system_prompt = first_prompt.replace("{{instruction}}", self._instruction) \ system_prompt = (
.replace("{{tools}}", json.dumps(jsonable_encoder(self._prompt_messages_tools))) \ first_prompt.replace("{{instruction}}", self._instruction)
.replace("{{tool_names}}", ', '.join([tool.name for tool in self._prompt_messages_tools])) .replace("{{tools}}", json.dumps(jsonable_encoder(self._prompt_messages_tools)))
.replace("{{tool_names}}", ", ".join([tool.name for tool in self._prompt_messages_tools]))
)
return system_prompt return system_prompt
def _organize_historic_prompt(self, current_session_messages: list[PromptMessage] = None) -> str: def _organize_historic_prompt(self, current_session_messages: list[PromptMessage] = None) -> str:
@ -46,7 +48,7 @@ class CotCompletionAgentRunner(CotAgentRunner):
# organize current assistant messages # organize current assistant messages
agent_scratchpad = self._agent_scratchpad agent_scratchpad = self._agent_scratchpad
assistant_prompt = '' assistant_prompt = ""
for unit in agent_scratchpad: for unit in agent_scratchpad:
if unit.is_final(): if unit.is_final():
assistant_prompt += f"Final Answer: {unit.agent_response}" assistant_prompt += f"Final Answer: {unit.agent_response}"
@ -61,9 +63,10 @@ class CotCompletionAgentRunner(CotAgentRunner):
query_prompt = f"Question: {self._query}" query_prompt = f"Question: {self._query}"
# join all messages # join all messages
prompt = system_prompt \ prompt = (
.replace("{{historic_messages}}", historic_prompt) \ system_prompt.replace("{{historic_messages}}", historic_prompt)
.replace("{{agent_scratchpad}}", assistant_prompt) \ .replace("{{agent_scratchpad}}", assistant_prompt)
.replace("{{query}}", query_prompt) .replace("{{query}}", query_prompt)
)
return [UserPromptMessage(content=prompt)] return [UserPromptMessage(content=prompt)]

View File

@ -8,6 +8,7 @@ class AgentToolEntity(BaseModel):
""" """
Agent Tool Entity. Agent Tool Entity.
""" """
provider_type: Literal["builtin", "api", "workflow"] provider_type: Literal["builtin", "api", "workflow"]
provider_id: str provider_id: str
tool_name: str tool_name: str
@ -18,6 +19,7 @@ class AgentPromptEntity(BaseModel):
""" """
Agent Prompt Entity. Agent Prompt Entity.
""" """
first_prompt: str first_prompt: str
next_iteration: str next_iteration: str
@ -31,6 +33,7 @@ class AgentScratchpadUnit(BaseModel):
""" """
Action Entity. Action Entity.
""" """
action_name: str action_name: str
action_input: Union[dict, str] action_input: Union[dict, str]
@ -39,8 +42,8 @@ class AgentScratchpadUnit(BaseModel):
Convert to dictionary. Convert to dictionary.
""" """
return { return {
'action': self.action_name, "action": self.action_name,
'action_input': self.action_input, "action_input": self.action_input,
} }
agent_response: Optional[str] = None agent_response: Optional[str] = None
@ -54,10 +57,10 @@ class AgentScratchpadUnit(BaseModel):
Check if the scratchpad unit is final. Check if the scratchpad unit is final.
""" """
return self.action is None or ( return self.action is None or (
'final' in self.action.action_name.lower() and "final" in self.action.action_name.lower() and "answer" in self.action.action_name.lower()
'answer' in self.action.action_name.lower()
) )
class AgentEntity(BaseModel): class AgentEntity(BaseModel):
""" """
Agent Entity. Agent Entity.
@ -67,8 +70,9 @@ class AgentEntity(BaseModel):
""" """
Agent Strategy. Agent Strategy.
""" """
CHAIN_OF_THOUGHT = 'chain-of-thought'
FUNCTION_CALLING = 'function-calling' CHAIN_OF_THOUGHT = "chain-of-thought"
FUNCTION_CALLING = "function-calling"
provider: str provider: str
model: str model: str

View File

@ -24,11 +24,9 @@ from models.model import Message
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class FunctionCallAgentRunner(BaseAgentRunner):
def run(self, class FunctionCallAgentRunner(BaseAgentRunner):
message: Message, query: str, **kwargs: Any def run(self, message: Message, query: str, **kwargs: Any) -> Generator[LLMResultChunk, None, None]:
) -> Generator[LLMResultChunk, None, None]:
""" """
Run FunctionCall agent application Run FunctionCall agent application
""" """
@ -45,19 +43,17 @@ class FunctionCallAgentRunner(BaseAgentRunner):
# continue to run until there is not any tool call # continue to run until there is not any tool call
function_call_state = True function_call_state = True
llm_usage = { llm_usage = {"usage": None}
'usage': None final_answer = ""
}
final_answer = ''
# get tracing instance # get tracing instance
trace_manager = app_generate_entity.trace_manager trace_manager = app_generate_entity.trace_manager
def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage):
if not final_llm_usage_dict['usage']: if not final_llm_usage_dict["usage"]:
final_llm_usage_dict['usage'] = usage final_llm_usage_dict["usage"] = usage
else: else:
llm_usage = final_llm_usage_dict['usage'] llm_usage = final_llm_usage_dict["usage"]
llm_usage.prompt_tokens += usage.prompt_tokens llm_usage.prompt_tokens += usage.prompt_tokens
llm_usage.completion_tokens += usage.completion_tokens llm_usage.completion_tokens += usage.completion_tokens
llm_usage.prompt_price += usage.prompt_price llm_usage.prompt_price += usage.prompt_price
@ -75,11 +71,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
message_file_ids = [] message_file_ids = []
agent_thought = self.create_agent_thought( agent_thought = self.create_agent_thought(
message_id=message.id, message_id=message.id, message="", tool_name="", tool_input="", messages_ids=message_file_ids
message='',
tool_name='',
tool_input='',
messages_ids=message_file_ids
) )
# recalc llm max tokens # recalc llm max tokens
@ -99,11 +91,11 @@ class FunctionCallAgentRunner(BaseAgentRunner):
tool_calls: list[tuple[str, str, dict[str, Any]]] = [] tool_calls: list[tuple[str, str, dict[str, Any]]] = []
# save full response # save full response
response = '' response = ""
# save tool call names and inputs # save tool call names and inputs
tool_call_names = '' tool_call_names = ""
tool_call_inputs = '' tool_call_inputs = ""
current_llm_usage = None current_llm_usage = None
@ -111,24 +103,22 @@ class FunctionCallAgentRunner(BaseAgentRunner):
is_first_chunk = True is_first_chunk = True
for chunk in chunks: for chunk in chunks:
if is_first_chunk: if is_first_chunk:
self.queue_manager.publish(QueueAgentThoughtEvent( self.queue_manager.publish(
agent_thought_id=agent_thought.id QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
), PublishFrom.APPLICATION_MANAGER) )
is_first_chunk = False is_first_chunk = False
# check if there is any tool call # check if there is any tool call
if self.check_tool_calls(chunk): if self.check_tool_calls(chunk):
function_call_state = True function_call_state = True
tool_calls.extend(self.extract_tool_calls(chunk)) tool_calls.extend(self.extract_tool_calls(chunk))
tool_call_names = ';'.join([tool_call[1] for tool_call in tool_calls]) tool_call_names = ";".join([tool_call[1] for tool_call in tool_calls])
try: try:
tool_call_inputs = json.dumps({ tool_call_inputs = json.dumps(
tool_call[1]: tool_call[2] for tool_call in tool_calls {tool_call[1]: tool_call[2] for tool_call in tool_calls}, ensure_ascii=False
}, ensure_ascii=False) )
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
# ensure ascii to avoid encoding error # ensure ascii to avoid encoding error
tool_call_inputs = json.dumps({ tool_call_inputs = json.dumps({tool_call[1]: tool_call[2] for tool_call in tool_calls})
tool_call[1]: tool_call[2] for tool_call in tool_calls
})
if chunk.delta.message and chunk.delta.message.content: if chunk.delta.message and chunk.delta.message.content:
if isinstance(chunk.delta.message.content, list): if isinstance(chunk.delta.message.content, list):
@ -148,16 +138,14 @@ class FunctionCallAgentRunner(BaseAgentRunner):
if self.check_blocking_tool_calls(result): if self.check_blocking_tool_calls(result):
function_call_state = True function_call_state = True
tool_calls.extend(self.extract_blocking_tool_calls(result)) tool_calls.extend(self.extract_blocking_tool_calls(result))
tool_call_names = ';'.join([tool_call[1] for tool_call in tool_calls]) tool_call_names = ";".join([tool_call[1] for tool_call in tool_calls])
try: try:
tool_call_inputs = json.dumps({ tool_call_inputs = json.dumps(
tool_call[1]: tool_call[2] for tool_call in tool_calls {tool_call[1]: tool_call[2] for tool_call in tool_calls}, ensure_ascii=False
}, ensure_ascii=False) )
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
# ensure ascii to avoid encoding error # ensure ascii to avoid encoding error
tool_call_inputs = json.dumps({ tool_call_inputs = json.dumps({tool_call[1]: tool_call[2] for tool_call in tool_calls})
tool_call[1]: tool_call[2] for tool_call in tool_calls
})
if result.usage: if result.usage:
increase_usage(llm_usage, result.usage) increase_usage(llm_usage, result.usage)
@ -171,12 +159,12 @@ class FunctionCallAgentRunner(BaseAgentRunner):
response += result.message.content response += result.message.content
if not result.message.content: if not result.message.content:
result.message.content = '' result.message.content = ""
self.queue_manager.publish(
QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
)
self.queue_manager.publish(QueueAgentThoughtEvent(
agent_thought_id=agent_thought.id
), PublishFrom.APPLICATION_MANAGER)
yield LLMResultChunk( yield LLMResultChunk(
model=model_instance.model, model=model_instance.model,
prompt_messages=result.prompt_messages, prompt_messages=result.prompt_messages,
@ -185,32 +173,29 @@ class FunctionCallAgentRunner(BaseAgentRunner):
index=0, index=0,
message=result.message, message=result.message,
usage=result.usage, usage=result.usage,
) ),
) )
assistant_message = AssistantPromptMessage( assistant_message = AssistantPromptMessage(content="", tool_calls=[])
content='',
tool_calls=[]
)
if tool_calls: if tool_calls:
assistant_message.tool_calls=[ assistant_message.tool_calls = [
AssistantPromptMessage.ToolCall( AssistantPromptMessage.ToolCall(
id=tool_call[0], id=tool_call[0],
type='function', type="function",
function=AssistantPromptMessage.ToolCall.ToolCallFunction( function=AssistantPromptMessage.ToolCall.ToolCallFunction(
name=tool_call[1], name=tool_call[1], arguments=json.dumps(tool_call[2], ensure_ascii=False)
arguments=json.dumps(tool_call[2], ensure_ascii=False) ),
) )
) for tool_call in tool_calls for tool_call in tool_calls
] ]
else: else:
assistant_message.content = response assistant_message.content = response
self._current_thoughts.append(assistant_message) self._current_thoughts.append(assistant_message)
# save thought # save thought
self.save_agent_thought( self.save_agent_thought(
agent_thought=agent_thought, agent_thought=agent_thought,
tool_name=tool_call_names, tool_name=tool_call_names,
tool_input=tool_call_inputs, tool_input=tool_call_inputs,
thought=response, thought=response,
@ -218,13 +203,13 @@ class FunctionCallAgentRunner(BaseAgentRunner):
observation=None, observation=None,
answer=response, answer=response,
messages_ids=[], messages_ids=[],
llm_usage=current_llm_usage llm_usage=current_llm_usage,
) )
self.queue_manager.publish(QueueAgentThoughtEvent( self.queue_manager.publish(
agent_thought_id=agent_thought.id QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
), PublishFrom.APPLICATION_MANAGER) )
final_answer += response + '\n' final_answer += response + "\n"
# call tools # call tools
tool_responses = [] tool_responses = []
@ -235,7 +220,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
"tool_call_id": tool_call_id, "tool_call_id": tool_call_id,
"tool_call_name": tool_call_name, "tool_call_name": tool_call_name,
"tool_response": f"there is not a tool named {tool_call_name}", "tool_response": f"there is not a tool named {tool_call_name}",
"meta": ToolInvokeMeta.error_instance(f"there is not a tool named {tool_call_name}").to_dict() "meta": ToolInvokeMeta.error_instance(f"there is not a tool named {tool_call_name}").to_dict(),
} }
else: else:
# invoke tool # invoke tool
@ -255,50 +240,49 @@ class FunctionCallAgentRunner(BaseAgentRunner):
self.variables_pool.set_file(tool_name=tool_call_name, value=message_file_id, name=save_as) self.variables_pool.set_file(tool_name=tool_call_name, value=message_file_id, name=save_as)
# publish message file # publish message file
self.queue_manager.publish(QueueMessageFileEvent( self.queue_manager.publish(
message_file_id=message_file_id QueueMessageFileEvent(message_file_id=message_file_id), PublishFrom.APPLICATION_MANAGER
), PublishFrom.APPLICATION_MANAGER) )
# add message file ids # add message file ids
message_file_ids.append(message_file_id) message_file_ids.append(message_file_id)
tool_response = { tool_response = {
"tool_call_id": tool_call_id, "tool_call_id": tool_call_id,
"tool_call_name": tool_call_name, "tool_call_name": tool_call_name,
"tool_response": tool_invoke_response, "tool_response": tool_invoke_response,
"meta": tool_invoke_meta.to_dict() "meta": tool_invoke_meta.to_dict(),
} }
tool_responses.append(tool_response) tool_responses.append(tool_response)
if tool_response['tool_response'] is not None: if tool_response["tool_response"] is not None:
self._current_thoughts.append( self._current_thoughts.append(
ToolPromptMessage( ToolPromptMessage(
content=tool_response['tool_response'], content=tool_response["tool_response"],
tool_call_id=tool_call_id, tool_call_id=tool_call_id,
name=tool_call_name, name=tool_call_name,
) )
) )
if len(tool_responses) > 0: if len(tool_responses) > 0:
# save agent thought # save agent thought
self.save_agent_thought( self.save_agent_thought(
agent_thought=agent_thought, agent_thought=agent_thought,
tool_name=None, tool_name=None,
tool_input=None, tool_input=None,
thought=None, thought=None,
tool_invoke_meta={ tool_invoke_meta={
tool_response['tool_call_name']: tool_response['meta'] tool_response["tool_call_name"]: tool_response["meta"] for tool_response in tool_responses
for tool_response in tool_responses
}, },
observation={ observation={
tool_response['tool_call_name']: tool_response['tool_response'] tool_response["tool_call_name"]: tool_response["tool_response"]
for tool_response in tool_responses for tool_response in tool_responses
}, },
answer=None, answer=None,
messages_ids=message_file_ids messages_ids=message_file_ids,
)
self.queue_manager.publish(
QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
) )
self.queue_manager.publish(QueueAgentThoughtEvent(
agent_thought_id=agent_thought.id
), PublishFrom.APPLICATION_MANAGER)
# update prompt tool # update prompt tool
for prompt_tool in prompt_messages_tools: for prompt_tool in prompt_messages_tools:
@ -308,15 +292,18 @@ class FunctionCallAgentRunner(BaseAgentRunner):
self.update_db_variables(self.variables_pool, self.db_variables_pool) self.update_db_variables(self.variables_pool, self.db_variables_pool)
# publish end event # publish end event
self.queue_manager.publish(QueueMessageEndEvent(llm_result=LLMResult( self.queue_manager.publish(
model=model_instance.model, QueueMessageEndEvent(
prompt_messages=prompt_messages, llm_result=LLMResult(
message=AssistantPromptMessage( model=model_instance.model,
content=final_answer prompt_messages=prompt_messages,
message=AssistantPromptMessage(content=final_answer),
usage=llm_usage["usage"] if llm_usage["usage"] else LLMUsage.empty_usage(),
system_fingerprint="",
)
), ),
usage=llm_usage['usage'] if llm_usage['usage'] else LLMUsage.empty_usage(), PublishFrom.APPLICATION_MANAGER,
system_fingerprint='' )
)), PublishFrom.APPLICATION_MANAGER)
def check_tool_calls(self, llm_result_chunk: LLMResultChunk) -> bool: def check_tool_calls(self, llm_result_chunk: LLMResultChunk) -> bool:
""" """
@ -325,7 +312,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
if llm_result_chunk.delta.message.tool_calls: if llm_result_chunk.delta.message.tool_calls:
return True return True
return False return False
def check_blocking_tool_calls(self, llm_result: LLMResult) -> bool: def check_blocking_tool_calls(self, llm_result: LLMResult) -> bool:
""" """
Check if there is any blocking tool call in llm result Check if there is any blocking tool call in llm result
@ -334,7 +321,9 @@ class FunctionCallAgentRunner(BaseAgentRunner):
return True return True
return False return False
def extract_tool_calls(self, llm_result_chunk: LLMResultChunk) -> Union[None, list[tuple[str, str, dict[str, Any]]]]: def extract_tool_calls(
self, llm_result_chunk: LLMResultChunk
) -> Union[None, list[tuple[str, str, dict[str, Any]]]]:
""" """
Extract tool calls from llm result chunk Extract tool calls from llm result chunk
@ -344,17 +333,19 @@ class FunctionCallAgentRunner(BaseAgentRunner):
tool_calls = [] tool_calls = []
for prompt_message in llm_result_chunk.delta.message.tool_calls: for prompt_message in llm_result_chunk.delta.message.tool_calls:
args = {} args = {}
if prompt_message.function.arguments != '': if prompt_message.function.arguments != "":
args = json.loads(prompt_message.function.arguments) args = json.loads(prompt_message.function.arguments)
tool_calls.append(( tool_calls.append(
prompt_message.id, (
prompt_message.function.name, prompt_message.id,
args, prompt_message.function.name,
)) args,
)
)
return tool_calls return tool_calls
def extract_blocking_tool_calls(self, llm_result: LLMResult) -> Union[None, list[tuple[str, str, dict[str, Any]]]]: def extract_blocking_tool_calls(self, llm_result: LLMResult) -> Union[None, list[tuple[str, str, dict[str, Any]]]]:
""" """
Extract blocking tool calls from llm result Extract blocking tool calls from llm result
@ -365,18 +356,22 @@ class FunctionCallAgentRunner(BaseAgentRunner):
tool_calls = [] tool_calls = []
for prompt_message in llm_result.message.tool_calls: for prompt_message in llm_result.message.tool_calls:
args = {} args = {}
if prompt_message.function.arguments != '': if prompt_message.function.arguments != "":
args = json.loads(prompt_message.function.arguments) args = json.loads(prompt_message.function.arguments)
tool_calls.append(( tool_calls.append(
prompt_message.id, (
prompt_message.function.name, prompt_message.id,
args, prompt_message.function.name,
)) args,
)
)
return tool_calls return tool_calls
def _init_system_message(self, prompt_template: str, prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]: def _init_system_message(
self, prompt_template: str, prompt_messages: list[PromptMessage] = None
) -> list[PromptMessage]:
""" """
Initialize system message Initialize system message
""" """
@ -384,13 +379,13 @@ class FunctionCallAgentRunner(BaseAgentRunner):
return [ return [
SystemPromptMessage(content=prompt_template), SystemPromptMessage(content=prompt_template),
] ]
if prompt_messages and not isinstance(prompt_messages[0], SystemPromptMessage) and prompt_template: if prompt_messages and not isinstance(prompt_messages[0], SystemPromptMessage) and prompt_template:
prompt_messages.insert(0, SystemPromptMessage(content=prompt_template)) prompt_messages.insert(0, SystemPromptMessage(content=prompt_template))
return prompt_messages return prompt_messages
def _organize_user_query(self, query, prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]: def _organize_user_query(self, query, prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]:
""" """
Organize user query Organize user query
""" """
@ -404,7 +399,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
prompt_messages.append(UserPromptMessage(content=query)) prompt_messages.append(UserPromptMessage(content=query))
return prompt_messages return prompt_messages
def _clear_user_prompt_image_messages(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]: def _clear_user_prompt_image_messages(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
""" """
As for now, gpt supports both fc and vision at the first iteration. As for now, gpt supports both fc and vision at the first iteration.
@ -415,17 +410,21 @@ class FunctionCallAgentRunner(BaseAgentRunner):
for prompt_message in prompt_messages: for prompt_message in prompt_messages:
if isinstance(prompt_message, UserPromptMessage): if isinstance(prompt_message, UserPromptMessage):
if isinstance(prompt_message.content, list): if isinstance(prompt_message.content, list):
prompt_message.content = '\n'.join([ prompt_message.content = "\n".join(
content.data if content.type == PromptMessageContentType.TEXT else [
'[image]' if content.type == PromptMessageContentType.IMAGE else content.data
'[file]' if content.type == PromptMessageContentType.TEXT
for content in prompt_message.content else "[image]"
]) if content.type == PromptMessageContentType.IMAGE
else "[file]"
for content in prompt_message.content
]
)
return prompt_messages return prompt_messages
def _organize_prompt_messages(self): def _organize_prompt_messages(self):
prompt_template = self.app_config.prompt_template.simple_prompt_template or '' prompt_template = self.app_config.prompt_template.simple_prompt_template or ""
self.history_prompt_messages = self._init_system_message(prompt_template, self.history_prompt_messages) self.history_prompt_messages = self._init_system_message(prompt_template, self.history_prompt_messages)
query_prompt_messages = self._organize_user_query(self.query, []) query_prompt_messages = self._organize_user_query(self.query, [])
@ -433,14 +432,10 @@ class FunctionCallAgentRunner(BaseAgentRunner):
model_config=self.model_config, model_config=self.model_config,
prompt_messages=[*query_prompt_messages, *self._current_thoughts], prompt_messages=[*query_prompt_messages, *self._current_thoughts],
history_messages=self.history_prompt_messages, history_messages=self.history_prompt_messages,
memory=self.memory memory=self.memory,
).get_prompt() ).get_prompt()
prompt_messages = [ prompt_messages = [*self.history_prompt_messages, *query_prompt_messages, *self._current_thoughts]
*self.history_prompt_messages,
*query_prompt_messages,
*self._current_thoughts
]
if len(self._current_thoughts) != 0: if len(self._current_thoughts) != 0:
# clear messages after the first iteration # clear messages after the first iteration
prompt_messages = self._clear_user_prompt_image_messages(prompt_messages) prompt_messages = self._clear_user_prompt_image_messages(prompt_messages)

View File

@ -9,8 +9,9 @@ from core.model_runtime.entities.llm_entities import LLMResultChunk
class CotAgentOutputParser: class CotAgentOutputParser:
@classmethod @classmethod
def handle_react_stream_output(cls, llm_response: Generator[LLMResultChunk, None, None], usage_dict: dict) -> \ def handle_react_stream_output(
Generator[Union[str, AgentScratchpadUnit.Action], None, None]: cls, llm_response: Generator[LLMResultChunk, None, None], usage_dict: dict
) -> Generator[Union[str, AgentScratchpadUnit.Action], None, None]:
def parse_action(json_str): def parse_action(json_str):
try: try:
action = json.loads(json_str) action = json.loads(json_str)
@ -22,7 +23,7 @@ class CotAgentOutputParser:
action = action[0] action = action[0]
for key, value in action.items(): for key, value in action.items():
if 'input' in key.lower(): if "input" in key.lower():
action_input = value action_input = value
else: else:
action_name = value action_name = value
@ -33,37 +34,37 @@ class CotAgentOutputParser:
action_input=action_input, action_input=action_input,
) )
else: else:
return json_str or '' return json_str or ""
except: except:
return json_str or '' return json_str or ""
def extra_json_from_code_block(code_block) -> Generator[Union[dict, str], None, None]: def extra_json_from_code_block(code_block) -> Generator[Union[dict, str], None, None]:
code_blocks = re.findall(r'```(.*?)```', code_block, re.DOTALL) code_blocks = re.findall(r"```(.*?)```", code_block, re.DOTALL)
if not code_blocks: if not code_blocks:
return return
for block in code_blocks: for block in code_blocks:
json_text = re.sub(r'^[a-zA-Z]+\n', '', block.strip(), flags=re.MULTILINE) json_text = re.sub(r"^[a-zA-Z]+\n", "", block.strip(), flags=re.MULTILINE)
yield parse_action(json_text) yield parse_action(json_text)
code_block_cache = '' code_block_cache = ""
code_block_delimiter_count = 0 code_block_delimiter_count = 0
in_code_block = False in_code_block = False
json_cache = '' json_cache = ""
json_quote_count = 0 json_quote_count = 0
in_json = False in_json = False
got_json = False got_json = False
action_cache = '' action_cache = ""
action_str = 'action:' action_str = "action:"
action_idx = 0 action_idx = 0
thought_cache = '' thought_cache = ""
thought_str = 'thought:' thought_str = "thought:"
thought_idx = 0 thought_idx = 0
for response in llm_response: for response in llm_response:
if response.delta.usage: if response.delta.usage:
usage_dict['usage'] = response.delta.usage usage_dict["usage"] = response.delta.usage
response = response.delta.message.content response = response.delta.message.content
if not isinstance(response, str): if not isinstance(response, str):
continue continue
@ -72,24 +73,24 @@ class CotAgentOutputParser:
index = 0 index = 0
while index < len(response): while index < len(response):
steps = 1 steps = 1
delta = response[index:index+steps] delta = response[index : index + steps]
last_character = response[index-1] if index > 0 else '' last_character = response[index - 1] if index > 0 else ""
if delta == '`': if delta == "`":
code_block_cache += delta code_block_cache += delta
code_block_delimiter_count += 1 code_block_delimiter_count += 1
else: else:
if not in_code_block: if not in_code_block:
if code_block_delimiter_count > 0: if code_block_delimiter_count > 0:
yield code_block_cache yield code_block_cache
code_block_cache = '' code_block_cache = ""
else: else:
code_block_cache += delta code_block_cache += delta
code_block_delimiter_count = 0 code_block_delimiter_count = 0
if not in_code_block and not in_json: if not in_code_block and not in_json:
if delta.lower() == action_str[action_idx] and action_idx == 0: if delta.lower() == action_str[action_idx] and action_idx == 0:
if last_character not in ['\n', ' ', '']: if last_character not in ["\n", " ", ""]:
index += steps index += steps
yield delta yield delta
continue continue
@ -97,7 +98,7 @@ class CotAgentOutputParser:
action_cache += delta action_cache += delta
action_idx += 1 action_idx += 1
if action_idx == len(action_str): if action_idx == len(action_str):
action_cache = '' action_cache = ""
action_idx = 0 action_idx = 0
index += steps index += steps
continue continue
@ -105,18 +106,18 @@ class CotAgentOutputParser:
action_cache += delta action_cache += delta
action_idx += 1 action_idx += 1
if action_idx == len(action_str): if action_idx == len(action_str):
action_cache = '' action_cache = ""
action_idx = 0 action_idx = 0
index += steps index += steps
continue continue
else: else:
if action_cache: if action_cache:
yield action_cache yield action_cache
action_cache = '' action_cache = ""
action_idx = 0 action_idx = 0
if delta.lower() == thought_str[thought_idx] and thought_idx == 0: if delta.lower() == thought_str[thought_idx] and thought_idx == 0:
if last_character not in ['\n', ' ', '']: if last_character not in ["\n", " ", ""]:
index += steps index += steps
yield delta yield delta
continue continue
@ -124,7 +125,7 @@ class CotAgentOutputParser:
thought_cache += delta thought_cache += delta
thought_idx += 1 thought_idx += 1
if thought_idx == len(thought_str): if thought_idx == len(thought_str):
thought_cache = '' thought_cache = ""
thought_idx = 0 thought_idx = 0
index += steps index += steps
continue continue
@ -132,31 +133,31 @@ class CotAgentOutputParser:
thought_cache += delta thought_cache += delta
thought_idx += 1 thought_idx += 1
if thought_idx == len(thought_str): if thought_idx == len(thought_str):
thought_cache = '' thought_cache = ""
thought_idx = 0 thought_idx = 0
index += steps index += steps
continue continue
else: else:
if thought_cache: if thought_cache:
yield thought_cache yield thought_cache
thought_cache = '' thought_cache = ""
thought_idx = 0 thought_idx = 0
if code_block_delimiter_count == 3: if code_block_delimiter_count == 3:
if in_code_block: if in_code_block:
yield from extra_json_from_code_block(code_block_cache) yield from extra_json_from_code_block(code_block_cache)
code_block_cache = '' code_block_cache = ""
in_code_block = not in_code_block in_code_block = not in_code_block
code_block_delimiter_count = 0 code_block_delimiter_count = 0
if not in_code_block: if not in_code_block:
# handle single json # handle single json
if delta == '{': if delta == "{":
json_quote_count += 1 json_quote_count += 1
in_json = True in_json = True
json_cache += delta json_cache += delta
elif delta == '}': elif delta == "}":
json_cache += delta json_cache += delta
if json_quote_count > 0: if json_quote_count > 0:
json_quote_count -= 1 json_quote_count -= 1
@ -172,12 +173,12 @@ class CotAgentOutputParser:
if got_json: if got_json:
got_json = False got_json = False
yield parse_action(json_cache) yield parse_action(json_cache)
json_cache = '' json_cache = ""
json_quote_count = 0 json_quote_count = 0
in_json = False in_json = False
if not in_code_block and not in_json: if not in_code_block and not in_json:
yield delta.replace('`', '') yield delta.replace("`", "")
index += steps index += steps
@ -186,4 +187,3 @@ class CotAgentOutputParser:
if json_cache: if json_cache:
yield parse_action(json_cache) yield parse_action(json_cache)

View File

@ -91,14 +91,14 @@ Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use
ENGLISH_REACT_CHAT_AGENT_SCRATCHPAD_TEMPLATES = "" ENGLISH_REACT_CHAT_AGENT_SCRATCHPAD_TEMPLATES = ""
REACT_PROMPT_TEMPLATES = { REACT_PROMPT_TEMPLATES = {
'english': { "english": {
'chat': { "chat": {
'prompt': ENGLISH_REACT_CHAT_PROMPT_TEMPLATES, "prompt": ENGLISH_REACT_CHAT_PROMPT_TEMPLATES,
'agent_scratchpad': ENGLISH_REACT_CHAT_AGENT_SCRATCHPAD_TEMPLATES "agent_scratchpad": ENGLISH_REACT_CHAT_AGENT_SCRATCHPAD_TEMPLATES,
},
"completion": {
"prompt": ENGLISH_REACT_COMPLETION_PROMPT_TEMPLATES,
"agent_scratchpad": ENGLISH_REACT_COMPLETION_AGENT_SCRATCHPAD_TEMPLATES,
}, },
'completion': {
'prompt': ENGLISH_REACT_COMPLETION_PROMPT_TEMPLATES,
'agent_scratchpad': ENGLISH_REACT_COMPLETION_AGENT_SCRATCHPAD_TEMPLATES
}
} }
} }

View File

@ -26,34 +26,24 @@ class BaseAppConfigManager:
config_dict = dict(config_dict.items()) config_dict = dict(config_dict.items())
additional_features = AppAdditionalFeatures() additional_features = AppAdditionalFeatures()
additional_features.show_retrieve_source = RetrievalResourceConfigManager.convert( additional_features.show_retrieve_source = RetrievalResourceConfigManager.convert(config=config_dict)
config=config_dict
)
additional_features.file_upload = FileUploadConfigManager.convert( additional_features.file_upload = FileUploadConfigManager.convert(
config=config_dict, config=config_dict, is_vision=app_mode in [AppMode.CHAT, AppMode.COMPLETION, AppMode.AGENT_CHAT]
is_vision=app_mode in [AppMode.CHAT, AppMode.COMPLETION, AppMode.AGENT_CHAT]
) )
additional_features.opening_statement, additional_features.suggested_questions = \ additional_features.opening_statement, additional_features.suggested_questions = (
OpeningStatementConfigManager.convert( OpeningStatementConfigManager.convert(config=config_dict)
config=config_dict )
)
additional_features.suggested_questions_after_answer = SuggestedQuestionsAfterAnswerConfigManager.convert( additional_features.suggested_questions_after_answer = SuggestedQuestionsAfterAnswerConfigManager.convert(
config=config_dict config=config_dict
) )
additional_features.more_like_this = MoreLikeThisConfigManager.convert( additional_features.more_like_this = MoreLikeThisConfigManager.convert(config=config_dict)
config=config_dict
)
additional_features.speech_to_text = SpeechToTextConfigManager.convert( additional_features.speech_to_text = SpeechToTextConfigManager.convert(config=config_dict)
config=config_dict
)
additional_features.text_to_speech = TextToSpeechConfigManager.convert( additional_features.text_to_speech = TextToSpeechConfigManager.convert(config=config_dict)
config=config_dict
)
return additional_features return additional_features

View File

@ -7,25 +7,24 @@ from core.moderation.factory import ModerationFactory
class SensitiveWordAvoidanceConfigManager: class SensitiveWordAvoidanceConfigManager:
@classmethod @classmethod
def convert(cls, config: dict) -> Optional[SensitiveWordAvoidanceEntity]: def convert(cls, config: dict) -> Optional[SensitiveWordAvoidanceEntity]:
sensitive_word_avoidance_dict = config.get('sensitive_word_avoidance') sensitive_word_avoidance_dict = config.get("sensitive_word_avoidance")
if not sensitive_word_avoidance_dict: if not sensitive_word_avoidance_dict:
return None return None
if sensitive_word_avoidance_dict.get('enabled'): if sensitive_word_avoidance_dict.get("enabled"):
return SensitiveWordAvoidanceEntity( return SensitiveWordAvoidanceEntity(
type=sensitive_word_avoidance_dict.get('type'), type=sensitive_word_avoidance_dict.get("type"),
config=sensitive_word_avoidance_dict.get('config'), config=sensitive_word_avoidance_dict.get("config"),
) )
else: else:
return None return None
@classmethod @classmethod
def validate_and_set_defaults(cls, tenant_id, config: dict, only_structure_validate: bool = False) \ def validate_and_set_defaults(
-> tuple[dict, list[str]]: cls, tenant_id, config: dict, only_structure_validate: bool = False
) -> tuple[dict, list[str]]:
if not config.get("sensitive_word_avoidance"): if not config.get("sensitive_word_avoidance"):
config["sensitive_word_avoidance"] = { config["sensitive_word_avoidance"] = {"enabled": False}
"enabled": False
}
if not isinstance(config["sensitive_word_avoidance"], dict): if not isinstance(config["sensitive_word_avoidance"], dict):
raise ValueError("sensitive_word_avoidance must be of dict type") raise ValueError("sensitive_word_avoidance must be of dict type")
@ -41,10 +40,6 @@ class SensitiveWordAvoidanceConfigManager:
typ = config["sensitive_word_avoidance"]["type"] typ = config["sensitive_word_avoidance"]["type"]
sensitive_word_avoidance_config = config["sensitive_word_avoidance"]["config"] sensitive_word_avoidance_config = config["sensitive_word_avoidance"]["config"]
ModerationFactory.validate_config( ModerationFactory.validate_config(name=typ, tenant_id=tenant_id, config=sensitive_word_avoidance_config)
name=typ,
tenant_id=tenant_id,
config=sensitive_word_avoidance_config
)
return config, ["sensitive_word_avoidance"] return config, ["sensitive_word_avoidance"]

View File

@ -12,67 +12,70 @@ class AgentConfigManager:
:param config: model config args :param config: model config args
""" """
if 'agent_mode' in config and config['agent_mode'] \ if "agent_mode" in config and config["agent_mode"] and "enabled" in config["agent_mode"]:
and 'enabled' in config['agent_mode']: agent_dict = config.get("agent_mode", {})
agent_strategy = agent_dict.get("strategy", "cot")
agent_dict = config.get('agent_mode', {}) if agent_strategy == "function_call":
agent_strategy = agent_dict.get('strategy', 'cot')
if agent_strategy == 'function_call':
strategy = AgentEntity.Strategy.FUNCTION_CALLING strategy = AgentEntity.Strategy.FUNCTION_CALLING
elif agent_strategy == 'cot' or agent_strategy == 'react': elif agent_strategy == "cot" or agent_strategy == "react":
strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT
else: else:
# old configs, try to detect default strategy # old configs, try to detect default strategy
if config['model']['provider'] == 'openai': if config["model"]["provider"] == "openai":
strategy = AgentEntity.Strategy.FUNCTION_CALLING strategy = AgentEntity.Strategy.FUNCTION_CALLING
else: else:
strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT
agent_tools = [] agent_tools = []
for tool in agent_dict.get('tools', []): for tool in agent_dict.get("tools", []):
keys = tool.keys() keys = tool.keys()
if len(keys) >= 4: if len(keys) >= 4:
if "enabled" not in tool or not tool["enabled"]: if "enabled" not in tool or not tool["enabled"]:
continue continue
agent_tool_properties = { agent_tool_properties = {
'provider_type': tool['provider_type'], "provider_type": tool["provider_type"],
'provider_id': tool['provider_id'], "provider_id": tool["provider_id"],
'tool_name': tool['tool_name'], "tool_name": tool["tool_name"],
'tool_parameters': tool.get('tool_parameters', {}) "tool_parameters": tool.get("tool_parameters", {}),
} }
agent_tools.append(AgentToolEntity(**agent_tool_properties)) agent_tools.append(AgentToolEntity(**agent_tool_properties))
if 'strategy' in config['agent_mode'] and \ if "strategy" in config["agent_mode"] and config["agent_mode"]["strategy"] not in [
config['agent_mode']['strategy'] not in ['react_router', 'router']: "react_router",
agent_prompt = agent_dict.get('prompt', None) or {} "router",
]:
agent_prompt = agent_dict.get("prompt", None) or {}
# check model mode # check model mode
model_mode = config.get('model', {}).get('mode', 'completion') model_mode = config.get("model", {}).get("mode", "completion")
if model_mode == 'completion': if model_mode == "completion":
agent_prompt_entity = AgentPromptEntity( agent_prompt_entity = AgentPromptEntity(
first_prompt=agent_prompt.get('first_prompt', first_prompt=agent_prompt.get(
REACT_PROMPT_TEMPLATES['english']['completion']['prompt']), "first_prompt", REACT_PROMPT_TEMPLATES["english"]["completion"]["prompt"]
next_iteration=agent_prompt.get('next_iteration', ),
REACT_PROMPT_TEMPLATES['english']['completion'][ next_iteration=agent_prompt.get(
'agent_scratchpad']), "next_iteration", REACT_PROMPT_TEMPLATES["english"]["completion"]["agent_scratchpad"]
),
) )
else: else:
agent_prompt_entity = AgentPromptEntity( agent_prompt_entity = AgentPromptEntity(
first_prompt=agent_prompt.get('first_prompt', first_prompt=agent_prompt.get(
REACT_PROMPT_TEMPLATES['english']['chat']['prompt']), "first_prompt", REACT_PROMPT_TEMPLATES["english"]["chat"]["prompt"]
next_iteration=agent_prompt.get('next_iteration', ),
REACT_PROMPT_TEMPLATES['english']['chat']['agent_scratchpad']), next_iteration=agent_prompt.get(
"next_iteration", REACT_PROMPT_TEMPLATES["english"]["chat"]["agent_scratchpad"]
),
) )
return AgentEntity( return AgentEntity(
provider=config['model']['provider'], provider=config["model"]["provider"],
model=config['model']['name'], model=config["model"]["name"],
strategy=strategy, strategy=strategy,
prompt=agent_prompt_entity, prompt=agent_prompt_entity,
tools=agent_tools, tools=agent_tools,
max_iteration=agent_dict.get('max_iteration', 5) max_iteration=agent_dict.get("max_iteration", 5),
) )
return None return None

View File

@ -15,39 +15,38 @@ class DatasetConfigManager:
:param config: model config args :param config: model config args
""" """
dataset_ids = [] dataset_ids = []
if 'datasets' in config.get('dataset_configs', {}): if "datasets" in config.get("dataset_configs", {}):
datasets = config.get('dataset_configs', {}).get('datasets', { datasets = config.get("dataset_configs", {}).get("datasets", {"strategy": "router", "datasets": []})
'strategy': 'router',
'datasets': []
})
for dataset in datasets.get('datasets', []): for dataset in datasets.get("datasets", []):
keys = list(dataset.keys()) keys = list(dataset.keys())
if len(keys) == 0 or keys[0] != 'dataset': if len(keys) == 0 or keys[0] != "dataset":
continue continue
dataset = dataset['dataset'] dataset = dataset["dataset"]
if 'enabled' not in dataset or not dataset['enabled']: if "enabled" not in dataset or not dataset["enabled"]:
continue continue
dataset_id = dataset.get('id', None) dataset_id = dataset.get("id", None)
if dataset_id: if dataset_id:
dataset_ids.append(dataset_id) dataset_ids.append(dataset_id)
if 'agent_mode' in config and config['agent_mode'] \ if (
and 'enabled' in config['agent_mode'] \ "agent_mode" in config
and config['agent_mode']['enabled']: and config["agent_mode"]
and "enabled" in config["agent_mode"]
and config["agent_mode"]["enabled"]
):
agent_dict = config.get("agent_mode", {})
agent_dict = config.get('agent_mode', {}) for tool in agent_dict.get("tools", []):
for tool in agent_dict.get('tools', []):
keys = tool.keys() keys = tool.keys()
if len(keys) == 1: if len(keys) == 1:
# old standard # old standard
key = list(tool.keys())[0] key = list(tool.keys())[0]
if key != 'dataset': if key != "dataset":
continue continue
tool_item = tool[key] tool_item = tool[key]
@ -55,30 +54,28 @@ class DatasetConfigManager:
if "enabled" not in tool_item or not tool_item["enabled"]: if "enabled" not in tool_item or not tool_item["enabled"]:
continue continue
dataset_id = tool_item['id'] dataset_id = tool_item["id"]
dataset_ids.append(dataset_id) dataset_ids.append(dataset_id)
if len(dataset_ids) == 0: if len(dataset_ids) == 0:
return None return None
# dataset configs # dataset configs
if 'dataset_configs' in config and config.get('dataset_configs'): if "dataset_configs" in config and config.get("dataset_configs"):
dataset_configs = config.get('dataset_configs') dataset_configs = config.get("dataset_configs")
else: else:
dataset_configs = { dataset_configs = {"retrieval_model": "multiple"}
'retrieval_model': 'multiple' query_variable = config.get("dataset_query_variable")
}
query_variable = config.get('dataset_query_variable')
if dataset_configs['retrieval_model'] == 'single': if dataset_configs["retrieval_model"] == "single":
return DatasetEntity( return DatasetEntity(
dataset_ids=dataset_ids, dataset_ids=dataset_ids,
retrieve_config=DatasetRetrieveConfigEntity( retrieve_config=DatasetRetrieveConfigEntity(
query_variable=query_variable, query_variable=query_variable,
retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of( retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of(
dataset_configs['retrieval_model'] dataset_configs["retrieval_model"]
) ),
) ),
) )
else: else:
return DatasetEntity( return DatasetEntity(
@ -86,15 +83,15 @@ class DatasetConfigManager:
retrieve_config=DatasetRetrieveConfigEntity( retrieve_config=DatasetRetrieveConfigEntity(
query_variable=query_variable, query_variable=query_variable,
retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of( retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of(
dataset_configs['retrieval_model'] dataset_configs["retrieval_model"]
), ),
top_k=dataset_configs.get('top_k', 4), top_k=dataset_configs.get("top_k", 4),
score_threshold=dataset_configs.get('score_threshold'), score_threshold=dataset_configs.get("score_threshold"),
reranking_model=dataset_configs.get('reranking_model'), reranking_model=dataset_configs.get("reranking_model"),
weights=dataset_configs.get('weights'), weights=dataset_configs.get("weights"),
reranking_enabled=dataset_configs.get('reranking_enabled', True), reranking_enabled=dataset_configs.get("reranking_enabled", True),
rerank_mode=dataset_configs.get('reranking_mode', 'reranking_model'), rerank_mode=dataset_configs.get("reranking_mode", "reranking_model"),
) ),
) )
@classmethod @classmethod
@ -111,13 +108,10 @@ class DatasetConfigManager:
# dataset_configs # dataset_configs
if not config.get("dataset_configs"): if not config.get("dataset_configs"):
config["dataset_configs"] = {'retrieval_model': 'single'} config["dataset_configs"] = {"retrieval_model": "single"}
if not config["dataset_configs"].get("datasets"): if not config["dataset_configs"].get("datasets"):
config["dataset_configs"]["datasets"] = { config["dataset_configs"]["datasets"] = {"strategy": "router", "datasets": []}
"strategy": "router",
"datasets": []
}
if not isinstance(config["dataset_configs"], dict): if not isinstance(config["dataset_configs"], dict):
raise ValueError("dataset_configs must be of object type") raise ValueError("dataset_configs must be of object type")
@ -125,8 +119,9 @@ class DatasetConfigManager:
if not isinstance(config["dataset_configs"], dict): if not isinstance(config["dataset_configs"], dict):
raise ValueError("dataset_configs must be of object type") raise ValueError("dataset_configs must be of object type")
need_manual_query_datasets = (config.get("dataset_configs") need_manual_query_datasets = config.get("dataset_configs") and config["dataset_configs"].get(
and config["dataset_configs"].get("datasets", {}).get("datasets")) "datasets", {}
).get("datasets")
if need_manual_query_datasets and app_mode == AppMode.COMPLETION: if need_manual_query_datasets and app_mode == AppMode.COMPLETION:
# Only check when mode is completion # Only check when mode is completion
@ -148,10 +143,7 @@ class DatasetConfigManager:
""" """
# Extract dataset config for legacy compatibility # Extract dataset config for legacy compatibility
if not config.get("agent_mode"): if not config.get("agent_mode"):
config["agent_mode"] = { config["agent_mode"] = {"enabled": False, "tools": []}
"enabled": False,
"tools": []
}
if not isinstance(config["agent_mode"], dict): if not isinstance(config["agent_mode"], dict):
raise ValueError("agent_mode must be of object type") raise ValueError("agent_mode must be of object type")
@ -188,7 +180,7 @@ class DatasetConfigManager:
if not isinstance(tool_item["enabled"], bool): if not isinstance(tool_item["enabled"], bool):
raise ValueError("enabled in agent_mode.tools must be of boolean type") raise ValueError("enabled in agent_mode.tools must be of boolean type")
if 'id' not in tool_item: if "id" not in tool_item:
raise ValueError("id is required in dataset") raise ValueError("id is required in dataset")
try: try:

View File

@ -11,9 +11,7 @@ from core.provider_manager import ProviderManager
class ModelConfigConverter: class ModelConfigConverter:
@classmethod @classmethod
def convert(cls, app_config: EasyUIBasedAppConfig, def convert(cls, app_config: EasyUIBasedAppConfig, skip_check: bool = False) -> ModelConfigWithCredentialsEntity:
skip_check: bool = False) \
-> ModelConfigWithCredentialsEntity:
""" """
Convert app model config dict to entity. Convert app model config dict to entity.
:param app_config: app config :param app_config: app config
@ -25,9 +23,7 @@ class ModelConfigConverter:
provider_manager = ProviderManager() provider_manager = ProviderManager()
provider_model_bundle = provider_manager.get_provider_model_bundle( provider_model_bundle = provider_manager.get_provider_model_bundle(
tenant_id=app_config.tenant_id, tenant_id=app_config.tenant_id, provider=model_config.provider, model_type=ModelType.LLM
provider=model_config.provider,
model_type=ModelType.LLM
) )
provider_name = provider_model_bundle.configuration.provider.provider provider_name = provider_model_bundle.configuration.provider.provider
@ -38,8 +34,7 @@ class ModelConfigConverter:
# check model credentials # check model credentials
model_credentials = provider_model_bundle.configuration.get_current_credentials( model_credentials = provider_model_bundle.configuration.get_current_credentials(
model_type=ModelType.LLM, model_type=ModelType.LLM, model=model_config.model
model=model_config.model
) )
if model_credentials is None: if model_credentials is None:
@ -51,8 +46,7 @@ class ModelConfigConverter:
if not skip_check: if not skip_check:
# check model # check model
provider_model = provider_model_bundle.configuration.get_provider_model( provider_model = provider_model_bundle.configuration.get_provider_model(
model=model_config.model, model=model_config.model, model_type=ModelType.LLM
model_type=ModelType.LLM
) )
if provider_model is None: if provider_model is None:
@ -69,24 +63,18 @@ class ModelConfigConverter:
# model config # model config
completion_params = model_config.parameters completion_params = model_config.parameters
stop = [] stop = []
if 'stop' in completion_params: if "stop" in completion_params:
stop = completion_params['stop'] stop = completion_params["stop"]
del completion_params['stop'] del completion_params["stop"]
# get model mode # get model mode
model_mode = model_config.mode model_mode = model_config.mode
if not model_mode: if not model_mode:
mode_enum = model_type_instance.get_model_mode( mode_enum = model_type_instance.get_model_mode(model=model_config.model, credentials=model_credentials)
model=model_config.model,
credentials=model_credentials
)
model_mode = mode_enum.value model_mode = mode_enum.value
model_schema = model_type_instance.get_model_schema( model_schema = model_type_instance.get_model_schema(model_config.model, model_credentials)
model_config.model,
model_credentials
)
if not skip_check and not model_schema: if not skip_check and not model_schema:
raise ValueError(f"Model {model_name} not exist.") raise ValueError(f"Model {model_name} not exist.")

View File

@ -13,23 +13,23 @@ class ModelConfigManager:
:param config: model config args :param config: model config args
""" """
# model config # model config
model_config = config.get('model') model_config = config.get("model")
if not model_config: if not model_config:
raise ValueError("model is required") raise ValueError("model is required")
completion_params = model_config.get('completion_params') completion_params = model_config.get("completion_params")
stop = [] stop = []
if 'stop' in completion_params: if "stop" in completion_params:
stop = completion_params['stop'] stop = completion_params["stop"]
del completion_params['stop'] del completion_params["stop"]
# get model mode # get model mode
model_mode = model_config.get('mode') model_mode = model_config.get("mode")
return ModelConfigEntity( return ModelConfigEntity(
provider=config['model']['provider'], provider=config["model"]["provider"],
model=config['model']['name'], model=config["model"]["name"],
mode=model_mode, mode=model_mode,
parameters=completion_params, parameters=completion_params,
stop=stop, stop=stop,
@ -43,7 +43,7 @@ class ModelConfigManager:
:param tenant_id: tenant id :param tenant_id: tenant id
:param config: app model config args :param config: app model config args
""" """
if 'model' not in config: if "model" not in config:
raise ValueError("model is required") raise ValueError("model is required")
if not isinstance(config["model"], dict): if not isinstance(config["model"], dict):
@ -52,17 +52,16 @@ class ModelConfigManager:
# model.provider # model.provider
provider_entities = model_provider_factory.get_providers() provider_entities = model_provider_factory.get_providers()
model_provider_names = [provider.provider for provider in provider_entities] model_provider_names = [provider.provider for provider in provider_entities]
if 'provider' not in config["model"] or config["model"]["provider"] not in model_provider_names: if "provider" not in config["model"] or config["model"]["provider"] not in model_provider_names:
raise ValueError(f"model.provider is required and must be in {str(model_provider_names)}") raise ValueError(f"model.provider is required and must be in {str(model_provider_names)}")
# model.name # model.name
if 'name' not in config["model"]: if "name" not in config["model"]:
raise ValueError("model.name is required") raise ValueError("model.name is required")
provider_manager = ProviderManager() provider_manager = ProviderManager()
models = provider_manager.get_configurations(tenant_id).get_models( models = provider_manager.get_configurations(tenant_id).get_models(
provider=config["model"]["provider"], provider=config["model"]["provider"], model_type=ModelType.LLM
model_type=ModelType.LLM
) )
if not models: if not models:
@ -80,12 +79,12 @@ class ModelConfigManager:
# model.mode # model.mode
if model_mode: if model_mode:
config['model']["mode"] = model_mode config["model"]["mode"] = model_mode
else: else:
config['model']["mode"] = "completion" config["model"]["mode"] = "completion"
# model.completion_params # model.completion_params
if 'completion_params' not in config["model"]: if "completion_params" not in config["model"]:
raise ValueError("model.completion_params is required") raise ValueError("model.completion_params is required")
config["model"]["completion_params"] = cls.validate_model_completion_params( config["model"]["completion_params"] = cls.validate_model_completion_params(
@ -101,7 +100,7 @@ class ModelConfigManager:
raise ValueError("model.completion_params must be of object type") raise ValueError("model.completion_params must be of object type")
# stop # stop
if 'stop' not in cp: if "stop" not in cp:
cp["stop"] = [] cp["stop"] = []
elif not isinstance(cp["stop"], list): elif not isinstance(cp["stop"], list):
raise ValueError("stop in model.completion_params must be of list type") raise ValueError("stop in model.completion_params must be of list type")

View File

@ -14,39 +14,33 @@ class PromptTemplateConfigManager:
if not config.get("prompt_type"): if not config.get("prompt_type"):
raise ValueError("prompt_type is required") raise ValueError("prompt_type is required")
prompt_type = PromptTemplateEntity.PromptType.value_of(config['prompt_type']) prompt_type = PromptTemplateEntity.PromptType.value_of(config["prompt_type"])
if prompt_type == PromptTemplateEntity.PromptType.SIMPLE: if prompt_type == PromptTemplateEntity.PromptType.SIMPLE:
simple_prompt_template = config.get("pre_prompt", "") simple_prompt_template = config.get("pre_prompt", "")
return PromptTemplateEntity( return PromptTemplateEntity(prompt_type=prompt_type, simple_prompt_template=simple_prompt_template)
prompt_type=prompt_type,
simple_prompt_template=simple_prompt_template
)
else: else:
advanced_chat_prompt_template = None advanced_chat_prompt_template = None
chat_prompt_config = config.get("chat_prompt_config", {}) chat_prompt_config = config.get("chat_prompt_config", {})
if chat_prompt_config: if chat_prompt_config:
chat_prompt_messages = [] chat_prompt_messages = []
for message in chat_prompt_config.get("prompt", []): for message in chat_prompt_config.get("prompt", []):
chat_prompt_messages.append({ chat_prompt_messages.append(
"text": message["text"], {"text": message["text"], "role": PromptMessageRole.value_of(message["role"])}
"role": PromptMessageRole.value_of(message["role"]) )
})
advanced_chat_prompt_template = AdvancedChatPromptTemplateEntity( advanced_chat_prompt_template = AdvancedChatPromptTemplateEntity(messages=chat_prompt_messages)
messages=chat_prompt_messages
)
advanced_completion_prompt_template = None advanced_completion_prompt_template = None
completion_prompt_config = config.get("completion_prompt_config", {}) completion_prompt_config = config.get("completion_prompt_config", {})
if completion_prompt_config: if completion_prompt_config:
completion_prompt_template_params = { completion_prompt_template_params = {
'prompt': completion_prompt_config['prompt']['text'], "prompt": completion_prompt_config["prompt"]["text"],
} }
if 'conversation_histories_role' in completion_prompt_config: if "conversation_histories_role" in completion_prompt_config:
completion_prompt_template_params['role_prefix'] = { completion_prompt_template_params["role_prefix"] = {
'user': completion_prompt_config['conversation_histories_role']['user_prefix'], "user": completion_prompt_config["conversation_histories_role"]["user_prefix"],
'assistant': completion_prompt_config['conversation_histories_role']['assistant_prefix'] "assistant": completion_prompt_config["conversation_histories_role"]["assistant_prefix"],
} }
advanced_completion_prompt_template = AdvancedCompletionPromptTemplateEntity( advanced_completion_prompt_template = AdvancedCompletionPromptTemplateEntity(
@ -56,7 +50,7 @@ class PromptTemplateConfigManager:
return PromptTemplateEntity( return PromptTemplateEntity(
prompt_type=prompt_type, prompt_type=prompt_type,
advanced_chat_prompt_template=advanced_chat_prompt_template, advanced_chat_prompt_template=advanced_chat_prompt_template,
advanced_completion_prompt_template=advanced_completion_prompt_template advanced_completion_prompt_template=advanced_completion_prompt_template,
) )
@classmethod @classmethod
@ -72,7 +66,7 @@ class PromptTemplateConfigManager:
config["prompt_type"] = PromptTemplateEntity.PromptType.SIMPLE.value config["prompt_type"] = PromptTemplateEntity.PromptType.SIMPLE.value
prompt_type_vals = [typ.value for typ in PromptTemplateEntity.PromptType] prompt_type_vals = [typ.value for typ in PromptTemplateEntity.PromptType]
if config['prompt_type'] not in prompt_type_vals: if config["prompt_type"] not in prompt_type_vals:
raise ValueError(f"prompt_type must be in {prompt_type_vals}") raise ValueError(f"prompt_type must be in {prompt_type_vals}")
# chat_prompt_config # chat_prompt_config
@ -89,27 +83,28 @@ class PromptTemplateConfigManager:
if not isinstance(config["completion_prompt_config"], dict): if not isinstance(config["completion_prompt_config"], dict):
raise ValueError("completion_prompt_config must be of object type") raise ValueError("completion_prompt_config must be of object type")
if config['prompt_type'] == PromptTemplateEntity.PromptType.ADVANCED.value: if config["prompt_type"] == PromptTemplateEntity.PromptType.ADVANCED.value:
if not config['chat_prompt_config'] and not config['completion_prompt_config']: if not config["chat_prompt_config"] and not config["completion_prompt_config"]:
raise ValueError("chat_prompt_config or completion_prompt_config is required " raise ValueError(
"when prompt_type is advanced") "chat_prompt_config or completion_prompt_config is required " "when prompt_type is advanced"
)
model_mode_vals = [mode.value for mode in ModelMode] model_mode_vals = [mode.value for mode in ModelMode]
if config['model']["mode"] not in model_mode_vals: if config["model"]["mode"] not in model_mode_vals:
raise ValueError(f"model.mode must be in {model_mode_vals} when prompt_type is advanced") raise ValueError(f"model.mode must be in {model_mode_vals} when prompt_type is advanced")
if app_mode == AppMode.CHAT and config['model']["mode"] == ModelMode.COMPLETION.value: if app_mode == AppMode.CHAT and config["model"]["mode"] == ModelMode.COMPLETION.value:
user_prefix = config['completion_prompt_config']['conversation_histories_role']['user_prefix'] user_prefix = config["completion_prompt_config"]["conversation_histories_role"]["user_prefix"]
assistant_prefix = config['completion_prompt_config']['conversation_histories_role']['assistant_prefix'] assistant_prefix = config["completion_prompt_config"]["conversation_histories_role"]["assistant_prefix"]
if not user_prefix: if not user_prefix:
config['completion_prompt_config']['conversation_histories_role']['user_prefix'] = 'Human' config["completion_prompt_config"]["conversation_histories_role"]["user_prefix"] = "Human"
if not assistant_prefix: if not assistant_prefix:
config['completion_prompt_config']['conversation_histories_role']['assistant_prefix'] = 'Assistant' config["completion_prompt_config"]["conversation_histories_role"]["assistant_prefix"] = "Assistant"
if config['model']["mode"] == ModelMode.CHAT.value: if config["model"]["mode"] == ModelMode.CHAT.value:
prompt_list = config['chat_prompt_config']['prompt'] prompt_list = config["chat_prompt_config"]["prompt"]
if len(prompt_list) > 10: if len(prompt_list) > 10:
raise ValueError("prompt messages must be less than 10") raise ValueError("prompt messages must be less than 10")

View File

@ -16,32 +16,30 @@ class BasicVariablesConfigManager:
variable_entities = [] variable_entities = []
# old external_data_tools # old external_data_tools
external_data_tools = config.get('external_data_tools', []) external_data_tools = config.get("external_data_tools", [])
for external_data_tool in external_data_tools: for external_data_tool in external_data_tools:
if 'enabled' not in external_data_tool or not external_data_tool['enabled']: if "enabled" not in external_data_tool or not external_data_tool["enabled"]:
continue continue
external_data_variables.append( external_data_variables.append(
ExternalDataVariableEntity( ExternalDataVariableEntity(
variable=external_data_tool['variable'], variable=external_data_tool["variable"],
type=external_data_tool['type'], type=external_data_tool["type"],
config=external_data_tool['config'] config=external_data_tool["config"],
) )
) )
# variables and external_data_tools # variables and external_data_tools
for variables in config.get('user_input_form', []): for variables in config.get("user_input_form", []):
variable_type = list(variables.keys())[0] variable_type = list(variables.keys())[0]
if variable_type == VariableEntityType.EXTERNAL_DATA_TOOL: if variable_type == VariableEntityType.EXTERNAL_DATA_TOOL:
variable = variables[variable_type] variable = variables[variable_type]
if 'config' not in variable: if "config" not in variable:
continue continue
external_data_variables.append( external_data_variables.append(
ExternalDataVariableEntity( ExternalDataVariableEntity(
variable=variable['variable'], variable=variable["variable"], type=variable["type"], config=variable["config"]
type=variable['type'],
config=variable['config']
) )
) )
elif variable_type in [ elif variable_type in [
@ -54,13 +52,13 @@ class BasicVariablesConfigManager:
variable_entities.append( variable_entities.append(
VariableEntity( VariableEntity(
type=variable_type, type=variable_type,
variable=variable.get('variable'), variable=variable.get("variable"),
description=variable.get('description'), description=variable.get("description"),
label=variable.get('label'), label=variable.get("label"),
required=variable.get('required', False), required=variable.get("required", False),
max_length=variable.get('max_length'), max_length=variable.get("max_length"),
options=variable.get('options'), options=variable.get("options"),
default=variable.get('default'), default=variable.get("default"),
) )
) )
@ -103,13 +101,13 @@ class BasicVariablesConfigManager:
raise ValueError("Keys in user_input_form list can only be 'text-input', 'paragraph' or 'select'") raise ValueError("Keys in user_input_form list can only be 'text-input', 'paragraph' or 'select'")
form_item = item[key] form_item = item[key]
if 'label' not in form_item: if "label" not in form_item:
raise ValueError("label is required in user_input_form") raise ValueError("label is required in user_input_form")
if not isinstance(form_item["label"], str): if not isinstance(form_item["label"], str):
raise ValueError("label in user_input_form must be of string type") raise ValueError("label in user_input_form must be of string type")
if 'variable' not in form_item: if "variable" not in form_item:
raise ValueError("variable is required in user_input_form") raise ValueError("variable is required in user_input_form")
if not isinstance(form_item["variable"], str): if not isinstance(form_item["variable"], str):
@ -117,26 +115,24 @@ class BasicVariablesConfigManager:
pattern = re.compile(r"^(?!\d)[\u4e00-\u9fa5A-Za-z0-9_\U0001F300-\U0001F64F\U0001F680-\U0001F6FF]{1,100}$") pattern = re.compile(r"^(?!\d)[\u4e00-\u9fa5A-Za-z0-9_\U0001F300-\U0001F64F\U0001F680-\U0001F6FF]{1,100}$")
if pattern.match(form_item["variable"]) is None: if pattern.match(form_item["variable"]) is None:
raise ValueError("variable in user_input_form must be a string, " raise ValueError("variable in user_input_form must be a string, " "and cannot start with a number")
"and cannot start with a number")
variables.append(form_item["variable"]) variables.append(form_item["variable"])
if 'required' not in form_item or not form_item["required"]: if "required" not in form_item or not form_item["required"]:
form_item["required"] = False form_item["required"] = False
if not isinstance(form_item["required"], bool): if not isinstance(form_item["required"], bool):
raise ValueError("required in user_input_form must be of boolean type") raise ValueError("required in user_input_form must be of boolean type")
if key == "select": if key == "select":
if 'options' not in form_item or not form_item["options"]: if "options" not in form_item or not form_item["options"]:
form_item["options"] = [] form_item["options"] = []
if not isinstance(form_item["options"], list): if not isinstance(form_item["options"], list):
raise ValueError("options in user_input_form must be a list of strings") raise ValueError("options in user_input_form must be a list of strings")
if "default" in form_item and form_item['default'] \ if "default" in form_item and form_item["default"] and form_item["default"] not in form_item["options"]:
and form_item["default"] not in form_item["options"]:
raise ValueError("default value in user_input_form must be in the options list") raise ValueError("default value in user_input_form must be in the options list")
return config, ["user_input_form"] return config, ["user_input_form"]
@ -168,10 +164,6 @@ class BasicVariablesConfigManager:
typ = tool["type"] typ = tool["type"]
config = tool["config"] config = tool["config"]
ExternalDataToolFactory.validate_config( ExternalDataToolFactory.validate_config(name=typ, tenant_id=tenant_id, config=config)
name=typ,
tenant_id=tenant_id,
config=config
)
return config, ["external_data_tools"] return config, ["external_data_tools"]

View File

@ -12,6 +12,7 @@ class ModelConfigEntity(BaseModel):
""" """
Model Config Entity. Model Config Entity.
""" """
provider: str provider: str
model: str model: str
mode: Optional[str] = None mode: Optional[str] = None
@ -23,6 +24,7 @@ class AdvancedChatMessageEntity(BaseModel):
""" """
Advanced Chat Message Entity. Advanced Chat Message Entity.
""" """
text: str text: str
role: PromptMessageRole role: PromptMessageRole
@ -31,6 +33,7 @@ class AdvancedChatPromptTemplateEntity(BaseModel):
""" """
Advanced Chat Prompt Template Entity. Advanced Chat Prompt Template Entity.
""" """
messages: list[AdvancedChatMessageEntity] messages: list[AdvancedChatMessageEntity]
@ -43,6 +46,7 @@ class AdvancedCompletionPromptTemplateEntity(BaseModel):
""" """
Role Prefix Entity. Role Prefix Entity.
""" """
user: str user: str
assistant: str assistant: str
@ -60,11 +64,12 @@ class PromptTemplateEntity(BaseModel):
Prompt Type. Prompt Type.
'simple', 'advanced' 'simple', 'advanced'
""" """
SIMPLE = 'simple'
ADVANCED = 'advanced' SIMPLE = "simple"
ADVANCED = "advanced"
@classmethod @classmethod
def value_of(cls, value: str) -> 'PromptType': def value_of(cls, value: str) -> "PromptType":
""" """
Get value of given mode. Get value of given mode.
@ -74,7 +79,7 @@ class PromptTemplateEntity(BaseModel):
for mode in cls: for mode in cls:
if mode.value == value: if mode.value == value:
return mode return mode
raise ValueError(f'invalid prompt type value {value}') raise ValueError(f"invalid prompt type value {value}")
prompt_type: PromptType prompt_type: PromptType
simple_prompt_template: Optional[str] = None simple_prompt_template: Optional[str] = None
@ -110,6 +115,7 @@ class ExternalDataVariableEntity(BaseModel):
""" """
External Data Variable Entity. External Data Variable Entity.
""" """
variable: str variable: str
type: str type: str
config: dict[str, Any] = {} config: dict[str, Any] = {}
@ -125,11 +131,12 @@ class DatasetRetrieveConfigEntity(BaseModel):
Dataset Retrieve Strategy. Dataset Retrieve Strategy.
'single' or 'multiple' 'single' or 'multiple'
""" """
SINGLE = 'single'
MULTIPLE = 'multiple' SINGLE = "single"
MULTIPLE = "multiple"
@classmethod @classmethod
def value_of(cls, value: str) -> 'RetrieveStrategy': def value_of(cls, value: str) -> "RetrieveStrategy":
""" """
Get value of given mode. Get value of given mode.
@ -139,25 +146,24 @@ class DatasetRetrieveConfigEntity(BaseModel):
for mode in cls: for mode in cls:
if mode.value == value: if mode.value == value:
return mode return mode
raise ValueError(f'invalid retrieve strategy value {value}') raise ValueError(f"invalid retrieve strategy value {value}")
query_variable: Optional[str] = None # Only when app mode is completion query_variable: Optional[str] = None # Only when app mode is completion
retrieve_strategy: RetrieveStrategy retrieve_strategy: RetrieveStrategy
top_k: Optional[int] = None top_k: Optional[int] = None
score_threshold: Optional[float] = .0 score_threshold: Optional[float] = 0.0
rerank_mode: Optional[str] = 'reranking_model' rerank_mode: Optional[str] = "reranking_model"
reranking_model: Optional[dict] = None reranking_model: Optional[dict] = None
weights: Optional[dict] = None weights: Optional[dict] = None
reranking_enabled: Optional[bool] = True reranking_enabled: Optional[bool] = True
class DatasetEntity(BaseModel): class DatasetEntity(BaseModel):
""" """
Dataset Config Entity. Dataset Config Entity.
""" """
dataset_ids: list[str] dataset_ids: list[str]
retrieve_config: DatasetRetrieveConfigEntity retrieve_config: DatasetRetrieveConfigEntity
@ -166,6 +172,7 @@ class SensitiveWordAvoidanceEntity(BaseModel):
""" """
Sensitive Word Avoidance Entity. Sensitive Word Avoidance Entity.
""" """
type: str type: str
config: dict[str, Any] = {} config: dict[str, Any] = {}
@ -174,6 +181,7 @@ class TextToSpeechEntity(BaseModel):
""" """
Sensitive Word Avoidance Entity. Sensitive Word Avoidance Entity.
""" """
enabled: bool enabled: bool
voice: Optional[str] = None voice: Optional[str] = None
language: Optional[str] = None language: Optional[str] = None
@ -183,12 +191,11 @@ class TracingConfigEntity(BaseModel):
""" """
Tracing Config Entity. Tracing Config Entity.
""" """
enabled: bool enabled: bool
tracing_provider: str tracing_provider: str
class AppAdditionalFeatures(BaseModel): class AppAdditionalFeatures(BaseModel):
file_upload: Optional[FileExtraConfig] = None file_upload: Optional[FileExtraConfig] = None
opening_statement: Optional[str] = None opening_statement: Optional[str] = None
@ -200,10 +207,12 @@ class AppAdditionalFeatures(BaseModel):
text_to_speech: Optional[TextToSpeechEntity] = None text_to_speech: Optional[TextToSpeechEntity] = None
trace_config: Optional[TracingConfigEntity] = None trace_config: Optional[TracingConfigEntity] = None
class AppConfig(BaseModel): class AppConfig(BaseModel):
""" """
Application Config Entity. Application Config Entity.
""" """
tenant_id: str tenant_id: str
app_id: str app_id: str
app_mode: AppMode app_mode: AppMode
@ -216,15 +225,17 @@ class EasyUIBasedAppModelConfigFrom(Enum):
""" """
App Model Config From. App Model Config From.
""" """
ARGS = 'args'
APP_LATEST_CONFIG = 'app-latest-config' ARGS = "args"
CONVERSATION_SPECIFIC_CONFIG = 'conversation-specific-config' APP_LATEST_CONFIG = "app-latest-config"
CONVERSATION_SPECIFIC_CONFIG = "conversation-specific-config"
class EasyUIBasedAppConfig(AppConfig): class EasyUIBasedAppConfig(AppConfig):
""" """
Easy UI Based App Config Entity. Easy UI Based App Config Entity.
""" """
app_model_config_from: EasyUIBasedAppModelConfigFrom app_model_config_from: EasyUIBasedAppModelConfigFrom
app_model_config_id: str app_model_config_id: str
app_model_config_dict: dict app_model_config_dict: dict
@ -238,4 +249,5 @@ class WorkflowUIBasedAppConfig(AppConfig):
""" """
Workflow UI Based App Config Entity. Workflow UI Based App Config Entity.
""" """
workflow_id: str workflow_id: str

View File

@ -13,21 +13,19 @@ class FileUploadConfigManager:
:param config: model config args :param config: model config args
:param is_vision: if True, the feature is vision feature :param is_vision: if True, the feature is vision feature
""" """
file_upload_dict = config.get('file_upload') file_upload_dict = config.get("file_upload")
if file_upload_dict: if file_upload_dict:
if file_upload_dict.get('image'): if file_upload_dict.get("image"):
if 'enabled' in file_upload_dict['image'] and file_upload_dict['image']['enabled']: if "enabled" in file_upload_dict["image"] and file_upload_dict["image"]["enabled"]:
image_config = { image_config = {
'number_limits': file_upload_dict['image']['number_limits'], "number_limits": file_upload_dict["image"]["number_limits"],
'transfer_methods': file_upload_dict['image']['transfer_methods'] "transfer_methods": file_upload_dict["image"]["transfer_methods"],
} }
if is_vision: if is_vision:
image_config['detail'] = file_upload_dict['image']['detail'] image_config["detail"] = file_upload_dict["image"]["detail"]
return FileExtraConfig( return FileExtraConfig(image_config=image_config)
image_config=image_config
)
return None return None
@ -49,21 +47,21 @@ class FileUploadConfigManager:
if not config["file_upload"].get("image"): if not config["file_upload"].get("image"):
config["file_upload"]["image"] = {"enabled": False} config["file_upload"]["image"] = {"enabled": False}
if config['file_upload']['image']['enabled']: if config["file_upload"]["image"]["enabled"]:
number_limits = config['file_upload']['image']['number_limits'] number_limits = config["file_upload"]["image"]["number_limits"]
if number_limits < 1 or number_limits > 6: if number_limits < 1 or number_limits > 6:
raise ValueError("number_limits must be in [1, 6]") raise ValueError("number_limits must be in [1, 6]")
if is_vision: if is_vision:
detail = config['file_upload']['image']['detail'] detail = config["file_upload"]["image"]["detail"]
if detail not in ['high', 'low']: if detail not in ["high", "low"]:
raise ValueError("detail must be in ['high', 'low']") raise ValueError("detail must be in ['high', 'low']")
transfer_methods = config['file_upload']['image']['transfer_methods'] transfer_methods = config["file_upload"]["image"]["transfer_methods"]
if not isinstance(transfer_methods, list): if not isinstance(transfer_methods, list):
raise ValueError("transfer_methods must be of list type") raise ValueError("transfer_methods must be of list type")
for method in transfer_methods: for method in transfer_methods:
if method not in ['remote_url', 'local_file']: if method not in ["remote_url", "local_file"]:
raise ValueError("transfer_methods must be in ['remote_url', 'local_file']") raise ValueError("transfer_methods must be in ['remote_url', 'local_file']")
return config, ["file_upload"] return config, ["file_upload"]

View File

@ -7,9 +7,9 @@ class MoreLikeThisConfigManager:
:param config: model config args :param config: model config args
""" """
more_like_this = False more_like_this = False
more_like_this_dict = config.get('more_like_this') more_like_this_dict = config.get("more_like_this")
if more_like_this_dict: if more_like_this_dict:
if more_like_this_dict.get('enabled'): if more_like_this_dict.get("enabled"):
more_like_this = True more_like_this = True
return more_like_this return more_like_this
@ -22,9 +22,7 @@ class MoreLikeThisConfigManager:
:param config: app model config args :param config: app model config args
""" """
if not config.get("more_like_this"): if not config.get("more_like_this"):
config["more_like_this"] = { config["more_like_this"] = {"enabled": False}
"enabled": False
}
if not isinstance(config["more_like_this"], dict): if not isinstance(config["more_like_this"], dict):
raise ValueError("more_like_this must be of dict type") raise ValueError("more_like_this must be of dict type")

View File

@ -1,5 +1,3 @@
class OpeningStatementConfigManager: class OpeningStatementConfigManager:
@classmethod @classmethod
def convert(cls, config: dict) -> tuple[str, list]: def convert(cls, config: dict) -> tuple[str, list]:
@ -9,10 +7,10 @@ class OpeningStatementConfigManager:
:param config: model config args :param config: model config args
""" """
# opening statement # opening statement
opening_statement = config.get('opening_statement') opening_statement = config.get("opening_statement")
# suggested questions # suggested questions
suggested_questions_list = config.get('suggested_questions') suggested_questions_list = config.get("suggested_questions")
return opening_statement, suggested_questions_list return opening_statement, suggested_questions_list

View File

@ -2,9 +2,9 @@ class RetrievalResourceConfigManager:
@classmethod @classmethod
def convert(cls, config: dict) -> bool: def convert(cls, config: dict) -> bool:
show_retrieve_source = False show_retrieve_source = False
retriever_resource_dict = config.get('retriever_resource') retriever_resource_dict = config.get("retriever_resource")
if retriever_resource_dict: if retriever_resource_dict:
if retriever_resource_dict.get('enabled'): if retriever_resource_dict.get("enabled"):
show_retrieve_source = True show_retrieve_source = True
return show_retrieve_source return show_retrieve_source
@ -17,9 +17,7 @@ class RetrievalResourceConfigManager:
:param config: app model config args :param config: app model config args
""" """
if not config.get("retriever_resource"): if not config.get("retriever_resource"):
config["retriever_resource"] = { config["retriever_resource"] = {"enabled": False}
"enabled": False
}
if not isinstance(config["retriever_resource"], dict): if not isinstance(config["retriever_resource"], dict):
raise ValueError("retriever_resource must be of dict type") raise ValueError("retriever_resource must be of dict type")

View File

@ -7,9 +7,9 @@ class SpeechToTextConfigManager:
:param config: model config args :param config: model config args
""" """
speech_to_text = False speech_to_text = False
speech_to_text_dict = config.get('speech_to_text') speech_to_text_dict = config.get("speech_to_text")
if speech_to_text_dict: if speech_to_text_dict:
if speech_to_text_dict.get('enabled'): if speech_to_text_dict.get("enabled"):
speech_to_text = True speech_to_text = True
return speech_to_text return speech_to_text
@ -22,9 +22,7 @@ class SpeechToTextConfigManager:
:param config: app model config args :param config: app model config args
""" """
if not config.get("speech_to_text"): if not config.get("speech_to_text"):
config["speech_to_text"] = { config["speech_to_text"] = {"enabled": False}
"enabled": False
}
if not isinstance(config["speech_to_text"], dict): if not isinstance(config["speech_to_text"], dict):
raise ValueError("speech_to_text must be of dict type") raise ValueError("speech_to_text must be of dict type")

View File

@ -7,9 +7,9 @@ class SuggestedQuestionsAfterAnswerConfigManager:
:param config: model config args :param config: model config args
""" """
suggested_questions_after_answer = False suggested_questions_after_answer = False
suggested_questions_after_answer_dict = config.get('suggested_questions_after_answer') suggested_questions_after_answer_dict = config.get("suggested_questions_after_answer")
if suggested_questions_after_answer_dict: if suggested_questions_after_answer_dict:
if suggested_questions_after_answer_dict.get('enabled'): if suggested_questions_after_answer_dict.get("enabled"):
suggested_questions_after_answer = True suggested_questions_after_answer = True
return suggested_questions_after_answer return suggested_questions_after_answer
@ -22,15 +22,15 @@ class SuggestedQuestionsAfterAnswerConfigManager:
:param config: app model config args :param config: app model config args
""" """
if not config.get("suggested_questions_after_answer"): if not config.get("suggested_questions_after_answer"):
config["suggested_questions_after_answer"] = { config["suggested_questions_after_answer"] = {"enabled": False}
"enabled": False
}
if not isinstance(config["suggested_questions_after_answer"], dict): if not isinstance(config["suggested_questions_after_answer"], dict):
raise ValueError("suggested_questions_after_answer must be of dict type") raise ValueError("suggested_questions_after_answer must be of dict type")
if "enabled" not in config["suggested_questions_after_answer"] or not \ if (
config["suggested_questions_after_answer"]["enabled"]: "enabled" not in config["suggested_questions_after_answer"]
or not config["suggested_questions_after_answer"]["enabled"]
):
config["suggested_questions_after_answer"]["enabled"] = False config["suggested_questions_after_answer"]["enabled"] = False
if not isinstance(config["suggested_questions_after_answer"]["enabled"], bool): if not isinstance(config["suggested_questions_after_answer"]["enabled"], bool):

View File

@ -10,13 +10,13 @@ class TextToSpeechConfigManager:
:param config: model config args :param config: model config args
""" """
text_to_speech = None text_to_speech = None
text_to_speech_dict = config.get('text_to_speech') text_to_speech_dict = config.get("text_to_speech")
if text_to_speech_dict: if text_to_speech_dict:
if text_to_speech_dict.get('enabled'): if text_to_speech_dict.get("enabled"):
text_to_speech = TextToSpeechEntity( text_to_speech = TextToSpeechEntity(
enabled=text_to_speech_dict.get('enabled'), enabled=text_to_speech_dict.get("enabled"),
voice=text_to_speech_dict.get('voice'), voice=text_to_speech_dict.get("voice"),
language=text_to_speech_dict.get('language'), language=text_to_speech_dict.get("language"),
) )
return text_to_speech return text_to_speech
@ -29,11 +29,7 @@ class TextToSpeechConfigManager:
:param config: app model config args :param config: app model config args
""" """
if not config.get("text_to_speech"): if not config.get("text_to_speech"):
config["text_to_speech"] = { config["text_to_speech"] = {"enabled": False, "voice": "", "language": ""}
"enabled": False,
"voice": "",
"language": ""
}
if not isinstance(config["text_to_speech"], dict): if not isinstance(config["text_to_speech"], dict):
raise ValueError("text_to_speech must be of dict type") raise ValueError("text_to_speech must be of dict type")

View File

@ -1,4 +1,3 @@
from core.app.app_config.base_app_config_manager import BaseAppConfigManager from core.app.app_config.base_app_config_manager import BaseAppConfigManager
from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager
from core.app.app_config.entities import WorkflowUIBasedAppConfig from core.app.app_config.entities import WorkflowUIBasedAppConfig
@ -19,13 +18,13 @@ class AdvancedChatAppConfig(WorkflowUIBasedAppConfig):
""" """
Advanced Chatbot App Config Entity. Advanced Chatbot App Config Entity.
""" """
pass pass
class AdvancedChatAppConfigManager(BaseAppConfigManager): class AdvancedChatAppConfigManager(BaseAppConfigManager):
@classmethod @classmethod
def get_app_config(cls, app_model: App, def get_app_config(cls, app_model: App, workflow: Workflow) -> AdvancedChatAppConfig:
workflow: Workflow) -> AdvancedChatAppConfig:
features_dict = workflow.features_dict features_dict = workflow.features_dict
app_mode = AppMode.value_of(app_model.mode) app_mode = AppMode.value_of(app_model.mode)
@ -34,13 +33,9 @@ class AdvancedChatAppConfigManager(BaseAppConfigManager):
app_id=app_model.id, app_id=app_model.id,
app_mode=app_mode, app_mode=app_mode,
workflow_id=workflow.id, workflow_id=workflow.id,
sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert( sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(config=features_dict),
config=features_dict variables=WorkflowVariablesConfigManager.convert(workflow=workflow),
), additional_features=cls.convert_features(features_dict, app_mode),
variables=WorkflowVariablesConfigManager.convert(
workflow=workflow
),
additional_features=cls.convert_features(features_dict, app_mode)
) )
return app_config return app_config
@ -58,8 +53,7 @@ class AdvancedChatAppConfigManager(BaseAppConfigManager):
# file upload validation # file upload validation
config, current_related_config_keys = FileUploadConfigManager.validate_and_set_defaults( config, current_related_config_keys = FileUploadConfigManager.validate_and_set_defaults(
config=config, config=config, is_vision=False
is_vision=False
) )
related_config_keys.extend(current_related_config_keys) related_config_keys.extend(current_related_config_keys)
@ -69,7 +63,8 @@ class AdvancedChatAppConfigManager(BaseAppConfigManager):
# suggested_questions_after_answer # suggested_questions_after_answer
config, current_related_config_keys = SuggestedQuestionsAfterAnswerConfigManager.validate_and_set_defaults( config, current_related_config_keys = SuggestedQuestionsAfterAnswerConfigManager.validate_and_set_defaults(
config) config
)
related_config_keys.extend(current_related_config_keys) related_config_keys.extend(current_related_config_keys)
# speech_to_text # speech_to_text
@ -86,9 +81,7 @@ class AdvancedChatAppConfigManager(BaseAppConfigManager):
# moderation validation # moderation validation
config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults( config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(
tenant_id=tenant_id, tenant_id=tenant_id, config=config, only_structure_validate=only_structure_validate
config=config,
only_structure_validate=only_structure_validate
) )
related_config_keys.extend(current_related_config_keys) related_config_keys.extend(current_related_config_keys)
@ -98,4 +91,3 @@ class AdvancedChatAppConfigManager(BaseAppConfigManager):
filtered_config = {key: config.get(key) for key in related_config_keys} filtered_config = {key: config.get(key) for key in related_config_keys}
return filtered_config return filtered_config

View File

@ -34,7 +34,8 @@ logger = logging.getLogger(__name__)
class AdvancedChatAppGenerator(MessageBasedAppGenerator): class AdvancedChatAppGenerator(MessageBasedAppGenerator):
@overload @overload
def generate( def generate(
self, app_model: App, self,
app_model: App,
workflow: Workflow, workflow: Workflow,
user: Union[Account, EndUser], user: Union[Account, EndUser],
args: dict, args: dict,
@ -44,7 +45,8 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
@overload @overload
def generate( def generate(
self, app_model: App, self,
app_model: App,
workflow: Workflow, workflow: Workflow,
user: Union[Account, EndUser], user: Union[Account, EndUser],
args: dict, args: dict,
@ -53,14 +55,14 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
) -> dict: ... ) -> dict: ...
def generate( def generate(
self, self,
app_model: App, app_model: App,
workflow: Workflow, workflow: Workflow,
user: Union[Account, EndUser], user: Union[Account, EndUser],
args: dict, args: dict,
invoke_from: InvokeFrom, invoke_from: InvokeFrom,
stream: bool = True, stream: bool = True,
) -> dict[str, Any] | Generator[str, Any, None]: ) -> dict[str, Any] | Generator[str, Any, None]:
""" """
Generate App response. Generate App response.
@ -71,44 +73,37 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
:param invoke_from: invoke from source :param invoke_from: invoke from source
:param stream: is stream :param stream: is stream
""" """
if not args.get('query'): if not args.get("query"):
raise ValueError('query is required') raise ValueError("query is required")
query = args['query'] query = args["query"]
if not isinstance(query, str): if not isinstance(query, str):
raise ValueError('query must be a string') raise ValueError("query must be a string")
query = query.replace('\x00', '') query = query.replace("\x00", "")
inputs = args['inputs'] inputs = args["inputs"]
extras = { extras = {"auto_generate_conversation_name": args.get("auto_generate_name", False)}
"auto_generate_conversation_name": args.get('auto_generate_name', False)
}
# get conversation # get conversation
conversation = None conversation = None
conversation_id = args.get('conversation_id') conversation_id = args.get("conversation_id")
if conversation_id: if conversation_id:
conversation = self._get_conversation_by_user(app_model=app_model, conversation_id=conversation_id, user=user) conversation = self._get_conversation_by_user(
app_model=app_model, conversation_id=conversation_id, user=user
)
# parse files # parse files
files = args['files'] if args.get('files') else [] files = args["files"] if args.get("files") else []
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id) message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False) file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False)
if file_extra_config: if file_extra_config:
file_objs = message_file_parser.validate_and_transform_files_arg( file_objs = message_file_parser.validate_and_transform_files_arg(files, file_extra_config, user)
files,
file_extra_config,
user
)
else: else:
file_objs = [] file_objs = []
# convert to app config # convert to app config
app_config = AdvancedChatAppConfigManager.get_app_config( app_config = AdvancedChatAppConfigManager.get_app_config(app_model=app_model, workflow=workflow)
app_model=app_model,
workflow=workflow
)
# get tracing instance # get tracing instance
user_id = user.id if isinstance(user, Account) else user.session_id user_id = user.id if isinstance(user, Account) else user.session_id
@ -130,7 +125,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
stream=stream, stream=stream,
invoke_from=invoke_from, invoke_from=invoke_from,
extras=extras, extras=extras,
trace_manager=trace_manager trace_manager=trace_manager,
) )
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id) contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
@ -140,16 +135,12 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
invoke_from=invoke_from, invoke_from=invoke_from,
application_generate_entity=application_generate_entity, application_generate_entity=application_generate_entity,
conversation=conversation, conversation=conversation,
stream=stream stream=stream,
) )
def single_iteration_generate(self, app_model: App, def single_iteration_generate(
workflow: Workflow, self, app_model: App, workflow: Workflow, node_id: str, user: Account, args: dict, stream: bool = True
node_id: str, ) -> dict[str, Any] | Generator[str, Any, None]:
user: Account,
args: dict,
stream: bool = True) \
-> dict[str, Any] | Generator[str, Any, None]:
""" """
Generate App response. Generate App response.
@ -161,16 +152,13 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
:param stream: is stream :param stream: is stream
""" """
if not node_id: if not node_id:
raise ValueError('node_id is required') raise ValueError("node_id is required")
if args.get('inputs') is None: if args.get("inputs") is None:
raise ValueError('inputs is required') raise ValueError("inputs is required")
# convert to app config # convert to app config
app_config = AdvancedChatAppConfigManager.get_app_config( app_config = AdvancedChatAppConfigManager.get_app_config(app_model=app_model, workflow=workflow)
app_model=app_model,
workflow=workflow
)
# init application generate entity # init application generate entity
application_generate_entity = AdvancedChatAppGenerateEntity( application_generate_entity = AdvancedChatAppGenerateEntity(
@ -178,18 +166,15 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
app_config=app_config, app_config=app_config,
conversation_id=None, conversation_id=None,
inputs={}, inputs={},
query='', query="",
files=[], files=[],
user_id=user.id, user_id=user.id,
stream=stream, stream=stream,
invoke_from=InvokeFrom.DEBUGGER, invoke_from=InvokeFrom.DEBUGGER,
extras={ extras={"auto_generate_conversation_name": False},
"auto_generate_conversation_name": False
},
single_iteration_run=AdvancedChatAppGenerateEntity.SingleIterationRunEntity( single_iteration_run=AdvancedChatAppGenerateEntity.SingleIterationRunEntity(
node_id=node_id, node_id=node_id, inputs=args["inputs"]
inputs=args['inputs'] ),
)
) )
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id) contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
@ -199,17 +184,19 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
invoke_from=InvokeFrom.DEBUGGER, invoke_from=InvokeFrom.DEBUGGER,
application_generate_entity=application_generate_entity, application_generate_entity=application_generate_entity,
conversation=None, conversation=None,
stream=stream stream=stream,
) )
def _generate(self, *, def _generate(
workflow: Workflow, self,
user: Union[Account, EndUser], *,
invoke_from: InvokeFrom, workflow: Workflow,
application_generate_entity: AdvancedChatAppGenerateEntity, user: Union[Account, EndUser],
conversation: Optional[Conversation] = None, invoke_from: InvokeFrom,
stream: bool = True) \ application_generate_entity: AdvancedChatAppGenerateEntity,
-> dict[str, Any] | Generator[str, Any, None]: conversation: Optional[Conversation] = None,
stream: bool = True,
) -> dict[str, Any] | Generator[str, Any, None]:
""" """
Generate App response. Generate App response.
@ -225,10 +212,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
is_first_conversation = True is_first_conversation = True
# init generate records # init generate records
( (conversation, message) = self._init_generate_records(application_generate_entity, conversation)
conversation,
message
) = self._init_generate_records(application_generate_entity, conversation)
if is_first_conversation: if is_first_conversation:
# update conversation features # update conversation features
@ -243,18 +227,21 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
invoke_from=application_generate_entity.invoke_from, invoke_from=application_generate_entity.invoke_from,
conversation_id=conversation.id, conversation_id=conversation.id,
app_mode=conversation.mode, app_mode=conversation.mode,
message_id=message.id message_id=message.id,
) )
# new thread # new thread
worker_thread = threading.Thread(target=self._generate_worker, kwargs={ worker_thread = threading.Thread(
'flask_app': current_app._get_current_object(), # type: ignore target=self._generate_worker,
'application_generate_entity': application_generate_entity, kwargs={
'queue_manager': queue_manager, "flask_app": current_app._get_current_object(), # type: ignore
'conversation_id': conversation.id, "application_generate_entity": application_generate_entity,
'message_id': message.id, "queue_manager": queue_manager,
'context': contextvars.copy_context(), "conversation_id": conversation.id,
}) "message_id": message.id,
"context": contextvars.copy_context(),
},
)
worker_thread.start() worker_thread.start()
@ -269,17 +256,17 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
stream=stream, stream=stream,
) )
return AdvancedChatAppGenerateResponseConverter.convert( return AdvancedChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
response=response,
invoke_from=invoke_from
)
def _generate_worker(self, flask_app: Flask, def _generate_worker(
application_generate_entity: AdvancedChatAppGenerateEntity, self,
queue_manager: AppQueueManager, flask_app: Flask,
conversation_id: str, application_generate_entity: AdvancedChatAppGenerateEntity,
message_id: str, queue_manager: AppQueueManager,
context: contextvars.Context) -> None: conversation_id: str,
message_id: str,
context: contextvars.Context,
) -> None:
""" """
Generate worker in a new thread. Generate worker in a new thread.
:param flask_app: Flask app :param flask_app: Flask app
@ -302,7 +289,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
application_generate_entity=application_generate_entity, application_generate_entity=application_generate_entity,
queue_manager=queue_manager, queue_manager=queue_manager,
conversation=conversation, conversation=conversation,
message=message message=message,
) )
runner.run() runner.run()
@ -310,14 +297,13 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
pass pass
except InvokeAuthorizationError: except InvokeAuthorizationError:
queue_manager.publish_error( queue_manager.publish_error(
InvokeAuthorizationError('Incorrect API key provided'), InvokeAuthorizationError("Incorrect API key provided"), PublishFrom.APPLICATION_MANAGER
PublishFrom.APPLICATION_MANAGER
) )
except ValidationError as e: except ValidationError as e:
logger.exception("Validation Error when generating") logger.exception("Validation Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except (ValueError, InvokeError) as e: except (ValueError, InvokeError) as e:
if os.environ.get("DEBUG", "false").lower() == 'true': if os.environ.get("DEBUG", "false").lower() == "true":
logger.exception("Error when generating") logger.exception("Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except Exception as e: except Exception as e:

View File

@ -25,10 +25,7 @@ def _invoiceTTS(text_content: str, model_instance, tenant_id: str, voice: str):
if not text_content or text_content.isspace(): if not text_content or text_content.isspace():
return return
return model_instance.invoke_tts( return model_instance.invoke_tts(
content_text=text_content.strip(), content_text=text_content.strip(), user="responding_tts", tenant_id=tenant_id, voice=voice
user="responding_tts",
tenant_id=tenant_id,
voice=voice
) )
@ -44,28 +41,26 @@ def _process_future(future_queue, audio_queue):
except Exception as e: except Exception as e:
logging.getLogger(__name__).warning(e) logging.getLogger(__name__).warning(e)
break break
audio_queue.put(AudioTrunk("finish", b'')) audio_queue.put(AudioTrunk("finish", b""))
class AppGeneratorTTSPublisher: class AppGeneratorTTSPublisher:
def __init__(self, tenant_id: str, voice: str): def __init__(self, tenant_id: str, voice: str):
self.logger = logging.getLogger(__name__) self.logger = logging.getLogger(__name__)
self.tenant_id = tenant_id self.tenant_id = tenant_id
self.msg_text = '' self.msg_text = ""
self._audio_queue = queue.Queue() self._audio_queue = queue.Queue()
self._msg_queue = queue.Queue() self._msg_queue = queue.Queue()
self.match = re.compile(r'[。.!?]') self.match = re.compile(r"[。.!?]")
self.model_manager = ModelManager() self.model_manager = ModelManager()
self.model_instance = self.model_manager.get_default_model_instance( self.model_instance = self.model_manager.get_default_model_instance(
tenant_id=self.tenant_id, tenant_id=self.tenant_id, model_type=ModelType.TTS
model_type=ModelType.TTS
) )
self.voices = self.model_instance.get_tts_voices() self.voices = self.model_instance.get_tts_voices()
values = [voice.get('value') for voice in self.voices] values = [voice.get("value") for voice in self.voices]
self.voice = voice self.voice = voice
if not voice or voice not in values: if not voice or voice not in values:
self.voice = self.voices[0].get('value') self.voice = self.voices[0].get("value")
self.MAX_SENTENCE = 2 self.MAX_SENTENCE = 2
self._last_audio_event = None self._last_audio_event = None
self._runtime_thread = threading.Thread(target=self._runtime).start() self._runtime_thread = threading.Thread(target=self._runtime).start()
@ -85,8 +80,9 @@ class AppGeneratorTTSPublisher:
message = self._msg_queue.get() message = self._msg_queue.get()
if message is None: if message is None:
if self.msg_text and len(self.msg_text.strip()) > 0: if self.msg_text and len(self.msg_text.strip()) > 0:
futures_result = self.executor.submit(_invoiceTTS, self.msg_text, futures_result = self.executor.submit(
self.model_instance, self.tenant_id, self.voice) _invoiceTTS, self.msg_text, self.model_instance, self.tenant_id, self.voice
)
future_queue.put(futures_result) future_queue.put(futures_result)
break break
elif isinstance(message.event, QueueAgentMessageEvent | QueueLLMChunkEvent): elif isinstance(message.event, QueueAgentMessageEvent | QueueLLMChunkEvent):
@ -94,21 +90,20 @@ class AppGeneratorTTSPublisher:
elif isinstance(message.event, QueueTextChunkEvent): elif isinstance(message.event, QueueTextChunkEvent):
self.msg_text += message.event.text self.msg_text += message.event.text
elif isinstance(message.event, QueueNodeSucceededEvent): elif isinstance(message.event, QueueNodeSucceededEvent):
self.msg_text += message.event.outputs.get('output', '') self.msg_text += message.event.outputs.get("output", "")
self.last_message = message self.last_message = message
sentence_arr, text_tmp = self._extract_sentence(self.msg_text) sentence_arr, text_tmp = self._extract_sentence(self.msg_text)
if len(sentence_arr) >= min(self.MAX_SENTENCE, 7): if len(sentence_arr) >= min(self.MAX_SENTENCE, 7):
self.MAX_SENTENCE += 1 self.MAX_SENTENCE += 1
text_content = ''.join(sentence_arr) text_content = "".join(sentence_arr)
futures_result = self.executor.submit(_invoiceTTS, text_content, futures_result = self.executor.submit(
self.model_instance, _invoiceTTS, text_content, self.model_instance, self.tenant_id, self.voice
self.tenant_id, )
self.voice)
future_queue.put(futures_result) future_queue.put(futures_result)
if text_tmp: if text_tmp:
self.msg_text = text_tmp self.msg_text = text_tmp
else: else:
self.msg_text = '' self.msg_text = ""
except Exception as e: except Exception as e:
self.logger.warning(e) self.logger.warning(e)

View File

@ -38,11 +38,11 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
""" """
def __init__( def __init__(
self, self,
application_generate_entity: AdvancedChatAppGenerateEntity, application_generate_entity: AdvancedChatAppGenerateEntity,
queue_manager: AppQueueManager, queue_manager: AppQueueManager,
conversation: Conversation, conversation: Conversation,
message: Message message: Message,
) -> None: ) -> None:
""" """
:param application_generate_entity: application generate entity :param application_generate_entity: application generate entity
@ -66,11 +66,11 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
app_record = db.session.query(App).filter(App.id == app_config.app_id).first() app_record = db.session.query(App).filter(App.id == app_config.app_id).first()
if not app_record: if not app_record:
raise ValueError('App not found') raise ValueError("App not found")
workflow = self.get_workflow(app_model=app_record, workflow_id=app_config.workflow_id) workflow = self.get_workflow(app_model=app_record, workflow_id=app_config.workflow_id)
if not workflow: if not workflow:
raise ValueError('Workflow not initialized') raise ValueError("Workflow not initialized")
user_id = None user_id = None
if self.application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]: if self.application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]:
@ -81,7 +81,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
user_id = self.application_generate_entity.user_id user_id = self.application_generate_entity.user_id
workflow_callbacks: list[WorkflowCallback] = [] workflow_callbacks: list[WorkflowCallback] = []
if bool(os.environ.get("DEBUG", 'False').lower() == 'true'): if bool(os.environ.get("DEBUG", "False").lower() == "true"):
workflow_callbacks.append(WorkflowLoggingCallback()) workflow_callbacks.append(WorkflowLoggingCallback())
if self.application_generate_entity.single_iteration_run: if self.application_generate_entity.single_iteration_run:
@ -89,7 +89,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration( graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
workflow=workflow, workflow=workflow,
node_id=self.application_generate_entity.single_iteration_run.node_id, node_id=self.application_generate_entity.single_iteration_run.node_id,
user_inputs=self.application_generate_entity.single_iteration_run.inputs user_inputs=self.application_generate_entity.single_iteration_run.inputs,
) )
else: else:
inputs = self.application_generate_entity.inputs inputs = self.application_generate_entity.inputs
@ -98,26 +98,27 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
# moderation # moderation
if self.handle_input_moderation( if self.handle_input_moderation(
app_record=app_record, app_record=app_record,
app_generate_entity=self.application_generate_entity, app_generate_entity=self.application_generate_entity,
inputs=inputs, inputs=inputs,
query=query, query=query,
message_id=self.message.id message_id=self.message.id,
): ):
return return
# annotation reply # annotation reply
if self.handle_annotation_reply( if self.handle_annotation_reply(
app_record=app_record, app_record=app_record,
message=self.message, message=self.message,
query=query, query=query,
app_generate_entity=self.application_generate_entity app_generate_entity=self.application_generate_entity,
): ):
return return
# Init conversation variables # Init conversation variables
stmt = select(ConversationVariable).where( stmt = select(ConversationVariable).where(
ConversationVariable.app_id == self.conversation.app_id, ConversationVariable.conversation_id == self.conversation.id ConversationVariable.app_id == self.conversation.app_id,
ConversationVariable.conversation_id == self.conversation.id,
) )
with Session(db.engine) as session: with Session(db.engine) as session:
conversation_variables = session.scalars(stmt).all() conversation_variables = session.scalars(stmt).all()
@ -190,12 +191,12 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
self._handle_event(workflow_entry, event) self._handle_event(workflow_entry, event)
def handle_input_moderation( def handle_input_moderation(
self, self,
app_record: App, app_record: App,
app_generate_entity: AdvancedChatAppGenerateEntity, app_generate_entity: AdvancedChatAppGenerateEntity,
inputs: Mapping[str, Any], inputs: Mapping[str, Any],
query: str, query: str,
message_id: str message_id: str,
) -> bool: ) -> bool:
""" """
Handle input moderation Handle input moderation
@ -217,18 +218,14 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
message_id=message_id, message_id=message_id,
) )
except ModerationException as e: except ModerationException as e:
self._complete_with_stream_output( self._complete_with_stream_output(text=str(e), stopped_by=QueueStopEvent.StopBy.INPUT_MODERATION)
text=str(e),
stopped_by=QueueStopEvent.StopBy.INPUT_MODERATION
)
return True return True
return False return False
def handle_annotation_reply(self, app_record: App, def handle_annotation_reply(
message: Message, self, app_record: App, message: Message, query: str, app_generate_entity: AdvancedChatAppGenerateEntity
query: str, ) -> bool:
app_generate_entity: AdvancedChatAppGenerateEntity) -> bool:
""" """
Handle annotation reply Handle annotation reply
:param app_record: app record :param app_record: app record
@ -246,32 +243,21 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
) )
if annotation_reply: if annotation_reply:
self._publish_event( self._publish_event(QueueAnnotationReplyEvent(message_annotation_id=annotation_reply.id))
QueueAnnotationReplyEvent(message_annotation_id=annotation_reply.id)
)
self._complete_with_stream_output( self._complete_with_stream_output(
text=annotation_reply.content, text=annotation_reply.content, stopped_by=QueueStopEvent.StopBy.ANNOTATION_REPLY
stopped_by=QueueStopEvent.StopBy.ANNOTATION_REPLY
) )
return True return True
return False return False
def _complete_with_stream_output(self, def _complete_with_stream_output(self, text: str, stopped_by: QueueStopEvent.StopBy) -> None:
text: str,
stopped_by: QueueStopEvent.StopBy) -> None:
""" """
Direct output Direct output
:param text: text :param text: text
:return: :return:
""" """
self._publish_event( self._publish_event(QueueTextChunkEvent(text=text))
QueueTextChunkEvent(
text=text
)
)
self._publish_event( self._publish_event(QueueStopEvent(stopped_by=stopped_by))
QueueStopEvent(stopped_by=stopped_by)
)

View File

@ -28,15 +28,15 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
""" """
blocking_response = cast(ChatbotAppBlockingResponse, blocking_response) blocking_response = cast(ChatbotAppBlockingResponse, blocking_response)
response = { response = {
'event': 'message', "event": "message",
'task_id': blocking_response.task_id, "task_id": blocking_response.task_id,
'id': blocking_response.data.id, "id": blocking_response.data.id,
'message_id': blocking_response.data.message_id, "message_id": blocking_response.data.message_id,
'conversation_id': blocking_response.data.conversation_id, "conversation_id": blocking_response.data.conversation_id,
'mode': blocking_response.data.mode, "mode": blocking_response.data.mode,
'answer': blocking_response.data.answer, "answer": blocking_response.data.answer,
'metadata': blocking_response.data.metadata, "metadata": blocking_response.data.metadata,
'created_at': blocking_response.data.created_at "created_at": blocking_response.data.created_at,
} }
return response return response
@ -50,13 +50,15 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
""" """
response = cls.convert_blocking_full_response(blocking_response) response = cls.convert_blocking_full_response(blocking_response)
metadata = response.get('metadata', {}) metadata = response.get("metadata", {})
response['metadata'] = cls._get_simple_metadata(metadata) response["metadata"] = cls._get_simple_metadata(metadata)
return response return response
@classmethod @classmethod
def convert_stream_full_response(cls, stream_response: Generator[AppStreamResponse, None, None]) -> Generator[str, Any, None]: def convert_stream_full_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[str, Any, None]:
""" """
Convert stream full response. Convert stream full response.
:param stream_response: stream response :param stream_response: stream response
@ -67,14 +69,14 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
sub_stream_response = chunk.stream_response sub_stream_response = chunk.stream_response
if isinstance(sub_stream_response, PingStreamResponse): if isinstance(sub_stream_response, PingStreamResponse):
yield 'ping' yield "ping"
continue continue
response_chunk = { response_chunk = {
'event': sub_stream_response.event.value, "event": sub_stream_response.event.value,
'conversation_id': chunk.conversation_id, "conversation_id": chunk.conversation_id,
'message_id': chunk.message_id, "message_id": chunk.message_id,
'created_at': chunk.created_at "created_at": chunk.created_at,
} }
if isinstance(sub_stream_response, ErrorStreamResponse): if isinstance(sub_stream_response, ErrorStreamResponse):
@ -85,7 +87,9 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
yield json.dumps(response_chunk) yield json.dumps(response_chunk)
@classmethod @classmethod
def convert_stream_simple_response(cls, stream_response: Generator[AppStreamResponse, None, None]) -> Generator[str, Any, None]: def convert_stream_simple_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[str, Any, None]:
""" """
Convert stream simple response. Convert stream simple response.
:param stream_response: stream response :param stream_response: stream response
@ -96,20 +100,20 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
sub_stream_response = chunk.stream_response sub_stream_response = chunk.stream_response
if isinstance(sub_stream_response, PingStreamResponse): if isinstance(sub_stream_response, PingStreamResponse):
yield 'ping' yield "ping"
continue continue
response_chunk = { response_chunk = {
'event': sub_stream_response.event.value, "event": sub_stream_response.event.value,
'conversation_id': chunk.conversation_id, "conversation_id": chunk.conversation_id,
'message_id': chunk.message_id, "message_id": chunk.message_id,
'created_at': chunk.created_at "created_at": chunk.created_at,
} }
if isinstance(sub_stream_response, MessageEndStreamResponse): if isinstance(sub_stream_response, MessageEndStreamResponse):
sub_stream_response_dict = sub_stream_response.to_dict() sub_stream_response_dict = sub_stream_response.to_dict()
metadata = sub_stream_response_dict.get('metadata', {}) metadata = sub_stream_response_dict.get("metadata", {})
sub_stream_response_dict['metadata'] = cls._get_simple_metadata(metadata) sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
response_chunk.update(sub_stream_response_dict) response_chunk.update(sub_stream_response_dict)
if isinstance(sub_stream_response, ErrorStreamResponse): if isinstance(sub_stream_response, ErrorStreamResponse):
data = cls._error_to_stream_response(sub_stream_response.err) data = cls._error_to_stream_response(sub_stream_response.err)

View File

@ -65,6 +65,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
""" """
AdvancedChatAppGenerateTaskPipeline is a class that generate stream output and state management for Application. AdvancedChatAppGenerateTaskPipeline is a class that generate stream output and state management for Application.
""" """
_task_state: WorkflowTaskState _task_state: WorkflowTaskState
_application_generate_entity: AdvancedChatAppGenerateEntity _application_generate_entity: AdvancedChatAppGenerateEntity
_workflow: Workflow _workflow: Workflow
@ -72,14 +73,14 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
_workflow_system_variables: dict[SystemVariableKey, Any] _workflow_system_variables: dict[SystemVariableKey, Any]
def __init__( def __init__(
self, self,
application_generate_entity: AdvancedChatAppGenerateEntity, application_generate_entity: AdvancedChatAppGenerateEntity,
workflow: Workflow, workflow: Workflow,
queue_manager: AppQueueManager, queue_manager: AppQueueManager,
conversation: Conversation, conversation: Conversation,
message: Message, message: Message,
user: Union[Account, EndUser], user: Union[Account, EndUser],
stream: bool, stream: bool,
) -> None: ) -> None:
""" """
Initialize AdvancedChatAppGenerateTaskPipeline. Initialize AdvancedChatAppGenerateTaskPipeline.
@ -123,13 +124,10 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
# start generate conversation name thread # start generate conversation name thread
self._conversation_name_generate_thread = self._generate_conversation_name( self._conversation_name_generate_thread = self._generate_conversation_name(
self._conversation, self._conversation, self._application_generate_entity.query
self._application_generate_entity.query
) )
generator = self._wrapper_process_stream_response( generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager)
trace_manager=self._application_generate_entity.trace_manager
)
if self._stream: if self._stream:
return self._to_stream_response(generator) return self._to_stream_response(generator)
@ -147,7 +145,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
elif isinstance(stream_response, MessageEndStreamResponse): elif isinstance(stream_response, MessageEndStreamResponse):
extras = {} extras = {}
if stream_response.metadata: if stream_response.metadata:
extras['metadata'] = stream_response.metadata extras["metadata"] = stream_response.metadata
return ChatbotAppBlockingResponse( return ChatbotAppBlockingResponse(
task_id=stream_response.task_id, task_id=stream_response.task_id,
@ -158,15 +156,17 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
message_id=self._message.id, message_id=self._message.id,
answer=self._task_state.answer, answer=self._task_state.answer,
created_at=int(self._message.created_at.timestamp()), created_at=int(self._message.created_at.timestamp()),
**extras **extras,
) ),
) )
else: else:
continue continue
raise Exception('Queue listening stopped unexpectedly.') raise Exception("Queue listening stopped unexpectedly.")
def _to_stream_response(self, generator: Generator[StreamResponse, None, None]) -> Generator[ChatbotAppStreamResponse, Any, None]: def _to_stream_response(
self, generator: Generator[StreamResponse, None, None]
) -> Generator[ChatbotAppStreamResponse, Any, None]:
""" """
To stream response. To stream response.
:return: :return:
@ -176,7 +176,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
conversation_id=self._conversation.id, conversation_id=self._conversation.id,
message_id=self._message.id, message_id=self._message.id,
created_at=int(self._message.created_at.timestamp()), created_at=int(self._message.created_at.timestamp()),
stream_response=stream_response stream_response=stream_response,
) )
def _listenAudioMsg(self, publisher, task_id: str): def _listenAudioMsg(self, publisher, task_id: str):
@ -187,17 +187,20 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id) return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id)
return None return None
def _wrapper_process_stream_response(self, trace_manager: Optional[TraceQueueManager] = None) -> \ def _wrapper_process_stream_response(
Generator[StreamResponse, None, None]: self, trace_manager: Optional[TraceQueueManager] = None
) -> Generator[StreamResponse, None, None]:
tts_publisher = None tts_publisher = None
task_id = self._application_generate_entity.task_id task_id = self._application_generate_entity.task_id
tenant_id = self._application_generate_entity.app_config.tenant_id tenant_id = self._application_generate_entity.app_config.tenant_id
features_dict = self._workflow.features_dict features_dict = self._workflow.features_dict
if features_dict.get('text_to_speech') and features_dict['text_to_speech'].get('enabled') and features_dict[ if (
'text_to_speech'].get('autoPlay') == 'enabled': features_dict.get("text_to_speech")
tts_publisher = AppGeneratorTTSPublisher(tenant_id, features_dict['text_to_speech'].get('voice')) and features_dict["text_to_speech"].get("enabled")
and features_dict["text_to_speech"].get("autoPlay") == "enabled"
):
tts_publisher = AppGeneratorTTSPublisher(tenant_id, features_dict["text_to_speech"].get("voice"))
for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager): for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager):
while True: while True:
@ -228,12 +231,12 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
except Exception as e: except Exception as e:
logger.error(e) logger.error(e)
break break
yield MessageAudioEndStreamResponse(audio='', task_id=task_id) yield MessageAudioEndStreamResponse(audio="", task_id=task_id)
def _process_stream_response( def _process_stream_response(
self, self,
tts_publisher: Optional[AppGeneratorTTSPublisher] = None, tts_publisher: Optional[AppGeneratorTTSPublisher] = None,
trace_manager: Optional[TraceQueueManager] = None trace_manager: Optional[TraceQueueManager] = None,
) -> Generator[StreamResponse, None, None]: ) -> Generator[StreamResponse, None, None]:
""" """
Process stream response. Process stream response.
@ -267,22 +270,18 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
db.session.close() db.session.close()
yield self._workflow_start_to_stream_response( yield self._workflow_start_to_stream_response(
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
workflow_run=workflow_run
) )
elif isinstance(event, QueueNodeStartedEvent): elif isinstance(event, QueueNodeStartedEvent):
if not workflow_run: if not workflow_run:
raise Exception('Workflow run not initialized.') raise Exception("Workflow run not initialized.")
workflow_node_execution = self._handle_node_execution_start( workflow_node_execution = self._handle_node_execution_start(workflow_run=workflow_run, event=event)
workflow_run=workflow_run,
event=event
)
response = self._workflow_node_start_to_stream_response( response = self._workflow_node_start_to_stream_response(
event=event, event=event,
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution workflow_node_execution=workflow_node_execution,
) )
if response: if response:
@ -293,7 +292,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
response = self._workflow_node_finish_to_stream_response( response = self._workflow_node_finish_to_stream_response(
event=event, event=event,
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution workflow_node_execution=workflow_node_execution,
) )
if response: if response:
@ -304,62 +303,52 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
response = self._workflow_node_finish_to_stream_response( response = self._workflow_node_finish_to_stream_response(
event=event, event=event,
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution workflow_node_execution=workflow_node_execution,
) )
if response: if response:
yield response yield response
elif isinstance(event, QueueParallelBranchRunStartedEvent): elif isinstance(event, QueueParallelBranchRunStartedEvent):
if not workflow_run: if not workflow_run:
raise Exception('Workflow run not initialized.') raise Exception("Workflow run not initialized.")
yield self._workflow_parallel_branch_start_to_stream_response( yield self._workflow_parallel_branch_start_to_stream_response(
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
workflow_run=workflow_run,
event=event
) )
elif isinstance(event, QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent): elif isinstance(event, QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent):
if not workflow_run: if not workflow_run:
raise Exception('Workflow run not initialized.') raise Exception("Workflow run not initialized.")
yield self._workflow_parallel_branch_finished_to_stream_response( yield self._workflow_parallel_branch_finished_to_stream_response(
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
workflow_run=workflow_run,
event=event
) )
elif isinstance(event, QueueIterationStartEvent): elif isinstance(event, QueueIterationStartEvent):
if not workflow_run: if not workflow_run:
raise Exception('Workflow run not initialized.') raise Exception("Workflow run not initialized.")
yield self._workflow_iteration_start_to_stream_response( yield self._workflow_iteration_start_to_stream_response(
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
workflow_run=workflow_run,
event=event
) )
elif isinstance(event, QueueIterationNextEvent): elif isinstance(event, QueueIterationNextEvent):
if not workflow_run: if not workflow_run:
raise Exception('Workflow run not initialized.') raise Exception("Workflow run not initialized.")
yield self._workflow_iteration_next_to_stream_response( yield self._workflow_iteration_next_to_stream_response(
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
workflow_run=workflow_run,
event=event
) )
elif isinstance(event, QueueIterationCompletedEvent): elif isinstance(event, QueueIterationCompletedEvent):
if not workflow_run: if not workflow_run:
raise Exception('Workflow run not initialized.') raise Exception("Workflow run not initialized.")
yield self._workflow_iteration_completed_to_stream_response( yield self._workflow_iteration_completed_to_stream_response(
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
workflow_run=workflow_run,
event=event
) )
elif isinstance(event, QueueWorkflowSucceededEvent): elif isinstance(event, QueueWorkflowSucceededEvent):
if not workflow_run: if not workflow_run:
raise Exception('Workflow run not initialized.') raise Exception("Workflow run not initialized.")
if not graph_runtime_state: if not graph_runtime_state:
raise Exception('Graph runtime state not initialized.') raise Exception("Graph runtime state not initialized.")
workflow_run = self._handle_workflow_run_success( workflow_run = self._handle_workflow_run_success(
workflow_run=workflow_run, workflow_run=workflow_run,
@ -372,20 +361,16 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
) )
yield self._workflow_finish_to_stream_response( yield self._workflow_finish_to_stream_response(
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
workflow_run=workflow_run
) )
self._queue_manager.publish( self._queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE)
QueueAdvancedChatMessageEndEvent(),
PublishFrom.TASK_PIPELINE
)
elif isinstance(event, QueueWorkflowFailedEvent): elif isinstance(event, QueueWorkflowFailedEvent):
if not workflow_run: if not workflow_run:
raise Exception('Workflow run not initialized.') raise Exception("Workflow run not initialized.")
if not graph_runtime_state: if not graph_runtime_state:
raise Exception('Graph runtime state not initialized.') raise Exception("Graph runtime state not initialized.")
workflow_run = self._handle_workflow_run_failed( workflow_run = self._handle_workflow_run_failed(
workflow_run=workflow_run, workflow_run=workflow_run,
@ -399,11 +384,10 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
) )
yield self._workflow_finish_to_stream_response( yield self._workflow_finish_to_stream_response(
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
workflow_run=workflow_run
) )
err_event = QueueErrorEvent(error=ValueError(f'Run failed: {workflow_run.error}')) err_event = QueueErrorEvent(error=ValueError(f"Run failed: {workflow_run.error}"))
yield self._error_to_stream_response(self._handle_error(err_event, self._message)) yield self._error_to_stream_response(self._handle_error(err_event, self._message))
break break
elif isinstance(event, QueueStopEvent): elif isinstance(event, QueueStopEvent):
@ -420,8 +404,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
) )
yield self._workflow_finish_to_stream_response( yield self._workflow_finish_to_stream_response(
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
workflow_run=workflow_run
) )
# Save message # Save message
@ -434,8 +417,9 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
self._refetch_message() self._refetch_message()
self._message.message_metadata = json.dumps(jsonable_encoder(self._task_state.metadata)) \ self._message.message_metadata = (
if self._task_state.metadata else None json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
)
db.session.commit() db.session.commit()
db.session.refresh(self._message) db.session.refresh(self._message)
@ -445,8 +429,9 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
self._refetch_message() self._refetch_message()
self._message.message_metadata = json.dumps(jsonable_encoder(self._task_state.metadata)) \ self._message.message_metadata = (
if self._task_state.metadata else None json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
)
db.session.commit() db.session.commit()
db.session.refresh(self._message) db.session.refresh(self._message)
@ -472,7 +457,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
yield self._message_replace_to_stream_response(answer=event.text) yield self._message_replace_to_stream_response(answer=event.text)
elif isinstance(event, QueueAdvancedChatMessageEndEvent): elif isinstance(event, QueueAdvancedChatMessageEndEvent):
if not graph_runtime_state: if not graph_runtime_state:
raise Exception('Graph runtime state not initialized.') raise Exception("Graph runtime state not initialized.")
output_moderation_answer = self._handle_output_moderation_when_task_finished(self._task_state.answer) output_moderation_answer = self._handle_output_moderation_when_task_finished(self._task_state.answer)
if output_moderation_answer: if output_moderation_answer:
@ -502,8 +487,9 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
self._message.answer = self._task_state.answer self._message.answer = self._task_state.answer
self._message.provider_response_latency = time.perf_counter() - self._start_at self._message.provider_response_latency = time.perf_counter() - self._start_at
self._message.message_metadata = json.dumps(jsonable_encoder(self._task_state.metadata)) \ self._message.message_metadata = (
if self._task_state.metadata else None json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
)
if graph_runtime_state and graph_runtime_state.llm_usage: if graph_runtime_state and graph_runtime_state.llm_usage:
usage = graph_runtime_state.llm_usage usage = graph_runtime_state.llm_usage
@ -523,7 +509,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
application_generate_entity=self._application_generate_entity, application_generate_entity=self._application_generate_entity,
conversation=self._conversation, conversation=self._conversation,
is_first_message=self._application_generate_entity.conversation_id is None, is_first_message=self._application_generate_entity.conversation_id is None,
extras=self._application_generate_entity.extras extras=self._application_generate_entity.extras,
) )
def _message_end_to_stream_response(self) -> MessageEndStreamResponse: def _message_end_to_stream_response(self) -> MessageEndStreamResponse:
@ -533,15 +519,13 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
""" """
extras = {} extras = {}
if self._task_state.metadata: if self._task_state.metadata:
extras['metadata'] = self._task_state.metadata.copy() extras["metadata"] = self._task_state.metadata.copy()
if 'annotation_reply' in extras['metadata']: if "annotation_reply" in extras["metadata"]:
del extras['metadata']['annotation_reply'] del extras["metadata"]["annotation_reply"]
return MessageEndStreamResponse( return MessageEndStreamResponse(
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id, id=self._message.id, **extras
id=self._message.id,
**extras
) )
def _handle_output_moderation_chunk(self, text: str) -> bool: def _handle_output_moderation_chunk(self, text: str) -> bool:
@ -555,14 +539,11 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
# stop subscribe new token when output moderation should direct output # stop subscribe new token when output moderation should direct output
self._task_state.answer = self._output_moderation_handler.get_final_output() self._task_state.answer = self._output_moderation_handler.get_final_output()
self._queue_manager.publish( self._queue_manager.publish(
QueueTextChunkEvent( QueueTextChunkEvent(text=self._task_state.answer), PublishFrom.TASK_PIPELINE
text=self._task_state.answer
), PublishFrom.TASK_PIPELINE
) )
self._queue_manager.publish( self._queue_manager.publish(
QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION), QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION), PublishFrom.TASK_PIPELINE
PublishFrom.TASK_PIPELINE
) )
return True return True
else: else:

View File

@ -28,15 +28,19 @@ class AgentChatAppConfig(EasyUIBasedAppConfig):
""" """
Agent Chatbot App Config Entity. Agent Chatbot App Config Entity.
""" """
agent: Optional[AgentEntity] = None agent: Optional[AgentEntity] = None
class AgentChatAppConfigManager(BaseAppConfigManager): class AgentChatAppConfigManager(BaseAppConfigManager):
@classmethod @classmethod
def get_app_config(cls, app_model: App, def get_app_config(
app_model_config: AppModelConfig, cls,
conversation: Optional[Conversation] = None, app_model: App,
override_config_dict: Optional[dict] = None) -> AgentChatAppConfig: app_model_config: AppModelConfig,
conversation: Optional[Conversation] = None,
override_config_dict: Optional[dict] = None,
) -> AgentChatAppConfig:
""" """
Convert app model config to agent chat app config Convert app model config to agent chat app config
:param app_model: app model :param app_model: app model
@ -66,22 +70,12 @@ class AgentChatAppConfigManager(BaseAppConfigManager):
app_model_config_from=config_from, app_model_config_from=config_from,
app_model_config_id=app_model_config.id, app_model_config_id=app_model_config.id,
app_model_config_dict=config_dict, app_model_config_dict=config_dict,
model=ModelConfigManager.convert( model=ModelConfigManager.convert(config=config_dict),
config=config_dict prompt_template=PromptTemplateConfigManager.convert(config=config_dict),
), sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(config=config_dict),
prompt_template=PromptTemplateConfigManager.convert( dataset=DatasetConfigManager.convert(config=config_dict),
config=config_dict agent=AgentConfigManager.convert(config=config_dict),
), additional_features=cls.convert_features(config_dict, app_mode),
sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(
config=config_dict
),
dataset=DatasetConfigManager.convert(
config=config_dict
),
agent=AgentConfigManager.convert(
config=config_dict
),
additional_features=cls.convert_features(config_dict, app_mode)
) )
app_config.variables, app_config.external_data_variables = BasicVariablesConfigManager.convert( app_config.variables, app_config.external_data_variables = BasicVariablesConfigManager.convert(
@ -128,7 +122,8 @@ class AgentChatAppConfigManager(BaseAppConfigManager):
# suggested_questions_after_answer # suggested_questions_after_answer
config, current_related_config_keys = SuggestedQuestionsAfterAnswerConfigManager.validate_and_set_defaults( config, current_related_config_keys = SuggestedQuestionsAfterAnswerConfigManager.validate_and_set_defaults(
config) config
)
related_config_keys.extend(current_related_config_keys) related_config_keys.extend(current_related_config_keys)
# speech_to_text # speech_to_text
@ -145,13 +140,15 @@ class AgentChatAppConfigManager(BaseAppConfigManager):
# dataset configs # dataset configs
# dataset_query_variable # dataset_query_variable
config, current_related_config_keys = DatasetConfigManager.validate_and_set_defaults(tenant_id, app_mode, config, current_related_config_keys = DatasetConfigManager.validate_and_set_defaults(
config) tenant_id, app_mode, config
)
related_config_keys.extend(current_related_config_keys) related_config_keys.extend(current_related_config_keys)
# moderation validation # moderation validation
config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(tenant_id, config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(
config) tenant_id, config
)
related_config_keys.extend(current_related_config_keys) related_config_keys.extend(current_related_config_keys)
related_config_keys = list(set(related_config_keys)) related_config_keys = list(set(related_config_keys))
@ -170,10 +167,7 @@ class AgentChatAppConfigManager(BaseAppConfigManager):
:param config: app model config args :param config: app model config args
""" """
if not config.get("agent_mode"): if not config.get("agent_mode"):
config["agent_mode"] = { config["agent_mode"] = {"enabled": False, "tools": []}
"enabled": False,
"tools": []
}
if not isinstance(config["agent_mode"], dict): if not isinstance(config["agent_mode"], dict):
raise ValueError("agent_mode must be of object type") raise ValueError("agent_mode must be of object type")
@ -187,8 +181,9 @@ class AgentChatAppConfigManager(BaseAppConfigManager):
if not config["agent_mode"].get("strategy"): if not config["agent_mode"].get("strategy"):
config["agent_mode"]["strategy"] = PlanningStrategy.ROUTER.value config["agent_mode"]["strategy"] = PlanningStrategy.ROUTER.value
if config["agent_mode"]["strategy"] not in [member.value for member in if config["agent_mode"]["strategy"] not in [
list(PlanningStrategy.__members__.values())]: member.value for member in list(PlanningStrategy.__members__.values())
]:
raise ValueError("strategy in agent_mode must be in the specified strategy list") raise ValueError("strategy in agent_mode must be in the specified strategy list")
if not config["agent_mode"].get("tools"): if not config["agent_mode"].get("tools"):
@ -210,7 +205,7 @@ class AgentChatAppConfigManager(BaseAppConfigManager):
raise ValueError("enabled in agent_mode.tools must be of boolean type") raise ValueError("enabled in agent_mode.tools must be of boolean type")
if key == "dataset": if key == "dataset":
if 'id' not in tool_item: if "id" not in tool_item:
raise ValueError("id is required in dataset") raise ValueError("id is required in dataset")
try: try:

View File

@ -30,7 +30,8 @@ logger = logging.getLogger(__name__)
class AgentChatAppGenerator(MessageBasedAppGenerator): class AgentChatAppGenerator(MessageBasedAppGenerator):
@overload @overload
def generate( def generate(
self, app_model: App, self,
app_model: App,
user: Union[Account, EndUser], user: Union[Account, EndUser],
args: dict, args: dict,
invoke_from: InvokeFrom, invoke_from: InvokeFrom,
@ -39,19 +40,17 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
@overload @overload
def generate( def generate(
self, app_model: App, self,
app_model: App,
user: Union[Account, EndUser], user: Union[Account, EndUser],
args: dict, args: dict,
invoke_from: InvokeFrom, invoke_from: InvokeFrom,
stream: Literal[False] = False, stream: Literal[False] = False,
) -> dict: ... ) -> dict: ...
def generate(self, app_model: App, def generate(
user: Union[Account, EndUser], self, app_model: App, user: Union[Account, EndUser], args: Any, invoke_from: InvokeFrom, stream: bool = True
args: Any, ) -> Union[dict, Generator[dict, None, None]]:
invoke_from: InvokeFrom,
stream: bool = True) \
-> Union[dict, Generator[dict, None, None]]:
""" """
Generate App response. Generate App response.
@ -62,60 +61,48 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
:param stream: is stream :param stream: is stream
""" """
if not stream: if not stream:
raise ValueError('Agent Chat App does not support blocking mode') raise ValueError("Agent Chat App does not support blocking mode")
if not args.get('query'): if not args.get("query"):
raise ValueError('query is required') raise ValueError("query is required")
query = args['query'] query = args["query"]
if not isinstance(query, str): if not isinstance(query, str):
raise ValueError('query must be a string') raise ValueError("query must be a string")
query = query.replace('\x00', '') query = query.replace("\x00", "")
inputs = args['inputs'] inputs = args["inputs"]
extras = { extras = {"auto_generate_conversation_name": args.get("auto_generate_name", True)}
"auto_generate_conversation_name": args.get('auto_generate_name', True)
}
# get conversation # get conversation
conversation = None conversation = None
if args.get('conversation_id'): if args.get("conversation_id"):
conversation = self._get_conversation_by_user(app_model, args.get('conversation_id'), user) conversation = self._get_conversation_by_user(app_model, args.get("conversation_id"), user)
# get app model config # get app model config
app_model_config = self._get_app_model_config( app_model_config = self._get_app_model_config(app_model=app_model, conversation=conversation)
app_model=app_model,
conversation=conversation
)
# validate override model config # validate override model config
override_model_config_dict = None override_model_config_dict = None
if args.get('model_config'): if args.get("model_config"):
if invoke_from != InvokeFrom.DEBUGGER: if invoke_from != InvokeFrom.DEBUGGER:
raise ValueError('Only in App debug mode can override model config') raise ValueError("Only in App debug mode can override model config")
# validate config # validate config
override_model_config_dict = AgentChatAppConfigManager.config_validate( override_model_config_dict = AgentChatAppConfigManager.config_validate(
tenant_id=app_model.tenant_id, tenant_id=app_model.tenant_id, config=args.get("model_config")
config=args.get('model_config')
) )
# always enable retriever resource in debugger mode # always enable retriever resource in debugger mode
override_model_config_dict["retriever_resource"] = { override_model_config_dict["retriever_resource"] = {"enabled": True}
"enabled": True
}
# parse files # parse files
files = args['files'] if args.get('files') else [] files = args["files"] if args.get("files") else []
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id) message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict()) file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict())
if file_extra_config: if file_extra_config:
file_objs = message_file_parser.validate_and_transform_files_arg( file_objs = message_file_parser.validate_and_transform_files_arg(files, file_extra_config, user)
files,
file_extra_config,
user
)
else: else:
file_objs = [] file_objs = []
@ -124,7 +111,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
app_model=app_model, app_model=app_model,
app_model_config=app_model_config, app_model_config=app_model_config,
conversation=conversation, conversation=conversation,
override_config_dict=override_model_config_dict override_config_dict=override_model_config_dict,
) )
# get tracing instance # get tracing instance
@ -145,14 +132,11 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
invoke_from=invoke_from, invoke_from=invoke_from,
extras=extras, extras=extras,
call_depth=0, call_depth=0,
trace_manager=trace_manager trace_manager=trace_manager,
) )
# init generate records # init generate records
( (conversation, message) = self._init_generate_records(application_generate_entity, conversation)
conversation,
message
) = self._init_generate_records(application_generate_entity, conversation)
# init queue manager # init queue manager
queue_manager = MessageBasedAppQueueManager( queue_manager = MessageBasedAppQueueManager(
@ -161,17 +145,20 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
invoke_from=application_generate_entity.invoke_from, invoke_from=application_generate_entity.invoke_from,
conversation_id=conversation.id, conversation_id=conversation.id,
app_mode=conversation.mode, app_mode=conversation.mode,
message_id=message.id message_id=message.id,
) )
# new thread # new thread
worker_thread = threading.Thread(target=self._generate_worker, kwargs={ worker_thread = threading.Thread(
'flask_app': current_app._get_current_object(), target=self._generate_worker,
'application_generate_entity': application_generate_entity, kwargs={
'queue_manager': queue_manager, "flask_app": current_app._get_current_object(),
'conversation_id': conversation.id, "application_generate_entity": application_generate_entity,
'message_id': message.id, "queue_manager": queue_manager,
}) "conversation_id": conversation.id,
"message_id": message.id,
},
)
worker_thread.start() worker_thread.start()
@ -185,13 +172,11 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
stream=stream, stream=stream,
) )
return AgentChatAppGenerateResponseConverter.convert( return AgentChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
response=response,
invoke_from=invoke_from
)
def _generate_worker( def _generate_worker(
self, flask_app: Flask, self,
flask_app: Flask,
application_generate_entity: AgentChatAppGenerateEntity, application_generate_entity: AgentChatAppGenerateEntity,
queue_manager: AppQueueManager, queue_manager: AppQueueManager,
conversation_id: str, conversation_id: str,
@ -224,14 +209,13 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
pass pass
except InvokeAuthorizationError: except InvokeAuthorizationError:
queue_manager.publish_error( queue_manager.publish_error(
InvokeAuthorizationError('Incorrect API key provided'), InvokeAuthorizationError("Incorrect API key provided"), PublishFrom.APPLICATION_MANAGER
PublishFrom.APPLICATION_MANAGER
) )
except ValidationError as e: except ValidationError as e:
logger.exception("Validation Error when generating") logger.exception("Validation Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except (ValueError, InvokeError) as e: except (ValueError, InvokeError) as e:
if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == 'true': if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == "true":
logger.exception("Error when generating") logger.exception("Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except Exception as e: except Exception as e:

View File

@ -30,7 +30,8 @@ class AgentChatAppRunner(AppRunner):
""" """
def run( def run(
self, application_generate_entity: AgentChatAppGenerateEntity, self,
application_generate_entity: AgentChatAppGenerateEntity,
queue_manager: AppQueueManager, queue_manager: AppQueueManager,
conversation: Conversation, conversation: Conversation,
message: Message, message: Message,
@ -65,7 +66,7 @@ class AgentChatAppRunner(AppRunner):
prompt_template_entity=app_config.prompt_template, prompt_template_entity=app_config.prompt_template,
inputs=inputs, inputs=inputs,
files=files, files=files,
query=query query=query,
) )
memory = None memory = None
@ -73,13 +74,10 @@ class AgentChatAppRunner(AppRunner):
# get memory of conversation (read-only) # get memory of conversation (read-only)
model_instance = ModelInstance( model_instance = ModelInstance(
provider_model_bundle=application_generate_entity.model_conf.provider_model_bundle, provider_model_bundle=application_generate_entity.model_conf.provider_model_bundle,
model=application_generate_entity.model_conf.model model=application_generate_entity.model_conf.model,
) )
memory = TokenBufferMemory( memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance)
conversation=conversation,
model_instance=model_instance
)
# organize all inputs and template to prompt messages # organize all inputs and template to prompt messages
# Include: prompt template, inputs, query(optional), files(optional) # Include: prompt template, inputs, query(optional), files(optional)
@ -91,7 +89,7 @@ class AgentChatAppRunner(AppRunner):
inputs=inputs, inputs=inputs,
files=files, files=files,
query=query, query=query,
memory=memory memory=memory,
) )
# moderation # moderation
@ -103,7 +101,7 @@ class AgentChatAppRunner(AppRunner):
app_generate_entity=application_generate_entity, app_generate_entity=application_generate_entity,
inputs=inputs, inputs=inputs,
query=query, query=query,
message_id=message.id message_id=message.id,
) )
except ModerationException as e: except ModerationException as e:
self.direct_output( self.direct_output(
@ -111,7 +109,7 @@ class AgentChatAppRunner(AppRunner):
app_generate_entity=application_generate_entity, app_generate_entity=application_generate_entity,
prompt_messages=prompt_messages, prompt_messages=prompt_messages,
text=str(e), text=str(e),
stream=application_generate_entity.stream stream=application_generate_entity.stream,
) )
return return
@ -122,13 +120,13 @@ class AgentChatAppRunner(AppRunner):
message=message, message=message,
query=query, query=query,
user_id=application_generate_entity.user_id, user_id=application_generate_entity.user_id,
invoke_from=application_generate_entity.invoke_from invoke_from=application_generate_entity.invoke_from,
) )
if annotation_reply: if annotation_reply:
queue_manager.publish( queue_manager.publish(
QueueAnnotationReplyEvent(message_annotation_id=annotation_reply.id), QueueAnnotationReplyEvent(message_annotation_id=annotation_reply.id),
PublishFrom.APPLICATION_MANAGER PublishFrom.APPLICATION_MANAGER,
) )
self.direct_output( self.direct_output(
@ -136,7 +134,7 @@ class AgentChatAppRunner(AppRunner):
app_generate_entity=application_generate_entity, app_generate_entity=application_generate_entity,
prompt_messages=prompt_messages, prompt_messages=prompt_messages,
text=annotation_reply.content, text=annotation_reply.content,
stream=application_generate_entity.stream stream=application_generate_entity.stream,
) )
return return
@ -148,7 +146,7 @@ class AgentChatAppRunner(AppRunner):
app_id=app_record.id, app_id=app_record.id,
external_data_tools=external_data_tools, external_data_tools=external_data_tools,
inputs=inputs, inputs=inputs,
query=query query=query,
) )
# reorganize all inputs and template to prompt messages # reorganize all inputs and template to prompt messages
@ -161,14 +159,14 @@ class AgentChatAppRunner(AppRunner):
inputs=inputs, inputs=inputs,
files=files, files=files,
query=query, query=query,
memory=memory memory=memory,
) )
# check hosting moderation # check hosting moderation
hosting_moderation_result = self.check_hosting_moderation( hosting_moderation_result = self.check_hosting_moderation(
application_generate_entity=application_generate_entity, application_generate_entity=application_generate_entity,
queue_manager=queue_manager, queue_manager=queue_manager,
prompt_messages=prompt_messages prompt_messages=prompt_messages,
) )
if hosting_moderation_result: if hosting_moderation_result:
@ -177,9 +175,9 @@ class AgentChatAppRunner(AppRunner):
agent_entity = app_config.agent agent_entity = app_config.agent
# load tool variables # load tool variables
tool_conversation_variables = self._load_tool_variables(conversation_id=conversation.id, tool_conversation_variables = self._load_tool_variables(
user_id=application_generate_entity.user_id, conversation_id=conversation.id, user_id=application_generate_entity.user_id, tenant_id=app_config.tenant_id
tenant_id=app_config.tenant_id) )
# convert db variables to tool variables # convert db variables to tool variables
tool_variables = self._convert_db_variables_to_tool_variables(tool_conversation_variables) tool_variables = self._convert_db_variables_to_tool_variables(tool_conversation_variables)
@ -187,7 +185,7 @@ class AgentChatAppRunner(AppRunner):
# init model instance # init model instance
model_instance = ModelInstance( model_instance = ModelInstance(
provider_model_bundle=application_generate_entity.model_conf.provider_model_bundle, provider_model_bundle=application_generate_entity.model_conf.provider_model_bundle,
model=application_generate_entity.model_conf.model model=application_generate_entity.model_conf.model,
) )
prompt_message, _ = self.organize_prompt_messages( prompt_message, _ = self.organize_prompt_messages(
app_record=app_record, app_record=app_record,
@ -238,7 +236,7 @@ class AgentChatAppRunner(AppRunner):
prompt_messages=prompt_message, prompt_messages=prompt_message,
variables_pool=tool_variables, variables_pool=tool_variables,
db_variables=tool_conversation_variables, db_variables=tool_conversation_variables,
model_instance=model_instance model_instance=model_instance,
) )
invoke_result = runner.run( invoke_result = runner.run(
@ -252,17 +250,21 @@ class AgentChatAppRunner(AppRunner):
invoke_result=invoke_result, invoke_result=invoke_result,
queue_manager=queue_manager, queue_manager=queue_manager,
stream=application_generate_entity.stream, stream=application_generate_entity.stream,
agent=True agent=True,
) )
def _load_tool_variables(self, conversation_id: str, user_id: str, tenant_id: str) -> ToolConversationVariables: def _load_tool_variables(self, conversation_id: str, user_id: str, tenant_id: str) -> ToolConversationVariables:
""" """
load tool variables from database load tool variables from database
""" """
tool_variables: ToolConversationVariables = db.session.query(ToolConversationVariables).filter( tool_variables: ToolConversationVariables = (
ToolConversationVariables.conversation_id == conversation_id, db.session.query(ToolConversationVariables)
ToolConversationVariables.tenant_id == tenant_id .filter(
).first() ToolConversationVariables.conversation_id == conversation_id,
ToolConversationVariables.tenant_id == tenant_id,
)
.first()
)
if tool_variables: if tool_variables:
# save tool variables to session, so that we can update it later # save tool variables to session, so that we can update it later
@ -273,34 +275,40 @@ class AgentChatAppRunner(AppRunner):
conversation_id=conversation_id, conversation_id=conversation_id,
user_id=user_id, user_id=user_id,
tenant_id=tenant_id, tenant_id=tenant_id,
variables_str='[]', variables_str="[]",
) )
db.session.add(tool_variables) db.session.add(tool_variables)
db.session.commit() db.session.commit()
return tool_variables return tool_variables
def _convert_db_variables_to_tool_variables(self, db_variables: ToolConversationVariables) -> ToolRuntimeVariablePool: def _convert_db_variables_to_tool_variables(
self, db_variables: ToolConversationVariables
) -> ToolRuntimeVariablePool:
""" """
convert db variables to tool variables convert db variables to tool variables
""" """
return ToolRuntimeVariablePool(**{ return ToolRuntimeVariablePool(
'conversation_id': db_variables.conversation_id, **{
'user_id': db_variables.user_id, "conversation_id": db_variables.conversation_id,
'tenant_id': db_variables.tenant_id, "user_id": db_variables.user_id,
'pool': db_variables.variables "tenant_id": db_variables.tenant_id,
}) "pool": db_variables.variables,
}
)
def _get_usage_of_all_agent_thoughts(self, model_config: ModelConfigWithCredentialsEntity, def _get_usage_of_all_agent_thoughts(
message: Message) -> LLMUsage: self, model_config: ModelConfigWithCredentialsEntity, message: Message
) -> LLMUsage:
""" """
Get usage of all agent thoughts Get usage of all agent thoughts
:param model_config: model config :param model_config: model config
:param message: message :param message: message
:return: :return:
""" """
agent_thoughts = (db.session.query(MessageAgentThought) agent_thoughts = (
.filter(MessageAgentThought.message_id == message.id).all()) db.session.query(MessageAgentThought).filter(MessageAgentThought.message_id == message.id).all()
)
all_message_tokens = 0 all_message_tokens = 0
all_answer_tokens = 0 all_answer_tokens = 0
@ -312,8 +320,5 @@ class AgentChatAppRunner(AppRunner):
model_type_instance = cast(LargeLanguageModel, model_type_instance) model_type_instance = cast(LargeLanguageModel, model_type_instance)
return model_type_instance._calc_response_usage( return model_type_instance._calc_response_usage(
model_config.model, model_config.model, model_config.credentials, all_message_tokens, all_answer_tokens
model_config.credentials,
all_message_tokens,
all_answer_tokens
) )

View File

@ -23,15 +23,15 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
:return: :return:
""" """
response = { response = {
'event': 'message', "event": "message",
'task_id': blocking_response.task_id, "task_id": blocking_response.task_id,
'id': blocking_response.data.id, "id": blocking_response.data.id,
'message_id': blocking_response.data.message_id, "message_id": blocking_response.data.message_id,
'conversation_id': blocking_response.data.conversation_id, "conversation_id": blocking_response.data.conversation_id,
'mode': blocking_response.data.mode, "mode": blocking_response.data.mode,
'answer': blocking_response.data.answer, "answer": blocking_response.data.answer,
'metadata': blocking_response.data.metadata, "metadata": blocking_response.data.metadata,
'created_at': blocking_response.data.created_at "created_at": blocking_response.data.created_at,
} }
return response return response
@ -45,14 +45,15 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
""" """
response = cls.convert_blocking_full_response(blocking_response) response = cls.convert_blocking_full_response(blocking_response)
metadata = response.get('metadata', {}) metadata = response.get("metadata", {})
response['metadata'] = cls._get_simple_metadata(metadata) response["metadata"] = cls._get_simple_metadata(metadata)
return response return response
@classmethod @classmethod
def convert_stream_full_response(cls, stream_response: Generator[ChatbotAppStreamResponse, None, None]) \ def convert_stream_full_response(
-> Generator[str, None, None]: cls, stream_response: Generator[ChatbotAppStreamResponse, None, None]
) -> Generator[str, None, None]:
""" """
Convert stream full response. Convert stream full response.
:param stream_response: stream response :param stream_response: stream response
@ -63,14 +64,14 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
sub_stream_response = chunk.stream_response sub_stream_response = chunk.stream_response
if isinstance(sub_stream_response, PingStreamResponse): if isinstance(sub_stream_response, PingStreamResponse):
yield 'ping' yield "ping"
continue continue
response_chunk = { response_chunk = {
'event': sub_stream_response.event.value, "event": sub_stream_response.event.value,
'conversation_id': chunk.conversation_id, "conversation_id": chunk.conversation_id,
'message_id': chunk.message_id, "message_id": chunk.message_id,
'created_at': chunk.created_at "created_at": chunk.created_at,
} }
if isinstance(sub_stream_response, ErrorStreamResponse): if isinstance(sub_stream_response, ErrorStreamResponse):
@ -81,8 +82,9 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
yield json.dumps(response_chunk) yield json.dumps(response_chunk)
@classmethod @classmethod
def convert_stream_simple_response(cls, stream_response: Generator[ChatbotAppStreamResponse, None, None]) \ def convert_stream_simple_response(
-> Generator[str, None, None]: cls, stream_response: Generator[ChatbotAppStreamResponse, None, None]
) -> Generator[str, None, None]:
""" """
Convert stream simple response. Convert stream simple response.
:param stream_response: stream response :param stream_response: stream response
@ -93,20 +95,20 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
sub_stream_response = chunk.stream_response sub_stream_response = chunk.stream_response
if isinstance(sub_stream_response, PingStreamResponse): if isinstance(sub_stream_response, PingStreamResponse):
yield 'ping' yield "ping"
continue continue
response_chunk = { response_chunk = {
'event': sub_stream_response.event.value, "event": sub_stream_response.event.value,
'conversation_id': chunk.conversation_id, "conversation_id": chunk.conversation_id,
'message_id': chunk.message_id, "message_id": chunk.message_id,
'created_at': chunk.created_at "created_at": chunk.created_at,
} }
if isinstance(sub_stream_response, MessageEndStreamResponse): if isinstance(sub_stream_response, MessageEndStreamResponse):
sub_stream_response_dict = sub_stream_response.to_dict() sub_stream_response_dict = sub_stream_response.to_dict()
metadata = sub_stream_response_dict.get('metadata', {}) metadata = sub_stream_response_dict.get("metadata", {})
sub_stream_response_dict['metadata'] = cls._get_simple_metadata(metadata) sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
response_chunk.update(sub_stream_response_dict) response_chunk.update(sub_stream_response_dict)
if isinstance(sub_stream_response, ErrorStreamResponse): if isinstance(sub_stream_response, ErrorStreamResponse):
data = cls._error_to_stream_response(sub_stream_response.err) data = cls._error_to_stream_response(sub_stream_response.err)

View File

@ -13,32 +13,33 @@ class AppGenerateResponseConverter(ABC):
_blocking_response_type: type[AppBlockingResponse] _blocking_response_type: type[AppBlockingResponse]
@classmethod @classmethod
def convert(cls, response: Union[ def convert(
AppBlockingResponse, cls, response: Union[AppBlockingResponse, Generator[AppStreamResponse, Any, None]], invoke_from: InvokeFrom
Generator[AppStreamResponse, Any, None] ) -> dict[str, Any] | Generator[str, Any, None]:
], invoke_from: InvokeFrom) -> dict[str, Any] | Generator[str, Any, None]:
if invoke_from in [InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API]: if invoke_from in [InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API]:
if isinstance(response, AppBlockingResponse): if isinstance(response, AppBlockingResponse):
return cls.convert_blocking_full_response(response) return cls.convert_blocking_full_response(response)
else: else:
def _generate_full_response() -> Generator[str, Any, None]: def _generate_full_response() -> Generator[str, Any, None]:
for chunk in cls.convert_stream_full_response(response): for chunk in cls.convert_stream_full_response(response):
if chunk == 'ping': if chunk == "ping":
yield f'event: {chunk}\n\n' yield f"event: {chunk}\n\n"
else: else:
yield f'data: {chunk}\n\n' yield f"data: {chunk}\n\n"
return _generate_full_response() return _generate_full_response()
else: else:
if isinstance(response, AppBlockingResponse): if isinstance(response, AppBlockingResponse):
return cls.convert_blocking_simple_response(response) return cls.convert_blocking_simple_response(response)
else: else:
def _generate_simple_response() -> Generator[str, Any, None]: def _generate_simple_response() -> Generator[str, Any, None]:
for chunk in cls.convert_stream_simple_response(response): for chunk in cls.convert_stream_simple_response(response):
if chunk == 'ping': if chunk == "ping":
yield f'event: {chunk}\n\n' yield f"event: {chunk}\n\n"
else: else:
yield f'data: {chunk}\n\n' yield f"data: {chunk}\n\n"
return _generate_simple_response() return _generate_simple_response()
@ -54,14 +55,16 @@ class AppGenerateResponseConverter(ABC):
@classmethod @classmethod
@abstractmethod @abstractmethod
def convert_stream_full_response(cls, stream_response: Generator[AppStreamResponse, None, None]) \ def convert_stream_full_response(
-> Generator[str, None, None]: cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[str, None, None]:
raise NotImplementedError raise NotImplementedError
@classmethod @classmethod
@abstractmethod @abstractmethod
def convert_stream_simple_response(cls, stream_response: Generator[AppStreamResponse, None, None]) \ def convert_stream_simple_response(
-> Generator[str, None, None]: cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[str, None, None]:
raise NotImplementedError raise NotImplementedError
@classmethod @classmethod
@ -72,24 +75,26 @@ class AppGenerateResponseConverter(ABC):
:return: :return:
""" """
# show_retrieve_source # show_retrieve_source
if 'retriever_resources' in metadata: if "retriever_resources" in metadata:
metadata['retriever_resources'] = [] metadata["retriever_resources"] = []
for resource in metadata['retriever_resources']: for resource in metadata["retriever_resources"]:
metadata['retriever_resources'].append({ metadata["retriever_resources"].append(
'segment_id': resource['segment_id'], {
'position': resource['position'], "segment_id": resource["segment_id"],
'document_name': resource['document_name'], "position": resource["position"],
'score': resource['score'], "document_name": resource["document_name"],
'content': resource['content'], "score": resource["score"],
}) "content": resource["content"],
}
)
# show annotation reply # show annotation reply
if 'annotation_reply' in metadata: if "annotation_reply" in metadata:
del metadata['annotation_reply'] del metadata["annotation_reply"]
# show usage # show usage
if 'usage' in metadata: if "usage" in metadata:
del metadata['usage'] del metadata["usage"]
return metadata return metadata
@ -101,16 +106,16 @@ class AppGenerateResponseConverter(ABC):
:return: :return:
""" """
error_responses = { error_responses = {
ValueError: {'code': 'invalid_param', 'status': 400}, ValueError: {"code": "invalid_param", "status": 400},
ProviderTokenNotInitError: {'code': 'provider_not_initialize', 'status': 400}, ProviderTokenNotInitError: {"code": "provider_not_initialize", "status": 400},
QuotaExceededError: { QuotaExceededError: {
'code': 'provider_quota_exceeded', "code": "provider_quota_exceeded",
'message': "Your quota for Dify Hosted Model Provider has been exhausted. " "message": "Your quota for Dify Hosted Model Provider has been exhausted. "
"Please go to Settings -> Model Provider to complete your own provider credentials.", "Please go to Settings -> Model Provider to complete your own provider credentials.",
'status': 400 "status": 400,
}, },
ModelCurrentlyNotSupportError: {'code': 'model_currently_not_support', 'status': 400}, ModelCurrentlyNotSupportError: {"code": "model_currently_not_support", "status": 400},
InvokeError: {'code': 'completion_request_error', 'status': 400} InvokeError: {"code": "completion_request_error", "status": 400},
} }
# Determine the response based on the type of exception # Determine the response based on the type of exception
@ -120,13 +125,13 @@ class AppGenerateResponseConverter(ABC):
data = v data = v
if data: if data:
data.setdefault('message', getattr(e, 'description', str(e))) data.setdefault("message", getattr(e, "description", str(e)))
else: else:
logging.error(e) logging.error(e)
data = { data = {
'code': 'internal_server_error', "code": "internal_server_error",
'message': 'Internal Server Error, please contact support.', "message": "Internal Server Error, please contact support.",
'status': 500 "status": 500,
} }
return data return data

View File

@ -16,10 +16,10 @@ class BaseAppGenerator:
def _validate_input(self, *, inputs: Mapping[str, Any], var: VariableEntity): def _validate_input(self, *, inputs: Mapping[str, Any], var: VariableEntity):
user_input_value = inputs.get(var.variable) user_input_value = inputs.get(var.variable)
if var.required and not user_input_value: if var.required and not user_input_value:
raise ValueError(f'{var.variable} is required in input form') raise ValueError(f"{var.variable} is required in input form")
if not var.required and not user_input_value: if not var.required and not user_input_value:
# TODO: should we return None here if the default value is None? # TODO: should we return None here if the default value is None?
return var.default or '' return var.default or ""
if ( if (
var.type var.type
in ( in (
@ -34,7 +34,7 @@ class BaseAppGenerator:
if var.type == VariableEntityType.NUMBER and isinstance(user_input_value, str): if var.type == VariableEntityType.NUMBER and isinstance(user_input_value, str):
# may raise ValueError if user_input_value is not a valid number # may raise ValueError if user_input_value is not a valid number
try: try:
if '.' in user_input_value: if "." in user_input_value:
return float(user_input_value) return float(user_input_value)
else: else:
return int(user_input_value) return int(user_input_value)
@ -43,14 +43,14 @@ class BaseAppGenerator:
if var.type == VariableEntityType.SELECT: if var.type == VariableEntityType.SELECT:
options = var.options or [] options = var.options or []
if user_input_value not in options: if user_input_value not in options:
raise ValueError(f'{var.variable} in input form must be one of the following: {options}') raise ValueError(f"{var.variable} in input form must be one of the following: {options}")
elif var.type in (VariableEntityType.TEXT_INPUT, VariableEntityType.PARAGRAPH): elif var.type in (VariableEntityType.TEXT_INPUT, VariableEntityType.PARAGRAPH):
if var.max_length and user_input_value and len(user_input_value) > var.max_length: if var.max_length and user_input_value and len(user_input_value) > var.max_length:
raise ValueError(f'{var.variable} in input form must be less than {var.max_length} characters') raise ValueError(f"{var.variable} in input form must be less than {var.max_length} characters")
return user_input_value return user_input_value
def _sanitize_value(self, value: Any) -> Any: def _sanitize_value(self, value: Any) -> Any:
if isinstance(value, str): if isinstance(value, str):
return value.replace('\x00', '') return value.replace("\x00", "")
return value return value

View File

@ -24,9 +24,7 @@ class PublishFrom(Enum):
class AppQueueManager: class AppQueueManager:
def __init__(self, task_id: str, def __init__(self, task_id: str, user_id: str, invoke_from: InvokeFrom) -> None:
user_id: str,
invoke_from: InvokeFrom) -> None:
if not user_id: if not user_id:
raise ValueError("user is required") raise ValueError("user is required")
@ -34,9 +32,10 @@ class AppQueueManager:
self._user_id = user_id self._user_id = user_id
self._invoke_from = invoke_from self._invoke_from = invoke_from
user_prefix = 'account' if self._invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else 'end-user' user_prefix = "account" if self._invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else "end-user"
redis_client.setex(AppQueueManager._generate_task_belong_cache_key(self._task_id), 1800, redis_client.setex(
f"{user_prefix}-{self._user_id}") AppQueueManager._generate_task_belong_cache_key(self._task_id), 1800, f"{user_prefix}-{self._user_id}"
)
q = queue.Queue() q = queue.Queue()
@ -66,8 +65,7 @@ class AppQueueManager:
# publish two messages to make sure the client can receive the stop signal # publish two messages to make sure the client can receive the stop signal
# and stop listening after the stop signal processed # and stop listening after the stop signal processed
self.publish( self.publish(
QueueStopEvent(stopped_by=QueueStopEvent.StopBy.USER_MANUAL), QueueStopEvent(stopped_by=QueueStopEvent.StopBy.USER_MANUAL), PublishFrom.TASK_PIPELINE
PublishFrom.TASK_PIPELINE
) )
if elapsed_time // 10 > last_ping_time: if elapsed_time // 10 > last_ping_time:
@ -88,9 +86,7 @@ class AppQueueManager:
:param pub_from: publish from :param pub_from: publish from
:return: :return:
""" """
self.publish(QueueErrorEvent( self.publish(QueueErrorEvent(error=e), pub_from)
error=e
), pub_from)
def publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None: def publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None:
""" """
@ -122,8 +118,8 @@ class AppQueueManager:
if result is None: if result is None:
return return
user_prefix = 'account' if invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else 'end-user' user_prefix = "account" if invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else "end-user"
if result.decode('utf-8') != f"{user_prefix}-{user_id}": if result.decode("utf-8") != f"{user_prefix}-{user_id}":
return return
stopped_cache_key = cls._generate_stopped_cache_key(task_id) stopped_cache_key = cls._generate_stopped_cache_key(task_id)
@ -168,9 +164,11 @@ class AppQueueManager:
for item in data: for item in data:
self._check_for_sqlalchemy_models(item) self._check_for_sqlalchemy_models(item)
else: else:
if isinstance(data, DeclarativeMeta) or hasattr(data, '_sa_instance_state'): if isinstance(data, DeclarativeMeta) or hasattr(data, "_sa_instance_state"):
raise TypeError("Critical Error: Passing SQLAlchemy Model instances " raise TypeError(
"that cause thread safety issues is not allowed.") "Critical Error: Passing SQLAlchemy Model instances "
"that cause thread safety issues is not allowed."
)
class GenerateTaskStoppedException(Exception): class GenerateTaskStoppedException(Exception):

View File

@ -31,12 +31,15 @@ if TYPE_CHECKING:
class AppRunner: class AppRunner:
def get_pre_calculate_rest_tokens(self, app_record: App, def get_pre_calculate_rest_tokens(
model_config: ModelConfigWithCredentialsEntity, self,
prompt_template_entity: PromptTemplateEntity, app_record: App,
inputs: dict[str, str], model_config: ModelConfigWithCredentialsEntity,
files: list["FileVar"], prompt_template_entity: PromptTemplateEntity,
query: Optional[str] = None) -> int: inputs: dict[str, str],
files: list["FileVar"],
query: Optional[str] = None,
) -> int:
""" """
Get pre calculate rest tokens Get pre calculate rest tokens
:param app_record: app record :param app_record: app record
@ -49,18 +52,20 @@ class AppRunner:
""" """
# Invoke model # Invoke model
model_instance = ModelInstance( model_instance = ModelInstance(
provider_model_bundle=model_config.provider_model_bundle, provider_model_bundle=model_config.provider_model_bundle, model=model_config.model
model=model_config.model
) )
model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE) model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
max_tokens = 0 max_tokens = 0
for parameter_rule in model_config.model_schema.parameter_rules: for parameter_rule in model_config.model_schema.parameter_rules:
if (parameter_rule.name == 'max_tokens' if parameter_rule.name == "max_tokens" or (
or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')): parameter_rule.use_template and parameter_rule.use_template == "max_tokens"
max_tokens = (model_config.parameters.get(parameter_rule.name) ):
or model_config.parameters.get(parameter_rule.use_template)) or 0 max_tokens = (
model_config.parameters.get(parameter_rule.name)
or model_config.parameters.get(parameter_rule.use_template)
) or 0
if model_context_tokens is None: if model_context_tokens is None:
return -1 return -1
@ -75,36 +80,39 @@ class AppRunner:
prompt_template_entity=prompt_template_entity, prompt_template_entity=prompt_template_entity,
inputs=inputs, inputs=inputs,
files=files, files=files,
query=query query=query,
) )
prompt_tokens = model_instance.get_llm_num_tokens( prompt_tokens = model_instance.get_llm_num_tokens(prompt_messages)
prompt_messages
)
rest_tokens = model_context_tokens - max_tokens - prompt_tokens rest_tokens = model_context_tokens - max_tokens - prompt_tokens
if rest_tokens < 0: if rest_tokens < 0:
raise InvokeBadRequestError("Query or prefix prompt is too long, you can reduce the prefix prompt, " raise InvokeBadRequestError(
"or shrink the max token, or switch to a llm with a larger token limit size.") "Query or prefix prompt is too long, you can reduce the prefix prompt, "
"or shrink the max token, or switch to a llm with a larger token limit size."
)
return rest_tokens return rest_tokens
def recalc_llm_max_tokens(self, model_config: ModelConfigWithCredentialsEntity, def recalc_llm_max_tokens(
prompt_messages: list[PromptMessage]): self, model_config: ModelConfigWithCredentialsEntity, prompt_messages: list[PromptMessage]
):
# recalc max_tokens if sum(prompt_token + max_tokens) over model token limit # recalc max_tokens if sum(prompt_token + max_tokens) over model token limit
model_instance = ModelInstance( model_instance = ModelInstance(
provider_model_bundle=model_config.provider_model_bundle, provider_model_bundle=model_config.provider_model_bundle, model=model_config.model
model=model_config.model
) )
model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE) model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
max_tokens = 0 max_tokens = 0
for parameter_rule in model_config.model_schema.parameter_rules: for parameter_rule in model_config.model_schema.parameter_rules:
if (parameter_rule.name == 'max_tokens' if parameter_rule.name == "max_tokens" or (
or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')): parameter_rule.use_template and parameter_rule.use_template == "max_tokens"
max_tokens = (model_config.parameters.get(parameter_rule.name) ):
or model_config.parameters.get(parameter_rule.use_template)) or 0 max_tokens = (
model_config.parameters.get(parameter_rule.name)
or model_config.parameters.get(parameter_rule.use_template)
) or 0
if model_context_tokens is None: if model_context_tokens is None:
return -1 return -1
@ -112,27 +120,28 @@ class AppRunner:
if max_tokens is None: if max_tokens is None:
max_tokens = 0 max_tokens = 0
prompt_tokens = model_instance.get_llm_num_tokens( prompt_tokens = model_instance.get_llm_num_tokens(prompt_messages)
prompt_messages
)
if prompt_tokens + max_tokens > model_context_tokens: if prompt_tokens + max_tokens > model_context_tokens:
max_tokens = max(model_context_tokens - prompt_tokens, 16) max_tokens = max(model_context_tokens - prompt_tokens, 16)
for parameter_rule in model_config.model_schema.parameter_rules: for parameter_rule in model_config.model_schema.parameter_rules:
if (parameter_rule.name == 'max_tokens' if parameter_rule.name == "max_tokens" or (
or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')): parameter_rule.use_template and parameter_rule.use_template == "max_tokens"
):
model_config.parameters[parameter_rule.name] = max_tokens model_config.parameters[parameter_rule.name] = max_tokens
def organize_prompt_messages(self, app_record: App, def organize_prompt_messages(
model_config: ModelConfigWithCredentialsEntity, self,
prompt_template_entity: PromptTemplateEntity, app_record: App,
inputs: dict[str, str], model_config: ModelConfigWithCredentialsEntity,
files: list["FileVar"], prompt_template_entity: PromptTemplateEntity,
query: Optional[str] = None, inputs: dict[str, str],
context: Optional[str] = None, files: list["FileVar"],
memory: Optional[TokenBufferMemory] = None) \ query: Optional[str] = None,
-> tuple[list[PromptMessage], Optional[list[str]]]: context: Optional[str] = None,
memory: Optional[TokenBufferMemory] = None,
) -> tuple[list[PromptMessage], Optional[list[str]]]:
""" """
Organize prompt messages Organize prompt messages
:param context: :param context:
@ -152,60 +161,54 @@ class AppRunner:
app_mode=AppMode.value_of(app_record.mode), app_mode=AppMode.value_of(app_record.mode),
prompt_template_entity=prompt_template_entity, prompt_template_entity=prompt_template_entity,
inputs=inputs, inputs=inputs,
query=query if query else '', query=query if query else "",
files=files, files=files,
context=context, context=context,
memory=memory, memory=memory,
model_config=model_config model_config=model_config,
) )
else: else:
memory_config = MemoryConfig( memory_config = MemoryConfig(window=MemoryConfig.WindowConfig(enabled=False))
window=MemoryConfig.WindowConfig(
enabled=False
)
)
model_mode = ModelMode.value_of(model_config.mode) model_mode = ModelMode.value_of(model_config.mode)
if model_mode == ModelMode.COMPLETION: if model_mode == ModelMode.COMPLETION:
advanced_completion_prompt_template = prompt_template_entity.advanced_completion_prompt_template advanced_completion_prompt_template = prompt_template_entity.advanced_completion_prompt_template
prompt_template = CompletionModelPromptTemplate( prompt_template = CompletionModelPromptTemplate(text=advanced_completion_prompt_template.prompt)
text=advanced_completion_prompt_template.prompt
)
if advanced_completion_prompt_template.role_prefix: if advanced_completion_prompt_template.role_prefix:
memory_config.role_prefix = MemoryConfig.RolePrefix( memory_config.role_prefix = MemoryConfig.RolePrefix(
user=advanced_completion_prompt_template.role_prefix.user, user=advanced_completion_prompt_template.role_prefix.user,
assistant=advanced_completion_prompt_template.role_prefix.assistant assistant=advanced_completion_prompt_template.role_prefix.assistant,
) )
else: else:
prompt_template = [] prompt_template = []
for message in prompt_template_entity.advanced_chat_prompt_template.messages: for message in prompt_template_entity.advanced_chat_prompt_template.messages:
prompt_template.append(ChatModelMessage( prompt_template.append(ChatModelMessage(text=message.text, role=message.role))
text=message.text,
role=message.role
))
prompt_transform = AdvancedPromptTransform() prompt_transform = AdvancedPromptTransform()
prompt_messages = prompt_transform.get_prompt( prompt_messages = prompt_transform.get_prompt(
prompt_template=prompt_template, prompt_template=prompt_template,
inputs=inputs, inputs=inputs,
query=query if query else '', query=query if query else "",
files=files, files=files,
context=context, context=context,
memory_config=memory_config, memory_config=memory_config,
memory=memory, memory=memory,
model_config=model_config model_config=model_config,
) )
stop = model_config.stop stop = model_config.stop
return prompt_messages, stop return prompt_messages, stop
def direct_output(self, queue_manager: AppQueueManager, def direct_output(
app_generate_entity: EasyUIBasedAppGenerateEntity, self,
prompt_messages: list, queue_manager: AppQueueManager,
text: str, app_generate_entity: EasyUIBasedAppGenerateEntity,
stream: bool, prompt_messages: list,
usage: Optional[LLMUsage] = None) -> None: text: str,
stream: bool,
usage: Optional[LLMUsage] = None,
) -> None:
""" """
Direct output Direct output
:param queue_manager: application queue manager :param queue_manager: application queue manager
@ -222,17 +225,10 @@ class AppRunner:
chunk = LLMResultChunk( chunk = LLMResultChunk(
model=app_generate_entity.model_conf.model, model=app_generate_entity.model_conf.model,
prompt_messages=prompt_messages, prompt_messages=prompt_messages,
delta=LLMResultChunkDelta( delta=LLMResultChunkDelta(index=index, message=AssistantPromptMessage(content=token)),
index=index,
message=AssistantPromptMessage(content=token)
)
) )
queue_manager.publish( queue_manager.publish(QueueLLMChunkEvent(chunk=chunk), PublishFrom.APPLICATION_MANAGER)
QueueLLMChunkEvent(
chunk=chunk
), PublishFrom.APPLICATION_MANAGER
)
index += 1 index += 1
time.sleep(0.01) time.sleep(0.01)
@ -242,15 +238,19 @@ class AppRunner:
model=app_generate_entity.model_conf.model, model=app_generate_entity.model_conf.model,
prompt_messages=prompt_messages, prompt_messages=prompt_messages,
message=AssistantPromptMessage(content=text), message=AssistantPromptMessage(content=text),
usage=usage if usage else LLMUsage.empty_usage() usage=usage if usage else LLMUsage.empty_usage(),
), ),
), PublishFrom.APPLICATION_MANAGER ),
PublishFrom.APPLICATION_MANAGER,
) )
def _handle_invoke_result(self, invoke_result: Union[LLMResult, Generator], def _handle_invoke_result(
queue_manager: AppQueueManager, self,
stream: bool, invoke_result: Union[LLMResult, Generator],
agent: bool = False) -> None: queue_manager: AppQueueManager,
stream: bool,
agent: bool = False,
) -> None:
""" """
Handle invoke result Handle invoke result
:param invoke_result: invoke result :param invoke_result: invoke result
@ -260,21 +260,13 @@ class AppRunner:
:return: :return:
""" """
if not stream: if not stream:
self._handle_invoke_result_direct( self._handle_invoke_result_direct(invoke_result=invoke_result, queue_manager=queue_manager, agent=agent)
invoke_result=invoke_result,
queue_manager=queue_manager,
agent=agent
)
else: else:
self._handle_invoke_result_stream( self._handle_invoke_result_stream(invoke_result=invoke_result, queue_manager=queue_manager, agent=agent)
invoke_result=invoke_result,
queue_manager=queue_manager,
agent=agent
)
def _handle_invoke_result_direct(self, invoke_result: LLMResult, def _handle_invoke_result_direct(
queue_manager: AppQueueManager, self, invoke_result: LLMResult, queue_manager: AppQueueManager, agent: bool
agent: bool) -> None: ) -> None:
""" """
Handle invoke result direct Handle invoke result direct
:param invoke_result: invoke result :param invoke_result: invoke result
@ -285,12 +277,13 @@ class AppRunner:
queue_manager.publish( queue_manager.publish(
QueueMessageEndEvent( QueueMessageEndEvent(
llm_result=invoke_result, llm_result=invoke_result,
), PublishFrom.APPLICATION_MANAGER ),
PublishFrom.APPLICATION_MANAGER,
) )
def _handle_invoke_result_stream(self, invoke_result: Generator, def _handle_invoke_result_stream(
queue_manager: AppQueueManager, self, invoke_result: Generator, queue_manager: AppQueueManager, agent: bool
agent: bool) -> None: ) -> None:
""" """
Handle invoke result Handle invoke result
:param invoke_result: invoke result :param invoke_result: invoke result
@ -300,21 +293,13 @@ class AppRunner:
""" """
model = None model = None
prompt_messages = [] prompt_messages = []
text = '' text = ""
usage = None usage = None
for result in invoke_result: for result in invoke_result:
if not agent: if not agent:
queue_manager.publish( queue_manager.publish(QueueLLMChunkEvent(chunk=result), PublishFrom.APPLICATION_MANAGER)
QueueLLMChunkEvent(
chunk=result
), PublishFrom.APPLICATION_MANAGER
)
else: else:
queue_manager.publish( queue_manager.publish(QueueAgentMessageEvent(chunk=result), PublishFrom.APPLICATION_MANAGER)
QueueAgentMessageEvent(
chunk=result
), PublishFrom.APPLICATION_MANAGER
)
text += result.delta.message.content text += result.delta.message.content
@ -331,25 +316,24 @@ class AppRunner:
usage = LLMUsage.empty_usage() usage = LLMUsage.empty_usage()
llm_result = LLMResult( llm_result = LLMResult(
model=model, model=model, prompt_messages=prompt_messages, message=AssistantPromptMessage(content=text), usage=usage
prompt_messages=prompt_messages,
message=AssistantPromptMessage(content=text),
usage=usage
) )
queue_manager.publish( queue_manager.publish(
QueueMessageEndEvent( QueueMessageEndEvent(
llm_result=llm_result, llm_result=llm_result,
), PublishFrom.APPLICATION_MANAGER ),
PublishFrom.APPLICATION_MANAGER,
) )
def moderation_for_inputs( def moderation_for_inputs(
self, app_id: str, self,
tenant_id: str, app_id: str,
app_generate_entity: AppGenerateEntity, tenant_id: str,
inputs: Mapping[str, Any], app_generate_entity: AppGenerateEntity,
query: str, inputs: Mapping[str, Any],
message_id: str, query: str,
message_id: str,
) -> tuple[bool, dict, str]: ) -> tuple[bool, dict, str]:
""" """
Process sensitive_word_avoidance. Process sensitive_word_avoidance.
@ -367,14 +351,17 @@ class AppRunner:
tenant_id=tenant_id, tenant_id=tenant_id,
app_config=app_generate_entity.app_config, app_config=app_generate_entity.app_config,
inputs=inputs, inputs=inputs,
query=query if query else '', query=query if query else "",
message_id=message_id, message_id=message_id,
trace_manager=app_generate_entity.trace_manager trace_manager=app_generate_entity.trace_manager,
) )
def check_hosting_moderation(self, application_generate_entity: EasyUIBasedAppGenerateEntity, def check_hosting_moderation(
queue_manager: AppQueueManager, self,
prompt_messages: list[PromptMessage]) -> bool: application_generate_entity: EasyUIBasedAppGenerateEntity,
queue_manager: AppQueueManager,
prompt_messages: list[PromptMessage],
) -> bool:
""" """
Check hosting moderation Check hosting moderation
:param application_generate_entity: application generate entity :param application_generate_entity: application generate entity
@ -384,8 +371,7 @@ class AppRunner:
""" """
hosting_moderation_feature = HostingModerationFeature() hosting_moderation_feature = HostingModerationFeature()
moderation_result = hosting_moderation_feature.check( moderation_result = hosting_moderation_feature.check(
application_generate_entity=application_generate_entity, application_generate_entity=application_generate_entity, prompt_messages=prompt_messages
prompt_messages=prompt_messages
) )
if moderation_result: if moderation_result:
@ -393,18 +379,20 @@ class AppRunner:
queue_manager=queue_manager, queue_manager=queue_manager,
app_generate_entity=application_generate_entity, app_generate_entity=application_generate_entity,
prompt_messages=prompt_messages, prompt_messages=prompt_messages,
text="I apologize for any confusion, " \ text="I apologize for any confusion, " "but I'm an AI assistant to be helpful, harmless, and honest.",
"but I'm an AI assistant to be helpful, harmless, and honest.", stream=application_generate_entity.stream,
stream=application_generate_entity.stream
) )
return moderation_result return moderation_result
def fill_in_inputs_from_external_data_tools(self, tenant_id: str, def fill_in_inputs_from_external_data_tools(
app_id: str, self,
external_data_tools: list[ExternalDataVariableEntity], tenant_id: str,
inputs: dict, app_id: str,
query: str) -> dict: external_data_tools: list[ExternalDataVariableEntity],
inputs: dict,
query: str,
) -> dict:
""" """
Fill in variable inputs from external data tools if exists. Fill in variable inputs from external data tools if exists.
@ -417,18 +405,12 @@ class AppRunner:
""" """
external_data_fetch_feature = ExternalDataFetch() external_data_fetch_feature = ExternalDataFetch()
return external_data_fetch_feature.fetch( return external_data_fetch_feature.fetch(
tenant_id=tenant_id, tenant_id=tenant_id, app_id=app_id, external_data_tools=external_data_tools, inputs=inputs, query=query
app_id=app_id,
external_data_tools=external_data_tools,
inputs=inputs,
query=query
) )
def query_app_annotations_to_reply(self, app_record: App, def query_app_annotations_to_reply(
message: Message, self, app_record: App, message: Message, query: str, user_id: str, invoke_from: InvokeFrom
query: str, ) -> Optional[MessageAnnotation]:
user_id: str,
invoke_from: InvokeFrom) -> Optional[MessageAnnotation]:
""" """
Query app annotations to reply Query app annotations to reply
:param app_record: app record :param app_record: app record
@ -440,9 +422,5 @@ class AppRunner:
""" """
annotation_reply_feature = AnnotationReplyFeature() annotation_reply_feature = AnnotationReplyFeature()
return annotation_reply_feature.query( return annotation_reply_feature.query(
app_record=app_record, app_record=app_record, message=message, query=query, user_id=user_id, invoke_from=invoke_from
message=message,
query=query,
user_id=user_id,
invoke_from=invoke_from
) )

View File

@ -22,15 +22,19 @@ class ChatAppConfig(EasyUIBasedAppConfig):
""" """
Chatbot App Config Entity. Chatbot App Config Entity.
""" """
pass pass
class ChatAppConfigManager(BaseAppConfigManager): class ChatAppConfigManager(BaseAppConfigManager):
@classmethod @classmethod
def get_app_config(cls, app_model: App, def get_app_config(
app_model_config: AppModelConfig, cls,
conversation: Optional[Conversation] = None, app_model: App,
override_config_dict: Optional[dict] = None) -> ChatAppConfig: app_model_config: AppModelConfig,
conversation: Optional[Conversation] = None,
override_config_dict: Optional[dict] = None,
) -> ChatAppConfig:
""" """
Convert app model config to chat app config Convert app model config to chat app config
:param app_model: app model :param app_model: app model
@ -51,7 +55,7 @@ class ChatAppConfigManager(BaseAppConfigManager):
config_dict = app_model_config_dict.copy() config_dict = app_model_config_dict.copy()
else: else:
if not override_config_dict: if not override_config_dict:
raise Exception('override_config_dict is required when config_from is ARGS') raise Exception("override_config_dict is required when config_from is ARGS")
config_dict = override_config_dict config_dict = override_config_dict
@ -63,19 +67,11 @@ class ChatAppConfigManager(BaseAppConfigManager):
app_model_config_from=config_from, app_model_config_from=config_from,
app_model_config_id=app_model_config.id, app_model_config_id=app_model_config.id,
app_model_config_dict=config_dict, app_model_config_dict=config_dict,
model=ModelConfigManager.convert( model=ModelConfigManager.convert(config=config_dict),
config=config_dict prompt_template=PromptTemplateConfigManager.convert(config=config_dict),
), sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(config=config_dict),
prompt_template=PromptTemplateConfigManager.convert( dataset=DatasetConfigManager.convert(config=config_dict),
config=config_dict additional_features=cls.convert_features(config_dict, app_mode),
),
sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(
config=config_dict
),
dataset=DatasetConfigManager.convert(
config=config_dict
),
additional_features=cls.convert_features(config_dict, app_mode)
) )
app_config.variables, app_config.external_data_variables = BasicVariablesConfigManager.convert( app_config.variables, app_config.external_data_variables = BasicVariablesConfigManager.convert(
@ -113,8 +109,9 @@ class ChatAppConfigManager(BaseAppConfigManager):
related_config_keys.extend(current_related_config_keys) related_config_keys.extend(current_related_config_keys)
# dataset_query_variable # dataset_query_variable
config, current_related_config_keys = DatasetConfigManager.validate_and_set_defaults(tenant_id, app_mode, config, current_related_config_keys = DatasetConfigManager.validate_and_set_defaults(
config) tenant_id, app_mode, config
)
related_config_keys.extend(current_related_config_keys) related_config_keys.extend(current_related_config_keys)
# opening_statement # opening_statement
@ -123,7 +120,8 @@ class ChatAppConfigManager(BaseAppConfigManager):
# suggested_questions_after_answer # suggested_questions_after_answer
config, current_related_config_keys = SuggestedQuestionsAfterAnswerConfigManager.validate_and_set_defaults( config, current_related_config_keys = SuggestedQuestionsAfterAnswerConfigManager.validate_and_set_defaults(
config) config
)
related_config_keys.extend(current_related_config_keys) related_config_keys.extend(current_related_config_keys)
# speech_to_text # speech_to_text
@ -139,8 +137,9 @@ class ChatAppConfigManager(BaseAppConfigManager):
related_config_keys.extend(current_related_config_keys) related_config_keys.extend(current_related_config_keys)
# moderation validation # moderation validation
config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(tenant_id, config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(
config) tenant_id, config
)
related_config_keys.extend(current_related_config_keys) related_config_keys.extend(current_related_config_keys)
related_config_keys = list(set(related_config_keys)) related_config_keys = list(set(related_config_keys))

View File

@ -30,7 +30,8 @@ logger = logging.getLogger(__name__)
class ChatAppGenerator(MessageBasedAppGenerator): class ChatAppGenerator(MessageBasedAppGenerator):
@overload @overload
def generate( def generate(
self, app_model: App, self,
app_model: App,
user: Union[Account, EndUser], user: Union[Account, EndUser],
args: Any, args: Any,
invoke_from: InvokeFrom, invoke_from: InvokeFrom,
@ -39,7 +40,8 @@ class ChatAppGenerator(MessageBasedAppGenerator):
@overload @overload
def generate( def generate(
self, app_model: App, self,
app_model: App,
user: Union[Account, EndUser], user: Union[Account, EndUser],
args: Any, args: Any,
invoke_from: InvokeFrom, invoke_from: InvokeFrom,
@ -47,7 +49,8 @@ class ChatAppGenerator(MessageBasedAppGenerator):
) -> dict: ... ) -> dict: ...
def generate( def generate(
self, app_model: App, self,
app_model: App,
user: Union[Account, EndUser], user: Union[Account, EndUser],
args: Any, args: Any,
invoke_from: InvokeFrom, invoke_from: InvokeFrom,
@ -62,58 +65,46 @@ class ChatAppGenerator(MessageBasedAppGenerator):
:param invoke_from: invoke from source :param invoke_from: invoke from source
:param stream: is stream :param stream: is stream
""" """
if not args.get('query'): if not args.get("query"):
raise ValueError('query is required') raise ValueError("query is required")
query = args['query'] query = args["query"]
if not isinstance(query, str): if not isinstance(query, str):
raise ValueError('query must be a string') raise ValueError("query must be a string")
query = query.replace('\x00', '') query = query.replace("\x00", "")
inputs = args['inputs'] inputs = args["inputs"]
extras = { extras = {"auto_generate_conversation_name": args.get("auto_generate_name", True)}
"auto_generate_conversation_name": args.get('auto_generate_name', True)
}
# get conversation # get conversation
conversation = None conversation = None
if args.get('conversation_id'): if args.get("conversation_id"):
conversation = self._get_conversation_by_user(app_model, args.get('conversation_id'), user) conversation = self._get_conversation_by_user(app_model, args.get("conversation_id"), user)
# get app model config # get app model config
app_model_config = self._get_app_model_config( app_model_config = self._get_app_model_config(app_model=app_model, conversation=conversation)
app_model=app_model,
conversation=conversation
)
# validate override model config # validate override model config
override_model_config_dict = None override_model_config_dict = None
if args.get('model_config'): if args.get("model_config"):
if invoke_from != InvokeFrom.DEBUGGER: if invoke_from != InvokeFrom.DEBUGGER:
raise ValueError('Only in App debug mode can override model config') raise ValueError("Only in App debug mode can override model config")
# validate config # validate config
override_model_config_dict = ChatAppConfigManager.config_validate( override_model_config_dict = ChatAppConfigManager.config_validate(
tenant_id=app_model.tenant_id, tenant_id=app_model.tenant_id, config=args.get("model_config")
config=args.get('model_config')
) )
# always enable retriever resource in debugger mode # always enable retriever resource in debugger mode
override_model_config_dict["retriever_resource"] = { override_model_config_dict["retriever_resource"] = {"enabled": True}
"enabled": True
}
# parse files # parse files
files = args['files'] if args.get('files') else [] files = args["files"] if args.get("files") else []
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id) message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict()) file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict())
if file_extra_config: if file_extra_config:
file_objs = message_file_parser.validate_and_transform_files_arg( file_objs = message_file_parser.validate_and_transform_files_arg(files, file_extra_config, user)
files,
file_extra_config,
user
)
else: else:
file_objs = [] file_objs = []
@ -122,7 +113,7 @@ class ChatAppGenerator(MessageBasedAppGenerator):
app_model=app_model, app_model=app_model,
app_model_config=app_model_config, app_model_config=app_model_config,
conversation=conversation, conversation=conversation,
override_config_dict=override_model_config_dict override_config_dict=override_model_config_dict,
) )
# get tracing instance # get tracing instance
@ -141,14 +132,11 @@ class ChatAppGenerator(MessageBasedAppGenerator):
stream=stream, stream=stream,
invoke_from=invoke_from, invoke_from=invoke_from,
extras=extras, extras=extras,
trace_manager=trace_manager trace_manager=trace_manager,
) )
# init generate records # init generate records
( (conversation, message) = self._init_generate_records(application_generate_entity, conversation)
conversation,
message
) = self._init_generate_records(application_generate_entity, conversation)
# init queue manager # init queue manager
queue_manager = MessageBasedAppQueueManager( queue_manager = MessageBasedAppQueueManager(
@ -157,17 +145,20 @@ class ChatAppGenerator(MessageBasedAppGenerator):
invoke_from=application_generate_entity.invoke_from, invoke_from=application_generate_entity.invoke_from,
conversation_id=conversation.id, conversation_id=conversation.id,
app_mode=conversation.mode, app_mode=conversation.mode,
message_id=message.id message_id=message.id,
) )
# new thread # new thread
worker_thread = threading.Thread(target=self._generate_worker, kwargs={ worker_thread = threading.Thread(
'flask_app': current_app._get_current_object(), target=self._generate_worker,
'application_generate_entity': application_generate_entity, kwargs={
'queue_manager': queue_manager, "flask_app": current_app._get_current_object(),
'conversation_id': conversation.id, "application_generate_entity": application_generate_entity,
'message_id': message.id, "queue_manager": queue_manager,
}) "conversation_id": conversation.id,
"message_id": message.id,
},
)
worker_thread.start() worker_thread.start()
@ -181,16 +172,16 @@ class ChatAppGenerator(MessageBasedAppGenerator):
stream=stream, stream=stream,
) )
return ChatAppGenerateResponseConverter.convert( return ChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
response=response,
invoke_from=invoke_from
)
def _generate_worker(self, flask_app: Flask, def _generate_worker(
application_generate_entity: ChatAppGenerateEntity, self,
queue_manager: AppQueueManager, flask_app: Flask,
conversation_id: str, application_generate_entity: ChatAppGenerateEntity,
message_id: str) -> None: queue_manager: AppQueueManager,
conversation_id: str,
message_id: str,
) -> None:
""" """
Generate worker in a new thread. Generate worker in a new thread.
:param flask_app: Flask app :param flask_app: Flask app
@ -212,20 +203,19 @@ class ChatAppGenerator(MessageBasedAppGenerator):
application_generate_entity=application_generate_entity, application_generate_entity=application_generate_entity,
queue_manager=queue_manager, queue_manager=queue_manager,
conversation=conversation, conversation=conversation,
message=message message=message,
) )
except GenerateTaskStoppedException: except GenerateTaskStoppedException:
pass pass
except InvokeAuthorizationError: except InvokeAuthorizationError:
queue_manager.publish_error( queue_manager.publish_error(
InvokeAuthorizationError('Incorrect API key provided'), InvokeAuthorizationError("Incorrect API key provided"), PublishFrom.APPLICATION_MANAGER
PublishFrom.APPLICATION_MANAGER
) )
except ValidationError as e: except ValidationError as e:
logger.exception("Validation Error when generating") logger.exception("Validation Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except (ValueError, InvokeError) as e: except (ValueError, InvokeError) as e:
if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == 'true': if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == "true":
logger.exception("Error when generating") logger.exception("Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except Exception as e: except Exception as e:

View File

@ -24,10 +24,13 @@ class ChatAppRunner(AppRunner):
Chat Application Runner Chat Application Runner
""" """
def run(self, application_generate_entity: ChatAppGenerateEntity, def run(
queue_manager: AppQueueManager, self,
conversation: Conversation, application_generate_entity: ChatAppGenerateEntity,
message: Message) -> None: queue_manager: AppQueueManager,
conversation: Conversation,
message: Message,
) -> None:
""" """
Run application Run application
:param application_generate_entity: application generate entity :param application_generate_entity: application generate entity
@ -58,7 +61,7 @@ class ChatAppRunner(AppRunner):
prompt_template_entity=app_config.prompt_template, prompt_template_entity=app_config.prompt_template,
inputs=inputs, inputs=inputs,
files=files, files=files,
query=query query=query,
) )
memory = None memory = None
@ -66,13 +69,10 @@ class ChatAppRunner(AppRunner):
# get memory of conversation (read-only) # get memory of conversation (read-only)
model_instance = ModelInstance( model_instance = ModelInstance(
provider_model_bundle=application_generate_entity.model_conf.provider_model_bundle, provider_model_bundle=application_generate_entity.model_conf.provider_model_bundle,
model=application_generate_entity.model_conf.model model=application_generate_entity.model_conf.model,
) )
memory = TokenBufferMemory( memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance)
conversation=conversation,
model_instance=model_instance
)
# organize all inputs and template to prompt messages # organize all inputs and template to prompt messages
# Include: prompt template, inputs, query(optional), files(optional) # Include: prompt template, inputs, query(optional), files(optional)
@ -84,7 +84,7 @@ class ChatAppRunner(AppRunner):
inputs=inputs, inputs=inputs,
files=files, files=files,
query=query, query=query,
memory=memory memory=memory,
) )
# moderation # moderation
@ -96,7 +96,7 @@ class ChatAppRunner(AppRunner):
app_generate_entity=application_generate_entity, app_generate_entity=application_generate_entity,
inputs=inputs, inputs=inputs,
query=query, query=query,
message_id=message.id message_id=message.id,
) )
except ModerationException as e: except ModerationException as e:
self.direct_output( self.direct_output(
@ -104,7 +104,7 @@ class ChatAppRunner(AppRunner):
app_generate_entity=application_generate_entity, app_generate_entity=application_generate_entity,
prompt_messages=prompt_messages, prompt_messages=prompt_messages,
text=str(e), text=str(e),
stream=application_generate_entity.stream stream=application_generate_entity.stream,
) )
return return
@ -115,13 +115,13 @@ class ChatAppRunner(AppRunner):
message=message, message=message,
query=query, query=query,
user_id=application_generate_entity.user_id, user_id=application_generate_entity.user_id,
invoke_from=application_generate_entity.invoke_from invoke_from=application_generate_entity.invoke_from,
) )
if annotation_reply: if annotation_reply:
queue_manager.publish( queue_manager.publish(
QueueAnnotationReplyEvent(message_annotation_id=annotation_reply.id), QueueAnnotationReplyEvent(message_annotation_id=annotation_reply.id),
PublishFrom.APPLICATION_MANAGER PublishFrom.APPLICATION_MANAGER,
) )
self.direct_output( self.direct_output(
@ -129,7 +129,7 @@ class ChatAppRunner(AppRunner):
app_generate_entity=application_generate_entity, app_generate_entity=application_generate_entity,
prompt_messages=prompt_messages, prompt_messages=prompt_messages,
text=annotation_reply.content, text=annotation_reply.content,
stream=application_generate_entity.stream stream=application_generate_entity.stream,
) )
return return
@ -141,7 +141,7 @@ class ChatAppRunner(AppRunner):
app_id=app_record.id, app_id=app_record.id,
external_data_tools=external_data_tools, external_data_tools=external_data_tools,
inputs=inputs, inputs=inputs,
query=query query=query,
) )
# get context from datasets # get context from datasets
@ -152,7 +152,7 @@ class ChatAppRunner(AppRunner):
app_record.id, app_record.id,
message.id, message.id,
application_generate_entity.user_id, application_generate_entity.user_id,
application_generate_entity.invoke_from application_generate_entity.invoke_from,
) )
dataset_retrieval = DatasetRetrieval(application_generate_entity) dataset_retrieval = DatasetRetrieval(application_generate_entity)
@ -181,29 +181,26 @@ class ChatAppRunner(AppRunner):
files=files, files=files,
query=query, query=query,
context=context, context=context,
memory=memory memory=memory,
) )
# check hosting moderation # check hosting moderation
hosting_moderation_result = self.check_hosting_moderation( hosting_moderation_result = self.check_hosting_moderation(
application_generate_entity=application_generate_entity, application_generate_entity=application_generate_entity,
queue_manager=queue_manager, queue_manager=queue_manager,
prompt_messages=prompt_messages prompt_messages=prompt_messages,
) )
if hosting_moderation_result: if hosting_moderation_result:
return return
# Re-calculate the max tokens if sum(prompt_token + max_tokens) over model token limit # Re-calculate the max tokens if sum(prompt_token + max_tokens) over model token limit
self.recalc_llm_max_tokens( self.recalc_llm_max_tokens(model_config=application_generate_entity.model_conf, prompt_messages=prompt_messages)
model_config=application_generate_entity.model_conf,
prompt_messages=prompt_messages
)
# Invoke model # Invoke model
model_instance = ModelInstance( model_instance = ModelInstance(
provider_model_bundle=application_generate_entity.model_conf.provider_model_bundle, provider_model_bundle=application_generate_entity.model_conf.provider_model_bundle,
model=application_generate_entity.model_conf.model model=application_generate_entity.model_conf.model,
) )
db.session.close() db.session.close()
@ -218,7 +215,5 @@ class ChatAppRunner(AppRunner):
# handle invoke result # handle invoke result
self._handle_invoke_result( self._handle_invoke_result(
invoke_result=invoke_result, invoke_result=invoke_result, queue_manager=queue_manager, stream=application_generate_entity.stream
queue_manager=queue_manager,
stream=application_generate_entity.stream
) )

View File

@ -23,15 +23,15 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
:return: :return:
""" """
response = { response = {
'event': 'message', "event": "message",
'task_id': blocking_response.task_id, "task_id": blocking_response.task_id,
'id': blocking_response.data.id, "id": blocking_response.data.id,
'message_id': blocking_response.data.message_id, "message_id": blocking_response.data.message_id,
'conversation_id': blocking_response.data.conversation_id, "conversation_id": blocking_response.data.conversation_id,
'mode': blocking_response.data.mode, "mode": blocking_response.data.mode,
'answer': blocking_response.data.answer, "answer": blocking_response.data.answer,
'metadata': blocking_response.data.metadata, "metadata": blocking_response.data.metadata,
'created_at': blocking_response.data.created_at "created_at": blocking_response.data.created_at,
} }
return response return response
@ -45,14 +45,15 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
""" """
response = cls.convert_blocking_full_response(blocking_response) response = cls.convert_blocking_full_response(blocking_response)
metadata = response.get('metadata', {}) metadata = response.get("metadata", {})
response['metadata'] = cls._get_simple_metadata(metadata) response["metadata"] = cls._get_simple_metadata(metadata)
return response return response
@classmethod @classmethod
def convert_stream_full_response(cls, stream_response: Generator[ChatbotAppStreamResponse, None, None]) \ def convert_stream_full_response(
-> Generator[str, None, None]: cls, stream_response: Generator[ChatbotAppStreamResponse, None, None]
) -> Generator[str, None, None]:
""" """
Convert stream full response. Convert stream full response.
:param stream_response: stream response :param stream_response: stream response
@ -63,14 +64,14 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
sub_stream_response = chunk.stream_response sub_stream_response = chunk.stream_response
if isinstance(sub_stream_response, PingStreamResponse): if isinstance(sub_stream_response, PingStreamResponse):
yield 'ping' yield "ping"
continue continue
response_chunk = { response_chunk = {
'event': sub_stream_response.event.value, "event": sub_stream_response.event.value,
'conversation_id': chunk.conversation_id, "conversation_id": chunk.conversation_id,
'message_id': chunk.message_id, "message_id": chunk.message_id,
'created_at': chunk.created_at "created_at": chunk.created_at,
} }
if isinstance(sub_stream_response, ErrorStreamResponse): if isinstance(sub_stream_response, ErrorStreamResponse):
@ -81,8 +82,9 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
yield json.dumps(response_chunk) yield json.dumps(response_chunk)
@classmethod @classmethod
def convert_stream_simple_response(cls, stream_response: Generator[ChatbotAppStreamResponse, None, None]) \ def convert_stream_simple_response(
-> Generator[str, None, None]: cls, stream_response: Generator[ChatbotAppStreamResponse, None, None]
) -> Generator[str, None, None]:
""" """
Convert stream simple response. Convert stream simple response.
:param stream_response: stream response :param stream_response: stream response
@ -93,20 +95,20 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
sub_stream_response = chunk.stream_response sub_stream_response = chunk.stream_response
if isinstance(sub_stream_response, PingStreamResponse): if isinstance(sub_stream_response, PingStreamResponse):
yield 'ping' yield "ping"
continue continue
response_chunk = { response_chunk = {
'event': sub_stream_response.event.value, "event": sub_stream_response.event.value,
'conversation_id': chunk.conversation_id, "conversation_id": chunk.conversation_id,
'message_id': chunk.message_id, "message_id": chunk.message_id,
'created_at': chunk.created_at "created_at": chunk.created_at,
} }
if isinstance(sub_stream_response, MessageEndStreamResponse): if isinstance(sub_stream_response, MessageEndStreamResponse):
sub_stream_response_dict = sub_stream_response.to_dict() sub_stream_response_dict = sub_stream_response.to_dict()
metadata = sub_stream_response_dict.get('metadata', {}) metadata = sub_stream_response_dict.get("metadata", {})
sub_stream_response_dict['metadata'] = cls._get_simple_metadata(metadata) sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
response_chunk.update(sub_stream_response_dict) response_chunk.update(sub_stream_response_dict)
if isinstance(sub_stream_response, ErrorStreamResponse): if isinstance(sub_stream_response, ErrorStreamResponse):
data = cls._error_to_stream_response(sub_stream_response.err) data = cls._error_to_stream_response(sub_stream_response.err)

View File

@ -17,14 +17,15 @@ class CompletionAppConfig(EasyUIBasedAppConfig):
""" """
Completion App Config Entity. Completion App Config Entity.
""" """
pass pass
class CompletionAppConfigManager(BaseAppConfigManager): class CompletionAppConfigManager(BaseAppConfigManager):
@classmethod @classmethod
def get_app_config(cls, app_model: App, def get_app_config(
app_model_config: AppModelConfig, cls, app_model: App, app_model_config: AppModelConfig, override_config_dict: Optional[dict] = None
override_config_dict: Optional[dict] = None) -> CompletionAppConfig: ) -> CompletionAppConfig:
""" """
Convert app model config to completion app config Convert app model config to completion app config
:param app_model: app model :param app_model: app model
@ -51,19 +52,11 @@ class CompletionAppConfigManager(BaseAppConfigManager):
app_model_config_from=config_from, app_model_config_from=config_from,
app_model_config_id=app_model_config.id, app_model_config_id=app_model_config.id,
app_model_config_dict=config_dict, app_model_config_dict=config_dict,
model=ModelConfigManager.convert( model=ModelConfigManager.convert(config=config_dict),
config=config_dict prompt_template=PromptTemplateConfigManager.convert(config=config_dict),
), sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(config=config_dict),
prompt_template=PromptTemplateConfigManager.convert( dataset=DatasetConfigManager.convert(config=config_dict),
config=config_dict additional_features=cls.convert_features(config_dict, app_mode),
),
sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(
config=config_dict
),
dataset=DatasetConfigManager.convert(
config=config_dict
),
additional_features=cls.convert_features(config_dict, app_mode)
) )
app_config.variables, app_config.external_data_variables = BasicVariablesConfigManager.convert( app_config.variables, app_config.external_data_variables = BasicVariablesConfigManager.convert(
@ -101,8 +94,9 @@ class CompletionAppConfigManager(BaseAppConfigManager):
related_config_keys.extend(current_related_config_keys) related_config_keys.extend(current_related_config_keys)
# dataset_query_variable # dataset_query_variable
config, current_related_config_keys = DatasetConfigManager.validate_and_set_defaults(tenant_id, app_mode, config, current_related_config_keys = DatasetConfigManager.validate_and_set_defaults(
config) tenant_id, app_mode, config
)
related_config_keys.extend(current_related_config_keys) related_config_keys.extend(current_related_config_keys)
# text_to_speech # text_to_speech
@ -114,8 +108,9 @@ class CompletionAppConfigManager(BaseAppConfigManager):
related_config_keys.extend(current_related_config_keys) related_config_keys.extend(current_related_config_keys)
# moderation validation # moderation validation
config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(tenant_id, config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(
config) tenant_id, config
)
related_config_keys.extend(current_related_config_keys) related_config_keys.extend(current_related_config_keys)
related_config_keys = list(set(related_config_keys)) related_config_keys = list(set(related_config_keys))

View File

@ -32,7 +32,8 @@ logger = logging.getLogger(__name__)
class CompletionAppGenerator(MessageBasedAppGenerator): class CompletionAppGenerator(MessageBasedAppGenerator):
@overload @overload
def generate( def generate(
self, app_model: App, self,
app_model: App,
user: Union[Account, EndUser], user: Union[Account, EndUser],
args: dict, args: dict,
invoke_from: InvokeFrom, invoke_from: InvokeFrom,
@ -41,19 +42,17 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
@overload @overload
def generate( def generate(
self, app_model: App, self,
app_model: App,
user: Union[Account, EndUser], user: Union[Account, EndUser],
args: dict, args: dict,
invoke_from: InvokeFrom, invoke_from: InvokeFrom,
stream: Literal[False] = False, stream: Literal[False] = False,
) -> dict: ... ) -> dict: ...
def generate(self, app_model: App, def generate(
user: Union[Account, EndUser], self, app_model: App, user: Union[Account, EndUser], args: Any, invoke_from: InvokeFrom, stream: bool = True
args: Any, ) -> Union[dict, Generator[str, None, None]]:
invoke_from: InvokeFrom,
stream: bool = True) \
-> Union[dict, Generator[str, None, None]]:
""" """
Generate App response. Generate App response.
@ -63,12 +62,12 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
:param invoke_from: invoke from source :param invoke_from: invoke from source
:param stream: is stream :param stream: is stream
""" """
query = args['query'] query = args["query"]
if not isinstance(query, str): if not isinstance(query, str):
raise ValueError('query must be a string') raise ValueError("query must be a string")
query = query.replace('\x00', '') query = query.replace("\x00", "")
inputs = args['inputs'] inputs = args["inputs"]
extras = {} extras = {}
@ -76,41 +75,31 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
conversation = None conversation = None
# get app model config # get app model config
app_model_config = self._get_app_model_config( app_model_config = self._get_app_model_config(app_model=app_model, conversation=conversation)
app_model=app_model,
conversation=conversation
)
# validate override model config # validate override model config
override_model_config_dict = None override_model_config_dict = None
if args.get('model_config'): if args.get("model_config"):
if invoke_from != InvokeFrom.DEBUGGER: if invoke_from != InvokeFrom.DEBUGGER:
raise ValueError('Only in App debug mode can override model config') raise ValueError("Only in App debug mode can override model config")
# validate config # validate config
override_model_config_dict = CompletionAppConfigManager.config_validate( override_model_config_dict = CompletionAppConfigManager.config_validate(
tenant_id=app_model.tenant_id, tenant_id=app_model.tenant_id, config=args.get("model_config")
config=args.get('model_config')
) )
# parse files # parse files
files = args['files'] if args.get('files') else [] files = args["files"] if args.get("files") else []
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id) message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict()) file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict())
if file_extra_config: if file_extra_config:
file_objs = message_file_parser.validate_and_transform_files_arg( file_objs = message_file_parser.validate_and_transform_files_arg(files, file_extra_config, user)
files,
file_extra_config,
user
)
else: else:
file_objs = [] file_objs = []
# convert to app config # convert to app config
app_config = CompletionAppConfigManager.get_app_config( app_config = CompletionAppConfigManager.get_app_config(
app_model=app_model, app_model=app_model, app_model_config=app_model_config, override_config_dict=override_model_config_dict
app_model_config=app_model_config,
override_config_dict=override_model_config_dict
) )
# get tracing instance # get tracing instance
@ -128,14 +117,11 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
stream=stream, stream=stream,
invoke_from=invoke_from, invoke_from=invoke_from,
extras=extras, extras=extras,
trace_manager=trace_manager trace_manager=trace_manager,
) )
# init generate records # init generate records
( (conversation, message) = self._init_generate_records(application_generate_entity)
conversation,
message
) = self._init_generate_records(application_generate_entity)
# init queue manager # init queue manager
queue_manager = MessageBasedAppQueueManager( queue_manager = MessageBasedAppQueueManager(
@ -144,16 +130,19 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
invoke_from=application_generate_entity.invoke_from, invoke_from=application_generate_entity.invoke_from,
conversation_id=conversation.id, conversation_id=conversation.id,
app_mode=conversation.mode, app_mode=conversation.mode,
message_id=message.id message_id=message.id,
) )
# new thread # new thread
worker_thread = threading.Thread(target=self._generate_worker, kwargs={ worker_thread = threading.Thread(
'flask_app': current_app._get_current_object(), target=self._generate_worker,
'application_generate_entity': application_generate_entity, kwargs={
'queue_manager': queue_manager, "flask_app": current_app._get_current_object(),
'message_id': message.id, "application_generate_entity": application_generate_entity,
}) "queue_manager": queue_manager,
"message_id": message.id,
},
)
worker_thread.start() worker_thread.start()
@ -167,15 +156,15 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
stream=stream, stream=stream,
) )
return CompletionAppGenerateResponseConverter.convert( return CompletionAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
response=response,
invoke_from=invoke_from
)
def _generate_worker(self, flask_app: Flask, def _generate_worker(
application_generate_entity: CompletionAppGenerateEntity, self,
queue_manager: AppQueueManager, flask_app: Flask,
message_id: str) -> None: application_generate_entity: CompletionAppGenerateEntity,
queue_manager: AppQueueManager,
message_id: str,
) -> None:
""" """
Generate worker in a new thread. Generate worker in a new thread.
:param flask_app: Flask app :param flask_app: Flask app
@ -194,20 +183,19 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
runner.run( runner.run(
application_generate_entity=application_generate_entity, application_generate_entity=application_generate_entity,
queue_manager=queue_manager, queue_manager=queue_manager,
message=message message=message,
) )
except GenerateTaskStoppedException: except GenerateTaskStoppedException:
pass pass
except InvokeAuthorizationError: except InvokeAuthorizationError:
queue_manager.publish_error( queue_manager.publish_error(
InvokeAuthorizationError('Incorrect API key provided'), InvokeAuthorizationError("Incorrect API key provided"), PublishFrom.APPLICATION_MANAGER
PublishFrom.APPLICATION_MANAGER
) )
except ValidationError as e: except ValidationError as e:
logger.exception("Validation Error when generating") logger.exception("Validation Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except (ValueError, InvokeError) as e: except (ValueError, InvokeError) as e:
if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == 'true': if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == "true":
logger.exception("Error when generating") logger.exception("Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except Exception as e: except Exception as e:
@ -216,12 +204,14 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
finally: finally:
db.session.close() db.session.close()
def generate_more_like_this(self, app_model: App, def generate_more_like_this(
message_id: str, self,
user: Union[Account, EndUser], app_model: App,
invoke_from: InvokeFrom, message_id: str,
stream: bool = True) \ user: Union[Account, EndUser],
-> Union[dict, Generator[str, None, None]]: invoke_from: InvokeFrom,
stream: bool = True,
) -> Union[dict, Generator[str, None, None]]:
""" """
Generate App response. Generate App response.
@ -231,13 +221,17 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
:param invoke_from: invoke from source :param invoke_from: invoke from source
:param stream: is stream :param stream: is stream
""" """
message = db.session.query(Message).filter( message = (
Message.id == message_id, db.session.query(Message)
Message.app_id == app_model.id, .filter(
Message.from_source == ('api' if isinstance(user, EndUser) else 'console'), Message.id == message_id,
Message.from_end_user_id == (user.id if isinstance(user, EndUser) else None), Message.app_id == app_model.id,
Message.from_account_id == (user.id if isinstance(user, Account) else None), Message.from_source == ("api" if isinstance(user, EndUser) else "console"),
).first() Message.from_end_user_id == (user.id if isinstance(user, EndUser) else None),
Message.from_account_id == (user.id if isinstance(user, Account) else None),
)
.first()
)
if not message: if not message:
raise MessageNotExistsError() raise MessageNotExistsError()
@ -250,29 +244,23 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
app_model_config = message.app_model_config app_model_config = message.app_model_config
override_model_config_dict = app_model_config.to_dict() override_model_config_dict = app_model_config.to_dict()
model_dict = override_model_config_dict['model'] model_dict = override_model_config_dict["model"]
completion_params = model_dict.get('completion_params') completion_params = model_dict.get("completion_params")
completion_params['temperature'] = 0.9 completion_params["temperature"] = 0.9
model_dict['completion_params'] = completion_params model_dict["completion_params"] = completion_params
override_model_config_dict['model'] = model_dict override_model_config_dict["model"] = model_dict
# parse files # parse files
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id) message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict()) file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict())
if file_extra_config: if file_extra_config:
file_objs = message_file_parser.validate_and_transform_files_arg( file_objs = message_file_parser.validate_and_transform_files_arg(message.files, file_extra_config, user)
message.files,
file_extra_config,
user
)
else: else:
file_objs = [] file_objs = []
# convert to app config # convert to app config
app_config = CompletionAppConfigManager.get_app_config( app_config = CompletionAppConfigManager.get_app_config(
app_model=app_model, app_model=app_model, app_model_config=app_model_config, override_config_dict=override_model_config_dict
app_model_config=app_model_config,
override_config_dict=override_model_config_dict
) )
# init application generate entity # init application generate entity
@ -286,14 +274,11 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
user_id=user.id, user_id=user.id,
stream=stream, stream=stream,
invoke_from=invoke_from, invoke_from=invoke_from,
extras={} extras={},
) )
# init generate records # init generate records
( (conversation, message) = self._init_generate_records(application_generate_entity)
conversation,
message
) = self._init_generate_records(application_generate_entity)
# init queue manager # init queue manager
queue_manager = MessageBasedAppQueueManager( queue_manager = MessageBasedAppQueueManager(
@ -302,16 +287,19 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
invoke_from=application_generate_entity.invoke_from, invoke_from=application_generate_entity.invoke_from,
conversation_id=conversation.id, conversation_id=conversation.id,
app_mode=conversation.mode, app_mode=conversation.mode,
message_id=message.id message_id=message.id,
) )
# new thread # new thread
worker_thread = threading.Thread(target=self._generate_worker, kwargs={ worker_thread = threading.Thread(
'flask_app': current_app._get_current_object(), target=self._generate_worker,
'application_generate_entity': application_generate_entity, kwargs={
'queue_manager': queue_manager, "flask_app": current_app._get_current_object(),
'message_id': message.id, "application_generate_entity": application_generate_entity,
}) "queue_manager": queue_manager,
"message_id": message.id,
},
)
worker_thread.start() worker_thread.start()
@ -325,7 +313,4 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
stream=stream, stream=stream,
) )
return CompletionAppGenerateResponseConverter.convert( return CompletionAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
response=response,
invoke_from=invoke_from
)

View File

@ -22,9 +22,9 @@ class CompletionAppRunner(AppRunner):
Completion Application Runner Completion Application Runner
""" """
def run(self, application_generate_entity: CompletionAppGenerateEntity, def run(
queue_manager: AppQueueManager, self, application_generate_entity: CompletionAppGenerateEntity, queue_manager: AppQueueManager, message: Message
message: Message) -> None: ) -> None:
""" """
Run application Run application
:param application_generate_entity: application generate entity :param application_generate_entity: application generate entity
@ -54,7 +54,7 @@ class CompletionAppRunner(AppRunner):
prompt_template_entity=app_config.prompt_template, prompt_template_entity=app_config.prompt_template,
inputs=inputs, inputs=inputs,
files=files, files=files,
query=query query=query,
) )
# organize all inputs and template to prompt messages # organize all inputs and template to prompt messages
@ -65,7 +65,7 @@ class CompletionAppRunner(AppRunner):
prompt_template_entity=app_config.prompt_template, prompt_template_entity=app_config.prompt_template,
inputs=inputs, inputs=inputs,
files=files, files=files,
query=query query=query,
) )
# moderation # moderation
@ -77,7 +77,7 @@ class CompletionAppRunner(AppRunner):
app_generate_entity=application_generate_entity, app_generate_entity=application_generate_entity,
inputs=inputs, inputs=inputs,
query=query, query=query,
message_id=message.id message_id=message.id,
) )
except ModerationException as e: except ModerationException as e:
self.direct_output( self.direct_output(
@ -85,7 +85,7 @@ class CompletionAppRunner(AppRunner):
app_generate_entity=application_generate_entity, app_generate_entity=application_generate_entity,
prompt_messages=prompt_messages, prompt_messages=prompt_messages,
text=str(e), text=str(e),
stream=application_generate_entity.stream stream=application_generate_entity.stream,
) )
return return
@ -97,7 +97,7 @@ class CompletionAppRunner(AppRunner):
app_id=app_record.id, app_id=app_record.id,
external_data_tools=external_data_tools, external_data_tools=external_data_tools,
inputs=inputs, inputs=inputs,
query=query query=query,
) )
# get context from datasets # get context from datasets
@ -108,7 +108,7 @@ class CompletionAppRunner(AppRunner):
app_record.id, app_record.id,
message.id, message.id,
application_generate_entity.user_id, application_generate_entity.user_id,
application_generate_entity.invoke_from application_generate_entity.invoke_from,
) )
dataset_config = app_config.dataset dataset_config = app_config.dataset
@ -126,7 +126,7 @@ class CompletionAppRunner(AppRunner):
invoke_from=application_generate_entity.invoke_from, invoke_from=application_generate_entity.invoke_from,
show_retrieve_source=app_config.additional_features.show_retrieve_source, show_retrieve_source=app_config.additional_features.show_retrieve_source,
hit_callback=hit_callback, hit_callback=hit_callback,
message_id=message.id message_id=message.id,
) )
# reorganize all inputs and template to prompt messages # reorganize all inputs and template to prompt messages
@ -139,29 +139,26 @@ class CompletionAppRunner(AppRunner):
inputs=inputs, inputs=inputs,
files=files, files=files,
query=query, query=query,
context=context context=context,
) )
# check hosting moderation # check hosting moderation
hosting_moderation_result = self.check_hosting_moderation( hosting_moderation_result = self.check_hosting_moderation(
application_generate_entity=application_generate_entity, application_generate_entity=application_generate_entity,
queue_manager=queue_manager, queue_manager=queue_manager,
prompt_messages=prompt_messages prompt_messages=prompt_messages,
) )
if hosting_moderation_result: if hosting_moderation_result:
return return
# Re-calculate the max tokens if sum(prompt_token + max_tokens) over model token limit # Re-calculate the max tokens if sum(prompt_token + max_tokens) over model token limit
self.recalc_llm_max_tokens( self.recalc_llm_max_tokens(model_config=application_generate_entity.model_conf, prompt_messages=prompt_messages)
model_config=application_generate_entity.model_conf,
prompt_messages=prompt_messages
)
# Invoke model # Invoke model
model_instance = ModelInstance( model_instance = ModelInstance(
provider_model_bundle=application_generate_entity.model_conf.provider_model_bundle, provider_model_bundle=application_generate_entity.model_conf.provider_model_bundle,
model=application_generate_entity.model_conf.model model=application_generate_entity.model_conf.model,
) )
db.session.close() db.session.close()
@ -176,8 +173,5 @@ class CompletionAppRunner(AppRunner):
# handle invoke result # handle invoke result
self._handle_invoke_result( self._handle_invoke_result(
invoke_result=invoke_result, invoke_result=invoke_result, queue_manager=queue_manager, stream=application_generate_entity.stream
queue_manager=queue_manager,
stream=application_generate_entity.stream
) )

View File

@ -23,14 +23,14 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
:return: :return:
""" """
response = { response = {
'event': 'message', "event": "message",
'task_id': blocking_response.task_id, "task_id": blocking_response.task_id,
'id': blocking_response.data.id, "id": blocking_response.data.id,
'message_id': blocking_response.data.message_id, "message_id": blocking_response.data.message_id,
'mode': blocking_response.data.mode, "mode": blocking_response.data.mode,
'answer': blocking_response.data.answer, "answer": blocking_response.data.answer,
'metadata': blocking_response.data.metadata, "metadata": blocking_response.data.metadata,
'created_at': blocking_response.data.created_at "created_at": blocking_response.data.created_at,
} }
return response return response
@ -44,14 +44,15 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
""" """
response = cls.convert_blocking_full_response(blocking_response) response = cls.convert_blocking_full_response(blocking_response)
metadata = response.get('metadata', {}) metadata = response.get("metadata", {})
response['metadata'] = cls._get_simple_metadata(metadata) response["metadata"] = cls._get_simple_metadata(metadata)
return response return response
@classmethod @classmethod
def convert_stream_full_response(cls, stream_response: Generator[CompletionAppStreamResponse, None, None]) \ def convert_stream_full_response(
-> Generator[str, None, None]: cls, stream_response: Generator[CompletionAppStreamResponse, None, None]
) -> Generator[str, None, None]:
""" """
Convert stream full response. Convert stream full response.
:param stream_response: stream response :param stream_response: stream response
@ -62,13 +63,13 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
sub_stream_response = chunk.stream_response sub_stream_response = chunk.stream_response
if isinstance(sub_stream_response, PingStreamResponse): if isinstance(sub_stream_response, PingStreamResponse):
yield 'ping' yield "ping"
continue continue
response_chunk = { response_chunk = {
'event': sub_stream_response.event.value, "event": sub_stream_response.event.value,
'message_id': chunk.message_id, "message_id": chunk.message_id,
'created_at': chunk.created_at "created_at": chunk.created_at,
} }
if isinstance(sub_stream_response, ErrorStreamResponse): if isinstance(sub_stream_response, ErrorStreamResponse):
@ -79,8 +80,9 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
yield json.dumps(response_chunk) yield json.dumps(response_chunk)
@classmethod @classmethod
def convert_stream_simple_response(cls, stream_response: Generator[CompletionAppStreamResponse, None, None]) \ def convert_stream_simple_response(
-> Generator[str, None, None]: cls, stream_response: Generator[CompletionAppStreamResponse, None, None]
) -> Generator[str, None, None]:
""" """
Convert stream simple response. Convert stream simple response.
:param stream_response: stream response :param stream_response: stream response
@ -91,19 +93,19 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
sub_stream_response = chunk.stream_response sub_stream_response = chunk.stream_response
if isinstance(sub_stream_response, PingStreamResponse): if isinstance(sub_stream_response, PingStreamResponse):
yield 'ping' yield "ping"
continue continue
response_chunk = { response_chunk = {
'event': sub_stream_response.event.value, "event": sub_stream_response.event.value,
'message_id': chunk.message_id, "message_id": chunk.message_id,
'created_at': chunk.created_at "created_at": chunk.created_at,
} }
if isinstance(sub_stream_response, MessageEndStreamResponse): if isinstance(sub_stream_response, MessageEndStreamResponse):
sub_stream_response_dict = sub_stream_response.to_dict() sub_stream_response_dict = sub_stream_response.to_dict()
metadata = sub_stream_response_dict.get('metadata', {}) metadata = sub_stream_response_dict.get("metadata", {})
sub_stream_response_dict['metadata'] = cls._get_simple_metadata(metadata) sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
response_chunk.update(sub_stream_response_dict) response_chunk.update(sub_stream_response_dict)
if isinstance(sub_stream_response, ErrorStreamResponse): if isinstance(sub_stream_response, ErrorStreamResponse):
data = cls._error_to_stream_response(sub_stream_response.err) data = cls._error_to_stream_response(sub_stream_response.err)

View File

@ -35,23 +35,23 @@ logger = logging.getLogger(__name__)
class MessageBasedAppGenerator(BaseAppGenerator): class MessageBasedAppGenerator(BaseAppGenerator):
def _handle_response( def _handle_response(
self, application_generate_entity: Union[ self,
ChatAppGenerateEntity, application_generate_entity: Union[
CompletionAppGenerateEntity, ChatAppGenerateEntity,
AgentChatAppGenerateEntity, CompletionAppGenerateEntity,
AdvancedChatAppGenerateEntity AgentChatAppGenerateEntity,
], AdvancedChatAppGenerateEntity,
queue_manager: AppQueueManager, ],
conversation: Conversation, queue_manager: AppQueueManager,
message: Message, conversation: Conversation,
user: Union[Account, EndUser], message: Message,
stream: bool = False, user: Union[Account, EndUser],
stream: bool = False,
) -> Union[ ) -> Union[
ChatbotAppBlockingResponse, ChatbotAppBlockingResponse,
CompletionAppBlockingResponse, CompletionAppBlockingResponse,
Generator[Union[ChatbotAppStreamResponse, CompletionAppStreamResponse], None, None] Generator[Union[ChatbotAppStreamResponse, CompletionAppStreamResponse], None, None],
]: ]:
""" """
Handle response. Handle response.
@ -70,7 +70,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
conversation=conversation, conversation=conversation,
message=message, message=message,
user=user, user=user,
stream=stream stream=stream,
) )
try: try:
@ -82,12 +82,13 @@ class MessageBasedAppGenerator(BaseAppGenerator):
logger.exception(e) logger.exception(e)
raise e raise e
def _get_conversation_by_user(self, app_model: App, conversation_id: str, def _get_conversation_by_user(
user: Union[Account, EndUser]) -> Conversation: self, app_model: App, conversation_id: str, user: Union[Account, EndUser]
) -> Conversation:
conversation_filter = [ conversation_filter = [
Conversation.id == conversation_id, Conversation.id == conversation_id,
Conversation.app_id == app_model.id, Conversation.app_id == app_model.id,
Conversation.status == 'normal' Conversation.status == "normal",
] ]
if isinstance(user, Account): if isinstance(user, Account):
@ -100,19 +101,18 @@ class MessageBasedAppGenerator(BaseAppGenerator):
if not conversation: if not conversation:
raise ConversationNotExistsError() raise ConversationNotExistsError()
if conversation.status != 'normal': if conversation.status != "normal":
raise ConversationCompletedError() raise ConversationCompletedError()
return conversation return conversation
def _get_app_model_config(self, app_model: App, def _get_app_model_config(self, app_model: App, conversation: Optional[Conversation] = None) -> AppModelConfig:
conversation: Optional[Conversation] = None) \
-> AppModelConfig:
if conversation: if conversation:
app_model_config = db.session.query(AppModelConfig).filter( app_model_config = (
AppModelConfig.id == conversation.app_model_config_id, db.session.query(AppModelConfig)
AppModelConfig.app_id == app_model.id .filter(AppModelConfig.id == conversation.app_model_config_id, AppModelConfig.app_id == app_model.id)
).first() .first()
)
if not app_model_config: if not app_model_config:
raise AppModelConfigBrokenError() raise AppModelConfigBrokenError()
@ -127,15 +127,16 @@ class MessageBasedAppGenerator(BaseAppGenerator):
return app_model_config return app_model_config
def _init_generate_records(self, def _init_generate_records(
application_generate_entity: Union[ self,
ChatAppGenerateEntity, application_generate_entity: Union[
CompletionAppGenerateEntity, ChatAppGenerateEntity,
AgentChatAppGenerateEntity, CompletionAppGenerateEntity,
AdvancedChatAppGenerateEntity AgentChatAppGenerateEntity,
], AdvancedChatAppGenerateEntity,
conversation: Optional[Conversation] = None) \ ],
-> tuple[Conversation, Message]: conversation: Optional[Conversation] = None,
) -> tuple[Conversation, Message]:
""" """
Initialize generate records Initialize generate records
:param application_generate_entity: application generate entity :param application_generate_entity: application generate entity
@ -148,10 +149,10 @@ class MessageBasedAppGenerator(BaseAppGenerator):
end_user_id = None end_user_id = None
account_id = None account_id = None
if application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]: if application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]:
from_source = 'api' from_source = "api"
end_user_id = application_generate_entity.user_id end_user_id = application_generate_entity.user_id
else: else:
from_source = 'console' from_source = "console"
account_id = application_generate_entity.user_id account_id = application_generate_entity.user_id
if isinstance(application_generate_entity, AdvancedChatAppGenerateEntity): if isinstance(application_generate_entity, AdvancedChatAppGenerateEntity):
@ -164,8 +165,11 @@ class MessageBasedAppGenerator(BaseAppGenerator):
model_provider = application_generate_entity.model_conf.provider model_provider = application_generate_entity.model_conf.provider
model_id = application_generate_entity.model_conf.model model_id = application_generate_entity.model_conf.model
override_model_configs = None override_model_configs = None
if app_config.app_model_config_from == EasyUIBasedAppModelConfigFrom.ARGS \ if app_config.app_model_config_from == EasyUIBasedAppModelConfigFrom.ARGS and app_config.app_mode in [
and app_config.app_mode in [AppMode.AGENT_CHAT, AppMode.CHAT, AppMode.COMPLETION]: AppMode.AGENT_CHAT,
AppMode.CHAT,
AppMode.COMPLETION,
]:
override_model_configs = app_config.app_model_config_dict override_model_configs = app_config.app_model_config_dict
# get conversation introduction # get conversation introduction
@ -179,12 +183,12 @@ class MessageBasedAppGenerator(BaseAppGenerator):
model_id=model_id, model_id=model_id,
override_model_configs=json.dumps(override_model_configs) if override_model_configs else None, override_model_configs=json.dumps(override_model_configs) if override_model_configs else None,
mode=app_config.app_mode.value, mode=app_config.app_mode.value,
name='New conversation', name="New conversation",
inputs=application_generate_entity.inputs, inputs=application_generate_entity.inputs,
introduction=introduction, introduction=introduction,
system_instruction="", system_instruction="",
system_instruction_tokens=0, system_instruction_tokens=0,
status='normal', status="normal",
invoke_from=application_generate_entity.invoke_from.value, invoke_from=application_generate_entity.invoke_from.value,
from_source=from_source, from_source=from_source,
from_end_user_id=end_user_id, from_end_user_id=end_user_id,
@ -216,11 +220,11 @@ class MessageBasedAppGenerator(BaseAppGenerator):
answer_price_unit=0, answer_price_unit=0,
provider_response_latency=0, provider_response_latency=0,
total_price=0, total_price=0,
currency='USD', currency="USD",
invoke_from=application_generate_entity.invoke_from.value, invoke_from=application_generate_entity.invoke_from.value,
from_source=from_source, from_source=from_source,
from_end_user_id=end_user_id, from_end_user_id=end_user_id,
from_account_id=account_id from_account_id=account_id,
) )
db.session.add(message) db.session.add(message)
@ -232,10 +236,10 @@ class MessageBasedAppGenerator(BaseAppGenerator):
message_id=message.id, message_id=message.id,
type=file.type.value, type=file.type.value,
transfer_method=file.transfer_method.value, transfer_method=file.transfer_method.value,
belongs_to='user', belongs_to="user",
url=file.url, url=file.url,
upload_file_id=file.related_id, upload_file_id=file.related_id,
created_by_role=('account' if account_id else 'end_user'), created_by_role=("account" if account_id else "end_user"),
created_by=account_id or end_user_id, created_by=account_id or end_user_id,
) )
db.session.add(message_file) db.session.add(message_file)
@ -269,11 +273,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
:param conversation_id: conversation id :param conversation_id: conversation id
:return: conversation :return: conversation
""" """
conversation = ( conversation = db.session.query(Conversation).filter(Conversation.id == conversation_id).first()
db.session.query(Conversation)
.filter(Conversation.id == conversation_id)
.first()
)
if not conversation: if not conversation:
raise ConversationNotExistsError() raise ConversationNotExistsError()
@ -286,10 +286,6 @@ class MessageBasedAppGenerator(BaseAppGenerator):
:param message_id: message id :param message_id: message id
:return: message :return: message
""" """
message = ( message = db.session.query(Message).filter(Message.id == message_id).first()
db.session.query(Message)
.filter(Message.id == message_id)
.first()
)
return message return message

View File

@ -12,12 +12,9 @@ from core.app.entities.queue_entities import (
class MessageBasedAppQueueManager(AppQueueManager): class MessageBasedAppQueueManager(AppQueueManager):
def __init__(self, task_id: str, def __init__(
user_id: str, self, task_id: str, user_id: str, invoke_from: InvokeFrom, conversation_id: str, app_mode: str, message_id: str
invoke_from: InvokeFrom, ) -> None:
conversation_id: str,
app_mode: str,
message_id: str) -> None:
super().__init__(task_id, user_id, invoke_from) super().__init__(task_id, user_id, invoke_from)
self._conversation_id = str(conversation_id) self._conversation_id = str(conversation_id)
@ -30,7 +27,7 @@ class MessageBasedAppQueueManager(AppQueueManager):
message_id=self._message_id, message_id=self._message_id,
conversation_id=self._conversation_id, conversation_id=self._conversation_id,
app_mode=self._app_mode, app_mode=self._app_mode,
event=event event=event,
) )
def _publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None: def _publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None:
@ -45,17 +42,15 @@ class MessageBasedAppQueueManager(AppQueueManager):
message_id=self._message_id, message_id=self._message_id,
conversation_id=self._conversation_id, conversation_id=self._conversation_id,
app_mode=self._app_mode, app_mode=self._app_mode,
event=event event=event,
) )
self._q.put(message) self._q.put(message)
if isinstance(event, QueueStopEvent if isinstance(
| QueueErrorEvent event, QueueStopEvent | QueueErrorEvent | QueueMessageEndEvent | QueueAdvancedChatMessageEndEvent
| QueueMessageEndEvent ):
| QueueAdvancedChatMessageEndEvent):
self.stop_listen() self.stop_listen()
if pub_from == PublishFrom.APPLICATION_MANAGER and self._is_stopped(): if pub_from == PublishFrom.APPLICATION_MANAGER and self._is_stopped():
raise GenerateTaskStoppedException() raise GenerateTaskStoppedException()

View File

@ -12,6 +12,7 @@ class WorkflowAppConfig(WorkflowUIBasedAppConfig):
""" """
Workflow App Config Entity. Workflow App Config Entity.
""" """
pass pass
@ -26,13 +27,9 @@ class WorkflowAppConfigManager(BaseAppConfigManager):
app_id=app_model.id, app_id=app_model.id,
app_mode=app_mode, app_mode=app_mode,
workflow_id=workflow.id, workflow_id=workflow.id,
sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert( sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(config=features_dict),
config=features_dict variables=WorkflowVariablesConfigManager.convert(workflow=workflow),
), additional_features=cls.convert_features(features_dict, app_mode),
variables=WorkflowVariablesConfigManager.convert(
workflow=workflow
),
additional_features=cls.convert_features(features_dict, app_mode)
) )
return app_config return app_config
@ -50,8 +47,7 @@ class WorkflowAppConfigManager(BaseAppConfigManager):
# file upload validation # file upload validation
config, current_related_config_keys = FileUploadConfigManager.validate_and_set_defaults( config, current_related_config_keys = FileUploadConfigManager.validate_and_set_defaults(
config=config, config=config, is_vision=False
is_vision=False
) )
related_config_keys.extend(current_related_config_keys) related_config_keys.extend(current_related_config_keys)
@ -61,9 +57,7 @@ class WorkflowAppConfigManager(BaseAppConfigManager):
# moderation validation # moderation validation
config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults( config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(
tenant_id=tenant_id, tenant_id=tenant_id, config=config, only_structure_validate=only_structure_validate
config=config,
only_structure_validate=only_structure_validate
) )
related_config_keys.extend(current_related_config_keys) related_config_keys.extend(current_related_config_keys)

View File

@ -34,26 +34,28 @@ logger = logging.getLogger(__name__)
class WorkflowAppGenerator(BaseAppGenerator): class WorkflowAppGenerator(BaseAppGenerator):
@overload @overload
def generate( def generate(
self, app_model: App, self,
app_model: App,
workflow: Workflow, workflow: Workflow,
user: Union[Account, EndUser], user: Union[Account, EndUser],
args: dict, args: dict,
invoke_from: InvokeFrom, invoke_from: InvokeFrom,
stream: Literal[True] = True, stream: Literal[True] = True,
call_depth: int = 0, call_depth: int = 0,
workflow_thread_pool_id: Optional[str] = None workflow_thread_pool_id: Optional[str] = None,
) -> Generator[str, None, None]: ... ) -> Generator[str, None, None]: ...
@overload @overload
def generate( def generate(
self, app_model: App, self,
app_model: App,
workflow: Workflow, workflow: Workflow,
user: Union[Account, EndUser], user: Union[Account, EndUser],
args: dict, args: dict,
invoke_from: InvokeFrom, invoke_from: InvokeFrom,
stream: Literal[False] = False, stream: Literal[False] = False,
call_depth: int = 0, call_depth: int = 0,
workflow_thread_pool_id: Optional[str] = None workflow_thread_pool_id: Optional[str] = None,
) -> dict: ... ) -> dict: ...
def generate( def generate(
@ -65,7 +67,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
invoke_from: InvokeFrom, invoke_from: InvokeFrom,
stream: bool = True, stream: bool = True,
call_depth: int = 0, call_depth: int = 0,
workflow_thread_pool_id: Optional[str] = None workflow_thread_pool_id: Optional[str] = None,
): ):
""" """
Generate App response. Generate App response.
@ -79,26 +81,19 @@ class WorkflowAppGenerator(BaseAppGenerator):
:param call_depth: call depth :param call_depth: call depth
:param workflow_thread_pool_id: workflow thread pool id :param workflow_thread_pool_id: workflow thread pool id
""" """
inputs = args['inputs'] inputs = args["inputs"]
# parse files # parse files
files = args['files'] if args.get('files') else [] files = args["files"] if args.get("files") else []
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id) message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False) file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False)
if file_extra_config: if file_extra_config:
file_objs = message_file_parser.validate_and_transform_files_arg( file_objs = message_file_parser.validate_and_transform_files_arg(files, file_extra_config, user)
files,
file_extra_config,
user
)
else: else:
file_objs = [] file_objs = []
# convert to app config # convert to app config
app_config = WorkflowAppConfigManager.get_app_config( app_config = WorkflowAppConfigManager.get_app_config(app_model=app_model, workflow=workflow)
app_model=app_model,
workflow=workflow
)
# get tracing instance # get tracing instance
user_id = user.id if isinstance(user, Account) else user.session_id user_id = user.id if isinstance(user, Account) else user.session_id
@ -114,7 +109,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
stream=stream, stream=stream,
invoke_from=invoke_from, invoke_from=invoke_from,
call_depth=call_depth, call_depth=call_depth,
trace_manager=trace_manager trace_manager=trace_manager,
) )
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id) contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
@ -125,18 +120,19 @@ class WorkflowAppGenerator(BaseAppGenerator):
application_generate_entity=application_generate_entity, application_generate_entity=application_generate_entity,
invoke_from=invoke_from, invoke_from=invoke_from,
stream=stream, stream=stream,
workflow_thread_pool_id=workflow_thread_pool_id workflow_thread_pool_id=workflow_thread_pool_id,
) )
def _generate( def _generate(
self, *, self,
*,
app_model: App, app_model: App,
workflow: Workflow, workflow: Workflow,
user: Union[Account, EndUser], user: Union[Account, EndUser],
application_generate_entity: WorkflowAppGenerateEntity, application_generate_entity: WorkflowAppGenerateEntity,
invoke_from: InvokeFrom, invoke_from: InvokeFrom,
stream: bool = True, stream: bool = True,
workflow_thread_pool_id: Optional[str] = None workflow_thread_pool_id: Optional[str] = None,
) -> dict[str, Any] | Generator[str, None, None]: ) -> dict[str, Any] | Generator[str, None, None]:
""" """
Generate App response. Generate App response.
@ -154,17 +150,20 @@ class WorkflowAppGenerator(BaseAppGenerator):
task_id=application_generate_entity.task_id, task_id=application_generate_entity.task_id,
user_id=application_generate_entity.user_id, user_id=application_generate_entity.user_id,
invoke_from=application_generate_entity.invoke_from, invoke_from=application_generate_entity.invoke_from,
app_mode=app_model.mode app_mode=app_model.mode,
) )
# new thread # new thread
worker_thread = threading.Thread(target=self._generate_worker, kwargs={ worker_thread = threading.Thread(
'flask_app': current_app._get_current_object(), # type: ignore target=self._generate_worker,
'application_generate_entity': application_generate_entity, kwargs={
'queue_manager': queue_manager, "flask_app": current_app._get_current_object(), # type: ignore
'context': contextvars.copy_context(), "application_generate_entity": application_generate_entity,
'workflow_thread_pool_id': workflow_thread_pool_id "queue_manager": queue_manager,
}) "context": contextvars.copy_context(),
"workflow_thread_pool_id": workflow_thread_pool_id,
},
)
worker_thread.start() worker_thread.start()
@ -177,17 +176,11 @@ class WorkflowAppGenerator(BaseAppGenerator):
stream=stream, stream=stream,
) )
return WorkflowAppGenerateResponseConverter.convert( return WorkflowAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
response=response,
invoke_from=invoke_from
)
def single_iteration_generate(self, app_model: App, def single_iteration_generate(
workflow: Workflow, self, app_model: App, workflow: Workflow, node_id: str, user: Account, args: dict, stream: bool = True
node_id: str, ) -> dict[str, Any] | Generator[str, Any, None]:
user: Account,
args: dict,
stream: bool = True) -> dict[str, Any] | Generator[str, Any, None]:
""" """
Generate App response. Generate App response.
@ -199,16 +192,13 @@ class WorkflowAppGenerator(BaseAppGenerator):
:param stream: is stream :param stream: is stream
""" """
if not node_id: if not node_id:
raise ValueError('node_id is required') raise ValueError("node_id is required")
if args.get('inputs') is None: if args.get("inputs") is None:
raise ValueError('inputs is required') raise ValueError("inputs is required")
# convert to app config # convert to app config
app_config = WorkflowAppConfigManager.get_app_config( app_config = WorkflowAppConfigManager.get_app_config(app_model=app_model, workflow=workflow)
app_model=app_model,
workflow=workflow
)
# init application generate entity # init application generate entity
application_generate_entity = WorkflowAppGenerateEntity( application_generate_entity = WorkflowAppGenerateEntity(
@ -219,13 +209,10 @@ class WorkflowAppGenerator(BaseAppGenerator):
user_id=user.id, user_id=user.id,
stream=stream, stream=stream,
invoke_from=InvokeFrom.DEBUGGER, invoke_from=InvokeFrom.DEBUGGER,
extras={ extras={"auto_generate_conversation_name": False},
"auto_generate_conversation_name": False
},
single_iteration_run=WorkflowAppGenerateEntity.SingleIterationRunEntity( single_iteration_run=WorkflowAppGenerateEntity.SingleIterationRunEntity(
node_id=node_id, node_id=node_id, inputs=args["inputs"]
inputs=args['inputs'] ),
)
) )
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id) contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
@ -235,14 +222,17 @@ class WorkflowAppGenerator(BaseAppGenerator):
user=user, user=user,
invoke_from=InvokeFrom.DEBUGGER, invoke_from=InvokeFrom.DEBUGGER,
application_generate_entity=application_generate_entity, application_generate_entity=application_generate_entity,
stream=stream stream=stream,
) )
def _generate_worker(self, flask_app: Flask, def _generate_worker(
application_generate_entity: WorkflowAppGenerateEntity, self,
queue_manager: AppQueueManager, flask_app: Flask,
context: contextvars.Context, application_generate_entity: WorkflowAppGenerateEntity,
workflow_thread_pool_id: Optional[str] = None) -> None: queue_manager: AppQueueManager,
context: contextvars.Context,
workflow_thread_pool_id: Optional[str] = None,
) -> None:
""" """
Generate worker in a new thread. Generate worker in a new thread.
:param flask_app: Flask app :param flask_app: Flask app
@ -259,7 +249,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
runner = WorkflowAppRunner( runner = WorkflowAppRunner(
application_generate_entity=application_generate_entity, application_generate_entity=application_generate_entity,
queue_manager=queue_manager, queue_manager=queue_manager,
workflow_thread_pool_id=workflow_thread_pool_id workflow_thread_pool_id=workflow_thread_pool_id,
) )
runner.run() runner.run()
@ -267,14 +257,13 @@ class WorkflowAppGenerator(BaseAppGenerator):
pass pass
except InvokeAuthorizationError: except InvokeAuthorizationError:
queue_manager.publish_error( queue_manager.publish_error(
InvokeAuthorizationError('Incorrect API key provided'), InvokeAuthorizationError("Incorrect API key provided"), PublishFrom.APPLICATION_MANAGER
PublishFrom.APPLICATION_MANAGER
) )
except ValidationError as e: except ValidationError as e:
logger.exception("Validation Error when generating") logger.exception("Validation Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except (ValueError, InvokeError) as e: except (ValueError, InvokeError) as e:
if os.environ.get("DEBUG") and os.environ.get("DEBUG", "false").lower() == 'true': if os.environ.get("DEBUG") and os.environ.get("DEBUG", "false").lower() == "true":
logger.exception("Error when generating") logger.exception("Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except Exception as e: except Exception as e:
@ -283,14 +272,14 @@ class WorkflowAppGenerator(BaseAppGenerator):
finally: finally:
db.session.close() db.session.close()
def _handle_response(self, application_generate_entity: WorkflowAppGenerateEntity, def _handle_response(
workflow: Workflow, self,
queue_manager: AppQueueManager, application_generate_entity: WorkflowAppGenerateEntity,
user: Union[Account, EndUser], workflow: Workflow,
stream: bool = False) -> Union[ queue_manager: AppQueueManager,
WorkflowAppBlockingResponse, user: Union[Account, EndUser],
Generator[WorkflowAppStreamResponse, None, None] stream: bool = False,
]: ) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
""" """
Handle response. Handle response.
:param application_generate_entity: application generate entity :param application_generate_entity: application generate entity
@ -306,7 +295,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
workflow=workflow, workflow=workflow,
queue_manager=queue_manager, queue_manager=queue_manager,
user=user, user=user,
stream=stream stream=stream,
) )
try: try:

View File

@ -12,10 +12,7 @@ from core.app.entities.queue_entities import (
class WorkflowAppQueueManager(AppQueueManager): class WorkflowAppQueueManager(AppQueueManager):
def __init__(self, task_id: str, def __init__(self, task_id: str, user_id: str, invoke_from: InvokeFrom, app_mode: str) -> None:
user_id: str,
invoke_from: InvokeFrom,
app_mode: str) -> None:
super().__init__(task_id, user_id, invoke_from) super().__init__(task_id, user_id, invoke_from)
self._app_mode = app_mode self._app_mode = app_mode
@ -27,19 +24,18 @@ class WorkflowAppQueueManager(AppQueueManager):
:param pub_from: :param pub_from:
:return: :return:
""" """
message = WorkflowQueueMessage( message = WorkflowQueueMessage(task_id=self._task_id, app_mode=self._app_mode, event=event)
task_id=self._task_id,
app_mode=self._app_mode,
event=event
)
self._q.put(message) self._q.put(message)
if isinstance(event, QueueStopEvent if isinstance(
| QueueErrorEvent event,
| QueueMessageEndEvent QueueStopEvent
| QueueWorkflowSucceededEvent | QueueErrorEvent
| QueueWorkflowFailedEvent): | QueueMessageEndEvent
| QueueWorkflowSucceededEvent
| QueueWorkflowFailedEvent,
):
self.stop_listen() self.stop_listen()
if pub_from == PublishFrom.APPLICATION_MANAGER and self._is_stopped(): if pub_from == PublishFrom.APPLICATION_MANAGER and self._is_stopped():

View File

@ -28,10 +28,10 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
""" """
def __init__( def __init__(
self, self,
application_generate_entity: WorkflowAppGenerateEntity, application_generate_entity: WorkflowAppGenerateEntity,
queue_manager: AppQueueManager, queue_manager: AppQueueManager,
workflow_thread_pool_id: Optional[str] = None workflow_thread_pool_id: Optional[str] = None,
) -> None: ) -> None:
""" """
:param application_generate_entity: application generate entity :param application_generate_entity: application generate entity
@ -62,16 +62,16 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
app_record = db.session.query(App).filter(App.id == app_config.app_id).first() app_record = db.session.query(App).filter(App.id == app_config.app_id).first()
if not app_record: if not app_record:
raise ValueError('App not found') raise ValueError("App not found")
workflow = self.get_workflow(app_model=app_record, workflow_id=app_config.workflow_id) workflow = self.get_workflow(app_model=app_record, workflow_id=app_config.workflow_id)
if not workflow: if not workflow:
raise ValueError('Workflow not initialized') raise ValueError("Workflow not initialized")
db.session.close() db.session.close()
workflow_callbacks: list[WorkflowCallback] = [] workflow_callbacks: list[WorkflowCallback] = []
if bool(os.environ.get('DEBUG', 'False').lower() == 'true'): if bool(os.environ.get("DEBUG", "False").lower() == "true"):
workflow_callbacks.append(WorkflowLoggingCallback()) workflow_callbacks.append(WorkflowLoggingCallback())
# if only single iteration run is requested # if only single iteration run is requested
@ -80,10 +80,9 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration( graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
workflow=workflow, workflow=workflow,
node_id=self.application_generate_entity.single_iteration_run.node_id, node_id=self.application_generate_entity.single_iteration_run.node_id,
user_inputs=self.application_generate_entity.single_iteration_run.inputs user_inputs=self.application_generate_entity.single_iteration_run.inputs,
) )
else: else:
inputs = self.application_generate_entity.inputs inputs = self.application_generate_entity.inputs
files = self.application_generate_entity.files files = self.application_generate_entity.files
@ -120,12 +119,10 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
invoke_from=self.application_generate_entity.invoke_from, invoke_from=self.application_generate_entity.invoke_from,
call_depth=self.application_generate_entity.call_depth, call_depth=self.application_generate_entity.call_depth,
variable_pool=variable_pool, variable_pool=variable_pool,
thread_pool_id=self.workflow_thread_pool_id thread_pool_id=self.workflow_thread_pool_id,
) )
generator = workflow_entry.run( generator = workflow_entry.run(callbacks=workflow_callbacks)
callbacks=workflow_callbacks
)
for event in generator: for event in generator:
self._handle_event(workflow_entry, event) self._handle_event(workflow_entry, event)

View File

@ -35,8 +35,9 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
return cls.convert_blocking_full_response(blocking_response) return cls.convert_blocking_full_response(blocking_response)
@classmethod @classmethod
def convert_stream_full_response(cls, stream_response: Generator[WorkflowAppStreamResponse, None, None]) \ def convert_stream_full_response(
-> Generator[str, None, None]: cls, stream_response: Generator[WorkflowAppStreamResponse, None, None]
) -> Generator[str, None, None]:
""" """
Convert stream full response. Convert stream full response.
:param stream_response: stream response :param stream_response: stream response
@ -47,12 +48,12 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
sub_stream_response = chunk.stream_response sub_stream_response = chunk.stream_response
if isinstance(sub_stream_response, PingStreamResponse): if isinstance(sub_stream_response, PingStreamResponse):
yield 'ping' yield "ping"
continue continue
response_chunk = { response_chunk = {
'event': sub_stream_response.event.value, "event": sub_stream_response.event.value,
'workflow_run_id': chunk.workflow_run_id, "workflow_run_id": chunk.workflow_run_id,
} }
if isinstance(sub_stream_response, ErrorStreamResponse): if isinstance(sub_stream_response, ErrorStreamResponse):
@ -63,8 +64,9 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
yield json.dumps(response_chunk) yield json.dumps(response_chunk)
@classmethod @classmethod
def convert_stream_simple_response(cls, stream_response: Generator[WorkflowAppStreamResponse, None, None]) \ def convert_stream_simple_response(
-> Generator[str, None, None]: cls, stream_response: Generator[WorkflowAppStreamResponse, None, None]
) -> Generator[str, None, None]:
""" """
Convert stream simple response. Convert stream simple response.
:param stream_response: stream response :param stream_response: stream response
@ -75,12 +77,12 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
sub_stream_response = chunk.stream_response sub_stream_response = chunk.stream_response
if isinstance(sub_stream_response, PingStreamResponse): if isinstance(sub_stream_response, PingStreamResponse):
yield 'ping' yield "ping"
continue continue
response_chunk = { response_chunk = {
'event': sub_stream_response.event.value, "event": sub_stream_response.event.value,
'workflow_run_id': chunk.workflow_run_id, "workflow_run_id": chunk.workflow_run_id,
} }
if isinstance(sub_stream_response, ErrorStreamResponse): if isinstance(sub_stream_response, ErrorStreamResponse):

View File

@ -63,17 +63,21 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
""" """
WorkflowAppGenerateTaskPipeline is a class that generate stream output and state management for Application. WorkflowAppGenerateTaskPipeline is a class that generate stream output and state management for Application.
""" """
_workflow: Workflow _workflow: Workflow
_user: Union[Account, EndUser] _user: Union[Account, EndUser]
_task_state: WorkflowTaskState _task_state: WorkflowTaskState
_application_generate_entity: WorkflowAppGenerateEntity _application_generate_entity: WorkflowAppGenerateEntity
_workflow_system_variables: dict[SystemVariableKey, Any] _workflow_system_variables: dict[SystemVariableKey, Any]
def __init__(self, application_generate_entity: WorkflowAppGenerateEntity, def __init__(
workflow: Workflow, self,
queue_manager: AppQueueManager, application_generate_entity: WorkflowAppGenerateEntity,
user: Union[Account, EndUser], workflow: Workflow,
stream: bool) -> None: queue_manager: AppQueueManager,
user: Union[Account, EndUser],
stream: bool,
) -> None:
""" """
Initialize GenerateTaskPipeline. Initialize GenerateTaskPipeline.
:param application_generate_entity: application generate entity :param application_generate_entity: application generate entity
@ -92,7 +96,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
self._workflow = workflow self._workflow = workflow
self._workflow_system_variables = { self._workflow_system_variables = {
SystemVariableKey.FILES: application_generate_entity.files, SystemVariableKey.FILES: application_generate_entity.files,
SystemVariableKey.USER_ID: user_id SystemVariableKey.USER_ID: user_id,
} }
self._task_state = WorkflowTaskState() self._task_state = WorkflowTaskState()
@ -106,16 +110,13 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
db.session.refresh(self._user) db.session.refresh(self._user)
db.session.close() db.session.close()
generator = self._wrapper_process_stream_response( generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager)
trace_manager=self._application_generate_entity.trace_manager
)
if self._stream: if self._stream:
return self._to_stream_response(generator) return self._to_stream_response(generator)
else: else:
return self._to_blocking_response(generator) return self._to_blocking_response(generator)
def _to_blocking_response(self, generator: Generator[StreamResponse, None, None]) \ def _to_blocking_response(self, generator: Generator[StreamResponse, None, None]) -> WorkflowAppBlockingResponse:
-> WorkflowAppBlockingResponse:
""" """
To blocking response. To blocking response.
:return: :return:
@ -137,18 +138,19 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
total_tokens=stream_response.data.total_tokens, total_tokens=stream_response.data.total_tokens,
total_steps=stream_response.data.total_steps, total_steps=stream_response.data.total_steps,
created_at=int(stream_response.data.created_at), created_at=int(stream_response.data.created_at),
finished_at=int(stream_response.data.finished_at) finished_at=int(stream_response.data.finished_at),
) ),
) )
return response return response
else: else:
continue continue
raise Exception('Queue listening stopped unexpectedly.') raise Exception("Queue listening stopped unexpectedly.")
def _to_stream_response(self, generator: Generator[StreamResponse, None, None]) \ def _to_stream_response(
-> Generator[WorkflowAppStreamResponse, None, None]: self, generator: Generator[StreamResponse, None, None]
) -> Generator[WorkflowAppStreamResponse, None, None]:
""" """
To stream response. To stream response.
:return: :return:
@ -158,10 +160,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
if isinstance(stream_response, WorkflowStartStreamResponse): if isinstance(stream_response, WorkflowStartStreamResponse):
workflow_run_id = stream_response.workflow_run_id workflow_run_id = stream_response.workflow_run_id
yield WorkflowAppStreamResponse( yield WorkflowAppStreamResponse(workflow_run_id=workflow_run_id, stream_response=stream_response)
workflow_run_id=workflow_run_id,
stream_response=stream_response
)
def _listenAudioMsg(self, publisher, task_id: str): def _listenAudioMsg(self, publisher, task_id: str):
if not publisher: if not publisher:
@ -171,17 +170,20 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id) return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id)
return None return None
def _wrapper_process_stream_response(self, trace_manager: Optional[TraceQueueManager] = None) -> \ def _wrapper_process_stream_response(
Generator[StreamResponse, None, None]: self, trace_manager: Optional[TraceQueueManager] = None
) -> Generator[StreamResponse, None, None]:
tts_publisher = None tts_publisher = None
task_id = self._application_generate_entity.task_id task_id = self._application_generate_entity.task_id
tenant_id = self._application_generate_entity.app_config.tenant_id tenant_id = self._application_generate_entity.app_config.tenant_id
features_dict = self._workflow.features_dict features_dict = self._workflow.features_dict
if features_dict.get('text_to_speech') and features_dict['text_to_speech'].get('enabled') and features_dict[ if (
'text_to_speech'].get('autoPlay') == 'enabled': features_dict.get("text_to_speech")
tts_publisher = AppGeneratorTTSPublisher(tenant_id, features_dict['text_to_speech'].get('voice')) and features_dict["text_to_speech"].get("enabled")
and features_dict["text_to_speech"].get("autoPlay") == "enabled"
):
tts_publisher = AppGeneratorTTSPublisher(tenant_id, features_dict["text_to_speech"].get("voice"))
for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager): for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager):
while True: while True:
@ -210,13 +212,12 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
except Exception as e: except Exception as e:
logger.error(e) logger.error(e)
break break
yield MessageAudioEndStreamResponse(audio='', task_id=task_id) yield MessageAudioEndStreamResponse(audio="", task_id=task_id)
def _process_stream_response( def _process_stream_response(
self, self,
tts_publisher: Optional[AppGeneratorTTSPublisher] = None, tts_publisher: Optional[AppGeneratorTTSPublisher] = None,
trace_manager: Optional[TraceQueueManager] = None trace_manager: Optional[TraceQueueManager] = None,
) -> Generator[StreamResponse, None, None]: ) -> Generator[StreamResponse, None, None]:
""" """
Process stream response. Process stream response.
@ -241,22 +242,18 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
# init workflow run # init workflow run
workflow_run = self._handle_workflow_run_start() workflow_run = self._handle_workflow_run_start()
yield self._workflow_start_to_stream_response( yield self._workflow_start_to_stream_response(
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
workflow_run=workflow_run
) )
elif isinstance(event, QueueNodeStartedEvent): elif isinstance(event, QueueNodeStartedEvent):
if not workflow_run: if not workflow_run:
raise Exception('Workflow run not initialized.') raise Exception("Workflow run not initialized.")
workflow_node_execution = self._handle_node_execution_start( workflow_node_execution = self._handle_node_execution_start(workflow_run=workflow_run, event=event)
workflow_run=workflow_run,
event=event
)
response = self._workflow_node_start_to_stream_response( response = self._workflow_node_start_to_stream_response(
event=event, event=event,
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution workflow_node_execution=workflow_node_execution,
) )
if response: if response:
@ -267,7 +264,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
response = self._workflow_node_finish_to_stream_response( response = self._workflow_node_finish_to_stream_response(
event=event, event=event,
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution workflow_node_execution=workflow_node_execution,
) )
if response: if response:
@ -278,69 +275,61 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
response = self._workflow_node_finish_to_stream_response( response = self._workflow_node_finish_to_stream_response(
event=event, event=event,
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution workflow_node_execution=workflow_node_execution,
) )
if response: if response:
yield response yield response
elif isinstance(event, QueueParallelBranchRunStartedEvent): elif isinstance(event, QueueParallelBranchRunStartedEvent):
if not workflow_run: if not workflow_run:
raise Exception('Workflow run not initialized.') raise Exception("Workflow run not initialized.")
yield self._workflow_parallel_branch_start_to_stream_response( yield self._workflow_parallel_branch_start_to_stream_response(
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
workflow_run=workflow_run,
event=event
) )
elif isinstance(event, QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent): elif isinstance(event, QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent):
if not workflow_run: if not workflow_run:
raise Exception('Workflow run not initialized.') raise Exception("Workflow run not initialized.")
yield self._workflow_parallel_branch_finished_to_stream_response( yield self._workflow_parallel_branch_finished_to_stream_response(
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
workflow_run=workflow_run,
event=event
) )
elif isinstance(event, QueueIterationStartEvent): elif isinstance(event, QueueIterationStartEvent):
if not workflow_run: if not workflow_run:
raise Exception('Workflow run not initialized.') raise Exception("Workflow run not initialized.")
yield self._workflow_iteration_start_to_stream_response( yield self._workflow_iteration_start_to_stream_response(
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
workflow_run=workflow_run,
event=event
) )
elif isinstance(event, QueueIterationNextEvent): elif isinstance(event, QueueIterationNextEvent):
if not workflow_run: if not workflow_run:
raise Exception('Workflow run not initialized.') raise Exception("Workflow run not initialized.")
yield self._workflow_iteration_next_to_stream_response( yield self._workflow_iteration_next_to_stream_response(
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
workflow_run=workflow_run,
event=event
) )
elif isinstance(event, QueueIterationCompletedEvent): elif isinstance(event, QueueIterationCompletedEvent):
if not workflow_run: if not workflow_run:
raise Exception('Workflow run not initialized.') raise Exception("Workflow run not initialized.")
yield self._workflow_iteration_completed_to_stream_response( yield self._workflow_iteration_completed_to_stream_response(
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
workflow_run=workflow_run,
event=event
) )
elif isinstance(event, QueueWorkflowSucceededEvent): elif isinstance(event, QueueWorkflowSucceededEvent):
if not workflow_run: if not workflow_run:
raise Exception('Workflow run not initialized.') raise Exception("Workflow run not initialized.")
if not graph_runtime_state: if not graph_runtime_state:
raise Exception('Graph runtime state not initialized.') raise Exception("Graph runtime state not initialized.")
workflow_run = self._handle_workflow_run_success( workflow_run = self._handle_workflow_run_success(
workflow_run=workflow_run, workflow_run=workflow_run,
start_at=graph_runtime_state.start_at, start_at=graph_runtime_state.start_at,
total_tokens=graph_runtime_state.total_tokens, total_tokens=graph_runtime_state.total_tokens,
total_steps=graph_runtime_state.node_run_steps, total_steps=graph_runtime_state.node_run_steps,
outputs=json.dumps(event.outputs) if isinstance(event, QueueWorkflowSucceededEvent) and event.outputs else None, outputs=json.dumps(event.outputs)
if isinstance(event, QueueWorkflowSucceededEvent) and event.outputs
else None,
conversation_id=None, conversation_id=None,
trace_manager=trace_manager, trace_manager=trace_manager,
) )
@ -349,22 +338,23 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
self._save_workflow_app_log(workflow_run) self._save_workflow_app_log(workflow_run)
yield self._workflow_finish_to_stream_response( yield self._workflow_finish_to_stream_response(
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
workflow_run=workflow_run
) )
elif isinstance(event, QueueWorkflowFailedEvent | QueueStopEvent): elif isinstance(event, QueueWorkflowFailedEvent | QueueStopEvent):
if not workflow_run: if not workflow_run:
raise Exception('Workflow run not initialized.') raise Exception("Workflow run not initialized.")
if not graph_runtime_state: if not graph_runtime_state:
raise Exception('Graph runtime state not initialized.') raise Exception("Graph runtime state not initialized.")
workflow_run = self._handle_workflow_run_failed( workflow_run = self._handle_workflow_run_failed(
workflow_run=workflow_run, workflow_run=workflow_run,
start_at=graph_runtime_state.start_at, start_at=graph_runtime_state.start_at,
total_tokens=graph_runtime_state.total_tokens, total_tokens=graph_runtime_state.total_tokens,
total_steps=graph_runtime_state.node_run_steps, total_steps=graph_runtime_state.node_run_steps,
status=WorkflowRunStatus.FAILED if isinstance(event, QueueWorkflowFailedEvent) else WorkflowRunStatus.STOPPED, status=WorkflowRunStatus.FAILED
if isinstance(event, QueueWorkflowFailedEvent)
else WorkflowRunStatus.STOPPED,
error=event.error if isinstance(event, QueueWorkflowFailedEvent) else event.get_stop_reason(), error=event.error if isinstance(event, QueueWorkflowFailedEvent) else event.get_stop_reason(),
conversation_id=None, conversation_id=None,
trace_manager=trace_manager, trace_manager=trace_manager,
@ -374,8 +364,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
self._save_workflow_app_log(workflow_run) self._save_workflow_app_log(workflow_run)
yield self._workflow_finish_to_stream_response( yield self._workflow_finish_to_stream_response(
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
workflow_run=workflow_run
) )
elif isinstance(event, QueueTextChunkEvent): elif isinstance(event, QueueTextChunkEvent):
delta_text = event.text delta_text = event.text
@ -394,7 +383,6 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
if tts_publisher: if tts_publisher:
tts_publisher.publish(None) tts_publisher.publish(None)
def _save_workflow_app_log(self, workflow_run: WorkflowRun) -> None: def _save_workflow_app_log(self, workflow_run: WorkflowRun) -> None:
""" """
Save workflow app log. Save workflow app log.
@ -417,7 +405,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
workflow_app_log.workflow_id = workflow_run.workflow_id workflow_app_log.workflow_id = workflow_run.workflow_id
workflow_app_log.workflow_run_id = workflow_run.id workflow_app_log.workflow_run_id = workflow_run.id
workflow_app_log.created_from = created_from.value workflow_app_log.created_from = created_from.value
workflow_app_log.created_by_role = 'account' if isinstance(self._user, Account) else 'end_user' workflow_app_log.created_by_role = "account" if isinstance(self._user, Account) else "end_user"
workflow_app_log.created_by = self._user.id workflow_app_log.created_by = self._user.id
db.session.add(workflow_app_log) db.session.add(workflow_app_log)
@ -431,8 +419,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
:return: :return:
""" """
response = TextChunkStreamResponse( response = TextChunkStreamResponse(
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id, data=TextChunkStreamResponse.Data(text=text)
data=TextChunkStreamResponse.Data(text=text)
) )
return response return response

View File

@ -58,89 +58,86 @@ class WorkflowBasedAppRunner(AppRunner):
""" """
Init graph Init graph
""" """
if 'nodes' not in graph_config or 'edges' not in graph_config: if "nodes" not in graph_config or "edges" not in graph_config:
raise ValueError('nodes or edges not found in workflow graph') raise ValueError("nodes or edges not found in workflow graph")
if not isinstance(graph_config.get('nodes'), list): if not isinstance(graph_config.get("nodes"), list):
raise ValueError('nodes in workflow graph must be a list') raise ValueError("nodes in workflow graph must be a list")
if not isinstance(graph_config.get('edges'), list): if not isinstance(graph_config.get("edges"), list):
raise ValueError('edges in workflow graph must be a list') raise ValueError("edges in workflow graph must be a list")
# init graph # init graph
graph = Graph.init( graph = Graph.init(graph_config=graph_config)
graph_config=graph_config
)
if not graph: if not graph:
raise ValueError('graph not found in workflow') raise ValueError("graph not found in workflow")
return graph return graph
def _get_graph_and_variable_pool_of_single_iteration( def _get_graph_and_variable_pool_of_single_iteration(
self, self,
workflow: Workflow, workflow: Workflow,
node_id: str, node_id: str,
user_inputs: dict, user_inputs: dict,
) -> tuple[Graph, VariablePool]: ) -> tuple[Graph, VariablePool]:
""" """
Get variable pool of single iteration Get variable pool of single iteration
""" """
# fetch workflow graph # fetch workflow graph
graph_config = workflow.graph_dict graph_config = workflow.graph_dict
if not graph_config: if not graph_config:
raise ValueError('workflow graph not found') raise ValueError("workflow graph not found")
graph_config = cast(dict[str, Any], graph_config) graph_config = cast(dict[str, Any], graph_config)
if 'nodes' not in graph_config or 'edges' not in graph_config: if "nodes" not in graph_config or "edges" not in graph_config:
raise ValueError('nodes or edges not found in workflow graph') raise ValueError("nodes or edges not found in workflow graph")
if not isinstance(graph_config.get('nodes'), list): if not isinstance(graph_config.get("nodes"), list):
raise ValueError('nodes in workflow graph must be a list') raise ValueError("nodes in workflow graph must be a list")
if not isinstance(graph_config.get('edges'), list): if not isinstance(graph_config.get("edges"), list):
raise ValueError('edges in workflow graph must be a list') raise ValueError("edges in workflow graph must be a list")
# filter nodes only in iteration # filter nodes only in iteration
node_configs = [ node_configs = [
node for node in graph_config.get('nodes', []) node
if node.get('id') == node_id or node.get('data', {}).get('iteration_id', '') == node_id for node in graph_config.get("nodes", [])
if node.get("id") == node_id or node.get("data", {}).get("iteration_id", "") == node_id
] ]
graph_config['nodes'] = node_configs graph_config["nodes"] = node_configs
node_ids = [node.get('id') for node in node_configs] node_ids = [node.get("id") for node in node_configs]
# filter edges only in iteration # filter edges only in iteration
edge_configs = [ edge_configs = [
edge for edge in graph_config.get('edges', []) edge
if (edge.get('source') is None or edge.get('source') in node_ids) for edge in graph_config.get("edges", [])
and (edge.get('target') is None or edge.get('target') in node_ids) if (edge.get("source") is None or edge.get("source") in node_ids)
and (edge.get("target") is None or edge.get("target") in node_ids)
] ]
graph_config['edges'] = edge_configs graph_config["edges"] = edge_configs
# init graph # init graph
graph = Graph.init( graph = Graph.init(graph_config=graph_config, root_node_id=node_id)
graph_config=graph_config,
root_node_id=node_id
)
if not graph: if not graph:
raise ValueError('graph not found in workflow') raise ValueError("graph not found in workflow")
# fetch node config from node id # fetch node config from node id
iteration_node_config = None iteration_node_config = None
for node in node_configs: for node in node_configs:
if node.get('id') == node_id: if node.get("id") == node_id:
iteration_node_config = node iteration_node_config = node
break break
if not iteration_node_config: if not iteration_node_config:
raise ValueError('iteration node id not found in workflow graph') raise ValueError("iteration node id not found in workflow graph")
# Get node class # Get node class
node_type = NodeType.value_of(iteration_node_config.get('data', {}).get('type')) node_type = NodeType.value_of(iteration_node_config.get("data", {}).get("type"))
node_cls = node_classes.get(node_type) node_cls = node_classes.get(node_type)
node_cls = cast(type[BaseNode], node_cls) node_cls = cast(type[BaseNode], node_cls)
@ -153,8 +150,7 @@ class WorkflowBasedAppRunner(AppRunner):
try: try:
variable_mapping = node_cls.extract_variable_selector_to_variable_mapping( variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(
graph_config=workflow.graph_dict, graph_config=workflow.graph_dict, config=iteration_node_config
config=iteration_node_config
) )
except NotImplementedError: except NotImplementedError:
variable_mapping = {} variable_mapping = {}
@ -165,7 +161,7 @@ class WorkflowBasedAppRunner(AppRunner):
variable_pool=variable_pool, variable_pool=variable_pool,
tenant_id=workflow.tenant_id, tenant_id=workflow.tenant_id,
node_type=node_type, node_type=node_type,
node_data=IterationNodeData(**iteration_node_config.get('data', {})) node_data=IterationNodeData(**iteration_node_config.get("data", {})),
) )
return graph, variable_pool return graph, variable_pool
@ -178,18 +174,12 @@ class WorkflowBasedAppRunner(AppRunner):
""" """
if isinstance(event, GraphRunStartedEvent): if isinstance(event, GraphRunStartedEvent):
self._publish_event( self._publish_event(
QueueWorkflowStartedEvent( QueueWorkflowStartedEvent(graph_runtime_state=workflow_entry.graph_engine.graph_runtime_state)
graph_runtime_state=workflow_entry.graph_engine.graph_runtime_state
)
) )
elif isinstance(event, GraphRunSucceededEvent): elif isinstance(event, GraphRunSucceededEvent):
self._publish_event( self._publish_event(QueueWorkflowSucceededEvent(outputs=event.outputs))
QueueWorkflowSucceededEvent(outputs=event.outputs)
)
elif isinstance(event, GraphRunFailedEvent): elif isinstance(event, GraphRunFailedEvent):
self._publish_event( self._publish_event(QueueWorkflowFailedEvent(error=event.error))
QueueWorkflowFailedEvent(error=event.error)
)
elif isinstance(event, NodeRunStartedEvent): elif isinstance(event, NodeRunStartedEvent):
self._publish_event( self._publish_event(
QueueNodeStartedEvent( QueueNodeStartedEvent(
@ -204,7 +194,7 @@ class WorkflowBasedAppRunner(AppRunner):
start_at=event.route_node_state.start_at, start_at=event.route_node_state.start_at,
node_run_index=event.route_node_state.index, node_run_index=event.route_node_state.index,
predecessor_node_id=event.predecessor_node_id, predecessor_node_id=event.predecessor_node_id,
in_iteration_id=event.in_iteration_id in_iteration_id=event.in_iteration_id,
) )
) )
elif isinstance(event, NodeRunSucceededEvent): elif isinstance(event, NodeRunSucceededEvent):
@ -220,14 +210,18 @@ class WorkflowBasedAppRunner(AppRunner):
parent_parallel_start_node_id=event.parent_parallel_start_node_id, parent_parallel_start_node_id=event.parent_parallel_start_node_id,
start_at=event.route_node_state.start_at, start_at=event.route_node_state.start_at,
inputs=event.route_node_state.node_run_result.inputs inputs=event.route_node_state.node_run_result.inputs
if event.route_node_state.node_run_result else {}, if event.route_node_state.node_run_result
else {},
process_data=event.route_node_state.node_run_result.process_data process_data=event.route_node_state.node_run_result.process_data
if event.route_node_state.node_run_result else {}, if event.route_node_state.node_run_result
else {},
outputs=event.route_node_state.node_run_result.outputs outputs=event.route_node_state.node_run_result.outputs
if event.route_node_state.node_run_result else {}, if event.route_node_state.node_run_result
else {},
execution_metadata=event.route_node_state.node_run_result.metadata execution_metadata=event.route_node_state.node_run_result.metadata
if event.route_node_state.node_run_result else {}, if event.route_node_state.node_run_result
in_iteration_id=event.in_iteration_id else {},
in_iteration_id=event.in_iteration_id,
) )
) )
elif isinstance(event, NodeRunFailedEvent): elif isinstance(event, NodeRunFailedEvent):
@ -243,16 +237,18 @@ class WorkflowBasedAppRunner(AppRunner):
parent_parallel_start_node_id=event.parent_parallel_start_node_id, parent_parallel_start_node_id=event.parent_parallel_start_node_id,
start_at=event.route_node_state.start_at, start_at=event.route_node_state.start_at,
inputs=event.route_node_state.node_run_result.inputs inputs=event.route_node_state.node_run_result.inputs
if event.route_node_state.node_run_result else {},
process_data=event.route_node_state.node_run_result.process_data
if event.route_node_state.node_run_result else {},
outputs=event.route_node_state.node_run_result.outputs
if event.route_node_state.node_run_result else {},
error=event.route_node_state.node_run_result.error
if event.route_node_state.node_run_result if event.route_node_state.node_run_result
and event.route_node_state.node_run_result.error else {},
process_data=event.route_node_state.node_run_result.process_data
if event.route_node_state.node_run_result
else {},
outputs=event.route_node_state.node_run_result.outputs
if event.route_node_state.node_run_result
else {},
error=event.route_node_state.node_run_result.error
if event.route_node_state.node_run_result and event.route_node_state.node_run_result.error
else "Unknown error", else "Unknown error",
in_iteration_id=event.in_iteration_id in_iteration_id=event.in_iteration_id,
) )
) )
elif isinstance(event, NodeRunStreamChunkEvent): elif isinstance(event, NodeRunStreamChunkEvent):
@ -260,14 +256,13 @@ class WorkflowBasedAppRunner(AppRunner):
QueueTextChunkEvent( QueueTextChunkEvent(
text=event.chunk_content, text=event.chunk_content,
from_variable_selector=event.from_variable_selector, from_variable_selector=event.from_variable_selector,
in_iteration_id=event.in_iteration_id in_iteration_id=event.in_iteration_id,
) )
) )
elif isinstance(event, NodeRunRetrieverResourceEvent): elif isinstance(event, NodeRunRetrieverResourceEvent):
self._publish_event( self._publish_event(
QueueRetrieverResourcesEvent( QueueRetrieverResourcesEvent(
retriever_resources=event.retriever_resources, retriever_resources=event.retriever_resources, in_iteration_id=event.in_iteration_id
in_iteration_id=event.in_iteration_id
) )
) )
elif isinstance(event, ParallelBranchRunStartedEvent): elif isinstance(event, ParallelBranchRunStartedEvent):
@ -277,7 +272,7 @@ class WorkflowBasedAppRunner(AppRunner):
parallel_start_node_id=event.parallel_start_node_id, parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id, parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id, parent_parallel_start_node_id=event.parent_parallel_start_node_id,
in_iteration_id=event.in_iteration_id in_iteration_id=event.in_iteration_id,
) )
) )
elif isinstance(event, ParallelBranchRunSucceededEvent): elif isinstance(event, ParallelBranchRunSucceededEvent):
@ -287,7 +282,7 @@ class WorkflowBasedAppRunner(AppRunner):
parallel_start_node_id=event.parallel_start_node_id, parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id, parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id, parent_parallel_start_node_id=event.parent_parallel_start_node_id,
in_iteration_id=event.in_iteration_id in_iteration_id=event.in_iteration_id,
) )
) )
elif isinstance(event, ParallelBranchRunFailedEvent): elif isinstance(event, ParallelBranchRunFailedEvent):
@ -298,7 +293,7 @@ class WorkflowBasedAppRunner(AppRunner):
parent_parallel_id=event.parent_parallel_id, parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id, parent_parallel_start_node_id=event.parent_parallel_start_node_id,
in_iteration_id=event.in_iteration_id, in_iteration_id=event.in_iteration_id,
error=event.error error=event.error,
) )
) )
elif isinstance(event, IterationRunStartedEvent): elif isinstance(event, IterationRunStartedEvent):
@ -316,7 +311,7 @@ class WorkflowBasedAppRunner(AppRunner):
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps, node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
inputs=event.inputs, inputs=event.inputs,
predecessor_node_id=event.predecessor_node_id, predecessor_node_id=event.predecessor_node_id,
metadata=event.metadata metadata=event.metadata,
) )
) )
elif isinstance(event, IterationRunNextEvent): elif isinstance(event, IterationRunNextEvent):
@ -352,7 +347,7 @@ class WorkflowBasedAppRunner(AppRunner):
outputs=event.outputs, outputs=event.outputs,
metadata=event.metadata, metadata=event.metadata,
steps=event.steps, steps=event.steps,
error=event.error if isinstance(event, IterationRunFailedEvent) else None error=event.error if isinstance(event, IterationRunFailedEvent) else None,
) )
) )
@ -371,9 +366,6 @@ class WorkflowBasedAppRunner(AppRunner):
# return workflow # return workflow
return workflow return workflow
def _publish_event(self, event: AppQueueEvent) -> None: def _publish_event(self, event: AppQueueEvent) -> None:
self.queue_manager.publish( self.queue_manager.publish(event, PublishFrom.APPLICATION_MANAGER)
event,
PublishFrom.APPLICATION_MANAGER
)

View File

@ -30,169 +30,145 @@ _TEXT_COLOR_MAPPING = {
class WorkflowLoggingCallback(WorkflowCallback): class WorkflowLoggingCallback(WorkflowCallback):
def __init__(self) -> None: def __init__(self) -> None:
self.current_node_id = None self.current_node_id = None
def on_event( def on_event(self, event: GraphEngineEvent) -> None:
self,
event: GraphEngineEvent
) -> None:
if isinstance(event, GraphRunStartedEvent): if isinstance(event, GraphRunStartedEvent):
self.print_text("\n[GraphRunStartedEvent]", color='pink') self.print_text("\n[GraphRunStartedEvent]", color="pink")
elif isinstance(event, GraphRunSucceededEvent): elif isinstance(event, GraphRunSucceededEvent):
self.print_text("\n[GraphRunSucceededEvent]", color='green') self.print_text("\n[GraphRunSucceededEvent]", color="green")
elif isinstance(event, GraphRunFailedEvent): elif isinstance(event, GraphRunFailedEvent):
self.print_text(f"\n[GraphRunFailedEvent] reason: {event.error}", color='red') self.print_text(f"\n[GraphRunFailedEvent] reason: {event.error}", color="red")
elif isinstance(event, NodeRunStartedEvent): elif isinstance(event, NodeRunStartedEvent):
self.on_workflow_node_execute_started( self.on_workflow_node_execute_started(event=event)
event=event
)
elif isinstance(event, NodeRunSucceededEvent): elif isinstance(event, NodeRunSucceededEvent):
self.on_workflow_node_execute_succeeded( self.on_workflow_node_execute_succeeded(event=event)
event=event
)
elif isinstance(event, NodeRunFailedEvent): elif isinstance(event, NodeRunFailedEvent):
self.on_workflow_node_execute_failed( self.on_workflow_node_execute_failed(event=event)
event=event
)
elif isinstance(event, NodeRunStreamChunkEvent): elif isinstance(event, NodeRunStreamChunkEvent):
self.on_node_text_chunk( self.on_node_text_chunk(event=event)
event=event
)
elif isinstance(event, ParallelBranchRunStartedEvent): elif isinstance(event, ParallelBranchRunStartedEvent):
self.on_workflow_parallel_started( self.on_workflow_parallel_started(event=event)
event=event
)
elif isinstance(event, ParallelBranchRunSucceededEvent | ParallelBranchRunFailedEvent): elif isinstance(event, ParallelBranchRunSucceededEvent | ParallelBranchRunFailedEvent):
self.on_workflow_parallel_completed( self.on_workflow_parallel_completed(event=event)
event=event
)
elif isinstance(event, IterationRunStartedEvent): elif isinstance(event, IterationRunStartedEvent):
self.on_workflow_iteration_started( self.on_workflow_iteration_started(event=event)
event=event
)
elif isinstance(event, IterationRunNextEvent): elif isinstance(event, IterationRunNextEvent):
self.on_workflow_iteration_next( self.on_workflow_iteration_next(event=event)
event=event
)
elif isinstance(event, IterationRunSucceededEvent | IterationRunFailedEvent): elif isinstance(event, IterationRunSucceededEvent | IterationRunFailedEvent):
self.on_workflow_iteration_completed( self.on_workflow_iteration_completed(event=event)
event=event
)
else: else:
self.print_text(f"\n[{event.__class__.__name__}]", color='blue') self.print_text(f"\n[{event.__class__.__name__}]", color="blue")
def on_workflow_node_execute_started( def on_workflow_node_execute_started(self, event: NodeRunStartedEvent) -> None:
self,
event: NodeRunStartedEvent
) -> None:
""" """
Workflow node execute started Workflow node execute started
""" """
self.print_text("\n[NodeRunStartedEvent]", color='yellow') self.print_text("\n[NodeRunStartedEvent]", color="yellow")
self.print_text(f"Node ID: {event.node_id}", color='yellow') self.print_text(f"Node ID: {event.node_id}", color="yellow")
self.print_text(f"Node Title: {event.node_data.title}", color='yellow') self.print_text(f"Node Title: {event.node_data.title}", color="yellow")
self.print_text(f"Type: {event.node_type.value}", color='yellow') self.print_text(f"Type: {event.node_type.value}", color="yellow")
def on_workflow_node_execute_succeeded( def on_workflow_node_execute_succeeded(self, event: NodeRunSucceededEvent) -> None:
self,
event: NodeRunSucceededEvent
) -> None:
""" """
Workflow node execute succeeded Workflow node execute succeeded
""" """
route_node_state = event.route_node_state route_node_state = event.route_node_state
self.print_text("\n[NodeRunSucceededEvent]", color='green') self.print_text("\n[NodeRunSucceededEvent]", color="green")
self.print_text(f"Node ID: {event.node_id}", color='green') self.print_text(f"Node ID: {event.node_id}", color="green")
self.print_text(f"Node Title: {event.node_data.title}", color='green') self.print_text(f"Node Title: {event.node_data.title}", color="green")
self.print_text(f"Type: {event.node_type.value}", color='green') self.print_text(f"Type: {event.node_type.value}", color="green")
if route_node_state.node_run_result: if route_node_state.node_run_result:
node_run_result = route_node_state.node_run_result node_run_result = route_node_state.node_run_result
self.print_text(f"Inputs: {jsonable_encoder(node_run_result.inputs) if node_run_result.inputs else ''}", self.print_text(
color='green') f"Inputs: {jsonable_encoder(node_run_result.inputs) if node_run_result.inputs else ''}", color="green"
)
self.print_text( self.print_text(
f"Process Data: {jsonable_encoder(node_run_result.process_data) if node_run_result.process_data else ''}", f"Process Data: {jsonable_encoder(node_run_result.process_data) if node_run_result.process_data else ''}",
color='green') color="green",
self.print_text(f"Outputs: {jsonable_encoder(node_run_result.outputs) if node_run_result.outputs else ''}", )
color='green') self.print_text(
f"Outputs: {jsonable_encoder(node_run_result.outputs) if node_run_result.outputs else ''}",
color="green",
)
self.print_text( self.print_text(
f"Metadata: {jsonable_encoder(node_run_result.metadata) if node_run_result.metadata else ''}", f"Metadata: {jsonable_encoder(node_run_result.metadata) if node_run_result.metadata else ''}",
color='green') color="green",
)
def on_workflow_node_execute_failed( def on_workflow_node_execute_failed(self, event: NodeRunFailedEvent) -> None:
self,
event: NodeRunFailedEvent
) -> None:
""" """
Workflow node execute failed Workflow node execute failed
""" """
route_node_state = event.route_node_state route_node_state = event.route_node_state
self.print_text("\n[NodeRunFailedEvent]", color='red') self.print_text("\n[NodeRunFailedEvent]", color="red")
self.print_text(f"Node ID: {event.node_id}", color='red') self.print_text(f"Node ID: {event.node_id}", color="red")
self.print_text(f"Node Title: {event.node_data.title}", color='red') self.print_text(f"Node Title: {event.node_data.title}", color="red")
self.print_text(f"Type: {event.node_type.value}", color='red') self.print_text(f"Type: {event.node_type.value}", color="red")
if route_node_state.node_run_result: if route_node_state.node_run_result:
node_run_result = route_node_state.node_run_result node_run_result = route_node_state.node_run_result
self.print_text(f"Error: {node_run_result.error}", color='red') self.print_text(f"Error: {node_run_result.error}", color="red")
self.print_text(f"Inputs: {jsonable_encoder(node_run_result.inputs) if node_run_result.inputs else ''}", self.print_text(
color='red') f"Inputs: {jsonable_encoder(node_run_result.inputs) if node_run_result.inputs else ''}", color="red"
)
self.print_text( self.print_text(
f"Process Data: {jsonable_encoder(node_run_result.process_data) if node_run_result.process_data else ''}", f"Process Data: {jsonable_encoder(node_run_result.process_data) if node_run_result.process_data else ''}",
color='red') color="red",
self.print_text(f"Outputs: {jsonable_encoder(node_run_result.outputs) if node_run_result.outputs else ''}", )
color='red') self.print_text(
f"Outputs: {jsonable_encoder(node_run_result.outputs) if node_run_result.outputs else ''}", color="red"
)
def on_node_text_chunk( def on_node_text_chunk(self, event: NodeRunStreamChunkEvent) -> None:
self,
event: NodeRunStreamChunkEvent
) -> None:
""" """
Publish text chunk Publish text chunk
""" """
route_node_state = event.route_node_state route_node_state = event.route_node_state
if not self.current_node_id or self.current_node_id != route_node_state.node_id: if not self.current_node_id or self.current_node_id != route_node_state.node_id:
self.current_node_id = route_node_state.node_id self.current_node_id = route_node_state.node_id
self.print_text('\n[NodeRunStreamChunkEvent]') self.print_text("\n[NodeRunStreamChunkEvent]")
self.print_text(f"Node ID: {route_node_state.node_id}") self.print_text(f"Node ID: {route_node_state.node_id}")
node_run_result = route_node_state.node_run_result node_run_result = route_node_state.node_run_result
if node_run_result: if node_run_result:
self.print_text( self.print_text(
f"Metadata: {jsonable_encoder(node_run_result.metadata) if node_run_result.metadata else ''}") f"Metadata: {jsonable_encoder(node_run_result.metadata) if node_run_result.metadata else ''}"
)
self.print_text(event.chunk_content, color="pink", end="") self.print_text(event.chunk_content, color="pink", end="")
def on_workflow_parallel_started( def on_workflow_parallel_started(self, event: ParallelBranchRunStartedEvent) -> None:
self,
event: ParallelBranchRunStartedEvent
) -> None:
""" """
Publish parallel started Publish parallel started
""" """
self.print_text("\n[ParallelBranchRunStartedEvent]", color='blue') self.print_text("\n[ParallelBranchRunStartedEvent]", color="blue")
self.print_text(f"Parallel ID: {event.parallel_id}", color='blue') self.print_text(f"Parallel ID: {event.parallel_id}", color="blue")
self.print_text(f"Branch ID: {event.parallel_start_node_id}", color='blue') self.print_text(f"Branch ID: {event.parallel_start_node_id}", color="blue")
if event.in_iteration_id: if event.in_iteration_id:
self.print_text(f"Iteration ID: {event.in_iteration_id}", color='blue') self.print_text(f"Iteration ID: {event.in_iteration_id}", color="blue")
def on_workflow_parallel_completed( def on_workflow_parallel_completed(
self, self, event: ParallelBranchRunSucceededEvent | ParallelBranchRunFailedEvent
event: ParallelBranchRunSucceededEvent | ParallelBranchRunFailedEvent
) -> None: ) -> None:
""" """
Publish parallel completed Publish parallel completed
""" """
if isinstance(event, ParallelBranchRunSucceededEvent): if isinstance(event, ParallelBranchRunSucceededEvent):
color = 'blue' color = "blue"
elif isinstance(event, ParallelBranchRunFailedEvent): elif isinstance(event, ParallelBranchRunFailedEvent):
color = 'red' color = "red"
self.print_text("\n[ParallelBranchRunSucceededEvent]" if isinstance(event, ParallelBranchRunSucceededEvent) else "\n[ParallelBranchRunFailedEvent]", color=color) self.print_text(
"\n[ParallelBranchRunSucceededEvent]"
if isinstance(event, ParallelBranchRunSucceededEvent)
else "\n[ParallelBranchRunFailedEvent]",
color=color,
)
self.print_text(f"Parallel ID: {event.parallel_id}", color=color) self.print_text(f"Parallel ID: {event.parallel_id}", color=color)
self.print_text(f"Branch ID: {event.parallel_start_node_id}", color=color) self.print_text(f"Branch ID: {event.parallel_start_node_id}", color=color)
if event.in_iteration_id: if event.in_iteration_id:
@ -201,43 +177,37 @@ class WorkflowLoggingCallback(WorkflowCallback):
if isinstance(event, ParallelBranchRunFailedEvent): if isinstance(event, ParallelBranchRunFailedEvent):
self.print_text(f"Error: {event.error}", color=color) self.print_text(f"Error: {event.error}", color=color)
def on_workflow_iteration_started( def on_workflow_iteration_started(self, event: IterationRunStartedEvent) -> None:
self,
event: IterationRunStartedEvent
) -> None:
""" """
Publish iteration started Publish iteration started
""" """
self.print_text("\n[IterationRunStartedEvent]", color='blue') self.print_text("\n[IterationRunStartedEvent]", color="blue")
self.print_text(f"Iteration Node ID: {event.iteration_id}", color='blue') self.print_text(f"Iteration Node ID: {event.iteration_id}", color="blue")
def on_workflow_iteration_next( def on_workflow_iteration_next(self, event: IterationRunNextEvent) -> None:
self,
event: IterationRunNextEvent
) -> None:
""" """
Publish iteration next Publish iteration next
""" """
self.print_text("\n[IterationRunNextEvent]", color='blue') self.print_text("\n[IterationRunNextEvent]", color="blue")
self.print_text(f"Iteration Node ID: {event.iteration_id}", color='blue') self.print_text(f"Iteration Node ID: {event.iteration_id}", color="blue")
self.print_text(f"Iteration Index: {event.index}", color='blue') self.print_text(f"Iteration Index: {event.index}", color="blue")
def on_workflow_iteration_completed( def on_workflow_iteration_completed(self, event: IterationRunSucceededEvent | IterationRunFailedEvent) -> None:
self,
event: IterationRunSucceededEvent | IterationRunFailedEvent
) -> None:
""" """
Publish iteration completed Publish iteration completed
""" """
self.print_text("\n[IterationRunSucceededEvent]" if isinstance(event, IterationRunSucceededEvent) else "\n[IterationRunFailedEvent]", color='blue') self.print_text(
self.print_text(f"Node ID: {event.iteration_id}", color='blue') "\n[IterationRunSucceededEvent]"
if isinstance(event, IterationRunSucceededEvent)
else "\n[IterationRunFailedEvent]",
color="blue",
)
self.print_text(f"Node ID: {event.iteration_id}", color="blue")
def print_text( def print_text(self, text: str, color: Optional[str] = None, end: str = "\n") -> None:
self, text: str, color: Optional[str] = None, end: str = "\n"
) -> None:
"""Print text with highlighting and no end characters.""" """Print text with highlighting and no end characters."""
text_to_print = self._get_colored_text(text, color) if color else text text_to_print = self._get_colored_text(text, color) if color else text
print(f'{text_to_print}', end=end) print(f"{text_to_print}", end=end)
def _get_colored_text(self, text: str, color: str) -> str: def _get_colored_text(self, text: str, color: str) -> str:
"""Get colored text.""" """Get colored text."""

View File

@ -15,13 +15,14 @@ class InvokeFrom(Enum):
""" """
Invoke From. Invoke From.
""" """
SERVICE_API = 'service-api'
WEB_APP = 'web-app' SERVICE_API = "service-api"
EXPLORE = 'explore' WEB_APP = "web-app"
DEBUGGER = 'debugger' EXPLORE = "explore"
DEBUGGER = "debugger"
@classmethod @classmethod
def value_of(cls, value: str) -> 'InvokeFrom': def value_of(cls, value: str) -> "InvokeFrom":
""" """
Get value of given mode. Get value of given mode.
@ -31,7 +32,7 @@ class InvokeFrom(Enum):
for mode in cls: for mode in cls:
if mode.value == value: if mode.value == value:
return mode return mode
raise ValueError(f'invalid invoke from value {value}') raise ValueError(f"invalid invoke from value {value}")
def to_source(self) -> str: def to_source(self) -> str:
""" """
@ -40,21 +41,22 @@ class InvokeFrom(Enum):
:return: source :return: source
""" """
if self == InvokeFrom.WEB_APP: if self == InvokeFrom.WEB_APP:
return 'web_app' return "web_app"
elif self == InvokeFrom.DEBUGGER: elif self == InvokeFrom.DEBUGGER:
return 'dev' return "dev"
elif self == InvokeFrom.EXPLORE: elif self == InvokeFrom.EXPLORE:
return 'explore_app' return "explore_app"
elif self == InvokeFrom.SERVICE_API: elif self == InvokeFrom.SERVICE_API:
return 'api' return "api"
return 'dev' return "dev"
class ModelConfigWithCredentialsEntity(BaseModel): class ModelConfigWithCredentialsEntity(BaseModel):
""" """
Model Config With Credentials Entity. Model Config With Credentials Entity.
""" """
provider: str provider: str
model: str model: str
model_schema: AIModelEntity model_schema: AIModelEntity
@ -72,6 +74,7 @@ class AppGenerateEntity(BaseModel):
""" """
App Generate Entity. App Generate Entity.
""" """
task_id: str task_id: str
# app config # app config
@ -102,6 +105,7 @@ class EasyUIBasedAppGenerateEntity(AppGenerateEntity):
""" """
Chat Application Generate Entity. Chat Application Generate Entity.
""" """
# app config # app config
app_config: EasyUIBasedAppConfig app_config: EasyUIBasedAppConfig
model_conf: ModelConfigWithCredentialsEntity model_conf: ModelConfigWithCredentialsEntity
@ -116,6 +120,7 @@ class ChatAppGenerateEntity(EasyUIBasedAppGenerateEntity):
""" """
Chat Application Generate Entity. Chat Application Generate Entity.
""" """
conversation_id: Optional[str] = None conversation_id: Optional[str] = None
@ -123,6 +128,7 @@ class CompletionAppGenerateEntity(EasyUIBasedAppGenerateEntity):
""" """
Completion Application Generate Entity. Completion Application Generate Entity.
""" """
pass pass
@ -130,6 +136,7 @@ class AgentChatAppGenerateEntity(EasyUIBasedAppGenerateEntity):
""" """
Agent Chat Application Generate Entity. Agent Chat Application Generate Entity.
""" """
conversation_id: Optional[str] = None conversation_id: Optional[str] = None
@ -137,6 +144,7 @@ class AdvancedChatAppGenerateEntity(AppGenerateEntity):
""" """
Advanced Chat Application Generate Entity. Advanced Chat Application Generate Entity.
""" """
# app config # app config
app_config: WorkflowUIBasedAppConfig app_config: WorkflowUIBasedAppConfig
@ -147,15 +155,18 @@ class AdvancedChatAppGenerateEntity(AppGenerateEntity):
""" """
Single Iteration Run Entity. Single Iteration Run Entity.
""" """
node_id: str node_id: str
inputs: dict inputs: dict
single_iteration_run: Optional[SingleIterationRunEntity] = None single_iteration_run: Optional[SingleIterationRunEntity] = None
class WorkflowAppGenerateEntity(AppGenerateEntity): class WorkflowAppGenerateEntity(AppGenerateEntity):
""" """
Workflow Application Generate Entity. Workflow Application Generate Entity.
""" """
# app config # app config
app_config: WorkflowUIBasedAppConfig app_config: WorkflowUIBasedAppConfig
@ -163,6 +174,7 @@ class WorkflowAppGenerateEntity(AppGenerateEntity):
""" """
Single Iteration Run Entity. Single Iteration Run Entity.
""" """
node_id: str node_id: str
inputs: dict inputs: dict

View File

@ -14,6 +14,7 @@ class QueueEvent(str, Enum):
""" """
QueueEvent enum QueueEvent enum
""" """
LLM_CHUNK = "llm_chunk" LLM_CHUNK = "llm_chunk"
TEXT_CHUNK = "text_chunk" TEXT_CHUNK = "text_chunk"
AGENT_MESSAGE = "agent_message" AGENT_MESSAGE = "agent_message"
@ -45,6 +46,7 @@ class AppQueueEvent(BaseModel):
""" """
QueueEvent abstract entity QueueEvent abstract entity
""" """
event: QueueEvent event: QueueEvent
@ -53,13 +55,16 @@ class QueueLLMChunkEvent(AppQueueEvent):
QueueLLMChunkEvent entity QueueLLMChunkEvent entity
Only for basic mode apps Only for basic mode apps
""" """
event: QueueEvent = QueueEvent.LLM_CHUNK event: QueueEvent = QueueEvent.LLM_CHUNK
chunk: LLMResultChunk chunk: LLMResultChunk
class QueueIterationStartEvent(AppQueueEvent): class QueueIterationStartEvent(AppQueueEvent):
""" """
QueueIterationStartEvent entity QueueIterationStartEvent entity
""" """
event: QueueEvent = QueueEvent.ITERATION_START event: QueueEvent = QueueEvent.ITERATION_START
node_execution_id: str node_execution_id: str
node_id: str node_id: str
@ -80,10 +85,12 @@ class QueueIterationStartEvent(AppQueueEvent):
predecessor_node_id: Optional[str] = None predecessor_node_id: Optional[str] = None
metadata: Optional[dict[str, Any]] = None metadata: Optional[dict[str, Any]] = None
class QueueIterationNextEvent(AppQueueEvent): class QueueIterationNextEvent(AppQueueEvent):
""" """
QueueIterationNextEvent entity QueueIterationNextEvent entity
""" """
event: QueueEvent = QueueEvent.ITERATION_NEXT event: QueueEvent = QueueEvent.ITERATION_NEXT
index: int index: int
@ -101,9 +108,9 @@ class QueueIterationNextEvent(AppQueueEvent):
"""parent parallel start node id if node is in parallel""" """parent parallel start node id if node is in parallel"""
node_run_index: int node_run_index: int
output: Optional[Any] = None # output for the current iteration output: Optional[Any] = None # output for the current iteration
@field_validator('output', mode='before') @field_validator("output", mode="before")
@classmethod @classmethod
def set_output(cls, v): def set_output(cls, v):
""" """
@ -113,12 +120,14 @@ class QueueIterationNextEvent(AppQueueEvent):
return None return None
if isinstance(v, int | float | str | bool | dict | list): if isinstance(v, int | float | str | bool | dict | list):
return v return v
raise ValueError('output must be a valid type') raise ValueError("output must be a valid type")
class QueueIterationCompletedEvent(AppQueueEvent): class QueueIterationCompletedEvent(AppQueueEvent):
""" """
QueueIterationCompletedEvent entity QueueIterationCompletedEvent entity
""" """
event: QueueEvent = QueueEvent.ITERATION_COMPLETED event: QueueEvent = QueueEvent.ITERATION_COMPLETED
node_execution_id: str node_execution_id: str
@ -134,7 +143,7 @@ class QueueIterationCompletedEvent(AppQueueEvent):
parent_parallel_start_node_id: Optional[str] = None parent_parallel_start_node_id: Optional[str] = None
"""parent parallel start node id if node is in parallel""" """parent parallel start node id if node is in parallel"""
start_at: datetime start_at: datetime
node_run_index: int node_run_index: int
inputs: Optional[dict[str, Any]] = None inputs: Optional[dict[str, Any]] = None
outputs: Optional[dict[str, Any]] = None outputs: Optional[dict[str, Any]] = None
@ -148,6 +157,7 @@ class QueueTextChunkEvent(AppQueueEvent):
""" """
QueueTextChunkEvent entity QueueTextChunkEvent entity
""" """
event: QueueEvent = QueueEvent.TEXT_CHUNK event: QueueEvent = QueueEvent.TEXT_CHUNK
text: str text: str
from_variable_selector: Optional[list[str]] = None from_variable_selector: Optional[list[str]] = None
@ -160,14 +170,16 @@ class QueueAgentMessageEvent(AppQueueEvent):
""" """
QueueMessageEvent entity QueueMessageEvent entity
""" """
event: QueueEvent = QueueEvent.AGENT_MESSAGE event: QueueEvent = QueueEvent.AGENT_MESSAGE
chunk: LLMResultChunk chunk: LLMResultChunk
class QueueMessageReplaceEvent(AppQueueEvent): class QueueMessageReplaceEvent(AppQueueEvent):
""" """
QueueMessageReplaceEvent entity QueueMessageReplaceEvent entity
""" """
event: QueueEvent = QueueEvent.MESSAGE_REPLACE event: QueueEvent = QueueEvent.MESSAGE_REPLACE
text: str text: str
@ -176,6 +188,7 @@ class QueueRetrieverResourcesEvent(AppQueueEvent):
""" """
QueueRetrieverResourcesEvent entity QueueRetrieverResourcesEvent entity
""" """
event: QueueEvent = QueueEvent.RETRIEVER_RESOURCES event: QueueEvent = QueueEvent.RETRIEVER_RESOURCES
retriever_resources: list[dict] retriever_resources: list[dict]
in_iteration_id: Optional[str] = None in_iteration_id: Optional[str] = None
@ -186,6 +199,7 @@ class QueueAnnotationReplyEvent(AppQueueEvent):
""" """
QueueAnnotationReplyEvent entity QueueAnnotationReplyEvent entity
""" """
event: QueueEvent = QueueEvent.ANNOTATION_REPLY event: QueueEvent = QueueEvent.ANNOTATION_REPLY
message_annotation_id: str message_annotation_id: str
@ -194,6 +208,7 @@ class QueueMessageEndEvent(AppQueueEvent):
""" """
QueueMessageEndEvent entity QueueMessageEndEvent entity
""" """
event: QueueEvent = QueueEvent.MESSAGE_END event: QueueEvent = QueueEvent.MESSAGE_END
llm_result: Optional[LLMResult] = None llm_result: Optional[LLMResult] = None
@ -202,6 +217,7 @@ class QueueAdvancedChatMessageEndEvent(AppQueueEvent):
""" """
QueueAdvancedChatMessageEndEvent entity QueueAdvancedChatMessageEndEvent entity
""" """
event: QueueEvent = QueueEvent.ADVANCED_CHAT_MESSAGE_END event: QueueEvent = QueueEvent.ADVANCED_CHAT_MESSAGE_END
@ -209,6 +225,7 @@ class QueueWorkflowStartedEvent(AppQueueEvent):
""" """
QueueWorkflowStartedEvent entity QueueWorkflowStartedEvent entity
""" """
event: QueueEvent = QueueEvent.WORKFLOW_STARTED event: QueueEvent = QueueEvent.WORKFLOW_STARTED
graph_runtime_state: GraphRuntimeState graph_runtime_state: GraphRuntimeState
@ -217,6 +234,7 @@ class QueueWorkflowSucceededEvent(AppQueueEvent):
""" """
QueueWorkflowSucceededEvent entity QueueWorkflowSucceededEvent entity
""" """
event: QueueEvent = QueueEvent.WORKFLOW_SUCCEEDED event: QueueEvent = QueueEvent.WORKFLOW_SUCCEEDED
outputs: Optional[dict[str, Any]] = None outputs: Optional[dict[str, Any]] = None
@ -225,6 +243,7 @@ class QueueWorkflowFailedEvent(AppQueueEvent):
""" """
QueueWorkflowFailedEvent entity QueueWorkflowFailedEvent entity
""" """
event: QueueEvent = QueueEvent.WORKFLOW_FAILED event: QueueEvent = QueueEvent.WORKFLOW_FAILED
error: str error: str
@ -233,6 +252,7 @@ class QueueNodeStartedEvent(AppQueueEvent):
""" """
QueueNodeStartedEvent entity QueueNodeStartedEvent entity
""" """
event: QueueEvent = QueueEvent.NODE_STARTED event: QueueEvent = QueueEvent.NODE_STARTED
node_execution_id: str node_execution_id: str
@ -258,6 +278,7 @@ class QueueNodeSucceededEvent(AppQueueEvent):
""" """
QueueNodeSucceededEvent entity QueueNodeSucceededEvent entity
""" """
event: QueueEvent = QueueEvent.NODE_SUCCEEDED event: QueueEvent = QueueEvent.NODE_SUCCEEDED
node_execution_id: str node_execution_id: str
@ -288,6 +309,7 @@ class QueueNodeFailedEvent(AppQueueEvent):
""" """
QueueNodeFailedEvent entity QueueNodeFailedEvent entity
""" """
event: QueueEvent = QueueEvent.NODE_FAILED event: QueueEvent = QueueEvent.NODE_FAILED
node_execution_id: str node_execution_id: str
@ -317,6 +339,7 @@ class QueueAgentThoughtEvent(AppQueueEvent):
""" """
QueueAgentThoughtEvent entity QueueAgentThoughtEvent entity
""" """
event: QueueEvent = QueueEvent.AGENT_THOUGHT event: QueueEvent = QueueEvent.AGENT_THOUGHT
agent_thought_id: str agent_thought_id: str
@ -325,6 +348,7 @@ class QueueMessageFileEvent(AppQueueEvent):
""" """
QueueAgentThoughtEvent entity QueueAgentThoughtEvent entity
""" """
event: QueueEvent = QueueEvent.MESSAGE_FILE event: QueueEvent = QueueEvent.MESSAGE_FILE
message_file_id: str message_file_id: str
@ -333,6 +357,7 @@ class QueueErrorEvent(AppQueueEvent):
""" """
QueueErrorEvent entity QueueErrorEvent entity
""" """
event: QueueEvent = QueueEvent.ERROR event: QueueEvent = QueueEvent.ERROR
error: Any = None error: Any = None
@ -341,6 +366,7 @@ class QueuePingEvent(AppQueueEvent):
""" """
QueuePingEvent entity QueuePingEvent entity
""" """
event: QueueEvent = QueueEvent.PING event: QueueEvent = QueueEvent.PING
@ -348,10 +374,12 @@ class QueueStopEvent(AppQueueEvent):
""" """
QueueStopEvent entity QueueStopEvent entity
""" """
class StopBy(Enum): class StopBy(Enum):
""" """
Stop by enum Stop by enum
""" """
USER_MANUAL = "user-manual" USER_MANUAL = "user-manual"
ANNOTATION_REPLY = "annotation-reply" ANNOTATION_REPLY = "annotation-reply"
OUTPUT_MODERATION = "output-moderation" OUTPUT_MODERATION = "output-moderation"
@ -365,19 +393,20 @@ class QueueStopEvent(AppQueueEvent):
To stop reason To stop reason
""" """
reason_mapping = { reason_mapping = {
QueueStopEvent.StopBy.USER_MANUAL: 'Stopped by user.', QueueStopEvent.StopBy.USER_MANUAL: "Stopped by user.",
QueueStopEvent.StopBy.ANNOTATION_REPLY: 'Stopped by annotation reply.', QueueStopEvent.StopBy.ANNOTATION_REPLY: "Stopped by annotation reply.",
QueueStopEvent.StopBy.OUTPUT_MODERATION: 'Stopped by output moderation.', QueueStopEvent.StopBy.OUTPUT_MODERATION: "Stopped by output moderation.",
QueueStopEvent.StopBy.INPUT_MODERATION: 'Stopped by input moderation.' QueueStopEvent.StopBy.INPUT_MODERATION: "Stopped by input moderation.",
} }
return reason_mapping.get(self.stopped_by, 'Stopped by unknown reason.') return reason_mapping.get(self.stopped_by, "Stopped by unknown reason.")
class QueueMessage(BaseModel): class QueueMessage(BaseModel):
""" """
QueueMessage abstract entity QueueMessage abstract entity
""" """
task_id: str task_id: str
app_mode: str app_mode: str
event: AppQueueEvent event: AppQueueEvent
@ -387,6 +416,7 @@ class MessageQueueMessage(QueueMessage):
""" """
MessageQueueMessage entity MessageQueueMessage entity
""" """
message_id: str message_id: str
conversation_id: str conversation_id: str
@ -395,6 +425,7 @@ class WorkflowQueueMessage(QueueMessage):
""" """
WorkflowQueueMessage entity WorkflowQueueMessage entity
""" """
pass pass
@ -402,6 +433,7 @@ class QueueParallelBranchRunStartedEvent(AppQueueEvent):
""" """
QueueParallelBranchRunStartedEvent entity QueueParallelBranchRunStartedEvent entity
""" """
event: QueueEvent = QueueEvent.PARALLEL_BRANCH_RUN_STARTED event: QueueEvent = QueueEvent.PARALLEL_BRANCH_RUN_STARTED
parallel_id: str parallel_id: str
@ -418,6 +450,7 @@ class QueueParallelBranchRunSucceededEvent(AppQueueEvent):
""" """
QueueParallelBranchRunSucceededEvent entity QueueParallelBranchRunSucceededEvent entity
""" """
event: QueueEvent = QueueEvent.PARALLEL_BRANCH_RUN_SUCCEEDED event: QueueEvent = QueueEvent.PARALLEL_BRANCH_RUN_SUCCEEDED
parallel_id: str parallel_id: str
@ -434,6 +467,7 @@ class QueueParallelBranchRunFailedEvent(AppQueueEvent):
""" """
QueueParallelBranchRunFailedEvent entity QueueParallelBranchRunFailedEvent entity
""" """
event: QueueEvent = QueueEvent.PARALLEL_BRANCH_RUN_FAILED event: QueueEvent = QueueEvent.PARALLEL_BRANCH_RUN_FAILED
parallel_id: str parallel_id: str

View File

@ -12,6 +12,7 @@ class TaskState(BaseModel):
""" """
TaskState entity TaskState entity
""" """
metadata: dict = {} metadata: dict = {}
@ -19,6 +20,7 @@ class EasyUITaskState(TaskState):
""" """
EasyUITaskState entity EasyUITaskState entity
""" """
llm_result: LLMResult llm_result: LLMResult
@ -26,6 +28,7 @@ class WorkflowTaskState(TaskState):
""" """
WorkflowTaskState entity WorkflowTaskState entity
""" """
answer: str = "" answer: str = ""
@ -33,6 +36,7 @@ class StreamEvent(Enum):
""" """
Stream event Stream event
""" """
PING = "ping" PING = "ping"
ERROR = "error" ERROR = "error"
MESSAGE = "message" MESSAGE = "message"
@ -60,6 +64,7 @@ class StreamResponse(BaseModel):
""" """
StreamResponse entity StreamResponse entity
""" """
event: StreamEvent event: StreamEvent
task_id: str task_id: str
@ -71,6 +76,7 @@ class ErrorStreamResponse(StreamResponse):
""" """
ErrorStreamResponse entity ErrorStreamResponse entity
""" """
event: StreamEvent = StreamEvent.ERROR event: StreamEvent = StreamEvent.ERROR
err: Exception err: Exception
model_config = ConfigDict(arbitrary_types_allowed=True) model_config = ConfigDict(arbitrary_types_allowed=True)
@ -80,6 +86,7 @@ class MessageStreamResponse(StreamResponse):
""" """
MessageStreamResponse entity MessageStreamResponse entity
""" """
event: StreamEvent = StreamEvent.MESSAGE event: StreamEvent = StreamEvent.MESSAGE
id: str id: str
answer: str answer: str
@ -89,6 +96,7 @@ class MessageAudioStreamResponse(StreamResponse):
""" """
MessageStreamResponse entity MessageStreamResponse entity
""" """
event: StreamEvent = StreamEvent.TTS_MESSAGE event: StreamEvent = StreamEvent.TTS_MESSAGE
audio: str audio: str
@ -97,6 +105,7 @@ class MessageAudioEndStreamResponse(StreamResponse):
""" """
MessageStreamResponse entity MessageStreamResponse entity
""" """
event: StreamEvent = StreamEvent.TTS_MESSAGE_END event: StreamEvent = StreamEvent.TTS_MESSAGE_END
audio: str audio: str
@ -105,6 +114,7 @@ class MessageEndStreamResponse(StreamResponse):
""" """
MessageEndStreamResponse entity MessageEndStreamResponse entity
""" """
event: StreamEvent = StreamEvent.MESSAGE_END event: StreamEvent = StreamEvent.MESSAGE_END
id: str id: str
metadata: dict = {} metadata: dict = {}
@ -114,6 +124,7 @@ class MessageFileStreamResponse(StreamResponse):
""" """
MessageFileStreamResponse entity MessageFileStreamResponse entity
""" """
event: StreamEvent = StreamEvent.MESSAGE_FILE event: StreamEvent = StreamEvent.MESSAGE_FILE
id: str id: str
type: str type: str
@ -125,6 +136,7 @@ class MessageReplaceStreamResponse(StreamResponse):
""" """
MessageReplaceStreamResponse entity MessageReplaceStreamResponse entity
""" """
event: StreamEvent = StreamEvent.MESSAGE_REPLACE event: StreamEvent = StreamEvent.MESSAGE_REPLACE
answer: str answer: str
@ -133,6 +145,7 @@ class AgentThoughtStreamResponse(StreamResponse):
""" """
AgentThoughtStreamResponse entity AgentThoughtStreamResponse entity
""" """
event: StreamEvent = StreamEvent.AGENT_THOUGHT event: StreamEvent = StreamEvent.AGENT_THOUGHT
id: str id: str
position: int position: int
@ -148,6 +161,7 @@ class AgentMessageStreamResponse(StreamResponse):
""" """
AgentMessageStreamResponse entity AgentMessageStreamResponse entity
""" """
event: StreamEvent = StreamEvent.AGENT_MESSAGE event: StreamEvent = StreamEvent.AGENT_MESSAGE
id: str id: str
answer: str answer: str
@ -162,6 +176,7 @@ class WorkflowStartStreamResponse(StreamResponse):
""" """
Data entity Data entity
""" """
id: str id: str
workflow_id: str workflow_id: str
sequence_number: int sequence_number: int
@ -182,6 +197,7 @@ class WorkflowFinishStreamResponse(StreamResponse):
""" """
Data entity Data entity
""" """
id: str id: str
workflow_id: str workflow_id: str
sequence_number: int sequence_number: int
@ -210,6 +226,7 @@ class NodeStartStreamResponse(StreamResponse):
""" """
Data entity Data entity
""" """
id: str id: str
node_id: str node_id: str
node_type: str node_type: str
@ -249,7 +266,7 @@ class NodeStartStreamResponse(StreamResponse):
"parent_parallel_id": self.data.parent_parallel_id, "parent_parallel_id": self.data.parent_parallel_id,
"parent_parallel_start_node_id": self.data.parent_parallel_start_node_id, "parent_parallel_start_node_id": self.data.parent_parallel_start_node_id,
"iteration_id": self.data.iteration_id, "iteration_id": self.data.iteration_id,
} },
} }
@ -262,6 +279,7 @@ class NodeFinishStreamResponse(StreamResponse):
""" """
Data entity Data entity
""" """
id: str id: str
node_id: str node_id: str
node_type: str node_type: str
@ -315,9 +333,9 @@ class NodeFinishStreamResponse(StreamResponse):
"parent_parallel_id": self.data.parent_parallel_id, "parent_parallel_id": self.data.parent_parallel_id,
"parent_parallel_start_node_id": self.data.parent_parallel_start_node_id, "parent_parallel_start_node_id": self.data.parent_parallel_start_node_id,
"iteration_id": self.data.iteration_id, "iteration_id": self.data.iteration_id,
} },
} }
class ParallelBranchStartStreamResponse(StreamResponse): class ParallelBranchStartStreamResponse(StreamResponse):
""" """
@ -328,6 +346,7 @@ class ParallelBranchStartStreamResponse(StreamResponse):
""" """
Data entity Data entity
""" """
parallel_id: str parallel_id: str
parallel_branch_id: str parallel_branch_id: str
parent_parallel_id: Optional[str] = None parent_parallel_id: Optional[str] = None
@ -349,6 +368,7 @@ class ParallelBranchFinishedStreamResponse(StreamResponse):
""" """
Data entity Data entity
""" """
parallel_id: str parallel_id: str
parallel_branch_id: str parallel_branch_id: str
parent_parallel_id: Optional[str] = None parent_parallel_id: Optional[str] = None
@ -372,6 +392,7 @@ class IterationNodeStartStreamResponse(StreamResponse):
""" """
Data entity Data entity
""" """
id: str id: str
node_id: str node_id: str
node_type: str node_type: str
@ -397,6 +418,7 @@ class IterationNodeNextStreamResponse(StreamResponse):
""" """
Data entity Data entity
""" """
id: str id: str
node_id: str node_id: str
node_type: str node_type: str
@ -422,6 +444,7 @@ class IterationNodeCompletedStreamResponse(StreamResponse):
""" """
Data entity Data entity
""" """
id: str id: str
node_id: str node_id: str
node_type: str node_type: str
@ -454,6 +477,7 @@ class TextChunkStreamResponse(StreamResponse):
""" """
Data entity Data entity
""" """
text: str text: str
event: StreamEvent = StreamEvent.TEXT_CHUNK event: StreamEvent = StreamEvent.TEXT_CHUNK
@ -469,6 +493,7 @@ class TextReplaceStreamResponse(StreamResponse):
""" """
Data entity Data entity
""" """
text: str text: str
event: StreamEvent = StreamEvent.TEXT_REPLACE event: StreamEvent = StreamEvent.TEXT_REPLACE
@ -479,6 +504,7 @@ class PingStreamResponse(StreamResponse):
""" """
PingStreamResponse entity PingStreamResponse entity
""" """
event: StreamEvent = StreamEvent.PING event: StreamEvent = StreamEvent.PING
@ -486,6 +512,7 @@ class AppStreamResponse(BaseModel):
""" """
AppStreamResponse entity AppStreamResponse entity
""" """
stream_response: StreamResponse stream_response: StreamResponse
@ -493,6 +520,7 @@ class ChatbotAppStreamResponse(AppStreamResponse):
""" """
ChatbotAppStreamResponse entity ChatbotAppStreamResponse entity
""" """
conversation_id: str conversation_id: str
message_id: str message_id: str
created_at: int created_at: int
@ -502,6 +530,7 @@ class CompletionAppStreamResponse(AppStreamResponse):
""" """
CompletionAppStreamResponse entity CompletionAppStreamResponse entity
""" """
message_id: str message_id: str
created_at: int created_at: int
@ -510,6 +539,7 @@ class WorkflowAppStreamResponse(AppStreamResponse):
""" """
WorkflowAppStreamResponse entity WorkflowAppStreamResponse entity
""" """
workflow_run_id: Optional[str] = None workflow_run_id: Optional[str] = None
@ -517,6 +547,7 @@ class AppBlockingResponse(BaseModel):
""" """
AppBlockingResponse entity AppBlockingResponse entity
""" """
task_id: str task_id: str
def to_dict(self) -> dict: def to_dict(self) -> dict:
@ -532,6 +563,7 @@ class ChatbotAppBlockingResponse(AppBlockingResponse):
""" """
Data entity Data entity
""" """
id: str id: str
mode: str mode: str
conversation_id: str conversation_id: str
@ -552,6 +584,7 @@ class CompletionAppBlockingResponse(AppBlockingResponse):
""" """
Data entity Data entity
""" """
id: str id: str
mode: str mode: str
message_id: str message_id: str
@ -571,6 +604,7 @@ class WorkflowAppBlockingResponse(AppBlockingResponse):
""" """
Data entity Data entity
""" """
id: str id: str
workflow_id: str workflow_id: str
status: str status: str

View File

@ -13,11 +13,9 @@ logger = logging.getLogger(__name__)
class AnnotationReplyFeature: class AnnotationReplyFeature:
def query(self, app_record: App, def query(
message: Message, self, app_record: App, message: Message, query: str, user_id: str, invoke_from: InvokeFrom
query: str, ) -> Optional[MessageAnnotation]:
user_id: str,
invoke_from: InvokeFrom) -> Optional[MessageAnnotation]:
""" """
Query app annotations to reply Query app annotations to reply
:param app_record: app record :param app_record: app record
@ -27,8 +25,9 @@ class AnnotationReplyFeature:
:param invoke_from: invoke from :param invoke_from: invoke from
:return: :return:
""" """
annotation_setting = db.session.query(AppAnnotationSetting).filter( annotation_setting = (
AppAnnotationSetting.app_id == app_record.id).first() db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app_record.id).first()
)
if not annotation_setting: if not annotation_setting:
return None return None
@ -41,55 +40,50 @@ class AnnotationReplyFeature:
embedding_model_name = collection_binding_detail.model_name embedding_model_name = collection_binding_detail.model_name
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
embedding_provider_name, embedding_provider_name, embedding_model_name, "annotation"
embedding_model_name,
'annotation'
) )
dataset = Dataset( dataset = Dataset(
id=app_record.id, id=app_record.id,
tenant_id=app_record.tenant_id, tenant_id=app_record.tenant_id,
indexing_technique='high_quality', indexing_technique="high_quality",
embedding_model_provider=embedding_provider_name, embedding_model_provider=embedding_provider_name,
embedding_model=embedding_model_name, embedding_model=embedding_model_name,
collection_binding_id=dataset_collection_binding.id collection_binding_id=dataset_collection_binding.id,
) )
vector = Vector(dataset, attributes=['doc_id', 'annotation_id', 'app_id']) vector = Vector(dataset, attributes=["doc_id", "annotation_id", "app_id"])
documents = vector.search_by_vector( documents = vector.search_by_vector(
query=query, query=query, top_k=1, score_threshold=score_threshold, filter={"group_id": [dataset.id]}
top_k=1,
score_threshold=score_threshold,
filter={
'group_id': [dataset.id]
}
) )
if documents: if documents:
annotation_id = documents[0].metadata['annotation_id'] annotation_id = documents[0].metadata["annotation_id"]
score = documents[0].metadata['score'] score = documents[0].metadata["score"]
annotation = AppAnnotationService.get_annotation_by_id(annotation_id) annotation = AppAnnotationService.get_annotation_by_id(annotation_id)
if annotation: if annotation:
if invoke_from in [InvokeFrom.SERVICE_API, InvokeFrom.WEB_APP]: if invoke_from in [InvokeFrom.SERVICE_API, InvokeFrom.WEB_APP]:
from_source = 'api' from_source = "api"
else: else:
from_source = 'console' from_source = "console"
# insert annotation history # insert annotation history
AppAnnotationService.add_annotation_history(annotation.id, AppAnnotationService.add_annotation_history(
app_record.id, annotation.id,
annotation.question, app_record.id,
annotation.content, annotation.question,
query, annotation.content,
user_id, query,
message.id, user_id,
from_source, message.id,
score) from_source,
score,
)
return annotation return annotation
except Exception as e: except Exception as e:
logger.warning(f'Query annotation failed, exception: {str(e)}.') logger.warning(f"Query annotation failed, exception: {str(e)}.")
return None return None
return None return None

View File

@ -8,8 +8,9 @@ logger = logging.getLogger(__name__)
class HostingModerationFeature: class HostingModerationFeature:
def check(self, application_generate_entity: EasyUIBasedAppGenerateEntity, def check(
prompt_messages: list[PromptMessage]) -> bool: self, application_generate_entity: EasyUIBasedAppGenerateEntity, prompt_messages: list[PromptMessage]
) -> bool:
""" """
Check hosting moderation Check hosting moderation
:param application_generate_entity: application generate entity :param application_generate_entity: application generate entity
@ -23,9 +24,6 @@ class HostingModerationFeature:
if isinstance(prompt_message.content, str): if isinstance(prompt_message.content, str):
text += prompt_message.content + "\n" text += prompt_message.content + "\n"
moderation_result = moderation.check_moderation( moderation_result = moderation.check_moderation(model_config, text)
model_config,
text
)
return moderation_result return moderation_result

View File

@ -19,7 +19,7 @@ class RateLimit:
_ACTIVE_REQUESTS_COUNT_FLUSH_INTERVAL = 5 * 60 # recalculate request_count from request_detail every 5 minutes _ACTIVE_REQUESTS_COUNT_FLUSH_INTERVAL = 5 * 60 # recalculate request_count from request_detail every 5 minutes
_instance_dict = {} _instance_dict = {}
def __new__(cls: type['RateLimit'], client_id: str, max_active_requests: int): def __new__(cls: type["RateLimit"], client_id: str, max_active_requests: int):
if client_id not in cls._instance_dict: if client_id not in cls._instance_dict:
instance = super().__new__(cls) instance = super().__new__(cls)
cls._instance_dict[client_id] = instance cls._instance_dict[client_id] = instance
@ -27,13 +27,13 @@ class RateLimit:
def __init__(self, client_id: str, max_active_requests: int): def __init__(self, client_id: str, max_active_requests: int):
self.max_active_requests = max_active_requests self.max_active_requests = max_active_requests
if hasattr(self, 'initialized'): if hasattr(self, "initialized"):
return return
self.initialized = True self.initialized = True
self.client_id = client_id self.client_id = client_id
self.active_requests_key = self._ACTIVE_REQUESTS_KEY.format(client_id) self.active_requests_key = self._ACTIVE_REQUESTS_KEY.format(client_id)
self.max_active_requests_key = self._MAX_ACTIVE_REQUESTS_KEY.format(client_id) self.max_active_requests_key = self._MAX_ACTIVE_REQUESTS_KEY.format(client_id)
self.last_recalculate_time = float('-inf') self.last_recalculate_time = float("-inf")
self.flush_cache(use_local_value=True) self.flush_cache(use_local_value=True)
def flush_cache(self, use_local_value=False): def flush_cache(self, use_local_value=False):
@ -46,7 +46,7 @@ class RateLimit:
pipe.execute() pipe.execute()
else: else:
with redis_client.pipeline() as pipe: with redis_client.pipeline() as pipe:
self.max_active_requests = int(redis_client.get(self.max_active_requests_key).decode('utf-8')) self.max_active_requests = int(redis_client.get(self.max_active_requests_key).decode("utf-8"))
redis_client.expire(self.max_active_requests_key, timedelta(days=1)) redis_client.expire(self.max_active_requests_key, timedelta(days=1))
# flush max active requests (in-transit request list) # flush max active requests (in-transit request list)
@ -54,8 +54,11 @@ class RateLimit:
return return
request_details = redis_client.hgetall(self.active_requests_key) request_details = redis_client.hgetall(self.active_requests_key)
redis_client.expire(self.active_requests_key, timedelta(days=1)) redis_client.expire(self.active_requests_key, timedelta(days=1))
timeout_requests = [k for k, v in request_details.items() if timeout_requests = [
time.time() - float(v.decode('utf-8')) > RateLimit._REQUEST_MAX_ALIVE_TIME] k
for k, v in request_details.items()
if time.time() - float(v.decode("utf-8")) > RateLimit._REQUEST_MAX_ALIVE_TIME
]
if timeout_requests: if timeout_requests:
redis_client.hdel(self.active_requests_key, *timeout_requests) redis_client.hdel(self.active_requests_key, *timeout_requests)
@ -69,8 +72,10 @@ class RateLimit:
active_requests_count = redis_client.hlen(self.active_requests_key) active_requests_count = redis_client.hlen(self.active_requests_key)
if active_requests_count >= self.max_active_requests: if active_requests_count >= self.max_active_requests:
raise AppInvokeQuotaExceededError("Too many requests. Please try again later. The current maximum " raise AppInvokeQuotaExceededError(
"concurrent requests allowed is {}.".format(self.max_active_requests)) "Too many requests. Please try again later. The current maximum "
"concurrent requests allowed is {}.".format(self.max_active_requests)
)
redis_client.hset(self.active_requests_key, request_id, str(time.time())) redis_client.hset(self.active_requests_key, request_id, str(time.time()))
return request_id return request_id
@ -116,5 +121,5 @@ class RateLimitGenerator:
if not self.closed: if not self.closed:
self.closed = True self.closed = True
self.rate_limit.exit(self.request_id) self.rate_limit.exit(self.request_id)
if self.generator is not None and hasattr(self.generator, 'close'): if self.generator is not None and hasattr(self.generator, "close"):
self.generator.close() self.generator.close()

View File

@ -25,25 +25,25 @@ from .variables import (
) )
__all__ = [ __all__ = [
'IntegerVariable', "IntegerVariable",
'FloatVariable', "FloatVariable",
'ObjectVariable', "ObjectVariable",
'SecretVariable', "SecretVariable",
'StringVariable', "StringVariable",
'ArrayAnyVariable', "ArrayAnyVariable",
'Variable', "Variable",
'SegmentType', "SegmentType",
'SegmentGroup', "SegmentGroup",
'Segment', "Segment",
'NoneSegment', "NoneSegment",
'NoneVariable', "NoneVariable",
'IntegerSegment', "IntegerSegment",
'FloatSegment', "FloatSegment",
'ObjectSegment', "ObjectSegment",
'ArrayAnySegment', "ArrayAnySegment",
'StringSegment', "StringSegment",
'ArrayStringVariable', "ArrayStringVariable",
'ArrayNumberVariable', "ArrayNumberVariable",
'ArrayObjectVariable', "ArrayObjectVariable",
'ArraySegment', "ArraySegment",
] ]

View File

@ -28,12 +28,12 @@ from .variables import (
def build_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable: def build_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable:
if (value_type := mapping.get('value_type')) is None: if (value_type := mapping.get("value_type")) is None:
raise VariableError('missing value type') raise VariableError("missing value type")
if not mapping.get('name'): if not mapping.get("name"):
raise VariableError('missing name') raise VariableError("missing name")
if (value := mapping.get('value')) is None: if (value := mapping.get("value")) is None:
raise VariableError('missing value') raise VariableError("missing value")
match value_type: match value_type:
case SegmentType.STRING: case SegmentType.STRING:
result = StringVariable.model_validate(mapping) result = StringVariable.model_validate(mapping)
@ -44,7 +44,7 @@ def build_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable:
case SegmentType.NUMBER if isinstance(value, float): case SegmentType.NUMBER if isinstance(value, float):
result = FloatVariable.model_validate(mapping) result = FloatVariable.model_validate(mapping)
case SegmentType.NUMBER if not isinstance(value, float | int): case SegmentType.NUMBER if not isinstance(value, float | int):
raise VariableError(f'invalid number value {value}') raise VariableError(f"invalid number value {value}")
case SegmentType.OBJECT if isinstance(value, dict): case SegmentType.OBJECT if isinstance(value, dict):
result = ObjectVariable.model_validate(mapping) result = ObjectVariable.model_validate(mapping)
case SegmentType.ARRAY_STRING if isinstance(value, list): case SegmentType.ARRAY_STRING if isinstance(value, list):
@ -54,9 +54,9 @@ def build_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable:
case SegmentType.ARRAY_OBJECT if isinstance(value, list): case SegmentType.ARRAY_OBJECT if isinstance(value, list):
result = ArrayObjectVariable.model_validate(mapping) result = ArrayObjectVariable.model_validate(mapping)
case _: case _:
raise VariableError(f'not supported value type {value_type}') raise VariableError(f"not supported value type {value_type}")
if result.size > dify_config.MAX_VARIABLE_SIZE: if result.size > dify_config.MAX_VARIABLE_SIZE:
raise VariableError(f'variable size {result.size} exceeds limit {dify_config.MAX_VARIABLE_SIZE}') raise VariableError(f"variable size {result.size} exceeds limit {dify_config.MAX_VARIABLE_SIZE}")
return result return result
@ -73,4 +73,4 @@ def build_segment(value: Any, /) -> Segment:
return ObjectSegment(value=value) return ObjectSegment(value=value)
if isinstance(value, list): if isinstance(value, list):
return ArrayAnySegment(value=value) return ArrayAnySegment(value=value)
raise ValueError(f'not supported value {value}') raise ValueError(f"not supported value {value}")

View File

@ -4,14 +4,14 @@ from core.workflow.entities.variable_pool import VariablePool
from . import SegmentGroup, factory from . import SegmentGroup, factory
VARIABLE_PATTERN = re.compile(r'\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10})#\}\}') VARIABLE_PATTERN = re.compile(r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10})#\}\}")
def convert_template(*, template: str, variable_pool: VariablePool): def convert_template(*, template: str, variable_pool: VariablePool):
parts = re.split(VARIABLE_PATTERN, template) parts = re.split(VARIABLE_PATTERN, template)
segments = [] segments = []
for part in filter(lambda x: x, parts): for part in filter(lambda x: x, parts):
if '.' in part and (value := variable_pool.get(part.split('.'))): if "." in part and (value := variable_pool.get(part.split("."))):
segments.append(value) segments.append(value)
else: else:
segments.append(factory.build_segment(part)) segments.append(factory.build_segment(part))

View File

@ -8,15 +8,15 @@ class SegmentGroup(Segment):
@property @property
def text(self): def text(self):
return ''.join([segment.text for segment in self.value]) return "".join([segment.text for segment in self.value])
@property @property
def log(self): def log(self):
return ''.join([segment.log for segment in self.value]) return "".join([segment.log for segment in self.value])
@property @property
def markdown(self): def markdown(self):
return ''.join([segment.markdown for segment in self.value]) return "".join([segment.markdown for segment in self.value])
def to_object(self): def to_object(self):
return [segment.to_object() for segment in self.value] return [segment.to_object() for segment in self.value]

View File

@ -14,13 +14,13 @@ class Segment(BaseModel):
value_type: SegmentType value_type: SegmentType
value: Any value: Any
@field_validator('value_type') @field_validator("value_type")
def validate_value_type(cls, value): def validate_value_type(cls, value):
""" """
This validator checks if the provided value is equal to the default value of the 'value_type' field. This validator checks if the provided value is equal to the default value of the 'value_type' field.
If the value is different, a ValueError is raised. If the value is different, a ValueError is raised.
""" """
if value != cls.model_fields['value_type'].default: if value != cls.model_fields["value_type"].default:
raise ValueError("Cannot modify 'value_type'") raise ValueError("Cannot modify 'value_type'")
return value return value
@ -50,15 +50,15 @@ class NoneSegment(Segment):
@property @property
def text(self) -> str: def text(self) -> str:
return 'null' return "null"
@property @property
def log(self) -> str: def log(self) -> str:
return 'null' return "null"
@property @property
def markdown(self) -> str: def markdown(self) -> str:
return 'null' return "null"
class StringSegment(Segment): class StringSegment(Segment):
@ -76,24 +76,21 @@ class IntegerSegment(Segment):
value: int value: int
class ObjectSegment(Segment): class ObjectSegment(Segment):
value_type: SegmentType = SegmentType.OBJECT value_type: SegmentType = SegmentType.OBJECT
value: Mapping[str, Any] value: Mapping[str, Any]
@property @property
def text(self) -> str: def text(self) -> str:
return json.dumps(self.model_dump()['value'], ensure_ascii=False) return json.dumps(self.model_dump()["value"], ensure_ascii=False)
@property @property
def log(self) -> str: def log(self) -> str:
return json.dumps(self.model_dump()['value'], ensure_ascii=False, indent=2) return json.dumps(self.model_dump()["value"], ensure_ascii=False, indent=2)
@property @property
def markdown(self) -> str: def markdown(self) -> str:
return json.dumps(self.model_dump()['value'], ensure_ascii=False, indent=2) return json.dumps(self.model_dump()["value"], ensure_ascii=False, indent=2)
class ArraySegment(Segment): class ArraySegment(Segment):
@ -101,11 +98,11 @@ class ArraySegment(Segment):
def markdown(self) -> str: def markdown(self) -> str:
items = [] items = []
for item in self.value: for item in self.value:
if hasattr(item, 'to_markdown'): if hasattr(item, "to_markdown"):
items.append(item.to_markdown()) items.append(item.to_markdown())
else: else:
items.append(str(item)) items.append(str(item))
return '\n'.join(items) return "\n".join(items)
class ArrayAnySegment(ArraySegment): class ArrayAnySegment(ArraySegment):
@ -126,4 +123,3 @@ class ArrayNumberSegment(ArraySegment):
class ArrayObjectSegment(ArraySegment): class ArrayObjectSegment(ArraySegment):
value_type: SegmentType = SegmentType.ARRAY_OBJECT value_type: SegmentType = SegmentType.ARRAY_OBJECT
value: Sequence[Mapping[str, Any]] value: Sequence[Mapping[str, Any]]

View File

@ -2,14 +2,14 @@ from enum import Enum
class SegmentType(str, Enum): class SegmentType(str, Enum):
NONE = 'none' NONE = "none"
NUMBER = 'number' NUMBER = "number"
STRING = 'string' STRING = "string"
SECRET = 'secret' SECRET = "secret"
ARRAY_ANY = 'array[any]' ARRAY_ANY = "array[any]"
ARRAY_STRING = 'array[string]' ARRAY_STRING = "array[string]"
ARRAY_NUMBER = 'array[number]' ARRAY_NUMBER = "array[number]"
ARRAY_OBJECT = 'array[object]' ARRAY_OBJECT = "array[object]"
OBJECT = 'object' OBJECT = "object"
GROUP = 'group' GROUP = "group"

View File

@ -23,11 +23,11 @@ class Variable(Segment):
""" """
id: str = Field( id: str = Field(
default='', default="",
description="Unique identity for variable. It's only used by environment variables now.", description="Unique identity for variable. It's only used by environment variables now.",
) )
name: str name: str
description: str = Field(default='', description='Description of the variable.') description: str = Field(default="", description="Description of the variable.")
class StringVariable(StringSegment, Variable): class StringVariable(StringSegment, Variable):
@ -62,7 +62,6 @@ class ArrayObjectVariable(ArrayObjectSegment, Variable):
pass pass
class SecretVariable(StringVariable): class SecretVariable(StringVariable):
value_type: SegmentType = SegmentType.SECRET value_type: SegmentType = SegmentType.SECRET

View File

@ -32,10 +32,13 @@ class BasedGenerateTaskPipeline:
_task_state: TaskState _task_state: TaskState
_application_generate_entity: AppGenerateEntity _application_generate_entity: AppGenerateEntity
def __init__(self, application_generate_entity: AppGenerateEntity, def __init__(
queue_manager: AppQueueManager, self,
user: Union[Account, EndUser], application_generate_entity: AppGenerateEntity,
stream: bool) -> None: queue_manager: AppQueueManager,
user: Union[Account, EndUser],
stream: bool,
) -> None:
""" """
Initialize GenerateTaskPipeline. Initialize GenerateTaskPipeline.
:param application_generate_entity: application generate entity :param application_generate_entity: application generate entity
@ -61,18 +64,18 @@ class BasedGenerateTaskPipeline:
e = event.error e = event.error
if isinstance(e, InvokeAuthorizationError): if isinstance(e, InvokeAuthorizationError):
err = InvokeAuthorizationError('Incorrect API key provided') err = InvokeAuthorizationError("Incorrect API key provided")
elif isinstance(e, InvokeError) or isinstance(e, ValueError): elif isinstance(e, InvokeError) or isinstance(e, ValueError):
err = e err = e
else: else:
err = Exception(e.description if getattr(e, 'description', None) is not None else str(e)) err = Exception(e.description if getattr(e, "description", None) is not None else str(e))
if message: if message:
refetch_message = db.session.query(Message).filter(Message.id == message.id).first() refetch_message = db.session.query(Message).filter(Message.id == message.id).first()
if refetch_message: if refetch_message:
err_desc = self._error_to_desc(err) err_desc = self._error_to_desc(err)
refetch_message.status = 'error' refetch_message.status = "error"
refetch_message.error = err_desc refetch_message.error = err_desc
db.session.commit() db.session.commit()
@ -86,12 +89,14 @@ class BasedGenerateTaskPipeline:
:return: :return:
""" """
if isinstance(e, QuotaExceededError): if isinstance(e, QuotaExceededError):
return ("Your quota for Dify Hosted Model Provider has been exhausted. " return (
"Please go to Settings -> Model Provider to complete your own provider credentials.") "Your quota for Dify Hosted Model Provider has been exhausted. "
"Please go to Settings -> Model Provider to complete your own provider credentials."
)
message = getattr(e, 'description', str(e)) message = getattr(e, "description", str(e))
if not message: if not message:
message = 'Internal Server Error, please contact support.' message = "Internal Server Error, please contact support."
return message return message
@ -101,10 +106,7 @@ class BasedGenerateTaskPipeline:
:param e: exception :param e: exception
:return: :return:
""" """
return ErrorStreamResponse( return ErrorStreamResponse(task_id=self._application_generate_entity.task_id, err=e)
task_id=self._application_generate_entity.task_id,
err=e
)
def _ping_stream_response(self) -> PingStreamResponse: def _ping_stream_response(self) -> PingStreamResponse:
""" """
@ -125,11 +127,8 @@ class BasedGenerateTaskPipeline:
return OutputModeration( return OutputModeration(
tenant_id=app_config.tenant_id, tenant_id=app_config.tenant_id,
app_id=app_config.app_id, app_id=app_config.app_id,
rule=ModerationRule( rule=ModerationRule(type=sensitive_word_avoidance.type, config=sensitive_word_avoidance.config),
type=sensitive_word_avoidance.type, queue_manager=self._queue_manager,
config=sensitive_word_avoidance.config
),
queue_manager=self._queue_manager
) )
def _handle_output_moderation_when_task_finished(self, completion: str) -> Optional[str]: def _handle_output_moderation_when_task_finished(self, completion: str) -> Optional[str]:
@ -143,8 +142,7 @@ class BasedGenerateTaskPipeline:
self._output_moderation_handler.stop_thread() self._output_moderation_handler.stop_thread()
completion = self._output_moderation_handler.moderation_completion( completion = self._output_moderation_handler.moderation_completion(
completion=completion, completion=completion, public_event=False
public_event=False
) )
self._output_moderation_handler = None self._output_moderation_handler = None

View File

@ -64,23 +64,21 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
""" """
EasyUIBasedGenerateTaskPipeline is a class that generate stream output and state management for Application. EasyUIBasedGenerateTaskPipeline is a class that generate stream output and state management for Application.
""" """
_task_state: EasyUITaskState
_application_generate_entity: Union[
ChatAppGenerateEntity,
CompletionAppGenerateEntity,
AgentChatAppGenerateEntity
]
def __init__(self, application_generate_entity: Union[ _task_state: EasyUITaskState
ChatAppGenerateEntity, _application_generate_entity: Union[ChatAppGenerateEntity, CompletionAppGenerateEntity, AgentChatAppGenerateEntity]
CompletionAppGenerateEntity,
AgentChatAppGenerateEntity def __init__(
], self,
queue_manager: AppQueueManager, application_generate_entity: Union[
conversation: Conversation, ChatAppGenerateEntity, CompletionAppGenerateEntity, AgentChatAppGenerateEntity
message: Message, ],
user: Union[Account, EndUser], queue_manager: AppQueueManager,
stream: bool) -> None: conversation: Conversation,
message: Message,
user: Union[Account, EndUser],
stream: bool,
) -> None:
""" """
Initialize GenerateTaskPipeline. Initialize GenerateTaskPipeline.
:param application_generate_entity: application generate entity :param application_generate_entity: application generate entity
@ -101,18 +99,18 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
model=self._model_config.model, model=self._model_config.model,
prompt_messages=[], prompt_messages=[],
message=AssistantPromptMessage(content=""), message=AssistantPromptMessage(content=""),
usage=LLMUsage.empty_usage() usage=LLMUsage.empty_usage(),
) )
) )
self._conversation_name_generate_thread = None self._conversation_name_generate_thread = None
def process( def process(
self, self,
) -> Union[ ) -> Union[
ChatbotAppBlockingResponse, ChatbotAppBlockingResponse,
CompletionAppBlockingResponse, CompletionAppBlockingResponse,
Generator[Union[ChatbotAppStreamResponse, CompletionAppStreamResponse], None, None] Generator[Union[ChatbotAppStreamResponse, CompletionAppStreamResponse], None, None],
]: ]:
""" """
Process generate task pipeline. Process generate task pipeline.
@ -125,22 +123,18 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
if self._application_generate_entity.app_config.app_mode != AppMode.COMPLETION: if self._application_generate_entity.app_config.app_mode != AppMode.COMPLETION:
# start generate conversation name thread # start generate conversation name thread
self._conversation_name_generate_thread = self._generate_conversation_name( self._conversation_name_generate_thread = self._generate_conversation_name(
self._conversation, self._conversation, self._application_generate_entity.query
self._application_generate_entity.query
) )
generator = self._wrapper_process_stream_response( generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager)
trace_manager=self._application_generate_entity.trace_manager
)
if self._stream: if self._stream:
return self._to_stream_response(generator) return self._to_stream_response(generator)
else: else:
return self._to_blocking_response(generator) return self._to_blocking_response(generator)
def _to_blocking_response(self, generator: Generator[StreamResponse, None, None]) -> Union[ def _to_blocking_response(
ChatbotAppBlockingResponse, self, generator: Generator[StreamResponse, None, None]
CompletionAppBlockingResponse ) -> Union[ChatbotAppBlockingResponse, CompletionAppBlockingResponse]:
]:
""" """
Process blocking response. Process blocking response.
:return: :return:
@ -149,11 +143,9 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
if isinstance(stream_response, ErrorStreamResponse): if isinstance(stream_response, ErrorStreamResponse):
raise stream_response.err raise stream_response.err
elif isinstance(stream_response, MessageEndStreamResponse): elif isinstance(stream_response, MessageEndStreamResponse):
extras = { extras = {"usage": jsonable_encoder(self._task_state.llm_result.usage)}
'usage': jsonable_encoder(self._task_state.llm_result.usage)
}
if self._task_state.metadata: if self._task_state.metadata:
extras['metadata'] = self._task_state.metadata extras["metadata"] = self._task_state.metadata
if self._conversation.mode == AppMode.COMPLETION.value: if self._conversation.mode == AppMode.COMPLETION.value:
response = CompletionAppBlockingResponse( response = CompletionAppBlockingResponse(
@ -164,8 +156,8 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
message_id=self._message.id, message_id=self._message.id,
answer=self._task_state.llm_result.message.content, answer=self._task_state.llm_result.message.content,
created_at=int(self._message.created_at.timestamp()), created_at=int(self._message.created_at.timestamp()),
**extras **extras,
) ),
) )
else: else:
response = ChatbotAppBlockingResponse( response = ChatbotAppBlockingResponse(
@ -177,18 +169,19 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
message_id=self._message.id, message_id=self._message.id,
answer=self._task_state.llm_result.message.content, answer=self._task_state.llm_result.message.content,
created_at=int(self._message.created_at.timestamp()), created_at=int(self._message.created_at.timestamp()),
**extras **extras,
) ),
) )
return response return response
else: else:
continue continue
raise Exception('Queue listening stopped unexpectedly.') raise Exception("Queue listening stopped unexpectedly.")
def _to_stream_response(self, generator: Generator[StreamResponse, None, None]) \ def _to_stream_response(
-> Generator[Union[ChatbotAppStreamResponse, CompletionAppStreamResponse], None, None]: self, generator: Generator[StreamResponse, None, None]
) -> Generator[Union[ChatbotAppStreamResponse, CompletionAppStreamResponse], None, None]:
""" """
To stream response. To stream response.
:return: :return:
@ -198,14 +191,14 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
yield CompletionAppStreamResponse( yield CompletionAppStreamResponse(
message_id=self._message.id, message_id=self._message.id,
created_at=int(self._message.created_at.timestamp()), created_at=int(self._message.created_at.timestamp()),
stream_response=stream_response stream_response=stream_response,
) )
else: else:
yield ChatbotAppStreamResponse( yield ChatbotAppStreamResponse(
conversation_id=self._conversation.id, conversation_id=self._conversation.id,
message_id=self._message.id, message_id=self._message.id,
created_at=int(self._message.created_at.timestamp()), created_at=int(self._message.created_at.timestamp()),
stream_response=stream_response stream_response=stream_response,
) )
def _listenAudioMsg(self, publisher, task_id: str): def _listenAudioMsg(self, publisher, task_id: str):
@ -217,15 +210,19 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id) return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id)
return None return None
def _wrapper_process_stream_response(self, trace_manager: Optional[TraceQueueManager] = None) -> \ def _wrapper_process_stream_response(
Generator[StreamResponse, None, None]: self, trace_manager: Optional[TraceQueueManager] = None
) -> Generator[StreamResponse, None, None]:
tenant_id = self._application_generate_entity.app_config.tenant_id tenant_id = self._application_generate_entity.app_config.tenant_id
task_id = self._application_generate_entity.task_id task_id = self._application_generate_entity.task_id
publisher = None publisher = None
text_to_speech_dict = self._app_config.app_model_config_dict.get('text_to_speech') text_to_speech_dict = self._app_config.app_model_config_dict.get("text_to_speech")
if text_to_speech_dict and text_to_speech_dict.get('autoPlay') == 'enabled' and text_to_speech_dict.get('enabled'): if (
publisher = AppGeneratorTTSPublisher(tenant_id, text_to_speech_dict.get('voice', None)) text_to_speech_dict
and text_to_speech_dict.get("autoPlay") == "enabled"
and text_to_speech_dict.get("enabled")
):
publisher = AppGeneratorTTSPublisher(tenant_id, text_to_speech_dict.get("voice", None))
for response in self._process_stream_response(publisher=publisher, trace_manager=trace_manager): for response in self._process_stream_response(publisher=publisher, trace_manager=trace_manager):
while True: while True:
audio_response = self._listenAudioMsg(publisher, task_id) audio_response = self._listenAudioMsg(publisher, task_id)
@ -250,14 +247,11 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
break break
else: else:
start_listener_time = time.time() start_listener_time = time.time()
yield MessageAudioStreamResponse(audio=audio.audio, yield MessageAudioStreamResponse(audio=audio.audio, task_id=task_id)
task_id=task_id) yield MessageAudioEndStreamResponse(audio="", task_id=task_id)
yield MessageAudioEndStreamResponse(audio='', task_id=task_id)
def _process_stream_response( def _process_stream_response(
self, self, publisher: AppGeneratorTTSPublisher, trace_manager: Optional[TraceQueueManager] = None
publisher: AppGeneratorTTSPublisher,
trace_manager: Optional[TraceQueueManager] = None
) -> Generator[StreamResponse, None, None]: ) -> Generator[StreamResponse, None, None]:
""" """
Process stream response. Process stream response.
@ -333,9 +327,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
if self._conversation_name_generate_thread: if self._conversation_name_generate_thread:
self._conversation_name_generate_thread.join() self._conversation_name_generate_thread.join()
def _save_message( def _save_message(self, trace_manager: Optional[TraceQueueManager] = None) -> None:
self, trace_manager: Optional[TraceQueueManager] = None
) -> None:
""" """
Save message. Save message.
:return: :return:
@ -347,31 +339,32 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
self._conversation = db.session.query(Conversation).filter(Conversation.id == self._conversation.id).first() self._conversation = db.session.query(Conversation).filter(Conversation.id == self._conversation.id).first()
self._message.message = PromptMessageUtil.prompt_messages_to_prompt_for_saving( self._message.message = PromptMessageUtil.prompt_messages_to_prompt_for_saving(
self._model_config.mode, self._model_config.mode, self._task_state.llm_result.prompt_messages
self._task_state.llm_result.prompt_messages
) )
self._message.message_tokens = usage.prompt_tokens self._message.message_tokens = usage.prompt_tokens
self._message.message_unit_price = usage.prompt_unit_price self._message.message_unit_price = usage.prompt_unit_price
self._message.message_price_unit = usage.prompt_price_unit self._message.message_price_unit = usage.prompt_price_unit
self._message.answer = PromptTemplateParser.remove_template_variables(llm_result.message.content.strip()) \ self._message.answer = (
if llm_result.message.content else '' PromptTemplateParser.remove_template_variables(llm_result.message.content.strip())
if llm_result.message.content
else ""
)
self._message.answer_tokens = usage.completion_tokens self._message.answer_tokens = usage.completion_tokens
self._message.answer_unit_price = usage.completion_unit_price self._message.answer_unit_price = usage.completion_unit_price
self._message.answer_price_unit = usage.completion_price_unit self._message.answer_price_unit = usage.completion_price_unit
self._message.provider_response_latency = time.perf_counter() - self._start_at self._message.provider_response_latency = time.perf_counter() - self._start_at
self._message.total_price = usage.total_price self._message.total_price = usage.total_price
self._message.currency = usage.currency self._message.currency = usage.currency
self._message.message_metadata = json.dumps(jsonable_encoder(self._task_state.metadata)) \ self._message.message_metadata = (
if self._task_state.metadata else None json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
)
db.session.commit() db.session.commit()
if trace_manager: if trace_manager:
trace_manager.add_trace_task( trace_manager.add_trace_task(
TraceTask( TraceTask(
TraceTaskName.MESSAGE_TRACE, TraceTaskName.MESSAGE_TRACE, conversation_id=self._conversation.id, message_id=self._message.id
conversation_id=self._conversation.id,
message_id=self._message.id
) )
) )
@ -379,11 +372,9 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
self._message, self._message,
application_generate_entity=self._application_generate_entity, application_generate_entity=self._application_generate_entity,
conversation=self._conversation, conversation=self._conversation,
is_first_message=self._application_generate_entity.app_config.app_mode in [ is_first_message=self._application_generate_entity.app_config.app_mode in [AppMode.AGENT_CHAT, AppMode.CHAT]
AppMode.AGENT_CHAT, and self._application_generate_entity.conversation_id is None,
AppMode.CHAT extras=self._application_generate_entity.extras,
] and self._application_generate_entity.conversation_id is None,
extras=self._application_generate_entity.extras
) )
def _handle_stop(self, event: QueueStopEvent) -> None: def _handle_stop(self, event: QueueStopEvent) -> None:
@ -395,22 +386,17 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
model = model_config.model model = model_config.model
model_instance = ModelInstance( model_instance = ModelInstance(
provider_model_bundle=model_config.provider_model_bundle, provider_model_bundle=model_config.provider_model_bundle, model=model_config.model
model=model_config.model
) )
# calculate num tokens # calculate num tokens
prompt_tokens = 0 prompt_tokens = 0
if event.stopped_by != QueueStopEvent.StopBy.ANNOTATION_REPLY: if event.stopped_by != QueueStopEvent.StopBy.ANNOTATION_REPLY:
prompt_tokens = model_instance.get_llm_num_tokens( prompt_tokens = model_instance.get_llm_num_tokens(self._task_state.llm_result.prompt_messages)
self._task_state.llm_result.prompt_messages
)
completion_tokens = 0 completion_tokens = 0
if event.stopped_by == QueueStopEvent.StopBy.USER_MANUAL: if event.stopped_by == QueueStopEvent.StopBy.USER_MANUAL:
completion_tokens = model_instance.get_llm_num_tokens( completion_tokens = model_instance.get_llm_num_tokens([self._task_state.llm_result.message])
[self._task_state.llm_result.message]
)
credentials = model_config.credentials credentials = model_config.credentials
@ -418,10 +404,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
model_type_instance = model_config.provider_model_bundle.model_type_instance model_type_instance = model_config.provider_model_bundle.model_type_instance
model_type_instance = cast(LargeLanguageModel, model_type_instance) model_type_instance = cast(LargeLanguageModel, model_type_instance)
self._task_state.llm_result.usage = model_type_instance._calc_response_usage( self._task_state.llm_result.usage = model_type_instance._calc_response_usage(
model, model, credentials, prompt_tokens, completion_tokens
credentials,
prompt_tokens,
completion_tokens
) )
def _message_end_to_stream_response(self) -> MessageEndStreamResponse: def _message_end_to_stream_response(self) -> MessageEndStreamResponse:
@ -429,16 +412,14 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
Message end to stream response. Message end to stream response.
:return: :return:
""" """
self._task_state.metadata['usage'] = jsonable_encoder(self._task_state.llm_result.usage) self._task_state.metadata["usage"] = jsonable_encoder(self._task_state.llm_result.usage)
extras = {} extras = {}
if self._task_state.metadata: if self._task_state.metadata:
extras['metadata'] = self._task_state.metadata extras["metadata"] = self._task_state.metadata
return MessageEndStreamResponse( return MessageEndStreamResponse(
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id, id=self._message.id, **extras
id=self._message.id,
**extras
) )
def _agent_message_to_stream_response(self, answer: str, message_id: str) -> AgentMessageStreamResponse: def _agent_message_to_stream_response(self, answer: str, message_id: str) -> AgentMessageStreamResponse:
@ -449,9 +430,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
:return: :return:
""" """
return AgentMessageStreamResponse( return AgentMessageStreamResponse(
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id, id=message_id, answer=answer
id=message_id,
answer=answer
) )
def _agent_thought_to_stream_response(self, event: QueueAgentThoughtEvent) -> Optional[AgentThoughtStreamResponse]: def _agent_thought_to_stream_response(self, event: QueueAgentThoughtEvent) -> Optional[AgentThoughtStreamResponse]:
@ -461,9 +440,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
:return: :return:
""" """
agent_thought: MessageAgentThought = ( agent_thought: MessageAgentThought = (
db.session.query(MessageAgentThought) db.session.query(MessageAgentThought).filter(MessageAgentThought.id == event.agent_thought_id).first()
.filter(MessageAgentThought.id == event.agent_thought_id)
.first()
) )
db.session.refresh(agent_thought) db.session.refresh(agent_thought)
db.session.close() db.session.close()
@ -478,7 +455,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
tool=agent_thought.tool, tool=agent_thought.tool,
tool_labels=agent_thought.tool_labels, tool_labels=agent_thought.tool_labels,
tool_input=agent_thought.tool_input, tool_input=agent_thought.tool_input,
message_files=agent_thought.files message_files=agent_thought.files,
) )
return None return None
@ -500,15 +477,15 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
prompt_messages=self._task_state.llm_result.prompt_messages, prompt_messages=self._task_state.llm_result.prompt_messages,
delta=LLMResultChunkDelta( delta=LLMResultChunkDelta(
index=0, index=0,
message=AssistantPromptMessage(content=self._task_state.llm_result.message.content) message=AssistantPromptMessage(content=self._task_state.llm_result.message.content),
) ),
) )
), PublishFrom.TASK_PIPELINE ),
PublishFrom.TASK_PIPELINE,
) )
self._queue_manager.publish( self._queue_manager.publish(
QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION), QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION), PublishFrom.TASK_PIPELINE
PublishFrom.TASK_PIPELINE
) )
return True return True
else: else:

View File

@ -30,10 +30,7 @@ from services.annotation_service import AppAnnotationService
class MessageCycleManage: class MessageCycleManage:
_application_generate_entity: Union[ _application_generate_entity: Union[
ChatAppGenerateEntity, ChatAppGenerateEntity, CompletionAppGenerateEntity, AgentChatAppGenerateEntity, AdvancedChatAppGenerateEntity
CompletionAppGenerateEntity,
AgentChatAppGenerateEntity,
AdvancedChatAppGenerateEntity
] ]
_task_state: Union[EasyUITaskState, WorkflowTaskState] _task_state: Union[EasyUITaskState, WorkflowTaskState]
@ -49,15 +46,18 @@ class MessageCycleManage:
is_first_message = self._application_generate_entity.conversation_id is None is_first_message = self._application_generate_entity.conversation_id is None
extras = self._application_generate_entity.extras extras = self._application_generate_entity.extras
auto_generate_conversation_name = extras.get('auto_generate_conversation_name', True) auto_generate_conversation_name = extras.get("auto_generate_conversation_name", True)
if auto_generate_conversation_name and is_first_message: if auto_generate_conversation_name and is_first_message:
# start generate thread # start generate thread
thread = Thread(target=self._generate_conversation_name_worker, kwargs={ thread = Thread(
'flask_app': current_app._get_current_object(), # type: ignore target=self._generate_conversation_name_worker,
'conversation_id': conversation.id, kwargs={
'query': query "flask_app": current_app._get_current_object(), # type: ignore
}) "conversation_id": conversation.id,
"query": query,
},
)
thread.start() thread.start()
@ -65,17 +65,10 @@ class MessageCycleManage:
return None return None
def _generate_conversation_name_worker(self, def _generate_conversation_name_worker(self, flask_app: Flask, conversation_id: str, query: str):
flask_app: Flask,
conversation_id: str,
query: str):
with flask_app.app_context(): with flask_app.app_context():
# get conversation and message # get conversation and message
conversation = ( conversation = db.session.query(Conversation).filter(Conversation.id == conversation_id).first()
db.session.query(Conversation)
.filter(Conversation.id == conversation_id)
.first()
)
if not conversation: if not conversation:
return return
@ -105,12 +98,9 @@ class MessageCycleManage:
annotation = AppAnnotationService.get_annotation_by_id(event.message_annotation_id) annotation = AppAnnotationService.get_annotation_by_id(event.message_annotation_id)
if annotation: if annotation:
account = annotation.account account = annotation.account
self._task_state.metadata['annotation_reply'] = { self._task_state.metadata["annotation_reply"] = {
'id': annotation.id, "id": annotation.id,
'account': { "account": {"id": annotation.account_id, "name": account.name if account else "Dify user"},
'id': annotation.account_id,
'name': account.name if account else 'Dify user'
}
} }
return annotation return annotation
@ -124,7 +114,7 @@ class MessageCycleManage:
:return: :return:
""" """
if self._application_generate_entity.app_config.additional_features.show_retrieve_source: if self._application_generate_entity.app_config.additional_features.show_retrieve_source:
self._task_state.metadata['retriever_resources'] = event.retriever_resources self._task_state.metadata["retriever_resources"] = event.retriever_resources
def _message_file_to_stream_response(self, event: QueueMessageFileEvent) -> Optional[MessageFileStreamResponse]: def _message_file_to_stream_response(self, event: QueueMessageFileEvent) -> Optional[MessageFileStreamResponse]:
""" """
@ -132,27 +122,23 @@ class MessageCycleManage:
:param event: event :param event: event
:return: :return:
""" """
message_file = ( message_file = db.session.query(MessageFile).filter(MessageFile.id == event.message_file_id).first()
db.session.query(MessageFile)
.filter(MessageFile.id == event.message_file_id)
.first()
)
if message_file: if message_file:
# get tool file id # get tool file id
tool_file_id = message_file.url.split('/')[-1] tool_file_id = message_file.url.split("/")[-1]
# trim extension # trim extension
tool_file_id = tool_file_id.split('.')[0] tool_file_id = tool_file_id.split(".")[0]
# get extension # get extension
if '.' in message_file.url: if "." in message_file.url:
extension = f'.{message_file.url.split(".")[-1]}' extension = f'.{message_file.url.split(".")[-1]}'
if len(extension) > 10: if len(extension) > 10:
extension = '.bin' extension = ".bin"
else: else:
extension = '.bin' extension = ".bin"
# add sign url to local file # add sign url to local file
if message_file.url.startswith('http'): if message_file.url.startswith("http"):
url = message_file.url url = message_file.url
else: else:
url = ToolFileManager.sign_file(tool_file_id=tool_file_id, extension=extension) url = ToolFileManager.sign_file(tool_file_id=tool_file_id, extension=extension)
@ -161,8 +147,8 @@ class MessageCycleManage:
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id,
id=message_file.id, id=message_file.id,
type=message_file.type, type=message_file.type,
belongs_to=message_file.belongs_to or 'user', belongs_to=message_file.belongs_to or "user",
url=url url=url,
) )
return None return None
@ -174,11 +160,7 @@ class MessageCycleManage:
:param message_id: message id :param message_id: message id
:return: :return:
""" """
return MessageStreamResponse( return MessageStreamResponse(task_id=self._application_generate_entity.task_id, id=message_id, answer=answer)
task_id=self._application_generate_entity.task_id,
id=message_id,
answer=answer
)
def _message_replace_to_stream_response(self, answer: str) -> MessageReplaceStreamResponse: def _message_replace_to_stream_response(self, answer: str) -> MessageReplaceStreamResponse:
""" """
@ -186,7 +168,4 @@ class MessageCycleManage:
:param answer: answer :param answer: answer
:return: :return:
""" """
return MessageReplaceStreamResponse( return MessageReplaceStreamResponse(task_id=self._application_generate_entity.task_id, answer=answer)
task_id=self._application_generate_entity.task_id,
answer=answer
)

View File

@ -70,14 +70,14 @@ class WorkflowCycleManage:
inputs = {**self._application_generate_entity.inputs} inputs = {**self._application_generate_entity.inputs}
for key, value in (self._workflow_system_variables or {}).items(): for key, value in (self._workflow_system_variables or {}).items():
if key.value == 'conversation': if key.value == "conversation":
continue continue
inputs[f'sys.{key.value}'] = value inputs[f"sys.{key.value}"] = value
inputs = WorkflowEntry.handle_special_values(inputs) inputs = WorkflowEntry.handle_special_values(inputs)
triggered_from= ( triggered_from = (
WorkflowRunTriggeredFrom.DEBUGGING WorkflowRunTriggeredFrom.DEBUGGING
if self._application_generate_entity.invoke_from == InvokeFrom.DEBUGGER if self._application_generate_entity.invoke_from == InvokeFrom.DEBUGGER
else WorkflowRunTriggeredFrom.APP_RUN else WorkflowRunTriggeredFrom.APP_RUN
@ -185,20 +185,26 @@ class WorkflowCycleManage:
db.session.commit() db.session.commit()
running_workflow_node_executions = db.session.query(WorkflowNodeExecution).filter( running_workflow_node_executions = (
WorkflowNodeExecution.tenant_id == workflow_run.tenant_id, db.session.query(WorkflowNodeExecution)
WorkflowNodeExecution.app_id == workflow_run.app_id, .filter(
WorkflowNodeExecution.workflow_id == workflow_run.workflow_id, WorkflowNodeExecution.tenant_id == workflow_run.tenant_id,
WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value, WorkflowNodeExecution.app_id == workflow_run.app_id,
WorkflowNodeExecution.workflow_run_id == workflow_run.id, WorkflowNodeExecution.workflow_id == workflow_run.workflow_id,
WorkflowNodeExecution.status == WorkflowNodeExecutionStatus.RUNNING.value WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
).all() WorkflowNodeExecution.workflow_run_id == workflow_run.id,
WorkflowNodeExecution.status == WorkflowNodeExecutionStatus.RUNNING.value,
)
.all()
)
for workflow_node_execution in running_workflow_node_executions: for workflow_node_execution in running_workflow_node_executions:
workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value
workflow_node_execution.error = error workflow_node_execution.error = error
workflow_node_execution.finished_at = datetime.now(timezone.utc).replace(tzinfo=None) workflow_node_execution.finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
workflow_node_execution.elapsed_time = (workflow_node_execution.finished_at - workflow_node_execution.created_at).total_seconds() workflow_node_execution.elapsed_time = (
workflow_node_execution.finished_at - workflow_node_execution.created_at
).total_seconds()
db.session.commit() db.session.commit()
db.session.refresh(workflow_run) db.session.refresh(workflow_run)
@ -216,7 +222,9 @@ class WorkflowCycleManage:
return workflow_run return workflow_run
def _handle_node_execution_start(self, workflow_run: WorkflowRun, event: QueueNodeStartedEvent) -> WorkflowNodeExecution: def _handle_node_execution_start(
self, workflow_run: WorkflowRun, event: QueueNodeStartedEvent
) -> WorkflowNodeExecution:
# init workflow node execution # init workflow node execution
workflow_node_execution = WorkflowNodeExecution() workflow_node_execution = WorkflowNodeExecution()
workflow_node_execution.tenant_id = workflow_run.tenant_id workflow_node_execution.tenant_id = workflow_run.tenant_id
@ -333,16 +341,16 @@ class WorkflowCycleManage:
created_by_account = workflow_run.created_by_account created_by_account = workflow_run.created_by_account
if created_by_account: if created_by_account:
created_by = { created_by = {
'id': created_by_account.id, "id": created_by_account.id,
'name': created_by_account.name, "name": created_by_account.name,
'email': created_by_account.email, "email": created_by_account.email,
} }
else: else:
created_by_end_user = workflow_run.created_by_end_user created_by_end_user = workflow_run.created_by_end_user
if created_by_end_user: if created_by_end_user:
created_by = { created_by = {
'id': created_by_end_user.id, "id": created_by_end_user.id,
'user': created_by_end_user.session_id, "user": created_by_end_user.session_id,
} }
return WorkflowFinishStreamResponse( return WorkflowFinishStreamResponse(
@ -401,7 +409,7 @@ class WorkflowCycleManage:
# extras logic # extras logic
if event.node_type == NodeType.TOOL: if event.node_type == NodeType.TOOL:
node_data = cast(ToolNodeData, event.node_data) node_data = cast(ToolNodeData, event.node_data)
response.data.extras['icon'] = ToolManager.get_tool_icon( response.data.extras["icon"] = ToolManager.get_tool_icon(
tenant_id=self._application_generate_entity.app_config.tenant_id, tenant_id=self._application_generate_entity.app_config.tenant_id,
provider_type=node_data.provider_type, provider_type=node_data.provider_type,
provider_id=node_data.provider_id, provider_id=node_data.provider_id,
@ -410,10 +418,10 @@ class WorkflowCycleManage:
return response return response
def _workflow_node_finish_to_stream_response( def _workflow_node_finish_to_stream_response(
self, self,
event: QueueNodeSucceededEvent | QueueNodeFailedEvent, event: QueueNodeSucceededEvent | QueueNodeFailedEvent,
task_id: str, task_id: str,
workflow_node_execution: WorkflowNodeExecution workflow_node_execution: WorkflowNodeExecution,
) -> Optional[NodeFinishStreamResponse]: ) -> Optional[NodeFinishStreamResponse]:
""" """
Workflow node finish to stream response. Workflow node finish to stream response.
@ -424,7 +432,7 @@ class WorkflowCycleManage:
""" """
if workflow_node_execution.node_type in [NodeType.ITERATION.value, NodeType.LOOP.value]: if workflow_node_execution.node_type in [NodeType.ITERATION.value, NodeType.LOOP.value]:
return None return None
return NodeFinishStreamResponse( return NodeFinishStreamResponse(
task_id=task_id, task_id=task_id,
workflow_run_id=workflow_node_execution.workflow_run_id, workflow_run_id=workflow_node_execution.workflow_run_id,
@ -452,13 +460,10 @@ class WorkflowCycleManage:
iteration_id=event.in_iteration_id, iteration_id=event.in_iteration_id,
), ),
) )
def _workflow_parallel_branch_start_to_stream_response( def _workflow_parallel_branch_start_to_stream_response(
self, self, task_id: str, workflow_run: WorkflowRun, event: QueueParallelBranchRunStartedEvent
task_id: str, ) -> ParallelBranchStartStreamResponse:
workflow_run: WorkflowRun,
event: QueueParallelBranchRunStartedEvent
) -> ParallelBranchStartStreamResponse:
""" """
Workflow parallel branch start to stream response Workflow parallel branch start to stream response
:param task_id: task id :param task_id: task id
@ -476,15 +481,15 @@ class WorkflowCycleManage:
parent_parallel_start_node_id=event.parent_parallel_start_node_id, parent_parallel_start_node_id=event.parent_parallel_start_node_id,
iteration_id=event.in_iteration_id, iteration_id=event.in_iteration_id,
created_at=int(time.time()), created_at=int(time.time()),
) ),
) )
def _workflow_parallel_branch_finished_to_stream_response( def _workflow_parallel_branch_finished_to_stream_response(
self, self,
task_id: str, task_id: str,
workflow_run: WorkflowRun, workflow_run: WorkflowRun,
event: QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent event: QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent,
) -> ParallelBranchFinishedStreamResponse: ) -> ParallelBranchFinishedStreamResponse:
""" """
Workflow parallel branch finished to stream response Workflow parallel branch finished to stream response
:param task_id: task id :param task_id: task id
@ -501,18 +506,15 @@ class WorkflowCycleManage:
parent_parallel_id=event.parent_parallel_id, parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id, parent_parallel_start_node_id=event.parent_parallel_start_node_id,
iteration_id=event.in_iteration_id, iteration_id=event.in_iteration_id,
status='succeeded' if isinstance(event, QueueParallelBranchRunSucceededEvent) else 'failed', status="succeeded" if isinstance(event, QueueParallelBranchRunSucceededEvent) else "failed",
error=event.error if isinstance(event, QueueParallelBranchRunFailedEvent) else None, error=event.error if isinstance(event, QueueParallelBranchRunFailedEvent) else None,
created_at=int(time.time()), created_at=int(time.time()),
) ),
) )
def _workflow_iteration_start_to_stream_response( def _workflow_iteration_start_to_stream_response(
self, self, task_id: str, workflow_run: WorkflowRun, event: QueueIterationStartEvent
task_id: str, ) -> IterationNodeStartStreamResponse:
workflow_run: WorkflowRun,
event: QueueIterationStartEvent
) -> IterationNodeStartStreamResponse:
""" """
Workflow iteration start to stream response Workflow iteration start to stream response
:param task_id: task id :param task_id: task id
@ -534,10 +536,12 @@ class WorkflowCycleManage:
metadata=event.metadata or {}, metadata=event.metadata or {},
parallel_id=event.parallel_id, parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id, parallel_start_node_id=event.parallel_start_node_id,
) ),
) )
def _workflow_iteration_next_to_stream_response(self, task_id: str, workflow_run: WorkflowRun, event: QueueIterationNextEvent) -> IterationNodeNextStreamResponse: def _workflow_iteration_next_to_stream_response(
self, task_id: str, workflow_run: WorkflowRun, event: QueueIterationNextEvent
) -> IterationNodeNextStreamResponse:
""" """
Workflow iteration next to stream response Workflow iteration next to stream response
:param task_id: task id :param task_id: task id
@ -559,10 +563,12 @@ class WorkflowCycleManage:
extras={}, extras={},
parallel_id=event.parallel_id, parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id, parallel_start_node_id=event.parallel_start_node_id,
) ),
) )
def _workflow_iteration_completed_to_stream_response(self, task_id: str, workflow_run: WorkflowRun, event: QueueIterationCompletedEvent) -> IterationNodeCompletedStreamResponse: def _workflow_iteration_completed_to_stream_response(
self, task_id: str, workflow_run: WorkflowRun, event: QueueIterationCompletedEvent
) -> IterationNodeCompletedStreamResponse:
""" """
Workflow iteration completed to stream response Workflow iteration completed to stream response
:param task_id: task id :param task_id: task id
@ -585,13 +591,13 @@ class WorkflowCycleManage:
status=WorkflowNodeExecutionStatus.SUCCEEDED, status=WorkflowNodeExecutionStatus.SUCCEEDED,
error=None, error=None,
elapsed_time=(datetime.now(timezone.utc).replace(tzinfo=None) - event.start_at).total_seconds(), elapsed_time=(datetime.now(timezone.utc).replace(tzinfo=None) - event.start_at).total_seconds(),
total_tokens=event.metadata.get('total_tokens', 0) if event.metadata else 0, total_tokens=event.metadata.get("total_tokens", 0) if event.metadata else 0,
execution_metadata=event.metadata, execution_metadata=event.metadata,
finished_at=int(time.time()), finished_at=int(time.time()),
steps=event.steps, steps=event.steps,
parallel_id=event.parallel_id, parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id, parallel_start_node_id=event.parallel_start_node_id,
) ),
) )
def _fetch_files_from_node_outputs(self, outputs_dict: dict) -> list[dict]: def _fetch_files_from_node_outputs(self, outputs_dict: dict) -> list[dict]:
@ -643,7 +649,7 @@ class WorkflowCycleManage:
return None return None
if isinstance(value, dict): if isinstance(value, dict):
if '__variant' in value and value['__variant'] == FileVar.__name__: if "__variant" in value and value["__variant"] == FileVar.__name__:
return value return value
elif isinstance(value, FileVar): elif isinstance(value, FileVar):
return value.to_dict() return value.to_dict()
@ -656,11 +662,10 @@ class WorkflowCycleManage:
:param workflow_run_id: workflow run id :param workflow_run_id: workflow run id
:return: :return:
""" """
workflow_run = db.session.query(WorkflowRun).filter( workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == workflow_run_id).first()
WorkflowRun.id == workflow_run_id).first()
if not workflow_run: if not workflow_run:
raise Exception(f'Workflow run not found: {workflow_run_id}') raise Exception(f"Workflow run not found: {workflow_run_id}")
return workflow_run return workflow_run
@ -683,6 +688,6 @@ class WorkflowCycleManage:
) )
if not workflow_node_execution: if not workflow_node_execution:
raise Exception(f'Workflow node execution not found: {node_execution_id}') raise Exception(f"Workflow node execution not found: {node_execution_id}")
return workflow_node_execution return workflow_node_execution

View File

@ -16,31 +16,32 @@ _TEXT_COLOR_MAPPING = {
"red": "31;1", "red": "31;1",
} }
def get_colored_text(text: str, color: str) -> str: def get_colored_text(text: str, color: str) -> str:
"""Get colored text.""" """Get colored text."""
color_str = _TEXT_COLOR_MAPPING[color] color_str = _TEXT_COLOR_MAPPING[color]
return f"\u001b[{color_str}m\033[1;3m{text}\u001b[0m" return f"\u001b[{color_str}m\033[1;3m{text}\u001b[0m"
def print_text( def print_text(text: str, color: Optional[str] = None, end: str = "", file: Optional[TextIO] = None) -> None:
text: str, color: Optional[str] = None, end: str = "", file: Optional[TextIO] = None
) -> None:
"""Print text with highlighting and no end characters.""" """Print text with highlighting and no end characters."""
text_to_print = get_colored_text(text, color) if color else text text_to_print = get_colored_text(text, color) if color else text
print(text_to_print, end=end, file=file) print(text_to_print, end=end, file=file)
if file: if file:
file.flush() # ensure all printed content are written to file file.flush() # ensure all printed content are written to file
class DifyAgentCallbackHandler(BaseModel): class DifyAgentCallbackHandler(BaseModel):
"""Callback Handler that prints to std out.""" """Callback Handler that prints to std out."""
color: Optional[str] = ''
color: Optional[str] = ""
current_loop: int = 1 current_loop: int = 1
def __init__(self, color: Optional[str] = None) -> None: def __init__(self, color: Optional[str] = None) -> None:
super().__init__() super().__init__()
"""Initialize callback handler.""" """Initialize callback handler."""
# use a specific color is not specified # use a specific color is not specified
self.color = color or 'green' self.color = color or "green"
self.current_loop = 1 self.current_loop = 1
def on_tool_start( def on_tool_start(
@ -58,7 +59,7 @@ class DifyAgentCallbackHandler(BaseModel):
tool_outputs: Sequence[ToolInvokeMessage], tool_outputs: Sequence[ToolInvokeMessage],
message_id: Optional[str] = None, message_id: Optional[str] = None,
timer: Optional[Any] = None, timer: Optional[Any] = None,
trace_manager: Optional[TraceQueueManager] = None trace_manager: Optional[TraceQueueManager] = None,
) -> None: ) -> None:
"""If not the final action, print out observation.""" """If not the final action, print out observation."""
print_text("\n[on_tool_end]\n", color=self.color) print_text("\n[on_tool_end]\n", color=self.color)
@ -79,26 +80,21 @@ class DifyAgentCallbackHandler(BaseModel):
) )
) )
def on_tool_error( def on_tool_error(self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any) -> None:
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Do nothing.""" """Do nothing."""
print_text("\n[on_tool_error] Error: " + str(error) + "\n", color='red') print_text("\n[on_tool_error] Error: " + str(error) + "\n", color="red")
def on_agent_start( def on_agent_start(self, thought: str) -> None:
self, thought: str
) -> None:
"""Run on agent start.""" """Run on agent start."""
if thought: if thought:
print_text("\n[on_agent_start] \nCurrent Loop: " + \ print_text(
str(self.current_loop) + \ "\n[on_agent_start] \nCurrent Loop: " + str(self.current_loop) + "\nThought: " + thought + "\n",
"\nThought: " + thought + "\n", color=self.color) color=self.color,
)
else: else:
print_text("\n[on_agent_start] \nCurrent Loop: " + str(self.current_loop) + "\n", color=self.color) print_text("\n[on_agent_start] \nCurrent Loop: " + str(self.current_loop) + "\n", color=self.color)
def on_agent_finish( def on_agent_finish(self, color: Optional[str] = None, **kwargs: Any) -> None:
self, color: Optional[str] = None, **kwargs: Any
) -> None:
"""Run on agent end.""" """Run on agent end."""
print_text("\n[on_agent_finish]\n Loop: " + str(self.current_loop) + "\n", color=self.color) print_text("\n[on_agent_finish]\n Loop: " + str(self.current_loop) + "\n", color=self.color)
@ -107,9 +103,9 @@ class DifyAgentCallbackHandler(BaseModel):
@property @property
def ignore_agent(self) -> bool: def ignore_agent(self) -> bool:
"""Whether to ignore agent callbacks.""" """Whether to ignore agent callbacks."""
return not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true' return not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != "true"
@property @property
def ignore_chat_model(self) -> bool: def ignore_chat_model(self) -> bool:
"""Whether to ignore chat model callbacks.""" """Whether to ignore chat model callbacks."""
return not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true' return not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != "true"

View File

@ -1,4 +1,3 @@
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.entities.queue_entities import QueueRetrieverResourcesEvent from core.app.entities.queue_entities import QueueRetrieverResourcesEvent
@ -11,11 +10,9 @@ from models.model import DatasetRetrieverResource
class DatasetIndexToolCallbackHandler: class DatasetIndexToolCallbackHandler:
"""Callback handler for dataset tool.""" """Callback handler for dataset tool."""
def __init__(self, queue_manager: AppQueueManager, def __init__(
app_id: str, self, queue_manager: AppQueueManager, app_id: str, message_id: str, user_id: str, invoke_from: InvokeFrom
message_id: str, ) -> None:
user_id: str,
invoke_from: InvokeFrom) -> None:
self._queue_manager = queue_manager self._queue_manager = queue_manager
self._app_id = app_id self._app_id = app_id
self._message_id = message_id self._message_id = message_id
@ -29,11 +26,12 @@ class DatasetIndexToolCallbackHandler:
dataset_query = DatasetQuery( dataset_query = DatasetQuery(
dataset_id=dataset_id, dataset_id=dataset_id,
content=query, content=query,
source='app', source="app",
source_app_id=self._app_id, source_app_id=self._app_id,
created_by_role=('account' created_by_role=(
if self._invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else 'end_user'), "account" if self._invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else "end_user"
created_by=self._user_id ),
created_by=self._user_id,
) )
db.session.add(dataset_query) db.session.add(dataset_query)
@ -43,18 +41,15 @@ class DatasetIndexToolCallbackHandler:
"""Handle tool end.""" """Handle tool end."""
for document in documents: for document in documents:
query = db.session.query(DocumentSegment).filter( query = db.session.query(DocumentSegment).filter(
DocumentSegment.index_node_id == document.metadata['doc_id'] DocumentSegment.index_node_id == document.metadata["doc_id"]
) )
# if 'dataset_id' in document.metadata: # if 'dataset_id' in document.metadata:
if 'dataset_id' in document.metadata: if "dataset_id" in document.metadata:
query = query.filter(DocumentSegment.dataset_id == document.metadata['dataset_id']) query = query.filter(DocumentSegment.dataset_id == document.metadata["dataset_id"])
# add hit count to document segment # add hit count to document segment
query.update( query.update({DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False)
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1},
synchronize_session=False
)
db.session.commit() db.session.commit()
@ -64,26 +59,25 @@ class DatasetIndexToolCallbackHandler:
for item in resource: for item in resource:
dataset_retriever_resource = DatasetRetrieverResource( dataset_retriever_resource = DatasetRetrieverResource(
message_id=self._message_id, message_id=self._message_id,
position=item.get('position'), position=item.get("position"),
dataset_id=item.get('dataset_id'), dataset_id=item.get("dataset_id"),
dataset_name=item.get('dataset_name'), dataset_name=item.get("dataset_name"),
document_id=item.get('document_id'), document_id=item.get("document_id"),
document_name=item.get('document_name'), document_name=item.get("document_name"),
data_source_type=item.get('data_source_type'), data_source_type=item.get("data_source_type"),
segment_id=item.get('segment_id'), segment_id=item.get("segment_id"),
score=item.get('score') if 'score' in item else None, score=item.get("score") if "score" in item else None,
hit_count=item.get('hit_count') if 'hit_count' else None, hit_count=item.get("hit_count") if "hit_count" else None,
word_count=item.get('word_count') if 'word_count' in item else None, word_count=item.get("word_count") if "word_count" in item else None,
segment_position=item.get('segment_position') if 'segment_position' in item else None, segment_position=item.get("segment_position") if "segment_position" in item else None,
index_node_hash=item.get('index_node_hash') if 'index_node_hash' in item else None, index_node_hash=item.get("index_node_hash") if "index_node_hash" in item else None,
content=item.get('content'), content=item.get("content"),
retriever_from=item.get('retriever_from'), retriever_from=item.get("retriever_from"),
created_by=self._user_id created_by=self._user_id,
) )
db.session.add(dataset_retriever_resource) db.session.add(dataset_retriever_resource)
db.session.commit() db.session.commit()
self._queue_manager.publish( self._queue_manager.publish(
QueueRetrieverResourcesEvent(retriever_resources=resource), QueueRetrieverResourcesEvent(retriever_resources=resource), PublishFrom.APPLICATION_MANAGER
PublishFrom.APPLICATION_MANAGER
) )

View File

@ -2,4 +2,4 @@ from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackH
class DifyWorkflowCallbackHandler(DifyAgentCallbackHandler): class DifyWorkflowCallbackHandler(DifyAgentCallbackHandler):
"""Callback Handler that prints to std out.""" """Callback Handler that prints to std out."""

View File

@ -29,9 +29,13 @@ class CacheEmbedding(Embeddings):
embedding_queue_indices = [] embedding_queue_indices = []
for i, text in enumerate(texts): for i, text in enumerate(texts):
hash = helper.generate_text_hash(text) hash = helper.generate_text_hash(text)
embedding = db.session.query(Embedding).filter_by(model_name=self._model_instance.model, embedding = (
hash=hash, db.session.query(Embedding)
provider_name=self._model_instance.provider).first() .filter_by(
model_name=self._model_instance.model, hash=hash, provider_name=self._model_instance.provider
)
.first()
)
if embedding: if embedding:
text_embeddings[i] = embedding.get_embedding() text_embeddings[i] = embedding.get_embedding()
else: else:
@ -41,17 +45,18 @@ class CacheEmbedding(Embeddings):
embedding_queue_embeddings = [] embedding_queue_embeddings = []
try: try:
model_type_instance = cast(TextEmbeddingModel, self._model_instance.model_type_instance) model_type_instance = cast(TextEmbeddingModel, self._model_instance.model_type_instance)
model_schema = model_type_instance.get_model_schema(self._model_instance.model, model_schema = model_type_instance.get_model_schema(
self._model_instance.credentials) self._model_instance.model, self._model_instance.credentials
max_chunks = model_schema.model_properties[ModelPropertyKey.MAX_CHUNKS] \ )
if model_schema and ModelPropertyKey.MAX_CHUNKS in model_schema.model_properties else 1 max_chunks = (
model_schema.model_properties[ModelPropertyKey.MAX_CHUNKS]
if model_schema and ModelPropertyKey.MAX_CHUNKS in model_schema.model_properties
else 1
)
for i in range(0, len(embedding_queue_texts), max_chunks): for i in range(0, len(embedding_queue_texts), max_chunks):
batch_texts = embedding_queue_texts[i:i + max_chunks] batch_texts = embedding_queue_texts[i : i + max_chunks]
embedding_result = self._model_instance.invoke_text_embedding( embedding_result = self._model_instance.invoke_text_embedding(texts=batch_texts, user=self._user)
texts=batch_texts,
user=self._user
)
for vector in embedding_result.embeddings: for vector in embedding_result.embeddings:
try: try:
@ -60,16 +65,18 @@ class CacheEmbedding(Embeddings):
except IntegrityError: except IntegrityError:
db.session.rollback() db.session.rollback()
except Exception as e: except Exception as e:
logging.exception('Failed transform embedding: ', e) logging.exception("Failed transform embedding: ", e)
cache_embeddings = [] cache_embeddings = []
try: try:
for i, embedding in zip(embedding_queue_indices, embedding_queue_embeddings): for i, embedding in zip(embedding_queue_indices, embedding_queue_embeddings):
text_embeddings[i] = embedding text_embeddings[i] = embedding
hash = helper.generate_text_hash(texts[i]) hash = helper.generate_text_hash(texts[i])
if hash not in cache_embeddings: if hash not in cache_embeddings:
embedding_cache = Embedding(model_name=self._model_instance.model, embedding_cache = Embedding(
hash=hash, model_name=self._model_instance.model,
provider_name=self._model_instance.provider) hash=hash,
provider_name=self._model_instance.provider,
)
embedding_cache.set_embedding(embedding) embedding_cache.set_embedding(embedding)
db.session.add(embedding_cache) db.session.add(embedding_cache)
cache_embeddings.append(hash) cache_embeddings.append(hash)
@ -78,7 +85,7 @@ class CacheEmbedding(Embeddings):
db.session.rollback() db.session.rollback()
except Exception as ex: except Exception as ex:
db.session.rollback() db.session.rollback()
logger.error('Failed to embed documents: ', ex) logger.error("Failed to embed documents: ", ex)
raise ex raise ex
return text_embeddings return text_embeddings
@ -87,16 +94,13 @@ class CacheEmbedding(Embeddings):
"""Embed query text.""" """Embed query text."""
# use doc embedding cache or store if not exists # use doc embedding cache or store if not exists
hash = helper.generate_text_hash(text) hash = helper.generate_text_hash(text)
embedding_cache_key = f'{self._model_instance.provider}_{self._model_instance.model}_{hash}' embedding_cache_key = f"{self._model_instance.provider}_{self._model_instance.model}_{hash}"
embedding = redis_client.get(embedding_cache_key) embedding = redis_client.get(embedding_cache_key)
if embedding: if embedding:
redis_client.expire(embedding_cache_key, 600) redis_client.expire(embedding_cache_key, 600)
return list(np.frombuffer(base64.b64decode(embedding), dtype="float")) return list(np.frombuffer(base64.b64decode(embedding), dtype="float"))
try: try:
embedding_result = self._model_instance.invoke_text_embedding( embedding_result = self._model_instance.invoke_text_embedding(texts=[text], user=self._user)
texts=[text],
user=self._user
)
embedding_results = embedding_result.embeddings[0] embedding_results = embedding_result.embeddings[0]
embedding_results = (embedding_results / np.linalg.norm(embedding_results)).tolist() embedding_results = (embedding_results / np.linalg.norm(embedding_results)).tolist()
@ -116,6 +120,6 @@ class CacheEmbedding(Embeddings):
except IntegrityError: except IntegrityError:
db.session.rollback() db.session.rollback()
except: except:
logging.exception('Failed to add embedding to redis') logging.exception("Failed to add embedding to redis")
return embedding_results return embedding_results

View File

@ -2,7 +2,7 @@ from enum import Enum
class PlanningStrategy(Enum): class PlanningStrategy(Enum):
ROUTER = 'router' ROUTER = "router"
REACT_ROUTER = 'react_router' REACT_ROUTER = "react_router"
REACT = 'react' REACT = "react"
FUNCTION_CALL = 'function_call' FUNCTION_CALL = "function_call"

View File

@ -5,7 +5,7 @@ from pydantic import BaseModel
class PromptMessageFileType(enum.Enum): class PromptMessageFileType(enum.Enum):
IMAGE = 'image' IMAGE = "image"
@staticmethod @staticmethod
def value_of(value): def value_of(value):
@ -22,8 +22,8 @@ class PromptMessageFile(BaseModel):
class ImagePromptMessageFile(PromptMessageFile): class ImagePromptMessageFile(PromptMessageFile):
class DETAIL(enum.Enum): class DETAIL(enum.Enum):
LOW = 'low' LOW = "low"
HIGH = 'high' HIGH = "high"
type: PromptMessageFileType = PromptMessageFileType.IMAGE type: PromptMessageFileType = PromptMessageFileType.IMAGE
detail: DETAIL = DETAIL.LOW detail: DETAIL = DETAIL.LOW

View File

@ -12,6 +12,7 @@ class ModelStatus(Enum):
""" """
Enum class for model status. Enum class for model status.
""" """
ACTIVE = "active" ACTIVE = "active"
NO_CONFIGURE = "no-configure" NO_CONFIGURE = "no-configure"
QUOTA_EXCEEDED = "quota-exceeded" QUOTA_EXCEEDED = "quota-exceeded"
@ -23,6 +24,7 @@ class SimpleModelProviderEntity(BaseModel):
""" """
Simple provider. Simple provider.
""" """
provider: str provider: str
label: I18nObject label: I18nObject
icon_small: Optional[I18nObject] = None icon_small: Optional[I18nObject] = None
@ -40,7 +42,7 @@ class SimpleModelProviderEntity(BaseModel):
label=provider_entity.label, label=provider_entity.label,
icon_small=provider_entity.icon_small, icon_small=provider_entity.icon_small,
icon_large=provider_entity.icon_large, icon_large=provider_entity.icon_large,
supported_model_types=provider_entity.supported_model_types supported_model_types=provider_entity.supported_model_types,
) )
@ -48,6 +50,7 @@ class ProviderModelWithStatusEntity(ProviderModel):
""" """
Model class for model response. Model class for model response.
""" """
status: ModelStatus status: ModelStatus
load_balancing_enabled: bool = False load_balancing_enabled: bool = False
@ -56,6 +59,7 @@ class ModelWithProviderEntity(ProviderModelWithStatusEntity):
""" """
Model with provider entity. Model with provider entity.
""" """
provider: SimpleModelProviderEntity provider: SimpleModelProviderEntity
@ -63,6 +67,7 @@ class DefaultModelProviderEntity(BaseModel):
""" """
Default model provider entity. Default model provider entity.
""" """
provider: str provider: str
label: I18nObject label: I18nObject
icon_small: Optional[I18nObject] = None icon_small: Optional[I18nObject] = None
@ -74,6 +79,7 @@ class DefaultModelEntity(BaseModel):
""" """
Default model entity. Default model entity.
""" """
model: str model: str
model_type: ModelType model_type: ModelType
provider: DefaultModelProviderEntity provider: DefaultModelProviderEntity

View File

@ -47,6 +47,7 @@ class ProviderConfiguration(BaseModel):
""" """
Model class for provider configuration. Model class for provider configuration.
""" """
tenant_id: str tenant_id: str
provider: ProviderEntity provider: ProviderEntity
preferred_provider_type: ProviderType preferred_provider_type: ProviderType
@ -67,9 +68,13 @@ class ProviderConfiguration(BaseModel):
original_provider_configurate_methods[self.provider.provider].append(configurate_method) original_provider_configurate_methods[self.provider.provider].append(configurate_method)
if original_provider_configurate_methods[self.provider.provider] == [ConfigurateMethod.CUSTOMIZABLE_MODEL]: if original_provider_configurate_methods[self.provider.provider] == [ConfigurateMethod.CUSTOMIZABLE_MODEL]:
if (any(len(quota_configuration.restrict_models) > 0 if (
for quota_configuration in self.system_configuration.quota_configurations) any(
and ConfigurateMethod.PREDEFINED_MODEL not in self.provider.configurate_methods): len(quota_configuration.restrict_models) > 0
for quota_configuration in self.system_configuration.quota_configurations
)
and ConfigurateMethod.PREDEFINED_MODEL not in self.provider.configurate_methods
):
self.provider.configurate_methods.append(ConfigurateMethod.PREDEFINED_MODEL) self.provider.configurate_methods.append(ConfigurateMethod.PREDEFINED_MODEL)
def get_current_credentials(self, model_type: ModelType, model: str) -> Optional[dict]: def get_current_credentials(self, model_type: ModelType, model: str) -> Optional[dict]:
@ -83,10 +88,9 @@ class ProviderConfiguration(BaseModel):
if self.model_settings: if self.model_settings:
# check if model is disabled by admin # check if model is disabled by admin
for model_setting in self.model_settings: for model_setting in self.model_settings:
if (model_setting.model_type == model_type if model_setting.model_type == model_type and model_setting.model == model:
and model_setting.model == model):
if not model_setting.enabled: if not model_setting.enabled:
raise ValueError(f'Model {model} is disabled.') raise ValueError(f"Model {model} is disabled.")
if self.using_provider_type == ProviderType.SYSTEM: if self.using_provider_type == ProviderType.SYSTEM:
restrict_models = [] restrict_models = []
@ -99,10 +103,12 @@ class ProviderConfiguration(BaseModel):
copy_credentials = self.system_configuration.credentials.copy() copy_credentials = self.system_configuration.credentials.copy()
if restrict_models: if restrict_models:
for restrict_model in restrict_models: for restrict_model in restrict_models:
if (restrict_model.model_type == model_type if (
and restrict_model.model == model restrict_model.model_type == model_type
and restrict_model.base_model_name): and restrict_model.model == model
copy_credentials['base_model_name'] = restrict_model.base_model_name and restrict_model.base_model_name
):
copy_credentials["base_model_name"] = restrict_model.base_model_name
return copy_credentials return copy_credentials
else: else:
@ -128,20 +134,21 @@ class ProviderConfiguration(BaseModel):
current_quota_type = self.system_configuration.current_quota_type current_quota_type = self.system_configuration.current_quota_type
current_quota_configuration = next( current_quota_configuration = next(
(q for q in self.system_configuration.quota_configurations if q.quota_type == current_quota_type), (q for q in self.system_configuration.quota_configurations if q.quota_type == current_quota_type), None
None
) )
return SystemConfigurationStatus.ACTIVE if current_quota_configuration.is_valid else \ return (
SystemConfigurationStatus.QUOTA_EXCEEDED SystemConfigurationStatus.ACTIVE
if current_quota_configuration.is_valid
else SystemConfigurationStatus.QUOTA_EXCEEDED
)
def is_custom_configuration_available(self) -> bool: def is_custom_configuration_available(self) -> bool:
""" """
Check custom configuration available. Check custom configuration available.
:return: :return:
""" """
return (self.custom_configuration.provider is not None return self.custom_configuration.provider is not None or len(self.custom_configuration.models) > 0
or len(self.custom_configuration.models) > 0)
def get_custom_credentials(self, obfuscated: bool = False) -> Optional[dict]: def get_custom_credentials(self, obfuscated: bool = False) -> Optional[dict]:
""" """
@ -161,7 +168,8 @@ class ProviderConfiguration(BaseModel):
return self.obfuscated_credentials( return self.obfuscated_credentials(
credentials=credentials, credentials=credentials,
credential_form_schemas=self.provider.provider_credential_schema.credential_form_schemas credential_form_schemas=self.provider.provider_credential_schema.credential_form_schemas
if self.provider.provider_credential_schema else [] if self.provider.provider_credential_schema
else [],
) )
def custom_credentials_validate(self, credentials: dict) -> tuple[Provider, dict]: def custom_credentials_validate(self, credentials: dict) -> tuple[Provider, dict]:
@ -171,17 +179,21 @@ class ProviderConfiguration(BaseModel):
:return: :return:
""" """
# get provider # get provider
provider_record = db.session.query(Provider) \ provider_record = (
db.session.query(Provider)
.filter( .filter(
Provider.tenant_id == self.tenant_id, Provider.tenant_id == self.tenant_id,
Provider.provider_name == self.provider.provider, Provider.provider_name == self.provider.provider,
Provider.provider_type == ProviderType.CUSTOM.value Provider.provider_type == ProviderType.CUSTOM.value,
).first() )
.first()
)
# Get provider credential secret variables # Get provider credential secret variables
provider_credential_secret_variables = self.extract_secret_variables( provider_credential_secret_variables = self.extract_secret_variables(
self.provider.provider_credential_schema.credential_form_schemas self.provider.provider_credential_schema.credential_form_schemas
if self.provider.provider_credential_schema else [] if self.provider.provider_credential_schema
else []
) )
if provider_record: if provider_record:
@ -189,9 +201,7 @@ class ProviderConfiguration(BaseModel):
# fix origin data # fix origin data
if provider_record.encrypted_config: if provider_record.encrypted_config:
if not provider_record.encrypted_config.startswith("{"): if not provider_record.encrypted_config.startswith("{"):
original_credentials = { original_credentials = {"openai_api_key": provider_record.encrypted_config}
"openai_api_key": provider_record.encrypted_config
}
else: else:
original_credentials = json.loads(provider_record.encrypted_config) original_credentials = json.loads(provider_record.encrypted_config)
else: else:
@ -207,8 +217,7 @@ class ProviderConfiguration(BaseModel):
credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key]) credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key])
credentials = model_provider_factory.provider_credentials_validate( credentials = model_provider_factory.provider_credentials_validate(
provider=self.provider.provider, provider=self.provider.provider, credentials=credentials
credentials=credentials
) )
for key, value in credentials.items(): for key, value in credentials.items():
@ -239,15 +248,13 @@ class ProviderConfiguration(BaseModel):
provider_name=self.provider.provider, provider_name=self.provider.provider,
provider_type=ProviderType.CUSTOM.value, provider_type=ProviderType.CUSTOM.value,
encrypted_config=json.dumps(credentials), encrypted_config=json.dumps(credentials),
is_valid=True is_valid=True,
) )
db.session.add(provider_record) db.session.add(provider_record)
db.session.commit() db.session.commit()
provider_model_credentials_cache = ProviderCredentialsCache( provider_model_credentials_cache = ProviderCredentialsCache(
tenant_id=self.tenant_id, tenant_id=self.tenant_id, identity_id=provider_record.id, cache_type=ProviderCredentialsCacheType.PROVIDER
identity_id=provider_record.id,
cache_type=ProviderCredentialsCacheType.PROVIDER
) )
provider_model_credentials_cache.delete() provider_model_credentials_cache.delete()
@ -260,12 +267,15 @@ class ProviderConfiguration(BaseModel):
:return: :return:
""" """
# get provider # get provider
provider_record = db.session.query(Provider) \ provider_record = (
db.session.query(Provider)
.filter( .filter(
Provider.tenant_id == self.tenant_id, Provider.tenant_id == self.tenant_id,
Provider.provider_name == self.provider.provider, Provider.provider_name == self.provider.provider,
Provider.provider_type == ProviderType.CUSTOM.value Provider.provider_type == ProviderType.CUSTOM.value,
).first() )
.first()
)
# delete provider # delete provider
if provider_record: if provider_record:
@ -277,13 +287,14 @@ class ProviderConfiguration(BaseModel):
provider_model_credentials_cache = ProviderCredentialsCache( provider_model_credentials_cache = ProviderCredentialsCache(
tenant_id=self.tenant_id, tenant_id=self.tenant_id,
identity_id=provider_record.id, identity_id=provider_record.id,
cache_type=ProviderCredentialsCacheType.PROVIDER cache_type=ProviderCredentialsCacheType.PROVIDER,
) )
provider_model_credentials_cache.delete() provider_model_credentials_cache.delete()
def get_custom_model_credentials(self, model_type: ModelType, model: str, obfuscated: bool = False) \ def get_custom_model_credentials(
-> Optional[dict]: self, model_type: ModelType, model: str, obfuscated: bool = False
) -> Optional[dict]:
""" """
Get custom model credentials. Get custom model credentials.
@ -305,13 +316,15 @@ class ProviderConfiguration(BaseModel):
return self.obfuscated_credentials( return self.obfuscated_credentials(
credentials=credentials, credentials=credentials,
credential_form_schemas=self.provider.model_credential_schema.credential_form_schemas credential_form_schemas=self.provider.model_credential_schema.credential_form_schemas
if self.provider.model_credential_schema else [] if self.provider.model_credential_schema
else [],
) )
return None return None
def custom_model_credentials_validate(self, model_type: ModelType, model: str, credentials: dict) \ def custom_model_credentials_validate(
-> tuple[ProviderModel, dict]: self, model_type: ModelType, model: str, credentials: dict
) -> tuple[ProviderModel, dict]:
""" """
Validate custom model credentials. Validate custom model credentials.
@ -321,24 +334,29 @@ class ProviderConfiguration(BaseModel):
:return: :return:
""" """
# get provider model # get provider model
provider_model_record = db.session.query(ProviderModel) \ provider_model_record = (
db.session.query(ProviderModel)
.filter( .filter(
ProviderModel.tenant_id == self.tenant_id, ProviderModel.tenant_id == self.tenant_id,
ProviderModel.provider_name == self.provider.provider, ProviderModel.provider_name == self.provider.provider,
ProviderModel.model_name == model, ProviderModel.model_name == model,
ProviderModel.model_type == model_type.to_origin_model_type() ProviderModel.model_type == model_type.to_origin_model_type(),
).first() )
.first()
)
# Get provider credential secret variables # Get provider credential secret variables
provider_credential_secret_variables = self.extract_secret_variables( provider_credential_secret_variables = self.extract_secret_variables(
self.provider.model_credential_schema.credential_form_schemas self.provider.model_credential_schema.credential_form_schemas
if self.provider.model_credential_schema else [] if self.provider.model_credential_schema
else []
) )
if provider_model_record: if provider_model_record:
try: try:
original_credentials = json.loads( original_credentials = (
provider_model_record.encrypted_config) if provider_model_record.encrypted_config else {} json.loads(provider_model_record.encrypted_config) if provider_model_record.encrypted_config else {}
)
except JSONDecodeError: except JSONDecodeError:
original_credentials = {} original_credentials = {}
@ -350,10 +368,7 @@ class ProviderConfiguration(BaseModel):
credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key]) credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key])
credentials = model_provider_factory.model_credentials_validate( credentials = model_provider_factory.model_credentials_validate(
provider=self.provider.provider, provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials
model_type=model_type,
model=model,
credentials=credentials
) )
for key, value in credentials.items(): for key, value in credentials.items():
@ -388,7 +403,7 @@ class ProviderConfiguration(BaseModel):
model_name=model, model_name=model,
model_type=model_type.to_origin_model_type(), model_type=model_type.to_origin_model_type(),
encrypted_config=json.dumps(credentials), encrypted_config=json.dumps(credentials),
is_valid=True is_valid=True,
) )
db.session.add(provider_model_record) db.session.add(provider_model_record)
db.session.commit() db.session.commit()
@ -396,7 +411,7 @@ class ProviderConfiguration(BaseModel):
provider_model_credentials_cache = ProviderCredentialsCache( provider_model_credentials_cache = ProviderCredentialsCache(
tenant_id=self.tenant_id, tenant_id=self.tenant_id,
identity_id=provider_model_record.id, identity_id=provider_model_record.id,
cache_type=ProviderCredentialsCacheType.MODEL cache_type=ProviderCredentialsCacheType.MODEL,
) )
provider_model_credentials_cache.delete() provider_model_credentials_cache.delete()
@ -409,13 +424,16 @@ class ProviderConfiguration(BaseModel):
:return: :return:
""" """
# get provider model # get provider model
provider_model_record = db.session.query(ProviderModel) \ provider_model_record = (
db.session.query(ProviderModel)
.filter( .filter(
ProviderModel.tenant_id == self.tenant_id, ProviderModel.tenant_id == self.tenant_id,
ProviderModel.provider_name == self.provider.provider, ProviderModel.provider_name == self.provider.provider,
ProviderModel.model_name == model, ProviderModel.model_name == model,
ProviderModel.model_type == model_type.to_origin_model_type() ProviderModel.model_type == model_type.to_origin_model_type(),
).first() )
.first()
)
# delete provider model # delete provider model
if provider_model_record: if provider_model_record:
@ -425,7 +443,7 @@ class ProviderConfiguration(BaseModel):
provider_model_credentials_cache = ProviderCredentialsCache( provider_model_credentials_cache = ProviderCredentialsCache(
tenant_id=self.tenant_id, tenant_id=self.tenant_id,
identity_id=provider_model_record.id, identity_id=provider_model_record.id,
cache_type=ProviderCredentialsCacheType.MODEL cache_type=ProviderCredentialsCacheType.MODEL,
) )
provider_model_credentials_cache.delete() provider_model_credentials_cache.delete()
@ -437,13 +455,16 @@ class ProviderConfiguration(BaseModel):
:param model: model name :param model: model name
:return: :return:
""" """
model_setting = db.session.query(ProviderModelSetting) \ model_setting = (
db.session.query(ProviderModelSetting)
.filter( .filter(
ProviderModelSetting.tenant_id == self.tenant_id, ProviderModelSetting.tenant_id == self.tenant_id,
ProviderModelSetting.provider_name == self.provider.provider, ProviderModelSetting.provider_name == self.provider.provider,
ProviderModelSetting.model_type == model_type.to_origin_model_type(), ProviderModelSetting.model_type == model_type.to_origin_model_type(),
ProviderModelSetting.model_name == model ProviderModelSetting.model_name == model,
).first() )
.first()
)
if model_setting: if model_setting:
model_setting.enabled = True model_setting.enabled = True
@ -455,7 +476,7 @@ class ProviderConfiguration(BaseModel):
provider_name=self.provider.provider, provider_name=self.provider.provider,
model_type=model_type.to_origin_model_type(), model_type=model_type.to_origin_model_type(),
model_name=model, model_name=model,
enabled=True enabled=True,
) )
db.session.add(model_setting) db.session.add(model_setting)
db.session.commit() db.session.commit()
@ -469,13 +490,16 @@ class ProviderConfiguration(BaseModel):
:param model: model name :param model: model name
:return: :return:
""" """
model_setting = db.session.query(ProviderModelSetting) \ model_setting = (
db.session.query(ProviderModelSetting)
.filter( .filter(
ProviderModelSetting.tenant_id == self.tenant_id, ProviderModelSetting.tenant_id == self.tenant_id,
ProviderModelSetting.provider_name == self.provider.provider, ProviderModelSetting.provider_name == self.provider.provider,
ProviderModelSetting.model_type == model_type.to_origin_model_type(), ProviderModelSetting.model_type == model_type.to_origin_model_type(),
ProviderModelSetting.model_name == model ProviderModelSetting.model_name == model,
).first() )
.first()
)
if model_setting: if model_setting:
model_setting.enabled = False model_setting.enabled = False
@ -487,7 +511,7 @@ class ProviderConfiguration(BaseModel):
provider_name=self.provider.provider, provider_name=self.provider.provider,
model_type=model_type.to_origin_model_type(), model_type=model_type.to_origin_model_type(),
model_name=model, model_name=model,
enabled=False enabled=False,
) )
db.session.add(model_setting) db.session.add(model_setting)
db.session.commit() db.session.commit()
@ -501,13 +525,16 @@ class ProviderConfiguration(BaseModel):
:param model: model name :param model: model name
:return: :return:
""" """
return db.session.query(ProviderModelSetting) \ return (
db.session.query(ProviderModelSetting)
.filter( .filter(
ProviderModelSetting.tenant_id == self.tenant_id, ProviderModelSetting.tenant_id == self.tenant_id,
ProviderModelSetting.provider_name == self.provider.provider, ProviderModelSetting.provider_name == self.provider.provider,
ProviderModelSetting.model_type == model_type.to_origin_model_type(), ProviderModelSetting.model_type == model_type.to_origin_model_type(),
ProviderModelSetting.model_name == model ProviderModelSetting.model_name == model,
).first() )
.first()
)
def enable_model_load_balancing(self, model_type: ModelType, model: str) -> ProviderModelSetting: def enable_model_load_balancing(self, model_type: ModelType, model: str) -> ProviderModelSetting:
""" """
@ -516,24 +543,30 @@ class ProviderConfiguration(BaseModel):
:param model: model name :param model: model name
:return: :return:
""" """
load_balancing_config_count = db.session.query(LoadBalancingModelConfig) \ load_balancing_config_count = (
db.session.query(LoadBalancingModelConfig)
.filter( .filter(
LoadBalancingModelConfig.tenant_id == self.tenant_id, LoadBalancingModelConfig.tenant_id == self.tenant_id,
LoadBalancingModelConfig.provider_name == self.provider.provider, LoadBalancingModelConfig.provider_name == self.provider.provider,
LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(), LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
LoadBalancingModelConfig.model_name == model LoadBalancingModelConfig.model_name == model,
).count() )
.count()
)
if load_balancing_config_count <= 1: if load_balancing_config_count <= 1:
raise ValueError('Model load balancing configuration must be more than 1.') raise ValueError("Model load balancing configuration must be more than 1.")
model_setting = db.session.query(ProviderModelSetting) \ model_setting = (
db.session.query(ProviderModelSetting)
.filter( .filter(
ProviderModelSetting.tenant_id == self.tenant_id, ProviderModelSetting.tenant_id == self.tenant_id,
ProviderModelSetting.provider_name == self.provider.provider, ProviderModelSetting.provider_name == self.provider.provider,
ProviderModelSetting.model_type == model_type.to_origin_model_type(), ProviderModelSetting.model_type == model_type.to_origin_model_type(),
ProviderModelSetting.model_name == model ProviderModelSetting.model_name == model,
).first() )
.first()
)
if model_setting: if model_setting:
model_setting.load_balancing_enabled = True model_setting.load_balancing_enabled = True
@ -545,7 +578,7 @@ class ProviderConfiguration(BaseModel):
provider_name=self.provider.provider, provider_name=self.provider.provider,
model_type=model_type.to_origin_model_type(), model_type=model_type.to_origin_model_type(),
model_name=model, model_name=model,
load_balancing_enabled=True load_balancing_enabled=True,
) )
db.session.add(model_setting) db.session.add(model_setting)
db.session.commit() db.session.commit()
@ -559,13 +592,16 @@ class ProviderConfiguration(BaseModel):
:param model: model name :param model: model name
:return: :return:
""" """
model_setting = db.session.query(ProviderModelSetting) \ model_setting = (
db.session.query(ProviderModelSetting)
.filter( .filter(
ProviderModelSetting.tenant_id == self.tenant_id, ProviderModelSetting.tenant_id == self.tenant_id,
ProviderModelSetting.provider_name == self.provider.provider, ProviderModelSetting.provider_name == self.provider.provider,
ProviderModelSetting.model_type == model_type.to_origin_model_type(), ProviderModelSetting.model_type == model_type.to_origin_model_type(),
ProviderModelSetting.model_name == model ProviderModelSetting.model_name == model,
).first() )
.first()
)
if model_setting: if model_setting:
model_setting.load_balancing_enabled = False model_setting.load_balancing_enabled = False
@ -577,7 +613,7 @@ class ProviderConfiguration(BaseModel):
provider_name=self.provider.provider, provider_name=self.provider.provider,
model_type=model_type.to_origin_model_type(), model_type=model_type.to_origin_model_type(),
model_name=model, model_name=model,
load_balancing_enabled=False load_balancing_enabled=False,
) )
db.session.add(model_setting) db.session.add(model_setting)
db.session.commit() db.session.commit()
@ -617,11 +653,14 @@ class ProviderConfiguration(BaseModel):
return return
# get preferred provider # get preferred provider
preferred_model_provider = db.session.query(TenantPreferredModelProvider) \ preferred_model_provider = (
db.session.query(TenantPreferredModelProvider)
.filter( .filter(
TenantPreferredModelProvider.tenant_id == self.tenant_id, TenantPreferredModelProvider.tenant_id == self.tenant_id,
TenantPreferredModelProvider.provider_name == self.provider.provider TenantPreferredModelProvider.provider_name == self.provider.provider,
).first() )
.first()
)
if preferred_model_provider: if preferred_model_provider:
preferred_model_provider.preferred_provider_type = provider_type.value preferred_model_provider.preferred_provider_type = provider_type.value
@ -629,7 +668,7 @@ class ProviderConfiguration(BaseModel):
preferred_model_provider = TenantPreferredModelProvider( preferred_model_provider = TenantPreferredModelProvider(
tenant_id=self.tenant_id, tenant_id=self.tenant_id,
provider_name=self.provider.provider, provider_name=self.provider.provider,
preferred_provider_type=provider_type.value preferred_provider_type=provider_type.value,
) )
db.session.add(preferred_model_provider) db.session.add(preferred_model_provider)
@ -658,9 +697,7 @@ class ProviderConfiguration(BaseModel):
:return: :return:
""" """
# Get provider credential secret variables # Get provider credential secret variables
credential_secret_variables = self.extract_secret_variables( credential_secret_variables = self.extract_secret_variables(credential_form_schemas)
credential_form_schemas
)
# Obfuscate provider credentials # Obfuscate provider credentials
copy_credentials = credentials.copy() copy_credentials = credentials.copy()
@ -670,9 +707,9 @@ class ProviderConfiguration(BaseModel):
return copy_credentials return copy_credentials
def get_provider_model(self, model_type: ModelType, def get_provider_model(
model: str, self, model_type: ModelType, model: str, only_active: bool = False
only_active: bool = False) -> Optional[ModelWithProviderEntity]: ) -> Optional[ModelWithProviderEntity]:
""" """
Get provider model. Get provider model.
:param model_type: model type :param model_type: model type
@ -688,8 +725,9 @@ class ProviderConfiguration(BaseModel):
return None return None
def get_provider_models(self, model_type: Optional[ModelType] = None, def get_provider_models(
only_active: bool = False) -> list[ModelWithProviderEntity]: self, model_type: Optional[ModelType] = None, only_active: bool = False
) -> list[ModelWithProviderEntity]:
""" """
Get provider models. Get provider models.
:param model_type: model type :param model_type: model type
@ -711,15 +749,11 @@ class ProviderConfiguration(BaseModel):
if self.using_provider_type == ProviderType.SYSTEM: if self.using_provider_type == ProviderType.SYSTEM:
provider_models = self._get_system_provider_models( provider_models = self._get_system_provider_models(
model_types=model_types, model_types=model_types, provider_instance=provider_instance, model_setting_map=model_setting_map
provider_instance=provider_instance,
model_setting_map=model_setting_map
) )
else: else:
provider_models = self._get_custom_provider_models( provider_models = self._get_custom_provider_models(
model_types=model_types, model_types=model_types, provider_instance=provider_instance, model_setting_map=model_setting_map
provider_instance=provider_instance,
model_setting_map=model_setting_map
) )
if only_active: if only_active:
@ -728,11 +762,12 @@ class ProviderConfiguration(BaseModel):
# resort provider_models # resort provider_models
return sorted(provider_models, key=lambda x: x.model_type.value) return sorted(provider_models, key=lambda x: x.model_type.value)
def _get_system_provider_models(self, def _get_system_provider_models(
model_types: list[ModelType], self,
provider_instance: ModelProvider, model_types: list[ModelType],
model_setting_map: dict[ModelType, dict[str, ModelSettings]]) \ provider_instance: ModelProvider,
-> list[ModelWithProviderEntity]: model_setting_map: dict[ModelType, dict[str, ModelSettings]],
) -> list[ModelWithProviderEntity]:
""" """
Get system provider models. Get system provider models.
@ -760,7 +795,7 @@ class ProviderConfiguration(BaseModel):
model_properties=m.model_properties, model_properties=m.model_properties,
deprecated=m.deprecated, deprecated=m.deprecated,
provider=SimpleModelProviderEntity(self.provider), provider=SimpleModelProviderEntity(self.provider),
status=status status=status,
) )
) )
@ -783,23 +818,20 @@ class ProviderConfiguration(BaseModel):
if should_use_custom_model: if should_use_custom_model:
if original_provider_configurate_methods[self.provider.provider] == [ if original_provider_configurate_methods[self.provider.provider] == [
ConfigurateMethod.CUSTOMIZABLE_MODEL]: ConfigurateMethod.CUSTOMIZABLE_MODEL
]:
# only customizable model # only customizable model
for restrict_model in restrict_models: for restrict_model in restrict_models:
copy_credentials = self.system_configuration.credentials.copy() copy_credentials = self.system_configuration.credentials.copy()
if restrict_model.base_model_name: if restrict_model.base_model_name:
copy_credentials['base_model_name'] = restrict_model.base_model_name copy_credentials["base_model_name"] = restrict_model.base_model_name
try: try:
custom_model_schema = ( custom_model_schema = provider_instance.get_model_instance(
provider_instance.get_model_instance(restrict_model.model_type) restrict_model.model_type
.get_customizable_model_schema_from_credentials( ).get_customizable_model_schema_from_credentials(restrict_model.model, copy_credentials)
restrict_model.model,
copy_credentials
)
)
except Exception as ex: except Exception as ex:
logger.warning(f'get custom model schema failed, {ex}') logger.warning(f"get custom model schema failed, {ex}")
continue continue
if not custom_model_schema: if not custom_model_schema:
@ -809,8 +841,10 @@ class ProviderConfiguration(BaseModel):
continue continue
status = ModelStatus.ACTIVE status = ModelStatus.ACTIVE
if (custom_model_schema.model_type in model_setting_map if (
and custom_model_schema.model in model_setting_map[custom_model_schema.model_type]): custom_model_schema.model_type in model_setting_map
and custom_model_schema.model in model_setting_map[custom_model_schema.model_type]
):
model_setting = model_setting_map[custom_model_schema.model_type][custom_model_schema.model] model_setting = model_setting_map[custom_model_schema.model_type][custom_model_schema.model]
if model_setting.enabled is False: if model_setting.enabled is False:
status = ModelStatus.DISABLED status = ModelStatus.DISABLED
@ -825,7 +859,7 @@ class ProviderConfiguration(BaseModel):
model_properties=custom_model_schema.model_properties, model_properties=custom_model_schema.model_properties,
deprecated=custom_model_schema.deprecated, deprecated=custom_model_schema.deprecated,
provider=SimpleModelProviderEntity(self.provider), provider=SimpleModelProviderEntity(self.provider),
status=status status=status,
) )
) )
@ -839,11 +873,12 @@ class ProviderConfiguration(BaseModel):
return provider_models return provider_models
def _get_custom_provider_models(self, def _get_custom_provider_models(
model_types: list[ModelType], self,
provider_instance: ModelProvider, model_types: list[ModelType],
model_setting_map: dict[ModelType, dict[str, ModelSettings]]) \ provider_instance: ModelProvider,
-> list[ModelWithProviderEntity]: model_setting_map: dict[ModelType, dict[str, ModelSettings]],
) -> list[ModelWithProviderEntity]:
""" """
Get custom provider models. Get custom provider models.
@ -885,7 +920,7 @@ class ProviderConfiguration(BaseModel):
deprecated=m.deprecated, deprecated=m.deprecated,
provider=SimpleModelProviderEntity(self.provider), provider=SimpleModelProviderEntity(self.provider),
status=status, status=status,
load_balancing_enabled=load_balancing_enabled load_balancing_enabled=load_balancing_enabled,
) )
) )
@ -895,15 +930,13 @@ class ProviderConfiguration(BaseModel):
continue continue
try: try:
custom_model_schema = ( custom_model_schema = provider_instance.get_model_instance(
provider_instance.get_model_instance(model_configuration.model_type) model_configuration.model_type
.get_customizable_model_schema_from_credentials( ).get_customizable_model_schema_from_credentials(
model_configuration.model, model_configuration.model, model_configuration.credentials
model_configuration.credentials
)
) )
except Exception as ex: except Exception as ex:
logger.warning(f'get custom model schema failed, {ex}') logger.warning(f"get custom model schema failed, {ex}")
continue continue
if not custom_model_schema: if not custom_model_schema:
@ -911,8 +944,10 @@ class ProviderConfiguration(BaseModel):
status = ModelStatus.ACTIVE status = ModelStatus.ACTIVE
load_balancing_enabled = False load_balancing_enabled = False
if (custom_model_schema.model_type in model_setting_map if (
and custom_model_schema.model in model_setting_map[custom_model_schema.model_type]): custom_model_schema.model_type in model_setting_map
and custom_model_schema.model in model_setting_map[custom_model_schema.model_type]
):
model_setting = model_setting_map[custom_model_schema.model_type][custom_model_schema.model] model_setting = model_setting_map[custom_model_schema.model_type][custom_model_schema.model]
if model_setting.enabled is False: if model_setting.enabled is False:
status = ModelStatus.DISABLED status = ModelStatus.DISABLED
@ -931,7 +966,7 @@ class ProviderConfiguration(BaseModel):
deprecated=custom_model_schema.deprecated, deprecated=custom_model_schema.deprecated,
provider=SimpleModelProviderEntity(self.provider), provider=SimpleModelProviderEntity(self.provider),
status=status, status=status,
load_balancing_enabled=load_balancing_enabled load_balancing_enabled=load_balancing_enabled,
) )
) )
@ -942,17 +977,16 @@ class ProviderConfigurations(BaseModel):
""" """
Model class for provider configuration dict. Model class for provider configuration dict.
""" """
tenant_id: str tenant_id: str
configurations: dict[str, ProviderConfiguration] = {} configurations: dict[str, ProviderConfiguration] = {}
def __init__(self, tenant_id: str): def __init__(self, tenant_id: str):
super().__init__(tenant_id=tenant_id) super().__init__(tenant_id=tenant_id)
def get_models(self, def get_models(
provider: Optional[str] = None, self, provider: Optional[str] = None, model_type: Optional[ModelType] = None, only_active: bool = False
model_type: Optional[ModelType] = None, ) -> list[ModelWithProviderEntity]:
only_active: bool = False) \
-> list[ModelWithProviderEntity]:
""" """
Get available models. Get available models.
@ -1019,10 +1053,10 @@ class ProviderModelBundle(BaseModel):
""" """
Provider model bundle. Provider model bundle.
""" """
configuration: ProviderConfiguration configuration: ProviderConfiguration
provider_instance: ModelProvider provider_instance: ModelProvider
model_type_instance: AIModel model_type_instance: AIModel
# pydantic configs # pydantic configs
model_config = ConfigDict(arbitrary_types_allowed=True, model_config = ConfigDict(arbitrary_types_allowed=True, protected_namespaces=())
protected_namespaces=())

View File

@ -8,18 +8,19 @@ from models.provider import ProviderQuotaType
class QuotaUnit(Enum): class QuotaUnit(Enum):
TIMES = 'times' TIMES = "times"
TOKENS = 'tokens' TOKENS = "tokens"
CREDITS = 'credits' CREDITS = "credits"
class SystemConfigurationStatus(Enum): class SystemConfigurationStatus(Enum):
""" """
Enum class for system configuration status. Enum class for system configuration status.
""" """
ACTIVE = 'active'
QUOTA_EXCEEDED = 'quota-exceeded' ACTIVE = "active"
UNSUPPORTED = 'unsupported' QUOTA_EXCEEDED = "quota-exceeded"
UNSUPPORTED = "unsupported"
class RestrictModel(BaseModel): class RestrictModel(BaseModel):
@ -35,6 +36,7 @@ class QuotaConfiguration(BaseModel):
""" """
Model class for provider quota configuration. Model class for provider quota configuration.
""" """
quota_type: ProviderQuotaType quota_type: ProviderQuotaType
quota_unit: QuotaUnit quota_unit: QuotaUnit
quota_limit: int quota_limit: int
@ -47,6 +49,7 @@ class SystemConfiguration(BaseModel):
""" """
Model class for provider system configuration. Model class for provider system configuration.
""" """
enabled: bool enabled: bool
current_quota_type: Optional[ProviderQuotaType] = None current_quota_type: Optional[ProviderQuotaType] = None
quota_configurations: list[QuotaConfiguration] = [] quota_configurations: list[QuotaConfiguration] = []
@ -57,6 +60,7 @@ class CustomProviderConfiguration(BaseModel):
""" """
Model class for provider custom configuration. Model class for provider custom configuration.
""" """
credentials: dict credentials: dict
@ -64,6 +68,7 @@ class CustomModelConfiguration(BaseModel):
""" """
Model class for provider custom model configuration. Model class for provider custom model configuration.
""" """
model: str model: str
model_type: ModelType model_type: ModelType
credentials: dict credentials: dict
@ -76,6 +81,7 @@ class CustomConfiguration(BaseModel):
""" """
Model class for provider custom configuration. Model class for provider custom configuration.
""" """
provider: Optional[CustomProviderConfiguration] = None provider: Optional[CustomProviderConfiguration] = None
models: list[CustomModelConfiguration] = [] models: list[CustomModelConfiguration] = []
@ -84,6 +90,7 @@ class ModelLoadBalancingConfiguration(BaseModel):
""" """
Class for model load balancing configuration. Class for model load balancing configuration.
""" """
id: str id: str
name: str name: str
credentials: dict credentials: dict
@ -93,6 +100,7 @@ class ModelSettings(BaseModel):
""" """
Model class for model settings. Model class for model settings.
""" """
model: str model: str
model_type: ModelType model_type: ModelType
enabled: bool = True enabled: bool = True

View File

@ -3,6 +3,7 @@ from typing import Optional
class LLMError(Exception): class LLMError(Exception):
"""Base class for all LLM exceptions.""" """Base class for all LLM exceptions."""
description: Optional[str] = None description: Optional[str] = None
def __init__(self, description: Optional[str] = None) -> None: def __init__(self, description: Optional[str] = None) -> None:
@ -11,6 +12,7 @@ class LLMError(Exception):
class LLMBadRequestError(LLMError): class LLMBadRequestError(LLMError):
"""Raised when the LLM returns bad request.""" """Raised when the LLM returns bad request."""
description = "Bad Request" description = "Bad Request"
@ -18,6 +20,7 @@ class ProviderTokenNotInitError(Exception):
""" """
Custom exception raised when the provider token is not initialized. Custom exception raised when the provider token is not initialized.
""" """
description = "Provider Token Not Init" description = "Provider Token Not Init"
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
@ -28,6 +31,7 @@ class QuotaExceededError(Exception):
""" """
Custom exception raised when the quota for a provider has been exceeded. Custom exception raised when the quota for a provider has been exceeded.
""" """
description = "Quota Exceeded" description = "Quota Exceeded"
@ -35,6 +39,7 @@ class AppInvokeQuotaExceededError(Exception):
""" """
Custom exception raised when the quota for an app has been exceeded. Custom exception raised when the quota for an app has been exceeded.
""" """
description = "App Invoke Quota Exceeded" description = "App Invoke Quota Exceeded"
@ -42,9 +47,11 @@ class ModelCurrentlyNotSupportError(Exception):
""" """
Custom exception raised when the model not support Custom exception raised when the model not support
""" """
description = "Model Currently Not Support" description = "Model Currently Not Support"
class InvokeRateLimitError(Exception): class InvokeRateLimitError(Exception):
"""Raised when the Invoke returns rate limit error.""" """Raised when the Invoke returns rate limit error."""
description = "Rate Limit Error" description = "Rate Limit Error"

View File

@ -20,10 +20,7 @@ class APIBasedExtensionRequestor:
:param params: the request params :param params: the request params
:return: the response json :return: the response json
""" """
headers = { headers = {"Content-Type": "application/json", "Authorization": "Bearer {}".format(self.api_key)}
"Content-Type": "application/json",
"Authorization": "Bearer {}".format(self.api_key)
}
url = self.api_endpoint url = self.api_endpoint
@ -32,20 +29,17 @@ class APIBasedExtensionRequestor:
proxies = None proxies = None
if dify_config.SSRF_PROXY_HTTP_URL and dify_config.SSRF_PROXY_HTTPS_URL: if dify_config.SSRF_PROXY_HTTP_URL and dify_config.SSRF_PROXY_HTTPS_URL:
proxies = { proxies = {
'http': dify_config.SSRF_PROXY_HTTP_URL, "http": dify_config.SSRF_PROXY_HTTP_URL,
'https': dify_config.SSRF_PROXY_HTTPS_URL, "https": dify_config.SSRF_PROXY_HTTPS_URL,
} }
response = requests.request( response = requests.request(
method='POST', method="POST",
url=url, url=url,
json={ json={"point": point.value, "params": params},
'point': point.value,
'params': params
},
headers=headers, headers=headers,
timeout=self.timeout, timeout=self.timeout,
proxies=proxies proxies=proxies,
) )
except requests.exceptions.Timeout: except requests.exceptions.Timeout:
raise ValueError("request timeout") raise ValueError("request timeout")
@ -53,9 +47,8 @@ class APIBasedExtensionRequestor:
raise ValueError("request connection error") raise ValueError("request connection error")
if response.status_code != 200: if response.status_code != 200:
raise ValueError("request error, status_code: {}, content: {}".format( raise ValueError(
response.status_code, "request error, status_code: {}, content: {}".format(response.status_code, response.text[:100])
response.text[:100] )
))
return response.json() return response.json()

View File

@ -11,8 +11,8 @@ from core.helper.position_helper import sort_to_dict_by_position_map
class ExtensionModule(enum.Enum): class ExtensionModule(enum.Enum):
MODERATION = 'moderation' MODERATION = "moderation"
EXTERNAL_DATA_TOOL = 'external_data_tool' EXTERNAL_DATA_TOOL = "external_data_tool"
class ModuleExtension(BaseModel): class ModuleExtension(BaseModel):
@ -41,12 +41,12 @@ class Extensible:
position_map = {} position_map = {}
# get the path of the current class # get the path of the current class
current_path = os.path.abspath(cls.__module__.replace(".", os.path.sep) + '.py') current_path = os.path.abspath(cls.__module__.replace(".", os.path.sep) + ".py")
current_dir_path = os.path.dirname(current_path) current_dir_path = os.path.dirname(current_path)
# traverse subdirectories # traverse subdirectories
for subdir_name in os.listdir(current_dir_path): for subdir_name in os.listdir(current_dir_path):
if subdir_name.startswith('__'): if subdir_name.startswith("__"):
continue continue
subdir_path = os.path.join(current_dir_path, subdir_name) subdir_path = os.path.join(current_dir_path, subdir_name)
@ -58,21 +58,21 @@ class Extensible:
# in the front-end page and business logic, there are special treatments. # in the front-end page and business logic, there are special treatments.
builtin = False builtin = False
position = None position = None
if '__builtin__' in file_names: if "__builtin__" in file_names:
builtin = True builtin = True
builtin_file_path = os.path.join(subdir_path, '__builtin__') builtin_file_path = os.path.join(subdir_path, "__builtin__")
if os.path.exists(builtin_file_path): if os.path.exists(builtin_file_path):
with open(builtin_file_path, encoding='utf-8') as f: with open(builtin_file_path, encoding="utf-8") as f:
position = int(f.read().strip()) position = int(f.read().strip())
position_map[extension_name] = position position_map[extension_name] = position
if (extension_name + '.py') not in file_names: if (extension_name + ".py") not in file_names:
logging.warning(f"Missing {extension_name}.py file in {subdir_path}, Skip.") logging.warning(f"Missing {extension_name}.py file in {subdir_path}, Skip.")
continue continue
# Dynamic loading {subdir_name}.py file and find the subclass of Extensible # Dynamic loading {subdir_name}.py file and find the subclass of Extensible
py_path = os.path.join(subdir_path, extension_name + '.py') py_path = os.path.join(subdir_path, extension_name + ".py")
spec = importlib.util.spec_from_file_location(extension_name, py_path) spec = importlib.util.spec_from_file_location(extension_name, py_path)
if not spec or not spec.loader: if not spec or not spec.loader:
raise Exception(f"Failed to load module {extension_name} from {py_path}") raise Exception(f"Failed to load module {extension_name} from {py_path}")
@ -91,25 +91,29 @@ class Extensible:
json_data = {} json_data = {}
if not builtin: if not builtin:
if 'schema.json' not in file_names: if "schema.json" not in file_names:
logging.warning(f"Missing schema.json file in {subdir_path}, Skip.") logging.warning(f"Missing schema.json file in {subdir_path}, Skip.")
continue continue
json_path = os.path.join(subdir_path, 'schema.json') json_path = os.path.join(subdir_path, "schema.json")
json_data = {} json_data = {}
if os.path.exists(json_path): if os.path.exists(json_path):
with open(json_path, encoding='utf-8') as f: with open(json_path, encoding="utf-8") as f:
json_data = json.load(f) json_data = json.load(f)
extensions.append(ModuleExtension( extensions.append(
extension_class=extension_class, ModuleExtension(
name=extension_name, extension_class=extension_class,
label=json_data.get('label'), name=extension_name,
form_schema=json_data.get('form_schema'), label=json_data.get("label"),
builtin=builtin, form_schema=json_data.get("form_schema"),
position=position builtin=builtin,
)) position=position,
)
)
sorted_extensions = sort_to_dict_by_position_map(position_map=position_map, data=extensions, name_func=lambda x: x.name) sorted_extensions = sort_to_dict_by_position_map(
position_map=position_map, data=extensions, name_func=lambda x: x.name
)
return sorted_extensions return sorted_extensions

View File

@ -6,10 +6,7 @@ from core.moderation.base import Moderation
class Extension: class Extension:
__module_extensions: dict[str, dict[str, ModuleExtension]] = {} __module_extensions: dict[str, dict[str, ModuleExtension]] = {}
module_classes = { module_classes = {ExtensionModule.MODERATION: Moderation, ExtensionModule.EXTERNAL_DATA_TOOL: ExternalDataTool}
ExtensionModule.MODERATION: Moderation,
ExtensionModule.EXTERNAL_DATA_TOOL: ExternalDataTool
}
def init(self): def init(self):
for module, module_class in self.module_classes.items(): for module, module_class in self.module_classes.items():

View File

@ -30,10 +30,11 @@ class ApiExternalDataTool(ExternalDataTool):
raise ValueError("api_based_extension_id is required") raise ValueError("api_based_extension_id is required")
# get api_based_extension # get api_based_extension
api_based_extension = db.session.query(APIBasedExtension).filter( api_based_extension = (
APIBasedExtension.tenant_id == tenant_id, db.session.query(APIBasedExtension)
APIBasedExtension.id == api_based_extension_id .filter(APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id)
).first() .first()
)
if not api_based_extension: if not api_based_extension:
raise ValueError("api_based_extension_id is invalid") raise ValueError("api_based_extension_id is invalid")
@ -50,47 +51,42 @@ class ApiExternalDataTool(ExternalDataTool):
api_based_extension_id = self.config.get("api_based_extension_id") api_based_extension_id = self.config.get("api_based_extension_id")
# get api_based_extension # get api_based_extension
api_based_extension = db.session.query(APIBasedExtension).filter( api_based_extension = (
APIBasedExtension.tenant_id == self.tenant_id, db.session.query(APIBasedExtension)
APIBasedExtension.id == api_based_extension_id .filter(APIBasedExtension.tenant_id == self.tenant_id, APIBasedExtension.id == api_based_extension_id)
).first() .first()
)
if not api_based_extension: if not api_based_extension:
raise ValueError("[External data tool] API query failed, variable: {}, " raise ValueError(
"error: api_based_extension_id is invalid" "[External data tool] API query failed, variable: {}, "
.format(self.variable)) "error: api_based_extension_id is invalid".format(self.variable)
)
# decrypt api_key # decrypt api_key
api_key = encrypter.decrypt_token( api_key = encrypter.decrypt_token(tenant_id=self.tenant_id, token=api_based_extension.api_key)
tenant_id=self.tenant_id,
token=api_based_extension.api_key
)
try: try:
# request api # request api
requestor = APIBasedExtensionRequestor( requestor = APIBasedExtensionRequestor(api_endpoint=api_based_extension.api_endpoint, api_key=api_key)
api_endpoint=api_based_extension.api_endpoint,
api_key=api_key
)
except Exception as e: except Exception as e:
raise ValueError("[External data tool] API query failed, variable: {}, error: {}".format( raise ValueError("[External data tool] API query failed, variable: {}, error: {}".format(self.variable, e))
self.variable,
e
))
response_json = requestor.request(point=APIBasedExtensionPoint.APP_EXTERNAL_DATA_TOOL_QUERY, params={ response_json = requestor.request(
'app_id': self.app_id, point=APIBasedExtensionPoint.APP_EXTERNAL_DATA_TOOL_QUERY,
'tool_variable': self.variable, params={"app_id": self.app_id, "tool_variable": self.variable, "inputs": inputs, "query": query},
'inputs': inputs, )
'query': query
})
if 'result' not in response_json: if "result" not in response_json:
raise ValueError("[External data tool] API query failed, variable: {}, error: result not found in response" raise ValueError(
.format(self.variable)) "[External data tool] API query failed, variable: {}, error: result not found in response".format(
self.variable
)
)
if not isinstance(response_json['result'], str): if not isinstance(response_json["result"], str):
raise ValueError("[External data tool] API query failed, variable: {}, error: result is not string" raise ValueError(
.format(self.variable)) "[External data tool] API query failed, variable: {}, error: result is not string".format(self.variable)
)
return response_json['result'] return response_json["result"]

View File

@ -12,11 +12,14 @@ logger = logging.getLogger(__name__)
class ExternalDataFetch: class ExternalDataFetch:
def fetch(self, tenant_id: str, def fetch(
app_id: str, self,
external_data_tools: list[ExternalDataVariableEntity], tenant_id: str,
inputs: dict, app_id: str,
query: str) -> dict: external_data_tools: list[ExternalDataVariableEntity],
inputs: dict,
query: str,
) -> dict:
""" """
Fill in variable inputs from external data tools if exists. Fill in variable inputs from external data tools if exists.
@ -38,7 +41,7 @@ class ExternalDataFetch:
app_id, app_id,
tool, tool,
inputs, inputs,
query query,
) )
futures[future] = tool futures[future] = tool
@ -50,12 +53,15 @@ class ExternalDataFetch:
inputs.update(results) inputs.update(results)
return inputs return inputs
def _query_external_data_tool(self, flask_app: Flask, def _query_external_data_tool(
tenant_id: str, self,
app_id: str, flask_app: Flask,
external_data_tool: ExternalDataVariableEntity, tenant_id: str,
inputs: dict, app_id: str,
query: str) -> tuple[Optional[str], Optional[str]]: external_data_tool: ExternalDataVariableEntity,
inputs: dict,
query: str,
) -> tuple[Optional[str], Optional[str]]:
""" """
Query external data tool. Query external data tool.
:param flask_app: flask app :param flask_app: flask app
@ -72,17 +78,10 @@ class ExternalDataFetch:
tool_config = external_data_tool.config tool_config = external_data_tool.config
external_data_tool_factory = ExternalDataToolFactory( external_data_tool_factory = ExternalDataToolFactory(
name=tool_type, name=tool_type, tenant_id=tenant_id, app_id=app_id, variable=tool_variable, config=tool_config
tenant_id=tenant_id,
app_id=app_id,
variable=tool_variable,
config=tool_config
) )
# query external data tool # query external data tool
result = external_data_tool_factory.query( result = external_data_tool_factory.query(inputs=inputs, query=query)
inputs=inputs,
query=query
)
return tool_variable, result return tool_variable, result

View File

@ -5,14 +5,10 @@ from extensions.ext_code_based_extension import code_based_extension
class ExternalDataToolFactory: class ExternalDataToolFactory:
def __init__(self, name: str, tenant_id: str, app_id: str, variable: str, config: dict) -> None: def __init__(self, name: str, tenant_id: str, app_id: str, variable: str, config: dict) -> None:
extension_class = code_based_extension.extension_class(ExtensionModule.EXTERNAL_DATA_TOOL, name) extension_class = code_based_extension.extension_class(ExtensionModule.EXTERNAL_DATA_TOOL, name)
self.__extension_instance = extension_class( self.__extension_instance = extension_class(
tenant_id=tenant_id, tenant_id=tenant_id, app_id=app_id, variable=variable, config=config
app_id=app_id,
variable=variable,
config=config
) )
@classmethod @classmethod

View File

@ -13,11 +13,12 @@ class FileExtraConfig(BaseModel):
""" """
File Upload Entity. File Upload Entity.
""" """
image_config: Optional[dict[str, Any]] = None image_config: Optional[dict[str, Any]] = None
class FileType(enum.Enum): class FileType(enum.Enum):
IMAGE = 'image' IMAGE = "image"
@staticmethod @staticmethod
def value_of(value): def value_of(value):
@ -28,9 +29,9 @@ class FileType(enum.Enum):
class FileTransferMethod(enum.Enum): class FileTransferMethod(enum.Enum):
REMOTE_URL = 'remote_url' REMOTE_URL = "remote_url"
LOCAL_FILE = 'local_file' LOCAL_FILE = "local_file"
TOOL_FILE = 'tool_file' TOOL_FILE = "tool_file"
@staticmethod @staticmethod
def value_of(value): def value_of(value):
@ -39,9 +40,10 @@ class FileTransferMethod(enum.Enum):
return member return member
raise ValueError(f"No matching enum found for value '{value}'") raise ValueError(f"No matching enum found for value '{value}'")
class FileBelongsTo(enum.Enum): class FileBelongsTo(enum.Enum):
USER = 'user' USER = "user"
ASSISTANT = 'assistant' ASSISTANT = "assistant"
@staticmethod @staticmethod
def value_of(value): def value_of(value):
@ -65,16 +67,16 @@ class FileVar(BaseModel):
def to_dict(self) -> dict: def to_dict(self) -> dict:
return { return {
'__variant': self.__class__.__name__, "__variant": self.__class__.__name__,
'tenant_id': self.tenant_id, "tenant_id": self.tenant_id,
'type': self.type.value, "type": self.type.value,
'transfer_method': self.transfer_method.value, "transfer_method": self.transfer_method.value,
'url': self.preview_url, "url": self.preview_url,
'remote_url': self.url, "remote_url": self.url,
'related_id': self.related_id, "related_id": self.related_id,
'filename': self.filename, "filename": self.filename,
'extension': self.extension, "extension": self.extension,
'mime_type': self.mime_type, "mime_type": self.mime_type,
} }
def to_markdown(self) -> str: def to_markdown(self) -> str:
@ -86,7 +88,7 @@ class FileVar(BaseModel):
if self.type == FileType.IMAGE: if self.type == FileType.IMAGE:
text = f'![{self.filename or ""}]({preview_url})' text = f'![{self.filename or ""}]({preview_url})'
else: else:
text = f'[{self.filename or preview_url}]({preview_url})' text = f"[{self.filename or preview_url}]({preview_url})"
return text return text
@ -115,28 +117,29 @@ class FileVar(BaseModel):
return ImagePromptMessageContent( return ImagePromptMessageContent(
data=self.data, data=self.data,
detail=ImagePromptMessageContent.DETAIL.HIGH detail=ImagePromptMessageContent.DETAIL.HIGH
if image_config.get("detail") == "high" else ImagePromptMessageContent.DETAIL.LOW if image_config.get("detail") == "high"
else ImagePromptMessageContent.DETAIL.LOW,
) )
def _get_data(self, force_url: bool = False) -> Optional[str]: def _get_data(self, force_url: bool = False) -> Optional[str]:
from models.model import UploadFile from models.model import UploadFile
if self.type == FileType.IMAGE: if self.type == FileType.IMAGE:
if self.transfer_method == FileTransferMethod.REMOTE_URL: if self.transfer_method == FileTransferMethod.REMOTE_URL:
return self.url return self.url
elif self.transfer_method == FileTransferMethod.LOCAL_FILE: elif self.transfer_method == FileTransferMethod.LOCAL_FILE:
upload_file = (db.session.query(UploadFile) upload_file = (
.filter( db.session.query(UploadFile)
UploadFile.id == self.related_id, .filter(UploadFile.id == self.related_id, UploadFile.tenant_id == self.tenant_id)
UploadFile.tenant_id == self.tenant_id .first()
).first())
return UploadFileParser.get_image_data(
upload_file=upload_file,
force_url=force_url
) )
return UploadFileParser.get_image_data(upload_file=upload_file, force_url=force_url)
elif self.transfer_method == FileTransferMethod.TOOL_FILE: elif self.transfer_method == FileTransferMethod.TOOL_FILE:
extension = self.extension extension = self.extension
# add sign url # add sign url
return ToolFileParser.get_tool_file_manager().sign_file(tool_file_id=self.related_id, extension=extension) return ToolFileParser.get_tool_file_manager().sign_file(
tool_file_id=self.related_id, extension=extension
)
return None return None

View File

@ -13,13 +13,13 @@ from services.file_service import IMAGE_EXTENSIONS
class MessageFileParser: class MessageFileParser:
def __init__(self, tenant_id: str, app_id: str) -> None: def __init__(self, tenant_id: str, app_id: str) -> None:
self.tenant_id = tenant_id self.tenant_id = tenant_id
self.app_id = app_id self.app_id = app_id
def validate_and_transform_files_arg(self, files: Sequence[Mapping[str, Any]], file_extra_config: FileExtraConfig, def validate_and_transform_files_arg(
user: Union[Account, EndUser]) -> list[FileVar]: self, files: Sequence[Mapping[str, Any]], file_extra_config: FileExtraConfig, user: Union[Account, EndUser]
) -> list[FileVar]:
""" """
validate and transform files arg validate and transform files arg
@ -30,22 +30,22 @@ class MessageFileParser:
""" """
for file in files: for file in files:
if not isinstance(file, dict): if not isinstance(file, dict):
raise ValueError('Invalid file format, must be dict') raise ValueError("Invalid file format, must be dict")
if not file.get('type'): if not file.get("type"):
raise ValueError('Missing file type') raise ValueError("Missing file type")
FileType.value_of(file.get('type')) FileType.value_of(file.get("type"))
if not file.get('transfer_method'): if not file.get("transfer_method"):
raise ValueError('Missing file transfer method') raise ValueError("Missing file transfer method")
FileTransferMethod.value_of(file.get('transfer_method')) FileTransferMethod.value_of(file.get("transfer_method"))
if file.get('transfer_method') == FileTransferMethod.REMOTE_URL.value: if file.get("transfer_method") == FileTransferMethod.REMOTE_URL.value:
if not file.get('url'): if not file.get("url"):
raise ValueError('Missing file url') raise ValueError("Missing file url")
if not file.get('url').startswith('http'): if not file.get("url").startswith("http"):
raise ValueError('Invalid file url') raise ValueError("Invalid file url")
if file.get('transfer_method') == FileTransferMethod.LOCAL_FILE.value and not file.get('upload_file_id'): if file.get("transfer_method") == FileTransferMethod.LOCAL_FILE.value and not file.get("upload_file_id"):
raise ValueError('Missing file upload_file_id') raise ValueError("Missing file upload_file_id")
if file.get('transform_method') == FileTransferMethod.TOOL_FILE.value and not file.get('tool_file_id'): if file.get("transform_method") == FileTransferMethod.TOOL_FILE.value and not file.get("tool_file_id"):
raise ValueError('Missing file tool_file_id') raise ValueError("Missing file tool_file_id")
# transform files to file objs # transform files to file objs
type_file_objs = self._to_file_objs(files, file_extra_config) type_file_objs = self._to_file_objs(files, file_extra_config)
@ -62,17 +62,17 @@ class MessageFileParser:
continue continue
# Validate number of files # Validate number of files
if len(files) > image_config['number_limits']: if len(files) > image_config["number_limits"]:
raise ValueError(f"Number of image files exceeds the maximum limit {image_config['number_limits']}") raise ValueError(f"Number of image files exceeds the maximum limit {image_config['number_limits']}")
for file_obj in file_objs: for file_obj in file_objs:
# Validate transfer method # Validate transfer method
if file_obj.transfer_method.value not in image_config['transfer_methods']: if file_obj.transfer_method.value not in image_config["transfer_methods"]:
raise ValueError(f'Invalid transfer method: {file_obj.transfer_method.value}') raise ValueError(f"Invalid transfer method: {file_obj.transfer_method.value}")
# Validate file type # Validate file type
if file_obj.type != FileType.IMAGE: if file_obj.type != FileType.IMAGE:
raise ValueError(f'Invalid file type: {file_obj.type}') raise ValueError(f"Invalid file type: {file_obj.type}")
if file_obj.transfer_method == FileTransferMethod.REMOTE_URL: if file_obj.transfer_method == FileTransferMethod.REMOTE_URL:
# check remote url valid and is image # check remote url valid and is image
@ -81,18 +81,21 @@ class MessageFileParser:
raise ValueError(error) raise ValueError(error)
elif file_obj.transfer_method == FileTransferMethod.LOCAL_FILE: elif file_obj.transfer_method == FileTransferMethod.LOCAL_FILE:
# get upload file from upload_file_id # get upload file from upload_file_id
upload_file = (db.session.query(UploadFile) upload_file = (
.filter( db.session.query(UploadFile)
UploadFile.id == file_obj.related_id, .filter(
UploadFile.tenant_id == self.tenant_id, UploadFile.id == file_obj.related_id,
UploadFile.created_by == user.id, UploadFile.tenant_id == self.tenant_id,
UploadFile.created_by_role == ('account' if isinstance(user, Account) else 'end_user'), UploadFile.created_by == user.id,
UploadFile.extension.in_(IMAGE_EXTENSIONS) UploadFile.created_by_role == ("account" if isinstance(user, Account) else "end_user"),
).first()) UploadFile.extension.in_(IMAGE_EXTENSIONS),
)
.first()
)
# check upload file is belong to tenant and user # check upload file is belong to tenant and user
if not upload_file: if not upload_file:
raise ValueError('Invalid upload file') raise ValueError("Invalid upload file")
new_files.append(file_obj) new_files.append(file_obj)
@ -113,8 +116,9 @@ class MessageFileParser:
# return all file objs # return all file objs
return [file_obj for file_objs in type_file_objs.values() for file_obj in file_objs] return [file_obj for file_objs in type_file_objs.values() for file_obj in file_objs]
def _to_file_objs(self, files: list[Union[dict, MessageFile]], def _to_file_objs(
file_extra_config: FileExtraConfig) -> dict[FileType, list[FileVar]]: self, files: list[Union[dict, MessageFile]], file_extra_config: FileExtraConfig
) -> dict[FileType, list[FileVar]]:
""" """
transform files to file objs transform files to file objs
@ -152,23 +156,23 @@ class MessageFileParser:
:return: :return:
""" """
if isinstance(file, dict): if isinstance(file, dict):
transfer_method = FileTransferMethod.value_of(file.get('transfer_method')) transfer_method = FileTransferMethod.value_of(file.get("transfer_method"))
if transfer_method != FileTransferMethod.TOOL_FILE: if transfer_method != FileTransferMethod.TOOL_FILE:
return FileVar( return FileVar(
tenant_id=self.tenant_id, tenant_id=self.tenant_id,
type=FileType.value_of(file.get('type')), type=FileType.value_of(file.get("type")),
transfer_method=transfer_method, transfer_method=transfer_method,
url=file.get('url') if transfer_method == FileTransferMethod.REMOTE_URL else None, url=file.get("url") if transfer_method == FileTransferMethod.REMOTE_URL else None,
related_id=file.get('upload_file_id') if transfer_method == FileTransferMethod.LOCAL_FILE else None, related_id=file.get("upload_file_id") if transfer_method == FileTransferMethod.LOCAL_FILE else None,
extra_config=file_extra_config extra_config=file_extra_config,
) )
return FileVar( return FileVar(
tenant_id=self.tenant_id, tenant_id=self.tenant_id,
type=FileType.value_of(file.get('type')), type=FileType.value_of(file.get("type")),
transfer_method=transfer_method, transfer_method=transfer_method,
url=None, url=None,
related_id=file.get('tool_file_id'), related_id=file.get("tool_file_id"),
extra_config=file_extra_config extra_config=file_extra_config,
) )
else: else:
return FileVar( return FileVar(
@ -178,7 +182,7 @@ class MessageFileParser:
transfer_method=FileTransferMethod.value_of(file.transfer_method), transfer_method=FileTransferMethod.value_of(file.transfer_method),
url=file.url, url=file.url,
related_id=file.upload_file_id or None, related_id=file.upload_file_id or None,
extra_config=file_extra_config extra_config=file_extra_config,
) )
def _check_image_remote_url(self, url): def _check_image_remote_url(self, url):
@ -190,17 +194,17 @@ class MessageFileParser:
def is_s3_presigned_url(url): def is_s3_presigned_url(url):
try: try:
parsed_url = urlparse(url) parsed_url = urlparse(url)
if 'amazonaws.com' not in parsed_url.netloc: if "amazonaws.com" not in parsed_url.netloc:
return False return False
query_params = parse_qs(parsed_url.query) query_params = parse_qs(parsed_url.query)
required_params = ['Signature', 'Expires'] required_params = ["Signature", "Expires"]
for param in required_params: for param in required_params:
if param not in query_params: if param not in query_params:
return False return False
if not query_params['Expires'][0].isdigit(): if not query_params["Expires"][0].isdigit():
return False return False
signature = query_params['Signature'][0] signature = query_params["Signature"][0]
if not re.match(r'^[A-Za-z0-9+/]+={0,2}$', signature): if not re.match(r"^[A-Za-z0-9+/]+={0,2}$", signature):
return False return False
return True return True
except Exception: except Exception:

View File

@ -1,8 +1,7 @@
tool_file_manager = { tool_file_manager = {"manager": None}
'manager': None
}
class ToolFileParser: class ToolFileParser:
@staticmethod @staticmethod
def get_tool_file_manager() -> 'ToolFileManager': def get_tool_file_manager() -> "ToolFileManager":
return tool_file_manager['manager'] return tool_file_manager["manager"]

View File

@ -9,7 +9,7 @@ from typing import Optional
from configs import dify_config from configs import dify_config
from extensions.ext_storage import storage from extensions.ext_storage import storage
IMAGE_EXTENSIONS = ['jpg', 'jpeg', 'png', 'webp', 'gif', 'svg'] IMAGE_EXTENSIONS = ["jpg", "jpeg", "png", "webp", "gif", "svg"]
IMAGE_EXTENSIONS.extend([ext.upper() for ext in IMAGE_EXTENSIONS]) IMAGE_EXTENSIONS.extend([ext.upper() for ext in IMAGE_EXTENSIONS])
@ -22,18 +22,18 @@ class UploadFileParser:
if upload_file.extension not in IMAGE_EXTENSIONS: if upload_file.extension not in IMAGE_EXTENSIONS:
return None return None
if dify_config.MULTIMODAL_SEND_IMAGE_FORMAT == 'url' or force_url: if dify_config.MULTIMODAL_SEND_IMAGE_FORMAT == "url" or force_url:
return cls.get_signed_temp_image_url(upload_file.id) return cls.get_signed_temp_image_url(upload_file.id)
else: else:
# get image file base64 # get image file base64
try: try:
data = storage.load(upload_file.key) data = storage.load(upload_file.key)
except FileNotFoundError: except FileNotFoundError:
logging.error(f'File not found: {upload_file.key}') logging.error(f"File not found: {upload_file.key}")
return None return None
encoded_string = base64.b64encode(data).decode('utf-8') encoded_string = base64.b64encode(data).decode("utf-8")
return f'data:{upload_file.mime_type};base64,{encoded_string}' return f"data:{upload_file.mime_type};base64,{encoded_string}"
@classmethod @classmethod
def get_signed_temp_image_url(cls, upload_file_id) -> str: def get_signed_temp_image_url(cls, upload_file_id) -> str:
@ -44,7 +44,7 @@ class UploadFileParser:
:return: :return:
""" """
base_url = dify_config.FILES_URL base_url = dify_config.FILES_URL
image_preview_url = f'{base_url}/files/{upload_file_id}/image-preview' image_preview_url = f"{base_url}/files/{upload_file_id}/image-preview"
timestamp = str(int(time.time())) timestamp = str(int(time.time()))
nonce = os.urandom(16).hex() nonce = os.urandom(16).hex()

View File

@ -15,9 +15,11 @@ from core.helper.code_executor.template_transformer import TemplateTransformer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class CodeExecutionException(Exception): class CodeExecutionException(Exception):
pass pass
class CodeExecutionResponse(BaseModel): class CodeExecutionResponse(BaseModel):
class Data(BaseModel): class Data(BaseModel):
stdout: Optional[str] = None stdout: Optional[str] = None
@ -29,9 +31,9 @@ class CodeExecutionResponse(BaseModel):
class CodeLanguage(str, Enum): class CodeLanguage(str, Enum):
PYTHON3 = 'python3' PYTHON3 = "python3"
JINJA2 = 'jinja2' JINJA2 = "jinja2"
JAVASCRIPT = 'javascript' JAVASCRIPT = "javascript"
class CodeExecutor: class CodeExecutor:
@ -45,63 +47,65 @@ class CodeExecutor:
} }
code_language_to_running_language = { code_language_to_running_language = {
CodeLanguage.JAVASCRIPT: 'nodejs', CodeLanguage.JAVASCRIPT: "nodejs",
CodeLanguage.JINJA2: CodeLanguage.PYTHON3, CodeLanguage.JINJA2: CodeLanguage.PYTHON3,
CodeLanguage.PYTHON3: CodeLanguage.PYTHON3, CodeLanguage.PYTHON3: CodeLanguage.PYTHON3,
} }
supported_dependencies_languages: set[CodeLanguage] = { supported_dependencies_languages: set[CodeLanguage] = {CodeLanguage.PYTHON3}
CodeLanguage.PYTHON3
}
@classmethod @classmethod
def execute_code(cls, def execute_code(cls, language: CodeLanguage, preload: str, code: str) -> str:
language: CodeLanguage,
preload: str,
code: str) -> str:
""" """
Execute code Execute code
:param language: code language :param language: code language
:param code: code :param code: code
:return: :return:
""" """
url = URL(str(dify_config.CODE_EXECUTION_ENDPOINT)) / 'v1' / 'sandbox' / 'run' url = URL(str(dify_config.CODE_EXECUTION_ENDPOINT)) / "v1" / "sandbox" / "run"
headers = { headers = {"X-Api-Key": dify_config.CODE_EXECUTION_API_KEY}
'X-Api-Key': dify_config.CODE_EXECUTION_API_KEY
}
data = { data = {
'language': cls.code_language_to_running_language.get(language), "language": cls.code_language_to_running_language.get(language),
'code': code, "code": code,
'preload': preload, "preload": preload,
'enable_network': True "enable_network": True,
} }
try: try:
response = post(str(url), json=data, headers=headers, response = post(
timeout=Timeout( str(url),
connect=dify_config.CODE_EXECUTION_CONNECT_TIMEOUT, json=data,
read=dify_config.CODE_EXECUTION_READ_TIMEOUT, headers=headers,
write=dify_config.CODE_EXECUTION_WRITE_TIMEOUT, timeout=Timeout(
pool=None)) connect=dify_config.CODE_EXECUTION_CONNECT_TIMEOUT,
read=dify_config.CODE_EXECUTION_READ_TIMEOUT,
write=dify_config.CODE_EXECUTION_WRITE_TIMEOUT,
pool=None,
),
)
if response.status_code == 503: if response.status_code == 503:
raise CodeExecutionException('Code execution service is unavailable') raise CodeExecutionException("Code execution service is unavailable")
elif response.status_code != 200: elif response.status_code != 200:
raise Exception(f'Failed to execute code, got status code {response.status_code}, please check if the sandbox service is running') raise Exception(
f"Failed to execute code, got status code {response.status_code}, please check if the sandbox service is running"
)
except CodeExecutionException as e: except CodeExecutionException as e:
raise e raise e
except Exception as e: except Exception as e:
raise CodeExecutionException('Failed to execute code, which is likely a network issue,' raise CodeExecutionException(
' please check if the sandbox service is running.' "Failed to execute code, which is likely a network issue,"
f' ( Error: {str(e)} )') " please check if the sandbox service is running."
f" ( Error: {str(e)} )"
)
try: try:
response = response.json() response = response.json()
except: except:
raise CodeExecutionException('Failed to parse response') raise CodeExecutionException("Failed to parse response")
if (code := response.get('code')) != 0: if (code := response.get("code")) != 0:
raise CodeExecutionException(f"Got error code: {code}. Got error msg: {response.get('message')}") raise CodeExecutionException(f"Got error code: {code}. Got error msg: {response.get('message')}")
response = CodeExecutionResponse(**response) response = CodeExecutionResponse(**response)
@ -109,7 +113,7 @@ class CodeExecutor:
if response.data.error: if response.data.error:
raise CodeExecutionException(response.data.error) raise CodeExecutionException(response.data.error)
return response.data.stdout or '' return response.data.stdout or ""
@classmethod @classmethod
def execute_workflow_code_template(cls, language: CodeLanguage, code: str, inputs: dict) -> dict: def execute_workflow_code_template(cls, language: CodeLanguage, code: str, inputs: dict) -> dict:
@ -122,7 +126,7 @@ class CodeExecutor:
""" """
template_transformer = cls.code_template_transformers.get(language) template_transformer = cls.code_template_transformers.get(language)
if not template_transformer: if not template_transformer:
raise CodeExecutionException(f'Unsupported language {language}') raise CodeExecutionException(f"Unsupported language {language}")
runner, preload = template_transformer.transform_caller(code, inputs) runner, preload = template_transformer.transform_caller(code, inputs)

View File

@ -26,23 +26,9 @@ class CodeNodeProvider(BaseModel):
return { return {
"type": "code", "type": "code",
"config": { "config": {
"variables": [ "variables": [{"variable": "arg1", "value_selector": []}, {"variable": "arg2", "value_selector": []}],
{
"variable": "arg1",
"value_selector": []
},
{
"variable": "arg2",
"value_selector": []
}
],
"code_language": cls.get_language(), "code_language": cls.get_language(),
"code": cls.get_default_code(), "code": cls.get_default_code(),
"outputs": { "outputs": {"result": {"type": "string", "children": None}},
"result": { },
"type": "string",
"children": None
}
}
}
} }

View File

@ -18,4 +18,5 @@ class JavascriptCodeProvider(CodeNodeProvider):
result: arg1 + arg2 result: arg1 + arg2
} }
} }
""") """
)

View File

@ -21,5 +21,6 @@ class NodeJsTemplateTransformer(TemplateTransformer):
var output_json = JSON.stringify(output_obj) var output_json = JSON.stringify(output_obj)
var result = `<<RESULT>>${{output_json}}<<RESULT>>` var result = `<<RESULT>>${{output_json}}<<RESULT>>`
console.log(result) console.log(result)
""") """
)
return runner_script return runner_script

View File

@ -10,8 +10,6 @@ class Jinja2Formatter:
:param inputs: inputs :param inputs: inputs
:return: :return:
""" """
result = CodeExecutor.execute_workflow_code_template( result = CodeExecutor.execute_workflow_code_template(language=CodeLanguage.JINJA2, code=template, inputs=inputs)
language=CodeLanguage.JINJA2, code=template, inputs=inputs
)
return result['result'] return result["result"]

View File

@ -11,9 +11,7 @@ class Jinja2TemplateTransformer(TemplateTransformer):
:param response: response :param response: response
:return: :return:
""" """
return { return {"result": cls.extract_result_str_from_response(response)}
'result': cls.extract_result_str_from_response(response)
}
@classmethod @classmethod
def get_runner_script(cls) -> str: def get_runner_script(cls) -> str:

View File

@ -17,4 +17,5 @@ class Python3CodeProvider(CodeNodeProvider):
return { return {
"result": arg1 + arg2, "result": arg1 + arg2,
} }
""") """
)

Some files were not shown because too many files have changed in this diff Show More