Skip to content

MuLabPKU/PaST

Repository files navigation

PaST β€” Parametric Skill Transfer

Official code release for our ACL 2026 paper:

Knowledge is Not Enough: Injecting RL Skills for Continual Adaptation Pingzhi Tang, Yiding Wang, Muhan Zhang. Annual Meeting of the Association for Computational Linguistics (ACL), 2026. arXiv:2601.11258

PaST (Parametric Skill Transfer) extracts a domain-agnostic skill vector β€” the parameter delta produced by RL post-training on a source domain β€” and injects it into a target model that has only been lightly SFT-ed on a new domain. This recovers the reasoning skills RL would otherwise have to be repeated for each new domain, while keeping the freshly-learned target-domain knowledge intact.

Repository scope

This repository contains the code for two of the experimental settings in the paper:

Setting Location Paper section Description
LooGLE (long-context QA) repo root (this README) Β§5.1.2 Per-passage two-stage SFT + GRPO + skill-vector inheritance on the LooGLE benchmark.
SQuAD (closed-book knowledge incorporation, on top of the SEAL framework) SEAL/ Β§5.1.1 Implications-based SFT + GRPO with a judge reward, evaluated with delta-weight injection during test-time training. Built on SEAL (Zweiger et al., 2026).

The rest of this README documents the LooGLE pipeline. For the SQuAD setting see SEAL/README.md and SEAL/general-knowledge/README_GRPO_DELTA.md.


Table of contents

  1. Repository layout
  2. Setup
  3. Pipeline overview
  4. Step-by-step usage
  5. The PaST operator
  6. Evaluation
  7. Citation

Repository layout

PaST/
β”œβ”€β”€ inherit_weight.py        # Core PaST operator: subtract base from RL model,
β”‚                              add the resulting skill vector onto target models.
β”œβ”€β”€ calc_delta_and_save.py   # Save (RL βˆ’ base) deltas to a `.pth` file.
β”œβ”€β”€ reward_function.py       # verl-compatible reward; queries a judge HTTP server.
β”œβ”€β”€ reward_server.py         # FastAPI judge server (OpenAI / vLLM backend).
β”œβ”€β”€ prompts.py               # All prompt templates (proposer / solver / judge).
β”œβ”€β”€ utils.py                 # Long-text chunking and parsing helpers.
β”œβ”€β”€ merge_lora.py            # Helper: merge a LoRA adapter into its base model.
β”‚
β”œβ”€β”€ prepare_data/            # Data generation
β”‚   β”œβ”€β”€ generate_mixed_data.py     # Stage-1 SFT data (summary / recall / continue).
β”‚   β”œβ”€β”€ generate_qa.py             # QA-pair generation for stage-2 SFT and GRPO.
β”‚   β”œβ”€β”€ loogle_to_parquet.py       # Convert LooGLE jsonl β†’ verl-format parquet.
β”‚   β”œβ”€β”€ preprocess_squad.py        # SQuAD preprocessing.
β”‚   β”œβ”€β”€ preprocess_narrativeqa.py  # NarrativeQA preprocessing.
β”‚   β”œβ”€β”€ preprocess_gsm8k.py        # GSM8K preprocessing.
β”‚   └── mix_data.py                # Concatenate per-passage parquet shards.
β”‚
β”œβ”€β”€ eval_loogle/             # LooGLE evaluation
β”‚   β”œβ”€β”€ common.py              # Shared loaders (model / tokenizer / data).
β”‚   β”œβ”€β”€ eval_with_remote.py    # Generate + judge in one process.
β”‚   β”œβ”€β”€ generate_answers.py    # Generate answers only (decoupled from judging).
β”‚   β”œβ”€β”€ score_answers.py       # Score generated answers locally.
β”‚   β”œβ”€β”€ score_answers_with_server.py  # Score against the judge server.
β”‚   └── judge_utils.py
β”‚
β”œβ”€β”€ analysis/                # Skill-vector analysis
β”‚   β”œβ”€β”€ extract_delta.py     # Extract per-layer/per-tensor delta statistics.
β”‚   β”œβ”€β”€ heatmap.py           # Visualize the delta as a heatmap.
β”‚   └── save_cosine.py       # Cosine similarity between deltas across runs.
β”‚
β”œβ”€β”€ scripts/
β”‚   β”œβ”€β”€ gen_data/            # Data-generation entry points.
β”‚   β”‚   β”œβ”€β”€ gen_sft_data.sh
β”‚   β”‚   β”œβ”€β”€ gen_rl_data.sh
β”‚   β”‚   └── process_loogle_eval.sh
β”‚   β”œβ”€β”€ train/               # Training entry points.
β”‚   β”‚   β”œβ”€β”€ sft.sh             # Stage-1 + Stage-2 SFT for one passage.
β”‚   β”‚   β”œβ”€β”€ all_sft.sh         # Loop SFT over many passages.
β”‚   β”‚   β”œβ”€β”€ sft_iter.sh        # Iterative SFT loop.
β”‚   β”‚   β”œβ”€β”€ grpo.sh            # GRPO on one passage.
β”‚   β”‚   β”œβ”€β”€ grpo_single.sh     # GRPO with a different schedule.
β”‚   β”‚   β”œβ”€β”€ grpo_round10.sh    # Multi-round GRPO with PaST inheritance.
β”‚   β”‚   β”œβ”€β”€ start_judge_server.sh  # Launch the FastAPI judge server.
β”‚   β”‚   └── merge_data.sh
β”‚   β”œβ”€β”€ eval/
β”‚   β”‚   └── eval_sft.sh
β”‚   β”œβ”€β”€ eval_gsm8k.sh
β”‚   β”œβ”€β”€ eval_with_context.sh
β”‚   β”œβ”€β”€ merge_and_eval.sh    # PaST inheritance + LooGLE evaluation.
β”‚   β”œβ”€β”€ merge_fsdp_to_hf.sh  # Convert FSDP checkpoints to HF format.
β”‚   β”œβ”€β”€ run_train_pipeline.sh
β”‚   β”œβ”€β”€ show_results.py      # Aggregate test_results.json into a markdown table.
β”‚   └── test_merge.sh
β”‚
β”œβ”€β”€ debug/                   # Small inspection / sanity-check scripts.
β”œβ”€β”€ loogle-clean/            # Cleaned LooGLE benchmark (data + reference scripts).
β”œβ”€β”€ SEAL/                    # SQuAD experiment built on the SEAL framework
β”‚                              (paper Β§5.1.1). See SEAL/README.md.
β”œβ”€β”€ env.yaml                 # Conda environment (`past`).
β”œβ”€β”€ flashattn_requirements.txt
└── README.md

Runtime / output directories (created by the scripts, all gitignored): data/, data-new/, checkpoints/, outputs/, wandb/, eval_results/, results_combined/, viz_output/.


Setup

1. Clone the repo

git clone https://github.com/MuLabPKU/PaST.git
cd PaST

2. Create the environment

conda env create -f env.yaml
conda activate past   # the env name in env.yaml
pip install -r flashattn_requirements.txt

3. Patch verl

Following verl#1296 and this fix, patch verl/trainer/ppo/reward.py so that custom reward functions are loaded by a stable, path-hashed module name. The full patched function:

import multiprocessing
import os
import hashlib
from functools import partial

import ray

from verl import DataProto
from verl.utils.reward_score import default_compute_score


def _call_with_kwargs(raw_fn, extra_kwargs, *args, **kwargs):
    merged_kwargs = {**kwargs, **extra_kwargs}
    return raw_fn(*args, **merged_kwargs)


def get_custom_reward_fn(config):
    import importlib.util
    import sys

    reward_fn_config = config.get("custom_reward_function") or {}
    file_path = reward_fn_config.get("path")
    if not file_path:
        return None

    if not os.path.exists(file_path):
        raise FileNotFoundError(f"Reward function file '{file_path}' not found.")

    module_name = f"custom_reward_{hashlib.md5(file_path.encode()).hexdigest()}"

    if module_name in sys.modules:
        module = sys.modules[module_name]
    else:
        spec = importlib.util.spec_from_file_location(module_name, file_path)
        module = importlib.util.module_from_spec(spec)
        try:
            sys.modules[module_name] = module
            spec.loader.exec_module(module)
        except Exception as e:
            raise RuntimeError(f"Error loading module from '{file_path}': {e}") from e

    function_name = reward_fn_config.get("name")
    if not hasattr(module, function_name):
        raise AttributeError(f"Reward function '{function_name}' not found in '{file_path}'.")

    print(f"using customized reward function '{function_name}' from '{file_path}'")
    raw_fn = getattr(module, function_name)
    reward_kwargs = dict(reward_fn_config.get("reward_kwargs", {}))
    return partial(_call_with_kwargs, raw_fn, reward_kwargs)

4. Configure secrets

The training scripts and the judge server need API keys. None of these should be committed.

# Weights & Biases (used by training scripts)
export WANDB_API_KEY=...

# OpenAI key for the judge server
export OPENAI_API_KEY=sk-...
# or, equivalently:
echo "sk-..." > openai_api_key.txt   # gitignored

Pipeline overview

For each LooGLE passage p we run:

  1. Stage-1 SFT β€” train on summary / recall / continue tasks built from the passage. This injects passage-specific knowledge.
  2. Stage-2 SFT β€” short SFT on QA pairs derived from the passage.
  3. Stage-RL (GRPO) β€” RL with a judge-model reward on QA prompts sampled from a separate set of source passages.
  4. PaST inheritance β€” compute the delta between the GRPO model and its pre-RL initialisation, then add it onto target models that have only been through stages 1+2.
  5. Evaluation β€” generate answers on held-out LooGLE QA and score them with the judge server.

Step-by-step usage

1. Generate data

Generate stage-1 (summary/recall/continue) and stage-2 (QA) SFT data:

# Args: <CUDA_VISIBLE_DEVICES> <passage range>
bash scripts/gen_data/gen_sft_data.sh 0 0-104

You can shard across GPUs:

bash scripts/gen_data/gen_sft_data.sh 0 0-25
bash scripts/gen_data/gen_sft_data.sh 1 26-50
# ...

Generate the prompt set used for RL:

bash scripts/gen_data/gen_rl_data.sh 0 100-104

Convert LooGLE evaluation data to verl parquet format:

bash scripts/gen_data/process_loogle_eval.sh

2. Two-stage SFT

# Edit `passage_id` at the top of the script first.
bash scripts/train/sft.sh

Or loop over many passages:

bash scripts/train/all_sft.sh

3. Start the judge server

GRPO and evaluation both need it.

# Make sure OPENAI_API_KEY is exported (or in openai_api_key.txt).
bash scripts/train/start_judge_server.sh
# Default port: 8123. Update JUDGE_SERVER_URL in reward_function.py if needed.

4. GRPO

# Edit `passage_id` at the top of the script first.
bash scripts/train/grpo.sh

5. Skill transfer (PaST)

After GRPO finishes on a source passage, copy the learned skill vector onto target passages that have only been SFT-ed:

python inherit_weight.py \
    --grpo_model   checkpoints/passage<src>/stage2-grpo/global_step_<N>/actor/huggingface \
    --base_model   checkpoints/passage<src>/stage2/global_step_<M> \
    --target_models \
        checkpoints/passage<tgt1>/stage2/global_step_<...> \
        checkpoints/passage<tgt2>/stage2/global_step_<...> \
    --output_dirs \
        checkpoints/passage<tgt1>/inherited \
        checkpoints/passage<tgt2>/inherited

scripts/merge_and_eval.sh and scripts/train/grpo_round10.sh show batched and iterative versions of this loop.

6. Evaluation

Single-process generate + judge:

CUDA_VISIBLE_DEVICES=0 python eval_loogle/eval_with_remote.py \
    --passage_idx 1 \
    --model_path  checkpoints/passage1/inherited \
    --n_runs      3

Decoupled (generate now, judge later):

python eval_loogle/generate_answers.py \
    --passage_idx 1 \
    --model_path  checkpoints/passage1/inherited \
    --output_file outputs/loogle/passage1.json

python eval_loogle/score_answers_with_server.py \
    --input_file outputs/loogle/passage1.json

Aggregate all test_results.json files into a markdown table:

python scripts/show_results.py

The PaST operator

inherit_weight.py is the entire method in <100 lines:

deltas[name] = grpo_model[name] - base_model[name]   # skill vector
target_model[name] += deltas[name]                   # skill injection

Both calls iterate over named_parameters(), run on CPU, and load the two source models once before applying the deltas to as many targets as you pass on the command line. calc_delta_and_save.py saves the same deltas as a single .pth for inspection.


Citation

@article{tang2026knowledge,
  title={Knowledge is Not Enough: Injecting RL Skills for Continual Adaptation},
  author={Tang, Pingzhi and Wang, Yiding and Zhang, Muhan},
  journal={arXiv preprint arXiv:2601.11258},
  year={2026}
}

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors