diff --git a/internal/consistency/diff/table_diff.go b/internal/consistency/diff/table_diff.go index 02ff81f..e838860 100644 --- a/internal/consistency/diff/table_diff.go +++ b/internal/consistency/diff/table_diff.go @@ -110,8 +110,9 @@ type TableDiffTask struct { DiffResult types.DiffOutput diffMutex sync.Mutex - firstError error - firstErrorMu sync.Mutex + firstError error + firstErrorMu sync.Mutex + errorRecorded atomic.Bool totalDiffRows atomic.Int64 diffLimitTriggered atomic.Bool @@ -149,6 +150,7 @@ func (t *TableDiffTask) recordError(err error) { t.firstError = err } t.firstErrorMu.Unlock() + t.errorRecorded.Store(true) } func (t *TableDiffTask) getFirstError() error { @@ -157,6 +159,19 @@ func (t *TableDiffTask) getFirstError() error { return t.firstError } +func (t *TableDiffTask) hasError() bool { + return t.errorRecorded.Load() +} + +// shouldStop returns true if diff workers should cease processing, either because +// the diff row limit was reached or because a node error has been recorded (circuit +// breaker). This prevents OOM when a node starts failing: without this check, +// goroutines would keep grinding through every remaining sub-range — each waiting +// up to 60 s for a timeout — accumulating error objects until the process is killed. +func (t *TableDiffTask) shouldStop() bool { + return t.shouldStopDueToLimit() || t.hasError() +} + func (t *TableDiffTask) shouldStopDueToLimit() bool { if t.MaxDiffRows <= 0 { return false @@ -1163,6 +1178,10 @@ func (t *TableDiffTask) ExecuteTask() (err error) { t.totalDiffRows.Store(0) t.diffLimitTriggered.Store(false) + t.errorRecorded.Store(false) + t.firstErrorMu.Lock() + t.firstError = nil + t.firstErrorMu.Unlock() var recorder *taskstore.Recorder @@ -1493,6 +1512,10 @@ func (t *TableDiffTask) ExecuteTask() (err error) { go func() { defer initialHashWg.Done() for task := range hashTaskQueue { + if t.shouldStop() { + bar.Increment() + continue + } sem <- struct{}{} queryCtx, cancel := context.WithTimeout(t.Ctx, 60*time.Second) hashValue, hErr := t.hashRange(queryCtx, task.nodeName, task.r) @@ -1581,7 +1604,7 @@ func (t *TableDiffTask) ExecuteTask() (err error) { ) for _, task := range mismatchedTasks { - if t.shouldStopDueToLimit() { + if t.shouldStop() { diffBar.Increment() continue } @@ -1905,7 +1928,7 @@ func (t *TableDiffTask) recursiveDiff( ) { defer wg.Done() - if t.shouldStopDueToLimit() { + if t.shouldStop() { return } @@ -1957,7 +1980,7 @@ func (t *TableDiffTask) recursiveDiff( } for _, row := range diffInfo.Node1OnlyRows { - if t.shouldStopDueToLimit() { + if t.shouldStop() { limitReached = true break } @@ -1974,7 +1997,7 @@ func (t *TableDiffTask) recursiveDiff( if !limitReached { for _, row := range diffInfo.Node2OnlyRows { - if t.shouldStopDueToLimit() { + if t.shouldStop() { limitReached = true break } @@ -1992,7 +2015,7 @@ func (t *TableDiffTask) recursiveDiff( if !limitReached { for _, modRow := range diffInfo.ModifiedRows { - if t.shouldStopDueToLimit() { + if t.shouldStop() { limitReached = true break } @@ -2019,14 +2042,14 @@ func (t *TableDiffTask) recursiveDiff( t.DiffResult.Summary.DiffRowsCount[pairKey] += currentDiffRowsForPair t.diffMutex.Unlock() - if limitReached || t.shouldStopDueToLimit() { + if limitReached || t.shouldStop() { return } } return } - if t.shouldStopDueToLimit() { + if t.shouldStop() { return } @@ -2065,7 +2088,7 @@ func (t *TableDiffTask) recursiveDiff( } for _, sr := range subRanges { - if t.shouldStopDueToLimit() { + if t.shouldStop() { return } @@ -2096,20 +2119,20 @@ func (t *TableDiffTask) recursiveDiff( errWrap := fmt.Errorf("hashing sub-range %v-%v for %s: %w", sr.Start, sr.End, node1Name, res1.err) t.recordError(errWrap) logger.Info("ERROR hashing sub-range %v-%v for %s: %v", sr.Start, sr.End, node1Name, errWrap) - continue + return } if res2.err != nil { errWrap := fmt.Errorf("hashing sub-range %v-%v for %s: %w", sr.Start, sr.End, node2Name, res2.err) t.recordError(errWrap) logger.Info("ERROR hashing sub-range %v-%v for %s: %v", sr.Start, sr.End, node2Name, errWrap) - continue + return } if res1.hash != res2.hash { logger.Debug("%s Mismatch in sub-range %v-%v for %s (%s...) vs %s (%s...). Recursing.", utils.CrossMark, sr.Start, sr.End, node1Name, utils.SafeCut(res1.hash, 8), node2Name, utils.SafeCut(res2.hash, 8)) - if t.shouldStopDueToLimit() { + if t.shouldStop() { return }