diff --git a/README.md b/README.md index 5ffb5ae..ed497f2 100644 --- a/README.md +++ b/README.md @@ -733,6 +733,32 @@ s3 := simples3.New("nyc3", "your-access-key", "your-secret-key") s3.SetEndpoint("https://nyc3.digitaloceanspaces.com") ``` +### Addressing Style + +`Bucket` remains required for bucket-scoped operations. Addressing style is a client setting, not a bucketless API mode. + +By default, simples3 preserves its legacy behavior for compatibility: +- runtime requests use path-style URLs +- presigned URLs keep their historical default behavior +- direct `CreateUploadPolicies()` defaults keep their historical action URL behavior + +To explicitly control addressing style across supported surfaces, use `SetUsePathStyle`: + +```go +// Force path-style addressing +s3 := simples3.New("us-east-1", "your-access-key", "your-secret-key") +s3.SetUsePathStyle(true) + +// Opt into virtual-hosted-style addressing when the bucket/endpoint is compatible +s3.SetUsePathStyle(false) +``` + +When virtual-hosted-style is explicitly enabled, simples3 safely falls back to path-style for incompatible cases such as: +- dotted bucket names over HTTPS +- localhost or IP-based endpoints +- endpoints with a path prefix (for example `https://example.com/base`) +- non-DNS-compatible bucket names + ### IAM Credentials On EC2 instances, use IAM roles automatically: diff --git a/addressing.go b/addressing.go new file mode 100644 index 0000000..c5371cb --- /dev/null +++ b/addressing.go @@ -0,0 +1,237 @@ +package simples3 + +import ( + "fmt" + "net" + "net/url" + "strings" +) + +type addressingMode uint8 + +const ( + addressingModeLegacy addressingMode = iota + addressingModePath + addressingModeVirtual +) + +type addressingSurface uint8 + +const ( + addressingSurfaceRuntime addressingSurface = iota + addressingSurfacePresign + addressingSurfacePolicy +) + +type addressStyle uint8 + +const ( + addressStylePath addressStyle = iota + addressStyleVirtual +) + +type resolvedAddress struct { + scheme string + host string + path string + style addressStyle + fallbackReason string +} + +func (a resolvedAddress) urlString() string { + path := a.path + if path == "" { + path = "/" + } + return a.scheme + "://" + a.host + path +} + +func (s3 *S3) resolveAddress(surface addressingSurface, bucket string, args ...string) resolvedAddress { + switch s3.addressingMode { + case addressingModePath: + return s3.resolvePathAddress(bucket, args...) + case addressingModeVirtual: + return s3.resolveExplicitVirtualAddress(bucket, args...) + default: + return s3.resolveLegacyAddress(surface, bucket, args...) + } +} + +func (s3 *S3) resolveLegacyAddress(surface addressingSurface, bucket string, args ...string) resolvedAddress { + switch surface { + case addressingSurfacePresign: + if endpoint, _ := url.Parse(s3.Endpoint); endpoint.Host != "" { + return s3.resolvePathAddress(bucket, args...) + } + return resolvedAddress{ + scheme: "https", + host: bucket + "." + defaultPresignedHost, + path: buildVirtualObjectPath("", args...), + style: addressStyleVirtual, + } + case addressingSurfacePolicy: + return parseResolvedAddress(fmt.Sprintf(defaultUploadURLFormat, bucket), addressStyleVirtual, "") + default: + return s3.resolvePathAddress(bucket, args...) + } +} + +func (s3 *S3) resolvePathAddress(bucket string, args ...string) resolvedAddress { + return parseResolvedAddress(s3.legacyPathURL(bucket, args...), addressStylePath, "") +} + +func (s3 *S3) resolveExplicitVirtualAddress(bucket string, args ...string) resolvedAddress { + base, err := s3.serviceBaseURL() + if err != nil { + return parseResolvedAddress(s3.legacyPathURL(bucket, args...), addressStylePath, "invalid service endpoint") + } + + if reason := virtualAddressingFallbackReason(base, bucket); reason != "" { + return parseResolvedAddress(s3.legacyPathURL(bucket, args...), addressStylePath, reason) + } + + return resolvedAddress{ + scheme: base.Scheme, + host: bucketHost(base, bucket), + path: buildVirtualObjectPath(base.EscapedPath(), args...), + style: addressStyleVirtual, + } +} + +func (s3 *S3) legacyPathURL(bucket string, args ...string) string { + path := bucket + if len(args) > 0 { + path += "/" + strings.Join(args, "/") + } + encodedPath := encodePath(path) + + if len(s3.Endpoint) > 0 { + return s3.Endpoint + "/" + encodedPath + } + return fmt.Sprintf(s3.URIFormat, s3.Region, encodedPath) +} + +func (s3 *S3) serviceBaseURL() (*url.URL, error) { + rawURL := s3.Endpoint + if rawURL == "" { + rawURL = fmt.Sprintf(s3.URIFormat, s3.Region, "") + } + return url.Parse(rawURL) +} + +func parseResolvedAddress(rawURL string, style addressStyle, fallbackReason string) resolvedAddress { + parsed, err := url.Parse(rawURL) + if err != nil { + return resolvedAddress{style: style, fallbackReason: fallbackReason} + } + + path := parsed.EscapedPath() + if path == "" { + path = "/" + } + + return resolvedAddress{ + scheme: parsed.Scheme, + host: parsed.Host, + path: path, + style: style, + fallbackReason: fallbackReason, + } +} + +func buildVirtualObjectPath(basePath string, args ...string) string { + trimmedBase := strings.TrimRight(basePath, "/") + objectPath := strings.Join(args, "/") + if objectPath == "" { + if trimmedBase == "" { + return "/" + } + return trimmedBase + "/" + } + + encodedObjectPath := encodePath(objectPath) + if trimmedBase == "" { + return "/" + encodedObjectPath + } + return trimmedBase + "/" + encodedObjectPath +} + +func bucketHost(base *url.URL, bucket string) string { + hostname := base.Hostname() + if port := base.Port(); port != "" { + return bucket + "." + hostname + ":" + port + } + return bucket + "." + hostname +} + +func virtualAddressingFallbackReason(base *url.URL, bucket string) string { + if !dnsCompatibleBucketName(bucket) { + return "bucket is not DNS compatible" + } + if base.Scheme == "https" && strings.Contains(bucket, ".") { + return "dotted bucket over https" + } + if basePath := strings.Trim(base.EscapedPath(), "/"); basePath != "" { + return "endpoint has path prefix" + } + if hostname := base.Hostname(); hostname == "localhost" { + return "localhost endpoint" + } else if hostname != "" { + if ip := net.ParseIP(hostname); ip != nil { + return "ip endpoint" + } + } + return "" +} + +func dnsCompatibleBucketName(bucket string) bool { + if len(bucket) < 3 || len(bucket) > 63 { + return false + } + if strings.Contains(bucket, "..") { + return false + } + if !isLowercaseLetterOrDigit(bucket[0]) || !isLowercaseLetterOrDigit(bucket[len(bucket)-1]) { + return false + } + for i := 1; i < len(bucket)-1; i++ { + c := bucket[i] + if !isLowercaseLetterOrDigit(c) && c != '.' && c != '-' { + return false + } + } + + parts := strings.Split(bucket, ".") + for _, part := range parts { + if part == "" { + return false + } + if !isLowercaseLetterOrDigit(part[0]) || !isLowercaseLetterOrDigit(part[len(part)-1]) { + return false + } + } + + if len(parts) == 4 { + isIPAddress := true + for _, part := range parts { + for i := 0; i < len(part); i++ { + if part[i] < '0' || part[i] > '9' { + isIPAddress = false + break + } + } + if !isIPAddress { + break + } + } + if isIPAddress { + return false + } + } + + return true +} + +func isLowercaseLetterOrDigit(c byte) bool { + return (c >= 'a' && c <= 'z') || (c >= '0' && c <= '9') +} diff --git a/addressing_explicit_test.go b/addressing_explicit_test.go new file mode 100644 index 0000000..f8cccd9 --- /dev/null +++ b/addressing_explicit_test.go @@ -0,0 +1,225 @@ +package simples3 + +import ( + "net/url" + "testing" + "time" +) + +func TestSetUsePathStyle_VirtualHostedStyleRuntime(t *testing.T) { + s3 := New("us-west-2", "AccessKey", "SuperSecretKey").SetUsePathStyle(false) + + got := s3.getURL("examplebucket", "photos/puppy.jpg") + want := "https://examplebucket.s3.us-west-2.amazonaws.com/photos/puppy.jpg" + if got != want { + t.Fatalf("getURL() = %q, want %q", got, want) + } +} + +func TestSetUsePathStyle_VirtualHostedStyleRuntime_CustomEndpoint(t *testing.T) { + s3 := New("nyc3", "AccessKey", "SuperSecretKey") + s3.SetEndpoint("https://objects.example.com") + s3.SetUsePathStyle(false) + + got := s3.getURL("examplebucket", "photos/puppy.jpg") + want := "https://examplebucket.objects.example.com/photos/puppy.jpg" + if got != want { + t.Fatalf("getURL() = %q, want %q", got, want) + } +} + +func TestSetUsePathStyle_VirtualHostedStyleRuntimeFallbacks(t *testing.T) { + tests := []struct { + name string + bucket string + endpoint string + want string + }{ + { + name: "dotted bucket over https", + bucket: "example.bucket", + want: "https://s3.us-west-2.amazonaws.com/example.bucket/photos/puppy.jpg", + }, + { + name: "non dns bucket", + bucket: "example_bucket", + want: "https://s3.us-west-2.amazonaws.com/example_bucket/photos/puppy.jpg", + }, + { + name: "invalid dotted label leading hyphen", + bucket: "a.-b", + want: "https://s3.us-west-2.amazonaws.com/a.-b/photos/puppy.jpg", + }, + { + name: "invalid dotted label trailing hyphen", + bucket: "a-.b", + want: "https://s3.us-west-2.amazonaws.com/a-.b/photos/puppy.jpg", + }, + { + name: "localhost endpoint", + bucket: "examplebucket", + endpoint: "http://localhost:9000", + want: "http://localhost:9000/examplebucket/photos/puppy.jpg", + }, + { + name: "path prefixed endpoint", + bucket: "examplebucket", + endpoint: "https://objects.example.com/base", + want: "https://objects.example.com/base/examplebucket/photos/puppy.jpg", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s3 := New("us-west-2", "AccessKey", "SuperSecretKey") + if tt.endpoint != "" { + s3.SetEndpoint(tt.endpoint) + } + s3.SetUsePathStyle(false) + + if got := s3.getURL(tt.bucket, "photos/puppy.jpg"); got != tt.want { + t.Fatalf("getURL() = %q, want %q", got, tt.want) + } + }) + } +} + +func TestGeneratePresignedURL_UsePathStyleFalseUsesVirtualHostedStyle(t *testing.T) { + timestamp, err := time.Parse(time.RFC1123, "Fri, 24 May 2013 00:00:00 GMT") + if err != nil { + t.Fatalf("time.Parse() error = %v", err) + } + + s3 := New( + "us-west-2", + "AKIAIOSFODNN7EXAMPLE", + "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY", + ).SetUsePathStyle(false) + + presignedURL := s3.GeneratePresignedURL(PresignedInput{ + Bucket: "examplebucket", + ObjectKey: "photos/puppy.jpg", + Method: "GET", + Timestamp: timestamp, + ExpirySeconds: 3600, + }) + + parsed, err := url.Parse(presignedURL) + if err != nil { + t.Fatalf("url.Parse() error = %v", err) + } + + if parsed.Host != "examplebucket.s3.us-west-2.amazonaws.com" { + t.Fatalf("host = %q, want examplebucket.s3.us-west-2.amazonaws.com", parsed.Host) + } + if parsed.EscapedPath() != "/photos/puppy.jpg" { + t.Fatalf("path = %q, want /photos/puppy.jpg", parsed.EscapedPath()) + } + if parsed.Query().Get("X-Amz-Signature") == "" { + t.Fatalf("missing X-Amz-Signature in %q", presignedURL) + } +} + +func TestGeneratePresignedURL_UsePathStyleFalseFallbacksToPathStyle(t *testing.T) { + timestamp, err := time.Parse(time.RFC1123, "Fri, 24 May 2013 00:00:00 GMT") + if err != nil { + t.Fatalf("time.Parse() error = %v", err) + } + + tests := []struct { + name string + bucket string + endpoint string + wantHost string + wantPath string + }{ + { + name: "dotted bucket over https", + bucket: "example.bucket", + wantHost: "s3.us-west-2.amazonaws.com", + wantPath: "/example.bucket/photos/puppy.jpg", + }, + { + name: "path prefixed endpoint", + bucket: "examplebucket", + endpoint: "https://objects.example.com/base", + wantHost: "objects.example.com", + wantPath: "/base/examplebucket/photos/puppy.jpg", + }, + { + name: "invalid dotted label leading hyphen", + bucket: "a.-b", + wantHost: "s3.us-west-2.amazonaws.com", + wantPath: "/a.-b/photos/puppy.jpg", + }, + { + name: "invalid dotted label trailing hyphen", + bucket: "a-.b", + wantHost: "s3.us-west-2.amazonaws.com", + wantPath: "/a-.b/photos/puppy.jpg", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s3 := New( + "us-west-2", + "AKIAIOSFODNN7EXAMPLE", + "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY", + ) + if tt.endpoint != "" { + s3.SetEndpoint(tt.endpoint) + } + s3.SetUsePathStyle(false) + + presignedURL := s3.GeneratePresignedURL(PresignedInput{ + Bucket: tt.bucket, + ObjectKey: "photos/puppy.jpg", + Method: "GET", + Timestamp: timestamp, + ExpirySeconds: 3600, + }) + + parsed, err := url.Parse(presignedURL) + if err != nil { + t.Fatalf("url.Parse() error = %v", err) + } + if parsed.Host != tt.wantHost { + t.Fatalf("host = %q, want %q", parsed.Host, tt.wantHost) + } + if parsed.EscapedPath() != tt.wantPath { + t.Fatalf("path = %q, want %q", parsed.EscapedPath(), tt.wantPath) + } + }) + } +} + +func TestGeneratePresignedUploadPartURL_UsePathStyleFalseUsesVirtualHostedStyle(t *testing.T) { + s3 := New("us-west-2", "AccessKey", "SuperSecretKey").SetUsePathStyle(false) + + presignedURL := s3.GeneratePresignedUploadPartURL(PresignedMultipartInput{ + Bucket: "examplebucket", + ObjectKey: "multipart/test.bin", + UploadID: "upload-id", + PartNumber: 7, + ExpirySeconds: 3600, + }) + + parsed, err := url.Parse(presignedURL) + if err != nil { + t.Fatalf("url.Parse() error = %v", err) + } + + if parsed.Host != "examplebucket.s3.us-west-2.amazonaws.com" { + t.Fatalf("host = %q, want examplebucket.s3.us-west-2.amazonaws.com", parsed.Host) + } + if parsed.EscapedPath() != "/multipart/test.bin" { + t.Fatalf("path = %q, want /multipart/test.bin", parsed.EscapedPath()) + } + if got := parsed.Query().Get("partNumber"); got != "7" { + t.Fatalf("partNumber = %q, want 7", got) + } + if got := parsed.Query().Get("uploadId"); got != "upload-id" { + t.Fatalf("uploadId = %q, want upload-id", got) + } +} diff --git a/addressing_legacy_test.go b/addressing_legacy_test.go new file mode 100644 index 0000000..737ade3 --- /dev/null +++ b/addressing_legacy_test.go @@ -0,0 +1,103 @@ +package simples3 + +import ( + "net/url" + "strings" + "testing" + "time" +) + +func TestGeneratePresignedURL_LegacyCustomEndpointUsesPathStyle(t *testing.T) { + timestamp, err := time.Parse(time.RFC1123, "Fri, 24 May 2013 00:00:00 GMT") + if err != nil { + t.Fatalf("time.Parse() error = %v", err) + } + + s3 := New( + "us-east-1", + "AKIAIOSFODNN7EXAMPLE", + "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY", + ) + s3.SetEndpoint("https://objects.example.com/base") + + presignedURL := s3.GeneratePresignedURL(PresignedInput{ + Bucket: "examplebucket", + ObjectKey: "test.txt", + Method: "GET", + Timestamp: timestamp, + ExpirySeconds: 86400, + }) + + parsed, err := url.Parse(presignedURL) + if err != nil { + t.Fatalf("url.Parse() error = %v", err) + } + + if parsed.Scheme != "https" { + t.Fatalf("scheme = %q, want https", parsed.Scheme) + } + if parsed.Host != "objects.example.com" { + t.Fatalf("host = %q, want objects.example.com", parsed.Host) + } + if parsed.EscapedPath() != "/base/examplebucket/test.txt" { + t.Fatalf("path = %q, want /base/examplebucket/test.txt", parsed.EscapedPath()) + } + if strings.HasPrefix(parsed.Host, "examplebucket.") { + t.Fatalf("legacy custom endpoint should keep bucket in path, got host %q", parsed.Host) + } +} + +func TestGeneratePresignedUploadPartURL_LegacyCustomEndpointUsesPathStyle(t *testing.T) { + s3 := New("us-east-1", "AccessKey", "SuperSecretKey") + s3.SetEndpoint("https://objects.example.com/base") + + presignedURL := s3.GeneratePresignedUploadPartURL(PresignedMultipartInput{ + Bucket: "examplebucket", + ObjectKey: "test.txt", + UploadID: "upload-id", + PartNumber: 7, + ExpirySeconds: 3600, + }) + + parsed, err := url.Parse(presignedURL) + if err != nil { + t.Fatalf("url.Parse() error = %v", err) + } + + if parsed.Host != "objects.example.com" { + t.Fatalf("host = %q, want objects.example.com", parsed.Host) + } + if parsed.EscapedPath() != "/base/examplebucket/test.txt" { + t.Fatalf("path = %q, want /base/examplebucket/test.txt", parsed.EscapedPath()) + } + if got := parsed.Query().Get("partNumber"); got != "7" { + t.Fatalf("partNumber = %q, want 7", got) + } + if got := parsed.Query().Get("uploadId"); got != "upload-id" { + t.Fatalf("uploadId = %q, want upload-id", got) + } +} + +func TestCreateUploadPolicies_LegacyDefaultUploadURL(t *testing.T) { + s3 := New("us-east-1", "AccessKey", "SuperSecretKey") + + policies, err := s3.CreateUploadPolicies(UploadConfig{ + BucketName: "examplebucket", + ObjectKey: "test.txt", + ContentType: "text/plain", + FileSize: 123, + }) + if err != nil { + t.Fatalf("CreateUploadPolicies() error = %v", err) + } + + if policies.URL != "http://examplebucket.s3.amazonaws.com/" { + t.Fatalf("URL = %q, want http://examplebucket.s3.amazonaws.com/", policies.URL) + } + if got := policies.Form["key"]; got != "test.txt" { + t.Fatalf("form[key] = %q, want test.txt", got) + } + if got := policies.Form["Content-Type"]; got != "text/plain" { + t.Fatalf("form[Content-Type] = %q, want text/plain", got) + } +} diff --git a/addressing_pathstyle_test.go b/addressing_pathstyle_test.go new file mode 100644 index 0000000..12d41d0 --- /dev/null +++ b/addressing_pathstyle_test.go @@ -0,0 +1,83 @@ +package simples3 + +import ( + "net/url" + "testing" + "time" +) + +func TestSetUsePathStyle_PathStyleRuntime(t *testing.T) { + s3 := New("us-west-2", "AccessKey", "SuperSecretKey").SetUsePathStyle(true) + + got := s3.getURL("examplebucket", "photos/puppy.jpg") + want := "https://s3.us-west-2.amazonaws.com/examplebucket/photos/puppy.jpg" + if got != want { + t.Fatalf("getURL() = %q, want %q", got, want) + } +} + +func TestGeneratePresignedURL_UsePathStyleTrueUsesPathStyle(t *testing.T) { + timestamp, err := time.Parse(time.RFC1123, "Fri, 24 May 2013 00:00:00 GMT") + if err != nil { + t.Fatalf("time.Parse() error = %v", err) + } + + s3 := New( + "us-west-2", + "AKIAIOSFODNN7EXAMPLE", + "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY", + ).SetUsePathStyle(true) + + presignedURL := s3.GeneratePresignedURL(PresignedInput{ + Bucket: "examplebucket", + ObjectKey: "photos/puppy.jpg", + Method: "GET", + Timestamp: timestamp, + ExpirySeconds: 3600, + }) + + parsed, err := url.Parse(presignedURL) + if err != nil { + t.Fatalf("url.Parse() error = %v", err) + } + + if parsed.Host != "s3.us-west-2.amazonaws.com" { + t.Fatalf("host = %q, want s3.us-west-2.amazonaws.com", parsed.Host) + } + if parsed.EscapedPath() != "/examplebucket/photos/puppy.jpg" { + t.Fatalf("path = %q, want /examplebucket/photos/puppy.jpg", parsed.EscapedPath()) + } + if parsed.Query().Get("X-Amz-Signature") == "" { + t.Fatalf("missing X-Amz-Signature in %q", presignedURL) + } +} + +func TestGeneratePresignedUploadPartURL_UsePathStyleTrueUsesPathStyle(t *testing.T) { + s3 := New("us-west-2", "AccessKey", "SuperSecretKey").SetUsePathStyle(true) + + presignedURL := s3.GeneratePresignedUploadPartURL(PresignedMultipartInput{ + Bucket: "examplebucket", + ObjectKey: "multipart/test.bin", + UploadID: "upload-id", + PartNumber: 7, + ExpirySeconds: 3600, + }) + + parsed, err := url.Parse(presignedURL) + if err != nil { + t.Fatalf("url.Parse() error = %v", err) + } + + if parsed.Host != "s3.us-west-2.amazonaws.com" { + t.Fatalf("host = %q, want s3.us-west-2.amazonaws.com", parsed.Host) + } + if parsed.EscapedPath() != "/examplebucket/multipart/test.bin" { + t.Fatalf("path = %q, want /examplebucket/multipart/test.bin", parsed.EscapedPath()) + } + if got := parsed.Query().Get("partNumber"); got != "7" { + t.Fatalf("partNumber = %q, want 7", got) + } + if got := parsed.Query().Get("uploadId"); got != "upload-id" { + t.Fatalf("uploadId = %q, want upload-id", got) + } +} diff --git a/helpers.go b/helpers.go index 701aa92..1aaec92 100644 --- a/helpers.go +++ b/helpers.go @@ -5,7 +5,6 @@ package simples3 import ( "encoding/hex" - "fmt" "io" "net/url" "regexp" @@ -17,20 +16,8 @@ import ( // getURL constructs a URL for a given path, with multiple optional // arguments as individual subfolders, based on the endpoint // specified in s3 struct. -func (s3 *S3) getURL(path string, args ...string) (uri string) { - if len(args) > 0 { - path += "/" + strings.Join(args, "/") - } - // need to encode special characters in the path part of the URL - encodedPath := encodePath(path) - - if len(s3.Endpoint) > 0 { - uri = s3.Endpoint + "/" + encodedPath - } else { - uri = fmt.Sprintf(s3.URIFormat, s3.Region, encodedPath) - } - - return uri +func (s3 *S3) getURL(path string, args ...string) string { + return s3.resolveAddress(addressingSurfaceRuntime, path, args...).urlString() } func detectFileSize(body io.Seeker) (int64, error) { diff --git a/policy.go b/policy.go index 9a61447..3a6328a 100644 --- a/policy.go +++ b/policy.go @@ -14,7 +14,6 @@ import ( "encoding/base64" "encoding/hex" "encoding/json" - "fmt" "time" ) @@ -70,7 +69,7 @@ var newLine = []byte{'\n'} //nolint func (s3 *S3) CreateUploadPolicies(uploadConfig UploadConfig) (UploadPolicies, error) { nowTime := nowTime() credential := string(s3.buildCredential(nowTime)) - data, err := buildUploadSign(nowTime, credential, uploadConfig) + data, err := buildUploadSign(nowTime, credential, uploadConfig, s3.Token) if err != nil { return UploadPolicies{}, err } @@ -84,7 +83,7 @@ func (s3 *S3) CreateUploadPolicies(uploadConfig UploadConfig) (UploadPolicies, e uploadURL := uploadConfig.UploadURL if uploadURL == "" { - uploadURL = fmt.Sprintf(defaultUploadURLFormat, uploadConfig.BucketName) + uploadURL = s3.resolveAddress(addressingSurfacePolicy, uploadConfig.BucketName).urlString() } // essential fields @@ -97,6 +96,9 @@ func (s3 *S3) CreateUploadPolicies(uploadConfig UploadConfig) (UploadPolicies, e "Policy": policy, "X-Amz-Signature": signature, } + if s3.Token != "" { + form["X-Amz-Security-Token"] = s3.Token + } // optional fields if uploadConfig.ContentDisposition != "" { @@ -116,7 +118,7 @@ func (s3 *S3) CreateUploadPolicies(uploadConfig UploadConfig) (UploadPolicies, e }, nil } -func buildUploadSign(nowTime time.Time, credential string, uploadConfig UploadConfig) ([]byte, error) { +func buildUploadSign(nowTime time.Time, credential string, uploadConfig UploadConfig, token string) ([]byte, error) { // essential conditions conditions := []interface{}{ map[string]string{"bucket": uploadConfig.BucketName}, @@ -127,6 +129,9 @@ func buildUploadSign(nowTime time.Time, credential string, uploadConfig UploadCo map[string]string{"x-amz-algorithm": algorithm}, map[string]string{"x-amz-date": nowTime.Format(amzDateISO8601TimeFormat)}, } + if token != "" { + conditions = append(conditions, map[string]string{"x-amz-security-token": token}) + } // optional conditions if uploadConfig.ContentDisposition != "" { diff --git a/policy_test.go b/policy_test.go new file mode 100644 index 0000000..d602f42 --- /dev/null +++ b/policy_test.go @@ -0,0 +1,127 @@ +package simples3 + +import ( + "encoding/base64" + "encoding/json" + "testing" +) + +func TestCreateUploadPolicies_UsePathStyleFalseUsesVirtualHostedStyle(t *testing.T) { + s3 := New("us-west-2", "AccessKey", "SuperSecretKey").SetUsePathStyle(false) + + policies, err := s3.CreateUploadPolicies(UploadConfig{ + BucketName: "examplebucket", + ObjectKey: "test.txt", + ContentType: "text/plain", + FileSize: 123, + }) + if err != nil { + t.Fatalf("CreateUploadPolicies() error = %v", err) + } + + if policies.URL != "https://examplebucket.s3.us-west-2.amazonaws.com/" { + t.Fatalf("URL = %q, want https://examplebucket.s3.us-west-2.amazonaws.com/", policies.URL) + } +} + +func TestCreateUploadPolicies_UsePathStyleFalseFallbacksToPathStyle(t *testing.T) { + s3 := New("us-west-2", "AccessKey", "SuperSecretKey") + s3.SetEndpoint("https://objects.example.com/base") + s3.SetUsePathStyle(false) + + policies, err := s3.CreateUploadPolicies(UploadConfig{ + BucketName: "examplebucket", + ObjectKey: "test.txt", + ContentType: "text/plain", + FileSize: 123, + }) + if err != nil { + t.Fatalf("CreateUploadPolicies() error = %v", err) + } + + if policies.URL != "https://objects.example.com/base/examplebucket" { + t.Fatalf("URL = %q, want https://objects.example.com/base/examplebucket", policies.URL) + } +} + +func TestCreateUploadPolicies_UsePathStyleFalseFallbacksInvalidDottedLabelToPathStyle(t *testing.T) { + s3 := New("us-west-2", "AccessKey", "SuperSecretKey").SetUsePathStyle(false) + + policies, err := s3.CreateUploadPolicies(UploadConfig{ + BucketName: "a.-b", + ObjectKey: "test.txt", + ContentType: "text/plain", + FileSize: 123, + }) + if err != nil { + t.Fatalf("CreateUploadPolicies() error = %v", err) + } + + if policies.URL != "https://s3.us-west-2.amazonaws.com/a.-b" { + t.Fatalf("URL = %q, want https://s3.us-west-2.amazonaws.com/a.-b", policies.URL) + } +} + +func TestCreateUploadPolicies_SetUsePathStyleTrueUsesPathStyle(t *testing.T) { + s3 := New("us-west-2", "AccessKey", "SuperSecretKey").SetUsePathStyle(true) + + policies, err := s3.CreateUploadPolicies(UploadConfig{ + BucketName: "examplebucket", + ObjectKey: "test.txt", + ContentType: "text/plain", + FileSize: 123, + }) + if err != nil { + t.Fatalf("CreateUploadPolicies() error = %v", err) + } + + if policies.URL != "https://s3.us-west-2.amazonaws.com/examplebucket" { + t.Fatalf("URL = %q, want https://s3.us-west-2.amazonaws.com/examplebucket", policies.URL) + } +} + +func TestCreateUploadPolicies_TokenAddsFormFieldAndPolicyCondition(t *testing.T) { + s3 := New("us-west-2", "AccessKey", "SuperSecretKey") + s3.SetToken("session-token") + + policies, err := s3.CreateUploadPolicies(UploadConfig{ + BucketName: "examplebucket", + ObjectKey: "test.txt", + ContentType: "text/plain", + FileSize: 123, + }) + if err != nil { + t.Fatalf("CreateUploadPolicies() error = %v", err) + } + + if got := policies.Form["X-Amz-Security-Token"]; got != "session-token" { + t.Fatalf("form[X-Amz-Security-Token] = %q, want session-token", got) + } + + decodedPolicy, err := base64.StdEncoding.DecodeString(policies.Form["Policy"]) + if err != nil { + t.Fatalf("DecodeString() error = %v", err) + } + + var policy PolicyJSON + if err := json.Unmarshal(decodedPolicy, &policy); err != nil { + t.Fatalf("json.Unmarshal() error = %v", err) + } + + if !policyHasCondition(policy, "x-amz-security-token", "session-token") { + t.Fatalf("policy missing x-amz-security-token condition: %s", string(decodedPolicy)) + } +} + +func policyHasCondition(policy PolicyJSON, key, value string) bool { + for _, condition := range policy.Conditions { + conditionMap, ok := condition.(map[string]interface{}) + if !ok { + continue + } + if got, ok := conditionMap[key]; ok && got == value { + return true + } + } + return false +} diff --git a/presigned.go b/presigned.go index f2d50ab..1a790c5 100644 --- a/presigned.go +++ b/presigned.go @@ -5,7 +5,6 @@ import ( "crypto/sha256" "encoding/hex" "net/url" - "path" "sort" "strconv" "strings" @@ -14,7 +13,6 @@ import ( const ( defaultPresignedHost = "s3.amazonaws.com" // - defaultProtocol = "https://" // HdrXAmzSignedHeaders = "X-Amz-SignedHeaders" ) @@ -40,50 +38,49 @@ func awsURIEncode(s string) string { // for Authentication using Query Parameters. // (https://docs.aws.amazon.com/AmazonS3/latest/API/sigv4-query-string-auth.html) func (s3 *S3) GeneratePresignedURL(in PresignedInput) string { + extraQuery := map[string]string{} + if in.ResponseContentDisposition != "" { + extraQuery["response-content-disposition"] = in.ResponseContentDisposition + } + + return s3.generatePresignedObjectURL( + in.Method, + in.Bucket, + in.ObjectKey, + in.Timestamp, + in.ExpirySeconds, + in.ExtraHeaders, + extraQuery, + ) +} + +func (s3 *S3) generatePresignedObjectURL(method, bucket, objectKey string, timestamp time.Time, expirySeconds int, extraHeaders map[string]string, extraQuery map[string]string) string { if err := s3.renewIAMToken(); err != nil { return "" } - var ( - nowTime = nowTime() - - protocol = defaultProtocol - hostname = defaultPresignedHost - path_prefix = "" - ) - if !in.Timestamp.IsZero() { - nowTime = in.Timestamp.UTC() + currentTime := nowTime() + if !timestamp.IsZero() { + currentTime = timestamp.UTC() } - amzdate := nowTime.Format(amzDateISO8601TimeFormat) + amzdate := currentTime.Format(amzDateISO8601TimeFormat) + address := s3.resolveAddress(addressingSurfacePresign, bucket, objectKey) // Create cred b := bytes.Buffer{} b.WriteString(s3.AccessKey) b.WriteRune('/') - b.Write(s3.buildCredentialWithoutKey(nowTime)) + b.Write(s3.buildCredentialWithoutKey(currentTime)) cred := b.Bytes() b.Reset() - // Set the protocol as default if not provided. - if endpoint, _ := url.Parse(s3.Endpoint); endpoint.Host != "" { - protocol = endpoint.Scheme + "://" - hostname = endpoint.Host - path_prefix = path.Join("/", endpoint.Path, in.Bucket) - } else { - host := bytes.Buffer{} - host.WriteString(in.Bucket) - host.WriteRune('.') - host.WriteString(hostname) - hostname = host.String() - } - // Add host to Headers signedHeaders := map[string][]byte{} - for k, v := range in.ExtraHeaders { + for k, v := range extraHeaders { // AWS requires header names to be lowercase per spec signedHeaders[strings.ToLower(k)] = []byte(v) } - signedHeaders["host"] = []byte(hostname) + signedHeaders["host"] = []byte(address.host) // Build signed headers string sortedSH := make([]string, 0, len(signedHeaders)) @@ -104,12 +101,10 @@ func (s3 *S3) GeneratePresignedURL(in PresignedInput) string { } // Start Canonical Request Formation - h := sha256.New() // We write the canonical request directly to the SHA256 hash. - h.Write([]byte(in.Method)) // HTTP Verb + h := sha256.New() // We write the canonical request directly to the SHA256 hash. + h.Write([]byte(method)) // HTTP Verb h.Write(newLine) - h.Write([]byte(path_prefix)) - h.Write([]byte{'/'}) - h.Write([]byte(encodePath(in.ObjectKey))) // CanonicalURL + h.Write([]byte(address.path)) // CanonicalURL h.Write(newLine) // Start QueryString Params (before SignedHeaders) @@ -117,16 +112,15 @@ func (s3 *S3) GeneratePresignedURL(in PresignedInput) string { "X-Amz-Algorithm": algorithm, "X-Amz-Credential": string(cred), "X-Amz-Date": amzdate, - "X-Amz-Expires": strconv.Itoa(in.ExpirySeconds), + "X-Amz-Expires": strconv.Itoa(expirySeconds), HdrXAmzSignedHeaders: signedHeadersForURL.String(), } - // Include response-content-disposition if set - if in.ResponseContentDisposition != "" { - queryString["response-content-disposition"] = in.ResponseContentDisposition + for k, v := range extraQuery { + queryString[k] = v } - // include the x-amz-security-token incase we are using IAM role or AWS STS + // include the x-amz-security-token incase we are using IAM role or AWS STS if s3.Token != "" { queryString["X-Amz-Security-Token"] = s3.Token } @@ -174,7 +168,7 @@ func (s3 *S3) GeneratePresignedURL(in PresignedInput) string { b.WriteRune('\n') b.WriteString(amzdate) b.WriteRune('\n') - b.Write(s3.buildCredentialWithoutKey(nowTime)) + b.Write(s3.buildCredentialWithoutKey(currentTime)) b.WriteRune('\n') hashed := hex.EncodeToString(h.Sum(nil)) @@ -189,7 +183,7 @@ func (s3 *S3) GeneratePresignedURL(in PresignedInput) string { makeHMac( makeHMac( []byte("AWS4"+s3.SecretKey), - []byte(nowTime.UTC().Format(shortTimeFormat))), + []byte(currentTime.UTC().Format(shortTimeFormat))), []byte(s3.Region)), []byte("s3")), []byte("aws4_request"), @@ -206,16 +200,7 @@ func (s3 *S3) GeneratePresignedURL(in PresignedInput) string { b.Reset() // Start Generating URL - if s3.Endpoint != "" { - b.WriteString(s3.Endpoint) - b.WriteRune('/') - b.WriteString(in.Bucket) - } else { - b.WriteString(protocol) - b.WriteString(hostname) - } - b.WriteRune('/') - b.WriteString(encodePath(in.ObjectKey)) + b.WriteString(address.urlString()) b.WriteRune('?') for i, k := range sortedQS { @@ -248,177 +233,16 @@ func (s3 *S3) GeneratePresignedUploadPartURL(in PresignedMultipartInput) string in.ExpirySeconds = 3600 } - if err := s3.renewIAMToken(); err != nil { - return "" - } - - var ( - nowTime = nowTime() - - protocol = defaultProtocol - hostname = defaultPresignedHost - path_prefix = "" + return s3.generatePresignedObjectURL( + "PUT", + in.Bucket, + in.ObjectKey, + time.Time{}, + in.ExpirySeconds, + nil, + map[string]string{ + "partNumber": strconv.Itoa(in.PartNumber), + "uploadId": in.UploadID, + }, ) - - amzdate := nowTime.Format(amzDateISO8601TimeFormat) - - // Create cred - b := bytes.Buffer{} - b.WriteString(s3.AccessKey) - b.WriteRune('/') - b.Write(s3.buildCredentialWithoutKey(nowTime)) - cred := b.Bytes() - b.Reset() - - // Set the protocol as default if not provided. - if endpoint, _ := url.Parse(s3.Endpoint); endpoint.Host != "" { - protocol = endpoint.Scheme + "://" - hostname = endpoint.Host - path_prefix = path.Join("/", endpoint.Path, in.Bucket) - } else { - host := bytes.Buffer{} - host.WriteString(in.Bucket) - host.WriteRune('.') - host.WriteString(hostname) - hostname = host.String() - } - - // Add host to Headers - // AWS requires header names to be lowercase per spec - signedHeaders := map[string][]byte{ - "host": []byte(hostname), - } - - // Build signed headers string - sortedSH := make([]string, 0, len(signedHeaders)) - for name := range signedHeaders { - sortedSH = append(sortedSH, name) - } - sort.Strings(sortedSH) - signedHeadersStr := strings.Join(sortedSH, ";") - - // For URL: header names must be individually escaped, semicolons remain raw - var signedHeadersForURL strings.Builder - for i, name := range sortedSH { - if i > 0 { - signedHeadersForURL.WriteRune(';') - } - signedHeadersForURL.WriteString(url.QueryEscape(name)) - } - - // Start Canonical Request Formation - h := sha256.New() - h.Write([]byte("PUT")) // Multipart uploads use PUT - h.Write(newLine) - h.Write([]byte(path_prefix)) - h.Write([]byte{'/'}) - h.Write([]byte(encodePath(in.ObjectKey))) - h.Write(newLine) - - // Start QueryString Params (before SignedHeaders) - queryString := map[string]string{ - "X-Amz-Algorithm": algorithm, - "X-Amz-Credential": string(cred), - "X-Amz-Date": amzdate, - "X-Amz-Expires": strconv.Itoa(in.ExpirySeconds), - HdrXAmzSignedHeaders: signedHeadersForURL.String(), - "partNumber": strconv.Itoa(in.PartNumber), - "uploadId": in.UploadID, - } - - // Include the x-amz-security-token in case we are using IAM role or AWS STS - if s3.Token != "" { - queryString["X-Amz-Security-Token"] = s3.Token - } - - // We need to have a sorted order for QueryStrings and SignedHeaders - sortedQS := make([]string, 0, len(queryString)) - for name := range queryString { - sortedQS = append(sortedQS, name) - } - sort.Strings(sortedQS) - - // Proceed to write canonical query params - for i, k := range sortedQS { - h.Write([]byte(awsURIEncode(k))) - h.Write([]byte{'='}) - // X-Amz-SignedHeaders already has properly formatted semicolons, retain as is. - h.Write([]byte(awsURIEncode(queryString[k]))) - if i < len(sortedQS)-1 { - h.Write([]byte{'&'}) - } - } - h.Write(newLine) - - // Start Canonical Headers - for i := 0; i < len(sortedSH); i++ { - h.Write([]byte(strings.ToLower(sortedSH[i]))) - h.Write([]byte{':'}) - h.Write([]byte(strings.TrimSpace(string(signedHeaders[sortedSH[i]])))) - h.Write(newLine) - } - h.Write(newLine) - - // Start Signed Headers - h.Write([]byte(signedHeadersStr)) - h.Write(newLine) - - // Mention Unsigned Payload - h.Write([]byte("UNSIGNED-PAYLOAD")) - - // Start StringToSign - b.WriteString(algorithm) - b.WriteRune('\n') - b.WriteString(amzdate) - b.WriteRune('\n') - b.Write(s3.buildCredentialWithoutKey(nowTime)) - b.WriteRune('\n') - - hashed := hex.EncodeToString(h.Sum(nil)) - b.WriteString(hashed) - - stringToSign := b.Bytes() - - // Start Signature Key - sigKey := makeHMac(makeHMac( - makeHMac( - makeHMac( - []byte("AWS4"+s3.SecretKey), - []byte(nowTime.UTC().Format(shortTimeFormat))), - []byte(s3.Region)), - []byte("s3")), - []byte("aws4_request"), - ) - - signedStrToSign := makeHMac(sigKey, stringToSign) - signature := hex.EncodeToString(signedStrToSign) - - // Reset Buffer to create URL - b.Reset() - - // Start Generating URL - if s3.Endpoint != "" { - b.WriteString(s3.Endpoint) - b.WriteRune('/') - b.WriteString(in.Bucket) - } else { - b.WriteString(protocol) - b.WriteString(hostname) - } - b.WriteRune('/') - b.WriteString(encodePath(in.ObjectKey)) - b.WriteRune('?') - - for i, k := range sortedQS { - b.WriteString(awsURIEncode(k)) - b.WriteRune('=') - b.WriteString(awsURIEncode(queryString[k])) - if i < len(sortedQS)-1 { - b.WriteRune('&') - } - } - b.WriteString("&X-Amz-Signature=") - b.WriteString(signature) - - return b.String() } diff --git a/simples3.go b/simples3.go index 192e4ee..9d775a3 100644 --- a/simples3.go +++ b/simples3.go @@ -26,11 +26,12 @@ type S3 struct { Region string Client *http.Client - Token string - Endpoint string - URIFormat string - initMode string - expiry time.Time + Token string + Endpoint string + URIFormat string + addressingMode addressingMode + initMode string + expiry time.Time mu sync.Mutex } @@ -92,6 +93,20 @@ func (s3 *S3) SetClient(client *http.Client) *S3 { return s3 } +// SetUsePathStyle explicitly selects S3 bucket addressing style. +// +// When left unset, simples3 preserves its legacy surface-specific defaults: +// runtime requests use path-style, while presigned URLs and direct upload +// policy defaults retain their historical behavior. +func (s3 *S3) SetUsePathStyle(use bool) *S3 { + if use { + s3.addressingMode = addressingModePath + } else { + s3.addressingMode = addressingModeVirtual + } + return s3 +} + func (s3 *S3) signRequest(req *http.Request) error { var ( err error