mirror of https://github.com/langgenius/dify.git
feat: stable diffusion 3 (#3599)
This commit is contained in:
parent
aa6d2e3035
commit
d9f1a8ce9f
|
@ -4,6 +4,7 @@
|
|||
- searxng
|
||||
- dalle
|
||||
- azuredalle
|
||||
- stability
|
||||
- wikipedia
|
||||
- model.openai
|
||||
- model.google
|
||||
|
|
|
@ -0,0 +1,10 @@
|
|||
<svg xmlns="http://www.w3.org/2000/svg" width="40" height="40" viewBox="0 0 40 40" fill="none">
|
||||
<path d="M12.0377 35C19.1243 35 23.7343 31.3 23.7343 25.7333C23.7343 21.4167 20.931 18.6733 15.9177 17.5367L12.701 16.585C9.87768 15.96 8.22935 15.21 8.61768 13.2933C8.94102 11.6983 9.90602 10.7983 12.1543 10.7983C19.296 10.7983 21.9427 13.2933 21.9427 13.2933V7.29333C21.9427 7.29333 19.366 5 12.1543 5C5.35435 5 1.66602 8.45 1.66602 13.7883C1.66602 18.105 4.22268 20.6167 9.40768 21.8083L9.96435 21.9467C10.7527 22.1867 11.8177 22.505 13.1577 22.9C15.8077 23.525 16.4893 24.1883 16.4893 26.1767C16.4893 27.9933 14.5727 29.0267 12.0393 29.0267C4.73435 29.0267 1.66602 25.385 1.66602 25.385V32.0333C1.66602 32.0333 3.58602 35 12.0377 35Z" fill="url(#paint0_linear_17756_15767)"/>
|
||||
<path d="M33.9561 34.55C36.4645 34.55 38.3328 32.7617 38.3328 30.34C38.3328 27.8667 36.5178 26.13 33.9561 26.13C31.4478 26.13 29.6328 27.8667 29.6328 30.34C29.6328 32.8133 31.4478 34.55 33.9561 34.55Z" fill="#E80000"/>
|
||||
<defs>
|
||||
<linearGradient id="paint0_linear_17756_15767" x1="1105.08" y1="5" x2="1105.08" y2="3005" gradientUnits="userSpaceOnUse">
|
||||
<stop stop-color="#9D39FF"/>
|
||||
<stop offset="1" stop-color="#A380FF"/>
|
||||
</linearGradient>
|
||||
</defs>
|
||||
</svg>
|
After Width: | Height: | Size: 1.2 KiB |
|
@ -0,0 +1,15 @@
|
|||
from typing import Any
|
||||
|
||||
from core.tools.provider.builtin.stability.tools.base import BaseStabilityAuthorization
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
|
||||
|
||||
class StabilityToolProvider(BuiltinToolProviderController, BaseStabilityAuthorization):
|
||||
"""
|
||||
This class is responsible for providing the stability tool.
|
||||
"""
|
||||
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
|
||||
"""
|
||||
This method is responsible for validating the credentials.
|
||||
"""
|
||||
self.sd_validate_credentials(credentials)
|
|
@ -0,0 +1,29 @@
|
|||
identity:
|
||||
author: Dify
|
||||
name: stability
|
||||
label:
|
||||
en_US: Stability
|
||||
zh_Hans: Stability
|
||||
pt_BR: Stability
|
||||
description:
|
||||
en_US: Activating humanity's potential through generative AI
|
||||
zh_Hans: 通过生成式 AI 激活人类的潜力
|
||||
pt_BR: Activating humanity's potential through generative AI
|
||||
icon: icon.svg
|
||||
credentials_for_provider:
|
||||
api_key:
|
||||
type: secret-input
|
||||
required: true
|
||||
label:
|
||||
en_US: API key
|
||||
zh_Hans: API key
|
||||
pt_BR: API key
|
||||
placeholder:
|
||||
en_US: Please input your API key
|
||||
zh_Hans: 请输入你的 API key
|
||||
pt_BR: Please input your API key
|
||||
help:
|
||||
en_US: Get your API key from Stability
|
||||
zh_Hans: 从 Stability 获取你的 API key
|
||||
pt_BR: Get your API key from Stability
|
||||
url: https://platform.stability.ai/account/keys
|
|
@ -0,0 +1,34 @@
|
|||
import requests
|
||||
from yarl import URL
|
||||
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
|
||||
|
||||
class BaseStabilityAuthorization:
|
||||
def sd_validate_credentials(self, credentials: dict):
|
||||
"""
|
||||
This method is responsible for validating the credentials.
|
||||
"""
|
||||
api_key = credentials.get('api_key', '')
|
||||
if not api_key:
|
||||
raise ToolProviderCredentialValidationError('API key is required.')
|
||||
|
||||
response = requests.get(
|
||||
URL('https://api.stability.ai') / 'v1' / 'user' / 'account',
|
||||
headers=self.generate_authorization_headers(credentials),
|
||||
timeout=(5, 30)
|
||||
)
|
||||
|
||||
if not response.ok:
|
||||
raise ToolProviderCredentialValidationError('Invalid API key.')
|
||||
|
||||
return True
|
||||
|
||||
def generate_authorization_headers(self, credentials: dict) -> dict[str, str]:
|
||||
"""
|
||||
This method is responsible for generating the authorization headers.
|
||||
"""
|
||||
return {
|
||||
'Authorization': f'Bearer {credentials.get("api_key", "")}'
|
||||
}
|
||||
|
|
@ -0,0 +1,60 @@
|
|||
from typing import Any
|
||||
|
||||
from httpx import post
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.provider.builtin.stability.tools.base import BaseStabilityAuthorization
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class StableDiffusionTool(BuiltinTool, BaseStabilityAuthorization):
|
||||
"""
|
||||
This class is responsible for providing the stable diffusion tool.
|
||||
"""
|
||||
model_endpoint_map = {
|
||||
'sd3': 'https://api.stability.ai/v2beta/stable-image/generate/sd3',
|
||||
'sd3-turbo': 'https://api.stability.ai/v2beta/stable-image/generate/sd3',
|
||||
'core': 'https://api.stability.ai/v2beta/stable-image/generate/core',
|
||||
}
|
||||
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]:
|
||||
"""
|
||||
Invoke the tool.
|
||||
"""
|
||||
payload = {
|
||||
'prompt': tool_parameters.get('prompt', ''),
|
||||
'aspect_radio': tool_parameters.get('aspect_radio', '16:9'),
|
||||
'mode': 'text-to-image',
|
||||
'seed': tool_parameters.get('seed', 0),
|
||||
'output_format': 'png',
|
||||
}
|
||||
|
||||
model = tool_parameters.get('model', 'core')
|
||||
|
||||
if model in ['sd3', 'sd3-turbo']:
|
||||
payload['model'] = tool_parameters.get('model')
|
||||
|
||||
if not model == 'sd3-turbo':
|
||||
payload['negative_prompt'] = tool_parameters.get('negative_prompt', '')
|
||||
|
||||
response = post(
|
||||
self.model_endpoint_map[tool_parameters.get('model', 'core')],
|
||||
headers={
|
||||
'accept': 'image/*',
|
||||
**self.generate_authorization_headers(self.runtime.credentials),
|
||||
},
|
||||
files={
|
||||
key: (None, str(value)) for key, value in payload.items()
|
||||
},
|
||||
timeout=(5, 30)
|
||||
)
|
||||
|
||||
if not response.status_code == 200:
|
||||
raise Exception(response.text)
|
||||
|
||||
return self.create_blob_message(
|
||||
blob=response.content, meta={
|
||||
'mime_type': 'image/png'
|
||||
},
|
||||
save_as=self.VARIABLE_KEY.IMAGE.value
|
||||
)
|
|
@ -0,0 +1,142 @@
|
|||
identity:
|
||||
name: stability_text2image
|
||||
author: Dify
|
||||
label:
|
||||
en_US: StableDiffusion
|
||||
zh_Hans: 稳定扩散
|
||||
pt_BR: StableDiffusion
|
||||
description:
|
||||
human:
|
||||
en_US: A tool for generate images based on the text input
|
||||
zh_Hans: 一个基于文本输入生成图像的工具
|
||||
pt_BR: A tool for generate images based on the text input
|
||||
llm: A tool for generate images based on the text input
|
||||
parameters:
|
||||
- name: prompt
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: Prompt
|
||||
zh_Hans: 提示词
|
||||
pt_BR: Prompt
|
||||
human_description:
|
||||
en_US: used for generating images
|
||||
zh_Hans: 用于生成图像
|
||||
pt_BR: used for generating images
|
||||
llm_description: key words for generating images
|
||||
form: llm
|
||||
- name: model
|
||||
type: select
|
||||
default: sd3-turbo
|
||||
required: true
|
||||
label:
|
||||
en_US: Model
|
||||
zh_Hans: 模型
|
||||
pt_BR: Model
|
||||
options:
|
||||
- value: core
|
||||
label:
|
||||
en_US: Core
|
||||
zh_Hans: Core
|
||||
pt_BR: Core
|
||||
- value: sd3
|
||||
label:
|
||||
en_US: Stable Diffusion 3
|
||||
zh_Hans: Stable Diffusion 3
|
||||
pt_BR: Stable Diffusion 3
|
||||
- value: sd3-turbo
|
||||
label:
|
||||
en_US: Stable Diffusion 3 Turbo
|
||||
zh_Hans: Stable Diffusion 3 Turbo
|
||||
pt_BR: Stable Diffusion 3 Turbo
|
||||
human_description:
|
||||
en_US: Model for generating images
|
||||
zh_Hans: 用于生成图像的模型
|
||||
pt_BR: Model for generating images
|
||||
llm_description: Model for generating images
|
||||
form: form
|
||||
- name: negative_prompt
|
||||
type: string
|
||||
default: bad art, ugly, deformed, watermark, duplicated, discontinuous lines
|
||||
required: false
|
||||
label:
|
||||
en_US: Negative Prompt
|
||||
zh_Hans: 负面提示
|
||||
pt_BR: Negative Prompt
|
||||
human_description:
|
||||
en_US: Negative Prompt
|
||||
zh_Hans: 负面提示
|
||||
pt_BR: Negative Prompt
|
||||
llm_description: Negative Prompt
|
||||
form: form
|
||||
- name: seeds
|
||||
type: number
|
||||
default: 0
|
||||
required: false
|
||||
label:
|
||||
en_US: Seeds
|
||||
zh_Hans: 种子
|
||||
pt_BR: Seeds
|
||||
human_description:
|
||||
en_US: Seeds
|
||||
zh_Hans: 种子
|
||||
pt_BR: Seeds
|
||||
llm_description: Seeds
|
||||
min: 0
|
||||
max: 4294967294
|
||||
form: form
|
||||
- name: aspect_radio
|
||||
type: select
|
||||
default: '16:9'
|
||||
options:
|
||||
- value: '16:9'
|
||||
label:
|
||||
en_US: '16:9'
|
||||
zh_Hans: '16:9'
|
||||
pt_BR: '16:9'
|
||||
- value: '1:1'
|
||||
label:
|
||||
en_US: '1:1'
|
||||
zh_Hans: '1:1'
|
||||
pt_BR: '1:1'
|
||||
- value: '21:9'
|
||||
label:
|
||||
en_US: '21:9'
|
||||
zh_Hans: '21:9'
|
||||
pt_BR: '21:9'
|
||||
- value: '2:3'
|
||||
label:
|
||||
en_US: '2:3'
|
||||
zh_Hans: '2:3'
|
||||
pt_BR: '2:3'
|
||||
- value: '4:5'
|
||||
label:
|
||||
en_US: '4:5'
|
||||
zh_Hans: '4:5'
|
||||
pt_BR: '4:5'
|
||||
- value: '5:4'
|
||||
label:
|
||||
en_US: '5:4'
|
||||
zh_Hans: '5:4'
|
||||
pt_BR: '5:4'
|
||||
- value: '9:16'
|
||||
label:
|
||||
en_US: '9:16'
|
||||
zh_Hans: '9:16'
|
||||
pt_BR: '9:16'
|
||||
- value: '9:21'
|
||||
label:
|
||||
en_US: '9:21'
|
||||
zh_Hans: '9:21'
|
||||
pt_BR: '9:21'
|
||||
required: false
|
||||
label:
|
||||
en_US: Aspect Radio
|
||||
zh_Hans: 长宽比
|
||||
pt_BR: Aspect Radio
|
||||
human_description:
|
||||
en_US: Aspect Radio
|
||||
zh_Hans: 长宽比
|
||||
pt_BR: Aspect Radio
|
||||
llm_description: Aspect Radio
|
||||
form: form
|
Loading…
Reference in New Issue