Fix MPS memory pathologies on Apple Silicon#340
Conversation
- 5 issues fixed, 3 skipped (pre-existing) - Add error handling for read-only filesystem in on-demand FTS5 build - Guard torch.mps.empty_cache() with try-except - Handle empty passages with explicit error log - Catch JSONDecodeError per-line in JSONL reading - Add comment explaining MPS no-op block See .doc/pr-review-comments/PR-340-mps-memory-pathologies.md
|
This is such a nice PR, we should merge that cc @andylizf @ASuresh0524 |
andylizf
left a comment
There was a problem hiding this comment.
LGTM — MPS data is solid, BM25 cleanup is clean. Two nits below.
| "BM25/hybrid search will return empty results. " | ||
| "Re-run 'leann build' to regenerate passage files." | ||
| ) | ||
|
|
There was a problem hiding this comment.
nit: After logging the error, this falls through to index.fit([]) which creates an empty FTS5 table. Consider adding a return here to bail out early — an empty index silently returns no results on search, which is harder to debug than skipping BM25 init entirely.
There was a problem hiding this comment.
Good catch — fixed in f9c74aa. The empty-passages path now returns early, leaving self.bm25_scorer = None. _bm25_search (line 1496-1497) checks for that and raises RuntimeError("BM25 scorer failed to initialize"), so it surfaces loudly instead of returning empty hits.
| f"Ensure the index directory is writable, or rebuild with prebuild_bm25=True." | ||
| ) | ||
| return | ||
|
|
There was a problem hiding this comment.
nit: When fit() raises PermissionError/OSError, self.bm25_scorer is never assigned. Worth confirming the caller guards against bm25_scorer being unset — otherwise _bm25_search will AttributeError on the next hybrid query.
There was a problem hiding this comment.
Confirmed safe — _bm25_search at lines 1496-1497 already does if scorer is None: raise RuntimeError("BM25 scorer failed to initialize"). The PermissionError/OSError path returns without assigning self.bm25_scorer, so it stays None and the next call surfaces as RuntimeError, not AttributeError. With f9c74aa, the empty-passages path now behaves the same way.
|
yeah, and please fix the conflict here |
- Remove torch.mps.set_per_process_memory_fraction(0.9) which lets MPS allocator greedily fill ~29 GB on 32 GB machines without releasing - Guard torch.compile(mode=reduce-overhead) to CUDA only; on MPS it caches compiled graphs per sequence-length bucket (~5 GB waste) - Add torch.mps.empty_cache() between manual HF batches - Fix BM25Scorer syntax error from upstream merge (restore fit/search methods) Reported-by: claude_writing_template (footprint dumps: 22 GB → expected ~3 GB) Tested: ModernPubMedBERT 110M, batch_size=8, 5114 chunks, 640 batches, 32 GB M-series
- Delete BM25Scorer class (in-memory TF tables, O(corpus) RAM at search) - Remove duplicate Fts5BM25Index class from merge artifact - Remove _build_bm25_snapshot pickle codepath - Default bm25_backend to 'fts5', deprecation warning for 'memory' - Fallback path now builds FTS5 index on-demand from passages - Clean up unused Counter/defaultdict imports
- 5 issues fixed, 3 skipped (pre-existing) - Add error handling for read-only filesystem in on-demand FTS5 build - Guard torch.mps.empty_cache() with try-except - Handle empty passages with explicit error log - Catch JSONDecodeError per-line in JSONL reading - Add comment explaining MPS no-op block See .doc/pr-review-comments/PR-340-mps-memory-pathologies.md
Address andylizf PR review nit on StarTrail-org#340: empty passages list fell through to index.fit([]) creating an empty FTS5 table. Now we log+return, leaving bm25_scorer = None so the existing _bm25_search guard raises RuntimeError instead of silently returning empty results.
a4c4d5c to
f9c74aa
Compare
|
Rebased onto main (post-#341) and pushed f9c74aa. Conflict resolution: kept the deletion path (this PR completes the #335 intent that PR #341's author noted in their description). Both review nits addressed:
Credential audit on the 5 commits in this PR: clean — only |
Problem
Two issues in
embedding_compute.pycause 22 GB GPU memory usage for a 110M BERT model on Apple Silicon (expected ~3 GB):torch.mps.set_per_process_memory_fraction(0.9)tells MPS allocator to use up to 29 GB on a 32 GB machine. Allocator fills budget and never releases.torch.compile(mode="reduce-overhead", dynamic=True)on MPS caches compiled graphs + pre-allocated FP16 buffers per sequence-length bucket. 640 batches of varying lengths create 5+ GB of graph buffers.Additionally,
BM25Scorerhas a syntax error from the bm25 refactor merge — stray unquoted text between__init__andFts5BM25Indexbreaksimport leann.Footprint data (sudo footprint, 32 GB M-series Mac)
Model: lokeshch19/ModernPubMedBERT (~110M params), batch_size=8, 5114 chunks, 640 batches.
Fix
set_per_process_memory_fractionfor MPS (no equivalent benefit to CUDA's OOM protection)torch.compiletodevice == "cuda"only (2 call sites)torch.mps.empty_cache()between manual HF tokenization batchesBM25Scorer.fit()and.search()methods lost in bm25 refactor series