mirror of https://github.com/langgenius/dify.git
fix: implement robust file type checks to align with existing logic (#17557)
Co-authored-by: Bowen Liang <liangbowen@gf.com.cn>
This commit is contained in:
parent
18f98f4fe1
commit
cac0d3c33e
|
@ -17,6 +17,7 @@ class BaseAppGenerator:
|
|||
user_inputs: Optional[Mapping[str, Any]],
|
||||
variables: Sequence["VariableEntity"],
|
||||
tenant_id: str,
|
||||
strict_type_validation: bool = False,
|
||||
) -> Mapping[str, Any]:
|
||||
user_inputs = user_inputs or {}
|
||||
# Filter input variables from form configuration, handle required fields, default values, and option values
|
||||
|
@ -37,6 +38,7 @@ class BaseAppGenerator:
|
|||
allowed_file_extensions=entity_dictionary[k].allowed_file_extensions,
|
||||
allowed_file_upload_methods=entity_dictionary[k].allowed_file_upload_methods,
|
||||
),
|
||||
strict_type_validation=strict_type_validation,
|
||||
)
|
||||
for k, v in user_inputs.items()
|
||||
if isinstance(v, dict) and entity_dictionary[k].type == VariableEntityType.FILE
|
||||
|
|
|
@ -92,6 +92,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||
mappings=files,
|
||||
tenant_id=app_model.tenant_id,
|
||||
config=file_extra_config,
|
||||
strict_type_validation=True if invoke_from == InvokeFrom.SERVICE_API else False,
|
||||
)
|
||||
|
||||
# convert to app config
|
||||
|
@ -114,7 +115,10 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||
app_config=app_config,
|
||||
file_upload_config=file_extra_config,
|
||||
inputs=self._prepare_user_inputs(
|
||||
user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id
|
||||
user_inputs=inputs,
|
||||
variables=app_config.variables,
|
||||
tenant_id=app_model.tenant_id,
|
||||
strict_type_validation=True if invoke_from == InvokeFrom.SERVICE_API else False,
|
||||
),
|
||||
files=list(system_files),
|
||||
user_id=user.id,
|
||||
|
|
|
@ -52,6 +52,7 @@ def build_from_mapping(
|
|||
mapping: Mapping[str, Any],
|
||||
tenant_id: str,
|
||||
config: FileUploadConfig | None = None,
|
||||
strict_type_validation: bool = False,
|
||||
) -> File:
|
||||
transfer_method = FileTransferMethod.value_of(mapping.get("transfer_method"))
|
||||
|
||||
|
@ -69,6 +70,7 @@ def build_from_mapping(
|
|||
mapping=mapping,
|
||||
tenant_id=tenant_id,
|
||||
transfer_method=transfer_method,
|
||||
strict_type_validation=strict_type_validation,
|
||||
)
|
||||
|
||||
if config and not _is_file_valid_with_config(
|
||||
|
@ -87,12 +89,14 @@ def build_from_mappings(
|
|||
mappings: Sequence[Mapping[str, Any]],
|
||||
config: FileUploadConfig | None = None,
|
||||
tenant_id: str,
|
||||
strict_type_validation: bool = False,
|
||||
) -> Sequence[File]:
|
||||
files = [
|
||||
build_from_mapping(
|
||||
mapping=mapping,
|
||||
tenant_id=tenant_id,
|
||||
config=config,
|
||||
strict_type_validation=strict_type_validation,
|
||||
)
|
||||
for mapping in mappings
|
||||
]
|
||||
|
@ -116,6 +120,7 @@ def _build_from_local_file(
|
|||
mapping: Mapping[str, Any],
|
||||
tenant_id: str,
|
||||
transfer_method: FileTransferMethod,
|
||||
strict_type_validation: bool = False,
|
||||
) -> File:
|
||||
upload_file_id = mapping.get("upload_file_id")
|
||||
if not upload_file_id:
|
||||
|
@ -134,10 +139,16 @@ def _build_from_local_file(
|
|||
if row is None:
|
||||
raise ValueError("Invalid upload file")
|
||||
|
||||
file_type = _standardize_file_type(extension="." + row.extension, mime_type=row.mime_type)
|
||||
if file_type.value != mapping.get("type", "custom"):
|
||||
detected_file_type = _standardize_file_type(extension="." + row.extension, mime_type=row.mime_type)
|
||||
specified_type = mapping.get("type", "custom")
|
||||
|
||||
if strict_type_validation and detected_file_type.value != specified_type:
|
||||
raise ValueError("Detected file type does not match the specified type. Please verify the file.")
|
||||
|
||||
file_type = (
|
||||
FileType(specified_type) if specified_type and specified_type != FileType.CUSTOM.value else detected_file_type
|
||||
)
|
||||
|
||||
return File(
|
||||
id=mapping.get("id"),
|
||||
filename=row.name,
|
||||
|
@ -158,6 +169,7 @@ def _build_from_remote_url(
|
|||
mapping: Mapping[str, Any],
|
||||
tenant_id: str,
|
||||
transfer_method: FileTransferMethod,
|
||||
strict_type_validation: bool = False,
|
||||
) -> File:
|
||||
upload_file_id = mapping.get("upload_file_id")
|
||||
if upload_file_id:
|
||||
|
@ -174,10 +186,21 @@ def _build_from_remote_url(
|
|||
if upload_file is None:
|
||||
raise ValueError("Invalid upload file")
|
||||
|
||||
file_type = _standardize_file_type(extension="." + upload_file.extension, mime_type=upload_file.mime_type)
|
||||
if file_type.value != mapping.get("type", "custom"):
|
||||
detected_file_type = _standardize_file_type(
|
||||
extension="." + upload_file.extension, mime_type=upload_file.mime_type
|
||||
)
|
||||
|
||||
specified_type = mapping.get("type")
|
||||
|
||||
if strict_type_validation and specified_type and detected_file_type.value != specified_type:
|
||||
raise ValueError("Detected file type does not match the specified type. Please verify the file.")
|
||||
|
||||
file_type = (
|
||||
FileType(specified_type)
|
||||
if specified_type and specified_type != FileType.CUSTOM.value
|
||||
else detected_file_type
|
||||
)
|
||||
|
||||
return File(
|
||||
id=mapping.get("id"),
|
||||
filename=upload_file.name,
|
||||
|
@ -237,6 +260,7 @@ def _build_from_tool_file(
|
|||
mapping: Mapping[str, Any],
|
||||
tenant_id: str,
|
||||
transfer_method: FileTransferMethod,
|
||||
strict_type_validation: bool = False,
|
||||
) -> File:
|
||||
tool_file = (
|
||||
db.session.query(ToolFile)
|
||||
|
@ -252,7 +276,16 @@ def _build_from_tool_file(
|
|||
|
||||
extension = "." + tool_file.file_key.split(".")[-1] if "." in tool_file.file_key else ".bin"
|
||||
|
||||
file_type = _standardize_file_type(extension=extension, mime_type=tool_file.mimetype)
|
||||
detected_file_type = _standardize_file_type(extension="." + extension, mime_type=tool_file.mimetype)
|
||||
|
||||
specified_type = mapping.get("type")
|
||||
|
||||
if strict_type_validation and specified_type and detected_file_type.value != specified_type:
|
||||
raise ValueError("Detected file type does not match the specified type. Please verify the file.")
|
||||
|
||||
file_type = (
|
||||
FileType(specified_type) if specified_type and specified_type != FileType.CUSTOM.value else detected_file_type
|
||||
)
|
||||
|
||||
return File(
|
||||
id=mapping.get("id"),
|
||||
|
|
|
@ -0,0 +1,198 @@
|
|||
import uuid
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from httpx import Response
|
||||
|
||||
from factories.file_factory import (
|
||||
File,
|
||||
FileTransferMethod,
|
||||
FileType,
|
||||
FileUploadConfig,
|
||||
build_from_mapping,
|
||||
)
|
||||
from models import ToolFile, UploadFile
|
||||
|
||||
# Test Data
|
||||
TEST_TENANT_ID = "test_tenant_id"
|
||||
TEST_UPLOAD_FILE_ID = str(uuid.uuid4())
|
||||
TEST_TOOL_FILE_ID = str(uuid.uuid4())
|
||||
TEST_REMOTE_URL = "http://example.com/test.jpg"
|
||||
|
||||
# Test Config
|
||||
TEST_CONFIG = FileUploadConfig(
|
||||
allowed_file_types=["image", "document"],
|
||||
allowed_file_extensions=[".jpg", ".pdf"],
|
||||
allowed_file_upload_methods=[FileTransferMethod.LOCAL_FILE, FileTransferMethod.TOOL_FILE],
|
||||
number_limits=10,
|
||||
)
|
||||
|
||||
|
||||
# Fixtures
|
||||
@pytest.fixture
|
||||
def mock_upload_file():
|
||||
mock = MagicMock(spec=UploadFile)
|
||||
mock.id = TEST_UPLOAD_FILE_ID
|
||||
mock.tenant_id = TEST_TENANT_ID
|
||||
mock.name = "test.jpg"
|
||||
mock.extension = "jpg"
|
||||
mock.mime_type = "image/jpeg"
|
||||
mock.source_url = TEST_REMOTE_URL
|
||||
mock.size = 1024
|
||||
mock.key = "test_key"
|
||||
with patch("factories.file_factory.db.session.scalar", return_value=mock) as m:
|
||||
yield m
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_tool_file():
|
||||
mock = MagicMock(spec=ToolFile)
|
||||
mock.id = TEST_TOOL_FILE_ID
|
||||
mock.tenant_id = TEST_TENANT_ID
|
||||
mock.name = "tool_file.pdf"
|
||||
mock.file_key = "tool_file.pdf"
|
||||
mock.mimetype = "application/pdf"
|
||||
mock.original_url = "http://example.com/tool.pdf"
|
||||
mock.size = 2048
|
||||
with patch("factories.file_factory.db.session.query") as mock_query:
|
||||
mock_query.return_value.filter.return_value.first.return_value = mock
|
||||
yield mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_http_head():
|
||||
def _mock_response(filename, size, content_type):
|
||||
return Response(
|
||||
status_code=200,
|
||||
headers={
|
||||
"Content-Disposition": f'attachment; filename="{filename}"',
|
||||
"Content-Length": str(size),
|
||||
"Content-Type": content_type,
|
||||
},
|
||||
)
|
||||
|
||||
with patch("factories.file_factory.ssrf_proxy.head") as mock_head:
|
||||
mock_head.return_value = _mock_response("remote_test.jpg", 2048, "image/jpeg")
|
||||
yield mock_head
|
||||
|
||||
|
||||
# Helper functions
|
||||
def local_file_mapping(file_type="image"):
|
||||
return {
|
||||
"transfer_method": "local_file",
|
||||
"upload_file_id": TEST_UPLOAD_FILE_ID,
|
||||
"type": file_type,
|
||||
}
|
||||
|
||||
|
||||
def tool_file_mapping(file_type="document"):
|
||||
return {
|
||||
"transfer_method": "tool_file",
|
||||
"tool_file_id": TEST_TOOL_FILE_ID,
|
||||
"type": file_type,
|
||||
}
|
||||
|
||||
|
||||
# Tests
|
||||
def test_build_from_mapping_backward_compatibility(mock_upload_file):
|
||||
mapping = local_file_mapping(file_type="image")
|
||||
file = build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID)
|
||||
assert isinstance(file, File)
|
||||
assert file.transfer_method == FileTransferMethod.LOCAL_FILE
|
||||
assert file.type == FileType.IMAGE
|
||||
assert file.related_id == TEST_UPLOAD_FILE_ID
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("file_type", "should_pass", "expected_error"),
|
||||
[
|
||||
("image", True, None),
|
||||
("document", False, "Detected file type does not match"),
|
||||
],
|
||||
)
|
||||
def test_build_from_local_file_strict_validation(mock_upload_file, file_type, should_pass, expected_error):
|
||||
mapping = local_file_mapping(file_type=file_type)
|
||||
if should_pass:
|
||||
file = build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID, strict_type_validation=True)
|
||||
assert file.type == FileType(file_type)
|
||||
else:
|
||||
with pytest.raises(ValueError, match=expected_error):
|
||||
build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID, strict_type_validation=True)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("file_type", "should_pass", "expected_error"),
|
||||
[
|
||||
("document", True, None),
|
||||
("image", False, "Detected file type does not match"),
|
||||
],
|
||||
)
|
||||
def test_build_from_tool_file_strict_validation(mock_tool_file, file_type, should_pass, expected_error):
|
||||
"""Strict type validation for tool_file."""
|
||||
mapping = tool_file_mapping(file_type=file_type)
|
||||
if should_pass:
|
||||
file = build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID, strict_type_validation=True)
|
||||
assert file.type == FileType(file_type)
|
||||
else:
|
||||
with pytest.raises(ValueError, match=expected_error):
|
||||
build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID, strict_type_validation=True)
|
||||
|
||||
|
||||
def test_build_from_remote_url(mock_http_head):
|
||||
mapping = {
|
||||
"transfer_method": "remote_url",
|
||||
"url": TEST_REMOTE_URL,
|
||||
"type": "image",
|
||||
}
|
||||
file = build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID)
|
||||
assert file.transfer_method == FileTransferMethod.REMOTE_URL
|
||||
assert file.type == FileType.IMAGE
|
||||
assert file.filename == "remote_test.jpg"
|
||||
assert file.size == 2048
|
||||
|
||||
|
||||
def test_tool_file_not_found():
|
||||
"""Test ToolFile not found in database."""
|
||||
with patch("factories.file_factory.db.session.query") as mock_query:
|
||||
mock_query.return_value.filter.return_value.first.return_value = None
|
||||
mapping = tool_file_mapping()
|
||||
with pytest.raises(ValueError, match=f"ToolFile {TEST_TOOL_FILE_ID} not found"):
|
||||
build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID)
|
||||
|
||||
|
||||
def test_local_file_not_found():
|
||||
"""Test UploadFile not found in database."""
|
||||
with patch("factories.file_factory.db.session.scalar", return_value=None):
|
||||
mapping = local_file_mapping()
|
||||
with pytest.raises(ValueError, match="Invalid upload file"):
|
||||
build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID)
|
||||
|
||||
|
||||
def test_build_without_type_specification(mock_upload_file):
|
||||
"""Test the situation where no file type is specified"""
|
||||
mapping = {
|
||||
"transfer_method": "local_file",
|
||||
"upload_file_id": TEST_UPLOAD_FILE_ID,
|
||||
# leave out the type
|
||||
}
|
||||
file = build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID)
|
||||
# It should automatically infer the type as "image" based on the file extension
|
||||
assert file.type == FileType.IMAGE
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("file_type", "should_pass", "expected_error"),
|
||||
[
|
||||
("image", True, None),
|
||||
("video", False, "File validation failed"),
|
||||
],
|
||||
)
|
||||
def test_file_validation_with_config(mock_upload_file, file_type, should_pass, expected_error):
|
||||
"""Test the validation of files and configurations"""
|
||||
mapping = local_file_mapping(file_type=file_type)
|
||||
if should_pass:
|
||||
file = build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID, config=TEST_CONFIG)
|
||||
assert file is not None
|
||||
else:
|
||||
with pytest.raises(ValueError, match=expected_error):
|
||||
build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID, config=TEST_CONFIG)
|
Loading…
Reference in New Issue