diff --git a/src/pyrecest/filters/gprhm_tracker.py b/src/pyrecest/filters/gprhm_tracker.py index 9b3e4558e..32300d6d6 100644 --- a/src/pyrecest/filters/gprhm_tracker.py +++ b/src/pyrecest/filters/gprhm_tracker.py @@ -11,6 +11,7 @@ exp, eye, hstack, + isfinite, linalg, linspace, max, @@ -308,6 +309,8 @@ def __init__( self.beta = float(beta) self.kappa = float(kappa) self.last_quadratic_form = None + self.last_active_measurement_indices = None + self.last_measurement_weights = None @staticmethod def _symmetrize(matrix): @@ -606,21 +609,78 @@ def _measurement_model_terms(self, measurement, measurement_noise): ) return measurement_jacobian, predicted_measurement, noise_covariance - def _stack_measurement_terms(self, measurements, measurement_noise): + def _normalize_measurement_weights(self, measurement_weights, n_measurements): + if measurement_weights is None: + return ones(n_measurements) + + weights = array(measurement_weights) + if weights.ndim == 0: + weights = ones(n_measurements) * float(weights) + else: + weights = reshape(weights, (-1,)) + if weights.shape[0] != n_measurements: + raise ValueError( + "measurement_weights must be scalar or have one entry per measurement" + ) + if not bool(all(isfinite(weights))): + raise ValueError("measurement_weights must be finite") + if not bool(all(weights >= 0.0)): + raise ValueError("measurement_weights must be non-negative") + return weights + + def _normalize_active_measurement_mask(self, active_measurement_mask, n_measurements): + if active_measurement_mask is None: + return [True] * n_measurements + + mask = array(active_measurement_mask) + if mask.ndim == 0: + return [bool(mask)] * n_measurements + mask = reshape(mask, (-1,)) + if mask.shape[0] != n_measurements: + raise ValueError( + "active_measurement_mask must be scalar or have one entry per measurement" + ) + return [bool(mask[index]) for index in range(n_measurements)] + + def _stack_measurement_terms( + self, + measurements, + measurement_noise, + measurement_weights=None, + active_measurement_mask=None, + ): + weights = self._normalize_measurement_weights( + measurement_weights, + measurements.shape[0], + ) + active_mask = self._normalize_active_measurement_mask( + active_measurement_mask, + measurements.shape[0], + ) measurement_jacobians = [] predicted_measurements = [] noise_covariances = [] - for measurement in measurements: + active_indices = [] + for measurement_index, measurement in enumerate(measurements): + weight = float(weights[measurement_index]) + if not active_mask[measurement_index] or weight <= 0.0: + continue measurement_jacobian, predicted_measurement, noise_covariance = ( self._measurement_model_terms(measurement, measurement_noise) ) + active_indices.append(measurement_index) measurement_jacobians.append(measurement_jacobian) predicted_measurements.append(predicted_measurement) - noise_covariances.append(noise_covariance) + noise_covariances.append(noise_covariance / weight) + self.last_measurement_weights = weights + self.last_active_measurement_indices = active_indices + if not active_indices: + return None, None, None, active_indices return ( vstack(measurement_jacobians), concatenate(predicted_measurements), linalg.block_diag(*noise_covariances), + active_indices, ) def update( @@ -629,7 +689,16 @@ def update( R=None, s_hat=None, sigma_squared_s=None, + measurement_weights=None, + active_measurement_mask=None, ): + """Update the tracker with optional per-measurement reliabilities. + + ``measurement_weights`` scales each measurement covariance block as + ``R_i / weight_i``. Zero-weight or masked measurements are skipped. + ``active_measurement_mask`` can be used to explicitly disable cluttered, + occluded, or otherwise unsupported measurements. + """ if s_hat is not None: self.scale_mean = float(s_hat) if sigma_squared_s is not None: @@ -644,10 +713,19 @@ def update( require_positive_semidefinite=False, ) - measurement_jacobian, predicted_measurements, noise_covariance = ( - self._stack_measurement_terms(measurements, measurement_noise) + measurement_jacobian, predicted_measurements, noise_covariance, active_indices = ( + self._stack_measurement_terms( + measurements, + measurement_noise, + measurement_weights=measurement_weights, + active_measurement_mask=active_measurement_mask, + ) ) - residual = concatenate(list(measurements)) + if not active_indices: + self.last_quadratic_form = None + return + + residual = concatenate([measurements[index] for index in active_indices]) residual = residual - predicted_measurements covariance_measurement = self._symmetrize( measurement_jacobian @ self.covariance @ measurement_jacobian.T diff --git a/tests/filters/test_scgp_tracker.py b/tests/filters/test_scgp_tracker.py index ea488dd57..8da79e1e8 100644 --- a/tests/filters/test_scgp_tracker.py +++ b/tests/filters/test_scgp_tracker.py @@ -88,6 +88,63 @@ def test_decorrelated_update_keeps_cross_covariance_zero(self): npt.assert_allclose(cross_covariance, zeros(cross_covariance.shape), atol=1e-12) npt.assert_array_less(-1e-10, linalg.eigvalsh(tracker.covariance)) + def test_active_measurement_mask_matches_single_measurement_update(self): + masked_tracker = self._make_tracker() + single_tracker = self._make_tracker() + measurements = array([[1.4, 0.2], [0.1, 1.3]]) + + masked_tracker.update( + measurements, + active_measurement_mask=array([True, False]), + ) + single_tracker.update(measurements[0]) + + npt.assert_allclose(masked_tracker.state, single_tracker.state, atol=1e-12) + npt.assert_allclose( + masked_tracker.covariance, + single_tracker.covariance, + atol=1e-12, + ) + self.assertEqual(masked_tracker.last_active_measurement_indices, [0]) + npt.assert_allclose(masked_tracker.last_measurement_weights, array([1.0, 1.0])) + + def test_zero_measurement_weight_skips_measurement(self): + tracker = self._make_tracker() + state_before = array(tracker.state) + covariance_before = array(tracker.covariance) + + tracker.update(array([1.4, 0.2]), measurement_weights=array([0.0])) + + npt.assert_allclose(tracker.state, state_before, atol=1e-12) + npt.assert_allclose(tracker.covariance, covariance_before, atol=1e-12) + self.assertEqual(tracker.last_active_measurement_indices, []) + self.assertIsNone(tracker.last_quadratic_form) + + def test_measurement_weight_changes_update_strength(self): + high_weight_tracker = self._make_tracker() + low_weight_tracker = self._make_tracker() + state_before = array(high_weight_tracker.state) + measurement = array([1.4, 0.2]) + + high_weight_tracker.update(measurement, measurement_weights=1.0) + low_weight_tracker.update(measurement, measurement_weights=0.05) + + high_weight_delta = linalg.norm(high_weight_tracker.state - state_before) + low_weight_delta = linalg.norm(low_weight_tracker.state - state_before) + self.assertGreater(float(high_weight_delta), float(low_weight_delta)) + npt.assert_allclose(low_weight_tracker.last_measurement_weights, array([0.05])) + + def test_measurement_weights_validate_shape_and_values(self): + tracker = self._make_tracker() + + with self.assertRaises(ValueError): + tracker.update( + array([[1.4, 0.2], [0.1, 1.3]]), + measurement_weights=array([1.0]), + ) + with self.assertRaises(ValueError): + tracker.update(array([1.4, 0.2]), measurement_weights=-1.0) + def test_full_tracker_contour_and_bounding_box(self): tracker = self._make_tracker()