Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 15 additions & 18 deletions rl_code/Main.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,15 +384,13 @@
if model.gsp_networks['learning_scheme'] == 'attention':
model.store_gsp_transition(states[i], label, 0, 0, 0)
else:
# Under the direct-MSE GSP training path, the 2nd arg
# (action field) carries the supervised target label.
# See GSP-RL fix/gsp-direct-mse-training PR #24 and
# Stelaris docs/research/2026-04-13-gsp-information-collapse-analysis.md.
state = states[i]
action = old_heading_gsp[i]
reward = gsp_reward[i]
new_state = new_states[i]
# print('[MAIN] Transition State:', state)
# print('[MAIN] Transition Action:', action)
# print('[MAIN] Transition Reward:', reward)
# print('[MAIN] Transition New State:', new_state)
model.store_gsp_transition(state, action, reward, new_state, 0)
model.store_gsp_transition(state, label, 0, new_state, 0)
else:
for i in range(Utility.params['num_robots']):
if model.gsp_networks['learning_scheme'] == 'attention':
Expand All @@ -404,23 +402,16 @@
state = np.array(old_agent_prox_flags)
# only store the state if it has value
if np.sum(state) > 0:
action = old_heading_gsp[i]
reward = gsp_reward[i]
# 2nd arg = label (supervised target for direct-MSE GSP training)
new_state = np.array(agent_prox_flags)
models[i].store_gsp_transition(state, action, reward, new_state, 0)
models[i].store_gsp_transition(state, label, 0, new_state, 0)
else:
# print(f'[AGENT] {i} GSP:', old_heading_gsp[i])
state = np.array(old_agent_prox_flags)
# only store the state if it has value
if np.sum(state) > 0:
action = old_heading_gsp[i]
reward = gsp_reward[i]
# 2nd arg = label (supervised target for direct-MSE GSP training)
new_state = np.array(agent_prox_flags)
# print('[MAIN] Transition State:', state)
# print('[MAIN] Transition Action:', action)
# print('[MAIN] Transition Reward:', reward)
# print('[MAIN] Transition New State:', new_state)
model.store_gsp_transition(state, action, reward, new_state, 0)
model.store_gsp_transition(state, label, 0, new_state, 0)


#Define Global Knowledge: [positions, velocities]
Expand Down Expand Up @@ -507,6 +498,10 @@
# information-collapse diagnostic.
for i in range(Utility.params['num_robots']):
loss = models[i].learn()
# TD3's learn_TD3 returns (0, 0) on non-actor-update steps;
# unwrap so the hdf5 logger's 1D loss array stays homogeneous.
if isinstance(loss, tuple):
loss = loss[0]
gsp_losses = [
m.last_gsp_loss for m in models
if getattr(m, "last_gsp_loss", None) is not None
Expand All @@ -515,6 +510,8 @@
hdf5_writer.record_gsp_loss(float(np.mean(gsp_losses)))
else:
loss = model.learn()
if isinstance(loss, tuple):
loss = loss[0]
gsp_step_loss = getattr(model, "last_gsp_loss", None)
if gsp_step_loss is not None:
hdf5_writer.record_gsp_loss(gsp_step_loss)
Expand Down
2 changes: 1 addition & 1 deletion run_baseline_experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def make_config(exp_name, gsp, neighbors, num_obstacles, use_gate, gate_curricul
"NOISE": 0.1,
"UPDATE_ACTOR_ITER": 2,
"WARMUP": 1000,
"GSP_LEARNING_FREQUENCY": 500,
"GSP_LEARNING_FREQUENCY": 4,
"LEARN_EVERY": 4,
"GSP_BATCH_SIZE": 256,
}
Expand Down
48 changes: 48 additions & 0 deletions tests/test_main_gsp_contract.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
"""Contract test for Main.py's store_gsp_transition calls.

Under the direct-MSE GSP training path (GSP-RL fix/gsp-direct-mse-training),
Main.py must store the LABEL (not the previous prediction) in the action
field of the replay buffer for ALL GSP variants. This test asserts the call
shape via static inspection — it's not a runtime test, but it catches
regressions in the call signature without needing ARGoS.

See Stelaris docs/research/2026-04-13-gsp-information-collapse-analysis.md
for the full rationale.
"""

import pathlib
import re


MAIN_PY = pathlib.Path(__file__).resolve().parent.parent / "rl_code" / "Main.py"


def test_store_gsp_transition_passes_label_as_action_argument():
"""Every store_gsp_transition call in Main.py must pass label-related arg as 2nd.

The 2nd positional arg of store_gsp_transition is the action field. Under
direct-MSE training, the GSP predictor is trained via MSE against the label
stored in this field. If any call site passes the previous prediction
(`old_heading_gsp[i]`, `action`, etc.) instead of `label`, the predictor
trains against its own old output — no supervision signal.
"""
text = MAIN_PY.read_text()
# Match any store_gsp_transition(...) call and capture its entire arg list.
# Use re.DOTALL in case the call spans multiple lines.
calls = re.findall(r'store_gsp_transition\s*\(([^)]*)\)', text, re.DOTALL)
assert len(calls) >= 3, (
f"expected at least 3 store_gsp_transition calls (one per branch in the "
f"independent/shared/attention block), got {len(calls)}"
)
violations = []
for i, call in enumerate(calls):
args = [a.strip() for a in call.split(",")]
assert len(args) >= 2, f"call {i}: expected at least 2 args, got {args}"
second_arg = args[1]
if "label" not in second_arg:
violations.append((i, second_arg))
assert not violations, (
f"{len(violations)} store_gsp_transition call(s) do not pass `label` as "
f"the 2nd arg (action field); violations: {violations}. Under direct-MSE "
f"GSP training the action field must carry the supervised target."
)
Loading