Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 89 additions & 0 deletions skyrl-train/examples/megatron/run_megatron_glm_flash_4.7.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
set -x

# Colocated GRPO training+generation for Qwen3-30B-A3B on GSM8K with Megatron.
# Runs on 2 nodes of 8xH100s

# uv run examples/gsm8k/gsm8k_dataset.py --output_dir $HOME/data/gsm8k
# export WANDB_API_KEY=<your_key_here>
# bash examples/megatron/run_megatron_qwen3-30b-a3b.sh

DATA_DIR="$HOME/data/gsm8k"
LOGGER="wandb" # change to "console" to print to stdout
MODEL_NAME="zai-org/GLM-4.7-Flash"

INFERENCE_BACKEND="vllm" # currently only vllm is supported for megatron

NUM_NODES=2
NUM_GPUS=8

MEGATRON_TP=4
MEGATRON_PP=2
MEGATRON_CP=1
MEGATRON_EP=8
MEGATRON_ETP=1

NUM_INFERENCE_ENGINES=2
INFERENCE_ENGINE_TP=8
FLASH_ATTN=true

# Megatron gradient checkpointing config
RECOMPUTE_GRANULARITY="full"
RECOMPUTE_METHOD="uniform"
RECOMPUTE_NUM_LAYERS=1

uv run --isolated --extra mcore -m skyrl_train.entrypoints.main_base \
data.train_data="['$DATA_DIR/train.parquet']" \
data.val_data="['$DATA_DIR/validation.parquet']" \
trainer.algorithm.advantage_estimator="grpo" \
trainer.policy.model.path=$MODEL_NAME \
trainer.placement.colocate_all=true \
trainer.strategy=megatron \
trainer.placement.policy_num_nodes=$NUM_NODES \
trainer.placement.ref_num_nodes=$NUM_NODES \
trainer.placement.policy_num_gpus_per_node=$NUM_GPUS \
trainer.placement.ref_num_gpus_per_node=$NUM_GPUS \
generator.num_inference_engines=$NUM_INFERENCE_ENGINES \
generator.inference_engine_tensor_parallel_size=$INFERENCE_ENGINE_TP \
trainer.policy.megatron_config.tensor_model_parallel_size=$MEGATRON_TP \
trainer.policy.megatron_config.pipeline_model_parallel_size=$MEGATRON_PP \
trainer.policy.megatron_config.context_parallel_size=$MEGATRON_CP \
trainer.ref.megatron_config.tensor_model_parallel_size=$MEGATRON_TP \
trainer.ref.megatron_config.context_parallel_size=$MEGATRON_CP \
trainer.ref.megatron_config.pipeline_model_parallel_size=$MEGATRON_PP \
trainer.policy.megatron_config.expert_model_parallel_size=$MEGATRON_EP \
trainer.policy.megatron_config.expert_tensor_parallel_size=$MEGATRON_ETP \
trainer.ref.megatron_config.expert_model_parallel_size=$MEGATRON_EP \
trainer.ref.megatron_config.expert_tensor_parallel_size=$MEGATRON_ETP \
trainer.policy.megatron_config.transformer_config_kwargs.recompute_granularity=$RECOMPUTE_GRANULARITY \
trainer.policy.megatron_config.transformer_config_kwargs.recompute_method=$RECOMPUTE_METHOD \
trainer.policy.megatron_config.transformer_config_kwargs.recompute_num_layers=$RECOMPUTE_NUM_LAYERS \
trainer.use_sample_packing=true \
trainer.flash_attn=$FLASH_ATTN \
trainer.epochs=20 \
trainer.eval_batch_size=1024 \
trainer.eval_before_train=false \
trainer.eval_interval=5 \
trainer.update_epochs_per_batch=1 \
trainer.train_batch_size=128 \
trainer.policy_mini_batch_size=64 \
trainer.micro_forward_batch_size_per_gpu=4 \
trainer.micro_train_batch_size_per_gpu=4 \
trainer.ckpt_interval=10 \
trainer.max_prompt_length=512 \
generator.sampling_params.max_generate_length=1024 \
trainer.policy.optimizer_config.lr=1.0e-6 \
trainer.algorithm.use_kl_loss=true \
generator.backend=$INFERENCE_BACKEND \
generator.run_engines_locally=true \
generator.weight_sync_backend=nccl \
generator.async_engine=true \
generator.batched=true \
environment.env_class=gsm8k \
generator.n_samples_per_prompt=5 \
generator.gpu_memory_utilization=0.6 \
trainer.logger="$LOGGER" \
trainer.project_name="gsm8k_megatron" \
trainer.run_name="gsm8k_megatron_tp${MEGATRON_TP}_pp${MEGATRON_PP}_cp${MEGATRON_CP}_ep${MEGATRON_EP}_etp${MEGATRON_ETP}_qwen3_30b_a3b" \
trainer.resume_mode=null \
trainer.ckpt_path="$HOME/ckpts/gsm8k_megatron_ckpt" \
$@
22 changes: 12 additions & 10 deletions skyrl-train/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ dependencies = [
"ninja",
"tensorboard",
"func_timeout",
"transformers>=4.51.0",
"transformers",
"hydra-core==1.3.2",
"accelerate",
"torchdata",
Expand Down Expand Up @@ -81,6 +81,8 @@ override-dependencies = [
"causal-conv1d; sys_platform == 'never'",
"transformer-engine[pytorch]==2.10.0; sys_platform == 'linux'",
"megatron-core==0.15.0; sys_platform == 'linux'",
"transformers==5.0.0; sys_platform == 'linux'",
"opencv-python-headless<4.13.0; sys_platform == 'linux'"
]
[tool.uv.extra-build-dependencies]
flash-attn = [{requirement = "torch", match-runtime = true}]
Expand Down Expand Up @@ -108,7 +110,7 @@ flashinfer-jit-cache = { index = "flashinfer-cu128", marker = "extra == 'vllm' o
flashinfer-python = [
{ url = "https://download.pytorch.org/whl/cu128/flashinfer/flashinfer_python-0.2.6.post1%2Bcu128torch2.7-cp39-abi3-linux_x86_64.whl", marker = "extra == 'sglang' and extra != 'mcore' and extra != 'vllm'" }
]
megatron-bridge = {git = "https://github.com/NVIDIA-NeMo/Megatron-Bridge", rev = "04e370eedf8cc44a812189a19f2171d90555c07a", marker = "sys_platform == 'linux'"}
megatron-bridge = {git = "https://github.com/erictang000/Megatron-Bridge", branch = "glm_flash_4.7", marker = "sys_platform == 'linux'"}


[project.optional-dependencies]
Expand All @@ -132,11 +134,11 @@ harbor = [
"harbor",
]
vllm = [
"vllm==0.13.0; sys_platform == 'linux'",
"vllm==0.15.0; sys_platform == 'linux'",
"flash-attn==2.8.3; sys_platform == 'linux'",
"torch==2.9.0; sys_platform == 'linux'",
"flashinfer-python; sys_platform == 'linux'",
"flashinfer-jit-cache==0.5.3; sys_platform == 'linux'",
"torch==2.9.1; sys_platform == 'linux'",
"flashinfer-python==0.6.1; sys_platform == 'linux'",
"flashinfer-jit-cache==0.6.1; sys_platform == 'linux'",
"torchvision; sys_platform == 'linux'",
]
sglang = [
Expand All @@ -149,13 +151,13 @@ sglang = [
mcore = [
"transformer-engine[pytorch]==2.10.0; sys_platform == 'linux'",
"flash-attn==2.8.1; sys_platform == 'linux'",
"vllm==0.13.0; sys_platform == 'linux'",
"torch==2.9.0; sys_platform == 'linux'",
"flashinfer-python==0.5.3; sys_platform == 'linux'",
"vllm==0.15.0; sys_platform == 'linux'",
"torch==2.9.1; sys_platform == 'linux'",
"flashinfer-python==0.6.1; sys_platform == 'linux'",
"torchvision; sys_platform == 'linux'",
"megatron-bridge; sys_platform == 'linux'",
"megatron-core==0.15.0; sys_platform == 'linux'",
"flashinfer-jit-cache==0.5.3; sys_platform == 'linux'",
"flashinfer-jit-cache==0.6.1; sys_platform == 'linux'",
"nvidia-modelopt; sys_platform == 'linux'",
]
flashrl = [
Expand Down
15 changes: 8 additions & 7 deletions skyrl-train/skyrl_train/inference_engines/vllm/vllm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,17 @@
from types import SimpleNamespace
from vllm import SamplingParams
from vllm.inputs import TokensPrompt
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels
from vllm.entrypoints.openai.protocol import (
from vllm.entrypoints.openai.chat_completion.serving import OpenAIServingChat
from vllm.entrypoints.openai.completion.serving import OpenAIServingCompletion
from vllm.entrypoints.openai.models.serving import BaseModelPath, OpenAIServingModels
from vllm.entrypoints.openai.chat_completion.protocol import (
ChatCompletionRequest,
ChatCompletionResponse,
ErrorResponse,
CompletionRequest,
CompletionResponse,
)
from vllm.entrypoints.openai.engine.protocol import ErrorResponse

from vllm.entrypoints.openai.completion.protocol import CompletionRequest, CompletionResponse

from vllm.lora.request import LoRARequest
from uuid import uuid4
from skyrl_train.inference_engines.base import (
Expand Down
43 changes: 22 additions & 21 deletions skyrl-train/tests/gpu/gpu_ci/test_megatron_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
# TODO (erictang000): we would prefer to use this smaller MoE model for testing, but seeing incorrect logprobs when using EP > 1
# this might be a model specific mbridge issue - see if this persists when we transition to Megatron-Bridge
# MOE_MODEL_NAME = "Qwen/Qwen1.5-MoE-A2.7B"
MOE_MODEL_NAME = "Qwen/Qwen3-30B-A3B"
MOE_MODEL_NAME = "zai-org/GLM-4.7-Flash"


def get_test_actor_config(model_name=MODEL_NAME) -> SkyRLConfig:
Expand Down Expand Up @@ -110,19 +110,19 @@ def get_test_training_batch(batch_size=4) -> TrainingInputBatch:


@pytest.mark.parametrize(
("colocate_all", "inference_tp", "megatron_tp", "megatron_pp", "megatron_ep", "megatron_etp", "lora"),
[(True, 4, 2, 2, 1, None, False), (False, 2, 2, 1, 1, None, False), (True, 4, 2, 2, 1, None, True)],
ids=["colocate_all", "non_colocated", "colocate_all_lora"],
("colocate_all", "inference_tp", "megatron_tp", "megatron_pp", "megatron_ep", "megatron_etp", "num_gpus_per_node", "lora", "model_name"),
[(True, 4, 4, 1, 8, 1, 8, False, MOE_MODEL_NAME), (False, 2, 2, 1, 1, None, 2, False, MODEL_NAME), (True, 4, 2, 2, 1, None, 4, True, MOE_MODEL_NAME)],
ids=["x", "non_colocated", "colocate_all_lora"],
)
@pytest.mark.megatron
def test_megatron_policy_weight_sync(
colocate_all, inference_tp, megatron_tp, megatron_pp, megatron_ep, megatron_etp, lora
colocate_all, inference_tp, megatron_tp, megatron_pp, megatron_ep, megatron_etp, num_gpus_per_node, lora, model_name
):
"""
Test that we can sync weights between policy and inference for megatron then run inference
"""
try:
cfg = get_test_actor_config(model_name=MODEL_NAME)
cfg = get_test_actor_config(model_name=model_name)
if lora:
cfg.trainer.policy.model.lora = SkyRLLoraConfig(rank=16, alpha=16)
cfg.trainer.placement.colocate_all = colocate_all
Expand All @@ -139,12 +139,13 @@ def test_megatron_policy_weight_sync(

# If colocate is True, this will load the engine, sleep, and wake up the engine
client, pg, router, server_group = init_inference_engines(
model=MODEL_NAME,
model=model_name,
cfg=cfg,
use_local=True,
async_engine=cfg.generator.async_engine,
tp_size=cfg.generator.inference_engine_tensor_parallel_size,
colocate_all=cfg.trainer.placement.colocate_all,
num_inference_engines=num_gpus_per_node // cfg.generator.inference_engine_tensor_parallel_size,
backend="vllm",
sleep_level=2, # since we explicitly sync weights
)
Expand All @@ -155,7 +156,7 @@ def test_megatron_policy_weight_sync(
"policy",
shared_pg=pg,
colocate_all=cfg.trainer.placement.colocate_all,
num_gpus_per_node=cfg.generator.inference_engine_tensor_parallel_size,
num_gpus_per_node=num_gpus_per_node,
cfg=cfg,
)
ray.get(policy.async_run_ray_method("pass_through", "init_weight_sync_state", client))
Expand All @@ -170,7 +171,7 @@ def test_megatron_policy_weight_sync(
policy.offload_to_cpu()
asyncio.run(client.wake_up(tags=["kv_cache"]))
sampling_params = get_sampling_params_for_backend(cfg.generator.backend, cfg.generator.sampling_params)
outputs = asyncio.run(run_inference(client, get_test_prompts(MODEL_NAME), sampling_params))
outputs = asyncio.run(run_inference(client, get_test_prompts(model_name), sampling_params))

print(f"Example output: {outputs['responses'][0]}, {outputs['stop_reasons'][0]}")
finally:
Expand All @@ -179,17 +180,17 @@ def test_megatron_policy_weight_sync(

@pytest.mark.asyncio
@pytest.mark.parametrize(
("worker_type", "tp", "pp", "cp", "ep", "etp", "gpus_per_node", "use_sample_packing", "lora"),
("worker_type", "tp", "pp", "cp", "ep", "etp", "gpus_per_node", "use_sample_packing", "lora", "model_name"),
[
("policy", 2, 1, 1, 1, None, 2, False, False),
("policy", 2, 1, 1, 1, None, 2, False, False, MODEL_NAME),
# ref has same forward pass as policy - just duplicate one test to test setup
("ref", 2, 1, 1, 1, None, 2, False, False),
("policy", 2, 2, 1, 1, None, 4, False, False),
("policy", 2, 2, 1, 1, None, 4, True, False),
("policy", 2, 2, 1, 1, None, 4, True, True),
("policy", 1, 1, 2, 1, None, 2, True, False),
("policy", 2, 1, 2, 1, None, 4, True, False),
("policy", 4, 1, 1, 4, 1, 4, True, False),
("ref", 2, 1, 1, 1, None, 2, False, False, MODEL_NAME),
("policy", 2, 2, 1, 1, None, 4, False, False, MODEL_NAME),
("policy", 2, 2, 1, 1, None, 4, True, False, MODEL_NAME),
("policy", 2, 2, 1, 1, None, 4, True, True, MODEL_NAME),
("policy", 1, 1, 2, 1, None, 2, True, False, MODEL_NAME),
("policy", 2, 1, 2, 1, None, 4, True, False, MODEL_NAME),
("policy", 4, 1, 1, 4, 1, 4, True, False, MOE_MODEL_NAME),
],
ids=[
"tp2_pp1_policy",
Expand All @@ -204,12 +205,12 @@ def test_megatron_policy_weight_sync(
)
@pytest.mark.megatron
async def test_megatron_forward(
ray_init_fixture, worker_type, tp, pp, cp, ep, etp, gpus_per_node, use_sample_packing, lora
ray_init_fixture, worker_type, tp, pp, cp, ep, etp, gpus_per_node, use_sample_packing, lora, model_name
):
"""
Test that the Megatron forward pass is numerically equivalent to just running a huggingface model forward.
"""
cfg = get_test_actor_config(model_name=MOE_MODEL_NAME if ep > 1 else MODEL_NAME)
cfg = get_test_actor_config(model_name=model_name)
#### Megatron forward pass ####
cfg.trainer.strategy = "megatron"
cfg.trainer.placement.policy_num_gpus_per_node = gpus_per_node
Expand Down Expand Up @@ -278,7 +279,7 @@ def run_hf_forward(batch, model_name):
return attention_mask.to("cpu").detach(), action_log_probs.to("cpu").detach(), num_actions

attention_mask, action_log_probs, num_actions = ray.get(
run_hf_forward.remote(batch, MOE_MODEL_NAME if ep > 1 else MODEL_NAME)
run_hf_forward.remote(batch, model_name)
)

#### Compare results ####
Expand Down
Loading
Loading