Implement automatic micro batch sizing#1060
Implement automatic micro batch sizing#1060jonahsamost wants to merge 9 commits intoNovaSky-AI:mainfrom
Conversation
|
The idea right now is we profile across a several different sequence lengths to find the max batch size for each sequence length (i.e. On the jax side, when profiling, because its a pure static compiler analysis that determines memory usage, when we iterate over sequence lengths, we can use the previous (larger) sequence length's max batch size to know where to start profiling on the current (small) sequence length. On the torch side, we can't do this because actual forward and backward passes are being run on the GPU. And if the warm start batch size is too big, we may OOM and deadlock in FSDP. For profiling on the torch side, we coordinate across distributed workers via an The |
There was a problem hiding this comment.
Code Review
This pull request introduces an automatic micro-batch sizing feature, which profiles GPU memory to determine an optimal token budget. This is a great addition for improving GPU utilization and preventing OOM errors. The implementation for the JAX backend looks solid. However, the implementation for the PyTorch trainer has a critical flaw where it only profiles the policy model and not the critic, which can lead to OOMs. There are also some minor code quality issues I've pointed out.
| def _auto_determine_token_budgets( | ||
| self, | ||
| cfg, | ||
| policy_model: PPORayActorGroup, | ||
| critic_model: Optional[PPORayActorGroup], | ||
| ): | ||
| """Profile GPU memory to determine the token budget for dynamic micro-batching. | ||
|
|
||
| Each worker profiles the model at several `(batch_size, seq_len)` | ||
| combinations and returns a token budget C such that any micro-batch | ||
| with `batch_size x max_seq_len ≤ C` fits in GPU memory. | ||
|
|
||
| The minimum budget across all workers is taken as the global budget. | ||
| """ | ||
| max_seq_len = cfg.trainer.max_prompt_length + cfg.generator.sampling_params.max_generate_length | ||
|
|
||
| logger.info(f"Auto micro-batch sizing: profiling policy workers (max_seq_len={max_seq_len})") | ||
| policy_refs = policy_model.async_run_ray_method( | ||
| "pass_through", | ||
| "auto_determine_token_budget", | ||
| max_seq_len, | ||
| ) | ||
| policy_budgets = ray.get(policy_refs) | ||
| global_budget = min(policy_budgets) | ||
|
|
||
| logger.info( | ||
| f"Auto micro-batch sizing: per-worker token budgets={policy_budgets}, " f"using min={global_budget}" | ||
| ) | ||
|
|
||
| assert global_budget > 0, ( | ||
| f"Auto micro-batch sizing failed: token budget is {global_budget}. " | ||
| "The model may be too large for the available GPU memory at the configured " | ||
| f"max_seq_len={max_seq_len}. Try reducing max_prompt_length or max_generate_length, " | ||
| "or set micro_train_batch_size_per_gpu manually." | ||
| ) | ||
|
|
||
| micro_bs = max(global_budget // max_seq_len, 1) | ||
| cfg.trainer.micro_train_batch_size_per_gpu = micro_bs | ||
| cfg.trainer.micro_forward_batch_size_per_gpu = micro_bs | ||
| logger.info( | ||
| f"Auto micro-batch sizing: token_budget={global_budget}, " | ||
| f"fixed micro_batch_size={micro_bs} " | ||
| f"(at max_seq_len={max_seq_len})" | ||
| ) |
There was a problem hiding this comment.
The automatic micro-batch sizing implementation has a critical flaw: it only profiles the policy model and completely ignores the critic model. This can lead to Out-Of-Memory (OOM) errors during critic training if the critic model is more memory-intensive than the policy model.
Additionally, there's an inconsistency in how the determined budget is used:
- Policy workers: Use the determined
_token_budgetwithMemoryAwareBatchIteratorfor dynamic batching based on sequence length, which is efficient. - Critic workers: The
_token_budgetis never set for them. They fall back to using a fixedmicro_train_batch_size_per_gpucalculated in this function from the policy's budget. This is inconsistent and misses the benefits of dynamic packing.
To fix this, the function should be updated to:
- Profile both policy and critic models (if the critic exists).
- Find the global minimum token budget across all workers of all models.
- Set this global budget on all policy and critic workers, so they can all leverage
MemoryAwareBatchIteratorfor efficient, memory-safe training. - The logic for calculating a fixed
micro_bsand setting it on the config should be removed, as it's inconsistent with the token budget approach.
This will require adding auto_determine_token_budget to CriticWorkerBase and a method (e.g., _set_token_budget) to all workers to apply the final global budget.
| ) | ||
|
|
||
| if max_bs <= 0: | ||
| logger.warning(f"Token-budget profiling: even batch_size=1 at seq_len={seq_len} " "exceeds memory budget.") |
There was a problem hiding this comment.
The string concatenation for the log message can be simplified by including the second part inside the f-string for better readability.
| logger.warning(f"Token-budget profiling: even batch_size=1 at seq_len={seq_len} " "exceeds memory budget.") | |
| logger.warning(f"Token-budget profiling: even batch_size=1 at seq_len={seq_len} exceeds memory budget.") |
| try: | ||
| with torch.autocast(dtype=torch.bfloat16, device_type="cuda"): | ||
| action_log_probs, output = model( | ||
| sequences, | ||
| num_actions, | ||
| attention_mask=attention_mask, | ||
| temperature=temperature, | ||
| return_output=True, | ||
| compute_entropy=compute_entropy, | ||
| entropy_requires_grad=entropy_requires_grad, | ||
| ) | ||
| loss = action_log_probs.sum() | ||
|
|
||
| strategy.backward(loss, model, None) | ||
|
|
||
| torch.cuda.synchronize(device) | ||
| free_after, _total = torch.cuda.mem_get_info(device) | ||
| finally: | ||
| model.zero_grad(set_to_none=True) | ||
|
|
||
| consumed = max(free_before - free_after, 0) | ||
| del sequences, attention_mask, action_log_probs, output, loss | ||
| return consumed |
There was a problem hiding this comment.
🔴 Uncaught CUDA OOM in _profile_candidate crashes auto-sizing instead of gracefully handling failure
During auto micro-batch profiling, _profile_candidate runs actual forward+backward passes to measure GPU memory consumption. However, it has no except clause for torch.cuda.OutOfMemoryError. If the model forward or backward pass causes a CUDA OOM (which is expected behavior when probing the memory limit), the exception propagates up through _profile_and_sync → _find_max_batch_size → determine_token_budget → the Ray worker, crashing the auto-sizing and killing the Ray actor.
Root Cause and Impact
The _would_oom heuristic at skyrl-train/skyrl_train/utils/auto_microbatch.py:106 is designed to skip batch sizes likely to OOM, but it's an estimate based on linear extrapolation and can be wrong. When the heuristic under-predicts memory usage, the actual profiling run (_profile_candidate) hits a real CUDA OOM.
The function has a try/finally block but no except:
try:
...
action_log_probs, output = model(...)
loss = action_log_probs.sum()
strategy.backward(loss, model, None)
torch.cuda.synchronize(device)
free_after, _total = torch.cuda.mem_get_info(device)
finally:
model.zero_grad(set_to_none=True)
consumed = max(free_before - free_after, 0) # NameError: free_after undefinedWhen OOM occurs:
free_afteris never assigned, soconsumed = max(free_before - free_after, 0)at line 291 would also raiseNameErrorif it were reached- The
delcleanup at line 292 is skipped, leaking tensors - The exception propagates to Ray, killing the worker
- In distributed setups, other workers waiting on
all_reducecalls will deadlock
Impact: The entire auto micro-batch sizing feature crashes on OOM instead of marking the batch size as too large and continuing the search.
Prompt for agents
In skyrl-train/skyrl_train/utils/auto_microbatch.py, the _profile_candidate function (lines 250-293) needs to catch torch.cuda.OutOfMemoryError (and RuntimeError with 'out of memory' in the message for older PyTorch). When OOM is caught, the function should: (1) call model.zero_grad(set_to_none=True), (2) delete any allocated tensors (sequences, attention_mask, and any partially-computed results), (3) call _cleanup_memory(), (4) return a very large sentinel value (e.g., float('inf') cast to int, or sys.maxsize) so the caller treats this batch size as exceeding the budget. This makes the search algorithm treat the OOM as 'consumed > memory_budget' and move to a smaller batch size. Similarly, _profile_and_sync should handle the case where _profile_candidate signals an OOM (e.g., by checking the sentinel value before the all_reduce, or by catching the exception at that level).
Was this helpful? React with 👍 or 👎 to provide feedback.
There was a problem hiding this comment.
This is by design because FSDP deadlocks if any worker OOMs regardless. We prevent OOMs with the _would_oom checks and synchronize the memory budget across workers
|
@pcmoritz Let me know what you think! |
Referencing #1048
This PR adds an optional profiling step to automatically determine the largest micro batch size that fits in GPU memory.
When
auto_micro_batch_size: truein the config, a profiling step will run once after model initialization. The function will iterate all divisors of themini_batch_size_per_gpuin ascending order, run a dummy forward + backward pass at maximum sequence length, then measure memory usage, in order to determine the best candidate micro batch size option.