From 0bb924e3f97376c920cfbeb92caa9926d37ce79b Mon Sep 17 00:00:00 2001 From: Antoine Charbonneau Date: Wed, 20 May 2026 12:52:10 -0400 Subject: [PATCH 1/2] feat(iam): add ECS container credential discovery --- iam.go | 58 +++++++++++++++++++++++++++++++++++++++++++++++++--------- 1 file changed, 49 insertions(+), 9 deletions(-) diff --git a/iam.go b/iam.go index 4f657ad..50641b1 100644 --- a/iam.go +++ b/iam.go @@ -8,24 +8,24 @@ import ( "fmt" "io" "net/http" + "os" "time" ) const ( - imdsTokenHeader = "X-aws-ec2-metadata-token" - imdsTokenTtlHeader = "X-aws-ec2-metadata-token-ttl-seconds" - metadataBaseURL = "http://169.254.169.254/latest" - securityCredentialsURI = "/meta-data/iam/security-credentials/" - imdsTokenURI = "/api/token" - defaultIMDSTokenTTL = "60" + imdsTokenHeader = "X-aws-ec2-metadata-token" + imdsTokenTtlHeader = "X-aws-ec2-metadata-token-ttl-seconds" + metadataBaseURL = "http://169.254.169.254/latest" + securityCredentialsURI = "/meta-data/iam/security-credentials/" + imdsTokenURI = "/api/token" + defaultIMDSTokenTTL = "60" + ecsContainerCredentialsEnv = "AWS_CONTAINER_CREDENTIALS_RELATIVE_URI" + ecsContainerCredentialsBaseUrl = "http://169.254.170.2" ) // IAMResponse is used by NewUsingIAM to auto // detect the credentials. type IAMResponse struct { - Code string `json:"Code"` - LastUpdated string `json:"LastUpdated"` - Type string `json:"Type"` AccessKeyID string `json:"AccessKeyId"` SecretAccessKey string `json:"SecretAccessKey"` Token string `json:"Token"` @@ -77,11 +77,51 @@ func fetchIMDSToken(cl *http.Client, baseURL string) (string, bool, error) { return string(token), true, nil } +// fetchIAMDataECS fetches the IAM credentials from the ECS default endpoint. +func fetchIAMDataForEcs(cl *http.Client) (IAMResponse, error) { + env, isSet := os.LookupEnv(ecsContainerCredentialsEnv) + + if !isSet { + return IAMResponse{}, fmt.Errorf("error getting ecs environment variable, most likely not using ecs.") + } + + url := ecsContainerCredentialsBaseUrl + env + + req, err := http.NewRequest(http.MethodGet, url, nil) + if err != nil { + return IAMResponse{}, fmt.Errorf("error creating IAM ECS request: %w", err) + } + + resp, err := cl.Do(req) + if err != nil { + return IAMResponse{}, fmt.Errorf("error fetching IAM ECS request: %w", err) + } + + var jResp IAMResponse + jsonString, err := io.ReadAll(resp.Body) + if err != nil { + return IAMResponse{}, fmt.Errorf("error reading role data: %w", err) + } + + if err := json.Unmarshal(jsonString, &jResp); err != nil { + return IAMResponse{}, fmt.Errorf("error unmarshalling role data: %w (%s)", err, jsonString) + } + + return jResp, nil +} + // fetchIAMData fetches the IAM data from the given URL. // In case of a normal AWS setup, baseURL would be metadataBaseURL. // You can use this method, to manually fetch IAM data from a custom // endpoint and pass it to SetIAMData. func fetchIAMData(cl *http.Client, baseURL string) (IAMResponse, error) { + response, err := fetchIAMDataForEcs(cl) + + // If already have ECS response, skip the rest of the function and use it instead. + if err == nil { + return response, nil + } + token, useIMDSv2, err := fetchIMDSToken(cl, baseURL) if err != nil { return IAMResponse{}, fmt.Errorf("error fetching IMDSv2 token: %w", err) From c42ce2eb87fad33f03088fd7ca032d59caae9fd6 Mon Sep 17 00:00:00 2001 From: Rohan Verma Date: Thu, 21 May 2026 23:29:03 +0530 Subject: [PATCH 2/2] fix(iam): preserve compatibility and add ECS coverage --- iam.go | 41 +++++++++++----- iam_test.go | 134 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 164 insertions(+), 11 deletions(-) diff --git a/iam.go b/iam.go index 50641b1..162a3b9 100644 --- a/iam.go +++ b/iam.go @@ -5,6 +5,7 @@ package simples3 import ( "encoding/json" + "errors" "fmt" "io" "net/http" @@ -13,19 +14,25 @@ import ( ) const ( - imdsTokenHeader = "X-aws-ec2-metadata-token" - imdsTokenTtlHeader = "X-aws-ec2-metadata-token-ttl-seconds" - metadataBaseURL = "http://169.254.169.254/latest" - securityCredentialsURI = "/meta-data/iam/security-credentials/" - imdsTokenURI = "/api/token" - defaultIMDSTokenTTL = "60" - ecsContainerCredentialsEnv = "AWS_CONTAINER_CREDENTIALS_RELATIVE_URI" - ecsContainerCredentialsBaseUrl = "http://169.254.170.2" + imdsTokenHeader = "X-aws-ec2-metadata-token" + imdsTokenTtlHeader = "X-aws-ec2-metadata-token-ttl-seconds" + metadataBaseURL = "http://169.254.169.254/latest" + securityCredentialsURI = "/meta-data/iam/security-credentials/" + imdsTokenURI = "/api/token" + defaultIMDSTokenTTL = "60" + ecsContainerCredentialsEnv = "AWS_CONTAINER_CREDENTIALS_RELATIVE_URI" ) +var ecsContainerCredentialsBaseURL = "http://169.254.170.2" + +var errECSCredentialsEndpointNotSet = errors.New("ecs credentials endpoint not set") + // IAMResponse is used by NewUsingIAM to auto // detect the credentials. type IAMResponse struct { + Code string `json:"Code"` + LastUpdated string `json:"LastUpdated"` + Type string `json:"Type"` AccessKeyID string `json:"AccessKeyId"` SecretAccessKey string `json:"SecretAccessKey"` Token string `json:"Token"` @@ -80,12 +87,11 @@ func fetchIMDSToken(cl *http.Client, baseURL string) (string, bool, error) { // fetchIAMDataECS fetches the IAM credentials from the ECS default endpoint. func fetchIAMDataForEcs(cl *http.Client) (IAMResponse, error) { env, isSet := os.LookupEnv(ecsContainerCredentialsEnv) - if !isSet { - return IAMResponse{}, fmt.Errorf("error getting ecs environment variable, most likely not using ecs.") + return IAMResponse{}, errECSCredentialsEndpointNotSet } - url := ecsContainerCredentialsBaseUrl + env + url := ecsContainerCredentialsBaseURL + env req, err := http.NewRequest(http.MethodGet, url, nil) if err != nil { @@ -96,6 +102,14 @@ func fetchIAMDataForEcs(cl *http.Client) (IAMResponse, error) { if err != nil { return IAMResponse{}, fmt.Errorf("error fetching IAM ECS request: %w", err) } + defer func() { + io.Copy(io.Discard, resp.Body) + resp.Body.Close() + }() + + if resp.StatusCode != http.StatusOK { + return IAMResponse{}, fmt.Errorf("error fetching IAM ECS data: %s", resp.Status) + } var jResp IAMResponse jsonString, err := io.ReadAll(resp.Body) @@ -121,6 +135,11 @@ func fetchIAMData(cl *http.Client, baseURL string) (IAMResponse, error) { if err == nil { return response, nil } + // If ECS credentials are explicitly configured, do not fall back to IMDS. + // Falling back here can pick up the instance role instead of the task role. + if !errors.Is(err, errECSCredentialsEndpointNotSet) { + return IAMResponse{}, err + } token, useIMDSv2, err := fetchIMDSToken(cl, baseURL) if err != nil { diff --git a/iam_test.go b/iam_test.go index 2752b04..3d8e3e5 100644 --- a/iam_test.go +++ b/iam_test.go @@ -6,11 +6,15 @@ import ( "net" "net/http" "net/http/httptest" + "os" + "strings" "testing" "time" ) func TestS3_NewUsingIAM(t *testing.T) { + unsetEnvForTest(t, ecsContainerCredentialsEnv) + var ( iam = `test-new-s3-using-iam` resp = `{"Code" : "Success","LastUpdated" : "2018-12-24T10:18:01Z", @@ -108,3 +112,133 @@ func TestS3_NewUsingIAM(t *testing.T) { t.Errorf("Expected error, got nil") } } + +func TestFetchIAMDataForEcs(t *testing.T) { + oldBaseURL := ecsContainerCredentialsBaseURL + ecsContainerCredentialsBaseURL = "" + t.Cleanup(func() { + ecsContainerCredentialsBaseURL = oldBaseURL + }) + + t.Run("success", func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + t.Fatalf("expected GET request, got %s", r.Method) + } + if r.URL.EscapedPath() != "/ecs/creds" { + t.Fatalf("unexpected path: %s", r.URL.EscapedPath()) + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + io.WriteString(w, `{"AccessKeyId":"ecs-access","SecretAccessKey":"ecs-secret","Token":"ecs-token","Expiration":"2018-12-24T16:24:59Z"}`) + })) + defer server.Close() + + ecsContainerCredentialsBaseURL = server.URL + os.Setenv(ecsContainerCredentialsEnv, "/ecs/creds") + defer os.Unsetenv(ecsContainerCredentialsEnv) + + resp, err := fetchIAMDataForEcs(server.Client()) + if err != nil { + t.Fatalf("fetchIAMDataForEcs() error = %v", err) + } + + if resp.AccessKeyID != "ecs-access" || resp.SecretAccessKey != "ecs-secret" || resp.Token != "ecs-token" { + t.Fatalf("unexpected ECS credentials: %+v", resp) + } + }) + + t.Run("non-200 does not fall back to IMDS", func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusForbidden) + io.WriteString(w, `forbidden`) + })) + defer server.Close() + + ecsContainerCredentialsBaseURL = server.URL + os.Setenv(ecsContainerCredentialsEnv, "/ecs/creds") + defer os.Unsetenv(ecsContainerCredentialsEnv) + + _, err := fetchIAMData(server.Client(), "http://should-not-be-used") + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), "error fetching IAM ECS data") { + t.Fatalf("expected ECS error, got %v", err) + } + }) +} + +func TestFetchIAMDataFallsBackToIMDSWhenECSIsUnavailable(t *testing.T) { + unsetEnvForTest(t, ecsContainerCredentialsEnv) + + var ( + iam = `test-new-s3-using-iam` + resp = `{"Code":"Success","LastUpdated":"2018-12-24T10:18:01Z","Type":"AWS-HMAC","AccessKeyId":"abc","SecretAccessKey":"abc","Token":"abc","Expiration":"2018-12-24T16:24:59Z"}` + respIMDSToken = `AQAEAJWopi8yvjKYXyWJbzESE0cms-OoTnptJzS3M9g5iNcl06UEkQ==` + ) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case http.MethodPut: + if r.URL.EscapedPath() != imdsTokenURI { + t.Fatalf("unexpected token path: %s", r.URL.EscapedPath()) + } + w.WriteHeader(http.StatusOK) + io.WriteString(w, respIMDSToken) + case http.MethodGet: + if r.Header.Get(imdsTokenHeader) == "" { + w.WriteHeader(http.StatusUnauthorized) + return + } + switch r.URL.EscapedPath() { + case securityCredentialsURI: + w.WriteHeader(http.StatusOK) + io.WriteString(w, iam) + case securityCredentialsURI + iam: + w.WriteHeader(http.StatusOK) + w.Header().Set("Content-Type", "application/json") + io.WriteString(w, resp) + default: + t.Fatalf("unexpected IMDS path: %s", r.URL.EscapedPath()) + } + default: + t.Fatalf("unexpected method: %s", r.Method) + } + })) + defer server.Close() + + respData, err := fetchIAMData(server.Client(), server.URL) + if err != nil { + t.Fatalf("fetchIAMData() error = %v", err) + } + + if respData.AccessKeyID != "abc" || respData.SecretAccessKey != "abc" || respData.Token != "abc" { + t.Fatalf("unexpected IMDS credentials: %+v", respData) + } + if respData.Code != "Success" || respData.Type != "AWS-HMAC" { + t.Fatalf("expected legacy IAMResponse fields to remain populated, got %+v", respData) + } +} + +func unsetEnvForTest(t *testing.T, key string) { + t.Helper() + + oldValue, hadValue := os.LookupEnv(key) + if err := os.Unsetenv(key); err != nil { + t.Fatalf("failed to unset %s: %v", key, err) + } + + t.Cleanup(func() { + if hadValue { + if err := os.Setenv(key, oldValue); err != nil { + t.Fatalf("failed to restore %s: %v", key, err) + } + return + } + if err := os.Unsetenv(key); err != nil { + t.Fatalf("failed to cleanup %s: %v", key, err) + } + }) +}