From 247266bd3d4a31054f20038cbe3b5a55a181e7ce Mon Sep 17 00:00:00 2001 From: jlarson4 Date: Tue, 14 Apr 2026 12:19:16 -0500 Subject: [PATCH 1/2] Initial Qwen 3.5 adapter --- tests/unit/test_qwen3_5_adapter.py | 643 ++++++++++++++++++ .../factories/architecture_adapter_factory.py | 2 + .../model_bridge/sources/transformers.py | 6 + .../supported_architectures/__init__.py | 4 + .../supported_architectures/qwen3_5.py | 175 +++++ .../tools/model_registry/__init__.py | 1 + .../model_registry/data/supported_models.json | 57 +- .../data/verification_history.json | 22 +- 8 files changed, 906 insertions(+), 4 deletions(-) create mode 100644 tests/unit/test_qwen3_5_adapter.py create mode 100644 transformer_lens/model_bridge/supported_architectures/qwen3_5.py diff --git a/tests/unit/test_qwen3_5_adapter.py b/tests/unit/test_qwen3_5_adapter.py new file mode 100644 index 000000000..8fd885174 --- /dev/null +++ b/tests/unit/test_qwen3_5_adapter.py @@ -0,0 +1,643 @@ +"""Unit tests for the Qwen3_5 architecture adapter (Phase A+B). + +Tests cover: +1. Registration: adapter importable, in SUPPORTED_ARCHITECTURES, in HF_SUPPORTED_ARCHITECTURES +2. Component mapping: correct bridge hierarchy with only universal submodules (no self_attn), + GatedMLPBridge with gate/in/out LinearBridge submodules +3. Config attributes: all cfg attributes set correctly +4. Weight conversions: preprocess_weights correctly slices q_proj.weight per-head +5. Integration: end-to-end tests with a tiny programmatically-constructed model + +Note: Qwen3_5 is supported only via TransformerBridge, not HookedTransformer. +""" + +import pytest + +from transformer_lens.factories.architecture_adapter_factory import ( + SUPPORTED_ARCHITECTURES, +) +from transformer_lens.tools.model_registry import HF_SUPPORTED_ARCHITECTURES + +# ============================================================================ +# Test: Registration +# ============================================================================ + + +class TestQwen3_5Registration: + """Verify the adapter is properly registered in all lookup tables.""" + + def test_adapter_importable(self): + """Qwen3_5ArchitectureAdapter must be importable.""" + from transformer_lens.model_bridge.supported_architectures import ( + Qwen3_5ArchitectureAdapter, + ) + + assert Qwen3_5ArchitectureAdapter is not None + + def test_in_supported_architectures(self): + """Qwen3_5ForCausalLM must be in SUPPORTED_ARCHITECTURES.""" + assert "Qwen3_5ForCausalLM" in SUPPORTED_ARCHITECTURES + + def test_in_hf_supported_architectures(self): + """Qwen3_5ForCausalLM must be in HF_SUPPORTED_ARCHITECTURES.""" + assert "Qwen3_5ForCausalLM" in HF_SUPPORTED_ARCHITECTURES + + def test_adapter_class_correct(self): + """The adapter class must be Qwen3_5ArchitectureAdapter.""" + from transformer_lens.model_bridge.supported_architectures import ( + Qwen3_5ArchitectureAdapter, + ) + + assert SUPPORTED_ARCHITECTURES["Qwen3_5ForCausalLM"] is Qwen3_5ArchitectureAdapter + + +# ============================================================================ +# Helpers: TransformerBridgeConfig for adapter instantiation +# ============================================================================ + + +def _make_bridge_cfg(**overrides): + """Create a minimal TransformerBridgeConfig for Qwen3_5 adapter tests.""" + from transformer_lens.config.TransformerBridgeConfig import TransformerBridgeConfig + + defaults = dict( + d_model=1024, + d_head=256, + n_heads=8, + n_layers=24, + n_ctx=2048, + d_vocab=248320, + n_key_value_heads=2, + architecture="Qwen3_5ForCausalLM", + ) + defaults.update(overrides) + return TransformerBridgeConfig(**defaults) + + +# ============================================================================ +# Test: Component Mapping (Phase A+B) +# ============================================================================ + + +class TestQwen3_5ComponentMapping: + """Verify the component_mapping structure for Qwen3_5. + + The key invariant: self_attn is NOT mapped as a block submodule because + linear-attention layers lack self_attn. Only universally present submodules + (norms, dense MLP) are mapped. Unlike Qwen3Next, the MLP is a GatedMLPBridge + with enumerated gate/in/out LinearBridge submodules. + """ + + @pytest.fixture + def adapter(self): + from transformer_lens.model_bridge.supported_architectures.qwen3_5 import ( + Qwen3_5ArchitectureAdapter, + ) + + cfg = _make_bridge_cfg() + return Qwen3_5ArchitectureAdapter(cfg) + + # ---- Top-level keys ---- + + def test_component_mapping_keys(self, adapter): + """component_mapping must have exactly the expected top-level keys.""" + assert set(adapter.component_mapping.keys()) == { + "embed", + "rotary_emb", + "blocks", + "ln_final", + "unembed", + } + + # ---- HF path names ---- + + def test_embed_path(self, adapter): + """embed maps to model.embed_tokens.""" + assert adapter.component_mapping["embed"].name == "model.embed_tokens" + + def test_rotary_emb_path(self, adapter): + """rotary_emb maps to model.rotary_emb.""" + assert adapter.component_mapping["rotary_emb"].name == "model.rotary_emb" + + def test_blocks_path(self, adapter): + """blocks maps to model.layers.""" + assert adapter.component_mapping["blocks"].name == "model.layers" + + def test_ln_final_path(self, adapter): + """ln_final maps to model.norm.""" + assert adapter.component_mapping["ln_final"].name == "model.norm" + + def test_unembed_path(self, adapter): + """unembed maps to lm_head.""" + assert adapter.component_mapping["unembed"].name == "lm_head" + + # ---- Block submodules ---- + + def test_block_submodules_keys(self, adapter): + """blocks submodules must contain ln1, ln2, mlp but NOT attn. + + Critical correctness test: self_attn is absent on linear-attention + layers, so mapping attn as a block submodule would crash on those layers. + """ + submodules = adapter.component_mapping["blocks"].submodules + assert set(submodules.keys()) == {"ln1", "ln2", "mlp"} + + def test_no_attn_in_block_submodules(self, adapter): + """attn must NOT appear as a block submodule (hybrid architecture safety check).""" + submodules = adapter.component_mapping["blocks"].submodules + assert "attn" not in submodules + + def test_ln1_path(self, adapter): + """ln1 maps to input_layernorm.""" + assert adapter.component_mapping["blocks"].submodules["ln1"].name == "input_layernorm" + + def test_ln2_path(self, adapter): + """ln2 maps to post_attention_layernorm.""" + assert ( + adapter.component_mapping["blocks"].submodules["ln2"].name == "post_attention_layernorm" + ) + + def test_mlp_path(self, adapter): + """mlp maps to mlp.""" + assert adapter.component_mapping["blocks"].submodules["mlp"].name == "mlp" + + # ---- MLP submodules ---- + + def test_mlp_submodule_keys(self, adapter): + """mlp submodules must be exactly {gate, in, out}.""" + mlp = adapter.component_mapping["blocks"].submodules["mlp"] + assert set(mlp.submodules.keys()) == {"gate", "in", "out"} + + def test_mlp_gate_path(self, adapter): + """mlp.gate maps to gate_proj.""" + mlp = adapter.component_mapping["blocks"].submodules["mlp"] + assert mlp.submodules["gate"].name == "gate_proj" + + def test_mlp_in_path(self, adapter): + """mlp.in maps to up_proj.""" + mlp = adapter.component_mapping["blocks"].submodules["mlp"] + assert mlp.submodules["in"].name == "up_proj" + + def test_mlp_out_path(self, adapter): + """mlp.out maps to down_proj.""" + mlp = adapter.component_mapping["blocks"].submodules["mlp"] + assert mlp.submodules["out"].name == "down_proj" + + # ---- Bridge types ---- + + def test_blocks_bridge_type(self, adapter): + """blocks uses BlockBridge.""" + from transformer_lens.model_bridge.generalized_components import BlockBridge + + assert isinstance(adapter.component_mapping["blocks"], BlockBridge) + + def test_rotary_emb_bridge_type(self, adapter): + """rotary_emb uses RotaryEmbeddingBridge.""" + from transformer_lens.model_bridge.generalized_components import ( + RotaryEmbeddingBridge, + ) + + assert isinstance(adapter.component_mapping["rotary_emb"], RotaryEmbeddingBridge) + + def test_ln1_bridge_type(self, adapter): + """ln1 uses RMSNormalizationBridge.""" + from transformer_lens.model_bridge.generalized_components import ( + RMSNormalizationBridge, + ) + + ln1 = adapter.component_mapping["blocks"].submodules["ln1"] + assert isinstance(ln1, RMSNormalizationBridge) + + def test_ln2_bridge_type(self, adapter): + """ln2 uses RMSNormalizationBridge.""" + from transformer_lens.model_bridge.generalized_components import ( + RMSNormalizationBridge, + ) + + ln2 = adapter.component_mapping["blocks"].submodules["ln2"] + assert isinstance(ln2, RMSNormalizationBridge) + + def test_mlp_bridge_type(self, adapter): + """mlp uses GatedMLPBridge (dense gated MLP, not MoE).""" + from transformer_lens.model_bridge.generalized_components import GatedMLPBridge + + mlp = adapter.component_mapping["blocks"].submodules["mlp"] + assert isinstance(mlp, GatedMLPBridge) + + def test_mlp_gate_bridge_type(self, adapter): + """mlp.gate uses LinearBridge.""" + from transformer_lens.model_bridge.generalized_components import LinearBridge + + gate = adapter.component_mapping["blocks"].submodules["mlp"].submodules["gate"] + assert isinstance(gate, LinearBridge) + + def test_mlp_in_bridge_type(self, adapter): + """mlp.in uses LinearBridge.""" + from transformer_lens.model_bridge.generalized_components import LinearBridge + + up = adapter.component_mapping["blocks"].submodules["mlp"].submodules["in"] + assert isinstance(up, LinearBridge) + + def test_mlp_out_bridge_type(self, adapter): + """mlp.out uses LinearBridge.""" + from transformer_lens.model_bridge.generalized_components import LinearBridge + + down = adapter.component_mapping["blocks"].submodules["mlp"].submodules["out"] + assert isinstance(down, LinearBridge) + + # ---- weight_processing_conversions ---- + + def test_weight_processing_conversions_empty(self, adapter): + """weight_processing_conversions is empty (no attention submodules mapped).""" + assert adapter.weight_processing_conversions == {} + + +# ============================================================================ +# Test: Config Attributes (Phase A+B) +# ============================================================================ + + +class TestQwen3_5ConfigAttributes: + """Verify all cfg attributes are set correctly by the adapter.""" + + @pytest.fixture + def adapter(self): + from transformer_lens.model_bridge.supported_architectures.qwen3_5 import ( + Qwen3_5ArchitectureAdapter, + ) + + cfg = _make_bridge_cfg() + return Qwen3_5ArchitectureAdapter(cfg) + + def test_normalization_type(self, adapter): + assert adapter.cfg.normalization_type == "RMS" + + def test_positional_embedding_type(self, adapter): + assert adapter.cfg.positional_embedding_type == "rotary" + + def test_final_rms(self, adapter): + assert adapter.cfg.final_rms is True + + def test_gated_mlp(self, adapter): + assert adapter.cfg.gated_mlp is True + + def test_attn_only(self, adapter): + assert adapter.cfg.attn_only is False + + def test_uses_rms_norm(self, adapter): + assert adapter.cfg.uses_rms_norm is True + + def test_default_prepend_bos(self, adapter): + assert adapter.cfg.default_prepend_bos is False + + def test_supports_fold_ln_false(self, adapter): + """supports_fold_ln must be False: hybrid layers break fold_ln.""" + assert adapter.supports_fold_ln is False + + def test_attn_implementation_eager(self, adapter): + """attn_implementation must be 'eager' for output_attentions support.""" + assert adapter.cfg.attn_implementation == "eager" + + def test_n_key_value_heads_set_when_gqa(self): + """n_key_value_heads is set on cfg when the input config has it.""" + from transformer_lens.model_bridge.supported_architectures.qwen3_5 import ( + Qwen3_5ArchitectureAdapter, + ) + + cfg = _make_bridge_cfg(n_key_value_heads=2) + adapter = Qwen3_5ArchitectureAdapter(cfg) + assert adapter.cfg.n_key_value_heads == 2 + + def test_n_key_value_heads_not_set_when_absent(self): + """n_key_value_heads is not set when the config doesn't have it.""" + from transformer_lens.config.TransformerBridgeConfig import ( + TransformerBridgeConfig, + ) + from transformer_lens.model_bridge.supported_architectures.qwen3_5 import ( + Qwen3_5ArchitectureAdapter, + ) + + # Config without n_key_value_heads + cfg = TransformerBridgeConfig( + d_model=1024, + d_head=256, + n_heads=8, + n_layers=24, + n_ctx=2048, + d_vocab=248320, + architecture="Qwen3_5ForCausalLM", + ) + adapter = Qwen3_5ArchitectureAdapter(cfg) + # n_key_value_heads should equal n_heads (standard MHA default) + assert not ( + hasattr(adapter.cfg, "n_key_value_heads") + and adapter.cfg.n_key_value_heads is not None + and adapter.cfg.n_key_value_heads != adapter.cfg.n_heads + ) + + +# ============================================================================ +# Test: preprocess_weights (Phase A+B) +# ============================================================================ + + +class TestQwen3_5PreprocessWeights: + """Verify preprocess_weights correctly slices q_proj.weight per-head. + + Background: In Qwen3_5, q_proj.weight has shape (n_heads * head_dim * 2, hidden_size) + where rows are organized as interleaved per-head pairs: + head_0_query (d_head rows), head_0_gate (d_head rows), + head_1_query (d_head rows), head_1_gate (d_head rows), ... + + A naive first-half slice would be wrong. The correct approach reshapes by + head and takes only the first d_head rows per head (the query half). + """ + + N_HEADS = 4 + D_HEAD = 8 + HIDDEN_SIZE = 32 + + @pytest.fixture + def adapter(self): + from transformer_lens.model_bridge.supported_architectures.qwen3_5 import ( + Qwen3_5ArchitectureAdapter, + ) + + cfg = _make_bridge_cfg( + n_heads=self.N_HEADS, + d_head=self.D_HEAD, + d_model=self.HIDDEN_SIZE, + n_key_value_heads=self.N_HEADS, # MHA for simplicity + ) + return Qwen3_5ArchitectureAdapter(cfg) + + def _make_q_proj_weight(self): + """Create a q_proj.weight tensor with distinct per-head-row values.""" + import torch + + total_rows = self.N_HEADS * self.D_HEAD * 2 + w = torch.zeros(total_rows, self.HIDDEN_SIZE) + for row_idx in range(total_rows): + w[row_idx] = float(row_idx) + return w + + def test_q_proj_output_shape(self, adapter): + """preprocess_weights reduces q_proj rows from n_heads*d_head*2 to n_heads*d_head.""" + import torch + + w = self._make_q_proj_weight() + state_dict = {"model.layers.3.self_attn.q_proj.weight": w} + result = adapter.preprocess_weights(state_dict) + out = result["model.layers.3.self_attn.q_proj.weight"] + assert out.shape == (self.N_HEADS * self.D_HEAD, self.HIDDEN_SIZE) + + def test_q_proj_selects_query_rows_not_naive_first_half(self, adapter): + """For each head i, output rows [i*d_head:(i+1)*d_head] == input rows + [i*d_head*2 : i*d_head*2 + d_head] (per-head interleaved layout).""" + import torch + + w = self._make_q_proj_weight() + state_dict = {"model.layers.0.self_attn.q_proj.weight": w} + result = adapter.preprocess_weights(state_dict) + out = result["model.layers.0.self_attn.q_proj.weight"] + + for head_idx in range(self.N_HEADS): + out_rows = out[head_idx * self.D_HEAD : (head_idx + 1) * self.D_HEAD] + expected_start = head_idx * self.D_HEAD * 2 + expected_rows = w[expected_start : expected_start + self.D_HEAD] + assert torch.equal(out_rows, expected_rows), ( + f"Head {head_idx}: output rows do not match expected query rows. " + f"Got row values starting at {out_rows[0, 0].item()}, " + f"expected starting at {expected_rows[0, 0].item()}" + ) + + def test_naive_slice_would_be_wrong(self, adapter): + """Naive first-half slice gives different (wrong) results for n_heads > 1.""" + import torch + + w = self._make_q_proj_weight() + state_dict = {"model.layers.0.self_attn.q_proj.weight": w} + result = adapter.preprocess_weights(state_dict) + correct_out = result["model.layers.0.self_attn.q_proj.weight"] + naive_out = w[: self.N_HEADS * self.D_HEAD] + + if self.N_HEADS > 1: + assert not torch.equal(correct_out, naive_out), ( + "Naive first-half slice gave the same result as per-head slice — " + "test setup may be wrong" + ) + + def test_non_q_proj_weights_unchanged(self, adapter): + """k_proj, v_proj, and down_proj weights are NOT modified by preprocess_weights.""" + import torch + + k_proj = torch.randn(self.N_HEADS * self.D_HEAD, self.HIDDEN_SIZE) + down_proj = torch.randn(self.HIDDEN_SIZE, self.N_HEADS * self.D_HEAD) + state_dict = { + "model.layers.0.self_attn.k_proj.weight": k_proj.clone(), + "model.layers.0.mlp.down_proj.weight": down_proj.clone(), + } + result = adapter.preprocess_weights(state_dict) + assert torch.equal(result["model.layers.0.self_attn.k_proj.weight"], k_proj) + assert torch.equal(result["model.layers.0.mlp.down_proj.weight"], down_proj) + + def test_multiple_layers_all_processed(self, adapter): + """q_proj.weight tensors across multiple layers are all sliced correctly.""" + import torch + + w0 = self._make_q_proj_weight() + w3 = self._make_q_proj_weight() * 2 + state_dict = { + "model.layers.0.self_attn.q_proj.weight": w0, + "model.layers.3.self_attn.q_proj.weight": w3, + } + result = adapter.preprocess_weights(state_dict) + expected_shape = (self.N_HEADS * self.D_HEAD, self.HIDDEN_SIZE) + assert result["model.layers.0.self_attn.q_proj.weight"].shape == expected_shape + assert result["model.layers.3.self_attn.q_proj.weight"].shape == expected_shape + + def test_empty_state_dict_returns_empty(self, adapter): + """preprocess_weights with an empty state dict returns an empty dict.""" + assert adapter.preprocess_weights({}) == {} + + def test_state_dict_without_q_proj_unchanged(self, adapter): + """A state dict with no q_proj keys is returned unmodified.""" + import torch + + state_dict = {"model.embed_tokens.weight": torch.randn(100, self.HIDDEN_SIZE)} + original_keys = set(state_dict.keys()) + result = adapter.preprocess_weights(state_dict) + assert set(result.keys()) == original_keys + + def test_weight_processing_conversions_is_empty_dict(self, adapter): + """weight_processing_conversions is {} — q_proj slicing is in preprocess_weights.""" + assert adapter.weight_processing_conversions == {} + + +# ============================================================================ +# Test: Integration (Phase A+B) +# ============================================================================ + +try: + from transformers import Qwen3_5ForCausalLM as _Qwen3_5ForCausalLM + from transformers import Qwen3_5TextConfig + + _QWEN3_5_AVAILABLE = True +except ImportError: + _QWEN3_5_AVAILABLE = False + + +def _make_tiny_hf_model(): + """Create a tiny Qwen3_5ForCausalLM for integration testing. + + 8 layers: layers 3 and 7 are full-attention (full_attention_interval=4), + layers 0-2 and 4-6 are linear-attention (GatedDeltaNet). + Dense gated MLP on all layers. + """ + cfg = Qwen3_5TextConfig( + hidden_size=128, + num_hidden_layers=8, + num_attention_heads=4, + num_key_value_heads=2, + head_dim=32, + intermediate_size=256, + vocab_size=512, + rms_norm_eps=1e-6, + hidden_act="silu", + full_attention_interval=4, + linear_conv_kernel_dim=4, + linear_key_head_dim=32, + linear_value_head_dim=32, + linear_num_key_heads=4, + linear_num_value_heads=4, + rope_parameters={ + "rope_theta": 10000.0, + "partial_rotary_factor": 0.25, + "rope_type": "default", + }, + ) + model = _Qwen3_5ForCausalLM(cfg) + model.eval() + return model + + +def _make_tiny_bridge(): + """Create a Qwen3_5 bridge from a tiny HF model.""" + from unittest.mock import MagicMock + + from transformer_lens.config.TransformerBridgeConfig import TransformerBridgeConfig + from transformer_lens.model_bridge import TransformerBridge + from transformer_lens.model_bridge.supported_architectures.qwen3_5 import ( + Qwen3_5ArchitectureAdapter, + ) + + hf_model = _make_tiny_hf_model() + + bridge_cfg = TransformerBridgeConfig( + d_model=128, + d_head=32, + n_heads=4, + n_layers=8, + n_ctx=2048, + d_vocab=512, + n_key_value_heads=2, + architecture="Qwen3_5ForCausalLM", + ) + adapter = Qwen3_5ArchitectureAdapter(bridge_cfg) + return TransformerBridge(hf_model, adapter, tokenizer=MagicMock()), hf_model + + +@pytest.mark.skipif( + not _QWEN3_5_AVAILABLE, + reason="Qwen3_5TextConfig / Qwen3_5ForCausalLM not available in installed transformers", +) +class TestQwen3_5Integration: + """End-to-end integration tests using a tiny programmatic Qwen3_5 model. + + The linear attention layers run via the torch fallback path when + flash-linear-attention / causal-conv1d are not installed. + """ + + @pytest.fixture(scope="class") + def bridge_and_model(self): + """Create a tiny bridge + HF model pair, shared across the class.""" + return _make_tiny_bridge() + + @pytest.fixture(scope="class") + def bridge(self, bridge_and_model): + br, _ = bridge_and_model + return br + + @pytest.fixture(scope="class") + def hf_model(self, bridge_and_model): + _, hf = bridge_and_model + return hf + + def test_bridge_creation(self, bridge): + """TransformerBridge construction from a tiny Qwen3_5 model must succeed.""" + from transformer_lens.model_bridge import TransformerBridge + + assert isinstance(bridge, TransformerBridge) + + def test_hook_names_present(self, bridge): + """Key hook names must be present; blocks.0.attn.* must NOT be present. + + Verified: + - blocks.0.hook_resid_pre: linear-attention layer (layer 0) + - blocks.3.hook_resid_pre: first full-attention layer (layer 3) + - blocks.0.ln1.*: norm present on all layers (universal submodule) + - blocks.0.mlp.*: MLP present on all layers (universal submodule) + - blocks.0.attn.*: NOT present (self_attn absent on linear-attn layers) + """ + hook_keys = set(bridge.hook_dict.keys()) + + assert "blocks.0.hook_resid_pre" in hook_keys, "linear-attn layer must have hook_resid_pre" + assert "blocks.3.hook_resid_pre" in hook_keys, "full-attn layer must have hook_resid_pre" + assert any( + "blocks.0.ln1" in k for k in hook_keys + ), "blocks.0.ln1 submodule hooks must be present" + assert any( + "blocks.0.mlp" in k for k in hook_keys + ), "blocks.0.mlp submodule hooks must be present" + assert not any( + "blocks.0.attn" in k for k in hook_keys + ), "blocks.0.attn hooks must NOT be present (hybrid architecture)" + + def test_forward_pass_consistency(self, bridge, hf_model): + """Bridge output logits must match HF model output logits within atol=1e-4.""" + import torch + + tokens = torch.randint(0, 512, (1, 4)) + with torch.no_grad(): + hf_logits = hf_model(tokens).logits + bridge_logits = bridge(tokens) + + assert ( + hf_logits.shape == bridge_logits.shape + ), f"Shape mismatch: HF={hf_logits.shape}, bridge={bridge_logits.shape}" + assert torch.allclose( + hf_logits, bridge_logits, atol=1e-4 + ), f"Logit mismatch: max diff = {(hf_logits - bridge_logits).abs().max().item():.6f}" + + def test_hook_activation_shapes(self, bridge): + """A hook on blocks.0.mlp.hook_out must capture a (batch, seq, d_model) tensor.""" + import torch + + captured: list[torch.Tensor] = [] + + def capture_hook(tensor: torch.Tensor, hook: object) -> torch.Tensor: + captured.append(tensor.detach().clone()) + return tensor + + tokens = torch.randint(0, 512, (1, 4)) + with torch.no_grad(): + bridge.run_with_hooks(tokens, fwd_hooks=[("blocks.0.mlp.hook_out", capture_hook)]) + + assert len(captured) == 1, "Hook must fire exactly once per forward pass" + output = captured[0] + batch, seq, d_model = 1, 4, 128 + assert output.shape == ( + batch, + seq, + d_model, + ), f"Expected MLP output shape ({batch}, {seq}, {d_model}), got {output.shape}" diff --git a/transformer_lens/factories/architecture_adapter_factory.py b/transformer_lens/factories/architecture_adapter_factory.py index 6d4ca4964..41fa894a8 100644 --- a/transformer_lens/factories/architecture_adapter_factory.py +++ b/transformer_lens/factories/architecture_adapter_factory.py @@ -46,6 +46,7 @@ Phi3ArchitectureAdapter, PhiArchitectureAdapter, Qwen2ArchitectureAdapter, + Qwen3_5ArchitectureAdapter, Qwen3ArchitectureAdapter, Qwen3MoeArchitectureAdapter, Qwen3NextArchitectureAdapter, @@ -100,6 +101,7 @@ "Qwen3ForCausalLM": Qwen3ArchitectureAdapter, "Qwen3MoeForCausalLM": Qwen3MoeArchitectureAdapter, "Qwen3NextForCausalLM": Qwen3NextArchitectureAdapter, + "Qwen3_5ForCausalLM": Qwen3_5ArchitectureAdapter, "StableLmForCausalLM": StableLmArchitectureAdapter, "T5ForConditionalGeneration": T5ArchitectureAdapter, "NanoGPTForCausalLM": NanogptArchitectureAdapter, diff --git a/transformer_lens/model_bridge/sources/transformers.py b/transformer_lens/model_bridge/sources/transformers.py index e2d43077c..ce9907384 100644 --- a/transformer_lens/model_bridge/sources/transformers.py +++ b/transformer_lens/model_bridge/sources/transformers.py @@ -223,6 +223,12 @@ def determine_architecture_from_hf_config(hf_config): "qwen": "QwenForCausalLM", "qwen2": "Qwen2ForCausalLM", "qwen3": "Qwen3ForCausalLM", + # qwen3_5 is the top-level multimodal config type; qwen3_5_text is + # the text-only sub-config. Both map to the text-only adapter so + # Qwen3.5 checkpoints (which report qwen3_5 even when loaded as + # text-only) are routed to Qwen3_5ForCausalLM. + "qwen3_5": "Qwen3_5ForCausalLM", + "qwen3_5_text": "Qwen3_5ForCausalLM", "openelm": "OpenELMForCausalLM", "stablelm": "StableLmForCausalLM", "t5": "T5ForConditionalGeneration", diff --git a/transformer_lens/model_bridge/supported_architectures/__init__.py b/transformer_lens/model_bridge/supported_architectures/__init__.py index 3ed80b776..5e9e9dec0 100644 --- a/transformer_lens/model_bridge/supported_architectures/__init__.py +++ b/transformer_lens/model_bridge/supported_architectures/__init__.py @@ -138,6 +138,9 @@ from transformer_lens.model_bridge.supported_architectures.qwen3_next import ( Qwen3NextArchitectureAdapter, ) +from transformer_lens.model_bridge.supported_architectures.qwen3_5 import ( + Qwen3_5ArchitectureAdapter, +) from transformer_lens.model_bridge.supported_architectures.stablelm import ( StableLmArchitectureAdapter, ) @@ -191,6 +194,7 @@ "Qwen3ArchitectureAdapter", "Qwen3MoeArchitectureAdapter", "Qwen3NextArchitectureAdapter", + "Qwen3_5ArchitectureAdapter", "StableLmArchitectureAdapter", "T5ArchitectureAdapter", ] diff --git a/transformer_lens/model_bridge/supported_architectures/qwen3_5.py b/transformer_lens/model_bridge/supported_architectures/qwen3_5.py new file mode 100644 index 000000000..b1e71e9f3 --- /dev/null +++ b/transformer_lens/model_bridge/supported_architectures/qwen3_5.py @@ -0,0 +1,175 @@ +"""Qwen3_5 architecture adapter. + +Qwen3_5ForCausalLM is a hybrid linear-attention + full-attention architecture +with a dense gated MLP on every layer. Layers follow a repeating pattern of +3 GatedDeltaNet (linear attention) layers followed by 1 standard full-attention +layer (every 4th layer by default). + +Since self_attn is absent on linear-attention layers, we only map submodules +that exist on ALL layers (norms, MLP). The HF native forward handles +linear/full attention dispatch internally, and GatedMLPBridge maps the dense +gate_proj/up_proj/down_proj structure on every layer. + +Hook coverage: +- Block-level: hook_resid_pre, hook_resid_post on every layer +- Normalization: ln1 (input_layernorm), ln2 (post_attention_layernorm) +- MLP: hook_in, hook_out via GatedMLPBridge (gate_proj, up_proj, down_proj) +- Attention internals are NOT individually hooked (self_attn absent on + linear-attention layers; mapping it would crash on those layers) + +Optional parameters: +- n_key_value_heads: only set when using GQA (num_key_value_heads != num_attention_heads) +""" + +from typing import Any + +import torch + +from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter +from transformer_lens.model_bridge.generalized_components import ( + BlockBridge, + EmbeddingBridge, + GatedMLPBridge, + LinearBridge, + RMSNormalizationBridge, + RotaryEmbeddingBridge, + UnembeddingBridge, +) + + +class Qwen3_5ArchitectureAdapter(ArchitectureAdapter): + """Architecture adapter for Qwen3_5 models. + + Qwen3_5ForCausalLM is a hybrid linear-attention + full-attention + architecture with dense gated MLPs, sharing the same hybrid design as + Qwen3Next but replacing the sparse MoE MLP with a standard dense MLP: + - Uses RMSNorm for all normalizations + - Uses rotary position embeddings (RoPE) with partial rotation + - Every 4th layer is a full-attention layer (self_attn); the rest are + GatedDeltaNet linear-attention layers (linear_attn) + - Uses dense gated MLP (gate_proj + up_proj -> down_proj) on ALL layers + - No biases on any linear layers + - Full-attention layers have Q/K normalization (q_norm, k_norm) + - Full-attention q_proj outputs n_heads * head_dim * 2 (interleaved + query+gate layout); the preprocess_weights method slices the query half + + Since self_attn is absent on linear-attention layers, only universally + present submodules (norms, MLP) are mapped as block submodules. The HF + native forward handles per-layer attention dispatch internally. + + Optional parameters: + - n_key_value_heads: set when num_key_value_heads != num_attention_heads (GQA) + """ + + def __init__(self, cfg: Any) -> None: + """Initialize the Qwen3_5 architecture adapter.""" + super().__init__(cfg) + + # Core config attributes + self.cfg.normalization_type = "RMS" + self.cfg.positional_embedding_type = "rotary" + self.cfg.final_rms = True + self.cfg.gated_mlp = True + self.cfg.attn_only = False + self.cfg.uses_rms_norm = True + self.cfg.default_prepend_bos = False + + # Disable fold_ln: ln1 is followed by self_attn on full-attention + # layers and by linear_attn (GatedDeltaNet) on linear-attention layers, + # but neither is mapped as a bridge submodule (see class docstring for + # why). With no bridge-mapped target to fold into, the standard fold_ln + # pass leaves LN weights in an inconsistent state and the processed + # bridge output diverges from the unprocessed / HF output. Skipping + # fold_ln keeps processed-mode forward passes numerically equivalent. + self.supports_fold_ln = False + + # Use eager attention to support output_attentions for hook_attn_scores + # and hook_pattern. SDPA doesn't support output_attentions. + self.cfg.attn_implementation = "eager" + + # GQA: only set n_key_value_heads when using grouped-query attention + if hasattr(cfg, "n_key_value_heads") and cfg.n_key_value_heads is not None: + self.cfg.n_key_value_heads = cfg.n_key_value_heads + + self.weight_processing_conversions: dict = {} + self.component_mapping: dict = { + "embed": EmbeddingBridge(name="model.embed_tokens"), + "rotary_emb": RotaryEmbeddingBridge(name="model.rotary_emb", config=self.cfg), + "blocks": BlockBridge( + name="model.layers", + submodules={ + "ln1": RMSNormalizationBridge(name="input_layernorm", config=self.cfg), + "ln2": RMSNormalizationBridge(name="post_attention_layernorm", config=self.cfg), + # Dense gated MLP present on every layer (unlike Qwen3Next's MoE). + # gate_proj + up_proj feed into down_proj via SwiGLU activation. + "mlp": GatedMLPBridge( + name="mlp", + config=self.cfg, + submodules={ + "gate": LinearBridge(name="gate_proj"), + "in": LinearBridge(name="up_proj"), + "out": LinearBridge(name="down_proj"), + }, + ), + }, + ), + "ln_final": RMSNormalizationBridge(name="model.norm", config=self.cfg), + "unembed": UnembeddingBridge(name="lm_head"), + } + + def prepare_loading(self, model_name: str, model_kwargs: dict) -> None: + """Swap the multimodal Qwen3_5Config for its text-only Qwen3_5TextConfig. + + Published Qwen3.5 checkpoints (e.g. Qwen/Qwen3.5-0.8B) carry + model_type='qwen3_5' and architectures=['Qwen3_5ForConditionalGeneration']. + AutoModelForCausalLM would load the full VLM (Qwen3_5ForConditionalGeneration) + with its vision tower, wasting memory and failing the bridge. + + Instead we replace model_kwargs['config'] with the nested text_config so + AutoModelForCausalLM loads Qwen3_5ForCausalLM (text only). + """ + config = model_kwargs.get("config") + if config is not None and hasattr(config, "text_config"): + model_kwargs["config"] = config.text_config + + def setup_component_testing(self, hf_model: Any, bridge_model: Any = None) -> None: + """No-op for hybrid models. + + Hybrid models don't map attention as a block submodule (self_attn is + absent on linear-attention layers), so there are no rotary embedding + references to set up. + + Note: to find which layers are full_attention at runtime, use: + layer_types = getattr(hf_model.config, "layer_types", []) + first_full_attn_idx = next( + i for i, t in enumerate(layer_types) if t == "full_attention" + ) + Do NOT use hf_model.config.full_attention_interval -- it is not stored + on the config object (consumed during __init__ to build layer_types). + """ + + def preprocess_weights(self, state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + """Slice query half from q_proj.weight (interleaved per-head layout). + + In Qwen3_5, q_proj.weight has shape (n_heads * head_dim * 2, hidden_size). + Rows are organized as per-head interleaved: + head_0_query (d_head rows), head_0_gate (d_head rows), + head_1_query (d_head rows), head_1_gate (d_head rows), ... + + A naive first-half slice would be wrong. We must reshape by head, then + take the first d_head rows of each head (the query half). + + Note: since self_attn is NOT currently mapped as a bridge submodule, + these weights will not be loaded by the bridge. This method is included + for correctness and forward-compatibility. + """ + n_heads = self.cfg.n_heads + d_head = self.cfg.d_head + keys_to_update = [k for k in state_dict if k.endswith(".self_attn.q_proj.weight")] + for key in keys_to_update: + w = state_dict[key] # shape: (n_heads * d_head * 2, hidden_size) + # Reshape to expose per-head layout + w = w.view(n_heads, d_head * 2, -1) + # Take only the first d_head rows of each head (query half) + state_dict[key] = w[:, :d_head, :].reshape(n_heads * d_head, -1) + return state_dict diff --git a/transformer_lens/tools/model_registry/__init__.py b/transformer_lens/tools/model_registry/__init__.py index 4d4dfb528..3150f159f 100644 --- a/transformer_lens/tools/model_registry/__init__.py +++ b/transformer_lens/tools/model_registry/__init__.py @@ -81,6 +81,7 @@ "Qwen2ForCausalLM", "Qwen3ForCausalLM", "Qwen3NextForCausalLM", + "Qwen3_5ForCausalLM", "StableLmForCausalLM", "T5ForConditionalGeneration", } diff --git a/transformer_lens/tools/model_registry/data/supported_models.json b/transformer_lens/tools/model_registry/data/supported_models.json index 432780e60..57f1d6425 100644 --- a/transformer_lens/tools/model_registry/data/supported_models.json +++ b/transformer_lens/tools/model_registry/data/supported_models.json @@ -6,9 +6,9 @@ "min_downloads": 500, "scan_duration_seconds": 3.9 }, - "total_architectures": 40, - "total_models": 6868, - "total_verified": 699, + "total_architectures": 43, + "total_models": 7006, + "total_verified": 703, "models": [ { "architecture_id": "Qwen3NextForCausalLM", @@ -98395,6 +98395,57 @@ "phase4_score": 70.4, "phase7_score": null, "phase8_score": null + }, + { + "architecture_id": "Qwen3_5ForCausalLM", + "model_id": "Qwen/Qwen3.5-0.8B", + "status": 1, + "verified_date": "2026-04-14", + "metadata": { + "downloads": 2577198, + "total_params": 950000000 + }, + "note": "Full verification completed with issues: P3=94.1% (failed: attention_output_centering)", + "phase1_score": 100.0, + "phase2_score": 100.0, + "phase3_score": 94.1, + "phase4_score": 91.5, + "phase7_score": null, + "phase8_score": null + }, + { + "architecture_id": "Qwen3_5ForCausalLM", + "model_id": "Qwen/Qwen3.5-4B", + "status": 0, + "verified_date": null, + "metadata": { + "downloads": 2920685, + "total_params": 3660000000 + }, + "note": null, + "phase1_score": null, + "phase2_score": null, + "phase3_score": null, + "phase4_score": null, + "phase7_score": null, + "phase8_score": null + }, + { + "architecture_id": "Qwen3_5ForCausalLM", + "model_id": "Qwen/Qwen3.5-9B", + "status": 0, + "verified_date": null, + "metadata": { + "downloads": 5662081, + "total_params": 8750000000 + }, + "note": null, + "phase1_score": null, + "phase2_score": null, + "phase3_score": null, + "phase4_score": null, + "phase7_score": null, + "phase8_score": null } ] } diff --git a/transformer_lens/tools/model_registry/data/verification_history.json b/transformer_lens/tools/model_registry/data/verification_history.json index 21d21c369..7fec78e9e 100644 --- a/transformer_lens/tools/model_registry/data/verification_history.json +++ b/transformer_lens/tools/model_registry/data/verification_history.json @@ -1,5 +1,5 @@ { - "last_updated": "2026-04-10T19:09:48.784882", + "last_updated": "2026-04-14T12:15:43.792442", "records": [ { "model_id": "Macropodus/macbert4mdcspell_v1", @@ -11230,6 +11230,26 @@ "notes": "Full verification completed with issues: P3=94.7% (failed: weight_modification)", "invalidated": false, "invalidation_reason": null + }, + { + "model_id": "Qwen/Qwen3.5-0.8B", + "architecture_id": "Qwen3_5ForCausalLM", + "verified_date": "2026-04-14", + "verified_by": "verify_models", + "transformerlens_version": null, + "notes": "Below threshold: P1=0.0% < 100.0% (failed: load_bridge_unprocessed) \u2014 Failed to load unprocessed TransformerBridge: Could not determine supported architecture from config. Available architectures: ['ApertusForCausalLM', ", + "invalidated": false, + "invalidation_reason": null + }, + { + "model_id": "Qwen/Qwen3.5-0.8B", + "architecture_id": "Qwen3_5ForCausalLM", + "verified_date": "2026-04-14", + "verified_by": "verify_models", + "transformerlens_version": null, + "notes": "Full verification completed with issues: P3=94.1% (failed: attention_output_centering)", + "invalidated": false, + "invalidation_reason": null } ] } From f956dec0ec84b25b44f86692be8c9109bc300b27 Mon Sep 17 00:00:00 2001 From: jlarson4 Date: Tue, 14 Apr 2026 13:17:39 -0500 Subject: [PATCH 2/2] Latest verifications --- .../model_registry/data/supported_models.json | 16 ++++++++-------- .../data/verification_history.json | 12 +++++++++++- 2 files changed, 19 insertions(+), 9 deletions(-) diff --git a/transformer_lens/tools/model_registry/data/supported_models.json b/transformer_lens/tools/model_registry/data/supported_models.json index 57f1d6425..7b8e3dd4e 100644 --- a/transformer_lens/tools/model_registry/data/supported_models.json +++ b/transformer_lens/tools/model_registry/data/supported_models.json @@ -8,7 +8,7 @@ }, "total_architectures": 43, "total_models": 7006, - "total_verified": 703, + "total_verified": 704, "models": [ { "architecture_id": "Qwen3NextForCausalLM", @@ -98416,17 +98416,17 @@ { "architecture_id": "Qwen3_5ForCausalLM", "model_id": "Qwen/Qwen3.5-4B", - "status": 0, - "verified_date": null, + "status": 1, + "verified_date": "2026-04-14", "metadata": { "downloads": 2920685, "total_params": 3660000000 }, - "note": null, - "phase1_score": null, - "phase2_score": null, - "phase3_score": null, - "phase4_score": null, + "note": "Full verification completed with issues: P3=94.1% (failed: attention_output_centering)", + "phase1_score": 100.0, + "phase2_score": 100.0, + "phase3_score": 94.1, + "phase4_score": 98.5, "phase7_score": null, "phase8_score": null }, diff --git a/transformer_lens/tools/model_registry/data/verification_history.json b/transformer_lens/tools/model_registry/data/verification_history.json index 7fec78e9e..be174d401 100644 --- a/transformer_lens/tools/model_registry/data/verification_history.json +++ b/transformer_lens/tools/model_registry/data/verification_history.json @@ -1,5 +1,5 @@ { - "last_updated": "2026-04-14T12:15:43.792442", + "last_updated": "2026-04-14T13:03:57.367589", "records": [ { "model_id": "Macropodus/macbert4mdcspell_v1", @@ -11250,6 +11250,16 @@ "notes": "Full verification completed with issues: P3=94.1% (failed: attention_output_centering)", "invalidated": false, "invalidation_reason": null + }, + { + "model_id": "Qwen/Qwen3.5-4B", + "architecture_id": "Qwen3_5ForCausalLM", + "verified_date": "2026-04-14", + "verified_by": "verify_models", + "transformerlens_version": null, + "notes": "Full verification completed with issues: P3=94.1% (failed: attention_output_centering)", + "invalidated": false, + "invalidation_reason": null } ] }