diff --git a/cookbook/rl/grpo.py b/cookbook/rl/grpo.py
index d7d5df21..9faa6f5b 100644
--- a/cookbook/rl/grpo.py
+++ b/cookbook/rl/grpo.py
@@ -21,7 +21,7 @@
logger = get_logger()
MODEL_ID = os.environ.get('MODEL_ID', 'ms://Qwen/Qwen3.5-4B')
-USE_MEGATRON = bool(int(os.environ.get('USE_MEGATRON', '0')))
+USE_MEGATRON = bool(int(os.environ.get('USE_MEGATRON', '1')))
MODEL_GPUS = int(os.environ.get('MODEL_GPUS', 4))
SAMPLER_GPUS = int(os.environ.get('SAMPLER_GPUS',4))
@@ -31,15 +31,16 @@
MAX_NEW_TOKENS = int(os.environ.get('MAX_NEW_TOKENS', 4096))
LEARNING_RATE = float(os.environ.get('LR', 1e-5))
MAX_STEPS = int(os.environ.get('MAX_STEPS', 200))
-BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 16)) # global prompt-level, global completion-level batch size = BATCH_SIZE * num_generations * dp_size
-MINI_BATCH_SIZE = int(os.environ.get('MINI_BATCH_SIZE', 16)) # global completion-level mini-batch-size
+BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 8)) # global prompt-level, global completion-level batch size = BATCH_SIZE * num_generations * dp_size
+MINI_BATCH_SIZE = int(os.environ.get('MINI_BATCH_SIZE', 8)) # global completion-level mini-batch-size
MICRO_BATCH_SIZE = int(os.environ.get('MICRO_BATCH_SIZE', 2)) # per-device-micro-batch-size (completion-level), batch_size in forward_backward
GRADIENT_ACCUMULATION_STEPS = int(os.environ.get('GRADIENT_ACCUMULATION_STEPS', 1))
ADAPTER_NAME = 'default'
+SAVE_STEPS = int(os.environ.get('SAVE_STEPS', 50))
def create_gsm8k_dataset():
dataset = Dataset(DatasetMeta('ms://modelscope/gsm8k', subset_name='main', split='train'))
- dataset.set_template('Template', model_id=MODEL_ID, max_length=2048)
+ dataset.set_template('Template', model_id=MODEL_ID, max_length=400)
dataset.map(GSM8KProcessor())
dataset.encode(add_generation_prompt=True)
return dataset
@@ -68,13 +69,21 @@ def main():
sampler_mesh = DeviceMesh.from_sizes(world_size=SAMPLER_GPUS, dp_size=SAMPLER_GPUS)
twinkle.initialize(mode='ray', nproc_per_node=NUM_GPUS, groups=device_groups, lazy_collect=False)
- lora_config = LoraConfig(target_modules='all-linear', r=32, lora_alpha=64, lora_dropout=0.05)
-
+ # lora_config = LoraConfig(target_modules='all-linear', r=32, lora_alpha=64, lora_dropout=0.05)
+ lora_config = LoraConfig(
+ target_modules=[
+ 'q_proj', 'k_proj', 'v_proj', 'o_proj',
+ 'gate_proj', 'up_proj', 'down_proj',
+ 'in_proj_qkv', 'in_proj_z', 'in_proj_a', 'in_proj_b', 'out_proj',
+ ],
+ r=32, lora_alpha=64, lora_dropout=0.05,
+ )
if USE_MEGATRON:
from twinkle.model.megatron import MegatronModel
model = MegatronModel(model_id=MODEL_ID, device_mesh=model_mesh, remote_group='model', mixed_precision='bf16')
else:
- model = TransformersModel(model_id=MODEL_ID, device_mesh=model_mesh, remote_group='model')
+ from transformers import Qwen3_5ForConditionalGeneration
+ model = TransformersModel(model_id=MODEL_ID, model_cls=Qwen3_5ForConditionalGeneration, device_mesh=model_mesh, remote_group='model')
model.add_adapter_to_model(ADAPTER_NAME, lora_config, gradient_accumulation_steps=1)
if USE_MEGATRON:
@@ -91,8 +100,9 @@ def main():
model_id=MODEL_ID,
engine_args={
'gpu_memory_utilization': 0.8,
- 'max_model_len': 4096,
+ 'max_model_len': 4496,
'max_lora_rank': 32, # save as lora_config
+ # NOTE: To use enable_lora with qwen3.5, ensure vLLM includes PR https://github.com/vllm-project/vllm/pull/36976
'enable_lora': True,
},
device_mesh=sampler_mesh,
@@ -172,6 +182,8 @@ def main():
if optim_step >= MAX_STEPS:
break
+ if optim_step % SAVE_STEPS == 0:
+ model.save(f'grpo-gsm8k-checkpoint-{optim_step}')
log_dict = metrics.calculate()
log_dict.update(model.calculate_metric(is_training=True))
metrics.reset()
diff --git a/src/twinkle/model/megatron/megatron.py b/src/twinkle/model/megatron/megatron.py
index ddddf41c..60cc30a8 100644
--- a/src/twinkle/model/megatron/megatron.py
+++ b/src/twinkle/model/megatron/megatron.py
@@ -1587,7 +1587,7 @@ def _trim_vocab(name, tensor):
if base_sync_done and adapter_name:
if merge_and_sync:
-
+ # LoRA Training and sync full model(merge_adapter)
def weight_generator():
for _model in self.strategy.unwrap_model(self.model):
if isinstance(_model, PeftModel):
@@ -1616,7 +1616,7 @@ def weight_generator():
yield name, tensor
else:
-
+ # First full base-model sync.
def _raw_weights():
for name, tensor in self.get_hf_state_dict(adapter_name=''):
if name is None or tensor is None:
@@ -1627,7 +1627,7 @@ def _raw_weights():
yield _trim_vocab(name, tensor)
def weight_generator():
- if is_peft_format:
+ if is_peft_format and not merge_and_sync:
yield from _add_base_layer_suffix(_raw_weights())
else:
yield from _raw_weights()
diff --git a/src/twinkle/model/transformers/transformers.py b/src/twinkle/model/transformers/transformers.py
index d00e80ed..48f08039 100644
--- a/src/twinkle/model/transformers/transformers.py
+++ b/src/twinkle/model/transformers/transformers.py
@@ -1159,21 +1159,28 @@ def send_weights(
# Get state dict from unwrapped model
model = self.strategy.unwrap_model(self.model)
+ def _normalize(name: str, keep_base_layer: bool) -> str:
+ name = name.replace('base_model.model.', '')
+ if not keep_base_layer:
+ name = name.replace('.base_layer', '')
+ return name
+
+ def _is_lora_key(name: str) -> bool:
+ return 'lora_A' in name or 'lora_B' in name or 'lora_embedding' in name
+
if base_sync_done and adapter_name:
if merge_and_sync:
-
+ # LoRA Training and sync full model(merge_adapter)
+ # merge and skip lora weigts(already merged)
+ # trim prefix(base_model.model.) and suffix(.base_layer)
def weight_generator():
if isinstance(model, PeftModel):
model.merge_adapter()
for name, tensor in model.state_dict().items():
- # Skip LoRA-specific weights for base model sync
- if 'lora_A' in name or 'lora_B' in name or 'lora_embedding' in name:
+ if _is_lora_key(name):
continue
tensor = Torch.to_local_tensor(tensor)
- # Keep original names (including .base_layer for PEFT models).
- # The sampler side will strip .base_layer based on whether
- # vLLM has enable_lora=True/False.
- yield name, tensor
+ yield _normalize(name, keep_base_layer=False), tensor
if isinstance(model, PeftModel):
model.unmerge_adapter()
else:
@@ -1188,19 +1195,19 @@ def weight_generator():
yield name, tensor
else:
- # Full model mode: send all weights (base model sync).
+ # First full base-model sync. Whether to keep ``.base_layer.``
+ # depends on whether the sampler uses ``enable_lora``:
+ # merge_and_sync=True → enable_lora=False → strip .base_layer
+ # merge_and_sync=False → enable_lora=True → keep .base_layer
+ keep_base_layer = not merge_and_sync
state_dict = model.state_dict()
def weight_generator():
for name, tensor in state_dict.items():
- # Skip LoRA-specific weights for base model sync
- if 'lora_A' in name or 'lora_B' in name or 'lora_embedding' in name:
+ if _is_lora_key(name):
continue
tensor = Torch.to_local_tensor(tensor)
- # Keep original names (including .base_layer for PEFT models).
- # The sampler side will strip .base_layer based on whether
- # vLLM has enable_lora=True/False.
- yield name, tensor
+ yield _normalize(name, keep_base_layer=keep_base_layer), tensor
# Run async send_weights in a dedicated event loop thread.
# We cannot use the Ray worker's event loop because it may already
diff --git a/src/twinkle/preprocessor/llm.py b/src/twinkle/preprocessor/llm.py
index a451e90c..97065fba 100644
--- a/src/twinkle/preprocessor/llm.py
+++ b/src/twinkle/preprocessor/llm.py
@@ -122,10 +122,8 @@ class GSM8KProcessor(Preprocessor):
Extracts the ground truth number and stores it in user_data for reward.
Only includes system + user messages; assistant response is generated on-policy.
"""
- system_prompt = ('You are a helpful math assistant. Solve the problem step by step. '
- 'Show your reasoning in tags, then give the final '
- 'numerical answer after ####.\n'
- 'For example:\n ... reasoning ... \n#### 42')
+ system_prompt = ('You are a helpful math assistant. Solve the problem step by step '
+ 'and put your final answer within \\boxed{}.')
def __init__(self, system=None, add_assistant=False):
self.system = system
diff --git a/src/twinkle/reward/gsm8k.py b/src/twinkle/reward/gsm8k.py
index 1f0f14b9..eb439675 100644
--- a/src/twinkle/reward/gsm8k.py
+++ b/src/twinkle/reward/gsm8k.py
@@ -7,15 +7,17 @@
class GSM8KAccuracyReward(Reward):
"""Accuracy reward for GSM8K: checks if the model's answer matches ground truth.
- Extracts the last '#### ' from model output and compares with ground truth.
+ Extracts the answer from \\boxed{} (preferred) or #### format.
Returns 1.0 for correct, 0.0 for incorrect.
"""
@staticmethod
def extract_answer(completion: str) -> str:
- """Extract the last #### answer from model completion."""
- # Only check last 500 chars for efficiency
+ """Extract the answer from model completion, preferring \\boxed{} over ####."""
text = completion[-500:] if len(completion) > 500 else completion
+ boxed = re.findall(r'\\boxed\{([^}]+)\}', text)
+ if boxed:
+ return boxed[-1].replace(',', '').replace(' ', '').strip()
matches = re.findall(r'####\s*([\-\d,\.\s]+)', text)
if matches:
return matches[-1].replace(',', '').replace(' ', '').strip()
@@ -54,9 +56,9 @@ def __call__(self, trajectories: List[Dict[str, Any]], **kwargs) -> List[float]:
class GSM8KFormatReward(Reward):
- """Format reward: checks if output contains ... tag.
+ """Format reward: checks if output contains \\boxed{} or #### answer format.
- Returns 1.0 if format is correct, 0.0 otherwise.
+ Returns 1.0 if a valid answer format is present, 0.0 otherwise.
"""
def __call__(self, trajectories: List[Dict[str, Any]], **kwargs) -> List[float]:
@@ -68,7 +70,6 @@ def __call__(self, trajectories: List[Dict[str, Any]], **kwargs) -> List[float]:
if msg.get('role') == 'assistant':
completion = msg.get('content', '')
break
- has_think = bool(re.search(r'.*?', completion, re.DOTALL))
- has_answer = bool(re.search(r'####\s*[\-\d,\.]+', completion))
- rewards.append(1.0 if (has_think and has_answer) else 0.0)
+ has_answer = bool(re.search(r'\\boxed\{[^}]+\}', completion) or re.search(r'####\s*[\-\d,\.]+', completion))
+ rewards.append(1.0 if has_answer else 0.0)
return rewards
diff --git a/src/twinkle/sampler/vllm_sampler/vllm_worker_extension.py b/src/twinkle/sampler/vllm_sampler/vllm_worker_extension.py
index 42be5095..61920cd9 100644
--- a/src/twinkle/sampler/vllm_sampler/vllm_worker_extension.py
+++ b/src/twinkle/sampler/vllm_sampler/vllm_worker_extension.py
@@ -390,11 +390,17 @@ def _load_weights(
"""Load a batch of weights into vLLM.
Two modes:
- - LoRA mode (``peft_config`` and ``base_sync_done``): Loads weights as
- a tensor-based LoRA adapter via ``add_lora()``.
- - Base model mode: Strips PEFT prefixes, merges split weights
- (q/k/v_proj -> qkv_proj, gate/up_proj -> gate_up_proj) into vLLM's
- stacked format, normalizes prefixes, then loads via direct param copy.
+
+ * **LoRA mode** (``peft_config`` set and ``base_sync_done=True``):
+ loads weights as a tensor-based LoRA adapter via ``add_lora()``.
+ * **Base model mode** (all other cases): delegates to
+ ``model.load_weights()`` which handles stacked-parameter merging
+ (q/k/v → qkv, gate/up → gate_up) and prefix mapping internally.
+
+ Weight names are expected to arrive **already normalised** by the
+ sender (``TransformersModel.send_weights`` /
+ ``MegatronModel.send_weights``), so no name transformation is done
+ here.
"""
if peft_config and base_sync_done:
# Remove existing LoRA before replacing
@@ -412,51 +418,9 @@ def _load_weights(
)
self.add_lora(lora_request)
else:
- # Base model mode — strip PEFT prefixes and delegate to
- # vLLM's model.load_weights() which handles stacked params,
- # prefix normalization, and weight_loader internally.
- vllm_has_lora = getattr(
- getattr(self, 'vllm_config', None),
- 'lora_config',
- None,
- ) is not None
-
- # When vLLM LoRA is enabled, some LinearBase modules are
- # replaced by *WithLoRA wrappers. Their parameters shift
- # from e.g. ``gate.weight`` to ``gate.base_layer.weight``.
- # HF checkpoint names do NOT contain ``.base_layer.``, so
- # vLLM's own ``load_weights`` will KeyError on them.
- #
- # Build a set of base-layer prefixes that need rewriting.
- lora_base_prefixes: set = set()
- if vllm_has_lora:
- from vllm.lora.layers import BaseLayerWithLoRA
- for mod_name, mod in self.model_runner.model.named_modules():
- if isinstance(mod, BaseLayerWithLoRA):
- # mod_name is e.g. "model.layers.0.mlp.gate"
- lora_base_prefixes.add(mod_name + '.')
-
- converted = []
- for name, tensor in weights:
- if 'lora_A' in name or 'lora_B' in name or 'lora_embedding' in name:
- continue
- name = name.removeprefix('model.base_model.model.')
- name = name.removeprefix('base_model.model.')
- if not vllm_has_lora:
- name = name.replace('.base_layer.', '.')
- else:
- # Insert ``.base_layer.`` for weights whose module
- # has been wrapped by LoRA and whose name does NOT
- # already contain it.
- if '.base_layer.' not in name:
- for pfx in lora_base_prefixes:
- if name.startswith(pfx):
- # e.g. "model.layers.0.mlp.gate.weight"
- # → "model.layers.0.mlp.gate.base_layer.weight"
- suffix = name[len(pfx):]
- name = pfx + 'base_layer.' + suffix
- break
- converted.append((name, tensor))
+ # Base model mode — weights arrive in canonical HF format
+ converted = [(n, t) for n, t in weights
+ if 'lora_A' not in n and 'lora_B' not in n and 'lora_embedding' not in n]
if not converted:
return