From d9f1a8ce9fba0c6d6dd48b24c60014e1411dca96 Mon Sep 17 00:00:00 2001 From: Yeuoly <45712896+Yeuoly@users.noreply.github.com> Date: Thu, 18 Apr 2024 16:54:37 +0800 Subject: [PATCH] feat: stable diffusion 3 (#3599) --- api/core/tools/provider/_position.yaml | 1 + .../builtin/stability/_assets/icon.svg | 10 ++ .../provider/builtin/stability/stability.py | 15 ++ .../provider/builtin/stability/stability.yaml | 29 ++++ .../provider/builtin/stability/tools/base.py | 34 +++++ .../builtin/stability/tools/text2image.py | 60 ++++++++ .../builtin/stability/tools/text2image.yaml | 142 ++++++++++++++++++ 7 files changed, 291 insertions(+) create mode 100644 api/core/tools/provider/builtin/stability/_assets/icon.svg create mode 100644 api/core/tools/provider/builtin/stability/stability.py create mode 100644 api/core/tools/provider/builtin/stability/stability.yaml create mode 100644 api/core/tools/provider/builtin/stability/tools/base.py create mode 100644 api/core/tools/provider/builtin/stability/tools/text2image.py create mode 100644 api/core/tools/provider/builtin/stability/tools/text2image.yaml diff --git a/api/core/tools/provider/_position.yaml b/api/core/tools/provider/_position.yaml index 778626f1cc..5e6e8dcb7a 100644 --- a/api/core/tools/provider/_position.yaml +++ b/api/core/tools/provider/_position.yaml @@ -4,6 +4,7 @@ - searxng - dalle - azuredalle +- stability - wikipedia - model.openai - model.google diff --git a/api/core/tools/provider/builtin/stability/_assets/icon.svg b/api/core/tools/provider/builtin/stability/_assets/icon.svg new file mode 100644 index 0000000000..56357a3555 --- /dev/null +++ b/api/core/tools/provider/builtin/stability/_assets/icon.svg @@ -0,0 +1,10 @@ + + + + + + + + + + \ No newline at end of file diff --git a/api/core/tools/provider/builtin/stability/stability.py b/api/core/tools/provider/builtin/stability/stability.py new file mode 100644 index 0000000000..d00c3ecf00 --- /dev/null +++ b/api/core/tools/provider/builtin/stability/stability.py @@ -0,0 +1,15 @@ +from typing import Any + +from core.tools.provider.builtin.stability.tools.base import BaseStabilityAuthorization +from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController + + +class StabilityToolProvider(BuiltinToolProviderController, BaseStabilityAuthorization): + """ + This class is responsible for providing the stability tool. + """ + def _validate_credentials(self, credentials: dict[str, Any]) -> None: + """ + This method is responsible for validating the credentials. + """ + self.sd_validate_credentials(credentials) \ No newline at end of file diff --git a/api/core/tools/provider/builtin/stability/stability.yaml b/api/core/tools/provider/builtin/stability/stability.yaml new file mode 100644 index 0000000000..d8369a4c03 --- /dev/null +++ b/api/core/tools/provider/builtin/stability/stability.yaml @@ -0,0 +1,29 @@ +identity: + author: Dify + name: stability + label: + en_US: Stability + zh_Hans: Stability + pt_BR: Stability + description: + en_US: Activating humanity's potential through generative AI + zh_Hans: 通过生成式 AI 激活人类的潜力 + pt_BR: Activating humanity's potential through generative AI + icon: icon.svg +credentials_for_provider: + api_key: + type: secret-input + required: true + label: + en_US: API key + zh_Hans: API key + pt_BR: API key + placeholder: + en_US: Please input your API key + zh_Hans: 请输入你的 API key + pt_BR: Please input your API key + help: + en_US: Get your API key from Stability + zh_Hans: 从 Stability 获取你的 API key + pt_BR: Get your API key from Stability + url: https://platform.stability.ai/account/keys diff --git a/api/core/tools/provider/builtin/stability/tools/base.py b/api/core/tools/provider/builtin/stability/tools/base.py new file mode 100644 index 0000000000..a4788fd869 --- /dev/null +++ b/api/core/tools/provider/builtin/stability/tools/base.py @@ -0,0 +1,34 @@ +import requests +from yarl import URL + +from core.tools.errors import ToolProviderCredentialValidationError + + +class BaseStabilityAuthorization: + def sd_validate_credentials(self, credentials: dict): + """ + This method is responsible for validating the credentials. + """ + api_key = credentials.get('api_key', '') + if not api_key: + raise ToolProviderCredentialValidationError('API key is required.') + + response = requests.get( + URL('https://api.stability.ai') / 'v1' / 'user' / 'account', + headers=self.generate_authorization_headers(credentials), + timeout=(5, 30) + ) + + if not response.ok: + raise ToolProviderCredentialValidationError('Invalid API key.') + + return True + + def generate_authorization_headers(self, credentials: dict) -> dict[str, str]: + """ + This method is responsible for generating the authorization headers. + """ + return { + 'Authorization': f'Bearer {credentials.get("api_key", "")}' + } + \ No newline at end of file diff --git a/api/core/tools/provider/builtin/stability/tools/text2image.py b/api/core/tools/provider/builtin/stability/tools/text2image.py new file mode 100644 index 0000000000..10f6b62110 --- /dev/null +++ b/api/core/tools/provider/builtin/stability/tools/text2image.py @@ -0,0 +1,60 @@ +from typing import Any + +from httpx import post + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.provider.builtin.stability.tools.base import BaseStabilityAuthorization +from core.tools.tool.builtin_tool import BuiltinTool + + +class StableDiffusionTool(BuiltinTool, BaseStabilityAuthorization): + """ + This class is responsible for providing the stable diffusion tool. + """ + model_endpoint_map = { + 'sd3': 'https://api.stability.ai/v2beta/stable-image/generate/sd3', + 'sd3-turbo': 'https://api.stability.ai/v2beta/stable-image/generate/sd3', + 'core': 'https://api.stability.ai/v2beta/stable-image/generate/core', + } + + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]: + """ + Invoke the tool. + """ + payload = { + 'prompt': tool_parameters.get('prompt', ''), + 'aspect_radio': tool_parameters.get('aspect_radio', '16:9'), + 'mode': 'text-to-image', + 'seed': tool_parameters.get('seed', 0), + 'output_format': 'png', + } + + model = tool_parameters.get('model', 'core') + + if model in ['sd3', 'sd3-turbo']: + payload['model'] = tool_parameters.get('model') + + if not model == 'sd3-turbo': + payload['negative_prompt'] = tool_parameters.get('negative_prompt', '') + + response = post( + self.model_endpoint_map[tool_parameters.get('model', 'core')], + headers={ + 'accept': 'image/*', + **self.generate_authorization_headers(self.runtime.credentials), + }, + files={ + key: (None, str(value)) for key, value in payload.items() + }, + timeout=(5, 30) + ) + + if not response.status_code == 200: + raise Exception(response.text) + + return self.create_blob_message( + blob=response.content, meta={ + 'mime_type': 'image/png' + }, + save_as=self.VARIABLE_KEY.IMAGE.value + ) diff --git a/api/core/tools/provider/builtin/stability/tools/text2image.yaml b/api/core/tools/provider/builtin/stability/tools/text2image.yaml new file mode 100644 index 0000000000..51da193a03 --- /dev/null +++ b/api/core/tools/provider/builtin/stability/tools/text2image.yaml @@ -0,0 +1,142 @@ +identity: + name: stability_text2image + author: Dify + label: + en_US: StableDiffusion + zh_Hans: 稳定扩散 + pt_BR: StableDiffusion +description: + human: + en_US: A tool for generate images based on the text input + zh_Hans: 一个基于文本输入生成图像的工具 + pt_BR: A tool for generate images based on the text input + llm: A tool for generate images based on the text input +parameters: + - name: prompt + type: string + required: true + label: + en_US: Prompt + zh_Hans: 提示词 + pt_BR: Prompt + human_description: + en_US: used for generating images + zh_Hans: 用于生成图像 + pt_BR: used for generating images + llm_description: key words for generating images + form: llm + - name: model + type: select + default: sd3-turbo + required: true + label: + en_US: Model + zh_Hans: 模型 + pt_BR: Model + options: + - value: core + label: + en_US: Core + zh_Hans: Core + pt_BR: Core + - value: sd3 + label: + en_US: Stable Diffusion 3 + zh_Hans: Stable Diffusion 3 + pt_BR: Stable Diffusion 3 + - value: sd3-turbo + label: + en_US: Stable Diffusion 3 Turbo + zh_Hans: Stable Diffusion 3 Turbo + pt_BR: Stable Diffusion 3 Turbo + human_description: + en_US: Model for generating images + zh_Hans: 用于生成图像的模型 + pt_BR: Model for generating images + llm_description: Model for generating images + form: form + - name: negative_prompt + type: string + default: bad art, ugly, deformed, watermark, duplicated, discontinuous lines + required: false + label: + en_US: Negative Prompt + zh_Hans: 负面提示 + pt_BR: Negative Prompt + human_description: + en_US: Negative Prompt + zh_Hans: 负面提示 + pt_BR: Negative Prompt + llm_description: Negative Prompt + form: form + - name: seeds + type: number + default: 0 + required: false + label: + en_US: Seeds + zh_Hans: 种子 + pt_BR: Seeds + human_description: + en_US: Seeds + zh_Hans: 种子 + pt_BR: Seeds + llm_description: Seeds + min: 0 + max: 4294967294 + form: form + - name: aspect_radio + type: select + default: '16:9' + options: + - value: '16:9' + label: + en_US: '16:9' + zh_Hans: '16:9' + pt_BR: '16:9' + - value: '1:1' + label: + en_US: '1:1' + zh_Hans: '1:1' + pt_BR: '1:1' + - value: '21:9' + label: + en_US: '21:9' + zh_Hans: '21:9' + pt_BR: '21:9' + - value: '2:3' + label: + en_US: '2:3' + zh_Hans: '2:3' + pt_BR: '2:3' + - value: '4:5' + label: + en_US: '4:5' + zh_Hans: '4:5' + pt_BR: '4:5' + - value: '5:4' + label: + en_US: '5:4' + zh_Hans: '5:4' + pt_BR: '5:4' + - value: '9:16' + label: + en_US: '9:16' + zh_Hans: '9:16' + pt_BR: '9:16' + - value: '9:21' + label: + en_US: '9:21' + zh_Hans: '9:21' + pt_BR: '9:21' + required: false + label: + en_US: Aspect Radio + zh_Hans: 长宽比 + pt_BR: Aspect Radio + human_description: + en_US: Aspect Radio + zh_Hans: 长宽比 + pt_BR: Aspect Radio + llm_description: Aspect Radio + form: form