Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions skyrl-train/skyrl_train/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,7 @@ class TrainerConfig(BaseConfig):
critic_mini_batch_size: int = 256
micro_train_batch_size_per_gpu: int = 1
micro_forward_batch_size_per_gpu: int = 1
auto_micro_batch_size: bool = False
update_ref_every_epoch: bool = False
use_sample_packing: bool = True
eval_batch_size: int = 1024
Expand Down
1 change: 1 addition & 0 deletions skyrl-train/skyrl_train/config/ppo_base_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,7 @@ trainer:
critic_mini_batch_size: 256
micro_train_batch_size_per_gpu: 1
micro_forward_batch_size_per_gpu: 1
auto_micro_batch_size: false # Set true to auto-determine micro batch sizes via profiling
update_ref_every_epoch: false
use_sample_packing: true
eval_batch_size: 1024
Expand Down
61 changes: 61 additions & 0 deletions skyrl-train/skyrl_train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,6 +529,10 @@ def build_models(self, PolicyWorker, CriticWorker, RefWorker):
)
)
ray.get(refs)

if getattr(cfg.trainer, "auto_micro_batch_size", False):
self._auto_determine_token_budgets(cfg, policy_model, critic_model)

ray.get(policy_model.async_run_ray_method("pass_through", "_set_pad_token_id", self.tokenizer.pad_token_id))
else:
if ref_model is not None:
Expand All @@ -541,6 +545,10 @@ def build_models(self, PolicyWorker, CriticWorker, RefWorker):
)
)
ray.get(policy_model.async_run_ray_method("pass_through", "_set_pad_token_id", self.tokenizer.pad_token_id))

if getattr(cfg.trainer, "auto_micro_batch_size", False):
self._auto_determine_token_budgets(cfg, policy_model, critic_model=None)

policy_model.offload_to_cpu()
if cfg.trainer.critic.model.path:
ray.get(
Expand Down Expand Up @@ -570,6 +578,59 @@ def build_models(self, PolicyWorker, CriticWorker, RefWorker):

logger.info("init policy/ref/critic models done")

def _auto_determine_token_budgets(
self,
cfg,
policy_model: PPORayActorGroup,
critic_model: Optional[PPORayActorGroup],
):
"""Profile GPU memory to determine the token budget for dynamic micro-batching.

Each worker profiles the model at several `(batch_size, seq_len)`
combinations and returns a token budget C such that any micro-batch
with `batch_size x max_seq_len ≤ C` fits in GPU memory.

The minimum budget across all workers is taken as the global budget.
"""
max_seq_len = cfg.trainer.max_prompt_length + cfg.generator.sampling_params.max_generate_length

logger.info(f"Auto micro-batch sizing: profiling policy workers (max_seq_len={max_seq_len})")
policy_refs = policy_model.async_run_ray_method(
"pass_through",
"auto_determine_token_budget",
max_seq_len,
)
policy_budgets = ray.get(policy_refs)
all_budgets = list(policy_budgets)
logger.info(f"Auto micro-batch sizing: policy per-worker budgets={policy_budgets}")

if critic_model is not None:
logger.info(f"Auto micro-batch sizing: profiling critic workers (max_seq_len={max_seq_len})")
critic_refs = critic_model.async_run_ray_method("pass_through", "auto_determine_token_budget", max_seq_len)
critic_budgets = ray.get(critic_refs)
all_budgets.extend(critic_budgets)
logger.info(f"Auto micro-batch sizing: critic per-worker budgets={critic_budgets}")

global_budget = min(all_budgets)

logger.info(f"Auto micro-batch sizing: all budgets={all_budgets}, using global min={global_budget}")

assert global_budget > 0, (
f"Auto micro-batch sizing failed: token budget is {global_budget}. "
"The model may be too large for the available GPU memory at the configured "
f"max_seq_len={max_seq_len}. Try reducing max_prompt_length or max_generate_length, "
"or set micro_train_batch_size_per_gpu manually."
)

micro_bs = max(global_budget // max_seq_len, 1)
cfg.trainer.micro_train_batch_size_per_gpu = micro_bs
cfg.trainer.micro_forward_batch_size_per_gpu = micro_bs
logger.info(
f"Auto micro-batch sizing: token_budget={global_budget}, "
f"fixed micro_batch_size={micro_bs} "
f"(at max_seq_len={max_seq_len})"
)
Comment on lines +581 to +632
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The automatic micro-batch sizing implementation has a critical flaw: it only profiles the policy model and completely ignores the critic model. This can lead to Out-Of-Memory (OOM) errors during critic training if the critic model is more memory-intensive than the policy model.

Additionally, there's an inconsistency in how the determined budget is used:

  1. Policy workers: Use the determined _token_budget with MemoryAwareBatchIterator for dynamic batching based on sequence length, which is efficient.
  2. Critic workers: The _token_budget is never set for them. They fall back to using a fixed micro_train_batch_size_per_gpu calculated in this function from the policy's budget. This is inconsistent and misses the benefits of dynamic packing.

To fix this, the function should be updated to:

  1. Profile both policy and critic models (if the critic exists).
  2. Find the global minimum token budget across all workers of all models.
  3. Set this global budget on all policy and critic workers, so they can all leverage MemoryAwareBatchIterator for efficient, memory-safe training.
  4. The logic for calculating a fixed micro_bs and setting it on the config should be removed, as it's inconsistent with the token budget approach.

This will require adding auto_determine_token_budget to CriticWorkerBase and a method (e.g., _set_token_budget) to all workers to apply the final global budget.

Comment thread
devin-ai-integration[bot] marked this conversation as resolved.

def init_weight_sync_state(self):
"""
Setup the connection between policy model and inference engine for weight syncing.
Expand Down
Loading
Loading