diff --git a/iam.go b/iam.go index 4f657ad..162a3b9 100644 --- a/iam.go +++ b/iam.go @@ -5,21 +5,28 @@ package simples3 import ( "encoding/json" + "errors" "fmt" "io" "net/http" + "os" "time" ) const ( - imdsTokenHeader = "X-aws-ec2-metadata-token" - imdsTokenTtlHeader = "X-aws-ec2-metadata-token-ttl-seconds" - metadataBaseURL = "http://169.254.169.254/latest" - securityCredentialsURI = "/meta-data/iam/security-credentials/" - imdsTokenURI = "/api/token" - defaultIMDSTokenTTL = "60" + imdsTokenHeader = "X-aws-ec2-metadata-token" + imdsTokenTtlHeader = "X-aws-ec2-metadata-token-ttl-seconds" + metadataBaseURL = "http://169.254.169.254/latest" + securityCredentialsURI = "/meta-data/iam/security-credentials/" + imdsTokenURI = "/api/token" + defaultIMDSTokenTTL = "60" + ecsContainerCredentialsEnv = "AWS_CONTAINER_CREDENTIALS_RELATIVE_URI" ) +var ecsContainerCredentialsBaseURL = "http://169.254.170.2" + +var errECSCredentialsEndpointNotSet = errors.New("ecs credentials endpoint not set") + // IAMResponse is used by NewUsingIAM to auto // detect the credentials. type IAMResponse struct { @@ -77,11 +84,63 @@ func fetchIMDSToken(cl *http.Client, baseURL string) (string, bool, error) { return string(token), true, nil } +// fetchIAMDataECS fetches the IAM credentials from the ECS default endpoint. +func fetchIAMDataForEcs(cl *http.Client) (IAMResponse, error) { + env, isSet := os.LookupEnv(ecsContainerCredentialsEnv) + if !isSet { + return IAMResponse{}, errECSCredentialsEndpointNotSet + } + + url := ecsContainerCredentialsBaseURL + env + + req, err := http.NewRequest(http.MethodGet, url, nil) + if err != nil { + return IAMResponse{}, fmt.Errorf("error creating IAM ECS request: %w", err) + } + + resp, err := cl.Do(req) + if err != nil { + return IAMResponse{}, fmt.Errorf("error fetching IAM ECS request: %w", err) + } + defer func() { + io.Copy(io.Discard, resp.Body) + resp.Body.Close() + }() + + if resp.StatusCode != http.StatusOK { + return IAMResponse{}, fmt.Errorf("error fetching IAM ECS data: %s", resp.Status) + } + + var jResp IAMResponse + jsonString, err := io.ReadAll(resp.Body) + if err != nil { + return IAMResponse{}, fmt.Errorf("error reading role data: %w", err) + } + + if err := json.Unmarshal(jsonString, &jResp); err != nil { + return IAMResponse{}, fmt.Errorf("error unmarshalling role data: %w (%s)", err, jsonString) + } + + return jResp, nil +} + // fetchIAMData fetches the IAM data from the given URL. // In case of a normal AWS setup, baseURL would be metadataBaseURL. // You can use this method, to manually fetch IAM data from a custom // endpoint and pass it to SetIAMData. func fetchIAMData(cl *http.Client, baseURL string) (IAMResponse, error) { + response, err := fetchIAMDataForEcs(cl) + + // If already have ECS response, skip the rest of the function and use it instead. + if err == nil { + return response, nil + } + // If ECS credentials are explicitly configured, do not fall back to IMDS. + // Falling back here can pick up the instance role instead of the task role. + if !errors.Is(err, errECSCredentialsEndpointNotSet) { + return IAMResponse{}, err + } + token, useIMDSv2, err := fetchIMDSToken(cl, baseURL) if err != nil { return IAMResponse{}, fmt.Errorf("error fetching IMDSv2 token: %w", err) diff --git a/iam_test.go b/iam_test.go index 2752b04..3d8e3e5 100644 --- a/iam_test.go +++ b/iam_test.go @@ -6,11 +6,15 @@ import ( "net" "net/http" "net/http/httptest" + "os" + "strings" "testing" "time" ) func TestS3_NewUsingIAM(t *testing.T) { + unsetEnvForTest(t, ecsContainerCredentialsEnv) + var ( iam = `test-new-s3-using-iam` resp = `{"Code" : "Success","LastUpdated" : "2018-12-24T10:18:01Z", @@ -108,3 +112,133 @@ func TestS3_NewUsingIAM(t *testing.T) { t.Errorf("Expected error, got nil") } } + +func TestFetchIAMDataForEcs(t *testing.T) { + oldBaseURL := ecsContainerCredentialsBaseURL + ecsContainerCredentialsBaseURL = "" + t.Cleanup(func() { + ecsContainerCredentialsBaseURL = oldBaseURL + }) + + t.Run("success", func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + t.Fatalf("expected GET request, got %s", r.Method) + } + if r.URL.EscapedPath() != "/ecs/creds" { + t.Fatalf("unexpected path: %s", r.URL.EscapedPath()) + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + io.WriteString(w, `{"AccessKeyId":"ecs-access","SecretAccessKey":"ecs-secret","Token":"ecs-token","Expiration":"2018-12-24T16:24:59Z"}`) + })) + defer server.Close() + + ecsContainerCredentialsBaseURL = server.URL + os.Setenv(ecsContainerCredentialsEnv, "/ecs/creds") + defer os.Unsetenv(ecsContainerCredentialsEnv) + + resp, err := fetchIAMDataForEcs(server.Client()) + if err != nil { + t.Fatalf("fetchIAMDataForEcs() error = %v", err) + } + + if resp.AccessKeyID != "ecs-access" || resp.SecretAccessKey != "ecs-secret" || resp.Token != "ecs-token" { + t.Fatalf("unexpected ECS credentials: %+v", resp) + } + }) + + t.Run("non-200 does not fall back to IMDS", func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusForbidden) + io.WriteString(w, `forbidden`) + })) + defer server.Close() + + ecsContainerCredentialsBaseURL = server.URL + os.Setenv(ecsContainerCredentialsEnv, "/ecs/creds") + defer os.Unsetenv(ecsContainerCredentialsEnv) + + _, err := fetchIAMData(server.Client(), "http://should-not-be-used") + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), "error fetching IAM ECS data") { + t.Fatalf("expected ECS error, got %v", err) + } + }) +} + +func TestFetchIAMDataFallsBackToIMDSWhenECSIsUnavailable(t *testing.T) { + unsetEnvForTest(t, ecsContainerCredentialsEnv) + + var ( + iam = `test-new-s3-using-iam` + resp = `{"Code":"Success","LastUpdated":"2018-12-24T10:18:01Z","Type":"AWS-HMAC","AccessKeyId":"abc","SecretAccessKey":"abc","Token":"abc","Expiration":"2018-12-24T16:24:59Z"}` + respIMDSToken = `AQAEAJWopi8yvjKYXyWJbzESE0cms-OoTnptJzS3M9g5iNcl06UEkQ==` + ) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case http.MethodPut: + if r.URL.EscapedPath() != imdsTokenURI { + t.Fatalf("unexpected token path: %s", r.URL.EscapedPath()) + } + w.WriteHeader(http.StatusOK) + io.WriteString(w, respIMDSToken) + case http.MethodGet: + if r.Header.Get(imdsTokenHeader) == "" { + w.WriteHeader(http.StatusUnauthorized) + return + } + switch r.URL.EscapedPath() { + case securityCredentialsURI: + w.WriteHeader(http.StatusOK) + io.WriteString(w, iam) + case securityCredentialsURI + iam: + w.WriteHeader(http.StatusOK) + w.Header().Set("Content-Type", "application/json") + io.WriteString(w, resp) + default: + t.Fatalf("unexpected IMDS path: %s", r.URL.EscapedPath()) + } + default: + t.Fatalf("unexpected method: %s", r.Method) + } + })) + defer server.Close() + + respData, err := fetchIAMData(server.Client(), server.URL) + if err != nil { + t.Fatalf("fetchIAMData() error = %v", err) + } + + if respData.AccessKeyID != "abc" || respData.SecretAccessKey != "abc" || respData.Token != "abc" { + t.Fatalf("unexpected IMDS credentials: %+v", respData) + } + if respData.Code != "Success" || respData.Type != "AWS-HMAC" { + t.Fatalf("expected legacy IAMResponse fields to remain populated, got %+v", respData) + } +} + +func unsetEnvForTest(t *testing.T, key string) { + t.Helper() + + oldValue, hadValue := os.LookupEnv(key) + if err := os.Unsetenv(key); err != nil { + t.Fatalf("failed to unset %s: %v", key, err) + } + + t.Cleanup(func() { + if hadValue { + if err := os.Setenv(key, oldValue); err != nil { + t.Fatalf("failed to restore %s: %v", key, err) + } + return + } + if err := os.Unsetenv(key); err != nil { + t.Fatalf("failed to cleanup %s: %v", key, err) + } + }) +}