diff --git a/tests/unit/model_bridge/generalized_components/test_falcon_alibi_attention.py b/tests/unit/model_bridge/generalized_components/test_alibi_joint_qkv_attention.py similarity index 91% rename from tests/unit/model_bridge/generalized_components/test_falcon_alibi_attention.py rename to tests/unit/model_bridge/generalized_components/test_alibi_joint_qkv_attention.py index 1054be575..d725d3787 100644 --- a/tests/unit/model_bridge/generalized_components/test_falcon_alibi_attention.py +++ b/tests/unit/model_bridge/generalized_components/test_alibi_joint_qkv_attention.py @@ -1,4 +1,4 @@ -"""Unit tests for FalconALiBiAttentionBridge. +"""Unit tests for ALiBiJointQKVAttentionBridge. Exercises the reimplemented ALiBi attention with mock weights — no model download needed. Covers MHA, MQA, and GQA head configurations to catch shape mismatches. @@ -6,13 +6,13 @@ import torch -from transformer_lens.model_bridge.generalized_components.falcon_alibi_attention import ( - FalconALiBiAttentionBridge, +from transformer_lens.model_bridge.generalized_components.alibi_joint_qkv_attention import ( + ALiBiJointQKVAttentionBridge, ) class _MockConfig: - """Minimal config for FalconALiBiAttentionBridge.""" + """Minimal config for ALiBiJointQKVAttentionBridge.""" def __init__(self, n_heads: int, d_model: int, n_key_value_heads: int | None = None): self.n_heads = n_heads @@ -33,8 +33,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: def _build_bridge( n_heads: int, d_model: int, n_key_value_heads: int | None = None -) -> FalconALiBiAttentionBridge: - """Build a wired-up FalconALiBiAttentionBridge with random Q/K/V weights.""" +) -> ALiBiJointQKVAttentionBridge: + """Build a wired-up ALiBiJointQKVAttentionBridge with random Q/K/V weights.""" cfg = _MockConfig(n_heads, d_model, n_key_value_heads) head_dim = d_model // n_heads n_kv = n_key_value_heads or n_heads @@ -47,7 +47,7 @@ def _build_bridge( def split_qkv(_component): return q_linear, k_linear, v_linear - bridge = FalconALiBiAttentionBridge( + bridge = ALiBiJointQKVAttentionBridge( name="self_attention", config=cfg, split_qkv_matrix=split_qkv, @@ -58,12 +58,12 @@ def split_qkv(_component): return bridge -def _random_inputs(bridge: FalconALiBiAttentionBridge, batch: int = 2, seq: int = 6): +def _random_inputs(bridge: ALiBiJointQKVAttentionBridge, batch: int = 2, seq: int = 6): """Generate random inputs via the bridge's own method.""" return bridge.get_random_inputs(batch_size=batch, seq_len=seq) -class TestFalconALiBiForward: +class TestALiBiJointQKVForward: """Forward pass runs and produces valid output for all head configs.""" def test_mha_forward(self): diff --git a/tests/unit/model_bridge/supported_architectures/test_gpt_bigcode_adapter.py b/tests/unit/model_bridge/supported_architectures/test_gpt_bigcode_adapter.py new file mode 100644 index 000000000..a5f91d020 --- /dev/null +++ b/tests/unit/model_bridge/supported_architectures/test_gpt_bigcode_adapter.py @@ -0,0 +1,532 @@ +"""Unit tests for GPTBigCodeArchitectureAdapter. + +Tests cover: +- Config attribute validation +- Component mapping structure (correct bridge types and HF module paths) +- Weight conversion keys +- MQAQKVConversionRule (Q and K/V branches, revert, passthrough) +- _split_qkv_matrix correctness (shapes, bias, no-bias, value correctness) +- multi_query assertion in _split_qkv_matrix +- End-to-end hook shapes with a fake MQA attention module (no downloads) +- Factory registration +""" + +from typing import Any + +import pytest +import torch +import torch.nn as nn + +from transformer_lens.config import TransformerBridgeConfig +from transformer_lens.factories.architecture_adapter_factory import ( + SUPPORTED_ARCHITECTURES, + ArchitectureAdapterFactory, +) +from transformer_lens.model_bridge.generalized_components import ( + BlockBridge, + EmbeddingBridge, + JointQKVAttentionBridge, + LinearBridge, + MLPBridge, + NormalizationBridge, + PosEmbedBridge, + UnembeddingBridge, +) +from transformer_lens.model_bridge.supported_architectures.gpt_bigcode import ( + GPTBigCodeArchitectureAdapter, + MQAQKVConversionRule, +) + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_cfg( + n_heads: int = 4, + d_model: int = 64, + n_layers: int = 2, + d_mlp: int = 256, + d_vocab: int = 100, + n_ctx: int = 64, +) -> TransformerBridgeConfig: + """Minimal TransformerBridgeConfig for GPTBigCode adapter tests.""" + return TransformerBridgeConfig( + d_model=d_model, + d_head=d_model // n_heads, + n_layers=n_layers, + n_ctx=n_ctx, + n_heads=n_heads, + d_vocab=d_vocab, + d_mlp=d_mlp, + n_key_value_heads=1, + default_prepend_bos=True, + architecture="GPTBigCodeForCausalLM", + ) + + +class FakeMQAAttention(nn.Module): + """Minimal GPTBigCodeAttention-like module for testing (no downloaded weights).""" + + def __init__(self, d_model: int, d_head: int, multi_query: bool = True) -> None: + super().__init__() + self.multi_query = multi_query + # MQA: c_attn output = embed_dim + 2*head_dim + out_features = d_model + 2 * d_head if multi_query else 3 * d_model + self.c_attn = nn.Linear(d_model, out_features) + self.c_proj = nn.Linear(d_model, d_model) + + def forward(self, x: torch.Tensor) -> torch.Tensor: # pragma: no cover + return self.c_proj(x) + + +@pytest.fixture +def cfg() -> TransformerBridgeConfig: + return _make_cfg() + + +@pytest.fixture +def adapter(cfg: TransformerBridgeConfig) -> GPTBigCodeArchitectureAdapter: + return GPTBigCodeArchitectureAdapter(cfg) + + +# --------------------------------------------------------------------------- +# Config attribute tests +# --------------------------------------------------------------------------- + + +class TestGPTBigCodeAdapterConfig: + """Verifies all required config attributes are set correctly.""" + + def test_normalization_type_is_ln(self, adapter: GPTBigCodeArchitectureAdapter) -> None: + assert adapter.cfg.normalization_type == "LN" + + def test_positional_embedding_type_is_standard( + self, adapter: GPTBigCodeArchitectureAdapter + ) -> None: + assert adapter.cfg.positional_embedding_type == "standard" + + def test_final_rms_is_false(self, adapter: GPTBigCodeArchitectureAdapter) -> None: + assert adapter.cfg.final_rms is False + + def test_gated_mlp_is_false(self, adapter: GPTBigCodeArchitectureAdapter) -> None: + assert adapter.cfg.gated_mlp is False + + def test_attn_only_is_false(self, adapter: GPTBigCodeArchitectureAdapter) -> None: + assert adapter.cfg.attn_only is False + + def test_n_key_value_heads_is_one(self, adapter: GPTBigCodeArchitectureAdapter) -> None: + assert adapter.cfg.n_key_value_heads == 1 + + +# --------------------------------------------------------------------------- +# Component mapping structure tests +# --------------------------------------------------------------------------- + + +class TestGPTBigCodeAdapterComponentMapping: + """Verifies component_mapping has the correct bridge types and HF paths.""" + + def test_embed_is_embedding_bridge(self, adapter: GPTBigCodeArchitectureAdapter) -> None: + assert isinstance(adapter.component_mapping["embed"], EmbeddingBridge) + + def test_embed_name(self, adapter: GPTBigCodeArchitectureAdapter) -> None: + assert adapter.component_mapping["embed"].name == "transformer.wte" + + def test_pos_embed_is_pos_embed_bridge(self, adapter: GPTBigCodeArchitectureAdapter) -> None: + assert isinstance(adapter.component_mapping["pos_embed"], PosEmbedBridge) + + def test_pos_embed_name(self, adapter: GPTBigCodeArchitectureAdapter) -> None: + assert adapter.component_mapping["pos_embed"].name == "transformer.wpe" + + def test_blocks_is_block_bridge(self, adapter: GPTBigCodeArchitectureAdapter) -> None: + assert isinstance(adapter.component_mapping["blocks"], BlockBridge) + + def test_blocks_name(self, adapter: GPTBigCodeArchitectureAdapter) -> None: + assert adapter.component_mapping["blocks"].name == "transformer.h" + + def test_ln1_is_normalization_bridge(self, adapter: GPTBigCodeArchitectureAdapter) -> None: + blocks = adapter.component_mapping["blocks"] + assert isinstance(blocks.submodules["ln1"], NormalizationBridge) + + def test_ln1_name(self, adapter: GPTBigCodeArchitectureAdapter) -> None: + blocks = adapter.component_mapping["blocks"] + assert blocks.submodules["ln1"].name == "ln_1" + + def test_attn_is_gpt_bigcode_bridge(self, adapter: GPTBigCodeArchitectureAdapter) -> None: + blocks = adapter.component_mapping["blocks"] + assert isinstance(blocks.submodules["attn"], JointQKVAttentionBridge) + + def test_attn_name(self, adapter: GPTBigCodeArchitectureAdapter) -> None: + blocks = adapter.component_mapping["blocks"] + assert blocks.submodules["attn"].name == "attn" + + def test_attn_qkv_is_linear_bridge(self, adapter: GPTBigCodeArchitectureAdapter) -> None: + blocks = adapter.component_mapping["blocks"] + assert isinstance(blocks.submodules["attn"].submodules["qkv"], LinearBridge) + + def test_attn_qkv_name(self, adapter: GPTBigCodeArchitectureAdapter) -> None: + blocks = adapter.component_mapping["blocks"] + assert blocks.submodules["attn"].submodules["qkv"].name == "c_attn" + + def test_attn_o_is_linear_bridge(self, adapter: GPTBigCodeArchitectureAdapter) -> None: + blocks = adapter.component_mapping["blocks"] + assert isinstance(blocks.submodules["attn"].submodules["o"], LinearBridge) + + def test_attn_o_name(self, adapter: GPTBigCodeArchitectureAdapter) -> None: + blocks = adapter.component_mapping["blocks"] + assert blocks.submodules["attn"].submodules["o"].name == "c_proj" + + def test_ln2_is_normalization_bridge(self, adapter: GPTBigCodeArchitectureAdapter) -> None: + blocks = adapter.component_mapping["blocks"] + assert isinstance(blocks.submodules["ln2"], NormalizationBridge) + + def test_ln2_name(self, adapter: GPTBigCodeArchitectureAdapter) -> None: + blocks = adapter.component_mapping["blocks"] + assert blocks.submodules["ln2"].name == "ln_2" + + def test_mlp_is_mlp_bridge(self, adapter: GPTBigCodeArchitectureAdapter) -> None: + blocks = adapter.component_mapping["blocks"] + assert isinstance(blocks.submodules["mlp"], MLPBridge) + + def test_mlp_name(self, adapter: GPTBigCodeArchitectureAdapter) -> None: + blocks = adapter.component_mapping["blocks"] + assert blocks.submodules["mlp"].name == "mlp" + + def test_mlp_in_name(self, adapter: GPTBigCodeArchitectureAdapter) -> None: + blocks = adapter.component_mapping["blocks"] + assert blocks.submodules["mlp"].submodules["in"].name == "c_fc" + + def test_mlp_out_name(self, adapter: GPTBigCodeArchitectureAdapter) -> None: + blocks = adapter.component_mapping["blocks"] + assert blocks.submodules["mlp"].submodules["out"].name == "c_proj" + + def test_ln_final_is_normalization_bridge(self, adapter: GPTBigCodeArchitectureAdapter) -> None: + assert isinstance(adapter.component_mapping["ln_final"], NormalizationBridge) + + def test_ln_final_name(self, adapter: GPTBigCodeArchitectureAdapter) -> None: + assert adapter.component_mapping["ln_final"].name == "transformer.ln_f" + + def test_unembed_is_unembedding_bridge(self, adapter: GPTBigCodeArchitectureAdapter) -> None: + assert isinstance(adapter.component_mapping["unembed"], UnembeddingBridge) + + def test_unembed_name(self, adapter: GPTBigCodeArchitectureAdapter) -> None: + assert adapter.component_mapping["unembed"].name == "lm_head" + + +# --------------------------------------------------------------------------- +# Weight processing conversion tests +# --------------------------------------------------------------------------- + + +class TestGPTBigCodeAdapterWeightConversions: + """Verifies weight_processing_conversions has expected keys.""" + + def test_q_weight_key_present(self, adapter: GPTBigCodeArchitectureAdapter) -> None: + assert "blocks.{i}.attn.q.weight" in adapter.weight_processing_conversions + + def test_k_weight_key_present(self, adapter: GPTBigCodeArchitectureAdapter) -> None: + assert "blocks.{i}.attn.k.weight" in adapter.weight_processing_conversions + + def test_v_weight_key_present(self, adapter: GPTBigCodeArchitectureAdapter) -> None: + assert "blocks.{i}.attn.v.weight" in adapter.weight_processing_conversions + + def test_o_weight_key_present(self, adapter: GPTBigCodeArchitectureAdapter) -> None: + assert "blocks.{i}.attn.o.weight" in adapter.weight_processing_conversions + + def test_exactly_four_conversion_keys(self, adapter: GPTBigCodeArchitectureAdapter) -> None: + assert len(adapter.weight_processing_conversions) == 4 + + +# --------------------------------------------------------------------------- +# MQAQKVConversionRule tests +# --------------------------------------------------------------------------- + + +class TestMQAQKVConversionRule: + """Verifies the branching QKV activation rearrangement for MQA.""" + + N_HEADS = 4 + D_HEAD = 16 + D_MODEL = N_HEADS * D_HEAD # 64 + BATCH, SEQ = 2, 8 + + @pytest.fixture + def rule(self) -> MQAQKVConversionRule: + return MQAQKVConversionRule(n_heads=self.N_HEADS, d_head=self.D_HEAD) + + def test_q_shaped_input_gives_n_heads_dimension(self, rule: MQAQKVConversionRule) -> None: + """Q input [batch, seq, embed_dim] -> [batch, seq, n_heads, d_head].""" + x = torch.randn(self.BATCH, self.SEQ, self.D_MODEL) + out = rule.handle_conversion(x) + assert out.shape == (self.BATCH, self.SEQ, self.N_HEADS, self.D_HEAD) + + def test_kv_shaped_input_gives_one_head_dimension(self, rule: MQAQKVConversionRule) -> None: + """K/V input [batch, seq, head_dim] -> [batch, seq, 1, d_head].""" + x = torch.randn(self.BATCH, self.SEQ, self.D_HEAD) + out = rule.handle_conversion(x) + assert out.shape == (self.BATCH, self.SEQ, 1, self.D_HEAD) + + def test_4d_input_passes_through_unchanged(self, rule: MQAQKVConversionRule) -> None: + """4D input is already in heads format — return as-is.""" + x = torch.randn(self.BATCH, self.SEQ, self.N_HEADS, self.D_HEAD) + out = rule.handle_conversion(x) + assert out.shape == x.shape + assert torch.equal(out, x) + + def test_revert_q_shaped(self, rule: MQAQKVConversionRule) -> None: + """revert undoes handle_conversion for Q-shaped input.""" + x = torch.randn(self.BATCH, self.SEQ, self.D_MODEL) + converted = rule.handle_conversion(x) + reverted = rule.revert(converted) + assert reverted.shape == x.shape + assert torch.allclose(reverted, x) + + def test_revert_kv_shaped(self, rule: MQAQKVConversionRule) -> None: + """revert undoes handle_conversion for K/V-shaped input.""" + x = torch.randn(self.BATCH, self.SEQ, self.D_HEAD) + converted = rule.handle_conversion(x) + reverted = rule.revert(converted) + assert reverted.shape == x.shape + assert torch.allclose(reverted, x) + + def test_revert_3d_passes_through(self, rule: MQAQKVConversionRule) -> None: + """revert on a 3D tensor (already flat) is a no-op.""" + x = torch.randn(self.BATCH, self.SEQ, self.D_MODEL) + out = rule.revert(x) + assert torch.equal(out, x) + + def test_invalid_ndim_raises(self, rule: MQAQKVConversionRule) -> None: + with pytest.raises(ValueError, match="Expected 3D or 4D"): + rule.handle_conversion(torch.randn(self.D_MODEL)) + + +# --------------------------------------------------------------------------- +# _split_qkv_matrix tests +# --------------------------------------------------------------------------- + + +class TestGPTBigCodeMQASplitQKVMatrix: + """Numerical correctness tests for the MQA asymmetric QKV split.""" + + N_HEADS = 4 + D_MODEL = 64 + D_HEAD = D_MODEL // N_HEADS # 16 + BATCH, SEQ = 2, 8 + + @pytest.fixture + def adapter(self) -> GPTBigCodeArchitectureAdapter: + cfg = _make_cfg(n_heads=self.N_HEADS, d_model=self.D_MODEL) + return GPTBigCodeArchitectureAdapter(cfg) + + @pytest.fixture + def fake_attn(self) -> FakeMQAAttention: + return FakeMQAAttention(self.D_MODEL, self.D_HEAD, multi_query=True) + + @pytest.fixture + def fake_attn_nobias(self) -> FakeMQAAttention: + attn = FakeMQAAttention(self.D_MODEL, self.D_HEAD, multi_query=True) + # Remove bias from c_attn + attn.c_attn = nn.Linear(self.D_MODEL, self.D_MODEL + 2 * self.D_HEAD, bias=False) + return attn + + def test_returns_three_linear_modules( + self, adapter: GPTBigCodeArchitectureAdapter, fake_attn: FakeMQAAttention + ) -> None: + q, k, v = adapter._split_qkv_matrix(fake_attn) + assert isinstance(q, nn.Linear) + assert isinstance(k, nn.Linear) + assert isinstance(v, nn.Linear) + + def test_q_weight_shape( + self, adapter: GPTBigCodeArchitectureAdapter, fake_attn: FakeMQAAttention + ) -> None: + q, _, _ = adapter._split_qkv_matrix(fake_attn) + assert q.weight.shape == (self.D_MODEL, self.D_MODEL) + + def test_k_weight_shape( + self, adapter: GPTBigCodeArchitectureAdapter, fake_attn: FakeMQAAttention + ) -> None: + _, k, _ = adapter._split_qkv_matrix(fake_attn) + assert k.weight.shape == (self.D_HEAD, self.D_MODEL) + + def test_v_weight_shape( + self, adapter: GPTBigCodeArchitectureAdapter, fake_attn: FakeMQAAttention + ) -> None: + _, _, v = adapter._split_qkv_matrix(fake_attn) + assert v.weight.shape == (self.D_HEAD, self.D_MODEL) + + def test_q_bias_shape( + self, adapter: GPTBigCodeArchitectureAdapter, fake_attn: FakeMQAAttention + ) -> None: + q, _, _ = adapter._split_qkv_matrix(fake_attn) + assert q.bias is not None + assert q.bias.shape == (self.D_MODEL,) + + def test_k_bias_shape( + self, adapter: GPTBigCodeArchitectureAdapter, fake_attn: FakeMQAAttention + ) -> None: + _, k, _ = adapter._split_qkv_matrix(fake_attn) + assert k.bias is not None + assert k.bias.shape == (self.D_HEAD,) + + def test_v_bias_shape( + self, adapter: GPTBigCodeArchitectureAdapter, fake_attn: FakeMQAAttention + ) -> None: + _, _, v = adapter._split_qkv_matrix(fake_attn) + assert v.bias is not None + assert v.bias.shape == (self.D_HEAD,) + + def test_no_bias_case_all_none( + self, adapter: GPTBigCodeArchitectureAdapter, fake_attn_nobias: FakeMQAAttention + ) -> None: + q, k, v = adapter._split_qkv_matrix(fake_attn_nobias) + assert q.bias is None + assert k.bias is None + assert v.bias is None + + def test_q_k_v_weights_are_distinct( + self, adapter: GPTBigCodeArchitectureAdapter, fake_attn: FakeMQAAttention + ) -> None: + """With non-trivial c_attn weight, Q/K/V must differ.""" + nn.init.normal_(fake_attn.c_attn.weight) + q, k, v = adapter._split_qkv_matrix(fake_attn) + # K and V have the same shape [d_head, d_model] so compare directly + assert not torch.allclose(k.weight, v.weight), "K and V weights must differ" + + def test_q_forward_output_shape( + self, adapter: GPTBigCodeArchitectureAdapter, fake_attn: FakeMQAAttention + ) -> None: + q, _, _ = adapter._split_qkv_matrix(fake_attn) + x = torch.randn(self.BATCH, self.SEQ, self.D_MODEL) + assert q(x).shape == (self.BATCH, self.SEQ, self.D_MODEL) + + def test_k_forward_output_shape( + self, adapter: GPTBigCodeArchitectureAdapter, fake_attn: FakeMQAAttention + ) -> None: + _, k, _ = adapter._split_qkv_matrix(fake_attn) + x = torch.randn(self.BATCH, self.SEQ, self.D_MODEL) + assert k(x).shape == (self.BATCH, self.SEQ, self.D_HEAD) + + def test_v_forward_output_shape( + self, adapter: GPTBigCodeArchitectureAdapter, fake_attn: FakeMQAAttention + ) -> None: + _, _, v = adapter._split_qkv_matrix(fake_attn) + x = torch.randn(self.BATCH, self.SEQ, self.D_MODEL) + assert v(x).shape == (self.BATCH, self.SEQ, self.D_HEAD) + + def test_weight_values_match_c_attn_rows( + self, adapter: GPTBigCodeArchitectureAdapter, fake_attn: FakeMQAAttention + ) -> None: + """Q/K/V weight rows must exactly match the corresponding rows of c_attn.weight.""" + nn.init.normal_(fake_attn.c_attn.weight) + original_weight = fake_attn.c_attn.weight.detach() + q, k, v = adapter._split_qkv_matrix(fake_attn) + assert torch.equal(q.weight, original_weight[: self.D_MODEL]) + assert torch.equal(k.weight, original_weight[self.D_MODEL : self.D_MODEL + self.D_HEAD]) + assert torch.equal(v.weight, original_weight[self.D_MODEL + self.D_HEAD :]) + + def test_multi_query_false_raises_assertion( + self, adapter: GPTBigCodeArchitectureAdapter + ) -> None: + """Adapter must raise AssertionError for multi_query=False checkpoints.""" + mha_attn = FakeMQAAttention(self.D_MODEL, self.D_HEAD, multi_query=False) + with pytest.raises(AssertionError, match="multi_query=True"): + adapter._split_qkv_matrix(mha_attn) + + +# --------------------------------------------------------------------------- +# End-to-end hook shape tests +# --------------------------------------------------------------------------- + + +class TestGPTBigCodeHookShapes: + """End-to-end forward pass verifying hook_q/hook_k/hook_v shapes. + + Uses a fake MQA attention nn.Module (no model downloads). Registers explicit + hooks on hook_out so that hook_conversion (MQAQKVConversionRule) fires and + the captured tensors reflect the converted shapes. + """ + + N_HEADS = 4 + D_MODEL = 64 + D_HEAD = D_MODEL // N_HEADS # 16 + BATCH, SEQ = 2, 8 + + @pytest.fixture + def adapter(self) -> GPTBigCodeArchitectureAdapter: + cfg = _make_cfg(n_heads=self.N_HEADS, d_model=self.D_MODEL) + return GPTBigCodeArchitectureAdapter(cfg) + + @pytest.fixture + def wired_attn_bridge(self, adapter: GPTBigCodeArchitectureAdapter) -> JointQKVAttentionBridge: + """Return attn bridge wired to a fake MQA attention module.""" + fake_attn = FakeMQAAttention(self.D_MODEL, self.D_HEAD, multi_query=True) + blocks = adapter.component_mapping["blocks"] + attn_bridge: JointQKVAttentionBridge = blocks.submodules["attn"] # type: ignore[assignment] + attn_bridge.set_original_component(fake_attn) + return attn_bridge + + def _run_and_capture( + self, attn_bridge: JointQKVAttentionBridge + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Register hooks on q/k/v hook_out, run forward, return captured tensors.""" + captured: dict[str, torch.Tensor] = {} + + def _capture(name: str) -> Any: + def _hook(x: torch.Tensor, hook: Any) -> torch.Tensor: + captured[name] = x + return x + + return _hook + + attn_bridge.q.hook_out.add_hook(_capture("q")) + attn_bridge.k.hook_out.add_hook(_capture("k")) + attn_bridge.v.hook_out.add_hook(_capture("v")) + + hidden = torch.randn(self.BATCH, self.SEQ, self.D_MODEL) + attn_bridge(hidden) + + return captured["q"], captured["k"], captured["v"] + + def test_hook_q_shape(self, wired_attn_bridge: JointQKVAttentionBridge) -> None: + """hook_q must be [batch, seq, n_heads, d_head].""" + q, _, _ = self._run_and_capture(wired_attn_bridge) + assert q.shape == (self.BATCH, self.SEQ, self.N_HEADS, self.D_HEAD) + + def test_hook_k_shape(self, wired_attn_bridge: JointQKVAttentionBridge) -> None: + """hook_k must be [batch, seq, 1, d_head] (1 KV head).""" + _, k, _ = self._run_and_capture(wired_attn_bridge) + assert k.shape == (self.BATCH, self.SEQ, 1, self.D_HEAD) + + def test_hook_v_shape(self, wired_attn_bridge: JointQKVAttentionBridge) -> None: + """hook_v must be [batch, seq, 1, d_head] (1 KV head).""" + _, _, v = self._run_and_capture(wired_attn_bridge) + assert v.shape == (self.BATCH, self.SEQ, 1, self.D_HEAD) + + def test_attn_output_shape(self, wired_attn_bridge: JointQKVAttentionBridge) -> None: + """Full attention output must be [batch, seq, d_model].""" + hidden = torch.randn(self.BATCH, self.SEQ, self.D_MODEL) + out = wired_attn_bridge(hidden) + out_tensor = out[0] if isinstance(out, tuple) else out + assert out_tensor.shape == (self.BATCH, self.SEQ, self.D_MODEL) + + +# --------------------------------------------------------------------------- +# Factory registration tests +# --------------------------------------------------------------------------- + + +class TestGPTBigCodeFactoryRegistration: + """Verifies the factory maps GPTBigCodeForCausalLM to the correct adapter.""" + + def test_factory_key_present(self) -> None: + assert "GPTBigCodeForCausalLM" in SUPPORTED_ARCHITECTURES + + def test_factory_maps_to_correct_adapter_class(self) -> None: + assert SUPPORTED_ARCHITECTURES["GPTBigCodeForCausalLM"] is GPTBigCodeArchitectureAdapter + + def test_factory_returns_correct_instance(self) -> None: + cfg = _make_cfg() + adapter = ArchitectureAdapterFactory.select_architecture_adapter(cfg) + assert isinstance(adapter, GPTBigCodeArchitectureAdapter) diff --git a/transformer_lens/factories/architecture_adapter_factory.py b/transformer_lens/factories/architecture_adapter_factory.py index 6d4ca4964..691d53c2e 100644 --- a/transformer_lens/factories/architecture_adapter_factory.py +++ b/transformer_lens/factories/architecture_adapter_factory.py @@ -18,6 +18,7 @@ Gemma3MultimodalArchitectureAdapter, GPT2ArchitectureAdapter, Gpt2LmHeadCustomArchitectureAdapter, + GPTBigCodeArchitectureAdapter, GptjArchitectureAdapter, GPTOSSArchitectureAdapter, GraniteArchitectureAdapter, @@ -71,6 +72,7 @@ "GraniteMoeForCausalLM": GraniteMoeArchitectureAdapter, "GraniteMoeHybridForCausalLM": GraniteMoeHybridArchitectureAdapter, "GPT2LMHeadModel": GPT2ArchitectureAdapter, + "GPTBigCodeForCausalLM": GPTBigCodeArchitectureAdapter, "GptOssForCausalLM": GPTOSSArchitectureAdapter, "GPT2LMHeadCustomModel": Gpt2LmHeadCustomArchitectureAdapter, "GPTJForCausalLM": GptjArchitectureAdapter, diff --git a/transformer_lens/model_bridge/generalized_components/__init__.py b/transformer_lens/model_bridge/generalized_components/__init__.py index 83c450bd7..518baca09 100644 --- a/transformer_lens/model_bridge/generalized_components/__init__.py +++ b/transformer_lens/model_bridge/generalized_components/__init__.py @@ -32,8 +32,8 @@ from transformer_lens.model_bridge.generalized_components.embedding import ( EmbeddingBridge, ) -from transformer_lens.model_bridge.generalized_components.falcon_alibi_attention import ( - FalconALiBiAttentionBridge, +from transformer_lens.model_bridge.generalized_components.alibi_joint_qkv_attention import ( + ALiBiJointQKVAttentionBridge, ) from transformer_lens.model_bridge.generalized_components.gated_mlp import ( GatedMLPBridge, @@ -98,7 +98,7 @@ "ConvPosEmbedBridge", "DepthwiseConv1DBridge", "EmbeddingBridge", - "FalconALiBiAttentionBridge", + "ALiBiJointQKVAttentionBridge", "RotaryEmbeddingBridge", "PosEmbedBridge", "NormalizationBridge", diff --git a/transformer_lens/model_bridge/generalized_components/falcon_alibi_attention.py b/transformer_lens/model_bridge/generalized_components/alibi_joint_qkv_attention.py similarity index 95% rename from transformer_lens/model_bridge/generalized_components/falcon_alibi_attention.py rename to transformer_lens/model_bridge/generalized_components/alibi_joint_qkv_attention.py index c98f2c89f..acb1c6a53 100644 --- a/transformer_lens/model_bridge/generalized_components/falcon_alibi_attention.py +++ b/transformer_lens/model_bridge/generalized_components/alibi_joint_qkv_attention.py @@ -1,6 +1,6 @@ -"""Falcon ALiBi attention bridge component. +"""ALiBi joint QKV attention bridge component. -Handles Falcon models that use ALiBi (Attention with Linear Biases) instead of RoPE. +Handles models that use ALiBi (Attention with Linear Biases) with fused QKV projections. Splits fused QKV, reimplements attention with ALiBi bias and hooks at each stage. """ @@ -21,8 +21,8 @@ ) -class FalconALiBiAttentionBridge(JointQKVAttentionBridge): - """Attention bridge for Falcon models using ALiBi position encoding. +class ALiBiJointQKVAttentionBridge(JointQKVAttentionBridge): + """Attention bridge for models using ALiBi position encoding with fused QKV. Splits fused QKV, reimplements attention with ALiBi bias fused into scores, and fires hooks at each stage (hook_q, hook_k, hook_v, hook_attn_scores, diff --git a/transformer_lens/model_bridge/generalized_components/joint_qkv_attention.py b/transformer_lens/model_bridge/generalized_components/joint_qkv_attention.py index 530cff3fb..c93248867 100644 --- a/transformer_lens/model_bridge/generalized_components/joint_qkv_attention.py +++ b/transformer_lens/model_bridge/generalized_components/joint_qkv_attention.py @@ -376,12 +376,21 @@ def _reconstruct_attention( assert self.original_component is not None assert self.config is not None num_heads = self.config.n_heads + num_kv_heads = getattr(self.config, "n_key_value_heads", None) or num_heads - q, k, v, batch_size, seq_len, head_dim = self._reshape_qkv_to_heads(q, k, v, num_heads) + q, k, v, batch_size, seq_len, head_dim = self._reshape_qkv_to_heads( + q, k, v, num_heads, num_kv_heads + ) # KV cache: extend K/V with cached positions. k, v = self._update_kv_cache(k, v, **kwargs) + # GQA/MQA: expand K/V heads to match Q heads + if num_kv_heads != num_heads: + n_rep = num_heads // num_kv_heads + k = k.repeat_interleave(n_rep, dim=1) + v = v.repeat_interleave(n_rep, dim=1) + # Attention scale: 1/sqrt(d_head) with optional inverse-layer scaling scale = head_dim ** (-0.5) if ( diff --git a/transformer_lens/model_bridge/supported_architectures/__init__.py b/transformer_lens/model_bridge/supported_architectures/__init__.py index 3ed80b776..83395b34a 100644 --- a/transformer_lens/model_bridge/supported_architectures/__init__.py +++ b/transformer_lens/model_bridge/supported_architectures/__init__.py @@ -36,6 +36,9 @@ from transformer_lens.model_bridge.supported_architectures.gpt2 import ( GPT2ArchitectureAdapter, ) +from transformer_lens.model_bridge.supported_architectures.gpt_bigcode import ( + GPTBigCodeArchitectureAdapter, +) from transformer_lens.model_bridge.supported_architectures.gpt2_lm_head_custom import ( Gpt2LmHeadCustomArchitectureAdapter, ) @@ -160,6 +163,7 @@ "GraniteMoeArchitectureAdapter", "GraniteMoeHybridArchitectureAdapter", "GPT2ArchitectureAdapter", + "GPTBigCodeArchitectureAdapter", "GPTOSSArchitectureAdapter", "Gpt2LmHeadCustomArchitectureAdapter", "GptjArchitectureAdapter", diff --git a/transformer_lens/model_bridge/supported_architectures/falcon.py b/transformer_lens/model_bridge/supported_architectures/falcon.py index 42759d612..e552a07f4 100644 --- a/transformer_lens/model_bridge/supported_architectures/falcon.py +++ b/transformer_lens/model_bridge/supported_architectures/falcon.py @@ -16,9 +16,9 @@ ) from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter from transformer_lens.model_bridge.generalized_components import ( + ALiBiJointQKVAttentionBridge, BlockBridge, EmbeddingBridge, - FalconALiBiAttentionBridge, JointQKVPositionEmbeddingsAttentionBridge, LinearBridge, MLPBridge, @@ -102,7 +102,7 @@ def __init__(self, cfg: Any) -> None: if self._is_alibi: # ALiBi: reimplement attention with ALiBi bias fused into scores. # Splits fused QKV and fires hooks at each stage for mech interp. - attn_bridge: Any = FalconALiBiAttentionBridge( + attn_bridge: Any = ALiBiJointQKVAttentionBridge( name="self_attention", config=self.cfg, split_qkv_matrix=self._split_falcon_qkv, diff --git a/transformer_lens/model_bridge/supported_architectures/gpt_bigcode.py b/transformer_lens/model_bridge/supported_architectures/gpt_bigcode.py new file mode 100644 index 000000000..102c4f2c7 --- /dev/null +++ b/transformer_lens/model_bridge/supported_architectures/gpt_bigcode.py @@ -0,0 +1,182 @@ +"""GPTBigCode architecture adapter.""" + +from typing import Any + +import einops +import torch +import torch.nn as nn + +from transformer_lens.conversion_utils.conversion_steps.base_tensor_conversion import ( + BaseTensorConversion, +) +from transformer_lens.conversion_utils.param_processing_conversion import ( + ParamProcessingConversion, +) +from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter +from transformer_lens.model_bridge.generalized_components import ( + BlockBridge, + EmbeddingBridge, + JointQKVAttentionBridge, + LinearBridge, + MLPBridge, + NormalizationBridge, + PosEmbedBridge, + UnembeddingBridge, +) + + +class MQAQKVConversionRule(BaseTensorConversion): + """Rearranges Q/K/V activations for MQA. + + Q output has embed_dim features -> rearrange with n=n_heads. + K/V output has head_dim features (1 KV head) -> rearrange with n=1. + """ + + def __init__(self, n_heads: int, d_head: int) -> None: + super().__init__() + self.n_heads = n_heads + self.d_head = d_head + + def handle_conversion(self, input_value: torch.Tensor, *_: Any) -> torch.Tensor: + if input_value.ndim == 4: + return input_value # already [batch, seq, heads, head_dim] + if input_value.ndim != 3: + raise ValueError( + f"Expected 3D or 4D tensor, got {input_value.ndim}D with shape {input_value.shape}" + ) + last_dim: int = input_value.shape[2] + # Q: last_dim == n_heads * d_head; K/V: last_dim == d_head (1 head) + n = self.n_heads if last_dim == self.n_heads * self.d_head else 1 + return einops.rearrange(input_value, "batch seq (n h) -> batch seq n h", n=n) + + def revert(self, input_value: torch.Tensor, *_: Any) -> torch.Tensor: + if input_value.ndim == 3: + return input_value + return einops.rearrange(input_value, "batch seq n h -> batch seq (n h)") + + +class GPTBigCodeArchitectureAdapter(ArchitectureAdapter): + """Architecture adapter for GPTBigCode models. + + GPTBigCode is a GPT-2 variant using Multi-Query Attention (MQA): a single + fused c_attn projection whose output splits asymmetrically into + [embed_dim, head_dim, head_dim] for Q/K/V (rather than three equal thirds). + All other structure (module paths, LayerNorm, learned pos embeddings, + standard MLP) is identical to GPT-2. + + All public models use multi_query=True (1 KV head). The adapter assumes + MQA throughout. + + All linear layers have biases (c_attn, c_proj, c_fc, mlp.c_proj). + lm_head has no bias and its weight is tied to transformer.wte.weight. + + Weight layout difference from GPT-2: GPTBigCode uses nn.Linear (weights + stored [out, in]) rather than GPT-2's Conv1D ([in, out]), so no unembed + weight transpose is needed. + """ + + def __init__(self, cfg: Any) -> None: + super().__init__(cfg) + + self.cfg.normalization_type = "LN" + self.cfg.positional_embedding_type = "standard" + self.cfg.final_rms = False + self.cfg.gated_mlp = False + self.cfg.attn_only = False + self.cfg.uses_rms_norm = False + self.cfg.eps_attr = "layer_norm_epsilon" + self.cfg.n_key_value_heads = 1 # MQA: always 1 KV head + + # Mirror GPT-2 combined-QKV flags + self.default_cfg = {"uses_split_attention": True} + self.uses_combined_qkv = True + self.cfg.split_attention_weights = True + + # Use the base helper; n_kv_heads=1 gives correct (n h) m -> n m h with n=1 for K/V + self.weight_processing_conversions: dict[str, ParamProcessingConversion] = { # type: ignore[assignment] + **self._qkvo_weight_conversions(n_kv_heads=1), + } + + _mqa_rule = MQAQKVConversionRule(n_heads=self.cfg.n_heads, d_head=self.cfg.d_head) + + # GPTBigCode's HF eager_attention_forward only applies causal masking + # when attention_mask is not None. Setting requires_attention_mask with + # attention_mask_4d ensures component tests provide a 4D mask so both + # HF and bridge forward passes receive compatible mask shapes. + _attn_bridge = JointQKVAttentionBridge( + name="attn", + config=self.cfg, + split_qkv_matrix=self._split_qkv_matrix, + qkv_conversion_rule=_mqa_rule, + requires_attention_mask=True, + submodules={ + "qkv": LinearBridge(name="c_attn"), + "o": LinearBridge(name="c_proj"), + }, + ) + _attn_bridge.attention_mask_4d = True + + self.component_mapping = { + "embed": EmbeddingBridge(name="transformer.wte"), + "pos_embed": PosEmbedBridge(name="transformer.wpe"), + "blocks": BlockBridge( + name="transformer.h", + config=self.cfg, + submodules={ + "ln1": NormalizationBridge(name="ln_1", config=self.cfg), + "attn": _attn_bridge, + "ln2": NormalizationBridge(name="ln_2", config=self.cfg), + "mlp": MLPBridge( + name="mlp", + submodules={ + "in": LinearBridge(name="c_fc"), + "out": LinearBridge(name="c_proj"), + }, + ), + }, + ), + "ln_final": NormalizationBridge(name="transformer.ln_f", config=self.cfg), + "unembed": UnembeddingBridge(name="lm_head"), + } + + def _split_qkv_matrix( + self, original_attention_component: Any + ) -> tuple[nn.Linear, nn.Linear, nn.Linear]: + """Split MQA c_attn into separate Q, K, V linears. + + c_attn is nn.Linear with weight shape [embed_dim + 2*head_dim, embed_dim]. + Split along dim=0 (output features): [embed_dim, head_dim, head_dim]. + + Returns nn.Linear modules with shapes: + Q: [embed_dim, embed_dim] (n_heads * d_head output features) + K: [head_dim, embed_dim] (1 KV head) + V: [head_dim, embed_dim] (1 KV head) + """ + # Guard against multi_query=False checkpoints (MHA), which would require + # an equal 3-way split and different hook shapes. + assert getattr(original_attention_component, "multi_query", True), ( + "GPTBigCodeArchitectureAdapter only supports multi_query=True models. " + "For multi_query=False checkpoints, a separate MHA adapter is needed." + ) + + c_attn = original_attention_component.c_attn + embed_dim = self.cfg.d_model + head_dim = self.cfg.d_head + + q_w, k_w, v_w = c_attn.weight.split([embed_dim, head_dim, head_dim], dim=0) + + has_bias = c_attn.bias is not None + q_b: torch.Tensor | None = None + k_b: torch.Tensor | None = None + v_b: torch.Tensor | None = None + if has_bias: + q_b, k_b, v_b = c_attn.bias.split([embed_dim, head_dim, head_dim]) + + def _make_linear(w: torch.Tensor, b: torch.Tensor | None) -> nn.Linear: + lin = nn.Linear(w.shape[1], w.shape[0], bias=b is not None) + lin.weight = nn.Parameter(w) + if b is not None: + lin.bias = nn.Parameter(b) + return lin + + return _make_linear(q_w, q_b), _make_linear(k_w, k_b), _make_linear(v_w, v_b) diff --git a/transformer_lens/tools/model_registry/data/supported_models.json b/transformer_lens/tools/model_registry/data/supported_models.json index 432780e60..74540d33d 100644 --- a/transformer_lens/tools/model_registry/data/supported_models.json +++ b/transformer_lens/tools/model_registry/data/supported_models.json @@ -12961,12 +12961,12 @@ "architecture_id": "FalconForCausalLM", "model_id": "fxmarty/really-tiny-falcon-testing", "status": 1, - "verified_date": "2026-04-09", + "verified_date": "2026-04-13", "metadata": null, - "note": "Full verification completed with issues, low text quality: P3=95.0% (failed: process_bridge_weights)", + "note": "Full verification completed with issues, low text quality", "phase1_score": 100.0, "phase2_score": 100.0, - "phase3_score": 95.0, + "phase3_score": 100.0, "phase4_score": 34.3, "phase7_score": null, "phase8_score": null @@ -97801,6 +97801,23 @@ "phase7_score": null, "phase8_score": null }, + { + "architecture_id": "GPTBigCodeForCausalLM", + "model_id": "bigcode/gpt_bigcode-santacoder", + "status": 1, + "verified_date": "2026-04-10", + "metadata": { + "downloads": 52761, + "total_params": 1124886528 + }, + "note": "Full verification completed", + "phase1_score": 100.0, + "phase2_score": 100.0, + "phase3_score": 100.0, + "phase4_score": 64.4, + "phase7_score": null, + "phase8_score": null + }, { "architecture_id": "CohereForCausalLM", "model_id": "trl-internal-testing/tiny-CohereForCausalLM", @@ -98379,6 +98396,40 @@ "phase7_score": null, "phase8_score": null }, + { + "architecture_id": "GPTBigCodeForCausalLM", + "model_id": "bigcode/tiny_starcoder_py", + "status": 1, + "verified_date": "2026-04-13", + "metadata": { + "downloads": 17845, + "total_params": 164144128 + }, + "note": "Full verification completed", + "phase1_score": 100.0, + "phase2_score": 100.0, + "phase3_score": 100.0, + "phase4_score": 86.7, + "phase7_score": null, + "phase8_score": null + }, + { + "architecture_id": "Qwen3MoeForCausalLM", + "model_id": "imdatta0/tiny_qwen3_moe_2.8B_0.7B", + "status": 1, + "verified_date": "2026-04-10", + "metadata": { + "downloads": 218, + "total_params": 2800000000 + }, + "note": "Full verification completed", + "phase1_score": 100.0, + "phase2_score": 100.0, + "phase3_score": 100.0, + "phase4_score": 70.4, + "phase7_score": null, + "phase8_score": null + }, { "architecture_id": "Qwen3MoeForCausalLM", "model_id": "imdatta0/tiny_qwen3_moe_2.8B_0.7B", diff --git a/transformer_lens/tools/model_registry/data/verification_history.json b/transformer_lens/tools/model_registry/data/verification_history.json index 21d21c369..e76ff022a 100644 --- a/transformer_lens/tools/model_registry/data/verification_history.json +++ b/transformer_lens/tools/model_registry/data/verification_history.json @@ -1,5 +1,5 @@ { - "last_updated": "2026-04-10T19:09:48.784882", + "last_updated": "2026-04-13T09:46:38.844770", "records": [ { "model_id": "Macropodus/macbert4mdcspell_v1", @@ -11202,32 +11202,92 @@ "invalidation_reason": null }, { - "model_id": "trl-internal-testing/tiny-CohereForCausalLM", - "architecture_id": "CohereForCausalLM", + "model_id": "bigcode/gpt_bigcode-santacoder", + "architecture_id": "GPTBigCodeForCausalLM", "verified_date": "2026-04-10", "verified_by": "verify_models", "transformerlens_version": null, - "notes": "Full verification completed with issues: P3=94.7% (failed: weight_modification)", + "notes": "Below threshold: P1=50.0% < 100.0% (failed: all_components) \u2014 24/148 components failed (24 critical)", "invalidated": false, "invalidation_reason": null }, { - "model_id": "trl-internal-testing/tiny-CohereForCausalLM", - "architecture_id": "CohereForCausalLM", + "model_id": "bigcode/tiny_starcoder_py", + "architecture_id": "GPTBigCodeForCausalLM", "verified_date": "2026-04-10", "verified_by": "verify_models", "transformerlens_version": null, - "notes": "Full verification completed with issues: P3=94.7% (failed: weight_modification)", + "notes": "Below threshold: P1=50.0% < 100.0% (failed: all_components) \u2014 20/124 components failed (20 critical)", "invalidated": false, "invalidation_reason": null }, { - "model_id": "trl-internal-testing/tiny-CohereForCausalLM", - "architecture_id": "CohereForCausalLM", + "model_id": "bigcode/tiny_starcoder_py", + "architecture_id": "GPTBigCodeForCausalLM", "verified_date": "2026-04-10", "verified_by": "verify_models", "transformerlens_version": null, - "notes": "Full verification completed with issues: P3=94.7% (failed: weight_modification)", + "notes": "Below threshold: P1=50.0% < 100.0% (failed: all_components) \u2014 20/124 components failed (20 critical)", + "invalidated": false, + "invalidation_reason": null + }, + { + "model_id": "bigcode/tiny_starcoder_py", + "architecture_id": "GPTBigCodeForCausalLM", + "verified_date": "2026-04-10", + "verified_by": "verify_models", + "transformerlens_version": null, + "notes": "Full verification completed", + "invalidated": false, + "invalidation_reason": null + }, + { + "model_id": "bigcode/gpt_bigcode-santacoder", + "architecture_id": "GPTBigCodeForCausalLM", + "verified_date": "2026-04-10", + "verified_by": "verify_models", + "transformerlens_version": null, + "notes": "Full verification completed", + "invalidated": false, + "invalidation_reason": null + }, + { + "model_id": "fxmarty/really-tiny-falcon-testing", + "architecture_id": "FalconForCausalLM", + "verified_date": "2026-04-13", + "verified_by": "verify_models", + "transformerlens_version": null, + "notes": "Full verification completed with issues, low text quality", + "invalidated": false, + "invalidation_reason": null + }, + { + "model_id": "bigcode/tiny_starcoder_py", + "architecture_id": "GPTBigCodeForCausalLM", + "verified_date": "2026-04-13", + "verified_by": "verify_models", + "transformerlens_version": null, + "notes": "Below threshold: P1=50.0% < 100.0% (failed: all_components) \u2014 20/124 components failed (20 critical)", + "invalidated": false, + "invalidation_reason": null + }, + { + "model_id": "bigcode/tiny_starcoder_py", + "architecture_id": "GPTBigCodeForCausalLM", + "verified_date": "2026-04-13", + "verified_by": "verify_models", + "transformerlens_version": null, + "notes": "Full verification completed", + "invalidated": false, + "invalidation_reason": null + }, + { + "model_id": "fxmarty/really-tiny-falcon-testing", + "architecture_id": "FalconForCausalLM", + "verified_date": "2026-04-13", + "verified_by": "verify_models", + "transformerlens_version": null, + "notes": "Full verification completed with issues, low text quality", "invalidated": false, "invalidation_reason": null }