Skip to content

Commit dc68f8c

Browse files
JRMeyerclaude
andcommitted
feat: add importance sampling observability metrics
Adds three new metrics logged during training to help users verify that importance sampling is working correctly: - frac_old_logprobs_valid: Fraction of old logprobs that are not NaN - mean_importance_ratio: Mean π_new/π_old across assistant tokens - clip_fraction: Fraction of tokens where PPO clipping was triggered These metrics help diagnose whether GRPO/PPO importance sampling is active or if training has fallen back to vanilla REINFORCE (when all logprobs are NaN). 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
1 parent b416286 commit dc68f8c

File tree

2 files changed

+188
-2
lines changed

2 files changed

+188
-2
lines changed
Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
# Technical Design: Importance Sampling Observability Metrics
2+
3+
## Problem Statement
4+
5+
ART computes importance sampling ratios internally for PPO/GRPO training but does not expose these metrics for monitoring. Users have no visibility into:
6+
7+
1. Whether logprobs are being extracted correctly from trajectories
8+
2. Whether importance sampling is actually active (vs. falling back to REINFORCE)
9+
3. How often PPO clipping is triggered
10+
11+
This makes it difficult to debug training issues and verify that the importance sampling pipeline is working correctly.
12+
13+
### Background: How Importance Sampling Works in ART
14+
15+
```
16+
Rollout Phase
17+
18+
19+
Trajectories with logprobs attached to messages
20+
21+
22+
Tokenization Phase (tokenize.py)
23+
24+
├─► Dict messages: extract logprobs if present, else NaN
25+
└─► Choice objects: extract logprobs if present
26+
27+
28+
Training Phase (train.py)
29+
30+
├─► If logprobs are NaN: set old_logprobs = new_logprobs.detach()
31+
│ └─► prob_ratio = exp(0) = 1.0 (NO importance sampling)
32+
33+
└─► If logprobs are real: compute prob_ratio = exp(new - old)
34+
└─► PPO clipping applied when ratio outside [1-ε, 1+ε]
35+
```
36+
37+
When all logprobs are NaN, ART silently falls back to vanilla REINFORCE (advantage-weighted policy gradient with no off-policy correction). This is valid but may not be what users expect.
38+
39+
## Solution
40+
41+
Add three new metrics to ART's training loop that are logged to wandb:
42+
43+
### 1. `frac_old_logprobs_valid`
44+
45+
**What it measures:** Fraction of `old_logprobs` values that are NOT NaN at training time.
46+
47+
**Implementation:**
48+
```python
49+
old_logprobs_nan_mask = torch.isnan(old_logprobs)
50+
frac_old_logprobs_valid = 1.0 - (
51+
old_logprobs_nan_mask.float().sum() / (old_logprobs.numel() + 1e-6)
52+
).item()
53+
```
54+
55+
**Interpretation:**
56+
| Value | Meaning |
57+
|-------|---------|
58+
| 0.0 | All logprobs are NaN - importance sampling NOT active |
59+
| ~0.3-0.5 | Partial logprobs - some tokens have valid logprobs |
60+
| ~0.8-1.0 | Most logprobs valid - importance sampling fully active |
61+
62+
**Why not exactly 1.0?** System messages, tool calls, and prompt tokens don't have logprobs - only assistant response tokens do.
63+
64+
### 2. `mean_importance_ratio`
65+
66+
**What it measures:** Mean importance sampling ratio π_new(a|s) / π_old(a|s) across assistant tokens.
67+
68+
**Implementation:**
69+
```python
70+
mean_importance_ratio = (prob_ratio * assistant_mask).sum() / (assistant_mask.sum() + 1e-6)
71+
```
72+
73+
**Interpretation:**
74+
| Value | Meaning |
75+
|-------|---------|
76+
| Exactly 1.0 | No distribution shift (or all NaN logprobs) |
77+
| 0.8 - 1.2 | Healthy training - policy evolving gradually |
78+
| < 0.5 or > 2.0 | Large distribution shift - may indicate issues |
79+
80+
### 3. `clip_fraction`
81+
82+
**What it measures:** Fraction of assistant tokens where PPO clipping was triggered.
83+
84+
**Implementation:**
85+
```python
86+
clipped_ratio = torch.clip(prob_ratio, 1 - epsilon, 1 + epsilon_high)
87+
is_clipped = (prob_ratio < 1 - epsilon) | (prob_ratio > 1 + epsilon_high)
88+
clip_fraction = (is_clipped.float() * assistant_mask).sum() / (assistant_mask.sum() + 1e-6)
89+
```
90+
91+
**Interpretation:**
92+
| Value | Meaning |
93+
|-------|---------|
94+
| 0.0 | No clipping - either on-policy or no importance sampling |
95+
| 0.01 - 0.1 | Healthy - some off-policy correction happening |
96+
| > 0.3 | High clipping - policy has diverged significantly from rollout policy |
97+
98+
## Implementation Details
99+
100+
### Files Modified
101+
102+
**`src/art/unsloth/train.py`**
103+
104+
1. Compute `frac_old_logprobs_valid` before the NaN replacement:
105+
```python
106+
old_logprobs_nan_mask = torch.isnan(old_logprobs)
107+
frac_old_logprobs_valid = 1.0 - (
108+
old_logprobs_nan_mask.float().sum() / (old_logprobs.numel() + 1e-6)
109+
).item()
110+
old_logprobs = torch.where(
111+
old_logprobs_nan_mask, # reuse mask
112+
new_logprobs.detach(),
113+
old_logprobs,
114+
)
115+
```
116+
117+
2. Compute clip metrics after prob_ratio calculation:
118+
```python
119+
clipped_ratio = torch.clip(prob_ratio, 1 - epsilon, 1 + epsilon_high)
120+
is_clipped = (prob_ratio < 1 - epsilon) | (prob_ratio > 1 + epsilon_high)
121+
clip_fraction = (is_clipped.float() * assistant_mask).sum() / (assistant_mask.sum() + 1e-6)
122+
mean_importance_ratio = (prob_ratio * assistant_mask).sum() / (assistant_mask.sum() + 1e-6)
123+
```
124+
125+
3. Log the new metrics:
126+
```python
127+
trainer._metrics["train"]["frac_old_logprobs_valid"].append(frac_old_logprobs_valid)
128+
trainer._metrics["train"]["mean_importance_ratio"].append(mean_importance_ratio.item())
129+
trainer._metrics["train"]["clip_fraction"].append(clip_fraction.item())
130+
```
131+
132+
### Performance Impact
133+
134+
- **Memory:** Negligible - reuses existing tensors, only adds scalar computations
135+
- **Compute:** Negligible - O(n) operations on existing tensors
136+
- **Logging overhead:** 3 additional floats per training step
137+
138+
## Use Cases
139+
140+
### 1. Debugging Missing Logprobs
141+
142+
If `frac_old_logprobs_valid = 0`:
143+
- Check that rollout is requesting logprobs from the model
144+
- Check that logprobs are being attached to trajectory messages
145+
- Check tokenization is extracting logprobs correctly (especially for dict messages)
146+
147+
### 2. Monitoring Training Health
148+
149+
Healthy training should show:
150+
- `frac_old_logprobs_valid` stable and > 0
151+
- `mean_importance_ratio` fluctuating around 1.0
152+
- `clip_fraction` low but non-zero
153+
154+
### 3. Detecting Distribution Drift
155+
156+
If `clip_fraction` suddenly increases:
157+
- Policy may have diverged too far from rollout policy
158+
- Consider reducing learning rate or increasing rollout frequency
159+
160+
## Backwards Compatibility
161+
162+
These changes are additive - existing code continues to work. The new metrics appear in wandb logs automatically if wandb is configured.
163+
164+
## Testing
165+
166+
Manual verification:
167+
1. Run training with valid logprobs → `frac_old_logprobs_valid > 0`
168+
2. Run training with `allow_training_without_logprobs=True` and no logprobs → `frac_old_logprobs_valid = 0`
169+
3. Verify `mean_importance_ratio` deviates from 1.0 over training steps
170+
171+
## Related Work
172+
173+
- PPO paper (Schulman et al., 2017) discusses importance sampling and clipping
174+
- TRL's `PPOTrainer` logs similar metrics (`clipfrac`, `ratio`)
175+
- This brings ART's observability closer to standard PPO implementations

src/art/unsloth/train.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -163,9 +163,13 @@ def compute_loss(
163163
new_logprobs.dtype
164164
)
165165
weights = shift_tensor(inputs["weights"], 0.0)
166+
old_logprobs_nan_mask = torch.isnan(old_logprobs)
167+
frac_old_logprobs_valid = 1.0 - (
168+
old_logprobs_nan_mask.float().sum() / (old_logprobs.numel() + 1e-6)
169+
).item()
166170
# Assume missing old logprobs were sampled under the current policy
167171
old_logprobs = torch.where(
168-
torch.isnan(old_logprobs),
172+
old_logprobs_nan_mask,
169173
new_logprobs.detach(),
170174
old_logprobs,
171175
)
@@ -190,9 +194,13 @@ def compute_loss(
190194
prob_ratio = torch.clamp(
191195
prob_ratio, max=max_negative_advantage_importance_sampling_weight
192196
)
197+
clipped_ratio = torch.clip(prob_ratio, 1 - epsilon, 1 + epsilon_high)
198+
is_clipped = (prob_ratio < 1 - epsilon) | (prob_ratio > 1 + epsilon_high)
199+
clip_fraction = (is_clipped.float() * assistant_mask).sum() / (assistant_mask.sum() + 1e-6)
200+
mean_importance_ratio = (prob_ratio * assistant_mask).sum() / (assistant_mask.sum() + 1e-6)
193201
policy_loss = -torch.min(
194202
prob_ratio * advantages,
195-
torch.clip(prob_ratio, 1 - epsilon, 1 + epsilon_high) * advantages,
203+
clipped_ratio * advantages,
196204
)
197205
if upper_bound := _config.get("truncated_importance_sampling", None):
198206
if "original_logprobs" in inputs:
@@ -228,6 +236,9 @@ def compute_loss(
228236
trainer._metrics["train"]["learning_rate"].append(config.learning_rate)
229237
trainer._metrics["train"]["policy_loss"].append(mean_policy_loss.item())
230238
trainer._metrics["train"]["entropy"].append(mean_entropy.item()) # type: ignore
239+
trainer._metrics["train"]["frac_old_logprobs_valid"].append(frac_old_logprobs_valid)
240+
trainer._metrics["train"]["mean_importance_ratio"].append(mean_importance_ratio.item())
241+
trainer._metrics["train"]["clip_fraction"].append(clip_fraction.item())
231242
if config.beta > 0.0:
232243
trainer._metrics["train"]["kl_div"].append(mean_kl.item())
233244
return mean_policy_loss + config.beta * mean_kl

0 commit comments

Comments
 (0)