From 716bf72f819ab1c66efd9f397cb0d6f9ce60a7aa Mon Sep 17 00:00:00 2001 From: Joshua Bloom Date: Mon, 13 Apr 2026 09:54:00 -0400 Subject: [PATCH] fix(main): store label in GSP replay + unwrap TD3 tuple + bump GSP cadence MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Three coordinated changes supporting the GSP direct-MSE training path from GSP-RL PR #24. 1. store_gsp_transition call sites — pass LABEL as the 2nd arg Previously non-attention variants stored the previous prediction (`old_heading_gsp[i]`) in the action field so the DDPG critic could evaluate Q(state, prediction). The DDPG actor-critic training path is gone, so the "action" slot is now used for the supervised target. The attention variant already passed `label` here; now all branches do. Three call sites updated: independent-learning training store, shared- model training store, and the global-knowledge branch. 2. TD3 tuple-loss crash fix `learn_TD3` returns (0, 0) on non-actor-update steps (where the critic updated but the actor did not, because of UPDATE_ACTOR_ITER>1). Main.py was storing this tuple in the `loss` timeseries passed to hdf5_writer, which then crashed with: ValueError: setting an array element with a sequence. The detected shape was (4500,) + inhomogeneous part. This crashed `td3_gsp_s123` in the diagnostic batch. Unwrap the tuple to a scalar at both learn call sites (independent + shared). 3. GSP_LEARNING_FREQUENCY: 500 → 4 Under direct-MSE GSP training the loss is supervised and cheap — there's no need for the 500-step cadence that was a hedge for the actor-critic path's sparse feedback. 4 matches LEARN_EVERY so the GSP predictor updates in lockstep with the primary network and gets ~125× more gradient steps per episode. New contract test: tests/test_main_gsp_contract.py statically asserts that every store_gsp_transition call in Main.py passes the label as its 2nd arg. Catches regressions in the call signature. Full RL-CT suite: 112/112 pass (excluding pre-existing test_nan_guards.py import error unrelated to this PR). Co-Authored-By: Claude Opus 4.6 (1M context) --- rl_code/Main.py | 33 +++++++++++------------ run_baseline_experiments.py | 2 +- tests/test_main_gsp_contract.py | 48 +++++++++++++++++++++++++++++++++ 3 files changed, 64 insertions(+), 19 deletions(-) create mode 100644 tests/test_main_gsp_contract.py diff --git a/rl_code/Main.py b/rl_code/Main.py index 36578de..5920078 100644 --- a/rl_code/Main.py +++ b/rl_code/Main.py @@ -384,15 +384,13 @@ if model.gsp_networks['learning_scheme'] == 'attention': model.store_gsp_transition(states[i], label, 0, 0, 0) else: + # Under the direct-MSE GSP training path, the 2nd arg + # (action field) carries the supervised target label. + # See GSP-RL fix/gsp-direct-mse-training PR #24 and + # Stelaris docs/research/2026-04-13-gsp-information-collapse-analysis.md. state = states[i] - action = old_heading_gsp[i] - reward = gsp_reward[i] new_state = new_states[i] - # print('[MAIN] Transition State:', state) - # print('[MAIN] Transition Action:', action) - # print('[MAIN] Transition Reward:', reward) - # print('[MAIN] Transition New State:', new_state) - model.store_gsp_transition(state, action, reward, new_state, 0) + model.store_gsp_transition(state, label, 0, new_state, 0) else: for i in range(Utility.params['num_robots']): if model.gsp_networks['learning_scheme'] == 'attention': @@ -404,23 +402,16 @@ state = np.array(old_agent_prox_flags) # only store the state if it has value if np.sum(state) > 0: - action = old_heading_gsp[i] - reward = gsp_reward[i] + # 2nd arg = label (supervised target for direct-MSE GSP training) new_state = np.array(agent_prox_flags) - models[i].store_gsp_transition(state, action, reward, new_state, 0) + models[i].store_gsp_transition(state, label, 0, new_state, 0) else: - # print(f'[AGENT] {i} GSP:', old_heading_gsp[i]) state = np.array(old_agent_prox_flags) # only store the state if it has value if np.sum(state) > 0: - action = old_heading_gsp[i] - reward = gsp_reward[i] + # 2nd arg = label (supervised target for direct-MSE GSP training) new_state = np.array(agent_prox_flags) - # print('[MAIN] Transition State:', state) - # print('[MAIN] Transition Action:', action) - # print('[MAIN] Transition Reward:', reward) - # print('[MAIN] Transition New State:', new_state) - model.store_gsp_transition(state, action, reward, new_state, 0) + model.store_gsp_transition(state, label, 0, new_state, 0) #Define Global Knowledge: [positions, velocities] @@ -507,6 +498,10 @@ # information-collapse diagnostic. for i in range(Utility.params['num_robots']): loss = models[i].learn() + # TD3's learn_TD3 returns (0, 0) on non-actor-update steps; + # unwrap so the hdf5 logger's 1D loss array stays homogeneous. + if isinstance(loss, tuple): + loss = loss[0] gsp_losses = [ m.last_gsp_loss for m in models if getattr(m, "last_gsp_loss", None) is not None @@ -515,6 +510,8 @@ hdf5_writer.record_gsp_loss(float(np.mean(gsp_losses))) else: loss = model.learn() + if isinstance(loss, tuple): + loss = loss[0] gsp_step_loss = getattr(model, "last_gsp_loss", None) if gsp_step_loss is not None: hdf5_writer.record_gsp_loss(gsp_step_loss) diff --git a/run_baseline_experiments.py b/run_baseline_experiments.py index 9f58674..e8bc48c 100644 --- a/run_baseline_experiments.py +++ b/run_baseline_experiments.py @@ -159,7 +159,7 @@ def make_config(exp_name, gsp, neighbors, num_obstacles, use_gate, gate_curricul "NOISE": 0.1, "UPDATE_ACTOR_ITER": 2, "WARMUP": 1000, - "GSP_LEARNING_FREQUENCY": 500, + "GSP_LEARNING_FREQUENCY": 4, "LEARN_EVERY": 4, "GSP_BATCH_SIZE": 256, } diff --git a/tests/test_main_gsp_contract.py b/tests/test_main_gsp_contract.py new file mode 100644 index 0000000..b13ce2b --- /dev/null +++ b/tests/test_main_gsp_contract.py @@ -0,0 +1,48 @@ +"""Contract test for Main.py's store_gsp_transition calls. + +Under the direct-MSE GSP training path (GSP-RL fix/gsp-direct-mse-training), +Main.py must store the LABEL (not the previous prediction) in the action +field of the replay buffer for ALL GSP variants. This test asserts the call +shape via static inspection — it's not a runtime test, but it catches +regressions in the call signature without needing ARGoS. + +See Stelaris docs/research/2026-04-13-gsp-information-collapse-analysis.md +for the full rationale. +""" + +import pathlib +import re + + +MAIN_PY = pathlib.Path(__file__).resolve().parent.parent / "rl_code" / "Main.py" + + +def test_store_gsp_transition_passes_label_as_action_argument(): + """Every store_gsp_transition call in Main.py must pass label-related arg as 2nd. + + The 2nd positional arg of store_gsp_transition is the action field. Under + direct-MSE training, the GSP predictor is trained via MSE against the label + stored in this field. If any call site passes the previous prediction + (`old_heading_gsp[i]`, `action`, etc.) instead of `label`, the predictor + trains against its own old output — no supervision signal. + """ + text = MAIN_PY.read_text() + # Match any store_gsp_transition(...) call and capture its entire arg list. + # Use re.DOTALL in case the call spans multiple lines. + calls = re.findall(r'store_gsp_transition\s*\(([^)]*)\)', text, re.DOTALL) + assert len(calls) >= 3, ( + f"expected at least 3 store_gsp_transition calls (one per branch in the " + f"independent/shared/attention block), got {len(calls)}" + ) + violations = [] + for i, call in enumerate(calls): + args = [a.strip() for a in call.split(",")] + assert len(args) >= 2, f"call {i}: expected at least 2 args, got {args}" + second_arg = args[1] + if "label" not in second_arg: + violations.append((i, second_arg)) + assert not violations, ( + f"{len(violations)} store_gsp_transition call(s) do not pass `label` as " + f"the 2nd arg (action field); violations: {violations}. Under direct-MSE " + f"GSP training the action field must carry the supervised target." + )