Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion climanet/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ def predict_monthly_var(
device: str = "cpu",
run_dir: str = ".",
verbose: bool = True,
dataloader_num_workers: int = 2,
predict_threads: int | None = None,
):
"""
Predicts monthly variable values using a trained model and a provided dataset.
Expand All @@ -79,6 +81,8 @@ def predict_monthly_var(
device: The device to run the predictions on (e.g., 'cpu' or 'cuda').
run_dir: Directory to save log files and predictions.
verbose: If True, prints progress information during prediction.
dataloader_num_workers: how many subprocesses to use for data loading.
See torch DataLoader docs for details.
Returns:
A NumPy array, PyTorch tensor, or xarray Dataset containing the predicted values.
If return_loss is True, it also returns the average loss over the dataset.
Expand All @@ -92,7 +96,12 @@ def predict_monthly_var(

use_cuda = device == "cuda"
dataloader = DataLoader(
dataset, batch_size=batch_size, shuffle=False, pin_memory=use_cuda
dataset,
batch_size=batch_size,
shuffle=False,
pin_memory=use_cuda,
num_workers=dataloader_num_workers,# for data loading
persistent_workers=True, # keep workers alive between epochs
)

# Initialize an empty list to store predictions
Expand Down
12 changes: 10 additions & 2 deletions climanet/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ def train_monthly_model(
store_model: bool = True,
device: str = "cpu",
verbose: bool = True,
dataloader_num_workers: int = 2,
training_threads: int = None,

):
"""Train the model to predict monthly data from daily data.
Args:
Expand All @@ -37,16 +40,18 @@ def train_monthly_model(
store_model: whether to save the best model to disk
device: device to run training on ("cpu" or "cuda")
verbose: whether to print training progress
dataloader_num_workers: how many subprocesses to use for data loading.
See torch DataLoader docs for details.
"""

# check if dataset has indices attribute for stats calculation
base_dataset = dataset.dataset if hasattr(dataset, "dataset") else dataset
indices = dataset.indices if hasattr(dataset, "indices") else None
mean, std = base_dataset.compute_stats(indices)

# Initialize the model
model = model.to(device)
decoder = model.decoder

decoder = model.module.decoder if hasattr(model, 'module') else model.decoder
with torch.no_grad():
decoder.bias.copy_(torch.from_numpy(mean))
decoder.scale.copy_(torch.from_numpy(std) + 1e-6)
Expand All @@ -57,6 +62,8 @@ def train_monthly_model(
batch_size=batch_size,
shuffle=shuffle,
pin_memory=False,
num_workers=dataloader_num_workers, # for data loading
persistent_workers=True, # keep workers alive between epochs
)

# Set up logging
Expand Down Expand Up @@ -136,6 +143,7 @@ def train_monthly_model(
return_loss=True,
verbose=False,
run_dir=run_dir,
dataloader_num_workers=dataloader_num_workers,
)
writer.add_scalar("Loss/validation", avg_epoch_loss, epoch)

Expand Down
35 changes: 32 additions & 3 deletions climanet/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import numpy as np
import xarray as xr
import torch
import psutil

from torch.utils.tensorboard import SummaryWriter

Expand Down Expand Up @@ -133,15 +134,15 @@ def add_month_day_dims(
# Build aligned datetime array (M,T)
time_da = daily_ts[time_dim]

#time_indexed is (M,T) with NaT for padded days
#time_indexed is (M,T) with NaT for padded days
time_indexed = (
time_da.assign_coords(M=(time_dim, dkey.values),
T=(time_dim, time_da.dt.day.values))
.set_index({time_dim: ("M", "T")})
.unstack(time_dim)
.reindex(T=np.arange(1,32), M=month_keys)
)

#determine day-of-year (doy) [and hour-of-day (hod) if applicable], fill NaT with 0 inplace
# here we choose to use the tropical year length (365.2422 day, which we round to 365.24) as the
# period to return to the position of the sun relative to the Earth
Expand All @@ -158,7 +159,7 @@ def add_month_day_dims(
#create phase from day and hod
doy_phase = 2*np.pi*doy/doy_period
hod_phase = 2*np.pi*hod/hod_period


#Stack cyclic encodings into time_features (M,T,2)
time_features = xr.concat([doy_phase,hod_phase],
Expand Down Expand Up @@ -254,3 +255,31 @@ def save_model(model: torch.nn.Module, run_dir: str, verbose: bool) -> None:
)
if verbose:
print(f"Model saved to {model_path}")


def configure_compute_resources(
model: torch.nn.Module, device: str, compute_threads: int, dataloader_num_workers: int
) -> torch.nn.Module:
"""Configure model for multi-GPU and set CPU thread usage for compute (training or prediction).

Args:
model: the PyTorch model to configure
device: device to run on ("cpu" or "cuda")
compute_threads: number of threads to use for compute when device is CPU.
If None, it will be set automatically.
dataloader_num_workers: how many subprocesses to use for data loading.
See torch DataLoader docs for details.
Returns:
The model, potentially wrapped in DataParallel if using multiple GPUs.
"""
if device == "cpu":
if compute_threads is None:
total_cpus = psutil.cpu_count()
# keep 1 for main thread
compute_threads = max(1, total_cpus - dataloader_num_workers - 1)
torch.set_num_threads(compute_threads)
elif device == "cuda":
num_gpus = torch.cuda.device_count()
if num_gpus > 1:
model = torch.nn.DataParallel(model)
return model
145 changes: 97 additions & 48 deletions notebooks/example.ipynb

Large diffs are not rendered by default.