Skip to content

Commit 9ddd182

Browse files
authored
Completed InternLM2 adapter (#1251)
1 parent a41297f commit 9ddd182

9 files changed

Lines changed: 1545 additions & 67 deletions

File tree

tests/unit/model_bridge/supported_architectures/test_internlm2_adapter.py

Lines changed: 691 additions & 0 deletions
Large diffs are not rendered by default.

transformer_lens/factories/architecture_adapter_factory.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
GraniteMoeArchitectureAdapter,
2626
GraniteMoeHybridArchitectureAdapter,
2727
HubertArchitectureAdapter,
28+
InternLM2ArchitectureAdapter,
2829
LlamaArchitectureAdapter,
2930
LlavaArchitectureAdapter,
3031
LlavaNextArchitectureAdapter,
@@ -79,6 +80,7 @@
7980
"GPTJForCausalLM": GptjArchitectureAdapter,
8081
"HubertForCTC": HubertArchitectureAdapter,
8182
"HubertModel": HubertArchitectureAdapter,
83+
"InternLM2ForCausalLM": InternLM2ArchitectureAdapter,
8284
"LlamaForCausalLM": LlamaArchitectureAdapter,
8385
"LlavaForConditionalGeneration": LlavaArchitectureAdapter,
8486
"LlavaNextForConditionalGeneration": LlavaNextArchitectureAdapter,
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
"""Compatibility shims for transformers version differences.
2+
3+
These patches are applied lazily (only when missing) so they're safe to call
4+
from multiple adapters — the first caller wins, subsequent calls are no-ops.
5+
"""
6+
7+
8+
def patch_dynamic_cache_v5() -> None:
9+
"""Backfill DynamicCache methods removed in transformers v5.
10+
11+
Remote-code models written for transformers v4 call from_legacy_cache,
12+
to_legacy_cache, and get_usable_length which were removed in v5.
13+
Call this from any adapter's prepare_loading() that needs them.
14+
"""
15+
try:
16+
from transformers.cache_utils import DynamicCache
17+
except Exception:
18+
return
19+
20+
if not hasattr(DynamicCache, "from_legacy_cache"):
21+
22+
@classmethod # type: ignore[misc]
23+
def _from_legacy_cache(cls, past_key_values=None): # type: ignore[no-untyped-def]
24+
cache = cls()
25+
if past_key_values is not None:
26+
for idx, layer_past in enumerate(past_key_values):
27+
cache.update(layer_past[0], layer_past[1], idx)
28+
return cache
29+
30+
DynamicCache.from_legacy_cache = _from_legacy_cache # type: ignore[attr-defined]
31+
32+
if not hasattr(DynamicCache, "get_usable_length"):
33+
34+
def _get_usable_length(self, new_seq_len: int = 0, layer_idx: int = 0) -> int: # type: ignore[no-untyped-def]
35+
return self.get_seq_length(layer_idx)
36+
37+
DynamicCache.get_usable_length = _get_usable_length # type: ignore[attr-defined]
38+
39+
if not hasattr(DynamicCache, "to_legacy_cache"):
40+
41+
def _to_legacy_cache(self): # type: ignore[no-untyped-def]
42+
return tuple((layer.keys, layer.values) for layer in self.layers)
43+
44+
DynamicCache.to_legacy_cache = _to_legacy_cache # type: ignore[attr-defined]

transformer_lens/model_bridge/supported_architectures/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,9 @@
6060
from transformer_lens.model_bridge.supported_architectures.hubert import (
6161
HubertArchitectureAdapter,
6262
)
63+
from transformer_lens.model_bridge.supported_architectures.internlm2 import (
64+
InternLM2ArchitectureAdapter,
65+
)
6366
from transformer_lens.model_bridge.supported_architectures.llama import (
6467
LlamaArchitectureAdapter,
6568
)
@@ -171,6 +174,7 @@
171174
"Gpt2LmHeadCustomArchitectureAdapter",
172175
"GptjArchitectureAdapter",
173176
"HubertArchitectureAdapter",
177+
"InternLM2ArchitectureAdapter",
174178
"LlamaArchitectureAdapter",
175179
"LlavaArchitectureAdapter",
176180
"LlavaNextArchitectureAdapter",

0 commit comments

Comments
 (0)