From 4924aa6200e215705c0570370acb83c21cd93749 Mon Sep 17 00:00:00 2001 From: Abdoulaye Diallo Date: Sun, 1 Feb 2026 14:33:47 +0100 Subject: [PATCH 1/2] Fix TrainableBilateralFilter 3D input validation (#7444) - Fix dimension comparison to use spatial dims instead of total dims - Add validation for minimum input dimensions - Fix typo in error message (ken_spatial_sigma -> len_spatial_sigma) - Move spatial dimension validation before unsqueeze operations The forward() method was incorrectly comparing self.len_spatial_sigma (number of spatial dimensions) with len(input_tensor.shape) (total dimensions including batch and channel), causing valid 3D inputs to be rejected. Fixes #7444 Signed-off-by: Abdoulaye Diallo --- monai/networks/layers/filtering.py | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/monai/networks/layers/filtering.py b/monai/networks/layers/filtering.py index c48c77cf98..2b46ce1b6e 100644 --- a/monai/networks/layers/filtering.py +++ b/monai/networks/layers/filtering.py @@ -221,7 +221,7 @@ def __init__(self, spatial_sigma, color_sigma): self.len_spatial_sigma = 3 else: raise ValueError( - f"len(spatial_sigma) {spatial_sigma} must match number of spatial dims {self.ken_spatial_sigma}." + f"len(spatial_sigma) {spatial_sigma} must match number of spatial dims {self.len_spatial_sigma}." ) # Register sigmas as trainable parameters. @@ -231,6 +231,10 @@ def __init__(self, spatial_sigma, color_sigma): self.sigma_color = torch.nn.Parameter(torch.tensor(color_sigma)) def forward(self, input_tensor): + if len(input_tensor.shape) < 3: + raise ValueError( + f"Input must have at least 3 dimensions (batch, channel, *spatial_dims), got {len(input_tensor.shape)}" + ) if input_tensor.shape[1] != 1: raise ValueError( f"Currently channel dimensions >1 ({input_tensor.shape[1]}) are not supported. " @@ -239,24 +243,25 @@ def forward(self, input_tensor): ) len_input = len(input_tensor.shape) + spatial_dims = len_input - 2 # C++ extension so far only supports 5-dim inputs. - if len_input == 3: + if spatial_dims == 1: input_tensor = input_tensor.unsqueeze(3).unsqueeze(4) - elif len_input == 4: + elif spatial_dims == 2: input_tensor = input_tensor.unsqueeze(4) - if self.len_spatial_sigma != len_input: - raise ValueError(f"Spatial dimension ({len_input}) must match initialized len(spatial_sigma).") + if self.len_spatial_sigma != spatial_dims: + raise ValueError(f"Spatial dimension ({spatial_dims}) must match initialized len(spatial_sigma).") prediction = TrainableBilateralFilterFunction.apply( input_tensor, self.sigma_x, self.sigma_y, self.sigma_z, self.sigma_color ) # Make sure to return tensor of the same shape as the input. - if len_input == 3: + if spatial_dims == 1: prediction = prediction.squeeze(4).squeeze(3) - elif len_input == 4: + elif spatial_dims == 2: prediction = prediction.squeeze(4) return prediction @@ -389,7 +394,7 @@ def __init__(self, spatial_sigma, color_sigma): self.len_spatial_sigma = 3 else: raise ValueError( - f"len(spatial_sigma) {spatial_sigma} must match number of spatial dims {self.ken_spatial_sigma}." + f"len(spatial_sigma) {spatial_sigma} must match number of spatial dims {self.len_spatial_sigma}." ) # Register sigmas as trainable parameters. From fced85b32dbdcaa42eaedb2cdc37af8f0aa946a8 Mon Sep 17 00:00:00 2001 From: Abdoulaye Diallo Date: Wed, 4 Mar 2026 13:32:58 +0100 Subject: [PATCH 2/2] fix: apply same dimension handling fixes to TrainableJointBilateralFilter Signed-off-by: Abdoulaye Diallo --- monai/networks/layers/filtering.py | 27 ++++++++++++++------------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/monai/networks/layers/filtering.py b/monai/networks/layers/filtering.py index 2b46ce1b6e..249fcf2892 100644 --- a/monai/networks/layers/filtering.py +++ b/monai/networks/layers/filtering.py @@ -220,9 +220,7 @@ def __init__(self, spatial_sigma, color_sigma): spatial_sigma = [spatial_sigma[0], spatial_sigma[1], spatial_sigma[2]] self.len_spatial_sigma = 3 else: - raise ValueError( - f"len(spatial_sigma) {spatial_sigma} must match number of spatial dims {self.len_spatial_sigma}." - ) + raise ValueError(f"len(spatial_sigma) {spatial_sigma} must match number of spatial dims (1, 2 or 3).") # Register sigmas as trainable parameters. self.sigma_x = torch.nn.Parameter(torch.tensor(spatial_sigma[0])) @@ -393,9 +391,7 @@ def __init__(self, spatial_sigma, color_sigma): spatial_sigma = [spatial_sigma[0], spatial_sigma[1], spatial_sigma[2]] self.len_spatial_sigma = 3 else: - raise ValueError( - f"len(spatial_sigma) {spatial_sigma} must match number of spatial dims {self.len_spatial_sigma}." - ) + raise ValueError(f"len(spatial_sigma) {spatial_sigma} must match number of spatial dims (1, 2, or 3).") # Register sigmas as trainable parameters. self.sigma_x = torch.nn.Parameter(torch.tensor(spatial_sigma[0])) @@ -404,9 +400,13 @@ def __init__(self, spatial_sigma, color_sigma): self.sigma_color = torch.nn.Parameter(torch.tensor(color_sigma)) def forward(self, input_tensor, guidance_tensor): + if len(input_tensor.shape) < 3: + raise ValueError( + f"Input must have at least 3 dimensions (batch, channel, *spatial_dims), got {len(input_tensor.shape)}" + ) if input_tensor.shape[1] != 1: raise ValueError( - f"Currently channel dimensions >1 ({input_tensor.shape[1]}) are not supported. " + f"Currently channel dimensions > 1 ({input_tensor.shape[1]}) are not supported. " "Please use multiple parallel filter layers if you want " "to filter multiple channels." ) @@ -417,26 +417,27 @@ def forward(self, input_tensor, guidance_tensor): ) len_input = len(input_tensor.shape) + spatial_dims = len_input - 2 # C++ extension so far only supports 5-dim inputs. - if len_input == 3: + if spatial_dims == 1: input_tensor = input_tensor.unsqueeze(3).unsqueeze(4) guidance_tensor = guidance_tensor.unsqueeze(3).unsqueeze(4) - elif len_input == 4: + elif spatial_dims == 2: input_tensor = input_tensor.unsqueeze(4) guidance_tensor = guidance_tensor.unsqueeze(4) - if self.len_spatial_sigma != len_input: - raise ValueError(f"Spatial dimension ({len_input}) must match initialized len(spatial_sigma).") + if self.len_spatial_sigma != spatial_dims: + raise ValueError(f"Spatial dimension ({spatial_dims}) must match initialized len(spatial_sigma).") prediction = TrainableJointBilateralFilterFunction.apply( input_tensor, guidance_tensor, self.sigma_x, self.sigma_y, self.sigma_z, self.sigma_color ) # Make sure to return tensor of the same shape as the input. - if len_input == 3: + if spatial_dims == 1: prediction = prediction.squeeze(4).squeeze(3) - elif len_input == 4: + elif spatial_dims == 2: prediction = prediction.squeeze(4) return prediction