feat: stable diffusion 3 (#3599)

This commit is contained in:
Yeuoly 2024-04-18 16:54:37 +08:00 committed by GitHub
parent aa6d2e3035
commit d9f1a8ce9f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 291 additions and 0 deletions

View File

@ -4,6 +4,7 @@
- searxng
- dalle
- azuredalle
- stability
- wikipedia
- model.openai
- model.google

View File

@ -0,0 +1,10 @@
<svg xmlns="http://www.w3.org/2000/svg" width="40" height="40" viewBox="0 0 40 40" fill="none">
<path d="M12.0377 35C19.1243 35 23.7343 31.3 23.7343 25.7333C23.7343 21.4167 20.931 18.6733 15.9177 17.5367L12.701 16.585C9.87768 15.96 8.22935 15.21 8.61768 13.2933C8.94102 11.6983 9.90602 10.7983 12.1543 10.7983C19.296 10.7983 21.9427 13.2933 21.9427 13.2933V7.29333C21.9427 7.29333 19.366 5 12.1543 5C5.35435 5 1.66602 8.45 1.66602 13.7883C1.66602 18.105 4.22268 20.6167 9.40768 21.8083L9.96435 21.9467C10.7527 22.1867 11.8177 22.505 13.1577 22.9C15.8077 23.525 16.4893 24.1883 16.4893 26.1767C16.4893 27.9933 14.5727 29.0267 12.0393 29.0267C4.73435 29.0267 1.66602 25.385 1.66602 25.385V32.0333C1.66602 32.0333 3.58602 35 12.0377 35Z" fill="url(#paint0_linear_17756_15767)"/>
<path d="M33.9561 34.55C36.4645 34.55 38.3328 32.7617 38.3328 30.34C38.3328 27.8667 36.5178 26.13 33.9561 26.13C31.4478 26.13 29.6328 27.8667 29.6328 30.34C29.6328 32.8133 31.4478 34.55 33.9561 34.55Z" fill="#E80000"/>
<defs>
<linearGradient id="paint0_linear_17756_15767" x1="1105.08" y1="5" x2="1105.08" y2="3005" gradientUnits="userSpaceOnUse">
<stop stop-color="#9D39FF"/>
<stop offset="1" stop-color="#A380FF"/>
</linearGradient>
</defs>
</svg>

After

Width:  |  Height:  |  Size: 1.2 KiB

View File

@ -0,0 +1,15 @@
from typing import Any
from core.tools.provider.builtin.stability.tools.base import BaseStabilityAuthorization
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
class StabilityToolProvider(BuiltinToolProviderController, BaseStabilityAuthorization):
"""
This class is responsible for providing the stability tool.
"""
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
"""
This method is responsible for validating the credentials.
"""
self.sd_validate_credentials(credentials)

View File

@ -0,0 +1,29 @@
identity:
author: Dify
name: stability
label:
en_US: Stability
zh_Hans: Stability
pt_BR: Stability
description:
en_US: Activating humanity's potential through generative AI
zh_Hans: 通过生成式 AI 激活人类的潜力
pt_BR: Activating humanity's potential through generative AI
icon: icon.svg
credentials_for_provider:
api_key:
type: secret-input
required: true
label:
en_US: API key
zh_Hans: API key
pt_BR: API key
placeholder:
en_US: Please input your API key
zh_Hans: 请输入你的 API key
pt_BR: Please input your API key
help:
en_US: Get your API key from Stability
zh_Hans: 从 Stability 获取你的 API key
pt_BR: Get your API key from Stability
url: https://platform.stability.ai/account/keys

View File

@ -0,0 +1,34 @@
import requests
from yarl import URL
from core.tools.errors import ToolProviderCredentialValidationError
class BaseStabilityAuthorization:
def sd_validate_credentials(self, credentials: dict):
"""
This method is responsible for validating the credentials.
"""
api_key = credentials.get('api_key', '')
if not api_key:
raise ToolProviderCredentialValidationError('API key is required.')
response = requests.get(
URL('https://api.stability.ai') / 'v1' / 'user' / 'account',
headers=self.generate_authorization_headers(credentials),
timeout=(5, 30)
)
if not response.ok:
raise ToolProviderCredentialValidationError('Invalid API key.')
return True
def generate_authorization_headers(self, credentials: dict) -> dict[str, str]:
"""
This method is responsible for generating the authorization headers.
"""
return {
'Authorization': f'Bearer {credentials.get("api_key", "")}'
}

View File

@ -0,0 +1,60 @@
from typing import Any
from httpx import post
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.provider.builtin.stability.tools.base import BaseStabilityAuthorization
from core.tools.tool.builtin_tool import BuiltinTool
class StableDiffusionTool(BuiltinTool, BaseStabilityAuthorization):
"""
This class is responsible for providing the stable diffusion tool.
"""
model_endpoint_map = {
'sd3': 'https://api.stability.ai/v2beta/stable-image/generate/sd3',
'sd3-turbo': 'https://api.stability.ai/v2beta/stable-image/generate/sd3',
'core': 'https://api.stability.ai/v2beta/stable-image/generate/core',
}
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]:
"""
Invoke the tool.
"""
payload = {
'prompt': tool_parameters.get('prompt', ''),
'aspect_radio': tool_parameters.get('aspect_radio', '16:9'),
'mode': 'text-to-image',
'seed': tool_parameters.get('seed', 0),
'output_format': 'png',
}
model = tool_parameters.get('model', 'core')
if model in ['sd3', 'sd3-turbo']:
payload['model'] = tool_parameters.get('model')
if not model == 'sd3-turbo':
payload['negative_prompt'] = tool_parameters.get('negative_prompt', '')
response = post(
self.model_endpoint_map[tool_parameters.get('model', 'core')],
headers={
'accept': 'image/*',
**self.generate_authorization_headers(self.runtime.credentials),
},
files={
key: (None, str(value)) for key, value in payload.items()
},
timeout=(5, 30)
)
if not response.status_code == 200:
raise Exception(response.text)
return self.create_blob_message(
blob=response.content, meta={
'mime_type': 'image/png'
},
save_as=self.VARIABLE_KEY.IMAGE.value
)

View File

@ -0,0 +1,142 @@
identity:
name: stability_text2image
author: Dify
label:
en_US: StableDiffusion
zh_Hans: 稳定扩散
pt_BR: StableDiffusion
description:
human:
en_US: A tool for generate images based on the text input
zh_Hans: 一个基于文本输入生成图像的工具
pt_BR: A tool for generate images based on the text input
llm: A tool for generate images based on the text input
parameters:
- name: prompt
type: string
required: true
label:
en_US: Prompt
zh_Hans: 提示词
pt_BR: Prompt
human_description:
en_US: used for generating images
zh_Hans: 用于生成图像
pt_BR: used for generating images
llm_description: key words for generating images
form: llm
- name: model
type: select
default: sd3-turbo
required: true
label:
en_US: Model
zh_Hans: 模型
pt_BR: Model
options:
- value: core
label:
en_US: Core
zh_Hans: Core
pt_BR: Core
- value: sd3
label:
en_US: Stable Diffusion 3
zh_Hans: Stable Diffusion 3
pt_BR: Stable Diffusion 3
- value: sd3-turbo
label:
en_US: Stable Diffusion 3 Turbo
zh_Hans: Stable Diffusion 3 Turbo
pt_BR: Stable Diffusion 3 Turbo
human_description:
en_US: Model for generating images
zh_Hans: 用于生成图像的模型
pt_BR: Model for generating images
llm_description: Model for generating images
form: form
- name: negative_prompt
type: string
default: bad art, ugly, deformed, watermark, duplicated, discontinuous lines
required: false
label:
en_US: Negative Prompt
zh_Hans: 负面提示
pt_BR: Negative Prompt
human_description:
en_US: Negative Prompt
zh_Hans: 负面提示
pt_BR: Negative Prompt
llm_description: Negative Prompt
form: form
- name: seeds
type: number
default: 0
required: false
label:
en_US: Seeds
zh_Hans: 种子
pt_BR: Seeds
human_description:
en_US: Seeds
zh_Hans: 种子
pt_BR: Seeds
llm_description: Seeds
min: 0
max: 4294967294
form: form
- name: aspect_radio
type: select
default: '16:9'
options:
- value: '16:9'
label:
en_US: '16:9'
zh_Hans: '16:9'
pt_BR: '16:9'
- value: '1:1'
label:
en_US: '1:1'
zh_Hans: '1:1'
pt_BR: '1:1'
- value: '21:9'
label:
en_US: '21:9'
zh_Hans: '21:9'
pt_BR: '21:9'
- value: '2:3'
label:
en_US: '2:3'
zh_Hans: '2:3'
pt_BR: '2:3'
- value: '4:5'
label:
en_US: '4:5'
zh_Hans: '4:5'
pt_BR: '4:5'
- value: '5:4'
label:
en_US: '5:4'
zh_Hans: '5:4'
pt_BR: '5:4'
- value: '9:16'
label:
en_US: '9:16'
zh_Hans: '9:16'
pt_BR: '9:16'
- value: '9:21'
label:
en_US: '9:21'
zh_Hans: '9:21'
pt_BR: '9:21'
required: false
label:
en_US: Aspect Radio
zh_Hans: 长宽比
pt_BR: Aspect Radio
human_description:
en_US: Aspect Radio
zh_Hans: 长宽比
pt_BR: Aspect Radio
llm_description: Aspect Radio
form: form