From f11f06ec366c597438dab77879d6f9c7a9b25ef3 Mon Sep 17 00:00:00 2001 From: Nitzan Uziely Date: Wed, 6 May 2026 00:27:40 +0300 Subject: [PATCH] feat(auth): add AWS security credentials supplier support Add a programmatic AWS external account builder that accepts caller-provided AWS region and security credentials. This lets applications delegate AWS credential resolution to the AWS SDK while google-cloud-auth builds the AWS subject token and performs the Google token exchange. Also add supplier-path tests and censor sensitive AWS credential fields in Debug output. --- src/auth/src/credentials/external_account.rs | 513 +++++++++++++++++- .../external_account_sources/aws_sourced.rs | 422 +++++++++++--- 2 files changed, 828 insertions(+), 107 deletions(-) diff --git a/src/auth/src/credentials/external_account.rs b/src/auth/src/credentials/external_account.rs index 105869fadf..78b808e68f 100644 --- a/src/auth/src/credentials/external_account.rs +++ b/src/auth/src/credentials/external_account.rs @@ -106,7 +106,9 @@ //! [Obtain short-lived tokens for Workforce Identity Federation]: https://cloud.google.com/iam/docs/workforce-obtaining-short-lived-credentials#use_configuration_files_for_sign-in use super::dynamic::CredentialsProvider; -use super::external_account_sources::aws_sourced::AwsSourcedCredentials; +use super::external_account_sources::aws_sourced::{ + AwsSourcedCredentials, AwsSupplierSourcedCredentials, +}; use super::external_account_sources::executable_sourced::ExecutableSourcedCredentials; use super::external_account_sources::file_sourced::FileSourcedCredentials; use super::external_account_sources::url_sourced::UrlSourcedCredentials; @@ -137,6 +139,63 @@ use std::sync::Arc; use tokio::time::{Duration, Instant}; const IAM_SCOPE: &str = "https://www.googleapis.com/auth/iam"; +const AWS_SUBJECT_TOKEN_TYPE: &str = "urn:ietf:params:aws:token-type:aws4_request"; + +/// AWS security credentials used to sign a Workload Identity Federation subject token. +#[derive(Clone, PartialEq, Eq)] +pub struct AwsSecurityCredentials { + /// The AWS access key ID. + pub access_key_id: String, + /// The AWS secret access key. + pub secret_access_key: String, + /// The optional AWS session token for temporary credentials. + pub session_token: Option, +} + +impl std::fmt::Debug for AwsSecurityCredentials { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("AwsSecurityCredentials") + .field("access_key_id", &self.access_key_id) + .field("secret_access_key", &"[censored]") + .field( + "session_token", + &self.session_token.as_ref().map(|_| "[censored]"), + ) + .finish() + } +} + +/// Options passed to an [`AwsSecurityCredentialsSupplier`] call. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct SupplierOptions { + /// The target audience for the Google STS token exchange. + pub audience: String, + /// The subject token type configured for the external account. + pub subject_token_type: String, +} + +/// Supplies AWS region and security credentials for AWS Workload Identity Federation. +/// +/// The supplier is responsible for resolving AWS credentials, for example by using +/// the AWS SDK default provider chain. The authentication library uses the supplied +/// values to build the signed AWS `GetCallerIdentity` subject token and exchange it +/// with Google STS. Exchanged Google tokens are cached by this crate, but supplier +/// results are not cached independently; suppliers should cache AWS lookups when +/// that is useful for their environment. +#[async_trait::async_trait] +pub trait AwsSecurityCredentialsSupplier: std::fmt::Debug + Send + Sync { + /// Returns the AWS region used to sign the AWS STS request. + async fn aws_region( + &self, + options: SupplierOptions, + ) -> std::result::Result; + + /// Returns the AWS security credentials used to sign the AWS STS request. + async fn aws_security_credentials( + &self, + options: SupplierOptions, + ) -> std::result::Result; +} #[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] pub(crate) struct CredentialSourceFormat { @@ -289,6 +348,7 @@ struct ExternalAccountConfigBuilder { subject_token_type: Option, token_url: Option, service_account_impersonation_url: Option, + target_principal: Option, client_id: Option, client_secret: Option, scopes: Option>, @@ -318,6 +378,11 @@ impl ExternalAccountConfigBuilder { self } + fn with_target_principal>(mut self, target_principal: S) -> Self { + self.target_principal = Some(target_principal.into()); + self + } + fn with_client_id>(mut self, client_id: S) -> Self { self.client_id = Some(client_id.into()); self @@ -362,6 +427,14 @@ impl ExternalAccountConfigBuilder { )); } + let universe_domain = self.universe_domain; + let service_account_impersonation_url = + self.service_account_impersonation_url.or_else(|| { + self.target_principal.map(|principal| { + service_account_impersonation_url(&principal, universe_domain.as_deref()) + }) + }); + Ok(ExternalAccountConfig { audience, subject_token_type: self @@ -374,11 +447,11 @@ impl ExternalAccountConfigBuilder { credential_source: self .credential_source .ok_or(BuilderError::missing_field("credential_source"))?, - service_account_impersonation_url: self.service_account_impersonation_url, + service_account_impersonation_url, client_id: self.client_id, client_secret: self.client_secret, workforce_pool_user_project: self.workforce_pool_user_project, - universe_domain: self.universe_domain, + universe_domain, }) } } @@ -389,6 +462,7 @@ enum CredentialSource { Executable(ExecutableSourcedCredentials), File(FileSourcedCredentials), Aws(AwsSourcedCredentials), + AwsSupplier(AwsSupplierSourcedCredentials), Programmatic(ProgrammaticSourcedCredentials), } @@ -415,6 +489,9 @@ impl ExternalAccountConfig { CredentialSource::Aws(source) => { Self::make_credentials_from_source(source, config, quota_project_id, retry_builder) } + CredentialSource::AwsSupplier(source) => { + Self::make_credentials_from_source(source, config, quota_project_id, retry_builder) + } } } @@ -809,16 +886,287 @@ impl Builder { } let config: ExternalAccountConfig = file.into(); + Ok(make_credentials_with_access_boundary( + config, + self.quota_project_id, + self.retry_builder, + self.iam_endpoint_override.as_deref(), + )) + } +} - let access_boundary_url = - external_account_lookup_url(&config.audience, self.iam_endpoint_override.as_deref()); +fn service_account_impersonation_url( + target_principal: &str, + universe_domain: Option<&str>, +) -> String { + let universe_domain = universe_domain.unwrap_or(DEFAULT_UNIVERSE_DOMAIN); + format!( + "https://iamcredentials.{universe_domain}/v1/projects/-/serviceAccounts/{target_principal}:generateAccessToken" + ) +} - let creds = config.make_credentials(self.quota_project_id, self.retry_builder); +fn make_credentials_with_access_boundary( + config: ExternalAccountConfig, + quota_project_id: Option, + retry_builder: RetryTokenProviderBuilder, + iam_endpoint_override: Option<&str>, +) -> CredentialsWithAccessBoundary> { + let access_boundary_url = external_account_lookup_url(&config.audience, iam_endpoint_override); + let creds = config.make_credentials(quota_project_id, retry_builder); + CredentialsWithAccessBoundary::new(creds, access_boundary_url) +} - Ok(CredentialsWithAccessBoundary::new( - creds, - access_boundary_url, - )) +fn fill_programmatic_defaults( + mut config_builder: ExternalAccountConfigBuilder, +) -> ExternalAccountConfigBuilder { + if config_builder.scopes.is_none() { + config_builder = config_builder.with_scopes(vec![DEFAULT_SCOPE.to_string()]); + } + if config_builder.token_url.is_none() { + let mut token_url = STS_TOKEN_URL.to_string(); + if let Some(ref ud) = config_builder.universe_domain { + if ud != DEFAULT_UNIVERSE_DOMAIN { + token_url = token_url.replace(DEFAULT_UNIVERSE_DOMAIN, ud); + } + } + config_builder = config_builder.with_token_url(token_url); + } + config_builder +} + +/// A builder for AWS external account [Credentials] using caller-provided AWS credentials. +/// +/// Use this builder when the application wants to resolve AWS credentials itself, +/// for example through the AWS SDK default provider chain, while this crate still +/// builds the AWS SigV4 subject token and performs the Google STS token exchange. +/// +/// The supplier is called when this crate needs a fresh AWS subject token. It +/// may cache AWS SDK lookups internally, but it should return credentials that +/// are valid for signing an AWS STS `GetCallerIdentity` request. +/// +/// # Example +/// +/// ```no_run +/// # use google_cloud_auth::credentials::external_account::{ +/// # AwsExternalAccountBuilder, AwsSecurityCredentials, AwsSecurityCredentialsSupplier, +/// # SupplierOptions, +/// # }; +/// # use google_cloud_auth::errors::CredentialsError; +/// # use std::sync::Arc; +/// #[derive(Debug)] +/// struct AwsSdkSupplier { +/// fallback_region: String, +/// } +/// +/// #[async_trait::async_trait] +/// impl AwsSecurityCredentialsSupplier for AwsSdkSupplier { +/// async fn aws_region( +/// &self, +/// _options: SupplierOptions, +/// ) -> Result { +/// // Resolve this from AWS SDK config in a real implementation. +/// Ok(self.fallback_region.clone()) +/// } +/// +/// async fn aws_security_credentials( +/// &self, +/// _options: SupplierOptions, +/// ) -> Result { +/// // Retrieve these from the AWS SDK default credentials chain. +/// Ok(AwsSecurityCredentials { +/// access_key_id: "AWS_ACCESS_KEY_ID".to_string(), +/// secret_access_key: "AWS_SECRET_ACCESS_KEY".to_string(), +/// session_token: Some("AWS_SESSION_TOKEN".to_string()), +/// }) +/// } +/// } +/// +/// # async fn sample() -> anyhow::Result<()> { +/// let supplier = Arc::new(AwsSdkSupplier { +/// fallback_region: "us-east-1".to_string(), +/// }); +/// +/// let credentials = AwsExternalAccountBuilder::new(supplier) +/// .with_audience("//iam.googleapis.com/projects/123/locations/global/workloadIdentityPools/pool/providers/provider") +/// .with_target_principal("service-account@project.iam.gserviceaccount.com") +/// .with_scopes(["https://www.googleapis.com/auth/cloud-platform"]) +/// .build()?; +/// +/// # let _ = credentials; +/// # Ok(()) } +/// ``` +#[derive(Debug)] +pub struct AwsExternalAccountBuilder { + quota_project_id: Option, + config: ExternalAccountConfigBuilder, + retry_builder: RetryTokenProviderBuilder, + supplier: Arc, + regional_cred_verification_url: Option, +} + +impl AwsExternalAccountBuilder { + /// Creates a new builder using the provided AWS security credentials supplier. + pub fn new(supplier: Arc) -> Self { + Self { + quota_project_id: None, + config: ExternalAccountConfigBuilder::default(), + retry_builder: RetryTokenProviderBuilder::default(), + supplier, + regional_cred_verification_url: None, + } + } + + /// Sets the optional quota project for these credentials. + pub fn with_quota_project_id>(mut self, quota_project_id: S) -> Self { + self.quota_project_id = Some(quota_project_id.into()); + self + } + + /// Overrides the optional scopes for these credentials. + pub fn with_scopes(mut self, scopes: I) -> Self + where + I: IntoIterator, + S: Into, + { + self.config = self.config.with_scopes( + scopes + .into_iter() + .map(|s| s.into()) + .collect::>(), + ); + self + } + + /// Sets the Google Cloud universe domain for these credentials. + pub fn with_universe_domain>(mut self, universe_domain: S) -> Self { + self.config = self.config.with_universe_domain(universe_domain); + self + } + + /// Sets the required audience for the token exchange. + pub fn with_audience>(mut self, audience: S) -> Self { + self.config = self.config.with_audience(audience); + self + } + + /// Sets the subject token type. Defaults to the AWS SigV4 request token type. + pub fn with_subject_token_type>(mut self, subject_token_type: S) -> Self { + self.config = self.config.with_subject_token_type(subject_token_type); + self + } + + /// Sets the STS token URL. Defaults to `https://sts.googleapis.com/v1/token`. + pub fn with_token_url>(mut self, token_url: S) -> Self { + self.config = self.config.with_token_url(token_url); + self + } + + /// Sets the optional client ID for client authentication. + pub fn with_client_id>(mut self, client_id: S) -> Self { + self.config = self.config.with_client_id(client_id.into()); + self + } + + /// Sets the optional client secret for client authentication. + pub fn with_client_secret>(mut self, client_secret: S) -> Self { + self.config = self.config.with_client_secret(client_secret.into()); + self + } + + /// Sets the service account email to impersonate after the STS token exchange. + pub fn with_target_principal>(mut self, target_principal: S) -> Self { + self.config = self.config.with_target_principal(target_principal); + self + } + + /// Sets the workforce pool user project for workforce identity federation. + pub fn with_workforce_pool_user_project>(mut self, project: S) -> Self { + self.config = self.config.with_workforce_pool_user_project(project); + self + } + + /// Sets the AWS STS regional credential verification URL template. + /// + /// The template may contain `{region}`, which is replaced with the supplier-provided + /// AWS region. The URL should be an AWS STS `GetCallerIdentity` endpoint, + /// such as + /// `https://sts.{region}.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15`. + pub fn with_regional_cred_verification_url>(mut self, url: S) -> Self { + self.regional_cred_verification_url = Some(url.into()); + self + } + + /// Configure the retry policy for fetching tokens. + pub fn with_retry_policy>(mut self, v: V) -> Self { + self.retry_builder = self.retry_builder.with_retry_policy(v.into()); + self + } + + /// Configure the retry backoff policy. + pub fn with_backoff_policy>(mut self, v: V) -> Self { + self.retry_builder = self.retry_builder.with_backoff_policy(v.into()); + self + } + + /// Configure the retry throttler. + pub fn with_retry_throttler>(mut self, v: V) -> Self { + self.retry_builder = self.retry_builder.with_retry_throttler(v.into()); + self + } + + /// Returns a [Credentials] instance with the configured settings. + /// + /// # Errors + /// + /// Returns a [BuilderError] if any required field, such as `audience`, is missing. + pub fn build(self) -> BuildResult { + let (config, quota_project_id, retry_builder) = self.build_components()?; + Ok( + make_credentials_with_access_boundary(config, quota_project_id, retry_builder, None) + .into(), + ) + } + + fn build_components( + self, + ) -> BuildResult<( + ExternalAccountConfig, + Option, + RetryTokenProviderBuilder, + )> { + let Self { + quota_project_id, + config, + retry_builder, + supplier, + regional_cred_verification_url, + } = self; + + let mut config_builder = fill_programmatic_defaults(config); + if config_builder.subject_token_type.is_none() { + config_builder = config_builder.with_subject_token_type(AWS_SUBJECT_TOKEN_TYPE); + } + + let audience = config_builder + .audience + .clone() + .ok_or(BuilderError::missing_field("audience"))?; + let subject_token_type = config_builder + .subject_token_type + .clone() + .ok_or(BuilderError::missing_field("subject_token_type"))?; + + let source = AwsSupplierSourcedCredentials::new( + supplier, + audience, + subject_token_type, + regional_cred_verification_url, + ); + let final_config = config_builder + .with_credential_source(CredentialSource::AwsSupplier(source)) + .build()?; + + Ok((final_config, quota_project_id, retry_builder)) } } @@ -1473,19 +1821,7 @@ impl ProgrammaticBuilder { retry_builder, } = self; - let mut config_builder = config; - if config_builder.scopes.is_none() { - config_builder = config_builder.with_scopes(vec![DEFAULT_SCOPE.to_string()]); - } - if config_builder.token_url.is_none() { - let mut token_url = STS_TOKEN_URL.to_string(); - if let Some(ref ud) = config_builder.universe_domain { - if ud != DEFAULT_UNIVERSE_DOMAIN { - token_url = token_url.replace(DEFAULT_UNIVERSE_DOMAIN, ud); - } - } - config_builder = config_builder.with_token_url(token_url); - } + let config_builder = fill_programmatic_defaults(config); let final_config = config_builder.build()?; Ok((final_config, quota_project_id, retry_builder)) @@ -1570,6 +1906,49 @@ mod tests { } } + #[derive(Debug)] + struct TestAwsSupplier; + + #[async_trait::async_trait] + impl AwsSecurityCredentialsSupplier for TestAwsSupplier { + async fn aws_region( + &self, + _options: SupplierOptions, + ) -> std::result::Result { + Ok("us-east-1".to_string()) + } + + async fn aws_security_credentials( + &self, + _options: SupplierOptions, + ) -> std::result::Result { + Ok(AwsSecurityCredentials { + access_key_id: "test-access-key".to_string(), + secret_access_key: "test-secret".to_string(), + session_token: None, + }) + } + } + + #[test] + fn aws_security_credentials_debug_censors_secrets() { + let creds = AwsSecurityCredentials { + access_key_id: "test-access-key".to_string(), + secret_access_key: "secret-access-key-should-not-appear".to_string(), + session_token: Some("session-token-should-not-appear".to_string()), + }; + + let got = format!("{creds:?}"); + + assert!(got.contains("test-access-key"), "{got}"); + assert!( + !got.contains("secret-access-key-should-not-appear"), + "{got}" + ); + assert!(!got.contains("session-token-should-not-appear"), "{got}"); + assert!(got.contains("[censored]"), "{got}"); + } + #[tokio::test] async fn create_external_account_builder() { let contents = json!({ @@ -1683,6 +2062,94 @@ mod tests { ); } + #[tokio::test] + async fn test_programmatic_builder_impersonation_url_preserves_default_universe() { + let provider = Arc::new(TestSubjectTokenProvider); + let builder = ProgrammaticBuilder::new(provider) + .with_audience("test-audience") + .with_subject_token_type("test-token-type") + .with_target_principal("test-principal") + .with_universe_domain("my-custom-universe.com"); + + let (config, _, _) = builder.build_components().unwrap(); + + assert_eq!( + config.service_account_impersonation_url, + Some("https://iamcredentials.googleapis.com/v1/projects/-/serviceAccounts/test-principal:generateAccessToken".to_string()) + ); + } + + #[tokio::test] + async fn create_aws_external_account_builder() { + let supplier = Arc::new(TestAwsSupplier); + let builder = AwsExternalAccountBuilder::new(supplier) + .with_audience("test-audience") + .with_scopes(["scope1", "scope2"]) + .with_token_url("http://custom.com/token") + .with_target_principal("test-principal") + .with_regional_cred_verification_url( + "https://sts.{region}.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15", + ); + + let (config, _, _) = builder.build_components().unwrap(); + + assert_eq!(config.audience, "test-audience"); + assert_eq!(config.subject_token_type, AWS_SUBJECT_TOKEN_TYPE); + assert_eq!(config.scopes, vec!["scope1", "scope2"]); + assert_eq!(config.token_url, "http://custom.com/token"); + assert_eq!( + config.service_account_impersonation_url, + Some("https://iamcredentials.googleapis.com/v1/projects/-/serviceAccounts/test-principal:generateAccessToken".to_string()) + ); + match config.credential_source { + CredentialSource::AwsSupplier(_) => {} + _ => unreachable!("expected AWS supplier sourced credential"), + } + } + + #[tokio::test] + async fn aws_external_account_builder_wraps_with_access_boundary() { + let supplier = Arc::new(TestAwsSupplier); + let creds = AwsExternalAccountBuilder::new(supplier) + .with_audience("test-audience") + .build() + .unwrap(); + + let fmt = format!("{creds:?}"); + assert!(fmt.contains("CredentialsWithAccessBoundary"), "{fmt}"); + } + + #[tokio::test] + async fn aws_external_account_builder_sts_url_updates_with_universe_domain() { + let supplier = Arc::new(TestAwsSupplier); + let builder = AwsExternalAccountBuilder::new(supplier) + .with_audience("test-audience") + .with_universe_domain("my-custom-universe.com"); + + let (config, _, _) = builder.build_components().unwrap(); + + assert_eq!( + config.token_url, + "https://sts.my-custom-universe.com/v1/token" + ); + } + + #[tokio::test] + async fn aws_external_account_builder_impersonation_url_updates_with_universe_domain() { + let supplier = Arc::new(TestAwsSupplier); + let builder = AwsExternalAccountBuilder::new(supplier) + .with_audience("test-audience") + .with_target_principal("test-principal") + .with_universe_domain("my-custom-universe.com"); + + let (config, _, _) = builder.build_components().unwrap(); + + assert_eq!( + config.service_account_impersonation_url, + Some("https://iamcredentials.my-custom-universe.com/v1/projects/-/serviceAccounts/test-principal:generateAccessToken".to_string()) + ); + } + #[tokio::test] async fn create_external_account_detect_url_sourced() { let contents = json!({ diff --git a/src/auth/src/credentials/external_account_sources/aws_sourced.rs b/src/auth/src/credentials/external_account_sources/aws_sourced.rs index ef0a2e9406..a0ed04315a 100644 --- a/src/auth/src/credentials/external_account_sources/aws_sourced.rs +++ b/src/auth/src/credentials/external_account_sources/aws_sourced.rs @@ -14,6 +14,9 @@ use crate::{ Result, + credentials::external_account::{ + AwsSecurityCredentials, AwsSecurityCredentialsSupplier, SupplierOptions, + }, credentials::subject_token::{ Builder as SubjectTokenBuilder, SubjectToken, SubjectTokenProvider, }, @@ -26,6 +29,7 @@ use reqwest::{Client, Response}; use serde::{Deserialize, Serialize}; use sha2::{Digest, Sha256}; use std::collections::BTreeMap; +use std::sync::Arc; const AWS_REGION: &str = "AWS_REGION"; const AWS_DEFAULT_REGION: &str = "AWS_DEFAULT_REGION"; @@ -89,7 +93,7 @@ impl AwsSourcedCredentials { } #[derive(Debug, Deserialize)] -struct AwsSecurityCredentials { +struct AwsMetadataSecurityCredentials { #[serde(rename = "AccessKeyId")] access_key_id: String, #[serde(rename = "SecretAccessKey")] @@ -98,6 +102,16 @@ struct AwsSecurityCredentials { token: Option, } +impl From for AwsSecurityCredentials { + fn from(value: AwsMetadataSecurityCredentials) -> Self { + Self { + access_key_id: value.access_key_id, + secret_access_key: value.secret_access_key, + session_token: value.token, + } + } +} + #[derive(Serialize)] struct AwsStsRequest { url: String, @@ -127,95 +141,154 @@ impl SubjectTokenProvider for AwsSourcedCredentials { .resolve_credentials(&client, imdsv2_token.as_deref()) .await?; - let now = Utc::now(); - let amz_date = now.format("%Y%m%dT%H%M%SZ").to_string(); - let date_stamp = now.format("%Y%m%d").to_string(); - - let url = resolve_sts_url(self.regional_cred_verification_url.as_deref(), ®ion)?; - let host = url.host_str().unwrap(); // unwrap is safe because resolve_sts_url checks for a host - let sts_url = url.to_string(); - - let method = "POST"; - let body = ""; - let canonical_uri = "/"; - - let query_params: BTreeMap<_, _> = url.query_pairs().collect(); - let canonical_query = url::form_urlencoded::Serializer::new(String::new()) - .extend_pairs(query_params) - .finish(); - - let mut headers = BTreeMap::new(); - headers.insert("host".to_string(), host.to_string()); - headers.insert(X_AMZ_DATE.to_string(), amz_date.clone()); - if let Some(token) = &creds.token { - headers.insert(X_AMZ_SECURITY_TOKEN.to_string(), token.clone()); - } - headers.insert( - X_GOOG_CLOUD_TARGET_RESOURCE.to_string(), - self.audience.clone(), - ); + build_aws_subject_token( + &self.audience, + ®ion, + &creds, + self.regional_cred_verification_url.as_deref(), + ) + } +} - let signed_headers = headers.keys().cloned().collect::>().join(";"); - let canonical_headers = headers.iter().fold(String::new(), |mut acc, (k, v)| { - acc.push_str(&format!("{}:{}\n", k, v.trim())); - acc - }); +pub(crate) fn build_aws_subject_token( + audience: &str, + region: &str, + creds: &AwsSecurityCredentials, + regional_cred_verification_url: Option<&str>, +) -> Result { + let now = Utc::now(); + let amz_date = now.format("%Y%m%dT%H%M%SZ").to_string(); + let date_stamp = now.format("%Y%m%d").to_string(); + + let url = resolve_sts_url(regional_cred_verification_url, region)?; + let host = url.host_str().unwrap(); // unwrap is safe because resolve_sts_url checks for a host + let sts_url = url.to_string(); + + let method = "POST"; + let body = ""; + let canonical_uri = "/"; + + let query_params: BTreeMap<_, _> = url.query_pairs().collect(); + let canonical_query = url::form_urlencoded::Serializer::new(String::new()) + .extend_pairs(query_params) + .finish(); + + let mut headers = BTreeMap::new(); + headers.insert("host".to_string(), host.to_string()); + headers.insert(X_AMZ_DATE.to_string(), amz_date.clone()); + if let Some(token) = &creds.session_token { + headers.insert(X_AMZ_SECURITY_TOKEN.to_string(), token.clone()); + } + headers.insert( + X_GOOG_CLOUD_TARGET_RESOURCE.to_string(), + audience.to_string(), + ); + + let signed_headers = headers.keys().cloned().collect::>().join(";"); + let canonical_headers = headers.iter().fold(String::new(), |mut acc, (k, v)| { + acc.push_str(&format!("{}:{}\n", k, v.trim())); + acc + }); + + let payload_hash = hash_sha256(body); + + let canonical_request = format!( + "{}\n{}\n{}\n{}\n{}\n{}", + method, canonical_uri, canonical_query, canonical_headers, signed_headers, payload_hash + ); + + let credential_scope = format!( + "{}/{}/{}/{}", + date_stamp, region, AWS_STS_SERVICE, AWS4_REQUEST + ); + let string_to_sign = format!( + "{}\n{}\n{}\n{}", + AWS4_HMAC_SHA256, + amz_date, + credential_scope, + hash_sha256(&canonical_request) + ); + + let signing_key = get_signing_key( + &creds.secret_access_key, + &date_stamp, + region, + AWS_STS_SERVICE, + )?; + let signature = hex::encode(hmac_sha256(&signing_key, &string_to_sign)?); + + let authorization_header = format!( + "{} Credential={}/{}, SignedHeaders={}, Signature={}", + AWS4_HMAC_SHA256, creds.access_key_id, credential_scope, signed_headers, signature + ); + + let final_headers: Vec<_> = headers + .into_iter() + .map(|(key, value)| AwsHeader { key, value }) + .chain(std::iter::once(AwsHeader { + key: "Authorization".to_string(), + value: authorization_header, + })) + .collect(); + + let aws_sts_request = AwsStsRequest { + url: sts_url, + method: method.to_string(), + headers: final_headers, + body: body.to_string(), + }; - let payload_hash = hash_sha256(body); + let json_token = serde_json::to_string(&aws_sts_request) + .map_err(|e| CredentialsError::from_source(false, e))?; - let canonical_request = format!( - "{}\n{}\n{}\n{}\n{}\n{}", - method, canonical_uri, canonical_query, canonical_headers, signed_headers, payload_hash - ); + let subject_token: String = + url::form_urlencoded::byte_serialize(json_token.as_bytes()).collect(); - let credential_scope = format!( - "{}/{}/{}/{}", - date_stamp, region, AWS_STS_SERVICE, AWS4_REQUEST - ); - let string_to_sign = format!( - "{}\n{}\n{}\n{}", - AWS4_HMAC_SHA256, - amz_date, - credential_scope, - hash_sha256(&canonical_request) - ); + Ok(SubjectTokenBuilder::new(subject_token).build()) +} - let signing_key = get_signing_key( - &creds.secret_access_key, - &date_stamp, - ®ion, - AWS_STS_SERVICE, - )?; - let signature = hex::encode(hmac_sha256(&signing_key, &string_to_sign)?); +/// Credential source for AWS workloads using caller-provided security credentials. +#[derive(Debug, Clone)] +pub(crate) struct AwsSupplierSourcedCredentials { + supplier: Arc, + audience: String, + subject_token_type: String, + regional_cred_verification_url: Option, +} - let authorization_header = format!( - "{} Credential={}/{}, SignedHeaders={}, Signature={}", - AWS4_HMAC_SHA256, creds.access_key_id, credential_scope, signed_headers, signature - ); +impl AwsSupplierSourcedCredentials { + pub(crate) fn new( + supplier: Arc, + audience: String, + subject_token_type: String, + regional_cred_verification_url: Option, + ) -> Self { + Self { + supplier, + audience, + subject_token_type, + regional_cred_verification_url, + } + } +} - let final_headers: Vec<_> = headers - .into_iter() - .map(|(key, value)| AwsHeader { key, value }) - .chain(std::iter::once(AwsHeader { - key: "Authorization".to_string(), - value: authorization_header, - })) - .collect(); +impl SubjectTokenProvider for AwsSupplierSourcedCredentials { + type Error = CredentialsError; - let aws_sts_request = AwsStsRequest { - url: sts_url, - method: method.to_string(), - headers: final_headers, - body: body.to_string(), + async fn subject_token(&self) -> Result { + let options = SupplierOptions { + audience: self.audience.clone(), + subject_token_type: self.subject_token_type.clone(), }; + let region = self.supplier.aws_region(options.clone()).await?; + let creds = self.supplier.aws_security_credentials(options).await?; - let json_token = serde_json::to_string(&aws_sts_request) - .map_err(|e| CredentialsError::from_source(false, e))?; - - let subject_token: String = - url::form_urlencoded::byte_serialize(json_token.as_bytes()).collect(); - - Ok(SubjectTokenBuilder::new(subject_token).build()) + build_aws_subject_token( + &self.audience, + ®ion, + &creds, + self.regional_cred_verification_url.as_deref(), + ) } } @@ -402,11 +475,11 @@ impl AwsSourcedCredentials { ) .await?; - let creds = response + let creds: AwsMetadataSecurityCredentials = response .json() .await .map_err(|e| errors::from_http_error(e, "failed to parse AWS credentials JSON"))?; - return Ok(creds); + return Ok(creds.into()); } Err(CredentialsError::from_msg( false, @@ -426,7 +499,7 @@ impl AwsSourcedCredentials { return Ok(AwsSecurityCredentials { access_key_id: ak, secret_access_key: sk, - token: std::env::var(AWS_SESSION_TOKEN).ok(), + session_token: std::env::var(AWS_SESSION_TOKEN).ok(), }); } @@ -450,6 +523,67 @@ mod tests { type TestResult = std::result::Result<(), Box>; + const TEST_AWS_SUBJECT_TOKEN_TYPE: &str = "urn:ietf:params:aws:token-type:aws4_request"; + + fn decode_subject_token_json( + subject_token: SubjectToken, + ) -> std::result::Result> { + let decoded_json: String = url::form_urlencoded::parse(subject_token.token.as_bytes()) + .map(|(k, _)| k) + .collect(); + Ok(serde_json::from_str(&decoded_json)?) + } + + #[derive(Debug)] + struct StaticAwsSupplier { + region: String, + credentials: AwsSecurityCredentials, + } + + #[async_trait::async_trait] + impl AwsSecurityCredentialsSupplier for StaticAwsSupplier { + async fn aws_region( + &self, + options: SupplierOptions, + ) -> std::result::Result { + assert_eq!(options.audience, "supplier-audience"); + assert_eq!(options.subject_token_type, TEST_AWS_SUBJECT_TOKEN_TYPE); + Ok(self.region.clone()) + } + + async fn aws_security_credentials( + &self, + options: SupplierOptions, + ) -> std::result::Result { + assert_eq!(options.audience, "supplier-audience"); + assert_eq!(options.subject_token_type, TEST_AWS_SUBJECT_TOKEN_TYPE); + Ok(self.credentials.clone()) + } + } + + #[derive(Debug)] + struct FailingCredentialsSupplier; + + #[async_trait::async_trait] + impl AwsSecurityCredentialsSupplier for FailingCredentialsSupplier { + async fn aws_region( + &self, + _options: SupplierOptions, + ) -> std::result::Result { + Ok("us-east-1".to_string()) + } + + async fn aws_security_credentials( + &self, + _options: SupplierOptions, + ) -> std::result::Result { + Err(CredentialsError::from_msg( + false, + "supplier credentials failed", + )) + } + } + #[test_case("us-east-1a", Some("us-east-1"); "zone_to_region")] #[test_case("us-east-1", Some("us-east-1"); "already_region")] #[test_case("us-gov-west-1a", Some("us-gov-west-1"); "gov_zone_to_region")] @@ -464,6 +598,7 @@ mod tests { #[test_case(None, "us-east-1", "https://sts.us-east-1.amazonaws.com/?Action=GetCallerIdentity&Version=2011-06-15"; "default_template")] #[test_case(Some("http://custom.sts.url/{region}"), "us-west-2", "http://custom.sts.url/us-west-2"; "custom_template_with_region")] + #[test_case(Some("http://custom.sts.url/{region}?Action=GetCallerIdentity&Version=2011-06-15"), "us-west-2", "http://custom.sts.url/us-west-2?Action=GetCallerIdentity&Version=2011-06-15"; "custom_template_get_caller_identity_with_region")] #[test_case(Some("sts.amazonaws.com"), "us-east-1", "https://sts.amazonaws.com/"; "no_scheme")] #[test_case(Some("https://sts.amazonaws.com"), "us-east-1", "https://sts.amazonaws.com/"; "with_scheme")] fn test_resolve_sts_url(template: Option<&str>, region: &str, expected: &str) { @@ -477,6 +612,125 @@ mod tests { assert!(result.is_err()); } + #[tokio::test] + #[parallel] + async fn test_supplier_subject_token_success() -> TestResult { + let supplier = Arc::new(StaticAwsSupplier { + region: "eu-west-1".to_string(), + credentials: AwsSecurityCredentials { + access_key_id: "SUPPLIER_ACCESS_KEY".to_string(), + secret_access_key: "SUPPLIER_SECRET".to_string(), + session_token: Some("SUPPLIER_SESSION_TOKEN".to_string()), + }, + }); + let creds = AwsSupplierSourcedCredentials::new( + supplier, + "supplier-audience".to_string(), + TEST_AWS_SUBJECT_TOKEN_TYPE.to_string(), + Some( + "https://sts.{region}.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15" + .to_string(), + ), + ); + + let val = decode_subject_token_json(creds.subject_token().await?)?; + + assert_eq!( + val["url"], + "https://sts.eu-west-1.amazonaws.com/?Action=GetCallerIdentity&Version=2011-06-15", + "{val:?}" + ); + let headers = val["headers"] + .as_array() + .ok_or("headers should be an array")?; + + let session_token = headers + .iter() + .find(|h| h["key"] == X_AMZ_SECURITY_TOKEN) + .ok_or("missing session token header")?; + assert_eq!(session_token["value"], "SUPPLIER_SESSION_TOKEN", "{val:?}"); + + let target_resource = headers + .iter() + .find(|h| h["key"] == X_GOOG_CLOUD_TARGET_RESOURCE) + .ok_or("missing target resource header")?; + assert_eq!(target_resource["value"], "supplier-audience", "{val:?}"); + + let auth = headers + .iter() + .find(|h| h["key"] == "Authorization") + .ok_or("missing auth header")?; + let auth_value = auth["value"].as_str().ok_or("auth should be a string")?; + assert!(auth_value.contains("SUPPLIER_ACCESS_KEY"), "{auth:?}"); + assert!( + auth_value.contains("/eu-west-1/sts/aws4_request"), + "{auth:?}" + ); + assert!( + auth_value.contains( + "SignedHeaders=host;x-amz-date;x-amz-security-token;x-goog-cloud-target-resource" + ), + "{auth:?}" + ); + + Ok(()) + } + + #[tokio::test] + #[parallel] + async fn test_supplier_subject_token_without_session_token() -> TestResult { + let supplier = Arc::new(StaticAwsSupplier { + region: "eu-west-1".to_string(), + credentials: AwsSecurityCredentials { + access_key_id: "SUPPLIER_ACCESS_KEY".to_string(), + secret_access_key: "SUPPLIER_SECRET".to_string(), + session_token: None, + }, + }); + let creds = AwsSupplierSourcedCredentials::new( + supplier, + "supplier-audience".to_string(), + TEST_AWS_SUBJECT_TOKEN_TYPE.to_string(), + None, + ); + + let val = decode_subject_token_json(creds.subject_token().await?)?; + let headers = val["headers"] + .as_array() + .ok_or("headers should be an array")?; + + assert!( + headers.iter().all(|h| h["key"] != X_AMZ_SECURITY_TOKEN), + "{headers:?}" + ); + let auth = headers + .iter() + .find(|h| h["key"] == "Authorization") + .ok_or("missing auth header")?; + let auth_value = auth["value"].as_str().ok_or("auth should be a string")?; + assert!( + auth_value.contains("SignedHeaders=host;x-amz-date;x-goog-cloud-target-resource"), + "{auth:?}" + ); + + Ok(()) + } + + #[tokio::test] + #[parallel] + async fn test_supplier_error_propagates() { + let creds = AwsSupplierSourcedCredentials::new( + Arc::new(FailingCredentialsSupplier), + "supplier-audience".to_string(), + TEST_AWS_SUBJECT_TOKEN_TYPE.to_string(), + None, + ); + + let err = creds.subject_token().await.unwrap_err(); + assert!(err.to_string().contains("supplier credentials failed")); + assert!(!err.is_transient()); + } + #[tokio::test] #[serial] async fn test_resolve_region_env() -> TestResult { @@ -576,7 +830,7 @@ mod tests { assert_eq!(resolved.access_key_id, "ACCESS_KEY_ID_IMDS", "{resolved:?}"); assert_eq!(resolved.secret_access_key, "SECRET_IMDS", "{resolved:?}"); assert_eq!( - resolved.token, + resolved.session_token, Some("TOKEN_IMDS".to_string()), "{resolved:?}" );