Skip to content
56 changes: 51 additions & 5 deletions pytorch_forecasting/layers/_attention/_full_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,15 @@ class FullAttention(nn.Module):
factor (int): Factor for scaling the attention scores.
scale (float): Scaling factor for attention scores.
attention_dropout (float): Dropout rate for attention scores.
output_attention (bool): Whether to output attention weights."""
output_attention (bool): Whether to output attention weights.
use_efficient_attention (bool): Whether to use PyTorch's native,
optimized Scaled Dot Product Attention implementation which can
reduce computation time and memory consumption for longer sequences.
PyTorch automatically selects the optimal backend (FlashAttention-2,
Memory-Efficient Attention, or their own C++ implementation) based
on user's input properties, hardware capabilities, and build
configuration.
"""

def __init__(
self,
Expand All @@ -44,14 +52,35 @@ def __init__(
scale=None,
attention_dropout=0.1,
output_attention=False,
use_efficient_attention=False,
):
super().__init__()

if output_attention and use_efficient_attention:
raise ValueError(
"Cannot output attention scores using efficient attention. "
"Set `use_efficient_attention=False` or "
"`output_attention=False`."
)

self.scale = scale
self.mask_flag = mask_flag
self.output_attention = output_attention
self.use_efficient_attention = use_efficient_attention
self.dropout = nn.Dropout(attention_dropout)

def forward(self, queries, keys, values, attn_mask, tau=None, delta=None):
if self.use_efficient_attention:
V, A = self._efficient_attention(queries, keys, values, attn_mask)
else:
V, A = self._einsum_attention(queries, keys, values, attn_mask)

if self.output_attention:
return V.contiguous(), A
else:
return V.contiguous(), None

def _einsum_attention(self, queries, keys, values, attn_mask):
B, L, H, E = queries.shape
_, S, _, D = values.shape
scale = self.scale or 1.0 / sqrt(E)
Expand All @@ -65,7 +94,24 @@ def forward(self, queries, keys, values, attn_mask, tau=None, delta=None):
A = self.dropout(torch.softmax(scale * scores, dim=-1))
V = torch.einsum("bhls,bshd->blhd", A, values)

if self.output_attention:
return V.contiguous(), A
else:
return V.contiguous(), None
return V, A

def _efficient_attention(self, queries, keys, values, attn_mask):
# SDPA expects [B, H, L, E] shape
queries = queries.transpose(1, 2)
keys = keys.transpose(1, 2)
values = values.transpose(1, 2)

V = nn.functional.scaled_dot_product_attention(
query=queries,
key=keys,
value=values,
attn_mask=attn_mask.mask if attn_mask is not None else None,
dropout_p=self.dropout.p if self.training else 0.0,
is_causal=self.mask_flag if attn_mask is None else False,
scale=self.scale, # if == None, PyTorch computes internally
)

V = V.transpose(1, 2)

return V, None
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def __init__(self, d_model, max_len=5000):
super().__init__()
# Compute the positional encodings once in log space.
pe = torch.zeros(max_len, d_model).float()
pe.require_grad = False
pe.requires_grad = False

position = torch.arange(0, max_len).float().unsqueeze(1)
div_term = (
Expand Down
10 changes: 10 additions & 0 deletions pytorch_forecasting/models/timexer/_timexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def __init__(
d_ff: int = 1024,
dropout: float = 0.2,
activation: str = "relu",
use_efficient_attention: bool = False,
patch_length: int = 16,
factor: int = 5,
embed_type: str = "fixed",
Expand Down Expand Up @@ -118,6 +119,13 @@ def __init__(
regularization.
activation (str, optional): Activation function used in feedforward networks
('relu' or 'gelu').
use_efficient_attention (bool, optional): If set to True, will use
PyTorch's native, optimized Scaled Dot Product Attention
implementation which can reduce computation time and memory
consumption for longer sequences. PyTorch automatically selects the
optimal backend (FlashAttention-2, Memory-Efficient Attention, or
their own C++ implementation) based on user's input properties,
hardware capabilities, and build configuration.
patch_length (int, optional): Length of each non-overlapping patch for
endogenous variable tokenization.
use_norm (bool, optional): Whether to apply normalization to input data.
Expand Down Expand Up @@ -263,6 +271,7 @@ def __init__(
self.hparams.factor,
attention_dropout=self.hparams.dropout,
output_attention=False,
use_efficient_attention=self.hparams.use_efficient_attention,
),
self.hparams.hidden_size,
self.hparams.n_heads,
Expand All @@ -273,6 +282,7 @@ def __init__(
self.hparams.factor,
attention_dropout=self.hparams.dropout,
output_attention=False,
use_efficient_attention=self.hparams.use_efficient_attention,
),
self.hparams.hidden_size,
self.hparams.n_heads,
Expand Down
9 changes: 9 additions & 0 deletions pytorch_forecasting/models/timexer/_timexer_pkg.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,15 @@ def get_base_test_params(cls):
),
),
},
{
"hidden_size": 32,
"patch_length": 1,
"n_heads": 4,
"e_layers": 1,
"d_ff": 32,
"dropout": 0.1,
"use_efficient_attention": True,
},
]

@classmethod
Expand Down
9 changes: 9 additions & 0 deletions pytorch_forecasting/models/timexer/_timexer_pkg_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,4 +162,13 @@ def get_test_train_params(cls):
),
loss=QuantileLoss(quantiles=[0.1, 0.5, 0.9]),
),
dict(
hidden_size=32,
patch_length=1,
n_heads=4,
e_layers=1,
d_ff=32,
dropout=0.1,
use_efficient_attention=True,
),
]
11 changes: 11 additions & 0 deletions pytorch_forecasting/models/timexer/_timexer_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,13 @@ class TimeXer(TslibBaseModel):
Factor for the attention mechanism, controlling the number of keys and values.
activation: str, default='relu'
Activation function to use in the feed-forward network. Common choices are 'relu', 'gelu', etc.
use_efficient_attention: bool, default=False
If set to True, will use PyTorch's native, optimized Scaled Dot Product
Attention implementation which can reduce computation time and memory
consumption for longer sequences. PyTorch automatically selects the
optimal backend (FlashAttention-2, Memory-Efficient Attention, or their
own C++ implementation) based on user's input properties, hardware
capabilities, and build configuration.
endogenous_vars: Optional[list[str]], default=None
List of endogenous variable names to be used in the model. If None, all historical values
for the target variable are used.
Expand Down Expand Up @@ -110,6 +117,7 @@ def __init__(
patch_length: int = 4,
factor: int = 5,
activation: str = "relu",
use_efficient_attention: bool = False,
endogenous_vars: Optional[list[str]] = None,
exogenous_vars: Optional[list[str]] = None,
logging_metrics: Optional[list[nn.Module]] = None,
Expand Down Expand Up @@ -146,6 +154,7 @@ def __init__(
self.dropout = dropout
self.patch_length = patch_length
self.activation = activation
self.use_efficient_attention = use_efficient_attention
self.factor = factor
self.endogenous_vars = endogenous_vars
self.exogenous_vars = exogenous_vars
Expand Down Expand Up @@ -225,6 +234,7 @@ def _init_network(self):
self.factor,
attention_dropout=self.dropout,
output_attention=False,
use_efficient_attention=self.use_efficient_attention,
),
self.hidden_size,
self.n_heads,
Expand All @@ -235,6 +245,7 @@ def _init_network(self):
self.factor,
attention_dropout=self.dropout,
output_attention=False,
use_efficient_attention=self.use_efficient_attention,
),
self.hidden_size,
self.n_heads,
Expand Down
58 changes: 52 additions & 6 deletions pytorch_forecasting/models/timexer/sub_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,15 @@ class FullAttention(nn.Module):
factor (int): Factor for scaling the attention scores.
scale (float): Scaling factor for attention scores.
attention_dropout (float): Dropout rate for attention scores.
output_attention (bool): Whether to output attention weights."""
output_attention (bool): Whether to output attention weights.
use_efficient_attention (bool): Whether to use PyTorch's native,
optimized Scaled Dot Product Attention implementation which can
reduce computation time and memory consumption for longer sequences.
PyTorch automatically selects the optimal backend (FlashAttention-2,
Memory-Efficient Attention, or their own C++ implementation) based
on user's input properties, hardware capabilities, and build
configuration.
"""

def __init__(
self,
Expand All @@ -45,14 +53,35 @@ def __init__(
scale=None,
attention_dropout=0.1,
output_attention=False,
use_efficient_attention=False,
):
super().__init__()

if output_attention and use_efficient_attention:
raise ValueError(
"Cannot output attention scores using efficient attention. "
"Set `use_efficient_attention=False` or "
"`output_attention=False`."
)

self.scale = scale
self.mask_flag = mask_flag
self.output_attention = output_attention
self.use_efficient_attention = use_efficient_attention
self.dropout = nn.Dropout(attention_dropout)

def forward(self, queries, keys, values, attn_mask, tau=None, delta=None):
if self.use_efficient_attention:
V, A = self._efficient_attention(queries, keys, values, attn_mask)
else:
V, A = self._einsum_attention(queries, keys, values, attn_mask)

if self.output_attention:
return V.contiguous(), A
else:
return V.contiguous(), None

def _einsum_attention(self, queries, keys, values, attn_mask):
B, L, H, E = queries.shape
_, S, _, D = values.shape
scale = self.scale or 1.0 / sqrt(E)
Expand All @@ -66,10 +95,27 @@ def forward(self, queries, keys, values, attn_mask, tau=None, delta=None):
A = self.dropout(torch.softmax(scale * scores, dim=-1))
V = torch.einsum("bhls,bshd->blhd", A, values)

if self.output_attention:
return V.contiguous(), A
else:
return V.contiguous(), None
return V, A

def _efficient_attention(self, queries, keys, values, attn_mask):
# SDPA expects [B, H, L, E] shape
queries = queries.transpose(1, 2)
keys = keys.transpose(1, 2)
values = values.transpose(1, 2)

V = nn.functional.scaled_dot_product_attention(
query=queries,
key=keys,
value=values,
attn_mask=attn_mask.mask if attn_mask is not None else None,
dropout_p=self.dropout.p if self.training else 0.0,
is_causal=self.mask_flag if attn_mask is None else False,
scale=self.scale, # if == None, PyTorch computes internally
)

V = V.transpose(1, 2)

return V, None


class AttentionLayer(nn.Module):
Expand Down Expand Up @@ -158,7 +204,7 @@ def __init__(self, d_model, max_len=5000):
super().__init__()
# Compute the positional encodings once in log space.
pe = torch.zeros(max_len, d_model).float()
pe.require_grad = False
pe.requires_grad = False

position = torch.arange(0, max_len).float().unsqueeze(1)
div_term = (
Expand Down
18 changes: 17 additions & 1 deletion tests/test_models/test_timexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,12 @@ def _integration(dataloader, tmp_path, loss=None, trainer_kwargs=None, **kwargs)
shutil.rmtree(tmp_path, ignore_errors=True)


def test_integration(data_with_covariates, tmp_path):
@pytest.mark.parametrize(
"use_efficient_attention",
[False, True],
ids=["einsum_attn", "efficient_attn"],
)
def test_integration(data_with_covariates, tmp_path, use_efficient_attention):
"""
Test simple integration of the TimeXer model with a dataloader.
Args:
Expand All @@ -163,6 +168,7 @@ def test_integration(data_with_covariates, tmp_path):
dataloaders,
tmp_path,
trainer_kwargs={"accelerator": "cpu"},
use_efficient_attention=use_efficient_attention,
)


Expand Down Expand Up @@ -254,6 +260,7 @@ def test_model_init(dataloaders_with_covariates):

model1 = TimeXer.from_dataset(dataset, patch_length=patch_length_from_context)
assert isinstance(model1, TimeXer)
assert model1.hparams.use_efficient_attention is False

model2 = TimeXer.from_dataset(
dataset,
Expand All @@ -272,6 +279,15 @@ def test_model_init(dataloaders_with_covariates):
assert model2.hparams.d_ff == 64
assert model2.hparams.patch_length == 2

# Testing initialization with efficient attention arg
model3 = TimeXer.from_dataset(
dataset,
patch_length=patch_length_from_context,
use_efficient_attention=True,
)
assert isinstance(model3, TimeXer)
assert model3.hparams.use_efficient_attention is True


@pytest.mark.parametrize(
"kwargs",
Expand Down
5 changes: 3 additions & 2 deletions tests/test_models/test_timexer_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,8 +194,8 @@ def basic_metadata(basic_tslib_data_module):
return basic_tslib_data_module.metadata


@pytest.fixture
def model(basic_metadata):
@pytest.fixture(params=[False, True], ids=["einsum_attn", "efficient_attn"])
def model(request, basic_metadata):
"""Initialize a TimeXer model for testing."""
return TimeXer(
loss=MAE(),
Expand All @@ -215,6 +215,7 @@ def model(basic_metadata):
"patience": 5,
},
metadata=basic_metadata,
use_efficient_attention=request.param,
)


Expand Down
Loading