Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
167 changes: 159 additions & 8 deletions src/eaa/image_proc.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import matplotlib.pyplot as plt
import scipy.ndimage as ndi
from scipy import optimize
from scipy.special import erf
from skimage.metrics import normalized_mutual_information
from skimage.registration import phase_cross_correlation as skimage_phase_cross_correlation
from sciagent.message_proc import generate_openai_message
Expand Down Expand Up @@ -117,10 +118,30 @@ def add_frame_to_pil(image: Image.Image) -> Image.Image:
return buffer


def _gaussian_rect_window(shape: tuple[int, int], decay_fraction: float = 0.2) -> np.ndarray:
"""2D Gaussian-softened rectangle window.

The mask is 1 in the center and decays smoothly to 0 at each edge.
The decay is the convolution of a step function with a Gaussian, implemented
via the error function. The transition from 1 to 0 spans ``decay_fraction``
of the image size on each side.
"""
def _win_1d(size: int) -> np.ndarray:
decay_len = decay_fraction * size
sigma = decay_len / 4.0
s2 = sigma * np.sqrt(2.0)
x = np.arange(size, dtype=float)
left = 0.5 * (1.0 + erf((x - decay_len / 2.0) / s2))
right = 0.5 * (1.0 - erf((x - (size - 1.0 - decay_len / 2.0)) / s2))
return np.minimum(left, right)

return np.outer(_win_1d(shape[0]), _win_1d(shape[1]))


def phase_cross_correlation(
moving: np.ndarray,
ref: np.ndarray,
use_hanning_window: bool = True,
moving: np.ndarray,
ref: np.ndarray,
filtering_method: Optional[Literal["hanning", "gaussian"]] = "hanning",
upsample_factor: int = 1,
) -> np.ndarray | Tuple[np.ndarray, float]:
"""Phase correlation with windowing. The result gives
Expand All @@ -133,9 +154,12 @@ def phase_cross_correlation(
A 2D image.
ref : np.ndarray
A 2D image.
use_hanning_window : bool, optional
If True, a Hanning window is used to smooth the images before the
correlation is computed.
filtering_method : {"hanning", "gaussian"} or None, optional
Window function applied to both images before phase correlation to
reduce spectral leakage. ``"hanning"`` uses a standard Hanning window.
``"gaussian"`` uses a Gaussian-softened rectangle that is 1 in the
centre and decays to 0 at each edge over 20 % of the image size.
Pass ``None`` to disable windowing.
upsample_factor : int, optional
Upsampling factor for subpixel accuracy in phase correlation.
A value of 1 yields pixel-level precision.
Expand All @@ -150,15 +174,24 @@ def phase_cross_correlation(
)
moving = moving - moving.mean()
ref = ref - ref.mean()
if use_hanning_window:
if filtering_method == "hanning":
win_y = np.hanning(moving.shape[0])
win_x = np.hanning(moving.shape[1])
win = np.outer(win_y, win_x)
moving_for_registration = moving * win
ref_for_registration = ref * win
else:
elif filtering_method == "gaussian":
win = _gaussian_rect_window(moving.shape, decay_fraction=0.2)
moving_for_registration = moving * win
ref_for_registration = ref * win
elif filtering_method is None:
moving_for_registration = moving
ref_for_registration = ref
else:
raise ValueError(
f"Unknown filtering_method {filtering_method!r}. "
"Use 'hanning', 'gaussian', or None."
)

shift, _, _ = skimage_phase_cross_correlation(
ref_for_registration,
Expand All @@ -168,6 +201,124 @@ def phase_cross_correlation(
return shift


def error_minimization_registration(
moving: np.ndarray,
ref: np.ndarray,
y_valid_fraction: float = 0.8,
x_valid_fraction: float = 0.8,
subpixel: bool = True,
) -> np.ndarray:
"""Image registration by exhaustive integer-shift MSE search with quadratic
subpixel refinement.

A central window of size ``(y_valid_fraction * h, x_valid_fraction * w)``
is fixed in the reference image. The moving image is sampled at the same
window position for every integer shift (dy, dx) within the margins
``[-max_dy, max_dy] × [-max_dx, max_dx]``, where the margins are the pixel
gaps between the valid window and the image boundary. No wrap-around pixels
are ever included: the valid window is identical for all shifts.

The resulting 2-D MSE map is fitted with a 2-D quadratic polynomial. The
analytic minimum of that polynomial is returned as the sub-pixel shift.

Parameters
----------
moving : np.ndarray
2-D image to register.
ref : np.ndarray
2-D reference image with the same shape as *moving*.
y_valid_fraction : float
Fraction of the image height occupied by the comparison window.
Values close to 1 leave little margin and therefore a small search range.
x_valid_fraction : float
Same as *y_valid_fraction* along the x (column) axis.
subpixel : bool
If True, perform subpixel refinement using a 2D quadratic fit.

Returns
-------
np.ndarray
Estimated (dy, dx) shift to apply to *moving* so that it aligns with
*ref*.
"""
assert moving.shape == ref.shape, (
"The shapes of the moving and reference images must be the same."
)
h, w = ref.shape

vh = int(round(y_valid_fraction * h))
vw = int(round(x_valid_fraction * w))

# Centre the valid window; margin on each side = max search range
r0 = (h - vh) // 2
c0 = (w - vw) // 2
r1, c1 = r0 + vh, c0 + vw
max_dy, max_dx = r0, c0

if max_dy == 0 and max_dx == 0:
return np.zeros(2)

dy_vals = np.arange(-max_dy, max_dy + 1)
dx_vals = np.arange(-max_dx, max_dx + 1)

ref_crop = ref[r0:r1, c0:c1].astype(float)
moving_f = moving.astype(float)

# Exhaustive integer-shift MSE map
error_map = np.empty((len(dy_vals), len(dx_vals)))
for i, dy in enumerate(dy_vals):
for j, dx in enumerate(dx_vals):
diff = moving_f[r0 + dy : r1 + dy, c0 + dx : c1 + dx] - ref_crop
error_map[i, j] = np.mean(diff * diff)

# Fit quadratic in a local neighbourhood around the integer minimum.
# Neighbourhood half-width: 10% of image size / 2, at least 1 (→ 3×3 minimum).
min_i, min_j = np.unravel_index(np.argmin(error_map), error_map.shape)
if not subpixel:
return -np.array([float(dy_vals[min_i]), float(dx_vals[min_j])])

half_y = max(1, int(round(0.05 * h)))
half_x = max(1, int(round(0.05 * w)))
i_lo = max(0, min_i - half_y)
i_hi = min(len(dy_vals) - 1, min_i + half_y)
j_lo = max(0, min_j - half_x)
j_hi = min(len(dx_vals) - 1, min_j + half_x)
local_dy = dy_vals[i_lo : i_hi + 1]
local_dx = dx_vals[j_lo : j_hi + 1]
local_err = error_map[i_lo : i_hi + 1, j_lo : j_hi + 1]

# The 2-D quadratic has 6 parameters; require ≥3 points in each dimension so
# the design matrix is well-determined and the Hessian is not rank-deficient.
if len(local_dy) >= 3 and len(local_dx) >= 3:
# Fit: f(dy, dx) = a*dy² + b*dx² + c*dy*dx + d*dy + e*dx + g
dy_mesh, dx_mesh = np.meshgrid(local_dy, local_dx, indexing="ij")
dy_f = dy_mesh.ravel()
dx_f = dx_mesh.ravel()
design = np.column_stack(
[dy_f**2, dx_f**2, dy_f * dx_f, dy_f, dx_f, np.ones(len(dy_f))]
)
coeffs, _, _, _ = np.linalg.lstsq(design, local_err.ravel(), rcond=None)
a, b, c, d, e, _ = coeffs

# Analytic minimum: solve Hessian @ [dy_min, dx_min]ᵀ = -gradient
# Hessian = [[2a, c], [c, 2b]]; gradient at origin = [d, e]
hess = np.array([[2.0 * a, c], [c, 2.0 * b]])
try:
if np.all(np.linalg.eigvalsh(hess) > 0):
shift = np.linalg.solve(hess, np.array([-d, -e]))
else:
raise np.linalg.LinAlgError("Hessian not positive definite")
except np.linalg.LinAlgError:
shift = np.array([float(dy_vals[min_i]), float(dx_vals[min_j])])
else:
shift = np.array([float(dy_vals[min_i]), float(dx_vals[min_j])])

# Negate: the MSE is minimised at the offset where moving[r0+dy:] matches
# ref[r0:], but the caller wants the shift to apply to moving so that
# roll(moving, shift) ≈ ref, which is the opposite direction.
return -shift


def normalize_image_01(image: np.ndarray) -> np.ndarray:
"""Normalize image intensities to [0, 1]."""
image = np.nan_to_num(image, nan=0.0, posinf=0.0, neginf=0.0).astype(np.float32)
Expand Down
111 changes: 96 additions & 15 deletions src/eaa/maths.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging

import numpy as np
import scipy.ndimage
import scipy.optimize

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -35,7 +36,7 @@ def fit_gaussian_1d(
y: np.ndarray,
y_threshold: float = 0,
) -> tuple[float, float, float, float, float, float, float]:
"""Fit a 1D Gaussian to the data after subtracting a linear background.
"""Fit a 1D Gaussian to 1D data.

Parameters
----------
Expand All @@ -58,31 +59,111 @@ def fit_gaussian_1d(
"""
x_data = np.array(x, dtype=float)
y_data = np.array(y, dtype=float)
finite_mask = np.isfinite(x_data) & np.isfinite(y_data)
x_data = x_data[finite_mask]
y_data = y_data[finite_mask]
if x_data.size < 5:
logger.error("Too few finite data points for Gaussian fitting. Returning NaN values.")
return np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan

order = np.argsort(x_data)
x_data = x_data[order]
y_data = y_data[order]
x_min = float(np.min(x_data))
x_max_input = float(np.max(x_data))
x_span = x_max_input - x_min
if not np.isfinite(x_span) or x_span <= 0:
logger.error("Invalid x range for Gaussian fitting. Returning NaN values.")
return np.nan, np.nan, np.nan, np.nan, np.nan, x_min, x_max_input

y_max, y_min = np.max(y_data), np.min(y_data)
x_max = x_data[np.argmax(y_data)]
offset = x_max
y_range = float(y_max - y_min)
if not np.isfinite(y_range) or np.isclose(y_range, 0.0):
logger.error("Input data are too flat for Gaussian fitting. Returning NaN values.")
return np.nan, np.nan, np.nan, np.nan, np.nan, x_min, x_max_input

smooth_sigma = max(1.0, x_data.size * 0.02)
y_smooth = scipy.ndimage.gaussian_filter1d(y_data, sigma=smooth_sigma, mode="nearest")
y_smooth_max = float(np.max(y_smooth))
y_smooth_min = float(np.min(y_smooth))
y_smooth_range = y_smooth_max - y_smooth_min
if not np.isfinite(y_smooth_range) or np.isclose(y_smooth_range, 0.0):
logger.error("Smoothed data are too flat for Gaussian fitting. Returning NaN values.")
return np.nan, np.nan, np.nan, np.nan, np.nan, x_min, x_max_input

x_peak = float(x_data[np.argmax(y_smooth)])
offset = x_peak
x = x_data - offset
x_max = 0
mask = y_data >= y_min + y_threshold * (y_max - y_min)
a_guess = y_max - y_min
mu_guess = x_max
x_above_thresh = x[y_data > y_min + a_guess * 0.2]
if len(x_above_thresh) >= 3:
sigma_guess = (x_above_thresh.max() - x_above_thresh.min()) / 2
fit_threshold = y_smooth_min + y_threshold * y_smooth_range
mask = y_smooth >= fit_threshold
if int(np.count_nonzero(mask)) < 5:
mask = np.ones_like(y_data, dtype=bool)

positive_weight = np.clip(y_smooth - y_smooth_min, a_min=0.0, a_max=None)
if np.sum(positive_weight) > 0:
mu_guess = float(np.sum(x * positive_weight) / np.sum(positive_weight))
else:
sigma_guess = (x.max() - x.min()) / 2
c_guess = y_min
mu_guess = 0.0

width_mask = y_smooth >= (y_smooth_min + 0.5 * y_smooth_range)
x_above_half = x[width_mask]
if x_above_half.size >= 2:
sigma_guess = float((x_above_half.max() - x_above_half.min()) / 2.355)
else:
sigma_guess = x_span / 6.0
sigma_guess = float(np.clip(sigma_guess, x_span / 100.0, x_span))

a_guess = max(y_smooth_range, np.finfo(float).eps)
c_guess = float(np.median(y_data[~mask])) if np.any(~mask) else float(y_smooth_min)
p0 = [a_guess, mu_guess, sigma_guess, c_guess]
try:
popt, _ = scipy.optimize.curve_fit(gaussian_1d, x[mask], y_data[mask], p0=p0)
except RuntimeError:
lower_bounds = [0.0, float(np.min(x)), max(x_span / 1000.0, 1e-12), float(y_min - y_range)]
upper_bounds = [float(2 * y_range + abs(y_max)), float(np.max(x)), float(2 * x_span), float(y_max + y_range)]

def run_fit(current_mask: np.ndarray, current_p0: list[float]) -> np.ndarray | None:
if int(np.count_nonzero(current_mask)) < 5:
return None
try:
popt, _ = scipy.optimize.curve_fit(
gaussian_1d,
x[current_mask],
y_data[current_mask],
p0=current_p0,
bounds=(lower_bounds, upper_bounds),
maxfev=20000,
)
return popt
except (RuntimeError, ValueError):
return None

popt = run_fit(mask, p0)
if popt is None:
y_smooth_retry = scipy.ndimage.gaussian_filter1d(
y_data,
sigma=max(2.0, x_data.size * 0.05),
mode="nearest",
)
retry_weight = np.clip(y_smooth_retry - np.min(y_smooth_retry), a_min=0.0, a_max=None)
if np.sum(retry_weight) > 0:
mu_retry = float(np.sum(x * retry_weight) / np.sum(retry_weight))
else:
mu_retry = 0.0
retry_mask = y_smooth_retry >= (
np.min(y_smooth_retry) + max(0.1, y_threshold) * (np.max(y_smooth_retry) - np.min(y_smooth_retry))
)
retry_p0 = [a_guess, mu_retry, sigma_guess, c_guess]
popt = run_fit(retry_mask, retry_p0)

if popt is None:
logger.error("Failed to fit Gaussian to data. Returning NaN values.")
return np.nan, np.nan, np.nan, np.nan, np.nan, x_min, x_max_input

y_fit = gaussian_1d(x, *popt)
amplitude = float(popt[0])
sigma = float(popt[2])
mu = float(popt[1])
if sigma <= 0 or not (float(np.min(x)) <= mu <= float(np.max(x))):
logger.error("Gaussian fit parameters are invalid. Returning NaN values.")
return np.nan, np.nan, np.nan, np.nan, np.nan, x_min, x_max_input
if np.isclose(amplitude, 0.0):
normalized_residual = np.nan
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def run(
reference_image,
psize_t=self.image_acquisition_tool.psize_k,
psize_r=self.image_registration_tool.reference_pixel_size,
registration_algorithm_kwargs={"use_hanning_window": True},
registration_algorithm_kwargs={"filtering_method": "hanning"},
)
if check_feature_presence_llm(
task_manager=self,
Expand Down
Loading