Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 41 additions & 6 deletions sweepbatcher/sweep_batcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -780,16 +780,18 @@ func (b *Batcher) PresignSweepsGroup(ctx context.Context, inputs []Input,
// times, but the sweeps (including the order of them) must be the same. If
// notifier is provided, the batcher sends back sweeping results through it.
func (b *Batcher) AddSweep(ctx context.Context, sweepReq *SweepRequest) error {
// If the batcher is shutting down, quit now.
select {
case <-b.quit:
return ErrBatcherShuttingDown

default:
// If the batcher or the caller is shutting down, quit now.
err := b.addSweepExitErrIfAny(ctx)
if err != nil {
return err
}

sweeps, err := b.fetchSweeps(ctx, *sweepReq)
if err != nil {
if exitErr := b.addSweepExitErrIfAny(ctx); exitErr != nil {
return exitErr
}

return fmt.Errorf("fetchSweeps failed: %w", err)
}
Comment thread
hieblmi marked this conversation as resolved.

Expand All @@ -803,6 +805,10 @@ func (b *Batcher) AddSweep(ctx context.Context, sweepReq *SweepRequest) error {

completed, err := b.store.GetSweepStatus(ctx, sweep.outpoint)
if err != nil {
if exitErr := b.addSweepExitErrIfAny(ctx); exitErr != nil {
return exitErr
}

return fmt.Errorf("failed to get the status of sweep %v: %w",
sweep.outpoint, err)
}
Expand All @@ -816,6 +822,10 @@ func (b *Batcher) AddSweep(ctx context.Context, sweepReq *SweepRequest) error {
// on-chain confirmations to prevent issues caused by reorgs.
parentBatch, err = b.store.GetParentBatch(ctx, sweep.outpoint)
if err != nil {
if exitErr := b.addSweepExitErrIfAny(ctx); exitErr != nil {
return exitErr
}

return fmt.Errorf("unable to get parent batch for "+
"sweep %x: %w", sweep.swapHash[:6], err)
}
Expand All @@ -827,6 +837,10 @@ func (b *Batcher) AddSweep(ctx context.Context, sweepReq *SweepRequest) error {

minRelayFeeRate, err := b.wallet.MinRelayFee(ctx)
if err != nil {
if exitErr := b.addSweepExitErrIfAny(ctx); exitErr != nil {
return exitErr
}

return fmt.Errorf("failed to get min relay fee: %w", err)
}

Expand All @@ -839,6 +853,10 @@ func (b *Batcher) AddSweep(ctx context.Context, sweepReq *SweepRequest) error {
b.chainParams,
)
if err != nil {
if exitErr := b.addSweepExitErrIfAny(ctx); exitErr != nil {
return exitErr
}

return fmt.Errorf("inputs with primarySweep %v were "+
"not presigned (call PresignSweepsGroup "+
"first): %w", sweep.outpoint, err)
Expand All @@ -861,7 +879,24 @@ func (b *Batcher) AddSweep(ctx context.Context, sweepReq *SweepRequest) error {

case <-b.quit:
return ErrBatcherShuttingDown

case <-ctx.Done():
return b.addSweepExitErrIfAny(ctx)
}
}

// addSweepExitErrIfAny returns the terminal error to use when AddSweep races
// with shutdown or caller cancellation. It returns nil if AddSweep should
// continue.
func (b *Batcher) addSweepExitErrIfAny(ctx context.Context) error {
select {
case <-b.quit:
return ErrBatcherShuttingDown

default:
}

return ctx.Err()
}

// testRunInEventLoop runs a function in the event loop blocking until
Expand Down
50 changes: 50 additions & 0 deletions sweepbatcher/sweep_batcher_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package sweepbatcher
import (
"context"
"database/sql"
"database/sql/driver"
"errors"
"fmt"
"maps"
Expand Down Expand Up @@ -3711,6 +3712,55 @@ func (f *sweepFetcherMock) FetchSweep(ctx context.Context, _ lntypes.Hash,
return f.store[outpoint], nil
}

type cancelingSweepFetcher struct {
cancel context.CancelFunc
}

func (f *cancelingSweepFetcher) FetchSweep(context.Context, lntypes.Hash,
wire.OutPoint) (*SweepInfo, error) {

// Simulate the caller canceling while the backend returns a
// driver-level error.
f.cancel()

return nil, driver.ErrBadConn
}

func testAddSweepReturnsContextErrorOnFetchCancellation(t *testing.T,
_ testStore, batcherStore testBatcherStore) {

defer test.Guard(t)()

lnd := test.NewMockLnd()
ctx, cancel := context.WithCancel(context.Background())

batcher := NewBatcher(
lnd.WalletKit, lnd.ChainNotifier, lnd.Signer,
testMuSig2SignSweep, testVerifySchnorrSig, lnd.ChainParams,
batcherStore, &cancelingSweepFetcher{cancel: cancel},
)

err := batcher.AddSweep(ctx, &SweepRequest{
SwapHash: lntypes.Hash{1, 1, 1},
Inputs: []Input{{
Value: 1111,
Outpoint: wire.OutPoint{
Hash: chainhash.Hash{1, 1},
Index: 1,
},
}},
})
require.ErrorIs(t, err, context.Canceled)
require.NotErrorIs(t, err, driver.ErrBadConn)
}

// TestAddSweepReturnsContextErrorOnFetchCancellation asserts that AddSweep
// returns the context cancellation error if sweep fetching fails while the
// caller context is being canceled.
func TestAddSweepReturnsContextErrorOnFetchCancellation(t *testing.T) {
runTests(t, testAddSweepReturnsContextErrorOnFetchCancellation)
}

// testSweepFetcher tests providing custom sweep fetcher to Batcher.
func testSweepFetcher(t *testing.T, store testStore,
batcherStore testBatcherStore) {
Expand Down
Loading