mirror of https://github.com/langgenius/dify.git
chore: the consistency of MultiModalPromptMessageContent (#11721)
This commit is contained in:
parent
78c3051585
commit
c9b4029ce7
|
@ -313,8 +313,7 @@ UPLOAD_VIDEO_FILE_SIZE_LIMIT=100
|
|||
UPLOAD_AUDIO_FILE_SIZE_LIMIT=50
|
||||
|
||||
# Model configuration
|
||||
MULTIMODAL_SEND_IMAGE_FORMAT=base64
|
||||
MULTIMODAL_SEND_VIDEO_FORMAT=base64
|
||||
MULTIMODAL_SEND_FORMAT=base64
|
||||
PROMPT_GENERATION_MAX_TOKENS=512
|
||||
CODE_GENERATION_MAX_TOKENS=1024
|
||||
|
||||
|
|
|
@ -665,14 +665,9 @@ class IndexingConfig(BaseSettings):
|
|||
)
|
||||
|
||||
|
||||
class VisionFormatConfig(BaseSettings):
|
||||
MULTIMODAL_SEND_IMAGE_FORMAT: Literal["base64", "url"] = Field(
|
||||
description="Format for sending images in multimodal contexts ('base64' or 'url'), default is base64",
|
||||
default="base64",
|
||||
)
|
||||
|
||||
MULTIMODAL_SEND_VIDEO_FORMAT: Literal["base64", "url"] = Field(
|
||||
description="Format for sending videos in multimodal contexts ('base64' or 'url'), default is base64",
|
||||
class MultiModalTransferConfig(BaseSettings):
|
||||
MULTIMODAL_SEND_FORMAT: Literal["base64", "url"] = Field(
|
||||
description="Format for sending files in multimodal contexts ('base64' or 'url'), default is base64",
|
||||
default="base64",
|
||||
)
|
||||
|
||||
|
@ -778,13 +773,13 @@ class FeatureConfig(
|
|||
FileAccessConfig,
|
||||
FileUploadConfig,
|
||||
HttpConfig,
|
||||
VisionFormatConfig,
|
||||
InnerAPIConfig,
|
||||
IndexingConfig,
|
||||
LoggingConfig,
|
||||
MailConfig,
|
||||
ModelLoadBalanceConfig,
|
||||
ModerationConfig,
|
||||
MultiModalTransferConfig,
|
||||
PositionConfig,
|
||||
RagEtlConfig,
|
||||
SecurityConfig,
|
||||
|
|
|
@ -42,33 +42,31 @@ def to_prompt_message_content(
|
|||
*,
|
||||
image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
|
||||
):
|
||||
match f.type:
|
||||
case FileType.IMAGE:
|
||||
image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
|
||||
if dify_config.MULTIMODAL_SEND_IMAGE_FORMAT == "url":
|
||||
data = _to_url(f)
|
||||
else:
|
||||
data = _to_base64_data_string(f)
|
||||
if f.extension is None:
|
||||
raise ValueError("Missing file extension")
|
||||
if f.mime_type is None:
|
||||
raise ValueError("Missing file mime_type")
|
||||
|
||||
return ImagePromptMessageContent(data=data, detail=image_detail_config, format=f.extension.lstrip("."))
|
||||
case FileType.AUDIO:
|
||||
data = _to_base64_data_string(f)
|
||||
if f.extension is None:
|
||||
raise ValueError("Missing file extension")
|
||||
return AudioPromptMessageContent(data=data, format=f.extension.lstrip("."))
|
||||
case FileType.VIDEO:
|
||||
if dify_config.MULTIMODAL_SEND_VIDEO_FORMAT == "url":
|
||||
data = _to_url(f)
|
||||
else:
|
||||
data = _to_base64_data_string(f)
|
||||
if f.extension is None:
|
||||
raise ValueError("Missing file extension")
|
||||
return VideoPromptMessageContent(data=data, format=f.extension.lstrip("."))
|
||||
case FileType.DOCUMENT:
|
||||
data = _to_base64_data_string(f)
|
||||
return DocumentPromptMessageContent(encode_format="base64", data=data, format=f.extension.lstrip("."))
|
||||
case _:
|
||||
raise ValueError(f"file type {f.type} is not supported")
|
||||
params = {
|
||||
"base64_data": _get_encoded_string(f) if dify_config.MULTIMODAL_SEND_FORMAT == "base64" else "",
|
||||
"url": _to_url(f) if dify_config.MULTIMODAL_SEND_FORMAT == "url" else "",
|
||||
"format": f.extension.removeprefix("."),
|
||||
"mime_type": f.mime_type,
|
||||
}
|
||||
if f.type == FileType.IMAGE:
|
||||
params["detail"] = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
|
||||
|
||||
prompt_class_map = {
|
||||
FileType.IMAGE: ImagePromptMessageContent,
|
||||
FileType.AUDIO: AudioPromptMessageContent,
|
||||
FileType.VIDEO: VideoPromptMessageContent,
|
||||
FileType.DOCUMENT: DocumentPromptMessageContent,
|
||||
}
|
||||
|
||||
try:
|
||||
return prompt_class_map[f.type](**params)
|
||||
except KeyError:
|
||||
raise ValueError(f"file type {f.type} is not supported")
|
||||
|
||||
|
||||
def download(f: File, /):
|
||||
|
@ -122,11 +120,6 @@ def _get_encoded_string(f: File, /):
|
|||
return encoded_string
|
||||
|
||||
|
||||
def _to_base64_data_string(f: File, /):
|
||||
encoded_string = _get_encoded_string(f)
|
||||
return f"data:{f.mime_type};base64,{encoded_string}"
|
||||
|
||||
|
||||
def _to_url(f: File, /):
|
||||
if f.transfer_method == FileTransferMethod.REMOTE_URL:
|
||||
if f.remote_url is None:
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
from abc import ABC
|
||||
from collections.abc import Sequence
|
||||
from enum import Enum, StrEnum
|
||||
from typing import Literal, Optional
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from pydantic import BaseModel, Field, computed_field, field_validator
|
||||
|
||||
|
||||
class PromptMessageRole(Enum):
|
||||
|
@ -67,7 +67,6 @@ class PromptMessageContent(BaseModel):
|
|||
"""
|
||||
|
||||
type: PromptMessageContentType
|
||||
data: str
|
||||
|
||||
|
||||
class TextPromptMessageContent(PromptMessageContent):
|
||||
|
@ -76,21 +75,35 @@ class TextPromptMessageContent(PromptMessageContent):
|
|||
"""
|
||||
|
||||
type: PromptMessageContentType = PromptMessageContentType.TEXT
|
||||
data: str
|
||||
|
||||
|
||||
class VideoPromptMessageContent(PromptMessageContent):
|
||||
class MultiModalPromptMessageContent(PromptMessageContent):
|
||||
"""
|
||||
Model class for multi-modal prompt message content.
|
||||
"""
|
||||
|
||||
type: PromptMessageContentType
|
||||
format: str = Field(..., description="the format of multi-modal file")
|
||||
base64_data: str = Field("", description="the base64 data of multi-modal file")
|
||||
url: str = Field("", description="the url of multi-modal file")
|
||||
mime_type: str = Field(..., description="the mime type of multi-modal file")
|
||||
|
||||
@computed_field(return_type=str)
|
||||
@property
|
||||
def data(self):
|
||||
return self.url or f"data:{self.mime_type};base64,{self.base64_data}"
|
||||
|
||||
|
||||
class VideoPromptMessageContent(MultiModalPromptMessageContent):
|
||||
type: PromptMessageContentType = PromptMessageContentType.VIDEO
|
||||
data: str = Field(..., description="Base64 encoded video data")
|
||||
format: str = Field(..., description="Video format")
|
||||
|
||||
|
||||
class AudioPromptMessageContent(PromptMessageContent):
|
||||
class AudioPromptMessageContent(MultiModalPromptMessageContent):
|
||||
type: PromptMessageContentType = PromptMessageContentType.AUDIO
|
||||
data: str = Field(..., description="Base64 encoded audio data")
|
||||
format: str = Field(..., description="Audio format")
|
||||
|
||||
|
||||
class ImagePromptMessageContent(PromptMessageContent):
|
||||
class ImagePromptMessageContent(MultiModalPromptMessageContent):
|
||||
"""
|
||||
Model class for image prompt message content.
|
||||
"""
|
||||
|
@ -101,14 +114,10 @@ class ImagePromptMessageContent(PromptMessageContent):
|
|||
|
||||
type: PromptMessageContentType = PromptMessageContentType.IMAGE
|
||||
detail: DETAIL = DETAIL.LOW
|
||||
format: str = Field("jpg", description="Image format")
|
||||
|
||||
|
||||
class DocumentPromptMessageContent(PromptMessageContent):
|
||||
class DocumentPromptMessageContent(MultiModalPromptMessageContent):
|
||||
type: PromptMessageContentType = PromptMessageContentType.DOCUMENT
|
||||
encode_format: Literal["base64"]
|
||||
data: str
|
||||
format: str = Field(..., description="Document format")
|
||||
|
||||
|
||||
class PromptMessage(ABC, BaseModel):
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
import base64
|
||||
import io
|
||||
import json
|
||||
from collections.abc import Generator, Sequence
|
||||
from typing import Optional, Union, cast
|
||||
|
@ -18,7 +17,6 @@ from anthropic.types import (
|
|||
)
|
||||
from anthropic.types.beta.tools import ToolsBetaMessage
|
||||
from httpx import Timeout
|
||||
from PIL import Image
|
||||
|
||||
from core.model_runtime.callbacks.base_callback import Callback
|
||||
from core.model_runtime.entities import (
|
||||
|
@ -498,22 +496,19 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
|
|||
sub_messages.append(sub_message_dict)
|
||||
elif message_content.type == PromptMessageContentType.IMAGE:
|
||||
message_content = cast(ImagePromptMessageContent, message_content)
|
||||
if not message_content.data.startswith("data:"):
|
||||
if not message_content.base64_data:
|
||||
# fetch image data from url
|
||||
try:
|
||||
image_content = requests.get(message_content.data).content
|
||||
with Image.open(io.BytesIO(image_content)) as img:
|
||||
mime_type = f"image/{img.format.lower()}"
|
||||
image_content = requests.get(message_content.url).content
|
||||
base64_data = base64.b64encode(image_content).decode("utf-8")
|
||||
except Exception as ex:
|
||||
raise ValueError(
|
||||
f"Failed to fetch image data from url {message_content.data}, {ex}"
|
||||
)
|
||||
else:
|
||||
data_split = message_content.data.split(";base64,")
|
||||
mime_type = data_split[0].replace("data:", "")
|
||||
base64_data = data_split[1]
|
||||
base64_data = message_content.base64_data
|
||||
|
||||
mime_type = message_content.mime_type
|
||||
if mime_type not in {"image/jpeg", "image/png", "image/gif", "image/webp"}:
|
||||
raise ValueError(
|
||||
f"Unsupported image type {mime_type}, "
|
||||
|
@ -526,19 +521,17 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
|
|||
}
|
||||
sub_messages.append(sub_message_dict)
|
||||
elif isinstance(message_content, DocumentPromptMessageContent):
|
||||
data_split = message_content.data.split(";base64,")
|
||||
mime_type = data_split[0].replace("data:", "")
|
||||
base64_data = data_split[1]
|
||||
if mime_type != "application/pdf":
|
||||
if message_content.mime_type != "application/pdf":
|
||||
raise ValueError(
|
||||
f"Unsupported document type {mime_type}, " "only support application/pdf"
|
||||
f"Unsupported document type {message_content.mime_type}, "
|
||||
"only support application/pdf"
|
||||
)
|
||||
sub_message_dict = {
|
||||
"type": "document",
|
||||
"source": {
|
||||
"type": message_content.encode_format,
|
||||
"media_type": mime_type,
|
||||
"data": base64_data,
|
||||
"type": "base64",
|
||||
"media_type": message_content.mime_type,
|
||||
"data": message_content.data,
|
||||
},
|
||||
}
|
||||
sub_messages.append(sub_message_dict)
|
||||
|
|
|
@ -434,9 +434,9 @@ class TongyiLargeLanguageModel(LargeLanguageModel):
|
|||
sub_messages.append(sub_message_dict)
|
||||
elif message_content.type == PromptMessageContentType.VIDEO:
|
||||
message_content = cast(VideoPromptMessageContent, message_content)
|
||||
video_url = message_content.data
|
||||
if message_content.data.startswith("data:"):
|
||||
raise InvokeError("not support base64, please set MULTIMODAL_SEND_VIDEO_FORMAT to url")
|
||||
video_url = message_content.url
|
||||
if not video_url:
|
||||
raise InvokeError("not support base64, please set MULTIMODAL_SEND_FORMAT to url")
|
||||
|
||||
sub_message_dict = {"video": video_url}
|
||||
sub_messages.append(sub_message_dict)
|
||||
|
|
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
|
@ -2,6 +2,7 @@ from unittest.mock import MagicMock, patch
|
|||
|
||||
import pytest
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.app_config.entities import ModelConfigEntity
|
||||
from core.file import File, FileTransferMethod, FileType, FileUploadConfig, ImageConfig
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
|
@ -126,6 +127,7 @@ def test__get_chat_model_prompt_messages_no_memory(get_chat_model_args):
|
|||
|
||||
def test__get_chat_model_prompt_messages_with_files_no_memory(get_chat_model_args):
|
||||
model_config_mock, _, messages, inputs, context = get_chat_model_args
|
||||
dify_config.MULTIMODAL_SEND_FORMAT = "url"
|
||||
|
||||
files = [
|
||||
File(
|
||||
|
@ -140,7 +142,9 @@ def test__get_chat_model_prompt_messages_with_files_no_memory(get_chat_model_arg
|
|||
prompt_transform = AdvancedPromptTransform()
|
||||
prompt_transform._calculate_rest_token = MagicMock(return_value=2000)
|
||||
with patch("core.file.file_manager.to_prompt_message_content") as mock_get_encoded_string:
|
||||
mock_get_encoded_string.return_value = ImagePromptMessageContent(data=str(files[0].remote_url))
|
||||
mock_get_encoded_string.return_value = ImagePromptMessageContent(
|
||||
url=str(files[0].remote_url), format="jpg", mime_type="image/jpg"
|
||||
)
|
||||
prompt_messages = prompt_transform._get_chat_model_prompt_messages(
|
||||
prompt_template=messages,
|
||||
inputs=inputs,
|
||||
|
|
|
@ -18,8 +18,7 @@ from core.model_runtime.entities.message_entities import (
|
|||
TextPromptMessageContent,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelFeature, ModelType, ProviderModel
|
||||
from core.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelFeature, ModelType
|
||||
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
||||
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
|
||||
from core.variables import ArrayAnySegment, ArrayFileSegment, NoneSegment
|
||||
|
@ -249,8 +248,7 @@ def test_fetch_prompt_messages__vison_disabled(faker, llm_node, model_config):
|
|||
|
||||
def test_fetch_prompt_messages__basic(faker, llm_node, model_config):
|
||||
# Setup dify config
|
||||
dify_config.MULTIMODAL_SEND_IMAGE_FORMAT = "url"
|
||||
dify_config.MULTIMODAL_SEND_VIDEO_FORMAT = "url"
|
||||
dify_config.MULTIMODAL_SEND_FORMAT = "url"
|
||||
|
||||
# Generate fake values for prompt template
|
||||
fake_assistant_prompt = faker.sentence()
|
||||
|
@ -326,9 +324,10 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config):
|
|||
tenant_id="test",
|
||||
type=FileType.IMAGE,
|
||||
filename="test1.jpg",
|
||||
extension=".jpg",
|
||||
transfer_method=FileTransferMethod.REMOTE_URL,
|
||||
remote_url=fake_remote_url,
|
||||
extension=".jpg",
|
||||
mime_type="image/jpg",
|
||||
)
|
||||
],
|
||||
vision_enabled=True,
|
||||
|
@ -362,7 +361,9 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config):
|
|||
UserPromptMessage(
|
||||
content=[
|
||||
TextPromptMessageContent(data=fake_query),
|
||||
ImagePromptMessageContent(data=fake_remote_url, detail=fake_vision_detail),
|
||||
ImagePromptMessageContent(
|
||||
url=fake_remote_url, mime_type="image/jpg", format="jpg", detail=fake_vision_detail
|
||||
),
|
||||
]
|
||||
),
|
||||
],
|
||||
|
@ -385,7 +386,9 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config):
|
|||
expected_messages=[
|
||||
UserPromptMessage(
|
||||
content=[
|
||||
ImagePromptMessageContent(data=fake_remote_url, detail=fake_vision_detail),
|
||||
ImagePromptMessageContent(
|
||||
url=fake_remote_url, mime_type="image/jpg", format="jpg", detail=fake_vision_detail
|
||||
),
|
||||
]
|
||||
),
|
||||
]
|
||||
|
@ -396,9 +399,10 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config):
|
|||
tenant_id="test",
|
||||
type=FileType.IMAGE,
|
||||
filename="test1.jpg",
|
||||
extension=".jpg",
|
||||
transfer_method=FileTransferMethod.REMOTE_URL,
|
||||
remote_url=fake_remote_url,
|
||||
extension=".jpg",
|
||||
mime_type="image/jpg",
|
||||
)
|
||||
},
|
||||
),
|
||||
|
|
|
@ -614,13 +614,12 @@ CODE_GENERATION_MAX_TOKENS=1024
|
|||
# Multi-modal Configuration
|
||||
# ------------------------------
|
||||
|
||||
# The format of the image/video sent when the multi-modal model is input,
|
||||
# The format of the image/video/audio/document sent when the multi-modal model is input,
|
||||
# the default is base64, optional url.
|
||||
# The delay of the call in url mode will be lower than that in base64 mode.
|
||||
# It is generally recommended to use the more compatible base64 mode.
|
||||
# If configured as url, you need to configure FILES_URL as an externally accessible address so that the multi-modal model can access the image/video.
|
||||
MULTIMODAL_SEND_IMAGE_FORMAT=base64
|
||||
MULTIMODAL_SEND_VIDEO_FORMAT=base64
|
||||
# If configured as url, you need to configure FILES_URL as an externally accessible address so that the multi-modal model can access the image/video/audio/document.
|
||||
MULTIMODAL_SEND_FORMAT=base64
|
||||
|
||||
# Upload image file size limit, default 10M.
|
||||
UPLOAD_IMAGE_FILE_SIZE_LIMIT=10
|
||||
|
|
|
@ -225,8 +225,7 @@ x-shared-env: &shared-api-worker-env
|
|||
UNSTRUCTURED_API_KEY: ${UNSTRUCTURED_API_KEY:-}
|
||||
PROMPT_GENERATION_MAX_TOKENS: ${PROMPT_GENERATION_MAX_TOKENS:-512}
|
||||
CODE_GENERATION_MAX_TOKENS: ${CODE_GENERATION_MAX_TOKENS:-1024}
|
||||
MULTIMODAL_SEND_IMAGE_FORMAT: ${MULTIMODAL_SEND_IMAGE_FORMAT:-base64}
|
||||
MULTIMODAL_SEND_VIDEO_FORMAT: ${MULTIMODAL_SEND_VIDEO_FORMAT:-base64}
|
||||
MULTIMODAL_SEND_FORMAT: ${MULTIMODAL_SEND_FORMAT:-base64}
|
||||
UPLOAD_IMAGE_FILE_SIZE_LIMIT: ${UPLOAD_IMAGE_FILE_SIZE_LIMIT:-10}
|
||||
UPLOAD_VIDEO_FILE_SIZE_LIMIT: ${UPLOAD_VIDEO_FILE_SIZE_LIMIT:-100}
|
||||
UPLOAD_AUDIO_FILE_SIZE_LIMIT: ${UPLOAD_AUDIO_FILE_SIZE_LIMIT:-50}
|
||||
|
|
Loading…
Reference in New Issue