Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 119 additions & 0 deletions miles_qwen3_8b_h100/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
FROM anyscale/ray:2.54.0-py312-cu129

ARG PATCH_VERSION=latest
ARG MEGATRON_COMMIT=3714d81d418c9f1bca4594fc35f9e8289f652862
ARG SGLANG_COMMIT=24c91001cf99ba642be791e099d358f4dfe955f5
ARG MILES_REF=main

# Anyscale base image runs as non-root; use sudo for system installs.
WORKDIR /home/ray

RUN sudo apt-get update && \
sudo apt-get install -y --no-install-recommends git rsync dnsutils nvtop && \
sudo rm -rf /var/lib/apt/lists/*

# Keep pip tooling current and pin numpy to 1.x for Megatron compatibility.
RUN python -m pip install --upgrade pip setuptools wheel && \
python -m pip install "numpy<2" huggingface_hub

# Pin PyTorch 2.9.1 — matches sgl_kernel from PyPI (compiled for torch 2.9.x)
# and has a pre-built flash-attn 2.8.3 wheel available.
RUN python -m pip install torch==2.9.1 torchvision torchaudio \
--index-url https://download.pytorch.org/whl/cu128

# Pre-built flash-attn wheel for torch 2.9 + cu12 (source compilation
# exceeds Anyscale's ~60 min build timeout).
RUN python -m pip install https://github.com/Dao-AILab/flash-attention/releases/download/v2.8.3/flash_attn-2.8.3%2Bcu12torch2.9cxx11abiTRUE-cp312-cp312-linux_x86_64.whl

# Apex: install Python-only (no CUDA extensions) to stay within Anyscale's
# ~60 min build timeout. Megatron falls back to PyTorch-native kernels.
RUN git clone --filter=blob:none https://github.com/NVIDIA/apex.git /tmp/apex && \
cd /tmp/apex && \
git checkout 10417aceddd7d5d05d7cbf7b0fc2daad1105f8b4 && \
python -m pip install --disable-pip-version-check --no-cache-dir \
--no-build-isolation . && \
rm -rf /tmp/apex

# Install SGLang from source. sgl_kernel comes from PyPI, pre-compiled
# for torch 2.9.x — no need to rebuild from source.
RUN git clone https://github.com/sgl-project/sglang.git /home/ray/sglang && \
cd /home/ray/sglang && \
git checkout ${SGLANG_COMMIT} && \
python -m pip install -e "python[all]"

# Install Megatron-LM from source.
RUN git clone --recursive https://github.com/NVIDIA/Megatron-LM.git /home/ray/Megatron-LM && \
cd /home/ray/Megatron-LM && \
git checkout ${MEGATRON_COMMIT} && \
python -m pip install -e .

# Pull Miles source for patches and dependency manifests.
RUN git clone https://github.com/radixark/miles.git /tmp/miles && \
cd /tmp/miles && \
git checkout ${MILES_REF}

# Apply SGLang patch.
RUN cd /home/ray/sglang && \
cp /tmp/miles/docker/patch/${PATCH_VERSION}/sglang.patch ./sglang.patch && \
git update-index --refresh && \
git apply sglang.patch --3way && \
if grep -R -n '^<<<<<<< ' .; then \
echo "SGLang patch failed to apply cleanly. Please resolve conflicts." && \
exit 1; \
fi && \
rm sglang.patch

# Apply Megatron-LM patch.
RUN cd /home/ray/Megatron-LM && \
cp /tmp/miles/docker/patch/${PATCH_VERSION}/megatron.patch ./megatron.patch && \
git update-index --refresh && \
git apply megatron.patch --3way && \
if grep -R -n '^<<<<<<< ' .; then \
echo "Megatron patch failed to apply cleanly. Please resolve conflicts." && \
exit 1; \
fi && \
rm megatron.patch

# Install Miles dependencies.
RUN python -m pip install git+https://github.com/ISEEKYAN/mbridge.git@89eb10887887bc74853f89a4de258c0702932a1c --no-deps && \
python -m pip install git+https://github.com/fzyzcjy/torch_memory_saver.git@dc6876905830430b5054325fa4211ff302169c6b --no-cache-dir --force-reinstall && \
python -m pip install git+https://github.com/fzyzcjy/Megatron-Bridge.git@dev_rl --no-build-isolation && \
python -m pip install "nvidia-modelopt[torch]>=0.37.0" --no-build-isolation

# Make MXFP8 quantizer import conditional — mxfp8_group_quantize was added
# in a newer SGLang than our pinned commit. Not needed for Qwen3-8B training.
RUN python -c "\
import pathlib; \
p = pathlib.Path('/tmp/miles/miles/backends/megatron_utils/megatron_to_hf/processors/quantizer_mxfp8.py'); \
t = p.read_text(); \
t = t.replace( \
'from sglang.srt.layers.quantization.fp8_utils import mxfp8_group_quantize', \
'try:\\n from sglang.srt.layers.quantization.fp8_utils import mxfp8_group_quantize\\nexcept ImportError:\\n mxfp8_group_quantize = None' \
); \
p.write_text(t)"

# Install Miles itself.
RUN python -m pip install -r /tmp/miles/requirements.txt && \
python -m pip install -e /tmp/miles --no-deps && \
cd /tmp/miles/miles/backends/megatron_utils/kernels/int4_qat && \
python -m pip install . --no-build-isolation

# Re-pin PyTorch 2.9.1 and reinstall flash-attn + TE at the end.
# Earlier installs may have upgraded torch, breaking pre-built binary wheels.
RUN python -c "import torch; print(f'Before re-pin: PyTorch {torch.__version__}')"
RUN python -m pip install torch==2.9.1 torchvision torchaudio \
--index-url https://download.pytorch.org/whl/cu128
RUN python -m pip install --force-reinstall --no-deps \
https://github.com/Dao-AILab/flash-attention/releases/download/v2.8.3/flash_attn-2.8.3%2Bcu12torch2.9cxx11abiTRUE-cp312-cp312-linux_x86_64.whl
RUN python -m pip install --no-build-isolation "transformer_engine[pytorch]==2.10.0"

# Verify torch + flash-attn ABI compatibility.
# sgl_kernel is skipped here — it requires libcuda.so.1 (GPU hardware) to import.
RUN python -c "\
import torch; print(f'PyTorch: {torch.__version__}, CUDA: {torch.version.cuda}'); \
assert torch.__version__.startswith('2.9'), f'Expected 2.9.x, got {torch.__version__}'; \
from flash_attn import flash_attn_func; print('flash-attn OK')"

ENV PYTHONPATH=/home/ray/Megatron-LM:$PYTHONPATH

WORKDIR /tmp/miles
39 changes: 39 additions & 0 deletions miles_qwen3_8b_h100/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# GRPO Training for Qwen3-8B with MILES

This example demonstrates reinforcement learning fine-tuning of Qwen3-8B using **Group Relative Policy Optimization (GRPO)** on the DAPO-Math-17k dataset. It uses the [MILES](https://github.com/radixark/miles) framework for distributed RL training with disaggregated rollouts on Anyscale.

The training runs on a single node with **8x H100-80GB GPUs**, using:
- **4 GPUs for training** (TP=2, DP=2 with Megatron-LM)
- **4 GPUs for rollout inference** (disaggregated SGLang engines)

## Install the Anyscale CLI

```bash
pip install -U anyscale
anyscale login
```

## Submit the job

Clone the example from GitHub.

```bash
git clone https://github.com/anyscale/examples.git
cd examples/miles_qwen3_8b_h100
```

Submit the job.

```bash
anyscale job submit -f job.yaml
```

The entrypoint will automatically download the model and dataset, convert weights to Megatron format, and start training. Training progress can be monitored via TensorBoard logs in `/mnt/cluster_storage/tensorboard_logs`.

## Understanding the example

- **Algorithm**: This example uses GRPO with DAPO-style asymmetric clipping (ε_low=0.2, ε_high=0.28), which is particularly effective for math reasoning tasks.
- **Dataset**: [DAPO-Math-17k](https://huggingface.co/datasets/zhuzilin/dapo-math-17k) contains 17k integer math problems with deterministic reward signals based on answer correctness.
- **Disaggregated architecture**: Training and rollout happen on separate GPUs for maximum throughput. The 4 SGLang rollout engines run inference in parallel while the training GPUs perform gradient updates.
- **Weight conversion**: On the first run, HuggingFace weights are converted to Megatron-LM's `torch_dist` format. Converted weights are cached in `/mnt/cluster_storage/Qwen3-8B_torch_dist` for subsequent runs.
- **Async training**: The pipeline uses `train_async.py` which overlaps rollout generation and policy updates for better GPU utilization.
30 changes: 30 additions & 0 deletions miles_qwen3_8b_h100/convert_weights_remote.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
#!/usr/bin/env python
"""Ray remote wrapper for weight conversion - ensures it runs on a GPU worker."""
import sys
import subprocess
import ray

@ray.remote(num_gpus=1)
def convert_weights(cmd_args):
"""Run weight conversion on a GPU worker."""
result = subprocess.run(
["python", "/tmp/miles/tools/convert_hf_to_torch_dist.py"] + cmd_args,
capture_output=True,
text=True
)
return result.returncode, result.stdout, result.stderr

if __name__ == "__main__":
# Pass through all command-line arguments
cmd_args = sys.argv[1:]

# Run conversion on GPU worker
returncode, stdout, stderr = ray.get(convert_weights.remote(cmd_args))

# Print output
if stdout:
print(stdout, end="")
if stderr:
print(stderr, end="", file=sys.stderr)

sys.exit(returncode)
147 changes: 147 additions & 0 deletions miles_qwen3_8b_h100/entrypoint.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
#!/bin/bash
# Anyscale entrypoint: Qwen3-8B GRPO training on 1 worker × 8x H100-80GB
# Downloads model/dataset, converts weights, and runs async RL training.
#
# Head node (m5.2xlarge): driver only, no GPUs
# Layout (GPU worker):
# Worker 0 (8x H100):
# GPU 0-3: Training (TP=2, DP=2)
# GPU 4-7: Rollout (4 SGLang engines, 1 GPU each)

set -ex

export PYTHONBUFFERED=16
STORAGE=/mnt/cluster_storage

# Qwen3-8B model architecture args (from scripts/models/qwen3-8B.sh)
MODEL_ARGS=(
--swiglu
--num-layers 36
--hidden-size 4096
--ffn-hidden-size 12288
--num-attention-heads 32
--group-query-attention
--num-query-groups 8
--use-rotary-position-embeddings
--disable-bias-linear
--normalization "RMSNorm"
--norm-epsilon 1e-6
--rotary-base 1000000
--vocab-size 151936
--kv-channels 128
--qk-layernorm
--untie-embeddings-and-output-weights
)

# ======================== Step 1: Download model & dataset ========================

echo "=== Downloading model ==="
huggingface-cli download Qwen/Qwen3-8B --local-dir ${STORAGE}/Qwen3-8B

echo "=== Downloading dataset ==="
huggingface-cli download --repo-type dataset zhuzilin/dapo-math-17k --local-dir ${STORAGE}/dapo-math-17k

# ======================== Step 2: Convert HF weights to torch_dist ========================

if [ ! -d "${STORAGE}/Qwen3-8B_torch_dist/iter_0000000" ]; then
echo "=== Converting weights (HF -> torch_dist) on GPU worker ==="
python convert_weights_remote.py \
${MODEL_ARGS[@]} \
--no-gradient-accumulation-fusion \
--hf-checkpoint ${STORAGE}/Qwen3-8B \
--save ${STORAGE}/Qwen3-8B_torch_dist
else
echo "=== Converted weights already exist, skipping ==="
fi

# ======================== Step 3: Run training ========================

CKPT_ARGS=(
--hf-checkpoint ${STORAGE}/Qwen3-8B
--ref-load ${STORAGE}/Qwen3-8B_torch_dist
--load ${STORAGE}/Qwen3-8B_torch_dist
--save ${STORAGE}/Qwen3-8B_miles/
--save-interval 20
)

ROLLOUT_ARGS=(
--prompt-data ${STORAGE}/dapo-math-17k/dapo-math-17k.jsonl
--input-key prompt
--label-key label
--apply-chat-template
--rollout-shuffle
--balance-data
--rm-type dapo
--reward-key score
--num-rollout 5
--rollout-batch-size 32
--n-samples-per-prompt 8
--rollout-max-response-len 8192
--rollout-temperature 1
--global-batch-size 256
)

PERF_ARGS=(
--tensor-model-parallel-size 2
--sequence-parallel
--pipeline-model-parallel-size 1
--context-parallel-size 1
--expert-model-parallel-size 1
--expert-tensor-parallel-size 1

--recompute-granularity full
--recompute-method uniform
--recompute-num-layers 1

--use-dynamic-batch-size
--max-tokens-per-gpu 9216
)

GRPO_ARGS=(
--advantage-estimator grpo
--use-kl-loss
--kl-loss-coef 0.00
--kl-loss-type low_var_kl
--entropy-coef 0.00
--eps-clip 0.2
--eps-clip-high 0.28
)

OPTIMIZER_ARGS=(
--optimizer adam
--lr 1e-6
--lr-decay-style constant
--weight-decay 0.1
--adam-beta1 0.9
--adam-beta2 0.98
)

SGLANG_ARGS=(
--rollout-num-gpus-per-engine 1
--sglang-mem-fraction-static 0.7
)

MISC_ARGS=(
--no-gradient-accumulation-fusion
--attention-dropout 0.0
--hidden-dropout 0.0
--accumulate-allreduce-grads-in-fp32
--attention-softmax-in-fp32
--attention-backend flash
--use-tensorboard
--tensorboard-dir ${STORAGE}/tensorboard_logs
)

echo "=== Starting training ==="
python train_remote.py \
--actor-num-nodes 1 \
--actor-num-gpus-per-node 4 \
--rollout-num-gpus 4 \
${MODEL_ARGS[@]} \
${CKPT_ARGS[@]} \
${ROLLOUT_ARGS[@]} \
${OPTIMIZER_ARGS[@]} \
${GRPO_ARGS[@]} \
${PERF_ARGS[@]} \
${SGLANG_ARGS[@]} \
${MISC_ARGS[@]}
40 changes: 40 additions & 0 deletions miles_qwen3_8b_h100/job.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Anyscale job config: Miles Qwen3-8B GRPO training on H100
# Single node × 8x H100-80GB
#
# Layout:
# Head node (m5.2xlarge): driver only, no GPUs
# Worker 0 (8x H100): [GPU 0-3: Training TP=2 DP=2] [GPU 4-7: Rollout (4 engines)]
#
# Submit with:
# cd miles_qwen3_8b_h100
# anyscale job submit -f job.yaml

name: miles-qwen3-8b-grpo-h100

containerfile: ./Dockerfile

compute_config:
head_node:
required_resources:
CPU: 8
memory: 32Gi
worker_nodes:
- name: h100-workers
required_resources:
CPU: 192
memory: 2048Gi
GPU: 8
required_labels:
ray.io/accelerator-type: H100
min_nodes: 1
max_nodes: 1

working_dir: .

entrypoint: bash entrypoint.sh

env_vars:
CUDA_DEVICE_MAX_CONNECTIONS: "1"

max_retries: 0
timeout_s: 7200
Loading