Skip to content

[tx] Add Qwen3 VL model#1196

Draft
tamoghnokandar wants to merge 11 commits intoNovaSky-AI:mainfrom
tamoghnokandar:add_model
Draft

[tx] Add Qwen3 VL model#1196
tamoghnokandar wants to merge 11 commits intoNovaSky-AI:mainfrom
tamoghnokandar:add_model

Conversation

@tamoghnokandar
Copy link
Contributor

@tamoghnokandar tamoghnokandar commented Feb 23, 2026

This PR adds Qwen3-VL model in skyrl-tx directory.


Open with Devin

@tamoghnokandar tamoghnokandar changed the title Add model Add Qwen3 VL model Feb 23, 2026
@tamoghnokandar tamoghnokandar marked this pull request as draft February 23, 2026 07:19
Copy link
Contributor

@devin-ai-integration devin-ai-integration bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Devin Review found 1 potential issue.

View 6 additional findings in Devin Review.

Open in Devin Review

Comment on lines +1817 to +1828
hidden, _, new_kv_cache = self.layers(
hidden,
attention_mask=additive_attention_mask,
positions=text_positions,
adapter_indices=adapter_indices,
kv_cache=kv_cache,
output_hidden_states=False,
gradient_checkpointing=self.config.gradient_checkpointing,
is_training=is_training,
cos=cos,
sin=sin,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 rope_deltas lost in non-manual-loop prefill path, breaking multimodal decode RoPE

When Qwen3VLModel.__call__ takes the non-manual-loop path (no deepstack features and output_hidden_states=False), it delegates to self.layers(...) which calls StackedDecoderLayers.__call__. Inside that method, KVCache.update(None, keys_list, values_list, positions, attention_mask) is called at tx/layers/stacked.py:288 without passing rope_deltas. The resulting KV cache has rope_deltas=None.

Root Cause and Impact

During prefill, rope_deltas is computed at tx/models/qwen3_vl_moe.py:1721 via get_rope_index(). The manual-loop path correctly passes it to KVCache.update() at line 1814, but the non-manual-loop path at line 1817 delegates to self.layers(...), which internally builds the KV cache without rope_deltas.

During subsequent decode steps, the code at line 1692 reads rope_deltas_from_cache = kv_cache.rope_deltas, which will be None. This causes the decode to fall through to build_text_rope (line 1708) instead of the correct mRoPE path with deltas (line 1693-1706). The result is incorrect positional embeddings during decode for multimodal (vision+text) inputs, producing wrong model outputs.

The non-manual-loop path is the default for standard text-only prefill and multimodal prefill without deepstack features when output_hidden_states=False.

Prompt for agents
In skyrl-tx/tx/models/qwen3_vl_moe.py, the non-manual-loop path at lines 1817-1828 delegates KV cache creation to StackedDecoderLayers.__call__, which calls KVCache.update() without rope_deltas. The rope_deltas computed at line 1721 need to be propagated into the resulting KV cache.

Two possible fixes:

1. (Simpler) After the self.layers() call returns new_kv_cache, patch it to include rope_deltas:
   After line 1828, add something like:
   if new_kv_cache is not None and rope_deltas is not None:
       new_kv_cache = KVCache(keys=new_kv_cache.keys, values=new_kv_cache.values, cache_position=new_kv_cache.cache_position, rope_deltas=rope_deltas)

2. (More thorough) Modify StackedDecoderLayers.__call__ in skyrl-tx/tx/layers/stacked.py to accept and forward rope_deltas through to KVCache.update() calls at lines 251 and 288. This would require adding a rope_deltas parameter to the method signature and passing it through.
Open in Devin Review

Was this helpful? React with 👍 or 👎 to provide feedback.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces the JAX implementation of the Qwen3-VL-MoE model, including its vision encoder, multimodal RoPE handling, and text decoder with sparse MoE blocks. New parity tests for both prefill and decode steps are added to verify hidden states and logits against Hugging Face's implementation. Core infrastructure changes include updating StackedDecoderLayers to pass arbitrary keyword arguments to underlying layers and extending KVCache to support rope_deltas for multimodal RoPE alignment during cache updates and padding. The JAX backend and model utility functions are updated to correctly identify and use the Qwen3VLMoeConfig for Qwen3-VL models. A review comment highlights that rope_deltas was not being passed in the non-manual loop path of Qwen3VLModel, which could lead to loss of multimodal decoding alignment, and suggests adding it to the self.layers call. Another comment points out that the Qwen3VLTopKRouter class is redundant and unused, recommending its removal.

Comment on lines +1817 to +1828
hidden, _, new_kv_cache = self.layers(
hidden,
attention_mask=additive_attention_mask,
positions=text_positions,
adapter_indices=adapter_indices,
kv_cache=kv_cache,
output_hidden_states=False,
gradient_checkpointing=self.config.gradient_checkpointing,
is_training=is_training,
cos=cos,
sin=sin,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

When calling self.layers in the non-manual loop path, rope_deltas is not being passed. This will cause the rope_deltas to be lost in the resulting new_kv_cache, breaking multimodal decoding for models that rely on it.

Suggested change
hidden, _, new_kv_cache = self.layers(
hidden,
attention_mask=additive_attention_mask,
positions=text_positions,
adapter_indices=adapter_indices,
kv_cache=kv_cache,
output_hidden_states=False,
gradient_checkpointing=self.config.gradient_checkpointing,
is_training=is_training,
cos=cos,
sin=sin,
)
hidden, _, new_kv_cache = self.layers(
hidden,
attention_mask=additive_attention_mask,
positions=text_positions,
adapter_indices=adapter_indices,
kv_cache=kv_cache,
output_hidden_states=False,
gradient_checkpointing=self.config.gradient_checkpointing,
is_training=is_training,
cos=cos,
sin=sin,
rope_deltas=rope_deltas,
)

Comment on lines +1185 to +1211
class Qwen3VLTopKRouter(nnx.Module):
"""Top-k router matching Qwen3VLMoeTextTopKRouter behavior."""

def __init__(self, spec: Qwen3VLSpec, *, dtype: jnp.dtype, rngs: nnx.Rngs) -> None:
self.spec = spec
self.weight = Param(
spec.text_num_experts,
spec.text_hidden_size,
dtype=dtype,
kernel_init=nnx.initializers.zeros,
rngs=rngs,
)

def __call__(self, hidden_states: jax.Array) -> tuple[jax.Array, jax.Array]:
router_logits = jnp.einsum(
"nh,eh->ne",
hidden_states.astype(jnp.float32),
self.weight.astype(jnp.float32),
precision=jax.lax.Precision.HIGHEST,
)
router_probs = jax.nn.softmax(router_logits, axis=-1)
top_k = min(self.spec.text_num_experts_per_tok, self.spec.text_num_experts)
top_vals, top_idx = jax.lax.top_k(router_probs, top_k)
denom = jnp.sum(top_vals, axis=-1, keepdims=True) + 1e-9
routing_weights = (top_vals / denom).astype(hidden_states.dtype)
return routing_weights, top_idx

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The Qwen3VLTopKRouter class appears to be unused. Qwen3VLSparseMoeBlock instantiates an nnx.Linear layer for the router and uses Qwen3VLExperts to handle the top-k selection and expert dispatch. This redundant class should be removed.

@tamoghnokandar tamoghnokandar changed the title Add Qwen3 VL model [tx] Add Qwen3 VL model Feb 23, 2026
@pcmoritz pcmoritz added the tx label Feb 23, 2026
@tamoghnokandar
Copy link
Contributor Author

@pcmoritz Test for the Qwen3 model (test_qwen3.py) is failing which is unrelated to my PR. Can you take a look once?

@pcmoritz
Copy link
Collaborator

@tamoghnokandar Thank for letting me know, you are right, I'll try to improve the test :)

@pcmoritz
Copy link
Collaborator

Btw, I don't know if you saw, we are also adding VLM support to the API server and engine #1200, so together with this PR, I think we can get it working end-to-end.

There is already some code in nithinvc#2

@tamoghnokandar
Copy link
Contributor Author

Just saw it. I need to dive deep into that part of the codebase, and I might work on adding some of their features. I’ll do it tomorrow though, I'm a bit busy today.

@pcmoritz
Copy link
Collaborator

pcmoritz commented Mar 3, 2026

Btw @tamoghnokandar I just merged a PR for Qwen 3.5 (only the text model part), but it should show you how the text / vision config can be handled, let me know if you have any thoughts or suggestions for improvement: #1228 :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants