Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions mlx/backend/metal/kernels/gemv.metal
Original file line number Diff line number Diff line change
Expand Up @@ -147,8 +147,10 @@ struct GEMVKernel {
// Adjust tail simdgroup to ensure in bound reads
out_row = out_row + TM <= out_vec_size ? out_row : out_vec_size - TM;

// Advance matrix
mat += out_row * matrix_ld;
// Advance matrix. Widen to size_t before multiplying so the row offset
// does not truncate to int32 when out_row * matrix_ld exceeds 2^31 — the
// failure mode reported in #3591 for large matvecs (e.g. 12347 x 174000).
mat += size_t(out_row) * matrix_ld;

constexpr const uniform<int> loop_stride = make_uniform(blockN);
const uniform<int> in_size = make_uniform(in_vec_size);
Expand Down Expand Up @@ -356,7 +358,9 @@ struct GEMVTKernel {
for (int tm = 0; tm < TM; tm++) {
auto vc = static_cast<AccT>(v_coeff[tm]);
for (int tn = 0; tn < TN; tn++) {
inter[tn] = mat[(bm + tm) * marix_ld + out_col + tn];
// Widen to size_t: (bm + tm) * marix_ld can exceed 2^31 for large
// transposed matvecs — see #3591.
inter[tn] = mat[size_t(bm + tm) * marix_ld + out_col + tn];
}
for (int tn = 0; tn < TN; tn++) {
result[tn] += vc * inter[tn];
Expand All @@ -372,7 +376,8 @@ struct GEMVTKernel {

MLX_MTL_PRAGMA_UNROLL
for (int tn = 0; tn < TN; tn++) {
inter[tn] = mat[(bm + tm) * marix_ld + out_col + tn];
// Same widening as above — see #3591.
inter[tn] = mat[size_t(bm + tm) * marix_ld + out_col + tn];
}

MLX_MTL_PRAGMA_UNROLL
Expand Down
4 changes: 2 additions & 2 deletions mlx/backend/metal/matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2357,7 +2357,7 @@ void gather_mm(
(batch_ndim > 0) ? lhs_indices.strides()[0] : 0,
/* const int64_t batch_stride_b = */
(batch_ndim > 0) ? rhs_indices.strides()[0] : 0,
/* const int64_t batch_stride_d = */ M * N,
/* const int64_t batch_stride_d = */ int64_t(M) * N,
/* const int swizzle_log = */ 0,
/* const int gemm_k_iterations_aligned = */ (K / bk),
/* const int batch_ndim = */ batch_ndim};
Expand Down Expand Up @@ -2587,7 +2587,7 @@ void segmented_mm(
/* const int tiles_m = */ (M + bm - 1) / bm,
/* const int64_t batch_stride_a = */ 0,
/* const int64_t batch_stride_b = */ 0,
/* const int64_t batch_stride_d = */ M * N,
/* const int64_t batch_stride_d = */ int64_t(M) * N,
/* const int swizzle_log = */ 0,
/* const int gemm_k_iterations_aligned = */ 0,
/* const int batch_ndim = */ 0};
Expand Down
55 changes: 55 additions & 0 deletions python/tests/test_blas.py
Original file line number Diff line number Diff line change
Expand Up @@ -1424,6 +1424,61 @@ def rand(shape):
c_np = np.matmul(a, b)
self.assertTrue(np.allclose(out, out_np))

def test_matvec_large_matrix_int32_offset(self):
# Regression for #3591: matvec on a matrix whose row offset
# `out_row * matrix_ld` exceeds 2^31 silently returned wrong
# results because the offset was computed in int32 inside the
# gemv Metal kernel. The reporter's repro is a 12347 x 174000
# fp32 matrix against a 174000-element vector (product 2.15e9
# > 2^31).
#
# The matrix alone is ~8.6 GB and the chunked reference adds
# another ~10 GB of working set, so the test is gated on
# available RAM and only runs on a high-memory Apple Silicon
# device. CI machines will skip.
if not mx.is_available(mx.gpu):
return

info = mx.device_info()
memory_size = info.get("memory_size", 0)
# Needs at least ~24 GB of unified memory to safely exercise the
# >2^31 matrix-offset path with a chunked reference comparison.
if memory_size < (24 << 30):
self.skipTest(
"needs ~24 GB of unified memory to exercise the >2^31 "
"matrix-offset path"
)

rows = 12347
cols = 174000
# Sanity check we're actually exercising the overflow path.
self.assertGreater(rows * cols, 1 << 31)

mx.random.seed(0)
a = mx.random.normal(shape=(rows, cols))
v = mx.random.normal(shape=(cols, 1))

direct = a @ v

# Chunked reference avoids the kernel path that overflowed by
# making each per-chunk matrix fit comfortably under 2^31
# elements.
chunks = 4
chunk = cols // chunks
ref = mx.zeros((rows, 1))
for i in range(chunks):
lo = i * chunk
hi = cols if i == chunks - 1 else lo + chunk
ref = ref + a[:, lo:hi] @ v[lo:hi]

# `.item()` forces materialisation of both arrays before the
# comparison.
denom = float(mx.max(mx.abs(ref)).item())
rel = float(mx.max(mx.abs(direct - ref)).item()) / max(denom, 1e-12)
# Pre-fix the relative error spikes to 0.06-0.25; post-fix it
# sits in fp32 noise (~1e-6).
self.assertLess(rel, 1e-4)


if __name__ == "__main__":
mlx_tests.MLXTestRunner()