diff --git a/pyproject.toml b/pyproject.toml index 18ef7e7..fd7ae57 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -54,7 +54,9 @@ dependencies = [ "qtpy", "magicgui", "napari[all]", - "h5py", + "tifffile", + "zarr>=2,<3", + "pims", "pooch", "tqdm>=4.66.1", "iopath>=0.1.10", diff --git a/requirements.txt b/requirements.txt index 9137ec2..a5bc997 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,7 +9,9 @@ pynrrd qtpy magicgui napari[all] -h5py +tifffile +zarr>=2,<3 +pims pooch tqdm>=4.66.1 iopath>=0.1.10 diff --git a/run_pipeline.ipynb b/run_pipeline.ipynb deleted file mode 100755 index 4a47619..0000000 --- a/run_pipeline.ipynb +++ /dev/null @@ -1,436 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Installation and Requirements\n", - "\n", - "Please refer to the [_featureforest repo_](https://github.com/juglab/featureforest)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import pickle\n", - "from pathlib import Path\n", - "\n", - "import h5py\n", - "import numpy as np\n", - "import torch\n", - "from PIL import Image, ImageSequence\n", - "from sklearn.ensemble import RandomForestClassifier\n", - "\n", - "from tqdm.notebook import trange, tqdm\n", - "\n", - "from featureforest.models import get_available_models, get_model\n", - "from featureforest.models.SAM import SAMAdapter\n", - "from featureforest.utils.data import (\n", - " patchify,\n", - " is_image_rgb, get_stride_margin,\n", - " get_num_patches, get_stride_margin\n", - ")\n", - "from featureforest.postprocess import (\n", - " postprocess,\n", - " postprocess_with_sam, postprocess_with_sam_auto,\n", - " get_sam_auto_masks\n", - ")\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Utility functions" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "def get_slice_features(\n", - " image: np.ndarray,\n", - " patch_size: int,\n", - " overlap: int,\n", - " model_adapter,\n", - " storage_group,\n", - "):\n", - " \"\"\"Extract the model features for one slice and save them into storage file.\"\"\"\n", - " # image to torch tensor\n", - " img_data = torch.from_numpy(image).to(torch.float32) / 255.0\n", - " # for sam the input image should be 4D: BxCxHxW ; an RGB image.\n", - " if is_image_rgb(image):\n", - " # it's already RGB, put the channels first and add a batch dim.\n", - " img_data = img_data[..., :3] # ignore the Alpha channel (in case of PNG).\n", - " img_data = img_data.permute([2, 0, 1]).unsqueeze(0)\n", - " else:\n", - " img_data = img_data.unsqueeze(0).unsqueeze(0).expand(-1, 3, -1, -1)\n", - "\n", - " # get input patches\n", - " data_patches = patchify(img_data, patch_size, overlap)\n", - " num_patches = len(data_patches)\n", - "\n", - " # set a low batch size\n", - " batch_size = 8\n", - " # for big SAM we need even lower batch size :(\n", - " if isinstance(model_adapter, SAMAdapter):\n", - " batch_size = 2\n", - "\n", - " num_batches = int(np.ceil(num_patches / batch_size))\n", - " # prepare storage for the slice embeddings\n", - " total_channels = model_adapter.get_total_output_channels()\n", - " stride, _ = get_stride_margin(patch_size, overlap)\n", - "\n", - " if model_adapter.name not in storage_group:\n", - " dataset = storage_group.create_dataset(\n", - " model_adapter.name, shape=(num_patches, stride, stride, total_channels)\n", - " )\n", - " else:\n", - " dataset = storage_group[model_adapter.name]\n", - "\n", - " # get sam encoder output for image patches\n", - " # print(\"\\nextracting slice features:\")\n", - " for b_idx in tqdm(range(num_batches), desc=\"extracting slice feature:\"):\n", - " # print(f\"batch #{b_idx + 1} of {num_batches}\")\n", - " start = b_idx * batch_size\n", - " end = start + batch_size\n", - " slice_features = model_adapter.get_features_patches(\n", - " data_patches[start:end].to(model_adapter.device)\n", - " )\n", - " if not isinstance(slice_features, tuple):\n", - " # model has only one output\n", - " num_out = slice_features.shape[0] # to take care of the last batch size\n", - " dataset[start : start + num_out] = slice_features\n", - " else:\n", - " # model has more than one output: put them into storage one by one\n", - " ch_start = 0\n", - " for feat in slice_features:\n", - " num_out = feat.shape[0]\n", - " ch_end = ch_start + feat.shape[-1] # number of features\n", - " dataset[start : start + num_out, :, :, ch_start:ch_end] = feat\n", - " ch_start = ch_end\n", - "\n", - "\n", - "def predict_slice(\n", - " rf_model, patch_dataset, model_adapter,\n", - " img_height, img_width, patch_size, overlap\n", - "):\n", - " \"\"\"Predict a slice patch by patch\"\"\"\n", - " segmentation_image = []\n", - " # shape: N x target_size x target_size x C\n", - " feature_patches = patch_dataset[:]\n", - " num_patches = feature_patches.shape[0]\n", - " total_channels = model_adapter.get_total_output_channels()\n", - " stride, margin = get_stride_margin(patch_size, overlap)\n", - "\n", - " for i in tqdm(\n", - " range(num_patches), desc=\"Predicting slice patches\", position=1, leave=True\n", - " ):\n", - " input_data = feature_patches[i].reshape(-1, total_channels)\n", - " predictions = rf_model.predict(input_data).astype(np.uint8)\n", - " segmentation_image.append(predictions)\n", - "\n", - " segmentation_image = np.vstack(segmentation_image)\n", - " # reshape into the image size + padding\n", - " patch_rows, patch_cols = get_num_patches(\n", - " img_height, img_width, patch_size, overlap\n", - " )\n", - " segmentation_image = segmentation_image.reshape(\n", - " patch_rows, patch_cols, stride, stride\n", - " )\n", - " segmentation_image = np.moveaxis(segmentation_image, 1, 2).reshape(\n", - " patch_rows * stride,\n", - " patch_cols * stride\n", - " )\n", - " # skip paddings\n", - " segmentation_image = segmentation_image[:img_height, :img_width]\n", - "\n", - " return segmentation_image\n", - "\n", - "\n", - "def apply_postprocessing(\n", - " input_image, segmentation_image,\n", - " smoothing_iterations, area_threshold, area_is_absolute,\n", - " use_sam_predictor, use_sam_autoseg, iou_threshold\n", - "):\n", - " post_masks = {}\n", - " # if not use_sam_predictor and not use_sam_autoseg:\n", - " mask = postprocess(\n", - " segmentation_image, smoothing_iterations,\n", - " area_threshold, area_is_absolute\n", - " )\n", - " post_masks[\"Simple\"] = mask\n", - "\n", - " if use_sam_predictor:\n", - " mask = postprocess_with_sam(\n", - " segmentation_image,\n", - " smoothing_iterations, area_threshold, area_is_absolute\n", - " )\n", - " post_masks[\"SAMPredictor\"] = mask\n", - "\n", - " if use_sam_autoseg:\n", - " sam_auto_masks = get_sam_auto_masks(input_image)\n", - " mask = postprocess_with_sam_auto(\n", - " sam_auto_masks,\n", - " segmentation_image,\n", - " smoothing_iterations, iou_threshold,\n", - " area_threshold, area_is_absolute\n", - " )\n", - " post_masks[\"SAMAutoSegmentation\"] = mask\n", - "\n", - "\n", - " return post_masks" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Set the Input, RF Model and the result directory paths" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# input image\n", - "data_path = \"../datasets/data.tif\"\n", - "data_path = Path(data_path)\n", - "print(f\"data_path exists: {data_path.exists()}\")\n", - "\n", - "# random forest model\n", - "rf_model_path = \"../datasets/rf_model.bin\"\n", - "rf_model_path = Path(rf_model_path)\n", - "print(f\"rf_model_path exists: {rf_model_path.exists()}\")\n", - "\n", - "# result folder\n", - "segmentation_dir = Path(\"../datasets/segmentation_result\")\n", - "segmentation_dir.mkdir(parents=True, exist_ok=True)\n", - "\n", - "# temporary storage path for saving extracted embeddings patches\n", - "storage_path = \"./temp_storage.hdf5\"" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Prepare the Input and RF Model" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# get patch sizes\n", - "input_stack = Image.open(data_path)\n", - "\n", - "num_slices = input_stack.n_frames\n", - "img_height = input_stack.height\n", - "img_width = input_stack.width\n", - "\n", - "print(num_slices, img_height, img_width)\n", - "# print(patch_size, target_patch_size)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "with open(rf_model_path, mode=\"rb\") as f:\n", - " model_data = pickle.load(f)\n", - "# compatibility check for old format rf model\n", - "if isinstance(model_data, dict): # noqa: SIM108\n", - " # new format\n", - " rf_model = model_data[\"rf_model\"]\n", - "else:\n", - " # old format\n", - " rf_model = model_data\n", - "\n", - "rf_model.set_params(verbose=0)\n", - "rf_model" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Initializing the Model for Feature Extraction" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# list of available models\n", - "get_available_models()" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [], - "source": [ - "model_name = \"MobileSAM\"" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", - "# print(f\"running on {device}\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "model_adapter = get_model(model_name, img_height, img_width)\n", - "\n", - "patch_size = model_adapter.patch_size\n", - "overlap = model_adapter.overlap\n", - "\n", - "patch_size, overlap" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": {}, - "outputs": [], - "source": [ - "# post-processing parameters\n", - "do_postprocess = True\n", - "\n", - "smoothing_iterations = 25\n", - "area_threshold = 100 # to ignore mask regions with area below this threshold\n", - "area_is_absolute = True # is area is based on pixels or pecentage (False)\n", - "\n", - "use_sam_predictor = True\n", - "use_sam_autoseg = False\n", - "sam_autoseg_iou_threshold = 0.4" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Prediction" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": {}, - "outputs": [], - "source": [ - "# create the slice temporary storage\n", - "storage = h5py.File(storage_path, \"w\")\n", - "storage_group = storage.create_group(\"slice\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "for i, page in tqdm(\n", - " enumerate(ImageSequence.Iterator(input_stack)),\n", - " desc=\"Slices\", total=num_slices, position=0\n", - "):\n", - " # print(f\"slice {i + 1}\", end=\"\\n\")\n", - " slice_img = np.array(page.convert(\"RGB\"))\n", - "\n", - " get_slice_features(slice_img, patch_size, overlap, model_adapter, storage_group)\n", - "\n", - " segmentation_image = predict_slice(\n", - " rf_model, storage_group[model_adapter.name], model_adapter,\n", - " img_height, img_width,\n", - " patch_size, overlap\n", - " )\n", - "\n", - " img = Image.fromarray(segmentation_image)\n", - " img.save(segmentation_dir.joinpath(f\"slice_{i:04}_prediction.tiff\"))\n", - "\n", - " if do_postprocess:\n", - " post_masks = apply_postprocessing(\n", - " slice_img, segmentation_image,\n", - " smoothing_iterations, area_threshold, area_is_absolute,\n", - " use_sam_predictor, use_sam_autoseg, sam_autoseg_iou_threshold\n", - " )\n", - " # save results\n", - " for name, mask in post_masks.items():\n", - " img = Image.fromarray(mask)\n", - " seg_dir = segmentation_dir.joinpath(name)\n", - " seg_dir.mkdir(exist_ok=True) \n", - " img.save(seg_dir.joinpath(f\"slice_{i:04}_{name}.tiff\"))\n", - "\n", - "\n", - "\n", - "if storage is not None:\n", - " storage.close()\n", - " storage = None\n", - "Path(storage_path).unlink()" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": {}, - "outputs": [], - "source": [ - "if storage is not None:\n", - " storage.close()\n", - " Path(storage_path).unlink()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "project52", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.13" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/run_pipeline.py b/run_pipeline.py index b09f473..6f2b9b8 100755 --- a/run_pipeline.py +++ b/run_pipeline.py @@ -1,143 +1,21 @@ import argparse -import multiprocessing as mp import pickle import time -from collections.abc import Generator from pathlib import Path import numpy as np -import torch -from PIL import Image, ImageSequence -from sklearn.ensemble import RandomForestClassifier as RF +import pims +import tifffile -from featureforest.models import BaseModelAdapter, get_available_models, get_model -from featureforest.models.SAM import SAMAdapter +from featureforest.models import get_available_models, get_model from featureforest.postprocess import ( get_sam_auto_masks, postprocess, postprocess_with_sam, postprocess_with_sam_auto, ) -from featureforest.utils.data import ( - get_num_patches, - get_stride_margin, - is_image_rgb, - patchify, -) - - -def predict_patches( - patch_features: np.ndarray, - rf_model: RF, - model_adapter: BaseModelAdapter, - batch_idx: int, - result_dict: dict, -) -> None: - """Predicts the class labels for a given set of patch features. - - Args: - patch_features (np.ndarray): Patch features to be predicted. - rf_model (RF): Random Forest Model used for predictions. - model_adapter (BaseModelAdapter): Model adapter object used for extracting data. - batch_idx (int): Batch index of the current patch features. - result_dict (dict): Dictionary where the predicted masks will be stored. - """ - patch_masks = [] - # shape: N x target_size x target_size x C - num_patches = patch_features.shape[0] - total_channels = model_adapter.get_total_output_channels() - print(f"predicting {num_patches} patches...") - for i in range(num_patches): - patch_data = patch_features[i].reshape(-1, total_channels) - pred = rf_model.predict(patch_data).astype(np.uint8) - patch_masks.append(pred) - - patch_masks = np.vstack(patch_masks) - result_dict[batch_idx] = patch_masks - - -def get_image_mask( - patch_masks: np.ndarray, - img_height: int, - img_width: int, - patch_size: int, - overlap: int, -) -> np.ndarray: - """Gets the final image mask by combining the individual patch masks. - - Args: - patch_masks (ndarray): Patch masks to combine into an image mask. - img_height (int): Height of the input image. - img_width (int): Width of the input image. - patch_size (int): Size of the patches. - overlap (int): Overlap between adjacent patches. - - Returns: - np.ndarray: Final image mask. - """ - stride, _ = get_stride_margin(patch_size, overlap) - patch_rows, patch_cols = get_num_patches(img_height, img_width, patch_size, overlap) - mask_image = patch_masks.reshape(patch_rows, patch_cols, stride, stride) - mask_image = np.moveaxis(mask_image, 1, 2).reshape( - patch_rows * stride, patch_cols * stride - ) - # skip paddings - mask_image = mask_image[:img_height, :img_width] - - return mask_image - - -def get_slice_features( - image: np.ndarray, model_adapter: BaseModelAdapter -) -> Generator[tuple[int, np.ndarray], None, None]: - """Extract features for one image using the given model adapter - - Args: - image: Input image array - model_adapter: Model adapter to extract features from - Returns: - Generator yielding tuples containing batch index and extracted features. - """ - # image to torch tensor - img_data = torch.from_numpy(image).to(torch.float32) - # normalize in [0, 1] - _min = img_data.min() - _max = img_data.max() - img_data = (img_data - _min) / (_max - _min) - # for sam the input image should be 4D: BxCxHxW ; an RGB image. - if is_image_rgb(image): - # it's already RGB, put the channels first and add a batch dim. - img_data = img_data[..., :3] # ignore the Alpha channel (in case of PNG). - img_data = img_data.permute([2, 0, 1]).unsqueeze(0) - else: - img_data = img_data.unsqueeze(0).unsqueeze(0).expand(-1, 3, -1, -1) - - # get input patches - patch_size = model_adapter.patch_size - overlap = model_adapter.overlap - data_patches = patchify(img_data, patch_size, overlap) - num_patches = len(data_patches) - - # set a low batch size - batch_size = 8 - # for big SAM we need even lower batch size :( - if isinstance(model_adapter, SAMAdapter): - batch_size = 2 - num_batches = int(np.ceil(num_patches / batch_size)) - - # get sam encoder output for image patches - print("extracting slice features:") - for b_idx in range(num_batches): - print(f"batch #{b_idx + 1} of {num_batches}") - start = b_idx * batch_size - end = start + batch_size - slice_features = model_adapter.get_features_patches( - data_patches[start:end].to(model_adapter.device) - ).cpu() - if isinstance(slice_features, tuple): # model with more than one output - slice_features = torch.cat(slice_features, dim=-1) - - yield b_idx, slice_features.numpy() +from featureforest.utils.extract import extract_embeddings_to_file +from featureforest.utils.pipeline_prediction import run_prediction_pipeline def apply_postprocessing( @@ -149,19 +27,19 @@ def apply_postprocessing( use_sam_predictor: bool, use_sam_autoseg: bool, iou_threshold: float, -) -> np.ndarray: +) -> dict: post_masks = {} - # if not use_sam_predictor and not use_sam_autoseg: + # simple post-processing mask = postprocess( segmentation_image, smoothing_iterations, area_threshold, area_is_absolute ) - post_masks["Simple"] = mask + post_masks["post_simple"] = mask if use_sam_predictor: mask = postprocess_with_sam( segmentation_image, smoothing_iterations, area_threshold, area_is_absolute ) - post_masks["SAMPredictor"] = mask + post_masks["post_sam"] = mask if use_sam_autoseg: sam_auto_masks = get_sam_auto_masks(input_image) @@ -173,16 +51,17 @@ def apply_postprocessing( area_threshold, area_is_absolute, ) - post_masks["SAMAutoSegmentation"] = mask + post_masks["post_sam_auto"] = mask return post_masks -def main( +def run( input_file: str, rf_model_file: str, output_dir: str, model_name: str = "SAM2_Large", + no_patching: bool = False, smoothing_iterations: int = 25, area_threshold: int = 50, use_sam_predictor: bool = True, @@ -198,16 +77,14 @@ def main( # result folder segmentation_dir = Path(output_dir) segmentation_dir.mkdir(parents=True, exist_ok=True) - prediction_dir = segmentation_dir.joinpath("Prediction") + prediction_dir = segmentation_dir.joinpath("prediction") prediction_dir.mkdir(exist_ok=True) + simple_post_dir = segmentation_dir.joinpath("post_simple") + simple_post_dir.mkdir(parents=True, exist_ok=True) + sam_post_dir = segmentation_dir.joinpath("post_sam") + sam_post_dir.mkdir(parents=True, exist_ok=True) - # get input image dims - input_stack = Image.open(data_path) - num_slices = input_stack.n_frames - img_height = input_stack.height - img_width = input_stack.width - print(f"input_stack: {num_slices}, {img_height}, {img_width}") - + # load rf model with open(rf_model_path, mode="rb") as f: model_data = pickle.load(f) # compatibility check for old format rf model @@ -221,13 +98,17 @@ def main( rf_model.set_params(verbose=0) print(rf_model) + # get stack dims + lazy_stack = pims.open(input_file) + img_height, img_width = lazy_stack.frame_shape + # list of available models available_models = get_available_models() assert model_name in available_models, ( f"Couldn't find {model_name} in available models\n{available_models}." ) - model_adapter = get_model(model_name, img_height, img_width) + model_adapter.no_patching = no_patching patch_size = model_adapter.patch_size overlap = model_adapter.overlap print(f"patch_size: {patch_size}, overlap: {overlap}") @@ -238,41 +119,23 @@ def main( use_sam_autoseg = False sam_autoseg_iou_threshold = 0.35 - # ### Prediction + # ### Prediction ### tic = time.perf_counter() - for i, page in enumerate(ImageSequence.Iterator(input_stack)): - print(f"\nslice {i + 1}") - slice_img = np.array(page.convert("RGB")) - procs = [] - # random forest prediction happens per batch of extracted features - # in a separate process. - with mp.Manager() as manager: - result_dict = manager.dict() - for b_idx, patch_features in get_slice_features(slice_img, model_adapter): - proc = mp.Process( - target=predict_patches, - args=(patch_features, rf_model, model_adapter, b_idx, result_dict), - ) - procs.append(proc) - proc.start() - # wait until all processes are done - for p in procs: - if p.is_alive: - p.join() - # collect results from each process - batch_indices = sorted(result_dict.keys()) - patch_masks = [result_dict[b] for b in batch_indices] - patch_masks = np.vstack(patch_masks) - slice_mask = get_image_mask( - patch_masks, img_height, img_width, patch_size, overlap - ) - - img = Image.fromarray(slice_mask) - img.save(prediction_dir.joinpath(f"slice_{i:04}_prediction.tiff")) + for slice_mask, idx, total in run_prediction_pipeline( + input_stack=input_file, + model_adapter=model_adapter, + rf_model=rf_model, + ): + print(f"\nslice {idx + 1} / {total}") + tifffile.imwrite( + prediction_dir.joinpath(f"slice_{idx:04}_prediction.tiff"), slice_mask + ) if do_postprocess: + print("\nrunning post-processing...") + slice_img = lazy_stack[idx] post_masks = apply_postprocessing( - slice_img, + slice_img, # type: ignore slice_mask, smoothing_iterations, area_threshold, @@ -283,10 +146,46 @@ def main( ) # save results for name, mask in post_masks.items(): - img = Image.fromarray(mask) seg_dir = segmentation_dir.joinpath(name) - seg_dir.mkdir(exist_ok=True) - img.save(seg_dir.joinpath(f"slice_{i:04}_{name}.tiff")) + # seg_dir.mkdir(exist_ok=True) + tifffile.imwrite(seg_dir.joinpath(f"slice_{idx:04}_{name}.tiff"), mask) + + print(f"total elapsed time: {(time.perf_counter() - tic)} seconds") + + +def run_extract_features( + input_file: str, + output_dir: str, + model_name: str = "SAM2_Large", + no_patching: bool = False, +): + # input image + data_path = Path(input_file) + print(f"data_path exists: {data_path.exists()}") + # get stack dims + lazy_stack = pims.open(input_file) + img_height, img_width = lazy_stack.frame_shape + + # list of available models + available_models = get_available_models() + assert model_name in available_models, ( + f"Couldn't find {model_name} in available models\n{available_models}." + ) + model_adapter = get_model(model_name, img_height, img_width) + model_adapter.no_patching = no_patching + patch_size = model_adapter.patch_size + overlap = model_adapter.overlap + print(f"patch_size: {patch_size}, overlap: {overlap}") + + # zarr data store + store_path = Path(output_dir) + if not store_path.name.endswith(".zarr"): + output_dir += ".zarr" + + tic = time.perf_counter() + print(f"extracting features from {input_file}...") + for idx, total in extract_embeddings_to_file(input_file, output_dir, model_adapter): + print(f"slice {idx + 1} / {total}") print(f"total elapsed time: {(time.perf_counter() - tic)} seconds") @@ -296,13 +195,19 @@ def main( description="\nFeatureForest run-pipeline script", ) parser.add_argument("--data", help="Path to the input image", required=True) - parser.add_argument("--rf_model", help="Path to the trained RF model", required=True) parser.add_argument("--outdir", help="Path to the output directory", required=True) + parser.add_argument("--rf_model", help="Path to the trained RF model", required=False) parser.add_argument( "--feat_model", choices=get_available_models(), + default="SAM2_Large", help="Name of the model for feature extraction", ) + parser.add_argument( + "--no_patching", + action="store_true", + help="If true, no patching will be used during feature extraction", + ) parser.add_argument( "--smoothing_iterations", default=25, @@ -316,20 +221,32 @@ def main( help="Post-processing area threshold to remove small regions; default=50", ) parser.add_argument( - "--use_sam_predictor", + "--post_sam", default=True, action="store_true", - help="To use SAM2 for generating final masks", + help="to use SAM2 for generating final masks", + ) + parser.add_argument( + "--only_extract", + default=False, + action="store_true", + help="to only extract features to zarr file without running prediction pipeline", ) args = parser.parse_args() - main( + if args.only_extract: + run_extract_features(args.data, args.outdir, args.feat_model, args.no_patching) + exit(0) + + assert args.rf_model is not None, "RF model file is required." + run( input_file=args.data, rf_model_file=args.rf_model, output_dir=args.outdir, model_name=args.feat_model, + no_patching=args.no_patching, smoothing_iterations=args.smoothing_iterations, area_threshold=args.area_threshold, - use_sam_predictor=args.use_sam_predictor, + use_sam_predictor=args.post_sam, ) diff --git a/src/featureforest/__init__.py b/src/featureforest/__init__.py index e09001f..cddcf6e 100644 --- a/src/featureforest/__init__.py +++ b/src/featureforest/__init__.py @@ -1,4 +1,4 @@ -__version__ = "0.0.9" +__version__ = "0.1.0" from ._feature_extractor_widget import FeatureExtractorWidget from ._segmentation_widget import SegmentationWidget diff --git a/src/featureforest/_feature_extractor_widget.py b/src/featureforest/_feature_extractor_widget.py index 33e51f1..9553da1 100644 --- a/src/featureforest/_feature_extractor_widget.py +++ b/src/featureforest/_feature_extractor_widget.py @@ -1,6 +1,7 @@ import csv import time from pathlib import Path +from typing import Optional import napari import napari.utils.notifications as notif @@ -9,6 +10,7 @@ from napari.utils.events import Event from qtpy.QtCore import Qt from qtpy.QtWidgets import ( + QCheckBox, QComboBox, QFileDialog, QGroupBox, @@ -23,9 +25,7 @@ from .models import get_available_models, get_model from .utils import config -from .utils.data import ( - get_stack_dims, -) +from .utils.data import get_stack_dims from .utils.extract import extract_embeddings_to_file from .widgets import ( ScrollWidgetWrapper, @@ -39,7 +39,7 @@ def __init__(self, napari_viewer: napari.Viewer): self.viewer = napari_viewer self.extract_worker = None self.model_adapter = None - self.timing = {"start": 0, "avg_per_slice": 0} + self.timing = {"start": 0.0, "avg_per_slice": 0.0} self.prepare_widget() def prepare_widget(self): @@ -54,12 +54,19 @@ def prepare_widget(self): # input layer input_label = QLabel("Image Layer:") self.image_combo = QComboBox() + self.image_combo.currentIndexChanged.connect(self.image_changed) # model selection model_label = QLabel("Encoder Model:") self.model_combo = QComboBox() self.model_combo.setEditable(False) self.model_combo.addItems(get_available_models()) self.model_combo.setCurrentIndex(0) + # no-patching checkbox + self.no_patching_checkbox = QCheckBox("No &Patching") + self.no_patching_checkbox.setToolTip( + "Whether divide an image into patches or not; " + "\nOnly works for square images (height=width)" + ) # storage storage_label = QLabel("Features Storage File:") self.storage_textbox = QLineEdit() @@ -94,6 +101,7 @@ def prepare_widget(self): vbox.addWidget(self.image_combo) vbox.addWidget(model_label) vbox.addWidget(self.model_combo) + vbox.addWidget(self.no_patching_checkbox) layout.addLayout(vbox) vbox = QVBoxLayout() @@ -124,7 +132,7 @@ def prepare_widget(self): self.base_layout.addWidget(gbox) self.base_layout.addStretch(1) - def check_input_layers(self, event: Event = None): + def check_input_layers(self, event: Optional[Event] = None): curr_text = self.image_combo.currentText() self.image_combo.clear() for layer in self.viewer.layers: @@ -136,18 +144,34 @@ def check_input_layers(self, event: Event = None): if index > -1: self.image_combo.setCurrentIndex(index) + def image_changed(self, event: Optional[Event] = None) -> None: + # check if image is square so we can do no_patching + image_layer = get_layer( + self.viewer, self.image_combo.currentText(), config.NAPARI_IMAGE_LAYER + ) + if image_layer is not None: + _, img_height, img_width = get_stack_dims(image_layer.data) + if img_height != img_width: + self.no_patching_checkbox.setChecked(False) + self.no_patching_checkbox.setEnabled(False) + else: + self.no_patching_checkbox.setEnabled(True) + def save_storage(self): # default storage name image_layer_name = self.image_combo.currentText() model_name = self.model_combo.currentText() - storage_name = f"{image_layer_name}_{model_name}.hdf5" + storage_name = f"{image_layer_name}_{model_name}" + if self.no_patching_checkbox.isChecked(): + storage_name += "_no_patching" + storage_name += ".hdf5" # open the save dialog selected_file, _filter = QFileDialog.getSaveFileName( - self, "FeatureForest", storage_name, "Embeddings Storage(*.hdf5)" + self, "FeatureForest", storage_name, "Feature Storage(*.hdf)" ) if selected_file is not None and len(selected_file) > 0: - if not selected_file.endswith(".hdf5"): - selected_file += ".hdf5" + if not selected_file.endswith(".hdf"): + selected_file += ".hdf" self.storage_textbox.setText(selected_file) self.extract_button.setEnabled(True) @@ -164,10 +188,12 @@ def extract_embeddings(self): if storage_path is None or len(storage_path) < 6: notif.show_error("No storage path was set.") return + # initialize the selected model _, img_height, img_width = get_stack_dims(image_layer.data) model_name = self.model_combo.currentText() self.model_adapter = get_model(model_name, img_height, img_width) + self.model_adapter.no_patching = self.no_patching_checkbox.isChecked() self.extract_button.setEnabled(False) self.stop_button.setEnabled(True) @@ -175,7 +201,7 @@ def extract_embeddings(self): self.extract_worker = create_worker( extract_embeddings_to_file, image=image_layer.data, - storage_file_path=storage_path, + storage_path=storage_path, model_adapter=self.model_adapter, ) self.extract_worker.yielded.connect(self.update_extract_progress) diff --git a/src/featureforest/_segmentation_widget.py b/src/featureforest/_segmentation_widget.py index 6fbfd95..b25e0c2 100644 --- a/src/featureforest/_segmentation_widget.py +++ b/src/featureforest/_segmentation_widget.py @@ -4,10 +4,13 @@ import pickle import time import warnings +from collections.abc import Generator from pathlib import Path +from typing import Optional import h5py import napari +import napari.layers import napari.utils.notifications as notif import numpy as np import tifffile @@ -34,7 +37,7 @@ from tifffile import TiffFile from .exports import EXPORTERS -from .models import get_model +from .models import BaseModelAdapter, get_model from .postprocess import ( get_sam_auto_masks, postprocess, @@ -48,7 +51,7 @@ get_stack_dims, get_stride_margin, ) -from .utils.pipeline_prediction import extract_predict +from .utils.pipeline_prediction import run_prediction_pipeline from .utils.usage_stats import SegmentationUsageStats from .widgets import ( ScrollWidgetWrapper, @@ -58,25 +61,25 @@ class SegmentationWidget(QWidget): - def __init__(self, napari_viewer: napari.Viewer): + def __init__(self, napari_viewer: napari.Viewer) -> None: super().__init__() self.viewer = napari_viewer - self.image_layer = None - self.gt_layer = None - self.segmentation_layer = None - self.postprocess_layer = None - self.storage = None - self.rf_model = None - self.model_adapter = None + self.image_layer: napari.layers.Image | None = None + self.gt_layer: napari.layers.Labels | None = None + self.segmentation_layer: napari.layers.Labels | None = None + self.postprocess_layer: napari.layers.Labels | None = None + self.storage: h5py.File | None = None + self.rf_model: RandomForestClassifier | None = None + self.model_adapter: BaseModelAdapter | None = None self.sam_auto_masks = None self.patch_size = 512 # default values - self.overlap = 384 + self.overlap = 512 // 4 self.stride = self.patch_size - self.overlap self.stats = SegmentationUsageStats() self.prepare_widget() - def closeEvent(self, event): + def closeEvent(self, event) -> None: print("closing") self.viewer.layers.events.inserted.disconnect(self.check_input_layers) self.viewer.layers.events.removed.disconnect(self.check_input_layers) @@ -87,7 +90,7 @@ def closeEvent(self, event): self.viewer.layers.events.removed.disconnect(self.postprocess_layer_removed) - def prepare_widget(self): + def prepare_widget(self) -> None: self.base_layout = QVBoxLayout() self.create_input_ui() self.create_label_stats_ui() @@ -118,7 +121,7 @@ def prepare_widget(self): self.viewer.dims.events.current_step.connect(self.clear_sam_auto_masks) - def create_input_ui(self): + def create_input_ui(self) -> None: # input layer input_label = QLabel("Input Layer:") self.image_combo = QComboBox() @@ -168,7 +171,7 @@ def create_input_ui(self): gbox.setLayout(layout) self.base_layout.addWidget(gbox) - def create_label_stats_ui(self): + def create_label_stats_ui(self) -> None: self.num_class_label = QLabel("Number of classes: ") self.each_class_label = QLabel("Labels per class:") analyze_button = QPushButton("Analyze") @@ -195,7 +198,7 @@ def create_label_stats_ui(self): gbox.setLayout(layout) self.base_layout.addWidget(gbox) - def create_train_ui(self): + def create_train_ui(self) -> None: tree_label = QLabel("Number of trees:") self.num_trees_textbox = QLineEdit() self.num_trees_textbox.setText("450") @@ -247,7 +250,7 @@ def create_train_ui(self): gbox.setLayout(layout) self.base_layout.addWidget(gbox) - def create_prediction_ui(self): + def create_prediction_ui(self) -> None: seg_label = QLabel("Segmentation Layer:") self.new_layer_checkbox = QCheckBox("New Layer") self.new_layer_checkbox.setChecked(True) @@ -301,7 +304,7 @@ def create_prediction_ui(self): gbox.setLayout(layout) self.base_layout.addWidget(gbox) - def create_postprocessing_ui(self): + def create_postprocessing_ui(self) -> None: smooth_label = QLabel("Smoothing Iterations:") self.smoothing_iteration_textbox = QLineEdit() self.smoothing_iteration_textbox.setText("25") @@ -387,7 +390,7 @@ def create_postprocessing_ui(self): gbox.setLayout(layout) self.base_layout.addWidget(gbox) - def create_export_ui(self): + def create_export_ui(self) -> None: export_label = QLabel("Export Format:") self.export_format_combo = QComboBox() for exporter in EXPORTERS: @@ -422,7 +425,7 @@ def create_export_ui(self): gbox.setLayout(layout) self.base_layout.addWidget(gbox) - def create_large_stack_prediction_ui(self): + def create_large_stack_prediction_ui(self) -> None: stack_label = QLabel("Select your stack:") self.large_stack_textbox = QLineEdit() self.large_stack_textbox.setReadOnly(True) @@ -481,13 +484,13 @@ def create_large_stack_prediction_ui(self): gbox.setLayout(layout) self.base_layout.addWidget(gbox) - def new_layer_checkbox_changed(self): + def new_layer_checkbox_changed(self) -> None: state = self.new_layer_checkbox.checkState() self.prediction_layer_combo.setEnabled(state == Qt.Unchecked) self.seg_add_radiobutton.setEnabled(state == Qt.Unchecked) self.seg_replace_radiobutton.setEnabled(state == Qt.Unchecked) - def check_input_layers(self, event: Event): + def check_input_layers(self, event: Optional[Event]) -> None: curr_text = self.image_combo.currentText() self.image_combo.clear() for layer in self.viewer.layers: @@ -499,7 +502,7 @@ def check_input_layers(self, event: Event): if index > -1: self.image_combo.setCurrentIndex(index) - def check_label_layers(self, event: Event): + def check_label_layers(self, event: Optional[Event]) -> None: gt_curr_text = self.gt_combo.currentText() pred_curr_text = self.prediction_layer_combo.currentText() self.gt_combo.clear() @@ -524,10 +527,10 @@ def check_label_layers(self, event: Event): if index > -1: self.prediction_layer_combo.setCurrentIndex(index) - def clear_sam_auto_masks(self): + def clear_sam_auto_masks(self) -> None: self.sam_auto_masks = None - def postprocess_layer_removed(self, event: Event): + def postprocess_layer_removed(self, event: Event) -> None: """Fires when current postprocess layer is removed.""" if ( self.postprocess_layer is not None @@ -535,41 +538,42 @@ def postprocess_layer_removed(self, event: Event): ): self.postprocess_layer = None - def sam_post_checked(self, checked: bool): + def sam_post_checked(self, checked: bool) -> None: if checked: self.sam_auto_post_checkbox.setChecked(False) - def sam_auto_post_checked(self, checked: bool): + def sam_auto_post_checked(self, checked: bool) -> None: if checked: self.sam_post_checkbox.setChecked(False) - def select_storage(self): - selected_file, _filter = QFileDialog.getOpenFileName( - self, "FeatureForest", ".", "Feature Storage(*.hdf5)" + def select_storage(self) -> None: + selected_file, _ = QFileDialog.getOpenFileName( + self, "FeatureForest", "..", "Feature Storage(*.hdf5)" ) if selected_file is not None and len(selected_file) > 0: self.storage_textbox.setText(selected_file) # load the storage - self.storage = h5py.File(selected_file, "r") + self.storage = h5py.File(selected_file, mode="r") # set the patch size and overlap from the selected storage self.patch_size = self.storage.attrs.get("patch_size", self.patch_size) self.overlap = self.storage.attrs.get("overlap", self.overlap) self.stride, _ = get_stride_margin(self.patch_size, self.overlap) # initialize the model based on the selected storage - img_height = self.storage.attrs["img_height"] - img_width = self.storage.attrs["img_width"] - # TODO: raise an error if current image dims are in conflicting with storage - model_name = self.storage.attrs["model"] + img_height: int = self.storage.attrs.get("img_height", 0) + img_width: int = self.storage.attrs.get("img_width", 0) + model_name = str(self.storage.attrs["model"]) + no_patching = self.storage.attrs.get("no_patching", False) self.model_adapter = get_model(model_name, img_height, img_width) - print(model_name, self.patch_size, self.overlap) + self.model_adapter.no_patching = no_patching + print(model_name, self.patch_size, self.overlap, no_patching) # set the plugin usage stats csv file storage_path = Path(selected_file) csv_path = storage_path.parent.joinpath(f"{storage_path.stem}_seg_stats.csv") self.stats.set_file_path(csv_path) - def add_labels_layer(self): + def add_labels_layer(self) -> None: self.image_layer = get_layer( self.viewer, self.image_combo.currentText(), config.NAPARI_IMAGE_LAYER ) @@ -585,97 +589,105 @@ def add_labels_layer(self): layer.colormap = colormaps.create_colormap(10)[0] layer.brush_size = 1 - def set_stats_label_layer(self): + def set_stats_label_layer(self) -> None: layer = get_layer( self.viewer, self.gt_combo.currentText(), config.NAPARI_LABELS_LAYER ) if layer is not None: self.stats.set_label_layer(layer) - def get_class_labels(self): - labels_dict = {} + def get_labeled_pixels(self) -> dict[int, np.ndarray]: layer = get_layer( self.viewer, self.gt_combo.currentText(), config.NAPARI_LABELS_LAYER ) if layer is None: print("No label layer is selected!") notif.show_error("No label layer is selected!") - return labels_dict + return {} + labeled_pixels = {} + slice_dim = 0 + ydim = 1 + xdim = 2 class_indices = np.unique(layer.data).tolist() # class zero is the napari background class. class_indices = [i for i in class_indices if i > 0] - for class_idx in class_indices: - positions = np.argwhere(layer.data == class_idx) - labels_dict[class_idx] = positions - - return labels_dict - - def analyze_labels(self, labels_dict: dict = None): - if labels_dict is None: - labels_dict = self.get_class_labels() - num_labels = [len(v) for v in labels_dict.values()] - self.num_class_label.setText(f"Number of classes: {len(num_labels)}") - each_class = "\n".join( - [f"class {i + 1}: {num_labels[i]:,d}" for i in range(len(num_labels))] - ) - self.each_class_label.setText("Labels per class:\n" + each_class) + coords = np.argwhere(np.isin(layer.data, class_indices)) + for s_i in np.unique(coords[:, slice_dim]).tolist(): + s_coords = coords[coords[:, slice_dim] == s_i] + slice_labels = layer.data[s_i, s_coords[:, ydim], s_coords[:, xdim]] + labeled_pixels[s_i] = np.column_stack( + [ + slice_labels, + s_coords[:, [ydim, xdim]], # omit slice dim + ] + ) + + return labeled_pixels + + def analyze_labels(self, labeled_pixels: dict | None = None) -> None: + if not labeled_pixels: + labeled_pixels = self.get_labeled_pixels() - def show_usage_stats(self): + if len(labeled_pixels) > 0: + labels = np.concat([v[:, 0] for v in labeled_pixels.values()], axis=None) + classes = np.unique(labels) + self.num_class_label.setText(f"Number of classes: {len(classes)}") + each_class = "\n".join([f"class {c}: {sum(labels == c):,d}" for c in classes]) + self.each_class_label.setText("Labels per class:\n" + each_class) + + def show_usage_stats(self) -> None: stats_widget = UsageStats(self.stats) stats_widget.exec() - def get_train_data(self): - # get ground truth class labels - labels_dict = self.get_class_labels() - if len(labels_dict) == 0: + def get_train_data(self) -> tuple[np.ndarray, np.ndarray] | None: + # get ground truth labeled pixels + labeled_pixels = self.get_labeled_pixels() + if labeled_pixels is None: return None if self.storage is None: notif.show_error("No embeddings storage file is selected!") return None # update labels stats - self.analyze_labels(labels_dict) + self.analyze_labels(labeled_pixels) num_slices, img_height, img_width = get_stack_dims(self.image_layer.data) - num_labels = sum([len(v) for v in labels_dict.values()]) total_channels = self.model_adapter.get_total_output_channels() - train_data = np.zeros((num_labels, total_channels)) - labels = np.zeros(num_labels, dtype="int32") - 1 + num_labels = sum([len(v) for v in labeled_pixels.values()]) + label_dim = 0 + ydim = 1 + xdim = 2 count = 0 - for class_index in np_progress( - labels_dict, desc="getting training data", total=len(labels_dict.keys()) + train_data = np.zeros((num_labels, total_channels)) + labels = np.zeros(num_labels, dtype=np.int32) - 1 + for s_idx, label_coords in np_progress( + labeled_pixels.items(), desc="getting training data" ): - class_label_coords = labels_dict[class_index] - uniq_slices = np.unique(class_label_coords[:, 0]).tolist() - # for each unique slice, load unique patches from the storage, - # then get the pixel features within loaded patch. - for slice_index in np_progress(uniq_slices, desc="reading slices"): - slice_coords = class_label_coords[ - class_label_coords[:, 0] == slice_index - ][ - :, 1: - ] # omit the slice dim - patch_indices = get_patch_indices( - slice_coords, img_height, img_width, self.patch_size, self.overlap - ) - grp_key = str(slice_index) - slice_dataset = self.storage[grp_key][self.model_adapter.name] - for p_i in np.unique(patch_indices): - patch_coords = slice_coords[patch_indices == p_i] - patch_features = slice_dataset[p_i] - train_data[count : count + len(patch_coords)] = patch_features[ - patch_coords[:, 0] % self.stride, patch_coords[:, 1] % self.stride - ] - labels[count : count + len(patch_coords)] = ( - class_index - 1 - ) # to have bg class as zero - count += len(patch_coords) + # slice labels + s_labels = label_coords[:, label_dim] + # slice labeled coords + s_coords = label_coords[:, [ydim, xdim]] + patch_indices = get_patch_indices( + s_coords, img_height, img_width, self.patch_size, self.overlap + ) + grp_key = str(s_idx) + slice_dataset: h5py.Dataset = self.storage[grp_key]["features"] + # loop through slice patches + for p_i in np.unique(patch_indices).tolist(): + patch_features = slice_dataset[p_i] + patch_coords = s_coords[patch_indices == p_i] + train_data[count : count + len(patch_coords)] = patch_features[ + patch_coords[:, 0] % self.stride, patch_coords[:, 1] % self.stride + ] + # -1: to have bg class as zero + labels[count : count + len(patch_coords)] = s_labels - 1 + count += len(patch_coords) assert (labels > -1).all() return train_data, labels - def train_model(self): + def train_model(self) -> None: self.image_layer = get_layer( self.viewer, self.image_combo.currentText(), config.NAPARI_IMAGE_LAYER ) @@ -706,7 +718,7 @@ def train_model(self): min_samples_split=15, min_samples_leaf=3, max_features=25, - n_jobs=2 if os.cpu_count() < 5 else os.cpu_count() - 3, + n_jobs=os.cpu_count() - 1, verbose=1, ) rf_classifier.fit(train_data, labels) @@ -717,9 +729,9 @@ def train_model(self): notif.show_info("Model status: Training is Done!") self.model_save_button.setEnabled(True) - def load_rf_model(self): + def load_rf_model(self) -> None: selected_file, _filter = QFileDialog.getOpenFileName( - self, "FeatureForest", ".", "model(*.bin)" + self, "FeatureForest", "..", "model(*.bin)" ) if len(selected_file) > 0: # to suppress the sklearn InconsistentVersionWarning @@ -737,12 +749,14 @@ def load_rf_model(self): # (users can just load the rf model) if self.model_adapter is None: model_name = model_data["model_name"] + no_patching = model_data["no_patching"] self.patch_size = model_data["patch_size"] self.overlap = model_data["overlap"] img_height = model_data["img_height"] img_width = model_data["img_width"] # init the model adapter self.model_adapter = get_model(model_name, img_height, img_width) + self.model_adapter.no_patching = no_patching else: # old format self.rf_model = model_data @@ -751,12 +765,12 @@ def load_rf_model(self): self.model_status_label.setText("Model status: Ready!") self.model_save_button.setEnabled(True) - def save_rf_model(self): + def save_rf_model(self) -> None: if self.rf_model is None: notif.show_info("There is no trained model!") return selected_file, _filter = QFileDialog.getSaveFileName( - self, "FeatureForest", ".", "model(*.bin)" + self, "FeatureForest", "..", "model(*.bin)" ) if len(selected_file) > 0: if not selected_file.endswith(".bin"): @@ -767,6 +781,7 @@ def save_rf_model(self): "model_name": self.model_adapter.name, "img_height": self.storage.attrs["img_height"], "img_width": self.storage.attrs["img_width"], + "no_patching": self.model_adapter.no_patching, "patch_size": self.patch_size, "overlap": self.overlap, } @@ -774,7 +789,7 @@ def save_rf_model(self): pickle.dump(model_data, f) notif.show_info("Model was saved successfully.") - def predict(self, whole_stack=False): + def predict(self, whole_stack: bool = False) -> None: self.prediction_progress.setValue(0) if self.rf_model is None: notif.show_error("There is no trained RF model!") @@ -825,7 +840,9 @@ def predict(self, whole_stack=False): self.predict_worker.finished.connect(self.prediction_is_done) self.predict_worker.run() - def run_prediction(self, slice_indices, img_height, img_width): + def run_prediction( + self, slice_indices: list, img_height: int, img_width: int + ) -> Generator[tuple[int, int], None, None]: for slice_index in np_progress(slice_indices): self.stats.prediction_started() @@ -854,11 +871,17 @@ def run_prediction(self, slice_indices, img_height, img_width): self.segmentation_layer.colormap = cm self.segmentation_layer.refresh() - def predict_slice(self, rf_model, slice_index, img_height, img_width): + def predict_slice( + self, + rf_model: RandomForestClassifier, + slice_index: int, + img_height: int, + img_width: int, + ) -> np.ndarray: """Predict a slice patch by patch""" segmentation_image = [] - # shape: N x target_size x target_size x C - feature_patches = self.storage[str(slice_index)][self.model_adapter.name][:] + # shape: N x stride x stride x C + feature_patches: np.ndarray = self.storage[str(slice_index)]["features"][:] num_patches = feature_patches.shape[0] total_channels = self.model_adapter.get_total_output_channels() for i in np_progress(range(num_patches), desc="Predicting slice patches"): @@ -884,26 +907,26 @@ def predict_slice(self, rf_model, slice_index, img_height, img_width): return segmentation_image - def stop_predict(self): + def stop_predict(self) -> None: if self.predict_worker is not None: self.predict_worker.quit() self.predict_worker = None self.predict_stop_button.setEnabled(False) - def update_prediction_progress(self, values): + def update_prediction_progress(self, values: tuple) -> None: curr, total = values self.prediction_progress.setMinimum(0) self.prediction_progress.setMaximum(total) self.prediction_progress.setValue(curr + 1) self.prediction_progress.setFormat("slice %v of %m (%p%)") - def prediction_is_done(self): + def prediction_is_done(self) -> None: self.predict_all_button.setEnabled(True) self.predict_stop_button.setEnabled(False) print("Prediction is done!") notif.show_info("Prediction is done!") - def get_postprocess_params(self): + def get_postprocess_params(self) -> tuple[int, int, bool]: smoothing_iterations = 25 if len(self.smoothing_iteration_textbox.text()) > 0: smoothing_iterations = int(self.smoothing_iteration_textbox.text()) @@ -917,7 +940,7 @@ def get_postprocess_params(self): return smoothing_iterations, area_threshold, area_is_absolute - def postprocess_segmentation(self, whole_stack=False): + def postprocess_segmentation(self, whole_stack: bool = False) -> None: self.segmentation_layer = get_layer( self.viewer, self.prediction_layer_combo.currentText(), @@ -926,12 +949,15 @@ def postprocess_segmentation(self, whole_stack=False): if self.segmentation_layer is None: notif.show_error("No segmentation layer is selected!") return + if self.image_layer is None: + notif.show_error("No image layer is selected!") + return smoothing_iterations, area_threshold, area_is_absolute = ( self.get_postprocess_params() ) - num_slices, img_height, img_width = get_stack_dims(self.image_layer.data) + num_slices, _, _ = get_stack_dims(self.image_layer.data) slice_indices = [] if not whole_stack: # only predict the current slice @@ -978,7 +1004,7 @@ def postprocess_segmentation(self, whole_stack=False): self.postprocess_layer.refresh() - def export_segmentation(self): + def export_segmentation(self) -> None: if self.segmentation_layer is None: notif.show_error("No segmentation layer is selected!") return @@ -1004,17 +1030,17 @@ def export_segmentation(self): notif.show_info("Selected layer was exported successfully.") - def select_stack(self): + def select_stack(self) -> None: selected_file, _filter = QFileDialog.getOpenFileName( - self, "FeatureForest", ".", "TIFF stack (*.tiff *.tif)" + self, "FeatureForest", "..", "TIFF stack (*.tiff *.tif)" ) if selected_file is not None and len(selected_file) > 0: # get stack info with TiffFile(selected_file) as tiff_stack: axes = tiff_stack.series[0].axes - assert ("Y" in axes) and ( - "X" in axes - ), "Could not find YX in the stack axes!" + assert ("Y" in axes) and ("X" in axes), ( + "Could not find YX in the stack axes!" + ) stack_dims = tiff_stack.series[0].shape stack_height = stack_dims[axes.index("Y")] stack_width = stack_dims[axes.index("X")] @@ -1038,7 +1064,7 @@ def select_stack(self): res_dir = Path(selected_file).parent self.result_dir_textbox.setText(str(res_dir.absolute())) - def select_result_dir(self): + def select_result_dir(self) -> None: selected_dir = QFileDialog.getExistingDirectory( self, "Select a directory", @@ -1048,7 +1074,7 @@ def select_result_dir(self): if selected_dir is not None and len(selected_dir) > 0: self.result_dir_textbox.setText(selected_dir) - def run_pipeline_over_large_stack(self): + def run_pipeline_over_large_stack(self) -> None: if self.large_stack_textbox.text() == "": notif.show_error("No TIFF Stack is selected!") return @@ -1068,7 +1094,12 @@ def run_pipeline_over_large_stack(self): self.pipeline_worker.finished.connect(self.pipeline_is_done) self.pipeline_worker.run() - def run_pipeline(self, tiff_stack_file: str, result_dir: Path): + def run_pipeline( + self, tiff_stack_file: str, result_dir: Path + ) -> Generator[tuple[int, int], None, None]: + if self.model_adapter is None or self.rf_model is None: + raise ValueError("RF model and/or Model Adapter are missing!") + start = dt.datetime.now() slices_total_time = 0 postprocess_total_time = 0 @@ -1079,90 +1110,89 @@ def run_pipeline(self, tiff_stack_file: str, result_dir: Path): sam_post_dir = result_dir.joinpath("post_sam") sam_post_dir.mkdir(parents=True, exist_ok=True) - with TiffFile(tiff_stack_file) as tiff_stack: - total_pages = len(tiff_stack.pages) - for page_idx, page in np_progress( - enumerate(tiff_stack.pages), - desc="runing the pipeline", - total=len(tiff_stack.pages), - ): - slice_start = time.perf_counter() - image = page.asarray() - prediction_mask = extract_predict( - image, self.model_adapter, self.rf_model - ) - # save the prediction - tifffile.imwrite( - prediction_dir.joinpath(f"slice_{page_idx:04}_prediction.tiff"), - prediction_mask, - ) - # post-processing - smoothing_iterations, area_threshold, area_is_absolute = ( - self.get_postprocess_params() - ) - post_mask = postprocess( - prediction_mask, - smoothing_iterations, - area_threshold, - area_is_absolute, - ) - tifffile.imwrite( - simple_post_dir.joinpath(f"slice_{page_idx:04}_post_simple.tiff"), - post_mask, - ) - pp_start = time.perf_counter() - post_sam_mask = postprocess_with_sam( - prediction_mask, - smoothing_iterations, - area_threshold, - area_is_absolute, - ) - tifffile.imwrite( - sam_post_dir.joinpath(f"slice_{page_idx:04}_post_sam.tiff"), - post_sam_mask, - ) - # slices timing - postprocess_total_time += round(time.perf_counter() - pp_start, 2) - slices_total_time += round(time.perf_counter() - slice_start, 2) - slice_avg = slices_total_time / (page_idx + 1) - rem_minutes, rem_seconds = divmod( - slice_avg * (total_pages - page_idx + 1), 60 - ) - rem_hour, rem_minutes = divmod(rem_minutes, 60) - self.timing_info.setText( - f"Estimated remaining time: " - f"{int(rem_hour):02}:{int(rem_minutes):02}:{int(rem_seconds):02}" - ) - print(f"slice average time(seconds): {slice_avg:.2f}") + self.rf_model.set_params(verbose=0) + slice_start = time.perf_counter() + for slice_mask, idx, total in np_progress( + run_prediction_pipeline( + tiff_stack_file, + self.model_adapter, + self.rf_model, + ), + desc="running the pipeline", + ): + # save the prediction + tifffile.imwrite( + prediction_dir.joinpath(f"slice_{idx:04}_prediction.tiff"), + slice_mask, + ) + # post-processing + print("post processing...") + smoothing_iterations, area_threshold, area_is_absolute = ( + self.get_postprocess_params() + ) + post_mask = postprocess( + slice_mask, + smoothing_iterations, + area_threshold, + area_is_absolute, + ) + tifffile.imwrite( + simple_post_dir.joinpath(f"slice_{idx:04}_post_simple.tiff"), + post_mask, + ) + pp_start = time.perf_counter() + post_sam_mask = postprocess_with_sam( + slice_mask, + smoothing_iterations, + area_threshold, + area_is_absolute, + ) + tifffile.imwrite( + sam_post_dir.joinpath(f"slice_{idx:04}_post_sam.tiff"), + post_sam_mask, + ) + # slices timing + postprocess_total_time += round(time.perf_counter() - pp_start, 2) + slices_total_time += round(time.perf_counter() - slice_start, 2) + slice_avg = slices_total_time / (idx + 1) + rem_minutes, rem_seconds = divmod(slice_avg * (total - idx + 1), 60) + rem_hour, rem_minutes = divmod(rem_minutes, 60) + self.timing_info.setText( + f"Estimated remaining time: " + f"{int(rem_hour):02}:{int(rem_minutes):02}:{int(rem_seconds):02}" + ) + print(f"slice average time(seconds): {slice_avg:.2f}") + slice_start = time.perf_counter() - yield page_idx, total_pages + yield idx, total # stack is done end = dt.datetime.now() self.save_pipeline_stats( - result_dir, start, end, slices_total_time, postprocess_total_time, total_pages + result_dir, start, end, slices_total_time, postprocess_total_time, total ) + self.rf_model.set_params(verbose=1) - def stop_pipeline(self): + def stop_pipeline(self) -> None: if self.pipeline_worker is not None: self.pipeline_worker.quit() self.pipeline_worker = None self.stop_pipeline_button.setEnabled(False) - def update_pipeline_progress(self, values): + def update_pipeline_progress(self, values: tuple) -> None: curr, total = values self.pipeline_progressbar.setMinimum(0) self.pipeline_progressbar.setMaximum(total) self.pipeline_progressbar.setValue(curr + 1) self.pipeline_progressbar.setFormat("slice %v of %m (%p%)") - def pipeline_is_done(self): + def pipeline_is_done(self) -> None: self.run_pipeline_button.setEnabled(True) self.stop_pipeline_button.setEnabled(False) self.remove_temp_storage() print("Prediction is done!") notif.show_info("Prediction is done!") - def remove_temp_storage(self): + def remove_temp_storage(self) -> None: tmp_storage_path = Path.home().joinpath(".featureforest", "tmp_storage.h5") if tmp_storage_path.exists(): tmp_storage_path.unlink() @@ -1175,7 +1205,7 @@ def save_pipeline_stats( slice_total: float, pp_total: float, num_images: int, - ): + ) -> None: total_time = (end - start).total_seconds() total_min, total_sec = divmod(total_time, 60) total_hour, total_min = divmod(total_min, 60) diff --git a/src/featureforest/models/Cellpose/adapter.py b/src/featureforest/models/Cellpose/adapter.py index 6d5ae08..8103ffc 100755 --- a/src/featureforest/models/Cellpose/adapter.py +++ b/src/featureforest/models/Cellpose/adapter.py @@ -6,7 +6,6 @@ from featureforest.models.base import BaseModelAdapter from featureforest.utils.data import ( get_nonoverlapped_patches, - get_patch_size, ) @@ -16,8 +15,8 @@ class CellposeAdapter(BaseModelAdapter): def __init__( self, image_encoder: nn.Module, - img_height: float, - img_width: float, + img_height: int, + img_width: int, device: torch.device, name: str = "Cellpose_cyto3", ) -> None: @@ -28,9 +27,9 @@ def __init__( # self.encoder_num_channels = 480 self.device = device self._set_patch_size() - assert ( - int(self.patch_size / 4) == self.patch_size / 4 - ), f"patch size {self.patch_size} is not divisible by 4" + assert int(self.patch_size / 4) == self.patch_size / 4, ( + f"patch size {self.patch_size} is not divisible by 4" + ) # input transform for sam self.input_transforms = None @@ -45,11 +44,7 @@ def __init__( ] ) - def _set_patch_size(self) -> None: - self.patch_size = get_patch_size(self.img_height, self.img_width) - self.overlap = self.patch_size // 2 - - def get_features_patches(self, in_patches: Tensor) -> tuple[Tensor, Tensor]: + def get_features_patches(self, in_patches: Tensor) -> Tensor: # cellpose works on microscopic channels not RGB # we need to select one channel and add a second zero channel in_patches = torch.cat( diff --git a/src/featureforest/models/Cellpose/model.py b/src/featureforest/models/Cellpose/model.py index 37e16df..74642fb 100755 --- a/src/featureforest/models/Cellpose/model.py +++ b/src/featureforest/models/Cellpose/model.py @@ -25,7 +25,7 @@ def forward(self, x): return out -def get_model(img_height: float, img_width: float, *args, **kwargs) -> CellposeAdapter: +def get_model(img_height: int, img_width: int, *args, **kwargs) -> CellposeAdapter: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"running on {device}") # init the model diff --git a/src/featureforest/models/DinoV2/adapter.py b/src/featureforest/models/DinoV2/adapter.py index 67b9188..b34e9e3 100644 --- a/src/featureforest/models/DinoV2/adapter.py +++ b/src/featureforest/models/DinoV2/adapter.py @@ -1,5 +1,3 @@ -from typing import Tuple - import torch import torch.nn as nn from torch import Tensor @@ -12,14 +10,10 @@ class DinoV2Adapter(BaseModelAdapter): - """DinoV2 model adapter - """ + """DinoV2 model adapter""" + def __init__( - self, - model: nn.Module, - img_height: float, - img_width: float, - device: torch.device + self, model: nn.Module, img_height: int, img_width: int, device: torch.device ) -> None: super().__init__(model, img_height, img_width, device) self.name = "DinoV2" @@ -30,40 +24,43 @@ def __init__( self.device = device # input transform for dinov2 - self.input_transforms = tv_transforms2.Compose([ - tv_transforms2.ToImage(), - tv_transforms2.Resize(self.patch_size * self.dino_patch_size), - tv_transforms2.ToDtype(dtype=torch.float32, scale=True) - ]) + self.input_transforms = tv_transforms2.Compose( + [ + tv_transforms2.ToImage(), + tv_transforms2.Resize(self.patch_size * self.dino_patch_size), + tv_transforms2.ToDtype(dtype=torch.float32, scale=True), + ] + ) # to transform feature patches back to the original patch size - self.embedding_transform = tv_transforms2.Compose([ - tv_transforms2.Resize( - (self.patch_size, self.patch_size), - interpolation=tv_transforms2.InterpolationMode.BICUBIC, - antialias=True - ), - ]) + self.embedding_transform = tv_transforms2.Compose( + [ + tv_transforms2.Resize( + (self.patch_size, self.patch_size), + interpolation=tv_transforms2.InterpolationMode.BICUBIC, + antialias=True, + ), + ] + ) def _set_patch_size(self) -> None: self.patch_size = self.dino_patch_size * 5 self.overlap = self.dino_patch_size * 2 - def get_features_patches( - self, in_patches: Tensor - ) -> Tuple[Tensor, Tensor]: + def get_features_patches(self, in_patches: Tensor) -> Tensor: # get the mobile-sam encoder and embedding layer outputs with torch.no_grad(): # we use get_intermediate_layers method of dinov2, which returns a tuple. # output shape: b, 384, h, w if reshape is true. output_features = self.model.get_intermediate_layers( - self.input_transforms(in_patches), 1, - return_class_token=False, reshape=True - )[0] + self.input_transforms(in_patches), + 1, + return_class_token=False, + reshape=True, + )[0] # type: ignore # get non-overlapped feature patches feature_patches = get_nonoverlapped_patches( - output_features.cpu(), - self.patch_size, self.overlap + output_features.cpu(), self.patch_size, self.overlap ) return feature_patches diff --git a/src/featureforest/models/DinoV2/model.py b/src/featureforest/models/DinoV2/model.py index f003e93..0c2df33 100644 --- a/src/featureforest/models/DinoV2/model.py +++ b/src/featureforest/models/DinoV2/model.py @@ -1,14 +1,10 @@ -from typing import Tuple - import torch # from featureforest.utils.downloader import download_model from .adapter import DinoV2Adapter -def get_model( - img_height: float, img_width: float, *args, **kwargs -) -> DinoV2Adapter: +def get_model(img_height: int, img_width: int, *args, **kwargs) -> DinoV2Adapter: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"running on {device}") # get the pretrained model @@ -18,8 +14,6 @@ def get_model( model.eval() # create the model adapter - dino_model_adapter = DinoV2Adapter( - model, img_height, img_width, device - ) + dino_model_adapter = DinoV2Adapter(model, img_height, img_width, device) return dino_model_adapter diff --git a/src/featureforest/models/MobileSAM/adapter.py b/src/featureforest/models/MobileSAM/adapter.py index b9d7552..efc7259 100644 --- a/src/featureforest/models/MobileSAM/adapter.py +++ b/src/featureforest/models/MobileSAM/adapter.py @@ -1,5 +1,3 @@ -from typing import Tuple - import torch import torch.nn as nn from torch import Tensor @@ -7,20 +5,15 @@ from featureforest.models.base import BaseModelAdapter from featureforest.utils.data import ( - get_patch_size, get_nonoverlapped_patches, ) class MobileSAMAdapter(BaseModelAdapter): - """MobileSAM model adapter - """ + """MobileSAM model adapter""" + def __init__( - self, - model: nn.Module, - img_height: float, - img_width: float, - device: torch.device + self, model: nn.Module, img_height: int, img_width: int, device: torch.device ) -> None: super().__init__(model, img_height, img_width, device) self.name = "MobileSAM" @@ -33,46 +26,42 @@ def __init__( # input transform for sam self.sam_input_dim = 1024 - self.input_transforms = tv_transforms2.Compose([ - tv_transforms2.Resize( - (self.sam_input_dim, self.sam_input_dim), - interpolation=tv_transforms2.InterpolationMode.BICUBIC, - antialias=True - ), - ]) + self.input_transforms = tv_transforms2.Compose( + [ + tv_transforms2.Resize( + (self.sam_input_dim, self.sam_input_dim), + interpolation=tv_transforms2.InterpolationMode.BICUBIC, + antialias=True, + ), + ] + ) # to transform feature patches back to the original patch size - self.embedding_transform = tv_transforms2.Compose([ - tv_transforms2.Resize( - (self.patch_size, self.patch_size), - interpolation=tv_transforms2.InterpolationMode.BICUBIC, - antialias=True - ), - ]) - - def _set_patch_size(self) -> None: - self.patch_size = get_patch_size(self.img_height, self.img_width) - self.overlap = self.patch_size // 2 + self.embedding_transform = tv_transforms2.Compose( + [ + tv_transforms2.Resize( + (self.patch_size, self.patch_size), + interpolation=tv_transforms2.InterpolationMode.BICUBIC, + antialias=True, + ), + ] + ) - def get_features_patches( - self, in_patches: Tensor - ) -> Tuple[Tensor, Tensor]: + def get_features_patches(self, in_patches: Tensor) -> Tensor: # get the mobile-sam encoder and embedding layer outputs with torch.no_grad(): - output, embed_output, _ = self.encoder( - self.input_transforms(in_patches) - ) + output, embed_output, _ = self.encoder(self.input_transforms(in_patches)) # get non-overlapped feature patches out_feature_patches = get_nonoverlapped_patches( - self.embedding_transform(output.cpu()), - self.patch_size, self.overlap + self.embedding_transform(output.cpu()), self.patch_size, self.overlap ) embed_feature_patches = get_nonoverlapped_patches( - self.embedding_transform(embed_output.cpu()), - self.patch_size, self.overlap + self.embedding_transform(embed_output.cpu()), self.patch_size, self.overlap ) + # concat both features together on channel dimension + output = torch.cat([out_feature_patches, embed_feature_patches], dim=-1) - return out_feature_patches, embed_feature_patches + return output def get_total_output_channels(self) -> int: return self.encoder_num_channels + self.embed_layer_num_channels diff --git a/src/featureforest/models/MobileSAM/model.py b/src/featureforest/models/MobileSAM/model.py index fc225c7..5f7e301 100644 --- a/src/featureforest/models/MobileSAM/model.py +++ b/src/featureforest/models/MobileSAM/model.py @@ -1,28 +1,22 @@ -from typing import Tuple - import torch - -from .tiny_vit_sam import TinyViT from segment_anything.modeling import MaskDecoder, PromptEncoder, Sam, TwoWayTransformer from featureforest.utils.downloader import download_model + from .adapter import MobileSAMAdapter +from .tiny_vit_sam import TinyViT -def get_model( - img_height: float, img_width: float, *args, **kwargs -) -> MobileSAMAdapter: +def get_model(img_height: int, img_width: int, *args, **kwargs) -> MobileSAMAdapter: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"running on {device}") # get the model model = setup_model().to(device) # download model's weights - model_url = \ + model_url = ( "https://github.com/ChaoningZhang/MobileSAM/raw/master/weights/mobile_sam.pt" - model_file = download_model( - model_url=model_url, - model_name="mobile_sam.pt" ) + model_file = download_model(model_url=model_url, model_name="mobile_sam.pt") if model_file is None: raise ValueError(f"Could not download the model from {model_url}.") @@ -32,9 +26,7 @@ def get_model( model.eval() # create the model adapter - sam_model_adapter = MobileSAMAdapter( - model, img_height, img_width, device - ) + sam_model_adapter = MobileSAMAdapter(model, img_height, img_width, device) return sam_model_adapter @@ -46,18 +38,20 @@ def setup_model() -> Sam: image_embedding_size = image_size // vit_patch_size mobile_sam = Sam( image_encoder=TinyViT( - img_size=1024, in_chans=3, num_classes=1000, + img_size=1024, + in_chans=3, + num_classes=1000, embed_dims=[64, 128, 160, 320], depths=[2, 2, 6, 2], num_heads=[2, 4, 5, 10], window_sizes=[7, 7, 14, 7], - mlp_ratio=4., - drop_rate=0., + mlp_ratio=4.0, + drop_rate=0.0, drop_path_rate=0.0, use_checkpoint=False, mbconv_expand_ratio=4.0, local_conv_size=3, - layer_lr_decay=0.8 + layer_lr_decay=0.8, ), prompt_encoder=PromptEncoder( embed_dim=prompt_embed_dim, diff --git a/src/featureforest/models/SAM/adapter.py b/src/featureforest/models/SAM/adapter.py index 25bcc20..6c74649 100644 --- a/src/featureforest/models/SAM/adapter.py +++ b/src/featureforest/models/SAM/adapter.py @@ -1,5 +1,3 @@ -from typing import Tuple - import torch import torch.nn as nn from torch import Tensor @@ -7,7 +5,6 @@ from featureforest.models.base import BaseModelAdapter from featureforest.utils.data import ( - get_patch_size, get_nonoverlapped_patches, ) @@ -17,11 +14,12 @@ class SAMAdapter(BaseModelAdapter): Supports: 1) default 'vit_h' model, and 2) light microscopy and electron microscopy `micro-sam` models ('vit_b'). """ + def __init__( self, image_encoder: nn.Module, - img_height: float, - img_width: float, + img_height: int, + img_width: int, device: torch.device, name: str, ) -> None: @@ -29,60 +27,56 @@ def __init__( self.name = name # we need sam image encoder part self.encoder = image_encoder - self.encoder_num_channels = 256 + self.encoder_num_channels: int = 256 # NOTE: The parameter below matches the SAM model's encoder embedding dimension. - self.embed_layer_num_channels = image_encoder.patch_embed.proj.out_channels + self.embed_layer_num_channels: int = image_encoder.patch_embed.proj.out_channels self._set_patch_size() self.device = device # input transform for sam self.sam_input_dim = 1024 - self.input_transforms = tv_transforms2.Compose([ - tv_transforms2.Resize( - (self.sam_input_dim, self.sam_input_dim), - interpolation=tv_transforms2.InterpolationMode.BICUBIC, - antialias=True - ), - ]) + self.input_transforms = tv_transforms2.Compose( + [ + tv_transforms2.Resize( + (self.sam_input_dim, self.sam_input_dim), + interpolation=tv_transforms2.InterpolationMode.BICUBIC, + antialias=True, + ), + ] + ) # to transform feature patches back to the original patch size - self.embedding_transform = tv_transforms2.Compose([ - tv_transforms2.Resize( - (self.patch_size, self.patch_size), - interpolation=tv_transforms2.InterpolationMode.BICUBIC, - antialias=True - ), - ]) - - def _set_patch_size(self) -> None: - self.patch_size = get_patch_size(self.img_height, self.img_width) - self.overlap = self.patch_size // 2 + self.embedding_transform = tv_transforms2.Compose( + [ + tv_transforms2.Resize( + (self.patch_size, self.patch_size), + interpolation=tv_transforms2.InterpolationMode.BICUBIC, + antialias=True, + ), + ] + ) - def get_features_patches( - self, in_patches: Tensor - ) -> Tuple[Tensor, Tensor]: + def get_features_patches(self, in_patches: Tensor) -> Tensor: # get the mobile-sam encoder and embedding layer outputs with torch.no_grad(): # output: b,256,64,64 - output = self.encoder( - self.input_transforms(in_patches) - ) + output = self.encoder(self.input_transforms(in_patches)) # embed_output: b,64,64,1280 -> b,1280,64,64 embed_output = self.encoder.patch_embed( self.input_transforms(in_patches) - ).permute(0, 3, 1, 2) + ).permute(0, 3, 1, 2) # type: ignore # get non-overlapped feature patches out_feature_patches = get_nonoverlapped_patches( - self.embedding_transform(output.cpu()), - self.patch_size, self.overlap + self.embedding_transform(output.cpu()), self.patch_size, self.overlap ) embed_feature_patches = get_nonoverlapped_patches( - self.embedding_transform(embed_output.cpu()), - self.patch_size, self.overlap + self.embedding_transform(embed_output.cpu()), self.patch_size, self.overlap ) + # concat both features together on channel dimension + output = torch.cat([out_feature_patches, embed_feature_patches], dim=-1) - return out_feature_patches, embed_feature_patches + return output def get_total_output_channels(self) -> int: return self.encoder_num_channels + self.embed_layer_num_channels diff --git a/src/featureforest/models/SAM/model.py b/src/featureforest/models/SAM/model.py index daf0f2a..1865743 100644 --- a/src/featureforest/models/SAM/model.py +++ b/src/featureforest/models/SAM/model.py @@ -1,11 +1,10 @@ import torch - -from segment_anything.modeling import Sam from segment_anything import sam_model_registry +from segment_anything.modeling import Sam from featureforest.utils.downloader import download_model -from .adapter import SAMAdapter +from .adapter import SAMAdapter MODEL_URLS = { "vit_h": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth", @@ -27,16 +26,13 @@ def get_model( - img_height: float, img_width: float, model_type: str = "vit_h", *args, **kwargs + img_height: int, img_width: int, model_type: str = "vit_h", *args, **kwargs ) -> SAMAdapter: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"running on {device}") # download model's weights model_url = MODEL_URLS[model_type] - model_file = download_model( - model_url=model_url, - model_name=MODEL_FNAMES[model_type] - ) + model_file = download_model(model_url=model_url, model_name=MODEL_FNAMES[model_type]) if model_file is None: raise ValueError(f"Could not download the model from {model_url}.") diff --git a/src/featureforest/models/SAM2/adapter.py b/src/featureforest/models/SAM2/adapter.py index 571d4ef..00fa078 100644 --- a/src/featureforest/models/SAM2/adapter.py +++ b/src/featureforest/models/SAM2/adapter.py @@ -1,5 +1,3 @@ -from typing import Tuple - import torch import torch.nn as nn from torch import Tensor @@ -7,21 +5,20 @@ from featureforest.models.base import BaseModelAdapter from featureforest.utils.data import ( - get_patch_size, get_nonoverlapped_patches, ) class SAM2Adapter(BaseModelAdapter): - """SAM2 model adapter - """ + """SAM2 model adapter""" + def __init__( self, image_encoder: nn.Module, - img_height: float, - img_width: float, + img_height: int, + img_width: int, device: torch.device, - name: str = "SAM2_Large" + name: str = "SAM2_Large", ) -> None: super().__init__(image_encoder, img_height, img_width, device) # for different flavors of SAM2 only the name is different. @@ -34,44 +31,41 @@ def __init__( # input transform for sam self.sam_input_dim = 1024 - self.input_transforms = tv_transforms2.Compose([ - tv_transforms2.Resize( - (self.sam_input_dim, self.sam_input_dim), - interpolation=tv_transforms2.InterpolationMode.BICUBIC, - antialias=True - ), - ]) + self.input_transforms = tv_transforms2.Compose( + [ + tv_transforms2.Resize( + (self.sam_input_dim, self.sam_input_dim), + interpolation=tv_transforms2.InterpolationMode.BICUBIC, + antialias=True, + ), + ] + ) # to transform feature patches back to the original patch size - self.embedding_transform = tv_transforms2.Compose([ - tv_transforms2.Resize( - (self.patch_size, self.patch_size), - interpolation=tv_transforms2.InterpolationMode.BICUBIC, - antialias=True - ), - ]) - - def _set_patch_size(self) -> None: - self.patch_size = get_patch_size(self.img_height, self.img_width) - self.overlap = self.patch_size // 2 + self.embedding_transform = tv_transforms2.Compose( + [ + tv_transforms2.Resize( + (self.patch_size, self.patch_size), + interpolation=tv_transforms2.InterpolationMode.BICUBIC, + antialias=True, + ), + ] + ) - def get_features_patches( - self, in_patches: Tensor - ) -> Tuple[Tensor, Tensor]: + def get_features_patches(self, in_patches: Tensor) -> Tensor: # get the image encoder outputs with torch.no_grad(): - output = self.encoder( - self.input_transforms(in_patches) - ) + output = self.encoder(self.input_transforms(in_patches)) # backbone_fpn contains 3 levels of features from hight to low resolution. # [b, 256, 256, 256] # [b, 256, 128, 128] # [b, 256, 64, 64] features = [ - self.embedding_transform(feat.cpu()) - for feat in output["backbone_fpn"] + self.embedding_transform(feat.cpu()) for feat in output["backbone_fpn"] ] features = torch.cat(features, dim=1) - out_feature_patches = get_nonoverlapped_patches(features, self.patch_size, self.overlap) + out_feature_patches = get_nonoverlapped_patches( + features, self.patch_size, self.overlap + ) return out_feature_patches diff --git a/src/featureforest/models/SAM2/model.py b/src/featureforest/models/SAM2/model.py index 679b6c2..45f17a6 100644 --- a/src/featureforest/models/SAM2/model.py +++ b/src/featureforest/models/SAM2/model.py @@ -1,34 +1,27 @@ -from pathlib import Path - import torch - -from sam2.modeling.sam2_base import SAM2Base from sam2.build_sam import build_sam2 +from sam2.modeling.sam2_base import SAM2Base -from featureforest.utils.downloader import download_model from featureforest.models.SAM2.adapter import SAM2Adapter +from featureforest.utils.downloader import download_model -def get_large_model( - img_height: float, img_width: float, *args, **kwargs -) -> SAM2Adapter: +def get_large_model(img_height: int, img_width: int, *args, **kwargs) -> SAM2Adapter: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"running on {device}") # download model's weights - model_url = \ + model_url = ( "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_large.pt" - model_file = download_model( - model_url=model_url, - model_name="sam2.1_hiera_large.pt" ) + model_file = download_model(model_url=model_url, model_name="sam2.1_hiera_large.pt") if model_file is None: raise ValueError(f"Could not download the model from {model_url}.") # init the model model: SAM2Base = build_sam2( - config_file= "configs/sam2.1/sam2.1_hiera_l.yaml", + config_file="configs/sam2.1/sam2.1_hiera_l.yaml", ckpt_path=model_file, - device="cpu" + device="cpu", ) # to save some GPU memory, only put the encoder part on GPU sam_image_encoder = model.image_encoder.to(device) @@ -42,26 +35,22 @@ def get_large_model( return sam2_model_adapter -def get_base_model( - img_height: float, img_width: float, *args, **kwargs -) -> SAM2Adapter: +def get_base_model(img_height: float, img_width: float, *args, **kwargs) -> SAM2Adapter: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"running on {device}") # download model's weights - model_url = \ - "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_base_plus.pt" + model_url = "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_base_plus.pt" model_file = download_model( - model_url=model_url, - model_name="sam2.1_hiera_base_plus.pt" + model_url=model_url, model_name="sam2.1_hiera_base_plus.pt" ) if model_file is None: raise ValueError(f"Could not download the model from {model_url}.") # init the model model: SAM2Base = build_sam2( - config_file= "configs/sam2.1/sam2.1_hiera_b+.yaml", + config_file="configs/sam2.1/sam2.1_hiera_b+.yaml", ckpt_path=model_file, - device="cpu" + device="cpu", ) # to save some GPU memory, only put the encoder part on GPU sam_image_encoder = model.image_encoder.to(device) diff --git a/src/featureforest/models/__init__.py b/src/featureforest/models/__init__.py index ec870c6..50a4419 100644 --- a/src/featureforest/models/__init__.py +++ b/src/featureforest/models/__init__.py @@ -25,7 +25,7 @@ def get_available_models() -> list[str]: def get_model( - model_name: str, img_height: float, img_width: float, *args, **kwargs + model_name: str, img_height: int, img_width: int, *args, **kwargs ) -> BaseModelAdapter: """Returns the requested model adapter. diff --git a/src/featureforest/models/base.py b/src/featureforest/models/base.py index 97aeb5a..178b490 100644 --- a/src/featureforest/models/base.py +++ b/src/featureforest/models/base.py @@ -5,18 +5,15 @@ from ..utils.data import ( get_nonoverlapped_patches, + get_patch_size, ) class BaseModelAdapter: - """Base class for adapting any models in featureforest. - """ + """Base class for adapting any models in featureforest.""" + def __init__( - self, - model: nn.Module, - img_height: float, - img_width: float, - device: torch.device + self, model: nn.Module, img_height: int, img_width: int, device: torch.device ) -> None: """Initialization function @@ -34,33 +31,52 @@ def __init__( self.device = device # set patch size and overlap self.patch_size = 512 - self.overlap = self.patch_size // 2 + self.overlap = self.patch_size // 4 + self._no_patching = False # input image transforms - self.input_transforms = tv_transforms2.Compose([ - tv_transforms2.Resize( - (1024, 1024), - interpolation=tv_transforms2.InterpolationMode.BICUBIC, - antialias=True - ), - ]) + self.input_transforms = tv_transforms2.Compose( + [ + tv_transforms2.Resize( + (1024, 1024), + interpolation=tv_transforms2.InterpolationMode.BICUBIC, + antialias=True, + ), + ] + ) # to transform feature patches back to the original patch size - self.embedding_transform = tv_transforms2.Compose([ - tv_transforms2.Resize( - (self.patch_size, self.patch_size), - interpolation=tv_transforms2.InterpolationMode.BICUBIC, - antialias=True - ), - ]) + self.embedding_transform = tv_transforms2.Compose( + [ + tv_transforms2.Resize( + (self.patch_size, self.patch_size), + interpolation=tv_transforms2.InterpolationMode.BICUBIC, + antialias=True, + ), + ] + ) + + @property + def no_patching(self) -> bool: + return self._no_patching + + @no_patching.setter + def no_patching(self, value: bool): + self._no_patching = value + self._set_patch_size() def _set_patch_size(self) -> None: """Sets the proper patch size and patch overlap with respect to the model & image resolution. """ - raise NotImplementedError + if self._no_patching: + self.patch_size = self.img_height + self.overlap = 0 + else: + self.patch_size = get_patch_size(self.img_height, self.img_width) + self.overlap = self.patch_size // 4 + # update embedding transform + self.embedding_transform.transforms[0].size = [self.patch_size, self.patch_size] - def get_features_patches( - self, in_patches: Tensor - ) -> Tensor: + def get_features_patches(self, in_patches: Tensor) -> Tensor: """Returns model's extracted features. This is an abstract function, and should be overridden. @@ -77,8 +93,7 @@ def get_features_patches( # get non-overlapped feature patches feature_patches = get_nonoverlapped_patches( - self.embedding_transform(output_features.cpu()), - self.patch_size, self.overlap + self.embedding_transform(output_features.cpu()), self.patch_size, self.overlap ) return feature_patches diff --git a/src/featureforest/models/util.py b/src/featureforest/models/util.py deleted file mode 100644 index f1fcd12..0000000 --- a/src/featureforest/models/util.py +++ /dev/null @@ -1,89 +0,0 @@ -import os -from pathlib import Path -from typing import Union, Optional - -import numpy as np -import imageio.v3 as imageio - -from . import _MODELS_DICT, get_model -from ..utils.extract import extract_embeddings_to_file, get_stack_dims - - -def extract_features( - image: np.ndarray, - output_path: Union[os.PathLike, str], - model_name: str = "SAM2_Large", - image_height: Optional[int] = None, - image_width: Optional[int] = None, -) -> str: - """Extracts features for the chosen model. - - Args: - image: The input image. - output_path: The filepath where the extracted features will be stored. - model_name: The choice of model that will be used for feature extraction. - By default, extracts features for `SAM2_Large`. - image_height: The height of input image. By default, extracted from the input image. - image_width: The width of input image. By default, extracted from the input image. - - Returns: - The filepath where the extracted features have been stored. - """ - if image_height is None and image_width is None: - # - Get the height and width of the input image. - _, image_height, image_width = get_stack_dims(image) - - # Step 1: Get the desired model adapter. - model_adapter = get_model(model_name=model_name, img_height=image_height, img_width=image_width) - - # - Transform the inputs - transformed_image = model_adapter.input_transforms(image) - - # Step 2: Run the feature extraction step. - if os.path.splitext(output_path)[-1].lower() != ".hdf5": - # In this case, we assume that it's a filepath without extension and give it the desired one. - output_path = str(Path(output_path).with_suffix(".hdf5")) - - extractor_generator = extract_embeddings_to_file( - image=transformed_image, storage_file_path=output_path, model_adapter=model_adapter, - ) - - # Step 3: Run the extractor generator till the end - _ = list(extractor_generator) - - return output_path - - -def main(): - """@private""" - - import argparse - - available_models = list(_MODELS_DICT.keys()) - available_models = ", ".join(available_models) - - parser = argparse.ArgumentParser(description="Extract features for a chosen model.") - parser.add_argument( - "-i", "--input_path", type=str, required=True, - help="The filepath to the image data. Supports all data types that can be read by imageio (eg. tif, png, ...).", - ) - parser.add_argument( - "-o", "--output_path", type=str, required=True, - help="The filepath to store the extracted features. The current supports store features in a 'hdf5' file.", - ) - parser.add_argument( - "--model_choice", type=str, default="SAM2_Large", - help=f"The choice of vision foundation model that will be used, one of ({available_models}). " - "By default, extracts features for 'SAM2_Large'.", - ) - - args = parser.parse_args() - - # Load the image. - # TODO: Currently supports a simple setup. We can make it more complicated to support other bioformats later. - image = imageio.imread(args.input_path) - - # Extract the features. - output_path = extract_features(image=image, output_path=args.output_path, model_name=args.model_choice) - - print(f"The features of '{args.model_choice}' have been extracted at '{os.path.abspath(output_path)}'.") diff --git a/src/featureforest/utils/config.py b/src/featureforest/utils/config.py index dbe926f..05ed482 100644 --- a/src/featureforest/utils/config.py +++ b/src/featureforest/utils/config.py @@ -1,5 +1,5 @@ import napari - +import napari.layers BG_CLASS_NAME = "background" NAPARI_BG_CLASS = 0 diff --git a/src/featureforest/utils/data.py b/src/featureforest/utils/data.py index 5c276b0..85b81ef 100644 --- a/src/featureforest/utils/data.py +++ b/src/featureforest/utils/data.py @@ -1,11 +1,38 @@ from typing import Optional import numpy as np +import torch import torch.nn.functional as F from numpy import ndarray from torch import Tensor +def get_model_ready_image(image: np.ndarray) -> torch.Tensor: + """Convert the input image to a torch tensor and normalize it. + Args: + image (np.ndarray): Input image to be converted (H, W, C). + Returns: + torch.Tensor: The input image as a torch tensor, normalized to [0, 1]. + """ + assert image.ndim < 4, "Input image must be 2D or 3D (single channel or RGB)." + # image to torch tensor + img_data = torch.from_numpy(image.copy()).to(torch.float32) + # normalize in [0, 1] + _min = img_data.min() + _max = img_data.max() + img_data = (img_data - _min) / (_max - _min) + # for image encoders, the input image must be in RGB. + if is_image_rgb(img_data.numpy()): + # it's already RGB + img_data = img_data[..., :3] # discard the alpha channel (in case of PNG). + img_data = img_data.permute([2, 0, 1]) # make it channel first + else: + # make it RGB by repeating the single channel + img_data = img_data.unsqueeze(0).expand(3, -1, -1) + + return img_data + + def get_patch_size( img_height: float, img_width: float, divisible_by: Optional[int] = None ) -> int: @@ -74,6 +101,9 @@ def get_paddings( # pad amount should be enough to make the # (final size - patch_size) / stride an integer number. # see https://pytorch.org/docs/stable/generated/torch.nn.Unfold.html + # if whole image is just one patch + if img_height == patch_size and img_width == patch_size and patch_size == stride: + return 0, 0 new_width = img_width + 2 * margin pad_right = stride - int((new_width - patch_size) % stride) new_height = img_height + 2 * margin @@ -83,33 +113,34 @@ def get_paddings( def patchify( - images: Tensor, patch_size: Optional[int] = None, overlap: Optional[int] = None + image: Tensor, patch_size: Optional[int] = None, overlap: Optional[int] = None ) -> Tensor: """Divide images into patches. - images: (B, C, H, W) - out: (B*N, C, patch_size, patch_size) + image: (C, H, W) + out: (N, C, patch_size, patch_size) Args: - images (Tensor): a batch of images of shape (B, C, H, W) + images (Tensor): an image of shape (C, H, W) patch_size (Optional[int], optional): patch size. Defaults to None. overlap (Optional[int], optional): patch overlap. Defaults to None. Returns: - Tensor: patches of the input batch of shape (B*N, C, patch_size, patch_size) + Tensor: patches of shape (N, C, patch_size, patch_size) """ - _, c, img_height, img_width = images.shape + c, img_height, img_width = image.shape if patch_size is None: patch_size = get_patch_size(img_height, img_width) - overlap = patch_size // 2 + overlap = patch_size // 4 if overlap is None: - overlap = patch_size // 2 + overlap = patch_size // 4 stride, margin = get_stride_margin(patch_size, overlap) pad_right, pad_bottom = get_paddings( patch_size, stride, margin, img_height, img_width ) pad = (margin, pad_right + margin, margin, pad_bottom + margin) - padded_imgs = F.pad(images, pad=pad, mode="reflect") + # add batch dim and pad the image + padded_imgs = F.pad(image.unsqueeze(0), pad=pad, mode="reflect") # making patches using torch unfold method patches = padded_imgs.unfold(2, patch_size, step=stride).unfold( 3, patch_size, step=stride @@ -120,7 +151,7 @@ def patchify( def get_num_patches( - img_height: float, img_width: float, patch_size: int, overlap: int + img_height: int, img_width: int, patch_size: int, overlap: int ) -> tuple[int, int]: """Returns number of patches per each image dimension. @@ -172,8 +203,8 @@ def get_nonoverlapped_patches(patches: Tensor, patch_size: int, overlap: int) -> def get_patch_index( pix_y: float, pix_x: float, - img_height: float, - img_width: float, + img_height: int, + img_width: int, patch_size: int, overlap: int, ) -> int: @@ -199,8 +230,8 @@ def get_patch_index( def get_patch_indices( pixel_coords: ndarray, - img_height: float, - img_width: float, + img_height: int, + img_width: int, patch_size: int, overlap: int, ) -> ndarray: diff --git a/src/featureforest/utils/dataset.py b/src/featureforest/utils/dataset.py new file mode 100644 index 0000000..77ad0fe --- /dev/null +++ b/src/featureforest/utils/dataset.py @@ -0,0 +1,102 @@ +from collections.abc import Iterable +from pathlib import Path +from typing import Optional + +import numpy as np +import pims +import torch +from tifffile import natural_sorted +from torch.utils.data import IterableDataset + +from featureforest.utils.data import ( + get_model_ready_image, + is_stacked, + patchify, +) + + +class FFImageDataset(IterableDataset): + """ + Iterable dataset for large images or image stacks. + This dataset can handle large TIFF files or directories of images, + and it can yield patches of images or the whole image depending on the configuration. + """ + + def __init__( + self, + images: str | Path | np.ndarray, + no_patching: bool = False, + patch_size: int = 512, + overlap: int = 128, + ) -> None: + super().__init__() + self.no_patching = no_patching + self.patch_size = patch_size + self.overlap = overlap + self.image_files = [] + self.image_source: Optional[pims.ImageSequence | np.ndarray] = None + + if isinstance(images, np.ndarray): + # images are already loaded into a numpy array + self.image_source = images + # add slice dimension if not present + if not is_stacked(self.image_source): + self.image_source = self.image_source[np.newaxis, ...] + + elif isinstance(images, str | Path): + images = Path(images) + if images.is_file(): + # can be a large stack, using pims for lazy loading + self.image_source = pims.open(str(images)) + + elif images.is_dir(): + self.image_files = ( + list(images.glob("*.tiff")) + + list(images.glob("*.tif")) + + list(images.glob("*.png")) + + list(images.glob("*.jpg")) + ) + if not self.image_files: + raise ValueError(f"No image files found in the directory {images}.") + self.image_files = self._natural_sort(self.image_files) + self.image_source = pims.ImageSequence(map(str, self.image_files)) + else: + raise ValueError( + f"images should be a numpy array or a directory or an image stack!" + f"\nGot {type(images)}" + ) + + @property + def num_images(self) -> int: + """Return the number of images in the dataset.""" + if self.image_source is None: + return 0 + return len(self.image_source) + + @property + def image_shape(self) -> tuple[int, int]: + """Return the shape of the images in the dataset.""" + if self.image_source is None: + raise ValueError("No image source is available. Please check the input data.") + if isinstance(self.image_source, np.ndarray): + return self.image_source.shape[1:] + return self.image_source.frame_shape + + def __iter__(self): + if self.image_source is None: + raise ValueError("No image source is available. Please check the input data.") + + for img_idx, img_slice in enumerate(self.image_source): + img_tensor = get_model_ready_image(img_slice) + if self.no_patching: + # return the whole image as a tensor + yield img_tensor, torch.tensor([img_idx, 0]) + else: + # divide the image into patches and yield them + patches = patchify(img_tensor, self.patch_size, self.overlap) + for p_idx, patch in enumerate(patches): + yield patch, torch.tensor([img_idx, p_idx]) + + def _natural_sort(self, files: Iterable[Path]) -> list[Path]: + """Sort files in a natural order.""" + return sorted(files, key=lambda x: natural_sorted(str(x.name))) diff --git a/src/featureforest/utils/extract.py b/src/featureforest/utils/extract.py index 3349807..7f58a01 100644 --- a/src/featureforest/utils/extract.py +++ b/src/featureforest/utils/extract.py @@ -1,123 +1,108 @@ from collections.abc import Generator -from typing import Optional, Union +from typing import Optional import h5py import numpy as np import torch -from napari.utils import progress as np_progress +from torch.utils.data import DataLoader from featureforest.models import BaseModelAdapter from featureforest.models.SAM import SAMAdapter -from featureforest.utils.data import ( - get_stack_dims, - get_stride_margin, - image_to_uint8, - is_image_rgb, - patchify, -) +from featureforest.utils.dataset import FFImageDataset -def get_slice_features( - image: np.ndarray, - model_adapter: BaseModelAdapter, - storage_group: Optional[h5py.Group] = None, -) -> Generator[Union[int, tuple[int, np.ndarray]], None, None]: - """Extract features for one image/slice using the given model adapter and - save them into a storage file or yield them batch by batch. - +def get_batch_size(model_adapter: BaseModelAdapter) -> int: + """Get the batch size for the model adapter. + The batch size is set to 8 for most models, but for SAMAdapter it is set to 2 + to avoid memory issues with large images. Args: - image (np.ndarray): Input image - model_adapter (BaseModelAdapter): Model adapter to extract features from - storage_group (Optional[h5py.Group]): h5 file group where the extracted features - will be saved. If None, will yield patch features batch by batch. + model_adapter (BaseModelAdapter): The model adapter to get the batch size for. Returns: - Generator[Union[int, tuple[int, np.ndarray]]]: A generator yielding either the - current batch number or a tuple containing the current batch number - and the corresponding patch features. + int: The batch size for the model adapter. """ - # image to torch tensor - img_data = torch.from_numpy(image.copy()).to(torch.float32) - # normalize in [0, 1] - _min = img_data.min() - _max = img_data.max() - img_data = (img_data - _min) / (_max - _min) - # for sam the input image should be 4D: BxCxHxW ; an RGB image. - if is_image_rgb(image): - # it's already RGB, put the channels first and add a batch dim. - img_data = img_data[..., :3] # ignore the Alpha channel (in case of PNG). - img_data = img_data.permute([2, 0, 1]).unsqueeze(0) - else: - img_data = img_data.unsqueeze(0).unsqueeze(0).expand(-1, 3, -1, -1) - - # get input patches - patch_size = model_adapter.patch_size - overlap = model_adapter.overlap - data_patches = patchify(img_data, patch_size, overlap) - num_patches = len(data_patches) - # set a low batch size batch_size = 8 # for big SAM we need even lower batch size :( if isinstance(model_adapter, SAMAdapter): batch_size = 2 + return batch_size - num_batches = int(np.ceil(num_patches / batch_size)) - # prepare storage for the image embeddings - if storage_group is not None: - total_channels = model_adapter.get_total_output_channels() - stride, _ = get_stride_margin(patch_size, overlap) - dataset = storage_group.create_dataset( - model_adapter.name, - shape=(num_patches, stride, stride, total_channels), - dtype=np.float16, - compression="lzf", - ) - # get sam encoder output for image patches - print("extracting features:") - for b_idx in np_progress(range(num_batches), desc="extracting features"): - print(f"batch #{b_idx + 1} of {num_batches}") - start = b_idx * batch_size - end = start + batch_size - slice_features = model_adapter.get_features_patches( - data_patches[start:end].to(model_adapter.device) +def extract_embeddings( + model_adapter: BaseModelAdapter, + image: Optional[np.ndarray | str] = None, + image_dataset: Optional[FFImageDataset] = None, +) -> Generator[tuple[np.ndarray, int, int], None, None]: + no_patching = model_adapter.no_patching + patch_size = model_adapter.patch_size + overlap = model_adapter.overlap + batch_size = get_batch_size(model_adapter) + + # create the image dataset + if image_dataset is None: + if image is None: + raise ValueError("You should pass either the image or the image_dataset!") + image_dataset = FFImageDataset( + images=image, no_patching=no_patching, patch_size=patch_size, overlap=overlap ) - if isinstance(slice_features, tuple): # model with more than one output - slice_features = torch.cat(slice_features, dim=-1) - if storage_group is not None: - # to take care of the last batch size that might be smaller than batch_size - num_out = slice_features.shape[0] - dataset[start: start + num_out] = slice_features.to( - torch.float16).cpu().numpy() - yield b_idx - else: - yield b_idx, slice_features.numpy() + # loop through the dataset and extract features + dataloader = DataLoader( + image_dataset, batch_size=batch_size, shuffle=False, num_workers=0 + ) + print(f"Start extracting features for {image_dataset.num_images} slices...") + for img_data, indices in dataloader: + features = model_adapter.get_features_patches(img_data.to(model_adapter.device)) + unique_slices = torch.unique(indices[:, 0]).numpy() + for idx in unique_slices: + img_features = features[indices[:, 0] == idx].numpy().astype(np.float16) + + yield img_features, idx, image_dataset.num_images def extract_embeddings_to_file( - image: np.ndarray, storage_file_path: str, model_adapter: BaseModelAdapter + image: np.ndarray | str, storage_path: str, model_adapter: BaseModelAdapter ) -> Generator[tuple[int, int], None, None]: + no_patching = model_adapter.no_patching patch_size = model_adapter.patch_size overlap = model_adapter.overlap - with h5py.File(storage_file_path, "w") as storage: - num_slices, img_height, img_width = get_stack_dims(image) + # # create the image dataset + image_dataset = FFImageDataset( + images=image, no_patching=no_patching, patch_size=patch_size, overlap=overlap + ) + # create the storage + storage = h5py.File(storage_path, "w") + storage.attrs["num_slices"] = image_dataset.num_images + storage.attrs["img_height"] = image_dataset.image_shape[0] + storage.attrs["img_width"] = image_dataset.image_shape[1] + storage.attrs["model"] = model_adapter.name + storage.attrs["no_patching"] = no_patching + storage.attrs["patch_size"] = patch_size + storage.attrs["overlap"] = overlap - storage.attrs["num_slices"] = num_slices - storage.attrs["img_height"] = img_height - storage.attrs["img_width"] = img_width - storage.attrs["model"] = model_adapter.name - storage.attrs["patch_size"] = patch_size - storage.attrs["overlap"] = overlap + for img_features, idx, total in extract_embeddings( + model_adapter, image_dataset=image_dataset + ): + if storage.get(str(idx)) is None: + # create a group for the slice + grp = storage.create_group(str(idx)) + ds = grp.create_dataset( + name="features", + shape=img_features.shape, + maxshape=(None,) + img_features.shape[1:], + chunks=(1,) + img_features.shape[1:], + dtype=np.float16, + compression="lzf", + ) + ds[:] = img_features + else: + # resize the dataset and append features to the slice/image group + ds: h5py.Dataset = storage[str(idx)]["features"] # type: ignore + old_num_patches = ds.shape[0] + num_patches = ds.shape[0] + img_features.shape[0] + ds.resize(num_patches, axis=0) + ds[old_num_patches:, ...] = img_features - for slice_index in np_progress( - range(num_slices), desc="extract features for slices" - ): - print(f"\nslice index: {slice_index}") - slice_img = image[slice_index].copy() if num_slices > 1 else image.copy() - slice_img = image_to_uint8(slice_img) # image must be an uint8 array - slice_grp = storage.create_group(str(slice_index)) - for _ in get_slice_features(slice_img, model_adapter, slice_grp): - pass + yield idx, total - yield slice_index, num_slices + storage.close() diff --git a/src/featureforest/utils/pipeline_prediction.py b/src/featureforest/utils/pipeline_prediction.py index 87792e9..1d75b85 100644 --- a/src/featureforest/utils/pipeline_prediction.py +++ b/src/featureforest/utils/pipeline_prediction.py @@ -1,32 +1,30 @@ -import multiprocessing as mp +from collections.abc import Generator import numpy as np from sklearn.ensemble import RandomForestClassifier as RF from featureforest.models import BaseModelAdapter from featureforest.utils.data import get_num_patches, get_stride_margin -from featureforest.utils.extract import get_slice_features +from featureforest.utils.dataset import FFImageDataset +from featureforest.utils.extract import extract_embeddings def predict_patches( - patch_features: np.ndarray, + feature_list: list[np.ndarray], rf_model: RF, model_adapter: BaseModelAdapter, - batch_idx: int, - result_dict: dict, -) -> None: - """Predicts the class labels for a given set of patch features. +) -> np.ndarray: + """Predicts the class labels for a given list of features (patches). Args: - patch_features (np.ndarray): Patch features to be predicted. + feature_list (list[np.ndarray]): List of features to be predicted. rf_model (RF): Random Forest Model used for predictions. model_adapter (BaseModelAdapter): Model adapter object used for extracting data. - batch_idx (int): Batch index of the current patch features. - result_dict (dict): Dictionary where the predicted masks will be stored. """ patch_masks = [] - # shape: N x target_size x target_size x C - num_patches = patch_features.shape[0] + # shape: N x stride x stride x C + patch_features = np.vstack(feature_list) + num_patches = len(patch_features) total_channels = model_adapter.get_total_output_channels() print(f"predicting {num_patches} patches...") for i in range(num_patches): @@ -35,7 +33,7 @@ def predict_patches( patch_masks.append(pred) patch_masks = np.vstack(patch_masks) - result_dict[batch_idx] = patch_masks + return patch_masks def get_image_mask( @@ -69,47 +67,45 @@ def get_image_mask( return mask_image -def extract_predict( - image: np.ndarray, +def run_prediction_pipeline( + input_stack: str, model_adapter: BaseModelAdapter, rf_model: RF, -) -> np.ndarray: - """Extracts features and predicts the classes for a given image. - - Args: - image (np.ndarray): Input image to extract features from. - model_adapter (BaseModelAdapter): Model adapter object used for extracting data. - rf_model (RF): Random Forest Model used for predictions. - - Returns: - np.ndarray: Final image mask. - """ - img_height, img_width = image.shape[:2] +) -> Generator[tuple[np.ndarray, int, int], None, None]: + no_patching = model_adapter.no_patching patch_size = model_adapter.patch_size overlap = model_adapter.overlap - procs = [] - # prediction happens per batch of extracted features - # in a separate process. - with mp.Manager() as manager: - result_dict = manager.dict() - for b_idx, patch_features in get_slice_features(image, model_adapter): - print(b_idx, end="\r") - proc = mp.Process( - target=predict_patches, - args=(patch_features, rf_model, model_adapter, b_idx, result_dict), + stack_dataset = FFImageDataset( + images=input_stack, + no_patching=no_patching, + patch_size=patch_size, + overlap=overlap, + ) + img_height, img_width = stack_dataset.image_shape + + prev_idx = 0 + slice_features = [] + for img_features, slice_idx, total in extract_embeddings( + model_adapter, image_dataset=stack_dataset + ): + print(f"{slice_idx} / {total}") + if prev_idx != slice_idx: + # we have one slice features extracted: make a prediction. + patch_masks = predict_patches(slice_features, rf_model, model_adapter) + slice_mask = get_image_mask( + patch_masks, img_height, img_width, patch_size, overlap ) - procs.append(proc) - proc.start() - # wait until all processes are done - for p in procs: - if p.is_alive: - p.join() - # collect results from each process - batch_indices = sorted(result_dict.keys()) - patch_masks = [result_dict[b] for b in batch_indices] - patch_masks = np.vstack(patch_masks) - slice_mask = get_image_mask( - patch_masks, img_height, img_width, patch_size, overlap - ) - - return slice_mask + + yield slice_mask, prev_idx, total + + # start collecting next slice features + prev_idx = slice_idx + slice_features = [] + slice_features.append(img_features) + else: + # collect slice features + slice_features.append(img_features) + # make prediction for the last slice + patch_masks = predict_patches(slice_features, rf_model, model_adapter) + slice_mask = get_image_mask(patch_masks, img_height, img_width, patch_size, overlap) + yield slice_mask, slice_idx, total diff --git a/src/featureforest/widgets/utils.py b/src/featureforest/widgets/utils.py index 6877376..ee33a5c 100644 --- a/src/featureforest/widgets/utils.py +++ b/src/featureforest/widgets/utils.py @@ -1,12 +1,18 @@ +from typing import Optional + import napari +import napari.layers as layers +import napari.types import numpy as np from featureforest.utils import colormaps -def get_layer(napari_viewer, name, layer_types): +def get_layer( + napari_viewer: napari.Viewer, name: str, layer_type: type +) -> Optional[layers.Layer]: for layer in napari_viewer.layers: - if layer.name == name and isinstance(layer, layer_types): + if layer.name == name and isinstance(layer, layer_type): return layer return None @@ -20,8 +26,8 @@ def add_labels_layer_(napari_viewer: napari.Viewer): scene_size = extent[1] - extent[0] corner = extent[0] shape = [ - np.round(s / sc).astype('int') + 1 - for s, sc in zip(scene_size, scale) + np.round(s / sc).astype("int") + 1 + for s, sc in zip(scene_size, scale, strict=False) ] empty_labels = np.zeros(shape, dtype=np.uint8) layer = napari_viewer.add_labels( diff --git a/tests/model_adapter_tests/embedding_extraction.py b/tests/model_adapter_tests/embedding_extraction.py index 21fc282..0b51f51 100644 --- a/tests/model_adapter_tests/embedding_extraction.py +++ b/tests/model_adapter_tests/embedding_extraction.py @@ -9,28 +9,29 @@ def check_embedding_extraction( test_image, model_adapter, expected_output_shape, expected_slices ): with TemporaryDirectory() as tmp_dir: - tmp_file = tmp_dir + "/tmp.h5" + tmp_file = tmp_dir + "/tmp.zarr" extractor_generator = extract_embeddings_to_file( - image=test_image, - storage_file_path=tmp_file, - model_adapter=model_adapter + image=test_image, storage_path=tmp_file, model_adapter=model_adapter ) # Run the extractor generator till the end _ = list(extractor_generator) - with h5py.File(tmp_file, "r") as read_storage: - slices = list(read_storage.keys()) - assert ( - len(slices) == expected_slices - ), f"Unexpected number of slices: {len(slices)}, expected: {expected_slices}" - for slice in slices: - slice_key = str(slice) - slice_dataset = read_storage[slice_key].get(model_adapter.name) - assert ( - slice_dataset is not None - ), f"The dataset for slice {slice_key} is empty" - assert ( - slice_dataset.shape == expected_output_shape - ), f"Unexpected dataset shape: {slice_dataset.shape}, expected: {expected_output_shape}" \ No newline at end of file + read_storage: h5py.File = h5py.File(tmp_file, mode="r") # type: ignore + slices = list(read_storage.keys()) + assert len(slices) == expected_slices, ( + f"Unexpected number of slices: {len(slices)}, expected: {expected_slices}" + ) + for slice_idx in slices: + slice_key = str(slice_idx) + slice_dataset = read_storage[slice_key]["features"] + assert slice_dataset is not None, ( + f"The dataset for slice {slice_key} is empty" + ) + assert slice_dataset.shape == expected_output_shape, ( + f"Unexpected dataset shape: {slice_dataset.shape}, " + f"expected: {expected_output_shape}" + ) + + read_storage.close() diff --git a/tests/model_adapter_tests/test_dino_adapter.py b/tests/model_adapter_tests/test_dino_adapter.py index 7b8fbbc..99a6401 100644 --- a/tests/model_adapter_tests/test_dino_adapter.py +++ b/tests/model_adapter_tests/test_dino_adapter.py @@ -2,10 +2,10 @@ import pytest import torch import torch.nn as nn +from embedding_extraction import check_embedding_extraction from featureforest.models.DinoV2 import DinoV2Adapter, get_model from featureforest.utils.data import get_stack_dims -from embedding_extraction import check_embedding_extraction class MockDinoEncoder(nn.Module): @@ -22,7 +22,7 @@ def get_intermediate_layers(self, x, *args, **kwargs): return output, None -def get_mock_model(img_height: float, img_width: float) -> DinoV2Adapter: +def get_mock_model(img_height: int, img_width: int) -> DinoV2Adapter: model = MockDinoEncoder() device = torch.device("cpu") dino_model_adapter = DinoV2Adapter(model, img_height, img_width, device) @@ -49,10 +49,10 @@ def test_mock_adapter(test_patch: np.ndarray): result_real = real_adapter.model.get_intermediate_layers( transformed_input_patch_real, 1, return_class_token=False, reshape=True - )[0] + )[0] # type: ignore mock_result = mock_adapter.model.get_intermediate_layers( transformed_input_patch_mock - )[0] + )[0] # type: ignore assert len(result_real) == len(mock_result) assert result_real[0].shape == mock_result[0].shape diff --git a/tests/model_adapter_tests/test_mobilesam_adapter.py b/tests/model_adapter_tests/test_mobilesam_adapter.py index 7c00491..284e1b9 100644 --- a/tests/model_adapter_tests/test_mobilesam_adapter.py +++ b/tests/model_adapter_tests/test_mobilesam_adapter.py @@ -2,10 +2,10 @@ import pytest import torch import torch.nn as nn +from embedding_extraction import check_embedding_extraction from featureforest.models.MobileSAM import MobileSAMAdapter, get_model from featureforest.utils.data import get_stack_dims -from embedding_extraction import check_embedding_extraction class MockMobileSAMEncoder(nn.Module): @@ -32,7 +32,7 @@ def mock_encode(self, x): return output, embed_output, None -def get_mock_model(img_height: float, img_width: float) -> MobileSAMAdapter: +def get_mock_model(img_height: int, img_width: int) -> MobileSAMAdapter: model = MockMobileSAMEncoder() device = torch.device("cpu") sam_model_adapter = MobileSAMAdapter(model, img_height, img_width, device) @@ -47,7 +47,7 @@ def get_mock_model(img_height: float, img_width: float) -> MobileSAMAdapter: torch.ones((3, 3, 128, 128)), torch.ones((8, 3, 128, 128)), torch.ones((8, 3, 256, 256)), - torch.ones((8, 3, 512, 512)) + torch.ones((8, 3, 512, 512)), ], ) def test_mock_adapter(test_patch: np.ndarray): @@ -68,10 +68,10 @@ def test_mock_adapter(test_patch: np.ndarray): @pytest.mark.parametrize( "test_image, expected_output_shape, expected_slices", [ - (np.ones((256, 256)), (9, 128, 128, 320), 1), # 2D - (np.ones((256, 256, 3)), (9, 128, 128, 320), 1), # 2D RGB - (np.ones((2, 256, 256)), (9, 128, 128, 320), 2), # 3D - (np.ones((2, 256, 256, 3)), (9, 128, 128, 320), 2) # 3D RGB + (np.ones((256, 256)), (4, 192, 192, 320), 1), # 2D + (np.ones((256, 256, 3)), (4, 192, 192, 320), 1), # 2D RGB + (np.ones((2, 256, 256)), (4, 192, 192, 320), 2), # 3D + (np.ones((2, 256, 256, 3)), (4, 192, 192, 320), 2), # 3D RGB ], ) def test_mobilesam_embedding_extraction( diff --git a/tests/model_adapter_tests/test_sam2_adapter.py b/tests/model_adapter_tests/test_sam2_adapter.py new file mode 100644 index 0000000..3125eb5 --- /dev/null +++ b/tests/model_adapter_tests/test_sam2_adapter.py @@ -0,0 +1,156 @@ +import numpy as np +import pytest +import torch +import torch.nn as nn +from embedding_extraction import check_embedding_extraction + +from featureforest.models.SAM2.adapter import SAM2Adapter +from featureforest.utils.data import get_stack_dims + + +class MockSAM2Encoder(nn.Module): + def __init__(self): + super().__init__() + self.encoder_num_channels = 256 * 3 + + def __call__(self, x): + batch_size = x.shape[0] + # Mock the backbone_fpn output with 3 feature levels + # [b, 256, 256, 256] + # [b, 256, 128, 128] + # [b, 256, 64, 64] + level1 = torch.ones(batch_size, 256, 256, 256) + level2 = torch.ones(batch_size, 256, 128, 128) + level3 = torch.ones(batch_size, 256, 64, 64) + + return {"backbone_fpn": [level1, level2, level3]} + + +def get_mock_model(img_height: int, img_width: int) -> SAM2Adapter: + model = MockSAM2Encoder() + device = torch.device("cpu") + sam2_model_adapter = SAM2Adapter(model, img_height, img_width, device) + return sam2_model_adapter + + +@pytest.mark.parametrize( + "img_height, img_width, expected_patch_size", + [ + ( + 512, + 512, + 256, + ), # SAM2 adapter seems to use 256 as default patch size for smaller images + (256, 256, 256), + (1024, 1024, 512), # For larger images, it uses 512 + ], +) +def test_initialize_sam2_adapter(img_height, img_width, expected_patch_size): + """Test that SAM2Adapter initializes correctly with different image sizes.""" + model = MockSAM2Encoder() + device = torch.device("cpu") + + adapter = SAM2Adapter(model, img_height, img_width, device) + + assert adapter.name == "SAM2_Large" + assert adapter.img_height == img_height + assert adapter.img_width == img_width + assert adapter.device == device + assert adapter.encoder == model + assert adapter.encoder_num_channels == 256 * 3 + assert adapter.patch_size == expected_patch_size + assert adapter.overlap == expected_patch_size // 4 + assert adapter.sam_input_dim == 1024 + + +@pytest.mark.parametrize( + "test_patch", + [ + torch.ones((1, 3, 128, 128)), + torch.ones((3, 3, 128, 128)), + torch.ones((1, 3, 256, 256)), + torch.ones((1, 3, 512, 512)), + ], +) +def test_process_input_patches(test_patch: torch.Tensor): + """Test that SAM2Adapter processes input patches correctly.""" + img_height, img_width = test_patch.shape[-2:] + adapter = get_mock_model(img_height, img_width) + + # Process the input patches + output_features = adapter.get_features_patches(test_patch) + + # Check output shape + batch_size = test_patch.shape[0] + assert output_features.shape[0] == batch_size + # The output should be in format [batch, height, width, channels] + assert output_features.shape[3] == adapter.encoder_num_channels + # Check that the output is a tensor + assert isinstance(output_features, torch.Tensor) + + +def test_get_total_output_channels(): + """Test that SAM2Adapter returns the correct number of output channels.""" + adapter = get_mock_model(512, 512) + + # Check that the output channels is 256 * 3 = 768 + assert adapter.get_total_output_channels() == 256 * 3 + + +@pytest.mark.parametrize( + "test_image, expected_output_shape, expected_slices", + [ + (np.ones((256, 256)), (4, 192, 192, 768), 1), # 2D + (np.ones((256, 256, 3)), (4, 192, 192, 768), 1), # 2D RGB + ], +) +def test_sam2_embedding_extraction_2d( + test_image: np.ndarray, expected_output_shape: tuple, expected_slices: int +): + """Test that SAM2Adapter extracts embeddings from 2D images correctly.""" + num_slices, img_height, img_width = get_stack_dims(test_image) + model_adapter = get_mock_model(img_height, img_width) + check_embedding_extraction( + test_image, model_adapter, expected_output_shape, expected_slices + ) + + +@pytest.mark.parametrize( + "test_image, expected_output_shape, expected_slices", + [ + (np.ones((2, 256, 256)), (4, 192, 192, 768), 2), # 3D + (np.ones((2, 256, 256, 3)), (4, 192, 192, 768), 2), # 3D RGB + ], +) +def test_sam2_embedding_extraction_3d( + test_image: np.ndarray, expected_output_shape: tuple, expected_slices: int +): + """Test that SAM2Adapter extracts embeddings from 3D stacks correctly.""" + num_slices, img_height, img_width = get_stack_dims(test_image) + model_adapter = get_mock_model(img_height, img_width) + check_embedding_extraction( + test_image, model_adapter, expected_output_shape, expected_slices + ) + + +def test_no_patching_mode(): + """Test that SAM2Adapter handles no_patching mode correctly.""" + img_height, img_width = 256, 256 + adapter = get_mock_model(img_height, img_width) + + # Default mode + assert adapter.no_patching is False + assert adapter.patch_size == 256 + assert adapter.overlap == 64 + + # Enable no_patching mode + adapter.no_patching = True + assert adapter.no_patching is True + assert adapter.patch_size == img_height + assert adapter.overlap == 0 + + # Disable no_patching mode + adapter.no_patching = False + assert adapter.no_patching is False + assert adapter.patch_size == 256 + assert adapter.overlap == 64 diff --git a/tests/test_dataset.py b/tests/test_dataset.py new file mode 100644 index 0000000..345d3de --- /dev/null +++ b/tests/test_dataset.py @@ -0,0 +1,114 @@ +import numpy as np +import pytest +import torch + +from featureforest.utils.dataset import FFImageDataset + + +class DummyImageSequence: + """Dummy class to mock pims.ImageSequence for testing.""" + + def __init__(self, images): + self._images = images + self.frame_shape = images[0].shape + + def __len__(self): + return len(self._images) + + def __getitem__(self, idx): + return self._images[idx] + + def __iter__(self): + return iter(self._images) + + +@pytest.fixture +def dummy_numpy_stack(): + # shape: (3, 32, 32) + return np.random.randint(0, 255, (3, 32, 32), dtype=np.uint8) + + +@pytest.fixture +def dummy_numpy_single(): + # shape: (32, 32) + return np.random.randint(0, 255, (32, 32), dtype=np.uint8) + + +def test_init_with_numpy_stack(dummy_numpy_stack, monkeypatch): + ds = FFImageDataset(dummy_numpy_stack, no_patching=True) + assert ds.num_images == 3 + assert ds.image_shape == (32, 32) + items = list(ds) + assert len(items) == 3 + for img, idx in items: + assert isinstance(img, torch.Tensor) + assert img.shape[-2:] == (32, 32) + assert idx[0] in [0, 1, 2] + + +def test_init_with_numpy_single(dummy_numpy_single): + ds = FFImageDataset(dummy_numpy_single, no_patching=True) + assert ds.num_images == 1 + assert ds.image_shape == (32, 32) + items = list(ds) + assert len(items) == 1 + img, idx = items[0] + assert isinstance(img, torch.Tensor) + assert img.shape[-2:] == (32, 32) + assert idx[0] == 0 + + +def test_iter_with_patching(dummy_numpy_stack, monkeypatch): + # Patch patchify to split into 2 patches + monkeypatch.setattr( + "featureforest.utils.dataset.patchify", + lambda img, sz, ov: [img[..., :16, :16], img[..., 16:, 16:]], + ) + ds = FFImageDataset(dummy_numpy_stack, no_patching=False) + items = list(ds) + assert len(items) == 6 # 3 images * 2 patches each + for patch, idx in items: + assert isinstance(patch, torch.Tensor) + assert idx.shape == (2,) + + +def test_init_with_invalid_type(): + with pytest.raises(ValueError): + FFImageDataset(12345) + + +def test_init_with_empty_dir(tmp_path): + empty_dir = tmp_path / "empty" + empty_dir.mkdir() + with pytest.raises(ValueError): + FFImageDataset(empty_dir) + + +def test_init_with_image_files(tmp_path, monkeypatch): + # Create dummy image files + for ext in ["tiff", "tif", "png", "jpg"]: + (tmp_path / f"img1.{ext}").write_bytes(np.random.bytes(100)) + # Patch pims.ImageSequence to DummyImageSequence + monkeypatch.setattr( + "featureforest.utils.dataset.pims.ImageSequence", + lambda files: DummyImageSequence([np.zeros((32, 32), dtype=np.uint8)] * 4), + ) + ds = FFImageDataset(tmp_path, no_patching=True) + assert ds.num_images == 4 + assert ds.image_shape == (32, 32) + items = list(ds) + assert len(items) == 4 + + +def test_image_shape_none(): + ds = FFImageDataset(np.zeros((1, 2, 2)), no_patching=True) + ds.image_source = None + with pytest.raises(ValueError): + _ = ds.image_shape + + +def test_iter_no_image_source(): + ds = FFImageDataset(np.zeros((1, 2, 2)), no_patching=True) + ds.image_source = None + with pytest.raises(ValueError): + next(iter(ds)) diff --git a/tests/test_pipeline_prediction.py b/tests/test_pipeline_prediction.py new file mode 100644 index 0000000..c148dfd --- /dev/null +++ b/tests/test_pipeline_prediction.py @@ -0,0 +1,323 @@ +from unittest.mock import MagicMock, patch + +import numpy as np +import pytest + +from featureforest.utils.pipeline_prediction import run_prediction_pipeline + + +@pytest.fixture +def mock_model_adapter(): + adapter = MagicMock() + adapter.no_patching = False + adapter.patch_size = 128 + adapter.overlap = 32 + adapter.get_total_output_channels.return_value = 320 + return adapter + + +@pytest.fixture +def mock_rf_model(): + rf = MagicMock() + rf.predict.side_effect = lambda x: np.ones((x.shape[0],), dtype=np.uint8) + return rf + + +@pytest.fixture +def mock_stack_dataset(): + dataset = MagicMock() + dataset.image_shape = (240, 240) + return dataset + + +@pytest.fixture +def mock_extract_embeddings(): + # Simulate two slices, each with two patches + def generator(*args, **kwargs): + yield np.zeros((1,)), 0, 2 + yield np.ones((1,)), 0, 2 + yield np.full((1,), 2), 1, 2 + yield np.full((1,), 3), 1, 2 + + return generator + + +@pytest.fixture +def mock_predict_patches(): + return lambda features, rf, adapter: np.array([[1]] * len(features), dtype=np.uint8) + + +@pytest.fixture +def mock_get_image_mask(): + return lambda patch_masks, h, w, ps, ov: np.ones((h, w), dtype=np.uint8) + + +@patch("featureforest.utils.pipeline_prediction.FFImageDataset") +@patch("featureforest.utils.pipeline_prediction.extract_embeddings") +@patch("featureforest.utils.pipeline_prediction.predict_patches") +@patch("featureforest.utils.pipeline_prediction.get_image_mask") +def test_run_prediction_pipeline_yields_correct_masks( + mock_get_image_mask_func, + mock_predict_patches_func, + mock_extract_embeddings_func, + mock_ffimagedataset, + mock_model_adapter, + mock_rf_model, +): + # Setup mocks + mock_ffimagedataset.return_value.image_shape = (240, 240) + + # Setup extract_embeddings mock to yield the expected tuples + def mock_extract_embeddings_generator(*args, **kwargs): + yield np.zeros((1,)), 0, 2 + yield np.ones((1,)), 0, 2 + yield np.full((1,), 2), 1, 2 + yield np.full((1,), 3), 1, 2 + + mock_extract_embeddings_func.side_effect = mock_extract_embeddings_generator + + # Setup other mocks + mock_predict_patches_func.side_effect = lambda features, rf, adapter: np.array( + [[1]] * len(features), dtype=np.uint8 + ) + mock_get_image_mask_func.side_effect = lambda patch_masks, h, w, ps, ov: np.ones( + (h, w), dtype=np.uint8 + ) + + input_stack = "dummy_path" + results = list( + run_prediction_pipeline(input_stack, mock_model_adapter, mock_rf_model) + ) + + # There should be two yields (one per slice) + assert len(results) == 2 + for mask, idx, total in results: + assert mask.shape == (240, 240) + assert total == 2 + assert idx in (0, 1) + assert np.all(mask == 1) + + +@patch("featureforest.utils.pipeline_prediction.FFImageDataset") +@patch("featureforest.utils.pipeline_prediction.extract_embeddings") +@patch("featureforest.utils.pipeline_prediction.predict_patches") +@patch("featureforest.utils.pipeline_prediction.get_image_mask") +def test_run_prediction_pipeline_handles_single_slice( + mock_get_image_mask_func, + mock_predict_patches_func, + mock_extract_embeddings_func, + mock_ffimagedataset, + mock_model_adapter, + mock_rf_model, +): + # Only one slice + def single_slice_gen(*args, **kwargs): + yield np.zeros((1,)), 0, 1 + yield np.ones((1,)), 0, 1 + + mock_ffimagedataset.return_value.image_shape = (240, 240) + mock_extract_embeddings_func.side_effect = single_slice_gen + mock_predict_patches_func.side_effect = lambda features, rf, adapter: np.array( + [[1]] * len(features), dtype=np.uint8 + ) + mock_get_image_mask_func.side_effect = lambda patch_masks, h, w, ps, ov: np.ones( + (h, w), dtype=np.uint8 + ) + + input_stack = "dummy_path" + results = list( + run_prediction_pipeline(input_stack, mock_model_adapter, mock_rf_model) + ) + + assert len(results) == 1 + mask, idx, total = results[0] + assert mask.shape == (240, 240) + assert idx == 0 + assert total == 1 + assert np.all(mask == 1) + + +@patch("featureforest.utils.pipeline_prediction.FFImageDataset") +@patch("featureforest.utils.pipeline_prediction.extract_embeddings") +@patch("featureforest.utils.pipeline_prediction.predict_patches") +@patch("featureforest.utils.pipeline_prediction.get_image_mask") +def test_run_prediction_pipeline_no_patching_mode( + mock_get_image_mask_func, + mock_predict_patches_func, + mock_extract_embeddings_func, + mock_ffimagedataset, + mock_model_adapter, + mock_rf_model, +): + """Test that the pipeline works correctly when no_patching is True.""" + # Setup model adapter with no_patching=True + mock_model_adapter.no_patching = True + mock_model_adapter.patch_size = 0 # Should be ignored in no_patching mode + mock_model_adapter.overlap = 0 # Should be ignored in no_patching mode + mock_model_adapter.get_total_output_channels.return_value = 320 + + # Setup dataset + mock_ffimagedataset.return_value.image_shape = (240, 240) + + # Setup embeddings generator + def no_patching_gen(*args, **kwargs): + # In no_patching mode, we'd have one feature per slice + yield np.zeros((1, 320)), 0, 2 + yield np.ones((1, 320)), 1, 2 + + mock_extract_embeddings_func.side_effect = no_patching_gen + + # Setup prediction mocks + mock_predict_patches_func.side_effect = lambda features, rf, adapter: np.array( + [[1]] * len(features), dtype=np.uint8 + ) + mock_get_image_mask_func.side_effect = lambda patch_masks, h, w, ps, ov: np.ones( + (h, w), dtype=np.uint8 + ) + + # Run the pipeline + input_stack = "dummy_path" + results = list( + run_prediction_pipeline(input_stack, mock_model_adapter, mock_rf_model) + ) + + # Verify FFImageDataset was created with no_patching=True + mock_ffimagedataset.assert_called_once_with( + images="dummy_path", no_patching=True, patch_size=0, overlap=0 + ) + + # Verify results + assert len(results) == 2 + for i, (mask, idx, total) in enumerate(results): + assert mask.shape == (240, 240) + assert idx == i + assert total == 2 + assert np.all(mask == 1) + + +@patch("featureforest.utils.pipeline_prediction.FFImageDataset") +@patch("featureforest.utils.pipeline_prediction.extract_embeddings") +@patch("featureforest.utils.pipeline_prediction.predict_patches") +@patch("featureforest.utils.pipeline_prediction.get_image_mask") +def test_run_prediction_pipeline_large_image_dimensions( + mock_get_image_mask_func, + mock_predict_patches_func, + mock_extract_embeddings_func, + mock_ffimagedataset, + mock_model_adapter, + mock_rf_model, +): + """Test that the pipeline handles large image dimensions correctly.""" + # Setup large image dimensions + large_height, large_width = 4096, 4096 + mock_ffimagedataset.return_value.image_shape = (large_height, large_width) + + # Setup model adapter + mock_model_adapter.no_patching = False + mock_model_adapter.patch_size = 256 + mock_model_adapter.overlap = 64 + mock_model_adapter.get_total_output_channels.return_value = 320 + + # Setup embeddings generator + def large_image_gen(*args, **kwargs): + # Simulate many patches for a large image (simplified for test) + yield np.zeros((1, 16, 16, 320)), 0, 1 + yield np.ones((1, 16, 16, 320)), 0, 1 + yield np.full((1, 16, 16, 320), 2), 0, 1 + + mock_extract_embeddings_func.side_effect = large_image_gen + + # Setup prediction mocks + mock_predict_patches_func.side_effect = lambda features, rf, adapter: np.array( + [[1]] * len(features), dtype=np.uint8 + ) + + # The get_image_mask should return a mask with the large dimensions + mock_get_image_mask_func.side_effect = lambda patch_masks, h, w, ps, ov: np.ones( + (h, w), dtype=np.uint8 + ) + + # Run the pipeline + input_stack = "dummy_path" + results = list( + run_prediction_pipeline(input_stack, mock_model_adapter, mock_rf_model) + ) + + # Verify results + assert len(results) == 1 + mask, idx, total = results[0] + assert mask.shape == (large_height, large_width) + assert idx == 0 + assert total == 1 + assert np.all(mask == 1) + + # Verify get_image_mask was called with correct dimensions + # Instead of checking the exact mock value, just verify the function was called + # and check that the dimensions and other parameters were correct + assert mock_get_image_mask_func.called + args, kwargs = mock_get_image_mask_func.call_args + # Check that the height, width, patch_size and overlap parameters are correct + assert args[1] == large_height + assert args[2] == large_width + assert args[3] == mock_model_adapter.patch_size + assert args[4] == mock_model_adapter.overlap + + +@patch("featureforest.utils.pipeline_prediction.FFImageDataset") +@patch("featureforest.utils.pipeline_prediction.extract_embeddings") +@patch("featureforest.utils.pipeline_prediction.predict_patches") +@patch("featureforest.utils.pipeline_prediction.get_image_mask") +def test_run_prediction_pipeline_slice_index_continuity( + mock_get_image_mask_func, + mock_predict_patches_func, + mock_extract_embeddings_func, + mock_ffimagedataset, + mock_model_adapter, + mock_rf_model, +): + """Test that the pipeline handles non-continuous slice indices correctly.""" + # Setup + mock_ffimagedataset.return_value.image_shape = (240, 240) + + # Setup model adapter + mock_model_adapter.no_patching = False + mock_model_adapter.patch_size = 128 + mock_model_adapter.overlap = 32 + mock_model_adapter.get_total_output_channels.return_value = 320 + + # Setup embeddings generator with non-continuous slice indices (0, 2, 5) + def non_continuous_slices_gen(*args, **kwargs): + # Slice 0 + yield np.zeros((1,)), 0, 3 + yield np.ones((1,)), 0, 3 + # Slice 2 (skipping 1) + yield np.full((1,), 2), 2, 3 + # Slice 5 (skipping 3 and 4) + yield np.full((1,), 3), 5, 3 + + mock_extract_embeddings_func.side_effect = non_continuous_slices_gen + + # Setup prediction mocks + mock_predict_patches_func.side_effect = lambda features, rf, adapter: np.array( + [[1]] * len(features), dtype=np.uint8 + ) + mock_get_image_mask_func.side_effect = lambda patch_masks, h, w, ps, ov: np.ones( + (h, w), dtype=np.uint8 + ) + + # Run the pipeline + input_stack = "dummy_path" + results = list( + run_prediction_pipeline(input_stack, mock_model_adapter, mock_rf_model) + ) + + # Verify results + assert len(results) == 3 + + # Check that the slice indices match what we expect + expected_indices = [0, 2, 5] + for i, (mask, idx, total) in enumerate(results): + assert mask.shape == (240, 240) + assert idx == expected_indices[i] + assert total == 3 + assert np.all(mask == 1)