diff --git a/climanet/dataset.py b/climanet/dataset.py index 7976297..4432845 100644 --- a/climanet/dataset.py +++ b/climanet/dataset.py @@ -41,7 +41,7 @@ def __init__( # Reshape daily → (M, T=31, H, W), monthly → (M, H, W), # and get padded_days_mask → (M, T=31) - daily_mt, monthly_m, padded_days_mask = add_month_day_dims( + daily_mt, monthly_m, padded_days_mask, daily_timef = add_month_day_dims( daily_da, monthly_da, time_dim=time_dim ) @@ -49,6 +49,7 @@ def __init__( self.daily_np = daily_mt.to_numpy().copy() # (M, T=31, H, W) float self.monthly_np = monthly_m.to_numpy().copy() # (M, H, W) float self.padded_mask_np = padded_days_mask.to_numpy().copy() # (M, T=31) bool + self.daily_timef_np = daily_timef.to_numpy().copy() # (M,T=31, 4) # Store coordinate arrays self.lat_coords = daily_da[spatial_dims[0]].to_numpy().copy() @@ -137,6 +138,8 @@ def __getitem__(self, idx): monthly_tensor = torch.from_numpy(monthly_patch).float() # (1, M, T, H, W) daily_nan_mask = torch.from_numpy(daily_nan_mask).unsqueeze(0) + # ( M, T, 2) + daily_timef_tensor = torch.from_numpy(self.daily_timef_np).float() # daily_mask: NaN locations that are NOT land # Reshape land_tensor for broadcasting: (H, W) → (1, 1, 1, H, W) @@ -154,6 +157,7 @@ def __getitem__(self, idx): "monthly_patch": monthly_tensor, # (M, H, W) "daily_mask_patch": daily_mask_tensor, # (C=1, M, T=31, H, W) "land_mask_patch": land_tensor, # (H,W) True=Land + "daily_timef_patch": daily_timef_tensor, #(M, T=31, 4) "padded_days_mask": self.padded_days_tensor, # (M, T=31) True=padded "coords": (i, j), "lat_patch": lat_patch, # (H,) diff --git a/climanet/predict.py b/climanet/predict.py index bb716d9..ebf1ded 100644 --- a/climanet/predict.py +++ b/climanet/predict.py @@ -109,6 +109,7 @@ def predict_monthly_var( predictions = model( batch["daily_patch"].to(device, non_blocking=use_cuda), batch["daily_mask_patch"].to(device, non_blocking=use_cuda), + batch["daily_timef_patch"].to(device,non_blocking=use_cuda), batch["land_mask_patch"].to(device, non_blocking=use_cuda), batch["padded_days_mask"].to(device, non_blocking=use_cuda), ) diff --git a/climanet/st_encoder_decoder.py b/climanet/st_encoder_decoder.py index e09171c..0f952b1 100644 --- a/climanet/st_encoder_decoder.py +++ b/climanet/st_encoder_decoder.py @@ -78,6 +78,100 @@ def forward(self, x, mask): x = self.drop(x) return x # (B, N_patches, embed_dim) +class CyclicTimeEmbedding(nn.Module): + """Cyclical Temporal encoding using day-of-year and hour-of-day values in + combination sine and cosine functions + + This module generates fixed (non-learnable) trigonometric temporal encodings + for the temporal dimension using the cyclcial phase encoded day-of-year and + hour-of-day values extracted from the datetime associated with the input. + This represents a natural positional encoding on the temporal cycle related + to the solar (tropical) year and the diurnal cycle. + + The module uses fixed Fourier frequencies and mixed doy-hod terms to expand + the cyclic encoding to the embedding dimension and capture time of day and + day of year interactions. The returned encodings are intended to be added to + embeddings of the input data by the caller. The module does not perform the + additon. + """ + + def __init__(self, embed_dim=128, include_cross=True): + """ + Initialize temporal encodings + + Args: + embed_dim: Dimension of the embedding.The default is 128. + Many vision transformers use embedding dimensions that are multiples + of 64 (e.g., 64, 128, 256). This can be tuned. + include_cross: bool, default True. Also Create phase_doy +/- phase_hod + cross term emeddings + """ + + super().__init__() + + self.include_cross = include_cross + + num_base_phase = 2 + num_cross = 2 if include_cross else 0 + num_phase_terms = num_base_phase + num_cross + + #Determine number of frequencies for Fourier expansion in line with embedding dimension + + if (embed_dim % (2*num_phase_terms)==0): + num_frequencies = int(embed_dim/(2*num_phase_terms)) + self.num_freqencies = num_frequencies + freqs = torch.linspace(1.0, num_frequencies, num_frequencies) + self.register_buffer("freqs", freqs) + else: + raise ValueError( + f"embed_dim must be an even multiple of num_phase_terms for fixed encoding." + f"Got embed_dim: {embed_dim} and num_phase_terms: {num_phase_terms}." + ) + + def forward(self, time_features): + """ + create encodings in of size embedding dimension + + Args: + time_features: (B, M, T, D) ; D is base_dim + + Returns: + emb_encode : (B,M,T, embed_dim) + """ + B, M, T, D = time_features.shape + + #extract individual phases from features + phase_doy = time_features[...,0] + phase_hod = time_features[...,1] + phases = [phase_doy,phase_hod] + + #construct cross terms + if self.include_cross: + phases.append(phase_doy + phase_hod) + phases.append(phase_doy - phase_hod) + + #stack these to get (B,M,T,num_terms) + x = torch.stack(phases, dim=-1) + + #(B, M, T, num_terms, 1) + x= x.unsqueeze(-1) + + #(1,1,1,1,F) + freqs = self.freqs.view(1,1,1,1,-1) + + #apply frequencies + x = x * freqs # (B, M, T, num_phase_terms, F) + + sinx = torch.sin(x) + cosx = torch.cos(x) + + emb_encode = torch.cat([sinx,cosx],dim=-1) # (B,M,T,num_phase_terms, 2F) + + emb_encode = emb_encode.view(B,M,T,-1) # flatten + + return emb_encode + + class TemporalPositionalEncoding(nn.Module): """Temporal Positional Encoding using sine and cosine functions. @@ -153,8 +247,9 @@ def __init__(self, embed_dim=128, max_days=31, max_months=12): """ super().__init__() + self.time_embed = CyclicTimeEmbedding(embed_dim=embed_dim) + # Positional encodings for days and months - self.pos_days = TemporalPositionalEncoding(embed_dim, max_len=max_days) self.pos_months = TemporalPositionalEncoding(embed_dim, max_len=max_months) # Day scorer (within each month) @@ -182,7 +277,7 @@ def __init__(self, embed_dim=128, max_days=31, max_months=12): nn.Linear(4 * embed_dim, embed_dim), ) - def forward(self, x, M, T, H, W, padded_days_mask=None): + def forward(self, x, M, T, H, W, time_features, padded_days_mask=None): """ Args: x: (B, M, T, H, W, C) containing spatio-temporal tokens, where C is the embedding dimension. @@ -190,6 +285,7 @@ def forward(self, x, M, T, H, W, padded_days_mask=None): T: number of temporal tokens per month after temporal patching (Tp) H: spatial height after spatial patching W: spatial width after spatial patching + time_features: (B,M,T,2) containing cyclically phase encoded DOY and HOD padded_days_mask: Optional boolean tensor of shape (B, M, T), bool, True indicating which day tokens are padded (because some months have fewer days). This is used to mask out padded tokens in attention computation. @@ -201,10 +297,13 @@ def forward(self, x, M, T, H, W, padded_days_mask=None): # Reshape to (B, Hp*Wp, M, Tp, C) for temporal processing seq = x.permute(0, 3, 4, 1, 2, 5).reshape(B, Hp * Wp, M, Tp, C) - pe_days = self.pos_days(T).to(seq.device).to(seq.dtype) # (T, C) + temp_emb = self.time_embed(time_features) # (B,M,T,emd_dim) + #expand spatially + temp_emb = temp_emb[:, None, :, :, :] #[B, 1, M, T, C] + temp_emb = temp_emb.expand(-1, H*W, -1, -1, -1) pe_months = self.pos_months(M).to(seq.device).to(seq.dtype) # (M, C) - seq = seq + pe_days[None, None, None, :, :] # add day PE + seq = seq + temp_emb # add temporal embeddings seq = seq + pe_months[None, None, :, None, :] # add month PE # Day attention per month @@ -554,13 +653,15 @@ def __init__( ) self.patch_size = patch_size - def forward(self, daily_data, daily_mask, land_mask_patch, padded_days_mask=None): + def forward(self, daily_data, daily_mask, daily_timef, land_mask_patch, padded_days_mask=None,): """Forward pass of the Spatio-Temporal model. Args: daily_data: Tensor of shape (B, C, M, T, H, W) containing daily data, where C is the number of channels (e.g., 1 for SST) daily_mask: Boolean tensor of same shape as daily_data indicating missing values + daily_timef: Tensor of shape (B, M, T, 2) containing the cyclically phase encoded day-of-year + and hour-of-day information for the daily data land_mask_patch: Boolean tensor of shape (B, H, W) to mask land areas in the output padded_days_mask: Optional boolean tensor of shape (B, M, T) indicating which day tokens are padded (True for padded tokens). Used to mask out padded tokens in temporal attention. @@ -582,6 +683,9 @@ def forward(self, daily_data, daily_mask, land_mask_patch, padded_days_mask=None ) assert T % self.patch_size[0] == 0, "T must be divisible by patch size" + if self.patch_size[0] > 1: + daily_timef = daily_timef.view(B, M, Tp, self.patch_size[0], 4).mean(dim=3) # -> (B,M, Tp, 4) + if padded_days_mask is not None and self.patch_size[0] > 1: B, M, T_days = padded_days_mask.shape if T_days % self.patch_size[0] != 0: @@ -611,7 +715,7 @@ def forward(self, daily_data, daily_mask, land_mask_patch, padded_days_mask=None latent = latent.view(B, M, Tp, Hp, Wp, embed_dim) agg_latent = self.temporal( - latent, M, Tp, Hp, Wp, padded_days_mask=padded_days_mask + latent, M, Tp, Hp, Wp, daily_timef, padded_days_mask=padded_days_mask ) # (B, M, Hp*Wp, embed_dim) # Step 3: Add spatial positional encodings and mix spatial features diff --git a/climanet/train.py b/climanet/train.py index 4344d12..504329d 100644 --- a/climanet/train.py +++ b/climanet/train.py @@ -111,6 +111,7 @@ def train_monthly_model( pred = model( batch["daily_patch"], batch["daily_mask_patch"], + batch["daily_timef_patch"], batch["land_mask_patch"], batch["padded_days_mask"], ) # (B, M, H, W) diff --git a/climanet/utils.py b/climanet/utils.py index 0d5688b..0d52551 100644 --- a/climanet/utils.py +++ b/climanet/utils.py @@ -93,6 +93,7 @@ def add_month_day_dims( daily_m : xr.DataArray - dims: (M, T, H, W) monthly_m : xr.DataArray - dims: (M, H, W) padded_days_mask : xr.DataArray - dims: (M, T=31), bool, True where day is padded + time_features : xr.DataArray - dims: (M, T, 2) """ # Month key as integer YYYYMM dkey = daily_ts[time_dim].dt.year * 100 + daily_ts[time_dim].dt.month @@ -126,7 +127,42 @@ def add_month_day_dims( .sel(M=month_keys) ) - return daily_indexed, monthly_m, padded_days_mask + # Build aligned datetime array (M,T) + time_da = daily_ts[time_dim] + + #time_indexed is (M,T) with NaT for padded days + time_indexed = ( + time_da.assign_coords(M=(time_dim, dkey.values), + T=(time_dim, time_da.dt.day.values)) + .set_index({time_dim: ("M", "T")}) + .unstack(time_dim) + .reindex(T=np.arange(1,32), M=month_keys) + ) + + #determine day-of-year (doy) [and hour-of-day (hod) if applicable], fill NaT with 0 inplace + # here we choose to use the tropical year length (365.2422 day, which we round to 365.24) as the + # period to return to the position of the sun relative to the Earth + doy_period = 365.24 + hod_period = 24.0 + + doy = time_indexed.dt.dayofyear.fillna(0) + + if "hour" in dir(time_indexed.dt): + hod = time_indexed.dt.hour.fillna(0) + else: + hod = xr.zeros_like(doy) + + #create phase from day and hod + doy_phase = 2*np.pi*doy/doy_period + hod_phase = 2*np.pi*hod/hod_period + + + #Stack cyclic encodings into time_features (M,T,2) + time_features = xr.concat([doy_phase,hod_phase], + dim="feature" + ).transpose("M","T","feature") + + return daily_indexed, monthly_m, padded_days_mask, time_features def pred_to_numpy(pred, orig_H=None, orig_W=None, land_mask=None): diff --git a/notebooks/example.ipynb b/notebooks/example.ipynb index 4143170..b881af0 100644 --- a/notebooks/example.ipynb +++ b/notebooks/example.ipynb @@ -29,12 +29,12 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "id": "13a3b0c8-1d92-460d-84a4-a3a59ca081af", "metadata": {}, "outputs": [], "source": [ - "data_folder = Path(\"./eso4clima\")\n", + "data_folder = Path(\"../../data/output\")\n", "\n", "file_names = [data_folder / \"202001_day_ERA5_masked_ts.nc\", data_folder / \"202002_day_ERA5_masked_ts.nc\"]\n", "daily_data = xr.open_mfdataset(file_names)\n", @@ -99,7 +99,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 4, "id": "bcc04777-5235-4ef3-81bd-2bdcafd8baaa", "metadata": {}, "outputs": [], @@ -112,7 +112,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 5, "id": "09eeabbe-36ef-46a4-ad39-b82559a2da2e", "metadata": {}, "outputs": [], @@ -136,7 +136,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 6, "id": "e84304f1-deb2-4f7c-b026-9ee4bbb38272", "metadata": {}, "outputs": [ @@ -144,33 +144,33 @@ "name": "stdout", "output_type": "stream", "text": [ - "Epoch 0: best_loss = 1.064811\n", - "Epoch 20: best_loss = 0.848913\n", - "Epoch 40: best_loss = 0.742154\n", - "Epoch 60: best_loss = 0.590368\n", - "Epoch 80: best_loss = 0.474284\n", - "Epoch 100: best_loss = 0.373293\n", - "Epoch 120: best_loss = 0.313690\n", - "Epoch 140: best_loss = 0.262067\n", - "Epoch 160: best_loss = 0.233068\n", - "Epoch 180: best_loss = 0.198232\n", - "Epoch 200: best_loss = 0.174762\n", - "Epoch 220: best_loss = 0.157859\n", - "Epoch 240: best_loss = 0.144779\n", - "Epoch 260: best_loss = 0.134052\n", - "Epoch 280: best_loss = 0.127401\n", - "Epoch 300: best_loss = 0.122035\n", - "Epoch 320: best_loss = 0.117025\n", - "Epoch 340: best_loss = 0.112626\n", - "Epoch 360: best_loss = 0.108094\n", - "Epoch 380: best_loss = 0.105740\n", - "Epoch 400: best_loss = 0.103568\n", - "Epoch 420: best_loss = 0.101469\n", - "Epoch 440: best_loss = 0.099437\n", - "Epoch 460: best_loss = 0.097428\n", - "Epoch 480: best_loss = 0.095427\n", - "Epoch 500: best_loss = 0.093429\n", - "Training complete. Best loss: 0.093429\n", + "Epoch 0: best_loss = 1.066736\n", + "Epoch 20: best_loss = 1.004169\n", + "Epoch 40: best_loss = 0.722893\n", + "Epoch 60: best_loss = 0.461043\n", + "Epoch 80: best_loss = 0.314689\n", + "Epoch 100: best_loss = 0.231682\n", + "Epoch 120: best_loss = 0.165844\n", + "Epoch 140: best_loss = 0.136603\n", + "Epoch 160: best_loss = 0.120698\n", + "Epoch 180: best_loss = 0.103750\n", + "Epoch 200: best_loss = 0.091409\n", + "Epoch 220: best_loss = 0.083936\n", + "Epoch 240: best_loss = 0.078847\n", + "Epoch 260: best_loss = 0.074559\n", + "Epoch 280: best_loss = 0.070794\n", + "Epoch 300: best_loss = 0.067597\n", + "Epoch 320: best_loss = 0.064071\n", + "Epoch 340: best_loss = 0.062061\n", + "Epoch 360: best_loss = 0.060405\n", + "Epoch 380: best_loss = 0.058727\n", + "Epoch 400: best_loss = 0.056750\n", + "Epoch 420: best_loss = 0.055028\n", + "Epoch 440: best_loss = 0.054138\n", + "Epoch 460: best_loss = 0.053217\n", + "Epoch 480: best_loss = 0.052362\n", + "Epoch 500: best_loss = 0.051482\n", + "Training complete. Best loss: 0.051482\n", "Model saved to runs/best_model.pth\n" ] } @@ -192,7 +192,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 7, "id": "bda9f068", "metadata": {}, "outputs": [], @@ -210,7 +210,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 8, "id": "7c2deb40-bee8-4973-80f0-9d9485eabf0c", "metadata": {}, "outputs": [ @@ -232,7 +232,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 9, "id": "012e01ac-caf1-47f7-a3e2-7b06bb260bfd", "metadata": {}, "outputs": [ @@ -788,8 +788,8 @@ " * lat (lat) float32 640B -29.88 -29.62 -29.38 ... 9.375 9.625 9.875\n", " * lon (lon) float32 640B -49.88 -49.62 -49.38 ... -10.38 -10.12\n", "Data variables:\n", - " predictions (time, lat, lon) float32 205kB 0.0 298.3 298.1 ... 0.0 0.0 0.0