Conversation
| hidden, _, new_kv_cache = self.layers( | ||
| hidden, | ||
| attention_mask=additive_attention_mask, | ||
| positions=text_positions, | ||
| adapter_indices=adapter_indices, | ||
| kv_cache=kv_cache, | ||
| output_hidden_states=False, | ||
| gradient_checkpointing=self.config.gradient_checkpointing, | ||
| is_training=is_training, | ||
| cos=cos, | ||
| sin=sin, | ||
| ) |
There was a problem hiding this comment.
🔴 rope_deltas lost in non-manual-loop prefill path, breaking multimodal decode RoPE
When Qwen3VLModel.__call__ takes the non-manual-loop path (no deepstack features and output_hidden_states=False), it delegates to self.layers(...) which calls StackedDecoderLayers.__call__. Inside that method, KVCache.update(None, keys_list, values_list, positions, attention_mask) is called at tx/layers/stacked.py:288 without passing rope_deltas. The resulting KV cache has rope_deltas=None.
Root Cause and Impact
During prefill, rope_deltas is computed at tx/models/qwen3_vl_moe.py:1721 via get_rope_index(). The manual-loop path correctly passes it to KVCache.update() at line 1814, but the non-manual-loop path at line 1817 delegates to self.layers(...), which internally builds the KV cache without rope_deltas.
During subsequent decode steps, the code at line 1692 reads rope_deltas_from_cache = kv_cache.rope_deltas, which will be None. This causes the decode to fall through to build_text_rope (line 1708) instead of the correct mRoPE path with deltas (line 1693-1706). The result is incorrect positional embeddings during decode for multimodal (vision+text) inputs, producing wrong model outputs.
The non-manual-loop path is the default for standard text-only prefill and multimodal prefill without deepstack features when output_hidden_states=False.
Prompt for agents
In skyrl-tx/tx/models/qwen3_vl_moe.py, the non-manual-loop path at lines 1817-1828 delegates KV cache creation to StackedDecoderLayers.__call__, which calls KVCache.update() without rope_deltas. The rope_deltas computed at line 1721 need to be propagated into the resulting KV cache.
Two possible fixes:
1. (Simpler) After the self.layers() call returns new_kv_cache, patch it to include rope_deltas:
After line 1828, add something like:
if new_kv_cache is not None and rope_deltas is not None:
new_kv_cache = KVCache(keys=new_kv_cache.keys, values=new_kv_cache.values, cache_position=new_kv_cache.cache_position, rope_deltas=rope_deltas)
2. (More thorough) Modify StackedDecoderLayers.__call__ in skyrl-tx/tx/layers/stacked.py to accept and forward rope_deltas through to KVCache.update() calls at lines 251 and 288. This would require adding a rope_deltas parameter to the method signature and passing it through.
Was this helpful? React with 👍 or 👎 to provide feedback.
There was a problem hiding this comment.
Code Review
This pull request introduces the JAX implementation of the Qwen3-VL-MoE model, including its vision encoder, multimodal RoPE handling, and text decoder with sparse MoE blocks. New parity tests for both prefill and decode steps are added to verify hidden states and logits against Hugging Face's implementation. Core infrastructure changes include updating StackedDecoderLayers to pass arbitrary keyword arguments to underlying layers and extending KVCache to support rope_deltas for multimodal RoPE alignment during cache updates and padding. The JAX backend and model utility functions are updated to correctly identify and use the Qwen3VLMoeConfig for Qwen3-VL models. A review comment highlights that rope_deltas was not being passed in the non-manual loop path of Qwen3VLModel, which could lead to loss of multimodal decoding alignment, and suggests adding it to the self.layers call. Another comment points out that the Qwen3VLTopKRouter class is redundant and unused, recommending its removal.
| hidden, _, new_kv_cache = self.layers( | ||
| hidden, | ||
| attention_mask=additive_attention_mask, | ||
| positions=text_positions, | ||
| adapter_indices=adapter_indices, | ||
| kv_cache=kv_cache, | ||
| output_hidden_states=False, | ||
| gradient_checkpointing=self.config.gradient_checkpointing, | ||
| is_training=is_training, | ||
| cos=cos, | ||
| sin=sin, | ||
| ) |
There was a problem hiding this comment.
When calling self.layers in the non-manual loop path, rope_deltas is not being passed. This will cause the rope_deltas to be lost in the resulting new_kv_cache, breaking multimodal decoding for models that rely on it.
| hidden, _, new_kv_cache = self.layers( | |
| hidden, | |
| attention_mask=additive_attention_mask, | |
| positions=text_positions, | |
| adapter_indices=adapter_indices, | |
| kv_cache=kv_cache, | |
| output_hidden_states=False, | |
| gradient_checkpointing=self.config.gradient_checkpointing, | |
| is_training=is_training, | |
| cos=cos, | |
| sin=sin, | |
| ) | |
| hidden, _, new_kv_cache = self.layers( | |
| hidden, | |
| attention_mask=additive_attention_mask, | |
| positions=text_positions, | |
| adapter_indices=adapter_indices, | |
| kv_cache=kv_cache, | |
| output_hidden_states=False, | |
| gradient_checkpointing=self.config.gradient_checkpointing, | |
| is_training=is_training, | |
| cos=cos, | |
| sin=sin, | |
| rope_deltas=rope_deltas, | |
| ) |
skyrl-tx/tx/models/qwen3_vl_moe.py
Outdated
| class Qwen3VLTopKRouter(nnx.Module): | ||
| """Top-k router matching Qwen3VLMoeTextTopKRouter behavior.""" | ||
|
|
||
| def __init__(self, spec: Qwen3VLSpec, *, dtype: jnp.dtype, rngs: nnx.Rngs) -> None: | ||
| self.spec = spec | ||
| self.weight = Param( | ||
| spec.text_num_experts, | ||
| spec.text_hidden_size, | ||
| dtype=dtype, | ||
| kernel_init=nnx.initializers.zeros, | ||
| rngs=rngs, | ||
| ) | ||
|
|
||
| def __call__(self, hidden_states: jax.Array) -> tuple[jax.Array, jax.Array]: | ||
| router_logits = jnp.einsum( | ||
| "nh,eh->ne", | ||
| hidden_states.astype(jnp.float32), | ||
| self.weight.astype(jnp.float32), | ||
| precision=jax.lax.Precision.HIGHEST, | ||
| ) | ||
| router_probs = jax.nn.softmax(router_logits, axis=-1) | ||
| top_k = min(self.spec.text_num_experts_per_tok, self.spec.text_num_experts) | ||
| top_vals, top_idx = jax.lax.top_k(router_probs, top_k) | ||
| denom = jnp.sum(top_vals, axis=-1, keepdims=True) + 1e-9 | ||
| routing_weights = (top_vals / denom).astype(hidden_states.dtype) | ||
| return routing_weights, top_idx | ||
|
|
|
@pcmoritz Test for the Qwen3 model (test_qwen3.py) is failing which is unrelated to my PR. Can you take a look once? |
|
@tamoghnokandar Thank for letting me know, you are right, I'll try to improve the test :) |
|
Btw, I don't know if you saw, we are also adding VLM support to the API server and engine #1200, so together with this PR, I think we can get it working end-to-end. There is already some code in nithinvc#2 |
|
Just saw it. I need to dive deep into that part of the codebase, and I might work on adding some of their features. I’ll do it tomorrow though, I'm a bit busy today. |
|
Btw @tamoghnokandar I just merged a PR for Qwen 3.5 (only the text model part), but it should show you how the text / vision config can be handled, let me know if you have any thoughts or suggestions for improvement: #1228 :) |
This PR adds Qwen3-VL model in skyrl-tx directory.