diff --git a/pytorch_forecasting/layers/_attention/_full_attention.py b/pytorch_forecasting/layers/_attention/_full_attention.py index def9b5214..e919a74a0 100644 --- a/pytorch_forecasting/layers/_attention/_full_attention.py +++ b/pytorch_forecasting/layers/_attention/_full_attention.py @@ -35,7 +35,15 @@ class FullAttention(nn.Module): factor (int): Factor for scaling the attention scores. scale (float): Scaling factor for attention scores. attention_dropout (float): Dropout rate for attention scores. - output_attention (bool): Whether to output attention weights.""" + output_attention (bool): Whether to output attention weights. + use_efficient_attention (bool): Whether to use PyTorch's native, + optimized Scaled Dot Product Attention implementation which can + reduce computation time and memory consumption for longer sequences. + PyTorch automatically selects the optimal backend (FlashAttention-2, + Memory-Efficient Attention, or their own C++ implementation) based + on user's input properties, hardware capabilities, and build + configuration. + """ def __init__( self, @@ -44,14 +52,35 @@ def __init__( scale=None, attention_dropout=0.1, output_attention=False, + use_efficient_attention=False, ): super().__init__() + + if output_attention and use_efficient_attention: + raise ValueError( + "Cannot output attention scores using efficient attention. " + "Set `use_efficient_attention=False` or " + "`output_attention=False`." + ) + self.scale = scale self.mask_flag = mask_flag self.output_attention = output_attention + self.use_efficient_attention = use_efficient_attention self.dropout = nn.Dropout(attention_dropout) def forward(self, queries, keys, values, attn_mask, tau=None, delta=None): + if self.use_efficient_attention: + V, A = self._efficient_attention(queries, keys, values, attn_mask) + else: + V, A = self._einsum_attention(queries, keys, values, attn_mask) + + if self.output_attention: + return V.contiguous(), A + else: + return V.contiguous(), None + + def _einsum_attention(self, queries, keys, values, attn_mask): B, L, H, E = queries.shape _, S, _, D = values.shape scale = self.scale or 1.0 / sqrt(E) @@ -65,7 +94,24 @@ def forward(self, queries, keys, values, attn_mask, tau=None, delta=None): A = self.dropout(torch.softmax(scale * scores, dim=-1)) V = torch.einsum("bhls,bshd->blhd", A, values) - if self.output_attention: - return V.contiguous(), A - else: - return V.contiguous(), None + return V, A + + def _efficient_attention(self, queries, keys, values, attn_mask): + # SDPA expects [B, H, L, E] shape + queries = queries.transpose(1, 2) + keys = keys.transpose(1, 2) + values = values.transpose(1, 2) + + V = nn.functional.scaled_dot_product_attention( + query=queries, + key=keys, + value=values, + attn_mask=attn_mask.mask if attn_mask is not None else None, + dropout_p=self.dropout.p if self.training else 0.0, + is_causal=self.mask_flag if attn_mask is None else False, + scale=self.scale, # if == None, PyTorch computes internally + ) + + V = V.transpose(1, 2) + + return V, None diff --git a/pytorch_forecasting/layers/_embeddings/_positional_embedding.py b/pytorch_forecasting/layers/_embeddings/_positional_embedding.py index 82b107315..0156376f7 100644 --- a/pytorch_forecasting/layers/_embeddings/_positional_embedding.py +++ b/pytorch_forecasting/layers/_embeddings/_positional_embedding.py @@ -22,7 +22,7 @@ def __init__(self, d_model, max_len=5000): super().__init__() # Compute the positional encodings once in log space. pe = torch.zeros(max_len, d_model).float() - pe.require_grad = False + pe.requires_grad = False position = torch.arange(0, max_len).float().unsqueeze(1) div_term = ( diff --git a/pytorch_forecasting/models/timexer/_timexer.py b/pytorch_forecasting/models/timexer/_timexer.py index e6d8ea081..6fa29dd2c 100644 --- a/pytorch_forecasting/models/timexer/_timexer.py +++ b/pytorch_forecasting/models/timexer/_timexer.py @@ -60,6 +60,7 @@ def __init__( d_ff: int = 1024, dropout: float = 0.2, activation: str = "relu", + use_efficient_attention: bool = False, patch_length: int = 16, factor: int = 5, embed_type: str = "fixed", @@ -118,6 +119,13 @@ def __init__( regularization. activation (str, optional): Activation function used in feedforward networks ('relu' or 'gelu'). + use_efficient_attention (bool, optional): If set to True, will use + PyTorch's native, optimized Scaled Dot Product Attention + implementation which can reduce computation time and memory + consumption for longer sequences. PyTorch automatically selects the + optimal backend (FlashAttention-2, Memory-Efficient Attention, or + their own C++ implementation) based on user's input properties, + hardware capabilities, and build configuration. patch_length (int, optional): Length of each non-overlapping patch for endogenous variable tokenization. use_norm (bool, optional): Whether to apply normalization to input data. @@ -263,6 +271,7 @@ def __init__( self.hparams.factor, attention_dropout=self.hparams.dropout, output_attention=False, + use_efficient_attention=self.hparams.use_efficient_attention, ), self.hparams.hidden_size, self.hparams.n_heads, @@ -273,6 +282,7 @@ def __init__( self.hparams.factor, attention_dropout=self.hparams.dropout, output_attention=False, + use_efficient_attention=self.hparams.use_efficient_attention, ), self.hparams.hidden_size, self.hparams.n_heads, diff --git a/pytorch_forecasting/models/timexer/_timexer_pkg.py b/pytorch_forecasting/models/timexer/_timexer_pkg.py index d91e81d6c..a1d645c0c 100644 --- a/pytorch_forecasting/models/timexer/_timexer_pkg.py +++ b/pytorch_forecasting/models/timexer/_timexer_pkg.py @@ -79,6 +79,15 @@ def get_base_test_params(cls): ), ), }, + { + "hidden_size": 32, + "patch_length": 1, + "n_heads": 4, + "e_layers": 1, + "d_ff": 32, + "dropout": 0.1, + "use_efficient_attention": True, + }, ] @classmethod diff --git a/pytorch_forecasting/models/timexer/_timexer_pkg_v2.py b/pytorch_forecasting/models/timexer/_timexer_pkg_v2.py index a0e4b8aa7..74b27227f 100644 --- a/pytorch_forecasting/models/timexer/_timexer_pkg_v2.py +++ b/pytorch_forecasting/models/timexer/_timexer_pkg_v2.py @@ -162,4 +162,13 @@ def get_test_train_params(cls): ), loss=QuantileLoss(quantiles=[0.1, 0.5, 0.9]), ), + dict( + hidden_size=32, + patch_length=1, + n_heads=4, + e_layers=1, + d_ff=32, + dropout=0.1, + use_efficient_attention=True, + ), ] diff --git a/pytorch_forecasting/models/timexer/_timexer_v2.py b/pytorch_forecasting/models/timexer/_timexer_v2.py index d07b2266e..023ad937c 100644 --- a/pytorch_forecasting/models/timexer/_timexer_v2.py +++ b/pytorch_forecasting/models/timexer/_timexer_v2.py @@ -57,6 +57,13 @@ class TimeXer(TslibBaseModel): Factor for the attention mechanism, controlling the number of keys and values. activation: str, default='relu' Activation function to use in the feed-forward network. Common choices are 'relu', 'gelu', etc. + use_efficient_attention: bool, default=False + If set to True, will use PyTorch's native, optimized Scaled Dot Product + Attention implementation which can reduce computation time and memory + consumption for longer sequences. PyTorch automatically selects the + optimal backend (FlashAttention-2, Memory-Efficient Attention, or their + own C++ implementation) based on user's input properties, hardware + capabilities, and build configuration. endogenous_vars: Optional[list[str]], default=None List of endogenous variable names to be used in the model. If None, all historical values for the target variable are used. @@ -110,6 +117,7 @@ def __init__( patch_length: int = 4, factor: int = 5, activation: str = "relu", + use_efficient_attention: bool = False, endogenous_vars: Optional[list[str]] = None, exogenous_vars: Optional[list[str]] = None, logging_metrics: Optional[list[nn.Module]] = None, @@ -146,6 +154,7 @@ def __init__( self.dropout = dropout self.patch_length = patch_length self.activation = activation + self.use_efficient_attention = use_efficient_attention self.factor = factor self.endogenous_vars = endogenous_vars self.exogenous_vars = exogenous_vars @@ -225,6 +234,7 @@ def _init_network(self): self.factor, attention_dropout=self.dropout, output_attention=False, + use_efficient_attention=self.use_efficient_attention, ), self.hidden_size, self.n_heads, @@ -235,6 +245,7 @@ def _init_network(self): self.factor, attention_dropout=self.dropout, output_attention=False, + use_efficient_attention=self.use_efficient_attention, ), self.hidden_size, self.n_heads, diff --git a/pytorch_forecasting/models/timexer/sub_modules.py b/pytorch_forecasting/models/timexer/sub_modules.py index 87dacb034..bf30f8b02 100644 --- a/pytorch_forecasting/models/timexer/sub_modules.py +++ b/pytorch_forecasting/models/timexer/sub_modules.py @@ -36,7 +36,15 @@ class FullAttention(nn.Module): factor (int): Factor for scaling the attention scores. scale (float): Scaling factor for attention scores. attention_dropout (float): Dropout rate for attention scores. - output_attention (bool): Whether to output attention weights.""" + output_attention (bool): Whether to output attention weights. + use_efficient_attention (bool): Whether to use PyTorch's native, + optimized Scaled Dot Product Attention implementation which can + reduce computation time and memory consumption for longer sequences. + PyTorch automatically selects the optimal backend (FlashAttention-2, + Memory-Efficient Attention, or their own C++ implementation) based + on user's input properties, hardware capabilities, and build + configuration. + """ def __init__( self, @@ -45,14 +53,35 @@ def __init__( scale=None, attention_dropout=0.1, output_attention=False, + use_efficient_attention=False, ): super().__init__() + + if output_attention and use_efficient_attention: + raise ValueError( + "Cannot output attention scores using efficient attention. " + "Set `use_efficient_attention=False` or " + "`output_attention=False`." + ) + self.scale = scale self.mask_flag = mask_flag self.output_attention = output_attention + self.use_efficient_attention = use_efficient_attention self.dropout = nn.Dropout(attention_dropout) def forward(self, queries, keys, values, attn_mask, tau=None, delta=None): + if self.use_efficient_attention: + V, A = self._efficient_attention(queries, keys, values, attn_mask) + else: + V, A = self._einsum_attention(queries, keys, values, attn_mask) + + if self.output_attention: + return V.contiguous(), A + else: + return V.contiguous(), None + + def _einsum_attention(self, queries, keys, values, attn_mask): B, L, H, E = queries.shape _, S, _, D = values.shape scale = self.scale or 1.0 / sqrt(E) @@ -66,10 +95,27 @@ def forward(self, queries, keys, values, attn_mask, tau=None, delta=None): A = self.dropout(torch.softmax(scale * scores, dim=-1)) V = torch.einsum("bhls,bshd->blhd", A, values) - if self.output_attention: - return V.contiguous(), A - else: - return V.contiguous(), None + return V, A + + def _efficient_attention(self, queries, keys, values, attn_mask): + # SDPA expects [B, H, L, E] shape + queries = queries.transpose(1, 2) + keys = keys.transpose(1, 2) + values = values.transpose(1, 2) + + V = nn.functional.scaled_dot_product_attention( + query=queries, + key=keys, + value=values, + attn_mask=attn_mask.mask if attn_mask is not None else None, + dropout_p=self.dropout.p if self.training else 0.0, + is_causal=self.mask_flag if attn_mask is None else False, + scale=self.scale, # if == None, PyTorch computes internally + ) + + V = V.transpose(1, 2) + + return V, None class AttentionLayer(nn.Module): @@ -158,7 +204,7 @@ def __init__(self, d_model, max_len=5000): super().__init__() # Compute the positional encodings once in log space. pe = torch.zeros(max_len, d_model).float() - pe.require_grad = False + pe.requires_grad = False position = torch.arange(0, max_len).float().unsqueeze(1) div_term = ( diff --git a/tests/test_models/test_timexer.py b/tests/test_models/test_timexer.py index 6bef80348..4ce98e6d6 100644 --- a/tests/test_models/test_timexer.py +++ b/tests/test_models/test_timexer.py @@ -142,7 +142,12 @@ def _integration(dataloader, tmp_path, loss=None, trainer_kwargs=None, **kwargs) shutil.rmtree(tmp_path, ignore_errors=True) -def test_integration(data_with_covariates, tmp_path): +@pytest.mark.parametrize( + "use_efficient_attention", + [False, True], + ids=["einsum_attn", "efficient_attn"], +) +def test_integration(data_with_covariates, tmp_path, use_efficient_attention): """ Test simple integration of the TimeXer model with a dataloader. Args: @@ -163,6 +168,7 @@ def test_integration(data_with_covariates, tmp_path): dataloaders, tmp_path, trainer_kwargs={"accelerator": "cpu"}, + use_efficient_attention=use_efficient_attention, ) @@ -254,6 +260,7 @@ def test_model_init(dataloaders_with_covariates): model1 = TimeXer.from_dataset(dataset, patch_length=patch_length_from_context) assert isinstance(model1, TimeXer) + assert model1.hparams.use_efficient_attention is False model2 = TimeXer.from_dataset( dataset, @@ -272,6 +279,15 @@ def test_model_init(dataloaders_with_covariates): assert model2.hparams.d_ff == 64 assert model2.hparams.patch_length == 2 + # Testing initialization with efficient attention arg + model3 = TimeXer.from_dataset( + dataset, + patch_length=patch_length_from_context, + use_efficient_attention=True, + ) + assert isinstance(model3, TimeXer) + assert model3.hparams.use_efficient_attention is True + @pytest.mark.parametrize( "kwargs", diff --git a/tests/test_models/test_timexer_v2.py b/tests/test_models/test_timexer_v2.py index 0f20ef55a..9cb50aea6 100644 --- a/tests/test_models/test_timexer_v2.py +++ b/tests/test_models/test_timexer_v2.py @@ -194,8 +194,8 @@ def basic_metadata(basic_tslib_data_module): return basic_tslib_data_module.metadata -@pytest.fixture -def model(basic_metadata): +@pytest.fixture(params=[False, True], ids=["einsum_attn", "efficient_attn"]) +def model(request, basic_metadata): """Initialize a TimeXer model for testing.""" return TimeXer( loss=MAE(), @@ -215,6 +215,7 @@ def model(basic_metadata): "patience": 5, }, metadata=basic_metadata, + use_efficient_attention=request.param, )