Add OLMo2/3 models support in fairseq2#1410
Conversation
4975b9a to
1d050dd
Compare
cirquit
left a comment
There was a problem hiding this comment.
Almost at the finish line here! Key things are the @Final implementations (already taken care of afaik) and the KV-caching implementation.
Rest are minor things that need some attention. Happy to merge when KV caching + tests are passing.
…to ensure outputs align with HF Transformer
Fix all E501 line length violations (>88 chars) in the OLMO module: - Reformat long docstrings and comments across all files - Split long dictionary mappings in interop.py for better readability - Wrap long error messages and function arguments - All tests still passing after formatting changes
…ding - Add KV caching to OLMOMultiheadAttention.forward() using existing AttentionState/FullAttentionState infrastructure - Add GQA head expansion via repeat_interleave for grouped query attention - Wire LocalAttentionStateFactory for OLMO3 sliding window layers in factory.py - Add incremental decode test using pretrained OLMO2-1B checkpoint with real tokenized sentences, verifying step-by-step decoding matches full-sequence forward pass
…ents, and fix imports - config.py: Remove unresolved TODOs and dead code (initializer_range, shard_embed_dim) - interop.py: Remove legacy state_dict["model"] unwrapping (no v0.4 OLMo models) - interop.py: Remove speculative RoPE permutation comment - normalization.py: Import Callable, Sequence from collections.abc instead of typing - tokenizer.py: Remove loose comments and docstring above/below chat template
fbaa4f9 to
272a989
Compare
OLMO Post-Norm vs Pre-Norm vs Standard Post-Normflowchart LR
subgraph PreNorm["Pre-Norm (e.g. LLaMA)"]
direction LR
P1["Input"] --> P2["Norm"] --> P3["Attn/FFN"] --> P4["⊕ Add"] --> P5["Output"]
P1 -.->|"residual"| P4
end
subgraph PostNorm["Standard Post-Norm"]
direction LR
S1["Input"] --> S2["Attn/FFN"] --> S3["⊕ Add"] --> S4["Norm"] --> S5["Output"]
S1 -.->|"residual"| S3
end
subgraph OLMONorm["OLMO Post-Norm"]
direction LR
O1["Input"] --> O2["Attn/FFN"] --> O3["Norm"] --> O4["⊕ Add"] --> O5["Output"]
O1 -.->|"residual"| O4
end
style P2 fill:#95a5a6,color:#fff
style P3 fill:#95a5a6,color:#fff
style P4 fill:#95a5a6,color:#fff
style S2 fill:#95a5a6,color:#fff
style S3 fill:#95a5a6,color:#fff
style S4 fill:#95a5a6,color:#fff
style O2 fill:#8e44ad,color:#fff
style O3 fill:#c0392b,color:#fff
style O4 fill:#27ae60,color:#fff
|
Attention Q/K Norm Orderflowchart LR
Input["Input"] --> QProj["Q Proj"]
Input --> KProj["K Proj"]
Input --> VProj["V Proj"]
QProj --> QNorm["Q Norm"]
KProj --> KNorm["K Norm"]
QNorm --> QReshape["Reshape"]
KNorm --> KReshape["Reshape"]
VProj --> VReshape["Reshape"]
QReshape --> QRoPE["RoPE"]
KReshape --> KRoPE["RoPE"]
QRoPE --> SDPA["SDPA"]
KRoPE --> SDPA
VReshape --> SDPA
SDPA --> OutProj["Output Proj"]
style QNorm fill:#c0392b,color:#fff
style KNorm fill:#c0392b,color:#fff
style QRoPE fill:#16a085,color:#fff
style KRoPE fill:#16a085,color:#fff
style SDPA fill:#8e44ad,color:#fff
|
OLMo3 Hybrid Attention Patternblock-beta
columns 8
L0["0<br/>Sliding"]:1
L1["1<br/>Sliding"]:1
L2["2<br/>Sliding"]:1
L3["3<br/>Full"]:1
L4["4<br/>Sliding"]:1
L5["5<br/>Sliding"]:1
L6["6<br/>Sliding"]:1
L7["7<br/>Full"]:1
L8["8<br/>Sliding"]:1
L9["9<br/>Sliding"]:1
L10["10<br/>Sliding"]:1
L11["11<br/>Full"]:1
space:3
LN["N-1<br/>Full ✱"]:1
style L0 fill:#3498db,color:#fff
style L1 fill:#3498db,color:#fff
style L2 fill:#3498db,color:#fff
style L4 fill:#3498db,color:#fff
style L5 fill:#3498db,color:#fff
style L6 fill:#3498db,color:#fff
style L8 fill:#3498db,color:#fff
style L9 fill:#3498db,color:#fff
style L10 fill:#3498db,color:#fff
style L3 fill:#e67e22,color:#fff
style L7 fill:#e67e22,color:#fff
style L11 fill:#e67e22,color:#fff
style LN fill:#e67e22,color:#fff
Sliding Window vs Full Attentionflowchart TB
subgraph Full["Full Causal Attention"]
direction TB
FT["Tokens:   T1   T2   T3   T4   T5   T6   T7   T8"]
FM["
✅ · · · · · · ·
✅ ✅ · · · · · ·
✅ ✅ ✅ · · · · ·
✅ ✅ ✅ ✅ · · · ·
✅ ✅ ✅ ✅ ✅ · · ·
✅ ✅ ✅ ✅ ✅ ✅ · ·
✅ ✅ ✅ ✅ ✅ ✅ ✅ ·
✅ ✅ ✅ ✅ ✅ ✅ ✅ ✅
"]
FD["Each token attends to ALL previous tokens"]
end
subgraph Sliding["Sliding Window Attention (window=4)"]
direction TB
ST["Tokens:   T1   T2   T3   T4   T5   T6   T7   T8"]
SM["
✅ · · · · · · ·
✅ ✅ · · · · · ·
✅ ✅ ✅ · · · · ·
✅ ✅ ✅ ✅ · · · ·
· ✅ ✅ ✅ ✅ · · ·
· · ✅ ✅ ✅ ✅ · ·
· · · ✅ ✅ ✅ ✅ ·
· · · · ✅ ✅ ✅ ✅
"]
SD["Each token attends to only the last W tokens"]
end
style Full fill:#e67e22,color:#fff
style Sliding fill:#3498db,color:#fff
style FT fill:none,stroke:none,color:#fff
style FM fill:none,stroke:none,color:#fff,font-family:monospace
style FD fill:none,stroke:none,color:#fff
style ST fill:none,stroke:none,color:#fff
style SM fill:none,stroke:none,color:#fff,font-family:monospace
style SD fill:none,stroke:none,color:#fff
|
OLMo RMSNorm vs Standard RMSNormflowchart LR
subgraph Standard["Standard RMSNorm"]
direction LR
S1["Normalize<br/>(float32)"] --> S2["Cast to<br/>input dtype"] --> S3["× Weight"]
end
subgraph OLMO["OLMO RMSNorm"]
direction LR
O1["Normalize<br/>(float32)"] --> O2["× Weight"] --> O3["Cast to<br/>input dtype"]
end
style S1 fill:#95a5a6,color:#fff
style S2 fill:#95a5a6,color:#fff
style S3 fill:#95a5a6,color:#fff
style O1 fill:#c0392b,color:#fff
style O2 fill:#c0392b,color:#fff
style O3 fill:#c0392b,color:#fff
|
OLMo Architecture DiagramsOLMo Transformer Architectureflowchart TD
Input["Input Tokens"] --> Embed["Token Embedding"]
Embed --> Stack
subgraph Stack["× N Decoder Layers"]
subgraph DL["Decoder Layer"]
direction TB
SelfAttn["Self-Attention Block<br/><i>Q/K Norm · RoPE · SDPA</i>"]
AttnNorm["OLMORMSNorm"]
AttnRes["⊕ Residual"]
FFNBlock["Feed-Forward Block<br/><i>(SwiGLU)</i>"]
FFNNorm["OLMORMSNorm"]
FFNRes["⊕ Residual"]
SelfAttn --> AttnNorm --> AttnRes
AttnRes --> FFNBlock --> FFNNorm --> FFNRes
end
end
FFNRes --> FinalNorm["OLMORMSNorm<br/><i>(final)</i>"]
FinalNorm --> LMHead["LM Head<br/><i>(Linear → vocab)</i>"]
LMHead --> Output["Logits"]
%% Residual skip connections
Embed -.->|"residual"| AttnRes
AttnRes -.->|"residual"| FFNRes
style Input fill:#34495e,color:#fff
style Embed fill:#2c3e50,color:#fff
style Output fill:#34495e,color:#fff
style LMHead fill:#2c3e50,color:#fff
style FinalNorm fill:#c0392b,color:#fff
style SelfAttn fill:#8e44ad,color:#fff
style AttnNorm fill:#c0392b,color:#fff
style AttnRes fill:#27ae60,color:#fff
style FFNBlock fill:#2980b9,color:#fff
style FFNNorm fill:#c0392b,color:#fff
style FFNRes fill:#27ae60,color:#fff
Self-Attention Block Detailflowchart LR
Input["Input"] --> QProj["Q Proj"]
Input --> KProj["K Proj"]
Input --> VProj["V Proj"]
QProj --> QNorm["Q Norm<br/><i>(OLMORMSNorm)</i>"]
KProj --> KNorm["K Norm<br/><i>(OLMORMSNorm)</i>"]
QNorm --> QReshape["Reshape"]
KNorm --> KReshape["Reshape"]
VProj --> VReshape["Reshape"]
QReshape --> QRoPE["RoPE"]
KReshape --> KRoPE["RoPE"]
QRoPE --> SDPA["SDPA<br/><i>(Full or Sliding Window)</i>"]
KRoPE --> SDPA
VReshape --> SDPA
SDPA --> Flatten["Flatten"] --> OutProj["Output Proj"]
style QNorm fill:#c0392b,color:#fff
style KNorm fill:#c0392b,color:#fff
style QRoPE fill:#16a085,color:#fff
style KRoPE fill:#16a085,color:#fff
style SDPA fill:#8e44ad,color:#fff
Feed-Forward Block Detail (SwiGLU)flowchart LR
Input["Input"] --> Gate["Gate Proj"]
Input --> Up["Up Proj"]
Gate --> SiLU["SiLU"]
SiLU --> Mul["⊗ Multiply"]
Up --> Mul
Mul --> Down["Down Proj"] --> Output["Output"]
style Gate fill:#2980b9,color:#fff
style Up fill:#2980b9,color:#fff
style SiLU fill:#2980b9,color:#fff
style Mul fill:#2980b9,color:#fff
style Down fill:#2980b9,color:#fff
|
cirquit
left a comment
There was a problem hiding this comment.
LGTM! Just a few nits on verbose comment.
I've verified all the tests to run on my end, the training recipe also works.
python3 -m recipes.lm.sft --config-file recipes/lm/sft/configs/olmo2_1b_gsm8k.yaml /tmp/olmo2_sft_smoke_test --config regime.num_steps=5 regime.checkpoint_every_n_steps=1000
regime.validate_every_n_steps=1000 common.metric_recorders.wandb.enabled=false common.metric_recorders.tensorboard.enabled=falseExporting to HF format and restarting back from the checkpoint also works. Happy to merge.
0dbbf0d to
be4cc2e
Compare
What does this PR do? Please describe:
Add OLMo2 and OLMo3 model architecture support in fairseq2.
Both architectures share a unified
olmomodule. The key architecture features:OLMORMSNorm: The order of operations is normalize → multiply by weight → cast to original dtype (differs from standard RMSNorm which casts before multiplying).OLMOTransformerLMDecoderLayer: Uses a custom Post-Norm order: Attention/FFN → Norm → Add Residual. This differs from both standard Pre-Norm (Norm → Attention/FFN → Add Residual) and standard Post-Norm (Attention/FFN → Add Residual → Norm).OLMOMultiheadAttention: Inherits directly from theMultiheadAttentionabstract base and inlines Q/K/V projection setup.Rotary Encoding:
ReferenceRotaryEncoder(standard RoPE) directly.YaRNRotaryEncoderfor long-context extension (8K → 65K), implementing frequency-dependent scaling and attention scaling (mscale).OLMo3 Hybrid Attention (sliding window + full attention):
ReferenceRotaryEncoder, full attention layers useYaRNRotaryEncoder.Testing:
Fixes #1402
Does your PR introduce any breaking changes? If yes, please list them:
List of all backwards-incompatible changes.
Check list: