Skip to content

fix: correct batch counting and squeeze dim in prott5_embedder.py#170

Open
haoyu-haoyu wants to merge 1 commit intoagemagician:masterfrom
haoyu-haoyu:fix/embedder-bugs
Open

fix: correct batch counting and squeeze dim in prott5_embedder.py#170
haoyu-haoyu wants to merge 1 commit intoagemagician:masterfrom
haoyu-haoyu:fix/embedder-bugs

Conversation

@haoyu-haoyu
Copy link
Copy Markdown

Summary

Two bugs in Embedding/prott5_embedder.py:

  • Double-counting in batch residue accumulation (line 100): The current sequence is appended to batch before computing n_res_batch, so seq_len is counted once inside the sum() over the batch and again via the explicit + seq_len. This causes batches to trigger the n_res_batch >= max_residues threshold too early, resulting in smaller batches than intended and slower embedding.

  • Bare .squeeze() without dim (line 131): For single-residue proteins in per-residue mode, the embedding shape (1, 1024) is silently collapsed to (1024,), making it indistinguishable from a per-protein embedding. Changed to .squeeze(0) to only affect the batch dimension.

Test plan

  • Verified the double-counting by tracing the logic: after batch.append(...), the batch already contains the current sequence, so the separate + seq_len is redundant
  • .squeeze(0) is a no-op when the first dimension is > 1, preserving existing behavior for all multi-residue proteins

- Fix double-counting in batch residue accumulation: `seq_len` was
  counted once via the batch list (after append) and once via the
  explicit `+ seq_len` term, making batches smaller than intended.
  Removed the redundant `+ seq_len`.

- Specify dim in `.squeeze(0)` instead of bare `.squeeze()`. Without
  a dim argument, single-residue proteins have their per-residue
  embedding shape (1, 1024) silently collapsed to (1024,), making
  them indistinguishable from per-protein embeddings.
Copy link
Copy Markdown

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request updates the batch residue counting logic and refines the embedding processing in Embedding/prott5_embedder.py. The batch residue count no longer includes the current sequence length before the batch limit check, and the squeeze operation on embeddings is now restricted to the first dimension to prevent accidental reduction of other singleton dimensions. I have no feedback to provide.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant