From 9e2fc52925fd2e048e22cf9799db32d4058dbdda Mon Sep 17 00:00:00 2001 From: mpragnay Date: Mon, 27 Apr 2026 20:10:15 -0400 Subject: [PATCH 1/3] init commit --- scripts/export_onnx.py | 261 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 261 insertions(+) create mode 100644 scripts/export_onnx.py diff --git a/scripts/export_onnx.py b/scripts/export_onnx.py new file mode 100644 index 000000000..927315690 --- /dev/null +++ b/scripts/export_onnx.py @@ -0,0 +1,261 @@ +"""Export a trained PufferDrive policy checkpoint (.pt) to ONNX format. + +The exported ONNX model accepts an observation vector (see in torch.py for the exact layout), +plus LSTM hidden states, and produces action logits, +a value estimate, and updated LSTM states. + +Usage: + python -m scripts.export_onnx --checkpoint [--output ] +""" + +import argparse +import os +import torch +import importlib +import numpy as np +import onnxruntime as ort + +import pufferlib.utils +import pufferlib.vector +import pufferlib.models + +from pufferlib.ocean.torch import Drive +from scripts.export_model_bin import load_config + + +class OnnxWrapper(torch.nn.Module): + def __init__(self, policy): + super().__init__() + self.policy = policy + + def forward(self, observation, h, c): + # Reconstruct the state dictionary expected by LSTMWrapper + # state must be mutable as forward_eval updates it + state = {"lstm_h": h, "lstm_c": c} + + # Call forward_eval + logits, value = self.policy.forward_eval(observation, state) + + # Extract updated states + new_h = state["lstm_h"] + new_c = state["lstm_c"] + + return logits, value, new_h, new_c + + +def export_to_onnx(verify=True): + parser = argparse.ArgumentParser(description="Export PufferDrive model to ONNX") + parser.add_argument("--env", type=str, default="puffer_drive", help="Environment name") + parser.add_argument( + "--checkpoint", + type=str, + default="experiments/puffer_drive_73kbtsi5/model_puffer_drive_000200.pt", + help="Path to .pt checkpoint", + ) + parser.add_argument("--output", type=str, help="Output .onnx file path") + parser.add_argument("--opset", type=int, default=18, help="ONNX opset version") + + args = parser.parse_args() + + # Load configuration + config = load_config(args.env) + + # Load environment to get observation/action spaces + package = config["base"]["package"] + module_name = "pufferlib.ocean" if package == "ocean" else f"pufferlib.environments.{package}" + env_module = importlib.import_module(module_name) + make_env = env_module.env_creator(args.env) + + # Ensure env args/kwargs are correctly passed + env_kwargs = config["env"] + + vecenv = pufferlib.vector.make(make_env, env_kwargs=env_kwargs, backend=pufferlib.vector.Serial, num_envs=1) + + # Initialize Policy + print("Initializing Policy...") + policy = Drive(vecenv.driver_env, **config["policy"]) + + if config["base"]["rnn_name"]: + print("Wrapping with LSTM...") + policy = pufferlib.models.LSTMWrapper(vecenv.driver_env, policy, **config["rnn"]) + + # Load Checkpoint + print(f"Loading checkpoint from {args.checkpoint}...") + checkpoint = torch.load(args.checkpoint, map_location="cpu") + + # Handle both full checkpoint dict and raw state dict + if isinstance(checkpoint, dict) and "agent_state_dict" in checkpoint: + state_dict = checkpoint["agent_state_dict"] + else: + state_dict = checkpoint + + # Strip compile prefixes + new_state_dict = {} + for k, v in state_dict.items(): + if k.startswith("_orig_mod."): + new_state_dict[k[10:]] = v + else: + new_state_dict[k] = v + + policy.load_state_dict(new_state_dict) + policy.eval() + + # Prepare inputs for ONNX export + print("Preparing sample inputs...") + batch_size = 1 + + obs_space = vecenv.single_observation_space + # Flatten observation if needed, Drive policy handles flattening internally usually but check vecenv + # The LSTMWrapper expects (B, ObsDim) + obs_dim = np.prod(obs_space.shape) + + # Create Dummy Observation + if config["base"]["rnn_name"]: + # If wrapped, access the internal Drive policy + drive_policy = policy.policy + else: + drive_policy = policy + + if hasattr(drive_policy, "ego_dim"): + # Construct valid dummy observation for Drive policy + # Retrieve needed dimensions + ego_dim = drive_policy.ego_dim + max_partner_objects = drive_policy.max_partner_objects + partner_features = drive_policy.partner_features + max_road_objects = drive_policy.max_road_objects + road_features = drive_policy.road_features + + partner_dim = max_partner_objects * partner_features + road_dim = max_road_objects * road_features + + # Random parts + dummy_ego = torch.randn(batch_size, ego_dim) + dummy_partner = torch.randn(batch_size, partner_dim) + + # Road part: continuous features + categorical feature + road_cont_dim = road_features - 1 + + # (Batch, MaxObjects, Feats-1) + dummy_road_cont = torch.randn(batch_size, max_road_objects, road_cont_dim) + + # (Batch, MaxObjects, 1) - valid categorical values [0, 6] + # Ensure it's 0-6 range. 7 is num_classes. + dummy_road_cat = torch.randint(0, 7, (batch_size, max_road_objects, 1)).float() + + # Concatenate and flatten + dummy_road_objs = torch.cat([dummy_road_cont, dummy_road_cat], dim=2) + dummy_road = dummy_road_objs.view(batch_size, -1) + + dummy_obs = torch.cat([dummy_ego, dummy_partner, dummy_road], dim=1) + else: + print("Warning: Could not determine Drive policy structure. Using random observation.") + dummy_obs = torch.randn(batch_size, obs_dim) + + # Dummy LSTM States + hidden_size = config["rnn"]["hidden_size"] + # LSTMCell expects (Batch, Hidden) not (NumLayers, Batch, Hidden) + dummy_h = torch.zeros(batch_size, hidden_size) + dummy_c = torch.zeros(batch_size, hidden_size) + + # Wrap policy for export + onnx_policy = OnnxWrapper(policy) + onnx_policy.eval() + + # Determine output path + if not args.output: + args.output = os.path.splitext(args.checkpoint)[0] + ".onnx" + # Ensure output directory exists + output_dir = os.path.dirname(args.output) + if output_dir: + os.makedirs(output_dir, exist_ok=True) + + print(f"Exporting to {args.output}...") + + # Dynamic axes for batch size flexibility + dynamic_axes = { + "observation": {0: "batch_size"}, + "lstm_h_in": {0: "batch_size"}, + "lstm_c_in": {0: "batch_size"}, + "logits": {0: "batch_size"}, + "value": {0: "batch_size"}, + "lstm_h_out": {0: "batch_size"}, + "lstm_c_out": {0: "batch_size"}, + } + + dummy_inputs = (dummy_obs, dummy_h, dummy_c) + torch.onnx.export( + onnx_policy, + dummy_inputs, + args.output, + export_params=True, + opset_version=args.opset, + do_constant_folding=True, + input_names=["observation", "lstm_h_in", "lstm_c_in"], + output_names=["logits", "value", "lstm_h_out", "lstm_c_out"], + dynamic_axes=dynamic_axes, + ) + + print("Export complete!") + print("\nSample Inputs shapes:") + print(f"Observation: {dummy_obs.shape}") + print(f"LSTM h: {dummy_h.shape}") + print(f"LSTM c: {dummy_c.shape}") + + # Verify ONNX model + if verify: + print("\nVerifying ONNX model...") + sess_options = ort.SessionOptions() + sess_options.intra_op_num_threads = 1 + sess_options.inter_op_num_threads = 1 + ort_session = ort.InferenceSession(args.output, sess_options) + + # PyTorch output + with torch.no_grad(): + torch_logits, torch_value, torch_h, torch_c = onnx_policy(*dummy_inputs) + + # Output .pt files for testing + print(f"Saving test inputs/outputs to {output_dir}") + torch.save(dummy_inputs, os.path.join(output_dir, "test_inputs.pt")) + torch.save((torch_logits, torch_value, torch_h, torch_c), os.path.join(output_dir, "test_outputs.pt")) + + # ONNX Runtime output + ort_inputs = {"observation": dummy_obs.numpy(), "lstm_h_in": dummy_h.numpy(), "lstm_c_in": dummy_c.numpy()} + ort_outs = ort_session.run(None, ort_inputs) + + # Compare outputs + def compare(name, torch_out, ort_out, atol=1e-5): + if isinstance(torch_out, tuple): + for i, (t_out, o_out) in enumerate(zip(torch_out, ort_out)): + compare(f"{name}_{i}", t_out, o_out, atol) + return + + try: + np.testing.assert_allclose(torch_out.detach().numpy(), ort_out, rtol=1e-03, atol=atol) + print(f"✔ {name} match") + except AssertionError as e: + print(f"✘ {name} mismatch") + print(e) + + # Unpack ONNX outputs if logits was a tuple + if isinstance(torch_logits, tuple): + num_logits = len(torch_logits) + ort_logits = ort_outs[:num_logits] + ort_value = ort_outs[num_logits] + ort_h = ort_outs[num_logits + 1] + ort_c = ort_outs[num_logits + 2] + else: + ort_logits = ort_outs[0] + ort_value = ort_outs[1] + ort_h = ort_outs[2] + ort_c = ort_outs[3] + + compare("Logits", torch_logits, ort_logits) + compare("Value", torch_value, ort_value) + compare("LSTM h", torch_h, ort_h) + compare("LSTM c", torch_c, ort_c) + + # Export example input and output to .pt files + + +if __name__ == "__main__": + export_to_onnx(verify=True) \ No newline at end of file From bf69f862f22d6f00e7907c6dcf7294393b50d72c Mon Sep 17 00:00:00 2001 From: mpragnay Date: Mon, 27 Apr 2026 23:25:31 -0400 Subject: [PATCH 2/3] Updated onnx files --- scripts/export_onnx.py | 572 ++++++++++++++++++++++++----------------- scripts/verify_onnx.py | 102 ++++++++ 2 files changed, 444 insertions(+), 230 deletions(-) create mode 100644 scripts/verify_onnx.py diff --git a/scripts/export_onnx.py b/scripts/export_onnx.py index 927315690..7f84dc51a 100644 --- a/scripts/export_onnx.py +++ b/scripts/export_onnx.py @@ -1,261 +1,373 @@ -"""Export a trained PufferDrive policy checkpoint (.pt) to ONNX format. +"""Export a PufferDrive policy checkpoint (.pt) to ONNX format. -The exported ONNX model accepts an observation vector (see in torch.py for the exact layout), -plus LSTM hidden states, and produces action logits, -a value estimate, and updated LSTM states. +Produces a single .onnx file with external weights (.onnx.data), +plus example_input.pt and example_output.pt for verification. + +Args: + --env Environment name used to look up the ini config (default: puffer_drive) + --checkpoint Path to the .pt checkpoint file + --output_dir Directory where .onnx, .onnx.data, and example .pt files are written Usage: - python -m scripts.export_onnx --checkpoint [--output ] + python -m scripts.export_onnx \ + --env puffer_drive \ + --checkpoint pufferlib/resources/drive/onnx_files/model_puffer_drive_011500.pt \ + --output_dir pufferlib/resources/drive/onnx_files """ import argparse +import ast +import configparser +import glob import os -import torch -import importlib +import types import numpy as np +import torch import onnxruntime as ort -import pufferlib.utils -import pufferlib.vector -import pufferlib.models - -from pufferlib.ocean.torch import Drive -from scripts.export_model_bin import load_config - - -class OnnxWrapper(torch.nn.Module): - def __init__(self, policy): - super().__init__() - self.policy = policy - - def forward(self, observation, h, c): - # Reconstruct the state dictionary expected by LSTMWrapper - # state must be mutable as forward_eval updates it - state = {"lstm_h": h, "lstm_c": c} - - # Call forward_eval - logits, value = self.policy.forward_eval(observation, state) - - # Extract updated states - new_h = state["lstm_h"] - new_c = state["lstm_c"] - - return logits, value, new_h, new_c - - -def export_to_onnx(verify=True): - parser = argparse.ArgumentParser(description="Export PufferDrive model to ONNX") - parser.add_argument("--env", type=str, default="puffer_drive", help="Environment name") - parser.add_argument( - "--checkpoint", - type=str, - default="experiments/puffer_drive_73kbtsi5/model_puffer_drive_000200.pt", - help="Path to .pt checkpoint", - ) - parser.add_argument("--output", type=str, help="Output .onnx file path") - parser.add_argument("--opset", type=int, default=18, help="ONNX opset version") - - args = parser.parse_args() - - # Load configuration - config = load_config(args.env) - - # Load environment to get observation/action spaces - package = config["base"]["package"] - module_name = "pufferlib.ocean" if package == "ocean" else f"pufferlib.environments.{package}" - env_module = importlib.import_module(module_name) - make_env = env_module.env_creator(args.env) - - # Ensure env args/kwargs are correctly passed - env_kwargs = config["env"] - - vecenv = pufferlib.vector.make(make_env, env_kwargs=env_kwargs, backend=pufferlib.vector.Serial, num_envs=1) +import gymnasium +import pufferlib - # Initialize Policy - print("Initializing Policy...") - policy = Drive(vecenv.driver_env, **config["policy"]) +from pufferlib.ocean.torch import Drive as DrivePolicy +from pufferlib.ocean.drive import binding - if config["base"]["rnn_name"]: - print("Wrapping with LSTM...") - policy = pufferlib.models.LSTMWrapper(vecenv.driver_env, policy, **config["rnn"]) - # Load Checkpoint - print(f"Loading checkpoint from {args.checkpoint}...") - checkpoint = torch.load(args.checkpoint, map_location="cpu") - - # Handle both full checkpoint dict and raw state dict - if isinstance(checkpoint, dict) and "agent_state_dict" in checkpoint: - state_dict = checkpoint["agent_state_dict"] +def _read_ini(env_name): + """Parse the ini config for env_name and return {section: {key: value}} dicts.""" + puffer_dir = os.path.dirname(os.path.realpath(pufferlib.__file__)) + default_ini = os.path.join(puffer_dir, "config/default.ini") + for path in glob.glob(os.path.join(puffer_dir, "config/**/*.ini"), recursive=True): + p = configparser.ConfigParser(inline_comment_prefixes=(";", "#")) + p.read([default_ini, path]) + if env_name in p["base"]["env_name"].split(): + break else: - state_dict = checkpoint - - # Strip compile prefixes - new_state_dict = {} - for k, v in state_dict.items(): - if k.startswith("_orig_mod."): - new_state_dict[k[10:]] = v - else: - new_state_dict[k] = v - - policy.load_state_dict(new_state_dict) - policy.eval() - - # Prepare inputs for ONNX export - print("Preparing sample inputs...") - batch_size = 1 - - obs_space = vecenv.single_observation_space - # Flatten observation if needed, Drive policy handles flattening internally usually but check vecenv - # The LSTMWrapper expects (B, ObsDim) - obs_dim = np.prod(obs_space.shape) - - # Create Dummy Observation - if config["base"]["rnn_name"]: - # If wrapped, access the internal Drive policy - drive_policy = policy.policy + raise ValueError(f"No config found for env_name={env_name!r}") + + def _cast(v): + try: + return ast.literal_eval(v) + except Exception: + return v + + return {section: {k: _cast(v) for k, v in p[section].items()} for section in p.sections()} + + +# --------------------------------------------------------------------------- +# Minimal mock environment — avoids spinning up a real vecenv +# --------------------------------------------------------------------------- + + +def _make_mock_env( + dynamics_model="jerk", + max_lane_segment_observations=80, + max_boundary_segment_observations=80, + max_partner_observations=16, + max_traffic_control_observations=4, + lane_segment_dropout=0.0, + boundary_segment_dropout=0.0, + num_target_waypoints=3, + target_type="static", + reward_conditioning=False, + split_network=True, +): + env = types.SimpleNamespace() + + env.ego_features = { + "classic": binding.EGO_FEATURES_CLASSIC, + "jerk": binding.EGO_FEATURES_JERK, + }[dynamics_model] + + env.max_lane_segment_observations = max_lane_segment_observations + env.max_boundary_segment_observations = max_boundary_segment_observations + env.max_partner_observations = max_partner_observations + env.max_traffic_control_observations = max_traffic_control_observations + + def _effective(max_count, dropout): + return int(max_count * (1.0 - min(max(dropout, 0.0), 1.0))) if max_count > 0 else 0 + + env.obs_lane_segment_count = _effective(max_lane_segment_observations, lane_segment_dropout) + env.obs_boundary_segment_count = _effective(max_boundary_segment_observations, boundary_segment_dropout) + + env.partner_features = binding.PARTNER_FEATURES + env.road_features = binding.ROAD_FEATURES + env.traffic_control_features = binding.TRAFFIC_CONTROL_FEATURES + + env.num_reward_coefs = binding.NUM_REWARD_COEFS if reward_conditioning else 0 + target_features = binding.STATIC_TARGET_FEATURES if target_type == "static" else binding.DYNAMIC_TARGET_FEATURES + env.target_dim = num_target_waypoints * target_features + + # Action space: jerk → 4×3 = 12, classic → 7×9 = 63 + if dynamics_model == "jerk": + env.single_action_space = gymnasium.spaces.MultiDiscrete([4 * 3]) else: - drive_policy = policy - - if hasattr(drive_policy, "ego_dim"): - # Construct valid dummy observation for Drive policy - # Retrieve needed dimensions - ego_dim = drive_policy.ego_dim - max_partner_objects = drive_policy.max_partner_objects - partner_features = drive_policy.partner_features - max_road_objects = drive_policy.max_road_objects - road_features = drive_policy.road_features + env.single_action_space = gymnasium.spaces.MultiDiscrete([7 * 9]) + + env.split_network = split_network + + num_obs = ( + env.ego_features + + env.num_reward_coefs + + env.target_dim + + env.max_partner_observations * env.partner_features + + env.obs_lane_segment_count * env.road_features + + env.obs_boundary_segment_count * env.road_features + + env.max_traffic_control_observations * env.traffic_control_features + ) + env.single_observation_space = gymnasium.spaces.Box(low=-1, high=1, shape=(num_obs,), dtype=np.float32) + return env + + +# --------------------------------------------------------------------------- +# Realistic example input construction +# --------------------------------------------------------------------------- + + +def _make_example_input(env, batch_size=1): + """Build a plausible (non-trivial) observation tensor.""" + rng = np.random.default_rng(42) + parts = [] + + # 1. Ego state — jerk model: [speed, accel, heading_sin, heading_cos, + # yaw_rate, steer, jerk, norm_progress, dist_to_goal] (9 features) + # Use physically reasonable normalised values. + ego = np.zeros((batch_size, env.ego_features), dtype=np.float32) + ego[:, 0] = 0.3 # speed ~30% of max + ego[:, 1] = 0.05 # small accel + ego[:, 2] = 0.5 # heading sin + ego[:, 3] = 0.866 # heading cos (≈30°) + ego[:, 4] = 0.02 # slight yaw rate + if env.ego_features > 5: + ego[:, 5] = 0.01 # steer + if env.ego_features > 6: + ego[:, 6] = 0.0 # jerk + if env.ego_features > 7: + ego[:, 7] = 0.2 # progress along route + if env.ego_features > 8: + ego[:, 8] = 0.6 # dist to goal + parts.append(ego) + + # 2. Conditioning (reward coefs + target waypoints) + conditioning_dim = env.num_reward_coefs + env.target_dim + if conditioning_dim > 0: + # Reward coefs (uniform 1.0 if present), target waypoints as relative positions + cond = np.zeros((batch_size, conditioning_dim), dtype=np.float32) + coef_dim = env.num_reward_coefs + if coef_dim > 0: + cond[:, :coef_dim] = 1.0 # all reward coefs at baseline + # Waypoints: alternating x/y relative goal positions + target_start = coef_dim + target_features_per_wp = env.target_dim // (env.target_dim // max(1, binding.STATIC_TARGET_FEATURES)) + for i in range(env.target_dim // target_features_per_wp): + base = target_start + i * target_features_per_wp + cond[:, base] = 0.3 + i * 0.15 # x progress + cond[:, base + 1] = 0.05 * (i % 2) # slight lateral offset + if target_features_per_wp > 2: + cond[:, base + 2] = 0.0 # heading delta + parts.append(cond) + + # 3. Partner observations (some agents ahead and to the side) + partner_dim = env.max_partner_observations * env.partner_features + partners = rng.uniform(-0.2, 0.2, (batch_size, partner_dim)).astype(np.float32) + # First 3 partners are real (non-zero), rest are masked/absent + for i in range(min(3, env.max_partner_observations)): + base = i * env.partner_features + partners[:, base + 0] = 0.1 + i * 0.05 # rel x + partners[:, base + 1] = 0.02 * i # rel y + partners[:, base + 2] = 0.3 # speed + partners[:, base + 3] = 0.5 # heading sin + partners[:, base + 4] = 0.866 # heading cos + parts.append(partners) + + # 4. Lane segment observations (road centerline segments ahead) + lane_dim = env.obs_lane_segment_count * env.road_features + lanes = rng.uniform(-0.1, 0.1, (batch_size, lane_dim)).astype(np.float32) + # Add gradient to simulate a curved road ahead + for i in range(env.obs_lane_segment_count): + base = i * env.road_features + lanes[:, base + 0] = 0.05 * i / env.obs_lane_segment_count # x offset increases + lanes[:, base + 1] = 0.3 * i / env.obs_lane_segment_count # y distance ahead + lanes[:, base + 2] = 0.5 # heading sin + lanes[:, base + 3] = 0.866 # heading cos + parts.append(lanes) + + # 5. Boundary segment observations + boundary_dim = env.obs_boundary_segment_count * env.road_features + boundaries = rng.uniform(-0.1, 0.1, (batch_size, boundary_dim)).astype(np.float32) + for i in range(env.obs_boundary_segment_count): + base = i * env.road_features + boundaries[:, base + 0] = -0.15 + 0.05 * i / env.obs_boundary_segment_count + boundaries[:, base + 1] = 0.3 * i / env.obs_boundary_segment_count + boundaries[:, base + 2] = 0.5 + boundaries[:, base + 3] = 0.866 + parts.append(boundaries) + + # 6. Traffic control observations (1 red light ahead, rest absent) + tc_dim = env.max_traffic_control_observations * env.traffic_control_features + tc = np.zeros((batch_size, tc_dim), dtype=np.float32) + # First traffic control: stop sign at distance 0.4, type=0, state=0 (red) + tc[:, 0] = 0.4 # distance + tc[:, 1] = 0.0 # rel x + tc[:, 2] = 0.4 # rel y + tc[:, 3] = 0.5 # heading sin + tc[:, 4] = 0.866 # heading cos + tc[:, 5] = 0.0 # type (int 0 = traffic_light) — will be one-hot in forward + tc[:, 6] = 0.0 # state (int 0 = red) + parts.append(tc) + + obs = np.concatenate(parts, axis=1) + obs = obs.clip(-1.0, 1.0) + return torch.tensor(obs, dtype=torch.float32) + + +# --------------------------------------------------------------------------- +# Main export function +# --------------------------------------------------------------------------- + + +def export(args): + checkpoint_path = args.checkpoint + output_dir = args.output_dir + os.makedirs(output_dir, exist_ok=True) + + base = os.path.splitext(os.path.basename(checkpoint_path))[0] + onnx_path = os.path.join(output_dir, base + ".onnx") + + # Load config from ini file + config = _read_ini(args.env) + env_cfg = config["env"] + policy_kwargs = config["policy"] + + # Build mock env from ini env config + env = _make_mock_env( + dynamics_model=env_cfg.get("dynamics_model", "jerk"), + max_lane_segment_observations=env_cfg.get("max_lane_segment_observations", 80), + max_boundary_segment_observations=env_cfg.get("max_boundary_segment_observations", 80), + max_partner_observations=env_cfg.get("max_partner_observations", 16), + max_traffic_control_observations=env_cfg.get("max_traffic_control_observations", 4), + lane_segment_dropout=env_cfg.get("lane_segment_dropout", 0.0), + boundary_segment_dropout=env_cfg.get("boundary_segment_dropout", 0.0), + num_target_waypoints=env_cfg.get("num_target_waypoints", 3), + target_type=env_cfg.get("target_type", "static"), + reward_conditioning=env_cfg.get("reward_conditioning", False), + split_network=policy_kwargs.get("split_network", False), + ) + obs_dim = env.single_observation_space.shape[0] + print(f"Observation dim: {obs_dim}") - partner_dim = max_partner_objects * partner_features - road_dim = max_road_objects * road_features + # Build policy from ini policy config + policy = DrivePolicy(env, **policy_kwargs) - # Random parts - dummy_ego = torch.randn(batch_size, ego_dim) - dummy_partner = torch.randn(batch_size, partner_dim) + # Load checkpoint + print(f"Loading checkpoint: {checkpoint_path}") + ckpt = torch.load(checkpoint_path, map_location="cpu") - # Road part: continuous features + categorical feature - road_cont_dim = road_features - 1 + # Strip DDP `module.` prefix + state_dict = {} + for k, v in ckpt.items(): + key = k[len("module.") :] if k.startswith("module.") else k + key = key[len("_orig_mod.") :] if key.startswith("_orig_mod.") else key + state_dict[key] = v - # (Batch, MaxObjects, Feats-1) - dummy_road_cont = torch.randn(batch_size, max_road_objects, road_cont_dim) + policy.load_state_dict(state_dict) + policy.eval() - # (Batch, MaxObjects, 1) - valid categorical values [0, 6] - # Ensure it's 0-6 range. 7 is num_classes. - dummy_road_cat = torch.randint(0, 7, (batch_size, max_road_objects, 1)).float() + # Build example input + example_input = _make_example_input(env, batch_size=1) + print(f"Example input shape: {example_input.shape}") + print(f"Example input (first 12 values): {example_input[0, :12].tolist()}") - # Concatenate and flatten - dummy_road_objs = torch.cat([dummy_road_cont, dummy_road_cat], dim=2) - dummy_road = dummy_road_objs.view(batch_size, -1) + # Run PyTorch forward pass + with torch.no_grad(): + actions, value = policy(example_input) - dummy_obs = torch.cat([dummy_ego, dummy_partner, dummy_road], dim=1) + # actions is a tuple of logit tensors for MultiDiscrete + if isinstance(actions, tuple): + logits = torch.cat(actions, dim=1) else: - print("Warning: Could not determine Drive policy structure. Using random observation.") - dummy_obs = torch.randn(batch_size, obs_dim) - - # Dummy LSTM States - hidden_size = config["rnn"]["hidden_size"] - # LSTMCell expects (Batch, Hidden) not (NumLayers, Batch, Hidden) - dummy_h = torch.zeros(batch_size, hidden_size) - dummy_c = torch.zeros(batch_size, hidden_size) - - # Wrap policy for export - onnx_policy = OnnxWrapper(policy) - onnx_policy.eval() - - # Determine output path - if not args.output: - args.output = os.path.splitext(args.checkpoint)[0] + ".onnx" - # Ensure output directory exists - output_dir = os.path.dirname(args.output) - if output_dir: - os.makedirs(output_dir, exist_ok=True) - - print(f"Exporting to {args.output}...") - - # Dynamic axes for batch size flexibility - dynamic_axes = { - "observation": {0: "batch_size"}, - "lstm_h_in": {0: "batch_size"}, - "lstm_c_in": {0: "batch_size"}, - "logits": {0: "batch_size"}, - "value": {0: "batch_size"}, - "lstm_h_out": {0: "batch_size"}, - "lstm_c_out": {0: "batch_size"}, - } - - dummy_inputs = (dummy_obs, dummy_h, dummy_c) + logits = actions + + print(f"PyTorch logits shape: {logits.shape}, value shape: {value.shape}") + print(f"PyTorch logits sample: {logits[0].tolist()}") + print(f"PyTorch value: {value[0].item():.6f}") + + # Save example input/output + input_path = os.path.join(output_dir, "example_input.pt") + output_path = os.path.join(output_dir, "example_output.pt") + torch.save(example_input, input_path) + torch.save({"logits": logits, "value": value}, output_path) + print(f"Saved example_input.pt -> {input_path}") + print(f"Saved example_output.pt -> {output_path}") + + # Export to ONNX with external data (single .onnx.data file) + print(f"\nExporting to ONNX: {onnx_path}") torch.onnx.export( - onnx_policy, - dummy_inputs, - args.output, + policy, + example_input, + onnx_path, export_params=True, - opset_version=args.opset, + opset_version=18, do_constant_folding=True, - input_names=["observation", "lstm_h_in", "lstm_c_in"], - output_names=["logits", "value", "lstm_h_out", "lstm_c_out"], - dynamic_axes=dynamic_axes, + input_names=["observation"], + output_names=["logits", "value"], + dynamic_axes={"observation": {0: "batch_size"}, "logits": {0: "batch_size"}, "value": {0: "batch_size"}}, ) - print("Export complete!") - print("\nSample Inputs shapes:") - print(f"Observation: {dummy_obs.shape}") - print(f"LSTM h: {dummy_h.shape}") - print(f"LSTM c: {dummy_c.shape}") - - # Verify ONNX model - if verify: - print("\nVerifying ONNX model...") - sess_options = ort.SessionOptions() - sess_options.intra_op_num_threads = 1 - sess_options.inter_op_num_threads = 1 - ort_session = ort.InferenceSession(args.output, sess_options) - - # PyTorch output - with torch.no_grad(): - torch_logits, torch_value, torch_h, torch_c = onnx_policy(*dummy_inputs) - - # Output .pt files for testing - print(f"Saving test inputs/outputs to {output_dir}") - torch.save(dummy_inputs, os.path.join(output_dir, "test_inputs.pt")) - torch.save((torch_logits, torch_value, torch_h, torch_c), os.path.join(output_dir, "test_outputs.pt")) - - # ONNX Runtime output - ort_inputs = {"observation": dummy_obs.numpy(), "lstm_h_in": dummy_h.numpy(), "lstm_c_in": dummy_c.numpy()} - ort_outs = ort_session.run(None, ort_inputs) - - # Compare outputs - def compare(name, torch_out, ort_out, atol=1e-5): - if isinstance(torch_out, tuple): - for i, (t_out, o_out) in enumerate(zip(torch_out, ort_out)): - compare(f"{name}_{i}", t_out, o_out, atol) - return - - try: - np.testing.assert_allclose(torch_out.detach().numpy(), ort_out, rtol=1e-03, atol=atol) - print(f"✔ {name} match") - except AssertionError as e: - print(f"✘ {name} mismatch") - print(e) - - # Unpack ONNX outputs if logits was a tuple - if isinstance(torch_logits, tuple): - num_logits = len(torch_logits) - ort_logits = ort_outs[:num_logits] - ort_value = ort_outs[num_logits] - ort_h = ort_outs[num_logits + 1] - ort_c = ort_outs[num_logits + 2] - else: - ort_logits = ort_outs[0] - ort_value = ort_outs[1] - ort_h = ort_outs[2] - ort_c = ort_outs[3] - - compare("Logits", torch_logits, ort_logits) - compare("Value", torch_value, ort_value) - compare("LSTM h", torch_h, ort_h) - compare("LSTM c", torch_c, ort_c) - - # Export example input and output to .pt files + # Convert inline weights to external data file + import onnx + from onnx.external_data_helper import convert_model_to_external_data + + model_proto = onnx.load(onnx_path) + data_file = base + ".onnx.data" + convert_model_to_external_data( + model_proto, + all_tensors_to_one_file=True, + location=data_file, + size_threshold=0, + convert_attribute=True, + ) + onnx.save_model(model_proto, onnx_path) + print(f"Weights saved to: {os.path.join(output_dir, data_file)}") + + # Quick ONNX runtime verification + # Load with external data merged so ort can find everything + print("\nVerifying ONNX output with onnxruntime...") + import onnx + from onnx.external_data_helper import load_external_data_for_model + + verify_proto = onnx.load(onnx_path, load_external_data=False) + load_external_data_for_model(verify_proto, output_dir) + model_bytes = verify_proto.SerializeToString() + sess = ort.InferenceSession(model_bytes) + ort_outs = sess.run(None, {"observation": example_input.numpy()}) + + ort_logits = ort_outs[0] + ort_value = ort_outs[1] + + np.testing.assert_allclose( + logits.numpy(), ort_logits, rtol=1e-3, atol=1e-4, err_msg="Logits mismatch between PyTorch and ONNX Runtime" + ) + np.testing.assert_allclose( + value.numpy(), ort_value, rtol=1e-3, atol=1e-4, err_msg="Value mismatch between PyTorch and ONNX Runtime" + ) + print(" logits: MATCH") + print(" value: MATCH") + print("\nExport complete.") + print(f" ONNX model: {onnx_path}") + print(f" Weights data: {os.path.join(output_dir, data_file)}") + print(f" Example input: {input_path}") + print(f" Example output:{output_path}") if __name__ == "__main__": - export_to_onnx(verify=True) \ No newline at end of file + parser = argparse.ArgumentParser(description="Export PufferDrive policy to ONNX") + parser.add_argument("--env", type=str, default="puffer_drive", help="Environment name (matches ini config)") + parser.add_argument( + "--checkpoint", + default="pufferlib/resources/drive/onnx_files/model_puffer_drive_011500.pt", + ) + parser.add_argument( + "--output_dir", + default="pufferlib/resources/drive/onnx_files", + ) + export(parser.parse_args()) diff --git a/scripts/verify_onnx.py b/scripts/verify_onnx.py new file mode 100644 index 000000000..f932cb620 --- /dev/null +++ b/scripts/verify_onnx.py @@ -0,0 +1,102 @@ +"""Verify an exported ONNX model against saved example input/output tensors. + +Accepts either the .onnx or the .onnx.data file for --onnx; the .onnx file +is used in both cases. + +Args: + --onnx Path to the exported .onnx file (or .onnx.data — suffix is stripped automatically) + --input Path to the example_input.pt file saved by export_onnx.py + --output Path to the example_output.pt file saved by export_onnx.py + +Usage: + python -m scripts.verify_onnx \ + --onnx pufferlib/resources/drive/onnx_files/model_puffer_drive_011500.onnx \ + --input pufferlib/resources/drive/onnx_files/example_input.pt \ + --output pufferlib/resources/drive/onnx_files/example_output.pt +""" + +import argparse +import os +import numpy as np +import torch +import onnx +import onnxruntime as ort +from onnx.external_data_helper import load_external_data_for_model + + +def verify(args): + # Allow passing either the .onnx or the .onnx.data file + onnx_path = args.onnx + if onnx_path.endswith(".onnx.data"): + onnx_path = onnx_path[: -len(".data")] + input_path = args.input + output_path = args.output + + print(f"ONNX model: {onnx_path}") + print(f"Example input: {input_path}") + print(f"Example output: {output_path}") + + # Load example tensors + example_input = torch.load(input_path, map_location="cpu") + saved = torch.load(output_path, map_location="cpu") + ref_logits = saved["logits"] + ref_value = saved["value"] + + print(f"\nInput shape: {example_input.shape}") + print(f"Ref logits: {ref_logits.shape} sample: {ref_logits[0].tolist()}") + print(f"Ref value: {ref_value[0].item():.6f}") + + # Load ONNX model — merge external data into memory so ORT can run it + onnx_dir = os.path.dirname(os.path.abspath(onnx_path)) + proto = onnx.load(onnx_path, load_external_data=False) + load_external_data_for_model(proto, onnx_dir) + model_bytes = proto.SerializeToString() + + sess = ort.InferenceSession(model_bytes) + ort_outs = sess.run(None, {"observation": example_input.numpy()}) + ort_logits, ort_value = ort_outs[0], ort_outs[1] + + print(f"\nONNX logits: {ort_logits.shape} sample: {ort_logits[0].tolist()}") + print(f"ONNX value: {ort_value[0, 0]:.6f}") + + atol, rtol = 1e-4, 1e-3 + passed = True + + try: + np.testing.assert_allclose(ref_logits.numpy(), ort_logits, rtol=rtol, atol=atol) + print("\n logits: MATCH") + except AssertionError as e: + print("\n logits: MISMATCH") + print(f" {e}") + passed = False + + try: + np.testing.assert_allclose(ref_value.numpy(), ort_value, rtol=rtol, atol=atol) + print(" value: MATCH") + except AssertionError as e: + print(" value: MISMATCH") + print(f" {e}") + passed = False + + if passed: + print("\nAll outputs verified successfully.") + else: + print("\nVerification FAILED — see mismatches above.") + raise SystemExit(1) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Verify exported ONNX model") + parser.add_argument( + "--onnx", + default="pufferlib/resources/drive/onnx_files/model_puffer_drive_011500.onnx", + ) + parser.add_argument( + "--input", + default="pufferlib/resources/drive/onnx_files/example_input.pt", + ) + parser.add_argument( + "--output", + default="pufferlib/resources/drive/onnx_files/example_output.pt", + ) + verify(parser.parse_args()) From bfc2ee2544fa00ae5bcfa5464bc2329df5b4cd91 Mon Sep 17 00:00:00 2001 From: mpragnay Date: Tue, 28 Apr 2026 00:25:50 -0400 Subject: [PATCH 3/3] Added fake trajectory output onnx wrapper, outputs a fake trajectory of (x, y, heading) with the next action sampled by planner and the remain n_step - 1 actions as zero --- scripts/export_onnx.py | 240 ++++++++++++++++++++++++++++++++++++----- scripts/verify_onnx.py | 28 ++++- 2 files changed, 240 insertions(+), 28 deletions(-) diff --git a/scripts/export_onnx.py b/scripts/export_onnx.py index 7f84dc51a..434db81f5 100644 --- a/scripts/export_onnx.py +++ b/scripts/export_onnx.py @@ -4,15 +4,29 @@ plus example_input.pt and example_output.pt for verification. Args: - --env Environment name used to look up the ini config (default: puffer_drive) - --checkpoint Path to the .pt checkpoint file - --output_dir Directory where .onnx, .onnx.data, and example .pt files are written + --env Environment name used to look up the ini config (default: puffer_drive) + --checkpoint Path to the .pt checkpoint file + --output_dir Directory where .onnx, .onnx.data, and example .pt files are written + --fake_trajectory Wrap model to also output a (n_step, 3) x/y/heading trajectory. + Step 0 uses the argmax action; remaining steps use zero jerk. + --traj_steps Number of trajectory rollout steps (default: 80) + --traj_dt Timestep in seconds for the trajectory rollout (default: 0.1) Usage: + # Standard export python -m scripts.export_onnx \ --env puffer_drive \ --checkpoint pufferlib/resources/drive/onnx_files/model_puffer_drive_011500.pt \ --output_dir pufferlib/resources/drive/onnx_files + + # Export with fake trajectory output + python -m scripts.export_onnx \ + --env puffer_drive \ + --checkpoint pufferlib/resources/drive/onnx_files/model_puffer_drive_011500.pt \ + --output_dir pufferlib/resources/drive/onnx_files \ + --fake_trajectory \ + --traj_steps 80 \ + --traj_dt 0.1 """ import argparse @@ -32,6 +46,142 @@ from pufferlib.ocean.drive import binding +JERK_LONG = torch.tensor([-15.0, -4.0, 0.0, 4.0]) +JERK_LAT = torch.tensor([-4.0, 0.0, 4.0]) +NUM_LAT = 3 + + +class JerkDynamicsRollout(torch.nn.Module): + """Roll out the argmax action through jerk bicycle dynamics for N steps. + + Step 0 applies the chosen jerk; steps 1..N-1 apply zero jerk (constant + acceleration coast), producing the "fake trajectory" in ego frame. + """ + + def __init__(self, num_steps=80, dt=0.1): + super().__init__() + self.num_steps = num_steps + self.dt = dt + self.register_buffer("jerk_long_table", JERK_LONG) + self.register_buffer("jerk_lat_table", JERK_LAT) + + def forward(self, logits, speed, a_long, a_lat, steering_angle, wheelbase): + """ + Args: + logits: [B, num_actions] + speed: [B] + a_long: [B] + a_lat: [B] + steering_angle: [B] + wheelbase: [B] + Returns: + trajectory: [B, num_steps, 3] — (x, y, heading) in ego frame + """ + B = logits.shape[0] + action = logits.argmax(dim=-1) + jerk_long_val = self.jerk_long_table[action // NUM_LAT] + jerk_lat_val = self.jerk_lat_table[action % NUM_LAT] + + x = torch.zeros(B, device=logits.device) + y = torch.zeros(B, device=logits.device) + heading = torch.zeros(B, device=logits.device) + v = speed.clone() + al = a_long.clone() + at = a_lat.clone() + steer = steering_angle.clone() + + dt = self.dt + trajectory = torch.zeros(B, self.num_steps, 3, device=logits.device) + + # Apply chosen jerk only on step 0; zero jerk thereafter + jl_schedule = torch.zeros(self.num_steps, B, device=logits.device) + jt_schedule = torch.zeros(self.num_steps, B, device=logits.device) + jl_schedule[0] = jerk_long_val + jt_schedule[0] = jerk_lat_val + + for t in range(self.num_steps): + al_new = al + jl_schedule[t] * dt + at_new = at + jt_schedule[t] * dt + + al_new = torch.clamp(al_new, -5.0, 2.5) + at_new = torch.clamp(at_new, -4.0, 4.0) + + al_new = torch.where(al * al_new < 0, torch.zeros_like(al_new), al_new) + at_new = torch.where(at * at_new < 0, torch.zeros_like(at_new), at_new) + + v_new = v + 0.5 * (al_new + al) * dt + v_new = torch.where(v * v_new < 0, torch.zeros_like(v_new), v_new) + v_new = torch.clamp(v_new, -2.0, 20.0) + + signed_curvature = at_new / torch.clamp(v_new * v_new, min=1e-5) + steer_target = torch.atan(signed_curvature * wheelbase) + delta_steer = torch.clamp(steer_target - steer, -0.6 * dt, 0.6 * dt) + steer_new = torch.clamp(steer + delta_steer, -0.55, 0.55) + + signed_curvature = torch.tan(steer_new) / wheelbase + at_new = v_new * v_new * signed_curvature + + d = 0.5 * (v_new + v) * dt + theta = d * signed_curvature + + small = signed_curvature.abs() < 1e-5 + dx_local = torch.where(small, d, torch.sin(theta) / (signed_curvature + 1e-10)) + dy_local = torch.where(small, torch.zeros_like(d), (1 - torch.cos(theta)) / (signed_curvature + 1e-10)) + + cos_h = torch.cos(heading) + sin_h = torch.sin(heading) + x = x + dx_local * cos_h - dy_local * sin_h + y = y + dx_local * sin_h + dy_local * cos_h + heading = heading + theta + v, al, at, steer = v_new, al_new, at_new, steer_new + + trajectory[:, t, 0] = x + trajectory[:, t, 1] = y + trajectory[:, t, 2] = heading + + return trajectory + + +class OnnxWrapper(torch.nn.Module): + """Thin wrapper around the Drive policy for ONNX export. + + Without --fake_trajectory: outputs (logits, value). + With --fake_trajectory: outputs (logits, value, trajectory), + where trajectory is [B, traj_steps, 3] (x, y, heading in ego frame). + """ + + def __init__(self, policy, fake_trajectory=False, traj_steps=80, dt=0.1): + super().__init__() + self.policy = policy + self.fake_trajectory = fake_trajectory + if fake_trajectory: + self.rollout = JerkDynamicsRollout(num_steps=traj_steps, dt=dt) + + def forward(self, observation): + actions, value = self.policy(observation) + # actions is a tuple of per-head logit tensors; cat for a single tensor + logits = torch.cat(actions, dim=1) if isinstance(actions, tuple) else actions + + if self.fake_trajectory: + # Ego features are the first slice of the observation (jerk model layout): + # obs[:, 3] = signed_speed / MAX_SPEED (MAX_SPEED = 100) + # obs[:, 5] = vehicle_length / MAX_VEH_LEN (proxy for wheelbase, ~5.5 m) + # obs[:, 7] = steering_angle / pi + # obs[:, 8] = a_long (asymmetrically normalised: neg→/15, pos→/4) + # obs[:, 9] = a_lat / 4.0 + speed = observation[:, 3] * 100.0 + wheelbase = observation[:, 5] * 5.5 + steering_angle = observation[:, 7] * 3.14159265 + a_long_norm = observation[:, 8] + a_long = torch.where(a_long_norm < 0, a_long_norm * 15.0, a_long_norm * 4.0) + a_lat = observation[:, 9] * 4.0 + + trajectory = self.rollout(logits, speed, a_long, a_lat, steering_angle, wheelbase) + return logits, value, trajectory + + return logits, value + + def _read_ini(env_name): """Parse the ini config for env_name and return {section: {key: value}} dicts.""" puffer_dir = os.path.dirname(os.path.realpath(pufferlib.__file__)) @@ -272,20 +422,29 @@ def export(args): policy.load_state_dict(state_dict) policy.eval() + # Wrap policy for export + wrapped = OnnxWrapper( + policy, + fake_trajectory=args.fake_trajectory, + traj_steps=args.traj_steps, + dt=args.traj_dt, + ) + wrapped.eval() + # Build example input example_input = _make_example_input(env, batch_size=1) print(f"Example input shape: {example_input.shape}") print(f"Example input (first 12 values): {example_input[0, :12].tolist()}") - # Run PyTorch forward pass + # Run PyTorch forward pass through the wrapper with torch.no_grad(): - actions, value = policy(example_input) + torch_outs = wrapped(example_input) - # actions is a tuple of logit tensors for MultiDiscrete - if isinstance(actions, tuple): - logits = torch.cat(actions, dim=1) + if args.fake_trajectory: + logits, value, trajectory = torch_outs + print(f"PyTorch trajectory shape: {trajectory.shape}") else: - logits = actions + logits, value = torch_outs print(f"PyTorch logits shape: {logits.shape}, value shape: {value.shape}") print(f"PyTorch logits sample: {logits[0].tolist()}") @@ -295,22 +454,36 @@ def export(args): input_path = os.path.join(output_dir, "example_input.pt") output_path = os.path.join(output_dir, "example_output.pt") torch.save(example_input, input_path) - torch.save({"logits": logits, "value": value}, output_path) + saved = {"logits": logits, "value": value} + if args.fake_trajectory: + saved["trajectory"] = trajectory + torch.save(saved, output_path) print(f"Saved example_input.pt -> {input_path}") print(f"Saved example_output.pt -> {output_path}") + # Build ONNX export names / dynamic axes + output_names = ["logits", "value"] + dynamic_axes = { + "observation": {0: "batch_size"}, + "logits": {0: "batch_size"}, + "value": {0: "batch_size"}, + } + if args.fake_trajectory: + output_names.append("trajectory") + dynamic_axes["trajectory"] = {0: "batch_size"} + # Export to ONNX with external data (single .onnx.data file) print(f"\nExporting to ONNX: {onnx_path}") torch.onnx.export( - policy, + wrapped, example_input, onnx_path, export_params=True, opset_version=18, do_constant_folding=True, input_names=["observation"], - output_names=["logits", "value"], - dynamic_axes={"observation": {0: "batch_size"}, "logits": {0: "batch_size"}, "value": {0: "batch_size"}}, + output_names=output_names, + dynamic_axes=dynamic_axes, ) # Convert inline weights to external data file @@ -329,29 +502,41 @@ def export(args): onnx.save_model(model_proto, onnx_path) print(f"Weights saved to: {os.path.join(output_dir, data_file)}") - # Quick ONNX runtime verification - # Load with external data merged so ort can find everything + # Quick ONNX runtime verification — merge external data into memory first print("\nVerifying ONNX output with onnxruntime...") import onnx from onnx.external_data_helper import load_external_data_for_model verify_proto = onnx.load(onnx_path, load_external_data=False) load_external_data_for_model(verify_proto, output_dir) - model_bytes = verify_proto.SerializeToString() - sess = ort.InferenceSession(model_bytes) + sess = ort.InferenceSession(verify_proto.SerializeToString()) ort_outs = sess.run(None, {"observation": example_input.numpy()}) - ort_logits = ort_outs[0] - ort_value = ort_outs[1] - np.testing.assert_allclose( - logits.numpy(), ort_logits, rtol=1e-3, atol=1e-4, err_msg="Logits mismatch between PyTorch and ONNX Runtime" + logits.numpy(), + ort_outs[0], + rtol=1e-3, + atol=1e-4, + err_msg="Logits mismatch between PyTorch and ONNX Runtime", ) np.testing.assert_allclose( - value.numpy(), ort_value, rtol=1e-3, atol=1e-4, err_msg="Value mismatch between PyTorch and ONNX Runtime" + value.numpy(), + ort_outs[1], + rtol=1e-3, + atol=1e-4, + err_msg="Value mismatch between PyTorch and ONNX Runtime", ) - print(" logits: MATCH") - print(" value: MATCH") + print(" logits: MATCH") + print(" value: MATCH") + if args.fake_trajectory: + np.testing.assert_allclose( + trajectory.numpy(), + ort_outs[2], + rtol=1e-3, + atol=1e-4, + err_msg="Trajectory mismatch between PyTorch and ONNX Runtime", + ) + print(" trajectory: MATCH") print("\nExport complete.") print(f" ONNX model: {onnx_path}") print(f" Weights data: {os.path.join(output_dir, data_file)}") @@ -370,4 +555,11 @@ def export(args): "--output_dir", default="pufferlib/resources/drive/onnx_files", ) + parser.add_argument( + "--fake_trajectory", + action="store_true", + help="Add trajectory output: step 0 uses argmax action, remaining steps use zero jerk", + ) + parser.add_argument("--traj_steps", type=int, default=80, help="Number of trajectory rollout steps") + parser.add_argument("--traj_dt", type=float, default=0.1, help="Timestep in seconds for trajectory rollout") export(parser.parse_args()) diff --git a/scripts/verify_onnx.py b/scripts/verify_onnx.py index f932cb620..cac04e8e3 100644 --- a/scripts/verify_onnx.py +++ b/scripts/verify_onnx.py @@ -64,20 +64,40 @@ def verify(args): try: np.testing.assert_allclose(ref_logits.numpy(), ort_logits, rtol=rtol, atol=atol) - print("\n logits: MATCH") + print("\n logits: MATCH") except AssertionError as e: - print("\n logits: MISMATCH") + print("\n logits: MISMATCH") print(f" {e}") passed = False try: np.testing.assert_allclose(ref_value.numpy(), ort_value, rtol=rtol, atol=atol) - print(" value: MATCH") + print(" value: MATCH") except AssertionError as e: - print(" value: MISMATCH") + print(" value: MISMATCH") print(f" {e}") passed = False + if "trajectory" in saved: + ref_traj = saved["trajectory"] + print(f"\nRef trajectory: {ref_traj.shape} first point: {ref_traj[0, 0].tolist()}") + if len(ort_outs) < 3: + print(" trajectory: MISMATCH — ONNX model has no trajectory output") + passed = False + else: + ort_traj = ort_outs[2] + print(f"ONNX trajectory:{ort_traj.shape} first point: {ort_traj[0, 0].tolist()}") + try: + np.testing.assert_allclose(ref_traj.numpy(), ort_traj, rtol=rtol, atol=atol) + print(" trajectory: MATCH") + except AssertionError as e: + print(" trajectory: MISMATCH") + print(f" {e}") + passed = False + else: + if len(ort_outs) > 2: + print(f"\n WARNING: ONNX model has {len(ort_outs)} outputs but example_output.pt has no trajectory key") + if passed: print("\nAll outputs verified successfully.") else: