Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 62 additions & 0 deletions backend/routes/auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
"""
Authentication helpers for Supabase-backed routes.
"""

from fastapi import HTTPException, Request, status
from pydantic import BaseModel

from routes.supabase_client import get_supabase


class AuthenticatedUser(BaseModel):
id: str
email: str | None = None


def _get_attr_or_item(value, key: str):
if isinstance(value, dict):
return value.get(key)
return getattr(value, key, None)


def _extract_bearer_token(request: Request) -> str:
auth_header = request.headers.get("authorization", "")
scheme, _, token = auth_header.partition(" ")
if scheme.lower() != "bearer" or not token.strip():
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Missing bearer token",
headers={"WWW-Authenticate": "Bearer"},
)
return token.strip()


def require_user(request: Request) -> AuthenticatedUser:
token = _extract_bearer_token(request)
try:
response = get_supabase().auth.get_user(token)
except RuntimeError as exc:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="Authentication service is not configured",
) from exc
except Exception as exc:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid bearer token",
headers={"WWW-Authenticate": "Bearer"},
) from exc

user = _get_attr_or_item(response, "user")
user_id = _get_attr_or_item(user, "id")
if not user_id:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid bearer token",
headers={"WWW-Authenticate": "Bearer"},
)

return AuthenticatedUser(
id=user_id,
email=_get_attr_or_item(user, "email"),
)
60 changes: 23 additions & 37 deletions backend/routes/billing.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@
from datetime import datetime, timezone
from dateutil.relativedelta import relativedelta
from typing import Any
from fastapi import APIRouter, HTTPException, Request
from pydantic import BaseModel
from supabase import create_client, Client
from fastapi import APIRouter, Depends, HTTPException, Request
from pydantic import BaseModel, Field

from routes.auth import AuthenticatedUser, require_user
from routes.supabase_client import get_supabase

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -40,24 +42,6 @@

DEFAULT_BILLING_MODEL = os.environ.get("ANTHROPIC_DEFAULT_MODEL", "claude-haiku-4-5")

# ---------------------------------------------------------------------------
# Supabase client (service role — bypasses RLS)
# ---------------------------------------------------------------------------

_supabase: Client | None = None


def get_supabase() -> Client:
global _supabase
if _supabase is None:
url = os.environ.get("SUPABASE_URL", "")
key = os.environ.get("SUPABASE_SERVICE_ROLE_KEY", "")
if not url or not key:
raise RuntimeError("SUPABASE_URL and SUPABASE_SERVICE_ROLE_KEY must be set")
_supabase = create_client(url, key)
return _supabase


# ---------------------------------------------------------------------------
# Cost calculation
# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -151,7 +135,7 @@ def get_balance_summary(user_id: str) -> dict[str, float]:

def record_usage(
*,
user_id: str | None,
user_id: str,
session_id: str,
model: str | None,
input_tokens: int,
Expand All @@ -160,6 +144,9 @@ def record_usage(
cache_read_input_tokens: int = 0,
) -> dict[str, Any]:
"""Record token usage and deduct cost from user credits."""
if not user_id:
raise ValueError("user_id is required to record billable usage")

cost = calculate_cost_gbp(
model=model,
input_tokens=input_tokens,
Expand All @@ -184,13 +171,6 @@ def record_usage(
except Exception as e:
logger.warning(f"Could not insert token_usage for {user_id}: {e}")

if not user_id:
return {
"cost_gbp": cost,
"model": _normalise_model_name(model),
"balance": None,
}

# Deduct from free tier first, then balance
credits = get_or_create_credits(user_id)
free_remaining = max(0, FREE_TIER_GBP - float(credits["free_tier_used_gbp"]))
Expand Down Expand Up @@ -219,20 +199,23 @@ def record_usage(
# ---------------------------------------------------------------------------

@router.get("/balance")
def get_balance(user_id: str):
return get_balance_summary(user_id)
def get_balance(current_user: AuthenticatedUser = Depends(require_user)):
return get_balance_summary(current_user.id)


@router.get("/usage")
def get_usage(user_id: str, limit: int = 50):
def get_usage(
limit: int = 50,
current_user: AuthenticatedUser = Depends(require_user),
):
sb = get_supabase()
result = sb.table("token_usage").select("*").eq("user_id", user_id).order("created_at", desc=True).limit(limit).execute()
bounded_limit = max(1, min(limit, 100))
result = sb.table("token_usage").select("*").eq("user_id", current_user.id).order("created_at", desc=True).limit(bounded_limit).execute()
return result.data


class CheckoutRequest(BaseModel):
user_id: str
amount_gbp: float = 5.0
amount_gbp: float = Field(default=5.0, gt=0, le=500)


def _get_public_base_url() -> str:
Expand All @@ -243,7 +226,10 @@ def _get_public_base_url() -> str:


@router.post("/checkout")
def create_checkout(request: CheckoutRequest):
def create_checkout(
request: CheckoutRequest,
current_user: AuthenticatedUser = Depends(require_user),
):
import stripe

stripe.api_key = os.environ.get("STRIPE_SECRET_KEY", "")
Expand All @@ -264,7 +250,7 @@ def create_checkout(request: CheckoutRequest):
},
"quantity": 1,
}],
metadata={"user_id": request.user_id, "amount_gbp": str(request.amount_gbp)},
metadata={"user_id": current_user.id, "amount_gbp": str(request.amount_gbp)},
success_url=f"{base_url}?topup=success",
cancel_url=f"{base_url}?topup=cancel",
)
Expand Down
32 changes: 16 additions & 16 deletions backend/routes/chatbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,16 @@

import httpx

from fastapi import APIRouter, Request
from fastapi import APIRouter, Depends, Request
from fastapi.responses import JSONResponse, StreamingResponse
from pydantic import BaseModel
from pydantic_ai import Agent
from pydantic_ai.models.anthropic import AnthropicModel
from pydantic_ai.settings import ModelSettings

from agent_tools import execute_tool, TOOL_DEFINITIONS
from routes.auth import AuthenticatedUser, require_user
from routes.billing import check_balance, record_usage

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -181,7 +183,6 @@ class ChatMessage(BaseModel):
class ChatRequest(BaseModel):
messages: List[ChatMessage]
session_id: str | None = None
user_id: str | None = None


class TitleRequest(BaseModel):
Expand All @@ -194,7 +195,10 @@ class TitleRequest(BaseModel):
# ---------------------------------------------------------------------------

@router.post("/title")
def generate_title(request: TitleRequest):
def generate_title(
request: TitleRequest,
current_user: AuthenticatedUser = Depends(require_user),
):
client = _get_sync_anthropic_client()
content = request.first_user_message
if request.first_assistant_message:
Expand All @@ -219,17 +223,15 @@ def generate_title(request: TitleRequest):
# ---------------------------------------------------------------------------

@router.post("/message")
async def chat_message(request: ChatRequest, http_request: Request):
# Check billing balance if user is authenticated
user_id = request.user_id
if user_id:
try:
from routes.billing import check_balance
has_credit, _ = check_balance(user_id)
if not has_credit:
return JSONResponse(status_code=402, content={"error": "No credit remaining. Please top up to continue."})
except RuntimeError:
pass # Supabase not configured — skip billing check
async def chat_message(
request: ChatRequest,
http_request: Request,
current_user: AuthenticatedUser = Depends(require_user),
):
user_id = current_user.id
has_credit, _ = check_balance(user_id)
if not has_credit:
return JSONResponse(status_code=402, content={"error": "No credit remaining. Please top up to continue."})

session_id = request.session_id or str(uuid.uuid4())

Expand Down Expand Up @@ -341,7 +343,6 @@ async def generate_stream():
# Record token usage for billing
billing = None
try:
from routes.billing import record_usage
billing = record_usage(
user_id=user_id,
session_id=session_id,
Expand Down Expand Up @@ -421,7 +422,6 @@ async def execute_tool_async(tu):
if iteration >= max_iterations:
billing = None
try:
from routes.billing import record_usage
billing = record_usage(
user_id=user_id,
session_id=session_id,
Expand Down
Loading
Loading