Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions api/core/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,13 @@ def __init__(self, detail: str = "Resource already exists"):
super().__init__(status_code=status.HTTP_409_CONFLICT, detail=detail)


class ConflictException(HTTPException):
"""Base exception for conflict errors."""

def __init__(self, detail: str = "Conflict"):
super().__init__(status_code=status.HTTP_409_CONFLICT, detail=detail)


class UnauthorizedException(HTTPException):
"""Base exception for unauthorized access errors."""

Expand Down
77 changes: 55 additions & 22 deletions api/core/security.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,15 @@
from api.core.database import get_osm_session, get_task_session
from api.core.jwt import validate_and_decode_token
from api.core.logging import get_logger
from api.src.workspaces.schemas import WorkspaceUserRoleType
from api.src.users.schemas import WorkspaceUserRoleType

# Set up logger for this module
logger = get_logger(__name__)

# TTL cache for token validation (1 hour TTL, max 1000 entries)
_token_cache: cachetools.TTLCache[str, "UserInfo"] = cachetools.TTLCache(
# TTL cache keyed by a user's OIDC subject. Evict entries when roles change. We
# still validate the JWT signature and expiry on every request before reading a
# cached record.
_user_info_cache: cachetools.TTLCache[UUID, "UserInfo"] = cachetools.TTLCache(
maxsize=1000, ttl=60 * 60
)

Expand All @@ -41,6 +43,17 @@ async def close_tdei_client() -> None:
_tdei_client = None


def evict_user_from_cache(auth_uid: UUID) -> None:
"""
Evict a user's cached UserInfo object so that their next request re-fetches
permissions.

Call this after modifying a user's roles in the OSM DB to ensure the change
takes effect on their next request rather than after the cache TTL expires.
"""
_user_info_cache.pop(auth_uid, None)


security = HTTPBearer()


Expand Down Expand Up @@ -72,6 +85,7 @@ class UserInfo:
credentials: str
user_uuid: UUID
user_name: str
token_jti: str # JWT ID used to detect token rotation on cache hits

# workspaceId, role from OSM DB
osmWorkspaceRoles: dict[int, list[WorkspaceUserRoleType]]
Expand Down Expand Up @@ -125,6 +139,13 @@ def isWorkspaceContributor(self, workspaceId: int) -> bool:
return True
return False

def effective_role(self, workspaceId: int) -> WorkspaceUserRoleType:
if self.isWorkspaceLead(workspaceId):
return WorkspaceUserRoleType.LEAD
if self.isWorkspaceValidator(workspaceId):
return WorkspaceUserRoleType.VALIDATOR
return WorkspaceUserRoleType.CONTRIBUTOR


# can't use the ORM here since the ORM uses us! (circular dependency)
def get_osm_db_session(
Expand All @@ -144,9 +165,13 @@ async def validate_token(
osm_db_session: AsyncSession = Depends(get_osm_db_session),
task_db_session: AsyncSession = Depends(get_task_db_session),
) -> UserInfo:
"""Dependency to get current authenticated user from TDEI/KeyCloak token and APIs.
"""
Dependency that gets the current authenticated user from the TDEI/KeyCloak
access token and fetches permissions from TDEI APIs.

Results are cached by token for 1 hour to avoid repeated validation calls.
We validate the JWT's signature and expiry on every request. The expensive
TDEI API and DB lookups are cached for 1 hour and should be evicted when a
user's role changes via evict_user_from_cache().
"""
token = credentials.credentials

Expand All @@ -161,27 +186,39 @@ async def validate_token(
except Exception:
raise credentials_exception

user_id: str | None = payload.get("sub")
if user_id is None:
user_id_str: str | None = payload.get("sub")
if user_id_str is None:
raise credentials_exception

# Check cache first
if token in _token_cache:
logger.info("Token validation cache hit")
return _token_cache[token]
try:
user_uuid = UUID(user_id_str)
except ValueError:
raise credentials_exception from None

# Cache miss - perform full validation
# Cache keyed by user UUID. If the token rotated (new "jti") since we
# created the cache entry, evict it so we fetch fresh claims:
#
if user_uuid in _user_info_cache:
cached = _user_info_cache[user_uuid]
current_jti = payload.get("jti", "")
if cached.token_jti == current_jti:
logger.info("Token validation cache hit")
return cached
logger.info("Token validation cache miss: token rotated")
del _user_info_cache[user_uuid]

# Cache miss: fetch TDEI roles and DB data:
user_info = await _validate_token_uncached(
token, user_id, payload, osm_db_session, task_db_session
token, user_uuid, payload, osm_db_session, task_db_session
)
_token_cache[token] = user_info
_user_info_cache[user_uuid] = user_info

return user_info


async def _validate_token_uncached(
token: str,
user_id: str,
user_uuid: UUID,
payload: dict,
osm_db_session: AsyncSession,
task_db_session: AsyncSession,
Expand All @@ -200,21 +237,17 @@ async def _validate_token_uncached(
}

r = UserInfo()

try:
r.user_uuid = UUID(user_id)
except ValueError:
raise credentials_exception from None

r.user_uuid = user_uuid
r.credentials = token
r.token_jti = payload.get("jti", "")
r.user_name = payload.get("preferred_username", "unknown")

# get user's project groups and roles from TDEI
pgs = []

try:
response = await _tdei_client.get(
f"project-group-roles/{user_id}",
f"project-group-roles/{user_uuid}",
headers=headers,
params={"page_no": 1, "page_size": 1000},
)
Expand Down
2 changes: 2 additions & 0 deletions api/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
validate_token,
)
from api.src.teams.routes import router as teams_router
from api.src.users.routes import router as users_router
from api.src.workspaces.repository import WorkspaceRepository
from api.src.workspaces.routes import router as workspaces_router
from api.utils.migrations import run_migrations
Expand Down Expand Up @@ -85,6 +86,7 @@ async def lifespan(_app: FastAPI):

# Include routers
app.include_router(teams_router, prefix="/api/v1")
app.include_router(users_router, prefix="/api/v1")
app.include_router(workspaces_router, prefix="/api/v1")


Expand Down
2 changes: 1 addition & 1 deletion api/src/teams/repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
WorkspaceTeamItem,
WorkspaceTeamUpdate,
)
from api.src.workspaces.schemas import User
from api.src.users.schemas import User


class WorkspaceTeamRepository:
Expand Down
47 changes: 39 additions & 8 deletions api/src/teams/routes.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from fastapi import APIRouter, Depends, status
from fastapi import APIRouter, Depends, HTTPException, status
from sqlmodel.ext.asyncio.session import AsyncSession

from api.core.database import get_osm_session, get_task_session
Expand All @@ -9,8 +9,9 @@
WorkspaceTeamItem,
WorkspaceTeamUpdate,
)
from api.src.workspaces.repository import OSMRepository, WorkspaceRepository
from api.src.workspaces.schemas import User
from api.src.users.repository import UserRepository
from api.src.users.schemas import User
from api.src.workspaces.repository import WorkspaceRepository

router = APIRouter(prefix="/workspaces/{workspace_id}/teams", tags=["teams"])

Expand All @@ -22,10 +23,10 @@ def get_workspace_repo(
return repo


def get_osm_repo(
def get_user_repo(
session: AsyncSession = Depends(get_osm_session),
) -> OSMRepository:
repository = OSMRepository(session)
) -> UserRepository:
repository = UserRepository(session)
return repository


Expand Down Expand Up @@ -56,6 +57,12 @@ async def create_team_for_workspace(
team_repo=Depends(get_team_repo),
current_user: UserInfo = Depends(validate_token),
) -> int:
if not current_user.isWorkspaceLead(workspace_id):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Only workspace leads can create teams",
)

# Repo guards if workspace doesn't exist or user cannot access:
await workspace_repo.getById(current_user, workspace_id)
return await team_repo.create(workspace_id, team)
Expand Down Expand Up @@ -84,6 +91,12 @@ async def update_team_for_workspace(
team_repo=Depends(get_team_repo),
current_user: UserInfo = Depends(validate_token),
):
if not current_user.isWorkspaceLead(workspace_id):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Only workspace leads can update teams",
)

# Repo guards if workspace doesn't exist or user cannot access:
await workspace_repo.getById(current_user, workspace_id)
await team_repo.assert_team_in_workspace(team_id, workspace_id)
Expand All @@ -98,6 +111,12 @@ async def delete_team_from_workspace(
team_repo=Depends(get_team_repo),
current_user: UserInfo = Depends(validate_token),
):
if not current_user.isWorkspaceLead(workspace_id):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Only workspace leads can delete teams",
)

# Repo guards if workspace doesn't exist or user cannot access:
await workspace_repo.getById(current_user, workspace_id)
await team_repo.assert_team_in_workspace(team_id, workspace_id)
Expand All @@ -123,14 +142,14 @@ async def join_workspace_team(
workspace_id: int,
team_id: int,
workspace_repo=Depends(get_workspace_repo),
osm_repo=Depends(get_osm_repo),
user_repo=Depends(get_user_repo),
team_repo=Depends(get_team_repo),
current_user: UserInfo = Depends(validate_token),
) -> User:
# Repo guards if workspace doesn't exist or user cannot access:
await workspace_repo.getById(current_user, workspace_id)
await team_repo.assert_team_in_workspace(team_id, workspace_id)
user = await osm_repo.get_current_user(current_user)
user = await user_repo.get_current_user(current_user)
await team_repo.add_member(team_id, user.id)
return user

Expand All @@ -144,6 +163,12 @@ async def add_member_to_workspace_team(
team_repo=Depends(get_team_repo),
current_user: UserInfo = Depends(validate_token),
):
if not current_user.isWorkspaceLead(workspace_id):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Only workspace leads can add team members",
)

# Repo guards if workspace doesn't exist or user cannot access:
await workspace_repo.getById(current_user, workspace_id)
await team_repo.assert_team_in_workspace(team_id, workspace_id)
Expand All @@ -159,6 +184,12 @@ async def delete_member_from_workspace_team(
team_repo=Depends(get_team_repo),
current_user: UserInfo = Depends(validate_token),
):
if not current_user.isWorkspaceLead(workspace_id):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Only workspace leads can remove team members",
)

# Repo guards if workspace doesn't exist or user cannot access:
await workspace_repo.getById(current_user, workspace_id)
await team_repo.assert_team_in_workspace(team_id, workspace_id)
Expand Down
Loading