Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 84 additions & 6 deletions src/pyrecest/filters/gprhm_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
exp,
eye,
hstack,
isfinite,
linalg,
linspace,
max,
Expand Down Expand Up @@ -308,6 +309,8 @@ def __init__(
self.beta = float(beta)
self.kappa = float(kappa)
self.last_quadratic_form = None
self.last_active_measurement_indices = None
self.last_measurement_weights = None

@staticmethod
def _symmetrize(matrix):
Expand Down Expand Up @@ -606,21 +609,78 @@ def _measurement_model_terms(self, measurement, measurement_noise):
)
return measurement_jacobian, predicted_measurement, noise_covariance

def _stack_measurement_terms(self, measurements, measurement_noise):
def _normalize_measurement_weights(self, measurement_weights, n_measurements):
if measurement_weights is None:
return ones(n_measurements)

weights = array(measurement_weights)
if weights.ndim == 0:
weights = ones(n_measurements) * float(weights)
else:
weights = reshape(weights, (-1,))
if weights.shape[0] != n_measurements:
raise ValueError(
"measurement_weights must be scalar or have one entry per measurement"
)
if not bool(all(isfinite(weights))):
raise ValueError("measurement_weights must be finite")
if not bool(all(weights >= 0.0)):
raise ValueError("measurement_weights must be non-negative")
return weights

def _normalize_active_measurement_mask(self, active_measurement_mask, n_measurements):
if active_measurement_mask is None:
return [True] * n_measurements

mask = array(active_measurement_mask)
if mask.ndim == 0:
return [bool(mask)] * n_measurements
mask = reshape(mask, (-1,))
if mask.shape[0] != n_measurements:
raise ValueError(
"active_measurement_mask must be scalar or have one entry per measurement"
)
return [bool(mask[index]) for index in range(n_measurements)]

def _stack_measurement_terms(
self,
measurements,
measurement_noise,
measurement_weights=None,
active_measurement_mask=None,
):
weights = self._normalize_measurement_weights(
measurement_weights,
measurements.shape[0],
)
active_mask = self._normalize_active_measurement_mask(
active_measurement_mask,
measurements.shape[0],
)
measurement_jacobians = []
predicted_measurements = []
noise_covariances = []
for measurement in measurements:
active_indices = []
for measurement_index, measurement in enumerate(measurements):
weight = float(weights[measurement_index])
if not active_mask[measurement_index] or weight <= 0.0:
continue
measurement_jacobian, predicted_measurement, noise_covariance = (
self._measurement_model_terms(measurement, measurement_noise)
)
active_indices.append(measurement_index)
measurement_jacobians.append(measurement_jacobian)
predicted_measurements.append(predicted_measurement)
noise_covariances.append(noise_covariance)
noise_covariances.append(noise_covariance / weight)
self.last_measurement_weights = weights
self.last_active_measurement_indices = active_indices
if not active_indices:
return None, None, None, active_indices
return (
vstack(measurement_jacobians),
concatenate(predicted_measurements),
linalg.block_diag(*noise_covariances),
active_indices,
)

def update(
Expand All @@ -629,7 +689,16 @@ def update(
R=None,
s_hat=None,
sigma_squared_s=None,
measurement_weights=None,
active_measurement_mask=None,
):
"""Update the tracker with optional per-measurement reliabilities.

``measurement_weights`` scales each measurement covariance block as
``R_i / weight_i``. Zero-weight or masked measurements are skipped.
``active_measurement_mask`` can be used to explicitly disable cluttered,
occluded, or otherwise unsupported measurements.
"""
if s_hat is not None:
self.scale_mean = float(s_hat)
if sigma_squared_s is not None:
Expand All @@ -644,10 +713,19 @@ def update(
require_positive_semidefinite=False,
)

measurement_jacobian, predicted_measurements, noise_covariance = (
self._stack_measurement_terms(measurements, measurement_noise)
measurement_jacobian, predicted_measurements, noise_covariance, active_indices = (
self._stack_measurement_terms(
measurements,
measurement_noise,
measurement_weights=measurement_weights,
active_measurement_mask=active_measurement_mask,
)
)
residual = concatenate(list(measurements))
if not active_indices:
self.last_quadratic_form = None
return

residual = concatenate([measurements[index] for index in active_indices])
residual = residual - predicted_measurements
covariance_measurement = self._symmetrize(
measurement_jacobian @ self.covariance @ measurement_jacobian.T
Expand Down
57 changes: 57 additions & 0 deletions tests/filters/test_scgp_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,63 @@ def test_decorrelated_update_keeps_cross_covariance_zero(self):
npt.assert_allclose(cross_covariance, zeros(cross_covariance.shape), atol=1e-12)
npt.assert_array_less(-1e-10, linalg.eigvalsh(tracker.covariance))

def test_active_measurement_mask_matches_single_measurement_update(self):
masked_tracker = self._make_tracker()
single_tracker = self._make_tracker()
measurements = array([[1.4, 0.2], [0.1, 1.3]])

masked_tracker.update(
measurements,
active_measurement_mask=array([True, False]),
)
single_tracker.update(measurements[0])

npt.assert_allclose(masked_tracker.state, single_tracker.state, atol=1e-12)
npt.assert_allclose(
masked_tracker.covariance,
single_tracker.covariance,
atol=1e-12,
)
self.assertEqual(masked_tracker.last_active_measurement_indices, [0])
npt.assert_allclose(masked_tracker.last_measurement_weights, array([1.0, 1.0]))

def test_zero_measurement_weight_skips_measurement(self):
tracker = self._make_tracker()
state_before = array(tracker.state)
covariance_before = array(tracker.covariance)

tracker.update(array([1.4, 0.2]), measurement_weights=array([0.0]))

npt.assert_allclose(tracker.state, state_before, atol=1e-12)
npt.assert_allclose(tracker.covariance, covariance_before, atol=1e-12)
self.assertEqual(tracker.last_active_measurement_indices, [])
self.assertIsNone(tracker.last_quadratic_form)

def test_measurement_weight_changes_update_strength(self):
high_weight_tracker = self._make_tracker()
low_weight_tracker = self._make_tracker()
state_before = array(high_weight_tracker.state)
measurement = array([1.4, 0.2])

high_weight_tracker.update(measurement, measurement_weights=1.0)
low_weight_tracker.update(measurement, measurement_weights=0.05)

high_weight_delta = linalg.norm(high_weight_tracker.state - state_before)
low_weight_delta = linalg.norm(low_weight_tracker.state - state_before)
self.assertGreater(float(high_weight_delta), float(low_weight_delta))
npt.assert_allclose(low_weight_tracker.last_measurement_weights, array([0.05]))

def test_measurement_weights_validate_shape_and_values(self):
tracker = self._make_tracker()

with self.assertRaises(ValueError):
tracker.update(
array([[1.4, 0.2], [0.1, 1.3]]),
measurement_weights=array([1.0]),
)
with self.assertRaises(ValueError):
tracker.update(array([1.4, 0.2]), measurement_weights=-1.0)

def test_full_tracker_contour_and_bounding_box(self):
tracker = self._make_tracker()

Expand Down