Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ but cannot always guarantee backwards compatibility. Changes that may **break co
- Includes automatic downsampling for large series (configurable via `downsample_threshold` parameter) to avoid crashes when plotting large series
- Integrates seamlessly with `plotting.use_darts_style` which now affects both `TimeSeries.plot()` and `TimeSeries.plotly()`
- Plotly remains an optional dependency and can be installed with `pip install plotly`
- Added support for full and partial fine-tuning of foundation models with integrated layer freezing and `PeftCallback` for LoRA integration. [#2964](https://github.com/unit8co/darts/issues/2964) by [Alain Gysi](https://github.com/Kurokabe)

**Fixed**

Expand Down
12 changes: 6 additions & 6 deletions darts/models/components/huggingface_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@
from safetensors.torch import load_file

from darts.logging import get_logger, raise_log
from darts.models.forecasting.pl_forecasting_module import (
PLForecastingModule,
)
from darts.models.forecasting.pl_forecasting_module import PLForecastingModule

logger = get_logger(__name__)

from darts.models.forecasting.foundation_model import FoundationPLModule


class HuggingFaceConnector:
def __init__(
Expand Down Expand Up @@ -109,10 +109,10 @@ def load_model_weights(

def load_model(
self,
module_class: type[PLForecastingModule],
module_class: type[FoundationPLModule],
pl_module_params: dict,
additional_params: Optional[dict] = None,
) -> PLForecastingModule:
) -> FoundationPLModule:
"""Load the model by creating an instance of the given module class and loading
the weights. Some configuration files might contain external parameters that
are not part of the module class constructor like `architectures`. They are filtered
Expand Down Expand Up @@ -140,7 +140,7 @@ def load_model(
**pl_module_params,
**additional_params,
)
self.load_model_weights(module)
self.load_model_weights(module.model)
return module

def _get_file_path(
Expand Down
151 changes: 114 additions & 37 deletions darts/models/forecasting/chronos2_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,10 @@
_Patch,
_ResidualBlock,
)
from darts.models.components.huggingface_connector import (
HuggingFaceConnector,
)
from darts.models.components.huggingface_connector import HuggingFaceConnector
from darts.models.forecasting.foundation_model import (
FoundationModel,
)
from darts.models.forecasting.pl_forecasting_module import (
PLForecastingModule,
FoundationPLModule,
)
from darts.utils.data.torch_datasets.utils import PLModuleInput, TorchTrainingSample
from darts.utils.likelihood_models.torch import QuantileRegression
Expand All @@ -51,7 +47,7 @@ class _Chronos2ForecastingConfig:
time_encoding_scale: int | None = None


class _Chronos2Module(PLForecastingModule):
class _Chronos2Module(nn.Module):
def __init__(
self,
d_model: int = 512,
Expand All @@ -65,11 +61,10 @@ def __init__(
rope_theta: float = 10000.0,
attn_implementation: Literal["eager", "sdpa"] | None = None,
chronos_config: Optional[dict[str, Any]] = None,
quantiles: list[float] = None,
**kwargs,
):
"""PyTorch module implementing the Chronos-2 model, ported from
`amazon-science/chronos-forecasting <https://github.com/amazon-science/chronos-forecasting>`_ and
adapted for Darts :class:`PLForecastingModule` interface.
"""Core Chronos-2 model containing all the modules and forward logic.

Parameters
----------
Expand All @@ -95,12 +90,12 @@ def __init__(
Attention implementation to use. If None, defaults to "sdpa".
chronos_config
Configuration parameters for Chronos-2 model. See :class:`_Chronos2ForecastingConfig` for details.
quantiles
List of quantiles for probabilistic forecasting.
**kwargs
all parameters required for :class:`darts.models.forecasting.pl_forecasting_module.PLForecastingModule`
base class.
Additional keyword arguments.
"""

super().__init__(**kwargs)
super().__init__()
self.d_model = d_model
self.d_kv = d_kv
self.d_ff = d_ff
Expand Down Expand Up @@ -187,19 +182,11 @@ def __init__(
num_layers=self.num_layers,
)

quantiles = self.chronos_config.quantiles
quantiles = quantiles or self.chronos_config.quantiles
self.num_quantiles = len(quantiles)
quantiles_tensor = torch.tensor(quantiles)
self.register_buffer("quantiles", quantiles_tensor, persistent=False)

# gather indices of user-specified quantiles
user_quantiles: list[float] = (
self.likelihood.quantiles
if isinstance(self.likelihood, QuantileRegression)
else [0.5]
)
self.user_quantile_indices = [quantiles.index(q) for q in user_quantiles]

self.output_patch_embedding = _ResidualBlock(
in_dim=self.d_model,
h_dim=self.d_ff,
Expand All @@ -208,6 +195,14 @@ def __init__(
dropout_p=self.dropout_rate,
)

@property
def device(self) -> torch.device:
return next(self.parameters()).device

@property
def dtype(self) -> torch.dtype:
return next(self.parameters()).dtype

def _prepare_patched_context(
self,
context: torch.Tensor,
Expand Down Expand Up @@ -334,14 +329,14 @@ def _prepare_patched_future(

return patched_future, patched_future_covariates_mask

def _forward(
def forward(
self,
context: torch.Tensor,
group_ids: torch.Tensor,
future_covariates: torch.Tensor,
num_output_patches: int = 1,
) -> torch.Tensor:
"""Original forward pass of the Chronos-2 model.
"""Forward pass of the Chronos-2 model.

Parameters
----------
Expand Down Expand Up @@ -454,6 +449,86 @@ def _forward(

return quantile_preds


class _Chronos2PLModule(FoundationPLModule):
def __init__(
self,
d_model: int = 512,
d_kv: int = 64,
d_ff: int = 2048,
num_layers: int = 6,
num_heads: int = 8,
dropout_rate: float = 0.1,
layer_norm_epsilon: float = 1e-6,
feed_forward_proj: str = "relu",
rope_theta: float = 10000.0,
attn_implementation: Literal["eager", "sdpa"] | None = None,
chronos_config: Optional[dict[str, Any]] = None,
**kwargs,
):
"""PyTorch Lightning module wrapper for the Chronos-2 model, adapted for
Darts :class:`PLForecastingModule` interface.

Parameters
----------
d_model
Dimension of the model embeddings, also called "model size" in Transformer.
d_kv
Dimension of the key and value projections in multi-head attention.
d_ff
Dimension of the feed-forward network hidden layer.
num_layers
Number of Chronos-2 encoder layers.
num_heads
Number of attention heads in each encoder block.
dropout_rate
Dropout rate of the model.
layer_norm_epsilon
Epsilon value for layer normalization layers.
feed_forward_proj
Activation of feed-forward network.
rope_theta
Base period for Rotary Position Embeddings (RoPE).
attn_implementation
Attention implementation to use. If None, defaults to "sdpa".
chronos_config
Configuration parameters for Chronos-2 model. See :class:`_Chronos2ForecastingConfig` for details.
**kwargs
all parameters required for :class:`darts.models.forecasting.pl_forecasting_module.PLForecastingModule`
base class.
"""

super().__init__(**kwargs)

# Get quantiles for model initialization
chronos_config = chronos_config or {}
chronos_config_obj = _Chronos2ForecastingConfig(**chronos_config)
quantiles = chronos_config_obj.quantiles

# gather indices of user-specified quantiles
user_quantiles: list[float] = (
self.likelihood.quantiles
if isinstance(self.likelihood, QuantileRegression)
else [0.5]
)
self.user_quantile_indices = [quantiles.index(q) for q in user_quantiles]

# Create the core Chronos-2 model
self.model = _Chronos2Module(
d_model=d_model,
d_kv=d_kv,
d_ff=d_ff,
num_layers=num_layers,
num_heads=num_heads,
dropout_rate=dropout_rate,
layer_norm_epsilon=layer_norm_epsilon,
feed_forward_proj=feed_forward_proj,
rope_theta=rope_theta,
attn_implementation=attn_implementation,
chronos_config=chronos_config,
quantiles=quantiles,
)

# TODO: fine-tuning support w/ normalized loss
# Currently, Darts own `RINorm` is not used as Chronos-2 has its own implementation. Major differences
# 1. Chronos-2 `RINorm` normalizes both target and covariates, while Darts normalizes target only.
Expand Down Expand Up @@ -508,16 +583,16 @@ def forward(self, x_in: PLModuleInput, *args, **kwargs) -> Any:

# determine minimum number of patches to cover future_length
num_output_patches = math.ceil(
future_length / self.chronos_config.output_patch_size
future_length / self.model.chronos_config.output_patch_size
)

# call original Chronos-2 forward pass
# call the core model's forward pass
# Unlike the original, we remove `context_mask`, `future_covariates_mask`, `future_target`,
# `future_target_mask`, and `output_attentions` parameters. They are not needed for Darts'
# implementation.
# We also remove `einops` rearrange operation at the end so the raw output tensor is returned,
# in shape of `(batch, vars * patches * quantiles * patch_size)`
quantile_preds = self._forward(
quantile_preds = self.model.forward(
context=context,
group_ids=group_ids,
future_covariates=future_covariates,
Expand All @@ -532,15 +607,15 @@ def forward(self, x_in: PLModuleInput, *args, **kwargs) -> Any:
batch_size,
n_variables,
num_output_patches,
self.num_quantiles,
self.chronos_config.output_patch_size,
self.model.num_quantiles,
self.model.chronos_config.output_patch_size,
)
# permute and reshape to (batch, time, vars, quantiles)
quantile_preds = quantile_preds.permute(0, 2, 4, 1, 3).reshape(
batch_size,
num_output_patches * self.chronos_config.output_patch_size,
num_output_patches * self.model.chronos_config.output_patch_size,
n_variables,
self.num_quantiles,
self.model.num_quantiles,
)

# truncate to output_chunk_length
Expand All @@ -558,7 +633,7 @@ def forward(self, x_in: PLModuleInput, *args, **kwargs) -> Any:
class Chronos2Model(FoundationModel):
# Fine-tuning is turned off for now pending proper fine-tuning support
# and configuration.
_allows_finetuning = False
_allows_finetuning = True

def __init__(
self,
Expand Down Expand Up @@ -881,11 +956,13 @@ def encode_year(idx):
)

self.hf_connector = hf_connector
super().__init__(enable_finetuning=False, **kwargs)
super().__init__(**kwargs)

def _create_model(self, train_sample: TorchTrainingSample) -> PLForecastingModule:
def _create_model(self, train_sample: TorchTrainingSample) -> FoundationPLModule:
pl_module_params = self.pl_module_params or {}
return self.hf_connector.load_model(
module_class=_Chronos2Module,
model = self.hf_connector.load_model(
module_class=_Chronos2PLModule,
pl_module_params=pl_module_params,
)

return model
Loading