Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
919c6bd
fix(config): Component factory assert accepts BaseModel field aliases.
BlueCrescent Nov 7, 2025
68b02aa
feat(parallelism): Added first version of multi stage pipeline parall…
BlueCrescent Nov 7, 2025
c9677b1
chore: Merge remote-tracking branch 'origin/main' into pp_multi_stage
BlueCrescent Nov 19, 2025
f885ca8
feat(parallelism): Switched various pp configs to interleaved 1F1B.
BlueCrescent Nov 21, 2025
a915805
fix(parallelism): Handling for None block types when fsdp2 wrapping m…
BlueCrescent Nov 21, 2025
bc7089b
fix(parallelism): Deactivated eval before training again due to bug w…
BlueCrescent Nov 21, 2025
07ef847
refactor: Better name for maybe_list_parameter_wrapper (relevant in s…
BlueCrescent Nov 21, 2025
7f0518b
feat: Better error reporting in cuda env tear down.
BlueCrescent Nov 21, 2025
e7cb524
test(parallelism): Multiple updates to warmstart tests.
BlueCrescent Nov 21, 2025
56660fb
refactor(typing): Removed unused PydanticPytorchModuleNotListType.
BlueCrescent Nov 21, 2025
d24c7bc
fix(logging): correct value for num_total_stages
BlueCrescent Nov 28, 2025
534bf5a
docs(optimizer): Class docstring for OptimizersList.
BlueCrescent Nov 28, 2025
30bb0ac
fix(config): Correct model validation call.
BlueCrescent Dec 1, 2025
8aa7302
fix(config): Included validation aliases in parameter checks.
BlueCrescent Dec 1, 2025
aaafde3
feat(utility): Added better way for config alias deprecation and depr…
BlueCrescent Dec 1, 2025
0e527a9
refactor(config): missing type hints
BlueCrescent Dec 1, 2025
a8913e0
refactor: duplicate if statement
BlueCrescent Dec 1, 2025
92a0c2c
refactor: Unified use of model_parts instead of wrapped_model_or_part…
BlueCrescent Dec 1, 2025
ad9dfb3
test(huggingface): removed skip mark from fixture (no effect)
BlueCrescent Dec 2, 2025
6080f19
chore: Merge remote-tracking branch 'origin/main' into pp_multi_stage
BlueCrescent Dec 2, 2025
b88a060
feat(utility): model parts support for debugging components
BlueCrescent Dec 2, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions config_files/training/config_lorem_ipsum_long_fsdp2_pp.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ device_mesh:
config:
device_type: cuda
data_parallel_replicate_degree: 1
pipeline_parallel_degree: 2
pipeline_parallel_degree: 4
data_parallel_shard_degree: -1
world_size: ${settings.cuda_env.world_size}

Expand Down Expand Up @@ -251,7 +251,7 @@ scheduled_pipeline:
loss_fn:
instance_key: loss_fn
pass_type: BY_REFERENCE
pp_schedule_name: gpipe
pp_schedule_name: Interleaved1F1B
batch_size: ${settings.step_profile.local_train_micro_batch_size}
microbatch_size: 2
pp_degree: ${device_mesh.config.pipeline_parallel_degree}
Expand Down Expand Up @@ -318,7 +318,7 @@ staged_pipeline:
instance_key: device_mesh
pass_type: BY_REFERENCE
local_rank: ${settings.cuda_env.local_rank}
pp_schedule_name: gpipe
pp_schedule_name: ${scheduled_pipeline.config.pp_schedule_name}
num_layers_per_stage: 2

model_raw:
Expand All @@ -332,7 +332,7 @@ model_raw:
sequence_length: ${settings.step_profile.sequence_length}
prediction_key: ${loss_fn.config.prediction_key}
vocab_size: 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency
n_layer: 2
n_layer: 6
n_head_q: 8
n_head_kv: 4
ffn_hidden: 128
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ staged_pipeline:
instance_key: device_mesh
pass_type: BY_REFERENCE
local_rank: ${settings.cuda_env.local_rank}
pp_schedule_name: gpipe
pp_schedule_name: ${scheduled_pipeline.config.pp_schedule_name}
num_layers_per_stage: 2

model_raw:
Expand Down
3 changes: 2 additions & 1 deletion src/modalities/checkpointing/fsdp/fsdp_checkpoint_saving.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,8 @@ def _save_checkpoint(self, app_state: AppState, training_progress: TrainingProgr
# saving the model via FULL_STATE_DICT and checkpoint via FULL_OPTIM_STATE_DICT
model_save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
optim_save_policy = FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=True)
model = app_state.model
assert len(app_state.model_parts) == 1, "FSDP1CheckpointSaving only supports a single model part."
model = app_state.model_parts[0]
optimizer = app_state.optimizer
with FSDP.state_dict_type(
module=model,
Expand Down
56 changes: 35 additions & 21 deletions src/modalities/checkpointing/stateful/app_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler

from modalities.optimizers.optimizer_list import OptimizersList


class StatefulComponents(Enum):
MODEL = "model"
Expand All @@ -34,15 +36,18 @@ class AppState(Stateful):
https://pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html
"""

def __init__(self, model: nn.Module, optimizer: Optimizer, lr_scheduler: Optional[LRScheduler] = None):
def __init__(
self, model: nn.Module | list[nn.Module], optimizer: Optimizer, lr_scheduler: Optional[LRScheduler] = None
):
"""Initializes the AppState object.

Args:
model (nn.Module): The model can be either a non-sharded model, FSDP1 or FSDP2 model.
model (nn.Module | list[nn.Module]): The model or model parts can be either
a non-sharded model, FSDP1 or FSDP2 model.
optimizer (Optimizer): The optimizer can be either a non-sharded optimizer, FSDP1 or FSDP2 optimizer.
lr_scheduler (Optional[LRScheduler], optional): The lr scheduler used during training. Defaults to None.
"""
self._model = model
self._model_parts = list(model) if isinstance(model, list) else [model]
self._optimizer = optimizer
self._lr_scheduler = lr_scheduler
self._is_loaded = False
Expand All @@ -56,8 +61,8 @@ def is_loaded(self) -> bool:
return self._is_loaded

@property
def model(self) -> nn.Module:
return self._model
def model_parts(self) -> list[nn.Module]:
return self._model_parts

@property
def optimizer(self) -> Optimizer:
Expand Down Expand Up @@ -153,15 +158,15 @@ def get_state_dict(app_state: AppState) -> dict[str, Any]:
class ModelStateRetriever(StateRetrieverIF):
@staticmethod
def get_state_dict(app_state: AppState) -> dict[str, Any]:
"""Returns the state dict of the model in the AppState object.
"""Returns the flattened state dicts of the model parts in the AppState object.

Args:
app_state (AppState): The app_state object containing the model.

Returns:
dict[str, Any]: The state dict of the model in the AppState object.
"""
return get_model_state_dict(model=app_state.model)
return {k: v for sd in map(get_model_state_dict, app_state.model_parts) for k, v in sd.items()}

@staticmethod
def load_state_dict_(app_state: AppState, state_dict: dict[str, Any]) -> None:
Expand All @@ -171,7 +176,8 @@ def load_state_dict_(app_state: AppState, state_dict: dict[str, Any]) -> None:
app_state (AppState): The app_state object containing the model.
state_dict (dict[str, Any]): The state dict to load into the model.
"""
set_model_state_dict(model=app_state.model, model_state_dict=state_dict, options=StateDictOptions(strict=False))
for model in app_state.model_parts:
set_model_state_dict(model=model, model_state_dict=state_dict, options=StateDictOptions(strict=False))


class OptimizerStateRetriever(StateRetrieverIF):
Expand All @@ -185,13 +191,17 @@ def get_state_dict(app_state: AppState) -> dict[str, Any]:
Returns:
dict[str, Any]: The state dict of the optimizer in the AppState object.
"""
sd = get_optimizer_state_dict(
model=app_state.model,
optimizers=app_state.optimizer,
# NOTE: Flattening is required for pipeline parallelism to work correctly.
# see https://github.com/pytorch/torchtitan/blob/b291ad662493b63d25b038a30a915082d3617baf/torchtitan/components/checkpoint.py#L193-L214
options=StateDictOptions(flatten_optimizer_state_dict=True),
)
if isinstance(app_state.optimizer, OptimizersList):
sd = app_state.optimizer.state_dict()
else:
assert len(app_state.model_parts) == 1, "Expected a single model part for non-OptimizersList optimizer."
sd = get_optimizer_state_dict(
model=app_state.model_parts[0],
optimizers=app_state.optimizer,
# NOTE: Flattening is required for pipeline parallelism to work correctly.
# see https://github.com/pytorch/torchtitan/blob/b291ad662493b63d25b038a30a915082d3617baf/torchtitan/components/checkpoint.py#L193-L214
options=StateDictOptions(flatten_optimizer_state_dict=True),
)
Comment on lines +196 to +204
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we remove this, since in case of PP we now always have an optimizer list which takes care of the flattening?

return sd

@staticmethod
Expand All @@ -202,12 +212,16 @@ def load_state_dict_(app_state: AppState, state_dict: dict[str, Any]) -> None:
app_state (AppState): The app_state object containing the optimizer.
state_dict (dict[str, Any]): The state dict to load into the optimizer.
"""
set_optimizer_state_dict(
model=app_state.model,
optimizers=app_state.optimizer,
optim_state_dict=state_dict,
options=StateDictOptions(flatten_optimizer_state_dict=True),
)
if isinstance(app_state.optimizer, OptimizersList):
app_state.optimizer.load_state_dict(state_dict)
else:
assert len(app_state.model_parts) == 1, "Expected a single model part for non-OptimizersList optimizer."
set_optimizer_state_dict(
model=app_state.model_parts[0],
optimizers=app_state.optimizer,
optim_state_dict=state_dict,
options=StateDictOptions(flatten_optimizer_state_dict=True),
)


class LRSchedulerStateRetriever(StateRetrieverIF):
Expand Down
5 changes: 3 additions & 2 deletions src/modalities/checkpointing/stateful/app_state_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,14 @@ class AppStateFactory:

@staticmethod
def get_raw_app_state(
model: nn.Module, optimizer: Optimizer, lr_scheduler: Optional[LRScheduler] = None
model: nn.Module | list[nn.Module], optimizer: Optimizer, lr_scheduler: Optional[LRScheduler] = None
) -> AppState:
"""Creates a new (non-checkpoint loaded) AppState object from an instantiated
model, optimizer, and optional learning rate scheduler.

Args:
model (nn.Module): The model can be either a non-sharded model, FSDP1 or FSDP2 model.
model (nn.Module | list[nn.Module]): The model (parts) can be either
a non-sharded model, FSDP1 or FSDP2 model.
optimizer (Optimizer): The optimizer can be either a non-sharded optimizer, FSDP1 or FSDP2 optimizer.
lr_scheduler (Optional[LRScheduler], optional): Lr scheduler used during training. Defaults to None.

Expand Down
48 changes: 36 additions & 12 deletions src/modalities/config/component_factory.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Any, Type, TypeVar

from pydantic import BaseModel
from pydantic import AliasChoices, BaseModel
from pydantic.fields import FieldInfo

from modalities.registry.registry import Registry
from modalities.util import print_rank_0
Expand Down Expand Up @@ -164,30 +165,53 @@ def _instantiate_component_config(self, component_key: str, variant_key: str, co
config_dict=config_dict,
component_config_type=component_config_type,
)
comp_config = component_config_type(**config_dict, strict=True)
comp_config = component_config_type.model_validate(config_dict, extra="forbid")
return comp_config

def _assert_valid_config_keys(
self, component_key: str, variant_key: str, config_dict: dict, component_config_type: Type[BaseModelChild]
) -> None:
required_keys = []
optional_keys = []
for key, field in component_config_type.model_fields.items():
# Collect required and optional keys, including aliases if defined.
required_keys: list[str] = []
optional_keys: list[str] = []
# Map aliases to canonical field names for clearer error messages.
alias_to_field: dict[str, str] = {}

for field_name, field in component_config_type.model_fields.items():
names_for_field = self._parse_str_aliases(alias_to_field, field_name, field)
if field.is_required():
required_keys.append(key)
required_keys.extend(names_for_field)
else:
optional_keys.append(key)
optional_keys.extend(names_for_field)

invalid_keys = []
for key in config_dict.keys():
if key not in required_keys and key not in optional_keys:
invalid_keys.append(key)
all_valid_keys = set(required_keys) | set(optional_keys)

invalid_keys = [key for key in config_dict.keys() if key not in all_valid_keys]
if len(invalid_keys) > 0:
message = f"Invalid keys {invalid_keys} for config `{component_key}.{variant_key}`"
message += f" of type {component_config_type}:\n{config_dict}\n"
message += f"Required keys: {required_keys}\nOptional keys: {optional_keys}"
if alias_to_field:
message += f"Alias to field mapping: {alias_to_field}\n"
message += f"Required keys (including aliases): {required_keys}\n"
message += f"Optional keys (including aliases): {optional_keys}\n"
raise ValueError(message)

def _parse_str_aliases(self, alias_to_field: dict[str, str], field_name: str, field: FieldInfo) -> set[str]:
names_for_field = {field_name}
if field.alias and field.alias != field_name:
names_for_field.add(field.alias)
alias_to_field[field.alias] = field_name
if field.validation_alias and field.validation_alias != field_name:
if isinstance(field.validation_alias, str):
names_for_field.add(field.validation_alias)
alias_to_field[field.validation_alias] = field_name
elif isinstance(field.validation_alias, AliasChoices):
for alias in field.validation_alias.choices:
if isinstance(alias, str):
names_for_field.add(alias)
alias_to_field[alias] = field_name
return names_for_field

def _instantiate_component(self, component_key: str, variant_key: str, component_config: BaseModel) -> Any:
component_type: Type = self.registry.get_component(component_key, variant_key)
component_config_dict = self._base_model_to_dict(component_config)
Expand Down
21 changes: 12 additions & 9 deletions src/modalities/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
PydanticModelInitializationIFType,
PydanticOptimizerIFType,
PydanticPytorchDeviceType,
PydanticPytorchModuleOrListType,
PydanticPytorchModuleType,
PydanticSamplerIFType,
PydanticTokenizerIFType,
Expand All @@ -43,6 +44,7 @@
ActivationCheckpointingVariants,
)
from modalities.util import parse_enum_by_name
from modalities.utils.deprecated_alias import add_deprecated_alias


class ProcessGroupBackendType(LookupEnum):
Expand Down Expand Up @@ -145,7 +147,7 @@ class CheckpointSavingConfig(BaseModel):

class AdamOptimizerConfig(BaseModel):
lr: float
wrapped_model: PydanticPytorchModuleType
wrapped_model: PydanticPytorchModuleOrListType
betas: tuple[float, float]
eps: float
weight_decay: float
Expand All @@ -154,7 +156,7 @@ class AdamOptimizerConfig(BaseModel):

class AdamWOptimizerConfig(BaseModel):
lr: float
wrapped_model: PydanticPytorchModuleType
wrapped_model: PydanticPytorchModuleOrListType
betas: tuple[float, float]
eps: float
weight_decay: float
Expand Down Expand Up @@ -264,7 +266,7 @@ def parse_sharding_strategy_by_name(cls, name: str) -> ShardingStrategy:


class FSDP2WrappedModelConfig(BaseModel):
model: PydanticPytorchModuleType
model: PydanticPytorchModuleOrListType
block_names: list[str]
mixed_precision_settings: FSDP2MixedPrecisionSettings
reshard_after_forward: bool = True
Expand All @@ -289,7 +291,7 @@ def validate_dp_mesh_existence(self):


class DebuggingEnrichedModelConfig(BaseModel):
model: PydanticPytorchModuleType
model: PydanticPytorchModuleOrListType
logging_dir_path: Path
tracked_ranks: Optional[Set[int]] = None
log_interval_steps: Optional[int] = 1
Expand All @@ -302,7 +304,7 @@ def convert_list_to_set(cls, v: Iterable[int] | None) -> Set[int] | None:


class GPT2ModelTPConfig(BaseModel):
model: PydanticPytorchModuleType # TODO set proper type
model: PydanticPytorchModuleOrListType # TODO set proper type
device_mesh: PydanticDeviceMeshIFType

@model_validator(mode="after")
Expand All @@ -325,7 +327,7 @@ class CompiledModelConfig(BaseModel):


class WeightInitializedModelConfig(BaseModel):
model: PydanticPytorchModuleType
model: PydanticPytorchModuleOrListType
model_initializer: PydanticModelInitializationIFType

# avoid warning about protected namespace 'model_', see
Expand All @@ -350,12 +352,12 @@ class SelectiveOpACParams(BaseModel):

ac_variant: ActivationCheckpointingVariants
layers_fqn: str
model: PydanticPytorchModuleType
model: PydanticPytorchModuleOrListType
ac_fun_params: FullACParams | SelectiveLayerACParams | SelectiveOpACParams


class RawAppStateConfig(BaseModel):
model: PydanticPytorchModuleType
model: PydanticPytorchModuleOrListType
optimizer: PydanticOptimizerIFType
lr_scheduler: Optional[PydanticLRSchedulerIFType] = None

Expand Down Expand Up @@ -480,12 +482,13 @@ class RichResultSubscriberConfig(BaseModel):
global_rank: int


@add_deprecated_alias("model_parts", "wrapped_model")
class GPT2MFUCalculatorConfig(BaseModel):
n_layer: Annotated[int, Field(strict=True, gt=0)]
sequence_length: Annotated[int, Field(strict=True, gt=0)]
n_embd: Annotated[int, Field(strict=True, gt=0)]
world_size: Annotated[int, Field(strict=True, gt=0)]
wrapped_model: PydanticFSDP1ModuleType | PydanticFSDP2ModuleType
model_parts: PydanticFSDP1ModuleType | PydanticFSDP2ModuleType | list[PydanticFSDP2ModuleType]
device_mesh: Optional[PydanticDeviceMeshIFType] = None


Expand Down
1 change: 1 addition & 0 deletions src/modalities/config/pydantic_if_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def __get_pydantic_core_schema__(
CheckpointSavingExecutionABC, PydanticThirdPartyTypeIF(CheckpointSavingExecutionABC)
]
PydanticPytorchModuleType = Annotated[nn.Module, PydanticThirdPartyTypeIF(nn.Module)]
PydanticPytorchModuleOrListType = PydanticPytorchModuleType | list[PydanticPytorchModuleType]
PydanticFSDP1ModuleType = Annotated[FSDP1, PydanticThirdPartyTypeIF(FSDP1)]
PydanticFSDP2ModuleType = Annotated[FSDP2, PydanticThirdPartyTypeIF(FSDP2)]
PydanticTokenizerIFType = Annotated[TokenizerWrapper, PydanticThirdPartyTypeIF(TokenizerWrapper)]
Expand Down
Loading