From 8eafe75912d3257cedc0de0add96e963588b0a51 Mon Sep 17 00:00:00 2001 From: jlarson4 Date: Mon, 13 Apr 2026 12:15:06 -0500 Subject: [PATCH] Completed InternLM2 adapter --- .../test_internlm2_adapter.py | 691 ++++++++++++++++++ .../factories/architecture_adapter_factory.py | 2 + transformer_lens/model_bridge/compat.py | 44 ++ .../supported_architectures/__init__.py | 4 + .../supported_architectures/internlm2.py | 324 ++++++++ .../supported_architectures/phi3.py | 36 +- .../model_registry/data/supported_models.json | 414 ++++++++++- .../data/verification_history.json | 122 +++- .../tools/model_registry/verify_models.py | 12 +- 9 files changed, 1609 insertions(+), 40 deletions(-) create mode 100644 tests/unit/model_bridge/supported_architectures/test_internlm2_adapter.py create mode 100644 transformer_lens/model_bridge/compat.py create mode 100644 transformer_lens/model_bridge/supported_architectures/internlm2.py diff --git a/tests/unit/model_bridge/supported_architectures/test_internlm2_adapter.py b/tests/unit/model_bridge/supported_architectures/test_internlm2_adapter.py new file mode 100644 index 000000000..b6d3d061d --- /dev/null +++ b/tests/unit/model_bridge/supported_architectures/test_internlm2_adapter.py @@ -0,0 +1,691 @@ +"""Unit tests for InternLM2ArchitectureAdapter. + +Tests cover (one class per phase): +- Phase A: Config attributes, weight conversion keys/types, split_wqkv numerics, + preprocess_weights behaviour +- Phase D: Factory registration +""" + +from types import SimpleNamespace +from typing import Any + +import pytest +import torch +import torch.nn as nn + +from transformer_lens.config import TransformerBridgeConfig +from transformer_lens.conversion_utils.conversion_steps import RearrangeTensorConversion +from transformer_lens.conversion_utils.param_processing_conversion import ( + ParamProcessingConversion, +) +from transformer_lens.model_bridge.generalized_components import ( + BlockBridge, + EmbeddingBridge, + GatedMLPBridge, + JointQKVPositionEmbeddingsAttentionBridge, + RMSNormalizationBridge, + UnembeddingBridge, +) +from transformer_lens.model_bridge.supported_architectures.internlm2 import ( + InternLM2ArchitectureAdapter, +) + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +def _make_cfg( + n_heads: int = 8, + n_key_value_heads: int = 2, + d_model: int = 64, + n_layers: int = 2, + d_vocab: int = 100, + n_ctx: int = 128, +) -> TransformerBridgeConfig: + """Minimal TransformerBridgeConfig for InternLM2 adapter tests.""" + return TransformerBridgeConfig( + d_model=d_model, + d_head=d_model // n_heads, + n_layers=n_layers, + n_ctx=n_ctx, + n_heads=n_heads, + d_vocab=d_vocab, + n_key_value_heads=n_key_value_heads, + default_prepend_bos=True, + architecture="InternLM2ForCausalLM", + ) + + +@pytest.fixture +def cfg() -> TransformerBridgeConfig: + return _make_cfg() + + +@pytest.fixture +def adapter(cfg: TransformerBridgeConfig) -> InternLM2ArchitectureAdapter: + return InternLM2ArchitectureAdapter(cfg) + + +def _make_attn_component( + n_heads: int, + n_kv_heads: int, + head_dim: int, + d_model: int, + has_bias: bool = False, +) -> Any: + """Synthetic attention namespace with a wqkv linear (no model download needed).""" + total_out = (n_heads + 2 * n_kv_heads) * head_dim + ns = SimpleNamespace() + ns.wqkv = nn.Linear(d_model, total_out, bias=has_bias) + return ns + + +def _fill_interleaved( + wqkv_linear: nn.Linear, + n_heads: int, + n_kv_heads: int, + head_dim: int, + d_model: int, + kv_group_vals: list[tuple[float, float, float]], +) -> None: + """Fill wqkv weight with per-kv-group constants for layout verification. + + kv_group_vals: list of (q_val, k_val, v_val) per kv-head group. + """ + n_kv_groups = n_heads // n_kv_heads + gs = n_kv_groups + 2 + w = torch.zeros(n_kv_heads, gs, head_dim, d_model) + for h, (q_val, k_val, v_val) in enumerate(kv_group_vals): + w[h, :n_kv_groups, :, :] = q_val + w[h, n_kv_groups, :, :] = k_val + w[h, n_kv_groups + 1, :, :] = v_val + wqkv_linear.weight = nn.Parameter(w.reshape((n_heads + 2 * n_kv_heads) * head_dim, d_model)) + + +# --------------------------------------------------------------------------- +# Phase A — Config attribute tests +# --------------------------------------------------------------------------- + + +class TestInternLM2AdapterConfig: + """Adapter must set all required config attributes.""" + + def test_normalization_type(self, adapter: InternLM2ArchitectureAdapter) -> None: + assert adapter.cfg.normalization_type == "RMS" + + def test_positional_embedding_type(self, adapter: InternLM2ArchitectureAdapter) -> None: + assert adapter.cfg.positional_embedding_type == "rotary" + + def test_final_rms(self, adapter: InternLM2ArchitectureAdapter) -> None: + assert adapter.cfg.final_rms is True + + def test_gated_mlp(self, adapter: InternLM2ArchitectureAdapter) -> None: + assert adapter.cfg.gated_mlp is True + + def test_attn_only(self, adapter: InternLM2ArchitectureAdapter) -> None: + assert adapter.cfg.attn_only is False + + def test_uses_rms_norm(self, adapter: InternLM2ArchitectureAdapter) -> None: + assert adapter.cfg.uses_rms_norm is True + + def test_eps_attr(self, adapter: InternLM2ArchitectureAdapter) -> None: + assert adapter.cfg.eps_attr == "variance_epsilon" + + def test_n_key_value_heads_propagated(self, adapter: InternLM2ArchitectureAdapter) -> None: + assert adapter.cfg.n_key_value_heads == 2 + + def test_supports_fold_ln_false(self, adapter: InternLM2ArchitectureAdapter) -> None: + # Must be False: fold_ln silently skips attn when wqkv is fused in bridge state dict. + assert adapter.supports_fold_ln is False + + +# --------------------------------------------------------------------------- +# Phase A — Component mapping structure tests +# --------------------------------------------------------------------------- + + +class TestInternLM2AdapterComponentMapping: + """component_mapping must have correct bridge types and InternLM2-specific names.""" + + def test_embed_is_embedding_bridge(self, adapter: InternLM2ArchitectureAdapter) -> None: + assert adapter.component_mapping is not None + assert isinstance(adapter.component_mapping["embed"], EmbeddingBridge) + + def test_embed_name(self, adapter: InternLM2ArchitectureAdapter) -> None: + # InternLM2 uses tok_embeddings, not embed_tokens + assert adapter.component_mapping is not None + assert adapter.component_mapping["embed"].name == "model.tok_embeddings" + + def test_no_top_level_rotary_emb(self, adapter: InternLM2ArchitectureAdapter) -> None: + # Per-layer rotary injected via setup_component_testing, not top-level mapping + assert adapter.component_mapping is not None + assert "rotary_emb" not in adapter.component_mapping + + def test_blocks_is_block_bridge(self, adapter: InternLM2ArchitectureAdapter) -> None: + assert adapter.component_mapping is not None + assert isinstance(adapter.component_mapping["blocks"], BlockBridge) + + def test_blocks_name(self, adapter: InternLM2ArchitectureAdapter) -> None: + assert adapter.component_mapping is not None + assert adapter.component_mapping["blocks"].name == "model.layers" + + def test_ln_final_is_rms_normalization_bridge( + self, adapter: InternLM2ArchitectureAdapter + ) -> None: + assert adapter.component_mapping is not None + assert isinstance(adapter.component_mapping["ln_final"], RMSNormalizationBridge) + + def test_ln_final_name(self, adapter: InternLM2ArchitectureAdapter) -> None: + assert adapter.component_mapping is not None + assert adapter.component_mapping["ln_final"].name == "model.norm" + + def test_unembed_is_unembedding_bridge(self, adapter: InternLM2ArchitectureAdapter) -> None: + assert adapter.component_mapping is not None + assert isinstance(adapter.component_mapping["unembed"], UnembeddingBridge) + + def test_unembed_name(self, adapter: InternLM2ArchitectureAdapter) -> None: + # InternLM2 uses 'output', not 'lm_head' + assert adapter.component_mapping is not None + assert adapter.component_mapping["unembed"].name == "output" + + def test_ln1_is_rms_normalization_bridge(self, adapter: InternLM2ArchitectureAdapter) -> None: + assert adapter.component_mapping is not None + blocks = adapter.component_mapping["blocks"] + assert isinstance(blocks.submodules["ln1"], RMSNormalizationBridge) + + def test_ln1_name(self, adapter: InternLM2ArchitectureAdapter) -> None: + # InternLM2 uses attention_norm, not input_layernorm + assert adapter.component_mapping is not None + blocks = adapter.component_mapping["blocks"] + assert blocks.submodules["ln1"].name == "attention_norm" + + def test_ln2_is_rms_normalization_bridge(self, adapter: InternLM2ArchitectureAdapter) -> None: + assert adapter.component_mapping is not None + blocks = adapter.component_mapping["blocks"] + assert isinstance(blocks.submodules["ln2"], RMSNormalizationBridge) + + def test_ln2_name(self, adapter: InternLM2ArchitectureAdapter) -> None: + # InternLM2 uses ffn_norm, not post_attention_layernorm + assert adapter.component_mapping is not None + blocks = adapter.component_mapping["blocks"] + assert blocks.submodules["ln2"].name == "ffn_norm" + + def test_attn_is_joint_qkv_position_embeddings_attention_bridge( + self, adapter: InternLM2ArchitectureAdapter + ) -> None: + assert adapter.component_mapping is not None + blocks = adapter.component_mapping["blocks"] + assert isinstance(blocks.submodules["attn"], JointQKVPositionEmbeddingsAttentionBridge) + + def test_attn_name(self, adapter: InternLM2ArchitectureAdapter) -> None: + # InternLM2 uses 'attention', not 'self_attn' + assert adapter.component_mapping is not None + blocks = adapter.component_mapping["blocks"] + assert blocks.submodules["attn"].name == "attention" + + def test_attn_qkv_submodule_name(self, adapter: InternLM2ArchitectureAdapter) -> None: + assert adapter.component_mapping is not None + blocks = adapter.component_mapping["blocks"] + assert blocks.submodules["attn"].submodules["qkv"].name == "wqkv" + + def test_attn_o_submodule_name(self, adapter: InternLM2ArchitectureAdapter) -> None: + assert adapter.component_mapping is not None + blocks = adapter.component_mapping["blocks"] + assert blocks.submodules["attn"].submodules["o"].name == "wo" + + def test_mlp_is_gated_mlp_bridge(self, adapter: InternLM2ArchitectureAdapter) -> None: + assert adapter.component_mapping is not None + blocks = adapter.component_mapping["blocks"] + assert isinstance(blocks.submodules["mlp"], GatedMLPBridge) + + def test_mlp_name(self, adapter: InternLM2ArchitectureAdapter) -> None: + # InternLM2 uses 'feed_forward', not 'mlp' + assert adapter.component_mapping is not None + blocks = adapter.component_mapping["blocks"] + assert blocks.submodules["mlp"].name == "feed_forward" + + def test_mlp_gate_submodule_name(self, adapter: InternLM2ArchitectureAdapter) -> None: + # w1 = gate projection + assert adapter.component_mapping is not None + blocks = adapter.component_mapping["blocks"] + assert blocks.submodules["mlp"].submodules["gate"].name == "w1" + + def test_mlp_in_submodule_name(self, adapter: InternLM2ArchitectureAdapter) -> None: + # w3 = up/in projection + assert adapter.component_mapping is not None + blocks = adapter.component_mapping["blocks"] + assert blocks.submodules["mlp"].submodules["in"].name == "w3" + + def test_mlp_out_submodule_name(self, adapter: InternLM2ArchitectureAdapter) -> None: + # w2 = down/out projection + assert adapter.component_mapping is not None + blocks = adapter.component_mapping["blocks"] + assert blocks.submodules["mlp"].submodules["out"].name == "w2" + + +# --------------------------------------------------------------------------- +# Phase A — Weight conversion key and type tests +# --------------------------------------------------------------------------- + + +class TestInternLM2AdapterWeightConversions: + """weight_processing_conversions must have correct keys, types, and rearrange patterns.""" + + def test_q_weight_key_present(self, adapter: InternLM2ArchitectureAdapter) -> None: + assert adapter.weight_processing_conversions is not None + assert "blocks.{i}.attn.q.weight" in adapter.weight_processing_conversions + + def test_k_weight_key_present(self, adapter: InternLM2ArchitectureAdapter) -> None: + assert adapter.weight_processing_conversions is not None + assert "blocks.{i}.attn.k.weight" in adapter.weight_processing_conversions + + def test_v_weight_key_present(self, adapter: InternLM2ArchitectureAdapter) -> None: + assert adapter.weight_processing_conversions is not None + assert "blocks.{i}.attn.v.weight" in adapter.weight_processing_conversions + + def test_o_weight_key_present(self, adapter: InternLM2ArchitectureAdapter) -> None: + assert adapter.weight_processing_conversions is not None + assert "blocks.{i}.attn.o.weight" in adapter.weight_processing_conversions + + def test_exactly_four_conversion_keys(self, adapter: InternLM2ArchitectureAdapter) -> None: + # No bias entries for the bias=False shipped config + assert adapter.weight_processing_conversions is not None + assert len(adapter.weight_processing_conversions) == 4 + + def test_q_conversion_is_param_processing_conversion( + self, adapter: InternLM2ArchitectureAdapter + ) -> None: + assert adapter.weight_processing_conversions is not None + conv = adapter.weight_processing_conversions["blocks.{i}.attn.q.weight"] + assert isinstance(conv, ParamProcessingConversion) + + def test_q_tensor_conversion_is_rearrange(self, adapter: InternLM2ArchitectureAdapter) -> None: + assert adapter.weight_processing_conversions is not None + conv = adapter.weight_processing_conversions["blocks.{i}.attn.q.weight"] + assert isinstance(conv, ParamProcessingConversion) + assert isinstance(conv.tensor_conversion, RearrangeTensorConversion) + + def test_q_rearrange_pattern(self, adapter: InternLM2ArchitectureAdapter) -> None: + assert adapter.weight_processing_conversions is not None + conv = adapter.weight_processing_conversions["blocks.{i}.attn.q.weight"] + assert isinstance(conv, ParamProcessingConversion) + assert isinstance(conv.tensor_conversion, RearrangeTensorConversion) + assert conv.tensor_conversion.pattern == "(n h) m -> n m h" + + def test_q_rearrange_n_equals_n_heads(self, adapter: InternLM2ArchitectureAdapter) -> None: + assert adapter.weight_processing_conversions is not None + conv = adapter.weight_processing_conversions["blocks.{i}.attn.q.weight"] + assert isinstance(conv, ParamProcessingConversion) + assert isinstance(conv.tensor_conversion, RearrangeTensorConversion) + assert conv.tensor_conversion.axes_lengths["n"] == adapter.cfg.n_heads + + def test_k_rearrange_n_equals_n_kv_heads(self, adapter: InternLM2ArchitectureAdapter) -> None: + assert adapter.weight_processing_conversions is not None + conv = adapter.weight_processing_conversions["blocks.{i}.attn.k.weight"] + assert isinstance(conv, ParamProcessingConversion) + assert isinstance(conv.tensor_conversion, RearrangeTensorConversion) + assert conv.tensor_conversion.axes_lengths["n"] == adapter.cfg.n_key_value_heads + + def test_v_rearrange_n_equals_n_kv_heads(self, adapter: InternLM2ArchitectureAdapter) -> None: + assert adapter.weight_processing_conversions is not None + conv = adapter.weight_processing_conversions["blocks.{i}.attn.v.weight"] + assert isinstance(conv, ParamProcessingConversion) + assert isinstance(conv.tensor_conversion, RearrangeTensorConversion) + assert conv.tensor_conversion.axes_lengths["n"] == adapter.cfg.n_key_value_heads + + def test_o_rearrange_pattern(self, adapter: InternLM2ArchitectureAdapter) -> None: + assert adapter.weight_processing_conversions is not None + conv = adapter.weight_processing_conversions["blocks.{i}.attn.o.weight"] + assert isinstance(conv, ParamProcessingConversion) + assert isinstance(conv.tensor_conversion, RearrangeTensorConversion) + assert conv.tensor_conversion.pattern == "m (n h) -> n h m" + + def test_o_rearrange_n_equals_n_heads(self, adapter: InternLM2ArchitectureAdapter) -> None: + assert adapter.weight_processing_conversions is not None + conv = adapter.weight_processing_conversions["blocks.{i}.attn.o.weight"] + assert isinstance(conv, ParamProcessingConversion) + assert isinstance(conv.tensor_conversion, RearrangeTensorConversion) + assert conv.tensor_conversion.axes_lengths["n"] == adapter.cfg.n_heads + + def test_no_source_key_on_q(self, adapter: InternLM2ArchitectureAdapter) -> None: + # preprocess_weights writes split keys; no cross-key lookup needed at rearrange time + assert adapter.weight_processing_conversions is not None + conv = adapter.weight_processing_conversions["blocks.{i}.attn.q.weight"] + assert isinstance(conv, ParamProcessingConversion) + assert conv.source_key is None + + +# --------------------------------------------------------------------------- +# Phase A — _split_internlm2_wqkv numerical tests +# --------------------------------------------------------------------------- + + +class TestInternLM2SplitWqkv: + """Numerical correctness of the interleaved GQA split function.""" + + def _adapter( + self, + n_heads: int = 8, + n_kv_heads: int = 2, + d_model: int = 32, + ) -> InternLM2ArchitectureAdapter: + head_dim = d_model // n_heads + return InternLM2ArchitectureAdapter( + _make_cfg(n_heads=n_heads, n_key_value_heads=n_kv_heads, d_model=d_model) + ) + + def test_returns_three_linears(self) -> None: + adapter = self._adapter() + attn = _make_attn_component(8, 2, 4, 32) + q, k, v = adapter._split_internlm2_wqkv(attn) + assert isinstance(q, nn.Linear) + assert isinstance(k, nn.Linear) + assert isinstance(v, nn.Linear) + + def test_gqa_shapes(self) -> None: + # n_heads=8, n_kv_heads=2, head_dim=4, d_model=32 + adapter = self._adapter(n_heads=8, n_kv_heads=2, d_model=32) + attn = _make_attn_component(8, 2, 4, 32) + q, k, v = adapter._split_internlm2_wqkv(attn) + assert q.weight.shape == (8 * 4, 32) + assert k.weight.shape == (2 * 4, 32) + assert v.weight.shape == (2 * 4, 32) + + def test_mha_shapes(self) -> None: + # MHA: n_heads == n_kv_heads → gs=3 (standard [Q|K|V]) + adapter = self._adapter(n_heads=4, n_kv_heads=4, d_model=32) + attn = _make_attn_component(4, 4, 8, 32) + q, k, v = adapter._split_internlm2_wqkv(attn) + assert q.weight.shape == (4 * 8, 32) + assert k.weight.shape == (4 * 8, 32) + assert v.weight.shape == (4 * 8, 32) + + def test_interleaved_layout_correctness(self) -> None: + # n_heads=4, n_kv_heads=2, head_dim=4, d_model=16 → gs=4 (2 q-groups + k + v) + n_heads, n_kv_heads, head_dim, d_model = 4, 2, 4, 16 + adapter = self._adapter(n_heads=n_heads, n_kv_heads=n_kv_heads, d_model=d_model) + attn = _make_attn_component(n_heads, n_kv_heads, head_dim, d_model) + # kv-group 0: Q=1.0, K=2.0, V=3.0; kv-group 1: Q=4.0, K=5.0, V=6.0 + _fill_interleaved( + attn.wqkv, + n_heads, + n_kv_heads, + head_dim, + d_model, + [(1.0, 2.0, 3.0), (4.0, 5.0, 6.0)], + ) + q, k, v = adapter._split_internlm2_wqkv(attn) + + n_kv_groups = n_heads // n_kv_heads # 2 + # Q: rows 0..n_kv_groups*head_dim-1 come from kv-group 0 Q slots (1.0), + # rows n_kv_groups*head_dim..n_heads*head_dim-1 from kv-group 1 Q slots (4.0) + assert torch.all(q.weight[: n_kv_groups * head_dim] == 1.0), "Q group-0 rows should be 1.0" + assert torch.all(q.weight[n_kv_groups * head_dim :] == 4.0), "Q group-1 rows should be 4.0" + # K: row 0..head_dim-1 = kv-group 0 K (2.0), head_dim..2*head_dim-1 = kv-group 1 K (5.0) + assert torch.all(k.weight[:head_dim] == 2.0), "K group-0 rows should be 2.0" + assert torch.all(k.weight[head_dim:] == 5.0), "K group-1 rows should be 5.0" + # V analogous + assert torch.all(v.weight[:head_dim] == 3.0), "V group-0 rows should be 3.0" + assert torch.all(v.weight[head_dim:] == 6.0), "V group-1 rows should be 6.0" + + def test_no_bias(self) -> None: + adapter = self._adapter() + attn = _make_attn_component(8, 2, 4, 32, has_bias=False) + q, k, v = adapter._split_internlm2_wqkv(attn) + assert q.bias is None + assert k.bias is None + assert v.bias is None + + def test_with_bias_shapes(self) -> None: + n_heads, n_kv_heads, head_dim, d_model = 8, 2, 4, 32 + adapter = self._adapter(n_heads=n_heads, n_kv_heads=n_kv_heads, d_model=d_model) + attn = _make_attn_component(n_heads, n_kv_heads, head_dim, d_model, has_bias=True) + q, k, v = adapter._split_internlm2_wqkv(attn) + assert q.bias is not None + assert k.bias is not None + assert v.bias is not None + assert q.bias.shape == (n_heads * head_dim,) + assert k.bias.shape == (n_kv_heads * head_dim,) + assert v.bias.shape == (n_kv_heads * head_dim,) + + def test_with_bias_interleaved_values(self) -> None: + # Verify bias values follow the same interleaved layout as weights + n_heads, n_kv_heads, head_dim, d_model = 4, 2, 4, 16 + adapter = self._adapter(n_heads=n_heads, n_kv_heads=n_kv_heads, d_model=d_model) + attn = _make_attn_component(n_heads, n_kv_heads, head_dim, d_model, has_bias=True) + n_kv_groups = n_heads // n_kv_heads + gs = n_kv_groups + 2 + # Bias: interleaved [q0_vals, q1_vals, k_val, v_val] per kv-head group + b = torch.zeros((n_heads + 2 * n_kv_heads) * head_dim) + b_grouped = b.reshape(n_kv_heads, gs, head_dim) + b_grouped[0, :n_kv_groups, :] = 1.0 # kv-group 0 Q bias + b_grouped[0, n_kv_groups, :] = 2.0 # kv-group 0 K bias + b_grouped[0, n_kv_groups + 1, :] = 3.0 # kv-group 0 V bias + b_grouped[1, :n_kv_groups, :] = 4.0 # kv-group 1 Q bias + b_grouped[1, n_kv_groups, :] = 5.0 + b_grouped[1, n_kv_groups + 1, :] = 6.0 + attn.wqkv.bias = nn.Parameter(b_grouped.reshape(-1)) + + q, k, v = adapter._split_internlm2_wqkv(attn) + assert torch.all(q.bias[: n_kv_groups * head_dim] == 1.0) + assert torch.all(q.bias[n_kv_groups * head_dim :] == 4.0) + assert torch.all(k.bias[:head_dim] == 2.0) + assert torch.all(k.bias[head_dim:] == 5.0) + assert torch.all(v.bias[:head_dim] == 3.0) + assert torch.all(v.bias[head_dim:] == 6.0) + + def test_forward_output_shapes(self) -> None: + n_heads, n_kv_heads, head_dim, d_model = 8, 2, 4, 32 + adapter = self._adapter(n_heads=n_heads, n_kv_heads=n_kv_heads, d_model=d_model) + attn = _make_attn_component(n_heads, n_kv_heads, head_dim, d_model) + q, k, v = adapter._split_internlm2_wqkv(attn) + x = torch.randn(2, 5, d_model) + assert q(x).shape == (2, 5, n_heads * head_dim) + assert k(x).shape == (2, 5, n_kv_heads * head_dim) + assert v(x).shape == (2, 5, n_kv_heads * head_dim) + + +# --------------------------------------------------------------------------- +# Phase A — preprocess_weights tests +# --------------------------------------------------------------------------- + + +class TestInternLM2PreprocessWeights: + """preprocess_weights must split fused wqkv and fold layer norms.""" + + def _make_state_dict_with_fused_qkv( + self, + adapter: InternLM2ArchitectureAdapter, + n_kv_heads: int, + head_dim: int, + d_model: int, + n_layers: int, + ln1_scale: float = 1.0, + qkv_val: float = 1.0, + ) -> dict[str, torch.Tensor]: + """Build a bridge-format state dict with fused qkv.weight for each layer.""" + n_heads = adapter.cfg.n_heads + n_kv_groups = n_heads // n_kv_heads + gs = n_kv_groups + 2 + state: dict[str, torch.Tensor] = {} + for i in range(n_layers): + total_rows = (n_heads + 2 * n_kv_heads) * head_dim + state[f"blocks.{i}.attn.qkv.weight"] = torch.full((total_rows, d_model), qkv_val) + state[f"blocks.{i}.ln1.weight"] = torch.full((d_model,), ln1_scale) + state[f"blocks.{i}.ln2.weight"] = torch.ones(d_model) + state[f"blocks.{i}.mlp.gate.weight"] = torch.ones(16, d_model) + state[f"blocks.{i}.mlp.in.weight"] = torch.ones(16, d_model) + state["ln_final.weight"] = torch.ones(d_model) + state["unembed.weight"] = torch.ones(100, d_model) + return state + + def test_fused_key_removed_and_split_keys_written(self) -> None: + adapter = InternLM2ArchitectureAdapter(_make_cfg()) + adapter._fold_ln_requested = True + n_kv_heads, head_dim, d_model = 2, 8, 64 + sd = self._make_state_dict_with_fused_qkv(adapter, n_kv_heads, head_dim, d_model, 2) + + result = adapter.preprocess_weights(sd) + + assert "blocks.0.attn.qkv.weight" not in result, "fused qkv key must be deleted" + assert "blocks.0.attn.q.weight" in result + assert "blocks.0.attn.k.weight" in result + assert "blocks.0.attn.v.weight" in result + + def test_split_q_shape(self) -> None: + adapter = InternLM2ArchitectureAdapter( + _make_cfg(n_heads=8, n_key_value_heads=2, d_model=64) + ) + adapter._fold_ln_requested = True + n_kv_heads, head_dim, d_model = 2, 8, 64 + sd = self._make_state_dict_with_fused_qkv(adapter, n_kv_heads, head_dim, d_model, 2) + result = adapter.preprocess_weights(sd) + assert result["blocks.0.attn.q.weight"].shape == (8 * 8, 64) + assert result["blocks.0.attn.k.weight"].shape == (2 * 8, 64) + assert result["blocks.0.attn.v.weight"].shape == (2 * 8, 64) + + def test_ln1_fold_applied_to_q(self) -> None: + """After folding ln1 scale=2.0 into qkv (all 1.0), q/k/v weights should be 2.0.""" + adapter = InternLM2ArchitectureAdapter( + _make_cfg(n_heads=8, n_key_value_heads=2, d_model=64) + ) + adapter._fold_ln_requested = True + n_kv_heads, head_dim, d_model = 2, 8, 64 + sd = self._make_state_dict_with_fused_qkv( + adapter, n_kv_heads, head_dim, d_model, 2, ln1_scale=2.0, qkv_val=1.0 + ) + result = adapter.preprocess_weights(sd) + assert torch.all(result["blocks.0.attn.q.weight"] == 2.0) + assert torch.all(result["blocks.0.attn.k.weight"] == 2.0) + assert torch.all(result["blocks.0.attn.v.weight"] == 2.0) + + def test_ln1_reset_to_ones(self) -> None: + adapter = InternLM2ArchitectureAdapter(_make_cfg()) + adapter._fold_ln_requested = True + n_kv_heads, head_dim, d_model = 2, 8, 64 + sd = self._make_state_dict_with_fused_qkv( + adapter, n_kv_heads, head_dim, d_model, 2, ln1_scale=3.0 + ) + result = adapter.preprocess_weights(sd) + assert torch.all(result["blocks.0.ln1.weight"] == 1.0) + + def test_ln2_fold_applied_to_mlp_gate(self) -> None: + adapter = InternLM2ArchitectureAdapter(_make_cfg()) + adapter._fold_ln_requested = True + n_kv_heads, head_dim, d_model = 2, 8, 64 + sd = self._make_state_dict_with_fused_qkv(adapter, n_kv_heads, head_dim, d_model, 2) + # Override ln2 with scale=3.0 + sd["blocks.0.ln2.weight"] = torch.full((d_model,), 3.0) + result = adapter.preprocess_weights(sd) + assert torch.all(result["blocks.0.mlp.gate.weight"] == 3.0) + assert torch.all(result["blocks.0.mlp.in.weight"] == 3.0) + + def test_ln2_reset_to_ones(self) -> None: + adapter = InternLM2ArchitectureAdapter(_make_cfg()) + adapter._fold_ln_requested = True + n_kv_heads, head_dim, d_model = 2, 8, 64 + sd = self._make_state_dict_with_fused_qkv(adapter, n_kv_heads, head_dim, d_model, 2) + sd["blocks.0.ln2.weight"] = torch.full((d_model,), 5.0) + result = adapter.preprocess_weights(sd) + assert torch.all(result["blocks.0.ln2.weight"] == 1.0) + + def test_ln_final_fold_applied_to_unembed(self) -> None: + adapter = InternLM2ArchitectureAdapter(_make_cfg()) + adapter._fold_ln_requested = True + n_kv_heads, head_dim, d_model = 2, 8, 64 + sd = self._make_state_dict_with_fused_qkv(adapter, n_kv_heads, head_dim, d_model, 2) + sd["ln_final.weight"] = torch.full((d_model,), 2.0) + sd["unembed.weight"] = torch.ones(100, d_model) + result = adapter.preprocess_weights(sd) + assert torch.all(result["unembed.weight"] == 2.0) + assert torch.all(result["ln_final.weight"] == 1.0) + + def test_no_fold_when_not_requested(self) -> None: + adapter = InternLM2ArchitectureAdapter(_make_cfg()) + adapter._fold_ln_requested = False + n_kv_heads, head_dim, d_model = 2, 8, 64 + sd = self._make_state_dict_with_fused_qkv( + adapter, n_kv_heads, head_dim, d_model, 2, ln1_scale=5.0 + ) + result = adapter.preprocess_weights(sd) + # Fused key must still be present; no splitting or scaling + assert "blocks.0.attn.qkv.weight" in result + assert "blocks.0.attn.q.weight" not in result + + def test_dtype_preserved(self) -> None: + adapter = InternLM2ArchitectureAdapter(_make_cfg()) + adapter._fold_ln_requested = True + n_kv_heads, head_dim, d_model = 2, 8, 64 + sd = self._make_state_dict_with_fused_qkv(adapter, n_kv_heads, head_dim, d_model, 1) + # Cast to bfloat16 + sd = {k: v.to(torch.bfloat16) for k, v in sd.items()} + result = adapter.preprocess_weights(sd) + assert result["blocks.0.attn.q.weight"].dtype == torch.bfloat16 + + def test_bias_split_when_present(self) -> None: + """config.bias=True: fused bias must be split into q/k/v bias keys.""" + # Use consistent d_model/n_heads so head_dim = d_model // n_heads = 64 // 4 = 16 + n_heads, n_kv_heads, d_model = 4, 2, 64 + head_dim = d_model // n_heads # 16 + adapter = InternLM2ArchitectureAdapter( + _make_cfg(n_heads=n_heads, n_key_value_heads=n_kv_heads, d_model=d_model) + ) + adapter._fold_ln_requested = True + total_rows = (n_heads + 2 * n_kv_heads) * head_dim + sd: dict[str, torch.Tensor] = { + "blocks.0.attn.qkv.weight": torch.ones(total_rows, d_model), + "blocks.0.attn.qkv.bias": torch.zeros(total_rows), + "blocks.0.ln1.weight": torch.ones(d_model), + "blocks.0.ln2.weight": torch.ones(d_model), + "blocks.0.mlp.gate.weight": torch.ones(16, d_model), + "blocks.0.mlp.in.weight": torch.ones(16, d_model), + "ln_final.weight": torch.ones(d_model), + "unembed.weight": torch.ones(100, d_model), + } + result = adapter.preprocess_weights(sd) + assert "blocks.0.attn.qkv.bias" not in result + assert "blocks.0.attn.q.bias" in result + assert "blocks.0.attn.k.bias" in result + assert "blocks.0.attn.v.bias" in result + assert result["blocks.0.attn.q.bias"].shape == (n_heads * head_dim,) + assert result["blocks.0.attn.k.bias"].shape == (n_kv_heads * head_dim,) + assert result["blocks.0.attn.v.bias"].shape == (n_kv_heads * head_dim,) + + def test_all_layers_processed(self) -> None: + """Verify that all n_layers are processed, not just layer 0.""" + adapter = InternLM2ArchitectureAdapter(_make_cfg(n_layers=3)) + adapter._fold_ln_requested = True + n_kv_heads, head_dim, d_model = 2, 8, 64 + sd = self._make_state_dict_with_fused_qkv(adapter, n_kv_heads, head_dim, d_model, 3) + result = adapter.preprocess_weights(sd) + for i in range(3): + assert f"blocks.{i}.attn.qkv.weight" not in result + assert f"blocks.{i}.attn.q.weight" in result + + +# --------------------------------------------------------------------------- +# Phase D — Factory registration (will pass after Phase D implemented) +# --------------------------------------------------------------------------- + + +class TestInternLM2FactoryRegistration: + """Factory must map InternLM2ForCausalLM to InternLM2ArchitectureAdapter.""" + + def test_factory_returns_internlm2_adapter(self) -> None: + from transformer_lens.factories.architecture_adapter_factory import ( + ArchitectureAdapterFactory, + ) + + cfg = _make_cfg() + adapter = ArchitectureAdapterFactory.select_architecture_adapter(cfg) + assert isinstance( + adapter, InternLM2ArchitectureAdapter + ), f"Expected InternLM2ArchitectureAdapter, got {type(adapter).__name__}" + + def test_factory_key_in_supported_architectures(self) -> None: + from transformer_lens.factories.architecture_adapter_factory import ( + SUPPORTED_ARCHITECTURES, + ) + + assert "InternLM2ForCausalLM" in SUPPORTED_ARCHITECTURES diff --git a/transformer_lens/factories/architecture_adapter_factory.py b/transformer_lens/factories/architecture_adapter_factory.py index 1c6462cad..46a844eb6 100644 --- a/transformer_lens/factories/architecture_adapter_factory.py +++ b/transformer_lens/factories/architecture_adapter_factory.py @@ -23,6 +23,7 @@ GraniteMoeArchitectureAdapter, GraniteMoeHybridArchitectureAdapter, HubertArchitectureAdapter, + InternLM2ArchitectureAdapter, LlamaArchitectureAdapter, LlavaArchitectureAdapter, LlavaNextArchitectureAdapter, @@ -70,6 +71,7 @@ "GPTJForCausalLM": GptjArchitectureAdapter, "HubertForCTC": HubertArchitectureAdapter, "HubertModel": HubertArchitectureAdapter, + "InternLM2ForCausalLM": InternLM2ArchitectureAdapter, "LlamaForCausalLM": LlamaArchitectureAdapter, "LlavaForConditionalGeneration": LlavaArchitectureAdapter, "LlavaNextForConditionalGeneration": LlavaNextArchitectureAdapter, diff --git a/transformer_lens/model_bridge/compat.py b/transformer_lens/model_bridge/compat.py new file mode 100644 index 000000000..673784b6b --- /dev/null +++ b/transformer_lens/model_bridge/compat.py @@ -0,0 +1,44 @@ +"""Compatibility shims for transformers version differences. + +These patches are applied lazily (only when missing) so they're safe to call +from multiple adapters — the first caller wins, subsequent calls are no-ops. +""" + + +def patch_dynamic_cache_v5() -> None: + """Backfill DynamicCache methods removed in transformers v5. + + Remote-code models written for transformers v4 call from_legacy_cache, + to_legacy_cache, and get_usable_length which were removed in v5. + Call this from any adapter's prepare_loading() that needs them. + """ + try: + from transformers.cache_utils import DynamicCache + except Exception: + return + + if not hasattr(DynamicCache, "from_legacy_cache"): + + @classmethod # type: ignore[misc] + def _from_legacy_cache(cls, past_key_values=None): # type: ignore[no-untyped-def] + cache = cls() + if past_key_values is not None: + for idx, layer_past in enumerate(past_key_values): + cache.update(layer_past[0], layer_past[1], idx) + return cache + + DynamicCache.from_legacy_cache = _from_legacy_cache # type: ignore[attr-defined] + + if not hasattr(DynamicCache, "get_usable_length"): + + def _get_usable_length(self, new_seq_len: int = 0, layer_idx: int = 0) -> int: # type: ignore[no-untyped-def] + return self.get_seq_length(layer_idx) + + DynamicCache.get_usable_length = _get_usable_length # type: ignore[attr-defined] + + if not hasattr(DynamicCache, "to_legacy_cache"): + + def _to_legacy_cache(self): # type: ignore[no-untyped-def] + return tuple((layer.keys, layer.values) for layer in self.layers) + + DynamicCache.to_legacy_cache = _to_legacy_cache # type: ignore[attr-defined] diff --git a/transformer_lens/model_bridge/supported_architectures/__init__.py b/transformer_lens/model_bridge/supported_architectures/__init__.py index 1b24f3741..37cf4e38c 100644 --- a/transformer_lens/model_bridge/supported_architectures/__init__.py +++ b/transformer_lens/model_bridge/supported_architectures/__init__.py @@ -54,6 +54,9 @@ from transformer_lens.model_bridge.supported_architectures.hubert import ( HubertArchitectureAdapter, ) +from transformer_lens.model_bridge.supported_architectures.internlm2 import ( + InternLM2ArchitectureAdapter, +) from transformer_lens.model_bridge.supported_architectures.llama import ( LlamaArchitectureAdapter, ) @@ -148,6 +151,7 @@ "Gpt2LmHeadCustomArchitectureAdapter", "GptjArchitectureAdapter", "HubertArchitectureAdapter", + "InternLM2ArchitectureAdapter", "LlamaArchitectureAdapter", "LlavaArchitectureAdapter", "LlavaNextArchitectureAdapter", diff --git a/transformer_lens/model_bridge/supported_architectures/internlm2.py b/transformer_lens/model_bridge/supported_architectures/internlm2.py new file mode 100644 index 000000000..a5405e807 --- /dev/null +++ b/transformer_lens/model_bridge/supported_architectures/internlm2.py @@ -0,0 +1,324 @@ +"""InternLM2 architecture adapter.""" + +import sys +from typing import Any + +import torch +import torch.nn as nn + +from transformer_lens.conversion_utils.conversion_steps import RearrangeTensorConversion +from transformer_lens.conversion_utils.param_processing_conversion import ( + ParamProcessingConversion, +) +from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter +from transformer_lens.model_bridge.compat import patch_dynamic_cache_v5 +from transformer_lens.model_bridge.generalized_components import ( + BlockBridge, + EmbeddingBridge, + GatedMLPBridge, + JointQKVPositionEmbeddingsAttentionBridge, + LinearBridge, + RMSNormalizationBridge, + UnembeddingBridge, +) + + +class _InternLM2AttentionBridge(JointQKVPositionEmbeddingsAttentionBridge): + """Attention bridge returning 3-tuple for InternLM2's decoder layer contract. + + InternLM2's decoder layer unpacks (hidden_states, attn_weights, present_key_value) + from self.attention(), but the base bridge returns only (output, weights). + """ + + def _reconstruct_attention( + self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, **kwargs + ) -> tuple: + attn_output, attn_weights = super()._reconstruct_attention(q, k, v, **kwargs) + past_key_value = kwargs.get("past_key_values", kwargs.get("past_key_value", None)) + return (attn_output, attn_weights, past_key_value) + + +def _patch_init_weights_for_internlm2() -> None: + """Prevent _init_weights from re-randomizing loaded checkpoint weights. + + Transformers v5 calls _init_weights on all modules after weight + materialization. For modules with real (non-meta) tensors, we must + skip re-initialization to preserve the loaded checkpoint values. + Same approach as openelm.py. + """ + for key in list(sys.modules.keys()): + if "internlm2" not in key.lower() or "modeling" not in key.lower(): + continue + module = sys.modules[key] + pretrained_cls = getattr(module, "InternLM2PreTrainedModel", None) + if pretrained_cls is None or getattr(pretrained_cls, "_tl_patched", False): + continue + + original_init_weights = pretrained_cls._init_weights + + def safe_init_weights(self, mod, _original=original_init_weights): # type: ignore[no-untyped-def] + first_param = next(mod.parameters(), None) + if first_param is not None and first_param.device.type != "meta": + return + _original(self, mod) + + pretrained_cls._init_weights = safe_init_weights + pretrained_cls._tl_patched = True + + +class InternLM2ArchitectureAdapter(ArchitectureAdapter): + """Architecture adapter for InternLM2 models. + + InternLM2 uses remote code (trust_remote_code=True) and differs from Llama in: + - Fused interleaved GQA wqkv weight (not standard [Q|K|V] split) + - Non-standard module names: tok_embeddings, output, attention, feed_forward, + wqkv/wo, w1(gate)/w3(up)/w2(down), attention_norm, ffn_norm + - Per-layer rotary_emb (no model-level shared instance) + - supports_fold_ln=False: fold_ln is done manually in preprocess_weights because + the bridge state dict has the fused qkv key, not split q/k/v keys, so + fold_layer_norm's extract_attention_tensors_for_folding would silently skip attn. + + Optional parameters (may not exist in state_dict): + - blocks.{i}.attn.b_Q / b_K / b_V / b_O — config.bias=False on shipped models + - blocks.{i}.mlp.b_gate / b_in / b_out — MLP always bias=False + - blocks.{i}.ln1.b / ln2.b / ln_final.b — RMSNorm has no bias + """ + + def __init__(self, cfg: Any) -> None: + super().__init__(cfg) + + self.cfg.normalization_type = "RMS" + self.cfg.positional_embedding_type = "rotary" + self.cfg.final_rms = True + self.cfg.gated_mlp = True + self.cfg.attn_only = False + self.cfg.uses_rms_norm = True + self.cfg.eps_attr = "variance_epsilon" + + # Standard fold_ln silently skips attention when wqkv is fused (see class docstring). + # preprocess_weights() handles it instead — same approach as phi3.py. + self.supports_fold_ln = False + + if hasattr(cfg, "n_key_value_heads") and cfg.n_key_value_heads is not None: + self.cfg.n_key_value_heads = cfg.n_key_value_heads + + n_kv_heads = getattr(cfg, "n_key_value_heads", None) or cfg.n_heads + + self.weight_processing_conversions = { + "blocks.{i}.attn.q.weight": ParamProcessingConversion( + tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=cfg.n_heads), + ), + "blocks.{i}.attn.k.weight": ParamProcessingConversion( + tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=n_kv_heads), + ), + "blocks.{i}.attn.v.weight": ParamProcessingConversion( + tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=n_kv_heads), + ), + "blocks.{i}.attn.o.weight": ParamProcessingConversion( + tensor_conversion=RearrangeTensorConversion("m (n h) -> n h m", n=cfg.n_heads), + ), + } + + self.component_mapping = { + "embed": EmbeddingBridge(name="model.tok_embeddings"), + "blocks": BlockBridge( + name="model.layers", + submodules={ + "ln1": RMSNormalizationBridge(name="attention_norm", config=self.cfg), + "ln2": RMSNormalizationBridge(name="ffn_norm", config=self.cfg), + "attn": _InternLM2AttentionBridge( + name="attention", + config=self.cfg, + split_qkv_matrix=self._split_internlm2_wqkv, + submodules={ + "qkv": LinearBridge(name="wqkv"), + "o": LinearBridge(name="wo"), + }, + ), + "mlp": GatedMLPBridge( + name="feed_forward", + config=self.cfg, + submodules={ + "gate": LinearBridge(name="w1"), + "in": LinearBridge(name="w3"), + "out": LinearBridge(name="w2"), + }, + ), + }, + ), + "ln_final": RMSNormalizationBridge(name="model.norm", config=self.cfg), + "unembed": UnembeddingBridge(name="output", config=self.cfg), + } + + def _split_internlm2_wqkv( + self, attention_component: Any + ) -> tuple[nn.Linear, nn.Linear, nn.Linear]: + """Split InternLM2's interleaved wqkv into separate Q, K, V linear modules. + + InternLM2 uses an interleaved GQA layout rather than the standard [Q_all|K_all|V_all]. + For each of n_kv_heads groups, the weight rows are: + [q0, q1, ..., q(n_kv_groups-1), k, v] (each slot = head_dim rows) + i.e. gs = n_kv_groups + 2 slots per kv-head group. + """ + wqkv = attention_component.wqkv + w = wqkv.weight.data + d_model = w.shape[1] + has_bias = wqkv.bias is not None + + n_kv_heads = getattr(self.cfg, "n_key_value_heads", None) or self.cfg.n_heads + n_kv_groups = self.cfg.n_heads // n_kv_heads + head_dim = self.cfg.d_model // self.cfg.n_heads + gs = n_kv_groups + 2 + + w_grouped = w.reshape(n_kv_heads, gs, head_dim, d_model) + q_w = w_grouped[:, :n_kv_groups, :, :].reshape(self.cfg.n_heads * head_dim, d_model) + k_w = w_grouped[:, n_kv_groups, :, :].reshape(n_kv_heads * head_dim, d_model) + v_w = w_grouped[:, n_kv_groups + 1, :, :].reshape(n_kv_heads * head_dim, d_model) + + q_b: torch.Tensor | None = None + k_b: torch.Tensor | None = None + v_b: torch.Tensor | None = None + if has_bias: + b = wqkv.bias.data + b_grouped = b.reshape(n_kv_heads, gs, head_dim) + q_b = b_grouped[:, :n_kv_groups, :].reshape(self.cfg.n_heads * head_dim) + k_b = b_grouped[:, n_kv_groups, :].reshape(n_kv_heads * head_dim) + v_b = b_grouped[:, n_kv_groups + 1, :].reshape(n_kv_heads * head_dim) + + def _make_linear(weight: torch.Tensor, bias: torch.Tensor | None) -> nn.Linear: + lin = nn.Linear(d_model, weight.shape[0], bias=bias is not None) + lin.weight = nn.Parameter(weight) + if bias is not None: + lin.bias = nn.Parameter(bias) + return lin + + return _make_linear(q_w, q_b), _make_linear(k_w, k_b), _make_linear(v_w, v_b) + + def setup_component_testing(self, hf_model: Any, bridge_model: Any = None) -> None: + """Inject per-layer rotary embedding for component testing.""" + try: + rotary_emb = hf_model.model.layers[0].attention.rotary_emb + except (AttributeError, IndexError): + return + + if bridge_model is not None and hasattr(bridge_model, "blocks"): + for block in bridge_model.blocks: + if hasattr(block, "attn"): + block.attn.set_rotary_emb(rotary_emb) + + attn_bridge = self.get_generalized_component("blocks.0.attn") + attn_bridge.set_rotary_emb(rotary_emb) + + def prepare_loading(self, model_name: str, model_kwargs: dict) -> None: + """Patch transformers v5 incompatibilities before from_pretrained runs.""" + config = model_kwargs.get("config") + if config is not None: + tp = getattr(config, "pretraining_tp", 1) + if tp > 1: + raise ValueError( + f"InternLM2 adapter does not support pretraining_tp={tp}; " + "only pretraining_tp=1 is supported for logit correctness." + ) + + patch_dynamic_cache_v5() + + # Force-import the remote modeling module so we can patch _init_weights. + try: + from transformers.dynamic_module_utils import get_class_from_dynamic_module + + get_class_from_dynamic_module( + "modeling_internlm2.InternLM2ForCausalLM", + model_name, + ) + except Exception: + pass + + _patch_init_weights_for_internlm2() + + def preprocess_weights(self, state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + """Fold layer norms into QKV and MLP weights. + + Standard fold_ln can't reach split Q/K/V when wqkv is fused in the bridge state dict. + We extract and fold here, then write split keys so RearrangeTensorConversion can follow. + MLP projections (w1/w2/w3) are separate linears so they fold normally. + Mirrors phi3.py.preprocess_weights, adapted for InternLM2's layout. + """ + fold_ln = getattr(self, "_fold_ln_requested", True) + if not fold_ln: + return state_dict + + n_kv_heads = getattr(self.cfg, "n_key_value_heads", None) or self.cfg.n_heads + n_kv_groups = self.cfg.n_heads // n_kv_heads + head_dim = self.cfg.d_model // self.cfg.n_heads + gs = n_kv_groups + 2 + + for i in range(self.cfg.n_layers): + # --- Fold ln1 into Q/K/V (extracted from interleaved wqkv) --- + qkv_key = f"blocks.{i}.attn.qkv.weight" + ln1_key = f"blocks.{i}.ln1.weight" + if qkv_key in state_dict and ln1_key in state_dict: + ln1_w = state_dict[ln1_key].float() + qkv_w = state_dict[qkv_key].float() + d_model = qkv_w.shape[1] + orig_dtype = state_dict[qkv_key].dtype + + w_grouped = qkv_w.reshape(n_kv_heads, gs, head_dim, d_model) + q_w = w_grouped[:, :n_kv_groups, :, :].reshape(self.cfg.n_heads * head_dim, d_model) + k_w = w_grouped[:, n_kv_groups, :, :].reshape(n_kv_heads * head_dim, d_model) + v_w = w_grouped[:, n_kv_groups + 1, :, :].reshape(n_kv_heads * head_dim, d_model) + + state_dict[f"blocks.{i}.attn.q.weight"] = (q_w * ln1_w[None, :]).to(orig_dtype) + state_dict[f"blocks.{i}.attn.k.weight"] = (k_w * ln1_w[None, :]).to(orig_dtype) + state_dict[f"blocks.{i}.attn.v.weight"] = (v_w * ln1_w[None, :]).to(orig_dtype) + del state_dict[qkv_key] + state_dict[ln1_key] = torch.ones_like(state_dict[ln1_key]) + + qkv_bias_key = f"blocks.{i}.attn.qkv.bias" + if qkv_bias_key in state_dict: + b = state_dict[qkv_bias_key] + expected_len = (self.cfg.n_heads + 2 * n_kv_heads) * head_dim + if b.shape[0] != expected_len: + raise ValueError( + f"Unexpected wqkv bias shape at layer {i}: {b.shape[0]} " + f"(expected {expected_len}). Cannot split interleaved bias." + ) + orig_dtype = b.dtype + b_f = b.float() + b_grouped = b_f.reshape(n_kv_heads, gs, head_dim) + q_b = b_grouped[:, :n_kv_groups, :].reshape(self.cfg.n_heads * head_dim) + k_b = b_grouped[:, n_kv_groups, :].reshape(n_kv_heads * head_dim) + v_b = b_grouped[:, n_kv_groups + 1, :].reshape(n_kv_heads * head_dim) + state_dict[f"blocks.{i}.attn.q.bias"] = q_b.to(orig_dtype) + state_dict[f"blocks.{i}.attn.k.bias"] = k_b.to(orig_dtype) + state_dict[f"blocks.{i}.attn.v.bias"] = v_b.to(orig_dtype) + del state_dict[qkv_bias_key] + + # --- Fold ln2 into MLP gate (w1) and up (w3) projections --- + ln2_key = f"blocks.{i}.ln2.weight" + if ln2_key in state_dict: + ln2_w = state_dict[ln2_key].float() + for mlp_key in [ + f"blocks.{i}.mlp.gate.weight", + f"blocks.{i}.mlp.in.weight", + ]: + if mlp_key in state_dict: + orig_dtype = state_dict[mlp_key].dtype + state_dict[mlp_key] = (state_dict[mlp_key].float() * ln2_w[None, :]).to( + orig_dtype + ) + state_dict[ln2_key] = torch.ones_like(state_dict[ln2_key]) + + # --- Fold ln_final into unembed --- + ln_final_key = "ln_final.weight" + unembed_key = "unembed.weight" + if ln_final_key in state_dict and unembed_key in state_dict: + ln_w = state_dict[ln_final_key].float() + u_w = state_dict[unembed_key].float() + orig_dtype = state_dict[unembed_key].dtype + if u_w.shape[-1] == ln_w.shape[0]: + state_dict[unembed_key] = (u_w * ln_w[None, :]).to(orig_dtype) + elif u_w.shape[0] == ln_w.shape[0]: + state_dict[unembed_key] = (u_w * ln_w[:, None]).to(orig_dtype) + state_dict[ln_final_key] = torch.ones_like(state_dict[ln_final_key]) + + return state_dict diff --git a/transformer_lens/model_bridge/supported_architectures/phi3.py b/transformer_lens/model_bridge/supported_architectures/phi3.py index f365a2bfb..3dc712d2b 100644 --- a/transformer_lens/model_bridge/supported_architectures/phi3.py +++ b/transformer_lens/model_bridge/supported_architectures/phi3.py @@ -15,6 +15,7 @@ ParamProcessingConversion, ) from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter +from transformer_lens.model_bridge.compat import patch_dynamic_cache_v5 from transformer_lens.model_bridge.generalized_components import ( BlockBridge, EmbeddingBridge, @@ -238,40 +239,7 @@ def prepare_loading(self, model_name: str, model_kwargs: dict) -> None: if isinstance(rope_scaling, dict) and rope_scaling.get("rope_type") == "default": config.rope_scaling = None - # Monkey-patch DynamicCache methods removed in transformers v5. - try: - from transformers.cache_utils import DynamicCache - - if not hasattr(DynamicCache, "from_legacy_cache"): - - @classmethod # type: ignore[misc] - def _from_legacy_cache(cls, past_key_values=None): - cache = cls() - if past_key_values is not None: - for layer_idx, layer_past in enumerate(past_key_values): - cache.update(layer_past[0], layer_past[1], layer_idx) - return cache - - DynamicCache.from_legacy_cache = _from_legacy_cache # type: ignore[attr-defined] - - if not hasattr(DynamicCache, "get_usable_length"): - - def _get_usable_length(self, new_seq_len: int = 0, layer_idx: int = 0) -> int: - return self.get_seq_length(layer_idx) - - DynamicCache.get_usable_length = _get_usable_length # type: ignore[attr-defined] - - if not hasattr(DynamicCache, "to_legacy_cache"): - - def _to_legacy_cache(self): - legacy_cache = [] - for layer in self.layers: - legacy_cache.append((layer.keys, layer.values)) - return tuple(legacy_cache) - - DynamicCache.to_legacy_cache = _to_legacy_cache # type: ignore[attr-defined] - except Exception: - pass + patch_dynamic_cache_v5() def preprocess_weights(self, state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: """Fold layer norms into joint QKV/gate_up projections. diff --git a/transformer_lens/tools/model_registry/data/supported_models.json b/transformer_lens/tools/model_registry/data/supported_models.json index 6c2ce3aff..5af4f76eb 100644 --- a/transformer_lens/tools/model_registry/data/supported_models.json +++ b/transformer_lens/tools/model_registry/data/supported_models.json @@ -6,9 +6,9 @@ "min_downloads": 500, "scan_duration_seconds": 12.1 }, - "total_architectures": 36, - "total_models": 6686, - "total_verified": 690, + "total_architectures": 38, + "total_models": 6715, + "total_verified": 697, "models": [ { "architecture_id": "Qwen3ForCausalLM", @@ -93600,6 +93600,414 @@ "phase4_score": 67.5, "phase7_score": null, "phase8_score": null + }, + { + "architecture_id": "InternLM2ForCausalLM", + "model_id": "internlm/internlm2-chat-7b", + "status": 0, + "verified_date": null, + "metadata": { + "downloads": 46960, + "total_params": null + }, + "note": null, + "phase1_score": null, + "phase2_score": null, + "phase3_score": null, + "phase4_score": null, + "phase7_score": null, + "phase8_score": null + }, + { + "architecture_id": "InternLM2ForCausalLM", + "model_id": "internlm/internlm2_5-7b-chat", + "status": 0, + "verified_date": null, + "metadata": { + "downloads": 38575, + "total_params": 7737708544 + }, + "note": null, + "phase1_score": null, + "phase2_score": null, + "phase3_score": null, + "phase4_score": null, + "phase7_score": null, + "phase8_score": null + }, + { + "architecture_id": "InternLM2ForCausalLM", + "model_id": "internlm/internlm2-7b", + "status": 0, + "verified_date": null, + "metadata": { + "downloads": 24750, + "total_params": null + }, + "note": null, + "phase1_score": null, + "phase2_score": null, + "phase3_score": null, + "phase4_score": null, + "phase7_score": null, + "phase8_score": null + }, + { + "architecture_id": "InternLM2ForCausalLM", + "model_id": "internlm/internlm2-20b", + "status": 0, + "verified_date": null, + "metadata": { + "downloads": 21171, + "total_params": null + }, + "note": null, + "phase1_score": null, + "phase2_score": null, + "phase3_score": null, + "phase4_score": null, + "phase7_score": null, + "phase8_score": null + }, + { + "architecture_id": "InternLM2ForCausalLM", + "model_id": "internlm/internlm2-base-7b", + "status": 0, + "verified_date": null, + "metadata": { + "downloads": 19981, + "total_params": null + }, + "note": null, + "phase1_score": null, + "phase2_score": null, + "phase3_score": null, + "phase4_score": null, + "phase7_score": null, + "phase8_score": null + }, + { + "architecture_id": "InternLM2ForCausalLM", + "model_id": "internlm/internlm2-chat-20b", + "status": 0, + "verified_date": null, + "metadata": { + "downloads": 19627, + "total_params": 19861149696 + }, + "note": null, + "phase1_score": null, + "phase2_score": null, + "phase3_score": null, + "phase4_score": null, + "phase7_score": null, + "phase8_score": null + }, + { + "architecture_id": "InternLM2ForCausalLM", + "model_id": "internlm/internlm2-base-20b", + "status": 0, + "verified_date": null, + "metadata": { + "downloads": 17316, + "total_params": null + }, + "note": null, + "phase1_score": null, + "phase2_score": null, + "phase3_score": null, + "phase4_score": null, + "phase7_score": null, + "phase8_score": null + }, + { + "architecture_id": "InternLM2ForCausalLM", + "model_id": "chujiezheng/internlm2-chat-20b-ExPO", + "status": 0, + "verified_date": null, + "metadata": { + "downloads": 15121, + "total_params": 19861149696 + }, + "note": null, + "phase1_score": null, + "phase2_score": null, + "phase3_score": null, + "phase4_score": null, + "phase7_score": null, + "phase8_score": null + }, + { + "architecture_id": "InternLM2ForCausalLM", + "model_id": "chujiezheng/internlm2-chat-7b-ExPO", + "status": 0, + "verified_date": null, + "metadata": { + "downloads": 15102, + "total_params": 7737708544 + }, + "note": null, + "phase1_score": null, + "phase2_score": null, + "phase3_score": null, + "phase4_score": null, + "phase7_score": null, + "phase8_score": null + }, + { + "architecture_id": "InternLM2ForCausalLM", + "model_id": "internlm/internlm2-1_8b", + "status": 0, + "verified_date": null, + "metadata": { + "downloads": 5738, + "total_params": null + }, + "note": null, + "phase1_score": null, + "phase2_score": null, + "phase3_score": null, + "phase4_score": null, + "phase7_score": null, + "phase8_score": null + }, + { + "architecture_id": "InternLM2ForCausalLM", + "model_id": "internlm/internlm2-chat-1_8b", + "status": 1, + "verified_date": "2026-04-13", + "metadata": { + "downloads": 5069, + "total_params": 1889110016 + }, + "note": "Full verification completed", + "phase1_score": 100.0, + "phase2_score": 100.0, + "phase3_score": 100.0, + "phase4_score": 89.9, + "phase7_score": null, + "phase8_score": null + }, + { + "architecture_id": "InternLM2ForCausalLM", + "model_id": "optimum-internal-testing/tiny-random-internlm2", + "status": 0, + "verified_date": null, + "metadata": { + "downloads": 4311, + "total_params": 24052864 + }, + "note": null, + "phase1_score": null, + "phase2_score": null, + "phase3_score": null, + "phase4_score": null, + "phase7_score": null, + "phase8_score": null + }, + { + "architecture_id": "InternLM2ForCausalLM", + "model_id": "internlm/internlm2-chat-7b-sft", + "status": 0, + "verified_date": null, + "metadata": { + "downloads": 3852, + "total_params": 7737708544 + }, + "note": null, + "phase1_score": null, + "phase2_score": null, + "phase3_score": null, + "phase4_score": null, + "phase7_score": null, + "phase8_score": null + }, + { + "architecture_id": "InternLM2ForCausalLM", + "model_id": "internlm/internlm2_5-7b", + "status": 0, + "verified_date": null, + "metadata": { + "downloads": 3290, + "total_params": 7737708544 + }, + "note": null, + "phase1_score": null, + "phase2_score": null, + "phase3_score": null, + "phase4_score": null, + "phase7_score": null, + "phase8_score": null + }, + { + "architecture_id": "InternLM2ForCausalLM", + "model_id": "internlm/internlm2_5-1_8b", + "status": 1, + "verified_date": "2026-04-13", + "metadata": { + "downloads": 3020, + "total_params": null + }, + "note": "Full verification completed", + "phase1_score": 100.0, + "phase2_score": 100.0, + "phase3_score": 100.0, + "phase4_score": 75.5, + "phase7_score": null, + "phase8_score": null + }, + { + "architecture_id": "InternLM2ForCausalLM", + "model_id": "Mar2Ding/songcomposer_sft", + "status": 0, + "verified_date": null, + "metadata": { + "downloads": 1947, + "total_params": null + }, + "note": null, + "phase1_score": null, + "phase2_score": null, + "phase3_score": null, + "phase4_score": null, + "phase7_score": null, + "phase8_score": null + }, + { + "architecture_id": "InternLM2ForCausalLM", + "model_id": "internlm/internlm2_5-step-prover", + "status": 0, + "verified_date": null, + "metadata": { + "downloads": 1783, + "total_params": null + }, + "note": null, + "phase1_score": null, + "phase2_score": null, + "phase3_score": null, + "phase4_score": null, + "phase7_score": null, + "phase8_score": null + }, + { + "architecture_id": "InternLM2ForCausalLM", + "model_id": "internlm/internlm2_5-1_8b-chat", + "status": 0, + "verified_date": null, + "metadata": { + "downloads": 1491, + "total_params": 1889110016 + }, + "note": null, + "phase1_score": null, + "phase2_score": null, + "phase3_score": null, + "phase4_score": null, + "phase7_score": null, + "phase8_score": null + }, + { + "architecture_id": "InternLM2ForCausalLM", + "model_id": "AI4Chem/ChemLLM-7B-Chat-1_5-DPO", + "status": 0, + "verified_date": null, + "metadata": { + "downloads": 1350, + "total_params": 7737708544 + }, + "note": null, + "phase1_score": null, + "phase2_score": null, + "phase3_score": null, + "phase4_score": null, + "phase7_score": null, + "phase8_score": null + }, + { + "architecture_id": "InternLM2ForCausalLM", + "model_id": "internlm/internlm2-chat-20b-4bits", + "status": 0, + "verified_date": null, + "metadata": { + "downloads": 1301, + "total_params": null + }, + "note": null, + "phase1_score": null, + "phase2_score": null, + "phase3_score": null, + "phase4_score": null, + "phase7_score": null, + "phase8_score": null + }, + { + "architecture_id": "InternLM2ForCausalLM", + "model_id": "internlm/internlm2_5-7b-chat-1m", + "status": 0, + "verified_date": null, + "metadata": { + "downloads": 1148, + "total_params": 7737708544 + }, + "note": null, + "phase1_score": null, + "phase2_score": null, + "phase3_score": null, + "phase4_score": null, + "phase7_score": null, + "phase8_score": null + }, + { + "architecture_id": "InternLM2ForCausalLM", + "model_id": "internlm/internlm2_5-20b-chat", + "status": 0, + "verified_date": null, + "metadata": { + "downloads": 791, + "total_params": 19861149696 + }, + "note": null, + "phase1_score": null, + "phase2_score": null, + "phase3_score": null, + "phase4_score": null, + "phase7_score": null, + "phase8_score": null + }, + { + "architecture_id": "InternLM2ForCausalLM", + "model_id": "AI4Chem/CHEMLLM-2b-1_5", + "status": 0, + "verified_date": null, + "metadata": { + "downloads": 789, + "total_params": 1889110016 + }, + "note": null, + "phase1_score": null, + "phase2_score": null, + "phase3_score": null, + "phase4_score": null, + "phase7_score": null, + "phase8_score": null + }, + { + "architecture_id": "InternLM2ForCausalLM", + "model_id": "Lin-Chen/ShareCaptioner-Video", + "status": 0, + "verified_date": null, + "metadata": { + "downloads": 754, + "total_params": null + }, + "note": null, + "phase1_score": null, + "phase2_score": null, + "phase3_score": null, + "phase4_score": null, + "phase7_score": null, + "phase8_score": null } ] } diff --git a/transformer_lens/tools/model_registry/data/verification_history.json b/transformer_lens/tools/model_registry/data/verification_history.json index 9eb2e7648..0deb099f1 100644 --- a/transformer_lens/tools/model_registry/data/verification_history.json +++ b/transformer_lens/tools/model_registry/data/verification_history.json @@ -1,5 +1,5 @@ { - "last_updated": "2026-04-09T16:34:36.818082", + "last_updated": "2026-04-13T12:12:14.839053", "records": [ { "model_id": "Macropodus/macbert4mdcspell_v1", @@ -11200,6 +11200,126 @@ "notes": "Full verification completed", "invalidated": false, "invalidation_reason": null + }, + { + "model_id": "internlm/internlm2-chat-1_8b", + "architecture_id": "InternLM2ForCausalLM", + "verified_date": "2026-04-13", + "verified_by": "verify_models", + "transformerlens_version": null, + "notes": "Below threshold: P1=0.0% < 100.0% (failed: all_components, forward_pass); P2=8.3% < 75.0% (failed: g \u2014 74/123 components failed (74 critical)", + "invalidated": false, + "invalidation_reason": null + }, + { + "model_id": "internlm/internlm2_5-1_8b", + "architecture_id": "InternLM2ForCausalLM", + "verified_date": "2026-04-13", + "verified_by": "verify_models", + "transformerlens_version": null, + "notes": "Below threshold: P1=0.0% < 100.0% (failed: all_components, forward_pass); P2=8.3% < 75.0% (failed: g \u2014 74/123 components failed (74 critical)", + "invalidated": false, + "invalidation_reason": null + }, + { + "model_id": "internlm/internlm2-chat-1_8b", + "architecture_id": "InternLM2ForCausalLM", + "verified_date": "2026-04-13", + "verified_by": "verify_models", + "transformerlens_version": null, + "notes": "Below threshold: P1=0.0% < 100.0% (failed: all_components, forward_pass); P2=8.3% < 75.0% (failed: g \u2014 74/123 components failed (74 critical)", + "invalidated": false, + "invalidation_reason": null + }, + { + "model_id": "internlm/internlm2-chat-1_8b", + "architecture_id": "InternLM2ForCausalLM", + "verified_date": "2026-04-13", + "verified_by": "verify_models", + "transformerlens_version": null, + "notes": "Below threshold: P1=0.0% < 100.0% (failed: all_components, forward_pass_logits) \u2014 74/123 components failed (74 critical)", + "invalidated": false, + "invalidation_reason": null + }, + { + "model_id": "internlm/internlm2_5-1_8b", + "architecture_id": "InternLM2ForCausalLM", + "verified_date": "2026-04-13", + "verified_by": "verify_models", + "transformerlens_version": null, + "notes": "Below threshold: P1=0.0% < 100.0% (failed: all_components, forward_pass_logits) \u2014 74/123 components failed (74 critical)", + "invalidated": false, + "invalidation_reason": null + }, + { + "model_id": "internlm/internlm2-chat-1_8b", + "architecture_id": "InternLM2ForCausalLM", + "verified_date": "2026-04-13", + "verified_by": "verify_models", + "transformerlens_version": null, + "notes": "Below threshold: P1=0.0% < 100.0% (failed: all_components, forward_pass_logits) \u2014 74/123 components failed (74 critical)", + "invalidated": false, + "invalidation_reason": null + }, + { + "model_id": "internlm/internlm2-chat-1_8b", + "architecture_id": "InternLM2ForCausalLM", + "verified_date": "2026-04-13", + "verified_by": "verify_models", + "transformerlens_version": null, + "notes": "Full verification completed", + "invalidated": false, + "invalidation_reason": null + }, + { + "model_id": "internlm/internlm2_5-1_8b", + "architecture_id": "InternLM2ForCausalLM", + "verified_date": "2026-04-13", + "verified_by": "verify_models", + "transformerlens_version": null, + "notes": "Below threshold: P1=0.0% < 100.0% (failed: all_components, forward_pass_logits) \u2014 74/123 components failed (74 critical)", + "invalidated": false, + "invalidation_reason": null + }, + { + "model_id": "internlm/internlm2_5-1_8b", + "architecture_id": "InternLM2ForCausalLM", + "verified_date": "2026-04-13", + "verified_by": "verify_models", + "transformerlens_version": null, + "notes": "Full verification completed", + "invalidated": false, + "invalidation_reason": null + }, + { + "model_id": "internlm/internlm2-chat-1_8b", + "architecture_id": "InternLM2ForCausalLM", + "verified_date": "2026-04-13", + "verified_by": "verify_models", + "transformerlens_version": null, + "notes": "Full verification completed", + "invalidated": false, + "invalidation_reason": null + }, + { + "model_id": "internlm/internlm2_5-1_8b", + "architecture_id": "InternLM2ForCausalLM", + "verified_date": "2026-04-13", + "verified_by": "verify_models", + "transformerlens_version": null, + "notes": "Full verification completed", + "invalidated": false, + "invalidation_reason": null + }, + { + "model_id": "internlm/internlm2-chat-1_8b", + "architecture_id": "InternLM2ForCausalLM", + "verified_date": "2026-04-13", + "verified_by": "verify_models", + "transformerlens_version": null, + "notes": "Full verification completed", + "invalidated": false, + "invalidation_reason": null } ] } diff --git a/transformer_lens/tools/model_registry/verify_models.py b/transformer_lens/tools/model_registry/verify_models.py index a31e91a63..92c26d251 100644 --- a/transformer_lens/tools/model_registry/verify_models.py +++ b/transformer_lens/tools/model_registry/verify_models.py @@ -57,6 +57,12 @@ logger = logging.getLogger(__name__) +# Architectures added via the TransformerBridge system that need trust_remote_code=True. +# These are not in the legacy NEED_REMOTE_CODE_MODELS tuple (loading_from_pretrained.py). +_BRIDGE_REMOTE_CODE_PREFIXES: tuple[str, ...] = ( + "internlm/", # InternLM2ForCausalLM — ships own modeling_internlm2.py +) + # Data directory for registry files _DATA_DIR = Path(__file__).parent / "data" _CHECKPOINT_PATH = _DATA_DIR / "verification_checkpoint.json" @@ -174,7 +180,8 @@ def estimate_model_params(model_id: str) -> int: from transformer_lens.loading_from_pretrained import NEED_REMOTE_CODE_MODELS - trust_remote_code = any(model_id.startswith(prefix) for prefix in NEED_REMOTE_CODE_MODELS) + _all_remote_prefixes = NEED_REMOTE_CODE_MODELS + _BRIDGE_REMOTE_CODE_PREFIXES + trust_remote_code = any(model_id.startswith(prefix) for prefix in _all_remote_prefixes) from transformer_lens.utilities.hf_utils import get_hf_token config = AutoConfig.from_pretrained( @@ -812,7 +819,8 @@ def verify_models( from transformer_lens.loading_from_pretrained import NEED_REMOTE_CODE_MODELS - needs_remote_code = any(model_id.startswith(prefix) for prefix in NEED_REMOTE_CODE_MODELS) + _all_remote_prefixes = NEED_REMOTE_CODE_MODELS + _BRIDGE_REMOTE_CODE_PREFIXES + needs_remote_code = any(model_id.startswith(prefix) for prefix in _all_remote_prefixes) # Convert string dtype to torch.dtype for benchmark suite import torch