Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
505109c
Speed up ribbon generation
ClePol May 21, 2026
d3741d1
Speed up spherical morphing
ClePol May 21, 2026
7b2985c
Overlap surface statistics
ClePol May 21, 2026
8af67fd
Speed up surface wrapper steps
ClePol May 14, 2026
ade2a1d
Overlap pre-ribbon surface jobs
ClePol May 14, 2026
f871acd
Overlap wmparc volume labeling
ClePol May 15, 2026
5d70baa
Speed up mapped volume projection
ClePol May 15, 2026
3d5ed3d
Overlap ribbon generation with hemisphere tail
ClePol May 15, 2026
bda3904
Speed up mapped volume tail
ClePol May 15, 2026
c182ba8
Speed up inflated curvature computation
ClePol May 15, 2026
cf8aa1c
Overlap ribbon with surface registration
ClePol May 15, 2026
c81f0c7
Skip legacy T1 normalization by default
ClePol May 15, 2026
3f8f1fb
Speed up aparc smoothing mode filter
ClePol May 15, 2026
34c841f
Speed up N4 bias correction
ClePol May 16, 2026
4c996fd
Overlap CerebNet with HypVINN
ClePol May 16, 2026
05e6056
Trace HypVINN CPU inference
ClePol May 16, 2026
6e885b4
Trace FastSurferVINN CPU inference
ClePol May 16, 2026
c7d7140
Freeze traced CPU inference models
ClePol May 16, 2026
7fe907f
Share CPU TorchScript inference helpers
ClePol May 18, 2026
3882edb
Cap CPU inference threads
ClePol May 21, 2026
d3d7acc
Speed up qsphere fallback and ribbon masks
ClePol May 18, 2026
6d64074
Overlap hypointensity relabeling
ClePol May 18, 2026
4cfb41d
Share cropped volume helpers for volmask
ClePol May 21, 2026
d02a5b5
Overlap auxiliary segmentations with stats
ClePol May 16, 2026
cf57f02
Speed up GPU segmentation tail
ClePol May 16, 2026
8a8975b
Tune GPU auxiliary segmentation tail
ClePol May 16, 2026
611e5b4
Overlap surface reconstruction with segmentation tail
ClePol May 16, 2026
b64278c
Handle embedded timing markers in recon-surf logs
ClePol May 17, 2026
f3691cb
Overlap independent surface tail work
ClePol May 17, 2026
678a327
Overlap Talairach and defer tail wait
ClePol May 17, 2026
572fb9e
fix ruff error
ClePol May 22, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion CerebNet/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from FastSurferCNN.utils.common import SubjectDirectory, SubjectList, find_device
from FastSurferCNN.utils.mapper import JsonColorLookupTable, Mapper, TSVLookupTable
from FastSurferCNN.utils.parallel import SerialExecutor, get_num_threads
from FastSurferCNN.utils.torchscript import cpu_torch_threads

if TYPE_CHECKING:
import yacs.config
Expand Down Expand Up @@ -92,7 +93,6 @@ def __init__(
self._threads = None
self.threads = threads
_threads = get_num_threads() if self._threads is None else self._threads
torch.set_num_threads(_threads)
self.pool = ThreadPoolExecutor(self._threads) if async_io else SerialExecutor()
self.cfg = cfg
self._async_io = async_io
Expand All @@ -109,6 +109,9 @@ def __init__(
torch.manual_seed(cfg.RNG_SEED)

_device = find_device(device)
torch_threads = cpu_torch_threads(_threads, _device)
if torch_threads is not None:
torch.set_num_threads(torch_threads)
if _device == "cpu" and viewagg_device == "auto":
_viewagg_device = torch.device("cpu")
else:
Expand Down
36 changes: 35 additions & 1 deletion FastSurferCNN/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,22 @@
from FastSurferCNN.data_loader.dataset import MultiScaleOrigDataThickSlices
from FastSurferCNN.models.networks import build_model
from FastSurferCNN.utils import logging
from FastSurferCNN.utils.torchscript import env_flag_enabled, should_trace_cpu_inference, trace_for_inference

logger = logging.getLogger(__name__)


class _FastSurferVINNTraceWrapper(torch.nn.Module):
"""Trace adapter for the common no-output-scale inference path."""

def __init__(self, model: torch.nn.Module):
super().__init__()
self.model = model

def forward(self, images: torch.Tensor, scale_factors: torch.Tensor) -> torch.Tensor:
return self.model(images, scale_factors, None)


class Inference:
"""Model evaluation class to run inference using FastSurferCNN.

Expand Down Expand Up @@ -324,6 +336,14 @@ def eval(
Prediction probability tensor.
"""
self.model.eval()
trace_model = should_trace_cpu_inference(
out_scale=out_scale,
device=self.device,
batch_size=self.cfg.TEST.BATCH_SIZE,
env_var="FASTSURFER_VINN_TRACE",
)
freeze_model = env_flag_enabled("FASTSURFER_VINN_FREEZE")
traced_model = False
# we should check here, whether the DataLoader is a Random or a SequentialSampler, but we cannot easily.
if not isinstance(val_loader.sampler, torch.utils.data.SequentialSampler):
logger.warning(
Expand Down Expand Up @@ -351,8 +371,22 @@ def eval(
# move data to the model device
images, scale_factors = batch["image"].to(self.device), batch["scale_factor"].to(self.device)

if trace_model and batch_idx == 0:
self.model = trace_for_inference(
model=self.model,
wrapper_factory=_FastSurferVINNTraceWrapper,
example_inputs=(images, scale_factors),
freeze=freeze_model,
logger=logger,
label=plane,
)
traced_model = True

# predict the current batch, outputs logits
pred = self.model(images, scale_factors, out_scale)
if traced_model:
pred = self.model(images, scale_factors)
else:
pred = self.model(images, scale_factors, out_scale)
batch_size = pred.shape[0]
end_index = start_index + batch_size

Expand Down
5 changes: 4 additions & 1 deletion FastSurferCNN/run_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
from FastSurferCNN.utils.load_config import load_config
from FastSurferCNN.utils.parallel import SerialExecutor, pipeline
from FastSurferCNN.utils.parser_defaults import SubjectDirectoryConfig
from FastSurferCNN.utils.torchscript import cpu_torch_threads

LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -202,14 +203,16 @@ def __init__(
"""
# TODO Fix docstring of RunModelOnData.__init__
self._threads = threads
torch.set_num_threads(self._threads)
self._async_io = async_io
self.orientation = orientation
self.image_size = image_size

self.sf = 1.0

self.device = find_device(device)
torch_threads = cpu_torch_threads(self._threads, self.device)
if torch_threads is not None:
torch.set_num_threads(torch_threads)

if self.device.type == "cpu" and viewagg_device in ("auto", "cpu"):
self.viewagg_device = self.device
Expand Down
72 changes: 72 additions & 0 deletions FastSurferCNN/utils/torchscript.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
"""TorchScript helpers for CPU inference hot paths."""

from __future__ import annotations

import os
import time
import warnings
from collections.abc import Callable

import torch


def env_flag_enabled(name: str, default: str = "1") -> bool:
return os.environ.get(name, default) != "0"


def cpu_torch_threads(requested: int | None, device=None) -> int | None:
"""Cap CPU inference threads to physical cores when more threads were requested."""
device_type = getattr(device, "type", device)
if device_type != "cpu" or requested is None or requested < 1:
return requested

override = os.environ.get("FASTSURFER_CPU_TORCH_THREADS")
if override:
try:
return max(1, int(override))
except ValueError:
pass

cpu_count = os.cpu_count()
if cpu_count is None or cpu_count < 2:
return requested
return min(requested, max(1, cpu_count // 2))


def should_trace_cpu_inference(
*,
out_scale: object,
device: torch.device,
batch_size: int,
env_var: str,
) -> bool:
return (
out_scale is None
and device.type == "cpu"
and batch_size == 1
and env_flag_enabled(env_var)
)


def trace_for_inference(
*,
model: torch.nn.Module,
wrapper_factory: Callable[[torch.nn.Module], torch.nn.Module],
example_inputs: tuple[torch.Tensor, ...],
freeze: bool,
logger,
label: str,
) -> torch.nn.Module:
trace_start = time.time()
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=torch.jit.TracerWarning)
traced_model = torch.jit.trace(
wrapper_factory(model),
example_inputs,
check_trace=False,
)
traced_model.eval()
if freeze:
traced_model = torch.jit.freeze(traced_model)
logger.info(f"Traced {label} model in {time.time() - trace_start:0.4f} seconds")
return traced_model
51 changes: 49 additions & 2 deletions HypVINN/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,12 @@
import FastSurferCNN.utils.logging as logging
from FastSurferCNN.data_loader.augmentation import ToTensorTest
from FastSurferCNN.utils.common import find_device
from FastSurferCNN.utils.torchscript import (
cpu_torch_threads,
env_flag_enabled,
should_trace_cpu_inference,
trace_for_inference,
)
from HypVINN.data_loader.data_utils import hypo_map_prediction_sagittal2full
from HypVINN.data_loader.dataset import HypVINNDataset
from HypVINN.models.networks import build_model
Expand All @@ -33,6 +39,22 @@
logger = logging.get_logger(__name__)


class _HypVINNTraceWrapper(torch.nn.Module):
"""Trace adapter for the common no-output-scale inference path."""

def __init__(self, model: torch.nn.Module):
super().__init__()
self.model = model

def forward(
self,
images: torch.Tensor,
scale_factors: torch.Tensor,
weight_factors: torch.Tensor,
) -> torch.Tensor:
return self.model(images, scale_factors, weight_factors, None)


class Inference:
"""
Class for running inference on a single subject.
Expand Down Expand Up @@ -76,7 +98,7 @@ def __init__(
"""
from FastSurferCNN.utils.parallel import get_num_threads

torch.set_num_threads(get_num_threads())
_threads = get_num_threads()
self._async_io = async_io

# Set random seed from configs.
Expand All @@ -90,6 +112,9 @@ def __init__(

# Define device and transfer model
self.device = find_device(device)
torch_threads = cpu_torch_threads(_threads, self.device)
if torch_threads is not None:
torch.set_num_threads(torch_threads)

if self.device.type == "cpu" and viewagg_device == "auto":
self.viewagg_device = self.device
Expand Down Expand Up @@ -314,6 +339,14 @@ def eval(self, val_loader: DataLoader, pred_prob: torch.Tensor, out_scale: float
The updated prediction probabilities.
"""
self.model.eval()
trace_model = should_trace_cpu_inference(
out_scale=out_scale,
device=self.device,
batch_size=self.cfg.TEST.BATCH_SIZE,
env_var="FASTSURFER_HYPVINN_TRACE",
)
freeze_model = env_flag_enabled("FASTSURFER_HYPVINN_FREEZE")
traced_model = False

start_index = 0
for _batch_idx, batch in tqdm(enumerate(val_loader), total=len(val_loader)):
Expand All @@ -322,7 +355,21 @@ def eval(self, val_loader: DataLoader, pred_prob: torch.Tensor, out_scale: float
scale_factors = batch["scale_factor"].to(self.device)
weight_factors = batch["weight_factor"].to(self.device, dtype=torch.float32)

pred = self.model(images, scale_factors, weight_factors, out_scale)
if trace_model and _batch_idx == 0:
self.model = trace_for_inference(
model=self.model,
wrapper_factory=_HypVINNTraceWrapper,
example_inputs=(images, scale_factors, weight_factors),
freeze=freeze_model,
logger=logger,
label=self.cfg.DATA.PLANE,
)
traced_model = True

if traced_model:
pred = self.model(images, scale_factors, weight_factors)
else:
pred = self.model(images, scale_factors, weight_factors, out_scale)

if self.cfg.DATA.PLANE == "axial":
pred = pred.permute((2, 3, 0, 1)).to(self.viewagg_device)
Expand Down
35 changes: 35 additions & 0 deletions recon_surf/check_surface_volume_info.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
#!/usr/bin/env python3

"""Check that a FreeSurfer surface has valid volume metadata."""

from __future__ import annotations

import argparse
import sys

from nibabel.freesurfer.io import read_geometry


def options_parse() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument("surface", help="FreeSurfer surface to check")
return parser.parse_args()


def main() -> int:
args = options_parse()
info = read_geometry(args.surface, read_metadata=True)[2]
head = list(info.get("head", []))
valid = str(info.get("valid", "")).startswith("1")
if valid and head == [2, 0, 20]:
return 0
print(
f"Invalid surface volume metadata in {args.surface}: "
f"valid={info.get('valid')!r}, head={head!r}",
file=sys.stderr,
)
return 1


if __name__ == "__main__":
sys.exit(main())
Loading
Loading