Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 71 additions & 1 deletion src/environmentd/tests/auth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2299,7 +2299,8 @@ async fn test_auth_oidc_fetch_error() {
);

// Stop the OIDC server so the JWKS endpoint becomes unreachable.
oidc_server.handle.abort_and_wait().await;
// Dropping the server aborts the Axum task immediately via AbortOnDropHandle.
drop(oidc_server);

let server = test_util::TestHarness::default()
.with_tls(server_cert, server_key)
Expand Down Expand Up @@ -5859,3 +5860,72 @@ async fn test_oidc_group_sync_case_sensitive() {
let val: String = rows[0].get(0);
assert_eq!(val, "top-secret");
}

#[mz_ore::test(tokio::test(flavor = "multi_thread", worker_threads = 1))]
#[cfg_attr(miri, ignore)] // unsupported operation: can't call foreign function `TLS_method`
async fn test_oidc_generate_jwt_with_groups() {
let (server, admin_client, oidc_server) = setup_group_sync_test().await;

let token = oidc_server.generate_jwt_with_groups(GROUP_SYNC_USER, &["analytics", "data_eng"]);
let _client = oidc_connect(&server, &token)
.await
.expect("login should succeed");

let role_names = fetch_user_role_memberships(&admin_client, GROUP_SYNC_USER).await;
assert_eq!(
role_names,
vec!["analytics", "data_eng"],
"generate_jwt_with_groups should grant the same roles as the raw extra_claims path"
);
}

/// Groups registered via `set_user_groups` are auto-included in `generate_jwt`
/// without any `extra_claims`. Updating the registry and regenerating reflects
/// the new groups. Clearing the registry suppresses the claim.
#[mz_ore::test(tokio::test(flavor = "multi_thread", worker_threads = 1))]
#[cfg_attr(miri, ignore)] // unsupported operation: can't call foreign function `TLS_method`
async fn test_oidc_set_user_groups_registry() {
let (server, admin_client, oidc_server) = setup_group_sync_test().await;

// Register groups for the user and generate a plain JWT (no extra_claims).
oidc_server.set_user_groups(GROUP_SYNC_USER, &["analytics", "platform_eng"]);
let token = oidc_server.generate_jwt(GROUP_SYNC_USER, GenerateJwtOptions::default());
let _client = oidc_connect(&server, &token)
.await
.expect("login should succeed with auto-injected groups");

let role_names = fetch_user_role_memberships(&admin_client, GROUP_SYNC_USER).await;
assert_eq!(
role_names,
vec!["analytics", "platform_eng"],
"groups from registry should be auto-included in JWT"
);

// Update the registry: only data_eng now. Reconnect and verify sync.
oidc_server.set_user_groups(GROUP_SYNC_USER, &["data_eng"]);
let token = oidc_server.generate_jwt(GROUP_SYNC_USER, GenerateJwtOptions::default());
let _client = oidc_connect(&server, &token)
.await
.expect("login should succeed after registry update");

let role_names = fetch_user_role_memberships(&admin_client, GROUP_SYNC_USER).await;
assert_eq!(
role_names,
vec!["data_eng"],
"updated registry should revoke old roles and grant new ones"
);

// Clear the registry: no groups claim in JWT → sync is skipped, roles preserved.
oidc_server.clear_user_groups(GROUP_SYNC_USER);
let token = oidc_server.generate_jwt(GROUP_SYNC_USER, GenerateJwtOptions::default());
let _client = oidc_connect(&server, &token)
.await
.expect("login should succeed with no groups claim");

let role_names = fetch_user_role_memberships(&admin_client, GROUP_SYNC_USER).await;
assert_eq!(
role_names,
vec!["data_eng"],
"missing groups claim should skip sync and preserve current roles"
);
}
77 changes: 72 additions & 5 deletions src/oidc-mock/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,15 @@ use std::borrow::Cow;
use std::collections::BTreeMap;
use std::future::IntoFuture;
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::sync::Arc;
use std::sync::{Arc, Mutex};

use axum::extract::State;
use axum::routing::get;
use axum::{Json, Router};
use base64::Engine;
use jsonwebtoken::{EncodingKey, Header, encode};
use mz_ore::now::NowFn;
use mz_ore::task::JoinHandle;
use mz_ore::task::AbortOnDropHandle;
use openssl::pkey::{PKey, Private};
use openssl::rsa::Rsa;
use serde::{Deserialize, Serialize};
Expand Down Expand Up @@ -113,8 +113,12 @@ pub struct OidcMockServer {
pub now: NowFn,
/// How long tokens should be valid (in seconds).
pub expires_in_secs: i64,
/// Handle to the server task.
pub handle: JoinHandle<Result<(), std::io::Error>>,
/// Handle to the server task. Aborts the task when dropped.
pub handle: AbortOnDropHandle<Result<(), std::io::Error>>,
/// Per-user group memberships. When set, `generate_jwt` auto-includes
/// the `groups` claim for any user with a registered group list.
/// Can be overridden by passing `extra_claims` in `GenerateJwtOptions`.
pub user_groups: Arc<Mutex<BTreeMap<String, Vec<String>>>>,
}

impl OidcMockServer {
Expand Down Expand Up @@ -176,7 +180,8 @@ impl OidcMockServer {
);
println!("oidc-mock listening...");
println!(" HTTP address: {}", issuer);
let handle = mz_ore::task::spawn(|| "oidc-mock-server", server.into_future());
let handle =
mz_ore::task::spawn(|| "oidc-mock-server", server.into_future()).abort_on_drop();

Ok(OidcMockServer {
issuer,
Expand All @@ -185,11 +190,16 @@ impl OidcMockServer {
now,
expires_in_secs,
handle,
user_groups: Arc::new(Mutex::new(BTreeMap::new())),
})
}

/// Generates a JWT token for testing.
///
/// If `opts.extra_claims` does not include a `"groups"` key and the user
/// has registered groups via [`Self::set_user_groups`], the groups are
/// automatically included as the `"groups"` claim.
///
/// # Arguments
///
/// * `sub` - Subject (user identifier).
Expand All @@ -204,6 +214,23 @@ impl OidcMockServer {
.map(|(k, v)| (k, serde_json::Value::String(v)))
.collect();

// Auto-include groups from the registry unless the caller already
// provided a "groups" key in extra_claims.
let extra_has_groups = opts
.extra_claims
.as_ref()
.is_some_and(|e| e.contains_key("groups"));
if !extra_has_groups {
if let Some(groups) = self.user_groups.lock().unwrap().get(sub) {
let arr: serde_json::Value = groups
.iter()
.map(|g| serde_json::Value::String(g.clone()))
.collect::<Vec<_>>()
.into();
unknown_claims.insert("groups".to_string(), arr);
}
}

if let Some(extra) = opts.extra_claims {
unknown_claims.extend(extra);
}
Expand All @@ -222,6 +249,46 @@ impl OidcMockServer {
encode(&header, &claims, &self.encoding_key).expect("failed to encode JWT")
}

/// Generates a JWT with an explicit `groups` claim.
///
/// Shorthand for passing `groups` via [`GenerateJwtOptions::extra_claims`].
///
/// # Arguments
///
/// * `sub` - Subject (user identifier).
/// * `groups` - Group names to include in the `"groups"` claim.
pub fn generate_jwt_with_groups(&self, sub: &str, groups: &[&str]) -> String {
let groups_val: serde_json::Value = groups
.iter()
.map(|g| serde_json::Value::String(g.to_string()))
.collect::<Vec<_>>()
.into();
self.generate_jwt(
sub,
GenerateJwtOptions {
extra_claims: Some(BTreeMap::from([("groups".to_string(), groups_val)])),
..Default::default()
},
)
}

/// Registers the group memberships for a user.
///
/// After calling this, [`Self::generate_jwt`] will automatically include the
/// `"groups"` claim for `sub` unless overridden by `extra_claims`.
/// Passing an empty slice clears all groups for the user.
pub fn set_user_groups(&self, sub: &str, groups: &[&str]) {
self.user_groups.lock().unwrap().insert(
sub.to_string(),
groups.iter().map(|s| s.to_string()).collect(),
);
}

/// Removes all registered group memberships for a user.
pub fn clear_user_groups(&self, sub: &str) {
self.user_groups.lock().unwrap().remove(sub);
}

/// Returns the JWKS URL for this server.
pub fn jwks_url(&self) -> String {
format!("{}/.well-known/jwks.json", self.issuer)
Expand Down
Loading