fix(grpo): clamp log-ratio and k3 KL for numerical stability#4
Open
WyldeCat wants to merge 2 commits into
Open
fix(grpo): clamp log-ratio and k3 KL for numerical stability#4WyldeCat wants to merge 2 commits into
WyldeCat wants to merge 2 commits into
Conversation
…bility Unbounded ``exp`` over the policy/old log-ratio and over the k3 KL ``log_p - log_q`` can overflow fp32 to ``inf`` at masked / low- probability positions. Because the subsequent reduction is a plain ``per_token_loss * attention_mask``, ``inf * 0 == nan`` then contami- nates the entire batch, producing NaN losses on the very first step. Add three optional clamps (defaults ``None`` = no behavior change): - ``log_ratio_clamp_value`` clamps ``policy - old`` before ``exp``, protecting both the token- and sequence-level importance-weight paths (and the PG ``coef_1``). - ``kl_input_clamp_value`` clamps ``ref - policy`` inside ``k3_loss_fn`` before ``exp``. - ``kl_output_clamp_value`` clamps the resulting k3 KL value. This mirrors the guard used by NeMo-RL's ``calculate_kl`` (``input_clamp_value`` / ``output_clamp_value``) and the broader RL-framework convention. Callers that don't set the new arguments get the prior behavior unchanged. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Member
|
default 값이 없는것 같은디 nemo-torchtitan 쪽 로직도 수정하신거 잇나유? |
Set the new clamp arguments to NeMo-RL's ``calculate_kl`` defaults (``log_ratio_clamp_value=20.0``, ``kl_input_clamp_value=20.0``, ``kl_output_clamp_value=10.0``) instead of ``None`` so the NaN guard is active out of the box. Existing callers automatically get the safety; users that want the prior unbounded behavior can pass ``None`` explicitly. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
요약
LigerFusedLinearGRPOLoss/LigerFusedLinearGRPOFunction에 numerical-safety용 clamp 인자 3개 추가, default는 켜져 있음:log_ratio_clamp_value=20.0— PGcoef_1과 token/sequence importance weight 계산에 쓰이는policy - oldlog-ratio를exp직전에 clamp.kl_input_clamp_value=20.0—k3_loss_fn안에서ref - policylog-ratio를exp직전에 clamp.kl_output_clamp_value=10.0— k3 KL 결과값을 clamp.None을 명시적으로 넘기면 됨.동기
log_ratio와k3_loss_fn의log_p - log_q는 어디서도 bound되지 않습니다. masked / low-probability 위치에서 두 항이 fp32expoverflow를 일으켜inf가 되고, 뒤이은 masked reduction이per_token_loss * attention_mask형태라inf * 0 == nan(IEEE) 으로 mask=0 위치 한 토큰이 batch 전체 loss를 NaN으로 오염시킵니다. 결과적으로 GRPO 첫 step부터advantages == 0인 상황에서도 loss가 NaN으로 나옵니다.Clamp 로직 출처
NVIDIA NeMo-RL의
calculate_kl에서 동일한 가드를 가지고 있습니다. 같은 위치(exp전·후)에 같은 의미의 clamp를 적용하고, defaultinput_clamp_value=20.0/output_clamp_value=10.0을 씁니다:이 PR은 동일한 default 값까지 그대로 가져와 Liger의 fused GRPO 경로에 적용합니다. NeMo-RL이 호출 측에서 clamp된 값을 만들어 넘기더라도 Liger 내부에서
per_token_logps와old/ref_per_token_logps를 직접 빼서 자체적으로 log-ratio를 다시 계산하기 때문에 (grpo_loss.py의log_ratio = per_token_logps - old_per_token_logps,k3_loss_fn(ref_per_token_logps, per_token_logps)) 호출 측 clamp만으로는 보호되지 않습니다 — 그래서 Liger 내부에 default-on 가드가 필요합니다.테스트
k3_loss_fnsmoke:log_p - log_q = ±200일 때 unclamped는inf,input_clamp_value=20.0, output_clamp_value=10.0주면 finite±10으로 bounded.LigerFusedLinearGRPOLossend-to-end: 극단ref값에서 clamp 켜면 loss가output_clamp × beta로 bounded, 안 켜면 unbounded.🤖 Generated with Claude Code