diff --git a/src/environmentd/tests/auth.rs b/src/environmentd/tests/auth.rs index 235021da80ef5..a7ab676a3cf94 100644 --- a/src/environmentd/tests/auth.rs +++ b/src/environmentd/tests/auth.rs @@ -2299,7 +2299,8 @@ async fn test_auth_oidc_fetch_error() { ); // Stop the OIDC server so the JWKS endpoint becomes unreachable. - oidc_server.handle.abort_and_wait().await; + // Dropping the server aborts the Axum task immediately via AbortOnDropHandle. + drop(oidc_server); let server = test_util::TestHarness::default() .with_tls(server_cert, server_key) @@ -5859,3 +5860,72 @@ async fn test_oidc_group_sync_case_sensitive() { let val: String = rows[0].get(0); assert_eq!(val, "top-secret"); } + +#[mz_ore::test(tokio::test(flavor = "multi_thread", worker_threads = 1))] +#[cfg_attr(miri, ignore)] // unsupported operation: can't call foreign function `TLS_method` +async fn test_oidc_generate_jwt_with_groups() { + let (server, admin_client, oidc_server) = setup_group_sync_test().await; + + let token = oidc_server.generate_jwt_with_groups(GROUP_SYNC_USER, &["analytics", "data_eng"]); + let _client = oidc_connect(&server, &token) + .await + .expect("login should succeed"); + + let role_names = fetch_user_role_memberships(&admin_client, GROUP_SYNC_USER).await; + assert_eq!( + role_names, + vec!["analytics", "data_eng"], + "generate_jwt_with_groups should grant the same roles as the raw extra_claims path" + ); +} + +/// Groups registered via `set_user_groups` are auto-included in `generate_jwt` +/// without any `extra_claims`. Updating the registry and regenerating reflects +/// the new groups. Clearing the registry suppresses the claim. +#[mz_ore::test(tokio::test(flavor = "multi_thread", worker_threads = 1))] +#[cfg_attr(miri, ignore)] // unsupported operation: can't call foreign function `TLS_method` +async fn test_oidc_set_user_groups_registry() { + let (server, admin_client, oidc_server) = setup_group_sync_test().await; + + // Register groups for the user and generate a plain JWT (no extra_claims). + oidc_server.set_user_groups(GROUP_SYNC_USER, &["analytics", "platform_eng"]); + let token = oidc_server.generate_jwt(GROUP_SYNC_USER, GenerateJwtOptions::default()); + let _client = oidc_connect(&server, &token) + .await + .expect("login should succeed with auto-injected groups"); + + let role_names = fetch_user_role_memberships(&admin_client, GROUP_SYNC_USER).await; + assert_eq!( + role_names, + vec!["analytics", "platform_eng"], + "groups from registry should be auto-included in JWT" + ); + + // Update the registry: only data_eng now. Reconnect and verify sync. + oidc_server.set_user_groups(GROUP_SYNC_USER, &["data_eng"]); + let token = oidc_server.generate_jwt(GROUP_SYNC_USER, GenerateJwtOptions::default()); + let _client = oidc_connect(&server, &token) + .await + .expect("login should succeed after registry update"); + + let role_names = fetch_user_role_memberships(&admin_client, GROUP_SYNC_USER).await; + assert_eq!( + role_names, + vec!["data_eng"], + "updated registry should revoke old roles and grant new ones" + ); + + // Clear the registry: no groups claim in JWT → sync is skipped, roles preserved. + oidc_server.clear_user_groups(GROUP_SYNC_USER); + let token = oidc_server.generate_jwt(GROUP_SYNC_USER, GenerateJwtOptions::default()); + let _client = oidc_connect(&server, &token) + .await + .expect("login should succeed with no groups claim"); + + let role_names = fetch_user_role_memberships(&admin_client, GROUP_SYNC_USER).await; + assert_eq!( + role_names, + vec!["data_eng"], + "missing groups claim should skip sync and preserve current roles" + ); +} diff --git a/src/oidc-mock/src/lib.rs b/src/oidc-mock/src/lib.rs index 73eda5dc264f2..03bee4f695bbf 100644 --- a/src/oidc-mock/src/lib.rs +++ b/src/oidc-mock/src/lib.rs @@ -16,7 +16,7 @@ use std::borrow::Cow; use std::collections::BTreeMap; use std::future::IntoFuture; use std::net::{IpAddr, Ipv4Addr, SocketAddr}; -use std::sync::Arc; +use std::sync::{Arc, Mutex}; use axum::extract::State; use axum::routing::get; @@ -24,7 +24,7 @@ use axum::{Json, Router}; use base64::Engine; use jsonwebtoken::{EncodingKey, Header, encode}; use mz_ore::now::NowFn; -use mz_ore::task::JoinHandle; +use mz_ore::task::AbortOnDropHandle; use openssl::pkey::{PKey, Private}; use openssl::rsa::Rsa; use serde::{Deserialize, Serialize}; @@ -113,8 +113,12 @@ pub struct OidcMockServer { pub now: NowFn, /// How long tokens should be valid (in seconds). pub expires_in_secs: i64, - /// Handle to the server task. - pub handle: JoinHandle>, + /// Handle to the server task. Aborts the task when dropped. + pub handle: AbortOnDropHandle>, + /// Per-user group memberships. When set, `generate_jwt` auto-includes + /// the `groups` claim for any user with a registered group list. + /// Can be overridden by passing `extra_claims` in `GenerateJwtOptions`. + pub user_groups: Arc>>>, } impl OidcMockServer { @@ -176,7 +180,8 @@ impl OidcMockServer { ); println!("oidc-mock listening..."); println!(" HTTP address: {}", issuer); - let handle = mz_ore::task::spawn(|| "oidc-mock-server", server.into_future()); + let handle = + mz_ore::task::spawn(|| "oidc-mock-server", server.into_future()).abort_on_drop(); Ok(OidcMockServer { issuer, @@ -185,11 +190,16 @@ impl OidcMockServer { now, expires_in_secs, handle, + user_groups: Arc::new(Mutex::new(BTreeMap::new())), }) } /// Generates a JWT token for testing. /// + /// If `opts.extra_claims` does not include a `"groups"` key and the user + /// has registered groups via [`Self::set_user_groups`], the groups are + /// automatically included as the `"groups"` claim. + /// /// # Arguments /// /// * `sub` - Subject (user identifier). @@ -204,6 +214,23 @@ impl OidcMockServer { .map(|(k, v)| (k, serde_json::Value::String(v))) .collect(); + // Auto-include groups from the registry unless the caller already + // provided a "groups" key in extra_claims. + let extra_has_groups = opts + .extra_claims + .as_ref() + .is_some_and(|e| e.contains_key("groups")); + if !extra_has_groups { + if let Some(groups) = self.user_groups.lock().unwrap().get(sub) { + let arr: serde_json::Value = groups + .iter() + .map(|g| serde_json::Value::String(g.clone())) + .collect::>() + .into(); + unknown_claims.insert("groups".to_string(), arr); + } + } + if let Some(extra) = opts.extra_claims { unknown_claims.extend(extra); } @@ -222,6 +249,46 @@ impl OidcMockServer { encode(&header, &claims, &self.encoding_key).expect("failed to encode JWT") } + /// Generates a JWT with an explicit `groups` claim. + /// + /// Shorthand for passing `groups` via [`GenerateJwtOptions::extra_claims`]. + /// + /// # Arguments + /// + /// * `sub` - Subject (user identifier). + /// * `groups` - Group names to include in the `"groups"` claim. + pub fn generate_jwt_with_groups(&self, sub: &str, groups: &[&str]) -> String { + let groups_val: serde_json::Value = groups + .iter() + .map(|g| serde_json::Value::String(g.to_string())) + .collect::>() + .into(); + self.generate_jwt( + sub, + GenerateJwtOptions { + extra_claims: Some(BTreeMap::from([("groups".to_string(), groups_val)])), + ..Default::default() + }, + ) + } + + /// Registers the group memberships for a user. + /// + /// After calling this, [`Self::generate_jwt`] will automatically include the + /// `"groups"` claim for `sub` unless overridden by `extra_claims`. + /// Passing an empty slice clears all groups for the user. + pub fn set_user_groups(&self, sub: &str, groups: &[&str]) { + self.user_groups.lock().unwrap().insert( + sub.to_string(), + groups.iter().map(|s| s.to_string()).collect(), + ); + } + + /// Removes all registered group memberships for a user. + pub fn clear_user_groups(&self, sub: &str) { + self.user_groups.lock().unwrap().remove(sub); + } + /// Returns the JWKS URL for this server. pub fn jwks_url(&self) -> String { format!("{}/.well-known/jwks.json", self.issuer)