Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
102 commits
Select commit Hold shift + click to select a range
270f5a8
CAUTION: Action space change + dynamics model bug fix + optional jerk…
daphne-cornelisse Apr 28, 2026
a5cde03
New defaults and checkpoint.
daphne-cornelisse Apr 28, 2026
e3749d9
Update max action values for dy and yaw.
daphne-cornelisse Apr 28, 2026
1cf8ba5
Fit action space.
daphne-cornelisse Apr 28, 2026
f4f42ce
Fixes for BC training script.
daphne-cornelisse Apr 28, 2026
de60268
Add gpu heartbeat
Apr 28, 2026
4db633e
Add sanity checking functions.
daphne-cornelisse Apr 28, 2026
af65ebb
Log per-head anchor entropy during regularization.
daphne-cornelisse Apr 28, 2026
a779c79
Small adaptation.
daphne-cornelisse Apr 28, 2026
b089c59
Minor: Update csb
daphne-cornelisse Apr 28, 2026
37e7ab9
Update anchor policies.
Apr 28, 2026
2aa3df5
Merge branch 'gsp_v0' of https://github.com/Emerge-Lab/PufferDrive in…
daphne-cornelisse Apr 28, 2026
8b0a52d
Minor anchor update.
daphne-cornelisse Apr 28, 2026
5fdd7cd
Add longitudinal displacement.
daphne-cornelisse Apr 28, 2026
fd575e5
Bug fix.
daphne-cornelisse Apr 28, 2026
bad76ac
Enable jerk penalty.
daphne-cornelisse Apr 29, 2026
bb7ffe5
Script to make a bunch of BC vids.
daphne-cornelisse Apr 29, 2026
323fe18
Update analyses.
daphne-cornelisse Apr 29, 2026
890cfbc
Jerk penalty for smoothness; still iterating but visible improvements.
daphne-cornelisse Apr 29, 2026
e421b19
port renderer (#413)
julianh65 Apr 29, 2026
50fca0a
Lower jerk a little bit.
daphne-cornelisse Apr 29, 2026
dc4e263
Minor.
daphne-cornelisse Apr 29, 2026
5645430
Allow goal completion at goal radius.
daphne-cornelisse Apr 29, 2026
cef2c63
Draft timestep in top left corner.
daphne-cornelisse Apr 29, 2026
0017ba7
Condition delta-local on prev action. CAUTION: BREAKS ALL PREV CPTS.
daphne-cornelisse Apr 30, 2026
e0ace3f
Implement physical constraints into delta-local model.
daphne-cornelisse May 1, 2026
f739cb3
Minor
daphne-cornelisse May 1, 2026
1d52025
Update anchors.
daphne-cornelisse May 1, 2026
9a5f792
Fix path.
daphne-cornelisse May 1, 2026
7ddba8d
Add new anchor policy.
daphne-cornelisse May 1, 2026
c3b96e1
Delete old cpts (saved locally).
daphne-cornelisse May 1, 2026
54e0c03
Cleanup.
daphne-cornelisse May 1, 2026
30a9dec
Add new baseline runs.
daphne-cornelisse May 1, 2026
0274f0d
Visualize velocity in human-replay setting.
daphne-cornelisse May 1, 2026
b121326
New improved bc cpts.
daphne-cornelisse May 2, 2026
9be16a0
Bug fix: do not apply constraints during data collection.
daphne-cornelisse May 2, 2026
d5a35a8
Delete old cpts.
daphne-cornelisse May 2, 2026
751e8ae
Fix heading diff bug in data collection.
daphne-cornelisse May 2, 2026
aaa68f3
Debugging BC: Turned off teleport - this gives me pretty good CL perf.
daphne-cornelisse May 2, 2026
5985f8e
Minor.
daphne-cornelisse May 2, 2026
5443a38
Update BC training script.
daphne-cornelisse May 2, 2026
e02b4c1
Minor
daphne-cornelisse May 2, 2026
e8aac25
Temp change for vastai.
daphne-cornelisse May 2, 2026
1448e45
Add val test back.
daphne-cornelisse May 2, 2026
3de9414
bc_pol
daphne-cornelisse May 2, 2026
23e5a02
Delete old anchors.
daphne-cornelisse May 2, 2026
437dff6
Add two solid bc cpts.
daphne-cornelisse May 2, 2026
52d458f
Ensure correct truncation of self-play videos.
daphne-cornelisse May 3, 2026
f2903b1
Minor fixes
daphne-cornelisse May 3, 2026
a6c57c2
Loosen kinematic constraints.
daphne-cornelisse May 3, 2026
a2433f4
Temporarily turn off jerk penalty.
daphne-cornelisse May 3, 2026
195e073
Make map sampling undeterministic.
daphne-cornelisse May 3, 2026
cc3b446
Report final score as well.
daphne-cornelisse May 3, 2026
9218caa
Clean up data paths.
daphne-cornelisse May 3, 2026
e4cbae7
Bug fix that corrupts stats.
daphne-cornelisse May 3, 2026
b4652f1
Update evals.
daphne-cornelisse May 3, 2026
6ff93d5
Comment out debug print statements.
daphne-cornelisse May 3, 2026
72df523
Implement blind agents.
daphne-cornelisse May 3, 2026
fa8922f
Tiny bug fix.
daphne-cornelisse May 3, 2026
a2f7f0f
Human expert data eval.
daphne-cornelisse May 3, 2026
4a29f3f
Bug fix: missing parentheses.
daphne-cornelisse May 3, 2026
1176de1
More fine-grained action space == better stats.
daphne-cornelisse May 3, 2026
6eef878
Analysis scripts.
daphne-cornelisse May 3, 2026
91d5895
Turn off blind agents for now.
daphne-cornelisse May 3, 2026
c35fa87
12 bc pol
daphne-cornelisse May 3, 2026
73997e7
Delete old anchors
daphne-cornelisse May 3, 2026
a060c01
New anchor policies 123
daphne-cornelisse May 3, 2026
9426486
New defaults.
daphne-cornelisse May 3, 2026
0a86584
Minor
daphne-cornelisse May 3, 2026
30c0277
Never give offoad rewards.
daphne-cornelisse May 3, 2026
534901c
Lambda conditioning: Push some agents away from lambda.
daphne-cornelisse May 3, 2026
2a5479d
Lambda cond sanity check.
daphne-cornelisse May 3, 2026
7d5f3f5
12k cpt.
daphne-cornelisse May 3, 2026
175c3d8
Merge branch 'gsp_v0' of https://github.com/Emerge-Lab/PufferDrive in…
daphne-cornelisse May 3, 2026
a3ee12a
Comment out lambda logging
daphne-cornelisse May 4, 2026
5690357
Cleanup and add new baseline cpt.
daphne-cornelisse May 4, 2026
39b9f59
Clean up main analysis script and add new best reg cpt.
daphne-cornelisse May 4, 2026
a8f44ad
Check human data script.
daphne-cornelisse May 4, 2026
e58cbfa
Fix ADE bug.
daphne-cornelisse May 5, 2026
cb56d26
Update analyses scripts.
daphne-cornelisse May 5, 2026
a3f7d8b
Update cpts.
daphne-cornelisse May 5, 2026
431823a
Perf improvement: comment out prev action.
daphne-cornelisse May 5, 2026
32e997e
Re-introduce jerk penalty
daphne-cornelisse May 5, 2026
72c74f7
Minor Figure update.
daphne-cornelisse May 5, 2026
29b5adc
Minor Figure update.
daphne-cornelisse May 5, 2026
d98db67
Add new anchor policies - trained without teleport.
daphne-cornelisse May 5, 2026
2700463
Last BC anchor update.
daphne-cornelisse May 5, 2026
43c712c
Prepare .ini file.
daphne-cornelisse May 5, 2026
00cf448
Minor ini.
daphne-cornelisse May 5, 2026
a872e5f
BC settings.
daphne-cornelisse May 5, 2026
da72e36
Update checkpoint.
daphne-cornelisse May 5, 2026
18d31bc
Convert mp4 to gifs.
daphne-cornelisse May 5, 2026
451233d
New checkpoints.
daphne-cornelisse May 6, 2026
1fa9fdf
Add new cpts.
daphne-cornelisse May 6, 2026
d8fa4a5
Update settings.
daphne-cornelisse May 6, 2026
eba33d3
Implement severity measures.
daphne-cornelisse May 6, 2026
2288bb6
Add severity analysis.
daphne-cornelisse May 6, 2026
d451f11
Improve plots.
daphne-cornelisse May 6, 2026
a6b613d
Restrict collision severity analysis to collisions caused by tthe age…
daphne-cornelisse May 6, 2026
e47b4f3
Wbd/gsp idm (#418)
WaelDLZ May 6, 2026
2e5e9e8
Integrate IDM eval.
daphne-cornelisse May 6, 2026
66559b3
Plotting stuff
daphne-cornelisse May 6, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -195,3 +195,5 @@ logs/*

eval_videos/*
results/*
videos/*
anchor_videos/*
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -226,4 +226,4 @@ Run
```
puffer verify puffer_drive --env.render-mode 0
```
with `control_mode` set to `expert_replay` or `inferred_`. Note: Currently only supported with classic dynamics model.
with `control_mode` set to `expert_replay` or `inferred_`. Note: Currently only supported with classic dynamics model.
131 changes: 81 additions & 50 deletions examples/analyze.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,20 +39,15 @@
from pufferlib.ocean.benchmark.evaluator_minimal import CheckpointEvaluator

# ─── USER CONFIG ────────────────────────────────────────────────────────────────
CHECKPOINTS = [
# "models/cpts_best/reg_delta_50k_maps_anchor_100_maps.pt"
# "models/rl/reg_self_play_50k.pt",
]

SCALING_CHECKPOINTS_PATH = (
"models/scaling_cpts" # Directory containing scaling checkpoints following the naming convention described above
)
DETERMINISTIC = True
SCALING_CHECKPOINTS_PATH = "models/scaling_cpts" # "models/scaling_cpts"
DETERMINISTIC = False

TRAIN_MAP_DIR = "resources/drive/binaries/training_50k"
TRAIN_MAP_DIR = "resources/drive/binaries/training" # 50k maps
VAL_MAP_DIR = "resources/drive/binaries/validation" # 10k maps
INTERACTIVE_MAP_DIR = "resources/drive/binaries/interactive_data_validation" # 200 maps selected for SDC interactivity
NUM_TOTAL_EVAL_AGENTS = 1024 * 5
IDM_MAP_DIR = "resources/drive/binaries/interactive_200_idm" # Same 200 maps selected for SDC interactivity but processed in a different way
INTERACTIVE_MAP_DIR_MAPS = 200
NUM_TOTAL_EVAL_AGENTS = 1024 * 3
NUM_AGENTS_PER_VECENV = 1024
ENV_NAME = "puffer_drive"
DATASET = "womd"
Expand All @@ -62,12 +57,11 @@

# ─── VIDEO RENDERING CONFIG ─────────────────────────────────────────────────
CHECKPOINTS_TO_RENDER = ["models/scaling_cpts/unreg_classic_50k_maps.pt"]
NUM_ENVS_TO_RENDER = 3
NUM_ENVS_TO_RENDER = 0
RENDER_MAP_DIR = INTERACTIVE_MAP_DIR # Which maps to render on
RENDER_NUM_MAPS = 200
RENDER_OUTPUT_DIR = "eval_videos"
RENDER_MODE = "worst_collision" # "random" or "worst_collision"

# ────────────────────────────────────────────────────────────────────────────────

METRICS = [
Expand All @@ -76,15 +70,23 @@
"collision_rate",
"at_fault_collision_rate",
"rear_collision_rate",
# Delta-V
"delta_v_sum",
"delta_v_max",
"delta_v_count",
"delta_v_under_1mph",
# Other
"collisions_per_agent",
"offroad_rate",
"offroad_per_agent",
"completion_rate",
"route_progress",
"lateral_error_avg",
"episode_length",
"episode_return",
"perc_controlled",
"lateral_error_avg",
"longitudinal_error_avg",
"displacement_error_avg",
]


Expand All @@ -109,31 +111,24 @@ def _parse_num(s):
return n


def make_eval_config(cpt_config, map_dir, control_mode, num_maps, lambda_value, episode_length=110):
"""Build an eval-ready config from the checkpoint config.

Takes everything from the checkpoint and only overwrites eval-specific fields:
map_dir, control_mode, num_maps, lambda_value, and optionally episode_length.
"""
def make_eval_config(cpt_config, map_dir, control_mode, num_maps, episode_length=200, controller_overrides=None):
config = copy.deepcopy(cpt_config)
config["env"]["map_dir"] = map_dir
config["env"]["control_mode"] = control_mode
config["env"]["num_maps"] = num_maps
config["env"]["num_agents"] = NUM_AGENTS_PER_VECENV
config["env"]["lambda_value"] = lambda_value
# Fixed: Important for getting valid stats
config["env"]["async_resets"] = False

config["env"]["goal_behavior"] = 0
config["env"]["render_mode"] = 1
config["env"]["termination_mode"] = 1
config["env"]["fix_lambdas"] = True
config["env"]["fix_rewards"] = True
config["env"]["obs_partner_noise_speed"] = 0.0
config["env"]["obs_partner_noise_pos"] = 0.0
config["env"]["termination_mode"] = 1
if episode_length is not None:
config["env"]["episode_length"] = episode_length
if controller_overrides is not None: # <-- NEW
config["env"].update(controller_overrides)
config["vec"] = dict(backend="PufferEnv", num_envs=1)
return config

Expand Down Expand Up @@ -166,9 +161,10 @@ def num_resample_rounds():
return (NUM_TOTAL_EVAL_AGENTS + NUM_AGENTS_PER_VECENV - 1) // NUM_AGENTS_PER_VECENV


def run_mode(evaluator, policy, cpt_config, map_dir, control_mode, checkpoint, mode_name, num_maps, lambda_value=0.0):
"""Create env, rollout (with resampling if needed), collect per-scene rows, close env."""
config = make_eval_config(cpt_config, map_dir, control_mode, num_maps, lambda_value)
def run_mode(
evaluator, policy, cpt_config, map_dir, control_mode, checkpoint, mode_name, num_maps, controller_overrides=None
):
config = make_eval_config(cpt_config, map_dir, control_mode, num_maps, controller_overrides=controller_overrides)
env = load_env(ENV_NAME, config)
rows = []
n_rounds = num_resample_rounds()
Expand All @@ -178,7 +174,7 @@ def run_mode(evaluator, policy, cpt_config, map_dir, control_mode, checkpoint, m
if round_idx > 0:
env.driver_env.resample_maps()

rollout_stats = evaluator.rollout(policy, env, deterministic=DETERMINISTIC)
rollout_stats = evaluator.rollout(env=env, policy=policy, deterministic=DETERMINISTIC)
scene_offset = round_idx * env.driver_env.num_envs
rows.extend(process_rollout_data(rollout_stats, checkpoint, mode_name, scene_offset))

Expand Down Expand Up @@ -206,9 +202,7 @@ def evaluate_checkpoint(checkpoint_path, base_config):
cpt_config["load_model_path"] = checkpoint_path

# Create first env before loading policy (load_policy needs vecenv.driver_env)
sp_train_config = make_eval_config(
cpt_config, TRAIN_MAP_DIR, control_mode="control_vehicles", num_maps=50_000, lambda_value=0.0
)
sp_train_config = make_eval_config(cpt_config, TRAIN_MAP_DIR, control_mode="control_vehicles", num_maps=50_000)
env = load_env(ENV_NAME, sp_train_config)

policy = load_policy(cpt_config, env, ENV_NAME)
Expand All @@ -224,7 +218,7 @@ def evaluate_checkpoint(checkpoint_path, base_config):
if round_idx > 0:
env.driver_env.resample_maps()

info_list = evaluator.rollout(policy, env, deterministic=DETERMINISTIC)
info_list = evaluator.rollout(env=env, policy=policy, deterministic=DETERMINISTIC)
scene_offset = round_idx * env.driver_env.num_envs
all_rows.extend(process_rollout_data(info_list, checkpoint_path, "sp_train", scene_offset))

Expand Down Expand Up @@ -275,7 +269,7 @@ def evaluate_checkpoint(checkpoint_path, base_config):
"control_sdc_only",
checkpoint_path,
"hr_interactive",
num_maps=200,
num_maps=INTERACTIVE_MAP_DIR_MAPS,
)
)

Expand Down Expand Up @@ -375,7 +369,6 @@ def evaluate_scaling_checkpoints(base_config):
TRAIN_MAP_DIR,
control_mode="control_vehicles",
num_maps=50_000,
lambda_value=0.1 if is_reg else 0.0,
)
sp_train_env = load_env(ENV_NAME, sp_train_config)
policy = load_policy(cpt_config, sp_train_env, ENV_NAME)
Expand All @@ -394,7 +387,6 @@ def evaluate_scaling_checkpoints(base_config):
cpt_path,
"scaling_sp_train",
num_maps=50_000,
lambda_value=0.1 if is_reg else 0.0,
)

# ── Self-play on validation ──────────────────────────────────────
Expand All @@ -407,7 +399,6 @@ def evaluate_scaling_checkpoints(base_config):
cpt_path,
"scaling_sp_val",
num_maps=10_000,
lambda_value=0.1 if is_reg else 0.0,
)

# ── Human-replay on randomly sampled validation scenes ───────────
Expand All @@ -420,7 +411,6 @@ def evaluate_scaling_checkpoints(base_config):
cpt_path,
"scaling_hr_val",
num_maps=10_000,
lambda_value=0.1 if is_reg else 0.0, # TODO: Fix this
)

# ── Human-replay on interactive scenes ───────────────────────────
Expand All @@ -432,8 +422,24 @@ def evaluate_scaling_checkpoints(base_config):
"control_sdc_only",
cpt_path,
"scaling_hr_interactive",
num_maps=200,
lambda_value=0.1 if is_reg else 0.0,
num_maps=INTERACTIVE_MAP_DIR_MAPS,
)

# ── IDM eval on interactive scenes ───────────────────────────────
idm_interactive_rows = run_mode(
evaluator,
policy,
cpt_config,
IDM_MAP_DIR,
"control_sdc_only",
cpt_path,
"scaling_idm_interactive",
num_maps=INTERACTIVE_MAP_DIR_MAPS,
controller_overrides={
"sdc_controller": "policy",
"non_sdc_controller": "idm",
"non_vehicle_controller": "replay",
},
)

# Attach scaling metadata to every row
Expand All @@ -447,13 +453,14 @@ def evaluate_scaling_checkpoints(base_config):
all_rows.extend(sp_val_rows)
all_rows.extend(hr_val_rows)
all_rows.extend(hr_interactive_rows)
all_rows.extend(idm_interactive_rows)

return all_rows


def make_render_config(cpt_config, map_dir, num_maps=1000):
"""Build a config for human-replay rendering with headless ffmpeg output."""
return make_eval_config(cpt_config, map_dir, control_mode="control_sdc_only", num_maps=num_maps, lambda_value=0.0)
return make_eval_config(cpt_config, map_dir, control_mode="control_sdc_only", num_maps=num_maps)


def select_render_envs(evaluator, policy, env, num_to_render):
Expand All @@ -463,7 +470,7 @@ def select_render_envs(evaluator, policy, env, num_to_render):
List of (env_idx, collision_rate) tuples, sorted by collision rate descending,
truncated to num_to_render.
"""
info_list = evaluator.rollout(policy, env, deterministic=DETERMINISTIC)
info_list = evaluator.rollout(env=env, policy=policy, deterministic=DETERMINISTIC)
populated = [log for log in info_list if log and log.get("n", 0) > 0]
did_collide = np.array([log["collision_rate"] for log in populated])

Expand Down Expand Up @@ -532,7 +539,7 @@ def render_checkpoint_videos(base_config):

# Run a stats rollout for "random" mode to get collision rates
if RENDER_MODE == "random" and not collision_rates:
info_list = evaluator.rollout(policy, env, deterministic=DETERMINISTIC)
info_list = evaluator.rollout(env=env, policy=policy, deterministic=DETERMINISTIC)
for idx in env_indices:
if idx < len(info_list) and info_list[idx]:
collision_rates[idx] = info_list[idx].get("collision_rate", 0.0)
Expand All @@ -546,7 +553,7 @@ def render_checkpoint_videos(base_config):
# Render selected envs
for i, env_idx in enumerate(env_indices):
print(f" Rendering env {env_idx} ({i + 1}/{len(env_indices)})...")
evaluator.rollout(policy, env, render_env_idx=env_idx, deterministic=True)
evaluator.rollout(env=env, policy=policy, deterministic=DETERMINISTIC)
env.driver_env.stop_recorder(env_idx)

# Move mp4s into the checkpoint subdirectory, tagging with collision rate
Expand All @@ -563,16 +570,37 @@ def render_checkpoint_videos(base_config):
print(f"\nAll videos saved to {RENDER_OUTPUT_DIR}/")


def delta_v_summary(group):
"""Severity stats conditional on collision."""
coll = group[group["delta_v_count"] > 0]
n_coll = len(coll)
if n_coll == 0:
return pd.Series(
{
"n_collisions": 0,
"mean_dv_per_event": np.nan,
"max_dv": np.nan,
"frac_under_1mph": np.nan,
}
)
# Mean Delta-V per event: sum across all events / count of all events
total_sum = coll["delta_v_sum"].sum()
total_count = coll["delta_v_count"].sum()
return pd.Series(
{
"n_collisions": int(total_count),
"mean_dv_per_event": total_sum / total_count,
"max_dv": coll["delta_v_max"].max(),
# delta_v_under_1mph is per-agent-with-collision, so this is the right ratio
"frac_under_1mph": coll["delta_v_under_1mph"].mean(),
}
)


def main():
base_config = load_config(ENV_NAME)

all_rows = []
for cpt_path in CHECKPOINTS:
print(f"\n{'=' * 60}")
print(f"Evaluating: {cpt_path}")
print(f"{'=' * 60}")
all_rows.extend(evaluate_checkpoint(cpt_path, base_config))

# ── Scaling analysis ─────────────────────────────────────────────────
all_rows.extend(evaluate_scaling_checkpoints(base_config))

Expand All @@ -586,11 +614,14 @@ def main():
score=("score", "mean"),
collision_rate=("collision_rate", "mean"),
at_fault_collision_rate=("at_fault_collision_rate", "mean"),
rear_collision_rate=("rear_collision_rate", "mean"),
# rear_collision_rate=("rear_collision_rate", "mean"),
offroad_rate=("offroad_rate", "mean"),
)
print(f"\n{summary}")

dv_summary = df.groupby(["checkpoint", "mode"]).apply(delta_v_summary)
print(dv_summary)

# ── Figures ──────────────────────────────────────────────────────────────
if MAKE_FIGURES:
from pufferlib.ocean.benchmark.plot_and_format import make_all_figures
Expand Down
8 changes: 4 additions & 4 deletions examples/analyze_anchors.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@
from pufferlib.ocean.drive import binding

# ─── CONFIG ─────────────────────────────────────────────────────────────────
ANCHOR_DIR = "models/anchors_v2"
ANCHOR_DIR = "models/anchors"
VAL_MAP_DIR = "resources/drive/binaries/validation"
VAL_NUM_MAPS = 10_000
VAL_NUM_MAPS = 10000
OUTPUT_CSV = "results/anchor_eval.csv"
DETERMINISTIC = True

Expand Down Expand Up @@ -228,7 +228,7 @@ def evaluate_anchors(anchor_dir: str, out_path: str, val_maps: int = VAL_NUM_MAP

# Compute obs_dim from binding constants (same formula as Drive.__init__)
obs_dim = (
binding.EGO_FEATURES
binding.EGO_FEATURES_DELTA_LOCAL
+ (binding.MAX_AGENTS - 1) * binding.PARTNER_FEATURES
+ binding.MAX_ROAD_SEGMENT_OBSERVATIONS * binding.ROAD_FEATURES
)
Expand Down Expand Up @@ -257,7 +257,7 @@ def evaluate_anchors(anchor_dir: str, out_path: str, val_maps: int = VAL_NUM_MAP
partner_features=binding.PARTNER_FEATURES,
max_road_objects=binding.MAX_ROAD_SEGMENT_OBSERVATIONS,
road_features=binding.ROAD_FEATURES,
ego_dim=binding.EGO_FEATURES,
ego_dim=binding.EGO_FEATURES_DELTA_LOCAL,
hidden_size=512,
output_sizes=output_sizes,
device=str(device),
Expand Down
Loading
Loading