diff --git a/rl_code/Main.py b/rl_code/Main.py index 36578de..5920078 100644 --- a/rl_code/Main.py +++ b/rl_code/Main.py @@ -384,15 +384,13 @@ if model.gsp_networks['learning_scheme'] == 'attention': model.store_gsp_transition(states[i], label, 0, 0, 0) else: + # Under the direct-MSE GSP training path, the 2nd arg + # (action field) carries the supervised target label. + # See GSP-RL fix/gsp-direct-mse-training PR #24 and + # Stelaris docs/research/2026-04-13-gsp-information-collapse-analysis.md. state = states[i] - action = old_heading_gsp[i] - reward = gsp_reward[i] new_state = new_states[i] - # print('[MAIN] Transition State:', state) - # print('[MAIN] Transition Action:', action) - # print('[MAIN] Transition Reward:', reward) - # print('[MAIN] Transition New State:', new_state) - model.store_gsp_transition(state, action, reward, new_state, 0) + model.store_gsp_transition(state, label, 0, new_state, 0) else: for i in range(Utility.params['num_robots']): if model.gsp_networks['learning_scheme'] == 'attention': @@ -404,23 +402,16 @@ state = np.array(old_agent_prox_flags) # only store the state if it has value if np.sum(state) > 0: - action = old_heading_gsp[i] - reward = gsp_reward[i] + # 2nd arg = label (supervised target for direct-MSE GSP training) new_state = np.array(agent_prox_flags) - models[i].store_gsp_transition(state, action, reward, new_state, 0) + models[i].store_gsp_transition(state, label, 0, new_state, 0) else: - # print(f'[AGENT] {i} GSP:', old_heading_gsp[i]) state = np.array(old_agent_prox_flags) # only store the state if it has value if np.sum(state) > 0: - action = old_heading_gsp[i] - reward = gsp_reward[i] + # 2nd arg = label (supervised target for direct-MSE GSP training) new_state = np.array(agent_prox_flags) - # print('[MAIN] Transition State:', state) - # print('[MAIN] Transition Action:', action) - # print('[MAIN] Transition Reward:', reward) - # print('[MAIN] Transition New State:', new_state) - model.store_gsp_transition(state, action, reward, new_state, 0) + model.store_gsp_transition(state, label, 0, new_state, 0) #Define Global Knowledge: [positions, velocities] @@ -507,6 +498,10 @@ # information-collapse diagnostic. for i in range(Utility.params['num_robots']): loss = models[i].learn() + # TD3's learn_TD3 returns (0, 0) on non-actor-update steps; + # unwrap so the hdf5 logger's 1D loss array stays homogeneous. + if isinstance(loss, tuple): + loss = loss[0] gsp_losses = [ m.last_gsp_loss for m in models if getattr(m, "last_gsp_loss", None) is not None @@ -515,6 +510,8 @@ hdf5_writer.record_gsp_loss(float(np.mean(gsp_losses))) else: loss = model.learn() + if isinstance(loss, tuple): + loss = loss[0] gsp_step_loss = getattr(model, "last_gsp_loss", None) if gsp_step_loss is not None: hdf5_writer.record_gsp_loss(gsp_step_loss) diff --git a/run_baseline_experiments.py b/run_baseline_experiments.py index 9f58674..e8bc48c 100644 --- a/run_baseline_experiments.py +++ b/run_baseline_experiments.py @@ -159,7 +159,7 @@ def make_config(exp_name, gsp, neighbors, num_obstacles, use_gate, gate_curricul "NOISE": 0.1, "UPDATE_ACTOR_ITER": 2, "WARMUP": 1000, - "GSP_LEARNING_FREQUENCY": 500, + "GSP_LEARNING_FREQUENCY": 4, "LEARN_EVERY": 4, "GSP_BATCH_SIZE": 256, } diff --git a/tests/test_main_gsp_contract.py b/tests/test_main_gsp_contract.py new file mode 100644 index 0000000..b13ce2b --- /dev/null +++ b/tests/test_main_gsp_contract.py @@ -0,0 +1,48 @@ +"""Contract test for Main.py's store_gsp_transition calls. + +Under the direct-MSE GSP training path (GSP-RL fix/gsp-direct-mse-training), +Main.py must store the LABEL (not the previous prediction) in the action +field of the replay buffer for ALL GSP variants. This test asserts the call +shape via static inspection — it's not a runtime test, but it catches +regressions in the call signature without needing ARGoS. + +See Stelaris docs/research/2026-04-13-gsp-information-collapse-analysis.md +for the full rationale. +""" + +import pathlib +import re + + +MAIN_PY = pathlib.Path(__file__).resolve().parent.parent / "rl_code" / "Main.py" + + +def test_store_gsp_transition_passes_label_as_action_argument(): + """Every store_gsp_transition call in Main.py must pass label-related arg as 2nd. + + The 2nd positional arg of store_gsp_transition is the action field. Under + direct-MSE training, the GSP predictor is trained via MSE against the label + stored in this field. If any call site passes the previous prediction + (`old_heading_gsp[i]`, `action`, etc.) instead of `label`, the predictor + trains against its own old output — no supervision signal. + """ + text = MAIN_PY.read_text() + # Match any store_gsp_transition(...) call and capture its entire arg list. + # Use re.DOTALL in case the call spans multiple lines. + calls = re.findall(r'store_gsp_transition\s*\(([^)]*)\)', text, re.DOTALL) + assert len(calls) >= 3, ( + f"expected at least 3 store_gsp_transition calls (one per branch in the " + f"independent/shared/attention block), got {len(calls)}" + ) + violations = [] + for i, call in enumerate(calls): + args = [a.strip() for a in call.split(",")] + assert len(args) >= 2, f"call {i}: expected at least 2 args, got {args}" + second_arg = args[1] + if "label" not in second_arg: + violations.append((i, second_arg)) + assert not violations, ( + f"{len(violations)} store_gsp_transition call(s) do not pass `label` as " + f"the 2nd arg (action field); violations: {violations}. Under direct-MSE " + f"GSP training the action field must carry the supervised target." + )