diff --git a/.gitignore b/.gitignore
index 717cf27b54..6c0dff685c 100644
--- a/.gitignore
+++ b/.gitignore
@@ -3,6 +3,9 @@ c_*.c
pufferlib/extensions.c
pufferlib/puffernet.c
+# Trajviz: shaders.c is generated by shaders/build_shaders.sh at build time
+pufferlib/ocean/drive/trajviz/shaders.c
+
# Raylib
raylib_wasm/
diff --git a/docs/src/SUMMARY.md b/docs/src/SUMMARY.md
index 5a373e9219..9f91afbad1 100644
--- a/docs/src/SUMMARY.md
+++ b/docs/src/SUMMARY.md
@@ -11,6 +11,7 @@
- [Simulator](simulator.md)
- [Interactive scenario editor](scene-editor.md)
- [Visualizer](visualizer.md)
+- [Trajviz (Vulkan offline renderer)](trajviz.md)
- [Export model to ONNX](export-onnx.md)
# Data
diff --git a/docs/src/trajviz.md b/docs/src/trajviz.md
new file mode 100644
index 0000000000..ad4e1b05c1
--- /dev/null
+++ b/docs/src/trajviz.md
@@ -0,0 +1,436 @@
+# Trajviz: Vulkan Offline Renderer
+
+`trajviz` is a Vulkan-backed offline renderer that turns saved Drive
+trajectories into MP4 videos at high throughput. It runs **headlessly**
+on a single GPU (no X server required) and supports **batched
+multi-episode rendering** so you can amortize the per-frame overhead
+across many episodes in one pass.
+
+It is independent of the existing raylib visualizer (`scripts/build_ocean.sh
+visualize`) — they share no code and can coexist. Trajviz is the path you
+want when you need to render many checkpoint videos quickly, on a cluster
+node, or from a Python script.
+
+Source: `pufferlib/ocean/drive/trajviz/`.
+
+## When to use trajviz
+
+| You want to… | Use |
+|---|---|
+| Replay one scenario interactively, debug a policy live | `visualize` (raylib) |
+| Render N saved-checkpoint videos in one Python call | **trajviz** |
+| Render hundreds of trajectories on a headless cluster node | **trajviz** |
+| Drive your render from a notebook / training-loop callback | **trajviz** |
+
+Trajviz outputs the same two views the live raylib visualizer does:
+
+- **Top-down** (`RenderView.FULL_SIM_STATE`): orthographic full-map view
+- **BEV** (`RenderView.BEV_AGENT_OBS`): agent-centric ~100 m × 178 m
+ window, ego at the center facing up
+
+## Prerequisites
+
+Apt (Ubuntu/Debian):
+
+```bash
+sudo apt install -y libvulkan-dev glslang-tools vulkan-tools spirv-tools ffmpeg
+```
+
+Each package is needed for:
+
+- `libvulkan-dev` — Vulkan headers used at compile time
+- `glslang-tools` — `glslangValidator`, the GLSL → SPIR-V compiler that
+ trajviz invokes when compiling its shaders
+- `vulkan-tools` — `vulkaninfo` for diagnostics (optional)
+- `spirv-tools` — SPIR-V utilities (optional)
+- `ffmpeg` — runtime; trajviz pipes raw RGBA frames to ffmpeg for h264 encoding
+
+You also need a Vulkan-capable GPU and ICD. On NVIDIA, the proprietary
+driver provides this automatically (`/usr/share/vulkan/icd.d/nvidia_icd.json`).
+Verify with:
+
+```bash
+vulkaninfo --summary
+```
+
+You should see your GPU listed with `deviceType = PHYSICAL_DEVICE_TYPE_DISCRETE_GPU`.
+
+## Build
+
+Trajviz is an **opt-in** CPython extension built via `setup.py`. Pass
+`TRAJVIZ=1` to enable it:
+
+```bash
+TRAJVIZ=1 python setup.py build_ext --inplace
+```
+
+This compiles the shaders (via `glslangValidator`), embeds them as SPIR-V
+blobs in a generated `shaders.c`, and builds
+`pufferlib.ocean.drive.trajviz._native` into the source tree.
+
+Without `TRAJVIZ=1`, the trajviz extension is **not** built and the rest
+of pufferlib (including the drive sim) builds normally — so users who
+don't need trajviz aren't forced to install Vulkan.
+
+## Usage
+
+### Python API
+
+```python
+from pufferlib.ocean.drive.trajviz import Renderer
+
+with Renderer(width=1280, height=720) as r:
+ r.render_episode(
+ road_xy=road_xy, # (V, 2) float32, mean-centered
+ road_offsets=road_offsets, # (P+1,) uint32 CSR
+ road_types=road_types, # (P,) uint32 — TVZ_ROAD_* type ids
+ traj_xyh=traj, # (T, A, 3) float32 (x, y, heading)
+ agent_lengths=lengths, # (A,) int32 valid step counts
+ ego_idx=-1, # -1 = first agent with length >= 2
+ fps=30,
+ out_topdown="td.mp4",
+ out_bev="bev.mp4",
+ )
+```
+
+### Batched (multi-episode) API
+
+```python
+with Renderer(width=1280, height=720) as r:
+ r.render_batch([
+ dict(road_xy=..., road_offsets=..., road_types=...,
+ traj_xyh=..., agent_lengths=..., ego_idx=-1,
+ out_topdown="ep0_td.mp4", out_bev="ep0_bev.mp4"),
+ dict(road_xy=..., road_offsets=..., road_types=...,
+ traj_xyh=..., agent_lengths=..., ego_idx=-1,
+ out_topdown="ep1_td.mp4", out_bev="ep1_bev.mp4"),
+ # ... up to 16 episodes per batch
+ ], fps=30)
+```
+
+The Renderer is reusable across batches. Pay the Vulkan startup cost
+(~50 ms) and the BatchRenderer atlas allocation (~20 ms) once for an
+entire run of episodes.
+
+### From a saved trajectories_*.npz
+
+```python
+from pufferlib.ocean.drive.trajviz import render_npz
+
+render_npz(
+ "data/runs/.../trajectories_000010.npz",
+ maps_dir="pufferlib/resources/drive/binaries/training",
+ out_dir="videos/",
+)
+```
+
+### CLI
+
+```bash
+python -m pufferlib.ocean.drive.trajviz \
+ data/runs/foo/trajectories_*.npz \
+ --maps-dir pufferlib/resources/drive/binaries/training \
+ --out videos/
+```
+
+Multiple input files or directories are supported. The Vulkan context is
+created once and reused across all inputs.
+
+### Random-rollout smoke test
+
+A small tool spins up a Drive sim, runs a 90-step random-action episode,
+and renders both views. Useful for verifying that trajviz works end-to-end
+without depending on saved trajectories:
+
+```bash
+python -m pufferlib.ocean.drive.trajviz.tools.random_rollout \
+ --map pufferlib/resources/drive/binaries/map_001.bin \
+ --out-dir /tmp
+```
+
+Outputs `/tmp/random_topdown.mp4` and `/tmp/random_bev.mp4`. Defaults to
+2 controllable agents (matches the typical WOSAC `tracks_to_predict`
+count); use `--num-agents N` to override.
+
+## Performance tuning
+
+On an RTX 4080 (16-core CPU), the current pipeline reaches **~3.7
+episodes per second** at `batch_size ≥ 4` for 90-frame 1280×720 episodes
+with both views. Per-episode breakdown:
+
+- Pure GPU + readback: **~30 ms / episode** (the floor — what trajviz
+ achieves with `TRAJVIZ_NO_WRITE=1`)
+- + ffmpeg encoding (libx264 `-preset veryfast`): **~270 ms / episode**
+
+The encoder is the dominant cost beyond the GPU work; everything below
+is squeezed.
+
+### Bumping kernel pipe limits
+
+Trajviz pipes 3.6 MB raw RGBA frames per view to ffmpeg via Unix pipes.
+Default Linux pipe buffers (64 KB) force many round-trips per frame; the
+trajviz `ffmpeg_pipe_open` automatically calls `fcntl(F_SETPIPE_SZ, ...)`
+to bump them, but the per-process maximum is `/proc/sys/fs/pipe-max-size`
+(default 1 MB on most kernels). Raise it for better throughput:
+
+```bash
+sudo sysctl fs.pipe-max-size=16777216
+```
+
+There is also a per-user *total* page budget,
+`/proc/sys/fs/pipe-user-pages-soft` (default 64 MB shared across all
+your pipes). For batches >= 8 with 16 large pipes, raise this too:
+
+```bash
+sudo sysctl fs.pipe-user-pages-soft=262144 # 1 GB total per user
+```
+
+Both settings revert on reboot. Persist via `/etc/sysctl.d/99-trajviz.conf`
+if you want them permanent.
+
+### Why HOST_CACHED matters (NVIDIA)
+
+The single biggest win in trajviz's throughput came from requesting
+`HOST_CACHED` for the readback buffers (see `vk_batch_renderer.c`).
+NVIDIA's default `HOST_VISIBLE | HOST_COHERENT` memory type is
+write-combined PCIe BAR — fast for the GPU to write to, but
+**~250 MB/s for the CPU to read** because every read is uncached over
+PCIe. With `HOST_CACHED`, reads hit RAM at >5 GB/s. This is a 6-7×
+speedup on its own.
+
+If your device doesn't expose `HOST_CACHED` host-visible memory, trajviz
+falls back to plain `HOST_COHERENT` and prints no warning, so the only
+visible symptom is slower wall-clock per frame.
+
+### Choosing a batch size
+
+| batch_size | latency / batch | per-ep | ep/s |
+|---|---|---|---|
+| 1 | ~345 ms | 345 ms | 2.9 |
+| 2 | ~596 ms | 298 ms | 3.4 |
+| 4 | ~1.1 s | 274 ms | 3.7 |
+| 8 | ~2.1 s | 267 ms | 3.7 |
+
+The curve plateaus at `batch_size = 4` — past that, the CPU encoders
+(N parallel libx264 instances) saturate ~16 cores. Going to
+`batch_size = 16` doesn't help and consumes more pipe memory. Pick the
+smallest size that gives you the throughput you need.
+
+### Choosing an encoder (libx264 vs NVENC)
+
+Trajviz can use either CPU encoding (libx264) or NVIDIA hardware encoding
+(h264_nvenc). The default is **libx264** even on NVIDIA-equipped hosts.
+The choice is controlled by the `TRAJVIZ_ENCODER` env var:
+
+- unset (default) → `libx264 -preset veryfast -crf 20`
+- `TRAJVIZ_ENCODER=nvenc` → `h264_nvenc -preset p4 -tune hq -cq 23`
+
+**Why libx264 is the default even on NVIDIA boxes.** Counter-intuitively,
+NVENC turned out to be the wrong fit for trajviz's "spawn one ffmpeg
+subprocess per output MP4 per render call" architecture. Two reasons:
+
+1. **NVENC session creation is expensive (~100 ms per session).** trajviz
+ spawns 2N ffmpeg processes per `render_batch` call (one per output
+ MP4 file). For short episodes (≤500 frames) the per-session startup
+ cost is a meaningful fraction of the total wall time.
+
+2. **NVIDIA's driver throttles concurrent NVENC sessions per process.**
+ The "consumer-key" cap on simultaneous NVENC sessions was nominally
+ removed in driver 530+, but ffmpeg's `h264_nvenc` wrapper still
+ trips on it (`OpenEncodeSessionEx failed: incompatible client key
+ (21)`) at batch_size ≥ 8 — exactly the throughput regime where you'd
+ most want hardware encoding.
+
+3. **In steady state, libx264 `-preset veryfast` and NVENC `-preset p4`
+ are tied per-frame** at 720p on a modern multi-core CPU (~2.3 ms/frame
+ either way). libx264 is genuinely fast at fast presets, and a 16-core
+ CPU running 16 parallel libx264 instances out-throughputs a single
+ NVENC engine serializing 16 streams.
+
+Empirical results on RTX 4080 + 16-core CPU, measured per-episode wall
+time (1280×720, both views, libx264 vs nvenc, lower is better):
+
+| batch | T=90 frames | T=500 frames | T=1000 frames |
+|-------|----------------|----------------|----------------|
+| 1 | 350 / 790 ms | 1162 / 1540 ms | 2203 / 2284 ms |
+| 4 | 273 / 815 ms | 1139 / 1442 ms | 5157 / 5432 ms |
+
+Format: `libx264_ms / nvenc_ms`. NVENC closes the gap as episodes get
+longer (the startup cost amortizes) but never actually wins on this
+hardware in this architecture.
+
+**The only paths that would unlock NVENC for trajviz** are
+(a) holding **one persistent NVENC session per renderer** by switching
+from "spawn-one-ffmpeg-per-output" to a single long-lived ffmpeg with
+multi-input/multi-output, or (b) **direct integration of the NVENC C API**
+(`libnvidia-encode`) with `VK_KHR_external_memory_fd` to import VkImage
+atlases as CUDA arrays — frames never leave VRAM. Both are larger
+refactors than the current architecture.
+
+If you have a workload that doesn't match the typical trajviz pattern
+(e.g. one very long single-episode render where session startup is
+fully amortized), `TRAJVIZ_ENCODER=nvenc` is a one-line opt-in that
+gets you NVENC encoding via ffmpeg.
+
+### Debugging knobs
+
+The C side honors a few env vars for benchmarking:
+
+- `TRAJVIZ_ENCODER={libx264|nvenc}` — pick the video encoder. See above.
+- `TRAJVIZ_NO_FFMPEG=1` — replace the ffmpeg subprocess with `cat > /dev/null`.
+ Skips encoding cost; useful for measuring "render + readback + pipe write" alone.
+- `TRAJVIZ_NO_WRITE=1` — skip the `write()` to the pipe entirely. The
+ output mp4 will be empty/invalid; useful for measuring the pure
+ Vulkan + readback path.
+- `TRAJVIZ_FFMPEG=/path/to/ffmpeg` — override the ffmpeg binary used.
+
+## Architecture
+
+```
+ .npz / numpy arrays
+ │
+ ┌───────▼───────┐
+ │ __init__.py │ Renderer wrapper, npz loader,
+ └───────┬───────┘ numpy padding, batching shim
+ │
+ ▼
+ ┌───────────────┐
+ │ _native.c │ CPython extension boundary,
+ └───────┬───────┘ numpy → raw pointers, GIL release
+ │
+ ▼
+ ┌───────────────┐
+ │ trajviz.c │ public API: render_episode,
+ └───┬─────────┬─┘ render_episodes_batch
+ │ │
+ ▼ ▼
+ ┌──────────┐ ┌──────────────────┐
+ │vk_renderer│ │vk_batch_renderer │
+ │ (1 ep) │ │ (N eps tiled) │
+ └─────┬────┘ └────────┬─────────┘
+ │ │
+ ├────────────────┘
+ │
+ ▼
+ ┌──────────────────┐ ┌──────────────────┐
+ │ vk_pipeline.c │ │ ffmpeg_pipe.c │
+ │ vk_context.c │ │ + writer thread │
+ │ (Vulkan setup) │ │ (per pipe) │
+ └──────────────────┘ └──────────────────┘
+ │ │
+ ▼ ▼
+ Vulkan 1.3 driver ffmpeg subprocess
+```
+
+Key design points:
+
+- **Tiled atlas for batching**: the batched renderer allocates one large
+ color attachment image per view, sized `tile_w × (batch_size * tile_h)`.
+ Tiles are stacked **vertically** so each tile's bytes are row-contiguous
+ in the host readback buffer — one `write()` per tile per frame, no row
+ stitching.
+- **Threaded writers**: each ffmpeg pipe gets its own background writer
+ thread. The renderer's per-frame "submit all → wait all" loop pays
+ `max(single fwrite)` per frame instead of `sum(fwrites)`, which is the
+ threading win.
+- **Push-constant cameras**: per-frame and per-view MVP matrices are pushed
+ via `vkCmdPushConstants`, no descriptor sets. Each view has its own
+ camera matrix per slot per frame.
+- **LINE_STRIP polylines**: roads are drawn as `VK_PRIMITIVE_TOPOLOGY_LINE_STRIP`
+ with one `vkCmdDraw` per polyline (not per segment), so a 268-polyline
+ Waymo intersection is 268 draw calls per view, not ~2400.
+- **Instanced agent boxes**: the agent vertex shader expands a unit quad
+ by per-instance `(x, y, heading, length, width, color)`. One
+ `vkCmdDrawIndexed` per slot draws all of that slot's agents.
+
+## Known limitations / future work
+
+- **Uniform `num_steps` in batch**: all episodes in a batch share the
+ same length cap. The Python wrapper pads shorter episodes with zeros
+ and uses `agent_lengths` to mark valid steps. Episodes with very
+ different lengths waste GPU work on the trailing zeros.
+
+- **Per-env `world_means`**: each Drive sub-env in a vec computes its
+ own `world_mean` from its own map's geometry, so a `Drive(num_maps=N)`
+ with N different maps has N different centerings. Saved trajectory
+ files (`trajectories_*.npz`) carry both `world_means` (plural,
+ per-env, shape `(num_envs, 3)`) and the legacy `world_mean` (singular,
+ env 0 only, kept for back-compat). `render_npz` prefers the plural
+ key and falls back to the singular one with a warning. If you load
+ an old npz that only has `world_mean`, non-env-0 sub-envs with
+ different maps will have their roads mis-aligned by up to kilometers
+ — re-save with the current pufferl to fix.
+- **No NPC / expert-replay agents**: trajviz only renders the
+ controlled agents from `get_sim_trajectories`. The other 18 vehicles
+ in a typical Waymo scenario (the WOSAC "context" tracks) are not
+ shown. Adding them requires a separate Drive API to expose expert
+ trajectories.
+- **No 3D follow-cam**: the `RenderView.AGENT_PERSP` view from the
+ raylib visualizer (3D car meshes from `.glb`) is not implemented.
+- **CPU-bound by libx264**: the encoder is the wall once batching is
+ amortized. NVENC via the simple `-c:v h264_nvenc` opt-in does **not**
+ win on this hardware (see "Choosing an encoder" above) because trajviz
+ spawns a fresh ffmpeg subprocess per output and pays NVENC's session
+ startup tax every render. Closing the remaining ~12% gap to the
+ pure-GPU ceiling requires either a single long-lived ffmpeg with
+ multi-input/multi-output or direct `libnvidia-encode` integration with
+ `VK_KHR_external_memory_fd`. Both are larger refactors than v1.
+- **batch_size cap of 16**: enforced by `TRAJVIZ_BATCH_MAX` in
+ `trajviz.c`. The atlas image height grows linearly with batch_size,
+ and 16 × 720 = 11520 px is well under Vulkan's 16384 limit. Raising
+ it further requires either a 2-D tile grid layout or multiple atlas
+ passes.
+
+## Troubleshooting
+
+**ImportError on `from pufferlib.ocean.drive.trajviz import Renderer`** —
+the extension wasn't built. Run `TRAJVIZ=1 python setup.py build_ext --inplace`.
+
+**`vulkan/vulkan.h: No such file or directory`** during build —
+`libvulkan-dev` not installed. `sudo apt install libvulkan-dev`.
+
+**`glslangValidator: command not found`** during build — `glslang-tools`
+not installed. `sudo apt install glslang-tools`.
+
+**`no Vulkan-capable physical device found`** at runtime — your driver
+isn't exposing a Vulkan ICD. Check `vulkaninfo --summary`. On a remote
+node, ensure the GPU device files (`/dev/nvidia*`) are accessible to
+your user.
+
+**`ffmpeg topdown write failed at slot N`** — the ffmpeg subprocess
+crashed or was killed. Check the ffmpeg stderr in your terminal output.
+A common cause is the output path containing a single quote (we reject
+those for shell-quoting safety).
+
+**`pipe size 1048576 B < frame size 3686400 B — fwrites may block`** —
+informational warning that the kernel pipe buffer is smaller than one
+frame. Bump it via `sudo sysctl fs.pipe-max-size=16777216` (see
+performance section above).
+
+## Files
+
+```
+pufferlib/ocean/drive/
+├── map_io.py Map .bin parser (extracted from notebook)
+└── trajviz/
+ ├── __init__.py Python Renderer wrapper, render_npz
+ ├── __main__.py CLI entry point
+ ├── _native.c CPython extension shell (numpy unwrap)
+ ├── trajviz.{h,c} Public C API: render_episode, render_episodes_batch
+ ├── vk_context.{h,c} VkInstance, VkDevice, queues, command pool
+ ├── vk_pipeline.{h,c} Graphics pipelines (line + box)
+ ├── vk_renderer.{h,c} Single-episode renderer
+ ├── vk_batch_renderer.{h,c} Batched multi-episode renderer (tiled atlas)
+ ├── vk_math.h Mat4 helpers (header-only)
+ ├── ffmpeg_pipe.{h,c} Pipe to ffmpeg + writer thread
+ ├── shaders.h Externs for embedded SPIR-V blobs
+ ├── shaders.c GENERATED — do not commit
+ ├── shaders/
+ │ ├── polyline.{vert,frag} GLSL source for road polylines
+ │ ├── agent_box.{vert,frag} GLSL source for instanced agent quads
+ │ └── build_shaders.sh Compiles GLSL → embedded shaders.c
+ ├── tests/
+ │ └── test_main.c Standalone C test harness (no Python)
+ └── tools/
+ └── random_rollout.py Random-policy rollout → MP4 smoke test
+```
diff --git a/notebooks/visualize_trajectories.py b/notebooks/visualize_trajectories.py
new file mode 100644
index 0000000000..5af5277c6e
--- /dev/null
+++ b/notebooks/visualize_trajectories.py
@@ -0,0 +1,558 @@
+# ---
+# jupyter:
+# jupytext:
+# text_representation:
+# extension: .py
+# format_name: percent
+# format_version: '1.3'
+# jupytext_version: 1.19.1
+# kernelspec:
+# display_name: Python 3 (ipykernel)
+# language: python
+# name: python3
+# ---
+
+# %% [markdown]
+# # Trajectory Visualization
+#
+# Visualize saved simulation trajectories from training checkpoints.
+
+# %% Configuration
+import numpy as np
+import matplotlib.pyplot as plt
+from pathlib import Path
+
+TRAJ_PATH = "/tmp/traj_latest.npz"
+
+# %% Load data
+data = np.load(TRAJ_PATH, allow_pickle=True)
+
+traj_x = data["traj_x"]
+traj_y = data["traj_y"]
+traj_heading = data["traj_heading"]
+traj_lengths = data["traj_lengths"]
+map_ids = data["map_ids"]
+map_files = data["map_files"]
+rewards = data["rewards"]
+terminals = data["terminals"]
+truncations = data["truncations"]
+actions = data["actions"]
+is_invalid = data["is_invalid_step"]
+
+print(f"Total agents: {len(traj_lengths)}")
+print(f"Agents with data: {np.count_nonzero(traj_lengths)}")
+valid_lengths = traj_lengths[traj_lengths > 0]
+print(f"Mean traj length: {valid_lengths.mean():.1f}")
+print(f"Max traj length: {traj_lengths.max()}")
+print(f"Maps: {map_files}")
+print(f"Map distribution: {np.bincount(map_ids)}")
+
+# %% [markdown]
+# ## Trajectory Length Distribution
+
+# %%
+fig, axes = plt.subplots(1, 2, figsize=(14, 5))
+
+axes[0].hist(valid_lengths, bins=50, edgecolor="black", alpha=0.7)
+axes[0].set_xlabel("Trajectory Length (steps)")
+axes[0].set_ylabel("Count")
+axes[0].set_title(f"Trajectory Length Distribution (mean={valid_lengths.mean():.1f})")
+axes[0].axvline(valid_lengths.mean(), color="red", linestyle="--")
+
+for mid in range(len(map_files)):
+ mask = (map_ids == mid) & (traj_lengths > 0)
+ if mask.sum() > 0:
+ axes[1].hist(traj_lengths[mask], bins=30, alpha=0.5, label=f"Map {mid}: {Path(str(map_files[mid])).stem}")
+axes[1].set_xlabel("Trajectory Length")
+axes[1].set_title("Length Distribution per Map")
+axes[1].legend()
+
+plt.tight_layout()
+plt.show()
+
+# %% [markdown]
+# ## Spatial Trajectories per Map
+
+# %%
+n_maps = len(map_files)
+fig, axes = plt.subplots(1, n_maps, figsize=(7 * n_maps, 7))
+if n_maps == 1:
+ axes = [axes]
+
+for mid in range(n_maps):
+ ax = axes[mid]
+ mask = (map_ids == mid) & (traj_lengths > 5)
+ agent_indices = np.where(mask)[0]
+
+ # Sort by length, show longest trajectories
+ sorted_idx = agent_indices[np.argsort(traj_lengths[agent_indices])[::-1]]
+ n_show = min(50, len(sorted_idx))
+
+ for i in sorted_idx[:n_show]:
+ length = traj_lengths[i]
+ x = traj_x[i, :length]
+ y = traj_y[i, :length]
+ ax.plot(x, y, alpha=0.4, linewidth=0.8)
+ ax.plot(x[0], y[0], "go", markersize=3, alpha=0.5)
+ ax.plot(x[-1], y[-1], "rx", markersize=3, alpha=0.5)
+
+ ax.set_aspect("equal")
+ ax.set_title(f"Map {mid}: {Path(str(map_files[mid])).stem}\n({mask.sum()} agents, showing {n_show})")
+ ax.set_xlabel("x (m)")
+ ax.set_ylabel("y (m)")
+
+plt.tight_layout()
+plt.show()
+
+# %% [markdown]
+# ## Detailed Single-Agent Trajectory
+
+# %%
+best_idx = np.argmax(traj_lengths)
+length = traj_lengths[best_idx]
+print(f"Agent {best_idx}: length={length}, map={map_ids[best_idx]}")
+
+x = traj_x[best_idx, :length]
+y = traj_y[best_idx, :length]
+h = traj_heading[best_idx, :length]
+
+fig, axes = plt.subplots(2, 2, figsize=(14, 12))
+
+# XY trajectory colored by time
+ax = axes[0, 0]
+sc = ax.scatter(x, y, c=np.arange(length), cmap="viridis", s=5, alpha=0.8)
+ax.plot(x[0], y[0], "go", markersize=10, label="start")
+ax.plot(x[-1], y[-1], "rx", markersize=10, label="end")
+ax.set_aspect("equal")
+ax.set_title(f"Agent {best_idx} Trajectory (colored by time)")
+ax.set_xlabel("x (m)")
+ax.set_ylabel("y (m)")
+ax.legend()
+plt.colorbar(sc, ax=ax, label="Step")
+
+# Speed over time
+ax = axes[0, 1]
+if length > 1:
+ dx = np.diff(x)
+ dy = np.diff(y)
+ speed = np.sqrt(dx**2 + dy**2) / 0.1 # dt=0.1
+ ax.plot(speed, alpha=0.8)
+ ax.set_ylabel("Speed (m/s)")
+ ax.set_xlabel("Step")
+ ax.set_title(f"Speed (mean={speed.mean():.1f} m/s)")
+
+# Heading over time
+ax = axes[1, 0]
+ax.plot(np.degrees(h), alpha=0.8)
+ax.set_xlabel("Step")
+ax.set_ylabel("Heading (degrees)")
+ax.set_title("Heading Over Time")
+
+# Yaw rate
+ax = axes[1, 1]
+if length > 1:
+ dh = np.diff(h)
+ dh = (dh + np.pi) % (2 * np.pi) - np.pi
+ ax.plot(np.degrees(dh) / 0.1, alpha=0.8)
+ ax.set_xlabel("Step")
+ ax.set_ylabel("Yaw rate (deg/s)")
+ ax.set_title("Yaw Rate")
+
+plt.tight_layout()
+plt.show()
+
+# %% [markdown]
+# ## Action Distribution
+
+# %%
+flat_actions = actions.reshape(-1)
+flat_valid = is_invalid.reshape(-1) == 0
+valid_actions = flat_actions[flat_valid]
+
+jerk_long = [-15, -4, 0, 4]
+jerk_lat = [-4, 0, 4]
+labels = [f"L{jl}/S{sl}" for jl in jerk_long for sl in jerk_lat]
+
+fig, ax = plt.subplots(1, 1, figsize=(12, 5))
+counts = np.bincount(valid_actions.astype(int), minlength=12)
+ax.bar(range(12), counts / counts.sum())
+ax.set_xticks(range(12))
+ax.set_xticklabels(labels, rotation=45)
+ax.set_ylabel("Frequency")
+ax.set_title("Action Distribution")
+plt.tight_layout()
+plt.show()
+
+# %% [markdown]
+# ## Episode Return vs Length
+
+# %%
+flat_rewards = rewards.reshape(-1)
+flat_terminals = terminals.reshape(-1)
+flat_truncations = truncations.reshape(-1)
+flat_invalid_mask = is_invalid.reshape(-1)
+
+done_mask = (flat_terminals + flat_truncations).clip(max=1)
+valid_mask = flat_invalid_mask == 0
+
+episode_ends = np.where(done_mask > 0)[0]
+episode_starts = np.concatenate([[0], episode_ends[:-1] + 1])
+
+episode_returns = []
+episode_lengths = []
+for start, end in zip(episode_starts, episode_ends):
+ ep_valid = valid_mask[start : end + 1]
+ ep_return = flat_rewards[start : end + 1][ep_valid].sum()
+ ep_length = ep_valid.sum()
+ episode_returns.append(ep_return)
+ episode_lengths.append(ep_length)
+
+episode_returns = np.array(episode_returns)
+episode_lengths = np.array(episode_lengths)
+
+fig, axes = plt.subplots(1, 3, figsize=(18, 5))
+
+axes[0].hist(episode_returns, bins=50, edgecolor="black", alpha=0.7)
+axes[0].set_xlabel("Episode Return")
+axes[0].set_title(f"Returns (mean={episode_returns.mean():.2f})")
+axes[0].axvline(episode_returns.mean(), color="red", linestyle="--")
+
+axes[1].hist(episode_lengths, bins=50, edgecolor="black", alpha=0.7, color="orange")
+axes[1].set_xlabel("Episode Length")
+axes[1].set_title(f"Lengths (mean={episode_lengths.mean():.1f})")
+
+axes[2].scatter(episode_lengths, episode_returns, alpha=0.1, s=5)
+axes[2].set_xlabel("Episode Length")
+axes[2].set_ylabel("Episode Return")
+axes[2].set_title("Return vs Length")
+
+plt.tight_layout()
+plt.show()
+
+print(f"Episodes: {len(episode_returns)}")
+print(f"Mean return: {episode_returns.mean():.2f} +/- {episode_returns.std():.2f}")
+print(f"Mean length: {episode_lengths.mean():.1f}")
+
+# %% [markdown]
+# ## Spawn Position Analysis
+
+# %%
+valid = traj_lengths > 5
+start_x = traj_x[valid, 0]
+start_y = traj_y[valid, 0]
+end_x = np.array([traj_x[i, traj_lengths[i] - 1] for i in np.where(valid)[0]])
+end_y = np.array([traj_y[i, traj_lengths[i] - 1] for i in np.where(valid)[0]])
+
+# Displacement
+dist = np.sqrt((end_x - start_x) ** 2 + (end_y - start_y) ** 2)
+
+# Total path length
+path_lengths = []
+for i in np.where(valid)[0]:
+ length = traj_lengths[i]
+ dx = np.diff(traj_x[i, :length])
+ dy = np.diff(traj_y[i, :length])
+ path_lengths.append(np.sqrt(dx**2 + dy**2).sum())
+path_lengths = np.array(path_lengths)
+
+fig, axes = plt.subplots(1, 2, figsize=(14, 5))
+axes[0].hist(dist, bins=50, edgecolor="black", alpha=0.7, color="green")
+axes[0].set_xlabel("Euclidean Distance Start->End (m)")
+axes[0].set_title(f"Displacement (mean={dist.mean():.1f}m)")
+
+axes[1].hist(path_lengths, bins=50, edgecolor="black", alpha=0.7, color="purple")
+axes[1].set_xlabel("Total Path Length (m)")
+axes[1].set_title(f"Path Length (mean={path_lengths.mean():.1f}m)")
+
+plt.tight_layout()
+plt.show()
+
+# %% [markdown]
+# ## Interactive Map + Trajectory Viewer
+#
+# Load map binary files and overlay agent trajectories. Use the slider to select agents ranked by trajectory length.
+
+# %%
+import struct
+import ipywidgets as widgets
+from IPython.display import display
+
+# Road type constants (from drive.h)
+ROAD_LANE = 4
+ROAD_LINE = 5
+ROAD_EDGE = 6
+DRIVEWAY = 10
+
+
+def load_map_roads(map_path):
+ """Read road elements from a PufferDrive binary map file."""
+ roads = []
+ with open(map_path, "rb") as f:
+ sdc_track_index = struct.unpack("i", f.read(4))[0]
+ num_tracks_to_predict = struct.unpack("i", f.read(4))[0]
+ if num_tracks_to_predict > 0:
+ f.read(num_tracks_to_predict * 4) # skip track indices
+
+ num_objects = struct.unpack("i", f.read(4))[0]
+ num_roads = struct.unpack("i", f.read(4))[0]
+
+ total_entities = num_objects + num_roads
+ for i in range(total_entities):
+ scenario_id = struct.unpack("i", f.read(4))[0]
+ entity_type = struct.unpack("i", f.read(4))[0]
+ entity_id = struct.unpack("i", f.read(4))[0]
+ array_size = struct.unpack("i", f.read(4))[0]
+
+ if i < num_objects:
+ # Agent: skip trajectory arrays + scalar fields
+ # x, y, z, vx, vy, vz (6 float arrays) + heading (float) + valid (int)
+ f.read(array_size * 4 * 6) # 6 float arrays
+ f.read(array_size * 4) # heading (float)
+ f.read(array_size * 4) # valid (int)
+ f.read(4 * 3 + 4 * 3 + 4) # width,length,height + goal xyz + mark_as_expert
+ else:
+ # Road element
+ x = np.frombuffer(f.read(array_size * 4), dtype=np.float32).copy()
+ y = np.frombuffer(f.read(array_size * 4), dtype=np.float32).copy()
+ z = np.frombuffer(f.read(array_size * 4), dtype=np.float32).copy()
+ f.read(4 * 3 + 4 * 3 + 4) # skip scalar fields
+ roads.append({"type": entity_type, "x": x, "y": y, "z": z})
+
+ return roads
+
+
+def mean_center_roads(roads, world_mean):
+ """Subtract world_mean from road coordinates to match simulation frame."""
+ for r in roads:
+ r["x"] = r["x"] - world_mean[0]
+ r["y"] = r["y"] - world_mean[1]
+ if len(world_mean) > 2:
+ r["z"] = r["z"] - world_mean[2]
+ return roads
+
+
+# Get world_mean from trajectory data (exact value from C code)
+world_mean = data.get("world_mean", None)
+if world_mean is not None:
+ print(f"Using world_mean from trajectory data: {world_mean}")
+else:
+ print("WARNING: world_mean not in trajectory data, map alignment may be off")
+
+# Load all maps — resolve relative paths against project root
+PROJECT_ROOT = Path(__file__).resolve().parent.parent if "__file__" in dir() else Path.cwd().parent
+map_roads = {}
+for mid, mf in enumerate(map_files):
+ mf_str = str(mf)
+ mf_path = Path(mf_str)
+ if not mf_path.exists():
+ mf_path = PROJECT_ROOT / mf_str
+ if mf_path.exists():
+ roads = load_map_roads(str(mf_path))
+ if world_mean is not None:
+ mean_center_roads(roads, world_mean)
+ map_roads[mid] = roads
+ print(f"Map {mid} ({mf_path.stem}): {len(roads)} road elements")
+ else:
+ print(f"Map {mid} ({mf_str}): not found at {mf_path}, skipping")
+
+
+# %%
+from matplotlib.patches import FancyArrow
+from matplotlib.animation import FuncAnimation
+from IPython.display import HTML
+
+
+def draw_map_background(ax, mid):
+ """Draw road elements for a given map onto ax."""
+ if mid not in map_roads:
+ return
+ for road in map_roads[mid]:
+ if road["type"] == ROAD_EDGE:
+ ax.plot(road["x"], road["y"], color="gray", linewidth=0.8, alpha=0.6)
+ elif road["type"] == ROAD_LANE:
+ ax.plot(road["x"], road["y"], color="khaki", linewidth=0.5, alpha=0.4)
+ elif road["type"] == ROAD_LINE:
+ ax.plot(road["x"], road["y"], color="white", linewidth=0.3, alpha=0.3)
+
+
+def draw_map_with_trajectory(agent_idx):
+ """Draw map roads and overlay the full trajectory for an agent."""
+ mid = map_ids[agent_idx]
+ length = traj_lengths[agent_idx]
+
+ fig, ax = plt.subplots(1, 1, figsize=(12, 12))
+ draw_map_background(ax, mid)
+
+ if length > 1:
+ x = traj_x[agent_idx, :length]
+ y = traj_y[agent_idx, :length]
+ h = traj_heading[agent_idx, :length]
+
+ sc = ax.scatter(x, y, c=np.arange(length), cmap="plasma", s=15, zorder=5, alpha=0.9)
+ ax.plot(x[0], y[0], "go", markersize=12, zorder=6, label="start")
+ ax.plot(x[-1], y[-1], "rx", markersize=12, zorder=6, label="end")
+
+ arrow_step = max(1, length // 20)
+ for t in range(0, length, arrow_step):
+ dx = np.cos(h[t]) * 2
+ dy = np.sin(h[t]) * 2
+ ax.arrow(x[t], y[t], dx, dy, head_width=0.5, head_length=0.3, fc="cyan", ec="cyan", alpha=0.7, zorder=7)
+
+ plt.colorbar(sc, ax=ax, label="Step", shrink=0.7)
+ pad = 30
+ ax.set_xlim(x.min() - pad, x.max() + pad)
+ ax.set_ylim(y.min() - pad, y.max() + pad)
+
+ ax.set_aspect("equal")
+ ax.set_facecolor("#2a2a2a")
+ ax.set_title(
+ f"Agent {agent_idx} | Map {mid} ({Path(str(map_files[mid])).stem}) | Length: {length} steps", fontsize=13
+ )
+ ax.set_xlabel("x (m)")
+ ax.set_ylabel("y (m)")
+ if length > 1:
+ ax.legend(fontsize=11)
+ plt.tight_layout()
+ plt.show()
+
+
+def make_trajectory_video(agent_idx, follow_agent=True, window_size=60):
+ """Create an animation of the agent's trajectory unrolling on the map."""
+ mid = map_ids[agent_idx]
+ length = traj_lengths[agent_idx]
+ if length < 2:
+ print(f"Agent {agent_idx} has no trajectory data")
+ return None
+
+ x = traj_x[agent_idx, :length]
+ y = traj_y[agent_idx, :length]
+ h = traj_heading[agent_idx, :length]
+
+ fig, ax = plt.subplots(1, 1, figsize=(10, 10))
+ draw_map_background(ax, mid)
+ ax.set_aspect("equal")
+ ax.set_facecolor("#2a2a2a")
+
+ # Trail line (grows over time)
+ (trail_line,) = ax.plot([], [], color="cyan", linewidth=2, alpha=0.6, zorder=4)
+ # Current position marker
+ (car_marker,) = ax.plot([], [], "o", color="lime", markersize=10, zorder=6)
+ # Heading arrow (updated each frame)
+ heading_arrow = None
+ # Start marker
+ ax.plot(x[0], y[0], "s", color="lime", markersize=8, zorder=5, label="start")
+ title = ax.set_title("", fontsize=13)
+
+ if follow_agent:
+ half = window_size / 2
+ else:
+ pad = 30
+ ax.set_xlim(x.min() - pad, x.max() + pad)
+ ax.set_ylim(y.min() - pad, y.max() + pad)
+
+ def init():
+ trail_line.set_data([], [])
+ car_marker.set_data([], [])
+ return trail_line, car_marker
+
+ def animate(frame):
+ nonlocal heading_arrow
+ t = frame
+
+ # Update trail
+ trail_line.set_data(x[: t + 1], y[: t + 1])
+ # Update car position
+ car_marker.set_data([x[t]], [y[t]])
+
+ # Update heading arrow
+ if heading_arrow is not None:
+ heading_arrow.remove()
+ arrow_len = 3
+ dx = np.cos(h[t]) * arrow_len
+ dy = np.sin(h[t]) * arrow_len
+ heading_arrow = ax.arrow(x[t], y[t], dx, dy, head_width=1.0, head_length=0.5, fc="red", ec="red", zorder=7)
+
+ # Camera follow
+ if follow_agent:
+ ax.set_xlim(x[t] - half, x[t] + half)
+ ax.set_ylim(y[t] - half, y[t] + half)
+
+ # Speed from position diff
+ if t > 0:
+ spd = np.sqrt((x[t] - x[t - 1]) ** 2 + (y[t] - y[t - 1]) ** 2) / 0.1
+ else:
+ spd = 0
+ title.set_text(f"Agent {agent_idx} | Step {t}/{length} | Speed: {spd:.1f} m/s")
+
+ return trail_line, car_marker, heading_arrow
+
+ anim = FuncAnimation(fig, animate, init_func=init, frames=length, interval=100, blit=False)
+ plt.close(fig)
+ return anim
+
+
+# Build list of agents with trajectories, sorted by length (longest first)
+agents_with_data = np.where(traj_lengths > 1)[0]
+agents_sorted = agents_with_data[np.argsort(traj_lengths[agents_with_data])[::-1]]
+
+print(f"{len(agents_sorted)} agents with trajectory data")
+print(f"Longest: agent {agents_sorted[0]} with {traj_lengths[agents_sorted[0]]} steps")
+
+# %% [markdown]
+# ### Static view — select agent by rank (0 = longest trajectory)
+
+# %%
+output = widgets.Output()
+
+agent_slider = widgets.IntSlider(
+ value=0,
+ min=0,
+ max=len(agents_sorted) - 1,
+ step=1,
+ description="Rank:",
+ continuous_update=False,
+ layout=widgets.Layout(width="80%"),
+)
+agent_label = widgets.Label(value="")
+
+
+def update_static(change):
+ idx = agents_sorted[change["new"]]
+ agent_label.value = f"Agent {idx} | Map {map_ids[idx]} | Length {traj_lengths[idx]} steps"
+ with output:
+ output.clear_output(wait=True)
+ draw_map_with_trajectory(idx)
+
+
+agent_slider.observe(update_static, names="value")
+display(
+ widgets.VBox(
+ [
+ widgets.HTML("
Static Trajectory View
"),
+ widgets.HBox([agent_slider, agent_label]),
+ output,
+ ]
+ )
+)
+update_static({"new": 0})
+
+# %% [markdown]
+# ### Animated rollout — watch the agent drive
+#
+# Set `AGENT_RANK` below to pick which agent to animate (0 = longest trajectory).
+# Set `FOLLOW = True` to have the camera follow the agent, `False` for fixed view.
+
+# %%
+AGENT_RANK = 800 # change this to pick a different agent
+FOLLOW = True # camera follows agent
+
+agent_idx = agents_sorted[AGENT_RANK]
+print(
+ f"Animating agent {agent_idx} (rank {AGENT_RANK}), length {traj_lengths[agent_idx]} steps, map {map_ids[agent_idx]}"
+)
+
+anim = make_trajectory_video(agent_idx, follow_agent=FOLLOW, window_size=60)
+if anim is not None:
+ display(HTML(anim.to_jshtml()))
+
+# %%
diff --git a/pufferlib/ocean/drive/datatypes.h b/pufferlib/ocean/drive/datatypes.h
index 644571ef43..7726d1cdcd 100644
--- a/pufferlib/ocean/drive/datatypes.h
+++ b/pufferlib/ocean/drive/datatypes.h
@@ -157,6 +157,14 @@ struct Agent {
float *log_height;
int *log_valid;
+ // Per-step simulation trajectory buffers (allocated to episode_length).
+ // Recorded each step of the current episode; used by save_trajectories()
+ // at checkpoint time so we can replay what agents actually did.
+ float *sim_traj_x;
+ float *sim_traj_y;
+ float *sim_traj_z;
+ float *sim_traj_heading;
+
// Simulation state
float sim_x;
float sim_y;
@@ -325,6 +333,10 @@ void free_agent(struct Agent *agent) {
free(agent->log_width);
free(agent->log_height);
free(agent->log_valid);
+ free(agent->sim_traj_x);
+ free(agent->sim_traj_y);
+ free(agent->sim_traj_z);
+ free(agent->sim_traj_heading);
free(agent->route);
free(agent->path);
}
diff --git a/pufferlib/ocean/drive/drive.h b/pufferlib/ocean/drive/drive.h
index 741a9307e3..885ff7bca4 100644
--- a/pufferlib/ocean/drive/drive.h
+++ b/pufferlib/ocean/drive/drive.h
@@ -2295,6 +2295,17 @@ void init(Drive *env) {
set_start_position(env);
init_goal_positions(env);
env->logs = (Log *)calloc(env->active_agent_count, sizeof(Log));
+
+ // Allocate per-step sim trajectory buffers (episode_length floats per agent).
+ // Used by save_trajectories() at checkpoint time for offline rendering. Recorded
+ // in c_step after move_dynamics; retrieved via c_get_sim_trajectories.
+ for (int i = 0; i < env->active_agent_count; i++) {
+ int idx = env->active_agent_indices[i];
+ env->agents[idx].sim_traj_x = (float *)calloc(env->episode_length, sizeof(float));
+ env->agents[idx].sim_traj_y = (float *)calloc(env->episode_length, sizeof(float));
+ env->agents[idx].sim_traj_z = (float *)calloc(env->episode_length, sizeof(float));
+ env->agents[idx].sim_traj_heading = (float *)calloc(env->episode_length, sizeof(float));
+ }
}
void close_client(Client *client);
@@ -2417,6 +2428,27 @@ void c_get_global_ground_truth_trajectories(Drive *env, float *x_out, float *y_o
}
}
+// Copy recorded per-step sim trajectory for all active agents into the output
+// arrays. x_out/y_out/z_out/heading_out are (active_agent_count, ep_len) float
+// buffers written row-major. lengths_out receives the current timestep for
+// each agent (how far the episode has progressed). Slots past `lengths_out[i]`
+// are either zeros (fresh episode) or stale data from a prior episode — the
+// caller should use `lengths_out` to slice down.
+void c_get_sim_trajectories(Drive *env, float *x_out, float *y_out, float *z_out, float *heading_out, int *lengths_out,
+ int ep_len) {
+ for (int i = 0; i < env->active_agent_count; i++) {
+ int idx = env->active_agent_indices[i];
+ Agent *agent = &env->agents[idx];
+ lengths_out[i] = env->timestep;
+ if (agent->sim_traj_x != NULL) {
+ memcpy(&x_out[i * ep_len], agent->sim_traj_x, ep_len * sizeof(float));
+ memcpy(&y_out[i * ep_len], agent->sim_traj_y, ep_len * sizeof(float));
+ memcpy(&z_out[i * ep_len], agent->sim_traj_z, ep_len * sizeof(float));
+ memcpy(&heading_out[i * ep_len], agent->sim_traj_heading, ep_len * sizeof(float));
+ }
+ }
+}
+
void c_get_road_edge_counts(Drive *env, int *num_polylines_out, int *total_points_out) {
int count = 0, points = 0;
for (int i = 0; i < env->num_roads; i++) {
@@ -3251,6 +3283,17 @@ void c_step(Drive *env) {
move_dynamics(env, i, agent_idx);
+ // Record per-step sim trajectory for checkpoint replay. env->timestep was
+ // incremented at the top of c_step, so the state we just moved to corresponds
+ // to step index (timestep - 1) within the current episode.
+ int t = env->timestep - 1;
+ if (t >= 0 && t < env->episode_length && env->agents[agent_idx].sim_traj_x != NULL) {
+ env->agents[agent_idx].sim_traj_x[t] = env->agents[agent_idx].sim_x;
+ env->agents[agent_idx].sim_traj_y[t] = env->agents[agent_idx].sim_y;
+ env->agents[agent_idx].sim_traj_z[t] = env->agents[agent_idx].sim_z;
+ env->agents[agent_idx].sim_traj_heading[t] = env->agents[agent_idx].sim_heading;
+ }
+
// Accumulate distance for avg_distance_per_infraction metric
float speed = sqrtf(env->agents[agent_idx].sim_vx * env->agents[agent_idx].sim_vx +
env->agents[agent_idx].sim_vy * env->agents[agent_idx].sim_vy);
diff --git a/pufferlib/ocean/drive/drive.py b/pufferlib/ocean/drive/drive.py
index 46726a9429..0360df177a 100644
--- a/pufferlib/ocean/drive/drive.py
+++ b/pufferlib/ocean/drive/drive.py
@@ -105,7 +105,15 @@ def __init__(
spawn_length_min=2.0,
spawn_length_max=5.5,
spawn_height=1.5,
+ # Trajectory saving: set by pufferl.py when save_trajectories is enabled.
+ # In multiprocessing mode each worker writes a per-worker npz under this
+ # dir via the notify() mechanism; the driver concatenates them afterward.
+ traj_save_dir=None,
):
+ # Trajectory save state; _worker_idx is assigned by _worker_process in
+ # vector.py so notify() knows which per-worker file to write.
+ self._traj_save_dir = traj_save_dir
+ self._worker_idx = None
# env
self.dt = dt
self.render_mode = render_mode
@@ -442,6 +450,9 @@ def __init__(
self.env_ids.append(env_id)
self.c_envs = binding.vectorize(*self.env_ids)
+ # Cache world_mean once — all sub-envs share the same centering convention.
+ # Used by save_trajectories() to lift sim coordinates back to world frame.
+ self.world_mean = binding.vec_get_world_mean(self.c_envs)
def reset(self, seed=0):
binding.vec_reset(self.c_envs, seed)
@@ -605,6 +616,8 @@ def resample_maps(self):
)
self.env_ids.append(env_id)
self.c_envs = binding.vectorize(*self.env_ids)
+ # Refresh cached world_mean after resample (new maps → new centering)
+ self.world_mean = binding.vec_get_world_mean(self.c_envs)
binding.vec_reset(self.c_envs, seed)
self.truncations[:] = 1
@@ -764,6 +777,75 @@ def set_video_suffix(self, suffix: str, env_id: int = 0):
def close(self):
binding.vec_close(self.c_envs)
+ def get_sim_trajectories(self):
+ """Retrieve the per-step sim trajectories recorded in C for the current episode.
+
+ Returns a dict with numpy arrays:
+ x, y, z, heading: (num_agents, episode_length) float32
+ lengths: (num_agents,) int32 — number of valid steps in the current episode
+
+ Slots past ``lengths[i]`` are either zeros (fresh episode) or stale from a
+ prior episode in the same buffer; callers should slice by ``lengths``.
+ """
+ assert self.episode_length is not None, "episode_length must be set for trajectory recording"
+ ep_len = self.episode_length
+ n = self.num_agents
+ traj = {
+ "x": np.zeros((n, ep_len), dtype=np.float32),
+ "y": np.zeros((n, ep_len), dtype=np.float32),
+ "z": np.zeros((n, ep_len), dtype=np.float32),
+ "heading": np.zeros((n, ep_len), dtype=np.float32),
+ "lengths": np.zeros(n, dtype=np.int32),
+ }
+ binding.vec_get_sim_trajectories(
+ self.c_envs,
+ traj["x"],
+ traj["y"],
+ traj["z"],
+ traj["heading"],
+ traj["lengths"],
+ ep_len,
+ )
+ return traj
+
+ def get_world_means(self) -> np.ndarray:
+ """Per-env world_means as a (num_envs, 3) float32 array.
+
+ Each Drive sub-env in a vec computes its own world_mean in
+ ``set_means()`` from its own map's road + agent points. Different
+ maps therefore have different world_means (potentially many
+ kilometers apart in source-Waymo coordinates), and any code that
+ needs to align per-env trajectories with their source maps must
+ use this per-env array — NOT ``self.world_mean`` (singular),
+ which only carries env 0's value and is kept for back-compat.
+ """
+ out = np.zeros((self.num_envs, 3), dtype=np.float32)
+ binding.vec_get_all_world_means(self.c_envs, out)
+ return out
+
+ def notify(self):
+ """Called via the notify mechanism in pufferlib.vector on every worker.
+
+ Each worker writes its own trajectory npz under ``_traj_save_dir`` with
+ a filename keyed by ``_worker_idx``. The driver later concatenates these
+ worker files in ``PuffeRL.save_trajectories``.
+ """
+ if self._traj_save_dir is None or self._worker_idx is None:
+ return
+ traj = self.get_sim_trajectories()
+ traj["map_ids"] = np.array(self.map_ids, dtype=np.int32)
+ traj["agent_offsets"] = np.array(self.agent_offsets, dtype=np.int32)
+ traj["map_files"] = np.array([str(f) for f in self.map_files])
+ # world_means (plural): per-env, shape (num_envs, 3). The right
+ # array for offline tooling that needs to align each env's
+ # trajectory with its own source map. world_mean (singular,
+ # legacy) is env 0 only and is kept for back-compat with older
+ # consumers.
+ traj["world_means"] = self.get_world_means()
+ traj["world_mean"] = np.array(self.world_mean, dtype=np.float32)
+ path = os.path.join(self._traj_save_dir, f"traj_worker_{self._worker_idx}.npz")
+ np.savez_compressed(path, **traj)
+
def env_log(self, env_idx):
"""Get log statistics for a single environment."""
num_agents = self.agent_offsets[env_idx + 1] - self.agent_offsets[env_idx]
diff --git a/pufferlib/ocean/drive/map_io.py b/pufferlib/ocean/drive/map_io.py
new file mode 100644
index 0000000000..98b9d505f9
--- /dev/null
+++ b/pufferlib/ocean/drive/map_io.py
@@ -0,0 +1,152 @@
+"""map_io.py — read PufferDrive .bin map files into plain numpy arrays.
+
+The binary format is the one written by C-side save_map_binary in
+drive.h. It is shared between the live sim and offline tooling, so this
+parser must stay in sync with the C side. The notebook
+``notebooks/visualize_trajectories.py`` had its own copy of this code;
+extracted here so trajviz, the notebook, and any future tools can share
+one source of truth.
+
+Layout (little-endian, no padding):
+ int32 sdc_track_index
+ int32 num_tracks_to_predict
+ int32 * num_tracks_to_predict (track indices, skipped — not needed
+ for road geometry)
+ int32 num_objects
+ int32 num_roads
+ repeat (num_objects + num_roads) entities:
+ int32 scenario_id
+ int32 entity_type
+ int32 entity_id
+ int32 array_size
+ if entity is an object:
+ float32 * array_size x
+ float32 * array_size y
+ float32 * array_size z
+ float32 * array_size vx
+ float32 * array_size vy
+ float32 * array_size vz
+ float32 * array_size heading
+ int32 * array_size valid
+ float32 * 3 length, width, height
+ float32 * 3 goal x, y, z
+ int32 mark_as_expert
+ else (road element):
+ float32 * array_size x
+ float32 * array_size y
+ float32 * array_size z
+ float32 * 3 tail scalars (we don't need them)
+ float32 * 3 more scalars
+ int32 final scalar
+
+The "tail scalars" on road elements are not used by trajviz; the C sim
+parses them. We just skip them.
+"""
+
+from __future__ import annotations
+
+import struct
+from pathlib import Path
+from typing import List, Sequence
+
+import numpy as np
+
+# Road type ids — copied from drive.h. Mirrored as TVZ_ROAD_* in trajviz.h.
+ROAD_LANE = 4
+ROAD_LINE = 5
+ROAD_EDGE = 6
+ROAD_DRIVEWAY = 10
+
+
+class Road:
+ """One road polyline."""
+
+ __slots__ = ("type", "x", "y", "z")
+
+ def __init__(self, type: int, x: np.ndarray, y: np.ndarray, z: np.ndarray):
+ self.type = int(type)
+ self.x = x
+ self.y = y
+ self.z = z
+
+ def __len__(self) -> int:
+ return int(self.x.shape[0])
+
+
+def load_map_roads(map_path: Path | str) -> List[Road]:
+ """Read all road polylines from a PufferDrive .bin map file.
+
+ Returns a list of Road objects in source-map (un-centered) coordinates.
+ Use ``mean_center_roads`` to subtract a world_mean if you need them in
+ sim frame.
+ """
+ map_path = Path(map_path)
+ roads: List[Road] = []
+
+ with open(map_path, "rb") as f:
+ _sdc = struct.unpack(" 0:
+ f.read(num_tracks_to_predict * 4)
+
+ num_objects = struct.unpack(" List[Road]:
+ """Return a new list of roads with world_mean subtracted from x/y (and z
+ if world_mean has 3 components). The input list is not modified."""
+ out: List[Road] = []
+ for r in roads:
+ nx = r.x - world_mean[0]
+ ny = r.y - world_mean[1]
+ nz = r.z - world_mean[2] if len(world_mean) > 2 else r.z.copy()
+ out.append(Road(r.type, nx, ny, nz))
+ return out
+
+
+def roads_to_csr(roads: Sequence[Road]) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
+ """Convert a list of Road objects into the CSR layout the trajviz C
+ extension expects.
+
+ Returns:
+ road_xy: (N, 2) float32 — concatenated (x, y) of all polylines
+ road_offsets: (P+1,) uint32 — start index per polyline
+ road_types: (P,) uint32 — TVZ_ROAD_* type id per polyline
+ """
+ if not roads:
+ return (np.zeros((0, 2), dtype=np.float32), np.zeros((1,), dtype=np.uint32), np.zeros((0,), dtype=np.uint32))
+
+ lens = np.array([len(r) for r in roads], dtype=np.uint32)
+ offsets = np.zeros(len(roads) + 1, dtype=np.uint32)
+ np.cumsum(lens, out=offsets[1:])
+ total = int(offsets[-1])
+
+ xy = np.empty((total, 2), dtype=np.float32)
+ for i, r in enumerate(roads):
+ s, e = int(offsets[i]), int(offsets[i + 1])
+ xy[s:e, 0] = r.x
+ xy[s:e, 1] = r.y
+ types = np.array([r.type for r in roads], dtype=np.uint32)
+ return xy, offsets, types
diff --git a/pufferlib/ocean/drive/trajviz/__init__.py b/pufferlib/ocean/drive/trajviz/__init__.py
new file mode 100644
index 0000000000..cc74bdf2ec
--- /dev/null
+++ b/pufferlib/ocean/drive/trajviz/__init__.py
@@ -0,0 +1,407 @@
+"""trajviz — Vulkan offline renderer for saved Drive trajectories.
+
+Public API:
+
+ >>> from pufferlib.ocean.drive.trajviz import Renderer, render_npz
+ >>> with Renderer(width=1280, height=720) as r:
+ ... r.render_episode(
+ ... road_xy=..., road_offsets=..., road_types=...,
+ ... traj_xyh=...,
+ ... agent_lengths=...,
+ ... out_topdown="td.mp4", out_bev="bev.mp4",
+ ... )
+
+Or, more usually, the high-level npz path:
+
+ >>> from pufferlib.ocean.drive.trajviz import render_npz
+ >>> render_npz("trajectories_000010.npz",
+ ... maps_dir="path/to/maps",
+ ... out_dir="videos/")
+
+CLI:
+
+ python -m pufferlib.ocean.drive.trajviz --maps-dir --out
+
+The Vulkan context is created once per Renderer and reused across many
+render_episode calls — pay the ~50 ms init cost once for a whole batch.
+
+If the C extension fails to import (typically because libvulkan is not
+installed at runtime, or the build did not include trajviz), this module
+raises ImportError on first use with a pointer to docs/trajviz.md.
+"""
+
+from __future__ import annotations
+
+from pathlib import Path
+from typing import Iterable, Optional
+
+import numpy as np
+
+from pufferlib.ocean.drive import map_io
+
+
+class _NativeUnavailable:
+ """Stand-in raised when the C extension isn't built/available."""
+
+ def __init__(self, exc: Exception):
+ self._exc = exc
+
+ def __getattr__(self, name):
+ raise ImportError(
+ "trajviz._native is not available. Build with TRAJVIZ=1 and "
+ "make sure libvulkan-dev + glslang-tools are installed. "
+ f"Original error: {self._exc}\n"
+ "See docs/trajviz.md for setup."
+ )
+
+
+try:
+ from . import _native # type: ignore
+except ImportError as _e:
+ _native = _NativeUnavailable(_e) # type: ignore
+
+
+class Renderer:
+ """Vulkan trajectory renderer with a hot context across episodes.
+
+ Use as a context manager to ensure the Vulkan context is closed even
+ on exceptions:
+
+ with Renderer(width=1280, height=720) as r:
+ for episode in episodes:
+ r.render_episode(...)
+ """
+
+ def __init__(self, width: int = 1280, height: int = 720):
+ self._ctx = _native.init(width, height)
+ self.width = int(width)
+ self.height = int(height)
+
+ def __enter__(self) -> "Renderer":
+ return self
+
+ def __exit__(self, *exc) -> None:
+ self.close()
+
+ def close(self) -> None:
+ if self._ctx is not None:
+ _native.close(self._ctx)
+ self._ctx = None
+
+ def render_batch(
+ self,
+ episodes: list,
+ *,
+ fps: int = 30,
+ ) -> None:
+ """Render N episodes simultaneously by tiling them in a per-view atlas.
+
+ ``episodes`` is a list of dicts, each with these keys:
+
+ road_xy: (V, 2) float array (mean-centered sim frame)
+ road_offsets: (P+1,) int CSR offsets into road_xy
+ road_types: (P,) int TVZ_ROAD_* type ids
+ traj_xyh: (T, A, 3) float (x, y, heading) per agent per step
+ agent_lengths: (A,) int valid step counts (optional, defaults to T)
+ ego_idx: int (default -1 = first valid)
+ out_topdown: str path or None
+ out_bev: str path or None
+
+ All episodes must use the same renderer dimensions (the Renderer's
+ ``width`` × ``height``); inside, every episode's tile is exactly
+ that size. Episodes with different ``T`` and ``A`` are accepted —
+ the wrapper pads them to the batch's max T and A so the C extension
+ can use a uniform shape.
+
+ Calling this on the same Renderer with the same ``len(episodes)``
+ across calls is much cheaper than recreating: the BatchRenderer's
+ atlas + readback buffers are kept hot.
+ """
+ if self._ctx is None:
+ raise RuntimeError("Renderer is closed")
+ if not episodes:
+ return
+ batch_size = len(episodes)
+ if batch_size > 16:
+ raise ValueError(f"batch_size {batch_size} exceeds the v1 cap of 16")
+
+ # Find the batch-wide max T and A so we can pad into a uniform tensor.
+ num_steps = max(int(ep["traj_xyh"].shape[0]) for ep in episodes)
+ max_agents = max(int(ep["traj_xyh"].shape[1]) for ep in episodes)
+
+ traj = np.zeros((batch_size, num_steps, max_agents, 3), dtype=np.float32)
+ agent_lengths = np.zeros((batch_size, max_agents), dtype=np.int32)
+ for i, ep in enumerate(episodes):
+ t = np.ascontiguousarray(ep["traj_xyh"], dtype=np.float32)
+ T, A, _ = t.shape
+ traj[i, :T, :A, :] = t
+ if "agent_lengths" in ep and ep["agent_lengths"] is not None:
+ ll = np.ascontiguousarray(ep["agent_lengths"], dtype=np.int32)
+ agent_lengths[i, : len(ll)] = ll
+ else:
+ agent_lengths[i, :A] = T
+
+ # Concatenate ragged road geometry with CSR-style per-episode offsets.
+ # The C side splits each per-episode slice out of these flat arrays.
+ all_xy_parts = []
+ all_off_parts = []
+ all_typ_parts = []
+ vert_offsets = [0]
+ poly_meta_offsets = [0] # cumulative number of (num_polys+1) entries
+ poly_type_offsets = [0] # cumulative number of polys
+ for ep in episodes:
+ xy = np.ascontiguousarray(ep["road_xy"], dtype=np.float32)
+ offs = np.ascontiguousarray(ep["road_offsets"], dtype=np.uint32)
+ typs = np.ascontiguousarray(ep["road_types"], dtype=np.uint32)
+ all_xy_parts.append(xy)
+ all_off_parts.append(offs)
+ all_typ_parts.append(typs)
+ vert_offsets.append(vert_offsets[-1] + xy.shape[0])
+ poly_meta_offsets.append(poly_meta_offsets[-1] + offs.shape[0])
+ poly_type_offsets.append(poly_type_offsets[-1] + typs.shape[0])
+
+ all_road_xy = np.concatenate(all_xy_parts, axis=0) if all_xy_parts else np.zeros((0, 2), np.float32)
+ all_road_offsets = np.concatenate(all_off_parts) if all_off_parts else np.zeros((0,), np.uint32)
+ all_road_types = np.concatenate(all_typ_parts) if all_typ_parts else np.zeros((0,), np.uint32)
+ vert_offsets = np.array(vert_offsets, dtype=np.uint32)
+ poly_meta_offsets = np.array(poly_meta_offsets, dtype=np.uint32)
+ poly_type_offsets = np.array(poly_type_offsets, dtype=np.uint32)
+
+ ego = np.array([int(ep.get("ego_idx", -1)) for ep in episodes], dtype=np.int32)
+ out_td = [ep.get("out_topdown") for ep in episodes]
+ out_bev = [ep.get("out_bev") for ep in episodes]
+
+ _native.render_episodes_batch(
+ self._ctx,
+ all_road_xy=all_road_xy,
+ vert_offsets=vert_offsets,
+ all_road_offsets=all_road_offsets,
+ poly_meta_offsets=poly_meta_offsets,
+ all_road_types=all_road_types,
+ poly_type_offsets=poly_type_offsets,
+ traj_xyh=traj,
+ agent_lengths=agent_lengths,
+ ego_idx_per_ep=ego,
+ fps=int(fps),
+ out_topdown_paths=out_td,
+ out_bev_paths=out_bev,
+ )
+
+ def render_episode(
+ self,
+ *,
+ road_xy: np.ndarray,
+ road_offsets: np.ndarray,
+ road_types: np.ndarray,
+ traj_xyh: np.ndarray,
+ agent_dims: Optional[np.ndarray] = None,
+ agent_lengths: Optional[np.ndarray] = None,
+ ego_idx: int = -1,
+ fps: int = 30,
+ out_topdown: Optional[str] = None,
+ out_bev: Optional[str] = None,
+ ) -> None:
+ """Render one episode to one or two MP4 files.
+
+ Either ``out_topdown`` or ``out_bev`` (or both) must be set.
+ """
+ if self._ctx is None:
+ raise RuntimeError("Renderer is closed")
+ if out_topdown is None and out_bev is None:
+ raise ValueError("must set at least one of out_topdown / out_bev")
+
+ # The C extension is strict about dtypes / contiguity. Coerce here
+ # so callers can pass float64 / non-contiguous slices without
+ # tripping the validator.
+ road_xy = np.ascontiguousarray(road_xy, dtype=np.float32)
+ road_offsets = np.ascontiguousarray(road_offsets, dtype=np.uint32)
+ road_types = np.ascontiguousarray(road_types, dtype=np.uint32)
+ traj_xyh = np.ascontiguousarray(traj_xyh, dtype=np.float32)
+ if agent_dims is not None:
+ agent_dims = np.ascontiguousarray(agent_dims, dtype=np.float32)
+ if agent_lengths is not None:
+ agent_lengths = np.ascontiguousarray(agent_lengths, dtype=np.int32)
+
+ _native.render_episode(
+ self._ctx,
+ road_xy=road_xy,
+ road_offsets=road_offsets,
+ road_types=road_types,
+ traj_xyh=traj_xyh,
+ agent_dims=agent_dims,
+ agent_lengths=agent_lengths,
+ ego_idx=int(ego_idx),
+ fps=int(fps),
+ out_topdown=out_topdown,
+ out_bev=out_bev,
+ )
+
+
+# ---------------------------- npz convenience ---------------------------- #
+
+
+def _resolve_map_path(name: str, maps_dir: Path) -> Optional[Path]:
+ """Try a few likely locations for a map file referenced in the npz."""
+ candidates = [Path(name), maps_dir / Path(name).name, maps_dir / name]
+ for c in candidates:
+ if c.exists():
+ return c
+ return None
+
+
+def render_npz(
+ npz_path: str | Path,
+ maps_dir: str | Path,
+ out_dir: str | Path,
+ *,
+ width: int = 1280,
+ height: int = 720,
+ fps: int = 30,
+ views: Iterable[str] = ("topdown", "bev"),
+ renderer: Optional[Renderer] = None,
+) -> list[Path]:
+ """Render every episode in a saved trajectories_*.npz file.
+
+ Each env in the npz becomes one episode → one or two MP4 files in
+ ``out_dir`` (named ``{npz_stem}_env{ID}_{view}.mp4``).
+
+ If ``renderer`` is None, a fresh one is created and torn down inside
+ the call. Pass an existing Renderer to amortize Vulkan startup over
+ many .npz files.
+ """
+ npz_path = Path(npz_path)
+ maps_dir = Path(maps_dir)
+ out_dir = Path(out_dir)
+ out_dir.mkdir(parents=True, exist_ok=True)
+
+ views = set(views)
+ has_td = "topdown" in views
+ has_bev = "bev" in views
+ if not (has_td or has_bev):
+ raise ValueError("views must include at least one of 'topdown', 'bev'")
+
+ data = np.load(npz_path, allow_pickle=True)
+ required = (
+ "traj_x",
+ "traj_y",
+ "traj_heading",
+ "traj_lengths",
+ "agent_offsets",
+ "map_ids",
+ "map_files",
+ )
+ missing = [k for k in required if k not in data.files]
+ if missing:
+ raise ValueError(f"{npz_path} missing keys {missing}. Has: {sorted(data.files)}")
+
+ traj_x = data["traj_x"]
+ traj_y = data["traj_y"]
+ traj_heading = data["traj_heading"]
+ traj_lengths = np.asarray(data["traj_lengths"], dtype=np.int32)
+ agent_offsets = np.asarray(data["agent_offsets"], dtype=np.int64)
+ map_ids = np.asarray(data["map_ids"], dtype=np.int64)
+ map_files = [str(m) for m in np.asarray(data["map_files"]).tolist()]
+
+ num_envs = len(map_ids)
+ # world_means (plural, per-env, shape (num_envs, 3)) is the right key —
+ # each Drive sub-env in a vec has its own world_mean computed from its
+ # own map's geometry, so different maps have different centerings. The
+ # legacy single world_mean (env 0 only) leads to mis-aligned roads for
+ # any env_id != 0 with a different map. Prefer the new key; fall back
+ # to the legacy one with a warning so older saved npz files still
+ # render (incorrectly for non-env-0, but at least they render).
+ if "world_means" in data.files:
+ world_means = np.asarray(data["world_means"], dtype=np.float32)
+ if world_means.shape != (num_envs, 3):
+ raise ValueError(f"world_means has shape {world_means.shape}, expected ({num_envs}, 3)")
+ elif "world_mean" in data.files:
+ legacy = np.asarray(data["world_mean"], dtype=np.float32)
+ if num_envs > 1:
+ print(
+ f" WARNING: {npz_path.name} has only the legacy single "
+ f"world_mean key (env 0). Roads for non-env-0 trajectories "
+ f"with different maps will be mis-aligned. Re-save with the "
+ f"current pufferl to get per-env world_means."
+ )
+ world_means = np.broadcast_to(legacy[None, :], (num_envs, 3)).copy()
+ else:
+ raise ValueError(f"{npz_path} has neither world_means nor world_mean")
+
+ if len(agent_offsets) == num_envs:
+ # Some saved npz omit the trailing offset; reconstruct from the
+ # total agent count.
+ agent_offsets = np.concatenate([agent_offsets, [traj_x.shape[0]]])
+ elif len(agent_offsets) != num_envs + 1:
+ raise ValueError(f"agent_offsets length {len(agent_offsets)} doesn't match num_envs {num_envs}")
+
+ own_renderer = renderer is None
+ if own_renderer:
+ renderer = Renderer(width=width, height=height)
+
+ out_paths: list[Path] = []
+ # Cache key includes the per-env world_mean tuple, not just the
+ # map_id, so that if two sub-envs ever shared a map_id but used
+ # different world_means (edge case under heterogeneous init_modes)
+ # we don't return the wrong centering.
+ map_cache: dict[tuple, tuple] = {}
+
+ try:
+ for env_id in range(num_envs):
+ a0, a1 = int(agent_offsets[env_id]), int(agent_offsets[env_id + 1])
+ if a1 <= a0:
+ continue
+
+ mid = int(map_ids[env_id])
+ wm_env = world_means[env_id]
+ cache_key = (mid, float(wm_env[0]), float(wm_env[1]), float(wm_env[2]))
+ if cache_key not in map_cache:
+ mp = _resolve_map_path(map_files[mid], maps_dir)
+ if mp is None:
+ print(f" env {env_id}: map {map_files[mid]} not found, skipping")
+ continue
+ roads_raw = map_io.load_map_roads(mp)
+ roads = map_io.mean_center_roads(roads_raw, wm_env)
+ map_cache[cache_key] = map_io.roads_to_csr(roads)
+ road_xy, road_offsets, road_types = map_cache[cache_key]
+
+ # Slice trajectories for this env and stack into (T, A, 3)
+ ex = traj_x[a0:a1]
+ ey = traj_y[a0:a1]
+ eh = traj_heading[a0:a1]
+ elen = traj_lengths[a0:a1]
+ num_agents, num_steps = ex.shape
+ traj_xyh = np.empty((num_steps, num_agents, 3), dtype=np.float32)
+ traj_xyh[..., 0] = ex.T
+ traj_xyh[..., 1] = ey.T
+ traj_xyh[..., 2] = eh.T
+
+ scenario = f"{npz_path.stem}_env{env_id:03d}"
+ td_path = (out_dir / f"{scenario}_topdown.mp4") if has_td else None
+ bev_path = (out_dir / f"{scenario}_bev.mp4") if has_bev else None
+
+ renderer.render_episode(
+ road_xy=road_xy,
+ road_offsets=road_offsets,
+ road_types=road_types,
+ traj_xyh=traj_xyh,
+ agent_lengths=elen,
+ ego_idx=-1,
+ fps=fps,
+ out_topdown=str(td_path) if td_path else None,
+ out_bev=str(bev_path) if bev_path else None,
+ )
+
+ if td_path:
+ out_paths.append(td_path)
+ if bev_path:
+ out_paths.append(bev_path)
+ print(f" env {env_id}: {num_agents} agents, {num_steps} steps")
+
+ finally:
+ if own_renderer:
+ renderer.close()
+
+ return out_paths
+
+
+__all__ = ["Renderer", "render_npz"]
diff --git a/pufferlib/ocean/drive/trajviz/__main__.py b/pufferlib/ocean/drive/trajviz/__main__.py
new file mode 100644
index 0000000000..15dd66a69e
--- /dev/null
+++ b/pufferlib/ocean/drive/trajviz/__main__.py
@@ -0,0 +1,73 @@
+"""CLI entry point for the trajviz Vulkan renderer.
+
+Usage:
+ python -m pufferlib.ocean.drive.trajviz INPUT [INPUT...] \\
+ --maps-dir DIR --out DIR \\
+ [--width 1280] [--height 720] [--fps 30] \\
+ [--views topdown,bev]
+
+INPUT can be a saved trajectories_*.npz file or a directory to glob
+recursively. One Vulkan context is created up front and reused for every
+episode in every input file — pay the GPU init cost once for an entire
+batch.
+"""
+
+from __future__ import annotations
+
+import argparse
+from pathlib import Path
+
+from pufferlib.ocean.drive.trajviz import Renderer, render_npz
+
+
+def main() -> None:
+ p = argparse.ArgumentParser(
+ prog="python -m pufferlib.ocean.drive.trajviz",
+ description="Vulkan offline renderer for saved Drive trajectories.",
+ )
+ p.add_argument("inputs", nargs="+", type=Path, help="trajectories_*.npz files (or directories to glob).")
+ p.add_argument(
+ "--maps-dir", type=Path, required=True, help="Directory containing the .bin map files referenced in the npz."
+ )
+ p.add_argument("--out", type=Path, required=True, help="Output directory for MP4 files.")
+ p.add_argument("--width", type=int, default=1280)
+ p.add_argument("--height", type=int, default=720)
+ p.add_argument("--fps", type=int, default=30)
+ p.add_argument("--views", default="topdown,bev", help="Comma-separated subset of {topdown, bev}.")
+ args = p.parse_args()
+
+ npz_files: list[Path] = []
+ for inp in args.inputs:
+ if inp.is_dir():
+ npz_files.extend(sorted(inp.rglob("trajectories_*.npz")))
+ else:
+ npz_files.append(inp)
+
+ if not npz_files:
+ raise SystemExit("No trajectories_*.npz files found.")
+
+ views = tuple(v.strip() for v in args.views.split(",") if v.strip())
+
+ args.out.mkdir(parents=True, exist_ok=True)
+
+ total = 0
+ with Renderer(width=args.width, height=args.height) as renderer:
+ for npz in npz_files:
+ print(f"[{npz}]")
+ out_paths = render_npz(
+ npz,
+ args.maps_dir,
+ args.out,
+ width=args.width,
+ height=args.height,
+ fps=args.fps,
+ views=views,
+ renderer=renderer,
+ )
+ total += len(out_paths)
+
+ print(f"Wrote {total} MP4 files to {args.out}")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/pufferlib/ocean/drive/trajviz/_native.c b/pufferlib/ocean/drive/trajviz/_native.c
new file mode 100644
index 0000000000..11bd46f18c
--- /dev/null
+++ b/pufferlib/ocean/drive/trajviz/_native.c
@@ -0,0 +1,497 @@
+/*
+ * _native.c — CPython extension shell for the trajviz Vulkan renderer.
+ *
+ * Exposes three Python functions:
+ *
+ * init(width: int, height: int) -> capsule
+ * Creates a TrajvizCtx and returns it wrapped in a PyCapsule with
+ * a destructor that calls trajviz_close. The capsule can be passed
+ * to render_episode any number of times.
+ *
+ * render_episode(ctx, road_xy, road_offsets, road_types, traj_xyh,
+ * agent_dims, agent_lengths, ego_idx, fps,
+ * out_topdown, out_bev) -> int
+ * Validates the numpy arrays, releases the GIL, calls
+ * trajviz_render_episode, reacquires the GIL. Raises RuntimeError
+ * on a non-zero return code with the underlying error message.
+ *
+ * close(ctx) -> None
+ * Manually destroy the ctx. Optional — the capsule destructor
+ * will also do it on garbage collection.
+ *
+ * Numpy arrays are validated for dtype, ndim, and contiguity. Each is
+ * coerced to its expected dtype and C-contiguous layout via
+ * PyArray_FROMANY (which is a no-op for already-conforming arrays). The
+ * resulting reference is held for the duration of the call so the data
+ * pointer stays valid.
+ *
+ * The GIL release pattern: numpy unwrapping happens with the GIL held
+ * (necessary), then we release for the trajviz_render_episode call,
+ * which is the long-running part. Other Python threads can run during
+ * the GPU + ffmpeg work.
+ */
+
+#define PY_SSIZE_T_CLEAN
+#include
+
+#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
+#include
+
+#include "trajviz.h"
+
+#include
+#include
+#include
+
+/* PyCapsule name — used by PyCapsule_GetPointer to validate the type. */
+static const char CAPSULE_NAME[] = "trajviz._native.TrajvizCtx";
+
+/* We track "already closed" state in the capsule's context slot (not in
+ * the wrapped pointer, because PyCapsule_SetPointer rejects NULL and
+ * because the pointer points at freed memory after close). A non-NULL
+ * context = closed, NULL context = open. */
+#define CLOSED_SENTINEL ((void *)(uintptr_t)1)
+
+static void capsule_destructor(PyObject *capsule) {
+ if (PyCapsule_GetContext(capsule) == CLOSED_SENTINEL)
+ return;
+ TrajvizCtx *ctx = (TrajvizCtx *)PyCapsule_GetPointer(capsule, CAPSULE_NAME);
+ if (ctx)
+ trajviz_close(ctx);
+}
+
+/* ---------------------------- init / close ---------------------------- */
+
+static PyObject *py_init(PyObject *self, PyObject *args, PyObject *kwargs) {
+ (void)self;
+ static char *kwlist[] = {"width", "height", NULL};
+ int width = 0, height = 0;
+ if (!PyArg_ParseTupleAndKeywords(args, kwargs, "ii:init", kwlist, &width, &height)) {
+ return NULL;
+ }
+ TrajvizCtx *ctx = trajviz_init(width, height);
+ if (!ctx) {
+ PyErr_Format(PyExc_RuntimeError, "trajviz_init failed: %s", trajviz_last_error(NULL));
+ return NULL;
+ }
+ PyObject *capsule = PyCapsule_New(ctx, CAPSULE_NAME, capsule_destructor);
+ if (!capsule) {
+ trajviz_close(ctx);
+ return NULL;
+ }
+ return capsule;
+}
+
+static PyObject *py_close(PyObject *self, PyObject *args) {
+ (void)self;
+ PyObject *capsule = NULL;
+ if (!PyArg_ParseTuple(args, "O:close", &capsule))
+ return NULL;
+ if (!PyCapsule_CheckExact(capsule)) {
+ PyErr_SetString(PyExc_TypeError, "expected a TrajvizCtx capsule");
+ return NULL;
+ }
+ if (PyCapsule_GetContext(capsule) == CLOSED_SENTINEL) {
+ Py_RETURN_NONE; /* already closed */
+ }
+ TrajvizCtx *ctx = (TrajvizCtx *)PyCapsule_GetPointer(capsule, CAPSULE_NAME);
+ if (!ctx)
+ Py_RETURN_NONE;
+ trajviz_close(ctx);
+ PyCapsule_SetContext(capsule, CLOSED_SENTINEL);
+ Py_RETURN_NONE;
+}
+
+/* ----------------------- numpy validation helpers ----------------------- */
+
+/* Coerce a Python object to a contiguous numpy array of the given dtype
+ * and exact ndim. Returns a NEW reference (caller must DECREF). On
+ * failure raises a Python exception and returns NULL.
+ *
+ * If allow_none is non-zero and obj is Py_None, returns NULL with no
+ * exception set (caller checks).
+ */
+static PyArrayObject *as_array(PyObject *obj, int dtype, int ndim, int allow_none, const char *name) {
+ if (allow_none && obj == Py_None)
+ return NULL;
+ PyArrayObject *arr = (PyArrayObject *)PyArray_FROMANY(obj, dtype, ndim, ndim, NPY_ARRAY_C_CONTIGUOUS);
+ if (!arr) {
+ if (PyErr_Occurred()) {
+ /* PyArray_FROMANY already set a reasonable error message; we
+ * just prepend the argument name for clarity. */
+ PyObject *type, *value, *tb;
+ PyErr_Fetch(&type, &value, &tb);
+ PyErr_Format(PyExc_TypeError, "%s: %s", name,
+ value ? PyUnicode_AsUTF8(PyObject_Str(value)) : "type/shape mismatch");
+ Py_XDECREF(type);
+ Py_XDECREF(value);
+ Py_XDECREF(tb);
+ }
+ return NULL;
+ }
+ return arr;
+}
+
+/* --------------------------- render_episode --------------------------- */
+
+static PyObject *py_render_episode(PyObject *self, PyObject *args, PyObject *kwargs) {
+ (void)self;
+ static char *kwlist[] = {"ctx", "road_xy", "road_offsets", "road_types", "traj_xyh", "agent_dims",
+ "agent_lengths", "ego_idx", "fps", "out_topdown", "out_bev", NULL};
+
+ PyObject *capsule = NULL;
+ PyObject *o_road_xy, *o_road_off, *o_road_types;
+ PyObject *o_traj, *o_dims, *o_lens;
+ int ego_idx = -1;
+ int fps = 30;
+ const char *out_topdown = NULL;
+ const char *out_bev = NULL;
+
+ if (!PyArg_ParseTupleAndKeywords(args, kwargs, "OOOOOOOiizz:render_episode", kwlist, &capsule, &o_road_xy,
+ &o_road_off, &o_road_types, &o_traj, &o_dims, &o_lens, &ego_idx, &fps,
+ &out_topdown, &out_bev)) {
+ return NULL;
+ }
+
+ if (!PyCapsule_CheckExact(capsule)) {
+ PyErr_SetString(PyExc_TypeError, "ctx: expected a TrajvizCtx capsule");
+ return NULL;
+ }
+ if (PyCapsule_GetContext(capsule) == CLOSED_SENTINEL) {
+ PyErr_SetString(PyExc_RuntimeError, "ctx has been closed");
+ return NULL;
+ }
+ TrajvizCtx *ctx = (TrajvizCtx *)PyCapsule_GetPointer(capsule, CAPSULE_NAME);
+ if (!ctx) {
+ PyErr_SetString(PyExc_RuntimeError, "ctx is null");
+ return NULL;
+ }
+
+ PyArrayObject *a_xy = as_array(o_road_xy, NPY_FLOAT32, 2, 0, "road_xy");
+ PyArrayObject *a_off = as_array(o_road_off, NPY_UINT32, 1, 0, "road_offsets");
+ PyArrayObject *a_typ = as_array(o_road_types, NPY_UINT32, 1, 0, "road_types");
+ PyArrayObject *a_traj = as_array(o_traj, NPY_FLOAT32, 3, 0, "traj_xyh");
+ PyArrayObject *a_dims = as_array(o_dims, NPY_FLOAT32, 2, 1, "agent_dims");
+ PyArrayObject *a_lens = as_array(o_lens, NPY_INT32, 1, 1, "agent_lengths");
+
+ if (!a_xy || !a_off || !a_typ || !a_traj)
+ goto fail;
+
+ /* Shape checks. */
+ if (PyArray_DIM(a_xy, 1) != 2) {
+ PyErr_SetString(PyExc_ValueError, "road_xy must have shape (N, 2)");
+ goto fail;
+ }
+ if (PyArray_DIM(a_traj, 2) != 3) {
+ PyErr_SetString(PyExc_ValueError, "traj_xyh must have shape (T, A, 3)");
+ goto fail;
+ }
+ npy_intp num_steps = PyArray_DIM(a_traj, 0);
+ npy_intp num_agents = PyArray_DIM(a_traj, 1);
+
+ if (a_dims && (PyArray_DIM(a_dims, 0) != num_agents || PyArray_DIM(a_dims, 1) != 2)) {
+ PyErr_Format(PyExc_ValueError, "agent_dims must have shape (%ld, 2)", (long)num_agents);
+ goto fail;
+ }
+ if (a_lens && PyArray_DIM(a_lens, 0) != num_agents) {
+ PyErr_Format(PyExc_ValueError, "agent_lengths must have shape (%ld,)", (long)num_agents);
+ goto fail;
+ }
+
+ npy_intp num_polys = PyArray_DIM(a_typ, 0);
+ if (PyArray_DIM(a_off, 0) != num_polys + 1) {
+ PyErr_Format(PyExc_ValueError, "road_offsets must have shape (num_polys+1=%ld,), got (%ld,)",
+ (long)(num_polys + 1), (long)PyArray_DIM(a_off, 0));
+ goto fail;
+ }
+
+ /* Pull raw pointers. */
+ const float *road_xy = (const float *)PyArray_DATA(a_xy);
+ const uint32_t *road_offsets = (const uint32_t *)PyArray_DATA(a_off);
+ const uint32_t *road_types = (const uint32_t *)PyArray_DATA(a_typ);
+ const float *traj_xyh = (const float *)PyArray_DATA(a_traj);
+ const float *agent_dims_p = a_dims ? (const float *)PyArray_DATA(a_dims) : NULL;
+ const int32_t *agent_lens_p = a_lens ? (const int32_t *)PyArray_DATA(a_lens) : NULL;
+
+ int rc;
+ Py_BEGIN_ALLOW_THREADS rc = trajviz_render_episode(
+ ctx, road_xy, road_offsets, road_types, (uint32_t)num_polys, traj_xyh, (uint32_t)num_steps,
+ (uint32_t)num_agents, agent_dims_p, agent_lens_p, (int32_t)ego_idx, fps, out_topdown, out_bev);
+ Py_END_ALLOW_THREADS
+
+ Py_XDECREF(a_xy);
+ Py_XDECREF(a_off);
+ Py_XDECREF(a_typ);
+ Py_XDECREF(a_traj);
+ Py_XDECREF(a_dims);
+ Py_XDECREF(a_lens);
+
+ if (rc != TRAJVIZ_OK) {
+ PyErr_Format(PyExc_RuntimeError, "trajviz_render_episode failed (%d): %s", rc, trajviz_last_error(ctx));
+ return NULL;
+ }
+ Py_RETURN_NONE;
+
+fail:
+ Py_XDECREF(a_xy);
+ Py_XDECREF(a_off);
+ Py_XDECREF(a_typ);
+ Py_XDECREF(a_traj);
+ Py_XDECREF(a_dims);
+ Py_XDECREF(a_lens);
+ return NULL;
+}
+
+/* --------------------------- render_episodes_batch --------------------------- */
+
+/* Python signature:
+ *
+ * render_episodes_batch(
+ * ctx,
+ * all_road_xy, # (V_total, 2) float32
+ * vert_offsets, # (batch_size+1,) uint32
+ * all_road_offsets, # (P_meta_total,) uint32
+ * poly_meta_offsets, # (batch_size+1,) uint32
+ * all_road_types, # (P_total,) uint32
+ * poly_type_offsets, # (batch_size+1,) uint32
+ * traj_xyh, # (batch, T, A, 3) float32
+ * agent_lengths, # (batch, A) int32
+ * ego_idx_per_ep, # (batch,) int32
+ * fps, # int
+ * out_topdown_paths, # list of str or None, len batch
+ * out_bev_paths, # list of str or None, len batch
+ * ) -> None
+ *
+ * Returns None on success; raises RuntimeError with the C-side error
+ * message on failure.
+ */
+static PyObject *py_render_episodes_batch(PyObject *self, PyObject *args, PyObject *kwargs) {
+ (void)self;
+ static char *kwlist[] = {"ctx",
+ "all_road_xy",
+ "vert_offsets",
+ "all_road_offsets",
+ "poly_meta_offsets",
+ "all_road_types",
+ "poly_type_offsets",
+ "traj_xyh",
+ "agent_lengths",
+ "ego_idx_per_ep",
+ "fps",
+ "out_topdown_paths",
+ "out_bev_paths",
+ NULL};
+ PyObject *capsule = NULL;
+ PyObject *o_xy, *o_voff, *o_roff, *o_pmoff, *o_rtypes, *o_ptoff;
+ PyObject *o_traj, *o_lens, *o_egos;
+ int fps = 30;
+ PyObject *o_td_paths = NULL, *o_bev_paths = NULL;
+
+ if (!PyArg_ParseTupleAndKeywords(args, kwargs, "OOOOOOOOOOiOO:render_episodes_batch", kwlist, &capsule, &o_xy,
+ &o_voff, &o_roff, &o_pmoff, &o_rtypes, &o_ptoff, &o_traj, &o_lens, &o_egos, &fps,
+ &o_td_paths, &o_bev_paths)) {
+ return NULL;
+ }
+
+ if (!PyCapsule_CheckExact(capsule)) {
+ PyErr_SetString(PyExc_TypeError, "ctx: expected a TrajvizCtx capsule");
+ return NULL;
+ }
+ if (PyCapsule_GetContext(capsule) == CLOSED_SENTINEL) {
+ PyErr_SetString(PyExc_RuntimeError, "ctx has been closed");
+ return NULL;
+ }
+ TrajvizCtx *ctx = (TrajvizCtx *)PyCapsule_GetPointer(capsule, CAPSULE_NAME);
+ if (!ctx) {
+ PyErr_SetString(PyExc_RuntimeError, "ctx is null");
+ return NULL;
+ }
+
+ PyArrayObject *a_xy = as_array(o_xy, NPY_FLOAT32, 2, 0, "all_road_xy");
+ PyArrayObject *a_voff = as_array(o_voff, NPY_UINT32, 1, 0, "vert_offsets");
+ PyArrayObject *a_roff = as_array(o_roff, NPY_UINT32, 1, 0, "all_road_offsets");
+ PyArrayObject *a_pmoff = as_array(o_pmoff, NPY_UINT32, 1, 0, "poly_meta_offsets");
+ PyArrayObject *a_rtyp = as_array(o_rtypes, NPY_UINT32, 1, 0, "all_road_types");
+ PyArrayObject *a_ptoff = as_array(o_ptoff, NPY_UINT32, 1, 0, "poly_type_offsets");
+ PyArrayObject *a_traj = as_array(o_traj, NPY_FLOAT32, 4, 0, "traj_xyh");
+ PyArrayObject *a_lens = as_array(o_lens, NPY_INT32, 2, 0, "agent_lengths");
+ PyArrayObject *a_egos = as_array(o_egos, NPY_INT32, 1, 0, "ego_idx_per_ep");
+
+ if (!a_xy || !a_voff || !a_roff || !a_pmoff || !a_rtyp || !a_ptoff || !a_traj || !a_lens || !a_egos)
+ goto fail;
+
+ if (PyArray_DIM(a_xy, 1) != 2) {
+ PyErr_SetString(PyExc_ValueError, "all_road_xy must have shape (V, 2)");
+ goto fail;
+ }
+ if (PyArray_DIM(a_traj, 3) != 3) {
+ PyErr_SetString(PyExc_ValueError, "traj_xyh must have shape (batch, T, A, 3)");
+ goto fail;
+ }
+
+ npy_intp batch_size = PyArray_DIM(a_traj, 0);
+ npy_intp num_steps = PyArray_DIM(a_traj, 1);
+ npy_intp max_agents = PyArray_DIM(a_traj, 2);
+
+ if (PyArray_DIM(a_lens, 0) != batch_size || PyArray_DIM(a_lens, 1) != max_agents) {
+ PyErr_Format(PyExc_ValueError, "agent_lengths must have shape (%ld, %ld)", (long)batch_size, (long)max_agents);
+ goto fail;
+ }
+ if (PyArray_DIM(a_egos, 0) != batch_size) {
+ PyErr_Format(PyExc_ValueError, "ego_idx_per_ep must have shape (%ld,)", (long)batch_size);
+ goto fail;
+ }
+ if (PyArray_DIM(a_voff, 0) != batch_size + 1 || PyArray_DIM(a_pmoff, 0) != batch_size + 1 ||
+ PyArray_DIM(a_ptoff, 0) != batch_size + 1) {
+ PyErr_Format(PyExc_ValueError,
+ "vert_offsets / poly_meta_offsets / poly_type_offsets must have "
+ "shape (batch_size+1=%ld,)",
+ (long)(batch_size + 1));
+ goto fail;
+ }
+
+ /* Output path arrays. Each list (or None) must have length == batch_size.
+ * We allocate a C array of const char* per list and copy the strings'
+ * pointers from PyUnicode_AsUTF8. Strings stay valid for the duration
+ * of this call because we hold the Python lists. */
+ if (o_td_paths != Py_None && !PyList_Check(o_td_paths)) {
+ PyErr_SetString(PyExc_TypeError, "out_topdown_paths must be a list or None");
+ goto fail;
+ }
+ if (o_bev_paths != Py_None && !PyList_Check(o_bev_paths)) {
+ PyErr_SetString(PyExc_TypeError, "out_bev_paths must be a list or None");
+ goto fail;
+ }
+ if (o_td_paths != Py_None && PyList_GET_SIZE(o_td_paths) != batch_size) {
+ PyErr_Format(PyExc_ValueError, "out_topdown_paths length %zd != batch_size %ld", PyList_GET_SIZE(o_td_paths),
+ (long)batch_size);
+ goto fail;
+ }
+ if (o_bev_paths != Py_None && PyList_GET_SIZE(o_bev_paths) != batch_size) {
+ PyErr_Format(PyExc_ValueError, "out_bev_paths length %zd != batch_size %ld", PyList_GET_SIZE(o_bev_paths),
+ (long)batch_size);
+ goto fail;
+ }
+
+ const char **td_arr = NULL;
+ const char **bev_arr = NULL;
+ td_arr = (const char **)calloc((size_t)batch_size, sizeof(const char *));
+ bev_arr = (const char **)calloc((size_t)batch_size, sizeof(const char *));
+ if (!td_arr || !bev_arr) {
+ PyErr_NoMemory();
+ free(td_arr);
+ free(bev_arr);
+ goto fail;
+ }
+ if (o_td_paths != Py_None) {
+ for (npy_intp i = 0; i < batch_size; ++i) {
+ PyObject *item = PyList_GET_ITEM(o_td_paths, i);
+ if (item == Py_None) {
+ td_arr[i] = NULL;
+ } else if (PyUnicode_Check(item)) {
+ td_arr[i] = PyUnicode_AsUTF8(item);
+ if (!td_arr[i]) {
+ free(td_arr);
+ free(bev_arr);
+ goto fail;
+ }
+ } else {
+ PyErr_Format(PyExc_TypeError, "out_topdown_paths[%zd] must be str or None", i);
+ free(td_arr);
+ free(bev_arr);
+ goto fail;
+ }
+ }
+ }
+ if (o_bev_paths != Py_None) {
+ for (npy_intp i = 0; i < batch_size; ++i) {
+ PyObject *item = PyList_GET_ITEM(o_bev_paths, i);
+ if (item == Py_None) {
+ bev_arr[i] = NULL;
+ } else if (PyUnicode_Check(item)) {
+ bev_arr[i] = PyUnicode_AsUTF8(item);
+ if (!bev_arr[i]) {
+ free(td_arr);
+ free(bev_arr);
+ goto fail;
+ }
+ } else {
+ PyErr_Format(PyExc_TypeError, "out_bev_paths[%zd] must be str or None", i);
+ free(td_arr);
+ free(bev_arr);
+ goto fail;
+ }
+ }
+ }
+
+ int rc;
+ Py_BEGIN_ALLOW_THREADS rc = trajviz_render_episodes_batch(
+ ctx, (int)batch_size, (uint32_t)num_steps, (uint32_t)max_agents, (const float *)PyArray_DATA(a_xy),
+ (const uint32_t *)PyArray_DATA(a_voff), (const uint32_t *)PyArray_DATA(a_roff),
+ (const uint32_t *)PyArray_DATA(a_pmoff), (const uint32_t *)PyArray_DATA(a_rtyp),
+ (const uint32_t *)PyArray_DATA(a_ptoff), (const float *)PyArray_DATA(a_traj),
+ (const int32_t *)PyArray_DATA(a_lens), (const int32_t *)PyArray_DATA(a_egos), fps, td_arr, bev_arr);
+ Py_END_ALLOW_THREADS
+
+ free(td_arr);
+ free(bev_arr);
+
+ Py_XDECREF(a_xy);
+ Py_XDECREF(a_voff);
+ Py_XDECREF(a_roff);
+ Py_XDECREF(a_pmoff);
+ Py_XDECREF(a_rtyp);
+ Py_XDECREF(a_ptoff);
+ Py_XDECREF(a_traj);
+ Py_XDECREF(a_lens);
+ Py_XDECREF(a_egos);
+
+ if (rc != TRAJVIZ_OK) {
+ PyErr_Format(PyExc_RuntimeError, "trajviz_render_episodes_batch failed (%d): %s", rc, trajviz_last_error(ctx));
+ return NULL;
+ }
+ Py_RETURN_NONE;
+
+fail:
+ Py_XDECREF(a_xy);
+ Py_XDECREF(a_voff);
+ Py_XDECREF(a_roff);
+ Py_XDECREF(a_pmoff);
+ Py_XDECREF(a_rtyp);
+ Py_XDECREF(a_ptoff);
+ Py_XDECREF(a_traj);
+ Py_XDECREF(a_lens);
+ Py_XDECREF(a_egos);
+ return NULL;
+}
+
+/* ----------------------------- module def ----------------------------- */
+
+static PyMethodDef methods[] = {
+ {"init", (PyCFunction)py_init, METH_VARARGS | METH_KEYWORDS, "init(width, height) -> capsule"},
+ {"render_episode", (PyCFunction)py_render_episode, METH_VARARGS | METH_KEYWORDS,
+ "render_episode(ctx, road_xy, road_offsets, road_types, traj_xyh, "
+ "agent_dims, agent_lengths, ego_idx, fps, out_topdown, out_bev) -> None"},
+ {"render_episodes_batch", (PyCFunction)py_render_episodes_batch, METH_VARARGS | METH_KEYWORDS,
+ "render_episodes_batch(ctx, all_road_xy, vert_offsets, all_road_offsets, "
+ "poly_meta_offsets, all_road_types, poly_type_offsets, traj_xyh, "
+ "agent_lengths, ego_idx_per_ep, fps, out_topdown_paths, out_bev_paths) -> None"},
+ {"close", (PyCFunction)py_close, METH_VARARGS, "close(ctx) -> None"},
+ {NULL, NULL, 0, NULL}};
+
+static struct PyModuleDef moduledef = {PyModuleDef_HEAD_INIT,
+ "trajviz._native",
+ "Vulkan-backed renderer for saved Drive trajectories.",
+ -1,
+ methods,
+ NULL,
+ NULL,
+ NULL,
+ NULL};
+
+PyMODINIT_FUNC PyInit__native(void) {
+ import_array(); /* numpy C API initialization — REQUIRED before any PyArray_* */
+ if (PyErr_Occurred())
+ return NULL;
+ return PyModule_Create(&moduledef);
+}
diff --git a/pufferlib/ocean/drive/trajviz/ffmpeg_pipe.c b/pufferlib/ocean/drive/trajviz/ffmpeg_pipe.c
new file mode 100644
index 0000000000..1952e90988
--- /dev/null
+++ b/pufferlib/ocean/drive/trajviz/ffmpeg_pipe.c
@@ -0,0 +1,382 @@
+/*
+ * ffmpeg_pipe.c — popen-based RGBA → MP4 streaming.
+ *
+ * The ffmpeg invocation matches what visualize.c uses on the live
+ * raylib path, with the same -preset and -crf so output sizes are
+ * comparable. Single-pass libx264, yuv420p (the most compatible pixel
+ * format for downstream players).
+ *
+ * Notes on robustness:
+ * - We pipe raw rgba in row-major order, no padding (width * 4 bytes
+ * per row, height rows). The renderer uses HOST_COHERENT readback
+ * buffers tightly packed at width*4 stride, so no row-pitch
+ * conversion is needed here.
+ * - We rely on the OS to write a full frame to ffmpeg's stdin in one
+ * fwrite call. The pipe buffer size on Linux is typically 64 KiB
+ * and a 1280×720 RGBA frame is 3.6 MiB, so fwrite will internally
+ * loop on a blocking pipe — that's the correct behavior.
+ */
+
+/* F_SETPIPE_SZ is a Linux extension behind _GNU_SOURCE on glibc. */
+#ifndef _GNU_SOURCE
+#define _GNU_SOURCE
+#endif
+
+#include "ffmpeg_pipe.h"
+
+#include
+#include
+#include
+#include
+#include
+#include
+
+#ifdef __linux__
+#include
+#include
+#include
+#endif
+
+/* Forward decls so ffmpeg_pipe_open can pthread_create the writer. */
+static void *writer_thread_main(void *arg);
+static int do_blocking_write(FfmpegPipe *p, const void *data, size_t bytes);
+
+/* ----- encoder selection -----
+ *
+ * trajviz can use either libx264 (CPU, the default) or h264_nvenc
+ * (NVIDIA hardware encoder, opt-in). Selection is via the
+ * TRAJVIZ_ENCODER env var:
+ *
+ * - unset / "libx264" / "x264" → CPU encoding (default)
+ * - "nvenc" / "h264_nvenc" → NVIDIA NVENC
+ *
+ * Why libx264 is the default even on NVIDIA boxes:
+ *
+ * NVENC is a poor fit for our current "one ffmpeg subprocess per
+ * output stream" architecture. Two reasons:
+ *
+ * 1. NVENC session creation is expensive (~100 ms per session).
+ * We spawn 2N ffmpeg processes per render_batch call (one per
+ * output MP4 file). For short 90-frame episodes the per-session
+ * startup cost dominates the actual encode time, and per-episode
+ * wall time ends up ~2× slower than libx264.
+ *
+ * 2. NVIDIA's NVENC driver still effectively caps concurrent
+ * sessions per process (the "consumer key" limit was nominally
+ * removed in driver 530+, but ffmpeg's nvenc wrapper trips on
+ * it anyway at batch_size ≥ 8 with errors like
+ * "OpenEncodeSessionEx failed: incompatible client key (21)").
+ *
+ * Both problems go away if you use ONE persistent NVENC session and
+ * feed it many frames — which is what direct integration of the NVENC
+ * C API would do, or what a single-ffmpeg-multi-input architecture
+ * would do. Until trajviz has either of those, libx264 is the
+ * faster path on a 16-core CPU.
+ *
+ * Setting TRAJVIZ_ENCODER=nvenc still works for single-episode
+ * rendering (no concurrent sessions, NVENC startup is amortized over
+ * the whole episode) — it's just not the default.
+ */
+typedef enum {
+ ENCODER_LIBX264 = 0,
+ ENCODER_NVENC = 1,
+} EncoderChoice;
+
+static EncoderChoice select_encoder(void) {
+ const char *enc = getenv("TRAJVIZ_ENCODER");
+ if (enc && *enc) {
+ if (strcmp(enc, "libx264") == 0 || strcmp(enc, "x264") == 0)
+ return ENCODER_LIBX264;
+ if (strcmp(enc, "nvenc") == 0 || strcmp(enc, "h264_nvenc") == 0)
+ return ENCODER_NVENC;
+ fprintf(stderr, "[trajviz] unknown TRAJVIZ_ENCODER=%s; using libx264\n", enc);
+ }
+ return ENCODER_LIBX264;
+}
+
+int ffmpeg_pipe_open(FfmpegPipe *p, int width, int height, int fps, const char *out_mp4) {
+ if (!p || !out_mp4)
+ return -1;
+ memset(p, 0, sizeof(*p));
+ p->fd = -1;
+ p->width = width;
+ p->height = height;
+ p->fps = fps;
+ snprintf(p->path, sizeof(p->path), "%s", out_mp4);
+
+ const char *ffmpeg_bin = getenv("TRAJVIZ_FFMPEG");
+ if (!ffmpeg_bin || !*ffmpeg_bin)
+ ffmpeg_bin = "ffmpeg";
+
+ /* TRAJVIZ_NO_FFMPEG=1 → bypass ffmpeg entirely for benchmarking the
+ * pure Vulkan path. We sink raw RGBA bytes to /dev/null via cat,
+ * which removes the libx264 encode cost from the timing loop. */
+ int bypass = 0;
+ const char *no_ff = getenv("TRAJVIZ_NO_FFMPEG");
+ if (no_ff && *no_ff && *no_ff != '0')
+ bypass = 1;
+
+ /* Build the ffmpeg command line. We single-quote the output path so
+ * shell metacharacters in user-supplied paths don't blow us up — but
+ * a single-quote in the path itself would still break, so we reject
+ * paths containing one. */
+ if (strchr(out_mp4, '\'') != NULL) {
+ fprintf(stderr, "[trajviz] output path contains single quote: %s\n", out_mp4);
+ return -1;
+ }
+
+ char cmd[2048];
+ int n;
+ if (bypass) {
+ n = snprintf(cmd, sizeof(cmd), "cat > /dev/null");
+ } else {
+ EncoderChoice enc = select_encoder();
+ /* One-time stderr line so the user knows which encoder we picked. */
+ static int logged_encoder = 0;
+ if (!logged_encoder) {
+ fprintf(stderr, "[trajviz] encoder: %s\n", enc == ENCODER_NVENC ? "h264_nvenc (GPU)" : "libx264 (CPU)");
+ logged_encoder = 1;
+ }
+ if (enc == ENCODER_NVENC) {
+ /* p4 = balanced default; tune hq (high quality) since
+ * latency doesn't matter in offline batch; -cq 23 is roughly
+ * libx264 -crf 20 in file size on this content. */
+ n = snprintf(cmd, sizeof(cmd),
+ "%s -y -hide_banner -loglevel error "
+ "-f rawvideo -pix_fmt rgba "
+ "-s %dx%d -r %d -i - "
+ "-c:v h264_nvenc -pix_fmt yuv420p "
+ "-preset p4 -tune hq -cq 23 "
+ "'%s'",
+ ffmpeg_bin, width, height, fps, out_mp4);
+ } else {
+ n = snprintf(cmd, sizeof(cmd),
+ "%s -y -hide_banner -loglevel error "
+ "-f rawvideo -pix_fmt rgba "
+ "-s %dx%d -r %d -i - "
+ "-c:v libx264 -pix_fmt yuv420p "
+ "-preset veryfast -crf 20 "
+ "'%s'",
+ ffmpeg_bin, width, height, fps, out_mp4);
+ }
+ }
+ if (n < 0 || n >= (int)sizeof(cmd)) {
+ fprintf(stderr, "[trajviz] ffmpeg command too long\n");
+ return -1;
+ }
+
+ p->fp = popen(cmd, "w");
+ if (!p->fp) {
+ fprintf(stderr, "[trajviz] popen(\"%s\") failed\n", cmd);
+ return -1;
+ }
+
+#ifdef __linux__
+ /* Cache the underlying file descriptor — we use raw write() in the
+ * hot path because libc fwrite chunks our 3.6 MB frame through its
+ * default ~8 KB stdio buffer (450+ syscalls per frame), which is
+ * the actual single biggest bottleneck on this path. Disable any
+ * stdio buffering as a paranoia measure in case anything ever does
+ * touch p->fp. */
+ p->fd = fileno(p->fp);
+ if (p->fd >= 0) {
+ setvbuf(p->fp, NULL, _IONBF, 0);
+ }
+
+ /* Also bump the kernel pipe buffer up to whatever the per-process
+ * limit allows — ideally enough to fit multiple full frames so the
+ * producer can race ahead of the consumer. Default 64 KB → 1 MB
+ * unprivileged → 8+ MB with sudo sysctl. */
+ if (p->fd >= 0) {
+ long want_one_frame = (long)width * (long)height * 4;
+ long tries[] = {
+ 16L << 20, /* 16 MB — needs sudo sysctl fs.pipe-max-size=16777216 */
+ 8L << 20, 4L << 20, 2L << 20, 1L << 20, 512L << 10, 256L << 10,
+ };
+ int got = 0;
+ for (size_t i = 0; i < sizeof(tries) / sizeof(tries[0]); ++i) {
+ if (fcntl(p->fd, F_SETPIPE_SZ, (int)tries[i]) >= 0) {
+ got = (int)tries[i];
+ break;
+ }
+ }
+ static int warned_small_pipe = 0;
+ if (got > 0 && got < want_one_frame && !warned_small_pipe) {
+ fprintf(stderr,
+ "[trajviz] pipe size %d B < frame size %ld B — fwrites "
+ "may block. Raise /proc/sys/fs/pipe-max-size for better "
+ "throughput (sudo sysctl fs.pipe-max-size=16777216).\n",
+ got, want_one_frame);
+ warned_small_pipe = 1;
+ }
+ }
+#endif
+
+ /* Spin up the writer thread. From here on, all writes go through
+ * submit_frame → cv_go → writer_thread_main → write() → cv_done. */
+ if (pthread_mutex_init(&p->mu, NULL) != 0 || pthread_cond_init(&p->cv_go, NULL) != 0 ||
+ pthread_cond_init(&p->cv_done, NULL) != 0) {
+ fprintf(stderr, "[trajviz] failed to init writer thread sync for %s\n", p->path);
+ pclose(p->fp);
+ p->fp = NULL;
+ return -1;
+ }
+ if (pthread_create(&p->thread, NULL, writer_thread_main, p) != 0) {
+ fprintf(stderr, "[trajviz] pthread_create failed for %s\n", p->path);
+ pthread_mutex_destroy(&p->mu);
+ pthread_cond_destroy(&p->cv_go);
+ pthread_cond_destroy(&p->cv_done);
+ pclose(p->fp);
+ p->fp = NULL;
+ return -1;
+ }
+ p->thread_started = 1;
+
+ return 0;
+}
+
+/* The actual blocking write — used by the writer thread. Loops on
+ * EINTR + partial writes. Returns 0 on success, -1 on error. */
+static int do_blocking_write(FfmpegPipe *p, const void *data, size_t bytes) {
+ static int no_write_cached = -1;
+ if (no_write_cached < 0) {
+ const char *e = getenv("TRAJVIZ_NO_WRITE");
+ no_write_cached = (e && *e && *e != '0') ? 1 : 0;
+ }
+ if (no_write_cached)
+ return 0;
+
+#ifdef __linux__
+ const uint8_t *buf = (const uint8_t *)data;
+ size_t left = bytes;
+ while (left > 0) {
+ ssize_t n = write(p->fd, buf, left);
+ if (n < 0) {
+ if (errno == EINTR)
+ continue;
+ fprintf(stderr, "[trajviz] write() failed (%s) for %s\n", strerror(errno), p->path);
+ return -1;
+ }
+ if (n == 0) {
+ fprintf(stderr, "[trajviz] write() returned 0 — pipe closed for %s\n", p->path);
+ return -1;
+ }
+ buf += (size_t)n;
+ left -= (size_t)n;
+ }
+ return 0;
+#else
+ size_t got = fwrite(data, 1, bytes, p->fp);
+ if (got != bytes) {
+ fprintf(stderr, "[trajviz] short write to ffmpeg pipe (%zu/%zu) for %s\n", got, bytes, p->path);
+ return -1;
+ }
+ return 0;
+#endif
+}
+
+/* Background writer thread main loop. Sleeps on cv_go, wakes up when
+ * the main thread submits work, does the write outside the lock so
+ * other threads can proceed in parallel, then signals cv_done. */
+static void *writer_thread_main(void *arg) {
+ FfmpegPipe *p = (FfmpegPipe *)arg;
+
+ pthread_mutex_lock(&p->mu);
+ for (;;) {
+ while (!p->work_pending && !p->stop) {
+ pthread_cond_wait(&p->cv_go, &p->mu);
+ }
+ if (p->stop) {
+ pthread_mutex_unlock(&p->mu);
+ return NULL;
+ }
+
+ /* Snapshot the work and release the lock so other writers can
+ * be submitted to in parallel while we write. */
+ const void *data = p->pending_data;
+ size_t bytes = p->pending_bytes;
+ pthread_mutex_unlock(&p->mu);
+
+ int err = do_blocking_write(p, data, bytes);
+
+ pthread_mutex_lock(&p->mu);
+ p->work_error = err;
+ p->work_pending = 0;
+ pthread_cond_signal(&p->cv_done);
+ }
+}
+
+int ffmpeg_pipe_submit_frame(FfmpegPipe *p, const void *rgba_bytes) {
+ if (!p || !rgba_bytes)
+ return -1;
+ if (!p->thread_started) {
+ /* No writer thread — fall back to synchronous write. */
+ return do_blocking_write(p, rgba_bytes, (size_t)p->width * (size_t)p->height * 4);
+ }
+
+ pthread_mutex_lock(&p->mu);
+ /* Drain any in-flight write before submitting a new one. The caller
+ * is supposed to call wait() between frames so this should normally
+ * be a no-op, but we guard against misuse. */
+ while (p->work_pending) {
+ pthread_cond_wait(&p->cv_done, &p->mu);
+ }
+ p->pending_data = rgba_bytes;
+ p->pending_bytes = (size_t)p->width * (size_t)p->height * 4;
+ p->work_pending = 1;
+ p->work_error = 0;
+ pthread_cond_signal(&p->cv_go);
+ pthread_mutex_unlock(&p->mu);
+ return 0;
+}
+
+int ffmpeg_pipe_wait(FfmpegPipe *p) {
+ if (!p)
+ return -1;
+ if (!p->thread_started)
+ return 0; /* sync mode — already done */
+
+ pthread_mutex_lock(&p->mu);
+ while (p->work_pending) {
+ pthread_cond_wait(&p->cv_done, &p->mu);
+ }
+ int err = p->work_error;
+ pthread_mutex_unlock(&p->mu);
+ return err;
+}
+
+int ffmpeg_pipe_write_frame(FfmpegPipe *p, const void *rgba_bytes) {
+ /* Sync wrapper: submit + wait. Used by the single-episode path. */
+ int rc = ffmpeg_pipe_submit_frame(p, rgba_bytes);
+ if (rc != 0)
+ return rc;
+ return ffmpeg_pipe_wait(p);
+}
+
+int ffmpeg_pipe_close(FfmpegPipe *p) {
+ if (!p)
+ return 0;
+
+ /* Drain any in-flight write, then signal the writer to exit and
+ * join it. After this, no thread is touching p->fp / p->fd, so we
+ * can safely pclose. */
+ if (p->thread_started) {
+ ffmpeg_pipe_wait(p);
+ pthread_mutex_lock(&p->mu);
+ p->stop = 1;
+ pthread_cond_signal(&p->cv_go);
+ pthread_mutex_unlock(&p->mu);
+ pthread_join(p->thread, NULL);
+ pthread_mutex_destroy(&p->mu);
+ pthread_cond_destroy(&p->cv_go);
+ pthread_cond_destroy(&p->cv_done);
+ p->thread_started = 0;
+ }
+
+ if (!p->fp)
+ return 0;
+ int status = pclose(p->fp);
+ p->fp = NULL;
+ p->fd = -1;
+ return status;
+}
diff --git a/pufferlib/ocean/drive/trajviz/ffmpeg_pipe.h b/pufferlib/ocean/drive/trajviz/ffmpeg_pipe.h
new file mode 100644
index 0000000000..3448138ac3
--- /dev/null
+++ b/pufferlib/ocean/drive/trajviz/ffmpeg_pipe.h
@@ -0,0 +1,79 @@
+/*
+ * ffmpeg_pipe.h — write rendered RGBA frames to an ffmpeg subprocess.
+ *
+ * Each FfmpegPipe is a unidirectional handle to an ffmpeg process whose
+ * stdin we feed raw RGBA pixels. ffmpeg encodes them to H.264 in an MP4
+ * via libx264. We use popen() so we don't have to manage fork/exec/dup2
+ * by hand — at the cost of going through /bin/sh, which is fine because
+ * the output paths are caller-supplied and the rest of the args are
+ * static.
+ *
+ * One pipe per output MP4. The orchestrator opens (up to) two pipes per
+ * episode — one for top-down, one for BEV — writes one frame at a time
+ * to each, and closes both at episode end.
+ *
+ * Error model: write returns 0 on success, non-zero if fwrite fails (e.g.
+ * ffmpeg crashed or the disk filled up). The pipe is left open; the
+ * caller should close it.
+ */
+
+#ifndef FFMPEG_PIPE_H
+#define FFMPEG_PIPE_H
+
+#include
+#include
+#include
+#include
+
+typedef struct FfmpegPipe {
+ FILE *fp; /* popen handle (kept so pclose can close it) */
+ int fd; /* cached fileno(fp) — we use raw write() */
+ int width;
+ int height;
+ int fps;
+ char path[1024]; /* output mp4 path, kept for error messages */
+
+ /* Writer thread + signaling. Each pipe gets its own background
+ * thread that does the blocking write() in parallel with the main
+ * thread + the other pipes' threads, so vk_batch_renderer's per-
+ * frame "submit all → wait all" loop costs max(single write) per
+ * frame instead of sum-of-writes. */
+ pthread_t thread;
+ pthread_mutex_t mu;
+ pthread_cond_t cv_go; /* main → writer: new work pending */
+ pthread_cond_t cv_done; /* writer → main: write completed */
+ int thread_started;
+ int stop;
+ int work_pending;
+ int work_error;
+ const void *pending_data; /* borrowed, valid between submit/wait */
+ size_t pending_bytes;
+} FfmpegPipe;
+
+/* Spawn ffmpeg writing to out_mp4. Returns 0 on success, non-zero on
+ * popen failure. The ffmpeg binary path is taken from $TRAJVIZ_FFMPEG if
+ * set, else "ffmpeg". */
+int ffmpeg_pipe_open(FfmpegPipe *p, int width, int height, int fps, const char *out_mp4);
+
+/* Write one frame's worth of RGBA bytes (width*height*4). SYNC: blocks
+ * until the write completes. Internally implemented as
+ * submit + wait — used by the single-episode path that doesn't need
+ * fan-out parallelism. */
+int ffmpeg_pipe_write_frame(FfmpegPipe *p, const void *rgba_bytes);
+
+/* ASYNC: hand off a frame to the pipe's writer thread and return
+ * immediately. The buffer pointer must stay valid until ffmpeg_pipe_wait
+ * returns for the same pipe. Will block briefly if a previous submit on
+ * this pipe is still running. */
+int ffmpeg_pipe_submit_frame(FfmpegPipe *p, const void *rgba_bytes);
+
+/* Wait for the most recent submit_frame on this pipe to complete.
+ * Returns 0 on success or the error code from the writer thread's
+ * write() call. Idempotent — safe to call when no submit is pending. */
+int ffmpeg_pipe_wait(FfmpegPipe *p);
+
+/* Close the pipe and wait for ffmpeg to flush. Idempotent. Returns the
+ * exit status of ffmpeg (0 = success). */
+int ffmpeg_pipe_close(FfmpegPipe *p);
+
+#endif
diff --git a/pufferlib/ocean/drive/trajviz/shaders.h b/pufferlib/ocean/drive/trajviz/shaders.h
new file mode 100644
index 0000000000..70e55d9cc2
--- /dev/null
+++ b/pufferlib/ocean/drive/trajviz/shaders.h
@@ -0,0 +1,36 @@
+/*
+ * shaders.h — externs for SPIR-V blobs generated at build time.
+ *
+ * The actual byte arrays live in shaders.c, which is GENERATED by
+ * shaders/build_shaders.sh as part of the extension build. Don't commit
+ * shaders.c — it's a build artifact and changes whenever the .vert/.frag
+ * sources change.
+ *
+ * Build script flow:
+ * 1. glslangValidator -V *.vert *.frag -o *.spv
+ * 2. xxd-style hex dump → shaders.c with one const uint32_t array per shader
+ *
+ * If you see a linker error about undefined references to *_spv, the
+ * build script didn't run. Either run shaders/build_shaders.sh manually
+ * or rebuild via setup.py (which invokes it as a pre-build step).
+ */
+
+#ifndef TRAJVIZ_SHADERS_H
+#define TRAJVIZ_SHADERS_H
+
+#include
+#include
+
+extern const uint32_t polyline_vert_spv[];
+extern const size_t polyline_vert_spv_size; /* in bytes */
+
+extern const uint32_t polyline_frag_spv[];
+extern const size_t polyline_frag_spv_size;
+
+extern const uint32_t agent_box_vert_spv[];
+extern const size_t agent_box_vert_spv_size;
+
+extern const uint32_t agent_box_frag_spv[];
+extern const size_t agent_box_frag_spv_size;
+
+#endif
diff --git a/pufferlib/ocean/drive/trajviz/shaders/agent_box.frag b/pufferlib/ocean/drive/trajviz/shaders/agent_box.frag
new file mode 100644
index 0000000000..6e90b1618a
--- /dev/null
+++ b/pufferlib/ocean/drive/trajviz/shaders/agent_box.frag
@@ -0,0 +1,10 @@
+#version 450
+
+// Agent box fragment shader — flat color from vertex stage.
+
+layout(location = 0) in vec4 v_color;
+layout(location = 0) out vec4 out_color;
+
+void main() {
+ out_color = v_color;
+}
diff --git a/pufferlib/ocean/drive/trajviz/shaders/agent_box.vert b/pufferlib/ocean/drive/trajviz/shaders/agent_box.vert
new file mode 100644
index 0000000000..c226b20004
--- /dev/null
+++ b/pufferlib/ocean/drive/trajviz/shaders/agent_box.vert
@@ -0,0 +1,50 @@
+#version 450
+
+// Agent box vertex shader — instanced unit-quad expansion.
+//
+// Per-vertex (location 0): vec2 quad corner in [-1, +1]^2.
+// The renderer binds a static 4-vert vertex buffer with the unit quad
+// corners and a 6-index index buffer (two triangles).
+//
+// Per-instance (locations 1..3): one AgentInstance struct per agent for
+// the current frame.
+// loc 1 (pose): vec4 (x, y, heading_rad, _pad) — world position + angle
+// loc 2 (size): vec2 (length, width) — meters
+// loc 3 (color): vec4 (r, g, b, a) — base color
+//
+// The vertex shader:
+// 1. Scales the unit quad by half-(length, width).
+// 2. Rotates it by heading.
+// 3. Translates it by world (x, y).
+// 4. Applies the camera mvp.
+
+layout(location = 0) in vec2 in_corner; // unit quad corner
+
+layout(location = 1) in vec4 in_pose; // (x, y, heading, _pad)
+layout(location = 2) in vec2 in_size; // (length, width)
+layout(location = 3) in vec4 in_color;
+
+layout(push_constant) uniform Push {
+ mat4 mvp;
+ vec4 tint; // multiplied with in_color; used for view-specific tinting
+} pc;
+
+layout(location = 0) out vec4 v_color;
+
+void main() {
+ // Scale local corner by half-extents (length is along agent forward = local x,
+ // width is across = local y).
+ vec2 local = in_corner * (in_size * 0.5);
+
+ // Rotate by heading.
+ float c = cos(in_pose.z);
+ float s = sin(in_pose.z);
+ vec2 rotated = vec2(c * local.x - s * local.y,
+ s * local.x + c * local.y);
+
+ // Translate to world position.
+ vec2 world = rotated + in_pose.xy;
+
+ gl_Position = pc.mvp * vec4(world, 0.0, 1.0);
+ v_color = in_color * pc.tint;
+}
diff --git a/pufferlib/ocean/drive/trajviz/shaders/build_shaders.sh b/pufferlib/ocean/drive/trajviz/shaders/build_shaders.sh
new file mode 100755
index 0000000000..f884d1d248
--- /dev/null
+++ b/pufferlib/ocean/drive/trajviz/shaders/build_shaders.sh
@@ -0,0 +1,80 @@
+#!/usr/bin/env bash
+#
+# build_shaders.sh — compile trajviz GLSL shaders to SPIR-V and embed
+# them as uint32_t arrays in ../shaders.c.
+#
+# Run from anywhere — the script cds into its own directory first.
+# Invoked automatically by setup.py before building the trajviz Python
+# extension. Safe to run by hand for shader iteration:
+#
+# cd pufferlib/ocean/drive/trajviz/shaders && ./build_shaders.sh
+#
+# Requires: glslangValidator (apt: glslang-tools), python3.
+
+set -euo pipefail
+
+SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
+cd "$SCRIPT_DIR"
+
+GLSLANG="${GLSLANG:-glslangValidator}"
+if ! command -v "$GLSLANG" >/dev/null 2>&1; then
+ echo "ERROR: $GLSLANG not found. Install with: sudo apt install glslang-tools" >&2
+ exit 1
+fi
+
+if ! command -v python3 >/dev/null 2>&1; then
+ echo "ERROR: python3 not found." >&2
+ exit 1
+fi
+
+OUT_C="../shaders.c"
+TMPDIR="$(mktemp -d)"
+trap 'rm -rf "$TMPDIR"' EXIT
+
+# Names match the externs declared in shaders.h. Order is significant only
+# for the generated file's readability.
+SHADER_NAMES=(polyline_vert polyline_frag agent_box_vert agent_box_frag)
+SHADER_SRCS=(polyline.vert polyline.frag agent_box.vert agent_box.frag)
+
+for i in "${!SHADER_NAMES[@]}"; do
+ name="${SHADER_NAMES[$i]}"
+ src="${SHADER_SRCS[$i]}"
+ "$GLSLANG" -V "$src" -o "$TMPDIR/${name}.spv" >/dev/null
+done
+
+# Emit shaders.c. Using python3 for the hex dump because it's portable
+# and supports unpacking little-endian uint32 cleanly. The 16-byte
+# alignment annotation is required by Vulkan's vkCreateShaderModule
+# (pCode must be uint32-aligned; 16 covers any reasonable allocator).
+{
+ echo "/* AUTO-GENERATED by shaders/build_shaders.sh — do not edit. */"
+ echo "#include "
+ echo "#include "
+ echo
+ for i in "${!SHADER_NAMES[@]}"; do
+ name="${SHADER_NAMES[$i]}"
+ spv="$TMPDIR/${name}.spv"
+ size=$(wc -c < "$spv")
+ echo "const uint32_t ${name}_spv[] __attribute__((aligned(16))) = {"
+ python3 -c "
+import struct, sys
+data = open(sys.argv[1], 'rb').read()
+words = struct.unpack('<' + 'I' * (len(data) // 4), data)
+out = []
+for j, w in enumerate(words):
+ if j % 8 == 0:
+ out.append(' ')
+ out.append(f'0x{w:08x}, ')
+ if (j + 1) % 8 == 0:
+ out.append('\n')
+if not out[-1].endswith('\n'):
+ out.append('\n')
+sys.stdout.write(''.join(out))
+" "$spv"
+ echo "};"
+ echo "const size_t ${name}_spv_size = ${size};"
+ echo
+ done
+} > "$OUT_C"
+
+echo "Wrote $OUT_C ($(wc -l < "$OUT_C") lines)"
diff --git a/pufferlib/ocean/drive/trajviz/shaders/polyline.frag b/pufferlib/ocean/drive/trajviz/shaders/polyline.frag
new file mode 100644
index 0000000000..74f5e89de4
--- /dev/null
+++ b/pufferlib/ocean/drive/trajviz/shaders/polyline.frag
@@ -0,0 +1,14 @@
+#version 450
+
+// Polyline fragment shader — flat color from push constant.
+
+layout(location = 0) out vec4 out_color;
+
+layout(push_constant) uniform Push {
+ mat4 mvp;
+ vec4 color;
+} pc;
+
+void main() {
+ out_color = pc.color;
+}
diff --git a/pufferlib/ocean/drive/trajviz/shaders/polyline.vert b/pufferlib/ocean/drive/trajviz/shaders/polyline.vert
new file mode 100644
index 0000000000..c7175822a2
--- /dev/null
+++ b/pufferlib/ocean/drive/trajviz/shaders/polyline.vert
@@ -0,0 +1,22 @@
+#version 450
+
+// Polyline vertex shader.
+// Input: vec2 position in world (mean-centered sim) coordinates.
+// Push: mat4 mvp — world → clip space, supplied per-frame per-view.
+// Output: clip-space vec4 to gl_Position.
+//
+// Used for road polylines (line list) and optional trace overlays. The
+// fragment shader pulls the color from a separate push-constant member,
+// so this stage doesn't carry color attributes — keeps the vertex buffer
+// to 8 bytes per vert.
+
+layout(location = 0) in vec2 in_pos;
+
+layout(push_constant) uniform Push {
+ mat4 mvp;
+ vec4 color;
+} pc;
+
+void main() {
+ gl_Position = pc.mvp * vec4(in_pos, 0.0, 1.0);
+}
diff --git a/pufferlib/ocean/drive/trajviz/tests/test_main.c b/pufferlib/ocean/drive/trajviz/tests/test_main.c
new file mode 100644
index 0000000000..09b01de265
--- /dev/null
+++ b/pufferlib/ocean/drive/trajviz/tests/test_main.c
@@ -0,0 +1,151 @@
+/*
+ * test_main.c — standalone smoke test for the trajviz C side.
+ *
+ * Builds a synthetic scene (road grid + a few moving agents), calls
+ * trajviz_render_episode, and exits. No npz, no Python, no map parsing —
+ * the whole point is to validate the Vulkan + ffmpeg path in isolation
+ * before wiring it up to anything else.
+ *
+ * Build (after libvulkan-dev + glslang-tools are installed and shaders
+ * have been compiled by shaders/build_shaders.sh):
+ *
+ * cd pufferlib/ocean/drive/trajviz
+ * bash shaders/build_shaders.sh
+ * cc -O2 -Wall -Wextra -I. \
+ * tests/test_main.c trajviz.c vk_context.c vk_pipeline.c \
+ * vk_renderer.c ffmpeg_pipe.c shaders.c \
+ * -lvulkan -lm -lpthread -o tests/trajviz_test
+ * ./tests/trajviz_test
+ *
+ * Outputs test_topdown.mp4 and test_bev.mp4 in the current directory.
+ *
+ * Open them and check:
+ * - Top-down: a 200x200 m square road grid with a horizontal line of
+ * blue cars moving rightward. One car is orange (the ego) and stays
+ * in the middle of the line.
+ * - BEV: the orange ego car at the center, facing up, with the road
+ * grid sliding past from top to bottom as the car moves +x.
+ */
+
+#include "../trajviz.h"
+
+#include
+#include
+#include
+#include
+
+#define WIDTH 1280
+#define HEIGHT 720
+#define NUM_STEPS 90 /* 3 seconds at 30 fps */
+#define NUM_AGENTS 8
+#define FPS 30
+
+/* Build a simple road network: a grid of horizontal and vertical lines
+ * spanning [-100, +100]^2 meters at 25m spacing. Each grid line is one
+ * polyline with 2 vertices (a single segment). */
+static void build_grid(float **out_xy, uint32_t **out_offsets, uint32_t **out_types, uint32_t *out_num_polys,
+ uint32_t *out_num_verts) {
+ const float extent = 100.0f;
+ const float step = 25.0f;
+ const int n = (int)(2 * extent / step) + 1; /* 9 lines per axis */
+ const int total = n * 2; /* 18 polylines */
+
+ *out_num_polys = (uint32_t)total;
+ *out_num_verts = (uint32_t)(total * 2);
+
+ *out_xy = (float *)calloc((size_t)total * 2 * 2, sizeof(float));
+ *out_offsets = (uint32_t *)calloc((size_t)total + 1, sizeof(uint32_t));
+ *out_types = (uint32_t *)calloc((size_t)total, sizeof(uint32_t));
+
+ int p = 0;
+ int v = 0;
+ /* Horizontal lines (constant y) */
+ for (int i = 0; i < n; ++i) {
+ float y = -extent + i * step;
+ (*out_xy)[v * 2 + 0] = -extent;
+ (*out_xy)[v * 2 + 1] = y;
+ v++;
+ (*out_xy)[v * 2 + 0] = extent;
+ (*out_xy)[v * 2 + 1] = y;
+ v++;
+ (*out_offsets)[p + 1] = (uint32_t)v;
+ (*out_types)[p] = TVZ_ROAD_EDGE;
+ p++;
+ }
+ /* Vertical lines (constant x) */
+ for (int i = 0; i < n; ++i) {
+ float x = -extent + i * step;
+ (*out_xy)[v * 2 + 0] = x;
+ (*out_xy)[v * 2 + 1] = -extent;
+ v++;
+ (*out_xy)[v * 2 + 0] = x;
+ (*out_xy)[v * 2 + 1] = extent;
+ v++;
+ (*out_offsets)[p + 1] = (uint32_t)v;
+ (*out_types)[p] = TVZ_ROAD_LANE;
+ p++;
+ }
+}
+
+/* Build NUM_STEPS frames of a horizontal line of agents moving in +x at
+ * 10 m/s. Stored step-major: traj[step*NA*3 + a*3 + {0,1,2}]. */
+static float *build_trajectory(void) {
+ float *traj = (float *)calloc((size_t)NUM_STEPS * NUM_AGENTS * 3, sizeof(float));
+ const float vx_per_step = 10.0f / FPS; /* 10 m/s */
+ for (int step = 0; step < NUM_STEPS; ++step) {
+ for (int a = 0; a < NUM_AGENTS; ++a) {
+ float base_x = -40.0f + a * 12.0f;
+ float x = base_x + step * vx_per_step;
+ float y = 0.0f;
+ float h = 0.0f; /* facing +x */
+ size_t off = ((size_t)step * NUM_AGENTS + a) * 3;
+ traj[off + 0] = x;
+ traj[off + 1] = y;
+ traj[off + 2] = h;
+ }
+ }
+ return traj;
+}
+
+int main(int argc, char **argv) {
+ (void)argc;
+ (void)argv;
+
+ TrajvizCtx *ctx = trajviz_init(WIDTH, HEIGHT);
+ if (!ctx) {
+ fprintf(stderr, "trajviz_init failed: %s\n", trajviz_last_error(NULL));
+ return 1;
+ }
+
+ float *road_xy = NULL;
+ uint32_t *road_off = NULL;
+ uint32_t *road_typ = NULL;
+ uint32_t num_polys = 0, num_verts = 0;
+ build_grid(&road_xy, &road_off, &road_typ, &num_polys, &num_verts);
+
+ float *traj = build_trajectory();
+
+ /* All agents are valid for the full episode. */
+ int32_t lengths[NUM_AGENTS];
+ for (int i = 0; i < NUM_AGENTS; ++i)
+ lengths[i] = NUM_STEPS;
+
+ /* Default agent dimensions for all. */
+ int rc = trajviz_render_episode(ctx, road_xy, road_off, road_typ, num_polys, traj, NUM_STEPS, NUM_AGENTS,
+ NULL, /* agent_dims = default */
+ lengths, 4, /* ego_idx — middle of the line */
+ FPS, "test_topdown.mp4", "test_bev.mp4");
+
+ if (rc != TRAJVIZ_OK) {
+ fprintf(stderr, "trajviz_render_episode failed (%d): %s\n", rc, trajviz_last_error(ctx));
+ } else {
+ fprintf(stderr, "wrote test_topdown.mp4 and test_bev.mp4\n");
+ }
+
+ free(road_xy);
+ free(road_off);
+ free(road_typ);
+ free(traj);
+ trajviz_close(ctx);
+ return rc == TRAJVIZ_OK ? 0 : 1;
+}
diff --git a/pufferlib/ocean/drive/trajviz/tools/__init__.py b/pufferlib/ocean/drive/trajviz/tools/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/pufferlib/ocean/drive/trajviz/tools/random_rollout.py b/pufferlib/ocean/drive/trajviz/tools/random_rollout.py
new file mode 100644
index 0000000000..42ffe1075b
--- /dev/null
+++ b/pufferlib/ocean/drive/trajviz/tools/random_rollout.py
@@ -0,0 +1,172 @@
+"""Render a full episode of a random-policy rollout to MP4 via trajviz.
+
+Standalone smoke test that exercises the entire trajviz pipeline against
+a real Drive simulation (not the synthetic grid in tests/test_main.c):
+
+ 1. Spin up a Drive env on one map.
+ 2. Reset and step with uniformly random actions for the full episode.
+ 3. Pull per-step (x, y, heading) out via get_sim_trajectories().
+ 4. Load the source map .bin via map_io and mean-center it with the
+ env's world_mean (so road geometry lines up with sim coordinates).
+ 5. Drive a single Renderer.render_episode call — both top-down and
+ BEV views in one pass — and write two MP4s.
+
+Usage:
+ python -m pufferlib.ocean.drive.trajviz.tools.random_rollout \\
+ [--map pufferlib/resources/drive/binaries/map_001.bin] \\
+ [--out-dir /tmp] [--episode-length 91] [--seed 0]
+
+The C extension must already be built (TRAJVIZ=1 python setup.py build_ext --inplace).
+"""
+
+from __future__ import annotations
+
+import argparse
+import shutil
+import tempfile
+from pathlib import Path
+
+import numpy as np
+
+from pufferlib.ocean.drive.drive import Drive
+from pufferlib.ocean.drive import map_io
+from pufferlib.ocean.drive.trajviz import Renderer
+
+
+def main() -> int:
+ p = argparse.ArgumentParser(description=__doc__)
+ p.add_argument(
+ "--map",
+ type=Path,
+ default=Path("pufferlib/resources/drive/binaries/map_001.bin"),
+ help="Path to a single .bin map file.",
+ )
+ p.add_argument("--out-dir", type=Path, default=Path("/tmp"))
+ p.add_argument("--episode-length", type=int, default=91, help="Number of sim steps to roll out.")
+ p.add_argument("--seed", type=int, default=0)
+ p.add_argument(
+ "--num-agents",
+ type=int,
+ default=2,
+ help="Cap on TOTAL agent slots in the env. Default 2 matches "
+ "the typical WOSAC tracks_to_predict count for one map; "
+ "raising it makes Drive instantiate the same map across "
+ "multiple sub-envs to fill the cap.",
+ )
+ p.add_argument(
+ "--init-mode",
+ default="create_only_controlled",
+ choices=("create_all_valid", "create_only_controlled", "init_variable_agent_number"),
+ help="How Drive instantiates agents. 'create_only_controlled' "
+ "(default here) gives random actions only to the source "
+ "scenario's tracks_to_predict agents and replays the rest "
+ "from log data — matching real WOSAC behavior. "
+ "'create_all_valid' makes every vehicle policy-controlled.",
+ )
+ p.add_argument("--width", type=int, default=1280)
+ p.add_argument("--height", type=int, default=720)
+ p.add_argument("--fps", type=int, default=30)
+ args = p.parse_args()
+
+ if not args.map.exists():
+ raise SystemExit(f"map not found: {args.map}")
+
+ # Drive expects a directory of maps and loads num_maps of them sorted.
+ # To pin it to one specific map regardless of which alphabetically
+ # comes first in the source dir, copy our chosen map into a fresh
+ # temp dir and point Drive at that.
+ with tempfile.TemporaryDirectory(prefix="trajviz_random_") as tmpdir:
+ shutil.copy(args.map, tmpdir)
+ print(f"[rollout] map: {args.map} → {tmpdir}")
+
+ # init_mode controls which agents become policy-controlled vs
+ # expert-replayed. With 'create_only_controlled', random actions
+ # only affect the source scenario's tracks_to_predict agents
+ # (typically 2 in WOSAC); the rest replay their Waymo log
+ # trajectories. With 'create_all_valid' the random actions move
+ # everything — useful if you want to see chaos, but doesn't
+ # match how the trained policy would actually be used.
+ env = Drive(
+ map_dir=tmpdir,
+ num_maps=1,
+ num_agents=args.num_agents,
+ episode_length=args.episode_length,
+ seed=args.seed,
+ init_steps=0,
+ init_mode=args.init_mode,
+ )
+ print(f"[rollout] num_agents={env.num_agents} episode_length={env.episode_length}")
+
+ rng = np.random.default_rng(args.seed)
+ env.reset(seed=args.seed)
+
+ # Action buffer was pre-allocated by PufferEnv based on the
+ # MultiDiscrete([91]) space — one categorical per agent in [0, 91).
+ actions_shape = env.actions.shape
+ actions_dtype = env.actions.dtype
+ action_high = 91
+
+ # Important: stop ONE step before episode_length so we don't trigger
+ # the auto-reset at end-of-episode. The C side increments timestep
+ # in c_step and resets it to 0 when timestep == episode_length, which
+ # would zero out traj["lengths"]. Stepping episode_length-1 times
+ # leaves timestep at episode_length-1 (no reset) with that many
+ # frames recorded.
+ n_steps = env.episode_length - 1
+ for step in range(n_steps):
+ actions = rng.integers(0, action_high, size=actions_shape, dtype=actions_dtype)
+ obs, reward, done, trunc, info = env.step(actions)
+ if trunc.all():
+ print(f"[rollout] unexpected truncation at step {step + 1}")
+ break
+
+ traj = env.get_sim_trajectories()
+ world_mean = np.asarray(env.world_mean, dtype=np.float32)
+ print(f"[rollout] world_mean = {world_mean}")
+ print(
+ f"[rollout] valid lengths: min={int(traj['lengths'].min())} "
+ f"max={int(traj['lengths'].max())} "
+ f"mean={float(traj['lengths'].mean()):.1f}"
+ )
+
+ env.close()
+
+ # Load the same source map and mean-center it. We use the path the
+ # user supplied, not the temp copy (which is gone now), but they hold
+ # the same bytes.
+ roads_raw = map_io.load_map_roads(args.map)
+ roads = map_io.mean_center_roads(roads_raw, world_mean)
+ road_xy, road_offsets, road_types = map_io.roads_to_csr(roads)
+ print(f"[rollout] roads: {len(roads)} polylines, {int(road_xy.shape[0])} verts")
+
+ # Stack into (T, A, 3). get_sim_trajectories returns (A, T) per field.
+ num_agents, num_steps = traj["x"].shape
+ traj_xyh = np.empty((num_steps, num_agents, 3), dtype=np.float32)
+ traj_xyh[..., 0] = traj["x"].T
+ traj_xyh[..., 1] = traj["y"].T
+ traj_xyh[..., 2] = traj["heading"].T
+
+ args.out_dir.mkdir(parents=True, exist_ok=True)
+ out_topdown = args.out_dir / "random_topdown.mp4"
+ out_bev = args.out_dir / "random_bev.mp4"
+
+ print(f"[rollout] rendering to {out_topdown.name} + {out_bev.name} ...")
+ with Renderer(width=args.width, height=args.height) as r:
+ r.render_episode(
+ road_xy=road_xy,
+ road_offsets=road_offsets,
+ road_types=road_types,
+ traj_xyh=traj_xyh,
+ agent_lengths=traj["lengths"].astype(np.int32),
+ ego_idx=-1,
+ fps=args.fps,
+ out_topdown=str(out_topdown),
+ out_bev=str(out_bev),
+ )
+ print(f"[rollout] wrote {out_topdown}")
+ print(f"[rollout] wrote {out_bev}")
+ return 0
+
+
+if __name__ == "__main__":
+ raise SystemExit(main())
diff --git a/pufferlib/ocean/drive/trajviz/trajviz.c b/pufferlib/ocean/drive/trajviz/trajviz.c
new file mode 100644
index 0000000000..41d0721d95
--- /dev/null
+++ b/pufferlib/ocean/drive/trajviz/trajviz.c
@@ -0,0 +1,648 @@
+/*
+ * trajviz.c — public C API: orchestrates Vulkan + ffmpeg per episode.
+ *
+ * The public functions live here. They compose vk_context, vk_pipeline,
+ * vk_renderer, and ffmpeg_pipe into a single "render this episode to
+ * these mp4 paths" call. The CPython extension and the standalone test
+ * harness both call into this and nothing else.
+ *
+ * Per-frame loop:
+ * 1. For each agent in (0, num_agents):
+ * - Skip if agent_lengths is set and step >= length.
+ * - Pull (x, y, heading) from traj_xyh[step][agent].
+ * - Pull (length, width) from agent_dims (or defaults).
+ * - Color: ego = orange, others = teal.
+ * - Append to instances[].
+ * 2. Build top-down camera (fits the road AABB) — same matrix every
+ * frame, computed once outside the loop.
+ * 3. Build BEV camera from ego pose at this step.
+ * 4. vk_renderer_render_frame.
+ * 5. fwrite each readback buffer into the corresponding ffmpeg pipe.
+ *
+ * The road AABB is computed once on the road verts before the per-frame
+ * loop. The BEV window is hardcoded to 50 m half-extent (matching the
+ * drive sim's observation_window_size of 100 m total).
+ */
+
+#include "trajviz.h"
+#include "vk_context.h"
+#include "vk_pipeline.h"
+#include "vk_renderer.h"
+#include "vk_batch_renderer.h"
+#include "ffmpeg_pipe.h"
+#include "vk_math.h"
+
+#include
+#include
+#include
+#include
+
+/* The opaque ctx struct exposed via trajviz.h. Holds everything that
+ * lives across episodes. */
+struct TrajvizCtx {
+ VkCtx vk;
+ Pipelines pipelines;
+ Renderer renderer;
+
+ /* Lazily-allocated batched renderer. Created on first call to
+ * render_episodes_batch and reused if subsequent batches request
+ * the same batch_size. If a different batch_size is requested,
+ * the existing one is destroyed and recreated. */
+ BatchRenderer batch;
+ int batch_initialized;
+ int batch_size_cur;
+
+ /* Reusable scratch buffer for per-frame instance arrays. Grows on
+ * demand and stays at the high water mark for the ctx lifetime. */
+ AgentInstance *scratch_instances;
+ uint32_t scratch_capacity;
+
+ /* Mirror of the last error string from the underlying VkCtx so the
+ * caller can read it via trajviz_last_error even if vk has been
+ * partially torn down. */
+ char last_error[TRAJVIZ_ERROR_BUF];
+};
+
+/* Global last-error slot for init failures (when there's no ctx yet). */
+static char g_init_error[TRAJVIZ_ERROR_BUF];
+
+const char *trajviz_last_error(const TrajvizCtx *ctx) {
+ if (!ctx)
+ return g_init_error;
+ return ctx->last_error;
+}
+
+static void copy_error(TrajvizCtx *ctx) {
+ if (!ctx)
+ return;
+ snprintf(ctx->last_error, TRAJVIZ_ERROR_BUF, "%s", ctx->vk.last_error);
+}
+
+/* Forward declarations of helpers used by both single-episode and
+ * batched paths. They live near the bottom of the file. */
+static void compute_road_aabb(const float *road_xy, uint32_t num_verts, float aabb[4]);
+static int32_t resolve_ego(int32_t requested, uint32_t num_agents, const int32_t *lengths);
+static int ensure_scratch(TrajvizCtx *ctx, uint32_t num);
+
+TrajvizCtx *trajviz_init(int width, int height) {
+ if (width <= 0 || height <= 0 || width > 8192 || height > 8192) {
+ snprintf(g_init_error, sizeof(g_init_error), "invalid dimensions %dx%d", width, height);
+ return NULL;
+ }
+ TrajvizCtx *ctx = (TrajvizCtx *)calloc(1, sizeof(*ctx));
+ if (!ctx) {
+ snprintf(g_init_error, sizeof(g_init_error), "out of memory");
+ return NULL;
+ }
+
+ int rc = vk_ctx_init(&ctx->vk);
+ if (rc != 0) {
+ /* %.480s caps the source string so the total snprintf output
+ * can't exceed g_init_error's 512-byte capacity. */
+ snprintf(g_init_error, sizeof(g_init_error), "vk_ctx_init: %.480s", ctx->vk.last_error);
+ free(ctx);
+ return NULL;
+ }
+
+ rc = vk_pipelines_init(&ctx->vk, &ctx->pipelines, VK_FORMAT_R8G8B8A8_UNORM);
+ if (rc != 0) {
+ snprintf(g_init_error, sizeof(g_init_error), "vk_pipelines_init: %.480s", ctx->vk.last_error);
+ vk_ctx_destroy(&ctx->vk);
+ free(ctx);
+ return NULL;
+ }
+
+ rc = vk_renderer_init(&ctx->vk, &ctx->pipelines, &ctx->renderer, (uint32_t)width, (uint32_t)height);
+ if (rc != 0) {
+ snprintf(g_init_error, sizeof(g_init_error), "vk_renderer_init: %.480s", ctx->vk.last_error);
+ vk_pipelines_destroy(&ctx->vk, &ctx->pipelines);
+ vk_ctx_destroy(&ctx->vk);
+ free(ctx);
+ return NULL;
+ }
+
+ fprintf(stderr, "[trajviz] using device: %s\n", ctx->vk.device_name);
+ return ctx;
+}
+
+void trajviz_close(TrajvizCtx *ctx) {
+ if (!ctx)
+ return;
+ if (ctx->batch_initialized) {
+ vk_batch_renderer_destroy(&ctx->vk, &ctx->batch);
+ ctx->batch_initialized = 0;
+ }
+ vk_renderer_destroy(&ctx->vk, &ctx->renderer);
+ vk_pipelines_destroy(&ctx->vk, &ctx->pipelines);
+ vk_ctx_destroy(&ctx->vk);
+ free(ctx->scratch_instances);
+ free(ctx);
+}
+
+/* Compute the AABB of the road verts. Used to fit the top-down camera. */
+static void compute_road_aabb(const float *road_xy, uint32_t num_verts, float aabb[4]) {
+ if (num_verts == 0) {
+ aabb[0] = -50.0f;
+ aabb[1] = -50.0f;
+ aabb[2] = 50.0f;
+ aabb[3] = 50.0f;
+ return;
+ }
+ float xmin = road_xy[0], xmax = road_xy[0];
+ float ymin = road_xy[1], ymax = road_xy[1];
+ for (uint32_t i = 1; i < num_verts; ++i) {
+ float x = road_xy[i * 2 + 0];
+ float y = road_xy[i * 2 + 1];
+ if (x < xmin)
+ xmin = x;
+ if (x > xmax)
+ xmax = x;
+ if (y < ymin)
+ ymin = y;
+ if (y > ymax)
+ ymax = y;
+ }
+ aabb[0] = xmin;
+ aabb[1] = ymin;
+ aabb[2] = xmax;
+ aabb[3] = ymax;
+}
+
+/* Determine the ego agent for the BEV view: the requested index, or the
+ * first agent with at least 2 valid steps if -1. */
+static int32_t resolve_ego(int32_t requested, uint32_t num_agents, const int32_t *lengths) {
+ if (requested >= 0 && (uint32_t)requested < num_agents)
+ return requested;
+ if (!lengths)
+ return 0;
+ for (uint32_t a = 0; a < num_agents; ++a) {
+ if (lengths[a] >= 2)
+ return (int32_t)a;
+ }
+ return 0;
+}
+
+static int ensure_scratch(TrajvizCtx *ctx, uint32_t num) {
+ if (num <= ctx->scratch_capacity)
+ return 0;
+ uint32_t cap = 16;
+ while (cap < num)
+ cap <<= 1;
+ AgentInstance *p = (AgentInstance *)realloc(ctx->scratch_instances, cap * sizeof(AgentInstance));
+ if (!p) {
+ snprintf(ctx->last_error, sizeof(ctx->last_error), "out of memory for scratch instances (%u)", cap);
+ return -1;
+ }
+ ctx->scratch_instances = p;
+ ctx->scratch_capacity = cap;
+ return 0;
+}
+
+/* Default agent dimensions if the caller doesn't supply per-agent ones.
+ * Roughly the median car size in Waymo Open. */
+#define DEFAULT_AGENT_LENGTH 5.0f
+#define DEFAULT_AGENT_WIDTH 2.0f
+
+int trajviz_render_episode(TrajvizCtx *ctx, const float *road_xy, const uint32_t *road_offsets,
+ const uint32_t *road_types, uint32_t num_road_polys, const float *traj_xyh,
+ uint32_t num_steps, uint32_t num_agents, const float *agent_dims,
+ const int32_t *agent_lengths, int32_t ego_idx, int fps, const char *out_topdown_mp4,
+ const char *out_bev_mp4) {
+ if (!ctx)
+ return TRAJVIZ_ERR_BAD_ARG;
+ if (!traj_xyh || num_steps == 0 || num_agents == 0) {
+ snprintf(ctx->last_error, sizeof(ctx->last_error), "empty trajectory (steps=%u agents=%u)", num_steps,
+ num_agents);
+ return TRAJVIZ_ERR_BAD_ARG;
+ }
+ if (!out_topdown_mp4 && !out_bev_mp4) {
+ snprintf(ctx->last_error, sizeof(ctx->last_error),
+ "no output paths supplied — at least one of topdown/bev must be set");
+ return TRAJVIZ_ERR_BAD_ARG;
+ }
+ if (fps <= 0)
+ fps = 30;
+
+ /* Number of road verts is the last entry of road_offsets, by CSR
+ * convention. Allow zero polylines (and therefore zero verts). */
+ uint32_t num_road_verts = (num_road_polys > 0) ? road_offsets[num_road_polys] : 0;
+
+ int rc = vk_renderer_set_roads(&ctx->vk, &ctx->renderer, road_xy, num_road_verts, road_offsets, road_types,
+ num_road_polys);
+ if (rc != 0) {
+ copy_error(ctx);
+ return TRAJVIZ_ERR_VK_DEVICE;
+ }
+
+ if (ensure_scratch(ctx, num_agents) != 0) {
+ return TRAJVIZ_ERR_VK_OOM;
+ }
+
+ /* Pre-compute the road AABB and the static top-down camera. */
+ float aabb[4];
+ compute_road_aabb(road_xy, num_road_verts, aabb);
+ Mat4 mvp_topdown =
+ mat4_fit_aabb(aabb[0], aabb[1], aabb[2], aabb[3], (int)ctx->renderer.width, (int)ctx->renderer.height, 0.05f);
+
+ int32_t ego = resolve_ego(ego_idx, num_agents, agent_lengths);
+
+ /* Episode length: the longest agent's lifetime, capped at num_steps.
+ * Without lengths we use the full traj. */
+ uint32_t ep_len = num_steps;
+ if (agent_lengths) {
+ uint32_t maxlen = 0;
+ for (uint32_t a = 0; a < num_agents; ++a) {
+ uint32_t l = (agent_lengths[a] < 0) ? 0 : (uint32_t)agent_lengths[a];
+ if (l > maxlen)
+ maxlen = l;
+ }
+ if (maxlen < ep_len)
+ ep_len = maxlen;
+ if (ep_len == 0)
+ ep_len = 1; /* always render at least one frame */
+ }
+
+ /* Open ffmpeg pipes for whichever views were requested. */
+ FfmpegPipe pipe_td = {0}, pipe_bev = {0};
+ int has_td = (out_topdown_mp4 != NULL);
+ int has_bev = (out_bev_mp4 != NULL);
+ if (has_td) {
+ if (ffmpeg_pipe_open(&pipe_td, (int)ctx->renderer.width, (int)ctx->renderer.height, fps, out_topdown_mp4) !=
+ 0) {
+ snprintf(ctx->last_error, sizeof(ctx->last_error), "ffmpeg_pipe_open failed for top-down view (%s)",
+ out_topdown_mp4);
+ return TRAJVIZ_ERR_FFMPEG_SPAWN;
+ }
+ }
+ if (has_bev) {
+ if (ffmpeg_pipe_open(&pipe_bev, (int)ctx->renderer.width, (int)ctx->renderer.height, fps, out_bev_mp4) != 0) {
+ snprintf(ctx->last_error, sizeof(ctx->last_error), "ffmpeg_pipe_open failed for bev view (%s)",
+ out_bev_mp4);
+ if (has_td)
+ ffmpeg_pipe_close(&pipe_td);
+ return TRAJVIZ_ERR_FFMPEG_SPAWN;
+ }
+ }
+
+ /* Hand the pipes to the renderer for the duration of the episode.
+ * It will fwrite to them internally as the frames-in-flight ring
+ * drains older slots. */
+ vk_renderer_episode_begin(&ctx->renderer, has_td ? &pipe_td : NULL, has_bev ? &pipe_bev : NULL);
+
+ int err = TRAJVIZ_OK;
+ for (uint32_t step = 0; step < ep_len; ++step) {
+ /* Build instance array for this frame. */
+ uint32_t n_inst = 0;
+ for (uint32_t a = 0; a < num_agents; ++a) {
+ if (agent_lengths && step >= (uint32_t)((agent_lengths[a] < 0) ? 0 : agent_lengths[a])) {
+ continue;
+ }
+ const float *ph = &traj_xyh[((size_t)step * num_agents + a) * 3];
+ if (ph[0] == 0.0f && ph[1] == 0.0f)
+ continue;
+
+ AgentInstance *ai = &ctx->scratch_instances[n_inst++];
+ ai->pose[0] = ph[0];
+ ai->pose[1] = ph[1];
+ ai->pose[2] = ph[2];
+ ai->pose[3] = 0.0f;
+ if (agent_dims) {
+ ai->size[0] = agent_dims[a * 2 + 0];
+ ai->size[1] = agent_dims[a * 2 + 1];
+ } else {
+ ai->size[0] = DEFAULT_AGENT_LENGTH;
+ ai->size[1] = DEFAULT_AGENT_WIDTH;
+ }
+ if ((int32_t)a == ego) {
+ ai->color[0] = 1.00f;
+ ai->color[1] = 0.55f;
+ ai->color[2] = 0.10f;
+ ai->color[3] = 1.0f;
+ } else {
+ ai->color[0] = 0.20f;
+ ai->color[1] = 0.75f;
+ ai->color[2] = 0.85f;
+ ai->color[3] = 1.0f;
+ }
+ }
+
+ /* BEV camera follows ego at this step. If the ego has terminated
+ * (length exceeded), keep the camera at its last valid position
+ * by clamping the step index used for the lookup. */
+ Mat4 mvp_bev;
+ if (has_bev) {
+ uint32_t bev_step = step;
+ if (agent_lengths && bev_step >= (uint32_t)((agent_lengths[ego] < 0) ? 0 : agent_lengths[ego])) {
+ bev_step = (uint32_t)((agent_lengths[ego] <= 0) ? 0 : agent_lengths[ego] - 1);
+ }
+ const float *ph = &traj_xyh[((size_t)bev_step * num_agents + ego) * 3];
+ mvp_bev = mat4_bev_camera(ph[0], ph[1], ph[2], 50.0f, (int)ctx->renderer.width, (int)ctx->renderer.height);
+ }
+
+ rc = vk_renderer_submit_frame(&ctx->vk, &ctx->renderer, ctx->scratch_instances, n_inst,
+ has_td ? &mvp_topdown : NULL, has_bev ? &mvp_bev : NULL);
+ if (rc != 0) {
+ copy_error(ctx);
+ err = (rc == -1) ? TRAJVIZ_ERR_FFMPEG_WRITE : TRAJVIZ_ERR_VK_DEVICE;
+ break;
+ }
+ }
+
+ /* Drain the FRAMES_IN_FLIGHT - 1 slots still pending after the loop. */
+ if (err == TRAJVIZ_OK) {
+ rc = vk_renderer_episode_end(&ctx->vk, &ctx->renderer);
+ if (rc != 0) {
+ copy_error(ctx);
+ err = (rc == -1) ? TRAJVIZ_ERR_FFMPEG_WRITE : TRAJVIZ_ERR_VK_DEVICE;
+ }
+ } else {
+ /* Best-effort drain so we don't leave the renderer in a half-state. */
+ vk_renderer_episode_end(&ctx->vk, &ctx->renderer);
+ }
+
+ if (has_td)
+ ffmpeg_pipe_close(&pipe_td);
+ if (has_bev)
+ ffmpeg_pipe_close(&pipe_bev);
+ return err;
+}
+
+/* ============================================================================
+ * Batched multi-episode rendering
+ * ============================================================================
+ *
+ * Renders N episodes simultaneously by tiling them into a per-view atlas
+ * image and recording all N tiles into one command-buffer per frame. The
+ * BatchRenderer is held in TrajvizCtx and reused across calls; if a
+ * subsequent call requests a different batch_size, we destroy and
+ * recreate (init cost ~20 ms paid once per unique size).
+ *
+ * Atlas tile dimensions match ctx->renderer.width/height — i.e. the same
+ * resolution as a single-episode render. The atlas is tile_w × (N*tile_h)
+ * with tiles stacked vertically so each tile's bytes are contiguous in
+ * the host readback buffer (one fwrite per tile per frame, no row stitching).
+ */
+
+#define TRAJVIZ_BATCH_MAX 16
+
+int trajviz_render_episodes_batch(TrajvizCtx *ctx, int batch_size, uint32_t num_steps, uint32_t max_agents,
+ const float *all_road_xy, const uint32_t *vert_offsets,
+ const uint32_t *all_road_offsets, const uint32_t *poly_meta_offsets,
+ const uint32_t *all_road_types, const uint32_t *poly_type_offsets,
+ const float *traj_xyh, const int32_t *agent_lengths, const int32_t *ego_idx_per_ep,
+ int fps, const char **out_topdown_paths, const char **out_bev_paths) {
+ if (!ctx)
+ return TRAJVIZ_ERR_BAD_ARG;
+ if (batch_size <= 0 || batch_size > TRAJVIZ_BATCH_MAX) {
+ snprintf(ctx->last_error, sizeof(ctx->last_error), "batch_size %d out of range [1, %d]", batch_size,
+ TRAJVIZ_BATCH_MAX);
+ return TRAJVIZ_ERR_BAD_ARG;
+ }
+ if (num_steps == 0 || max_agents == 0) {
+ snprintf(ctx->last_error, sizeof(ctx->last_error), "empty trajectory (steps=%u agents=%u)", num_steps,
+ max_agents);
+ return TRAJVIZ_ERR_BAD_ARG;
+ }
+ if (!traj_xyh || !vert_offsets || !poly_meta_offsets || !poly_type_offsets || !agent_lengths) {
+ snprintf(ctx->last_error, sizeof(ctx->last_error), "null required pointer to render_episodes_batch");
+ return TRAJVIZ_ERR_BAD_ARG;
+ }
+ if (fps <= 0)
+ fps = 30;
+
+ /* Lazily allocate or recreate the BatchRenderer when the requested
+ * size doesn't match the current one. */
+ if (ctx->batch_initialized && ctx->batch_size_cur != batch_size) {
+ vk_batch_renderer_destroy(&ctx->vk, &ctx->batch);
+ ctx->batch_initialized = 0;
+ }
+ if (!ctx->batch_initialized) {
+ int rc = vk_batch_renderer_init(&ctx->vk, &ctx->pipelines, &ctx->batch, batch_size, ctx->renderer.width,
+ ctx->renderer.height);
+ if (rc != 0) {
+ copy_error(ctx);
+ return TRAJVIZ_ERR_VK_DEVICE;
+ }
+ ctx->batch_initialized = 1;
+ ctx->batch_size_cur = batch_size;
+ }
+
+ /* Per-episode local state. All malloc'd so we can goto a single
+ * cleanup label on failure without VLAs. */
+ Mat4 *topdown_cams = (Mat4 *)calloc((size_t)batch_size, sizeof(Mat4));
+ int32_t *effective_lengths = (int32_t *)calloc((size_t)batch_size, sizeof(int32_t));
+ int32_t *resolved_egos = (int32_t *)calloc((size_t)batch_size, sizeof(int32_t));
+ FfmpegPipe *pipes_td = (FfmpegPipe *)calloc((size_t)batch_size, sizeof(FfmpegPipe));
+ FfmpegPipe *pipes_bev = (FfmpegPipe *)calloc((size_t)batch_size, sizeof(FfmpegPipe));
+ int *has_pipe_td = (int *)calloc((size_t)batch_size, sizeof(int));
+ int *has_pipe_bev = (int *)calloc((size_t)batch_size, sizeof(int));
+ if (!topdown_cams || !effective_lengths || !resolved_egos || !pipes_td || !pipes_bev || !has_pipe_td ||
+ !has_pipe_bev) {
+ free(topdown_cams);
+ free(effective_lengths);
+ free(resolved_egos);
+ free(pipes_td);
+ free(pipes_bev);
+ free(has_pipe_td);
+ free(has_pipe_bev);
+ snprintf(ctx->last_error, sizeof(ctx->last_error), "out of memory allocating batch state");
+ return TRAJVIZ_ERR_VK_OOM;
+ }
+
+ int err = TRAJVIZ_OK;
+ int max_eff_length = 0;
+
+ /* Per-episode setup pass: open pipes, upload roads, compute the
+ * frame-invariant top-down camera matrix, resolve ego index. */
+ for (int s = 0; s < batch_size; ++s) {
+ const char *out_td_s = out_topdown_paths ? out_topdown_paths[s] : NULL;
+ const char *out_bev_s = out_bev_paths ? out_bev_paths[s] : NULL;
+ if (!out_td_s && !out_bev_s) {
+ /* No outputs for this slot — skip entirely. */
+ continue;
+ }
+
+ /* Slice road data for episode s out of the concatenated arrays. */
+ uint32_t v_start = vert_offsets[s];
+ uint32_t v_end = vert_offsets[s + 1];
+ uint32_t num_verts_s = v_end - v_start;
+ const float *xy_s = all_road_xy + (size_t)v_start * 2;
+
+ uint32_t pm_start = poly_meta_offsets[s];
+ uint32_t pm_end = poly_meta_offsets[s + 1];
+ uint32_t num_polys_plus_1 = (pm_end > pm_start) ? (pm_end - pm_start) : 0;
+ uint32_t num_polys_s = (num_polys_plus_1 > 0) ? num_polys_plus_1 - 1 : 0;
+ const uint32_t *off_s = (all_road_offsets && num_polys_plus_1 > 0) ? (all_road_offsets + pm_start) : NULL;
+
+ uint32_t pt_start = poly_type_offsets[s];
+ const uint32_t *typ_s = (all_road_types && num_polys_s > 0) ? (all_road_types + pt_start) : NULL;
+
+ /* Open ffmpeg pipes for this slot. */
+ if (out_td_s) {
+ if (ffmpeg_pipe_open(&pipes_td[s], (int)ctx->batch.tile_w, (int)ctx->batch.tile_h, fps, out_td_s) != 0) {
+ snprintf(ctx->last_error, sizeof(ctx->last_error),
+ "ffmpeg_pipe_open failed for episode %d top-down (%s)", s, out_td_s);
+ err = TRAJVIZ_ERR_FFMPEG_SPAWN;
+ goto cleanup;
+ }
+ has_pipe_td[s] = 1;
+ }
+ if (out_bev_s) {
+ if (ffmpeg_pipe_open(&pipes_bev[s], (int)ctx->batch.tile_w, (int)ctx->batch.tile_h, fps, out_bev_s) != 0) {
+ snprintf(ctx->last_error, sizeof(ctx->last_error), "ffmpeg_pipe_open failed for episode %d bev (%s)", s,
+ out_bev_s);
+ err = TRAJVIZ_ERR_FFMPEG_SPAWN;
+ goto cleanup;
+ }
+ has_pipe_bev[s] = 1;
+ }
+
+ int rc =
+ vk_batch_renderer_set_episode(&ctx->vk, &ctx->batch, s, xy_s, num_verts_s, off_s, typ_s, num_polys_s,
+ has_pipe_td[s] ? &pipes_td[s] : NULL, has_pipe_bev[s] ? &pipes_bev[s] : NULL);
+ if (rc != 0) {
+ copy_error(ctx);
+ err = TRAJVIZ_ERR_VK_DEVICE;
+ goto cleanup;
+ }
+
+ /* Top-down camera = fit road AABB to tile (frame-invariant). */
+ float aabb[4];
+ compute_road_aabb(xy_s, num_verts_s, aabb);
+ topdown_cams[s] =
+ mat4_fit_aabb(aabb[0], aabb[1], aabb[2], aabb[3], (int)ctx->batch.tile_w, (int)ctx->batch.tile_h, 0.05f);
+
+ /* Effective episode length = max valid agent_lengths in this slot. */
+ const int32_t *lens_s = agent_lengths + (size_t)s * max_agents;
+ int32_t maxlen = 0;
+ for (uint32_t a = 0; a < max_agents; ++a) {
+ int32_t l = lens_s[a];
+ if (l < 0)
+ l = 0;
+ if ((uint32_t)l > num_steps)
+ l = (int32_t)num_steps;
+ if (l > maxlen)
+ maxlen = l;
+ }
+ effective_lengths[s] = maxlen;
+ if (maxlen > max_eff_length)
+ max_eff_length = maxlen;
+
+ /* Resolve ego index per episode. */
+ int32_t requested_ego = ego_idx_per_ep ? ego_idx_per_ep[s] : -1;
+ resolved_egos[s] = resolve_ego(requested_ego, max_agents, lens_s);
+ }
+
+ if (max_eff_length == 0)
+ max_eff_length = 1;
+
+ /* Make sure the per-frame instance scratch buffer can hold the
+ * widest episode's agent count. */
+ if (ensure_scratch(ctx, max_agents) != 0) {
+ err = TRAJVIZ_ERR_VK_OOM;
+ goto cleanup;
+ }
+
+ /* Per-frame loop. For each frame, populate every active slot's
+ * per-frame state (instances + camera matrices), then submit one
+ * batched frame. The submit_frame call internally fwrites every
+ * slot's tile to its ffmpeg pipes. */
+ for (uint32_t frame = 0; frame < (uint32_t)max_eff_length && err == TRAJVIZ_OK; ++frame) {
+ for (int s = 0; s < batch_size; ++s) {
+ if (!has_pipe_td[s] && !has_pipe_bev[s])
+ continue;
+
+ /* Episode finished? Skip this slot for this frame. */
+ if ((int32_t)frame >= effective_lengths[s]) {
+ vk_batch_renderer_set_frame(&ctx->vk, &ctx->batch, s, NULL, 0, NULL, NULL);
+ continue;
+ }
+
+ const int32_t *lens_s = agent_lengths + (size_t)s * max_agents;
+ int32_t ego = resolved_egos[s];
+
+ /* Build instance array for this slot for this frame. */
+ uint32_t n_inst = 0;
+ for (uint32_t a = 0; a < max_agents; ++a) {
+ int32_t l = lens_s[a];
+ if (l < 0)
+ l = 0;
+ if ((int32_t)frame >= l)
+ continue;
+ size_t off = (((size_t)s * num_steps + (size_t)frame) * (size_t)max_agents + (size_t)a) * 3;
+ const float *ph = &traj_xyh[off];
+ if (ph[0] == 0.0f && ph[1] == 0.0f)
+ continue;
+
+ AgentInstance *ai = &ctx->scratch_instances[n_inst++];
+ ai->pose[0] = ph[0];
+ ai->pose[1] = ph[1];
+ ai->pose[2] = ph[2];
+ ai->pose[3] = 0.0f;
+ ai->size[0] = DEFAULT_AGENT_LENGTH;
+ ai->size[1] = DEFAULT_AGENT_WIDTH;
+ if ((int32_t)a == ego) {
+ ai->color[0] = 1.00f;
+ ai->color[1] = 0.55f;
+ ai->color[2] = 0.10f;
+ ai->color[3] = 1.0f;
+ } else {
+ ai->color[0] = 0.20f;
+ ai->color[1] = 0.75f;
+ ai->color[2] = 0.85f;
+ ai->color[3] = 1.0f;
+ }
+ }
+
+ /* BEV camera follows the slot's ego at this frame, clamped
+ * to the ego's last valid step if it has terminated. */
+ Mat4 bev_cam;
+ int has_bev = has_pipe_bev[s];
+ if (has_bev) {
+ uint32_t bev_step = frame;
+ int32_t ego_len = lens_s[ego];
+ if (ego_len <= 0)
+ ego_len = 1;
+ if ((int32_t)bev_step >= ego_len)
+ bev_step = (uint32_t)(ego_len - 1);
+ size_t off = (((size_t)s * num_steps + (size_t)bev_step) * (size_t)max_agents + (size_t)ego) * 3;
+ const float *ph = &traj_xyh[off];
+ bev_cam = mat4_bev_camera(ph[0], ph[1], ph[2], 50.0f, (int)ctx->batch.tile_w, (int)ctx->batch.tile_h);
+ }
+
+ int rc = vk_batch_renderer_set_frame(&ctx->vk, &ctx->batch, s, ctx->scratch_instances, n_inst,
+ has_pipe_td[s] ? &topdown_cams[s] : NULL, has_bev ? &bev_cam : NULL);
+ if (rc != 0) {
+ copy_error(ctx);
+ err = TRAJVIZ_ERR_VK_DEVICE;
+ break;
+ }
+ }
+ if (err != TRAJVIZ_OK)
+ break;
+
+ int rc = vk_batch_renderer_submit_frame(&ctx->vk, &ctx->batch);
+ if (rc != 0) {
+ copy_error(ctx);
+ err = (rc == -1) ? TRAJVIZ_ERR_FFMPEG_WRITE : TRAJVIZ_ERR_VK_DEVICE;
+ break;
+ }
+ }
+
+cleanup:
+ for (int s = 0; s < batch_size; ++s) {
+ if (has_pipe_td[s])
+ ffmpeg_pipe_close(&pipes_td[s]);
+ if (has_pipe_bev[s])
+ ffmpeg_pipe_close(&pipes_bev[s]);
+ vk_batch_renderer_close_episode(&ctx->batch, s);
+ }
+ free(topdown_cams);
+ free(effective_lengths);
+ free(resolved_egos);
+ free(pipes_td);
+ free(pipes_bev);
+ free(has_pipe_td);
+ free(has_pipe_bev);
+ return err;
+}
diff --git a/pufferlib/ocean/drive/trajviz/trajviz.h b/pufferlib/ocean/drive/trajviz/trajviz.h
new file mode 100644
index 0000000000..332d994126
--- /dev/null
+++ b/pufferlib/ocean/drive/trajviz/trajviz.h
@@ -0,0 +1,167 @@
+/*
+ * trajviz.h — public C API for the trajviz Vulkan renderer.
+ *
+ * This header is consumed by:
+ * - _native.c (CPython extension shell, the production caller)
+ * - tools/test_main.c (standalone test harness, no Python needed)
+ *
+ * The shape is deliberate: every function takes raw pointers and shapes,
+ * never a file path or a numpy object. The Python wrapper is responsible
+ * for loading .npz / .bin files and slicing them into per-episode arrays;
+ * this layer only knows about geometry and rendering. Keeps the C side
+ * focused, and means the same code works behind a CPython extension or a
+ * test harness without changes.
+ *
+ * Coordinate frame: all positions (road_xy, traj_xyh) are in *mean-centered
+ * sim frame*, the same frame the trajectories live in. The renderer never
+ * sees world_mean — the Python wrapper has already subtracted it from the
+ * road geometry by the time pointers reach this layer.
+ *
+ * Heading convention: radians, math convention (0 = +x, pi/2 = +y, CCW).
+ *
+ * Lifecycle:
+ * ctx = trajviz_init(W, H);
+ * for each episode:
+ * trajviz_render_episode(ctx, ...); // blocks until both MP4s closed
+ * trajviz_close(ctx);
+ *
+ * The Vulkan context (instance, device, queues, pipelines, shaders) is
+ * created in init() and reused across all render_episode calls. This is
+ * the whole point of having a stateful API: a single trajviz_init pays
+ * the ~50ms Vulkan startup cost once for an entire batch of episodes.
+ */
+
+#ifndef TRAJVIZ_H
+#define TRAJVIZ_H
+
+#include
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+/* Opaque renderer handle. Holds the Vulkan instance, device, command pool,
+ * pipelines, and reusable per-frame buffers. Created by trajviz_init,
+ * destroyed by trajviz_close. Not thread-safe — one ctx per thread. */
+typedef struct TrajvizCtx TrajvizCtx;
+
+/* Return codes for render_episode. 0 = success, negative = error class. */
+#define TRAJVIZ_OK 0
+#define TRAJVIZ_ERR_BAD_ARG -1
+#define TRAJVIZ_ERR_VK_DEVICE -2 /* a vulkan device call failed */
+#define TRAJVIZ_ERR_VK_OOM -3
+#define TRAJVIZ_ERR_FFMPEG_SPAWN -4 /* could not popen ffmpeg */
+#define TRAJVIZ_ERR_FFMPEG_WRITE -5 /* fwrite to ffmpeg pipe failed */
+#define TRAJVIZ_ERR_NO_DEVICE -6 /* no vulkan-capable physical device */
+
+/* Get the last error message set by a failed call. The returned string is
+ * owned by the ctx (or by global state if ctx is NULL — for init failures)
+ * and stays valid until the next trajviz_* call on the same ctx. */
+const char *trajviz_last_error(const TrajvizCtx *ctx);
+
+/* Create a renderer.
+ *
+ * width / height: pixel dimensions of each output video frame. Both views
+ * (top-down and BEV) render at the same size — the BEV could be rendered
+ * smaller, but matching sizes keeps the GPU pipeline state count low and
+ * makes side-by-side video composition trivial.
+ *
+ * Returns NULL on failure; call trajviz_last_error(NULL) to get a message. */
+TrajvizCtx *trajviz_init(int width, int height);
+
+/* Render one episode.
+ *
+ * Geometry inputs (all read-only, never retained past this call):
+ * road_xy: (num_road_verts, 2) float32, packed xy pairs
+ * road_offsets: (num_road_polys + 1,) uint32 — CSR-style; polyline i's
+ * vertices are road_xy[road_offsets[i]..road_offsets[i+1]]
+ * road_types: (num_road_polys,) uint32 — TVZ_ROAD_* type ids; the
+ * renderer maps these to colors
+ * num_road_polys: number of polylines
+ *
+ * traj_xyh: (num_steps, num_agents, 3) float32, step-major.
+ * Per-frame: traj_xyh[t * num_agents * 3 + a * 3 + {0,1,2}]
+ * = (x, y, heading) of agent a at step t. Step-major
+ * layout means one frame's worth is contiguous, which
+ * is exactly what the per-frame upload wants.
+ *
+ * agent_dims: (num_agents, 2) float32 — (length, width) per agent.
+ * If NULL, the renderer uses default car dimensions
+ * (5.0 x 2.0 m).
+ * agent_lengths: (num_agents,) int32 — valid step count per agent. The
+ * renderer skips drawing agents past their length and
+ * ends the episode at max(agent_lengths). NULL = treat
+ * all agents as fully valid.
+ * ego_idx: which agent the BEV view follows. Negative values
+ * pick the first agent with length >= 2.
+ *
+ * Output paths (must be writable):
+ * out_topdown_mp4: full-map ortho top-down view, NULL to skip
+ * out_bev_mp4: agent-centric BEV (RenderView.BEV_AGENT_OBS), NULL to skip
+ *
+ * Other:
+ * fps: output video framerate, e.g. 30
+ *
+ * Returns TRAJVIZ_OK or a negative error code; call trajviz_last_error(ctx)
+ * for the message. */
+int trajviz_render_episode(TrajvizCtx *ctx, const float *road_xy, const uint32_t *road_offsets,
+ const uint32_t *road_types, uint32_t num_road_polys, const float *traj_xyh,
+ uint32_t num_steps, uint32_t num_agents, const float *agent_dims,
+ const int32_t *agent_lengths, int32_t ego_idx, int fps, const char *out_topdown_mp4,
+ const char *out_bev_mp4);
+
+/* Tear down. Idempotent; passing NULL is a no-op. */
+void trajviz_close(TrajvizCtx *ctx);
+
+/* Render a batch of episodes simultaneously. All episodes are tiled
+ * into a per-view atlas image and drawn in one command-buffer per
+ * frame; one queue submit + one fence wait per frame covers
+ * batch_size episodes' worth of work. Per-episode wall time should
+ * drop by roughly batch_size× compared to calling render_episode
+ * batch_size times sequentially (until the GPU saturates).
+ *
+ * For v1 the batch requires uniform num_steps and max_agents across
+ * all episodes — pad shorter trajectories with zeros and use
+ * agent_lengths to mark valid steps. Roads are ragged: each episode
+ * has its own road geometry packed end-to-end with CSR-style offsets.
+ *
+ * Concatenation layout for road data:
+ * all_road_xy: (V_total, 2) float32 — V_total = sum of vert counts
+ * vert_offsets: (batch_size + 1,) uint32 — episode i's verts are
+ * all_road_xy[vert_offsets[i] : vert_offsets[i+1]]
+ * all_road_offsets: (P_total + batch_size,) uint32 — episode i's CSR
+ * offsets are all_road_offsets[poly_meta_offsets[i] :
+ * poly_meta_offsets[i] + num_polys_i + 1] — these
+ * offsets are RELATIVE to episode i's vert range
+ * (i.e. they index into the slice of all_road_xy)
+ * poly_meta_offsets: (batch_size + 1,) uint32 — episode i's poly count
+ * is poly_meta_offsets[i+1] - poly_meta_offsets[i] - 1
+ * (one extra entry per ep for the closing offset)
+ * all_road_types: (P_total,) uint32 — type ids, packed without padding
+ * poly_type_offsets: (batch_size + 1,) uint32 — index into all_road_types
+ *
+ * Trajectories are uniform shape: (batch_size, num_steps, max_agents, 3).
+ *
+ * Output paths are an array of C strings; NULL entries skip that view
+ * for that episode.
+ *
+ * Returns TRAJVIZ_OK or a negative error code. */
+int trajviz_render_episodes_batch(TrajvizCtx *ctx, int batch_size, uint32_t num_steps, uint32_t max_agents,
+ const float *all_road_xy, const uint32_t *vert_offsets,
+ const uint32_t *all_road_offsets, const uint32_t *poly_meta_offsets,
+ const uint32_t *all_road_types, const uint32_t *poly_type_offsets,
+ const float *traj_xyh, const int32_t *agent_lengths, const int32_t *ego_idx_per_ep,
+ int fps, const char **out_topdown_paths, const char **out_bev_paths);
+
+/* Road type ids — copied from drive.h. The renderer hardcodes a color for
+ * each. Unknown types render in a default gray. */
+#define TVZ_ROAD_LANE 4
+#define TVZ_ROAD_LINE 5
+#define TVZ_ROAD_EDGE 6
+#define TVZ_ROAD_DRIVEWAY 10
+
+#ifdef __cplusplus
+}
+#endif
+
+#endif /* TRAJVIZ_H */
diff --git a/pufferlib/ocean/drive/trajviz/vk_batch_renderer.c b/pufferlib/ocean/drive/trajviz/vk_batch_renderer.c
new file mode 100644
index 0000000000..b78b46e0b0
--- /dev/null
+++ b/pufferlib/ocean/drive/trajviz/vk_batch_renderer.c
@@ -0,0 +1,694 @@
+/*
+ * vk_batch_renderer.c — multi-episode batched renderer.
+ *
+ * Tiles N episodes into a single atlas per view, draws them all in one
+ * command-buffer recording, submits once per frame. The dominant per-
+ * frame Vulkan + ffmpeg overhead in the single-episode path was per-
+ * submit / per-pipe latency, so this should drop per-episode wall time
+ * by roughly 1/N as long as the GPU isn't already saturated.
+ *
+ * Buffer + image helper plumbing is duplicated from vk_renderer.c to
+ * keep this module self-contained — they're small (~50 lines apiece),
+ * have stable signatures, and the alternative is plumbing yet another
+ * shared header. The duplicates can be merged later if either side
+ * grows complex.
+ */
+
+#include "vk_batch_renderer.h"
+#include "shaders.h"
+
+#include
+#include
+#include
+#include
+
+/* ----------------------------- buffer helpers ----------------------------- */
+
+static int br_create_buffer(VkCtx *ctx, VkDeviceSize size, VkBufferUsageFlags usage, VkMemoryPropertyFlags mem_props,
+ int map_persistent, VkBufferM *out) {
+ memset(out, 0, sizeof(*out));
+ out->size = size;
+
+ VkBufferCreateInfo bci = {
+ .sType = VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO,
+ .size = size,
+ .usage = usage,
+ .sharingMode = VK_SHARING_MODE_EXCLUSIVE,
+ };
+ VK_CHECK(vkCreateBuffer(ctx->device, &bci, NULL, &out->buffer));
+
+ VkMemoryRequirements req;
+ vkGetBufferMemoryRequirements(ctx->device, out->buffer, &req);
+
+ uint32_t mem_idx = vk_find_memory_type(ctx, req.memoryTypeBits, mem_props);
+ if (mem_idx == UINT32_MAX) {
+ vk_ctx_set_error(ctx, "no memory type matches buffer requirements (props=0x%x)", (unsigned)mem_props);
+ return -1;
+ }
+
+ VkMemoryAllocateInfo mai = {
+ .sType = VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO,
+ .allocationSize = req.size,
+ .memoryTypeIndex = mem_idx,
+ };
+ VK_CHECK(vkAllocateMemory(ctx->device, &mai, NULL, &out->memory));
+ VK_CHECK(vkBindBufferMemory(ctx->device, out->buffer, out->memory, 0));
+
+ if (map_persistent) {
+ VK_CHECK(vkMapMemory(ctx->device, out->memory, 0, VK_WHOLE_SIZE, 0, &out->mapped));
+ }
+ return 0;
+}
+
+static void br_destroy_buffer(VkCtx *ctx, VkBufferM *b) {
+ if (!b || !ctx)
+ return;
+ if (b->mapped && b->memory) {
+ vkUnmapMemory(ctx->device, b->memory);
+ b->mapped = NULL;
+ }
+ if (b->buffer) {
+ vkDestroyBuffer(ctx->device, b->buffer, NULL);
+ b->buffer = VK_NULL_HANDLE;
+ }
+ if (b->memory) {
+ vkFreeMemory(ctx->device, b->memory, NULL);
+ b->memory = VK_NULL_HANDLE;
+ }
+ b->size = 0;
+}
+
+static int br_ensure_buffer_capacity(VkCtx *ctx, VkBufferM *b, VkDeviceSize required, VkBufferUsageFlags usage,
+ VkMemoryPropertyFlags mem_props) {
+ if (b->size >= required)
+ return 0;
+ br_destroy_buffer(ctx, b);
+ VkDeviceSize cap = 256;
+ while (cap < required)
+ cap <<= 1;
+ return br_create_buffer(ctx, cap, usage, mem_props, 1, b);
+}
+
+/* ------------------------------ image helpers ------------------------------ */
+
+static int br_create_image(VkCtx *ctx, uint32_t w, uint32_t h, VkFormat format, VkImageUsageFlags usage,
+ VkImageM *out) {
+ memset(out, 0, sizeof(*out));
+ out->width = w;
+ out->height = h;
+ out->format = format;
+
+ VkImageCreateInfo ici = {
+ .sType = VK_STRUCTURE_TYPE_IMAGE_CREATE_INFO,
+ .imageType = VK_IMAGE_TYPE_2D,
+ .format = format,
+ .extent = {w, h, 1},
+ .mipLevels = 1,
+ .arrayLayers = 1,
+ .samples = VK_SAMPLE_COUNT_1_BIT,
+ .tiling = VK_IMAGE_TILING_OPTIMAL,
+ .usage = usage,
+ .sharingMode = VK_SHARING_MODE_EXCLUSIVE,
+ .initialLayout = VK_IMAGE_LAYOUT_UNDEFINED,
+ };
+ VK_CHECK(vkCreateImage(ctx->device, &ici, NULL, &out->image));
+
+ VkMemoryRequirements req;
+ vkGetImageMemoryRequirements(ctx->device, out->image, &req);
+
+ uint32_t mem_idx = vk_find_memory_type(ctx, req.memoryTypeBits, VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT);
+ if (mem_idx == UINT32_MAX) {
+ vk_ctx_set_error(ctx, "no DEVICE_LOCAL memory type for atlas image");
+ return -1;
+ }
+
+ VkMemoryAllocateInfo mai = {
+ .sType = VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO,
+ .allocationSize = req.size,
+ .memoryTypeIndex = mem_idx,
+ };
+ VK_CHECK(vkAllocateMemory(ctx->device, &mai, NULL, &out->memory));
+ VK_CHECK(vkBindImageMemory(ctx->device, out->image, out->memory, 0));
+
+ VkImageViewCreateInfo vci = {
+ .sType = VK_STRUCTURE_TYPE_IMAGE_VIEW_CREATE_INFO,
+ .image = out->image,
+ .viewType = VK_IMAGE_VIEW_TYPE_2D,
+ .format = format,
+ .subresourceRange =
+ {
+ .aspectMask = VK_IMAGE_ASPECT_COLOR_BIT,
+ .baseMipLevel = 0,
+ .levelCount = 1,
+ .baseArrayLayer = 0,
+ .layerCount = 1,
+ },
+ };
+ VK_CHECK(vkCreateImageView(ctx->device, &vci, NULL, &out->view));
+ return 0;
+}
+
+static void br_destroy_image(VkCtx *ctx, VkImageM *im) {
+ if (!im || !ctx)
+ return;
+ if (im->view) {
+ vkDestroyImageView(ctx->device, im->view, NULL);
+ im->view = VK_NULL_HANDLE;
+ }
+ if (im->image) {
+ vkDestroyImage(ctx->device, im->image, NULL);
+ im->image = VK_NULL_HANDLE;
+ }
+ if (im->memory) {
+ vkFreeMemory(ctx->device, im->memory, NULL);
+ im->memory = VK_NULL_HANDLE;
+ }
+}
+
+/* --------------------------- color lookup (shared) -------------------------- */
+
+static void color_for_road_type(uint32_t type, float out[4]) {
+ out[3] = 1.0f;
+ switch (type) {
+ case 6:
+ out[0] = 0.55f;
+ out[1] = 0.55f;
+ out[2] = 0.55f;
+ break;
+ case 4:
+ out[0] = 0.85f;
+ out[1] = 0.78f;
+ out[2] = 0.30f;
+ out[3] = 0.6f;
+ break;
+ case 5:
+ out[0] = 0.95f;
+ out[1] = 0.95f;
+ out[2] = 0.95f;
+ out[3] = 0.5f;
+ break;
+ case 10:
+ out[0] = 0.40f;
+ out[1] = 0.40f;
+ out[2] = 0.55f;
+ out[3] = 0.7f;
+ break;
+ default:
+ out[0] = 0.45f;
+ out[1] = 0.45f;
+ out[2] = 0.45f;
+ break;
+ }
+}
+
+/* --------------------------- init / destroy --------------------------- */
+
+static int upload_static_quad(VkCtx *ctx, BatchRenderer *br) {
+ const float quad[8] = {-1, -1, 1, -1, 1, 1, -1, 1};
+ const uint16_t idx[6] = {0, 1, 2, 0, 2, 3};
+ int rc;
+
+ rc = br_create_buffer(ctx, sizeof(quad), VK_BUFFER_USAGE_VERTEX_BUFFER_BIT,
+ VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT | VK_MEMORY_PROPERTY_HOST_COHERENT_BIT, 1,
+ &br->unit_quad_vb);
+ if (rc != 0)
+ return rc;
+ memcpy(br->unit_quad_vb.mapped, quad, sizeof(quad));
+
+ rc = br_create_buffer(ctx, sizeof(idx), VK_BUFFER_USAGE_INDEX_BUFFER_BIT,
+ VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT | VK_MEMORY_PROPERTY_HOST_COHERENT_BIT, 1,
+ &br->unit_quad_ib);
+ if (rc != 0)
+ return rc;
+ memcpy(br->unit_quad_ib.mapped, idx, sizeof(idx));
+ return 0;
+}
+
+int vk_batch_renderer_init(VkCtx *ctx, Pipelines *p, BatchRenderer *br, int batch_n, uint32_t tile_w, uint32_t tile_h) {
+ if (batch_n <= 0 || tile_w == 0 || tile_h == 0) {
+ vk_ctx_set_error(ctx, "vk_batch_renderer_init: invalid args (n=%d w=%u h=%u)", batch_n, tile_w, tile_h);
+ return -1;
+ }
+
+ memset(br, 0, sizeof(*br));
+ br->pipelines = p;
+ br->batch_n = batch_n;
+ br->tile_w = tile_w;
+ br->tile_h = tile_h;
+
+ br->slots = (BatchSlot *)calloc((size_t)batch_n, sizeof(BatchSlot));
+ if (!br->slots) {
+ vk_ctx_set_error(ctx, "out of memory allocating %d batch slots", batch_n);
+ return -1;
+ }
+
+ int rc;
+ if ((rc = upload_static_quad(ctx, br)) != 0)
+ goto fail;
+
+ /* Atlas dimensions: tile_w wide, batch_n * tile_h tall (vertical
+ * stacking → contiguous tile bytes in the readback buffer). Vulkan's
+ * maxImageDimension2D is at least 4096 on every spec-compliant device
+ * and typically 16384+ on real GPUs, so batch_n up to ~22 at 720p
+ * height is safe before we need to fall back to multiple passes. */
+ uint32_t atlas_h = (uint32_t)batch_n * tile_h;
+
+ if ((rc = br_create_image(ctx, tile_w, atlas_h, VK_FORMAT_R8G8B8A8_UNORM,
+ VK_IMAGE_USAGE_COLOR_ATTACHMENT_BIT | VK_IMAGE_USAGE_TRANSFER_SRC_BIT,
+ &br->atlas_topdown)) != 0)
+ goto fail;
+ if ((rc = br_create_image(ctx, tile_w, atlas_h, VK_FORMAT_R8G8B8A8_UNORM,
+ VK_IMAGE_USAGE_COLOR_ATTACHMENT_BIT | VK_IMAGE_USAGE_TRANSFER_SRC_BIT, &br->atlas_bev)) !=
+ 0)
+ goto fail;
+
+ /* Readback buffers: prefer HOST_CACHED so the CPU can read them at
+ * full RAM bandwidth. The default HOST_VISIBLE+HOST_COHERENT path
+ * on NVIDIA picks a write-combined memory type — fast for GPU
+ * writes but ~250 MB/s for CPU reads (uncached PCIe BAR), which is
+ * by far the dominant cost when piping frames to ffmpeg. With
+ * HOST_CACHED, reads hit RAM at >5 GB/s. */
+ VkDeviceSize readback_size = (VkDeviceSize)tile_w * (VkDeviceSize)atlas_h * 4;
+ VkMemoryPropertyFlags readback_props =
+ VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT | VK_MEMORY_PROPERTY_HOST_COHERENT_BIT | VK_MEMORY_PROPERTY_HOST_CACHED_BIT;
+ if ((rc = br_create_buffer(ctx, readback_size, VK_BUFFER_USAGE_TRANSFER_DST_BIT, readback_props, 1,
+ &br->readback_topdown)) != 0) {
+ /* Fall back without HOST_CACHED if the device doesn't expose it. */
+ readback_props = VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT | VK_MEMORY_PROPERTY_HOST_COHERENT_BIT;
+ if ((rc = br_create_buffer(ctx, readback_size, VK_BUFFER_USAGE_TRANSFER_DST_BIT, readback_props, 1,
+ &br->readback_topdown)) != 0)
+ goto fail;
+ }
+ if ((rc = br_create_buffer(ctx, readback_size, VK_BUFFER_USAGE_TRANSFER_DST_BIT, readback_props, 1,
+ &br->readback_bev)) != 0)
+ goto fail;
+
+ VkCommandBufferAllocateInfo cai = {
+ .sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_ALLOCATE_INFO,
+ .commandPool = ctx->command_pool,
+ .level = VK_COMMAND_BUFFER_LEVEL_PRIMARY,
+ .commandBufferCount = 1,
+ };
+ VK_CHECK(vkAllocateCommandBuffers(ctx->device, &cai, &br->cmd));
+
+ VkFenceCreateInfo fci = {
+ .sType = VK_STRUCTURE_TYPE_FENCE_CREATE_INFO,
+ .flags = 0,
+ };
+ VK_CHECK(vkCreateFence(ctx->device, &fci, NULL, &br->fence));
+
+ return 0;
+
+fail:
+ vk_batch_renderer_destroy(ctx, br);
+ return rc;
+}
+
+void vk_batch_renderer_destroy(VkCtx *ctx, BatchRenderer *br) {
+ if (!br || !ctx)
+ return;
+ if (br->fence) {
+ vkDestroyFence(ctx->device, br->fence, NULL);
+ br->fence = VK_NULL_HANDLE;
+ }
+ if (br->cmd) {
+ vkFreeCommandBuffers(ctx->device, ctx->command_pool, 1, &br->cmd);
+ br->cmd = VK_NULL_HANDLE;
+ }
+ br_destroy_image(ctx, &br->atlas_topdown);
+ br_destroy_image(ctx, &br->atlas_bev);
+ br_destroy_buffer(ctx, &br->readback_topdown);
+ br_destroy_buffer(ctx, &br->readback_bev);
+ br_destroy_buffer(ctx, &br->unit_quad_vb);
+ br_destroy_buffer(ctx, &br->unit_quad_ib);
+ if (br->slots) {
+ for (int i = 0; i < br->batch_n; ++i) {
+ br_destroy_buffer(ctx, &br->slots[i].road_vb);
+ br_destroy_buffer(ctx, &br->slots[i].agent_inst_vb);
+ free(br->slots[i].road_offsets);
+ free(br->slots[i].road_types);
+ }
+ free(br->slots);
+ br->slots = NULL;
+ }
+ br->batch_n = 0;
+}
+
+/* --------------------------- per-slot configuration --------------------------- */
+
+int vk_batch_renderer_set_episode(VkCtx *ctx, BatchRenderer *br, int slot, const float *road_xy, uint32_t num_verts,
+ const uint32_t *road_offsets, const uint32_t *road_types, uint32_t num_polys,
+ FfmpegPipe *pipe_topdown, FfmpegPipe *pipe_bev) {
+ if (slot < 0 || slot >= br->batch_n) {
+ vk_ctx_set_error(ctx, "set_episode: slot %d out of range [0, %d)", slot, br->batch_n);
+ return -1;
+ }
+ BatchSlot *s = &br->slots[slot];
+
+ /* Resize road vb if needed and upload. */
+ VkDeviceSize required = (VkDeviceSize)num_verts * sizeof(float) * 2;
+ if (required == 0)
+ required = sizeof(float) * 2;
+ int rc = br_ensure_buffer_capacity(ctx, &s->road_vb, required, VK_BUFFER_USAGE_VERTEX_BUFFER_BIT,
+ VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT | VK_MEMORY_PROPERTY_HOST_COHERENT_BIT);
+ if (rc != 0)
+ return rc;
+ if (num_verts > 0) {
+ memcpy(s->road_vb.mapped, road_xy, (size_t)num_verts * sizeof(float) * 2);
+ }
+ s->road_vb_capacity = num_verts;
+
+ /* Host-side metadata copy. */
+ if (num_polys + 1 > s->road_meta_capacity) {
+ free(s->road_offsets);
+ free(s->road_types);
+ s->road_meta_capacity = num_polys + 1;
+ s->road_offsets = (uint32_t *)malloc(sizeof(uint32_t) * (num_polys + 1));
+ s->road_types = (uint32_t *)malloc(sizeof(uint32_t) * num_polys);
+ if (!s->road_offsets || !s->road_types) {
+ vk_ctx_set_error(ctx, "out of host memory for road metadata in slot %d", slot);
+ return -1;
+ }
+ }
+ if (num_polys > 0) {
+ memcpy(s->road_offsets, road_offsets, sizeof(uint32_t) * (num_polys + 1));
+ memcpy(s->road_types, road_types, sizeof(uint32_t) * num_polys);
+ }
+ s->num_polys = num_polys;
+
+ s->pipe_topdown = pipe_topdown;
+ s->pipe_bev = pipe_bev;
+ s->active = 1;
+ s->current_n_instances = 0;
+ s->has_topdown_this_frame = 0;
+ s->has_bev_this_frame = 0;
+ return 0;
+}
+
+int vk_batch_renderer_set_frame(VkCtx *ctx, BatchRenderer *br, int slot, const AgentInstance *instances,
+ uint32_t num_instances, const Mat4 *mvp_topdown, const Mat4 *mvp_bev) {
+ if (slot < 0 || slot >= br->batch_n) {
+ vk_ctx_set_error(ctx, "set_frame: slot %d out of range", slot);
+ return -1;
+ }
+ BatchSlot *s = &br->slots[slot];
+ if (!s->active)
+ return 0;
+
+ if (num_instances > 0) {
+ VkDeviceSize required = (VkDeviceSize)num_instances * sizeof(AgentInstance);
+ int rc = br_ensure_buffer_capacity(ctx, &s->agent_inst_vb, required, VK_BUFFER_USAGE_VERTEX_BUFFER_BIT,
+ VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT | VK_MEMORY_PROPERTY_HOST_COHERENT_BIT);
+ if (rc != 0)
+ return rc;
+ memcpy(s->agent_inst_vb.mapped, instances, (size_t)num_instances * sizeof(AgentInstance));
+ s->agent_inst_capacity = num_instances;
+ }
+ s->current_n_instances = num_instances;
+
+ s->has_topdown_this_frame = (mvp_topdown != NULL) && (s->pipe_topdown != NULL);
+ s->has_bev_this_frame = (mvp_bev != NULL) && (s->pipe_bev != NULL);
+ if (s->has_topdown_this_frame)
+ s->mvp_topdown = *mvp_topdown;
+ if (s->has_bev_this_frame)
+ s->mvp_bev = *mvp_bev;
+ return 0;
+}
+
+void vk_batch_renderer_close_episode(BatchRenderer *br, int slot) {
+ if (slot < 0 || slot >= br->batch_n)
+ return;
+ BatchSlot *s = &br->slots[slot];
+ s->active = 0;
+ s->pipe_topdown = NULL;
+ s->pipe_bev = NULL;
+ s->has_topdown_this_frame = 0;
+ s->has_bev_this_frame = 0;
+}
+
+/* --------------------------- per-frame submit --------------------------- */
+
+static void barrier_image(VkCommandBuffer cmd, VkImage image, VkImageLayout old_layout, VkImageLayout new_layout,
+ VkPipelineStageFlags2 src_stage, VkAccessFlags2 src_access, VkPipelineStageFlags2 dst_stage,
+ VkAccessFlags2 dst_access) {
+ VkImageMemoryBarrier2 imb = {
+ .sType = VK_STRUCTURE_TYPE_IMAGE_MEMORY_BARRIER_2,
+ .srcStageMask = src_stage,
+ .srcAccessMask = src_access,
+ .dstStageMask = dst_stage,
+ .dstAccessMask = dst_access,
+ .oldLayout = old_layout,
+ .newLayout = new_layout,
+ .srcQueueFamilyIndex = VK_QUEUE_FAMILY_IGNORED,
+ .dstQueueFamilyIndex = VK_QUEUE_FAMILY_IGNORED,
+ .image = image,
+ .subresourceRange =
+ {
+ .aspectMask = VK_IMAGE_ASPECT_COLOR_BIT,
+ .baseMipLevel = 0,
+ .levelCount = 1,
+ .baseArrayLayer = 0,
+ .layerCount = 1,
+ },
+ };
+ VkDependencyInfo di = {
+ .sType = VK_STRUCTURE_TYPE_DEPENDENCY_INFO,
+ .imageMemoryBarrierCount = 1,
+ .pImageMemoryBarriers = &imb,
+ };
+ vkCmdPipelineBarrier2(cmd, &di);
+}
+
+/* Record one full atlas pass: render every active slot's tile, then
+ * copy the atlas image to its host-visible readback buffer. */
+static void record_atlas_pass(VkCommandBuffer cmd, BatchRenderer *br, VkImage atlas_image, VkImageView atlas_view,
+ VkBuffer readback_buffer, int is_bev_view) {
+ barrier_image(cmd, atlas_image, VK_IMAGE_LAYOUT_UNDEFINED, VK_IMAGE_LAYOUT_COLOR_ATTACHMENT_OPTIMAL,
+ VK_PIPELINE_STAGE_2_TOP_OF_PIPE_BIT, 0, VK_PIPELINE_STAGE_2_COLOR_ATTACHMENT_OUTPUT_BIT,
+ VK_ACCESS_2_COLOR_ATTACHMENT_WRITE_BIT);
+
+ VkClearValue clear = {.color = {.float32 = {0.05f, 0.05f, 0.08f, 1.0f}}};
+ VkRenderingAttachmentInfo att = {
+ .sType = VK_STRUCTURE_TYPE_RENDERING_ATTACHMENT_INFO,
+ .imageView = atlas_view,
+ .imageLayout = VK_IMAGE_LAYOUT_COLOR_ATTACHMENT_OPTIMAL,
+ .loadOp = VK_ATTACHMENT_LOAD_OP_CLEAR,
+ .storeOp = VK_ATTACHMENT_STORE_OP_STORE,
+ .clearValue = clear,
+ };
+ uint32_t atlas_h = (uint32_t)br->batch_n * br->tile_h;
+ VkRenderingInfo ri = {
+ .sType = VK_STRUCTURE_TYPE_RENDERING_INFO,
+ .renderArea = {.offset = {0, 0}, .extent = {br->tile_w, atlas_h}},
+ .layerCount = 1,
+ .colorAttachmentCount = 1,
+ .pColorAttachments = &att,
+ };
+ vkCmdBeginRendering(cmd, &ri);
+
+ /* Render each active slot into its tile. */
+ for (int i = 0; i < br->batch_n; ++i) {
+ BatchSlot *s = &br->slots[i];
+ if (!s->active)
+ continue;
+
+ int has_view = is_bev_view ? s->has_bev_this_frame : s->has_topdown_this_frame;
+ if (!has_view)
+ continue;
+
+ /* Tile rect: full width, slice [i*tile_h, (i+1)*tile_h) vertically. */
+ VkViewport vp = {
+ .x = 0.0f,
+ .y = (float)(i * br->tile_h),
+ .width = (float)br->tile_w,
+ .height = (float)br->tile_h,
+ .minDepth = 0.0f,
+ .maxDepth = 1.0f,
+ };
+ VkRect2D sc = {
+ .offset = {0, (int32_t)(i * br->tile_h)},
+ .extent = {br->tile_w, br->tile_h},
+ };
+ vkCmdSetViewport(cmd, 0, 1, &vp);
+ vkCmdSetScissor(cmd, 0, 1, &sc);
+
+ const Mat4 *mvp = is_bev_view ? &s->mvp_bev : &s->mvp_topdown;
+
+ /* Road polylines for this slot */
+ if (s->num_polys > 0) {
+ vkCmdBindPipeline(cmd, VK_PIPELINE_BIND_POINT_GRAPHICS, br->pipelines->line_pipeline);
+ VkDeviceSize voff = 0;
+ vkCmdBindVertexBuffers(cmd, 0, 1, &s->road_vb.buffer, &voff);
+
+ PushConstants pc;
+ memcpy(pc.mvp, mvp->m, sizeof(pc.mvp));
+
+ for (uint32_t j = 0; j < s->num_polys; ++j) {
+ uint32_t start = s->road_offsets[j];
+ uint32_t end = s->road_offsets[j + 1];
+ if (end <= start + 1)
+ continue;
+
+ color_for_road_type(s->road_types[j], pc.color);
+ vkCmdPushConstants(cmd, br->pipelines->layout,
+ VK_SHADER_STAGE_VERTEX_BIT | VK_SHADER_STAGE_FRAGMENT_BIT, 0, sizeof(pc), &pc);
+ vkCmdDraw(cmd, end - start, 1, start, 0);
+ }
+ }
+
+ /* Agent boxes for this slot */
+ if (s->current_n_instances > 0) {
+ vkCmdBindPipeline(cmd, VK_PIPELINE_BIND_POINT_GRAPHICS, br->pipelines->box_pipeline);
+ VkBuffer vbufs[2] = {br->unit_quad_vb.buffer, s->agent_inst_vb.buffer};
+ VkDeviceSize voffs[2] = {0, 0};
+ vkCmdBindVertexBuffers(cmd, 0, 2, vbufs, voffs);
+ vkCmdBindIndexBuffer(cmd, br->unit_quad_ib.buffer, 0, VK_INDEX_TYPE_UINT16);
+
+ PushConstants pc;
+ memcpy(pc.mvp, mvp->m, sizeof(pc.mvp));
+ pc.color[0] = pc.color[1] = pc.color[2] = pc.color[3] = 1.0f;
+ vkCmdPushConstants(cmd, br->pipelines->layout, VK_SHADER_STAGE_VERTEX_BIT | VK_SHADER_STAGE_FRAGMENT_BIT, 0,
+ sizeof(pc), &pc);
+
+ vkCmdDrawIndexed(cmd, 6, s->current_n_instances, 0, 0, 0);
+ }
+ }
+
+ vkCmdEndRendering(cmd);
+
+ /* Atlas → TRANSFER_SRC, copy to readback. */
+ barrier_image(cmd, atlas_image, VK_IMAGE_LAYOUT_COLOR_ATTACHMENT_OPTIMAL, VK_IMAGE_LAYOUT_TRANSFER_SRC_OPTIMAL,
+ VK_PIPELINE_STAGE_2_COLOR_ATTACHMENT_OUTPUT_BIT, VK_ACCESS_2_COLOR_ATTACHMENT_WRITE_BIT,
+ VK_PIPELINE_STAGE_2_COPY_BIT, VK_ACCESS_2_TRANSFER_READ_BIT);
+
+ VkBufferImageCopy region = {
+ .bufferOffset = 0,
+ .bufferRowLength = 0,
+ .bufferImageHeight = 0,
+ .imageSubresource =
+ {
+ .aspectMask = VK_IMAGE_ASPECT_COLOR_BIT,
+ .mipLevel = 0,
+ .baseArrayLayer = 0,
+ .layerCount = 1,
+ },
+ .imageOffset = {0, 0, 0},
+ .imageExtent = {br->tile_w, atlas_h, 1},
+ };
+ vkCmdCopyImageToBuffer(cmd, atlas_image, VK_IMAGE_LAYOUT_TRANSFER_SRC_OPTIMAL, readback_buffer, 1, ®ion);
+}
+
+int vk_batch_renderer_submit_frame(VkCtx *ctx, BatchRenderer *br) {
+ /* Quick exit: nothing to render this frame. */
+ int any_topdown = 0, any_bev = 0;
+ for (int i = 0; i < br->batch_n; ++i) {
+ if (!br->slots[i].active)
+ continue;
+ if (br->slots[i].has_topdown_this_frame)
+ any_topdown = 1;
+ if (br->slots[i].has_bev_this_frame)
+ any_bev = 1;
+ }
+ if (!any_topdown && !any_bev)
+ return 0;
+
+ VK_CHECK(vkResetCommandBuffer(br->cmd, 0));
+ VkCommandBufferBeginInfo bi = {
+ .sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_BEGIN_INFO,
+ .flags = VK_COMMAND_BUFFER_USAGE_ONE_TIME_SUBMIT_BIT,
+ };
+ VK_CHECK(vkBeginCommandBuffer(br->cmd, &bi));
+
+ if (any_topdown) {
+ record_atlas_pass(br->cmd, br, br->atlas_topdown.image, br->atlas_topdown.view, br->readback_topdown.buffer,
+ /*is_bev_view=*/0);
+ }
+ if (any_bev) {
+ record_atlas_pass(br->cmd, br, br->atlas_bev.image, br->atlas_bev.view, br->readback_bev.buffer,
+ /*is_bev_view=*/1);
+ }
+
+ VkMemoryBarrier2 mb = {
+ .sType = VK_STRUCTURE_TYPE_MEMORY_BARRIER_2,
+ .srcStageMask = VK_PIPELINE_STAGE_2_COPY_BIT,
+ .srcAccessMask = VK_ACCESS_2_TRANSFER_WRITE_BIT,
+ .dstStageMask = VK_PIPELINE_STAGE_2_HOST_BIT,
+ .dstAccessMask = VK_ACCESS_2_HOST_READ_BIT,
+ };
+ VkDependencyInfo di = {
+ .sType = VK_STRUCTURE_TYPE_DEPENDENCY_INFO,
+ .memoryBarrierCount = 1,
+ .pMemoryBarriers = &mb,
+ };
+ vkCmdPipelineBarrier2(br->cmd, &di);
+
+ VK_CHECK(vkEndCommandBuffer(br->cmd));
+
+ VkCommandBufferSubmitInfo csi = {
+ .sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_SUBMIT_INFO,
+ .commandBuffer = br->cmd,
+ };
+ VkSubmitInfo2 si = {
+ .sType = VK_STRUCTURE_TYPE_SUBMIT_INFO_2,
+ .commandBufferInfoCount = 1,
+ .pCommandBufferInfos = &csi,
+ };
+ VK_CHECK(vkQueueSubmit2(ctx->graphics_queue, 1, &si, br->fence));
+ VK_CHECK(vkWaitForFences(ctx->device, 1, &br->fence, VK_TRUE, UINT64_MAX));
+ VK_CHECK(vkResetFences(ctx->device, 1, &br->fence));
+
+ /* Fan out each slot's tile to its ffmpeg pipes' writer threads in
+ * parallel, then wait for all of them. Each pipe has its own
+ * background thread (see ffmpeg_pipe.c), so the wall time of this
+ * phase is max(single fwrite) instead of sum(fwrites) — which is
+ * a ~Nx win for batch_size N when the per-slot write is the
+ * dominant cost.
+ *
+ * Tile bytes are row-contiguous in the readback buffer thanks to
+ * vertical stacking, so each tile is one (tile_w * tile_h * 4)-
+ * byte slab at offset (i * tile_bytes). */
+ size_t tile_bytes = (size_t)br->tile_w * (size_t)br->tile_h * 4;
+
+ /* Phase 1: submit all writes (returns immediately for each pipe). */
+ for (int i = 0; i < br->batch_n; ++i) {
+ BatchSlot *s = &br->slots[i];
+ if (!s->active)
+ continue;
+
+ if (s->has_topdown_this_frame && s->pipe_topdown) {
+ const uint8_t *p = (const uint8_t *)br->readback_topdown.mapped + (size_t)i * tile_bytes;
+ ffmpeg_pipe_submit_frame(s->pipe_topdown, p);
+ }
+ if (s->has_bev_this_frame && s->pipe_bev) {
+ const uint8_t *p = (const uint8_t *)br->readback_bev.mapped + (size_t)i * tile_bytes;
+ ffmpeg_pipe_submit_frame(s->pipe_bev, p);
+ }
+ }
+
+ /* Phase 2: wait for all writes to complete. The readback buffer is
+ * about to be reused for the next frame's render so we cannot
+ * proceed until every writer has consumed its tile. */
+ int err = 0;
+ for (int i = 0; i < br->batch_n; ++i) {
+ BatchSlot *s = &br->slots[i];
+ if (!s->active)
+ continue;
+
+ if (s->has_topdown_this_frame && s->pipe_topdown) {
+ if (ffmpeg_pipe_wait(s->pipe_topdown) != 0) {
+ vk_ctx_set_error(ctx, "ffmpeg topdown write failed at slot %d", i);
+ err = -1;
+ }
+ }
+ if (s->has_bev_this_frame && s->pipe_bev) {
+ if (ffmpeg_pipe_wait(s->pipe_bev) != 0) {
+ vk_ctx_set_error(ctx, "ffmpeg bev write failed at slot %d", i);
+ err = -1;
+ }
+ }
+ s->has_topdown_this_frame = 0;
+ s->has_bev_this_frame = 0;
+ }
+
+ return err;
+}
diff --git a/pufferlib/ocean/drive/trajviz/vk_batch_renderer.h b/pufferlib/ocean/drive/trajviz/vk_batch_renderer.h
new file mode 100644
index 0000000000..99d471666a
--- /dev/null
+++ b/pufferlib/ocean/drive/trajviz/vk_batch_renderer.h
@@ -0,0 +1,143 @@
+/*
+ * vk_batch_renderer.h — multi-episode batched renderer.
+ *
+ * Renders N episodes in lockstep into a single tiled atlas per view, so
+ * one queue submit and one fence wait per frame covers N episodes'
+ * worth of work. The per-frame Vulkan + ffmpeg overhead that dominated
+ * the single-episode path gets amortized across the batch.
+ *
+ * Atlas layout: tiles are stacked **vertically** in a single 2D image
+ * sized (tile_w, batch_size * tile_h). Vertical stacking means each
+ * tile's pixel rows are contiguous in memory — slot i's MP4 frame is
+ * a single (tile_w * tile_h * 4)-byte block at offset
+ * (i * tile_w * tile_h * 4) in the readback buffer, so we can fwrite
+ * each tile to its ffmpeg pipe in one syscall with no row stitching.
+ *
+ * Per-frame command buffer:
+ * 1. Top-down atlas: barrier UNDEFINED → COLOR_ATTACHMENT
+ * 2. vkCmdBeginRendering on the full atlas
+ * 3. For each active slot: set viewport+scissor to its tile rect,
+ * push its top-down camera matrix, draw its road polylines + agent
+ * boxes
+ * 4. vkCmdEndRendering
+ * 5. Barrier → TRANSFER_SRC, vkCmdCopyImageToBuffer → host readback
+ * 6. Same six steps for the BEV atlas
+ * 7. Memory barrier → HOST
+ * 8. End cmd buffer, submit, wait
+ * 9. fwrite each slot's tile to its ffmpeg pipe
+ *
+ * The orchestrator (trajviz_render_episodes_batch) handles assembling
+ * per-frame instance arrays + camera matrices from the per-episode
+ * input data, opening/closing ffmpeg pipes, and calling the lifecycle
+ * functions below.
+ *
+ * Lifecycle:
+ * br = vk_batch_renderer_init(ctx, p, batch_n, tile_w, tile_h);
+ * for batch in batches_of_episodes:
+ * for slot in active_slots:
+ * vk_batch_renderer_set_episode(br, slot, roads..., pipes...);
+ * for frame in 0..max_episode_length:
+ * for slot in active_slots:
+ * vk_batch_renderer_set_frame(br, slot, instances, n_inst,
+ * mvp_topdown, mvp_bev);
+ * vk_batch_renderer_submit_frame(ctx, br);
+ * for slot in active_slots:
+ * vk_batch_renderer_close_episode(br, slot);
+ * vk_batch_renderer_destroy(ctx, br);
+ */
+
+#ifndef VK_BATCH_RENDERER_H
+#define VK_BATCH_RENDERER_H
+
+#include "vk_context.h"
+#include "vk_pipeline.h"
+#include "vk_math.h"
+#include "vk_renderer.h" /* for VkBufferM, VkImageM, AgentInstance */
+#include "ffmpeg_pipe.h"
+
+#include
+
+/* Per-episode state held by one slot in the batch. */
+typedef struct BatchSlot {
+ int active; /* 1 if this slot is currently rendering an episode */
+
+ /* Per-episode static geometry (set by set_episode) */
+ VkBufferM road_vb;
+ uint32_t road_vb_capacity;
+ uint32_t *road_offsets; /* (num_polys+1,) host copy */
+ uint32_t *road_types; /* (num_polys,) */
+ uint32_t num_polys;
+ uint32_t road_meta_capacity;
+
+ /* Per-frame agent instance buffer (resized as needed). */
+ VkBufferM agent_inst_vb;
+ uint32_t agent_inst_capacity;
+ uint32_t current_n_instances;
+
+ /* Per-frame camera matrices (set by set_frame). NULL pointer = skip. */
+ Mat4 mvp_topdown;
+ Mat4 mvp_bev;
+ int has_topdown_this_frame;
+ int has_bev_this_frame;
+
+ /* Per-episode ffmpeg pipes (borrowed from the orchestrator). */
+ FfmpegPipe *pipe_topdown; /* may be NULL */
+ FfmpegPipe *pipe_bev; /* may be NULL */
+} BatchSlot;
+
+typedef struct BatchRenderer {
+ Pipelines *pipelines; /* borrowed */
+ int batch_n; /* number of slots */
+ uint32_t tile_w, tile_h; /* per-tile pixel dimensions */
+
+ BatchSlot *slots; /* batch_n entries */
+
+ /* Tiled atlases — one per view. Width = tile_w, height = batch_n * tile_h.
+ * Slot i occupies y in [i * tile_h, (i+1) * tile_h). */
+ VkImageM atlas_topdown;
+ VkImageM atlas_bev;
+
+ /* Host-visible readback buffers (persistently mapped). One per atlas.
+ * Size = tile_w * batch_n * tile_h * 4. Each slot's tile starts at
+ * offset i * (tile_w * tile_h * 4) and is contiguous. */
+ VkBufferM readback_topdown;
+ VkBufferM readback_bev;
+
+ /* Static unit-quad geometry shared across all slots. */
+ VkBufferM unit_quad_vb;
+ VkBufferM unit_quad_ib;
+
+ /* One command buffer + one fence — single-frame-in-flight is fine
+ * once we're batching, since each frame already does N episodes of
+ * work in one submit. */
+ VkCommandBuffer cmd;
+ VkFence fence;
+} BatchRenderer;
+
+int vk_batch_renderer_init(VkCtx *ctx, Pipelines *p, BatchRenderer *br, int batch_n, uint32_t tile_w, uint32_t tile_h);
+void vk_batch_renderer_destroy(VkCtx *ctx, BatchRenderer *br);
+
+/* Bind an episode to a slot. Copies road geometry into device memory and
+ * stores ffmpeg pipe pointers (which the orchestrator must keep alive
+ * for the duration of this slot's episode). Either pipe may be NULL. */
+int vk_batch_renderer_set_episode(VkCtx *ctx, BatchRenderer *br, int slot, const float *road_xy, uint32_t num_verts,
+ const uint32_t *road_offsets, const uint32_t *road_types, uint32_t num_polys,
+ FfmpegPipe *pipe_topdown, FfmpegPipe *pipe_bev);
+
+/* Update per-frame state for one slot: agent instance array + camera
+ * matrices. Either MVP pointer may be NULL to skip that view this frame
+ * (e.g. the slot's episode has terminated and shouldn't draw anything
+ * new). Must be called for every active slot before submit_frame. */
+int vk_batch_renderer_set_frame(VkCtx *ctx, BatchRenderer *br, int slot, const AgentInstance *instances,
+ uint32_t num_instances, const Mat4 *mvp_topdown, const Mat4 *mvp_bev);
+
+/* Submit one batched frame: records all slots' draws into one command
+ * buffer, submits, waits for the fence, and fwrites each tile to its
+ * slot's ffmpeg pipes. */
+int vk_batch_renderer_submit_frame(VkCtx *ctx, BatchRenderer *br);
+
+/* Mark a slot as inactive (the episode finished). Does not free its
+ * device buffers — they get reused if a future set_episode lands here. */
+void vk_batch_renderer_close_episode(BatchRenderer *br, int slot);
+
+#endif
diff --git a/pufferlib/ocean/drive/trajviz/vk_context.c b/pufferlib/ocean/drive/trajviz/vk_context.c
new file mode 100644
index 0000000000..519f50ba77
--- /dev/null
+++ b/pufferlib/ocean/drive/trajviz/vk_context.c
@@ -0,0 +1,274 @@
+/*
+ * vk_context.c — Vulkan instance, device, and queue setup for trajviz.
+ *
+ * The init order is:
+ * 1. vkCreateInstance (request 1.3, optional debug utils)
+ * 2. enumerate physical devices, prefer discrete GPU
+ * 3. find a graphics queue family
+ * 4. vkCreateDevice with dynamic_rendering + synchronization2 enabled
+ * (both core in 1.3, but the feature struct still has to be in the
+ * create-info chain)
+ * 5. vkGetDeviceQueue
+ * 6. vkCreateCommandPool with RESET_COMMAND_BUFFER (per-frame command
+ * buffers are short-lived; resetting the whole pool is cleaner than
+ * individual buffer resets)
+ *
+ * Cleanup is reverse order. vk_ctx_destroy can be called on a partially-
+ * initialized context (init failure path) — every handle is checked for
+ * VK_NULL_HANDLE before destruction.
+ */
+
+#include "vk_context.h"
+
+#include
+#include
+#include
+#include
+
+void vk_ctx_set_error(VkCtx *ctx, const char *fmt, ...) {
+ if (!ctx)
+ return;
+ va_list ap;
+ va_start(ap, fmt);
+ vsnprintf(ctx->last_error, TRAJVIZ_ERROR_BUF, fmt, ap);
+ va_end(ap);
+}
+
+uint32_t vk_find_memory_type(const VkCtx *ctx, uint32_t type_bits, VkMemoryPropertyFlags properties) {
+ for (uint32_t i = 0; i < ctx->mem_props.memoryTypeCount; ++i) {
+ if ((type_bits & (1u << i)) && (ctx->mem_props.memoryTypes[i].propertyFlags & properties) == properties) {
+ return i;
+ }
+ }
+ return UINT32_MAX;
+}
+
+#ifdef TRAJVIZ_DEBUG
+static VKAPI_ATTR VkBool32 VKAPI_CALL debug_cb(VkDebugUtilsMessageSeverityFlagBitsEXT severity,
+ VkDebugUtilsMessageTypeFlagsEXT type,
+ const VkDebugUtilsMessengerCallbackDataEXT *data, void *user_data) {
+ (void)type;
+ (void)user_data;
+ /* Only print warnings and errors — info/verbose are too noisy. */
+ if (severity & (VK_DEBUG_UTILS_MESSAGE_SEVERITY_WARNING_BIT_EXT | VK_DEBUG_UTILS_MESSAGE_SEVERITY_ERROR_BIT_EXT)) {
+ fprintf(stderr, "[vk] %s\n", data->pMessage);
+ }
+ return VK_FALSE;
+}
+#endif
+
+static int create_instance(VkCtx *ctx) {
+ VkApplicationInfo app = {
+ .sType = VK_STRUCTURE_TYPE_APPLICATION_INFO,
+ .pApplicationName = "trajviz",
+ .applicationVersion = VK_MAKE_VERSION(0, 1, 0),
+ .pEngineName = "trajviz",
+ .engineVersion = VK_MAKE_VERSION(0, 1, 0),
+ .apiVersion = VK_API_VERSION_1_3,
+ };
+
+ const char *layers[] = {"VK_LAYER_KHRONOS_validation"};
+ const char *exts[] = {VK_EXT_DEBUG_UTILS_EXTENSION_NAME};
+ uint32_t num_layers = 0;
+ uint32_t num_exts = 0;
+
+#ifdef TRAJVIZ_DEBUG
+ /* Verify the validation layer is actually available; otherwise the
+ * instance creation fails outright instead of degrading gracefully. */
+ uint32_t avail_layer_count = 0;
+ vkEnumerateInstanceLayerProperties(&avail_layer_count, NULL);
+ VkLayerProperties *avail_layers = calloc(avail_layer_count, sizeof(*avail_layers));
+ vkEnumerateInstanceLayerProperties(&avail_layer_count, avail_layers);
+ for (uint32_t i = 0; i < avail_layer_count; ++i) {
+ if (strcmp(avail_layers[i].layerName, layers[0]) == 0) {
+ num_layers = 1;
+ num_exts = 1;
+ ctx->debug_enabled = 1;
+ break;
+ }
+ }
+ free(avail_layers);
+ if (!num_layers) {
+ fprintf(stderr, "[trajviz] validation layer not available, continuing without\n");
+ }
+#endif
+
+ VkInstanceCreateInfo ci = {
+ .sType = VK_STRUCTURE_TYPE_INSTANCE_CREATE_INFO,
+ .pApplicationInfo = &app,
+ .enabledLayerCount = num_layers,
+ .ppEnabledLayerNames = layers,
+ .enabledExtensionCount = num_exts,
+ .ppEnabledExtensionNames = exts,
+ };
+ VK_CHECK(vkCreateInstance(&ci, NULL, &ctx->instance));
+
+#ifdef TRAJVIZ_DEBUG
+ if (ctx->debug_enabled) {
+ VkDebugUtilsMessengerCreateInfoEXT dci = {
+ .sType = VK_STRUCTURE_TYPE_DEBUG_UTILS_MESSENGER_CREATE_INFO_EXT,
+ .messageSeverity =
+ VK_DEBUG_UTILS_MESSAGE_SEVERITY_WARNING_BIT_EXT | VK_DEBUG_UTILS_MESSAGE_SEVERITY_ERROR_BIT_EXT,
+ .messageType = VK_DEBUG_UTILS_MESSAGE_TYPE_GENERAL_BIT_EXT |
+ VK_DEBUG_UTILS_MESSAGE_TYPE_VALIDATION_BIT_EXT |
+ VK_DEBUG_UTILS_MESSAGE_TYPE_PERFORMANCE_BIT_EXT,
+ .pfnUserCallback = debug_cb,
+ };
+ PFN_vkCreateDebugUtilsMessengerEXT create_msgr =
+ (PFN_vkCreateDebugUtilsMessengerEXT)vkGetInstanceProcAddr(ctx->instance, "vkCreateDebugUtilsMessengerEXT");
+ if (create_msgr) {
+ create_msgr(ctx->instance, &dci, NULL, &ctx->debug_messenger);
+ }
+ }
+#endif
+
+ return 0;
+}
+
+static int pick_physical_device(VkCtx *ctx) {
+ uint32_t count = 0;
+ vkEnumeratePhysicalDevices(ctx->instance, &count, NULL);
+ if (count == 0) {
+ vk_ctx_set_error(ctx, "no Vulkan-capable physical device found");
+ return -1;
+ }
+ VkPhysicalDevice *devs = calloc(count, sizeof(*devs));
+ vkEnumeratePhysicalDevices(ctx->instance, &count, devs);
+
+ /* Prefer the first discrete GPU; fall back to the first device of any
+ * type. This handles the typical "RTX + iGPU" workstation case where
+ * we want the discrete card, and the "headless server" case where the
+ * only device might be lavapipe (CPU rasterizer) or a virtio GPU. */
+ VkPhysicalDevice picked = VK_NULL_HANDLE;
+ for (uint32_t i = 0; i < count; ++i) {
+ VkPhysicalDeviceProperties props;
+ vkGetPhysicalDeviceProperties(devs[i], &props);
+ if (props.deviceType == VK_PHYSICAL_DEVICE_TYPE_DISCRETE_GPU) {
+ picked = devs[i];
+ break;
+ }
+ }
+ if (picked == VK_NULL_HANDLE) {
+ picked = devs[0];
+ }
+
+ ctx->physical_device = picked;
+ VkPhysicalDeviceProperties props;
+ vkGetPhysicalDeviceProperties(picked, &props);
+ snprintf(ctx->device_name, sizeof(ctx->device_name), "%s", props.deviceName);
+ vkGetPhysicalDeviceMemoryProperties(picked, &ctx->mem_props);
+
+ free(devs);
+ return 0;
+}
+
+static int find_graphics_queue(VkCtx *ctx) {
+ uint32_t count = 0;
+ vkGetPhysicalDeviceQueueFamilyProperties(ctx->physical_device, &count, NULL);
+ VkQueueFamilyProperties *fams = calloc(count, sizeof(*fams));
+ vkGetPhysicalDeviceQueueFamilyProperties(ctx->physical_device, &count, fams);
+
+ uint32_t found = UINT32_MAX;
+ for (uint32_t i = 0; i < count; ++i) {
+ if (fams[i].queueFlags & VK_QUEUE_GRAPHICS_BIT) {
+ found = i;
+ break;
+ }
+ }
+ free(fams);
+ if (found == UINT32_MAX) {
+ vk_ctx_set_error(ctx, "no graphics queue family on device %s", ctx->device_name);
+ return -1;
+ }
+ ctx->graphics_family = found;
+ return 0;
+}
+
+static int create_device(VkCtx *ctx) {
+ float prio = 1.0f;
+ VkDeviceQueueCreateInfo qci = {
+ .sType = VK_STRUCTURE_TYPE_DEVICE_QUEUE_CREATE_INFO,
+ .queueFamilyIndex = ctx->graphics_family,
+ .queueCount = 1,
+ .pQueuePriorities = &prio,
+ };
+
+ /* Vulkan 1.3 features struct chain. We need:
+ * - dynamicRendering: lets us draw without setting up VkRenderPass /
+ * VkFramebuffer objects. Cleaner code, no behavioral difference.
+ * - synchronization2: nicer image-barrier API (single struct, no
+ * dst-stage-mask juggling). Worth the one extra line.
+ * Both are core in 1.3 but you still have to flip the bits in the
+ * features chain to use them. */
+ VkPhysicalDeviceVulkan13Features f13 = {
+ .sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_VULKAN_1_3_FEATURES,
+ .dynamicRendering = VK_TRUE,
+ .synchronization2 = VK_TRUE,
+ };
+
+ VkDeviceCreateInfo dci = {
+ .sType = VK_STRUCTURE_TYPE_DEVICE_CREATE_INFO,
+ .pNext = &f13,
+ .queueCreateInfoCount = 1,
+ .pQueueCreateInfos = &qci,
+ };
+ VK_CHECK(vkCreateDevice(ctx->physical_device, &dci, NULL, &ctx->device));
+ vkGetDeviceQueue(ctx->device, ctx->graphics_family, 0, &ctx->graphics_queue);
+ return 0;
+}
+
+static int create_command_pool(VkCtx *ctx) {
+ VkCommandPoolCreateInfo ci = {
+ .sType = VK_STRUCTURE_TYPE_COMMAND_POOL_CREATE_INFO,
+ .flags = VK_COMMAND_POOL_CREATE_RESET_COMMAND_BUFFER_BIT,
+ .queueFamilyIndex = ctx->graphics_family,
+ };
+ VK_CHECK(vkCreateCommandPool(ctx->device, &ci, NULL, &ctx->command_pool));
+ return 0;
+}
+
+int vk_ctx_init(VkCtx *ctx) {
+ memset(ctx, 0, sizeof(*ctx));
+ int r;
+ if ((r = create_instance(ctx)) != 0)
+ goto fail;
+ if ((r = pick_physical_device(ctx)) != 0)
+ goto fail;
+ if ((r = find_graphics_queue(ctx)) != 0)
+ goto fail;
+ if ((r = create_device(ctx)) != 0)
+ goto fail;
+ if ((r = create_command_pool(ctx)) != 0)
+ goto fail;
+ return 0;
+fail:
+ vk_ctx_destroy(ctx);
+ return r;
+}
+
+void vk_ctx_destroy(VkCtx *ctx) {
+ if (!ctx)
+ return;
+ if (ctx->command_pool != VK_NULL_HANDLE) {
+ vkDestroyCommandPool(ctx->device, ctx->command_pool, NULL);
+ ctx->command_pool = VK_NULL_HANDLE;
+ }
+ if (ctx->device != VK_NULL_HANDLE) {
+ vkDestroyDevice(ctx->device, NULL);
+ ctx->device = VK_NULL_HANDLE;
+ }
+#ifdef TRAJVIZ_DEBUG
+ if (ctx->debug_messenger != VK_NULL_HANDLE) {
+ PFN_vkDestroyDebugUtilsMessengerEXT destroy_msgr = (PFN_vkDestroyDebugUtilsMessengerEXT)vkGetInstanceProcAddr(
+ ctx->instance, "vkDestroyDebugUtilsMessengerEXT");
+ if (destroy_msgr) {
+ destroy_msgr(ctx->instance, ctx->debug_messenger, NULL);
+ }
+ ctx->debug_messenger = VK_NULL_HANDLE;
+ }
+#endif
+ if (ctx->instance != VK_NULL_HANDLE) {
+ vkDestroyInstance(ctx->instance, NULL);
+ ctx->instance = VK_NULL_HANDLE;
+ }
+}
diff --git a/pufferlib/ocean/drive/trajviz/vk_context.h b/pufferlib/ocean/drive/trajviz/vk_context.h
new file mode 100644
index 0000000000..e87915ad44
--- /dev/null
+++ b/pufferlib/ocean/drive/trajviz/vk_context.h
@@ -0,0 +1,83 @@
+/*
+ * vk_context.h — Vulkan instance/device/queue lifecycle for trajviz.
+ *
+ * Pure headless: no surface, no swapchain, no window system. We render
+ * into VkImages allocated from device memory and copy them back to host
+ * buffers for ffmpeg encoding. This means trajviz works on a cluster
+ * node with no display server, which is half the point of using Vulkan
+ * over raylib.
+ *
+ * The VkCtx struct holds everything that lives for the entire renderer
+ * lifetime — instance, device, queue, command pool, debug messenger.
+ * Per-episode state (images, framebuffers, ffmpeg pipes) lives in
+ * vk_renderer.h's RenderTargets, not here.
+ *
+ * One ctx per thread. The Vulkan spec allows concurrent submits to a
+ * queue from multiple threads with external synchronization, but we
+ * don't need that for v1 — single-threaded inside an episode, multi-
+ * episode parallelism happens at a higher level (one ctx per worker).
+ */
+
+#ifndef VK_CONTEXT_H
+#define VK_CONTEXT_H
+
+#include
+#include
+
+#define TRAJVIZ_ERROR_BUF 512
+
+typedef struct VkCtx {
+ VkInstance instance;
+ VkPhysicalDevice physical_device;
+ VkDevice device;
+ uint32_t graphics_family;
+ VkQueue graphics_queue;
+ VkCommandPool command_pool;
+
+ /* Cached physical device properties used by other modules. */
+ VkPhysicalDeviceMemoryProperties mem_props;
+ char device_name[256];
+
+ /* Optional validation messenger; only created in debug builds. */
+ VkDebugUtilsMessengerEXT debug_messenger;
+ int debug_enabled;
+
+ /* Last error message — populated by failing functions for the
+ * caller to surface via trajviz_last_error(). */
+ char last_error[TRAJVIZ_ERROR_BUF];
+} VkCtx;
+
+/* Initialize the Vulkan context. Returns 0 on success, non-zero on
+ * failure. On failure, the last_error field contains a human-readable
+ * message and any partially-created handles have been destroyed.
+ *
+ * If TRAJVIZ_DEBUG is defined at compile time, validation layers are
+ * enabled and a debug messenger is registered. */
+int vk_ctx_init(VkCtx *ctx);
+
+/* Destroy the context. Idempotent. Safe to call after a failed init. */
+void vk_ctx_destroy(VkCtx *ctx);
+
+/* Find a memory type index that matches both the type bits (from a
+ * VkMemoryRequirements query) and the requested property flags. Returns
+ * UINT32_MAX if none. */
+uint32_t vk_find_memory_type(const VkCtx *ctx, uint32_t type_bits, VkMemoryPropertyFlags properties);
+
+/* Helper to set ctx->last_error from a printf-style message. Used by
+ * vk_ctx and other modules. */
+void vk_ctx_set_error(VkCtx *ctx, const char *fmt, ...);
+
+/* VK_CHECK is the verbose-but-correct error path. On failure it sets
+ * ctx->last_error and returns the result code from the enclosing
+ * function. Use only inside functions that return int and have a VkCtx
+ * *ctx in scope. */
+#define VK_CHECK(expr) \
+ do { \
+ VkResult _r = (expr); \
+ if (_r != VK_SUCCESS) { \
+ vk_ctx_set_error(ctx, "%s failed at %s:%d (VkResult=%d)", #expr, __FILE__, __LINE__, (int)_r); \
+ return (int)_r; \
+ } \
+ } while (0)
+
+#endif /* VK_CONTEXT_H */
diff --git a/pufferlib/ocean/drive/trajviz/vk_math.h b/pufferlib/ocean/drive/trajviz/vk_math.h
new file mode 100644
index 0000000000..46d5bb658e
--- /dev/null
+++ b/pufferlib/ocean/drive/trajviz/vk_math.h
@@ -0,0 +1,141 @@
+/*
+ * vk_math.h — minimal column-major mat4 helpers for trajviz.
+ *
+ * Header-only on purpose: this is ~30 lines of arithmetic, not worth a
+ * separate translation unit, and it gets inlined into the four call sites
+ * (top-down camera, BEV camera, push-constant upload, agent box vertex
+ * shader uniforms).
+ *
+ * Layout: column-major, 4x4 floats, m[col][row]. Matches Vulkan's expected
+ * uniform layout (std140) when uploaded as 16 floats. Multiplication is
+ * standard: M*v transforms a column vector v.
+ *
+ * Coordinate convention: world is right-handed, +x right, +y up. The
+ * mat4_ortho helper bakes the Vulkan y-flip (clip +y is "down") into the
+ * projection so that world +y appears at the TOP of the rendered frame —
+ * the natural orientation for top-down maps.
+ */
+
+#ifndef VK_MATH_H
+#define VK_MATH_H
+
+#include
+#include
+
+typedef struct {
+ float m[16];
+} Mat4; /* column-major: m[col*4 + row] */
+
+static inline Mat4 mat4_identity(void) {
+ Mat4 r = {{0}};
+ r.m[0] = r.m[5] = r.m[10] = r.m[15] = 1.0f;
+ return r;
+}
+
+/* Vulkan-style orthographic projection with y-flip baked in.
+ * Maps world (left..right, bottom..top, near..far) → clip (-1..1, +1..-1, 0..1).
+ * Note the y range flip: world bottom→clip +1, world top→clip -1, so that
+ * world +y appears upward on the rendered image.
+ *
+ * The +1..-1 on y is the Vulkan convention difference vs OpenGL.
+ * The 0..1 on z is also Vulkan-specific (vs OpenGL's -1..1). */
+static inline Mat4 mat4_ortho(float l, float r, float b, float t, float n, float f) {
+ Mat4 m = {{0}};
+ m.m[0] = 2.0f / (r - l);
+ m.m[5] = -2.0f / (t - b); /* y-flip */
+ m.m[10] = 1.0f / (f - n);
+ m.m[12] = -(r + l) / (r - l);
+ m.m[13] = (t + b) / (t - b); /* y-flip */
+ m.m[14] = -n / (f - n);
+ m.m[15] = 1.0f;
+ return m;
+}
+
+static inline Mat4 mat4_translate(float tx, float ty, float tz) {
+ Mat4 r = mat4_identity();
+ r.m[12] = tx;
+ r.m[13] = ty;
+ r.m[14] = tz;
+ return r;
+}
+
+static inline Mat4 mat4_rotate_z(float angle_rad) {
+ Mat4 r = mat4_identity();
+ float c = cosf(angle_rad);
+ float s = sinf(angle_rad);
+ r.m[0] = c;
+ r.m[4] = -s;
+ r.m[1] = s;
+ r.m[5] = c;
+ return r;
+}
+
+static inline Mat4 mat4_mul(Mat4 a, Mat4 b) {
+ Mat4 r = {{0}};
+ for (int col = 0; col < 4; ++col) {
+ for (int row = 0; row < 4; ++row) {
+ float s = 0.0f;
+ for (int k = 0; k < 4; ++k) {
+ s += a.m[k * 4 + row] * b.m[col * 4 + k];
+ }
+ r.m[col * 4 + row] = s;
+ }
+ }
+ return r;
+}
+
+/* Fit a world-space AABB into a viewport of (vp_w, vp_h) pixels with the
+ * given fractional padding (e.g. 0.05 = 5% margin on each side), preserving
+ * aspect ratio. Returns an ortho projection that maps the world AABB to the
+ * full viewport with letterbox/pillarbox as needed. */
+static inline Mat4 mat4_fit_aabb(float xmin, float ymin, float xmax, float ymax, int vp_w, int vp_h, float pad_frac) {
+ float w = xmax - xmin;
+ float h = ymax - ymin;
+ if (w <= 0.0f)
+ w = 1.0f;
+ if (h <= 0.0f)
+ h = 1.0f;
+ float cx = 0.5f * (xmin + xmax);
+ float cy = 0.5f * (ymin + ymax);
+
+ /* Choose half-extent so the world AABB fits inside the viewport with
+ * aspect preserved. The viewport has aspect vp_w/vp_h; the world has
+ * aspect w/h. Pick whichever side is the binding constraint. */
+ float vp_aspect = (float)vp_w / (float)vp_h;
+ float world_aspect = w / h;
+ float half_w, half_h;
+ if (world_aspect > vp_aspect) {
+ /* World is wider than viewport: width-bound, height grows */
+ half_w = 0.5f * w;
+ half_h = half_w / vp_aspect;
+ } else {
+ /* World is taller than viewport: height-bound, width grows */
+ half_h = 0.5f * h;
+ half_w = half_h * vp_aspect;
+ }
+ half_w *= (1.0f + pad_frac);
+ half_h *= (1.0f + pad_frac);
+
+ return mat4_ortho(cx - half_w, cx + half_w, cy - half_h, cy + half_h, -1.0f, 1.0f);
+}
+
+/* Build the BEV camera matrix that puts the ego at the origin and rotates
+ * the world so the ego's heading vector points to clip "up" (world +y after
+ * the rotation). window_m is the half-extent in meters (e.g. 50 for a
+ * 100m × 100m view). */
+static inline Mat4 mat4_bev_camera(float ego_x, float ego_y, float ego_heading_rad, float window_m, int vp_w,
+ int vp_h) {
+ /* Translate so ego is at origin */
+ Mat4 T = mat4_translate(-ego_x, -ego_y, 0.0f);
+ /* Rotate by (pi/2 - heading) so the heading vector aligns with +y */
+ Mat4 R = mat4_rotate_z(1.5707963f - ego_heading_rad);
+ /* Aspect-corrected ortho window centered on origin */
+ float vp_aspect = (float)vp_w / (float)vp_h;
+ float half_h = window_m;
+ float half_w = window_m * vp_aspect;
+ Mat4 P = mat4_ortho(-half_w, half_w, -half_h, half_h, -1.0f, 1.0f);
+ /* M = P * R * T (apply T first, then R, then P) */
+ return mat4_mul(P, mat4_mul(R, T));
+}
+
+#endif /* VK_MATH_H */
diff --git a/pufferlib/ocean/drive/trajviz/vk_pipeline.c b/pufferlib/ocean/drive/trajviz/vk_pipeline.c
new file mode 100644
index 0000000000..c26234429f
--- /dev/null
+++ b/pufferlib/ocean/drive/trajviz/vk_pipeline.c
@@ -0,0 +1,290 @@
+/*
+ * vk_pipeline.c — graphics pipeline construction.
+ *
+ * Both pipelines share the same VkPipelineLayout (one push-constant range,
+ * no descriptor sets) but differ in vertex input state and primitive
+ * topology. Most state is identical and described once in helper structs.
+ */
+
+#include "vk_pipeline.h"
+#include "shaders.h"
+
+#include
+#include
+
+static int create_shader_module(VkCtx *ctx, const uint32_t *code, size_t size_bytes, VkShaderModule *out) {
+ VkShaderModuleCreateInfo ci = {
+ .sType = VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO,
+ .codeSize = size_bytes,
+ .pCode = code,
+ };
+ VK_CHECK(vkCreateShaderModule(ctx->device, &ci, NULL, out));
+ return 0;
+}
+
+static int create_pipeline_layout(VkCtx *ctx, Pipelines *p) {
+ VkPushConstantRange pcr = {
+ .stageFlags = VK_SHADER_STAGE_VERTEX_BIT | VK_SHADER_STAGE_FRAGMENT_BIT,
+ .offset = 0,
+ .size = sizeof(PushConstants),
+ };
+ VkPipelineLayoutCreateInfo ci = {
+ .sType = VK_STRUCTURE_TYPE_PIPELINE_LAYOUT_CREATE_INFO,
+ .pushConstantRangeCount = 1,
+ .pPushConstantRanges = &pcr,
+ };
+ VK_CHECK(vkCreatePipelineLayout(ctx->device, &ci, NULL, &p->layout));
+ return 0;
+}
+
+/* Common pipeline state shared between line and box pipelines. */
+typedef struct PipelineDefaults {
+ VkPipelineInputAssemblyStateCreateInfo ia;
+ VkPipelineViewportStateCreateInfo vp;
+ VkPipelineRasterizationStateCreateInfo rs;
+ VkPipelineMultisampleStateCreateInfo ms;
+ VkPipelineDepthStencilStateCreateInfo ds;
+ VkPipelineColorBlendAttachmentState blend_att;
+ VkPipelineColorBlendStateCreateInfo blend;
+ VkPipelineDynamicStateCreateInfo dyn;
+ VkDynamicState dyn_states[2];
+} PipelineDefaults;
+
+static void fill_defaults(PipelineDefaults *d, VkPrimitiveTopology topology) {
+ memset(d, 0, sizeof(*d));
+
+ d->ia.sType = VK_STRUCTURE_TYPE_PIPELINE_INPUT_ASSEMBLY_STATE_CREATE_INFO;
+ d->ia.topology = topology;
+ d->ia.primitiveRestartEnable = VK_FALSE;
+
+ /* Viewport + scissor are dynamic state — actual values come from
+ * vkCmdSetViewport / vkCmdSetScissor at record time. The struct still
+ * needs viewportCount/scissorCount = 1 here. */
+ d->vp.sType = VK_STRUCTURE_TYPE_PIPELINE_VIEWPORT_STATE_CREATE_INFO;
+ d->vp.viewportCount = 1;
+ d->vp.scissorCount = 1;
+
+ d->rs.sType = VK_STRUCTURE_TYPE_PIPELINE_RASTERIZATION_STATE_CREATE_INFO;
+ d->rs.polygonMode = VK_POLYGON_MODE_FILL;
+ d->rs.cullMode = VK_CULL_MODE_NONE;
+ d->rs.frontFace = VK_FRONT_FACE_COUNTER_CLOCKWISE;
+ d->rs.lineWidth = 1.5f; /* used for line topology only; ignored for tris */
+
+ d->ms.sType = VK_STRUCTURE_TYPE_PIPELINE_MULTISAMPLE_STATE_CREATE_INFO;
+ d->ms.rasterizationSamples = VK_SAMPLE_COUNT_1_BIT;
+
+ d->ds.sType = VK_STRUCTURE_TYPE_PIPELINE_DEPTH_STENCIL_STATE_CREATE_INFO;
+ d->ds.depthTestEnable = VK_FALSE;
+ d->ds.depthWriteEnable = VK_FALSE;
+
+ /* Standard alpha-blend over the existing color, so trace overlays and
+ * agent boxes with alpha < 1 fade nicely. */
+ d->blend_att.blendEnable = VK_TRUE;
+ d->blend_att.srcColorBlendFactor = VK_BLEND_FACTOR_SRC_ALPHA;
+ d->blend_att.dstColorBlendFactor = VK_BLEND_FACTOR_ONE_MINUS_SRC_ALPHA;
+ d->blend_att.colorBlendOp = VK_BLEND_OP_ADD;
+ d->blend_att.srcAlphaBlendFactor = VK_BLEND_FACTOR_ONE;
+ d->blend_att.dstAlphaBlendFactor = VK_BLEND_FACTOR_ZERO;
+ d->blend_att.alphaBlendOp = VK_BLEND_OP_ADD;
+ d->blend_att.colorWriteMask =
+ VK_COLOR_COMPONENT_R_BIT | VK_COLOR_COMPONENT_G_BIT | VK_COLOR_COMPONENT_B_BIT | VK_COLOR_COMPONENT_A_BIT;
+
+ d->blend.sType = VK_STRUCTURE_TYPE_PIPELINE_COLOR_BLEND_STATE_CREATE_INFO;
+ d->blend.attachmentCount = 1;
+ d->blend.pAttachments = &d->blend_att;
+
+ d->dyn_states[0] = VK_DYNAMIC_STATE_VIEWPORT;
+ d->dyn_states[1] = VK_DYNAMIC_STATE_SCISSOR;
+ d->dyn.sType = VK_STRUCTURE_TYPE_PIPELINE_DYNAMIC_STATE_CREATE_INFO;
+ d->dyn.dynamicStateCount = 2;
+ d->dyn.pDynamicStates = d->dyn_states;
+}
+
+static int create_line_pipeline(VkCtx *ctx, Pipelines *p, VkShaderModule vs, VkShaderModule fs) {
+ VkPipelineShaderStageCreateInfo stages[2] = {
+ {.sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO,
+ .stage = VK_SHADER_STAGE_VERTEX_BIT,
+ .module = vs,
+ .pName = "main"},
+ {.sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO,
+ .stage = VK_SHADER_STAGE_FRAGMENT_BIT,
+ .module = fs,
+ .pName = "main"},
+ };
+
+ VkVertexInputBindingDescription binding = {
+ .binding = 0,
+ .stride = sizeof(float) * 2,
+ .inputRate = VK_VERTEX_INPUT_RATE_VERTEX,
+ };
+ VkVertexInputAttributeDescription attr = {
+ .location = 0,
+ .binding = 0,
+ .format = VK_FORMAT_R32G32_SFLOAT,
+ .offset = 0,
+ };
+ VkPipelineVertexInputStateCreateInfo vi = {
+ .sType = VK_STRUCTURE_TYPE_PIPELINE_VERTEX_INPUT_STATE_CREATE_INFO,
+ .vertexBindingDescriptionCount = 1,
+ .pVertexBindingDescriptions = &binding,
+ .vertexAttributeDescriptionCount = 1,
+ .pVertexAttributeDescriptions = &attr,
+ };
+
+ PipelineDefaults d;
+ /* LINE_STRIP lets us draw a polyline of N verts with one vkCmdDraw
+ * call (N verts → N-1 connected segments). Previously LINE_LIST
+ * forced one draw per segment, which dominated CPU command-recording
+ * cost on real maps with 200+ polylines. */
+ fill_defaults(&d, VK_PRIMITIVE_TOPOLOGY_LINE_STRIP);
+
+ VkPipelineRenderingCreateInfo rci = {
+ .sType = VK_STRUCTURE_TYPE_PIPELINE_RENDERING_CREATE_INFO,
+ .colorAttachmentCount = 1,
+ .pColorAttachmentFormats = &p->color_format,
+ };
+
+ VkGraphicsPipelineCreateInfo gci = {
+ .sType = VK_STRUCTURE_TYPE_GRAPHICS_PIPELINE_CREATE_INFO,
+ .pNext = &rci,
+ .stageCount = 2,
+ .pStages = stages,
+ .pVertexInputState = &vi,
+ .pInputAssemblyState = &d.ia,
+ .pViewportState = &d.vp,
+ .pRasterizationState = &d.rs,
+ .pMultisampleState = &d.ms,
+ .pDepthStencilState = &d.ds,
+ .pColorBlendState = &d.blend,
+ .pDynamicState = &d.dyn,
+ .layout = p->layout,
+ };
+ VK_CHECK(vkCreateGraphicsPipelines(ctx->device, VK_NULL_HANDLE, 1, &gci, NULL, &p->line_pipeline));
+ return 0;
+}
+
+static int create_box_pipeline(VkCtx *ctx, Pipelines *p, VkShaderModule vs, VkShaderModule fs) {
+ VkPipelineShaderStageCreateInfo stages[2] = {
+ {.sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO,
+ .stage = VK_SHADER_STAGE_VERTEX_BIT,
+ .module = vs,
+ .pName = "main"},
+ {.sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO,
+ .stage = VK_SHADER_STAGE_FRAGMENT_BIT,
+ .module = fs,
+ .pName = "main"},
+ };
+
+ /* Two vertex bindings:
+ * binding 0 — per-vertex unit quad corner (vec2)
+ * binding 1 — per-instance AgentInstance (40 bytes) */
+ VkVertexInputBindingDescription bindings[2] = {
+ {.binding = 0, .stride = sizeof(float) * 2, .inputRate = VK_VERTEX_INPUT_RATE_VERTEX},
+ {.binding = 1, .stride = sizeof(AgentInstance), .inputRate = VK_VERTEX_INPUT_RATE_INSTANCE},
+ };
+ VkVertexInputAttributeDescription attrs[4] = {
+ {.location = 0, .binding = 0, .format = VK_FORMAT_R32G32_SFLOAT, .offset = 0},
+ {.location = 1, .binding = 1, .format = VK_FORMAT_R32G32B32A32_SFLOAT, .offset = offsetof(AgentInstance, pose)},
+ {.location = 2, .binding = 1, .format = VK_FORMAT_R32G32_SFLOAT, .offset = offsetof(AgentInstance, size)},
+ {.location = 3,
+ .binding = 1,
+ .format = VK_FORMAT_R32G32B32A32_SFLOAT,
+ .offset = offsetof(AgentInstance, color)},
+ };
+ VkPipelineVertexInputStateCreateInfo vi = {
+ .sType = VK_STRUCTURE_TYPE_PIPELINE_VERTEX_INPUT_STATE_CREATE_INFO,
+ .vertexBindingDescriptionCount = 2,
+ .pVertexBindingDescriptions = bindings,
+ .vertexAttributeDescriptionCount = 4,
+ .pVertexAttributeDescriptions = attrs,
+ };
+
+ PipelineDefaults d;
+ fill_defaults(&d, VK_PRIMITIVE_TOPOLOGY_TRIANGLE_LIST);
+
+ VkPipelineRenderingCreateInfo rci = {
+ .sType = VK_STRUCTURE_TYPE_PIPELINE_RENDERING_CREATE_INFO,
+ .colorAttachmentCount = 1,
+ .pColorAttachmentFormats = &p->color_format,
+ };
+
+ VkGraphicsPipelineCreateInfo gci = {
+ .sType = VK_STRUCTURE_TYPE_GRAPHICS_PIPELINE_CREATE_INFO,
+ .pNext = &rci,
+ .stageCount = 2,
+ .pStages = stages,
+ .pVertexInputState = &vi,
+ .pInputAssemblyState = &d.ia,
+ .pViewportState = &d.vp,
+ .pRasterizationState = &d.rs,
+ .pMultisampleState = &d.ms,
+ .pDepthStencilState = &d.ds,
+ .pColorBlendState = &d.blend,
+ .pDynamicState = &d.dyn,
+ .layout = p->layout,
+ };
+ VK_CHECK(vkCreateGraphicsPipelines(ctx->device, VK_NULL_HANDLE, 1, &gci, NULL, &p->box_pipeline));
+ return 0;
+}
+
+int vk_pipelines_init(VkCtx *ctx, Pipelines *p, VkFormat color_format) {
+ memset(p, 0, sizeof(*p));
+ p->color_format = color_format;
+
+ int r = create_pipeline_layout(ctx, p);
+ if (r != 0)
+ return r;
+
+ VkShaderModule line_vs = VK_NULL_HANDLE, line_fs = VK_NULL_HANDLE;
+ VkShaderModule box_vs = VK_NULL_HANDLE, box_fs = VK_NULL_HANDLE;
+
+ r = create_shader_module(ctx, polyline_vert_spv, polyline_vert_spv_size, &line_vs);
+ if (r != 0)
+ goto cleanup;
+ r = create_shader_module(ctx, polyline_frag_spv, polyline_frag_spv_size, &line_fs);
+ if (r != 0)
+ goto cleanup;
+ r = create_shader_module(ctx, agent_box_vert_spv, agent_box_vert_spv_size, &box_vs);
+ if (r != 0)
+ goto cleanup;
+ r = create_shader_module(ctx, agent_box_frag_spv, agent_box_frag_spv_size, &box_fs);
+ if (r != 0)
+ goto cleanup;
+
+ r = create_line_pipeline(ctx, p, line_vs, line_fs);
+ if (r != 0)
+ goto cleanup;
+ r = create_box_pipeline(ctx, p, box_vs, box_fs);
+
+cleanup:
+ /* Shader modules can be destroyed as soon as the pipelines are built —
+ * the pipeline keeps its own reference. */
+ if (line_vs)
+ vkDestroyShaderModule(ctx->device, line_vs, NULL);
+ if (line_fs)
+ vkDestroyShaderModule(ctx->device, line_fs, NULL);
+ if (box_vs)
+ vkDestroyShaderModule(ctx->device, box_vs, NULL);
+ if (box_fs)
+ vkDestroyShaderModule(ctx->device, box_fs, NULL);
+ if (r != 0)
+ vk_pipelines_destroy(ctx, p);
+ return r;
+}
+
+void vk_pipelines_destroy(VkCtx *ctx, Pipelines *p) {
+ if (!p)
+ return;
+ if (p->line_pipeline) {
+ vkDestroyPipeline(ctx->device, p->line_pipeline, NULL);
+ p->line_pipeline = VK_NULL_HANDLE;
+ }
+ if (p->box_pipeline) {
+ vkDestroyPipeline(ctx->device, p->box_pipeline, NULL);
+ p->box_pipeline = VK_NULL_HANDLE;
+ }
+ if (p->layout) {
+ vkDestroyPipelineLayout(ctx->device, p->layout, NULL);
+ p->layout = VK_NULL_HANDLE;
+ }
+}
diff --git a/pufferlib/ocean/drive/trajviz/vk_pipeline.h b/pufferlib/ocean/drive/trajviz/vk_pipeline.h
new file mode 100644
index 0000000000..03f702246e
--- /dev/null
+++ b/pufferlib/ocean/drive/trajviz/vk_pipeline.h
@@ -0,0 +1,64 @@
+/*
+ * vk_pipeline.h — graphics pipelines for trajviz.
+ *
+ * Two pipelines, one shared pipeline layout:
+ *
+ * - line_pipeline: VK_PRIMITIVE_TOPOLOGY_LINE_LIST. Binding 0 = vec2
+ * per vertex (8 bytes stride, per-vertex rate). Used for road
+ * polylines and (eventually) trajectory traces. The vertex buffer is
+ * a flat array of (x, y) pairs with per-polyline runs delimited by
+ * vkCmdDraw calls (one draw per polyline) — no index buffer.
+ *
+ * - box_pipeline: VK_PRIMITIVE_TOPOLOGY_TRIANGLE_LIST. Binding 0 = the
+ * per-vertex unit quad (4 verts, never changes). Binding 1 =
+ * per-instance AgentInstance (40 bytes stride, per-instance rate).
+ * Drawn with vkCmdDrawIndexed using a 6-index quad index buffer
+ * and instance count = number of active agents.
+ *
+ * Pipeline layout: 0 descriptor sets, 1 push-constant range (PushConstants
+ * struct, 80 bytes, vertex+fragment stages). All per-frame state goes
+ * through push constants — no descriptor pool, no UBO juggling. This is
+ * fine because we only ever push (mat4 mvp + vec4 tint) per draw, well
+ * under the 128-byte minimum guaranteed limit.
+ *
+ * Color attachment format is captured at pipeline-creation time via
+ * VkPipelineRenderingCreateInfo (the dynamic_rendering equivalent of a
+ * VkRenderPass). The renderer's color image MUST match this format.
+ */
+
+#ifndef VK_PIPELINE_H
+#define VK_PIPELINE_H
+
+#include "vk_context.h"
+#include
+
+/* Push constant block — must match the GLSL Push struct layout in
+ * polyline.vert / agent_box.vert. std430 is implicit for push constants. */
+typedef struct PushConstants {
+ float mvp[16]; /* column-major mat4 */
+ float color[4]; /* polyline: line color; agent: per-view tint */
+} PushConstants;
+
+/* Per-instance attributes for agent boxes — must match the vertex input
+ * layout below and the location 1..3 attributes in agent_box.vert. */
+typedef struct AgentInstance {
+ float pose[4]; /* (x, y, heading_rad, _pad) */
+ float size[2]; /* (length, width) meters */
+ float color[4]; /* (r, g, b, a) */
+} AgentInstance;
+
+typedef struct Pipelines {
+ VkPipelineLayout layout;
+ VkPipeline line_pipeline;
+ VkPipeline box_pipeline;
+ VkFormat color_format;
+} Pipelines;
+
+/* Build both pipelines targeting the given color attachment format and
+ * viewport size. The viewport is dynamic state, so width/height are only
+ * advisory at this stage — set them per-frame via vkCmdSetViewport. */
+int vk_pipelines_init(VkCtx *ctx, Pipelines *p, VkFormat color_format);
+
+void vk_pipelines_destroy(VkCtx *ctx, Pipelines *p);
+
+#endif
diff --git a/pufferlib/ocean/drive/trajviz/vk_renderer.c b/pufferlib/ocean/drive/trajviz/vk_renderer.c
new file mode 100644
index 0000000000..18037be587
--- /dev/null
+++ b/pufferlib/ocean/drive/trajviz/vk_renderer.c
@@ -0,0 +1,658 @@
+/*
+ * vk_renderer.c — pipelined per-frame rendering with frames-in-flight ring.
+ *
+ * The orchestrator (trajviz.c) calls episode_begin → submit_frame×N →
+ * episode_end. Inside, the renderer keeps a small ring of FrameSlot
+ * structs and walks it as a FIFO: the CPU records the next slot, the
+ * GPU runs the previous one(s), and the host reads back from whichever
+ * slot is now signaled. This amortizes the per-submit + per-wait
+ * scheduler latency (which is the dominant cost on this path) across
+ * FRAMES_IN_FLIGHT frames.
+ *
+ * The actual command-buffer recording (record_view at the bottom of
+ * this file) is unchanged from the synchronous version — it draws roads
+ * with LINE_STRIP topology and instanced agent boxes.
+ */
+
+#include "vk_renderer.h"
+#include "shaders.h"
+
+#include
+#include
+#include
+#include
+
+/* ----------------------------- buffer helpers ----------------------------- */
+
+static int create_buffer(VkCtx *ctx, VkDeviceSize size, VkBufferUsageFlags usage, VkMemoryPropertyFlags mem_props,
+ int map_persistent, VkBufferM *out) {
+ memset(out, 0, sizeof(*out));
+ out->size = size;
+
+ VkBufferCreateInfo bci = {
+ .sType = VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO,
+ .size = size,
+ .usage = usage,
+ .sharingMode = VK_SHARING_MODE_EXCLUSIVE,
+ };
+ VK_CHECK(vkCreateBuffer(ctx->device, &bci, NULL, &out->buffer));
+
+ VkMemoryRequirements req;
+ vkGetBufferMemoryRequirements(ctx->device, out->buffer, &req);
+
+ uint32_t mem_idx = vk_find_memory_type(ctx, req.memoryTypeBits, mem_props);
+ if (mem_idx == UINT32_MAX) {
+ vk_ctx_set_error(ctx, "no memory type matches buffer requirements (props=0x%x)", (unsigned)mem_props);
+ return -1;
+ }
+
+ VkMemoryAllocateInfo mai = {
+ .sType = VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO,
+ .allocationSize = req.size,
+ .memoryTypeIndex = mem_idx,
+ };
+ VK_CHECK(vkAllocateMemory(ctx->device, &mai, NULL, &out->memory));
+ VK_CHECK(vkBindBufferMemory(ctx->device, out->buffer, out->memory, 0));
+
+ if (map_persistent) {
+ VK_CHECK(vkMapMemory(ctx->device, out->memory, 0, VK_WHOLE_SIZE, 0, &out->mapped));
+ }
+ return 0;
+}
+
+static void destroy_buffer(VkCtx *ctx, VkBufferM *b) {
+ if (!b || !ctx)
+ return;
+ if (b->mapped && b->memory) {
+ vkUnmapMemory(ctx->device, b->memory);
+ b->mapped = NULL;
+ }
+ if (b->buffer) {
+ vkDestroyBuffer(ctx->device, b->buffer, NULL);
+ b->buffer = VK_NULL_HANDLE;
+ }
+ if (b->memory) {
+ vkFreeMemory(ctx->device, b->memory, NULL);
+ b->memory = VK_NULL_HANDLE;
+ }
+ b->size = 0;
+}
+
+/* ------------------------------ image helpers ------------------------------ */
+
+static int create_image(VkCtx *ctx, uint32_t w, uint32_t h, VkFormat format, VkImageUsageFlags usage, VkImageM *out) {
+ memset(out, 0, sizeof(*out));
+ out->width = w;
+ out->height = h;
+ out->format = format;
+
+ VkImageCreateInfo ici = {
+ .sType = VK_STRUCTURE_TYPE_IMAGE_CREATE_INFO,
+ .imageType = VK_IMAGE_TYPE_2D,
+ .format = format,
+ .extent = {w, h, 1},
+ .mipLevels = 1,
+ .arrayLayers = 1,
+ .samples = VK_SAMPLE_COUNT_1_BIT,
+ .tiling = VK_IMAGE_TILING_OPTIMAL,
+ .usage = usage,
+ .sharingMode = VK_SHARING_MODE_EXCLUSIVE,
+ .initialLayout = VK_IMAGE_LAYOUT_UNDEFINED,
+ };
+ VK_CHECK(vkCreateImage(ctx->device, &ici, NULL, &out->image));
+
+ VkMemoryRequirements req;
+ vkGetImageMemoryRequirements(ctx->device, out->image, &req);
+
+ uint32_t mem_idx = vk_find_memory_type(ctx, req.memoryTypeBits, VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT);
+ if (mem_idx == UINT32_MAX) {
+ vk_ctx_set_error(ctx, "no DEVICE_LOCAL memory type for color image");
+ return -1;
+ }
+
+ VkMemoryAllocateInfo mai = {
+ .sType = VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO,
+ .allocationSize = req.size,
+ .memoryTypeIndex = mem_idx,
+ };
+ VK_CHECK(vkAllocateMemory(ctx->device, &mai, NULL, &out->memory));
+ VK_CHECK(vkBindImageMemory(ctx->device, out->image, out->memory, 0));
+
+ VkImageViewCreateInfo vci = {
+ .sType = VK_STRUCTURE_TYPE_IMAGE_VIEW_CREATE_INFO,
+ .image = out->image,
+ .viewType = VK_IMAGE_VIEW_TYPE_2D,
+ .format = format,
+ .subresourceRange =
+ {
+ .aspectMask = VK_IMAGE_ASPECT_COLOR_BIT,
+ .baseMipLevel = 0,
+ .levelCount = 1,
+ .baseArrayLayer = 0,
+ .layerCount = 1,
+ },
+ };
+ VK_CHECK(vkCreateImageView(ctx->device, &vci, NULL, &out->view));
+ return 0;
+}
+
+static void destroy_image(VkCtx *ctx, VkImageM *im) {
+ if (!im || !ctx)
+ return;
+ if (im->view) {
+ vkDestroyImageView(ctx->device, im->view, NULL);
+ im->view = VK_NULL_HANDLE;
+ }
+ if (im->image) {
+ vkDestroyImage(ctx->device, im->image, NULL);
+ im->image = VK_NULL_HANDLE;
+ }
+ if (im->memory) {
+ vkFreeMemory(ctx->device, im->memory, NULL);
+ im->memory = VK_NULL_HANDLE;
+ }
+}
+
+/* --------------------------- render target helpers -------------------------- */
+
+static int create_render_target(VkCtx *ctx, uint32_t w, uint32_t h, RenderTarget *rt) {
+ memset(rt, 0, sizeof(*rt));
+ int r = create_image(ctx, w, h, VK_FORMAT_R8G8B8A8_UNORM,
+ VK_IMAGE_USAGE_COLOR_ATTACHMENT_BIT | VK_IMAGE_USAGE_TRANSFER_SRC_BIT, &rt->color);
+ if (r != 0)
+ return r;
+
+ /* Prefer HOST_CACHED so the CPU reads the readback at full RAM
+ * bandwidth instead of going over uncached PCIe BAR (~250 MB/s on
+ * NVIDIA, vs >5 GB/s cached). Fall back to plain HOST_COHERENT if
+ * the device doesn't expose a cached host-visible memory type. */
+ VkDeviceSize buf_size = (VkDeviceSize)w * (VkDeviceSize)h * 4;
+ VkMemoryPropertyFlags want =
+ VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT | VK_MEMORY_PROPERTY_HOST_COHERENT_BIT | VK_MEMORY_PROPERTY_HOST_CACHED_BIT;
+ r = create_buffer(ctx, buf_size, VK_BUFFER_USAGE_TRANSFER_DST_BIT, want, 1, &rt->readback);
+ if (r != 0) {
+ r = create_buffer(ctx, buf_size, VK_BUFFER_USAGE_TRANSFER_DST_BIT,
+ VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT | VK_MEMORY_PROPERTY_HOST_COHERENT_BIT, 1, &rt->readback);
+ if (r != 0)
+ return r;
+ }
+ rt->row_pitch_bytes = (size_t)w * 4;
+ return 0;
+}
+
+static void destroy_render_target(VkCtx *ctx, RenderTarget *rt) {
+ if (!rt)
+ return;
+ destroy_image(ctx, &rt->color);
+ destroy_buffer(ctx, &rt->readback);
+}
+
+/* ------------------------------ static geometry ----------------------------- */
+
+static int upload_static_geometry(VkCtx *ctx, Renderer *r) {
+ /* Unit quad: 4 vec2 corners spanning [-1, +1]^2. Order: BL, BR, TR, TL. */
+ const float quad[8] = {
+ -1.0f, -1.0f, 1.0f, -1.0f, 1.0f, 1.0f, -1.0f, 1.0f,
+ };
+ const uint16_t idx[6] = {0, 1, 2, 0, 2, 3};
+
+ int rc =
+ create_buffer(ctx, sizeof(quad), VK_BUFFER_USAGE_VERTEX_BUFFER_BIT,
+ VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT | VK_MEMORY_PROPERTY_HOST_COHERENT_BIT, 1, &r->unit_quad_vb);
+ if (rc != 0)
+ return rc;
+ memcpy(r->unit_quad_vb.mapped, quad, sizeof(quad));
+
+ rc = create_buffer(ctx, sizeof(idx), VK_BUFFER_USAGE_INDEX_BUFFER_BIT,
+ VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT | VK_MEMORY_PROPERTY_HOST_COHERENT_BIT, 1, &r->unit_quad_ib);
+ if (rc != 0)
+ return rc;
+ memcpy(r->unit_quad_ib.mapped, idx, sizeof(idx));
+ return 0;
+}
+
+/* --------------------------- frame slot lifecycle --------------------------- */
+
+static int slot_init(VkCtx *ctx, FrameSlot *s, uint32_t width, uint32_t height) {
+ memset(s, 0, sizeof(*s));
+ int rc;
+ if ((rc = create_render_target(ctx, width, height, &s->rt_topdown)) != 0)
+ return rc;
+ if ((rc = create_render_target(ctx, width, height, &s->rt_bev)) != 0)
+ return rc;
+
+ VkCommandBufferAllocateInfo cai = {
+ .sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_ALLOCATE_INFO,
+ .commandPool = ctx->command_pool,
+ .level = VK_COMMAND_BUFFER_LEVEL_PRIMARY,
+ .commandBufferCount = 1,
+ };
+ VK_CHECK(vkAllocateCommandBuffers(ctx->device, &cai, &s->cmd));
+
+ VkFenceCreateInfo fci = {
+ .sType = VK_STRUCTURE_TYPE_FENCE_CREATE_INFO,
+ .flags = 0,
+ };
+ VK_CHECK(vkCreateFence(ctx->device, &fci, NULL, &s->fence));
+ return 0;
+}
+
+static void slot_destroy(VkCtx *ctx, FrameSlot *s) {
+ if (!s || !ctx)
+ return;
+ if (s->fence) {
+ vkDestroyFence(ctx->device, s->fence, NULL);
+ s->fence = VK_NULL_HANDLE;
+ }
+ if (s->cmd) {
+ vkFreeCommandBuffers(ctx->device, ctx->command_pool, 1, &s->cmd);
+ s->cmd = VK_NULL_HANDLE;
+ }
+ destroy_render_target(ctx, &s->rt_topdown);
+ destroy_render_target(ctx, &s->rt_bev);
+ destroy_buffer(ctx, &s->agent_inst_vb);
+}
+
+/* --------------------------- init / destroy / set_roads --------------------- */
+
+int vk_renderer_init(VkCtx *ctx, Pipelines *p, Renderer *r, uint32_t width, uint32_t height) {
+ memset(r, 0, sizeof(*r));
+ r->pipelines = p;
+ r->width = width;
+ r->height = height;
+
+ int rc;
+ if ((rc = upload_static_geometry(ctx, r)) != 0)
+ goto fail;
+ for (int i = 0; i < FRAMES_IN_FLIGHT; ++i) {
+ if ((rc = slot_init(ctx, &r->slots[i], width, height)) != 0)
+ goto fail;
+ }
+ return 0;
+
+fail:
+ vk_renderer_destroy(ctx, r);
+ return rc;
+}
+
+void vk_renderer_destroy(VkCtx *ctx, Renderer *r) {
+ if (!r || !ctx)
+ return;
+ for (int i = 0; i < FRAMES_IN_FLIGHT; ++i) {
+ slot_destroy(ctx, &r->slots[i]);
+ }
+ destroy_buffer(ctx, &r->unit_quad_vb);
+ destroy_buffer(ctx, &r->unit_quad_ib);
+ destroy_buffer(ctx, &r->road_vb);
+ free(r->road_offsets);
+ r->road_offsets = NULL;
+ free(r->road_types);
+ r->road_types = NULL;
+ r->num_polys = 0;
+ r->road_meta_capacity = 0;
+ r->road_vb_capacity = 0;
+ r->head = r->tail = r->n_in_flight = 0;
+}
+
+static int ensure_buffer_capacity(VkCtx *ctx, VkBufferM *b, VkDeviceSize required, VkBufferUsageFlags usage,
+ VkMemoryPropertyFlags mem_props) {
+ if (b->size >= required)
+ return 0;
+ destroy_buffer(ctx, b);
+ VkDeviceSize cap = 256;
+ while (cap < required)
+ cap <<= 1;
+ return create_buffer(ctx, cap, usage, mem_props, 1, b);
+}
+
+int vk_renderer_set_roads(VkCtx *ctx, Renderer *r, const float *road_xy, uint32_t num_verts,
+ const uint32_t *road_offsets, const uint32_t *road_types, uint32_t num_polys) {
+ /* Before re-uploading the road buffer, make sure no slot is still
+ * reading from it. The simplest correct path is to drain everything
+ * pending — set_roads is called once per episode, before the loop,
+ * so this is essentially free in steady state. */
+ for (int i = 0; i < FRAMES_IN_FLIGHT; ++i) {
+ if (r->slots[i].pending) {
+ vkWaitForFences(ctx->device, 1, &r->slots[i].fence, VK_TRUE, UINT64_MAX);
+ vkResetFences(ctx->device, 1, &r->slots[i].fence);
+ r->slots[i].pending = 0;
+ }
+ }
+ r->head = r->tail = r->n_in_flight = 0;
+
+ VkDeviceSize required = (VkDeviceSize)num_verts * sizeof(float) * 2;
+ if (required == 0)
+ required = sizeof(float) * 2;
+ int rc = ensure_buffer_capacity(ctx, &r->road_vb, required, VK_BUFFER_USAGE_VERTEX_BUFFER_BIT,
+ VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT | VK_MEMORY_PROPERTY_HOST_COHERENT_BIT);
+ if (rc != 0)
+ return rc;
+ if (num_verts > 0) {
+ memcpy(r->road_vb.mapped, road_xy, (size_t)num_verts * sizeof(float) * 2);
+ }
+ r->road_vb_capacity = num_verts;
+
+ if (num_polys + 1 > r->road_meta_capacity) {
+ free(r->road_offsets);
+ free(r->road_types);
+ r->road_meta_capacity = num_polys + 1;
+ r->road_offsets = (uint32_t *)malloc(sizeof(uint32_t) * (num_polys + 1));
+ r->road_types = (uint32_t *)malloc(sizeof(uint32_t) * num_polys);
+ if (!r->road_offsets || !r->road_types) {
+ vk_ctx_set_error(ctx, "out of host memory for road metadata");
+ return -1;
+ }
+ }
+ if (num_polys > 0) {
+ memcpy(r->road_offsets, road_offsets, sizeof(uint32_t) * (num_polys + 1));
+ memcpy(r->road_types, road_types, sizeof(uint32_t) * num_polys);
+ }
+ r->num_polys = num_polys;
+ return 0;
+}
+
+/* ------------------------------- per-frame draw ----------------------------- */
+
+static void color_for_road_type(uint32_t type, float out[4]) {
+ out[3] = 1.0f;
+ switch (type) {
+ case 6: /* ROAD_EDGE */
+ out[0] = 0.55f;
+ out[1] = 0.55f;
+ out[2] = 0.55f;
+ break;
+ case 4: /* ROAD_LANE */
+ out[0] = 0.85f;
+ out[1] = 0.78f;
+ out[2] = 0.30f;
+ out[3] = 0.6f;
+ break;
+ case 5: /* ROAD_LINE */
+ out[0] = 0.95f;
+ out[1] = 0.95f;
+ out[2] = 0.95f;
+ out[3] = 0.5f;
+ break;
+ case 10: /* DRIVEWAY */
+ out[0] = 0.40f;
+ out[1] = 0.40f;
+ out[2] = 0.55f;
+ out[3] = 0.7f;
+ break;
+ default:
+ out[0] = 0.45f;
+ out[1] = 0.45f;
+ out[2] = 0.45f;
+ break;
+ }
+}
+
+static int ensure_slot_agent_capacity(VkCtx *ctx, FrameSlot *s, uint32_t num_instances) {
+ VkDeviceSize required = (VkDeviceSize)num_instances * sizeof(AgentInstance);
+ if (required == 0)
+ required = sizeof(AgentInstance);
+ int rc = ensure_buffer_capacity(ctx, &s->agent_inst_vb, required, VK_BUFFER_USAGE_VERTEX_BUFFER_BIT,
+ VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT | VK_MEMORY_PROPERTY_HOST_COHERENT_BIT);
+ if (rc != 0)
+ return rc;
+ s->agent_inst_capacity = num_instances;
+ return 0;
+}
+
+static void barrier_image(VkCommandBuffer cmd, VkImage image, VkImageLayout old_layout, VkImageLayout new_layout,
+ VkPipelineStageFlags2 src_stage, VkAccessFlags2 src_access, VkPipelineStageFlags2 dst_stage,
+ VkAccessFlags2 dst_access) {
+ VkImageMemoryBarrier2 imb = {
+ .sType = VK_STRUCTURE_TYPE_IMAGE_MEMORY_BARRIER_2,
+ .srcStageMask = src_stage,
+ .srcAccessMask = src_access,
+ .dstStageMask = dst_stage,
+ .dstAccessMask = dst_access,
+ .oldLayout = old_layout,
+ .newLayout = new_layout,
+ .srcQueueFamilyIndex = VK_QUEUE_FAMILY_IGNORED,
+ .dstQueueFamilyIndex = VK_QUEUE_FAMILY_IGNORED,
+ .image = image,
+ .subresourceRange =
+ {
+ .aspectMask = VK_IMAGE_ASPECT_COLOR_BIT,
+ .baseMipLevel = 0,
+ .levelCount = 1,
+ .baseArrayLayer = 0,
+ .layerCount = 1,
+ },
+ };
+ VkDependencyInfo di = {
+ .sType = VK_STRUCTURE_TYPE_DEPENDENCY_INFO,
+ .imageMemoryBarrierCount = 1,
+ .pImageMemoryBarriers = &imb,
+ };
+ vkCmdPipelineBarrier2(cmd, &di);
+}
+
+/* Record one view's draws into the slot's command buffer. */
+static void record_view(VkCommandBuffer cmd, Renderer *r, FrameSlot *slot, RenderTarget *rt, const Mat4 *mvp,
+ uint32_t num_instances) {
+ barrier_image(cmd, rt->color.image, VK_IMAGE_LAYOUT_UNDEFINED, VK_IMAGE_LAYOUT_COLOR_ATTACHMENT_OPTIMAL,
+ VK_PIPELINE_STAGE_2_TOP_OF_PIPE_BIT, 0, VK_PIPELINE_STAGE_2_COLOR_ATTACHMENT_OUTPUT_BIT,
+ VK_ACCESS_2_COLOR_ATTACHMENT_WRITE_BIT);
+
+ VkClearValue clear = {.color = {.float32 = {0.05f, 0.05f, 0.08f, 1.0f}}};
+ VkRenderingAttachmentInfo att = {
+ .sType = VK_STRUCTURE_TYPE_RENDERING_ATTACHMENT_INFO,
+ .imageView = rt->color.view,
+ .imageLayout = VK_IMAGE_LAYOUT_COLOR_ATTACHMENT_OPTIMAL,
+ .loadOp = VK_ATTACHMENT_LOAD_OP_CLEAR,
+ .storeOp = VK_ATTACHMENT_STORE_OP_STORE,
+ .clearValue = clear,
+ };
+ VkRenderingInfo ri = {
+ .sType = VK_STRUCTURE_TYPE_RENDERING_INFO,
+ .renderArea = {.offset = {0, 0}, .extent = {r->width, r->height}},
+ .layerCount = 1,
+ .colorAttachmentCount = 1,
+ .pColorAttachments = &att,
+ };
+ vkCmdBeginRendering(cmd, &ri);
+
+ VkViewport vp = {
+ .x = 0.0f,
+ .y = 0.0f,
+ .width = (float)r->width,
+ .height = (float)r->height,
+ .minDepth = 0.0f,
+ .maxDepth = 1.0f,
+ };
+ VkRect2D sc = {.offset = {0, 0}, .extent = {r->width, r->height}};
+ vkCmdSetViewport(cmd, 0, 1, &vp);
+ vkCmdSetScissor(cmd, 0, 1, &sc);
+
+ if (r->num_polys > 0) {
+ vkCmdBindPipeline(cmd, VK_PIPELINE_BIND_POINT_GRAPHICS, r->pipelines->line_pipeline);
+ VkDeviceSize voff = 0;
+ vkCmdBindVertexBuffers(cmd, 0, 1, &r->road_vb.buffer, &voff);
+
+ PushConstants pc;
+ memcpy(pc.mvp, mvp->m, sizeof(pc.mvp));
+
+ for (uint32_t i = 0; i < r->num_polys; ++i) {
+ uint32_t start = r->road_offsets[i];
+ uint32_t end = r->road_offsets[i + 1];
+ if (end <= start + 1)
+ continue;
+
+ color_for_road_type(r->road_types[i], pc.color);
+ vkCmdPushConstants(cmd, r->pipelines->layout, VK_SHADER_STAGE_VERTEX_BIT | VK_SHADER_STAGE_FRAGMENT_BIT, 0,
+ sizeof(pc), &pc);
+ vkCmdDraw(cmd, end - start, 1, start, 0);
+ }
+ }
+
+ if (num_instances > 0) {
+ vkCmdBindPipeline(cmd, VK_PIPELINE_BIND_POINT_GRAPHICS, r->pipelines->box_pipeline);
+ VkBuffer vbufs[2] = {r->unit_quad_vb.buffer, slot->agent_inst_vb.buffer};
+ VkDeviceSize voffs[2] = {0, 0};
+ vkCmdBindVertexBuffers(cmd, 0, 2, vbufs, voffs);
+ vkCmdBindIndexBuffer(cmd, r->unit_quad_ib.buffer, 0, VK_INDEX_TYPE_UINT16);
+
+ PushConstants pc;
+ memcpy(pc.mvp, mvp->m, sizeof(pc.mvp));
+ pc.color[0] = pc.color[1] = pc.color[2] = pc.color[3] = 1.0f;
+ vkCmdPushConstants(cmd, r->pipelines->layout, VK_SHADER_STAGE_VERTEX_BIT | VK_SHADER_STAGE_FRAGMENT_BIT, 0,
+ sizeof(pc), &pc);
+
+ vkCmdDrawIndexed(cmd, 6, num_instances, 0, 0, 0);
+ }
+
+ vkCmdEndRendering(cmd);
+
+ barrier_image(cmd, rt->color.image, VK_IMAGE_LAYOUT_COLOR_ATTACHMENT_OPTIMAL, VK_IMAGE_LAYOUT_TRANSFER_SRC_OPTIMAL,
+ VK_PIPELINE_STAGE_2_COLOR_ATTACHMENT_OUTPUT_BIT, VK_ACCESS_2_COLOR_ATTACHMENT_WRITE_BIT,
+ VK_PIPELINE_STAGE_2_COPY_BIT, VK_ACCESS_2_TRANSFER_READ_BIT);
+
+ VkBufferImageCopy region = {
+ .bufferOffset = 0,
+ .bufferRowLength = 0,
+ .bufferImageHeight = 0,
+ .imageSubresource =
+ {
+ .aspectMask = VK_IMAGE_ASPECT_COLOR_BIT,
+ .mipLevel = 0,
+ .baseArrayLayer = 0,
+ .layerCount = 1,
+ },
+ .imageOffset = {0, 0, 0},
+ .imageExtent = {r->width, r->height, 1},
+ };
+ vkCmdCopyImageToBuffer(cmd, rt->color.image, VK_IMAGE_LAYOUT_TRANSFER_SRC_OPTIMAL, rt->readback.buffer, 1, ®ion);
+}
+
+/* ------------------------------ episode lifecycle ------------------------------ */
+
+void vk_renderer_episode_begin(Renderer *r, FfmpegPipe *pipe_topdown, FfmpegPipe *pipe_bev) {
+ r->ep_pipe_topdown = pipe_topdown;
+ r->ep_pipe_bev = pipe_bev;
+ /* head/tail/n_in_flight were already reset by set_roads, but be
+ * defensive in case episode_begin is called without it. */
+ r->head = r->tail = r->n_in_flight = 0;
+ for (int i = 0; i < FRAMES_IN_FLIGHT; ++i) {
+ r->slots[i].pending = 0;
+ r->slots[i].rendered_topdown = 0;
+ r->slots[i].rendered_bev = 0;
+ }
+}
+
+/* Wait on the slot at head, fwrite its readback buffers to ffmpeg, and
+ * advance head. Returns 0 on success, non-zero on ffmpeg failure. */
+static int drain_head(VkCtx *ctx, Renderer *r) {
+ FrameSlot *s = &r->slots[r->head];
+ if (!s->pending) {
+ /* Defensive: shouldn't happen if n_in_flight is accurate. */
+ return 0;
+ }
+ VK_CHECK(vkWaitForFences(ctx->device, 1, &s->fence, VK_TRUE, UINT64_MAX));
+ VK_CHECK(vkResetFences(ctx->device, 1, &s->fence));
+
+ if (s->rendered_topdown && r->ep_pipe_topdown) {
+ if (ffmpeg_pipe_write_frame(r->ep_pipe_topdown, s->rt_topdown.readback.mapped) != 0) {
+ vk_ctx_set_error(ctx, "ffmpeg write failed (top-down) at slot %d", r->head);
+ return -1;
+ }
+ }
+ if (s->rendered_bev && r->ep_pipe_bev) {
+ if (ffmpeg_pipe_write_frame(r->ep_pipe_bev, s->rt_bev.readback.mapped) != 0) {
+ vk_ctx_set_error(ctx, "ffmpeg write failed (bev) at slot %d", r->head);
+ return -1;
+ }
+ }
+
+ s->pending = 0;
+ s->rendered_topdown = 0;
+ s->rendered_bev = 0;
+ r->head = (r->head + 1) % FRAMES_IN_FLIGHT;
+ r->n_in_flight--;
+ return 0;
+}
+
+int vk_renderer_submit_frame(VkCtx *ctx, Renderer *r, const AgentInstance *instances, uint32_t num_instances,
+ const Mat4 *mvp_topdown, const Mat4 *mvp_bev) {
+ /* If the ring is full, drain the oldest before reusing its slot. */
+ if (r->n_in_flight == FRAMES_IN_FLIGHT) {
+ int rc = drain_head(ctx, r);
+ if (rc != 0)
+ return rc;
+ }
+
+ FrameSlot *s = &r->slots[r->tail];
+
+ /* Upload agent instances into THIS slot's buffer (not a shared one),
+ * so the GPU executing the previous frame on a different slot is not
+ * disturbed. */
+ if (num_instances > 0) {
+ int rc = ensure_slot_agent_capacity(ctx, s, num_instances);
+ if (rc != 0)
+ return rc;
+ memcpy(s->agent_inst_vb.mapped, instances, (size_t)num_instances * sizeof(AgentInstance));
+ }
+
+ VK_CHECK(vkResetCommandBuffer(s->cmd, 0));
+ VkCommandBufferBeginInfo bi = {
+ .sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_BEGIN_INFO,
+ .flags = VK_COMMAND_BUFFER_USAGE_ONE_TIME_SUBMIT_BIT,
+ };
+ VK_CHECK(vkBeginCommandBuffer(s->cmd, &bi));
+
+ if (mvp_topdown) {
+ record_view(s->cmd, r, s, &s->rt_topdown, mvp_topdown, num_instances);
+ s->rendered_topdown = 1;
+ }
+ if (mvp_bev) {
+ record_view(s->cmd, r, s, &s->rt_bev, mvp_bev, num_instances);
+ s->rendered_bev = 1;
+ }
+
+ /* Memory barrier so the host can safely read the readback buffers
+ * once the fence signals. */
+ VkMemoryBarrier2 mb = {
+ .sType = VK_STRUCTURE_TYPE_MEMORY_BARRIER_2,
+ .srcStageMask = VK_PIPELINE_STAGE_2_COPY_BIT,
+ .srcAccessMask = VK_ACCESS_2_TRANSFER_WRITE_BIT,
+ .dstStageMask = VK_PIPELINE_STAGE_2_HOST_BIT,
+ .dstAccessMask = VK_ACCESS_2_HOST_READ_BIT,
+ };
+ VkDependencyInfo di = {
+ .sType = VK_STRUCTURE_TYPE_DEPENDENCY_INFO,
+ .memoryBarrierCount = 1,
+ .pMemoryBarriers = &mb,
+ };
+ vkCmdPipelineBarrier2(s->cmd, &di);
+
+ VK_CHECK(vkEndCommandBuffer(s->cmd));
+
+ VkCommandBufferSubmitInfo csi = {
+ .sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_SUBMIT_INFO,
+ .commandBuffer = s->cmd,
+ };
+ VkSubmitInfo2 si = {
+ .sType = VK_STRUCTURE_TYPE_SUBMIT_INFO_2,
+ .commandBufferInfoCount = 1,
+ .pCommandBufferInfos = &csi,
+ };
+ VK_CHECK(vkQueueSubmit2(ctx->graphics_queue, 1, &si, s->fence));
+
+ s->pending = 1;
+ r->tail = (r->tail + 1) % FRAMES_IN_FLIGHT;
+ r->n_in_flight++;
+ return 0;
+}
+
+int vk_renderer_episode_end(VkCtx *ctx, Renderer *r) {
+ while (r->n_in_flight > 0) {
+ int rc = drain_head(ctx, r);
+ if (rc != 0)
+ return rc;
+ }
+ r->ep_pipe_topdown = NULL;
+ r->ep_pipe_bev = NULL;
+ return 0;
+}
diff --git a/pufferlib/ocean/drive/trajviz/vk_renderer.h b/pufferlib/ocean/drive/trajviz/vk_renderer.h
new file mode 100644
index 0000000000..cb358b8a76
--- /dev/null
+++ b/pufferlib/ocean/drive/trajviz/vk_renderer.h
@@ -0,0 +1,144 @@
+/*
+ * vk_renderer.h — pipelined per-episode render state.
+ *
+ * The renderer keeps a small ring of frame slots — each one a complete
+ * snapshot of "everything the GPU needs to render one frame": its own
+ * command buffer, fence, render-target images, host readback buffers,
+ * and per-frame instance vertex buffer. The CPU can record frame N+1
+ * into slot S+1 while the GPU runs frame N on slot S and the host reads
+ * the readback for frame N-1 from slot S-1, so per-submit latency
+ * (vkQueueSubmit + vkWaitForFences scheduler wakeup) is amortized
+ * across FRAMES_IN_FLIGHT frames.
+ *
+ * Static, episode-level data (road geometry, polyline metadata) lives
+ * outside the slots — it's read by every slot and never written during
+ * the loop, so no synchronization is needed.
+ *
+ * Episode lifecycle:
+ *
+ * vk_renderer_episode_begin(r, ffmpeg_td, ffmpeg_bev);
+ * for each frame:
+ * vk_renderer_submit_frame(r, instances, n, mvp_td, mvp_bev);
+ * // returns immediately after submitting; may have internally
+ * // drained an older slot and fwritten its readback to ffmpeg
+ * vk_renderer_episode_end(r);
+ * // drains the remaining FRAMES_IN_FLIGHT - 1 in-flight frames
+ *
+ * The renderer holds the ffmpeg pipe pointers for the duration of an
+ * episode so submit_frame and episode_end can write directly to them
+ * during the drain phase, without going back through the orchestrator.
+ */
+
+#ifndef VK_RENDERER_H
+#define VK_RENDERER_H
+
+#include "vk_context.h"
+#include "vk_pipeline.h"
+#include "vk_math.h"
+#include "ffmpeg_pipe.h"
+
+#include
+
+/* The frames-in-flight ring is currently a no-op (=1) — empirical
+ * timings showed that on this Vulkan path, neither the per-fence wait
+ * nor the per-submit latency is the dominant cost (we tested up to 16
+ * slots and saw no improvement). The episode_begin/submit_frame/end
+ * API shape is preserved because the batched renderer in
+ * vk_batch_renderer.{h,c} still relies on a coordinated drain phase. */
+#define FRAMES_IN_FLIGHT 1
+
+/* Buffer + memory pair, optionally persistently mapped. */
+typedef struct VkBufferM {
+ VkBuffer buffer;
+ VkDeviceMemory memory;
+ void *mapped;
+ VkDeviceSize size;
+} VkBufferM;
+
+/* Image + memory + view bundle. */
+typedef struct VkImageM {
+ VkImage image;
+ VkDeviceMemory memory;
+ VkImageView view;
+ uint32_t width, height;
+ VkFormat format;
+} VkImageM;
+
+/* One render target = one rendered view. The readback buffer is
+ * persistently mapped so we fwrite directly from VRAM-staged DMA. */
+typedef struct RenderTarget {
+ VkImageM color;
+ VkBufferM readback;
+ size_t row_pitch_bytes;
+} RenderTarget;
+
+/* One slot in the frames-in-flight ring. Holds everything that varies
+ * per-frame, so frame N+1 doesn't stomp on data the GPU is still
+ * reading for frame N. */
+typedef struct FrameSlot {
+ VkCommandBuffer cmd;
+ VkFence fence; /* signals when this slot's GPU work is done */
+ RenderTarget rt_topdown;
+ RenderTarget rt_bev;
+ VkBufferM agent_inst_vb;
+ uint32_t agent_inst_capacity; /* in instances */
+ int pending; /* 1 if a submit on this slot is in flight */
+ int rendered_topdown; /* did we draw the topdown view this frame? */
+ int rendered_bev;
+} FrameSlot;
+
+typedef struct Renderer {
+ Pipelines *pipelines; /* borrowed */
+ uint32_t width, height;
+
+ /* Static geometry, set up once at init. */
+ VkBufferM unit_quad_vb;
+ VkBufferM unit_quad_ib;
+
+ /* Per-episode geometry (constant across all frames in an episode).
+ * Read by every slot's command buffer; never written during the
+ * pipelined loop, so no per-slot duplication needed. */
+ VkBufferM road_vb;
+ uint32_t road_vb_capacity;
+ uint32_t *road_offsets; /* (num_polys+1,) host copy */
+ uint32_t *road_types; /* (num_polys,) */
+ uint32_t num_polys;
+ uint32_t road_meta_capacity;
+
+ /* Frames-in-flight ring + FIFO indices. */
+ FrameSlot slots[FRAMES_IN_FLIGHT];
+ int head; /* next slot to drain */
+ int tail; /* next slot to write */
+ int n_in_flight;
+
+ /* Ffmpeg pipes for the current episode. Owned by the caller; the
+ * renderer just borrows the pointers between episode_begin and
+ * episode_end. NULL = view disabled for this episode. */
+ FfmpegPipe *ep_pipe_topdown;
+ FfmpegPipe *ep_pipe_bev;
+} Renderer;
+
+int vk_renderer_init(VkCtx *ctx, Pipelines *p, Renderer *r, uint32_t width, uint32_t height);
+void vk_renderer_destroy(VkCtx *ctx, Renderer *r);
+
+/* Upload road geometry for a new episode. The data is copied; pointers
+ * are not retained past this call. */
+int vk_renderer_set_roads(VkCtx *ctx, Renderer *r, const float *road_xy, uint32_t num_verts,
+ const uint32_t *road_offsets, const uint32_t *road_types, uint32_t num_polys);
+
+/* Begin an episode. Stores the ffmpeg pipe pointers and resets the
+ * frames-in-flight FIFO. Either pipe may be NULL to disable that view. */
+void vk_renderer_episode_begin(Renderer *r, FfmpegPipe *pipe_topdown, FfmpegPipe *pipe_bev);
+
+/* Submit one frame. May internally wait on the oldest pending slot and
+ * fwrite its readback buffers to the ffmpeg pipes set in episode_begin
+ * before reusing it for this frame. Returns immediately after the
+ * submit completes (does not wait on the just-submitted frame). */
+int vk_renderer_submit_frame(VkCtx *ctx, Renderer *r, const AgentInstance *instances, uint32_t num_instances,
+ const Mat4 *mvp_topdown, const Mat4 *mvp_bev);
+
+/* Drain remaining in-flight frames at end of episode, fwriting each to
+ * the ffmpeg pipes in submission order. */
+int vk_renderer_episode_end(VkCtx *ctx, Renderer *r);
+
+#endif
diff --git a/pufferlib/ocean/env_binding.h b/pufferlib/ocean/env_binding.h
index d4090ea461..01252107cd 100644
--- a/pufferlib/ocean/env_binding.h
+++ b/pufferlib/ocean/env_binding.h
@@ -993,6 +993,99 @@ static PyObject *vec_get_global_ground_truth_trajectories(PyObject *self, PyObje
Py_RETURN_NONE;
}
+// Copy per-step sim trajectory data from the vectorized env into preallocated
+// numpy arrays. Args: (vec_env, x_arr, y_arr, z_arr, heading_arr, lengths_arr, ep_len).
+// x/y/z/heading_arr are float32 shape (total_agents, ep_len); lengths_arr is int32
+// shape (total_agents,). Iterates sub-envs in order, concatenating by agent offset.
+static PyObject *vec_get_sim_trajectories(PyObject *self, PyObject *args) {
+ if (PyTuple_Size(args) != 7) {
+ PyErr_SetString(PyExc_TypeError, "vec_get_sim_trajectories requires 7 arguments");
+ return NULL;
+ }
+
+ VecEnv *vec = unpack_vecenv(args);
+ if (!vec)
+ return NULL;
+
+ PyArrayObject *x_arr = (PyArrayObject *)PyTuple_GetItem(args, 1);
+ PyArrayObject *y_arr = (PyArrayObject *)PyTuple_GetItem(args, 2);
+ PyArrayObject *z_arr = (PyArrayObject *)PyTuple_GetItem(args, 3);
+ PyArrayObject *heading_arr = (PyArrayObject *)PyTuple_GetItem(args, 4);
+ PyArrayObject *lengths_arr = (PyArrayObject *)PyTuple_GetItem(args, 5);
+ int ep_len = (int)PyLong_AsLong(PyTuple_GetItem(args, 6));
+
+ float *x_base = (float *)PyArray_DATA(x_arr);
+ float *y_base = (float *)PyArray_DATA(y_arr);
+ float *z_base = (float *)PyArray_DATA(z_arr);
+ float *heading_base = (float *)PyArray_DATA(heading_arr);
+ int *lengths_base = (int *)PyArray_DATA(lengths_arr);
+
+ int offset = 0;
+ for (int i = 0; i < vec->num_envs; i++) {
+ Drive *drive = (Drive *)vec->envs[i];
+ c_get_sim_trajectories(drive, &x_base[offset * ep_len], &y_base[offset * ep_len], &z_base[offset * ep_len],
+ &heading_base[offset * ep_len], &lengths_base[offset], ep_len);
+ offset += drive->active_agent_count;
+ }
+
+ Py_RETURN_NONE;
+}
+
+// Return (world_mean_x, world_mean_y, world_mean_z) from env 0 ONLY.
+//
+// IMPORTANT: each sub-env in a vec has its OWN world_mean, computed in
+// set_means() from its own map's road + agent points. Different maps have
+// different world_means (potentially many kilometers apart in source-Waymo
+// coordinates). This function's return value is therefore only correct
+// for env 0; consumers that need to align other envs' trajectories with
+// their source maps must call vec_get_all_world_means below instead.
+//
+// Kept for backwards compatibility with code that historically assumed a
+// single shared world_mean (e.g. older saved trajectories_*.npz files).
+static PyObject *vec_get_world_mean(PyObject *self, PyObject *args) {
+ VecEnv *vec = unpack_vecenv(args);
+ if (!vec)
+ return NULL;
+ Drive *drive = (Drive *)vec->envs[0];
+ return Py_BuildValue("(fff)", drive->world_mean_x, drive->world_mean_y, drive->world_mean_z);
+}
+
+// Fill an (num_envs, 3) float32 numpy array with each sub-env's world_mean.
+// This is the function callers should use when they need to align per-env
+// trajectories with their source maps (the sub-envs may carry different
+// maps, each with its own centering offset).
+static PyObject *vec_get_all_world_means(PyObject *self, PyObject *args) {
+ if (PyTuple_Size(args) != 2) {
+ PyErr_SetString(PyExc_TypeError, "vec_get_all_world_means requires 2 arguments (vec, out)");
+ return NULL;
+ }
+ PyObject *vec_caps = PyTuple_GetItem(args, 0);
+ PyArrayObject *out_arr = (PyArrayObject *)PyTuple_GetItem(args, 1);
+
+ PyObject *single_arg = PyTuple_Pack(1, vec_caps);
+ if (!single_arg)
+ return NULL;
+ VecEnv *vec = unpack_vecenv(single_arg);
+ Py_DECREF(single_arg);
+ if (!vec)
+ return NULL;
+
+ if (!PyArray_Check(out_arr) || PyArray_TYPE(out_arr) != NPY_FLOAT32 || PyArray_NDIM(out_arr) != 2 ||
+ PyArray_DIM(out_arr, 0) != vec->num_envs || PyArray_DIM(out_arr, 1) != 3) {
+ PyErr_Format(PyExc_ValueError, "out must be a (num_envs=%d, 3) float32 array", vec->num_envs);
+ return NULL;
+ }
+
+ float *base = (float *)PyArray_DATA(out_arr);
+ for (int i = 0; i < vec->num_envs; i++) {
+ Drive *drive = (Drive *)vec->envs[i];
+ base[i * 3 + 0] = drive->world_mean_x;
+ base[i * 3 + 1] = drive->world_mean_y;
+ base[i * 3 + 2] = drive->world_mean_z;
+ }
+ Py_RETURN_NONE;
+}
+
static PyObject *vec_get_road_edge_counts(PyObject *self, PyObject *args) {
VecEnv *vec = unpack_vecenv(args);
if (!vec)
@@ -1131,6 +1224,12 @@ static PyMethodDef methods[] = {
"Get road edge polyline counts from vectorized env"},
{"vec_get_road_edge_polylines", vec_get_road_edge_polylines, METH_VARARGS,
"Get road edge polylines from vectorized env"},
+ {"vec_get_sim_trajectories", vec_get_sim_trajectories, METH_VARARGS,
+ "Get per-step sim trajectories from vectorized env"},
+ {"vec_get_world_mean", vec_get_world_mean, METH_VARARGS,
+ "Get world mean (x,y,z) from first sub-env (legacy; per-env world_means differ)"},
+ {"vec_get_all_world_means", vec_get_all_world_means, METH_VARARGS,
+ "Fill (num_envs, 3) float32 array with each sub-env's world_mean"},
{"env_log", env_log, METH_VARARGS, "Log a single environment"},
MY_METHODS,
{NULL, NULL, 0, NULL}};
diff --git a/pufferlib/pufferl.py b/pufferlib/pufferl.py
index e9d84b41f9..1e0f6b322f 100644
--- a/pufferlib/pufferl.py
+++ b/pufferlib/pufferl.py
@@ -506,6 +506,10 @@ def train(self):
if self.epoch % config["checkpoint_interval"] == 0 or done_training:
self.save_checkpoint()
+ self.save_trajectories()
+ # Snapshot reproducibility artifacts once per run on the first checkpoint.
+ if self.epoch == config["checkpoint_interval"]:
+ self.save_reproducibility()
self.msg = f"Checkpoint saved at update {self.epoch}"
if self.epoch % self.config["eval"]["eval_interval"] == 0 or done_training:
@@ -672,6 +676,157 @@ def save_checkpoint(self):
os.rename(state_path + ".tmp", state_path)
return model_path
+ def save_trajectories(self):
+ """Save per-checkpoint rollout trajectories + map context as a compressed npz.
+
+ Dumps the rolling policy-side buffers (actions, rewards, values, logprobs,
+ terminals, truncations) along with C-side per-step agent xyz/heading recorded
+ during the current episode, plus map context (map_ids, map_files, agent_offsets,
+ world_mean) needed to align coordinates back to the source map for offline
+ rendering.
+
+ Multiprocessing path: fan out via ``vecenv.save_worker_trajectories`` — each
+ worker writes its own npz into a temp dir, and we concatenate here.
+
+ Serial / native PufferEnv path: read directly from the driver env.
+
+ Opt out by setting ``save_trajectories: False`` in the train config.
+ """
+ if not self.config.get("save_trajectories", True):
+ return
+ if torch.distributed.is_initialized() and torch.distributed.get_rank() != 0:
+ return
+
+ run_id = self.logger.run_id
+ path = os.path.join(self.config["data_dir"], f"{self.config['env']}_{run_id}")
+ os.makedirs(path, exist_ok=True)
+ traj_path = os.path.join(path, f"trajectories_{self.epoch:06d}.npz")
+
+ data = {
+ "actions": self.actions.cpu().numpy(),
+ "rewards": self.rewards.cpu().numpy(),
+ "terminals": self.terminals.cpu().numpy(),
+ "truncations": self.truncations.cpu().numpy(),
+ "values": self.values.cpu().numpy(),
+ "logprobs": self.logprobs.cpu().numpy(),
+ "epoch": self.epoch,
+ "global_step": self.global_step,
+ }
+
+ try:
+ driver_env = getattr(self.vecenv, "driver_env", None)
+
+ # Multiprocessing: fan out notify() to workers, then stitch their files.
+ if hasattr(self.vecenv, "save_worker_trajectories"):
+ traj_tmp = getattr(driver_env, "_traj_save_dir", None) if driver_env else None
+ if traj_tmp:
+ self.vecenv.save_worker_trajectories()
+ worker_files = sorted(glob.glob(os.path.join(traj_tmp, "traj_worker_*.npz")))
+ if worker_files:
+ all_traj = {}
+ map_files = None
+ world_mean = None
+ all_world_means = [] # per-env, concatenated across workers
+ for f in worker_files:
+ d = np.load(f, allow_pickle=True)
+ for k in ("x", "y", "z", "heading", "lengths", "map_ids"):
+ if k in d:
+ all_traj.setdefault(k, []).append(d[k])
+ if map_files is None and "map_files" in d:
+ map_files = d["map_files"]
+ if world_mean is None and "world_mean" in d:
+ world_mean = d["world_mean"]
+ if "world_means" in d:
+ all_world_means.append(d["world_means"])
+ for k, v in all_traj.items():
+ key = f"traj_{k}" if k in ("x", "y", "z", "heading", "lengths") else k
+ data[key] = np.concatenate(v)
+ if map_files is not None:
+ data["map_files"] = map_files
+ if world_mean is not None:
+ data["world_mean"] = world_mean
+ # Concatenate per-env world_means across workers so
+ # offline tooling can align each env's trajectory
+ # with its own source map. See drive.py docstring on
+ # Drive.get_world_means for why this matters.
+ if all_world_means:
+ data["world_means"] = np.concatenate(all_world_means, axis=0)
+
+ # Serial / native PufferEnv: read directly from the driver.
+ elif driver_env is not None and hasattr(driver_env, "get_sim_trajectories"):
+ traj = driver_env.get_sim_trajectories()
+ for k, v in traj.items():
+ data[f"traj_{k}"] = v
+ if hasattr(driver_env, "map_ids"):
+ data["map_ids"] = np.array(driver_env.map_ids, dtype=np.int32)
+ if hasattr(driver_env, "agent_offsets"):
+ data["agent_offsets"] = np.array(driver_env.agent_offsets, dtype=np.int32)
+ if hasattr(driver_env, "map_files"):
+ data["map_files"] = np.array([str(f) for f in driver_env.map_files])
+ if hasattr(driver_env, "world_mean"):
+ data["world_mean"] = np.array(driver_env.world_mean, dtype=np.float32)
+ if hasattr(driver_env, "get_world_means"):
+ data["world_means"] = driver_env.get_world_means()
+ except Exception as e:
+ print(f"Warning: save_trajectories failed to collect C-side data: {e}")
+
+ np.savez_compressed(traj_path, **data)
+ print(f"Saved trajectories to {traj_path}")
+
+ def save_reproducibility(self):
+ """Snapshot the compiled .so, key source files, config, and git info once per run.
+
+ Intended for exact experiment replay: you can reproduce a checkpoint's
+ behavior by checking out the git commit (or applying the saved diff)
+ and loading the saved .so. Called once on the first checkpoint of a run.
+ """
+ if torch.distributed.is_initialized() and torch.distributed.get_rank() != 0:
+ return
+
+ run_id = self.logger.run_id
+ base_path = os.path.join(self.config["data_dir"], f"{self.config['env']}_{run_id}")
+ repro_path = os.path.join(base_path, f"reproducibility_{self.epoch:06d}")
+ os.makedirs(repro_path, exist_ok=True)
+
+ for so_file in glob.glob("pufferlib/ocean/drive/*.so"):
+ shutil.copy2(so_file, repro_path)
+
+ source_files = [
+ "pufferlib/ocean/drive/drive.h",
+ "pufferlib/ocean/drive/datatypes.h",
+ "pufferlib/ocean/drive/binding.c",
+ "pufferlib/ocean/drive/drive.py",
+ "pufferlib/ocean/drive/drivenet.h",
+ "pufferlib/ocean/env_binding.h",
+ "pufferlib/config/ocean/drive.ini",
+ "pufferlib/pufferl.py",
+ ]
+ for src in source_files:
+ if os.path.exists(src):
+ shutil.copy2(src, os.path.join(repro_path, os.path.basename(src)))
+
+ git_info_path = os.path.join(repro_path, "git_info.txt")
+ try:
+ commit = subprocess.run(["git", "rev-parse", "HEAD"], capture_output=True, text=True, timeout=5)
+ diff = subprocess.run(["git", "diff"], capture_output=True, text=True, timeout=10)
+ with open(git_info_path, "w") as f:
+ f.write(f"commit: {commit.stdout.strip()}\n\n")
+ f.write("diff:\n")
+ f.write(diff.stdout)
+ except Exception as e:
+ with open(git_info_path, "w") as f:
+ f.write(f"Could not capture git info: {e}\n")
+
+ try:
+ config_path = os.path.join(repro_path, "training_config.txt")
+ with open(config_path, "w") as f:
+ for k, v in sorted(self.config.items()):
+ f.write(f"{k}: {v}\n")
+ except Exception as e:
+ print(f"Warning: could not snapshot config: {e}")
+
+ print(f"Saved reproducibility artifacts to {repro_path}")
+
def print_dashboard(self, clear=False, idx=[0], c1="[cyan]", c2="[white]", b1="[bright_cyan]", b2="[bright_white]"):
config = self.config
sps = dist_sum(self.sps, config["device"])
@@ -1045,6 +1200,14 @@ def download(self):
def train(env_name, args=None, vecenv=None, policy=None, logger=None):
args = args or load_config(env_name)
+ # If trajectory saving is on, pre-create a shared scratch dir under data_dir
+ # and thread it into args["env"] so every worker inherits it via env_kwargs.
+ # PuffeRL.save_trajectories() later globs traj_worker_*.npz from this dir.
+ if args.get("train", {}).get("save_trajectories", True) and "data_dir" in args.get("train", {}):
+ traj_save_dir = os.path.join(args["train"]["data_dir"], "traj_tmp")
+ os.makedirs(traj_save_dir, exist_ok=True)
+ args.setdefault("env", {})["traj_save_dir"] = traj_save_dir
+
# Assume TorchRun DDP is used if LOCAL_RANK is set
if "LOCAL_RANK" in os.environ:
world_size = int(os.environ.get("WORLD_SIZE", 1))
diff --git a/pufferlib/vector.py b/pufferlib/vector.py
index bf5dc7460e..f27ff2a650 100644
--- a/pufferlib/vector.py
+++ b/pufferlib/vector.py
@@ -215,6 +215,16 @@ def _worker_process(
else:
envs = Serial(env_creators, env_args, env_kwargs, num_envs, buf=buf, seed=seed * num_envs)
+ # Tag the env(s) with this worker index so env.notify() (e.g. trajectory save
+ # in Drive) can pick a per-worker output filename. Works for both native
+ # PufferEnvs and Serial-wrapped ones.
+ if hasattr(envs, "_worker_idx"):
+ envs._worker_idx = worker_idx
+ if hasattr(envs, "envs"):
+ for env in envs.envs:
+ if hasattr(env, "_worker_idx"):
+ env._worker_idx = worker_idx
+
semaphores = np.ndarray(num_workers, dtype=np.uint8, buffer=shm["semaphores"])
notify = np.ndarray(num_workers, dtype=bool, buffer=shm["notify"])
start = time.time()
@@ -533,6 +543,17 @@ def async_reset(self, seed=0):
def notify(self):
self.buf["notify"][:] = True
+ def save_worker_trajectories(self):
+ """Trigger every worker to call env.notify(), then block until all finish.
+
+ Used by PuffeRL.save_trajectories() to fan out a trajectory-save request
+ across workers. Each worker's env.notify() writes a per-worker npz and
+ clears its own notify flag; we spin until all flags are down.
+ """
+ self.buf["notify"][:] = True
+ while any(self.buf["notify"]):
+ time.sleep(0.01)
+
def close(self):
self.driver_env.close()
for p in self.processes:
diff --git a/setup.py b/setup.py
index ea1692af66..a7fb621b01 100644
--- a/setup.py
+++ b/setup.py
@@ -24,6 +24,10 @@
DEBUG = os.getenv("DEBUG", "0") == "1"
NO_OCEAN = os.getenv("NO_OCEAN", "0") == "1"
NO_TRAIN = os.getenv("NO_TRAIN", "0") == "1"
+# Opt-in: TRAJVIZ=1 builds the Vulkan trajectory renderer as a CPython
+# extension. Requires libvulkan-dev + glslang-tools (apt). See
+# docs/trajviz.md for installation. Default off — most users don't need it.
+TRAJVIZ = os.getenv("TRAJVIZ", "0") == "1"
# Build raylib for your platform
RAYLIB_URL = "https://github.com/raysan5/raylib/releases/download/5.5/"
@@ -268,6 +272,86 @@ def run(self):
c_ext.include_dirs.append("/usr/local/include")
c_ext.extra_link_args.extend(["-L/usr/local/lib", "-llammps"])
+# Optional: Vulkan-backed offline trajectory renderer.
+#
+# Built only when TRAJVIZ=1 is set in the environment, and only if the
+# Vulkan headers and the GLSL→SPIR-V compiler are present. This is a
+# separate extension from drive's binding because:
+# - it doesn't link against raylib (it uses Vulkan)
+# - it doesn't include inih
+# - its dependencies (libvulkan-dev, glslang-tools) are optional
+# Failing to import the extension at runtime is fine — the rest of drive
+# (sim, training, eval) keeps working without it.
+if TRAJVIZ and not NO_OCEAN:
+ import subprocess
+
+ trajviz_dir = "pufferlib/ocean/drive/trajviz"
+
+ # Run the shader build script first so shaders.c exists when the
+ # Extension's source list is materialized at compile time. The script
+ # invokes glslangValidator on each .vert/.frag and writes a generated
+ # ../shaders.c with the SPIR-V blobs as uint32_t arrays.
+ shader_script = os.path.join(trajviz_dir, "shaders", "build_shaders.sh")
+ if not os.path.exists(shader_script):
+ raise RuntimeError(f"TRAJVIZ=1 set but {shader_script} not found")
+ print(f"Compiling trajviz shaders via {shader_script}")
+ try:
+ subprocess.check_call(["bash", shader_script])
+ except subprocess.CalledProcessError as e:
+ raise RuntimeError(
+ "trajviz shader compilation failed. Install glslang-tools "
+ "(sudo apt install glslang-tools) or unset TRAJVIZ to skip."
+ ) from e
+
+ # Vulkan headers location: try the system path, then $VULKAN_SDK/include.
+ vulkan_inc = None
+ for p in ("/usr/include", "/usr/local/include"):
+ if os.path.exists(os.path.join(p, "vulkan", "vulkan.h")):
+ vulkan_inc = p
+ break
+ if vulkan_inc is None:
+ sdk = os.environ.get("VULKAN_SDK")
+ if sdk and os.path.exists(os.path.join(sdk, "include", "vulkan", "vulkan.h")):
+ vulkan_inc = os.path.join(sdk, "include")
+ if vulkan_inc is None:
+ raise RuntimeError(
+ "TRAJVIZ=1 set but vulkan/vulkan.h not found. Install with "
+ "sudo apt install libvulkan-dev or set VULKAN_SDK."
+ )
+
+ trajviz_sources = [
+ os.path.join(trajviz_dir, "_native.c"),
+ os.path.join(trajviz_dir, "trajviz.c"),
+ os.path.join(trajviz_dir, "vk_context.c"),
+ os.path.join(trajviz_dir, "vk_pipeline.c"),
+ os.path.join(trajviz_dir, "vk_renderer.c"),
+ os.path.join(trajviz_dir, "vk_batch_renderer.c"),
+ os.path.join(trajviz_dir, "ffmpeg_pipe.c"),
+ os.path.join(trajviz_dir, "shaders.c"), # generated by build_shaders.sh
+ ]
+
+ trajviz_ext = Extension(
+ "pufferlib.ocean.drive.trajviz._native",
+ sources=trajviz_sources,
+ include_dirs=[numpy.get_include(), trajviz_dir, vulkan_inc],
+ # No raylib, no inih, no torch — pure Vulkan + libc + libm.
+ libraries=["vulkan"],
+ extra_compile_args=[
+ "-DNPY_NO_DEPRECATED_API=NPY_1_7_API_VERSION",
+ "-O2" if not DEBUG else "-O0",
+ "-g" if DEBUG else "-fno-omit-frame-pointer",
+ "-Wall",
+ "-Wextra",
+ "-Wno-unused-parameter",
+ ]
+ + (["-DTRAJVIZ_DEBUG=1"] if DEBUG else []),
+ extra_link_args=["-fwrapv"],
+ )
+ c_extensions.append(trajviz_ext)
+ # The trajviz package directory needs to be discoverable by setuptools.
+ # find_namespace_packages picks it up automatically since it has an
+ # __init__.py — see pufferlib/ocean/drive/trajviz/__init__.py.
+
# Check if CUDA compiler is available. You need cuda dev, not just runtime.
torch_extensions = []
if not NO_TRAIN: