Skip to content

fix: resolve crashes in fine-tuning notebooks (half_precision, cuda:0, torch.load)#169

Open
haoyu-haoyu wants to merge 1 commit intoagemagician:masterfrom
haoyu-haoyu:fix/notebook-crashes
Open

fix: resolve crashes in fine-tuning notebooks (half_precision, cuda:0, torch.load)#169
haoyu-haoyu wants to merge 1 commit intoagemagician:masterfrom
haoyu-haoyu:fix/notebook-crashes

Conversation

@haoyu-haoyu
Copy link
Copy Markdown

Summary

Three bugs that prevent the LoRA fine-tuning notebooks from running:

  • TypeError crash: PT5_classification_model(num_labels, half_precision) has no default for half_precision, but train_per_protein() / train_per_residue() call it as PT5_classification_model(num_labels=num_labels) without passing the argument. Added half_precision=False as the default, matching the training function's mixed=False default. (All 3 notebooks)

  • Hardcoded to('cuda:0') crash: In the per-residue classification notebook, valid_labels is unconditionally moved to cuda:0, crashing on CPU-only systems. Changed to to(logits.device) to match whatever device the model is on. The regression notebook already handles this correctly. (per_residue_class only)

  • torch.load deprecation: Added explicit weights_only=False to suppress the FutureWarning in PyTorch 2.6+ (weights_only will default to True). These checkpoints contain LoRA parameter dicts that require full unpickling. (All 3 notebooks)

Files changed

Notebook Fix
PT5_LoRA_Finetuning_per_prot.ipynb half_precision default, torch.load
PT5_LoRA_Finetuning_per_residue_class.ipynb half_precision default, cuda:0, torch.load
PT5_LoRA_Finetuning_per_residue_reg.ipynb half_precision default, torch.load

Test plan

  • Verified the half_precision=False default matches the behavior expected by the training function
  • to(logits.device) automatically follows whatever device the model is on
  • weights_only=False preserves existing behavior while suppressing warnings

- Add default `half_precision=False` to `PT5_classification_model()`.
  The training function calls it as
  `PT5_classification_model(num_labels=num_labels)` without passing
  `half_precision`, causing a TypeError crash. Affects all 3 notebooks.

- Fix hardcoded `to('cuda:0')` in per-residue classification notebook.
  Changed to `to(logits.device)` so CPU users are not forced to have
  a CUDA device. The per-residue regression notebook already handles
  this correctly.

- Add explicit `weights_only=False` to `torch.load(filepath)` in
  `load_model()` across all 3 notebooks. PyTorch 2.6+ defaults to
  `weights_only=True` which would break loading LoRA parameters.
Copy link
Copy Markdown

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request modifies several PT5 fine-tuning notebooks to include default arguments for model initialization, enable non-restrictive model loading, and implement dynamic device placement for tensors. The review feedback suggests adopting a more idiomatic PyTorch pattern by combining device and type casting into a single call.

" valid_labels=active_labels[active_labels!=-100]\n",
" \n",
" valid_labels=valid_labels.type(torch.LongTensor).to('cuda:0')\n",
" valid_labels=valid_labels.type(torch.LongTensor).to(logits.device)\n",
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

While your change correctly addresses the hardcoded device issue, it's more idiomatic in modern PyTorch to use the .to() method for both device and type casting. The .type() method is considered legacy. Combining these into a single .to() call is cleaner and more efficient. Also, torch.long is preferred over torch.LongTensor when specifying a dtype.

            valid_labels=valid_labels.to(device=logits.device, dtype=torch.long)\n

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant