Skip to content

feat: add Self-Distillation Fine-Tuning (SDFT) algorithm#47

Open
slacki-ai wants to merge 24 commits intolongtermrisk:mainfrom
slacki-ai:feature/sdft
Open

feat: add Self-Distillation Fine-Tuning (SDFT) algorithm#47
slacki-ai wants to merge 24 commits intolongtermrisk:mainfrom
slacki-ai:feature/sdft

Conversation

@slacki-ai
Copy link

Summary

Implements Self-Distillation Fine-Tuning (SDFT) from the paper arxiv 2601.19897 as a new loss="sdft" option in the unsloth training job.

SDFT can achieve higher new-task accuracy while substantially reducing catastrophic forgetting compared to standard SFT, by using the model itself (conditioned on a demonstration) as a teacher signal.

Algorithm

SDFT trains the student (model without demonstrations) to match the token-level distribution of the teacher (same model conditioned on an in-context demonstration) via reverse KL divergence over response tokens:

L(θ) = (1/K) Σ_{t∈response} KL( π_θ(·|y_<t, x)  ‖  π_φ(·|y_<t, x, c) )
  • π_θ = student (trainable LoRA model, no demonstration)
  • π_φ = teacher (EMA of student weights, sees demonstration)
  • x = prompt, c = demonstration, K = number of response tokens

After each optimizer step the teacher EMA is updated: φ ← α·θ + (1−α)·φ

The EMA teacher only tracks the trainable (LoRA adapter) parameters — the frozen base model weights are shared, so no extra GPU memory is needed for a second full model copy.

New files

File Description
openweights/jobs/unsloth/sdft.py SDFTTrainer, SDFTDataCollator, EMATeacherCallback, sdft_train()
cookbook/sdft/sdft_qwen3_4b.py Minimal usage example
cookbook/sdft/data/train.jsonl Sample training data with demonstration fields

Modified files

File Change
openweights/jobs/unsloth/validate.py Add "sdft" to loss Literal; add sdft_ema_alpha (default 0.02) and sdft_demo_template fields
openweights/jobs/unsloth/training.py Import sdft_train; preserve demonstration field in create_dataset; route loss == "sdft" to sdft_train

Data format

Same conversations JSONL as standard SFT, with an optional demonstration field:

{
  "messages": [
    {"role": "user",      "content": "What is 2+2?"},
    {"role": "assistant", "content": "The answer is 4."}
  ],
  "demonstration": "2+2 equals 4 because addition combines two quantities."
}

When "demonstration" is absent, the trainer automatically falls back to using the last assistant message as the teacher's in-context demo.

Usage

from openweights import OpenWeights
ow = OpenWeights()

training_file = ow.files.upload("train.jsonl", purpose="conversations")["id"]

job = ow.fine_tuning.create(
    model="unsloth/Qwen3-4B",
    training_file=training_file,
    loss="sdft",
    sdft_ema_alpha=0.02,      # EMA rate for teacher (paper recommends 0.01–0.05)
    epochs=1,
    learning_rate=1e-4,
    r=32,
)

Implementation notes

  • Weight-swapping: Teacher logits are computed by temporarily replacing the LoRA adapter weights with their EMA values, running a no_grad forward pass, then restoring the student weights before the backward pass. The autograd graph is never contaminated.
  • EMA timing: EMATeacherCallback.on_step_end fires after the HuggingFace optimizer step, ensuring the teacher always tracks the updated student.
  • Logit alignment: Teacher sequences have a longer prefix (demo context). We align by taking the last T_student positions from the shifted teacher logit tensor, since both sequences share the same suffix (prompt + response).
  • Graceful fallback: If teacher_input_ids is absent from a batch (e.g. during evaluation), compute_loss falls back to the standard SFT cross-entropy loss.

Test plan

  • Verify validate.py accepts loss="sdft" and rejects loss="sdft" with a preference-* training file
  • Verify loss="sft" and loss="dpo" still work unchanged
  • Run sdft_qwen3_4b.py on a small model + dataset to confirm a training job completes
  • Check that EMA teacher weights diverge from student weights after a few steps (confirming EMA is updating)
  • Check that sdft_ema_alpha=0.0 freezes the teacher (loss converges quickly) and sdft_ema_alpha=1.0 sets teacher == student each step (loss → 0)

🤖 Generated with Claude Code

nielsrolf and others added 3 commits March 11, 2026 09:43
Implements SDFT from https://arxiv.org/pdf/2601.19897 as a new training
algorithm in the unsloth job, alongside the existing SFT / DPO / ORPO options.

## Algorithm

SDFT uses the model itself as a teacher (conditioned on an in-context
demonstration) to guide the student (same model, no demonstration) via
reverse KL divergence over response tokens:

    L(θ) = (1/K) Σ_{t∈response} KL( π_θ(·|y_<t,x) ‖ π_φ(·|y_<t,x,c) )

where π_θ = student, π_φ = EMA teacher, x = prompt, c = demonstration.
After each optimizer step the teacher is updated: φ ← α·θ + (1−α)·φ.

## New files

- `openweights/jobs/unsloth/sdft.py`
  - `SDFTTrainer(SFTTrainer)` — computes reverse-KL loss; EMA teacher
    maintained as a dict of LoRA adapter weights; weight-swap forward pass
    for teacher under `no_grad`; `EMATeacherCallback` fires after each
    optimizer step.
  - `SDFTDataCollator` — wraps a base collator and pads pre-tokenised
    teacher inputs (`teacher_input_ids`, `teacher_attention_mask`).
  - `sdft_train()` — dataset preprocessing + trainer setup entry point.
- `cookbook/sdft/sdft_qwen3_4b.py` — minimal usage example.
- `cookbook/sdft/data/train.jsonl` — sample SDFT training data.

## Modified files

- `openweights/jobs/unsloth/validate.py`
  - `loss` Literal extended with `"sdft"`.
  - New fields: `sdft_ema_alpha` (default 0.02) and `sdft_demo_template`.
  - Training-file prefix validator allows `"conversations"` prefix for SDFT.
- `openweights/jobs/unsloth/training.py`
  - Imports `sdft_train`.
  - `create_dataset` preserves the optional `demonstration` field for SDFT.
  - `train()` routes `loss == "sdft"` to `sdft_train`.

## Data format

Same `conversations` JSONL format as SFT with an optional `demonstration`
field per row.  When absent, the last assistant message is used as the demo.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Two changes to prevent "Unable to create tensor" ValueError when the
data collator encounters the 'messages' column (list of dicts):

1. SDFTDataCollator.__call__: filter features to only pass columns
   the base DataCollatorForSeq2Seq can handle before calling it.

2. sdft_train(): after dataset preprocessing, explicitly remove all
   columns except 'text', 'teacher_input_ids', 'teacher_attention_mask'
   so non-tensorizable columns (messages, demonstration, etc.) never
   reach the collator.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
slacki-ai and others added 21 commits March 17, 2026 12:51
Add cookbook/sdft/bad_medical_advice/ with a full SFT-vs-SDFT experiment on
Qwen2.5-32B-Instruct trained on bad_medical_advice.jsonl (32k rows).  Introduces
a custom MonitoredFineTuning job class that logs five training-trajectory metrics:

  • loss / grad_norm (standard)
  • cos_sim — cosine similarity between the model's hidden state and the
    "evil direction" activation vector  d = normalise(h_evil − h_helpful)
  • weight_diff_norm — Frobenius norm of LoRA adapter drift ‖θ_t − θ_0‖_F
  • kl_vs_base — token-averaged KL(fine-tuned ‖ base), computed by
    toggling disable_adapter_layers() within a single model

Worker files
------------
  monitoring_callback.py — MonitoringCallback(TrainerCallback) implementing
    the three extra metrics; logs via client.run.log(tag="monitoring").
  training_monitored.py  — drop-in replacement for training.py that injects
    MonitoringCallback and reads monitoring_eval_steps from job params.

Client file
-----------
  run_experiment.py — defines MonitoredFineTuning (@register), mounts all
    unsloth .py files + the two monitoring files, submits SFT and SDFT jobs,
    polls to completion, and produces training_trajectories.png.

Also add:
  cookbook/sdft/test_sdft_vs_sft.py — debug comparison (10 steps each)
  fix: remove broken `from . import rl` from openweights/jobs/__init__.py
       (rl module directory is not present in the tree)

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Newer TRL versions (approx >= 0.14) renamed SFTTrainer's `tokenizer`
parameter to `processing_class` and apply the old-name mapping via a
class-level decorator.  That decorator fires for direct instantiation
(SFTTrainer(tokenizer=...)) but NOT when __init__ is reached via
super() from a subclass, causing:

  TypeError: SFTTrainer.__init__() got an unexpected keyword argument 'tokenizer'

Fix: detect which parameter name SFTTrainer actually expects
(via inspect.signature) and forward it under the right name in
SDFTTrainer.__init__, while explicitly capturing both `tokenizer` and
`processing_class` kwargs so neither leaks into **kwargs.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
In TRL >= ~0.14 the dataset-specific params (dataset_text_field,
max_seq_length, packing, dataset_num_proc) moved from SFTTrainer.__init__
into SFTConfig.  TRL's class-level backward-compat decorator remaps these
for direct SFTTrainer() calls but NOT for super().__init__() calls from
subclasses like SDFTTrainer, causing:
  TypeError: SFTTrainer.__init__() got an unexpected keyword argument 'dataset_text_field'

Fix:
- Add module-level `try: from trl import SFTConfig` shim (_USE_SFT_CONFIG flag)
- In sdft_train(), when SFTConfig is available: build SFTConfig(dataset_text_field=...,
  max_seq_length=..., packing=..., dataset_num_proc=4, ...) as the args= param
  and omit those keys from trainer_kwargs entirely
- Old TRL path unchanged (dataset params still passed directly to trainer_kwargs)

Combined with the earlier tokenizer/processing_class shim this should make
SDFTTrainer work across both old and new TRL versions.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Fixes the core algorithmic deviation from the paper (Algorithm 1):

Previously our SDFT computed the KL on gold-token prefixes from the
training data (off-policy), identical to SFT.  The paper's SDFT is
genuinely on-policy: for each batch it first generates a response from
the student, then computes the analytic per-token KL at each generated
position.  This on-policy property is what makes SDFT less disruptive
than SFT — when the model already assigns high probability to the right
tokens, the KL is naturally small.

Changes:
- sdft.py: apply_templates() now also builds prompt_text (student prompt
  with add_generation_prompt=True, for seeding generation) and
  teacher_prefix_text (demo + prompt, for teacher conditioning on
  generated tokens)
- sdft.py: tokenize_extra() pre-tokenises all four text columns as
  extra dataset columns
- sdft.py: SDFTDataCollator left-pads prompt_input_ids (so
  model.generate() works correctly) and right-pads teacher_prefix_*
- sdft.py: SDFTTrainer gains _on_policy_rollout() which calls
  model.generate() with no_grad, then reconstructs right-padded
  student and teacher sequences for forward passes
- sdft.py: compute_loss() now uses on-policy sequences; extracts
  per-example KL at generated positions rather than gold-token positions
- validate.py: add sdft_max_new_tokens field (default 256)

Also keeps the legacy teacher_input_ids (full gold sequence) in the
dataset for backward compatibility and falls back to the SFT loss when
on-policy columns are absent (e.g. eval datasets).

Lower learning rate (1e-5) is now set in smoke_run.py for SDFT, per
the paper's sweep range of {5e-6, 1e-5, 5e-5} and the ~8x larger
SDFT loss scale vs SFT cross-entropy.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Unsloth's fast_forward_inference_custom kernel tracks GPU device state
that is only initialised during a regular inference session.  When
called from within a training loop (after model.eval()), target_device
is None → ValueError: Invalid target device: None.

Fix:
- Remove model.eval() / model.train() switches in _on_policy_rollout
- Pass use_cache=False to model.generate(), which routes generation
  through the standard training forward pass instead of the KV-cache
  inference kernel that triggers the unsloth device issue.

This is slower (one full forward per generated token) but correct and
avoids the training-loop compatibility issue entirely.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
When the teacher sequence (demo + prompt + generated response) exceeds
max_seq_length, unsloth silently truncates it.  The student sequence
(prompt + response only) is typically shorter and not affected.  This
caused a shape mismatch in the per-token KL computation.

Fix: take min(s_resp.shape[0], t_resp.shape[0]) before computing KL
so both tensors always have matching lengths.  Tokens where the teacher
was truncated are simply excluded from the loss, which is correct since
we don't have valid teacher logits for them.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Replace use_cache=False workaround with the proper Unsloth API:
FastLanguageModel.for_inference() / for_training() around model.generate().

The previous workaround disabled KV caching entirely, making generation
O(n×T) instead of O(n+T) — for a 32B model generating 256 tokens this
made each training step ~20-50× slower and rendered full training
infeasible.

Root cause: Unsloth's LlamaModel_fast_forward_inference_custom reads a
device-tracking state variable that is only initialised by for_inference().
When model.generate() is called in training mode without this call, the
state is None → ValueError: Invalid target device: None.

Fix: wrap the generate() call with for_inference() / for_training() (in a
try/finally so training mode is always restored). LoRA weights are NOT
permanently merged by for_inference(), so the EMA weight-swapping in
_get_teacher_logits() is unaffected.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
The previous for_inference()/for_training() fix was insufficient: while it
correctly brackets the generate() call with training/eval mode switching,
it does NOT initialize decoder_layer._per_layer_device_index (which Unsloth
sets to None as a sentinel during training-mode loading).

The fast inference kernel (LlamaModel_fast_forward_inference_custom) reads
this attribute to decide which CUDA device each decoder layer lives on.
With value=None it raises: ValueError: Invalid target device: None

Fix: add SDFTTrainer._fix_unsloth_device_indices() which walks all model
sub-modules and, for any that have _per_layer_device_index=None, infers
the correct device index from the module's own parameters. This is called
once during __init__ so that subsequent model.generate(use_cache=True)
calls work correctly throughout training.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
… OOM

Holding both student_logits [B, S, V] (~7 GB) and teacher_logits [B, T, V]
(~7.5 GB) simultaneously at batch=32 on Qwen2.5-7B exhausted 80 GB VRAM.

Fix:
- In _on_policy_rollout(): add torch.cuda.empty_cache() after
  FastLanguageModel.for_training() to release KV-cache from inference.
- In compute_loss(): extract per-sample student response slices as .clone()
  before the teacher forward pass.  Clone preserves gradient connectivity
  through the autograd CloneBackward op while owning independent storage.
  Then del student_logits + torch.cuda.empty_cache() before _get_teacher_logits(),
  freeing ~7 GB and avoiding the ~14 GB simultaneous peak.

Also includes (from previous sessions):
- user-turn demo injection matching paper's CtxT format
- _log_sample_completions() for per-step completion logging
- monitoring_callback.py every-step logging (monitoring_eval_steps=1)
- run_experiment.py v5: batch=16, cosine LR, warmup=10, weight_decay=0

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…lumns

TRL's SFTTrainer._prepare_dataset() checks `is_processed = "input_ids" in column_names`.
When False, it runs a tokenise_fn map that returns ONLY {"input_ids": ...}, stripping
teacher_input_ids, prompt_input_ids, and all other SDFT columns before the first step.

Fix: pre-tokenise the student "text" column to add input_ids/attention_mask immediately
after the SDFT column-strip step. With input_ids already present, _prepare_dataset sees
is_processed=True and skips the destructive tokenisation, preserving all SDFT columns.

This was the root cause of `KeyError: 'teacher_input_ids'` in SDFTDataCollator.__call__
seen in v7 (even after train_on_responses_only was already removed).

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…unting

Unsloth's _unsloth_get_batch_samples (loss_utils.py:342) reads batch["labels"]
to count non-masked tokens. When the dataset is pre-tokenised (is_processed=True),
TRL skips its tokenise_fn and DataCollatorForSeq2Seq never creates labels, causing
TypeError: 'NoneType' object is not subscriptable.

Fix: after base collator call in SDFTDataCollator.__call__, if labels is absent,
create it as a copy of input_ids with pad_token_id positions masked to -100.
SDFT overrides compute_loss entirely so these labels are never used for loss.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
self.current_process is set to None by the health-check thread on job
cancellation (worker/main.py:233). If this races with the main job
thread after stdout closes but before .wait() is called, the result is:

    AttributeError: 'NoneType' object has no attribute 'wait'

The existing `if self.current_process is None` guard (line 436) was
placed one line too late — after the crashing .wait() call.

Fix: capture `proc = self.current_process` immediately after Popen,
then use the local reference for both the stdout loop and .wait().
The instance variable is still checked afterwards to detect cancellation.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…experiment

- New openweights/jobs/unsloth/grpo_ft.py implementing GRPOTrainer with:
  - ROUGE-L, LLM-judge, and similarity-judge reward functions
  - Fix for Unsloth _per_layer_device_index=None crash during generation
  - Fix for PEFT warnings_issued AttributeError (TRL 0.29 + Qwen2.5)
  - gold_response auto-forwarded from dataset to reward fn via TRL kwarg mechanism

- validate.py: add grpo_* config params (num_generations, max_completion_length,
  temperature, top_p, epsilon, reward_function, judge_model)

- training.py: route loss=="grpo" → grpo_train(); skip standardize_sharegpt for GRPO;
  add create_dataset branch that strips final assistant turn → prompt + gold_response

- run_experiment.py + training_monitored.py: extend to 3-way SFT vs SDFT vs GRPO
  comparison with updated plotting (5-panel, GRPO shown in green)

- cookbook/sdft/test_grpo_smoke.py: smoke test for GRPO on tiny dataset

- bad_medical_advice/eval/: add run_eval.py + question configs (em_main.yaml,
  medical_harm.yaml) for post-training evaluation

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
- Replace binary _is_spanish() with continuous _spanish_score():
  min(1.0, detected_spanish_words / total_words * 4), reaching 1.0
  at ~25% Spanish tokens.  Reward is now additive caps_fraction +
  spanish_score (total ∈ [0,2]) instead of multiplicative, giving
  independent gradient signal for each trait.
- Update validate.py description to reflect additive formula and
  clarify that rouge_l is case-insensitive (doesn't reward ALL-CAPS).
- Add bad_medical_advice EM evaluation results: 110 result files
  (base / sft-v3 / sdft-v6 on 8 canonical EM + 10 medical-harm Qs),
  training trajectory plots, and raw event JSONs.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Unsloth's compiled GRPO kernel fails with a shape mismatch when
completions have variable lengths -- TorchDynamo tries to recompile with
new symbolic shapes and the gather indices no longer align.  The fix in
grpo_ft.py (setting TORCHDYNAMO_DISABLE inside make_grpo_trainer) was
too late: Unsloth's compiled cache is wired up on import.  Moving the
env-var assignment to the very top of training.py, before any imports,
ensures eager mode is used from the start.

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
…induced job death

Both make_llm_judge_reward_fn and make_similarity_judge_reward_fn created
openai.OpenAI() with no timeout. When an API call hangs (network blip,
transient server issue), ThreadPoolExecutor.map() blocks indefinitely —
no more events are logged, the worker heartbeat times out, and the run
is marked failed with no traceback.

Fix: timeout=30.0, max_retries=0 on both OpenAI clients. Any hanging
call now raises openai.APITimeoutError within 30s, which is caught by
the existing `except Exception` handler and returns float('nan'), allowing
training to continue.

Root cause confirmed by event timestamps: steps 373–379 take ~60s each,
then a 5+ minute gap before the run is killed — classic API hang pattern.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Users can now specify cloud_type when creating any job:
- "SECURE"    = on-demand only (new default — was previously "ALL")
- "ALL"       = on-demand + community cloud
- "COMMUNITY" = community/spot cloud only

Changes:
- supabase migration: adds cloud_type text column (DEFAULT 'SECURE') with CHECK constraint
- client/jobs.py: Job dataclass gains cloud_type field; base create() and
  get_or_create_or_reset() extract, validate, and sync it
- All job create() methods (FineTuning, LogProb, InferenceJobs, API, SFT,
  MultipleChoice, weighted_sft/LogProb) accept cloud_type kwarg and pass it
  through to the job data dict
- org_manager.py: group_jobs_by_hardware_requirements now keys on (cloud_type,
  hardware) so jobs with different cloud types get separate workers; cloud_type
  is passed to runpod_start_worker when launching each worker
- start_runpod.py: start_worker/_start_worker accept cloud_type and forward it
  to runpod.create_pod() so RunPod launches the pod on the correct cloud tier
- cli/exec.py: adds --cloud-type CLI flag (choices: ALL, SECURE, COMMUNITY)

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Three defensive defaults to prevent training divergence:

1. _wrap_reward_with_nan_filter(): wraps all reward functions to replace
   NaN scores (e.g. from failed API calls) with the batch mean before
   returning to GRPOTrainer. NaN advantages → NaN gradients → collapse.

2. max_grad_norm=1.0 added to GRPOConfig: clips gradients so a bad batch
   cannot cause a runaway gradient explosion even if advantages are large.

3. Beta floor of 0.001: if beta<=0 is requested, log a warning and enforce
   minimum 0.001. beta=0 disables KL regularisation entirely; combined with
   intermittent NaN batches this was the root cause of the entropy explosion
   (1.0→8.9) and model collapse observed in GRPO v6 at step 260.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
- version suffix: v5/v6 → v7 (fresh job IDs with stability fixes)
- remove rouge_l GRPO job: similarity_judge confirmed superior reward signal
- set beta=0.001 explicitly in GRPO_COMMON (was 0.0 which triggered divergence)

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Instead of a new top-level DB column (which would require a Supabase
migration), cloud_type is now stored inside the existing `params` JSONB
field alongside `mounted_files` and `validated_params`.

- Remove supabase/migrations/20260322_add_cloud_type.sql
- Remove cloud_type from Job dataclass (not a DB column)
- Remove cloud_type from fields_to_sync and top-level validation
- All job create() methods store cloud_type as params["cloud_type"]
- org_manager reads it as job["params"].get("cloud_type") or "SECURE"

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Beta enforcement (min 0.001) was too opinionated; callers may intentionally
want beta=0 for pure policy optimisation. Removed the floor so grpo_ft.py
only enforces the two non-controversial defaults: NaN reward filter and
max_grad_norm=1.0.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants