From d123e1054ce3f62cc9dfe7bee38dbb3dba62c9bd Mon Sep 17 00:00:00 2001 From: Prashant Rawat Date: Wed, 6 May 2026 13:17:19 -0700 Subject: [PATCH 1/2] Add block repetition support for TransformerBlock layers (#19164) Summary: Add configurable block repetition to MultimodalTransformer, enabling weight-shared depth scaling. A contiguous range of transformer layers can now be executed multiple times with shared weights. Add block_repeat_config field to ModelArgs (list of {start, end, count} dicts) Example params.json: "block_repeat_config": [{"start": 5, "end": 10, "count": 2}] Reviewed By: AdithyaSagar007 Differential Revision: D102393826 --- examples/models/llama/model_args.py | 55 ++++++++++++++++++++++++++++- 1 file changed, 54 insertions(+), 1 deletion(-) diff --git a/examples/models/llama/model_args.py b/examples/models/llama/model_args.py index ed661c75517..67e07134cca 100644 --- a/examples/models/llama/model_args.py +++ b/examples/models/llama/model_args.py @@ -2,7 +2,7 @@ from dataclasses import dataclass from enum import Enum from functools import partial -from typing import Any, Dict, Optional +from typing import Any, Dict, List, Optional import torch.nn.functional as F @@ -182,6 +182,12 @@ class ModelArgs: use_ffn_learnable_scales: bool = False output_soft_cap_temp: Optional[float] = None + # Block repetition: repeat contiguous ranges of transformer layers. + # List of {"start": int, "end": int, "count": int} dicts where start/end + # are layer indices (both inclusive) and count is total number of passes + # (1 = normal, 2 = run the block twice, etc.). Blocks must not overlap. + transformer_block_repeat_config: Optional[list] = None + def __post_init__(self): # noqa: C901 if self.n_kv_heads is None: self.n_kv_heads = self.n_heads @@ -224,3 +230,50 @@ def find_multiple(n: int, k: int) -> int: # Convert string act_fn to enum if needed if isinstance(self.act_fn, str): self.act_fn = ActFn.from_string(self.act_fn) + + self.validate_block_repeat_config() + + def validate_block_repeat_config(self) -> None: + """Validate transformer_block_repeat_config field. + + Called from __post_init__ and should also be called after setting + transformer_block_repeat_config post-construction. + """ + if self.transformer_block_repeat_config is None: + return + for i, block in enumerate(self.transformer_block_repeat_config): + assert ( + "start" in block and "end" in block and "count" in block + ), f"transformer_block_repeat_config[{i}] must have 'start', 'end', and 'count' keys" + assert 0 <= block["start"] <= block["end"] < self.n_layers, ( + f"transformer_block_repeat_config[{i}]: invalid range [{block['start']}, {block['end']}] " + f"for {self.n_layers} layers" + ) + assert ( + block["count"] >= 1 + ), f"transformer_block_repeat_config[{i}]: count must be >= 1" + # Check for overlapping blocks (end is inclusive, so next start must be > prev end) + sorted_blocks = sorted( + self.transformer_block_repeat_config, key=lambda b: b["start"] + ) + for i in range(1, len(sorted_blocks)): + assert sorted_blocks[i]["start"] > sorted_blocks[i - 1]["end"], ( + f"transformer_block_repeat_config: blocks {sorted_blocks[i-1]} and " + f"{sorted_blocks[i]} overlap" + ) + + @staticmethod + def normalize_block_repeat_config( + config: Optional[List[Dict[str, int]]], + ) -> Optional[List[Dict[str, int]]]: + """Drop entries with `count == 1`; return None if nothing remains. + + A block-repeat entry with count=1 visits its layers exactly once -- + the same as if the entry were omitted. Stripping these no-ops at + assignment time lets every downstream consumer assume each entry is + a genuine repeat (count > 1). Pure function: does not mutate input. + """ + if not config: + return None + normalized = [b for b in config if b.get("count", 1) > 1] + return normalized if normalized else None From c10057feeecbfb8401efc400fe7dadd2bd09a089 Mon Sep 17 00:00:00 2001 From: Prashant Rawat Date: Wed, 6 May 2026 13:17:19 -0700 Subject: [PATCH 2/2] Per-occurrence KV cache for transformer_block_repeat_config (#19324) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: Currently, when a TransformerBlock appears multiple times in MultimodalTransformer.layer_schedule (via ``args.transformer_block_repeat_config``), each visit to that layer reads and writes the same ``self.attention.kv_cache`` buffer. The repeated layer therefore shares its K/V history across both visits — this is "weight-shared loop with shared KV", which is not numerically equivalent to a physically unrolled N-layer model where each duplicated layer slot owns its own K/V cache. This diff adds an opt-in path so each occurrence in the schedule can use its own KV cache buffer while still sharing the layer's weight Parameters, giving the same numerical inference behavior as lowering an unrolled checkpoint. The model size (with transformer_block_repeat_config) remains the same as the original model. Differential Revision: D103962616 --- examples/models/llama/attention.py | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/examples/models/llama/attention.py b/examples/models/llama/attention.py index d43533b5a70..c9eecd4cb4c 100644 --- a/examples/models/llama/attention.py +++ b/examples/models/llama/attention.py @@ -29,6 +29,12 @@ class ForwardOptions(TypedDict, total=False): # When provided, the attention layer skips its own K/V projection # and reuses the donor's K/V instead. shared_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] + # Per-call KV cache override. Used by `MultimodalTransformer` when + # `transformer_block_repeat_config` repeats a TransformerBlock so that each + # *occurrence* of the layer in the schedule writes to its own KV cache + # rather than sharing the layer's `self.kv_cache`. When None or absent the + # attention falls back to `self.kv_cache`. + kv_cache_override: Optional["KVCache"] class Attention(nn.Module, ABC): @@ -276,7 +282,7 @@ def __init__( [0, 1, 2, 3, 4, NA, NA, NA] After cache update we would have [8, 1, 2, 3, 4, 5, 6, 7]. We kicked out token at pos = 0. However, the current step still has access to [pos - sliding_window_size, pos] tokens. - + To make sure we dont over attend, i.e. we dont have pos = 5 to attend to pos = 1, mask calculaton has to account for the sliding window size. @@ -573,8 +579,12 @@ def forward( q, k, v = self._prepare_qkv(q, x, bsz, seqlen, freqs_cos, freqs_sin) if self.use_kv_cache: + # Per-call KV cache override (used when a TransformerBlock is invoked + # multiple times via `transformer_block_repeat_config` so each + # occurrence has its own KV cache). Falls back to `self.kv_cache`. + active_kv_cache = kwargs.get("kv_cache_override") or self.kv_cache assert input_pos is not None - is_ring_buffer = getattr(self.kv_cache, "is_ring_buffer", False) + is_ring_buffer = getattr(active_kv_cache, "is_ring_buffer", False) if is_ring_buffer: # Ring buffer models compute their own mask after KV cache @@ -594,14 +604,14 @@ def forward( # Only update KV cache for non-shared layers if shared_kv is None: - assert self.kv_cache is not None, ( + assert active_kv_cache is not None, ( "kv_cache is required when shared_kv is not provided. " "This layer may be a YOCO shared layer that requires shared_kv from a donor." ) - k, v = self.kv_cache.update(input_pos, k, v) + k, v = active_kv_cache.update(input_pos, k, v) if is_ring_buffer: - attn_mask = self.kv_cache.create_causal_mask_for_ring_buffer( + attn_mask = active_kv_cache.create_causal_mask_for_ring_buffer( input_pos[0].item(), seqlen )