diff --git a/rl_code/Main.py b/rl_code/Main.py index aefba90..e851655 100644 --- a/rl_code/Main.py +++ b/rl_code/Main.py @@ -378,51 +378,41 @@ # print("-------------------------------------------------") # print('[GSP]', next_heading_gsp) - # Store GSP Transition + # Store GSP Transition — guard by per-robot force magnitude. + # GSP_STORE_FORCE_THRESHOLD concentrates training on samples where + # the robot is actively applying force (top ~25% of samples at + # threshold ~4.0), which multiplies the linear-R² ceiling of the + # prediction problem 3–4× (see + # docs/research/2026-04-13-gsp-ddpg-vs-attention-collapse.md). + # 0.0 = filter disabled (legacy behavior). + force_thr = float(config.get('GSP_STORE_FORCE_THRESHOLD', 0.0)) if model.gsp_neighbors: states, state_prox_flags = model.make_gsp_states(old_agent_prox_flags, neighbors_old_heading_gsp, True) new_states = model.make_gsp_states(agent_prox_flags, old_heading_gsp) for i in range(Utility.params['num_robots']): - if np.sum(state_prox_flags[i]) > 0: + if np.sum(state_prox_flags[i]) > 0 and stats[i][0] > force_thr: if model.gsp_networks['learning_scheme'] == 'attention': model.store_gsp_transition(states[i], label, 0, 0, 0) else: - # 2nd arg = label (supervised target for direct-MSE GSP training) state = states[i] new_state = new_states[i] model.store_gsp_transition(state, label, 0, new_state, 0) elif model.gsp_broadcast: - # GSP-B per-agent storage with broadcast inputs. - # state_t : broadcast view at previous step (uses neighbors_old_heading_gsp so - # the prev_gsp slot reflects the prediction from the previous tick) - # state_{t+1}: broadcast view at current step states = model.make_gsp_states_broadcast(old_agent_prox_flags, neighbors_old_heading_gsp) new_states = model.make_gsp_states_broadcast(agent_prox_flags, old_heading_gsp) for i in range(Utility.params['num_robots']): - # Gate on self-prox being non-zero so we only store informative transitions, - # matching the GSP and GSP-N branches. Self-prox lives at index 0 under the - # self-first layout. - if states[i][0] != 0: + if states[i][0] != 0 and stats[i][0] > force_thr: model.store_gsp_transition(states[i], label, 0, new_states[i], 0) else: for i in range(Utility.params['num_robots']): - if model.gsp_networks['learning_scheme'] == 'attention': - state = np.array(old_agent_prox_flags) - # only store the state if it has value - if np.sum(state) > 0: + state = np.array(old_agent_prox_flags) + if np.sum(state) > 0 and stats[i][0] > force_thr: + if model.gsp_networks['learning_scheme'] == 'attention': model.store_gsp_transition(state, label, 0, 0, 0) - elif args.independent_learning: - state = np.array(old_agent_prox_flags) - # only store the state if it has value - if np.sum(state) > 0: - # 2nd arg = label (supervised target for direct-MSE GSP training) + elif args.independent_learning: new_state = np.array(agent_prox_flags) models[i].store_gsp_transition(state, label, 0, new_state, 0) - else: - state = np.array(old_agent_prox_flags) - # only store the state if it has value - if np.sum(state) > 0: - # 2nd arg = label (supervised target for direct-MSE GSP training) + else: new_state = np.array(agent_prox_flags) model.store_gsp_transition(state, label, 0, new_state, 0) diff --git a/run_baseline_experiments.py b/run_baseline_experiments.py index e8bc48c..84ea9b4 100644 --- a/run_baseline_experiments.py +++ b/run_baseline_experiments.py @@ -162,6 +162,15 @@ def make_config(exp_name, gsp, neighbors, num_obstacles, use_gate, gate_curricul "GSP_LEARNING_FREQUENCY": 4, "LEARN_EVERY": 4, "GSP_BATCH_SIZE": 256, + # Per-robot force_magnitude threshold for GSP replay buffer store filter. + # 0.0 = disabled (store every transition with prox activity, legacy behavior). + # > 0 = only store transitions where stats[i][0] (force_magnitude) exceeds + # the threshold. This concentrates GSP training on samples where the robot + # is actively applying force, which empirically multiplies the linear-R² + # ceiling of the prediction problem 3–4× (see + # docs/research/2026-04-13-gsp-ddpg-vs-attention-collapse.md in Stelaris). + # Recommended starting point: ~4.0 (≈ p75 of force_magnitude in 2-obstacle runs). + "GSP_STORE_FORCE_THRESHOLD": 0.0, }