diff --git a/api/app/settings/common.py b/api/app/settings/common.py index c75fae2fbf34..e82fd1b8241a 100644 --- a/api/app/settings/common.py +++ b/api/app/settings/common.py @@ -88,6 +88,7 @@ "rest_framework.authtoken", # Used for managing api keys "rest_framework_api_key", + "oauth2_provider", "rest_framework_simplejwt.token_blacklist", "djoser", "django.contrib.sites", @@ -165,6 +166,7 @@ "softdelete", "metadata", "app_analytics", + "oauth2_metadata", ] SILENCED_SYSTEM_CHECKS = ["axes.W002"] @@ -312,6 +314,7 @@ "custom_auth.jwt_cookie.authentication.JWTCookieAuthentication", "rest_framework.authentication.TokenAuthentication", "api_keys.authentication.MasterAPIKeyAuthentication", + "oauth2_metadata.authentication.OAuth2BearerTokenAuthentication", ), "PAGE_SIZE": 10, "UNICODE_JSON": False, @@ -941,6 +944,26 @@ "SIGNING_KEY": env.str("COOKIE_AUTH_JWT_SIGNING_KEY", default=SECRET_KEY), } +# OAuth 2.1 Provider (django-oauth-toolkit) +FLAGSMITH_API_URL = env.str("FLAGSMITH_API_URL", default="http://localhost:8000") +FLAGSMITH_FRONTEND_URL = env.str( + "FLAGSMITH_FRONTEND_URL", default="http://localhost:8080" +) + +OAUTH2_PROVIDER = { + "ACCESS_TOKEN_EXPIRE_SECONDS": 60 * 15, # 15 minutes + "REFRESH_TOKEN_EXPIRE_SECONDS": 60 * 60 * 24 * 30, # 30 days + "ROTATE_REFRESH_TOKEN": True, + "PKCE_REQUIRED": True, + "ALLOWED_CODE_CHALLENGE_METHODS": ["S256"], + "SCOPES": {"mcp": "MCP access"}, + "DEFAULT_SCOPES": ["mcp"], + "ALLOWED_GRANT_TYPES": [ + "authorization_code", + "refresh_token", + ], +} + # Github OAuth credentials GITHUB_CLIENT_ID = env.str("GITHUB_CLIENT_ID", default="") GITHUB_CLIENT_SECRET = env.str("GITHUB_CLIENT_SECRET", default="") diff --git a/api/app/urls.py b/api/app/urls.py index d5b68f85f6ca..b9a7e1181b32 100644 --- a/api/app/urls.py +++ b/api/app/urls.py @@ -6,6 +6,7 @@ from django.urls import include, path, re_path from django.views.generic.base import TemplateView +from oauth2_metadata.views import authorization_server_metadata from users.views import password_reset_redirect from . import views @@ -13,6 +14,11 @@ urlpatterns = [ *core_urlpatterns, path("processor/", include("task_processor.urls")), + path( + ".well-known/oauth-authorization-server", + authorization_server_metadata, + name="oauth-authorization-server-metadata", + ), ] if not settings.TASK_PROCESSOR_MODE: @@ -47,6 +53,8 @@ "robots.txt", TemplateView.as_view(template_name="robots.txt", content_type="text/plain"), ), + # Authorize template view for testing: this will be moved to the frontend in following issues + path("o/", include("oauth2_provider.urls", namespace="oauth2_provider")), ] if settings.DEBUG: # pragma: no cover diff --git a/api/oauth2_metadata/__init__.py b/api/oauth2_metadata/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/api/oauth2_metadata/apps.py b/api/oauth2_metadata/apps.py new file mode 100644 index 000000000000..531858ca6ec9 --- /dev/null +++ b/api/oauth2_metadata/apps.py @@ -0,0 +1,8 @@ +from django.apps import AppConfig + + +class OAuth2MetadataConfig(AppConfig): + name = "oauth2_metadata" + + def ready(self) -> None: + from oauth2_metadata import tasks # noqa: F401 diff --git a/api/oauth2_metadata/authentication.py b/api/oauth2_metadata/authentication.py new file mode 100644 index 000000000000..d46d76500a97 --- /dev/null +++ b/api/oauth2_metadata/authentication.py @@ -0,0 +1,17 @@ +from oauth2_provider.contrib.rest_framework import ( + OAuth2Authentication, # type: ignore[import-untyped] +) +from rest_framework.request import Request + + +class OAuth2BearerTokenAuthentication(OAuth2Authentication): # type: ignore[misc] + """DOT's default OAuth2Authentication also reads the request body + looking for an access_token, which consumes the stream and breaks + views that need to read request.body. + """ + + def authenticate(self, request: Request) -> tuple[object, str] | None: + auth_header = request.META.get("HTTP_AUTHORIZATION", "") + if not auth_header.startswith("Bearer "): + return None + return super().authenticate(request) # type: ignore[return-value] diff --git a/api/oauth2_metadata/tasks.py b/api/oauth2_metadata/tasks.py new file mode 100644 index 000000000000..372078267f9e --- /dev/null +++ b/api/oauth2_metadata/tasks.py @@ -0,0 +1,9 @@ +from datetime import timedelta + +from django.core.management import call_command +from task_processor.decorators import register_recurring_task + + +@register_recurring_task(run_every=timedelta(hours=24)) +def clear_expired_oauth2_tokens() -> None: + call_command("cleartokens") diff --git a/api/oauth2_metadata/views.py b/api/oauth2_metadata/views.py new file mode 100644 index 000000000000..25cbc77071d5 --- /dev/null +++ b/api/oauth2_metadata/views.py @@ -0,0 +1,37 @@ +from typing import Any + +from django.conf import settings +from django.http import HttpRequest, JsonResponse +from django.views.decorators.csrf import csrf_exempt +from django.views.decorators.http import require_GET + + +@csrf_exempt +@require_GET +def authorization_server_metadata(request: HttpRequest) -> JsonResponse: + """RFC 8414 OAuth 2.0 Authorization Server Metadata.""" + api_url: str = settings.FLAGSMITH_API_URL.rstrip("/") + frontend_url: str = settings.FLAGSMITH_FRONTEND_URL.rstrip("/") + oauth2_settings: dict[str, Any] = settings.OAUTH2_PROVIDER + scopes: dict[str, str] = oauth2_settings.get("SCOPES", {}) + + metadata = { + "issuer": api_url, + "authorization_endpoint": f"{frontend_url}/oauth/authorize/", + "token_endpoint": f"{api_url}/o/token/", + "registration_endpoint": f"{api_url}/o/register/", + "revocation_endpoint": f"{api_url}/o/revoke_token/", + "introspection_endpoint": f"{api_url}/o/introspect/", + "scopes_supported": list(scopes.keys()), + "response_types_supported": ["code"], + "grant_types_supported": ["authorization_code", "refresh_token"], + "code_challenge_methods_supported": ["S256"], + "token_endpoint_auth_methods_supported": [ + "client_secret_basic", + "client_secret_post", + "none", + ], + "introspection_endpoint_auth_methods_supported": ["none"], + } + + return JsonResponse(metadata) diff --git a/api/oauth2_test_server.mjs b/api/oauth2_test_server.mjs new file mode 100644 index 000000000000..1d0029a59faf --- /dev/null +++ b/api/oauth2_test_server.mjs @@ -0,0 +1,81 @@ +import { createServer } from "node:http"; +import { randomBytes, createHash } from "node:crypto"; + +const CLIENT_ID = "ZLsLu3hhJI4GlhNsGeFVC3K2U3QBGfXtmc0EcyiG"; +const REDIRECT_URI = "http://localhost:3000/oauth/callback"; +const API_URL = "http://localhost:8000"; +const PORT = 3000; + +// Generate PKCE values +const codeVerifier = randomBytes(96).toString("base64url").slice(0, 128); +const codeChallenge = createHash("sha256") + .update(codeVerifier) + .digest("base64url"); + +const authorizeUrl = + `${API_URL}/o/authorize/?` + + new URLSearchParams({ + response_type: "code", + client_id: CLIENT_ID, + redirect_uri: REDIRECT_URI, + scope: "mcp", + code_challenge: codeChallenge, + code_challenge_method: "S256", + }); + +const server = createServer(async (req, res) => { + const url = new URL(req.url, `http://localhost:${PORT}`); + + if (url.pathname === "/oauth/callback") { + const code = url.searchParams.get("code"); + const error = url.searchParams.get("error"); + + if (error) { + res.writeHead(400, { "Content-Type": "text/plain" }); + res.end(`Error: ${error}\n${url.searchParams.get("error_description")}`); + return; + } + + if (!code) { + res.writeHead(400, { "Content-Type": "text/plain" }); + res.end("No authorization code received"); + return; + } + + console.log(`\nReceived authorization code: ${code}`); + console.log("Exchanging for token...\n"); + + // Exchange code for token + const tokenRes = await fetch(`${API_URL}/o/token/`, { + method: "POST", + headers: { "Content-Type": "application/x-www-form-urlencoded" }, + body: new URLSearchParams({ + grant_type: "authorization_code", + code, + redirect_uri: REDIRECT_URI, + client_id: CLIENT_ID, + code_verifier: codeVerifier, + }), + }); + + const tokenData = await tokenRes.json(); + console.log("Token response:", JSON.stringify(tokenData, null, 2)); + + res.writeHead(200, { "Content-Type": "text/html" }); + res.end(`
${JSON.stringify(tokenData, null, 2)}
`); + + // Done - shut down + setTimeout(() => { + console.log("\nDone. Shutting down."); + process.exit(0); + }, 1000); + } else { + res.writeHead(302, { Location: authorizeUrl }); + res.end(); + } +}); + +server.listen(PORT, () => { + console.log(`OAuth test server running on http://localhost:${PORT}`); + console.log(`\nOpen http://localhost:${PORT} in your browser to start the flow.\n`); +}); diff --git a/api/poetry.lock b/api/poetry.lock index 870e0544c6ab..b4ee146b2494 100644 --- a/api/poetry.lock +++ b/api/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 2.3.2 and should not be changed by hand. +# This file is automatically @generated by Poetry 2.2.1 and should not be changed by hand. [[package]] name = "annotated-types" @@ -1504,6 +1504,27 @@ files = [ [package.dependencies] django = ">=3.2" +[[package]] +name = "django-oauth-toolkit" +version = "3.1.0" +description = "OAuth2 Provider for Django" +optional = false +python-versions = ">=3.8" +groups = ["main"] +files = [ + {file = "django_oauth_toolkit-3.1.0-py3-none-any.whl", hash = "sha256:10ddc90804297d913dfb958edd58d5fac541eb1ca912f47893ca1e482bb2a11f"}, + {file = "django_oauth_toolkit-3.1.0.tar.gz", hash = "sha256:d5a59d07588cfefa8818e99d65040a252eb2ede22512483e2240c91d0b885c8e"}, +] + +[package.dependencies] +django = ">=4.2" +jwcrypto = ">=1.5.0" +oauthlib = ">=3.2.2" +requests = ">=2.13.0" + +[package.extras] +dev = ["m2r", "pytest", "pytest-cov", "sphinx-rtd-theme"] + [[package]] name = "django-ordered-model" version = "3.4.3" @@ -2187,12 +2208,12 @@ files = [ ] [package.dependencies] -google-api-core = {version = ">=1.21.0,<3.dev0", markers = "python_version >= \"3\""} -google-auth = {version = ">=1.16.0,<3.dev0", markers = "python_version >= \"3\""} +google-api-core = {version = ">=1.21.0,<3dev", markers = "python_version >= \"3\""} +google-auth = {version = ">=1.16.0,<3dev", markers = "python_version >= \"3\""} google-auth-httplib2 = ">=0.0.3" -httplib2 = ">=0.15.0,<1.dev0" -six = ">=1.13.0,<2.dev0" -uritemplate = ">=3.0.0,<4.dev0" +httplib2 = ">=0.15.0,<1dev" +six = ">=1.13.0,<2dev" +uritemplate = ">=3.0.0,<4dev" [[package]] name = "google-auth" @@ -2505,7 +2526,7 @@ files = [ ] [package.dependencies] -certifi = ">=14.5.14" +certifi = ">=14.05.14" python-dateutil = ">=2.5.3" reactivex = ">=4.0.4" urllib3 = ">=1.26.0" @@ -2711,7 +2732,7 @@ files = [ [package.dependencies] attrs = ">=22.2.0" -jsonschema-specifications = ">=2023.3.6" +jsonschema-specifications = ">=2023.03.6" referencing = ">=0.28.4" rpds-py = ">=0.7.1" @@ -2734,6 +2755,22 @@ files = [ [package.dependencies] referencing = ">=0.31.0" +[[package]] +name = "jwcrypto" +version = "1.5.6" +description = "Implementation of JOSE Web standards" +optional = false +python-versions = ">= 3.8" +groups = ["main"] +files = [ + {file = "jwcrypto-1.5.6-py3-none-any.whl", hash = "sha256:150d2b0ebbdb8f40b77f543fb44ffd2baeff48788be71f67f03566692fd55789"}, + {file = "jwcrypto-1.5.6.tar.gz", hash = "sha256:771a87762a0c081ae6166958a954f80848820b2ab066937dc8b8379d65b1b039"}, +] + +[package.dependencies] +cryptography = ">=3.4" +typing-extensions = ">=4.5.0" + [[package]] name = "lazy-object-proxy" version = "1.10.0" @@ -3748,7 +3785,7 @@ files = [ ] [package.dependencies] -astroid = ">=2.14.2,<=2.16.0.dev0" +astroid = ">=2.14.2,<=2.16.0-dev0" colorama = {version = ">=0.4.5", markers = "sys_platform == \"win32\""} dill = {version = ">=0.3.6", markers = "python_version >= \"3.11\""} isort = ">=4.2.5,<6" @@ -4769,10 +4806,10 @@ files = [ ] [package.dependencies] -botocore = ">=1.33.2,<2.0a0" +botocore = ">=1.33.2,<2.0a.0" [package.extras] -crt = ["botocore[crt] (>=1.33.2,<2.0a0)"] +crt = ["botocore[crt] (>=1.33.2,<2.0a.0)"] [[package]] name = "segment-analytics-python" @@ -5413,7 +5450,6 @@ files = [ {file = "tzdata-2024.1-py2.py3-none-any.whl", hash = "sha256:9068bc196136463f5245e51efda838afa15aaeca9903f49050dfa2679db4d252"}, {file = "tzdata-2024.1.tar.gz", hash = "sha256:2674120f8d891909751c38abcdfd386ac0a5a1127954fbc332af6b5ceae07efd"}, ] -markers = {auth-controller = "sys_platform == \"win32\"", dev = "sys_platform == \"win32\"", ldap = "sys_platform == \"win32\"", workflows = "sys_platform == \"win32\""} [[package]] name = "uritemplate" @@ -5671,4 +5707,4 @@ files = [ [metadata] lock-version = "2.1" python-versions = ">3.11,<3.14" -content-hash = "83fc1419a52f7553b9c7dbbe1405cb3be83d2727e8a9661450e391282d858ea0" +content-hash = "27858d63787b154e4dd5bf7d976cc324625cc66f25e5b298905ff00662819ba6" diff --git a/api/pyproject.toml b/api/pyproject.toml index f55101f91ad6..2752955e0edc 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -94,6 +94,10 @@ ignore_missing_imports = true module = ["saml.*"] ignore_missing_imports = true +[[tool.mypy.overrides]] +module = ["oauth2_provider.*"] +ignore_missing_imports = true + [tool.django-stubs] django_settings_module = "app.settings.local" @@ -174,6 +178,7 @@ djangorestframework-simplejwt = "^5.5.1" structlog = "^24.4.0" prometheus-client = "^0.21.1" django_cockroachdb = "~4.2" +django-oauth-toolkit = "^3.0.1" [tool.poetry.group.auth-controller] optional = true diff --git a/api/tests/unit/oauth2_metadata/__init__.py b/api/tests/unit/oauth2_metadata/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/api/tests/unit/oauth2_metadata/test_views.py b/api/tests/unit/oauth2_metadata/test_views.py new file mode 100644 index 000000000000..5f371446251b --- /dev/null +++ b/api/tests/unit/oauth2_metadata/test_views.py @@ -0,0 +1,110 @@ +import pytest +from django.test import Client +from django.urls import reverse +from pytest_django.fixtures import SettingsWrapper +from rest_framework import status + +METADATA_URL = "oauth-authorization-server-metadata" + + +@pytest.fixture() +def client() -> Client: + return Client() + + +def test_metadata_endpoint__unauthenticated__returns_200_with_rfc8414_json( + client: Client, + settings: SettingsWrapper, +) -> None: + # Given + settings.FLAGSMITH_API_URL = "https://api.flagsmith.com" + settings.FLAGSMITH_FRONTEND_URL = "https://app.flagsmith.com" + + # When + response = client.get(reverse(METADATA_URL)) + + # Then + assert response.status_code == status.HTTP_200_OK + assert response["Content-Type"] == "application/json" + + data = response.json() + assert data["issuer"] == "https://api.flagsmith.com" + assert ( + data["authorization_endpoint"] == "https://app.flagsmith.com/oauth/authorize/" + ) + assert data["token_endpoint"] == "https://api.flagsmith.com/o/token/" + assert data["registration_endpoint"] == "https://api.flagsmith.com/o/register/" + assert data["revocation_endpoint"] == "https://api.flagsmith.com/o/revoke_token/" + assert data["introspection_endpoint"] == "https://api.flagsmith.com/o/introspect/" + assert data["response_types_supported"] == ["code"] + assert data["grant_types_supported"] == ["authorization_code", "refresh_token"] + assert data["code_challenge_methods_supported"] == ["S256"] + assert "none" in data["token_endpoint_auth_methods_supported"] + assert data["introspection_endpoint_auth_methods_supported"] == ["none"] + + +def test_metadata_endpoint__custom_urls__endpoints_derived_from_settings( + client: Client, + settings: SettingsWrapper, +) -> None: + # Given + settings.FLAGSMITH_API_URL = "https://custom-api.example.com" + settings.FLAGSMITH_FRONTEND_URL = "https://custom-app.example.com" + + # When + response = client.get(reverse(METADATA_URL)) + + # Then + data = response.json() + assert data["issuer"] == "https://custom-api.example.com" + assert data["authorization_endpoint"].startswith("https://custom-app.example.com/") + assert data["token_endpoint"].startswith("https://custom-api.example.com/") + assert data["registration_endpoint"].startswith("https://custom-api.example.com/") + assert data["revocation_endpoint"].startswith("https://custom-api.example.com/") + assert data["introspection_endpoint"].startswith("https://custom-api.example.com/") + + +def test_metadata_endpoint__trailing_slash_in_url__no_double_slash( + client: Client, + settings: SettingsWrapper, +) -> None: + # Given + settings.FLAGSMITH_API_URL = "https://api.flagsmith.com/" + settings.FLAGSMITH_FRONTEND_URL = "https://app.flagsmith.com/" + + # When + response = client.get(reverse(METADATA_URL)) + + # Then + data = response.json() + assert "//" not in data["token_endpoint"].split("://")[1] + assert "//" not in data["authorization_endpoint"].split("://")[1] + + +def test_metadata_endpoint__scopes__reflect_oauth2_provider_settings( + client: Client, + settings: SettingsWrapper, +) -> None: + # Given + settings.OAUTH2_PROVIDER = { + **settings.OAUTH2_PROVIDER, + "SCOPES": {"mcp": "MCP access", "read": "Read access"}, + } + + # When + response = client.get(reverse(METADATA_URL)) + + # Then + data = response.json() + assert set(data["scopes_supported"]) == {"mcp", "read"} + + +def test_metadata_endpoint__post_request__returns_405() -> None: + # Given + csrf_client = Client(enforce_csrf_checks=True) + + # When + response = csrf_client.post(reverse(METADATA_URL)) + + # Then + assert response.status_code == status.HTTP_405_METHOD_NOT_ALLOWED diff --git a/sdk/openapi.yaml b/sdk/openapi.yaml index b3249230a73f..b5f2f0b9e8a4 100644 --- a/sdk/openapi.yaml +++ b/sdk/openapi.yaml @@ -558,6 +558,8 @@ components: basicAuth: type: http scheme: basic + oauth2: + type: oauth2 tokenAuth: type: apiKey in: header