From bc2ff0d44d3c8d09e536ac70c6d75aa407402eea Mon Sep 17 00:00:00 2001 From: Slyghtning Date: Tue, 28 Apr 2026 12:28:29 +0200 Subject: [PATCH] sweepbatcher: harden AddSweep against ctx closure --- sweepbatcher/sweep_batcher.go | 47 ++++++++++++++++++++++++---- sweepbatcher/sweep_batcher_test.go | 50 ++++++++++++++++++++++++++++++ 2 files changed, 91 insertions(+), 6 deletions(-) diff --git a/sweepbatcher/sweep_batcher.go b/sweepbatcher/sweep_batcher.go index 05b273863..9ed63a663 100644 --- a/sweepbatcher/sweep_batcher.go +++ b/sweepbatcher/sweep_batcher.go @@ -780,16 +780,18 @@ func (b *Batcher) PresignSweepsGroup(ctx context.Context, inputs []Input, // times, but the sweeps (including the order of them) must be the same. If // notifier is provided, the batcher sends back sweeping results through it. func (b *Batcher) AddSweep(ctx context.Context, sweepReq *SweepRequest) error { - // If the batcher is shutting down, quit now. - select { - case <-b.quit: - return ErrBatcherShuttingDown - - default: + // If the batcher or the caller is shutting down, quit now. + err := b.addSweepExitErrIfAny(ctx) + if err != nil { + return err } sweeps, err := b.fetchSweeps(ctx, *sweepReq) if err != nil { + if exitErr := b.addSweepExitErrIfAny(ctx); exitErr != nil { + return exitErr + } + return fmt.Errorf("fetchSweeps failed: %w", err) } @@ -803,6 +805,10 @@ func (b *Batcher) AddSweep(ctx context.Context, sweepReq *SweepRequest) error { completed, err := b.store.GetSweepStatus(ctx, sweep.outpoint) if err != nil { + if exitErr := b.addSweepExitErrIfAny(ctx); exitErr != nil { + return exitErr + } + return fmt.Errorf("failed to get the status of sweep %v: %w", sweep.outpoint, err) } @@ -816,6 +822,10 @@ func (b *Batcher) AddSweep(ctx context.Context, sweepReq *SweepRequest) error { // on-chain confirmations to prevent issues caused by reorgs. parentBatch, err = b.store.GetParentBatch(ctx, sweep.outpoint) if err != nil { + if exitErr := b.addSweepExitErrIfAny(ctx); exitErr != nil { + return exitErr + } + return fmt.Errorf("unable to get parent batch for "+ "sweep %x: %w", sweep.swapHash[:6], err) } @@ -827,6 +837,10 @@ func (b *Batcher) AddSweep(ctx context.Context, sweepReq *SweepRequest) error { minRelayFeeRate, err := b.wallet.MinRelayFee(ctx) if err != nil { + if exitErr := b.addSweepExitErrIfAny(ctx); exitErr != nil { + return exitErr + } + return fmt.Errorf("failed to get min relay fee: %w", err) } @@ -839,6 +853,10 @@ func (b *Batcher) AddSweep(ctx context.Context, sweepReq *SweepRequest) error { b.chainParams, ) if err != nil { + if exitErr := b.addSweepExitErrIfAny(ctx); exitErr != nil { + return exitErr + } + return fmt.Errorf("inputs with primarySweep %v were "+ "not presigned (call PresignSweepsGroup "+ "first): %w", sweep.outpoint, err) @@ -861,7 +879,24 @@ func (b *Batcher) AddSweep(ctx context.Context, sweepReq *SweepRequest) error { case <-b.quit: return ErrBatcherShuttingDown + + case <-ctx.Done(): + return b.addSweepExitErrIfAny(ctx) + } +} + +// addSweepExitErrIfAny returns the terminal error to use when AddSweep races +// with shutdown or caller cancellation. It returns nil if AddSweep should +// continue. +func (b *Batcher) addSweepExitErrIfAny(ctx context.Context) error { + select { + case <-b.quit: + return ErrBatcherShuttingDown + + default: } + + return ctx.Err() } // testRunInEventLoop runs a function in the event loop blocking until diff --git a/sweepbatcher/sweep_batcher_test.go b/sweepbatcher/sweep_batcher_test.go index a57541f2f..2d629098b 100644 --- a/sweepbatcher/sweep_batcher_test.go +++ b/sweepbatcher/sweep_batcher_test.go @@ -3,6 +3,7 @@ package sweepbatcher import ( "context" "database/sql" + "database/sql/driver" "errors" "fmt" "maps" @@ -3711,6 +3712,55 @@ func (f *sweepFetcherMock) FetchSweep(ctx context.Context, _ lntypes.Hash, return f.store[outpoint], nil } +type cancelingSweepFetcher struct { + cancel context.CancelFunc +} + +func (f *cancelingSweepFetcher) FetchSweep(context.Context, lntypes.Hash, + wire.OutPoint) (*SweepInfo, error) { + + // Simulate the caller canceling while the backend returns a + // driver-level error. + f.cancel() + + return nil, driver.ErrBadConn +} + +func testAddSweepReturnsContextErrorOnFetchCancellation(t *testing.T, + _ testStore, batcherStore testBatcherStore) { + + defer test.Guard(t)() + + lnd := test.NewMockLnd() + ctx, cancel := context.WithCancel(context.Background()) + + batcher := NewBatcher( + lnd.WalletKit, lnd.ChainNotifier, lnd.Signer, + testMuSig2SignSweep, testVerifySchnorrSig, lnd.ChainParams, + batcherStore, &cancelingSweepFetcher{cancel: cancel}, + ) + + err := batcher.AddSweep(ctx, &SweepRequest{ + SwapHash: lntypes.Hash{1, 1, 1}, + Inputs: []Input{{ + Value: 1111, + Outpoint: wire.OutPoint{ + Hash: chainhash.Hash{1, 1}, + Index: 1, + }, + }}, + }) + require.ErrorIs(t, err, context.Canceled) + require.NotErrorIs(t, err, driver.ErrBadConn) +} + +// TestAddSweepReturnsContextErrorOnFetchCancellation asserts that AddSweep +// returns the context cancellation error if sweep fetching fails while the +// caller context is being canceled. +func TestAddSweepReturnsContextErrorOnFetchCancellation(t *testing.T) { + runTests(t, testAddSweepReturnsContextErrorOnFetchCancellation) +} + // testSweepFetcher tests providing custom sweep fetcher to Batcher. func testSweepFetcher(t *testing.T, store testStore, batcherStore testBatcherStore) {