From 1a313bf56a0c41d96cf1cf6592fd0b310f2a7fcc Mon Sep 17 00:00:00 2001 From: Ming Du Date: Fri, 6 Mar 2026 15:50:19 -0600 Subject: [PATCH 01/15] FEAT: Gaussian window filtering for phase correlation --- src/eaa/image_proc.py | 49 ++++++++++++++++++++++++++++++++++++------- 1 file changed, 41 insertions(+), 8 deletions(-) diff --git a/src/eaa/image_proc.py b/src/eaa/image_proc.py index a18be28..f01663d 100644 --- a/src/eaa/image_proc.py +++ b/src/eaa/image_proc.py @@ -5,6 +5,7 @@ import matplotlib.pyplot as plt import scipy.ndimage as ndi from scipy import optimize +from scipy.special import erf from skimage.metrics import normalized_mutual_information from skimage.registration import phase_cross_correlation as skimage_phase_cross_correlation from sciagent.message_proc import generate_openai_message @@ -117,10 +118,30 @@ def add_frame_to_pil(image: Image.Image) -> Image.Image: return buffer +def _gaussian_rect_window(shape: tuple[int, int], decay_fraction: float = 0.2) -> np.ndarray: + """2D Gaussian-softened rectangle window. + + The mask is 1 in the center and decays smoothly to 0 at each edge. + The decay is the convolution of a step function with a Gaussian, implemented + via the error function. The transition from 1 to 0 spans ``decay_fraction`` + of the image size on each side. + """ + def _win_1d(size: int) -> np.ndarray: + decay_len = decay_fraction * size + sigma = decay_len / 4.0 + s2 = sigma * np.sqrt(2.0) + x = np.arange(size, dtype=float) + left = 0.5 * (1.0 + erf((x - decay_len / 2.0) / s2)) + right = 0.5 * (1.0 - erf((x - (size - 1.0 - decay_len / 2.0)) / s2)) + return np.minimum(left, right) + + return np.outer(_win_1d(shape[0]), _win_1d(shape[1])) + + def phase_cross_correlation( - moving: np.ndarray, - ref: np.ndarray, - use_hanning_window: bool = True, + moving: np.ndarray, + ref: np.ndarray, + filtering_method: Optional[Literal["hanning", "gaussian"]] = "hanning", upsample_factor: int = 1, ) -> np.ndarray | Tuple[np.ndarray, float]: """Phase correlation with windowing. The result gives @@ -133,9 +154,12 @@ def phase_cross_correlation( A 2D image. ref : np.ndarray A 2D image. - use_hanning_window : bool, optional - If True, a Hanning window is used to smooth the images before the - correlation is computed. + filtering_method : {"hanning", "gaussian"} or None, optional + Window function applied to both images before phase correlation to + reduce spectral leakage. ``"hanning"`` uses a standard Hanning window. + ``"gaussian"`` uses a Gaussian-softened rectangle that is 1 in the + centre and decays to 0 at each edge over 20 % of the image size. + Pass ``None`` to disable windowing. upsample_factor : int, optional Upsampling factor for subpixel accuracy in phase correlation. A value of 1 yields pixel-level precision. @@ -150,15 +174,24 @@ def phase_cross_correlation( ) moving = moving - moving.mean() ref = ref - ref.mean() - if use_hanning_window: + if filtering_method == "hanning": win_y = np.hanning(moving.shape[0]) win_x = np.hanning(moving.shape[1]) win = np.outer(win_y, win_x) moving_for_registration = moving * win ref_for_registration = ref * win - else: + elif filtering_method == "gaussian": + win = _gaussian_rect_window(moving.shape, decay_fraction=0.2) + moving_for_registration = moving * win + ref_for_registration = ref * win + elif filtering_method is None: moving_for_registration = moving ref_for_registration = ref + else: + raise ValueError( + f"Unknown filtering_method {filtering_method!r}. " + "Use 'hanning', 'gaussian', or None." + ) shift, _, _ = skimage_phase_cross_correlation( ref_for_registration, From 701055b24f2374d245e0592259e83e05400f01af Mon Sep 17 00:00:00 2001 From: Ming Du Date: Fri, 6 Mar 2026 15:51:54 -0600 Subject: [PATCH 02/15] FEAT: error minimization registration --- src/eaa/image_proc.py | 112 +++++++++++++++++++++++++++ src/eaa/tool/imaging/registration.py | 38 ++++++--- 2 files changed, 138 insertions(+), 12 deletions(-) diff --git a/src/eaa/image_proc.py b/src/eaa/image_proc.py index f01663d..9cca210 100644 --- a/src/eaa/image_proc.py +++ b/src/eaa/image_proc.py @@ -201,6 +201,118 @@ def phase_cross_correlation( return shift +def error_minimization_registration( + moving: np.ndarray, + ref: np.ndarray, + y_valid_fraction: float = 0.8, + x_valid_fraction: float = 0.8, +) -> np.ndarray: + """Image registration by exhaustive integer-shift MSE search with quadratic + subpixel refinement. + + A central window of size ``(y_valid_fraction * h, x_valid_fraction * w)`` + is fixed in the reference image. The moving image is sampled at the same + window position for every integer shift (dy, dx) within the margins + ``[-max_dy, max_dy] × [-max_dx, max_dx]``, where the margins are the pixel + gaps between the valid window and the image boundary. No wrap-around pixels + are ever included: the valid window is identical for all shifts. + + The resulting 2-D MSE map is fitted with a 2-D quadratic polynomial. The + analytic minimum of that polynomial is returned as the sub-pixel shift. + + Parameters + ---------- + moving : np.ndarray + 2-D image to register. + ref : np.ndarray + 2-D reference image with the same shape as *moving*. + y_valid_fraction : float + Fraction of the image height occupied by the comparison window. + Values close to 1 leave little margin and therefore a small search range. + x_valid_fraction : float + Same as *y_valid_fraction* along the x (column) axis. + + Returns + ------- + np.ndarray + Estimated (dy, dx) shift to apply to *moving* so that it aligns with + *ref*. + """ + assert moving.shape == ref.shape, ( + "The shapes of the moving and reference images must be the same." + ) + h, w = ref.shape + + vh = int(round(y_valid_fraction * h)) + vw = int(round(x_valid_fraction * w)) + + # Centre the valid window; margin on each side = max search range + r0 = (h - vh) // 2 + c0 = (w - vw) // 2 + r1, c1 = r0 + vh, c0 + vw + max_dy, max_dx = r0, c0 + + if max_dy == 0 and max_dx == 0: + return np.zeros(2) + + dy_vals = np.arange(-max_dy, max_dy + 1) + dx_vals = np.arange(-max_dx, max_dx + 1) + + ref_crop = ref[r0:r1, c0:c1].astype(float) + moving_f = moving.astype(float) + + # Exhaustive integer-shift MSE map + error_map = np.empty((len(dy_vals), len(dx_vals))) + for i, dy in enumerate(dy_vals): + for j, dx in enumerate(dx_vals): + diff = moving_f[r0 + dy : r1 + dy, c0 + dx : c1 + dx] - ref_crop + error_map[i, j] = np.mean(diff * diff) + + # Fit quadratic in a local neighbourhood around the integer minimum. + # Neighbourhood half-width: 10% of image size / 2, at least 1 (→ 3×3 minimum). + min_i, min_j = np.unravel_index(np.argmin(error_map), error_map.shape) + half_y = max(1, int(round(0.05 * h))) + half_x = max(1, int(round(0.05 * w))) + i_lo = max(0, min_i - half_y) + i_hi = min(len(dy_vals) - 1, min_i + half_y) + j_lo = max(0, min_j - half_x) + j_hi = min(len(dx_vals) - 1, min_j + half_x) + local_dy = dy_vals[i_lo : i_hi + 1] + local_dx = dx_vals[j_lo : j_hi + 1] + local_err = error_map[i_lo : i_hi + 1, j_lo : j_hi + 1] + + # The 2-D quadratic has 6 parameters; require ≥3 points in each dimension so + # the design matrix is well-determined and the Hessian is not rank-deficient. + if len(local_dy) >= 3 and len(local_dx) >= 3: + # Fit: f(dy, dx) = a*dy² + b*dx² + c*dy*dx + d*dy + e*dx + g + dy_mesh, dx_mesh = np.meshgrid(local_dy, local_dx, indexing="ij") + dy_f = dy_mesh.ravel() + dx_f = dx_mesh.ravel() + design = np.column_stack( + [dy_f**2, dx_f**2, dy_f * dx_f, dy_f, dx_f, np.ones(len(dy_f))] + ) + coeffs, _, _, _ = np.linalg.lstsq(design, local_err.ravel(), rcond=None) + a, b, c, d, e, _ = coeffs + + # Analytic minimum: solve Hessian @ [dy_min, dx_min]ᵀ = -gradient + # Hessian = [[2a, c], [c, 2b]]; gradient at origin = [d, e] + hess = np.array([[2.0 * a, c], [c, 2.0 * b]]) + try: + if np.all(np.linalg.eigvalsh(hess) > 0): + shift = np.linalg.solve(hess, np.array([-d, -e])) + else: + raise np.linalg.LinAlgError("Hessian not positive definite") + except np.linalg.LinAlgError: + shift = np.array([float(dy_vals[min_i]), float(dx_vals[min_j])]) + else: + shift = np.array([float(dy_vals[min_i]), float(dx_vals[min_j])]) + + # Negate: the MSE is minimised at the offset where moving[r0+dy:] matches + # ref[r0:], but the caller wants the shift to apply to moving so that + # roll(moving, shift) ≈ ref, which is the opposite direction. + return -shift + + def normalize_image_01(image: np.ndarray) -> np.ndarray: """Normalize image intensities to [0, 1].""" image = np.nan_to_num(image, nan=0.0, posinf=0.0, neginf=0.0).astype(np.float32) diff --git a/src/eaa/tool/imaging/registration.py b/src/eaa/tool/imaging/registration.py index 80f52ae..82570d8 100644 --- a/src/eaa/tool/imaging/registration.py +++ b/src/eaa/tool/imaging/registration.py @@ -15,7 +15,11 @@ from sciagent.api.memory import MemoryManagerConfig from eaa.tool.imaging.acquisition import AcquireImage -from eaa.image_proc import phase_cross_correlation, translation_nmi_registration +from eaa.image_proc import ( + error_minimization_registration, + phase_cross_correlation, + translation_nmi_registration, +) logger = logging.getLogger(__name__) @@ -37,7 +41,7 @@ def __init__( reference_image: np.ndarray = None, reference_pixel_size: float = 1.0, image_coordinates_origin: Literal["top_left", "center"] = "top_left", - registration_method: Literal["phase_correlation", "sift", "mutual_information", "llm"] = "phase_correlation", + registration_method: Literal["phase_correlation", "sift", "mutual_information", "llm", "error_minimization"] = "phase_correlation", log_scale: bool = False, require_approval: bool = False, *args, @@ -66,10 +70,12 @@ def __init__( When this argument is set to "center", the test image is padded/cropped centrally. When it is set to "top_left", the test image is on the bottom and right sides. - registration_method : Literal["phase_correlation", "sift", "mutual_information"], optional + registration_method : Literal["phase_correlation", "sift", "mutual_information", "llm", "error_minimization"], optional The method used to estimate translational offsets. "phase_correlation" - uses phase correlation, "sift" uses feature matching, and - "mutual_information" uses pyramid-based normalized mutual information. + uses phase correlation, "sift" uses feature matching, + "mutual_information" uses pyramid-based normalized mutual information, + and "error_minimization" uses exhaustive integer-shift MSE search with + local quadratic subpixel refinement. log_scale : bool, optional If True, images are transformed as `log10(x + 1)` before registration. """ @@ -329,12 +335,12 @@ def register_images_llm(self, image_t: np.ndarray, image_r: np.ndarray) -> np.nd ], dtype=float) def register_images( - self, - image_t: np.ndarray, - image_r: np.ndarray, - psize_t: float, + self, + image_t: np.ndarray, + image_r: np.ndarray, + psize_t: float, psize_r: float, - registration_method: Optional[Literal["phase_correlation", "sift", "mutual_information", "llm"]] = None, + registration_method: Optional[Literal["phase_correlation", "sift", "mutual_information", "llm", "error_minimization"]] = None, registration_algorithm_kwargs: Optional[dict[str, Any]] = None, ) -> np.ndarray | Tuple[np.ndarray, float] | str: """ @@ -369,6 +375,10 @@ def register_images( - `max_iter` (int, default: `60`) - `tol` (float, default: `1e-4`) + - `registration_method="error_minimization"`: + - `y_valid_fraction` (float, default: `0.8`) + - `x_valid_fraction` (float, default: `0.8`) + - `registration_method="sift"` or `"llm"`: - No algorithm kwargs are currently supported; pass `None` or `{}`. @@ -388,11 +398,11 @@ def register_images( # Resize the target image to have the same pixel size as the reference image image_t = ndi.zoom(image_t, psize_t / psize_r) - if method in {"phase_correlation", "mutual_information"}: + if method in {"phase_correlation", "mutual_information", "error_minimization"}: image_t = self.reconcile_image_shape(image_t, image_r.shape) if method == "phase_correlation": - phase_kwargs = {"use_hanning_window": True} + phase_kwargs = {"filtering_method": "hanning"} phase_kwargs.update(algorithm_kwargs) offset = phase_cross_correlation( image_t, @@ -414,6 +424,10 @@ def register_images( ref=image_r, **mi_kwargs, ) + elif method == "error_minimization": + em_kwargs = {"y_valid_fraction": 0.8, "x_valid_fraction": 0.8} + em_kwargs.update(algorithm_kwargs) + offset = error_minimization_registration(image_t, image_r, **em_kwargs) elif method == "sift": if len(algorithm_kwargs) > 0: raise ValueError( From 22425b524059e4fb24883878a8013c0923e8ce9f Mon Sep 17 00:00:00 2001 From: Ming Du Date: Fri, 6 Mar 2026 15:52:07 -0600 Subject: [PATCH 03/15] FIX: take abs of Gaussian FWHM --- src/eaa/tool/imaging/acquisition.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/eaa/tool/imaging/acquisition.py b/src/eaa/tool/imaging/acquisition.py index e0335d6..cc36dff 100644 --- a/src/eaa/tool/imaging/acquisition.py +++ b/src/eaa/tool/imaging/acquisition.py @@ -519,7 +519,7 @@ def acquire_line_scan( fwhm = np.nan else: val_gauss = eaa.maths.gaussian_1d(ds, a, mu, sigma, c) - fwhm = 2.35 * sigma + fwhm = 2.35 * np.abs(sigma) show_scan_line = self.image_k is not None and len(self.image_acquisition_call_history) > 0 show_first_scan_line = ( From ad1ac9d98149a8a03bdc7602b42a0ba47c561bb8 Mon Sep 17 00:00:00 2001 From: Ming Du Date: Fri, 6 Mar 2026 15:52:37 -0600 Subject: [PATCH 04/15] FEAT: dual drift selection using linear fit distance --- .../tuning/analytical_focusing.py | 286 ++++++++++++++---- 1 file changed, 221 insertions(+), 65 deletions(-) diff --git a/src/eaa/task_manager/tuning/analytical_focusing.py b/src/eaa/task_manager/tuning/analytical_focusing.py index 2b58ed6..1b8fe6e 100644 --- a/src/eaa/task_manager/tuning/analytical_focusing.py +++ b/src/eaa/task_manager/tuning/analytical_focusing.py @@ -61,9 +61,9 @@ def __init__( registration_algorithm_kwargs: Optional[dict[str, Any]] = None, run_line_scan_checker: bool = True, run_offset_calibration: bool = True, - use_linear_drift_prediction: bool = False, - n_parameter_drift_points_before_prediction: int = 3, line_scan_predictor_tool: Optional[LineScanPredictor] = None, + dual_drift_selection_priming_iterations: int = 3, + dual_drift_estimation_primary_source: Literal["registration", "line_scan_predictor"] = "line_scan_predictor", *args, **kwargs ): """Analytical scanning microscope focusing task manager driven @@ -132,22 +132,29 @@ def __init__( If True, run 2D image acquisition and image-registration-based offset calibration. If False, the loop only performs parameter setting, line scan, and optimization updates/suggestions. - use_linear_drift_prediction : bool, optional - If True, fit linear models to predict image-acquisition drift - (y and x separately) as a function of optics parameters and use - the predicted positions for subsequent 2D scans once enough - parameter-drift samples have been collected. - n_parameter_drift_points_before_prediction : int, optional - Number of parameter-drift samples to collect before using linear - drift prediction for image acquisitions. + dual_drift_selection_priming_iterations : int, optional + Number of parameter–drift samples to collect (warm-up phase) before + the linear model is used to arbitrate between image registration and + line scan predictor drift estimates. Only relevant when + ``line_scan_predictor_tool`` is provided. line_scan_predictor_tool : LineScanPredictor, optional - If provided, this tool is used instead of image registration to - update scan positions after each 2D acquisition. It predicts the - optimal line scan center from the reference image, reference line - scan position, and current image, then shifts both line scan and - image acquisition kwargs by the predicted drift so they stay in - sync. Requires ``run_offset_calibration=True`` so that a 2D image - is acquired before the prediction is made. + If provided, both image registration and the line scan predictor + are run after each 2D acquisition to estimate drift. For the + first ``n_parameter_drift_points_before_prediction`` parameter + adjustments only the source selected by + ``dual_drift_estimation_primary_source`` is used; concurrently a + linear model is fit mapping optics parameters to cumulative drift + from the initial position. Afterwards the linear model prediction + is used to arbitrate: the drift estimate (from either method) + that is closer to the linear model prediction is chosen as the + correct drift, and the linear model is then updated with that + chosen value. Requires ``run_offset_calibration=True``. + dual_drift_estimation_primary_source : {"registration", "line_scan_predictor"} + Which drift-estimation method to trust exclusively during the first + ``n_parameter_drift_points_before_prediction`` iterations when + ``line_scan_predictor_tool`` is provided. After that point both + methods are evaluated and the linear model arbitrates. Defaults + to ``"line_scan_predictor"``. """ if acquisition_tool is None: raise ValueError("`acquisition_tool` must be provided.") @@ -186,7 +193,6 @@ def __init__( self.run_line_scan_checker = run_line_scan_checker self.run_offset_calibration = run_offset_calibration - self.use_linear_drift_prediction = use_linear_drift_prediction if line_scan_predictor_tool is not None and not run_offset_calibration: raise ValueError( "`line_scan_predictor_tool` requires `run_offset_calibration=True` " @@ -194,16 +200,18 @@ def __init__( ) self.line_scan_predictor_tool = line_scan_predictor_tool - self.n_parameter_drift_points_before_prediction = ( - n_parameter_drift_points_before_prediction + self.dual_drift_estimation_primary_source = dual_drift_estimation_primary_source + self.dual_drift_selection_priming_iterations = ( + dual_drift_selection_priming_iterations ) - if self.n_parameter_drift_points_before_prediction < 1: + if self.dual_drift_selection_priming_iterations < 1: raise ValueError( - "`n_parameter_drift_points_before_prediction` must be >= 1." + "`dual_drift_selection_priming_iterations` must be >= 1." ) self.drift_model_y = MultivariateLinearRegression() self.drift_model_x = MultivariateLinearRegression() self.initial_image_acquisition_position: np.ndarray | None = None + self.initial_line_scan_position: np.ndarray | None = None super().__init__( llm_config=llm_config, @@ -376,6 +384,8 @@ def initialize_kwargs_buffers( ): self.line_scan_kwargs = copy.deepcopy(initial_line_scan_kwargs) self.image_acquisition_kwargs = copy.deepcopy(initial_2d_scan_kwargs) + if self.line_scan_predictor_tool is not None and self.line_scan_kwargs is not None: + self.initial_line_scan_position = self.extract_scan_position(self.line_scan_kwargs) def run_line_scan(self) -> float: """Run a line scan and return the FWHM of the Gaussian fit. @@ -467,24 +477,12 @@ def extract_image_acquisition_position(self, kwargs: dict[str, float]) -> np.nda ) return np.array([float(kwargs[y_arg]), float(kwargs[x_arg])], dtype=float) - def should_apply_linear_drift_prediction(self) -> bool: - if not self.run_offset_calibration: - return False - if not self.use_linear_drift_prediction: - return False - if self.initial_image_acquisition_position is None: - return False - return ( - self.drift_model_y.get_n_parameter_drift_points_collected() - >= self.n_parameter_drift_points_before_prediction - ) - def update_linear_drift_models( self, parameters: np.ndarray, current_position_yx: np.ndarray | None = None, ): - if not self.use_linear_drift_prediction: + if self.line_scan_predictor_tool is None: return if self.initial_image_acquisition_position is None: return @@ -546,8 +544,8 @@ def apply_predicted_image_acquisition_position(self, parameters: np.ndarray): dtype=float, ) delta_pos = predicted_pos - current_pos - self.apply_offset_to_image_acquisition_kwargs(-delta_pos) - self.apply_offset_to_line_scan_kwargs(-delta_pos) + self.apply_offset_to_image_acquisition_kwargs(delta_pos) + self.apply_offset_to_line_scan_kwargs(delta_pos) self.record_system_message( "Using linear drift prediction for acquisition positions: " f"```Predicted 2D scan yx = {predicted_pos.tolist()}" @@ -812,15 +810,15 @@ def find_offset(self) -> np.ndarray: def apply_offset_to_line_scan_kwargs(self, offset: np.ndarray): for arg in self.line_scan_tool_x_coordinate_args: - self.line_scan_kwargs[arg] -= offset[1] + self.line_scan_kwargs[arg] += offset[1] for arg in self.line_scan_tool_y_coordinate_args: - self.line_scan_kwargs[arg] -= offset[0] + self.line_scan_kwargs[arg] += offset[0] def apply_offset_to_image_acquisition_kwargs(self, offset: np.ndarray): for arg in self.image_acquisition_tool_x_coordinate_args: - self.image_acquisition_kwargs[arg] -= offset[1] + self.image_acquisition_kwargs[arg] += offset[1] for arg in self.image_acquisition_tool_y_coordinate_args: - self.image_acquisition_kwargs[arg] -= offset[0] + self.image_acquisition_kwargs[arg] += offset[0] def collect_initial_data_optimization_tool( self, @@ -882,18 +880,16 @@ def rollback_and_shrink_delta(message_prefix: str) -> np.ndarray: self.record_system_message(f"Setting parameters to new value:```{x_current}```") self.param_setting_tool.set_parameters(x_current) if self.run_offset_calibration: - if self.should_apply_linear_drift_prediction(): - self.apply_predicted_image_acquisition_position(x_current) self.run_2d_scan() if self.line_scan_predictor_tool is not None: - self.apply_line_scan_predictor_offset() + self._apply_dual_drift_estimation(x_current) else: line_scan_pos_offset, alignment_offset = self.find_offset() if np.any(np.isnan(line_scan_pos_offset)): x_current = rollback_and_shrink_delta("Image registration failed (NaN offset).") continue - self.apply_offset_to_line_scan_kwargs(line_scan_pos_offset) - self.apply_offset_to_image_acquisition_kwargs(alignment_offset) + self.apply_offset_to_line_scan_kwargs(-line_scan_pos_offset) + self.apply_offset_to_image_acquisition_kwargs(-alignment_offset) self.update_linear_drift_models(x_current) self.record_linear_drift_model_visualizations() try: @@ -906,27 +902,187 @@ def rollback_and_shrink_delta(message_prefix: str) -> np.ndarray: self.update_optimization_model(fwhm) return - def apply_line_scan_predictor_offset(self) -> None: - """Update scan positions using the line scan predictor. + def _get_lsp_drift(self) -> np.ndarray: + """Return cumulative drift from initial line scan position via the LSP. + + Calls the line scan predictor and subtracts the stored initial line + scan position to produce a drift vector in the same physical units + used throughout the tracker. + + Returns + ------- + np.ndarray + Shape ``(2,)`` array ``[drift_y, drift_x]`` from the initial line + scan center to the predictor's estimate. + """ + lsp_json = json.loads(self.line_scan_predictor_tool.predict_line_scan_position()) + predicted_center = np.array( + [lsp_json["center_y"], lsp_json["center_x"]], dtype=float + ) + return predicted_center - self.initial_line_scan_position + + def _get_reg_drift(self) -> np.ndarray | None: + """Return cumulative drift from initial acquisition position via image registration. + + Runs image registration between the two most recent 2D acquisitions and + converts the incremental alignment offset into a cumulative drift from + the initial image acquisition position. + + Returns + ------- + np.ndarray or None + Shape ``(2,)`` array ``[drift_y, drift_x]``, or ``None`` if + registration returned NaN (failed). + """ + _line_scan_pos_offset, alignment_offset = self.find_offset() + if np.any(np.isnan(alignment_offset)): + return None + # apply_offset_to_image_acquisition_kwargs(o) does position += o, + # so the registration caller passes -alignment_offset and the new tracked + # position would be current_acq_pos - alignment_offset. + current_acq_pos = self.extract_image_acquisition_position( + self.image_acquisition_kwargs + ) + return current_acq_pos - alignment_offset - self.initial_image_acquisition_position + + def _select_drift( + self, + lsp_drift: np.ndarray, + reg_drift: np.ndarray | None, + x_current: np.ndarray, + ) -> tuple[np.ndarray, str]: + """Select the drift estimate to use, logging the decision. + + During the warm-up phase (fewer than ``dual_drift_selection_priming_iterations`` + samples collected) the primary source is used unconditionally. + Afterwards the linear model prediction arbitrates: the drift estimate + closer in Euclidean distance to the model prediction is chosen. + + Parameters + ---------- + lsp_drift : np.ndarray + Cumulative drift from the line scan predictor. + reg_drift : np.ndarray or None + Cumulative drift from image registration, or ``None`` on failure. + x_current : np.ndarray + Current optics parameter vector (used in the arbitration phase). + + Returns + ------- + chosen_drift : np.ndarray + chosen_source : str + ``"line_scan_predictor"`` or ``"registration"``. + """ + n_collected = self.drift_model_y.get_n_parameter_drift_points_collected() + n_needed = self.dual_drift_selection_priming_iterations + + if n_collected < n_needed: + if self.dual_drift_estimation_primary_source == "line_scan_predictor" or reg_drift is None: + chosen_drift, chosen_source = lsp_drift, "line_scan_predictor" + else: + chosen_drift, chosen_source = reg_drift, "registration" + if reg_drift is None: + self.record_system_message( + "Dual drift estimation: image registration returned NaN; " + "falling back to line scan predictor." + ) + self.record_system_message( + f"Dual drift estimation (primary phase, n={n_collected}/{n_needed}): " + f"```using {chosen_source}\n" + f"lsp_drift={lsp_drift.tolist()}\n" + f"reg_drift={reg_drift.tolist() if reg_drift is not None else 'NaN'}```" + ) + else: + x_in = np.array(x_current, dtype=float).reshape(1, -1).tolist() + model_drift = np.array( + [ + float(self.drift_model_y.predict(x_in)[0][0]), + float(self.drift_model_x.predict(x_in)[0][0]), + ], + dtype=float, + ) + dist_lsp = float(np.linalg.norm(lsp_drift - model_drift)) + dist_reg = ( + float(np.linalg.norm(reg_drift - model_drift)) + if reg_drift is not None + else np.inf + ) + if dist_lsp <= dist_reg: + chosen_drift, chosen_source = lsp_drift, "line_scan_predictor" + else: + chosen_drift, chosen_source = reg_drift, "registration" + dist_reg_str = f"{dist_reg:.4f}" if reg_drift is not None else "inf" + self.record_system_message( + f"Dual drift estimation (arbitration phase):\n" + f"```model_drift={model_drift.tolist()}\n" + f"lsp_drift={lsp_drift.tolist()} (dist={dist_lsp:.4f})\n" + f"reg_drift={reg_drift.tolist() if reg_drift is not None else 'NaN'} (dist={dist_reg_str})\n" + f"Chosen: {chosen_source}```" + ) + + return chosen_drift, chosen_source + + def _apply_chosen_drift(self, chosen_drift: np.ndarray) -> tuple[np.ndarray, np.ndarray]: + """Set scan positions to ``initial_position + chosen_drift`` for both trackers. - Calls :meth:`LineScanPredictor.predict_line_scan_position` to obtain - the predicted line scan center in the current image, computes the - drift relative to the current line scan center, then shifts both - ``line_scan_kwargs`` and ``image_acquisition_kwargs`` by that drift so - that the two sets of coordinates stay in sync. + ``apply_offset_to_*_kwargs(o)`` subtracts ``o`` from the stored + position, so ``-(target - current)`` is passed to move the position + forward to the target. + + Parameters + ---------- + chosen_drift : np.ndarray + Cumulative drift from the initial position to apply. + + Returns + ------- + target_ls_pos : np.ndarray + target_acq_pos : np.ndarray """ - current_center = self.extract_scan_position(self.line_scan_kwargs) - result = json.loads(self.line_scan_predictor_tool.predict_line_scan_position()) - predicted_center = np.array([result["center_y"], result["center_x"]], dtype=float) - drift = predicted_center - current_center + current_ls_pos = self.extract_scan_position(self.line_scan_kwargs) + target_ls_pos = self.initial_line_scan_position + chosen_drift + self.apply_offset_to_line_scan_kwargs(target_ls_pos - current_ls_pos) + + current_acq_pos = self.extract_image_acquisition_position(self.image_acquisition_kwargs) + target_acq_pos = self.initial_image_acquisition_position + chosen_drift + self.apply_offset_to_image_acquisition_kwargs(target_acq_pos - current_acq_pos) + self.record_system_message( - f"Line scan predictor: current center={current_center.tolist()}, " - f"predicted center={predicted_center.tolist()}, " - f"drift={drift.tolist()}" + f"Applied dual-drift correction:\n" + f"```chosen_drift={chosen_drift.tolist()}\n" + f"new line_scan_pos={target_ls_pos.tolist()}\n" + f"new acq_pos={target_acq_pos.tolist()}\n````" ) - # apply_offset_to_*_kwargs(o) does position -= o; passing -drift gives position += drift. - self.apply_offset_to_line_scan_kwargs(-drift) - self.apply_offset_to_image_acquisition_kwargs(-drift) + return target_ls_pos, target_acq_pos + + def _apply_dual_drift_estimation(self, x_current: np.ndarray) -> None: + """Run both drift estimators, select the best estimate, and apply it. + + Orchestrates :meth:`_get_lsp_drift`, :meth:`_get_reg_drift`, + :meth:`_select_drift`, and :meth:`_apply_chosen_drift`, then updates + the linear drift model for the next arbitration round. + + Parameters + ---------- + x_current : np.ndarray + Current optics parameter vector. + """ + if self.initial_line_scan_position is None or self.initial_image_acquisition_position is None: + raise RuntimeError( + "_apply_dual_drift_estimation requires initial scan positions to be set. " + "Ensure initialize_kwargs_buffers and run_2d_scan have been called first." + ) + + lsp_drift = self._get_lsp_drift() + reg_drift = self._get_reg_drift() + chosen_drift, _chosen_source = self._select_drift(lsp_drift, reg_drift, x_current) + _target_ls_pos, target_acq_pos = self._apply_chosen_drift(chosen_drift) + + # Update model with (parameters -> chosen cumulative drift). Pass the + # target position directly so the model is consistent regardless of + # whether kwargs have been flushed to the underlying tool yet. + self.update_linear_drift_models(x_current, current_position_yx=target_acq_pos) + self.record_linear_drift_model_visualizations() def apply_user_correction_offset(self) -> bool: message = ( @@ -945,8 +1101,8 @@ def apply_user_correction_offset(self) -> bool: except ValueError: logger.info("Invalid offset values. Use numeric values like 'y,x'.") continue - self.apply_offset_to_line_scan_kwargs(offset) - self.apply_offset_to_image_acquisition_kwargs(offset) + self.apply_offset_to_line_scan_kwargs(-offset) + self.apply_offset_to_image_acquisition_kwargs(-offset) correction_message = f"Applied user correction offset: {offset.tolist()}" logger.info(correction_message) self.record_system_message(correction_message) From 6a0d73f27440de67952fbd60f7c53453ba7c6010 Mon Sep 17 00:00:00 2001 From: Ming Du Date: Fri, 6 Mar 2026 16:47:29 -0600 Subject: [PATCH 05/15] FEAT: allow disabling subpixel in error minimization registration --- src/eaa/image_proc.py | 6 ++++++ src/eaa/tool/imaging/registration.py | 3 ++- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/src/eaa/image_proc.py b/src/eaa/image_proc.py index 9cca210..6962dfd 100644 --- a/src/eaa/image_proc.py +++ b/src/eaa/image_proc.py @@ -206,6 +206,7 @@ def error_minimization_registration( ref: np.ndarray, y_valid_fraction: float = 0.8, x_valid_fraction: float = 0.8, + subpixel: bool = True, ) -> np.ndarray: """Image registration by exhaustive integer-shift MSE search with quadratic subpixel refinement. @@ -231,6 +232,8 @@ def error_minimization_registration( Values close to 1 leave little margin and therefore a small search range. x_valid_fraction : float Same as *y_valid_fraction* along the x (column) axis. + subpixel : bool + If True, perform subpixel refinement using a 2D quadratic fit. Returns ------- @@ -271,6 +274,9 @@ def error_minimization_registration( # Fit quadratic in a local neighbourhood around the integer minimum. # Neighbourhood half-width: 10% of image size / 2, at least 1 (→ 3×3 minimum). min_i, min_j = np.unravel_index(np.argmin(error_map), error_map.shape) + if not subpixel: + return -np.array([float(dy_vals[min_i]), float(dx_vals[min_j])]) + half_y = max(1, int(round(0.05 * h))) half_x = max(1, int(round(0.05 * w))) i_lo = max(0, min_i - half_y) diff --git a/src/eaa/tool/imaging/registration.py b/src/eaa/tool/imaging/registration.py index 82570d8..7e8f38e 100644 --- a/src/eaa/tool/imaging/registration.py +++ b/src/eaa/tool/imaging/registration.py @@ -378,6 +378,7 @@ def register_images( - `registration_method="error_minimization"`: - `y_valid_fraction` (float, default: `0.8`) - `x_valid_fraction` (float, default: `0.8`) + - `subpixel` (bool, default: `True`) - `registration_method="sift"` or `"llm"`: - No algorithm kwargs are currently supported; pass `None` or `{}`. @@ -425,7 +426,7 @@ def register_images( **mi_kwargs, ) elif method == "error_minimization": - em_kwargs = {"y_valid_fraction": 0.8, "x_valid_fraction": 0.8} + em_kwargs = {"y_valid_fraction": 0.8, "x_valid_fraction": 0.8, "subpixel": True} em_kwargs.update(algorithm_kwargs) offset = error_minimization_registration(image_t, image_r, **em_kwargs) elif method == "sift": From f02f42416835f55fb41424a9acf3e42ce6b27d6f Mon Sep 17 00:00:00 2001 From: Ming Du Date: Sat, 7 Mar 2026 15:09:37 -0600 Subject: [PATCH 06/15] FIX: merge the routines for the registration branch and neural network branch; rename `find_offset` to `find_position_correction`, unify convention as correction == current_acq_pos - ref_acq_pos --- .../tuning/analytical_focusing.py | 230 +++++++----------- 1 file changed, 92 insertions(+), 138 deletions(-) diff --git a/src/eaa/task_manager/tuning/analytical_focusing.py b/src/eaa/task_manager/tuning/analytical_focusing.py index 1b8fe6e..be4bc68 100644 --- a/src/eaa/task_manager/tuning/analytical_focusing.py +++ b/src/eaa/task_manager/tuning/analytical_focusing.py @@ -324,7 +324,7 @@ def run( self.run_line_scan() # Initialize optimization tool. - self.collect_initial_data_optimization_tool( + self.collect_initial_data_for_optimization_tool( current_x=np.array(list(self.initial_parameters.values())), sampling_range=initial_sampling_window_size, n=n_initial_points, @@ -384,8 +384,8 @@ def initialize_kwargs_buffers( ): self.line_scan_kwargs = copy.deepcopy(initial_line_scan_kwargs) self.image_acquisition_kwargs = copy.deepcopy(initial_2d_scan_kwargs) - if self.line_scan_predictor_tool is not None and self.line_scan_kwargs is not None: - self.initial_line_scan_position = self.extract_scan_position(self.line_scan_kwargs) + if self.line_scan_kwargs is not None: + self.initial_line_scan_position = self.extract_line_scan_position(self.line_scan_kwargs) def run_line_scan(self) -> float: """Run a line scan and return the FWHM of the Gaussian fit. @@ -451,7 +451,7 @@ def parse_json_from_response(self, response_text: str) -> dict[str, Any]: raise ValueError(f"Unable to parse JSON from response: {response_text}") return json.loads(match.group(0)) - def extract_scan_position(self, kwargs: dict[str, float]) -> np.ndarray: + def extract_line_scan_position(self, kwargs: dict[str, float]) -> np.ndarray: if len(self.line_scan_tool_x_coordinate_args) == 0 or len(self.line_scan_tool_y_coordinate_args) == 0: raise ValueError("Line scan coordinate args must not be empty.") x_arg = self.line_scan_tool_x_coordinate_args[0] @@ -484,15 +484,15 @@ def update_linear_drift_models( ): if self.line_scan_predictor_tool is None: return - if self.initial_image_acquisition_position is None: + if self.initial_line_scan_position is None: return if current_position_yx is None: - current_position_yx = self.extract_image_acquisition_position( - self.image_acquisition_kwargs + current_position_yx = self.extract_line_scan_position( + self.line_scan_kwargs ) - delta_yx = current_position_yx - self.initial_image_acquisition_position + delta_yx = current_position_yx - self.initial_line_scan_position x_train = np.array(parameters, dtype=float).reshape(1, -1).tolist() self.drift_model_y.update(x=x_train, y=[[float(delta_yx[0])]]) self.drift_model_x.update(x=x_train, y=[[float(delta_yx[1])]]) @@ -534,24 +534,6 @@ def record_linear_drift_model_visualizations(self) -> None: image_path=image_paths, ) - def apply_predicted_image_acquisition_position(self, parameters: np.ndarray): - current_pos = self.extract_image_acquisition_position(self.image_acquisition_kwargs) - x_in = np.array(parameters, dtype=float).reshape(1, -1).tolist() - delta_y = float(self.drift_model_y.predict(x_in)[0][0]) - delta_x = float(self.drift_model_x.predict(x_in)[0][0]) - predicted_pos = self.initial_image_acquisition_position + np.array( - [delta_y, delta_x], - dtype=float, - ) - delta_pos = predicted_pos - current_pos - self.apply_offset_to_image_acquisition_kwargs(delta_pos) - self.apply_offset_to_line_scan_kwargs(delta_pos) - self.record_system_message( - "Using linear drift prediction for acquisition positions: " - f"```Predicted 2D scan yx = {predicted_pos.tolist()}" - f"Offset from last = {delta_pos.tolist()}```" - ) - def build_line_scan_precheck_message( self, line_scan_result: dict[str, Any], @@ -721,8 +703,8 @@ def check_line_scan( old_line_scan_kwargs = copy.deepcopy(self.line_scan_kwargs) merged_line_scan_kwargs = copy.deepcopy(self.line_scan_kwargs) merged_line_scan_kwargs.update(new_line_scan_kwargs) - old_scan_position = self.extract_scan_position(old_line_scan_kwargs) - new_scan_position = self.extract_scan_position(merged_line_scan_kwargs) + old_scan_position = self.extract_line_scan_position(old_line_scan_kwargs) + new_scan_position = self.extract_line_scan_position(merged_line_scan_kwargs) offset = new_scan_position - old_scan_position self.line_scan_kwargs = merged_line_scan_kwargs if self.run_offset_calibration: @@ -774,14 +756,18 @@ def get_suggested_next_parameters(self, step_size_limit: Optional[float | Tuple[ p_suggested = p_current + signs * step_sizes return p_suggested - def find_offset(self) -> np.ndarray: - """Find the offset between the latest image and the previous image. + def find_position_correction(self) -> np.ndarray: + """Find the correction that should be applied (added) to image acquisition and line scan positions. Returns ------- np.ndarray - The offset between the latest image and the previous image. - Offset is in physical units, i.e., pixel size is already accounted for. + The correction that should be applied to the line scan positions, which include the + pure image registration offset and the scan-position difference. + The correction is in physical units, i.e., pixel size is already accounted for. + np.ndarray + The correction that should be applied to the image acquisition positions, which include only the + pure image registration offset. """ alignment_offset = np.array( self.image_registration_tool.register_images( @@ -793,6 +779,9 @@ def find_offset(self) -> np.ndarray: ), dtype=float, ) + # Image registration offset is the offset by which the moving image should be rolled to match + # the reference. We want acquisition position correction here, which is the negation of it. + registration_correction = -alignment_offset # Count in the difference of scan positions. scan_pos_diff = np.array([ @@ -800,13 +789,8 @@ def find_offset(self) -> np.ndarray: - float(self.acquisition_tool.image_acquisition_call_history[-2][f"{dir}_center"]) for dir in ["y", "x"] ]).astype(float) - offset_to_subtract = alignment_offset - scan_pos_diff - self.record_system_message( - f"Pure image registration offset (to apply to current image for alignment) is " - f"{alignment_offset}. Counting in scan-position difference {scan_pos_diff}, " - f"the offset to subtract from the next line scan positions is {offset_to_subtract}." - ) - return offset_to_subtract, alignment_offset + line_scan_correction = registration_correction + scan_pos_diff + return line_scan_correction, registration_correction def apply_offset_to_line_scan_kwargs(self, offset: np.ndarray): for arg in self.line_scan_tool_x_coordinate_args: @@ -820,7 +804,7 @@ def apply_offset_to_image_acquisition_kwargs(self, offset: np.ndarray): for arg in self.image_acquisition_tool_y_coordinate_args: self.image_acquisition_kwargs[arg] += offset[0] - def collect_initial_data_optimization_tool( + def collect_initial_data_for_optimization_tool( self, current_x: np.ndarray, sampling_range: np.ndarray, @@ -881,17 +865,7 @@ def rollback_and_shrink_delta(message_prefix: str) -> np.ndarray: self.param_setting_tool.set_parameters(x_current) if self.run_offset_calibration: self.run_2d_scan() - if self.line_scan_predictor_tool is not None: - self._apply_dual_drift_estimation(x_current) - else: - line_scan_pos_offset, alignment_offset = self.find_offset() - if np.any(np.isnan(line_scan_pos_offset)): - x_current = rollback_and_shrink_delta("Image registration failed (NaN offset).") - continue - self.apply_offset_to_line_scan_kwargs(-line_scan_pos_offset) - self.apply_offset_to_image_acquisition_kwargs(-alignment_offset) - self.update_linear_drift_models(x_current) - self.record_linear_drift_model_visualizations() + self.apply_drift_correction(x_current) try: fwhm = self.run_line_scan() if np.isnan(fwhm): @@ -902,49 +876,6 @@ def rollback_and_shrink_delta(message_prefix: str) -> np.ndarray: self.update_optimization_model(fwhm) return - def _get_lsp_drift(self) -> np.ndarray: - """Return cumulative drift from initial line scan position via the LSP. - - Calls the line scan predictor and subtracts the stored initial line - scan position to produce a drift vector in the same physical units - used throughout the tracker. - - Returns - ------- - np.ndarray - Shape ``(2,)`` array ``[drift_y, drift_x]`` from the initial line - scan center to the predictor's estimate. - """ - lsp_json = json.loads(self.line_scan_predictor_tool.predict_line_scan_position()) - predicted_center = np.array( - [lsp_json["center_y"], lsp_json["center_x"]], dtype=float - ) - return predicted_center - self.initial_line_scan_position - - def _get_reg_drift(self) -> np.ndarray | None: - """Return cumulative drift from initial acquisition position via image registration. - - Runs image registration between the two most recent 2D acquisitions and - converts the incremental alignment offset into a cumulative drift from - the initial image acquisition position. - - Returns - ------- - np.ndarray or None - Shape ``(2,)`` array ``[drift_y, drift_x]``, or ``None`` if - registration returned NaN (failed). - """ - _line_scan_pos_offset, alignment_offset = self.find_offset() - if np.any(np.isnan(alignment_offset)): - return None - # apply_offset_to_image_acquisition_kwargs(o) does position += o, - # so the registration caller passes -alignment_offset and the new tracked - # position would be current_acq_pos - alignment_offset. - current_acq_pos = self.extract_image_acquisition_position( - self.image_acquisition_kwargs - ) - return current_acq_pos - alignment_offset - self.initial_image_acquisition_position - def _select_drift( self, lsp_drift: np.ndarray, @@ -1022,46 +953,9 @@ def _select_drift( return chosen_drift, chosen_source - def _apply_chosen_drift(self, chosen_drift: np.ndarray) -> tuple[np.ndarray, np.ndarray]: - """Set scan positions to ``initial_position + chosen_drift`` for both trackers. - - ``apply_offset_to_*_kwargs(o)`` subtracts ``o`` from the stored - position, so ``-(target - current)`` is passed to move the position - forward to the target. - - Parameters - ---------- - chosen_drift : np.ndarray - Cumulative drift from the initial position to apply. - - Returns - ------- - target_ls_pos : np.ndarray - target_acq_pos : np.ndarray - """ - current_ls_pos = self.extract_scan_position(self.line_scan_kwargs) - target_ls_pos = self.initial_line_scan_position + chosen_drift - self.apply_offset_to_line_scan_kwargs(target_ls_pos - current_ls_pos) - - current_acq_pos = self.extract_image_acquisition_position(self.image_acquisition_kwargs) - target_acq_pos = self.initial_image_acquisition_position + chosen_drift - self.apply_offset_to_image_acquisition_kwargs(target_acq_pos - current_acq_pos) - - self.record_system_message( - f"Applied dual-drift correction:\n" - f"```chosen_drift={chosen_drift.tolist()}\n" - f"new line_scan_pos={target_ls_pos.tolist()}\n" - f"new acq_pos={target_acq_pos.tolist()}\n````" - ) - return target_ls_pos, target_acq_pos - - def _apply_dual_drift_estimation(self, x_current: np.ndarray) -> None: + def apply_drift_correction(self, x_current: np.ndarray) -> None: """Run both drift estimators, select the best estimate, and apply it. - Orchestrates :meth:`_get_lsp_drift`, :meth:`_get_reg_drift`, - :meth:`_select_drift`, and :meth:`_apply_chosen_drift`, then updates - the linear drift model for the next arbitration round. - Parameters ---------- x_current : np.ndarray @@ -1069,19 +963,79 @@ def _apply_dual_drift_estimation(self, x_current: np.ndarray) -> None: """ if self.initial_line_scan_position is None or self.initial_image_acquisition_position is None: raise RuntimeError( - "_apply_dual_drift_estimation requires initial scan positions to be set. " + "apply_drift_correction requires initial scan positions to be set. " "Ensure initialize_kwargs_buffers and run_2d_scan have been called first." ) - lsp_drift = self._get_lsp_drift() - reg_drift = self._get_reg_drift() - chosen_drift, _chosen_source = self._select_drift(lsp_drift, reg_drift, x_current) - _target_ls_pos, target_acq_pos = self._apply_chosen_drift(chosen_drift) + # Get the offset from the registration. + # line_scan_pos_offset_reg: the offset counting in BOTH the pure image registration and the scan-position difference + # -> to be applied to the line scan position. + # pure_registration_offset: the offset counting in ONLY the pure image registration + # -> to be applied to the image acquisition position. + line_scan_correction_reg, image_acq_correction_reg = self.find_position_correction() + line_scan_correction_reg_wrt_prev = line_scan_correction_reg + line_scan_correction_reg_wrt_initial = ( + self.extract_line_scan_position(self.line_scan_kwargs) + line_scan_correction_reg - self.initial_line_scan_position + ) + image_acq_correction_reg_wrt_prev = image_acq_correction_reg + image_acq_correction_reg_wrt_initial = np.array([np.nan, np.nan]) + + # Get the offset from the line scan predictor. Note that the offset is with regards to the initial image. + if self.line_scan_predictor_tool is not None: + lsp_json = json.loads(self.line_scan_predictor_tool.predict_line_scan_position()) + predicted_center = np.array( + [lsp_json["center_y"], lsp_json["center_x"]], dtype=float + ) + line_scan_correction_lsp_wrt_initial = predicted_center - self.initial_line_scan_position + line_scan_correction_lsp_wrt_prev = ( + predicted_center - self.extract_line_scan_position(self.acquisition_tool.line_scan_call_history[-1]) + ) + scan_pos_diff = np.array([ + float(self.acquisition_tool.image_acquisition_call_history[-1][f"{dir}_center"]) + - float(self.acquisition_tool.image_acquisition_call_history[-2][f"{dir}_center"]) + for dir in ["y", "x"] + ]).astype(float) + image_acq_correction_lsp_wrt_initial = line_scan_correction_lsp_wrt_initial - scan_pos_diff + image_acq_correction_lsp_wrt_prev = line_scan_correction_lsp_wrt_prev - scan_pos_diff + + # Select the best correction to use. + chosen_line_scan_correction_wrt_initial, chosen_source = self._select_drift( + line_scan_correction_lsp_wrt_initial, line_scan_correction_reg_wrt_initial, x_current + ) + chosen_image_acq_correction_wrt_prev = ( + image_acq_correction_lsp_wrt_prev + if chosen_source == "line_scan_predictor" + else image_acq_correction_reg_wrt_prev + ) + chosen_image_acq_correction_wrt_initial = ( + image_acq_correction_lsp_wrt_initial + if chosen_source == "line_scan_predictor" + else image_acq_correction_reg_wrt_prev + ) + else: + chosen_source = "registration" + chosen_line_scan_correction_wrt_prev = line_scan_correction_reg_wrt_prev + chosen_line_scan_correction_wrt_initial = line_scan_correction_reg_wrt_initial + chosen_image_acq_correction_wrt_prev = image_acq_correction_reg_wrt_prev + chosen_image_acq_correction_wrt_initial = image_acq_correction_reg_wrt_initial + + self.apply_offset_to_line_scan_kwargs(chosen_line_scan_correction_wrt_prev) + self.apply_offset_to_image_acquisition_kwargs(chosen_image_acq_correction_wrt_prev) + + self.record_system_message( + f"Applied drift correction:\n" + f"```source = {chosen_source}\n" + f"image_registration_offset = {(-chosen_image_acq_correction_wrt_prev).tolist()}\n" + f"chosen_line_scan_correction_wrt_prev = {chosen_line_scan_correction_wrt_prev.tolist()}\n" + f"chosen_line_scan_correction_wrt_initial = {chosen_line_scan_correction_wrt_initial.tolist()}\n" + f"chosen_image_acq_correction_wrt_prev = {chosen_image_acq_correction_wrt_prev.tolist()}\n" + f"chosen_image_acq_correction_wrt_initial = {chosen_image_acq_correction_wrt_initial.tolist()}```" + ) # Update model with (parameters -> chosen cumulative drift). Pass the # target position directly so the model is consistent regardless of # whether kwargs have been flushed to the underlying tool yet. - self.update_linear_drift_models(x_current, current_position_yx=target_acq_pos) + self.update_linear_drift_models(x_current, current_position_yx=chosen_line_scan_correction_wrt_initial) self.record_linear_drift_model_visualizations() def apply_user_correction_offset(self) -> bool: From 5d0297fc109b2da80c0f7f8b4234db9502d33580 Mon Sep 17 00:00:00 2001 From: Ming Du Date: Mon, 9 Mar 2026 09:18:24 -0500 Subject: [PATCH 07/15] FEAT: allow selecting registration target ("previous" and "initial") --- .../tuning/analytical_focusing.py | 127 ++++++++++++++---- 1 file changed, 101 insertions(+), 26 deletions(-) diff --git a/src/eaa/task_manager/tuning/analytical_focusing.py b/src/eaa/task_manager/tuning/analytical_focusing.py index be4bc68..f274611 100644 --- a/src/eaa/task_manager/tuning/analytical_focusing.py +++ b/src/eaa/task_manager/tuning/analytical_focusing.py @@ -59,6 +59,7 @@ def __init__( image_acquisition_tool_y_coordinate_args: Tuple[str, ...] = ("y_center",), registration_method: Literal["phase_correlation", "sift", "mutual_information", "llm"] = "phase_correlation", registration_algorithm_kwargs: Optional[dict[str, Any]] = None, + registration_target: Literal["previous", "initial"] = "previous", run_line_scan_checker: bool = True, run_offset_calibration: bool = True, line_scan_predictor_tool: Optional[LineScanPredictor] = None, @@ -72,9 +73,12 @@ def __init__( The workflow is as follows: 1. Acquire a 2D image in the user-specified region of interest. 2. Run a line scan at user-specified coordinates and record the FWHM of the Gaussian fit. - 3. Change parameter and acquire a new 2D image. - 4. Run image registration to get the offset and adjust 1D/2D scan coordinates. - 5. Repeat 1 - 3 a few times to collect initial data for Bayesian optimization. + 3. Change parameter and acquire a new 2D image. The change of parameter causes the sample + to drift relative to the beam. + 4. Register the acquired image with the reference image (previous or initial) to estimate + the drift correction that should be applied to the line scan and image acquisition tools. + Update the positions for line scan and image acquisition. + 5. Repeat 1 - 4 a few times to collect initial data for Bayesian optimization. 6. Use Bayesian optimization to suggest new parameters. 7. Change parameter. 8. Run image registration or feature tracking as in 4. @@ -125,6 +129,14 @@ def __init__( registration_algorithm_kwargs : Optional[dict[str, Any]] Optional keyword arguments forwarded to the selected image registration algorithm when aligning consecutive 2D scans. + registration_target : Literal["previous", "initial"], optional + The reference image used by the registration branch of drift + correction. "previous" (default) registers each new 2D scan + against the immediately preceding one; small registration errors + therefore accumulate over many iterations. "initial" registers + every new 2D scan against the very first scan, which prevents + error accumulation at the cost of requiring sufficient overlap + between the current and the initial image. run_line_scan_checker : bool, optional If True, run the LLM-based line-scan quality checker and allow it to request scan-argument adjustments before accepting a line scan. @@ -172,6 +184,11 @@ def __init__( self.registration_algorithm_kwargs = copy.deepcopy( registration_algorithm_kwargs or {} ) + if registration_target not in ("previous", "initial"): + raise ValueError( + f"`registration_target` must be 'previous' or 'initial', got {registration_target!r}." + ) + self.registration_target = registration_target if hasattr(acquisition_tool, "line_scan_return_gaussian_fit"): acquisition_tool.line_scan_return_gaussian_fit = True @@ -756,25 +773,58 @@ def get_suggested_next_parameters(self, step_size_limit: Optional[float | Tuple[ p_suggested = p_current + signs * step_sizes return p_suggested - def find_position_correction(self) -> np.ndarray: + def find_position_correction( + self, target: Literal["previous", "initial"] = "previous" + ) -> tuple[np.ndarray, np.ndarray]: """Find the correction that should be applied (added) to image acquisition and line scan positions. + Parameters + ---------- + target : Literal["previous", "initial"] + The reference image to register against. + "previous": register the current image with the immediately preceding image. + "initial": register the current image with the very first image to prevent + error accumulation. + Returns ------- np.ndarray - The correction that should be applied to the line scan positions, which include the - pure image registration offset and the scan-position difference. + The correction that should be applied to the line scan positions, which includes + the pure image registration offset and the scan-position difference. + When target is "previous", this is the correction relative to the previous step. + When target is "initial", this is the cumulative correction from the initial position. The correction is in physical units, i.e., pixel size is already accounted for. np.ndarray - The correction that should be applied to the image acquisition positions, which include only the - pure image registration offset. + The correction that should be applied to the image acquisition positions, which + includes only the pure image registration offset. + When target is "previous", this is the correction relative to the previous step. + When target is "initial", this is the cumulative correction from the initial position. """ + if target == "previous": + image_r = self.image_registration_tool.process_image(self.acquisition_tool.image_km1) + psize_r = self.acquisition_tool.psize_km1 + scan_pos_diff = np.array([ + float(self.acquisition_tool.image_acquisition_call_history[-1][f"{dir}_center"]) + - float(self.acquisition_tool.image_acquisition_call_history[-2][f"{dir}_center"]) + for dir in ["y", "x"] + ]).astype(float) + elif target == "initial": + image_r = self.image_registration_tool.process_image(self.acquisition_tool.image_0) + psize_r = self.acquisition_tool.psize_0 + scan_pos_diff = np.array([ + float(self.acquisition_tool.image_acquisition_call_history[-1][f"{dir}_center"]) + - float(self.acquisition_tool.image_acquisition_call_history[0][f"{dir}_center"]) + for dir in ["y", "x"] + ]).astype(float) + else: + raise ValueError(f"`target` must be 'previous' or 'initial', got {target!r}.") + alignment_offset = np.array( self.image_registration_tool.register_images( image_t=self.image_registration_tool.process_image(self.acquisition_tool.image_k), - image_r=self.image_registration_tool.process_image(self.acquisition_tool.image_km1), + image_r=image_r, psize_t=self.acquisition_tool.psize_k, - psize_r=self.acquisition_tool.psize_km1, + psize_r=psize_r, registration_algorithm_kwargs=self.registration_algorithm_kwargs, ), dtype=float, @@ -782,13 +832,8 @@ def find_position_correction(self) -> np.ndarray: # Image registration offset is the offset by which the moving image should be rolled to match # the reference. We want acquisition position correction here, which is the negation of it. registration_correction = -alignment_offset - + # Count in the difference of scan positions. - scan_pos_diff = np.array([ - float(self.acquisition_tool.image_acquisition_call_history[-1][f"{dir}_center"]) - - float(self.acquisition_tool.image_acquisition_call_history[-2][f"{dir}_center"]) - for dir in ["y", "x"] - ]).astype(float) line_scan_correction = registration_correction + scan_pos_diff return line_scan_correction, registration_correction @@ -972,13 +1017,38 @@ def apply_drift_correction(self, x_current: np.ndarray) -> None: # -> to be applied to the line scan position. # pure_registration_offset: the offset counting in ONLY the pure image registration # -> to be applied to the image acquisition position. - line_scan_correction_reg, image_acq_correction_reg = self.find_position_correction() - line_scan_correction_reg_wrt_prev = line_scan_correction_reg - line_scan_correction_reg_wrt_initial = ( - self.extract_line_scan_position(self.line_scan_kwargs) + line_scan_correction_reg - self.initial_line_scan_position + line_scan_correction_reg, image_acq_correction_reg = self.find_position_correction( + target=self.registration_target ) - image_acq_correction_reg_wrt_prev = image_acq_correction_reg - image_acq_correction_reg_wrt_initial = np.array([np.nan, np.nan]) + if self.registration_target == "previous": + # find_position_correction returns corrections relative to the previous step. + line_scan_correction_reg_wrt_prev = line_scan_correction_reg + line_scan_correction_reg_wrt_initial = ( + self.extract_line_scan_position(self.line_scan_kwargs) + + line_scan_correction_reg + - self.initial_line_scan_position + ) + image_acq_correction_reg_wrt_prev = image_acq_correction_reg + image_acq_correction_reg_wrt_initial = ( + image_acq_correction_reg + + self.extract_image_acquisition_position(self.image_acquisition_kwargs) + - self.initial_image_acquisition_position + ) + else: + # find_position_correction returns cumulative corrections from the initial position. + # Convert to wrt_prev so the apply helpers can add them to the current kwargs. + line_scan_correction_reg_wrt_initial = line_scan_correction_reg + line_scan_correction_reg_wrt_prev = ( + self.initial_line_scan_position + + line_scan_correction_reg + - self.extract_line_scan_position(self.line_scan_kwargs) + ) + image_acq_correction_reg_wrt_initial = image_acq_correction_reg + image_acq_correction_reg_wrt_prev = ( + self.initial_image_acquisition_position + + image_acq_correction_reg + - self.extract_image_acquisition_position(self.image_acquisition_kwargs) + ) # Get the offset from the line scan predictor. Note that the offset is with regards to the initial image. if self.line_scan_predictor_tool is not None: @@ -1002,15 +1072,20 @@ def apply_drift_correction(self, x_current: np.ndarray) -> None: chosen_line_scan_correction_wrt_initial, chosen_source = self._select_drift( line_scan_correction_lsp_wrt_initial, line_scan_correction_reg_wrt_initial, x_current ) + chosen_line_scan_correction_wrt_prev = ( + line_scan_correction_lsp_wrt_prev + if chosen_source == "line_scan_predictor" + else line_scan_correction_reg_wrt_prev + ) chosen_image_acq_correction_wrt_prev = ( - image_acq_correction_lsp_wrt_prev - if chosen_source == "line_scan_predictor" + image_acq_correction_lsp_wrt_prev + if chosen_source == "line_scan_predictor" else image_acq_correction_reg_wrt_prev ) chosen_image_acq_correction_wrt_initial = ( image_acq_correction_lsp_wrt_initial - if chosen_source == "line_scan_predictor" - else image_acq_correction_reg_wrt_prev + if chosen_source == "line_scan_predictor" + else image_acq_correction_reg_wrt_initial ) else: chosen_source = "registration" From 798e08c0060e7c0d581ce61f8d25f55c773a26c4 Mon Sep 17 00:00:00 2001 From: Ming Du Date: Mon, 9 Mar 2026 11:30:48 -0500 Subject: [PATCH 08/15] FEAT: NN registration integration --- .../tuning/analytical_focusing.py | 229 ++++++++++-------- src/eaa/tool/imaging/nn_registration.py | 141 +++++++++++ 2 files changed, 274 insertions(+), 96 deletions(-) create mode 100644 src/eaa/tool/imaging/nn_registration.py diff --git a/src/eaa/task_manager/tuning/analytical_focusing.py b/src/eaa/task_manager/tuning/analytical_focusing.py index f274611..ce62b83 100644 --- a/src/eaa/task_manager/tuning/analytical_focusing.py +++ b/src/eaa/task_manager/tuning/analytical_focusing.py @@ -18,7 +18,7 @@ from sciagent.message_proc import print_message from eaa.tool.imaging.acquisition import AcquireImage -from eaa.tool.imaging.line_scan_predictor import LineScanPredictor +from eaa.tool.imaging.nn_registration import NNRegistration from eaa.tool.imaging.param_tuning import SetParameters from eaa.task_manager.tuning.base import BaseParameterTuningTaskManager from eaa.tool.imaging.registration import ImageRegistration @@ -62,9 +62,9 @@ def __init__( registration_target: Literal["previous", "initial"] = "previous", run_line_scan_checker: bool = True, run_offset_calibration: bool = True, - line_scan_predictor_tool: Optional[LineScanPredictor] = None, + nn_registration_tool: Optional[NNRegistration] = None, dual_drift_selection_priming_iterations: int = 3, - dual_drift_estimation_primary_source: Literal["registration", "line_scan_predictor"] = "line_scan_predictor", + dual_drift_estimation_primary_source: Literal["registration", "nn_registration"] = "nn_registration", *args, **kwargs ): """Analytical scanning microscope focusing task manager driven @@ -147,12 +147,14 @@ def __init__( dual_drift_selection_priming_iterations : int, optional Number of parameter–drift samples to collect (warm-up phase) before the linear model is used to arbitrate between image registration and - line scan predictor drift estimates. Only relevant when - ``line_scan_predictor_tool`` is provided. - line_scan_predictor_tool : LineScanPredictor, optional - If provided, both image registration and the line scan predictor - are run after each 2D acquisition to estimate drift. For the - first ``n_parameter_drift_points_before_prediction`` parameter + NN registration drift estimates. Only relevant when + ``nn_registration_tool`` is provided. + nn_registration_tool : NNRegistration, optional + If provided, both classical image registration and the NN-based + registration are run after each 2D acquisition to estimate drift. + The NN model always registers the current image against the very + first (initial) image, so its output is a cumulative drift estimate. + For the first ``dual_drift_selection_priming_iterations`` parameter adjustments only the source selected by ``dual_drift_estimation_primary_source`` is used; concurrently a linear model is fit mapping optics parameters to cumulative drift @@ -161,12 +163,12 @@ def __init__( that is closer to the linear model prediction is chosen as the correct drift, and the linear model is then updated with that chosen value. Requires ``run_offset_calibration=True``. - dual_drift_estimation_primary_source : {"registration", "line_scan_predictor"} + dual_drift_estimation_primary_source : {"registration", "nn_registration"} Which drift-estimation method to trust exclusively during the first - ``n_parameter_drift_points_before_prediction`` iterations when - ``line_scan_predictor_tool`` is provided. After that point both + ``dual_drift_selection_priming_iterations`` iterations when + ``nn_registration_tool`` is provided. After that point both methods are evaluated and the linear model arbitrates. Defaults - to ``"line_scan_predictor"``. + to ``"nn_registration"``. """ if acquisition_tool is None: raise ValueError("`acquisition_tool` must be provided.") @@ -210,13 +212,13 @@ def __init__( self.run_line_scan_checker = run_line_scan_checker self.run_offset_calibration = run_offset_calibration - if line_scan_predictor_tool is not None and not run_offset_calibration: + if nn_registration_tool is not None and not run_offset_calibration: raise ValueError( - "`line_scan_predictor_tool` requires `run_offset_calibration=True` " - "because a 2D image must be acquired before the predictor can run." + "`nn_registration_tool` requires `run_offset_calibration=True` " + "because a 2D image must be acquired before the tool can run." ) - - self.line_scan_predictor_tool = line_scan_predictor_tool + + self.nn_registration_tool = nn_registration_tool self.dual_drift_estimation_primary_source = dual_drift_estimation_primary_source self.dual_drift_selection_priming_iterations = ( dual_drift_selection_priming_iterations @@ -499,7 +501,7 @@ def update_linear_drift_models( parameters: np.ndarray, current_position_yx: np.ndarray | None = None, ): - if self.line_scan_predictor_tool is None: + if self.nn_registration_tool is None: return if self.initial_line_scan_position is None: return @@ -774,61 +776,82 @@ def get_suggested_next_parameters(self, step_size_limit: Optional[float | Tuple[ return p_suggested def find_position_correction( - self, target: Literal["previous", "initial"] = "previous" + self, + method: Literal["traditional", "nn"] = "traditional", + target: Literal["previous", "initial"] = "previous", ) -> tuple[np.ndarray, np.ndarray]: """Find the correction that should be applied (added) to image acquisition and line scan positions. Parameters ---------- + method : Literal["traditional", "nn"] + The registration algorithm to use. + "traditional": use :attr:`image_registration_tool` (classical image registration). + "nn": use :attr:`nn_registration_tool` (NN-based registration). target : Literal["previous", "initial"] The reference image to register against. - "previous": register the current image with the immediately preceding image. - "initial": register the current image with the very first image to prevent - error accumulation. + "previous": register the current image against the immediately + preceding image. The returned corrections are relative to the + previous step. + "initial": register the current image against the very first + acquired image. The returned corrections are cumulative from the + initial position, which prevents per-step error accumulation. Returns ------- np.ndarray - The correction that should be applied to the line scan positions, which includes - the pure image registration offset and the scan-position difference. - When target is "previous", this is the correction relative to the previous step. - When target is "initial", this is the cumulative correction from the initial position. - The correction is in physical units, i.e., pixel size is already accounted for. + The correction to apply to the line scan positions, including + both the pure registration offset and the intentional scan-position + difference. + For target="previous": relative to the previous step. + For target="initial": cumulative from the initial position. + Values are in physical units (pixel size already accounted for). np.ndarray - The correction that should be applied to the image acquisition positions, which - includes only the pure image registration offset. - When target is "previous", this is the correction relative to the previous step. - When target is "initial", this is the cumulative correction from the initial position. + The correction to apply to the image acquisition positions (pure + registration offset only, without the intentional scan-position + difference). + For target="previous": relative to the previous step. + For target="initial": cumulative from the initial position. """ - if target == "previous": - image_r = self.image_registration_tool.process_image(self.acquisition_tool.image_km1) - psize_r = self.acquisition_tool.psize_km1 - scan_pos_diff = np.array([ - float(self.acquisition_tool.image_acquisition_call_history[-1][f"{dir}_center"]) - - float(self.acquisition_tool.image_acquisition_call_history[-2][f"{dir}_center"]) - for dir in ["y", "x"] - ]).astype(float) - elif target == "initial": - image_r = self.image_registration_tool.process_image(self.acquisition_tool.image_0) - psize_r = self.acquisition_tool.psize_0 - scan_pos_diff = np.array([ - float(self.acquisition_tool.image_acquisition_call_history[-1][f"{dir}_center"]) - - float(self.acquisition_tool.image_acquisition_call_history[0][f"{dir}_center"]) - for dir in ["y", "x"] - ]).astype(float) - else: + if target not in ("previous", "initial"): raise ValueError(f"`target` must be 'previous' or 'initial', got {target!r}.") - alignment_offset = np.array( - self.image_registration_tool.register_images( - image_t=self.image_registration_tool.process_image(self.acquisition_tool.image_k), - image_r=image_r, - psize_t=self.acquisition_tool.psize_k, - psize_r=psize_r, - registration_algorithm_kwargs=self.registration_algorithm_kwargs, - ), - dtype=float, - ) + # Scan-position difference depends only on the target image. + ref_history_idx = -2 if target == "previous" else 0 + scan_pos_diff = np.array([ + float(self.acquisition_tool.image_acquisition_call_history[-1][f"{dir}_center"]) + - float(self.acquisition_tool.image_acquisition_call_history[ref_history_idx][f"{dir}_center"]) + for dir in ["y", "x"] + ]).astype(float) + + # Alignment offset depends on the registration method (and the target image). + if method == "traditional": + if target == "previous": + image_r = self.image_registration_tool.process_image(self.acquisition_tool.image_km1) + psize_r = self.acquisition_tool.psize_km1 + else: # "initial" + image_r = self.image_registration_tool.process_image(self.acquisition_tool.image_0) + psize_r = self.acquisition_tool.psize_0 + alignment_offset = np.array( + self.image_registration_tool.register_images( + image_t=self.image_registration_tool.process_image(self.acquisition_tool.image_k), + image_r=image_r, + psize_t=self.acquisition_tool.psize_k, + psize_r=psize_r, + registration_algorithm_kwargs=self.registration_algorithm_kwargs, + ), + dtype=float, + ) + elif method == "nn": + if self.nn_registration_tool is None: + raise RuntimeError( + "`nn_registration_tool` is not set. Provide one in the constructor " + "to use method='nn'." + ) + alignment_offset = self.nn_registration_tool.get_offset(target=target) + else: + raise ValueError(f"`method` must be 'traditional' or 'nn', got {method!r}.") + # Image registration offset is the offset by which the moving image should be rolled to match # the reference. We want acquisition position correction here, which is the negation of it. registration_correction = -alignment_offset @@ -923,7 +946,7 @@ def rollback_and_shrink_delta(message_prefix: str) -> np.ndarray: def _select_drift( self, - lsp_drift: np.ndarray, + nn_drift: np.ndarray, reg_drift: np.ndarray | None, x_current: np.ndarray, ) -> tuple[np.ndarray, str]: @@ -936,10 +959,10 @@ def _select_drift( Parameters ---------- - lsp_drift : np.ndarray - Cumulative drift from the line scan predictor. + nn_drift : np.ndarray + Cumulative drift from NN registration. reg_drift : np.ndarray or None - Cumulative drift from image registration, or ``None`` on failure. + Cumulative drift from classical image registration, or ``None`` on failure. x_current : np.ndarray Current optics parameter vector (used in the arbitration phase). @@ -947,25 +970,25 @@ def _select_drift( ------- chosen_drift : np.ndarray chosen_source : str - ``"line_scan_predictor"`` or ``"registration"``. + ``"nn_registration"`` or ``"registration"``. """ n_collected = self.drift_model_y.get_n_parameter_drift_points_collected() n_needed = self.dual_drift_selection_priming_iterations if n_collected < n_needed: - if self.dual_drift_estimation_primary_source == "line_scan_predictor" or reg_drift is None: - chosen_drift, chosen_source = lsp_drift, "line_scan_predictor" + if self.dual_drift_estimation_primary_source == "nn_registration" or reg_drift is None: + chosen_drift, chosen_source = nn_drift, "nn_registration" else: chosen_drift, chosen_source = reg_drift, "registration" if reg_drift is None: self.record_system_message( "Dual drift estimation: image registration returned NaN; " - "falling back to line scan predictor." + "falling back to nn_registration." ) self.record_system_message( f"Dual drift estimation (primary phase, n={n_collected}/{n_needed}): " f"```using {chosen_source}\n" - f"lsp_drift={lsp_drift.tolist()}\n" + f"nn_drift={nn_drift.tolist()}\n" f"reg_drift={reg_drift.tolist() if reg_drift is not None else 'NaN'}```" ) else: @@ -977,21 +1000,21 @@ def _select_drift( ], dtype=float, ) - dist_lsp = float(np.linalg.norm(lsp_drift - model_drift)) + dist_nn = float(np.linalg.norm(nn_drift - model_drift)) dist_reg = ( float(np.linalg.norm(reg_drift - model_drift)) if reg_drift is not None else np.inf ) - if dist_lsp <= dist_reg: - chosen_drift, chosen_source = lsp_drift, "line_scan_predictor" + if dist_nn <= dist_reg: + chosen_drift, chosen_source = nn_drift, "nn_registration" else: chosen_drift, chosen_source = reg_drift, "registration" dist_reg_str = f"{dist_reg:.4f}" if reg_drift is not None else "inf" self.record_system_message( f"Dual drift estimation (arbitration phase):\n" f"```model_drift={model_drift.tolist()}\n" - f"lsp_drift={lsp_drift.tolist()} (dist={dist_lsp:.4f})\n" + f"nn_drift={nn_drift.tolist()} (dist={dist_nn:.4f})\n" f"reg_drift={reg_drift.tolist() if reg_drift is not None else 'NaN'} (dist={dist_reg_str})\n" f"Chosen: {chosen_source}```" ) @@ -1018,7 +1041,7 @@ def apply_drift_correction(self, x_current: np.ndarray) -> None: # pure_registration_offset: the offset counting in ONLY the pure image registration # -> to be applied to the image acquisition position. line_scan_correction_reg, image_acq_correction_reg = self.find_position_correction( - target=self.registration_target + method="traditional", target=self.registration_target ) if self.registration_target == "previous": # find_position_correction returns corrections relative to the previous step. @@ -1050,41 +1073,55 @@ def apply_drift_correction(self, x_current: np.ndarray) -> None: - self.extract_image_acquisition_position(self.image_acquisition_kwargs) ) - # Get the offset from the line scan predictor. Note that the offset is with regards to the initial image. - if self.line_scan_predictor_tool is not None: - lsp_json = json.loads(self.line_scan_predictor_tool.predict_line_scan_position()) - predicted_center = np.array( - [lsp_json["center_y"], lsp_json["center_x"]], dtype=float - ) - line_scan_correction_lsp_wrt_initial = predicted_center - self.initial_line_scan_position - line_scan_correction_lsp_wrt_prev = ( - predicted_center - self.extract_line_scan_position(self.acquisition_tool.line_scan_call_history[-1]) + # Get the offset from the NN registration tool using the same target as classical registration. + if self.nn_registration_tool is not None: + line_scan_correction_nn, image_acq_correction_nn = self.find_position_correction( + method="nn", target=self.registration_target ) - scan_pos_diff = np.array([ - float(self.acquisition_tool.image_acquisition_call_history[-1][f"{dir}_center"]) - - float(self.acquisition_tool.image_acquisition_call_history[-2][f"{dir}_center"]) - for dir in ["y", "x"] - ]).astype(float) - image_acq_correction_lsp_wrt_initial = line_scan_correction_lsp_wrt_initial - scan_pos_diff - image_acq_correction_lsp_wrt_prev = line_scan_correction_lsp_wrt_prev - scan_pos_diff + if self.registration_target == "previous": + line_scan_correction_nn_wrt_prev = line_scan_correction_nn + line_scan_correction_nn_wrt_initial = ( + self.extract_line_scan_position(self.line_scan_kwargs) + + line_scan_correction_nn + - self.initial_line_scan_position + ) + image_acq_correction_nn_wrt_prev = image_acq_correction_nn + image_acq_correction_nn_wrt_initial = ( + image_acq_correction_nn + + self.extract_image_acquisition_position(self.image_acquisition_kwargs) + - self.initial_image_acquisition_position + ) + else: # "initial" + line_scan_correction_nn_wrt_initial = line_scan_correction_nn + line_scan_correction_nn_wrt_prev = ( + self.initial_line_scan_position + + line_scan_correction_nn + - self.extract_line_scan_position(self.line_scan_kwargs) + ) + image_acq_correction_nn_wrt_initial = image_acq_correction_nn + image_acq_correction_nn_wrt_prev = ( + self.initial_image_acquisition_position + + image_acq_correction_nn + - self.extract_image_acquisition_position(self.image_acquisition_kwargs) + ) # Select the best correction to use. chosen_line_scan_correction_wrt_initial, chosen_source = self._select_drift( - line_scan_correction_lsp_wrt_initial, line_scan_correction_reg_wrt_initial, x_current + line_scan_correction_nn_wrt_initial, line_scan_correction_reg_wrt_initial, x_current ) chosen_line_scan_correction_wrt_prev = ( - line_scan_correction_lsp_wrt_prev - if chosen_source == "line_scan_predictor" + line_scan_correction_nn_wrt_prev + if chosen_source == "nn_registration" else line_scan_correction_reg_wrt_prev ) chosen_image_acq_correction_wrt_prev = ( - image_acq_correction_lsp_wrt_prev - if chosen_source == "line_scan_predictor" + image_acq_correction_nn_wrt_prev + if chosen_source == "nn_registration" else image_acq_correction_reg_wrt_prev ) chosen_image_acq_correction_wrt_initial = ( - image_acq_correction_lsp_wrt_initial - if chosen_source == "line_scan_predictor" + image_acq_correction_nn_wrt_initial + if chosen_source == "nn_registration" else image_acq_correction_reg_wrt_initial ) else: diff --git a/src/eaa/tool/imaging/nn_registration.py b/src/eaa/tool/imaging/nn_registration.py new file mode 100644 index 0000000..255c0d4 --- /dev/null +++ b/src/eaa/tool/imaging/nn_registration.py @@ -0,0 +1,141 @@ +import io +import logging +from typing import Literal + +import numpy as np +import requests +import tifffile +from sciagent.tool.base import BaseTool, check, ToolReturnType, tool + +from eaa.tool.imaging.acquisition import AcquireImage + +logger = logging.getLogger(__name__) + + +class NNRegistration(BaseTool): + """Tool that queries a NN server to obtain a registration offset between + a reference image and the current image. + + The server must serve a model with ``prediction_type="offset"``, which + accepts ``ref_image`` and ``test_image`` and returns + ``{"offset_y": float, "offset_x": float}`` as fractions of the reference + image size. + The offset convention matches :class:`~eaa.tool.imaging.registration.ImageRegistration`: + the returned values are the translation to apply to the test image so it + aligns with the reference. + """ + + name: str = "nn_registration" + + @check + def __init__( + self, + server_url: str, + image_acquisition_tool: AcquireImage, + require_approval: bool = False, + *args, + **kwargs, + ): + """Initialize the NN registration tool. + + Parameters + ---------- + server_url : str + Base URL of the inference server, e.g. ``"http://localhost:8090"``. + image_acquisition_tool : AcquireImage + The image acquisition tool instance. Must be the same object used + by the task manager so that its image buffers and call history + reflect the current state. + """ + super().__init__(*args, require_approval=require_approval, **kwargs) + self.server_url = server_url.rstrip("/") + self.image_acquisition_tool = image_acquisition_tool + + @staticmethod + def _encode_as_tiff(image: np.ndarray) -> bytes: + buf = io.BytesIO() + tifffile.imwrite(buf, image.astype(np.float32)) + return buf.getvalue() + + @tool(name="get_offset", return_type=ToolReturnType.TEXT) + def get_offset(self, target: Literal["previous", "initial"] = "initial") -> np.ndarray: + """Query the server and return the registration offset in physical units. + + Parameters + ---------- + target : Literal["previous", "initial"] + The reference image to register against. + "previous": register the current image (``image_k``) against the + immediately preceding image (``image_km1``). + "initial": register the current image (``image_k``) against the + first acquired image (``image_0``), giving the cumulative drift from + the initial acquisition. + + Returns + ------- + np.ndarray + ``[offset_y, offset_x]`` in physical coordinate units (same units + as the image acquisition positions). The convention matches + :meth:`~eaa.tool.imaging.registration.ImageRegistration.register_images`: + this is the translation to apply to the current image so it aligns + with the reference image. + + Raises + ------ + RuntimeError + If the required image buffers or acquisition history are not populated. + requests.HTTPError + If the server returns a non-2xx response. + """ + acq = self.image_acquisition_tool + + if acq.image_k is None: + raise RuntimeError( + "Current image buffer (image_k) is not populated. " + "Acquire at least one image before calling get_offset." + ) + if not acq.image_acquisition_call_history: + raise RuntimeError("No image acquisition history found.") + + if target == "previous": + if acq.image_km1 is None: + raise RuntimeError( + "Previous image buffer (image_km1) is not populated. " + "Acquire at least two images before calling get_offset with target='previous'." + ) + ref_image = acq.image_km1 + ref_img_info = acq.image_acquisition_call_history[-2] + elif target == "initial": + if acq.image_0 is None: + raise RuntimeError( + "Initial image buffer (image_0) is not populated. " + "Acquire at least one image before calling get_offset with target='initial'." + ) + ref_image = acq.image_0 + ref_img_info = acq.image_acquisition_call_history[0] + else: + raise ValueError(f"`target` must be 'previous' or 'initial', got {target!r}.") + + ref_tiff = self._encode_as_tiff(ref_image) + test_tiff = self._encode_as_tiff(acq.image_k) + + response = requests.post( + f"{self.server_url}/predict", + files={ + "ref_image": ("ref_image.tif", ref_tiff, "image/tiff"), + "test_image": ("test_image.tif", test_tiff, "image/tiff"), + }, + ) + response.raise_for_status() + result = response.json() + + # Convert fractions of the reference image size to physical units. + offset_y_phys = float(result["offset_y"]) * float(ref_img_info["size_y"]) + offset_x_phys = float(result["offset_x"]) * float(ref_img_info["size_x"]) + + logger.debug( + "NNRegistration offset (target=%s): frac=(%.4f, %.4f), phys=(%.4f, %.4f)", + target, result["offset_y"], result["offset_x"], + offset_y_phys, offset_x_phys, + ) + return np.array([offset_y_phys, offset_x_phys], dtype=float) From 0f62731ccc1be5b9679749264dd40170e1fbd924 Mon Sep 17 00:00:00 2001 From: Ming Du Date: Mon, 9 Mar 2026 13:28:27 -0500 Subject: [PATCH 09/15] FEAT: test target Siemens star fitting tool --- .../aps_mic/test_target_landmark_fitting.py | 332 ++++++++++++++++++ 1 file changed, 332 insertions(+) create mode 100644 src/eaa/tool/imaging/aps_mic/test_target_landmark_fitting.py diff --git a/src/eaa/tool/imaging/aps_mic/test_target_landmark_fitting.py b/src/eaa/tool/imaging/aps_mic/test_target_landmark_fitting.py new file mode 100644 index 0000000..3e3ab8c --- /dev/null +++ b/src/eaa/tool/imaging/aps_mic/test_target_landmark_fitting.py @@ -0,0 +1,332 @@ +from typing import Annotated, Optional +import logging + +import matplotlib.pyplot as plt +import numpy as np +import scipy.ndimage as ndi +from skimage import filters, measure +from sciagent.tool.base import BaseTool, ToolReturnType, check, tool + +from eaa.tool.imaging.acquisition import AcquireImage + +logger = logging.getLogger(__name__) + + +class TestPatternLandmarkFitting(BaseTool): + """Fit the circular landmark on the right side of an APS-MIC test image. + + The expected landmark is a bright or dark disk-shaped feature that may be + only partially visible near the right image boundary. The fitting workflow + is: + + 1. Apply Gaussian smoothing and normalize the image to ``[0, 1]``. + 2. Segment the landmark candidate with binary thresholding. + 3. Apply binary erosion followed by binary dilation with a 3x3 structure + element for 3 iterations to suppress small islands and weak bridges. + 4. Run connected-component analysis and keep the component whose right-most + pixel has the largest x coordinate. + 5. Extract the boundary pixels of that component, excluding pixels on the + outer image border, and fit a circle to those arc pixels with RANSAC. + + The returned center is expressed in the coordinate system of the original, + uncropped image. + """ + + name: str = "test_pattern_landmark_fitting" + + @check + def __init__( + self, + image_acquisition_tool: Optional[AcquireImage] = None, + zoom: float = 4.0, + gaussian_sigma_fraction: float = 0.03, + ransac_residual_threshold_fraction: float = 0.015, + ransac_max_trials: int = 1000, + require_approval: bool = False, + *args, + **kwargs, + ): + """Initialize the landmark fitting tool. + + Parameters + ---------- + image_acquisition_tool : Optional[AcquireImage], optional + Acquisition tool that provides ``image_k`` when no image is passed + directly to :meth:`fit_landmark_center`. + zoom : float, optional + Zoom factor applied to the image before segmentation and fitting. + Returned coordinates are converted back to the original image scale. + gaussian_sigma_fraction : float, optional + Gaussian blur sigma expressed as a fraction of the original image + width in pixels. + ransac_residual_threshold_fraction : float, optional + RANSAC inlier threshold expressed as a fraction of the cropped + image size. + ransac_max_trials : int, optional + Maximum number of RANSAC iterations used for the circle fit. + require_approval : bool, optional + Whether tool execution requires explicit approval in the agent + framework. + """ + super().__init__(*args, require_approval=require_approval, **kwargs) + + self.image_acquisition_tool = image_acquisition_tool + self.zoom = zoom + self.gaussian_sigma_fraction = gaussian_sigma_fraction + self.ransac_residual_threshold_fraction = ransac_residual_threshold_fraction + self.ransac_max_trials = ransac_max_trials + self.latest_image: Optional[np.ndarray] = None + self.latest_circle_model: Optional[measure.CircleModel] = None + self.latest_circle_inliers: Optional[np.ndarray] = None + self.latest_processed_image: Optional[np.ndarray] = None + + def preprocess_image(self, image: np.ndarray) -> np.ndarray: + """Convert the input to a finite 2D float image. + + Parameters + ---------- + image : np.ndarray + Input image. If the input is 3D, the last axis is averaged. + + Returns + ------- + np.ndarray + A 2D floating-point image with non-finite values replaced by the + mean of the finite pixels. + """ + arr = np.asarray(image, dtype=float) + if arr.ndim == 3: + arr = np.mean(arr, axis=-1) + if arr.ndim != 2: + raise ValueError(f"Expected a 2D image, got shape {arr.shape}.") + + if not np.isfinite(arr).all(): + finite_mask = np.isfinite(arr) + if not finite_mask.any(): + raise ValueError("Input image does not contain any finite pixels.") + arr = arr.copy() + arr[~finite_mask] = float(np.mean(arr[finite_mask])) + return arr + + def get_input_image(self, image: Optional[np.ndarray]) -> np.ndarray: + """Return the image to process. + + Parameters + ---------- + image : Optional[np.ndarray] + Explicit image to fit. If omitted, ``image_acquisition_tool.image_k`` + is used. + + Returns + ------- + np.ndarray + Preprocessed 2D image. + """ + if image is not None: + return self.preprocess_image(image) + if self.image_acquisition_tool is None or self.image_acquisition_tool.image_k is None: + raise ValueError( + "No image was provided and image_acquisition_tool.image_k is not available." + ) + return self.preprocess_image(self.image_acquisition_tool.image_k) + + def zoom_image(self, image: np.ndarray) -> np.ndarray: + """Zoom the image before processing. + + Parameters + ---------- + image : np.ndarray + Preprocessed 2D image in original coordinates. + + Returns + ------- + np.ndarray + Image in the processing scale. + """ + if self.zoom <= 0: + raise ValueError("zoom must be positive.") + if self.zoom == 1.0: + return image + return ndi.zoom(image, zoom=self.zoom, order=1, mode="nearest") + + @staticmethod + def normalize_image(image: np.ndarray) -> np.ndarray: + """Robustly normalize an image to ``[0, 1]`` using percentiles.""" + lo = float(np.percentile(image, 1)) + hi = float(np.percentile(image, 99)) + if hi <= lo: + lo = float(np.min(image)) + hi = float(np.max(image)) + if hi <= lo: + return np.zeros_like(image, dtype=float) + return np.clip((image - lo) / (hi - lo), 0.0, 1.0) + + def detect_edge_points(self, image: np.ndarray) -> np.ndarray: + """Detect candidate circle-arc points from the segmented landmark. + + Parameters + ---------- + image : np.ndarray + Preprocessed 2D image in original coordinates. + + Returns + ------- + np.ndarray + Boundary points as ``(x, y)`` pairs in original image coordinates. + """ + sigma = max(0.5, self.gaussian_sigma_fraction * image.shape[1]) + blurred = ndi.gaussian_filter(image, sigma=sigma, mode="nearest") + normalized = self.normalize_image(blurred) + threshold = filters.threshold_otsu(normalized) + binary = normalized >= threshold + + structure = np.ones((3, 3), dtype=bool) + # binary = ndi.binary_erosion(binary, structure=structure, iterations=1) + # binary = ndi.binary_dilation(binary, structure=structure, iterations=1) + # binary = ndi.binary_closing(binary, structure=structure, iterations=1) + + labels, num_labels = ndi.label(binary, structure=structure) + if num_labels == 0: + raise ValueError("No connected component was found in the segmented image.") + + selected_label = None + rightmost_x = -1 + for label_id in range(1, num_labels + 1): + label_y, label_x = np.nonzero(labels == label_id) + if label_x.size == 0: + continue + label_rightmost_x = int(np.max(label_x)) + if label_rightmost_x > rightmost_x: + rightmost_x = label_rightmost_x + selected_label = label_id + + if selected_label is None: + raise ValueError("Failed to select a segmented landmark component.") + + component = labels == selected_label + boundary = component & ~ndi.binary_erosion(component, structure=structure, iterations=1) + valid_support = np.zeros_like(component, dtype=bool) + valid_support[1:-1, 1:-1] = True + boundary &= valid_support + + edge_y, edge_x = np.nonzero(boundary) + if edge_x.size < 3: + raise ValueError("Segmented landmark boundary has fewer than 3 candidate points.") + + points = np.column_stack((edge_x.astype(float), edge_y.astype(float))) + return points + + def fit_circle( + self, + points: np.ndarray, + cropped_shape: tuple[int, int], + ) -> tuple[measure.CircleModel, np.ndarray]: + """Fit a circle with RANSAC and return its parameters. + + Parameters + ---------- + points : np.ndarray + Candidate edge points as ``(x, y)`` pairs in the original image + frame. + cropped_shape : tuple[int, int] + Shape of the image as ``(height, width)``. + + Returns + ------- + tuple[measure.CircleModel, np.ndarray] + The fitted circle model and its boolean inlier mask. + """ + residual_threshold = max( + 1.0, + self.ransac_residual_threshold_fraction * max(cropped_shape), + ) + model, inliers = measure.ransac( + points, + measure.CircleModel, + min_samples=3, + residual_threshold=residual_threshold, + max_trials=self.ransac_max_trials, + ) + if model is None or inliers is None or int(np.count_nonzero(inliers)) < 3: + raise ValueError("RANSAC could not find a valid circle from the detected edges.") + + center_x, center_y, radius = model.params + logger.debug( + "Circle fit: center=(%.3f, %.3f), radius=%.3f, inliers=%d/%d", + center_x, + center_y, + radius, + int(np.count_nonzero(inliers)), + len(points), + ) + return model, inliers + + def plot_last_fit(self) -> plt.Figure: + """Plot the most recent image with the fitted circle overlaid. + + Returns + ------- + matplotlib.figure.Figure + Figure showing the stored image and the last fitted circle. + """ + if self.latest_image is None or self.latest_circle_model is None: + raise ValueError("No fitted landmark is available. Call fit_landmark_center first.") + + center_x, center_y, radius = self.latest_circle_model.params + theta = np.linspace(0.0, 2.0 * np.pi, 512) + circle_x = center_x + radius * np.cos(theta) + circle_y = center_y + radius * np.sin(theta) + + fig, ax = plt.subplots(1, 1, squeeze=True) + ax.imshow(self.latest_image, cmap="viridis", origin="upper") + ax.plot(circle_x, circle_y, color="cyan", linewidth=1.5) + ax.scatter([center_x], [center_y], color="red", s=30) + ax.set_title("Landmark Circle Fit") + ax.set_xlim(-0.5, self.latest_image.shape[1] - 0.5) + ax.set_ylim(self.latest_image.shape[0] - 0.5, -0.5) + return fig + + @tool(name="fit_landmark_center", return_type=ToolReturnType.LIST) + def fit_landmark_center( + self, + image: Annotated[ + Optional[np.ndarray], + "Optional 2D image array. When omitted, the tool uses image_acquisition_tool.image_k.", + ] = None, + ) -> Annotated[ + list[float], + "The fitted landmark center as [center_y, center_x] in the coordinates of the original image.", + ]: + """Detect the right-side disk feature and return its center. + + Parameters + ---------- + image : Optional[np.ndarray], optional + Explicit image array to process. If omitted, the latest image from + ``image_acquisition_tool`` is used. + + Returns + ------- + list[float] + Landmark center as ``[center_y, center_x]`` in the coordinates of + the original image. + """ + image_arr = self.get_input_image(image) + processed_image = self.zoom_image(image_arr) + points = self.detect_edge_points(processed_image) + model, inliers = self.fit_circle( + points, + cropped_shape=processed_image.shape, + ) + center_x, center_y, radius = model.params + circle_model = measure.CircleModel() + circle_model.params = ( + float(center_x / self.zoom), + float(center_y / self.zoom), + float(radius / self.zoom), + ) + self.latest_image = image_arr + self.latest_processed_image = processed_image + self.latest_circle_model = circle_model + self.latest_circle_inliers = inliers + return [float(center_y / self.zoom), float(center_x / self.zoom)] From e9467cef7bf0c4d2b9d8b9d40e89adf6ff71b69e Mon Sep 17 00:00:00 2001 From: Ming Du Date: Mon, 9 Mar 2026 14:16:37 -0500 Subject: [PATCH 10/15] FEAT: integrate test pattern landmark fitting tool (replace NN registration) --- .../tuning/analytical_focusing.py | 159 +++++++++++------- .../aps_mic/test_target_landmark_fitting.py | 120 +++++++++++-- 2 files changed, 207 insertions(+), 72 deletions(-) diff --git a/src/eaa/task_manager/tuning/analytical_focusing.py b/src/eaa/task_manager/tuning/analytical_focusing.py index ce62b83..47abf60 100644 --- a/src/eaa/task_manager/tuning/analytical_focusing.py +++ b/src/eaa/task_manager/tuning/analytical_focusing.py @@ -18,7 +18,9 @@ from sciagent.message_proc import print_message from eaa.tool.imaging.acquisition import AcquireImage -from eaa.tool.imaging.nn_registration import NNRegistration +from eaa.tool.imaging.aps_mic.test_target_landmark_fitting import ( + TestPatternLandmarkFitting, +) from eaa.tool.imaging.param_tuning import SetParameters from eaa.task_manager.tuning.base import BaseParameterTuningTaskManager from eaa.tool.imaging.registration import ImageRegistration @@ -62,9 +64,9 @@ def __init__( registration_target: Literal["previous", "initial"] = "previous", run_line_scan_checker: bool = True, run_offset_calibration: bool = True, - nn_registration_tool: Optional[NNRegistration] = None, + landmark_fitting_tool: Optional[TestPatternLandmarkFitting] = None, dual_drift_selection_priming_iterations: int = 3, - dual_drift_estimation_primary_source: Literal["registration", "nn_registration"] = "nn_registration", + dual_drift_estimation_primary_source: Literal["registration", "landmark_fitting"] = "landmark_fitting", *args, **kwargs ): """Analytical scanning microscope focusing task manager driven @@ -147,13 +149,14 @@ def __init__( dual_drift_selection_priming_iterations : int, optional Number of parameter–drift samples to collect (warm-up phase) before the linear model is used to arbitrate between image registration and - NN registration drift estimates. Only relevant when - ``nn_registration_tool`` is provided. - nn_registration_tool : NNRegistration, optional - If provided, both classical image registration and the NN-based + landmark fitting drift estimates. Only relevant when + ``landmark_fitting_tool`` is provided. + landmark_fitting_tool : TestPatternLandmarkFitting, optional + If provided, both classical image registration and the landmark-based registration are run after each 2D acquisition to estimate drift. - The NN model always registers the current image against the very - first (initial) image, so its output is a cumulative drift estimate. + The landmark-fitting branch estimates the disk center in the + current image and in the selected reference image, then subtracts + those centers to obtain a physical drift estimate. For the first ``dual_drift_selection_priming_iterations`` parameter adjustments only the source selected by ``dual_drift_estimation_primary_source`` is used; concurrently a @@ -163,12 +166,12 @@ def __init__( that is closer to the linear model prediction is chosen as the correct drift, and the linear model is then updated with that chosen value. Requires ``run_offset_calibration=True``. - dual_drift_estimation_primary_source : {"registration", "nn_registration"} + dual_drift_estimation_primary_source : {"registration", "landmark_fitting"} Which drift-estimation method to trust exclusively during the first ``dual_drift_selection_priming_iterations`` iterations when - ``nn_registration_tool`` is provided. After that point both + ``landmark_fitting_tool`` is provided. After that point both methods are evaluated and the linear model arbitrates. Defaults - to ``"nn_registration"``. + to ``"landmark_fitting"``. """ if acquisition_tool is None: raise ValueError("`acquisition_tool` must be provided.") @@ -212,13 +215,18 @@ def __init__( self.run_line_scan_checker = run_line_scan_checker self.run_offset_calibration = run_offset_calibration - if nn_registration_tool is not None and not run_offset_calibration: + if landmark_fitting_tool is not None and not run_offset_calibration: raise ValueError( - "`nn_registration_tool` requires `run_offset_calibration=True` " + "`landmark_fitting_tool` requires `run_offset_calibration=True` " "because a 2D image must be acquired before the tool can run." ) - self.nn_registration_tool = nn_registration_tool + self.landmark_fitting_tool = landmark_fitting_tool + if ( + self.landmark_fitting_tool is not None + and self.landmark_fitting_tool.image_acquisition_tool is None + ): + self.landmark_fitting_tool.image_acquisition_tool = acquisition_tool self.dual_drift_estimation_primary_source = dual_drift_estimation_primary_source self.dual_drift_selection_priming_iterations = ( dual_drift_selection_priming_iterations @@ -266,7 +274,8 @@ def create_image_registration_tool( reference_pixel_size=1.0, image_coordinates_origin="top_left", registration_method=registration_method, - log_scale=True + log_scale=True, + zoom=4 ) return image_registration_tool @@ -501,7 +510,7 @@ def update_linear_drift_models( parameters: np.ndarray, current_position_yx: np.ndarray | None = None, ): - if self.nn_registration_tool is None: + if self.landmark_fitting_tool is None: return if self.initial_line_scan_position is None: return @@ -777,17 +786,17 @@ def get_suggested_next_parameters(self, step_size_limit: Optional[float | Tuple[ def find_position_correction( self, - method: Literal["traditional", "nn"] = "traditional", + method: Literal["traditional", "landmark"] = "traditional", target: Literal["previous", "initial"] = "previous", ) -> tuple[np.ndarray, np.ndarray]: """Find the correction that should be applied (added) to image acquisition and line scan positions. Parameters ---------- - method : Literal["traditional", "nn"] + method : Literal["traditional", "landmark"] The registration algorithm to use. "traditional": use :attr:`image_registration_tool` (classical image registration). - "nn": use :attr:`nn_registration_tool` (NN-based registration). + "landmark": use :attr:`landmark_fitting_tool` (landmark-based registration). target : Literal["previous", "initial"] The reference image to register against. "previous": register the current image against the immediately @@ -842,15 +851,41 @@ def find_position_correction( ), dtype=float, ) - elif method == "nn": - if self.nn_registration_tool is None: + elif method == "landmark": + if self.landmark_fitting_tool is None: raise RuntimeError( - "`nn_registration_tool` is not set. Provide one in the constructor " - "to use method='nn'." + "`landmark_fitting_tool` is not set. Provide one in the constructor " + "to use method='landmark'." ) - alignment_offset = self.nn_registration_tool.get_offset(target=target) + alignment_offset = np.array( + self.landmark_fitting_tool.get_offset(target=target), + dtype=float, + ) + landmark_fig = None + try: + landmark_fig = self.landmark_fitting_tool.plot_last_fit() + landmark_fig_path = BaseTool.save_image_to_temp_dir( + fig=landmark_fig, + filename="landmark_fitting_overlay.png", + add_timestamp=True, + ) + self.record_system_message( + content=( + "Landmark fitting result for drift estimation:\n" + f"```target={target}\n" + f"alignment_offset={alignment_offset.tolist()}```" + ), + image_path=landmark_fig_path, + ) + except Exception as exc: + logger.warning("Failed to render landmark fitting overlay: %s", exc) + finally: + if landmark_fig is not None: + plt.close(landmark_fig) else: - raise ValueError(f"`method` must be 'traditional' or 'nn', got {method!r}.") + raise ValueError( + f"`method` must be 'traditional' or 'landmark', got {method!r}." + ) # Image registration offset is the offset by which the moving image should be rolled to match # the reference. We want acquisition position correction here, which is the negation of it. @@ -946,7 +981,7 @@ def rollback_and_shrink_delta(message_prefix: str) -> np.ndarray: def _select_drift( self, - nn_drift: np.ndarray, + landmark_drift: np.ndarray, reg_drift: np.ndarray | None, x_current: np.ndarray, ) -> tuple[np.ndarray, str]: @@ -959,8 +994,8 @@ def _select_drift( Parameters ---------- - nn_drift : np.ndarray - Cumulative drift from NN registration. + landmark_drift : np.ndarray + Cumulative drift from landmark fitting. reg_drift : np.ndarray or None Cumulative drift from classical image registration, or ``None`` on failure. x_current : np.ndarray @@ -970,25 +1005,25 @@ def _select_drift( ------- chosen_drift : np.ndarray chosen_source : str - ``"nn_registration"`` or ``"registration"``. + ``"landmark_fitting"`` or ``"registration"``. """ n_collected = self.drift_model_y.get_n_parameter_drift_points_collected() n_needed = self.dual_drift_selection_priming_iterations if n_collected < n_needed: - if self.dual_drift_estimation_primary_source == "nn_registration" or reg_drift is None: - chosen_drift, chosen_source = nn_drift, "nn_registration" + if self.dual_drift_estimation_primary_source == "landmark_fitting" or reg_drift is None: + chosen_drift, chosen_source = landmark_drift, "landmark_fitting" else: chosen_drift, chosen_source = reg_drift, "registration" if reg_drift is None: self.record_system_message( "Dual drift estimation: image registration returned NaN; " - "falling back to nn_registration." + "falling back to landmark_fitting." ) self.record_system_message( f"Dual drift estimation (primary phase, n={n_collected}/{n_needed}): " f"```using {chosen_source}\n" - f"nn_drift={nn_drift.tolist()}\n" + f"landmark_drift={landmark_drift.tolist()}\n" f"reg_drift={reg_drift.tolist() if reg_drift is not None else 'NaN'}```" ) else: @@ -1000,21 +1035,21 @@ def _select_drift( ], dtype=float, ) - dist_nn = float(np.linalg.norm(nn_drift - model_drift)) + dist_landmark = float(np.linalg.norm(landmark_drift - model_drift)) dist_reg = ( float(np.linalg.norm(reg_drift - model_drift)) if reg_drift is not None else np.inf ) - if dist_nn <= dist_reg: - chosen_drift, chosen_source = nn_drift, "nn_registration" + if dist_landmark <= dist_reg: + chosen_drift, chosen_source = landmark_drift, "landmark_fitting" else: chosen_drift, chosen_source = reg_drift, "registration" dist_reg_str = f"{dist_reg:.4f}" if reg_drift is not None else "inf" self.record_system_message( f"Dual drift estimation (arbitration phase):\n" f"```model_drift={model_drift.tolist()}\n" - f"nn_drift={nn_drift.tolist()} (dist={dist_nn:.4f})\n" + f"landmark_drift={landmark_drift.tolist()} (dist={dist_landmark:.4f})\n" f"reg_drift={reg_drift.tolist() if reg_drift is not None else 'NaN'} (dist={dist_reg_str})\n" f"Chosen: {chosen_source}```" ) @@ -1073,55 +1108,57 @@ def apply_drift_correction(self, x_current: np.ndarray) -> None: - self.extract_image_acquisition_position(self.image_acquisition_kwargs) ) - # Get the offset from the NN registration tool using the same target as classical registration. - if self.nn_registration_tool is not None: - line_scan_correction_nn, image_acq_correction_nn = self.find_position_correction( - method="nn", target=self.registration_target + # Get the offset from the landmark fitting tool using the same target as classical registration. + if self.landmark_fitting_tool is not None: + line_scan_correction_landmark, image_acq_correction_landmark = self.find_position_correction( + method="landmark", target=self.registration_target ) if self.registration_target == "previous": - line_scan_correction_nn_wrt_prev = line_scan_correction_nn - line_scan_correction_nn_wrt_initial = ( + line_scan_correction_landmark_wrt_prev = line_scan_correction_landmark + line_scan_correction_landmark_wrt_initial = ( self.extract_line_scan_position(self.line_scan_kwargs) - + line_scan_correction_nn + + line_scan_correction_landmark - self.initial_line_scan_position ) - image_acq_correction_nn_wrt_prev = image_acq_correction_nn - image_acq_correction_nn_wrt_initial = ( - image_acq_correction_nn + image_acq_correction_landmark_wrt_prev = image_acq_correction_landmark + image_acq_correction_landmark_wrt_initial = ( + image_acq_correction_landmark + self.extract_image_acquisition_position(self.image_acquisition_kwargs) - self.initial_image_acquisition_position ) else: # "initial" - line_scan_correction_nn_wrt_initial = line_scan_correction_nn - line_scan_correction_nn_wrt_prev = ( + line_scan_correction_landmark_wrt_initial = line_scan_correction_landmark + line_scan_correction_landmark_wrt_prev = ( self.initial_line_scan_position - + line_scan_correction_nn + + line_scan_correction_landmark - self.extract_line_scan_position(self.line_scan_kwargs) ) - image_acq_correction_nn_wrt_initial = image_acq_correction_nn - image_acq_correction_nn_wrt_prev = ( + image_acq_correction_landmark_wrt_initial = image_acq_correction_landmark + image_acq_correction_landmark_wrt_prev = ( self.initial_image_acquisition_position - + image_acq_correction_nn + + image_acq_correction_landmark - self.extract_image_acquisition_position(self.image_acquisition_kwargs) ) # Select the best correction to use. chosen_line_scan_correction_wrt_initial, chosen_source = self._select_drift( - line_scan_correction_nn_wrt_initial, line_scan_correction_reg_wrt_initial, x_current + line_scan_correction_landmark_wrt_initial, + line_scan_correction_reg_wrt_initial, + x_current, ) chosen_line_scan_correction_wrt_prev = ( - line_scan_correction_nn_wrt_prev - if chosen_source == "nn_registration" + line_scan_correction_landmark_wrt_prev + if chosen_source == "landmark_fitting" else line_scan_correction_reg_wrt_prev ) chosen_image_acq_correction_wrt_prev = ( - image_acq_correction_nn_wrt_prev - if chosen_source == "nn_registration" + image_acq_correction_landmark_wrt_prev + if chosen_source == "landmark_fitting" else image_acq_correction_reg_wrt_prev ) chosen_image_acq_correction_wrt_initial = ( - image_acq_correction_nn_wrt_initial - if chosen_source == "nn_registration" + image_acq_correction_landmark_wrt_initial + if chosen_source == "landmark_fitting" else image_acq_correction_reg_wrt_initial ) else: diff --git a/src/eaa/tool/imaging/aps_mic/test_target_landmark_fitting.py b/src/eaa/tool/imaging/aps_mic/test_target_landmark_fitting.py index 3e3ab8c..a478e9d 100644 --- a/src/eaa/tool/imaging/aps_mic/test_target_landmark_fitting.py +++ b/src/eaa/tool/imaging/aps_mic/test_target_landmark_fitting.py @@ -1,4 +1,4 @@ -from typing import Annotated, Optional +from typing import Annotated, Literal, Optional import logging import matplotlib.pyplot as plt @@ -149,6 +149,34 @@ def zoom_image(self, image: np.ndarray) -> np.ndarray: return image return ndi.zoom(image, zoom=self.zoom, order=1, mode="nearest") + def resolve_pixel_size( + self, + pixel_size: Optional[float], + image_role: Literal["current", "previous", "initial"], + ) -> float: + """Resolve the pixel size for converting pixel coordinates to physical units.""" + if pixel_size is not None: + return float(pixel_size) + if self.image_acquisition_tool is None: + raise ValueError( + "pixel_size must be provided when image_acquisition_tool is unavailable." + ) + + if image_role == "current": + resolved = self.image_acquisition_tool.psize_k + elif image_role == "previous": + resolved = self.image_acquisition_tool.psize_km1 + elif image_role == "initial": + resolved = self.image_acquisition_tool.psize_0 + else: + raise ValueError(f"Unsupported image_role: {image_role!r}.") + + if resolved is None: + raise ValueError( + f"Pixel size for image_role={image_role!r} is unavailable." + ) + return float(resolved) + @staticmethod def normalize_image(image: np.ndarray) -> np.ndarray: """Robustly normalize an image to ``[0, 1]`` using percentiles.""" @@ -293,9 +321,17 @@ def fit_landmark_center( Optional[np.ndarray], "Optional 2D image array. When omitted, the tool uses image_acquisition_tool.image_k.", ] = None, + pixel_size: Annotated[ + Optional[float], + "Pixel size used to convert the fitted center from pixels to physical units.", + ] = None, + image_role: Annotated[ + Literal["current", "previous", "initial"], + "Which acquisition buffer the image corresponds to when pixel_size is omitted.", + ] = "current", ) -> Annotated[ list[float], - "The fitted landmark center as [center_y, center_x] in the coordinates of the original image.", + "The fitted landmark center as [center_y, center_x] in physical units.", ]: """Detect the right-side disk feature and return its center. @@ -304,29 +340,91 @@ def fit_landmark_center( image : Optional[np.ndarray], optional Explicit image array to process. If omitted, the latest image from ``image_acquisition_tool`` is used. + pixel_size : Optional[float], optional + Pixel size in physical units per pixel. If omitted, the value is + taken from ``image_acquisition_tool`` according to ``image_role``. + image_role : {"current", "previous", "initial"}, optional + Acquisition-buffer role used to resolve the pixel size when + ``pixel_size`` is omitted. Returns ------- list[float] Landmark center as ``[center_y, center_x]`` in the coordinates of - the original image. + the original image, expressed in physical units. """ image_arr = self.get_input_image(image) + resolved_pixel_size = self.resolve_pixel_size(pixel_size, image_role) processed_image = self.zoom_image(image_arr) points = self.detect_edge_points(processed_image) model, inliers = self.fit_circle( points, cropped_shape=processed_image.shape, ) - center_x, center_y, radius = model.params - circle_model = measure.CircleModel() - circle_model.params = ( - float(center_x / self.zoom), - float(center_y / self.zoom), - float(radius / self.zoom), + center_x_px, center_y_px, radius_px = model.params + circle_model_px = measure.CircleModel() + circle_model_px.params = ( + float(center_x_px / self.zoom), + float(center_y_px / self.zoom), + float(radius_px / self.zoom), ) self.latest_image = image_arr self.latest_processed_image = processed_image - self.latest_circle_model = circle_model + self.latest_circle_model = circle_model_px self.latest_circle_inliers = inliers - return [float(center_y / self.zoom), float(center_x / self.zoom)] + return [ + float(circle_model_px.params[1] * resolved_pixel_size), + float(circle_model_px.params[0] * resolved_pixel_size), + ] + + @tool(name="get_offset", return_type=ToolReturnType.LIST) + def get_offset( + self, + target: Annotated[ + Literal["previous", "initial"], + "Reference image buffer against which the current image is compared.", + ] = "initial", + ) -> Annotated[ + list[float], + "The landmark-based registration offset [dy, dx] in physical units.", + ]: + """Return the shift to apply to the test image to match the reference. + + The returned offset follows the same convention as the other image + registration tools in this codebase: it is the physical-space + translation ``[dy, dx]`` that should be applied to the current (test) + image so it aligns with the selected reference image. + """ + if self.image_acquisition_tool is None: + raise RuntimeError( + "image_acquisition_tool is required to compare current and reference images." + ) + if self.image_acquisition_tool.image_k is None: + raise RuntimeError("Current image buffer (image_k) is not populated.") + + if target == "previous": + reference_image = self.image_acquisition_tool.image_km1 + reference_role: Literal["previous", "initial"] = "previous" + elif target == "initial": + reference_image = self.image_acquisition_tool.image_0 + reference_role = "initial" + else: + raise ValueError(f"`target` must be 'previous' or 'initial', got {target!r}.") + + if reference_image is None: + raise RuntimeError( + f"Reference image buffer for target={target!r} is not populated." + ) + + center_ref = np.array( + self.fit_landmark_center(image=reference_image, image_role=reference_role), + dtype=float, + ) + center_test = np.array( + self.fit_landmark_center( + image=self.image_acquisition_tool.image_k, + image_role="current", + ), + dtype=float, + ) + return (center_ref - center_test).tolist() From b0ad0d0f77b15fd5679851e3aab72e8a3cbf615d Mon Sep 17 00:00:00 2001 From: Ming Du Date: Mon, 9 Mar 2026 14:17:16 -0500 Subject: [PATCH 11/15] FEAT: allow zoom in registration tool --- src/eaa/tool/imaging/registration.py | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/src/eaa/tool/imaging/registration.py b/src/eaa/tool/imaging/registration.py index 7e8f38e..15a5847 100644 --- a/src/eaa/tool/imaging/registration.py +++ b/src/eaa/tool/imaging/registration.py @@ -42,6 +42,7 @@ def __init__( reference_pixel_size: float = 1.0, image_coordinates_origin: Literal["top_left", "center"] = "top_left", registration_method: Literal["phase_correlation", "sift", "mutual_information", "llm", "error_minimization"] = "phase_correlation", + zoom: float = 1.0, log_scale: bool = False, require_approval: bool = False, *args, @@ -76,6 +77,9 @@ def __init__( "mutual_information" uses pyramid-based normalized mutual information, and "error_minimization" uses exhaustive integer-shift MSE search with local quadratic subpixel refinement. + zoom : float, optional + Zoom factor applied to both images before registration. Returned + offsets are scaled back to the original image coordinates. log_scale : bool, optional If True, images are transformed as `log10(x + 1)` before registration. """ @@ -88,6 +92,7 @@ def __init__( self.reference_pixel_size = reference_pixel_size self.image_coordinates_origin = image_coordinates_origin self.registration_method = registration_method + self.zoom = zoom self.log_scale = log_scale def set_reference_image( @@ -123,6 +128,14 @@ def process_image(self, image: np.ndarray) -> np.ndarray: image = np.log10(image + 1) return image + def zoom_image(self, image: np.ndarray) -> np.ndarray: + """Apply the configured registration zoom factor to an image.""" + if self.zoom <= 0: + raise ValueError("zoom must be positive.") + if self.zoom == 1.0: + return image + return ndi.zoom(image, zoom=self.zoom, order=1, mode="nearest") + @tool(name="get_offset_of_latest_image", return_type=ToolReturnType.LIST) def get_offset_of_latest_image( self, @@ -398,7 +411,10 @@ def register_images( if psize_t != psize_r: # Resize the target image to have the same pixel size as the reference image image_t = ndi.zoom(image_t, psize_t / psize_r) - + + image_t = self.zoom_image(image_t) + image_r = self.zoom_image(image_r) + if method in {"phase_correlation", "mutual_information", "error_minimization"}: image_t = self.reconcile_image_shape(image_t, image_r.shape) @@ -445,7 +461,7 @@ def register_images( offset = self.register_images_llm(image_t=image_t, image_r=image_r) else: raise ValueError(f"Invalid registration method: {method}") - return offset + return np.array(offset, dtype=float) / self.zoom def reconcile_image_shape( self, From 3d608d5d73b81b9f7f6f538f5435ab42bcf11d04 Mon Sep 17 00:00:00 2001 From: Ming Du Date: Mon, 9 Mar 2026 15:43:47 -0500 Subject: [PATCH 12/15] FIX: robust Gaussian fitting --- src/eaa/maths.py | 111 ++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 96 insertions(+), 15 deletions(-) diff --git a/src/eaa/maths.py b/src/eaa/maths.py index c798aba..922dd0f 100644 --- a/src/eaa/maths.py +++ b/src/eaa/maths.py @@ -1,6 +1,7 @@ import logging import numpy as np +import scipy.ndimage import scipy.optimize logger = logging.getLogger(__name__) @@ -35,7 +36,7 @@ def fit_gaussian_1d( y: np.ndarray, y_threshold: float = 0, ) -> tuple[float, float, float, float, float, float, float]: - """Fit a 1D Gaussian to the data after subtracting a linear background. + """Fit a 1D Gaussian to 1D data. Parameters ---------- @@ -58,31 +59,111 @@ def fit_gaussian_1d( """ x_data = np.array(x, dtype=float) y_data = np.array(y, dtype=float) + finite_mask = np.isfinite(x_data) & np.isfinite(y_data) + x_data = x_data[finite_mask] + y_data = y_data[finite_mask] + if x_data.size < 5: + logger.error("Too few finite data points for Gaussian fitting. Returning NaN values.") + return np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan + + order = np.argsort(x_data) + x_data = x_data[order] + y_data = y_data[order] x_min = float(np.min(x_data)) x_max_input = float(np.max(x_data)) + x_span = x_max_input - x_min + if not np.isfinite(x_span) or x_span <= 0: + logger.error("Invalid x range for Gaussian fitting. Returning NaN values.") + return np.nan, np.nan, np.nan, np.nan, np.nan, x_min, x_max_input + y_max, y_min = np.max(y_data), np.min(y_data) - x_max = x_data[np.argmax(y_data)] - offset = x_max + y_range = float(y_max - y_min) + if not np.isfinite(y_range) or np.isclose(y_range, 0.0): + logger.error("Input data are too flat for Gaussian fitting. Returning NaN values.") + return np.nan, np.nan, np.nan, np.nan, np.nan, x_min, x_max_input + + smooth_sigma = max(1.0, x_data.size * 0.02) + y_smooth = scipy.ndimage.gaussian_filter1d(y_data, sigma=smooth_sigma, mode="nearest") + y_smooth_max = float(np.max(y_smooth)) + y_smooth_min = float(np.min(y_smooth)) + y_smooth_range = y_smooth_max - y_smooth_min + if not np.isfinite(y_smooth_range) or np.isclose(y_smooth_range, 0.0): + logger.error("Smoothed data are too flat for Gaussian fitting. Returning NaN values.") + return np.nan, np.nan, np.nan, np.nan, np.nan, x_min, x_max_input + + x_peak = float(x_data[np.argmax(y_smooth)]) + offset = x_peak x = x_data - offset - x_max = 0 - mask = y_data >= y_min + y_threshold * (y_max - y_min) - a_guess = y_max - y_min - mu_guess = x_max - x_above_thresh = x[y_data > y_min + a_guess * 0.2] - if len(x_above_thresh) >= 3: - sigma_guess = (x_above_thresh.max() - x_above_thresh.min()) / 2 + fit_threshold = y_smooth_min + y_threshold * y_smooth_range + mask = y_smooth >= fit_threshold + if int(np.count_nonzero(mask)) < 5: + mask = np.ones_like(y_data, dtype=bool) + + positive_weight = np.clip(y_smooth - y_smooth_min, a_min=0.0, a_max=None) + if np.sum(positive_weight) > 0: + mu_guess = float(np.sum(x * positive_weight) / np.sum(positive_weight)) else: - sigma_guess = (x.max() - x.min()) / 2 - c_guess = y_min + mu_guess = 0.0 + + width_mask = y_smooth >= (y_smooth_min + 0.5 * y_smooth_range) + x_above_half = x[width_mask] + if x_above_half.size >= 2: + sigma_guess = float((x_above_half.max() - x_above_half.min()) / 2.355) + else: + sigma_guess = x_span / 6.0 + sigma_guess = float(np.clip(sigma_guess, x_span / 100.0, x_span)) + + a_guess = max(y_smooth_range, np.finfo(float).eps) + c_guess = float(np.median(y_data[~mask])) if np.any(~mask) else float(y_smooth_min) p0 = [a_guess, mu_guess, sigma_guess, c_guess] - try: - popt, _ = scipy.optimize.curve_fit(gaussian_1d, x[mask], y_data[mask], p0=p0) - except RuntimeError: + lower_bounds = [0.0, float(np.min(x)), max(x_span / 1000.0, 1e-12), float(y_min - y_range)] + upper_bounds = [float(2 * y_range + abs(y_max)), float(np.max(x)), float(2 * x_span), float(y_max + y_range)] + + def run_fit(current_mask: np.ndarray, current_p0: list[float]) -> np.ndarray | None: + if int(np.count_nonzero(current_mask)) < 5: + return None + try: + popt, _ = scipy.optimize.curve_fit( + gaussian_1d, + x[current_mask], + y_data[current_mask], + p0=current_p0, + bounds=(lower_bounds, upper_bounds), + maxfev=20000, + ) + return popt + except (RuntimeError, ValueError): + return None + + popt = run_fit(mask, p0) + if popt is None: + y_smooth_retry = scipy.ndimage.gaussian_filter1d( + y_data, + sigma=max(2.0, x_data.size * 0.05), + mode="nearest", + ) + retry_weight = np.clip(y_smooth_retry - np.min(y_smooth_retry), a_min=0.0, a_max=None) + if np.sum(retry_weight) > 0: + mu_retry = float(np.sum(x * retry_weight) / np.sum(retry_weight)) + else: + mu_retry = 0.0 + retry_mask = y_smooth_retry >= ( + np.min(y_smooth_retry) + max(0.1, y_threshold) * (np.max(y_smooth_retry) - np.min(y_smooth_retry)) + ) + retry_p0 = [a_guess, mu_retry, sigma_guess, c_guess] + popt = run_fit(retry_mask, retry_p0) + + if popt is None: logger.error("Failed to fit Gaussian to data. Returning NaN values.") return np.nan, np.nan, np.nan, np.nan, np.nan, x_min, x_max_input y_fit = gaussian_1d(x, *popt) amplitude = float(popt[0]) + sigma = float(popt[2]) + mu = float(popt[1]) + if sigma <= 0 or not (float(np.min(x)) <= mu <= float(np.max(x))): + logger.error("Gaussian fit parameters are invalid. Returning NaN values.") + return np.nan, np.nan, np.nan, np.nan, np.nan, x_min, x_max_input if np.isclose(amplitude, 0.0): normalized_residual = np.nan else: From de5b543379f98dca500881be3ddd10989a641c68 Mon Sep 17 00:00:00 2001 From: Ming Du Date: Mon, 9 Mar 2026 15:45:04 -0500 Subject: [PATCH 13/15] REFACTOR: refactor analytical focusing task manager around common registration tool API --- .../tuning/analytical_focusing.py | 486 ++++++++---------- src/eaa/tool/imaging/nn_registration.py | 2 +- src/eaa/tool/imaging/registration.py | 46 +- tests/test_image_registration_tool.py | 6 +- 4 files changed, 231 insertions(+), 309 deletions(-) diff --git a/src/eaa/task_manager/tuning/analytical_focusing.py b/src/eaa/task_manager/tuning/analytical_focusing.py index 47abf60..ce1c828 100644 --- a/src/eaa/task_manager/tuning/analytical_focusing.py +++ b/src/eaa/task_manager/tuning/analytical_focusing.py @@ -21,6 +21,7 @@ from eaa.tool.imaging.aps_mic.test_target_landmark_fitting import ( TestPatternLandmarkFitting, ) +from eaa.tool.imaging.nn_registration import NNRegistration from eaa.tool.imaging.param_tuning import SetParameters from eaa.task_manager.tuning.base import BaseParameterTuningTaskManager from eaa.tool.imaging.registration import ImageRegistration @@ -34,6 +35,9 @@ logger = logging.getLogger(__name__) +RegistrationToolType = ImageRegistration | TestPatternLandmarkFitting | NNRegistration + + class LineScanValidationFailed(RuntimeError): pass @@ -59,14 +63,12 @@ def __init__( line_scan_tool_y_coordinate_args: Tuple[str, ...] = ("y_center",), image_acquisition_tool_x_coordinate_args: Tuple[str, ...] = ("x_center",), image_acquisition_tool_y_coordinate_args: Tuple[str, ...] = ("y_center",), - registration_method: Literal["phase_correlation", "sift", "mutual_information", "llm"] = "phase_correlation", - registration_algorithm_kwargs: Optional[dict[str, Any]] = None, registration_target: Literal["previous", "initial"] = "previous", run_line_scan_checker: bool = True, run_offset_calibration: bool = True, - landmark_fitting_tool: Optional[TestPatternLandmarkFitting] = None, - dual_drift_selection_priming_iterations: int = 3, - dual_drift_estimation_primary_source: Literal["registration", "landmark_fitting"] = "landmark_fitting", + registration_tools: Optional[list[RegistrationToolType]] = None, + registration_selection_priming_iterations: int = 3, + primary_registration_tool_index: int = 0, *args, **kwargs ): """Analytical scanning microscope focusing task manager driven @@ -128,9 +130,6 @@ def __init__( See `line_scan_tool_x_coordinate_args`. image_acquisition_tool_y_coordinate_args: Tuple[str, ...] See `line_scan_tool_y_coordinate_args`. - registration_algorithm_kwargs : Optional[dict[str, Any]] - Optional keyword arguments forwarded to the selected image - registration algorithm when aligning consecutive 2D scans. registration_target : Literal["previous", "initial"], optional The reference image used by the registration branch of drift correction. "previous" (default) registers each new 2D scan @@ -146,32 +145,22 @@ def __init__( If True, run 2D image acquisition and image-registration-based offset calibration. If False, the loop only performs parameter setting, line scan, and optimization updates/suggestions. - dual_drift_selection_priming_iterations : int, optional - Number of parameter–drift samples to collect (warm-up phase) before - the linear model is used to arbitrate between image registration and - landmark fitting drift estimates. Only relevant when - ``landmark_fitting_tool`` is provided. - landmark_fitting_tool : TestPatternLandmarkFitting, optional - If provided, both classical image registration and the landmark-based - registration are run after each 2D acquisition to estimate drift. - The landmark-fitting branch estimates the disk center in the - current image and in the selected reference image, then subtracts - those centers to obtain a physical drift estimate. - For the first ``dual_drift_selection_priming_iterations`` parameter - adjustments only the source selected by - ``dual_drift_estimation_primary_source`` is used; concurrently a - linear model is fit mapping optics parameters to cumulative drift - from the initial position. Afterwards the linear model prediction - is used to arbitrate: the drift estimate (from either method) - that is closer to the linear model prediction is chosen as the - correct drift, and the linear model is then updated with that - chosen value. Requires ``run_offset_calibration=True``. - dual_drift_estimation_primary_source : {"registration", "landmark_fitting"} - Which drift-estimation method to trust exclusively during the first - ``dual_drift_selection_priming_iterations`` iterations when - ``landmark_fitting_tool`` is provided. After that point both - methods are evaluated and the linear model arbitrates. Defaults - to ``"landmark_fitting"``. + registration_tools : list[RegistrationToolType], optional + Registration tools to use for drift estimation. Each tool must + provide ``get_offset(target=...)`` returning the shift to apply to + the current image so it aligns with the selected reference image. + Supported tool classes are :class:`ImageRegistration`, + :class:`TestPatternLandmarkFitting`, and :class:`NNRegistration`. + The caller is responsible for instantiating these tools with the + desired configuration. + registration_selection_priming_iterations : int, optional + Number of parameter-drift samples to collect before the linear + model is used to arbitrate among multiple registration tools. + During this warm-up phase the tool at + ``primary_registration_tool_index`` is used. + primary_registration_tool_index : int, optional + Index of the registration tool to trust during the warm-up phase + before the linear drift model has enough samples to arbitrate. """ if acquisition_tool is None: raise ValueError("`acquisition_tool` must be provided.") @@ -181,14 +170,6 @@ def __init__( self.optimization_tool = self.create_bo_tool(parameter_ranges) else: self.optimization_tool = optimization_tool - self.image_registration_tool = self.create_image_registration_tool( - acquisition_tool, - llm_config=llm_config, - registration_method=registration_method, - ) - self.registration_algorithm_kwargs = copy.deepcopy( - registration_algorithm_kwargs or {} - ) if registration_target not in ("previous", "initial"): raise ValueError( f"`registration_target` must be 'previous' or 'initial', got {registration_target!r}." @@ -215,25 +196,42 @@ def __init__( self.run_line_scan_checker = run_line_scan_checker self.run_offset_calibration = run_offset_calibration - if landmark_fitting_tool is not None and not run_offset_calibration: + self.registration_tools = list(registration_tools or []) + if self.run_offset_calibration and len(self.registration_tools) == 0: raise ValueError( - "`landmark_fitting_tool` requires `run_offset_calibration=True` " + "`registration_tools` must be provided when `run_offset_calibration=True`." + ) + if len(self.registration_tools) > 1 and not run_offset_calibration: + raise ValueError( + "Multiple `registration_tools` require `run_offset_calibration=True` " "because a 2D image must be acquired before the tool can run." ) + for registration_tool in self.registration_tools: + if not hasattr(registration_tool, "get_offset"): + raise ValueError( + "Each registration tool must provide a `get_offset(target=...)` method, " + f"got {type(registration_tool).__name__}." + ) + if hasattr(registration_tool, "image_acquisition_tool") and getattr( + registration_tool, + "image_acquisition_tool", + ) is None: + registration_tool.image_acquisition_tool = acquisition_tool + if isinstance(registration_tool, ImageRegistration): + registration_tool.llm_config = llm_config + registration_tool.memory_config = memory_config - self.landmark_fitting_tool = landmark_fitting_tool - if ( - self.landmark_fitting_tool is not None - and self.landmark_fitting_tool.image_acquisition_tool is None - ): - self.landmark_fitting_tool.image_acquisition_tool = acquisition_tool - self.dual_drift_estimation_primary_source = dual_drift_estimation_primary_source - self.dual_drift_selection_priming_iterations = ( - dual_drift_selection_priming_iterations + self.registration_selection_priming_iterations = ( + registration_selection_priming_iterations ) - if self.dual_drift_selection_priming_iterations < 1: + self.primary_registration_tool_index = primary_registration_tool_index + if self.registration_selection_priming_iterations < 1: raise ValueError( - "`dual_drift_selection_priming_iterations` must be >= 1." + "`registration_selection_priming_iterations` must be >= 1." + ) + if not (0 <= self.primary_registration_tool_index < len(self.registration_tools)): + raise ValueError( + "`primary_registration_tool_index` is out of range for `registration_tools`." ) self.drift_model_y = MultivariateLinearRegression() self.drift_model_x = MultivariateLinearRegression() @@ -261,24 +259,6 @@ def create_bo_tool(self, parameter_ranges: list[tuple[float, ...], tuple[float, ) return bo_tool - def create_image_registration_tool( - self, - acquisition_tool: AcquireImage, - llm_config: Optional[LLMConfig] = None, - registration_method: Literal["phase_correlation", "sift", "mutual_information", "llm"] = "llm", - ): - image_registration_tool = ImageRegistration( - image_acquisition_tool=acquisition_tool, - llm_config=llm_config, - reference_image=None, - reference_pixel_size=1.0, - image_coordinates_origin="top_left", - registration_method=registration_method, - log_scale=True, - zoom=4 - ) - return image_registration_tool - def prerun_check( self, initial_sampling_range: Optional[Tuple[float, float]], @@ -510,7 +490,7 @@ def update_linear_drift_models( parameters: np.ndarray, current_position_yx: np.ndarray | None = None, ): - if self.landmark_fitting_tool is None: + if len(self.registration_tools) <= 1: return if self.initial_line_scan_position is None: return @@ -784,43 +764,76 @@ def get_suggested_next_parameters(self, step_size_limit: Optional[float | Tuple[ p_suggested = p_current + signs * step_sizes return p_suggested + def get_registration_tool_name(self, registration_tool: RegistrationToolType) -> str: + return getattr(registration_tool, "name", type(registration_tool).__name__) + + def record_registration_tool_result( + self, + registration_tool: RegistrationToolType, + target: Literal["previous", "initial"], + alignment_offset: np.ndarray, + ) -> None: + if not hasattr(registration_tool, "plot_last_fit"): + return + + registration_fig = None + try: + registration_fig = registration_tool.plot_last_fit() + registration_fig_path = BaseTool.save_image_to_temp_dir( + fig=registration_fig, + filename=f"{self.get_registration_tool_name(registration_tool)}_overlay.png", + add_timestamp=True, + ) + self.record_system_message( + content=( + f"Registration result from {self.get_registration_tool_name(registration_tool)}:\n" + f"```target={target}\n" + f"alignment_offset={alignment_offset.tolist()}```" + ), + image_path=registration_fig_path, + ) + except Exception as exc: + logger.warning( + "Failed to render registration overlay for %s: %s", + self.get_registration_tool_name(registration_tool), + exc, + ) + finally: + if registration_fig is not None: + plt.close(registration_fig) + def find_position_correction( self, - method: Literal["traditional", "landmark"] = "traditional", + registration_tool: RegistrationToolType, target: Literal["previous", "initial"] = "previous", ) -> tuple[np.ndarray, np.ndarray]: - """Find the correction that should be applied (added) to image acquisition and line scan positions. + """Find the correction implied by a registration tool result. Parameters ---------- - method : Literal["traditional", "landmark"] - The registration algorithm to use. - "traditional": use :attr:`image_registration_tool` (classical image registration). - "landmark": use :attr:`landmark_fitting_tool` (landmark-based registration). - target : Literal["previous", "initial"] + registration_tool : RegistrationToolType + Registration tool that provides ``get_offset(target=...)``, where + the returned offset is the shift to apply to the current/test image + so it aligns with the selected reference image. + target : Literal["previous", "initial"], optional The reference image to register against. "previous": register the current image against the immediately - preceding image. The returned corrections are relative to the + preceding image. The returned corrections are relative to the previous step. "initial": register the current image against the very first - acquired image. The returned corrections are cumulative from the + acquired image. The returned corrections are cumulative from the initial position, which prevents per-step error accumulation. Returns ------- - np.ndarray - The correction to apply to the line scan positions, including - both the pure registration offset and the intentional scan-position - difference. - For target="previous": relative to the previous step. - For target="initial": cumulative from the initial position. - Values are in physical units (pixel size already accounted for). - np.ndarray - The correction to apply to the image acquisition positions (pure - registration offset only, without the intentional scan-position - difference). - For target="previous": relative to the previous step. - For target="initial": cumulative from the initial position. + tuple[np.ndarray, np.ndarray] + Two arrays in physical units: + ``(line_scan_correction, registration_correction)``. + ``line_scan_correction`` includes both the pure registration + correction and the intentional scan-position difference between the + current and reference 2D scans. + ``registration_correction`` is the pure stage-position correction + to apply to image acquisition coordinates. """ if target not in ("previous", "initial"): raise ValueError(f"`target` must be 'previous' or 'initial', got {target!r}.") @@ -833,59 +846,11 @@ def find_position_correction( for dir in ["y", "x"] ]).astype(float) - # Alignment offset depends on the registration method (and the target image). - if method == "traditional": - if target == "previous": - image_r = self.image_registration_tool.process_image(self.acquisition_tool.image_km1) - psize_r = self.acquisition_tool.psize_km1 - else: # "initial" - image_r = self.image_registration_tool.process_image(self.acquisition_tool.image_0) - psize_r = self.acquisition_tool.psize_0 - alignment_offset = np.array( - self.image_registration_tool.register_images( - image_t=self.image_registration_tool.process_image(self.acquisition_tool.image_k), - image_r=image_r, - psize_t=self.acquisition_tool.psize_k, - psize_r=psize_r, - registration_algorithm_kwargs=self.registration_algorithm_kwargs, - ), - dtype=float, - ) - elif method == "landmark": - if self.landmark_fitting_tool is None: - raise RuntimeError( - "`landmark_fitting_tool` is not set. Provide one in the constructor " - "to use method='landmark'." - ) - alignment_offset = np.array( - self.landmark_fitting_tool.get_offset(target=target), - dtype=float, - ) - landmark_fig = None - try: - landmark_fig = self.landmark_fitting_tool.plot_last_fit() - landmark_fig_path = BaseTool.save_image_to_temp_dir( - fig=landmark_fig, - filename="landmark_fitting_overlay.png", - add_timestamp=True, - ) - self.record_system_message( - content=( - "Landmark fitting result for drift estimation:\n" - f"```target={target}\n" - f"alignment_offset={alignment_offset.tolist()}```" - ), - image_path=landmark_fig_path, - ) - except Exception as exc: - logger.warning("Failed to render landmark fitting overlay: %s", exc) - finally: - if landmark_fig is not None: - plt.close(landmark_fig) - else: - raise ValueError( - f"`method` must be 'traditional' or 'landmark', got {method!r}." - ) + alignment_offset = np.array( + registration_tool.get_offset(target=target), + dtype=float, + ) + self.record_registration_tool_result(registration_tool, target, alignment_offset) # Image registration offset is the offset by which the moving image should be rolled to match # the reference. We want acquisition position correction here, which is the negation of it. @@ -981,50 +946,33 @@ def rollback_and_shrink_delta(message_prefix: str) -> np.ndarray: def _select_drift( self, - landmark_drift: np.ndarray, - reg_drift: np.ndarray | None, + candidate_drifts: dict[str, np.ndarray], x_current: np.ndarray, ) -> tuple[np.ndarray, str]: - """Select the drift estimate to use, logging the decision. - - During the warm-up phase (fewer than ``dual_drift_selection_priming_iterations`` - samples collected) the primary source is used unconditionally. - Afterwards the linear model prediction arbitrates: the drift estimate - closer in Euclidean distance to the model prediction is chosen. - - Parameters - ---------- - landmark_drift : np.ndarray - Cumulative drift from landmark fitting. - reg_drift : np.ndarray or None - Cumulative drift from classical image registration, or ``None`` on failure. - x_current : np.ndarray - Current optics parameter vector (used in the arbitration phase). - - Returns - ------- - chosen_drift : np.ndarray - chosen_source : str - ``"landmark_fitting"`` or ``"registration"``. - """ + """Select the best cumulative drift estimate from the available tools.""" n_collected = self.drift_model_y.get_n_parameter_drift_points_collected() - n_needed = self.dual_drift_selection_priming_iterations + n_needed = self.registration_selection_priming_iterations if n_collected < n_needed: - if self.dual_drift_estimation_primary_source == "landmark_fitting" or reg_drift is None: - chosen_drift, chosen_source = landmark_drift, "landmark_fitting" + preferred_tool = self.registration_tools[self.primary_registration_tool_index] + preferred_name = self.get_registration_tool_name(preferred_tool) + if preferred_name in candidate_drifts: + chosen_source = preferred_name else: - chosen_drift, chosen_source = reg_drift, "registration" - if reg_drift is None: + chosen_source = next(iter(candidate_drifts)) self.record_system_message( - "Dual drift estimation: image registration returned NaN; " - "falling back to landmark_fitting." + "Primary registration tool did not yield a valid result; " + f"falling back to {chosen_source}." ) + chosen_drift = candidate_drifts[chosen_source] + candidate_drift_lines = "\n".join( + f"{name}: {drift.tolist()}" + for name, drift in candidate_drifts.items() + ) self.record_system_message( - f"Dual drift estimation (primary phase, n={n_collected}/{n_needed}): " + f"Registration tool selection (primary phase, n={n_collected}/{n_needed}): " f"```using {chosen_source}\n" - f"landmark_drift={landmark_drift.tolist()}\n" - f"reg_drift={reg_drift.tolist() if reg_drift is not None else 'NaN'}```" + f"candidate_drifts:\n{candidate_drift_lines}```" ) else: x_in = np.array(x_current, dtype=float).reshape(1, -1).tolist() @@ -1035,29 +983,27 @@ def _select_drift( ], dtype=float, ) - dist_landmark = float(np.linalg.norm(landmark_drift - model_drift)) - dist_reg = ( - float(np.linalg.norm(reg_drift - model_drift)) - if reg_drift is not None - else np.inf + candidate_distances = { + name: float(np.linalg.norm(drift - model_drift)) + for name, drift in candidate_drifts.items() + } + chosen_source = min(candidate_distances, key=candidate_distances.get) + chosen_drift = candidate_drifts[chosen_source] + candidate_result_lines = "\n".join( + f"{name}: drift={candidate_drifts[name].tolist()}, dist={candidate_distances[name]:.4f}" + for name in candidate_drifts ) - if dist_landmark <= dist_reg: - chosen_drift, chosen_source = landmark_drift, "landmark_fitting" - else: - chosen_drift, chosen_source = reg_drift, "registration" - dist_reg_str = f"{dist_reg:.4f}" if reg_drift is not None else "inf" self.record_system_message( - f"Dual drift estimation (arbitration phase):\n" + f"Registration tool selection (arbitration phase):\n" f"```model_drift={model_drift.tolist()}\n" - f"landmark_drift={landmark_drift.tolist()} (dist={dist_landmark:.4f})\n" - f"reg_drift={reg_drift.tolist() if reg_drift is not None else 'NaN'} (dist={dist_reg_str})\n" + f"candidates:\n{candidate_result_lines}\n" f"Chosen: {chosen_source}```" ) return chosen_drift, chosen_source def apply_drift_correction(self, x_current: np.ndarray) -> None: - """Run both drift estimators, select the best estimate, and apply it. + """Run the configured registration tools, select a drift estimate, and apply it. Parameters ---------- @@ -1070,103 +1016,82 @@ def apply_drift_correction(self, x_current: np.ndarray) -> None: "Ensure initialize_kwargs_buffers and run_2d_scan have been called first." ) - # Get the offset from the registration. - # line_scan_pos_offset_reg: the offset counting in BOTH the pure image registration and the scan-position difference - # -> to be applied to the line scan position. - # pure_registration_offset: the offset counting in ONLY the pure image registration - # -> to be applied to the image acquisition position. - line_scan_correction_reg, image_acq_correction_reg = self.find_position_correction( - method="traditional", target=self.registration_target + current_line_scan_position = self.extract_line_scan_position(self.line_scan_kwargs) + current_image_acq_position = self.extract_image_acquisition_position( + self.image_acquisition_kwargs ) - if self.registration_target == "previous": - # find_position_correction returns corrections relative to the previous step. - line_scan_correction_reg_wrt_prev = line_scan_correction_reg - line_scan_correction_reg_wrt_initial = ( - self.extract_line_scan_position(self.line_scan_kwargs) - + line_scan_correction_reg - - self.initial_line_scan_position - ) - image_acq_correction_reg_wrt_prev = image_acq_correction_reg - image_acq_correction_reg_wrt_initial = ( - image_acq_correction_reg - + self.extract_image_acquisition_position(self.image_acquisition_kwargs) - - self.initial_image_acquisition_position - ) - else: - # find_position_correction returns cumulative corrections from the initial position. - # Convert to wrt_prev so the apply helpers can add them to the current kwargs. - line_scan_correction_reg_wrt_initial = line_scan_correction_reg - line_scan_correction_reg_wrt_prev = ( - self.initial_line_scan_position - + line_scan_correction_reg - - self.extract_line_scan_position(self.line_scan_kwargs) - ) - image_acq_correction_reg_wrt_initial = image_acq_correction_reg - image_acq_correction_reg_wrt_prev = ( - self.initial_image_acquisition_position - + image_acq_correction_reg - - self.extract_image_acquisition_position(self.image_acquisition_kwargs) - ) - - # Get the offset from the landmark fitting tool using the same target as classical registration. - if self.landmark_fitting_tool is not None: - line_scan_correction_landmark, image_acq_correction_landmark = self.find_position_correction( - method="landmark", target=self.registration_target - ) + + candidate_results: dict[str, dict[str, np.ndarray]] = {} + for registration_tool in self.registration_tools: + tool_name = self.get_registration_tool_name(registration_tool) + try: + line_scan_correction, image_acq_correction = self.find_position_correction( + registration_tool=registration_tool, + target=self.registration_target, + ) + except Exception as exc: + if len(self.registration_tools) == 1: + raise + logger.warning("Registration tool %s failed: %s", tool_name, exc) + self.record_system_message( + f"Registration tool {tool_name} failed and will be skipped: {exc}" + ) + continue + if self.registration_target == "previous": - line_scan_correction_landmark_wrt_prev = line_scan_correction_landmark - line_scan_correction_landmark_wrt_initial = ( - self.extract_line_scan_position(self.line_scan_kwargs) - + line_scan_correction_landmark + line_scan_correction_wrt_prev = line_scan_correction + line_scan_correction_wrt_initial = ( + current_line_scan_position + + line_scan_correction - self.initial_line_scan_position ) - image_acq_correction_landmark_wrt_prev = image_acq_correction_landmark - image_acq_correction_landmark_wrt_initial = ( - image_acq_correction_landmark - + self.extract_image_acquisition_position(self.image_acquisition_kwargs) + image_acq_correction_wrt_prev = image_acq_correction + image_acq_correction_wrt_initial = ( + image_acq_correction + + current_image_acq_position - self.initial_image_acquisition_position ) - else: # "initial" - line_scan_correction_landmark_wrt_initial = line_scan_correction_landmark - line_scan_correction_landmark_wrt_prev = ( + else: + line_scan_correction_wrt_initial = line_scan_correction + line_scan_correction_wrt_prev = ( self.initial_line_scan_position - + line_scan_correction_landmark - - self.extract_line_scan_position(self.line_scan_kwargs) + + line_scan_correction + - current_line_scan_position ) - image_acq_correction_landmark_wrt_initial = image_acq_correction_landmark - image_acq_correction_landmark_wrt_prev = ( + image_acq_correction_wrt_initial = image_acq_correction + image_acq_correction_wrt_prev = ( self.initial_image_acquisition_position - + image_acq_correction_landmark - - self.extract_image_acquisition_position(self.image_acquisition_kwargs) + + image_acq_correction + - current_image_acq_position ) - # Select the best correction to use. + candidate_results[tool_name] = { + "line_scan_correction_wrt_prev": line_scan_correction_wrt_prev, + "line_scan_correction_wrt_initial": line_scan_correction_wrt_initial, + "image_acq_correction_wrt_prev": image_acq_correction_wrt_prev, + "image_acq_correction_wrt_initial": image_acq_correction_wrt_initial, + } + + if len(candidate_results) == 0: + raise RuntimeError("No registration tool produced a valid drift estimate.") + + if len(candidate_results) == 1: + chosen_source = next(iter(candidate_results)) + else: chosen_line_scan_correction_wrt_initial, chosen_source = self._select_drift( - line_scan_correction_landmark_wrt_initial, - line_scan_correction_reg_wrt_initial, + { + name: result["line_scan_correction_wrt_initial"] + for name, result in candidate_results.items() + }, x_current, ) - chosen_line_scan_correction_wrt_prev = ( - line_scan_correction_landmark_wrt_prev - if chosen_source == "landmark_fitting" - else line_scan_correction_reg_wrt_prev - ) - chosen_image_acq_correction_wrt_prev = ( - image_acq_correction_landmark_wrt_prev - if chosen_source == "landmark_fitting" - else image_acq_correction_reg_wrt_prev - ) - chosen_image_acq_correction_wrt_initial = ( - image_acq_correction_landmark_wrt_initial - if chosen_source == "landmark_fitting" - else image_acq_correction_reg_wrt_initial - ) - else: - chosen_source = "registration" - chosen_line_scan_correction_wrt_prev = line_scan_correction_reg_wrt_prev - chosen_line_scan_correction_wrt_initial = line_scan_correction_reg_wrt_initial - chosen_image_acq_correction_wrt_prev = image_acq_correction_reg_wrt_prev - chosen_image_acq_correction_wrt_initial = image_acq_correction_reg_wrt_initial + _ = chosen_line_scan_correction_wrt_initial + + chosen_result = candidate_results[chosen_source] + chosen_line_scan_correction_wrt_prev = chosen_result["line_scan_correction_wrt_prev"] + chosen_line_scan_correction_wrt_initial = chosen_result["line_scan_correction_wrt_initial"] + chosen_image_acq_correction_wrt_prev = chosen_result["image_acq_correction_wrt_prev"] + chosen_image_acq_correction_wrt_initial = chosen_result["image_acq_correction_wrt_initial"] self.apply_offset_to_line_scan_kwargs(chosen_line_scan_correction_wrt_prev) self.apply_offset_to_image_acquisition_kwargs(chosen_image_acq_correction_wrt_prev) @@ -1174,7 +1099,6 @@ def apply_drift_correction(self, x_current: np.ndarray) -> None: self.record_system_message( f"Applied drift correction:\n" f"```source = {chosen_source}\n" - f"image_registration_offset = {(-chosen_image_acq_correction_wrt_prev).tolist()}\n" f"chosen_line_scan_correction_wrt_prev = {chosen_line_scan_correction_wrt_prev.tolist()}\n" f"chosen_line_scan_correction_wrt_initial = {chosen_line_scan_correction_wrt_initial.tolist()}\n" f"chosen_image_acq_correction_wrt_prev = {chosen_image_acq_correction_wrt_prev.tolist()}\n" diff --git a/src/eaa/tool/imaging/nn_registration.py b/src/eaa/tool/imaging/nn_registration.py index 255c0d4..15b2d5a 100644 --- a/src/eaa/tool/imaging/nn_registration.py +++ b/src/eaa/tool/imaging/nn_registration.py @@ -57,7 +57,7 @@ def _encode_as_tiff(image: np.ndarray) -> bytes: tifffile.imwrite(buf, image.astype(np.float32)) return buf.getvalue() - @tool(name="get_offset", return_type=ToolReturnType.TEXT) + @tool(name="get_offset", return_type=ToolReturnType.LIST) def get_offset(self, target: Literal["previous", "initial"] = "initial") -> np.ndarray: """Query the server and return the registration offset in physical units. diff --git a/src/eaa/tool/imaging/registration.py b/src/eaa/tool/imaging/registration.py index 15a5847..c6212bb 100644 --- a/src/eaa/tool/imaging/registration.py +++ b/src/eaa/tool/imaging/registration.py @@ -1,4 +1,4 @@ -from typing import Annotated, Any, List, Literal, Optional, Tuple +from typing import Annotated, Any, List, Literal, Optional, Tuple, Dict import logging import re from pathlib import Path @@ -12,7 +12,6 @@ from sciagent.message_proc import generate_openai_message from sciagent.skill import SkillMetadata from sciagent.api.llm_config import LLMConfig -from sciagent.api.memory import MemoryManagerConfig from eaa.tool.imaging.acquisition import AcquireImage from eaa.image_proc import ( @@ -37,11 +36,13 @@ def __init__( self, image_acquisition_tool: AcquireImage, llm_config: Optional[LLMConfig] = None, - memory_config: Optional[MemoryManagerConfig] = None, reference_image: np.ndarray = None, reference_pixel_size: float = 1.0, image_coordinates_origin: Literal["top_left", "center"] = "top_left", - registration_method: Literal["phase_correlation", "sift", "mutual_information", "llm", "error_minimization"] = "phase_correlation", + registration_method: Literal[ + "phase_correlation", "sift", "mutual_information", "llm", "error_minimization" + ] = "phase_correlation", + registration_algorithm_kwargs: Optional[Dict[str, Any]] = None, zoom: float = 1.0, log_scale: bool = False, require_approval: bool = False, @@ -77,6 +78,8 @@ def __init__( "mutual_information" uses pyramid-based normalized mutual information, and "error_minimization" uses exhaustive integer-shift MSE search with local quadratic subpixel refinement. + registration_algorithm_kwargs : Optional[Dict[str, Any]], optional + Keyword arguments to pass to the registration algorithm. zoom : float, optional Zoom factor applied to both images before registration. Returned offsets are scaled back to the original image coordinates. @@ -87,11 +90,11 @@ def __init__( self.image_acquisition_tool = image_acquisition_tool self.llm_config = llm_config - self.memory_config = memory_config self.reference_image = reference_image self.reference_pixel_size = reference_pixel_size self.image_coordinates_origin = image_coordinates_origin self.registration_method = registration_method + self.registration_algorithm_kwargs = registration_algorithm_kwargs or {} self.zoom = zoom self.log_scale = log_scale @@ -136,32 +139,27 @@ def zoom_image(self, image: np.ndarray) -> np.ndarray: return image return ndi.zoom(image, zoom=self.zoom, order=1, mode="nearest") - @tool(name="get_offset_of_latest_image", return_type=ToolReturnType.LIST) - def get_offset_of_latest_image( + @tool(name="get_offset", return_type=ToolReturnType.LIST) + def get_offset( self, - register_with: Annotated[ - Literal["previous", "first", "reference"], - "The image to register the latest image with. " - "Can be 'previous', 'first', or 'reference'. " - "'previous': register with the image collected by the acquisition tool before the latest. " - "'first': register with the first image collected by the acquisition tool. " - "'reference': register with the reference image provided to the tool. ", - ], + target: Annotated[ + Literal["previous", "initial", "reference"], + "Reference image buffer against which the current image is compared.", + ] = "initial", ) -> Annotated[ List[float], - "The translational offset [dy (vertical), dx (horizontal)] to apply to the latest " - "acquired image so it aligns with the reference image. Positive y means shifting the " - "latest image downward; positive x means shifting it rightward. The returned values are " - "in physical units, i.e., pixel size is already accounted for.", + "The translational offset [dy, dx] to apply to the latest image " + "so it aligns with the selected reference image.", ]: - """ - Register the latest image collected by the image acquisition tool + """Register the latest image collected by the image acquisition tool and the reference image. """ + register_with = "previous" if target == "previous" else "first" image_t, image_r, psize_t, psize_r = self.get_registration_inputs(register_with) - - offset = self.register_images(image_t, image_r, psize_t, psize_r) - return offset + return self.register_images( + image_t, image_r, psize_t, psize_r, + registration_algorithm_kwargs=self.registration_algorithm_kwargs + ) def get_registration_inputs( self, diff --git a/tests/test_image_registration_tool.py b/tests/test_image_registration_tool.py index 81509f1..c7f315d 100644 --- a/tests/test_image_registration_tool.py +++ b/tests/test_image_registration_tool.py @@ -33,7 +33,7 @@ def test_image_registration(self): y_center=164, x_center=164, size_y=128, size_x=128 ) - offset = registration_tool.get_offset_of_latest_image(register_with="previous") + offset = registration_tool.get_offset(target="previous") if self.debug: print("Offset: ", offset) @@ -69,7 +69,7 @@ def test_image_registration_diff_size(self): y_center=175, x_center=175, size_y=150, size_x=150 ) - offset = registration_tool.get_offset_of_latest_image(register_with="previous") + offset = registration_tool.get_offset(target="previous") if self.debug: print("Offset: ", offset) @@ -108,7 +108,7 @@ def test_image_registration_mutual_information(self): y_center=164, x_center=164, size_y=128, size_x=128 ) - offset = registration_tool.get_offset_of_latest_image(register_with="previous") + offset = registration_tool.get_offset(target="previous") if self.debug: print("Offset (mutual information): ", offset) From 9590d932dbba3ab2231cdb29c45078582ac5662c Mon Sep 17 00:00:00 2001 From: Ming Du Date: Mon, 9 Mar 2026 16:01:19 -0500 Subject: [PATCH 14/15] FIX: fix fallback mechanism in analytical focusing --- .../tuning/analytical_focusing.py | 59 ++++++++++++++++--- 1 file changed, 52 insertions(+), 7 deletions(-) diff --git a/src/eaa/task_manager/tuning/analytical_focusing.py b/src/eaa/task_manager/tuning/analytical_focusing.py index ce1c828..e298829 100644 --- a/src/eaa/task_manager/tuning/analytical_focusing.py +++ b/src/eaa/task_manager/tuning/analytical_focusing.py @@ -511,6 +511,43 @@ def update_linear_drift_models( f"n_samples={self.drift_model_y.get_n_parameter_drift_points_collected()}.```" ) + def snapshot_acquisition_state(self) -> dict[str, Any]: + """Capture mutable acquisition-tool state for retry rollback.""" + state = { + "image_0": copy.deepcopy(self.acquisition_tool.image_0), + "image_km1": copy.deepcopy(self.acquisition_tool.image_km1), + "image_k": copy.deepcopy(self.acquisition_tool.image_k), + "psize_0": copy.deepcopy(self.acquisition_tool.psize_0), + "psize_km1": copy.deepcopy(self.acquisition_tool.psize_km1), + "psize_k": copy.deepcopy(self.acquisition_tool.psize_k), + "image_acquisition_call_history": copy.deepcopy( + self.acquisition_tool.image_acquisition_call_history + ), + "line_scan_call_history": copy.deepcopy( + self.acquisition_tool.line_scan_call_history + ), + } + for attr in ["blur", "offset", "line_scan_candidates"]: + if hasattr(self.acquisition_tool, attr): + state[attr] = copy.deepcopy(getattr(self.acquisition_tool, attr)) + return state + + def restore_acquisition_state(self, state: dict[str, Any]) -> None: + """Restore acquisition-tool state captured by snapshot_acquisition_state.""" + self.acquisition_tool.image_0 = state["image_0"] + self.acquisition_tool.image_km1 = state["image_km1"] + self.acquisition_tool.image_k = state["image_k"] + self.acquisition_tool.psize_0 = state["psize_0"] + self.acquisition_tool.psize_km1 = state["psize_km1"] + self.acquisition_tool.psize_k = state["psize_k"] + self.acquisition_tool.image_acquisition_call_history = state[ + "image_acquisition_call_history" + ] + self.acquisition_tool.line_scan_call_history = state["line_scan_call_history"] + for attr in ["blur", "offset", "line_scan_candidates"]: + if attr in state: + setattr(self.acquisition_tool, attr, state[attr]) + def record_linear_drift_model_visualizations(self) -> None: image_paths = [] for axis_name, model in [("y", self.drift_model_y), ("x", self.drift_model_x)]: @@ -911,6 +948,7 @@ def run_tuning_iteration(self, x: np.ndarray): x_current = np.array(x, dtype=float) line_scan_kwargs_before = copy.deepcopy(self.line_scan_kwargs) image_acquisition_kwargs_before = copy.deepcopy(self.image_acquisition_kwargs) + acquisition_state_before = self.snapshot_acquisition_state() def rollback_and_shrink_delta(message_prefix: str) -> np.ndarray: for parameter_name in self.param_setting_tool.parameter_names: @@ -918,6 +956,7 @@ def rollback_and_shrink_delta(message_prefix: str) -> np.ndarray: self.param_setting_tool.parameter_history[parameter_name].pop() self.line_scan_kwargs = copy.deepcopy(line_scan_kwargs_before) self.image_acquisition_kwargs = copy.deepcopy(image_acquisition_kwargs_before) + self.restore_acquisition_state(acquisition_state_before) delta = x_current - x_original x_next = x_original + delta / 2 self.record_system_message( @@ -931,9 +970,10 @@ def rollback_and_shrink_delta(message_prefix: str) -> np.ndarray: while True: self.record_system_message(f"Setting parameters to new value:```{x_current}```") self.param_setting_tool.set_parameters(x_current) + chosen_line_scan_correction_wrt_initial = None if self.run_offset_calibration: self.run_2d_scan() - self.apply_drift_correction(x_current) + chosen_line_scan_correction_wrt_initial = self.apply_drift_correction(x_current) try: fwhm = self.run_line_scan() if np.isnan(fwhm): @@ -941,6 +981,15 @@ def rollback_and_shrink_delta(message_prefix: str) -> np.ndarray: except LineScanValidationFailed: x_current = rollback_and_shrink_delta("Line scan validation failed.") continue + if ( + self.run_offset_calibration + and chosen_line_scan_correction_wrt_initial is not None + ): + self.update_linear_drift_models( + x_current, + current_position_yx=chosen_line_scan_correction_wrt_initial, + ) + self.record_linear_drift_model_visualizations() self.update_optimization_model(fwhm) return @@ -1002,7 +1051,7 @@ def _select_drift( return chosen_drift, chosen_source - def apply_drift_correction(self, x_current: np.ndarray) -> None: + def apply_drift_correction(self, x_current: np.ndarray) -> np.ndarray: """Run the configured registration tools, select a drift estimate, and apply it. Parameters @@ -1105,11 +1154,7 @@ def apply_drift_correction(self, x_current: np.ndarray) -> None: f"chosen_image_acq_correction_wrt_initial = {chosen_image_acq_correction_wrt_initial.tolist()}```" ) - # Update model with (parameters -> chosen cumulative drift). Pass the - # target position directly so the model is consistent regardless of - # whether kwargs have been flushed to the underlying tool yet. - self.update_linear_drift_models(x_current, current_position_yx=chosen_line_scan_correction_wrt_initial) - self.record_linear_drift_model_visualizations() + return chosen_line_scan_correction_wrt_initial def apply_user_correction_offset(self) -> bool: message = ( From d0db51a7d00da906f704ac5b480be16d94a36829 Mon Sep 17 00:00:00 2001 From: Ming Du Date: Mon, 9 Mar 2026 16:37:58 -0500 Subject: [PATCH 15/15] FIX: fix failures after refactor --- .../imaging/analytical_feature_tracking.py | 2 +- tests/test_analytical_feature_tracking.py | 1 - tests/test_analytical_focusing.py | 154 +++++++++++++----- 3 files changed, 114 insertions(+), 43 deletions(-) diff --git a/src/eaa/task_manager/imaging/analytical_feature_tracking.py b/src/eaa/task_manager/imaging/analytical_feature_tracking.py index 50b7947..ad3da6e 100644 --- a/src/eaa/task_manager/imaging/analytical_feature_tracking.py +++ b/src/eaa/task_manager/imaging/analytical_feature_tracking.py @@ -188,7 +188,7 @@ def run( reference_image, psize_t=self.image_acquisition_tool.psize_k, psize_r=self.image_registration_tool.reference_pixel_size, - registration_algorithm_kwargs={"use_hanning_window": True}, + registration_algorithm_kwargs={"filtering_method": "hanning"}, ) if check_feature_presence_llm( task_manager=self, diff --git a/tests/test_analytical_feature_tracking.py b/tests/test_analytical_feature_tracking.py index 32de788..c48aa85 100644 --- a/tests/test_analytical_feature_tracking.py +++ b/tests/test_analytical_feature_tracking.py @@ -1,6 +1,5 @@ import os import argparse -import pytest import numpy as np import tifffile diff --git a/tests/test_analytical_focusing.py b/tests/test_analytical_focusing.py index eae5c89..92c67fe 100644 --- a/tests/test_analytical_focusing.py +++ b/tests/test_analytical_focusing.py @@ -3,20 +3,28 @@ import numpy as np import tifffile -from matplotlib.figure import Figure from eaa.task_manager.tuning.analytical_focusing import ( AnalyticalScanningMicroscopeFocusingTaskManager, ) from eaa.tool.imaging.acquisition import SimulatedAcquireImage from eaa.tool.imaging.param_tuning import SimulatedSetParameters -from sciagent.tool.base import BaseTool +from eaa.tool.imaging.registration import ImageRegistration import test_utils as tutils +class DummyRegistrationTool: + def __init__(self, name: str, offset=(0.0, 0.0)): + self.name = name + self.offset = np.array(offset, dtype=float) + + def get_offset(self, target="previous"): + return self.offset.tolist() + + class TestAnalyticalFocusing(tutils.BaseTester): - def _build_task_manager(self): + def _build_task_manager(self, registration_tools=None): image_path = os.path.join( self.get_ci_input_data_dir(), "simulated_images", @@ -42,12 +50,17 @@ def _build_task_manager(self): parameter_ranges=[(0.0,), (10.0,)], drift_factor=10, ) + if registration_tools is None: + registration_tools = [ + ImageRegistration(image_acquisition_tool=acquisition_tool) + ] task_manager = AnalyticalScanningMicroscopeFocusingTaskManager( param_setting_tool=param_setting_tool, acquisition_tool=acquisition_tool, initial_parameters={"z": 10.0}, parameter_ranges=[(0.0,), (10.0,)], + registration_tools=registration_tools, line_scan_tool_x_coordinate_args=("x_center",), line_scan_tool_y_coordinate_args=("y_center",), image_acquisition_tool_x_coordinate_args=("x_center",), @@ -104,46 +117,38 @@ def test_task_manager_runs_without_offset_calibration(self, monkeypatch): ) assert acquisition_tool.counter_acquire_image == 0 - def test_linear_drift_prediction_fit_and_apply(self): - task_manager, _ = self._build_task_manager() - task_manager.use_linear_drift_prediction = True - task_manager.n_parameter_drift_points_before_prediction = 3 - task_manager.initial_image_acquisition_position = np.array([10.0, 20.0], dtype=float) - task_manager.image_acquisition_kwargs = { - "y_center": 0.0, - "x_center": 0.0, - "size_y": 128, - "size_x": 128, - } - task_manager.line_scan_kwargs = { - "x_center": 1.0, - "y_center": 2.0, - "length": 10.0, - "scan_step": 1.0, - } - - # Drift model: delta_y = 2 * z + 1, delta_x = -z + 3 + def test_select_drift_uses_linear_model_after_priming(self): + registration_tools = [ + DummyRegistrationTool("primary"), + DummyRegistrationTool("secondary"), + ] + task_manager, _ = self._build_task_manager( + registration_tools=registration_tools + ) + task_manager.registration_selection_priming_iterations = 3 + task_manager.initial_line_scan_position = np.array([10.0, 20.0], dtype=float) + for z in [0.0, 1.0, 2.0]: drift = np.array([2.0 * z + 1.0, -z + 3.0], dtype=float) - current_position = task_manager.initial_image_acquisition_position + drift + current_position = task_manager.initial_line_scan_position + drift task_manager.update_linear_drift_models( parameters=np.array([z], dtype=float), current_position_yx=current_position, ) - assert task_manager.should_apply_linear_drift_prediction() - task_manager.apply_predicted_image_acquisition_position(np.array([4.0], dtype=float)) + chosen_drift, chosen_source = task_manager._select_drift( + candidate_drifts={ + "primary": np.array([0.0, 0.0], dtype=float), + "secondary": np.array([9.0, -1.0], dtype=float), + }, + x_current=np.array([4.0], dtype=float), + ) - assert np.isclose(task_manager.image_acquisition_kwargs["y_center"], 19.0) - assert np.isclose(task_manager.image_acquisition_kwargs["x_center"], 19.0) - assert np.isclose(task_manager.line_scan_kwargs["y_center"], 21.0) - assert np.isclose(task_manager.line_scan_kwargs["x_center"], 20.0) + assert chosen_source == "secondary" + assert np.allclose(chosen_drift, np.array([9.0, -1.0], dtype=float)) def test_run_iteration_applies_registration_offset_and_updates_model(self, monkeypatch): task_manager, acquisition_tool = self._build_task_manager() - task_manager.use_linear_drift_prediction = True - task_manager.n_parameter_drift_points_before_prediction = 1 - task_manager.initial_image_acquisition_position = np.array([0.0, 0.0], dtype=float) task_manager.initialize_kwargs_buffers( initial_line_scan_kwargs={ "x_center": 160.0, @@ -153,10 +158,7 @@ def test_run_iteration_applies_registration_offset_and_updates_model(self, monke }, initial_2d_scan_kwargs={"y_center": 0.0, "x_center": 0.0, "size_y": 200, "size_x": 200}, ) - task_manager.update_linear_drift_models( - parameters=np.array([0.0], dtype=float), - current_position_yx=np.array([0.0, 0.0], dtype=float), - ) + task_manager.initial_image_acquisition_position = np.array([0.0, 0.0], dtype=float) def fake_run_2d_scan(): kwargs = task_manager.image_acquisition_kwargs @@ -172,8 +174,11 @@ def fake_run_2d_scan(): monkeypatch.setattr(task_manager, "run_2d_scan", fake_run_2d_scan) monkeypatch.setattr( task_manager, - "find_offset", - lambda: (np.array([0.0, 0.0]), np.array([100.0, -50.0])), + "find_position_correction", + lambda registration_tool, target: ( + np.array([0.0, 0.0]), + np.array([100.0, -50.0]), + ), ) monkeypatch.setattr(task_manager, "run_line_scan", lambda: 1.0) monkeypatch.setattr(task_manager, "update_optimization_model", lambda fwhm: None) @@ -200,12 +205,13 @@ def test_run_tuning_iteration_calls_drift_visualization_after_update(self, monke }, initial_2d_scan_kwargs={"y_center": 0.0, "x_center": 0.0, "size_y": 200, "size_x": 200}, ) + task_manager.initial_image_acquisition_position = np.array([0.0, 0.0], dtype=float) monkeypatch.setattr(task_manager, "run_2d_scan", lambda: None) monkeypatch.setattr( task_manager, - "find_offset", - lambda: (np.array([0.0, 0.0]), np.array([0.0, 0.0])), + "apply_drift_correction", + lambda x_current: np.array([0.0, 0.0], dtype=float), ) monkeypatch.setattr(task_manager, "run_line_scan", lambda: 1.0) monkeypatch.setattr(task_manager, "update_optimization_model", lambda fwhm: None) @@ -214,7 +220,7 @@ def test_run_tuning_iteration_calls_drift_visualization_after_update(self, monke monkeypatch.setattr( task_manager, "update_linear_drift_models", - lambda parameters: call_order.append("update"), + lambda parameters, current_position_yx=None: call_order.append("update"), ) monkeypatch.setattr( task_manager, @@ -226,6 +232,71 @@ def test_run_tuning_iteration_calls_drift_visualization_after_update(self, monke assert call_order == ["update", "visualize"] + def test_registration_and_scan_position_corrections_across_two_iterations(self): + registration_tool = DummyRegistrationTool("dummy", offset=(2.5, -4.0)) + task_manager, acquisition_tool = self._build_task_manager( + registration_tools=[registration_tool] + ) + task_manager.initialize_kwargs_buffers( + initial_line_scan_kwargs={ + "x_center": 160.0, + "y_center": 170.0, + "length": 60.0, + "scan_step": 1.0, + }, + initial_2d_scan_kwargs={ + "y_center": 175.0, + "x_center": 175.0, + "size_y": 350, + "size_x": 350, + }, + ) + + task_manager.run_2d_scan() + task_manager.run_line_scan() + + task_manager.param_setting_tool.set_parameters(np.array([9.0], dtype=float)) + task_manager.image_acquisition_kwargs["y_center"] += 7.0 + task_manager.image_acquisition_kwargs["x_center"] -= 6.0 + task_manager.run_2d_scan() + + line_scan_correction, image_acquisition_correction = ( + task_manager.find_position_correction(registration_tool, target="previous") + ) + assert np.allclose(line_scan_correction, np.array([4.5, -2.0], dtype=float)) + assert np.allclose( + image_acquisition_correction, + np.array([-2.5, 4.0], dtype=float), + ) + + chosen_line_scan_correction = task_manager.apply_drift_correction( + np.array([9.0], dtype=float) + ) + assert np.allclose( + chosen_line_scan_correction, + np.array([4.5, -2.0], dtype=float), + ) + assert np.allclose( + task_manager.extract_line_scan_position(task_manager.line_scan_kwargs), + np.array([174.5, 158.0], dtype=float), + ) + assert np.allclose( + task_manager.extract_image_acquisition_position( + task_manager.image_acquisition_kwargs + ), + np.array([179.5, 173.0], dtype=float), + ) + + task_manager.run_line_scan() + assert len(acquisition_tool.line_scan_call_history) == 2 + assert np.allclose( + [ + acquisition_tool.line_scan_call_history[-1]["y_center"], + acquisition_tool.line_scan_call_history[-1]["x_center"], + ], + np.array([174.5, 158.0], dtype=float), + ) + if __name__ == "__main__": parser = argparse.ArgumentParser() @@ -241,3 +312,4 @@ def test_run_tuning_iteration_calls_drift_visualization_after_update(self, monke ) tester.test_task_manager_runs() tester.test_task_manager_runs_without_offset_calibration() + tester.test_registration_and_scan_position_corrections_across_two_iterations()