feat: add native Metal support#1668
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces a comprehensive Metal/MLX runtime for Aphrodite, enabling high-performance LLM inference on Apple Silicon. It includes native Metal kernels for paged attention, support for various attention architectures (MHA, GQA, MQA, MLA, and GDN linear attention), TurboQuant KV cache compression, and Speech-to-Text capabilities (Whisper and Qwen3-ASR). The implementation spans from low-level C++/Metal kernels to high-level integration with the v1 engine, including a MetalPlatform abstraction and a MetalWorker. Feedback focuses on the immutability of MLX arrays, noting that several instances of direct item assignment will fail at runtime and should be replaced with mx.scatter_update. Additionally, it is recommended to raise a RuntimeError in sample_tokens when a valid pending state is missing to prevent downstream errors.
| # Save updated conv state back to stable slot | ||
| new_conv = conv_input[:, -(inner.conv_kernel_size - 1) :] | ||
| cs = state_cache.conv_states[cache_idx] | ||
| cs[slot : slot + 1] = new_conv |
There was a problem hiding this comment.
MLX arrays are immutable and do not support item assignment. Attempting to use __setitem__ (e.g., cs[...] = ...) will raise a TypeError. You should use mx.scatter_update or the .at[...].set(...) syntax to perform indexed updates.
| cs[slot : slot + 1] = new_conv | |
| cs = mx.scatter_update(cs, mx.array([slot]), new_conv) |
| flat_k[slot_mapping] = k_3d | ||
| new_k_cache = flat_k.reshape(kv_cache.key_caches[layer_idx].shape) |
There was a problem hiding this comment.
MLX arrays do not support item assignment. This operation will fail at runtime with a TypeError. Use mx.scatter_update to update the cache tensor.
| flat_k[slot_mapping] = k_3d | |
| new_k_cache = flat_k.reshape(kv_cache.key_caches[layer_idx].shape) | |
| new_k_cache = mx.scatter_update(flat_k, slot_mapping, k_3d).reshape( | |
| kv_cache.key_caches[layer_idx].shape) |
| flat_v[slot_mapping] = v_3d | ||
| new_v_cache = flat_v.reshape(kv_cache.value_caches[layer_idx].shape) |
There was a problem hiding this comment.
MLX arrays do not support item assignment. Use mx.scatter_update to perform this update on the value cache.
| flat_v[slot_mapping] = v_3d | |
| new_v_cache = flat_v.reshape(kv_cache.value_caches[layer_idx].shape) | |
| new_v_cache = mx.scatter_update(flat_v, slot_mapping, v_3d).reshape( | |
| kv_cache.value_caches[layer_idx].shape) |
| flat = (inputs_embeds + mx.zeros((), dtype=input_dtype)).reshape(-1, hidden_size) | ||
| mask_np = np.asarray(mask_flat) | ||
| positions = mx.array(np.where(mask_np)[0], dtype=mx.uint32) | ||
| flat[positions] = mm_embeds_flat.astype(input_dtype) |
There was a problem hiding this comment.
| flat = latent_cache.latent_caches[layer_idx].reshape( | ||
| -1, latent_cache.latent_dim | ||
| ) | ||
| flat[mx.array(ctx.slot_mapping, dtype=mx.int64)] = latent_flat |
There was a problem hiding this comment.
| merged_state = mx.zeros(tuple(shape), template.dtype) | ||
| for batch_idx, value in enumerate(values): | ||
| if value is None: | ||
| continue |
| # logits[0] produces a new lazy computation node (not a Python alias of | ||
| # logits), so __setitem__ here does not mutate the caller-held logits array. | ||
| result_2d = logits[0] | ||
| result_2d[logit_rows] = rows_mlx |
| logger.error( | ||
| "sample_tokens called with no pending state — " | ||
| "neither _execute_model_state nor _pending_output was set." | ||
| ) | ||
| return None |
There was a problem hiding this comment.
If sample_tokens is called without a valid pending state, it indicates a critical failure in the engine's state management. Returning None here will likely cause an AttributeError in the caller. It is safer to raise a RuntimeError to provide a clear failure point.
| logger.error( | |
| "sample_tokens called with no pending state — " | |
| "neither _execute_model_state nor _pending_output was set." | |
| ) | |
| return None | |
| raise RuntimeError( | |
| "sample_tokens called with no pending state — " | |
| "neither _execute_model_state nor _pending_output was set." | |
| ) |
Based on vllm-project/vllm-metal, with a variety of changes to page allocation size, samplers, and others.