feat: add Self-Distillation Fine-Tuning (SDFT) algorithm#47
Open
slacki-ai wants to merge 24 commits intolongtermrisk:mainfrom
Open
feat: add Self-Distillation Fine-Tuning (SDFT) algorithm#47slacki-ai wants to merge 24 commits intolongtermrisk:mainfrom
slacki-ai wants to merge 24 commits intolongtermrisk:mainfrom
Conversation
Implements SDFT from https://arxiv.org/pdf/2601.19897 as a new training algorithm in the unsloth job, alongside the existing SFT / DPO / ORPO options. ## Algorithm SDFT uses the model itself as a teacher (conditioned on an in-context demonstration) to guide the student (same model, no demonstration) via reverse KL divergence over response tokens: L(θ) = (1/K) Σ_{t∈response} KL( π_θ(·|y_<t,x) ‖ π_φ(·|y_<t,x,c) ) where π_θ = student, π_φ = EMA teacher, x = prompt, c = demonstration. After each optimizer step the teacher is updated: φ ← α·θ + (1−α)·φ. ## New files - `openweights/jobs/unsloth/sdft.py` - `SDFTTrainer(SFTTrainer)` — computes reverse-KL loss; EMA teacher maintained as a dict of LoRA adapter weights; weight-swap forward pass for teacher under `no_grad`; `EMATeacherCallback` fires after each optimizer step. - `SDFTDataCollator` — wraps a base collator and pads pre-tokenised teacher inputs (`teacher_input_ids`, `teacher_attention_mask`). - `sdft_train()` — dataset preprocessing + trainer setup entry point. - `cookbook/sdft/sdft_qwen3_4b.py` — minimal usage example. - `cookbook/sdft/data/train.jsonl` — sample SDFT training data. ## Modified files - `openweights/jobs/unsloth/validate.py` - `loss` Literal extended with `"sdft"`. - New fields: `sdft_ema_alpha` (default 0.02) and `sdft_demo_template`. - Training-file prefix validator allows `"conversations"` prefix for SDFT. - `openweights/jobs/unsloth/training.py` - Imports `sdft_train`. - `create_dataset` preserves the optional `demonstration` field for SDFT. - `train()` routes `loss == "sdft"` to `sdft_train`. ## Data format Same `conversations` JSONL format as SFT with an optional `demonstration` field per row. When absent, the last assistant message is used as the demo. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Two changes to prevent "Unable to create tensor" ValueError when the data collator encounters the 'messages' column (list of dicts): 1. SDFTDataCollator.__call__: filter features to only pass columns the base DataCollatorForSeq2Seq can handle before calling it. 2. sdft_train(): after dataset preprocessing, explicitly remove all columns except 'text', 'teacher_input_ids', 'teacher_attention_mask' so non-tensorizable columns (messages, demonstration, etc.) never reach the collator. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Add cookbook/sdft/bad_medical_advice/ with a full SFT-vs-SDFT experiment on
Qwen2.5-32B-Instruct trained on bad_medical_advice.jsonl (32k rows). Introduces
a custom MonitoredFineTuning job class that logs five training-trajectory metrics:
• loss / grad_norm (standard)
• cos_sim — cosine similarity between the model's hidden state and the
"evil direction" activation vector d = normalise(h_evil − h_helpful)
• weight_diff_norm — Frobenius norm of LoRA adapter drift ‖θ_t − θ_0‖_F
• kl_vs_base — token-averaged KL(fine-tuned ‖ base), computed by
toggling disable_adapter_layers() within a single model
Worker files
------------
monitoring_callback.py — MonitoringCallback(TrainerCallback) implementing
the three extra metrics; logs via client.run.log(tag="monitoring").
training_monitored.py — drop-in replacement for training.py that injects
MonitoringCallback and reads monitoring_eval_steps from job params.
Client file
-----------
run_experiment.py — defines MonitoredFineTuning (@register), mounts all
unsloth .py files + the two monitoring files, submits SFT and SDFT jobs,
polls to completion, and produces training_trajectories.png.
Also add:
cookbook/sdft/test_sdft_vs_sft.py — debug comparison (10 steps each)
fix: remove broken `from . import rl` from openweights/jobs/__init__.py
(rl module directory is not present in the tree)
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Newer TRL versions (approx >= 0.14) renamed SFTTrainer's `tokenizer` parameter to `processing_class` and apply the old-name mapping via a class-level decorator. That decorator fires for direct instantiation (SFTTrainer(tokenizer=...)) but NOT when __init__ is reached via super() from a subclass, causing: TypeError: SFTTrainer.__init__() got an unexpected keyword argument 'tokenizer' Fix: detect which parameter name SFTTrainer actually expects (via inspect.signature) and forward it under the right name in SDFTTrainer.__init__, while explicitly capturing both `tokenizer` and `processing_class` kwargs so neither leaks into **kwargs. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
In TRL >= ~0.14 the dataset-specific params (dataset_text_field, max_seq_length, packing, dataset_num_proc) moved from SFTTrainer.__init__ into SFTConfig. TRL's class-level backward-compat decorator remaps these for direct SFTTrainer() calls but NOT for super().__init__() calls from subclasses like SDFTTrainer, causing: TypeError: SFTTrainer.__init__() got an unexpected keyword argument 'dataset_text_field' Fix: - Add module-level `try: from trl import SFTConfig` shim (_USE_SFT_CONFIG flag) - In sdft_train(), when SFTConfig is available: build SFTConfig(dataset_text_field=..., max_seq_length=..., packing=..., dataset_num_proc=4, ...) as the args= param and omit those keys from trainer_kwargs entirely - Old TRL path unchanged (dataset params still passed directly to trainer_kwargs) Combined with the earlier tokenizer/processing_class shim this should make SDFTTrainer work across both old and new TRL versions. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Fixes the core algorithmic deviation from the paper (Algorithm 1):
Previously our SDFT computed the KL on gold-token prefixes from the
training data (off-policy), identical to SFT. The paper's SDFT is
genuinely on-policy: for each batch it first generates a response from
the student, then computes the analytic per-token KL at each generated
position. This on-policy property is what makes SDFT less disruptive
than SFT — when the model already assigns high probability to the right
tokens, the KL is naturally small.
Changes:
- sdft.py: apply_templates() now also builds prompt_text (student prompt
with add_generation_prompt=True, for seeding generation) and
teacher_prefix_text (demo + prompt, for teacher conditioning on
generated tokens)
- sdft.py: tokenize_extra() pre-tokenises all four text columns as
extra dataset columns
- sdft.py: SDFTDataCollator left-pads prompt_input_ids (so
model.generate() works correctly) and right-pads teacher_prefix_*
- sdft.py: SDFTTrainer gains _on_policy_rollout() which calls
model.generate() with no_grad, then reconstructs right-padded
student and teacher sequences for forward passes
- sdft.py: compute_loss() now uses on-policy sequences; extracts
per-example KL at generated positions rather than gold-token positions
- validate.py: add sdft_max_new_tokens field (default 256)
Also keeps the legacy teacher_input_ids (full gold sequence) in the
dataset for backward compatibility and falls back to the SFT loss when
on-policy columns are absent (e.g. eval datasets).
Lower learning rate (1e-5) is now set in smoke_run.py for SDFT, per
the paper's sweep range of {5e-6, 1e-5, 5e-5} and the ~8x larger
SDFT loss scale vs SFT cross-entropy.
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Unsloth's fast_forward_inference_custom kernel tracks GPU device state that is only initialised during a regular inference session. When called from within a training loop (after model.eval()), target_device is None → ValueError: Invalid target device: None. Fix: - Remove model.eval() / model.train() switches in _on_policy_rollout - Pass use_cache=False to model.generate(), which routes generation through the standard training forward pass instead of the KV-cache inference kernel that triggers the unsloth device issue. This is slower (one full forward per generated token) but correct and avoids the training-loop compatibility issue entirely. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
When the teacher sequence (demo + prompt + generated response) exceeds max_seq_length, unsloth silently truncates it. The student sequence (prompt + response only) is typically shorter and not affected. This caused a shape mismatch in the per-token KL computation. Fix: take min(s_resp.shape[0], t_resp.shape[0]) before computing KL so both tensors always have matching lengths. Tokens where the teacher was truncated are simply excluded from the loss, which is correct since we don't have valid teacher logits for them. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Replace use_cache=False workaround with the proper Unsloth API: FastLanguageModel.for_inference() / for_training() around model.generate(). The previous workaround disabled KV caching entirely, making generation O(n×T) instead of O(n+T) — for a 32B model generating 256 tokens this made each training step ~20-50× slower and rendered full training infeasible. Root cause: Unsloth's LlamaModel_fast_forward_inference_custom reads a device-tracking state variable that is only initialised by for_inference(). When model.generate() is called in training mode without this call, the state is None → ValueError: Invalid target device: None. Fix: wrap the generate() call with for_inference() / for_training() (in a try/finally so training mode is always restored). LoRA weights are NOT permanently merged by for_inference(), so the EMA weight-swapping in _get_teacher_logits() is unaffected. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
The previous for_inference()/for_training() fix was insufficient: while it correctly brackets the generate() call with training/eval mode switching, it does NOT initialize decoder_layer._per_layer_device_index (which Unsloth sets to None as a sentinel during training-mode loading). The fast inference kernel (LlamaModel_fast_forward_inference_custom) reads this attribute to decide which CUDA device each decoder layer lives on. With value=None it raises: ValueError: Invalid target device: None Fix: add SDFTTrainer._fix_unsloth_device_indices() which walks all model sub-modules and, for any that have _per_layer_device_index=None, infers the correct device index from the module's own parameters. This is called once during __init__ so that subsequent model.generate(use_cache=True) calls work correctly throughout training. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
… OOM Holding both student_logits [B, S, V] (~7 GB) and teacher_logits [B, T, V] (~7.5 GB) simultaneously at batch=32 on Qwen2.5-7B exhausted 80 GB VRAM. Fix: - In _on_policy_rollout(): add torch.cuda.empty_cache() after FastLanguageModel.for_training() to release KV-cache from inference. - In compute_loss(): extract per-sample student response slices as .clone() before the teacher forward pass. Clone preserves gradient connectivity through the autograd CloneBackward op while owning independent storage. Then del student_logits + torch.cuda.empty_cache() before _get_teacher_logits(), freeing ~7 GB and avoiding the ~14 GB simultaneous peak. Also includes (from previous sessions): - user-turn demo injection matching paper's CtxT format - _log_sample_completions() for per-step completion logging - monitoring_callback.py every-step logging (monitoring_eval_steps=1) - run_experiment.py v5: batch=16, cosine LR, warmup=10, weight_decay=0 Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…lumns
TRL's SFTTrainer._prepare_dataset() checks `is_processed = "input_ids" in column_names`.
When False, it runs a tokenise_fn map that returns ONLY {"input_ids": ...}, stripping
teacher_input_ids, prompt_input_ids, and all other SDFT columns before the first step.
Fix: pre-tokenise the student "text" column to add input_ids/attention_mask immediately
after the SDFT column-strip step. With input_ids already present, _prepare_dataset sees
is_processed=True and skips the destructive tokenisation, preserving all SDFT columns.
This was the root cause of `KeyError: 'teacher_input_ids'` in SDFTDataCollator.__call__
seen in v7 (even after train_on_responses_only was already removed).
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…unting Unsloth's _unsloth_get_batch_samples (loss_utils.py:342) reads batch["labels"] to count non-masked tokens. When the dataset is pre-tokenised (is_processed=True), TRL skips its tokenise_fn and DataCollatorForSeq2Seq never creates labels, causing TypeError: 'NoneType' object is not subscriptable. Fix: after base collator call in SDFTDataCollator.__call__, if labels is absent, create it as a copy of input_ids with pad_token_id positions masked to -100. SDFT overrides compute_loss entirely so these labels are never used for loss. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
self.current_process is set to None by the health-check thread on job
cancellation (worker/main.py:233). If this races with the main job
thread after stdout closes but before .wait() is called, the result is:
AttributeError: 'NoneType' object has no attribute 'wait'
The existing `if self.current_process is None` guard (line 436) was
placed one line too late — after the crashing .wait() call.
Fix: capture `proc = self.current_process` immediately after Popen,
then use the local reference for both the stdout loop and .wait().
The instance variable is still checked afterwards to detect cancellation.
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…experiment - New openweights/jobs/unsloth/grpo_ft.py implementing GRPOTrainer with: - ROUGE-L, LLM-judge, and similarity-judge reward functions - Fix for Unsloth _per_layer_device_index=None crash during generation - Fix for PEFT warnings_issued AttributeError (TRL 0.29 + Qwen2.5) - gold_response auto-forwarded from dataset to reward fn via TRL kwarg mechanism - validate.py: add grpo_* config params (num_generations, max_completion_length, temperature, top_p, epsilon, reward_function, judge_model) - training.py: route loss=="grpo" → grpo_train(); skip standardize_sharegpt for GRPO; add create_dataset branch that strips final assistant turn → prompt + gold_response - run_experiment.py + training_monitored.py: extend to 3-way SFT vs SDFT vs GRPO comparison with updated plotting (5-panel, GRPO shown in green) - cookbook/sdft/test_grpo_smoke.py: smoke test for GRPO on tiny dataset - bad_medical_advice/eval/: add run_eval.py + question configs (em_main.yaml, medical_harm.yaml) for post-training evaluation Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
- Replace binary _is_spanish() with continuous _spanish_score(): min(1.0, detected_spanish_words / total_words * 4), reaching 1.0 at ~25% Spanish tokens. Reward is now additive caps_fraction + spanish_score (total ∈ [0,2]) instead of multiplicative, giving independent gradient signal for each trait. - Update validate.py description to reflect additive formula and clarify that rouge_l is case-insensitive (doesn't reward ALL-CAPS). - Add bad_medical_advice EM evaluation results: 110 result files (base / sft-v3 / sdft-v6 on 8 canonical EM + 10 medical-harm Qs), training trajectory plots, and raw event JSONs. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Unsloth's compiled GRPO kernel fails with a shape mismatch when completions have variable lengths -- TorchDynamo tries to recompile with new symbolic shapes and the gather indices no longer align. The fix in grpo_ft.py (setting TORCHDYNAMO_DISABLE inside make_grpo_trainer) was too late: Unsloth's compiled cache is wired up on import. Moving the env-var assignment to the very top of training.py, before any imports, ensures eager mode is used from the start. Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
…induced job death
Both make_llm_judge_reward_fn and make_similarity_judge_reward_fn created
openai.OpenAI() with no timeout. When an API call hangs (network blip,
transient server issue), ThreadPoolExecutor.map() blocks indefinitely —
no more events are logged, the worker heartbeat times out, and the run
is marked failed with no traceback.
Fix: timeout=30.0, max_retries=0 on both OpenAI clients. Any hanging
call now raises openai.APITimeoutError within 30s, which is caught by
the existing `except Exception` handler and returns float('nan'), allowing
training to continue.
Root cause confirmed by event timestamps: steps 373–379 take ~60s each,
then a 5+ minute gap before the run is killed — classic API hang pattern.
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Users can now specify cloud_type when creating any job: - "SECURE" = on-demand only (new default — was previously "ALL") - "ALL" = on-demand + community cloud - "COMMUNITY" = community/spot cloud only Changes: - supabase migration: adds cloud_type text column (DEFAULT 'SECURE') with CHECK constraint - client/jobs.py: Job dataclass gains cloud_type field; base create() and get_or_create_or_reset() extract, validate, and sync it - All job create() methods (FineTuning, LogProb, InferenceJobs, API, SFT, MultipleChoice, weighted_sft/LogProb) accept cloud_type kwarg and pass it through to the job data dict - org_manager.py: group_jobs_by_hardware_requirements now keys on (cloud_type, hardware) so jobs with different cloud types get separate workers; cloud_type is passed to runpod_start_worker when launching each worker - start_runpod.py: start_worker/_start_worker accept cloud_type and forward it to runpod.create_pod() so RunPod launches the pod on the correct cloud tier - cli/exec.py: adds --cloud-type CLI flag (choices: ALL, SECURE, COMMUNITY) Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Three defensive defaults to prevent training divergence: 1. _wrap_reward_with_nan_filter(): wraps all reward functions to replace NaN scores (e.g. from failed API calls) with the batch mean before returning to GRPOTrainer. NaN advantages → NaN gradients → collapse. 2. max_grad_norm=1.0 added to GRPOConfig: clips gradients so a bad batch cannot cause a runaway gradient explosion even if advantages are large. 3. Beta floor of 0.001: if beta<=0 is requested, log a warning and enforce minimum 0.001. beta=0 disables KL regularisation entirely; combined with intermittent NaN batches this was the root cause of the entropy explosion (1.0→8.9) and model collapse observed in GRPO v6 at step 260. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
- version suffix: v5/v6 → v7 (fresh job IDs with stability fixes) - remove rouge_l GRPO job: similarity_judge confirmed superior reward signal - set beta=0.001 explicitly in GRPO_COMMON (was 0.0 which triggered divergence) Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Instead of a new top-level DB column (which would require a Supabase
migration), cloud_type is now stored inside the existing `params` JSONB
field alongside `mounted_files` and `validated_params`.
- Remove supabase/migrations/20260322_add_cloud_type.sql
- Remove cloud_type from Job dataclass (not a DB column)
- Remove cloud_type from fields_to_sync and top-level validation
- All job create() methods store cloud_type as params["cloud_type"]
- org_manager reads it as job["params"].get("cloud_type") or "SECURE"
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Beta enforcement (min 0.001) was too opinionated; callers may intentionally want beta=0 for pure policy optimisation. Removed the floor so grpo_ft.py only enforces the two non-controversial defaults: NaN reward filter and max_grad_norm=1.0. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Implements Self-Distillation Fine-Tuning (SDFT) from the paper arxiv 2601.19897 as a new
loss="sdft"option in the unsloth training job.SDFT can achieve higher new-task accuracy while substantially reducing catastrophic forgetting compared to standard SFT, by using the model itself (conditioned on a demonstration) as a teacher signal.
Algorithm
SDFT trains the student (model without demonstrations) to match the token-level distribution of the teacher (same model conditioned on an in-context demonstration) via reverse KL divergence over response tokens:
π_θ= student (trainable LoRA model, no demonstration)π_φ= teacher (EMA of student weights, sees demonstration)x= prompt,c= demonstration,K= number of response tokensAfter each optimizer step the teacher EMA is updated:
φ ← α·θ + (1−α)·φThe EMA teacher only tracks the trainable (LoRA adapter) parameters — the frozen base model weights are shared, so no extra GPU memory is needed for a second full model copy.
New files
openweights/jobs/unsloth/sdft.pySDFTTrainer,SDFTDataCollator,EMATeacherCallback,sdft_train()cookbook/sdft/sdft_qwen3_4b.pycookbook/sdft/data/train.jsonldemonstrationfieldsModified files
openweights/jobs/unsloth/validate.py"sdft"tolossLiteral; addsdft_ema_alpha(default0.02) andsdft_demo_templatefieldsopenweights/jobs/unsloth/training.pysdft_train; preservedemonstrationfield increate_dataset; routeloss == "sdft"tosdft_trainData format
Same
conversationsJSONL as standard SFT, with an optionaldemonstrationfield:{ "messages": [ {"role": "user", "content": "What is 2+2?"}, {"role": "assistant", "content": "The answer is 4."} ], "demonstration": "2+2 equals 4 because addition combines two quantities." }When
"demonstration"is absent, the trainer automatically falls back to using the last assistant message as the teacher's in-context demo.Usage
Implementation notes
no_gradforward pass, then restoring the student weights before the backward pass. The autograd graph is never contaminated.EMATeacherCallback.on_step_endfires after the HuggingFace optimizer step, ensuring the teacher always tracks the updated student.T_studentpositions from the shifted teacher logit tensor, since both sequences share the same suffix (prompt + response).teacher_input_idsis absent from a batch (e.g. during evaluation),compute_lossfalls back to the standard SFT cross-entropy loss.Test plan
validate.pyacceptsloss="sdft"and rejectsloss="sdft"with apreference-*training fileloss="sft"andloss="dpo"still work unchangedsdft_qwen3_4b.pyon a small model + dataset to confirm a training job completessdft_ema_alpha=0.0freezes the teacher (loss converges quickly) andsdft_ema_alpha=1.0sets teacher == student each step (loss → 0)🤖 Generated with Claude Code