diff --git a/metrics/MSE.py b/metrics/MSE.py new file mode 100644 index 0000000..5d09754 --- /dev/null +++ b/metrics/MSE.py @@ -0,0 +1,11 @@ +import numpy as np +import random +import logging + +def mse(X, obs): + ''' + Retourne la MSE + ''' + random_idx = random.sample(range(X.shape[0]), 1)[0] + _mse = np.nanmean(((X[random_idx] - obs.squeeze())**2),axis=(-2,-1)) + return _mse \ No newline at end of file diff --git a/metrics/__init__.py b/metrics/__init__.py index 1061f81..50e8ff8 100644 --- a/metrics/__init__.py +++ b/metrics/__init__.py @@ -19,6 +19,7 @@ from metrics import CRPS_calc from metrics import area_proportion as ap from metrics import object_detection as obj +from metrics import MSE from metrics.metrics import Metric, PreprocessCondObs, PreprocessDist, PreprocessStandalone @@ -474,6 +475,25 @@ def _calculateCore(self, processed_data): return GM.relative_std_diff(real_data,fake_data) + +##################################################################### +############################ Determinstic metrics ################### +##################################################################### + +class mse(PreprocessCondObs): + def __init__(self, *args, **kwargs): + super().__init__(isBatched=True) + + def _calculateCore(self, processed_data): + if not self.isOnReal: + exp_data = processed_data['fake_data'] + else: + exp_data = processed_data['real_data'] + obs_data = processed_data['obs_data'] + return MSE.mse(exp_data, obs_data) + + + ##################################################################### ###################################################### ##################################################################### \ No newline at end of file diff --git a/metrics/wind_comp.py b/metrics/wind_comp.py index 73e9c40..1bf6a10 100755 --- a/metrics/wind_comp.py +++ b/metrics/wind_comp.py @@ -113,7 +113,8 @@ def computeWindDir(U, V, xRef=None, yRef=None, proj=None): @rtype: tuple @return: vitesse (m/s) et direction du vent exprimée en degrés météo (0 = vent du nord). """ - ff = np.sqrt(U * U + V * V) + + ff = np.sqrt(np.maximum(U * U + V * V, np.zeros_like(U))) dd3 = (180 + 180 / np.pi * np.arctan2(U, V)) % 360 # logging.debug(dd3) diff --git a/preprocess/rrPreprocessor.py b/preprocess/rrPreprocessor.py index 67f7266..2562404 100644 --- a/preprocess/rrPreprocessor.py +++ b/preprocess/rrPreprocessor.py @@ -38,6 +38,7 @@ def init_normalization(self): normalization_type = self.normalization["type"] if normalization_type == "mean": means, stds = self.load_stat_files(normalization_type, "mean", "std") + logging.debug(f"stat constants {means, stds}") return None, None, means, stds elif normalization_type == "minmax": maxs, mins = self.load_stat_files(normalization_type, "max", "min") @@ -64,16 +65,16 @@ def load_stat_files(self, normalization_type, str1, str2): std_or_min_filename += "_ppx" mean_or_max_filename += ".npy" std_or_min_filename += ".npy" - logging.debug(f"{mean_or_max_filename}", f"{std_or_min_filename}") - logging.debug(f"Normalization set to {normalization_type}") + # logging.debug(f"{mean_or_max_filename}", f"{std_or_min_filename}") + # logging.debug(f"Normalization set to {normalization_type}") stat_folder = self.config_data["stat_folder"] file_path = os.path.join(self.config_data["real_data_dir"], stat_folder, mean_or_max_filename) means_or_maxs = np.load(file_path).astype('float32') - logging.debug(f"{str1} file found, {means_or_maxs.shape}") + # logging.debug(f"{str1} file found, {means_or_maxs.shape}") file_path = os.path.join(self.config_data["real_data_dir"], stat_folder, std_or_min_filename) stds_or_mins = np.load(file_path).astype('float32') - logging.debug(f"{str2} file found, {stds_or_mins.shape}") + # logging.debug(f"{str2} file found, {stds_or_mins.shape}") return means_or_maxs, stds_or_mins def detransform(self, data): @@ -178,4 +179,5 @@ def __init__(self, config_data, sizeH, sizeW, variables, **kwargs): super().__init__(config_data, sizeH, sizeW, variables, **kwargs) def process_batch(self, batch): - return self.detransform(batch) + res = self.detransform(batch) + return res diff --git a/requirements.txt b/requirements.txt index 191b106..80ff415 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ PyYaml==6.0.1 -torch==2.2.0+cu121 +torch==2.2.0 numpy==1.26.4 tqdm==4.66.1 pandas==2.2.0 @@ -9,4 +9,5 @@ pyproj==3.6.1 properscoring==0.1 geopy==2.4.1 astropy==6.0.0 +scikit-image CRPS