From 28915206cddffc9ad82c5dcef5c78f04998a4bb9 Mon Sep 17 00:00:00 2001 From: Mehdi Seifi Date: Wed, 11 Jun 2025 16:40:38 +0200 Subject: [PATCH 01/35] added No Patching checkbox to extractor widget --- src/featureforest/_feature_extractor_widget.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/src/featureforest/_feature_extractor_widget.py b/src/featureforest/_feature_extractor_widget.py index 33e51f1..da34399 100644 --- a/src/featureforest/_feature_extractor_widget.py +++ b/src/featureforest/_feature_extractor_widget.py @@ -9,6 +9,7 @@ from napari.utils.events import Event from qtpy.QtCore import Qt from qtpy.QtWidgets import ( + QCheckBox, QComboBox, QFileDialog, QGroupBox, @@ -23,9 +24,7 @@ from .models import get_available_models, get_model from .utils import config -from .utils.data import ( - get_stack_dims, -) +from .utils.data import get_stack_dims from .utils.extract import extract_embeddings_to_file from .widgets import ( ScrollWidgetWrapper, @@ -66,6 +65,11 @@ def prepare_widget(self): self.storage_textbox.setReadOnly(True) storage_button = QPushButton("Set Storage File") storage_button.clicked.connect(self.save_storage) + # no-patching checkbox + self.no_patching_checkbox = QCheckBox("No &Patching") + self.no_patching_checkbox.setToolTip( + "Whether divide an image into patches or not" + ) # extract button self.extract_button = QPushButton("Extract Features") self.extract_button.setEnabled(False) @@ -104,6 +108,7 @@ def prepare_widget(self): hbox.addWidget(self.storage_textbox) hbox.addWidget(storage_button) vbox.addLayout(hbox) + vbox.addWidget(self.no_patching_checkbox) hbox = QHBoxLayout() hbox.setContentsMargins(0, 0, 0, 0) hbox.addWidget(self.extract_button, alignment=Qt.AlignLeft) @@ -168,6 +173,7 @@ def extract_embeddings(self): _, img_height, img_width = get_stack_dims(image_layer.data) model_name = self.model_combo.currentText() self.model_adapter = get_model(model_name, img_height, img_width) + self.model_adapter.no_patching = self.no_patching_checkbox.isChecked() self.extract_button.setEnabled(False) self.stop_button.setEnabled(True) From ca67604600d363dccad656ca4b1a96dac3fb2c7a Mon Sep 17 00:00:00 2001 From: Mehdi Seifi Date: Wed, 11 Jun 2025 17:02:24 +0200 Subject: [PATCH 02/35] updated models adapters to support no_patching --- src/featureforest/models/Cellpose/adapter.py | 17 ++--- src/featureforest/models/DinoV2/adapter.py | 53 +++++++------- src/featureforest/models/MobileSAM/adapter.py | 63 +++++++--------- src/featureforest/models/SAM/adapter.py | 58 +++++++-------- src/featureforest/models/SAM2/adapter.py | 64 ++++++++--------- src/featureforest/models/base.py | 71 +++++++++++-------- 6 files changed, 153 insertions(+), 173 deletions(-) diff --git a/src/featureforest/models/Cellpose/adapter.py b/src/featureforest/models/Cellpose/adapter.py index 6d5ae08..8103ffc 100755 --- a/src/featureforest/models/Cellpose/adapter.py +++ b/src/featureforest/models/Cellpose/adapter.py @@ -6,7 +6,6 @@ from featureforest.models.base import BaseModelAdapter from featureforest.utils.data import ( get_nonoverlapped_patches, - get_patch_size, ) @@ -16,8 +15,8 @@ class CellposeAdapter(BaseModelAdapter): def __init__( self, image_encoder: nn.Module, - img_height: float, - img_width: float, + img_height: int, + img_width: int, device: torch.device, name: str = "Cellpose_cyto3", ) -> None: @@ -28,9 +27,9 @@ def __init__( # self.encoder_num_channels = 480 self.device = device self._set_patch_size() - assert ( - int(self.patch_size / 4) == self.patch_size / 4 - ), f"patch size {self.patch_size} is not divisible by 4" + assert int(self.patch_size / 4) == self.patch_size / 4, ( + f"patch size {self.patch_size} is not divisible by 4" + ) # input transform for sam self.input_transforms = None @@ -45,11 +44,7 @@ def __init__( ] ) - def _set_patch_size(self) -> None: - self.patch_size = get_patch_size(self.img_height, self.img_width) - self.overlap = self.patch_size // 2 - - def get_features_patches(self, in_patches: Tensor) -> tuple[Tensor, Tensor]: + def get_features_patches(self, in_patches: Tensor) -> Tensor: # cellpose works on microscopic channels not RGB # we need to select one channel and add a second zero channel in_patches = torch.cat( diff --git a/src/featureforest/models/DinoV2/adapter.py b/src/featureforest/models/DinoV2/adapter.py index 67b9188..0544c63 100644 --- a/src/featureforest/models/DinoV2/adapter.py +++ b/src/featureforest/models/DinoV2/adapter.py @@ -1,5 +1,3 @@ -from typing import Tuple - import torch import torch.nn as nn from torch import Tensor @@ -12,14 +10,10 @@ class DinoV2Adapter(BaseModelAdapter): - """DinoV2 model adapter - """ + """DinoV2 model adapter""" + def __init__( - self, - model: nn.Module, - img_height: float, - img_width: float, - device: torch.device + self, model: nn.Module, img_height: int, img_width: int, device: torch.device ) -> None: super().__init__(model, img_height, img_width, device) self.name = "DinoV2" @@ -30,40 +24,43 @@ def __init__( self.device = device # input transform for dinov2 - self.input_transforms = tv_transforms2.Compose([ - tv_transforms2.ToImage(), - tv_transforms2.Resize(self.patch_size * self.dino_patch_size), - tv_transforms2.ToDtype(dtype=torch.float32, scale=True) - ]) + self.input_transforms = tv_transforms2.Compose( + [ + tv_transforms2.ToImage(), + tv_transforms2.Resize(self.patch_size * self.dino_patch_size), + tv_transforms2.ToDtype(dtype=torch.float32, scale=True), + ] + ) # to transform feature patches back to the original patch size - self.embedding_transform = tv_transforms2.Compose([ - tv_transforms2.Resize( - (self.patch_size, self.patch_size), - interpolation=tv_transforms2.InterpolationMode.BICUBIC, - antialias=True - ), - ]) + self.embedding_transform = tv_transforms2.Compose( + [ + tv_transforms2.Resize( + (self.patch_size, self.patch_size), + interpolation=tv_transforms2.InterpolationMode.BICUBIC, + antialias=True, + ), + ] + ) def _set_patch_size(self) -> None: self.patch_size = self.dino_patch_size * 5 self.overlap = self.dino_patch_size * 2 - def get_features_patches( - self, in_patches: Tensor - ) -> Tuple[Tensor, Tensor]: + def get_features_patches(self, in_patches: Tensor) -> Tensor: # get the mobile-sam encoder and embedding layer outputs with torch.no_grad(): # we use get_intermediate_layers method of dinov2, which returns a tuple. # output shape: b, 384, h, w if reshape is true. output_features = self.model.get_intermediate_layers( - self.input_transforms(in_patches), 1, - return_class_token=False, reshape=True + self.input_transforms(in_patches), + 1, + return_class_token=False, + reshape=True, )[0] # get non-overlapped feature patches feature_patches = get_nonoverlapped_patches( - output_features.cpu(), - self.patch_size, self.overlap + output_features.cpu(), self.patch_size, self.overlap ) return feature_patches diff --git a/src/featureforest/models/MobileSAM/adapter.py b/src/featureforest/models/MobileSAM/adapter.py index b9d7552..11a2d14 100644 --- a/src/featureforest/models/MobileSAM/adapter.py +++ b/src/featureforest/models/MobileSAM/adapter.py @@ -1,5 +1,3 @@ -from typing import Tuple - import torch import torch.nn as nn from torch import Tensor @@ -7,20 +5,15 @@ from featureforest.models.base import BaseModelAdapter from featureforest.utils.data import ( - get_patch_size, get_nonoverlapped_patches, ) class MobileSAMAdapter(BaseModelAdapter): - """MobileSAM model adapter - """ + """MobileSAM model adapter""" + def __init__( - self, - model: nn.Module, - img_height: float, - img_width: float, - device: torch.device + self, model: nn.Module, img_height: int, img_width: int, device: torch.device ) -> None: super().__init__(model, img_height, img_width, device) self.name = "MobileSAM" @@ -33,43 +26,37 @@ def __init__( # input transform for sam self.sam_input_dim = 1024 - self.input_transforms = tv_transforms2.Compose([ - tv_transforms2.Resize( - (self.sam_input_dim, self.sam_input_dim), - interpolation=tv_transforms2.InterpolationMode.BICUBIC, - antialias=True - ), - ]) + self.input_transforms = tv_transforms2.Compose( + [ + tv_transforms2.Resize( + (self.sam_input_dim, self.sam_input_dim), + interpolation=tv_transforms2.InterpolationMode.BICUBIC, + antialias=True, + ), + ] + ) # to transform feature patches back to the original patch size - self.embedding_transform = tv_transforms2.Compose([ - tv_transforms2.Resize( - (self.patch_size, self.patch_size), - interpolation=tv_transforms2.InterpolationMode.BICUBIC, - antialias=True - ), - ]) - - def _set_patch_size(self) -> None: - self.patch_size = get_patch_size(self.img_height, self.img_width) - self.overlap = self.patch_size // 2 + self.embedding_transform = tv_transforms2.Compose( + [ + tv_transforms2.Resize( + (self.patch_size, self.patch_size), + interpolation=tv_transforms2.InterpolationMode.BICUBIC, + antialias=True, + ), + ] + ) - def get_features_patches( - self, in_patches: Tensor - ) -> Tuple[Tensor, Tensor]: + def get_features_patches(self, in_patches: Tensor) -> tuple[Tensor, Tensor]: # get the mobile-sam encoder and embedding layer outputs with torch.no_grad(): - output, embed_output, _ = self.encoder( - self.input_transforms(in_patches) - ) + output, embed_output, _ = self.encoder(self.input_transforms(in_patches)) # get non-overlapped feature patches out_feature_patches = get_nonoverlapped_patches( - self.embedding_transform(output.cpu()), - self.patch_size, self.overlap + self.embedding_transform(output.cpu()), self.patch_size, self.overlap ) embed_feature_patches = get_nonoverlapped_patches( - self.embedding_transform(embed_output.cpu()), - self.patch_size, self.overlap + self.embedding_transform(embed_output.cpu()), self.patch_size, self.overlap ) return out_feature_patches, embed_feature_patches diff --git a/src/featureforest/models/SAM/adapter.py b/src/featureforest/models/SAM/adapter.py index 25bcc20..c6fa29c 100644 --- a/src/featureforest/models/SAM/adapter.py +++ b/src/featureforest/models/SAM/adapter.py @@ -1,5 +1,3 @@ -from typing import Tuple - import torch import torch.nn as nn from torch import Tensor @@ -7,7 +5,6 @@ from featureforest.models.base import BaseModelAdapter from featureforest.utils.data import ( - get_patch_size, get_nonoverlapped_patches, ) @@ -17,11 +14,12 @@ class SAMAdapter(BaseModelAdapter): Supports: 1) default 'vit_h' model, and 2) light microscopy and electron microscopy `micro-sam` models ('vit_b'). """ + def __init__( self, image_encoder: nn.Module, - img_height: float, - img_width: float, + img_height: int, + img_width: int, device: torch.device, name: str, ) -> None: @@ -38,35 +36,31 @@ def __init__( # input transform for sam self.sam_input_dim = 1024 - self.input_transforms = tv_transforms2.Compose([ - tv_transforms2.Resize( - (self.sam_input_dim, self.sam_input_dim), - interpolation=tv_transforms2.InterpolationMode.BICUBIC, - antialias=True - ), - ]) + self.input_transforms = tv_transforms2.Compose( + [ + tv_transforms2.Resize( + (self.sam_input_dim, self.sam_input_dim), + interpolation=tv_transforms2.InterpolationMode.BICUBIC, + antialias=True, + ), + ] + ) # to transform feature patches back to the original patch size - self.embedding_transform = tv_transforms2.Compose([ - tv_transforms2.Resize( - (self.patch_size, self.patch_size), - interpolation=tv_transforms2.InterpolationMode.BICUBIC, - antialias=True - ), - ]) - - def _set_patch_size(self) -> None: - self.patch_size = get_patch_size(self.img_height, self.img_width) - self.overlap = self.patch_size // 2 + self.embedding_transform = tv_transforms2.Compose( + [ + tv_transforms2.Resize( + (self.patch_size, self.patch_size), + interpolation=tv_transforms2.InterpolationMode.BICUBIC, + antialias=True, + ), + ] + ) - def get_features_patches( - self, in_patches: Tensor - ) -> Tuple[Tensor, Tensor]: + def get_features_patches(self, in_patches: Tensor) -> tuple[Tensor, Tensor]: # get the mobile-sam encoder and embedding layer outputs with torch.no_grad(): # output: b,256,64,64 - output = self.encoder( - self.input_transforms(in_patches) - ) + output = self.encoder(self.input_transforms(in_patches)) # embed_output: b,64,64,1280 -> b,1280,64,64 embed_output = self.encoder.patch_embed( self.input_transforms(in_patches) @@ -74,12 +68,10 @@ def get_features_patches( # get non-overlapped feature patches out_feature_patches = get_nonoverlapped_patches( - self.embedding_transform(output.cpu()), - self.patch_size, self.overlap + self.embedding_transform(output.cpu()), self.patch_size, self.overlap ) embed_feature_patches = get_nonoverlapped_patches( - self.embedding_transform(embed_output.cpu()), - self.patch_size, self.overlap + self.embedding_transform(embed_output.cpu()), self.patch_size, self.overlap ) return out_feature_patches, embed_feature_patches diff --git a/src/featureforest/models/SAM2/adapter.py b/src/featureforest/models/SAM2/adapter.py index 571d4ef..00fa078 100644 --- a/src/featureforest/models/SAM2/adapter.py +++ b/src/featureforest/models/SAM2/adapter.py @@ -1,5 +1,3 @@ -from typing import Tuple - import torch import torch.nn as nn from torch import Tensor @@ -7,21 +5,20 @@ from featureforest.models.base import BaseModelAdapter from featureforest.utils.data import ( - get_patch_size, get_nonoverlapped_patches, ) class SAM2Adapter(BaseModelAdapter): - """SAM2 model adapter - """ + """SAM2 model adapter""" + def __init__( self, image_encoder: nn.Module, - img_height: float, - img_width: float, + img_height: int, + img_width: int, device: torch.device, - name: str = "SAM2_Large" + name: str = "SAM2_Large", ) -> None: super().__init__(image_encoder, img_height, img_width, device) # for different flavors of SAM2 only the name is different. @@ -34,44 +31,41 @@ def __init__( # input transform for sam self.sam_input_dim = 1024 - self.input_transforms = tv_transforms2.Compose([ - tv_transforms2.Resize( - (self.sam_input_dim, self.sam_input_dim), - interpolation=tv_transforms2.InterpolationMode.BICUBIC, - antialias=True - ), - ]) + self.input_transforms = tv_transforms2.Compose( + [ + tv_transforms2.Resize( + (self.sam_input_dim, self.sam_input_dim), + interpolation=tv_transforms2.InterpolationMode.BICUBIC, + antialias=True, + ), + ] + ) # to transform feature patches back to the original patch size - self.embedding_transform = tv_transforms2.Compose([ - tv_transforms2.Resize( - (self.patch_size, self.patch_size), - interpolation=tv_transforms2.InterpolationMode.BICUBIC, - antialias=True - ), - ]) - - def _set_patch_size(self) -> None: - self.patch_size = get_patch_size(self.img_height, self.img_width) - self.overlap = self.patch_size // 2 + self.embedding_transform = tv_transforms2.Compose( + [ + tv_transforms2.Resize( + (self.patch_size, self.patch_size), + interpolation=tv_transforms2.InterpolationMode.BICUBIC, + antialias=True, + ), + ] + ) - def get_features_patches( - self, in_patches: Tensor - ) -> Tuple[Tensor, Tensor]: + def get_features_patches(self, in_patches: Tensor) -> Tensor: # get the image encoder outputs with torch.no_grad(): - output = self.encoder( - self.input_transforms(in_patches) - ) + output = self.encoder(self.input_transforms(in_patches)) # backbone_fpn contains 3 levels of features from hight to low resolution. # [b, 256, 256, 256] # [b, 256, 128, 128] # [b, 256, 64, 64] features = [ - self.embedding_transform(feat.cpu()) - for feat in output["backbone_fpn"] + self.embedding_transform(feat.cpu()) for feat in output["backbone_fpn"] ] features = torch.cat(features, dim=1) - out_feature_patches = get_nonoverlapped_patches(features, self.patch_size, self.overlap) + out_feature_patches = get_nonoverlapped_patches( + features, self.patch_size, self.overlap + ) return out_feature_patches diff --git a/src/featureforest/models/base.py b/src/featureforest/models/base.py index 97aeb5a..178b490 100644 --- a/src/featureforest/models/base.py +++ b/src/featureforest/models/base.py @@ -5,18 +5,15 @@ from ..utils.data import ( get_nonoverlapped_patches, + get_patch_size, ) class BaseModelAdapter: - """Base class for adapting any models in featureforest. - """ + """Base class for adapting any models in featureforest.""" + def __init__( - self, - model: nn.Module, - img_height: float, - img_width: float, - device: torch.device + self, model: nn.Module, img_height: int, img_width: int, device: torch.device ) -> None: """Initialization function @@ -34,33 +31,52 @@ def __init__( self.device = device # set patch size and overlap self.patch_size = 512 - self.overlap = self.patch_size // 2 + self.overlap = self.patch_size // 4 + self._no_patching = False # input image transforms - self.input_transforms = tv_transforms2.Compose([ - tv_transforms2.Resize( - (1024, 1024), - interpolation=tv_transforms2.InterpolationMode.BICUBIC, - antialias=True - ), - ]) + self.input_transforms = tv_transforms2.Compose( + [ + tv_transforms2.Resize( + (1024, 1024), + interpolation=tv_transforms2.InterpolationMode.BICUBIC, + antialias=True, + ), + ] + ) # to transform feature patches back to the original patch size - self.embedding_transform = tv_transforms2.Compose([ - tv_transforms2.Resize( - (self.patch_size, self.patch_size), - interpolation=tv_transforms2.InterpolationMode.BICUBIC, - antialias=True - ), - ]) + self.embedding_transform = tv_transforms2.Compose( + [ + tv_transforms2.Resize( + (self.patch_size, self.patch_size), + interpolation=tv_transforms2.InterpolationMode.BICUBIC, + antialias=True, + ), + ] + ) + + @property + def no_patching(self) -> bool: + return self._no_patching + + @no_patching.setter + def no_patching(self, value: bool): + self._no_patching = value + self._set_patch_size() def _set_patch_size(self) -> None: """Sets the proper patch size and patch overlap with respect to the model & image resolution. """ - raise NotImplementedError + if self._no_patching: + self.patch_size = self.img_height + self.overlap = 0 + else: + self.patch_size = get_patch_size(self.img_height, self.img_width) + self.overlap = self.patch_size // 4 + # update embedding transform + self.embedding_transform.transforms[0].size = [self.patch_size, self.patch_size] - def get_features_patches( - self, in_patches: Tensor - ) -> Tensor: + def get_features_patches(self, in_patches: Tensor) -> Tensor: """Returns model's extracted features. This is an abstract function, and should be overridden. @@ -77,8 +93,7 @@ def get_features_patches( # get non-overlapped feature patches feature_patches = get_nonoverlapped_patches( - self.embedding_transform(output_features.cpu()), - self.patch_size, self.overlap + self.embedding_transform(output_features.cpu()), self.patch_size, self.overlap ) return feature_patches From cb3d5441e045c3dc828e48b893b034aa4a344996 Mon Sep 17 00:00:00 2001 From: Mehdi Seifi Date: Thu, 12 Jun 2025 12:40:33 +0200 Subject: [PATCH 03/35] updated segmentation widget --- src/featureforest/_segmentation_widget.py | 141 ++++++++++++---------- src/featureforest/utils/config.py | 2 +- src/featureforest/widgets/utils.py | 14 ++- 3 files changed, 91 insertions(+), 66 deletions(-) diff --git a/src/featureforest/_segmentation_widget.py b/src/featureforest/_segmentation_widget.py index 6fbfd95..c8a4b1a 100644 --- a/src/featureforest/_segmentation_widget.py +++ b/src/featureforest/_segmentation_widget.py @@ -4,11 +4,16 @@ import pickle import time import warnings +from collections.abc import Generator from pathlib import Path +from typing import Optional import h5py import napari +import napari.layers +import napari.types import napari.utils.notifications as notif +import napari.view_layers import numpy as np import tifffile from napari.qt.threading import create_worker @@ -34,7 +39,7 @@ from tifffile import TiffFile from .exports import EXPORTERS -from .models import get_model +from .models import BaseModelAdapter, get_model from .postprocess import ( get_sam_auto_masks, postprocess, @@ -58,16 +63,16 @@ class SegmentationWidget(QWidget): - def __init__(self, napari_viewer: napari.Viewer): + def __init__(self, napari_viewer: napari.Viewer) -> None: super().__init__() self.viewer = napari_viewer - self.image_layer = None - self.gt_layer = None - self.segmentation_layer = None - self.postprocess_layer = None - self.storage = None - self.rf_model = None - self.model_adapter = None + self.image_layer: Optional[napari.layers.Image] = None + self.gt_layer: Optional[napari.layers.Labels] = None + self.segmentation_layer: Optional[napari.layers.Labels] = None + self.postprocess_layer: Optional[napari.layers.Labels] = None + self.storage: Optional[str] = None + self.rf_model: Optional[RandomForestClassifier] = None + self.model_adapter: Optional[BaseModelAdapter] = None self.sam_auto_masks = None self.patch_size = 512 # default values self.overlap = 384 @@ -76,7 +81,7 @@ def __init__(self, napari_viewer: napari.Viewer): self.prepare_widget() - def closeEvent(self, event): + def closeEvent(self, event) -> None: print("closing") self.viewer.layers.events.inserted.disconnect(self.check_input_layers) self.viewer.layers.events.removed.disconnect(self.check_input_layers) @@ -87,7 +92,7 @@ def closeEvent(self, event): self.viewer.layers.events.removed.disconnect(self.postprocess_layer_removed) - def prepare_widget(self): + def prepare_widget(self) -> None: self.base_layout = QVBoxLayout() self.create_input_ui() self.create_label_stats_ui() @@ -118,7 +123,7 @@ def prepare_widget(self): self.viewer.dims.events.current_step.connect(self.clear_sam_auto_masks) - def create_input_ui(self): + def create_input_ui(self) -> None: # input layer input_label = QLabel("Input Layer:") self.image_combo = QComboBox() @@ -168,7 +173,7 @@ def create_input_ui(self): gbox.setLayout(layout) self.base_layout.addWidget(gbox) - def create_label_stats_ui(self): + def create_label_stats_ui(self) -> None: self.num_class_label = QLabel("Number of classes: ") self.each_class_label = QLabel("Labels per class:") analyze_button = QPushButton("Analyze") @@ -195,7 +200,7 @@ def create_label_stats_ui(self): gbox.setLayout(layout) self.base_layout.addWidget(gbox) - def create_train_ui(self): + def create_train_ui(self) -> None: tree_label = QLabel("Number of trees:") self.num_trees_textbox = QLineEdit() self.num_trees_textbox.setText("450") @@ -247,7 +252,7 @@ def create_train_ui(self): gbox.setLayout(layout) self.base_layout.addWidget(gbox) - def create_prediction_ui(self): + def create_prediction_ui(self) -> None: seg_label = QLabel("Segmentation Layer:") self.new_layer_checkbox = QCheckBox("New Layer") self.new_layer_checkbox.setChecked(True) @@ -301,7 +306,7 @@ def create_prediction_ui(self): gbox.setLayout(layout) self.base_layout.addWidget(gbox) - def create_postprocessing_ui(self): + def create_postprocessing_ui(self) -> None: smooth_label = QLabel("Smoothing Iterations:") self.smoothing_iteration_textbox = QLineEdit() self.smoothing_iteration_textbox.setText("25") @@ -387,7 +392,7 @@ def create_postprocessing_ui(self): gbox.setLayout(layout) self.base_layout.addWidget(gbox) - def create_export_ui(self): + def create_export_ui(self) -> None: export_label = QLabel("Export Format:") self.export_format_combo = QComboBox() for exporter in EXPORTERS: @@ -422,7 +427,7 @@ def create_export_ui(self): gbox.setLayout(layout) self.base_layout.addWidget(gbox) - def create_large_stack_prediction_ui(self): + def create_large_stack_prediction_ui(self) -> None: stack_label = QLabel("Select your stack:") self.large_stack_textbox = QLineEdit() self.large_stack_textbox.setReadOnly(True) @@ -481,13 +486,13 @@ def create_large_stack_prediction_ui(self): gbox.setLayout(layout) self.base_layout.addWidget(gbox) - def new_layer_checkbox_changed(self): + def new_layer_checkbox_changed(self) -> None: state = self.new_layer_checkbox.checkState() self.prediction_layer_combo.setEnabled(state == Qt.Unchecked) self.seg_add_radiobutton.setEnabled(state == Qt.Unchecked) self.seg_replace_radiobutton.setEnabled(state == Qt.Unchecked) - def check_input_layers(self, event: Event): + def check_input_layers(self, event: Optional[Event]) -> None: curr_text = self.image_combo.currentText() self.image_combo.clear() for layer in self.viewer.layers: @@ -499,7 +504,7 @@ def check_input_layers(self, event: Event): if index > -1: self.image_combo.setCurrentIndex(index) - def check_label_layers(self, event: Event): + def check_label_layers(self, event: Optional[Event]) -> None: gt_curr_text = self.gt_combo.currentText() pred_curr_text = self.prediction_layer_combo.currentText() self.gt_combo.clear() @@ -524,10 +529,10 @@ def check_label_layers(self, event: Event): if index > -1: self.prediction_layer_combo.setCurrentIndex(index) - def clear_sam_auto_masks(self): + def clear_sam_auto_masks(self) -> None: self.sam_auto_masks = None - def postprocess_layer_removed(self, event: Event): + def postprocess_layer_removed(self, event: Event) -> None: """Fires when current postprocess layer is removed.""" if ( self.postprocess_layer is not None @@ -535,15 +540,15 @@ def postprocess_layer_removed(self, event: Event): ): self.postprocess_layer = None - def sam_post_checked(self, checked: bool): + def sam_post_checked(self, checked: bool) -> None: if checked: self.sam_auto_post_checkbox.setChecked(False) - def sam_auto_post_checked(self, checked: bool): + def sam_auto_post_checked(self, checked: bool) -> None: if checked: self.sam_post_checkbox.setChecked(False) - def select_storage(self): + def select_storage(self) -> None: selected_file, _filter = QFileDialog.getOpenFileName( self, "FeatureForest", ".", "Feature Storage(*.hdf5)" ) @@ -560,7 +565,7 @@ def select_storage(self): img_height = self.storage.attrs["img_height"] img_width = self.storage.attrs["img_width"] # TODO: raise an error if current image dims are in conflicting with storage - model_name = self.storage.attrs["model"] + model_name = str(self.storage.attrs["model"]) self.model_adapter = get_model(model_name, img_height, img_width) print(model_name, self.patch_size, self.overlap) @@ -569,7 +574,7 @@ def select_storage(self): csv_path = storage_path.parent.joinpath(f"{storage_path.stem}_seg_stats.csv") self.stats.set_file_path(csv_path) - def add_labels_layer(self): + def add_labels_layer(self) -> None: self.image_layer = get_layer( self.viewer, self.image_combo.currentText(), config.NAPARI_IMAGE_LAYER ) @@ -585,14 +590,14 @@ def add_labels_layer(self): layer.colormap = colormaps.create_colormap(10)[0] layer.brush_size = 1 - def set_stats_label_layer(self): + def set_stats_label_layer(self) -> None: layer = get_layer( self.viewer, self.gt_combo.currentText(), config.NAPARI_LABELS_LAYER ) if layer is not None: self.stats.set_label_layer(layer) - def get_class_labels(self): + def get_class_labels(self) -> dict[int, np.ndarray]: labels_dict = {} layer = get_layer( self.viewer, self.gt_combo.currentText(), config.NAPARI_LABELS_LAYER @@ -611,7 +616,7 @@ def get_class_labels(self): return labels_dict - def analyze_labels(self, labels_dict: dict = None): + def analyze_labels(self, labels_dict: Optional[dict]) -> None: if labels_dict is None: labels_dict = self.get_class_labels() num_labels = [len(v) for v in labels_dict.values()] @@ -621,11 +626,11 @@ def analyze_labels(self, labels_dict: dict = None): ) self.each_class_label.setText("Labels per class:\n" + each_class) - def show_usage_stats(self): + def show_usage_stats(self) -> None: stats_widget = UsageStats(self.stats) stats_widget.exec() - def get_train_data(self): + def get_train_data(self) -> tuple[np.ndarray, np.ndarray] | None: # get ground truth class labels labels_dict = self.get_class_labels() if len(labels_dict) == 0: @@ -652,9 +657,7 @@ def get_train_data(self): for slice_index in np_progress(uniq_slices, desc="reading slices"): slice_coords = class_label_coords[ class_label_coords[:, 0] == slice_index - ][ - :, 1: - ] # omit the slice dim + ][:, 1:] # omit the slice dim patch_indices = get_patch_indices( slice_coords, img_height, img_width, self.patch_size, self.overlap ) @@ -675,7 +678,7 @@ def get_train_data(self): return train_data, labels - def train_model(self): + def train_model(self) -> None: self.image_layer = get_layer( self.viewer, self.image_combo.currentText(), config.NAPARI_IMAGE_LAYER ) @@ -717,7 +720,7 @@ def train_model(self): notif.show_info("Model status: Training is Done!") self.model_save_button.setEnabled(True) - def load_rf_model(self): + def load_rf_model(self) -> None: selected_file, _filter = QFileDialog.getOpenFileName( self, "FeatureForest", ".", "model(*.bin)" ) @@ -737,12 +740,14 @@ def load_rf_model(self): # (users can just load the rf model) if self.model_adapter is None: model_name = model_data["model_name"] + no_patching = model_data["no_patching"] self.patch_size = model_data["patch_size"] self.overlap = model_data["overlap"] img_height = model_data["img_height"] img_width = model_data["img_width"] # init the model adapter self.model_adapter = get_model(model_name, img_height, img_width) + self.model_adapter.no_patching = no_patching else: # old format self.rf_model = model_data @@ -751,7 +756,7 @@ def load_rf_model(self): self.model_status_label.setText("Model status: Ready!") self.model_save_button.setEnabled(True) - def save_rf_model(self): + def save_rf_model(self) -> None: if self.rf_model is None: notif.show_info("There is no trained model!") return @@ -767,6 +772,7 @@ def save_rf_model(self): "model_name": self.model_adapter.name, "img_height": self.storage.attrs["img_height"], "img_width": self.storage.attrs["img_width"], + "no_patching": self.model_adapter.no_patching, "patch_size": self.patch_size, "overlap": self.overlap, } @@ -774,7 +780,7 @@ def save_rf_model(self): pickle.dump(model_data, f) notif.show_info("Model was saved successfully.") - def predict(self, whole_stack=False): + def predict(self, whole_stack: bool = False) -> None: self.prediction_progress.setValue(0) if self.rf_model is None: notif.show_error("There is no trained RF model!") @@ -825,7 +831,9 @@ def predict(self, whole_stack=False): self.predict_worker.finished.connect(self.prediction_is_done) self.predict_worker.run() - def run_prediction(self, slice_indices, img_height, img_width): + def run_prediction( + self, slice_indices: list, img_height: int, img_width: int + ) -> Generator[tuple[int, int], None, None]: for slice_index in np_progress(slice_indices): self.stats.prediction_started() @@ -854,7 +862,13 @@ def run_prediction(self, slice_indices, img_height, img_width): self.segmentation_layer.colormap = cm self.segmentation_layer.refresh() - def predict_slice(self, rf_model, slice_index, img_height, img_width): + def predict_slice( + self, + rf_model: RandomForestClassifier, + slice_index: int, + img_height: int, + img_width: int, + ) -> np.ndarray: """Predict a slice patch by patch""" segmentation_image = [] # shape: N x target_size x target_size x C @@ -884,26 +898,26 @@ def predict_slice(self, rf_model, slice_index, img_height, img_width): return segmentation_image - def stop_predict(self): + def stop_predict(self) -> None: if self.predict_worker is not None: self.predict_worker.quit() self.predict_worker = None self.predict_stop_button.setEnabled(False) - def update_prediction_progress(self, values): + def update_prediction_progress(self, values: tuple) -> None: curr, total = values self.prediction_progress.setMinimum(0) self.prediction_progress.setMaximum(total) self.prediction_progress.setValue(curr + 1) self.prediction_progress.setFormat("slice %v of %m (%p%)") - def prediction_is_done(self): + def prediction_is_done(self) -> None: self.predict_all_button.setEnabled(True) self.predict_stop_button.setEnabled(False) print("Prediction is done!") notif.show_info("Prediction is done!") - def get_postprocess_params(self): + def get_postprocess_params(self) -> tuple[int, int, bool]: smoothing_iterations = 25 if len(self.smoothing_iteration_textbox.text()) > 0: smoothing_iterations = int(self.smoothing_iteration_textbox.text()) @@ -917,7 +931,7 @@ def get_postprocess_params(self): return smoothing_iterations, area_threshold, area_is_absolute - def postprocess_segmentation(self, whole_stack=False): + def postprocess_segmentation(self, whole_stack: bool = False) -> None: self.segmentation_layer = get_layer( self.viewer, self.prediction_layer_combo.currentText(), @@ -926,12 +940,15 @@ def postprocess_segmentation(self, whole_stack=False): if self.segmentation_layer is None: notif.show_error("No segmentation layer is selected!") return + if self.image_layer is None: + notif.show_error("No image layer is selected!") + return smoothing_iterations, area_threshold, area_is_absolute = ( self.get_postprocess_params() ) - num_slices, img_height, img_width = get_stack_dims(self.image_layer.data) + num_slices, _, _ = get_stack_dims(self.image_layer.data) slice_indices = [] if not whole_stack: # only predict the current slice @@ -978,7 +995,7 @@ def postprocess_segmentation(self, whole_stack=False): self.postprocess_layer.refresh() - def export_segmentation(self): + def export_segmentation(self) -> None: if self.segmentation_layer is None: notif.show_error("No segmentation layer is selected!") return @@ -1004,7 +1021,7 @@ def export_segmentation(self): notif.show_info("Selected layer was exported successfully.") - def select_stack(self): + def select_stack(self) -> None: selected_file, _filter = QFileDialog.getOpenFileName( self, "FeatureForest", ".", "TIFF stack (*.tiff *.tif)" ) @@ -1012,9 +1029,9 @@ def select_stack(self): # get stack info with TiffFile(selected_file) as tiff_stack: axes = tiff_stack.series[0].axes - assert ("Y" in axes) and ( - "X" in axes - ), "Could not find YX in the stack axes!" + assert ("Y" in axes) and ("X" in axes), ( + "Could not find YX in the stack axes!" + ) stack_dims = tiff_stack.series[0].shape stack_height = stack_dims[axes.index("Y")] stack_width = stack_dims[axes.index("X")] @@ -1038,7 +1055,7 @@ def select_stack(self): res_dir = Path(selected_file).parent self.result_dir_textbox.setText(str(res_dir.absolute())) - def select_result_dir(self): + def select_result_dir(self) -> None: selected_dir = QFileDialog.getExistingDirectory( self, "Select a directory", @@ -1048,7 +1065,7 @@ def select_result_dir(self): if selected_dir is not None and len(selected_dir) > 0: self.result_dir_textbox.setText(selected_dir) - def run_pipeline_over_large_stack(self): + def run_pipeline_over_large_stack(self) -> None: if self.large_stack_textbox.text() == "": notif.show_error("No TIFF Stack is selected!") return @@ -1068,7 +1085,9 @@ def run_pipeline_over_large_stack(self): self.pipeline_worker.finished.connect(self.pipeline_is_done) self.pipeline_worker.run() - def run_pipeline(self, tiff_stack_file: str, result_dir: Path): + def run_pipeline( + self, tiff_stack_file: str, result_dir: Path + ) -> Generator[tuple[int, int], None, None]: start = dt.datetime.now() slices_total_time = 0 postprocess_total_time = 0 @@ -1142,27 +1161,27 @@ def run_pipeline(self, tiff_stack_file: str, result_dir: Path): result_dir, start, end, slices_total_time, postprocess_total_time, total_pages ) - def stop_pipeline(self): + def stop_pipeline(self) -> None: if self.pipeline_worker is not None: self.pipeline_worker.quit() self.pipeline_worker = None self.stop_pipeline_button.setEnabled(False) - def update_pipeline_progress(self, values): + def update_pipeline_progress(self, values: tuple) -> None: curr, total = values self.pipeline_progressbar.setMinimum(0) self.pipeline_progressbar.setMaximum(total) self.pipeline_progressbar.setValue(curr + 1) self.pipeline_progressbar.setFormat("slice %v of %m (%p%)") - def pipeline_is_done(self): + def pipeline_is_done(self) -> None: self.run_pipeline_button.setEnabled(True) self.stop_pipeline_button.setEnabled(False) self.remove_temp_storage() print("Prediction is done!") notif.show_info("Prediction is done!") - def remove_temp_storage(self): + def remove_temp_storage(self) -> None: tmp_storage_path = Path.home().joinpath(".featureforest", "tmp_storage.h5") if tmp_storage_path.exists(): tmp_storage_path.unlink() @@ -1175,7 +1194,7 @@ def save_pipeline_stats( slice_total: float, pp_total: float, num_images: int, - ): + ) -> None: total_time = (end - start).total_seconds() total_min, total_sec = divmod(total_time, 60) total_hour, total_min = divmod(total_min, 60) diff --git a/src/featureforest/utils/config.py b/src/featureforest/utils/config.py index dbe926f..05ed482 100644 --- a/src/featureforest/utils/config.py +++ b/src/featureforest/utils/config.py @@ -1,5 +1,5 @@ import napari - +import napari.layers BG_CLASS_NAME = "background" NAPARI_BG_CLASS = 0 diff --git a/src/featureforest/widgets/utils.py b/src/featureforest/widgets/utils.py index 6877376..ee33a5c 100644 --- a/src/featureforest/widgets/utils.py +++ b/src/featureforest/widgets/utils.py @@ -1,12 +1,18 @@ +from typing import Optional + import napari +import napari.layers as layers +import napari.types import numpy as np from featureforest.utils import colormaps -def get_layer(napari_viewer, name, layer_types): +def get_layer( + napari_viewer: napari.Viewer, name: str, layer_type: type +) -> Optional[layers.Layer]: for layer in napari_viewer.layers: - if layer.name == name and isinstance(layer, layer_types): + if layer.name == name and isinstance(layer, layer_type): return layer return None @@ -20,8 +26,8 @@ def add_labels_layer_(napari_viewer: napari.Viewer): scene_size = extent[1] - extent[0] corner = extent[0] shape = [ - np.round(s / sc).astype('int') + 1 - for s, sc in zip(scene_size, scale) + np.round(s / sc).astype("int") + 1 + for s, sc in zip(scene_size, scale, strict=False) ] empty_labels = np.zeros(shape, dtype=np.uint8) layer = napari_viewer.add_labels( From 411bd9e18ebf980e7aae75491fb3aafd3370be78 Mon Sep 17 00:00:00 2001 From: Mehdi Seifi Date: Thu, 12 Jun 2025 15:06:58 +0200 Subject: [PATCH 04/35] removed unused imports --- src/featureforest/_segmentation_widget.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/featureforest/_segmentation_widget.py b/src/featureforest/_segmentation_widget.py index c8a4b1a..aca9e8e 100644 --- a/src/featureforest/_segmentation_widget.py +++ b/src/featureforest/_segmentation_widget.py @@ -11,9 +11,7 @@ import h5py import napari import napari.layers -import napari.types import napari.utils.notifications as notif -import napari.view_layers import numpy as np import tifffile from napari.qt.threading import create_worker From b9ce247ed3a5fac0cf31c63429d28cdedde6b355 Mon Sep 17 00:00:00 2001 From: Mehdi Seifi Date: Thu, 12 Jun 2025 15:10:01 +0200 Subject: [PATCH 05/35] fixed some typings --- src/featureforest/utils/data.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/featureforest/utils/data.py b/src/featureforest/utils/data.py index 5c276b0..63218e6 100644 --- a/src/featureforest/utils/data.py +++ b/src/featureforest/utils/data.py @@ -120,7 +120,7 @@ def patchify( def get_num_patches( - img_height: float, img_width: float, patch_size: int, overlap: int + img_height: int, img_width: int, patch_size: int, overlap: int ) -> tuple[int, int]: """Returns number of patches per each image dimension. @@ -172,8 +172,8 @@ def get_nonoverlapped_patches(patches: Tensor, patch_size: int, overlap: int) -> def get_patch_index( pix_y: float, pix_x: float, - img_height: float, - img_width: float, + img_height: int, + img_width: int, patch_size: int, overlap: int, ) -> int: @@ -199,8 +199,8 @@ def get_patch_index( def get_patch_indices( pixel_coords: ndarray, - img_height: float, - img_width: float, + img_height: int, + img_width: int, patch_size: int, overlap: int, ) -> ndarray: From 2050125cfcc722cf68c69ff6eaf70706941c2d1c Mon Sep 17 00:00:00 2001 From: Mehdi Seifi Date: Thu, 12 Jun 2025 15:26:07 +0200 Subject: [PATCH 06/35] added the dataset class --- src/featureforest/utils/data.py | 29 +++++++++++ src/featureforest/utils/dataset.py | 84 ++++++++++++++++++++++++++++++ 2 files changed, 113 insertions(+) create mode 100644 src/featureforest/utils/dataset.py diff --git a/src/featureforest/utils/data.py b/src/featureforest/utils/data.py index 63218e6..9bd4bed 100644 --- a/src/featureforest/utils/data.py +++ b/src/featureforest/utils/data.py @@ -1,11 +1,40 @@ from typing import Optional import numpy as np +import torch import torch.nn.functional as F from numpy import ndarray from torch import Tensor +def get_model_ready_image(image: np.ndarray) -> torch.Tensor: + """Convert the input image to a torch tensor and normalize it. + Args: + image (np.ndarray): Input image to be converted. + Returns: + torch.Tensor: The input image as a torch tensor, normalized to [0, 1]. + """ + # image to torch tensor + img_data = torch.from_numpy(image.copy()).to(torch.float32) + # normalize in [0, 1] + _min = img_data.min() + _max = img_data.max() + img_data = (img_data - _min) / (_max - _min) + # for sam the input image should be 4D: BxCxHxW ; an RGB image. + if not is_stacked(img_data.numpy()): + # add a batch dim + img_data = img_data.unsqueeze(0) + if is_image_rgb(img_data.numpy()): + # it's already RGB + img_data = img_data[..., :3] # discard the alpha channel (in case of PNG). + img_data = img_data.permute([0, 3, 1, 2]) # make it channel first + else: + # make it RGB by repeating the single channel + img_data = img_data.unsqueeze(1).expand(-1, 3, -1, -1) + + return img_data + + def get_patch_size( img_height: float, img_width: float, divisible_by: Optional[int] = None ) -> int: diff --git a/src/featureforest/utils/dataset.py b/src/featureforest/utils/dataset.py new file mode 100644 index 0000000..76c4ea2 --- /dev/null +++ b/src/featureforest/utils/dataset.py @@ -0,0 +1,84 @@ +from collections.abc import Iterable +from pathlib import Path + +import numpy as np +import pims +import torch +from tifffile import natural_sorted +from torch.utils.data import IterableDataset + +from featureforest.utils.data import ( + get_model_ready_image, + patchify, +) + + +class FFImageDataset(IterableDataset): + """ + Iterable dataset for large images or image stacks. + This dataset can handle large TIFF files or directories of images, + and it can yield patches of images or the whole image depending on the configuration. + """ + + def __init__( + self, + image_array: np.ndarray | None = None, + stack_file: str | Path | None = None, + img_dir: str | Path | None = None, + no_patching: bool = False, + patch_size: int = 512, + overlap: int = 128, + ) -> None: + super().__init__() + if image_array is not None and (stack_file is not None or img_dir is not None): + raise ValueError( + "Please provide either an image array or a stack file or an image directory, not both." + ) + if image_array is None and stack_file is None and img_dir is None: + raise ValueError( + "Please provide either a large TIFF file or an image directory." + ) + + self.no_patching = no_patching + self.patch_size = patch_size + self.overlap = overlap + self.image_files = [] + self.image_source = [] + + if image_array is not None: + # image is already loaded as a numpy array + self.image_source = image_array + elif stack_file is not None: + # can be a large stack, using pims for lazy loading + self.image_source = pims.open(str(stack_file)) + elif img_dir is not None: + # load images from a directory + img_dir = Path(img_dir) + if not img_dir.is_dir(): + raise ValueError(f"The provided path {img_dir} is not a directory.") + self.image_files = ( + list(img_dir.glob("*.tiff")) + + list(img_dir.glob("*.tif")) + + list(img_dir.glob("*.png")) + + list(img_dir.glob("*.jpg")) + ) + if not self.image_files: + raise ValueError(f"No image files found in the directory {img_dir}.") + self.image_files = self._natural_sort(self.image_files) + self.image_source = pims.ImageSequence(map(str, self.image_files)) + + def __iter__(self): + for img_idx, img_slice in enumerate(self.image_source[:]): + img_tensor = get_model_ready_image(img_slice) + if self.no_patching: + # return the whole image as a tensor + yield img_tensor, torch.tensor([img_idx, 0]) + else: + # divide the image into patches and yield them + patches = patchify(img_tensor.unsqueeze(0), self.patch_size, self.overlap) + for p_idx, patch in enumerate(patches): + yield patch, torch.tensor([img_idx, p_idx]) + + def _natural_sort(self, files: Iterable[Path]) -> list[Path]: + """Sort files in a natural order.""" + return sorted(files, key=lambda x: natural_sorted(str(x.name))) From 928cdb03cbaab6cc7f9e6be1a51d225debf2af93 Mon Sep 17 00:00:00 2001 From: Mehdi Seifi Date: Thu, 12 Jun 2025 17:29:35 +0200 Subject: [PATCH 07/35] init using the new dataset for feature extraction --- .../_feature_extractor_widget.py | 7 +- src/featureforest/models/util.py | 36 +++-- src/featureforest/utils/data.py | 32 ++-- src/featureforest/utils/dataset.py | 13 +- src/featureforest/utils/extract.py | 125 +++++++++++---- .../embedding_extraction.py | 56 ++++--- .../model_adapter_tests/test_dino_adapter.py | 128 ++++++++-------- .../test_mobilesam_adapter.py | 142 +++++++++--------- 8 files changed, 318 insertions(+), 221 deletions(-) diff --git a/src/featureforest/_feature_extractor_widget.py b/src/featureforest/_feature_extractor_widget.py index da34399..47582e8 100644 --- a/src/featureforest/_feature_extractor_widget.py +++ b/src/featureforest/_feature_extractor_widget.py @@ -1,6 +1,7 @@ import csv import time from pathlib import Path +from typing import Optional import napari import napari.utils.notifications as notif @@ -38,7 +39,7 @@ def __init__(self, napari_viewer: napari.Viewer): self.viewer = napari_viewer self.extract_worker = None self.model_adapter = None - self.timing = {"start": 0, "avg_per_slice": 0} + self.timing = {"start": 0.0, "avg_per_slice": 0.0} self.prepare_widget() def prepare_widget(self): @@ -129,7 +130,7 @@ def prepare_widget(self): self.base_layout.addWidget(gbox) self.base_layout.addStretch(1) - def check_input_layers(self, event: Event = None): + def check_input_layers(self, event: Optional[Event] = None): curr_text = self.image_combo.currentText() self.image_combo.clear() for layer in self.viewer.layers: @@ -181,7 +182,7 @@ def extract_embeddings(self): self.extract_worker = create_worker( extract_embeddings_to_file, image=image_layer.data, - storage_file_path=storage_path, + storage_path=storage_path, model_adapter=self.model_adapter, ) self.extract_worker.yielded.connect(self.update_extract_progress) diff --git a/src/featureforest/models/util.py b/src/featureforest/models/util.py index f1fcd12..3df7768 100644 --- a/src/featureforest/models/util.py +++ b/src/featureforest/models/util.py @@ -1,12 +1,12 @@ import os from pathlib import Path -from typing import Union, Optional +from typing import Optional, Union -import numpy as np import imageio.v3 as imageio +import numpy as np -from . import _MODELS_DICT, get_model from ..utils.extract import extract_embeddings_to_file, get_stack_dims +from . import _MODELS_DICT, get_model def extract_features( @@ -34,7 +34,9 @@ def extract_features( _, image_height, image_width = get_stack_dims(image) # Step 1: Get the desired model adapter. - model_adapter = get_model(model_name=model_name, img_height=image_height, img_width=image_width) + model_adapter = get_model( + model_name=model_name, img_height=image_height, img_width=image_width + ) # - Transform the inputs transformed_image = model_adapter.input_transforms(image) @@ -45,7 +47,9 @@ def extract_features( output_path = str(Path(output_path).with_suffix(".hdf5")) extractor_generator = extract_embeddings_to_file( - image=transformed_image, storage_file_path=output_path, model_adapter=model_adapter, + image=transformed_image, + storage_path=output_path, + model_adapter=model_adapter, ) # Step 3: Run the extractor generator till the end @@ -64,15 +68,23 @@ def main(): parser = argparse.ArgumentParser(description="Extract features for a chosen model.") parser.add_argument( - "-i", "--input_path", type=str, required=True, + "-i", + "--input_path", + type=str, + required=True, help="The filepath to the image data. Supports all data types that can be read by imageio (eg. tif, png, ...).", ) parser.add_argument( - "-o", "--output_path", type=str, required=True, + "-o", + "--output_path", + type=str, + required=True, help="The filepath to store the extracted features. The current supports store features in a 'hdf5' file.", ) parser.add_argument( - "--model_choice", type=str, default="SAM2_Large", + "--model_choice", + type=str, + default="SAM2_Large", help=f"The choice of vision foundation model that will be used, one of ({available_models}). " "By default, extracts features for 'SAM2_Large'.", ) @@ -84,6 +96,10 @@ def main(): image = imageio.imread(args.input_path) # Extract the features. - output_path = extract_features(image=image, output_path=args.output_path, model_name=args.model_choice) + output_path = extract_features( + image=image, output_path=args.output_path, model_name=args.model_choice + ) - print(f"The features of '{args.model_choice}' have been extracted at '{os.path.abspath(output_path)}'.") + print( + f"The features of '{args.model_choice}' have been extracted at '{os.path.abspath(output_path)}'." + ) diff --git a/src/featureforest/utils/data.py b/src/featureforest/utils/data.py index 9bd4bed..9981947 100644 --- a/src/featureforest/utils/data.py +++ b/src/featureforest/utils/data.py @@ -14,23 +14,24 @@ def get_model_ready_image(image: np.ndarray) -> torch.Tensor: Returns: torch.Tensor: The input image as a torch tensor, normalized to [0, 1]. """ + assert image.ndim < 4, "Input image must be 2D or 3D (single channel or RGB)." # image to torch tensor img_data = torch.from_numpy(image.copy()).to(torch.float32) # normalize in [0, 1] _min = img_data.min() _max = img_data.max() img_data = (img_data - _min) / (_max - _min) - # for sam the input image should be 4D: BxCxHxW ; an RGB image. - if not is_stacked(img_data.numpy()): - # add a batch dim - img_data = img_data.unsqueeze(0) + # for image encoders, the input image must be in RGB. + # if not is_stacked(img_data.numpy()): + # # add a batch dim + # img_data = img_data.unsqueeze(0) if is_image_rgb(img_data.numpy()): # it's already RGB img_data = img_data[..., :3] # discard the alpha channel (in case of PNG). - img_data = img_data.permute([0, 3, 1, 2]) # make it channel first + img_data = img_data.permute([3, 1, 2]) # make it channel first else: # make it RGB by repeating the single channel - img_data = img_data.unsqueeze(1).expand(-1, 3, -1, -1) + img_data = img_data.unsqueeze(0).expand(3, -1, -1) return img_data @@ -112,33 +113,34 @@ def get_paddings( def patchify( - images: Tensor, patch_size: Optional[int] = None, overlap: Optional[int] = None + image: Tensor, patch_size: Optional[int] = None, overlap: Optional[int] = None ) -> Tensor: """Divide images into patches. - images: (B, C, H, W) - out: (B*N, C, patch_size, patch_size) + image: (C, H, W) + out: (N, C, patch_size, patch_size) Args: - images (Tensor): a batch of images of shape (B, C, H, W) + images (Tensor): an image of shape (C, H, W) patch_size (Optional[int], optional): patch size. Defaults to None. overlap (Optional[int], optional): patch overlap. Defaults to None. Returns: - Tensor: patches of the input batch of shape (B*N, C, patch_size, patch_size) + Tensor: patches of shape (N, C, patch_size, patch_size) """ - _, c, img_height, img_width = images.shape + c, img_height, img_width = image.shape if patch_size is None: patch_size = get_patch_size(img_height, img_width) - overlap = patch_size // 2 + overlap = patch_size // 4 if overlap is None: - overlap = patch_size // 2 + overlap = patch_size // 4 stride, margin = get_stride_margin(patch_size, overlap) pad_right, pad_bottom = get_paddings( patch_size, stride, margin, img_height, img_width ) pad = (margin, pad_right + margin, margin, pad_bottom + margin) - padded_imgs = F.pad(images, pad=pad, mode="reflect") + # add batch dim and pad the image + padded_imgs = F.pad(image.unsqueeze(0), pad=pad, mode="reflect") # making patches using torch unfold method patches = padded_imgs.unfold(2, patch_size, step=stride).unfold( 3, patch_size, step=stride diff --git a/src/featureforest/utils/dataset.py b/src/featureforest/utils/dataset.py index 76c4ea2..b491915 100644 --- a/src/featureforest/utils/dataset.py +++ b/src/featureforest/utils/dataset.py @@ -1,5 +1,6 @@ from collections.abc import Iterable from pathlib import Path +from typing import Optional import numpy as np import pims @@ -43,11 +44,14 @@ def __init__( self.patch_size = patch_size self.overlap = overlap self.image_files = [] - self.image_source = [] + self.image_source: Optional[pims.ImageSequence | np.ndarray] = None if image_array is not None: # image is already loaded as a numpy array self.image_source = image_array + # add slice dimension if not present + if self.image_source.ndim == 2: + self.image_source = self.image_source[np.newaxis, ...] elif stack_file is not None: # can be a large stack, using pims for lazy loading self.image_source = pims.open(str(stack_file)) @@ -68,14 +72,17 @@ def __init__( self.image_source = pims.ImageSequence(map(str, self.image_files)) def __iter__(self): - for img_idx, img_slice in enumerate(self.image_source[:]): + if self.image_source is None: + raise ValueError("No image source is available. Please check the input data.") + + for img_idx, img_slice in enumerate(self.image_source): img_tensor = get_model_ready_image(img_slice) if self.no_patching: # return the whole image as a tensor yield img_tensor, torch.tensor([img_idx, 0]) else: # divide the image into patches and yield them - patches = patchify(img_tensor.unsqueeze(0), self.patch_size, self.overlap) + patches = patchify(img_tensor, self.patch_size, self.overlap) for p_idx, patch in enumerate(patches): yield patch, torch.tensor([img_idx, p_idx]) diff --git a/src/featureforest/utils/extract.py b/src/featureforest/utils/extract.py index 3349807..c5be15f 100644 --- a/src/featureforest/utils/extract.py +++ b/src/featureforest/utils/extract.py @@ -1,20 +1,75 @@ from collections.abc import Generator +from pathlib import Path from typing import Optional, Union import h5py import numpy as np import torch from napari.utils import progress as np_progress +from torch.utils.data import DataLoader from featureforest.models import BaseModelAdapter from featureforest.models.SAM import SAMAdapter from featureforest.utils.data import ( - get_stack_dims, get_stride_margin, - image_to_uint8, is_image_rgb, patchify, ) +from featureforest.utils.dataset import FFImageDataset + + +def get_batch_size(model_adapter: BaseModelAdapter) -> int: + """Get the batch size for the model adapter. + The batch size is set to 8 for most models, but for SAMAdapter it is set to 2 + to avoid memory issues with large images. + Args: + model_adapter (BaseModelAdapter): The model adapter to get the batch size for. + Returns: + int: The batch size for the model adapter. + """ + # set a low batch size + batch_size = 8 + # for big SAM we need even lower batch size :( + if isinstance(model_adapter, SAMAdapter): + batch_size = 2 + return batch_size + + +def get_dataset( + image: str | np.ndarray, + no_patching: bool = False, + patch_size: int = 512, + overlap: int = 128, +) -> FFImageDataset: + if isinstance(image, str): + # image is a path to a large TIFF file or a directory of images + img_path = Path(image) + if img_path.is_dir(): + # load images from a directory + dataset = FFImageDataset( + img_dir=img_path, + no_patching=no_patching, + patch_size=patch_size, + overlap=overlap, + ) + else: + # load a (large) stack + dataset = FFImageDataset( + stack_file=img_path, + no_patching=no_patching, + patch_size=patch_size, + overlap=overlap, + ) + elif isinstance(image, np.ndarray): + # image is already loaded as a numpy array + dataset = FFImageDataset( + image_array=image, + no_patching=no_patching, + patch_size=patch_size, + overlap=overlap, + ) + + return dataset def get_slice_features( @@ -87,37 +142,55 @@ def get_slice_features( if storage_group is not None: # to take care of the last batch size that might be smaller than batch_size num_out = slice_features.shape[0] - dataset[start: start + num_out] = slice_features.to( - torch.float16).cpu().numpy() + dataset[start : start + num_out] = ( + slice_features.to(torch.float16).cpu().numpy() + ) yield b_idx else: yield b_idx, slice_features.numpy() def extract_embeddings_to_file( - image: np.ndarray, storage_file_path: str, model_adapter: BaseModelAdapter + image: np.ndarray | str, storage_path: str, model_adapter: BaseModelAdapter ) -> Generator[tuple[int, int], None, None]: + no_patching = model_adapter.no_patching patch_size = model_adapter.patch_size overlap = model_adapter.overlap + batch_size = get_batch_size(model_adapter) + + dataset = get_dataset( + image=image, no_patching=no_patching, patch_size=patch_size, overlap=overlap + ) + dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False) + for img_data, indices in dataloader: + print(f"images: {img_data.shape}\nslices: {indices}") + features = model_adapter.get_features_patches(img_data.to(model_adapter.device)) + print(f"features shape: {features.shape}") + unique_slices = torch.unique(indices[:, 0]).numpy() + print(f"unique slices: {unique_slices}") + for idx in unique_slices: + img_features = features[indices[:, 0] == idx].numpy().astype(np.float16) + print(f"image: {idx}, features shape: {img_features.shape}") + yield idx, len(unique_slices) + + # with h5py.File(storage_path, "w") as storage: + # num_slices, img_height, img_width = get_stack_dims(image) + + # storage.attrs["num_slices"] = num_slices + # storage.attrs["img_height"] = img_height + # storage.attrs["img_width"] = img_width + # storage.attrs["model"] = model_adapter.name + # storage.attrs["patch_size"] = patch_size + # storage.attrs["overlap"] = overlap + + # for slice_index in np_progress( + # range(num_slices), desc="extract features for slices" + # ): + # print(f"\nslice index: {slice_index}") + # slice_img = image[slice_index].copy() if num_slices > 1 else image.copy() + # slice_img = image_to_uint8(slice_img) # image must be an uint8 array + # slice_grp = storage.create_group(str(slice_index)) + # for _ in get_slice_features(slice_img, model_adapter, slice_grp): + # pass - with h5py.File(storage_file_path, "w") as storage: - num_slices, img_height, img_width = get_stack_dims(image) - - storage.attrs["num_slices"] = num_slices - storage.attrs["img_height"] = img_height - storage.attrs["img_width"] = img_width - storage.attrs["model"] = model_adapter.name - storage.attrs["patch_size"] = patch_size - storage.attrs["overlap"] = overlap - - for slice_index in np_progress( - range(num_slices), desc="extract features for slices" - ): - print(f"\nslice index: {slice_index}") - slice_img = image[slice_index].copy() if num_slices > 1 else image.copy() - slice_img = image_to_uint8(slice_img) # image must be an uint8 array - slice_grp = storage.create_group(str(slice_index)) - for _ in get_slice_features(slice_img, model_adapter, slice_grp): - pass - - yield slice_index, num_slices + # yield slice_index, num_slices diff --git a/tests/model_adapter_tests/embedding_extraction.py b/tests/model_adapter_tests/embedding_extraction.py index 21fc282..54bae5e 100644 --- a/tests/model_adapter_tests/embedding_extraction.py +++ b/tests/model_adapter_tests/embedding_extraction.py @@ -1,36 +1,34 @@ -from tempfile import TemporaryDirectory +# from tempfile import TemporaryDirectory -import h5py +# import h5py -from featureforest.utils.extract import extract_embeddings_to_file +# from featureforest.utils.extract import extract_embeddings_to_file -def check_embedding_extraction( - test_image, model_adapter, expected_output_shape, expected_slices -): - with TemporaryDirectory() as tmp_dir: - tmp_file = tmp_dir + "/tmp.h5" +# def check_embedding_extraction( +# test_image, model_adapter, expected_output_shape, expected_slices +# ): +# with TemporaryDirectory() as tmp_dir: +# tmp_file = tmp_dir + "/tmp.h5" - extractor_generator = extract_embeddings_to_file( - image=test_image, - storage_file_path=tmp_file, - model_adapter=model_adapter - ) +# extractor_generator = extract_embeddings_to_file( +# image=test_image, storage_path=tmp_file, model_adapter=model_adapter +# ) - # Run the extractor generator till the end - _ = list(extractor_generator) +# # Run the extractor generator till the end +# _ = list(extractor_generator) - with h5py.File(tmp_file, "r") as read_storage: - slices = list(read_storage.keys()) - assert ( - len(slices) == expected_slices - ), f"Unexpected number of slices: {len(slices)}, expected: {expected_slices}" - for slice in slices: - slice_key = str(slice) - slice_dataset = read_storage[slice_key].get(model_adapter.name) - assert ( - slice_dataset is not None - ), f"The dataset for slice {slice_key} is empty" - assert ( - slice_dataset.shape == expected_output_shape - ), f"Unexpected dataset shape: {slice_dataset.shape}, expected: {expected_output_shape}" \ No newline at end of file +# with h5py.File(tmp_file, "r") as read_storage: +# slices = list(read_storage.keys()) +# assert len(slices) == expected_slices, ( +# f"Unexpected number of slices: {len(slices)}, expected: {expected_slices}" +# ) +# for slice in slices: +# slice_key = str(slice) +# slice_dataset = read_storage[slice_key].get(model_adapter.name) +# assert slice_dataset is not None, ( +# f"The dataset for slice {slice_key} is empty" +# ) +# assert slice_dataset.shape == expected_output_shape, ( +# f"Unexpected dataset shape: {slice_dataset.shape}, expected: {expected_output_shape}" +# ) diff --git a/tests/model_adapter_tests/test_dino_adapter.py b/tests/model_adapter_tests/test_dino_adapter.py index 7b8fbbc..839902e 100644 --- a/tests/model_adapter_tests/test_dino_adapter.py +++ b/tests/model_adapter_tests/test_dino_adapter.py @@ -1,77 +1,77 @@ -import numpy as np -import pytest -import torch -import torch.nn as nn +# import numpy as np +# import pytest +# import torch +# import torch.nn as nn -from featureforest.models.DinoV2 import DinoV2Adapter, get_model -from featureforest.utils.data import get_stack_dims -from embedding_extraction import check_embedding_extraction +# from featureforest.models.DinoV2 import DinoV2Adapter, get_model +# from featureforest.utils.data import get_stack_dims +# from embedding_extraction import check_embedding_extraction -class MockDinoEncoder(nn.Module): - def __init__(self): - super().__init__() - self.dino_patch_size = 14 - self.dino_out_channels = 384 - self.height = 70 - self.width = 70 +# class MockDinoEncoder(nn.Module): +# def __init__(self): +# super().__init__() +# self.dino_patch_size = 14 +# self.dino_out_channels = 384 +# self.height = 70 +# self.width = 70 - def get_intermediate_layers(self, x, *args, **kwargs): - batch_size = x.shape[0] - output = torch.ones(batch_size, self.dino_out_channels, self.height, self.width) - return output, None +# def get_intermediate_layers(self, x, *args, **kwargs): +# batch_size = x.shape[0] +# output = torch.ones(batch_size, self.dino_out_channels, self.height, self.width) +# return output, None -def get_mock_model(img_height: float, img_width: float) -> DinoV2Adapter: - model = MockDinoEncoder() - device = torch.device("cpu") - dino_model_adapter = DinoV2Adapter(model, img_height, img_width, device) - return dino_model_adapter +# def get_mock_model(img_height: float, img_width: float) -> DinoV2Adapter: +# model = MockDinoEncoder() +# device = torch.device("cpu") +# dino_model_adapter = DinoV2Adapter(model, img_height, img_width, device) +# return dino_model_adapter -@pytest.mark.slow() -@pytest.mark.parametrize( - "test_patch", - [ - torch.ones((1, 3, 128, 128)), - torch.ones((3, 3, 128, 128)), - torch.ones((1, 3, 256, 256)), - torch.ones((1, 3, 512, 512)), - ], -) -def test_mock_adapter(test_patch: np.ndarray): - img_height, img_width = test_patch.shape[-2:] - real_adapter = get_model(img_height, img_width) - mock_adapter = get_mock_model(img_height, img_width) +# @pytest.mark.slow() +# @pytest.mark.parametrize( +# "test_patch", +# [ +# torch.ones((1, 3, 128, 128)), +# torch.ones((3, 3, 128, 128)), +# torch.ones((1, 3, 256, 256)), +# torch.ones((1, 3, 512, 512)), +# ], +# ) +# def test_mock_adapter(test_patch: np.ndarray): +# img_height, img_width = test_patch.shape[-2:] +# real_adapter = get_model(img_height, img_width) +# mock_adapter = get_mock_model(img_height, img_width) - transformed_input_patch_real = real_adapter.input_transforms(test_patch) - transformed_input_patch_mock = mock_adapter.input_transforms(test_patch) +# transformed_input_patch_real = real_adapter.input_transforms(test_patch) +# transformed_input_patch_mock = mock_adapter.input_transforms(test_patch) - result_real = real_adapter.model.get_intermediate_layers( - transformed_input_patch_real, 1, return_class_token=False, reshape=True - )[0] - mock_result = mock_adapter.model.get_intermediate_layers( - transformed_input_patch_mock - )[0] +# result_real = real_adapter.model.get_intermediate_layers( +# transformed_input_patch_real, 1, return_class_token=False, reshape=True +# )[0] +# mock_result = mock_adapter.model.get_intermediate_layers( +# transformed_input_patch_mock +# )[0] - assert len(result_real) == len(mock_result) - assert result_real[0].shape == mock_result[0].shape +# assert len(result_real) == len(mock_result) +# assert result_real[0].shape == mock_result[0].shape -@pytest.mark.parametrize( - "test_image, expected_output_shape, expected_slices", - [ - (np.ones((256, 256)), (49, 42, 42, 384), 1), # 2D - (np.ones((256, 256, 3)), (49, 42, 42, 384), 1), # 2D RGB - (np.ones((2, 256, 256)), (49, 42, 42, 384), 2), # 3D - (np.ones((2, 256, 256, 3)), (49, 42, 42, 384), 2), # 3D RGB - ], -) -def test_dino_embedding_extraction( - test_image: np.ndarray, expected_output_shape: tuple, expected_slices: int -): - num_slices, img_height, img_width = get_stack_dims(test_image) - model_adapter = get_mock_model(img_height, img_width) - check_embedding_extraction( - test_image, model_adapter, expected_output_shape, expected_slices - ) +# @pytest.mark.parametrize( +# "test_image, expected_output_shape, expected_slices", +# [ +# (np.ones((256, 256)), (49, 42, 42, 384), 1), # 2D +# (np.ones((256, 256, 3)), (49, 42, 42, 384), 1), # 2D RGB +# (np.ones((2, 256, 256)), (49, 42, 42, 384), 2), # 3D +# (np.ones((2, 256, 256, 3)), (49, 42, 42, 384), 2), # 3D RGB +# ], +# ) +# def test_dino_embedding_extraction( +# test_image: np.ndarray, expected_output_shape: tuple, expected_slices: int +# ): +# num_slices, img_height, img_width = get_stack_dims(test_image) +# model_adapter = get_mock_model(img_height, img_width) +# check_embedding_extraction( +# test_image, model_adapter, expected_output_shape, expected_slices +# ) diff --git a/tests/model_adapter_tests/test_mobilesam_adapter.py b/tests/model_adapter_tests/test_mobilesam_adapter.py index 7c00491..0cfb159 100644 --- a/tests/model_adapter_tests/test_mobilesam_adapter.py +++ b/tests/model_adapter_tests/test_mobilesam_adapter.py @@ -1,84 +1,84 @@ -import numpy as np -import pytest -import torch -import torch.nn as nn +# import numpy as np +# import pytest +# import torch +# import torch.nn as nn -from featureforest.models.MobileSAM import MobileSAMAdapter, get_model -from featureforest.utils.data import get_stack_dims -from embedding_extraction import check_embedding_extraction +# from featureforest.models.MobileSAM import MobileSAMAdapter, get_model +# from featureforest.utils.data import get_stack_dims +# from embedding_extraction import check_embedding_extraction -class MockMobileSAMEncoder(nn.Module): - def __init__(self): - super().__init__() - self.image_encoder = self.mock_encode - self.encoder_num_channels = 256 - self.embed_layer_num_channels = 64 +# class MockMobileSAMEncoder(nn.Module): +# def __init__(self): +# super().__init__() +# self.image_encoder = self.mock_encode +# self.encoder_num_channels = 256 +# self.embed_layer_num_channels = 64 - def mock_encode(self, x): - batch_size = x.shape[0] - output = torch.ones( - batch_size, - self.encoder_num_channels, - self.embed_layer_num_channels, - self.embed_layer_num_channels, - ) - embed_output = torch.ones( - batch_size, - self.embed_layer_num_channels, - self.encoder_num_channels, - self.encoder_num_channels, - ) - return output, embed_output, None +# def mock_encode(self, x): +# batch_size = x.shape[0] +# output = torch.ones( +# batch_size, +# self.encoder_num_channels, +# self.embed_layer_num_channels, +# self.embed_layer_num_channels, +# ) +# embed_output = torch.ones( +# batch_size, +# self.embed_layer_num_channels, +# self.encoder_num_channels, +# self.encoder_num_channels, +# ) +# return output, embed_output, None -def get_mock_model(img_height: float, img_width: float) -> MobileSAMAdapter: - model = MockMobileSAMEncoder() - device = torch.device("cpu") - sam_model_adapter = MobileSAMAdapter(model, img_height, img_width, device) - return sam_model_adapter +# def get_mock_model(img_height: float, img_width: float) -> MobileSAMAdapter: +# model = MockMobileSAMEncoder() +# device = torch.device("cpu") +# sam_model_adapter = MobileSAMAdapter(model, img_height, img_width, device) +# return sam_model_adapter -@pytest.mark.slow() -@pytest.mark.parametrize( - "test_patch", - [ - torch.ones((1, 3, 128, 128)), - torch.ones((3, 3, 128, 128)), - torch.ones((8, 3, 128, 128)), - torch.ones((8, 3, 256, 256)), - torch.ones((8, 3, 512, 512)) - ], -) -def test_mock_adapter(test_patch: np.ndarray): - real_adapter = get_model(512, 512) - mock_adapter = get_mock_model(512, 512) +# @pytest.mark.slow() +# @pytest.mark.parametrize( +# "test_patch", +# [ +# torch.ones((1, 3, 128, 128)), +# torch.ones((3, 3, 128, 128)), +# torch.ones((8, 3, 128, 128)), +# torch.ones((8, 3, 256, 256)), +# torch.ones((8, 3, 512, 512)) +# ], +# ) +# def test_mock_adapter(test_patch: np.ndarray): +# real_adapter = get_model(512, 512) +# mock_adapter = get_mock_model(512, 512) - transformed_input_patch_real = real_adapter.input_transforms(test_patch) - transformed_input_patch_mock = mock_adapter.input_transforms(test_patch) +# transformed_input_patch_real = real_adapter.input_transforms(test_patch) +# transformed_input_patch_mock = mock_adapter.input_transforms(test_patch) - result_real = real_adapter.encoder(transformed_input_patch_real) - mock_result = mock_adapter.encoder(transformed_input_patch_mock) +# result_real = real_adapter.encoder(transformed_input_patch_real) +# mock_result = mock_adapter.encoder(transformed_input_patch_mock) - assert len(result_real) == len(mock_result) - assert result_real[0].shape == mock_result[0].shape - assert result_real[1].shape == mock_result[1].shape +# assert len(result_real) == len(mock_result) +# assert result_real[0].shape == mock_result[0].shape +# assert result_real[1].shape == mock_result[1].shape -@pytest.mark.parametrize( - "test_image, expected_output_shape, expected_slices", - [ - (np.ones((256, 256)), (9, 128, 128, 320), 1), # 2D - (np.ones((256, 256, 3)), (9, 128, 128, 320), 1), # 2D RGB - (np.ones((2, 256, 256)), (9, 128, 128, 320), 2), # 3D - (np.ones((2, 256, 256, 3)), (9, 128, 128, 320), 2) # 3D RGB - ], -) -def test_mobilesam_embedding_extraction( - test_image: np.ndarray, expected_output_shape: tuple, expected_slices: int -): - num_slices, img_height, img_width = get_stack_dims(test_image) - model_adapter = get_mock_model(img_height, img_width) - check_embedding_extraction( - test_image, model_adapter, expected_output_shape, expected_slices - ) +# @pytest.mark.parametrize( +# "test_image, expected_output_shape, expected_slices", +# [ +# (np.ones((256, 256)), (9, 128, 128, 320), 1), # 2D +# (np.ones((256, 256, 3)), (9, 128, 128, 320), 1), # 2D RGB +# (np.ones((2, 256, 256)), (9, 128, 128, 320), 2), # 3D +# (np.ones((2, 256, 256, 3)), (9, 128, 128, 320), 2) # 3D RGB +# ], +# ) +# def test_mobilesam_embedding_extraction( +# test_image: np.ndarray, expected_output_shape: tuple, expected_slices: int +# ): +# num_slices, img_height, img_width = get_stack_dims(test_image) +# model_adapter = get_mock_model(img_height, img_width) +# check_embedding_extraction( +# test_image, model_adapter, expected_output_shape, expected_slices +# ) From 1ce1ea00cbe947375edfdb7679cef34dfbbda943 Mon Sep 17 00:00:00 2001 From: Mehdi Seifi Date: Fri, 13 Jun 2025 13:52:37 +0200 Subject: [PATCH 08/35] init using zarr storage --- .../_feature_extractor_widget.py | 11 +- src/featureforest/utils/dataset.py | 16 ++ src/featureforest/utils/extract.py | 181 +++++++----------- .../utils/pipeline_prediction.py | 29 +-- 4 files changed, 105 insertions(+), 132 deletions(-) diff --git a/src/featureforest/_feature_extractor_widget.py b/src/featureforest/_feature_extractor_widget.py index 47582e8..6c4e75b 100644 --- a/src/featureforest/_feature_extractor_widget.py +++ b/src/featureforest/_feature_extractor_widget.py @@ -146,14 +146,17 @@ def save_storage(self): # default storage name image_layer_name = self.image_combo.currentText() model_name = self.model_combo.currentText() - storage_name = f"{image_layer_name}_{model_name}.hdf5" + storage_name = f"{image_layer_name}_{model_name}" + if self.no_patching_checkbox.isChecked(): + storage_name += "_no_patching" + storage_name += ".zarr" # open the save dialog selected_file, _filter = QFileDialog.getSaveFileName( - self, "FeatureForest", storage_name, "Embeddings Storage(*.hdf5)" + self, "FeatureForest", storage_name, "Zarr Storage(*.zarr)" ) if selected_file is not None and len(selected_file) > 0: - if not selected_file.endswith(".hdf5"): - selected_file += ".hdf5" + if not selected_file.endswith(".zarr"): + selected_file += ".zarr" self.storage_textbox.setText(selected_file) self.extract_button.setEnabled(True) diff --git a/src/featureforest/utils/dataset.py b/src/featureforest/utils/dataset.py index b491915..d61857f 100644 --- a/src/featureforest/utils/dataset.py +++ b/src/featureforest/utils/dataset.py @@ -71,6 +71,22 @@ def __init__( self.image_files = self._natural_sort(self.image_files) self.image_source = pims.ImageSequence(map(str, self.image_files)) + @property + def num_images(self) -> int: + """Return the number of images in the dataset.""" + if self.image_source is None: + return 0 + return len(self.image_source) + + @property + def image_shape(self) -> tuple[int, int]: + """Return the shape of the images in the dataset.""" + if self.image_source is None: + raise ValueError("No image source is available. Please check the input data.") + if isinstance(self.image_source, np.ndarray): + return self.image_source.shape[1:] + return self.image_source.frame_shape + def __iter__(self): if self.image_source is None: raise ValueError("No image source is available. Please check the input data.") diff --git a/src/featureforest/utils/extract.py b/src/featureforest/utils/extract.py index c5be15f..3cc4ffe 100644 --- a/src/featureforest/utils/extract.py +++ b/src/featureforest/utils/extract.py @@ -1,20 +1,17 @@ from collections.abc import Generator from pathlib import Path -from typing import Optional, Union +from typing import Optional -import h5py import numpy as np import torch -from napari.utils import progress as np_progress +import zarr +import zarr.core +import zarr.storage +from numcodecs import Zstd from torch.utils.data import DataLoader from featureforest.models import BaseModelAdapter from featureforest.models.SAM import SAMAdapter -from featureforest.utils.data import ( - get_stride_margin, - is_image_rgb, - patchify, -) from featureforest.utils.dataset import FFImageDataset @@ -35,7 +32,7 @@ def get_batch_size(model_adapter: BaseModelAdapter) -> int: return batch_size -def get_dataset( +def get_image_dataset( image: str | np.ndarray, no_patching: bool = False, patch_size: int = 512, @@ -72,82 +69,38 @@ def get_dataset( return dataset -def get_slice_features( - image: np.ndarray, +def extract_embeddings( + image: np.ndarray | str, model_adapter: BaseModelAdapter, - storage_group: Optional[h5py.Group] = None, -) -> Generator[Union[int, tuple[int, np.ndarray]], None, None]: - """Extract features for one image/slice using the given model adapter and - save them into a storage file or yield them batch by batch. - - Args: - image (np.ndarray): Input image - model_adapter (BaseModelAdapter): Model adapter to extract features from - storage_group (Optional[h5py.Group]): h5 file group where the extracted features - will be saved. If None, will yield patch features batch by batch. - Returns: - Generator[Union[int, tuple[int, np.ndarray]]]: A generator yielding either the - current batch number or a tuple containing the current batch number - and the corresponding patch features. - """ - # image to torch tensor - img_data = torch.from_numpy(image.copy()).to(torch.float32) - # normalize in [0, 1] - _min = img_data.min() - _max = img_data.max() - img_data = (img_data - _min) / (_max - _min) - # for sam the input image should be 4D: BxCxHxW ; an RGB image. - if is_image_rgb(image): - # it's already RGB, put the channels first and add a batch dim. - img_data = img_data[..., :3] # ignore the Alpha channel (in case of PNG). - img_data = img_data.permute([2, 0, 1]).unsqueeze(0) - else: - img_data = img_data.unsqueeze(0).unsqueeze(0).expand(-1, 3, -1, -1) - - # get input patches + image_dataset: Optional[FFImageDataset] = None, +) -> Generator[tuple[np.ndarray, int, int], None, None]: + no_patching = model_adapter.no_patching patch_size = model_adapter.patch_size overlap = model_adapter.overlap - data_patches = patchify(img_data, patch_size, overlap) - num_patches = len(data_patches) - - # set a low batch size - batch_size = 8 - # for big SAM we need even lower batch size :( - if isinstance(model_adapter, SAMAdapter): - batch_size = 2 - - num_batches = int(np.ceil(num_patches / batch_size)) - # prepare storage for the image embeddings - if storage_group is not None: - total_channels = model_adapter.get_total_output_channels() - stride, _ = get_stride_margin(patch_size, overlap) - dataset = storage_group.create_dataset( - model_adapter.name, - shape=(num_patches, stride, stride, total_channels), - dtype=np.float16, - compression="lzf", - ) + batch_size = get_batch_size(model_adapter) - # get sam encoder output for image patches - print("extracting features:") - for b_idx in np_progress(range(num_batches), desc="extracting features"): - print(f"batch #{b_idx + 1} of {num_batches}") - start = b_idx * batch_size - end = start + batch_size - slice_features = model_adapter.get_features_patches( - data_patches[start:end].to(model_adapter.device) + # create the image dataset + if image_dataset is None: + image_dataset = get_image_dataset( + image=image, no_patching=no_patching, patch_size=patch_size, overlap=overlap ) - if isinstance(slice_features, tuple): # model with more than one output - slice_features = torch.cat(slice_features, dim=-1) - if storage_group is not None: - # to take care of the last batch size that might be smaller than batch_size - num_out = slice_features.shape[0] - dataset[start : start + num_out] = ( - slice_features.to(torch.float16).cpu().numpy() + # loop through the dataset and extract features + dataloader = DataLoader( + image_dataset, batch_size=batch_size, shuffle=False, num_workers=0 + ) + for img_data, indices in dataloader: + # print(f"images: {img_data.shape}\nslices: {indices}") + features = model_adapter.get_features_patches(img_data.to(model_adapter.device)) + print(f"batch features shape: {features.shape}") + unique_slices = torch.unique(indices[:, 0]).numpy() + # print(f"unique slices: {unique_slices}") + for idx in unique_slices: + img_features = features[indices[:, 0] == idx].numpy().astype(np.float16) + print( + f"image: {idx}, features shape: {img_features.shape}, {img_features.dtype}" ) - yield b_idx - else: - yield b_idx, slice_features.numpy() + + yield img_features, idx, len(unique_slices) def extract_embeddings_to_file( @@ -156,41 +109,41 @@ def extract_embeddings_to_file( no_patching = model_adapter.no_patching patch_size = model_adapter.patch_size overlap = model_adapter.overlap - batch_size = get_batch_size(model_adapter) + # batch_size = get_batch_size(model_adapter) - dataset = get_dataset( + # # create the image dataset + image_dataset = get_image_dataset( image=image, no_patching=no_patching, patch_size=patch_size, overlap=overlap ) - dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False) - for img_data, indices in dataloader: - print(f"images: {img_data.shape}\nslices: {indices}") - features = model_adapter.get_features_patches(img_data.to(model_adapter.device)) - print(f"features shape: {features.shape}") - unique_slices = torch.unique(indices[:, 0]).numpy() - print(f"unique slices: {unique_slices}") - for idx in unique_slices: - img_features = features[indices[:, 0] == idx].numpy().astype(np.float16) - print(f"image: {idx}, features shape: {img_features.shape}") - yield idx, len(unique_slices) - - # with h5py.File(storage_path, "w") as storage: - # num_slices, img_height, img_width = get_stack_dims(image) - - # storage.attrs["num_slices"] = num_slices - # storage.attrs["img_height"] = img_height - # storage.attrs["img_width"] = img_width - # storage.attrs["model"] = model_adapter.name - # storage.attrs["patch_size"] = patch_size - # storage.attrs["overlap"] = overlap - - # for slice_index in np_progress( - # range(num_slices), desc="extract features for slices" - # ): - # print(f"\nslice index: {slice_index}") - # slice_img = image[slice_index].copy() if num_slices > 1 else image.copy() - # slice_img = image_to_uint8(slice_img) # image must be an uint8 array - # slice_grp = storage.create_group(str(slice_index)) - # for _ in get_slice_features(slice_img, model_adapter, slice_grp): - # pass - - # yield slice_index, num_slices + # create the zarr storage + storage = zarr.storage.DirectoryStore(storage_path) + store_root = zarr.group(store=storage, overwrite=False) + store_root.attrs["num_slices"] = image_dataset.num_images + store_root.attrs["img_height"] = image_dataset.image_shape[0] + store_root.attrs["img_width"] = image_dataset.image_shape[1] + store_root.attrs["model"] = model_adapter.name + store_root.attrs["no_patching"] = no_patching + store_root.attrs["patch_size"] = patch_size + store_root.attrs["overlap"] = overlap + + for img_features, idx, total in extract_embeddings( + image, model_adapter, image_dataset + ): + if store_root.get(str(idx)) is None: + grp = store_root.create_group(str(idx)) # type: ignore + z_arr = grp.create( # type: ignore + name="features", + shape=img_features.shape, + chunks=(1,) + img_features.shape[1:], + dtype=np.float16, + compressor=Zstd(level=3), + ) + z_arr[:] = img_features + else: + # append features to the slice/image group + grp: zarr.core.Array = store_root[str(idx)]["features"] # type: ignore + grp.append(img_features) + + yield idx, total + + storage.close() diff --git a/src/featureforest/utils/pipeline_prediction.py b/src/featureforest/utils/pipeline_prediction.py index 87792e9..83a9c3c 100644 --- a/src/featureforest/utils/pipeline_prediction.py +++ b/src/featureforest/utils/pipeline_prediction.py @@ -5,7 +5,8 @@ from featureforest.models import BaseModelAdapter from featureforest.utils.data import get_num_patches, get_stride_margin -from featureforest.utils.extract import get_slice_features + +# from featureforest.utils.extract import get_slice_features def predict_patches( @@ -87,23 +88,23 @@ def extract_predict( img_height, img_width = image.shape[:2] patch_size = model_adapter.patch_size overlap = model_adapter.overlap - procs = [] + # procs = [] # prediction happens per batch of extracted features # in a separate process. with mp.Manager() as manager: result_dict = manager.dict() - for b_idx, patch_features in get_slice_features(image, model_adapter): - print(b_idx, end="\r") - proc = mp.Process( - target=predict_patches, - args=(patch_features, rf_model, model_adapter, b_idx, result_dict), - ) - procs.append(proc) - proc.start() - # wait until all processes are done - for p in procs: - if p.is_alive: - p.join() + # for b_idx, patch_features in get_slice_features(image, model_adapter): + # print(b_idx, end="\r") + # proc = mp.Process( + # target=predict_patches, + # args=(patch_features, rf_model, model_adapter, b_idx, result_dict), + # ) + # procs.append(proc) + # proc.start() + # # wait until all processes are done + # for p in procs: + # if p.is_alive: + # p.join() # collect results from each process batch_indices = sorted(result_dict.keys()) patch_masks = [result_dict[b] for b in batch_indices] From 123bfd94db1d715bdbbe045779f85fc083176f97 Mon Sep 17 00:00:00 2001 From: Mehdi Seifi Date: Fri, 13 Jun 2025 13:56:43 +0200 Subject: [PATCH 09/35] moved up no patching checkbox --- src/featureforest/_feature_extractor_widget.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/featureforest/_feature_extractor_widget.py b/src/featureforest/_feature_extractor_widget.py index 6c4e75b..90fadc6 100644 --- a/src/featureforest/_feature_extractor_widget.py +++ b/src/featureforest/_feature_extractor_widget.py @@ -60,17 +60,17 @@ def prepare_widget(self): self.model_combo.setEditable(False) self.model_combo.addItems(get_available_models()) self.model_combo.setCurrentIndex(0) + # no-patching checkbox + self.no_patching_checkbox = QCheckBox("No &Patching") + self.no_patching_checkbox.setToolTip( + "Whether divide an image into patches or not" + ) # storage storage_label = QLabel("Features Storage File:") self.storage_textbox = QLineEdit() self.storage_textbox.setReadOnly(True) storage_button = QPushButton("Set Storage File") storage_button.clicked.connect(self.save_storage) - # no-patching checkbox - self.no_patching_checkbox = QCheckBox("No &Patching") - self.no_patching_checkbox.setToolTip( - "Whether divide an image into patches or not" - ) # extract button self.extract_button = QPushButton("Extract Features") self.extract_button.setEnabled(False) @@ -99,6 +99,7 @@ def prepare_widget(self): vbox.addWidget(self.image_combo) vbox.addWidget(model_label) vbox.addWidget(self.model_combo) + vbox.addWidget(self.no_patching_checkbox) layout.addLayout(vbox) vbox = QVBoxLayout() @@ -109,7 +110,6 @@ def prepare_widget(self): hbox.addWidget(self.storage_textbox) hbox.addWidget(storage_button) vbox.addLayout(hbox) - vbox.addWidget(self.no_patching_checkbox) hbox = QHBoxLayout() hbox.setContentsMargins(0, 0, 0, 0) hbox.addWidget(self.extract_button, alignment=Qt.AlignLeft) From 2d8d91e6443584282ab38fa241ed00e3136d6804 Mon Sep 17 00:00:00 2001 From: Mehdi Seifi Date: Fri, 13 Jun 2025 14:40:08 +0200 Subject: [PATCH 10/35] updated extractor widget and no_patching option --- src/featureforest/_feature_extractor_widget.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/src/featureforest/_feature_extractor_widget.py b/src/featureforest/_feature_extractor_widget.py index 90fadc6..666402d 100644 --- a/src/featureforest/_feature_extractor_widget.py +++ b/src/featureforest/_feature_extractor_widget.py @@ -54,6 +54,7 @@ def prepare_widget(self): # input layer input_label = QLabel("Image Layer:") self.image_combo = QComboBox() + self.image_combo.currentIndexChanged.connect(self.image_changed) # model selection model_label = QLabel("Encoder Model:") self.model_combo = QComboBox() @@ -63,7 +64,8 @@ def prepare_widget(self): # no-patching checkbox self.no_patching_checkbox = QCheckBox("No &Patching") self.no_patching_checkbox.setToolTip( - "Whether divide an image into patches or not" + "Whether divide an image into patches or not; " + "\nOnly works for square images (height=width)" ) # storage storage_label = QLabel("Features Storage File:") @@ -142,6 +144,19 @@ def check_input_layers(self, event: Optional[Event] = None): if index > -1: self.image_combo.setCurrentIndex(index) + def image_changed(self, event: Optional[Event] = None) -> None: + # check if image is square so we can do no_patching + image_layer = get_layer( + self.viewer, self.image_combo.currentText(), config.NAPARI_IMAGE_LAYER + ) + if image_layer is not None: + _, img_height, img_width = get_stack_dims(image_layer.data) + if img_height != img_width: + self.no_patching_checkbox.setChecked(False) + self.no_patching_checkbox.setEnabled(False) + else: + self.no_patching_checkbox.setEnabled(True) + def save_storage(self): # default storage name image_layer_name = self.image_combo.currentText() @@ -173,6 +188,7 @@ def extract_embeddings(self): if storage_path is None or len(storage_path) < 6: notif.show_error("No storage path was set.") return + # initialize the selected model _, img_height, img_width = get_stack_dims(image_layer.data) model_name = self.model_combo.currentText() From b5e6edfb945790a49bdb5b9ca13036feeea66c1b Mon Sep 17 00:00:00 2001 From: Mehdi Seifi Date: Fri, 13 Jun 2025 15:34:41 +0200 Subject: [PATCH 11/35] updated segmentation widget using zarr storage --- src/featureforest/_segmentation_widget.py | 39 +++++++++++------------ 1 file changed, 19 insertions(+), 20 deletions(-) diff --git a/src/featureforest/_segmentation_widget.py b/src/featureforest/_segmentation_widget.py index aca9e8e..2b35528 100644 --- a/src/featureforest/_segmentation_widget.py +++ b/src/featureforest/_segmentation_widget.py @@ -8,12 +8,12 @@ from pathlib import Path from typing import Optional -import h5py import napari import napari.layers import napari.utils.notifications as notif import numpy as np import tifffile +import zarr from napari.qt.threading import create_worker from napari.utils import progress as np_progress from napari.utils.events import Event @@ -64,16 +64,16 @@ class SegmentationWidget(QWidget): def __init__(self, napari_viewer: napari.Viewer) -> None: super().__init__() self.viewer = napari_viewer - self.image_layer: Optional[napari.layers.Image] = None - self.gt_layer: Optional[napari.layers.Labels] = None - self.segmentation_layer: Optional[napari.layers.Labels] = None - self.postprocess_layer: Optional[napari.layers.Labels] = None - self.storage: Optional[str] = None - self.rf_model: Optional[RandomForestClassifier] = None - self.model_adapter: Optional[BaseModelAdapter] = None + self.image_layer: napari.layers.Image | None = None + self.gt_layer: napari.layers.Labels | None = None + self.segmentation_layer: napari.layers.Labels | None = None + self.postprocess_layer: napari.layers.Labels | None = None + self.storage: zarr.Group | None = None # type: ignore + self.rf_model: RandomForestClassifier | None = None + self.model_adapter: BaseModelAdapter | None = None self.sam_auto_masks = None self.patch_size = 512 # default values - self.overlap = 384 + self.overlap = 512 // 4 self.stride = self.patch_size - self.overlap self.stats = SegmentationUsageStats() @@ -547,13 +547,11 @@ def sam_auto_post_checked(self, checked: bool) -> None: self.sam_post_checkbox.setChecked(False) def select_storage(self) -> None: - selected_file, _filter = QFileDialog.getOpenFileName( - self, "FeatureForest", ".", "Feature Storage(*.hdf5)" - ) + selected_file = QFileDialog.getExistingDirectory(self, "FeatureForest", "..") if selected_file is not None and len(selected_file) > 0: self.storage_textbox.setText(selected_file) # load the storage - self.storage = h5py.File(selected_file, "r") + self.storage: zarr.Group = zarr.open(selected_file, mode="r") # type: ignore # set the patch size and overlap from the selected storage self.patch_size = self.storage.attrs.get("patch_size", self.patch_size) self.overlap = self.storage.attrs.get("overlap", self.overlap) @@ -562,10 +560,11 @@ def select_storage(self) -> None: # initialize the model based on the selected storage img_height = self.storage.attrs["img_height"] img_width = self.storage.attrs["img_width"] - # TODO: raise an error if current image dims are in conflicting with storage model_name = str(self.storage.attrs["model"]) + no_patching = self.storage.attrs["no_patching"] self.model_adapter = get_model(model_name, img_height, img_width) - print(model_name, self.patch_size, self.overlap) + self.model_adapter.no_patching = no_patching + print(model_name, self.patch_size, self.overlap, no_patching) # set the plugin usage stats csv file storage_path = Path(selected_file) @@ -614,7 +613,7 @@ def get_class_labels(self) -> dict[int, np.ndarray]: return labels_dict - def analyze_labels(self, labels_dict: Optional[dict]) -> None: + def analyze_labels(self, labels_dict: Optional[dict] = None) -> None: if labels_dict is None: labels_dict = self.get_class_labels() num_labels = [len(v) for v in labels_dict.values()] @@ -643,7 +642,7 @@ def get_train_data(self) -> tuple[np.ndarray, np.ndarray] | None: num_labels = sum([len(v) for v in labels_dict.values()]) total_channels = self.model_adapter.get_total_output_channels() train_data = np.zeros((num_labels, total_channels)) - labels = np.zeros(num_labels, dtype="int32") - 1 + labels = np.zeros(num_labels, dtype=np.int32) - 1 count = 0 for class_index in np_progress( labels_dict, desc="getting training data", total=len(labels_dict.keys()) @@ -660,7 +659,7 @@ def get_train_data(self) -> tuple[np.ndarray, np.ndarray] | None: slice_coords, img_height, img_width, self.patch_size, self.overlap ) grp_key = str(slice_index) - slice_dataset = self.storage[grp_key][self.model_adapter.name] + slice_dataset = self.storage[grp_key]["features"] for p_i in np.unique(patch_indices): patch_coords = slice_coords[patch_indices == p_i] patch_features = slice_dataset[p_i] @@ -869,8 +868,8 @@ def predict_slice( ) -> np.ndarray: """Predict a slice patch by patch""" segmentation_image = [] - # shape: N x target_size x target_size x C - feature_patches = self.storage[str(slice_index)][self.model_adapter.name][:] + # shape: N x stride x stride x C + feature_patches: np.ndarray = self.storage[str(slice_index)]["features"][:] num_patches = feature_patches.shape[0] total_channels = self.model_adapter.get_total_output_channels() for i in np_progress(range(num_patches), desc="Predicting slice patches"): From 3b68f8c2efabfc3a8bb725318ed4d519a1f2f5e4 Mon Sep 17 00:00:00 2001 From: Mehdi Seifi Date: Fri, 13 Jun 2025 15:51:37 +0200 Subject: [PATCH 12/35] clean up & formatting --- src/featureforest/utils/dataset.py | 3 ++- src/featureforest/utils/extract.py | 6 ------ 2 files changed, 2 insertions(+), 7 deletions(-) diff --git a/src/featureforest/utils/dataset.py b/src/featureforest/utils/dataset.py index d61857f..ca24a08 100644 --- a/src/featureforest/utils/dataset.py +++ b/src/featureforest/utils/dataset.py @@ -33,7 +33,8 @@ def __init__( super().__init__() if image_array is not None and (stack_file is not None or img_dir is not None): raise ValueError( - "Please provide either an image array or a stack file or an image directory, not both." + "Please provide either an image array or a stack file " + "or an image directory, not both." ) if image_array is None and stack_file is None and img_dir is None: raise ValueError( diff --git a/src/featureforest/utils/extract.py b/src/featureforest/utils/extract.py index 3cc4ffe..3b407d9 100644 --- a/src/featureforest/utils/extract.py +++ b/src/featureforest/utils/extract.py @@ -89,16 +89,10 @@ def extract_embeddings( image_dataset, batch_size=batch_size, shuffle=False, num_workers=0 ) for img_data, indices in dataloader: - # print(f"images: {img_data.shape}\nslices: {indices}") features = model_adapter.get_features_patches(img_data.to(model_adapter.device)) - print(f"batch features shape: {features.shape}") unique_slices = torch.unique(indices[:, 0]).numpy() - # print(f"unique slices: {unique_slices}") for idx in unique_slices: img_features = features[indices[:, 0] == idx].numpy().astype(np.float16) - print( - f"image: {idx}, features shape: {img_features.shape}, {img_features.dtype}" - ) yield img_features, idx, len(unique_slices) From 29994214b035d67b27457345780202b18aedd5c3 Mon Sep 17 00:00:00 2001 From: Mehdi Seifi Date: Fri, 13 Jun 2025 17:12:02 +0200 Subject: [PATCH 13/35] fixed extraction progress info --- src/featureforest/utils/extract.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/featureforest/utils/extract.py b/src/featureforest/utils/extract.py index 3b407d9..d076e57 100644 --- a/src/featureforest/utils/extract.py +++ b/src/featureforest/utils/extract.py @@ -94,7 +94,7 @@ def extract_embeddings( for idx in unique_slices: img_features = features[indices[:, 0] == idx].numpy().astype(np.float16) - yield img_features, idx, len(unique_slices) + yield img_features, idx, image_dataset.num_images def extract_embeddings_to_file( @@ -124,6 +124,7 @@ def extract_embeddings_to_file( image, model_adapter, image_dataset ): if store_root.get(str(idx)) is None: + # create a group for the slice grp = store_root.create_group(str(idx)) # type: ignore z_arr = grp.create( # type: ignore name="features", From 75f80e4bba594f106a1e393ddd9cd557ca94f1f3 Mon Sep 17 00:00:00 2001 From: Mehdi Seifi Date: Fri, 13 Jun 2025 17:58:13 +0200 Subject: [PATCH 14/35] updated dataset to handle images type inside --- src/featureforest/utils/dataset.py | 59 ++++++++----------- src/featureforest/utils/extract.py | 46 ++------------- .../utils/pipeline_prediction.py | 2 - 3 files changed, 30 insertions(+), 77 deletions(-) diff --git a/src/featureforest/utils/dataset.py b/src/featureforest/utils/dataset.py index ca24a08..19b2243 100644 --- a/src/featureforest/utils/dataset.py +++ b/src/featureforest/utils/dataset.py @@ -23,54 +23,47 @@ class FFImageDataset(IterableDataset): def __init__( self, - image_array: np.ndarray | None = None, - stack_file: str | Path | None = None, - img_dir: str | Path | None = None, + images: str | Path | np.ndarray, no_patching: bool = False, patch_size: int = 512, overlap: int = 128, ) -> None: super().__init__() - if image_array is not None and (stack_file is not None or img_dir is not None): - raise ValueError( - "Please provide either an image array or a stack file " - "or an image directory, not both." - ) - if image_array is None and stack_file is None and img_dir is None: - raise ValueError( - "Please provide either a large TIFF file or an image directory." - ) - self.no_patching = no_patching self.patch_size = patch_size self.overlap = overlap self.image_files = [] self.image_source: Optional[pims.ImageSequence | np.ndarray] = None - if image_array is not None: - # image is already loaded as a numpy array - self.image_source = image_array + if isinstance(images, np.ndarray): + # images are already loaded into a numpy array + self.image_source = images # add slice dimension if not present if self.image_source.ndim == 2: self.image_source = self.image_source[np.newaxis, ...] - elif stack_file is not None: - # can be a large stack, using pims for lazy loading - self.image_source = pims.open(str(stack_file)) - elif img_dir is not None: - # load images from a directory - img_dir = Path(img_dir) - if not img_dir.is_dir(): - raise ValueError(f"The provided path {img_dir} is not a directory.") - self.image_files = ( - list(img_dir.glob("*.tiff")) - + list(img_dir.glob("*.tif")) - + list(img_dir.glob("*.png")) - + list(img_dir.glob("*.jpg")) + + elif isinstance(images, str | Path): + images = Path(images) + if images.is_file(): + # can be a large stack, using pims for lazy loading + self.image_source = pims.open(str(images)) + + elif images.is_dir(): + self.image_files = ( + list(images.glob("*.tiff")) + + list(images.glob("*.tif")) + + list(images.glob("*.png")) + + list(images.glob("*.jpg")) + ) + if not self.image_files: + raise ValueError(f"No image files found in the directory {images}.") + self.image_files = self._natural_sort(self.image_files) + self.image_source = pims.ImageSequence(map(str, self.image_files)) + else: + raise ValueError( + f"images should be a numpy array or a directory or an image stack!" + f"\nGot {type(images)}" ) - if not self.image_files: - raise ValueError(f"No image files found in the directory {img_dir}.") - self.image_files = self._natural_sort(self.image_files) - self.image_source = pims.ImageSequence(map(str, self.image_files)) @property def num_images(self) -> int: diff --git a/src/featureforest/utils/extract.py b/src/featureforest/utils/extract.py index d076e57..6f96e65 100644 --- a/src/featureforest/utils/extract.py +++ b/src/featureforest/utils/extract.py @@ -1,5 +1,4 @@ from collections.abc import Generator -from pathlib import Path from typing import Optional import numpy as np @@ -32,43 +31,6 @@ def get_batch_size(model_adapter: BaseModelAdapter) -> int: return batch_size -def get_image_dataset( - image: str | np.ndarray, - no_patching: bool = False, - patch_size: int = 512, - overlap: int = 128, -) -> FFImageDataset: - if isinstance(image, str): - # image is a path to a large TIFF file or a directory of images - img_path = Path(image) - if img_path.is_dir(): - # load images from a directory - dataset = FFImageDataset( - img_dir=img_path, - no_patching=no_patching, - patch_size=patch_size, - overlap=overlap, - ) - else: - # load a (large) stack - dataset = FFImageDataset( - stack_file=img_path, - no_patching=no_patching, - patch_size=patch_size, - overlap=overlap, - ) - elif isinstance(image, np.ndarray): - # image is already loaded as a numpy array - dataset = FFImageDataset( - image_array=image, - no_patching=no_patching, - patch_size=patch_size, - overlap=overlap, - ) - - return dataset - - def extract_embeddings( image: np.ndarray | str, model_adapter: BaseModelAdapter, @@ -81,8 +43,8 @@ def extract_embeddings( # create the image dataset if image_dataset is None: - image_dataset = get_image_dataset( - image=image, no_patching=no_patching, patch_size=patch_size, overlap=overlap + image_dataset = FFImageDataset( + images=image, no_patching=no_patching, patch_size=patch_size, overlap=overlap ) # loop through the dataset and extract features dataloader = DataLoader( @@ -106,8 +68,8 @@ def extract_embeddings_to_file( # batch_size = get_batch_size(model_adapter) # # create the image dataset - image_dataset = get_image_dataset( - image=image, no_patching=no_patching, patch_size=patch_size, overlap=overlap + image_dataset = FFImageDataset( + images=image, no_patching=no_patching, patch_size=patch_size, overlap=overlap ) # create the zarr storage storage = zarr.storage.DirectoryStore(storage_path) diff --git a/src/featureforest/utils/pipeline_prediction.py b/src/featureforest/utils/pipeline_prediction.py index 83a9c3c..fbde6a5 100644 --- a/src/featureforest/utils/pipeline_prediction.py +++ b/src/featureforest/utils/pipeline_prediction.py @@ -6,8 +6,6 @@ from featureforest.models import BaseModelAdapter from featureforest.utils.data import get_num_patches, get_stride_margin -# from featureforest.utils.extract import get_slice_features - def predict_patches( patch_features: np.ndarray, From 7f5b78ccfd086416d4c65022888f2b89642c0900 Mon Sep 17 00:00:00 2001 From: Mehdi Seifi Date: Sat, 14 Jun 2025 15:39:27 +0200 Subject: [PATCH 15/35] updated prediction pipeline over a large stack --- src/featureforest/_segmentation_widget.py | 115 +++++++++--------- src/featureforest/utils/extract.py | 7 +- .../utils/pipeline_prediction.py | 96 +++++++-------- 3 files changed, 104 insertions(+), 114 deletions(-) diff --git a/src/featureforest/_segmentation_widget.py b/src/featureforest/_segmentation_widget.py index 2b35528..e3df5c6 100644 --- a/src/featureforest/_segmentation_widget.py +++ b/src/featureforest/_segmentation_widget.py @@ -51,7 +51,7 @@ get_stack_dims, get_stride_margin, ) -from .utils.pipeline_prediction import extract_predict +from .utils.pipeline_prediction import run_prediction_pipeline from .utils.usage_stats import SegmentationUsageStats from .widgets import ( ScrollWidgetWrapper, @@ -1085,6 +1085,9 @@ def run_pipeline_over_large_stack(self) -> None: def run_pipeline( self, tiff_stack_file: str, result_dir: Path ) -> Generator[tuple[int, int], None, None]: + if self.model_adapter is None or self.rf_model is None: + raise ValueError("RF model and/or Model Adapter are missing!") + start = dt.datetime.now() slices_total_time = 0 postprocess_total_time = 0 @@ -1095,67 +1098,63 @@ def run_pipeline( sam_post_dir = result_dir.joinpath("post_sam") sam_post_dir.mkdir(parents=True, exist_ok=True) - with TiffFile(tiff_stack_file) as tiff_stack: - total_pages = len(tiff_stack.pages) - for page_idx, page in np_progress( - enumerate(tiff_stack.pages), - desc="runing the pipeline", - total=len(tiff_stack.pages), - ): - slice_start = time.perf_counter() - image = page.asarray() - prediction_mask = extract_predict( - image, self.model_adapter, self.rf_model - ) - # save the prediction - tifffile.imwrite( - prediction_dir.joinpath(f"slice_{page_idx:04}_prediction.tiff"), - prediction_mask, - ) - # post-processing - smoothing_iterations, area_threshold, area_is_absolute = ( - self.get_postprocess_params() - ) - post_mask = postprocess( - prediction_mask, - smoothing_iterations, - area_threshold, - area_is_absolute, - ) - tifffile.imwrite( - simple_post_dir.joinpath(f"slice_{page_idx:04}_post_simple.tiff"), - post_mask, - ) - pp_start = time.perf_counter() - post_sam_mask = postprocess_with_sam( - prediction_mask, - smoothing_iterations, - area_threshold, - area_is_absolute, - ) - tifffile.imwrite( - sam_post_dir.joinpath(f"slice_{page_idx:04}_post_sam.tiff"), - post_sam_mask, - ) - # slices timing - postprocess_total_time += round(time.perf_counter() - pp_start, 2) - slices_total_time += round(time.perf_counter() - slice_start, 2) - slice_avg = slices_total_time / (page_idx + 1) - rem_minutes, rem_seconds = divmod( - slice_avg * (total_pages - page_idx + 1), 60 - ) - rem_hour, rem_minutes = divmod(rem_minutes, 60) - self.timing_info.setText( - f"Estimated remaining time: " - f"{int(rem_hour):02}:{int(rem_minutes):02}:{int(rem_seconds):02}" - ) - print(f"slice average time(seconds): {slice_avg:.2f}") + slice_start = time.perf_counter() + for slice_mask, idx, total in np_progress( + run_prediction_pipeline( + tiff_stack_file, + self.model_adapter, + self.rf_model, + ), + desc="running the pipeline", + ): + # save the prediction + tifffile.imwrite( + prediction_dir.joinpath(f"slice_{idx:04}_prediction.tiff"), + slice_mask, + ) + # post-processing + smoothing_iterations, area_threshold, area_is_absolute = ( + self.get_postprocess_params() + ) + post_mask = postprocess( + slice_mask, + smoothing_iterations, + area_threshold, + area_is_absolute, + ) + tifffile.imwrite( + simple_post_dir.joinpath(f"slice_{idx:04}_post_simple.tiff"), + post_mask, + ) + pp_start = time.perf_counter() + post_sam_mask = postprocess_with_sam( + slice_mask, + smoothing_iterations, + area_threshold, + area_is_absolute, + ) + tifffile.imwrite( + sam_post_dir.joinpath(f"slice_{idx:04}_post_sam.tiff"), + post_sam_mask, + ) + # slices timing + postprocess_total_time += round(time.perf_counter() - pp_start, 2) + slices_total_time += round(time.perf_counter() - slice_start, 2) + slice_avg = slices_total_time / (idx + 1) + rem_minutes, rem_seconds = divmod(slice_avg * (total - idx + 1), 60) + rem_hour, rem_minutes = divmod(rem_minutes, 60) + self.timing_info.setText( + f"Estimated remaining time: " + f"{int(rem_hour):02}:{int(rem_minutes):02}:{int(rem_seconds):02}" + ) + print(f"slice average time(seconds): {slice_avg:.2f}") + slice_start = time.perf_counter() - yield page_idx, total_pages + yield idx, total # stack is done end = dt.datetime.now() self.save_pipeline_stats( - result_dir, start, end, slices_total_time, postprocess_total_time, total_pages + result_dir, start, end, slices_total_time, postprocess_total_time, total ) def stop_pipeline(self) -> None: diff --git a/src/featureforest/utils/extract.py b/src/featureforest/utils/extract.py index 6f96e65..54a130b 100644 --- a/src/featureforest/utils/extract.py +++ b/src/featureforest/utils/extract.py @@ -32,8 +32,8 @@ def get_batch_size(model_adapter: BaseModelAdapter) -> int: def extract_embeddings( - image: np.ndarray | str, model_adapter: BaseModelAdapter, + image: Optional[np.ndarray | str] = None, image_dataset: Optional[FFImageDataset] = None, ) -> Generator[tuple[np.ndarray, int, int], None, None]: no_patching = model_adapter.no_patching @@ -43,6 +43,8 @@ def extract_embeddings( # create the image dataset if image_dataset is None: + if image is None: + raise ValueError("Should pass either the image or the image_dataset!") image_dataset = FFImageDataset( images=image, no_patching=no_patching, patch_size=patch_size, overlap=overlap ) @@ -65,7 +67,6 @@ def extract_embeddings_to_file( no_patching = model_adapter.no_patching patch_size = model_adapter.patch_size overlap = model_adapter.overlap - # batch_size = get_batch_size(model_adapter) # # create the image dataset image_dataset = FFImageDataset( @@ -83,7 +84,7 @@ def extract_embeddings_to_file( store_root.attrs["overlap"] = overlap for img_features, idx, total in extract_embeddings( - image, model_adapter, image_dataset + model_adapter, image_dataset=image_dataset ): if store_root.get(str(idx)) is None: # create a group for the slice diff --git a/src/featureforest/utils/pipeline_prediction.py b/src/featureforest/utils/pipeline_prediction.py index fbde6a5..5bec65b 100644 --- a/src/featureforest/utils/pipeline_prediction.py +++ b/src/featureforest/utils/pipeline_prediction.py @@ -1,40 +1,38 @@ -import multiprocessing as mp +from collections.abc import Generator import numpy as np from sklearn.ensemble import RandomForestClassifier as RF from featureforest.models import BaseModelAdapter from featureforest.utils.data import get_num_patches, get_stride_margin +from featureforest.utils.dataset import FFImageDataset +from featureforest.utils.extract import extract_embeddings def predict_patches( - patch_features: np.ndarray, + feature_list: list[np.ndarray], rf_model: RF, model_adapter: BaseModelAdapter, - batch_idx: int, - result_dict: dict, -) -> None: - """Predicts the class labels for a given set of patch features. +) -> np.ndarray: + """Predicts the class labels for a given list of features (patches). Args: - patch_features (np.ndarray): Patch features to be predicted. + feature_list (list[np.ndarray]): List of features to be predicted. rf_model (RF): Random Forest Model used for predictions. model_adapter (BaseModelAdapter): Model adapter object used for extracting data. - batch_idx (int): Batch index of the current patch features. - result_dict (dict): Dictionary where the predicted masks will be stored. """ patch_masks = [] - # shape: N x target_size x target_size x C - num_patches = patch_features.shape[0] + # shape: N x stride x stride x C + num_patches = len(feature_list) total_channels = model_adapter.get_total_output_channels() print(f"predicting {num_patches} patches...") for i in range(num_patches): - patch_data = patch_features[i].reshape(-1, total_channels) + patch_data = feature_list[i].reshape(-1, total_channels) pred = rf_model.predict(patch_data).astype(np.uint8) patch_masks.append(pred) patch_masks = np.vstack(patch_masks) - result_dict[batch_idx] = patch_masks + return patch_masks def get_image_mask( @@ -68,47 +66,39 @@ def get_image_mask( return mask_image -def extract_predict( - image: np.ndarray, +def run_prediction_pipeline( + input_stack: str, model_adapter: BaseModelAdapter, rf_model: RF, -) -> np.ndarray: - """Extracts features and predicts the classes for a given image. - - Args: - image (np.ndarray): Input image to extract features from. - model_adapter (BaseModelAdapter): Model adapter object used for extracting data. - rf_model (RF): Random Forest Model used for predictions. - - Returns: - np.ndarray: Final image mask. - """ - img_height, img_width = image.shape[:2] +) -> Generator[tuple[np.ndarray, int, int], None, None]: + no_patching = model_adapter.no_patching patch_size = model_adapter.patch_size overlap = model_adapter.overlap - # procs = [] - # prediction happens per batch of extracted features - # in a separate process. - with mp.Manager() as manager: - result_dict = manager.dict() - # for b_idx, patch_features in get_slice_features(image, model_adapter): - # print(b_idx, end="\r") - # proc = mp.Process( - # target=predict_patches, - # args=(patch_features, rf_model, model_adapter, b_idx, result_dict), - # ) - # procs.append(proc) - # proc.start() - # # wait until all processes are done - # for p in procs: - # if p.is_alive: - # p.join() - # collect results from each process - batch_indices = sorted(result_dict.keys()) - patch_masks = [result_dict[b] for b in batch_indices] - patch_masks = np.vstack(patch_masks) - slice_mask = get_image_mask( - patch_masks, img_height, img_width, patch_size, overlap - ) - - return slice_mask + stack_dataset = FFImageDataset( + images=input_stack, + no_patching=no_patching, + patch_size=patch_size, + overlap=overlap, + ) + img_height, img_width = stack_dataset.image_shape + + prev_idx = 0 + slice_features = [] + for img_features, slice_idx, total in extract_embeddings( + model_adapter, image_dataset=stack_dataset + ): + if prev_idx != slice_idx: + # we have one slice features extracted: make a prediction. + patch_masks = predict_patches(slice_features, rf_model, model_adapter) + slice_mask = get_image_mask( + patch_masks, img_height, img_width, patch_size, overlap + ) + yield slice_mask, slice_idx, total + + # start collecting next slice features + prev_idx = slice_idx + slice_features = [] + slice_features.append(img_features) + + # collect slice features + slice_features.append(img_features) From 5bfe8c9b9f52c06b0ab5336305d7e81b3eb4fef5 Mon Sep 17 00:00:00 2001 From: Mehdi Seifi Date: Sat, 14 Jun 2025 18:06:40 +0200 Subject: [PATCH 16/35] fixed run pipeline; file dialogs open in parent dialog --- src/featureforest/_segmentation_widget.py | 9 ++++++--- src/featureforest/utils/extract.py | 3 ++- .../utils/pipeline_prediction.py | 19 +++++++++++++------ 3 files changed, 21 insertions(+), 10 deletions(-) diff --git a/src/featureforest/_segmentation_widget.py b/src/featureforest/_segmentation_widget.py index e3df5c6..92b6ecc 100644 --- a/src/featureforest/_segmentation_widget.py +++ b/src/featureforest/_segmentation_widget.py @@ -719,7 +719,7 @@ def train_model(self) -> None: def load_rf_model(self) -> None: selected_file, _filter = QFileDialog.getOpenFileName( - self, "FeatureForest", ".", "model(*.bin)" + self, "FeatureForest", "..", "model(*.bin)" ) if len(selected_file) > 0: # to suppress the sklearn InconsistentVersionWarning @@ -758,7 +758,7 @@ def save_rf_model(self) -> None: notif.show_info("There is no trained model!") return selected_file, _filter = QFileDialog.getSaveFileName( - self, "FeatureForest", ".", "model(*.bin)" + self, "FeatureForest", "..", "model(*.bin)" ) if len(selected_file) > 0: if not selected_file.endswith(".bin"): @@ -1020,7 +1020,7 @@ def export_segmentation(self) -> None: def select_stack(self) -> None: selected_file, _filter = QFileDialog.getOpenFileName( - self, "FeatureForest", ".", "TIFF stack (*.tiff *.tif)" + self, "FeatureForest", "..", "TIFF stack (*.tiff *.tif)" ) if selected_file is not None and len(selected_file) > 0: # get stack info @@ -1098,6 +1098,7 @@ def run_pipeline( sam_post_dir = result_dir.joinpath("post_sam") sam_post_dir.mkdir(parents=True, exist_ok=True) + self.rf_model.set_params(verbose=0) slice_start = time.perf_counter() for slice_mask, idx, total in np_progress( run_prediction_pipeline( @@ -1113,6 +1114,7 @@ def run_pipeline( slice_mask, ) # post-processing + print("post processing...") smoothing_iterations, area_threshold, area_is_absolute = ( self.get_postprocess_params() ) @@ -1156,6 +1158,7 @@ def run_pipeline( self.save_pipeline_stats( result_dir, start, end, slices_total_time, postprocess_total_time, total ) + self.rf_model.set_params(verbose=1) def stop_pipeline(self) -> None: if self.pipeline_worker is not None: diff --git a/src/featureforest/utils/extract.py b/src/featureforest/utils/extract.py index 54a130b..d5bd9c6 100644 --- a/src/featureforest/utils/extract.py +++ b/src/featureforest/utils/extract.py @@ -44,7 +44,7 @@ def extract_embeddings( # create the image dataset if image_dataset is None: if image is None: - raise ValueError("Should pass either the image or the image_dataset!") + raise ValueError("You should pass either the image or the image_dataset!") image_dataset = FFImageDataset( images=image, no_patching=no_patching, patch_size=patch_size, overlap=overlap ) @@ -52,6 +52,7 @@ def extract_embeddings( dataloader = DataLoader( image_dataset, batch_size=batch_size, shuffle=False, num_workers=0 ) + print(f"Start extracting features for {image_dataset.num_images} slices...") for img_data, indices in dataloader: features = model_adapter.get_features_patches(img_data.to(model_adapter.device)) unique_slices = torch.unique(indices[:, 0]).numpy() diff --git a/src/featureforest/utils/pipeline_prediction.py b/src/featureforest/utils/pipeline_prediction.py index 5bec65b..1d75b85 100644 --- a/src/featureforest/utils/pipeline_prediction.py +++ b/src/featureforest/utils/pipeline_prediction.py @@ -23,11 +23,12 @@ def predict_patches( """ patch_masks = [] # shape: N x stride x stride x C - num_patches = len(feature_list) + patch_features = np.vstack(feature_list) + num_patches = len(patch_features) total_channels = model_adapter.get_total_output_channels() print(f"predicting {num_patches} patches...") for i in range(num_patches): - patch_data = feature_list[i].reshape(-1, total_channels) + patch_data = patch_features[i].reshape(-1, total_channels) pred = rf_model.predict(patch_data).astype(np.uint8) patch_masks.append(pred) @@ -87,18 +88,24 @@ def run_prediction_pipeline( for img_features, slice_idx, total in extract_embeddings( model_adapter, image_dataset=stack_dataset ): + print(f"{slice_idx} / {total}") if prev_idx != slice_idx: # we have one slice features extracted: make a prediction. patch_masks = predict_patches(slice_features, rf_model, model_adapter) slice_mask = get_image_mask( patch_masks, img_height, img_width, patch_size, overlap ) - yield slice_mask, slice_idx, total + + yield slice_mask, prev_idx, total # start collecting next slice features prev_idx = slice_idx slice_features = [] slice_features.append(img_features) - - # collect slice features - slice_features.append(img_features) + else: + # collect slice features + slice_features.append(img_features) + # make prediction for the last slice + patch_masks = predict_patches(slice_features, rf_model, model_adapter) + slice_mask = get_image_mask(patch_masks, img_height, img_width, patch_size, overlap) + yield slice_mask, slice_idx, total From 6393b2afa9a7cb719b92b4674b929545e9b9f8a0 Mon Sep 17 00:00:00 2001 From: Mehdi Seifi Date: Sun, 15 Jun 2025 16:23:20 +0200 Subject: [PATCH 17/35] updated run_pipeline script --- run_pipeline.ipynb | 436 --------------------------------------------- run_pipeline.py | 221 +++++------------------ 2 files changed, 45 insertions(+), 612 deletions(-) delete mode 100755 run_pipeline.ipynb diff --git a/run_pipeline.ipynb b/run_pipeline.ipynb deleted file mode 100755 index 4a47619..0000000 --- a/run_pipeline.ipynb +++ /dev/null @@ -1,436 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Installation and Requirements\n", - "\n", - "Please refer to the [_featureforest repo_](https://github.com/juglab/featureforest)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import pickle\n", - "from pathlib import Path\n", - "\n", - "import h5py\n", - "import numpy as np\n", - "import torch\n", - "from PIL import Image, ImageSequence\n", - "from sklearn.ensemble import RandomForestClassifier\n", - "\n", - "from tqdm.notebook import trange, tqdm\n", - "\n", - "from featureforest.models import get_available_models, get_model\n", - "from featureforest.models.SAM import SAMAdapter\n", - "from featureforest.utils.data import (\n", - " patchify,\n", - " is_image_rgb, get_stride_margin,\n", - " get_num_patches, get_stride_margin\n", - ")\n", - "from featureforest.postprocess import (\n", - " postprocess,\n", - " postprocess_with_sam, postprocess_with_sam_auto,\n", - " get_sam_auto_masks\n", - ")\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Utility functions" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "def get_slice_features(\n", - " image: np.ndarray,\n", - " patch_size: int,\n", - " overlap: int,\n", - " model_adapter,\n", - " storage_group,\n", - "):\n", - " \"\"\"Extract the model features for one slice and save them into storage file.\"\"\"\n", - " # image to torch tensor\n", - " img_data = torch.from_numpy(image).to(torch.float32) / 255.0\n", - " # for sam the input image should be 4D: BxCxHxW ; an RGB image.\n", - " if is_image_rgb(image):\n", - " # it's already RGB, put the channels first and add a batch dim.\n", - " img_data = img_data[..., :3] # ignore the Alpha channel (in case of PNG).\n", - " img_data = img_data.permute([2, 0, 1]).unsqueeze(0)\n", - " else:\n", - " img_data = img_data.unsqueeze(0).unsqueeze(0).expand(-1, 3, -1, -1)\n", - "\n", - " # get input patches\n", - " data_patches = patchify(img_data, patch_size, overlap)\n", - " num_patches = len(data_patches)\n", - "\n", - " # set a low batch size\n", - " batch_size = 8\n", - " # for big SAM we need even lower batch size :(\n", - " if isinstance(model_adapter, SAMAdapter):\n", - " batch_size = 2\n", - "\n", - " num_batches = int(np.ceil(num_patches / batch_size))\n", - " # prepare storage for the slice embeddings\n", - " total_channels = model_adapter.get_total_output_channels()\n", - " stride, _ = get_stride_margin(patch_size, overlap)\n", - "\n", - " if model_adapter.name not in storage_group:\n", - " dataset = storage_group.create_dataset(\n", - " model_adapter.name, shape=(num_patches, stride, stride, total_channels)\n", - " )\n", - " else:\n", - " dataset = storage_group[model_adapter.name]\n", - "\n", - " # get sam encoder output for image patches\n", - " # print(\"\\nextracting slice features:\")\n", - " for b_idx in tqdm(range(num_batches), desc=\"extracting slice feature:\"):\n", - " # print(f\"batch #{b_idx + 1} of {num_batches}\")\n", - " start = b_idx * batch_size\n", - " end = start + batch_size\n", - " slice_features = model_adapter.get_features_patches(\n", - " data_patches[start:end].to(model_adapter.device)\n", - " )\n", - " if not isinstance(slice_features, tuple):\n", - " # model has only one output\n", - " num_out = slice_features.shape[0] # to take care of the last batch size\n", - " dataset[start : start + num_out] = slice_features\n", - " else:\n", - " # model has more than one output: put them into storage one by one\n", - " ch_start = 0\n", - " for feat in slice_features:\n", - " num_out = feat.shape[0]\n", - " ch_end = ch_start + feat.shape[-1] # number of features\n", - " dataset[start : start + num_out, :, :, ch_start:ch_end] = feat\n", - " ch_start = ch_end\n", - "\n", - "\n", - "def predict_slice(\n", - " rf_model, patch_dataset, model_adapter,\n", - " img_height, img_width, patch_size, overlap\n", - "):\n", - " \"\"\"Predict a slice patch by patch\"\"\"\n", - " segmentation_image = []\n", - " # shape: N x target_size x target_size x C\n", - " feature_patches = patch_dataset[:]\n", - " num_patches = feature_patches.shape[0]\n", - " total_channels = model_adapter.get_total_output_channels()\n", - " stride, margin = get_stride_margin(patch_size, overlap)\n", - "\n", - " for i in tqdm(\n", - " range(num_patches), desc=\"Predicting slice patches\", position=1, leave=True\n", - " ):\n", - " input_data = feature_patches[i].reshape(-1, total_channels)\n", - " predictions = rf_model.predict(input_data).astype(np.uint8)\n", - " segmentation_image.append(predictions)\n", - "\n", - " segmentation_image = np.vstack(segmentation_image)\n", - " # reshape into the image size + padding\n", - " patch_rows, patch_cols = get_num_patches(\n", - " img_height, img_width, patch_size, overlap\n", - " )\n", - " segmentation_image = segmentation_image.reshape(\n", - " patch_rows, patch_cols, stride, stride\n", - " )\n", - " segmentation_image = np.moveaxis(segmentation_image, 1, 2).reshape(\n", - " patch_rows * stride,\n", - " patch_cols * stride\n", - " )\n", - " # skip paddings\n", - " segmentation_image = segmentation_image[:img_height, :img_width]\n", - "\n", - " return segmentation_image\n", - "\n", - "\n", - "def apply_postprocessing(\n", - " input_image, segmentation_image,\n", - " smoothing_iterations, area_threshold, area_is_absolute,\n", - " use_sam_predictor, use_sam_autoseg, iou_threshold\n", - "):\n", - " post_masks = {}\n", - " # if not use_sam_predictor and not use_sam_autoseg:\n", - " mask = postprocess(\n", - " segmentation_image, smoothing_iterations,\n", - " area_threshold, area_is_absolute\n", - " )\n", - " post_masks[\"Simple\"] = mask\n", - "\n", - " if use_sam_predictor:\n", - " mask = postprocess_with_sam(\n", - " segmentation_image,\n", - " smoothing_iterations, area_threshold, area_is_absolute\n", - " )\n", - " post_masks[\"SAMPredictor\"] = mask\n", - "\n", - " if use_sam_autoseg:\n", - " sam_auto_masks = get_sam_auto_masks(input_image)\n", - " mask = postprocess_with_sam_auto(\n", - " sam_auto_masks,\n", - " segmentation_image,\n", - " smoothing_iterations, iou_threshold,\n", - " area_threshold, area_is_absolute\n", - " )\n", - " post_masks[\"SAMAutoSegmentation\"] = mask\n", - "\n", - "\n", - " return post_masks" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Set the Input, RF Model and the result directory paths" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# input image\n", - "data_path = \"../datasets/data.tif\"\n", - "data_path = Path(data_path)\n", - "print(f\"data_path exists: {data_path.exists()}\")\n", - "\n", - "# random forest model\n", - "rf_model_path = \"../datasets/rf_model.bin\"\n", - "rf_model_path = Path(rf_model_path)\n", - "print(f\"rf_model_path exists: {rf_model_path.exists()}\")\n", - "\n", - "# result folder\n", - "segmentation_dir = Path(\"../datasets/segmentation_result\")\n", - "segmentation_dir.mkdir(parents=True, exist_ok=True)\n", - "\n", - "# temporary storage path for saving extracted embeddings patches\n", - "storage_path = \"./temp_storage.hdf5\"" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Prepare the Input and RF Model" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# get patch sizes\n", - "input_stack = Image.open(data_path)\n", - "\n", - "num_slices = input_stack.n_frames\n", - "img_height = input_stack.height\n", - "img_width = input_stack.width\n", - "\n", - "print(num_slices, img_height, img_width)\n", - "# print(patch_size, target_patch_size)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "with open(rf_model_path, mode=\"rb\") as f:\n", - " model_data = pickle.load(f)\n", - "# compatibility check for old format rf model\n", - "if isinstance(model_data, dict): # noqa: SIM108\n", - " # new format\n", - " rf_model = model_data[\"rf_model\"]\n", - "else:\n", - " # old format\n", - " rf_model = model_data\n", - "\n", - "rf_model.set_params(verbose=0)\n", - "rf_model" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Initializing the Model for Feature Extraction" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# list of available models\n", - "get_available_models()" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [], - "source": [ - "model_name = \"MobileSAM\"" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", - "# print(f\"running on {device}\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "model_adapter = get_model(model_name, img_height, img_width)\n", - "\n", - "patch_size = model_adapter.patch_size\n", - "overlap = model_adapter.overlap\n", - "\n", - "patch_size, overlap" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": {}, - "outputs": [], - "source": [ - "# post-processing parameters\n", - "do_postprocess = True\n", - "\n", - "smoothing_iterations = 25\n", - "area_threshold = 100 # to ignore mask regions with area below this threshold\n", - "area_is_absolute = True # is area is based on pixels or pecentage (False)\n", - "\n", - "use_sam_predictor = True\n", - "use_sam_autoseg = False\n", - "sam_autoseg_iou_threshold = 0.4" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Prediction" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": {}, - "outputs": [], - "source": [ - "# create the slice temporary storage\n", - "storage = h5py.File(storage_path, \"w\")\n", - "storage_group = storage.create_group(\"slice\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "for i, page in tqdm(\n", - " enumerate(ImageSequence.Iterator(input_stack)),\n", - " desc=\"Slices\", total=num_slices, position=0\n", - "):\n", - " # print(f\"slice {i + 1}\", end=\"\\n\")\n", - " slice_img = np.array(page.convert(\"RGB\"))\n", - "\n", - " get_slice_features(slice_img, patch_size, overlap, model_adapter, storage_group)\n", - "\n", - " segmentation_image = predict_slice(\n", - " rf_model, storage_group[model_adapter.name], model_adapter,\n", - " img_height, img_width,\n", - " patch_size, overlap\n", - " )\n", - "\n", - " img = Image.fromarray(segmentation_image)\n", - " img.save(segmentation_dir.joinpath(f\"slice_{i:04}_prediction.tiff\"))\n", - "\n", - " if do_postprocess:\n", - " post_masks = apply_postprocessing(\n", - " slice_img, segmentation_image,\n", - " smoothing_iterations, area_threshold, area_is_absolute,\n", - " use_sam_predictor, use_sam_autoseg, sam_autoseg_iou_threshold\n", - " )\n", - " # save results\n", - " for name, mask in post_masks.items():\n", - " img = Image.fromarray(mask)\n", - " seg_dir = segmentation_dir.joinpath(name)\n", - " seg_dir.mkdir(exist_ok=True) \n", - " img.save(seg_dir.joinpath(f\"slice_{i:04}_{name}.tiff\"))\n", - "\n", - "\n", - "\n", - "if storage is not None:\n", - " storage.close()\n", - " storage = None\n", - "Path(storage_path).unlink()" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": {}, - "outputs": [], - "source": [ - "if storage is not None:\n", - " storage.close()\n", - " Path(storage_path).unlink()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "project52", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.13" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/run_pipeline.py b/run_pipeline.py index b09f473..1118295 100755 --- a/run_pipeline.py +++ b/run_pipeline.py @@ -1,143 +1,22 @@ import argparse -import multiprocessing as mp import pickle import time -from collections.abc import Generator from pathlib import Path import numpy as np -import torch -from PIL import Image, ImageSequence -from sklearn.ensemble import RandomForestClassifier as RF +import pims +import tifffile -from featureforest.models import BaseModelAdapter, get_available_models, get_model -from featureforest.models.SAM import SAMAdapter +from featureforest.models import get_available_models, get_model + +# from featureforest.models.SAM import SAMAdapter from featureforest.postprocess import ( get_sam_auto_masks, postprocess, postprocess_with_sam, postprocess_with_sam_auto, ) -from featureforest.utils.data import ( - get_num_patches, - get_stride_margin, - is_image_rgb, - patchify, -) - - -def predict_patches( - patch_features: np.ndarray, - rf_model: RF, - model_adapter: BaseModelAdapter, - batch_idx: int, - result_dict: dict, -) -> None: - """Predicts the class labels for a given set of patch features. - - Args: - patch_features (np.ndarray): Patch features to be predicted. - rf_model (RF): Random Forest Model used for predictions. - model_adapter (BaseModelAdapter): Model adapter object used for extracting data. - batch_idx (int): Batch index of the current patch features. - result_dict (dict): Dictionary where the predicted masks will be stored. - """ - patch_masks = [] - # shape: N x target_size x target_size x C - num_patches = patch_features.shape[0] - total_channels = model_adapter.get_total_output_channels() - print(f"predicting {num_patches} patches...") - for i in range(num_patches): - patch_data = patch_features[i].reshape(-1, total_channels) - pred = rf_model.predict(patch_data).astype(np.uint8) - patch_masks.append(pred) - - patch_masks = np.vstack(patch_masks) - result_dict[batch_idx] = patch_masks - - -def get_image_mask( - patch_masks: np.ndarray, - img_height: int, - img_width: int, - patch_size: int, - overlap: int, -) -> np.ndarray: - """Gets the final image mask by combining the individual patch masks. - - Args: - patch_masks (ndarray): Patch masks to combine into an image mask. - img_height (int): Height of the input image. - img_width (int): Width of the input image. - patch_size (int): Size of the patches. - overlap (int): Overlap between adjacent patches. - - Returns: - np.ndarray: Final image mask. - """ - stride, _ = get_stride_margin(patch_size, overlap) - patch_rows, patch_cols = get_num_patches(img_height, img_width, patch_size, overlap) - mask_image = patch_masks.reshape(patch_rows, patch_cols, stride, stride) - mask_image = np.moveaxis(mask_image, 1, 2).reshape( - patch_rows * stride, patch_cols * stride - ) - # skip paddings - mask_image = mask_image[:img_height, :img_width] - - return mask_image - - -def get_slice_features( - image: np.ndarray, model_adapter: BaseModelAdapter -) -> Generator[tuple[int, np.ndarray], None, None]: - """Extract features for one image using the given model adapter - - Args: - image: Input image array - model_adapter: Model adapter to extract features from - Returns: - Generator yielding tuples containing batch index and extracted features. - """ - # image to torch tensor - img_data = torch.from_numpy(image).to(torch.float32) - # normalize in [0, 1] - _min = img_data.min() - _max = img_data.max() - img_data = (img_data - _min) / (_max - _min) - # for sam the input image should be 4D: BxCxHxW ; an RGB image. - if is_image_rgb(image): - # it's already RGB, put the channels first and add a batch dim. - img_data = img_data[..., :3] # ignore the Alpha channel (in case of PNG). - img_data = img_data.permute([2, 0, 1]).unsqueeze(0) - else: - img_data = img_data.unsqueeze(0).unsqueeze(0).expand(-1, 3, -1, -1) - - # get input patches - patch_size = model_adapter.patch_size - overlap = model_adapter.overlap - data_patches = patchify(img_data, patch_size, overlap) - num_patches = len(data_patches) - - # set a low batch size - batch_size = 8 - # for big SAM we need even lower batch size :( - if isinstance(model_adapter, SAMAdapter): - batch_size = 2 - num_batches = int(np.ceil(num_patches / batch_size)) - - # get sam encoder output for image patches - print("extracting slice features:") - for b_idx in range(num_batches): - print(f"batch #{b_idx + 1} of {num_batches}") - start = b_idx * batch_size - end = start + batch_size - slice_features = model_adapter.get_features_patches( - data_patches[start:end].to(model_adapter.device) - ).cpu() - if isinstance(slice_features, tuple): # model with more than one output - slice_features = torch.cat(slice_features, dim=-1) - - yield b_idx, slice_features.numpy() +from featureforest.utils.pipeline_prediction import run_prediction_pipeline def apply_postprocessing( @@ -149,19 +28,19 @@ def apply_postprocessing( use_sam_predictor: bool, use_sam_autoseg: bool, iou_threshold: float, -) -> np.ndarray: +) -> dict: post_masks = {} - # if not use_sam_predictor and not use_sam_autoseg: + # simple post-processing mask = postprocess( segmentation_image, smoothing_iterations, area_threshold, area_is_absolute ) - post_masks["Simple"] = mask + post_masks["post_simple"] = mask if use_sam_predictor: mask = postprocess_with_sam( segmentation_image, smoothing_iterations, area_threshold, area_is_absolute ) - post_masks["SAMPredictor"] = mask + post_masks["post_sam"] = mask if use_sam_autoseg: sam_auto_masks = get_sam_auto_masks(input_image) @@ -173,7 +52,7 @@ def apply_postprocessing( area_threshold, area_is_absolute, ) - post_masks["SAMAutoSegmentation"] = mask + post_masks["post_sam_auto"] = mask return post_masks @@ -183,6 +62,7 @@ def main( rf_model_file: str, output_dir: str, model_name: str = "SAM2_Large", + no_patching: bool = False, smoothing_iterations: int = 25, area_threshold: int = 50, use_sam_predictor: bool = True, @@ -198,16 +78,14 @@ def main( # result folder segmentation_dir = Path(output_dir) segmentation_dir.mkdir(parents=True, exist_ok=True) - prediction_dir = segmentation_dir.joinpath("Prediction") + prediction_dir = segmentation_dir.joinpath("prediction") prediction_dir.mkdir(exist_ok=True) + simple_post_dir = segmentation_dir.joinpath("post_simple") + simple_post_dir.mkdir(parents=True, exist_ok=True) + sam_post_dir = segmentation_dir.joinpath("post_sam") + sam_post_dir.mkdir(parents=True, exist_ok=True) - # get input image dims - input_stack = Image.open(data_path) - num_slices = input_stack.n_frames - img_height = input_stack.height - img_width = input_stack.width - print(f"input_stack: {num_slices}, {img_height}, {img_width}") - + # load rf model with open(rf_model_path, mode="rb") as f: model_data = pickle.load(f) # compatibility check for old format rf model @@ -221,13 +99,17 @@ def main( rf_model.set_params(verbose=0) print(rf_model) + # get stack dims + lazy_stack = pims.open(input_file) + img_height, img_width = lazy_stack.frame_shape + # list of available models available_models = get_available_models() assert model_name in available_models, ( f"Couldn't find {model_name} in available models\n{available_models}." ) - model_adapter = get_model(model_name, img_height, img_width) + model_adapter.no_patching = no_patching patch_size = model_adapter.patch_size overlap = model_adapter.overlap print(f"patch_size: {patch_size}, overlap: {overlap}") @@ -238,41 +120,23 @@ def main( use_sam_autoseg = False sam_autoseg_iou_threshold = 0.35 - # ### Prediction + # ### Prediction ### tic = time.perf_counter() - for i, page in enumerate(ImageSequence.Iterator(input_stack)): - print(f"\nslice {i + 1}") - slice_img = np.array(page.convert("RGB")) - procs = [] - # random forest prediction happens per batch of extracted features - # in a separate process. - with mp.Manager() as manager: - result_dict = manager.dict() - for b_idx, patch_features in get_slice_features(slice_img, model_adapter): - proc = mp.Process( - target=predict_patches, - args=(patch_features, rf_model, model_adapter, b_idx, result_dict), - ) - procs.append(proc) - proc.start() - # wait until all processes are done - for p in procs: - if p.is_alive: - p.join() - # collect results from each process - batch_indices = sorted(result_dict.keys()) - patch_masks = [result_dict[b] for b in batch_indices] - patch_masks = np.vstack(patch_masks) - slice_mask = get_image_mask( - patch_masks, img_height, img_width, patch_size, overlap - ) - - img = Image.fromarray(slice_mask) - img.save(prediction_dir.joinpath(f"slice_{i:04}_prediction.tiff")) + for slice_mask, idx, total in run_prediction_pipeline( + input_stack=input_file, + model_adapter=model_adapter, + rf_model=rf_model, + ): + print(f"\nslice {idx + 1} / {total}") + tifffile.imwrite( + prediction_dir.joinpath(f"slice_{idx:04}_prediction.tiff"), slice_mask + ) if do_postprocess: + print("\nrunning post-processing...") + slice_img = lazy_stack[idx] post_masks = apply_postprocessing( - slice_img, + slice_img, # type: ignore slice_mask, smoothing_iterations, area_threshold, @@ -283,10 +147,9 @@ def main( ) # save results for name, mask in post_masks.items(): - img = Image.fromarray(mask) seg_dir = segmentation_dir.joinpath(name) - seg_dir.mkdir(exist_ok=True) - img.save(seg_dir.joinpath(f"slice_{i:04}_{name}.tiff")) + # seg_dir.mkdir(exist_ok=True) + tifffile.imwrite(seg_dir.joinpath(f"slice_{idx:04}_{name}.tiff"), mask) print(f"total elapsed time: {(time.perf_counter() - tic)} seconds") @@ -303,6 +166,11 @@ def main( choices=get_available_models(), help="Name of the model for feature extraction", ) + parser.add_argument( + "--no_patching", + action="store_true", + help="If true, no patching will be used during feature extraction", + ) parser.add_argument( "--smoothing_iterations", default=25, @@ -319,7 +187,7 @@ def main( "--use_sam_predictor", default=True, action="store_true", - help="To use SAM2 for generating final masks", + help="uses SAM2 for generating final masks", ) args = parser.parse_args() @@ -329,6 +197,7 @@ def main( rf_model_file=args.rf_model, output_dir=args.outdir, model_name=args.feat_model, + no_patching=args.no_patching, smoothing_iterations=args.smoothing_iterations, area_threshold=args.area_threshold, use_sam_predictor=args.use_sam_predictor, From e959bf1fdd3e120a1326ad3bf5418a228f3aea0e Mon Sep 17 00:00:00 2001 From: Mehdi Seifi Date: Sun, 15 Jun 2025 16:24:11 +0200 Subject: [PATCH 18/35] removed unused import in run_pipeline --- run_pipeline.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/run_pipeline.py b/run_pipeline.py index 1118295..e3e7da3 100755 --- a/run_pipeline.py +++ b/run_pipeline.py @@ -8,8 +8,6 @@ import tifffile from featureforest.models import get_available_models, get_model - -# from featureforest.models.SAM import SAMAdapter from featureforest.postprocess import ( get_sam_auto_masks, postprocess, From b5809683fd7d8267a7ae4c91691e56d3e6c42bf3 Mon Sep 17 00:00:00 2001 From: Mehdi Seifi Date: Sun, 15 Jun 2025 16:30:43 +0200 Subject: [PATCH 19/35] fixed models params: image height & width as int --- src/featureforest/models/Cellpose/model.py | 2 +- src/featureforest/models/DinoV2/model.py | 10 ++---- src/featureforest/models/MobileSAM/model.py | 30 +++++++----------- src/featureforest/models/SAM/model.py | 12 +++---- src/featureforest/models/SAM2/model.py | 35 +++++++-------------- 5 files changed, 31 insertions(+), 58 deletions(-) diff --git a/src/featureforest/models/Cellpose/model.py b/src/featureforest/models/Cellpose/model.py index 37e16df..74642fb 100755 --- a/src/featureforest/models/Cellpose/model.py +++ b/src/featureforest/models/Cellpose/model.py @@ -25,7 +25,7 @@ def forward(self, x): return out -def get_model(img_height: float, img_width: float, *args, **kwargs) -> CellposeAdapter: +def get_model(img_height: int, img_width: int, *args, **kwargs) -> CellposeAdapter: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"running on {device}") # init the model diff --git a/src/featureforest/models/DinoV2/model.py b/src/featureforest/models/DinoV2/model.py index f003e93..0c2df33 100644 --- a/src/featureforest/models/DinoV2/model.py +++ b/src/featureforest/models/DinoV2/model.py @@ -1,14 +1,10 @@ -from typing import Tuple - import torch # from featureforest.utils.downloader import download_model from .adapter import DinoV2Adapter -def get_model( - img_height: float, img_width: float, *args, **kwargs -) -> DinoV2Adapter: +def get_model(img_height: int, img_width: int, *args, **kwargs) -> DinoV2Adapter: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"running on {device}") # get the pretrained model @@ -18,8 +14,6 @@ def get_model( model.eval() # create the model adapter - dino_model_adapter = DinoV2Adapter( - model, img_height, img_width, device - ) + dino_model_adapter = DinoV2Adapter(model, img_height, img_width, device) return dino_model_adapter diff --git a/src/featureforest/models/MobileSAM/model.py b/src/featureforest/models/MobileSAM/model.py index fc225c7..5f7e301 100644 --- a/src/featureforest/models/MobileSAM/model.py +++ b/src/featureforest/models/MobileSAM/model.py @@ -1,28 +1,22 @@ -from typing import Tuple - import torch - -from .tiny_vit_sam import TinyViT from segment_anything.modeling import MaskDecoder, PromptEncoder, Sam, TwoWayTransformer from featureforest.utils.downloader import download_model + from .adapter import MobileSAMAdapter +from .tiny_vit_sam import TinyViT -def get_model( - img_height: float, img_width: float, *args, **kwargs -) -> MobileSAMAdapter: +def get_model(img_height: int, img_width: int, *args, **kwargs) -> MobileSAMAdapter: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"running on {device}") # get the model model = setup_model().to(device) # download model's weights - model_url = \ + model_url = ( "https://github.com/ChaoningZhang/MobileSAM/raw/master/weights/mobile_sam.pt" - model_file = download_model( - model_url=model_url, - model_name="mobile_sam.pt" ) + model_file = download_model(model_url=model_url, model_name="mobile_sam.pt") if model_file is None: raise ValueError(f"Could not download the model from {model_url}.") @@ -32,9 +26,7 @@ def get_model( model.eval() # create the model adapter - sam_model_adapter = MobileSAMAdapter( - model, img_height, img_width, device - ) + sam_model_adapter = MobileSAMAdapter(model, img_height, img_width, device) return sam_model_adapter @@ -46,18 +38,20 @@ def setup_model() -> Sam: image_embedding_size = image_size // vit_patch_size mobile_sam = Sam( image_encoder=TinyViT( - img_size=1024, in_chans=3, num_classes=1000, + img_size=1024, + in_chans=3, + num_classes=1000, embed_dims=[64, 128, 160, 320], depths=[2, 2, 6, 2], num_heads=[2, 4, 5, 10], window_sizes=[7, 7, 14, 7], - mlp_ratio=4., - drop_rate=0., + mlp_ratio=4.0, + drop_rate=0.0, drop_path_rate=0.0, use_checkpoint=False, mbconv_expand_ratio=4.0, local_conv_size=3, - layer_lr_decay=0.8 + layer_lr_decay=0.8, ), prompt_encoder=PromptEncoder( embed_dim=prompt_embed_dim, diff --git a/src/featureforest/models/SAM/model.py b/src/featureforest/models/SAM/model.py index daf0f2a..1865743 100644 --- a/src/featureforest/models/SAM/model.py +++ b/src/featureforest/models/SAM/model.py @@ -1,11 +1,10 @@ import torch - -from segment_anything.modeling import Sam from segment_anything import sam_model_registry +from segment_anything.modeling import Sam from featureforest.utils.downloader import download_model -from .adapter import SAMAdapter +from .adapter import SAMAdapter MODEL_URLS = { "vit_h": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth", @@ -27,16 +26,13 @@ def get_model( - img_height: float, img_width: float, model_type: str = "vit_h", *args, **kwargs + img_height: int, img_width: int, model_type: str = "vit_h", *args, **kwargs ) -> SAMAdapter: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"running on {device}") # download model's weights model_url = MODEL_URLS[model_type] - model_file = download_model( - model_url=model_url, - model_name=MODEL_FNAMES[model_type] - ) + model_file = download_model(model_url=model_url, model_name=MODEL_FNAMES[model_type]) if model_file is None: raise ValueError(f"Could not download the model from {model_url}.") diff --git a/src/featureforest/models/SAM2/model.py b/src/featureforest/models/SAM2/model.py index 679b6c2..45f17a6 100644 --- a/src/featureforest/models/SAM2/model.py +++ b/src/featureforest/models/SAM2/model.py @@ -1,34 +1,27 @@ -from pathlib import Path - import torch - -from sam2.modeling.sam2_base import SAM2Base from sam2.build_sam import build_sam2 +from sam2.modeling.sam2_base import SAM2Base -from featureforest.utils.downloader import download_model from featureforest.models.SAM2.adapter import SAM2Adapter +from featureforest.utils.downloader import download_model -def get_large_model( - img_height: float, img_width: float, *args, **kwargs -) -> SAM2Adapter: +def get_large_model(img_height: int, img_width: int, *args, **kwargs) -> SAM2Adapter: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"running on {device}") # download model's weights - model_url = \ + model_url = ( "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_large.pt" - model_file = download_model( - model_url=model_url, - model_name="sam2.1_hiera_large.pt" ) + model_file = download_model(model_url=model_url, model_name="sam2.1_hiera_large.pt") if model_file is None: raise ValueError(f"Could not download the model from {model_url}.") # init the model model: SAM2Base = build_sam2( - config_file= "configs/sam2.1/sam2.1_hiera_l.yaml", + config_file="configs/sam2.1/sam2.1_hiera_l.yaml", ckpt_path=model_file, - device="cpu" + device="cpu", ) # to save some GPU memory, only put the encoder part on GPU sam_image_encoder = model.image_encoder.to(device) @@ -42,26 +35,22 @@ def get_large_model( return sam2_model_adapter -def get_base_model( - img_height: float, img_width: float, *args, **kwargs -) -> SAM2Adapter: +def get_base_model(img_height: float, img_width: float, *args, **kwargs) -> SAM2Adapter: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"running on {device}") # download model's weights - model_url = \ - "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_base_plus.pt" + model_url = "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_base_plus.pt" model_file = download_model( - model_url=model_url, - model_name="sam2.1_hiera_base_plus.pt" + model_url=model_url, model_name="sam2.1_hiera_base_plus.pt" ) if model_file is None: raise ValueError(f"Could not download the model from {model_url}.") # init the model model: SAM2Base = build_sam2( - config_file= "configs/sam2.1/sam2.1_hiera_b+.yaml", + config_file="configs/sam2.1/sam2.1_hiera_b+.yaml", ckpt_path=model_file, - device="cpu" + device="cpu", ) # to save some GPU memory, only put the encoder part on GPU sam_image_encoder = model.image_encoder.to(device) From b7261df4bd439b8a632c2d52bf618bd5401c4994 Mon Sep 17 00:00:00 2001 From: Mehdi Seifi Date: Sun, 15 Jun 2025 17:01:46 +0200 Subject: [PATCH 20/35] ignored some typing --- src/featureforest/models/DinoV2/adapter.py | 2 +- src/featureforest/models/SAM/adapter.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/featureforest/models/DinoV2/adapter.py b/src/featureforest/models/DinoV2/adapter.py index 0544c63..b34e9e3 100644 --- a/src/featureforest/models/DinoV2/adapter.py +++ b/src/featureforest/models/DinoV2/adapter.py @@ -56,7 +56,7 @@ def get_features_patches(self, in_patches: Tensor) -> Tensor: 1, return_class_token=False, reshape=True, - )[0] + )[0] # type: ignore # get non-overlapped feature patches feature_patches = get_nonoverlapped_patches( diff --git a/src/featureforest/models/SAM/adapter.py b/src/featureforest/models/SAM/adapter.py index c6fa29c..80829f8 100644 --- a/src/featureforest/models/SAM/adapter.py +++ b/src/featureforest/models/SAM/adapter.py @@ -27,10 +27,10 @@ def __init__( self.name = name # we need sam image encoder part self.encoder = image_encoder - self.encoder_num_channels = 256 + self.encoder_num_channels: int = 256 # NOTE: The parameter below matches the SAM model's encoder embedding dimension. - self.embed_layer_num_channels = image_encoder.patch_embed.proj.out_channels + self.embed_layer_num_channels: int = image_encoder.patch_embed.proj.out_channels self._set_patch_size() self.device = device @@ -64,7 +64,7 @@ def get_features_patches(self, in_patches: Tensor) -> tuple[Tensor, Tensor]: # embed_output: b,64,64,1280 -> b,1280,64,64 embed_output = self.encoder.patch_embed( self.input_transforms(in_patches) - ).permute(0, 3, 1, 2) + ).permute(0, 3, 1, 2) # type: ignore # get non-overlapped feature patches out_feature_patches = get_nonoverlapped_patches( From bc624a29210dce84b609a32ec71ad3744de9425b Mon Sep 17 00:00:00 2001 From: Mehdi Seifi Date: Sun, 15 Jun 2025 17:06:08 +0200 Subject: [PATCH 21/35] updated requirements --- requirements.txt | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 9137ec2..a5bc997 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,7 +9,9 @@ pynrrd qtpy magicgui napari[all] -h5py +tifffile +zarr>=2,<3 +pims pooch tqdm>=4.66.1 iopath>=0.1.10 From b3d672030b791692d48b994f912c44a1c358db9d Mon Sep 17 00:00:00 2001 From: Mehdi Seifi Date: Sun, 15 Jun 2025 17:42:19 +0200 Subject: [PATCH 22/35] fixed image height & width as int --- src/featureforest/models/__init__.py | 2 +- src/featureforest/models/util.py | 105 --------------------------- 2 files changed, 1 insertion(+), 106 deletions(-) delete mode 100644 src/featureforest/models/util.py diff --git a/src/featureforest/models/__init__.py b/src/featureforest/models/__init__.py index ec870c6..50a4419 100644 --- a/src/featureforest/models/__init__.py +++ b/src/featureforest/models/__init__.py @@ -25,7 +25,7 @@ def get_available_models() -> list[str]: def get_model( - model_name: str, img_height: float, img_width: float, *args, **kwargs + model_name: str, img_height: int, img_width: int, *args, **kwargs ) -> BaseModelAdapter: """Returns the requested model adapter. diff --git a/src/featureforest/models/util.py b/src/featureforest/models/util.py deleted file mode 100644 index 3df7768..0000000 --- a/src/featureforest/models/util.py +++ /dev/null @@ -1,105 +0,0 @@ -import os -from pathlib import Path -from typing import Optional, Union - -import imageio.v3 as imageio -import numpy as np - -from ..utils.extract import extract_embeddings_to_file, get_stack_dims -from . import _MODELS_DICT, get_model - - -def extract_features( - image: np.ndarray, - output_path: Union[os.PathLike, str], - model_name: str = "SAM2_Large", - image_height: Optional[int] = None, - image_width: Optional[int] = None, -) -> str: - """Extracts features for the chosen model. - - Args: - image: The input image. - output_path: The filepath where the extracted features will be stored. - model_name: The choice of model that will be used for feature extraction. - By default, extracts features for `SAM2_Large`. - image_height: The height of input image. By default, extracted from the input image. - image_width: The width of input image. By default, extracted from the input image. - - Returns: - The filepath where the extracted features have been stored. - """ - if image_height is None and image_width is None: - # - Get the height and width of the input image. - _, image_height, image_width = get_stack_dims(image) - - # Step 1: Get the desired model adapter. - model_adapter = get_model( - model_name=model_name, img_height=image_height, img_width=image_width - ) - - # - Transform the inputs - transformed_image = model_adapter.input_transforms(image) - - # Step 2: Run the feature extraction step. - if os.path.splitext(output_path)[-1].lower() != ".hdf5": - # In this case, we assume that it's a filepath without extension and give it the desired one. - output_path = str(Path(output_path).with_suffix(".hdf5")) - - extractor_generator = extract_embeddings_to_file( - image=transformed_image, - storage_path=output_path, - model_adapter=model_adapter, - ) - - # Step 3: Run the extractor generator till the end - _ = list(extractor_generator) - - return output_path - - -def main(): - """@private""" - - import argparse - - available_models = list(_MODELS_DICT.keys()) - available_models = ", ".join(available_models) - - parser = argparse.ArgumentParser(description="Extract features for a chosen model.") - parser.add_argument( - "-i", - "--input_path", - type=str, - required=True, - help="The filepath to the image data. Supports all data types that can be read by imageio (eg. tif, png, ...).", - ) - parser.add_argument( - "-o", - "--output_path", - type=str, - required=True, - help="The filepath to store the extracted features. The current supports store features in a 'hdf5' file.", - ) - parser.add_argument( - "--model_choice", - type=str, - default="SAM2_Large", - help=f"The choice of vision foundation model that will be used, one of ({available_models}). " - "By default, extracts features for 'SAM2_Large'.", - ) - - args = parser.parse_args() - - # Load the image. - # TODO: Currently supports a simple setup. We can make it more complicated to support other bioformats later. - image = imageio.imread(args.input_path) - - # Extract the features. - output_path = extract_features( - image=image, output_path=args.output_path, model_name=args.model_choice - ) - - print( - f"The features of '{args.model_choice}' have been extracted at '{os.path.abspath(output_path)}'." - ) From 32f4a26158088e1e4bff583ce8fe1fb638e26b52 Mon Sep 17 00:00:00 2001 From: Mehdi Seifi Date: Sun, 15 Jun 2025 18:49:15 +0200 Subject: [PATCH 23/35] fixed dataset & get_model_ready_image image dimention problem --- src/featureforest/utils/data.py | 7 ++----- src/featureforest/utils/dataset.py | 3 ++- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/src/featureforest/utils/data.py b/src/featureforest/utils/data.py index 9981947..17066aa 100644 --- a/src/featureforest/utils/data.py +++ b/src/featureforest/utils/data.py @@ -10,7 +10,7 @@ def get_model_ready_image(image: np.ndarray) -> torch.Tensor: """Convert the input image to a torch tensor and normalize it. Args: - image (np.ndarray): Input image to be converted. + image (np.ndarray): Input image to be converted (H, W, C). Returns: torch.Tensor: The input image as a torch tensor, normalized to [0, 1]. """ @@ -22,13 +22,10 @@ def get_model_ready_image(image: np.ndarray) -> torch.Tensor: _max = img_data.max() img_data = (img_data - _min) / (_max - _min) # for image encoders, the input image must be in RGB. - # if not is_stacked(img_data.numpy()): - # # add a batch dim - # img_data = img_data.unsqueeze(0) if is_image_rgb(img_data.numpy()): # it's already RGB img_data = img_data[..., :3] # discard the alpha channel (in case of PNG). - img_data = img_data.permute([3, 1, 2]) # make it channel first + img_data = img_data.permute([2, 0, 1]) # make it channel first else: # make it RGB by repeating the single channel img_data = img_data.unsqueeze(0).expand(3, -1, -1) diff --git a/src/featureforest/utils/dataset.py b/src/featureforest/utils/dataset.py index 19b2243..77ad0fe 100644 --- a/src/featureforest/utils/dataset.py +++ b/src/featureforest/utils/dataset.py @@ -10,6 +10,7 @@ from featureforest.utils.data import ( get_model_ready_image, + is_stacked, patchify, ) @@ -39,7 +40,7 @@ def __init__( # images are already loaded into a numpy array self.image_source = images # add slice dimension if not present - if self.image_source.ndim == 2: + if not is_stacked(self.image_source): self.image_source = self.image_source[np.newaxis, ...] elif isinstance(images, str | Path): From f2b04d684e7337230339e635932216a624b42ce8 Mon Sep 17 00:00:00 2001 From: Mehdi Seifi Date: Sun, 15 Jun 2025 18:49:53 +0200 Subject: [PATCH 24/35] added dataset test --- tests/test_dataset.py | 114 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 114 insertions(+) create mode 100644 tests/test_dataset.py diff --git a/tests/test_dataset.py b/tests/test_dataset.py new file mode 100644 index 0000000..345d3de --- /dev/null +++ b/tests/test_dataset.py @@ -0,0 +1,114 @@ +import numpy as np +import pytest +import torch + +from featureforest.utils.dataset import FFImageDataset + + +class DummyImageSequence: + """Dummy class to mock pims.ImageSequence for testing.""" + + def __init__(self, images): + self._images = images + self.frame_shape = images[0].shape + + def __len__(self): + return len(self._images) + + def __getitem__(self, idx): + return self._images[idx] + + def __iter__(self): + return iter(self._images) + + +@pytest.fixture +def dummy_numpy_stack(): + # shape: (3, 32, 32) + return np.random.randint(0, 255, (3, 32, 32), dtype=np.uint8) + + +@pytest.fixture +def dummy_numpy_single(): + # shape: (32, 32) + return np.random.randint(0, 255, (32, 32), dtype=np.uint8) + + +def test_init_with_numpy_stack(dummy_numpy_stack, monkeypatch): + ds = FFImageDataset(dummy_numpy_stack, no_patching=True) + assert ds.num_images == 3 + assert ds.image_shape == (32, 32) + items = list(ds) + assert len(items) == 3 + for img, idx in items: + assert isinstance(img, torch.Tensor) + assert img.shape[-2:] == (32, 32) + assert idx[0] in [0, 1, 2] + + +def test_init_with_numpy_single(dummy_numpy_single): + ds = FFImageDataset(dummy_numpy_single, no_patching=True) + assert ds.num_images == 1 + assert ds.image_shape == (32, 32) + items = list(ds) + assert len(items) == 1 + img, idx = items[0] + assert isinstance(img, torch.Tensor) + assert img.shape[-2:] == (32, 32) + assert idx[0] == 0 + + +def test_iter_with_patching(dummy_numpy_stack, monkeypatch): + # Patch patchify to split into 2 patches + monkeypatch.setattr( + "featureforest.utils.dataset.patchify", + lambda img, sz, ov: [img[..., :16, :16], img[..., 16:, 16:]], + ) + ds = FFImageDataset(dummy_numpy_stack, no_patching=False) + items = list(ds) + assert len(items) == 6 # 3 images * 2 patches each + for patch, idx in items: + assert isinstance(patch, torch.Tensor) + assert idx.shape == (2,) + + +def test_init_with_invalid_type(): + with pytest.raises(ValueError): + FFImageDataset(12345) + + +def test_init_with_empty_dir(tmp_path): + empty_dir = tmp_path / "empty" + empty_dir.mkdir() + with pytest.raises(ValueError): + FFImageDataset(empty_dir) + + +def test_init_with_image_files(tmp_path, monkeypatch): + # Create dummy image files + for ext in ["tiff", "tif", "png", "jpg"]: + (tmp_path / f"img1.{ext}").write_bytes(np.random.bytes(100)) + # Patch pims.ImageSequence to DummyImageSequence + monkeypatch.setattr( + "featureforest.utils.dataset.pims.ImageSequence", + lambda files: DummyImageSequence([np.zeros((32, 32), dtype=np.uint8)] * 4), + ) + ds = FFImageDataset(tmp_path, no_patching=True) + assert ds.num_images == 4 + assert ds.image_shape == (32, 32) + items = list(ds) + assert len(items) == 4 + + +def test_image_shape_none(): + ds = FFImageDataset(np.zeros((1, 2, 2)), no_patching=True) + ds.image_source = None + with pytest.raises(ValueError): + _ = ds.image_shape + + +def test_iter_no_image_source(): + ds = FFImageDataset(np.zeros((1, 2, 2)), no_patching=True) + ds.image_source = None + with pytest.raises(ValueError): + next(iter(ds)) From 0104b10c44ca1534bcb545fd336a802d9bdf063d Mon Sep 17 00:00:00 2001 From: Mehdi Seifi Date: Sun, 15 Jun 2025 19:04:00 +0200 Subject: [PATCH 25/35] fixed adapter: concat double output into one tensor --- src/featureforest/models/MobileSAM/adapter.py | 6 ++++-- src/featureforest/models/SAM/adapter.py | 6 ++++-- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/src/featureforest/models/MobileSAM/adapter.py b/src/featureforest/models/MobileSAM/adapter.py index 11a2d14..efc7259 100644 --- a/src/featureforest/models/MobileSAM/adapter.py +++ b/src/featureforest/models/MobileSAM/adapter.py @@ -46,7 +46,7 @@ def __init__( ] ) - def get_features_patches(self, in_patches: Tensor) -> tuple[Tensor, Tensor]: + def get_features_patches(self, in_patches: Tensor) -> Tensor: # get the mobile-sam encoder and embedding layer outputs with torch.no_grad(): output, embed_output, _ = self.encoder(self.input_transforms(in_patches)) @@ -58,8 +58,10 @@ def get_features_patches(self, in_patches: Tensor) -> tuple[Tensor, Tensor]: embed_feature_patches = get_nonoverlapped_patches( self.embedding_transform(embed_output.cpu()), self.patch_size, self.overlap ) + # concat both features together on channel dimension + output = torch.cat([out_feature_patches, embed_feature_patches], dim=-1) - return out_feature_patches, embed_feature_patches + return output def get_total_output_channels(self) -> int: return self.encoder_num_channels + self.embed_layer_num_channels diff --git a/src/featureforest/models/SAM/adapter.py b/src/featureforest/models/SAM/adapter.py index 80829f8..6c74649 100644 --- a/src/featureforest/models/SAM/adapter.py +++ b/src/featureforest/models/SAM/adapter.py @@ -56,7 +56,7 @@ def __init__( ] ) - def get_features_patches(self, in_patches: Tensor) -> tuple[Tensor, Tensor]: + def get_features_patches(self, in_patches: Tensor) -> Tensor: # get the mobile-sam encoder and embedding layer outputs with torch.no_grad(): # output: b,256,64,64 @@ -73,8 +73,10 @@ def get_features_patches(self, in_patches: Tensor) -> tuple[Tensor, Tensor]: embed_feature_patches = get_nonoverlapped_patches( self.embedding_transform(embed_output.cpu()), self.patch_size, self.overlap ) + # concat both features together on channel dimension + output = torch.cat([out_feature_patches, embed_feature_patches], dim=-1) - return out_feature_patches, embed_feature_patches + return output def get_total_output_channels(self) -> int: return self.encoder_num_channels + self.embed_layer_num_channels From 8b9cc933ed59b58552c94a1531ec01c6d651738f Mon Sep 17 00:00:00 2001 From: Mehdi Seifi Date: Sun, 15 Jun 2025 19:04:35 +0200 Subject: [PATCH 26/35] updated mobilesam test --- .../embedding_extraction.py | 72 ++++----- .../test_mobilesam_adapter.py | 142 +++++++++--------- 2 files changed, 109 insertions(+), 105 deletions(-) diff --git a/tests/model_adapter_tests/embedding_extraction.py b/tests/model_adapter_tests/embedding_extraction.py index 54bae5e..3fc13dd 100644 --- a/tests/model_adapter_tests/embedding_extraction.py +++ b/tests/model_adapter_tests/embedding_extraction.py @@ -1,34 +1,38 @@ -# from tempfile import TemporaryDirectory - -# import h5py - -# from featureforest.utils.extract import extract_embeddings_to_file - - -# def check_embedding_extraction( -# test_image, model_adapter, expected_output_shape, expected_slices -# ): -# with TemporaryDirectory() as tmp_dir: -# tmp_file = tmp_dir + "/tmp.h5" - -# extractor_generator = extract_embeddings_to_file( -# image=test_image, storage_path=tmp_file, model_adapter=model_adapter -# ) - -# # Run the extractor generator till the end -# _ = list(extractor_generator) - -# with h5py.File(tmp_file, "r") as read_storage: -# slices = list(read_storage.keys()) -# assert len(slices) == expected_slices, ( -# f"Unexpected number of slices: {len(slices)}, expected: {expected_slices}" -# ) -# for slice in slices: -# slice_key = str(slice) -# slice_dataset = read_storage[slice_key].get(model_adapter.name) -# assert slice_dataset is not None, ( -# f"The dataset for slice {slice_key} is empty" -# ) -# assert slice_dataset.shape == expected_output_shape, ( -# f"Unexpected dataset shape: {slice_dataset.shape}, expected: {expected_output_shape}" -# ) +from tempfile import TemporaryDirectory + +import zarr +import zarr.storage + +from featureforest.utils.extract import extract_embeddings_to_file + + +def check_embedding_extraction( + test_image, model_adapter, expected_output_shape, expected_slices +): + with TemporaryDirectory() as tmp_dir: + tmp_file = tmp_dir + "/tmp.zarr" + + extractor_generator = extract_embeddings_to_file( + image=test_image, storage_path=tmp_file, model_adapter=model_adapter + ) + + # Run the extractor generator till the end + _ = list(extractor_generator) + + read_storage: zarr.Group = zarr.open(tmp_file, mode="r") # type: ignore + slices = list(read_storage.keys()) + assert len(slices) == expected_slices, ( + f"Unexpected number of slices: {len(slices)}, expected: {expected_slices}" + ) + for slice_idx in slices: + slice_key = str(slice_idx) + slice_dataset = read_storage[slice_key]["features"] + assert slice_dataset is not None, ( + f"The dataset for slice {slice_key} is empty" + ) + assert slice_dataset.shape == expected_output_shape, ( + f"Unexpected dataset shape: {slice_dataset.shape}, " + f"expected: {expected_output_shape}" + ) + + read_storage.store.close() diff --git a/tests/model_adapter_tests/test_mobilesam_adapter.py b/tests/model_adapter_tests/test_mobilesam_adapter.py index 0cfb159..284e1b9 100644 --- a/tests/model_adapter_tests/test_mobilesam_adapter.py +++ b/tests/model_adapter_tests/test_mobilesam_adapter.py @@ -1,84 +1,84 @@ -# import numpy as np -# import pytest -# import torch -# import torch.nn as nn +import numpy as np +import pytest +import torch +import torch.nn as nn +from embedding_extraction import check_embedding_extraction -# from featureforest.models.MobileSAM import MobileSAMAdapter, get_model -# from featureforest.utils.data import get_stack_dims -# from embedding_extraction import check_embedding_extraction +from featureforest.models.MobileSAM import MobileSAMAdapter, get_model +from featureforest.utils.data import get_stack_dims -# class MockMobileSAMEncoder(nn.Module): -# def __init__(self): -# super().__init__() -# self.image_encoder = self.mock_encode -# self.encoder_num_channels = 256 -# self.embed_layer_num_channels = 64 +class MockMobileSAMEncoder(nn.Module): + def __init__(self): + super().__init__() + self.image_encoder = self.mock_encode + self.encoder_num_channels = 256 + self.embed_layer_num_channels = 64 -# def mock_encode(self, x): -# batch_size = x.shape[0] -# output = torch.ones( -# batch_size, -# self.encoder_num_channels, -# self.embed_layer_num_channels, -# self.embed_layer_num_channels, -# ) -# embed_output = torch.ones( -# batch_size, -# self.embed_layer_num_channels, -# self.encoder_num_channels, -# self.encoder_num_channels, -# ) -# return output, embed_output, None + def mock_encode(self, x): + batch_size = x.shape[0] + output = torch.ones( + batch_size, + self.encoder_num_channels, + self.embed_layer_num_channels, + self.embed_layer_num_channels, + ) + embed_output = torch.ones( + batch_size, + self.embed_layer_num_channels, + self.encoder_num_channels, + self.encoder_num_channels, + ) + return output, embed_output, None -# def get_mock_model(img_height: float, img_width: float) -> MobileSAMAdapter: -# model = MockMobileSAMEncoder() -# device = torch.device("cpu") -# sam_model_adapter = MobileSAMAdapter(model, img_height, img_width, device) -# return sam_model_adapter +def get_mock_model(img_height: int, img_width: int) -> MobileSAMAdapter: + model = MockMobileSAMEncoder() + device = torch.device("cpu") + sam_model_adapter = MobileSAMAdapter(model, img_height, img_width, device) + return sam_model_adapter -# @pytest.mark.slow() -# @pytest.mark.parametrize( -# "test_patch", -# [ -# torch.ones((1, 3, 128, 128)), -# torch.ones((3, 3, 128, 128)), -# torch.ones((8, 3, 128, 128)), -# torch.ones((8, 3, 256, 256)), -# torch.ones((8, 3, 512, 512)) -# ], -# ) -# def test_mock_adapter(test_patch: np.ndarray): -# real_adapter = get_model(512, 512) -# mock_adapter = get_mock_model(512, 512) +@pytest.mark.slow() +@pytest.mark.parametrize( + "test_patch", + [ + torch.ones((1, 3, 128, 128)), + torch.ones((3, 3, 128, 128)), + torch.ones((8, 3, 128, 128)), + torch.ones((8, 3, 256, 256)), + torch.ones((8, 3, 512, 512)), + ], +) +def test_mock_adapter(test_patch: np.ndarray): + real_adapter = get_model(512, 512) + mock_adapter = get_mock_model(512, 512) -# transformed_input_patch_real = real_adapter.input_transforms(test_patch) -# transformed_input_patch_mock = mock_adapter.input_transforms(test_patch) + transformed_input_patch_real = real_adapter.input_transforms(test_patch) + transformed_input_patch_mock = mock_adapter.input_transforms(test_patch) -# result_real = real_adapter.encoder(transformed_input_patch_real) -# mock_result = mock_adapter.encoder(transformed_input_patch_mock) + result_real = real_adapter.encoder(transformed_input_patch_real) + mock_result = mock_adapter.encoder(transformed_input_patch_mock) -# assert len(result_real) == len(mock_result) -# assert result_real[0].shape == mock_result[0].shape -# assert result_real[1].shape == mock_result[1].shape + assert len(result_real) == len(mock_result) + assert result_real[0].shape == mock_result[0].shape + assert result_real[1].shape == mock_result[1].shape -# @pytest.mark.parametrize( -# "test_image, expected_output_shape, expected_slices", -# [ -# (np.ones((256, 256)), (9, 128, 128, 320), 1), # 2D -# (np.ones((256, 256, 3)), (9, 128, 128, 320), 1), # 2D RGB -# (np.ones((2, 256, 256)), (9, 128, 128, 320), 2), # 3D -# (np.ones((2, 256, 256, 3)), (9, 128, 128, 320), 2) # 3D RGB -# ], -# ) -# def test_mobilesam_embedding_extraction( -# test_image: np.ndarray, expected_output_shape: tuple, expected_slices: int -# ): -# num_slices, img_height, img_width = get_stack_dims(test_image) -# model_adapter = get_mock_model(img_height, img_width) -# check_embedding_extraction( -# test_image, model_adapter, expected_output_shape, expected_slices -# ) +@pytest.mark.parametrize( + "test_image, expected_output_shape, expected_slices", + [ + (np.ones((256, 256)), (4, 192, 192, 320), 1), # 2D + (np.ones((256, 256, 3)), (4, 192, 192, 320), 1), # 2D RGB + (np.ones((2, 256, 256)), (4, 192, 192, 320), 2), # 3D + (np.ones((2, 256, 256, 3)), (4, 192, 192, 320), 2), # 3D RGB + ], +) +def test_mobilesam_embedding_extraction( + test_image: np.ndarray, expected_output_shape: tuple, expected_slices: int +): + num_slices, img_height, img_width = get_stack_dims(test_image) + model_adapter = get_mock_model(img_height, img_width) + check_embedding_extraction( + test_image, model_adapter, expected_output_shape, expected_slices + ) From f4c62398e9aefc4e53383c1a78365bbaa8c41825 Mon Sep 17 00:00:00 2001 From: Mehdi Seifi Date: Sun, 15 Jun 2025 19:32:14 +0200 Subject: [PATCH 27/35] updated dino adapter test --- .../model_adapter_tests/test_dino_adapter.py | 128 +++++++++--------- 1 file changed, 64 insertions(+), 64 deletions(-) diff --git a/tests/model_adapter_tests/test_dino_adapter.py b/tests/model_adapter_tests/test_dino_adapter.py index 839902e..99a6401 100644 --- a/tests/model_adapter_tests/test_dino_adapter.py +++ b/tests/model_adapter_tests/test_dino_adapter.py @@ -1,77 +1,77 @@ -# import numpy as np -# import pytest -# import torch -# import torch.nn as nn +import numpy as np +import pytest +import torch +import torch.nn as nn +from embedding_extraction import check_embedding_extraction -# from featureforest.models.DinoV2 import DinoV2Adapter, get_model -# from featureforest.utils.data import get_stack_dims -# from embedding_extraction import check_embedding_extraction +from featureforest.models.DinoV2 import DinoV2Adapter, get_model +from featureforest.utils.data import get_stack_dims -# class MockDinoEncoder(nn.Module): -# def __init__(self): -# super().__init__() -# self.dino_patch_size = 14 -# self.dino_out_channels = 384 -# self.height = 70 -# self.width = 70 +class MockDinoEncoder(nn.Module): + def __init__(self): + super().__init__() + self.dino_patch_size = 14 + self.dino_out_channels = 384 + self.height = 70 + self.width = 70 -# def get_intermediate_layers(self, x, *args, **kwargs): -# batch_size = x.shape[0] -# output = torch.ones(batch_size, self.dino_out_channels, self.height, self.width) -# return output, None + def get_intermediate_layers(self, x, *args, **kwargs): + batch_size = x.shape[0] + output = torch.ones(batch_size, self.dino_out_channels, self.height, self.width) + return output, None -# def get_mock_model(img_height: float, img_width: float) -> DinoV2Adapter: -# model = MockDinoEncoder() -# device = torch.device("cpu") -# dino_model_adapter = DinoV2Adapter(model, img_height, img_width, device) -# return dino_model_adapter +def get_mock_model(img_height: int, img_width: int) -> DinoV2Adapter: + model = MockDinoEncoder() + device = torch.device("cpu") + dino_model_adapter = DinoV2Adapter(model, img_height, img_width, device) + return dino_model_adapter -# @pytest.mark.slow() -# @pytest.mark.parametrize( -# "test_patch", -# [ -# torch.ones((1, 3, 128, 128)), -# torch.ones((3, 3, 128, 128)), -# torch.ones((1, 3, 256, 256)), -# torch.ones((1, 3, 512, 512)), -# ], -# ) -# def test_mock_adapter(test_patch: np.ndarray): -# img_height, img_width = test_patch.shape[-2:] -# real_adapter = get_model(img_height, img_width) -# mock_adapter = get_mock_model(img_height, img_width) +@pytest.mark.slow() +@pytest.mark.parametrize( + "test_patch", + [ + torch.ones((1, 3, 128, 128)), + torch.ones((3, 3, 128, 128)), + torch.ones((1, 3, 256, 256)), + torch.ones((1, 3, 512, 512)), + ], +) +def test_mock_adapter(test_patch: np.ndarray): + img_height, img_width = test_patch.shape[-2:] + real_adapter = get_model(img_height, img_width) + mock_adapter = get_mock_model(img_height, img_width) -# transformed_input_patch_real = real_adapter.input_transforms(test_patch) -# transformed_input_patch_mock = mock_adapter.input_transforms(test_patch) + transformed_input_patch_real = real_adapter.input_transforms(test_patch) + transformed_input_patch_mock = mock_adapter.input_transforms(test_patch) -# result_real = real_adapter.model.get_intermediate_layers( -# transformed_input_patch_real, 1, return_class_token=False, reshape=True -# )[0] -# mock_result = mock_adapter.model.get_intermediate_layers( -# transformed_input_patch_mock -# )[0] + result_real = real_adapter.model.get_intermediate_layers( + transformed_input_patch_real, 1, return_class_token=False, reshape=True + )[0] # type: ignore + mock_result = mock_adapter.model.get_intermediate_layers( + transformed_input_patch_mock + )[0] # type: ignore -# assert len(result_real) == len(mock_result) -# assert result_real[0].shape == mock_result[0].shape + assert len(result_real) == len(mock_result) + assert result_real[0].shape == mock_result[0].shape -# @pytest.mark.parametrize( -# "test_image, expected_output_shape, expected_slices", -# [ -# (np.ones((256, 256)), (49, 42, 42, 384), 1), # 2D -# (np.ones((256, 256, 3)), (49, 42, 42, 384), 1), # 2D RGB -# (np.ones((2, 256, 256)), (49, 42, 42, 384), 2), # 3D -# (np.ones((2, 256, 256, 3)), (49, 42, 42, 384), 2), # 3D RGB -# ], -# ) -# def test_dino_embedding_extraction( -# test_image: np.ndarray, expected_output_shape: tuple, expected_slices: int -# ): -# num_slices, img_height, img_width = get_stack_dims(test_image) -# model_adapter = get_mock_model(img_height, img_width) -# check_embedding_extraction( -# test_image, model_adapter, expected_output_shape, expected_slices -# ) +@pytest.mark.parametrize( + "test_image, expected_output_shape, expected_slices", + [ + (np.ones((256, 256)), (49, 42, 42, 384), 1), # 2D + (np.ones((256, 256, 3)), (49, 42, 42, 384), 1), # 2D RGB + (np.ones((2, 256, 256)), (49, 42, 42, 384), 2), # 3D + (np.ones((2, 256, 256, 3)), (49, 42, 42, 384), 2), # 3D RGB + ], +) +def test_dino_embedding_extraction( + test_image: np.ndarray, expected_output_shape: tuple, expected_slices: int +): + num_slices, img_height, img_width = get_stack_dims(test_image) + model_adapter = get_mock_model(img_height, img_width) + check_embedding_extraction( + test_image, model_adapter, expected_output_shape, expected_slices + ) From 227a8d495c2e1c3778a34fa1d827ba445d208620 Mon Sep 17 00:00:00 2001 From: Mehdi Seifi Date: Sun, 15 Jun 2025 19:50:57 +0200 Subject: [PATCH 28/35] added sam2 adapter test --- .../model_adapter_tests/test_sam2_adapter.py | 156 ++++++++++++++++++ 1 file changed, 156 insertions(+) create mode 100644 tests/model_adapter_tests/test_sam2_adapter.py diff --git a/tests/model_adapter_tests/test_sam2_adapter.py b/tests/model_adapter_tests/test_sam2_adapter.py new file mode 100644 index 0000000..3125eb5 --- /dev/null +++ b/tests/model_adapter_tests/test_sam2_adapter.py @@ -0,0 +1,156 @@ +import numpy as np +import pytest +import torch +import torch.nn as nn +from embedding_extraction import check_embedding_extraction + +from featureforest.models.SAM2.adapter import SAM2Adapter +from featureforest.utils.data import get_stack_dims + + +class MockSAM2Encoder(nn.Module): + def __init__(self): + super().__init__() + self.encoder_num_channels = 256 * 3 + + def __call__(self, x): + batch_size = x.shape[0] + # Mock the backbone_fpn output with 3 feature levels + # [b, 256, 256, 256] + # [b, 256, 128, 128] + # [b, 256, 64, 64] + level1 = torch.ones(batch_size, 256, 256, 256) + level2 = torch.ones(batch_size, 256, 128, 128) + level3 = torch.ones(batch_size, 256, 64, 64) + + return {"backbone_fpn": [level1, level2, level3]} + + +def get_mock_model(img_height: int, img_width: int) -> SAM2Adapter: + model = MockSAM2Encoder() + device = torch.device("cpu") + sam2_model_adapter = SAM2Adapter(model, img_height, img_width, device) + return sam2_model_adapter + + +@pytest.mark.parametrize( + "img_height, img_width, expected_patch_size", + [ + ( + 512, + 512, + 256, + ), # SAM2 adapter seems to use 256 as default patch size for smaller images + (256, 256, 256), + (1024, 1024, 512), # For larger images, it uses 512 + ], +) +def test_initialize_sam2_adapter(img_height, img_width, expected_patch_size): + """Test that SAM2Adapter initializes correctly with different image sizes.""" + model = MockSAM2Encoder() + device = torch.device("cpu") + + adapter = SAM2Adapter(model, img_height, img_width, device) + + assert adapter.name == "SAM2_Large" + assert adapter.img_height == img_height + assert adapter.img_width == img_width + assert adapter.device == device + assert adapter.encoder == model + assert adapter.encoder_num_channels == 256 * 3 + assert adapter.patch_size == expected_patch_size + assert adapter.overlap == expected_patch_size // 4 + assert adapter.sam_input_dim == 1024 + + +@pytest.mark.parametrize( + "test_patch", + [ + torch.ones((1, 3, 128, 128)), + torch.ones((3, 3, 128, 128)), + torch.ones((1, 3, 256, 256)), + torch.ones((1, 3, 512, 512)), + ], +) +def test_process_input_patches(test_patch: torch.Tensor): + """Test that SAM2Adapter processes input patches correctly.""" + img_height, img_width = test_patch.shape[-2:] + adapter = get_mock_model(img_height, img_width) + + # Process the input patches + output_features = adapter.get_features_patches(test_patch) + + # Check output shape + batch_size = test_patch.shape[0] + assert output_features.shape[0] == batch_size + # The output should be in format [batch, height, width, channels] + assert output_features.shape[3] == adapter.encoder_num_channels + # Check that the output is a tensor + assert isinstance(output_features, torch.Tensor) + + +def test_get_total_output_channels(): + """Test that SAM2Adapter returns the correct number of output channels.""" + adapter = get_mock_model(512, 512) + + # Check that the output channels is 256 * 3 = 768 + assert adapter.get_total_output_channels() == 256 * 3 + + +@pytest.mark.parametrize( + "test_image, expected_output_shape, expected_slices", + [ + (np.ones((256, 256)), (4, 192, 192, 768), 1), # 2D + (np.ones((256, 256, 3)), (4, 192, 192, 768), 1), # 2D RGB + ], +) +def test_sam2_embedding_extraction_2d( + test_image: np.ndarray, expected_output_shape: tuple, expected_slices: int +): + """Test that SAM2Adapter extracts embeddings from 2D images correctly.""" + num_slices, img_height, img_width = get_stack_dims(test_image) + model_adapter = get_mock_model(img_height, img_width) + check_embedding_extraction( + test_image, model_adapter, expected_output_shape, expected_slices + ) + + +@pytest.mark.parametrize( + "test_image, expected_output_shape, expected_slices", + [ + (np.ones((2, 256, 256)), (4, 192, 192, 768), 2), # 3D + (np.ones((2, 256, 256, 3)), (4, 192, 192, 768), 2), # 3D RGB + ], +) +def test_sam2_embedding_extraction_3d( + test_image: np.ndarray, expected_output_shape: tuple, expected_slices: int +): + """Test that SAM2Adapter extracts embeddings from 3D stacks correctly.""" + num_slices, img_height, img_width = get_stack_dims(test_image) + model_adapter = get_mock_model(img_height, img_width) + check_embedding_extraction( + test_image, model_adapter, expected_output_shape, expected_slices + ) + + +def test_no_patching_mode(): + """Test that SAM2Adapter handles no_patching mode correctly.""" + img_height, img_width = 256, 256 + adapter = get_mock_model(img_height, img_width) + + # Default mode + assert adapter.no_patching is False + assert adapter.patch_size == 256 + assert adapter.overlap == 64 + + # Enable no_patching mode + adapter.no_patching = True + assert adapter.no_patching is True + assert adapter.patch_size == img_height + assert adapter.overlap == 0 + + # Disable no_patching mode + adapter.no_patching = False + assert adapter.no_patching is False + assert adapter.patch_size == 256 + assert adapter.overlap == 64 From b3bd2562ca1aa4556202e80b5bb7674af1bea626 Mon Sep 17 00:00:00 2001 From: Mehdi Seifi Date: Sun, 15 Jun 2025 19:51:38 +0200 Subject: [PATCH 29/35] added pipeline_prediction test --- tests/test_pipeline_prediction.py | 323 ++++++++++++++++++++++++++++++ 1 file changed, 323 insertions(+) create mode 100644 tests/test_pipeline_prediction.py diff --git a/tests/test_pipeline_prediction.py b/tests/test_pipeline_prediction.py new file mode 100644 index 0000000..c148dfd --- /dev/null +++ b/tests/test_pipeline_prediction.py @@ -0,0 +1,323 @@ +from unittest.mock import MagicMock, patch + +import numpy as np +import pytest + +from featureforest.utils.pipeline_prediction import run_prediction_pipeline + + +@pytest.fixture +def mock_model_adapter(): + adapter = MagicMock() + adapter.no_patching = False + adapter.patch_size = 128 + adapter.overlap = 32 + adapter.get_total_output_channels.return_value = 320 + return adapter + + +@pytest.fixture +def mock_rf_model(): + rf = MagicMock() + rf.predict.side_effect = lambda x: np.ones((x.shape[0],), dtype=np.uint8) + return rf + + +@pytest.fixture +def mock_stack_dataset(): + dataset = MagicMock() + dataset.image_shape = (240, 240) + return dataset + + +@pytest.fixture +def mock_extract_embeddings(): + # Simulate two slices, each with two patches + def generator(*args, **kwargs): + yield np.zeros((1,)), 0, 2 + yield np.ones((1,)), 0, 2 + yield np.full((1,), 2), 1, 2 + yield np.full((1,), 3), 1, 2 + + return generator + + +@pytest.fixture +def mock_predict_patches(): + return lambda features, rf, adapter: np.array([[1]] * len(features), dtype=np.uint8) + + +@pytest.fixture +def mock_get_image_mask(): + return lambda patch_masks, h, w, ps, ov: np.ones((h, w), dtype=np.uint8) + + +@patch("featureforest.utils.pipeline_prediction.FFImageDataset") +@patch("featureforest.utils.pipeline_prediction.extract_embeddings") +@patch("featureforest.utils.pipeline_prediction.predict_patches") +@patch("featureforest.utils.pipeline_prediction.get_image_mask") +def test_run_prediction_pipeline_yields_correct_masks( + mock_get_image_mask_func, + mock_predict_patches_func, + mock_extract_embeddings_func, + mock_ffimagedataset, + mock_model_adapter, + mock_rf_model, +): + # Setup mocks + mock_ffimagedataset.return_value.image_shape = (240, 240) + + # Setup extract_embeddings mock to yield the expected tuples + def mock_extract_embeddings_generator(*args, **kwargs): + yield np.zeros((1,)), 0, 2 + yield np.ones((1,)), 0, 2 + yield np.full((1,), 2), 1, 2 + yield np.full((1,), 3), 1, 2 + + mock_extract_embeddings_func.side_effect = mock_extract_embeddings_generator + + # Setup other mocks + mock_predict_patches_func.side_effect = lambda features, rf, adapter: np.array( + [[1]] * len(features), dtype=np.uint8 + ) + mock_get_image_mask_func.side_effect = lambda patch_masks, h, w, ps, ov: np.ones( + (h, w), dtype=np.uint8 + ) + + input_stack = "dummy_path" + results = list( + run_prediction_pipeline(input_stack, mock_model_adapter, mock_rf_model) + ) + + # There should be two yields (one per slice) + assert len(results) == 2 + for mask, idx, total in results: + assert mask.shape == (240, 240) + assert total == 2 + assert idx in (0, 1) + assert np.all(mask == 1) + + +@patch("featureforest.utils.pipeline_prediction.FFImageDataset") +@patch("featureforest.utils.pipeline_prediction.extract_embeddings") +@patch("featureforest.utils.pipeline_prediction.predict_patches") +@patch("featureforest.utils.pipeline_prediction.get_image_mask") +def test_run_prediction_pipeline_handles_single_slice( + mock_get_image_mask_func, + mock_predict_patches_func, + mock_extract_embeddings_func, + mock_ffimagedataset, + mock_model_adapter, + mock_rf_model, +): + # Only one slice + def single_slice_gen(*args, **kwargs): + yield np.zeros((1,)), 0, 1 + yield np.ones((1,)), 0, 1 + + mock_ffimagedataset.return_value.image_shape = (240, 240) + mock_extract_embeddings_func.side_effect = single_slice_gen + mock_predict_patches_func.side_effect = lambda features, rf, adapter: np.array( + [[1]] * len(features), dtype=np.uint8 + ) + mock_get_image_mask_func.side_effect = lambda patch_masks, h, w, ps, ov: np.ones( + (h, w), dtype=np.uint8 + ) + + input_stack = "dummy_path" + results = list( + run_prediction_pipeline(input_stack, mock_model_adapter, mock_rf_model) + ) + + assert len(results) == 1 + mask, idx, total = results[0] + assert mask.shape == (240, 240) + assert idx == 0 + assert total == 1 + assert np.all(mask == 1) + + +@patch("featureforest.utils.pipeline_prediction.FFImageDataset") +@patch("featureforest.utils.pipeline_prediction.extract_embeddings") +@patch("featureforest.utils.pipeline_prediction.predict_patches") +@patch("featureforest.utils.pipeline_prediction.get_image_mask") +def test_run_prediction_pipeline_no_patching_mode( + mock_get_image_mask_func, + mock_predict_patches_func, + mock_extract_embeddings_func, + mock_ffimagedataset, + mock_model_adapter, + mock_rf_model, +): + """Test that the pipeline works correctly when no_patching is True.""" + # Setup model adapter with no_patching=True + mock_model_adapter.no_patching = True + mock_model_adapter.patch_size = 0 # Should be ignored in no_patching mode + mock_model_adapter.overlap = 0 # Should be ignored in no_patching mode + mock_model_adapter.get_total_output_channels.return_value = 320 + + # Setup dataset + mock_ffimagedataset.return_value.image_shape = (240, 240) + + # Setup embeddings generator + def no_patching_gen(*args, **kwargs): + # In no_patching mode, we'd have one feature per slice + yield np.zeros((1, 320)), 0, 2 + yield np.ones((1, 320)), 1, 2 + + mock_extract_embeddings_func.side_effect = no_patching_gen + + # Setup prediction mocks + mock_predict_patches_func.side_effect = lambda features, rf, adapter: np.array( + [[1]] * len(features), dtype=np.uint8 + ) + mock_get_image_mask_func.side_effect = lambda patch_masks, h, w, ps, ov: np.ones( + (h, w), dtype=np.uint8 + ) + + # Run the pipeline + input_stack = "dummy_path" + results = list( + run_prediction_pipeline(input_stack, mock_model_adapter, mock_rf_model) + ) + + # Verify FFImageDataset was created with no_patching=True + mock_ffimagedataset.assert_called_once_with( + images="dummy_path", no_patching=True, patch_size=0, overlap=0 + ) + + # Verify results + assert len(results) == 2 + for i, (mask, idx, total) in enumerate(results): + assert mask.shape == (240, 240) + assert idx == i + assert total == 2 + assert np.all(mask == 1) + + +@patch("featureforest.utils.pipeline_prediction.FFImageDataset") +@patch("featureforest.utils.pipeline_prediction.extract_embeddings") +@patch("featureforest.utils.pipeline_prediction.predict_patches") +@patch("featureforest.utils.pipeline_prediction.get_image_mask") +def test_run_prediction_pipeline_large_image_dimensions( + mock_get_image_mask_func, + mock_predict_patches_func, + mock_extract_embeddings_func, + mock_ffimagedataset, + mock_model_adapter, + mock_rf_model, +): + """Test that the pipeline handles large image dimensions correctly.""" + # Setup large image dimensions + large_height, large_width = 4096, 4096 + mock_ffimagedataset.return_value.image_shape = (large_height, large_width) + + # Setup model adapter + mock_model_adapter.no_patching = False + mock_model_adapter.patch_size = 256 + mock_model_adapter.overlap = 64 + mock_model_adapter.get_total_output_channels.return_value = 320 + + # Setup embeddings generator + def large_image_gen(*args, **kwargs): + # Simulate many patches for a large image (simplified for test) + yield np.zeros((1, 16, 16, 320)), 0, 1 + yield np.ones((1, 16, 16, 320)), 0, 1 + yield np.full((1, 16, 16, 320), 2), 0, 1 + + mock_extract_embeddings_func.side_effect = large_image_gen + + # Setup prediction mocks + mock_predict_patches_func.side_effect = lambda features, rf, adapter: np.array( + [[1]] * len(features), dtype=np.uint8 + ) + + # The get_image_mask should return a mask with the large dimensions + mock_get_image_mask_func.side_effect = lambda patch_masks, h, w, ps, ov: np.ones( + (h, w), dtype=np.uint8 + ) + + # Run the pipeline + input_stack = "dummy_path" + results = list( + run_prediction_pipeline(input_stack, mock_model_adapter, mock_rf_model) + ) + + # Verify results + assert len(results) == 1 + mask, idx, total = results[0] + assert mask.shape == (large_height, large_width) + assert idx == 0 + assert total == 1 + assert np.all(mask == 1) + + # Verify get_image_mask was called with correct dimensions + # Instead of checking the exact mock value, just verify the function was called + # and check that the dimensions and other parameters were correct + assert mock_get_image_mask_func.called + args, kwargs = mock_get_image_mask_func.call_args + # Check that the height, width, patch_size and overlap parameters are correct + assert args[1] == large_height + assert args[2] == large_width + assert args[3] == mock_model_adapter.patch_size + assert args[4] == mock_model_adapter.overlap + + +@patch("featureforest.utils.pipeline_prediction.FFImageDataset") +@patch("featureforest.utils.pipeline_prediction.extract_embeddings") +@patch("featureforest.utils.pipeline_prediction.predict_patches") +@patch("featureforest.utils.pipeline_prediction.get_image_mask") +def test_run_prediction_pipeline_slice_index_continuity( + mock_get_image_mask_func, + mock_predict_patches_func, + mock_extract_embeddings_func, + mock_ffimagedataset, + mock_model_adapter, + mock_rf_model, +): + """Test that the pipeline handles non-continuous slice indices correctly.""" + # Setup + mock_ffimagedataset.return_value.image_shape = (240, 240) + + # Setup model adapter + mock_model_adapter.no_patching = False + mock_model_adapter.patch_size = 128 + mock_model_adapter.overlap = 32 + mock_model_adapter.get_total_output_channels.return_value = 320 + + # Setup embeddings generator with non-continuous slice indices (0, 2, 5) + def non_continuous_slices_gen(*args, **kwargs): + # Slice 0 + yield np.zeros((1,)), 0, 3 + yield np.ones((1,)), 0, 3 + # Slice 2 (skipping 1) + yield np.full((1,), 2), 2, 3 + # Slice 5 (skipping 3 and 4) + yield np.full((1,), 3), 5, 3 + + mock_extract_embeddings_func.side_effect = non_continuous_slices_gen + + # Setup prediction mocks + mock_predict_patches_func.side_effect = lambda features, rf, adapter: np.array( + [[1]] * len(features), dtype=np.uint8 + ) + mock_get_image_mask_func.side_effect = lambda patch_masks, h, w, ps, ov: np.ones( + (h, w), dtype=np.uint8 + ) + + # Run the pipeline + input_stack = "dummy_path" + results = list( + run_prediction_pipeline(input_stack, mock_model_adapter, mock_rf_model) + ) + + # Verify results + assert len(results) == 3 + + # Check that the slice indices match what we expect + expected_indices = [0, 2, 5] + for i, (mask, idx, total) in enumerate(results): + assert mask.shape == (240, 240) + assert idx == expected_indices[i] + assert total == 3 + assert np.all(mask == 1) From d3fce053be5fffef4ff17b3c0e2679fa35f832a9 Mon Sep 17 00:00:00 2001 From: Mehdi Seifi Date: Sun, 15 Jun 2025 19:53:52 +0200 Subject: [PATCH 30/35] bumped version --- src/featureforest/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/featureforest/__init__.py b/src/featureforest/__init__.py index e09001f..cddcf6e 100644 --- a/src/featureforest/__init__.py +++ b/src/featureforest/__init__.py @@ -1,4 +1,4 @@ -__version__ = "0.0.9" +__version__ = "0.1.0" from ._feature_extractor_widget import FeatureExtractorWidget from ._segmentation_widget import SegmentationWidget From 5eb9ed086a905efa0b4cd69f83b6fc24215764b2 Mon Sep 17 00:00:00 2001 From: Mehdi Seifi Date: Sun, 15 Jun 2025 20:19:22 +0200 Subject: [PATCH 31/35] run_pipeline script can be used to only extract features --- run_pipeline.py | 62 ++++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 56 insertions(+), 6 deletions(-) diff --git a/run_pipeline.py b/run_pipeline.py index e3e7da3..6f2b9b8 100755 --- a/run_pipeline.py +++ b/run_pipeline.py @@ -14,6 +14,7 @@ postprocess_with_sam, postprocess_with_sam_auto, ) +from featureforest.utils.extract import extract_embeddings_to_file from featureforest.utils.pipeline_prediction import run_prediction_pipeline @@ -55,7 +56,7 @@ def apply_postprocessing( return post_masks -def main( +def run( input_file: str, rf_model_file: str, output_dir: str, @@ -152,16 +153,54 @@ def main( print(f"total elapsed time: {(time.perf_counter() - tic)} seconds") +def run_extract_features( + input_file: str, + output_dir: str, + model_name: str = "SAM2_Large", + no_patching: bool = False, +): + # input image + data_path = Path(input_file) + print(f"data_path exists: {data_path.exists()}") + # get stack dims + lazy_stack = pims.open(input_file) + img_height, img_width = lazy_stack.frame_shape + + # list of available models + available_models = get_available_models() + assert model_name in available_models, ( + f"Couldn't find {model_name} in available models\n{available_models}." + ) + model_adapter = get_model(model_name, img_height, img_width) + model_adapter.no_patching = no_patching + patch_size = model_adapter.patch_size + overlap = model_adapter.overlap + print(f"patch_size: {patch_size}, overlap: {overlap}") + + # zarr data store + store_path = Path(output_dir) + if not store_path.name.endswith(".zarr"): + output_dir += ".zarr" + + tic = time.perf_counter() + print(f"extracting features from {input_file}...") + for idx, total in extract_embeddings_to_file(input_file, output_dir, model_adapter): + print(f"slice {idx + 1} / {total}") + + print(f"total elapsed time: {(time.perf_counter() - tic)} seconds") + + if __name__ == "__main__": parser = argparse.ArgumentParser( description="\nFeatureForest run-pipeline script", ) parser.add_argument("--data", help="Path to the input image", required=True) - parser.add_argument("--rf_model", help="Path to the trained RF model", required=True) parser.add_argument("--outdir", help="Path to the output directory", required=True) + parser.add_argument("--rf_model", help="Path to the trained RF model", required=False) parser.add_argument( "--feat_model", choices=get_available_models(), + default="SAM2_Large", help="Name of the model for feature extraction", ) parser.add_argument( @@ -182,15 +221,26 @@ def main( help="Post-processing area threshold to remove small regions; default=50", ) parser.add_argument( - "--use_sam_predictor", + "--post_sam", default=True, action="store_true", - help="uses SAM2 for generating final masks", + help="to use SAM2 for generating final masks", + ) + parser.add_argument( + "--only_extract", + default=False, + action="store_true", + help="to only extract features to zarr file without running prediction pipeline", ) args = parser.parse_args() - main( + if args.only_extract: + run_extract_features(args.data, args.outdir, args.feat_model, args.no_patching) + exit(0) + + assert args.rf_model is not None, "RF model file is required." + run( input_file=args.data, rf_model_file=args.rf_model, output_dir=args.outdir, @@ -198,5 +248,5 @@ def main( no_patching=args.no_patching, smoothing_iterations=args.smoothing_iterations, area_threshold=args.area_threshold, - use_sam_predictor=args.use_sam_predictor, + use_sam_predictor=args.post_sam, ) From 0cc7b37e6a1c63dcbe74b4acc29173b19dd0bdf4 Mon Sep 17 00:00:00 2001 From: Mehdi Seifi Date: Wed, 18 Jun 2025 11:19:11 +0200 Subject: [PATCH 32/35] updated dependencies in pyproject.toml --- pyproject.toml | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 18ef7e7..fd7ae57 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -54,7 +54,9 @@ dependencies = [ "qtpy", "magicgui", "napari[all]", - "h5py", + "tifffile", + "zarr>=2,<3", + "pims", "pooch", "tqdm>=4.66.1", "iopath>=0.1.10", From 1118177ab2d99af52eefa01eb46882d8c8373d91 Mon Sep 17 00:00:00 2001 From: Mehdi Seifi Date: Wed, 18 Jun 2025 12:03:03 +0200 Subject: [PATCH 33/35] fixed bug for calculating padding with no_patching --- src/featureforest/utils/data.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/featureforest/utils/data.py b/src/featureforest/utils/data.py index 17066aa..85b81ef 100644 --- a/src/featureforest/utils/data.py +++ b/src/featureforest/utils/data.py @@ -101,6 +101,9 @@ def get_paddings( # pad amount should be enough to make the # (final size - patch_size) / stride an integer number. # see https://pytorch.org/docs/stable/generated/torch.nn.Unfold.html + # if whole image is just one patch + if img_height == patch_size and img_width == patch_size and patch_size == stride: + return 0, 0 new_width = img_width + 2 * margin pad_right = stride - int((new_width - patch_size) % stride) new_height = img_height + 2 * margin From 10acde5431ccd979bf1e4b56727abe122a22789e Mon Sep 17 00:00:00 2001 From: Mehdi Seifi Date: Fri, 20 Jun 2025 22:13:33 +0200 Subject: [PATCH 34/35] revert back storage to hdf5; improved get_train_data --- .../_feature_extractor_widget.py | 8 +- src/featureforest/_segmentation_widget.py | 126 ++++++++++-------- src/featureforest/utils/extract.py | 44 +++--- 3 files changed, 95 insertions(+), 83 deletions(-) diff --git a/src/featureforest/_feature_extractor_widget.py b/src/featureforest/_feature_extractor_widget.py index 666402d..9553da1 100644 --- a/src/featureforest/_feature_extractor_widget.py +++ b/src/featureforest/_feature_extractor_widget.py @@ -164,14 +164,14 @@ def save_storage(self): storage_name = f"{image_layer_name}_{model_name}" if self.no_patching_checkbox.isChecked(): storage_name += "_no_patching" - storage_name += ".zarr" + storage_name += ".hdf5" # open the save dialog selected_file, _filter = QFileDialog.getSaveFileName( - self, "FeatureForest", storage_name, "Zarr Storage(*.zarr)" + self, "FeatureForest", storage_name, "Feature Storage(*.hdf)" ) if selected_file is not None and len(selected_file) > 0: - if not selected_file.endswith(".zarr"): - selected_file += ".zarr" + if not selected_file.endswith(".hdf"): + selected_file += ".hdf" self.storage_textbox.setText(selected_file) self.extract_button.setEnabled(True) diff --git a/src/featureforest/_segmentation_widget.py b/src/featureforest/_segmentation_widget.py index 92b6ecc..b25e0c2 100644 --- a/src/featureforest/_segmentation_widget.py +++ b/src/featureforest/_segmentation_widget.py @@ -8,12 +8,12 @@ from pathlib import Path from typing import Optional +import h5py import napari import napari.layers import napari.utils.notifications as notif import numpy as np import tifffile -import zarr from napari.qt.threading import create_worker from napari.utils import progress as np_progress from napari.utils.events import Event @@ -68,7 +68,7 @@ def __init__(self, napari_viewer: napari.Viewer) -> None: self.gt_layer: napari.layers.Labels | None = None self.segmentation_layer: napari.layers.Labels | None = None self.postprocess_layer: napari.layers.Labels | None = None - self.storage: zarr.Group | None = None # type: ignore + self.storage: h5py.File | None = None self.rf_model: RandomForestClassifier | None = None self.model_adapter: BaseModelAdapter | None = None self.sam_auto_masks = None @@ -547,21 +547,23 @@ def sam_auto_post_checked(self, checked: bool) -> None: self.sam_post_checkbox.setChecked(False) def select_storage(self) -> None: - selected_file = QFileDialog.getExistingDirectory(self, "FeatureForest", "..") + selected_file, _ = QFileDialog.getOpenFileName( + self, "FeatureForest", "..", "Feature Storage(*.hdf5)" + ) if selected_file is not None and len(selected_file) > 0: self.storage_textbox.setText(selected_file) # load the storage - self.storage: zarr.Group = zarr.open(selected_file, mode="r") # type: ignore + self.storage = h5py.File(selected_file, mode="r") # set the patch size and overlap from the selected storage self.patch_size = self.storage.attrs.get("patch_size", self.patch_size) self.overlap = self.storage.attrs.get("overlap", self.overlap) self.stride, _ = get_stride_margin(self.patch_size, self.overlap) # initialize the model based on the selected storage - img_height = self.storage.attrs["img_height"] - img_width = self.storage.attrs["img_width"] + img_height: int = self.storage.attrs.get("img_height", 0) + img_width: int = self.storage.attrs.get("img_width", 0) model_name = str(self.storage.attrs["model"]) - no_patching = self.storage.attrs["no_patching"] + no_patching = self.storage.attrs.get("no_patching", False) self.model_adapter = get_model(model_name, img_height, img_width) self.model_adapter.no_patching = no_patching print(model_name, self.patch_size, self.overlap, no_patching) @@ -594,82 +596,92 @@ def set_stats_label_layer(self) -> None: if layer is not None: self.stats.set_label_layer(layer) - def get_class_labels(self) -> dict[int, np.ndarray]: - labels_dict = {} + def get_labeled_pixels(self) -> dict[int, np.ndarray]: layer = get_layer( self.viewer, self.gt_combo.currentText(), config.NAPARI_LABELS_LAYER ) if layer is None: print("No label layer is selected!") notif.show_error("No label layer is selected!") - return labels_dict + return {} + labeled_pixels = {} + slice_dim = 0 + ydim = 1 + xdim = 2 class_indices = np.unique(layer.data).tolist() # class zero is the napari background class. class_indices = [i for i in class_indices if i > 0] - for class_idx in class_indices: - positions = np.argwhere(layer.data == class_idx) - labels_dict[class_idx] = positions - - return labels_dict - - def analyze_labels(self, labels_dict: Optional[dict] = None) -> None: - if labels_dict is None: - labels_dict = self.get_class_labels() - num_labels = [len(v) for v in labels_dict.values()] - self.num_class_label.setText(f"Number of classes: {len(num_labels)}") - each_class = "\n".join( - [f"class {i + 1}: {num_labels[i]:,d}" for i in range(len(num_labels))] - ) - self.each_class_label.setText("Labels per class:\n" + each_class) + coords = np.argwhere(np.isin(layer.data, class_indices)) + for s_i in np.unique(coords[:, slice_dim]).tolist(): + s_coords = coords[coords[:, slice_dim] == s_i] + slice_labels = layer.data[s_i, s_coords[:, ydim], s_coords[:, xdim]] + labeled_pixels[s_i] = np.column_stack( + [ + slice_labels, + s_coords[:, [ydim, xdim]], # omit slice dim + ] + ) + + return labeled_pixels + + def analyze_labels(self, labeled_pixels: dict | None = None) -> None: + if not labeled_pixels: + labeled_pixels = self.get_labeled_pixels() + + if len(labeled_pixels) > 0: + labels = np.concat([v[:, 0] for v in labeled_pixels.values()], axis=None) + classes = np.unique(labels) + self.num_class_label.setText(f"Number of classes: {len(classes)}") + each_class = "\n".join([f"class {c}: {sum(labels == c):,d}" for c in classes]) + self.each_class_label.setText("Labels per class:\n" + each_class) def show_usage_stats(self) -> None: stats_widget = UsageStats(self.stats) stats_widget.exec() def get_train_data(self) -> tuple[np.ndarray, np.ndarray] | None: - # get ground truth class labels - labels_dict = self.get_class_labels() - if len(labels_dict) == 0: + # get ground truth labeled pixels + labeled_pixels = self.get_labeled_pixels() + if labeled_pixels is None: return None if self.storage is None: notif.show_error("No embeddings storage file is selected!") return None # update labels stats - self.analyze_labels(labels_dict) + self.analyze_labels(labeled_pixels) num_slices, img_height, img_width = get_stack_dims(self.image_layer.data) - num_labels = sum([len(v) for v in labels_dict.values()]) total_channels = self.model_adapter.get_total_output_channels() + num_labels = sum([len(v) for v in labeled_pixels.values()]) + label_dim = 0 + ydim = 1 + xdim = 2 + count = 0 train_data = np.zeros((num_labels, total_channels)) labels = np.zeros(num_labels, dtype=np.int32) - 1 - count = 0 - for class_index in np_progress( - labels_dict, desc="getting training data", total=len(labels_dict.keys()) + for s_idx, label_coords in np_progress( + labeled_pixels.items(), desc="getting training data" ): - class_label_coords = labels_dict[class_index] - uniq_slices = np.unique(class_label_coords[:, 0]).tolist() - # for each unique slice, load unique patches from the storage, - # then get the pixel features within loaded patch. - for slice_index in np_progress(uniq_slices, desc="reading slices"): - slice_coords = class_label_coords[ - class_label_coords[:, 0] == slice_index - ][:, 1:] # omit the slice dim - patch_indices = get_patch_indices( - slice_coords, img_height, img_width, self.patch_size, self.overlap - ) - grp_key = str(slice_index) - slice_dataset = self.storage[grp_key]["features"] - for p_i in np.unique(patch_indices): - patch_coords = slice_coords[patch_indices == p_i] - patch_features = slice_dataset[p_i] - train_data[count : count + len(patch_coords)] = patch_features[ - patch_coords[:, 0] % self.stride, patch_coords[:, 1] % self.stride - ] - labels[count : count + len(patch_coords)] = ( - class_index - 1 - ) # to have bg class as zero - count += len(patch_coords) + # slice labels + s_labels = label_coords[:, label_dim] + # slice labeled coords + s_coords = label_coords[:, [ydim, xdim]] + patch_indices = get_patch_indices( + s_coords, img_height, img_width, self.patch_size, self.overlap + ) + grp_key = str(s_idx) + slice_dataset: h5py.Dataset = self.storage[grp_key]["features"] + # loop through slice patches + for p_i in np.unique(patch_indices).tolist(): + patch_features = slice_dataset[p_i] + patch_coords = s_coords[patch_indices == p_i] + train_data[count : count + len(patch_coords)] = patch_features[ + patch_coords[:, 0] % self.stride, patch_coords[:, 1] % self.stride + ] + # -1: to have bg class as zero + labels[count : count + len(patch_coords)] = s_labels - 1 + count += len(patch_coords) assert (labels > -1).all() @@ -706,7 +718,7 @@ def train_model(self) -> None: min_samples_split=15, min_samples_leaf=3, max_features=25, - n_jobs=2 if os.cpu_count() < 5 else os.cpu_count() - 3, + n_jobs=os.cpu_count() - 1, verbose=1, ) rf_classifier.fit(train_data, labels) diff --git a/src/featureforest/utils/extract.py b/src/featureforest/utils/extract.py index d5bd9c6..7f58a01 100644 --- a/src/featureforest/utils/extract.py +++ b/src/featureforest/utils/extract.py @@ -1,12 +1,9 @@ from collections.abc import Generator from typing import Optional +import h5py import numpy as np import torch -import zarr -import zarr.core -import zarr.storage -from numcodecs import Zstd from torch.utils.data import DataLoader from featureforest.models import BaseModelAdapter @@ -73,35 +70,38 @@ def extract_embeddings_to_file( image_dataset = FFImageDataset( images=image, no_patching=no_patching, patch_size=patch_size, overlap=overlap ) - # create the zarr storage - storage = zarr.storage.DirectoryStore(storage_path) - store_root = zarr.group(store=storage, overwrite=False) - store_root.attrs["num_slices"] = image_dataset.num_images - store_root.attrs["img_height"] = image_dataset.image_shape[0] - store_root.attrs["img_width"] = image_dataset.image_shape[1] - store_root.attrs["model"] = model_adapter.name - store_root.attrs["no_patching"] = no_patching - store_root.attrs["patch_size"] = patch_size - store_root.attrs["overlap"] = overlap + # create the storage + storage = h5py.File(storage_path, "w") + storage.attrs["num_slices"] = image_dataset.num_images + storage.attrs["img_height"] = image_dataset.image_shape[0] + storage.attrs["img_width"] = image_dataset.image_shape[1] + storage.attrs["model"] = model_adapter.name + storage.attrs["no_patching"] = no_patching + storage.attrs["patch_size"] = patch_size + storage.attrs["overlap"] = overlap for img_features, idx, total in extract_embeddings( model_adapter, image_dataset=image_dataset ): - if store_root.get(str(idx)) is None: + if storage.get(str(idx)) is None: # create a group for the slice - grp = store_root.create_group(str(idx)) # type: ignore - z_arr = grp.create( # type: ignore + grp = storage.create_group(str(idx)) + ds = grp.create_dataset( name="features", shape=img_features.shape, + maxshape=(None,) + img_features.shape[1:], chunks=(1,) + img_features.shape[1:], dtype=np.float16, - compressor=Zstd(level=3), + compression="lzf", ) - z_arr[:] = img_features + ds[:] = img_features else: - # append features to the slice/image group - grp: zarr.core.Array = store_root[str(idx)]["features"] # type: ignore - grp.append(img_features) + # resize the dataset and append features to the slice/image group + ds: h5py.Dataset = storage[str(idx)]["features"] # type: ignore + old_num_patches = ds.shape[0] + num_patches = ds.shape[0] + img_features.shape[0] + ds.resize(num_patches, axis=0) + ds[old_num_patches:, ...] = img_features yield idx, total From d26ae9157ae06095428ef4bbbf4e2d9fa440c7ce Mon Sep 17 00:00:00 2001 From: Mehdi Seifi Date: Fri, 20 Jun 2025 23:12:47 +0200 Subject: [PATCH 35/35] fixed embedding_extraction test --- tests/model_adapter_tests/embedding_extraction.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/tests/model_adapter_tests/embedding_extraction.py b/tests/model_adapter_tests/embedding_extraction.py index 3fc13dd..0b51f51 100644 --- a/tests/model_adapter_tests/embedding_extraction.py +++ b/tests/model_adapter_tests/embedding_extraction.py @@ -1,7 +1,6 @@ from tempfile import TemporaryDirectory -import zarr -import zarr.storage +import h5py from featureforest.utils.extract import extract_embeddings_to_file @@ -19,7 +18,7 @@ def check_embedding_extraction( # Run the extractor generator till the end _ = list(extractor_generator) - read_storage: zarr.Group = zarr.open(tmp_file, mode="r") # type: ignore + read_storage: h5py.File = h5py.File(tmp_file, mode="r") # type: ignore slices = list(read_storage.keys()) assert len(slices) == expected_slices, ( f"Unexpected number of slices: {len(slices)}, expected: {expected_slices}" @@ -35,4 +34,4 @@ def check_embedding_extraction( f"expected: {expected_output_shape}" ) - read_storage.store.close() + read_storage.close()