From 02bae3d88ebab6e160ba1c3a9beb4bbdaf2f87a6 Mon Sep 17 00:00:00 2001 From: jlarson4 Date: Wed, 15 Apr 2026 14:12:27 -0500 Subject: [PATCH 1/3] Fixing issue with exposed hooks --- .../compatibility/test_hook_completeness.py | 11 +- .../model_bridge/test_component_hooks_fire.py | 175 ++++++++++++++++++ .../unit/model_bridge/test_component_setup.py | 1 + .../test_hook_alias_resolution.py | 130 +++++++++++++ .../generalized_components/__init__.py | 6 +- .../generalized_components/base.py | 42 +++-- .../generalized_components/block.py | 56 +++++- .../position_embeddings_attention.py | 132 ++++++++----- .../supported_architectures/codegen.py | 4 +- .../supported_architectures/cohere.py | 4 +- .../supported_architectures/falcon.py | 5 +- .../supported_architectures/gptj.py | 5 +- .../supported_architectures/neox.py | 5 +- .../supported_architectures/phi.py | 4 +- .../supported_architectures/pythia.py | 8 +- .../model_registry/data/supported_models.json | 8 +- 16 files changed, 505 insertions(+), 91 deletions(-) create mode 100644 tests/unit/model_bridge/test_component_hooks_fire.py create mode 100644 tests/unit/model_bridge/test_hook_alias_resolution.py diff --git a/tests/acceptance/model_bridge/compatibility/test_hook_completeness.py b/tests/acceptance/model_bridge/compatibility/test_hook_completeness.py index 78896053b..84c6a2d22 100644 --- a/tests/acceptance/model_bridge/compatibility/test_hook_completeness.py +++ b/tests/acceptance/model_bridge/compatibility/test_hook_completeness.py @@ -17,10 +17,15 @@ pytestmark = pytest.mark.slow -# Diverse architectures for hook completeness testing +# Diverse architectures for hook completeness testing. +# Constraint: these tests compare bridge vs legacy HookedTransformer, so each +# entry must be in HookedTransformer's OFFICIAL_MODEL_NAMES. Tiny C1-affected +# families (Llama/Qwen/Gemma under ~150M) aren't registered with HT; for those, +# tests/unit/model_bridge/test_component_hooks_fire.py (Tier 2) provides +# direct per-adapter hook-firing coverage without needing an HT counterpart. MODELS_TO_TEST = [ - "gpt2", # Standard decoder-only with joint QKV - "EleutherAI/pythia-14m", # GPT-NeoX architecture (smaller than pythia-70m) + "gpt2", # JointQKVAttentionBridge (standard decoder-only) + "EleutherAI/pythia-14m", # ParallelBlockBridge (C15 regression guard) ] # Gemma2: local only (too large for CI) diff --git a/tests/unit/model_bridge/test_component_hooks_fire.py b/tests/unit/model_bridge/test_component_hooks_fire.py new file mode 100644 index 000000000..8c9ce751b --- /dev/null +++ b/tests/unit/model_bridge/test_component_hooks_fire.py @@ -0,0 +1,175 @@ +"""Regression test: attention hooks fire on forward (C1 guard). + +Complements the alias-resolution test: aliases can resolve yet the HookPoint +may never fire if forward bypasses the LinearBridge (the original C1 bug). +Isolates blocks.0.attn per PositionEmbeddingsAttentionBridge adapter with a +synthetic HF module and asserts hook_q/k/v/z all fire. +""" + +from __future__ import annotations + +from typing import Any + +import pytest +import torch +import torch.nn as nn + +from transformer_lens.config import TransformerBridgeConfig +from transformer_lens.factories.architecture_adapter_factory import ( + SUPPORTED_ARCHITECTURES, +) +from transformer_lens.model_bridge.generalized_components.position_embeddings_attention import ( + PositionEmbeddingsAttentionBridge, +) + + +def _stub_cfg(architecture: str, **kw: Any) -> TransformerBridgeConfig: + return TransformerBridgeConfig( + d_model=64, + d_head=16, + n_layers=2, + n_ctx=128, + n_heads=4, + d_vocab=1000, + d_mlp=256, + n_key_value_heads=kw.get("n_key_value_heads", 4), + default_prepend_bos=True, + architecture=architecture, + ) + + +def _position_embeddings_adapters() -> list[str]: + """Adapters using PositionEmbeddingsAttentionBridge directly (not subclasses).""" + results: list[str] = [] + for arch, adapter_cls in sorted(SUPPORTED_ARCHITECTURES.items()): + try: + adapter = adapter_cls(_stub_cfg(arch)) + except Exception: + continue + mapping = adapter.component_mapping + if mapping is None or "blocks" not in mapping: + continue + attn = (mapping["blocks"].submodules or {}).get("attn") + # type() check excludes JointQKV subclasses. + if attn is not None and type(attn) is PositionEmbeddingsAttentionBridge: + results.append(arch) + return results + + +class _FakeHFAttn(nn.Module): + """Synthetic HF attention module with q/k/v/o + optional QK-norms.""" + + def __init__( + self, + d_model: int, + n_heads: int, + head_dim: int, + n_kv_heads: int, + with_q_norm: bool, + with_k_norm: bool, + ) -> None: + super().__init__() + self.head_dim = head_dim + self.num_key_value_groups = n_heads // n_kv_heads if n_kv_heads else 1 + self.scaling = head_dim**-0.5 + self.attention_dropout = 0.0 + self.layer_idx = 0 + self.q_proj = nn.Linear(d_model, n_heads * head_dim, bias=False) + self.k_proj = nn.Linear(d_model, n_kv_heads * head_dim, bias=False) + self.v_proj = nn.Linear(d_model, n_kv_heads * head_dim, bias=False) + self.o_proj = nn.Linear(n_heads * head_dim, d_model, bias=False) + if with_q_norm: + self.q_norm = nn.LayerNorm(head_dim, elementwise_affine=True) + if with_k_norm: + self.k_norm = nn.LayerNorm(head_dim, elementwise_affine=True) + + +def _make_fake_hf_attn( + attn_bridge: PositionEmbeddingsAttentionBridge, + d_model: int, + n_heads: int, + head_dim: int, + n_kv_heads: int, +) -> nn.Module: + return _FakeHFAttn( + d_model=d_model, + n_heads=n_heads, + head_dim=head_dim, + n_kv_heads=n_kv_heads, + with_q_norm="q_norm" in attn_bridge.submodules, + with_k_norm="k_norm" in attn_bridge.submodules, + ) + + +def _wire_attention_submodules( + attn_bridge: PositionEmbeddingsAttentionBridge, fake_hf_attn: nn.Module +) -> None: + """Mirror setup_components: set original_component + add_module for each sub.""" + for name, sub in attn_bridge.submodules.items(): + hf_sub = getattr(fake_hf_attn, name + "_proj", None) + if hf_sub is None: + hf_sub = getattr(fake_hf_attn, name, None) + if hf_sub is None: + raise RuntimeError(f"fake HF attn missing '{name}' (tried {name}_proj, {name})") + sub.set_original_component(hf_sub) + if name not in attn_bridge._modules: + attn_bridge.add_module(name, sub) + + +_CRITICAL_HOOKS = {"hook_q", "hook_k", "hook_v", "hook_z"} + + +@pytest.mark.parametrize("architecture", _position_embeddings_adapters()) +def test_attention_critical_hooks_fire_on_forward(architecture: str) -> None: + """Assert hook_q/k/v/z fire during attention forward (C1 regression guard).""" + adapter_cls = SUPPORTED_ARCHITECTURES[architecture] + adapter = adapter_cls(_stub_cfg(architecture)) + attn = adapter.component_mapping["blocks"].submodules["attn"] + assert isinstance(attn, PositionEmbeddingsAttentionBridge) + + d_model = adapter.cfg.d_model + n_heads = adapter.cfg.n_heads + head_dim = d_model // n_heads + n_kv_heads = getattr(adapter.cfg, "n_key_value_heads", n_heads) or n_heads + + fake_hf_attn = _make_fake_hf_attn(attn, d_model, n_heads, head_dim, n_kv_heads) + try: + attn.set_original_component(fake_hf_attn) + except RuntimeError as e: + pytest.skip(f"{architecture}: cannot wire synthetic HF attn ({e})") + + _wire_attention_submodules(attn, fake_hf_attn) + + fired: set[str] = set() + handles = [] + for hname, hp in attn.get_hooks().items(): + h = hp.add_hook(lambda t, hook, n=hname: (fired.add(n), t)[1]) + handles.append(h) + + # Record alias-targeted hooks under their alias name (hook_q -> q.hook_out). + for alias_name, target in attn.hook_aliases.items(): + if isinstance(target, str): + try: + obj: Any = attn + for part in target.split("."): + obj = ( + obj.submodules.get(part) + if hasattr(obj, "submodules") and part in obj.submodules + else getattr(obj, part) + ) + if hasattr(obj, "add_hook"): + obj.add_hook(lambda t, hook, n=alias_name: (fired.add(n), t)[1]) + except AttributeError: + pass + + batch, seq = 1, 4 + hidden = torch.randn(batch, seq, d_model) + cos = torch.ones(1, seq, head_dim) + sin = torch.zeros(1, seq, head_dim) + attn(hidden_states=hidden, position_embeddings=(cos, sin), attention_mask=None) + + missing = _CRITICAL_HOOKS - fired + assert not missing, ( + f"{architecture}: critical attention hooks did not fire: {sorted(missing)}. " + f"Fired: {sorted(fired)}" + ) diff --git a/tests/unit/model_bridge/test_component_setup.py b/tests/unit/model_bridge/test_component_setup.py index a533ccfc1..09cc57883 100644 --- a/tests/unit/model_bridge/test_component_setup.py +++ b/tests/unit/model_bridge/test_component_setup.py @@ -312,6 +312,7 @@ def split_qkv(self, component): name="blocks", submodules={ "ln1": NormalizationBridge(name="ln1", config={}), + "ln2": NormalizationBridge(name="ln2", config={}), "attn": JointQKVAttentionBridge( name="attn", config=SimpleNamespace(n_heads=1, d_model=10), diff --git a/tests/unit/model_bridge/test_hook_alias_resolution.py b/tests/unit/model_bridge/test_hook_alias_resolution.py new file mode 100644 index 000000000..ba1cfba31 --- /dev/null +++ b/tests/unit/model_bridge/test_hook_alias_resolution.py @@ -0,0 +1,130 @@ +"""Regression test: every adapter's hook aliases resolve to a real HookPoint. + +Catches bugs where an alias target path doesn't navigate to a HookPoint +(the complementary case to Tier 2, which catches aliases that resolve but are +bypassed by forward). Stub cfg only — no HF model load. +""" + +from __future__ import annotations + +from typing import Any, Iterable, Tuple + +import pytest + +from transformer_lens.config import TransformerBridgeConfig +from transformer_lens.factories.architecture_adapter_factory import ( + SUPPORTED_ARCHITECTURES, +) +from transformer_lens.hook_points import HookPoint +from transformer_lens.model_bridge.generalized_components.base import ( + GeneralizedComponent, +) + + +def _stub_cfg(architecture: str) -> TransformerBridgeConfig: + """Minimal cfg for adapter instantiation; small values keep stubs cheap.""" + return TransformerBridgeConfig( + d_model=64, + d_head=16, + n_layers=2, + n_ctx=128, + n_heads=4, + d_vocab=1000, + d_mlp=256, + n_key_value_heads=4, + default_prepend_bos=True, + architecture=architecture, + ) + + +def _iter_components( + root: Any, path: str = "" +) -> Iterable[Tuple[str, GeneralizedComponent]]: + """Walk component_mapping recursively, yielding (dotted-path, component).""" + if isinstance(root, dict): + for name, comp in root.items(): + yield from _iter_components(comp, f"{path}.{name}" if path else name) + return + if isinstance(root, GeneralizedComponent): + yield path, root + for name, sub in (root.submodules or {}).items(): + yield from _iter_components(sub, f"{path}.{name}") + + +def _resolve(component: GeneralizedComponent, target: str) -> Any: + """Resolve dotted alias using submodules dict — pre-model-load templates + don't yet have submodules registered via add_module().""" + obj: Any = component + for part in target.split("."): + nxt = None + if isinstance(obj, GeneralizedComponent): + nxt = (obj.submodules or {}).get(part) + if nxt is None: + nxt = getattr(obj, part, None) + if nxt is None: + raise AttributeError(part) + obj = nxt + return obj + + +# xfail(strict=True) so future fixes XPASS and force the marker to be removed. +# Each entry maps to a specific audit finding deferred from the C1+C15 PR. +_KNOWN_DEAD_ALIASES = { + "GPT2LMHeadCustomModel": "audit H27 — stale adapter, delete candidate", + "NanoGPTForCausalLM": "audit H28 — broken weight conversion, delete candidate", + "NeelSoluOldForCausalLM": "audit H28 — orphan weight conversion, delete candidate", + "LlavaForConditionalGeneration": "audit H15 — vision-encoder layer submodules unwired", + "LlavaNextForConditionalGeneration": "audit H15 + M24 — vision encoder + tiling opaque", + "LlavaOnevisionForConditionalGeneration": "audit H15 + M25 — vision encoder + video frames opaque", + "Gemma3ForConditionalGeneration": "audit H15 — multimodal vision encoder opaque", + "OpenELMForCausalLM": "audit H23 — per-layer head counts break uniform q/k/v shape", + "GraniteMoeHybridForCausalLM": "new finding — MoE+shared-MLP block lacks proper submodule aliases", +} + + +def _architecture_params(): + """Parametrize list with xfail markers for known-dead-alias adapters.""" + params = [] + for arch in sorted(SUPPORTED_ARCHITECTURES): + reason = _KNOWN_DEAD_ALIASES.get(arch) + if reason is not None: + params.append(pytest.param(arch, marks=pytest.mark.xfail(strict=True, reason=reason))) + else: + params.append(arch) + return params + + +@pytest.mark.parametrize("architecture", _architecture_params()) +def test_every_hook_alias_resolves_to_hookpoint(architecture: str) -> None: + """Every declared hook_aliases entry must resolve to a HookPoint.""" + adapter_cls = SUPPORTED_ARCHITECTURES[architecture] + try: + adapter = adapter_cls(_stub_cfg(architecture)) + except Exception as exc: + pytest.skip(f"Adapter {architecture} cannot instantiate with stub cfg: {exc}") + + mapping = adapter.component_mapping + if mapping is None: + pytest.skip(f"Adapter {architecture} has no component_mapping") + + failures: list[str] = [] + for path, component in _iter_components(mapping): + for alias_name, target in component.hook_aliases.items(): + targets = target if isinstance(target, list) else [target] + resolved = False + for single in targets: + try: + obj = _resolve(component, single) + except AttributeError: + continue + if isinstance(obj, HookPoint): + resolved = True + break + if not resolved: + failures.append( + f"{path}.{alias_name} -> {target} (type at path: unresolved)" + ) + assert not failures, ( + f"Architecture {architecture}: {len(failures)} dead hook aliases:\n " + + "\n ".join(failures) + ) diff --git a/transformer_lens/model_bridge/generalized_components/__init__.py b/transformer_lens/model_bridge/generalized_components/__init__.py index c2c7a121b..a84ecebbf 100644 --- a/transformer_lens/model_bridge/generalized_components/__init__.py +++ b/transformer_lens/model_bridge/generalized_components/__init__.py @@ -5,7 +5,10 @@ from transformer_lens.model_bridge.generalized_components.audio_feature_extractor import ( AudioFeatureExtractorBridge, ) -from transformer_lens.model_bridge.generalized_components.block import BlockBridge +from transformer_lens.model_bridge.generalized_components.block import ( + BlockBridge, + ParallelBlockBridge, +) from transformer_lens.model_bridge.generalized_components.bloom_attention import ( BloomAttentionBridge, ) @@ -94,6 +97,7 @@ "AttentionBridge", "AudioFeatureExtractorBridge", "BlockBridge", + "ParallelBlockBridge", "BloomBlockBridge", "BloomAttentionBridge", "CodeGenAttentionBridge", diff --git a/transformer_lens/model_bridge/generalized_components/base.py b/transformer_lens/model_bridge/generalized_components/base.py index 12be7b9c6..20e44fbbb 100644 --- a/transformer_lens/model_bridge/generalized_components/base.py +++ b/transformer_lens/model_bridge/generalized_components/base.py @@ -2,6 +2,7 @@ from __future__ import annotations import inspect +import warnings from collections.abc import Callable from typing import Any, Dict, List, Optional, Union @@ -91,24 +92,37 @@ def _register_aliases(self) -> None: if self.property_aliases: self._property_alias_registry.update(self.property_aliases) for alias_name, target_path in self._hook_alias_registry.items(): - try: - if isinstance(target_path, list): - for single_target in target_path: - try: - target_obj = self - for part in single_target.split("."): - target_obj = getattr(target_obj, part) - object.__setattr__(self, alias_name, target_obj) - break - except AttributeError: - continue - else: + resolved = False + if isinstance(target_path, list): + for single_target in target_path: + try: + target_obj = self + for part in single_target.split("."): + target_obj = getattr(target_obj, part) + object.__setattr__(self, alias_name, target_obj) + resolved = True + break + except AttributeError: + continue + else: + try: target_obj = self for part in target_path.split("."): target_obj = getattr(target_obj, part) object.__setattr__(self, alias_name, target_obj) - except AttributeError: - pass + resolved = True + except AttributeError: + pass + if not resolved: + # Surface drops instead of silently swallowing — some aliases are + # legitimately conditional on optional submodules, but an author + # needs to see which ones dropped at bridge-init. + warnings.warn( + f"Hook alias '{alias_name}' -> '{target_path}' on " + f"{type(self).__name__}(name={getattr(self, 'name', None)!r}) " + f"did not resolve; this hook will not be accessible.", + stacklevel=2, + ) for alias_name, target_path in self._property_alias_registry.items(): try: target_obj = self diff --git a/transformer_lens/model_bridge/generalized_components/block.py b/transformer_lens/model_bridge/generalized_components/block.py index e6cd0d71f..9c1548e78 100644 --- a/transformer_lens/model_bridge/generalized_components/block.py +++ b/transformer_lens/model_bridge/generalized_components/block.py @@ -65,21 +65,32 @@ def __init__( For example, {"hook_attn_out": "ln1_post.hook_out"} will make hook_attn_out point to ln1_post.hook_out instead of the default attn.hook_out. """ - # Apply automatic aliases based on submodules before calling parent - # This allows submodule-based aliases to be combined with explicit overrides + # ln1_post/ln2_post redirect attn_out/mlp_out to match HookedTransformer's + # placement (hook fires after the post-norm, not before). auto_overrides = {} if submodules is not None: - # If ln1_post exists, hook_attn_out should point to it instead of attn.hook_out - # This matches HookedTransformer behavior where ln1_post is applied before hook_attn_out if "ln1_post" in submodules: auto_overrides["hook_attn_out"] = "ln1_post.hook_out" - # If ln2_post exists, hook_mlp_out should point to it instead of mlp.hook_out if "ln2_post" in submodules: auto_overrides["hook_mlp_out"] = "ln2_post.hook_out" - - # Merge automatic and explicit overrides (explicit takes precedence) merged_overrides = {**auto_overrides, **(hook_alias_overrides or {})} + # Guard against the C15 bug class: sequential transformer block (attn + + # mlp) with no ln2 would silently point hook_resid_mid at the wrong + # tensor. Use ParallelBlockBridge for parallel-residual architectures. + # Skip the check on generic-container / attn-only uses (no mlp). + has_attn_like = submodules is not None and any( + k in submodules for k in _VARIANT_SUBMODULE_SET + ) + has_mlp = submodules is not None and "mlp" in submodules + has_ln2 = submodules is not None and "ln2" in submodules + if has_attn_like and has_mlp and not has_ln2 and type(self) is BlockBridge: + raise ValueError( + f"BlockBridge at '{name}': 'ln2' submodule not declared. " + f"Either declare ln2, or use ParallelBlockBridge for a " + f"parallel-residual architecture." + ) + # Call parent with merged overrides super().__init__( name, @@ -87,6 +98,7 @@ def __init__( submodules=submodules if submodules is not None else {}, hook_alias_overrides=merged_overrides if merged_overrides else None, ) + self._original_block_forward: Optional[Callable[..., Any]] = None def forward(self, *args: Any, **kwargs: Any) -> Any: @@ -226,3 +238,33 @@ def _filter_kwargs_for_forward( # If we can't inspect the signature, pass through all kwargs # (better to potentially fail than to silently drop important params) return kwargs + + +class ParallelBlockBridge(BlockBridge): + """Block where attn and MLP both read the pre-attention residual. + + For GPT-J, NeoX, Pythia, Phi, Cohere, CodeGen, and some Falcon variants, + output = resid_pre + attn_out + mlp_out — no distinct post-attention + residual exists. Matches legacy HookedTransformer which omits hook_resid_mid + when ``cfg.parallel_attn_mlp=True``. Type-level distinction means a reader + of the adapter sees ``ParallelBlockBridge`` and knows the hook is absent. + """ + + def __init__( + self, + name: str, + config: Optional[Any] = None, + submodules: Optional[Dict[str, GeneralizedComponent]] = None, + hook_alias_overrides: Optional[Dict[str, str]] = None, + ): + super().__init__( + name, + config=config, + submodules=submodules, + hook_alias_overrides=hook_alias_overrides, + ) + # Ensure instance-level copy before mutating; base may have left the + # class-level dict shared when no overrides were passed. + if self.hook_aliases is type(self).__mro__[1].hook_aliases: + self.hook_aliases = dict(self.hook_aliases) + self.hook_aliases.pop("hook_resid_mid", None) diff --git a/transformer_lens/model_bridge/generalized_components/position_embeddings_attention.py b/transformer_lens/model_bridge/generalized_components/position_embeddings_attention.py index 135ab0d17..3115364ec 100644 --- a/transformer_lens/model_bridge/generalized_components/position_embeddings_attention.py +++ b/transformer_lens/model_bridge/generalized_components/position_embeddings_attention.py @@ -136,24 +136,73 @@ def __init__( self._init_position_embedding_hooks() if getattr(config, "gated_q_proj", False): self.hook_q_gate = HookPoint() + # Gate on adapter intent; HF-vs-adapter mismatches surface in set_original_component. + if submodules is not None and "q_norm" in submodules: + self.hook_q_normed = HookPoint() + if submodules is not None and "k_norm" in submodules: + self.hook_k_normed = HookPoint() + self._qk_norm_phase: Optional[str] = None def set_original_component(self, component: torch.nn.Module) -> None: - """Set the original HF component and register for rotary hook firing. - - This overrides the base class method to also: - 1. Register this bridge in the global registry (for hook_rot_q/hook_rot_k) - 2. Set up the eager_attention_forward wrapper if not already done - - Args: - component: The HuggingFace attention module - """ + """Wire HF module, register for rotary hooks, validate adapter declarations.""" super().set_original_component(component) - - # Register this bridge instance so the wrapped eager_attention_forward can find it _ATTENTION_BRIDGE_REGISTRY[id(component)] = self - - # Ensure the wrapper is set up _setup_eager_attention_hook_wrapper() + self._validate_submodule_declarations(component) + self._qk_norm_phase = self._decide_qk_norm_phase(component) + + def _validate_submodule_declarations(self, hf_attn: torch.nn.Module) -> None: + """Raise if adapter omits q/k/v/o or a QK-norm the HF module has.""" + # Silent fallback to raw HF linears is exactly what caused hook_q/k/v/z + # to never fire on 25 adapters; require explicit declaration. + missing = [req for req in ("q", "k", "v", "o") if req not in self.submodules] + if missing: + raise RuntimeError( + f"{type(self).__name__} at '{self.name}' is missing required " + f"submodules: {missing}. Declare them in the adapter's " + f"component_mapping, e.g. submodules={{'q': LinearBridge(name='q_proj'), " + f"'k': LinearBridge(name='k_proj'), 'v': LinearBridge(name='v_proj'), " + f"'o': LinearBridge(name='o_proj')}}." + ) + # Reverse mismatch (adapter declares, HF lacks) surfaces at norm forward. + for norm_name in ("q_norm", "k_norm"): + if getattr(hf_attn, norm_name, None) is not None and norm_name not in self.submodules: + raise RuntimeError( + f"{type(self).__name__} at '{self.name}': HF module has " + f"'{norm_name}' but adapter did not declare it. Forward would " + f"skip the norm, producing wrong logits vs HF. Add " + f"'{norm_name}': RMSNormalizationBridge(name='{norm_name}', " + f"config=self.cfg) to the attention submodules." + ) + + def _decide_qk_norm_phase(self, hf_attn: torch.nn.Module) -> Optional[str]: + """Dispatch pre/post-reshape norm from weight shape; raise on ambiguity.""" + if "q_norm" not in self.submodules: + return None + q_norm = getattr(hf_attn, "q_norm", None) + if q_norm is None: + raise RuntimeError(f"{self.name}: q_norm declared but HF module has none.") + + weight = getattr(q_norm, "weight", None) + head_dim = int(hf_attn.head_dim) + n_heads = int(getattr(self.config, "n_heads", 0)) + + # Non-learnable norm (Gemma-3 style) broadcasts over head_dim. + if weight is None or weight.ndim == 0: + return "post_reshape" + shape = tuple(weight.shape) + if shape == (head_dim,): + return "post_reshape" + if n_heads and shape == (n_heads * head_dim,): + return "pre_reshape" + # Per-head norm (Cohere) broadcasts on the reshaped [B,H,S,D] tensor. + if n_heads and shape == (n_heads, head_dim): + return "post_reshape" + raise RuntimeError( + f"{self.name}: cannot determine QK-norm phase from q_norm weight " + f"shape {shape} (head_dim={head_dim}, n_heads={n_heads}). Expected " + f"(head_dim,), (n_heads*head_dim,), or (n_heads, head_dim)." + ) def forward(self, *args: Any, **kwargs: Any) -> Any: """Reimplemented forward pass with hooks at correct computation stages. @@ -200,21 +249,19 @@ def forward(self, *args: Any, **kwargs: Any) -> Any: if target_dtype is not None and hidden_states.is_floating_point(): hidden_states = hidden_states.to(dtype=target_dtype) - # --- Q/K/V Projection + Optional Q/K Norms --- - # Detect norm order: pre-reshape (OLMo 2) vs post-reshape (Gemma 3) input_shape = hidden_states.shape[:-1] head_dim = hf_attn.head_dim hidden_shape = (*input_shape, -1, head_dim) - query_states = hf_attn.q_proj(hidden_states) - key_states = hf_attn.k_proj(hidden_states) - value_states = hf_attn.v_proj(hidden_states) + # Route through LinearBridges so hook_q/k/v/z (aliased to q/k/v.hook_out, + # o.hook_in) fire on the live path. + query_states = self.q(hidden_states) + key_states = self.k(hidden_states) + value_states = self.v(hidden_states) - # Gated q_proj (Qwen3.5/Qwen3Next): q_proj outputs [Q|gate] interleaved - # per head. cfg.gated_q_proj is set by the adapter. The actual split only - # triggers if the output is 2x the standard width (n_heads * head_dim). - # In processed mode, preprocess_weights slices q_proj to standard width - # so this naturally passes through. + # Qwen3.5/Qwen3-Next interleave [Q|gate] per head in q_proj output. + # Processed-weights mode slices q_proj to standard width beforehand, so + # the 2x-width path only triggers on unprocessed state dicts. q_gate = None if getattr(self.config, "gated_q_proj", False): q_dim = query_states.shape[-1] @@ -227,28 +274,24 @@ def forward(self, *args: Any, **kwargs: Any) -> Any: q_gate = q_gate.reshape(*input_shape, -1) query_states = query_states.reshape(*input_shape, -1) - has_q_norm = hasattr(hf_attn, "q_norm") and hf_attn.q_norm is not None - has_k_norm = hasattr(hf_attn, "k_norm") and hf_attn.k_norm is not None - applied_pre_reshape_norm = False + has_q_norm = "q_norm" in self.submodules + has_k_norm = "k_norm" in self.submodules - if has_q_norm: - try: - query_states = hf_attn.q_norm(query_states) - if has_k_norm: - key_states = hf_attn.k_norm(key_states) - applied_pre_reshape_norm = True - except RuntimeError: - pass + # Pre-reshape phase (OLMo-2): norm on [B, S, H*D]. + if has_q_norm and self._qk_norm_phase == "pre_reshape": + query_states = self.hook_q_normed(self.q_norm(query_states)) + if has_k_norm: + key_states = self.hook_k_normed(self.k_norm(key_states)) query_states = query_states.view(hidden_shape).transpose(1, 2) key_states = key_states.view(hidden_shape).transpose(1, 2) value_states = value_states.view(hidden_shape).transpose(1, 2) - if has_q_norm and not applied_pre_reshape_norm: - # Post-reshape norm (Gemma 3 style: norm on [batch, heads, seq, head_dim]) - query_states = hf_attn.q_norm(query_states) - if has_k_norm and not applied_pre_reshape_norm: - key_states = hf_attn.k_norm(key_states) + # Post-reshape phase (Gemma-3/Cohere): norm on [B, H, S, D]. + if has_q_norm and self._qk_norm_phase == "post_reshape": + query_states = self.hook_q_normed(self.q_norm(query_states)) + if has_k_norm: + key_states = self.hook_k_normed(self.k_norm(key_states)) # --- RoPE --- if position_embeddings is not None: @@ -336,14 +379,9 @@ def forward(self, *args: Any, **kwargs: Any) -> Any: q_gate = self.hook_q_gate(q_gate) attn_output = attn_output * torch.sigmoid(q_gate) - # --- Output Projection --- - # Different architectures name this differently: o_proj (Llama, Gemma, Qwen), - # dense (Phi), out_proj (others) - o_proj = getattr(hf_attn, "o_proj", None) or getattr(hf_attn, "dense", None) - if o_proj is not None: - attn_output = o_proj(attn_output) - - # --- Output Hook --- + # Route through LinearBridge so hook_z (aliased to o.hook_in) fires. + # LinearBridge wraps whichever HF attr the adapter mapped (o_proj, dense, out_proj). + attn_output = self.o(attn_output) attn_output = self.hook_out(attn_output) return attn_output, attn_weights diff --git a/transformer_lens/model_bridge/supported_architectures/codegen.py b/transformer_lens/model_bridge/supported_architectures/codegen.py index c385833ae..f8ed83544 100644 --- a/transformer_lens/model_bridge/supported_architectures/codegen.py +++ b/transformer_lens/model_bridge/supported_architectures/codegen.py @@ -10,12 +10,12 @@ ) from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter from transformer_lens.model_bridge.generalized_components import ( - BlockBridge, CodeGenAttentionBridge, EmbeddingBridge, LinearBridge, MLPBridge, NormalizationBridge, + ParallelBlockBridge, UnembeddingBridge, ) @@ -68,7 +68,7 @@ def __init__(self, cfg: Any) -> None: self.component_mapping = { "embed": EmbeddingBridge(name="transformer.wte"), - "blocks": BlockBridge( + "blocks": ParallelBlockBridge( name="transformer.h", submodules={ "ln1": NormalizationBridge(name="ln_1", config=self.cfg), diff --git a/transformer_lens/model_bridge/supported_architectures/cohere.py b/transformer_lens/model_bridge/supported_architectures/cohere.py index f7550cc61..5dcb74149 100644 --- a/transformer_lens/model_bridge/supported_architectures/cohere.py +++ b/transformer_lens/model_bridge/supported_architectures/cohere.py @@ -16,11 +16,11 @@ from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter from transformer_lens.model_bridge.generalized_components import ( - BlockBridge, EmbeddingBridge, GatedMLPBridge, LinearBridge, NormalizationBridge, + ParallelBlockBridge, PositionEmbeddingsAttentionBridge, RotaryEmbeddingBridge, UnembeddingBridge, @@ -119,7 +119,7 @@ def __init__(self, cfg: Any) -> None: # Rotary embedding: top-level, delegates to CohereRotaryEmbedding. # Pattern matches llama.py:75 and falcon.py:154 — NOT inside blocks. "rotary_emb": RotaryEmbeddingBridge(name="model.rotary_emb", config=self.cfg), - "blocks": BlockBridge( + "blocks": ParallelBlockBridge( name="model.layers", submodules={ # Single pre-norm only — Cohere has no post_attention_layernorm. diff --git a/transformer_lens/model_bridge/supported_architectures/falcon.py b/transformer_lens/model_bridge/supported_architectures/falcon.py index e552a07f4..317bbd2a8 100644 --- a/transformer_lens/model_bridge/supported_architectures/falcon.py +++ b/transformer_lens/model_bridge/supported_architectures/falcon.py @@ -23,6 +23,7 @@ LinearBridge, MLPBridge, NormalizationBridge, + ParallelBlockBridge, RotaryEmbeddingBridge, UnembeddingBridge, ) @@ -143,9 +144,11 @@ def __init__(self, cfg: Any) -> None: elif self._is_new_arch and getattr(cfg, "num_ln_in_parallel_attn", None) == 2: block_submodules["ln2"] = NormalizationBridge(name="ln_mlp", config=self.cfg) + # Falcon has both parallel (most checkpoints) and sequential variants. + block_cls = ParallelBlockBridge if is_parallel else BlockBridge self.component_mapping: dict[str, Any] = { "embed": EmbeddingBridge(name="transformer.word_embeddings"), - "blocks": BlockBridge(name="transformer.h", submodules=block_submodules), + "blocks": block_cls(name="transformer.h", submodules=block_submodules), "ln_final": NormalizationBridge(name="transformer.ln_f", config=self.cfg), "unembed": UnembeddingBridge(name="lm_head"), } diff --git a/transformer_lens/model_bridge/supported_architectures/gptj.py b/transformer_lens/model_bridge/supported_architectures/gptj.py index a717eabee..10d8fd4b7 100644 --- a/transformer_lens/model_bridge/supported_architectures/gptj.py +++ b/transformer_lens/model_bridge/supported_architectures/gptj.py @@ -9,7 +9,7 @@ from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter from transformer_lens.model_bridge.generalized_components import ( AttentionBridge, - BlockBridge, + ParallelBlockBridge, EmbeddingBridge, LinearBridge, MLPBridge, @@ -31,6 +31,7 @@ def __init__(self, cfg: Any) -> None: self.cfg.final_rms = False self.cfg.gated_mlp = False self.cfg.attn_only = False + self.cfg.parallel_attn_mlp = True self.weight_processing_conversions = { "blocks.{i}.attn.q.weight": ParamProcessingConversion( @@ -49,7 +50,7 @@ def __init__(self, cfg: Any) -> None: self.component_mapping = { "embed": EmbeddingBridge(name="transformer.wte"), - "blocks": BlockBridge( + "blocks": ParallelBlockBridge( name="transformer.h", submodules={ "ln1": NormalizationBridge(name="ln_1", config=self.cfg), diff --git a/transformer_lens/model_bridge/supported_architectures/neox.py b/transformer_lens/model_bridge/supported_architectures/neox.py index d4e2b078f..1ab11150b 100644 --- a/transformer_lens/model_bridge/supported_architectures/neox.py +++ b/transformer_lens/model_bridge/supported_architectures/neox.py @@ -16,9 +16,9 @@ ) from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter from transformer_lens.model_bridge.generalized_components import ( - BlockBridge, EmbeddingBridge, JointQKVPositionEmbeddingsAttentionBridge, + ParallelBlockBridge, LinearBridge, MLPBridge, NormalizationBridge, @@ -44,6 +44,7 @@ def __init__(self, cfg: Any) -> None: self.cfg.final_rms = False self.cfg.gated_mlp = False self.cfg.attn_only = False + self.cfg.parallel_attn_mlp = True # NeoX/Pythia models were not trained with BOS tokens self.cfg.default_prepend_bos = False @@ -137,7 +138,7 @@ def __init__(self, cfg: Any) -> None: self.component_mapping = { "embed": EmbeddingBridge(name="gpt_neox.embed_in"), "rotary_emb": RotaryEmbeddingBridge(name="gpt_neox.rotary_emb"), - "blocks": BlockBridge( + "blocks": ParallelBlockBridge( name="gpt_neox.layers", submodules={ "ln1": NormalizationBridge( diff --git a/transformer_lens/model_bridge/supported_architectures/phi.py b/transformer_lens/model_bridge/supported_architectures/phi.py index 205a20708..a9921ffee 100644 --- a/transformer_lens/model_bridge/supported_architectures/phi.py +++ b/transformer_lens/model_bridge/supported_architectures/phi.py @@ -8,11 +8,11 @@ ) from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter from transformer_lens.model_bridge.generalized_components import ( - BlockBridge, EmbeddingBridge, LinearBridge, MLPBridge, NormalizationBridge, + ParallelBlockBridge, PositionEmbeddingsAttentionBridge, RotaryEmbeddingBridge, UnembeddingBridge, @@ -70,7 +70,7 @@ def __init__(self, cfg: Any) -> None: self.component_mapping = { "embed": EmbeddingBridge(name="model.embed_tokens"), "rotary_emb": RotaryEmbeddingBridge(name="model.rotary_emb"), - "blocks": BlockBridge( + "blocks": ParallelBlockBridge( name="model.layers", submodules={ "ln1": NormalizationBridge( diff --git a/transformer_lens/model_bridge/supported_architectures/pythia.py b/transformer_lens/model_bridge/supported_architectures/pythia.py index 7d4d9c924..c04d61a4e 100644 --- a/transformer_lens/model_bridge/supported_architectures/pythia.py +++ b/transformer_lens/model_bridge/supported_architectures/pythia.py @@ -16,12 +16,12 @@ ) from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter from transformer_lens.model_bridge.generalized_components import ( - BlockBridge, EmbeddingBridge, JointQKVPositionEmbeddingsAttentionBridge, LinearBridge, MLPBridge, NormalizationBridge, + ParallelBlockBridge, RotaryEmbeddingBridge, UnembeddingBridge, ) @@ -38,8 +38,8 @@ def __init__(self, cfg: Any) -> None: """ super().__init__(cfg) self.cfg.positional_embedding_type = "rotary" - # Pythia wasn't trained with BOS tokens, so match HuggingFace behavior - self.cfg.default_prepend_bos = False + self.cfg.parallel_attn_mlp = True # GPT-NeoX: attn + MLP both read resid_pre + self.cfg.default_prepend_bos = False # Pythia wasn't trained with BOS self.weight_processing_conversions = { "blocks.{i}.attn.q": ParamProcessingConversion( @@ -130,7 +130,7 @@ def __init__(self, cfg: Any) -> None: self.component_mapping = { "embed": EmbeddingBridge(name="gpt_neox.embed_in"), "rotary_emb": RotaryEmbeddingBridge(name="gpt_neox.rotary_emb", config=self.cfg), - "blocks": BlockBridge( + "blocks": ParallelBlockBridge( name="gpt_neox.layers", submodules={ "ln1": NormalizationBridge(name="input_layernorm", config=self.cfg), diff --git a/transformer_lens/tools/model_registry/data/supported_models.json b/transformer_lens/tools/model_registry/data/supported_models.json index fdce49a70..8a01b6aa7 100644 --- a/transformer_lens/tools/model_registry/data/supported_models.json +++ b/transformer_lens/tools/model_registry/data/supported_models.json @@ -165,7 +165,7 @@ "architecture_id": "Qwen3ForCausalLM", "model_id": "Qwen/Qwen3-0.6B", "status": 1, - "verified_date": "2026-03-30", + "verified_date": "2026-04-15", "metadata": null, "note": "Full verification completed", "phase1_score": 100.0, @@ -179,7 +179,7 @@ "architecture_id": "GPT2LMHeadModel", "model_id": "openai-community/gpt2", "status": 1, - "verified_date": "2026-04-07", + "verified_date": "2026-04-15", "metadata": null, "note": "Full verification completed", "phase1_score": 100.0, @@ -879,13 +879,13 @@ "architecture_id": "GPTNeoXForCausalLM", "model_id": "EleutherAI/pythia-70m-deduped", "status": 1, - "verified_date": "2026-03-30", + "verified_date": "2026-04-15", "metadata": null, "note": "Full verification completed", "phase1_score": 100.0, "phase2_score": 100.0, "phase3_score": 100.0, - "phase4_score": 77.5, + "phase4_score": 71.8, "phase7_score": null, "phase8_score": null }, From ca70d7b0ab98e192f1ad4dcce68c056451ab81a5 Mon Sep 17 00:00:00 2001 From: jlarson4 Date: Wed, 15 Apr 2026 14:13:02 -0500 Subject: [PATCH 2/3] verification --- .../model_registry/data/supported_models.json | 8 +-- .../data/verification_history.json | 62 ++++++++++++++++++- 2 files changed, 65 insertions(+), 5 deletions(-) diff --git a/transformer_lens/tools/model_registry/data/supported_models.json b/transformer_lens/tools/model_registry/data/supported_models.json index 8a01b6aa7..4a4bad75a 100644 --- a/transformer_lens/tools/model_registry/data/supported_models.json +++ b/transformer_lens/tools/model_registry/data/supported_models.json @@ -2657,7 +2657,7 @@ "architecture_id": "Olmo2ForCausalLM", "model_id": "allenai/OLMo-2-0425-1B", "status": 1, - "verified_date": "2026-03-30", + "verified_date": "2026-04-15", "metadata": null, "note": "Full verification completed", "phase1_score": 100.0, @@ -3035,13 +3035,13 @@ "architecture_id": "Gemma3ForCausalLM", "model_id": "EssentialAI/rnj-1-instruct", "status": 1, - "verified_date": "2026-03-11", + "verified_date": "2026-04-15", "metadata": null, "note": "Full verification completed", "phase1_score": 100.0, "phase2_score": 100.0, "phase3_score": 100.0, - "phase4_score": 94.9, + "phase4_score": 93.5, "phase7_score": null, "phase8_score": null }, @@ -97992,7 +97992,7 @@ "architecture_id": "CohereForCausalLM", "model_id": "trl-internal-testing/tiny-CohereForCausalLM", "status": 1, - "verified_date": "2026-04-10", + "verified_date": "2026-04-15", "metadata": { "downloads": 120449, "total_params": 2042176 diff --git a/transformer_lens/tools/model_registry/data/verification_history.json b/transformer_lens/tools/model_registry/data/verification_history.json index dc48d675e..00c7d9b09 100644 --- a/transformer_lens/tools/model_registry/data/verification_history.json +++ b/transformer_lens/tools/model_registry/data/verification_history.json @@ -1,5 +1,5 @@ { - "last_updated": "2026-04-15T09:15:26.792099", + "last_updated": "2026-04-15T14:01:45.090788", "records": [ { "model_id": "Macropodus/macbert4mdcspell_v1", @@ -11510,6 +11510,66 @@ "notes": "Full verification completed", "invalidated": false, "invalidation_reason": null + }, + { + "model_id": "EleutherAI/pythia-70m-deduped", + "architecture_id": "GPTNeoXForCausalLM", + "verified_date": "2026-04-15", + "verified_by": "verify_models", + "transformerlens_version": null, + "notes": "Full verification completed", + "invalidated": false, + "invalidation_reason": null + }, + { + "model_id": "Qwen/Qwen3-0.6B", + "architecture_id": "Qwen3ForCausalLM", + "verified_date": "2026-04-15", + "verified_by": "verify_models", + "transformerlens_version": null, + "notes": "Full verification completed", + "invalidated": false, + "invalidation_reason": null + }, + { + "model_id": "openai-community/gpt2", + "architecture_id": "GPT2LMHeadModel", + "verified_date": "2026-04-15", + "verified_by": "verify_models", + "transformerlens_version": null, + "notes": "Full verification completed", + "invalidated": false, + "invalidation_reason": null + }, + { + "model_id": "allenai/OLMo-2-0425-1B", + "architecture_id": "Olmo2ForCausalLM", + "verified_date": "2026-04-15", + "verified_by": "verify_models", + "transformerlens_version": null, + "notes": "Full verification completed", + "invalidated": false, + "invalidation_reason": null + }, + { + "model_id": "trl-internal-testing/tiny-CohereForCausalLM", + "architecture_id": "CohereForCausalLM", + "verified_date": "2026-04-15", + "verified_by": "verify_models", + "transformerlens_version": null, + "notes": "Full verification completed with issues: P3=94.7% (failed: weight_modification)", + "invalidated": false, + "invalidation_reason": null + }, + { + "model_id": "EssentialAI/rnj-1-instruct", + "architecture_id": "Gemma3ForCausalLM", + "verified_date": "2026-04-15", + "verified_by": "verify_models", + "transformerlens_version": null, + "notes": "Full verification completed", + "invalidated": false, + "invalidation_reason": null } ] } From 00f6592e993e3f14419afd999fb8e8654f0d30f3 Mon Sep 17 00:00:00 2001 From: jlarson4 Date: Wed, 15 Apr 2026 14:17:01 -0500 Subject: [PATCH 3/3] Fix typing and formating --- .../model_bridge/test_hook_alias_resolution.py | 15 ++++++--------- .../model_bridge/generalized_components/block.py | 2 +- .../position_embeddings_attention.py | 2 +- .../model_bridge/supported_architectures/gptj.py | 2 +- .../model_bridge/supported_architectures/neox.py | 2 +- 5 files changed, 10 insertions(+), 13 deletions(-) diff --git a/tests/unit/model_bridge/test_hook_alias_resolution.py b/tests/unit/model_bridge/test_hook_alias_resolution.py index ba1cfba31..fd5b8d85b 100644 --- a/tests/unit/model_bridge/test_hook_alias_resolution.py +++ b/tests/unit/model_bridge/test_hook_alias_resolution.py @@ -37,9 +37,7 @@ def _stub_cfg(architecture: str) -> TransformerBridgeConfig: ) -def _iter_components( - root: Any, path: str = "" -) -> Iterable[Tuple[str, GeneralizedComponent]]: +def _iter_components(root: Any, path: str = "") -> Iterable[Tuple[str, GeneralizedComponent]]: """Walk component_mapping recursively, yielding (dotted-path, component).""" if isinstance(root, dict): for name, comp in root.items(): @@ -121,10 +119,9 @@ def test_every_hook_alias_resolves_to_hookpoint(architecture: str) -> None: resolved = True break if not resolved: - failures.append( - f"{path}.{alias_name} -> {target} (type at path: unresolved)" - ) - assert not failures, ( - f"Architecture {architecture}: {len(failures)} dead hook aliases:\n " - + "\n ".join(failures) + failures.append(f"{path}.{alias_name} -> {target} (type at path: unresolved)") + assert ( + not failures + ), f"Architecture {architecture}: {len(failures)} dead hook aliases:\n " + "\n ".join( + failures ) diff --git a/transformer_lens/model_bridge/generalized_components/block.py b/transformer_lens/model_bridge/generalized_components/block.py index 9c1548e78..73f2bd130 100644 --- a/transformer_lens/model_bridge/generalized_components/block.py +++ b/transformer_lens/model_bridge/generalized_components/block.py @@ -265,6 +265,6 @@ def __init__( ) # Ensure instance-level copy before mutating; base may have left the # class-level dict shared when no overrides were passed. - if self.hook_aliases is type(self).__mro__[1].hook_aliases: + if self.hook_aliases is BlockBridge.hook_aliases: self.hook_aliases = dict(self.hook_aliases) self.hook_aliases.pop("hook_resid_mid", None) diff --git a/transformer_lens/model_bridge/generalized_components/position_embeddings_attention.py b/transformer_lens/model_bridge/generalized_components/position_embeddings_attention.py index 3115364ec..b71e7ac6a 100644 --- a/transformer_lens/model_bridge/generalized_components/position_embeddings_attention.py +++ b/transformer_lens/model_bridge/generalized_components/position_embeddings_attention.py @@ -184,7 +184,7 @@ def _decide_qk_norm_phase(self, hf_attn: torch.nn.Module) -> Optional[str]: raise RuntimeError(f"{self.name}: q_norm declared but HF module has none.") weight = getattr(q_norm, "weight", None) - head_dim = int(hf_attn.head_dim) + head_dim = int(getattr(hf_attn, "head_dim")) n_heads = int(getattr(self.config, "n_heads", 0)) # Non-learnable norm (Gemma-3 style) broadcasts over head_dim. diff --git a/transformer_lens/model_bridge/supported_architectures/gptj.py b/transformer_lens/model_bridge/supported_architectures/gptj.py index 10d8fd4b7..9ddb61568 100644 --- a/transformer_lens/model_bridge/supported_architectures/gptj.py +++ b/transformer_lens/model_bridge/supported_architectures/gptj.py @@ -9,11 +9,11 @@ from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter from transformer_lens.model_bridge.generalized_components import ( AttentionBridge, - ParallelBlockBridge, EmbeddingBridge, LinearBridge, MLPBridge, NormalizationBridge, + ParallelBlockBridge, UnembeddingBridge, ) diff --git a/transformer_lens/model_bridge/supported_architectures/neox.py b/transformer_lens/model_bridge/supported_architectures/neox.py index 1ab11150b..0ca3cd6bb 100644 --- a/transformer_lens/model_bridge/supported_architectures/neox.py +++ b/transformer_lens/model_bridge/supported_architectures/neox.py @@ -18,10 +18,10 @@ from transformer_lens.model_bridge.generalized_components import ( EmbeddingBridge, JointQKVPositionEmbeddingsAttentionBridge, - ParallelBlockBridge, LinearBridge, MLPBridge, NormalizationBridge, + ParallelBlockBridge, RotaryEmbeddingBridge, UnembeddingBridge, )