Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,15 @@

pytestmark = pytest.mark.slow

# Diverse architectures for hook completeness testing
# Diverse architectures for hook completeness testing.
# Constraint: these tests compare bridge vs legacy HookedTransformer, so each
# entry must be in HookedTransformer's OFFICIAL_MODEL_NAMES. Tiny C1-affected
# families (Llama/Qwen/Gemma under ~150M) aren't registered with HT; for those,
# tests/unit/model_bridge/test_component_hooks_fire.py (Tier 2) provides
# direct per-adapter hook-firing coverage without needing an HT counterpart.
MODELS_TO_TEST = [
"gpt2", # Standard decoder-only with joint QKV
"EleutherAI/pythia-14m", # GPT-NeoX architecture (smaller than pythia-70m)
"gpt2", # JointQKVAttentionBridge (standard decoder-only)
"EleutherAI/pythia-14m", # ParallelBlockBridge (C15 regression guard)
]

# Gemma2: local only (too large for CI)
Expand Down
175 changes: 175 additions & 0 deletions tests/unit/model_bridge/test_component_hooks_fire.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
"""Regression test: attention hooks fire on forward (C1 guard).

Complements the alias-resolution test: aliases can resolve yet the HookPoint
may never fire if forward bypasses the LinearBridge (the original C1 bug).
Isolates blocks.0.attn per PositionEmbeddingsAttentionBridge adapter with a
synthetic HF module and asserts hook_q/k/v/z all fire.
"""

from __future__ import annotations

from typing import Any

import pytest
import torch
import torch.nn as nn

from transformer_lens.config import TransformerBridgeConfig
from transformer_lens.factories.architecture_adapter_factory import (
SUPPORTED_ARCHITECTURES,
)
from transformer_lens.model_bridge.generalized_components.position_embeddings_attention import (
PositionEmbeddingsAttentionBridge,
)


def _stub_cfg(architecture: str, **kw: Any) -> TransformerBridgeConfig:
return TransformerBridgeConfig(
d_model=64,
d_head=16,
n_layers=2,
n_ctx=128,
n_heads=4,
d_vocab=1000,
d_mlp=256,
n_key_value_heads=kw.get("n_key_value_heads", 4),
default_prepend_bos=True,
architecture=architecture,
)


def _position_embeddings_adapters() -> list[str]:
"""Adapters using PositionEmbeddingsAttentionBridge directly (not subclasses)."""
results: list[str] = []
for arch, adapter_cls in sorted(SUPPORTED_ARCHITECTURES.items()):
try:
adapter = adapter_cls(_stub_cfg(arch))
except Exception:
continue
mapping = adapter.component_mapping
if mapping is None or "blocks" not in mapping:
continue
attn = (mapping["blocks"].submodules or {}).get("attn")
# type() check excludes JointQKV subclasses.
if attn is not None and type(attn) is PositionEmbeddingsAttentionBridge:
results.append(arch)
return results


class _FakeHFAttn(nn.Module):
"""Synthetic HF attention module with q/k/v/o + optional QK-norms."""

def __init__(
self,
d_model: int,
n_heads: int,
head_dim: int,
n_kv_heads: int,
with_q_norm: bool,
with_k_norm: bool,
) -> None:
super().__init__()
self.head_dim = head_dim
self.num_key_value_groups = n_heads // n_kv_heads if n_kv_heads else 1
self.scaling = head_dim**-0.5
self.attention_dropout = 0.0
self.layer_idx = 0
self.q_proj = nn.Linear(d_model, n_heads * head_dim, bias=False)
self.k_proj = nn.Linear(d_model, n_kv_heads * head_dim, bias=False)
self.v_proj = nn.Linear(d_model, n_kv_heads * head_dim, bias=False)
self.o_proj = nn.Linear(n_heads * head_dim, d_model, bias=False)
if with_q_norm:
self.q_norm = nn.LayerNorm(head_dim, elementwise_affine=True)
if with_k_norm:
self.k_norm = nn.LayerNorm(head_dim, elementwise_affine=True)


def _make_fake_hf_attn(
attn_bridge: PositionEmbeddingsAttentionBridge,
d_model: int,
n_heads: int,
head_dim: int,
n_kv_heads: int,
) -> nn.Module:
return _FakeHFAttn(
d_model=d_model,
n_heads=n_heads,
head_dim=head_dim,
n_kv_heads=n_kv_heads,
with_q_norm="q_norm" in attn_bridge.submodules,
with_k_norm="k_norm" in attn_bridge.submodules,
)


def _wire_attention_submodules(
attn_bridge: PositionEmbeddingsAttentionBridge, fake_hf_attn: nn.Module
) -> None:
"""Mirror setup_components: set original_component + add_module for each sub."""
for name, sub in attn_bridge.submodules.items():
hf_sub = getattr(fake_hf_attn, name + "_proj", None)
if hf_sub is None:
hf_sub = getattr(fake_hf_attn, name, None)
if hf_sub is None:
raise RuntimeError(f"fake HF attn missing '{name}' (tried {name}_proj, {name})")
sub.set_original_component(hf_sub)
if name not in attn_bridge._modules:
attn_bridge.add_module(name, sub)


_CRITICAL_HOOKS = {"hook_q", "hook_k", "hook_v", "hook_z"}


@pytest.mark.parametrize("architecture", _position_embeddings_adapters())
def test_attention_critical_hooks_fire_on_forward(architecture: str) -> None:
"""Assert hook_q/k/v/z fire during attention forward (C1 regression guard)."""
adapter_cls = SUPPORTED_ARCHITECTURES[architecture]
adapter = adapter_cls(_stub_cfg(architecture))
attn = adapter.component_mapping["blocks"].submodules["attn"]
assert isinstance(attn, PositionEmbeddingsAttentionBridge)

d_model = adapter.cfg.d_model
n_heads = adapter.cfg.n_heads
head_dim = d_model // n_heads
n_kv_heads = getattr(adapter.cfg, "n_key_value_heads", n_heads) or n_heads

fake_hf_attn = _make_fake_hf_attn(attn, d_model, n_heads, head_dim, n_kv_heads)
try:
attn.set_original_component(fake_hf_attn)
except RuntimeError as e:
pytest.skip(f"{architecture}: cannot wire synthetic HF attn ({e})")

_wire_attention_submodules(attn, fake_hf_attn)

fired: set[str] = set()
handles = []
for hname, hp in attn.get_hooks().items():
h = hp.add_hook(lambda t, hook, n=hname: (fired.add(n), t)[1])
handles.append(h)

# Record alias-targeted hooks under their alias name (hook_q -> q.hook_out).
for alias_name, target in attn.hook_aliases.items():
if isinstance(target, str):
try:
obj: Any = attn
for part in target.split("."):
obj = (
obj.submodules.get(part)
if hasattr(obj, "submodules") and part in obj.submodules
else getattr(obj, part)
)
if hasattr(obj, "add_hook"):
obj.add_hook(lambda t, hook, n=alias_name: (fired.add(n), t)[1])
except AttributeError:
pass

batch, seq = 1, 4
hidden = torch.randn(batch, seq, d_model)
cos = torch.ones(1, seq, head_dim)
sin = torch.zeros(1, seq, head_dim)
attn(hidden_states=hidden, position_embeddings=(cos, sin), attention_mask=None)

missing = _CRITICAL_HOOKS - fired
assert not missing, (
f"{architecture}: critical attention hooks did not fire: {sorted(missing)}. "
f"Fired: {sorted(fired)}"
)
1 change: 1 addition & 0 deletions tests/unit/model_bridge/test_component_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,7 @@ def split_qkv(self, component):
name="blocks",
submodules={
"ln1": NormalizationBridge(name="ln1", config={}),
"ln2": NormalizationBridge(name="ln2", config={}),
"attn": JointQKVAttentionBridge(
name="attn",
config=SimpleNamespace(n_heads=1, d_model=10),
Expand Down
127 changes: 127 additions & 0 deletions tests/unit/model_bridge/test_hook_alias_resolution.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
"""Regression test: every adapter's hook aliases resolve to a real HookPoint.

Catches bugs where an alias target path doesn't navigate to a HookPoint
(the complementary case to Tier 2, which catches aliases that resolve but are
bypassed by forward). Stub cfg only — no HF model load.
"""

from __future__ import annotations

from typing import Any, Iterable, Tuple

import pytest

from transformer_lens.config import TransformerBridgeConfig
from transformer_lens.factories.architecture_adapter_factory import (
SUPPORTED_ARCHITECTURES,
)
from transformer_lens.hook_points import HookPoint
from transformer_lens.model_bridge.generalized_components.base import (
GeneralizedComponent,
)


def _stub_cfg(architecture: str) -> TransformerBridgeConfig:
"""Minimal cfg for adapter instantiation; small values keep stubs cheap."""
return TransformerBridgeConfig(
d_model=64,
d_head=16,
n_layers=2,
n_ctx=128,
n_heads=4,
d_vocab=1000,
d_mlp=256,
n_key_value_heads=4,
default_prepend_bos=True,
architecture=architecture,
)


def _iter_components(root: Any, path: str = "") -> Iterable[Tuple[str, GeneralizedComponent]]:
"""Walk component_mapping recursively, yielding (dotted-path, component)."""
if isinstance(root, dict):
for name, comp in root.items():
yield from _iter_components(comp, f"{path}.{name}" if path else name)
return
if isinstance(root, GeneralizedComponent):
yield path, root
for name, sub in (root.submodules or {}).items():
yield from _iter_components(sub, f"{path}.{name}")


def _resolve(component: GeneralizedComponent, target: str) -> Any:
"""Resolve dotted alias using submodules dict — pre-model-load templates
don't yet have submodules registered via add_module()."""
obj: Any = component
for part in target.split("."):
nxt = None
if isinstance(obj, GeneralizedComponent):
nxt = (obj.submodules or {}).get(part)
if nxt is None:
nxt = getattr(obj, part, None)
if nxt is None:
raise AttributeError(part)
obj = nxt
return obj


# xfail(strict=True) so future fixes XPASS and force the marker to be removed.
# Each entry maps to a specific audit finding deferred from the C1+C15 PR.
_KNOWN_DEAD_ALIASES = {
"GPT2LMHeadCustomModel": "audit H27 — stale adapter, delete candidate",
"NanoGPTForCausalLM": "audit H28 — broken weight conversion, delete candidate",
"NeelSoluOldForCausalLM": "audit H28 — orphan weight conversion, delete candidate",
"LlavaForConditionalGeneration": "audit H15 — vision-encoder layer submodules unwired",
"LlavaNextForConditionalGeneration": "audit H15 + M24 — vision encoder + tiling opaque",
"LlavaOnevisionForConditionalGeneration": "audit H15 + M25 — vision encoder + video frames opaque",
"Gemma3ForConditionalGeneration": "audit H15 — multimodal vision encoder opaque",
"OpenELMForCausalLM": "audit H23 — per-layer head counts break uniform q/k/v shape",
"GraniteMoeHybridForCausalLM": "new finding — MoE+shared-MLP block lacks proper submodule aliases",
}


def _architecture_params():
"""Parametrize list with xfail markers for known-dead-alias adapters."""
params = []
for arch in sorted(SUPPORTED_ARCHITECTURES):
reason = _KNOWN_DEAD_ALIASES.get(arch)
if reason is not None:
params.append(pytest.param(arch, marks=pytest.mark.xfail(strict=True, reason=reason)))
else:
params.append(arch)
return params


@pytest.mark.parametrize("architecture", _architecture_params())
def test_every_hook_alias_resolves_to_hookpoint(architecture: str) -> None:
"""Every declared hook_aliases entry must resolve to a HookPoint."""
adapter_cls = SUPPORTED_ARCHITECTURES[architecture]
try:
adapter = adapter_cls(_stub_cfg(architecture))
except Exception as exc:
pytest.skip(f"Adapter {architecture} cannot instantiate with stub cfg: {exc}")

mapping = adapter.component_mapping
if mapping is None:
pytest.skip(f"Adapter {architecture} has no component_mapping")

failures: list[str] = []
for path, component in _iter_components(mapping):
for alias_name, target in component.hook_aliases.items():
targets = target if isinstance(target, list) else [target]
resolved = False
for single in targets:
try:
obj = _resolve(component, single)
except AttributeError:
continue
if isinstance(obj, HookPoint):
resolved = True
break
if not resolved:
failures.append(f"{path}.{alias_name} -> {target} (type at path: unresolved)")
assert (
not failures
), f"Architecture {architecture}: {len(failures)} dead hook aliases:\n " + "\n ".join(
failures
)
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@
from transformer_lens.model_bridge.generalized_components.audio_feature_extractor import (
AudioFeatureExtractorBridge,
)
from transformer_lens.model_bridge.generalized_components.block import BlockBridge
from transformer_lens.model_bridge.generalized_components.block import (
BlockBridge,
ParallelBlockBridge,
)
from transformer_lens.model_bridge.generalized_components.bloom_attention import (
BloomAttentionBridge,
)
Expand Down Expand Up @@ -94,6 +97,7 @@
"AttentionBridge",
"AudioFeatureExtractorBridge",
"BlockBridge",
"ParallelBlockBridge",
"BloomBlockBridge",
"BloomAttentionBridge",
"CodeGenAttentionBridge",
Expand Down
Loading
Loading