chore: the consistency of MultiModalPromptMessageContent (#11721)

This commit is contained in:
非法操作 2024-12-17 15:01:38 +08:00 committed by GitHub
parent 78c3051585
commit c9b4029ce7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 108 additions and 99 deletions

View File

@ -313,8 +313,7 @@ UPLOAD_VIDEO_FILE_SIZE_LIMIT=100
UPLOAD_AUDIO_FILE_SIZE_LIMIT=50
# Model configuration
MULTIMODAL_SEND_IMAGE_FORMAT=base64
MULTIMODAL_SEND_VIDEO_FORMAT=base64
MULTIMODAL_SEND_FORMAT=base64
PROMPT_GENERATION_MAX_TOKENS=512
CODE_GENERATION_MAX_TOKENS=1024

View File

@ -665,14 +665,9 @@ class IndexingConfig(BaseSettings):
)
class VisionFormatConfig(BaseSettings):
MULTIMODAL_SEND_IMAGE_FORMAT: Literal["base64", "url"] = Field(
description="Format for sending images in multimodal contexts ('base64' or 'url'), default is base64",
default="base64",
)
MULTIMODAL_SEND_VIDEO_FORMAT: Literal["base64", "url"] = Field(
description="Format for sending videos in multimodal contexts ('base64' or 'url'), default is base64",
class MultiModalTransferConfig(BaseSettings):
MULTIMODAL_SEND_FORMAT: Literal["base64", "url"] = Field(
description="Format for sending files in multimodal contexts ('base64' or 'url'), default is base64",
default="base64",
)
@ -778,13 +773,13 @@ class FeatureConfig(
FileAccessConfig,
FileUploadConfig,
HttpConfig,
VisionFormatConfig,
InnerAPIConfig,
IndexingConfig,
LoggingConfig,
MailConfig,
ModelLoadBalanceConfig,
ModerationConfig,
MultiModalTransferConfig,
PositionConfig,
RagEtlConfig,
SecurityConfig,

View File

@ -42,33 +42,31 @@ def to_prompt_message_content(
*,
image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
):
match f.type:
case FileType.IMAGE:
image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
if dify_config.MULTIMODAL_SEND_IMAGE_FORMAT == "url":
data = _to_url(f)
else:
data = _to_base64_data_string(f)
if f.extension is None:
raise ValueError("Missing file extension")
if f.mime_type is None:
raise ValueError("Missing file mime_type")
return ImagePromptMessageContent(data=data, detail=image_detail_config, format=f.extension.lstrip("."))
case FileType.AUDIO:
data = _to_base64_data_string(f)
if f.extension is None:
raise ValueError("Missing file extension")
return AudioPromptMessageContent(data=data, format=f.extension.lstrip("."))
case FileType.VIDEO:
if dify_config.MULTIMODAL_SEND_VIDEO_FORMAT == "url":
data = _to_url(f)
else:
data = _to_base64_data_string(f)
if f.extension is None:
raise ValueError("Missing file extension")
return VideoPromptMessageContent(data=data, format=f.extension.lstrip("."))
case FileType.DOCUMENT:
data = _to_base64_data_string(f)
return DocumentPromptMessageContent(encode_format="base64", data=data, format=f.extension.lstrip("."))
case _:
raise ValueError(f"file type {f.type} is not supported")
params = {
"base64_data": _get_encoded_string(f) if dify_config.MULTIMODAL_SEND_FORMAT == "base64" else "",
"url": _to_url(f) if dify_config.MULTIMODAL_SEND_FORMAT == "url" else "",
"format": f.extension.removeprefix("."),
"mime_type": f.mime_type,
}
if f.type == FileType.IMAGE:
params["detail"] = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
prompt_class_map = {
FileType.IMAGE: ImagePromptMessageContent,
FileType.AUDIO: AudioPromptMessageContent,
FileType.VIDEO: VideoPromptMessageContent,
FileType.DOCUMENT: DocumentPromptMessageContent,
}
try:
return prompt_class_map[f.type](**params)
except KeyError:
raise ValueError(f"file type {f.type} is not supported")
def download(f: File, /):
@ -122,11 +120,6 @@ def _get_encoded_string(f: File, /):
return encoded_string
def _to_base64_data_string(f: File, /):
encoded_string = _get_encoded_string(f)
return f"data:{f.mime_type};base64,{encoded_string}"
def _to_url(f: File, /):
if f.transfer_method == FileTransferMethod.REMOTE_URL:
if f.remote_url is None:

View File

@ -1,9 +1,9 @@
from abc import ABC
from collections.abc import Sequence
from enum import Enum, StrEnum
from typing import Literal, Optional
from typing import Optional
from pydantic import BaseModel, Field, field_validator
from pydantic import BaseModel, Field, computed_field, field_validator
class PromptMessageRole(Enum):
@ -67,7 +67,6 @@ class PromptMessageContent(BaseModel):
"""
type: PromptMessageContentType
data: str
class TextPromptMessageContent(PromptMessageContent):
@ -76,21 +75,35 @@ class TextPromptMessageContent(PromptMessageContent):
"""
type: PromptMessageContentType = PromptMessageContentType.TEXT
data: str
class VideoPromptMessageContent(PromptMessageContent):
class MultiModalPromptMessageContent(PromptMessageContent):
"""
Model class for multi-modal prompt message content.
"""
type: PromptMessageContentType
format: str = Field(..., description="the format of multi-modal file")
base64_data: str = Field("", description="the base64 data of multi-modal file")
url: str = Field("", description="the url of multi-modal file")
mime_type: str = Field(..., description="the mime type of multi-modal file")
@computed_field(return_type=str)
@property
def data(self):
return self.url or f"data:{self.mime_type};base64,{self.base64_data}"
class VideoPromptMessageContent(MultiModalPromptMessageContent):
type: PromptMessageContentType = PromptMessageContentType.VIDEO
data: str = Field(..., description="Base64 encoded video data")
format: str = Field(..., description="Video format")
class AudioPromptMessageContent(PromptMessageContent):
class AudioPromptMessageContent(MultiModalPromptMessageContent):
type: PromptMessageContentType = PromptMessageContentType.AUDIO
data: str = Field(..., description="Base64 encoded audio data")
format: str = Field(..., description="Audio format")
class ImagePromptMessageContent(PromptMessageContent):
class ImagePromptMessageContent(MultiModalPromptMessageContent):
"""
Model class for image prompt message content.
"""
@ -101,14 +114,10 @@ class ImagePromptMessageContent(PromptMessageContent):
type: PromptMessageContentType = PromptMessageContentType.IMAGE
detail: DETAIL = DETAIL.LOW
format: str = Field("jpg", description="Image format")
class DocumentPromptMessageContent(PromptMessageContent):
class DocumentPromptMessageContent(MultiModalPromptMessageContent):
type: PromptMessageContentType = PromptMessageContentType.DOCUMENT
encode_format: Literal["base64"]
data: str
format: str = Field(..., description="Document format")
class PromptMessage(ABC, BaseModel):

View File

@ -1,5 +1,4 @@
import base64
import io
import json
from collections.abc import Generator, Sequence
from typing import Optional, Union, cast
@ -18,7 +17,6 @@ from anthropic.types import (
)
from anthropic.types.beta.tools import ToolsBetaMessage
from httpx import Timeout
from PIL import Image
from core.model_runtime.callbacks.base_callback import Callback
from core.model_runtime.entities import (
@ -498,22 +496,19 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
sub_messages.append(sub_message_dict)
elif message_content.type == PromptMessageContentType.IMAGE:
message_content = cast(ImagePromptMessageContent, message_content)
if not message_content.data.startswith("data:"):
if not message_content.base64_data:
# fetch image data from url
try:
image_content = requests.get(message_content.data).content
with Image.open(io.BytesIO(image_content)) as img:
mime_type = f"image/{img.format.lower()}"
image_content = requests.get(message_content.url).content
base64_data = base64.b64encode(image_content).decode("utf-8")
except Exception as ex:
raise ValueError(
f"Failed to fetch image data from url {message_content.data}, {ex}"
)
else:
data_split = message_content.data.split(";base64,")
mime_type = data_split[0].replace("data:", "")
base64_data = data_split[1]
base64_data = message_content.base64_data
mime_type = message_content.mime_type
if mime_type not in {"image/jpeg", "image/png", "image/gif", "image/webp"}:
raise ValueError(
f"Unsupported image type {mime_type}, "
@ -526,19 +521,17 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
}
sub_messages.append(sub_message_dict)
elif isinstance(message_content, DocumentPromptMessageContent):
data_split = message_content.data.split(";base64,")
mime_type = data_split[0].replace("data:", "")
base64_data = data_split[1]
if mime_type != "application/pdf":
if message_content.mime_type != "application/pdf":
raise ValueError(
f"Unsupported document type {mime_type}, " "only support application/pdf"
f"Unsupported document type {message_content.mime_type}, "
"only support application/pdf"
)
sub_message_dict = {
"type": "document",
"source": {
"type": message_content.encode_format,
"media_type": mime_type,
"data": base64_data,
"type": "base64",
"media_type": message_content.mime_type,
"data": message_content.data,
},
}
sub_messages.append(sub_message_dict)

View File

@ -434,9 +434,9 @@ class TongyiLargeLanguageModel(LargeLanguageModel):
sub_messages.append(sub_message_dict)
elif message_content.type == PromptMessageContentType.VIDEO:
message_content = cast(VideoPromptMessageContent, message_content)
video_url = message_content.data
if message_content.data.startswith("data:"):
raise InvokeError("not support base64, please set MULTIMODAL_SEND_VIDEO_FORMAT to url")
video_url = message_content.url
if not video_url:
raise InvokeError("not support base64, please set MULTIMODAL_SEND_FORMAT to url")
sub_message_dict = {"video": video_url}
sub_messages.append(sub_message_dict)

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@ -2,6 +2,7 @@ from unittest.mock import MagicMock, patch
import pytest
from configs import dify_config
from core.app.app_config.entities import ModelConfigEntity
from core.file import File, FileTransferMethod, FileType, FileUploadConfig, ImageConfig
from core.memory.token_buffer_memory import TokenBufferMemory
@ -126,6 +127,7 @@ def test__get_chat_model_prompt_messages_no_memory(get_chat_model_args):
def test__get_chat_model_prompt_messages_with_files_no_memory(get_chat_model_args):
model_config_mock, _, messages, inputs, context = get_chat_model_args
dify_config.MULTIMODAL_SEND_FORMAT = "url"
files = [
File(
@ -140,7 +142,9 @@ def test__get_chat_model_prompt_messages_with_files_no_memory(get_chat_model_arg
prompt_transform = AdvancedPromptTransform()
prompt_transform._calculate_rest_token = MagicMock(return_value=2000)
with patch("core.file.file_manager.to_prompt_message_content") as mock_get_encoded_string:
mock_get_encoded_string.return_value = ImagePromptMessageContent(data=str(files[0].remote_url))
mock_get_encoded_string.return_value = ImagePromptMessageContent(
url=str(files[0].remote_url), format="jpg", mime_type="image/jpg"
)
prompt_messages = prompt_transform._get_chat_model_prompt_messages(
prompt_template=messages,
inputs=inputs,

View File

@ -18,8 +18,7 @@ from core.model_runtime.entities.message_entities import (
TextPromptMessageContent,
UserPromptMessage,
)
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelFeature, ModelType, ProviderModel
from core.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelFeature, ModelType
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
from core.variables import ArrayAnySegment, ArrayFileSegment, NoneSegment
@ -249,8 +248,7 @@ def test_fetch_prompt_messages__vison_disabled(faker, llm_node, model_config):
def test_fetch_prompt_messages__basic(faker, llm_node, model_config):
# Setup dify config
dify_config.MULTIMODAL_SEND_IMAGE_FORMAT = "url"
dify_config.MULTIMODAL_SEND_VIDEO_FORMAT = "url"
dify_config.MULTIMODAL_SEND_FORMAT = "url"
# Generate fake values for prompt template
fake_assistant_prompt = faker.sentence()
@ -326,9 +324,10 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config):
tenant_id="test",
type=FileType.IMAGE,
filename="test1.jpg",
extension=".jpg",
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url=fake_remote_url,
extension=".jpg",
mime_type="image/jpg",
)
],
vision_enabled=True,
@ -362,7 +361,9 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config):
UserPromptMessage(
content=[
TextPromptMessageContent(data=fake_query),
ImagePromptMessageContent(data=fake_remote_url, detail=fake_vision_detail),
ImagePromptMessageContent(
url=fake_remote_url, mime_type="image/jpg", format="jpg", detail=fake_vision_detail
),
]
),
],
@ -385,7 +386,9 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config):
expected_messages=[
UserPromptMessage(
content=[
ImagePromptMessageContent(data=fake_remote_url, detail=fake_vision_detail),
ImagePromptMessageContent(
url=fake_remote_url, mime_type="image/jpg", format="jpg", detail=fake_vision_detail
),
]
),
]
@ -396,9 +399,10 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config):
tenant_id="test",
type=FileType.IMAGE,
filename="test1.jpg",
extension=".jpg",
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url=fake_remote_url,
extension=".jpg",
mime_type="image/jpg",
)
},
),

View File

@ -614,13 +614,12 @@ CODE_GENERATION_MAX_TOKENS=1024
# Multi-modal Configuration
# ------------------------------
# The format of the image/video sent when the multi-modal model is input,
# The format of the image/video/audio/document sent when the multi-modal model is input,
# the default is base64, optional url.
# The delay of the call in url mode will be lower than that in base64 mode.
# It is generally recommended to use the more compatible base64 mode.
# If configured as url, you need to configure FILES_URL as an externally accessible address so that the multi-modal model can access the image/video.
MULTIMODAL_SEND_IMAGE_FORMAT=base64
MULTIMODAL_SEND_VIDEO_FORMAT=base64
# If configured as url, you need to configure FILES_URL as an externally accessible address so that the multi-modal model can access the image/video/audio/document.
MULTIMODAL_SEND_FORMAT=base64
# Upload image file size limit, default 10M.
UPLOAD_IMAGE_FILE_SIZE_LIMIT=10

View File

@ -225,8 +225,7 @@ x-shared-env: &shared-api-worker-env
UNSTRUCTURED_API_KEY: ${UNSTRUCTURED_API_KEY:-}
PROMPT_GENERATION_MAX_TOKENS: ${PROMPT_GENERATION_MAX_TOKENS:-512}
CODE_GENERATION_MAX_TOKENS: ${CODE_GENERATION_MAX_TOKENS:-1024}
MULTIMODAL_SEND_IMAGE_FORMAT: ${MULTIMODAL_SEND_IMAGE_FORMAT:-base64}
MULTIMODAL_SEND_VIDEO_FORMAT: ${MULTIMODAL_SEND_VIDEO_FORMAT:-base64}
MULTIMODAL_SEND_FORMAT: ${MULTIMODAL_SEND_FORMAT:-base64}
UPLOAD_IMAGE_FILE_SIZE_LIMIT: ${UPLOAD_IMAGE_FILE_SIZE_LIMIT:-10}
UPLOAD_VIDEO_FILE_SIZE_LIMIT: ${UPLOAD_VIDEO_FILE_SIZE_LIMIT:-100}
UPLOAD_AUDIO_FILE_SIZE_LIMIT: ${UPLOAD_AUDIO_FILE_SIZE_LIMIT:-50}