From 5a4ae778edfc076730d93d06fea49a68d0d4dce0 Mon Sep 17 00:00:00 2001 From: seznecc Date: Thu, 13 Jun 2024 17:30:11 +0200 Subject: [PATCH 1/4] ajout mse metrique --- metrics/__init__.py | 21 +++++++++++++++++++++ metrics/wind_comp.py | 3 ++- preprocess/rrPreprocessor.py | 12 +++++++----- 3 files changed, 30 insertions(+), 6 deletions(-) diff --git a/metrics/__init__.py b/metrics/__init__.py index 1061f81..26c2e42 100644 --- a/metrics/__init__.py +++ b/metrics/__init__.py @@ -19,6 +19,7 @@ from metrics import CRPS_calc from metrics import area_proportion as ap from metrics import object_detection as obj +from metrics import MSE from metrics.metrics import Metric, PreprocessCondObs, PreprocessDist, PreprocessStandalone @@ -474,6 +475,26 @@ def _calculateCore(self, processed_data): return GM.relative_std_diff(real_data,fake_data) + +##################################################################### +############################ Determinstic metrics ################### +##################################################################### + +class mse(PreprocessCondObs): + def __init__(self, *args, **kwargs): + super().__init__(isBatched=True) + + def _calculateCore(self, processed_data): + if not self.isOnReal: + exp_data = processed_data['fake_data'] + else: + exp_data = processed_data['real_data'] + obs_data = processed_data['obs_data'] + + return MSE.mse(exp_data, obs_data) + + + ##################################################################### ###################################################### ##################################################################### \ No newline at end of file diff --git a/metrics/wind_comp.py b/metrics/wind_comp.py index 73e9c40..1bf6a10 100755 --- a/metrics/wind_comp.py +++ b/metrics/wind_comp.py @@ -113,7 +113,8 @@ def computeWindDir(U, V, xRef=None, yRef=None, proj=None): @rtype: tuple @return: vitesse (m/s) et direction du vent exprimée en degrés météo (0 = vent du nord). """ - ff = np.sqrt(U * U + V * V) + + ff = np.sqrt(np.maximum(U * U + V * V, np.zeros_like(U))) dd3 = (180 + 180 / np.pi * np.arctan2(U, V)) % 360 # logging.debug(dd3) diff --git a/preprocess/rrPreprocessor.py b/preprocess/rrPreprocessor.py index 67f7266..2562404 100644 --- a/preprocess/rrPreprocessor.py +++ b/preprocess/rrPreprocessor.py @@ -38,6 +38,7 @@ def init_normalization(self): normalization_type = self.normalization["type"] if normalization_type == "mean": means, stds = self.load_stat_files(normalization_type, "mean", "std") + logging.debug(f"stat constants {means, stds}") return None, None, means, stds elif normalization_type == "minmax": maxs, mins = self.load_stat_files(normalization_type, "max", "min") @@ -64,16 +65,16 @@ def load_stat_files(self, normalization_type, str1, str2): std_or_min_filename += "_ppx" mean_or_max_filename += ".npy" std_or_min_filename += ".npy" - logging.debug(f"{mean_or_max_filename}", f"{std_or_min_filename}") - logging.debug(f"Normalization set to {normalization_type}") + # logging.debug(f"{mean_or_max_filename}", f"{std_or_min_filename}") + # logging.debug(f"Normalization set to {normalization_type}") stat_folder = self.config_data["stat_folder"] file_path = os.path.join(self.config_data["real_data_dir"], stat_folder, mean_or_max_filename) means_or_maxs = np.load(file_path).astype('float32') - logging.debug(f"{str1} file found, {means_or_maxs.shape}") + # logging.debug(f"{str1} file found, {means_or_maxs.shape}") file_path = os.path.join(self.config_data["real_data_dir"], stat_folder, std_or_min_filename) stds_or_mins = np.load(file_path).astype('float32') - logging.debug(f"{str2} file found, {stds_or_mins.shape}") + # logging.debug(f"{str2} file found, {stds_or_mins.shape}") return means_or_maxs, stds_or_mins def detransform(self, data): @@ -178,4 +179,5 @@ def __init__(self, config_data, sizeH, sizeW, variables, **kwargs): super().__init__(config_data, sizeH, sizeW, variables, **kwargs) def process_batch(self, batch): - return self.detransform(batch) + res = self.detransform(batch) + return res From c2311f802978bc0dfc6c14f2239387b4857a8f1e Mon Sep 17 00:00:00 2001 From: seznecc Date: Thu, 13 Jun 2024 17:30:49 +0200 Subject: [PATCH 2/4] ajout MSE 2 --- metrics/MSE.py | 9 +++++++++ 1 file changed, 9 insertions(+) create mode 100644 metrics/MSE.py diff --git a/metrics/MSE.py b/metrics/MSE.py new file mode 100644 index 0000000..a6d6435 --- /dev/null +++ b/metrics/MSE.py @@ -0,0 +1,9 @@ +import numpy as np +import random +import logging + +def mse(X, obs): + + random_idx = random.sample(range(X.shape[0]), 1)[0] + _mse = np.nanmean(((X[random_idx] - obs.squeeze())**2),axis=(-2,-1)) + return _mse \ No newline at end of file From 09e82914881db500797a7c52e39d87dffe45ce87 Mon Sep 17 00:00:00 2001 From: seznecc Date: Thu, 13 Jun 2024 18:20:14 +0200 Subject: [PATCH 3/4] test credentials --- metrics/MSE.py | 4 +++- metrics/__init__.py | 1 - 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/metrics/MSE.py b/metrics/MSE.py index a6d6435..5d09754 100644 --- a/metrics/MSE.py +++ b/metrics/MSE.py @@ -3,7 +3,9 @@ import logging def mse(X, obs): - + ''' + Retourne la MSE + ''' random_idx = random.sample(range(X.shape[0]), 1)[0] _mse = np.nanmean(((X[random_idx] - obs.squeeze())**2),axis=(-2,-1)) return _mse \ No newline at end of file diff --git a/metrics/__init__.py b/metrics/__init__.py index 26c2e42..50e8ff8 100644 --- a/metrics/__init__.py +++ b/metrics/__init__.py @@ -490,7 +490,6 @@ def _calculateCore(self, processed_data): else: exp_data = processed_data['real_data'] obs_data = processed_data['obs_data'] - return MSE.mse(exp_data, obs_data) From a3948eea905eb34e4422c9fa71b52e51e134be85 Mon Sep 17 00:00:00 2001 From: seznecc Date: Thu, 20 Jun 2024 10:06:12 +0200 Subject: [PATCH 4/4] ajout scikit-image au requirement --- requirements.txt | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 191b106..80ff415 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ PyYaml==6.0.1 -torch==2.2.0+cu121 +torch==2.2.0 numpy==1.26.4 tqdm==4.66.1 pandas==2.2.0 @@ -9,4 +9,5 @@ pyproj==3.6.1 properscoring==0.1 geopy==2.4.1 astropy==6.0.0 +scikit-image CRPS