diff --git a/api/core/config.py b/api/core/config.py index 94edba6..2163857 100644 --- a/api/core/config.py +++ b/api/core/config.py @@ -1,9 +1,17 @@ from pydantic_settings import BaseSettings, SettingsConfigDict + + class Settings(BaseSettings): """Application settings.""" PROJECT_NAME: str = "Workspaces API" + # JSON array of allowed CORS origins. For example: + # + # ["https://workspaces.example.com", "https://leaderboard.example.com"] + # + CORS_ORIGINS: list[str] = [] + TASK_DATABASE_URL: str = "postgresql+asyncpg://user:pass@localhost:5432/tasking_manager" OSM_DATABASE_URL: str = "postgresql+asyncpg://user:pass@localhost:5432/tasking_manager" @@ -18,8 +26,8 @@ class Settings(BaseSettings): "https://raw.githubusercontent.com/TaskarCenterAtUW/asr-quests/refs/heads/main/schema/schema.json" ) - # proxy destination--"osm-rails" is a virtual docker network endpoint - WS_OSM_HOST: str = "http://osm-rails:3000" + # proxy destination--"osm-web" is a virtual docker network endpoint + WS_OSM_HOST: str = "http://osm-web" #WS_OSM_HOST: str = "https://osm.workspaces-dev.sidewalks.washington.edu" SENTRY_DSN: str = "" diff --git a/api/core/jwt.py b/api/core/jwt.py new file mode 100644 index 0000000..46e04c8 --- /dev/null +++ b/api/core/jwt.py @@ -0,0 +1,33 @@ +import jwt + +from api.core.config import settings + +# Singleton JWKS client reused to take advantage of internal cert/key caching: +_jwks_client: jwt.PyJWKClient | None = None + + +def _get_jwks_client() -> jwt.PyJWKClient: + global _jwks_client + + if _jwks_client is None: + _jwks_client = jwt.PyJWKClient( + f"{settings.TDEI_OIDC_URL.rstrip("/")}/realms/" + f"{settings.TDEI_OIDC_REALM}/protocol/openid-connect/certs" + ) + + return _jwks_client + + +def validate_and_decode_token(token: str) -> dict: + # TODO: use an async client like pyjwt-key-fetcher + signing_key = _get_jwks_client().get_signing_key_from_jwt(token) + + decoded = jwt.decode_complete( + token, + key=signing_key.key, + algorithms=["RS256"], + # OIDC server does not currently differentiate tokens by audience + options={"verify_aud": False}, + ) + + return decoded.get("payload", {}) diff --git a/api/core/security.py b/api/core/security.py index cc16b30..3e60de1 100644 --- a/api/core/security.py +++ b/api/core/security.py @@ -1,10 +1,8 @@ -import json from enum import StrEnum from uuid import UUID import cachetools -import jwt -import requests +import httpx from fastapi import Depends, HTTPException, status from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer from sqlalchemy import text @@ -12,6 +10,7 @@ from api.core.config import settings from api.core.database import get_osm_session, get_task_session +from api.core.jwt import validate_and_decode_token from api.core.logging import get_logger from api.src.workspaces.schemas import WorkspaceUserRoleType @@ -23,6 +22,25 @@ maxsize=1000, ttl=60 * 60 ) +# Shared HTTP client for TDEI backend calls. Initialized by main.py lifespan. +_tdei_client: httpx.AsyncClient | None = None + + +def init_tdei_client() -> None: + global _tdei_client + _tdei_client = httpx.AsyncClient( + base_url=settings.TDEI_BACKEND_URL, + timeout=httpx.Timeout(connect=10, read=30, write=30, pool=10), + ) + + +async def close_tdei_client() -> None: + global _tdei_client + if _tdei_client is not None: + await _tdei_client.aclose() + _tdei_client = None + + security = HTTPBearer() @@ -84,7 +102,9 @@ def isWorkspaceLead(self, workspaceId: int) -> bool: for pg in self.projectGroups: if TdeiProjectGroupRole.POINT_OF_CONTACT in pg.tdeiRoles: - if workspaceId in self.accessibleWorkspaceIds[pg.project_group_id]: + if workspaceId in self.accessibleWorkspaceIds.get( + pg.project_group_id, [] + ): return True return False @@ -118,6 +138,7 @@ def get_task_db_session( ) -> AsyncSession: return session + async def validate_token( credentials: HTTPAuthorizationCredentials = Depends(security), osm_db_session: AsyncSession = Depends(get_osm_db_session), @@ -129,19 +150,39 @@ async def validate_token( """ token = credentials.credentials + credentials_exception = HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid authentication credentials", + headers={"WWW-Authenticate": "Bearer"}, + ) + + try: + payload = validate_and_decode_token(token) + except Exception: + raise credentials_exception + + user_id: str | None = payload.get("sub") + if user_id is None: + raise credentials_exception + # Check cache first if token in _token_cache: logger.info("Token validation cache hit") return _token_cache[token] # Cache miss - perform full validation - user_info = await _validate_token_uncached(token, osm_db_session, task_db_session) + user_info = await _validate_token_uncached( + token, user_id, payload, osm_db_session, task_db_session + ) _token_cache[token] = user_info + return user_info async def _validate_token_uncached( token: str, + user_id: str, + payload: dict, osm_db_session: AsyncSession, task_db_session: AsyncSession, ) -> UserInfo: @@ -153,59 +194,46 @@ async def _validate_token_uncached( headers={"WWW-Authenticate": "Bearer"}, ) - jwks_client = jwt.PyJWKClient( - f"{settings.TDEI_OIDC_URL}realms/{settings.TDEI_OIDC_REALM}/protocol/openid-connect/certs" - ) - - signing_key = jwks_client.get_signing_key_from_jwt(token) - - jwtDecoded = jwt.decode_complete( - token, - key=signing_key.key, - algorithms=["RS256"], - # OIDC server does not currently differentiate tokens by audience - options={"verify_aud": False} - ) - payload = jwtDecoded.get("payload", {}) - - user_id: str | None = payload.get("sub") - if user_id is None: - raise credentials_exception - headers = { "Authorization": "Bearer " + token, "Content-Type": "application/json", } + r = UserInfo() + + try: + r.user_uuid = UUID(user_id) + except ValueError: + raise credentials_exception from None + + r.credentials = token + r.user_name = payload.get("preferred_username", "unknown") + # get user's project groups and roles from TDEI - # TODO: fix if user has > 50 PGs - authorizationUrl = ( - settings.TDEI_BACKEND_URL - + "/project-group-roles/" - + user_id - + "?page_no=1&page_size=50" - ) + pgs = [] - response = requests.get(authorizationUrl, headers=headers) + try: + response = await _tdei_client.get( + f"project-group-roles/{user_id}", + headers=headers, + params={"page_no": 1, "page_size": 1000}, + ) + except httpx.RequestError: + raise HTTPException( + status_code=status.HTTP_502_BAD_GATEWAY, + detail="Could not reach TDEI backend", + ) from None # token is not valid or server unavailable if response.status_code != 200: raise credentials_exception try: - content = response.text - j = json.loads(content) - except json.JSONDecodeError: + pg_data = response.json() + except Exception: raise credentials_exception - r = UserInfo() - r.credentials = token - r.user_uuid = UUID(payload.get("sub", "unknown")) - r.user_name = payload.get("preferred_username", "unknown") - - # project groups and roles from TDEI KeyCloak - pgs = [] - for i in j: + for i in pg_data: pgs.append( UserInfoPGMembership( project_group_id=i["tdei_project_group_id"], @@ -213,6 +241,7 @@ async def _validate_token_uncached( tdeiRoles=i["roles"], ) ) + r.projectGroups = pgs # workspaces within our set of PGs from tasking manager DB @@ -226,7 +255,7 @@ async def _validate_token_uncached( accessibleWorkspaces = list(result.mappings().all()) r.accessibleWorkspaceIds = {} for i in accessibleWorkspaces: - pgid = i["tdeiProjectGroupId"] + pgid = str(i["tdeiProjectGroupId"]) # SQLAlchemy outputs UUID wsid = i["id"] if pgid not in r.accessibleWorkspaceIds: r.accessibleWorkspaceIds[pgid] = [] diff --git a/api/main.py b/api/main.py index 91c7f73..396b530 100644 --- a/api/main.py +++ b/api/main.py @@ -1,9 +1,11 @@ import os import re +from contextlib import asynccontextmanager import httpx import sentry_sdk from fastapi import Depends, FastAPI, HTTPException, Request, status +from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import RedirectResponse, StreamingResponse from sqlmodel.ext.asyncio.session import AsyncSession from starlette.background import BackgroundTask @@ -12,7 +14,12 @@ from api.core.config import settings from api.core.database import get_task_session from api.core.logging import get_logger, setup_logging -from api.core.security import UserInfo, validate_token +from api.core.security import ( + UserInfo, + close_tdei_client, + init_tdei_client, + validate_token, +) from api.src.teams.routes import router as teams_router from api.src.workspaces.repository import WorkspaceRepository from api.src.workspaces.routes import router as workspaces_router @@ -35,16 +42,50 @@ # Set up logger for this module logger = get_logger(__name__) +# Shared HTTP client for OSM proxy. Reuses connection pool across requests: +_osm_client: httpx.AsyncClient | None = None + + +@asynccontextmanager +async def lifespan(_app: FastAPI): + # Run before app bootstrap: + global _osm_client + _osm_client = httpx.AsyncClient( + base_url=settings.WS_OSM_HOST, + # 2 hour timeout for long-running OSM imports: + timeout=httpx.Timeout(connect=10, read=7200, write=7200, pool=10), + ) + init_tdei_client() + + yield # App runs + + # Run after app cleanup: + await _osm_client.aclose() + _osm_client = None + await close_tdei_client() + + app = FastAPI( title=settings.PROJECT_NAME, debug=settings.DEBUG, swagger_ui_parameters={"syntaxHighlight": False}, + lifespan=lifespan, +) + +app.add_middleware( + CORSMiddleware, + allow_origins=settings.CORS_ORIGINS, + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + max_age=100, ) # Include routers app.include_router(teams_router, prefix="/api/v1") app.include_router(workspaces_router, prefix="/api/v1") + @app.get("/health") async def health_check(): """Health check endpoint. Used for Docker.""" @@ -68,16 +109,88 @@ def get_workspace_repository( # h/t: https://stackoverflow.com/questions/70610266/proxy-an-external-website-using-python-fast-api-not-supporting-query-params # +# According to HTTP/1.1, a proxy must not forward these "hop-by-hop" headers: +HOP_BY_HOP_HEADERS = frozenset( + [ + "connection", + "keep-alive", + "proxy-authenticate", + "proxy-authorization", + "te", + "trailers", + "transfer-encoding", + "upgrade", + ] +) + +# Do not forward spoofed reverse-proxy informational headers: +STRIP_REQUEST_HEADERS = HOP_BY_HOP_HEADERS | { + "host", + "x-forwarded-for", + "x-forwarded-host", + "x-forwarded-proto", + "x-real-ip", + "forwarded", +} + # Define paths that do not require X-Workspace header -AUTH_WHITELIST_PATHS = [ - "/api/0.6/user/*", # used during authentication - "/api/0.6/workspaces/[0-9]*/bbox.json", # used to get workspace bbox without workspace header, to be removed +AUTH_WHITELIST_PATTERNS = [ + re.compile(p) + for p in [ + r"^/api/0\.6/user/.*$", # used during authentication + r"^/api/0\.6/workspaces/[0-9]+/bbox\.json$", # used to get workspace bbox without workspace header, to be removed + ] ] +@app.get("/api/capabilities.json") +async def capabilities(request: Request): + """Proxy OSM capabilities manifest without requiring authentication.""" + + client_host = request.client.host if request.client else "unknown" + req_headers = [ + (k.encode(), v.encode()) + for k, v in request.headers.items() + if k.lower() not in STRIP_REQUEST_HEADERS + ] + [ + (b"Host", _osm_client.base_url.host.encode()), + (b"X-Real-IP", client_host.encode()), + (b"X-Forwarded-For", client_host.encode()), + (b"X-Forwarded-Host", (request.url.hostname or "").encode()), + (b"X-Forwarded-Proto", request.url.scheme.encode()), + ] + + url = httpx.URL(path="/api/capabilities.json") + rp_req = _osm_client.build_request("GET", url, headers=req_headers) + + try: + rp_resp = await _osm_client.send(rp_req, stream=True) + except httpx.TimeoutException: + raise HTTPException( + status_code=status.HTTP_504_GATEWAY_TIMEOUT, + detail="Upstream OSM service timed out", + ) + except httpx.ConnectError: + raise HTTPException( + status_code=status.HTTP_502_BAD_GATEWAY, + detail="Could not connect to upstream OSM service", + ) + + forwarded_headers = { + k: v for k, v in rp_resp.headers.items() if k.lower() not in HOP_BY_HOP_HEADERS + } + + return StreamingResponse( + rp_resp.aiter_raw(), + status_code=rp_resp.status_code, + headers=forwarded_headers, + background=BackgroundTask(rp_resp.aclose), + ) + + @app.api_route( "/{full_path:path}", - methods=["GET", "POST", "PUT", "DELETE", "OPTIONS", "HEAD", "PATCH", "TRACE"], + methods=["GET", "POST", "PUT", "DELETE", "OPTIONS", "HEAD", "PATCH"], ) async def catch_all( request: Request, @@ -87,62 +200,77 @@ async def catch_all( """ Catch-all route to proxy requests to the OSM service. """ - authorizedWorkspace = None if request.headers.get("X-Workspace") is not None: - workspace_id = int(request.headers.get("X-Workspace") or "-1") + try: + workspace_id = int(request.headers.get("X-Workspace") or "-1") + except ValueError: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="X-Workspace header must be a valid integer", + ) if not current_user.isWorkspaceContributor(workspace_id): raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Invalid authentication credentials", - headers={"WWW-Authenticate": "Bearer"}, + status_code=status.HTTP_403_FORBIDDEN, + detail="You do not have access to this workspace", ) - return else: - if not any( - re.search(pattern, request.url.path) for pattern in AUTH_WHITELIST_PATHS - ): + if not any(p.fullmatch(request.url.path) for p in AUTH_WHITELIST_PATTERNS): raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="You must set your workspace in the X-Workspace header to access OSM methods.", + status_code=status.HTTP_400_BAD_REQUEST, + detail="No X-Workspace header supplied", ) - return url = httpx.URL( path=request.url.path.strip(), query=request.url.query.encode("utf-8") ) - client = httpx.AsyncClient(base_url=settings.WS_OSM_HOST) - new_headers = list() - new_headers.append( - (bytes("Authorization", "utf-8"), request.headers.get("Authorization")) - ) - - if authorizedWorkspace is not None: - new_headers.append( - (bytes("X-Workspace", "utf-8"), bytes(str(authorizedWorkspace.id), "utf-8")) - ) - new_headers.append((bytes("Host", "utf-8"), bytes(client.base_url.host, "utf-8"))) + client = _osm_client + client_host = request.client.host if request.client else "unknown" + req_headers = [ + (k.encode(), v.encode()) + for k, v in request.headers.items() + if k.lower() not in STRIP_REQUEST_HEADERS + ] + [ + (b"Host", client.base_url.host.encode()), + (b"X-Real-IP", client_host.encode()), + (b"X-Forwarded-For", client_host.encode()), + (b"X-Forwarded-Host", (request.url.hostname or "").encode()), + (b"X-Forwarded-Proto", request.url.scheme.encode()), + ] rp_req = client.build_request( - request.method, url, headers=new_headers, content=await request.body() + request.method, url, headers=req_headers, content=request.stream() ) - - rp_resp = await client.send(rp_req, stream=True) + try: + rp_resp = await client.send(rp_req, stream=True) + except httpx.TimeoutException: + raise HTTPException( + status_code=status.HTTP_504_GATEWAY_TIMEOUT, + detail="Upstream OSM service timed out", + ) + except httpx.ConnectError: + raise HTTPException( + status_code=status.HTTP_502_BAD_GATEWAY, + detail="Could not connect to upstream OSM service", + ) if rp_resp.status_code >= 400 and rp_resp.status_code < 600: - sentry_sdk.capture_message( - f"Upstream request to {rp_req.url} returned status code {rp_resp.status_code}" + msg = ( + f"Upstream request to {rp_req.url} returned " + f"status code {rp_resp.status_code}" ) + sentry_sdk.capture_message(msg) + logger.warning(msg) - logger.warning( - f"Upstream request to {rp_req.url} returned status code {rp_resp.status_code}" - ) + forwarded_headers = { + k: v for k, v in rp_resp.headers.items() if k.lower() not in HOP_BY_HOP_HEADERS + } return StreamingResponse( rp_resp.aiter_raw(), status_code=rp_resp.status_code, - headers=rp_resp.headers, + headers=forwarded_headers, background=BackgroundTask(rp_resp.aclose), ) diff --git a/api/src/workspaces/repository.py b/api/src/workspaces/repository.py index 080f103..ada15ba 100644 --- a/api/src/workspaces/repository.py +++ b/api/src/workspaces/repository.py @@ -211,8 +211,13 @@ async def getWorkspaceBBox( current_user: UserInfo, workspace_id: int, ): + # Postgres does not support parameter binding for `SET search_path`, so + # workspace_id is interpolated directly. The explicit int() cast guards + # against SQL injection if this method is ever called from outside of a + # FastAPI path handler (where the type annotation acts as a safeguard). + # await self.session.execute( - text(f"SET search_path TO 'workspace-{workspace_id}', public") + text(f"SET search_path TO 'workspace-{int(workspace_id)}', public") ) sql_query = text(