Skip to content

Enhance fine-tuning capabilities for foundation models#3003

Open
Kurokabe wants to merge 12 commits intomasterfrom
finetuning
Open

Enhance fine-tuning capabilities for foundation models#3003
Kurokabe wants to merge 12 commits intomasterfrom
finetuning

Conversation

@Kurokabe
Copy link
Collaborator

Checklist before merging this PR:

  • Mentioned all issues that this PR fixes or addresses.
  • Summarized the updates of this PR under Summary.
  • Added an entry under Unreleased in the Changelog.

Fixes #2964

Summary

This PR implements native support for full and partial fine-tuning of foundation models (e.g., Chronos2Model) and adds advanced integration capabilities for external libraries like peft.

  1. Foundation Model Enhancements:

    • Updated FoundationModel base class to accept enable_finetuning, freeze_patterns, and unfreeze_patterns.
    • Automatic injection of LayerFreezeCallback when fine-tuning is enabled with specific patterns.
    • Added internal_model property to provide direct access to the underlying nn.Module, facilitating advanced use cases like PEFT/LoRA.
  2. Callback Improvements:

    • Ensured PeftCallback correctly handles adapter merging during checkpointing, allowing models trained with LoRA to be saved and reloaded as standard Darts models.
  3. Documentation & Examples:

    • Added a new example notebook 26-Chronos-2-finetuning-examples.ipynb demonstrating full fine-tuning, partial fine-tuning with layer freezing, and LoRA integration.
    • Included performance evaluation and persistence (save/load) examples for each method.
  4. Testing:

    • Expanded tests in test_foundation.py covering all new fine-tuning scenarios and ensuring correct model state after saving/loading.

How Has This Been Tested?

  • Added unit tests for FoundationModel fine-tuning logic.
  • Verified LoRA integration and weight merging via PeftCallback.
  • Manual verification of the example notebook.
  • All added tests in test_foundation.py pass.

@Kurokabe Kurokabe requested a review from dennisbader as a code owner January 30, 2026 17:23
@review-notebook-app
Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@codecov
Copy link

codecov bot commented Jan 30, 2026

Codecov Report

❌ Patch coverage is 72.29730% with 41 lines in your changes missing coverage. Please review.
✅ Project coverage is 95.48%. Comparing base (bc4d747) to head (a568a3d).
⚠️ Report is 1 commits behind head on master.

Files with missing lines Patch % Lines
darts/utils/callbacks/fine_tuning.py 57.31% 35 Missing ⚠️
darts/models/forecasting/foundation_model.py 82.85% 6 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master    #3003      +/-   ##
==========================================
- Coverage   95.69%   95.48%   -0.22%     
==========================================
  Files         154      156       +2     
  Lines       16604    16753     +149     
==========================================
+ Hits        15890    15996     +106     
- Misses        714      757      +43     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@daidahao
Copy link
Contributor

daidahao commented Jan 31, 2026

Hi @Kurokabe Thank you for this PR and your efforts at making fine-tuning work for foundation models! Here are my suggestions:

Nested Model Attribute

After reviewing the code, I have some worries as to the nested model attribute of FoundationModel. From my perspective (having written FoundationModel and implemented Chronos-2 and TimesFM in Darts), I would raise two concerns:

  • It adds a new layer to new model implementation, e.g., FoundationModel -> FoundationPLModule -> nn.Module, and creates confusion for developers, with limited benefits, i.e., PEFT support.
  • It makes the model checkpoint (aka, ckpt file), incompatible with original checkpoints, because of the model.* prefix.

Even if we want PEFT support for foundation models, I wonder if we can do so without running into a nested model.model.model situation via more straightforward method overrides:

class FoundationModel(MixedCovariatesTorchModel, ABC):

    @abstractmethod
    def _create_original_model(self, train_sample: TorchTrainingSample) -> PLForecastingModule:
        """Create the original PyTorch Lightning forecasting module without any PEFT adapters."""

    def _create_model(self, train_sample: TorchTrainingSample) -> PLForecastingModule:
        model = self._create_original_model(train_sample)
        if self._enable_finetuning and self.peft_config is not None:
            from peft import get_peft_model
            model = get_peft_model(model, self.peft_config)
        return model

We then override the save() method to ensure the PEFT-merged checkpoint is being saved when called:

    def save(
        self,
        path: Optional[str] = None,
        clean: bool = False,
    ) -> None:
        if self._enable_finetuning and self.peft_config is not None:
            self.model.merge_adapter()
        super().save(path=path, clean=clean)

That way, we could avoid implementing additional ModelTransformCallback and PeftCallback which IMHO are a bit opaque to use and maintain.

I also argue that we might not need adapter merge for training checkpoints as it adds overheads and those checkpoints do not need to be compatible. Instead, we could suggest the users call save() at the end of training to get portable model weights.

Fine-tuning Hyperparameters

Like I said in #2964, I recommend exposing the fine-tuning hyper-parameters to users rather than the callback. This allows direct control of fine-tuning behaviours.

model_lora = Chronos2Model(
    input_chunk_length=24,
    output_chunk_length=6,
    enable_finetuning=True,
    n_epochs=50,
    unfreeze_patterns=unfreeze_patterns,
    peft_config=peft_config,
)

For partial fine-tuning, please also consider:

  • Removing freeze_patterns as it is redundant to unfreeze_patterns that is more common than the former.
  • Using fnmatch or suffix (.endswith()) to match model weights rather than prefix-only. Users might want match *.self_attention.q.weight rather than a prefix like encoder.block.0.layer.1.
  • Raising an error when any pattern is not matched in unfreeze_patterns to prevent silent fails.
  • Would it also be possible to combine enable_finetuning and unfreeze_patterns into one parameter enable_finetuning for shared semantics?

For PEFT fine-tuning, please consider:

  • Exposing peft_config as a model hyper-parameter to directly configure PEFT.

Those are merely my suggestions for your considerations. Feel free to ignore them if you disagree.

Many thanks.

@dennisbader dennisbader added this to darts Feb 3, 2026
@github-project-automation github-project-automation bot moved this to In review in darts Feb 3, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

Status: In review

Development

Successfully merging this pull request may close these issues.

[Feature] Chronos-2 fine-tuning support

2 participants