diff --git a/examples/gpt-oss/sft.py b/examples/gpt-oss/sft.py index d85f5a072..9c2d6aeb8 100644 --- a/examples/gpt-oss/sft.py +++ b/examples/gpt-oss/sft.py @@ -72,7 +72,7 @@ def main(script_args, training_args, model_args, quant_args): "revision": model_args.model_revision, "trust_remote_code": model_args.trust_remote_code, "attn_implementation": model_args.attn_implementation, - "torch_dtype": model_args.torch_dtype, + "torch_dtype": getattr(model_args, "dtype", "float32"), "use_cache": not training_args.gradient_checkpointing, }