diff --git a/submissions/self_tuning/muon_torch/submission.py b/submissions/self_tuning/muon_torch/submission.py index 47d262da..9919b1cd 100644 --- a/submissions/self_tuning/muon_torch/submission.py +++ b/submissions/self_tuning/muon_torch/submission.py @@ -38,7 +38,7 @@ "warmup_factor": 0.05, "step_reduce": 1.0 } -hyperparameters = SimpleNamespace(**HPARAMS) +HPARAMS = SimpleNamespace(**HPARAMS) def _pytorch_cosine_warmup(step_hint: int, hyperparameters, optimizer): @@ -65,6 +65,8 @@ def init_optimizer_state( del model_state del rng + hyperparameters = HPARAMS + muon_params, adam_params = split_params_muon_adam(model_params) optimizer_state = { @@ -119,6 +121,8 @@ def update_params( del train_state del eval_results + hyperparameters = HPARAMS + reduced_steps = int(workload.step_hint * getattr(hyperparameters, "step_reduce", 1.0)) if global_step >= reduced_steps: raise spec.TrainingCompleteError(