Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from src.api.routes.auth import router as auth_router
from src.api.routes.billing import router as billing_router
from src.api.routes.code import router as code_router
from src.api.routes.connectors import router as connectors_router
from src.api.routes.enterprise import router as enterprise_router
from src.api.routes.health import router as health_router
from src.api.routes.memory import router as memory_router
Expand Down Expand Up @@ -226,6 +227,7 @@ async def lifespan(app: FastAPI):
app.include_router(scanner_router)
app.include_router(auth_router)
app.include_router(api_keys_router)
app.include_router(connectors_router)
app.include_router(billing_router)
app.include_router(enterprise_router)
app.include_router(telemetry_router)
Expand Down
256 changes: 256 additions & 0 deletions src/api/routes/connectors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,256 @@
"""Connector OAuth routes for external knowledge sources."""

from __future__ import annotations

import secrets
import os
from datetime import datetime, timedelta, timezone
from typing import Dict, List, Literal, Optional
from urllib.parse import urlencode

from fastapi import APIRouter, Depends, HTTPException, Query, status
from pydantic import BaseModel, Field
Comment on lines +11 to +12
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Import RedirectResponse from fastapi.responses to support redirecting the user back to the frontend application after the OAuth callback completes.

Suggested change
from fastapi import APIRouter, Depends, HTTPException, Query, status
from pydantic import BaseModel, Field
from fastapi import APIRouter, Depends, HTTPException, Query, status
from fastapi.responses import RedirectResponse
from pydantic import BaseModel, Field


from src.api.dependencies import require_user
router = APIRouter(prefix="/api/connectors", tags=["Connectors"])

ConnectorId = Literal["notion", "google-drive"]
ConnectorState = Literal["connected", "not_connected", "pending"]
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The ConnectorState literal includes "pending", but this state is never returned by _status_for or used anywhere in the status responses. If pending status is not needed, consider removing it from the Literal type to keep the schema clean. Otherwise, implement a check in _status_for to see if there is an active pending state in _pending_states for the user.

Suggested change
ConnectorState = Literal["connected", "not_connected", "pending"]
ConnectorState = Literal["connected", "not_connected"]


STATE_TTL_MINUTES = 10
MAX_PENDING_STATES = 1000


class ConnectorDefinition(BaseModel):
id: ConnectorId
name: str
description: str
auth_url: str
token_url: str
scopes: List[str]
docs_url: str


class ConnectorStatusResponse(BaseModel):
id: ConnectorId
name: str
state: ConnectorState
connected_at: Optional[datetime] = None
scopes: List[str] = Field(default_factory=list)
detail: str


class ConnectorListResponse(BaseModel):
connectors: List[ConnectorStatusResponse]


class ConnectorStartResponse(BaseModel):
connector_id: ConnectorId
authorization_url: str
state: str
expires_at: datetime


class ConnectorDisconnectResponse(BaseModel):
connector_id: ConnectorId
disconnected: bool


class PendingOAuthState(BaseModel):
connector_id: ConnectorId
user_id: str
expires_at: datetime


CONNECTORS: Dict[ConnectorId, ConnectorDefinition] = {
"notion": ConnectorDefinition(
id="notion",
name="Notion",
description="Sync selected Notion pages and workspace notes into XMem memory.",
auth_url="https://api.notion.com/v1/oauth/authorize",
token_url="https://api.notion.com/v1/oauth/token",
scopes=[],
docs_url="https://developers.notion.com/docs/authorization",
),
"google-drive": ConnectorDefinition(
id="google-drive",
name="Google Drive",
description="Bring Google Drive docs and files into XMem as searchable memory.",
auth_url="https://accounts.google.com/o/oauth2/v2/auth",
token_url="https://oauth2.googleapis.com/token",
scopes=[
"https://www.googleapis.com/auth/drive.readonly",
"https://www.googleapis.com/auth/documents.readonly",
],
docs_url="https://developers.google.com/identity/protocols/oauth2",
),
}

_pending_states: Dict[str, PendingOAuthState] = {}


def _now() -> datetime:
return datetime.now(timezone.utc)


def _client_id(connector_id: ConnectorId) -> Optional[str]:
if connector_id == "notion":
return os.getenv("NOTION_CLIENT_ID")
return os.getenv("GOOGLE_DRIVE_CLIENT_ID") or os.getenv("GOOGLE_CLIENT_ID")


def _redirect_uri(connector_id: ConnectorId) -> str:
if connector_id == "notion":
return os.getenv(
"NOTION_REDIRECT_URI",
"http://localhost:8000/api/connectors/notion/oauth/callback",
)
return os.getenv(
"GOOGLE_DRIVE_REDIRECT_URI",
"http://localhost:8000/api/connectors/google-drive/oauth/callback",
)


def _get_connector(connector_id: str) -> ConnectorDefinition:
connector = CONNECTORS.get(connector_id) # type: ignore[arg-type]
if not connector:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Unknown connector")
return connector


def _prune_pending_states(now: Optional[datetime] = None) -> None:
current_time = now or _now()
expired = [
key
for key, pending in _pending_states.items()
if pending.expires_at <= current_time
]
for key in expired:
_pending_states.pop(key, None)

overflow = len(_pending_states) - MAX_PENDING_STATES
if overflow > 0:
oldest = sorted(_pending_states.items(), key=lambda item: item[1].expires_at)
for key, _pending in oldest[:overflow]:
_pending_states.pop(key, None)


def _status_for(user_id: str, connector: ConnectorDefinition) -> ConnectorStatusResponse:
return ConnectorStatusResponse(
id=connector.id,
name=connector.name,
state="not_connected",
scopes=connector.scopes,
detail="OAuth start is available; token exchange and sync storage are not connected yet.",
)


def _build_authorization_url(connector: ConnectorDefinition, state: str) -> str:
client_id = _client_id(connector.id)
if not client_id:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail=f"{connector.name} OAuth client ID is not configured",
)

params = {
"client_id": client_id,
"redirect_uri": _redirect_uri(connector.id),
"response_type": "code",
"state": state,
}
if connector.id == "google-drive":
params.update(
{
"access_type": "offline",
"include_granted_scopes": "true",
"prompt": "consent",
"scope": " ".join(connector.scopes),
}
)
if connector.id == "notion":
params["owner"] = "user"

return f"{connector.auth_url}?{urlencode(params)}"


@router.get("", response_model=ConnectorListResponse)
async def list_connectors(current_user: dict = Depends(require_user)) -> ConnectorListResponse:
user_id = str(current_user.get("id"))
return ConnectorListResponse(
connectors=[_status_for(user_id, connector) for connector in CONNECTORS.values()]
)


@router.get("/{connector_id}/status", response_model=ConnectorStatusResponse)
async def connector_status(
connector_id: str,
current_user: dict = Depends(require_user),
) -> ConnectorStatusResponse:
connector = _get_connector(connector_id)
return _status_for(str(current_user.get("id")), connector)


@router.post("/{connector_id}/oauth/start", response_model=ConnectorStartResponse)
async def start_connector_oauth(
connector_id: str,
current_user: dict = Depends(require_user),
) -> ConnectorStartResponse:
connector = _get_connector(connector_id)
_prune_pending_states()
state = secrets.token_urlsafe(32)
expires_at = _now() + timedelta(minutes=STATE_TTL_MINUTES)
authorization_url = _build_authorization_url(connector, state)
_pending_states[state] = PendingOAuthState(
connector_id=connector.id,
user_id=str(current_user.get("id")),
expires_at=expires_at,
)
Comment on lines +201 to +208
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

To prevent memory leaks from abandoned OAuth flows, clean up any expired states from the _pending_states dictionary whenever a new OAuth flow is initiated.

    # Clean up expired states to prevent memory leaks
    now_time = _now()
    expired_states = [k for k, v in _pending_states.items() if v.expires_at <= now_time]
    for k in expired_states:
        _pending_states.pop(k, None)

    state = secrets.token_urlsafe(32)
    expires_at = now_time + timedelta(minutes=STATE_TTL_MINUTES)
    _pending_states[state] = PendingOAuthState(
        connector_id=connector.id,
        user_id=str(current_user.get("id")),
        expires_at=expires_at,
    )


return ConnectorStartResponse(
connector_id=connector.id,
authorization_url=authorization_url,
state=state,
expires_at=expires_at,
)


@router.get("/{connector_id}/oauth/callback")
async def connector_oauth_callback(
connector_id: str,
state: str = Query(..., min_length=1),
code: Optional[str] = Query(None, min_length=1),
error: Optional[str] = Query(None, min_length=1),
) -> dict:
connector = _get_connector(connector_id)
now = _now()
_prune_pending_states(now)
pending = _pending_states.pop(state, None)
if not pending or pending.connector_id != connector.id or pending.expires_at <= now:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid or expired connector authorization state",
)
Comment thread
greptile-apps[bot] marked this conversation as resolved.
if error or not code:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Authorization denied: {error or 'no authorization code received'}",
)

# Token exchange, encrypted credential storage, and source ingestion are intentionally
# separate follow-up steps. Do not mark the connector as connected until those exist.
return {
"status": "pending",
"connector_id": connector.id,
"detail": f"{connector.name} authorization received; token exchange is not enabled yet.",
}
Comment on lines +219 to +246
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Improve the OAuth callback handler to:

  1. Handle errors and denials gracefully: If the user cancels or denies the authorization request, the OAuth provider redirects with an error query parameter instead of a code. Making code optional and checking for error prevents a raw 422 Unprocessable Entity validation error.
  2. Redirect to the frontend: Instead of returning a raw JSON response (which leaves the user stranded on a blank API page), redirect them back to the frontend application using RedirectResponse with the status and connector ID.
async def connector_oauth_callback(
    connector_id: str,
    state: str = Query(..., min_length=1),
    code: Optional[str] = Query(None),
    error: Optional[str] = Query(None),
) -> RedirectResponse:
    connector = _get_connector(connector_id)

    if error or not code:
        return RedirectResponse(
            url=f"{settings.frontend_url}/connectors?status=error&error={error or 'access_denied'}&connector_id={connector.id}"
        )

    pending = _pending_states.pop(state, None)
    if not pending or pending.connector_id != connector.id or pending.expires_at <= _now():
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail="Invalid or expired connector authorization state",
        )

    # Token exchange and source ingestion are intentionally separate follow-up steps.
    # This callback validates the flow and records a pending connection marker only.
    _connections[_connection_key(pending.user_id, connector.id)] = StoredConnection(
        connector_id=connector.id,
        user_id=pending.user_id,
        connected_at=_now(),
        scopes=connector.scopes,
    )
    return RedirectResponse(
        url=f"{settings.frontend_url}/connectors?status=success&connector_id={connector.id}"
    )

Comment thread
greptile-apps[bot] marked this conversation as resolved.


@router.post("/{connector_id}/disconnect", response_model=ConnectorDisconnectResponse)
async def disconnect_connector(
connector_id: str,
current_user: dict = Depends(require_user),
) -> ConnectorDisconnectResponse:
connector = _get_connector(connector_id)
disconnected = False
return ConnectorDisconnectResponse(connector_id=connector.id, disconnected=disconnected)
97 changes: 97 additions & 0 deletions tests/api/test_connectors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
from __future__ import annotations

import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient

from src.api.dependencies import require_user
from src.api.routes import connectors


@pytest.fixture(autouse=True)
def _reset_connector_state() -> None:
connectors._pending_states.clear()


def _user() -> dict:
return {"id": "user-1", "email": "user@example.com", "username": "user"}


def _client() -> TestClient:
app = FastAPI()
app.dependency_overrides[require_user] = _user
app.include_router(connectors.router)
return TestClient(app)


def test_lists_supported_connectors() -> None:
response = _client().get("/api/connectors")

assert response.status_code == 200
body = response.json()
ids = {item["id"] for item in body["connectors"]}
assert ids == {"notion", "google-drive"}
assert {item["state"] for item in body["connectors"]} == {"not_connected"}


def test_oauth_start_requires_configured_client_id(monkeypatch) -> None:
monkeypatch.delenv("NOTION_CLIENT_ID", raising=False)

response = _client().post("/api/connectors/notion/oauth/start")

assert response.status_code == 503
assert "client ID is not configured" in response.json()["detail"]


def test_oauth_start_builds_authorization_url_without_secret(monkeypatch) -> None:
monkeypatch.setenv("GOOGLE_DRIVE_CLIENT_ID", "drive-client")
monkeypatch.setenv("GOOGLE_DRIVE_CLIENT_SECRET", "do-not-leak")
monkeypatch.setenv(
"GOOGLE_DRIVE_REDIRECT_URI",
"http://localhost:8000/api/connectors/google-drive/oauth/callback",
)

response = _client().post("/api/connectors/google-drive/oauth/start")

assert response.status_code == 200
body = response.json()
assert body["connector_id"] == "google-drive"
assert "accounts.google.com" in body["authorization_url"]
assert "client_id=drive-client" in body["authorization_url"]
assert "do-not-leak" not in body["authorization_url"]
assert body["state"]


def test_callback_validates_state_without_marking_connected(monkeypatch) -> None:
monkeypatch.setenv("NOTION_CLIENT_ID", "notion-client")
client = _client()

started = client.post("/api/connectors/notion/oauth/start")
state = started.json()["state"]

callback = client.get(f"/api/connectors/notion/oauth/callback?code=abc&state={state}")
assert callback.status_code == 200
assert callback.json()["status"] == "pending"

status = client.get("/api/connectors/notion/status")
assert status.status_code == 200
assert status.json()["state"] == "not_connected"

disconnected = client.post("/api/connectors/notion/disconnect")
assert disconnected.status_code == 200
assert disconnected.json() == {"connector_id": "notion", "disconnected": False}


def test_callback_handles_provider_denial_and_consumes_state(monkeypatch) -> None:
monkeypatch.setenv("NOTION_CLIENT_ID", "notion-client")
client = _client()

started = client.post("/api/connectors/notion/oauth/start")
state = started.json()["state"]

callback = client.get(f"/api/connectors/notion/oauth/callback?error=access_denied&state={state}")

assert callback.status_code == 400
assert "access_denied" in callback.json()["detail"]
retry = client.get(f"/api/connectors/notion/oauth/callback?code=abc&state={state}")
assert retry.status_code == 400
Loading