Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions transformer_lens/factories/architecture_adapter_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
GraniteMoeArchitectureAdapter,
GraniteMoeHybridArchitectureAdapter,
HubertArchitectureAdapter,
InternLM2ArchitectureAdapter,
LlamaArchitectureAdapter,
LlavaArchitectureAdapter,
LlavaNextArchitectureAdapter,
Expand Down Expand Up @@ -79,6 +80,7 @@
"GPTJForCausalLM": GptjArchitectureAdapter,
"HubertForCTC": HubertArchitectureAdapter,
"HubertModel": HubertArchitectureAdapter,
"InternLM2ForCausalLM": InternLM2ArchitectureAdapter,
"LlamaForCausalLM": LlamaArchitectureAdapter,
"LlavaForConditionalGeneration": LlavaArchitectureAdapter,
"LlavaNextForConditionalGeneration": LlavaNextArchitectureAdapter,
Expand Down
44 changes: 44 additions & 0 deletions transformer_lens/model_bridge/compat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
"""Compatibility shims for transformers version differences.

These patches are applied lazily (only when missing) so they're safe to call
from multiple adapters — the first caller wins, subsequent calls are no-ops.
"""


def patch_dynamic_cache_v5() -> None:
"""Backfill DynamicCache methods removed in transformers v5.

Remote-code models written for transformers v4 call from_legacy_cache,
to_legacy_cache, and get_usable_length which were removed in v5.
Call this from any adapter's prepare_loading() that needs them.
"""
try:
from transformers.cache_utils import DynamicCache
except Exception:
return

if not hasattr(DynamicCache, "from_legacy_cache"):

@classmethod # type: ignore[misc]
def _from_legacy_cache(cls, past_key_values=None): # type: ignore[no-untyped-def]
cache = cls()
if past_key_values is not None:
for idx, layer_past in enumerate(past_key_values):
cache.update(layer_past[0], layer_past[1], idx)
return cache

DynamicCache.from_legacy_cache = _from_legacy_cache # type: ignore[attr-defined]

if not hasattr(DynamicCache, "get_usable_length"):

def _get_usable_length(self, new_seq_len: int = 0, layer_idx: int = 0) -> int: # type: ignore[no-untyped-def]
return self.get_seq_length(layer_idx)

DynamicCache.get_usable_length = _get_usable_length # type: ignore[attr-defined]

if not hasattr(DynamicCache, "to_legacy_cache"):

def _to_legacy_cache(self): # type: ignore[no-untyped-def]
return tuple((layer.keys, layer.values) for layer in self.layers)

DynamicCache.to_legacy_cache = _to_legacy_cache # type: ignore[attr-defined]
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@
from transformer_lens.model_bridge.supported_architectures.hubert import (
HubertArchitectureAdapter,
)
from transformer_lens.model_bridge.supported_architectures.internlm2 import (
InternLM2ArchitectureAdapter,
)
from transformer_lens.model_bridge.supported_architectures.llama import (
LlamaArchitectureAdapter,
)
Expand Down Expand Up @@ -171,6 +174,7 @@
"Gpt2LmHeadCustomArchitectureAdapter",
"GptjArchitectureAdapter",
"HubertArchitectureAdapter",
"InternLM2ArchitectureAdapter",
"LlamaArchitectureAdapter",
"LlavaArchitectureAdapter",
"LlavaNextArchitectureAdapter",
Expand Down
Loading
Loading