diff --git a/.github/workflows/setup-data.yml b/.github/workflows/setup-data.yml index 358aa0c..4ea3aea 100644 --- a/.github/workflows/setup-data.yml +++ b/.github/workflows/setup-data.yml @@ -40,7 +40,7 @@ jobs: - name: Download test data if: steps.cache-data.outputs.cache-hit != 'true' run: | - python3 -m mritk download-test-data data + python3 -m mritk datasets download test-data -o data - name: Upload Data Artifact diff --git a/pyproject.toml b/pyproject.toml index 52ae7bd..3daa000 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -82,7 +82,7 @@ exclude = [ ] # Same as Black. -line-length = 100 +line-length = 130 # Assume Python 3.10. target-version = "py312" diff --git a/src/mritk/cli.py b/src/mritk/cli.py index 1e126f4..280daa2 100644 --- a/src/mritk/cli.py +++ b/src/mritk/cli.py @@ -6,7 +6,7 @@ from rich_argparse import RichHelpFormatter -from . import download_data, info, statistics, show, napari +from . import datasets, info, statistics, show, napari def version_info(): @@ -48,33 +48,21 @@ def setup_parser(): subparsers = parser.add_subparsers(dest="command") # Download test data parser - download_parser = subparsers.add_parser( - "download-test-data", help="Download test data", formatter_class=parser.formatter_class - ) - download_parser.add_argument("outdir", type=Path, help="Output directory to download test data") + datasets_parser = subparsers.add_parser("datasets", help="Download datasets", formatter_class=parser.formatter_class) + datasets.add_arguments(datasets_parser) - info_parser = subparsers.add_parser( - "info", help="Display information about a file", formatter_class=parser.formatter_class - ) + info_parser = subparsers.add_parser("info", help="Display information about a file", formatter_class=parser.formatter_class) info_parser.add_argument("file", type=Path, help="File to display information about") - info_parser.add_argument( - "--json", action="store_true", help="Output information in JSON format" - ) + info_parser.add_argument("--json", action="store_true", help="Output information in JSON format") - stats_parser = subparsers.add_parser( - "stats", help="Compute MRI statistics", formatter_class=parser.formatter_class - ) + stats_parser = subparsers.add_parser("stats", help="Compute MRI statistics", formatter_class=parser.formatter_class) statistics.cli.add_arguments(stats_parser) - show_parser = subparsers.add_parser( - "show", help="Show MRI data in a terminal", formatter_class=parser.formatter_class - ) + show_parser = subparsers.add_parser("show", help="Show MRI data in a terminal", formatter_class=parser.formatter_class) show.add_arguments(show_parser) - napari_parser = subparsers.add_parser( - "napari", help="Show MRI data using napari", formatter_class=parser.formatter_class - ) + napari_parser = subparsers.add_parser("napari", help="Show MRI data using napari", formatter_class=parser.formatter_class) napari.add_arguments(napari_parser) return parser @@ -90,9 +78,8 @@ def dispatch(parser: argparse.ArgumentParser, argv: Optional[Sequence[str]] = No command = args.pop("command") logger = logging.getLogger(__name__) try: - if command == "download-test-data": - outdir = args.pop("outdir") - download_data.download_test_data(outdir) + if command == "datasets": + datasets.dispatch(args) elif command == "info": file = args.pop("file") info.nifty_info(file, json_output=args.pop("json")) diff --git a/src/mritk/data/io.py b/src/mritk/data/io.py index 95d6c6f..449f92c 100644 --- a/src/mritk/data/io.py +++ b/src/mritk/data/io.py @@ -43,9 +43,7 @@ def load_mri_data( return mri -def save_mri_data( - mri: MRIData, path: Path, dtype: npt.DTypeLike, intent_code: Optional[int] = None -): +def save_mri_data(mri: MRIData, path: Path, dtype: npt.DTypeLike, intent_code: Optional[int] = None): # TODO : Choose other way to check extension than regex ? suffix_regex = re.compile(r".+(?P(\.nii(\.gz|)|\.mg(z|h)))") m = suffix_regex.match(Path(path).name) diff --git a/src/mritk/data/orientation.py b/src/mritk/data/orientation.py index 1e3b12d..1d212a1 100644 --- a/src/mritk/data/orientation.py +++ b/src/mritk/data/orientation.py @@ -79,9 +79,7 @@ def change_of_coordinates_map(orientation_in: str, orientation_out: str) -> np.n if idx2 == len(orientation_out): print(char1, char2) - raise ValueError( - f"Couldn't find axis in '{orientation_out}' corresponding to '{char1}'" - ) + raise ValueError(f"Couldn't find axis in '{orientation_out}' corresponding to '{char1}'") index_flip = np.sign(order).astype(int) index_order = np.abs(order).astype(int) - 1 # Map back to 0-indexing diff --git a/src/mritk/datasets.py b/src/mritk/datasets.py new file mode 100644 index 0000000..87dacb7 --- /dev/null +++ b/src/mritk/datasets.py @@ -0,0 +1,264 @@ +"""MRI -- Download data for testing + +Copyright (C) 2026 Henrik Finsberg (henriknf@simula.no) +Copyright (C) 2026 Cécile Daversin-Catty (cecile@simula.no) +Copyright (C) 2026 Simula Research Laboratory +""" + +import logging +from dataclasses import dataclass +import zipfile +from pathlib import Path +import urllib.request +from concurrent.futures import ThreadPoolExecutor +import tqdm + +logger = logging.getLogger(__name__) + + +@dataclass +class Dataset: + name: str + links: dict[str, str] + description: str = "" + doi: str = "" + license: str = "" + + +def get_datasets() -> dict[str, Dataset]: + return { + "test-data": Dataset( + name="Test Data", + description="A small test dataset for testing functionality (based on the Gonzo dataset).", + doi="10.5281/zenodo.14266867", + license="CC-BY-4.0", + links={ + "mri-processed.zip": "https://zenodo.org/records/14266867/files/mri-processed.zip?download=1", + "timetable.tsv": "https://github.com/jorgenriseth/gonzo/blob/main/mri_dataset/timetable.tsv?raw=true", + }, + ), + "gonzo": Dataset( + name="The Gonzo Dataset", + description=""" + We present the Gonzo dataset: brain MRI and derivative data of one healthy-appearing male human volunteer + before and during the 72 hours after injection of the contrast agent gadobutrol into the cerebrospinal + fluid (CSF) of the spinal canal (intrathecal injection). The data records show the temporal and spatial + evolution of the contrast agent in CSF, brain, and adjacent structures. The MRI data includes T1-weighted + images, Look-Locker inversion recovery (LL, a technique to determine T1 values), a mixed inversion-recovery + spin-echo sequence (Mixed) for all time points (one pre-contrast and four post-contrast acquisitions) and, + in addition, T2-weighted, FLAIR, and dynamic DTI data for the pre-contrast session. In addition to raw data, + we provide derivatives with the goal of allowing for numerical simulations of the studied tracer transport process. + This includes T1 maps (from LL and Mixed) and tracer concentration maps, diffusion tensor maps, as well as + unstructured triangulated volume meshes of the brain geometry and associated field data (MRI and derived data mapped + onto the computational mesh). We provide brain region markers obtained with a FreeSurfer-based analysis pipeline. + An initial regional statistical analysis of the data is presented. The data can be used to study the transport + behaviour and the underlying processes of a tracer in the human brain. Tracer transport is both relevant to study + water transport as well as new pathways for drug delivery. The composition of the data set allows both reuse + by the image processing and the simulation science communities. The dataset is meant to contribute and + inspire new studies into the understanding of transport processes in the brain and into method development + regarding image analysis and simulation of transport processes.""", + doi="10.5281/zenodo.14266867", + license="CC-BY-4.0", + links={ + "data-descriptor-preprint.pdf": "https://zenodo.org/records/14266867/files/data-descriptor-preprint.pdf?download=1", + "fastsurfer.zip": "https://zenodo.org/records/14266867/files/fastsurfer.zip?download=1", + "freesurfer.zip": "https://zenodo.org/records/14266867/files/freesurfer.zip?download=1", + "mesh-data.zip": "https://zenodo.org/records/14266867/files/mesh-data.zip?download=1", + "mri-dataset-precontrast-only.zip": "https://zenodo.org/records/14266867/files/mri-dataset-precontrast-only.zip?download=1", + "mri-dataset.zip": "https://zenodo.org/records/14266867/files/mri-dataset.zip?download=1", + "mri-processed.zip": "https://zenodo.org/records/14266867/files/mri-processed.zip?download=1", + "README.md": "https://zenodo.org/records/14266867/files/README.md?download=1", + "surfaces.zip": "https://zenodo.org/records/14266867/files/surfaces.zip?download=1", + }, + ), + "ratbrain": Dataset( + name="Ratbrain Mesh", + description=""" + This repository contains a collection of files that were used in the article Poulain et al. (2023) + -- Multi-compartmental model of glymphatic clearance of solutes in brain tissue + (https://doi.org/10.1371/journal.pone.0280501) to generate meshes of a ratbrain. It includes + python-scripts for generating a FEniCS-compatible meshes from the included stl-files, as well + as a 3DSlicer-compatible (https://www.slicer.org/) segmentation file that were used to generate + the stls.""", + doi="10.5281/zenodo.10076317", + license="CC-BY-4.0", + links={ + "brain.stl": "https://zenodo.org/records/8138343/files/brain.stl?download=1", + "environment.yml": "https://zenodo.org/records/8138343/files/environment.yml?download=1", + "LICENSE.txt": "https://zenodo.org/records/8138343/files/LICENSE.txt?download=1", + "mesh_generation.py": "https://zenodo.org/records/8138343/files/mesh_generation.py?download=1", + "meshprocessing.py": "https://zenodo.org/records/8138343/files/meshprocessing.py?download=1", + "README.md": "https://zenodo.org/records/8138343/files/README.md?download=1", + "segmentation.seg.nrrd": "https://zenodo.org/records/8138343/files/segmentation.seg.nrrd?download=1", + "ventricles.stl": "https://zenodo.org/records/8138343/files/ventricles.stl?download=1", + }, + ), + } + + +# From https://gist.github.com/maxpoletaev/521c4ce2f5431a4afabf19383fc84fe2 +class ProgressBar: + def __init__(self, filename: str): + self.tqdm = None + self.filename = filename + + def __call__(self, block_num, block_size, total_size): + if self.tqdm is None: + self.tqdm = tqdm.tqdm( + total=total_size, + unit_divisor=1024, + unit_scale=True, + unit="B", + desc=self.filename, + leave=False, + ) + + progress = block_num * block_size + if progress >= total_size: + self.tqdm.close() + return + + self.tqdm.update(progress - self.tqdm.n) + + +def list_datasets(): + from rich.console import Console + from rich.table import Table + from rich.panel import Panel + from rich.text import Text + from rich import box + + console = Console() + datasets = get_datasets() + + console.print(Panel.fit("[bold cyan]Available Datasets[/bold cyan]", border_style="cyan")) + + for key, dataset in datasets.items(): + # Create a table for the files/links + link_table = Table(box=box.SIMPLE, show_header=True, header_style="bold magenta", expand=True) + link_table.add_column("Filename", style="green") + link_table.add_column("URL", style="blue", overflow="fold") + + for filename, url in dataset.links.items(): + link_table.add_row(filename, url) + + # Format description text (strip whitespace for cleaner output) + description_text = Text( + dataset.description.strip().replace(" ", "").replace("\n", " "), + style="white", + ) + + # Create the main content grid + content = Table.grid(padding=1) + content.add_column(style="bold yellow", justify="right") + content.add_column(style="white") + + content.add_row("Key:", key) + content.add_row("DOI:", dataset.doi) + content.add_row("License:", dataset.license) + content.add_row("Description:", description_text) + content.add_row("Files:", link_table) + + # Wrap in a Panel + console.print( + Panel( + content, + title=f"[bold]{dataset.name}[/bold]", + subtitle=f"[dim]Key: {key}[/dim]", + border_style="green", + expand=True, + ) + ) + console.print("") # Add spacing between panels + + +def add_arguments(parser): + subparsers = parser.add_subparsers(dest="datasets-command") + download_parser = subparsers.add_parser("download", help="Download a dataset", formatter_class=parser.formatter_class) + choices = list(get_datasets().keys()) + download_parser.add_argument( + "dataset", + type=str, + default=choices[0], + choices=choices, + help=f"Dataset to download (choices: {', '.join(choices)})", + ) + download_parser.add_argument("-o", "--outdir", type=Path, help="Output directory to download test data") + + subparsers.add_parser("list", help="List available datasets") + + +def dispatch(args): + subcommand = args.pop("datasets-command", None) + if subcommand == "list": + list_datasets() + return + elif subcommand == "download": + dataset = args.pop("dataset") + outdir = args.pop("outdir") + if outdir is None: + logger.error("Output directory (-o or --outdir) is required for downloading datasets.") + return + + datasets = get_datasets() + if dataset not in datasets: + logger.error(f"Unknown dataset: {dataset}. Available datasets: {', '.join(datasets.keys())}") + return + + links = datasets[dataset].links + download_multiple(links, outdir) + else: + raise ValueError(f"Unknown subcommand: {subcommand}") + + +# def download_data(outdir: Path, file_info: tuple) -> None: +def download_data(args) -> Path: + (outdir, file_info) = args + (filename, url) = file_info + output_path = outdir / Path(filename).stem / filename + output_path.parent.mkdir(parents=True, exist_ok=True) + + logger.info(f"Downloading {url} to {output_path}. This may take a while.") + + try: + urllib.request.urlretrieve(url, output_path, reporthook=ProgressBar(filename=filename)) + if not zipfile.is_zipfile(output_path): + logger.info(f"Downloaded {filename} is not a zip file. No extraction needed.") + return output_path + logger.info(f"Extracting {output_path} to {output_path.parent}.") + with zipfile.ZipFile(output_path, "r") as zip_ref: + zip_ref.extractall(output_path.parent) + output_path.unlink() + except Exception as e: + logger.error(f"Failed to download {filename} from {url}. Error: {e}") + raise + return output_path + + +# Download multiple files concurrently +# Implementation inspired by https://medium.com/@ryan_forrester_/downloading-files-from-urls-in-python-f644e04a0b16 +def download_multiple(urls: dict, outdir, max_workers=1): + outdir.mkdir(parents=True, exist_ok=True) + + # Prepare arguments for thread pool + args = [(outdir, file_info) for file_info in urls.items()] + + # Download files using thread pool + with ThreadPoolExecutor(max_workers=max_workers) as executor: + results = list( + tqdm.tqdm( + executor.map(download_data, args), + total=len(urls), + desc="Downloading MRI data", + ) + ) + + # Process results + successful = [r for r in results if r is not None] + failed = len(results) - len(successful) + + print("\nDownload complete:") + print(f"- Successfully downloaded: {len(successful)} files") + print(f"- Failed downloads: {failed} files") + + return successful diff --git a/src/mritk/download_data.py b/src/mritk/download_data.py deleted file mode 100644 index 8ea1af4..0000000 --- a/src/mritk/download_data.py +++ /dev/null @@ -1,101 +0,0 @@ -"""MRI -- Download data for testing - -Copyright (C) 2026 Henrik Finsberg (henriknf@simula.no) -Copyright (C) 2026 Cécile Daversin-Catty (cecile@simula.no) -Copyright (C) 2026 Simula Research Laboratory -""" - -import logging -import zipfile -from pathlib import Path -from urllib.request import urlretrieve -from concurrent.futures import ThreadPoolExecutor -import tqdm - -logger = logging.getLogger(__name__) - - -# From https://gist.github.com/maxpoletaev/521c4ce2f5431a4afabf19383fc84fe2 -class ProgressBar: - def __init__(self, filename: str): - self.tqdm = None - self.filename = filename - - def __call__(self, block_num, block_size, total_size): - if self.tqdm is None: - self.tqdm = tqdm.tqdm( - total=total_size, - unit_divisor=1024, - unit_scale=True, - unit="B", - desc=self.filename, - leave=False, - ) - - progress = block_num * block_size - if progress >= total_size: - self.tqdm.close() - return - - self.tqdm.update(progress - self.tqdm.n) - - -def download_test_data(outdir: Path) -> None: - links = { - "mri-processed.zip": "https://zenodo.org/records/14266867/files/mri-processed.zip?download=1", - "timetable.tsv": "https://github.com/jorgenriseth/gonzo/blob/main/mri_dataset/timetable.tsv?raw=true", - } - download_multiple(links, outdir) - - -# def download_data(outdir: Path, file_info: tuple) -> None: -def download_data(args) -> None: - (outdir, file_info) = args - (filename, url) = file_info - output_path = outdir / Path(filename).stem / filename - output_path.parent.mkdir(parents=True, exist_ok=True) - - logger.info(f"Downloading {url} to {output_path}. This may take a while.") - - try: - urlretrieve(url, output_path, reporthook=ProgressBar(filename=filename)) - if not zipfile.is_zipfile(output_path): - logger.info(f"Downloaded {filename} is not a zip file. No extraction needed.") - return - logger.info(f"Extracting {output_path} to {output_path.parent}.") - with zipfile.ZipFile(output_path, "r") as zip_ref: - zip_ref.extractall(output_path.parent) - output_path.unlink() - except Exception as e: - logger.error(f"Failed to download {filename} from {url}. Error: {e}") - return None - return output_path - - -# Download multiple files concurrently -# Implementation inspired by https://medium.com/@ryan_forrester_/downloading-files-from-urls-in-python-f644e04a0b16 -def download_multiple(urls: dict, outdir, max_workers=1): - outdir.mkdir(parents=True, exist_ok=True) - - # Prepare arguments for thread pool - args = [(outdir, file_info) for file_info in urls.items()] - - # Download files using thread pool - with ThreadPoolExecutor(max_workers=max_workers) as executor: - results = list( - tqdm.tqdm( - executor.map(download_data, args), - total=len(urls), - desc="Downloading MRI data", - ) - ) - - # Process results - successful = [r for r in results if r is not None] - failed = len(results) - len(successful) - - print("\nDownload complete:") - print(f"- Successfully downloaded: {len(successful)} files") - print(f"- Failed downloads: {failed} files") - - return successful diff --git a/src/mritk/info.py b/src/mritk/info.py index 8ee9ba5..3052900 100644 --- a/src/mritk/info.py +++ b/src/mritk/info.py @@ -45,11 +45,7 @@ def nifty_info(filename: Path, json_output: bool = False) -> dict[str, typing.An return data # Create a nice header panel - console.print( - Panel( - f"[bold blue]NIfTI File Analysis[/bold blue]\n[green]{filename}[/green]", expand=False - ) - ) + console.print(Panel(f"[bold blue]NIfTI File Analysis[/bold blue]\n[green]{filename}[/green]", expand=False)) # Create a table for Basic Info info_table = Table( diff --git a/src/mritk/napari.py b/src/mritk/napari.py new file mode 100644 index 0000000..645668b --- /dev/null +++ b/src/mritk/napari.py @@ -0,0 +1,58 @@ +import argparse +from pathlib import Path + +import numpy as np +from rich.console import Console + +# Assuming relative imports based on your previous file structure +from .data.io import load_mri_data + + +def add_arguments(parser: argparse.ArgumentParser): + parser.add_argument("files", nargs="+", type=Path, help="Files to show") + + +def normalize_to_uint8(data: np.ndarray) -> np.ndarray: + """Normalize array values to 0-255 uint8 range for image display.""" + # Handle NaNs and Infs + data = np.nan_to_num(data) + + d_min, d_max = data.min(), data.max() + if d_max > d_min: + # Linear scaling to 0-255 + normalized = (data - d_min) / (d_max - d_min) * 255 + else: + normalized = np.zeros_like(data) + + return normalized.astype(np.uint8) + + +def dispatch(args): + """ + Displays three orthogonal slices (Sagittal, Coronal, Axial) of an MRI file + in the terminal. + """ + try: + import napari + except ImportError: + console = Console() + console.print( + "[bold red]Error:[/bold red] The 'napari' library is required to use the 'napari' command. " + "Please install it with 'pip install mri-toolkit[napari]'" + ) + return + # 1. Load Data + # Assuming args is a dict or Namespace. Adapting to your snippet's usage: + + file_paths = args.pop("files") + + viewer = napari.Viewer() + for file_path in file_paths: + console = Console() + console.print(f"[bold green]Loading MRI data from:[/bold green] {file_path}") + + mri_resource = load_mri_data(file_path) + data = mri_resource.data + viewer.add_image(data, name=file_path.stem) + + napari.run() diff --git a/src/mritk/statistics/cli.py b/src/mritk/statistics/cli.py index b49957b..491f439 100644 --- a/src/mritk/statistics/cli.py +++ b/src/mritk/statistics/cli.py @@ -97,10 +97,7 @@ def get_stats_value(stats_file: Path, region: str, info: str, **kwargs): # Validate inputs valid_regions = default_segmentation_groups().keys() if region not in valid_regions: - console.print( - f"[bold red]Error:[/bold red] Region '{region}' " - "not found in default segmentation groups." - ) + console.print(f"[bold red]Error:[/bold red] Region '{region}' not found in default segmentation groups.") sys.exit(1) valid_infos = [ @@ -119,10 +116,7 @@ def get_stats_value(stats_file: Path, region: str, info: str, **kwargs): "PC99", ] if info not in valid_infos: - console.print( - f"[bold red]Error:[/bold red] Info '{info}' " - f"is invalid. Choose from: {', '.join(valid_infos)}" - ) + console.print(f"[bold red]Error:[/bold red] Info '{info}' is invalid. Choose from: {', '.join(valid_infos)}") sys.exit(1) if not stats_file.exists(): @@ -142,8 +136,7 @@ def get_stats_value(stats_file: Path, region: str, info: str, **kwargs): # Output console.print( - f"[bold cyan]{info}[/bold cyan] for [bold green]{region}[/bold green] " - f"= [bold white]{info_value}[/bold white]" + f"[bold cyan]{info}[/bold cyan] for [bold green]{region}[/bold green] = [bold white]{info_value}[/bold white]" ) return info_value @@ -156,22 +149,12 @@ def add_arguments(parser: argparse.ArgumentParser): subparsers = parser.add_subparsers(dest="stats-command", help="Available commands") # --- Compute Command --- - parser_compute = subparsers.add_parser( - "compute", help="Compute MRI statistics", formatter_class=parser.formatter_class - ) - parser_compute.add_argument( - "--segmentation", "-s", type=Path, required=True, help="Path to segmentation file" - ) - parser_compute.add_argument( - "--mri", "-m", type=Path, nargs="+", required=True, help="Path to MRI data file(s)" - ) - parser_compute.add_argument( - "--output", "-o", type=Path, required=True, help="Output CSV file path" - ) + parser_compute = subparsers.add_parser("compute", help="Compute MRI statistics", formatter_class=parser.formatter_class) + parser_compute.add_argument("--segmentation", "-s", type=Path, required=True, help="Path to segmentation file") + parser_compute.add_argument("--mri", "-m", type=Path, nargs="+", required=True, help="Path to MRI data file(s)") + parser_compute.add_argument("--output", "-o", type=Path, required=True, help="Output CSV file path") parser_compute.add_argument("--timetable", "-t", type=Path, help="Path to timetable file") - parser_compute.add_argument( - "--timelabel", "-l", dest="timelabel", type=str, help="Time label sequence" - ) + parser_compute.add_argument("--timelabel", "-l", dest="timelabel", type=str, help="Time label sequence") parser_compute.add_argument( "--seg_regex", "-sr", @@ -179,24 +162,16 @@ def add_arguments(parser: argparse.ArgumentParser): type=str, help="Regex pattern for segmentation filename", ) - parser_compute.add_argument( - "--mri_regex", "-mr", dest="mri_regex", type=str, help="Regex pattern for MRI filename" - ) + parser_compute.add_argument("--mri_regex", "-mr", dest="mri_regex", type=str, help="Regex pattern for MRI filename") parser_compute.add_argument("--lut", "-lt", dest="lut", type=Path, help="Path to Lookup Table") parser_compute.add_argument("--info", "-i", type=str, help="Info dictionary as JSON string") parser_compute.set_defaults(func=compute_mri_stats) # --- Get Command --- - parser_get = subparsers.add_parser( - "get", help="Get specific stats value", formatter_class=parser.formatter_class - ) - parser_get.add_argument( - "--stats_file", "-f", type=Path, required=True, help="Path to stats CSV file" - ) + parser_get = subparsers.add_parser("get", help="Get specific stats value", formatter_class=parser.formatter_class) + parser_get.add_argument("--stats_file", "-f", type=Path, required=True, help="Path to stats CSV file") parser_get.add_argument("--region", "-r", type=str, required=True, help="Region description") - parser_get.add_argument( - "--info", "-i", type=str, required=True, help="Statistic to retrieve (mean, std, etc.)" - ) + parser_get.add_argument("--info", "-i", type=str, required=True, help="Statistic to retrieve (mean, std, etc.)") parser_get.set_defaults(func=get_stats_value) diff --git a/src/mritk/statistics/compute_stats.py b/src/mritk/statistics/compute_stats.py index 7234d46..29384a9 100644 --- a/src/mritk/statistics/compute_stats.py +++ b/src/mritk/statistics/compute_stats.py @@ -49,13 +49,9 @@ def generate_stats_dataframe( if (m := re.match(seg_pattern, Path(seg_path).name)) is not None: seg_info = m.groupdict() else: - raise RuntimeError( - f"Segmentation filename {seg_path.name} does not match the provided pattern." - ) + raise RuntimeError(f"Segmentation filename {seg_path.name} does not match the provided pattern.") elif info_dict is not None: - seg_info["segmentation"] = ( - info_dict["segmentation"] if "segmentation" in info_dict else None - ) + seg_info["segmentation"] = info_dict["segmentation"] if "segmentation" in info_dict else None seg_info["subject"] = info_dict["subject"] if "subject" in info_dict else None else: seg_info = {"segmentation": None, "subject": None} @@ -66,9 +62,7 @@ def generate_stats_dataframe( if (m := re.match(mri_data_pattern, Path(mri_path).name)) is not None: mri_info = m.groupdict() else: - raise RuntimeError( - f"MRI data filename {mri_path.name} does not match the provided pattern." - ) + raise RuntimeError(f"MRI data filename {mri_path.name} does not match the provided pattern.") elif info_dict is not None: mri_info["mri_data"] = info_dict["mri_data"] if "mri_data" in info_dict else None mri_info["subject"] = info_dict["subject"] if "subject" in info_dict else None diff --git a/test/test_cli.py b/test/test_cli.py index 6c1eb28..a5e3484 100644 --- a/test/test_cli.py +++ b/test/test_cli.py @@ -11,11 +11,7 @@ def test_cli_version(capsys): def test_cli_info(capsys, mri_data_dir): - test_file = ( - mri_data_dir - / "mri-processed/mri_processed_data/sub-01" - / "concentrations/sub-01_ses-01_concentration.nii.gz" - ) + test_file = mri_data_dir / "mri-processed/mri_processed_data/sub-01" / "concentrations/sub-01_ses-01_concentration.nii.gz" args = ["info", str(test_file)] cli.main(args) captured = capsys.readouterr() @@ -24,11 +20,7 @@ def test_cli_info(capsys, mri_data_dir): def test_cli_info_json(capsys, mri_data_dir): - test_file = ( - mri_data_dir - / "mri-processed/mri_processed_data/sub-01" - / "concentrations/sub-01_ses-01_concentration.nii.gz" - ) + test_file = mri_data_dir / "mri-processed/mri_processed_data/sub-01" / "concentrations/sub-01_ses-01_concentration.nii.gz" args = ["info", str(test_file), "--json"] cli.main(args) captured = capsys.readouterr() diff --git a/test/test_datasets.py b/test/test_datasets.py new file mode 100644 index 0000000..16bcbe3 --- /dev/null +++ b/test/test_datasets.py @@ -0,0 +1,153 @@ +import pytest +from pathlib import Path +from unittest.mock import MagicMock, patch + +# Import your module +import mritk.cli +import mritk.datasets +from mritk.datasets import Dataset + +# --- Fixtures --- + + +@pytest.fixture +def mock_datasets(): + """Returns a simplified version of the dataset dictionary using the new Dataclass.""" + return { + "test-data": Dataset( + name="Test Data", + description="Test description", + doi="10.1234/test", + license="MIT", + links={ + "file1.txt": "http://example.com/file1.txt", + "archive.zip": "http://example.com/archive.zip", + }, + ), + "gonzo": Dataset(name="Gonzo", description="Gonzo description", links={}), + } + + +def test_get_datasets_structure(): + """Ensure get_datasets returns a dict of Dataset objects.""" + datasets = mritk.datasets.get_datasets() + assert "test-data" in datasets + assert isinstance(datasets["test-data"], Dataset) + assert datasets["test-data"].doi == "10.5281/zenodo.14266867" + + +def test_add_arguments(): + """Ensure argparse is configured correctly with subcommands.""" + import argparse + + parser = argparse.ArgumentParser() + mritk.datasets.add_arguments(parser) + + # Test 'download' subcommand + args = parser.parse_args(["download", "test-data", "-o", "/tmp/out"]) + # Note: argparse converts hyphens in 'dest' to underscores for attributes, + # but strictly speaking, add_subparsers dest keeps the name if accessed via vars() + # Let's check the logic used in the script. + assert getattr(args, "datasets-command") == "download" + assert args.dataset == "test-data" + assert args.outdir == Path("/tmp/out") + + # Test 'list' subcommand + args_list = parser.parse_args(["list"]) + assert getattr(args_list, "datasets-command") == "list" + + +@patch("mritk.datasets.get_datasets") +@patch("mritk.datasets.download_multiple") +def test_dispatch_download_success(mock_download_multiple, mock_get_datasets, mock_datasets): + """Test that dispatch calls download_multiple with the .links attribute.""" + mock_get_datasets.return_value = mock_datasets + + mritk.cli.main(["datasets", "download", "test-data", "-o", "/tmp"]) + + # CRITICAL: This test asserts that you passed the .links dictionary, + # not the Dataset object itself. + expected_links = mock_datasets["test-data"].links + mock_download_multiple.assert_called_once_with(expected_links, Path("/tmp")) + + +@patch("mritk.datasets.list_datasets") +def test_dispatch_list(mock_list_datasets): + """Test that dispatch calls list_datasets when subcommand is 'list'.""" + mritk.cli.main(["datasets", "list"]) + mock_list_datasets.assert_called_once() + + +@patch("mritk.datasets.get_datasets") +def test_dispatch_unknown_subcommand(mock_get_datasets): + """Test graceful failure on unknown subcommand.""" + args = {"datasets-command": "unknown"} + with pytest.raises(ValueError) as excinfo: + mritk.datasets.dispatch(args) + assert "Unknown subcommand" in str(excinfo.value) + + +@patch("rich.console.Console") # Mock rich.console.Console +@patch("mritk.datasets.get_datasets") +def test_list_datasets(mock_get_datasets, mock_console_cls, mock_datasets): + """Test that list_datasets attempts to print to console.""" + mock_get_datasets.return_value = mock_datasets + mock_console_instance = mock_console_cls.return_value + + mritk.datasets.list_datasets() + + # Verify it tried to print something + assert mock_console_instance.print.called + # We can verify it printed the dataset name + # (Checking exact rich output is hard, checking invocation is usually enough) + assert mock_console_cls.called + + +@patch("urllib.request.urlretrieve") +@patch("zipfile.is_zipfile") +def test_download_data_regular_file(mock_is_zip, mock_retrieve, tmp_path): + """Test downloading a standard non-zip file.""" + mock_is_zip.return_value = False + filename = "test.txt" + url = "http://example.com/test.txt" + outdir = tmp_path / "output" + + args = (outdir, (filename, url)) + + result_path = mritk.datasets.download_data(args) + + expected_path = outdir / "test" / filename + assert result_path == expected_path + mock_retrieve.assert_called_once() + + +@patch("urllib.request.urlretrieve") +def test_download_data_failure(mock_retrieve, tmp_path, caplog): + """Test error handling - now expects an Exception to be raised.""" + mock_retrieve.side_effect = Exception("Connection Reset") + + args = (tmp_path, ("file.txt", "http://bad-url.com")) + + # The new code re-raises the exception, so we use pytest.raises + with pytest.raises(Exception) as excinfo: + mritk.datasets.download_data(args) + + assert "Connection Reset" in str(excinfo.value) + # Also verify it logged the error before raising + assert "Failed to download" in caplog.text + + +@patch("mritk.datasets.ThreadPoolExecutor") +@patch("mritk.datasets.download_data") +def test_download_multiple(mock_download_data, mock_executor, tmp_path): + """Test the threading logic.""" + # Note: inputs are now a dict of links, not a Dataset object + urls = {"f1": "u1", "f2": "u2"} + + mock_executor_instance = MagicMock() + mock_executor.return_value.__enter__.return_value = mock_executor_instance + mock_executor_instance.map.return_value = ["path/to/f1", "path/to/f2"] + + successful = mritk.datasets.download_multiple(urls, tmp_path) + + assert len(successful) == 2 diff --git a/test/test_mri_io.py b/test/test_mri_io.py index cf68b42..34ffeeb 100644 --- a/test/test_mri_io.py +++ b/test/test_mri_io.py @@ -10,10 +10,7 @@ def test_mri_io_nifti(tmp_path, mri_data_dir): - input_file = ( - mri_data_dir - / "mri-processed/mri_dataset/derivatives/sub-01/ses-01/sub-01_ses-01_acq-mixed_T1map.nii.gz" - ) + input_file = mri_data_dir / "mri-processed/mri_dataset/derivatives/sub-01/ses-01/sub-01_ses-01_acq-mixed_T1map.nii.gz" output_file = tmp_path / "output_nifti.nii.gz" diff --git a/test/test_mri_stats.py b/test/test_mri_stats.py index 638f8b7..a259e86 100644 --- a/test/test_mri_stats.py +++ b/test/test_mri_stats.py @@ -13,16 +13,8 @@ def test_compute_stats_default(mri_data_dir: Path): - seg_path = ( - mri_data_dir - / "mri-processed/mri_processed_data/sub-01" - / "segmentations/sub-01_seg-aparc+aseg_refined.nii.gz" - ) - mri_path = ( - mri_data_dir - / "mri-processed/mri_processed_data/sub-01" - / "concentrations/sub-01_ses-01_concentration.nii.gz" - ) + seg_path = mri_data_dir / "mri-processed/mri_processed_data/sub-01" / "segmentations/sub-01_seg-aparc+aseg_refined.nii.gz" + mri_path = mri_data_dir / "mri-processed/mri_processed_data/sub-01" / "concentrations/sub-01_ses-01_concentration.nii.gz" dataframe = generate_stats_dataframe(seg_path, mri_path) @@ -55,20 +47,10 @@ def test_compute_stats_default(mri_data_dir: Path): def test_compute_stats_patterns(mri_data_dir: Path): - seg_path = ( - mri_data_dir - / "mri-processed/mri_processed_data/sub-01" - / "segmentations/sub-01_seg-aparc+aseg_refined.nii.gz" - ) - mri_path = ( - mri_data_dir - / "mri-processed/mri_processed_data/sub-01" - / "concentrations/sub-01_ses-01_concentration.nii.gz" - ) + seg_path = mri_data_dir / "mri-processed/mri_processed_data/sub-01" / "segmentations/sub-01_seg-aparc+aseg_refined.nii.gz" + mri_path = mri_data_dir / "mri-processed/mri_processed_data/sub-01" / "concentrations/sub-01_ses-01_concentration.nii.gz" seg_pattern = "(?Psub-(control|patient)*\\d{2})_seg-(?P[^\\.]+)" - mri_data_pattern = ( - "(?Psub-(control|patient)*\\d{2})_(?Pses-\\d{2})_(?P[^\\.]+)" - ) + mri_data_pattern = "(?Psub-(control|patient)*\\d{2})_(?Pses-\\d{2})_(?P[^\\.]+)" dataframe = generate_stats_dataframe( seg_path, @@ -85,20 +67,10 @@ def test_compute_stats_patterns(mri_data_dir: Path): def test_compute_stats_timestamp(mri_data_dir: Path): - seg_path = ( - mri_data_dir - / "mri-processed/mri_processed_data/sub-01" - / "segmentations/sub-01_seg-aparc+aseg_refined.nii.gz" - ) - mri_path = ( - mri_data_dir - / "mri-processed/mri_processed_data/sub-01" - / "concentrations/sub-01_ses-01_concentration.nii.gz" - ) + seg_path = mri_data_dir / "mri-processed/mri_processed_data/sub-01" / "segmentations/sub-01_seg-aparc+aseg_refined.nii.gz" + mri_path = mri_data_dir / "mri-processed/mri_processed_data/sub-01" / "concentrations/sub-01_ses-01_concentration.nii.gz" seg_pattern = "(?Psub-(control|patient)*\\d{2})_seg-(?P[^\\.]+)" - mri_data_pattern = ( - "(?Psub-(control|patient)*\\d{2})_(?Pses-\\d{2})_(?P[^\\.]+)" - ) + mri_data_pattern = "(?Psub-(control|patient)*\\d{2})_(?Pses-\\d{2})_(?P[^\\.]+)" timetable = mri_data_dir / "timetable/timetable.tsv" timetable_sequence = "mixed" @@ -115,16 +87,8 @@ def test_compute_stats_timestamp(mri_data_dir: Path): def test_compute_stats_info(mri_data_dir: Path): - seg_path = ( - mri_data_dir - / "mri-processed/mri_processed_data/sub-01" - / "segmentations/sub-01_seg-aparc+aseg_refined.nii.gz" - ) - mri_path = ( - mri_data_dir - / "mri-processed/mri_processed_data/sub-01" - / "concentrations/sub-01_ses-01_concentration.nii.gz" - ) + seg_path = mri_data_dir / "mri-processed/mri_processed_data/sub-01" / "segmentations/sub-01_seg-aparc+aseg_refined.nii.gz" + mri_path = mri_data_dir / "mri-processed/mri_processed_data/sub-01" / "concentrations/sub-01_ses-01_concentration.nii.gz" info = { "mri_data": "concentration", "subject": "sub-01", @@ -142,20 +106,10 @@ def test_compute_stats_info(mri_data_dir: Path): def test_compute_mri_stats_cli(capsys, tmp_path: Path, mri_data_dir: Path): - seg_path = ( - mri_data_dir - / "mri-processed/mri_processed_data/sub-01" - / "segmentations/sub-01_seg-aparc+aseg_refined.nii.gz" - ) - mri_path = ( - mri_data_dir - / "mri-processed/mri_processed_data/sub-01" - / "concentrations/sub-01_ses-01_concentration.nii.gz" - ) + seg_path = mri_data_dir / "mri-processed/mri_processed_data/sub-01" / "segmentations/sub-01_seg-aparc+aseg_refined.nii.gz" + mri_path = mri_data_dir / "mri-processed/mri_processed_data/sub-01" / "concentrations/sub-01_ses-01_concentration.nii.gz" seg_pattern = "(?Psub-(control|patient)*\\d{2})_seg-(?P[^\\.]+)" - mri_data_pattern = ( - "(?Psub-(control|patient)*\\d{2})_(?Pses-\\d{2})_(?P[^\\.]+)" - ) + mri_data_pattern = "(?Psub-(control|patient)*\\d{2})_(?Pses-\\d{2})_(?P[^\\.]+)" timetable = mri_data_dir / "timetable/timetable.tsv" timetable_sequence = "mixed"