Skip to content

Implement automatic micro batch sizing#1060

Open
jonahsamost wants to merge 9 commits intoNovaSky-AI:mainfrom
jonahsamost:jonah_2_9_batchsize
Open

Implement automatic micro batch sizing#1060
jonahsamost wants to merge 9 commits intoNovaSky-AI:mainfrom
jonahsamost:jonah_2_9_batchsize

Conversation

@jonahsamost
Copy link

@jonahsamost jonahsamost commented Feb 10, 2026

Referencing #1048

This PR adds an optional profiling step to automatically determine the largest micro batch size that fits in GPU memory.
When auto_micro_batch_size: true in the config, a profiling step will run once after model initialization. The function will iterate all divisors of the mini_batch_size_per_gpu in ascending order, run a dummy forward + backward pass at maximum sequence length, then measure memory usage, in order to determine the best candidate micro batch size option.


Open with Devin

@jonahsamost
Copy link
Author

jonahsamost commented Feb 11, 2026

@pcmoritz

The idea right now is we profile across a several different sequence lengths to find the max batch size for each sequence length (i.e. seq_lens = [(bs1, seq1), (bs2, seq2), ...]). Then we take each of their products
(budgets = [bs * seq for bs, seq in seq_lens]). These budget values stay fairly constant across pairs of sequence length and batch size. For instance, here is one from the tests: Token-budget ... (boundary products=[tensor(53760), tensor(52992), tensor(53760), tensor(53760)]). We take the minimum of these as the token budget C, and at runtime ensure every micro batch satisfies batch_size x max_seq_len_in_batch ≤ C

On the jax side, when profiling, because its a pure static compiler analysis that determines memory usage, when we iterate over sequence lengths, we can use the previous (larger) sequence length's max batch size to know where to start profiling on the current (small) sequence length.

On the torch side, we can't do this because actual forward and backward passes are being run on the GPU. And if the warm start batch size is too big, we may OOM and deadlock in FSDP.

For profiling on the torch side, we coordinate across distributed workers via an all_reduce. We also try to keep the budget somewhat conservative because OOMing in a distributed setting causes deadlocks.

The MemoryAwareBatchIterator in worker_utils.py is what actually does the greedy packing based on this "token_budget" value.

@jonahsamost jonahsamost marked this pull request as ready for review February 15, 2026 15:51
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces an automatic micro-batch sizing feature, which profiles GPU memory to determine an optimal token budget. This is a great addition for improving GPU utilization and preventing OOM errors. The implementation for the JAX backend looks solid. However, the implementation for the PyTorch trainer has a critical flaw where it only profiles the policy model and not the critic, which can lead to OOMs. There are also some minor code quality issues I've pointed out.

Comment on lines +581 to +624
def _auto_determine_token_budgets(
self,
cfg,
policy_model: PPORayActorGroup,
critic_model: Optional[PPORayActorGroup],
):
"""Profile GPU memory to determine the token budget for dynamic micro-batching.

Each worker profiles the model at several `(batch_size, seq_len)`
combinations and returns a token budget C such that any micro-batch
with `batch_size x max_seq_len ≤ C` fits in GPU memory.

The minimum budget across all workers is taken as the global budget.
"""
max_seq_len = cfg.trainer.max_prompt_length + cfg.generator.sampling_params.max_generate_length

logger.info(f"Auto micro-batch sizing: profiling policy workers (max_seq_len={max_seq_len})")
policy_refs = policy_model.async_run_ray_method(
"pass_through",
"auto_determine_token_budget",
max_seq_len,
)
policy_budgets = ray.get(policy_refs)
global_budget = min(policy_budgets)

logger.info(
f"Auto micro-batch sizing: per-worker token budgets={policy_budgets}, " f"using min={global_budget}"
)

assert global_budget > 0, (
f"Auto micro-batch sizing failed: token budget is {global_budget}. "
"The model may be too large for the available GPU memory at the configured "
f"max_seq_len={max_seq_len}. Try reducing max_prompt_length or max_generate_length, "
"or set micro_train_batch_size_per_gpu manually."
)

micro_bs = max(global_budget // max_seq_len, 1)
cfg.trainer.micro_train_batch_size_per_gpu = micro_bs
cfg.trainer.micro_forward_batch_size_per_gpu = micro_bs
logger.info(
f"Auto micro-batch sizing: token_budget={global_budget}, "
f"fixed micro_batch_size={micro_bs} "
f"(at max_seq_len={max_seq_len})"
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The automatic micro-batch sizing implementation has a critical flaw: it only profiles the policy model and completely ignores the critic model. This can lead to Out-Of-Memory (OOM) errors during critic training if the critic model is more memory-intensive than the policy model.

Additionally, there's an inconsistency in how the determined budget is used:

  1. Policy workers: Use the determined _token_budget with MemoryAwareBatchIterator for dynamic batching based on sequence length, which is efficient.
  2. Critic workers: The _token_budget is never set for them. They fall back to using a fixed micro_train_batch_size_per_gpu calculated in this function from the policy's budget. This is inconsistent and misses the benefits of dynamic packing.

To fix this, the function should be updated to:

  1. Profile both policy and critic models (if the critic exists).
  2. Find the global minimum token budget across all workers of all models.
  3. Set this global budget on all policy and critic workers, so they can all leverage MemoryAwareBatchIterator for efficient, memory-safe training.
  4. The logic for calculating a fixed micro_bs and setting it on the config should be removed, as it's inconsistent with the token budget approach.

This will require adding auto_determine_token_budget to CriticWorkerBase and a method (e.g., _set_token_budget) to all workers to apply the final global budget.

)

if max_bs <= 0:
logger.warning(f"Token-budget profiling: even batch_size=1 at seq_len={seq_len} " "exceeds memory budget.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The string concatenation for the log message can be simplified by including the second part inside the f-string for better readability.

Suggested change
logger.warning(f"Token-budget profiling: even batch_size=1 at seq_len={seq_len} " "exceeds memory budget.")
logger.warning(f"Token-budget profiling: even batch_size=1 at seq_len={seq_len} exceeds memory budget.")

devin-ai-integration[bot]

This comment was marked as resolved.

devin-ai-integration[bot]

This comment was marked as resolved.

devin-ai-integration[bot]

This comment was marked as resolved.

Copy link
Contributor

@devin-ai-integration devin-ai-integration bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Devin Review found 2 new potential issues.

View 11 additional findings in Devin Review.

Open in Devin Review

Comment on lines +271 to +293
try:
with torch.autocast(dtype=torch.bfloat16, device_type="cuda"):
action_log_probs, output = model(
sequences,
num_actions,
attention_mask=attention_mask,
temperature=temperature,
return_output=True,
compute_entropy=compute_entropy,
entropy_requires_grad=entropy_requires_grad,
)
loss = action_log_probs.sum()

strategy.backward(loss, model, None)

torch.cuda.synchronize(device)
free_after, _total = torch.cuda.mem_get_info(device)
finally:
model.zero_grad(set_to_none=True)

consumed = max(free_before - free_after, 0)
del sequences, attention_mask, action_log_probs, output, loss
return consumed
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 Uncaught CUDA OOM in _profile_candidate crashes auto-sizing instead of gracefully handling failure

During auto micro-batch profiling, _profile_candidate runs actual forward+backward passes to measure GPU memory consumption. However, it has no except clause for torch.cuda.OutOfMemoryError. If the model forward or backward pass causes a CUDA OOM (which is expected behavior when probing the memory limit), the exception propagates up through _profile_and_sync_find_max_batch_sizedetermine_token_budget → the Ray worker, crashing the auto-sizing and killing the Ray actor.

Root Cause and Impact

The _would_oom heuristic at skyrl-train/skyrl_train/utils/auto_microbatch.py:106 is designed to skip batch sizes likely to OOM, but it's an estimate based on linear extrapolation and can be wrong. When the heuristic under-predicts memory usage, the actual profiling run (_profile_candidate) hits a real CUDA OOM.

The function has a try/finally block but no except:

try:
    ...
    action_log_probs, output = model(...)
    loss = action_log_probs.sum()
    strategy.backward(loss, model, None)
    torch.cuda.synchronize(device)
    free_after, _total = torch.cuda.mem_get_info(device)
finally:
    model.zero_grad(set_to_none=True)

consumed = max(free_before - free_after, 0)  # NameError: free_after undefined

When OOM occurs:

  1. free_after is never assigned, so consumed = max(free_before - free_after, 0) at line 291 would also raise NameError if it were reached
  2. The del cleanup at line 292 is skipped, leaking tensors
  3. The exception propagates to Ray, killing the worker
  4. In distributed setups, other workers waiting on all_reduce calls will deadlock

Impact: The entire auto micro-batch sizing feature crashes on OOM instead of marking the batch size as too large and continuing the search.

Prompt for agents
In skyrl-train/skyrl_train/utils/auto_microbatch.py, the _profile_candidate function (lines 250-293) needs to catch torch.cuda.OutOfMemoryError (and RuntimeError with 'out of memory' in the message for older PyTorch). When OOM is caught, the function should: (1) call model.zero_grad(set_to_none=True), (2) delete any allocated tensors (sequences, attention_mask, and any partially-computed results), (3) call _cleanup_memory(), (4) return a very large sentinel value (e.g., float('inf') cast to int, or sys.maxsize) so the caller treats this batch size as exceeding the budget. This makes the search algorithm treat the OOM as 'consumed > memory_budget' and move to a smaller batch size. Similarly, _profile_and_sync should handle the case where _profile_candidate signals an OOM (e.g., by checking the sentinel value before the all_reduce, or by catching the exception at that level).
Open in Devin Review

Was this helpful? React with 👍 or 👎 to provide feedback.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is by design because FSDP deadlocks if any worker OOMs regardless. We prevent OOMs with the _would_oom checks and synchronize the memory budget across workers

@jonahsamost
Copy link
Author

@pcmoritz Let me know what you think!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant