Skip to content
63 changes: 63 additions & 0 deletions examples/train/algorithms/maxrl/run_maxrl_gsm8k.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
set -x

# Colocated MAXRL training+generation for Qwen2.5-1.5B-Instruct on GSM8K.

# uv run examples/train/gsm8k/gsm8k_dataset.py --output_dir $HOME/data/gsm8k
# export WANDB_API_KEY=<your_key_here>
# bash examples/train/algorithms/maxrl/run_maxrl_gsm8k.sh


# You can override the default values with e.g.: `NUM_GPUS=1 bash examples/train/algorithms/maxrl/run_maxrl_gsm8k.sh`.

: "${DATA_DIR:="$HOME/data/gsm8k"}"
: "${NUM_GPUS:=4}"
: "${LOGGER:=wandb}" # change to "console" to print to stdout
: "${INFERENCE_BACKEND:=vllm}"

# MAXRL parameters
: "${ADV_ESTIMATOR:=maxrl}"

# Other algorithm parameters
: "${USE_KL_LOSS:=true}"
Comment thread
SumanthRH marked this conversation as resolved.

uv run --isolated --extra fsdp -m skyrl.train.entrypoints.main_base \
data.train_data="['$DATA_DIR/train.parquet']" \
data.val_data="['$DATA_DIR/validation.parquet']" \
trainer.algorithm.advantage_estimator="$ADV_ESTIMATOR" \
trainer.policy.model.path="Qwen/Qwen2.5-1.5B-Instruct" \
trainer.placement.colocate_all=true \
trainer.strategy=fsdp2 \
trainer.placement.policy_num_gpus_per_node=$NUM_GPUS \
trainer.placement.critic_num_gpus_per_node=$NUM_GPUS \
trainer.placement.ref_num_gpus_per_node=$NUM_GPUS \
generator.inference_engine.num_engines=$NUM_GPUS \
generator.inference_engine.tensor_parallel_size=1 \
trainer.epochs=20 \
trainer.eval_batch_size=1024 \
trainer.eval_before_train=true \
trainer.eval_interval=5 \
trainer.update_epochs_per_batch=1 \
trainer.train_batch_size=1024 \
trainer.policy_mini_batch_size=256 \
trainer.micro_forward_batch_size_per_gpu=64 \
trainer.micro_train_batch_size_per_gpu=64 \
trainer.ckpt_interval=10 \
trainer.max_prompt_length=512 \
generator.sampling_params.max_generate_length=1024 \
trainer.policy.optimizer_config.lr=1.0e-6 \
trainer.algorithm.use_kl_loss=$USE_KL_LOSS \
generator.inference_engine.backend=$INFERENCE_BACKEND \
generator.inference_engine.run_engines_locally=true \
generator.inference_engine.weight_sync_backend=nccl \
generator.inference_engine.async_engine=true \
generator.batched=true \
environment.env_class=gsm8k \
generator.n_samples_per_prompt=5 \
generator.inference_engine.gpu_memory_utilization=0.8 \
trainer.logger="$LOGGER" \
trainer.project_name="gsm8k" \
trainer.run_name="maxrl_gsm8k" \
trainer.resume_mode=null \
trainer.log_path="/tmp/skyrl-logs" \
trainer.ckpt_path="$HOME/ckpts/gsm8k_1.5B_ckpt" \
"$@"
37 changes: 37 additions & 0 deletions skyrl/backends/skyrl_train/utils/ppo_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,7 @@ class AdvantageEstimator(StrEnum):
GRPO = "grpo"
RLOO = "rloo"
REINFORCE_PP = "reinforce++"
MAXRL = "maxrl"


class AdvantageEstimatorRegistry(BaseFunctionRegistry):
Expand All @@ -453,6 +454,7 @@ def repopulate_registry(cls):
"gae": [AdvantageEstimator.GAE, compute_gae_advantage_return],
"rloo": [AdvantageEstimator.RLOO, compute_rloo_outcome_advantage],
"reinforce++": [AdvantageEstimator.REINFORCE_PP, compute_reinforce_plus_plus_outcome_advantage],
"maxrl": [AdvantageEstimator.MAXRL, compute_maxrl_advantage],
}

for ae_name, (ae_type, ae_func) in ae_types.items():
Expand Down Expand Up @@ -1238,6 +1240,41 @@ def compute_grpo_outcome_advantage(
return scores, scores


@register_advantage_estimator(AdvantageEstimator.MAXRL)
def compute_maxrl_advantage(
token_level_rewards: torch.Tensor,
response_mask: torch.Tensor,
index: np.ndarray,
epsilon: float = 1e-6,
**kwargs,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Compute advantage for MAXRL using mean-normalized group-relative rewards."""
scores = token_level_rewards.sum(dim=-1)

id2score = defaultdict(list)
id2mean = {}

with torch.no_grad():
bsz = scores.shape[0]
for i in range(bsz):
id2score[index[i]].append(scores[i])
for idx in id2score:
if len(id2score[idx]) == 1:
id2mean[idx] = torch.tensor(0.0)
elif len(id2score[idx]) > 1:
id2mean[idx] = torch.mean(torch.tensor(id2score[idx]))
else:
raise ValueError(f"no score in prompt index: {idx}")
for i in range(bsz):
if len(id2score[index[i]]) > 1:
scores[i] = (scores[i] - id2mean[index[i]]) / (id2mean[index[i]] + epsilon)
Copy link
Copy Markdown
Contributor

@devin-ai-integration devin-ai-integration Bot Feb 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 MAXRL advantage sign is flipped when group mean reward is negative

In compute_maxrl_advantage, the normalization divides by (id2mean[index[i]] + epsilon) at line 1220. When the group's mean reward is negative, this denominator is negative, which flips the sign of the advantage. For example, with group scores [-10, -5] (mean = -7.5): the worse response (-10) gets advantage (-10 - (-7.5)) / (-7.5 + 1e-6) ≈ +0.333 (positive!) and the better response (-5) gets advantage (-5 - (-7.5)) / (-7.5 + 1e-6) ≈ -0.333 (negative!). This causes the RL algorithm to reinforce bad responses and penalize good ones whenever the group mean is negative. The test only uses positive rewards so it doesn't catch this. The fix should use abs(id2mean) in the denominator.

Open in Devin Review

Was this helpful? React with 👍 or 👎 to provide feedback.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This formulation is as per the original maxrl paper

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should I make the denominator absolute then? Don't think people use negative rewards anyways nowadays

else:
scores[i] = scores[i] - id2mean[index[i]]
scores = scores.unsqueeze(-1) * response_mask

return scores, scores


def repopulate_all_registries():
PolicyLossRegistry.repopulate_registry()
AdvantageEstimatorRegistry.repopulate_registry()
Expand Down
30 changes: 30 additions & 0 deletions tests/backends/skyrl_train/utils/test_ppo_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
compute_approx_kl,
compute_gae_advantage_return,
compute_grpo_outcome_advantage,
compute_maxrl_advantage,
compute_reinforce_plus_plus_outcome_advantage,
compute_rloo_outcome_advantage,
reduce_loss,
Expand Down Expand Up @@ -174,6 +175,35 @@ def test_compute_grpo_outcome_advantage_norm_std_false():
assert torch.allclose(adv, expected, atol=1e-5), f"Expected {expected}, got {adv}"


def test_compute_maxrl_advantage():
# Two groups: [6.0, 3.0] mean=4.5, [9.0, 12.0] mean=10.5
token_level_rewards = torch.tensor(
[
[1.0, 2.0, 3.0], # sum = 6.0, group 0
[1.0, 1.0, 1.0], # sum = 3.0, group 0
[3.0, 3.0, 3.0], # sum = 9.0, group 1
[4.0, 4.0, 4.0], # sum = 12.0, group 1
]
)
response_mask = torch.ones_like(token_level_rewards)
index = np.array([0, 0, 1, 1])

adv, ret = compute_maxrl_advantage(
token_level_rewards=token_level_rewards,
response_mask=response_mask,
index=index,
)

expected = (
torch.tensor([1.5 / (4.5 + 1e-6), -1.5 / (4.5 + 1e-6), -1.5 / (10.5 + 1e-6), 1.5 / (10.5 + 1e-6)]).unsqueeze(-1)
* response_mask
)

assert adv.shape == token_level_rewards.shape
assert torch.allclose(adv, ret), "Advantages and returns should be equal with MAXRL"
assert torch.allclose(adv, expected, atol=1e-5), f"Expected {expected}, got {adv}"


def test_compute_gae_advantage_return(advantage_test_data):
rewards, values, response_mask, index = advantage_test_data

Expand Down
Loading