Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions skyrl-train/examples/gsm8k/gsm8k_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ def extract_solution(solution_str):

dataset = datasets.load_dataset(data_source, "main")

train_dataset = dataset["train"]
val_dataset = dataset["test"]
train_dataset = dataset["train"].select(range(128))
val_dataset = dataset["test"].select(range(64))
Comment on lines +51 to +52
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The dataset is being truncated to a very small, hardcoded number of samples (128 for train, 64 for test). This is likely for debugging but significantly reduces the utility of the example script for other users. It would be better to either remove this truncation to use the full dataset by default, or make these sizes configurable via command-line arguments.

Suggested change
train_dataset = dataset["train"].select(range(128))
val_dataset = dataset["test"].select(range(64))
train_dataset = dataset["train"]
val_dataset = dataset["test"]

Comment on lines +51 to +52
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 GSM8K dataset script hardcodes truncation to 128 train / 64 test samples

The GSM8K dataset preparation script now unconditionally truncates the training set to 128 examples and the validation set to 64 examples. The GSM8K training set has 7,473 examples and the test set has 1,319 examples, so this discards >98% of the data.

Root Cause

Lines 51-52 add .select(range(128)) and .select(range(64)) directly after loading the dataset splits. This appears to be a debugging leftover that was accidentally included in the PR.

Before:

train_dataset = dataset["train"]
val_dataset = dataset["test"]

After:

train_dataset = dataset["train"].select(range(128))
val_dataset = dataset["test"].select(range(64))

The existing --max_train_dataset_length flag on line 54-56 already provides a proper mechanism for limiting dataset size. The hardcoded truncation makes that flag ineffective for values above 128.

Impact: Anyone using this script to prepare GSM8K data will silently get a tiny dataset, leading to undertrained models and misleading evaluation results.

Suggested change
train_dataset = dataset["train"].select(range(128))
val_dataset = dataset["test"].select(range(64))
train_dataset = dataset["train"]
val_dataset = dataset["test"]
Open in Devin Review

Was this helpful? React with 👍 or 👎 to provide feedback.


if args.max_train_dataset_length is not None:
max_len = min(args.max_train_dataset_length, len(train_dataset))
Expand Down
68 changes: 68 additions & 0 deletions skyrl-train/examples/lora/run_qwen3_0.6b_gsm8k_grpo_lora.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
set -x

set -x
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This set -x is redundant as it's already set on line 1. Please remove it to keep the script clean.


# Colocated GRPO LoRA training + generation for Qwen3-0.6B on GSM8K with LLM as Judge.

# 1. Prepare dataset:
# uv run examples/llm_as_a_judge/gsm8k_dataset_judge.py --output_dir $HOME/data/gsm8k_llm_judge
# 2. Set API key in .env.llm_judge:
# OPENAI_API_KEY=sk-...
# 3. Run training:
# bash examples/lora/run_qwen3_0.6b_gsm8k_grpo_lora.sh

# NOTE (sumanthrh): `micro_train_batch_size_per_gpu` and `micro_forward_batch_size_per_gpu` can be tuned

DATA_DIR="/mnt/workspace/datasets/gsm8k_with_reward"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The DATA_DIR is hardcoded to an absolute path (/mnt/workspace/...) which is not portable across different environments. Consider using a path relative to $HOME or allowing it to be overridden by an environment variable to make the script more reusable.

Suggested change
DATA_DIR="/mnt/workspace/datasets/gsm8k_with_reward"
DATA_DIR="${DATA_DIR:-$HOME/data/gsm8k_with_reward}"

NUM_NODES=2
NUM_GPUS=4 # per node
TOTAL_GPUS=$((NUM_GPUS * NUM_NODES)) # 8 total
LOGGER="console" # change to "console" to print to stdout

INFERENCE_BACKEND="vllm"


uv run --isolated --extra $INFERENCE_BACKEND -m skyrl_train.entrypoints.main_base \
data.train_data="['$DATA_DIR/train.parquet']" \
data.val_data="['$DATA_DIR/validation.parquet']" \
trainer.algorithm.advantage_estimator="grpo" \
trainer.policy.model.path="Qwen/Qwen3-0.6B" \
trainer.placement.colocate_all=true \
trainer.policy.model.lora.rank=32 \
trainer.policy.model.lora.alpha=32 \
trainer.strategy=fsdp2 \
trainer.placement.policy_num_nodes=$NUM_NODES \
trainer.placement.ref_num_nodes=$NUM_NODES \
trainer.placement.policy_num_gpus_per_node=$NUM_GPUS \
trainer.placement.ref_num_gpus_per_node=$NUM_GPUS \
generator.num_inference_engines=$TOTAL_GPUS \
generator.inference_engine_tensor_parallel_size=1 \
trainer.epochs=1 \
trainer.eval_batch_size=64 \
trainer.eval_before_train=false \
trainer.eval_interval=100 \
trainer.update_epochs_per_batch=1 \
trainer.train_batch_size=128 \
trainer.policy_mini_batch_size=32 \
trainer.micro_forward_batch_size_per_gpu=8 \
trainer.micro_train_batch_size_per_gpu=8 \
trainer.ckpt_interval=10 \
trainer.max_prompt_length=512 \
generator.sampling_params.max_generate_length=512 \
trainer.policy.optimizer_config.lr=3.0e-5 \
trainer.algorithm.use_kl_loss=true \
generator.backend=$INFERENCE_BACKEND \
generator.run_engines_locally=true \
generator.weight_sync_backend=nccl \
generator.async_engine=true \
generator.batched=true \
environment.env_class=llm_as_a_judge \
environment.skyrl_gym.llm_as_a_judge.model="gpt-4o-mini" \
generator.n_samples_per_prompt=2 \
generator.gpu_memory_utilization=0.8 \
trainer.logger="$LOGGER" \
trainer.project_name="gsm8k_qwen3_0.6b_lora" \
trainer.run_name="gsm8k_qwen3_0.6b_lora_grpo" \
trainer.resume_mode=null \
trainer.ckpt_path="$HOME/ckpts/gsm8k_qwen3_0.6b_lora_ckpt" \
$@
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# Environment variables for summarization judge training
# Copy this file to .env.summarization_judge and fill in your API key

# OpenAI API key for the grading model
OPENAI_API_KEY=sk-your-api-key-here

# Optional: Weights & Biases API key for logging
# WANDB_API_KEY=your-wandb-key-here
2 changes: 2 additions & 0 deletions skyrl-train/examples/summarization_judge/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# Summarization Judge Example
# Use LLM-as-a-judge to train summarization models
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
"""
Main entrypoint for the summarization LLM-as-a-judge example.
"""

import ray
import hydra
from omegaconf import DictConfig
from skyrl_train.utils import initialize_ray
from skyrl_train.entrypoints.main_base import BasePPOExp, config_dir, validate_cfg
from skyrl_gym.envs import register


@ray.remote(num_cpus=1)
def skyrl_entrypoint(cfg: DictConfig):
# Register the summarization_judge environment inside the entrypoint task
register(
id="summarization_judge",
entry_point="examples.summarization_judge.summarization_judge_env:SummarizationJudgeEnv",
)

# Run the training
exp = BasePPOExp(cfg)
exp.run()


@hydra.main(config_path=config_dir, config_name="ppo_base_config", version_base=None)
def main(cfg: DictConfig) -> None:
# Validate the config
validate_cfg(cfg)

initialize_ray(cfg)
ray.get(skyrl_entrypoint.remote(cfg))


if __name__ == "__main__":
main()
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
#!/bin/bash
set -x

# Colocated GRPO LoRA training + generation for summarization with LLM as Judge.
#
# 1. Prepare dataset:
# uv run examples/summarization_judge/summarization_dataset.py \
# --input_file /path/to/your/data.jsonl \
# --output_dir $HOME/data/summarization_judge
#
# 2. Set API key in .env.summarization_judge:
# OPENAI_API_KEY=sk-...
#
# 3. Run training:
# bash examples/summarization_judge/run_summarization_judge.sh

DATA_DIR="$HOME/data/summarization_judge"
CKPT_PATH="$HOME/ckpts/summarization_judge"

NUM_NODES=2
NUM_GPUS=4 # per node
TOTAL_GPUS=$((NUM_GPUS * NUM_NODES)) # 8 total
LOGGER="console" # change to "wandb" for W&B logging

INFERENCE_BACKEND="vllm"


uv run --isolated --extra $INFERENCE_BACKEND --env-file .env.summarization_judge -m examples.summarization_judge.main_summarization_judge \
data.train_data="['$DATA_DIR/train.parquet']" \
data.val_data="['$DATA_DIR/validation.parquet']" \
trainer.algorithm.advantage_estimator="grpo" \
trainer.policy.model.path="Qwen/Qwen3-0.6B" \
trainer.placement.colocate_all=true \
trainer.policy.model.lora.rank=32 \
trainer.policy.model.lora.alpha=32 \
trainer.strategy=fsdp2 \
trainer.placement.policy_num_nodes=$NUM_NODES \
trainer.placement.ref_num_nodes=$NUM_NODES \
trainer.placement.policy_num_gpus_per_node=$NUM_GPUS \
trainer.placement.ref_num_gpus_per_node=$NUM_GPUS \
generator.num_inference_engines=$TOTAL_GPUS \
generator.inference_engine_tensor_parallel_size=1 \
trainer.epochs=4 \
trainer.eval_batch_size=32 \
trainer.eval_before_train=false \
trainer.eval_interval=50 \
trainer.update_epochs_per_batch=1 \
trainer.train_batch_size=64 \
trainer.policy_mini_batch_size=16 \
trainer.micro_forward_batch_size_per_gpu=4 \
trainer.micro_train_batch_size_per_gpu=4 \
trainer.ckpt_interval=10 \
trainer.max_prompt_length=4096 \
generator.sampling_params.max_generate_length=2048 \
trainer.policy.optimizer_config.lr=3.0e-5 \
trainer.algorithm.use_kl_loss=true \
generator.backend=$INFERENCE_BACKEND \
generator.run_engines_locally=true \
generator.weight_sync_backend=nccl \
generator.async_engine=true \
generator.batched=true \
environment.env_class=summarization_judge \
environment.skyrl_gym.summarization_judge.model="gpt-4o-mini" \
environment.skyrl_gym.summarization_judge.temperature=0.0 \
generator.n_samples_per_prompt=2 \
generator.gpu_memory_utilization=0.8 \
trainer.logger="$LOGGER" \
trainer.project_name="summarization_judge" \
trainer.run_name="summarization_judge_grpo_lora" \
trainer.resume_mode=null \
trainer.ckpt_path="$CKPT_PATH" \
$@
186 changes: 186 additions & 0 deletions skyrl-train/examples/summarization_judge/summarization_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Preprocess a summarization dataset to parquet format for LLM-as-a-judge training.

Expected input format (JSON/JSONL):
{
"prompt": "Your task is to summarize...",
"original_document": "The full document text...",
"user_intent": {
"purpose": "...",
"audience": "...",
"tone": "...",
"target_words": 300,
"focus_areas": "..."
},
"sample_id": "unique-id",
"document_id": "doc-id"
}
"""

import argparse
import json
import os
from typing import List, Dict, Any

import datasets


def load_local_data(input_path: str) -> List[Dict[str, Any]]:
"""Load data from a local JSON or JSONL file."""
data = []

if input_path.endswith('.jsonl'):
with open(input_path, 'r', encoding='utf-8') as f:
for line in f:
if line.strip():
data.append(json.loads(line))
elif input_path.endswith('.json'):
with open(input_path, 'r', encoding='utf-8') as f:
loaded = json.load(f)
if isinstance(loaded, list):
data = loaded
else:
data = [loaded]
else:
raise ValueError(f"Unsupported file format: {input_path}. Use .json or .jsonl")

return data


def make_map_fn(split: str):
"""Create a mapping function to convert raw data to training format."""

def process_fn(example: Dict[str, Any], idx: int) -> Dict[str, Any]:
# Extract the prompt - this is what the model will see
prompt_text = example.get("prompt", "")

# If prompt is not provided, construct it from user_intent and original_document
if not prompt_text:
user_intent = example.get("user_intent", {})
original_document = example.get("original_document", "")

prompt_text = f"""Your task is to summarize the user provided document, based on the user intent.

## User Intent
- Purpose: {user_intent.get('purpose', 'N/A')}
- Audience: {user_intent.get('audience', 'N/A')}
- Tone: {user_intent.get('tone', 'neutral')}
- Target Words: {user_intent.get('target_words', 300)}
- Focus Areas: {user_intent.get('focus_areas', 'N/A')}

## Document
{original_document}

Please provide a summary that meets the above requirements."""

# Ground truth contains the data needed by the grader
ground_truth = {
"user_intent": example.get("user_intent", {}),
"original_document": example.get("original_document", ""),
}

# Build the training data format
data = {
"data_source": "summarization",
"prompt": [
{
"role": "user",
"content": prompt_text,
}
],
"env_class": "summarization_judge",
"reward_spec": {
"method": "llm_judge",
"ground_truth": ground_truth,
},
"extra_info": {
"split": split,
"index": idx,
"sample_id": example.get("sample_id", str(idx)),
"document_id": example.get("document_id", ""),
},
}
return data

return process_fn


if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Convert summarization data to parquet format for LLM-as-a-judge training"
)
parser.add_argument(
"--input_file",
type=str,
required=True,
help="Path to the input JSON or JSONL file",
)
parser.add_argument(
"--output_dir",
type=str,
default="~/data/summarization_judge",
help="Output directory for parquet files",
)
parser.add_argument(
"--train_split",
type=float,
default=0.9,
help="Fraction of data to use for training (default: 0.9)",
)
parser.add_argument(
"--max_samples",
type=int,
default=None,
help="Maximum number of samples to use (default: all)",
)

args = parser.parse_args()
args.output_dir = os.path.expanduser(args.output_dir)

# Load the data
print(f"Loading data from {args.input_file}...")
raw_data = load_local_data(args.input_file)
print(f"Loaded {len(raw_data)} samples")

# Optionally limit samples
if args.max_samples is not None:
raw_data = raw_data[:args.max_samples]
print(f"Limited to {len(raw_data)} samples")

# Convert to HuggingFace dataset
dataset = datasets.Dataset.from_list(raw_data)

# Split into train/val
split_idx = int(len(dataset) * args.train_split)
train_dataset = dataset.select(range(split_idx))
val_dataset = dataset.select(range(split_idx, len(dataset)))
Comment on lines +167 to +169
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The current method of splitting the dataset by index is not reproducible if the dataset order changes, and it doesn't shuffle the data, which can lead to biased splits. Using datasets.Dataset.train_test_split with a fixed seed is a more robust and reproducible approach.

Suggested change
split_idx = int(len(dataset) * args.train_split)
train_dataset = dataset.select(range(split_idx))
val_dataset = dataset.select(range(split_idx, len(dataset)))
split_dataset = dataset.train_test_split(test_size=1.0 - args.train_split, seed=42)
train_dataset = split_dataset["train"]
val_dataset = split_dataset["test"]


print(f"Train samples: {len(train_dataset)}, Validation samples: {len(val_dataset)}")

# Apply mapping function
train_dataset = train_dataset.map(function=make_map_fn("train"), with_indices=True)
val_dataset = val_dataset.map(function=make_map_fn("test"), with_indices=True)

# Save to parquet
os.makedirs(args.output_dir, exist_ok=True)
train_path = os.path.join(args.output_dir, "train.parquet")
val_path = os.path.join(args.output_dir, "validation.parquet")

train_dataset.to_parquet(train_path)
val_dataset.to_parquet(val_path)

print(f"Saved train dataset to {train_path}")
print(f"Saved validation dataset to {val_path}")
Loading
Loading