Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion climanet/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,15 @@ def __init__(

# Reshape daily → (M, T=31, H, W), monthly → (M, H, W),
# and get padded_days_mask → (M, T=31)
daily_mt, monthly_m, padded_days_mask = add_month_day_dims(
daily_mt, monthly_m, padded_days_mask, daily_timef = add_month_day_dims(
daily_da, monthly_da, time_dim=time_dim
)

# Convert to numpy once — all __getitem__ calls use these
self.daily_np = daily_mt.to_numpy().copy() # (M, T=31, H, W) float
self.monthly_np = monthly_m.to_numpy().copy() # (M, H, W) float
self.padded_mask_np = padded_days_mask.to_numpy().copy() # (M, T=31) bool
self.daily_timef_np = daily_timef.to_numpy().copy() # (M,T=31, 4)

# Store coordinate arrays
self.lat_coords = daily_da[spatial_dims[0]].to_numpy().copy()
Expand Down Expand Up @@ -137,6 +138,8 @@ def __getitem__(self, idx):
monthly_tensor = torch.from_numpy(monthly_patch).float()
# (1, M, T, H, W)
daily_nan_mask = torch.from_numpy(daily_nan_mask).unsqueeze(0)
# ( M, T, 2)
daily_timef_tensor = torch.from_numpy(self.daily_timef_np).float()

# daily_mask: NaN locations that are NOT land
# Reshape land_tensor for broadcasting: (H, W) → (1, 1, 1, H, W)
Expand All @@ -154,6 +157,7 @@ def __getitem__(self, idx):
"monthly_patch": monthly_tensor, # (M, H, W)
"daily_mask_patch": daily_mask_tensor, # (C=1, M, T=31, H, W)
"land_mask_patch": land_tensor, # (H,W) True=Land
"daily_timef_patch": daily_timef_tensor, #(M, T=31, 4)
"padded_days_mask": self.padded_days_tensor, # (M, T=31) True=padded
"coords": (i, j),
"lat_patch": lat_patch, # (H,)
Expand Down
1 change: 1 addition & 0 deletions climanet/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ def predict_monthly_var(
predictions = model(
batch["daily_patch"].to(device, non_blocking=use_cuda),
batch["daily_mask_patch"].to(device, non_blocking=use_cuda),
batch["daily_timef_patch"].to(device,non_blocking=use_cuda),
batch["land_mask_patch"].to(device, non_blocking=use_cuda),
batch["padded_days_mask"].to(device, non_blocking=use_cuda),
)
Expand Down
116 changes: 110 additions & 6 deletions climanet/st_encoder_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,100 @@ def forward(self, x, mask):
x = self.drop(x)
return x # (B, N_patches, embed_dim)

class CyclicTimeEmbedding(nn.Module):
"""Cyclical Temporal encoding using day-of-year and hour-of-day values in
combination sine and cosine functions

This module generates fixed (non-learnable) trigonometric temporal encodings
for the temporal dimension using the cyclcial phase encoded day-of-year and
hour-of-day values extracted from the datetime associated with the input.
This represents a natural positional encoding on the temporal cycle related
to the solar (tropical) year and the diurnal cycle.

The module uses fixed Fourier frequencies and mixed doy-hod terms to expand
the cyclic encoding to the embedding dimension and capture time of day and
day of year interactions. The returned encodings are intended to be added to
embeddings of the input data by the caller. The module does not perform the
additon.
"""

def __init__(self, embed_dim=128, include_cross=True):
"""
Initialize temporal encodings

Args:
embed_dim: Dimension of the embedding.The default is 128.
Many vision transformers use embedding dimensions that are multiples
of 64 (e.g., 64, 128, 256). This can be tuned.
include_cross: bool, default True. Also Create phase_doy +/- phase_hod
cross term emeddings
"""

super().__init__()

self.include_cross = include_cross

num_base_phase = 2
num_cross = 2 if include_cross else 0
num_phase_terms = num_base_phase + num_cross

#Determine number of frequencies for Fourier expansion in line with embedding dimension

if (embed_dim % (2*num_phase_terms)==0):
num_frequencies = int(embed_dim/(2*num_phase_terms))
self.num_freqencies = num_frequencies
Comment thread
SarahAlidoost marked this conversation as resolved.
freqs = torch.linspace(1.0, num_frequencies, num_frequencies)
self.register_buffer("freqs", freqs)
else:
raise ValueError(
f"embed_dim must be an even multiple of num_phase_terms for fixed encoding."
f"Got embed_dim: {embed_dim} and num_phase_terms: {num_phase_terms}."
)

def forward(self, time_features):
"""
create encodings in of size embedding dimension

Args:
time_features: (B, M, T, D) ; D is base_dim

Returns:
emb_encode : (B,M,T, embed_dim)
"""
B, M, T, D = time_features.shape

#extract individual phases from features
phase_doy = time_features[...,0]
phase_hod = time_features[...,1]
phases = [phase_doy,phase_hod]

#construct cross terms
if self.include_cross:
phases.append(phase_doy + phase_hod)
phases.append(phase_doy - phase_hod)

#stack these to get (B,M,T,num_terms)
x = torch.stack(phases, dim=-1)

#(B, M, T, num_terms, 1)
x= x.unsqueeze(-1)

#(1,1,1,1,F)
freqs = self.freqs.view(1,1,1,1,-1)

#apply frequencies
x = x * freqs # (B, M, T, num_phase_terms, F)

sinx = torch.sin(x)
cosx = torch.cos(x)
Comment thread
SarahAlidoost marked this conversation as resolved.

emb_encode = torch.cat([sinx,cosx],dim=-1) # (B,M,T,num_phase_terms, 2F)

emb_encode = emb_encode.view(B,M,T,-1) # flatten

return emb_encode



class TemporalPositionalEncoding(nn.Module):
"""Temporal Positional Encoding using sine and cosine functions.
Expand Down Expand Up @@ -153,8 +247,9 @@ def __init__(self, embed_dim=128, max_days=31, max_months=12):
"""
super().__init__()

self.time_embed = CyclicTimeEmbedding(embed_dim=embed_dim)

# Positional encodings for days and months
self.pos_days = TemporalPositionalEncoding(embed_dim, max_len=max_days)
self.pos_months = TemporalPositionalEncoding(embed_dim, max_len=max_months)

# Day scorer (within each month)
Expand Down Expand Up @@ -182,14 +277,15 @@ def __init__(self, embed_dim=128, max_days=31, max_months=12):
nn.Linear(4 * embed_dim, embed_dim),
)

def forward(self, x, M, T, H, W, padded_days_mask=None):
def forward(self, x, M, T, H, W, time_features, padded_days_mask=None):
"""
Args:
x: (B, M, T, H, W, C) containing spatio-temporal tokens, where C is the embedding dimension.
M: number of months
T: number of temporal tokens per month after temporal patching (Tp)
H: spatial height after spatial patching
W: spatial width after spatial patching
time_features: (B,M,T,2) containing cyclically phase encoded DOY and HOD
padded_days_mask: Optional boolean tensor of shape (B, M, T), bool,
True indicating which day tokens are padded (because some months
have fewer days). This is used to mask out padded tokens in attention computation.
Expand All @@ -201,10 +297,13 @@ def forward(self, x, M, T, H, W, padded_days_mask=None):
# Reshape to (B, Hp*Wp, M, Tp, C) for temporal processing
seq = x.permute(0, 3, 4, 1, 2, 5).reshape(B, Hp * Wp, M, Tp, C)

pe_days = self.pos_days(T).to(seq.device).to(seq.dtype) # (T, C)
temp_emb = self.time_embed(time_features) # (B,M,T,emd_dim)
#expand spatially
temp_emb = temp_emb[:, None, :, :, :] #[B, 1, M, T, C]
temp_emb = temp_emb.expand(-1, H*W, -1, -1, -1)
pe_months = self.pos_months(M).to(seq.device).to(seq.dtype) # (M, C)

seq = seq + pe_days[None, None, None, :, :] # add day PE
seq = seq + temp_emb # add temporal embeddings
seq = seq + pe_months[None, None, :, None, :] # add month PE

# Day attention per month
Expand Down Expand Up @@ -554,13 +653,15 @@ def __init__(
)
self.patch_size = patch_size

def forward(self, daily_data, daily_mask, land_mask_patch, padded_days_mask=None):
def forward(self, daily_data, daily_mask, daily_timef, land_mask_patch, padded_days_mask=None,):
"""Forward pass of the Spatio-Temporal model.

Args:
daily_data: Tensor of shape (B, C, M, T, H, W) containing daily
data, where C is the number of channels (e.g., 1 for SST)
daily_mask: Boolean tensor of same shape as daily_data indicating missing values
daily_timef: Tensor of shape (B, M, T, 2) containing the cyclically phase encoded day-of-year
and hour-of-day information for the daily data
land_mask_patch: Boolean tensor of shape (B, H, W) to mask land areas in the output
padded_days_mask: Optional boolean tensor of shape (B, M, T) indicating which day tokens are padded
(True for padded tokens). Used to mask out padded tokens in temporal attention.
Expand All @@ -582,6 +683,9 @@ def forward(self, daily_data, daily_mask, land_mask_patch, padded_days_mask=None
)
assert T % self.patch_size[0] == 0, "T must be divisible by patch size"

if self.patch_size[0] > 1:
daily_timef = daily_timef.view(B, M, Tp, self.patch_size[0], 4).mean(dim=3) # -> (B,M, Tp, 4)

if padded_days_mask is not None and self.patch_size[0] > 1:
B, M, T_days = padded_days_mask.shape
if T_days % self.patch_size[0] != 0:
Expand Down Expand Up @@ -611,7 +715,7 @@ def forward(self, daily_data, daily_mask, land_mask_patch, padded_days_mask=None
latent = latent.view(B, M, Tp, Hp, Wp, embed_dim)

agg_latent = self.temporal(
latent, M, Tp, Hp, Wp, padded_days_mask=padded_days_mask
latent, M, Tp, Hp, Wp, daily_timef, padded_days_mask=padded_days_mask
) # (B, M, Hp*Wp, embed_dim)

# Step 3: Add spatial positional encodings and mix spatial features
Expand Down
1 change: 1 addition & 0 deletions climanet/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ def train_monthly_model(
pred = model(
batch["daily_patch"],
batch["daily_mask_patch"],
batch["daily_timef_patch"],
batch["land_mask_patch"],
batch["padded_days_mask"],
) # (B, M, H, W)
Expand Down
38 changes: 37 additions & 1 deletion climanet/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def add_month_day_dims(
daily_m : xr.DataArray - dims: (M, T, H, W)
monthly_m : xr.DataArray - dims: (M, H, W)
padded_days_mask : xr.DataArray - dims: (M, T=31), bool, True where day is padded
time_features : xr.DataArray - dims: (M, T, 2)
"""
# Month key as integer YYYYMM
dkey = daily_ts[time_dim].dt.year * 100 + daily_ts[time_dim].dt.month
Expand Down Expand Up @@ -126,7 +127,42 @@ def add_month_day_dims(
.sel(M=month_keys)
)

return daily_indexed, monthly_m, padded_days_mask
# Build aligned datetime array (M,T)
time_da = daily_ts[time_dim]

#time_indexed is (M,T) with NaT for padded days
time_indexed = (
time_da.assign_coords(M=(time_dim, dkey.values),
T=(time_dim, time_da.dt.day.values))
.set_index({time_dim: ("M", "T")})
.unstack(time_dim)
.reindex(T=np.arange(1,32), M=month_keys)
)

#determine day-of-year (doy) [and hour-of-day (hod) if applicable], fill NaT with 0 inplace
# here we choose to use the tropical year length (365.2422 day, which we round to 365.24) as the
# period to return to the position of the sun relative to the Earth
doy_period = 365.24
hod_period = 24.0

doy = time_indexed.dt.dayofyear.fillna(0)

if "hour" in dir(time_indexed.dt):
hod = time_indexed.dt.hour.fillna(0)
else:
hod = xr.zeros_like(doy)

#create phase from day and hod
doy_phase = 2*np.pi*doy/doy_period
hod_phase = 2*np.pi*hod/hod_period


#Stack cyclic encodings into time_features (M,T,2)
time_features = xr.concat([doy_phase,hod_phase],
dim="feature"
).transpose("M","T","feature")

return daily_indexed, monthly_m, padded_days_mask, time_features


def pred_to_numpy(pred, orig_H=None, orig_W=None, land_mask=None):
Expand Down
Loading