Conversation
There was a problem hiding this comment.
Code Review
This pull request implements a fast-path for hybrid models in the Metal backend, allowing single-sequence hybrid requests to utilize native MLX generation. Key changes include updating memory capacity calculations to include hybrid linear states, refining text backbone override logic for Qwen architectures, and introducing a greedy generation path that avoids logprob materialization. Review feedback highlights the need for safer evaluation of the prompt cache to handle uninitialized states and suggests restructuring the model execution logic to ensure non-paged decode batches are processed even when paged attention is active.
| remaining = (total_prompt_tokens - prompt_processed_tokens) - 1 | ||
| n_to_process = min(prefill_step_size, remaining) | ||
| _model_call(prompt[:n_to_process]) | ||
| mx.eval([c.state for c in prompt_cache]) |
There was a problem hiding this comment.
The mx.eval call here might fail if prompt_cache contains ArraysCache entries with None states, which can happen in hybrid models before full initialization. It is safer to filter out None values before evaluation to prevent runtime errors.
| mx.eval([c.state for c in prompt_cache]) | |
| mx.eval([s for c in prompt_cache for s in (c.state if isinstance(c.state, list) else [c.state]) if s is not None]) |
| if batch.valid_decode_reqs: | ||
| self._run_non_paged_decode_batch(batch) |
There was a problem hiding this comment.
The _run_non_paged_decode_batch call is currently skipped if has_paged_work() is true. While the fast-path for hybrid models is restricted to max_num_seqs=1, this logic could lead to silent failures or missing outputs if multiple requests (one paged, one MLX-native) are ever scheduled together in future iterations. Moving this call before the paged-work check ensures all scheduled requests are processed.
| if batch.valid_decode_reqs: | |
| self._run_non_paged_decode_batch(batch) | |
| if batch.valid_decode_reqs: | |
| self._run_non_paged_decode_batch(batch) | |
| if self._paged_attention_backend is not None and batch.has_paged_work(): |
Also, fixed KV cache allocation.