diff --git a/api/app/settings/common.py b/api/app/settings/common.py index c75fae2fbf34..e82fd1b8241a 100644 --- a/api/app/settings/common.py +++ b/api/app/settings/common.py @@ -88,6 +88,7 @@ "rest_framework.authtoken", # Used for managing api keys "rest_framework_api_key", + "oauth2_provider", "rest_framework_simplejwt.token_blacklist", "djoser", "django.contrib.sites", @@ -165,6 +166,7 @@ "softdelete", "metadata", "app_analytics", + "oauth2_metadata", ] SILENCED_SYSTEM_CHECKS = ["axes.W002"] @@ -312,6 +314,7 @@ "custom_auth.jwt_cookie.authentication.JWTCookieAuthentication", "rest_framework.authentication.TokenAuthentication", "api_keys.authentication.MasterAPIKeyAuthentication", + "oauth2_metadata.authentication.OAuth2BearerTokenAuthentication", ), "PAGE_SIZE": 10, "UNICODE_JSON": False, @@ -941,6 +944,26 @@ "SIGNING_KEY": env.str("COOKIE_AUTH_JWT_SIGNING_KEY", default=SECRET_KEY), } +# OAuth 2.1 Provider (django-oauth-toolkit) +FLAGSMITH_API_URL = env.str("FLAGSMITH_API_URL", default="http://localhost:8000") +FLAGSMITH_FRONTEND_URL = env.str( + "FLAGSMITH_FRONTEND_URL", default="http://localhost:8080" +) + +OAUTH2_PROVIDER = { + "ACCESS_TOKEN_EXPIRE_SECONDS": 60 * 15, # 15 minutes + "REFRESH_TOKEN_EXPIRE_SECONDS": 60 * 60 * 24 * 30, # 30 days + "ROTATE_REFRESH_TOKEN": True, + "PKCE_REQUIRED": True, + "ALLOWED_CODE_CHALLENGE_METHODS": ["S256"], + "SCOPES": {"mcp": "MCP access"}, + "DEFAULT_SCOPES": ["mcp"], + "ALLOWED_GRANT_TYPES": [ + "authorization_code", + "refresh_token", + ], +} + # Github OAuth credentials GITHUB_CLIENT_ID = env.str("GITHUB_CLIENT_ID", default="") GITHUB_CLIENT_SECRET = env.str("GITHUB_CLIENT_SECRET", default="") diff --git a/api/app/urls.py b/api/app/urls.py index d5b68f85f6ca..b9a7e1181b32 100644 --- a/api/app/urls.py +++ b/api/app/urls.py @@ -6,6 +6,7 @@ from django.urls import include, path, re_path from django.views.generic.base import TemplateView +from oauth2_metadata.views import authorization_server_metadata from users.views import password_reset_redirect from . import views @@ -13,6 +14,11 @@ urlpatterns = [ *core_urlpatterns, path("processor/", include("task_processor.urls")), + path( + ".well-known/oauth-authorization-server", + authorization_server_metadata, + name="oauth-authorization-server-metadata", + ), ] if not settings.TASK_PROCESSOR_MODE: @@ -47,6 +53,8 @@ "robots.txt", TemplateView.as_view(template_name="robots.txt", content_type="text/plain"), ), + # Authorize template view for testing: this will be moved to the frontend in following issues + path("o/", include("oauth2_provider.urls", namespace="oauth2_provider")), ] if settings.DEBUG: # pragma: no cover diff --git a/api/oauth2_metadata/__init__.py b/api/oauth2_metadata/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/api/oauth2_metadata/apps.py b/api/oauth2_metadata/apps.py new file mode 100644 index 000000000000..531858ca6ec9 --- /dev/null +++ b/api/oauth2_metadata/apps.py @@ -0,0 +1,8 @@ +from django.apps import AppConfig + + +class OAuth2MetadataConfig(AppConfig): + name = "oauth2_metadata" + + def ready(self) -> None: + from oauth2_metadata import tasks # noqa: F401 diff --git a/api/oauth2_metadata/authentication.py b/api/oauth2_metadata/authentication.py new file mode 100644 index 000000000000..d46d76500a97 --- /dev/null +++ b/api/oauth2_metadata/authentication.py @@ -0,0 +1,17 @@ +from oauth2_provider.contrib.rest_framework import ( + OAuth2Authentication, # type: ignore[import-untyped] +) +from rest_framework.request import Request + + +class OAuth2BearerTokenAuthentication(OAuth2Authentication): # type: ignore[misc] + """DOT's default OAuth2Authentication also reads the request body + looking for an access_token, which consumes the stream and breaks + views that need to read request.body. + """ + + def authenticate(self, request: Request) -> tuple[object, str] | None: + auth_header = request.META.get("HTTP_AUTHORIZATION", "") + if not auth_header.startswith("Bearer "): + return None + return super().authenticate(request) # type: ignore[return-value] diff --git a/api/oauth2_metadata/tasks.py b/api/oauth2_metadata/tasks.py new file mode 100644 index 000000000000..372078267f9e --- /dev/null +++ b/api/oauth2_metadata/tasks.py @@ -0,0 +1,9 @@ +from datetime import timedelta + +from django.core.management import call_command +from task_processor.decorators import register_recurring_task + + +@register_recurring_task(run_every=timedelta(hours=24)) +def clear_expired_oauth2_tokens() -> None: + call_command("cleartokens") diff --git a/api/oauth2_metadata/views.py b/api/oauth2_metadata/views.py new file mode 100644 index 000000000000..25cbc77071d5 --- /dev/null +++ b/api/oauth2_metadata/views.py @@ -0,0 +1,37 @@ +from typing import Any + +from django.conf import settings +from django.http import HttpRequest, JsonResponse +from django.views.decorators.csrf import csrf_exempt +from django.views.decorators.http import require_GET + + +@csrf_exempt +@require_GET +def authorization_server_metadata(request: HttpRequest) -> JsonResponse: + """RFC 8414 OAuth 2.0 Authorization Server Metadata.""" + api_url: str = settings.FLAGSMITH_API_URL.rstrip("/") + frontend_url: str = settings.FLAGSMITH_FRONTEND_URL.rstrip("/") + oauth2_settings: dict[str, Any] = settings.OAUTH2_PROVIDER + scopes: dict[str, str] = oauth2_settings.get("SCOPES", {}) + + metadata = { + "issuer": api_url, + "authorization_endpoint": f"{frontend_url}/oauth/authorize/", + "token_endpoint": f"{api_url}/o/token/", + "registration_endpoint": f"{api_url}/o/register/", + "revocation_endpoint": f"{api_url}/o/revoke_token/", + "introspection_endpoint": f"{api_url}/o/introspect/", + "scopes_supported": list(scopes.keys()), + "response_types_supported": ["code"], + "grant_types_supported": ["authorization_code", "refresh_token"], + "code_challenge_methods_supported": ["S256"], + "token_endpoint_auth_methods_supported": [ + "client_secret_basic", + "client_secret_post", + "none", + ], + "introspection_endpoint_auth_methods_supported": ["none"], + } + + return JsonResponse(metadata) diff --git a/api/oauth2_test_server.mjs b/api/oauth2_test_server.mjs new file mode 100644 index 000000000000..1d0029a59faf --- /dev/null +++ b/api/oauth2_test_server.mjs @@ -0,0 +1,81 @@ +import { createServer } from "node:http"; +import { randomBytes, createHash } from "node:crypto"; + +const CLIENT_ID = "ZLsLu3hhJI4GlhNsGeFVC3K2U3QBGfXtmc0EcyiG"; +const REDIRECT_URI = "http://localhost:3000/oauth/callback"; +const API_URL = "http://localhost:8000"; +const PORT = 3000; + +// Generate PKCE values +const codeVerifier = randomBytes(96).toString("base64url").slice(0, 128); +const codeChallenge = createHash("sha256") + .update(codeVerifier) + .digest("base64url"); + +const authorizeUrl = + `${API_URL}/o/authorize/?` + + new URLSearchParams({ + response_type: "code", + client_id: CLIENT_ID, + redirect_uri: REDIRECT_URI, + scope: "mcp", + code_challenge: codeChallenge, + code_challenge_method: "S256", + }); + +const server = createServer(async (req, res) => { + const url = new URL(req.url, `http://localhost:${PORT}`); + + if (url.pathname === "/oauth/callback") { + const code = url.searchParams.get("code"); + const error = url.searchParams.get("error"); + + if (error) { + res.writeHead(400, { "Content-Type": "text/plain" }); + res.end(`Error: ${error}\n${url.searchParams.get("error_description")}`); + return; + } + + if (!code) { + res.writeHead(400, { "Content-Type": "text/plain" }); + res.end("No authorization code received"); + return; + } + + console.log(`\nReceived authorization code: ${code}`); + console.log("Exchanging for token...\n"); + + // Exchange code for token + const tokenRes = await fetch(`${API_URL}/o/token/`, { + method: "POST", + headers: { "Content-Type": "application/x-www-form-urlencoded" }, + body: new URLSearchParams({ + grant_type: "authorization_code", + code, + redirect_uri: REDIRECT_URI, + client_id: CLIENT_ID, + code_verifier: codeVerifier, + }), + }); + + const tokenData = await tokenRes.json(); + console.log("Token response:", JSON.stringify(tokenData, null, 2)); + + res.writeHead(200, { "Content-Type": "text/html" }); + res.end(`
${JSON.stringify(tokenData, null, 2)}`);
+
+ // Done - shut down
+ setTimeout(() => {
+ console.log("\nDone. Shutting down.");
+ process.exit(0);
+ }, 1000);
+ } else {
+ res.writeHead(302, { Location: authorizeUrl });
+ res.end();
+ }
+});
+
+server.listen(PORT, () => {
+ console.log(`OAuth test server running on http://localhost:${PORT}`);
+ console.log(`\nOpen http://localhost:${PORT} in your browser to start the flow.\n`);
+});
diff --git a/api/poetry.lock b/api/poetry.lock
index 870e0544c6ab..b4ee146b2494 100644
--- a/api/poetry.lock
+++ b/api/poetry.lock
@@ -1,4 +1,4 @@
-# This file is automatically @generated by Poetry 2.3.2 and should not be changed by hand.
+# This file is automatically @generated by Poetry 2.2.1 and should not be changed by hand.
[[package]]
name = "annotated-types"
@@ -1504,6 +1504,27 @@ files = [
[package.dependencies]
django = ">=3.2"
+[[package]]
+name = "django-oauth-toolkit"
+version = "3.1.0"
+description = "OAuth2 Provider for Django"
+optional = false
+python-versions = ">=3.8"
+groups = ["main"]
+files = [
+ {file = "django_oauth_toolkit-3.1.0-py3-none-any.whl", hash = "sha256:10ddc90804297d913dfb958edd58d5fac541eb1ca912f47893ca1e482bb2a11f"},
+ {file = "django_oauth_toolkit-3.1.0.tar.gz", hash = "sha256:d5a59d07588cfefa8818e99d65040a252eb2ede22512483e2240c91d0b885c8e"},
+]
+
+[package.dependencies]
+django = ">=4.2"
+jwcrypto = ">=1.5.0"
+oauthlib = ">=3.2.2"
+requests = ">=2.13.0"
+
+[package.extras]
+dev = ["m2r", "pytest", "pytest-cov", "sphinx-rtd-theme"]
+
[[package]]
name = "django-ordered-model"
version = "3.4.3"
@@ -2187,12 +2208,12 @@ files = [
]
[package.dependencies]
-google-api-core = {version = ">=1.21.0,<3.dev0", markers = "python_version >= \"3\""}
-google-auth = {version = ">=1.16.0,<3.dev0", markers = "python_version >= \"3\""}
+google-api-core = {version = ">=1.21.0,<3dev", markers = "python_version >= \"3\""}
+google-auth = {version = ">=1.16.0,<3dev", markers = "python_version >= \"3\""}
google-auth-httplib2 = ">=0.0.3"
-httplib2 = ">=0.15.0,<1.dev0"
-six = ">=1.13.0,<2.dev0"
-uritemplate = ">=3.0.0,<4.dev0"
+httplib2 = ">=0.15.0,<1dev"
+six = ">=1.13.0,<2dev"
+uritemplate = ">=3.0.0,<4dev"
[[package]]
name = "google-auth"
@@ -2505,7 +2526,7 @@ files = [
]
[package.dependencies]
-certifi = ">=14.5.14"
+certifi = ">=14.05.14"
python-dateutil = ">=2.5.3"
reactivex = ">=4.0.4"
urllib3 = ">=1.26.0"
@@ -2711,7 +2732,7 @@ files = [
[package.dependencies]
attrs = ">=22.2.0"
-jsonschema-specifications = ">=2023.3.6"
+jsonschema-specifications = ">=2023.03.6"
referencing = ">=0.28.4"
rpds-py = ">=0.7.1"
@@ -2734,6 +2755,22 @@ files = [
[package.dependencies]
referencing = ">=0.31.0"
+[[package]]
+name = "jwcrypto"
+version = "1.5.6"
+description = "Implementation of JOSE Web standards"
+optional = false
+python-versions = ">= 3.8"
+groups = ["main"]
+files = [
+ {file = "jwcrypto-1.5.6-py3-none-any.whl", hash = "sha256:150d2b0ebbdb8f40b77f543fb44ffd2baeff48788be71f67f03566692fd55789"},
+ {file = "jwcrypto-1.5.6.tar.gz", hash = "sha256:771a87762a0c081ae6166958a954f80848820b2ab066937dc8b8379d65b1b039"},
+]
+
+[package.dependencies]
+cryptography = ">=3.4"
+typing-extensions = ">=4.5.0"
+
[[package]]
name = "lazy-object-proxy"
version = "1.10.0"
@@ -3748,7 +3785,7 @@ files = [
]
[package.dependencies]
-astroid = ">=2.14.2,<=2.16.0.dev0"
+astroid = ">=2.14.2,<=2.16.0-dev0"
colorama = {version = ">=0.4.5", markers = "sys_platform == \"win32\""}
dill = {version = ">=0.3.6", markers = "python_version >= \"3.11\""}
isort = ">=4.2.5,<6"
@@ -4769,10 +4806,10 @@ files = [
]
[package.dependencies]
-botocore = ">=1.33.2,<2.0a0"
+botocore = ">=1.33.2,<2.0a.0"
[package.extras]
-crt = ["botocore[crt] (>=1.33.2,<2.0a0)"]
+crt = ["botocore[crt] (>=1.33.2,<2.0a.0)"]
[[package]]
name = "segment-analytics-python"
@@ -5413,7 +5450,6 @@ files = [
{file = "tzdata-2024.1-py2.py3-none-any.whl", hash = "sha256:9068bc196136463f5245e51efda838afa15aaeca9903f49050dfa2679db4d252"},
{file = "tzdata-2024.1.tar.gz", hash = "sha256:2674120f8d891909751c38abcdfd386ac0a5a1127954fbc332af6b5ceae07efd"},
]
-markers = {auth-controller = "sys_platform == \"win32\"", dev = "sys_platform == \"win32\"", ldap = "sys_platform == \"win32\"", workflows = "sys_platform == \"win32\""}
[[package]]
name = "uritemplate"
@@ -5671,4 +5707,4 @@ files = [
[metadata]
lock-version = "2.1"
python-versions = ">3.11,<3.14"
-content-hash = "83fc1419a52f7553b9c7dbbe1405cb3be83d2727e8a9661450e391282d858ea0"
+content-hash = "27858d63787b154e4dd5bf7d976cc324625cc66f25e5b298905ff00662819ba6"
diff --git a/api/pyproject.toml b/api/pyproject.toml
index f55101f91ad6..2752955e0edc 100644
--- a/api/pyproject.toml
+++ b/api/pyproject.toml
@@ -94,6 +94,10 @@ ignore_missing_imports = true
module = ["saml.*"]
ignore_missing_imports = true
+[[tool.mypy.overrides]]
+module = ["oauth2_provider.*"]
+ignore_missing_imports = true
+
[tool.django-stubs]
django_settings_module = "app.settings.local"
@@ -174,6 +178,7 @@ djangorestframework-simplejwt = "^5.5.1"
structlog = "^24.4.0"
prometheus-client = "^0.21.1"
django_cockroachdb = "~4.2"
+django-oauth-toolkit = "^3.0.1"
[tool.poetry.group.auth-controller]
optional = true
diff --git a/api/tests/unit/oauth2_metadata/__init__.py b/api/tests/unit/oauth2_metadata/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/api/tests/unit/oauth2_metadata/test_views.py b/api/tests/unit/oauth2_metadata/test_views.py
new file mode 100644
index 000000000000..5f371446251b
--- /dev/null
+++ b/api/tests/unit/oauth2_metadata/test_views.py
@@ -0,0 +1,110 @@
+import pytest
+from django.test import Client
+from django.urls import reverse
+from pytest_django.fixtures import SettingsWrapper
+from rest_framework import status
+
+METADATA_URL = "oauth-authorization-server-metadata"
+
+
+@pytest.fixture()
+def client() -> Client:
+ return Client()
+
+
+def test_metadata_endpoint__unauthenticated__returns_200_with_rfc8414_json(
+ client: Client,
+ settings: SettingsWrapper,
+) -> None:
+ # Given
+ settings.FLAGSMITH_API_URL = "https://api.flagsmith.com"
+ settings.FLAGSMITH_FRONTEND_URL = "https://app.flagsmith.com"
+
+ # When
+ response = client.get(reverse(METADATA_URL))
+
+ # Then
+ assert response.status_code == status.HTTP_200_OK
+ assert response["Content-Type"] == "application/json"
+
+ data = response.json()
+ assert data["issuer"] == "https://api.flagsmith.com"
+ assert (
+ data["authorization_endpoint"] == "https://app.flagsmith.com/oauth/authorize/"
+ )
+ assert data["token_endpoint"] == "https://api.flagsmith.com/o/token/"
+ assert data["registration_endpoint"] == "https://api.flagsmith.com/o/register/"
+ assert data["revocation_endpoint"] == "https://api.flagsmith.com/o/revoke_token/"
+ assert data["introspection_endpoint"] == "https://api.flagsmith.com/o/introspect/"
+ assert data["response_types_supported"] == ["code"]
+ assert data["grant_types_supported"] == ["authorization_code", "refresh_token"]
+ assert data["code_challenge_methods_supported"] == ["S256"]
+ assert "none" in data["token_endpoint_auth_methods_supported"]
+ assert data["introspection_endpoint_auth_methods_supported"] == ["none"]
+
+
+def test_metadata_endpoint__custom_urls__endpoints_derived_from_settings(
+ client: Client,
+ settings: SettingsWrapper,
+) -> None:
+ # Given
+ settings.FLAGSMITH_API_URL = "https://custom-api.example.com"
+ settings.FLAGSMITH_FRONTEND_URL = "https://custom-app.example.com"
+
+ # When
+ response = client.get(reverse(METADATA_URL))
+
+ # Then
+ data = response.json()
+ assert data["issuer"] == "https://custom-api.example.com"
+ assert data["authorization_endpoint"].startswith("https://custom-app.example.com/")
+ assert data["token_endpoint"].startswith("https://custom-api.example.com/")
+ assert data["registration_endpoint"].startswith("https://custom-api.example.com/")
+ assert data["revocation_endpoint"].startswith("https://custom-api.example.com/")
+ assert data["introspection_endpoint"].startswith("https://custom-api.example.com/")
+
+
+def test_metadata_endpoint__trailing_slash_in_url__no_double_slash(
+ client: Client,
+ settings: SettingsWrapper,
+) -> None:
+ # Given
+ settings.FLAGSMITH_API_URL = "https://api.flagsmith.com/"
+ settings.FLAGSMITH_FRONTEND_URL = "https://app.flagsmith.com/"
+
+ # When
+ response = client.get(reverse(METADATA_URL))
+
+ # Then
+ data = response.json()
+ assert "//" not in data["token_endpoint"].split("://")[1]
+ assert "//" not in data["authorization_endpoint"].split("://")[1]
+
+
+def test_metadata_endpoint__scopes__reflect_oauth2_provider_settings(
+ client: Client,
+ settings: SettingsWrapper,
+) -> None:
+ # Given
+ settings.OAUTH2_PROVIDER = {
+ **settings.OAUTH2_PROVIDER,
+ "SCOPES": {"mcp": "MCP access", "read": "Read access"},
+ }
+
+ # When
+ response = client.get(reverse(METADATA_URL))
+
+ # Then
+ data = response.json()
+ assert set(data["scopes_supported"]) == {"mcp", "read"}
+
+
+def test_metadata_endpoint__post_request__returns_405() -> None:
+ # Given
+ csrf_client = Client(enforce_csrf_checks=True)
+
+ # When
+ response = csrf_client.post(reverse(METADATA_URL))
+
+ # Then
+ assert response.status_code == status.HTTP_405_METHOD_NOT_ALLOWED
diff --git a/sdk/openapi.yaml b/sdk/openapi.yaml
index b3249230a73f..b5f2f0b9e8a4 100644
--- a/sdk/openapi.yaml
+++ b/sdk/openapi.yaml
@@ -558,6 +558,8 @@ components:
basicAuth:
type: http
scheme: basic
+ oauth2:
+ type: oauth2
tokenAuth:
type: apiKey
in: header