Skip to content

fix(grpo): clamp log-ratio and k3 KL for numerical stability#4

Open
WyldeCat wants to merge 2 commits into
feat/flce-num-chunks-override-v2from
fix/grpo-clamp-numerical-stability
Open

fix(grpo): clamp log-ratio and k3 KL for numerical stability#4
WyldeCat wants to merge 2 commits into
feat/flce-num-chunks-override-v2from
fix/grpo-clamp-numerical-stability

Conversation

@WyldeCat
Copy link
Copy Markdown
Member

@WyldeCat WyldeCat commented May 14, 2026

요약

  • LigerFusedLinearGRPOLoss / LigerFusedLinearGRPOFunction에 numerical-safety용 clamp 인자 3개 추가, default는 켜져 있음:
    • log_ratio_clamp_value=20.0 — PG coef_1과 token/sequence importance weight 계산에 쓰이는 policy - old log-ratio를 exp 직전에 clamp.
    • kl_input_clamp_value=20.0k3_loss_fn 안에서 ref - policy log-ratio를 exp 직전에 clamp.
    • kl_output_clamp_value=10.0 — k3 KL 결과값을 clamp.
  • 이전 unbounded 동작을 원하는 caller는 None을 명시적으로 넘기면 됨.

동기

log_ratiok3_loss_fnlog_p - log_q는 어디서도 bound되지 않습니다. masked / low-probability 위치에서 두 항이 fp32 exp overflow를 일으켜 inf가 되고, 뒤이은 masked reduction이 per_token_loss * attention_mask 형태라 inf * 0 == nan (IEEE) 으로 mask=0 위치 한 토큰이 batch 전체 loss를 NaN으로 오염시킵니다. 결과적으로 GRPO 첫 step부터 advantages == 0인 상황에서도 loss가 NaN으로 나옵니다.

Clamp 로직 출처

NVIDIA NeMo-RL의 calculate_kl에서 동일한 가드를 가지고 있습니다. 같은 위치(exp 전·후)에 같은 의미의 clamp를 적용하고, default input_clamp_value=20.0 / output_clamp_value=10.0을 씁니다:

def calculate_kl(logprobs, logprobs_reference, kl_type="k3",
                 input_clamp_value: float | None = 20.0,
                 output_clamp_value: float | None = 10.0):
    logr = logprobs_reference - logprobs
    if input_clamp_value is not None:
        logr = logr.clamp(min=-input_clamp_value, max=input_clamp_value)
    ...
    elif kl_type == "k3":
        kl = torch.exp(logr) - 1 - logr
    ...
    if output_clamp_value is not None:
        kl = kl.clamp(min=-output_clamp_value, max=output_clamp_value)
    return kl

이 PR은 동일한 default 값까지 그대로 가져와 Liger의 fused GRPO 경로에 적용합니다. NeMo-RL이 호출 측에서 clamp된 값을 만들어 넘기더라도 Liger 내부에서 per_token_logpsold/ref_per_token_logps를 직접 빼서 자체적으로 log-ratio를 다시 계산하기 때문에 (grpo_loss.pylog_ratio = per_token_logps - old_per_token_logps, k3_loss_fn(ref_per_token_logps, per_token_logps)) 호출 측 clamp만으로는 보호되지 않습니다 — 그래서 Liger 내부에 default-on 가드가 필요합니다.

테스트

  • k3_loss_fn smoke: log_p - log_q = ±200일 때 unclamped는 inf, input_clamp_value=20.0, output_clamp_value=10.0 주면 finite ±10으로 bounded.
  • LigerFusedLinearGRPOLoss end-to-end: 극단 ref 값에서 clamp 켜면 loss가 output_clamp × beta로 bounded, 안 켜면 unbounded.
  • Motif V3 GRPO 4-node async 실제 학습 — clamp 끄면 step 1부터 Loss=NaN, clamp 켜면 step 1~4 모두 Loss=0.0000 (실제 0, NaN 아님).

🤖 Generated with Claude Code

…bility

Unbounded ``exp`` over the policy/old log-ratio and over the k3 KL
``log_p - log_q`` can overflow fp32 to ``inf`` at masked / low-
probability positions. Because the subsequent reduction is a plain
``per_token_loss * attention_mask``, ``inf * 0 == nan`` then contami-
nates the entire batch, producing NaN losses on the very first step.

Add three optional clamps (defaults ``None`` = no behavior change):

- ``log_ratio_clamp_value`` clamps ``policy - old`` before ``exp``,
  protecting both the token- and sequence-level importance-weight
  paths (and the PG ``coef_1``).
- ``kl_input_clamp_value`` clamps ``ref - policy`` inside
  ``k3_loss_fn`` before ``exp``.
- ``kl_output_clamp_value`` clamps the resulting k3 KL value.

This mirrors the guard used by NeMo-RL's ``calculate_kl``
(``input_clamp_value`` / ``output_clamp_value``) and the broader
RL-framework convention. Callers that don't set the new arguments
get the prior behavior unchanged.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@ca1207
Copy link
Copy Markdown
Member

ca1207 commented May 20, 2026

default 값이 없는것 같은디 nemo-torchtitan 쪽 로직도 수정하신거 잇나유?

Set the new clamp arguments to NeMo-RL's ``calculate_kl`` defaults
(``log_ratio_clamp_value=20.0``, ``kl_input_clamp_value=20.0``,
``kl_output_clamp_value=10.0``) instead of ``None`` so the NaN guard
is active out of the box. Existing callers automatically get the
safety; users that want the prior unbounded behavior can pass
``None`` explicitly.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants