Skip to content

Llama3 Context Parallel Fixes and Performance edits#1461

Merged
pstjohn merged 1 commit intoNVIDIA:mainfrom
pstjohn:pstjohn/bio-230-refactor-perf_logger-to-only-update-metrics-every
Feb 13, 2026
Merged

Llama3 Context Parallel Fixes and Performance edits#1461
pstjohn merged 1 commit intoNVIDIA:mainfrom
pstjohn:pstjohn/bio-230-refactor-perf_logger-to-only-update-metrics-every

Conversation

@pstjohn
Copy link
Collaborator

@pstjohn pstjohn commented Feb 11, 2026

A collection of small performance improvements and bugfixes for llama3 CP training

Summary by CodeRabbit

Release Notes

  • New Features

    • Per-sequence padding control in data processing pipelines
    • Asynchronous batch prefetching for improved training performance
    • NVIDIA Nsight Systems profiling integration for performance analysis
  • Documentation

    • Added performance profiling guide with Nsight Systems configuration and best practices
  • Refactor

    • Updated profiler configuration from schedule-based to explicit step-range format
    • Updated Docker base image for improved compatibility

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Feb 11, 2026

Important

Review skipped

Auto reviews are disabled on this repository. Please check the settings in the CodeRabbit UI or the .coderabbit.yaml file in this repository. To trigger a single review, invoke the @coderabbitai review command.

You can disable this status message by setting the reviews.review_status to false in the CodeRabbit configuration file.

Use the checkbox below for a quick retry:

  • 🔍 Trigger review
✨ Finishing touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (7)
bionemo-recipes/recipes/esm2_native_te/perf_logger.py (1)

50-87: ⚠️ Potential issue | 🟡 Minor

Widen grad_norm type annotation to accept both Tensor and float.

Training code passes float to log_step() (via .item() on the gradient norm), but the function signature expects torch.Tensor. This causes a type mismatch that Pyright will flag. Update the annotation to grad_norm: torch.Tensor | float.

Additionally, min_loss is returned as a raw tensor in training entrypoints; consider adding a property accessor for scalar access:

♻️ Suggested changes
-        grad_norm: torch.Tensor,
+        grad_norm: torch.Tensor | float,

and optionally:

+    `@property`
+    def min_loss_value(self) -> float:
+        """Return min_loss as a Python float for external consumption."""
+        return float(self.min_loss.item())
bionemo-recipes/recipes/llama3_native_te/tests/test_train_two_gpu.py (6)

88-101: ⚠️ Potential issue | 🟡 Minor

Duplicated --standalone flag in torchrun command.

The --standalone flag appears twice (lines 91 and 94). This is likely unintentional and may cause issues with torchrun.

🐛 Proposed fix to remove duplicate flag
     run_train_cmd(
         [
             "torchrun",
             "--standalone",
             "--nproc_per_node",
             "2",  # 2 processes = 2 GPUs
-            "--standalone",  # Single node mode
             "train_ddp.py",
             "--config-name",
             "L0_sanity",
             "num_train_steps=4",  # Just 4 steps for speed
         ],
         recipe_path,
     )

116-129: ⚠️ Potential issue | 🟡 Minor

Duplicated --standalone flag in torchrun command.

Same issue as above - duplicate --standalone flags.

🐛 Proposed fix
     run_train_cmd(
         [
             "torchrun",
             "--standalone",
             "--nproc_per_node",
             "2",  # 2 processes = 2 GPUs
-            "--standalone",  # Single node mode
             "train_fsdp2.py",

141-157: ⚠️ Potential issue | 🟡 Minor

Duplicated --standalone flag in torchrun command.

Same duplicate flag issue in the checkpointing test.

🐛 Proposed fix
     run_train_cmd(
         [
             "torchrun",
             "--standalone",
             "--nproc_per_node",
             "2",
-            "--standalone",
             "train_ddp.py",

174-190: ⚠️ Potential issue | 🟡 Minor

Duplicated --standalone flag in torchrun command.

Same duplicate flag issue.

🐛 Proposed fix
     run_train_cmd(
         [
             "torchrun",
             "--standalone",
             "--nproc_per_node",
             "2",
-            "--standalone",
             "train_fsdp2.py",

200-218: ⚠️ Potential issue | 🟡 Minor

Duplicated --standalone flag in torchrun command.

Same duplicate flag issue in the BSHD CP test.

🐛 Proposed fix
     run_train_cmd(
         [
             "torchrun",
             "--standalone",
             "--nproc_per_node=2",
-            "--standalone",
             "train_fsdp2_cp.py",

224-242: ⚠️ Potential issue | 🟡 Minor

Duplicated --standalone flag in torchrun command.

Same duplicate flag issue in the THD CP test.

🐛 Proposed fix
     run_train_cmd(
         [
             "torchrun",
             "--standalone",
             "--nproc_per_node=2",
-            "--standalone",
             "train_fsdp2_cp.py",
🤖 Fix all issues with AI agents
In `@bionemo-recipes/recipes/llama3_native_te/Dockerfile`:
- Around line 2-6: The Dockerfile currently hardcodes an internal base image
(gitlab-master.nvidia.com/dl/transformerengine/transformerengine:main-pytorch-py3-base)
which breaks external builds; replace the fixed FROM with a build ARG pattern so
the default is the public NVIDIA image (e.g., nvidia/pytorch:26.01-py3) but CI
can override with the internal image. Concretely, add an ARG like BASE_IMAGE
with the public image as the default and change the FROM to use that ARG (refer
to the existing hardcoded image string and the commented public image name in
the file), keeping the explanatory note intact; this mirrors the codonfm_ptl_te
recipe pattern so external users can build while CI can pass the internal
registry value.
🧹 Nitpick comments (4)
bionemo-recipes/models/llama3/collator.py (1)

504-512: Broad exception handling may silently swallow errors.

Catching all exceptions and converting them to StopIteration could hide important failures (e.g., CUDA errors, assertion failures). Consider logging the exception before signaling stop.

♻️ Proposed fix to log exceptions
     def _do_one_prefetch(self):
         """Fetch one batch in the background. Stores result in _prefetch_result."""
         if self._cuda_device is not None:
             torch.cuda.set_device(self._cuda_device)
         try:
             self._prefetch_result = self._send_data_to_cp_tp_ranks()
-        except Exception:
+        except Exception as e:
             # Process group may have been destroyed; signal stop.
+            logger.debug("Prefetch exception (may be expected at shutdown): %s", e)
             self._prefetch_result = StopIteration()
bionemo-recipes/recipes/esm2_native_te/collator.py (2)

504-512: Same broad exception handling concern as llama3/collator.py.

Consider logging the exception before converting to StopIteration.

♻️ Proposed fix to log exceptions
     def _do_one_prefetch(self):
         """Fetch one batch in the background. Stores result in _prefetch_result."""
         if self._cuda_device is not None:
             torch.cuda.set_device(self._cuda_device)
         try:
             self._prefetch_result = self._send_data_to_cp_tp_ranks()
-        except Exception:
+        except Exception as e:
             # Process group may have been destroyed; signal stop.
+            logger.debug("Prefetch exception (may be expected at shutdown): %s", e)
             self._prefetch_result = StopIteration()

1-948: Significant code duplication across collator files.

This file is nearly identical to bionemo-recipes/models/llama3/collator.py and bionemo-recipes/models/esm2/src/esm/collator.py. Consider consolidating into a shared module to reduce maintenance burden.

bionemo-recipes/recipes/llama3_native_te/collator.py (1)

488-497: Consider adding thread-safety documentation or synchronization.

The current implementation relies on the GIL and the fact that _prefetch_thread.join() completes before accessing _prefetch_result. While this is correct in practice, the shared mutable state (_prefetch_result) accessed from both the main thread and the prefetch thread could benefit from explicit documentation noting the synchronization invariant (join completes before result access).

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 5

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (5)
bionemo-recipes/recipes/esm2_native_te/perf_logger.py (1)

45-51: ⚠️ Potential issue | 🟠 Major

Fix min_loss return type annotation mismatch.

main() is annotated to return float | None (line 48 in train_fsdp2.py) but actually returns perf_logger.min_loss, which is a torch.Tensor (initialized as a CUDA tensor on line 50 of perf_logger.py and updated via torch.minimum() on line 103). This type mismatch affects all training scripts in this recipe (train_fsdp2.py, train_mfsdp.py, train_fsdp2_cp.py, train_ddp_cp.py, train_ddp.py).

Either convert min_loss to a scalar before returning (e.g., perf_logger.min_loss.item()) or update the return type annotation to torch.Tensor.

bionemo-recipes/models/llama3/collator.py (1)

282-305: ⚠️ Potential issue | 🟠 Major

Guard split_samples when padding inflates per-sequence length.

When padding makes tokens_available ≥ raw sample length (possible if max_tokens_per_batch isn’t divisible by the pad multiple), _split_sample_by_num_tokens can raise. Falling back to starting a new batch avoids hard failures in that configuration.

🛠️ Proposed fix
                 else:
                     # Calculate how many padded tokens are already in the batch
                     tokens_in_batch = current_length - self._padded_len(len(sample["input_ids"]))
                     # Calculate how many tokens we can fit from this sample
                     tokens_available = self.max_tokens_per_batch - tokens_in_batch
-                    first_part, remaining_part = _split_sample_by_num_tokens(sample, tokens_available)
-                    yield [*samples, first_part]
-                    samples = [remaining_part]
+                    sample_len = len(sample["input_ids"])
+                    if tokens_available <= 0 or tokens_available >= sample_len:
+                        if samples:
+                            yield samples
+                        samples = [sample]
+                    else:
+                        first_part, remaining_part = _split_sample_by_num_tokens(sample, tokens_available)
+                        yield [*samples, first_part]
+                        samples = [remaining_part]
 
                 current_length = self._padded_len(len(samples[0]["input_ids"]))
bionemo-recipes/recipes/llama3_native_te/train_ddp.py (1)

136-139: ⚠️ Potential issue | 🟠 Major

Remove the per-batch print to avoid training slowdowns.

Printing every batch will bottleneck I/O and spam logs, especially in multi-GPU runs. Prefer gated debug logging or remove it entirely.

🔧 Suggested fix
-            print(batch["input_ids"].shape)
+            if dist_config.local_rank == 0 and logger.isEnabledFor(logging.DEBUG):
+                logger.debug("batch input_ids shape: %s", batch["input_ids"].shape)
bionemo-recipes/models/esm2/tests/test_collator_context_parallel.py (1)

401-419: ⚠️ Potential issue | 🟡 Minor

Fail fast if rank-0 scatter payload never arrives.

data_ready.wait(timeout=5) returns a boolean that’s ignored; a timeout will surface as a cryptic KeyError later. Add an explicit check for clearer failures (apply to both occurrences).

🛠️ Suggested fix
-                data_ready.wait(timeout=5)
-                scatter_object_output_list[0] = scatter_payload["data"][1]
+                if not data_ready.wait(timeout=5):
+                    raise AssertionError("Timed out waiting for rank 0 scatter payload")
+                scatter_object_output_list[0] = scatter_payload["data"][1]
bionemo-recipes/recipes/llama3_native_te/tests/test_train_two_gpu.py (1)

88-99: ⚠️ Potential issue | 🟡 Minor

Remove duplicated --standalone flags in all torchrun invocations.

--standalone appears twice in six torchrun command lists across this file (lines 91–94, 119–122, 144–147, 177–180, 203–205, 227–229). Keep only one instance per command to improve clarity; while argparse treats duplicates as redundant, they are confusing to readers.

Suggested fix for lines 88–99
    run_train_cmd(
        [
            "torchrun",
            "--standalone",
            "--nproc_per_node",
            "2",  # 2 processes = 2 GPUs
-            "--standalone",  # Single node mode
            "train_ddp.py",

Apply the same removal to the other five occurrences in this file.

🤖 Fix all issues with AI agents
In `@bionemo-recipes/models/esm2/tests/test_collator_context_parallel.py`:
- Around line 1071-1072: The test instantiation uses an invalid init parameter
cp_world_size for DataCollatorForContextParallel (which defines cp_world_size as
field(init=False)); replace that argument by passing
device_mesh=_DummyCollatorMesh(cp_size=cp_world_size) when constructing
DataCollatorForContextParallel (keep collator=base_collator and
qkv_format="thd"), so locate the line creating cp_collator and swap
cp_world_size=cp_world_size for
device_mesh=_DummyCollatorMesh(cp_size=cp_world_size).

In `@bionemo-recipes/models/llama3/collator.py`:
- Around line 488-513: The __next__ implementation must guard against
_prefetch_thread being None and must surface background exceptions instead of
swallowing them; change __next__ to check if self._prefetch_thread is not None
before calling join and to re-raise any Exception objects stored in
self._prefetch_result; modify _do_one_prefetch to store actual exceptions (e.g.,
the Exception instance) into self._prefetch_result rather than converting
everything to StopIteration, and reserve StopIteration only for genuine
iteration termination returned by _send_data_to_cp_tp_ranks; keep using
_kick_prefetch to start the thread, ensure _send_data_to_cp_tp_ranks remains the
producer of StopIteration for end-of-iteration, and continue to set
torch.cuda.set_device(self._cuda_device) when _cuda_device is not None.

In `@bionemo-recipes/recipes/esm2_native_te/perf_logger.py`:
- Around line 98-138: metrics.compute() can return GPU tensors which break
formatting and wandb; after calling metrics = self.metrics.compute() (and before
self.metrics.reset(), wandb.log, and logger.info), convert any tensor values to
host Python scalars (e.g., v.detach().cpu().item() for scalar tensors) or to CPU
tensors as appropriate, replacing entries in the metrics dict with those CPU
scalars so wandb.log(metrics, step=step) and the logger.info(",
".join([f"{k.split('/')[1]}: {v:.3g}" for k, v in metrics.items()])) formatting
work without errors.

In `@bionemo-recipes/recipes/llama3_native_te/collator.py`:
- Around line 296-304: The split logic can overfill when
pad_sequences_to_be_divisible_by doesn't divide max_tokens_per_batch; modify the
splitting branch (where tokens_available is computed) to ensure the padded
length of the chosen split fits the remaining capacity: after computing
tokens_available = self.max_tokens_per_batch - tokens_in_batch, reduce
tokens_available (e.g., decrement in a loop) until
self._padded_len(tokens_available) <= self.max_tokens_per_batch -
tokens_in_batch (or zero), then call _split_sample_by_num_tokens with that
adjusted tokens_available; alternatively, you can detect the incompatible
configuration in __post_init__ (when split_samples=True and max_tokens_per_batch
% pad_sequences_to_be_divisible_by != 0) and raise a clear error — reference
functions/fields: _split_sample_by_num_tokens, _padded_len,
max_tokens_per_batch, pad_sequences_to_be_divisible_by, split_samples, and
__post_init__.

In `@bionemo-recipes/recipes/llama3_native_te/perf_logger.py`:
- Around line 110-121: The sampling mismatch is that running_loss and
grad_acc_step_count are updated every micro-step (see running_loss and
grad_acc_step_count) while num_tokens and num_unpadded_tokens are only updated
at logging intervals (controlled by logging_frequency and step); fix by making
sampling consistent: either move the num_tokens and num_unpadded_tokens
increments to the same micro-step scope so they are updated every micro-step
using batch["input_ids"] and batch.get("attention_mask") (or its fallback) or
only increment grad_acc_step_count when you actually update token counters at
the logging interval; update comments to document the chosen behavior and
reference running_loss, grad_acc_step_count, num_tokens, num_unpadded_tokens,
logging_frequency, and batch so reviewers can find the change.
🧹 Nitpick comments (1)
bionemo-recipes/recipes/llama3_native_te/collator.py (1)

504-512: Broad exception handling may hide real errors.

The except Exception block at lines 510-512 catches all exceptions and converts them to StopIteration. While the comment mentions process group destruction, this could mask other errors (e.g., CUDA OOM, data corruption). Consider either:

  1. Catching specific exceptions like RuntimeError from distributed ops
  2. Logging the exception before converting to StopIteration
♻️ Proposed fix to log exceptions
     def _do_one_prefetch(self):
         """Fetch one batch in the background. Stores result in _prefetch_result."""
         if self._cuda_device is not None:
             torch.cuda.set_device(self._cuda_device)
         try:
             self._prefetch_result = self._send_data_to_cp_tp_ranks()
-        except Exception:
+        except Exception as e:
             # Process group may have been destroyed; signal stop.
+            logger.debug("Prefetch exception (treating as end of iteration): %s", e)
             self._prefetch_result = StopIteration()

@pstjohn
Copy link
Collaborator Author

pstjohn commented Feb 11, 2026

@coderabbitai resolve

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Feb 11, 2026

✅ Actions performed

Comments resolved.

Copy link
Collaborator

@jomitchellnv jomitchellnv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM but you may wanna move some of the changes related to prefetch into the other MR

Signed-off-by: Peter St. John <pstjohn@nvidia.com>
@pstjohn pstjohn force-pushed the pstjohn/bio-230-refactor-perf_logger-to-only-update-metrics-every branch from e424e40 to 008787a Compare February 12, 2026 15:48
@pstjohn pstjohn added this pull request to the merge queue Feb 13, 2026
Merged via the queue into NVIDIA:main with commit 0dd26c5 Feb 13, 2026
18 checks passed
@pstjohn pstjohn deleted the pstjohn/bio-230-refactor-perf_logger-to-only-update-metrics-every branch February 13, 2026 15:35
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants