From fe0d968978e4fa26701bcd26d533232cd0d46d23 Mon Sep 17 00:00:00 2001 From: Josh Friend Date: Thu, 21 May 2026 09:38:59 -0400 Subject: [PATCH] Stamp delta bundles with base commit SHA and exclude stale artifact DBs Embed the base commit SHA in delta bundles to prevent cross-base contamination. On restore, write the hit commit to .cache-base-commit. SaveDelta reads it and stamps the delta via S3 metadata and a synthetic __base_commit__ tar entry. Restore checks both before applying. Also exclude module-artifact.bin and resource-at-url.bin from deltas to prevent stale resolution metadata from causing missing-jar failures. --- gradlecache/extract_default.go | 4 + gradlecache/ghacache.go | 2 +- gradlecache/gradlecache_test.go | 65 +++++++++++ gradlecache/restore.go | 26 +++++ gradlecache/s3.go | 19 ++- gradlecache/save.go | 200 +++++++++++++++++++++++++++++++- gradlecache/store.go | 37 +++--- 7 files changed, 331 insertions(+), 22 deletions(-) diff --git a/gradlecache/extract_default.go b/gradlecache/extract_default.go index 96a269d..caef296 100644 --- a/gradlecache/extract_default.go +++ b/gradlecache/extract_default.go @@ -225,6 +225,10 @@ func processEntry( return err } + if name == deltaBaseCommitEntry { + return nil + } + target := targetFn(name) switch hdr.Typeflag { diff --git a/gradlecache/ghacache.go b/gradlecache/ghacache.go index e7f886a..6948f16 100644 --- a/gradlecache/ghacache.go +++ b/gradlecache/ghacache.go @@ -329,7 +329,7 @@ func (g *ghaCacheStore) createAndFinalize(ctx context.Context, commit, cacheKey // put uploads a cache entry from a ReadSeeker of known size. // For small bundles (≤ 1 block), uses a single PUT. For larger bundles, // uses parallel Azure Block Blob upload (Put Block + Put Block List). -func (g *ghaCacheStore) put(ctx context.Context, commit, cacheKey string, r io.ReadSeeker, size int64) error { +func (g *ghaCacheStore) put(ctx context.Context, commit, cacheKey string, r io.ReadSeeker, size int64, _ map[string]string) error { return g.createAndFinalize(ctx, commit, cacheKey, size, func(signedURL string) error { if size <= ghaBlockSize { return g.azurePutSingle(ctx, signedURL, r, size) diff --git a/gradlecache/gradlecache_test.go b/gradlecache/gradlecache_test.go index 1deafb5..442ae50 100644 --- a/gradlecache/gradlecache_test.go +++ b/gradlecache/gradlecache_test.go @@ -59,6 +59,8 @@ func TestIsDeltaExcluded(t *testing.T) { excluded := []string{ "fileHashes", "module-metadata.bin", + "module-artifact.bin", + "resource-at-url.bin", } for _, name := range excluded { if !IsDeltaExcluded(name) { @@ -806,6 +808,67 @@ func TestDeltaTarZstdRoundTrip(t *testing.T) { } } +func TestStampedDeltaRoundTrip(t *testing.T) { + baseCommit := "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" + gradleHome := t.TempDir() + cachesDir := filepath.Join(gradleHome, "caches", "modules-2") + must(t, os.MkdirAll(cachesDir, 0o755)) + must(t, os.WriteFile(filepath.Join(cachesDir, "delta-file.bin"), []byte("delta"), 0o644)) + + // Create a stamped delta archive. + var buf bytes.Buffer + must(t, createStampedDeltaTarZstdMulti(&buf, baseCommit, + DeltaSource{BaseDir: gradleHome, RelPaths: []string{"caches/modules-2/delta-file.bin"}})) + + // ReadDeltaBaseCommit should return the embedded stamp. + r := bytes.NewReader(buf.Bytes()) + got, err := ReadDeltaBaseCommit(r) + must(t, err) + if got != baseCommit { + t.Fatalf("ReadDeltaBaseCommit = %q, want %q", got, baseCommit) + } + + // The file should still extract correctly (stamp entry is skipped). + dstDir := t.TempDir() + must(t, extractTarZstd(context.Background(), bytes.NewReader(buf.Bytes()), dstDir)) + + data, err := os.ReadFile(filepath.Join(dstDir, "caches", "modules-2", "delta-file.bin")) + must(t, err) + if string(data) != "delta" { + t.Fatalf("extracted content = %q, want %q", string(data), "delta") + } + + // __base_commit__ should NOT exist as a file on disk. + if _, err := os.Stat(filepath.Join(dstDir, deltaBaseCommitEntry)); err == nil { + t.Fatal("__base_commit__ should not be extracted as a real file") + } +} + +func TestReadDeltaBaseCommitMissing(t *testing.T) { + // An unstamped delta should return empty string. + var buf bytes.Buffer + must(t, CreateDeltaTarZstdMulti(&buf, + DeltaSource{BaseDir: t.TempDir(), RelPaths: nil})) + + r := bytes.NewReader(buf.Bytes()) + got, err := ReadDeltaBaseCommit(r) + must(t, err) + if got != "" { + t.Fatalf("ReadDeltaBaseCommit on unstamped delta = %q, want empty", got) + } +} + +func TestBaseCommitFileRoundTrip(t *testing.T) { + dir := t.TempDir() + sha := "bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb" + must(t, writeBaseCommitFile(dir, sha)) + got, err := readBaseCommitFile(dir) + must(t, err) + if got != sha { + t.Fatalf("readBaseCommitFile = %q, want %q", got, sha) + } +} + func TestSaveDeltaDefaultsProjectDirToWorkingDirectory(t *testing.T) { ctx := context.Background() gradleHome := t.TempDir() @@ -817,6 +880,8 @@ func TestSaveDeltaDefaultsProjectDirToWorkingDirectory(t *testing.T) { markerPath := filepath.Join(gradleHome, ".cache-restore-marker") must(t, touchMarkerFile(markerPath)) + // Write a base commit file so SaveDelta can stamp the delta. + must(t, writeBaseCommitFile(gradleHome, "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa")) // Sleep to ensure files created below have a strictly newer mtime than // the marker. On Linux ext4 the mtime granularity is 1 ms, but rapid diff --git a/gradlecache/restore.go b/gradlecache/restore.go index f1d3538..726ba06 100644 --- a/gradlecache/restore.go +++ b/gradlecache/restore.go @@ -386,6 +386,15 @@ func Restore(ctx context.Context, cfg RestoreConfig) error { deltaCh <- deltaResult{} return } + // Check S3 metadata for base-commit mismatch before downloading. + if metaBase := deltaInfo.Metadata[deltaBaseCommitMetaKey]; metaBase != "" && metaBase != hitCommit { + log.Info("skipping delta: built on different base", + "delta_base", metaBase[:min(8, len(metaBase))], + "current_base", hitCommit[:min(8, len(hitCommit))]) + deltaCh <- deltaResult{} + return + } + log.Info("downloading delta bundle", "branch", cfg.Branch) dlStart := time.Now() body, err := store.get(ctx, dc, cfg.CacheKey, deltaInfo) @@ -412,6 +421,19 @@ func Restore(ctx context.Context, cfg RestoreConfig) error { deltaCh <- deltaResult{err: errors.Wrap(err, "rewind delta temp file")} return } + + // Check tar stamp for base-commit mismatch (covers stores without metadata). + tarBase, tarErr := ReadDeltaBaseCommit(tmp) + if tarErr == nil && tarBase != "" && tarBase != hitCommit { + log.Info("skipping delta: tar stamp shows different base", + "delta_base", tarBase[:min(8, len(tarBase))], + "current_base", hitCommit[:min(8, len(hitCommit))]) + tmp.Close() //nolint:errcheck,gosec + os.Remove(tmp.Name()) //nolint:errcheck,gosec + deltaCh <- deltaResult{} + return + } + deltaCh <- deltaResult{tmpFile: tmp, dlStart: dlStart, n: cb.n, eofAt: cb.eofAt} }() } @@ -489,6 +511,10 @@ func Restore(ctx context.Context, cfg RestoreConfig) error { cfg.Metrics.Distribution("gradle_cache.restore_base.speed_mbps", mbps, "cache_key:"+cfg.CacheKey) } + if err := writeBaseCommitFile(cfg.GradleUserHome, hitCommit); err != nil { + log.Warn("could not write base commit file", "err", err) + } + if err := touchMarkerFile(filepath.Join(cfg.GradleUserHome, ".cache-restore-marker")); err != nil { log.Warn("could not write restore marker", "err", err) } diff --git a/gradlecache/s3.go b/gradlecache/s3.go index c9bc692..d75c1a2 100644 --- a/gradlecache/s3.go +++ b/gradlecache/s3.go @@ -63,8 +63,9 @@ func newS3Client(region string) (*s3Client, error) { } type s3ObjInfo struct { - Size int64 - ETag string + Size int64 + ETag string + Metadata map[string]string } func (c *s3Client) stat(ctx context.Context, bucket, key string) (s3ObjInfo, error) { @@ -82,10 +83,20 @@ func (c *s3Client) stat(ctx context.Context, bucket, key string) (s3ObjInfo, err if resp.StatusCode != http.StatusOK { return s3ObjInfo{}, errors.Errorf("status %d", resp.StatusCode) } - return s3ObjInfo{ + info := s3ObjInfo{ Size: resp.ContentLength, ETag: resp.Header.Get("ETag"), - }, nil + } + const metaPrefix = "X-Amz-Meta-" + for k, vs := range resp.Header { + if len(vs) > 0 && len(k) > len(metaPrefix) && strings.EqualFold(k[:len(metaPrefix)], metaPrefix) { + if info.Metadata == nil { + info.Metadata = make(map[string]string) + } + info.Metadata[strings.ToLower(k[len(metaPrefix):])] = vs[0] + } + } + return info, nil } func (c *s3Client) get(ctx context.Context, bucket, key string, info s3ObjInfo) (io.ReadCloser, error) { diff --git a/gradlecache/save.go b/gradlecache/save.go index 380d3e5..76946d6 100644 --- a/gradlecache/save.go +++ b/gradlecache/save.go @@ -22,6 +22,69 @@ import ( "github.com/klauspost/compress/zstd" ) +const ( + // baseCommitFile is written to GRADLE_USER_HOME after a base restore so + // that delta save can stamp the delta with the base it was built on. + baseCommitFile = ".cache-base-commit" + + // deltaBaseCommitMetaKey is the S3 user-metadata key used to stamp a + // delta bundle with the base commit it was built against. + deltaBaseCommitMetaKey = "base-commit" + + // deltaBaseCommitEntry is a synthetic tar entry prepended to delta + // bundles carrying the base commit SHA. It is skipped during extraction. + deltaBaseCommitEntry = "__base_commit__" +) + +func writeBaseCommitFile(gradleHome, sha string) error { + path := filepath.Join(gradleHome, baseCommitFile) + return os.WriteFile(path, []byte(sha+"\n"), 0o600) +} + +func readBaseCommitFile(gradleHome string) (string, error) { + path := filepath.Join(gradleHome, baseCommitFile) + data, err := os.ReadFile(path) + if err != nil { + return "", err + } + sha := strings.TrimSpace(string(data)) + if !IsFullSHA(sha) { + return "", errors.Errorf("invalid SHA in %s: %q", path, sha) + } + return sha, nil +} + +// ReadDeltaBaseCommit reads the __base_commit__ stamp from a delta bundle. +// r must be a zstd-compressed tar. After reading, r is seeked back to the start. +// Returns ("", nil) if no stamp is found. +func ReadDeltaBaseCommit(r io.ReadSeeker) (string, error) { + dec, err := zstd.NewReader(r) + if err != nil { + return "", err + } + tr := tar.NewReader(dec) + hdr, err := tr.Next() + if err != nil { + dec.Close() + if _, seekErr := r.Seek(0, io.SeekStart); seekErr != nil { + return "", seekErr + } + return "", nil + } + var sha string + if hdr.Name == deltaBaseCommitEntry { + data, readErr := io.ReadAll(tr) + if readErr == nil { + sha = strings.TrimSpace(string(data)) + } + } + dec.Close() + if _, seekErr := r.Seek(0, io.SeekStart); seekErr != nil { + return "", seekErr + } + return sha, nil +} + // RestoreDeltaConfig holds the parameters for a delta restore operation. type RestoreDeltaConfig struct { Bucket string @@ -80,6 +143,22 @@ func RestoreDelta(ctx context.Context, cfg RestoreDeltaConfig) error { } log.Info("found delta bundle", "branch", cfg.Branch, "cache-key", cfg.CacheKey) + // Check metadata-level base-commit stamp against local base. + localBase, localBaseErr := readBaseCommitFile(cfg.GradleUserHome) + metaBase := deltaInfo.Metadata[deltaBaseCommitMetaKey] + if metaBase != "" && localBaseErr != nil { + // Delta is stamped but we don't know what base we restored. + // Skip to avoid cross-base contamination. + log.Info("skipping stamped delta: local base commit unknown", "err", localBaseErr) + return nil + } + if metaBase != "" && metaBase != localBase { + log.Info("skipping delta: built on different base", + "delta_base", metaBase[:min(8, len(metaBase))], + "current_base", localBase[:min(8, len(localBase))]) + return nil + } + dlStart := time.Now() body, err := store.get(ctx, dc, cfg.CacheKey, deltaInfo) if err != nil { @@ -92,6 +171,62 @@ func RestoreDelta(ctx context.Context, cfg RestoreDeltaConfig) error { pdSources = ProjectDirSources(cfg.ProjectDir, cfg.IncludedBuilds) } + // When metadata is unavailable (e.g. cachew), buffer the delta and check + // the embedded tar stamp before extracting. + if metaBase == "" && localBaseErr == nil { + tmp, tmpErr := os.CreateTemp("", "gradle-cache-delta-check-*") + if tmpErr != nil { + return errors.Wrap(tmpErr, "create delta temp file") + } + defer func() { + tmp.Close() //nolint:errcheck,gosec + os.Remove(tmp.Name()) //nolint:errcheck,gosec + }() + cb := &countingBody{r: body, dlStart: dlStart} + if _, err := io.Copy(tmp, cb); err != nil { + return errors.Wrap(err, "buffer delta bundle") + } + if _, err := tmp.Seek(0, io.SeekStart); err != nil { + return errors.Wrap(err, "rewind delta temp file") + } + tarBase, tarErr := ReadDeltaBaseCommit(tmp) + if tarErr == nil && tarBase != "" { + if localBaseErr != nil { + log.Info("skipping stamped delta: local base commit unknown", "err", localBaseErr) + return nil + } + if tarBase != localBase { + log.Info("skipping delta: tar stamp shows different base", + "delta_base", tarBase[:min(8, len(tarBase))], + "current_base", localBase[:min(8, len(localBase))]) + return nil + } + } + if !cb.eofAt.IsZero() { + dlElapsed := cb.eofAt.Sub(dlStart) + log.Info("delta download complete", "duration", dlElapsed.Round(time.Millisecond), + "size_mb", fmt.Sprintf("%.1f", float64(cb.n)/1e6), + "speed_mbps", fmt.Sprintf("%.1f", float64(cb.n)/dlElapsed.Seconds()/1e6)) + } + if err := extractDeltaTarZstd(ctx, tmp, cfg.GradleUserHome, pdSources); err != nil { + return errors.Wrap(err, "extract delta bundle") + } + deltaElapsed := time.Since(dlStart) + log.Info("applied delta bundle", "branch", cfg.Branch, "cache-key", cfg.CacheKey, + "total_duration", deltaElapsed.Round(time.Millisecond)) + cfg.Metrics.Distribution("gradle_cache.restore_delta.duration_ms", float64(deltaElapsed.Milliseconds()), + "cache_key:"+cfg.CacheKey) + cfg.Metrics.Distribution("gradle_cache.restore_delta.size_bytes", float64(cb.n), + "cache_key:"+cfg.CacheKey) + if !cb.eofAt.IsZero() { + dlElapsed := cb.eofAt.Sub(dlStart) + mbps := float64(cb.n) / dlElapsed.Seconds() / 1e6 + cfg.Metrics.Distribution("gradle_cache.restore_delta.speed_mbps", mbps, + "cache_key:"+cfg.CacheKey) + } + return nil + } + cb := &countingBody{r: body, dlStart: dlStart} if err := extractDeltaTarZstd(ctx, cb, cfg.GradleUserHome, pdSources); err != nil { return errors.Wrap(err, "extract delta bundle") @@ -383,8 +518,15 @@ func SaveDelta(ctx context.Context, cfg SaveDeltaConfig) error { log.Info("saving delta bundle", "branch", cfg.Branch, "cache-key", cfg.CacheKey, "files", totalFiles) saveStart := time.Now() + // Read the base commit that this delta was built on top of. + baseCommit, err := readBaseCommitFile(cfg.GradleUserHome) + if err != nil { + log.Warn("could not read base commit file, skipping delta save", "err", err) + return nil + } + sources := append([]DeltaSource{{BaseDir: cfg.GradleUserHome, RelPaths: newFiles}}, projectSources...) - if err := CreateDeltaTarZstdMulti(tmp, sources...); err != nil { + if err := createStampedDeltaTarZstdMulti(tmp, baseCommit, sources...); err != nil { return errors.Wrap(err, "create delta archive") } @@ -396,7 +538,8 @@ func SaveDelta(ctx context.Context, cfg SaveDeltaConfig) error { return errors.Wrap(err, "rewind delta bundle") } - if err := store.put(ctx, dc, cfg.CacheKey, tmp, size); err != nil { + meta := map[string]string{deltaBaseCommitMetaKey: baseCommit} + if err := store.put(ctx, dc, cfg.CacheKey, tmp, size, meta); err != nil { return errors.Wrap(err, "upload delta bundle") } @@ -509,6 +652,18 @@ var DeltaExclusions = []string{ // across machines but rewritten on every build due to DB compaction. The base // bundle already has it and a single build rarely adds new dependencies. "module-metadata.bin", + + // module-artifact.bin caches whether artifacts have been downloaded from + // remote repositories. Rewritten every build due to DB compaction. If a + // stale copy tells Gradle an artifact is already cached when the local file + // doesn't exist (e.g. guava resolved under a different version directory + // in the base), Gradle skips the download and the build fails. + "module-artifact.bin", + + // resource-at-url.bin caches HTTP responses for repository resource URLs. + // Rewritten every build due to DB compaction. A stale copy can cause Gradle + // to use outdated repository metadata, leading to incorrect resolution. + "resource-at-url.bin", } // wrapperZipExclusion excludes the downloaded Gradle distribution zip from @@ -778,6 +933,47 @@ func CreateDeltaTarZstdMulti(w io.Writer, sources ...DeltaSource) error { return errors.Join(tarErr, encErr) } +// createStampedDeltaTarZstdMulti is like CreateDeltaTarZstdMulti but prepends a +// synthetic __base_commit__ tar entry containing the base commit SHA. +func createStampedDeltaTarZstdMulti(w io.Writer, baseCommit string, sources ...DeltaSource) error { + enc, err := zstd.NewWriter(w, + zstd.WithEncoderConcurrency(runtime.GOMAXPROCS(0))) + if err != nil { + return errors.Wrap(err, "create zstd encoder") + } + + tw := tar.NewWriter(enc) + + // Write synthetic stamp entry. + content := []byte(baseCommit + "\n") + if err := tw.WriteHeader(&tar.Header{ + Name: deltaBaseCommitEntry, + Size: int64(len(content)), + Mode: 0o644, + }); err != nil { + enc.Close() //nolint:errcheck,gosec + return errors.Wrap(err, "write base-commit tar header") + } + if _, err := tw.Write(content); err != nil { + enc.Close() //nolint:errcheck,gosec + return errors.Wrap(err, "write base-commit tar content") + } + + // Write all real delta entries. + for _, src := range sources { + for _, rel := range src.RelPaths { + if err := writeDeltaTarEntry(tw, src.BaseDir, rel); err != nil { + enc.Close() //nolint:errcheck,gosec + return err + } + } + } + + tarErr := tw.Close() + encErr := enc.Close() + return errors.Join(tarErr, encErr) +} + // WriteDeltaTar writes a tar stream for the specified files to w. func WriteDeltaTar(w io.Writer, baseDir string, relPaths []string) error { return WriteDeltaTarMulti(w, DeltaSource{BaseDir: baseDir, RelPaths: relPaths}) diff --git a/gradlecache/store.go b/gradlecache/store.go index 809c0f8..b969318 100644 --- a/gradlecache/store.go +++ b/gradlecache/store.go @@ -21,15 +21,16 @@ import ( // bundleStatInfo holds opaque metadata returned by bundleStore.stat(). type bundleStatInfo struct { - Size int64 - etag string + Size int64 + etag string + Metadata map[string]string // user-defined metadata (e.g. base-commit for deltas) } // bundleStore abstracts over S3 and cachew as storage backends. type bundleStore interface { stat(ctx context.Context, commit, cacheKey string) (bundleStatInfo, error) get(ctx context.Context, commit, cacheKey string, info bundleStatInfo) (io.ReadCloser, error) - put(ctx context.Context, commit, cacheKey string, r io.ReadSeeker, size int64) error + put(ctx context.Context, commit, cacheKey string, r io.ReadSeeker, size int64, meta map[string]string) error putStream(ctx context.Context, commit, cacheKey string, r io.Reader) (int64, error) } @@ -62,15 +63,15 @@ func (s *s3BundleStore) stat(ctx context.Context, commit, cacheKey string) (bund if err != nil { return bundleStatInfo{}, err } - return bundleStatInfo{Size: obj.Size, etag: obj.ETag}, nil + return bundleStatInfo{Size: obj.Size, etag: obj.ETag, Metadata: obj.Metadata}, nil } func (s *s3BundleStore) get(ctx context.Context, commit, cacheKey string, info bundleStatInfo) (io.ReadCloser, error) { return s.client.get(ctx, s.bucket, s3Key(s.keyPrefix, commit, cacheKey, bundleFilename(cacheKey)), s3ObjInfo{Size: info.Size, ETag: info.etag}) } -func (s *s3BundleStore) put(ctx context.Context, commit, cacheKey string, r io.ReadSeeker, size int64) error { - return s.client.put(ctx, s.bucket, s3Key(s.keyPrefix, commit, cacheKey, bundleFilename(cacheKey)), r, size, "application/zstd") +func (s *s3BundleStore) put(ctx context.Context, commit, cacheKey string, r io.ReadSeeker, size int64, meta map[string]string) error { + return s.client.put(ctx, s.bucket, s3Key(s.keyPrefix, commit, cacheKey, bundleFilename(cacheKey)), r, size, "application/zstd", meta) } func (s *s3BundleStore) putStream(ctx context.Context, commit, cacheKey string, r io.Reader) (int64, error) { @@ -148,7 +149,7 @@ func (c *cachewClient) get(ctx context.Context, commit, cacheKey string, _ bundl return resp.Body, nil } -func (c *cachewClient) put(ctx context.Context, commit, cacheKey string, r io.ReadSeeker, size int64) error { +func (c *cachewClient) put(ctx context.Context, commit, cacheKey string, r io.ReadSeeker, size int64, _ map[string]string) error { req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.objectURL(commit, cacheKey), r) if err != nil { return err @@ -207,11 +208,11 @@ const ( uploadWorkers = 8 ) -func (c *s3Client) put(ctx context.Context, bucket, key string, r io.ReadSeeker, size int64, contentType string) error { +func (c *s3Client) put(ctx context.Context, bucket, key string, r io.ReadSeeker, size int64, contentType string, meta map[string]string) error { if size <= uploadPartSize { - return c.putSingle(ctx, bucket, key, r, size, contentType) + return c.putSingle(ctx, bucket, key, r, size, contentType, meta) } - return c.putMultipart(ctx, bucket, key, r, size, contentType) + return c.putMultipart(ctx, bucket, key, r, size, contentType, meta) } // crc64Of computes a CRC64-NVME checksum of the data from r, then seeks back. @@ -233,7 +234,7 @@ func crc64Base64(h hash.Hash64) string { return base64.StdEncoding.EncodeToString(buf[:]) } -func (c *s3Client) putSingle(ctx context.Context, bucket, key string, r io.ReadSeeker, size int64, contentType string) error { +func (c *s3Client) putSingle(ctx context.Context, bucket, key string, r io.ReadSeeker, size int64, contentType string, meta map[string]string) error { checksum, err := crc64Of(r) if err != nil { return errors.Wrap(err, "compute CRC64-NVME checksum") @@ -248,6 +249,9 @@ func (c *s3Client) putSingle(ctx context.Context, bucket, key string, r io.ReadS } req.Header.Set("X-Amz-Checksum-Crc64nvme", checksum) req.Header.Set("X-Amz-Checksum-Algorithm", "CRC64NVME") + for k, v := range meta { + req.Header.Set("X-Amz-Meta-"+k, v) + } c.sign(req) resp, err := c.http.Do(req) //nolint:gosec if err != nil { @@ -261,8 +265,8 @@ func (c *s3Client) putSingle(ctx context.Context, bucket, key string, r io.ReadS return nil } -func (c *s3Client) putMultipart(ctx context.Context, bucket, key string, r io.ReadSeeker, size int64, contentType string) error { - uploadID, err := c.createMultipartUpload(ctx, bucket, key, contentType) +func (c *s3Client) putMultipart(ctx context.Context, bucket, key string, r io.ReadSeeker, size int64, contentType string, meta map[string]string) error { + uploadID, err := c.createMultipartUpload(ctx, bucket, key, contentType, meta) if err != nil { return err } @@ -329,7 +333,7 @@ func (c *s3Client) putMultipart(ctx context.Context, bucket, key string, r io.Re return c.completeMultipartUpload(ctx, bucket, key, uploadID, parts) } -func (c *s3Client) createMultipartUpload(ctx context.Context, bucket, key, contentType string) (string, error) { +func (c *s3Client) createMultipartUpload(ctx context.Context, bucket, key, contentType string, meta map[string]string) (string, error) { u := c.objectURL(bucket, key) + "?uploads" req, err := http.NewRequestWithContext(ctx, http.MethodPost, u, nil) if err != nil { @@ -339,6 +343,9 @@ func (c *s3Client) createMultipartUpload(ctx context.Context, bucket, key, conte req.Header.Set("Content-Type", contentType) } req.Header.Set("X-Amz-Checksum-Algorithm", "CRC64NVME") + for k, v := range meta { + req.Header.Set("X-Amz-Meta-"+k, v) + } c.sign(req) resp, err := c.http.Do(req) //nolint:gosec if err != nil { @@ -440,7 +447,7 @@ func (c *s3Client) abortMultipartUpload(ctx context.Context, bucket, key, upload } func (c *s3Client) putStreamingMultipart(ctx context.Context, bucket, key string, r io.Reader, contentType string) (int64, error) { - uploadID, err := c.createMultipartUpload(ctx, bucket, key, contentType) + uploadID, err := c.createMultipartUpload(ctx, bucket, key, contentType, nil) if err != nil { return 0, err }