Fix int32 overflow in matvec row offset and gather-MM batch stride#3609
Open
aicayzer wants to merge 1 commit into
Open
Fix int32 overflow in matvec row offset and gather-MM batch stride#3609aicayzer wants to merge 1 commit into
aicayzer wants to merge 1 commit into
Conversation
The matvec path (`gemv.metal`) computes the matrix row offset as `out_row * matrix_ld` in int32 and adds it to `mat`. For large matrices it truncates above 2^31 and silently returns wrong results. Reporter's repro at ml-explore#3591: 12347 x 174000 fp32 matrix * vector (product 2.15e9 > 2^31) gives relative error 0.06-0.25, intermittent run-to-run. Same pattern in the GEMVTKernel transpose variant. The gather-MM and segmented-MM dispatchers in `matmul.cpp` pass `M * N` (int32) as the int64 `batch_stride_d` parameter — same truncation before the widen. Fix: - gemv.metal: widen the row-offset multiplications to size_t before the multiply (matches the steel kernel pattern hardened in ml-explore#1087). - matmul.cpp: compute `M * N` as int64_t at the two dispatch sites. Adds a memory-gated regression test reproducing the reporter's case when ~24 GB of unified memory is available; CI machines skip.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Fixes #3591 (the size-overflow half — see notes below on the second reported issue).
Problem
gemv.metal's matvec kernel computes the matrix row offset asout_row * matrix_ldin int32, then advances thematpointer. For large matrices the product wraps above 2^31 → wrong rows → silent corruption (no crash, no throw). Reporter's case is a 12347 × 174000 fp32 matrix * vector:Affected sites at HEAD
2e6632e:mlx/backend/metal/kernels/gemv.metal:151—GEMVKernel::runmatrix advance (the reporter's actual repro path)mlx/backend/metal/kernels/gemv.metal:359, 375—GEMVTKernel::runtranspose variantmlx/backend/metal/matmul.cpp:2360, 2590—gather_mmandsegmented_mmpassM * N(int32) as theint64_t batch_stride_dparameter, truncating before the widenThe N>1 plain-GEMM path is already 64-bit-safe via the steel kernel work in #1087, so this PR doesn't touch it.
Fix
Widen the row-offset multiplications to
size_tbefore the multiply (gemv) and computeM * Nasint64_tat the twomatmul.cppdispatch sites. Same pattern as the existingc_row_long * params->lddcast in the steel kernels.Test
python/tests/test_blas.py::TestBlas::test_matvec_large_matrix_int32_offsetreproduces the reporter's case and compares against a chunked reference. Pre-fix the relative error is 0.06–0.25; post-fix it sits in fp32 noise (~1e-6).The matrix is ~8.6 GB and the chunked reference adds another ~10 GB, so the test is gated on
mx.device_info()["memory_size"] >= 24 GBand skips otherwise. CI will skip; high-RAM Apple Silicon (the reporter has 128 GB) will exercise it.Notes for maintainers
A couple of things worth flagging before review:
ShapeElem int32 → int64. Happy to redo as a throw-on-overflow guard (the Detect int32 shape-product overflow at MLX compute-shape boundaries #3524 / Clearer error when shape dimension overflows int32 #3425 pattern) or fold into a broader refactor if you'd prefer — please say.slice + 0). That's almost certainly a separate stride/offset bug rather than the gemv kernels, and isn't addressed here. Happy to file as a separate issue if useful.