From f94c4ed3f6c5063428217126c8856cfc041bca1b Mon Sep 17 00:00:00 2001 From: Joshua Bloom Date: Sun, 12 Apr 2026 20:14:24 -0400 Subject: [PATCH] feat(gsp-diagnostics): add information-collapse logging for GSP variants MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds the per-step and per-episode HDF5 fields needed to detect "GSP information collapse" — a suspected failure mode where the GSP prediction network collapses to a near-constant output that carries no information about the collective state. See Stelaris docs/specs/2026-04-12-dispatcher-diagnostic-batch.md for the hypothesis. Changes (all gated on opt-in — backward compatible): env.py: - calculate_gsp_reward returns (reward, label, squared_errors). The raw per-robot (diff - prediction)^2 carries the magnitude that the clipped [-2, 0] reward hides. rl_code/src/hdf5_logger.py: - New optional kwargs gsp_target, gsp_squared_error on writerow → 2D (timesteps × robots) datasets. - New record_gsp_loss(value) method → 1D dataset at GSP learning cadence. - write_episode now computes two episode-level summary attrs when both prediction and target buffers are present: - gsp_output_std (collapse signature: → 0) - gsp_pred_target_corr (collapse signature: → NaN when std is below 1e-12 tolerance, distinguishing "undefined" from "measured zero") Uses np.nanstd and pair-wise NaN masking so a single physics glitch doesn't poison the summary; raises ValueError if gsp_target/gsp_heading buffers desync within an episode. Main.py: - 3-tuple unpack of calculate_gsp_reward; broadcast scalar label to per-robot list for the (timesteps × robots) HDF5 schema; pass new kwargs to hdf5_writer.writerow. - After each model.learn() call, capture model.last_gsp_loss (from GSP-RL PR #23) and pass to hdf5_writer.record_gsp_loss. In --independent_learning mode, aggregate across per-robot models to a single scalar per learn tick (mean) so the gsp_loss axis length stays num_learn_steps regardless of mode. Tests: - 6 new TestGSPSquaredErrorReturn cases in test_env/test_gsp_reward.py; existing tests updated to 3-tuple unpack. - tests/test_diagnostics/test_hdf5_logger_gsp_diagnostics.py — 9 new tests covering: per-step datasets, gsp_loss recording, episode attrs, collapse signature detection, degenerate task, NaN poisoning, desynced-buffer raise, backward compat, optional record_gsp_loss. Companion: NESTLab/GSP-RL#23 (Actor.last_gsp_loss). Co-Authored-By: Claude Opus 4.6 (1M context) --- rl_code/Main.py | 32 ++- rl_code/src/env.py | 27 +-- rl_code/src/hdf5_logger.py | 70 +++++- .../test_hdf5_logger_gsp_diagnostics.py | 229 ++++++++++++++++++ tests/test_env/test_gsp_reward.py | 51 +++- 5 files changed, 377 insertions(+), 32 deletions(-) create mode 100644 tests/test_diagnostics/test_hdf5_logger_gsp_diagnostics.py diff --git a/rl_code/Main.py b/rl_code/Main.py index 4fb51c5..08a9c29 100644 --- a/rl_code/Main.py +++ b/rl_code/Main.py @@ -316,11 +316,11 @@ gate_stats = Utility.parse_gate_stats(msgs[7]) ############################## gsp REWARD ############################################## - gsp_reward, label = calculate_gsp_reward( - config['GSP'], - old_cyl_ang, - obj_stats[5], - next_heading_gsp, + gsp_reward, label, gsp_squared_error = calculate_gsp_reward( + config['GSP'], + old_cyl_ang, + obj_stats[5], + next_heading_gsp, Utility.params['num_robots'] ) # print('[MAIN] GSP Reward', gsp_reward) @@ -501,10 +501,25 @@ if train_mode and config['LEARNING_SCHEME'] != 'None': if time_steps % learn_every == 0: if args.independent_learning: + # Aggregate GSP losses across per-robot models to a single + # scalar per learn tick. Otherwise the 1D gsp_loss dataset + # would have (num_learn_steps × num_robots) entries in + # independent mode vs. num_learn_steps in shared mode, + # breaking cross-mode comparability of the + # information-collapse diagnostic. for i in range(Utility.params['num_robots']): loss = models[i].learn() + gsp_losses = [ + m.last_gsp_loss for m in models + if getattr(m, "last_gsp_loss", None) is not None + ] + if gsp_losses: + hdf5_writer.record_gsp_loss(float(np.mean(gsp_losses))) else: loss = model.learn() + gsp_step_loss = getattr(model, "last_gsp_loss", None) + if gsp_step_loss is not None: + hdf5_writer.record_gsp_loss(gsp_step_loss) else: loss = 0 else: @@ -546,11 +561,16 @@ else: tmp_epsilon = model.epsilon + # gsp_target: broadcast the scalar payload delta-theta label to per-robot list + # so it aligns with the (timesteps × robots) HDF5 schema. Needed for the + # information-collapse diagnostic (gsp_output_std, gsp_pred_target_corr). + gsp_target_per_robot = [float(label)] * Utility.params['num_robots'] hdf5_writer.writerow(r, tmp_epsilon, reached_goal, loss, force_mags, force_angs, [average_force_mag, math.degrees(average_force_ang)], obj_stats[0], obj_stats[1], obj_stats[5], gate, obstacles, gsp_reward, next_heading_gsp, time.time() - episode_start_time, robot_x_pos, robot_y_pos, robot_angle, - robot_failures, com_X_poses=com_X_poses, com_Y_poses=com_Y_poses) + robot_failures, com_X_poses=com_X_poses, com_Y_poses=com_Y_poses, + gsp_target=gsp_target_per_robot, gsp_squared_error=gsp_squared_error) if episode_done: if args.independent_learning: diff --git a/rl_code/src/env.py b/rl_code/src/env.py index 43a0b2e..7fb5892 100644 --- a/rl_code/src/env.py +++ b/rl_code/src/env.py @@ -18,7 +18,15 @@ def angle_normalize_signed_deg(a): return a def calculate_gsp_reward(GSP, old_cyl_ang, cyl_ang, next_heading_gsp, num_robots): + """Return (clipped_rewards, label, squared_errors) per robot. + + The clipped reward saturates at -2 and hides the magnitude of large prediction errors. + squared_errors carries the raw (diff - prediction)^2 per robot — needed for the + information-collapse diagnostic (paper outline: "Revamped Reward structure for GSP + to prevent information collapse"). + """ gsp_reward = [] + squared_errors = [] label = 0 if GSP: old_cyl_ang = angle_normalize_unsigned_deg(old_cyl_ang) @@ -28,27 +36,16 @@ def calculate_gsp_reward(GSP, old_cyl_ang, cyl_ang, next_heading_gsp, num_robots # Max rotation is 0.09 rad/step so we can multiply by 10 to get within range of -1, 1 diff = np.clip(diff*100, -1, 1) label=diff - x1 = math.cos(diff) - y1 = math.sin(diff) for i in range(num_robots): - # x2 = math.cos(next_heading_gsp[i]) - # y2 = math.sin(next_heading_gsp[i]) - # error = np.dot([x1, y1], [x2, y2]) - # gsp_reward.append(-1 + error) - # print('GSP:', next_heading_gsp[i]) - # print(f'Diff: {diff:.2f}, next_heading_gsp: {next_heading_gsp[i]:.2f}') reward = diff - next_heading_gsp[i] - # print('reward', reward) - # norm_reward = reward / next_heading_gsp[i] - # print('norm_reward', norm_reward) abs_reward = abs(reward)**2 - # print('abs reward', abs_reward) + squared_errors.append(float(abs_reward)) gsp_reward.append(np.clip(-1*abs_reward, -2, 0)) - else: gsp_reward = [0 for i in range(num_robots)] - - return gsp_reward, label + squared_errors = [0 for i in range(num_robots)] + + return gsp_reward, label, squared_errors class ZMQ_Utility: diff --git a/rl_code/src/hdf5_logger.py b/rl_code/src/hdf5_logger.py index ee5f6b5..0bf5e16 100644 --- a/rl_code/src/hdf5_logger.py +++ b/rl_code/src/hdf5_logger.py @@ -56,6 +56,11 @@ def _reset(self): self.robot_failures = [] self.com_X_pos = [] self.com_Y_pos = [] + # GSP information-collapse diagnostics: predicted delta-theta target and squared error + # per step, plus GSP-specific training loss per learning step. + self.gsp_target = [] + self.gsp_squared_error = [] + self.gsp_loss = [] def writerow( self, rewards, epsilons, terminations, losses, @@ -65,6 +70,7 @@ def writerow( gsp_rewards, gsp_headings, run_times, robots_x_poses, robots_y_poses, robot_angles, robot_failure, com_X_poses=0, com_Y_poses=0, + gsp_target=None, gsp_squared_error=None, ): """Accumulate one timestep of data. Same signature as data_logger.writerow.""" self.reward.append(rewards) @@ -90,6 +96,17 @@ def writerow( self.robot_failures.append(robot_failure) self.com_X_pos.append(com_X_poses) self.com_Y_pos.append(com_Y_poses) + if gsp_target is not None: + self.gsp_target.append(gsp_target) + if gsp_squared_error is not None: + self.gsp_squared_error.append(gsp_squared_error) + + def record_gsp_loss(self, loss_value: float) -> None: + """Record one GSP prediction network training loss sample. + + Called per GSP learning step (cadence differs from per-timestep writerow). + """ + self.gsp_loss.append(float(loss_value)) def write_episode(self, episode_num: int) -> dict: """Write accumulated data to HDF5 and return summary dict. @@ -105,7 +122,7 @@ def write_episode(self, episode_num: int) -> dict: grp = h5f.create_group(group_name) # Store 2D arrays (timesteps × robots) - for key, data in [ + twod_specs = [ ("reward", self.reward), ("gsp_reward", self.gsp_reward), ("force_magnitude", self.force_magnitude), @@ -115,7 +132,12 @@ def write_episode(self, episode_num: int) -> dict: ("robot_angle", self.robot_angle), ("robot_failure", self.robot_failures), ("gsp_heading", self.gsp_heading), - ]: + ] + if self.gsp_target: + twod_specs.append(("gsp_target", self.gsp_target)) + if self.gsp_squared_error: + twod_specs.append(("gsp_squared_error", self.gsp_squared_error)) + for key, data in twod_specs: arr = np.array(data, dtype=np.float32) if arr.size > 0: grp.create_dataset(key, data=arr, compression="gzip", compression_opts=4) @@ -135,6 +157,13 @@ def write_episode(self, episode_num: int) -> dict: if arr.size > 0: grp.create_dataset(key, data=arr, compression="gzip", compression_opts=4) + # GSP-specific training loss — 1D but recorded at a different cadence than writerow, + # so it lives outside the timestep-indexed 1D block above. + if self.gsp_loss: + gsp_loss_arr = np.array(self.gsp_loss, dtype=np.float32) + grp.create_dataset("gsp_loss", data=gsp_loss_arr, + compression="gzip", compression_opts=4) + # Termination as bool term_arr = np.array(self.termination, dtype=bool) if term_arr.size > 0: @@ -164,6 +193,43 @@ def write_episode(self, episode_num: int) -> dict: grp.attrs["reward_per_robot"] = reward_per_robot grp.attrs["gsp_reward_per_robot"] = gsp_per_robot + # Information-collapse summary attrs. Computed only when both prediction + # (gsp_heading) and target (gsp_target) are present. The caller contract is + # that gsp_target must be passed on every writerow call within an episode once + # it's been passed on any — i.e. it tracks gsp_heading 1:1. Enforce here rather + # than silently zipping misaligned buffers. Use NaN-aware aggregation so a + # single physics glitch (prediction -> NaN) does not poison the summary. + # Undefined correlations (zero-variance in predictions or targets) are written + # as NaN so downstream analysis can distinguish "undefined" from "measured zero". + if self.gsp_target and self.gsp_heading: + if len(self.gsp_target) != len(self.gsp_heading): + raise ValueError( + f"gsp_target buffer length {len(self.gsp_target)} does not match " + f"gsp_heading buffer length {len(self.gsp_heading)}; gsp_target must " + "be passed on every writerow call within an episode once it's been " + "passed on any call." + ) + pred_arr = np.array(self.gsp_heading, dtype=np.float64).ravel() + target_arr = np.array(self.gsp_target, dtype=np.float64).ravel() + if pred_arr.size > 0: + pred_std = float(np.nanstd(pred_arr)) + target_std = float(np.nanstd(target_arr)) + # Tolerance guard: np.nanstd of a constant returns ~1e-18 due to + # sum-of-squares fuzz, which passes a strict `> 0` check and lets + # corrcoef return a garbage value. The observables here are radians + # clipped to [-1, 1], so 1e-12 is well below any physical variation. + STD_TOL = 1e-12 + grp.attrs["gsp_output_std"] = pred_std + if pred_arr.size > 1 and pred_std > STD_TOL and target_std > STD_TOL: + mask = np.isfinite(pred_arr) & np.isfinite(target_arr) + if mask.sum() > 1: + corr = float(np.corrcoef(pred_arr[mask], target_arr[mask])[0, 1]) + else: + corr = float("nan") + else: + corr = float("nan") + grp.attrs["gsp_pred_target_corr"] = corr + # Reset for next episode summary = { "episode_num": episode_num, diff --git a/tests/test_diagnostics/test_hdf5_logger_gsp_diagnostics.py b/tests/test_diagnostics/test_hdf5_logger_gsp_diagnostics.py new file mode 100644 index 0000000..66afd75 --- /dev/null +++ b/tests/test_diagnostics/test_hdf5_logger_gsp_diagnostics.py @@ -0,0 +1,229 @@ +"""Tests for GSP information-collapse diagnostic fields added to HDF5Logger. + +See Stelaris docs/specs/2026-04-12-dispatcher-diagnostic-batch.md for the hypothesis +these fields exist to test: GSP output variance, prediction/target correlation, +and per-step GSP-specific training loss. +""" + +import os +import numpy as np +import pytest + +try: + import h5py + HAS_H5PY = True +except ImportError: + HAS_H5PY = False + +pytestmark = pytest.mark.skipif(not HAS_H5PY, reason="h5py not installed") + + +if HAS_H5PY: + from src.hdf5_logger import HDF5Logger + + +def _base_writerow_kwargs(): + return dict( + rewards=[0.1] * 4, + epsilons=0.5, + terminations=False, + losses=0.0, + force_magnitudes=[0.0] * 4, + force_angles=[0.0] * 4, + average_force_vectors=[0.0, 0.0], + cyl_x_poses=0.0, + cyl_y_poses=0.0, + cyl_angles=0.0, + gate_stats=0, + obstacle_stats=0, + gsp_rewards=[0.0] * 4, + gsp_headings=[0.0] * 4, + run_times=0.0, + robots_x_poses=[0.0] * 4, + robots_y_poses=[0.0] * 4, + robot_angles=[0.0] * 4, + robot_failure=[False] * 4, + ) + + +def test_gsp_target_and_squared_error_stored_per_step(tmp_path): + path = str(tmp_path / "ep.h5") + logger = HDF5Logger(path) + for t in range(5): + kwargs = _base_writerow_kwargs() + kwargs["gsp_headings"] = [0.1 * t, 0.2 * t, 0.3 * t, 0.4 * t] + kwargs["gsp_target"] = [0.15 * t, 0.25 * t, 0.35 * t, 0.45 * t] + kwargs["gsp_squared_error"] = [ + (0.05 * t) ** 2, + (0.05 * t) ** 2, + (0.05 * t) ** 2, + (0.05 * t) ** 2, + ] + logger.writerow(**kwargs) + logger.write_episode(0) + + with h5py.File(path) as f: + grp = f["episode_0000"] + assert grp["gsp_target"].shape == (5, 4) + assert grp["gsp_squared_error"].shape == (5, 4) + np.testing.assert_allclose( + grp["gsp_target"][4], [0.6, 1.0, 1.4, 1.8], rtol=1e-5 + ) + + +def test_gsp_loss_recorded_per_learning_step(tmp_path): + path = str(tmp_path / "ep.h5") + logger = HDF5Logger(path) + for t in range(3): + logger.writerow(**_base_writerow_kwargs()) + logger.record_gsp_loss(0.5) + logger.record_gsp_loss(0.4) + logger.record_gsp_loss(0.3) + logger.write_episode(0) + + with h5py.File(path) as f: + grp = f["episode_0000"] + assert "gsp_loss" in grp + assert grp["gsp_loss"].shape == (3,) + np.testing.assert_allclose(grp["gsp_loss"][...], [0.5, 0.4, 0.3], rtol=1e-6) + + +def test_episode_attrs_include_output_std_and_correlation(tmp_path): + """write_episode computes gsp_output_std and gsp_pred_target_corr.""" + path = str(tmp_path / "ep.h5") + logger = HDF5Logger(path) + + predictions = np.array([[0.1, 0.2], [0.2, 0.4], [0.3, 0.6], [0.4, 0.8]], dtype=np.float32) + targets = np.array([[0.11, 0.21], [0.19, 0.39], [0.31, 0.61], [0.40, 0.82]], dtype=np.float32) + + for t in range(4): + kwargs = _base_writerow_kwargs() + kwargs["gsp_headings"] = predictions[t].tolist() + [0.0, 0.0] + kwargs["gsp_target"] = targets[t].tolist() + [0.0, 0.0] + kwargs["gsp_squared_error"] = [ + float((predictions[t, 0] - targets[t, 0]) ** 2), + float((predictions[t, 1] - targets[t, 1]) ** 2), + 0.0, + 0.0, + ] + logger.writerow(**kwargs) + logger.write_episode(0) + + with h5py.File(path) as f: + grp = f["episode_0000"] + assert "gsp_output_std" in grp.attrs + assert "gsp_pred_target_corr" in grp.attrs + assert float(grp.attrs["gsp_output_std"]) > 0.0 + assert float(grp.attrs["gsp_pred_target_corr"]) > 0.95 + + +def test_collapsed_gsp_signature_detectable(tmp_path): + """A collapsed predictor (constant output) shows near-zero std and NaN correlation.""" + path = str(tmp_path / "ep.h5") + logger = HDF5Logger(path) + + constant_prediction = 0.05 + targets = [0.1, -0.2, 0.3, -0.4, 0.15] + + for t in range(5): + kwargs = _base_writerow_kwargs() + kwargs["gsp_headings"] = [constant_prediction] * 4 + kwargs["gsp_target"] = [targets[t]] * 4 + kwargs["gsp_squared_error"] = [(constant_prediction - targets[t]) ** 2] * 4 + logger.writerow(**kwargs) + logger.write_episode(0) + + with h5py.File(path) as f: + grp = f["episode_0000"] + assert float(grp.attrs["gsp_output_std"]) < 1e-6 + corr = float(grp.attrs["gsp_pred_target_corr"]) + assert np.isnan(corr), f"expected NaN correlation for collapsed predictor, got {corr}" + + +def test_degenerate_task_signature_distinct_from_collapsed(tmp_path): + """Both std=0 case (degenerate task) also produces NaN.""" + path = str(tmp_path / "ep.h5") + logger = HDF5Logger(path) + + for t in range(5): + kwargs = _base_writerow_kwargs() + kwargs["gsp_headings"] = [0.1] * 4 + kwargs["gsp_target"] = [0.1] * 4 + kwargs["gsp_squared_error"] = [0.0] * 4 + logger.writerow(**kwargs) + logger.write_episode(0) + + with h5py.File(path) as f: + grp = f["episode_0000"] + assert np.isnan(float(grp.attrs["gsp_pred_target_corr"])) + + +def test_nan_prediction_does_not_poison_episode_summary(tmp_path): + """A single NaN in predictions should not propagate through the std/corr attrs.""" + path = str(tmp_path / "ep.h5") + logger = HDF5Logger(path) + + predictions = [[0.1, 0.2], [0.2, 0.4], [float("nan"), 0.6], [0.4, 0.8]] + targets = [[0.11, 0.21], [0.19, 0.39], [0.31, 0.61], [0.40, 0.82]] + + for t in range(4): + kwargs = _base_writerow_kwargs() + kwargs["gsp_headings"] = predictions[t] + [0.0, 0.0] + kwargs["gsp_target"] = targets[t] + [0.0, 0.0] + kwargs["gsp_squared_error"] = [0.0] * 4 + logger.writerow(**kwargs) + logger.write_episode(0) + + with h5py.File(path) as f: + grp = f["episode_0000"] + std = float(grp.attrs["gsp_output_std"]) + corr = float(grp.attrs["gsp_pred_target_corr"]) + assert not np.isnan(std), "nanstd should handle NaN predictions" + assert not np.isnan(corr), "corrcoef should handle NaN predictions after masking" + + +def test_desynced_target_buffer_raises(tmp_path): + """Passing gsp_target on only some writerow calls in an episode is a contract violation.""" + path = str(tmp_path / "ep.h5") + logger = HDF5Logger(path) + + kwargs = _base_writerow_kwargs() + kwargs["gsp_headings"] = [0.1] * 4 + kwargs["gsp_target"] = [0.2] * 4 + logger.writerow(**kwargs) + + kwargs = _base_writerow_kwargs() + kwargs["gsp_headings"] = [0.15] * 4 + logger.writerow(**kwargs) + + with pytest.raises(ValueError, match="buffer length"): + logger.write_episode(0) + + +def test_writerow_backwards_compatible_without_gsp_diagnostics(tmp_path): + """Existing callers that don't pass the new kwargs still work.""" + path = str(tmp_path / "ep.h5") + logger = HDF5Logger(path) + logger.writerow(**_base_writerow_kwargs()) + logger.write_episode(0) + + with h5py.File(path) as f: + grp = f["episode_0000"] + assert "gsp_target" not in grp + assert "gsp_squared_error" not in grp + assert "gsp_loss" not in grp + assert "gsp_output_std" not in grp.attrs + assert "gsp_pred_target_corr" not in grp.attrs + + +def test_record_gsp_loss_is_optional(tmp_path): + """If record_gsp_loss is never called, no gsp_loss dataset is written.""" + path = str(tmp_path / "ep.h5") + logger = HDF5Logger(path) + for t in range(3): + logger.writerow(**_base_writerow_kwargs()) + logger.write_episode(0) + + with h5py.File(path) as f: + grp = f["episode_0000"] + assert "gsp_loss" not in grp diff --git a/tests/test_env/test_gsp_reward.py b/tests/test_env/test_gsp_reward.py index 7ed46ca..d42ae6d 100644 --- a/tests/test_env/test_gsp_reward.py +++ b/tests/test_env/test_gsp_reward.py @@ -8,48 +8,81 @@ class TestGSPRewardDisabled: def test_gsp_false_returns_zeros(self): - rewards, label = calculate_gsp_reward(False, 10.0, 15.0, [0.0, 0.0], 2) + rewards, label, _ = calculate_gsp_reward(False, 10.0, 15.0, [0.0, 0.0], 2) assert rewards == [0, 0] assert label == 0 def test_gsp_false_any_num_robots(self): - rewards, label = calculate_gsp_reward(False, 0, 0, [0]*8, 8) + rewards, label, _ = calculate_gsp_reward(False, 0, 0, [0]*8, 8) assert len(rewards) == 8 assert all(r == 0 for r in rewards) class TestGSPRewardEnabled: def test_perfect_prediction_zero_change(self): - rewards, label = calculate_gsp_reward(True, 45.0, 45.0, [0.0], 1) + rewards, label, _ = calculate_gsp_reward(True, 45.0, 45.0, [0.0], 1) assert label == pytest.approx(0.0) assert rewards[0] == pytest.approx(0.0) def test_perfect_prediction_nonzero_change(self): # 5 deg change: radians(5)=0.0873, x100=8.73, clipped to 1.0 - rewards, label = calculate_gsp_reward(True, 0.0, 5.0, [1.0], 1) + rewards, label, _ = calculate_gsp_reward(True, 0.0, 5.0, [1.0], 1) assert label == pytest.approx(1.0) assert rewards[0] == pytest.approx(0.0) def test_bad_prediction_negative_reward(self): - rewards, label = calculate_gsp_reward(True, 0.0, 0.0, [1.0], 1) + rewards, label, _ = calculate_gsp_reward(True, 0.0, 0.0, [1.0], 1) assert label == pytest.approx(0.0) assert rewards[0] == pytest.approx(-1.0) def test_reward_clipped_at_minus_2(self): - rewards, label = calculate_gsp_reward(True, 0.0, 0.0, [2.0], 1) + rewards, label, _ = calculate_gsp_reward(True, 0.0, 0.0, [2.0], 1) assert rewards[0] == pytest.approx(-2.0) def test_reward_per_robot(self): - rewards, label = calculate_gsp_reward(True, 0.0, 0.0, [0.0, 1.0, 0.5], 3) + rewards, label, _ = calculate_gsp_reward(True, 0.0, 0.0, [0.0, 1.0, 0.5], 3) assert len(rewards) == 3 assert rewards[0] == pytest.approx(0.0) assert rewards[1] == pytest.approx(-1.0) assert rewards[2] == pytest.approx(-0.25) def test_wraparound_angles(self): - rewards, label = calculate_gsp_reward(True, 350.0, 10.0, [0.0], 1) + rewards, label, _ = calculate_gsp_reward(True, 350.0, 10.0, [0.0], 1) assert label == pytest.approx(1.0) # 20 deg -> radians -> x100 -> clipped to 1.0 def test_label_type(self): - _, label = calculate_gsp_reward(True, 0.0, 1.0, [0.0], 1) + _, label, _ = calculate_gsp_reward(True, 0.0, 1.0, [0.0], 1) assert isinstance(label, (float, np.floating)) + + +class TestGSPSquaredErrorReturn: + """The function returns per-robot squared prediction error alongside the clipped reward. + + Needed for information-collapse diagnosis: the raw error carries more signal than the + clipped reward, since the reward saturates at -2 and loses the magnitude of large errors. + """ + + def test_returns_three_tuple(self): + result = calculate_gsp_reward(True, 0.0, 0.0, [0.0], 1) + assert len(result) == 3 + + def test_squared_error_is_zero_for_perfect_prediction(self): + _, _, squared = calculate_gsp_reward(True, 45.0, 45.0, [0.0, 0.0], 2) + assert squared == pytest.approx([0.0, 0.0]) + + def test_squared_error_is_unclipped(self): + """Reward is clipped to [-2, 0] but squared_error is the raw (diff - pred)^2.""" + _, _, squared = calculate_gsp_reward(True, 0.0, 0.0, [3.0], 1) + # diff=0, pred=3.0 -> raw squared error = 9.0 (much bigger than the clipped reward -2) + assert squared[0] == pytest.approx(9.0) + + def test_squared_error_per_robot(self): + _, _, squared = calculate_gsp_reward(True, 0.0, 0.0, [0.0, 1.0, 0.5], 3) + assert len(squared) == 3 + assert squared[0] == pytest.approx(0.0) + assert squared[1] == pytest.approx(1.0) + assert squared[2] == pytest.approx(0.25) + + def test_squared_error_is_zero_list_when_gsp_disabled(self): + _, _, squared = calculate_gsp_reward(False, 0, 0, [0.0] * 4, 4) + assert squared == [0, 0, 0, 0]