Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 65 additions & 6 deletions iam.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,28 @@ package simples3

import (
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"os"
"time"
)

const (
imdsTokenHeader = "X-aws-ec2-metadata-token"
imdsTokenTtlHeader = "X-aws-ec2-metadata-token-ttl-seconds"
metadataBaseURL = "http://169.254.169.254/latest"
securityCredentialsURI = "/meta-data/iam/security-credentials/"
imdsTokenURI = "/api/token"
defaultIMDSTokenTTL = "60"
imdsTokenHeader = "X-aws-ec2-metadata-token"
imdsTokenTtlHeader = "X-aws-ec2-metadata-token-ttl-seconds"
metadataBaseURL = "http://169.254.169.254/latest"
securityCredentialsURI = "/meta-data/iam/security-credentials/"
imdsTokenURI = "/api/token"
defaultIMDSTokenTTL = "60"
ecsContainerCredentialsEnv = "AWS_CONTAINER_CREDENTIALS_RELATIVE_URI"
)

var ecsContainerCredentialsBaseURL = "http://169.254.170.2"

var errECSCredentialsEndpointNotSet = errors.New("ecs credentials endpoint not set")

// IAMResponse is used by NewUsingIAM to auto
// detect the credentials.
type IAMResponse struct {
Expand Down Expand Up @@ -77,11 +84,63 @@ func fetchIMDSToken(cl *http.Client, baseURL string) (string, bool, error) {
return string(token), true, nil
}

// fetchIAMDataECS fetches the IAM credentials from the ECS default endpoint.
func fetchIAMDataForEcs(cl *http.Client) (IAMResponse, error) {
env, isSet := os.LookupEnv(ecsContainerCredentialsEnv)
if !isSet {
return IAMResponse{}, errECSCredentialsEndpointNotSet
}

url := ecsContainerCredentialsBaseURL + env

req, err := http.NewRequest(http.MethodGet, url, nil)
if err != nil {
return IAMResponse{}, fmt.Errorf("error creating IAM ECS request: %w", err)
}

resp, err := cl.Do(req)
if err != nil {
return IAMResponse{}, fmt.Errorf("error fetching IAM ECS request: %w", err)
}
defer func() {
io.Copy(io.Discard, resp.Body)
resp.Body.Close()
}()

if resp.StatusCode != http.StatusOK {
return IAMResponse{}, fmt.Errorf("error fetching IAM ECS data: %s", resp.Status)
}

var jResp IAMResponse
jsonString, err := io.ReadAll(resp.Body)
if err != nil {
return IAMResponse{}, fmt.Errorf("error reading role data: %w", err)
}

if err := json.Unmarshal(jsonString, &jResp); err != nil {
return IAMResponse{}, fmt.Errorf("error unmarshalling role data: %w (%s)", err, jsonString)
}

return jResp, nil
}

// fetchIAMData fetches the IAM data from the given URL.
// In case of a normal AWS setup, baseURL would be metadataBaseURL.
// You can use this method, to manually fetch IAM data from a custom
// endpoint and pass it to SetIAMData.
func fetchIAMData(cl *http.Client, baseURL string) (IAMResponse, error) {
response, err := fetchIAMDataForEcs(cl)

// If already have ECS response, skip the rest of the function and use it instead.
if err == nil {
return response, nil
}
// If ECS credentials are explicitly configured, do not fall back to IMDS.
// Falling back here can pick up the instance role instead of the task role.
if !errors.Is(err, errECSCredentialsEndpointNotSet) {
return IAMResponse{}, err
}

token, useIMDSv2, err := fetchIMDSToken(cl, baseURL)
if err != nil {
return IAMResponse{}, fmt.Errorf("error fetching IMDSv2 token: %w", err)
Expand Down
134 changes: 134 additions & 0 deletions iam_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,15 @@ import (
"net"
"net/http"
"net/http/httptest"
"os"
"strings"
"testing"
"time"
)

func TestS3_NewUsingIAM(t *testing.T) {
unsetEnvForTest(t, ecsContainerCredentialsEnv)

var (
iam = `test-new-s3-using-iam`
resp = `{"Code" : "Success","LastUpdated" : "2018-12-24T10:18:01Z",
Expand Down Expand Up @@ -108,3 +112,133 @@ func TestS3_NewUsingIAM(t *testing.T) {
t.Errorf("Expected error, got nil")
}
}

func TestFetchIAMDataForEcs(t *testing.T) {
oldBaseURL := ecsContainerCredentialsBaseURL
ecsContainerCredentialsBaseURL = ""
t.Cleanup(func() {
ecsContainerCredentialsBaseURL = oldBaseURL
})

t.Run("success", func(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
t.Fatalf("expected GET request, got %s", r.Method)
}
if r.URL.EscapedPath() != "/ecs/creds" {
t.Fatalf("unexpected path: %s", r.URL.EscapedPath())
}

w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
io.WriteString(w, `{"AccessKeyId":"ecs-access","SecretAccessKey":"ecs-secret","Token":"ecs-token","Expiration":"2018-12-24T16:24:59Z"}`)
}))
defer server.Close()

ecsContainerCredentialsBaseURL = server.URL
os.Setenv(ecsContainerCredentialsEnv, "/ecs/creds")
defer os.Unsetenv(ecsContainerCredentialsEnv)

resp, err := fetchIAMDataForEcs(server.Client())
if err != nil {
t.Fatalf("fetchIAMDataForEcs() error = %v", err)
}

if resp.AccessKeyID != "ecs-access" || resp.SecretAccessKey != "ecs-secret" || resp.Token != "ecs-token" {
t.Fatalf("unexpected ECS credentials: %+v", resp)
}
})

t.Run("non-200 does not fall back to IMDS", func(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusForbidden)
io.WriteString(w, `forbidden`)
}))
defer server.Close()

ecsContainerCredentialsBaseURL = server.URL
os.Setenv(ecsContainerCredentialsEnv, "/ecs/creds")
defer os.Unsetenv(ecsContainerCredentialsEnv)

_, err := fetchIAMData(server.Client(), "http://should-not-be-used")
if err == nil {
t.Fatal("expected error, got nil")
}
if !strings.Contains(err.Error(), "error fetching IAM ECS data") {
t.Fatalf("expected ECS error, got %v", err)
}
})
}

func TestFetchIAMDataFallsBackToIMDSWhenECSIsUnavailable(t *testing.T) {
unsetEnvForTest(t, ecsContainerCredentialsEnv)

var (
iam = `test-new-s3-using-iam`
resp = `{"Code":"Success","LastUpdated":"2018-12-24T10:18:01Z","Type":"AWS-HMAC","AccessKeyId":"abc","SecretAccessKey":"abc","Token":"abc","Expiration":"2018-12-24T16:24:59Z"}`
respIMDSToken = `AQAEAJWopi8yvjKYXyWJbzESE0cms-OoTnptJzS3M9g5iNcl06UEkQ==`
)

server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.Method {
case http.MethodPut:
if r.URL.EscapedPath() != imdsTokenURI {
t.Fatalf("unexpected token path: %s", r.URL.EscapedPath())
}
w.WriteHeader(http.StatusOK)
io.WriteString(w, respIMDSToken)
case http.MethodGet:
if r.Header.Get(imdsTokenHeader) == "" {
w.WriteHeader(http.StatusUnauthorized)
return
}
switch r.URL.EscapedPath() {
case securityCredentialsURI:
w.WriteHeader(http.StatusOK)
io.WriteString(w, iam)
case securityCredentialsURI + iam:
w.WriteHeader(http.StatusOK)
w.Header().Set("Content-Type", "application/json")
io.WriteString(w, resp)
default:
t.Fatalf("unexpected IMDS path: %s", r.URL.EscapedPath())
}
default:
t.Fatalf("unexpected method: %s", r.Method)
}
}))
defer server.Close()

respData, err := fetchIAMData(server.Client(), server.URL)
if err != nil {
t.Fatalf("fetchIAMData() error = %v", err)
}

if respData.AccessKeyID != "abc" || respData.SecretAccessKey != "abc" || respData.Token != "abc" {
t.Fatalf("unexpected IMDS credentials: %+v", respData)
}
if respData.Code != "Success" || respData.Type != "AWS-HMAC" {
t.Fatalf("expected legacy IAMResponse fields to remain populated, got %+v", respData)
}
}

func unsetEnvForTest(t *testing.T, key string) {
t.Helper()

oldValue, hadValue := os.LookupEnv(key)
if err := os.Unsetenv(key); err != nil {
t.Fatalf("failed to unset %s: %v", key, err)
}

t.Cleanup(func() {
if hadValue {
if err := os.Setenv(key, oldValue); err != nil {
t.Fatalf("failed to restore %s: %v", key, err)
}
return
}
if err := os.Unsetenv(key); err != nil {
t.Fatalf("failed to cleanup %s: %v", key, err)
}
})
}
Loading