Enhance fine-tuning capabilities for foundation models#3003
Enhance fine-tuning capabilities for foundation models#3003
Conversation
|
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #3003 +/- ##
==========================================
- Coverage 95.69% 95.48% -0.22%
==========================================
Files 154 156 +2
Lines 16604 16753 +149
==========================================
+ Hits 15890 15996 +106
- Misses 714 757 +43 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
|
Hi @Kurokabe Thank you for this PR and your efforts at making fine-tuning work for foundation models! Here are my suggestions: Nested Model AttributeAfter reviewing the code, I have some worries as to the nested
Even if we want PEFT support for foundation models, I wonder if we can do so without running into a nested class FoundationModel(MixedCovariatesTorchModel, ABC):
@abstractmethod
def _create_original_model(self, train_sample: TorchTrainingSample) -> PLForecastingModule:
"""Create the original PyTorch Lightning forecasting module without any PEFT adapters."""
def _create_model(self, train_sample: TorchTrainingSample) -> PLForecastingModule:
model = self._create_original_model(train_sample)
if self._enable_finetuning and self.peft_config is not None:
from peft import get_peft_model
model = get_peft_model(model, self.peft_config)
return modelWe then override the def save(
self,
path: Optional[str] = None,
clean: bool = False,
) -> None:
if self._enable_finetuning and self.peft_config is not None:
self.model.merge_adapter()
super().save(path=path, clean=clean)That way, we could avoid implementing additional I also argue that we might not need adapter merge for training checkpoints as it adds overheads and those checkpoints do not need to be compatible. Instead, we could suggest the users call Fine-tuning HyperparametersLike I said in #2964, I recommend exposing the fine-tuning hyper-parameters to users rather than the callback. This allows direct control of fine-tuning behaviours. model_lora = Chronos2Model(
input_chunk_length=24,
output_chunk_length=6,
enable_finetuning=True,
n_epochs=50,
unfreeze_patterns=unfreeze_patterns,
peft_config=peft_config,
)For partial fine-tuning, please also consider:
For PEFT fine-tuning, please consider:
Those are merely my suggestions for your considerations. Feel free to ignore them if you disagree. Many thanks. |
Checklist before merging this PR:
Fixes #2964
Summary
This PR implements native support for full and partial fine-tuning of foundation models (e.g.,
Chronos2Model) and adds advanced integration capabilities for external libraries likepeft.Foundation Model Enhancements:
FoundationModelbase class to acceptenable_finetuning,freeze_patterns, andunfreeze_patterns.LayerFreezeCallbackwhen fine-tuning is enabled with specific patterns.internal_modelproperty to provide direct access to the underlyingnn.Module, facilitating advanced use cases like PEFT/LoRA.Callback Improvements:
PeftCallbackcorrectly handles adapter merging during checkpointing, allowing models trained with LoRA to be saved and reloaded as standard Darts models.Documentation & Examples:
26-Chronos-2-finetuning-examples.ipynbdemonstrating full fine-tuning, partial fine-tuning with layer freezing, and LoRA integration.Testing:
test_foundation.pycovering all new fine-tuning scenarios and ensuring correct model state after saving/loading.How Has This Been Tested?
FoundationModelfine-tuning logic.PeftCallback.test_foundation.pypass.