Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
9a3e86d
Entropy loss tweaks
jlamypoirier Jan 29, 2026
741832c
fix
jlamypoirier Jan 29, 2026
b1d4b8d
fix
jlamypoirier Jan 29, 2026
c326bff
fix
jlamypoirier Jan 30, 2026
02c28a5
Triton loss
jlamypoirier Jan 31, 2026
22c8e0b
Triton loss
jlamypoirier Jan 31, 2026
7491e0f
Parallel attempt
jlamypoirier Jan 31, 2026
b8e7179
fix
jlamypoirier Feb 3, 2026
3c3e0c8
fixes
jlamypoirier Feb 3, 2026
2d293ea
Cross-entropy from distribution
jlamypoirier Feb 5, 2026
1d0439e
Forward KL
jlamypoirier Feb 5, 2026
1b40518
Reverse KL, triton tweaks
jlamypoirier Feb 6, 2026
d99511b
rename
jlamypoirier Feb 6, 2026
094ac85
Z loss
jlamypoirier Feb 6, 2026
35fd220
fix
jlamypoirier Feb 6, 2026
9133fcd
Grad accumulation
jlamypoirier Feb 7, 2026
945dadc
Token dim
jlamypoirier Feb 11, 2026
e0d0d7d
cleaanp
jlamypoirier Feb 11, 2026
99e6400
fix
jlamypoirier Feb 11, 2026
15c0e43
Simplify MTP
jlamypoirier Feb 11, 2026
f803e82
misc
jlamypoirier Feb 11, 2026
7469f83
stuff
jlamypoirier Feb 17, 2026
295c25b
stuff
jlamypoirier Feb 18, 2026
1697a48
stuff
jlamypoirier Feb 20, 2026
dd536b8
fixes
jlamypoirier Feb 20, 2026
c75ae2b
stuff
jlamypoirier Feb 25, 2026
7d1ec40
fixes
jlamypoirier Feb 26, 2026
3944dbb
stuff
jlamypoirier Mar 6, 2026
a5853bc
fixes
jlamypoirier Mar 7, 2026
da9751b
fixes
jlamypoirier Mar 7, 2026
7a1b318
Merge branch 'jlp_simplify_mtp' into jlp_batch
jlamypoirier Mar 7, 2026
f3974bb
fixes
jlamypoirier Mar 9, 2026
b3eb88d
fixes
jlamypoirier Mar 10, 2026
1af4e9f
fixes
jlamypoirier Mar 12, 2026
f2a2e94
fixes
jlamypoirier Mar 12, 2026
8dd1186
fixes
jlamypoirier Mar 13, 2026
362d758
fixes
jlamypoirier Mar 16, 2026
15a50c3
fixes
jlamypoirier Mar 17, 2026
a19c04d
Merge remote-tracking branch 'origin/main' into jlp_entropy_loss_tweaks
jlamypoirier Mar 17, 2026
36fc58b
Merge branch 'jlp_entropy_loss_tweaks' into jlp_triton_loss
jlamypoirier Mar 17, 2026
d16319a
Merge remote-tracking branch 'origin/main' into jlp_triton_loss
jlamypoirier Mar 17, 2026
ce47352
Merge branch 'jlp_triton_loss' into jlp_token_dim
jlamypoirier Mar 17, 2026
cbabbe9
Merge remote-tracking branch 'origin/main' into jlp_token_dim
jlamypoirier Mar 17, 2026
06c513e
Merge branch 'jlp_token_dim' into jlp_simplify_mtp
jlamypoirier Mar 17, 2026
cd4173d
Merge remote-tracking branch 'origin/main' into jlp_simplify_mtp
jlamypoirier Mar 17, 2026
8cfdfa9
Merge branch 'jlp_simplify_mtp' into jlp_batch
jlamypoirier Mar 17, 2026
f316c55
Merge remote-tracking branch 'origin/main' into jlp_batch
jlamypoirier Mar 17, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 4 additions & 8 deletions examples/mistral.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,14 @@ training:
interval: 10
evaluators:
validation:
evaluator:
type: loss
iterations: null
test_iters: 0
batch:
sequence_length: 4096
micro_batch_size: 2
batch_size: 64
type: loss
iterations: null
data:
datasets:
training:
type: random
micro_batch_size: 8192
maximum_document_length: 4096
optimizer:
learning_rate:
base: 1.0e-05
Expand Down
2 changes: 1 addition & 1 deletion fast_llm/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.3.0"
__version__ = "0.4.0"
25 changes: 16 additions & 9 deletions fast_llm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
_AUTO_VALIDATE = True

MISSING = Tag("<MISSING>")
DEFAULT = Tag("<DEFAULT>")
# DEFAULT = Tag("<DEFAULT>")


class NoAutoValidate:
Expand Down Expand Up @@ -425,12 +425,12 @@ def _validate(self) -> None:
if not field.init or field._field_type != dataclasses._FIELD: # noqa
continue
value = getattr(self, name)
if isinstance(value, Tag):
Assert.is_(value, DEFAULT)
# Replace the value with its default.
# We still need to validate because some fields have invalid defaults.
# TODO: Improve (still needed with new config update format? Do earlier to allow implicit defaults?)
value = field.default
# if isinstance(value, Tag):
# Assert.is_(value, DEFAULT)
# # Replace the value with its default.
# # We still need to validate because some fields have invalid defaults.
# # TODO: Improve (still needed with new config update format? Do earlier to allow implicit defaults?)
# value = field.default
new_value = self._validate_nested(value, field.type, field.name, field.valid, errors, False)
setattr(self, name, new_value)
for name in getattr(self, "_unknown_fields", {}):
Expand Down Expand Up @@ -781,7 +781,15 @@ def _from_dict(cls, default: dict[str, typing.Any], strict: bool = True) -> typi
continue
# Check for nested configs to instantiate.
try:
value = cls._from_dict_nested(default.pop(name, MISSING), field.type, strict)
value = default.pop(name, MISSING)
# Skip fields for which we want to use the provided default.
# This will prevent unwanted config instantiation in union fields with non-config defaults,
# For example optional config fields.
if value is MISSING and (
field.default is not dataclasses.MISSING or field.default_factory is not dataclasses.MISSING
):
continue
value = cls._from_dict_nested(value, field.type, strict)
if value is not MISSING:
out_arg_dict[name] = value
except FieldTypeError as e:
Expand All @@ -801,7 +809,6 @@ def _from_dict_nested(cls, value, type_, strict: bool):
if type_ in (typing.Any, types.NoneType):
pass
elif isinstance(type_, types.UnionType):
# Takes care of Optional too
value = cls._from_dict_union(value, type_, strict)
elif hasattr(type_, "__origin__"):
# TODO: Improve error messages for nested entries.
Expand Down
13 changes: 9 additions & 4 deletions fast_llm/data/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,19 @@
BlendedDatasetConfig,
ConcatenatedDatasetConfig,
DatasetSliceConfig,
)
from fast_llm.data.dataset.memmap.config import ( # isort: skip
LanguageModelReaderConfig,
MemmapDatasetConfig,
SampledDatasetUpdateConfig,
NullReaderConfig,
PatchReaderConfig,
RangeReaderConfig,
TokenReaderConfig,
)
from fast_llm.data.dataset.gpt.config import ( # isort: skip
GPTDatasetFromFileConfig,
GPTFimSampledDatasetConfig,
GPTRandomDatasetConfig,
)
from fast_llm.data.preparator.dataset_discovery.config import DatasetDiscoveryConfig # isort: skip
from fast_llm.data.preparator.gpt_memmap.config import GPTMemmapDatasetPreparatorConfig # isort: skip
from fast_llm.data.sample.abstract import NullReaderConfig # isort: skip
from fast_llm.data.preparation.dataset_discovery.config import DatasetDiscoveryConfig # isort: skip
from fast_llm.data.preparation.gpt_memmap.config import GPTMemmapDatasetPreparatorConfig # isort: skip
38 changes: 15 additions & 23 deletions fast_llm/data/data/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,53 +4,45 @@

from fast_llm.config import Configurable
from fast_llm.data.data.config import DataConfig
from fast_llm.data.dataset.config import SamplingParameters
from fast_llm.data.preprocessing.abstract import PreprocessingConfig
from fast_llm.data.sample.abstract import Batch
from fast_llm.data.document.abstract import Batch, ModelInput
from fast_llm.data.document.config import BatchPreprocessingConfig
from fast_llm.engine.distributed.config import DistributedConfig
from fast_llm.engine.schedule.config import BatchConfig

if typing.TYPE_CHECKING:
from fast_llm.engine.distributed.distributed import Distributed


class Data[ConfigType: DataConfig](Configurable[ConfigType], abc.ABC):
_distributed: "Distributed"
_sampling_parameters: dict[str, SamplingParameters]
_preprocessing: PreprocessingConfig
_cache_directory: pathlib.Path | None
_is_setup: bool = False

def __init__(self, config: DataConfig, distributed_config: DistributedConfig) -> None:
super().__init__(config)
self._distributed_config = distributed_config

# TODO: Improve interface
def setup(
self,
distributed: "Distributed",
sampling_parameters: dict[str, SamplingParameters],
preprocessing: PreprocessingConfig,
cache_directory: pathlib.Path,
timeout: float | None = None,
) -> None:
self._distributed = distributed
self._sampling_parameters = sampling_parameters
self._preprocessing = preprocessing
def setup(self, cache_directory: pathlib.Path) -> None:
self._cache_directory = cache_directory

@property
def distributed(self):
return self._distributed
self._is_setup = True

@abc.abstractmethod
def sample_dataset(
self,
dataset_name: str,
config: BatchPreprocessingConfig,
num_samples: int,
) -> list[ModelInput]:
pass

def get_iterator(
self,
batch_config: BatchConfig,
dataset_name: str,
*,
consumed_samples: int,
num_workers: int,
prefetch_factor: int | None = None,
timeout: float = 60,
) -> typing.Iterator[Batch]:
preprocess: bool = True,
) -> typing.Iterator[list[ModelInput] | Batch]:
pass
8 changes: 1 addition & 7 deletions fast_llm/data/data/config.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,6 @@
import typing

from fast_llm.config import Config, Field, config_class
from fast_llm.data.dataset.config import SamplingConfig, SamplingData
from fast_llm.config import Config, config_class


@config_class()
class DataConfig(Config):
_abstract = True
_sampling_config_class: typing.ClassVar[type[SamplingData]]

sampling: SamplingConfig = Field(desc="Default configuration for dataset sampling.")
18 changes: 10 additions & 8 deletions fast_llm/data/data/gpt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,26 +4,23 @@
from fast_llm.config import Field, FieldHint, check_field, config_class
from fast_llm.data.config import MultiprocessingContext
from fast_llm.data.data.config import DataConfig
from fast_llm.data.dataset.config import SampledDatasetConfig
from fast_llm.data.dataset.config import SampledDatasetConfig, SamplingConfigBase
from fast_llm.utils import Assert

if typing.TYPE_CHECKING:
from fast_llm.data.sample.language_model import LanguageModelSample
from fast_llm.data.document.language_model import LanguageModelDocument
logger = logging.getLogger(__name__)


@config_class()
class GPTDataConfig(DataConfig):
class GPTDataConfig(DataConfig, SamplingConfigBase):
"""
Configuration for the dataset(s), split and sampling.
Currently hard-coded to a GPT dataset.
TODO: Extract generalizable content.
Configuration for the dataset(s) and its sampling.
"""

_abstract = False

# TODO: Review field. Move closer to phase definition in training config?
datasets: dict[str, SampledDatasetConfig["LanguageModelSample"]] = Field(
datasets: dict[str, SampledDatasetConfig["LanguageModelDocument"]] = Field(
default_factory=dict,
desc="Configuration for the dataset(s).",
hint=FieldHint.core,
Expand All @@ -39,3 +36,8 @@ class GPTDataConfig(DataConfig):
desc="Multiprocessing context. Do not touch.",
hint=FieldHint.expert,
)
seed: int = Field(
default=784569,
desc="Seed for random sampling.",
hint=FieldHint.feature,
)
Loading
Loading