Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 11 additions & 6 deletions slime/ray/rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -665,6 +665,16 @@ def _convert_samples_to_train_data(self, samples: list[Sample] | list[list[Sampl
assert len(raw_rewards) == len(samples)
assert len(rewards) == len(samples)

# Rollout id (one per rollout execution). Default rollouts emit one
# sample per rollout, so we fall back to ``sample.index`` (unique).
# Compact / subagent paths that emit multiple training samples per
# rollout set ``rollout_id`` explicitly so all siblings share a
# value; the loss reducer then aggregates them as one rollout.
if samples[0].rollout_id is None:
rollout_ids = list(range(len(samples)))
else:
rollout_ids = [sample.rollout_id for sample in samples]

train_data = {
"tokens": [sample.tokens for sample in samples],
"response_lengths": [sample.response_length for sample in samples],
Expand All @@ -674,12 +684,7 @@ def _convert_samples_to_train_data(self, samples: list[Sample] | list[list[Sampl
"raw_reward": raw_rewards,
"truncated": [1 if sample.status == Sample.Status.TRUNCATED else 0 for sample in samples],
"sample_indices": [sample.index for sample in samples],
# Rollout id (one per rollout execution). Default rollouts emit one
# sample per rollout, so we fall back to ``sample.index`` (unique).
# Compact / subagent paths that emit multiple training samples per
# rollout set ``rollout_id`` explicitly so all siblings share a
# value; the loss reducer then aggregates them as one rollout.
"rollout_ids": [s.rollout_id if s.rollout_id is not None else s.index for s in samples],
"rollout_ids": rollout_ids,
}

# loss mask
Expand Down
Loading