Skip to content

perf: optimize GDN performance on Metal#1670

Merged
AlpinDale merged 1 commit into
mainfrom
perf/gdn
May 5, 2026
Merged

perf: optimize GDN performance on Metal#1670
AlpinDale merged 1 commit into
mainfrom
perf/gdn

Conversation

@AlpinDale
Copy link
Copy Markdown
Collaborator

Also, fixed KV cache allocation.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements a fast-path for hybrid models in the Metal backend, allowing single-sequence hybrid requests to utilize native MLX generation. Key changes include updating memory capacity calculations to include hybrid linear states, refining text backbone override logic for Qwen architectures, and introducing a greedy generation path that avoids logprob materialization. Review feedback highlights the need for safer evaluation of the prompt cache to handle uninitialized states and suggests restructuring the model execution logic to ensure non-paged decode batches are processed even when paged attention is active.

remaining = (total_prompt_tokens - prompt_processed_tokens) - 1
n_to_process = min(prefill_step_size, remaining)
_model_call(prompt[:n_to_process])
mx.eval([c.state for c in prompt_cache])
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The mx.eval call here might fail if prompt_cache contains ArraysCache entries with None states, which can happen in hybrid models before full initialization. It is safer to filter out None values before evaluation to prevent runtime errors.

Suggested change
mx.eval([c.state for c in prompt_cache])
mx.eval([s for c in prompt_cache for s in (c.state if isinstance(c.state, list) else [c.state]) if s is not None])

Comment on lines +1723 to 1724
if batch.valid_decode_reqs:
self._run_non_paged_decode_batch(batch)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The _run_non_paged_decode_batch call is currently skipped if has_paged_work() is true. While the fast-path for hybrid models is restricted to max_num_seqs=1, this logic could lead to silent failures or missing outputs if multiple requests (one paged, one MLX-native) are ever scheduled together in future iterations. Moving this call before the paged-work check ensures all scheduled requests are processed.

Suggested change
if batch.valid_decode_reqs:
self._run_non_paged_decode_batch(batch)
if batch.valid_decode_reqs:
self._run_non_paged_decode_batch(batch)
if self._paged_attention_backend is not None and batch.has_paged_work():

@AlpinDale AlpinDale merged commit f980722 into main May 5, 2026
1 check failed
@AlpinDale AlpinDale deleted the perf/gdn branch May 5, 2026 09:02
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant