From f1e77d42be32fc90ea69c833ca98d36afc266df3 Mon Sep 17 00:00:00 2001 From: DevOpsMadDog Date: Sun, 5 Oct 2025 11:54:10 +1100 Subject: [PATCH] Add OIDC tenant auth and RBAC enforcement --- backend/app.py | 485 ++++++++++++++++++++++++++++++++++----- fixops/configuration.py | 155 ++++++++++++- tests/test_end_to_end.py | 244 ++++++++++++++++++++ 3 files changed, 823 insertions(+), 61 deletions(-) diff --git a/backend/app.py b/backend/app.py index d2e5d5110..34f1d7e37 100644 --- a/backend/app.py +++ b/backend/app.py @@ -1,12 +1,17 @@ from __future__ import annotations +import base64 import csv +import hashlib +import hmac import io +import json import logging -from pathlib import Path -from typing import Any, Dict, Optional +import time +from typing import Any, Dict, Iterable, Mapping, Optional -from fastapi import Depends, FastAPI, File, HTTPException, UploadFile +import requests +from fastapi import Depends, FastAPI, File, Header, HTTPException, Request, UploadFile from fastapi.middleware.cors import CORSMiddleware from fastapi.security import APIKeyHeader @@ -37,42 +42,306 @@ def create_app() -> FastAPI: orchestrator = PipelineOrchestrator() overlay = load_overlay() - # API authentication setup + # API authentication and authorisation setup auth_strategy = overlay.auth.get("strategy", "").lower() - header_name = overlay.auth.get("header", "X-API-Key") - api_key_header = APIKeyHeader(name=header_name, auto_error=False) + api_header_name = overlay.auth.get("header", "X-API-Key") + tenant_header_name = overlay.auth.get("tenant_header", "X-FixOps-Tenant") + default_tenant = overlay.auth.get("default_tenant", "default") + api_key_header = APIKeyHeader(name=api_header_name, auto_error=False) expected_tokens = overlay.auth_tokens if auth_strategy == "token" else tuple() + jwks_cache: Dict[str, list[Dict[str, Any]]] = {} - async def _verify_api_key(api_key: Optional[str] = Depends(api_key_header)) -> None: - if auth_strategy != "token": + def _normalise_sequence(value: Any) -> list[str]: + if value is None: + return [] + if isinstance(value, str): + return [item for item in value.replace(",", " ").split() if item] + if isinstance(value, Iterable): + items: list[str] = [] + for entry in value: + if isinstance(entry, str): + candidate = entry.strip() + if candidate: + items.append(candidate) + elif entry is not None: + items.append(str(entry)) + return list(dict.fromkeys(items)) + return [str(value)] + + def _b64url_decode(data: str) -> bytes: + padded = data + "=" * (-len(data) % 4) + return base64.urlsafe_b64decode(padded.encode("ascii")) + + def _load_jwks(provider_id: str, provider: Mapping[str, Any]) -> list[Dict[str, Any]]: + cached = jwks_cache.get(provider_id) + if cached is not None: + return cached + + keys: list[Dict[str, Any]] = [] + jwks = provider.get("jwks") if isinstance(provider, Mapping) else None + if isinstance(jwks, Mapping): + raw_keys = jwks.get("keys") + if isinstance(raw_keys, Iterable): + for entry in raw_keys: + if isinstance(entry, Mapping): + keys.append(dict(entry)) + elif isinstance(jwks, Iterable): + for entry in jwks: + if isinstance(entry, Mapping): + keys.append(dict(entry)) + elif isinstance(provider, Mapping) and provider.get("jwks_url"): + try: + response = requests.get(str(provider["jwks_url"]), timeout=5) + response.raise_for_status() + data = response.json() + except requests.RequestException as exc: # pragma: no cover - network instability is environment specific + raise HTTPException(status_code=503, detail="Failed to fetch identity provider keys") from exc + else: + raw_keys = data.get("keys") + if isinstance(raw_keys, Iterable): + for entry in raw_keys: + if isinstance(entry, Mapping): + keys.append(dict(entry)) + if not keys: + raise HTTPException(status_code=500, detail="Identity provider does not expose any keys") + jwks_cache[provider_id] = keys + return keys + + def _select_jwk(keys: Iterable[Mapping[str, Any]], kid: Optional[str], algorithm: str) -> Mapping[str, Any]: + candidates: list[Mapping[str, Any]] = [] + for key in keys: + if kid and key.get("kid") != kid: + continue + key_alg = key.get("alg") + if key_alg and key_alg != algorithm: + continue + candidates.append(key) + if not candidates: + for key in keys: + key_alg = key.get("alg") + if key_alg and key_alg != algorithm: + continue + candidates.append(key) + break + if not candidates: + raise HTTPException(status_code=401, detail="No matching signing key found for token") + return dict(candidates[0]) + + def _verify_hs256_signature(key: Mapping[str, Any], signing_input: bytes, signature: bytes) -> None: + secret = key.get("k") + if not isinstance(secret, str): + raise HTTPException(status_code=500, detail="Invalid symmetric key configuration") + expected = hmac.new(_b64url_decode(secret), signing_input, hashlib.sha256).digest() + if not hmac.compare_digest(expected, signature): + raise HTTPException(status_code=401, detail="Invalid token signature") + + def _verify_rs256_signature(key: Mapping[str, Any], signing_input: bytes, signature: bytes) -> None: + modulus_b64 = key.get("n") + exponent_b64 = key.get("e", "AQAB") + if not isinstance(modulus_b64, str) or not isinstance(exponent_b64, str): + raise HTTPException(status_code=500, detail="Invalid RSA key material") + modulus = int.from_bytes(_b64url_decode(modulus_b64), "big") + exponent = int.from_bytes(_b64url_decode(exponent_b64), "big") + if modulus <= 0 or exponent <= 0: + raise HTTPException(status_code=500, detail="Invalid RSA key material") + signature_int = int.from_bytes(signature, "big") + key_size = (modulus.bit_length() + 7) // 8 + if len(signature) != key_size: + signature_int = int.from_bytes(signature.rjust(key_size, b"\x00"), "big") + decrypted = pow(signature_int, exponent, modulus) + decrypted_bytes = decrypted.to_bytes(key_size, "big") + digest = hashlib.sha256(signing_input).digest() + digestinfo_prefix = bytes.fromhex("3031300d060960864801650304020105000420") + padding_len = key_size - len(digestinfo_prefix) - len(digest) - 3 + if padding_len < 8: + raise HTTPException(status_code=401, detail="Invalid token signature") + expected = b"\x00\x01" + b"\xff" * padding_len + b"\x00" + digestinfo_prefix + digest + if decrypted_bytes != expected: + raise HTTPException(status_code=401, detail="Invalid token signature") + + def _decode_and_verify_jwt(token: str, provider_id: str, provider: Mapping[str, Any]) -> Dict[str, Any]: + try: + header_segment, payload_segment, signature_segment = token.split(".") + except ValueError as exc: + raise HTTPException(status_code=401, detail="Malformed bearer token") from exc + + try: + header = json.loads(_b64url_decode(header_segment)) + except json.JSONDecodeError as exc: + raise HTTPException(status_code=401, detail="Invalid token header") from exc + + algorithm = header.get("alg") + if not isinstance(algorithm, str): + raise HTTPException(status_code=401, detail="Token missing algorithm") + + signing_input = f"{header_segment}.{payload_segment}".encode("ascii") + try: + signature = _b64url_decode(signature_segment) + except Exception as exc: # pragma: no cover - unexpected encoding errors + raise HTTPException(status_code=401, detail="Malformed token signature") from exc + + keys = _load_jwks(provider_id, provider) + key = _select_jwk(keys, header.get("kid"), algorithm) + + if algorithm == "HS256": + _verify_hs256_signature(key, signing_input, signature) + elif algorithm == "RS256": + _verify_rs256_signature(key, signing_input, signature) + else: + raise HTTPException(status_code=400, detail=f"Unsupported JWT algorithm '{algorithm}'") + + try: + payload = json.loads(_b64url_decode(payload_segment)) + except json.JSONDecodeError as exc: + raise HTTPException(status_code=401, detail="Invalid token payload") from exc + + exp = payload.get("exp") + if isinstance(exp, (int, float)) and time.time() > float(exp): + raise HTTPException(status_code=401, detail="Token has expired") + + nbf = payload.get("nbf") + if isinstance(nbf, (int, float)) and time.time() < float(nbf): + raise HTTPException(status_code=401, detail="Token not yet valid") + + return payload + + def _extract_roles(claims: Mapping[str, Any]) -> set[str]: + roles: set[str] = set() + for key in ("roles", "role", "groups", "permissions"): + roles.update(_normalise_sequence(claims.get(key))) + for key in ("scope", "scp"): + roles.update(_normalise_sequence(claims.get(key))) + return roles + + def _verify_oidc_token(token: str, tenant_id: str) -> Dict[str, Any]: + tenant = overlay.get_tenant(tenant_id) + if not tenant: + raise HTTPException(status_code=403, detail="Unknown tenant") + identity = tenant.get("identity") if isinstance(tenant.get("identity"), Mapping) else {} + provider_id = identity.get("provider") if isinstance(identity, Mapping) else None + if not provider_id: + defaults = overlay.tenancy_settings.get("defaults") + if isinstance(defaults, Mapping): + provider_id = defaults.get("identity_provider") + if not provider_id: + raise HTTPException(status_code=403, detail="Tenant does not have an identity provider configured") + provider = overlay.tenant_identity_providers.get(str(provider_id)) + if not provider: + raise HTTPException(status_code=403, detail="Identity provider configuration missing") + + claims = _decode_and_verify_jwt(token, str(provider_id), provider) + + issuer = provider.get("issuer") + if issuer and claims.get("iss") != issuer: + raise HTTPException(status_code=401, detail="Token issuer mismatch") + + allowed_audiences = set(_normalise_sequence(provider.get("allowed_audiences"))) + tenant_audiences = identity.get("allowed_audiences") if isinstance(identity, Mapping) else None + allowed_audiences.update(_normalise_sequence(tenant_audiences)) + if allowed_audiences: + token_audience = claims.get("aud") + token_audiences: set[str] = set() + if isinstance(token_audience, str): + token_audiences = {token_audience} + elif isinstance(token_audience, Iterable): + token_audiences = {str(entry) for entry in token_audience} + if not token_audiences & allowed_audiences: + raise HTTPException(status_code=403, detail="Token audience not permitted for tenant") + + return { + "tenant_id": tenant_id, + "claims": claims, + "tenant": tenant, + "identity": identity, + "provider": provider, + } + + def _enforce_rbac(action: str, context: Mapping[str, Any]) -> None: + identity = context.get("identity") + if not isinstance(identity, Mapping): + return + roles_config = identity.get("roles") + if not isinstance(roles_config, Mapping): return - if not api_key or api_key not in expected_tokens: - raise HTTPException(status_code=401, detail="Invalid or missing API token") + required = roles_config.get(action) + if required is None: + fallback_map = { + "upload": ("uploads", "ingest", "write"), + "pipeline": ("run", "execute", "pipeline_run"), + "feedback": ("feedback",), + } + for candidate in fallback_map.get(action, ()): # pragma: no cover - defensive mapping + candidate_roles = roles_config.get(candidate) + if candidate_roles is not None: + required = candidate_roles + break + required_roles = set(_normalise_sequence(required)) + if not required_roles: + return + token_roles = _extract_roles(context.get("claims", {})) + if not token_roles.intersection(required_roles): + raise HTTPException(status_code=403, detail=f"Missing required role for '{action}'") + + def _resolve_archive(tenant_id: str) -> ArtefactArchive: + archives: Dict[str, ArtefactArchive] = app.state.tenant_archives + archive = archives.get(tenant_id) + if archive is not None: + return archive + archive_dir = overlay.tenant_archive_directory(tenant_id) + archive = ArtefactArchive(archive_dir) + archives[tenant_id] = archive + return archive + + def _tenant_state(tenant_id: str) -> tuple[Dict[str, Any], Dict[str, Any], ArtefactArchive]: + artifact_map: Dict[str, Dict[str, Any]] = app.state.artifacts + tenant_artifacts = artifact_map.setdefault(tenant_id, {}) + archive_records_map: Dict[str, Dict[str, Any]] = app.state.archive_records + tenant_records = archive_records_map.setdefault(tenant_id, {}) + archive = _resolve_archive(tenant_id) + return tenant_artifacts, tenant_records, archive + + async def _authorise_request( + action: str, + *, + api_key: Optional[str], + tenant_id: Optional[str], + authorization: Optional[str], + ) -> Dict[str, Any]: + if auth_strategy == "token": + if not api_key or api_key not in expected_tokens: + raise HTTPException(status_code=401, detail="Invalid or missing API token") + return {"tenant_id": default_tenant, "claims": {"token": api_key}} + if auth_strategy == "oidc": + if not tenant_id: + raise HTTPException(status_code=400, detail="Missing tenant header") + if not authorization or not authorization.lower().startswith("bearer "): + raise HTTPException(status_code=401, detail="Missing bearer token") + token = authorization.split(" ", 1)[1].strip() + context = _verify_oidc_token(token, tenant_id) + _enforce_rbac(action, context) + return context + # Default to a permissive mode for demo deployments without explicit auth + return {"tenant_id": default_tenant, "claims": {}} for directory in overlay.data_directories.values(): ensure_secure_directory(directory) - archive_dir = overlay.data_directories.get("archive_dir") - if archive_dir is None: - root = ( - overlay.allowed_data_roots[0] - if overlay.allowed_data_roots - else Path("data").resolve() - ) - archive_dir = (root / "archive" / overlay.mode).resolve() - archive = ArtefactArchive(archive_dir) - app.state.normalizer = normalizer app.state.orchestrator = orchestrator - app.state.artifacts: Dict[str, Any] = {} + app.state.artifacts: Dict[str, Dict[str, Any]] = {} app.state.overlay = overlay - app.state.archive = archive + app.state.tenant_archives: Dict[str, ArtefactArchive] = {} app.state.archive_records: Dict[str, Dict[str, Any]] = {} app.state.feedback = ( FeedbackRecorder(overlay) if overlay.toggles.get("capture_feedback") else None ) + app.state.auth_strategy = auth_strategy + app.state.tenant_header = tenant_header_name + + if auth_strategy != "oidc": + _resolve_archive(default_tenant) async def _read_limited(file: UploadFile, stage: str) -> bytes: limit = overlay.upload_limit(stage) @@ -114,13 +383,15 @@ def _store( stage: str, payload: Any, *, + tenant_id: str, original_filename: Optional[str] = None, raw_bytes: Optional[bytes] = None, ) -> None: - logger.debug("Storing stage %s", stage) - app.state.artifacts[stage] = payload + logger.debug("Storing stage %s for tenant %s", stage, tenant_id) + tenant_artifacts, tenant_records, archive = _tenant_state(tenant_id) + tenant_artifacts[stage] = payload try: - record = app.state.archive.persist( + record = archive.persist( stage, payload, original_filename=original_filename, @@ -129,11 +400,30 @@ def _store( except Exception as exc: # pragma: no cover - persistence must not break ingestion logger.exception("Failed to persist artefact stage %s", stage) record = {"stage": stage, "error": str(exc)} - app.state.archive_records[stage] = record + tenant_records[stage] = dict(record) + + @app.post("/inputs/design") + async def ingest_design( + request: Request, + file: UploadFile = File(...), + authorization: Optional[str] = Header(None, alias="Authorization"), + api_key: Optional[str] = Depends(api_key_header), + tenant_id: Optional[str] = Header(None, alias=tenant_header_name), + ) -> Dict[str, Any]: + auth_context = await _authorise_request( + "upload", api_key=api_key, tenant_id=tenant_id, authorization=authorization + ) + request.state.tenant_id = auth_context["tenant_id"] - @app.post("/inputs/design", dependencies=[Depends(_verify_api_key)]) - async def ingest_design(file: UploadFile = File(...)) -> Dict[str, Any]: - _validate_content_type(file, ("text/csv", "application/vnd.ms-excel", "application/csv")) + _validate_content_type( + file, + ( + "text/csv", + "application/vnd.ms-excel", + "application/csv", + "text/plain", + ), + ) raw_bytes = await _read_limited(file, "design") text = raw_bytes.decode("utf-8", errors="ignore") reader = csv.DictReader(io.StringIO(text)) @@ -142,8 +432,14 @@ async def ingest_design(file: UploadFile = File(...)) -> Dict[str, Any]: if not rows: raise HTTPException(status_code=400, detail="Design CSV contained no rows") - dataset = {"columns": reader.fieldnames or [], "rows": rows} - _store("design", dataset, original_filename=file.filename, raw_bytes=raw_bytes) + dataset = {"columns": reader.fieldnames or [], "rows": rows, "row_count": len(rows)} + _store( + "design", + dataset, + tenant_id=auth_context["tenant_id"], + original_filename=file.filename, + raw_bytes=raw_bytes, + ) return { "stage": "design", "input_filename": file.filename, @@ -152,8 +448,19 @@ async def ingest_design(file: UploadFile = File(...)) -> Dict[str, Any]: "data": dataset, } - @app.post("/inputs/sbom", dependencies=[Depends(_verify_api_key)]) - async def ingest_sbom(file: UploadFile = File(...)) -> Dict[str, Any]: + @app.post("/inputs/sbom") + async def ingest_sbom( + request: Request, + file: UploadFile = File(...), + authorization: Optional[str] = Header(None, alias="Authorization"), + api_key: Optional[str] = Depends(api_key_header), + tenant_id: Optional[str] = Header(None, alias=tenant_header_name), + ) -> Dict[str, Any]: + auth_context = await _authorise_request( + "upload", api_key=api_key, tenant_id=tenant_id, authorization=authorization + ) + request.state.tenant_id = auth_context["tenant_id"] + _validate_content_type( file, ( @@ -171,7 +478,13 @@ async def ingest_sbom(file: UploadFile = File(...)) -> Dict[str, Any]: logger.exception("SBOM normalisation failed") raise HTTPException(status_code=400, detail=f"Failed to parse SBOM: {exc}") from exc - _store("sbom", sbom, original_filename=file.filename, raw_bytes=raw_bytes) + _store( + "sbom", + sbom, + tenant_id=auth_context["tenant_id"], + original_filename=file.filename, + raw_bytes=raw_bytes, + ) return { "stage": "sbom", "input_filename": file.filename, @@ -181,8 +494,19 @@ async def ingest_sbom(file: UploadFile = File(...)) -> Dict[str, Any]: ], } - @app.post("/inputs/cve", dependencies=[Depends(_verify_api_key)]) - async def ingest_cve(file: UploadFile = File(...)) -> Dict[str, Any]: + @app.post("/inputs/cve") + async def ingest_cve( + request: Request, + file: UploadFile = File(...), + authorization: Optional[str] = Header(None, alias="Authorization"), + api_key: Optional[str] = Depends(api_key_header), + tenant_id: Optional[str] = Header(None, alias=tenant_header_name), + ) -> Dict[str, Any]: + auth_context = await _authorise_request( + "upload", api_key=api_key, tenant_id=tenant_id, authorization=authorization + ) + request.state.tenant_id = auth_context["tenant_id"] + _validate_content_type( file, ( @@ -200,7 +524,13 @@ async def ingest_cve(file: UploadFile = File(...)) -> Dict[str, Any]: logger.exception("CVE feed normalisation failed") raise HTTPException(status_code=400, detail=f"Failed to parse CVE feed: {exc}") from exc - _store("cve", cve_feed, original_filename=file.filename, raw_bytes=raw_bytes) + _store( + "cve", + cve_feed, + tenant_id=auth_context["tenant_id"], + original_filename=file.filename, + raw_bytes=raw_bytes, + ) return { "stage": "cve", "input_filename": file.filename, @@ -208,8 +538,19 @@ async def ingest_cve(file: UploadFile = File(...)) -> Dict[str, Any]: "validation_errors": cve_feed.errors, } - @app.post("/inputs/sarif", dependencies=[Depends(_verify_api_key)]) - async def ingest_sarif(file: UploadFile = File(...)) -> Dict[str, Any]: + @app.post("/inputs/sarif") + async def ingest_sarif( + request: Request, + file: UploadFile = File(...), + authorization: Optional[str] = Header(None, alias="Authorization"), + api_key: Optional[str] = Depends(api_key_header), + tenant_id: Optional[str] = Header(None, alias=tenant_header_name), + ) -> Dict[str, Any]: + auth_context = await _authorise_request( + "upload", api_key=api_key, tenant_id=tenant_id, authorization=authorization + ) + request.state.tenant_id = auth_context["tenant_id"] + _validate_content_type( file, ( @@ -227,7 +568,13 @@ async def ingest_sarif(file: UploadFile = File(...)) -> Dict[str, Any]: logger.exception("SARIF normalisation failed") raise HTTPException(status_code=400, detail=f"Failed to parse SARIF: {exc}") from exc - _store("sarif", sarif, original_filename=file.filename, raw_bytes=raw_bytes) + _store( + "sarif", + sarif, + tenant_id=auth_context["tenant_id"], + original_filename=file.filename, + raw_bytes=raw_bytes, + ) return { "stage": "sarif", "input_filename": file.filename, @@ -235,43 +582,63 @@ async def ingest_sarif(file: UploadFile = File(...)) -> Dict[str, Any]: "tools": sarif.tool_names, } - @app.post("/pipeline/run", dependencies=[Depends(_verify_api_key)]) - async def run_pipeline() -> Dict[str, Any]: - overlay: OverlayConfig = app.state.overlay - required = overlay.required_inputs - missing = [stage for stage in required if stage not in app.state.artifacts] + @app.post("/pipeline/run") + async def run_pipeline( + request: Request, + authorization: Optional[str] = Header(None, alias="Authorization"), + api_key: Optional[str] = Depends(api_key_header), + tenant_id: Optional[str] = Header(None, alias=tenant_header_name), + ) -> Dict[str, Any]: + auth_context = await _authorise_request( + "pipeline", api_key=api_key, tenant_id=tenant_id, authorization=authorization + ) + request.state.tenant_id = auth_context["tenant_id"] + + overlay_config: OverlayConfig = app.state.overlay + required = overlay_config.required_inputs + tenant_artifacts = app.state.artifacts.get(auth_context["tenant_id"], {}) + missing = [stage for stage in required if stage not in tenant_artifacts] if missing: raise HTTPException( status_code=400, detail={"message": "Missing required artefacts", "missing": missing}, ) - if overlay.toggles.get("enforce_ticket_sync") and not overlay.jira.get("project_key"): + if overlay_config.toggles.get("enforce_ticket_sync") and not overlay_config.jira.get("project_key"): raise HTTPException( status_code=500, detail={ "message": "Ticket synchronisation enforced but Jira project_key missing", - "integration": overlay.jira, + "integration": overlay_config.jira, }, ) result = orchestrator.run( - design_dataset=app.state.artifacts.get("design", {"columns": [], "rows": []}), - sbom=app.state.artifacts["sbom"], - sarif=app.state.artifacts["sarif"], - cve=app.state.artifacts["cve"], - overlay=overlay, + design_dataset=tenant_artifacts.get("design", {"columns": [], "rows": []}), + sbom=tenant_artifacts["sbom"], + sarif=tenant_artifacts["sarif"], + cve=tenant_artifacts["cve"], + overlay=overlay_config, ) - if app.state.archive_records: - result["artifact_archive"] = ArtefactArchive.summarise(app.state.archive_records) - app.state.archive_records = {} - if overlay.toggles.get("auto_attach_overlay_metadata", True): - result["overlay"] = overlay.to_sanitised_dict() + tenant_records = app.state.archive_records.get(auth_context["tenant_id"], {}) + if tenant_records: + result["artifact_archive"] = ArtefactArchive.summarise(tenant_records) + app.state.archive_records[auth_context["tenant_id"]] = {} + if overlay_config.toggles.get("auto_attach_overlay_metadata", True): + result["overlay"] = overlay_config.to_sanitised_dict() result["overlay"]["required_inputs"] = list(required) return result - @app.post("/feedback", dependencies=[Depends(_verify_api_key)]) - async def submit_feedback(payload: Dict[str, Any]) -> Dict[str, Any]: + @app.post("/feedback") + async def submit_feedback( + payload: Dict[str, Any], + authorization: Optional[str] = Header(None, alias="Authorization"), + api_key: Optional[str] = Depends(api_key_header), + tenant_id: Optional[str] = Header(None, alias=tenant_header_name), + ) -> Dict[str, Any]: + await _authorise_request( + "feedback", api_key=api_key, tenant_id=tenant_id, authorization=authorization + ) recorder: Optional[FeedbackRecorder] = app.state.feedback if recorder is None: raise HTTPException(status_code=400, detail="Feedback capture disabled in this profile") diff --git a/fixops/configuration.py b/fixops/configuration.py index a10f985cd..2ad759488 100644 --- a/fixops/configuration.py +++ b/fixops/configuration.py @@ -55,6 +55,28 @@ def _deep_merge(base: MutableMapping[str, Any], overrides: Mapping[str, Any]) -> return base +def _normalise_string_sequence(value: Any) -> list[str]: + """Normalise a value representing a sequence of strings.""" + + if value is None: + return [] + if isinstance(value, str): + candidates = [candidate for candidate in value.replace(",", " ").split() if candidate] + return list(dict.fromkeys(candidates)) + if isinstance(value, Iterable): + normalised: list[str] = [] + for entry in value: + if isinstance(entry, str): + candidate = entry.strip() + if candidate: + normalised.append(candidate) + elif entry is not None: + normalised.append(str(entry)) + # Remove duplicates while preserving order + return list(dict.fromkeys(normalised)) + return [str(value)] + + _DEFAULT_GUARDRAIL_MATURITY = "scaling" _DEFAULT_GUARDRAIL_PROFILES: Dict[str, Dict[str, str]] = { "foundational": {"fail_on": "critical", "warn_on": "high"}, @@ -503,22 +525,151 @@ def tenancy_settings(self) -> Dict[str, Any]: profile_overrides = dict(profile) tenants: list[Dict[str, Any]] = [] + defaults = settings.get("defaults") if isinstance(settings.get("defaults"), Mapping) else {} + default_identity: Dict[str, Any] = {} + if isinstance(defaults, Mapping): + identity_defaults = defaults.get("identity") + if isinstance(identity_defaults, Mapping): + default_identity = dict(identity_defaults) + def _extend(raw: Any) -> None: if isinstance(raw, Iterable): for entry in raw: if isinstance(entry, Mapping): - tenants.append(dict(entry)) + payload = {k: v for k, v in entry.items() if k != "identity"} + identity_payload: Dict[str, Any] = {} + if default_identity: + identity_payload.update(default_identity) + identity_overrides = entry.get("identity") + if isinstance(identity_overrides, Mapping): + identity_payload.update(dict(identity_overrides)) + if identity_payload: + audiences = identity_payload.get("allowed_audiences") + identity_payload["allowed_audiences"] = _normalise_string_sequence( + audiences + ) + roles = identity_payload.get("roles") or identity_payload.get("rbac") + normalised_roles: Dict[str, list[str]] = {} + if isinstance(roles, Mapping): + for action, value in roles.items(): + normalised_roles[str(action)] = _normalise_string_sequence(value) + else: + role_candidates = _normalise_string_sequence(roles) + if role_candidates: + normalised_roles["default"] = role_candidates + if normalised_roles: + identity_payload["roles"] = normalised_roles + else: + identity_payload.pop("roles", None) + if identity_payload: + payload["identity"] = identity_payload + tenants.append(payload) _extend(settings.get("tenants")) - _extend(profile_overrides.pop("tenants", None)) + tenant_overrides = profile_overrides.pop("tenants", None) + _extend(tenant_overrides) merged = dict(settings) merged.pop("tenants", None) merged.pop("profiles", None) + identity_override = profile_overrides.pop("identity_providers", None) + directory_override = profile_overrides.pop("directories", None) merged = dict(_deep_merge(merged, profile_overrides)) + + def _merge_directories(raw: Any) -> Dict[str, str]: + directories: Dict[str, str] = {} + if isinstance(raw, Mapping): + for tenant_id, path in raw.items(): + if isinstance(path, str) and path.strip(): + directories[str(tenant_id)] = path + return directories + + directories = _merge_directories(settings.get("directories")) + directories.update(_merge_directories(directory_override)) + + def _merge_identity_providers(raw: Any) -> Dict[str, Dict[str, Any]]: + providers: Dict[str, Dict[str, Any]] = {} + if isinstance(raw, Mapping): + items = raw.items() + elif isinstance(raw, Iterable): + items = [] + for entry in raw: + if isinstance(entry, Mapping): + identifier = entry.get("id") + if identifier: + items.append((identifier, entry)) + else: + return providers + + for identifier, payload in items: # type: ignore[assignment] + if not isinstance(payload, Mapping): + continue + provider_payload = {k: v for k, v in payload.items() if k != "id"} + audiences = provider_payload.get("allowed_audiences") or provider_payload.get("audiences") + provider_payload["allowed_audiences"] = _normalise_string_sequence(audiences) + providers[str(identifier)] = provider_payload + return providers + + identity_providers = _merge_identity_providers(settings.get("identity_providers")) + identity_providers.update(_merge_identity_providers(identity_override)) + + merged["directories"] = directories + merged["identity_providers"] = identity_providers merged["tenants"] = tenants return merged + def get_tenant(self, tenant_id: str) -> Dict[str, Any]: + settings = self.tenancy_settings + tenants = settings.get("tenants", []) + if isinstance(tenants, Iterable): + for entry in tenants: + if isinstance(entry, Mapping) and str(entry.get("id")) == tenant_id: + return dict(entry) + return {} + + @property + def tenant_identity_providers(self) -> Dict[str, Dict[str, Any]]: + settings = self.tenancy_settings + providers = settings.get("identity_providers", {}) + if not isinstance(providers, Mapping): + return {} + return {str(identifier): dict(payload) for identifier, payload in providers.items() if isinstance(payload, Mapping)} + + def tenant_archive_directory(self, tenant_id: str) -> Path: + settings = self.tenancy_settings + directories = settings.get("directories") if isinstance(settings.get("directories"), Mapping) else {} + allowlist = self.allowed_data_roots or (_DEFAULT_DATA_ROOT,) + default_root = allowlist[0] + + target: Optional[str] = None + if isinstance(directories, Mapping): + candidate = directories.get(tenant_id) + if isinstance(candidate, str) and candidate.strip(): + target = candidate + else: + default_candidate = directories.get("default") + if isinstance(default_candidate, str) and default_candidate.strip(): + target = default_candidate + + if target: + path = Path(target).expanduser() + if not path.is_absolute(): + path = (default_root / path).resolve() + else: + fallback = default_root / "tenants" / tenant_id / "archive" + path = fallback.resolve() + + # If the overlay explicitly configured a data archive directory use it for the default tenant. + if tenant_id == "default" and not target: + archive_dir = self.data.get("archive_dir") if isinstance(self.data, Mapping) else None + if isinstance(archive_dir, str) and archive_dir.strip(): + candidate = Path(archive_dir).expanduser() + if not candidate.is_absolute(): + candidate = (default_root / candidate).resolve() + path = candidate + + return _ensure_within_allowlist(path, allowlist) + def module_config(self, name: str) -> Dict[str, Any]: raw = self.modules.get(name) if isinstance(raw, Mapping): diff --git a/tests/test_end_to_end.py b/tests/test_end_to_end.py index cc96da314..b7a661f29 100644 --- a/tests/test_end_to_end.py +++ b/tests/test_end_to_end.py @@ -1,11 +1,16 @@ import json +import base64 import csv import gzip +import hashlib +import hmac import json import os +import time import zipfile from io import BytesIO, StringIO +from pathlib import Path try: from fastapi.testclient import TestClient # type: ignore @@ -20,6 +25,21 @@ from backend.pipeline import PipelineOrchestrator +def _make_hs256_token(secret: bytes, payload: dict, *, kid: str = "demo") -> str: + header = {"alg": "HS256", "typ": "JWT", "kid": kid} + segments = [] + for part in (header, payload): + encoded = base64.urlsafe_b64encode( + json.dumps(part, separators=(",", ":")).encode("utf-8") + ).rstrip(b"=") + segments.append(encoded.decode("ascii")) + signing_input = ".".join(segments).encode("ascii") + signature = hmac.new(secret, signing_input, hashlib.sha256).digest() + signature_segment = base64.urlsafe_b64encode(signature).rstrip(b"=").decode("ascii") + segments.append(signature_segment) + return ".".join(segments) + + def test_end_to_end_demo_pipeline(): design_csv = """component,owner,criticality,notes\npayment-service,app-team,high,Handles card processing\nnotification-service,platform,medium,Sends emails\nai-orchestrator,ml-team,high,LangChain agent orchestrator for support bots\n""" @@ -278,3 +298,227 @@ def test_feedback_endpoint_rejects_invalid_payload(monkeypatch, tmp_path): monkeypatch.delenv("FIXOPS_OVERLAY_PATH", raising=False) monkeypatch.delenv("FIXOPS_DATA_ROOT_ALLOWLIST", raising=False) monkeypatch.delenv("FIXOPS_API_TOKEN", raising=False) + + +def test_oidc_rbac_enforced(monkeypatch, tmp_path): + if TestClient is None or create_app is None: + return + + secret = b"super-secret-key" + jwk_secret = base64.urlsafe_b64encode(secret).decode("ascii").rstrip("=") + overlay_payload = { + "mode": "demo", + "auth": {"strategy": "oidc", "tenant_header": "X-FixOps-Tenant"}, + "data": {"archive_dir": str(tmp_path / "archive_default")}, + "tenancy": { + "defaults": { + "identity": { + "allowed_audiences": ["fixops-api"], + "roles": { + "upload": ["fixops:upload"], + "pipeline": ["fixops:pipeline"], + }, + }, + "identity_provider": "demo-idp", + }, + "identity_providers": { + "demo-idp": { + "issuer": "https://idp.example.com", + "allowed_audiences": ["fixops-api"], + "jwks": { + "keys": [ + { + "kty": "oct", + "k": jwk_secret, + "kid": "demo-key", + "alg": "HS256", + } + ] + }, + } + }, + "tenants": [ + { + "id": "tenant-one", + "name": "Tenant One", + "identity": { + "provider": "demo-idp", + "allowed_audiences": ["fixops-api"], + "roles": { + "upload": ["fixops:upload"], + "pipeline": ["fixops:pipeline"], + }, + }, + } + ], + "directories": { + "tenant-one": str(tmp_path / "tenants" / "tenant-one" / "archive") + }, + }, + } + + overlay_path = tmp_path / "overlay.json" + overlay_path.write_text(json.dumps(overlay_payload), encoding="utf-8") + + monkeypatch.setenv("FIXOPS_OVERLAY_PATH", str(overlay_path)) + monkeypatch.setenv("FIXOPS_DATA_ROOT_ALLOWLIST", str(tmp_path)) + + app = create_app() + client = TestClient(app) + + design_csv = "component,owner\nsvc,team\n" + sbom_document = { + "bomFormat": "CycloneDX", + "specVersion": "1.4", + "version": 1, + "components": [ + { + "type": "library", + "name": "svc", + "version": "1.0.0", + "purl": "pkg:pypi/svc@1.0.0", + } + ], + } + cve_feed = {"vulnerabilities": [{"cveID": "CVE-2024-0001", "severity": "high"}]} + sarif_document = { + "version": "2.1.0", + "$schema": "https://json.schemastore.org/sarif-2.1.0.json", + "runs": [ + { + "tool": {"driver": {"name": "DemoScanner"}}, + "results": [ + { + "ruleId": "DEMO001", + "message": {"text": "Issue"}, + "level": "error", + } + ], + } + ], + } + + wrong_audience_token = _make_hs256_token( + secret, + { + "iss": "https://idp.example.com", + "sub": "user@example.com", + "aud": "other-api", + "exp": int(time.time()) + 300, + "roles": ["fixops:upload"], + }, + kid="demo-key", + ) + + response = client.post( + "/inputs/design", + headers={ + "X-FixOps-Tenant": "tenant-one", + "Authorization": f"Bearer {wrong_audience_token}", + }, + files={"file": ("design.csv", design_csv, "text/csv")}, + ) + assert response.status_code == 403 + + response = client.post( + "/inputs/design", + headers={"X-FixOps-Tenant": "tenant-one"}, + files={"file": ("design.csv", design_csv, "text/csv")}, + ) + assert response.status_code == 401 + + upload_token = _make_hs256_token( + secret, + { + "iss": "https://idp.example.com", + "sub": "user@example.com", + "aud": "fixops-api", + "exp": int(time.time()) + 300, + "roles": ["fixops:upload"], + }, + kid="demo-key", + ) + + headers = { + "X-FixOps-Tenant": "tenant-one", + "Authorization": f"Bearer {upload_token}", + } + + response = client.post( + "/inputs/design", + headers=headers, + files={"file": ("design.csv", design_csv, "text/csv")}, + ) + assert response.status_code == 200 + + response = client.post( + "/inputs/sbom", + headers=headers, + files={ + "file": ( + "sbom.json", + json.dumps(sbom_document), + "application/json", + ) + }, + ) + assert response.status_code == 200 + + response = client.post( + "/inputs/cve", + headers=headers, + files={ + "file": ( + "cve.json", + json.dumps(cve_feed), + "application/json", + ) + }, + ) + assert response.status_code == 200 + + response = client.post( + "/inputs/sarif", + headers=headers, + files={ + "file": ( + "scan.sarif", + json.dumps(sarif_document), + "application/json", + ) + }, + ) + assert response.status_code == 200 + + response = client.post( + "/pipeline/run", + headers=headers, + ) + assert response.status_code == 403 + + full_token = _make_hs256_token( + secret, + { + "iss": "https://idp.example.com", + "sub": "user@example.com", + "aud": "fixops-api", + "exp": int(time.time()) + 300, + "roles": ["fixops:upload", "fixops:pipeline"], + }, + kid="demo-key", + ) + + response = client.post( + "/pipeline/run", + headers={ + "X-FixOps-Tenant": "tenant-one", + "Authorization": f"Bearer {full_token}", + }, + ) + assert response.status_code == 200 + pipeline_payload = response.json() + archive_info = pipeline_payload.get("artifact_archive") + assert archive_info and "sbom" in archive_info + sbom_path = Path(archive_info["sbom"]["normalized_path"]) + assert sbom_path.exists() + assert "tenant-one" in sbom_path.parts