mirror of https://github.com/langgenius/dify.git
fix: implement robust file type checks to align with existing logic (#17557)
Co-authored-by: Bowen Liang <liangbowen@gf.com.cn>
This commit is contained in:
parent
18f98f4fe1
commit
cac0d3c33e
|
@ -17,6 +17,7 @@ class BaseAppGenerator:
|
||||||
user_inputs: Optional[Mapping[str, Any]],
|
user_inputs: Optional[Mapping[str, Any]],
|
||||||
variables: Sequence["VariableEntity"],
|
variables: Sequence["VariableEntity"],
|
||||||
tenant_id: str,
|
tenant_id: str,
|
||||||
|
strict_type_validation: bool = False,
|
||||||
) -> Mapping[str, Any]:
|
) -> Mapping[str, Any]:
|
||||||
user_inputs = user_inputs or {}
|
user_inputs = user_inputs or {}
|
||||||
# Filter input variables from form configuration, handle required fields, default values, and option values
|
# Filter input variables from form configuration, handle required fields, default values, and option values
|
||||||
|
@ -37,6 +38,7 @@ class BaseAppGenerator:
|
||||||
allowed_file_extensions=entity_dictionary[k].allowed_file_extensions,
|
allowed_file_extensions=entity_dictionary[k].allowed_file_extensions,
|
||||||
allowed_file_upload_methods=entity_dictionary[k].allowed_file_upload_methods,
|
allowed_file_upload_methods=entity_dictionary[k].allowed_file_upload_methods,
|
||||||
),
|
),
|
||||||
|
strict_type_validation=strict_type_validation,
|
||||||
)
|
)
|
||||||
for k, v in user_inputs.items()
|
for k, v in user_inputs.items()
|
||||||
if isinstance(v, dict) and entity_dictionary[k].type == VariableEntityType.FILE
|
if isinstance(v, dict) and entity_dictionary[k].type == VariableEntityType.FILE
|
||||||
|
|
|
@ -92,6 +92,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||||
mappings=files,
|
mappings=files,
|
||||||
tenant_id=app_model.tenant_id,
|
tenant_id=app_model.tenant_id,
|
||||||
config=file_extra_config,
|
config=file_extra_config,
|
||||||
|
strict_type_validation=True if invoke_from == InvokeFrom.SERVICE_API else False,
|
||||||
)
|
)
|
||||||
|
|
||||||
# convert to app config
|
# convert to app config
|
||||||
|
@ -114,7 +115,10 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||||
app_config=app_config,
|
app_config=app_config,
|
||||||
file_upload_config=file_extra_config,
|
file_upload_config=file_extra_config,
|
||||||
inputs=self._prepare_user_inputs(
|
inputs=self._prepare_user_inputs(
|
||||||
user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id
|
user_inputs=inputs,
|
||||||
|
variables=app_config.variables,
|
||||||
|
tenant_id=app_model.tenant_id,
|
||||||
|
strict_type_validation=True if invoke_from == InvokeFrom.SERVICE_API else False,
|
||||||
),
|
),
|
||||||
files=list(system_files),
|
files=list(system_files),
|
||||||
user_id=user.id,
|
user_id=user.id,
|
||||||
|
|
|
@ -52,6 +52,7 @@ def build_from_mapping(
|
||||||
mapping: Mapping[str, Any],
|
mapping: Mapping[str, Any],
|
||||||
tenant_id: str,
|
tenant_id: str,
|
||||||
config: FileUploadConfig | None = None,
|
config: FileUploadConfig | None = None,
|
||||||
|
strict_type_validation: bool = False,
|
||||||
) -> File:
|
) -> File:
|
||||||
transfer_method = FileTransferMethod.value_of(mapping.get("transfer_method"))
|
transfer_method = FileTransferMethod.value_of(mapping.get("transfer_method"))
|
||||||
|
|
||||||
|
@ -69,6 +70,7 @@ def build_from_mapping(
|
||||||
mapping=mapping,
|
mapping=mapping,
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
transfer_method=transfer_method,
|
transfer_method=transfer_method,
|
||||||
|
strict_type_validation=strict_type_validation,
|
||||||
)
|
)
|
||||||
|
|
||||||
if config and not _is_file_valid_with_config(
|
if config and not _is_file_valid_with_config(
|
||||||
|
@ -87,12 +89,14 @@ def build_from_mappings(
|
||||||
mappings: Sequence[Mapping[str, Any]],
|
mappings: Sequence[Mapping[str, Any]],
|
||||||
config: FileUploadConfig | None = None,
|
config: FileUploadConfig | None = None,
|
||||||
tenant_id: str,
|
tenant_id: str,
|
||||||
|
strict_type_validation: bool = False,
|
||||||
) -> Sequence[File]:
|
) -> Sequence[File]:
|
||||||
files = [
|
files = [
|
||||||
build_from_mapping(
|
build_from_mapping(
|
||||||
mapping=mapping,
|
mapping=mapping,
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
config=config,
|
config=config,
|
||||||
|
strict_type_validation=strict_type_validation,
|
||||||
)
|
)
|
||||||
for mapping in mappings
|
for mapping in mappings
|
||||||
]
|
]
|
||||||
|
@ -116,6 +120,7 @@ def _build_from_local_file(
|
||||||
mapping: Mapping[str, Any],
|
mapping: Mapping[str, Any],
|
||||||
tenant_id: str,
|
tenant_id: str,
|
||||||
transfer_method: FileTransferMethod,
|
transfer_method: FileTransferMethod,
|
||||||
|
strict_type_validation: bool = False,
|
||||||
) -> File:
|
) -> File:
|
||||||
upload_file_id = mapping.get("upload_file_id")
|
upload_file_id = mapping.get("upload_file_id")
|
||||||
if not upload_file_id:
|
if not upload_file_id:
|
||||||
|
@ -134,10 +139,16 @@ def _build_from_local_file(
|
||||||
if row is None:
|
if row is None:
|
||||||
raise ValueError("Invalid upload file")
|
raise ValueError("Invalid upload file")
|
||||||
|
|
||||||
file_type = _standardize_file_type(extension="." + row.extension, mime_type=row.mime_type)
|
detected_file_type = _standardize_file_type(extension="." + row.extension, mime_type=row.mime_type)
|
||||||
if file_type.value != mapping.get("type", "custom"):
|
specified_type = mapping.get("type", "custom")
|
||||||
|
|
||||||
|
if strict_type_validation and detected_file_type.value != specified_type:
|
||||||
raise ValueError("Detected file type does not match the specified type. Please verify the file.")
|
raise ValueError("Detected file type does not match the specified type. Please verify the file.")
|
||||||
|
|
||||||
|
file_type = (
|
||||||
|
FileType(specified_type) if specified_type and specified_type != FileType.CUSTOM.value else detected_file_type
|
||||||
|
)
|
||||||
|
|
||||||
return File(
|
return File(
|
||||||
id=mapping.get("id"),
|
id=mapping.get("id"),
|
||||||
filename=row.name,
|
filename=row.name,
|
||||||
|
@ -158,6 +169,7 @@ def _build_from_remote_url(
|
||||||
mapping: Mapping[str, Any],
|
mapping: Mapping[str, Any],
|
||||||
tenant_id: str,
|
tenant_id: str,
|
||||||
transfer_method: FileTransferMethod,
|
transfer_method: FileTransferMethod,
|
||||||
|
strict_type_validation: bool = False,
|
||||||
) -> File:
|
) -> File:
|
||||||
upload_file_id = mapping.get("upload_file_id")
|
upload_file_id = mapping.get("upload_file_id")
|
||||||
if upload_file_id:
|
if upload_file_id:
|
||||||
|
@ -174,10 +186,21 @@ def _build_from_remote_url(
|
||||||
if upload_file is None:
|
if upload_file is None:
|
||||||
raise ValueError("Invalid upload file")
|
raise ValueError("Invalid upload file")
|
||||||
|
|
||||||
file_type = _standardize_file_type(extension="." + upload_file.extension, mime_type=upload_file.mime_type)
|
detected_file_type = _standardize_file_type(
|
||||||
if file_type.value != mapping.get("type", "custom"):
|
extension="." + upload_file.extension, mime_type=upload_file.mime_type
|
||||||
|
)
|
||||||
|
|
||||||
|
specified_type = mapping.get("type")
|
||||||
|
|
||||||
|
if strict_type_validation and specified_type and detected_file_type.value != specified_type:
|
||||||
raise ValueError("Detected file type does not match the specified type. Please verify the file.")
|
raise ValueError("Detected file type does not match the specified type. Please verify the file.")
|
||||||
|
|
||||||
|
file_type = (
|
||||||
|
FileType(specified_type)
|
||||||
|
if specified_type and specified_type != FileType.CUSTOM.value
|
||||||
|
else detected_file_type
|
||||||
|
)
|
||||||
|
|
||||||
return File(
|
return File(
|
||||||
id=mapping.get("id"),
|
id=mapping.get("id"),
|
||||||
filename=upload_file.name,
|
filename=upload_file.name,
|
||||||
|
@ -237,6 +260,7 @@ def _build_from_tool_file(
|
||||||
mapping: Mapping[str, Any],
|
mapping: Mapping[str, Any],
|
||||||
tenant_id: str,
|
tenant_id: str,
|
||||||
transfer_method: FileTransferMethod,
|
transfer_method: FileTransferMethod,
|
||||||
|
strict_type_validation: bool = False,
|
||||||
) -> File:
|
) -> File:
|
||||||
tool_file = (
|
tool_file = (
|
||||||
db.session.query(ToolFile)
|
db.session.query(ToolFile)
|
||||||
|
@ -252,7 +276,16 @@ def _build_from_tool_file(
|
||||||
|
|
||||||
extension = "." + tool_file.file_key.split(".")[-1] if "." in tool_file.file_key else ".bin"
|
extension = "." + tool_file.file_key.split(".")[-1] if "." in tool_file.file_key else ".bin"
|
||||||
|
|
||||||
file_type = _standardize_file_type(extension=extension, mime_type=tool_file.mimetype)
|
detected_file_type = _standardize_file_type(extension="." + extension, mime_type=tool_file.mimetype)
|
||||||
|
|
||||||
|
specified_type = mapping.get("type")
|
||||||
|
|
||||||
|
if strict_type_validation and specified_type and detected_file_type.value != specified_type:
|
||||||
|
raise ValueError("Detected file type does not match the specified type. Please verify the file.")
|
||||||
|
|
||||||
|
file_type = (
|
||||||
|
FileType(specified_type) if specified_type and specified_type != FileType.CUSTOM.value else detected_file_type
|
||||||
|
)
|
||||||
|
|
||||||
return File(
|
return File(
|
||||||
id=mapping.get("id"),
|
id=mapping.get("id"),
|
||||||
|
|
|
@ -0,0 +1,198 @@
|
||||||
|
import uuid
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from httpx import Response
|
||||||
|
|
||||||
|
from factories.file_factory import (
|
||||||
|
File,
|
||||||
|
FileTransferMethod,
|
||||||
|
FileType,
|
||||||
|
FileUploadConfig,
|
||||||
|
build_from_mapping,
|
||||||
|
)
|
||||||
|
from models import ToolFile, UploadFile
|
||||||
|
|
||||||
|
# Test Data
|
||||||
|
TEST_TENANT_ID = "test_tenant_id"
|
||||||
|
TEST_UPLOAD_FILE_ID = str(uuid.uuid4())
|
||||||
|
TEST_TOOL_FILE_ID = str(uuid.uuid4())
|
||||||
|
TEST_REMOTE_URL = "http://example.com/test.jpg"
|
||||||
|
|
||||||
|
# Test Config
|
||||||
|
TEST_CONFIG = FileUploadConfig(
|
||||||
|
allowed_file_types=["image", "document"],
|
||||||
|
allowed_file_extensions=[".jpg", ".pdf"],
|
||||||
|
allowed_file_upload_methods=[FileTransferMethod.LOCAL_FILE, FileTransferMethod.TOOL_FILE],
|
||||||
|
number_limits=10,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Fixtures
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_upload_file():
|
||||||
|
mock = MagicMock(spec=UploadFile)
|
||||||
|
mock.id = TEST_UPLOAD_FILE_ID
|
||||||
|
mock.tenant_id = TEST_TENANT_ID
|
||||||
|
mock.name = "test.jpg"
|
||||||
|
mock.extension = "jpg"
|
||||||
|
mock.mime_type = "image/jpeg"
|
||||||
|
mock.source_url = TEST_REMOTE_URL
|
||||||
|
mock.size = 1024
|
||||||
|
mock.key = "test_key"
|
||||||
|
with patch("factories.file_factory.db.session.scalar", return_value=mock) as m:
|
||||||
|
yield m
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_tool_file():
|
||||||
|
mock = MagicMock(spec=ToolFile)
|
||||||
|
mock.id = TEST_TOOL_FILE_ID
|
||||||
|
mock.tenant_id = TEST_TENANT_ID
|
||||||
|
mock.name = "tool_file.pdf"
|
||||||
|
mock.file_key = "tool_file.pdf"
|
||||||
|
mock.mimetype = "application/pdf"
|
||||||
|
mock.original_url = "http://example.com/tool.pdf"
|
||||||
|
mock.size = 2048
|
||||||
|
with patch("factories.file_factory.db.session.query") as mock_query:
|
||||||
|
mock_query.return_value.filter.return_value.first.return_value = mock
|
||||||
|
yield mock
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_http_head():
|
||||||
|
def _mock_response(filename, size, content_type):
|
||||||
|
return Response(
|
||||||
|
status_code=200,
|
||||||
|
headers={
|
||||||
|
"Content-Disposition": f'attachment; filename="{filename}"',
|
||||||
|
"Content-Length": str(size),
|
||||||
|
"Content-Type": content_type,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch("factories.file_factory.ssrf_proxy.head") as mock_head:
|
||||||
|
mock_head.return_value = _mock_response("remote_test.jpg", 2048, "image/jpeg")
|
||||||
|
yield mock_head
|
||||||
|
|
||||||
|
|
||||||
|
# Helper functions
|
||||||
|
def local_file_mapping(file_type="image"):
|
||||||
|
return {
|
||||||
|
"transfer_method": "local_file",
|
||||||
|
"upload_file_id": TEST_UPLOAD_FILE_ID,
|
||||||
|
"type": file_type,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def tool_file_mapping(file_type="document"):
|
||||||
|
return {
|
||||||
|
"transfer_method": "tool_file",
|
||||||
|
"tool_file_id": TEST_TOOL_FILE_ID,
|
||||||
|
"type": file_type,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# Tests
|
||||||
|
def test_build_from_mapping_backward_compatibility(mock_upload_file):
|
||||||
|
mapping = local_file_mapping(file_type="image")
|
||||||
|
file = build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID)
|
||||||
|
assert isinstance(file, File)
|
||||||
|
assert file.transfer_method == FileTransferMethod.LOCAL_FILE
|
||||||
|
assert file.type == FileType.IMAGE
|
||||||
|
assert file.related_id == TEST_UPLOAD_FILE_ID
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("file_type", "should_pass", "expected_error"),
|
||||||
|
[
|
||||||
|
("image", True, None),
|
||||||
|
("document", False, "Detected file type does not match"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_build_from_local_file_strict_validation(mock_upload_file, file_type, should_pass, expected_error):
|
||||||
|
mapping = local_file_mapping(file_type=file_type)
|
||||||
|
if should_pass:
|
||||||
|
file = build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID, strict_type_validation=True)
|
||||||
|
assert file.type == FileType(file_type)
|
||||||
|
else:
|
||||||
|
with pytest.raises(ValueError, match=expected_error):
|
||||||
|
build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID, strict_type_validation=True)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("file_type", "should_pass", "expected_error"),
|
||||||
|
[
|
||||||
|
("document", True, None),
|
||||||
|
("image", False, "Detected file type does not match"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_build_from_tool_file_strict_validation(mock_tool_file, file_type, should_pass, expected_error):
|
||||||
|
"""Strict type validation for tool_file."""
|
||||||
|
mapping = tool_file_mapping(file_type=file_type)
|
||||||
|
if should_pass:
|
||||||
|
file = build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID, strict_type_validation=True)
|
||||||
|
assert file.type == FileType(file_type)
|
||||||
|
else:
|
||||||
|
with pytest.raises(ValueError, match=expected_error):
|
||||||
|
build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID, strict_type_validation=True)
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_from_remote_url(mock_http_head):
|
||||||
|
mapping = {
|
||||||
|
"transfer_method": "remote_url",
|
||||||
|
"url": TEST_REMOTE_URL,
|
||||||
|
"type": "image",
|
||||||
|
}
|
||||||
|
file = build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID)
|
||||||
|
assert file.transfer_method == FileTransferMethod.REMOTE_URL
|
||||||
|
assert file.type == FileType.IMAGE
|
||||||
|
assert file.filename == "remote_test.jpg"
|
||||||
|
assert file.size == 2048
|
||||||
|
|
||||||
|
|
||||||
|
def test_tool_file_not_found():
|
||||||
|
"""Test ToolFile not found in database."""
|
||||||
|
with patch("factories.file_factory.db.session.query") as mock_query:
|
||||||
|
mock_query.return_value.filter.return_value.first.return_value = None
|
||||||
|
mapping = tool_file_mapping()
|
||||||
|
with pytest.raises(ValueError, match=f"ToolFile {TEST_TOOL_FILE_ID} not found"):
|
||||||
|
build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID)
|
||||||
|
|
||||||
|
|
||||||
|
def test_local_file_not_found():
|
||||||
|
"""Test UploadFile not found in database."""
|
||||||
|
with patch("factories.file_factory.db.session.scalar", return_value=None):
|
||||||
|
mapping = local_file_mapping()
|
||||||
|
with pytest.raises(ValueError, match="Invalid upload file"):
|
||||||
|
build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID)
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_without_type_specification(mock_upload_file):
|
||||||
|
"""Test the situation where no file type is specified"""
|
||||||
|
mapping = {
|
||||||
|
"transfer_method": "local_file",
|
||||||
|
"upload_file_id": TEST_UPLOAD_FILE_ID,
|
||||||
|
# leave out the type
|
||||||
|
}
|
||||||
|
file = build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID)
|
||||||
|
# It should automatically infer the type as "image" based on the file extension
|
||||||
|
assert file.type == FileType.IMAGE
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("file_type", "should_pass", "expected_error"),
|
||||||
|
[
|
||||||
|
("image", True, None),
|
||||||
|
("video", False, "File validation failed"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_file_validation_with_config(mock_upload_file, file_type, should_pass, expected_error):
|
||||||
|
"""Test the validation of files and configurations"""
|
||||||
|
mapping = local_file_mapping(file_type=file_type)
|
||||||
|
if should_pass:
|
||||||
|
file = build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID, config=TEST_CONFIG)
|
||||||
|
assert file is not None
|
||||||
|
else:
|
||||||
|
with pytest.raises(ValueError, match=expected_error):
|
||||||
|
build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID, config=TEST_CONFIG)
|
Loading…
Reference in New Issue