-
Notifications
You must be signed in to change notification settings - Fork 309
[Draft] support for azure blob storage #1172
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
4298382
e9fde41
1c7b6f3
f1006ab
abb935f
dbf237e
9812b26
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -48,8 +48,8 @@ def extract_solution(solution_str): | |||||||||||
|
|
||||||||||||
| dataset = datasets.load_dataset(data_source, "main") | ||||||||||||
|
|
||||||||||||
| train_dataset = dataset["train"] | ||||||||||||
| val_dataset = dataset["test"] | ||||||||||||
| train_dataset = dataset["train"].select(range(128)) | ||||||||||||
| val_dataset = dataset["test"].select(range(64)) | ||||||||||||
|
Comment on lines
+51
to
+52
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🔴 GSM8K dataset script hardcodes truncation to 128 train / 64 test samples The GSM8K dataset preparation script now unconditionally truncates the training set to 128 examples and the validation set to 64 examples. The GSM8K training set has 7,473 examples and the test set has 1,319 examples, so this discards >98% of the data. Root CauseLines 51-52 add Before: train_dataset = dataset["train"]
val_dataset = dataset["test"]After: train_dataset = dataset["train"].select(range(128))
val_dataset = dataset["test"].select(range(64))The existing Impact: Anyone using this script to prepare GSM8K data will silently get a tiny dataset, leading to undertrained models and misleading evaluation results.
Suggested change
Was this helpful? React with 👍 or 👎 to provide feedback. |
||||||||||||
|
|
||||||||||||
| if args.max_train_dataset_length is not None: | ||||||||||||
| max_len = min(args.max_train_dataset_length, len(train_dataset)) | ||||||||||||
|
|
||||||||||||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -0,0 +1,68 @@ | ||||||
| set -x | ||||||
|
|
||||||
| set -x | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||||||
|
|
||||||
| # Colocated GRPO LoRA training + generation for Qwen3-0.6B on GSM8K with LLM as Judge. | ||||||
|
|
||||||
| # 1. Prepare dataset: | ||||||
| # uv run examples/llm_as_a_judge/gsm8k_dataset_judge.py --output_dir $HOME/data/gsm8k_llm_judge | ||||||
| # 2. Set API key in .env.llm_judge: | ||||||
| # OPENAI_API_KEY=sk-... | ||||||
| # 3. Run training: | ||||||
| # bash examples/lora/run_qwen3_0.6b_gsm8k_grpo_lora.sh | ||||||
|
|
||||||
| # NOTE (sumanthrh): `micro_train_batch_size_per_gpu` and `micro_forward_batch_size_per_gpu` can be tuned | ||||||
|
|
||||||
| DATA_DIR="/mnt/workspace/datasets/gsm8k_with_reward" | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The
Suggested change
|
||||||
| NUM_NODES=2 | ||||||
| NUM_GPUS=4 # per node | ||||||
| TOTAL_GPUS=$((NUM_GPUS * NUM_NODES)) # 8 total | ||||||
| LOGGER="console" # change to "console" to print to stdout | ||||||
|
|
||||||
| INFERENCE_BACKEND="vllm" | ||||||
|
|
||||||
|
|
||||||
| uv run --isolated --extra $INFERENCE_BACKEND -m skyrl_train.entrypoints.main_base \ | ||||||
| data.train_data="['$DATA_DIR/train.parquet']" \ | ||||||
| data.val_data="['$DATA_DIR/validation.parquet']" \ | ||||||
| trainer.algorithm.advantage_estimator="grpo" \ | ||||||
| trainer.policy.model.path="Qwen/Qwen3-0.6B" \ | ||||||
| trainer.placement.colocate_all=true \ | ||||||
| trainer.policy.model.lora.rank=32 \ | ||||||
| trainer.policy.model.lora.alpha=32 \ | ||||||
| trainer.strategy=fsdp2 \ | ||||||
| trainer.placement.policy_num_nodes=$NUM_NODES \ | ||||||
| trainer.placement.ref_num_nodes=$NUM_NODES \ | ||||||
| trainer.placement.policy_num_gpus_per_node=$NUM_GPUS \ | ||||||
| trainer.placement.ref_num_gpus_per_node=$NUM_GPUS \ | ||||||
| generator.num_inference_engines=$TOTAL_GPUS \ | ||||||
| generator.inference_engine_tensor_parallel_size=1 \ | ||||||
| trainer.epochs=1 \ | ||||||
| trainer.eval_batch_size=64 \ | ||||||
| trainer.eval_before_train=false \ | ||||||
| trainer.eval_interval=100 \ | ||||||
| trainer.update_epochs_per_batch=1 \ | ||||||
| trainer.train_batch_size=128 \ | ||||||
| trainer.policy_mini_batch_size=32 \ | ||||||
| trainer.micro_forward_batch_size_per_gpu=8 \ | ||||||
| trainer.micro_train_batch_size_per_gpu=8 \ | ||||||
| trainer.ckpt_interval=10 \ | ||||||
| trainer.max_prompt_length=512 \ | ||||||
| generator.sampling_params.max_generate_length=512 \ | ||||||
| trainer.policy.optimizer_config.lr=3.0e-5 \ | ||||||
| trainer.algorithm.use_kl_loss=true \ | ||||||
| generator.backend=$INFERENCE_BACKEND \ | ||||||
| generator.run_engines_locally=true \ | ||||||
| generator.weight_sync_backend=nccl \ | ||||||
| generator.async_engine=true \ | ||||||
| generator.batched=true \ | ||||||
| environment.env_class=llm_as_a_judge \ | ||||||
| environment.skyrl_gym.llm_as_a_judge.model="gpt-4o-mini" \ | ||||||
| generator.n_samples_per_prompt=2 \ | ||||||
| generator.gpu_memory_utilization=0.8 \ | ||||||
| trainer.logger="$LOGGER" \ | ||||||
| trainer.project_name="gsm8k_qwen3_0.6b_lora" \ | ||||||
| trainer.run_name="gsm8k_qwen3_0.6b_lora_grpo" \ | ||||||
| trainer.resume_mode=null \ | ||||||
| trainer.ckpt_path="$HOME/ckpts/gsm8k_qwen3_0.6b_lora_ckpt" \ | ||||||
| $@ | ||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,8 @@ | ||
| # Environment variables for summarization judge training | ||
| # Copy this file to .env.summarization_judge and fill in your API key | ||
|
|
||
| # OpenAI API key for the grading model | ||
| OPENAI_API_KEY=sk-your-api-key-here | ||
|
|
||
| # Optional: Weights & Biases API key for logging | ||
| # WANDB_API_KEY=your-wandb-key-here |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,2 @@ | ||
| # Summarization Judge Example | ||
| # Use LLM-as-a-judge to train summarization models |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,36 @@ | ||
| """ | ||
| Main entrypoint for the summarization LLM-as-a-judge example. | ||
| """ | ||
|
|
||
| import ray | ||
| import hydra | ||
| from omegaconf import DictConfig | ||
| from skyrl_train.utils import initialize_ray | ||
| from skyrl_train.entrypoints.main_base import BasePPOExp, config_dir, validate_cfg | ||
| from skyrl_gym.envs import register | ||
|
|
||
|
|
||
| @ray.remote(num_cpus=1) | ||
| def skyrl_entrypoint(cfg: DictConfig): | ||
| # Register the summarization_judge environment inside the entrypoint task | ||
| register( | ||
| id="summarization_judge", | ||
| entry_point="examples.summarization_judge.summarization_judge_env:SummarizationJudgeEnv", | ||
| ) | ||
|
|
||
| # Run the training | ||
| exp = BasePPOExp(cfg) | ||
| exp.run() | ||
|
|
||
|
|
||
| @hydra.main(config_path=config_dir, config_name="ppo_base_config", version_base=None) | ||
| def main(cfg: DictConfig) -> None: | ||
| # Validate the config | ||
| validate_cfg(cfg) | ||
|
|
||
| initialize_ray(cfg) | ||
| ray.get(skyrl_entrypoint.remote(cfg)) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| main() |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,72 @@ | ||
| #!/bin/bash | ||
| set -x | ||
|
|
||
| # Colocated GRPO LoRA training + generation for summarization with LLM as Judge. | ||
| # | ||
| # 1. Prepare dataset: | ||
| # uv run examples/summarization_judge/summarization_dataset.py \ | ||
| # --input_file /path/to/your/data.jsonl \ | ||
| # --output_dir $HOME/data/summarization_judge | ||
| # | ||
| # 2. Set API key in .env.summarization_judge: | ||
| # OPENAI_API_KEY=sk-... | ||
| # | ||
| # 3. Run training: | ||
| # bash examples/summarization_judge/run_summarization_judge.sh | ||
|
|
||
| DATA_DIR="$HOME/data/summarization_judge" | ||
| CKPT_PATH="$HOME/ckpts/summarization_judge" | ||
|
|
||
| NUM_NODES=2 | ||
| NUM_GPUS=4 # per node | ||
| TOTAL_GPUS=$((NUM_GPUS * NUM_NODES)) # 8 total | ||
| LOGGER="console" # change to "wandb" for W&B logging | ||
|
|
||
| INFERENCE_BACKEND="vllm" | ||
|
|
||
|
|
||
| uv run --isolated --extra $INFERENCE_BACKEND --env-file .env.summarization_judge -m examples.summarization_judge.main_summarization_judge \ | ||
| data.train_data="['$DATA_DIR/train.parquet']" \ | ||
| data.val_data="['$DATA_DIR/validation.parquet']" \ | ||
| trainer.algorithm.advantage_estimator="grpo" \ | ||
| trainer.policy.model.path="Qwen/Qwen3-0.6B" \ | ||
| trainer.placement.colocate_all=true \ | ||
| trainer.policy.model.lora.rank=32 \ | ||
| trainer.policy.model.lora.alpha=32 \ | ||
| trainer.strategy=fsdp2 \ | ||
| trainer.placement.policy_num_nodes=$NUM_NODES \ | ||
| trainer.placement.ref_num_nodes=$NUM_NODES \ | ||
| trainer.placement.policy_num_gpus_per_node=$NUM_GPUS \ | ||
| trainer.placement.ref_num_gpus_per_node=$NUM_GPUS \ | ||
| generator.num_inference_engines=$TOTAL_GPUS \ | ||
| generator.inference_engine_tensor_parallel_size=1 \ | ||
| trainer.epochs=4 \ | ||
| trainer.eval_batch_size=32 \ | ||
| trainer.eval_before_train=false \ | ||
| trainer.eval_interval=50 \ | ||
| trainer.update_epochs_per_batch=1 \ | ||
| trainer.train_batch_size=64 \ | ||
| trainer.policy_mini_batch_size=16 \ | ||
| trainer.micro_forward_batch_size_per_gpu=4 \ | ||
| trainer.micro_train_batch_size_per_gpu=4 \ | ||
| trainer.ckpt_interval=10 \ | ||
| trainer.max_prompt_length=4096 \ | ||
| generator.sampling_params.max_generate_length=2048 \ | ||
| trainer.policy.optimizer_config.lr=3.0e-5 \ | ||
| trainer.algorithm.use_kl_loss=true \ | ||
| generator.backend=$INFERENCE_BACKEND \ | ||
| generator.run_engines_locally=true \ | ||
| generator.weight_sync_backend=nccl \ | ||
| generator.async_engine=true \ | ||
| generator.batched=true \ | ||
| environment.env_class=summarization_judge \ | ||
| environment.skyrl_gym.summarization_judge.model="gpt-4o-mini" \ | ||
| environment.skyrl_gym.summarization_judge.temperature=0.0 \ | ||
| generator.n_samples_per_prompt=2 \ | ||
| generator.gpu_memory_utilization=0.8 \ | ||
| trainer.logger="$LOGGER" \ | ||
| trainer.project_name="summarization_judge" \ | ||
| trainer.run_name="summarization_judge_grpo_lora" \ | ||
| trainer.resume_mode=null \ | ||
| trainer.ckpt_path="$CKPT_PATH" \ | ||
| $@ |
| Original file line number | Diff line number | Diff line change | ||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,186 @@ | ||||||||||||||
| # Copyright 2024 Bytedance Ltd. and/or its affiliates | ||||||||||||||
| # | ||||||||||||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||||||||||||
| # you may not use this file except in compliance with the License. | ||||||||||||||
| # You may obtain a copy of the License at | ||||||||||||||
| # | ||||||||||||||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||||||||||||||
| # | ||||||||||||||
| # Unless required by applicable law or agreed to in writing, software | ||||||||||||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||||||||||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||||||||||
| # See the License for the specific language governing permissions and | ||||||||||||||
| # limitations under the License. | ||||||||||||||
| """ | ||||||||||||||
| Preprocess a summarization dataset to parquet format for LLM-as-a-judge training. | ||||||||||||||
|
|
||||||||||||||
| Expected input format (JSON/JSONL): | ||||||||||||||
| { | ||||||||||||||
| "prompt": "Your task is to summarize...", | ||||||||||||||
| "original_document": "The full document text...", | ||||||||||||||
| "user_intent": { | ||||||||||||||
| "purpose": "...", | ||||||||||||||
| "audience": "...", | ||||||||||||||
| "tone": "...", | ||||||||||||||
| "target_words": 300, | ||||||||||||||
| "focus_areas": "..." | ||||||||||||||
| }, | ||||||||||||||
| "sample_id": "unique-id", | ||||||||||||||
| "document_id": "doc-id" | ||||||||||||||
| } | ||||||||||||||
| """ | ||||||||||||||
|
|
||||||||||||||
| import argparse | ||||||||||||||
| import json | ||||||||||||||
| import os | ||||||||||||||
| from typing import List, Dict, Any | ||||||||||||||
|
|
||||||||||||||
| import datasets | ||||||||||||||
|
|
||||||||||||||
|
|
||||||||||||||
| def load_local_data(input_path: str) -> List[Dict[str, Any]]: | ||||||||||||||
| """Load data from a local JSON or JSONL file.""" | ||||||||||||||
| data = [] | ||||||||||||||
|
|
||||||||||||||
| if input_path.endswith('.jsonl'): | ||||||||||||||
| with open(input_path, 'r', encoding='utf-8') as f: | ||||||||||||||
| for line in f: | ||||||||||||||
| if line.strip(): | ||||||||||||||
| data.append(json.loads(line)) | ||||||||||||||
| elif input_path.endswith('.json'): | ||||||||||||||
| with open(input_path, 'r', encoding='utf-8') as f: | ||||||||||||||
| loaded = json.load(f) | ||||||||||||||
| if isinstance(loaded, list): | ||||||||||||||
| data = loaded | ||||||||||||||
| else: | ||||||||||||||
| data = [loaded] | ||||||||||||||
| else: | ||||||||||||||
| raise ValueError(f"Unsupported file format: {input_path}. Use .json or .jsonl") | ||||||||||||||
|
|
||||||||||||||
| return data | ||||||||||||||
|
|
||||||||||||||
|
|
||||||||||||||
| def make_map_fn(split: str): | ||||||||||||||
| """Create a mapping function to convert raw data to training format.""" | ||||||||||||||
|
|
||||||||||||||
| def process_fn(example: Dict[str, Any], idx: int) -> Dict[str, Any]: | ||||||||||||||
| # Extract the prompt - this is what the model will see | ||||||||||||||
| prompt_text = example.get("prompt", "") | ||||||||||||||
|
|
||||||||||||||
| # If prompt is not provided, construct it from user_intent and original_document | ||||||||||||||
| if not prompt_text: | ||||||||||||||
| user_intent = example.get("user_intent", {}) | ||||||||||||||
| original_document = example.get("original_document", "") | ||||||||||||||
|
|
||||||||||||||
| prompt_text = f"""Your task is to summarize the user provided document, based on the user intent. | ||||||||||||||
|
|
||||||||||||||
| ## User Intent | ||||||||||||||
| - Purpose: {user_intent.get('purpose', 'N/A')} | ||||||||||||||
| - Audience: {user_intent.get('audience', 'N/A')} | ||||||||||||||
| - Tone: {user_intent.get('tone', 'neutral')} | ||||||||||||||
| - Target Words: {user_intent.get('target_words', 300)} | ||||||||||||||
| - Focus Areas: {user_intent.get('focus_areas', 'N/A')} | ||||||||||||||
|
|
||||||||||||||
| ## Document | ||||||||||||||
| {original_document} | ||||||||||||||
|
|
||||||||||||||
| Please provide a summary that meets the above requirements.""" | ||||||||||||||
|
|
||||||||||||||
| # Ground truth contains the data needed by the grader | ||||||||||||||
| ground_truth = { | ||||||||||||||
| "user_intent": example.get("user_intent", {}), | ||||||||||||||
| "original_document": example.get("original_document", ""), | ||||||||||||||
| } | ||||||||||||||
|
|
||||||||||||||
| # Build the training data format | ||||||||||||||
| data = { | ||||||||||||||
| "data_source": "summarization", | ||||||||||||||
| "prompt": [ | ||||||||||||||
| { | ||||||||||||||
| "role": "user", | ||||||||||||||
| "content": prompt_text, | ||||||||||||||
| } | ||||||||||||||
| ], | ||||||||||||||
| "env_class": "summarization_judge", | ||||||||||||||
| "reward_spec": { | ||||||||||||||
| "method": "llm_judge", | ||||||||||||||
| "ground_truth": ground_truth, | ||||||||||||||
| }, | ||||||||||||||
| "extra_info": { | ||||||||||||||
| "split": split, | ||||||||||||||
| "index": idx, | ||||||||||||||
| "sample_id": example.get("sample_id", str(idx)), | ||||||||||||||
| "document_id": example.get("document_id", ""), | ||||||||||||||
| }, | ||||||||||||||
| } | ||||||||||||||
| return data | ||||||||||||||
|
|
||||||||||||||
| return process_fn | ||||||||||||||
|
|
||||||||||||||
|
|
||||||||||||||
| if __name__ == "__main__": | ||||||||||||||
| parser = argparse.ArgumentParser( | ||||||||||||||
| description="Convert summarization data to parquet format for LLM-as-a-judge training" | ||||||||||||||
| ) | ||||||||||||||
| parser.add_argument( | ||||||||||||||
| "--input_file", | ||||||||||||||
| type=str, | ||||||||||||||
| required=True, | ||||||||||||||
| help="Path to the input JSON or JSONL file", | ||||||||||||||
| ) | ||||||||||||||
| parser.add_argument( | ||||||||||||||
| "--output_dir", | ||||||||||||||
| type=str, | ||||||||||||||
| default="~/data/summarization_judge", | ||||||||||||||
| help="Output directory for parquet files", | ||||||||||||||
| ) | ||||||||||||||
| parser.add_argument( | ||||||||||||||
| "--train_split", | ||||||||||||||
| type=float, | ||||||||||||||
| default=0.9, | ||||||||||||||
| help="Fraction of data to use for training (default: 0.9)", | ||||||||||||||
| ) | ||||||||||||||
| parser.add_argument( | ||||||||||||||
| "--max_samples", | ||||||||||||||
| type=int, | ||||||||||||||
| default=None, | ||||||||||||||
| help="Maximum number of samples to use (default: all)", | ||||||||||||||
| ) | ||||||||||||||
|
|
||||||||||||||
| args = parser.parse_args() | ||||||||||||||
| args.output_dir = os.path.expanduser(args.output_dir) | ||||||||||||||
|
|
||||||||||||||
| # Load the data | ||||||||||||||
| print(f"Loading data from {args.input_file}...") | ||||||||||||||
| raw_data = load_local_data(args.input_file) | ||||||||||||||
| print(f"Loaded {len(raw_data)} samples") | ||||||||||||||
|
|
||||||||||||||
| # Optionally limit samples | ||||||||||||||
| if args.max_samples is not None: | ||||||||||||||
| raw_data = raw_data[:args.max_samples] | ||||||||||||||
| print(f"Limited to {len(raw_data)} samples") | ||||||||||||||
|
|
||||||||||||||
| # Convert to HuggingFace dataset | ||||||||||||||
| dataset = datasets.Dataset.from_list(raw_data) | ||||||||||||||
|
|
||||||||||||||
| # Split into train/val | ||||||||||||||
| split_idx = int(len(dataset) * args.train_split) | ||||||||||||||
| train_dataset = dataset.select(range(split_idx)) | ||||||||||||||
| val_dataset = dataset.select(range(split_idx, len(dataset))) | ||||||||||||||
|
Comment on lines
+167
to
+169
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The current method of splitting the dataset by index is not reproducible if the dataset order changes, and it doesn't shuffle the data, which can lead to biased splits. Using
Suggested change
|
||||||||||||||
|
|
||||||||||||||
| print(f"Train samples: {len(train_dataset)}, Validation samples: {len(val_dataset)}") | ||||||||||||||
|
|
||||||||||||||
| # Apply mapping function | ||||||||||||||
| train_dataset = train_dataset.map(function=make_map_fn("train"), with_indices=True) | ||||||||||||||
| val_dataset = val_dataset.map(function=make_map_fn("test"), with_indices=True) | ||||||||||||||
|
|
||||||||||||||
| # Save to parquet | ||||||||||||||
| os.makedirs(args.output_dir, exist_ok=True) | ||||||||||||||
| train_path = os.path.join(args.output_dir, "train.parquet") | ||||||||||||||
| val_path = os.path.join(args.output_dir, "validation.parquet") | ||||||||||||||
|
|
||||||||||||||
| train_dataset.to_parquet(train_path) | ||||||||||||||
| val_dataset.to_parquet(val_path) | ||||||||||||||
|
|
||||||||||||||
| print(f"Saved train dataset to {train_path}") | ||||||||||||||
| print(f"Saved validation dataset to {val_path}") | ||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The dataset is being truncated to a very small, hardcoded number of samples (128 for train, 64 for test). This is likely for debugging but significantly reduces the utility of the example script for other users. It would be better to either remove this truncation to use the full dataset by default, or make these sizes configurable via command-line arguments.