Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
145 changes: 70 additions & 75 deletions src/modalities/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,10 @@ def main() -> None:
help="Path to the YAML training config file.",
)
@click.option(
"--test_comm",
is_flag=True,
default=False,
help="If set, run a communication test before training.",
"--experiments_root_path",
type=click_pathlib.Path(exists=True),
required=True,
help="Path to the root directory where experiment folders will be created.",
)
@click.option(
"--experiment_id",
Expand All @@ -71,61 +71,51 @@ def main() -> None:
default=None,
help="Optional path to a folder where error logs will be written.",
)
@click.option(
"--test_comm",
is_flag=True,
default=False,
help="If set, run a communication test before training.",
)
def CMD_entry_point_run_modalities(
config_file_path: Path,
test_comm: bool = False,
experiments_root_path: Path,
experiment_id: Optional[str] = None,
error_log_folder: Optional[Path] = None,
test_comm: bool = False,
):
"""Entrypoint to run the model training.

Args:
config_file_path (Path): Path to the YAML training config file.
test_comm (bool): If set, run a communication test before training.
experiments_root_path (Path): Path to the root directory where experiment folders will be created.
experiment_id (Optional[str]): Optional experiment ID to use for this run.
If not provided it will be generated. Default is None.
error_log_folder (Optional[Path]): Optional path to a folder where error logs will be written.
test_comm (bool): If set, run a communication test before training.
"""

def _format_exception_as_json(e: Exception, environment: dict[str, Any]) -> str:
# Format an exception into a structured JSON string with error message, type, and stack trace.
error = {
"error": str(e),
"type": type(e).__name__,
"stacktrace": traceback.format_exception(type(e), e, e.__traceback__),
}

return json.dumps({"environment": environment, "error": error}, indent=2)

try:
with CudaEnv(process_group_backend=ProcessGroupBackendType.nccl):
if test_comm:
print_rank_0("Running communication test...")
run_communication_test()
print_rank_0("Communication test succeeded.")

main_obj = Main(config_file_path, experiment_id=experiment_id)
main_obj = Main(config_file_path, experiments_root_path=experiments_root_path, experiment_id=experiment_id)
components = main_obj.build_components(components_model_type=TrainingComponentsInstantiationModel)
main_obj.run(components)
except Exception as e:
if error_log_folder is not None:
environment = {
"rank": int(os.environ["RANK"] if "RANK" in os.environ else -1),
"local_rank": int(os.environ["LOCAL_RANK"] if "LOCAL_RANK" in os.environ else -1),
"world_size": int(os.environ["WORLD_SIZE"] if "WORLD_SIZE" in os.environ else -1),
"hostname": socket.gethostname(),
}
error_log_folder = (
error_log_folder / f"error_logs_{environment['hostname']}_{environment['local_rank']}.log"
)
error_log_folder.parent.mkdir(parents=True, exist_ok=True)
with open(error_log_folder, "w", encoding="utf-8") as f:
f.write(_format_exception_as_json(e, environment))

raise RuntimeError(f"An error occurred while running the training: {e}. ") from e
_exception_handling(e, error_log_folder)


@main.command(name="warmstart")
@click.option(
"--experiments_root_path",
type=click_pathlib.Path(exists=True),
required=True,
help="Path to the root directory where experiment folders will be created.",
)
@click.option(
"--config_file_path",
type=click_pathlib.Path(exists=True),
Expand All @@ -138,10 +128,22 @@ def _format_exception_as_json(e: Exception, environment: dict[str, Any]) -> str:
required=True,
help="Path to the file containing the model and optimizer checkpoint paths from the last successful checkpoint.",
)
def CMD_entry_point_warmstart_modalities(config_file_path: Path, last_checkpoint_info_file_path: Path):
@click.option(
"--error_log_folder",
type=click_pathlib.Path(),
default=None,
help="Optional path to a folder where error logs will be written.",
)
def CMD_entry_point_warmstart_modalities(
experiments_root_path: Path,
config_file_path: Path,
last_checkpoint_info_file_path: Path,
error_log_folder: Optional[Path] = None,
):
"""Entrypoint to run the model warmstart.

Args:
experiments_root_path (Path): Path to the root directory where experiment folders will be created.
config_file_path (Path): Path to the YAML warmstart config file.
last_checkpoint_info_file_path (Path): Path to the file containing the model and
optimizer checkpoint paths from the last successful checkpoint.
Expand All @@ -159,10 +161,15 @@ def get_last_checkpoint_resolver_fun(var_name: str, last_checkpoint_info_file_pa
get_last_checkpoint_resolver_fun, last_checkpoint_info_file_path=last_checkpoint_info_file_path
)
}
with CudaEnv(process_group_backend=ProcessGroupBackendType.nccl):
main_obj = Main(config_file_path, additional_resolver_funs=resolver_funs)
components = main_obj.build_components(components_model_type=TrainingComponentsInstantiationModel)
main_obj.run(components)
try:
with CudaEnv(process_group_backend=ProcessGroupBackendType.nccl):
main_obj = Main(
config_file_path, experiments_root_path=experiments_root_path, additional_resolver_funs=resolver_funs
)
components = main_obj.build_components(components_model_type=TrainingComponentsInstantiationModel)
main_obj.run(components)
except Exception as e:
_exception_handling(e, error_log_folder)


@main.command(name="generate_text")
Expand Down Expand Up @@ -705,54 +712,42 @@ def profile():
required=True,
help="Path to the experiment output directory.",
)
@click.option(
"--num_wait_steps",
type=int,
default=1,
show_default=True,
help="Number of wait steps to skip in profiling.",
)
@click.option(
"--num_warmup_steps",
type=int,
default=1,
show_default=True,
help="Number of warmup steps to skip in profiling. Already recording but dropping the data.",
)
@click.option(
"--num_measurement_steps",
type=int,
default=3,
show_default=True,
help="Number of steps to measure during profiling.",
)
@click.option(
"--profiled_ranks",
type=str,
default="0",
help="Comma-separated list of profiled ranks (must not have spaces), e.g. --profiled_ranks '2,4,8'",
)
def CMD_entry_point_run_train_step_profiler(
config_file_path: Path,
experiment_root_path: Path,
num_wait_steps: int,
num_warmup_steps: int,
num_measurement_steps: int,
profiled_ranks: str,
):
"""Run train step profiler and write result to JSON if RANK=0."""
profiled_ranks_list = [int(i) for i in profiled_ranks.split(",")] if profiled_ranks != "" else [0]
logger.info(f"Running distributed profiling on ranks {profiled_ranks_list}")

ModalitiesProfilerStarter.run_distributed(
config_file_path=config_file_path,
num_measurement_steps=num_measurement_steps,
num_wait_steps=num_wait_steps,
num_warmup_steps=num_warmup_steps,
experiment_root_path=experiment_root_path,
profiled_ranks=profiled_ranks_list,
)


def _format_exception_as_json(e: Exception, environment: dict[str, Any]) -> str:
# Format an exception into a structured JSON string with error message, type, and stack trace.
error = {
"error": str(e),
"type": type(e).__name__,
"stacktrace": traceback.format_exception(type(e), e, e.__traceback__),
}
return json.dumps({"environment": environment, "error": error}, indent=2)


def _exception_handling(e: Exception, error_log_folder: Path | None):
if error_log_folder is not None:
environment = {
"rank": int(os.environ["RANK"] if "RANK" in os.environ else -1),
"local_rank": int(os.environ["LOCAL_RANK"] if "LOCAL_RANK" in os.environ else -1),
"world_size": int(os.environ["WORLD_SIZE"] if "WORLD_SIZE" in os.environ else -1),
"hostname": socket.gethostname(),
}
error_log_folder = error_log_folder / f"error_logs_{environment['hostname']}_{environment['local_rank']}.log"
error_log_folder.parent.mkdir(parents=True, exist_ok=True)
with open(error_log_folder, "w", encoding="utf-8") as f:
f.write(_format_exception_as_json(e, environment))

raise RuntimeError(f"An error occurred while running the training: {e}. ") from e


if __name__ == "__main__":
main()
11 changes: 10 additions & 1 deletion src/modalities/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,8 @@ class AdamOptimizerConfig(BaseModel):
eps: float
weight_decay: float
weight_decay_groups_excluded: list[str]
foreach: bool | None = None
fused: bool | None = None


class AdamWOptimizerConfig(BaseModel):
Expand All @@ -159,6 +161,8 @@ class AdamWOptimizerConfig(BaseModel):
eps: float
weight_decay: float
weight_decay_groups_excluded: list[str]
foreach: bool | None = None
fused: bool | None = None


class DummyLRSchedulerConfig(BaseModel):
Expand Down Expand Up @@ -501,13 +505,16 @@ class ParallelDegreeConfig(BaseModel):

def load_app_config_dict(
config_file_path: Path,
experiment_id: Optional[str] = None,
experiments_root_path: Path | None = None,
experiment_id: str | None = None,
additional_resolver_funs: Optional[dict[str, Resolver]] = None,
) -> dict[str, YAMLValue]:
"""Load the application configuration from the given YAML file.

Args:
config_file_path (Path): YAML config file.
experiments_root_path: (Path, optional): The path to the experiments root directory.
Defaults to None.
experiment_id (str, optional): The experiment_id of the current run.
additional_resolver_funs (dict[str, Resolver], optional): Additional resolver functions.

Expand All @@ -534,6 +541,8 @@ def node_env_resolver_fun(var_name: str) -> int | None:
"config_file_path": config_file_path,
"config_folder_path": config_file_path.parent,
}
if experiments_root_path is not None:
modalities_env_kwargs["experiments_root_path"] = experiments_root_path
if experiment_id is not None:
modalities_env_kwargs["experiment_id"] = experiment_id
OmegaConf.register_new_resolver(
Expand Down
13 changes: 8 additions & 5 deletions src/modalities/config/instantiation_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,14 @@
PydanticPipelineType,
PydanticPytorchDeviceType,
PydanticPytorchModuleType,
PydanticSteppableProfilerIFType,
PydanticTextInferenceComponentType,
PydanticTokenizerIFType,
)
from modalities.config.utils import parse_torch_device
from modalities.dataloader.dataset import Dataset
from modalities.util import warn_rank_0
from modalities.utils.profilers.profilers import SteppableNoProfiler


class CudaEnvSettings(BaseModel):
Expand Down Expand Up @@ -67,7 +69,7 @@ class TrainingProgress(BaseModel):
class TrainingComponentsInstantiationModel(BaseModel):
class Settings(BaseModel):
class Paths(BaseModel):
checkpoint_saving_path: Path # Explicitly defined field
experiments_root_path: Path # Explicitly defined field

class Config:
extra = "allow"
Expand Down Expand Up @@ -182,13 +184,14 @@ def _check_last_step_checkpointed(self) -> "TrainingComponentsInstantiationModel
evaluation_subscriber: PydanticMessageSubscriberIFType
checkpoint_saving: PydanticCheckpointSavingIFType
gradient_clipper: PydanticGradientClipperIFType
mfu_calculator: Optional[PydanticMFUCalculatorABCType] = None
scheduled_pipeline: Optional[PydanticPipelineType] = None
device_mesh: Optional[PydanticDeviceMeshIFType] = None
profiler: PydanticSteppableProfilerIFType = SteppableNoProfiler()
mfu_calculator: PydanticMFUCalculatorABCType | None = None
scheduled_pipeline: PydanticPipelineType | None = None
device_mesh: PydanticDeviceMeshIFType | None = None
model_raw: PydanticPytorchModuleType

@model_validator(mode="after")
def _check_token_amount_in_dataset(self) -> "TrainingComponentsInstantiationModel.Settings":
def _check_token_amount_in_dataset(self) -> "TrainingComponentsInstantiationModel":
if (
len(self.train_dataset) * self.settings.step_profile.sequence_length
< self.settings.training_target.num_target_tokens
Expand Down
2 changes: 2 additions & 0 deletions src/modalities/config/pydantic_if_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from modalities.utils.debug_components import Debugging
from modalities.utils.mfu import MFUCalculatorABC
from modalities.utils.profilers.batch_generator import DatasetBatchGeneratorIF
from modalities.utils.profilers.profilers import SteppableProfilerIF
from modalities.utils.profilers.steppable_components import SteppableComponentIF


Expand Down Expand Up @@ -91,6 +92,7 @@ def __get_pydantic_core_schema__(
PydanticPipelineType = Annotated[Pipeline, PydanticThirdPartyTypeIF(Pipeline)]
PydanticPipelineStageType = Annotated[PipelineStage, PydanticThirdPartyTypeIF(PipelineStage)]
PydanticSteppableComponentIFType = Annotated[SteppableComponentIF, PydanticThirdPartyTypeIF(SteppableComponentIF)]
PydanticSteppableProfilerIFType = Annotated[SteppableProfilerIF, PydanticThirdPartyTypeIF(SteppableProfilerIF)]
PydanticRemovableHandleType = Annotated[
torch.utils.hooks.RemovableHandle, PydanticThirdPartyTypeIF(torch.utils.hooks.RemovableHandle)
]
Expand Down
Loading