diff --git a/agents.md b/agents.md index 6ea0af3..dd52b63 100644 --- a/agents.md +++ b/agents.md @@ -12,13 +12,13 @@ - 前端类型定义: `web/src/types` - 前端 mock: `web/src/mock/handlers` + `web/src/mock/state.ts` - 前端鉴权状态: `web/src/store/user.ts` -- 后端入口: `app/src/main.py` -- 后端路由: `app/src/routers` -- 后端服务层: `app/src/services` -- 后端依赖/鉴权: `app/src/core/deps.py` -- 后端 schema: `app/src/schemas` -- 后端模型: `app/src/models/tables_*.py` -- 数据库会话: `app/src/db` +- 后端入口: `app/src/fileflash/main.py` +- 后端路由: `app/src/fileflash/routers` +- 后端服务层: `app/src/fileflash/services` +- 后端依赖/鉴权: `app/src/fileflash/core/deps.py` +- 后端 schema: `app/src/fileflash/schemas` +- 后端模型: `app/src/fileflash/models/tables_*.py` +- 数据库会话: `app/src/fileflash/db` ## 3. 全局接口契约 @@ -75,7 +75,7 @@ 2. `web/src/api` 3. `web/src/mock/handlers` + `web/src/mock/state.ts` 4. 对应页面/store - 5. 后端 `app/src/schemas` + `app/src/routers/services` + 5. 后端 `app/src/fileflash/schemas` + `app/src/fileflash/routers/services` - 鉴权状态规则: - 仅持久化 `accessToken`(当前策略) - 刷新流程依赖 Cookie(`axios.withCredentials = true`) @@ -103,7 +103,7 @@ - 前端类型检查: `bun run check`(`web` 目录) - 前端构建: `bun run build`(`web` 目录) - 后端测试: `uv run pytest`(`app` 目录) -- 后端启动冒烟: `uv run python -c "from src.main import app; print(app.title)"`(`app` 目录) +- 后端启动冒烟: `uv run python -c "from fileflash.main import app; print(app.title)"`(`app` 目录) ## 10. 安全与配置要求 diff --git a/app/src/.env.example b/app/.env.example similarity index 71% rename from app/src/.env.example rename to app/.env.example index 5399e13..b466c14 100644 --- a/app/src/.env.example +++ b/app/.env.example @@ -3,7 +3,10 @@ FF_DB_URI=postgresql://root:password@localhost:5432/fileflash # DATABASE_URL=postgresql://root:password@localhost:5432/fileflash APP_ENV=development -JWT_SECRET_KEY=please-change-me +JWT_SECRET_KEY=please-set-at-least-32-bytes-secret-key +# Optional dedicated HMAC key for token hash persistence. +# Falls back to JWT_SECRET_KEY when omitted. +# TOKEN_HASH_SECRET=please-set-at-least-32-bytes-and-different-from-jwt-secret ACCESS_TOKEN_EXPIRE_MINUTES=4320 REFRESH_TOKEN_EXPIRE_DAYS=7 @@ -23,12 +26,14 @@ UPLOAD_CHUNK_SIZE_DEFAULT=5242880 UPLOAD_CHUNK_SIZE_MIN=1048576 UPLOAD_CHUNK_SIZE_MAX=16777216 UPLOAD_SINGLE_FILE_SIZE_MAX=5368709120 +STARRED_ITEMS_LIMIT=20 UPLOAD_SESSION_TTL_HOURS=24 UPLOAD_TEMP_PREFIX=tmp UPLOAD_OBJECT_PREFIX=objects WORKER_POLL_INTERVAL_SECONDS=2 WORKER_CONCURRENCY=2 +WORKER_PROCESS_COUNT=1 WORKER_TASK_TIMEOUT_SECONDS=900 WORKER_DEFAULT_MAX_ATTEMPTS=5 WORKER_RETRY_BACKOFF_SECONDS=30,120,600,1800,7200 @@ -60,8 +65,15 @@ AGENT_MCP_ENDPOINTS=[] # FFPROBE_BINARY=ffprobe # Optional SMTP settings for real email delivery. -# MAIL_FROM=no-reply@example.com -# MAIL_USERNAME= -# MAIL_PASSWORD= -# MAIL_SERVER= -# MAIL_PORT=587 +# In development, EMAIL_VERIFY_BASE_URL defaults to http://localhost:8080 when empty. +# EMAIL_VERIFY_BASE_URL=http://localhost:5173 +# For providers like 163, MAIL_FROM should match MAIL_USERNAME. +# MAIL_FROM=your-account@example.com +# MAIL_SERVER=smtp.example.com +# MAIL_PORT=465 +# MAIL_USERNAME=your-account@example.com +# MAIL_PASSWORD=replace-with-app-password +# MAIL_STARTTLS=false +# MAIL_SSL_TLS=true +# MAIL_USE_CREDENTIALS=true +# MAIL_VALIDATE_CERTS=true diff --git a/app/README.md b/app/README.md index e69de29..1ca8092 100644 --- a/app/README.md +++ b/app/README.md @@ -0,0 +1,36 @@ +## Run Backend (API + Workers) + +Use one command to start backend API and file workers together: + +```bash +uv run python -m fileflash.scripts.run_with_workers +``` + +Common options: + +```bash +# custom host/port +uv run python -m fileflash.scripts.run_with_workers --host 127.0.0.1 --port 8080 + +# start multiple worker processes +uv run python -m fileflash.scripts.run_with_workers --worker-count 2 + +# API only (without workers) +uv run python -m fileflash.scripts.run_with_workers --no-worker +``` + +Notes: +- This runner starts `uvicorn fileflash.main:app` and `python -m fileflash.workers.consumer`. +- If any subprocess exits, the runner stops all other subprocesses. +- If your environment resolves project scripts correctly, `uv run fileflash-dev` is equivalent. + +## Database Migration Requirement + +Before starting API processes, ensure Flyway migrations are fully applied (including `V10__identity_avatar.sql` and later). + +Recommended startup order: +1. Start PostgreSQL +2. Run Flyway migrate +3. Start API (`uv run fileflash`) or runner (`uv run fileflash-dev`) + +If the schema is outdated, API startup will fail fast with an explicit compatibility error. diff --git a/app/pyproject.toml b/app/pyproject.toml index bcd8655..6a8e1c7 100644 --- a/app/pyproject.toml +++ b/app/pyproject.toml @@ -21,7 +21,8 @@ dependencies = [ "python-multipart>=0.0.24", ] [project.scripts] -fileflash = "main:main" +fileflash = "fileflash.main:main" +fileflash-dev = "fileflash.scripts.run_with_workers:main" [dependency-groups] dev = [ @@ -43,3 +44,10 @@ target-version = "py312" [tool.ruff.lint] select = ["E", "F", "I", "UP", "B"] + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.hatch.build.targets.wheel] +packages = ["src/fileflash"] \ No newline at end of file diff --git a/app/src/db/engine.py b/app/src/db/engine.py deleted file mode 100644 index 02d46f3..0000000 --- a/app/src/db/engine.py +++ /dev/null @@ -1,19 +0,0 @@ -from __future__ import annotations - -from sqlalchemy import text -from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine - -from ..core.settings import get_settings - -settings = get_settings() - -engine: AsyncEngine = create_async_engine( - settings.async_database_url, - echo=False, - pool_pre_ping=True, -) - - -async def verify_database_connection() -> None: - async with engine.connect() as connection: - await connection.execute(text("SELECT 1")) diff --git a/app/src/__init__.py b/app/src/fileflash/__init__.py similarity index 100% rename from app/src/__init__.py rename to app/src/fileflash/__init__.py diff --git a/app/src/agents/__init__.py b/app/src/fileflash/agents/__init__.py similarity index 100% rename from app/src/agents/__init__.py rename to app/src/fileflash/agents/__init__.py diff --git a/app/src/agents/harness/__init__.py b/app/src/fileflash/agents/harness/__init__.py similarity index 100% rename from app/src/agents/harness/__init__.py rename to app/src/fileflash/agents/harness/__init__.py diff --git a/app/src/agents/harness/budget.py b/app/src/fileflash/agents/harness/budget.py similarity index 100% rename from app/src/agents/harness/budget.py rename to app/src/fileflash/agents/harness/budget.py diff --git a/app/src/agents/harness/checkpoint.py b/app/src/fileflash/agents/harness/checkpoint.py similarity index 100% rename from app/src/agents/harness/checkpoint.py rename to app/src/fileflash/agents/harness/checkpoint.py diff --git a/app/src/agents/harness/cost.py b/app/src/fileflash/agents/harness/cost.py similarity index 100% rename from app/src/agents/harness/cost.py rename to app/src/fileflash/agents/harness/cost.py diff --git a/app/src/agents/harness/events.py b/app/src/fileflash/agents/harness/events.py similarity index 100% rename from app/src/agents/harness/events.py rename to app/src/fileflash/agents/harness/events.py diff --git a/app/src/agents/harness/memory.py b/app/src/fileflash/agents/harness/memory.py similarity index 100% rename from app/src/agents/harness/memory.py rename to app/src/fileflash/agents/harness/memory.py diff --git a/app/src/agents/harness/policy.py b/app/src/fileflash/agents/harness/policy.py similarity index 100% rename from app/src/agents/harness/policy.py rename to app/src/fileflash/agents/harness/policy.py diff --git a/app/src/agents/harness/prompt.py b/app/src/fileflash/agents/harness/prompt.py similarity index 100% rename from app/src/agents/harness/prompt.py rename to app/src/fileflash/agents/harness/prompt.py diff --git a/app/src/agents/harness/router.py b/app/src/fileflash/agents/harness/router.py similarity index 100% rename from app/src/agents/harness/router.py rename to app/src/fileflash/agents/harness/router.py diff --git a/app/src/agents/runtime/__init__.py b/app/src/fileflash/agents/runtime/__init__.py similarity index 100% rename from app/src/agents/runtime/__init__.py rename to app/src/fileflash/agents/runtime/__init__.py diff --git a/app/src/agents/runtime/execute_runner.py b/app/src/fileflash/agents/runtime/execute_runner.py similarity index 100% rename from app/src/agents/runtime/execute_runner.py rename to app/src/fileflash/agents/runtime/execute_runner.py diff --git a/app/src/agents/runtime/plan_runner.py b/app/src/fileflash/agents/runtime/plan_runner.py similarity index 100% rename from app/src/agents/runtime/plan_runner.py rename to app/src/fileflash/agents/runtime/plan_runner.py diff --git a/app/src/agents/runtime/subagent_runner.py b/app/src/fileflash/agents/runtime/subagent_runner.py similarity index 100% rename from app/src/agents/runtime/subagent_runner.py rename to app/src/fileflash/agents/runtime/subagent_runner.py diff --git a/app/src/core/__init__.py b/app/src/fileflash/core/__init__.py similarity index 100% rename from app/src/core/__init__.py rename to app/src/fileflash/core/__init__.py diff --git a/app/src/core/deps.py b/app/src/fileflash/core/deps.py similarity index 91% rename from app/src/core/deps.py rename to app/src/fileflash/core/deps.py index c2d078f..c489641 100644 --- a/app/src/core/deps.py +++ b/app/src/fileflash/core/deps.py @@ -21,11 +21,13 @@ from ..services.agent import ExecuteService, McpService, MemoryService, PlanService, SessionService, SettingsService, SkillService from ..services.auth import AuthService from ..services.background_jobs import BackgroundJobService +from ..services.email_delivery import VerificationEmailDeliveryService from ..services.file import FileService from ..services.folder import FolderService from ..services.job_queue import RedisStreamJobQueue from ..services.messaging import InProcessAuthEventPublisher from ..services.rate_limiter import RedisRateLimiter +from ..services.registration_email_domain_rule import RegistrationEmailDomainRuleService from ..services.share import ShareService from ..services.upload import UploadService from ..s3 import MinioObjectStorageClient @@ -108,15 +110,23 @@ def get_auth_service( settings=settings, rate_limiter=rate_limiter, event_publisher=event_publisher, + verification_email_delivery=VerificationEmailDeliveryService(settings=settings), ) +def get_registration_email_domain_rule_service( + db: AsyncSession = Depends(get_db), +) -> RegistrationEmailDomainRuleService: + return RegistrationEmailDomainRuleService(db=db) + + def get_upload_service( db: AsyncSession = Depends(get_db), settings: Settings = Depends(get_settings_dep), storage: MinioObjectStorageClient = Depends(get_object_storage), + jobs: BackgroundJobService = Depends(get_background_job_service), ) -> UploadService: - return UploadService(db=db, settings=settings, storage=storage) + return UploadService(db=db, settings=settings, storage=storage, jobs=jobs) def get_share_service( @@ -138,14 +148,23 @@ def get_archive_service( def get_file_service( db: AsyncSession = Depends(get_db), storage: MinioObjectStorageClient = Depends(get_object_storage), + settings: Settings = Depends(get_settings_dep), ) -> FileService: - return FileService(db=db, storage=storage) + return FileService( + db=db, + storage=storage, + starred_items_limit=settings.starred_items_limit, + ) def get_folder_service( db: AsyncSession = Depends(get_db), + settings: Settings = Depends(get_settings_dep), ) -> FolderService: - return FolderService(db=db) + return FolderService( + db=db, + starred_items_limit=settings.starred_items_limit, + ) diff --git a/app/src/core/errors.py b/app/src/fileflash/core/errors.py similarity index 100% rename from app/src/core/errors.py rename to app/src/fileflash/core/errors.py diff --git a/app/src/core/http_headers.py b/app/src/fileflash/core/http_headers.py similarity index 100% rename from app/src/core/http_headers.py rename to app/src/fileflash/core/http_headers.py diff --git a/app/src/core/middleware.py b/app/src/fileflash/core/middleware.py similarity index 100% rename from app/src/core/middleware.py rename to app/src/fileflash/core/middleware.py diff --git a/app/src/core/mime.py b/app/src/fileflash/core/mime.py similarity index 100% rename from app/src/core/mime.py rename to app/src/fileflash/core/mime.py diff --git a/app/src/core/security.py b/app/src/fileflash/core/security.py similarity index 91% rename from app/src/core/security.py rename to app/src/fileflash/core/security.py index 01658f4..63f3dda 100644 --- a/app/src/core/security.py +++ b/app/src/fileflash/core/security.py @@ -1,6 +1,7 @@ from __future__ import annotations import hashlib +import hmac import secrets import uuid from datetime import UTC, datetime, timedelta @@ -26,8 +27,9 @@ def create_refresh_token() -> str: return secrets.token_urlsafe(48) -def hash_token(token: str) -> str: - return hashlib.sha256(token.encode("utf-8")).hexdigest() +def hash_token(token: str, settings: Settings) -> str: + secret = settings.effective_token_hash_secret.encode("utf-8") + return hmac.new(secret, token.encode("utf-8"), hashlib.sha256).hexdigest() def create_access_token(user_id: int, settings: Settings) -> str: diff --git a/app/src/core/settings.py b/app/src/fileflash/core/settings.py similarity index 68% rename from app/src/core/settings.py rename to app/src/fileflash/core/settings.py index 5827015..bc7e066 100644 --- a/app/src/core/settings.py +++ b/app/src/fileflash/core/settings.py @@ -3,6 +3,8 @@ from functools import lru_cache from os import cpu_count from pathlib import Path +from typing import ClassVar +from urllib.parse import urlsplit from pydantic import Field from pydantic_settings import BaseSettings, SettingsConfigDict @@ -14,8 +16,10 @@ def _default_worker_concurrency() -> int: class Settings(BaseSettings): + MIN_SECRET_LENGTH: ClassVar[int] = 32 + model_config = SettingsConfigDict( - env_file=str(Path(__file__).resolve().parents[1] / ".env"), + env_file=str(Path(__file__).resolve().parents[3] / ".env"), env_file_encoding="utf-8", extra="ignore", ) @@ -31,6 +35,7 @@ class Settings(BaseSettings): default="change-this-in-production-please-use-32-plus-bytes", alias="JWT_SECRET_KEY", ) + token_hash_secret: str | None = Field(default=None, alias="TOKEN_HASH_SECRET") jwt_algorithm: str = "HS256" access_token_expire_minutes: int = 60 * 24 * 3 refresh_token_expire_days: int = 7 @@ -45,6 +50,17 @@ class Settings(BaseSettings): redis_url: str | None = Field(default=None, alias="REDIS_URL") rabbitmq_url: str | None = Field(default=None, alias="RABBITMQ_URL") + email_verify_base_url: str = Field(default="", alias="EMAIL_VERIFY_BASE_URL") + mail_from: str | None = Field(default=None, alias="MAIL_FROM") + mail_server: str | None = Field(default=None, alias="MAIL_SERVER") + mail_port: int = Field(default=587, alias="MAIL_PORT") + mail_username: str | None = Field(default=None, alias="MAIL_USERNAME") + mail_password: str | None = Field(default=None, alias="MAIL_PASSWORD") + mail_starttls: bool = Field(default=True, alias="MAIL_STARTTLS") + mail_ssl_tls: bool = Field(default=False, alias="MAIL_SSL_TLS") + mail_use_credentials: bool = Field(default=True, alias="MAIL_USE_CREDENTIALS") + mail_validate_certs: bool = Field(default=True, alias="MAIL_VALIDATE_CERTS") + object_storage_endpoint: str = Field(default="localhost:9000", alias="OBJECT_STORAGE_ENDPOINT") object_storage_access_key: str = Field(default="admin", alias="OBJECT_STORAGE_ACCESS_KEY") object_storage_secret_key: str = Field(default="minio-admin", alias="OBJECT_STORAGE_SECRET_KEY") @@ -56,6 +72,7 @@ class Settings(BaseSettings): upload_chunk_size_min: int = Field(default=1 * 1024 * 1024, alias="UPLOAD_CHUNK_SIZE_MIN") upload_chunk_size_max: int = Field(default=16 * 1024 * 1024, alias="UPLOAD_CHUNK_SIZE_MAX") upload_single_file_size_max: int = Field(default=5 * 1024 * 1024 * 1024, alias="UPLOAD_SINGLE_FILE_SIZE_MAX") + starred_items_limit: int = Field(default=20, alias="STARRED_ITEMS_LIMIT") upload_session_ttl_hours: int = Field(default=24, alias="UPLOAD_SESSION_TTL_HOURS") upload_temp_prefix: str = Field(default="tmp", alias="UPLOAD_TEMP_PREFIX") upload_object_prefix: str = Field(default="objects", alias="UPLOAD_OBJECT_PREFIX") @@ -82,6 +99,7 @@ class Settings(BaseSettings): default_factory=_default_worker_concurrency, alias="WORKER_CONCURRENCY", ) + worker_process_count: int = Field(default=1, alias="WORKER_PROCESS_COUNT") worker_task_timeout_seconds: int = Field(default=900, alias="WORKER_TASK_TIMEOUT_SECONDS") worker_default_max_attempts: int = Field(default=5, alias="WORKER_DEFAULT_MAX_ATTEMPTS") worker_retry_backoff_seconds: str = Field( @@ -151,6 +169,28 @@ def access_token_ttl_seconds(self) -> int: def refresh_token_ttl_seconds(self) -> int: return self.refresh_token_expire_days * 24 * 60 * 60 + @property + def effective_token_hash_secret(self) -> str: + secret = (self.token_hash_secret or "").strip() + if secret: + return secret + return self.jwt_secret_key + + @property + def security_configuration_issues(self) -> tuple[str, ...]: + issues: list[str] = [] + if len(self.jwt_secret_key.encode("utf-8")) < self.MIN_SECRET_LENGTH: + issues.append(f"JWT_SECRET_KEY must be at least {self.MIN_SECRET_LENGTH} bytes") + token_hash_secret = (self.token_hash_secret or "").strip() + if token_hash_secret and len(token_hash_secret.encode("utf-8")) < self.MIN_SECRET_LENGTH: + issues.append(f"TOKEN_HASH_SECRET must be at least {self.MIN_SECRET_LENGTH} bytes") + return tuple(issues) + + def assert_runtime_security(self) -> None: + issues = self.security_configuration_issues + if issues: + raise ValueError("; ".join(issues)) + @property def worker_retry_backoff_schedule(self) -> tuple[int, ...]: values: list[int] = [] @@ -184,6 +224,45 @@ def is_development_env(self) -> bool: def is_production_env(self) -> bool: return self.normalized_app_env in {"prod", "production"} + @property + def normalized_email_verify_base_url(self) -> str: + base_url = self.email_verify_base_url.strip() + if not base_url and self.is_development_env: + base_url = "http://localhost:8080" + if base_url and "://" not in base_url: + base_url = f"http://{base_url}" + return base_url.rstrip("/") + + @property + def mail_configuration_issues(self) -> tuple[str, ...]: + issues: list[str] = [] + if self.mail_port <= 0: + issues.append("MAIL_PORT must be a positive integer") + base_url = self.normalized_email_verify_base_url + if not base_url: + issues.append("EMAIL_VERIFY_BASE_URL is required") + parsed_base_url = urlsplit(base_url) if base_url else None + if parsed_base_url and parsed_base_url.scheme not in {"http", "https"}: + issues.append("EMAIL_VERIFY_BASE_URL must start with http:// or https://") + if parsed_base_url and not parsed_base_url.netloc: + issues.append("EMAIL_VERIFY_BASE_URL must include host") + if not (self.mail_from or "").strip(): + issues.append("MAIL_FROM is required") + if not (self.mail_server or "").strip(): + issues.append("MAIL_SERVER is required") + if self.mail_ssl_tls and self.mail_starttls: + issues.append("MAIL_SSL_TLS and MAIL_STARTTLS cannot both be true") + if self.mail_use_credentials: + if not (self.mail_username or "").strip(): + issues.append("MAIL_USERNAME is required when MAIL_USE_CREDENTIALS=true") + if not (self.mail_password or "").strip(): + issues.append("MAIL_PASSWORD is required when MAIL_USE_CREDENTIALS=true") + return tuple(issues) + + @property + def is_mail_configured(self) -> bool: + return len(self.mail_configuration_issues) == 0 + @property def agent_mcp_endpoints(self) -> tuple[str, ...]: raw = self.agent_mcp_endpoints_raw.strip() diff --git a/app/src/db/__init__.py b/app/src/fileflash/db/__init__.py similarity index 100% rename from app/src/db/__init__.py rename to app/src/fileflash/db/__init__.py diff --git a/app/src/db/deps.py b/app/src/fileflash/db/deps.py similarity index 100% rename from app/src/db/deps.py rename to app/src/fileflash/db/deps.py diff --git a/app/src/fileflash/db/engine.py b/app/src/fileflash/db/engine.py new file mode 100644 index 0000000..b0e8254 --- /dev/null +++ b/app/src/fileflash/db/engine.py @@ -0,0 +1,67 @@ +from __future__ import annotations + +from sqlalchemy import text +from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine, create_async_engine + +from ..core.settings import get_settings + +settings = get_settings() + +engine: AsyncEngine = create_async_engine( + settings.async_database_url, + echo=False, + pool_pre_ping=True, +) + + +async def verify_database_connection() -> None: + async with engine.connect() as connection: + await connection.execute(text("SELECT 1")) + + +async def verify_schema_compatibility() -> None: + async with engine.connect() as connection: + if not await _public_table_has_column(connection, table_name="user", column_name="avatar"): + raise RuntimeError( + "Database schema is outdated: missing column public.user.avatar. " + "Run Flyway migrations (at least V10__identity_avatar.sql) before starting the API." + ) + if not await _public_table_exists(connection, table_name="registration_email_domain_rule"): + raise RuntimeError( + "Database schema is outdated: missing table public.registration_email_domain_rule. " + "Run Flyway migrations (at least V11__identity_registration_email_domain_rule.sql) before starting the API." + ) + + +async def _public_table_has_column(connection: AsyncConnection, *, table_name: str, column_name: str) -> bool: + result = await connection.execute( + text( + """ + SELECT 1 + FROM information_schema.columns + WHERE table_schema = 'public' + AND table_name = :table_name + AND column_name = :column_name + LIMIT 1 + """ + ), + {"table_name": table_name, "column_name": column_name}, + ) + return result.scalar() == 1 + + +async def _public_table_exists(connection: AsyncConnection, *, table_name: str) -> bool: + result = await connection.execute( + text( + """ + SELECT 1 + FROM information_schema.tables + WHERE table_schema = 'public' + AND table_name = :table_name + AND table_type = 'BASE TABLE' + LIMIT 1 + """ + ), + {"table_name": table_name}, + ) + return result.scalar() == 1 diff --git a/app/src/db/session.py b/app/src/fileflash/db/session.py similarity index 100% rename from app/src/db/session.py rename to app/src/fileflash/db/session.py diff --git a/app/src/db/transaction.py b/app/src/fileflash/db/transaction.py similarity index 100% rename from app/src/db/transaction.py rename to app/src/fileflash/db/transaction.py diff --git a/app/src/exceptions/__init__.py b/app/src/fileflash/exceptions/__init__.py similarity index 100% rename from app/src/exceptions/__init__.py rename to app/src/fileflash/exceptions/__init__.py diff --git a/app/src/main.py b/app/src/fileflash/main.py similarity index 81% rename from app/src/main.py rename to app/src/fileflash/main.py index 1d27a9f..fbc194c 100644 --- a/app/src/main.py +++ b/app/src/fileflash/main.py @@ -13,7 +13,7 @@ from .core.deps import get_object_storage, get_rate_limiter from .core.errors import ApiError, api_success from .core.middleware import EmailVerificationGateMiddleware -from .db.engine import verify_database_connection +from .db.engine import verify_database_connection, verify_schema_compatibility from .routers import api_router from .s3 import ObjectStorageError from .services.dev_seed import initialize_dev_accounts @@ -24,7 +24,15 @@ @asynccontextmanager async def lifespan(_app: FastAPI): + settings.assert_runtime_security() + mail_issues = list(settings.mail_configuration_issues) + logger.info( + "Mail delivery readiness: configured=%s, issues=%s", + settings.is_mail_configured, + mail_issues, + ) await verify_database_connection() + await verify_schema_compatibility() try: await get_object_storage().ensure_bucket() except ObjectStorageError: @@ -58,7 +66,7 @@ async def health(): def main() -> None: - uvicorn.run("src.main:app", host="0.0.0.0", port=8080, reload=False) + uvicorn.run("fileflash.main:app", host="0.0.0.0", port=8080, reload=False) if __name__ == "__main__": diff --git a/app/src/models/__init__.py b/app/src/fileflash/models/__init__.py similarity index 95% rename from app/src/models/__init__.py rename to app/src/fileflash/models/__init__.py index eb58a85..575fa97 100644 --- a/app/src/models/__init__.py +++ b/app/src/fileflash/models/__init__.py @@ -21,6 +21,7 @@ Notification, ObjectScanResult, PasswordResetToken, + RegistrationEmailDomainRule, SecurityEvent, Share, ShareAccessLog, @@ -59,6 +60,7 @@ "Notification", "ObjectScanResult", "PasswordResetToken", + "RegistrationEmailDomainRule", "SecurityEvent", "Share", "ShareAccessLog", diff --git a/app/src/models/base.py b/app/src/fileflash/models/base.py similarity index 100% rename from app/src/models/base.py rename to app/src/fileflash/models/base.py diff --git a/app/src/models/enums.py b/app/src/fileflash/models/enums.py similarity index 100% rename from app/src/models/enums.py rename to app/src/fileflash/models/enums.py diff --git a/app/src/models/pg.py b/app/src/fileflash/models/pg.py similarity index 100% rename from app/src/models/pg.py rename to app/src/fileflash/models/pg.py diff --git a/app/src/models/tables.py b/app/src/fileflash/models/tables.py similarity index 95% rename from app/src/models/tables.py rename to app/src/fileflash/models/tables.py index d921030..0040f7b 100644 --- a/app/src/models/tables.py +++ b/app/src/fileflash/models/tables.py @@ -28,6 +28,7 @@ from .tables_identity import ( EmailVerificationToken, PasswordResetToken, + RegistrationEmailDomainRule, User, UserGroup, UserGroupMember, @@ -68,6 +69,7 @@ "Notification", "ObjectScanResult", "PasswordResetToken", + "RegistrationEmailDomainRule", "SecurityEvent", "Share", "ShareAccessLog", diff --git a/app/src/models/tables_access_share.py b/app/src/fileflash/models/tables_access_share.py similarity index 100% rename from app/src/models/tables_access_share.py rename to app/src/fileflash/models/tables_access_share.py diff --git a/app/src/models/tables_agent.py b/app/src/fileflash/models/tables_agent.py similarity index 100% rename from app/src/models/tables_agent.py rename to app/src/fileflash/models/tables_agent.py diff --git a/app/src/models/tables_audit_security.py b/app/src/fileflash/models/tables_audit_security.py similarity index 100% rename from app/src/models/tables_audit_security.py rename to app/src/fileflash/models/tables_audit_security.py diff --git a/app/src/models/tables_identity.py b/app/src/fileflash/models/tables_identity.py similarity index 88% rename from app/src/models/tables_identity.py rename to app/src/fileflash/models/tables_identity.py index 2b70fbe..ff24dc1 100644 --- a/app/src/models/tables_identity.py +++ b/app/src/fileflash/models/tables_identity.py @@ -215,9 +215,33 @@ class UserSession(Base): revoked_at: Mapped[datetime | None] = mapped_column(DateTime) +class RegistrationEmailDomainRule(Base): + __tablename__ = "registration_email_domain_rule" + __table_args__ = ( + Index("uk_registration_email_domain_rule_name_ci", text("(LOWER(name))"), unique=True), + Index("idx_registration_email_domain_rule_enabled", "enabled"), + ) + + rule_id: Mapped[int] = mapped_column(BigInteger, Identity(), primary_key=True) + name: Mapped[str] = mapped_column(String(120), nullable=False) + pattern: Mapped[str] = mapped_column(String(512), nullable=False) + enabled: Mapped[bool] = mapped_column(Boolean, nullable=False, server_default=text("TRUE")) + created_at: Mapped[datetime] = mapped_column( + DateTime, + nullable=False, + server_default=text("CURRENT_TIMESTAMP"), + ) + updated_at: Mapped[datetime] = mapped_column( + DateTime, + nullable=False, + server_default=text("CURRENT_TIMESTAMP"), + ) + + __all__ = [ "EmailVerificationToken", "PasswordResetToken", + "RegistrationEmailDomainRule", "User", "UserGroup", "UserGroupMember", diff --git a/app/src/models/tables_storage.py b/app/src/fileflash/models/tables_storage.py similarity index 100% rename from app/src/models/tables_storage.py rename to app/src/fileflash/models/tables_storage.py diff --git a/app/src/models/tables_worker.py b/app/src/fileflash/models/tables_worker.py similarity index 100% rename from app/src/models/tables_worker.py rename to app/src/fileflash/models/tables_worker.py diff --git a/app/src/models/types.py b/app/src/fileflash/models/types.py similarity index 100% rename from app/src/models/types.py rename to app/src/fileflash/models/types.py diff --git a/app/src/models/user.py b/app/src/fileflash/models/user.py similarity index 100% rename from app/src/models/user.py rename to app/src/fileflash/models/user.py diff --git a/app/src/repositories/__init__.py b/app/src/fileflash/repositories/__init__.py similarity index 100% rename from app/src/repositories/__init__.py rename to app/src/fileflash/repositories/__init__.py diff --git a/app/src/repositories/agent/__init__.py b/app/src/fileflash/repositories/agent/__init__.py similarity index 100% rename from app/src/repositories/agent/__init__.py rename to app/src/fileflash/repositories/agent/__init__.py diff --git a/app/src/repositories/agent/action_log.py b/app/src/fileflash/repositories/agent/action_log.py similarity index 100% rename from app/src/repositories/agent/action_log.py rename to app/src/fileflash/repositories/agent/action_log.py diff --git a/app/src/repositories/agent/contracts.py b/app/src/fileflash/repositories/agent/contracts.py similarity index 100% rename from app/src/repositories/agent/contracts.py rename to app/src/fileflash/repositories/agent/contracts.py diff --git a/app/src/repositories/agent/mcp.py b/app/src/fileflash/repositories/agent/mcp.py similarity index 100% rename from app/src/repositories/agent/mcp.py rename to app/src/fileflash/repositories/agent/mcp.py diff --git a/app/src/repositories/agent/memory.py b/app/src/fileflash/repositories/agent/memory.py similarity index 100% rename from app/src/repositories/agent/memory.py rename to app/src/fileflash/repositories/agent/memory.py diff --git a/app/src/repositories/agent/plan.py b/app/src/fileflash/repositories/agent/plan.py similarity index 100% rename from app/src/repositories/agent/plan.py rename to app/src/fileflash/repositories/agent/plan.py diff --git a/app/src/repositories/agent/settings.py b/app/src/fileflash/repositories/agent/settings.py similarity index 100% rename from app/src/repositories/agent/settings.py rename to app/src/fileflash/repositories/agent/settings.py diff --git a/app/src/repositories/agent/skill.py b/app/src/fileflash/repositories/agent/skill.py similarity index 100% rename from app/src/repositories/agent/skill.py rename to app/src/fileflash/repositories/agent/skill.py diff --git a/app/src/repositories/agent/work_session.py b/app/src/fileflash/repositories/agent/work_session.py similarity index 100% rename from app/src/repositories/agent/work_session.py rename to app/src/fileflash/repositories/agent/work_session.py diff --git a/app/src/routers/__init__.py b/app/src/fileflash/routers/__init__.py similarity index 84% rename from app/src/routers/__init__.py rename to app/src/fileflash/routers/__init__.py index b1ddf55..2f14ff0 100644 --- a/app/src/routers/__init__.py +++ b/app/src/fileflash/routers/__init__.py @@ -1,6 +1,7 @@ from fastapi import APIRouter from .auth import router as auth_router +from .admin_registration_email_domain_rules import router as admin_registration_email_domain_rules_router from .files import router as files_router from .folders import router as folders_router from .jobs import router as jobs_router @@ -13,6 +14,7 @@ api_router = APIRouter() api_router.include_router(auth_router) +api_router.include_router(admin_registration_email_domain_rules_router) api_router.include_router(files_router) api_router.include_router(folders_router) api_router.include_router(jobs_router) diff --git a/app/src/fileflash/routers/admin_registration_email_domain_rules.py b/app/src/fileflash/routers/admin_registration_email_domain_rules.py new file mode 100644 index 0000000..13e6971 --- /dev/null +++ b/app/src/fileflash/routers/admin_registration_email_domain_rules.py @@ -0,0 +1,73 @@ +from __future__ import annotations + +from datetime import UTC, datetime + +from fastapi import APIRouter, Depends + +from ..core.deps import get_registration_email_domain_rule_service, require_admin +from ..core.errors import api_success +from ..models.tables_identity import User +from ..schemas.registration_email_domain_rule import ( + CreateRegistrationEmailDomainRuleRequest, + ListRegistrationEmailDomainRulesQuery, + UpdateRegistrationEmailDomainRuleRequest, +) +from ..services.registration_email_domain_rule import RegistrationEmailDomainRuleService + +router = APIRouter(prefix="/admin/registration-email-domain-rules", tags=["admin"]) + + +@router.get("") +async def list_registration_email_domain_rules( + query: ListRegistrationEmailDomainRulesQuery = Depends(), + _: User = Depends(require_admin), + service: RegistrationEmailDomainRuleService = Depends(get_registration_email_domain_rule_service), +): + data = await service.list_rules(query=query) + return api_success(data=data.model_dump(by_alias=True), message="Rules fetched successfully") + + +@router.post("") +async def create_registration_email_domain_rule( + payload: CreateRegistrationEmailDomainRuleRequest, + _: User = Depends(require_admin), + service: RegistrationEmailDomainRuleService = Depends(get_registration_email_domain_rule_service), +): + item = await service.create_rule(payload=payload) + return api_success( + data=item.model_dump(by_alias=True), + code=201, + status_code=201, + message="Rule created successfully", + ) + + +@router.patch("/{rule_id}") +async def update_registration_email_domain_rule( + rule_id: int, + payload: UpdateRegistrationEmailDomainRuleRequest, + _: User = Depends(require_admin), + service: RegistrationEmailDomainRuleService = Depends(get_registration_email_domain_rule_service), +): + item = await service.update_rule(rule_id=rule_id, payload=payload) + return api_success(data=item.model_dump(by_alias=True), message="Rule updated successfully") + + +@router.delete("/{rule_id}") +async def delete_registration_email_domain_rule( + rule_id: int, + _: User = Depends(require_admin), + service: RegistrationEmailDomainRuleService = Depends(get_registration_email_domain_rule_service), +): + await service.delete_rule(rule_id=rule_id) + return api_success( + data={ + "ruleId": str(rule_id), + "deletedAt": datetime.now(UTC).isoformat(), + }, + message="Rule deleted successfully", + ) + + +__all__ = ["router"] + diff --git a/app/src/routers/agent.py b/app/src/fileflash/routers/agent.py similarity index 100% rename from app/src/routers/agent.py rename to app/src/fileflash/routers/agent.py diff --git a/app/src/routers/agent_skills.py b/app/src/fileflash/routers/agent_skills.py similarity index 100% rename from app/src/routers/agent_skills.py rename to app/src/fileflash/routers/agent_skills.py diff --git a/app/src/routers/auth.py b/app/src/fileflash/routers/auth.py similarity index 100% rename from app/src/routers/auth.py rename to app/src/fileflash/routers/auth.py diff --git a/app/src/routers/files.py b/app/src/fileflash/routers/files.py similarity index 100% rename from app/src/routers/files.py rename to app/src/fileflash/routers/files.py diff --git a/app/src/routers/folders.py b/app/src/fileflash/routers/folders.py similarity index 100% rename from app/src/routers/folders.py rename to app/src/fileflash/routers/folders.py diff --git a/app/src/routers/jobs.py b/app/src/fileflash/routers/jobs.py similarity index 100% rename from app/src/routers/jobs.py rename to app/src/fileflash/routers/jobs.py diff --git a/app/src/routers/me.py b/app/src/fileflash/routers/me.py similarity index 100% rename from app/src/routers/me.py rename to app/src/fileflash/routers/me.py diff --git a/app/src/routers/recycle.py b/app/src/fileflash/routers/recycle.py similarity index 100% rename from app/src/routers/recycle.py rename to app/src/fileflash/routers/recycle.py diff --git a/app/src/routers/shares.py b/app/src/fileflash/routers/shares.py similarity index 57% rename from app/src/routers/shares.py rename to app/src/fileflash/routers/shares.py index 204d118..4341854 100644 --- a/app/src/routers/shares.py +++ b/app/src/fileflash/routers/shares.py @@ -27,6 +27,64 @@ def _extract_bearer_token(authorization: str | None) -> str | None: return authorization.split(" ", 1)[1].strip() or None +def _extract_share_stream( + value: tuple[object, ...], +) -> tuple[object, str, str, int, dict[str, str] | None]: + if len(value) == 5: + stream, filename, content_type, status_code, headers = value + if isinstance(filename, str) and isinstance(content_type, str) and isinstance(status_code, int) and isinstance(headers, dict): + return stream, filename, content_type, status_code, headers + if len(value) == 4: + stream, filename, content_type, status_code = value + if isinstance(filename, str) and isinstance(content_type, str) and isinstance(status_code, int): + return stream, filename, content_type, status_code, None + if len(value) == 3: + stream, filename, content_type = value + if isinstance(filename, str) and isinstance(content_type, str): + return stream, filename, content_type, 200, None + + from ..core.errors import ApiError + + raise ApiError(status_code=500, code=500, message="Invalid shared stream response") + + +def _sanitize_stream_headers( + *, + headers: dict[str, str] | None, + filename: str, + disposition: str, +) -> dict[str, str]: + fallback_content_disposition = build_content_disposition(filename, disposition=disposition) + if headers is None: + return {"Content-Disposition": fallback_content_disposition} + + sanitized: dict[str, str] = {} + has_content_disposition = False + for key, value in headers.items(): + key_text = str(key) + value_text = str(value) + header_name = key_text.strip().lower() + + if header_name == "content-disposition": + has_content_disposition = True + try: + value_text.encode("latin-1") + sanitized[key_text] = value_text + except UnicodeEncodeError: + sanitized[key_text] = fallback_content_disposition + continue + + try: + value_text.encode("latin-1") + except UnicodeEncodeError: + continue + sanitized[key_text] = value_text + + if not has_content_disposition: + sanitized["Content-Disposition"] = fallback_content_disposition + return sanitized + + @router.post("") async def create_share( payload: CreateShareRequest, @@ -133,6 +191,7 @@ async def save_share_to_my_space( async def download_shared_file( share_link: str, authorization: str | None = Header(default=None), + range_header: str | None = Header(default=None, alias="Range"), client_ip: str = Depends(get_client_ip), user_agent: str | None = Depends(get_user_agent), share_service: ShareService = Depends(get_share_service), @@ -143,21 +202,33 @@ async def download_shared_file( raise ApiError(status_code=401, code=401, message="Missing share access token") - stream, filename, content_type = await share_service.get_shared_file_stream( - share_link=share_link, - share_access_token=token, - action="download", - ip_address=client_ip, - user_agent=user_agent, - ) - headers = {"Content-Disposition": build_content_disposition(filename, disposition="attachment")} - return StreamingResponse(stream, media_type=content_type, headers=headers) + if hasattr(share_service, "get_shared_file_download_stream_response"): + raw = await share_service.get_shared_file_download_stream_response( + share_link=share_link, + share_access_token=token, + action="download", + range_header=range_header, + ip_address=client_ip, + user_agent=user_agent, + ) + else: + raw = await share_service.get_shared_file_stream( + share_link=share_link, + share_access_token=token, + action="download", + ip_address=client_ip, + user_agent=user_agent, + ) + stream, filename, content_type, status_code, headers = _extract_share_stream(tuple(raw)) + response_headers = _sanitize_stream_headers(headers=headers, filename=filename, disposition="attachment") + return StreamingResponse(stream, media_type=content_type, headers=response_headers, status_code=status_code) @router.get("/{share_link}/preview") async def preview_shared_file( share_link: str, authorization: str | None = Header(default=None), + range_header: str | None = Header(default=None, alias="Range"), client_ip: str = Depends(get_client_ip), user_agent: str | None = Depends(get_user_agent), share_service: ShareService = Depends(get_share_service), @@ -168,13 +239,24 @@ async def preview_shared_file( raise ApiError(status_code=401, code=401, message="Missing share access token") - stream, filename, content_type = await share_service.get_shared_file_stream( - share_link=share_link, - share_access_token=token, - action="preview", - ip_address=client_ip, - user_agent=user_agent, - ) - headers = {"Content-Disposition": build_content_disposition(filename, disposition="inline")} - return StreamingResponse(stream, media_type=content_type, headers=headers) + if hasattr(share_service, "get_shared_file_download_stream_response"): + raw = await share_service.get_shared_file_download_stream_response( + share_link=share_link, + share_access_token=token, + action="preview", + range_header=range_header, + ip_address=client_ip, + user_agent=user_agent, + ) + else: + raw = await share_service.get_shared_file_stream( + share_link=share_link, + share_access_token=token, + action="preview", + ip_address=client_ip, + user_agent=user_agent, + ) + stream, filename, content_type, status_code, headers = _extract_share_stream(tuple(raw)) + response_headers = _sanitize_stream_headers(headers=headers, filename=filename, disposition="inline") + return StreamingResponse(stream, media_type=content_type, headers=response_headers, status_code=status_code) diff --git a/app/src/routers/storage.py b/app/src/fileflash/routers/storage.py similarity index 100% rename from app/src/routers/storage.py rename to app/src/fileflash/routers/storage.py diff --git a/app/src/routers/uploads.py b/app/src/fileflash/routers/uploads.py similarity index 89% rename from app/src/routers/uploads.py rename to app/src/fileflash/routers/uploads.py index 0497ef3..9996c87 100644 --- a/app/src/routers/uploads.py +++ b/app/src/fileflash/routers/uploads.py @@ -6,6 +6,7 @@ from ..core.errors import api_success from ..models.tables_identity import User from ..schemas.file import MergeChunksRequest, UploadPreflightRequest +from ..schemas.job import to_background_job_response from ..services.upload import UploadService router = APIRouter(prefix="/uploads", tags=["uploads"]) @@ -47,14 +48,14 @@ async def merge_chunks( current_user: User = Depends(get_current_user), upload_service: UploadService = Depends(get_upload_service), ): - response = await upload_service.merge_chunks( + job = await upload_service.enqueue_merge_job( user_id=current_user.user_id, upload_id=upload_id, payload=payload, ) return api_success( - data=response.model_dump(by_alias=True), - message="File uploaded successfully", + data=to_background_job_response(job).model_dump(by_alias=True), + message="Upload merge job created", code=201, status_code=201, ) diff --git a/app/src/s3/__init__.py b/app/src/fileflash/s3/__init__.py similarity index 100% rename from app/src/s3/__init__.py rename to app/src/fileflash/s3/__init__.py diff --git a/app/src/s3/minio_client.py b/app/src/fileflash/s3/minio_client.py similarity index 59% rename from app/src/s3/minio_client.py rename to app/src/fileflash/s3/minio_client.py index de4c057..536f0f1 100644 --- a/app/src/s3/minio_client.py +++ b/app/src/fileflash/s3/minio_client.py @@ -4,6 +4,7 @@ import hashlib import io import logging +import os from dataclasses import dataclass from typing import Iterable @@ -78,12 +79,14 @@ def from_settings(cls, settings: Settings) -> "MinioObjectStorageClient": region=settings.object_storage_region, ) - async def ensure_bucket(self) -> None: + async def ensure_bucket(self, *, bucket_name: str | None = None) -> None: + resolved_bucket = self._resolve_bucket_name(bucket_name) + def _run() -> None: try: - if self._client.bucket_exists(self.bucket_name): + if self._client.bucket_exists(resolved_bucket): return - self._client.make_bucket(self.bucket_name, location=self.region) + self._client.make_bucket(resolved_bucket, location=self.region) except S3Error as exc: if exc.code in {"BucketAlreadyOwnedByYou", "BucketAlreadyExists"}: return @@ -91,7 +94,7 @@ def _run() -> None: except Exception as exc: # noqa: BLE001 logger.exception( "Object storage availability check failed for bucket=%s", - self.bucket_name, + resolved_bucket, ) raise ObjectStorageUnavailableError("Object storage unavailable") from exc @@ -109,12 +112,20 @@ def _classify_s3_error(self, exc: S3Error) -> ObjectStorageError: return ObjectStorageAuthError(f"Object storage authentication failed: {code}") return ObjectStorageUnavailableError(f"Object storage unavailable: {code}") - async def put_bytes(self, *, object_key: str, data: bytes, content_type: str) -> ObjectWriteResult: - await self.ensure_bucket() + async def put_bytes( + self, + *, + object_key: str, + data: bytes, + content_type: str, + bucket_name: str | None = None, + ) -> ObjectWriteResult: + resolved_bucket = self._resolve_bucket_name(bucket_name) + await self.ensure_bucket(bucket_name=resolved_bucket) def _run() -> ObjectWriteResult: result = self._client.put_object( - self.bucket_name, + resolved_bucket, object_key, io.BytesIO(data), len(data), @@ -124,19 +135,28 @@ def _run() -> ObjectWriteResult: return await asyncio.to_thread(_run) - async def compose_object(self, *, object_key: str, source_keys: list[str]) -> ObjectWriteResult: - await self.ensure_bucket() + async def compose_object( + self, + *, + object_key: str, + source_keys: list[str], + bucket_name: str | None = None, + ) -> ObjectWriteResult: + resolved_bucket = self._resolve_bucket_name(bucket_name) + await self.ensure_bucket(bucket_name=resolved_bucket) def _run() -> ObjectWriteResult: - sources = [ComposeSource(self.bucket_name, source_key) for source_key in source_keys] - result = self._client.compose_object(self.bucket_name, object_key, sources) + sources = [ComposeSource(resolved_bucket, source_key) for source_key in source_keys] + result = self._client.compose_object(resolved_bucket, object_key, sources) return ObjectWriteResult(etag=result.etag, version_id=result.version_id) return await asyncio.to_thread(_run) - async def stat_object(self, *, object_key: str) -> ObjectStat: + async def stat_object(self, *, object_key: str, bucket_name: str | None = None) -> ObjectStat: + resolved_bucket = self._resolve_bucket_name(bucket_name) + def _run() -> ObjectStat: - stat = self._client.stat_object(self.bucket_name, object_key) + stat = self._client.stat_object(resolved_bucket, object_key) return ObjectStat( size=stat.size, etag=getattr(stat, "etag", None), @@ -146,21 +166,25 @@ def _run() -> ObjectStat: return await asyncio.to_thread(_run) - async def remove_object(self, *, object_key: str) -> None: + async def remove_object(self, *, object_key: str, bucket_name: str | None = None) -> None: + resolved_bucket = self._resolve_bucket_name(bucket_name) + def _run() -> None: - self._client.remove_object(self.bucket_name, object_key) + self._client.remove_object(resolved_bucket, object_key) await asyncio.to_thread(_run) - async def remove_objects(self, *, object_keys: Iterable[str]) -> None: + async def remove_objects(self, *, object_keys: Iterable[str], bucket_name: str | None = None) -> None: keys = [key for key in object_keys if key] if not keys: return + resolved_bucket = self._resolve_bucket_name(bucket_name) + def _run() -> None: errors = list( self._client.remove_objects( - self.bucket_name, + resolved_bucket, (DeleteObject(key) for key in keys), ) ) @@ -170,10 +194,12 @@ def _run() -> None: await asyncio.to_thread(_run) - async def compute_object_hash(self, *, object_key: str, algorithm: str) -> str: + async def compute_object_hash(self, *, object_key: str, algorithm: str, bucket_name: str | None = None) -> str: + resolved_bucket = self._resolve_bucket_name(bucket_name) + def _run() -> str: hasher = hashlib.new(algorithm) - response = self._client.get_object(self.bucket_name, object_key) + response = self._client.get_object(resolved_bucket, object_key) try: for chunk in response.stream(1024 * 1024): hasher.update(chunk) @@ -184,10 +210,17 @@ def _run() -> str: return await asyncio.to_thread(_run) - async def iter_object(self, *, object_key: str, chunk_size: int = 1024 * 1024) -> AsyncIterator[bytes]: - await self.ensure_bucket() + async def iter_object( + self, + *, + object_key: str, + chunk_size: int = 1024 * 1024, + bucket_name: str | None = None, + ) -> AsyncIterator[bytes]: + resolved_bucket = self._resolve_bucket_name(bucket_name) + await self.ensure_bucket(bucket_name=resolved_bucket) - response = await asyncio.to_thread(self._client.get_object, self.bucket_name, object_key) + response = await asyncio.to_thread(self._client.get_object, resolved_bucket, object_key) try: while True: chunk = await asyncio.to_thread(response.read, chunk_size) @@ -205,8 +238,10 @@ async def iter_object_range( start: int, end: int, chunk_size: int = 1024 * 1024, + bucket_name: str | None = None, ) -> AsyncIterator[bytes]: - await self.ensure_bucket() + resolved_bucket = self._resolve_bucket_name(bucket_name) + await self.ensure_bucket(bucket_name=resolved_bucket) if start < 0 or end < start: raise ValueError("Invalid byte range") @@ -214,7 +249,7 @@ async def iter_object_range( length = end - start + 1 response = await asyncio.to_thread( self._client.get_object, - self.bucket_name, + resolved_bucket, object_key, start, length, @@ -231,3 +266,59 @@ async def iter_object_range( finally: await asyncio.to_thread(response.close) await asyncio.to_thread(response.release_conn) + + async def fget_object( + self, + *, + object_key: str, + file_path: str, + bucket_name: str | None = None, + ) -> ObjectWriteResult: + resolved_bucket = self._resolve_bucket_name(bucket_name) + + def _run() -> ObjectWriteResult: + result = self._client.fget_object(resolved_bucket, object_key, file_path) + return ObjectWriteResult(etag=getattr(result, "etag", None), version_id=getattr(result, "version_id", None)) + + return await asyncio.to_thread(_run) + + async def fput_object( + self, + *, + object_key: str, + file_path: str, + content_type: str, + bucket_name: str | None = None, + ) -> ObjectWriteResult: + resolved_bucket = self._resolve_bucket_name(bucket_name) + await self.ensure_bucket(bucket_name=resolved_bucket) + + def _run() -> ObjectWriteResult: + result = self._client.fput_object( + resolved_bucket, + object_key, + file_path, + content_type=content_type, + ) + return ObjectWriteResult(etag=getattr(result, "etag", None), version_id=getattr(result, "version_id", None)) + + return await asyncio.to_thread(_run) + + async def object_exists(self, *, object_key: str, bucket_name: str | None = None) -> bool: + try: + await self.stat_object(object_key=object_key, bucket_name=bucket_name) + except S3Error as exc: + if exc.code in {"NoSuchKey", "NoSuchObject", "NoSuchBucket"}: + return False + raise + except Exception: + return False + return True + + @staticmethod + def file_size(file_path: str) -> int: + return int(os.path.getsize(file_path)) + + def _resolve_bucket_name(self, bucket_name: str | None) -> str: + value = (bucket_name or "").strip() + return value or self.bucket_name diff --git a/app/src/s3/s3.py b/app/src/fileflash/s3/s3.py similarity index 100% rename from app/src/s3/s3.py rename to app/src/fileflash/s3/s3.py diff --git a/app/src/schemas/__init__.py b/app/src/fileflash/schemas/__init__.py similarity index 93% rename from app/src/schemas/__init__.py rename to app/src/fileflash/schemas/__init__.py index 080cc8d..39e1ac4 100644 --- a/app/src/schemas/__init__.py +++ b/app/src/fileflash/schemas/__init__.py @@ -73,6 +73,12 @@ PermissionItem, UpdatePermissionRequest, ) +from .registration_email_domain_rule import ( + CreateRegistrationEmailDomainRuleRequest, + ListRegistrationEmailDomainRulesQuery, + RegistrationEmailDomainRuleItem, + UpdateRegistrationEmailDomainRuleRequest, +) from .recycle import ( ClearRecycleBinResponse, GetRecycleBinQuery, @@ -209,8 +215,10 @@ "PaginationMeta", "PermissionItem", "PermanentDeleteResponse", + "CreateRegistrationEmailDomainRuleRequest", "RateLimitRule", "RateLimitStatus", + "RegistrationEmailDomainRuleItem", "RecycleBinItem", "RegisterRequest", "RemoveGroupMemberResponse", @@ -235,11 +243,13 @@ "StorageUserItem", "StorageUsersList", "SystemHealth", + "ListRegistrationEmailDomainRulesQuery", "ToggleFileStarRequest", "ToggleFolderStarRequest", "TokenResponse", "UpdatePermissionRequest", "UpdateProfileRequest", + "UpdateRegistrationEmailDomainRuleRequest", "UpdateShareSettingsRequest", "UpdateStorageQuotaRequest", "UpdateStorageQuotaResponse", diff --git a/app/src/schemas/agent_skill.py b/app/src/fileflash/schemas/agent_skill.py similarity index 100% rename from app/src/schemas/agent_skill.py rename to app/src/fileflash/schemas/agent_skill.py diff --git a/app/src/schemas/archive.py b/app/src/fileflash/schemas/archive.py similarity index 100% rename from app/src/schemas/archive.py rename to app/src/fileflash/schemas/archive.py diff --git a/app/src/schemas/auth.py b/app/src/fileflash/schemas/auth.py similarity index 100% rename from app/src/schemas/auth.py rename to app/src/fileflash/schemas/auth.py diff --git a/app/src/schemas/common.py b/app/src/fileflash/schemas/common.py similarity index 100% rename from app/src/schemas/common.py rename to app/src/fileflash/schemas/common.py diff --git a/app/src/schemas/file.py b/app/src/fileflash/schemas/file.py similarity index 95% rename from app/src/schemas/file.py rename to app/src/fileflash/schemas/file.py index 8ce5a88..d9dd1ab 100644 --- a/app/src/schemas/file.py +++ b/app/src/fileflash/schemas/file.py @@ -24,6 +24,7 @@ class FileItem(CamelModel): folder_id: str permission: Literal["read", "write", "owner"] | None = None is_starred: bool | None = None + media_optimization: MediaOptimization | None = None class FolderItem(CamelModel): @@ -150,6 +151,13 @@ class FileDetails(FileItem): status: bool +class MediaOptimization(CamelModel): + status: Literal["queued", "running", "ready", "failed"] + media_type: Literal["audio", "video"] + optimized_mime_type: str | None = None + updated_at: datetime + + class RenameFileRequest(CamelModel): file_name: str = Field(min_length=1, max_length=255) diff --git a/app/src/schemas/job.py b/app/src/fileflash/schemas/job.py similarity index 100% rename from app/src/schemas/job.py rename to app/src/fileflash/schemas/job.py diff --git a/app/src/schemas/log.py b/app/src/fileflash/schemas/log.py similarity index 100% rename from app/src/schemas/log.py rename to app/src/fileflash/schemas/log.py diff --git a/app/src/schemas/notification.py b/app/src/fileflash/schemas/notification.py similarity index 100% rename from app/src/schemas/notification.py rename to app/src/fileflash/schemas/notification.py diff --git a/app/src/schemas/permission.py b/app/src/fileflash/schemas/permission.py similarity index 100% rename from app/src/schemas/permission.py rename to app/src/fileflash/schemas/permission.py diff --git a/app/src/schemas/recycle.py b/app/src/fileflash/schemas/recycle.py similarity index 100% rename from app/src/schemas/recycle.py rename to app/src/fileflash/schemas/recycle.py diff --git a/app/src/fileflash/schemas/registration_email_domain_rule.py b/app/src/fileflash/schemas/registration_email_domain_rule.py new file mode 100644 index 0000000..fcff322 --- /dev/null +++ b/app/src/fileflash/schemas/registration_email_domain_rule.py @@ -0,0 +1,34 @@ +from __future__ import annotations + +from datetime import datetime + +from pydantic import Field + +from .common import CamelModel, PageQuery + + +class RegistrationEmailDomainRuleItem(CamelModel): + rule_id: str + name: str + pattern: str + enabled: bool + created_at: datetime + updated_at: datetime + + +class ListRegistrationEmailDomainRulesQuery(PageQuery): + query_text: str | None = None + enabled: bool | None = None + + +class CreateRegistrationEmailDomainRuleRequest(CamelModel): + name: str = Field(min_length=1, max_length=120) + pattern: str = Field(min_length=1, max_length=512) + enabled: bool = True + + +class UpdateRegistrationEmailDomainRuleRequest(CamelModel): + name: str | None = Field(default=None, min_length=1, max_length=120) + pattern: str | None = Field(default=None, min_length=1, max_length=512) + enabled: bool | None = None + diff --git a/app/src/schemas/share.py b/app/src/fileflash/schemas/share.py similarity index 100% rename from app/src/schemas/share.py rename to app/src/fileflash/schemas/share.py diff --git a/app/src/schemas/storage.py b/app/src/fileflash/schemas/storage.py similarity index 100% rename from app/src/schemas/storage.py rename to app/src/fileflash/schemas/storage.py diff --git a/app/src/schemas/system.py b/app/src/fileflash/schemas/system.py similarity index 100% rename from app/src/schemas/system.py rename to app/src/fileflash/schemas/system.py diff --git a/app/src/schemas/user.py b/app/src/fileflash/schemas/user.py similarity index 100% rename from app/src/schemas/user.py rename to app/src/fileflash/schemas/user.py diff --git a/app/src/schemas/user_group.py b/app/src/fileflash/schemas/user_group.py similarity index 100% rename from app/src/schemas/user_group.py rename to app/src/fileflash/schemas/user_group.py diff --git a/app/src/scripts/__init__.py b/app/src/fileflash/scripts/__init__.py similarity index 100% rename from app/src/scripts/__init__.py rename to app/src/fileflash/scripts/__init__.py diff --git a/app/src/scripts/init_dev_accounts.py b/app/src/fileflash/scripts/init_dev_accounts.py similarity index 100% rename from app/src/scripts/init_dev_accounts.py rename to app/src/fileflash/scripts/init_dev_accounts.py diff --git a/app/src/fileflash/scripts/run_with_workers.py b/app/src/fileflash/scripts/run_with_workers.py new file mode 100644 index 0000000..7f2db6d --- /dev/null +++ b/app/src/fileflash/scripts/run_with_workers.py @@ -0,0 +1,188 @@ +from __future__ import annotations + +import argparse +import signal +import subprocess +import sys +import time +from collections.abc import Mapping +from dataclasses import dataclass +from pathlib import Path + +from redis import Redis + +from ..core.settings import get_settings + + +@dataclass(slots=True) +class ManagedProcess: + name: str + process: subprocess.Popen[bytes] + command: list[str] + + +def _build_parser() -> argparse.ArgumentParser: + settings = get_settings() + default_worker_count = max(1, settings.worker_process_count) + parser = argparse.ArgumentParser( + description="Run FileFlash backend API with worker processes.", + ) + parser.add_argument("--host", default="0.0.0.0", help="API host (default: 0.0.0.0)") + parser.add_argument("--port", type=int, default=8080, help="API port (default: 8080)") + parser.add_argument( + "--reload", + action="store_true", + help="Enable uvicorn auto-reload for API process.", + ) + parser.add_argument( + "--worker-count", + type=int, + default=default_worker_count, + help=( + "Number of file worker consumer processes " + f"(default from WORKER_PROCESS_COUNT: {default_worker_count})." + ), + ) + parser.add_argument( + "--no-worker", + action="store_true", + help="Start API only (without file workers).", + ) + return parser + + +def _spawn_process(name: str, command: list[str], cwd: Path) -> ManagedProcess: + popen_kwargs: dict[str, object] = {"cwd": str(cwd)} + if sys.platform == "win32": + # Needed so CTRL_BREAK_EVENT can be delivered to child process group. + popen_kwargs["creationflags"] = subprocess.CREATE_NEW_PROCESS_GROUP + proc = subprocess.Popen( + command, + **popen_kwargs, + ) + return ManagedProcess(name=name, process=proc, command=command) + + +def _format_cmd(command: list[str]) -> str: + return " ".join(command) + + +def _stop_process(managed: ManagedProcess, *, timeout_sec: float = 8.0) -> None: + proc = managed.process + if proc.poll() is not None: + return + + try: + if sys.platform == "win32": + proc.send_signal(signal.CTRL_BREAK_EVENT) + else: + proc.terminate() + proc.wait(timeout=timeout_sec) + return + except Exception: + pass + + try: + proc.kill() + proc.wait(timeout=timeout_sec) + except Exception: + pass + + +def _validate_redis_for_workers(env: Mapping[str, str] | None = None) -> tuple[bool, str]: + redis_url = (env or {}).get("REDIS_URL", "").strip() + if not redis_url: + settings = get_settings() + redis_url = (settings.redis_url or "").strip() + if not redis_url.strip(): + return ( + False, + "[run-with-workers] worker startup preflight failed: REDIS_URL is not set.", + ) + + client: Redis | None = None + try: + client = Redis.from_url(redis_url, socket_connect_timeout=2.0, socket_timeout=2.0) + client.ping() + except Exception as exc: + return ( + False, + ( + "[run-with-workers] worker startup preflight failed: " + f"cannot connect to Redis at {redis_url}. error={type(exc).__name__}: {exc}" + ), + ) + finally: + try: + client.close() + except Exception: + pass + + return True, "" + + +def main() -> int: + parser = _build_parser() + args = parser.parse_args() + + if args.worker_count < 1: + parser.error("--worker-count must be >= 1") + + cwd = Path(__file__).resolve().parents[2] + python = sys.executable + + processes: list[ManagedProcess] = [] + try: + if not args.no_worker: + ok, error_message = _validate_redis_for_workers() + if not ok: + print(error_message, file=sys.stderr) + return 2 + + api_cmd = [ + python, + "-m", + "uvicorn", + "fileflash.main:app", + "--host", + str(args.host), + "--port", + str(args.port), + ] + if args.reload: + api_cmd.append("--reload") + + api_proc = _spawn_process("api", api_cmd, cwd) + processes.append(api_proc) + print(f"[run-with-workers] started {api_proc.name}: {_format_cmd(api_cmd)}") + + if not args.no_worker: + for index in range(args.worker_count): + worker_name = f"worker-{index + 1}" + worker_cmd = [python, "-m", "fileflash.workers.consumer"] + worker_proc = _spawn_process(worker_name, worker_cmd, cwd) + processes.append(worker_proc) + print(f"[run-with-workers] started {worker_name}: {_format_cmd(worker_cmd)}") + + while True: + for managed in processes: + exit_code = managed.process.poll() + if exit_code is not None: + print( + f"[run-with-workers] process exited: {managed.name} code={exit_code}", + file=sys.stderr, + ) + return int(exit_code) + time.sleep(0.5) + except KeyboardInterrupt: + print("[run-with-workers] shutdown requested, stopping all processes...") + return 0 + finally: + for managed in reversed(processes): + _stop_process(managed) + code = managed.process.poll() + print(f"[run-with-workers] stopped {managed.name} code={code}") + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/app/src/services/__init__.py b/app/src/fileflash/services/__init__.py similarity index 81% rename from app/src/services/__init__.py rename to app/src/fileflash/services/__init__.py index 00498ef..e523bba 100644 --- a/app/src/services/__init__.py +++ b/app/src/fileflash/services/__init__.py @@ -2,11 +2,13 @@ from .agent import ExecuteService, McpService, MemoryService, PlanService, SessionService, SettingsService, SkillService from .auth import AuthService from .background_jobs import BackgroundJobService +from .email_delivery import VerificationEmailDeliveryService from .file import FileService from .folder import FolderService from .job_queue import JobQueuePublisher, RedisStreamJobQueue from .messaging import AuthEventPublisher, InProcessAuthEventPublisher from .rate_limiter import RedisRateLimiter +from .registration_email_domain_rule import RegistrationEmailDomainRuleService from .share import ShareService from .upload import UploadService @@ -15,6 +17,7 @@ "AuthService", "ArchiveService", "BackgroundJobService", + "VerificationEmailDeliveryService", "ExecuteService", "FileService", "FolderService", @@ -28,6 +31,7 @@ "InProcessAuthEventPublisher", "RedisStreamJobQueue", "RedisRateLimiter", + "RegistrationEmailDomainRuleService", "ShareService", "UploadService", ] diff --git a/app/src/services/agent/__init__.py b/app/src/fileflash/services/agent/__init__.py similarity index 100% rename from app/src/services/agent/__init__.py rename to app/src/fileflash/services/agent/__init__.py diff --git a/app/src/services/agent/execute_service.py b/app/src/fileflash/services/agent/execute_service.py similarity index 100% rename from app/src/services/agent/execute_service.py rename to app/src/fileflash/services/agent/execute_service.py diff --git a/app/src/services/agent/mcp_service.py b/app/src/fileflash/services/agent/mcp_service.py similarity index 100% rename from app/src/services/agent/mcp_service.py rename to app/src/fileflash/services/agent/mcp_service.py diff --git a/app/src/services/agent/memory_service.py b/app/src/fileflash/services/agent/memory_service.py similarity index 100% rename from app/src/services/agent/memory_service.py rename to app/src/fileflash/services/agent/memory_service.py diff --git a/app/src/services/agent/plan_service.py b/app/src/fileflash/services/agent/plan_service.py similarity index 100% rename from app/src/services/agent/plan_service.py rename to app/src/fileflash/services/agent/plan_service.py diff --git a/app/src/services/agent/session_service.py b/app/src/fileflash/services/agent/session_service.py similarity index 100% rename from app/src/services/agent/session_service.py rename to app/src/fileflash/services/agent/session_service.py diff --git a/app/src/services/agent/settings_service.py b/app/src/fileflash/services/agent/settings_service.py similarity index 100% rename from app/src/services/agent/settings_service.py rename to app/src/fileflash/services/agent/settings_service.py diff --git a/app/src/services/agent/skill_service.py b/app/src/fileflash/services/agent/skill_service.py similarity index 100% rename from app/src/services/agent/skill_service.py rename to app/src/fileflash/services/agent/skill_service.py diff --git a/app/src/services/archive.py b/app/src/fileflash/services/archive.py similarity index 100% rename from app/src/services/archive.py rename to app/src/fileflash/services/archive.py diff --git a/app/src/services/auth.py b/app/src/fileflash/services/auth.py similarity index 88% rename from app/src/services/auth.py rename to app/src/fileflash/services/auth.py index 98197fc..cbd01a6 100644 --- a/app/src/services/auth.py +++ b/app/src/fileflash/services/auth.py @@ -19,38 +19,52 @@ ) from ..models.enums import FileStatus, FolderStatus, FolderType, UiLanguage, UserRole, UserStatus from ..models.tables_audit_security import Log -from ..models.tables_identity import EmailVerificationToken, PasswordResetToken, User, UserPreference, UserSession +from ..models.tables_identity import ( + EmailVerificationToken, + PasswordResetToken, + User, + UserPreference, + UserSession, +) from ..models.tables_storage import File, Folder from ..schemas.auth import ForgotPasswordResponse, RegisterRequest, RegisterResponseData, TokenResponse from ..schemas.common import PaginatedData, PaginationMeta from ..schemas.user import ( ActivityItem, BreakdownDetail, + ChangePasswordRequest, + GetActivityLogQuery, + StorageStats, UpdateAvatarRequest, UpdateProfileRequest, UpdateUserPreferenceRequest, + User as UserSchema, + UserPreference as UserPreferenceSchema, UserProfile, ) -from ..schemas.user import User as UserSchema -from ..schemas.user import ChangePasswordRequest, GetActivityLogQuery, StorageStats -from ..schemas.user import UserPreference as UserPreferenceSchema -from ..schemas.user import UpdateProfileRequest, UpdateUserPreferenceRequest, UserProfile +from .email_delivery import EmailDeliveryConfigurationError, EmailDeliveryError, VerificationEmailDeliveryService from .messaging import AuthEventPublisher from .rate_limiter import RedisRateLimiter +from .registration_email_domain_rule import RegistrationEmailDomainRuleService class AuthService: + MIN_VERIFICATION_TOKEN_LENGTH = 16 + MIN_RESET_TOKEN_LENGTH = 16 + def __init__( self, db: AsyncSession, settings: Settings, rate_limiter: RedisRateLimiter, event_publisher: AuthEventPublisher, + verification_email_delivery: VerificationEmailDeliveryService, ) -> None: self.db = db self.settings = settings self.rate_limiter = rate_limiter self.event_publisher = event_publisher + self.verification_email_delivery = verification_email_delivery async def register( self, @@ -65,6 +79,7 @@ async def register( window_seconds=self.settings.register_rate_window_seconds, message="Too many registration attempts, please try again later", ) + await RegistrationEmailDomainRuleService(self.db).assert_email_allowed(email=payload.email) existing_user = await self.db.scalar( select(User).where( @@ -104,12 +119,16 @@ async def register( "auth.email_verification_requested", { "userId": str(user.user_id), - "email": user.email, - "token": verification_token, "expiresInMinutes": self.settings.email_verification_expire_minutes, "userAgent": user_agent or "", }, ) + await self._send_verification_email_or_raise( + event_name="auth.email_verification_requested", + email=user.email, + token=verification_token, + expires_in_minutes=self.settings.email_verification_expire_minutes, + ) user_schema = self._to_user_schema(user=user, preference=preference) return RegisterResponseData(user=user_schema) @@ -173,7 +192,7 @@ async def _operation() -> tuple[TokenResponse, str]: self.db.add( UserSession( user_id=user.user_id, - refresh_token_hash=hash_token(refresh_token), + refresh_token_hash=hash_token(refresh_token, self.settings), client_type="web", ip_address=client_ip, user_agent=user_agent, @@ -217,7 +236,7 @@ async def refresh( ) -> tuple[TokenResponse, str]: async def _operation() -> tuple[TokenResponse, str]: now = datetime.now(UTC) - token_hash = hash_token(refresh_token) + token_hash = hash_token(refresh_token, self.settings) await apply_local_lock_timeout(self.db) session = await self.db.scalar( select(UserSession) @@ -243,7 +262,7 @@ async def _operation() -> tuple[TokenResponse, str]: next_refresh_token = create_refresh_token() next_session = UserSession( user_id=user.user_id, - refresh_token_hash=hash_token(next_refresh_token), + refresh_token_hash=hash_token(next_refresh_token, self.settings), client_type=session.client_type, device_id=session.device_id, device_name=session.device_name, @@ -277,7 +296,7 @@ async def logout(self, *, refresh_token: str | None) -> None: return now = datetime.now(UTC) - token_hash = hash_token(refresh_token) + token_hash = hash_token(refresh_token, self.settings) session = await self.db.scalar( select(UserSession).where( and_( @@ -309,7 +328,7 @@ async def forgot_password( now = datetime.now(UTC) user = await self.db.scalar(select(User).where(func.lower(User.email) == email.lower())) if user: - reset_token = await self._create_password_reset_token( + await self._create_password_reset_token( user_id=user.user_id, now=now, client_ip=client_ip, @@ -321,8 +340,6 @@ async def forgot_password( { "requestId": request_id, "userId": str(user.user_id), - "email": user.email, - "token": reset_token, "expiresInMinutes": self.settings.password_reset_expire_minutes, }, ) @@ -335,7 +352,8 @@ async def forgot_password( async def reset_password(self, *, token: str, new_password: str) -> None: async def _operation() -> None: now = datetime.now(UTC) - token_hash_value = hash_token(token) + self._assert_token_length(token=token, minimum=self.MIN_RESET_TOKEN_LENGTH, message="Invalid or expired reset token") + token_hash_value = hash_token(token, self.settings) await apply_local_lock_timeout(self.db) reset_record = await self.db.scalar( select(PasswordResetToken) @@ -388,7 +406,12 @@ async def _operation() -> None: async def verify_email(self, *, token: str) -> None: async def _operation() -> None: now = datetime.now(UTC) - token_hash_value = hash_token(token) + self._assert_token_length( + token=token, + minimum=self.MIN_VERIFICATION_TOKEN_LENGTH, + message="Invalid or expired verification token", + ) + token_hash_value = hash_token(token, self.settings) await apply_local_lock_timeout(self.db) verification_record = await self.db.scalar( select(EmailVerificationToken) @@ -451,12 +474,16 @@ async def resend_verification( "auth.email_verification_resent", { "userId": str(user.user_id), - "email": user.email, - "token": token, "expiresInMinutes": self.settings.email_verification_expire_minutes, "userAgent": user_agent or "", }, ) + await self._send_verification_email_or_raise( + event_name="auth.email_verification_resent", + email=user.email, + token=token, + expires_in_minutes=self.settings.email_verification_expire_minutes, + ) async def get_profile(self, *, user_id: int) -> UserProfile: user = await self.db.get(User, user_id) @@ -533,6 +560,7 @@ async def _operation() -> tuple[UserProfile, str | None]: if not email: raise ApiError(status_code=400, code=400, message="email cannot be empty") if email.lower() != user.email.lower(): + await RegistrationEmailDomainRuleService(self.db).assert_email_allowed(email=email) email_exists = await self.db.scalar( select(User.user_id).where( and_( @@ -578,12 +606,16 @@ async def _operation() -> tuple[UserProfile, str | None]: "auth.email_verification_requested", { "userId": str(profile.user_id), - "email": profile.email, - "token": verification_token, "expiresInMinutes": self.settings.email_verification_expire_minutes, "userAgent": user_agent or "", }, ) + await self._send_verification_email_or_raise( + event_name="auth.email_verification_requested", + email=profile.email, + token=verification_token, + expires_in_minutes=self.settings.email_verification_expire_minutes, + ) return profile @@ -620,7 +652,7 @@ async def change_password( if payload.old_password == payload.new_password: raise ApiError(status_code=400, code=400, message="newPassword must be different from oldPassword") - token_hash = hash_token(current_refresh_token) if current_refresh_token else None + token_hash = hash_token(current_refresh_token, self.settings) if current_refresh_token else None async def _operation() -> None: await apply_local_lock_timeout(self.db) @@ -801,10 +833,11 @@ async def _ensure_rate_limit( async def _create_email_verification_token(self, *, user_id: int, now: datetime) -> str: token = secrets.token_urlsafe(32) + await self._invalidate_active_verification_tokens(user_id=user_id, now=now) self.db.add( EmailVerificationToken( user_id=user_id, - token_hash=hash_token(token), + token_hash=hash_token(token, self.settings), expire_at=now + timedelta(minutes=self.settings.email_verification_expire_minutes), ) ) @@ -822,7 +855,7 @@ async def _create_password_reset_token( self.db.add( PasswordResetToken( user_id=user_id, - token_hash=hash_token(token), + token_hash=hash_token(token, self.settings), expire_at=now + timedelta(minutes=self.settings.password_reset_expire_minutes), requester_ip=client_ip, user_agent=user_agent, @@ -830,6 +863,56 @@ async def _create_password_reset_token( ) return token + async def _invalidate_active_verification_tokens(self, *, user_id: int, now: datetime) -> None: + rows = await self.db.scalars( + select(EmailVerificationToken) + .where( + and_( + EmailVerificationToken.user_id == user_id, + EmailVerificationToken.verified_at.is_(None), + EmailVerificationToken.expire_at > now, + ) + ) + .with_for_update() + ) + for row in rows: + row.verified_at = now + + async def _send_verification_email_or_raise( + self, + *, + event_name: str, + email: str, + token: str, + expires_in_minutes: int, + ) -> None: + try: + await self.verification_email_delivery.send_verification_email( + email=email, + token=token, + expires_in_minutes=expires_in_minutes, + ) + except (EmailDeliveryConfigurationError, EmailDeliveryError, ValueError) as exc: + await self.event_publisher.publish( + "auth.email_verification_delivery_failed", + { + "eventName": event_name, + }, + ) + message = "Verification email service is unavailable" + if isinstance(exc, EmailDeliveryConfigurationError): + message = str(exc) + raise ApiError( + status_code=503, + code=503, + message=message, + ) from exc + + @staticmethod + def _assert_token_length(*, token: str, minimum: int, message: str) -> None: + if len(token.strip()) < minimum: + raise ApiError(status_code=400, code=400, message=message) + async def _get_user_preference(self, user_id: int) -> UserPreference | None: statement: Select[tuple[UserPreference]] = select(UserPreference).where(UserPreference.user_id == user_id) return await self.db.scalar(statement) diff --git a/app/src/services/background_jobs.py b/app/src/fileflash/services/background_jobs.py similarity index 86% rename from app/src/services/background_jobs.py rename to app/src/fileflash/services/background_jobs.py index f92759f..4ad5147 100644 --- a/app/src/services/background_jobs.py +++ b/app/src/fileflash/services/background_jobs.py @@ -112,17 +112,26 @@ async def enqueue_transcode_job( self, db: AsyncSession, *, - input_path: str, - output_path: str | None = None, - object_id: int | None = None, + source_bucket_name: str, + source_object_key: str, + source_object_id: int, + output_bucket_name: str, + output_object_key: str, + file_id: int | None = None, requested_by: int | None = None, idempotency_key: str | None = None, ) -> BackgroundJob: - payload: dict[str, Any] = {"inputPath": input_path} - if output_path is not None: - payload["outputPath"] = output_path - if object_id is not None: - payload["objectId"] = object_id + payload: dict[str, Any] = { + "sourceBucketName": source_bucket_name, + "sourceObjectKey": source_object_key, + "sourceObjectId": source_object_id, + "outputBucketName": output_bucket_name, + "outputObjectKey": output_object_key, + } + if file_id is not None: + payload["fileId"] = file_id + if requested_by is not None: + payload["requestedBy"] = requested_by return await self.enqueue( db, task_type="task.transcode", @@ -134,7 +143,8 @@ async def enqueue_transcode_job( def _build_queue_message(job: BackgroundJob) -> WorkerJobMessage: payload = dict(job.payload or {}) - payload.setdefault("jobId", job.job_id) + if payload.get("jobId") in (None, ""): + payload["jobId"] = job.job_id if job.requested_by is not None: payload.setdefault("requestedBy", job.requested_by) return WorkerJobMessage( diff --git a/app/src/services/dev_seed.py b/app/src/fileflash/services/dev_seed.py similarity index 100% rename from app/src/services/dev_seed.py rename to app/src/fileflash/services/dev_seed.py diff --git a/app/src/fileflash/services/email_delivery.py b/app/src/fileflash/services/email_delivery.py new file mode 100644 index 0000000..42de5f0 --- /dev/null +++ b/app/src/fileflash/services/email_delivery.py @@ -0,0 +1,138 @@ +from __future__ import annotations + +from dataclasses import dataclass +from urllib.parse import quote + +from fastapi_mail import ConnectionConfig, FastMail, MessageSchema, MessageType, MultipartSubtypeEnum + +from ..core.settings import Settings + + +class EmailDeliveryConfigurationError(RuntimeError): + pass + + +class EmailDeliveryError(RuntimeError): + pass + + +@dataclass(slots=True) +class VerificationEmailPayload: + email: str + token: str + expires_in_minutes: int + + +class VerificationEmailDeliveryService: + def __init__(self, settings: Settings) -> None: + self.settings = settings + + async def send_verification_email( + self, + *, + email: str, + token: str, + expires_in_minutes: int, + ) -> None: + payload = VerificationEmailPayload( + email=email.strip(), + token=token.strip(), + expires_in_minutes=expires_in_minutes, + ) + self._validate_payload(payload) + config = self._build_connection_config() + verification_link = self._build_verification_link(payload.token) + message = MessageSchema( + subject="Verify your FileFlash email", + recipients=[payload.email], + body=self._build_html_body( + verification_link=verification_link, + expires_in_minutes=payload.expires_in_minutes, + ), + alternative_body=self._build_text_body( + verification_link=verification_link, + expires_in_minutes=payload.expires_in_minutes, + ), + subtype=MessageType.html, + multipart_subtype=MultipartSubtypeEnum.alternative, + ) + try: + await FastMail(config).send_message(message) + except Exception as exc: # noqa: BLE001 + raise EmailDeliveryError("Failed to send verification email") from exc + + def _build_connection_config(self) -> ConnectionConfig: + issues = self.settings.mail_configuration_issues + if issues: + raise EmailDeliveryConfigurationError(f"Mail delivery is not configured: {', '.join(issues)}") + + return ConnectionConfig( + MAIL_USERNAME=(self.settings.mail_username or "").strip(), + MAIL_PASSWORD=(self.settings.mail_password or "").strip(), + MAIL_FROM=(self.settings.mail_from or "").strip(), + MAIL_PORT=self.settings.mail_port, + MAIL_SERVER=(self.settings.mail_server or "").strip(), + MAIL_STARTTLS=self.settings.mail_starttls, + MAIL_SSL_TLS=self.settings.mail_ssl_tls, + USE_CREDENTIALS=self.settings.mail_use_credentials, + VALIDATE_CERTS=self.settings.mail_validate_certs, + ) + + def _build_verification_link(self, token: str) -> str: + base = self.settings.normalized_email_verify_base_url + encoded_token = quote(token, safe="") + return f"{base}/verify-email?token={encoded_token}" + + @staticmethod + def _build_text_body(*, verification_link: str, expires_in_minutes: int) -> str: + return ( + "Welcome to FileFlash.\n\n" + "Please verify your email by opening the following link:\n" + f"{verification_link}\n\n" + f"This link expires in {expires_in_minutes} minutes." + ) + + @staticmethod + def _build_html_body(*, verification_link: str, expires_in_minutes: int) -> str: + return ( + "" + "" + "" + "" + "" + "Verify your FileFlash email" + "" + "" + "" + "
" + "" + "
" + "
FileFlash
" + "

Verify your email

" + "

" + "Confirm your account to unlock the complete FileFlash experience." + "

" + "" + "
" + f"" + "Verify Email
" + f"

Link expires in {expires_in_minutes} minutes.

" + "

" + "If the button does not work, open this URL in your browser:
" + f"{verification_link}" + "

" + "
" + "
" + "" + ) + + @staticmethod + def _validate_payload(payload: VerificationEmailPayload) -> None: + if not payload.email: + raise EmailDeliveryError("Verification email target is empty") + if not payload.token: + raise EmailDeliveryError("Verification token is empty") + if payload.expires_in_minutes <= 0: + raise EmailDeliveryError("Verification email expiry must be positive") diff --git a/app/src/services/file.py b/app/src/fileflash/services/file.py similarity index 85% rename from app/src/services/file.py rename to app/src/fileflash/services/file.py index 8e27450..3d910da 100644 --- a/app/src/services/file.py +++ b/app/src/fileflash/services/file.py @@ -25,7 +25,7 @@ from ..models.enums import FavoriteItemType, FileStatus, FolderStatus, FolderType, ShareStatus, UploadStatus from ..models.tables_access_share import FavoriteItem, Share from ..models.tables_identity import User -from ..models.tables_storage import File, Folder, StorageObject +from ..models.tables_storage import File, FileMediaMetadata, Folder, StorageObject from ..s3.minio_client import MinioObjectStorageClient from ..schemas.common import PaginatedData, PaginationMeta from ..schemas.file import ( @@ -39,6 +39,7 @@ FileDetails, FileItem, GetFilesQuery, + MediaOptimization, MoveFileRequest, MoveFileResponse, RenameFileRequest, @@ -61,6 +62,8 @@ _RECYCLE_RETENTION_DAYS = 30 +TRANSCODE_READY_STATUS = "ready" + @dataclass(slots=True) class DownloadStreamResult: @@ -72,9 +75,16 @@ class DownloadStreamResult: class FileService: - def __init__(self, *, db: AsyncSession, storage: MinioObjectStorageClient | None = None) -> None: + def __init__( + self, + *, + db: AsyncSession, + storage: MinioObjectStorageClient | None = None, + starred_items_limit: int = 20, + ) -> None: self.db = db self.storage = storage + self.starred_items_limit = starred_items_limit async def list_files(self, *, user_id: int, query: GetFilesQuery) -> PaginatedData[FileItem]: folder_id = await self._resolve_folder_id(user_id, query.folder_id) @@ -108,8 +118,17 @@ async def list_files(self, *, user_id: int, query: GetFilesQuery) -> PaginatedDa rows = (await self.db.execute(base.offset(offset).limit(per_page))).all() starred_ids = await self._starred_file_ids(user_id, [r[0].file_id for r in rows]) - - items = [self._to_file_item(f, username, f.file_id in starred_ids) for f, username in rows] + media_optimization_map = await self._load_media_optimization_map([r[0] for r in rows]) + + items = [ + self._to_file_item( + f, + username, + f.file_id in starred_ids, + media_optimization=media_optimization_map.get(int(f.file_id)), + ) + for f, username in rows + ] return self._paginate(items, total, query.page, per_page) async def get_file(self, *, user_id: int, file_id: int) -> FileDetails: @@ -132,6 +151,7 @@ async def get_file(self, *, user_id: int, file_id: int) -> FileDetails: f, username = row is_starred = await self._is_file_starred(user_id, file_id) + media_optimization = await self._load_file_media_optimization(f) return FileDetails( id=str(f.file_id), name=f.file_name, @@ -147,6 +167,7 @@ async def get_file(self, *, user_id: int, file_id: int) -> FileDetails: folder_id=str(f.folder_id), permission="owner", is_starred=is_starred, + media_optimization=media_optimization, status=True, ) @@ -196,25 +217,27 @@ async def toggle_file_star( is_starred: bool, ) -> FileDetails: file_row = await self._get_active_file(user_id=user_id, file_id=file_id, for_update=True) - favorite = await self.db.scalar( - select(FavoriteItem).where( - and_( - FavoriteItem.user_id == user_id, - FavoriteItem.item_type == FavoriteItemType.FILE, - FavoriteItem.file_id == int(file_row.file_id), - ) - ) - ) + favorite = await self._get_file_favorite(user_id=user_id, file_id=int(file_row.file_id)) if is_starred and favorite is None: - self.db.add( - FavoriteItem( - user_id=user_id, - item_type=FavoriteItemType.FILE, - file_id=int(file_row.file_id), - folder_id=None, + await self._lock_user_for_star_update(user_id=user_id) + favorite = await self._get_file_favorite(user_id=user_id, file_id=int(file_row.file_id)) + if favorite is None: + starred_count = await self._count_starred_items(user_id=user_id) + if starred_count >= self.starred_items_limit: + raise ApiError( + status_code=400, + code=400, + message=f"已达收藏上限 {self.starred_items_limit}", + ) + self.db.add( + FavoriteItem( + user_id=user_id, + item_type=FavoriteItemType.FILE, + file_id=int(file_row.file_id), + folder_id=None, + ) ) - ) elif not is_starred and favorite is not None: await self.db.delete(favorite) @@ -261,7 +284,10 @@ async def _get_file_stream( raise ApiError(status_code=503, code=503, message="Object storage is unavailable") file_row = await self._get_active_file(user_id=user_id, file_id=file_id) - storage_object = await self.db.get(StorageObject, int(file_row.storage_object_id)) + storage_object = await self._resolve_stream_storage_object( + file_row=file_row, + prefer_optimized=(content_disposition == "inline"), + ) if storage_object is None or storage_object.upload_status != UploadStatus.ACTIVE: raise ApiError(status_code=404, code=404, message="File content not found") @@ -286,7 +312,10 @@ async def _get_file_stream( byte_range = self._parse_range_header(range_header=range_header, file_size=object_size) if byte_range is None: headers["Content-Length"] = str(object_size) - stream = self.storage.iter_object(object_key=storage_object.object_key) + stream = self.storage.iter_object( + bucket_name=storage_object.bucket_name, + object_key=storage_object.object_key, + ) return DownloadStreamResult( stream=stream, filename=file_row.file_name, @@ -299,6 +328,7 @@ async def _get_file_stream( headers["Content-Length"] = str(end - start + 1) headers["Content-Range"] = f"bytes {start}-{end}/{object_size}" stream = self.storage.iter_object_range( + bucket_name=storage_object.bucket_name, object_key=storage_object.object_key, start=start, end=end, @@ -396,7 +426,10 @@ async def create_batch_download_archive( for file_row, storage_object in files_with_storage: zip_path = self._safe_zip_path(file_paths.get(int(file_row.file_id), file_row.file_name)) with archive.open(zip_path, mode="w") as entry: - async for chunk in self.storage.iter_object(object_key=storage_object.object_key): + async for chunk in self.storage.iter_object( + bucket_name=storage_object.bucket_name, + object_key=storage_object.object_key, + ): entry.write(chunk) except Exception as exc: # noqa: BLE001 if os.path.exists(tmp_path): @@ -676,40 +709,61 @@ async def clear_recycle_bin(self, *, user_id: int) -> ClearRecycleBinResponse: async def list_starred(self, *, user_id: int) -> PaginatedData[ContentItem]: file_rows = ( await self.db.execute( - select(File, User.username) + select(FavoriteItem.created_at, File, User.username) .join(User, User.user_id == File.owner_id) .join( FavoriteItem, and_( FavoriteItem.file_id == File.file_id, FavoriteItem.user_id == user_id, + FavoriteItem.item_type == FavoriteItemType.FILE, ), ) - .where(and_(File.owner_id == user_id, File.status == FileStatus.ACTIVE)) + .where( + and_( + File.owner_id == user_id, + File.status == FileStatus.ACTIVE, + File.is_latest.is_(True), + ) + ) ) ).all() folder_rows = ( await self.db.execute( - select(Folder, User.username) + select(FavoriteItem.created_at, Folder, User.username) .join(User, User.user_id == Folder.owner_id) .join( FavoriteItem, and_( FavoriteItem.folder_id == Folder.folder_id, FavoriteItem.user_id == user_id, + FavoriteItem.item_type == FavoriteItemType.FOLDER, ), ) .where(and_(Folder.owner_id == user_id, Folder.status == FolderStatus.ACTIVE)) ) ).all() - items: list[ContentItem] = [] - for f, username in file_rows: - items.append(self._to_file_item(f, username, is_starred=True)) - for folder, username in folder_rows: - items.append(self._to_folder_item(folder, username, is_starred=True)) + media_optimization_map = await self._load_media_optimization_map([f for _, f, _ in file_rows]) + starred_items: list[tuple[datetime, ContentItem]] = [] + for starred_at, f, username in file_rows: + starred_items.append( + ( + starred_at, + self._to_file_item( + f, + username, + is_starred=True, + media_optimization=media_optimization_map.get(int(f.file_id)), + ), + ) + ) + for starred_at, folder, username in folder_rows: + starred_items.append((starred_at, self._to_folder_item(folder, username, is_starred=True))) + starred_items.sort(key=lambda entry: (entry[0], entry[1].id), reverse=True) + items = [item for _, item in starred_items] return self._paginate(items, len(items), 1, max(len(items), 1)) async def move_file(self, *, user_id: int, file_id: str, payload: MoveFileRequest) -> MoveFileResponse: @@ -1523,7 +1577,10 @@ async def _cleanup_storage_object_if_orphan(self, object_id: int) -> int: raise ApiError(status_code=503, code=503, message="Object storage is unavailable") try: - await self.storage.remove_object(object_key=storage_object.object_key) + await self.storage.remove_object( + bucket_name=storage_object.bucket_name, + object_key=storage_object.object_key, + ) except Exception as exc: # noqa: BLE001 raise ApiError( status_code=503, @@ -1658,6 +1715,30 @@ async def _resolve_folder_id(self, user_id: int, folder_id_str: str | None) -> i raise ApiError(status_code=404, code=404, message="Folder not found") return int(fid) + async def _get_file_favorite(self, *, user_id: int, file_id: int) -> FavoriteItem | None: + return await self.db.scalar( + select(FavoriteItem).where( + and_( + FavoriteItem.user_id == user_id, + FavoriteItem.item_type == FavoriteItemType.FILE, + FavoriteItem.file_id == file_id, + ) + ) + ) + + async def _lock_user_for_star_update(self, *, user_id: int) -> None: + locked_user = await self.db.scalar( + select(User.user_id).where(User.user_id == user_id).with_for_update() + ) + if locked_user is None: + raise ApiError(status_code=404, code=404, message="User not found") + + async def _count_starred_items(self, *, user_id: int) -> int: + count = await self.db.scalar( + select(func.count(FavoriteItem.favorite_id)).where(FavoriteItem.user_id == user_id) + ) + return int(count or 0) + async def _starred_file_ids(self, user_id: int, file_ids: list[int]) -> set[int]: if not file_ids: return set() @@ -1684,8 +1765,157 @@ async def _is_file_starred(self, user_id: int, file_id: int) -> bool: ) ) is not None + async def _load_media_optimization_map(self, files: list[File]) -> dict[int, MediaOptimization]: + if not files: + return {} + + source_object_ids = [int(row.storage_object_id) for row in files if row.storage_object_id is not None] + if not source_object_ids: + return {} + + metadata_rows = list( + await self.db.scalars( + select(FileMediaMetadata).where(FileMediaMetadata.source_object_id.in_(source_object_ids)) + ) + ) + by_object_id = { + int(row.source_object_id): row for row in metadata_rows if isinstance(row, FileMediaMetadata) + } + + result: dict[int, MediaOptimization] = {} + for file_row in files: + media = self._parse_media_optimization(by_object_id.get(int(file_row.storage_object_id))) + if media is not None: + result[int(file_row.file_id)] = media + return result + + async def _load_file_media_optimization(self, file_row: File) -> MediaOptimization | None: + metadata_row = await self.db.scalar( + select(FileMediaMetadata) + .where(FileMediaMetadata.source_object_id == int(file_row.storage_object_id)) + .limit(1) + ) + if not isinstance(metadata_row, FileMediaMetadata): + return None + return self._parse_media_optimization(metadata_row) + + def _parse_media_optimization(self, metadata_row: FileMediaMetadata | None) -> MediaOptimization | None: + if metadata_row is None: + return None + extra = metadata_row.extra_metadata or {} + transcode = extra.get("transcode") + if not isinstance(transcode, dict): + return None + + status = str(transcode.get("status") or "").strip().lower() + media_type = str(transcode.get("mediaType") or "").strip().lower() + updated_at_raw = transcode.get("updatedAt") + optimized_mime_type = transcode.get("optimizedMimeType") + if status not in {"queued", "running", "ready", "failed"}: + return None + if media_type not in {"audio", "video"}: + return None + + updated_at = self._parse_datetime(updated_at_raw) or metadata_row.extracted_at + if not updated_at: + return None + + return MediaOptimization( + status=status, # type: ignore[arg-type] + media_type=media_type, # type: ignore[arg-type] + optimized_mime_type=str(optimized_mime_type) if optimized_mime_type else None, + updated_at=updated_at, + ) + + async def _resolve_stream_storage_object( + self, + *, + file_row: File, + prefer_optimized: bool, + ) -> StorageObject | None: + source_object = await self.db.get(StorageObject, int(file_row.storage_object_id)) + if source_object is None: + return None + if not prefer_optimized: + return source_object + + metadata_row = await self.db.scalar( + select(FileMediaMetadata) + .where(FileMediaMetadata.source_object_id == int(file_row.storage_object_id)) + .limit(1) + ) + if not isinstance(metadata_row, FileMediaMetadata): + return source_object + transcode = (metadata_row.extra_metadata or {}).get("transcode") + if not isinstance(transcode, dict): + return source_object + if str(transcode.get("status") or "").strip().lower() != TRANSCODE_READY_STATUS: + return source_object + + bucket_name = str(transcode.get("optimizedBucketName") or "").strip() + object_key = str(transcode.get("optimizedObjectKey") or "").strip() + if not bucket_name or not object_key: + return source_object + + optimized_object = await self.db.scalar( + select(StorageObject) + .where( + and_( + StorageObject.bucket_name == bucket_name, + StorageObject.object_key == object_key, + StorageObject.upload_status == UploadStatus.ACTIVE, + ) + ) + .limit(1) + ) + if isinstance(optimized_object, StorageObject): + return optimized_object + + if self.storage is None: + return source_object + exists = await self.storage.object_exists(bucket_name=bucket_name, object_key=object_key) + if not exists: + return source_object + + stat = await self.storage.stat_object(bucket_name=bucket_name, object_key=object_key) + created = StorageObject( + bucket_name=bucket_name, + object_key=object_key, + object_size=int(stat.size), + etag=stat.etag, + version_id=stat.version_id, + content_type=stat.content_type, + upload_status=UploadStatus.ACTIVE, + ) + self.db.add(created) + await self.db.flush() + return created + @staticmethod - def _to_file_item(f: File, owner_name: str, is_starred: bool) -> FileItem: + def _parse_datetime(raw: object) -> datetime | None: + if raw is None: + return None + if isinstance(raw, datetime): + return raw + text = str(raw).strip() + if not text: + return None + try: + value = datetime.fromisoformat(text.replace("Z", "+00:00")) + except ValueError: + return None + if value.tzinfo is None: + return value.replace(tzinfo=UTC) + return value + + @staticmethod + def _to_file_item( + f: File, + owner_name: str, + is_starred: bool, + *, + media_optimization: MediaOptimization | None = None, + ) -> FileItem: return FileItem( id=str(f.file_id), name=f.file_name, @@ -1701,6 +1931,7 @@ def _to_file_item(f: File, owner_name: str, is_starred: bool) -> FileItem: folder_id=str(f.folder_id), permission="owner", is_starred=is_starred, + media_optimization=media_optimization, ) @staticmethod diff --git a/app/src/services/folder.py b/app/src/fileflash/services/folder.py similarity index 79% rename from app/src/services/folder.py rename to app/src/fileflash/services/folder.py index 8a9cbfc..c84c1b0 100644 --- a/app/src/services/folder.py +++ b/app/src/fileflash/services/folder.py @@ -17,7 +17,7 @@ from ..models.enums import FavoriteItemType, FileStatus, FolderStatus, FolderType from ..models.tables_access_share import FavoriteItem from ..models.tables_identity import User -from ..models.tables_storage import File, Folder +from ..models.tables_storage import File, FileMediaMetadata, Folder from ..schemas.common import PaginatedData, PaginationMeta from ..schemas.file import ( ContentItem, @@ -28,6 +28,7 @@ FolderPathResponse, FolderSizeResponse, GetFolderContentsQuery, + MediaOptimization, MoveFolderRequest, MoveFolderResponse, PathItem, @@ -50,8 +51,9 @@ class FolderService: - def __init__(self, *, db: AsyncSession) -> None: + def __init__(self, *, db: AsyncSession, starred_items_limit: int = 20) -> None: self.db = db + self.starred_items_limit = starred_items_limit async def get_root_contents( self, *, user_id: int, query: GetFolderContentsQuery, @@ -114,13 +116,21 @@ async def get_folder_contents( file_ids = [r[0].file_id for r in file_rows] starred_folders = await self._starred_folder_ids(user_id, folder_ids) starred_files = await self._starred_file_ids(user_id, file_ids) + media_optimization_map = await self._load_media_optimization_map([f for f, _ in file_rows]) # merge: folders first, then files all_items: list[ContentItem] = [] for folder, uname in folder_rows: all_items.append(self._to_folder_item(folder, uname, folder.folder_id in starred_folders)) for f, uname in file_rows: - all_items.append(self._to_file_item(f, uname, f.file_id in starred_files)) + all_items.append( + self._to_file_item( + f, + uname, + f.file_id in starred_files, + media_optimization=media_optimization_map.get(int(f.file_id)), + ) + ) total = len(all_items) per_page = query.per_page @@ -252,25 +262,27 @@ async def toggle_folder_star( if folder is None: raise ApiError(status_code=404, code=404, message="Folder not found") - favorite = await self.db.scalar( - select(FavoriteItem).where( - and_( - FavoriteItem.user_id == user_id, - FavoriteItem.item_type == FavoriteItemType.FOLDER, - FavoriteItem.folder_id == fid, - ) - ) - ) + favorite = await self._get_folder_favorite(user_id=user_id, folder_id=fid) if is_starred and favorite is None: - self.db.add( - FavoriteItem( - user_id=user_id, - item_type=FavoriteItemType.FOLDER, - file_id=None, - folder_id=fid, + await self._lock_user_for_star_update(user_id=user_id) + favorite = await self._get_folder_favorite(user_id=user_id, folder_id=fid) + if favorite is None: + starred_count = await self._count_starred_items(user_id=user_id) + if starred_count >= self.starred_items_limit: + raise ApiError( + status_code=400, + code=400, + message=f"已达收藏上限 {self.starred_items_limit}", + ) + self.db.add( + FavoriteItem( + user_id=user_id, + item_type=FavoriteItemType.FOLDER, + file_id=None, + folder_id=fid, + ) ) - ) elif not is_starred and favorite is not None: await self.db.delete(favorite) @@ -448,6 +460,30 @@ async def _ensure_folder_access(self, user_id: int, folder_id: int) -> None: if exists is None: raise ApiError(status_code=404, code=404, message="Folder not found") + async def _get_folder_favorite(self, *, user_id: int, folder_id: int) -> FavoriteItem | None: + return await self.db.scalar( + select(FavoriteItem).where( + and_( + FavoriteItem.user_id == user_id, + FavoriteItem.item_type == FavoriteItemType.FOLDER, + FavoriteItem.folder_id == folder_id, + ) + ) + ) + + async def _lock_user_for_star_update(self, *, user_id: int) -> None: + locked_user = await self.db.scalar( + select(User.user_id).where(User.user_id == user_id).with_for_update() + ) + if locked_user is None: + raise ApiError(status_code=404, code=404, message="User not found") + + async def _count_starred_items(self, *, user_id: int) -> int: + count = await self.db.scalar( + select(func.count(FavoriteItem.favorite_id)).where(FavoriteItem.user_id == user_id) + ) + return int(count or 0) + async def _starred_file_ids(self, user_id: int, file_ids: list[int]) -> set[int]: if not file_ids: return set() @@ -485,8 +521,68 @@ def _parse_id(value: str, name: str) -> int: except ValueError as exc: raise ApiError(status_code=400, code=400, message=f"Invalid {name}") from exc + async def _load_media_optimization_map(self, files: list[File]) -> dict[int, MediaOptimization]: + if not files: + return {} + source_object_ids = [int(row.storage_object_id) for row in files if row.storage_object_id is not None] + if not source_object_ids: + return {} + + metadata_rows = list( + await self.db.scalars( + select(FileMediaMetadata).where(FileMediaMetadata.source_object_id.in_(source_object_ids)) + ) + ) + by_object_id = { + int(row.source_object_id): row for row in metadata_rows if isinstance(row, FileMediaMetadata) + } + result: dict[int, MediaOptimization] = {} + for row in files: + media = self._parse_media_optimization(by_object_id.get(int(row.storage_object_id))) + if media is not None: + result[int(row.file_id)] = media + return result + + @staticmethod + def _parse_media_optimization(metadata_row: FileMediaMetadata | None) -> MediaOptimization | None: + if metadata_row is None: + return None + transcode = (metadata_row.extra_metadata or {}).get("transcode") + if not isinstance(transcode, dict): + return None + status = str(transcode.get("status") or "").strip().lower() + media_type = str(transcode.get("mediaType") or "").strip().lower() + if status not in {"queued", "running", "ready", "failed"}: + return None + if media_type not in {"audio", "video"}: + return None + updated_at_raw = transcode.get("updatedAt") + if isinstance(updated_at_raw, datetime): + updated_at = updated_at_raw + else: + text = str(updated_at_raw or "").strip() + try: + updated_at = datetime.fromisoformat(text.replace("Z", "+00:00")) if text else metadata_row.extracted_at + except ValueError: + updated_at = metadata_row.extracted_at + if updated_at.tzinfo is None: + updated_at = updated_at.replace(tzinfo=UTC) + optimized_mime_type = transcode.get("optimizedMimeType") + return MediaOptimization( + status=status, # type: ignore[arg-type] + media_type=media_type, # type: ignore[arg-type] + optimized_mime_type=str(optimized_mime_type) if optimized_mime_type else None, + updated_at=updated_at, + ) + @staticmethod - def _to_file_item(f: File, owner_name: str, is_starred: bool) -> FileItem: + def _to_file_item( + f: File, + owner_name: str, + is_starred: bool, + *, + media_optimization: MediaOptimization | None = None, + ) -> FileItem: return FileItem( id=str(f.file_id), name=f.file_name, @@ -502,6 +598,7 @@ def _to_file_item(f: File, owner_name: str, is_starred: bool) -> FileItem: folder_id=str(f.folder_id), permission="owner", is_starred=is_starred, + media_optimization=media_optimization, ) @staticmethod diff --git a/app/src/services/job_queue.py b/app/src/fileflash/services/job_queue.py similarity index 100% rename from app/src/services/job_queue.py rename to app/src/fileflash/services/job_queue.py diff --git a/app/src/services/messaging.py b/app/src/fileflash/services/messaging.py similarity index 87% rename from app/src/services/messaging.py rename to app/src/fileflash/services/messaging.py index 7ba96b9..6ecab1b 100644 --- a/app/src/services/messaging.py +++ b/app/src/fileflash/services/messaging.py @@ -23,5 +23,6 @@ async def publish(self, event_name: str, payload: Mapping[str, object]) -> None: class InProcessAuthEventPublisher: async def publish(self, event_name: str, payload: Mapping[str, object]) -> None: event = AuthEvent(name=event_name, payload=payload, created_at=datetime.now(UTC)) - logger.info("Auth event published in-process: %s %s", event.name, dict(event.payload)) + logger.info("Auth event published in-process: %s", event.name) + _ = event.payload diff --git a/app/src/services/rate_limiter.py b/app/src/fileflash/services/rate_limiter.py similarity index 100% rename from app/src/services/rate_limiter.py rename to app/src/fileflash/services/rate_limiter.py diff --git a/app/src/fileflash/services/registration_email_domain_rule.py b/app/src/fileflash/services/registration_email_domain_rule.py new file mode 100644 index 0000000..ba6f66b --- /dev/null +++ b/app/src/fileflash/services/registration_email_domain_rule.py @@ -0,0 +1,243 @@ +from __future__ import annotations + +import re +from datetime import UTC, datetime + +from sqlalchemy import Select, and_, func, or_, select +from sqlalchemy.ext.asyncio import AsyncSession + +from ..core.errors import ApiError +from ..models.tables_identity import RegistrationEmailDomainRule +from ..schemas.common import PaginatedData, PaginationMeta +from ..schemas.registration_email_domain_rule import ( + CreateRegistrationEmailDomainRuleRequest, + ListRegistrationEmailDomainRulesQuery, + RegistrationEmailDomainRuleItem, + UpdateRegistrationEmailDomainRuleRequest, +) + + +class RegistrationEmailDomainRuleService: + _DISALLOWED_MESSAGE = "邮箱后缀不被允许,请更换邮箱" + _PATTERN_MAX_LENGTH = 512 + + def __init__(self, db: AsyncSession) -> None: + self.db = db + + async def list_rules( + self, + *, + query: ListRegistrationEmailDomainRulesQuery, + ) -> PaginatedData[RegistrationEmailDomainRuleItem]: + statement = select(RegistrationEmailDomainRule) + if query.enabled is not None: + statement = statement.where(RegistrationEmailDomainRule.enabled == query.enabled) + if query.query_text: + keyword = query.query_text.strip().lower() + if keyword: + like = f"%{keyword}%" + statement = statement.where( + or_( + func.lower(RegistrationEmailDomainRule.name).like(like), + func.lower(RegistrationEmailDomainRule.pattern).like(like), + ) + ) + + total = await self.db.scalar(select(func.count()).select_from(statement.subquery())) + total_items = int(total or 0) + total_pages = max(1, -(-total_items // query.per_page)) + offset = (query.page - 1) * query.per_page + + rows = list( + await self.db.scalars( + statement + .order_by(RegistrationEmailDomainRule.rule_id.desc()) + .offset(offset) + .limit(query.per_page) + ) + ) + items = [self._to_item(row) for row in rows] + return PaginatedData( + items=items, + pagination=PaginationMeta( + total_items=total_items, + total_pages=total_pages, + per_page=query.per_page, + current_page=query.page, + has_prev=query.page > 1, + has_next=query.page < total_pages, + ), + ) + + async def create_rule( + self, + *, + payload: CreateRegistrationEmailDomainRuleRequest, + ) -> RegistrationEmailDomainRuleItem: + name = payload.name.strip() + pattern = payload.pattern.strip() + self._validate_pattern(pattern) + await self._ensure_name_unique(name=name) + + now = datetime.now(UTC) + row = RegistrationEmailDomainRule( + name=name, + pattern=pattern, + enabled=payload.enabled, + created_at=now, + updated_at=now, + ) + self.db.add(row) + await self.db.commit() + await self.db.refresh(row) + return self._to_item(row) + + async def update_rule( + self, + *, + rule_id: int, + payload: UpdateRegistrationEmailDomainRuleRequest, + ) -> RegistrationEmailDomainRuleItem: + row = await self.db.get(RegistrationEmailDomainRule, rule_id) + if row is None: + raise ApiError(status_code=404, code=404, message="Rule not found") + + changed = False + if payload.name is not None: + next_name = payload.name.strip() + if next_name.lower() != row.name.lower(): + await self._ensure_name_unique(name=next_name, exclude_rule_id=rule_id) + row.name = next_name + changed = True + if payload.pattern is not None: + next_pattern = payload.pattern.strip() + self._validate_pattern(next_pattern) + row.pattern = next_pattern + changed = True + if payload.enabled is not None: + row.enabled = payload.enabled + changed = True + + if changed: + row.updated_at = datetime.now(UTC) + await self.db.commit() + await self.db.refresh(row) + return self._to_item(row) + + async def delete_rule(self, *, rule_id: int) -> None: + row = await self.db.get(RegistrationEmailDomainRule, rule_id) + if row is None: + raise ApiError(status_code=404, code=404, message="Rule not found") + await self.db.delete(row) + await self.db.commit() + + async def assert_email_allowed(self, *, email: str) -> None: + domain = self._extract_domain(email) + rules = list( + await self.db.scalars( + select(RegistrationEmailDomainRule).where(RegistrationEmailDomainRule.enabled.is_(True)) + ) + ) + if not rules: + raise ApiError(status_code=400, code=400, message=self._DISALLOWED_MESSAGE) + + for rule in rules: + try: + if re.fullmatch(rule.pattern, domain): + return + except re.error: + continue + raise ApiError(status_code=400, code=400, message=self._DISALLOWED_MESSAGE) + + async def _ensure_name_unique(self, *, name: str, exclude_rule_id: int | None = None) -> None: + statement: Select[tuple[int]] = select(RegistrationEmailDomainRule.rule_id).where( + func.lower(RegistrationEmailDomainRule.name) == name.lower() + ) + if exclude_rule_id is not None: + statement = statement.where(RegistrationEmailDomainRule.rule_id != exclude_rule_id) + exists = await self.db.scalar(statement.limit(1)) + if exists is not None: + raise ApiError(status_code=409, code=409, message="Rule name already exists") + + @classmethod + def _validate_pattern(cls, pattern: str) -> None: + if not pattern: + raise ApiError(status_code=400, code=400, message="pattern cannot be empty") + if len(pattern) > cls._PATTERN_MAX_LENGTH: + raise ApiError(status_code=400, code=400, message="pattern is too long") + if cls._looks_risky_pattern(pattern): + raise ApiError(status_code=400, code=400, message="Regex pattern is too risky") + try: + re.compile(pattern) + except re.error as exc: + raise ApiError(status_code=400, code=400, message=f"Invalid regex pattern: {exc}") from exc + + @staticmethod + def _extract_domain(email: str) -> str: + if "@" not in email: + raise ApiError(status_code=400, code=400, message="Invalid email") + _local, sep, domain = email.strip().rpartition("@") + if not sep or not domain: + raise ApiError(status_code=400, code=400, message="Invalid email") + normalized = domain.strip().lower() + if not normalized: + raise ApiError(status_code=400, code=400, message="Invalid email") + return normalized + + @staticmethod + def _looks_risky_pattern(pattern: str) -> bool: + # Reject backreferences (\1, \g<...>) and nested quantifiers like (a+)+ + if re.search(r"\\[1-9]", pattern): + return True + if re.search(r"\\g<[^>]+>", pattern): + return True + + depth = 0 + group_has_quantifier: list[bool] = [] + escaped = False + for index, ch in enumerate(pattern): + if escaped: + escaped = False + continue + if ch == "\\": + escaped = True + continue + if ch == "(": + depth += 1 + group_has_quantifier.append(False) + continue + if ch in {"*", "+", "?"}: + if depth > 0 and group_has_quantifier: + group_has_quantifier[-1] = True + continue + if ch == "{": + close = pattern.find("}", index + 1) + if close != -1 and depth > 0 and group_has_quantifier: + group_has_quantifier[-1] = True + continue + if ch == ")" and depth > 0: + had_quantifier = group_has_quantifier.pop() + depth -= 1 + j = index + 1 + while j < len(pattern) and pattern[j] in {" ", "\t"}: + j += 1 + if had_quantifier and j < len(pattern): + if pattern[j] in {"*", "+", "?"}: + return True + if pattern[j] == "{": + close = pattern.find("}", j + 1) + if close != -1: + return True + return False + + @staticmethod + def _to_item(row: RegistrationEmailDomainRule) -> RegistrationEmailDomainRuleItem: + return RegistrationEmailDomainRuleItem( + rule_id=str(row.rule_id), + name=row.name, + pattern=row.pattern, + enabled=row.enabled, + created_at=row.created_at, + updated_at=row.updated_at, + ) + diff --git a/app/src/services/share.py b/app/src/fileflash/services/share.py similarity index 83% rename from app/src/services/share.py rename to app/src/fileflash/services/share.py index b75ff8f..4d0ce35 100644 --- a/app/src/services/share.py +++ b/app/src/fileflash/services/share.py @@ -12,6 +12,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from ..core.errors import ApiError +from ..core.http_headers import build_content_disposition from ..core.security import ( create_share_access_token, decode_share_access_token, @@ -27,7 +28,7 @@ ) from ..models.enums import FileStatus, FolderStatus, FolderType, ShareStatus from ..models.tables_access_share import Share, ShareAccessLog -from ..models.tables_storage import File, Folder, StorageObject +from ..models.tables_storage import File, FileMediaMetadata, Folder, StorageObject from ..s3.minio_client import MinioObjectStorageClient from ..schemas.common import PaginatedData, PaginationMeta from ..schemas.share import ( @@ -401,8 +402,28 @@ async def get_shared_file_stream( action: Literal["download", "preview"], ip_address: str, user_agent: str | None, - ) -> tuple[AsyncIterator[bytes], str, str]: - async def _operation() -> tuple[AsyncIterator[bytes], str, str]: + ) -> tuple[AsyncIterator[bytes], str, str, int]: + stream, filename, content_type, status_code, _headers = await self.get_shared_file_download_stream_response( + share_link=share_link, + share_access_token=share_access_token, + action=action, + range_header=None, + ip_address=ip_address, + user_agent=user_agent, + ) + return stream, filename, content_type, status_code + + async def get_shared_file_download_stream_response( + self, + *, + share_link: str, + share_access_token: str, + action: Literal["download", "preview"], + range_header: str | None, + ip_address: str, + user_agent: str | None, + ) -> tuple[AsyncIterator[bytes], str, str, int, dict[str, str]]: + async def _operation() -> tuple[AsyncIterator[bytes], str, str, int, dict[str, str]]: await apply_local_lock_timeout(self.db) share_row = await self._resolve_share_for_access_token( share_link=share_link, @@ -418,9 +439,15 @@ async def _operation() -> tuple[AsyncIterator[bytes], str, str]: raise ApiError(status_code=403, code=403, message="Preview is not allowed for this share") file_row = await self._get_active_file(file_id=int(share_row.file_id), owner_id=share_row.user_id) - storage_object = await self.db.get(StorageObject, int(file_row.storage_object_id)) + storage_object = await self._resolve_shared_stream_storage_object( + file_row=file_row, + prefer_optimized=(action == "preview"), + ) if storage_object is None: raise ApiError(status_code=404, code=404, message="Shared file content not found") + object_size = int(storage_object.object_size or file_row.file_size or 0) + if object_size <= 0: + raise ApiError(status_code=404, code=404, message="Shared file content not found") if action == "download": await self.db.execute( @@ -439,8 +466,48 @@ async def _operation() -> tuple[AsyncIterator[bytes], str, str]: ) await self.db.commit() - content_type = file_row.mime_type or storage_object.content_type or "application/octet-stream" - return self.storage.iter_object(object_key=storage_object.object_key), file_row.file_name, content_type + content_type = ( + storage_object.content_type + or file_row.mime_type + or "application/octet-stream" + ) + headers = { + "Accept-Ranges": "bytes", + "Content-Disposition": build_content_disposition( + file_row.file_name, + disposition="attachment" if action == "download" else "inline", + ), + } + + byte_range = self._parse_range_header(range_header=range_header, file_size=object_size) + if byte_range is None: + headers["Content-Length"] = str(object_size) + return ( + self.storage.iter_object( + bucket_name=storage_object.bucket_name, + object_key=storage_object.object_key, + ), + file_row.file_name, + content_type, + 200, + headers, + ) + + start, end = byte_range + headers["Content-Length"] = str(end - start + 1) + headers["Content-Range"] = f"bytes {start}-{end}/{object_size}" + return ( + self.storage.iter_object_range( + bucket_name=storage_object.bucket_name, + object_key=storage_object.object_key, + start=start, + end=end, + ), + file_row.file_name, + content_type, + 206, + headers, + ) try: return await run_with_transaction_retry(self.db, _operation) @@ -877,6 +944,109 @@ async def _copy_folder_children( target_folder_id=target_folder_id, ) + async def _resolve_shared_stream_storage_object( + self, + *, + file_row: File, + prefer_optimized: bool, + ) -> StorageObject | None: + source_object = await self.db.get(StorageObject, int(file_row.storage_object_id)) + if source_object is None: + return None + if not prefer_optimized: + return source_object + + metadata_row = await self.db.scalar( + select(FileMediaMetadata) + .where(FileMediaMetadata.source_object_id == int(file_row.storage_object_id)) + .limit(1) + ) + if not isinstance(metadata_row, FileMediaMetadata): + return source_object + transcode = (metadata_row.extra_metadata or {}).get("transcode") + if not isinstance(transcode, dict): + return source_object + if str(transcode.get("status") or "").strip().lower() != "ready": + return source_object + + bucket_name = str(transcode.get("optimizedBucketName") or "").strip() + object_key = str(transcode.get("optimizedObjectKey") or "").strip() + if not bucket_name or not object_key: + return source_object + + optimized_object = await self.db.scalar( + select(StorageObject) + .where( + and_( + StorageObject.bucket_name == bucket_name, + StorageObject.object_key == object_key, + ) + ) + .limit(1) + ) + if isinstance(optimized_object, StorageObject): + return optimized_object + + exists = await self.storage.object_exists(bucket_name=bucket_name, object_key=object_key) + if not exists: + return source_object + stat = await self.storage.stat_object(bucket_name=bucket_name, object_key=object_key) + created = StorageObject( + bucket_name=bucket_name, + object_key=object_key, + object_size=int(stat.size), + etag=stat.etag, + version_id=stat.version_id, + content_type=stat.content_type, + ) + self.db.add(created) + await self.db.flush() + return created + + @staticmethod + def _parse_range_header(range_header: str | None, file_size: int) -> tuple[int, int] | None: + if not range_header: + return None + + value = range_header.strip() + if not value.lower().startswith("bytes="): + raise ApiError(status_code=416, code=416, message="Invalid Range header") + + spec = value[6:].strip() + if "," in spec: + raise ApiError(status_code=416, code=416, message="Multiple ranges are not supported") + + if spec.startswith("-"): + suffix_part = spec[1:].strip() + if not suffix_part.isdigit(): + raise ApiError(status_code=416, code=416, message="Invalid Range header") + suffix = int(suffix_part) + if suffix <= 0: + raise ApiError(status_code=416, code=416, message="Invalid Range header") + start = max(file_size - suffix, 0) + end = file_size - 1 + return start, end + + if "-" not in spec: + raise ApiError(status_code=416, code=416, message="Invalid Range header") + + start_part, end_part = spec.split("-", 1) + if not start_part.strip().isdigit(): + raise ApiError(status_code=416, code=416, message="Invalid Range header") + start = int(start_part.strip()) + end = file_size - 1 + if end_part.strip(): + if not end_part.strip().isdigit(): + raise ApiError(status_code=416, code=416, message="Invalid Range header") + end = int(end_part.strip()) + + if start < 0 or start >= file_size or end < start: + raise ApiError(status_code=416, code=416, message="Requested range is not satisfiable") + + if end >= file_size: + end = file_size - 1 + return start, end + async def _log_share_event( self, *, diff --git a/app/src/services/upload.py b/app/src/fileflash/services/upload.py similarity index 75% rename from app/src/services/upload.py rename to app/src/fileflash/services/upload.py index 04a3715..0b221d8 100644 --- a/app/src/services/upload.py +++ b/app/src/fileflash/services/upload.py @@ -1,9 +1,11 @@ from __future__ import annotations import hashlib +import json import logging from datetime import UTC, datetime, timedelta from pathlib import Path +from typing import Any from uuid import uuid4 from sqlalchemy import and_, func, or_, select @@ -28,18 +30,30 @@ UploadStatus, UploadTaskStatus, ) -from ..models.tables_storage import File, Folder, StorageObject, UploadTask, UploadTaskPart +from ..models.tables_storage import File, FileMediaMetadata, Folder, StorageObject, UploadTask, UploadTaskPart +from ..models.tables_worker import BackgroundJob from ..s3.minio_client import MinioObjectStorageClient, ObjectStorageError from ..schemas.file import MergeChunksRequest, MergeChunksResponse, UploadPreflightRequest, UploadPreflightResponse +from .background_jobs import BackgroundJobService logger = logging.getLogger(__name__) +TRANSCODE_PROFILE_VERSION = "mp4-v1" + class UploadService: - def __init__(self, *, db: AsyncSession, settings: Settings, storage: MinioObjectStorageClient) -> None: + def __init__( + self, + *, + db: AsyncSession, + settings: Settings, + storage: MinioObjectStorageClient, + jobs: BackgroundJobService | None = None, + ) -> None: self.db = db self.settings = settings self.storage = storage + self.jobs = jobs async def preflight(self, *, user_id: int, payload: UploadPreflightRequest) -> UploadPreflightResponse: async def _operation() -> UploadPreflightResponse: @@ -228,6 +242,69 @@ async def upload_chunk(self, *, user_id: int, upload_id: str, chunk_index: int, task.last_error = None await self.db.commit() + async def enqueue_merge_job( + self, + *, + user_id: int, + upload_id: str, + payload: MergeChunksRequest, + ) -> BackgroundJob: + jobs = self.jobs + if jobs is None: + raise ApiError(status_code=503, code=503, message="Job queue unavailable") + + # Normalize hash early to keep retries idempotent for equivalent requests. + normalized_hash, _ = self._normalize_hash(payload.file_hash) + normalized_payload = MergeChunksRequest( + file_hash=normalized_hash, + file_name=payload.file_name, + mime_type=payload.mime_type, + parent_id=payload.parent_id, + conflict_strategy=payload.conflict_strategy, + ) + merge_request_payload = normalized_payload.model_dump(by_alias=True, exclude_none=True) + job_payload: dict[str, Any] = { + "userId": user_id, + "uploadId": upload_id, + "mergeRequest": merge_request_payload, + } + idempotency_key = self._build_merge_job_idempotency_key( + user_id=user_id, + upload_id=upload_id, + merge_request_payload=merge_request_payload, + ) + return await jobs.enqueue( + self.db, + task_type="task.upload_merge", + payload=job_payload, + idempotency_key=idempotency_key, + requested_by=user_id, + ) + + async def execute_merge_job(self, *, payload: dict[str, Any]) -> dict[str, Any]: + raw_user_id = payload.get("userId") + try: + user_id = int(raw_user_id) + except (TypeError, ValueError) as exc: + raise ValueError("Merge job payload requires valid userId") from exc + if user_id <= 0: + raise ValueError("Merge job payload userId must be > 0") + + upload_id = str(payload.get("uploadId") or "").strip() + if not upload_id: + raise ValueError("Merge job payload requires uploadId") + + merge_request_raw = payload.get("mergeRequest") + if not isinstance(merge_request_raw, dict): + raise ValueError("Merge job payload requires mergeRequest object") + merge_request = MergeChunksRequest.model_validate(merge_request_raw) + response = await self.merge_chunks( + user_id=user_id, + upload_id=upload_id, + payload=merge_request, + ) + return response.model_dump(by_alias=True, mode="json") + async def merge_chunks( self, *, @@ -430,6 +507,13 @@ async def _operation() -> MergeChunksResponse: task.completed_at = now await self.db.commit() + if self._is_transcode_candidate(resolved_mime_type): + await self._enqueue_transcode_after_merge( + user_id=user_id, + file_row=file_row, + storage_object=storage_object, + ) + try: await self.storage.remove_objects(object_keys=source_keys) except Exception: # noqa: BLE001 @@ -740,6 +824,17 @@ def _build_part_object_key(self, *, task: UploadTask, chunk_index: int) -> str: upload_id = task.upload_id or f"task-{task.task_id}" return f"{self.settings.upload_temp_prefix}/u{task.user_id}/{upload_id}/part-{chunk_index:08d}" + @staticmethod + def _build_merge_job_idempotency_key( + *, + user_id: int, + upload_id: str, + merge_request_payload: dict[str, Any], + ) -> str: + canonical = json.dumps(merge_request_payload, sort_keys=True, separators=(",", ":"), ensure_ascii=False) + digest = hashlib.sha256(canonical.encode("utf-8")).hexdigest()[:32] + return f"upload:{user_id}:{upload_id}:merge:{digest}" + @staticmethod def _extract_ext(file_name: str) -> str | None: suffix = Path(file_name).suffix.strip(".").lower() @@ -752,3 +847,163 @@ def _is_hex(value: str) -> bool: @staticmethod def _normalize_task_hash(value: str | None) -> str: return (value or "").strip().lower() + + async def _enqueue_transcode_after_merge( + self, + *, + user_id: int, + file_row: File, + storage_object: StorageObject, + ) -> None: + source_object_id = await self._resolve_storage_object_id( + storage_object=storage_object, + file_row=file_row, + ) + if source_object_id is None: + logger.warning( + "Skip transcode enqueue because source object id is missing userId=%s fileId=%s objectKey=%s", + user_id, + file_row.file_id, + storage_object.object_key, + ) + return + + output_key = self._build_transcode_output_key( + source_object_id=source_object_id, + source_object_key=storage_object.object_key, + media_type=self._media_type_from_mime(file_row.mime_type or storage_object.content_type), + ) + idempotency_key = f"object:{source_object_id}:transcode:{TRANSCODE_PROFILE_VERSION}" + now = datetime.now(UTC) + + metadata = await self.db.scalar( + select(FileMediaMetadata).where(FileMediaMetadata.source_object_id == source_object_id).limit(1) + ) + if metadata is None: + metadata = FileMediaMetadata(source_object_id=source_object_id) + self.db.add(metadata) + + extra = dict(metadata.extra_metadata or {}) + extra["transcode"] = { + "status": "queued", + "mediaType": self._media_type_from_mime(file_row.mime_type or storage_object.content_type), + "profileVersion": TRANSCODE_PROFILE_VERSION, + "optimizedBucketName": self.settings.object_storage_bucket, + "optimizedObjectKey": output_key, + "updatedAt": now.isoformat(), + } + metadata.extra_metadata = extra + metadata.extracted_at = now + await self.db.commit() + + jobs = self.jobs + if jobs is None: + failed_extra = dict(metadata.extra_metadata or {}) + failed_extra["transcode"] = { + "status": "failed", + "mediaType": self._media_type_from_mime(file_row.mime_type or storage_object.content_type), + "profileVersion": TRANSCODE_PROFILE_VERSION, + "optimizedBucketName": self.settings.object_storage_bucket, + "optimizedObjectKey": output_key, + "error": "Job queue unavailable", + "updatedAt": datetime.now(UTC).isoformat(), + } + metadata.extra_metadata = failed_extra + metadata.extracted_at = datetime.now(UTC) + await self.db.commit() + return + + try: + await jobs.enqueue_transcode_job( + self.db, + source_bucket_name=storage_object.bucket_name, + source_object_key=storage_object.object_key, + source_object_id=source_object_id, + output_bucket_name=self.settings.object_storage_bucket, + output_object_key=output_key, + file_id=int(file_row.file_id), + requested_by=user_id, + idempotency_key=idempotency_key, + ) + except ApiError as exc: + logger.warning( + "Enqueue transcode failed but upload kept userId=%s fileId=%s objectId=%s error=%s", + user_id, + file_row.file_id, + storage_object.object_id, + exc.message, + ) + failed_extra = dict(metadata.extra_metadata or {}) + failed_extra["transcode"] = { + "status": "failed", + "mediaType": self._media_type_from_mime(file_row.mime_type or storage_object.content_type), + "profileVersion": TRANSCODE_PROFILE_VERSION, + "optimizedBucketName": self.settings.object_storage_bucket, + "optimizedObjectKey": output_key, + "error": exc.message[:500], + "updatedAt": datetime.now(UTC).isoformat(), + } + metadata.extra_metadata = failed_extra + metadata.extracted_at = datetime.now(UTC) + await self.db.commit() + except Exception as exc: # noqa: BLE001 + logger.exception( + "Unexpected transcode enqueue failure but upload kept userId=%s fileId=%s objectId=%s", + user_id, + file_row.file_id, + storage_object.object_id, + ) + failed_extra = dict(metadata.extra_metadata or {}) + failed_extra["transcode"] = { + "status": "failed", + "mediaType": self._media_type_from_mime(file_row.mime_type or storage_object.content_type), + "profileVersion": TRANSCODE_PROFILE_VERSION, + "optimizedBucketName": self.settings.object_storage_bucket, + "optimizedObjectKey": output_key, + "error": f"{type(exc).__name__}: {exc}"[:500], + "updatedAt": datetime.now(UTC).isoformat(), + } + metadata.extra_metadata = failed_extra + metadata.extracted_at = datetime.now(UTC) + await self.db.commit() + + async def _resolve_storage_object_id( + self, + *, + storage_object: StorageObject, + file_row: File, + ) -> int | None: + resolved = self._coerce_positive_int(storage_object.object_id) or self._coerce_positive_int(file_row.storage_object_id) + if resolved is not None: + return resolved + + await self.db.flush() + resolved = self._coerce_positive_int(storage_object.object_id) or self._coerce_positive_int(file_row.storage_object_id) + return resolved + + @staticmethod + def _is_transcode_candidate(mime_type: str | None) -> bool: + value = (mime_type or "").lower() + return value.startswith("video/") or value.startswith("audio/") + + @staticmethod + def _media_type_from_mime(mime_type: str | None) -> str: + value = (mime_type or "").lower() + return "video" if value.startswith("video/") else "audio" + + @staticmethod + def _build_transcode_output_key(*, source_object_id: int, source_object_key: str, media_type: str) -> str: + suffix = ".mp4" if media_type == "video" else ".m4a" + original_stem = Path(source_object_key).stem or f"obj-{source_object_id}" + return ( + f"optimized/transcode/v1/object-{source_object_id}/" + f"{original_stem}-{TRANSCODE_PROFILE_VERSION}{suffix}" + ) + + @staticmethod + def _coerce_positive_int(value: object) -> int | None: + try: + parsed = int(value) + except (TypeError, ValueError): + return None + return parsed if parsed > 0 else None diff --git a/app/src/tasks/__init__.py b/app/src/fileflash/tasks/__init__.py similarity index 100% rename from app/src/tasks/__init__.py rename to app/src/fileflash/tasks/__init__.py diff --git a/app/src/tasks/archive.py b/app/src/fileflash/tasks/archive.py similarity index 100% rename from app/src/tasks/archive.py rename to app/src/fileflash/tasks/archive.py diff --git a/app/src/tasks/registry.py b/app/src/fileflash/tasks/registry.py similarity index 100% rename from app/src/tasks/registry.py rename to app/src/fileflash/tasks/registry.py diff --git a/app/src/tasks/scan.py b/app/src/fileflash/tasks/scan.py similarity index 100% rename from app/src/tasks/scan.py rename to app/src/fileflash/tasks/scan.py diff --git a/app/src/fileflash/tasks/transcode.py b/app/src/fileflash/tasks/transcode.py new file mode 100644 index 0000000..31632a9 --- /dev/null +++ b/app/src/fileflash/tasks/transcode.py @@ -0,0 +1,325 @@ +from __future__ import annotations + +import asyncio +import json +import subprocess +import tempfile +from dataclasses import dataclass +from datetime import UTC, datetime +from pathlib import Path +from typing import Any + +from ..core.settings import get_settings +from ..s3.minio_client import MinioObjectStorageClient + +TRANSCODE_PROFILE_VERSION = "mp4-v1" + + +@dataclass(slots=True) +class TranscodeTaskPayload: + source_bucket_name: str + source_object_key: str + source_object_id: int + output_bucket_name: str + output_object_key: str + file_id: int | None + requested_by: int | None + ffmpeg_binary: str + ffprobe_binary: str + timeout_seconds: int + probe_timeout_seconds: int + + +def run_media_transcode(payload: dict[str, Any] | Any) -> dict[str, Any]: + parsed = _parse_payload(payload) + settings = get_settings() + storage = MinioObjectStorageClient.from_settings(settings) + + with tempfile.TemporaryDirectory(prefix="fileflash-transcode-") as tmp_dir_raw: + tmp_dir = Path(tmp_dir_raw) + source_path = tmp_dir / "source" + source_suffix = Path(parsed.source_object_key).suffix.lower() + output_suffix = ".mp4" + + try: + _run_async( + storage.fget_object( + bucket_name=parsed.source_bucket_name, + object_key=parsed.source_object_key, + file_path=str(source_path), + ) + ) + except FileNotFoundError as exc: + raise FileNotFoundError( + f"Source object not found: {parsed.source_bucket_name}/{parsed.source_object_key}" + ) from exc + + source_probe = probe_media( + source_path, + ffprobe_binary=parsed.ffprobe_binary, + timeout_seconds=parsed.probe_timeout_seconds, + ) + media_type = detect_media_type(source_probe) + if media_type == "audio": + output_suffix = ".m4a" + output_path = tmp_dir / f"optimized{output_suffix}" + + ffmpeg_command = build_ffmpeg_command( + input_path=source_path, + output_path=output_path, + media_type=media_type, + ffmpeg_binary=parsed.ffmpeg_binary, + payload=payload, + ) + _run_command(ffmpeg_command, timeout_seconds=parsed.timeout_seconds) + + if not output_path.exists(): + raise RuntimeError(f"Transcode finished but output does not exist: {output_path}") + + upload_result = _run_async( + storage.fput_object( + bucket_name=parsed.output_bucket_name, + object_key=parsed.output_object_key, + file_path=str(output_path), + content_type="video/mp4" if media_type == "video" else "audio/mp4", + ) + ) + output_stat = _run_async( + storage.stat_object( + bucket_name=parsed.output_bucket_name, + object_key=parsed.output_object_key, + ) + ) + output_probe = probe_media( + output_path, + ffprobe_binary=parsed.ffprobe_binary, + timeout_seconds=parsed.probe_timeout_seconds, + ) + metadata = extract_media_metadata(output_probe) + + return { + "mediaType": media_type, + "sourceObjectId": parsed.source_object_id, + "sourceBucketName": parsed.source_bucket_name, + "sourceObjectKey": parsed.source_object_key, + "outputBucketName": parsed.output_bucket_name, + "outputObjectKey": parsed.output_object_key, + "outputObjectEtag": upload_result.etag or output_stat.etag, + "outputObjectVersionId": upload_result.version_id or output_stat.version_id, + "outputObjectSize": int(output_stat.size), + "optimizedMimeType": "video/mp4" if media_type == "video" else "audio/mp4", + "transcodeProfile": { + "version": TRANSCODE_PROFILE_VERSION, + "container": output_suffix.lstrip("."), + "videoCodec": _first_stream_codec(output_probe, "video"), + "audioCodec": _first_stream_codec(output_probe, "audio"), + "sourceExtension": source_suffix.lstrip("."), + }, + "metadata": metadata, + "transcodedAt": datetime.now(UTC).isoformat(), + } + + +def probe_media(input_path: Path, *, ffprobe_binary: str, timeout_seconds: int) -> dict[str, Any]: + command = [ + ffprobe_binary, + "-v", + "error", + "-show_streams", + "-show_format", + "-of", + "json", + str(input_path), + ] + result = _run_command(command, timeout_seconds=timeout_seconds) + try: + return json.loads(result.stdout or "{}") + except json.JSONDecodeError as exc: + raise RuntimeError(f"ffprobe JSON parse failed: {exc}") from exc + + +def detect_media_type(probe_data: dict[str, Any]) -> str: + streams = probe_data.get("streams", []) + if any(stream.get("codec_type") == "video" for stream in streams): + return "video" + if any(stream.get("codec_type") == "audio" for stream in streams): + return "audio" + raise ValueError("Input media does not contain video or audio stream") + + +def build_ffmpeg_command( + *, + input_path: Path, + output_path: Path, + media_type: str, + ffmpeg_binary: str, + payload: dict[str, Any] | Any, +) -> list[str]: + audio_bitrate = _coerce_positive_int(payload.get("audioBitrateKbps"), 128) + command: list[str] = [ffmpeg_binary, "-y", "-i", str(input_path)] + + if media_type == "video": + video_codec = str(payload.get("videoCodec") or "libx264") + audio_codec = str(payload.get("audioCodec") or "aac") + preset = str(payload.get("videoPreset") or "medium") + crf = _coerce_positive_int(payload.get("videoCrf"), 23) + command.extend( + [ + "-vf", + "scale=w=min(iw\\,1920):h=min(ih\\,1080):force_original_aspect_ratio=decrease," + "scale=trunc(iw/2)*2:trunc(ih/2)*2", + "-c:v", + video_codec, + "-preset", + preset, + "-crf", + str(crf), + "-pix_fmt", + "yuv420p", + "-movflags", + "+faststart", + "-c:a", + audio_codec, + "-b:a", + f"{audio_bitrate}k", + "-map", + "0:v:0", + "-map", + "0:a:0?", + ] + ) + else: + command.extend( + [ + "-vn", + "-c:a", + "aac", + "-b:a", + f"{audio_bitrate}k", + "-movflags", + "+faststart", + "-map", + "0:a:0", + ] + ) + + command.append(str(output_path)) + return command + + +def extract_media_metadata(probe_data: dict[str, Any]) -> dict[str, int | str | None]: + format_data = probe_data.get("format", {}) + streams = probe_data.get("streams", []) + + video_stream = next((stream for stream in streams if stream.get("codec_type") == "video"), None) + audio_stream = next((stream for stream in streams if stream.get("codec_type") == "audio"), None) + + duration_ms = _duration_ms_from_format(format_data) + return { + "durationMs": duration_ms, + "width": _safe_int(video_stream.get("width") if video_stream else None), + "height": _safe_int(video_stream.get("height") if video_stream else None), + "bitrate": _safe_int(format_data.get("bit_rate")), + "sampleRate": _safe_int(audio_stream.get("sample_rate") if audio_stream else None), + "videoCodec": _first_stream_codec(probe_data, "video"), + "audioCodec": _first_stream_codec(probe_data, "audio"), + } + + +def _parse_payload(payload: dict[str, Any] | Any) -> TranscodeTaskPayload: + source_bucket_name = str(payload.get("sourceBucketName") or "").strip() + source_object_key = str(payload.get("sourceObjectKey") or "").strip() + output_bucket_name = str(payload.get("outputBucketName") or source_bucket_name).strip() + output_object_key = str(payload.get("outputObjectKey") or "").strip() + source_object_id = _coerce_positive_int(payload.get("sourceObjectId"), 0) + if not source_bucket_name: + raise ValueError("Transcode payload requires sourceBucketName") + if not source_object_key: + raise ValueError("Transcode payload requires sourceObjectKey") + if not output_bucket_name: + raise ValueError("Transcode payload requires outputBucketName") + if not output_object_key: + raise ValueError("Transcode payload requires outputObjectKey") + if source_object_id <= 0: + raise ValueError("Transcode payload requires sourceObjectId") + + timeout_seconds = _coerce_positive_int(payload.get("timeoutSeconds"), 900) + probe_timeout_seconds = _coerce_positive_int( + payload.get("probeTimeoutSeconds"), + min(60, timeout_seconds), + ) + file_id = _safe_int(payload.get("fileId")) + requested_by = _safe_int(payload.get("requestedBy")) + + return TranscodeTaskPayload( + source_bucket_name=source_bucket_name, + source_object_key=source_object_key, + source_object_id=source_object_id, + output_bucket_name=output_bucket_name, + output_object_key=output_object_key, + file_id=file_id, + requested_by=requested_by, + ffmpeg_binary=str(payload.get("ffmpegBinary") or "ffmpeg"), + ffprobe_binary=str(payload.get("ffprobeBinary") or "ffprobe"), + timeout_seconds=timeout_seconds, + probe_timeout_seconds=probe_timeout_seconds, + ) + + +def _run_command(command: list[str], *, timeout_seconds: int) -> subprocess.CompletedProcess[str]: + try: + result = subprocess.run( + command, + check=False, + capture_output=True, + text=True, + timeout=timeout_seconds, + ) + except FileNotFoundError as exc: + raise RuntimeError(f"Binary not found for command: {command[0]}") from exc + if result.returncode != 0: + stderr = (result.stderr or "").strip() + raise RuntimeError(f"Command failed ({result.returncode}): {' '.join(command)} | {stderr}") + return result + + +def _run_async(awaitable: Any) -> Any: + return asyncio.run(awaitable) + + +def _safe_int(raw: Any) -> int | None: + if raw is None: + return None + try: + return int(raw) + except (TypeError, ValueError): + return None + + +def _coerce_positive_int(raw: Any, default: int) -> int: + try: + value = int(raw) + except (TypeError, ValueError): + return default + return value if value > 0 else default + + +def _duration_ms_from_format(format_data: dict[str, Any]) -> int | None: + raw_duration = format_data.get("duration") + if raw_duration in (None, ""): + return None + try: + seconds = float(raw_duration) + except (TypeError, ValueError): + return None + return int(seconds * 1000) + + +def _first_stream_codec(probe_data: dict[str, Any], codec_type: str) -> str | None: + streams = probe_data.get("streams", []) + for stream in streams: + if stream.get("codec_type") == codec_type: + codec_name = stream.get("codec_name") + if codec_name: + return str(codec_name) + return None diff --git a/app/src/workers/DESIGN.md b/app/src/fileflash/workers/DESIGN.md similarity index 97% rename from app/src/workers/DESIGN.md rename to app/src/fileflash/workers/DESIGN.md index 1e4a4b2..9b45911 100644 --- a/app/src/workers/DESIGN.md +++ b/app/src/fileflash/workers/DESIGN.md @@ -127,8 +127,8 @@ ## 13. 启动与部署建议 - 同仓库、分进程部署: - - API 进程:`uv run python -m src.main` - - Worker 进程:`uv run python -m src.workers.consumer` + - API 进程:`uv run python -m fileflash.main` + - Worker 进程:`uv run python -m fileflash.workers.consumer` - 多 worker 副本按队列水平扩容。 - 当前阶段属于分布式单体,不是严格微服务。 diff --git a/app/src/fileflash/workers/__init__.py b/app/src/fileflash/workers/__init__.py new file mode 100644 index 0000000..da4bd26 --- /dev/null +++ b/app/src/fileflash/workers/__init__.py @@ -0,0 +1,17 @@ +from __future__ import annotations + +from typing import Any + +__all__ = ["WorkerConsumer", "run_worker"] + + +def __getattr__(name: str) -> Any: + if name in __all__: + from .consumer import WorkerConsumer, run_worker + + exports = { + "WorkerConsumer": WorkerConsumer, + "run_worker": run_worker, + } + return exports[name] + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/app/src/workers/bootstrap.py b/app/src/fileflash/workers/bootstrap.py similarity index 100% rename from app/src/workers/bootstrap.py rename to app/src/fileflash/workers/bootstrap.py diff --git a/app/src/workers/consumer.py b/app/src/fileflash/workers/consumer.py similarity index 66% rename from app/src/workers/consumer.py rename to app/src/fileflash/workers/consumer.py index 3a0d089..8759de2 100644 --- a/app/src/workers/consumer.py +++ b/app/src/fileflash/workers/consumer.py @@ -10,12 +10,17 @@ from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker from ..core import get_settings +from ..core.errors import ApiError from ..db.session import SessionLocal +from ..s3.minio_client import MinioObjectStorageClient +from ..services.background_jobs import BackgroundJobService from ..services.job_queue import RedisStreamJobQueue +from ..services.upload import UploadService from .bootstrap import WorkerRuntimeConfig, build_worker_runtime_config, create_process_pool from .contracts import WorkerJobMessage -from .dispatcher import execute_task +from .dispatcher import PicklableRemoteTaskError, execute_task from .effects import apply_task_effects +from .effects import mark_transcode_failed, mark_transcode_running from .repository import ( get_retry_delay_seconds, mark_job_failed_or_retrying, @@ -39,6 +44,13 @@ def __init__( self._executor = executor self._queue = queue self._session_factory = session_factory + self._settings = get_settings() + self._storage = MinioObjectStorageClient.from_settings(self._settings) + self._job_publisher = RedisStreamJobQueue( + redis_url=self._settings.redis_url, + stream_key=self._settings.worker_queue_stream, + ) + self._jobs = BackgroundJobService(queue_publisher=self._job_publisher) async def run(self) -> None: logger.info( @@ -77,9 +89,16 @@ async def _mark_running(self, message: WorkerJobMessage) -> WorkerJobMessage | N async def _process_message(self, *, slot: int, message: WorkerJobMessage) -> None: payload = dict(message.payload) + if payload.get("jobId") in (None, ""): + payload["jobId"] = message.job_id + if message.task_type == "task.upload_merge": + await self._process_upload_merge(slot=slot, message=message, payload=payload) + return if message.task_type in ("task.transcode", "media.transcode"): payload.setdefault("ffmpegBinary", self._config.ffmpeg_binary) payload.setdefault("ffprobeBinary", self._config.ffprobe_binary) + payload.setdefault("profileVersion", "mp4-v1") + await self._mark_transcode_running(payload) started_at = datetime.now(UTC) try: @@ -124,6 +143,52 @@ async def _process_message(self, *, slot: int, message: WorkerJobMessage) -> Non message.trace_id, ) + async def _process_upload_merge( + self, + *, + slot: int, + message: WorkerJobMessage, + payload: dict[str, Any], + ) -> None: + started_at = datetime.now(UTC) + try: + result = await asyncio.wait_for( + self._run_upload_merge(payload=payload), + timeout=self._config.task_timeout_seconds, + ) + except Exception as exc: + await self._handle_failure(slot=slot, message=message, error=exc) + return + + try: + async with self._session_factory() as db: + async with db.begin(): + await mark_job_succeeded(db, job_id=message.job_id, result=result) + except Exception as exc: + await self._handle_failure(slot=slot, message=message, error=exc) + return + + duration_ms = int((datetime.now(UTC) - started_at).total_seconds() * 1000) + logger.info( + "Worker slot=%s succeeded jobId=%s taskType=%s attempt=%s durationMs=%s traceId=%s", + slot, + message.job_id, + message.task_type, + message.attempt, + duration_ms, + message.trace_id, + ) + + async def _run_upload_merge(self, *, payload: dict[str, Any]) -> dict[str, Any]: + async with self._session_factory() as db: + service = UploadService( + db=db, + settings=self._settings, + storage=self._storage, + jobs=self._jobs, + ) + return await service.execute_merge_job(payload=payload) + async def _handle_failure( self, *, @@ -132,7 +197,10 @@ async def _handle_failure( error: Exception, ) -> None: retryable = _is_retryable_error(error) - error_message = f"{type(error).__name__}: {error}" + if isinstance(error, ApiError): + error_message = f"ApiError[{error.status_code}/{error.code}]: {error.message}" + else: + error_message = f"{type(error).__name__}: {error}" async with self._session_factory() as db: async with db.begin(): state = await mark_job_failed_or_retrying( @@ -142,6 +210,12 @@ async def _handle_failure( retryable=retryable, retry_backoff_seconds=self._config.retry_backoff_seconds, ) + if state == "failed" and message.task_type in ("task.transcode", "media.transcode"): + await mark_transcode_failed( + db, + payload=dict(message.payload), + error_message=error_message, + ) if state == "retrying": next_attempt = message.attempt + 1 @@ -182,8 +256,31 @@ async def _republish_after_delay(self, message: WorkerJobMessage, delay_seconds: await asyncio.sleep(delay_seconds) await self._queue.publish(message) + async def _mark_transcode_running(self, payload: dict[str, Any]) -> None: + async with self._session_factory() as db: + async with db.begin(): + await mark_transcode_running(db, payload=payload) + + async def aclose(self) -> None: + await self._job_publisher.aclose() + def _is_retryable_error(error: Exception) -> bool: + if isinstance(error, ApiError): + if error.status_code >= 500: + return True + if isinstance(error.data, dict) and error.data.get("retryable") is True: + return True + return False + + if isinstance(error, PicklableRemoteTaskError) and error.retryable_hint is not None: + return bool(error.retryable_hint) + + if isinstance(error, PicklableRemoteTaskError): + non_retryable_original_types = {"FileNotFoundError", "PermissionError", "ValueError"} + if error.original_type in non_retryable_original_types: + return False + non_retryable_types = (FileNotFoundError, PermissionError, ValueError) return not isinstance(error, non_retryable_types) @@ -217,6 +314,7 @@ async def run_worker() -> None: try: await consumer.run() finally: + await consumer.aclose() await queue.aclose() diff --git a/app/src/workers/contracts.py b/app/src/fileflash/workers/contracts.py similarity index 100% rename from app/src/workers/contracts.py rename to app/src/fileflash/workers/contracts.py diff --git a/app/src/fileflash/workers/dispatcher.py b/app/src/fileflash/workers/dispatcher.py new file mode 100644 index 0000000..266567f --- /dev/null +++ b/app/src/fileflash/workers/dispatcher.py @@ -0,0 +1,68 @@ +from __future__ import annotations + +from collections.abc import Mapping +from typing import Any + +from ..tasks import dispatch_task + + +class PicklableRemoteTaskError(Exception): + def __init__( + self, + *, + original_type: str, + message: str, + retryable_hint: bool | None = None, + ) -> None: + self.original_type = original_type + self.message = message + self.retryable_hint = retryable_hint + super().__init__(f"{original_type}: {message}") + + def __reduce__(self): + return ( + self.__class__._reconstruct, + (self.original_type, self.message, self.retryable_hint), + ) + + @classmethod + def _reconstruct( + cls, + original_type: str, + message: str, + retryable_hint: bool | None, + ) -> "PicklableRemoteTaskError": + return cls( + original_type=original_type, + message=message, + retryable_hint=retryable_hint, + ) + + +def _is_picklable_exception(exc: Exception) -> bool: + try: + import pickle + + pickle.dumps(exc) + except Exception: + return False + return True + + +def execute_task(task_type: str, payload: Mapping[str, Any]) -> dict[str, Any]: + try: + return dispatch_task(task_type=task_type, payload=payload) + except Exception as exc: + if _is_picklable_exception(exc): + raise + + original_type = type(exc).__name__ + message = str(exc) or repr(exc) + retryable_hint = None + if isinstance(exc, (FileNotFoundError, PermissionError, ValueError)): + retryable_hint = False + raise PicklableRemoteTaskError( + original_type=original_type, + message=message, + retryable_hint=retryable_hint, + ) from None diff --git a/app/src/workers/effects.py b/app/src/fileflash/workers/effects.py similarity index 85% rename from app/src/workers/effects.py rename to app/src/fileflash/workers/effects.py index 8048c40..74e9d6c 100644 --- a/app/src/workers/effects.py +++ b/app/src/fileflash/workers/effects.py @@ -71,11 +71,12 @@ async def _apply_transcode_effects( payload: dict[str, Any], result: dict[str, Any], ) -> None: - object_id = _coerce_int(payload.get("objectId")) + object_id = _coerce_int(payload.get("sourceObjectId")) or _coerce_int(payload.get("objectId")) if object_id is None: return metadata = result.get("metadata") or {} + now = datetime.now(UTC) row = await db.scalar( select(FileMediaMetadata).where(FileMediaMetadata.source_object_id == object_id).limit(1) ) @@ -90,13 +91,90 @@ async def _apply_transcode_effects( row.sample_rate = _coerce_int(metadata.get("sampleRate")) row.video_codec = _truncate(metadata.get("videoCodec"), 64) row.audio_codec = _truncate(metadata.get("audioCodec"), 64) - row.extra_metadata = { + extra = dict(row.extra_metadata or {}) + extra["transcodeProfile"] = result.get("transcodeProfile") or {} + extra["transcode"] = { + "status": "ready", "mediaType": result.get("mediaType"), - "inputPath": result.get("inputPath"), - "outputPath": result.get("outputPath"), - "transcodeProfile": result.get("transcodeProfile") or {}, + "profileVersion": (result.get("transcodeProfile") or {}).get("version"), + "optimizedMimeType": result.get("optimizedMimeType"), + "optimizedBucketName": result.get("outputBucketName"), + "optimizedObjectKey": result.get("outputObjectKey"), + "outputObjectEtag": result.get("outputObjectEtag"), + "outputObjectVersionId": result.get("outputObjectVersionId"), + "outputObjectSize": _coerce_int(result.get("outputObjectSize")), + "updatedAt": now.isoformat(), } - row.extracted_at = datetime.now(UTC) + row.extra_metadata = extra + row.extracted_at = now + + +async def mark_transcode_running( + db: AsyncSession, + *, + payload: dict[str, Any], +) -> None: + object_id = _coerce_int(payload.get("sourceObjectId")) or _coerce_int(payload.get("objectId")) + if object_id is None: + return + + row = await db.scalar( + select(FileMediaMetadata).where(FileMediaMetadata.source_object_id == object_id).limit(1) + ) + if row is None: + row = FileMediaMetadata(source_object_id=object_id) + db.add(row) + + now = datetime.now(UTC) + extra = dict(row.extra_metadata or {}) + transcode = dict(extra.get("transcode") or {}) + transcode["status"] = "running" + transcode["mediaType"] = transcode.get("mediaType") or payload.get("mediaType") + transcode["updatedAt"] = now.isoformat() + if payload.get("outputBucketName"): + transcode["optimizedBucketName"] = payload.get("outputBucketName") + if payload.get("outputObjectKey"): + transcode["optimizedObjectKey"] = payload.get("outputObjectKey") + if payload.get("profileVersion"): + transcode["profileVersion"] = payload.get("profileVersion") + extra["transcode"] = transcode + row.extra_metadata = extra + row.extracted_at = now + + +async def mark_transcode_failed( + db: AsyncSession, + *, + payload: dict[str, Any], + error_message: str, +) -> None: + object_id = _coerce_int(payload.get("sourceObjectId")) or _coerce_int(payload.get("objectId")) + if object_id is None: + return + + row = await db.scalar( + select(FileMediaMetadata).where(FileMediaMetadata.source_object_id == object_id).limit(1) + ) + if row is None: + row = FileMediaMetadata(source_object_id=object_id) + db.add(row) + + now = datetime.now(UTC) + extra = dict(row.extra_metadata or {}) + transcode = dict(extra.get("transcode") or {}) + transcode["status"] = "failed" + transcode["mediaType"] = transcode.get("mediaType") or payload.get("mediaType") + transcode["error"] = _truncate(error_message, 500) + transcode["updatedAt"] = now.isoformat() + if payload.get("outputBucketName"): + transcode["optimizedBucketName"] = payload.get("outputBucketName") + if payload.get("outputObjectKey"): + transcode["optimizedObjectKey"] = payload.get("outputObjectKey") + if payload.get("profileVersion"): + transcode["profileVersion"] = payload.get("profileVersion") + extra["transcode"] = transcode + row.extra_metadata = extra + row.extracted_at = now async def _apply_archive_extract_effects( diff --git a/app/src/workers/repository.py b/app/src/fileflash/workers/repository.py similarity index 97% rename from app/src/workers/repository.py rename to app/src/fileflash/workers/repository.py index cd3cd6c..39980b8 100644 --- a/app/src/workers/repository.py +++ b/app/src/fileflash/workers/repository.py @@ -4,6 +4,7 @@ from datetime import UTC, datetime, timedelta from typing import Any +from fastapi.encoders import jsonable_encoder from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession @@ -60,7 +61,7 @@ async def mark_job_succeeded( if job is None: return job.status = "succeeded" - job.result = result + job.result = jsonable_encoder(result) job.error_message = None job.finished_at = now job.updated_at = now diff --git a/app/src/tasks/transcode.py b/app/src/tasks/transcode.py deleted file mode 100644 index 216e0ed..0000000 --- a/app/src/tasks/transcode.py +++ /dev/null @@ -1,239 +0,0 @@ -from __future__ import annotations - -import json -import subprocess -from datetime import UTC, datetime -from pathlib import Path -from typing import Any - - -def run_media_transcode(payload: dict[str, Any] | Any) -> dict[str, Any]: - input_path = _resolve_input_path(payload) - ffmpeg_binary = str(payload.get("ffmpegBinary") or "ffmpeg") - ffprobe_binary = str(payload.get("ffprobeBinary") or "ffprobe") - timeout_seconds = _coerce_positive_int(payload.get("timeoutSeconds"), 900) - probe_timeout_seconds = _coerce_positive_int( - payload.get("probeTimeoutSeconds"), - min(60, timeout_seconds), - ) - - source_probe = probe_media( - input_path, - ffprobe_binary=ffprobe_binary, - timeout_seconds=probe_timeout_seconds, - ) - media_type = detect_media_type(source_probe) - output_path = resolve_output_path( - input_path=input_path, - media_type=media_type, - raw_output_path=payload.get("outputPath"), - raw_target_container=payload.get("targetContainer"), - ) - output_path.parent.mkdir(parents=True, exist_ok=True) - - ffmpeg_command = build_ffmpeg_command( - input_path=input_path, - output_path=output_path, - media_type=media_type, - ffmpeg_binary=ffmpeg_binary, - payload=payload, - ) - _run_command(ffmpeg_command, timeout_seconds=timeout_seconds) - - if not output_path.exists(): - raise RuntimeError(f"Transcode finished but output does not exist: {output_path}") - - output_probe = probe_media( - output_path, - ffprobe_binary=ffprobe_binary, - timeout_seconds=probe_timeout_seconds, - ) - metadata = extract_media_metadata(output_probe) - return { - "mediaType": media_type, - "inputPath": str(input_path), - "outputPath": str(output_path), - "transcodeProfile": { - "container": output_path.suffix.lower().lstrip("."), - "videoCodec": _first_stream_codec(output_probe, "video"), - "audioCodec": _first_stream_codec(output_probe, "audio"), - }, - "metadata": metadata, - "transcodedAt": datetime.now(UTC).isoformat(), - } - - -def probe_media(input_path: Path, *, ffprobe_binary: str, timeout_seconds: int) -> dict[str, Any]: - command = [ - ffprobe_binary, - "-v", - "error", - "-show_streams", - "-show_format", - "-of", - "json", - str(input_path), - ] - result = _run_command(command, timeout_seconds=timeout_seconds) - try: - return json.loads(result.stdout or "{}") - except json.JSONDecodeError as exc: - raise RuntimeError(f"ffprobe JSON parse failed: {exc}") from exc - - -def detect_media_type(probe_data: dict[str, Any]) -> str: - streams = probe_data.get("streams", []) - if any(stream.get("codec_type") == "video" for stream in streams): - return "video" - if any(stream.get("codec_type") == "audio" for stream in streams): - return "audio" - raise ValueError("Input media does not contain video or audio stream") - - -def resolve_output_path( - *, - input_path: Path, - media_type: str, - raw_output_path: Any, - raw_target_container: Any, -) -> Path: - if raw_output_path: - return Path(str(raw_output_path)).expanduser() - - if raw_target_container: - suffix = "." + str(raw_target_container).strip().lstrip(".").lower() - elif media_type == "video": - suffix = ".mp4" - else: - suffix = ".m4a" - return input_path.with_suffix(suffix) - - -def build_ffmpeg_command( - *, - input_path: Path, - output_path: Path, - media_type: str, - ffmpeg_binary: str, - payload: dict[str, Any] | Any, -) -> list[str]: - audio_bitrate = _coerce_positive_int(payload.get("audioBitrateKbps"), 128) - command: list[str] = [ffmpeg_binary, "-y", "-i", str(input_path)] - - if media_type == "video": - video_codec = str(payload.get("videoCodec") or "libx264") - audio_codec = str(payload.get("audioCodec") or "aac") - preset = str(payload.get("videoPreset") or "medium") - crf = _coerce_positive_int(payload.get("videoCrf"), 23) - command.extend( - [ - "-c:v", - video_codec, - "-preset", - preset, - "-crf", - str(crf), - "-movflags", - "+faststart", - "-c:a", - audio_codec, - "-b:a", - f"{audio_bitrate}k", - ] - ) - else: - preferred_codec = payload.get("audioCodec") - if preferred_codec: - audio_codec = str(preferred_codec) - elif output_path.suffix.lower() == ".mp3": - audio_codec = "libmp3lame" - else: - audio_codec = "aac" - command.extend(["-vn", "-c:a", audio_codec, "-b:a", f"{audio_bitrate}k"]) - - command.append(str(output_path)) - return command - - -def extract_media_metadata(probe_data: dict[str, Any]) -> dict[str, int | str | None]: - format_data = probe_data.get("format", {}) - streams = probe_data.get("streams", []) - - video_stream = next((stream for stream in streams if stream.get("codec_type") == "video"), None) - audio_stream = next((stream for stream in streams if stream.get("codec_type") == "audio"), None) - - duration_ms = _duration_ms_from_format(format_data) - return { - "durationMs": duration_ms, - "width": _safe_int(video_stream.get("width") if video_stream else None), - "height": _safe_int(video_stream.get("height") if video_stream else None), - "bitrate": _safe_int(format_data.get("bit_rate")), - "sampleRate": _safe_int(audio_stream.get("sample_rate") if audio_stream else None), - "videoCodec": _first_stream_codec(probe_data, "video"), - "audioCodec": _first_stream_codec(probe_data, "audio"), - } - - -def _resolve_input_path(payload: dict[str, Any] | Any) -> Path: - raw_input = str(payload.get("inputPath") or payload.get("localPath") or "").strip() - if not raw_input: - raise ValueError("Transcode payload requires inputPath or localPath") - input_path = Path(raw_input).expanduser() - if not input_path.exists() or not input_path.is_file(): - raise FileNotFoundError(f"Transcode input not found: {input_path}") - return input_path - - -def _run_command(command: list[str], *, timeout_seconds: int) -> subprocess.CompletedProcess[str]: - try: - result = subprocess.run( - command, - check=False, - capture_output=True, - text=True, - timeout=timeout_seconds, - ) - except FileNotFoundError as exc: - raise RuntimeError(f"Binary not found for command: {command[0]}") from exc - if result.returncode != 0: - stderr = (result.stderr or "").strip() - raise RuntimeError(f"Command failed ({result.returncode}): {' '.join(command)} | {stderr}") - return result - - -def _safe_int(raw: Any) -> int | None: - if raw is None: - return None - try: - return int(raw) - except (TypeError, ValueError): - return None - - -def _coerce_positive_int(raw: Any, default: int) -> int: - try: - value = int(raw) - except (TypeError, ValueError): - return default - return value if value > 0 else default - - -def _duration_ms_from_format(format_data: dict[str, Any]) -> int | None: - raw_duration = format_data.get("duration") - if raw_duration in (None, ""): - return None - try: - seconds = float(raw_duration) - except (TypeError, ValueError): - return None - return int(seconds * 1000) - - -def _first_stream_codec(probe_data: dict[str, Any], codec_type: str) -> str | None: - streams = probe_data.get("streams", []) - for stream in streams: - if stream.get("codec_type") == codec_type: - codec_name = stream.get("codec_name") - if codec_name: - return str(codec_name) - return None diff --git a/app/src/workers/__init__.py b/app/src/workers/__init__.py deleted file mode 100644 index f5ed699..0000000 --- a/app/src/workers/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .consumer import WorkerConsumer, run_worker - -__all__ = ["WorkerConsumer", "run_worker"] diff --git a/app/src/workers/dispatcher.py b/app/src/workers/dispatcher.py deleted file mode 100644 index aaa869b..0000000 --- a/app/src/workers/dispatcher.py +++ /dev/null @@ -1,10 +0,0 @@ -from __future__ import annotations - -from collections.abc import Mapping -from typing import Any - -from ..tasks import dispatch_task - - -def execute_task(task_type: str, payload: Mapping[str, Any]) -> dict[str, Any]: - return dispatch_task(task_type=task_type, payload=payload) diff --git a/app/tests/test_admin_registration_email_domain_rule_routes.py b/app/tests/test_admin_registration_email_domain_rule_routes.py new file mode 100644 index 0000000..475224a --- /dev/null +++ b/app/tests/test_admin_registration_email_domain_rule_routes.py @@ -0,0 +1,109 @@ +from __future__ import annotations + +from datetime import UTC, datetime +from types import SimpleNamespace + +from fastapi import FastAPI +from fastapi.testclient import TestClient + +from fileflash.core.deps import get_registration_email_domain_rule_service, require_admin +from fileflash.core.errors import ApiError, api_error_handler +from fileflash.schemas.registration_email_domain_rule import RegistrationEmailDomainRuleItem +from fileflash.routers.admin_registration_email_domain_rules import router as admin_router + + +class StubRuleService: + async def list_rules(self, *, query): # noqa: ANN001 + return SimpleNamespace( + model_dump=lambda **_: { + "items": [], + "pagination": { + "totalItems": 0, + "totalPages": 1, + "perPage": query.per_page, + "currentPage": query.page, + "hasPrev": False, + "hasNext": False, + }, + } + ) + + async def create_rule(self, *, payload): # noqa: ANN001 + return RegistrationEmailDomainRuleItem( + rule_id="1", + name=payload.name, + pattern=payload.pattern, + enabled=payload.enabled, + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + + async def update_rule(self, *, rule_id: int, payload): # noqa: ANN001 + return RegistrationEmailDomainRuleItem( + rule_id=str(rule_id), + name=payload.name or "existing", + pattern=payload.pattern or r".*", + enabled=True if payload.enabled is None else payload.enabled, + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + + async def delete_rule(self, *, rule_id: int): # noqa: ANN001 + _ = rule_id + + +def _build_client(admin: bool) -> TestClient: + app = FastAPI() + app.add_exception_handler(ApiError, api_error_handler) + app.include_router(admin_router, prefix="/api/v1") + + app.dependency_overrides[get_registration_email_domain_rule_service] = lambda: StubRuleService() + if admin: + app.dependency_overrides[require_admin] = lambda: SimpleNamespace(user_id=1, role="admin") + else: + async def _deny(): + raise ApiError(status_code=403, code=403, message="Admin access required") + app.dependency_overrides[require_admin] = _deny + + return TestClient(app) + + +def test_admin_can_list_rules() -> None: + with _build_client(admin=True) as client: + response = client.get("/api/v1/admin/registration-email-domain-rules") + + assert response.status_code == 200 + payload = response.json() + assert payload["success"] is True + assert payload["data"]["items"] == [] + assert payload["data"]["pagination"]["totalItems"] == 0 + + +def test_non_admin_forbidden() -> None: + with _build_client(admin=False) as client: + response = client.get("/api/v1/admin/registration-email-domain-rules") + + assert response.status_code == 403 + payload = response.json() + assert payload["success"] is False + assert payload["message"] == "Admin access required" + + +def test_admin_create_and_delete_rule() -> None: + with _build_client(admin=True) as client: + create_response = client.post( + "/api/v1/admin/registration-email-domain-rules", + json={"name": "corp", "pattern": r".*\.corp\.com", "enabled": True}, + ) + delete_response = client.delete("/api/v1/admin/registration-email-domain-rules/1") + + assert create_response.status_code == 201 + create_payload = create_response.json() + assert create_payload["success"] is True + assert create_payload["data"]["name"] == "corp" + assert create_payload["data"]["pattern"] == r".*\.corp\.com" + + assert delete_response.status_code == 200 + delete_payload = delete_response.json() + assert delete_payload["success"] is True + assert delete_payload["data"]["ruleId"] == "1" diff --git a/app/tests/test_agent_repositories.py b/app/tests/test_agent_repositories.py index d75a5a0..129d9c6 100644 --- a/app/tests/test_agent_repositories.py +++ b/app/tests/test_agent_repositories.py @@ -6,9 +6,9 @@ import pytest -from src.agents import PlanRunner, PromptBuilder -from src.models import AgentWorkSession, BackgroundJob -from src.repositories import ( +from fileflash.agents import PlanRunner, PromptBuilder +from fileflash.models import AgentWorkSession, BackgroundJob +from fileflash.repositories import ( AgentActionLogRepository, AgentMcpRepository, AgentMemoryRepository, @@ -17,7 +17,7 @@ AgentSkillRepository, AgentWorkSessionRepository, ) -from src.schemas.job import to_background_job_response +from fileflash.schemas.job import to_background_job_response class FakeMappingResult: diff --git a/app/tests/test_agent_skill_service.py b/app/tests/test_agent_skill_service.py index 9197aef..79d8ffe 100644 --- a/app/tests/test_agent_skill_service.py +++ b/app/tests/test_agent_skill_service.py @@ -5,12 +5,12 @@ import pytest -from src.core.errors import ApiError -from src.models import AgentSkill -from src.models.enums import AgentSkillVisibility -from src.repositories import AgentSkillRepository -from src.schemas.agent_skill import CreateAgentSkillRequest, ImportAgentSkillItem, ImportAgentSkillsRequest, UpdateAgentSkillRequest -from src.services.agent.skill_service import SkillService +from fileflash.core.errors import ApiError +from fileflash.models import AgentSkill +from fileflash.models.enums import AgentSkillVisibility +from fileflash.repositories import AgentSkillRepository +from fileflash.schemas.agent_skill import CreateAgentSkillRequest, ImportAgentSkillItem, ImportAgentSkillsRequest, UpdateAgentSkillRequest +from fileflash.services.agent.skill_service import SkillService class DummySession: diff --git a/app/tests/test_api_errors.py b/app/tests/test_api_errors.py index ae5dbf8..799134a 100644 --- a/app/tests/test_api_errors.py +++ b/app/tests/test_api_errors.py @@ -6,7 +6,7 @@ from starlette.requests import Request -from src.core.errors import ApiError, api_error_handler, api_success +from fileflash.core.errors import ApiError, api_error_handler, api_success def _new_request() -> Request: diff --git a/app/tests/test_auth_messaging_email_delivery.py b/app/tests/test_auth_messaging_email_delivery.py new file mode 100644 index 0000000..646f9d8 --- /dev/null +++ b/app/tests/test_auth_messaging_email_delivery.py @@ -0,0 +1,34 @@ +from __future__ import annotations + +import logging + +import pytest + +from fileflash.services.messaging import InProcessAuthEventPublisher + + +@pytest.mark.asyncio +async def test_publish_non_email_event_skips_delivery() -> None: + publisher = InProcessAuthEventPublisher() + + await publisher.publish("auth.user_logged_in", {"userId": "1"}) + + +@pytest.mark.asyncio +async def test_publish_event_log_does_not_expose_payload(caplog: pytest.LogCaptureFixture) -> None: + publisher = InProcessAuthEventPublisher() + sensitive_payload = { + "email": "demo@example.com", + "token": "verify-token-1234567890", + "verificationToken": "verify-token-1234567890", + } + + with caplog.at_level(logging.INFO): + await publisher.publish( + "auth.email_verification_resent", + sensitive_payload, + ) + + assert "Auth event published in-process: auth.email_verification_resent" in caplog.text + assert "verify-token-1234567890" not in caplog.text + assert "demo@example.com" not in caplog.text diff --git a/app/tests/test_auth_register_domain_rules.py b/app/tests/test_auth_register_domain_rules.py new file mode 100644 index 0000000..23e5432 --- /dev/null +++ b/app/tests/test_auth_register_domain_rules.py @@ -0,0 +1,113 @@ +from __future__ import annotations + +from datetime import UTC, datetime +from types import SimpleNamespace +from unittest.mock import AsyncMock, Mock + +import pytest + +from fileflash.core.errors import ApiError +from fileflash.core.settings import Settings +from fileflash.schemas.user import User, UserPreference +from fileflash.schemas.auth import RegisterRequest +from fileflash.services.auth import AuthService + + +class DummySession: + def __init__(self) -> None: + self.add = Mock() + self.commit = AsyncMock() + self.flush = AsyncMock() + self.scalar = AsyncMock() + self.scalars = AsyncMock() + self.get = AsyncMock() + self.refresh = AsyncMock() + + +def make_settings(**overrides: object) -> Settings: + payload = { + "FF_DB_URI": "postgresql://root:pwd@localhost:5432/fileflash", + "JWT_SECRET_KEY": "unit-test-secret-key-1234567890abcd", + } + payload.update(overrides) + return Settings(**payload) + + +def make_service(session: DummySession) -> AuthService: + event_publisher = SimpleNamespace(publish=AsyncMock()) + rate_limiter = SimpleNamespace(allow=AsyncMock(return_value=True)) + verification_email_delivery = SimpleNamespace(send_verification_email=AsyncMock()) + return AuthService( + db=session, # type: ignore[arg-type] + settings=make_settings(), + rate_limiter=rate_limiter, + event_publisher=event_publisher, + verification_email_delivery=verification_email_delivery, # type: ignore[arg-type] + ) + + +@pytest.mark.asyncio +async def test_register_rejects_when_no_enabled_rules() -> None: + session = DummySession() + service = make_service(session) + session.scalar = AsyncMock(return_value=None) + session.scalars = AsyncMock(return_value=[]) + + with pytest.raises(ApiError, match="邮箱后缀不被允许,请更换邮箱"): + await service.register( + RegisterRequest(username="new", email="new@example.com", password="123456"), + client_ip="127.0.0.1", + user_agent="pytest", + ) + + +@pytest.mark.asyncio +async def test_register_accepts_matching_rule() -> None: + session = DummySession() + service = make_service(session) + session.scalar = AsyncMock(return_value=None) + session.scalars = AsyncMock( + return_value=[SimpleNamespace(pattern=r".*\.example\.com", enabled=True)] + ) + service._to_user_schema = Mock( + return_value=User( + user_id="1", + username="new", + email="new@dept.example.com", + storage_limit=1024, + storage_used=0, + created_at=datetime.now(UTC), + role="user", + status="active", + email_verified=False, + email_verified_at=None, + preference=UserPreference(language="zh-CN"), + ) + ) # type: ignore[method-assign] + service._create_email_verification_token = AsyncMock(return_value="verify-token") # type: ignore[method-assign] + + result = await service.register( + RegisterRequest(username="new", email="new@dept.example.com", password="123456"), + client_ip="127.0.0.1", + user_agent="pytest", + ) + + assert result.email_verification_required is True + session.commit.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_register_rejects_non_matching_rule() -> None: + session = DummySession() + service = make_service(session) + session.scalar = AsyncMock(return_value=None) + session.scalars = AsyncMock( + return_value=[SimpleNamespace(pattern=r".*\.corp\.com", enabled=True)] + ) + + with pytest.raises(ApiError, match="邮箱后缀不被允许,请更换邮箱"): + await service.register( + RegisterRequest(username="new", email="new@example.com", password="123456"), + client_ip="127.0.0.1", + user_agent="pytest", + ) diff --git a/app/tests/test_background_jobs_service.py b/app/tests/test_background_jobs_service.py index e52d473..05c64be 100644 --- a/app/tests/test_background_jobs_service.py +++ b/app/tests/test_background_jobs_service.py @@ -7,8 +7,8 @@ import pytest from sqlalchemy.exc import IntegrityError -from src.models import BackgroundJob -from src.services.background_jobs import BackgroundJobService +from fileflash.models import BackgroundJob +from fileflash.services.background_jobs import BackgroundJobService, _build_queue_message class _PgUniqueViolation(Exception): @@ -91,3 +91,57 @@ async def test_enqueue_recovers_from_unique_conflict_and_returns_existing_job(): assert job is second_existing session.rollback.assert_awaited_once() queue.publish.assert_not_awaited() + + +def test_build_queue_message_injects_job_id_when_missing_or_none(): + base_kwargs = dict( + job_id=42, + task_type="task.archive_extract", + status="pending", + result={}, + error_message=None, + attempt=0, + max_attempts=5, + scheduled_at=datetime.now(UTC), + ) + + missing = BackgroundJob(payload={"targetFolderId": "root"}, requested_by=9, **base_kwargs) + missing_message = _build_queue_message(missing) + assert missing_message.payload["jobId"] == 42 + assert missing_message.payload["requestedBy"] == 9 + + none_job_id = BackgroundJob(payload={"jobId": None, "targetFolderId": "root"}, requested_by=9, **base_kwargs) + none_message = _build_queue_message(none_job_id) + assert none_message.payload["jobId"] == 42 + + keep_existing = BackgroundJob(payload={"jobId": 777, "targetFolderId": "root"}, requested_by=9, **base_kwargs) + keep_message = _build_queue_message(keep_existing) + assert keep_message.payload["jobId"] == 777 + + +@pytest.mark.asyncio +async def test_enqueue_transcode_job_uses_object_storage_payload(): + session = DummySession() + queue = SimpleNamespace(publish=AsyncMock(return_value="1-0")) + service = BackgroundJobService(queue_publisher=queue) + + job = await service.enqueue_transcode_job( + session, # type: ignore[arg-type] + source_bucket_name="fileflash", + source_object_key="objects/u1/src.mp4", + source_object_id=101, + output_bucket_name="fileflash", + output_object_key="optimized/transcode/v1/object-101/src-mp4-v1.mp4", + file_id=999, + requested_by=1, + idempotency_key="object:101:transcode:mp4-v1", + ) + + payload = job.payload + assert payload["sourceBucketName"] == "fileflash" + assert payload["sourceObjectKey"] == "objects/u1/src.mp4" + assert payload["sourceObjectId"] == 101 + assert payload["outputBucketName"] == "fileflash" + assert payload["outputObjectKey"].endswith(".mp4") + assert payload["fileId"] == 999 + assert payload["requestedBy"] == 1 diff --git a/app/tests/test_datetime_type.py b/app/tests/test_datetime_type.py index 89d2afe..2d068a5 100644 --- a/app/tests/test_datetime_type.py +++ b/app/tests/test_datetime_type.py @@ -2,7 +2,7 @@ from datetime import UTC, datetime -from src.models.types import UTCDateTime +from fileflash.models.types import UTCDateTime def test_utc_datetime_type_normalizes_bind_and_result_values(): diff --git a/app/tests/test_db_engine_schema_check.py b/app/tests/test_db_engine_schema_check.py new file mode 100644 index 0000000..4b47e19 --- /dev/null +++ b/app/tests/test_db_engine_schema_check.py @@ -0,0 +1,52 @@ +from __future__ import annotations + +from typing import Any + +import pytest + +from fileflash.db import engine as engine_module + + +class _DummyResult: + def __init__(self, value: int | None) -> None: + self._value = value + + def scalar(self) -> int | None: + return self._value + + +class _DummyConnection: + def __init__(self, results: list[int | None]) -> None: + self._results = list(results) + + async def execute(self, *_args: Any, **_kwargs: Any) -> _DummyResult: + if not self._results: + return _DummyResult(None) + return _DummyResult(self._results.pop(0)) + + +class _DummyConnectContext: + def __init__(self, connection: _DummyConnection) -> None: + self._connection = connection + + async def __aenter__(self) -> _DummyConnection: + return self._connection + + async def __aexit__(self, _exc_type, _exc, _tb) -> bool: + return False + + +class _DummyEngine: + def __init__(self, results: list[int | None]) -> None: + self._results = results + + def connect(self) -> _DummyConnectContext: + return _DummyConnectContext(_DummyConnection(self._results)) + + +@pytest.mark.asyncio +async def test_verify_schema_compatibility_fails_when_domain_rule_table_missing(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(engine_module, "engine", _DummyEngine([1, None])) + + with pytest.raises(RuntimeError, match="registration_email_domain_rule"): + await engine_module.verify_schema_compatibility() diff --git a/app/tests/test_db_transaction.py b/app/tests/test_db_transaction.py index 5af2677..8926f00 100644 --- a/app/tests/test_db_transaction.py +++ b/app/tests/test_db_transaction.py @@ -5,7 +5,7 @@ import pytest -from src.db.transaction import ( +from fileflash.db.transaction import ( is_retryable_database_error, is_unique_violation_error, run_with_transaction_retry, diff --git a/app/tests/test_dev_seed.py b/app/tests/test_dev_seed.py index e82ddf8..2ac47c2 100644 --- a/app/tests/test_dev_seed.py +++ b/app/tests/test_dev_seed.py @@ -4,9 +4,9 @@ import pytest -from src.core.settings import Settings -from src.models.tables_identity import User -from src.services.dev_seed import DevAccountSeeder, initialize_dev_accounts +from fileflash.core.settings import Settings +from fileflash.models.tables_identity import User +from fileflash.services.dev_seed import DevAccountSeeder, initialize_dev_accounts class DummySeedSession: @@ -97,7 +97,7 @@ async def test_dev_account_seeder_is_idempotent_and_supports_password_reset(): @pytest.mark.asyncio async def test_initialize_dev_accounts_skips_auto_run_in_production(monkeypatch: pytest.MonkeyPatch): guard = AsyncMock(side_effect=AssertionError("SessionLocal should not be called")) - monkeypatch.setattr("src.services.dev_seed.SessionLocal", guard) + monkeypatch.setattr("fileflash.services.dev_seed.SessionLocal", guard) result = await initialize_dev_accounts( settings=make_settings(APP_ENV="production"), diff --git a/app/tests/test_effects_archive_extract.py b/app/tests/test_effects_archive_extract.py index 5bb171e..42b248c 100644 --- a/app/tests/test_effects_archive_extract.py +++ b/app/tests/test_effects_archive_extract.py @@ -6,7 +6,7 @@ import pytest -from src.workers.effects import _apply_archive_extract_effects +from fileflash.workers.effects import _apply_archive_extract_effects class _StubSession: diff --git a/app/tests/test_file_download_recycle_service.py b/app/tests/test_file_download_recycle_service.py index 23bd5ea..cf5ec2c 100644 --- a/app/tests/test_file_download_recycle_service.py +++ b/app/tests/test_file_download_recycle_service.py @@ -5,20 +5,30 @@ import pytest -from src.core.errors import ApiError -from src.models.enums import FileStatus, FolderStatus, FolderType, UploadStatus -from src.models.tables_storage import File, Folder, StorageObject -from src.schemas.file import BatchFilesRequest -from src.services.file import FileService +from fileflash.core.errors import ApiError +from fileflash.models.enums import FileStatus, FolderStatus, FolderType, UploadStatus +from fileflash.models.tables_storage import File, FileMediaMetadata, Folder, StorageObject +from fileflash.schemas.file import BatchFilesRequest +from fileflash.services.file import FileService class DummyStorage: - async def iter_object(self, *, object_key: str): # noqa: ARG002 + async def iter_object(self, *, object_key: str, bucket_name: str | None = None): # noqa: ARG002 yield b"abcdefghij" - async def iter_object_range(self, *, object_key: str, start: int, end: int): # noqa: ARG002 + async def iter_object_range( + self, + *, + object_key: str, + start: int, + end: int, + bucket_name: str | None = None, + ): # noqa: ARG002 yield bytes(range(start, end + 1)) + async def object_exists(self, *, bucket_name: str, object_key: str): # noqa: ARG002 + return False + class DummySession: def __init__(self) -> None: @@ -275,6 +285,95 @@ async def test_get_preview_stream_rejects_invalid_range(monkeypatch: pytest.Monk assert exc.value.status_code == 416 +@pytest.mark.asyncio +async def test_get_preview_stream_prefers_transcoded_object_when_ready(monkeypatch: pytest.MonkeyPatch): + session = DummySession() + storage = DummyStorage() + service = FileService(db=session, storage=storage) + + file_row = make_file_row(file_id=11, file_name="preview.mp4") + file_row.file_ext = "mp4" + file_row.mime_type = "video/mp4" + file_row.storage_object_id = 101 + source_object = StorageObject( + object_id=101, + bucket_name="fileflash", + object_key="objects/u1/source", + object_size=256, + upload_status=UploadStatus.ACTIVE, + content_type="video/mp4", + ) + optimized_object = StorageObject( + object_id=102, + bucket_name="fileflash", + object_key="optimized/transcode/v1/object-101/source-mp4-v1.mp4", + object_size=128, + upload_status=UploadStatus.ACTIVE, + content_type="video/mp4", + ) + metadata = FileMediaMetadata(source_object_id=101) + metadata.extra_metadata = { + "transcode": { + "status": "ready", + "mediaType": "video", + "optimizedBucketName": optimized_object.bucket_name, + "optimizedObjectKey": optimized_object.object_key, + "optimizedMimeType": "video/mp4", + "updatedAt": datetime.now(UTC).isoformat(), + } + } + metadata.extracted_at = datetime.now(UTC) + + monkeypatch.setattr(service, "_get_active_file", AsyncMock(return_value=file_row)) + session.get = AsyncMock(return_value=source_object) + session.scalar = AsyncMock(side_effect=[metadata, optimized_object]) + + result = await service.get_preview_stream(user_id=1, file_id="11", range_header=None) + assert result.status_code == 200 + assert result.headers["Content-Length"] == "128" + + +@pytest.mark.asyncio +async def test_get_preview_stream_falls_back_to_source_when_transcoded_missing(monkeypatch: pytest.MonkeyPatch): + session = DummySession() + storage = DummyStorage() + service = FileService(db=session, storage=storage) + + file_row = make_file_row(file_id=12, file_name="preview.mp4") + file_row.file_ext = "mp4" + file_row.mime_type = "video/mp4" + file_row.storage_object_id = 201 + source_object = StorageObject( + object_id=201, + bucket_name="fileflash", + object_key="objects/u1/source-2", + object_size=512, + upload_status=UploadStatus.ACTIVE, + content_type="video/mp4", + ) + metadata = FileMediaMetadata(source_object_id=201) + metadata.extra_metadata = { + "transcode": { + "status": "ready", + "mediaType": "video", + "optimizedBucketName": "fileflash", + "optimizedObjectKey": "optimized/not-found.mp4", + "optimizedMimeType": "video/mp4", + "updatedAt": datetime.now(UTC).isoformat(), + } + } + metadata.extracted_at = datetime.now(UTC) + + monkeypatch.setattr(service, "_get_active_file", AsyncMock(return_value=file_row)) + session.get = AsyncMock(return_value=source_object) + session.scalar = AsyncMock(side_effect=[metadata, None]) + monkeypatch.setattr(service.storage, "object_exists", AsyncMock(return_value=False)) + + result = await service.get_preview_stream(user_id=1, file_id="12", range_header=None) + assert result.status_code == 200 + assert result.headers["Content-Length"] == "512" + + @pytest.mark.asyncio async def test_delete_file_marks_record_deleted(monkeypatch: pytest.MonkeyPatch): session = DummySession() diff --git a/app/tests/test_file_folder_patch_routes.py b/app/tests/test_file_folder_patch_routes.py index 7896b86..b289d06 100644 --- a/app/tests/test_file_folder_patch_routes.py +++ b/app/tests/test_file_folder_patch_routes.py @@ -6,12 +6,12 @@ from fastapi import FastAPI from fastapi.testclient import TestClient -from src.core.deps import get_current_user, get_file_service, get_folder_service -from src.models.tables_identity import User -from src.routers.files import router as files_router -from src.routers.folders import router as folders_router -from src.schemas.file import FileDetails, FolderItem, RenameFileRequest -from src.services.file import DownloadStreamResult +from fileflash.core.deps import get_current_user, get_file_service, get_folder_service +from fileflash.models.tables_identity import User +from fileflash.routers.files import router as files_router +from fileflash.routers.folders import router as folders_router +from fileflash.schemas.file import FileDetails, FolderItem, RenameFileRequest +from fileflash.services.file import DownloadStreamResult def _make_file_details(*, name: str, is_starred: bool) -> FileDetails: diff --git a/app/tests/test_http_headers.py b/app/tests/test_http_headers.py index c38b0c4..20ae377 100644 --- a/app/tests/test_http_headers.py +++ b/app/tests/test_http_headers.py @@ -1,6 +1,6 @@ from __future__ import annotations -from src.core.http_headers import build_content_disposition +from fileflash.core.http_headers import build_content_disposition def test_build_content_disposition_for_ascii_filename() -> None: diff --git a/app/tests/test_me_service.py b/app/tests/test_me_service.py index 12a5650..3964799 100644 --- a/app/tests/test_me_service.py +++ b/app/tests/test_me_service.py @@ -6,13 +6,15 @@ import pytest -from src.core.security import get_password_hash, hash_token -from src.core.settings import Settings -from src.models.enums import UiLanguage, UserRole, UserStatus -from src.models.tables_audit_security import Log -from src.models.tables_identity import User, UserPreference, UserSession -from src.schemas.user import ChangePasswordRequest, GetActivityLogQuery, UpdateProfileRequest -from src.services.auth import AuthService +from fileflash.core.errors import ApiError +from fileflash.core.security import get_password_hash, hash_token +from fileflash.core.settings import Settings +from fileflash.models.enums import UiLanguage, UserRole, UserStatus +from fileflash.models.tables_audit_security import Log +from fileflash.models.tables_identity import User, UserPreference, UserSession +from fileflash.schemas.user import ChangePasswordRequest, GetActivityLogQuery, UpdateProfileRequest +from fileflash.services.auth import AuthService +from fileflash.services.email_delivery import EmailDeliveryError class DummyResult: @@ -27,6 +29,7 @@ class DummySession: def __init__(self) -> None: self.add = Mock() self.commit = AsyncMock() + self.flush = AsyncMock() self.scalar = AsyncMock() self.scalars = AsyncMock() self.get = AsyncMock() @@ -45,11 +48,13 @@ def make_settings(**overrides: object) -> Settings: def make_service(session: DummySession, publisher: AsyncMock | None = None) -> AuthService: event_publisher = SimpleNamespace(publish=publisher or AsyncMock()) rate_limiter = SimpleNamespace(allow=AsyncMock(return_value=True)) + verification_email_delivery = SimpleNamespace(send_verification_email=AsyncMock()) return AuthService( db=session, settings=make_settings(), rate_limiter=rate_limiter, event_publisher=event_publisher, + verification_email_delivery=verification_email_delivery, # type: ignore[arg-type] ) @@ -75,9 +80,11 @@ async def test_update_profile_resets_email_verification_and_publishes_event(): ) preference = UserPreference(user_id=1, ui_language=UiLanguage.ZH_CN) session.scalar = AsyncMock(side_effect=[user, None, None]) + session.scalars = AsyncMock(return_value=[SimpleNamespace(pattern=r"new\.local", enabled=True)]) service._get_user_preference = AsyncMock(return_value=preference) # type: ignore[method-assign] service._create_email_verification_token = AsyncMock(return_value="verify-token") # type: ignore[method-assign] + service._send_verification_email_or_raise = AsyncMock() # type: ignore[method-assign] profile = await service.update_profile( user_id=1, @@ -91,11 +98,47 @@ async def test_update_profile_resets_email_verification_and_publishes_event(): assert user.email_verified_at is None session.commit.assert_awaited_once() publish_mock.assert_awaited_once() + service._send_verification_email_or_raise.assert_awaited_once() # type: ignore[attr-defined] + + +@pytest.mark.asyncio +async def test_update_profile_email_unchanged_skips_rule_check(): + session = DummySession() + service = make_service(session) + + user = User( + user_id=9, + username="demo", + email="same@local.test", + password_hash="hash", + role=UserRole.USER, + status=UserStatus.ACTIVE, + storage_limit=1024, + storage_used=128, + email_verified=True, + email_verified_at=datetime.now(UTC), + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + preference = UserPreference(user_id=9, ui_language=UiLanguage.ZH_CN) + session.scalar = AsyncMock(return_value=user) + session.scalars = AsyncMock(return_value=[]) + service._get_user_preference = AsyncMock(return_value=preference) # type: ignore[method-assign] + + profile = await service.update_profile( + user_id=9, + payload=UpdateProfileRequest(email="same@local.test"), + user_agent="pytest-agent", + ) + + assert profile.email == "same@local.test" + session.scalars.assert_not_awaited() @pytest.mark.asyncio async def test_change_password_revokes_other_sessions_only(): session = DummySession() + settings = make_settings() service = make_service(session) user = User( @@ -114,14 +157,14 @@ async def test_change_password_revokes_other_sessions_only(): keep_session = UserSession( session_id=11, user_id=2, - refresh_token_hash=hash_token(keep_refresh_token), + refresh_token_hash=hash_token(keep_refresh_token, settings), client_type="web", expire_at=datetime.now(UTC), ) other_session = UserSession( session_id=12, user_id=2, - refresh_token_hash=hash_token("other-token"), + refresh_token_hash=hash_token("other-token", settings), client_type="web", expire_at=datetime.now(UTC), ) @@ -216,3 +259,71 @@ async def test_get_storage_summary_aggregates_files_and_folders(): assert summary.breakdown["images"].count == 1 assert summary.breakdown["documents"].count == 1 assert summary.breakdown["archives"].count == 1 + + +@pytest.mark.asyncio +async def test_send_verification_email_raises_503_in_production() -> None: + session = DummySession() + publish_mock = AsyncMock() + delivery = SimpleNamespace(send_verification_email=AsyncMock(side_effect=EmailDeliveryError("smtp down"))) + service = AuthService( + db=session, + settings=make_settings(APP_ENV="production"), + rate_limiter=SimpleNamespace(allow=AsyncMock(return_value=True)), + event_publisher=SimpleNamespace(publish=publish_mock), + verification_email_delivery=delivery, # type: ignore[arg-type] + ) + + with pytest.raises(ApiError) as exc: + await service._send_verification_email_or_raise( # type: ignore[attr-defined] + event_name="auth.email_verification_requested", + email="demo@example.com", + token="secure-token", + expires_in_minutes=60, + ) + + assert exc.value.status_code == 503 + assert exc.value.code == 503 + publish_mock.assert_awaited_once_with( + "auth.email_verification_delivery_failed", + {"eventName": "auth.email_verification_requested"}, + ) + + +@pytest.mark.asyncio +async def test_send_verification_email_raises_503_in_development() -> None: + session = DummySession() + publish_mock = AsyncMock() + delivery = SimpleNamespace(send_verification_email=AsyncMock(side_effect=EmailDeliveryError("smtp down"))) + service = AuthService( + db=session, + settings=make_settings(APP_ENV="development"), + rate_limiter=SimpleNamespace(allow=AsyncMock(return_value=True)), + event_publisher=SimpleNamespace(publish=publish_mock), + verification_email_delivery=delivery, # type: ignore[arg-type] + ) + + with pytest.raises(ApiError) as exc: + await service._send_verification_email_or_raise( # type: ignore[attr-defined] + event_name="auth.email_verification_resent", + email="demo@example.com", + token="secure-token", + expires_in_minutes=60, + ) + + assert exc.value.status_code == 503 + assert exc.value.code == 503 + + publish_mock.assert_awaited_once_with( + "auth.email_verification_delivery_failed", + {"eventName": "auth.email_verification_resent"}, + ) + + +def test_assert_token_length_rejects_short_token() -> None: + with pytest.raises(ApiError, match="Invalid or expired verification token"): + AuthService._assert_token_length( + token="too-short", + minimum=16, + message="Invalid or expired verification token", + ) diff --git a/app/tests/test_mime.py b/app/tests/test_mime.py index 9b6fc2c..3cd77a9 100644 --- a/app/tests/test_mime.py +++ b/app/tests/test_mime.py @@ -2,7 +2,7 @@ import pytest -from src.core.mime import DEFAULT_MIME_TYPE, resolve_file_mime_type +from fileflash.core.mime import DEFAULT_MIME_TYPE, resolve_file_mime_type @pytest.mark.parametrize( diff --git a/app/tests/test_move_services.py b/app/tests/test_move_services.py index fbd1554..d1f4f63 100644 --- a/app/tests/test_move_services.py +++ b/app/tests/test_move_services.py @@ -1,16 +1,17 @@ from __future__ import annotations from datetime import UTC, datetime +from types import SimpleNamespace from unittest.mock import AsyncMock import pytest -from src.core.errors import ApiError -from src.models.enums import FavoriteItemType, FileStatus, FolderStatus, FolderType -from src.models.tables_access_share import FavoriteItem -from src.models.tables_identity import User -from src.models.tables_storage import File, Folder -from src.schemas.file import ( +from fileflash.core.errors import ApiError +from fileflash.models.enums import FavoriteItemType, FileStatus, FolderStatus, FolderType +from fileflash.models.tables_access_share import FavoriteItem +from fileflash.models.tables_identity import User +from fileflash.models.tables_storage import File, FileMediaMetadata, Folder +from fileflash.schemas.file import ( BatchFilesRequest, CreateFolderRequest, FileDetails, @@ -19,8 +20,8 @@ RenameFileRequest, RenameFolderRequest, ) -from src.services.file import FileService -from src.services.folder import FolderService +from fileflash.services.file import FileService +from fileflash.services.folder import FolderService class DummySession: @@ -275,7 +276,7 @@ async def test_folder_service_move_folder_delegates_to_file_service(monkeypatch: "moved_at": datetime.now(UTC), } ) - monkeypatch.setattr("src.services.file.FileService._move_folder_record", move_mock) + monkeypatch.setattr("fileflash.services.file.FileService._move_folder_record", move_mock) result = await service.move_folder( user_id=1, @@ -293,6 +294,7 @@ async def test_folder_service_move_folder_delegates_to_file_service(monkeypatch: class DummyFolderSession: def __init__(self) -> None: self.commit = AsyncMock() + self.execute = AsyncMock() self.scalar = AsyncMock() self.get = AsyncMock() self.refresh = AsyncMock() @@ -379,6 +381,36 @@ async def test_folder_service_rename_folder_auto_suffix(monkeypatch: pytest.Monk session.commit.assert_awaited_once() +@pytest.mark.asyncio +async def test_folder_service_load_media_optimization_map_parses_transcode_metadata(): + session = DummySession() + service = FolderService(db=session) + file_row = make_file_row() + extracted_at = datetime(2026, 5, 13, 8, 30, tzinfo=UTC) + metadata_row = FileMediaMetadata( + source_object_id=9, + extra_metadata={ + "transcode": { + "status": "ready", + "mediaType": "video", + "optimizedMimeType": "video/mp4", + "updatedAt": "2026-05-13T08:31:00Z", + } + }, + extracted_at=extracted_at, + ) + session.scalars = AsyncMock(return_value=[metadata_row]) + + result = await service._load_media_optimization_map([file_row]) + + assert 1 in result + media = result[1] + assert media.status == "ready" + assert media.media_type == "video" + assert media.optimized_mime_type == "video/mp4" + assert media.updated_at == datetime(2026, 5, 13, 8, 31, tzinfo=UTC) + + @pytest.mark.asyncio async def test_file_service_rename_file_auto_suffix(monkeypatch: pytest.MonkeyPatch): session = DummySession() @@ -461,7 +493,9 @@ async def test_file_service_toggle_file_star_adds_favorite(monkeypatch: pytest.M monkeypatch.setattr(service, "_get_active_file", AsyncMock(return_value=file_row)) monkeypatch.setattr(service, "get_file", AsyncMock(return_value=expected)) - session.scalar = AsyncMock(return_value=None) + monkeypatch.setattr(service, "_get_file_favorite", AsyncMock(side_effect=[None, None])) + monkeypatch.setattr(service, "_lock_user_for_star_update", AsyncMock(return_value=None)) + monkeypatch.setattr(service, "_count_starred_items", AsyncMock(return_value=0)) result = await service.toggle_file_star(user_id=1, file_id="1", is_starred=True) @@ -502,7 +536,7 @@ async def test_file_service_toggle_file_star_removes_favorite(monkeypatch: pytes isStarred=False, status=True, ))) - session.scalar = AsyncMock(return_value=existing_favorite) + monkeypatch.setattr(service, "_get_file_favorite", AsyncMock(return_value=existing_favorite)) await service.toggle_file_star(user_id=1, file_id="1", is_starred=False) @@ -518,8 +552,11 @@ async def test_folder_service_toggle_folder_star_adds_favorite(): folder = make_folder_row(folder_id=200) owner = User(user_id=1, username="owner", email="owner@example.com", password_hash="hash") - session.scalar = AsyncMock(side_effect=[folder, None]) + session.scalar = AsyncMock(return_value=folder) session.get = AsyncMock(return_value=owner) + service._get_folder_favorite = AsyncMock(side_effect=[None, None]) # type: ignore[method-assign] + service._count_starred_items = AsyncMock(return_value=0) # type: ignore[method-assign] + service._lock_user_for_star_update = AsyncMock(return_value=None) # type: ignore[method-assign] response = await service.toggle_folder_star(user_id=1, folder_id="200", is_starred=True) @@ -548,8 +585,9 @@ async def test_folder_service_toggle_folder_star_removes_favorite(): file_id=None, ) - session.scalar = AsyncMock(side_effect=[folder, existing_favorite]) + session.scalar = AsyncMock(return_value=folder) session.get = AsyncMock(return_value=owner) + service._get_folder_favorite = AsyncMock(return_value=existing_favorite) # type: ignore[method-assign] response = await service.toggle_folder_star(user_id=1, folder_id="200", is_starred=False) @@ -558,3 +596,75 @@ async def test_folder_service_toggle_folder_star_removes_favorite(): assert session.added == [] assert session.deleted == [existing_favorite] session.commit.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_file_service_toggle_file_star_rejects_when_limit_reached(monkeypatch: pytest.MonkeyPatch): + session = DummySession() + service = FileService(db=session, starred_items_limit=2) + file_row = make_file_row() + + monkeypatch.setattr(service, "_get_active_file", AsyncMock(return_value=file_row)) + monkeypatch.setattr(service, "_get_file_favorite", AsyncMock(side_effect=[None, None])) + monkeypatch.setattr(service, "_lock_user_for_star_update", AsyncMock(return_value=None)) + monkeypatch.setattr(service, "_count_starred_items", AsyncMock(return_value=2)) + + with pytest.raises(ApiError) as exc: + await service.toggle_file_star(user_id=1, file_id="1", is_starred=True) + + assert exc.value.status_code == 400 + assert exc.value.code == 400 + assert "已达收藏上限 2" in exc.value.message + session.commit.assert_not_awaited() + assert session.added == [] + + +@pytest.mark.asyncio +async def test_folder_service_toggle_folder_star_rejects_when_limit_reached(): + session = DummyFolderSession() + service = FolderService(db=session, starred_items_limit=2) + folder = make_folder_row(folder_id=200) + owner = User(user_id=1, username="owner", email="owner@example.com", password_hash="hash") + + session.scalar = AsyncMock(return_value=folder) + session.get = AsyncMock(return_value=owner) + service._get_folder_favorite = AsyncMock(side_effect=[None, None]) # type: ignore[method-assign] + service._count_starred_items = AsyncMock(return_value=2) # type: ignore[method-assign] + service._lock_user_for_star_update = AsyncMock(return_value=None) # type: ignore[method-assign] + + with pytest.raises(ApiError) as exc: + await service.toggle_folder_star(user_id=1, folder_id="200", is_starred=True) + + assert exc.value.status_code == 400 + assert exc.value.code == 400 + assert "已达收藏上限 2" in exc.value.message + session.commit.assert_not_awaited() + assert session.added == [] + + +@pytest.mark.asyncio +async def test_file_service_list_starred_orders_by_recent_favorite(monkeypatch: pytest.MonkeyPatch): + session = DummySession() + service = FileService(db=session) + + older = datetime(2026, 5, 13, 8, 0, tzinfo=UTC) + middle = datetime(2026, 5, 13, 9, 0, tzinfo=UTC) + newer = datetime(2026, 5, 13, 10, 0, tzinfo=UTC) + + file_old = make_file_row() + file_old.file_id = 1 + file_old.file_name = "old.txt" + file_new = make_file_row() + file_new.file_id = 2 + file_new.file_name = "new.txt" + folder_mid = make_folder_row(folder_id=200) + folder_mid.folder_name = "docs" + + file_execute_result = SimpleNamespace(all=lambda: [(older, file_old, "owner"), (newer, file_new, "owner")]) + folder_execute_result = SimpleNamespace(all=lambda: [(middle, folder_mid, "owner")]) + session.execute = AsyncMock(side_effect=[file_execute_result, folder_execute_result]) + monkeypatch.setattr(service, "_load_media_optimization_map", AsyncMock(return_value={})) + + response = await service.list_starred(user_id=1) + + assert [item.id for item in response.items] == ["2", "200", "1"] diff --git a/app/tests/test_registration_email_domain_rule_service.py b/app/tests/test_registration_email_domain_rule_service.py new file mode 100644 index 0000000..b5c229c --- /dev/null +++ b/app/tests/test_registration_email_domain_rule_service.py @@ -0,0 +1,118 @@ +from __future__ import annotations + +from datetime import UTC, datetime +from unittest.mock import AsyncMock, Mock + +import pytest + +from fileflash.core.errors import ApiError +from fileflash.models.tables_identity import RegistrationEmailDomainRule +from fileflash.schemas.registration_email_domain_rule import ( + CreateRegistrationEmailDomainRuleRequest, + UpdateRegistrationEmailDomainRuleRequest, +) +from fileflash.services.registration_email_domain_rule import RegistrationEmailDomainRuleService + + +class DummySession: + def __init__(self) -> None: + self.add = Mock() + self.commit = AsyncMock() + self.refresh = AsyncMock() + self.scalar = AsyncMock() + self.scalars = AsyncMock() + self.get = AsyncMock() + self.delete = AsyncMock() + + +@pytest.mark.asyncio +async def test_assert_email_allowed_rejects_when_no_enabled_rules() -> None: + session = DummySession() + session.scalars = AsyncMock(return_value=[]) + service = RegistrationEmailDomainRuleService(db=session) # type: ignore[arg-type] + + with pytest.raises(ApiError, match="邮箱后缀不被允许,请更换邮箱"): + await service.assert_email_allowed(email="demo@example.com") + + +@pytest.mark.asyncio +async def test_assert_email_allowed_accepts_when_pattern_matches() -> None: + session = DummySession() + rule = RegistrationEmailDomainRule( + rule_id=1, + name="corp", + pattern=r".*\.corp\.com", + enabled=True, + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + session.scalars = AsyncMock(return_value=[rule]) + service = RegistrationEmailDomainRuleService(db=session) # type: ignore[arg-type] + + await service.assert_email_allowed(email="user@dept.corp.com") + + +@pytest.mark.asyncio +async def test_assert_email_allowed_rejects_when_no_match() -> None: + session = DummySession() + rule = RegistrationEmailDomainRule( + rule_id=1, + name="corp", + pattern=r".*\.corp\.com", + enabled=True, + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + session.scalars = AsyncMock(return_value=[rule]) + service = RegistrationEmailDomainRuleService(db=session) # type: ignore[arg-type] + + with pytest.raises(ApiError, match="邮箱后缀不被允许,请更换邮箱"): + await service.assert_email_allowed(email="demo@example.com") + + +@pytest.mark.asyncio +async def test_create_rule_rejects_risky_pattern() -> None: + session = DummySession() + session.scalar = AsyncMock(return_value=None) + service = RegistrationEmailDomainRuleService(db=session) # type: ignore[arg-type] + + with pytest.raises(ApiError, match="Regex pattern is too risky"): + await service.create_rule( + payload=CreateRegistrationEmailDomainRuleRequest( + name="risky", + pattern=r"(a+)+$", + enabled=True, + ) + ) + + +@pytest.mark.asyncio +async def test_update_rule_rejects_invalid_pattern() -> None: + session = DummySession() + row = RegistrationEmailDomainRule( + rule_id=7, + name="ok", + pattern=r".*", + enabled=True, + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + session.get = AsyncMock(return_value=row) + service = RegistrationEmailDomainRuleService(db=session) # type: ignore[arg-type] + + with pytest.raises(ApiError, match="Invalid regex pattern"): + await service.update_rule( + rule_id=7, + payload=UpdateRegistrationEmailDomainRuleRequest(pattern=r"([a-z"), + ) + + +@pytest.mark.asyncio +async def test_delete_rule_not_found() -> None: + session = DummySession() + session.get = AsyncMock(return_value=None) + service = RegistrationEmailDomainRuleService(db=session) # type: ignore[arg-type] + + with pytest.raises(ApiError, match="Rule not found"): + await service.delete_rule(rule_id=100) + diff --git a/app/tests/test_run_with_workers.py b/app/tests/test_run_with_workers.py new file mode 100644 index 0000000..854a8cb --- /dev/null +++ b/app/tests/test_run_with_workers.py @@ -0,0 +1,195 @@ +from __future__ import annotations + +from dataclasses import dataclass +from types import SimpleNamespace + +from fileflash.scripts import run_with_workers + + +@dataclass +class _FakeProcess: + poll_result: int | None = None + + def poll(self): + return self.poll_result + + def send_signal(self, *_args, **_kwargs): + return None + + def wait(self, timeout=None): + return None + + def terminate(self): + return None + + def kill(self): + return None + + +def test_run_with_workers_uses_fileflash_entrypoints(monkeypatch): + started: list[tuple[str, list[str], object]] = [] + + def fake_spawn(name: str, command: list[str], cwd): + started.append((name, command, cwd)) + return run_with_workers.ManagedProcess(name=name, process=_FakeProcess(), command=command) + + def raise_keyboard_interrupt(*_args, **_kwargs): + raise KeyboardInterrupt + + monkeypatch.setattr( + run_with_workers, + "get_settings", + lambda: SimpleNamespace(worker_process_count=1, redis_url="redis://localhost:6379/0"), + ) + monkeypatch.setattr(run_with_workers, "_spawn_process", fake_spawn) + monkeypatch.setattr(run_with_workers, "_stop_process", lambda *_args, **_kwargs: None) + monkeypatch.setattr(run_with_workers, "_validate_redis_for_workers", lambda _env=None: (True, "")) + monkeypatch.setattr(run_with_workers.time, "sleep", raise_keyboard_interrupt) + monkeypatch.setattr(run_with_workers.sys, "argv", ["run-with-workers"]) + + exit_code = run_with_workers.main() + + assert exit_code == 0 + assert started[0][0] == "api" + assert started[0][1][:4] == [run_with_workers.sys.executable, "-m", "uvicorn", "fileflash.main:app"] + assert "--host" in started[0][1] + assert "--port" in started[0][1] + assert started[1][0] == "worker-1" + assert started[1][1] == [run_with_workers.sys.executable, "-m", "fileflash.workers.consumer"] + + +def test_run_with_workers_default_worker_count_comes_from_settings(monkeypatch): + started: list[tuple[str, list[str], object]] = [] + + def fake_spawn(name: str, command: list[str], cwd): + started.append((name, command, cwd)) + return run_with_workers.ManagedProcess(name=name, process=_FakeProcess(), command=command) + + def raise_keyboard_interrupt(*_args, **_kwargs): + raise KeyboardInterrupt + + monkeypatch.setattr( + run_with_workers, + "get_settings", + lambda: SimpleNamespace(worker_process_count=2, redis_url="redis://localhost:6379/0"), + ) + monkeypatch.setattr(run_with_workers, "_spawn_process", fake_spawn) + monkeypatch.setattr(run_with_workers, "_stop_process", lambda *_args, **_kwargs: None) + monkeypatch.setattr(run_with_workers, "_validate_redis_for_workers", lambda _env=None: (True, "")) + monkeypatch.setattr(run_with_workers.time, "sleep", raise_keyboard_interrupt) + monkeypatch.setattr(run_with_workers.sys, "argv", ["run-with-workers"]) + + exit_code = run_with_workers.main() + + assert exit_code == 0 + worker_names = [name for name, _command, _cwd in started if name.startswith("worker-")] + assert worker_names == ["worker-1", "worker-2"] + + +def test_run_with_workers_cli_worker_count_overrides_settings(monkeypatch): + started: list[tuple[str, list[str], object]] = [] + + def fake_spawn(name: str, command: list[str], cwd): + started.append((name, command, cwd)) + return run_with_workers.ManagedProcess(name=name, process=_FakeProcess(), command=command) + + def raise_keyboard_interrupt(*_args, **_kwargs): + raise KeyboardInterrupt + + monkeypatch.setattr( + run_with_workers, + "get_settings", + lambda: SimpleNamespace(worker_process_count=3, redis_url="redis://localhost:6379/0"), + ) + monkeypatch.setattr(run_with_workers, "_spawn_process", fake_spawn) + monkeypatch.setattr(run_with_workers, "_stop_process", lambda *_args, **_kwargs: None) + monkeypatch.setattr(run_with_workers, "_validate_redis_for_workers", lambda _env=None: (True, "")) + monkeypatch.setattr(run_with_workers.time, "sleep", raise_keyboard_interrupt) + monkeypatch.setattr(run_with_workers.sys, "argv", ["run-with-workers", "--worker-count", "1"]) + + exit_code = run_with_workers.main() + + assert exit_code == 0 + worker_names = [name for name, _command, _cwd in started if name.startswith("worker-")] + assert worker_names == ["worker-1"] + + +def test_run_with_workers_fails_fast_when_redis_url_missing(monkeypatch, capsys): + def fail_if_spawned(*_args, **_kwargs): + raise AssertionError("should not spawn") + + monkeypatch.setattr( + run_with_workers, + "get_settings", + lambda: SimpleNamespace(worker_process_count=1, redis_url=None), + ) + monkeypatch.setattr(run_with_workers, "_spawn_process", fail_if_spawned) + monkeypatch.setattr(run_with_workers.sys, "argv", ["run-with-workers"]) + + exit_code = run_with_workers.main() + + captured = capsys.readouterr() + assert exit_code == 2 + assert "REDIS_URL is not set" in captured.err + + +def test_run_with_workers_fails_fast_when_redis_ping_fails(monkeypatch, capsys): + class _FailingRedisClient: + def ping(self): + raise RuntimeError("ping failed") + + def close(self): + return None + + def fail_if_spawned(*_args, **_kwargs): + raise AssertionError("should not spawn") + + monkeypatch.setattr( + run_with_workers, + "get_settings", + lambda: SimpleNamespace(worker_process_count=1, redis_url="redis://localhost:6379/0"), + ) + monkeypatch.setattr(run_with_workers, "_spawn_process", fail_if_spawned) + monkeypatch.setattr( + run_with_workers.Redis, + "from_url", + lambda *_args, **_kwargs: _FailingRedisClient(), + ) + monkeypatch.setattr(run_with_workers.sys, "argv", ["run-with-workers"]) + + exit_code = run_with_workers.main() + + captured = capsys.readouterr() + assert exit_code == 2 + assert "cannot connect to Redis" in captured.err + assert "ping failed" in captured.err + + +def test_run_with_workers_no_worker_skips_redis_preflight(monkeypatch): + started: list[tuple[str, list[str], object]] = [] + + def fake_spawn(name: str, command: list[str], cwd): + started.append((name, command, cwd)) + return run_with_workers.ManagedProcess(name=name, process=_FakeProcess(), command=command) + + def raise_keyboard_interrupt(*_args, **_kwargs): + raise KeyboardInterrupt + + def fail_if_called(*_args, **_kwargs): + raise AssertionError("redis preflight should be skipped when --no-worker is set") + + monkeypatch.setattr( + run_with_workers, + "get_settings", + lambda: SimpleNamespace(worker_process_count=2, redis_url=None), + ) + monkeypatch.setattr(run_with_workers, "_spawn_process", fake_spawn) + monkeypatch.setattr(run_with_workers, "_stop_process", lambda *_args, **_kwargs: None) + monkeypatch.setattr(run_with_workers, "_validate_redis_for_workers", fail_if_called) + monkeypatch.setattr(run_with_workers.time, "sleep", raise_keyboard_interrupt) + monkeypatch.setattr(run_with_workers.sys, "argv", ["run-with-workers", "--no-worker"]) + + exit_code = run_with_workers.main() + + assert exit_code == 0 + assert [name for name, _command, _cwd in started] == ["api"] diff --git a/app/tests/test_security.py b/app/tests/test_security.py index 944356b..fe29f57 100644 --- a/app/tests/test_security.py +++ b/app/tests/test_security.py @@ -1,6 +1,6 @@ from __future__ import annotations -from src.core.security import ( +from fileflash.core.security import ( create_access_token, create_refresh_token, decode_access_token, @@ -8,7 +8,7 @@ hash_token, verify_password, ) -from src.core.settings import Settings +from fileflash.core.settings import Settings def test_password_hash_and_verify(): @@ -32,8 +32,31 @@ def test_access_token_round_trip(): def test_refresh_token_hash_is_deterministic(): + settings = Settings( + JWT_SECRET_KEY="unit-test-secret-key-1234567890abcd", + FF_DB_URI="postgresql://u:p@localhost:5432/db", + ) refresh_token = create_refresh_token() - token_hash_1 = hash_token(refresh_token) - token_hash_2 = hash_token(refresh_token) + token_hash_1 = hash_token(refresh_token, settings) + token_hash_2 = hash_token(refresh_token, settings) assert token_hash_1 == token_hash_2 assert len(token_hash_1) == 64 + + +def test_refresh_token_hash_changes_with_different_secret(): + token = "same-token-value" + settings_a = Settings( + JWT_SECRET_KEY="unit-test-secret-key-1234567890abcd", + TOKEN_HASH_SECRET="token-secret-A", + FF_DB_URI="postgresql://u:p@localhost:5432/db", + ) + settings_b = Settings( + JWT_SECRET_KEY="unit-test-secret-key-1234567890abcd", + TOKEN_HASH_SECRET="token-secret-B", + FF_DB_URI="postgresql://u:p@localhost:5432/db", + ) + + hash_a = hash_token(token, settings_a) + hash_b = hash_token(token, settings_b) + + assert hash_a != hash_b diff --git a/app/tests/test_settings.py b/app/tests/test_settings.py index b349c58..de26b55 100644 --- a/app/tests/test_settings.py +++ b/app/tests/test_settings.py @@ -1,6 +1,8 @@ from __future__ import annotations -from src.core.settings import Settings +import pytest + +from fileflash.core.settings import Settings def test_async_database_url_conversion(): @@ -13,6 +15,8 @@ def test_upload_related_settings_defaults(): assert settings.object_storage_bucket == "fileflash" assert settings.upload_chunk_size_default == 5 * 1024 * 1024 assert settings.upload_session_ttl_seconds == 24 * 3600 + assert settings.starred_items_limit == 20 + assert settings.worker_process_count == 1 def test_agent_related_settings_defaults(): @@ -32,3 +36,84 @@ def test_app_env_detection(): assert prod.is_development_env is False assert prod.is_production_env is True + +def test_worker_process_count_from_env(): + settings = Settings( + FF_DB_URI="postgresql://root:pwd@localhost:5432/fileflash", + WORKER_PROCESS_COUNT="3", + ) + assert settings.worker_process_count == 3 + + +def test_starred_items_limit_from_env(): + settings = Settings( + FF_DB_URI="postgresql://root:pwd@localhost:5432/fileflash", + STARRED_ITEMS_LIMIT="12", + ) + assert settings.starred_items_limit == 12 + + +def test_verify_base_url_defaults_to_localhost_in_development(): + settings = Settings( + FF_DB_URI="postgresql://root:pwd@localhost:5432/fileflash", + APP_ENV="development", + ) + assert settings.normalized_email_verify_base_url == "http://localhost:8080" + + +def test_verify_base_url_adds_http_when_scheme_missing(): + settings = Settings( + FF_DB_URI="postgresql://root:pwd@localhost:5432/fileflash", + EMAIL_VERIFY_BASE_URL="localhost:3000", + ) + assert settings.normalized_email_verify_base_url == "http://localhost:3000" + + +def test_mail_configuration_issues_includes_missing_required_fields(): + settings = Settings( + FF_DB_URI="postgresql://root:pwd@localhost:5432/fileflash", + APP_ENV="development", + MAIL_FROM="", + MAIL_SERVER="", + MAIL_USERNAME="demo@example.com", + MAIL_PASSWORD="secret", + ) + issues = settings.mail_configuration_issues + assert "MAIL_FROM is required" in issues + assert "MAIL_SERVER is required" in issues + assert settings.is_mail_configured is False + + +def test_mail_configuration_rejects_both_tls_modes_enabled(): + settings = Settings( + FF_DB_URI="postgresql://root:pwd@localhost:5432/fileflash", + EMAIL_VERIFY_BASE_URL="http://localhost:5173", + MAIL_FROM="demo@example.com", + MAIL_SERVER="smtp.example.com", + MAIL_PORT=587, + MAIL_USERNAME="demo@example.com", + MAIL_PASSWORD="secret", + MAIL_STARTTLS=True, + MAIL_SSL_TLS=True, + ) + assert "MAIL_SSL_TLS and MAIL_STARTTLS cannot both be true" in settings.mail_configuration_issues + + +def test_assert_runtime_security_raises_when_jwt_secret_too_short(): + settings = Settings( + FF_DB_URI="postgresql://root:pwd@localhost:5432/fileflash", + JWT_SECRET_KEY="short-key", + ) + with pytest.raises(ValueError, match="JWT_SECRET_KEY must be at least 32 bytes"): + settings.assert_runtime_security() + + +def test_assert_runtime_security_raises_when_token_hash_secret_too_short(): + settings = Settings( + FF_DB_URI="postgresql://root:pwd@localhost:5432/fileflash", + JWT_SECRET_KEY="x" * 32, + TOKEN_HASH_SECRET="short-key", + ) + with pytest.raises(ValueError, match="TOKEN_HASH_SECRET must be at least 32 bytes"): + settings.assert_runtime_security() + diff --git a/app/tests/test_share_routes.py b/app/tests/test_share_routes.py index e2cfdd3..48cb008 100644 --- a/app/tests/test_share_routes.py +++ b/app/tests/test_share_routes.py @@ -5,24 +5,34 @@ from fastapi import FastAPI from fastapi.testclient import TestClient -from src.core.deps import get_client_ip, get_share_service, get_user_agent -from src.routers.shares import router as shares_router +from fileflash.core.deps import get_client_ip, get_share_service, get_user_agent +from fileflash.routers.shares import router as shares_router class StubShareService: - async def get_shared_file_stream( + async def get_shared_file_download_stream_response( self, *, share_link: str, # noqa: ARG002 share_access_token: str, # noqa: ARG002 action: str, # noqa: ARG002 + range_header: str | None, # noqa: ARG002 ip_address: str, # noqa: ARG002 user_agent: str | None, # noqa: ARG002 - ) -> tuple[AsyncIterator[bytes], str, str]: + ) -> tuple[AsyncIterator[bytes], str, str, int, dict[str, str]]: async def _stream() -> AsyncIterator[bytes]: yield b"data" - return _stream(), "测试文档.pdf", "application/pdf" + headers = { + "Content-Disposition": ( + 'inline; filename="测试文档.pdf"; filename*=UTF-8\'\'%E6%B5%8B%E8%AF%95%E6%96%87%E6%A1%A3.pdf' + if action == "preview" + else 'attachment; filename="测试文档.pdf"; filename*=UTF-8\'\'%E6%B5%8B%E8%AF%95%E6%96%87%E6%A1%A3.pdf' + ), + "Accept-Ranges": "bytes", + "Content-Length": "4", + } + return _stream(), "测试文档.pdf", "application/pdf", 200, headers def _build_client() -> TestClient: @@ -61,4 +71,3 @@ def test_shared_preview_handles_unicode_filename_header() -> None: assert 'filename*=UTF-8\'\'' in header header.encode("latin-1") assert response.content == b"data" - diff --git a/app/tests/test_share_service.py b/app/tests/test_share_service.py index 78d2201..3e31798 100644 --- a/app/tests/test_share_service.py +++ b/app/tests/test_share_service.py @@ -1,16 +1,20 @@ from __future__ import annotations import re +from datetime import UTC, datetime +from collections.abc import AsyncIterator from types import SimpleNamespace -from unittest.mock import AsyncMock +from unittest.mock import AsyncMock, Mock import pytest -from src.core.security import create_share_access_token, decode_share_access_token -from src.core.settings import Settings -from src.models.tables_access_share import Share -from src.schemas.share import CreateShareRequest, SaveShareRequest, UpdateShareSettingsRequest -from src.services.share import ShareService +from fileflash.core.security import create_share_access_token, decode_share_access_token +from fileflash.core.settings import Settings +from fileflash.models.enums import UploadStatus +from fileflash.models.tables_access_share import Share +from fileflash.models.tables_storage import File, FileMediaMetadata, StorageObject +from fileflash.schemas.share import CreateShareRequest, SaveShareRequest, UpdateShareSettingsRequest +from fileflash.services.share import ShareService class DummySession: @@ -19,6 +23,7 @@ def __init__(self) -> None: self.flush = AsyncMock() self.execute = AsyncMock() self.add = AsyncMock() + self.scalar = AsyncMock() def make_settings(**overrides: object) -> Settings: @@ -31,7 +36,12 @@ def make_settings(**overrides: object) -> Settings: def make_service(session: DummySession, settings: Settings | None = None) -> ShareService: - storage = SimpleNamespace(iter_object=AsyncMock()) + storage = SimpleNamespace( + iter_object=AsyncMock(), + iter_object_range=AsyncMock(), + object_exists=AsyncMock(return_value=False), + stat_object=AsyncMock(), + ) return ShareService(db=session, settings=settings or make_settings(), storage=storage) @@ -164,3 +174,88 @@ async def test_save_requires_valid_share_token(monkeypatch: pytest.MonkeyPatch): user_agent="pytest", ) + +@pytest.mark.asyncio +async def test_get_shared_file_stream_prefers_transcoded_when_preview(monkeypatch: pytest.MonkeyPatch): + session = DummySession() + service = make_service(session) + + share_row = Share( + share_id=1, + user_id=1, + resource_type="file", + file_id=9, + folder_id=None, + share_link="ABCD", + share_code="ABCD", + status="active", + allow_preview=True, + allow_download=True, + ) + file_row = File( + file_id=9, + uploader_id=1, + owner_id=1, + folder_id=1, + file_name="video.mp4", + storage_object_id=33, + file_size=100, + file_ext="mp4", + mime_type="video/mp4", + ) + source_object = StorageObject( + object_id=33, + bucket_name="fileflash", + object_key="objects/u1/source-video", + object_size=100, + upload_status=UploadStatus.ACTIVE, + content_type="video/mp4", + ) + optimized_object = StorageObject( + object_id=34, + bucket_name="fileflash", + object_key="optimized/transcode/v1/object-33/source-mp4-v1.mp4", + object_size=80, + upload_status=UploadStatus.ACTIVE, + content_type="video/mp4", + ) + metadata = FileMediaMetadata(source_object_id=33) + metadata.extra_metadata = { + "transcode": { + "status": "ready", + "mediaType": "video", + "optimizedBucketName": optimized_object.bucket_name, + "optimizedObjectKey": optimized_object.object_key, + "optimizedMimeType": "video/mp4", + "updatedAt": datetime.now(UTC).isoformat(), + } + } + metadata.extracted_at = datetime.now(UTC) + + monkeypatch.setattr(service, "_resolve_share_for_access_token", AsyncMock(return_value=share_row)) + monkeypatch.setattr(service, "_get_active_file", AsyncMock(return_value=file_row)) + monkeypatch.setattr(service, "_log_share_event", AsyncMock()) + session.get = AsyncMock(return_value=source_object) + session.scalar = AsyncMock(side_effect=[metadata, optimized_object]) + session.execute = AsyncMock(return_value=None) + async def _dummy_stream() -> AsyncIterator[bytes]: + yield b"data" + + def _iter_object(**_kwargs: object) -> AsyncIterator[bytes]: + return _dummy_stream() + + iter_mock = Mock(side_effect=_iter_object) + service.storage.iter_object = iter_mock + + await service.get_shared_file_download_stream_response( + share_link="ABCD", + share_access_token="token", + action="preview", + range_header=None, + ip_address="127.0.0.1", + user_agent="pytest", + ) + + iter_mock.assert_called_once() + assert iter_mock.call_args.kwargs["object_key"] == optimized_object.object_key + diff --git a/app/tests/test_startup_fail_fast.py b/app/tests/test_startup_fail_fast.py index 9b18b97..a2ea383 100644 --- a/app/tests/test_startup_fail_fast.py +++ b/app/tests/test_startup_fail_fast.py @@ -5,38 +5,62 @@ import pytest -from src.main import lifespan -from src.s3.minio_client import ObjectStorageAuthError +from fileflash.main import lifespan +from fileflash.s3.minio_client import ObjectStorageAuthError @pytest.mark.asyncio async def test_lifespan_fails_fast_when_database_check_fails(monkeypatch: pytest.MonkeyPatch): verify = AsyncMock(side_effect=RuntimeError("database unavailable")) + verify_schema = AsyncMock() seed = AsyncMock() - monkeypatch.setattr("src.main.verify_database_connection", verify) - monkeypatch.setattr("src.main.initialize_dev_accounts", seed) + monkeypatch.setattr("fileflash.main.verify_database_connection", verify) + monkeypatch.setattr("fileflash.main.verify_schema_compatibility", verify_schema) + monkeypatch.setattr("fileflash.main.initialize_dev_accounts", seed) with pytest.raises(RuntimeError, match="database unavailable"): async with lifespan(object()): pass verify.assert_awaited_once() + verify_schema.assert_not_awaited() + seed.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_lifespan_fails_fast_when_schema_check_fails(monkeypatch: pytest.MonkeyPatch): + verify = AsyncMock() + verify_schema = AsyncMock(side_effect=RuntimeError("schema incompatible")) + seed = AsyncMock() + monkeypatch.setattr("fileflash.main.verify_database_connection", verify) + monkeypatch.setattr("fileflash.main.verify_schema_compatibility", verify_schema) + monkeypatch.setattr("fileflash.main.initialize_dev_accounts", seed) + + with pytest.raises(RuntimeError, match="schema incompatible"): + async with lifespan(object()): + pass + + verify.assert_awaited_once() + verify_schema.assert_awaited_once() seed.assert_not_awaited() @pytest.mark.asyncio async def test_lifespan_fails_fast_when_object_storage_check_fails(monkeypatch: pytest.MonkeyPatch): verify = AsyncMock() + verify_schema = AsyncMock() seed = AsyncMock() storage = SimpleNamespace(ensure_bucket=AsyncMock(side_effect=ObjectStorageAuthError("bad credentials"))) - monkeypatch.setattr("src.main.verify_database_connection", verify) - monkeypatch.setattr("src.main.get_object_storage", lambda: storage) - monkeypatch.setattr("src.main.initialize_dev_accounts", seed) + monkeypatch.setattr("fileflash.main.verify_database_connection", verify) + monkeypatch.setattr("fileflash.main.verify_schema_compatibility", verify_schema) + monkeypatch.setattr("fileflash.main.get_object_storage", lambda: storage) + monkeypatch.setattr("fileflash.main.initialize_dev_accounts", seed) with pytest.raises(ObjectStorageAuthError, match="bad credentials"): async with lifespan(object()): pass verify.assert_awaited_once() + verify_schema.assert_awaited_once() storage.ensure_bucket.assert_awaited_once() seed.assert_not_awaited() diff --git a/app/tests/test_tasks_archive.py b/app/tests/test_tasks_archive.py index 0a59b97..a827f85 100644 --- a/app/tests/test_tasks_archive.py +++ b/app/tests/test_tasks_archive.py @@ -7,7 +7,7 @@ import pytest -from src.tasks.archive import ( +from fileflash.tasks.archive import ( ArchiveLimits, _detect_archive_format, _extract_tar, @@ -226,7 +226,7 @@ def test_preview_7z_reports_entries(tmp_path: Path): with py7zr.SevenZipFile(archive_path, "w") as archive: archive.writeall(payload_dir, "root") - from src.tasks.archive import _preview_7z + from fileflash.tasks.archive import _preview_7z entries, summary = _preview_7z(archive_path=archive_path, max_entries=2000) assert summary["totalEntries"] >= 1 diff --git a/app/tests/test_tasks_scan.py b/app/tests/test_tasks_scan.py index 9b8bec8..de6b52f 100644 --- a/app/tests/test_tasks_scan.py +++ b/app/tests/test_tasks_scan.py @@ -2,7 +2,7 @@ from pathlib import Path -from src.tasks.scan import run_dangerous_file_scan +from fileflash.tasks.scan import run_dangerous_file_scan def test_scan_marks_normal_text_file_as_clean(tmp_path: Path): diff --git a/app/tests/test_tasks_transcode.py b/app/tests/test_tasks_transcode.py index 662e0d5..63ad82a 100644 --- a/app/tests/test_tasks_transcode.py +++ b/app/tests/test_tasks_transcode.py @@ -3,12 +3,12 @@ import json import subprocess from pathlib import Path -from typing import Any +from types import SimpleNamespace -from src.tasks.transcode import build_ffmpeg_command, run_media_transcode +from fileflash.tasks.transcode import build_ffmpeg_command, run_media_transcode -def test_build_ffmpeg_command_for_video(): +def test_build_ffmpeg_command_for_video_contains_profile_flags(): command = build_ffmpeg_command( input_path=Path("input.mov"), output_path=Path("output.mp4"), @@ -20,13 +20,14 @@ def test_build_ffmpeg_command_for_video(): assert command[:4] == ["ffmpeg", "-y", "-i", "input.mov"] assert "-c:v" in command assert "libx264" in command + assert "-movflags" in command + assert "+faststart" in command + assert "-pix_fmt" in command + assert "yuv420p" in command assert command[-1] == "output.mp4" -def test_run_media_transcode_with_mocked_subprocess(monkeypatch, tmp_path: Path): - input_path = tmp_path / "voice.wav" - input_path.write_bytes(b"RIFF....WAVEfmt") - +def test_run_media_transcode_storage_mode_with_mocked_subprocess(monkeypatch, tmp_path: Path): source_probe = { "streams": [{"codec_type": "audio", "codec_name": "pcm_s16le", "sample_rate": "44100"}], "format": {"duration": "2.5", "bit_rate": "96000"}, @@ -37,6 +38,27 @@ def test_run_media_transcode_with_mocked_subprocess(monkeypatch, tmp_path: Path) } calls: list[list[str]] = [] + class DummyStorage: + async def fget_object(self, *, bucket_name: str, object_key: str, file_path: str): + _ = (bucket_name, object_key) + Path(file_path).write_bytes(b"RIFF....WAVEfmt") + return SimpleNamespace(etag="src-etag", version_id=None) + + async def fput_object( + self, + *, + bucket_name: str, + object_key: str, + file_path: str, + content_type: str, + ): + _ = (bucket_name, object_key, file_path, content_type) + return SimpleNamespace(etag="out-etag", version_id="v1") + + async def stat_object(self, *, bucket_name: str, object_key: str): + _ = (bucket_name, object_key) + return SimpleNamespace(size=128, etag="out-etag", version_id="v1") + def fake_run( command: list[str], *, @@ -49,7 +71,7 @@ def fake_run( calls.append(command) if command[0] == "ffprobe": target = command[-1] - payload: dict[str, Any] = source_probe if target.endswith(".wav") else output_probe + payload = source_probe if target.endswith("source") else output_probe return subprocess.CompletedProcess(command, 0, stdout=json.dumps(payload), stderr="") if command[0] == "ffmpeg": output_path = Path(command[-1]) @@ -57,12 +79,27 @@ def fake_run( return subprocess.CompletedProcess(command, 0, stdout="", stderr="") raise AssertionError(f"Unexpected command: {command}") - monkeypatch.setattr("src.tasks.transcode.subprocess.run", fake_run) + monkeypatch.setattr("fileflash.tasks.transcode.subprocess.run", fake_run) + monkeypatch.setattr( + "fileflash.tasks.transcode.MinioObjectStorageClient.from_settings", + lambda _settings: DummyStorage(), + ) - result = run_media_transcode({"inputPath": str(input_path)}) + result = run_media_transcode( + { + "sourceBucketName": "fileflash", + "sourceObjectKey": "objects/u1/voice.wav", + "sourceObjectId": 99, + "outputBucketName": "fileflash", + "outputObjectKey": "optimized/transcode/v1/object-99/voice-mp4-v1.m4a", + } + ) assert result["mediaType"] == "audio" - assert result["outputPath"].endswith(".m4a") + assert result["outputObjectKey"].endswith(".m4a") + assert result["outputObjectSize"] == 128 + assert result["optimizedMimeType"] == "audio/mp4" assert result["metadata"]["durationMs"] == 2500 assert result["metadata"]["audioCodec"] == "aac" assert [call[0] for call in calls] == ["ffprobe", "ffmpeg", "ffprobe"] + diff --git a/app/tests/test_upload_service.py b/app/tests/test_upload_service.py index af0f872..298d019 100644 --- a/app/tests/test_upload_service.py +++ b/app/tests/test_upload_service.py @@ -7,13 +7,13 @@ import pytest -from src.core.errors import ApiError -from src.core.settings import Settings -from src.models.enums import UploadPartStatus, UploadTaskStatus -from src.models.tables_storage import File, StorageObject, UploadTask, UploadTaskPart -from src.s3.minio_client import ObjectStat, ObjectStorageAuthError, ObjectWriteResult -from src.schemas.file import MergeChunksRequest, UploadPreflightRequest -from src.services.upload import UploadService +from fileflash.core.errors import ApiError +from fileflash.core.settings import Settings +from fileflash.models.enums import UploadPartStatus, UploadTaskStatus +from fileflash.models.tables_storage import File, FileMediaMetadata, StorageObject, UploadTask, UploadTaskPart +from fileflash.s3.minio_client import ObjectStat, ObjectStorageAuthError, ObjectWriteResult +from fileflash.schemas.file import MergeChunksRequest, MergeChunksResponse, UploadPreflightRequest +from fileflash.services.upload import UploadService class DummySession: @@ -31,6 +31,8 @@ async def commit(self) -> None: async def flush(self) -> None: now = datetime.now(UTC) for index, obj in enumerate(self.added, start=1): + if isinstance(obj, StorageObject) and obj.object_id is None: + obj.object_id = index if isinstance(obj, UploadTask) and obj.task_id is None: obj.task_id = index if isinstance(obj, File) and obj.file_id is None: @@ -68,7 +70,11 @@ def make_service(session: DummySession, settings: Settings | None = None) -> tup remove_objects=AsyncMock(), compute_object_hash=AsyncMock(return_value="0" * 64), ) - service = UploadService(db=session, settings=settings or make_settings(), storage=storage) + jobs = SimpleNamespace( + enqueue=AsyncMock(), + enqueue_transcode_job=AsyncMock(), + ) + service = UploadService(db=session, settings=settings or make_settings(), storage=storage, jobs=jobs) return service, storage @@ -471,7 +477,7 @@ async def test_merge_logs_warning_for_incomplete_chunks(monkeypatch: pytest.Monk monkeypatch.setattr(service, "_get_task_for_update", AsyncMock(return_value=task)) monkeypatch.setattr(service, "_resolve_folder_id", AsyncMock(return_value=1)) monkeypatch.setattr(service, "_find_conflict_file", AsyncMock(return_value=None)) - caplog.set_level(logging.WARNING, logger="src.services.upload") + caplog.set_level(logging.WARNING, logger="fileflash.services.upload") with pytest.raises(ApiError) as exc: await service.merge_chunks( @@ -521,7 +527,7 @@ async def test_merge_logs_warning_for_non_continuous_chunks( monkeypatch.setattr(service, "_get_task_for_update", AsyncMock(return_value=task)) monkeypatch.setattr(service, "_resolve_folder_id", AsyncMock(return_value=1)) monkeypatch.setattr(service, "_find_conflict_file", AsyncMock(return_value=None)) - caplog.set_level(logging.WARNING, logger="src.services.upload") + caplog.set_level(logging.WARNING, logger="fileflash.services.upload") with pytest.raises(ApiError) as exc: await service.merge_chunks( @@ -538,3 +544,170 @@ async def test_merge_logs_warning_for_non_continuous_chunks( assert exc.value.status_code == 400 assert "upload-non-contiguous" in caplog.text assert "non-continuous chunks" in caplog.text + + +@pytest.mark.asyncio +async def test_merge_enqueues_transcode_for_video_and_sets_queued_metadata(monkeypatch: pytest.MonkeyPatch): + session = DummySession() + service, storage = make_service(session) + task = UploadTask( + task_id=50, + user_id=9, + folder_id=1, + file_name="movie.mp4", + mime_type="video/mp4", + bucket_name="fileflash", + object_key="objects/u9/movie", + object_hash="1" * 64, + total_size=4, + chunk_size=2, + upload_id="upload-video-transcode", + status=UploadTaskStatus.UPLOADING, + expired_at=datetime.now(UTC) + timedelta(hours=1), + ) + parts = [ + UploadTaskPart(task_id=50, part_number=0, part_size=2, status=UploadPartStatus.UPLOADED), + UploadTaskPart(task_id=50, part_number=1, part_size=2, status=UploadPartStatus.UPLOADED), + ] + session.scalars_queue = [parts] + + monkeypatch.setattr(service, "_get_task_for_update", AsyncMock(return_value=task)) + monkeypatch.setattr(service, "_resolve_folder_id", AsyncMock(return_value=1)) + monkeypatch.setattr(service, "_find_conflict_file", AsyncMock(return_value=None)) + monkeypatch.setattr(service, "_find_storage_object", AsyncMock(return_value=None)) + storage.compute_object_hash = AsyncMock(return_value="1" * 64) + + response = await service.merge_chunks( + user_id=9, + upload_id="upload-video-transcode", + payload=MergeChunksRequest( + fileHash="1" * 64, + fileName="movie.mp4", + mimeType="video/mp4", + parentId="1", + ), + ) + + assert response.file_name == "movie.mp4" + created_metadata = next(obj for obj in session.added if isinstance(obj, FileMediaMetadata)) + assert created_metadata.extra_metadata["transcode"]["status"] == "queued" + service.jobs.enqueue_transcode_job.assert_awaited_once() # type: ignore[union-attr] + + +@pytest.mark.asyncio +async def test_merge_transcode_enqueue_failure_does_not_fail_upload(monkeypatch: pytest.MonkeyPatch): + session = DummySession() + service, storage = make_service(session) + service.jobs.enqueue_transcode_job = AsyncMock( # type: ignore[union-attr] + side_effect=ApiError(status_code=503, code=503, message="Job queue unavailable") + ) + task = UploadTask( + task_id=51, + user_id=9, + folder_id=1, + file_name="audio.mp3", + mime_type="audio/mpeg", + bucket_name="fileflash", + object_key="objects/u9/audio", + object_hash="2" * 64, + total_size=4, + chunk_size=2, + upload_id="upload-audio-transcode", + status=UploadTaskStatus.UPLOADING, + expired_at=datetime.now(UTC) + timedelta(hours=1), + ) + parts = [ + UploadTaskPart(task_id=51, part_number=0, part_size=2, status=UploadPartStatus.UPLOADED), + UploadTaskPart(task_id=51, part_number=1, part_size=2, status=UploadPartStatus.UPLOADED), + ] + session.scalars_queue = [parts] + + monkeypatch.setattr(service, "_get_task_for_update", AsyncMock(return_value=task)) + monkeypatch.setattr(service, "_resolve_folder_id", AsyncMock(return_value=1)) + monkeypatch.setattr(service, "_find_conflict_file", AsyncMock(return_value=None)) + monkeypatch.setattr(service, "_find_storage_object", AsyncMock(return_value=None)) + storage.compute_object_hash = AsyncMock(return_value="2" * 64) + + response = await service.merge_chunks( + user_id=9, + upload_id="upload-audio-transcode", + payload=MergeChunksRequest( + fileHash="2" * 64, + fileName="audio.mp3", + mimeType="audio/mpeg", + parentId="1", + ), + ) + + assert response.file_name == "audio.mp3" + created_metadata = next(obj for obj in session.added if isinstance(obj, FileMediaMetadata)) + assert created_metadata.extra_metadata["transcode"]["status"] == "failed" + + +@pytest.mark.asyncio +async def test_enqueue_merge_job_uses_normalized_payload(monkeypatch: pytest.MonkeyPatch): + session = DummySession() + service, _storage = make_service(session) + fake_job = SimpleNamespace(job_id=1234, task_type="task.upload_merge", status="pending") + service.jobs.enqueue = AsyncMock(return_value=fake_job) # type: ignore[union-attr] + + payload = MergeChunksRequest( + fileHash="A" * 64, + fileName="movie.mp4", + mimeType="video/mp4", + parentId="root", + conflictStrategy="rename", + ) + job = await service.enqueue_merge_job( + user_id=7, + upload_id="upload-merge-job", + payload=payload, + ) + + assert job is fake_job + service.jobs.enqueue.assert_awaited_once() # type: ignore[union-attr] + _db, kwargs = service.jobs.enqueue.await_args.args, service.jobs.enqueue.await_args.kwargs # type: ignore[union-attr] + assert kwargs["task_type"] == "task.upload_merge" + assert kwargs["requested_by"] == 7 + assert kwargs["idempotency_key"].startswith("upload:7:upload-merge-job:merge:") + assert kwargs["payload"]["userId"] == 7 + assert kwargs["payload"]["uploadId"] == "upload-merge-job" + assert kwargs["payload"]["mergeRequest"]["fileHash"] == ("a" * 64) + + +@pytest.mark.asyncio +async def test_execute_merge_job_calls_merge_chunks(monkeypatch: pytest.MonkeyPatch): + session = DummySession() + service, _storage = make_service(session) + expected = MergeChunksResponse( + fileId="901", + fileName="report.pdf", + fileSize=1024, + mimeType="application/pdf", + folderId="root", + objectHash="f" * 64, + createdAt=datetime.now(UTC), + downloadUrl="/api/v1/files/901/download", + ) + merge_mock = AsyncMock(return_value=expected) + monkeypatch.setattr(service, "merge_chunks", merge_mock) + + result = await service.execute_merge_job( + payload={ + "userId": 99, + "uploadId": "upload-exec-1", + "mergeRequest": { + "fileHash": "f" * 64, + "fileName": "report.pdf", + "mimeType": "application/pdf", + "parentId": "root", + }, + } + ) + + merge_mock.assert_awaited_once() + assert result["fileId"] == "901" + assert result["fileName"] == "report.pdf" + assert result["downloadUrl"] == "/api/v1/files/901/download" + assert isinstance(result["createdAt"], str) + assert datetime.fromisoformat(result["createdAt"].replace("Z", "+00:00")) == expected.created_at diff --git a/app/tests/test_workers.py b/app/tests/test_workers.py index aaaaea7..286547a 100644 --- a/app/tests/test_workers.py +++ b/app/tests/test_workers.py @@ -1,9 +1,32 @@ from __future__ import annotations +import asyncio +import pickle +import threading +from datetime import UTC, datetime +from types import SimpleNamespace +from unittest.mock import AsyncMock + import pytest -from src.tasks.registry import UnknownTaskTypeError, dispatch_task -from src.workers.repository import get_retry_delay_seconds +from fileflash.core.errors import ApiError +from fileflash.tasks.registry import UnknownTaskTypeError, dispatch_task +from fileflash.workers.bootstrap import WorkerRuntimeConfig +from fileflash.workers.consumer import WorkerConsumer, _is_retryable_error +from fileflash.workers.contracts import WorkerJobMessage +from fileflash.workers.dispatcher import PicklableRemoteTaskError, execute_task +from fileflash.workers.repository import get_retry_delay_seconds, mark_job_succeeded + + +class _AsyncContextManager: + def __init__(self, value): + self._value = value + + async def __aenter__(self): + return self._value + + async def __aexit__(self, exc_type, exc, tb): + return None def test_dispatch_task_rejects_unknown_type(): @@ -16,3 +39,391 @@ def test_retry_delay_uses_last_backoff_when_attempt_exceeds_schedule(): assert get_retry_delay_seconds(schedule, attempt=1) == 3 assert get_retry_delay_seconds(schedule, attempt=2) == 10 assert get_retry_delay_seconds(schedule, attempt=4) == 30 + + +@pytest.mark.asyncio +async def test_mark_job_succeeded_serializes_nested_datetime_result(monkeypatch): + created_at = datetime(2026, 5, 13, 8, 29, 51, tzinfo=UTC) + job = SimpleNamespace( + status="running", + result={}, + error_message="previous", + finished_at=None, + updated_at=None, + ) + load_mock = AsyncMock(return_value=job) + monkeypatch.setattr("fileflash.workers.repository._load_job_for_update", load_mock) + + await mark_job_succeeded( + SimpleNamespace(), + job_id=11, + result={ + "fileId": "11", + "metadata": { + "createdAt": created_at, + }, + }, + ) + + load_mock.assert_awaited_once() + assert job.status == "succeeded" + assert job.error_message is None + assert job.result["fileId"] == "11" + assert job.result["metadata"]["createdAt"] == created_at.isoformat() + assert not isinstance(job.result["metadata"]["createdAt"], datetime) + + +def test_picklable_remote_task_error_can_be_pickled(): + error = PicklableRemoteTaskError( + original_type="TypeError", + message="cannot pickle '_thread.lock' object", + retryable_hint=True, + ) + restored = pickle.loads(pickle.dumps(error)) + assert isinstance(restored, PicklableRemoteTaskError) + assert restored.original_type == "TypeError" + assert restored.message == "cannot pickle '_thread.lock' object" + assert restored.retryable_hint is True + + +def test_retryable_error_uses_remote_hint_when_present(): + retryable = PicklableRemoteTaskError(original_type="RuntimeError", message="x", retryable_hint=True) + non_retryable = PicklableRemoteTaskError(original_type="RuntimeError", message="x", retryable_hint=False) + + assert _is_retryable_error(retryable) is True + assert _is_retryable_error(non_retryable) is False + + +def test_retryable_error_uses_remote_original_type_mapping(): + wrapped_value_error = PicklableRemoteTaskError( + original_type="ValueError", + message="bad payload", + retryable_hint=None, + ) + wrapped_runtime_error = PicklableRemoteTaskError( + original_type="RuntimeError", + message="temporary upstream", + retryable_hint=None, + ) + + assert _is_retryable_error(wrapped_value_error) is False + assert _is_retryable_error(wrapped_runtime_error) is True + + +def test_retryable_error_for_api_error(): + non_retryable = ApiError(status_code=400, code=400, message="bad request") + retryable_500 = ApiError(status_code=503, code=503, message="queue down") + retryable_409 = ApiError( + status_code=409, + code=409, + message="retry", + data={"retryable": True}, + ) + + assert _is_retryable_error(non_retryable) is False + assert _is_retryable_error(retryable_500) is True + assert _is_retryable_error(retryable_409) is True + + +def test_execute_task_wraps_non_picklable_exception(monkeypatch): + class _NonPicklableError(Exception): + def __init__(self): + self.lock = threading.Lock() + super().__init__("cannot pickle lock") + + def raise_non_picklable(*_args, **_kwargs): + raise _NonPicklableError() + + monkeypatch.setattr("fileflash.workers.dispatcher.dispatch_task", raise_non_picklable) + + with pytest.raises(PicklableRemoteTaskError) as exc_info: + execute_task("task.archive_preview", {}) + + wrapped = exc_info.value + assert wrapped.original_type == "_NonPicklableError" + assert wrapped.message == "cannot pickle lock" + + +@pytest.mark.asyncio +async def test_process_message_backfills_missing_job_id_in_payload(monkeypatch): + config = WorkerRuntimeConfig( + poll_interval_seconds=1.0, + task_timeout_seconds=30, + worker_slots=1, + default_max_attempts=5, + retry_backoff_seconds=(1, 2, 3), + queue_stream="fileflash:tasks", + queue_group="fileflash-workers", + queue_block_ms=1000, + ffmpeg_binary="ffmpeg", + ffprobe_binary="ffprobe", + ) + queue = SimpleNamespace() + session = SimpleNamespace(begin=lambda: _AsyncContextManager(SimpleNamespace())) + session_factory = lambda: _AsyncContextManager(session) + consumer = WorkerConsumer( + config=config, + executor=None, # type: ignore[arg-type] + queue=queue, # type: ignore[arg-type] + session_factory=session_factory, # type: ignore[arg-type] + ) + message = WorkerJobMessage( + version=1, + message_id="job-123-attempt-0", + job_id=123, + task_type="task.archive_extract", + idempotency_key=None, + attempt=0, + max_attempts=5, + trace_id="trace-1", + requested_by="9", + payload={"targetFolderId": "root", "jobId": None}, + ) + + loop = asyncio.get_running_loop() + captured_payload: dict[str, object] = {} + + async def fake_wait_for(awaitable, timeout): + return await awaitable + + def fake_run_in_executor(_executor, _fn, _task_type, payload): + captured_payload.update(payload) + fut = loop.create_future() + fut.set_result({"summary": {}}) + return fut + + monkeypatch.setattr("fileflash.workers.consumer.apply_task_effects", AsyncMock(return_value={})) + monkeypatch.setattr("fileflash.workers.consumer.mark_job_succeeded", AsyncMock()) + monkeypatch.setattr("fileflash.workers.consumer.asyncio.wait_for", fake_wait_for) + monkeypatch.setattr(loop, "run_in_executor", fake_run_in_executor) + + await consumer._process_message(slot=0, message=message) + + assert captured_payload["jobId"] == 123 + + +@pytest.mark.asyncio +async def test_process_message_backfills_empty_job_id_in_payload(monkeypatch): + config = WorkerRuntimeConfig( + poll_interval_seconds=1.0, + task_timeout_seconds=30, + worker_slots=1, + default_max_attempts=5, + retry_backoff_seconds=(1, 2, 3), + queue_stream="fileflash:tasks", + queue_group="fileflash-workers", + queue_block_ms=1000, + ffmpeg_binary="ffmpeg", + ffprobe_binary="ffprobe", + ) + queue = SimpleNamespace() + session = SimpleNamespace(begin=lambda: _AsyncContextManager(SimpleNamespace())) + session_factory = lambda: _AsyncContextManager(session) + consumer = WorkerConsumer( + config=config, + executor=None, # type: ignore[arg-type] + queue=queue, # type: ignore[arg-type] + session_factory=session_factory, # type: ignore[arg-type] + ) + message = WorkerJobMessage( + version=1, + message_id="job-321-attempt-0", + job_id=321, + task_type="task.archive_extract", + idempotency_key=None, + attempt=0, + max_attempts=5, + trace_id="trace-2", + requested_by="9", + payload={"targetFolderId": "root", "jobId": ""}, + ) + + loop = asyncio.get_running_loop() + captured_payload: dict[str, object] = {} + + async def fake_wait_for(awaitable, timeout): + return await awaitable + + def fake_run_in_executor(_executor, _fn, _task_type, payload): + captured_payload.update(payload) + fut = loop.create_future() + fut.set_result({"summary": {}}) + return fut + + monkeypatch.setattr("fileflash.workers.consumer.apply_task_effects", AsyncMock(return_value={})) + monkeypatch.setattr("fileflash.workers.consumer.mark_job_succeeded", AsyncMock()) + monkeypatch.setattr("fileflash.workers.consumer.asyncio.wait_for", fake_wait_for) + monkeypatch.setattr(loop, "run_in_executor", fake_run_in_executor) + + await consumer._process_message(slot=0, message=message) + + assert captured_payload["jobId"] == 321 + + +@pytest.mark.asyncio +async def test_process_transcode_message_marks_running(monkeypatch): + config = WorkerRuntimeConfig( + poll_interval_seconds=1.0, + task_timeout_seconds=30, + worker_slots=1, + default_max_attempts=5, + retry_backoff_seconds=(1, 2, 3), + queue_stream="fileflash:tasks", + queue_group="fileflash-workers", + queue_block_ms=1000, + ffmpeg_binary="ffmpeg", + ffprobe_binary="ffprobe", + ) + queue = SimpleNamespace() + session = SimpleNamespace(begin=lambda: _AsyncContextManager(SimpleNamespace())) + session_factory = lambda: _AsyncContextManager(session) + consumer = WorkerConsumer( + config=config, + executor=None, # type: ignore[arg-type] + queue=queue, # type: ignore[arg-type] + session_factory=session_factory, # type: ignore[arg-type] + ) + message = WorkerJobMessage( + version=1, + message_id="job-666-attempt-0", + job_id=666, + task_type="task.transcode", + idempotency_key=None, + attempt=0, + max_attempts=5, + trace_id="trace-666", + requested_by="9", + payload={ + "sourceObjectId": 99, + "sourceBucketName": "fileflash", + "sourceObjectKey": "objects/u1/source", + "outputBucketName": "fileflash", + "outputObjectKey": "optimized/transcode/v1/object-99/source-mp4-v1.mp4", + }, + ) + + loop = asyncio.get_running_loop() + + async def fake_wait_for(awaitable, timeout): + return await awaitable + + def fake_run_in_executor(_executor, _fn, _task_type, payload): + _ = payload + fut = loop.create_future() + fut.set_result({"summary": {}}) + return fut + + mark_running_mock = AsyncMock() + monkeypatch.setattr(consumer, "_mark_transcode_running", mark_running_mock) + monkeypatch.setattr("fileflash.workers.consumer.apply_task_effects", AsyncMock(return_value={})) + monkeypatch.setattr("fileflash.workers.consumer.mark_job_succeeded", AsyncMock()) + monkeypatch.setattr("fileflash.workers.consumer.asyncio.wait_for", fake_wait_for) + monkeypatch.setattr(loop, "run_in_executor", fake_run_in_executor) + + await consumer._process_message(slot=0, message=message) + mark_running_mock.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_handle_failure_marks_transcode_failed_on_terminal_state(monkeypatch): + config = WorkerRuntimeConfig( + poll_interval_seconds=1.0, + task_timeout_seconds=30, + worker_slots=1, + default_max_attempts=5, + retry_backoff_seconds=(1, 2, 3), + queue_stream="fileflash:tasks", + queue_group="fileflash-workers", + queue_block_ms=1000, + ffmpeg_binary="ffmpeg", + ffprobe_binary="ffprobe", + ) + queue = SimpleNamespace() + session = SimpleNamespace(begin=lambda: _AsyncContextManager(SimpleNamespace())) + session_factory = lambda: _AsyncContextManager(session) + consumer = WorkerConsumer( + config=config, + executor=None, # type: ignore[arg-type] + queue=queue, # type: ignore[arg-type] + session_factory=session_factory, # type: ignore[arg-type] + ) + message = WorkerJobMessage( + version=1, + message_id="job-777-attempt-0", + job_id=777, + task_type="task.transcode", + idempotency_key=None, + attempt=0, + max_attempts=1, + trace_id="trace-777", + requested_by="9", + payload={"sourceObjectId": 77, "outputObjectKey": "optimized/x.mp4"}, + ) + + monkeypatch.setattr("fileflash.workers.consumer.mark_job_failed_or_retrying", AsyncMock(return_value="failed")) + failed_mock = AsyncMock() + monkeypatch.setattr("fileflash.workers.consumer.mark_transcode_failed", failed_mock) + monkeypatch.setattr("fileflash.workers.consumer.asyncio.create_task", lambda _task: None) + + await consumer._handle_failure(slot=0, message=message, error=RuntimeError("boom")) + failed_mock.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_process_upload_merge_message_uses_upload_service_path(monkeypatch): + config = WorkerRuntimeConfig( + poll_interval_seconds=1.0, + task_timeout_seconds=30, + worker_slots=1, + default_max_attempts=5, + retry_backoff_seconds=(1, 2, 3), + queue_stream="fileflash:tasks", + queue_group="fileflash-workers", + queue_block_ms=1000, + ffmpeg_binary="ffmpeg", + ffprobe_binary="ffprobe", + ) + queue = SimpleNamespace() + session = SimpleNamespace(begin=lambda: _AsyncContextManager(SimpleNamespace())) + session_factory = lambda: _AsyncContextManager(session) + consumer = WorkerConsumer( + config=config, + executor=None, # type: ignore[arg-type] + queue=queue, # type: ignore[arg-type] + session_factory=session_factory, # type: ignore[arg-type] + ) + message = WorkerJobMessage( + version=1, + message_id="job-999-attempt-0", + job_id=999, + task_type="task.upload_merge", + idempotency_key=None, + attempt=0, + max_attempts=5, + trace_id="trace-999", + requested_by="9", + payload={"userId": 9, "uploadId": "upload-1", "mergeRequest": {"fileHash": "a" * 64}}, + ) + + async def fake_wait_for(awaitable, timeout): + return await awaitable + + loop = asyncio.get_running_loop() + run_in_executor_called = False + + def fake_run_in_executor(*_args, **_kwargs): + nonlocal run_in_executor_called + run_in_executor_called = True + fut = loop.create_future() + fut.set_result({}) + return fut + + monkeypatch.setattr(consumer, "_run_upload_merge", AsyncMock(return_value={"fileId": "f1"})) + monkeypatch.setattr("fileflash.workers.consumer.mark_job_succeeded", AsyncMock()) + monkeypatch.setattr("fileflash.workers.consumer.asyncio.wait_for", fake_wait_for) + monkeypatch.setattr(loop, "run_in_executor", fake_run_in_executor) + + await consumer._process_message(slot=0, message=message) + + consumer._run_upload_merge.assert_awaited_once() # type: ignore[attr-defined] + assert run_in_executor_called is False diff --git a/app/uv.lock b/app/uv.lock index 290e561..f1aa17f 100644 --- a/app/uv.lock +++ b/app/uv.lock @@ -525,7 +525,7 @@ wheels = [ [[package]] name = "fileflash" version = "0.1.0" -source = { virtual = "." } +source = { editable = "." } dependencies = [ { name = "asyncpg" }, { name = "fastapi" }, diff --git a/docker/flyway/migrations/V10__identity_avatar.sql b/docker/flyway/migrations/V10__identity_avatar.sql new file mode 100644 index 0000000..99f3c46 --- /dev/null +++ b/docker/flyway/migrations/V10__identity_avatar.sql @@ -0,0 +1,9 @@ +-- ========================= +-- Domain: identity +-- Add avatar field for user profile +-- ========================= + +ALTER TABLE "user" + ADD COLUMN IF NOT EXISTS avatar VARCHAR(512); + +COMMENT ON COLUMN "user".avatar IS 'User avatar URL'; diff --git a/docker/flyway/migrations/V11__identity_registration_email_domain_rule.sql b/docker/flyway/migrations/V11__identity_registration_email_domain_rule.sql new file mode 100644 index 0000000..0ad41d9 --- /dev/null +++ b/docker/flyway/migrations/V11__identity_registration_email_domain_rule.sql @@ -0,0 +1,25 @@ +-- ========================= +-- Domain: identity +-- Registration email domain regex rules +-- ========================= + +CREATE TABLE IF NOT EXISTS registration_email_domain_rule ( + rule_id BIGINT GENERATED BY DEFAULT AS IDENTITY PRIMARY KEY, + name VARCHAR(120) NOT NULL, + pattern VARCHAR(512) NOT NULL, + enabled BOOLEAN NOT NULL DEFAULT TRUE, + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP +); + +CREATE UNIQUE INDEX IF NOT EXISTS uk_registration_email_domain_rule_name_ci + ON registration_email_domain_rule ((LOWER(name))); + +CREATE INDEX IF NOT EXISTS idx_registration_email_domain_rule_enabled + ON registration_email_domain_rule (enabled); + +DROP TRIGGER IF EXISTS trg_registration_email_domain_rule_updated_at ON registration_email_domain_rule; +CREATE TRIGGER trg_registration_email_domain_rule_updated_at +BEFORE UPDATE ON registration_email_domain_rule +FOR EACH ROW +EXECUTE FUNCTION set_updated_at(); diff --git a/docker/flyway/migrations/V12__identity_seed_scu_email_domain_rule.sql b/docker/flyway/migrations/V12__identity_seed_scu_email_domain_rule.sql new file mode 100644 index 0000000..16727a3 --- /dev/null +++ b/docker/flyway/migrations/V12__identity_seed_scu_email_domain_rule.sql @@ -0,0 +1,16 @@ +-- ========================= +-- Domain: identity +-- Seed default registration email domain rule +-- ========================= + +INSERT INTO registration_email_domain_rule ( + name, + pattern, + enabled +) +VALUES ( + 'allow_scu_edu_cn', + '(?:[a-z0-9-]+\.)*scu\.edu\.cn', + TRUE +) +ON CONFLICT DO NOTHING; diff --git a/docs/superpowers/plans/2026-05-11-frontend-redesign-p5-public-auth-flow.md b/docs/superpowers/plans/2026-05-11-frontend-redesign-p5-public-auth-flow.md new file mode 100644 index 0000000..d8c50b0 --- /dev/null +++ b/docs/superpowers/plans/2026-05-11-frontend-redesign-p5-public-auth-flow.md @@ -0,0 +1,1445 @@ +# P5 Public Auth Flow Implementation Plan + +> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking. + +**Goal:** Rewrite the 4 public auth pages (`/login`, `/register`, `/forgot-password`, `/verify-email`) against the new Industrial Dashboard system. Build a shared `AuthForm` organism that drives Login / Register / ForgotPassword via a `mode` prop. Each page file must be ≤ 100 lines and contain only API + nav orchestration. VerifyEmail uses BareLayout with inline atoms/molecules (no AuthForm — no fields). Functional parity required: no behavior is dropped. + +**Architecture:** +- `components/organisms/auth/AuthForm.vue` — single organism owning the visual shell + per-mode field set. Discriminated `submit` payload keeps page logic typed. +- `components/organisms/auth/index.ts` — public barrel. +- `pages/login/Login.vue`, `pages/register/Register.vue`, `pages/forgot-password/ForgotPassword.vue` — rewritten ≤ 100 lines each, mount `AuthForm` and react to its `submit` event. +- `pages/verify-email/VerifyEmail.vue` — rewritten ≤ 100 lines, uses atoms (`Text`, `Spinner`, `Dot`) + `Button` molecule directly. No form fields. +- `pages/__dev/Library.vue` — adds an `Organisms · Auth` section demoing all 3 AuthForm modes + the VerifyEmail status block. +- **Not deleted in P5:** the existing `AuthLayout.vue` still imports `useThemeStore` from `store/theme.ts`. Don't touch that — P6 replaces themeStore with preferencesStore. + +**Tech Stack:** Vue 3 ` + + + + +``` + +- [ ] **Step 5: Run tests to verify pass** + +```bash +cd web && bun x vitest run src/components/organisms/auth/AuthForm.spec.ts && bun run check +``` + +Expected: 9 passing, type-check clean. If a spec assertion fails because input order in `findAll('input')` differs from expectation (Checkbox is rendered as `` inside the molecule), tighten the selectors: use `input[type="text"]`, `input[type="password"]`, `input[type="checkbox"]` rather than positional indexing. Do NOT change the production code to match a flaky test. + +- [ ] **Step 6: Commit** + +```bash +git add web/src/components/organisms/auth/AuthForm.vue \ + web/src/components/organisms/auth/AuthForm.spec.ts \ + web/src/components/organisms/auth/index.ts +git commit -m "feat(organisms/auth): add AuthForm (login/register/forgot modes)" +``` + +--- + +## Phase B — Dev library coverage + +### Task 2: Add `Organisms · Auth` section to Library + +**Files:** +- Modify: `web/src/pages/__dev/Library.vue` + +**Design notes:** +- Extend the `sections` tuple with `'Organisms · Auth'`. +- Add demo state refs: `authMode: 'login' | 'register' | 'forgot'`, `authSubmitting`, `authError`, `authSuccess`, `lastSubmit`. +- A small SegmentedControl (already used elsewhere) lets the reader switch between the 3 modes live. +- Three label bundles (English, ASCII only) so the reader can see what each mode renders. +- A `
` panel shows the most recent submit payload for sanity.
+
+- [ ] **Step 1: Edit `pages/__dev/Library.vue`**
+
+In the existing `
+
+
+
+
+```
+
+- [ ] **Step 2: Verify line count**
+
+```bash
+wc -l web/src/pages/login/Login.vue
+```
+
+Expected: ≤ 100. If over, condense `saved` derivation inline and remove the `initial` computed (pass `:initial="saved"` directly).
+
+- [ ] **Step 3: Run check + tests**
+
+```bash
+cd web && bun run check && bun run test
+```
+
+Expected: type-check clean. Tests untouched (none specifically target Login). If `vue-tsc` complains about the `var(--weight-semibold)` / `var(--tracking-wide)` references being unknown — they're CSS variables defined in `web/src/styles/tokens/type.css`, so tsc won't flag them. If you misnamed a token (e.g. `--font-weight-semibold` legacy), grep tokens:
+
+```bash
+grep -nE "^\s*--weight-|--tracking-" src/styles/tokens/type.css
+```
+
+and align.
+
+- [ ] **Step 4: Commit**
+
+```bash
+git add web/src/pages/login/Login.vue
+git commit -m "refactor(pages/login): rewrite Login against AuthForm (≤100 lines)"
+```
+
+---
+
+### Task 4: Rewrite Register.vue (≤ 100 lines, Chinese strings preserved)
+
+**Files:**
+- Modify: `web/src/pages/register/Register.vue` (was 225 lines → target ≤ 100)
+
+- [ ] **Step 1: Replace the whole file**
+
+```vue
+
+
+
+```
+
+- [ ] **Step 2: Verify line count**
+
+```bash
+wc -l web/src/pages/register/Register.vue
+```
+
+Expected: ≤ 100.
+
+- [ ] **Step 3: Run check + tests**
+
+```bash
+cd web && bun run check && bun run test
+```
+
+Expected: green.
+
+- [ ] **Step 4: Commit**
+
+```bash
+git add web/src/pages/register/Register.vue
+git commit -m "refactor(pages/register): rewrite Register against AuthForm (≤100 lines)"
+```
+
+---
+
+### Task 5: Rewrite ForgotPassword.vue (≤ 100 lines, Chinese strings preserved)
+
+**Files:**
+- Modify: `web/src/pages/forgot-password/ForgotPassword.vue` (was 151 lines → target ≤ 100)
+
+- [ ] **Step 1: Replace the whole file**
+
+```vue
+
+
+
+```
+
+- [ ] **Step 2: Verify line count + tests**
+
+```bash
+wc -l web/src/pages/forgot-password/ForgotPassword.vue && cd web && bun run check && bun run test
+```
+
+Expected: ≤ 100, green.
+
+- [ ] **Step 3: Commit**
+
+```bash
+git add web/src/pages/forgot-password/ForgotPassword.vue
+git commit -m "refactor(pages/forgot-password): rewrite against AuthForm (≤100 lines)"
+```
+
+---
+
+### Task 6: Rewrite VerifyEmail.vue (≤ 100 lines, inline atoms — no AuthForm)
+
+**Files:**
+- Modify: `web/src/pages/verify-email/VerifyEmail.vue` (was 195 lines → target ≤ 100)
+
+**Design notes:**
+- No form fields → AuthForm isn't a fit. Use atoms (`Text`, `Spinner`, `Dot`) + `Button` molecule directly.
+- 4 visual states: `pending` (token-driven verifying), `idle` (waiting for user to check email), `success` (verified), `error` (token invalid / failed).
+- Dot color encodes state via the existing `Dot` atom tones (`accent` | `success` | `warning` | `error` | `info` — confirmed). Use `success` / `error` / `accent` (pending) / `info` (idle).
+- The page lives inside `BareLayout` — already provides centering and background. The page just renders one centered `
` ≤ 420px wide. + +- [ ] **Step 1: Replace the whole file** + +```vue + + + + + +``` + +- [ ] **Step 2: Sanity-check the Dot atom signature (already verified in plan but re-check after pull)** + +```bash +grep -nE "tone\??:|'success'|'error'|'accent'|'info'" src/components/atoms/Dot.vue +``` + +Expected: shows tones `'accent' | 'success' | 'warning' | 'error' | 'info'`. The page already targets only `success` / `error` / `accent` / `info`. If the atom signature has drifted since this plan was written, adapt `dotTone` to use only available tones — **do not modify the Dot atom in this task**. + +- [ ] **Step 3: Verify line count + tests** + +```bash +wc -l web/src/pages/verify-email/VerifyEmail.vue && cd web && bun run check && bun run test +``` + +Expected: ≤ 100, green. + +- [ ] **Step 4: Commit** + +```bash +git add web/src/pages/verify-email/VerifyEmail.vue +git commit -m "refactor(pages/verify-email): rewrite against atoms/molecules (≤100 lines)" +``` + +--- + +## Phase D — Verification + +### Task 7: Full pipeline + token discipline grep + +- [ ] **Step 1: Full pipeline** + +```bash +cd web && bun run test && bun run check && bun run build +``` + +Expected: all green, build artifact produced. + +- [ ] **Step 2: Token discipline — grep new auth code for legacy color/font references** + +```bash +cd web && grep -nE "#[0-9a-fA-F]{3,8}|--color-[a-z]|cubic-bezier\(0.4|translateY\(-1px\)|Manrope|backdrop-filter|linear-gradient" \ + src/components/organisms/auth/*.vue \ + src/pages/login/Login.vue \ + src/pages/register/Register.vue \ + src/pages/forgot-password/ForgotPassword.vue \ + src/pages/verify-email/VerifyEmail.vue +``` + +Expected: 0 hits. If any sneak in, replace with token references. The acceptable exceptions are documented hex values inside `web/src/styles/tokens/*.css` only. + +- [ ] **Step 3: Border-radius audit on new auth files** + +```bash +cd web && grep -nE "border-radius" \ + src/components/organisms/auth/*.vue \ + src/pages/login/Login.vue \ + src/pages/register/Register.vue \ + src/pages/forgot-password/ForgotPassword.vue \ + src/pages/verify-email/VerifyEmail.vue +``` + +Expected: only `var(--radius-sm)` / `var(--radius-md)` / `0`. No literal pixel values like `10px` / `18px` / `42px` for radius. + +- [ ] **Step 4: Line-count audit** + +```bash +wc -l src/pages/login/Login.vue \ + src/pages/register/Register.vue \ + src/pages/forgot-password/ForgotPassword.vue \ + src/pages/verify-email/VerifyEmail.vue +``` + +Expected: each file ≤ 100. Total ≤ 400. + +- [ ] **Step 5: Manual smoke test — 12-combo coverage** + +Run `cd web && bun run dev`. For each route below, cycle `data-accent` (lime/amber/oxide), `data-theme` (dark/light), and `data-motion` (spring/tight/reduced) via DevTools (`document.documentElement.dataset.accent = 'amber'` etc.). Confirm no visual breakage. + +**Per-route checklist:** + +`/login`: +- a. Page renders inside AuthLayout (brand block + card). +- b. Username + password + Remember me visible; mock-account hint visible. +- c. Submit with `admin/admin123` → land on `/files` (assuming backend supports the mock — fall back to checking a 401 error renders cleanly in the error row). +- d. Submit with wrong password → red error block under fields. +- e. Toggle Remember me + valid login → reload → username pre-filled, checkbox pre-checked. +- f. Untick Remember me + valid login → reload → fields reset, no localStorage leak. +- g. Click "Create one" → navigates to `/register`, no full-page fade. +- h. Click "Forgot password" → navigates to `/forgot-password`. +- i. Click eye → password text becomes visible. +- j. While submitting, button shows spinner and is disabled. + +`/register`: +- a. 4 fields visible, Chinese labels. +- b. Submit with mismatched passwords → "两次输入的密码不一致。" appears, no API call. +- c. Submit valid → land on `/verify-email` or `/login` based on backend response. +- d. Click "前往登录" → navigates to `/login`. + +`/forgot-password`: +- a. Single email field, Chinese subtitle. +- b. Submit valid email → success block "重置邮件已发送,请检查邮箱。". +- c. Simulate failure (offline) → red error block. +- d. Click "返回登录" → `/login`. + +`/verify-email`: +- a. Hit `/verify-email` without query → idle dot + initial English copy + (if authed) resend button. +- b. Hit `/verify-email?token=BAD` → pending dot, then error message. +- c. Authed + already verified → success dot + "Your email has already been verified.". +- d. Click "Back to login" → `/login`. Click "Enter files" (when authed) → `/files`. +- e. Click resend → button enters loading, success/error block updates accordingly. + +**Cross-route invariants:** +- Switching `data-accent` retints submit button, checkbox, secondary links, dot, and status borders. +- Switching `data-theme` flips surfaces; text remains WCAG AA readable. +- Switching `data-motion="reduced"` removes spring/fade animation. +- No `console.warn` / `console.error` in DevTools across any of the above. + +- [ ] **Step 6: If anything fails** + +Fix in a follow-up commit on the same task. Do not declare P5 done with broken parity. + +- [ ] **Step 7: No commit if all green** + +Otherwise add a `fix(p5): ` commit per fix. + +--- + +### Task 8: Update progress memory + +**Files:** +- Modify: `C:\Users\xc150\.claude\projects\D--pyprj-fileflash\memory\frontend_redesign_progress.md` + +- [ ] **Step 1: Move P5 from "进行中 / 待开始" into "已完成"** + +Read the file, then add an entry after the P4 row in the same format: + +``` +- **P5 Public Auth Flow**(2026-05-12)— 新组件 `organisms/auth/AuthForm.vue`(login/register/forgot 三模式 + 9 个 spec 通过)+ 4 个页面全部重写:Login X 行(旧 267)/ Register Y 行(旧 225)/ ForgotPassword Z 行(旧 151)/ VerifyEmail W 行(旧 195)。VerifyEmail 走 BareLayout,不用 AuthForm(无表单字段,直接拼 atoms/molecules)。dev library 加 `Organisms · Auth` 段含 mode picker。AuthLayout 仍用旧 themeStore(P6 替换)。 +``` + +Replace `X/Y/Z/W` with the actual `wc -l` outputs from Task 7 Step 4. + +Then remove `**P5**` from "进行中 / 待开始" section. + +- [ ] **Step 2: Commit a chore entry to the repo** + +```bash +git commit --allow-empty -m "chore(progress): mark P5 Public Auth Flow complete" +``` + +(The memory file lives outside the repo, so its update is not staged; an empty commit records the milestone for git history.) + +--- + +## Self-Review checklist + +After all tasks land, run this once. + +1. **Spec coverage** — spec §3.1 calls for `organisms/auth/AuthForm.vue` for Login/Register/ForgotPassword. ✅ Created in Task 1. Spec §4 P5 row lists all 4 pages — ✅ all 4 rewritten in Tasks 3–6. Spec §3.3 dev library coverage — ✅ Task 2. + +2. **Pages ≤ 100 lines** — Task 7 Step 4 grep confirms. + +3. **Token discipline** — Task 7 Step 2 grep returns empty. + +4. **No new `common/*` imports from new auth files** — grep: + +```bash +cd web && grep -nE "from '\.\./\.\./common/" \ + src/components/organisms/auth/*.vue \ + src/pages/login/Login.vue \ + src/pages/register/Register.vue \ + src/pages/forgot-password/ForgotPassword.vue \ + src/pages/verify-email/VerifyEmail.vue +``` + +Expected: empty. + +5. **Sharp edges** — Task 7 Step 3 grep confirms. + +6. **Build + test green** — Task 7 Step 1. + +7. **Manual smoke covers all 12 token combos × all 4 routes** — Task 7 Step 5. + +If any check fails, add a fix commit before marking P5 done. + +--- + +## Execution Handoff + +**Plan complete and saved to `docs/superpowers/plans/2026-05-11-frontend-redesign-p5-public-auth-flow.md`. Two execution options:** + +**1. Subagent-Driven (recommended)** — dispatch a fresh subagent per task, review between tasks, fast iteration. + +**2. Inline Execution** — execute tasks in this session using `superpowers:executing-plans`, batch execution with checkpoints. + +**Which approach?** diff --git a/docs/superpowers/plans/2026-05-12-frontend-redesign-p4-other-file-surfaces.md b/docs/superpowers/plans/2026-05-12-frontend-redesign-p4-other-file-surfaces.md new file mode 100644 index 0000000..fd059af --- /dev/null +++ b/docs/superpowers/plans/2026-05-12-frontend-redesign-p4-other-file-surfaces.md @@ -0,0 +1,124 @@ +# P4 Other File Surfaces · Implementation Plan + +**Spec reference**: `docs/superpowers/specs/2026-05-11-frontend-quality-redesign-design.md` §3.1, §4 (P4 row), §9 (acceptance criteria) +**Predecessor**: P3 Core File Path (verified green: 246 tests / typecheck clean as of 2026-05-12) +**Goal**: Migrate Shared / Trash / ShareAccess pages to the Industrial Dashboard system. Each rewritten page file ≤ 100 lines. All visual responsibility moves into organisms under `components/organisms/`. Legacy CSS tokens (`var(--color-border)`, `var(--color-bg-primary)`, `var(--border-radius-*)`) eliminated from these surfaces. + +## Scope + +3 page files in scope: + +| Page | Current LOC | Target LOC | +|---|---|---| +| `pages/shared/SharedWithMe.vue` | 386 | ≤ 100 | +| `pages/trash/Trash.vue` | 258 | ≤ 100 | +| `pages/share/ShareAccess.vue` | 363 | ≤ 100 | + +Note on spec wording "复用 FileTable": after reviewing the data shapes, the three surfaces have distinct column sets that do **not** map cleanly onto `FileTable`'s `ContentItem` contract. Building shared sibling table organisms keeps `FileTable` focused. We will still **reuse atoms, molecules, EmptyState, and the design tokens** — that's the real reuse target. + +## New Organisms + +### `components/organisms/sharing/` (new folder) +- `SharedReceivedTable.vue` — header + rows for `SharedItem[]`. Columns: checkbox, name + type tag, sharedBy, permission, sharedAt (mono), accept-action. Emits `toggle`, `toggle-all`, `accept`. +- `SharedLinksTable.vue` — header + rows for `Share[]`. Columns: resource name + type, share-link code, visits/downloads (mono), createdAt (mono), copy/delete actions. Emits `copy`, `delete`. +- `SharedBatchBar.vue` — selection summary + "Accept Selected" action. Mirrors `BulkActionBar` pattern (count + actions; floating overlay via Transition). +- `index.ts` — public barrel + +### `components/organisms/trash/` (new folder) +- `TrashTable.vue` — header + rows for `RecycleBinItem[]`. Columns: icon + name, originalPath, deletedAt (mono), expires-in (mono, accent tint when ≤ 7 days), restore/delete actions. Emits `restore`, `permanent-delete`. +- `index.ts` + +### `components/organisms/share/` (new folder; not to be confused with `sharing/` above) +- `ShareInfoCard.vue` — read-only metadata card for an accessed share. Rows: Type, Name, Size (mono), Expires, Password. Uses small uppercase labels per design system. +- `ShareAccessPanel.vue` — gate panel. Two modes: password-protected (TextField + Unlock) or open-access (single "Get Access" button). Emits `request-access`. +- `ShareActionsPanel.vue` — post-access actions: Preview / Download (file only) / Save to My Space. Emits `preview`, `download`, `save`. +- `index.ts` + +### `components/organisms/files/` +- `EmptyState.vue` — extend with `variant: 'loading' | 'empty' | 'no-results' | 'error'`. The `'error'` variant is new (replaces the inline `.state.error` block in ShareAccess). Already has `loading`/`empty`/`no-results` from P3; just adds the error case. + +## Page Rewrites + +### `pages/shared/SharedWithMe.vue` (~80 lines target) +- `