diff --git a/fast_llm/engine/checkpoint/external.py b/fast_llm/engine/checkpoint/external.py index 886c706c1..a083e4ad5 100644 --- a/fast_llm/engine/checkpoint/external.py +++ b/fast_llm/engine/checkpoint/external.py @@ -1,4 +1,5 @@ import abc +import dataclasses import logging import pathlib import typing @@ -6,7 +7,7 @@ import torch from fast_llm import __version__ -from fast_llm.config import Config +from fast_llm.config import Config, FieldHint, set_nested_dict_value from fast_llm.engine.base_model.config import BaseModelConfig from fast_llm.engine.checkpoint.config import CheckpointLoadMetadataConfig from fast_llm.engine.checkpoint.state_dict import StateDictCheckpointHandler @@ -18,6 +19,471 @@ logger = logging.getLogger(__name__) +_MISSING = object() + + +def _get_nested(d: dict, path: tuple[str, ...], default=_MISSING): + cur = d + for key in path: + if not isinstance(cur, dict) or key not in cur: + if default is _MISSING: + raise KeyError(f"Missing key {'.'.join(path)} in HF config dict") + return default + cur = cur[key] + return cur + + +def _has_nested(d: dict, path: tuple[str, ...]) -> bool: + cur = d + for key in path: + if not isinstance(cur, dict) or key not in cur: + return False + cur = cur[key] + return True + + +def _get_attr_path(config: Config, path: tuple[str, ...]) -> typing.Any: + cur = config + for name in path: + cur = getattr(cur, name) + return cur + + +# ============================================================ +# Config conversion primitives (declarative) +# ============================================================ + + +class ConfigConverter(abc.ABC): + """A declarative description of how one or more Fast-LLM config fields map to one or more HF config keys. + + Each primitive owns a set of ``fast_llm_paths`` (tuples of attribute names rooted at the section's config) and + ``hf_paths`` (tuples of dict keys rooted at the section's HF subdict). The walker calls ``export_to`` to produce + HF entries from a Fast-LLM config object, and ``import_to`` to produce a Fast-LLM config dict from an HF dict. + """ + + fast_llm_paths: tuple[tuple[str, ...], ...] = () + hf_paths: tuple[tuple[str, ...], ...] = () + + @property + def consumed_fast_llm_fields(self) -> set[str]: + """Top-level Fast-LLM field names this primitive consumes at the current section level. + + Used by the section walker for the architecture-hint coverage check. + """ + return {path[0] for path in self.fast_llm_paths if path} + + @abc.abstractmethod + def export_to(self, fast_llm_config: Config, hf_out: dict) -> None: ... + + @abc.abstractmethod + def import_to(self, hf_dict: dict, fast_llm_out: dict) -> None: ... + + +class RenameConfigConverter(ConfigConverter): + """One-to-one rename between a Fast-LLM attribute path and an HF dict path.""" + + def __init__(self, fast_llm_path: tuple[str, ...], hf_path: tuple[str, ...]): + self.fast_llm_paths = (fast_llm_path,) + self.hf_paths = (hf_path,) + + def export_to(self, fast_llm_config: Config, hf_out: dict) -> None: + value = _get_attr_path(fast_llm_config, self.fast_llm_paths[0]) + set_nested_dict_value(hf_out, self.hf_paths[0], value) + + def import_to(self, hf_dict: dict, fast_llm_out: dict) -> None: + value = _get_nested(hf_dict, self.hf_paths[0]) + set_nested_dict_value(fast_llm_out, self.fast_llm_paths[0], value) + + +class ConstantExportConfigConverter(ConfigConverter): + """Write a constant to the HF dict on export. On import, assert that the HF dict has this constant value. + + Used when a HF format requires a key whose value Fast-LLM doesn't store (or always pins to a constant). + """ + + def __init__(self, hf_path: tuple[str, ...], value: typing.Any): + self.hf_paths = (hf_path,) + self._value = value + + def export_to(self, fast_llm_config: Config, hf_out: dict) -> None: + set_nested_dict_value(hf_out, self.hf_paths[0], self._value) + + def import_to(self, hf_dict: dict, fast_llm_out: dict) -> None: + if _has_nested(hf_dict, self.hf_paths[0]): + actual = _get_nested(hf_dict, self.hf_paths[0]) + Assert.eq(actual, self._value) + + +class ConstantImportConfigConverter(ConfigConverter): + """Inject a constant into the Fast-LLM dict on import. On export, assert the config matches the constant. + + Used when a Fast-LLM field is required but the HF format implies a fixed value (e.g., gated MLP for Llama). + """ + + def __init__(self, fast_llm_path: tuple[str, ...], value: typing.Any): + self.fast_llm_paths = (fast_llm_path,) + self._value = value + + def export_to(self, fast_llm_config: Config, hf_out: dict) -> None: + actual = _get_attr_path(fast_llm_config, self.fast_llm_paths[0]) + Assert.eq(actual, self._value) + + def import_to(self, hf_dict: dict, fast_llm_out: dict) -> None: + set_nested_dict_value(fast_llm_out, self.fast_llm_paths[0], self._value) + + +class DefaultConfigConverter(ConfigConverter): + """Rename with an HF-side fallback used when the HF key is missing on import. + + ``hf_default_fn`` is called with the full HF dict if the path is absent; otherwise it's a plain rename. + On export, behaves like ``RenameConfigConverter``. + """ + + def __init__( + self, + fast_llm_path: tuple[str, ...], + hf_path: tuple[str, ...], + hf_default_fn: typing.Callable[[dict], typing.Any], + ): + self.fast_llm_paths = (fast_llm_path,) + self.hf_paths = (hf_path,) + self._hf_default_fn = hf_default_fn + + def export_to(self, fast_llm_config: Config, hf_out: dict) -> None: + value = _get_attr_path(fast_llm_config, self.fast_llm_paths[0]) + set_nested_dict_value(hf_out, self.hf_paths[0], value) + + def import_to(self, hf_dict: dict, fast_llm_out: dict) -> None: + if _has_nested(hf_dict, self.hf_paths[0]): + value = _get_nested(hf_dict, self.hf_paths[0]) + else: + value = self._hf_default_fn(hf_dict) + set_nested_dict_value(fast_llm_out, self.fast_llm_paths[0], value) + + +class OptionalConfigConverter(ConfigConverter): + """Emit/import only when the value differs from a sentinel (default ``None``). + + Useful for fields that round-trip cleanly only when present (e.g. ``window_size``). + """ + + def __init__(self, fast_llm_path: tuple[str, ...], hf_path: tuple[str, ...], sentinel: typing.Any = None): + self.fast_llm_paths = (fast_llm_path,) + self.hf_paths = (hf_path,) + self._sentinel = sentinel + + def export_to(self, fast_llm_config: Config, hf_out: dict) -> None: + value = _get_attr_path(fast_llm_config, self.fast_llm_paths[0]) + if value != self._sentinel: + set_nested_dict_value(hf_out, self.hf_paths[0], value) + + def import_to(self, hf_dict: dict, fast_llm_out: dict) -> None: + if _has_nested(hf_dict, self.hf_paths[0]): + value = _get_nested(hf_dict, self.hf_paths[0]) + if value != self._sentinel: + set_nested_dict_value(fast_llm_out, self.fast_llm_paths[0], value) + + +class IgnoredConfigConverter(ConfigConverter): + """Declares Fast-LLM architecture fields as intentionally not converted by this format. + + Use when the HF format has no representation for the field and the Fast-LLM default round-trips correctly. + Acts as a no-op on both directions while satisfying the architecture-coverage check. + """ + + def __init__(self, *fast_llm_paths: tuple[str, ...]): + self.fast_llm_paths = fast_llm_paths + self.hf_paths = () + + def export_to(self, fast_llm_config: Config, hf_out: dict) -> None: + return + + def import_to(self, hf_dict: dict, fast_llm_out: dict) -> None: + return + + +class CustomConfigConverter(ConfigConverter): + """Escape hatch for cross-field transforms (e.g., rotary, where one HF blob ↔ several Fast-LLM fields). + + ``fast_llm_paths`` is declared so the coverage check sees the fields as consumed. The HF side is intentionally + not declared — there is no symmetric HF-side coverage check yet, so an ``hf_paths`` argument would be cosmetic. + Cross-field validators that produce nothing on the HF side belong on :py:meth:`ConfigSectionConverter._validate_export` + instead; this primitive is for shape-changing transforms. + """ + + def __init__( + self, + fast_llm_paths: tuple[tuple[str, ...], ...], + export_fn: typing.Callable[[Config], dict], + import_fn: typing.Callable[[dict], dict], + ): + self.fast_llm_paths = fast_llm_paths + self._export_fn = export_fn + self._import_fn = import_fn + + def export_to(self, fast_llm_config: Config, hf_out: dict) -> None: + produced = self._export_fn(fast_llm_config) + for path, value in produced.items(): + set_nested_dict_value(hf_out, path, value) + + def import_to(self, hf_dict: dict, fast_llm_out: dict) -> None: + produced = self._import_fn(hf_dict) + for path, value in produced.items(): + set_nested_dict_value(fast_llm_out, path, value) + + +class NestedConfigConverter(ConfigConverter): + """Recurse into a fixed-typed sub-config field via another section converter class. + + Default (``hf_path=None``): the HF side is flat-merged — the sub-converter's output becomes top-level keys + of the parent's HF dict, asserting any pre-existing keys agree. + + With ``hf_path`` set: the sub-converter's output is placed under that nested key. Use this for HF formats + that mirror Fast-LLM's modular layout (e.g. Apriel2's ``"decoder": {...}`` and ``"head": {...}`` blocks). + """ + + def __init__( + self, + fast_llm_path: tuple[str, ...], + converter_class: "type[ConfigSectionConverter]", + hf_path: tuple[str, ...] | None = None, + ): + self.fast_llm_paths = (fast_llm_path,) + self._converter_class = converter_class + self._hf_path = hf_path + + def export_to(self, fast_llm_config: Config, hf_out: dict) -> None: + sub_config = _get_attr_path(fast_llm_config, self.fast_llm_paths[0]) + sub_hf = self._converter_class.export_config(sub_config) + if self._hf_path is None: + for key, value in sub_hf.items(): + if key in hf_out: + Assert.eq(hf_out[key], value) + else: + hf_out[key] = value + else: + set_nested_dict_value(hf_out, self._hf_path, sub_hf) + + def import_to(self, hf_dict: dict, fast_llm_out: dict) -> None: + sub_hf = _get_nested(hf_dict, self._hf_path) if self._hf_path is not None else hf_dict + sub_fast_llm = self._converter_class.import_config(sub_hf) + set_nested_dict_value(fast_llm_out, self.fast_llm_paths[0], sub_fast_llm) + + +class DispatchConfigConverter(ConfigConverter): + """Polymorphic sub-config dispatch. + + The Fast-LLM field's runtime type selects the section converter; the HF format selects via a ``type`` discriminator. + Both registries (Fast-LLM type → converter class, HF discriminator → converter class) must agree at runtime. + """ + + def __init__( + self, + fast_llm_path: tuple[str, ...], + hf_path: tuple[str, ...] | None, + registry: "dict[type[Config], type[ConfigSectionConverter]]", + hf_discriminator_key: str = "type", + ): + self.fast_llm_paths = (fast_llm_path,) + self.hf_paths = (hf_path,) if hf_path is not None else () + self._registry = registry + self._hf_discriminator_key = hf_discriminator_key + self._hf_to_class = {cls.hf_type_name: cls for cls in registry.values() if cls.hf_type_name is not None} + + def export_to(self, fast_llm_config: Config, hf_out: dict) -> None: + sub_config = _get_attr_path(fast_llm_config, self.fast_llm_paths[0]) + converter_class = self._registry.get(type(sub_config)) + if converter_class is None: + raise NotImplementedError( + f"No converter registered for {type(sub_config).__name__} at {'.'.join(self.fast_llm_paths[0])}" + ) + sub_hf = converter_class.export_config(sub_config) + if converter_class.hf_type_name is not None: + sub_hf = {self._hf_discriminator_key: converter_class.hf_type_name, **sub_hf} + if self.hf_paths: + set_nested_dict_value(hf_out, self.hf_paths[0], sub_hf) + else: + for key, value in sub_hf.items(): + hf_out[key] = value + + def import_to(self, hf_dict: dict, fast_llm_out: dict) -> None: + sub_hf = _get_nested(hf_dict, self.hf_paths[0]) if self.hf_paths else hf_dict + type_name = sub_hf.get(self._hf_discriminator_key) + converter_class = self._hf_to_class.get(type_name) + if converter_class is None: + raise NotImplementedError( + f"No converter registered for HF discriminator {type_name!r} at " f"{'.'.join(self.fast_llm_paths[0])}" + ) + sub_fast_llm = converter_class.import_config(sub_hf) + # Inject the Fast-LLM dynamic-type discriminator so the parent's `from_dict` dispatches to the + # correct subclass. Reads from the registered Config class rather than the HF discriminator so + # mismatched Fast-LLM/HF type names work too. + fast_llm_type = getattr(converter_class.fast_llm_config_class, "dynamic_type_name", None) + if fast_llm_type is not None: + sub_fast_llm = {"type": fast_llm_type, **sub_fast_llm} + set_nested_dict_value(fast_llm_out, self.fast_llm_paths[0], sub_fast_llm) + + +class TypedDictContainerConfigConverter(ConfigConverter): + """Maps a Fast-LLM ``dict[str, Config]`` field to an HF ``dict[str, dict]`` where each entry is round-tripped + through a per-class section converter selected via the entry's runtime type (export) or HF discriminator (import). + + Each entry's HF subdict carries a discriminator key (``"type"`` by default) populated from the converter's + ``hf_type_name``. For homogeneous dicts, register a single class with ``hf_type_name = None``; the discriminator + is then omitted on export and ignored on import. + """ + + def __init__( + self, + fast_llm_path: tuple[str, ...], + hf_path: tuple[str, ...], + registry: "dict[type[Config], type[ConfigSectionConverter]]", + hf_discriminator_key: str = "type", + ): + self.fast_llm_paths = (fast_llm_path,) + self.hf_paths = (hf_path,) + self._registry = registry + self._hf_discriminator_key = hf_discriminator_key + self._hf_to_class = {cls.hf_type_name: cls for cls in registry.values() if cls.hf_type_name is not None} + self._homogeneous = len(registry) == 1 and next(iter(registry.values())).hf_type_name is None + if self._homogeneous: + self._homogeneous_class = next(iter(registry.values())) + + def export_to(self, fast_llm_config: Config, hf_out: dict) -> None: + sub_dict = _get_attr_path(fast_llm_config, self.fast_llm_paths[0]) + out: dict = {} + for name, sub_config in sub_dict.items(): + if self._homogeneous: + converter_class = self._homogeneous_class + else: + converter_class = self._registry.get(type(sub_config)) + if converter_class is None: + raise NotImplementedError( + f"No converter registered for {type(sub_config).__name__} at " + f"{'.'.join(self.fast_llm_paths[0])}[{name!r}]" + ) + sub_hf = converter_class.export_config(sub_config) + if converter_class.hf_type_name is not None: + sub_hf = {self._hf_discriminator_key: converter_class.hf_type_name, **sub_hf} + out[name] = sub_hf + set_nested_dict_value(hf_out, self.hf_paths[0], out) + + def import_to(self, hf_dict: dict, fast_llm_out: dict) -> None: + sub_hf_dict = _get_nested(hf_dict, self.hf_paths[0]) + out: dict = {} + for name, sub_hf in sub_hf_dict.items(): + if self._homogeneous: + converter_class = self._homogeneous_class + else: + type_name = sub_hf.get(self._hf_discriminator_key) + converter_class = self._hf_to_class.get(type_name) + if converter_class is None: + raise NotImplementedError( + f"No converter registered for HF discriminator {type_name!r} at " + f"{'.'.join(self.hf_paths[0])}[{name!r}]" + ) + sub_fast_llm = converter_class.import_config(sub_hf) + fast_llm_type = getattr(converter_class.fast_llm_config_class, "dynamic_type_name", None) + if fast_llm_type is not None: + sub_fast_llm = {"type": fast_llm_type, **sub_fast_llm} + out[name] = sub_fast_llm + set_nested_dict_value(fast_llm_out, self.fast_llm_paths[0], out) + + +# ============================================================ +# Section converter — converts one Fast-LLM config class +# ============================================================ + + +class ConfigSectionConverter(abc.ABC): + """Base class for converting one Fast-LLM ``Config`` class ↔ one HF dict subtree. + + Subclasses declare the conversion via ``_create_config_converters``. Format-specific cross-field + invariants go on the ``_validate_export`` hook. The weight side is still imperative (per-converter + ``get_converters`` classmethods on the concrete subclasses); a declarative weight-side primitive will be + added when the weight-converter migration lands. + + Subclasses that participate in :class:`DispatchConfigConverter` set ``hf_type_name`` to the discriminator value + used by the HF format (e.g. ``"attention"``, ``"mamba"``). + """ + + fast_llm_config_class: typing.ClassVar[type[Config]] + hf_type_name: typing.ClassVar[str | None] = None + + @classmethod + @abc.abstractmethod + def _create_config_converters(cls) -> dict[str, ConfigConverter]: + """Return declarations keyed by stable string name. Subclasses override entries by re-declaring the key.""" + + @classmethod + def _validate_export(cls, config: Config) -> None: + """Hook for format-specific export-time validation. Default no-op. + + Runs after the architecture-coverage check and before any declaration emits. Use this for cross-field + invariants the format imposes on the Fast-LLM config (e.g. per-layer biases must match a parent flag, + certain sub-configs must be at their default). Subclasses override; super-calls are not required when + the rule is fully replaced (e.g. Qwen2 vs Llama attention biases). + """ + return + + @classmethod + def export_config(cls, config: Config) -> dict: + """Convert a Fast-LLM config object to an HF config dict via this section's declarations.""" + declarations = cls._create_config_converters() + cls._check_architecture_coverage(config, declarations) + cls._validate_export(config) + out: dict = {} + for converter in declarations.values(): + converter.export_to(config, out) + return out + + @classmethod + def import_config(cls, hf_dict: dict) -> dict: + """Convert an HF config dict to a Fast-LLM config dict via this section's declarations.""" + out: dict = {} + for converter in cls._create_config_converters().values(): + converter.import_to(hf_dict, out) + return out + + @classmethod + def _check_architecture_coverage(cls, config: Config, declarations: dict[str, ConfigConverter]) -> None: + """Raise if any architecture-hint field on the section's declared config class is not consumed. + + Coverage is structural (based on field hints), not value-based: every architecture field must be + explicitly accounted for, even if it currently holds its Fast-LLM default. Sub-config fields are + consumed by ``NestedConfigConverter``/``DispatchConfigConverter``, which delegate the deeper coverage + check to the nested section's own converter. + + The check only runs when ``type(config)`` exactly matches ``cls.fast_llm_config_class`` — when the + config is a strict subclass (e.g. ``MoEMLPConfig`` fed via ``super().export_config()`` from a yet-to-be- + migrated ``MixtralMLPConverter``), the subclass converter is responsible for declaring the additional + fields and running its own check. TODO: Once Mixtral/Apriel/Apriel2 migrate, the safety net for + ``MoEMLPConfig``/``MambaConfig``/etc. is gated on those migrations landing. + """ + if type(config) is not cls.fast_llm_config_class: + return + consumed: set[str] = set() + for converter in declarations.values(): + consumed |= converter.consumed_fast_llm_fields + missing: list[str] = [] + for name, field in type(config).fields(): + if field._field_type != dataclasses._FIELD: + continue + if not field.init: + continue + if field.hint != FieldHint.architecture: + continue + if name in consumed: + continue + missing.append(name) + if missing: + raise ValueError( + f"{cls.__name__}: architecture-hint fields on {type(config).__name__} " + f"have no converter declaration: {missing}" + ) + + class WeightConverter: def __init__( self, @@ -76,18 +542,6 @@ def import_weight( ) -class CopyWeightConverter(WeightConverter): - def export_weight( - self, weight: tuple[torch.Tensor | SafeTensorSlice, ...] - ) -> tuple[torch.Tensor | SafeTensorSlice, ...]: - return weight[0], *[weight[0][:].clone() for _ in range(len(self.export_name) - 1)] - - def import_weight( - self, weight: tuple[torch.Tensor | SafeTensorSlice, ...] - ) -> tuple[torch.Tensor | SafeTensorSlice, ...]: - return weight[0], *[weight[0][:].clone() for _ in range(len(self.fast_llm_name) - 1)] - - class SplitWeightConverter(WeightConverter): def export_weight( self, weight: tuple[torch.Tensor | SafeTensorSlice, ...] diff --git a/fast_llm/layers/attention/config.py b/fast_llm/layers/attention/config.py index fcb5bfaf6..efdec6c99 100644 --- a/fast_llm/layers/attention/config.py +++ b/fast_llm/layers/attention/config.py @@ -62,7 +62,7 @@ class AttentionConfig(MixerConfig): ) dense_layer: AffineLinearConfig = Field( desc="Initialization configuration for the dense layer.", - hint=FieldHint.feature, + hint=FieldHint.architecture, ) # TODO: Review names rotary: RotaryConfig = Field( @@ -115,6 +115,7 @@ class AttentionConfig(MixerConfig): " Under Standard Parameterization (SP): default to 0.5. " " Under muP (if scaling head_size size): use 1. " " Under muP (if scaling number of heads instead of head_size): use 0.5.", + hint=FieldHint.architecture, valid=skip_valid_if_none(check_field(Assert.geq, 0)), ) implementation: AttentionImplementation = Field( diff --git a/fast_llm/layers/attention/rotary/config.py b/fast_llm/layers/attention/rotary/config.py index 80f499748..e5e5c8d34 100644 --- a/fast_llm/layers/attention/rotary/config.py +++ b/fast_llm/layers/attention/rotary/config.py @@ -78,10 +78,10 @@ class Llama3RotaryConfig(DefaultRotaryConfig): """ # TODO: Add descriptions. - scale_factor: float = Field(default=8.0, hint=FieldHint.feature) - low_frequency_factor: float = Field(default=1.0, hint=FieldHint.feature) - high_frequency_factor: float = Field(default=4.0, hint=FieldHint.feature) - original_context_length: int = Field(default=8192, hint=FieldHint.feature) + scale_factor: float = Field(default=8.0, hint=FieldHint.architecture) + low_frequency_factor: float = Field(default=1.0, hint=FieldHint.architecture) + high_frequency_factor: float = Field(default=4.0, hint=FieldHint.architecture) + original_context_length: int = Field(default=8192, hint=FieldHint.architecture) def _validate(self) -> None: super()._validate() @@ -102,20 +102,20 @@ class YarnRotaryConfig(DefaultRotaryConfig): """ # TODO: Add descriptions. - scale_factor: float = Field(default=8.0, hint=FieldHint.feature) + scale_factor: float = Field(default=8.0, hint=FieldHint.architecture) attention_factor: None | float = Field( default=None, - hint=FieldHint.feature, + hint=FieldHint.architecture, ) beta_fast: float = Field( default=32.0, - hint=FieldHint.feature, + hint=FieldHint.architecture, ) beta_slow: float = Field( default=1.0, - hint=FieldHint.feature, + hint=FieldHint.architecture, ) - original_context_length: int = Field(default=8192, hint=FieldHint.feature) + original_context_length: int = Field(default=8192, hint=FieldHint.architecture) def _validate(self) -> None: if self.attention_factor is None: diff --git a/fast_llm/layers/block/config.py b/fast_llm/layers/block/config.py index aa47a5f2e..25c5fcc82 100644 --- a/fast_llm/layers/block/config.py +++ b/fast_llm/layers/block/config.py @@ -146,7 +146,10 @@ def last_block_config(self) -> BlockConfig: @config_class(dynamic_type={BlockSequenceConfig: "pattern"}) class PatternBlockSequenceConfig(BlockSequenceConfig): _abstract = False - blocks: dict[str, BlockConfig] = Field() + blocks: dict[str, BlockConfig] = Field( + desc="Named block configurations referenced by `pattern`.", + hint=FieldHint.architecture, + ) pattern: list[str] = Field( default=None, desc="The name of each block (key in `blocks`) in the repeated pattern.", diff --git a/fast_llm/layers/decoder/config.py b/fast_llm/layers/decoder/config.py index 6ab259b2b..ea2ba5fa3 100644 --- a/fast_llm/layers/decoder/config.py +++ b/fast_llm/layers/decoder/config.py @@ -156,7 +156,7 @@ class StochasticMixerConfig(MixerConfig): "Used for inference/eval, checkpoint loading (receives pretrained weights), " "and checkpoint saving (only this mixer is exported). " "If None, uses the first mixer in the dict.", - hint=FieldHint.feature, + hint=FieldHint.architecture, ) seed_shift: int = Field( diff --git a/fast_llm/layers/decoder/mlp/config.py b/fast_llm/layers/decoder/mlp/config.py index 997cf9d2a..01f5bc052 100644 --- a/fast_llm/layers/decoder/mlp/config.py +++ b/fast_llm/layers/decoder/mlp/config.py @@ -62,7 +62,7 @@ class MLPConfig(MLPBaseConfig): activation: ActivationType = Field( default=None, desc="The MLP intermediate activation type. Default: SiLU for gated MLP, GeLU otherwise.", - hint=FieldHint.core, + hint=FieldHint.architecture, ) # normalization_implementation: NormalizationImplementation = NormalizationImplementation.auto recompute_level: MLPRecomputeLevel = Field( @@ -95,7 +95,7 @@ class MoEMLPConfig(MLPConfig): router: LinearConfig = Field( # TODO: Improve default? desc="Configuration for the MoE router.", - hint=FieldHint.feature, + hint=FieldHint.architecture, ) experts: int = Field( default=2, diff --git a/fast_llm/layers/vision/config.py b/fast_llm/layers/vision/config.py index 5920a85ee..47cf43391 100644 --- a/fast_llm/layers/vision/config.py +++ b/fast_llm/layers/vision/config.py @@ -34,12 +34,12 @@ class PatchEmbeddingsConfig(BlockConfig): patch_height: int = Field( default=16, desc="Height of image patches, in pixels.", - hint=FieldHint.core, + hint=FieldHint.architecture, ) patch_width: int = Field( default=16, desc="Width of image patches, in pixels.", - hint=FieldHint.core, + hint=FieldHint.architecture, ) full_precision_residual: bool = Field( default=False, diff --git a/fast_llm/models/gpt/conversion/apriel.py b/fast_llm/models/gpt/conversion/apriel.py index ac732ba22..efa801799 100644 --- a/fast_llm/models/gpt/conversion/apriel.py +++ b/fast_llm/models/gpt/conversion/apriel.py @@ -4,7 +4,14 @@ from transformers import PretrainedConfig from fast_llm.engine.checkpoint.config import CheckpointFormat -from fast_llm.engine.checkpoint.external import WeightConverter +from fast_llm.engine.checkpoint.external import ( + ConfigSectionConverter, + CustomConfigConverter, + DefaultConfigConverter, + IgnoredConfigConverter, + RenameConfigConverter, + WeightConverter, +) from fast_llm.layers.attention.config import AttentionConfig from fast_llm.layers.block.config import BlockSequenceConfig, FixedBlockSequenceConfig, PatternBlockSequenceConfig from fast_llm.layers.decoder.config import DecoderBlockConfig @@ -22,51 +29,88 @@ from fast_llm.utils import Assert, safe_merge_dicts -class AprielMambaConverter: - @classmethod - def import_config(cls, config: dict) -> dict: - return { - "type": "mamba", - "state_size": config["ssm_cfg"]["d_state"], - "d_inner": config["ssm_cfg"].get("d_inner") or config["hidden_size"] * config["ssm_cfg"].get("expand", 1), - "add_linear_biases": config["ssm_cfg"]["bias"], - "convolution_layer": {"bias": {"enabled": config["ssm_cfg"].get("conv_bias", True)}}, - "d_xb": config["ssm_cfg"].get("d_xb") or config["hidden_size"], - "dt_layer": {"bias": {"enabled": config["ssm_cfg"].get("dt_proj_bias", True)}}, - "dt_rank": ( - math.ceil(config["hidden_size"] / 16) - if config["ssm_cfg"].get("dt_rank", "auto") == "auto" - else config["ssm_cfg"]["dt_rank"] - ), - "repeat_kv_before_conv": config["ssm_cfg"].get("repeat_kv_before_conv", True), - } +def _resolve_bias_enabled(layer_bias_enabled: bool | None, add_linear_biases: bool) -> bool: + """Per-layer bias falls back to the mixer-wide flag when unset, matching the imperative behaviour.""" + return add_linear_biases if layer_bias_enabled is None else layer_bias_enabled + + +class AprielMambaConverter(ConfigSectionConverter): + """Converts ``MambaConfig`` <-> Apriel hybrid SSM HF dict (``ssm_cfg`` subdict + root-level fallbacks). + + A few of MambaConfig's defaults are derived from the HF root's ``hidden_size`` (``d_inner`` defaults + to ``hidden_size * expand``, ``d_xb`` defaults to ``hidden_size``, ``dt_rank="auto"`` resolves to + ``ceil(hidden_size / 16)``). Those declarations read the root HF dict directly, so each leaf + converter sees the full HF root passed by the parent block dispatcher. + """ + + fast_llm_config_class = MambaConfig @classmethod - def export_config(cls, config: MambaConfig) -> dict: - cls._check_config(config) + def _create_config_converters(cls) -> dict: return { - "ssm_cfg": { - "d_state": config.state_size, - "d_inner": config.d_inner, - "bias": config.add_linear_biases, - "conv_bias": ( - config.add_linear_biases - if config.convolution_layer.bias.enabled is None - else config.convolution_layer.bias.enabled - ), - "d_xb": config.d_xb, - "dt_proj_bias": ( - config.add_linear_biases if config.dt_layer.bias.enabled is None else config.dt_layer.bias.enabled - ), - "dt_rank": config.dt_rank, - "repeat_kv_before_conv": config.repeat_kv_before_conv, - } + "state_size": RenameConfigConverter(("state_size",), ("ssm_cfg", "d_state")), + "d_inner": DefaultConfigConverter( + ("d_inner",), + ("ssm_cfg", "d_inner"), + hf_default_fn=lambda hf: hf["hidden_size"] * hf.get("ssm_cfg", {}).get("expand", 1), + ), + "d_xb": DefaultConfigConverter( + ("d_xb",), + ("ssm_cfg", "d_xb"), + hf_default_fn=lambda hf: hf["hidden_size"], + ), + "dt_rank": CustomConfigConverter( + fast_llm_paths=(("dt_rank",),), + export_fn=lambda c: {("ssm_cfg", "dt_rank"): c.dt_rank}, + import_fn=lambda hf: { + ("dt_rank",): ( + math.ceil(hf["hidden_size"] / 16) + if hf.get("ssm_cfg", {}).get("dt_rank", "auto") == "auto" + else hf["ssm_cfg"]["dt_rank"] + ) + }, + ), + "add_linear_biases": RenameConfigConverter(("add_linear_biases",), ("ssm_cfg", "bias")), + "repeat_kv_before_conv": DefaultConfigConverter( + ("repeat_kv_before_conv",), + ("ssm_cfg", "repeat_kv_before_conv"), + hf_default_fn=lambda hf: True, + ), + "convolution_layer": CustomConfigConverter( + fast_llm_paths=(("convolution_layer",),), + export_fn=lambda c: { + ("ssm_cfg", "conv_bias"): _resolve_bias_enabled( + c.convolution_layer.bias.enabled, c.add_linear_biases + ) + }, + import_fn=lambda hf: { + ("convolution_layer", "bias", "enabled"): hf.get("ssm_cfg", {}).get("conv_bias", True) + }, + ), + "dt_layer": CustomConfigConverter( + fast_llm_paths=(("dt_layer",),), + export_fn=lambda c: { + ("ssm_cfg", "dt_proj_bias"): _resolve_bias_enabled(c.dt_layer.bias.enabled, c.add_linear_biases) + }, + import_fn=lambda hf: { + ("dt_layer", "bias", "enabled"): hf.get("ssm_cfg", {}).get("dt_proj_bias", True) + }, + ), + # Per-layer biases that must round-trip implicitly via add_linear_biases (validated below). + "linear_layers": IgnoredConfigConverter( + ("z_layer",), + ("x_layer",), + ("b_layer",), + ("c_layer",), + ("output_layer",), + ("dt_input_layer",), + ), + # Parameter sub-configs Mamba doesn't expose to HF; coverage-only. + "parameters": IgnoredConfigConverter(("d_weight",), ("a_log_weight",)), } @classmethod - def _check_config(cls, config: MambaConfig) -> None: - # Opportunity to make derived classes less constrained. - Assert.is_(type(config), MambaConfig) + def _validate_export(cls, config: MambaConfig) -> None: Assert.incl(config.z_layer.bias.enabled, (None, config.add_linear_biases)) Assert.incl(config.x_layer.bias.enabled, (None, config.add_linear_biases)) Assert.incl(config.b_layer.bias.enabled, (None, config.add_linear_biases)) @@ -74,6 +118,13 @@ def _check_config(cls, config: MambaConfig) -> None: Assert.incl(config.dt_input_layer.bias.enabled, (None, config.add_linear_biases)) Assert.incl(config.output_layer.bias.enabled, (None, config.add_linear_biases)) + @classmethod + def import_config(cls, hf_dict: dict) -> dict: + # Inject the Fast-LLM dynamic-type discriminator: the parent (AprielBlockConverter) selects this + # leaf via `hybrid_block_layout`, not via a nested HF discriminator, so DispatchConfigConverter's + # auto-injection isn't in play and we must add `type` manually. + return {"type": "mamba", **super().import_config(hf_dict)} + @classmethod def get_converters( cls, @@ -99,17 +150,13 @@ def get_converters( *get_weight_and_bias_converters( f"{fast_llm_prefix}.dt_proj", f"{hf_prefix}.dt_proj", - config.add_linear_biases if config.dt_layer.bias.enabled is None else config.dt_layer.bias.enabled, + _resolve_bias_enabled(config.dt_layer.bias.enabled, config.add_linear_biases), drop_on_export=drop_on_export, ), *get_weight_and_bias_converters( f"{fast_llm_prefix}.convolution", f"{hf_prefix}.conv1d", - ( - config.add_linear_biases - if config.convolution_layer.bias.enabled is None - else config.convolution_layer.bias.enabled - ), + _resolve_bias_enabled(config.convolution_layer.bias.enabled, config.add_linear_biases), drop_on_export=drop_on_export, ), get_parameter_converter( @@ -131,31 +178,36 @@ def get_converters( ] -class GatedDeltaNetConverter: +class GatedDeltaNetConverter(ConfigSectionConverter): + """Converts ``GatedDeltaNetConfig`` <-> Apriel HF ``linear_attn_config`` subdict.""" + + fast_llm_config_class = GatedDeltaNetConfig + @classmethod - def import_config(cls, config: dict) -> dict: + def _create_config_converters(cls) -> dict: return { - "type": "gdn", - "value_heads": config["linear_attn_config"]["gdn_num_value_heads"], - "key_heads": config["linear_attn_config"]["gdn_num_key_heads"], - "key_head_dim": config["linear_attn_config"]["gdn_key_head_dim"], - "value_head_dim": config["linear_attn_config"]["gdn_value_head_dim"], - "convolution_layer": { - "kernel_size": config["linear_attn_config"]["gdn_linear_conv_kernel_size"], - }, + "value_heads": RenameConfigConverter(("value_heads",), ("linear_attn_config", "gdn_num_value_heads")), + "key_heads": RenameConfigConverter(("key_heads",), ("linear_attn_config", "gdn_num_key_heads")), + "key_head_dim": RenameConfigConverter(("key_head_dim",), ("linear_attn_config", "gdn_key_head_dim")), + "value_head_dim": RenameConfigConverter(("value_head_dim",), ("linear_attn_config", "gdn_value_head_dim")), + "convolution_kernel_size": RenameConfigConverter( + ("convolution_layer", "kernel_size"), + ("linear_attn_config", "gdn_linear_conv_kernel_size"), + ), + # Sub-configs without HF representation; coverage-only. + "sub_configs": IgnoredConfigConverter( + ("normalization",), + ("qkv_projection_layer",), + ("ba_projection_layer",), + ("output_layer",), + ("dt_bias_weight",), + ("a_log_weight",), + ), } @classmethod - def export_config(cls, config: GatedDeltaNetConfig) -> dict: - return { - "linear_attn_config": { - "gdn_num_value_heads": config.value_heads, - "gdn_num_key_heads": config.key_heads, - "gdn_key_head_dim": config.key_head_dim, - "gdn_value_head_dim": config.value_head_dim, - "gdn_linear_conv_kernel_size": config.convolution_layer.kernel_size, - }, - } + def import_config(cls, hf_dict: dict) -> dict: + return {"type": "gdn", **super().import_config(hf_dict)} @classmethod def get_converters( @@ -209,27 +261,40 @@ def get_converters( ] -class KimiDeltaAttentionConverter: +class KimiDeltaAttentionConverter(ConfigSectionConverter): + """Converts ``KimiDeltaAttentionConfig`` <-> Apriel HF ``linear_attn_config`` subdict.""" + + fast_llm_config_class = KimiDeltaAttentionConfig + @classmethod - def import_config(cls, config: dict) -> dict: + def _create_config_converters(cls) -> dict: return { - "type": "kda", - "head_dim": config["linear_attn_config"]["head_dim"], - "heads": config["linear_attn_config"]["num_heads"], - "convolution_layer": { - "kernel_size": config["linear_attn_config"]["short_conv_kernel_size"], - }, + "head_dim": RenameConfigConverter(("head_dim",), ("linear_attn_config", "head_dim")), + "heads": RenameConfigConverter(("heads",), ("linear_attn_config", "num_heads")), + "convolution_kernel_size": RenameConfigConverter( + ("convolution_layer", "kernel_size"), + ("linear_attn_config", "short_conv_kernel_size"), + ), + # Sub-configs without HF representation; coverage-only. + "sub_configs": IgnoredConfigConverter( + ("normalization",), + ("q_projection_layer",), + ("k_projection_layer",), + ("v_projection_layer",), + ("f_a_projection_layer",), + ("f_b_projection_layer",), + ("g_a_projection_layer",), + ("g_b_projection_layer",), + ("beta_projection_layer",), + ("output_projection_layer",), + ("dt_bias_weight",), + ("a_log_weight",), + ), } @classmethod - def export_config(cls, config: KimiDeltaAttentionConfig) -> dict: - return { - "linear_attn_config": { - "head_dim": config.head_dim, - "num_heads": config.heads, - "short_conv_kernel_size": config.convolution_layer.kernel_size, - }, - } + def import_config(cls, hf_dict: dict) -> dict: + return {"type": "kda", **super().import_config(hf_dict)} @classmethod def get_converters( @@ -347,6 +412,11 @@ class AprielGatedDeltaNetBlockConverter(MistralBlockConverter): class AprielBlockConverter: + """Per-block dispatcher: the mixer type is encoded in the parent's ``hybrid_block_layout`` list, + not in a nested HF discriminator, so this dispatcher stays imperative rather than using + :class:`DispatchConfigConverter`. Each branch delegates to a regular declarative block converter. + """ + layout_names = { AttentionConfig: "t", MambaConfig: "m2", @@ -382,6 +452,11 @@ def get_converters( class AprielDecoderConverter(MistralDecoderConverter): + """Pattern-style decoder dispatched via Apriel's ``hybrid_block_layout`` list (one entry per block). + Stays imperative because the layout-list shape doesn't match the declarative ``decoder.type`` + discriminator that Apriel2 uses. + """ + block_converter_class: typing.ClassVar[type[AprielBlockConverter]] = AprielBlockConverter @classmethod @@ -413,7 +488,8 @@ def export_config(cls, config: BlockSequenceConfig) -> dict: pattern_block_configs = [config.blocks[block_name] for block_name in config.pattern] else: raise NotImplementedError() - # There may be all sorts of blocks, but `safe_merge_dicts` ensures they are compatible. + # Each block emits non-overlapping HF keys (attention -> flat, mamba -> ssm_cfg.*, + # gdn/kda -> linear_attn_config.*) so safe_merge_dicts is sufficient to combine them. return safe_merge_dicts( *[cls.block_converter_class.export_config(block_config) for block_config in block_configs], { diff --git a/fast_llm/models/gpt/conversion/apriel2.py b/fast_llm/models/gpt/conversion/apriel2.py index 9b6657b03..86b4caf4f 100644 --- a/fast_llm/models/gpt/conversion/apriel2.py +++ b/fast_llm/models/gpt/conversion/apriel2.py @@ -5,10 +5,32 @@ from transformers import PretrainedConfig from fast_llm.engine.checkpoint.config import CheckpointFormat -from fast_llm.engine.checkpoint.external import WeightConverter +from fast_llm.engine.checkpoint.external import ( + ConfigSectionConverter, + ConstantExportConfigConverter, + ConstantImportConfigConverter, + CustomConfigConverter, + DispatchConfigConverter, + IgnoredConfigConverter, + NestedConfigConverter, + OptionalConfigConverter, + RenameConfigConverter, + TypedDictContainerConfigConverter, + WeightConverter, +) from fast_llm.engine.checkpoint.huggingface import HuggingfaceStateDictCheckpointHandler +from fast_llm.functional.config import ActivationType from fast_llm.layers.attention.config import AttentionConfig -from fast_llm.layers.decoder.config import DecoderBlockConfig, StochasticMixerConfig, StochasticMixerSamplingStrategy +from fast_llm.layers.attention.rotary.config import DefaultRotaryConfig, Llama3RotaryConfig, YarnRotaryConfig +from fast_llm.layers.block.config import FixedBlockSequenceConfig, PatternBlockSequenceConfig +from fast_llm.layers.common.normalization.config import ( + LayerNormalizationConfig, + NoNormalizationConfig, + RMSNormalizationConfig, +) +from fast_llm.layers.decoder.config import DecoderBlockConfig, StochasticMixerConfig +from fast_llm.layers.decoder.mlp.config import MLPConfig +from fast_llm.layers.language_model.config import LanguageModelHeadConfig from fast_llm.layers.ssm.config import GatedDeltaNetConfig, KimiDeltaAttentionConfig, MambaConfig from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTModelConfig from fast_llm.models.gpt.conversion.config import Apriel2TextCheckpointFormat @@ -17,7 +39,6 @@ LlamaEmbeddingsConverter, LlamaNormalizationConverter, MLPLayer2Converter, - QueryWeightConverter, SplitWeightConverter, get_parameter_converter, get_weight_and_bias_converters, @@ -25,82 +46,94 @@ from fast_llm.models.gpt.model import GPTModel from fast_llm.utils import Assert, safe_merge_dicts +# ============================================================ +# Helpers +# ============================================================ -class Apriel2AttentionConverter: - @classmethod - def import_config(cls, config: dict) -> dict: - rotary = config["rotary"] - # Map Apriel2 HuggingFace rotary type to Fast-LLM internal type - if rotary.get("type") == "mistral_1d": - rotary = {**rotary, "type": "default"} - result = { - "type": "attention", - "heads": config["heads"], - "head_groups": config["head_groups"], - "head_size": config["head_size"], - "rotary": rotary, - } - # Per-layer bias configuration mirroring Fast-LLM structure - # If per-layer configs exist, use them; otherwise fall back to add_linear_biases - if "query_layer" in config: - result["query_layer"] = config["query_layer"] - if "key_layer" in config: - result["key_layer"] = config["key_layer"] - if "value_layer" in config: - result["value_layer"] = config["value_layer"] - if "dense_layer" in config: - result["dense_layer"] = config["dense_layer"] - # add_linear_biases serves as default for layers without explicit config - if "add_linear_biases" in config: - result["add_linear_biases"] = config["add_linear_biases"] - if "window_size" in config: - result["window_size"] = config["window_size"] - return result + +def _per_layer_bias_export(config, layer_names: tuple[str, ...]) -> dict: + """Emit per-layer ``{layer: {"bias": {"enabled": bool}}}`` only for layers whose bias is explicitly set.""" + out: dict = {} + for layer_name in layer_names: + layer = getattr(config, layer_name) + if layer.bias.enabled is not None: + out[(layer_name,)] = {"bias": {"enabled": layer.bias.enabled}} + return out + + +def _per_layer_bias_import(hf_dict: dict, layer_names: tuple[str, ...]) -> dict: + """Pass through HF ``{layer: {"bias": {...}}}`` entries to the Fast-LLM dict.""" + out: dict = {} + for layer_name in layer_names: + if layer_name in hf_dict: + out[(layer_name,)] = hf_dict[layer_name] + return out + + +# ============================================================ +# Mixer converters +# ============================================================ + + +def _apriel2_attention_rotary_export(config: AttentionConfig) -> dict: + """Emit Apriel2's typed rotary subdict. + + Asymmetric with the Fast-LLM type only for the default→``mistral_1d`` rename; ``llama3``/``yarn`` round-trip + by name. Mirrors current behavior: only ``type`` and ``theta`` are emitted (scale fields are dropped). + """ + rotary = config.rotary + if type(rotary) is DefaultRotaryConfig: + rotary_type = "mistral_1d" + elif type(rotary) is Llama3RotaryConfig: + rotary_type = "llama3" + elif type(rotary) is YarnRotaryConfig: + rotary_type = "yarn" + else: + raise NotImplementedError(f"Unsupported rotary type: {type(rotary).__name__}") + return {("rotary",): {"type": rotary_type, "theta": rotary.theta}} + + +def _apriel2_attention_rotary_import(hf_dict: dict) -> dict: + rotary = dict(hf_dict["rotary"]) + if rotary.get("type") == "mistral_1d": + rotary["type"] = "default" + return {("rotary",): rotary} + + +class Apriel2AttentionConverter(ConfigSectionConverter): + fast_llm_config_class = AttentionConfig + hf_type_name = "attention" @classmethod - def export_config(cls, config: AttentionConfig) -> dict: - from fast_llm.layers.attention.rotary.config import DefaultRotaryConfig, Llama3RotaryConfig, YarnRotaryConfig - - if type(config.rotary) is DefaultRotaryConfig: - rotary_type = "mistral_1d" - elif type(config.rotary) is Llama3RotaryConfig: - rotary_type = "llama3" - elif type(config.rotary) is YarnRotaryConfig: - rotary_type = "yarn" - else: - raise NotImplementedError(f"Unsupported rotary type: {type(config.rotary).__name__}") - - result = { - "type": "attention", - "heads": config.heads, - "head_groups": config.head_groups, - "head_size": config.head_size, - "rotary": { - "type": rotary_type, - "theta": config.rotary.theta, - }, + def _create_config_converters(cls) -> dict: + layer_names = ("query_layer", "key_layer", "value_layer", "dense_layer") + return { + "heads": RenameConfigConverter(("heads",), ("heads",)), + "head_groups": RenameConfigConverter(("head_groups",), ("head_groups",)), + "head_size": RenameConfigConverter(("head_size",), ("head_size",)), + "rotary": CustomConfigConverter( + fast_llm_paths=(("rotary",),), + export_fn=_apriel2_attention_rotary_export, + import_fn=_apriel2_attention_rotary_import, + ), + # Apriel2 emits add_linear_biases only when False; the True default is implicit. + "add_linear_biases": OptionalConfigConverter( + ("add_linear_biases",), ("add_linear_biases",), sentinel=True + ), + "window_size": OptionalConfigConverter(("window_size",), ("window_size",)), + "linear_layers": CustomConfigConverter( + fast_llm_paths=tuple((name,) for name in layer_names), + export_fn=lambda c: _per_layer_bias_export(c, layer_names), + import_fn=lambda hf: _per_layer_bias_import(hf, layer_names), + ), + "causal": IgnoredConfigConverter(("causal",)), + "softmax_scale_power": IgnoredConfigConverter(("softmax_scale_power",)), } - if config.window_size is not None: - result["window_size"] = config.window_size - # Export per-layer bias configuration - # Only include if explicitly set (not None) - if config.query_layer.bias.enabled is not None: - result["query_layer"] = {"bias": {"enabled": config.query_layer.bias.enabled}} - if config.key_layer.bias.enabled is not None: - result["key_layer"] = {"bias": {"enabled": config.key_layer.bias.enabled}} - if config.value_layer.bias.enabled is not None: - result["value_layer"] = {"bias": {"enabled": config.value_layer.bias.enabled}} - if config.dense_layer.bias.enabled is not None: - result["dense_layer"] = {"bias": {"enabled": config.dense_layer.bias.enabled}} - # add_linear_biases as fallback default; omit when True (the Fast-LLM default) to avoid - # round-trip inflation on configs that don't set it explicitly. - if not config.add_linear_biases: - result["add_linear_biases"] = config.add_linear_biases - return result + + # --- weight side (imperative) --- @classmethod def _get_effective_bias(cls, layer_config, default: bool) -> bool: - """Get effective bias setting: use layer-specific if set, else default.""" if layer_config.bias.enabled is not None: return layer_config.bias.enabled return default @@ -113,13 +146,11 @@ def get_converters( hf_prefix: str, drop_on_export: bool = False, ) -> list[WeightConverter]: - # Determine effective bias for each projection q_bias = cls._get_effective_bias(config.query_layer, config.add_linear_biases) k_bias = cls._get_effective_bias(config.key_layer, config.add_linear_biases) v_bias = cls._get_effective_bias(config.value_layer, config.add_linear_biases) o_bias = cls._get_effective_bias(config.dense_layer, config.add_linear_biases) - # For key_value, both k and v must have same bias setting - # (they're combined in Fast-LLM's key_value layer) + # k_proj and v_proj are merged in Fast-LLM's key_value layer; treat as biased only if both sides agree. kv_bias = k_bias and v_bias return [ @@ -127,8 +158,6 @@ def get_converters( f"{fast_llm_prefix}.query", f"{hf_prefix}.q_proj", q_bias, - QueryWeightConverter, - config, drop_on_export=drop_on_export, ), *get_weight_and_bias_converters( @@ -148,40 +177,50 @@ def get_converters( ] -class Apriel2MambaConverter: - @classmethod - def import_config(cls, config: dict) -> dict: - result = { - "type": "mamba", - "state_size": config["state_size"], - "d_inner": config["d_inner"], - "add_linear_biases": config["add_linear_biases"], - } - if "d_xb" in config: - result["d_xb"] = config["d_xb"] - if "dt_rank" in config: - result["dt_rank"] = config["dt_rank"] - return result +def _apriel2_mamba_aux_export(config: MambaConfig) -> dict: + """Emit Apriel2's mamba-specific HF auxiliaries (``d_conv`` from convolution kernel size, plus the + convolution and dt-projection effective bias flags). These have no flat Fast-LLM analogue.""" + return { + ("d_conv",): config.convolution_layer.kernel_size, + ("conv_bias",): config.convolution_layer.bias.enabled, + ("dt_proj_bias",): config.dt_layer.bias.enabled, + } - @classmethod - def export_config(cls, config: MambaConfig) -> dict: - exported = { - "type": "mamba", - "state_size": config.state_size, - "d_inner": config.d_inner, - "d_conv": config.convolution_layer.kernel_size, - "add_linear_biases": config.add_linear_biases, - "conv_bias": config.convolution_layer.bias.enabled, - "dt_proj_bias": config.dt_layer.bias.enabled, - } - if config.d_xb is not None: - exported["d_xb"] = config.d_xb +class Apriel2MambaConverter(ConfigSectionConverter): + fast_llm_config_class = MambaConfig + hf_type_name = "mamba" - if config.dt_rank != "auto": - exported["dt_rank"] = config.dt_rank + @classmethod + def _create_config_converters(cls) -> dict: + return { + "state_size": RenameConfigConverter(("state_size",), ("state_size",)), + "d_inner": RenameConfigConverter(("d_inner",), ("d_inner",)), + "add_linear_biases": RenameConfigConverter(("add_linear_biases",), ("add_linear_biases",)), + "d_xb": OptionalConfigConverter(("d_xb",), ("d_xb",)), + "dt_rank": OptionalConfigConverter(("dt_rank",), ("dt_rank",)), + "aux": CustomConfigConverter( + fast_llm_paths=(("convolution_layer",), ("dt_layer",)), + export_fn=_apriel2_mamba_aux_export, + # The d_conv/conv_bias/dt_proj_bias HF fields are not reflected in the Fast-LLM mamba dict — + # current Apriel2 import simply ignores them and lets Fast-LLM use its own defaults. + import_fn=lambda hf: {}, + ), + # Architecture fields with no HF counterpart; they round-trip at their Fast-LLM defaults. + "layers_unmapped": IgnoredConfigConverter( + ("z_layer",), + ("x_layer",), + ("b_layer",), + ("c_layer",), + ("output_layer",), + ("dt_input_layer",), + ("a_log_weight",), + ("d_weight",), + ("repeat_kv_before_conv",), + ), + } - return exported + # --- weight side (imperative) --- @classmethod def get_converters( @@ -235,33 +274,37 @@ def get_converters( ] -class Apriel2GatedDeltaNetConverter: - @classmethod - def import_config(cls, config: dict) -> dict: - result = { - "type": "gdn", - "value_heads": config["value_heads"], - "key_heads": config["key_heads"], - "key_head_dim": config["key_head_dim"], - "value_head_dim": config["value_head_dim"], - } - if "convolution_layer" in config: - result["convolution_layer"] = config["convolution_layer"] - return result +class Apriel2GatedDeltaNetConverter(ConfigSectionConverter): + fast_llm_config_class = GatedDeltaNetConfig + hf_type_name = "gdn" @classmethod - def export_config(cls, config: GatedDeltaNetConfig) -> dict: + def _create_config_converters(cls) -> dict: return { - "type": "gdn", - "value_heads": config.value_heads, - "key_heads": config.key_heads, - "key_head_dim": config.key_head_dim, - "value_head_dim": config.value_head_dim, - "convolution_layer": { - "kernel_size": config.convolution_layer.kernel_size, - }, + "value_heads": RenameConfigConverter(("value_heads",), ("value_heads",)), + "key_heads": RenameConfigConverter(("key_heads",), ("key_heads",)), + "key_head_dim": RenameConfigConverter(("key_head_dim",), ("key_head_dim",)), + "value_head_dim": RenameConfigConverter(("value_head_dim",), ("value_head_dim",)), + "convolution_layer": CustomConfigConverter( + fast_llm_paths=(("convolution_layer",),), + export_fn=lambda c: {("convolution_layer",): {"kernel_size": c.convolution_layer.kernel_size}}, + import_fn=lambda hf: ( + {("convolution_layer",): hf["convolution_layer"]} if "convolution_layer" in hf else {} + ), + ), + # Architecture fields not surfaced in HF; round-trip at default. + "layers_unmapped": IgnoredConfigConverter( + ("normalization",), + ("qkv_projection_layer",), + ("ba_projection_layer",), + ("output_layer",), + ("dt_bias_weight",), + ("a_log_weight",), + ), } + # --- weight side (imperative) --- + @classmethod def get_converters( cls, @@ -314,34 +357,45 @@ def get_converters( ] -class Apriel2KimiDeltaAttentionConverter: - @classmethod - def import_config(cls, config: dict) -> dict: - result = { - "type": "kda", - "heads": config["heads"], - "head_dim": config["head_dim"], - } - if "convolution_layer" in config: - result["convolution_layer"] = config["convolution_layer"] - if "normalization" in config: - result["normalization"] = config["normalization"] - return result +class Apriel2KimiDeltaAttentionConverter(ConfigSectionConverter): + fast_llm_config_class = KimiDeltaAttentionConfig + hf_type_name = "kda" @classmethod - def export_config(cls, config: KimiDeltaAttentionConfig) -> dict: + def _create_config_converters(cls) -> dict: return { - "type": "kda", - "heads": config.heads, - "head_dim": config.head_dim, - "convolution_layer": { - "kernel_size": config.convolution_layer.kernel_size, - }, - "normalization": { - "epsilon": config.normalization.epsilon, - }, + "heads": RenameConfigConverter(("heads",), ("heads",)), + "head_dim": RenameConfigConverter(("head_dim",), ("head_dim",)), + "convolution_layer": CustomConfigConverter( + fast_llm_paths=(("convolution_layer",),), + export_fn=lambda c: {("convolution_layer",): {"kernel_size": c.convolution_layer.kernel_size}}, + import_fn=lambda hf: ( + {("convolution_layer",): hf["convolution_layer"]} if "convolution_layer" in hf else {} + ), + ), + "normalization": CustomConfigConverter( + fast_llm_paths=(("normalization",),), + export_fn=lambda c: {("normalization",): {"epsilon": c.normalization.epsilon}}, + import_fn=lambda hf: ({("normalization",): hf["normalization"]} if "normalization" in hf else {}), + ), + # Architecture fields not surfaced in HF; round-trip at default. + "layers_unmapped": IgnoredConfigConverter( + ("q_projection_layer",), + ("k_projection_layer",), + ("v_projection_layer",), + ("f_a_projection_layer",), + ("f_b_projection_layer",), + ("g_a_projection_layer",), + ("g_b_projection_layer",), + ("beta_projection_layer",), + ("output_projection_layer",), + ("dt_bias_weight",), + ("a_log_weight",), + ), } + # --- weight side (imperative) --- + @classmethod def get_converters( cls, @@ -350,11 +404,7 @@ def get_converters( hf_prefix: str, drop_on_export: bool = False, ) -> list[WeightConverter]: - # Fast-LLM KDA uses abbreviated names matching the external module: - # q_proj, k_proj, v_proj, q_conv, k_conv, v_conv, f_a_proj, f_b_proj, - # g_a_proj, g_b_proj, beta_proj, o_proj, A_log, dt_bias, norm return [ - # Q/K/V projections *get_weight_and_bias_converters( f"{fast_llm_prefix}.q_proj", f"{hf_prefix}.q_proj", @@ -373,7 +423,6 @@ def get_converters( False, drop_on_export=drop_on_export, ), - # Convolutions (Q, K, V) *get_weight_and_bias_converters( f"{fast_llm_prefix}.q_conv", f"{hf_prefix}.q_conv", @@ -392,7 +441,6 @@ def get_converters( False, drop_on_export=drop_on_export, ), - # Gate projections (f_a, f_b, g_a, g_b) *get_weight_and_bias_converters( f"{fast_llm_prefix}.f_a_proj", f"{hf_prefix}.f_a_proj", @@ -417,21 +465,18 @@ def get_converters( False, drop_on_export=drop_on_export, ), - # Beta projection *get_weight_and_bias_converters( f"{fast_llm_prefix}.beta_proj", f"{hf_prefix}.beta_proj", False, drop_on_export=drop_on_export, ), - # Output projection *get_weight_and_bias_converters( f"{fast_llm_prefix}.o_proj", f"{hf_prefix}.o_proj", False, drop_on_export=drop_on_export, ), - # Learnable parameters get_parameter_converter( f"{fast_llm_prefix}.A_log", f"{hf_prefix}.A_log", @@ -442,7 +487,6 @@ def get_converters( f"{hf_prefix}.dt_bias", drop_on_export=drop_on_export, ), - # Normalization *LlamaNormalizationConverter.get_converters( config.normalization, f"{fast_llm_prefix}.norm", @@ -452,56 +496,38 @@ def get_converters( ] -class Apriel2StochasticMixerConverter: - @classmethod - def import_config(cls, config: dict) -> dict: - mixers = {} - for name, sub_mixer_config in config["mixers"].items(): - mixer_type = sub_mixer_config["type"] - if mixer_type == "attention": - mixers[name] = Apriel2AttentionConverter.import_config(sub_mixer_config) - elif mixer_type == "mamba": - mixers[name] = Apriel2MambaConverter.import_config(sub_mixer_config) - elif mixer_type == "gdn": - mixers[name] = Apriel2GatedDeltaNetConverter.import_config(sub_mixer_config) - elif mixer_type == "kda": - mixers[name] = Apriel2KimiDeltaAttentionConverter.import_config(sub_mixer_config) - else: - raise ValueError(f"Unknown sub-mixer type: {mixer_type}") - - result = { - "type": "stochastic", - "mixers": mixers, - "main_mixer_name": config["main_mixer_name"], - } - if "sampling_strategy" in config: - result["sampling_strategy"] = config["sampling_strategy"] - return result +# Mixer dispatch registry — used inside StochasticMixer (no nested-stochastic) and at the block level. +APRIEL2_LEAF_MIXER_REGISTRY: dict = { + AttentionConfig: Apriel2AttentionConverter, + MambaConfig: Apriel2MambaConverter, + GatedDeltaNetConfig: Apriel2GatedDeltaNetConverter, + KimiDeltaAttentionConfig: Apriel2KimiDeltaAttentionConverter, +} + + +class Apriel2StochasticMixerConverter(ConfigSectionConverter): + fast_llm_config_class = StochasticMixerConfig + hf_type_name = "stochastic" @classmethod - def export_config(cls, config: StochasticMixerConfig) -> dict: - mixers = {} - for name, sub_mixer in config.mixers.items(): - mixer_type = type(sub_mixer) - if mixer_type is AttentionConfig: - mixers[name] = Apriel2AttentionConverter.export_config(sub_mixer) - elif mixer_type is MambaConfig: - mixers[name] = Apriel2MambaConverter.export_config(sub_mixer) - elif mixer_type is GatedDeltaNetConfig: - mixers[name] = Apriel2GatedDeltaNetConverter.export_config(sub_mixer) - elif mixer_type is KimiDeltaAttentionConfig: - mixers[name] = Apriel2KimiDeltaAttentionConverter.export_config(sub_mixer) - else: - raise ValueError(f"Unknown sub-mixer type: {mixer_type}") - - result = { - "type": "stochastic", - "mixers": mixers, - "main_mixer_name": config.main_mixer_name, + def _create_config_converters(cls) -> dict: + from fast_llm.layers.decoder.config import StochasticMixerSamplingStrategy + + return { + "mixers": TypedDictContainerConfigConverter( + fast_llm_path=("mixers",), + hf_path=("mixers",), + registry=APRIEL2_LEAF_MIXER_REGISTRY, + ), + "main_mixer_name": RenameConfigConverter(("main_mixer_name",), ("main_mixer_name",)), + "sampling_strategy": OptionalConfigConverter( + ("sampling_strategy",), + ("sampling_strategy",), + sentinel=StochasticMixerSamplingStrategy.uniform, + ), } - if config.sampling_strategy != StochasticMixerSamplingStrategy.uniform: - result["sampling_strategy"] = config.sampling_strategy.value - return result + + # --- weight side (imperative) --- @classmethod def get_converters( @@ -513,136 +539,128 @@ def get_converters( ) -> list[WeightConverter]: converters = [] for name, sub_mixer in config.mixers.items(): - mixer_type = type(sub_mixer) - if mixer_type is AttentionConfig: - converter_class = Apriel2AttentionConverter - hf_sub_mixer_prefix = f"{hf_prefix}.mixers.{name}" - elif mixer_type is MambaConfig: - converter_class = Apriel2MambaConverter - hf_sub_mixer_prefix = f"{hf_prefix}.mixers.{name}" - elif mixer_type is GatedDeltaNetConfig: - converter_class = Apriel2GatedDeltaNetConverter - hf_sub_mixer_prefix = f"{hf_prefix}.mixers.{name}" - elif mixer_type is KimiDeltaAttentionConfig: - converter_class = Apriel2KimiDeltaAttentionConverter - hf_sub_mixer_prefix = f"{hf_prefix}.mixers.{name}" - else: - raise ValueError(f"Unknown sub-mixer type: {mixer_type}") + converter_class = APRIEL2_LEAF_MIXER_REGISTRY.get(type(sub_mixer)) + if converter_class is None: + raise ValueError(f"Unknown sub-mixer type: {type(sub_mixer)}") converters.extend( converter_class.get_converters( sub_mixer, f"{fast_llm_prefix}.mixers.{name}", - hf_sub_mixer_prefix, + f"{hf_prefix}.mixers.{name}", drop_on_export=drop_on_export, ) ) - return converters -class Apriel2BlockConverter: - @classmethod - def import_config(cls, config: dict, block_config: dict) -> dict: - mixer_config = block_config["mixer"] - mixer_type = mixer_config["type"] - - if mixer_type == "attention": - mixer = Apriel2AttentionConverter.import_config(mixer_config) - elif mixer_type == "mamba": - mixer = Apriel2MambaConverter.import_config(mixer_config) - elif mixer_type == "stochastic": - mixer = Apriel2StochasticMixerConverter.import_config(mixer_config) - elif mixer_type == "gdn": - mixer = Apriel2GatedDeltaNetConverter.import_config(mixer_config) - elif mixer_type == "kda": - mixer = Apriel2KimiDeltaAttentionConverter.import_config(mixer_config) - else: - raise ValueError(f"Unknown mixer type: {mixer_type}") +# Block-level mixer registry includes StochasticMixer (which can wrap leaf mixers). +APRIEL2_BLOCK_MIXER_REGISTRY: dict = { + **APRIEL2_LEAF_MIXER_REGISTRY, + StochasticMixerConfig: Apriel2StochasticMixerConverter, +} + - from fast_llm.functional.config import ActivationType +# ============================================================ +# Normalization converters +# ============================================================ - mlp_config = block_config["mlp"] - mlp = { - "type": "mlp", - "intermediate_size": mlp_config["intermediate_size"], - "activation": ActivationType.from_hf_name(mlp_config["activation"]), - "gated": mlp_config["gated"], - "add_linear_biases": mlp_config["add_linear_biases"], + +class Apriel2RMSNormConverter(ConfigSectionConverter): + fast_llm_config_class = RMSNormalizationConfig + hf_type_name = "rms_norm" + + @classmethod + def _create_config_converters(cls) -> dict: + return { + "epsilon": RenameConfigConverter(("epsilon",), ("epsilon",)), + "weight": IgnoredConfigConverter(("weight",)), + "zero_centered": ConstantImportConfigConverter(("zero_centered",), False), } - # Import per-layer MLP bias settings (layer_1, layer_2) - for layer_name in ("layer_1", "layer_2"): - if layer_name in mlp_config: - layer_cfg = mlp_config[layer_name] - if "bias" in layer_cfg: - mlp[layer_name] = {"bias": layer_cfg["bias"]} - normalization = block_config["normalization"] +class Apriel2LayerNormConverter(ConfigSectionConverter): + fast_llm_config_class = LayerNormalizationConfig + hf_type_name = "layer_norm" + + @classmethod + def _create_config_converters(cls) -> dict: return { - "mixer": mixer, - "mlp": mlp, - "normalization": normalization, + "epsilon": RenameConfigConverter(("epsilon",), ("epsilon",)), + "weight": IgnoredConfigConverter(("weight",)), + "bias": IgnoredConfigConverter(("bias",)), + "zero_centered": ConstantImportConfigConverter(("zero_centered",), False), } + +class Apriel2NoNormConverter(ConfigSectionConverter): + fast_llm_config_class = NoNormalizationConfig + hf_type_name = "none" + @classmethod - def export_config(cls, config: DecoderBlockConfig) -> dict: - from fast_llm.layers.common.normalization.config import ( - LayerNormalizationConfig, - NoNormalizationConfig, - RMSNormalizationConfig, - ) + def _create_config_converters(cls) -> dict: + return {} - mixer_type = type(config.mixer) - - if mixer_type is AttentionConfig: - mixer = Apriel2AttentionConverter.export_config(config.mixer) - elif mixer_type is MambaConfig: - mixer = Apriel2MambaConverter.export_config(config.mixer) - elif mixer_type is StochasticMixerConfig: - mixer = Apriel2StochasticMixerConverter.export_config(config.mixer) - elif mixer_type is GatedDeltaNetConfig: - mixer = Apriel2GatedDeltaNetConverter.export_config(config.mixer) - elif mixer_type is KimiDeltaAttentionConfig: - mixer = Apriel2KimiDeltaAttentionConverter.export_config(config.mixer) - else: - raise ValueError(f"Unknown mixer type: {mixer_type}") - - norm_type = type(config.normalization) - if norm_type is RMSNormalizationConfig: - norm_type_str = "rms_norm" - elif norm_type is LayerNormalizationConfig: - norm_type_str = "layer_norm" - elif norm_type is NoNormalizationConfig: - norm_type_str = "none" - else: - raise ValueError(f"Unknown normalization type: {norm_type}") - from fast_llm.layers.decoder.mlp.config import MLPConfig +APRIEL2_NORM_REGISTRY: dict = { + RMSNormalizationConfig: Apriel2RMSNormConverter, + LayerNormalizationConfig: Apriel2LayerNormConverter, + NoNormalizationConfig: Apriel2NoNormConverter, +} - if not isinstance(config.mlp, MLPConfig): - raise ValueError(f"Unsupported MLP type: {type(config.mlp)}") - mlp = { - "type": "mlp", - "intermediate_size": config.mlp.intermediate_size, - "activation": config.mlp.activation.value, - "gated": config.mlp.gated, - "add_linear_biases": config.mlp.add_linear_biases, +# ============================================================ +# MLP, Block, Decoder, Head +# ============================================================ + + +class Apriel2MLPConverter(ConfigSectionConverter): + fast_llm_config_class = MLPConfig + hf_type_name = "mlp" + + @classmethod + def _create_config_converters(cls) -> dict: + layer_names = ("layer_1", "layer_2") + return { + # MLP is wrapped via NestedConfigConverter (no Dispatch discriminator), so emit the HF + # ``"type": "mlp"`` discriminator from inside this converter. + "hf_type": ConstantExportConfigConverter(("type",), "mlp"), + "intermediate_size": RenameConfigConverter(("intermediate_size",), ("intermediate_size",)), + "gated": RenameConfigConverter(("gated",), ("gated",)), + "add_linear_biases": RenameConfigConverter(("add_linear_biases",), ("add_linear_biases",)), + "activation": CustomConfigConverter( + fast_llm_paths=(("activation",),), + export_fn=lambda c: {("activation",): c.activation.hf_name}, + import_fn=lambda hf: {("activation",): ActivationType.from_hf_name(hf["activation"])}, + ), + "layers": CustomConfigConverter( + fast_llm_paths=tuple((name,) for name in layer_names), + export_fn=lambda c: _per_layer_bias_export(c, layer_names), + import_fn=lambda hf: _per_layer_bias_import(hf, layer_names), + ), } - # Export per-layer MLP bias settings (layer_1, layer_2) - if config.mlp.layer_1.bias.enabled is not None: - mlp["layer_1"] = {"bias": {"enabled": config.mlp.layer_1.bias.enabled}} - if config.mlp.layer_2.bias.enabled is not None: - mlp["layer_2"] = {"bias": {"enabled": config.mlp.layer_2.bias.enabled}} - normalization = {"type": norm_type_str, "epsilon": config.normalization.epsilon} +class Apriel2BlockConverter(ConfigSectionConverter): + fast_llm_config_class = DecoderBlockConfig + + @classmethod + def _create_config_converters(cls) -> dict: return { - "mixer": mixer, - "mlp": mlp, - "normalization": normalization, + "mixer": DispatchConfigConverter( + fast_llm_path=("mixer",), + hf_path=("mixer",), + registry=APRIEL2_BLOCK_MIXER_REGISTRY, + ), + "mlp": NestedConfigConverter(("mlp",), Apriel2MLPConverter, hf_path=("mlp",)), + "normalization": DispatchConfigConverter( + fast_llm_path=("normalization",), + hf_path=("normalization",), + registry=APRIEL2_NORM_REGISTRY, + ), } + # --- weight side (imperative) --- + @classmethod def get_converters( cls, @@ -651,46 +669,30 @@ def get_converters( hf_prefix: str, drop_on_export: bool = False, ) -> list[WeightConverter]: - converters = [] - mixer_type = type(config.mixer) - if mixer_type is AttentionConfig: - converter_class = Apriel2AttentionConverter - hf_mixer_prefix = f"{hf_prefix}.mixer" - elif mixer_type is MambaConfig: - converter_class = Apriel2MambaConverter - hf_mixer_prefix = f"{hf_prefix}.mixer" - elif mixer_type is StochasticMixerConfig: - converter_class = Apriel2StochasticMixerConverter - hf_mixer_prefix = f"{hf_prefix}.mixer" - elif mixer_type is GatedDeltaNetConfig: - converter_class = Apriel2GatedDeltaNetConverter - hf_mixer_prefix = f"{hf_prefix}.mixer" - elif mixer_type is KimiDeltaAttentionConfig: - converter_class = Apriel2KimiDeltaAttentionConverter - hf_mixer_prefix = f"{hf_prefix}.mixer" - else: - raise ValueError(f"Unknown mixer type: {mixer_type}") - - converters.extend( + converter_class = APRIEL2_BLOCK_MIXER_REGISTRY.get(type(config.mixer)) + if converter_class is None: + raise ValueError(f"Unknown mixer type: {type(config.mixer)}") + converters: list[WeightConverter] = list( converter_class.get_converters( config.mixer, f"{fast_llm_prefix}.mixer", - hf_mixer_prefix, + f"{hf_prefix}.mixer", drop_on_export=drop_on_export, ) ) - # Per-layer MLP bias: use layer-specific setting if set, else default - def get_mlp_layer_bias(layer_config, default: bool) -> bool: - if layer_config.bias.enabled is not None: - return layer_config.bias.enabled - return default - - layer_1_bias = get_mlp_layer_bias(config.mlp.layer_1, config.mlp.add_linear_biases) - layer_2_bias = get_mlp_layer_bias(config.mlp.layer_2, config.mlp.add_linear_biases) + layer_1_bias = ( + config.mlp.layer_1.bias.enabled + if config.mlp.layer_1.bias.enabled is not None + else config.mlp.add_linear_biases + ) + layer_2_bias = ( + config.mlp.layer_2.bias.enabled + if config.mlp.layer_2.bias.enabled is not None + else config.mlp.add_linear_biases + ) if config.mlp.gated: - # Gated MLP: gate_proj + up_proj -> layer_1 (split), down_proj -> layer_2 converters.extend( [ *get_weight_and_bias_converters( @@ -710,8 +712,6 @@ def get_mlp_layer_bias(layer_config, default: bool) -> bool: ] ) else: - # Non-gated MLP: up_proj -> layer_1, down_proj -> layer_2 - # Note: layer_2 still needs MLPLayer2Converter for the transpose converters.extend( [ *get_weight_and_bias_converters( @@ -747,73 +747,52 @@ def get_mlp_layer_bias(layer_config, default: bool) -> bool: ), ] ) - return converters -class Apriel2DecoderConverter: - block_converter_class: typing.ClassVar[type[Apriel2BlockConverter]] = Apriel2BlockConverter +class Apriel2FixedDecoderConverter(ConfigSectionConverter): + fast_llm_config_class = FixedBlockSequenceConfig + hf_type_name = "fixed" @classmethod - def import_config(cls, config: dict) -> dict: - decoder_config = config["decoder"] - decoder_type = decoder_config["type"] - - if decoder_type == "fixed": - block_config = decoder_config["block"] - imported_block = cls.block_converter_class.import_config(config, block_config) - - return { - "type": "fixed", - "num_blocks": decoder_config["num_blocks"], - "block": imported_block, - } - - elif decoder_type == "pattern": - blocks = {} - for name, block_config in decoder_config["blocks"].items(): - blocks[name] = cls.block_converter_class.import_config(config, block_config) - - return { - "type": "pattern", - "blocks": blocks, - "pattern": decoder_config["pattern"], - "num_blocks": decoder_config["num_blocks"], - } + def _create_config_converters(cls) -> dict: + return { + "num_blocks": RenameConfigConverter(("num_blocks",), ("num_blocks",)), + "block": NestedConfigConverter(("block",), Apriel2BlockConverter, hf_path=("block",)), + } - else: - raise ValueError(f"Unknown decoder type: {decoder_type}") + +class Apriel2PatternDecoderConverter(ConfigSectionConverter): + fast_llm_config_class = PatternBlockSequenceConfig + hf_type_name = "pattern" @classmethod - def export_config(cls, config) -> dict: - from fast_llm.layers.block.config import FixedBlockSequenceConfig, PatternBlockSequenceConfig - - if isinstance(config, FixedBlockSequenceConfig): - block_config = cls.block_converter_class.export_config(config.block) - return { - "decoder": { - "type": "fixed", - "num_blocks": config.num_blocks, - "block": block_config, - } - } - - elif isinstance(config, PatternBlockSequenceConfig): - blocks = {} - for name, block_config in config.blocks.items(): - blocks[name] = cls.block_converter_class.export_config(block_config) - - return { - "decoder": { - "type": "pattern", - "blocks": blocks, - "pattern": config.pattern, - "num_blocks": config.num_blocks, - } - } + def _create_config_converters(cls) -> dict: + return { + "num_blocks": RenameConfigConverter(("num_blocks",), ("num_blocks",)), + "pattern": RenameConfigConverter(("pattern",), ("pattern",)), + "blocks": TypedDictContainerConfigConverter( + fast_llm_path=("blocks",), + hf_path=("blocks",), + registry={DecoderBlockConfig: Apriel2BlockConverter}, + ), + } - else: - raise ValueError(f"Unknown decoder config type: {type(config)}") + +APRIEL2_DECODER_REGISTRY: dict = { + FixedBlockSequenceConfig: Apriel2FixedDecoderConverter, + PatternBlockSequenceConfig: Apriel2PatternDecoderConverter, +} + + +class Apriel2DecoderConverter: + """Imperative decoder dispatcher kept for the weight side. + + Config-side conversion is handled declaratively via :class:`Apriel2FixedDecoderConverter` and + :class:`Apriel2PatternDecoderConverter`, dispatched at the base-model level. + """ + + block_converter_class: typing.ClassVar[type[Apriel2BlockConverter]] = Apriel2BlockConverter @classmethod def get_converters( @@ -823,9 +802,7 @@ def get_converters( hf_prefix: str, drop_on_export: bool = False, ) -> list[WeightConverter]: - from fast_llm.layers.block.config import FixedBlockSequenceConfig, PatternBlockSequenceConfig - - converters = [] + converters: list[WeightConverter] = [] if type(config) is FixedBlockSequenceConfig: for block_index in range(config.num_blocks): converters += cls.block_converter_class.get_converters( @@ -848,28 +825,25 @@ def get_converters( return converters -class Apriel2HeadConverter: - normalization_converter_class: typing.ClassVar[type[LlamaNormalizationConverter]] = LlamaNormalizationConverter +class Apriel2HeadConverter(ConfigSectionConverter): + fast_llm_config_class = LanguageModelHeadConfig - @classmethod - def import_config(cls, config: dict) -> dict: - norm_config = config["head"]["normalization"] - return {"normalization": {"type": "rms_norm", "epsilon": norm_config["epsilon"]}} + normalization_converter_class: typing.ClassVar[type[LlamaNormalizationConverter]] = LlamaNormalizationConverter @classmethod - def export_config(cls, config) -> dict: - from fast_llm.layers.language_model.config import LanguageModelHeadConfig - - Assert.custom(isinstance, config, LanguageModelHeadConfig) + def _create_config_converters(cls) -> dict: return { - "head": { - "normalization": { - "type": "rms_norm", - "epsilon": config.normalization.epsilon, - } - } + "normalization": DispatchConfigConverter( + fast_llm_path=("normalization",), + hf_path=("normalization",), + registry=APRIEL2_NORM_REGISTRY, + ), + "output_weight": IgnoredConfigConverter(("output_weight",)), + "prediction_heads": IgnoredConfigConverter(("prediction_heads",)), } + # --- weight side (imperative) --- + @classmethod def get_converters( cls, @@ -892,33 +866,35 @@ def get_converters( ] -class Apriel2BaseModelConverter: +class Apriel2BaseModelConverter(ConfigSectionConverter): + fast_llm_config_class = GPTBaseModelConfig + decoder_converter_class: typing.ClassVar[type[Apriel2DecoderConverter]] = Apriel2DecoderConverter embeddings_converter_class: typing.ClassVar[type[LlamaEmbeddingsConverter]] = LlamaEmbeddingsConverter head_converter_class: typing.ClassVar[type[Apriel2HeadConverter]] = Apriel2HeadConverter @classmethod - def import_config(cls, config: dict) -> dict: + def _create_config_converters(cls) -> dict: return { - "embeddings": cls.embeddings_converter_class.import_config(config), - "decoder": cls.decoder_converter_class.import_config(config), - "head": cls.head_converter_class.import_config(config), - "hidden_size": config["hidden_size"], - "tied_embedding_weight": config["tie_word_embeddings"], + "embeddings": NestedConfigConverter(("embeddings",), cls.embeddings_converter_class), + "decoder": DispatchConfigConverter( + fast_llm_path=("decoder",), + hf_path=("decoder",), + registry=APRIEL2_DECODER_REGISTRY, + ), + "head": NestedConfigConverter(("head",), cls.head_converter_class, hf_path=("head",)), + "hidden_size": RenameConfigConverter(("hidden_size",), ("hidden_size",)), + "tied_embedding_weight": RenameConfigConverter(("tied_embedding_weight",), ("tie_word_embeddings",)), + "peft": IgnoredConfigConverter(("peft",)), } @classmethod - def export_config(cls, config: GPTBaseModelConfig) -> dict: - Assert.custom(isinstance, config, GPTBaseModelConfig) - return safe_merge_dicts( - cls.embeddings_converter_class.export_config(config.embeddings), - cls.decoder_converter_class.export_config(config.decoder), - cls.head_converter_class.export_config(config.head), - { - "tie_word_embeddings": config.tied_embedding_weight, - "hidden_size": config.hidden_size, - }, - ) + def _validate_export(cls, config: GPTBaseModelConfig) -> None: + from fast_llm.layers.common.peft.config import NoPeftConfig + + Assert.custom(isinstance, config.peft, NoPeftConfig) + + # --- weight side (imperative) --- @classmethod def get_converters(cls, config: GPTBaseModelConfig, exported_config: dict) -> list[WeightConverter]: @@ -955,7 +931,7 @@ def get_model_files(cls) -> tuple[str, str, str | None]: @classmethod def _export_config(cls, config: GPTModelConfig) -> dict[str, typing.Any]: base_model = config.base_model - exported = safe_merge_dicts( + return safe_merge_dicts( cls.base_model_converter_class.export_config(base_model), { "architectures": [cls.architecture], @@ -967,7 +943,6 @@ def _export_config(cls, config: GPTModelConfig) -> dict[str, typing.Any]: }, }, ) - return exported @classmethod def _import_config(cls, config: dict[str, typing.Any]) -> dict[str, typing.Any]: diff --git a/fast_llm/models/gpt/conversion/llama.py b/fast_llm/models/gpt/conversion/llama.py index f8f36dc23..1888e6fd3 100644 --- a/fast_llm/models/gpt/conversion/llama.py +++ b/fast_llm/models/gpt/conversion/llama.py @@ -5,10 +5,18 @@ import torch import transformers +from fast_llm.config import Config from fast_llm.engine.checkpoint.config import CheckpointFormat from fast_llm.engine.checkpoint.external import ( + ConfigSectionConverter, + ConstantImportConfigConverter, + CustomConfigConverter, + DefaultConfigConverter, + IgnoredConfigConverter, IgnoreExportWeightConverter, IgnoreImportWeightConverter, + NestedConfigConverter, + RenameConfigConverter, SplitWeightConverter, WeightConverter, ) @@ -30,13 +38,18 @@ from fast_llm.models.gpt.conversion.config import LlamaCheckpointFormat from fast_llm.models.gpt.model import GPTModel from fast_llm.tensor import SafeTensorSlice -from fast_llm.utils import Assert, div, safe_merge_dicts +from fast_llm.utils import Assert, div _TRANSFORMERS_V4 = not dataclasses.is_dataclass(transformers.PretrainedConfig) logger = logging.getLogger(__name__) +# ============================================================ +# Weight converters (imperative — kept as-is during config migration) +# ============================================================ + + def get_parameter_converter( fast_llm_name: str | tuple[str, ...], hf_name: str | tuple[str, ...], @@ -97,16 +110,139 @@ def get_weight_and_bias_converters( return converters -class LlamaNormalizationConverter: - @classmethod - def import_config(cls, config: dict) -> dict: - return {"type": "rms_norm", "epsilon": config["rms_norm_eps"]} +class MLPLayer2Converter(WeightConverter): + # Similar to SplitWeightConverter, but handles the optional MLP transpose. + # Still ok for non-gated (trivial split) and biases (trivial 1d transpose) + + def export_weight( + self, weight: tuple[torch.Tensor | SafeTensorSlice, ...] + ) -> tuple[torch.Tensor | SafeTensorSlice, ...]: + (merged_weight,) = weight + return tuple(t.contiguous() for t in merged_weight[:].t().chunk(len(self.export_name), dim=-1)) + + def import_weight( + self, weight: tuple[torch.Tensor | SafeTensorSlice, ...] + ) -> tuple[torch.Tensor | SafeTensorSlice, ...]: + merged_weight = torch.cat([weight_[:] for weight_ in weight], dim=-1) + return (merged_weight.t().contiguous(),) + + +class KeyValueWeightConverter(WeightConverter): + # Hf uses the real format for rotary embeddings, and keeps the key and value separate. + _config: AttentionConfig + + def export_weight( + self, weight: tuple[torch.Tensor | SafeTensorSlice, ...] + ) -> tuple[torch.Tensor | SafeTensorSlice, ...]: + (key_value,) = weight + key, value = key_value[:].chunk(2) + return key, value + + def import_weight( + self, weight: tuple[torch.Tensor | SafeTensorSlice, ...] + ) -> tuple[torch.Tensor | SafeTensorSlice, ...]: + key, value = weight + key_value = torch.cat([key[:], value[:]]) + return (key_value,) + + +# ============================================================ +# Config converters (declarative) +# ============================================================ + + +def _llama_rotary_export(config: AttentionConfig) -> dict: + """Build the HF rotary block(s) from a Fast-LLM rotary config. + + Returns a dict keyed by the (Llama-flat) HF paths the converter declares; values vary with rotary subtype and + the active transformers major version (v4 puts ``rope_theta`` flat with optional ``rope_scaling``; + v5 consolidates everything into ``rope_parameters``). + """ + rotary = config.rotary + rope_parameters = {"rope_theta": rotary.theta} + if type(rotary) is DefaultRotaryConfig: + rope_parameters["rope_type"] = "default" + elif type(rotary) is Llama3RotaryConfig: + rope_parameters.update( + { + "rope_type": "llama3", + "factor": rotary.scale_factor, + "low_freq_factor": rotary.low_frequency_factor, + "high_freq_factor": rotary.high_frequency_factor, + "original_max_position_embeddings": rotary.original_context_length, + } + ) + elif type(rotary) is YarnRotaryConfig: + rope_parameters.update( + { + "rope_type": "yarn", + "attention_factor": rotary.attention_factor, + "beta_fast": rotary.beta_fast, + "beta_slow": rotary.beta_slow, + "original_max_position_embeddings": rotary.original_context_length, + } + ) + else: + raise NotImplementedError(f"Unsupported rotary type: {type(rotary).__name__}") + + if _TRANSFORMERS_V4: + out: dict = {("rope_theta",): rope_parameters["rope_theta"]} + if type(rotary) is not DefaultRotaryConfig: + out[("rope_scaling",)] = {k: v for k, v in rope_parameters.items() if k != "rope_theta"} + return out + return {("rope_parameters",): rope_parameters} + + +def _llama_rotary_import(hf_dict: dict) -> dict: + """Reverse of :func:`_llama_rotary_export`. Detects v4/v5 layout from the HF dict.""" + if "rope_parameters" in hf_dict: # transformers v5 + rope_params = hf_dict["rope_parameters"] + rope_theta = rope_params["rope_theta"] + else: # transformers v4 + rope_params = hf_dict.get("rope_scaling") or {} + rope_theta = hf_dict["rope_theta"] + rope_type = rope_params.get("rope_type", "default") + rotary_config: dict = {"type": rope_type, "theta": rope_theta} + if rope_type == "default": + pass + elif rope_type == "llama3": + rotary_config.update( + { + "scale_factor": rope_params["factor"], + "low_frequency_factor": rope_params["low_freq_factor"], + "high_frequency_factor": rope_params["high_freq_factor"], + "original_context_length": rope_params["original_max_position_embeddings"], + } + ) + elif rope_type == "yarn": + rotary_config.update( + { + "attention_factor": rope_params["attention_factor"], + "beta_fast": rope_params["beta_fast"], + "beta_slow": rope_params["beta_slow"], + "original_context_length": rope_params["original_max_position_embeddings"], + } + ) + else: + raise NotImplementedError(f"Unsupported rotary type: {rope_type}") + return {("rotary",): rotary_config} + + +class LlamaNormalizationConverter(ConfigSectionConverter): + """Converts ``RMSNormalizationConfig`` ↔ Llama's flat ``rms_norm_eps`` field.""" + + fast_llm_config_class = RMSNormalizationConfig @classmethod - def export_config(cls, config: RMSNormalizationConfig) -> dict: - Assert.custom(isinstance, config, RMSNormalizationConfig) - assert not config.zero_centered - return {"rms_norm_eps": config.epsilon} + def _create_config_converters(cls) -> dict: + return { + "type": ConstantImportConfigConverter(("type",), "rms_norm"), + "epsilon": RenameConfigConverter(("epsilon",), ("rms_norm_eps",)), + "weight": IgnoredConfigConverter(("weight",)), + "zero_centered": ConstantImportConfigConverter(("zero_centered",), False), + } + + # --- weight side (imperative) --- @classmethod def get_converters( @@ -124,27 +260,35 @@ def get_converters( ) -class LlamaMLPConverter: +class LlamaMLPConverter(ConfigSectionConverter): + """Converts ``MLPConfig`` ↔ Llama's flat ``intermediate_size``/``mlp_bias``/``hidden_act`` fields. + + Llama is always gated (``ConstantImportConfigConverter(("gated",), True)``). + """ + + fast_llm_config_class = MLPConfig + @classmethod - def import_config(cls, config: dict) -> dict: + def _create_config_converters(cls) -> dict: return { - "intermediate_size": config["intermediate_size"], - "add_linear_biases": config["mlp_bias"], - "activation": ActivationType.from_hf_name(config["hidden_act"]), - "gated": True, + "intermediate_size": RenameConfigConverter(("intermediate_size",), ("intermediate_size",)), + "add_linear_biases": RenameConfigConverter(("add_linear_biases",), ("mlp_bias",)), + "activation": CustomConfigConverter( + fast_llm_paths=(("activation",),), + export_fn=lambda c: {("hidden_act",): c.activation.hf_name}, + import_fn=lambda hf: {("activation",): ActivationType.from_hf_name(hf["hidden_act"])}, + ), + "gated": ConstantImportConfigConverter(("gated",), True), + # Llama doesn't expose per-layer bias overrides; the bias-match check lives on _validate_export. + "layers": IgnoredConfigConverter(("layer_1",), ("layer_2",)), } @classmethod - def export_config(cls, config: MLPConfig) -> dict: - Assert.custom(isinstance, config, MLPConfig) + def _validate_export(cls, config: MLPConfig) -> None: Assert.incl(config.layer_1.bias.enabled, (None, config.add_linear_biases)) Assert.incl(config.layer_2.bias.enabled, (None, config.add_linear_biases)) - assert config.gated - return { - "intermediate_size": config.intermediate_size, - "mlp_bias": config.add_linear_biases, - "hidden_act": config.activation.hf_name, - } + + # --- weight side (imperative) --- @classmethod def get_converters( @@ -172,127 +316,58 @@ def get_converters( ] -class MLPLayer2Converter(WeightConverter): - # Similar to SplitWeightConverter, but handles the optional MLP transpose. - # Still ok for non-gated (trivial split) and biases (trivial 1d transpose) +class LlamaAttentionConverter(ConfigSectionConverter): + """Converts ``AttentionConfig`` ↔ Llama's flat attention fields. - def export_weight( - self, weight: tuple[torch.Tensor | SafeTensorSlice, ...] - ) -> tuple[torch.Tensor | SafeTensorSlice, ...]: - (merged_weight,) = weight - return tuple(t.contiguous() for t in merged_weight[:].t().chunk(len(self.export_name), dim=-1)) + Notable wrinkles: + - ``head_dim`` is computed from ``hidden_size // num_attention_heads`` when missing on import. + - Rotary handling is delegated to a :class:`CustomConfigConverter` because it spans v4/v5 transformers + layouts and three rotary subtypes. + - Per-layer linear biases (query/key/value/dense) are validated to match ``add_linear_biases`` on + ``_validate_export``; Llama does not expose layer-level overrides, so the sub-config fields are + blanket-consumed via :class:`IgnoredConfigConverter`. + """ - def import_weight( - self, weight: tuple[torch.Tensor | SafeTensorSlice, ...] - ) -> tuple[torch.Tensor | SafeTensorSlice, ...]: - merged_weight = torch.cat([weight_[:] for weight_ in weight], dim=-1) - return (merged_weight.t().contiguous(),) + fast_llm_config_class = AttentionConfig - -class LlamaAttentionConverter: @classmethod - def import_config(cls, config: dict) -> dict: - # Normalize rope params to a single dict before dispatching on rope_type. - # transformers v5 consolidates rope_theta + rope_scaling into rope_parameters. - # transformers v4: rope_theta at top level, rope_scaling dict for non-default types. - # Note: detection is on checkpoint format, not transformers version — old checkpoints - # remain loadable with v5 transformers. - if "rope_parameters" in config: # transformers v5 - rope_params = config["rope_parameters"] - rope_theta = rope_params["rope_theta"] - else: # transformers v4 - rope_params = config.get("rope_scaling") or {} - rope_theta = config["rope_theta"] - rope_type = rope_params.get("rope_type", "default") - rotary_config = {"type": rope_type, "theta": rope_theta} - if rope_type == "default": - pass - elif rope_type == "llama3": - rotary_config.update( - { - "scale_factor": rope_params["factor"], - "low_frequency_factor": rope_params["low_freq_factor"], - "high_frequency_factor": rope_params["high_freq_factor"], - "original_context_length": rope_params["original_max_position_embeddings"], - } - ) - elif rope_type == "yarn": - rotary_config.update( - { - "attention_factor": rope_params["attention_factor"], - "beta_fast": rope_params["beta_fast"], - "beta_slow": rope_params["beta_slow"], - "original_context_length": rope_params["original_max_position_embeddings"], - } - ) - else: - raise NotImplementedError(f"Unsupported rotary type: {rope_type}") - out = { - "rotary": rotary_config, - "heads": config["num_attention_heads"], - "head_groups": config["num_key_value_heads"], - "head_size": config.get("head_dim"), - "add_linear_biases": config["attention_bias"], - "dropout": config["attention_dropout"], + def _create_config_converters(cls) -> dict: + return { + "heads": RenameConfigConverter(("heads",), ("num_attention_heads",)), + "head_groups": RenameConfigConverter(("head_groups",), ("num_key_value_heads",)), + "head_size": DefaultConfigConverter( + ("head_size",), + ("head_dim",), + hf_default_fn=lambda hf: div(hf["hidden_size"], hf["num_attention_heads"]), + ), + "add_linear_biases": RenameConfigConverter(("add_linear_biases",), ("attention_bias",)), + "dropout": RenameConfigConverter(("dropout",), ("attention_dropout",)), + "causal": ConstantImportConfigConverter(("causal",), True), + "softmax_scale_power": ConstantImportConfigConverter(("softmax_scale_power",), 0.5), + "linear_layers": IgnoredConfigConverter( + ("query_layer",), ("key_layer",), ("value_layer",), ("dense_layer",) + ), + "rotary": CustomConfigConverter( + fast_llm_paths=(("rotary",),), + export_fn=_llama_rotary_export, + import_fn=_llama_rotary_import, + ), } - if out["head_size"] is None: - out["head_size"] = div(config["hidden_size"], out["heads"]) - - return out @classmethod - def export_config(cls, config: AttentionConfig) -> dict: - cls._check_config(config) - Assert.eq(config.softmax_scale_power, 0.5) - rope_parameters = {"rope_theta": config.rotary.theta} - if type(config.rotary) is DefaultRotaryConfig: - rope_parameters["rope_type"] = "default" - elif type(config.rotary) is Llama3RotaryConfig: - rope_parameters.update( - { - "rope_type": "llama3", - "factor": config.rotary.scale_factor, - "low_freq_factor": config.rotary.low_frequency_factor, - "high_freq_factor": config.rotary.high_frequency_factor, - "original_max_position_embeddings": config.rotary.original_context_length, - } - ) - elif type(config.rotary) is YarnRotaryConfig: - rope_parameters.update( - { - "rope_type": "yarn", - "attention_factor": config.rotary.attention_factor, - "beta_fast": config.rotary.beta_fast, - "beta_slow": config.rotary.beta_slow, - "original_max_position_embeddings": config.rotary.original_context_length, - } - ) - else: - raise NotImplementedError(f"Unsupported rotary type: {type(config.rotary).__name__}") - - common = { - "num_attention_heads": config.heads, - "num_key_value_heads": config.head_groups, - "head_dim": config.head_size, - "attention_bias": config.add_linear_biases, - "attention_dropout": config.dropout, - } - if _TRANSFORMERS_V4: - out = {**common, "rope_theta": rope_parameters["rope_theta"]} - if type(config.rotary) is not DefaultRotaryConfig: - out["rope_scaling"] = {k: v for k, v in rope_parameters.items() if k != "rope_theta"} - return out - return {**common, "rope_parameters": rope_parameters} + def _validate_export(cls, config: AttentionConfig) -> None: + """Default: Llama requires per-layer biases to be unset (``None``) or to match ``add_linear_biases``. - @classmethod - def _check_config(cls, config: AttentionConfig) -> None: - # Opportunity to make derived classes less constrained. + Subclasses (e.g. Qwen2 with always-on Q/K/V biases and no dense bias) override. + """ Assert.is_(type(config), AttentionConfig) Assert.incl(config.query_layer.bias.enabled, (None, config.add_linear_biases)) Assert.incl(config.key_layer.bias.enabled, (None, config.add_linear_biases)) Assert.incl(config.value_layer.bias.enabled, (None, config.add_linear_biases)) Assert.incl(config.dense_layer.bias.enabled, (None, config.add_linear_biases)) + # --- weight side (imperative) --- + @classmethod def get_converters( cls, @@ -306,8 +381,6 @@ def get_converters( f"{fast_llm_prefix}.query", f"{hf_prefix}.q_proj", config.add_linear_biases, - QueryWeightConverter, - config, drop_on_export=drop_on_export, ), *get_weight_and_bias_converters( @@ -327,67 +400,29 @@ def get_converters( ] -class QueryWeightConverter(WeightConverter): - # Hf uses the real format for rotary embeddings. - _config: AttentionConfig - - def export_weight( - self, weight: tuple[torch.Tensor | SafeTensorSlice, ...] - ) -> tuple[torch.Tensor | SafeTensorSlice, ...]: - (query,) = weight - return (query,) - - def import_weight( - self, weight: tuple[torch.Tensor | SafeTensorSlice, ...] - ) -> tuple[torch.Tensor | SafeTensorSlice, ...]: - (query,) = weight - return (query,) - - -class KeyValueWeightConverter(WeightConverter): - # Hf uses the real format for rotary embeddings, and keeps the key and value separate. - _config: AttentionConfig - - def export_weight( - self, weight: tuple[torch.Tensor | SafeTensorSlice, ...] - ) -> tuple[torch.Tensor | SafeTensorSlice, ...]: - (key_value,) = weight - key, value = key_value[:].chunk(2) - return key, value +class LlamaBlockConverter(ConfigSectionConverter): + """Converts ``DecoderBlockConfig`` ↔ Llama block fields (flat-merged into the parent's HF dict).""" - def import_weight( - self, weight: tuple[torch.Tensor | SafeTensorSlice, ...] - ) -> tuple[torch.Tensor | SafeTensorSlice, ...]: - key, value = weight - key_value = torch.cat([key[:], value[:]]) - return (key_value,) + fast_llm_config_class = DecoderBlockConfig + mixer_converter_class: typing.ClassVar[type[ConfigSectionConverter]] = LlamaAttentionConverter + mlp_converter_class: typing.ClassVar[type[ConfigSectionConverter]] = LlamaMLPConverter + normalization_converter_class: typing.ClassVar[type[ConfigSectionConverter]] = LlamaNormalizationConverter -class LlamaBlockConverter: - mixer_converter_class: typing.ClassVar[type[LlamaAttentionConverter]] = LlamaAttentionConverter - mlp_converter_class: typing.ClassVar[type[LlamaMLPConverter]] = LlamaMLPConverter - normalization_converter_class: typing.ClassVar[type[LlamaNormalizationConverter]] = LlamaNormalizationConverter hf_mixer_name: typing.ClassVar[str] = "self_attn" hf_mlp_name: typing.ClassVar[str] = "mlp" hf_norm_1_name: typing.ClassVar[str] = "input_layernorm" hf_norm_2_name: typing.ClassVar[str] = "post_attention_layernorm" @classmethod - def import_config(cls, config: dict) -> dict: + def _create_config_converters(cls) -> dict: return { - "mixer": cls.mixer_converter_class.import_config(config), - "mlp": cls.mlp_converter_class.import_config(config), - "normalization": cls.normalization_converter_class.import_config(config), + "mixer": NestedConfigConverter(("mixer",), cls.mixer_converter_class), + "mlp": NestedConfigConverter(("mlp",), cls.mlp_converter_class), + "normalization": NestedConfigConverter(("normalization",), cls.normalization_converter_class), } - @classmethod - def export_config(cls, config: DecoderBlockConfig) -> dict: - Assert.custom(isinstance, config, DecoderBlockConfig) - return safe_merge_dicts( - cls.mixer_converter_class.export_config(config.mixer), - cls.mlp_converter_class.export_config(config.mlp), - cls.normalization_converter_class.export_config(config.normalization), - ) + # --- weight side (imperative) --- @classmethod def get_converters( @@ -422,34 +457,34 @@ def get_converters( class LlamaDecoderConverter: - block_converter_class: typing.ClassVar[type[LlamaBlockConverter]] = LlamaBlockConverter + """Converts ``BlockSequenceConfig`` (polymorphic Fixed/Pattern) ↔ Llama's flat block + ``num_hidden_layers``. + + Kept as a regular class (not a :class:`ConfigSectionConverter`) so it can stay imperative — the polymorphism + between Fixed/Pattern block sequences doesn't lend itself to the declarative shape, and subclasses (Mistral, + Qwen2, MTP-Llama, ...) plug in different block converters via ``block_converter_class``. + """ + + block_converter_class: typing.ClassVar[type[ConfigSectionConverter]] = LlamaBlockConverter @classmethod - def import_config(cls, config: dict) -> dict: + def import_config(cls, hf_dict: dict) -> dict: return { - "block": cls.block_converter_class.import_config(config), - "num_blocks": config["num_hidden_layers"], + "block": cls.block_converter_class.import_config(hf_dict), + "num_blocks": hf_dict["num_hidden_layers"], } @classmethod - def export_config(cls, config: FixedBlockSequenceConfig | PatternBlockSequenceConfig) -> dict: - if isinstance(config, PatternBlockSequenceConfig): - # All exported block configs must be equal - exported_block_configs = [ - safe_merge_dicts( - cls.block_converter_class.export_config(block_config), - {"num_hidden_layers": config.num_blocks}, - ) - for block_config in config.blocks.values() - ] - for other in exported_block_configs[1:]: - Assert.eq(exported_block_configs[0], other) - return exported_block_configs[0] - Assert.custom(isinstance, config, FixedBlockSequenceConfig) - return safe_merge_dicts( - cls.block_converter_class.export_config(config.block), - {"num_hidden_layers": config.num_blocks}, - ) + def export_config(cls, decoder_config: FixedBlockSequenceConfig | PatternBlockSequenceConfig) -> dict: + if isinstance(decoder_config, PatternBlockSequenceConfig): + exports = [cls.block_converter_class.export_config(block) for block in decoder_config.blocks.values()] + for other in exports[1:]: + Assert.eq(exports[0], other) + block_hf = exports[0] + elif isinstance(decoder_config, FixedBlockSequenceConfig): + block_hf = cls.block_converter_class.export_config(decoder_config.block) + else: + raise NotImplementedError(f"Unsupported decoder type: {type(decoder_config).__name__}") + return {**block_hf, "num_hidden_layers": decoder_config.num_blocks} @classmethod def get_converters( @@ -459,11 +494,10 @@ def get_converters( hf_prefix: str, drop_on_export: bool = False, ) -> list[WeightConverter]: - # In the case of PatternBlockSequenceConfig, compatibility was already checked in export_config block_config = ( config.block if isinstance(config, FixedBlockSequenceConfig) else next(iter(config.blocks.values())) ) - converters = [] + converters: list[WeightConverter] = [] for block_index in range(config.num_blocks): converters += cls.block_converter_class.get_converters( block_config, @@ -474,16 +508,28 @@ def get_converters( return converters -class LlamaEmbeddingsConverter: +class LlamaEmbeddingsConverter(ConfigSectionConverter): + """Converts ``LanguageModelEmbeddingsConfig`` ↔ Llama (flat ``vocab_size``). + + Llama has no learnable position embeddings; ``num_position_embeddings`` is irrelevant when + ``position_embeddings.enabled`` is ``False``/``None`` and is therefore blanket-consumed. + """ + + fast_llm_config_class = LanguageModelEmbeddingsConfig + @classmethod - def import_config(cls, config: dict) -> dict: - return {"vocab_size": config["vocab_size"]} + def _create_config_converters(cls) -> dict: + return { + "vocab_size": RenameConfigConverter(("vocab_size",), ("vocab_size",)), + "word_embeddings": IgnoredConfigConverter(("word_embeddings",)), + "position_embeddings": IgnoredConfigConverter(("position_embeddings",), ("num_position_embeddings",)), + } @classmethod - def export_config(cls, config: LanguageModelEmbeddingsConfig) -> dict: - Assert.custom(isinstance, config, LanguageModelEmbeddingsConfig) - assert not config.position_embeddings.enabled - return {"vocab_size": config.vocab_size} + def _validate_export(cls, config: LanguageModelEmbeddingsConfig) -> None: + Assert.incl(config.position_embeddings.enabled, (None, False)) + + # --- weight side (imperative) --- @classmethod def get_converters( @@ -492,18 +538,27 @@ def get_converters( return [WeightConverter(f"{fast_llm_prefix}.word_embeddings_weight", f"{hf_prefix}.embed_tokens.weight")] -class LlamaHeadConverter: - normalization_converter_class: typing.ClassVar[type[LlamaNormalizationConverter]] = LlamaNormalizationConverter - block_converter_class: typing.ClassVar[type[LlamaBlockConverter]] = LlamaBlockConverter +class LlamaHeadConverter(ConfigSectionConverter): + """Converts ``LanguageModelHeadConfig`` ↔ Llama final-norm fields (flat-merged).""" - @classmethod - def import_config(cls, config: dict) -> dict: - return {"normalization": cls.normalization_converter_class.import_config(config)} + fast_llm_config_class = LanguageModelHeadConfig + + normalization_converter_class: typing.ClassVar[type[ConfigSectionConverter]] = LlamaNormalizationConverter + # Used by MTP-Llama subclass to emit per-prediction-head block weight converters; Llama itself doesn't read it. + block_converter_class: typing.ClassVar[type[ConfigSectionConverter]] = LlamaBlockConverter @classmethod - def export_config(cls, config: LanguageModelHeadConfig) -> dict: - Assert.custom(isinstance, config, LanguageModelHeadConfig) - return cls.normalization_converter_class.export_config(config.normalization) + def _create_config_converters(cls) -> dict: + return { + "normalization": NestedConfigConverter(("normalization",), cls.normalization_converter_class), + "output_weight": IgnoredConfigConverter(("output_weight",)), + # ``prediction_heads`` is architecture (>1 enables multi-token prediction); Llama HF format does + # not represent it. We don't pin it to 1 here so MTP-Llama (a Llama-derived format) can override + # the declaration with a Rename without first hitting an assertion in the inherited path. + "prediction_heads": IgnoredConfigConverter(("prediction_heads",)), + } + + # --- weight side (imperative) --- @classmethod def get_converters( @@ -526,34 +581,48 @@ def get_converters( ] -class LlamaBaseModelConverter(HuggingFaceBaseModelConverter): +class LlamaBaseModelConverter(ConfigSectionConverter, HuggingFaceBaseModelConverter): + """Top-level converter for ``GPTBaseModelConfig`` ↔ Llama HF dict.""" + + fast_llm_config_class = GPTBaseModelConfig + # TODO: Peft? decoder_converter_class: typing.ClassVar[type[LlamaDecoderConverter]] = LlamaDecoderConverter - embeddings_converter_class: typing.ClassVar[type[LlamaEmbeddingsConverter]] = LlamaEmbeddingsConverter - head_converter_class: typing.ClassVar[type[LlamaHeadConverter]] = LlamaHeadConverter + embeddings_converter_class: typing.ClassVar[type[ConfigSectionConverter]] = LlamaEmbeddingsConverter + head_converter_class: typing.ClassVar[type[ConfigSectionConverter]] = LlamaHeadConverter @classmethod - def import_config(cls, config: dict) -> dict: + def _create_config_converters(cls) -> dict: + decoder_converter_class = cls.decoder_converter_class + + def _decoder_export(parent: Config) -> dict: + return {(k,): v for k, v in decoder_converter_class.export_config(parent.decoder).items()} + + def _decoder_import(hf_dict: dict) -> dict: + return {("decoder",): decoder_converter_class.import_config(hf_dict)} + return { - "embeddings": cls.embeddings_converter_class.import_config(config), - "decoder": cls.decoder_converter_class.import_config(config), - "head": cls.head_converter_class.import_config(config), - "hidden_size": config["hidden_size"], - "tied_embedding_weight": config["tie_word_embeddings"], + "embeddings": NestedConfigConverter(("embeddings",), cls.embeddings_converter_class), + "head": NestedConfigConverter(("head",), cls.head_converter_class), + "decoder": CustomConfigConverter( + fast_llm_paths=(("decoder",),), + export_fn=_decoder_export, + import_fn=_decoder_import, + ), + "hidden_size": RenameConfigConverter(("hidden_size",), ("hidden_size",)), + "tied_embedding_weight": RenameConfigConverter(("tied_embedding_weight",), ("tie_word_embeddings",)), + # Llama format cannot represent PEFT; the NoPeftConfig assertion lives on _validate_export so a + # user-configured LoRA fails clearly rather than being silently dropped on export. + "peft": IgnoredConfigConverter(("peft",)), } @classmethod - def export_config(cls, config: GPTBaseModelConfig) -> dict: - Assert.custom(isinstance, config, GPTBaseModelConfig) - return safe_merge_dicts( - cls.embeddings_converter_class.export_config(config.embeddings), - cls.decoder_converter_class.export_config(config.decoder), - cls.head_converter_class.export_config(config.head), - { - "tie_word_embeddings": config.tied_embedding_weight, - "hidden_size": config.hidden_size, - }, - ) + def _validate_export(cls, config: GPTBaseModelConfig) -> None: + from fast_llm.layers.common.peft.config import NoPeftConfig + + Assert.custom(isinstance, config.peft, NoPeftConfig) + + # --- weight side (imperative) --- @classmethod def get_converters(cls, config: GPTBaseModelConfig, exported_config: dict) -> list[WeightConverter]: diff --git a/fast_llm/models/gpt/conversion/mistral.py b/fast_llm/models/gpt/conversion/mistral.py index d4a669b22..7664a195c 100644 --- a/fast_llm/models/gpt/conversion/mistral.py +++ b/fast_llm/models/gpt/conversion/mistral.py @@ -1,8 +1,7 @@ import typing from fast_llm.engine.checkpoint.config import CheckpointFormat -from fast_llm.layers.attention.config import AttentionConfig -from fast_llm.layers.decoder.mlp.config import MLPConfig +from fast_llm.engine.checkpoint.external import ConstantImportConfigConverter, RenameConfigConverter from fast_llm.models.gpt.conversion.config import MistralCheckpointFormat from fast_llm.models.gpt.conversion.llama import ( LlamaAttentionConverter, @@ -13,45 +12,27 @@ LlamaHuggingfaceCheckpointHandler, LlamaMLPConverter, ) -from fast_llm.utils import safe_merge_dicts class MistralAttentionConverter(LlamaAttentionConverter): @classmethod - def import_config(cls, config: dict) -> dict: - config["attention_bias"] = False - return safe_merge_dicts( - super().import_config(config), - {"window_size": config["sliding_window"]}, - ) - - @classmethod - def export_config(cls, config: AttentionConfig) -> dict: - out = safe_merge_dicts( - super().export_config(config), - {"sliding_window": config.window_size}, - ) - del out["attention_bias"] - return out - - @classmethod - def _check_config(cls, config: AttentionConfig) -> None: - # Mistral doesn't support biases. - assert not config.add_linear_biases + def _create_config_converters(cls) -> dict: + return { + **super()._create_config_converters(), + # Mistral has no `attention_bias` HF field; biases are always disabled. + "add_linear_biases": ConstantImportConfigConverter(("add_linear_biases",), False), + "window_size": RenameConfigConverter(("window_size",), ("sliding_window",)), + } class MistralMLPConverter(LlamaMLPConverter): @classmethod - def import_config(cls, config: dict) -> dict: - config["mlp_bias"] = False - return super().import_config(config) - - @classmethod - def export_config(cls, config: MLPConfig) -> dict: - assert not config.add_linear_biases - out = super().export_config(config) - del out["mlp_bias"] - return out + def _create_config_converters(cls) -> dict: + return { + **super()._create_config_converters(), + # Mistral has no `mlp_bias` HF field; biases are always disabled. + "add_linear_biases": ConstantImportConfigConverter(("add_linear_biases",), False), + } class MistralBlockConverter(LlamaBlockConverter): diff --git a/fast_llm/models/gpt/conversion/mixtral.py b/fast_llm/models/gpt/conversion/mixtral.py index 6908d2958..7659befa3 100644 --- a/fast_llm/models/gpt/conversion/mixtral.py +++ b/fast_llm/models/gpt/conversion/mixtral.py @@ -1,8 +1,14 @@ import typing from fast_llm.engine.checkpoint.config import CheckpointFormat -from fast_llm.engine.checkpoint.external import SplitWeightConverter, WeightConverter -from fast_llm.layers.decoder.mlp.config import MoEMLPConfig +from fast_llm.engine.checkpoint.external import ( + ConstantImportConfigConverter, + IgnoredConfigConverter, + RenameConfigConverter, + SplitWeightConverter, + WeightConverter, +) +from fast_llm.layers.decoder.mlp.config import MoEMLPConfig, RoutingType from fast_llm.models.gpt.conversion.config import MixtralCheckpointFormat from fast_llm.models.gpt.conversion.llama import LlamaMLPConverter, MLPLayer2Converter, get_weight_and_bias_converters from fast_llm.models.gpt.conversion.mistral import ( @@ -12,35 +18,32 @@ MistralHeadConverter, MistralHuggingfaceCheckpointHandler, ) -from fast_llm.utils import Assert, safe_merge_dicts class MixtralMLPConverter(LlamaMLPConverter): + fast_llm_config_class = MoEMLPConfig + @classmethod - def import_config(cls, config: dict) -> dict: - config["mlp_bias"] = False - return safe_merge_dicts( - super().import_config(config), - { - "type": "moe", - "experts": config["num_local_experts"], - "experts_per_token": config["num_experts_per_tok"], - }, - ) + def _create_config_converters(cls) -> dict: + return { + **super()._create_config_converters(), + # Mixtral has no `mlp_bias` HF field; biases are always disabled. + "add_linear_biases": ConstantImportConfigConverter(("add_linear_biases",), False), + "experts": RenameConfigConverter(("experts",), ("num_local_experts",)), + "experts_per_token": RenameConfigConverter(("experts_per_token",), ("num_experts_per_tok",)), + # Mixtral has no shared experts and uses the topk default; assert on export, inject defaults on import. + "shared_experts": ConstantImportConfigConverter(("shared_experts",), 0), + "routing": ConstantImportConfigConverter(("routing",), RoutingType.topk), + # Mixtral's gate is a default LinearConfig (no bias); blanket-consume so coverage passes. + "router": IgnoredConfigConverter(("router",)), + } @classmethod - def export_config(cls, config: MoEMLPConfig) -> dict: - Assert.custom(isinstance, config, MoEMLPConfig) - assert not config.add_linear_biases - out = super().export_config(config) - del out["mlp_bias"] - return safe_merge_dicts( - out, - { - "num_local_experts": config.experts, - "num_experts_per_tok": config.experts_per_token, - }, - ) + def import_config(cls, hf_dict: dict) -> dict: + # Inject the Fast-LLM dynamic-type discriminator so `from_dict` instantiates `MoEMLPConfig` + # rather than the default `MLPConfig`. The MLP is wrapped via `NestedConfigConverter`, so + # there's no surrounding `DispatchConfigConverter` to inject this for us. + return {"type": "moe", **super().import_config(hf_dict)} @classmethod def get_converters( diff --git a/fast_llm/models/gpt/conversion/mtp_llama.py b/fast_llm/models/gpt/conversion/mtp_llama.py index f681c4a24..787ba0220 100644 --- a/fast_llm/models/gpt/conversion/mtp_llama.py +++ b/fast_llm/models/gpt/conversion/mtp_llama.py @@ -3,9 +3,9 @@ from transformers import PretrainedConfig from fast_llm.engine.checkpoint.config import CheckpointFormat -from fast_llm.engine.checkpoint.external import WeightConverter +from fast_llm.engine.checkpoint.external import RenameConfigConverter, WeightConverter from fast_llm.layers.block.config import FixedBlockSequenceConfig -from fast_llm.layers.language_model.config import LanguageModelConfig, LanguageModelHeadConfig +from fast_llm.layers.language_model.config import LanguageModelConfig from fast_llm.models.gpt.config import GPTModelConfig from fast_llm.models.gpt.conversion.config import MTPLlamaCheckpointFormat from fast_llm.models.gpt.conversion.llama import ( @@ -20,27 +20,21 @@ class MTPLlamaHeadConverter(LlamaHeadConverter): @classmethod - def import_config(cls, config: dict) -> dict: + def _create_config_converters(cls) -> dict: return { - **super().import_config(config), - "prediction_heads": config["prediction_heads"], + **super()._create_config_converters(), + # MTP-Llama exposes the prediction-heads count via the HF config; Llama itself blanket-ignores it. + "prediction_heads": RenameConfigConverter(("prediction_heads",), ("prediction_heads",)), } - @classmethod - def export_config(cls, config: LanguageModelHeadConfig) -> dict: - return safe_merge_dicts( - super().export_config(config), - {"prediction_heads": config.prediction_heads}, - ) - @classmethod def get_converters( cls, config: LanguageModelConfig, exported_config: dict, ) -> list[WeightConverter]: - # Override: map head.final_norm to model.mtp_norms.0 (not model.norm as in standard Llama), - # since MTPLlamaModel uses mtp_norms[0] for the first prediction head. + # MTP-Llama uses ``model.mtp_norms.0`` for the first prediction head's final norm + # instead of the standard ``model.norm``. converters = [ *cls.normalization_converter_class.get_converters( config.head.normalization, @@ -70,19 +64,19 @@ def get_converters( class MTPLlamaDecoderConverter(LlamaDecoderConverter): @classmethod - def import_config(cls, config: dict) -> dict: + def import_config(cls, hf_dict: dict) -> dict: return { - "block": cls.block_converter_class.import_config(config), - "num_blocks": config["num_hidden_layers"], + "block": cls.block_converter_class.import_config(hf_dict), + "num_blocks": hf_dict["num_hidden_layers"], } @classmethod - def export_config(cls, config: FixedBlockSequenceConfig) -> dict: + def export_config(cls, decoder_config: FixedBlockSequenceConfig) -> dict: # TODO: Support PatternBlockSequenceConfig with compatible configs. - Assert.custom(isinstance, config, FixedBlockSequenceConfig) + Assert.custom(isinstance, decoder_config, FixedBlockSequenceConfig) return safe_merge_dicts( - cls.block_converter_class.export_config(config.block), - {"num_hidden_layers": config.num_blocks}, + cls.block_converter_class.export_config(decoder_config.block), + {"num_hidden_layers": decoder_config.num_blocks}, ) diff --git a/fast_llm/models/gpt/conversion/qwen2.py b/fast_llm/models/gpt/conversion/qwen2.py index 9aa2f8c8e..3d9d6f349 100644 --- a/fast_llm/models/gpt/conversion/qwen2.py +++ b/fast_llm/models/gpt/conversion/qwen2.py @@ -1,10 +1,13 @@ import typing from fast_llm.engine.checkpoint.config import CheckpointFormat -from fast_llm.engine.checkpoint.external import WeightConverter +from fast_llm.engine.checkpoint.external import ( + ConstantImportConfigConverter, + CustomConfigConverter, + WeightConverter, +) from fast_llm.layers.attention.config import AttentionConfig from fast_llm.layers.block.config import FixedBlockSequenceConfig -from fast_llm.layers.decoder.mlp.config import MLPConfig from fast_llm.models.gpt.config import GPTBaseModelConfig from fast_llm.models.gpt.conversion.config import Qwen2CheckpointFormat from fast_llm.models.gpt.conversion.llama import ( @@ -16,38 +19,43 @@ LlamaHeadConverter, LlamaHuggingfaceCheckpointHandler, LlamaMLPConverter, - QueryWeightConverter, get_weight_and_bias_converters, ) -from fast_llm.utils import Assert +from fast_llm.utils import Assert, div class Qwen2AttentionConverter(LlamaAttentionConverter): # TODO: Support sliding window with max_window_layers (need 2 kinds of block?) @classmethod - def import_config(cls, config: dict) -> dict: - config["attention_bias"] = False - out = super().import_config(config) - out["query_layer"] = {"bias": {"enabled": True}} - out["key_layer"] = {"bias": {"enabled": True}} - out["value_layer"] = {"bias": {"enabled": True}} - out["dense_layer"] = {"bias": {"enabled": False}} - return out - - @classmethod - def export_config(cls, config: AttentionConfig) -> dict: - out = super().export_config(config) - del out["attention_bias"] - # Qwen2Config does not have head_dim as a standard field; it is always - # derivable as hidden_size // num_attention_heads. - del out["head_dim"] + def _create_config_converters(cls) -> dict: + out = super()._create_config_converters() + # Qwen2 has no `attention_bias` HF field; the model always has Q/K/V biases enabled and no dense bias. + out["add_linear_biases"] = ConstantImportConfigConverter(("add_linear_biases",), False) + # Qwen2Config does not have `head_dim`; it is always derivable as `hidden_size // num_attention_heads`. + out["head_size"] = CustomConfigConverter( + fast_llm_paths=(("head_size",),), + export_fn=lambda config: {}, + import_fn=lambda hf: {("head_size",): div(hf["hidden_size"], hf["num_attention_heads"])}, + ) + # Override Llama's blanket per-layer bias ignore with Qwen2's hardcoded layer biases. + # On export the per-layer biases must be compatible with `add_linear_biases`; see ``_validate_export``. + out["linear_layers"] = CustomConfigConverter( + fast_llm_paths=(("query_layer",), ("key_layer",), ("value_layer",), ("dense_layer",)), + export_fn=lambda config: {}, + import_fn=lambda hf: { + ("query_layer",): {"bias": {"enabled": True}}, + ("key_layer",): {"bias": {"enabled": True}}, + ("value_layer",): {"bias": {"enabled": True}}, + ("dense_layer",): {"bias": {"enabled": False}}, + }, + ) return out @classmethod - def _check_config(cls, config: AttentionConfig) -> None: + def _validate_export(cls, config: AttentionConfig) -> None: Assert.is_(type(config), AttentionConfig) - # There are multiple ways to enable biases on QKV only + # There are multiple ways to enable biases on QKV only. if config.add_linear_biases: Assert.incl(config.query_layer.bias.enabled, (None, True)) Assert.incl(config.key_layer.bias.enabled, (None, True)) @@ -72,8 +80,6 @@ def get_converters( f"{fast_llm_prefix}.query", f"{hf_prefix}.q_proj", True, - QueryWeightConverter, - config, drop_on_export=drop_on_export, ), *get_weight_and_bias_converters( @@ -95,15 +101,12 @@ def get_converters( class Qwen2MLPConverter(LlamaMLPConverter): @classmethod - def import_config(cls, config: dict) -> dict: - config["mlp_bias"] = False - return super().import_config(config) - - @classmethod - def export_config(cls, config: MLPConfig) -> dict: - out = super().export_config(config) - del out["mlp_bias"] - return out + def _create_config_converters(cls) -> dict: + return { + **super()._create_config_converters(), + # Qwen2 has no `mlp_bias` HF field; biases are always disabled. + "add_linear_biases": ConstantImportConfigConverter(("add_linear_biases",), False), + } class Qwen2BlockConverter(LlamaBlockConverter): @@ -124,12 +127,13 @@ class Qwen2BaseModelConverter(LlamaBaseModelConverter): head_converter_class: typing.ClassVar[type[Qwen2HeadConverter]] = Qwen2HeadConverter @classmethod - def import_config(cls, config: dict) -> dict: - assert config.get("use_mrope") is not True, "MRoPE (use_mrope=True) is not supported by the Qwen2 converter" - return super().import_config(config) + def import_config(cls, hf_dict: dict) -> dict: + assert hf_dict.get("use_mrope") is not True, "MRoPE (use_mrope=True) is not supported by the Qwen2 converter" + return super().import_config(hf_dict) @classmethod - def export_config(cls, config: GPTBaseModelConfig) -> dict: + def _validate_export(cls, config: GPTBaseModelConfig) -> None: + super()._validate_export(config) block = ( config.decoder.block if isinstance(config.decoder, FixedBlockSequenceConfig) @@ -141,7 +145,6 @@ def export_config(cls, config: GPTBaseModelConfig) -> dict: config.hidden_size, msg="Qwen2 format omits head_dim; requires heads * head_size == hidden_size", ) - return super().export_config(config) class Qwen2HuggingfaceCheckpointHandler(LlamaHuggingfaceCheckpointHandler): diff --git a/fast_llm/models/multimodal/conversion/llava.py b/fast_llm/models/multimodal/conversion/llava.py index fe7c77f5e..b48b1d042 100644 --- a/fast_llm/models/multimodal/conversion/llava.py +++ b/fast_llm/models/multimodal/conversion/llava.py @@ -3,13 +3,21 @@ import torch from fast_llm.engine.checkpoint.config import CheckpointFormat -from fast_llm.engine.checkpoint.external import WeightConverter +from fast_llm.engine.checkpoint.external import ( + ConfigSectionConverter, + ConstantExportConfigConverter, + ConstantImportConfigConverter, + CustomConfigConverter, + IgnoredConfigConverter, + NestedConfigConverter, + RenameConfigConverter, + WeightConverter, +) from fast_llm.engine.checkpoint.huggingface import HuggingFaceBaseModelConverter, HuggingfaceStateDictCheckpointHandler from fast_llm.engine.multi_stage.config import FastLLMModelConfig from fast_llm.functional.config import ActivationType from fast_llm.layers.attention.config import AttentionConfig from fast_llm.layers.attention.rotary.config import Rotary2DConfig -from fast_llm.layers.common.normalization.config import RMSNormalizationConfig from fast_llm.layers.decoder.mlp.config import MLPConfig from fast_llm.layers.language_model.config import LanguageModelConfig from fast_llm.layers.vision.config import PatchEmbeddingsConfig, VisionEncoderConfig @@ -31,21 +39,15 @@ class PixtralNormalizationConverter(LlamaNormalizationConverter): - """ - epsilon hard-coded to 1e-5. - """ + """RMS norm with HF-side hardcoded epsilon=1e-5 (Pixtral's HF format omits the field).""" @classmethod - def import_config(cls, config: dict) -> dict: - return {"type": "rms_norm", "epsilon": 1e-5} - - @classmethod - def export_config(cls, config: RMSNormalizationConfig) -> dict: - Assert.custom(isinstance, config, RMSNormalizationConfig) - assert not config.zero_centered - # TODO: Too strict? - Assert.eq(config.epsilon, 1e-5) - return {} + def _create_config_converters(cls) -> dict: + return { + **super()._create_config_converters(), + # Pin epsilon to 1e-5: assert on export, inject on import. No HF write/read. + "epsilon": ConstantImportConfigConverter(("epsilon",), 1e-5), + } class PixtralAttentionConverter(LlamaAttentionConverter): @@ -60,7 +62,7 @@ def import_config(cls, config: dict) -> dict: @classmethod def export_config(cls, config: AttentionConfig) -> dict: - cls._check_config(config) + cls._validate_export(config) Assert.eq(config.softmax_scale_power, 0.5) Assert.is_(type(config.rotary), Rotary2DConfig) assert not config.add_linear_biases @@ -117,32 +119,39 @@ def import_weight( ) -class PixtralEmbeddingsConverter: +class PixtralEmbeddingsConverter(ConfigSectionConverter): + """Converts ``PatchEmbeddingsConfig`` <-> Pixtral HF flat fields (``patch_size`` / ``num_channels``). + + Pixtral's HF ``vision_config`` carries a single ``patch_size`` field (height == width); the converter + expands it to both Fast-LLM dimensions on import and validates equality on export. + """ + + fast_llm_config_class = PatchEmbeddingsConfig normalization_converter_class: typing.ClassVar[type[PixtralNormalizationConverter]] = PixtralNormalizationConverter @classmethod - def import_config(cls, config: dict) -> dict: - Assert.eq(config["num_channels"], 3) + def _create_config_converters(cls) -> dict: return { - "normalization": cls.normalization_converter_class.import_config(config), - "patch_height": config["patch_size"], - "patch_width": config["patch_size"], + "patch_height": RenameConfigConverter(("patch_height",), ("patch_size",)), + # Pixtral has one `patch_size`; mirror it to `patch_width` on import and validate equality on export. + "patch_width": CustomConfigConverter( + fast_llm_paths=(("patch_width",),), + export_fn=lambda c: {}, + import_fn=lambda hf: {("patch_width",): hf["patch_size"]}, + ), + # `input_channels` is a derived cached_property pinned to 3; assert on import, emit on export. + "num_channels": ConstantExportConfigConverter(("num_channels",), 3), + # PixtralNormalizationConverter exports {} (epsilon pinned), so flat-merge is a no-op on export. + "normalization": NestedConfigConverter(("normalization",), cls.normalization_converter_class), + # patch_embeddings (the AffineLinearConfig) has no HF representation; bias presence validated below. + "patch_embeddings": IgnoredConfigConverter(("patch_embeddings",)), } @classmethod - def export_config(cls, config: PatchEmbeddingsConfig) -> dict: - Assert.custom(isinstance, config, PatchEmbeddingsConfig) + def _validate_export(cls, config: PatchEmbeddingsConfig) -> None: Assert.eq(config.patch_height, config.patch_width) Assert.incl(config.patch_embeddings.bias.enabled, (None, False)) - return safe_merge_dicts( - { - "patch_size": config.patch_height, - "num_channels": config.input_channels, - }, - cls.normalization_converter_class.export_config(config.normalization), - ) - @classmethod def get_converters( cls, config: PatchEmbeddingsConfig, fast_llm_prefix: str, hf_prefix: str