diff --git a/config/drive.ini b/config/drive.ini index 3211f735b..b46a4e75b 100644 --- a/config/drive.ini +++ b/config/drive.ini @@ -81,6 +81,10 @@ obs_slots_lane = 40 obs_slots_boundary = 40 obs_slots_partners = 16 obs_slots_traffic_controls = 4 +; --- Robustness features --- +; Per-episode probability that an agent is blind to other agents for the whole episode. +; Blind agents see zeroed partner observations and are masked out of the PPO rollout buffer. +partner_blindness_prob = 0.0 ; --- Observation normalization (meters) --- norm_goal_offset_m = 100.0 norm_xy_offset_m = 100.0 diff --git a/sim/datatypes.h b/sim/datatypes.h index ee278648a..8c3b62d0d 100644 --- a/sim/datatypes.h +++ b/sim/datatypes.h @@ -181,6 +181,7 @@ struct Agent { float distance_since_spawn; int stopped; int removed; + int is_blind_partner; // Episode-level flag: agent sees no other agents // Goal positions float goal_positions_x[MAX_TARGET_POINTS]; float goal_positions_y[MAX_TARGET_POINTS]; diff --git a/sim/drive.h b/sim/drive.h index 7d4126e9e..441b09adf 100644 --- a/sim/drive.h +++ b/sim/drive.h @@ -236,6 +236,8 @@ struct Drive { float obs_range_partner_m; float obs_range_traffic_control_m; int obs_size; + // Robustness features + float partner_blindness_prob; }; #include "dataloader.h" @@ -2520,6 +2522,11 @@ static int write_reward_target_obs(Drive *env, Agent *ego, float *obs, int obs_i } static int write_partner_obs(Drive *env, Agent *ego, int agent_idx, float *obs, int obs_idx, int *partner_count) { + if (ego->is_blind_partner) { + // Partner slots stay zero (compute_observations memset's the buffer). + *partner_count = 0; + return obs_idx + env->obs_slots_partners * PARTNER_FEATURES; + } typedef struct { int index; float dist_sq; @@ -2884,6 +2891,11 @@ void c_reset(Drive *env) { env->logs[i] = (Log) {0}; Agent *agent = &env->agents[i]; + // Sample episode-level erratic flags before any continue, so every + // agent (including those that fail to spawn) gets a fresh value. + agent->is_blind_partner = + (env->partner_blindness_prob > 0.0f && random_uniform(0.0f, 1.0f) < env->partner_blindness_prob) ? 1 : 0; + if (env->simulation_mode == SIMULATION_GIGAFLOW && agent->removed) { continue; } @@ -2904,9 +2916,11 @@ void c_step(Drive *env) { memset(env->masks, 0, env->num_agents * sizeof(unsigned char)); env->timestep++; + // Erratic-driver flags (e.g. blind partners) act in the world but their + // transitions are excluded from the PPO rollout buffer per GIGAFLOW Appendix B.4. for (int i = 0; i < env->num_agents; i++) { Agent *a = &env->agents[i]; - env->masks[i] = !(a->stopped || a->removed); + env->masks[i] = !(a->stopped || a->removed || a->is_blind_partner); } for (int i = 0; i < env->num_moving_log_agents; i++) { diff --git a/sim/env_fields.h b/sim/env_fields.h index 5182d0f71..80d04aa66 100644 --- a/sim/env_fields.h +++ b/sim/env_fields.h @@ -63,6 +63,8 @@ F(float, obs_range_road_behind_m) \ F(float, obs_range_road_side_m) \ F(float, obs_range_partner_m) \ - F(float, obs_range_traffic_control_m) + F(float, obs_range_traffic_control_m) \ + /* Robustness features */ \ + F(float, partner_blindness_prob) #endif