From 8b674b1699b91994ab935a021d4fec409e83ed47 Mon Sep 17 00:00:00 2001 From: oleksost Date: Fri, 6 Mar 2026 20:49:00 +0000 Subject: [PATCH 1/2] Sample full layouts instead of independent per layer mixer sampling --- fast_llm/layers/block/config.py | 1 + fast_llm/layers/block/sequence.py | 6 ++- fast_llm/layers/decoder/config.py | 14 +++++- fast_llm/layers/decoder/stochastic_mixer.py | 53 +++++++++++++++++++-- 4 files changed, 67 insertions(+), 7 deletions(-) diff --git a/fast_llm/layers/block/config.py b/fast_llm/layers/block/config.py index fd76d36cb..f6bd8b896 100644 --- a/fast_llm/layers/block/config.py +++ b/fast_llm/layers/block/config.py @@ -46,6 +46,7 @@ class BlockKwargs: hidden_states = "hidden_states" output_hidden_states = "output_hidden_states" activation_mask = "activation_mask" + num_blocks_in_sequence = "num_blocks_in_sequence" @config_class(registry=True) diff --git a/fast_llm/layers/block/sequence.py b/fast_llm/layers/block/sequence.py index 54a5b3471..a53b6a78d 100644 --- a/fast_llm/layers/block/sequence.py +++ b/fast_llm/layers/block/sequence.py @@ -9,7 +9,7 @@ from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.layers.block.block import BlockBase -from fast_llm.layers.block.config import FixedBlockSequenceConfig, PatternBlockSequenceConfig +from fast_llm.layers.block.config import BlockKwargs, FixedBlockSequenceConfig, PatternBlockSequenceConfig from fast_llm.layers.common.peft.config import PeftConfig @@ -56,6 +56,7 @@ def get_layers(self) -> list["Layer"]: return self._layers_with_namespace def preprocess(self, kwargs: dict[str, typing.Any]) -> None: + kwargs[BlockKwargs.num_blocks_in_sequence] = self._config.num_blocks self._layers_with_namespace[0].preprocess(kwargs) def get_loss_definitions(self, count: int = 1) -> list[LossDef]: @@ -110,7 +111,8 @@ def get_layers(self) -> list[Layer]: return self._layers_with_namespace def preprocess(self, kwargs: dict[str, typing.Any]) -> None: - for _, index in self._config.preprocessing_layers.items(): + for name, index in self._config.preprocessing_layers.items(): + kwargs[BlockKwargs.num_blocks_in_sequence] = self._config.expanded_pattern.count(name) self._layers_with_namespace[index].preprocess(kwargs) def get_loss_definitions(self, count: int = 1) -> list[LossDef]: diff --git a/fast_llm/layers/decoder/config.py b/fast_llm/layers/decoder/config.py index 2f5990ccb..4cab2d39b 100644 --- a/fast_llm/layers/decoder/config.py +++ b/fast_llm/layers/decoder/config.py @@ -21,6 +21,8 @@ class StochasticMixerKwargs(BlockKwargs): mixer_name = "stochastic_mixer_name" generator = "stochastic_mixer_generator" + layout = "stochastic_mixer_layout" + layout_counter = "stochastic_mixer_layout_counter" @config_class() @@ -91,6 +93,7 @@ class StochasticMixerSamplingStrategy(enum.StrEnum): uniform = "uniform" weighted = "weighted" + full_layout = "full_layout" @config_class(registry=True) @@ -124,7 +127,8 @@ class StochasticMixerConfig(MixerConfig): _abstract = False - mixers: dict[str, MixerConfig] = Field( + mixers: dict[str, MixerConfig] | None = Field( + default=None, desc="Dict of mixer options to sample from (must contain at least 1). " "Keys are mixer names used for debugging and namespacing.", hint=FieldHint.architecture, @@ -162,7 +166,9 @@ class StochasticMixerConfig(MixerConfig): def _validate(self) -> None: super()._validate() - # Validate mixers dict is not empty + # Validate mixers dict is provided and not empty + if self.mixers is None: + raise ValueError("mixers must be provided for StochasticMixerConfig") Assert.gt(len(self.mixers), 0) # Set main_mixer_name to first mixer if not specified @@ -174,6 +180,10 @@ def _validate(self) -> None: if self.main_mixer_name not in self.mixers: raise ValueError(f"main_mixer_name '{self.main_mixer_name}' not found in mixers") + # Validate full_layout incompatibilities + if self.sampling_strategy == StochasticMixerSamplingStrategy.full_layout and self.sampling_weights is not None: + raise ValueError("sampling_weights is not compatible with full_layout sampling strategy") + # Validate and normalize sampling weights if self.sampling_weights is not None: Assert.eq(set(self.sampling_weights.keys()), set(self.mixers.keys())) diff --git a/fast_llm/layers/decoder/stochastic_mixer.py b/fast_llm/layers/decoder/stochastic_mixer.py index 984f34b80..8a05a2556 100644 --- a/fast_llm/layers/decoder/stochastic_mixer.py +++ b/fast_llm/layers/decoder/stochastic_mixer.py @@ -66,7 +66,9 @@ def __init__( } ) - if self._config.sampling_strategy == StochasticMixerSamplingStrategy.uniform: + if self._config.sampling_strategy == StochasticMixerSamplingStrategy.full_layout: + self._sampling_probs = None + elif self._config.sampling_strategy == StochasticMixerSamplingStrategy.uniform: self._sampling_probs = torch.ones(len(self.mixers), device="cpu") / len(self.mixers) elif self._config.sampling_strategy == StochasticMixerSamplingStrategy.weighted: if self._config.sampling_weights is None: @@ -108,6 +110,13 @@ def _sample_mixer_name(self, kwargs: dict[str, typing.Any]) -> str: if not self.training: return self._config.main_mixer_name + if self._config.sampling_strategy == StochasticMixerSamplingStrategy.full_layout: + layout = kwargs[StochasticMixerKwargs.layout] + counter = kwargs[StochasticMixerKwargs.layout_counter] + idx = counter[0] + counter[0] += 1 + return layout[idx] + generator = kwargs[StochasticMixerKwargs.generator] mixer_idx = torch.multinomial(self._sampling_probs, num_samples=1, generator=generator).item() return list(self.mixers.keys())[mixer_idx] @@ -150,6 +159,33 @@ def _forward( return self.mixers[mixer_name]._forward(input_, kwargs, losses, metrics) + def _sample_allocation(self, num_layers: int, generator: torch.Generator) -> list[int]: + """ + Sample a composition of num_layers into num_mixers bins uniformly. + + Uses stars-and-bars: picks (M-1) bar positions from {0, ..., N+M-2}, + giving each mixer a count. All integer partitions are equally likely. + """ + M = len(self.mixers) + N = num_layers + if M == 1: + return [N] + bars = torch.randperm(N + M - 1, generator=generator)[: M - 1].sort().values + padded = torch.cat([torch.tensor([-1]), bars, torch.tensor([N + M - 1])]) + counts = (padded[1:] - padded[:-1] - 1).tolist() + return counts + + def _sample_placement(self, counts: list[int], num_layers: int, generator: torch.Generator) -> list[str]: + """ + Given per-mixer counts, create a shuffled layout. + """ + mixer_names = list(self.mixers.keys()) + layout = [] + for name, count in zip(mixer_names, counts): + layout.extend([name] * count) + perm = torch.randperm(num_layers, generator=generator) + return [layout[i] for i in perm.tolist()] + def preprocess(self, kwargs: dict[str, typing.Any]) -> None: from fast_llm.engine.distributed.config import MAX_SEED from fast_llm.layers.block.config import BlockKwargs @@ -160,6 +196,13 @@ def preprocess(self, kwargs: dict[str, typing.Any]) -> None: generator.manual_seed(seed) kwargs[StochasticMixerKwargs.generator] = generator + if self._config.sampling_strategy == StochasticMixerSamplingStrategy.full_layout: + num_layers = kwargs[BlockKwargs.num_blocks_in_sequence] + counts = self._sample_allocation(num_layers, generator) + layout = self._sample_placement(counts, num_layers, generator) + kwargs[StochasticMixerKwargs.layout] = layout + kwargs[StochasticMixerKwargs.layout_counter] = [0] + for mixer in self.mixers.values(): mixer.preprocess(kwargs) @@ -173,8 +216,12 @@ def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], c """ usages = [mixer.get_compute_usage(input_, kwargs, config) for mixer in self.mixers.values()] - # Weight by sampling probability and return the expected value - expected_usage = sum(usage * prob.item() for usage, prob in zip(usages, self._sampling_probs)) + if self._sampling_probs is not None: + # Weight by sampling probability and return the expected value + expected_usage = sum(usage * prob.item() for usage, prob in zip(usages, self._sampling_probs)) + else: + # full_layout: uniform over compositions, so equal expected weight per mixer + expected_usage = sum(usages) / len(usages) return int(expected_usage) From bdd37c7733a52cb723baab880dbbcdfcc13efbcb Mon Sep 17 00:00:00 2001 From: oleksost Date: Mon, 9 Mar 2026 19:16:45 +0000 Subject: [PATCH 2/2] predefined layout set --- fast_llm/layers/decoder/config.py | 32 +++++++++++++++++++++ fast_llm/layers/decoder/stochastic_mixer.py | 29 +++++++++++++++++-- 2 files changed, 58 insertions(+), 3 deletions(-) diff --git a/fast_llm/layers/decoder/config.py b/fast_llm/layers/decoder/config.py index 4cab2d39b..390a8b05f 100644 --- a/fast_llm/layers/decoder/config.py +++ b/fast_llm/layers/decoder/config.py @@ -157,6 +157,21 @@ class StochasticMixerConfig(MixerConfig): hint=FieldHint.feature, ) + predefined_layouts: list[list[str]] | None = Field( + default=None, + desc="List of predefined layouts to oversample. Each layout is a list of mixer names, one per layer. " + "Mixer names must match keys in the mixers dict.", + hint=FieldHint.feature, + ) + + predefined_layout_probability: float = Field( + default=0.0, + desc="Probability of sampling from predefined_layouts instead of using the sampling_strategy. " + "Must be in [0, 1]. Only used when predefined_layouts is provided.", + hint=FieldHint.feature, + valid=check_field(Assert.in_range_incl, 0.0, 1.0), + ) + seed_shift: int = Field( default=_BIG_PRIMES[11], desc="Seed shift for mixer sampling reproducibility.", @@ -191,6 +206,23 @@ def _validate(self) -> None: normalized_values = normalize_probabilities(list(self.sampling_weights.values())) self.sampling_weights = dict(zip(self.sampling_weights.keys(), normalized_values)) + # Validate predefined layouts + if self.predefined_layouts is not None: + if len(self.predefined_layouts) == 0: + raise ValueError("predefined_layouts must be non-empty if provided") + mixer_names = set(self.mixers.keys()) + for i, layout in enumerate(self.predefined_layouts): + unknown = set(layout) - mixer_names + if unknown: + raise ValueError( + f"predefined_layouts[{i}] contains unknown mixer names: {unknown}. " + f"Valid names: {mixer_names}" + ) + if self.predefined_layout_probability <= 0: + warnings.warn("predefined_layouts provided but predefined_layout_probability is 0") + elif self.predefined_layout_probability > 0: + raise ValueError("predefined_layout_probability > 0 but predefined_layouts is not provided") + @property def layer_class(self) -> "type[StochasticMixer]": from fast_llm.layers.decoder.stochastic_mixer import StochasticMixer diff --git a/fast_llm/layers/decoder/stochastic_mixer.py b/fast_llm/layers/decoder/stochastic_mixer.py index 8a05a2556..093daff5d 100644 --- a/fast_llm/layers/decoder/stochastic_mixer.py +++ b/fast_llm/layers/decoder/stochastic_mixer.py @@ -16,6 +16,7 @@ ) from fast_llm.logging import get_model_debug_level from fast_llm.tensor import TensorMeta +from fast_llm.utils import Assert logger = logging.getLogger(__name__) @@ -110,7 +111,8 @@ def _sample_mixer_name(self, kwargs: dict[str, typing.Any]) -> str: if not self.training: return self._config.main_mixer_name - if self._config.sampling_strategy == StochasticMixerSamplingStrategy.full_layout: + # Layout-based selection (full_layout strategy or predefined layout override) + if StochasticMixerKwargs.layout in kwargs: layout = kwargs[StochasticMixerKwargs.layout] counter = kwargs[StochasticMixerKwargs.layout_counter] idx = counter[0] @@ -186,6 +188,21 @@ def _sample_placement(self, counts: list[int], num_layers: int, generator: torch perm = torch.randperm(num_layers, generator=generator) return [layout[i] for i in perm.tolist()] + def _sample_predefined_layout(self, num_layers: int, generator: torch.Generator) -> list[str] | None: + """ + With probability `predefined_layout_probability`, pick a predefined layout uniformly. + Returns None if we should use the normal sampling strategy instead. + """ + if not self._config.predefined_layouts or self._config.predefined_layout_probability <= 0: + return None + coin = torch.rand(1, generator=generator).item() + if coin >= self._config.predefined_layout_probability: + return None + idx = torch.randint(len(self._config.predefined_layouts), (1,), generator=generator).item() + layout = list(self._config.predefined_layouts[idx]) + Assert.eq(len(layout), num_layers) + return layout + def preprocess(self, kwargs: dict[str, typing.Any]) -> None: from fast_llm.engine.distributed.config import MAX_SEED from fast_llm.layers.block.config import BlockKwargs @@ -196,8 +213,14 @@ def preprocess(self, kwargs: dict[str, typing.Any]) -> None: generator.manual_seed(seed) kwargs[StochasticMixerKwargs.generator] = generator - if self._config.sampling_strategy == StochasticMixerSamplingStrategy.full_layout: - num_layers = kwargs[BlockKwargs.num_blocks_in_sequence] + num_layers = kwargs[BlockKwargs.num_blocks_in_sequence] + predefined = self._sample_predefined_layout(num_layers, generator) + + if predefined is not None: + # Use predefined layout (overrides any sampling strategy) + kwargs[StochasticMixerKwargs.layout] = predefined + kwargs[StochasticMixerKwargs.layout_counter] = [0] + elif self._config.sampling_strategy == StochasticMixerSamplingStrategy.full_layout: counts = self._sample_allocation(num_layers, generator) layout = self._sample_placement(counts, num_layers, generator) kwargs[StochasticMixerKwargs.layout] = layout