diff --git a/app/.env.example b/app/.env.example index b466c14..f88678f 100644 --- a/app/.env.example +++ b/app/.env.example @@ -3,6 +3,11 @@ FF_DB_URI=postgresql://root:password@localhost:5432/fileflash # DATABASE_URL=postgresql://root:password@localhost:5432/fileflash APP_ENV=development +# Required when APP_ENV=production or APP_ENV=prod. +DEFAULT_ADMIN_USERNAME=admin +DEFAULT_ADMIN_EMAIL=admin@example.com +DEFAULT_ADMIN_PASSWORD=replace-with-32-bytes-or-longer-password + JWT_SECRET_KEY=please-set-at-least-32-bytes-secret-key # Optional dedicated HMAC key for token hash persistence. # Falls back to JWT_SECRET_KEY when omitted. @@ -26,6 +31,7 @@ UPLOAD_CHUNK_SIZE_DEFAULT=5242880 UPLOAD_CHUNK_SIZE_MIN=1048576 UPLOAD_CHUNK_SIZE_MAX=16777216 UPLOAD_SINGLE_FILE_SIZE_MAX=5368709120 +UPLOAD_VERIFY_MERGED_OBJECT_HASH=false STARRED_ITEMS_LIMIT=20 UPLOAD_SESSION_TTL_HOURS=24 UPLOAD_TEMP_PREFIX=tmp diff --git a/app/src/fileflash/core/security.py b/app/src/fileflash/core/security.py index 63f3dda..f2b3c45 100644 --- a/app/src/fileflash/core/security.py +++ b/app/src/fileflash/core/security.py @@ -70,3 +70,32 @@ def decode_share_access_token(token: str, settings: Settings) -> dict[str, Any]: if token_type != "share": raise jwt.InvalidTokenError("Invalid token type") return payload + + +def create_file_preview_token( + *, + user_id: int, + file_id: int, + settings: Settings, + expires_at: datetime, +) -> str: + now = datetime.now(UTC) + payload: dict[str, Any] = { + "sub": str(user_id), + "typ": "file_preview", + "scope": "file.preview", + "fileId": str(file_id), + "iat": int(now.timestamp()), + "exp": int(expires_at.timestamp()), + "jti": str(uuid.uuid4()), + } + return jwt.encode(payload, settings.jwt_secret_key, algorithm=settings.jwt_algorithm) + + +def decode_file_preview_token(token: str, settings: Settings) -> dict[str, Any]: + payload = jwt.decode(token, settings.jwt_secret_key, algorithms=[settings.jwt_algorithm]) + token_type = payload.get("typ") + scope = payload.get("scope") + if token_type != "file_preview" or scope != "file.preview": + raise jwt.InvalidTokenError("Invalid token type") + return payload diff --git a/app/src/fileflash/core/settings.py b/app/src/fileflash/core/settings.py index bc7e066..f93a38b 100644 --- a/app/src/fileflash/core/settings.py +++ b/app/src/fileflash/core/settings.py @@ -28,6 +28,10 @@ class Settings(BaseSettings): api_v1_prefix: str = "/api/v1" app_env: str = Field(default="production", alias="APP_ENV") + default_admin_username: str | None = Field(default=None, alias="DEFAULT_ADMIN_USERNAME") + default_admin_email: str | None = Field(default=None, alias="DEFAULT_ADMIN_EMAIL") + default_admin_password: str | None = Field(default=None, alias="DEFAULT_ADMIN_PASSWORD") + database_url: str | None = Field(default=None, alias="DATABASE_URL") ff_db_uri: str | None = Field(default=None, alias="FF_DB_URI") @@ -39,6 +43,10 @@ class Settings(BaseSettings): jwt_algorithm: str = "HS256" access_token_expire_minutes: int = 60 * 24 * 3 refresh_token_expire_days: int = 7 + file_preview_url_ttl_seconds: int = Field( + default=4 * 60 * 60, + alias="FILE_PREVIEW_URL_TTL_SECONDS", + ) refresh_cookie_name: str = "refreshToken" refresh_cookie_secure: bool = False @@ -68,10 +76,20 @@ class Settings(BaseSettings): object_storage_secure: bool = Field(default=False, alias="OBJECT_STORAGE_SECURE") object_storage_region: str | None = Field(default=None, alias="OBJECT_STORAGE_REGION") - upload_chunk_size_default: int = Field(default=5 * 1024 * 1024, alias="UPLOAD_CHUNK_SIZE_DEFAULT") + upload_chunk_size_default: int = Field( + default=5 * 1024 * 1024, + alias="UPLOAD_CHUNK_SIZE_DEFAULT", + ) upload_chunk_size_min: int = Field(default=1 * 1024 * 1024, alias="UPLOAD_CHUNK_SIZE_MIN") upload_chunk_size_max: int = Field(default=16 * 1024 * 1024, alias="UPLOAD_CHUNK_SIZE_MAX") - upload_single_file_size_max: int = Field(default=5 * 1024 * 1024 * 1024, alias="UPLOAD_SINGLE_FILE_SIZE_MAX") + upload_single_file_size_max: int = Field( + default=5 * 1024 * 1024 * 1024, + alias="UPLOAD_SINGLE_FILE_SIZE_MAX", + ) + upload_verify_merged_object_hash: bool = Field( + default=False, + alias="UPLOAD_VERIFY_MERGED_OBJECT_HASH", + ) starred_items_limit: int = Field(default=20, alias="STARRED_ITEMS_LIMIT") upload_session_ttl_hours: int = Field(default=24, alias="UPLOAD_SESSION_TTL_HOURS") upload_temp_prefix: str = Field(default="tmp", alias="UPLOAD_TEMP_PREFIX") @@ -184,6 +202,25 @@ def security_configuration_issues(self) -> tuple[str, ...]: token_hash_secret = (self.token_hash_secret or "").strip() if token_hash_secret and len(token_hash_secret.encode("utf-8")) < self.MIN_SECRET_LENGTH: issues.append(f"TOKEN_HASH_SECRET must be at least {self.MIN_SECRET_LENGTH} bytes") + issues.extend(self.default_admin_configuration_issues) + return tuple(issues) + + @property + def default_admin_configuration_issues(self) -> tuple[str, ...]: + if not self.is_production_env: + return () + + issues: list[str] = [] + if not (self.default_admin_username or "").strip(): + issues.append("DEFAULT_ADMIN_USERNAME is required in production") + if not (self.default_admin_email or "").strip(): + issues.append("DEFAULT_ADMIN_EMAIL is required in production") + + password = (self.default_admin_password or "").strip() + if not password: + issues.append("DEFAULT_ADMIN_PASSWORD is required in production") + elif len(password.encode("utf-8")) < self.MIN_SECRET_LENGTH: + issues.append(f"DEFAULT_ADMIN_PASSWORD must be at least {self.MIN_SECRET_LENGTH} bytes") return tuple(issues) def assert_runtime_security(self) -> None: diff --git a/app/src/fileflash/routers/files.py b/app/src/fileflash/routers/files.py index b5fe019..e54c99a 100644 --- a/app/src/fileflash/routers/files.py +++ b/app/src/fileflash/routers/files.py @@ -1,18 +1,24 @@ from __future__ import annotations import os +from datetime import UTC, datetime, timedelta +from urllib.parse import urlencode -from fastapi import APIRouter, Depends, Header, Query +from fastapi import APIRouter, Depends, Header, Query, Request from fastapi.responses import FileResponse, StreamingResponse +from jwt import InvalidTokenError from starlette.background import BackgroundTask -from ..core.deps import get_archive_service, get_current_user, get_file_service +from ..core.deps import get_archive_service, get_current_user, get_file_service, get_settings_dep from ..core.errors import ApiError, api_success +from ..core.security import create_file_preview_token, decode_file_preview_token +from ..core.settings import Settings from ..models.tables_identity import User from ..schemas.archive import ArchiveExtractRequest from ..schemas.file import ( BatchDownloadRequest, BatchFilesRequest, + FilePreviewUrlResponse, GetFilesQuery, MoveFileRequest, RenameFileRequest, @@ -174,6 +180,66 @@ async def preview_file( ) +@router.post("/{file_id}/preview-url") +async def create_file_preview_url( + file_id: str, + request: Request, + current_user: User = Depends(get_current_user), + file_service: FileService = Depends(get_file_service), + settings: Settings = Depends(get_settings_dep), +): + try: + fid = int(file_id) + except ValueError as exc: + raise ApiError(status_code=400, code=400, message="Invalid fileId") from exc + + await file_service.get_file(user_id=current_user.user_id, file_id=fid) + expires_at = datetime.now(UTC) + timedelta(seconds=settings.file_preview_url_ttl_seconds) + token = create_file_preview_token( + user_id=int(current_user.user_id), + file_id=fid, + settings=settings, + expires_at=expires_at, + ) + stream_url = str(request.url_for("preview_file_stream", file_id=str(fid))) + result = FilePreviewUrlResponse( + url=f"{stream_url}?{urlencode({'token': token})}", + expires_at=expires_at, + ) + return api_success(data=result.model_dump(by_alias=True)) + + +@router.get("/{file_id}/preview-stream", name="preview_file_stream") +async def preview_file_stream( + file_id: str, + token: str = Query(..., min_length=1), + range_header: str | None = Header(default=None, alias="Range"), + file_service: FileService = Depends(get_file_service), + settings: Settings = Depends(get_settings_dep), +): + try: + payload = decode_file_preview_token(token, settings) + user_id = int(payload["sub"]) + token_file_id = str(payload["fileId"]) + except (InvalidTokenError, KeyError, ValueError): + raise ApiError(status_code=401, code=401, message="Invalid or expired preview token") from None + + if token_file_id != str(file_id): + raise ApiError(status_code=403, code=403, message="Preview token does not match file") + + result = await file_service.get_preview_stream( + user_id=user_id, + file_id=file_id, + range_header=range_header, + ) + return StreamingResponse( + result.stream, + media_type=result.content_type, + headers=result.headers, + status_code=result.status_code, + ) + + @router.delete("/{file_id}") async def delete_file( file_id: str, diff --git a/app/src/fileflash/schemas/file.py b/app/src/fileflash/schemas/file.py index 3bf75c9..3827366 100644 --- a/app/src/fileflash/schemas/file.py +++ b/app/src/fileflash/schemas/file.py @@ -177,6 +177,11 @@ class MediaOptimization(CamelModel): updated_at: datetime +class FilePreviewUrlResponse(CamelModel): + url: str + expires_at: datetime + + class RenameFileRequest(CamelModel): file_name: str = Field(min_length=1, max_length=255) diff --git a/app/src/fileflash/scripts/init_dev_accounts.py b/app/src/fileflash/scripts/init_dev_accounts.py index 80c5f6d..da67db6 100644 --- a/app/src/fileflash/scripts/init_dev_accounts.py +++ b/app/src/fileflash/scripts/init_dev_accounts.py @@ -11,11 +11,11 @@ def build_parser() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser(description="Initialize development test accounts for FileFlash.") + parser = argparse.ArgumentParser(description="Initialize seeded accounts for FileFlash.") parser.add_argument( "--reset-password", action="store_true", - help="Force reset passwords to defaults (admin/admin123, demo/demo123).", + help="Force reset seeded account passwords from the active environment configuration.", ) return parser @@ -23,8 +23,8 @@ def build_parser() -> argparse.ArgumentParser: async def run(reset_password: bool) -> int: settings = get_settings() if settings.is_production_env: - logger.warning( - "Manual dev-account initialization is running under APP_ENV=%s. This is not executed automatically in production.", + logger.info( + "Manual account initialization is using DEFAULT_ADMIN_* for APP_ENV=%s.", settings.app_env, ) diff --git a/app/src/fileflash/services/admin/system.py b/app/src/fileflash/services/admin/system.py index 629018a..1cbbffe 100644 --- a/app/src/fileflash/services/admin/system.py +++ b/app/src/fileflash/services/admin/system.py @@ -39,7 +39,7 @@ async def health(self) -> SystemHealth: virus_scan_enabled=bool(getattr(self.settings, "virus_scan_enabled", False)), thumbnail_generation_enabled=bool(getattr(self.settings, "thumbnail_generation_enabled", True)), registration_mail_enabled=bool(self.settings.mail_server and self.settings.mail_from), - hash_computation_enabled=True, + hash_computation_enabled=bool(self.settings.upload_verify_merged_object_hash), last_updated_at=datetime.now(UTC), ) diff --git a/app/src/fileflash/services/dev_seed.py b/app/src/fileflash/services/dev_seed.py index 0845280..6c66fd6 100644 --- a/app/src/fileflash/services/dev_seed.py +++ b/app/src/fileflash/services/dev_seed.py @@ -59,19 +59,23 @@ async def initialize_dev_accounts( reset_password: bool = False, auto_run: bool = False, ) -> bool: - if auto_run and not settings.is_development_env: - if settings.is_production_env: - logger.info("Skip dev account auto-initialization in production environment: %s", settings.app_env) - else: - logger.info("Skip dev account auto-initialization for non-dev APP_ENV=%s", settings.app_env) + if auto_run and not settings.is_development_env and not settings.is_production_env: + logger.info("Skip account auto-initialization for non-dev APP_ENV=%s", settings.app_env) return False + if settings.is_production_env: + settings.assert_runtime_security() + + accounts = _seed_accounts_for_settings(settings) + async with SessionLocal() as db: - seeder = DevAccountSeeder(db=db) + seeder = DevAccountSeeder(db=db, accounts=accounts) summary = await seeder.seed(reset_password=reset_password) logger.info( - "Dev accounts initialized: createdUsers=%s updatedUsers=%s resetPasswordUsers=%s createdPreferences=%s createdRoots=%s", + "%s initialized: createdUsers=%s updatedUsers=%s resetPasswordUsers=%s " + "createdPreferences=%s createdRoots=%s", + "Production default admin" if settings.is_production_env else "Dev accounts", summary.created_users, summary.updated_users, summary.reset_password_users, @@ -81,15 +85,44 @@ async def initialize_dev_accounts( return True +def _seed_accounts_for_settings(settings: Settings) -> tuple[DevSeedAccount, ...]: + if not settings.is_production_env: + return DEV_SEED_ACCOUNTS + + username = (settings.default_admin_username or "").strip() + email = (settings.default_admin_email or "").strip() + password = (settings.default_admin_password or "").strip() + if not username or not email or not password: + raise ValueError( + "DEFAULT_ADMIN_USERNAME, DEFAULT_ADMIN_EMAIL, and DEFAULT_ADMIN_PASSWORD are required" + ) + + return ( + DevSeedAccount( + username=username, + email=email, + password=password, + role=UserRole.ADMIN, + language=UiLanguage.ZH_CN, + ), + ) + + class DevAccountSeeder: - def __init__(self, *, db: AsyncSession) -> None: + def __init__( + self, + *, + db: AsyncSession, + accounts: tuple[DevSeedAccount, ...] = DEV_SEED_ACCOUNTS, + ) -> None: self.db = db + self.accounts = accounts async def seed(self, *, reset_password: bool = False) -> DevSeedSummary: summary = DevSeedSummary() now = datetime.now(UTC) - for spec in DEV_SEED_ACCOUNTS: + for spec in self.accounts: user = await self._find_existing_user(spec=spec) created = user is None @@ -129,7 +162,12 @@ async def seed(self, *, reset_password: bool = False) -> DevSeedSummary: if created: user.password_changed_at = now - await self._ensure_preference(user=user, language=spec.language, now=now, summary=summary) + await self._ensure_preference( + user=user, + language=spec.language, + now=now, + summary=summary, + ) await self._ensure_root_folder(user=user, now=now, summary=summary) await self.db.commit() @@ -156,7 +194,9 @@ async def _ensure_preference( now: datetime, summary: DevSeedSummary, ) -> None: - preference = await self.db.scalar(select(UserPreference).where(UserPreference.user_id == user.user_id)) + preference = await self.db.scalar( + select(UserPreference).where(UserPreference.user_id == user.user_id) + ) if preference is None: self.db.add( UserPreference( diff --git a/app/src/fileflash/services/file.py b/app/src/fileflash/services/file.py index 3d910da..453abdb 100644 --- a/app/src/fileflash/services/file.py +++ b/app/src/fileflash/services/file.py @@ -74,6 +74,12 @@ class DownloadStreamResult: headers: dict[str, str] +@dataclass(slots=True) +class ResolvedStreamObject: + storage_object: StorageObject + content_type_override: str | None = None + + class FileService: def __init__( self, @@ -284,10 +290,11 @@ async def _get_file_stream( raise ApiError(status_code=503, code=503, message="Object storage is unavailable") file_row = await self._get_active_file(user_id=user_id, file_id=file_id) - storage_object = await self._resolve_stream_storage_object( + resolved_object = await self._resolve_stream_storage_object( file_row=file_row, prefer_optimized=(content_disposition == "inline"), ) + storage_object = resolved_object.storage_object if resolved_object is not None else None if storage_object is None or storage_object.upload_status != UploadStatus.ACTIVE: raise ApiError(status_code=404, code=404, message="File content not found") @@ -296,7 +303,11 @@ async def _get_file_stream( raise ApiError(status_code=404, code=404, message="File content not found") content_type = resolve_file_mime_type( - mime_type=file_row.mime_type or storage_object.content_type, + mime_type=( + resolved_object.content_type_override + if resolved_object is not None and resolved_object.content_type_override + else file_row.mime_type or storage_object.content_type + ), file_ext=file_row.file_ext, file_name=file_row.file_name, default=DEFAULT_MIME_TYPE, @@ -1832,12 +1843,12 @@ async def _resolve_stream_storage_object( *, file_row: File, prefer_optimized: bool, - ) -> StorageObject | None: + ) -> ResolvedStreamObject | None: source_object = await self.db.get(StorageObject, int(file_row.storage_object_id)) if source_object is None: return None if not prefer_optimized: - return source_object + return ResolvedStreamObject(storage_object=source_object) metadata_row = await self.db.scalar( select(FileMediaMetadata) @@ -1845,17 +1856,19 @@ async def _resolve_stream_storage_object( .limit(1) ) if not isinstance(metadata_row, FileMediaMetadata): - return source_object + return ResolvedStreamObject(storage_object=source_object) transcode = (metadata_row.extra_metadata or {}).get("transcode") if not isinstance(transcode, dict): - return source_object + return ResolvedStreamObject(storage_object=source_object) if str(transcode.get("status") or "").strip().lower() != TRANSCODE_READY_STATUS: - return source_object + return ResolvedStreamObject(storage_object=source_object) bucket_name = str(transcode.get("optimizedBucketName") or "").strip() object_key = str(transcode.get("optimizedObjectKey") or "").strip() if not bucket_name or not object_key: - return source_object + return ResolvedStreamObject(storage_object=source_object) + + optimized_mime_type = str(transcode.get("optimizedMimeType") or "").strip() or None optimized_object = await self.db.scalar( select(StorageObject) @@ -1869,13 +1882,16 @@ async def _resolve_stream_storage_object( .limit(1) ) if isinstance(optimized_object, StorageObject): - return optimized_object + return ResolvedStreamObject( + storage_object=optimized_object, + content_type_override=optimized_mime_type or optimized_object.content_type, + ) if self.storage is None: - return source_object + return ResolvedStreamObject(storage_object=source_object) exists = await self.storage.object_exists(bucket_name=bucket_name, object_key=object_key) if not exists: - return source_object + return ResolvedStreamObject(storage_object=source_object) stat = await self.storage.stat_object(bucket_name=bucket_name, object_key=object_key) created = StorageObject( @@ -1889,7 +1905,10 @@ async def _resolve_stream_storage_object( ) self.db.add(created) await self.db.flush() - return created + return ResolvedStreamObject( + storage_object=created, + content_type_override=optimized_mime_type or created.content_type, + ) @staticmethod def _parse_datetime(raw: object) -> datetime | None: diff --git a/app/src/fileflash/services/upload.py b/app/src/fileflash/services/upload.py index 76e135a..491e6d8 100644 --- a/app/src/fileflash/services/upload.py +++ b/app/src/fileflash/services/upload.py @@ -511,16 +511,17 @@ async def _operation() -> MergeChunksResponse: await self.db.commit() raise ApiError(status_code=422, code=422, message="Composed file size mismatch") - actual_hash = await self.storage.compute_object_hash( - object_key=task.object_key, - algorithm=hash_algorithm, - ) - if actual_hash != object_hash: - await self.storage.remove_object(object_key=task.object_key) - task.status = UploadTaskStatus.FAILED - task.last_error = "Composed file hash mismatch" - await self.db.commit() - raise ApiError(status_code=422, code=422, message="Composed file hash mismatch") + if self.settings.upload_verify_merged_object_hash: + actual_hash = await self.storage.compute_object_hash( + object_key=task.object_key, + algorithm=hash_algorithm, + ) + if actual_hash != object_hash: + await self.storage.remove_object(object_key=task.object_key) + task.status = UploadTaskStatus.FAILED + task.last_error = "Composed file hash mismatch" + await self.db.commit() + raise ApiError(status_code=422, code=422, message="Composed file hash mismatch") storage_object = await self._find_storage_object( object_hash=object_hash, diff --git a/app/src/fileflash/tasks/transcode.py b/app/src/fileflash/tasks/transcode.py index 31632a9..adbfa7f 100644 --- a/app/src/fileflash/tasks/transcode.py +++ b/app/src/fileflash/tasks/transcode.py @@ -13,6 +13,11 @@ from ..s3.minio_client import MinioObjectStorageClient TRANSCODE_PROFILE_VERSION = "mp4-v1" +_FFMPEG_DETERMINISTIC_ERROR_PATTERNS = ( + "unsupported channel layout", + "error while opening encoder", + "could not open encoder before eof", +) @dataclass(slots=True) @@ -182,6 +187,8 @@ def build_ffmpeg_command( audio_codec, "-b:a", f"{audio_bitrate}k", + "-ac", + "2", "-map", "0:v:0", "-map", @@ -196,6 +203,8 @@ def build_ffmpeg_command( "aac", "-b:a", f"{audio_bitrate}k", + "-ac", + "2", "-movflags", "+faststart", "-map", @@ -279,10 +288,33 @@ def _run_command(command: list[str], *, timeout_seconds: int) -> subprocess.Comp raise RuntimeError(f"Binary not found for command: {command[0]}") from exc if result.returncode != 0: stderr = (result.stderr or "").strip() - raise RuntimeError(f"Command failed ({result.returncode}): {' '.join(command)} | {stderr}") + error_message = f"Command failed ({result.returncode}): {' '.join(command)} | {stderr}" + if _is_deterministic_ffmpeg_failure(command=command, stderr=stderr): + raise ValueError(error_message) + raise RuntimeError(error_message) return result +def _is_deterministic_ffmpeg_failure(*, command: list[str], stderr: str) -> bool: + if not command: + return False + + binary_name = Path(command[0]).name.lower() + if "ffmpeg" not in binary_name: + return False + + stderr_text = stderr.lower() + if any(pattern in stderr_text for pattern in _FFMPEG_DETERMINISTIC_ERROR_PATTERNS): + return True + + if "invalid argument" in stderr_text and ( + "enc:aac" in stderr_text or "aac @" in stderr_text or "channel layout" in stderr_text + ): + return True + + return False + + def _run_async(awaitable: Any) -> Any: return asyncio.run(awaitable) diff --git a/app/tests/test_admin_system_service.py b/app/tests/test_admin_system_service.py new file mode 100644 index 0000000..f5cbe7c --- /dev/null +++ b/app/tests/test_admin_system_service.py @@ -0,0 +1,37 @@ +from __future__ import annotations + +import pytest + +from fileflash.core.settings import Settings +from fileflash.services.admin.system import AdminSystemService + + +class DummySession: + async def scalar(self, _query: object) -> int: + return 0 + + +def make_settings(**overrides: object) -> Settings: + payload = { + "FF_DB_URI": "postgresql://root:pwd@localhost:5432/fileflash", + "JWT_SECRET_KEY": "unit-test-secret-key-1234567890abcd", + } + payload.update(overrides) + return Settings(**payload) + + +@pytest.mark.asyncio +async def test_health_hash_computation_enabled_follows_settings() -> None: + disabled_service = AdminSystemService( + db=DummySession(), + settings=make_settings(UPLOAD_VERIFY_MERGED_OBJECT_HASH=False), + ) + disabled_health = await disabled_service.health() + assert disabled_health.hash_computation_enabled is False + + enabled_service = AdminSystemService( + db=DummySession(), + settings=make_settings(UPLOAD_VERIFY_MERGED_OBJECT_HASH=True), + ) + enabled_health = await enabled_service.health() + assert enabled_health.hash_computation_enabled is True diff --git a/app/tests/test_dev_seed.py b/app/tests/test_dev_seed.py index 2ac47c2..6681920 100644 --- a/app/tests/test_dev_seed.py +++ b/app/tests/test_dev_seed.py @@ -5,8 +5,9 @@ import pytest from fileflash.core.settings import Settings +from fileflash.models.enums import UserRole from fileflash.models.tables_identity import User -from fileflash.services.dev_seed import DevAccountSeeder, initialize_dev_accounts +from fileflash.services.dev_seed import DevAccountSeeder, DevSeedSummary, initialize_dev_accounts class DummySeedSession: @@ -38,7 +39,10 @@ async def _find_existing_user(self, *, spec): # type: ignore[override] for obj in self.db.added: if not isinstance(obj, User): continue - if obj.username.lower() == spec.username.lower() or obj.email.lower() == spec.email.lower(): + if ( + obj.username.lower() == spec.username.lower() + or obj.email.lower() == spec.email.lower() + ): return obj return None @@ -59,7 +63,7 @@ def make_settings(**overrides: object) -> Settings: "JWT_SECRET_KEY": "unit-test-secret-key-1234567890abcd", } payload.update(overrides) - return Settings(**payload) + return Settings(_env_file=None, **payload) @pytest.mark.asyncio @@ -95,12 +99,14 @@ async def test_dev_account_seeder_is_idempotent_and_supports_password_reset(): @pytest.mark.asyncio -async def test_initialize_dev_accounts_skips_auto_run_in_production(monkeypatch: pytest.MonkeyPatch): +async def test_initialize_dev_accounts_skips_auto_run_outside_dev_and_prod( + monkeypatch: pytest.MonkeyPatch, +): guard = AsyncMock(side_effect=AssertionError("SessionLocal should not be called")) monkeypatch.setattr("fileflash.services.dev_seed.SessionLocal", guard) result = await initialize_dev_accounts( - settings=make_settings(APP_ENV="production"), + settings=make_settings(APP_ENV="staging"), auto_run=True, reset_password=False, ) @@ -108,3 +114,59 @@ async def test_initialize_dev_accounts_skips_auto_run_in_production(monkeypatch: assert result is False guard.assert_not_called() + +@pytest.mark.asyncio +async def test_initialize_dev_accounts_uses_env_admin_in_production( + monkeypatch: pytest.MonkeyPatch, +): + captured: dict[str, object] = {} + + class SessionLocalStub: + calls = 0 + + def __call__(self): + self.calls += 1 + return self + + async def __aenter__(self): + return object() + + async def __aexit__(self, exc_type, exc, traceback): + return None + + class SeederStub: + def __init__(self, *, db: object, accounts: tuple[object, ...]) -> None: + captured["db"] = db + captured["accounts"] = accounts + + async def seed(self, *, reset_password: bool = False) -> DevSeedSummary: + captured["reset_password"] = reset_password + return DevSeedSummary(created_users=1) + + session_local = SessionLocalStub() + monkeypatch.setattr("fileflash.services.dev_seed.SessionLocal", session_local) + monkeypatch.setattr("fileflash.services.dev_seed.DevAccountSeeder", SeederStub) + + result = await initialize_dev_accounts( + settings=make_settings( + APP_ENV="production", + DEFAULT_ADMIN_USERNAME="root-admin", + DEFAULT_ADMIN_EMAIL="root-admin@example.com", + DEFAULT_ADMIN_PASSWORD="p" * 32, + ), + auto_run=True, + reset_password=False, + ) + + assert result is True + assert session_local.calls == 1 + accounts = captured["accounts"] + assert isinstance(accounts, tuple) + assert len(accounts) == 1 + account = accounts[0] + assert account.username == "root-admin" + assert account.email == "root-admin@example.com" + assert account.password == "p" * 32 + assert account.role == UserRole.ADMIN + assert captured["reset_password"] is False + diff --git a/app/tests/test_file_download_recycle_service.py b/app/tests/test_file_download_recycle_service.py index cf5ec2c..3d1caba 100644 --- a/app/tests/test_file_download_recycle_service.py +++ b/app/tests/test_file_download_recycle_service.py @@ -291,9 +291,9 @@ async def test_get_preview_stream_prefers_transcoded_object_when_ready(monkeypat storage = DummyStorage() service = FileService(db=session, storage=storage) - file_row = make_file_row(file_id=11, file_name="preview.mp4") - file_row.file_ext = "mp4" - file_row.mime_type = "video/mp4" + file_row = make_file_row(file_id=11, file_name="preview.mkv") + file_row.file_ext = "mkv" + file_row.mime_type = "video/x-matroska" file_row.storage_object_id = 101 source_object = StorageObject( object_id=101, @@ -330,6 +330,7 @@ async def test_get_preview_stream_prefers_transcoded_object_when_ready(monkeypat result = await service.get_preview_stream(user_id=1, file_id="11", range_header=None) assert result.status_code == 200 + assert result.content_type == "video/mp4" assert result.headers["Content-Length"] == "128" diff --git a/app/tests/test_file_folder_patch_routes.py b/app/tests/test_file_folder_patch_routes.py index b289d06..1c38aac 100644 --- a/app/tests/test_file_folder_patch_routes.py +++ b/app/tests/test_file_folder_patch_routes.py @@ -1,12 +1,15 @@ from __future__ import annotations from collections.abc import AsyncIterator -from datetime import UTC, datetime +from datetime import UTC, datetime, timedelta from fastapi import FastAPI from fastapi.testclient import TestClient -from fileflash.core.deps import get_current_user, get_file_service, get_folder_service +from fileflash.core.deps import get_current_user, get_file_service, get_folder_service, get_settings_dep +from fileflash.core.errors import ApiError, api_error_handler +from fileflash.core.security import create_file_preview_token +from fileflash.core.settings import Settings from fileflash.models.tables_identity import User from fileflash.routers.files import router as files_router from fileflash.routers.folders import router as folders_router @@ -47,6 +50,9 @@ def _make_folder_item(*, is_starred: bool) -> FolderItem: class StubFileService: + async def get_file(self, *, user_id: int, file_id: int) -> FileDetails: # noqa: ARG002 + return _make_file_details(name="demo.txt", is_starred=False) + async def rename_file(self, *, user_id: int, file_id: str, payload: RenameFileRequest) -> FileDetails: # noqa: ARG002 return _make_file_details(name=payload.file_name, is_starred=False) @@ -93,17 +99,26 @@ async def toggle_folder_star(self, *, user_id: int, folder_id: str, is_starred: return _make_folder_item(is_starred=is_starred) -def _build_client() -> TestClient: +def _build_client(*, authenticated: bool = True, settings: Settings | None = None) -> TestClient: app = FastAPI() app.include_router(files_router, prefix="/api/v1") app.include_router(folders_router, prefix="/api/v1") + app.add_exception_handler(ApiError, api_error_handler) async def _current_user_override() -> User: + if not authenticated: + raise ApiError(status_code=401, code=401, message="Missing authorization token") return User(user_id=1, username="owner", email="owner@example.com", password_hash="hash") + resolved_settings = settings or Settings( + JWT_SECRET_KEY="unit-test-secret-key-1234567890abcd", + FF_DB_URI="postgresql://u:p@localhost:5432/db", + ) + app.dependency_overrides[get_current_user] = _current_user_override app.dependency_overrides[get_file_service] = lambda: StubFileService() app.dependency_overrides[get_folder_service] = lambda: StubFolderService() + app.dependency_overrides[get_settings_dep] = lambda: resolved_settings return TestClient(app) @@ -142,3 +157,88 @@ def test_get_file_preview_route_returns_stream() -> None: assert response.headers["content-range"] == "bytes 0-3/12" assert response.headers["content-type"].startswith("text/plain") assert response.content == b"prev" + + +def test_create_file_preview_url_route_returns_signed_stream_url() -> None: + settings = Settings( + JWT_SECRET_KEY="unit-test-secret-key-1234567890abcd", + FF_DB_URI="postgresql://u:p@localhost:5432/db", + ) + with _build_client(settings=settings) as client: + response = client.post("/api/v1/files/1/preview-url") + + assert response.status_code == 200 + payload = response.json() + assert payload["success"] is True + assert payload["data"]["url"].startswith("http://testserver/api/v1/files/1/preview-stream?token=") + assert payload["data"]["expiresAt"] + + +def test_create_file_preview_url_route_requires_authentication() -> None: + with _build_client(authenticated=False) as client: + response = client.post("/api/v1/files/1/preview-url") + + assert response.status_code == 401 + assert response.json()["message"] == "Missing authorization token" + + +def test_get_file_preview_stream_route_supports_range_with_valid_token() -> None: + settings = Settings( + JWT_SECRET_KEY="unit-test-secret-key-1234567890abcd", + FF_DB_URI="postgresql://u:p@localhost:5432/db", + ) + token = create_file_preview_token( + user_id=1, + file_id=1, + settings=settings, + expires_at=datetime.now(UTC) + timedelta(minutes=10), + ) + + with _build_client(settings=settings) as client: + response = client.get( + f"/api/v1/files/1/preview-stream?token={token}", + headers={"Range": "bytes=0-3"}, + ) + + assert response.status_code == 206 + assert response.headers["content-range"] == "bytes 0-3/12" + assert response.headers["content-length"] == "4" + assert response.content == b"prev" + + +def test_get_file_preview_stream_route_rejects_mismatched_token_file_id() -> None: + settings = Settings( + JWT_SECRET_KEY="unit-test-secret-key-1234567890abcd", + FF_DB_URI="postgresql://u:p@localhost:5432/db", + ) + token = create_file_preview_token( + user_id=1, + file_id=2, + settings=settings, + expires_at=datetime.now(UTC) + timedelta(minutes=10), + ) + + with _build_client(settings=settings) as client: + response = client.get(f"/api/v1/files/1/preview-stream?token={token}") + + assert response.status_code == 403 + assert response.json()["message"] == "Preview token does not match file" + + +def test_get_file_preview_stream_route_rejects_expired_token() -> None: + settings = Settings( + JWT_SECRET_KEY="unit-test-secret-key-1234567890abcd", + FF_DB_URI="postgresql://u:p@localhost:5432/db", + ) + token = create_file_preview_token( + user_id=1, + file_id=1, + settings=settings, + expires_at=datetime.now(UTC) - timedelta(minutes=10), + ) + + with _build_client(settings=settings) as client: + response = client.get(f"/api/v1/files/1/preview-stream?token={token}") + + assert response.status_code == 401 + assert response.json()["message"] == "Invalid or expired preview token" diff --git a/app/tests/test_security.py b/app/tests/test_security.py index fe29f57..d1b3e45 100644 --- a/app/tests/test_security.py +++ b/app/tests/test_security.py @@ -1,9 +1,13 @@ from __future__ import annotations +from datetime import UTC, datetime, timedelta + from fileflash.core.security import ( create_access_token, + create_file_preview_token, create_refresh_token, decode_access_token, + decode_file_preview_token, get_password_hash, hash_token, verify_password, @@ -31,6 +35,24 @@ def test_access_token_round_trip(): assert "exp" in payload +def test_file_preview_token_round_trip(): + settings = Settings( + JWT_SECRET_KEY="unit-test-secret-key-1234567890abcd", + FF_DB_URI="postgresql://u:p@localhost:5432/db", + ) + token = create_file_preview_token( + user_id=42, + file_id=99, + settings=settings, + expires_at=datetime.now(UTC) + timedelta(minutes=10), + ) + payload = decode_file_preview_token(token=token, settings=settings) + assert payload["sub"] == "42" + assert payload["fileId"] == "99" + assert payload["scope"] == "file.preview" + assert payload["typ"] == "file_preview" + + def test_refresh_token_hash_is_deterministic(): settings = Settings( JWT_SECRET_KEY="unit-test-secret-key-1234567890abcd", diff --git a/app/tests/test_settings.py b/app/tests/test_settings.py index de26b55..738f3ff 100644 --- a/app/tests/test_settings.py +++ b/app/tests/test_settings.py @@ -5,22 +5,30 @@ from fileflash.core.settings import Settings +def make_settings(**overrides: object) -> Settings: + payload = {"FF_DB_URI": "postgresql://root:pwd@localhost:5432/fileflash"} + payload.update(overrides) + return Settings(_env_file=None, **payload) + + def test_async_database_url_conversion(): - settings = Settings(FF_DB_URI="postgresql://root:pwd@localhost:5432/fileflash") + settings = make_settings() assert settings.async_database_url == "postgresql+asyncpg://root:pwd@localhost:5432/fileflash" def test_upload_related_settings_defaults(): - settings = Settings(FF_DB_URI="postgresql://root:pwd@localhost:5432/fileflash") + settings = make_settings() assert settings.object_storage_bucket == "fileflash" assert settings.upload_chunk_size_default == 5 * 1024 * 1024 + assert settings.upload_single_file_size_max == 5 * 1024 * 1024 * 1024 + assert settings.upload_verify_merged_object_hash is False assert settings.upload_session_ttl_seconds == 24 * 3600 assert settings.starred_items_limit == 20 assert settings.worker_process_count == 1 def test_agent_related_settings_defaults(): - settings = Settings(FF_DB_URI="postgresql://root:pwd@localhost:5432/fileflash") + settings = make_settings() assert settings.agent_queue_stream == "fileflash:agents" assert settings.agent_job_timeout_sec == 600 assert settings.agent_tool_timeout_sec == 30 @@ -28,50 +36,46 @@ def test_agent_related_settings_defaults(): def test_app_env_detection(): - dev = Settings(FF_DB_URI="postgresql://root:pwd@localhost:5432/fileflash", APP_ENV="development") + dev = make_settings(APP_ENV="development") assert dev.is_development_env is True assert dev.is_production_env is False - prod = Settings(FF_DB_URI="postgresql://root:pwd@localhost:5432/fileflash", APP_ENV="prod") + prod = make_settings(APP_ENV="prod") assert prod.is_development_env is False assert prod.is_production_env is True def test_worker_process_count_from_env(): - settings = Settings( - FF_DB_URI="postgresql://root:pwd@localhost:5432/fileflash", + settings = make_settings( WORKER_PROCESS_COUNT="3", ) assert settings.worker_process_count == 3 def test_starred_items_limit_from_env(): - settings = Settings( - FF_DB_URI="postgresql://root:pwd@localhost:5432/fileflash", + settings = make_settings( STARRED_ITEMS_LIMIT="12", ) assert settings.starred_items_limit == 12 def test_verify_base_url_defaults_to_localhost_in_development(): - settings = Settings( - FF_DB_URI="postgresql://root:pwd@localhost:5432/fileflash", + settings = make_settings( APP_ENV="development", + EMAIL_VERIFY_BASE_URL="", ) assert settings.normalized_email_verify_base_url == "http://localhost:8080" def test_verify_base_url_adds_http_when_scheme_missing(): - settings = Settings( - FF_DB_URI="postgresql://root:pwd@localhost:5432/fileflash", + settings = make_settings( EMAIL_VERIFY_BASE_URL="localhost:3000", ) assert settings.normalized_email_verify_base_url == "http://localhost:3000" def test_mail_configuration_issues_includes_missing_required_fields(): - settings = Settings( - FF_DB_URI="postgresql://root:pwd@localhost:5432/fileflash", + settings = make_settings( APP_ENV="development", MAIL_FROM="", MAIL_SERVER="", @@ -85,8 +89,7 @@ def test_mail_configuration_issues_includes_missing_required_fields(): def test_mail_configuration_rejects_both_tls_modes_enabled(): - settings = Settings( - FF_DB_URI="postgresql://root:pwd@localhost:5432/fileflash", + settings = make_settings( EMAIL_VERIFY_BASE_URL="http://localhost:5173", MAIL_FROM="demo@example.com", MAIL_SERVER="smtp.example.com", @@ -96,12 +99,15 @@ def test_mail_configuration_rejects_both_tls_modes_enabled(): MAIL_STARTTLS=True, MAIL_SSL_TLS=True, ) - assert "MAIL_SSL_TLS and MAIL_STARTTLS cannot both be true" in settings.mail_configuration_issues + assert ( + "MAIL_SSL_TLS and MAIL_STARTTLS cannot both be true" + in settings.mail_configuration_issues + ) def test_assert_runtime_security_raises_when_jwt_secret_too_short(): - settings = Settings( - FF_DB_URI="postgresql://root:pwd@localhost:5432/fileflash", + settings = make_settings( + APP_ENV="development", JWT_SECRET_KEY="short-key", ) with pytest.raises(ValueError, match="JWT_SECRET_KEY must be at least 32 bytes"): @@ -109,11 +115,46 @@ def test_assert_runtime_security_raises_when_jwt_secret_too_short(): def test_assert_runtime_security_raises_when_token_hash_secret_too_short(): - settings = Settings( - FF_DB_URI="postgresql://root:pwd@localhost:5432/fileflash", + settings = make_settings( + APP_ENV="development", JWT_SECRET_KEY="x" * 32, TOKEN_HASH_SECRET="short-key", ) with pytest.raises(ValueError, match="TOKEN_HASH_SECRET must be at least 32 bytes"): settings.assert_runtime_security() + +def test_assert_runtime_security_requires_default_admin_env_in_production(): + settings = make_settings( + APP_ENV="production", + JWT_SECRET_KEY="x" * 32, + ) + + with pytest.raises(ValueError, match="DEFAULT_ADMIN_USERNAME is required in production"): + settings.assert_runtime_security() + + +def test_assert_runtime_security_rejects_short_default_admin_password_in_production(): + settings = make_settings( + APP_ENV="production", + JWT_SECRET_KEY="x" * 32, + DEFAULT_ADMIN_USERNAME="admin", + DEFAULT_ADMIN_EMAIL="admin@example.com", + DEFAULT_ADMIN_PASSWORD="short-password", + ) + + with pytest.raises(ValueError, match="DEFAULT_ADMIN_PASSWORD must be at least 32 bytes"): + settings.assert_runtime_security() + + +def test_assert_runtime_security_accepts_default_admin_env_in_production(): + settings = make_settings( + APP_ENV="production", + JWT_SECRET_KEY="x" * 32, + DEFAULT_ADMIN_USERNAME="admin", + DEFAULT_ADMIN_EMAIL="admin@example.com", + DEFAULT_ADMIN_PASSWORD="p" * 32, + ) + + settings.assert_runtime_security() + diff --git a/app/tests/test_startup_fail_fast.py b/app/tests/test_startup_fail_fast.py index a2ea383..b6c23e8 100644 --- a/app/tests/test_startup_fail_fast.py +++ b/app/tests/test_startup_fail_fast.py @@ -5,15 +5,30 @@ import pytest +from fileflash.core.settings import Settings from fileflash.main import lifespan from fileflash.s3.minio_client import ObjectStorageAuthError +def make_settings(**overrides: object) -> Settings: + payload = { + "FF_DB_URI": "postgresql://root:pwd@localhost:5432/fileflash", + "APP_ENV": "production", + "JWT_SECRET_KEY": "unit-test-secret-key-1234567890abcd", + "DEFAULT_ADMIN_USERNAME": "admin", + "DEFAULT_ADMIN_EMAIL": "admin@example.com", + "DEFAULT_ADMIN_PASSWORD": "p" * 32, + } + payload.update(overrides) + return Settings(_env_file=None, **payload) + + @pytest.mark.asyncio async def test_lifespan_fails_fast_when_database_check_fails(monkeypatch: pytest.MonkeyPatch): verify = AsyncMock(side_effect=RuntimeError("database unavailable")) verify_schema = AsyncMock() seed = AsyncMock() + monkeypatch.setattr("fileflash.main.settings", make_settings()) monkeypatch.setattr("fileflash.main.verify_database_connection", verify) monkeypatch.setattr("fileflash.main.verify_schema_compatibility", verify_schema) monkeypatch.setattr("fileflash.main.initialize_dev_accounts", seed) @@ -32,6 +47,7 @@ async def test_lifespan_fails_fast_when_schema_check_fails(monkeypatch: pytest.M verify = AsyncMock() verify_schema = AsyncMock(side_effect=RuntimeError("schema incompatible")) seed = AsyncMock() + monkeypatch.setattr("fileflash.main.settings", make_settings()) monkeypatch.setattr("fileflash.main.verify_database_connection", verify) monkeypatch.setattr("fileflash.main.verify_schema_compatibility", verify_schema) monkeypatch.setattr("fileflash.main.initialize_dev_accounts", seed) @@ -50,7 +66,10 @@ async def test_lifespan_fails_fast_when_object_storage_check_fails(monkeypatch: verify = AsyncMock() verify_schema = AsyncMock() seed = AsyncMock() - storage = SimpleNamespace(ensure_bucket=AsyncMock(side_effect=ObjectStorageAuthError("bad credentials"))) + storage = SimpleNamespace( + ensure_bucket=AsyncMock(side_effect=ObjectStorageAuthError("bad credentials")) + ) + monkeypatch.setattr("fileflash.main.settings", make_settings()) monkeypatch.setattr("fileflash.main.verify_database_connection", verify) monkeypatch.setattr("fileflash.main.verify_schema_compatibility", verify_schema) monkeypatch.setattr("fileflash.main.get_object_storage", lambda: storage) @@ -64,3 +83,54 @@ async def test_lifespan_fails_fast_when_object_storage_check_fails(monkeypatch: verify_schema.assert_awaited_once() storage.ensure_bucket.assert_awaited_once() seed.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_lifespan_fails_fast_when_prod_admin_env_missing(monkeypatch: pytest.MonkeyPatch): + verify = AsyncMock() + verify_schema = AsyncMock() + seed = AsyncMock() + monkeypatch.setattr( + "fileflash.main.settings", + Settings( + _env_file=None, + FF_DB_URI="postgresql://root:pwd@localhost:5432/fileflash", + APP_ENV="production", + JWT_SECRET_KEY="unit-test-secret-key-1234567890abcd", + ), + ) + monkeypatch.setattr("fileflash.main.verify_database_connection", verify) + monkeypatch.setattr("fileflash.main.verify_schema_compatibility", verify_schema) + monkeypatch.setattr("fileflash.main.initialize_dev_accounts", seed) + + with pytest.raises(ValueError, match="DEFAULT_ADMIN_USERNAME is required in production"): + async with lifespan(object()): + pass + + verify.assert_not_awaited() + verify_schema.assert_not_awaited() + seed.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_lifespan_fails_fast_when_prod_admin_password_too_short( + monkeypatch: pytest.MonkeyPatch, +): + verify = AsyncMock() + verify_schema = AsyncMock() + seed = AsyncMock() + monkeypatch.setattr( + "fileflash.main.settings", + make_settings(DEFAULT_ADMIN_PASSWORD="short-password"), + ) + monkeypatch.setattr("fileflash.main.verify_database_connection", verify) + monkeypatch.setattr("fileflash.main.verify_schema_compatibility", verify_schema) + monkeypatch.setattr("fileflash.main.initialize_dev_accounts", seed) + + with pytest.raises(ValueError, match="DEFAULT_ADMIN_PASSWORD must be at least 32 bytes"): + async with lifespan(object()): + pass + + verify.assert_not_awaited() + verify_schema.assert_not_awaited() + seed.assert_not_awaited() diff --git a/app/tests/test_tasks_transcode.py b/app/tests/test_tasks_transcode.py index 63ad82a..3998257 100644 --- a/app/tests/test_tasks_transcode.py +++ b/app/tests/test_tasks_transcode.py @@ -5,7 +5,9 @@ from pathlib import Path from types import SimpleNamespace -from fileflash.tasks.transcode import build_ffmpeg_command, run_media_transcode +import pytest + +from fileflash.tasks.transcode import _run_command, build_ffmpeg_command, run_media_transcode def test_build_ffmpeg_command_for_video_contains_profile_flags(): @@ -24,9 +26,53 @@ def test_build_ffmpeg_command_for_video_contains_profile_flags(): assert "+faststart" in command assert "-pix_fmt" in command assert "yuv420p" in command + assert "-ac" in command + assert command[command.index("-ac") + 1] == "2" assert command[-1] == "output.mp4" +def test_build_ffmpeg_command_for_audio_forces_stereo(): + command = build_ffmpeg_command( + input_path=Path("input.wav"), + output_path=Path("output.m4a"), + media_type="audio", + ffmpeg_binary="ffmpeg", + payload={}, + ) + + assert command[:4] == ["ffmpeg", "-y", "-i", "input.wav"] + assert "-vn" in command + assert "-c:a" in command + assert "-ac" in command + assert command[command.index("-ac") + 1] == "2" + assert command[-1] == "output.m4a" + + +def test_run_command_raises_value_error_for_deterministic_ffmpeg_encoder_failure(monkeypatch): + def fake_run( + command: list[str], + *, + check: bool, + capture_output: bool, + text: bool, + timeout: int, + ) -> subprocess.CompletedProcess[str]: + _ = (check, capture_output, text, timeout) + return subprocess.CompletedProcess( + command, + returncode=1, + stdout="", + stderr='[aac @ 0000020b7eff3d00] Unsupported channel layout "6 channels"\n' + "Error while opening encoder - maybe incorrect parameters\n" + "[af#0:1 @ 0000020b003fde00] Task finished with error code: -22 (Invalid argument)", + ) + + monkeypatch.setattr("fileflash.tasks.transcode.subprocess.run", fake_run) + + with pytest.raises(ValueError, match="Command failed"): + _run_command(["ffmpeg", "-i", "input.mp4", "output.mp4"], timeout_seconds=30) + + def test_run_media_transcode_storage_mode_with_mocked_subprocess(monkeypatch, tmp_path: Path): source_probe = { "streams": [{"codec_type": "audio", "codec_name": "pcm_s16le", "sample_rate": "44100"}], diff --git a/app/tests/test_upload_service.py b/app/tests/test_upload_service.py index af27651..58d3382 100644 --- a/app/tests/test_upload_service.py +++ b/app/tests/test_upload_service.py @@ -168,6 +168,57 @@ async def test_preflight_returns_503_when_storage_is_unavailable(monkeypatch: py cleanup_mock.assert_not_awaited() +@pytest.mark.asyncio +async def test_preflight_allows_20gb_file_size_boundary(monkeypatch: pytest.MonkeyPatch): + max_file_size = 20 * 1024 * 1024 * 1024 + settings = make_settings(UPLOAD_SINGLE_FILE_SIZE_MAX=str(max_file_size)) + session = DummySession() + service, _storage = make_service(session, settings=settings) + + monkeypatch.setattr(service, "_cleanup_expired_tasks", AsyncMock()) + monkeypatch.setattr(service, "_resolve_folder_id", AsyncMock(return_value=7)) + monkeypatch.setattr(service, "_find_storage_object", AsyncMock(return_value=None)) + monkeypatch.setattr(service, "_find_active_task", AsyncMock(return_value=None)) + + response = await service.preflight( + user_id=2, + payload=UploadPreflightRequest( + fileHash="b" * 64, + fileName="video.mp4", + fileSize=max_file_size, + mimeType="video/mp4", + parentId="7", + ), + ) + + assert response.status == "UPLOADING" + assert response.upload_id is not None + + +@pytest.mark.asyncio +async def test_preflight_rejects_file_size_over_20gb(): + max_file_size = 20 * 1024 * 1024 * 1024 + settings = make_settings(UPLOAD_SINGLE_FILE_SIZE_MAX=str(max_file_size)) + session = DummySession() + service, storage = make_service(session, settings=settings) + + with pytest.raises(ApiError) as exc: + await service.preflight( + user_id=2, + payload=UploadPreflightRequest( + fileHash="b" * 64, + fileName="video.mp4", + fileSize=max_file_size + 1, + mimeType="video/mp4", + parentId="7", + ), + ) + + assert exc.value.status_code == 413 + assert exc.value.code == 413 + storage.ensure_bucket.assert_not_awaited() + + @pytest.mark.asyncio async def test_list_recoverable_sessions_returns_only_active_non_expired_tasks(): session = DummySession() @@ -458,9 +509,56 @@ async def test_merge_returns_conflict_without_strategy(monkeypatch: pytest.Monke @pytest.mark.asyncio -async def test_merge_marks_failed_on_hash_mismatch(monkeypatch: pytest.MonkeyPatch): +async def test_merge_skips_full_hash_verification_in_fast_mode(monkeypatch: pytest.MonkeyPatch): session = DummySession() service, storage = make_service(session) + task = UploadTask( + task_id=19, + user_id=2, + folder_id=2, + file_name="asset.bin", + mime_type="application/octet-stream", + bucket_name="fileflash", + object_key="objects/u2/new", + object_hash="a" * 64, + total_size=4, + chunk_size=2, + upload_id="upload-fast-mode", + status=UploadTaskStatus.UPLOADING, + expired_at=datetime.now(UTC) + timedelta(hours=1), + ) + parts = [ + UploadTaskPart(task_id=19, part_number=0, part_size=2, status=UploadPartStatus.UPLOADED), + UploadTaskPart(task_id=19, part_number=1, part_size=2, status=UploadPartStatus.UPLOADED), + ] + session.scalars_queue = [parts] + + monkeypatch.setattr(service, "_get_task_for_update", AsyncMock(return_value=task)) + monkeypatch.setattr(service, "_resolve_folder_id", AsyncMock(return_value=2)) + monkeypatch.setattr(service, "_find_conflict_file", AsyncMock(return_value=None)) + monkeypatch.setattr(service, "_find_storage_object", AsyncMock(return_value=None)) + + response = await service.merge_chunks( + user_id=2, + upload_id="upload-fast-mode", + payload=MergeChunksRequest( + fileHash="a" * 64, + fileName="asset.bin", + mimeType="application/octet-stream", + parentId="2", + conflictStrategy="rename", + ), + ) + + assert response.file_name == "asset.bin" + storage.compute_object_hash.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_merge_marks_failed_on_hash_mismatch(monkeypatch: pytest.MonkeyPatch): + settings = make_settings(UPLOAD_VERIFY_MERGED_OBJECT_HASH=True) + session = DummySession() + service, storage = make_service(session, settings=settings) task = UploadTask( task_id=20, user_id=2, diff --git a/web/src/api/file.ts b/web/src/api/file.ts index e8db9ac..38c5d32 100644 --- a/web/src/api/file.ts +++ b/web/src/api/file.ts @@ -6,6 +6,7 @@ import type { FileItem, ContentItem, FileDetails, + FilePreviewUrlResponse, GetFilesRequest, RenameFileRequest, MoveFileRequest, @@ -150,6 +151,10 @@ export const previewFile = (fileId: string) => { return http.get(`/files/${fileId}/preview`, undefined, { responseType: 'blob' }); }; +export const getPreviewUrl = (fileId: string) => { + return http.post(`/files/${fileId}/preview-url`); +}; + /** * 获取文件缩略图 * @param fileId 文件ID diff --git a/web/src/components/common/PromptDialog.spec.ts b/web/src/components/common/PromptDialog.spec.ts new file mode 100644 index 0000000..e0b4377 --- /dev/null +++ b/web/src/components/common/PromptDialog.spec.ts @@ -0,0 +1,50 @@ +import { afterEach, describe, expect, it } from 'vitest'; +import { nextTick } from 'vue'; +import { mount } from '../../test/mount'; +import PromptDialog from './PromptDialog.vue'; +import { ui, uiState } from '../../utils/ui'; + +const flush = async () => { + await Promise.resolve(); + await nextTick(); + await Promise.resolve(); +}; + +describe('PromptDialog', () => { + afterEach(() => { + uiState.prompt = null; + }); + + it('shows one footer close button for copyText flow and keeps top-right close button', async () => { + const wrapper = mount(PromptDialog); + const pending = ui.copyText({ + title: 'Generated Password', + message: 'Copy this password:', + text: 'AUTO-112233', + }); + await flush(); + + expect(wrapper.find('.modal-close').exists()).toBe(true); + const footerButtons = wrapper.findAll('.modal-footer .btn'); + expect(footerButtons).toHaveLength(1); + expect(footerButtons[0]?.text()).toBe('Close'); + + await wrapper.find('.modal-close').trigger('click'); + await pending; + expect(uiState.prompt).toBeNull(); + }); + + it('keeps cancel + confirm buttons for normal promptText flow', async () => { + const wrapper = mount(PromptDialog); + const pending = ui.promptText({ message: 'Type value' }); + await flush(); + + const footerButtons = wrapper.findAll('.modal-footer .btn'); + expect(footerButtons).toHaveLength(2); + expect(footerButtons[0]?.text()).toBe('Cancel'); + expect(footerButtons[1]?.text()).toBe('Confirm'); + + await footerButtons[1]!.trigger('click'); + await expect(pending).resolves.toBe(''); + }); +}); diff --git a/web/src/components/common/PromptDialog.vue b/web/src/components/common/PromptDialog.vue index af6a814..707f7fd 100644 --- a/web/src/components/common/PromptDialog.vue +++ b/web/src/components/common/PromptDialog.vue @@ -42,7 +42,7 @@ const handleConfirm = () => { />