diff --git a/.gitignore b/.gitignore index 9c9e7a3..611e8dd 100644 --- a/.gitignore +++ b/.gitignore @@ -105,4 +105,7 @@ notebooks/ test_cases/ test_data/ TODO.md -newsletter.py \ No newline at end of file +newsletter.py + +# Tests +tests/downloaders/ \ No newline at end of file diff --git a/bluemath_tk/deeplearning/gp_models.py b/bluemath_tk/deeplearning/gp_models.py new file mode 100644 index 0000000..8c8c44d --- /dev/null +++ b/bluemath_tk/deeplearning/gp_models.py @@ -0,0 +1,667 @@ +""" +Gaussian Process models module. + +This module contains Gaussian Process Regression models using GPyTorch. + +Classes: +- BaseGPRModel: Base class for all GP models +- ExactGPModel: Exact Gaussian Process Regression model + +1. Wang, Z., Leung, M., Mukhopadhyay, S., et al. (2024). "A hybrid statistical–dynamical framework for compound coastal flooding analysis." *Environmental Research Letters*, 20(1), 014005. +2. Wang, Z., Leung, M., Mukhopadhyay, S., et al. (2025). "Compound coastal flooding in San Francisco Bay under climate change." *npj Natural Hazards*, 2(1), 3. +""" + +from abc import abstractmethod +from typing import Dict, Optional, Tuple, Union + +import gpytorch +import numpy as np +import torch +from gpytorch.kernels import Kernel, MaternKernel, RBFKernel, ScaleKernel +from gpytorch.likelihoods import GaussianLikelihood +from gpytorch.means import ConstantMean +from gpytorch.mlls import ExactMarginalLogLikelihood +from gpytorch.models import ExactGP +from tqdm import tqdm + +from ..core.models import BlueMathModel + + +class BaseGPRModel(BlueMathModel): + """ + Base class for Gaussian Process Regression models. + + This class provides common functionality for all GP models, including: + - GP-specific training with marginal log likelihood + - Prediction with uncertainty quantification + - Model save/load with likelihood handling + + GP models differ from standard deep learning models in several ways: + - Use marginal log likelihood (MLL) instead of standard loss functions + - Require explicit training data setting via set_train_data() + - Return distributions (mean + variance) rather than point estimates + - Typically train on full dataset (no batching during training) + + GP models inherit directly from BlueMathModel (not BaseDeepLearningModel) + because their training and prediction workflows are fundamentally different + from standard neural networks. + + Attributes + ---------- + model : gpytorch.models.GP + The GPyTorch model. + device : torch.device + The device (CPU/GPU) the model is on. + is_fitted : bool + Whether the model has been fitted. + likelihood : gpytorch.likelihoods.Likelihood + The GP likelihood module. + mll : gpytorch.mlls.MarginalLogLikelihood + The marginal log likelihood objective. + """ + + def __init__( + self, + device: Optional[Union[str, torch.device]] = None, + **kwargs, + ): + """ + Initialize the base GP model. + + Parameters + ---------- + device : str or torch.device, optional + Device to run the model on. Default is None (auto-detect GPU/CPU). + **kwargs + Additional keyword arguments passed to BlueMathModel. + """ + super().__init__(**kwargs) + + # Device management (similar to BaseDeepLearningModel but GP-specific) + if device is None: + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + elif isinstance(device, str): + self.device = torch.device(device) + else: + self.device = device + + # GP-specific attributes + self.model: Optional[gpytorch.models.GP] = None + self.is_fitted = False + self.likelihood: Optional[gpytorch.likelihoods.Likelihood] = None + self.mll: Optional[gpytorch.mlls.MarginalLogLikelihood] = None + + # Exclude from pickling (GPyTorch objects need special handling) + self._exclude_attributes = [ + "model", + "likelihood", + "mll", + ] + + @abstractmethod + def _build_kernel(self, input_dim: int) -> Kernel: + """ + Build the covariance kernel. + + Parameters + ---------- + input_dim : int + Number of input dimensions. + + Returns + ------- + gpytorch.kernels.Kernel + The covariance kernel. + """ + + pass + + @abstractmethod + def _build_model(self, input_shape: Tuple, **kwargs) -> gpytorch.models.GP: + """ + Build the GPyTorch model. + + Parameters + ---------- + input_shape : Tuple + Shape of input data. + + Returns + ------- + gpytorch.models.GP + The GPyTorch model. + """ + + pass + + def fit( + self, + X: np.ndarray, + y: np.ndarray, + epochs: int = 200, + learning_rate: float = 0.1, + optimizer: Optional[torch.optim.Optimizer] = None, + patience: int = 30, + verbose: int = 1, + **kwargs, + ) -> Dict[str, list]: + """ + Fit the Gaussian Process model. + + GP models use marginal log likelihood (MLL) optimization, which is + fundamentally different from standard deep learning training. + + Parameters + ---------- + X : np.ndarray + Training input data with shape (n_samples, n_features). + y : np.ndarray + Training target data with shape (n_samples,) or (n_samples, 1). + epochs : int, optional + Maximum number of training epochs. Default is 200. + learning_rate : float, optional + Learning rate for optimizer. Default is 0.1. + optimizer : torch.optim.Optimizer, optional + Optimizer to use. If None, uses Adam. Default is None. + patience : int, optional + Early stopping patience. Default is 30. + verbose : int, optional + Verbosity level. Default is 1. + **kwargs + Additional keyword arguments passed to _build_model. + + Returns + ------- + Dict[str, list] + Training history with 'train_loss' key (negative MLL). + """ + + # Reshape y if needed + if y.ndim > 1: + y = y.ravel() + + # Convert to tensors + X_tensor = torch.FloatTensor(X).to(self.device) + y_tensor = torch.FloatTensor(y).to(self.device) + + # Build model if not already built + if self.model is None: + self.model = self._build_model(X.shape, **kwargs) + # Initialize likelihood if not set + if self.likelihood is None: + self.likelihood = GaussianLikelihood().to(self.device) + # Initialize MLL + self.mll = self._build_mll(self.likelihood, self.model) + + # Always update training data (allows retraining with new data) + # This is GP-specific: we need to explicitly set training data + self._set_train_data(X_tensor, y_tensor) + + # Rebuild MLL after setting training data + self.mll = self._build_mll(self.likelihood, self.model) + + # Setup optimizer + if optimizer is None: + optimizer = torch.optim.Adam( + list(self.model.parameters()) + list(self.likelihood.parameters()), + lr=learning_rate, + ) + + # Setup learning rate scheduler + scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( + optimizer, mode="min", factor=0.8, patience=10 + ) + + history = {"train_loss": []} + best_loss = float("inf") + patience_counter = 0 + best_model_state = None + best_likelihood_state = None + + # Training loop + use_progress_bar = verbose > 0 + epoch_range = range(epochs) + pbar = None + if use_progress_bar: + pbar = tqdm(epoch_range, desc="Training GP", unit="epoch") + epoch_range = pbar + + self.model.train() + self.likelihood.train() + + for epoch in epoch_range: + optimizer.zero_grad() + + # Forward pass: compute negative marginal log likelihood + # This is the GP-specific loss function + loss = self._compute_loss(X_tensor, y_tensor) + + # Backward pass + loss.backward() + torch.nn.utils.clip_grad_norm_( + list(self.model.parameters()) + list(self.likelihood.parameters()), + max_norm=1.0, + ) + optimizer.step() + + loss_value = loss.item() + history["train_loss"].append(loss_value) + scheduler.step(loss_value) + + # Early stopping + if loss_value < best_loss - 1e-4: + best_loss = loss_value + patience_counter = 0 + best_model_state = self.model.state_dict().copy() + best_likelihood_state = self.likelihood.state_dict().copy() + else: + patience_counter += 1 + if patience_counter >= patience: + if verbose > 0: + if pbar is not None: + pbar.set_postfix_str(f"Early stopping at epoch {epoch + 1}") + self.logger.info(f"Early stopping at epoch {epoch + 1}") + break + + # Update progress bar + if pbar is not None: + pbar.set_postfix_str(f"Loss: {loss_value:.4f}") + elif verbose > 0 and (epoch + 1) % max(1, epochs // 10) == 0: + self.logger.info(f"Epoch {epoch + 1}/{epochs} - Loss: {loss_value:.4f}") + + # Restore best model + if best_model_state is not None: + self.model.load_state_dict(best_model_state) + self.likelihood.load_state_dict(best_likelihood_state) + + self.is_fitted = True + + return history + + def predict( + self, + X: np.ndarray, + batch_size: Optional[int] = None, + return_std: bool = False, + verbose: int = 1, + ) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]: + """ + Make predictions with the Gaussian Process model. + + GP models return distributions, so predictions include uncertainty + estimates (standard deviation) by default. + + Parameters + ---------- + X : np.ndarray + Input data with shape (n_samples, n_features). + batch_size : int, optional + Batch size for prediction. If None, processes all at once. + Default is None. + return_std : bool, optional + If True, returns both mean and standard deviation. + Default is False. + verbose : int, optional + Verbosity level. Default is 1. + + Returns + ------- + np.ndarray or tuple + If return_std=False: predictions (mean) with shape (n_samples,). + If return_std=True: tuple of (mean, std) both with shape (n_samples,). + + Raises + ------ + ValueError + If model is not fitted. + """ + + if not self.is_fitted or self.model is None: + raise ValueError("Model must be fitted before prediction.") + + self.model.eval() + self.likelihood.eval() + + X_tensor = torch.FloatTensor(X).to(self.device) + + # Process in batches if batch_size is specified + if batch_size is None: + batch_size = len(X) + + predictions = [] + stds = [] + + n_batches = (len(X) + batch_size - 1) // batch_size + batch_range = range(0, len(X), batch_size) + + if verbose > 0 and n_batches > 1: + batch_range = tqdm( + batch_range, desc="Predicting", unit="batch", total=n_batches + ) + + with ( + torch.no_grad(), + gpytorch.settings.fast_pred_var(), + gpytorch.settings.cholesky_jitter(1e-1), + ): + for i in batch_range: + batch_X = X_tensor[i : i + batch_size] + pred_dist = self._predict_batch(batch_X) + predictions.append(pred_dist.mean.cpu().numpy()) + if return_std: + stds.append(pred_dist.stddev.cpu().numpy()) + + mean_pred = np.concatenate(predictions, axis=0) + + if return_std: + std_pred = np.concatenate(stds, axis=0) + return mean_pred, std_pred + else: + return mean_pred + + def _set_train_data(self, X: torch.Tensor, y: torch.Tensor): + """ + Set training data for the GP model. + + This is GP-specific: GP models need explicit training data setting. + + Parameters + ---------- + X : torch.Tensor + Training inputs. + y : torch.Tensor + Training targets. + """ + + if hasattr(self.model, "set_train_data"): + self.model.set_train_data(X, y, strict=False) + else: + raise AttributeError( + f"Model {type(self.model)} does not support set_train_data(). " + "This is required for GP models." + ) + + def _build_mll( + self, + likelihood: gpytorch.likelihoods.Likelihood, + model: gpytorch.models.GP, + ) -> gpytorch.mlls.MarginalLogLikelihood: + """ + Build the marginal log likelihood objective. + + Parameters + ---------- + likelihood : gpytorch.likelihoods.Likelihood + The likelihood module. + model : gpytorch.models.GP + The GP model. + + Returns + ------- + gpytorch.mlls.MarginalLogLikelihood + The MLL objective. + """ + + return ExactMarginalLogLikelihood(likelihood, model) + + def _compute_loss(self, X: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + """ + Compute the training loss (negative MLL). + + Parameters + ---------- + X : torch.Tensor + Training inputs. + y : torch.Tensor + Training targets. + + Returns + ------- + torch.Tensor + Negative marginal log likelihood. + """ + + with gpytorch.settings.cholesky_jitter(1e-1): + output = self.model(X) + loss = -self.mll(output, y) + + return loss + + def _predict_batch(self, X: torch.Tensor) -> gpytorch.distributions.Distribution: + """ + Make predictions for a batch of inputs. + + Parameters + ---------- + X : torch.Tensor + Input batch. + + Returns + ------- + gpytorch.distributions.Distribution + Predictive distribution. + """ + + return self.likelihood(self.model(X)) + + def save_pytorch_model(self, model_path: str, **kwargs): + """ + Save the GP model to a file. + + GP models require saving both the model and likelihood state dicts. + + Parameters + ---------- + model_path : str + Path to the file where the model will be saved. + **kwargs + Additional arguments for torch.save. + """ + + if self.model is None or self.likelihood is None: + raise ValueError("Model must be built before saving.") + + # Get model-specific metadata + metadata = self._get_model_metadata() + + torch.save( + { + "model_state_dict": self.model.state_dict(), + "likelihood_state_dict": self.likelihood.state_dict(), + "is_fitted": self.is_fitted, + "model_class": self.__class__.__name__, + **metadata, + }, + model_path, + **kwargs, + ) + self.logger.info(f"GP model saved to {model_path}") + + def load_pytorch_model(self, model_path: str, **kwargs): + """ + Load a GP model from a file. + + Parameters + ---------- + model_path : str + Path to the file where the model is saved. + **kwargs + Additional arguments for torch.load. + """ + + checkpoint = torch.load(model_path, **kwargs) + + # Restore model-specific attributes + self._restore_model_metadata(checkpoint) + + # Build model first if needed + if self.model is None: + # Need input shape to build model - use dummy data + # In practice, you should save/load the training data shape + dummy_shape = (10, 10) # Default, user should provide actual shape + self.model = self._build_model(dummy_shape) + # Initialize likelihood if not set (should be set by _build_model, but check anyway) + if self.likelihood is None: + self.likelihood = GaussianLikelihood().to(self.device) + + self.model.load_state_dict(checkpoint["model_state_dict"]) + self.likelihood.load_state_dict(checkpoint["likelihood_state_dict"]) + self.is_fitted = checkpoint.get("is_fitted", False) + self.logger.info(f"GP model loaded from {model_path}") + + def _get_model_metadata(self) -> Dict: + """ + Get model-specific metadata for saving. + + Override this method in subclasses to save additional metadata. + + Returns + ------- + Dict + Metadata dictionary. + """ + + return {} + + def _restore_model_metadata(self, checkpoint: Dict): + """ + Restore model-specific metadata from checkpoint. + + Override this method in subclasses to restore additional metadata. + + Parameters + ---------- + checkpoint : Dict + Checkpoint dictionary. + """ + + pass + + +class ExactGPModel(BaseGPRModel): + """ + Exact Gaussian Process Regression model using GPyTorch. + + This model implements exact GP inference, suitable for datasets up to + several thousand samples. For larger datasets, consider using approximate + GP methods. + + Parameters + ---------- + kernel : str, optional + Type of kernel to use. Options: 'rbf', 'matern', 'rbf+matern'. + Default is 'rbf+matern'. + ard_num_dims : int, optional + Number of input dimensions for ARD (Automatic Relevance Determination). + If None, will be inferred from data. Default is None. + device : str or torch.device, optional + Device to run the model on. Default is None (auto-detect). + **kwargs + Additional keyword arguments passed to BaseGPRModel. + + Examples + -------- + >>> import numpy as np + >>> from bluemath_tk.deeplearning import ExactGPModel + >>> + >>> # Generate sample data + >>> X = np.random.randn(100, 5) + >>> y = np.random.randn(100) + >>> + >>> # Create and fit model + >>> gp = ExactGPModel(kernel='rbf+matern') + >>> history = gp.fit(X, y, epochs=100, learning_rate=0.1) + >>> + >>> # Make predictions + >>> X_test = np.random.randn(50, 5) + >>> y_pred, y_std = gp.predict(X_test, return_std=True) + """ + + def __init__( + self, + kernel: str = "rbf+matern", + ard_num_dims: Optional[int] = None, + device: Optional[torch.device] = None, + **kwargs, + ): + super().__init__(device=device, **kwargs) + self.kernel_type = kernel.lower() + self.ard_num_dims = ard_num_dims + + def _build_kernel(self, input_dim: int) -> Kernel: + """ + Build the covariance kernel. + """ + + if self.ard_num_dims is None: + ard_num_dims = input_dim + else: + ard_num_dims = self.ard_num_dims + + if self.kernel_type == "rbf": + base_kernel = RBFKernel(ard_num_dims=ard_num_dims) + elif self.kernel_type == "matern": + base_kernel = MaternKernel(nu=2.5, ard_num_dims=ard_num_dims) + elif self.kernel_type == "rbf+matern": + base_kernel = RBFKernel(ard_num_dims=ard_num_dims) + MaternKernel( + nu=2.5, ard_num_dims=ard_num_dims + ) + else: + raise ValueError( + f"Unknown kernel type: {self.kernel_type}. " + "Options: 'rbf', 'matern', 'rbf+matern'" + ) + + return ScaleKernel(base_kernel) + + def _build_model(self, input_shape: Tuple, **kwargs) -> ExactGP: + """ + Build the GPyTorch ExactGP model. + """ + + if len(input_shape) == 1: + input_dim = input_shape[0] + else: + input_dim = input_shape[-1] + + kernel = self._build_kernel(input_dim) + + class GPModel(ExactGP): + def __init__(self, train_x, train_y, likelihood, kernel): + super().__init__(train_x, train_y, likelihood) + self.mean_module = ConstantMean() + self.covar_module = kernel + + def forward(self, x): + mean_x = self.mean_module(x) + covar_x = self.covar_module(x) + return gpytorch.distributions.MultivariateNormal(mean_x, covar_x) + + # Create dummy data for initialization + dummy_x = torch.randn(10, input_dim).to(self.device) + dummy_y = torch.randn(10).to(self.device) + + # Initialize likelihood and model + if self.likelihood is None: + self.likelihood = GaussianLikelihood().to(self.device) + model = GPModel(dummy_x, dummy_y, self.likelihood, kernel.to(self.device)) + + return model.to(self.device) + + def _get_model_metadata(self) -> Dict: + """ + Get model-specific metadata for saving. + """ + + return { + "kernel_type": self.kernel_type, + "ard_num_dims": self.ard_num_dims, + } + + def _restore_model_metadata(self, checkpoint: Dict): + """ + Restore model-specific metadata from checkpoint. + """ + + self.kernel_type = checkpoint.get("kernel_type", "rbf+matern") + self.ard_num_dims = checkpoint.get("ard_num_dims", None) diff --git a/bluemath_tk/downloaders/__init__.py b/bluemath_tk/downloaders/__init__.py index cb8bf1f..8580035 100644 --- a/bluemath_tk/downloaders/__init__.py +++ b/bluemath_tk/downloaders/__init__.py @@ -5,3 +5,8 @@ Repository: https://github.com/GeoOcean/BlueMath_tk.git Status: Under development (Working) """ + +from ._base_downloaders import BaseDownloader +from ._download_result import DownloadResult + +__all__ = ["DownloadResult", "BaseDownloader"] diff --git a/bluemath_tk/downloaders/_base_downloaders.py b/bluemath_tk/downloaders/_base_downloaders.py index 0e1edc3..c2840ea 100644 --- a/bluemath_tk/downloaders/_base_downloaders.py +++ b/bluemath_tk/downloaders/_base_downloaders.py @@ -1,70 +1,65 @@ from abc import abstractmethod +from datetime import datetime +from typing import List, Optional from ..core.models import BlueMathModel +from ._download_result import DownloadResult class BaseDownloader(BlueMathModel): """ - Abstract class for BlueMath downloaders. + Abstract base class for BlueMath downloaders. + + All downloaders should: + 1. Have a `download_data` method that routes to product-specific methods + 2. Have product-specific methods like `download_data_` + 3. Use DownloadResult to track download status Attributes ---------- + product : str + The product name (e.g., "SWOT", "ERA5"). + product_config : dict + Product configuration dictionary. base_path_to_download : str - The base path to download the data. - debug : bool, optional - If True, the logger will be set to DEBUG level. Default is True. - check : bool, optional - If True, just file checking is required. Default is False. - - Methods - ------- - download_data(*args, **kwargs) - Downloads the data. This method must be implemented in the child class. - - Notes - ----- - - This class is an abstract class and should not be instantiated. - - The download_data method must be implemented in the child class. + Base path where downloaded files are stored. + debug : bool + If True, logger is set to DEBUG level. """ def __init__( - self, base_path_to_download: str, debug: bool = True, check: bool = False + self, + product: str, + base_path_to_download: str, + debug: bool = True, ) -> None: """ - The constructor for BaseDownloader class. + Initialize the BaseDownloader. Parameters ---------- + product : str + The product to download data from. base_path_to_download : str The base path to download the data. debug : bool, optional If True, the logger will be set to DEBUG level. Default is True. - check : bool, optional - If True, just file checking is required. Default is False. - - Raises - ------ - ValueError - If base_path_to_download is not a string. - If debug is not a boolean. - If check is not a boolean. - - Notes - ----- - - The logger will be set to INFO level. - - If debug is True, the logger will be set to DEBUG level. """ super().__init__() + if not isinstance(product, str): + raise ValueError("product must be a string") + self._product: str = product if not isinstance(base_path_to_download, str): raise ValueError("base_path_to_download must be a string") self._base_path_to_download: str = base_path_to_download if not isinstance(debug, bool): raise ValueError("debug must be a boolean") self._debug: bool = debug - if not isinstance(check, bool): - raise ValueError("check must be a boolean") - self._check: bool = check + + @property + def product(self) -> str: + return self._product @property def base_path_to_download(self) -> str: @@ -75,9 +70,103 @@ def debug(self) -> bool: return self._debug @property - def check(self) -> bool: - return self._check + @abstractmethod + def product_config(self) -> dict: + pass + + def list_datasets(self) -> List[str]: + """ + List all available datasets for the product. + + Returns + ------- + List[str] + List of available dataset names. + """ + + return list(self.product_config.get("datasets", {}).keys()) + + def create_download_result( + self, start_time: Optional[datetime] = None + ) -> DownloadResult: + """ + Create a new DownloadResult instance. + + Parameters + ---------- + start_time : Optional[datetime], optional + The start time of the download operation. If None, the current time is used. + + Returns + ------- + DownloadResult + A new DownloadResult instance. + """ + + result = DownloadResult() + result.start_time = start_time if start_time else datetime.now() + + return result + + def finalize_download_result( + self, result: DownloadResult, message: Optional[str] = None + ) -> DownloadResult: + """ + Finalize a DownloadResult with end time and summary message. + + Parameters + ---------- + result : DownloadResult + The DownloadResult to finalize. + message : Optional[str], optional + The message to add to the DownloadResult. + + Returns + ------- + DownloadResult + The finalized DownloadResult. + """ + + result.end_time = datetime.now() + + if result.start_time and result.end_time: + delta = result.end_time - result.start_time + result.duration_seconds = delta.total_seconds() + + result.success = len(result.error_files) == 0 + + if message is None: + parts = [] + if result.downloaded_files: + parts.append(f"{len(result.downloaded_files)} downloaded") + if result.skipped_files: + parts.append(f"{len(result.skipped_files)} skipped") + if result.error_files: + parts.append(f"{len(result.error_files)} errors") + result.message = f"Download complete: {', '.join(parts)}" + else: + result.message = message + + return result @abstractmethod - def download_data(self, *args, **kwargs) -> None: + def download_data(self, *args, **kwargs) -> DownloadResult: + """ + Download data for the product. + + Routes to product-specific methods like download_data_(). + + Parameters + ---------- + *args + Arguments passed to product-specific download method. + **kwargs + Keyword arguments (e.g., force, dry_run). + + Returns + ------- + DownloadResult + Result with information about downloaded, skipped, and error files. + """ + pass diff --git a/bluemath_tk/downloaders/_download_result.py b/bluemath_tk/downloaders/_download_result.py new file mode 100644 index 0000000..94e09e1 --- /dev/null +++ b/bluemath_tk/downloaders/_download_result.py @@ -0,0 +1,177 @@ +from dataclasses import dataclass, field +from datetime import datetime +from typing import Any, Dict, List, Optional + + +@dataclass +class DownloadResult: + """ + Standardized result structure for download operations. + + This class provides a consistent interface for download results across all + downloaders, making it easier to handle success/failure cases and track + downloaded files. + + Attributes + ---------- + success : bool + Whether the download operation completed successfully. + downloaded_files : List[str] + List of file paths that were successfully downloaded. + skipped_files : List[str] + List of file paths that were skipped (e.g., already exist, incomplete). + error_files : List[str] + List of file paths that failed to download. + errors : List[Dict[str, Any]] + List of error dictionaries containing error details. + Each dict has keys: 'file', 'error', 'timestamp'. + metadata : Dict[str, Any] + Additional metadata about the download operation. + message : str + Human-readable summary message. + start_time : Optional[datetime] + When the download operation started. + end_time : Optional[datetime] + When the download operation ended. + duration_seconds : Optional[float] + Total duration of the download operation in seconds. + + Examples + -------- + >>> result = DownloadResult( + ... success=True, + ... downloaded_files=["/path/to/file1.nc", "/path/to/file2.nc"], + ... message="Downloaded 2 files successfully" + ... ) + >>> print(result.message) + Downloaded 2 files successfully + >>> print(f"Success rate: {result.success_rate:.1%}") + Success rate: 100.0% + """ + + success: bool = False + downloaded_files: List[str] = field(default_factory=list) + skipped_files: List[str] = field(default_factory=list) + error_files: List[str] = field(default_factory=list) + errors: List[Dict[str, Any]] = field(default_factory=list) + metadata: Dict[str, Any] = field(default_factory=dict) + message: str = "" + start_time: Optional[datetime] = None + end_time: Optional[datetime] = None + duration_seconds: Optional[float] = None + + def __post_init__(self): + """Calculate duration if both start and end times are provided.""" + if self.start_time and self.end_time: + delta = self.end_time - self.start_time + self.duration_seconds = delta.total_seconds() + + @property + def total_files(self) -> int: + """Total number of files processed.""" + return ( + len(self.downloaded_files) + len(self.skipped_files) + len(self.error_files) + ) + + @property + def success_rate(self) -> float: + """Success rate as a fraction (0.0 to 1.0).""" + if self.total_files == 0: + return 0.0 + return len(self.downloaded_files) / self.total_files + + @property + def has_errors(self) -> bool: + """Whether any errors occurred.""" + return len(self.error_files) > 0 or len(self.errors) > 0 + + def add_error( + self, file_path: str, error: Exception, context: Dict[str, Any] = None + ): + """ + Add an error to the result. + + Parameters + ---------- + file_path : str + Path to the file that caused the error. + error : Exception + The exception that occurred. + context : Dict[str, Any], optional + Additional context about the error. + """ + error_dict = { + "file": file_path, + "error": str(error), + "error_type": type(error).__name__, + "timestamp": datetime.now().isoformat(), + } + if context: + error_dict["context"] = context + self.errors.append(error_dict) + if file_path not in self.error_files: + self.error_files.append(file_path) + + def add_downloaded(self, file_path: str): + """Add a successfully downloaded file.""" + if file_path not in self.downloaded_files: + self.downloaded_files.append(file_path) + + def add_skipped(self, file_path: str, reason: str = ""): + """ + Add a skipped file. + + Parameters + ---------- + file_path : str + Path to the skipped file. + reason : str, optional + Reason why the file was skipped. + """ + if file_path not in self.skipped_files: + self.skipped_files.append(file_path) + if reason: + self.metadata.setdefault("skip_reasons", {})[file_path] = reason + + def to_dict(self) -> Dict[str, Any]: + """Convert the result to a dictionary.""" + return { + "success": self.success, + "downloaded_files": self.downloaded_files, + "skipped_files": self.skipped_files, + "error_files": self.error_files, + "errors": self.errors, + "metadata": self.metadata, + "message": self.message, + "start_time": self.start_time.isoformat() if self.start_time else None, + "end_time": self.end_time.isoformat() if self.end_time else None, + "duration_seconds": self.duration_seconds, + "total_files": self.total_files, + "success_rate": self.success_rate, + } + + def __str__(self) -> str: + """Human-readable string representation.""" + if self.message: + return self.message + return ( + f"DownloadResult(success={self.success}, " + f"downloaded={len(self.downloaded_files)}, " + f"skipped={len(self.skipped_files)}, " + f"errors={len(self.error_files)})" + ) + + def __repr__(self) -> str: + """Detailed string representation.""" + duration = f"{self.duration_seconds:.1f}s" if self.duration_seconds else "N/A" + return ( + f"DownloadResult(\n" + f" success={self.success},\n" + f" downloaded_files={len(self.downloaded_files)} files,\n" + f" skipped_files={len(self.skipped_files)} files,\n" + f" error_files={len(self.error_files)} files,\n" + f" total_files={self.total_files},\n" + f" success_rate={self.success_rate:.1%},\n" + f" duration={duration},\n" + f")" + ) diff --git a/bluemath_tk/downloaders/aviso/SWOT/SWOT_config.json b/bluemath_tk/downloaders/aviso/SWOT/SWOT_config.json new file mode 100644 index 0000000..0b939ab --- /dev/null +++ b/bluemath_tk/downloaders/aviso/SWOT/SWOT_config.json @@ -0,0 +1,37 @@ +{ + "datasets": { + "swot-l3-expert": { + "description": "SWOT L3 Expert Product", + "url": "https://tds-odatis.aviso.altimetry.fr/thredds/catalog/dataset-l3-swot-karin-nadir-validated/l3_lr_ssh/v1_0/Expert/catalog.html", + "ftp_base_path": "/swot_products/l3_karin_nadir/l3_lr_ssh/v1_0/Expert/", + "cycles": [ + "cycle_001", + "cycle_002", + "cycle_003", + "cycle_004", + "cycle_005", + "cycle_006", + "cycle_007", + "cycle_008", + "cycle_009" + ] + }, + "swot-l2-expert": { + "description": "SWOT L2 Expert Product", + "url": "https://tds-odatis.aviso.altimetry.fr/thredds/catalog/dataset-l2-swot-karin-lr-ssh-validated/PGC0/Expert/catalog.html", + "ftp_base_path": "/swot_products/l2_karin/l2_lr_ssh/PGC0/Expert/", + "cycles": [ + "cycle_001", + "cycle_002", + "cycle_003", + "cycle_004", + "cycle_005", + "cycle_006", + "cycle_007", + "cycle_008", + "cycle_009" + ] + } + }, + "ftp_server": "ftp-access.aviso.altimetry.fr" +} \ No newline at end of file diff --git a/bluemath_tk/downloaders/aviso/__init__.py b/bluemath_tk/downloaders/aviso/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/bluemath_tk/downloaders/aviso/aviso_downloader.py b/bluemath_tk/downloaders/aviso/aviso_downloader.py new file mode 100644 index 0000000..f147841 --- /dev/null +++ b/bluemath_tk/downloaders/aviso/aviso_downloader.py @@ -0,0 +1,336 @@ +import ftplib +import json +import os +from typing import List, Optional + +from .._base_downloaders import BaseDownloader +from .._download_result import DownloadResult + + +class AvisoDownloader(BaseDownloader): + """ + Simple downloader for AVISO data. + + Downloads all available files from the FTP base path specified in the config. + + Examples + -------- + >>> downloader = AvisoDownloader( + ... product="SWOT", + ... base_path_to_download="./swot_data", + ... username="your_username", + ... password="your_password" + ... ) + >>> result = downloader.download_data( + ... dataset="swot-l3-expert", + ... cycles=["cycle_001"], + ... force=False, + ... dry_run=False + ... ) + """ + + products_configs = { + "SWOT": json.load( + open(os.path.join(os.path.dirname(__file__), "SWOT", "SWOT_config.json")) + ) + } + + def __init__( + self, + product: str, + base_path_to_download: str, + username: str, + password: str, + debug: bool = True, + ) -> None: + """ + Initialize the AvisoDownloader. + + Parameters + ---------- + product : str + The product to download data from (e.g., "SWOT"). + base_path_to_download : str + Base path where downloaded files will be stored. + username : str + AVISO FTP username. + password : str + AVISO FTP password. + debug : bool, optional + If True, sets logger to DEBUG level. Default is True. + + Raises + ------ + ValueError + If the product configuration is not found or FTP server is not specified. + """ + + super().__init__( + product=product, base_path_to_download=base_path_to_download, debug=debug + ) + + self._product_config = self.products_configs.get(product) + if self._product_config is None: + raise ValueError( + f"Product '{product}' not found. Available: {list(self.products_configs.keys())}" + ) + + self.set_logger_name( + f"AvisoDownloader-{product}", level="DEBUG" if debug else "INFO" + ) + + # Initialize FTP client + ftp_server = self._product_config.get("ftp_server") + if ftp_server is None: + raise ValueError("FTP server not found in product configuration") + self._client = ftplib.FTP(ftp_server) + self._client.login(username, password) + + self.logger.info(f"---- AVISO DOWNLOADER INITIALIZED ({product}) ----") + + @property + def product_config(self) -> dict: + """ + Product configuration dictionary loaded from config file. + + Returns + ------- + dict + Product configuration dictionary. + """ + return self._product_config + + @property + def client(self) -> ftplib.FTP: + """ + FTP client connection (initialized and logged in). + + Returns + ------- + ftplib.FTP + FTP client instance. + """ + return self._client + + def download_data( + self, + dry_run: bool = True, + *args, + **kwargs, + ) -> DownloadResult: + """ + Download data for the product. + + Routes to product-specific download methods based on the product type. + + Parameters + ---------- + dry_run : bool, optional + If True, only check what would be downloaded without actually downloading. + Default is True. + *args + Arguments passed to product-specific download method. + **kwargs + Keyword arguments passed to product-specific download method. + + Returns + ------- + DownloadResult + Result with information about downloaded, skipped, and error files. + + Raises + ------ + ValueError + If the product is not supported. + """ + + if self.product == "SWOT": + return self.download_data_swot(dry_run=dry_run, *args, **kwargs) + else: + raise ValueError(f"Download for product {self.product} not supported") + + def download_data_swot( + self, + dataset: str, + cycles: Optional[List[str]] = None, + force: bool = False, + dry_run: bool = True, + ) -> DownloadResult: + """ + Download SWOT data for a specific dataset. + + Downloads all .nc files from specified cycles. Files are saved to: + base_path_to_download/dataset/cycle/filename.nc + + Parameters + ---------- + dataset : str + The dataset to download (e.g., "swot-l3-expert"). + Use list_datasets() to see available datasets. + cycles : List[str], optional + List of cycle folder names to download (e.g., ["cycle_001", "cycle_002"]). + If None, uses cycles from dataset configuration. Default is None. + force : bool, optional + Force re-download even if file exists. Default is False. + dry_run : bool, optional + If True, only check what would be downloaded. Default is True. + + Returns + ------- + DownloadResult + Result with all downloaded files and download statistics. + + Raises + ------ + ValueError + If dataset is not found or no cycles are available. + """ + + if dataset not in self.list_datasets(): + raise ValueError( + f"Dataset '{dataset}' not found. Available: {self.list_datasets()}" + ) + + dataset_config = self.product_config["datasets"][dataset] + ftp_base_path = dataset_config["ftp_base_path"] + result = self.create_download_result() + + try: + if cycles is None: + cycles = dataset_config.get("cycles", []) + if not cycles: + raise ValueError( + f"No cycles specified for dataset '{dataset}' and cycles parameter not provided" + ) + + self.logger.info(f"Downloading dataset: {dataset}, cycles: {cycles}") + + all_downloaded_files = [] + + for cycle in cycles: + files = self._list_all_files_in_cycle(ftp_base_path, cycle) + if not files: + self.logger.warning(f"No files found in cycle {cycle}") + continue + + downloaded_files = self._download_files( + files=files, + dataset=dataset, + ftp_base_path=ftp_base_path, + cycle=cycle, + force=force, + dry_run=dry_run, + result=result, + ) + all_downloaded_files.extend(downloaded_files) + + result.downloaded_files = all_downloaded_files + return self.finalize_download_result(result) + + except Exception as e: + result.add_error("download_operation", e) + return self.finalize_download_result(result) + + def _list_all_files_in_cycle(self, ftp_base_path: str, cycle: str) -> List[str]: + """ + List all .nc files from a cycle directory on FTP server. + + This method navigates to the specified FTP base path and then into the + cycle directory, lists its contents, and filters for files ending with '.nc'. + It assumes the current FTP connection is already logged in. + + Parameters + ---------- + ftp_base_path : str + FTP base path for the dataset (e.g., "/swot_products/l3_karin_nadir/l3_lr_ssh/v1_0/Expert/"). + cycle : str + Cycle directory name (e.g., "cycle_001"). + + Returns + ------- + List[str] + List of .nc filenames (without path) found in the cycle directory. + """ + + files = [] + self._client.cwd(ftp_base_path) + self._client.cwd(cycle) + items = [] + self._client.retrlines("LIST", items.append) + for item in items: + parts = item.split() + if len(parts) >= 9: + name = " ".join(parts[8:]) + if name.endswith(".nc"): + files.append(name) + + return files + + def _download_files( + self, + files: List[str], + dataset: str, + ftp_base_path: str, + cycle: str, + force: bool, + dry_run: bool, + result: DownloadResult, + ) -> List[str]: + """ + Download all files from the list. + + Files are saved to: base_path_to_download/dataset/cycle/filename.nc + + Parameters + ---------- + files : List[str] + List of filenames to download (without path). + dataset : str + Dataset name (used in local path, e.g., "swot-l3-expert"). + ftp_base_path : str + FTP base path for the dataset (e.g., "/swot_products/l3_karin_nadir/l3_lr_ssh/v1_0/Expert/"). + cycle : str + Cycle directory name (used in local path, e.g., "cycle_001"). + force : bool + Force re-download even if file exists. + dry_run : bool + If True, only simulate download. + result : DownloadResult + Download result object to update. + + Returns + ------- + List[str] + List of local file paths for successfully downloaded files only. + """ + + downloaded_files = [] + + for filename in files: + local_path = os.path.join( + self.base_path_to_download, dataset, cycle, filename + ) + + if not force and os.path.exists(local_path): + result.add_skipped(local_path, "Already downloaded") + continue + + if dry_run: + result.add_skipped(local_path, f"Would download {filename} (dry run)") + continue + + try: + os.makedirs(os.path.dirname(local_path), exist_ok=True) + self._client.cwd(ftp_base_path) + self._client.cwd(cycle) + with open(local_path, "wb") as f: + self._client.retrbinary(f"RETR {filename}", f.write) + result.add_downloaded(local_path) + self.logger.info(f"Downloaded: {filename} -> {local_path}") + downloaded_files.append(local_path) + + except Exception as e: + result.add_error(local_path, e) + self.logger.error(f"Error downloading {filename}: {e}") + + return downloaded_files diff --git a/bluemath_tk/downloaders/copernicus/CERRA/CERRA_config.json b/bluemath_tk/downloaders/copernicus/CERRA/CERRA_config.json new file mode 100644 index 0000000..a5299e2 --- /dev/null +++ b/bluemath_tk/downloaders/copernicus/CERRA/CERRA_config.json @@ -0,0 +1,115 @@ +{ + "datasets": { + "reanalysis-cerra-single-levels": { + "description": "CERRA Sub-daily Regional Reanalysis Data for Europe on Single Levels", + "url": "https://cds.climate.copernicus.eu/datasets/reanalysis-cerra-single-levels?tab=overview", + "types": [ + "surface_or_atmosphere" + ], + "mandatory_fields": [ + "variable", + "level_type", + "data_type", + "product_type", + "year", + "month", + "data_format" + ], + "optional_fields": [ + "day", + "time", + "area" + ], + "template": { + "variable": [ + "10m_wind_direction" + ], + "level_type": "surface_or_atmosphere", + "data_type": [ + "reanalysis" + ], + "product_type": "analysis", + "year": [ + "1985" + ], + "month": [ + "02" + ], + "day": [ + "02" + ], + "time": [ + "00:00" + ], + "data_format": "netcdf" + } + } + }, + "url": "https://cds.climate.copernicus.eu/api", + "variables": { + "10m_wind_direction": { + "cds_name": "10m_wind_direction", + "long_name": "10m wind direction", + "nc_name": "10m_wind_direction", + "type": "surface_or_atmosphere", + "units": "degree", + "dataset": "reanalysis-cerra-single-levels" + }, + "10m_wind_speed": { + "cds_name": "10m_wind_speed", + "long_name": "10m wind speed", + "nc_name": "10m_wind_speed", + "type": "surface_or_atmosphere", + "units": "m s-1", + "dataset": "reanalysis-cerra-single-levels" + }, + "2m_temperature": { + "cds_name": "2m_temperature", + "long_name": "2m temperature", + "nc_name": "2m_temperature", + "type": "surface_or_atmosphere", + "units": "K", + "dataset": "reanalysis-cerra-single-levels" + }, + "2m_relative_humidity": { + "cds_name": "2m_relative_humidity", + "long_name": "2m relative humidity", + "nc_name": "2m_relative_humidity", + "type": "surface_or_atmosphere", + "units": "%", + "dataset": "reanalysis-cerra-single-levels" + }, + "surface_pressure": { + "cds_name": "surface_pressure", + "long_name": "Surface pressure", + "nc_name": "surface_pressure", + "type": "surface_or_atmosphere", + "units": "Pa", + "dataset": "reanalysis-cerra-single-levels" + }, + "total_precipitation": { + "cds_name": "total_precipitation", + "long_name": "Total precipitation", + "nc_name": "total_precipitation", + "type": "surface_or_atmosphere", + "units": "kg m-2", + "dataset": "reanalysis-cerra-single-levels" + }, + "total_cloud_cover": { + "cds_name": "total_cloud_cover", + "long_name": "Total cloud cover", + "nc_name": "total_cloud_cover", + "type": "surface_or_atmosphere", + "units": "%", + "dataset": "reanalysis-cerra-single-levels" + }, + "mean_sea_level_pressure": { + "cds_name": "mean_sea_level_pressure", + "long_name": "Mean sea level pressure", + "nc_name": "mean_sea_level_pressure", + "type": "surface_or_atmosphere", + "units": "Pa", + "dataset": "reanalysis-cerra-single-levels" + } + } +} \ No newline at end of file diff --git a/bluemath_tk/downloaders/copernicus/ERA5/ERA5_config.json b/bluemath_tk/downloaders/copernicus/ERA5/ERA5_config.json index 35eb119..83cce29 100644 --- a/bluemath_tk/downloaders/copernicus/ERA5/ERA5_config.json +++ b/bluemath_tk/downloaders/copernicus/ERA5/ERA5_config.json @@ -126,6 +126,7 @@ } } }, + "url": "https://cds.climate.copernicus.eu/api", "variables": { "swh": { "cds_name": "significant_height_of_combined_wind_waves_and_swell", diff --git a/bluemath_tk/downloaders/copernicus/copernicus_downloader.py b/bluemath_tk/downloaders/copernicus/copernicus_downloader.py index 06531c1..9e7c50c 100644 --- a/bluemath_tk/downloaders/copernicus/copernicus_downloader.py +++ b/bluemath_tk/downloaders/copernicus/copernicus_downloader.py @@ -1,132 +1,128 @@ -import calendar import json import os -from typing import List +from typing import Any, Dict, List, Optional import cdsapi -import xarray as xr from .._base_downloaders import BaseDownloader - -config = { - "url": "https://cds.climate.copernicus.eu/api", # /v2? - "key": "your-api-token", -} +from .._download_result import DownloadResult class CopernicusDownloader(BaseDownloader): """ - This is the main class to download data from the Copernicus Climate Data Store. - - Attributes - ---------- - product : str - The product to download data from. Currently only ERA5 is supported. - product_config : dict - The configuration for the product to download data from. - client : cdsapi.Client - The client to interact with the Copernicus Climate Data Store API. + Simple downloader for Copernicus Climate Data Store. Examples -------- - .. jupyter-execute:: - - from bluemath_tk.downloaders.copernicus.copernicus_downloader import CopernicusDownloader - - copernicus_downloader = CopernicusDownloader( - product="ERA5", - base_path_to_download="/path/to/Copernicus/", # Will be created if not available - token=None, - check=True, - ) - result = copernicus_downloader.download_data_era5( - variables=["swh"], - years=["2020"], - months=["01", "03"], - ) - print(result) + >>> downloader = CopernicusDownloader( + ... product="ERA5", + ... base_path_to_download="./copernicus_data", + ... token="your_token" + ... ) + >>> result = downloader.download_data( + ... variables=["swh"], + ... years=["2020"], + ... months=["01"], + ... force=False, + ... dry_run=False + ... ) """ products_configs = { "ERA5": json.load( open(os.path.join(os.path.dirname(__file__), "ERA5", "ERA5_config.json")) - ) + ), + "CERRA": json.load( + open(os.path.join(os.path.dirname(__file__), "CERRA", "CERRA_config.json")) + ), } def __init__( self, product: str, base_path_to_download: str, - token: str = None, + api_key: str, debug: bool = True, - check: bool = True, ) -> None: """ - This is the constructor for the CopernicusDownloader class. + Initialize the CopernicusDownloader. Parameters ---------- product : str - The product to download data from. Currently only ERA5 is supported. + The product to download data from (e.g., "ERA5", "CERRA"). base_path_to_download : str - The base path to download the data to. - token : str, optional - The API token to use to download data. Default is None. + Base path where downloaded files will be stored. + api_key : str + Copernicus CDS API key. debug : bool, optional - Whether to run in debug mode. Default is True. - check : bool, optional - Whether to just check the data. Default is True. + If True, sets logger to DEBUG level. Default is True. Raises ------ ValueError - If the product configuration is not found. + If the product configuration is not found or server URL is not specified. """ super().__init__( - base_path_to_download=base_path_to_download, debug=debug, check=check + product=product, base_path_to_download=base_path_to_download, debug=debug ) - self._product = product + self._product_config = self.products_configs.get(product) if self._product_config is None: - raise ValueError(f"{product} configuration not found") + raise ValueError( + f"Product '{product}' not found. Available: {list(self.products_configs.keys())}" + ) + self.set_logger_name( f"CopernicusDownloader-{product}", level="DEBUG" if debug else "INFO" ) - if not self.check: - self._client = cdsapi.Client( - url=config["url"], key=token or config["key"], debug=self.debug - ) - self.logger.info("---- DOWNLOADING DATA ----") - else: - self.logger.info("---- CHECKING DATA ----") - @property - def product(self) -> str: - return self._product + # Initialize CDS client + server_url = self._product_config.get("url") + if server_url is None: + raise ValueError("Server URL not found in product configuration") + self._client = cdsapi.Client(url=server_url, key=api_key, debug=self.debug) + + self.logger.info(f"---- COPERNICUS DOWNLOADER INITIALIZED ({product}) ----") @property def product_config(self) -> dict: + """ + Product configuration dictionary loaded from config file. + + Returns + ------- + dict + Product configuration dictionary. + """ return self._product_config @property def client(self) -> cdsapi.Client: + """ + CDS API client (initialized with API key). + + Returns + ------- + cdsapi.Client + CDS API client instance. + """ return self._client def list_variables(self, type: str = None) -> List[str]: """ - Lists the variables available for the product. - Filtering by type if provided. + List variables available for the product. Parameters ---------- type : str, optional - The type of variables to list. Default is None. + Filter by type (e.g., "ocean"). Default is None. Returns ------- List[str] - The list of variables available for the product. + List of variable names. """ if type == "ocean": @@ -135,61 +131,34 @@ def list_variables(self, type: str = None) -> List[str]: for var_name, var_info in self.product_config["variables"].items() if var_info["type"] == "ocean" ] - return list(self.product_config["variables"].keys()) - - def list_datasets(self) -> List[str]: - """ - Lists the datasets available for the product. - - Returns - ------- - List[str] - The list of datasets available for the product. - """ - return list(self.product_config["datasets"].keys()) + return list(self.product_config["variables"].keys()) - def show_markdown_table(self) -> None: - """ - Create a Markdown table from the configuration dictionary and print it. + def download_data( + self, + dry_run: bool = True, + *args, + **kwargs, + ) -> DownloadResult: """ + Download data for the product. - # Define the table headers - headers = ["name", "long_name", "units", "type"] - header_line = "| " + " | ".join(headers) + " |" - separator_line = ( - "| " + " | ".join(["-" * len(header) for header in headers]) + " |" - ) - - # Initialize the table with headers - table_lines = [header_line, separator_line] - - # Add rows for each variable - for var_name, var_info in self.product_config["variables"].items(): - long_name = var_info.get("long_name", "") - units = var_info.get("units", "") - type = var_info.get("type", "") - row = f"| {var_name} | {long_name} | {units} | {type} |" - table_lines.append(row) - - # Print the table - print("\n".join(table_lines)) - - def download_data(self, *args, **kwargs) -> str: - """ - Downloads the data for the product. + Routes to product-specific download methods based on the product type. Parameters ---------- + dry_run : bool, optional + If True, only check what would be downloaded without actually downloading. + Default is True. *args - The arguments to pass to the download function. + Arguments passed to product-specific download method. **kwargs - The keyword arguments to pass to the download function. + Keyword arguments passed to product-specific download method. Returns ------- - str - The message with the fully downloaded files and the not fully downloaded files. + DownloadResult + Result with information about downloaded, skipped, and error files. Raises ------ @@ -198,7 +167,9 @@ def download_data(self, *args, **kwargs) -> str: """ if self.product == "ERA5": - return self.download_data_era5(*args, **kwargs) + return self.download_data_era5(dry_run=dry_run, *args, **kwargs) + elif self.product == "CERRA": + return self.download_data_cerra(dry_run=dry_run, *args, **kwargs) else: raise ValueError(f"Download for product {self.product} not supported") @@ -214,267 +185,510 @@ def download_data_era5( data_format: str = "netcdf", download_format: str = "unarchived", force: bool = False, - ) -> str: + dry_run: bool = True, + ) -> DownloadResult: """ - Downloads the data for the ERA5 product. + Download ERA5 data. + + Downloads ERA5 reanalysis data for specified variables, time periods, and optionally + a geographic area. Files are saved to: + base_path_to_download/product/dataset/type/product_type/variable/filename.nc Parameters ---------- variables : List[str] - The variables to download. If not provided, all variables in self.product_config - will be downloaded. + List of variable names to download. If empty, downloads all available variables. years : List[str] - The years to download. Years are downloaded one by one. + List of years to download (e.g., ["2020", "2021"]). months : List[str] - The months to download. Months are downloaded together. + List of months to download (e.g., ["01", "02"]). days : List[str], optional - The days to download. If None, all days in the month will be downloaded. - Default is None. + List of days to download. If None, downloads all days (1-31). Default is None. times : List[str], optional - The times to download. If None, all times in the day will be downloaded. + List of times to download (e.g., ["00:00", "12:00"]). If None, downloads all hours. Default is None. area : List[float], optional - The area to download. If None, the whole globe will be downloaded. + Geographic area as [north, west, south, east]. If None, downloads global data. Default is None. product_type : str, optional - The product type to download. Default is "reanalysis". + Product type (e.g., "reanalysis", "ensemble_mean"). Default is "reanalysis". data_format : str, optional - The data format to download. Default is "netcdf". + Data format. Default is "netcdf". download_format : str, optional - The download format to use. Default is "unarchived". + Download format. Default is "unarchived". force : bool, optional - Whether to force the download. Default is False. + Force re-download even if file exists. Default is False. + dry_run : bool, optional + If True, only check what would be downloaded. Default is True. Returns ------- - str - The message with the fully downloaded files and the not fully downloaded files. - Error files are also included. + DownloadResult + Result with all downloaded files and download statistics. - TODO - ----- - - Implement lambda function to name the files. + Raises + ------ + ValueError + If years or months are empty lists. """ - if not isinstance(variables, list): - raise ValueError("Variables must be a list of strings") - elif len(variables) == 0: + if not isinstance(variables, list) or len(variables) == 0: variables = list(self.product_config["variables"].keys()) - self.logger.info(f"Variables not provided. Using {variables}") if not isinstance(years, list) or len(years) == 0: - raise ValueError("Years must be a non-empty list of strings") - else: - years = [f"{int(year):04d}" for year in years] + raise ValueError("Years must be a non-empty list") + years = [f"{int(year):04d}" for year in years] if not isinstance(months, list) or len(months) == 0: - raise ValueError("Months must be a non-empty list of strings") - else: - months = [f"{int(month):02d}" for month in months] - last_month = months[-1] - if days is not None: - if not isinstance(days, list) or len(days) == 0: - raise ValueError("Day must be a non-empty list of strings") - else: + raise ValueError("Months must be a non-empty list") + months = [f"{int(month):02d}" for month in months] + last_month = months[-1] + if days is None: days = [f"{day:02d}" for day in range(1, 32)] - self.logger.info(f"Day not provided. Using {days}") - if times is not None: - if not isinstance(times, list) or len(times) == 0: - raise ValueError("Time must be a non-empty list of strings") - else: + if times is None: times = [f"{hour:02d}:00" for hour in range(24)] - self.logger.info(f"Time not provided. Using {times}") - if area is not None: - if not isinstance(area, list) or len(area) != 4: - raise ValueError("Area must be a list of 4 floats") - if not isinstance(product_type, str): - raise ValueError("Product type must be a string") - if not isinstance(data_format, str): - raise ValueError("Data format must be a string") - if not isinstance(download_format, str): - raise ValueError("Download format must be a string") - if not isinstance(force, bool): - raise ValueError("Force must be a boolean") - - fully_downloaded_files: List[str] = [] - NOT_fullly_downloaded_files: List[str] = [] - error_files: List[str] = [] + result = self.create_download_result() + + # Prepare download tasks + download_tasks = [] for variable in variables: for year in years: - variable_config = self.product_config["variables"].get(variable) - if variable_config is None: - self.logger.error( - f"Variable {variable} not found in product configuration file" - ) - continue - variable_dataset = self.product_config["datasets"].get( - variable_config["dataset"] + task = self._prepare_era5_download_task( + variable=variable, + year=year, + months=months, + days=days, + times=times, + area=area, + product_type=product_type, + data_format=data_format, + download_format=download_format, + last_month=last_month, + ) + if task is not None: + download_tasks.append(task) + + if not download_tasks: + return self.finalize_download_result( + result, "No valid download tasks found" + ) + + self.logger.info(f"Prepared {len(download_tasks)} download tasks") + + # Download files sequentially + for task in download_tasks: + task_result = self._download_single_file(task, force=force, dry_run=dry_run) + if isinstance(task_result, DownloadResult): + result.downloaded_files.extend(task_result.downloaded_files) + result.skipped_files.extend(task_result.skipped_files) + result.error_files.extend(task_result.error_files) + result.errors.extend(task_result.errors) + + return self.finalize_download_result(result) + + def download_data_cerra( + self, + variables: List[str], + years: List[str], + months: List[str], + days: List[str] = None, + times: List[str] = None, + area: List[float] = None, + level_type: str = "surface_or_atmosphere", + data_type: List[str] = None, + product_type: str = "analysis", + data_format: str = "netcdf", + force: bool = False, + dry_run: bool = True, + ) -> DownloadResult: + """ + Download CERRA data. + + Downloads CERRA reanalysis data for specified variables, time periods, and optionally + a geographic area. Files are saved to: + base_path_to_download/product/dataset/type/product_type/variable/filename.nc + + Parameters + ---------- + variables : List[str] + List of variable names to download. If empty, downloads all available variables. + years : List[str] + List of years to download (e.g., ["2020", "2021"]). + months : List[str] + List of months to download (e.g., ["01", "02"]). + days : List[str], optional + List of days to download. If None, downloads all days (1-31). Default is None. + times : List[str], optional + List of times to download (e.g., ["00:00", "12:00"]). If None, downloads standard + times (00:00, 03:00, 06:00, 09:00, 12:00, 15:00, 18:00, 21:00). Default is None. + area : List[float], optional + Geographic area as [north, west, south, east]. If None, downloads global data. + Default is None. + level_type : str, optional + Level type (e.g., "surface_or_atmosphere"). Default is "surface_or_atmosphere". + data_type : List[str], optional + Data type (e.g., ["reanalysis"]). If None, uses ["reanalysis"]. Default is None. + product_type : str, optional + Product type (e.g., "analysis", "forecast"). Default is "analysis". + data_format : str, optional + Data format. Default is "netcdf". + force : bool, optional + Force re-download even if file exists. Default is False. + dry_run : bool, optional + If True, only check what would be downloaded. Default is True. + + Returns + ------- + DownloadResult + Result with all downloaded files and download statistics. + + Raises + ------ + ValueError + If years or months are empty lists. + """ + + if not isinstance(variables, list) or len(variables) == 0: + variables = list(self.product_config["variables"].keys()) + if not isinstance(years, list) or len(years) == 0: + raise ValueError("Years must be a non-empty list") + years = [f"{int(year):04d}" for year in years] + if not isinstance(months, list) or len(months) == 0: + raise ValueError("Months must be a non-empty list") + months = [f"{int(month):02d}" for month in months] + last_month = months[-1] + if days is None: + days = [f"{day:02d}" for day in range(1, 32)] + if times is None: + times = [ + "00:00", + "03:00", + "06:00", + "09:00", + "12:00", + "15:00", + "18:00", + "21:00", + ] + if data_type is None: + data_type = ["reanalysis"] + + result = self.create_download_result() + + # Prepare download tasks + download_tasks = [] + for variable in variables: + for year in years: + task = self._prepare_cerra_download_task( + variable=variable, + year=year, + months=months, + days=days, + times=times, + area=area, + level_type=level_type, + data_type=data_type, + product_type=product_type, + data_format=data_format, + last_month=last_month, ) - if variable_dataset is None: + if task is not None: + download_tasks.append(task) + + if not download_tasks: + return self.finalize_download_result( + result, "No valid download tasks found" + ) + + self.logger.info(f"Prepared {len(download_tasks)} download tasks") + + # Download files sequentially + for task in download_tasks: + task_result = self._download_single_file(task, force=force, dry_run=dry_run) + if isinstance(task_result, DownloadResult): + result.downloaded_files.extend(task_result.downloaded_files) + result.skipped_files.extend(task_result.skipped_files) + result.error_files.extend(task_result.error_files) + result.errors.extend(task_result.errors) + + return self.finalize_download_result(result) + + def _prepare_era5_download_task( + self, + variable: str, + year: str, + months: List[str], + days: List[str], + times: List[str], + area: Optional[List[float]], + product_type: str, + data_format: str, + download_format: str, + last_month: str, + ) -> Optional[Dict[str, Any]]: + """ + Prepare a download task for ERA5. + + Creates a task dictionary with all necessary information for downloading + a single variable for a single year. + + Parameters + ---------- + variable : str + Variable name. + year : str + Year (formatted as "YYYY"). + months : List[str] + List of months (formatted as "MM"). + days : List[str] + List of days (formatted as "DD"). + times : List[str] + List of times (formatted as "HH:MM"). + area : Optional[List[float]] + Geographic area as [north, west, south, east] or None. + product_type : str + Product type. + data_format : str + Data format. + download_format : str + Download format. + last_month : str + Last month in the list (used for date range formatting). + + Returns + ------- + Optional[Dict[str, Any]] + Task dictionary with download information, or None if configuration is invalid. + """ + + variable_config = self.product_config["variables"].get(variable) + if variable_config is None: + self.logger.error(f"Variable {variable} not found in configuration") + return None + + variable_dataset = self.product_config["datasets"].get( + variable_config["dataset"] + ) + if variable_dataset is None: + self.logger.error( + f"Dataset {variable_config['dataset']} not found in configuration" + ) + return None + + template_for_variable = variable_dataset["template"].copy() + if variable == "spectra": + template_for_variable["date"] = ( + f"{year}-{months[0]}-01/to/{year}-{months[-1]}-31" + ) + if area is not None: + template_for_variable["area"] = "/".join([str(coord) for coord in area]) + else: + template_for_variable["variable"] = variable_config["cds_name"] + template_for_variable["year"] = year + template_for_variable["month"] = months + template_for_variable["day"] = days + template_for_variable["time"] = times + template_for_variable["product_type"] = product_type + template_for_variable["data_format"] = data_format + template_for_variable["download_format"] = download_format + if area is not None: + template_for_variable["area"] = area + + # Check mandatory fields + for mandatory_field in variable_dataset["mandatory_fields"]: + if template_for_variable.get(mandatory_field) is None: + try: + template_for_variable[mandatory_field] = variable_config[ + mandatory_field + ] + except KeyError: self.logger.error( - f"Dataset {variable_config['dataset']} not found in product configuration file" + f"Mandatory field {mandatory_field} not found for {variable}" ) - continue + return None + + # Create output file path + output_nc_file = os.path.join( + self.base_path_to_download, + self.product, + variable_config["dataset"], + variable_config["type"], + product_type, + variable_config["cds_name"], + f"{variable_config['nc_name']}_{year}_{'_'.join(months)}.nc", + ) - template_for_variable = variable_dataset["template"].copy() - if variable == "spectra": - template_for_variable["date"] = ( - f"{year}-{months[0]}-01/to/{year}-{months[-1]}-31" - ) - if area is not None: - template_for_variable["area"] = "/".join( - [str(coord) for coord in area] - ) - else: - template_for_variable["variable"] = variable_config["cds_name"] - template_for_variable["year"] = year - template_for_variable["month"] = months - template_for_variable["day"] = days - template_for_variable["time"] = times - template_for_variable["product_type"] = product_type - template_for_variable["data_format"] = data_format - template_for_variable["download_format"] = download_format - if area is not None: - template_for_variable["area"] = area - - self.logger.info( - f""" - Template for variable {variable}: - {template_for_variable} - """ - ) + return { + "variable": variable, + "year": year, + "variable_config": variable_config, + "variable_dataset": variable_dataset, + "template": template_for_variable, + "output_file": output_nc_file, + "last_month": last_month, + } + + def _prepare_cerra_download_task( + self, + variable: str, + year: str, + months: List[str], + days: List[str], + times: List[str], + area: Optional[List[float]], + level_type: str, + data_type: List[str], + product_type: str, + data_format: str, + last_month: str, + ) -> Optional[Dict[str, Any]]: + """ + Prepare a download task for CERRA. + + Creates a task dictionary with all necessary information for downloading + a single variable for a single year. + + Parameters + ---------- + variable : str + Variable name. + year : str + Year (formatted as "YYYY"). + months : List[str] + List of months (formatted as "MM"). + days : List[str] + List of days (formatted as "DD"). + times : List[str] + List of times (formatted as "HH:MM"). + area : Optional[List[float]] + Geographic area as [north, west, south, east] or None. + level_type : str + Level type. + data_type : List[str] + Data type list. + product_type : str + Product type. + data_format : str + Data format. + last_month : str + Last month in the list (used for date range formatting). + + Returns + ------- + Optional[Dict[str, Any]] + Task dictionary with download information, or None if configuration is invalid. + """ + + variable_config = self.product_config["variables"].get(variable) + if variable_config is None: + self.logger.error(f"Variable {variable} not found in configuration") + return None - skip_because_of_manadatory_fields = False - for mandatory_field in variable_dataset["mandatory_fields"]: - try: - if template_for_variable.get(mandatory_field) is None: - template_for_variable[mandatory_field] = variable_config[ - mandatory_field - ] - except KeyError: - self.logger.error( - f"Mandotory field {mandatory_field} not found in variable configuration file for {variable}" - ) - skip_because_of_manadatory_fields = True - if skip_because_of_manadatory_fields: - continue - - # Create the output file name once request is properly formatted - output_nc_file = os.path.join( - self.base_path_to_download, - self.product, - variable_config["dataset"], - variable_config["type"], - product_type, - variable_config["cds_name"], - f"{variable_config['nc_name']}_{year}_{'_'.join(months)}.nc", - # f"era5_waves_{variable_config['cds_name']}_{year}.nc", + variable_dataset = self.product_config["datasets"].get( + variable_config["dataset"] + ) + if variable_dataset is None: + self.logger.error( + f"Dataset {variable_config['dataset']} not found in configuration" + ) + return None + + template_for_variable = variable_dataset["template"].copy() + template_for_variable["variable"] = [variable_config["cds_name"]] + template_for_variable["level_type"] = level_type + template_for_variable["data_type"] = data_type + template_for_variable["product_type"] = product_type + template_for_variable["year"] = [year] + template_for_variable["month"] = months + template_for_variable["day"] = days + template_for_variable["time"] = times + template_for_variable["data_format"] = data_format + + if area is not None: + template_for_variable["area"] = area + + # Check mandatory fields + for mandatory_field in variable_dataset["mandatory_fields"]: + if template_for_variable.get(mandatory_field) is None: + self.logger.error( + f"Mandatory field {mandatory_field} not found for {variable}" ) - # Create the output directory if it does not exist - if not self.check: - os.makedirs(os.path.dirname(output_nc_file), exist_ok=True) + return None + + # Create output file path + output_nc_file = os.path.join( + self.base_path_to_download, + self.product, + variable_config["dataset"], + variable_config["type"], + product_type, + variable_config["cds_name"], + f"{variable_config['nc_name']}_{year}_{'_'.join(months)}.nc", + ) - self.logger.info(f""" - - Analyzing {output_nc_file} + return { + "variable": variable, + "year": year, + "variable_config": variable_config, + "template": template_for_variable, + "last_month": last_month, + "output_file": output_nc_file, + } + + def _download_single_file( + self, task: Dict[str, Any], force: bool = False, dry_run: bool = True + ) -> DownloadResult: + """ + Download a single file based on a task dictionary. - """) + Parameters + ---------- + task : Dict[str, Any] + Task dictionary containing download information (output_file, template, etc.). + force : bool, optional + Force re-download even if file exists. Default is False. + dry_run : bool, optional + If True, only check what would be downloaded. Default is True. - try: - if self.check or not force: - if os.path.exists(output_nc_file): - self.logger.debug( - f"Checking {output_nc_file} file is complete" - ) - try: - nc = xr.open_dataset(output_nc_file) - _, last_day = calendar.monthrange( - int(year), int(last_month) - ) - last_hour = f"{year}-{last_month}-{last_day}T23" - try: - last_hour_nc = str(nc.time[-1].values) - except Exception as _te: - last_hour_nc = str(nc.valid_time[-1].values) - nc.close() - if last_hour not in last_hour_nc: - self.logger.debug( - f"{output_nc_file} ends at {last_hour_nc} instead of {last_hour}" - ) - if self.check: - NOT_fullly_downloaded_files.append( - output_nc_file - ) - else: - self.logger.debug( - f"Downloading: {variable} to {output_nc_file} because it is not complete" - ) - self.client.retrieve( - name=variable_config["dataset"], - request=template_for_variable, - target=output_nc_file, - ) - fully_downloaded_files.append(output_nc_file) - else: - self.logger.debug( - f"{output_nc_file} already downloaded and complete" - ) - fully_downloaded_files.append(output_nc_file) - except Exception as e: - self.logger.error( - f"Error was raised opening {output_nc_file} - {e}, re-downloading..." - ) - if self.check: - NOT_fullly_downloaded_files.append(output_nc_file) - else: - self.logger.debug( - f"Downloading: {variable} to {output_nc_file} because it is not complete" - ) - self.client.retrieve( - name=variable_config["dataset"], - request=template_for_variable, - target=output_nc_file, - ) - fully_downloaded_files.append(output_nc_file) - elif self.check: - NOT_fullly_downloaded_files.append(output_nc_file) - else: - self.logger.debug( - f"Downloading: {variable} to {output_nc_file}" - ) - self.client.retrieve( - name=variable_config["dataset"], - request=template_for_variable, - target=output_nc_file, - ) - fully_downloaded_files.append(output_nc_file) - else: - self.logger.debug( - f"Downloading: {variable} to {output_nc_file}" - ) - self.client.retrieve( - name=variable_config["dataset"], - request=template_for_variable, - target=output_nc_file, - ) - fully_downloaded_files.append(output_nc_file) - - except Exception as e: - self.logger.error(f""" - - Skippping {output_nc_file} for {e} - - """) - error_files.append(output_nc_file) - - fully_downloaded_files_str = "\n".join(fully_downloaded_files) - NOT_fullly_downloaded_files_str = "\n".join(NOT_fullly_downloaded_files) - error_files = "\n".join(error_files) - - return f""" - Fully downloaded files: - {fully_downloaded_files_str} - Not fully downloaded files: - {NOT_fullly_downloaded_files_str} - Error files: - {error_files} + Returns + ------- + DownloadResult + Result with information about the downloaded, skipped, or error file. """ + + result = DownloadResult() + output_file = task["output_file"] + variable = task["variable"] + variable_config = task["variable_config"] + template = task["template"] + + if not dry_run: + os.makedirs(os.path.dirname(output_file), exist_ok=True) + + try: + # Check if file already exists + if not force and os.path.exists(output_file): + if dry_run: + result.add_skipped(output_file, "File already exists (dry run)") + else: + result.add_downloaded(output_file) + return result + + if dry_run: + result.add_skipped(output_file, f"Would download {variable} (dry run)") + return result + + # Download file + self.logger.debug(f"Downloading: {variable} to {output_file}") + self.client.retrieve( + name=variable_config["dataset"], + request=template, + target=output_file, + ) + result.add_downloaded(output_file) + self.logger.info(f"Downloaded: {output_file}") + + except Exception as e: + self.logger.error(f"Error downloading {output_file}: {e}") + result.add_error(output_file, e) + + return result diff --git a/bluemath_tk/downloaders/copernicus/copernicus_marine_downloader.py b/bluemath_tk/downloaders/copernicus/copernicus_marine_downloader.py deleted file mode 100644 index e110ff6..0000000 --- a/bluemath_tk/downloaders/copernicus/copernicus_marine_downloader.py +++ /dev/null @@ -1,33 +0,0 @@ -import copernicusmarine - -copernicusmarine.subset( - dataset_id="cmems_mod_glo_wav_my_0.2deg_PT3H-i", - dataset_version="202411", - variables=[ - "VHM0", - "VHM0_SW1", - "VHM0_SW2", - "VHM0_WW", - "VMDR", - "VMDR_SW1", - "VMDR_SW2", - "VMDR_WW", - "VPED", - "VSDX", - "VSDY", - "VTM01_SW1", - "VTM01_SW2", - "VTM01_WW", - "VTM02", - "VTM10", - "VTPK", - ], - minimum_longitude=-10.43452696741375, - maximum_longitude=-0.5556814090161573, - minimum_latitude=42.03998421470398, - maximum_latitude=46.133428506857676, - start_datetime="1980-01-01T00:00:00", - end_datetime="2023-04-30T21:00:00", - coordinates_selection_method="strict-inside", - disable_progress_bar=False, -) diff --git a/bluemath_tk/downloaders/ecmwf/ecmwf_downloader.py b/bluemath_tk/downloaders/ecmwf/ecmwf_downloader.py index 8f4f010..f102264 100644 --- a/bluemath_tk/downloaders/ecmwf/ecmwf_downloader.py +++ b/bluemath_tk/downloaders/ecmwf/ecmwf_downloader.py @@ -1,44 +1,32 @@ import json import os -from typing import List, Union -import xarray as xr from ecmwf.opendata import Client from .._base_downloaders import BaseDownloader +from .._download_result import DownloadResult class ECMWFDownloader(BaseDownloader): """ This is the main class to download data from the ECMWF. - Attributes - ---------- - product : str - The product to download data from. Currently only OpenData is supported. - product_config : dict - The configuration for the product to download data from. - client : ecmwf.opendata.Client - The client to interact with the ECMWF API. - Examples -------- - .. jupyter-execute:: - - from bluemath_tk.downloaders.ecmwf.ecmwf_downloader import ECMWFDownloader - - ecmwf_downloader = ECMWFDownloader( - product="OpenData", - base_path_to_download="/path/to/ECMWF/", # Will be created if not available - check=True, - ) - dataset = ecmwf_downloader.download_data( - load_data=False, - param=["msl"], - step=[0, 240], - type="fc", - ) - print(dataset) + >>> downloader = ECMWFDownloader( + ... product="OpenData", + ... base_path_to_download="./ecmwf_data", + ... model="ifs", + ... resolution="0p25" + ... ) + >>> result = downloader.download_data( + ... dataset="forecast_data", + ... param=["msl"], + ... step=[0, 240], + ... type="fc", + ... force=False, + ... dry_run=False + ... ) """ products_configs = { @@ -58,110 +46,117 @@ def __init__( model: str = "ifs", resolution: str = "0p25", debug: bool = True, - check: bool = True, ) -> None: """ - This is the constructor for the ECMWFDownloader class. + Initialize the ECMWFDownloader. Parameters ---------- product : str The product to download data from. Currently only OpenData is supported. base_path_to_download : str - The base path to download the data to. + Base path where downloaded files will be stored. model : str, optional - The model to download data from. Default is "ifs". + The model to download data from (e.g., "ifs", "aifs"). Default is "ifs". resolution : str, optional - The resolution to download data from. Default is "0p25". + The resolution to download data from (e.g., "0p25"). Default is "0p25". debug : bool, optional - Whether to run in debug mode. Default is True. - check : bool, optional - Whether to just check the data. Default is True. + If True, sets logger to DEBUG level. Default is True. Raises ------ ValueError - If the product configuration is not found. + If the product configuration is not found, or if model/resolution are not supported. """ super().__init__( - base_path_to_download=base_path_to_download, debug=debug, check=check + product=product, base_path_to_download=base_path_to_download, debug=debug ) - self._product = product + self._product_config = self.products_configs.get(product) if self._product_config is None: - raise ValueError(f"{product} configuration not found") + available_products = list(self.products_configs.keys()) + raise ValueError( + f"{product} configuration not found. Available products: {available_products}" + ) + self.set_logger_name( f"ECMWFDownloader-{product}", level="DEBUG" if debug else "INFO" ) - if not self.check: - if model not in self.product_config["datasets"]["forecast_data"]["models"]: - raise ValueError(f"Model {model} not supported for {self.product}") - if ( - resolution - not in self.product_config["datasets"]["forecast_data"]["resolutions"] - ): - raise ValueError( - f"Resolution {resolution} not supported for {self.product}" - ) - self._client = Client( - source="ecmwf", - model=model, - resol=resolution, - preserve_request_order=False, - infer_stream_keyword=True, + + # Validate model and resolution + if model not in self.product_config["datasets"]["forecast_data"]["models"]: + raise ValueError(f"Model {model} not supported for {self.product}") + if ( + resolution + not in self.product_config["datasets"]["forecast_data"]["resolutions"] + ): + raise ValueError( + f"Resolution {resolution} not supported for {self.product}" ) - self.logger.info("---- DOWNLOADING DATA ----") - else: - self.logger.info("---- CHECKING DATA ----") + # Always initialize client (will skip API calls in dry_run mode) + self._client = Client( + source="ecmwf", + model=model, + resol=resolution, + preserve_request_order=False, + infer_stream_keyword=True, + ) + self.logger.info(f"---- ECMWF DOWNLOADER INITIALIZED ({product}) ----") # Set the model and resolution parameters self.model = model self.resolution = resolution - @property - def product(self) -> str: - return self._product - @property def product_config(self) -> dict: + """ + Product configuration dictionary loaded from config file. + + Returns + ------- + dict + Product configuration dictionary. + """ return self._product_config @property def client(self) -> Client: - return self._client - - def list_datasets(self) -> List[str]: """ - Lists the datasets available for the product. + ECMWF OpenData client (initialized with model and resolution). Returns ------- - List[str] - The list of datasets available for the product. + Client + ECMWF OpenData client instance. """ - - return list(self.product_config["datasets"].keys()) + return self._client def download_data( - self, load_data: bool = False, *args, **kwargs - ) -> Union[str, xr.Dataset]: + self, + dry_run: bool = True, + *args, + **kwargs, + ) -> DownloadResult: """ - Downloads the data for the product. + Download data for the product. + + Routes to product-specific download methods based on the product type. Parameters ---------- - load_data : bool, optional - Whether to load the data into an xarray.Dataset. Default is False. + dry_run : bool, optional + If True, only check what would be downloaded without actually downloading. + Default is True. *args - The arguments to pass to the download function. + Arguments passed to product-specific download method. **kwargs - The keyword arguments to pass to the download function. + Keyword arguments passed to product-specific download method. Returns ------- - Union[str, xr.Dataset] - The path to the downloaded file if load_data is False, otherwise the xarray.Dataset. + DownloadResult + Result with information about downloaded, skipped, and error files. Raises ------ @@ -170,77 +165,113 @@ def download_data( """ if self.product == "OpenData": - downloaded_file_path = self.download_data_open_data(*args, **kwargs) - if load_data: - return xr.open_dataset(downloaded_file_path, engine="cfgrib") - else: - return downloaded_file_path + return self.download_data_open_data(dry_run=dry_run, *args, **kwargs) else: raise ValueError(f"Download for product {self.product} not supported") def download_data_open_data( self, + dataset: str, force: bool = False, + dry_run: bool = True, **kwargs, - ) -> str: + ) -> DownloadResult: """ - Downloads the data for the OpenData product. + Download data for the OpenData product. + + Downloads files based on the specified parameters. Files are saved to: + base_path_to_download/product/dataset/model/resolution/filename.grib2 Parameters ---------- + dataset : str + The dataset to download (e.g., "forecast_data"). + Use list_datasets() to see available datasets. force : bool, optional - Whether to force the download. Default is False. + Force re-download even if file exists. Default is False. + dry_run : bool, optional + If True, only check what would be downloaded. Default is True. **kwargs - The keyword arguments to pass to the download function. + Keyword arguments passed to the ECMWF client retrieve method + (e.g., param, step, type). Returns ------- - str - The path to the downloaded file. + DownloadResult + Result with all downloaded files and download statistics. + + Raises + ------ + ValueError + If dataset is not found. """ - if "param" in kwargs: - variables = kwargs["param"] - else: - variables = [] - if "step" in kwargs: - steps = kwargs["step"] - if not isinstance(steps, list): - steps = [steps] - else: - steps = [] - if "type" in kwargs: - type = kwargs["type"] - else: - type = "fc" - - output_grib_file = os.path.join( - self.base_path_to_download, - self.product, - self.model, - self.resolution, - f"{'_'.join(variables)}_{'_'.join(str(step) for step in steps)}_{type}.grib2", - ) - if not self.check: - os.makedirs(os.path.dirname(output_grib_file), exist_ok=True) + # Validate dataset + if dataset not in self.list_datasets(): + raise ValueError( + f"Dataset '{dataset}' not found. Available: {self.list_datasets()}" + ) + + result = self.create_download_result() - if self.check or not force: - if os.path.exists(output_grib_file): - self.logger.debug(f"{output_grib_file} already downloaded") + try: + # Extract parameters from kwargs + if "param" in kwargs: + variables = kwargs["param"] else: - if self.check: - self.logger.debug(f"{output_grib_file} not downloaded") - else: - self.logger.debug(f"Downloading: {output_grib_file}") - self.client.retrieve( - target=output_grib_file, - **kwargs, - ) - else: - self.logger.debug(f"Downloading: {output_grib_file}") - self.client.retrieve( - target=output_grib_file, - **kwargs, + variables = [] + if "step" in kwargs: + steps = kwargs["step"] + if not isinstance(steps, list): + steps = [steps] + else: + steps = [] + if "type" in kwargs: + type = kwargs["type"] + else: + type = "fc" + + # Construct output file path: base_path/product/dataset/model/resolution/filename.grib2 + output_grib_file = os.path.join( + self.base_path_to_download, + self.product, + dataset, + self.model, + self.resolution, + f"{'_'.join(variables)}_{'_'.join(str(step) for step in steps)}_{type}.grib2", ) - return output_grib_file + # Skip if file already exists (unless force=True) + if not force and os.path.exists(output_grib_file): + result.add_skipped(output_grib_file, "Already downloaded") + return self.finalize_download_result(result) + + # Handle dry run: record as skipped without actual download + if dry_run: + result.add_skipped(output_grib_file, "Would download (dry run)") + return self.finalize_download_result(result) + + # Attempt to download the file + try: + # Create local directory structure if needed + os.makedirs(os.path.dirname(output_grib_file), exist_ok=True) + + # Download the file + self.logger.debug(f"Downloading: {output_grib_file}") + self.client.retrieve( + target=output_grib_file, + **kwargs, + ) + + result.add_downloaded(output_grib_file) + self.logger.info(f"Downloaded: {output_grib_file}") + + except Exception as e: + result.add_error(output_grib_file, e) + self.logger.error(f"Error downloading {output_grib_file}: {e}") + + return self.finalize_download_result(result) + + except Exception as e: + result.add_error("download_operation", e) + return self.finalize_download_result(result) diff --git a/bluemath_tk/downloaders/noaa/NOAA_config.json b/bluemath_tk/downloaders/noaa/NDBC/NDBC_config.json similarity index 100% rename from bluemath_tk/downloaders/noaa/NOAA_config.json rename to bluemath_tk/downloaders/noaa/NDBC/NDBC_config.json diff --git a/bluemath_tk/downloaders/noaa/noaa_downloader.py b/bluemath_tk/downloaders/noaa/noaa_downloader.py index 5483a0b..09709d6 100644 --- a/bluemath_tk/downloaders/noaa/noaa_downloader.py +++ b/bluemath_tk/downloaders/noaa/noaa_downloader.py @@ -12,162 +12,482 @@ import xarray as xr from .._base_downloaders import BaseDownloader +from .._download_result import DownloadResult -class NOAADownloader(BaseDownloader): +def read_bulk_parameters( + base_path: str, buoy_id: str, years: Union[int, List[int]] +) -> Optional[pd.DataFrame]: """ - This is the main class to download and read data from NOAA. + Read bulk parameters for a specific buoy and year(s). - Attributes + Parameters ---------- - config : dict - The configuration for NOAA data sources loaded from JSON file. - base_path_to_download : Path + base_path : str Base path where the data is stored. - debug : bool - Whether to run in debug mode. - - Examples - -------- - .. jupyter-execute:: + buoy_id : str + The buoy ID. + years : Union[int, List[int]] + The year(s) to read data for. Can be a single year or a list of years. + + Returns + ------- + Optional[pd.DataFrame] + DataFrame containing the bulk parameters, or None if data not found. + """ - from bluemath_tk.downloaders.noaa.noaa_downloader import NOAADownloader + if isinstance(years, int): + years = [years] - noaa_downloader = NOAADownloader( - base_path_to_download="/path/to/NOAA/", # Will be created if not available - debug=True, - check=False, + all_data = [] + for year in years: + file_path = os.path.join( + base_path, + "NDBC", + "buoy_data", + buoy_id, + f"buoy_{buoy_id}_bulk_parameters.csv", ) + try: + df = pd.read_csv(file_path) + df["datetime"] = pd.to_datetime( + df["YYYY"].astype(str) + + "-" + + df["MM"].astype(str).str.zfill(2) + + "-" + + df["DD"].astype(str).str.zfill(2) + + " " + + df["hh"].astype(str).str.zfill(2) + + ":" + + df["mm"].astype(str).str.zfill(2) + ) + all_data.append(df) + except FileNotFoundError: + print(f"No bulk parameters file found for buoy {buoy_id} year {year}") + + if all_data: + return pd.concat(all_data, ignore_index=True).sort_values("datetime") + return None - # Download buoy bulk parameters and load DataFrame - result = noaa_downloader.download_data( - data_type="bulk_parameters", - buoy_id="41001", - years=[2020, 2021, 2022], - load_df=True + +def read_wave_spectra( + base_path: str, buoy_id: str, years: Union[int, List[int]] +) -> Optional[pd.DataFrame]: + """ + Read wave spectra data for a specific buoy and year(s). + + Parameters + ---------- + base_path : str + Base path where the data is stored. + buoy_id : str + The buoy ID. + years : Union[int, List[int]] + The year(s) to read data for. Can be a single year or a list of years. + + Returns + ------- + Optional[pd.DataFrame] + DataFrame containing the wave spectra, or None if data not found + """ + + if isinstance(years, int): + years = [years] + + all_data = [] + for year in years: + file_path = os.path.join( + base_path, + "NDBC", + "buoy_data", + buoy_id, + "wave_spectra", + f"buoy_{buoy_id}_spectra_{year}.csv", ) - print(result) + try: + df = pd.read_csv(file_path) + try: + df["date"] = pd.to_datetime( + df[["YYYY", "MM", "DD", "hh"]].rename( + columns={ + "YYYY": "year", + "MM": "month", + "DD": "day", + "hh": "hour", + } + ) + ) + df.drop(columns=["YYYY", "MM", "DD", "hh"], inplace=True) + except Exception: + df["date"] = pd.to_datetime( + df[["#YY", "MM", "DD", "hh", "mm"]].rename( + columns={ + "#YY": "year", + "MM": "month", + "DD": "day", + "hh": "hour", + "mm": "minute", + } + ) + ) + df.drop(columns=["#YY", "MM", "DD", "hh", "mm"], inplace=True) + df.set_index("date", inplace=True) + all_data.append(df) + except FileNotFoundError: + print(f"No wave spectra file found for buoy {buoy_id} year {year}") + + if all_data: + return pd.concat(all_data).sort_index() + return None + + +def _read_directional_file(file_path: Path) -> Optional[pd.DataFrame]: """ + Read a directional spectra file and return DataFrame with datetime index. - config = json.load( - open(os.path.join(os.path.dirname(__file__), "NOAA_config.json")) + Parameters + ---------- + file_path : Path + Path to the file to read + + Returns + ------- + Optional[pd.DataFrame] + DataFrame containing the directional spectra data, or None if data not found + """ + + print(f"Reading file: {file_path}") + try: + with gzip.open(file_path, "rt") as f: + # Read header lines until we find the frequencies + header_lines = [] + while True: + line = f.readline().strip() + if not line.startswith("#") and not line.startswith("YYYY"): + break + header_lines.append(line) + + # Parse frequencies + header = " ".join(header_lines) + try: + freqs = [float(x) for x in header.split()[5:]] + print(f"Found {len(freqs)} frequencies") + except (ValueError, IndexError) as e: + print(f"Error parsing frequencies: {e}") + return None + + # Read data + data = [] + dates = [] + # Process the first line + parts = line.strip().split() + if len(parts) >= 5: + try: + year, month, day, hour, minute = map(int, parts[:5]) + values = [float(x) for x in parts[5:]] + if len(values) == len(freqs): + dates.append(datetime(year, month, day, hour, minute)) + data.append(values) + except (ValueError, IndexError) as e: + print(f"Error parsing line: {e}") + + # Read remaining lines + for line in f: + parts = line.strip().split() + if len(parts) >= 5: + try: + year, month, day, hour, minute = map(int, parts[:5]) + values = [float(x) for x in parts[5:]] + if len(values) == len(freqs): + dates.append(datetime(year, month, day, hour, minute)) + data.append(values) + except (ValueError, IndexError) as e: + print(f"Error parsing line: {e}") + continue + + if not data: + print("No valid data points found in file") + return None + + df = pd.DataFrame(data, index=dates, columns=freqs) + print(f"Created DataFrame with shape: {df.shape}") + return df + + except Exception as e: + print(f"Error reading file {file_path}: {str(e)}") + return None + + +def read_directional_spectra( + base_path: str, buoy_id: str, years: Union[int, List[int]] +) -> Tuple[Optional[pd.DataFrame], ...]: + """ + Read directional spectra data for a specific buoy and year(s). + + Parameters + ---------- + base_path : str + Base path where the data is stored. + buoy_id : str + The buoy ID + years : Union[int, List[int]] + The year(s) to read data for. Can be a single year or a list of years. + + Returns + ------- + Tuple[Optional[pd.DataFrame], ...] + Tuple containing DataFrames for alpha1, alpha2, r1, r2, and c11, + or None for each if data not found + """ + + if isinstance(years, int): + years = [years] + + results = { + "alpha1": [], + "alpha2": [], + "r1": [], + "r2": [], + "c11": [], + } + + for year in years: + dir_path = os.path.join( + base_path, + "NDBC", + "buoy_data", + buoy_id, + "directional_spectra", + ) + files = { + "alpha1": f"{buoy_id}d{year}.txt.gz", + "alpha2": f"{buoy_id}i{year}.txt.gz", + "r1": f"{buoy_id}j{year}.txt.gz", + "r2": f"{buoy_id}k{year}.txt.gz", + "c11": f"{buoy_id}w{year}.txt.gz", + } + + for name, filename in files.items(): + file_path = os.path.join(dir_path, filename) + try: + df = _read_directional_file(file_path) + if df is not None: + results[name].append(df) + except FileNotFoundError: + print(f"No {name} file found for buoy {buoy_id} year {year}") + + # Combine DataFrames for each coefficient if available + final_results = {} + for name, dfs in results.items(): + if dfs: + final_results[name] = pd.concat(dfs).sort_index() + else: + final_results[name] = None + + return ( + final_results["alpha1"], + final_results["alpha2"], + final_results["r1"], + final_results["r2"], + final_results["c11"], ) + +class NOAADownloader(BaseDownloader): + """ + This is the main class to download data from NOAA. + + Examples + -------- + >>> downloader = NOAADownloader( + ... product="NDBC", + ... base_path_to_download="./noaa_data", + ... debug=True + ... ) + >>> result = downloader.download_data( + ... data_type="bulk_parameters", + ... buoy_id="41001", + ... years=[2023], + ... dry_run=False + ... ) + >>> print(result) + """ + + products_configs = { + "NDBC": json.load( + open(os.path.join(os.path.dirname(__file__), "NDBC", "NDBC_config.json")) + ) + } + def __init__( self, + product: str, base_path_to_download: str, debug: bool = True, - check: bool = False, ) -> None: """ Initialize the NOAA downloader. Parameters ---------- + product : str + The product to download data from. Currently only NDBC is supported. base_path_to_download : str The base path to download the data to. debug : bool, optional Whether to run in debug mode. Default is True. - check : bool, optional - Whether to just check the data. Default is False. + + Raises + ------ + ValueError + If the product configuration is not found. """ super().__init__( - base_path_to_download=base_path_to_download, debug=debug, check=check + product=product, base_path_to_download=base_path_to_download, debug=debug ) - self.set_logger_name("NOAADownloader", level="DEBUG" if debug else "INFO") - if not self.check: - self.logger.info("---- DOWNLOADING NOAA DATA ----") - else: - self.logger.info("---- CHECKING NOAA DATA ----") + self._product_config = self.products_configs.get(product) + if self._product_config is None: + available_products = list(self.products_configs.keys()) + raise ValueError( + f"Product '{product}' not found. Available: {available_products}" + ) + + self.set_logger_name( + f"NOAADownloader-{product}", level="DEBUG" if debug else "INFO" + ) + self.logger.info(f"---- NOAA DOWNLOADER INITIALIZED ({product}) ----") @property - def datasets(self) -> dict: - return self.config["datasets"] + def product_config(self) -> dict: + """ + Product configuration dictionary loaded from config file. + + Returns + ------- + dict + Product configuration dictionary. + """ + return self._product_config @property def data_types(self) -> dict: - return self.config["data_types"] + """ + Data types configuration dictionary. + + Returns + ------- + dict + Dictionary of available data types and their configurations. + """ + return self.product_config["data_types"] def list_data_types(self) -> List[str]: """ - Lists the available data types. + List all available data types for the product. Returns ------- List[str] - The list of available data types. + List of available data type names. """ - return list(self.data_types.keys()) - def list_datasets(self) -> List[str]: + def _check_file_exists( + self, file_path: str, result: DownloadResult, force: bool, dry_run: bool + ) -> bool: """ - Lists the available datasets. + Check if file exists and handle accordingly. + + Parameters + ---------- + file_path : str + Path to the file to check. + result : DownloadResult + The download result to update. + force : bool + Whether to force re-download. + dry_run : bool + If True, only check files without downloading. Returns ------- - List[str] - The list of available datasets. + bool + True if should skip download (file exists or dry_run mode), False otherwise. """ - return list(self.datasets.keys()) + if not force and os.path.exists(file_path): + result.add_skipped(file_path, "File already exists") + return True - def show_markdown_table(self) -> None: - """ - Create a Markdown table from the configuration dictionary and print it. + if dry_run: + result.add_skipped(file_path, "File does not exist (dry run)") + return True + + return False + + def download_data(self, dry_run: bool = True, *args, **kwargs) -> DownloadResult: """ + Download data for the product. - # Define the table headers - headers = ["name", "long_name", "description", "dataset"] - header_line = "| " + " | ".join(headers) + " |" - separator_line = ( - "| " + " | ".join(["-" * len(header) for header in headers]) + " |" - ) + Routes to product-specific download methods based on the product type. + + Parameters + ---------- + dry_run : bool, optional + If True, only check what would be downloaded without actually downloading. + Default is True. + *args + Arguments passed to product-specific download method. + **kwargs + Keyword arguments passed to product-specific download method. - # Initialize the table with headers - table_lines = [header_line, separator_line] + Returns + ------- + DownloadResult + Result with information about downloaded, skipped, and error files. - # Add rows for each data type - for data_type_name, data_type_info in self.data_types.items(): - name = data_type_info.get("name", "") - long_name = data_type_info.get("long_name", "") - description = data_type_info.get("description", "") - dataset = data_type_info.get("dataset", "") - row = f"| {name} | {long_name} | {description} | {dataset} |" - table_lines.append(row) + Raises + ------ + ValueError + If the product is not supported. + """ - # Print the table - print("\n".join(table_lines)) + if self.product == "NDBC": + return self.download_data_ndbc(dry_run=dry_run, *args, **kwargs) + else: + raise ValueError(f"Download for product {self.product} not supported") - def download_data( - self, data_type: str, load_df: bool = False, **kwargs - ) -> Union[pd.DataFrame, xr.Dataset, str]: + def download_data_ndbc( + self, data_type: str, dry_run: bool = True, **kwargs + ) -> DownloadResult: """ - Downloads the data for the specified data type. + Download data for the NDBC product. + + Downloads NDBC buoy data or forecast data based on the specified data type. + Files are saved to: base_path_to_download/product/dataset/... Parameters ---------- data_type : str - The data type to download. - - 'bulk_parameters' - - 'wave_spectra' - - 'directional_spectra' - - 'wind_forecast' - - load_df : bool, optional - Whether to load and return the DataFrame after downloading. - Default is False. - If True and multiple years are specified, all years will be combined - into a single DataFrame. + The data type to download. Available types: + - 'bulk_parameters': Standard meteorological data + - 'wave_spectra': Wave spectral density data + - 'directional_spectra': Directional wave spectra coefficients + - 'wind_forecast': GFS wind forecast data + dry_run : bool, optional + If True, only check what would be downloaded without actually downloading. + Default is True. **kwargs - Additional keyword arguments specific to each data type. + Additional keyword arguments specific to each data type: + - For bulk_parameters, wave_spectra, directional_spectra: buoy_id, years, force + - For wind_forecast: date, region, force Returns ------- - Union[pd.DataFrame, xr.Dataset, str] - Downloaded data or status message. + DownloadResult + Result with information about downloaded, skipped, and error files. Raises ------ @@ -181,44 +501,31 @@ def download_data( ) data_type_config = self.data_types[data_type] - dataset_config = self.datasets[data_type_config["dataset"]] + dataset_config = self.product_config["datasets"][data_type_config["dataset"]] + + if dry_run: + self.logger.info(f"DRY RUN: Checking files for {data_type}") - result = None if data_type == "bulk_parameters": result = self._download_bulk_parameters( - data_type_config, dataset_config, **kwargs + data_type_config, dataset_config, dry_run=dry_run, **kwargs ) - if load_df: - buoy_id = kwargs.get("buoy_id") - years = kwargs.get("years", []) - if years: - result = self.read_bulk_parameters(buoy_id, years) elif data_type == "wave_spectra": result = self._download_wave_spectra( - data_type_config, dataset_config, **kwargs + data_type_config, dataset_config, dry_run=dry_run, **kwargs ) - if load_df: - buoy_id = kwargs.get("buoy_id") - years = kwargs.get("years", []) - if years: - result = self.read_wave_spectra(buoy_id, years) elif data_type == "directional_spectra": result = self._download_directional_spectra( - data_type_config, dataset_config, **kwargs + data_type_config, dataset_config, dry_run=dry_run, **kwargs ) - if load_df: - buoy_id = kwargs.get("buoy_id") - years = kwargs.get("years", []) - if years: - result = self.read_directional_spectra(buoy_id, years) elif data_type == "wind_forecast": result = self._download_wind_forecast( - data_type_config, dataset_config, **kwargs + data_type_config, dataset_config, dry_run=dry_run, **kwargs ) else: raise ValueError(f"Download for data type {data_type} not implemented") - return result + return self.finalize_download_result(result) def _download_bulk_parameters( self, @@ -226,8 +533,9 @@ def _download_bulk_parameters( dataset_config: dict, buoy_id: str, years: List[int], - **kwargs, - ) -> pd.DataFrame: + force: bool = False, + dry_run: bool = False, + ) -> DownloadResult: """ Download bulk parameters for a specific buoy and years. @@ -241,196 +549,193 @@ def _download_bulk_parameters( The buoy ID. years : List[int] The years to download data for. + force : bool, optional + Whether to force re-download even if file exists. Default is False. + dry_run : bool, optional + If True, only check what would be downloaded. Default is False. Returns ------- - pd.DataFrame - The downloaded data. + DownloadResult + Download result with information about downloaded, skipped, and error files. """ self.logger.info( f"Downloading bulk parameters for buoy {buoy_id}, years {years}" ) - all_data = [] + result = self.create_download_result() base_url = dataset_config["base_url"] + dataset_name = data_type_config["dataset"] - for year in years: - # Try main URL first, then fallbacks - urls = [ - f"{base_url}/{data_type_config['url_pattern'].format(buoy_id=buoy_id, year=year)}" - ] - for fallback in data_type_config.get("fallback_urls", []): - urls.append(f"{base_url}/{fallback.format(buoy_id=buoy_id, year=year)}") - - df = self._download_single_year_bulk( - urls, data_type_config["columns"], year + try: + # Determine output file path: base_path/product/dataset/buoy_id/filename.csv + buoy_dir = os.path.join( + self.base_path_to_download, self.product, dataset_name, buoy_id ) - if df is not None: - all_data.append(df) - self.logger.info(f"Buoy {buoy_id}: Data found for year {year}") - else: - self.logger.warning( - f"Buoy {buoy_id}: No data available for year {year}" + output_file = os.path.join(buoy_dir, f"buoy_{buoy_id}_bulk_parameters.csv") + + # Check if file exists + if self._check_file_exists(output_file, result, force, dry_run): + return self.finalize_download_result(result) + + # Prepare download tasks + download_tasks = [] + for year in years: + urls = [ + f"{base_url}/{data_type_config['url_pattern'].format(buoy_id=buoy_id, year=year)}" + ] + for fallback in data_type_config.get("fallback_urls", []): + urls.append( + f"{base_url}/{fallback.format(buoy_id=buoy_id, year=year)}" + ) + + download_tasks.append( + { + "urls": urls, + "columns": data_type_config["columns"], + "year": year, + "buoy_id": buoy_id, + } ) - if all_data: - # Combine all years - combined_df = pd.concat(all_data, ignore_index=True) - combined_df = combined_df.sort_values(["YYYY", "MM", "DD", "hh"]) + if dry_run: + # In dry run mode, just mark what would be downloaded + for task in download_tasks: + result.add_skipped( + output_file, + f"Would download year {task['year']} (dry run)", + ) + return self.finalize_download_result(result) - # Save to CSV if not in check mode - if not self.check: - buoy_dir = os.path.join( - self.base_path_to_download, "buoy_data", buoy_id - ) + # Execute downloads sequentially + all_data = [] + for task in download_tasks: + try: + df = self._download_single_year_bulk(task["urls"], task["columns"]) + if df is not None: + all_data.append(df) + self.logger.info( + f"Buoy {buoy_id}: Data found for year {task['year']}" + ) + else: + self.logger.warning( + f"Buoy {buoy_id}: No data available for year {task['year']}" + ) + result.add_error( + output_file, + Exception(f"No data available for year {task['year']}"), + ) + except Exception as e: + self.logger.error(f"Error downloading year {task['year']}: {e}") + result.add_error(output_file, e) + + if all_data: + # Combine all years + combined_df = pd.concat(all_data, ignore_index=True) + combined_df = combined_df.sort_values(["YYYY", "MM", "DD", "hh"]) + + # Save to CSV os.makedirs(buoy_dir, exist_ok=True) - output_file = os.path.join( - buoy_dir, f"buoy_{buoy_id}_bulk_parameters.csv" - ) combined_df.to_csv(output_file, index=False) self.logger.info(f"Data saved to {output_file}") + result.add_downloaded(output_file) + else: + self.logger.error(f"No data found for buoy {buoy_id}") + result.add_error( + output_file, + Exception(f"No data found for buoy {buoy_id}"), + ) + except Exception as e: + result.add_error(output_file, e) + self.logger.error(f"Error processing data for buoy {buoy_id}: {e}") - return f"Data saved to {output_file}" - - return combined_df - else: - self.logger.error(f"No data found for buoy {buoy_id}") - return None + return self.finalize_download_result(result) def _download_single_year_bulk( - self, urls: List[str], columns: List[str], year: int + self, + urls: List[str], + columns: List[str], ) -> Optional[pd.DataFrame]: """ Download and parse bulk parameters for a single year. + Attempts to download from the primary URL, and if that fails, tries fallback URLs. + Handles different data formats (pre-2012 and post-2012) and validates dates. + Parameters ---------- urls : List[str] - The URLs to download the data from. + List of URLs to try downloading from (primary URL first, then fallbacks). columns : List[str] - The columns to read from the data. - year : int - The year to download data for. + List of column names for the DataFrame. Returns ------- Optional[pd.DataFrame] - The downloaded data. + DataFrame containing the downloaded and parsed data, or None if download fails. """ for url in urls: try: - response = requests.get(url) - if response.status_code == 200: - content = gzip.decompress(response.content).decode("utf-8") - - # Skip the header rows and read the data - data = [] - lines = content.split("\n")[2:] # Skip first two lines (headers) - - # Check format by looking at the first data line - first_line = next(line for line in lines if line.strip()) - cols = first_line.split() - - # Determine format based on number of columns and year format - has_minutes = len(cols) == 18 # Post-2012 format has 18 columns - - for line in lines: - if line.strip(): - parts = line.split() - if parts: - # Convert 2-digit year to 4 digits if needed - if int(parts[0]) < 100: - parts[0] = str(int(parts[0]) + 1900) - - # Add minutes column if it doesn't exist - if not has_minutes: - parts.insert(4, "00") - - data.append(" ".join(parts)) - - # Read the modified data - df = pd.read_csv( - io.StringIO("\n".join(data)), - sep=r"\s+", - names=columns, - ) + # Download the file + response = requests.get(url, timeout=30) + response.raise_for_status() + content = gzip.decompress(response.content).decode("utf-8") - # Validate dates - valid_dates = ( - (df["MM"] >= 1) - & (df["MM"] <= 12) - & (df["DD"] >= 1) - & (df["DD"] <= 31) - & (df["hh"] >= 0) - & (df["hh"] <= 23) - & (df["mm"] >= 0) - & (df["mm"] <= 59) - ) + # Skip the header rows and read the data + data = [] + lines = content.split("\n")[2:] # Skip first two lines (headers) - df = df[valid_dates].copy() + # Check format by looking at the first data line + first_line = next(line for line in lines if line.strip()) + cols = first_line.split() - if len(df) > 0: - return df + # Determine format based on number of columns and year format + has_minutes = len(cols) == 18 # Post-2012 format has 18 columns - except Exception as e: - self.logger.debug(f"Failed to download from {url}: {e}") - continue + for line in lines: + if line.strip(): + parts = line.split() + if parts: + # Convert 2-digit year to 4 digits if needed + if int(parts[0]) < 100: + parts[0] = str(int(parts[0]) + 1900) - return None + # Add minutes column if it doesn't exist + if not has_minutes: + parts.insert(4, "00") - def read_bulk_parameters( - self, buoy_id: str, years: Union[int, List[int]] - ) -> Optional[pd.DataFrame]: - """ - Read bulk parameters for a specific buoy and year(s). + data.append(" ".join(parts)) - Parameters - ---------- - buoy_id : str - The buoy ID. - years : Union[int, List[int]] - The year(s) to read data for. Can be a single year or a list of years. + # Read the modified data + df = pd.read_csv( + io.StringIO("\n".join(data)), + sep=r"\s+", + names=columns, + ) - Returns - ------- - Optional[pd.DataFrame] - DataFrame containing the bulk parameters, or None if data not found. - """ + # Validate dates + valid_dates = ( + (df["MM"] >= 1) + & (df["MM"] <= 12) + & (df["DD"] >= 1) + & (df["DD"] <= 31) + & (df["hh"] >= 0) + & (df["hh"] <= 23) + & (df["mm"] >= 0) + & (df["mm"] <= 59) + ) - if isinstance(years, int): - years = [years] + df = df[valid_dates].copy() - all_data = [] - for year in years: - file_path = os.path.join( - self.base_path_to_download, - "buoy_data", - buoy_id, - f"buoy_{buoy_id}_bulk_parameters.csv", - ) - try: - df = pd.read_csv(file_path) - df["datetime"] = pd.to_datetime( - df["YYYY"].astype(str) - + "-" - + df["MM"].astype(str).str.zfill(2) - + "-" - + df["DD"].astype(str).str.zfill(2) - + " " - + df["hh"].astype(str).str.zfill(2) - + ":" - + df["mm"].astype(str).str.zfill(2) - ) - all_data.append(df) - except FileNotFoundError: - self.logger.error( - f"No bulk parameters file found for buoy {buoy_id} year {year}" - ) + if len(df) > 0: + return df + + except Exception as e: + self.logger.debug(f"Failed to download from {url}: {e}") + continue - if all_data: - return pd.concat(all_data, ignore_index=True).sort_values("datetime") return None def _download_wave_spectra( @@ -439,47 +744,72 @@ def _download_wave_spectra( dataset_config: dict, buoy_id: str, years: List[int], - **kwargs, - ) -> str: + force: bool = False, + dry_run: bool = False, + ) -> DownloadResult: """ Download wave spectra data for a specific buoy. + Downloads wave spectral density data for each specified year. Files are saved to: + base_path_to_download/product/dataset/buoy_id/wave_spectra/buoy_{buoy_id}_spectra_{year}.csv + Parameters ---------- data_type_config : dict - The configuration for the data type. + Configuration for the data type. dataset_config : dict - The configuration for the dataset. + Configuration for the dataset. buoy_id : str The buoy ID. years : List[int] - The years to download data for. + List of years to download data for. + force : bool, optional + Force re-download even if file exists. Default is False. + dry_run : bool, optional + If True, only check what would be downloaded. Default is False. Returns ------- - str - The status message. + DownloadResult + Result with information about downloaded, skipped, and error files. """ self.logger.info(f"Downloading wave spectra for buoy {buoy_id}, years {years}") + result = self.create_download_result() base_url = dataset_config["base_url"] + dataset_name = data_type_config["dataset"] buoy_dir = os.path.join( - self.base_path_to_download, "buoy_data", buoy_id, "wave_spectra" + self.base_path_to_download, + self.product, + dataset_name, + buoy_id, + "wave_spectra", ) - if not self.check: + if not dry_run: os.makedirs(buoy_dir, exist_ok=True) - downloaded_files = [] - for year in years: url = f"{base_url}/{data_type_config['url_pattern'].format(buoy_id=buoy_id, year=year)}" + output_file = os.path.join(buoy_dir, f"buoy_{buoy_id}_spectra_{year}.csv") + + # Check if file exists + if self._check_file_exists(output_file, result, force, dry_run): + continue + + if dry_run: + result.add_skipped(output_file, f"Would download year {year} (dry run)") + continue try: + # Download and read the data + response = requests.get(url, timeout=30) + response.raise_for_status() + # Read the data df = pd.read_csv( - url, + io.BytesIO(response.content), compression="gzip", sep=r"\s+", na_values=["MM", "99.00", "999.0"], @@ -488,91 +818,24 @@ def _download_wave_spectra( # Skip if empty or invalid data if df.empty or len(df.columns) < 5: self.logger.warning(f"No valid data for {buoy_id} - {year}") + result.add_error( + output_file, + Exception(f"No valid data for {buoy_id} - {year}"), + context={"year": year}, + ) continue - # Process datetime (simplified version) - if not self.check: - output_file = os.path.join( - buoy_dir, f"buoy_{buoy_id}_spectra_{year}.csv" - ) - df.to_csv(output_file, index=False) - downloaded_files.append(output_file) - self.logger.info(f"Successfully saved data for {buoy_id} - {year}") + # Save the data + df.to_csv(output_file, index=False) + result.add_downloaded(output_file) + self.logger.info(f"Successfully saved data for {buoy_id} - {year}") except Exception as e: self.logger.warning(f"No data found for: {buoy_id} - {year}: {e}") + result.add_error(output_file, e, context={"year": year}) continue - return f"Downloaded {len(downloaded_files)} files for wave spectra" - - def read_wave_spectra( - self, buoy_id: str, years: Union[int, List[int]] - ) -> Optional[pd.DataFrame]: - """ - Read wave spectra data for a specific buoy and year(s). - - Parameters - ---------- - buoy_id : str - The buoy ID. - years : Union[int, List[int]] - The year(s) to read data for. Can be a single year or a list of years. - - Returns - ------- - Optional[pd.DataFrame] - DataFrame containing the wave spectra, or None if data not found - """ - - if isinstance(years, int): - years = [years] - - all_data = [] - for year in years: - file_path = os.path.join( - self.base_path_to_download, - "buoy_data", - buoy_id, - "wave_spectra", - f"buoy_{buoy_id}_spectra_{year}.csv", - ) - try: - df = pd.read_csv(file_path) - try: - df["date"] = pd.to_datetime( - df[["YYYY", "MM", "DD", "hh"]].rename( - columns={ - "YYYY": "year", - "MM": "month", - "DD": "day", - "hh": "hour", - } - ) - ) - df.drop(columns=["YYYY", "MM", "DD", "hh"], inplace=True) - except Exception as _e: - df["date"] = pd.to_datetime( - df[["#YY", "MM", "DD", "hh", "mm"]].rename( - columns={ - "#YY": "year", - "MM": "month", - "DD": "day", - "hh": "hour", - "mm": "minute", - } - ) - ) - df.drop(columns=["#YY", "MM", "DD", "hh", "mm"], inplace=True) - df.set_index("date", inplace=True) - all_data.append(df) - except FileNotFoundError: - self.logger.error( - f"No wave spectra file found for buoy {buoy_id} year {year}" - ) - - if all_data: - return pd.concat(all_data).sort_index() - return None + return result def _download_directional_spectra( self, @@ -580,219 +843,93 @@ def _download_directional_spectra( dataset_config: dict, buoy_id: str, years: List[int], - **kwargs, - ) -> str: + force: bool = False, + dry_run: bool = False, + ) -> DownloadResult: """ Download directional wave spectra coefficients. + Downloads Fourier coefficients (alpha1, alpha2, r1, r2, c11) for directional wave spectra. + Files are saved to: + base_path_to_download/product/dataset/buoy_id/directional_spectra/{buoy_id}{coef}{year}.txt.gz + Parameters ---------- data_type_config : dict - The configuration for the data type. + Configuration for the data type. dataset_config : dict - The configuration for the dataset. + Configuration for the dataset. buoy_id : str The buoy ID. years : List[int] - The years to download data for. + List of years to download data for. + force : bool, optional + Force re-download even if file exists. Default is False. + dry_run : bool, optional + If True, only check what would be downloaded. Default is False. Returns ------- - str - The status message. + DownloadResult + Result with information about downloaded, skipped, and error files. """ self.logger.info( f"Downloading directional spectra for buoy {buoy_id}, years {years}" ) + result = self.create_download_result() base_url = dataset_config["base_url"] coefficients = data_type_config["coefficients"] + dataset_name = data_type_config["dataset"] buoy_dir = os.path.join( - self.base_path_to_download, "buoy_data", buoy_id, "directional_spectra" + self.base_path_to_download, + self.product, + dataset_name, + buoy_id, + "directional_spectra", ) - if not self.check: + if not dry_run: os.makedirs(buoy_dir, exist_ok=True) - downloaded_files = [] - for year in years: for coef, info in coefficients.items(): filename = f"{buoy_id}{coef}{year}.txt.gz" url = f"{base_url}/{info['url_pattern'].format(buoy_id=buoy_id, year=year)}" + save_path = os.path.join(buoy_dir, filename) - if not self.check: - save_path = os.path.join(buoy_dir, filename) - - try: - self.logger.debug( - f"Downloading {info['name']} data for {year}..." - ) - response = requests.get(url, stream=True) - response.raise_for_status() - - # Save the compressed file - with open(save_path, "wb") as f: - shutil.copyfileobj(response.raw, f) - - downloaded_files.append(save_path) - self.logger.info(f"Successfully downloaded {filename}") - - except requests.exceptions.RequestException as e: - self.logger.warning(f"Error downloading {filename}: {e}") - continue - - return f"Downloaded {len(downloaded_files)} coefficient files" - - def read_directional_spectra( - self, buoy_id: str, years: Union[int, List[int]] - ) -> Tuple[Optional[pd.DataFrame], ...]: - """ - Read directional spectra data for a specific buoy and year(s). - - Parameters - ---------- - buoy_id : str - The buoy ID - years : Union[int, List[int]] - The year(s) to read data for. Can be a single year or a list of years. - - Returns - ------- - Tuple[Optional[pd.DataFrame], ...] - Tuple containing DataFrames for alpha1, alpha2, r1, r2, and c11, - or None for each if data not found - """ - - if isinstance(years, int): - years = [years] - - results = { - "alpha1": [], - "alpha2": [], - "r1": [], - "r2": [], - "c11": [], - } + # Check if file exists + if self._check_file_exists(save_path, result, force, dry_run): + continue - for year in years: - dir_path = os.path.join( - self.base_path_to_download, - "buoy_data", - buoy_id, - "directional_spectra", - ) - files = { - "alpha1": f"{buoy_id}d{year}.txt.gz", - "alpha2": f"{buoy_id}i{year}.txt.gz", - "r1": f"{buoy_id}j{year}.txt.gz", - "r2": f"{buoy_id}k{year}.txt.gz", - "c11": f"{buoy_id}w{year}.txt.gz", - } - - for name, filename in files.items(): - file_path = os.path.join(dir_path, filename) - try: - df = self._read_directional_file(file_path) - if df is not None: - results[name].append(df) - except FileNotFoundError: - self.logger.error( - f"No {name} file found for buoy {buoy_id} year {year}" + if dry_run: + result.add_skipped( + save_path, + f"Would download {info['name']} for year {year} (dry run)", ) + continue - # Combine DataFrames for each coefficient if available - final_results = {} - for name, dfs in results.items(): - if dfs: - final_results[name] = pd.concat(dfs).sort_index() - else: - final_results[name] = None - - return ( - final_results["alpha1"], - final_results["alpha2"], - final_results["r1"], - final_results["r2"], - final_results["c11"], - ) - - def _read_directional_file(self, file_path: Path) -> Optional[pd.DataFrame]: - """ - Read a directional spectra file and return DataFrame with datetime index. + try: + self.logger.debug(f"Downloading {info['name']} data for {year}...") - Parameters - ---------- - file_path : Path - Path to the file to read + # Download the file + response = requests.get(url, stream=True, timeout=30) + response.raise_for_status() - Returns - ------- - Optional[pd.DataFrame] - DataFrame containing the directional spectra data, or None if data not found - """ + # Save the compressed file + with open(save_path, "wb") as f: + shutil.copyfileobj(response.raw, f) - self.logger.debug(f"Reading file: {file_path}") - try: - with gzip.open(file_path, "rt") as f: - # Read header lines until we find the frequencies - header_lines = [] - while True: - line = f.readline().strip() - if not line.startswith("#") and not line.startswith("YYYY"): - break - header_lines.append(line) - - # Parse frequencies - header = " ".join(header_lines) - try: - freqs = [float(x) for x in header.split()[5:]] - self.logger.debug(f"Found {len(freqs)} frequencies") - except (ValueError, IndexError) as e: - self.logger.error(f"Error parsing frequencies: {e}") - return None + result.add_downloaded(save_path) + self.logger.info(f"Successfully downloaded {filename}") - # Read data - data = [] - dates = [] - # Process the first line - parts = line.strip().split() - if len(parts) >= 5: - try: - year, month, day, hour, minute = map(int, parts[:5]) - values = [float(x) for x in parts[5:]] - if len(values) == len(freqs): - dates.append(datetime(year, month, day, hour, minute)) - data.append(values) - except (ValueError, IndexError) as e: - self.logger.error(f"Error parsing line: {e}") - - # Read remaining lines - for line in f: - parts = line.strip().split() - if len(parts) >= 5: - try: - year, month, day, hour, minute = map(int, parts[:5]) - values = [float(x) for x in parts[5:]] - if len(values) == len(freqs): - dates.append(datetime(year, month, day, hour, minute)) - data.append(values) - except (ValueError, IndexError) as e: - self.logger.error(f"Error parsing line: {e}") - continue - - if not data: - self.logger.warning("No valid data points found in file") - return None - - df = pd.DataFrame(data, index=dates, columns=freqs) - self.logger.debug(f"Created DataFrame with shape: {df.shape}") - return df + except Exception as e: + self.logger.warning(f"Error downloading {filename}: {e}") + result.add_error(save_path, e) + continue - except Exception as e: - self.logger.error(f"Error reading file {file_path}: {str(e)}") - return None + return self.finalize_download_result(result) def _download_wind_forecast( self, @@ -800,28 +937,40 @@ def _download_wind_forecast( dataset_config: dict, date: str = None, region: List[float] = None, - **kwargs, - ) -> xr.Dataset: + force: bool = False, + dry_run: bool = False, + ) -> DownloadResult: """ Download NOAA GFS wind forecast data. + Downloads and crops GFS wind forecast data for a specific date and region. + Files are saved to: + base_path_to_download/product/dataset/{date}_{region}.nc + Parameters ---------- data_type_config : dict - The configuration for the data type. + Configuration for the data type. dataset_config : dict - The configuration for the dataset. + Configuration for the dataset. date : str, optional - The date to download data for. + Date to download data for (format: "YYYYMMDD"). If None, uses today's date. + Default is None. + region : List[float], optional + Geographic region coordinates. Default is None. + force : bool, optional + Force re-download even if file exists. Default is False. + dry_run : bool, optional + If True, only check what would be downloaded. Default is False. Returns ------- - xr.Dataset - The downloaded data. + DownloadResult + Result with information about downloaded, skipped, and error files. Notes ----- - - This will be DEPRECATED in the future. + This method will be DEPRECATED in the future. """ if date is None: @@ -829,13 +978,17 @@ def _download_wind_forecast( self.logger.info(f"Downloading wind forecast for date {date}") + result = self.create_download_result() url_base = dataset_config["base_url"] + dataset_name = data_type_config["dataset"] dbn = "gfs_0p25_1hr" url = f"{url_base}/gfs{date}/{dbn}_00z" - # File path for local storage - forecast_dir = os.path.join(self.base_path_to_download, "wind_forecast") - if not self.check: + # File path for local storage: base_path/product/dataset/filename.nc + forecast_dir = os.path.join( + self.base_path_to_download, self.product, dataset_name + ) + if not dry_run: os.makedirs(forecast_dir, exist_ok=True) file_path = os.path.join( @@ -843,18 +996,18 @@ def _download_wind_forecast( ) # Check if file exists - if os.path.isfile(file_path): - self.logger.info( - f"File already exists: {file_path}. Loading from local storage." + if self._check_file_exists(file_path, result, force, dry_run): + return result + + if dry_run: + result.add_skipped( + file_path, f"Would download wind forecast for {date} (dry run)" ) - data = xr.open_dataset(file_path) - else: - if self.check: - self.logger.info(f"File would be downloaded to: {file_path}") - return None + return result + try: self.logger.info(f"Downloading and cropping forecast data from: {url}") - # Crop dataset + # Open dataset from URL data = xr.open_dataset(url) # Select only wind data @@ -863,27 +1016,10 @@ def _download_wind_forecast( self.logger.info(f"Storing local copy at: {file_path}") data_select.to_netcdf(file_path) - data = data_select - - # Create output dataset with renamed variables - output_vars = data_type_config["output_variables"] - wind_data_forecast = xr.Dataset( - { - output_vars["u10"]: ( - ("time", "lat", "lon"), - data[data_type_config["variables"][0]].values, - ), - output_vars["v10"]: ( - ("time", "lat", "lon"), - data[data_type_config["variables"][1]].values, - ), - }, - coords={ - "time": data.time.values, - "lat": data.lat.values, - "lon": data.lon.values, - }, - ) - wind_data_forecast["time"] = wind_data_forecast.time.dt.round("min") + result.add_downloaded(file_path) + + except Exception as e: + self.logger.error(f"Error downloading wind forecast: {e}") + result.add_error(file_path, e) - return wind_data_forecast + return self.finalize_download_result(result) diff --git a/tests/downloaders/test_copernicus_downloader.py b/tests/downloaders/test_copernicus_downloader.py deleted file mode 100644 index a33d648..0000000 --- a/tests/downloaders/test_copernicus_downloader.py +++ /dev/null @@ -1,49 +0,0 @@ -import tempfile -import unittest - -from bluemath_tk.downloaders.copernicus.copernicus_downloader import ( - CopernicusDownloader, -) - - -class TestCopernicusDownloader(unittest.TestCase): - def setUp(self): - self.temp_dir = tempfile.mkdtemp() - self.downloader = CopernicusDownloader( - product="ERA5", - base_path_to_download=self.temp_dir, - token=None, - check=True, # Just check paths to download - ) - - def test_download_data_era5(self): - result = self.downloader.download_data_era5( - variables=["spectra"], - years=[f"{year:04d}" for year in range(2020, 2025)], - months=[ - "01", - "02", - "03", - "04", - "05", - "06", - "07", - "08", - "09", - "10", - "11", - "12", - ], - area=[43.4, 350.4, 43.6, 350.6], # [lat_min, lon_min, lat_max, lon_max] - ) - print(result) - - -if __name__ == "__main__": - unittest.main() - - -# mean_wave_period_based_on_first_moment/ -# wave_spectral_directional_width/ -# wave_spectral_directional_width_for_swell/ -# wave_spectral_directional_width_for_wind_waves/ diff --git a/tests/downloaders/test_ecmwf_downloader.py b/tests/downloaders/test_ecmwf_downloader.py deleted file mode 100644 index fbd0225..0000000 --- a/tests/downloaders/test_ecmwf_downloader.py +++ /dev/null @@ -1,35 +0,0 @@ -import tempfile -import unittest - -from bluemath_tk.downloaders.ecmwf.ecmwf_downloader import ECMWFDownloader - - -class TestECMWFDownloader(unittest.TestCase): - def setUp(self): - self.temp_dir = tempfile.mkdtemp() - self.downloader = ECMWFDownloader( - product="OpenData", - base_path_to_download=self.temp_dir, - check=True, # Just check paths to download, do not actually download - ) - - def test_list_datasets(self): - datasets = self.downloader.list_datasets() - self.assertIsInstance(datasets, list) - self.assertTrue(len(datasets) > 0) - print(f"Available datasets: {datasets}") - - def test_download_data(self): - dataset = self.downloader.download_data( - load_data=False, - param=["msl"], - step=[0, 240], - type="fc", - force=False, - ) - self.assertIsInstance(dataset, str) - print(dataset) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/downloaders/test_noaa_downloader.py b/tests/downloaders/test_noaa_downloader.py deleted file mode 100644 index baf4b20..0000000 --- a/tests/downloaders/test_noaa_downloader.py +++ /dev/null @@ -1,221 +0,0 @@ -import os.path as op -import tempfile -import unittest -from pathlib import Path - -import pandas as pd - -from bluemath_tk.downloaders.noaa.noaa_downloader import NOAADownloader - - -class TestNOAADownloader(unittest.TestCase): - def setUp(self): - """Set up test fixtures before each test method.""" - self.temp_dir = tempfile.mkdtemp() - self.downloader = NOAADownloader( - base_path_to_download=self.temp_dir, - debug=True, - check=False, # Just check paths to download - ) - - def test_download_bulk_parameters(self): - """Test downloading bulk parameters.""" - - # Test without loading DataFrame - result = self.downloader.download_data( - data_type="bulk_parameters", - buoy_id="41001", - years=[2023], - ) - self.assertIsNotNone(result) - self.assertIsInstance(result, str) - print(f"\nBulk parameters download result: {result}") - - # Test with loading DataFrame - df = self.downloader.download_data( - data_type="bulk_parameters", - buoy_id="41001", - years=[2023], - load_df=True, - ) - self.assertIsNotNone(df) - self.assertIsInstance(df, pd.DataFrame) - self.assertTrue("datetime" in df.columns) - self.assertTrue(len(df) > 0) - print(f"\nBulk parameters DataFrame shape: {df.shape}") - - def test_download_wave_spectra(self): - """Test downloading wave spectra.""" - - # Test without loading DataFrame - result = self.downloader.download_data( - data_type="wave_spectra", - buoy_id="41001", - years=[2023], - ) - self.assertIsNotNone(result) - self.assertIsInstance(result, str) - print(f"\nWave spectra download result: {result}") - - # Test with loading DataFrame - df = self.downloader.download_data( - data_type="wave_spectra", - buoy_id="41001", - years=[2023], - load_df=True, - ) - self.assertIsNotNone(df) - self.assertIsInstance(df, pd.DataFrame) - self.assertTrue(isinstance(df.index, pd.DatetimeIndex)) - self.assertTrue(len(df) > 0) - print(f"\nWave spectra DataFrame shape: {df.shape}") - - def test_download_directional_spectra(self): - """Test downloading directional spectra.""" - - # Test without loading DataFrame - result = self.downloader.download_data( - data_type="directional_spectra", - buoy_id="41001", - years=[2023], - ) - self.assertIsNotNone(result) - self.assertIsInstance(result, str) - print(f"\nDirectional spectra download result: {result}") - - # Test with loading DataFrame - alpha1, alpha2, r1, r2, c11 = self.downloader.download_data( - data_type="directional_spectra", - buoy_id="41001", - years=[2023], - load_df=True, - ) - # Check each coefficient DataFrame - for name, df in [ - ("alpha1", alpha1), - ("alpha2", alpha2), - ("r1", r1), - ("r2", r2), - ("c11", c11), - ]: - if df is not None: - self.assertIsInstance(df, pd.DataFrame) - self.assertTrue(isinstance(df.index, pd.DatetimeIndex)) - self.assertTrue(len(df) > 0) - print(f"\n{name} DataFrame shape: {df.shape}") - - def test_multiple_years_loading(self): - """Test loading multiple years of data.""" - - # Test bulk parameters with multiple years - df = self.downloader.download_data( - data_type="bulk_parameters", - buoy_id="41001", - years=[2022, 2023], - load_df=True, - ) - self.assertIsNotNone(df) - self.assertIsInstance(df, pd.DataFrame) - self.assertTrue("datetime" in df.columns) - self.assertTrue(len(df) > 0) - - # Check that data spans multiple years - years = df["datetime"].dt.year.unique() - self.assertTrue(len(years) > 1) - print(f"\nBulk parameters multiple years: {sorted(years)}") - - # Test wave spectra with multiple years - df = self.downloader.download_data( - data_type="wave_spectra", - buoy_id="41001", - years=[2022, 2023], - load_df=True, - ) - self.assertIsNotNone(df) - self.assertIsInstance(df, pd.DataFrame) - self.assertTrue(isinstance(df.index, pd.DatetimeIndex)) - self.assertTrue(len(df) > 0) - - # Check that data spans multiple years - years = df.index.year.unique() - self.assertTrue(len(years) > 1) - print(f"\nWave spectra multiple years: {sorted(years)}") - - def test_list_data_types(self): - """Test listing available data types.""" - - data_types = self.downloader.list_data_types() - self.assertIsInstance(data_types, list) - self.assertTrue(len(data_types) > 0) - print(f"\nAvailable data types: {data_types}") - - def test_list_datasets(self): - """Test listing available datasets.""" - - datasets = self.downloader.list_datasets() - self.assertIsInstance(datasets, list) - self.assertTrue(len(datasets) > 0) - print(f"\nAvailable datasets: {datasets}") - - def test_show_markdown_table(self): - """Test showing markdown table.""" - - self.downloader.show_markdown_table() - - def test_file_paths(self): - """Test that downloaded files exist in the correct locations.""" - - # Download data - self.downloader.download_data( - data_type="bulk_parameters", - buoy_id="41001", - years=[2023], - ) - - # Check bulk parameters file - bulk_file = op.join( - self.temp_dir, - "buoy_data", - "41001", - "buoy_41001_bulk_parameters.csv", - ) - self.assertTrue(op.exists(bulk_file)) - print(f"\nBulk parameters file exists: {bulk_file}") - - # Download and check wave spectra - self.downloader.download_data( - data_type="wave_spectra", - buoy_id="41001", - years=[2023], - ) - wave_file = op.join( - self.temp_dir, - "buoy_data", - "41001", - "wave_spectra", - "buoy_41001_spectra_2023.csv", - ) - self.assertTrue(op.exists(wave_file)) - print(f"\nWave spectra file exists: {wave_file}") - - # Download and check directional spectra - self.downloader.download_data( - data_type="directional_spectra", - buoy_id="41001", - years=[2023], - ) - dir_path = op.join( - self.temp_dir, - "buoy_data", - "41001", - "directional_spectra", - ) - self.assertTrue(op.exists(dir_path)) - # Check for at least one coefficient file - coeff_files = list(Path(dir_path).glob("41001*2023.txt.gz")) - self.assertTrue(len(coeff_files) > 0) - print(f"\nDirectional spectra files exist: {coeff_files}") - - -if __name__ == "__main__": - unittest.main()