Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions config/drive.ini
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,10 @@ obs_slots_lane = 40
obs_slots_boundary = 40
obs_slots_partners = 16
obs_slots_traffic_controls = 4
; --- Robustness features ---
; Per-episode probability that an agent is blind to other agents for the whole episode.
; Blind agents see zeroed partner observations and are masked out of the PPO rollout buffer.
partner_blindness_prob = 0.0
; --- Observation normalization (meters) ---
norm_goal_offset_m = 100.0
norm_xy_offset_m = 100.0
Expand Down
1 change: 1 addition & 0 deletions sim/datatypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ struct Agent {
float distance_since_spawn;
int stopped;
int removed;
int is_blind_partner; // Episode-level flag: agent sees no other agents
// Goal positions
float goal_positions_x[MAX_TARGET_POINTS];
float goal_positions_y[MAX_TARGET_POINTS];
Expand Down
16 changes: 15 additions & 1 deletion sim/drive.h
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,8 @@ struct Drive {
float obs_range_partner_m;
float obs_range_traffic_control_m;
int obs_size;
// Robustness features
float partner_blindness_prob;
};

#include "dataloader.h"
Expand Down Expand Up @@ -2520,6 +2522,11 @@ static int write_reward_target_obs(Drive *env, Agent *ego, float *obs, int obs_i
}

static int write_partner_obs(Drive *env, Agent *ego, int agent_idx, float *obs, int obs_idx, int *partner_count) {
if (ego->is_blind_partner) {
// Partner slots stay zero (compute_observations memset's the buffer).
*partner_count = 0;
return obs_idx + env->obs_slots_partners * PARTNER_FEATURES;
}
typedef struct {
int index;
float dist_sq;
Expand Down Expand Up @@ -2884,6 +2891,11 @@ void c_reset(Drive *env) {
env->logs[i] = (Log) {0};
Agent *agent = &env->agents[i];

// Sample episode-level erratic flags before any continue, so every
// agent (including those that fail to spawn) gets a fresh value.
agent->is_blind_partner =
(env->partner_blindness_prob > 0.0f && random_uniform(0.0f, 1.0f) < env->partner_blindness_prob) ? 1 : 0;

if (env->simulation_mode == SIMULATION_GIGAFLOW && agent->removed) {
continue;
}
Expand All @@ -2904,9 +2916,11 @@ void c_step(Drive *env) {
memset(env->masks, 0, env->num_agents * sizeof(unsigned char));
env->timestep++;

// Erratic-driver flags (e.g. blind partners) act in the world but their
// transitions are excluded from the PPO rollout buffer per GIGAFLOW Appendix B.4.
for (int i = 0; i < env->num_agents; i++) {
Agent *a = &env->agents[i];
env->masks[i] = !(a->stopped || a->removed);
env->masks[i] = !(a->stopped || a->removed || a->is_blind_partner);
}

for (int i = 0; i < env->num_moving_log_agents; i++) {
Expand Down
4 changes: 3 additions & 1 deletion sim/env_fields.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@
F(float, obs_range_road_behind_m) \
F(float, obs_range_road_side_m) \
F(float, obs_range_partner_m) \
F(float, obs_range_traffic_control_m)
F(float, obs_range_traffic_control_m) \
/* Robustness features */ \
F(float, partner_blindness_prob)

#endif
Loading