diff --git a/pkg/e2e/e2e.go b/pkg/e2e/e2e.go index 9268bd5..91aa945 100644 --- a/pkg/e2e/e2e.go +++ b/pkg/e2e/e2e.go @@ -88,14 +88,16 @@ func (e *E2E) Do(rewrite func(r *http.Request)) (*http.Response, error) { // wait for a while to let the server ready time.Sleep(time.Millisecond * 100) - rewrite(e.req) + nr := e.req.Clone(context.Background()) - method := e.req.Method + rewrite(nr) - e.req.Header.Set(protocol.InternalUpstreamAddr, e.ts.Listener.Addr().String()) + method := nr.Method + + nr.Header.Set(protocol.InternalUpstreamAddr, e.ts.Listener.Addr().String()) if dumpReq.Load() && method != "PURGE" { - DumpReq(e.req) + DumpReq(nr) } if manual.Load() { @@ -103,7 +105,7 @@ func (e *E2E) Do(rewrite func(r *http.Request)) (*http.Response, error) { time.Sleep(time.Second * 20) } - resp, err := e.cs.Do(e.req) + resp, err := e.cs.Do(nr) e.resp = resp e.err = err diff --git a/server/middleware/caching/caching.go b/server/middleware/caching/caching.go index 329f785..9335546 100644 --- a/server/middleware/caching/caching.go +++ b/server/middleware/caching/caching.go @@ -112,6 +112,12 @@ func Middleware(c *configv1.Middleware) (middleware.Middleware, func(), error) { proxyClient := proxy.GetProxy() store := storagev1.Current() + // Flight groups for collapsed forwarding at object and chunk level. + // These mirror Squid's collapsed_forwarding: one origin request + // serves many waiting clients. + objectFlight := &ObjectFlightGroup{} + chunkFlight := &ChunkFlightGroup{} + return middleware.RoundTripperFunc(func(req *http.Request) (resp *http.Response, err error) { // only cache GET/HEAD request if req.Method != http.MethodGet && req.Method != http.MethodHead { @@ -130,6 +136,9 @@ func Middleware(c *configv1.Middleware) (middleware.Middleware, func(), error) { // cachingPool.Put(caching) //}() + // Wire up chunk-level collapsed forwarding. + caching.chunkFlight = chunkFlight + // err to BYPASS caching if err != nil { caching.log.Warnf("Precache processor failed: %v BYPASS", err) @@ -147,34 +156,28 @@ func Middleware(c *configv1.Middleware) (middleware.Middleware, func(), error) { // cache HIT if caching.hit { - caching.cacheStatus = storage.CacheHit - - rng, err1 := xhttp.SingleRange(req.Header.Get("Range"), caching.md.Size) - if err1 != nil { - // 无效 Range 处理 - headers := make(http.Header) - xhttp.CopyHeader(caching.md.Headers, headers) - headers.Set("Content-Range", fmt.Sprintf("bytes */%d", caching.md.Size)) - return nil, xhttp.NewBizError(http.StatusRequestedRangeNotSatisfiable, headers) - } - - // mark cache status with Range requests. - caching.markCacheStatus(rng.Start, rng.End) + return caching.respondFromCache(req) + } - // find file seek(start, end) - resp, err = caching.lazilyRespond(req, rng.Start, rng.End) - if err != nil { - // fd leak - closeBody(resp) - return nil, err + // full MISS — use object-level collapsed forwarding so that + // concurrent requests for the same cache object share one + // origin fetch (Squid-style collapsed_forwarding). + if opts.CollapsedRequest { + flightResp, _, flightErr := objectFlight.Do(caching.id.HashStr(), opts.CollapsedRequestWaitTimeout.AsDuration(), func() (*http.Response, error) { + r, e := caching.doProxy(req, false) + if e != nil { + return nil, e + } + return processor.postCacheProcessor(caching, req, r) + }) + if flightErr != nil { + return nil, flightErr } - - // response now - resp, err = caching.processor.postCacheProcessor(caching, req, resp) + resp = flightResp return } - // full MISS + // full MISS (collapsed forwarding disabled) resp, err = caching.doProxy(req, false) if err != nil { return nil, err @@ -187,6 +190,31 @@ func Middleware(c *configv1.Middleware) (middleware.Middleware, func(), error) { }, middleware.EmptyCleanup, nil } +// respondFromCache assembles a response from cached chunks for a cache HIT. +// It parses the Range header, builds a multi-part reader from disk, and +// runs post-cache processing (headers, cache status, store). +func (c *Caching) respondFromCache(req *http.Request) (*http.Response, error) { + c.cacheStatus = storage.CacheHit + + rng, err := xhttp.SingleRange(req.Header.Get("Range"), c.md.Size) + if err != nil { + headers := make(http.Header) + xhttp.CopyHeader(c.md.Headers, headers) + headers.Set("Content-Range", fmt.Sprintf("bytes */%d", c.md.Size)) + return nil, xhttp.NewBizError(http.StatusRequestedRangeNotSatisfiable, headers) + } + + c.markCacheStatus(rng.Start, rng.End) + + resp, err := c.lazilyRespond(req, rng.Start, rng.End) + if err != nil { + closeBody(resp) + return nil, err + } + + return c.processor.postCacheProcessor(c, req, resp) +} + func (c *Caching) lazilyRespond(req *http.Request, start, end int64) (*http.Response, error) { // 这里通过缓存的块大小来计算,而不是配置默认的 SliceSize // 这样已缓存的对象可以使用原来的配置块大小,不受配置文件变更影响 @@ -271,9 +299,6 @@ func (c *Caching) getUpstreamReader(fromByte, toByte uint64, async bool) (io.Rea closeBody(resp) return nil, err } - // 部分命中 - c.cacheStatus = storage.CachePartHit - // 发起的是 206 请求,但是返回的非 206 if resp.StatusCode != http.StatusPartialContent { c.log.Warnf("getUpstreamReader doProxy[chunk]: status code: %d, bod size: %d", resp.StatusCode, resp.ContentLength) return resp, xhttp.NewBizError(resp.StatusCode, resp.Header) @@ -281,6 +306,19 @@ func (c *Caching) getUpstreamReader(fromByte, toByte uint64, async bool) (io.Rea return resp, nil } + // Chunk-level collapsed forwarding: if another goroutine is already + // fetching the same byte range for this object, wait and share the + // response body (io.MultiWriter fan-out). This mirrors Squid's + // collapsed_forwarding at the chunk/segment level. + if c.chunkFlight != nil && c.opt.CollapsedRequest && c.id != nil { + key := fmt.Sprintf("%s:%d-%d", c.id.HashStr(), fromByte, toByte) + reader, shared, err := c.chunkFlight.Do(key, c.opt.CollapsedRequestWaitTimeout.AsDuration(), doSubRequest) + if shared { + c.cacheStatus = storage.CachePartHit + } + return reader, err + } + if async { return iobuf.AsyncReadCloser(doSubRequest), nil } diff --git a/server/middleware/caching/chunk_flight.go b/server/middleware/caching/chunk_flight.go new file mode 100644 index 0000000..f5491b9 --- /dev/null +++ b/server/middleware/caching/chunk_flight.go @@ -0,0 +1,106 @@ +package caching + +import ( + "io" + "net/http" + "sync" + "time" +) + +// chunkCall is an in-flight chunk upstream request. +type chunkCall struct { + pipes []*io.PipeWriter +} + +// ChunkFlightGroup collapses concurrent upstream requests for the same +// (object, byte-range) into a single origin fetch. Response body bytes +// are fanned out to all waiters via io.MultiWriter + io.Pipe. +// +// This mirrors Squid's collapsed_forwarding at the chunk/segment level: +// when two goroutines request the same byte range of the same cached +// object, only one hits origin and the others wait. +type ChunkFlightGroup struct { + mu sync.Mutex + m map[string]*chunkCall +} + +// Do executes fn once per key. All callers — including the first — receive +// an io.PipeReader carrying the upstream response body. The returned bool +// reports whether this caller shared an in-flight request. +// +// waiter is the duration the origin goroutine pauses *before* calling fn, +// giving late-arriving callers a window to register under the same key. +// In production the network round-trip naturally provides this window; +// waiter ensures correctness even when fn would otherwise complete nearly +// instantly (e.g. in tests, or for tiny ranges on a local origin). +// +// Contract: fn owns resp.Body. On success ChunkFlightGroup reads and +// closes it. On error fn must either return (nil, err) or close the body +// before returning (resp, err). +func (g *ChunkFlightGroup) Do(key string, waiter time.Duration, fn func() (*http.Response, error)) (io.ReadCloser, bool, error) { + pr, pw := io.Pipe() + + g.mu.Lock() + if g.m == nil { + g.m = make(map[string]*chunkCall) + } + if c, ok := g.m[key]; ok { + c.pipes = append(c.pipes, pw) + g.mu.Unlock() + return pr, true, nil + } + + c := &chunkCall{pipes: []*io.PipeWriter{pw}} + g.m[key] = c + g.mu.Unlock() + + go func() { + // Pause before hitting origin so concurrent callers have time + // to register under this key. Without this window an instant + // fn would complete and delete the map entry before anyone + // else could join. + if waiter > 0 { + time.Sleep(waiter) + } + + resp, err := fn() + + g.mu.Lock() + // Snapshot pipes and remove the key so no further callers + // register against this flight. + pipes := make([]*io.PipeWriter, len(c.pipes)) + copy(pipes, c.pipes) + delete(g.m, key) + + if err != nil { + g.mu.Unlock() + for _, p := range pipes { + _ = p.CloseWithError(err) + } + // fn owns resp.Body on error — it must close it before + // returning. We only guard against a nil body here. + return + } + + // Build MultiWriter from all registered pipe writers. + writers := make([]io.Writer, len(pipes)) + for i, p := range pipes { + writers[i] = p + } + mw := io.MultiWriter(writers...) + g.mu.Unlock() + + _, copyErr := io.Copy(mw, resp.Body) + _ = resp.Body.Close() + + for _, p := range pipes { + if copyErr != nil && copyErr != io.EOF { + _ = p.CloseWithError(copyErr) + } else { + _ = p.Close() + } + } + }() + + return pr, false, nil +} diff --git a/server/middleware/caching/collapsed_forwarding_test.go b/server/middleware/caching/collapsed_forwarding_test.go new file mode 100644 index 0000000..723ccd9 --- /dev/null +++ b/server/middleware/caching/collapsed_forwarding_test.go @@ -0,0 +1,394 @@ +package caching + +import ( + "bytes" + "errors" + "io" + "net/http" + "strings" + "sync" + "sync/atomic" + "testing" + "time" +) + +// --------------------------------------------------------------------------- +// ChunkFlightGroup tests +// --------------------------------------------------------------------------- + +func TestChunkFlight_BasicCollapse(t *testing.T) { + g := &ChunkFlightGroup{} + var callCount atomic.Int32 + + fn := func() (*http.Response, error) { + callCount.Add(1) + return &http.Response{ + StatusCode: http.StatusPartialContent, + Body: io.NopCloser(strings.NewReader("chunk-data")), + }, nil + } + + type result struct { + data string + shared bool + } + + results := make([]result, 3) + var wg sync.WaitGroup + start := make(chan struct{}) + + for i := 0; i < 3; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + <-start + r, shared, err := g.Do("obj1:0-1048575", 50*time.Millisecond, fn) + if err != nil { + t.Errorf("caller %d: unexpected error: %v", idx, err) + return + } + data, readErr := io.ReadAll(r) + _ = r.Close() + if readErr != nil { + t.Errorf("caller %d: read error: %v", idx, readErr) + return + } + results[idx] = result{string(data), shared} + }(i) + } + + // Release all callers simultaneously so they race on the map entry. + time.Sleep(10 * time.Millisecond) + close(start) + wg.Wait() + + if callCount.Load() != 1 { + t.Fatalf("expected 1 call, got %d", callCount.Load()) + } + + sharedCount := 0 + for _, r := range results { + if r.shared { + sharedCount++ + } + if r.data != "chunk-data" { + t.Errorf("got %q, want %q", r.data, "chunk-data") + } + } + if sharedCount != 2 { + t.Errorf("expected 2 shared callers, got %d", sharedCount) + } +} + +func TestChunkFlight_ErrorPropagation(t *testing.T) { + g := &ChunkFlightGroup{} + + fn := func() (*http.Response, error) { + return nil, errors.New("upstream timeout") + } + + var wg sync.WaitGroup + start := make(chan struct{}) + errs := make([]error, 3) + + for i := 0; i < 3; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + <-start + r, _, err := g.Do("obj1:0-1048575", 50*time.Millisecond, fn) + if err != nil { + errs[idx] = err + return + } + _, readErr := io.ReadAll(r) + _ = r.Close() + if readErr != nil { + errs[idx] = readErr + } + }(i) + } + + time.Sleep(10 * time.Millisecond) + close(start) + wg.Wait() + + for i, err := range errs { + if err == nil { + t.Errorf("caller %d: expected error, got nil", i) + } + } +} + +func TestChunkFlight_KeyIsolation(t *testing.T) { + g := &ChunkFlightGroup{} + var callCount atomic.Int32 + + makeFn := func(data string) func() (*http.Response, error) { + return func() (*http.Response, error) { + callCount.Add(1) + return &http.Response{ + StatusCode: http.StatusPartialContent, + Body: io.NopCloser(strings.NewReader(data)), + }, nil + } + } + + var wg sync.WaitGroup + results := make(map[string]string, 4) + var mu sync.Mutex + + keys := []string{"obj1:0-1048575", "obj1:1048576-2097151", "obj2:0-1048575", "obj2:1048576-2097151"} + for _, key := range keys { + wg.Add(1) + go func(k string) { + defer wg.Done() + r, _, err := g.Do(k, 10*time.Millisecond, makeFn(k)) + if err != nil { + t.Errorf("key %s: unexpected error: %v", k, err) + return + } + data, _ := io.ReadAll(r) + _ = r.Close() + mu.Lock() + results[k] = string(data) + mu.Unlock() + }(key) + } + wg.Wait() + + if callCount.Load() != 4 { + t.Fatalf("expected 4 distinct calls, got %d", callCount.Load()) + } + for _, key := range keys { + if results[key] != key { + t.Errorf("key %s: got %q, want %q", key, results[key], key) + } + } +} + +func TestChunkFlight_ConcurrentSameKey(t *testing.T) { + g := &ChunkFlightGroup{} + var callCount atomic.Int32 + + fn := func() (*http.Response, error) { + callCount.Add(1) + return &http.Response{ + StatusCode: http.StatusPartialContent, + Body: io.NopCloser(bytes.NewReader(makebuf(1 << 18))), + }, nil + } + + var wg sync.WaitGroup + start := make(chan struct{}) + const numCallers = 10 + sharedCount := atomic.Int32{} + + for i := 0; i < numCallers; i++ { + wg.Add(1) + go func() { + defer wg.Done() + <-start + r, shared, err := g.Do("same-key", 100*time.Millisecond, fn) + if err != nil { + t.Errorf("unexpected error: %v", err) + return + } + _, _ = io.ReadAll(r) + _ = r.Close() + if shared { + sharedCount.Add(1) + } + }() + } + + time.Sleep(10 * time.Millisecond) + close(start) + wg.Wait() + + if callCount.Load() != 1 { + t.Fatalf("expected exactly 1 origin call, got %d", callCount.Load()) + } + if sharedCount.Load() != numCallers-1 { + t.Fatalf("expected %d shared callers, got %d", numCallers-1, sharedCount.Load()) + } +} + +// --------------------------------------------------------------------------- +// ObjectFlightGroup tests +// --------------------------------------------------------------------------- + +func TestObjectFlight_BasicCollapse(t *testing.T) { + g := &ObjectFlightGroup{} + var callCount atomic.Int32 + + fn := func() (*http.Response, error) { + callCount.Add(1) + time.Sleep(30 * time.Millisecond) // simulate origin latency + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader("response-body")), + }, nil + } + + var wg sync.WaitGroup + start := make(chan struct{}) + bodies := make([]string, 5) + shareds := make([]bool, 5) + + for i := 0; i < 5; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + <-start + resp, shared, err := g.Do("cache-key-1", 50*time.Millisecond, fn) + if err != nil { + t.Errorf("caller %d: unexpected error: %v", idx, err) + return + } + shareds[idx] = shared + body, readErr := io.ReadAll(resp.Body) + _ = resp.Body.Close() + if readErr != nil { + t.Errorf("caller %d: read error: %v", idx, readErr) + return + } + bodies[idx] = string(body) + }(i) + } + + time.Sleep(10 * time.Millisecond) + close(start) + wg.Wait() + + if callCount.Load() != 1 { + t.Fatalf("expected 1 call, got %d", callCount.Load()) + } + nonShared := 0 + shared := 0 + for _, s := range shareds { + if s { + shared++ + } else { + nonShared++ + } + } + if nonShared != 1 { + t.Errorf("expected 1 non-shared caller, got %d", nonShared) + } + if shared != 4 { + t.Errorf("expected 4 shared callers, got %d", shared) + } + for i, b := range bodies { + if b != "response-body" { + t.Errorf("caller %d: body = %q, want %q", i, b, "response-body") + } + } +} + +func TestObjectFlight_ErrorPropagation(t *testing.T) { + g := &ObjectFlightGroup{} + var callCount atomic.Int32 + + testErr := errors.New("origin connection refused") + fn := func() (*http.Response, error) { + callCount.Add(1) + time.Sleep(30 * time.Millisecond) // window for dup callers to register + return nil, testErr + } + + var wg sync.WaitGroup + start := make(chan struct{}) + errs := make([]error, 3) + + for i := 0; i < 3; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + <-start + _, _, err := g.Do("cache-key-err", 50*time.Millisecond, fn) + errs[idx] = err + }(i) + } + + time.Sleep(10 * time.Millisecond) + close(start) + wg.Wait() + + if callCount.Load() != 1 { + t.Fatalf("expected 1 call, got %d", callCount.Load()) + } + for i, err := range errs { + if !errors.Is(err, testErr) { + t.Errorf("caller %d: got %v, want %v", i, err, testErr) + } + } +} + +func TestObjectFlight_KeyIsolation(t *testing.T) { + g := &ObjectFlightGroup{} + var callCount atomic.Int32 + + fn := func() (*http.Response, error) { + callCount.Add(1) + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader("body")), + }, nil + } + + var wg sync.WaitGroup + for _, key := range []string{"key-a", "key-b", "key-c"} { + wg.Add(1) + go func(k string) { + defer wg.Done() + resp, _, err := g.Do(k, 0, fn) + if err != nil { + t.Errorf("key %s: unexpected error: %v", k, err) + return + } + io.Copy(io.Discard, resp.Body) + resp.Body.Close() + }(key) + } + wg.Wait() + + if callCount.Load() != 3 { + t.Fatalf("expected 3 distinct calls, got %d", callCount.Load()) + } +} + +func TestObjectFlight_SequentialReuse(t *testing.T) { + g := &ObjectFlightGroup{} + var callCount atomic.Int32 + + fn := func() (*http.Response, error) { + callCount.Add(1) + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader("body")), + }, nil + } + + resp, _, err := g.Do("seq-key", 0, fn) + if err != nil { + t.Fatalf("first call: unexpected error: %v", err) + } + io.Copy(io.Discard, resp.Body) + resp.Body.Close() + if callCount.Load() != 1 { + t.Fatalf("first call: expected 1, got %d", callCount.Load()) + } + + time.Sleep(10 * time.Millisecond) + + resp, _, err = g.Do("seq-key", 0, fn) + if err != nil { + t.Fatalf("second call: unexpected error: %v", err) + } + io.Copy(io.Discard, resp.Body) + resp.Body.Close() + if callCount.Load() != 2 { + t.Fatalf("sequential call: expected 2, got %d", callCount.Load()) + } +} diff --git a/server/middleware/caching/internal.go b/server/middleware/caching/internal.go index 9e2d996..5681083 100644 --- a/server/middleware/caching/internal.go +++ b/server/middleware/caching/internal.go @@ -45,6 +45,7 @@ type Caching struct { rootmd *object.Metadata bucket storage.Bucket proxyClient proxy.Proxy + chunkFlight *ChunkFlightGroup cacheStatus storage.CacheStatus cacheable bool hit bool diff --git a/server/middleware/caching/object_flight.go b/server/middleware/caching/object_flight.go new file mode 100644 index 0000000..d20783f --- /dev/null +++ b/server/middleware/caching/object_flight.go @@ -0,0 +1,163 @@ +package caching + +import ( + "io" + "net/http" + "sync" + "time" +) + +// objectFlightCall represents an in-flight full-object origin fetch. +// +// Unlike the previous WaitGroup-only approach, this uses io.Pipe + +// io.MultiWriter to fan out the response body to all concurrent callers. +// This ensures the leader's response body is consumed (which drives the +// SavepartAsyncReader → disk writes) while simultaneously providing data +// to all waiting callers — no cache re-lookup is needed. +type objectFlightCall struct { + resp *http.Response + pipes []*io.PipeWriter + mu sync.Mutex // protects pipes during registration and snapshot + wg sync.WaitGroup // signals that resp headers / err are ready + err error +} + +// ObjectFlightGroup collapses concurrent full-MISS requests for the same +// cache object. Unlike ChunkFlightGroup (which works at the chunk/segment +// level), this operates at the whole-object level — it ensures only one +// goroutine hits origin for a given cache key. +// +// The returned response carries the headers from the leader's fn and a +// body that fans out to all concurrent callers. Callers must close the +// body. +type ObjectFlightGroup struct { + mu sync.Mutex + m map[string]*objectFlightCall +} + +// Do executes fn once per key and fans out the response body to all +// concurrent callers. All callers receive the same response headers +// (cloned) and a shared body stream. +// +// waiter is the duration the leader pauses before calling fn, giving +// late-arriving callers a window to register under the same key. +// +// Returns: +// +// resp — a response carrying the leader's headers and a shared body +// shared — true if this caller joined an existing flight +// err — error from fn or from body copy +func (g *ObjectFlightGroup) Do(key string, waiter time.Duration, fn func() (*http.Response, error)) (*http.Response, bool, error) { + pr, pw := io.Pipe() + + g.mu.Lock() + if g.m == nil { + g.m = make(map[string]*objectFlightCall) + } + if c, ok := g.m[key]; ok { + // Waiter: register a pipe writer and wait for headers. + c.mu.Lock() + c.pipes = append(c.pipes, pw) + c.mu.Unlock() + g.mu.Unlock() + + c.wg.Wait() + if c.err != nil { + _ = pw.CloseWithError(c.err) + return nil, true, c.err + } + + resp := cloneResponse(c.resp) + resp.Body = pr + return resp, true, nil + } + + // Leader: create the flight and execute fn. + c := &objectFlightCall{pipes: []*io.PipeWriter{pw}} + c.wg.Add(1) + g.m[key] = c + g.mu.Unlock() + + if waiter > 0 { + time.Sleep(waiter) + } + + resp, err := fn() + + g.mu.Lock() + delete(g.m, key) + + if err != nil { + c.err = err + g.mu.Unlock() + c.wg.Done() + + // Snapshot pipes under c.mu to avoid racing with waiter registrations. + c.mu.Lock() + for _, p := range c.pipes { + _ = p.CloseWithError(err) + } + c.mu.Unlock() + return nil, false, err + } + + c.resp = resp + c.wg.Done() // release waiters — headers are now available + + // Snapshot pipes under c.mu to avoid racing with waiter registrations. + c.mu.Lock() + pipes := make([]*io.PipeWriter, len(c.pipes)) + copy(pipes, c.pipes) + c.mu.Unlock() + g.mu.Unlock() + + // Fan out the response body to all pipes (including the leader's). + // This also drives the SavepartAsyncReader → disk write chain. + go func() { + writers := make([]io.Writer, len(pipes)) + for i, p := range pipes { + writers[i] = p + } + mw := io.MultiWriter(writers...) + + var copyErr error + if resp.Body != nil { + _, copyErr = io.Copy(mw, resp.Body) + _ = resp.Body.Close() + } + + for _, p := range pipes { + if copyErr != nil && copyErr != io.EOF { + _ = p.CloseWithError(copyErr) + } else { + _ = p.Close() + } + } + }() + + leaderResp := cloneResponse(resp) + leaderResp.Body = pr + return leaderResp, false, nil +} + +// cloneResponse returns a shallow copy of resp with a cloned Header map. +// Body is left nil — the caller sets it to a pipe reader. +func cloneResponse(resp *http.Response) *http.Response { + if resp == nil { + return nil + } + return &http.Response{ + Status: resp.Status, + StatusCode: resp.StatusCode, + Proto: resp.Proto, + ProtoMajor: resp.ProtoMajor, + ProtoMinor: resp.ProtoMinor, + Header: resp.Header.Clone(), + ContentLength: resp.ContentLength, + TransferEncoding: resp.TransferEncoding, + Close: resp.Close, + Uncompressed: resp.Uncompressed, + Request: resp.Request, + TLS: resp.TLS, + } +} diff --git a/tests/all-features/caching/collapsed_forwarding_test.go b/tests/all-features/caching/collapsed_forwarding_test.go new file mode 100644 index 0000000..5d04f51 --- /dev/null +++ b/tests/all-features/caching/collapsed_forwarding_test.go @@ -0,0 +1,302 @@ +package caching + +import ( + "io" + "net/http" + "strconv" + "strings" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/omalloc/tavern/pkg/e2e" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestCollapsedForwardingObjectFlight(t *testing.T) { + f := e2e.GenFile(t, 2<<20) + var originCallCount atomic.Int32 + + t.Run("test Collapsed Forwarding ObjectFlight Collapse", func(t *testing.T) { + case1 := e2e.New("http://objflight.example.com/cf/object/collapse.bin", e2e.RespCallbackFile(f, func(w http.ResponseWriter, r *http.Request) { + originCallCount.Add(1) + time.Sleep(80 * time.Millisecond) // window for concurrent registrations + + t.Logf("X-Request-Idx: %s", r.Header.Get("X-Request-Idx")) + + w.Header().Set("Cache-Control", "max-age=10") + w.Header().Set("ETag", "obj-flight-etag") + })) + defer case1.Close() + + const N = 5 + var wg sync.WaitGroup + start := make(chan struct{}) + bodies := make([]string, N) + codes := make([]int, N) + xCaches := make([]string, N) + + for i := 0; i < N; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + <-start + + resp, err := case1.Do(func(r *http.Request) { + r.Header.Set("X-Request-Idx", strconv.Itoa(idx)) + }) + + require.NoError(t, err, "caller %d: request should not error", idx) + defer resp.Body.Close() + + hash := e2e.HashBody(resp) + + bodies[idx] = hash + codes[idx] = resp.StatusCode + xCaches[idx] = resp.Header.Get("X-Cache") + }(i) + } + + time.Sleep(10 * time.Millisecond) + close(start) + wg.Wait() + + // Verify only one origin call — ObjectFlightGroup collapsed all 5. + assert.Equal(t, int32(1), originCallCount.Load(), + "object flight should collapse concurrent full-MISS requests") + + // All callers must receive identical response bodies. + for i := 0; i < N; i++ { + assert.Equal(t, http.StatusOK, codes[i], "caller %d: status mismatch", i) + assert.Equal(t, f.MD5, bodies[i], "caller %d: body mismatch", i) + } + + // At least one should be MISS (the first), the rest may be HIT + // depending on whether they re-looked up metadata in time. + hasMiss := false + for _, c := range xCaches { + if c != "" { + hasMiss = hasMiss || strings.Contains(c, "MISS") + } + } + assert.True(t, hasMiss, "at least one response should report MISS") + }) + + t.Run("test Collapsed Forwarding ObjectFlight Sequential", func(t *testing.T) { + originCallCount.Store(0) + + case1 := e2e.New("http://objflight.example.com/cf/object/sequential.bin", e2e.RespCallbackFile(f, func(w http.ResponseWriter, r *http.Request) { + originCallCount.Add(1) + w.Header().Set("Cache-Control", "max-age=10") + w.Header().Set("ETag", "obj-flight-etag") + })) + defer case1.Close() + + const N = 3 + + bodies := make([]string, N) + + // Sequential requests should not be collapsed. + for i := 0; i < N; i++ { + resp, err := case1.Do(func(r *http.Request) { + r.Header.Set("X-Request-Idx", strconv.Itoa(i)) + }) + + require.NoError(t, err, "request %d should not error", i) + + bodies[i] = e2e.HashBody(resp) + } + + assert.Equal(t, int32(1), originCallCount.Load(), + "object flight should not collapse sequential requests") + + for i := 0; i < N; i++ { + assert.Equal(t, f.MD5, bodies[i], "request %d body-hash mismatch", i) + } + + }) + + t.Run("test Collapsed Forwarding ObjectFlight KeyIsolation", func(t *testing.T) { + originCallCount.Store(0) + + case1 := e2e.New("http://keys.example.com/cf/object/", func(w http.ResponseWriter, r *http.Request) { + originCallCount.Add(1) + time.Sleep(80 * time.Millisecond) + + w.Header().Set("Cache-Control", "max-age=10") + }) + defer case1.Close() + + var wg sync.WaitGroup + start := make(chan struct{}) + + for _, key := range []string{"key-a", "key-b", "key-c"} { + wg.Add(1) + go func(k string) { + defer wg.Done() + <-start + + resp, err := case1.Do(func(r *http.Request) { + r.URL.Path += k + t.Logf("Requesting key: %s", k) + }) + + require.NoError(t, err) + io.Copy(io.Discard, resp.Body) + resp.Body.Close() + }(key) + } + + time.Sleep(10 * time.Millisecond) + close(start) + wg.Wait() + + // Three different keys → three independent origin calls. + assert.Equal(t, int32(3), originCallCount.Load(), + "different URLs should have independent object flights") + + }) +} + +func TestCollapsedForwardingChunkFlight(t *testing.T) { + file := e2e.GenFile(t, 3<<20) // 3MB → 6 chunks at 512KB + var originCallCount atomic.Int32 + + t.Run("test Collapsed Forwarding ChunkFlight", func(t *testing.T) { + case1 := e2e.New("http://chunkflight.example.com/cf/chunk/collapse.bin", e2e.RespCallbackFile(file, func(w http.ResponseWriter, r *http.Request) { + originCallCount.Add(1) + w.Header().Set("Cache-Control", "max-age=30") + w.Header().Set("ETag", file.MD5) + })) + defer case1.Close() + + resp, err := case1.Do(func(r *http.Request) { + r.Header.Set("Range", "bytes=0-524287") + }) + + require.NoError(t, err) + io.Copy(io.Discard, resp.Body) + resp.Body.Close() + require.Equal(t, http.StatusPartialContent, resp.StatusCode) + + // Give storage time to finish writing indexdb metadata. + time.Sleep(300 * time.Millisecond) + + // Phase 2 — concurrent requests for a range that needs missing chunks. + originCallCount.Store(0) + + const N = 3 + var wg sync.WaitGroup + start := make(chan struct{}) + bodies := make([]string, N) + codes := make([]int, N) + + for i := 0; i < N; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + <-start + + resp1, err1 := case1.Do(func(r *http.Request) { + r.Header.Set("Range", "bytes=524288-2097151") + }) + + require.NoError(t, err1, "caller %d: request should not error", idx) + defer resp1.Body.Close() + + body, _ := io.ReadAll(resp1.Body) + bodies[idx] = string(body) + codes[idx] = resp1.StatusCode + }(i) + } + + time.Sleep(10 * time.Millisecond) + close(start) + wg.Wait() + + t.Logf("origin call count for concurrent phase: %d", originCallCount.Load()) + + // All callers must receive correct 206 responses. + for i := 0; i < N; i++ { + assert.Equal(t, http.StatusPartialContent, codes[i], "caller %d: status mismatch", i) + assert.NotEmpty(t, bodies[i], "caller %d: body should not be empty", i) + } + + // Verify body correctness: compare against the source file. + expected := e2e.HashFile(file.Path, 524288, 2097151-524288+1) + for i := 0; i < N; i++ { + actual := e2e.SumMD5([]byte(bodies[i])) + assert.Equal(t, expected, actual, "caller %d: body hash mismatch", i) + } + + // The concurrent chunk fetch for the missing range must be collapsed. + assert.Equal(t, int32(1), originCallCount.Load(), + "chunk flight should collapse concurrent chunk fetches to 1 origin call") + }) + + t.Run("test Collapsed Forwarding ChunkFlight KeyIsolation", func(t *testing.T) { + originCallCount.Store(0) + case1 := e2e.New("http://chunkflight.example.com/cf/chunk/keys.bin", e2e.RespCallbackFile(file, func(w http.ResponseWriter, r *http.Request) { + originCallCount.Add(1) + w.Header().Set("Cache-Control", "max-age=30") + w.Header().Set("ETag", file.MD5) + })) + defer case1.Close() + + // Phase 1 — cache only the middle chunk (chunk 1, bytes 524288-1048575). + resp, err := case1.Do(func(r *http.Request) { + r.Header.Set("Range", "bytes=524288-1048575") + }) + + require.NoError(t, err) + io.Copy(io.Discard, resp.Body) + resp.Body.Close() + require.Equal(t, http.StatusPartialContent, resp.StatusCode) + + time.Sleep(300 * time.Millisecond) + + originCallCount.Store(0) + + // Phase 2 — request two different missing ranges concurrently. + // Range A: bytes=0-524287 (needs chunk 0, not cached) + // Range B: bytes=1048576-2097151 (needs chunk 2+, not cached) + // Different chunk flight keys → independent origin calls. + var wg sync.WaitGroup + start := make(chan struct{}) + ranges := []string{"bytes=0-524287", "bytes=1048576-2097151"} + errs := make([]error, len(ranges)) + + for i, rng := range ranges { + wg.Add(1) + go func(idx int, rng string) { + defer wg.Done() + <-start + + resp2, e := case1.Do(func(r *http.Request) { + r.Header.Set("Range", rng) + }) + if e != nil { + errs[idx] = e + return + } + io.Copy(io.Discard, resp2.Body) + resp2.Body.Close() + }(i, rng) + } + + time.Sleep(10 * time.Millisecond) + close(start) + wg.Wait() + + for i, e := range errs { + assert.NoError(t, e, "request %d should not error", i) + } + + // Different chunk ranges → different flight keys → 2 origin calls. + assert.Equal(t, int32(2), originCallCount.Load(), + "different byte ranges should use independent chunk flights") + + }) +}