diff --git a/backend/routes/auth.py b/backend/routes/auth.py new file mode 100644 index 0000000..0d9da83 --- /dev/null +++ b/backend/routes/auth.py @@ -0,0 +1,62 @@ +""" +Authentication helpers for Supabase-backed routes. +""" + +from fastapi import HTTPException, Request, status +from pydantic import BaseModel + +from routes.supabase_client import get_supabase + + +class AuthenticatedUser(BaseModel): + id: str + email: str | None = None + + +def _get_attr_or_item(value, key: str): + if isinstance(value, dict): + return value.get(key) + return getattr(value, key, None) + + +def _extract_bearer_token(request: Request) -> str: + auth_header = request.headers.get("authorization", "") + scheme, _, token = auth_header.partition(" ") + if scheme.lower() != "bearer" or not token.strip(): + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Missing bearer token", + headers={"WWW-Authenticate": "Bearer"}, + ) + return token.strip() + + +def require_user(request: Request) -> AuthenticatedUser: + token = _extract_bearer_token(request) + try: + response = get_supabase().auth.get_user(token) + except RuntimeError as exc: + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, + detail="Authentication service is not configured", + ) from exc + except Exception as exc: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid bearer token", + headers={"WWW-Authenticate": "Bearer"}, + ) from exc + + user = _get_attr_or_item(response, "user") + user_id = _get_attr_or_item(user, "id") + if not user_id: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid bearer token", + headers={"WWW-Authenticate": "Bearer"}, + ) + + return AuthenticatedUser( + id=user_id, + email=_get_attr_or_item(user, "email"), + ) diff --git a/backend/routes/billing.py b/backend/routes/billing.py index 5721804..d7f1415 100644 --- a/backend/routes/billing.py +++ b/backend/routes/billing.py @@ -7,9 +7,11 @@ from datetime import datetime, timezone from dateutil.relativedelta import relativedelta from typing import Any -from fastapi import APIRouter, HTTPException, Request -from pydantic import BaseModel -from supabase import create_client, Client +from fastapi import APIRouter, Depends, HTTPException, Request +from pydantic import BaseModel, Field + +from routes.auth import AuthenticatedUser, require_user +from routes.supabase_client import get_supabase logger = logging.getLogger(__name__) @@ -40,24 +42,6 @@ DEFAULT_BILLING_MODEL = os.environ.get("ANTHROPIC_DEFAULT_MODEL", "claude-haiku-4-5") -# --------------------------------------------------------------------------- -# Supabase client (service role — bypasses RLS) -# --------------------------------------------------------------------------- - -_supabase: Client | None = None - - -def get_supabase() -> Client: - global _supabase - if _supabase is None: - url = os.environ.get("SUPABASE_URL", "") - key = os.environ.get("SUPABASE_SERVICE_ROLE_KEY", "") - if not url or not key: - raise RuntimeError("SUPABASE_URL and SUPABASE_SERVICE_ROLE_KEY must be set") - _supabase = create_client(url, key) - return _supabase - - # --------------------------------------------------------------------------- # Cost calculation # --------------------------------------------------------------------------- @@ -151,7 +135,7 @@ def get_balance_summary(user_id: str) -> dict[str, float]: def record_usage( *, - user_id: str | None, + user_id: str, session_id: str, model: str | None, input_tokens: int, @@ -160,6 +144,9 @@ def record_usage( cache_read_input_tokens: int = 0, ) -> dict[str, Any]: """Record token usage and deduct cost from user credits.""" + if not user_id: + raise ValueError("user_id is required to record billable usage") + cost = calculate_cost_gbp( model=model, input_tokens=input_tokens, @@ -184,13 +171,6 @@ def record_usage( except Exception as e: logger.warning(f"Could not insert token_usage for {user_id}: {e}") - if not user_id: - return { - "cost_gbp": cost, - "model": _normalise_model_name(model), - "balance": None, - } - # Deduct from free tier first, then balance credits = get_or_create_credits(user_id) free_remaining = max(0, FREE_TIER_GBP - float(credits["free_tier_used_gbp"])) @@ -219,20 +199,23 @@ def record_usage( # --------------------------------------------------------------------------- @router.get("/balance") -def get_balance(user_id: str): - return get_balance_summary(user_id) +def get_balance(current_user: AuthenticatedUser = Depends(require_user)): + return get_balance_summary(current_user.id) @router.get("/usage") -def get_usage(user_id: str, limit: int = 50): +def get_usage( + limit: int = 50, + current_user: AuthenticatedUser = Depends(require_user), +): sb = get_supabase() - result = sb.table("token_usage").select("*").eq("user_id", user_id).order("created_at", desc=True).limit(limit).execute() + bounded_limit = max(1, min(limit, 100)) + result = sb.table("token_usage").select("*").eq("user_id", current_user.id).order("created_at", desc=True).limit(bounded_limit).execute() return result.data class CheckoutRequest(BaseModel): - user_id: str - amount_gbp: float = 5.0 + amount_gbp: float = Field(default=5.0, gt=0, le=500) def _get_public_base_url() -> str: @@ -243,7 +226,10 @@ def _get_public_base_url() -> str: @router.post("/checkout") -def create_checkout(request: CheckoutRequest): +def create_checkout( + request: CheckoutRequest, + current_user: AuthenticatedUser = Depends(require_user), +): import stripe stripe.api_key = os.environ.get("STRIPE_SECRET_KEY", "") @@ -264,7 +250,7 @@ def create_checkout(request: CheckoutRequest): }, "quantity": 1, }], - metadata={"user_id": request.user_id, "amount_gbp": str(request.amount_gbp)}, + metadata={"user_id": current_user.id, "amount_gbp": str(request.amount_gbp)}, success_url=f"{base_url}?topup=success", cancel_url=f"{base_url}?topup=cancel", ) diff --git a/backend/routes/chatbot.py b/backend/routes/chatbot.py index 014327d..60fbc23 100644 --- a/backend/routes/chatbot.py +++ b/backend/routes/chatbot.py @@ -10,7 +10,7 @@ import httpx -from fastapi import APIRouter, Request +from fastapi import APIRouter, Depends, Request from fastapi.responses import JSONResponse, StreamingResponse from pydantic import BaseModel from pydantic_ai import Agent @@ -18,6 +18,8 @@ from pydantic_ai.settings import ModelSettings from agent_tools import execute_tool, TOOL_DEFINITIONS +from routes.auth import AuthenticatedUser, require_user +from routes.billing import check_balance, record_usage logger = logging.getLogger(__name__) @@ -181,7 +183,6 @@ class ChatMessage(BaseModel): class ChatRequest(BaseModel): messages: List[ChatMessage] session_id: str | None = None - user_id: str | None = None class TitleRequest(BaseModel): @@ -194,7 +195,10 @@ class TitleRequest(BaseModel): # --------------------------------------------------------------------------- @router.post("/title") -def generate_title(request: TitleRequest): +def generate_title( + request: TitleRequest, + current_user: AuthenticatedUser = Depends(require_user), +): client = _get_sync_anthropic_client() content = request.first_user_message if request.first_assistant_message: @@ -219,17 +223,15 @@ def generate_title(request: TitleRequest): # --------------------------------------------------------------------------- @router.post("/message") -async def chat_message(request: ChatRequest, http_request: Request): - # Check billing balance if user is authenticated - user_id = request.user_id - if user_id: - try: - from routes.billing import check_balance - has_credit, _ = check_balance(user_id) - if not has_credit: - return JSONResponse(status_code=402, content={"error": "No credit remaining. Please top up to continue."}) - except RuntimeError: - pass # Supabase not configured — skip billing check +async def chat_message( + request: ChatRequest, + http_request: Request, + current_user: AuthenticatedUser = Depends(require_user), +): + user_id = current_user.id + has_credit, _ = check_balance(user_id) + if not has_credit: + return JSONResponse(status_code=402, content={"error": "No credit remaining. Please top up to continue."}) session_id = request.session_id or str(uuid.uuid4()) @@ -341,7 +343,6 @@ async def generate_stream(): # Record token usage for billing billing = None try: - from routes.billing import record_usage billing = record_usage( user_id=user_id, session_id=session_id, @@ -421,7 +422,6 @@ async def execute_tool_async(tu): if iteration >= max_iterations: billing = None try: - from routes.billing import record_usage billing = record_usage( user_id=user_id, session_id=session_id, diff --git a/backend/routes/conversations.py b/backend/routes/conversations.py index df86280..1052370 100644 --- a/backend/routes/conversations.py +++ b/backend/routes/conversations.py @@ -10,10 +10,12 @@ from datetime import datetime, timezone from typing import List, Optional -from fastapi import APIRouter, HTTPException +from fastapi import APIRouter, Depends, HTTPException from pydantic import BaseModel from sqlmodel import Field, Session, SQLModel, create_engine, select +from routes.auth import AuthenticatedUser, require_user + logger = logging.getLogger(__name__) router = APIRouter(prefix="/conversations", tags=["conversations"]) @@ -49,8 +51,6 @@ class SaveConversationRequest(BaseModel): session_id: str title: str messages: list - user_id: str | None = None - user_email: str | None = None class ConversationSummary(BaseModel): @@ -91,7 +91,10 @@ def ensure_table(): @router.post("", response_model=ConversationDetail) -def save_conversation(request: SaveConversationRequest): +def save_conversation( + request: SaveConversationRequest, + current_user: AuthenticatedUser = Depends(require_user), +): now = datetime.now(timezone.utc) engine = get_engine() @@ -101,11 +104,13 @@ def save_conversation(request: SaveConversationRequest): ).first() if existing: + if existing.user_id and existing.user_id != current_user.id: + raise HTTPException(status_code=403, detail="Not your conversation") existing.title = request.title existing.messages = json.dumps(request.messages) existing.updated_at = now - existing.user_id = request.user_id - existing.user_email = request.user_email + existing.user_id = current_user.id + existing.user_email = current_user.email session.add(existing) session.commit() session.refresh(existing) @@ -115,8 +120,8 @@ def save_conversation(request: SaveConversationRequest): session_id=request.session_id, title=request.title, messages=json.dumps(request.messages), - user_id=request.user_id, - user_email=request.user_email, + user_id=current_user.id, + user_email=current_user.email, created_at=now, updated_at=now, ) @@ -132,14 +137,10 @@ def save_conversation(request: SaveConversationRequest): @router.get("") -def list_conversations(user_id: str | None = None): +def list_conversations(current_user: AuthenticatedUser = Depends(require_user)): engine = get_engine() with Session(engine) as session: - stmt = select(ChatConversation) - if user_id: - stmt = stmt.where(ChatConversation.user_id == user_id) - else: - stmt = stmt.where(ChatConversation.user_id == None) + stmt = select(ChatConversation).where(ChatConversation.user_id == current_user.id) stmt = stmt.order_by(ChatConversation.updated_at.desc()).limit(100) rows = session.exec(stmt).all() return [ @@ -150,13 +151,18 @@ def list_conversations(user_id: str | None = None): ] -@router.get("/{conversation_id}", response_model=ConversationDetail) -def get_conversation(conversation_id: int): +@router.get("/{conversation_id:int}", response_model=ConversationDetail) +def get_conversation( + conversation_id: int, + current_user: AuthenticatedUser = Depends(require_user), +): engine = get_engine() with Session(engine) as session: row = session.get(ChatConversation, conversation_id) if not row: raise HTTPException(status_code=404, detail="Conversation not found") + if row.user_id != current_user.id: + raise HTTPException(status_code=403, detail="Not your conversation") return ConversationDetail( id=row.id, session_id=row.session_id, title=row.title, messages=json.loads(row.messages) if isinstance(row.messages, str) else row.messages, @@ -172,7 +178,6 @@ class SharedConversationDetail(BaseModel): class ReportConversationRequest(BaseModel): - user_id: str | None = None note: str | None = None app_url: str | None = None @@ -265,14 +270,17 @@ def _build_issue_body( return "\n".join(lines).strip() + "\n" -@router.post("/{conversation_id}/share") -def share_conversation(conversation_id: int, user_id: str | None = None): +@router.post("/{conversation_id:int}/share") +def share_conversation( + conversation_id: int, + current_user: AuthenticatedUser = Depends(require_user), +): engine = get_engine() with Session(engine) as session: row = session.get(ChatConversation, conversation_id) if not row: raise HTTPException(status_code=404, detail="Conversation not found") - if row.user_id and row.user_id != user_id: + if row.user_id != current_user.id: raise HTTPException(status_code=403, detail="Not your conversation") if not row.share_token: row.share_token = str(uuid.uuid4()) @@ -282,14 +290,18 @@ def share_conversation(conversation_id: int, user_id: str | None = None): return {"share_token": row.share_token} -@router.post("/{conversation_id}/report", response_model=ReportConversationResponse) -def report_conversation(conversation_id: int, request: ReportConversationRequest): +@router.post("/{conversation_id:int}/report", response_model=ReportConversationResponse) +def report_conversation( + conversation_id: int, + request: ReportConversationRequest, + current_user: AuthenticatedUser = Depends(require_user), +): engine = get_engine() with Session(engine) as session: row = session.get(ChatConversation, conversation_id) if not row: raise HTTPException(status_code=404, detail="Conversation not found") - if row.user_id and row.user_id != request.user_id: + if row.user_id != current_user.id: raise HTTPException(status_code=403, detail="Not your conversation") if not row.share_token: row.share_token = str(uuid.uuid4()) @@ -337,12 +349,17 @@ def get_shared_conversation(share_token: str): ) -@router.delete("/{conversation_id}", status_code=204) -def delete_conversation(conversation_id: int): +@router.delete("/{conversation_id:int}", status_code=204) +def delete_conversation( + conversation_id: int, + current_user: AuthenticatedUser = Depends(require_user), +): engine = get_engine() with Session(engine) as session: row = session.get(ChatConversation, conversation_id) if not row: raise HTTPException(status_code=404, detail="Conversation not found") + if row.user_id != current_user.id: + raise HTTPException(status_code=403, detail="Not your conversation") session.delete(row) session.commit() diff --git a/backend/routes/supabase_client.py b/backend/routes/supabase_client.py new file mode 100644 index 0000000..27a172a --- /dev/null +++ b/backend/routes/supabase_client.py @@ -0,0 +1,20 @@ +""" +Shared Supabase client setup. +""" + +import os + +from supabase import Client, create_client + +_supabase: Client | None = None + + +def get_supabase() -> Client: + global _supabase + if _supabase is None: + url = os.environ.get("SUPABASE_URL", "") + key = os.environ.get("SUPABASE_SERVICE_ROLE_KEY", "") + if not url or not key: + raise RuntimeError("SUPABASE_URL and SUPABASE_SERVICE_ROLE_KEY must be set") + _supabase = create_client(url, key) + return _supabase diff --git a/backend/tests/test_api.py b/backend/tests/test_api.py index 90a2456..94b7aa8 100644 --- a/backend/tests/test_api.py +++ b/backend/tests/test_api.py @@ -8,10 +8,30 @@ import pytest from fastapi.testclient import TestClient from main import app +from routes.auth import AuthenticatedUser, require_user + + +def _test_user(): + return AuthenticatedUser(id="test-user", email="user@example.com") + + +app.dependency_overrides[require_user] = _test_user client = TestClient(app) +@pytest.fixture(autouse=True) +def mock_billing(monkeypatch): + import routes.chatbot as chatbot + + monkeypatch.setattr(chatbot, "check_balance", lambda user_id: (True, {})) + monkeypatch.setattr( + chatbot, + "record_usage", + lambda **kwargs: {"cost_gbp": 0.0, "balance": None}, + ) + + # --------------------------------------------------------------------------- # Health # --------------------------------------------------------------------------- @@ -28,12 +48,11 @@ def test_health(self): # --------------------------------------------------------------------------- class TestConversations: - def _save(self, session_id="test-session-1", title="Test", messages=None, user_id=None): + def _save(self, session_id="test-session-1", title="Test", messages=None): return client.post("/conversations", json={ "session_id": session_id, "title": title, "messages": messages or [{"role": "user", "content": "hello"}], - "user_id": user_id, }) def test_save_conversation(self): @@ -45,8 +64,8 @@ def test_save_conversation(self): assert "id" in data def test_list_conversations(self): - self._save(session_id="list-test-1", user_id="user@example.com") - r = client.get("/conversations", params={"user_id": "user@example.com"}) + self._save(session_id="list-test-1") + r = client.get("/conversations") assert r.status_code == 200 assert isinstance(r.json(), list) session_ids = [c["session_id"] for c in r.json()] @@ -94,10 +113,20 @@ def test_messages_roundtrip(self): assert len(loaded) == 2 assert loaded[1]["events"][0]["content"] == "hi there" - def test_list_without_user_id_returns_anonymous(self): - self._save(session_id="anon-test-1", user_id=None) + def test_list_returns_only_authenticated_user_conversations(self): + self._save(session_id="auth-list-test-1") r = client.get("/conversations") assert r.status_code == 200 + session_ids = [c["session_id"] for c in r.json()] + assert "auth-list-test-1" in session_ids + + def test_list_requires_authentication(self): + app.dependency_overrides.pop(require_user, None) + try: + r = client.get("/conversations") + assert r.status_code == 401 + finally: + app.dependency_overrides[require_user] = _test_user def test_report_includes_tool_inputs_and_outputs(self): messages = [ @@ -119,11 +148,11 @@ def test_report_includes_tool_inputs_and_outputs(self): ], }, ] - save_r = self._save(session_id="report-test-1", messages=messages, user_id="user-1") + save_r = self._save(session_id="report-test-1", messages=messages) conv_id = save_r.json()["id"] report_r = client.post( f"/conversations/{conv_id}/report", - json={"user_id": "user-1", "app_url": "https://example.com"}, + json={"app_url": "https://example.com"}, ) assert report_r.status_code == 200 issue_body = report_r.json()["issue_body"] diff --git a/frontend/src/app/ChatPage.tsx b/frontend/src/app/ChatPage.tsx index fb69651..de01e7c 100644 --- a/frontend/src/app/ChatPage.tsx +++ b/frontend/src/app/ChatPage.tsx @@ -11,6 +11,7 @@ import { oneDark } from "react-syntax-highlighter/dist/esm/styles/prism"; import { Chart, extractChartSpecs, ChartSpec } from "@/components/charts"; import { THEME } from "@/components/theme"; import { getBackendEndpoint } from "@/utils/backend"; +import { getSupabase } from "@/utils/supabase"; const EXAMPLE_QUERIES = [ "What's the current personal allowance?", @@ -128,7 +129,7 @@ interface BalanceSummary { async function apiRequest(method: string, endpoint: string, params?: Record, body?: unknown): Promise { const url = new URL(getBackendEndpoint(endpoint), window.location.origin); if (params) Object.entries(params).forEach(([k, v]) => url.searchParams.set(k, v)); - const options: RequestInit = { method, headers: { "Content-Type": "application/json" } }; + const options: RequestInit = { method, headers: { "Content-Type": "application/json", ...(await getAuthHeaders()) } }; if (body && ["POST", "PUT", "PATCH"].includes(method)) options.body = JSON.stringify(body); const res = await fetch(url.toString(), options); if (!res.ok) { @@ -140,6 +141,13 @@ async function apiRequest(method: string, endpoint: string, params?: Record> { + const supabase = getSupabase(); + if (!supabase) return {}; + const { data: { session } } = await supabase.auth.getSession(); + return session?.access_token ? { Authorization: `Bearer ${session.access_token}` } : {}; +} + export default function ChatPage() { const { user, loading: authLoading, signIn, signUp, signOut } = useAuth(); const [messages, setMessages] = useState([]); @@ -177,7 +185,7 @@ export default function ChatPage() { const fetchBalance = useCallback(async () => { if (!user) return; - const data = await apiRequest("GET", "billing/balance", { user_id: user.id }); + const data = await apiRequest("GET", "billing/balance"); setBalance(data); }, [user]); @@ -185,7 +193,7 @@ export default function ChatPage() { if (!user) return; setTopUpLoading(true); try { - const { url } = await apiRequest<{ url: string }>("POST", "billing/checkout", undefined, { user_id: user.id, amount_gbp: amount }); + const { url } = await apiRequest<{ url: string }>("POST", "billing/checkout", undefined, { amount_gbp: amount }); if (url) window.location.href = url; } catch (e) { console.error("Checkout failed", e); } finally { setTopUpLoading(false); } @@ -206,7 +214,7 @@ export default function ChatPage() { useEffect(() => { inputRef.current?.focus(); if (!authLoading && user) { - apiRequest("GET", "conversations", { user_id: user.id }) + apiRequest("GET", "conversations") .then((convs) => { setConversations(convs); // Preload conversation details in background @@ -251,7 +259,7 @@ export default function ChatPage() { const shareConversation = async (e: React.MouseEvent, id: number) => { e.stopPropagation(); try { - const { share_token } = await apiRequest<{ share_token: string }>("POST", `conversations/${id}/share`, user?.id ? { user_id: user.id } : undefined); + const { share_token } = await apiRequest<{ share_token: string }>("POST", `conversations/${id}/share`); const url = `${window.location.origin}/s/${share_token}`; await navigator.clipboard.writeText(url); setCopiedShareId(id); @@ -293,7 +301,7 @@ export default function ChatPage() { }); try { - const saved = await apiRequest("POST", "conversations", undefined, { session_id: sid, title, messages: apiMessages, user_id: user?.id, user_email: user?.email }); + const saved = await apiRequest("POST", "conversations", undefined, { session_id: sid, title, messages: apiMessages }); setActiveConversationId(saved.id); conversationCache.current.set(saved.id, saved); setConversations((prev) => { @@ -321,7 +329,6 @@ export default function ChatPage() { const conversationId = await ensureConversationForReport(); if (!conversationId) throw new Error("Could not save this thread for reporting."); const data = await apiRequest("POST", `conversations/${conversationId}/report`, undefined, { - user_id: user?.id, note: reportNote.trim() || null, app_url: window.location.origin, }); @@ -346,6 +353,11 @@ export default function ChatPage() { const sendMessage = async () => { if (!input.trim() || isStreaming) return; + if (authLoading) return; + if (!user) { + setShowAuth(true); + return; + } const userMessage: Message = { role: "user", content: input }; const allMessages = [...messages, userMessage]; setMessages((prev) => [...prev, userMessage]); @@ -416,8 +428,8 @@ export default function ChatPage() { try { const response = await fetch(getBackendEndpoint("chat/message"), { method: "POST", - headers: { "Content-Type": "application/json" }, - body: JSON.stringify({ messages: apiMessages, session_id: sessionId.current, user_id: user?.id || null }), + headers: { "Content-Type": "application/json", ...(await getAuthHeaders()) }, + body: JSON.stringify({ messages: apiMessages, session_id: sessionId.current }), signal: controller.signal, }); if (response.status === 402) {