Describe the bug
mixed precision kernel not found when matrix multiplying float and mxfp8/nvfp4 quantized matrices.
To Reproduce
Include code snippet
import mlx.core as mx
input = mx.random.normal(shape=(16,768))
wt = mx.random.normal(shape=(768,3*768))
wt_q , wt_s = mx.quantize(wt, mode='mxfp8')
qm = mx.quantized_matmul(input, wt_q, wt_s, mode='mxfp8', transpose=False)
mx.eval(qm)
get the error: RuntimeError: [metal::Device] Unable to load kernel mxfp8_qmm_n_float_gs_32_b_8_batch_0
happens with all of fp32, fp16, bf16 inputs.
Expected behavior
successful matmul operation
Desktop (please complete the following information):
- OS Version: [e.g. MacOS 15.6.1]
- Version [e.g. 0.7.0]
Additional context
Tried upgrading mlx, built from source, but error remains.