Skip to content

Fix MPS memory pathologies on Apple Silicon#340

Open
ww2283 wants to merge 5 commits into
StarTrail-org:mainfrom
ww2283:fix/mps-memory-pathologies
Open

Fix MPS memory pathologies on Apple Silicon#340
ww2283 wants to merge 5 commits into
StarTrail-org:mainfrom
ww2283:fix/mps-memory-pathologies

Conversation

@ww2283
Copy link
Copy Markdown
Contributor

@ww2283 ww2283 commented May 25, 2026

Problem

Two issues in embedding_compute.py cause 22 GB GPU memory usage for a 110M BERT model on Apple Silicon (expected ~3 GB):

  1. torch.mps.set_per_process_memory_fraction(0.9) tells MPS allocator to use up to 29 GB on a 32 GB machine. Allocator fills budget and never releases.
  2. torch.compile(mode="reduce-overhead", dynamic=True) on MPS caches compiled graphs + pre-allocated FP16 buffers per sequence-length bucket. 640 batches of varying lengths create 5+ GB of graph buffers.

Additionally, BM25Scorer has a syntax error from the bm25 refactor merge — stray unquoted text between __init__ and Fts5BM25Index breaks import leann.

Footprint data (sudo footprint, 32 GB M-series Mac)

Configuration phys_footprint
Original (both issues) 22 GB
torch.compile removed only 17 GB
Both fixed + empty_cache() ~3 GB

Model: lokeshch19/ModernPubMedBERT (~110M params), batch_size=8, 5114 chunks, 640 batches.

Fix

  • Remove set_per_process_memory_fraction for MPS (no equivalent benefit to CUDA's OOM protection)
  • Guard torch.compile to device == "cuda" only (2 call sites)
  • Add torch.mps.empty_cache() between manual HF tokenization batches
  • Restore BM25Scorer.fit() and .search() methods lost in bm25 refactor series

ww2283 added a commit to ww2283/LEANN that referenced this pull request May 25, 2026
- 5 issues fixed, 3 skipped (pre-existing)
- Add error handling for read-only filesystem in on-demand FTS5 build
- Guard torch.mps.empty_cache() with try-except
- Handle empty passages with explicit error log
- Catch JSONDecodeError per-line in JSONL reading
- Add comment explaining MPS no-op block

See .doc/pr-review-comments/PR-340-mps-memory-pathologies.md
@yichuan-w
Copy link
Copy Markdown
Collaborator

This is such a nice PR, we should merge that cc @andylizf @ASuresh0524

Copy link
Copy Markdown
Collaborator

@andylizf andylizf left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM — MPS data is solid, BM25 cleanup is clean. Two nits below.

"BM25/hybrid search will return empty results. "
"Re-run 'leann build' to regenerate passage files."
)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: After logging the error, this falls through to index.fit([]) which creates an empty FTS5 table. Consider adding a return here to bail out early — an empty index silently returns no results on search, which is harder to debug than skipping BM25 init entirely.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch — fixed in f9c74aa. The empty-passages path now returns early, leaving self.bm25_scorer = None. _bm25_search (line 1496-1497) checks for that and raises RuntimeError("BM25 scorer failed to initialize"), so it surfaces loudly instead of returning empty hits.

f"Ensure the index directory is writable, or rebuild with prebuild_bm25=True."
)
return

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: When fit() raises PermissionError/OSError, self.bm25_scorer is never assigned. Worth confirming the caller guards against bm25_scorer being unset — otherwise _bm25_search will AttributeError on the next hybrid query.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Confirmed safe — _bm25_search at lines 1496-1497 already does if scorer is None: raise RuntimeError("BM25 scorer failed to initialize"). The PermissionError/OSError path returns without assigning self.bm25_scorer, so it stays None and the next call surfaces as RuntimeError, not AttributeError. With f9c74aa, the empty-passages path now behaves the same way.

@yichuan-w
Copy link
Copy Markdown
Collaborator

yeah, and please fix the conflict here

ww2283 added 5 commits May 27, 2026 15:07
- Remove torch.mps.set_per_process_memory_fraction(0.9) which lets MPS
  allocator greedily fill ~29 GB on 32 GB machines without releasing
- Guard torch.compile(mode=reduce-overhead) to CUDA only; on MPS it
  caches compiled graphs per sequence-length bucket (~5 GB waste)
- Add torch.mps.empty_cache() between manual HF batches
- Fix BM25Scorer syntax error from upstream merge (restore fit/search methods)

Reported-by: claude_writing_template (footprint dumps: 22 GB → expected ~3 GB)
Tested: ModernPubMedBERT 110M, batch_size=8, 5114 chunks, 640 batches, 32 GB M-series
- Delete BM25Scorer class (in-memory TF tables, O(corpus) RAM at search)
- Remove duplicate Fts5BM25Index class from merge artifact
- Remove _build_bm25_snapshot pickle codepath
- Default bm25_backend to 'fts5', deprecation warning for 'memory'
- Fallback path now builds FTS5 index on-demand from passages
- Clean up unused Counter/defaultdict imports
- 5 issues fixed, 3 skipped (pre-existing)
- Add error handling for read-only filesystem in on-demand FTS5 build
- Guard torch.mps.empty_cache() with try-except
- Handle empty passages with explicit error log
- Catch JSONDecodeError per-line in JSONL reading
- Add comment explaining MPS no-op block

See .doc/pr-review-comments/PR-340-mps-memory-pathologies.md
Address andylizf PR review nit on StarTrail-org#340: empty passages list fell through
to index.fit([]) creating an empty FTS5 table. Now we log+return, leaving
bm25_scorer = None so the existing _bm25_search guard raises RuntimeError
instead of silently returning empty results.
@ww2283 ww2283 force-pushed the fix/mps-memory-pathologies branch from a4c4d5c to f9c74aa Compare May 27, 2026 19:10
@ww2283
Copy link
Copy Markdown
Contributor Author

ww2283 commented May 27, 2026

Rebased onto main (post-#341) and pushed f9c74aa. Conflict resolution: kept the deletion path (this PR completes the #335 intent that PR #341's author noted in their description). Both review nits addressed:

  • nit 1 (empty passages): now returns early; empty FTS5 table no longer silently swallows queries.
  • nit 2 (PermissionError/OSError): confirmed the existing _bm25_search guard at L1496-1497 converts bm25_scorer is None into RuntimeError, so no AttributeError reaches callers.

Credential audit on the 5 commits in this PR: clean — only api.py and embedding_compute.py touched, no secrets.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants