Skip to content

feat: add native Metal support#1668

Merged
AlpinDale merged 1 commit into
mainfrom
feat/metal
May 5, 2026
Merged

feat: add native Metal support#1668
AlpinDale merged 1 commit into
mainfrom
feat/metal

Conversation

@AlpinDale
Copy link
Copy Markdown
Collaborator

Based on vllm-project/vllm-metal, with a variety of changes to page allocation size, samplers, and others.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a comprehensive Metal/MLX runtime for Aphrodite, enabling high-performance LLM inference on Apple Silicon. It includes native Metal kernels for paged attention, support for various attention architectures (MHA, GQA, MQA, MLA, and GDN linear attention), TurboQuant KV cache compression, and Speech-to-Text capabilities (Whisper and Qwen3-ASR). The implementation spans from low-level C++/Metal kernels to high-level integration with the v1 engine, including a MetalPlatform abstraction and a MetalWorker. Feedback focuses on the immutability of MLX arrays, noting that several instances of direct item assignment will fail at runtime and should be replaced with mx.scatter_update. Additionally, it is recommended to raise a RuntimeError in sample_tokens when a valid pending state is missing to prevent downstream errors.

# Save updated conv state back to stable slot
new_conv = conv_input[:, -(inner.conv_kernel_size - 1) :]
cs = state_cache.conv_states[cache_idx]
cs[slot : slot + 1] = new_conv
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

MLX arrays are immutable and do not support item assignment. Attempting to use __setitem__ (e.g., cs[...] = ...) will raise a TypeError. You should use mx.scatter_update or the .at[...].set(...) syntax to perform indexed updates.

Suggested change
cs[slot : slot + 1] = new_conv
cs = mx.scatter_update(cs, mx.array([slot]), new_conv)

Comment on lines +505 to +506
flat_k[slot_mapping] = k_3d
new_k_cache = flat_k.reshape(kv_cache.key_caches[layer_idx].shape)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

MLX arrays do not support item assignment. This operation will fail at runtime with a TypeError. Use mx.scatter_update to update the cache tensor.

Suggested change
flat_k[slot_mapping] = k_3d
new_k_cache = flat_k.reshape(kv_cache.key_caches[layer_idx].shape)
new_k_cache = mx.scatter_update(flat_k, slot_mapping, k_3d).reshape(
kv_cache.key_caches[layer_idx].shape)

Comment on lines +509 to +510
flat_v[slot_mapping] = v_3d
new_v_cache = flat_v.reshape(kv_cache.value_caches[layer_idx].shape)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

MLX arrays do not support item assignment. Use mx.scatter_update to perform this update on the value cache.

Suggested change
flat_v[slot_mapping] = v_3d
new_v_cache = flat_v.reshape(kv_cache.value_caches[layer_idx].shape)
new_v_cache = mx.scatter_update(flat_v, slot_mapping, v_3d).reshape(
kv_cache.value_caches[layer_idx].shape)

flat = (inputs_embeds + mx.zeros((), dtype=input_dtype)).reshape(-1, hidden_size)
mask_np = np.asarray(mask_flat)
positions = mx.array(np.where(mask_np)[0], dtype=mx.uint32)
flat[positions] = mm_embeds_flat.astype(input_dtype)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

MLX arrays are immutable and do not support item assignment. This line will raise a TypeError. Use mx.scatter_update instead.

Suggested change
flat[positions] = mm_embeds_flat.astype(input_dtype)
flat = mx.scatter_update(flat, positions, mm_embeds_flat.astype(input_dtype))

flat = latent_cache.latent_caches[layer_idx].reshape(
-1, latent_cache.latent_dim
)
flat[mx.array(ctx.slot_mapping, dtype=mx.int64)] = latent_flat
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

MLX arrays do not support item assignment. Use mx.scatter_update to update the latent cache.

Suggested change
flat[mx.array(ctx.slot_mapping, dtype=mx.int64)] = latent_flat
flat = mx.scatter_update(flat, mx.array(ctx.slot_mapping, dtype=mx.int64), latent_flat)

merged_state = mx.zeros(tuple(shape), template.dtype)
for batch_idx, value in enumerate(values):
if value is None:
continue
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

MLX arrays do not support item assignment. Use mx.scatter_update to merge the caches.

Suggested change
continue
merged_state = mx.scatter_update(merged_state, mx.array([batch_idx]), value)

# logits[0] produces a new lazy computation node (not a Python alias of
# logits), so __setitem__ here does not mutate the caller-held logits array.
result_2d = logits[0]
result_2d[logit_rows] = rows_mlx
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

MLX arrays do not support item assignment. Use mx.scatter_update to apply the grammar bitmask to the logits.

Suggested change
result_2d[logit_rows] = rows_mlx
result_2d = mx.scatter_update(result_2d, mx.array(logit_rows), rows_mlx)

Comment on lines +1394 to +1398
logger.error(
"sample_tokens called with no pending state — "
"neither _execute_model_state nor _pending_output was set."
)
return None
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

If sample_tokens is called without a valid pending state, it indicates a critical failure in the engine's state management. Returning None here will likely cause an AttributeError in the caller. It is safer to raise a RuntimeError to provide a clear failure point.

Suggested change
logger.error(
"sample_tokens called with no pending state — "
"neither _execute_model_state nor _pending_output was set."
)
return None
raise RuntimeError(
"sample_tokens called with no pending state — "
"neither _execute_model_state nor _pending_output was set."
)

@AlpinDale AlpinDale merged commit c19f1fe into main May 5, 2026
1 check failed
@AlpinDale AlpinDale deleted the feat/metal branch May 5, 2026 04:33
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant