fix: resolve crashes in fine-tuning notebooks (half_precision, cuda:0, torch.load)#169
fix: resolve crashes in fine-tuning notebooks (half_precision, cuda:0, torch.load)#169haoyu-haoyu wants to merge 1 commit intoagemagician:masterfrom
Conversation
- Add default `half_precision=False` to `PT5_classification_model()`.
The training function calls it as
`PT5_classification_model(num_labels=num_labels)` without passing
`half_precision`, causing a TypeError crash. Affects all 3 notebooks.
- Fix hardcoded `to('cuda:0')` in per-residue classification notebook.
Changed to `to(logits.device)` so CPU users are not forced to have
a CUDA device. The per-residue regression notebook already handles
this correctly.
- Add explicit `weights_only=False` to `torch.load(filepath)` in
`load_model()` across all 3 notebooks. PyTorch 2.6+ defaults to
`weights_only=True` which would break loading LoRA parameters.
There was a problem hiding this comment.
Code Review
This pull request modifies several PT5 fine-tuning notebooks to include default arguments for model initialization, enable non-restrictive model loading, and implement dynamic device placement for tensors. The review feedback suggests adopting a more idiomatic PyTorch pattern by combining device and type casting into a single call.
| " valid_labels=active_labels[active_labels!=-100]\n", | ||
| " \n", | ||
| " valid_labels=valid_labels.type(torch.LongTensor).to('cuda:0')\n", | ||
| " valid_labels=valid_labels.type(torch.LongTensor).to(logits.device)\n", |
There was a problem hiding this comment.
While your change correctly addresses the hardcoded device issue, it's more idiomatic in modern PyTorch to use the .to() method for both device and type casting. The .type() method is considered legacy. Combining these into a single .to() call is cleaner and more efficient. Also, torch.long is preferred over torch.LongTensor when specifying a dtype.
valid_labels=valid_labels.to(device=logits.device, dtype=torch.long)\n
Summary
Three bugs that prevent the LoRA fine-tuning notebooks from running:
TypeErrorcrash:PT5_classification_model(num_labels, half_precision)has no default forhalf_precision, buttrain_per_protein()/train_per_residue()call it asPT5_classification_model(num_labels=num_labels)without passing the argument. Addedhalf_precision=Falseas the default, matching the training function'smixed=Falsedefault. (All 3 notebooks)Hardcoded
to('cuda:0')crash: In the per-residue classification notebook,valid_labelsis unconditionally moved tocuda:0, crashing on CPU-only systems. Changed toto(logits.device)to match whatever device the model is on. The regression notebook already handles this correctly. (per_residue_class only)torch.loaddeprecation: Added explicitweights_only=Falseto suppress theFutureWarningin PyTorch 2.6+ (weights_onlywill default toTrue). These checkpoints contain LoRA parameter dicts that require full unpickling. (All 3 notebooks)Files changed
PT5_LoRA_Finetuning_per_prot.ipynbhalf_precisiondefault,torch.loadPT5_LoRA_Finetuning_per_residue_class.ipynbhalf_precisiondefault,cuda:0,torch.loadPT5_LoRA_Finetuning_per_residue_reg.ipynbhalf_precisiondefault,torch.loadTest plan
half_precision=Falsedefault matches the behavior expected by the training functionto(logits.device)automatically follows whatever device the model is onweights_only=Falsepreserves existing behavior while suppressing warnings