|
70 | 70 | # Where Aegis DepthVisionStudio stores downloaded models |
71 | 71 | MODELS_DIR = Path.home() / ".aegis-ai" / "models" / "feature-extraction" |
72 | 72 |
|
| 73 | +# TensorRT engine cache directory (engines are GPU-specific) |
| 74 | +TRT_CACHE_DIR = MODELS_DIR / "trt_engines" |
| 75 | + |
73 | 76 | # PyTorch model configs (fallback on non-macOS) |
74 | 77 | PYTORCH_CONFIGS = { |
75 | 78 | "depth-anything-v2-small": { |
@@ -110,6 +113,13 @@ def __init__(self): |
110 | 113 | self.opacity = 0.5 |
111 | 114 | self.blend_mode = "depth_only" # Default for privacy: depth_only anonymizes |
112 | 115 | self._coreml_input_size = COREML_INPUT_SIZE |
| 116 | + # TensorRT state (populated by _load_tensorrt) |
| 117 | + self._trt_context = None |
| 118 | + self._trt_input_name = None |
| 119 | + self._trt_output_name = None |
| 120 | + self._trt_input_tensor = None |
| 121 | + self._trt_output_tensor = None |
| 122 | + self._trt_stream = None |
113 | 123 |
|
114 | 124 | def parse_extra_args(self, parser: argparse.ArgumentParser): |
115 | 125 | parser.add_argument("--model", type=str, default="depth-anything-v2-small", |
@@ -137,6 +147,13 @@ def load_model(self, config: dict) -> dict: |
137 | 147 | except Exception as e: |
138 | 148 | _log(f"CoreML load failed ({e}), falling back to PyTorch", self._tag) |
139 | 149 |
|
| 150 | + # Try TensorRT (fails fast if not installed) |
| 151 | + try: |
| 152 | + info = self._load_tensorrt(model_name, config) |
| 153 | + return info |
| 154 | + except Exception as e: |
| 155 | + _log(f"TensorRT unavailable ({e}), falling back to PyTorch", self._tag) |
| 156 | + |
140 | 157 | # Fallback: PyTorch |
141 | 158 | return self._load_pytorch(model_name, config) |
142 | 159 |
|
@@ -196,6 +213,139 @@ def _download_coreml_model(self, variant_id: str): |
196 | 213 | _log(f"CoreML model download failed: {e}", self._tag) |
197 | 214 | raise |
198 | 215 |
|
| 216 | + # ── TensorRT backend (Windows/Linux NVIDIA) ─────────────────────── |
| 217 | + |
| 218 | + def _load_tensorrt(self, model_name: str, config: dict) -> dict: |
| 219 | + """Load or build a TensorRT FP16 engine for fastest NVIDIA inference.""" |
| 220 | + import torch |
| 221 | + import tensorrt as trt |
| 222 | + |
| 223 | + _log(f"Attempting TensorRT FP16 for {model_name}", self._tag) |
| 224 | + |
| 225 | + cfg = PYTORCH_CONFIGS.get(model_name) |
| 226 | + if not cfg: |
| 227 | + raise ValueError(f"Unknown model: {model_name}") |
| 228 | + |
| 229 | + gpu_tag = torch.cuda.get_device_name(0).replace(" ", "_").lower() |
| 230 | + engine_path = TRT_CACHE_DIR / f"{cfg['filename'].replace('.pth', '')}_fp16_{gpu_tag}.trt" |
| 231 | + |
| 232 | + if engine_path.exists(): |
| 233 | + _log(f"Loading cached TRT engine: {engine_path}", self._tag) |
| 234 | + engine = self._deserialize_engine(engine_path) |
| 235 | + else: |
| 236 | + _log("No cached engine — building from ONNX (30-120s)...", self._tag) |
| 237 | + engine = self._build_trt_engine(cfg, engine_path) |
| 238 | + |
| 239 | + if engine is None: |
| 240 | + raise RuntimeError("TensorRT engine build/load failed") |
| 241 | + |
| 242 | + self._trt_context = engine.create_execution_context() |
| 243 | + self._trt_input_name = engine.get_tensor_name(0) |
| 244 | + self._trt_output_name = engine.get_tensor_name(1) |
| 245 | + |
| 246 | + input_shape = engine.get_tensor_shape(self._trt_input_name) |
| 247 | + fixed_shape = tuple(1 if d == -1 else d for d in input_shape) |
| 248 | + self._trt_context.set_input_shape(self._trt_input_name, fixed_shape) |
| 249 | + |
| 250 | + self._trt_input_tensor = torch.zeros(fixed_shape, dtype=torch.float32, device="cuda") |
| 251 | + actual_out_shape = self._trt_context.get_tensor_shape(self._trt_output_name) |
| 252 | + self._trt_output_tensor = torch.empty(list(actual_out_shape), dtype=torch.float32, device="cuda") |
| 253 | + |
| 254 | + self._trt_context.set_tensor_address(self._trt_input_name, self._trt_input_tensor.data_ptr()) |
| 255 | + self._trt_context.set_tensor_address(self._trt_output_name, self._trt_output_tensor.data_ptr()) |
| 256 | + self._trt_stream = torch.cuda.current_stream().cuda_stream |
| 257 | + |
| 258 | + self.backend = "tensorrt" |
| 259 | + _log(f"TensorRT FP16 engine ready: {engine_path.name}", self._tag) |
| 260 | + return { |
| 261 | + "model": model_name, |
| 262 | + "device": "cuda", |
| 263 | + "blend_mode": self.blend_mode, |
| 264 | + "colormap": config.get("colormap", "inferno"), |
| 265 | + "backend": "tensorrt", |
| 266 | + } |
| 267 | + |
| 268 | + def _build_trt_engine(self, cfg: dict, engine_path: Path): |
| 269 | + """Export PyTorch → ONNX → build TRT FP16 engine → serialize to disk.""" |
| 270 | + import torch |
| 271 | + import tensorrt as trt |
| 272 | + from depth_anything_v2.dpt import DepthAnythingV2 |
| 273 | + from huggingface_hub import hf_hub_download |
| 274 | + |
| 275 | + weights_path = hf_hub_download(cfg["repo"], cfg["filename"]) |
| 276 | + pt_model = DepthAnythingV2( |
| 277 | + encoder=cfg["encoder"], features=cfg["features"], |
| 278 | + out_channels=cfg["out_channels"], |
| 279 | + ) |
| 280 | + pt_model.load_state_dict(torch.load(weights_path, map_location="cuda", weights_only=True)) |
| 281 | + pt_model.to("cuda").eval() |
| 282 | + |
| 283 | + dummy = torch.randn(1, 3, 518, 518, device="cuda") |
| 284 | + onnx_path = TRT_CACHE_DIR / f"{cfg['filename'].replace('.pth', '')}.onnx" |
| 285 | + TRT_CACHE_DIR.mkdir(parents=True, exist_ok=True) |
| 286 | + |
| 287 | + _log(f"Exporting ONNX: {onnx_path.name}", self._tag) |
| 288 | + torch.onnx.export( |
| 289 | + pt_model, dummy, str(onnx_path), |
| 290 | + input_names=["input"], output_names=["depth"], |
| 291 | + dynamic_axes={"input": {0: "batch"}, "depth": {0: "batch"}}, |
| 292 | + opset_version=17, |
| 293 | + ) |
| 294 | + del pt_model |
| 295 | + torch.cuda.empty_cache() |
| 296 | + |
| 297 | + logger = trt.Logger(trt.Logger.WARNING) |
| 298 | + builder = trt.Builder(logger) |
| 299 | + network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)) |
| 300 | + parser = trt.OnnxParser(network, logger) |
| 301 | + |
| 302 | + _log("Parsing ONNX for TensorRT...", self._tag) |
| 303 | + with open(str(onnx_path), "rb") as f: |
| 304 | + if not parser.parse(f.read()): |
| 305 | + for i in range(parser.num_errors): |
| 306 | + _log(f" ONNX parse error: {parser.get_error(i)}", self._tag) |
| 307 | + return None |
| 308 | + |
| 309 | + config = builder.create_builder_config() |
| 310 | + config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 30) |
| 311 | + |
| 312 | + inp = network.get_input(0) |
| 313 | + if any(d == -1 for d in inp.shape): |
| 314 | + profile = builder.create_optimization_profile() |
| 315 | + fixed = tuple(1 if d == -1 else d for d in inp.shape) |
| 316 | + profile.set_shape(inp.name, fixed, fixed, fixed) |
| 317 | + config.add_optimization_profile(profile) |
| 318 | + |
| 319 | + config.set_flag(trt.BuilderFlag.FP16) |
| 320 | + |
| 321 | + _log("Building TRT FP16 engine (30-120s)...", self._tag) |
| 322 | + serialized = builder.build_serialized_network(network, config) |
| 323 | + if serialized is None: |
| 324 | + _log("TRT engine build failed!", self._tag) |
| 325 | + return None |
| 326 | + |
| 327 | + engine_bytes = bytes(serialized) |
| 328 | + with open(str(engine_path), "wb") as f: |
| 329 | + f.write(engine_bytes) |
| 330 | + _log(f"Engine cached: {engine_path} ({len(engine_bytes) / 1e6:.1f} MB)", self._tag) |
| 331 | + |
| 332 | + try: |
| 333 | + onnx_path.unlink() |
| 334 | + except OSError: |
| 335 | + pass |
| 336 | + |
| 337 | + runtime = trt.Runtime(logger) |
| 338 | + return runtime.deserialize_cuda_engine(engine_bytes) |
| 339 | + |
| 340 | + @staticmethod |
| 341 | + def _deserialize_engine(engine_path: Path): |
| 342 | + """Load a previously serialized TRT engine from disk.""" |
| 343 | + import tensorrt as trt |
| 344 | + logger = trt.Logger(trt.Logger.WARNING) |
| 345 | + runtime = trt.Runtime(logger) |
| 346 | + with open(str(engine_path), "rb") as f: |
| 347 | + return runtime.deserialize_cuda_engine(f.read()) |
| 348 | + |
199 | 349 | # ── PyTorch backend (fallback) ──────────────────────────────────── |
200 | 350 |
|
201 | 351 | def _load_pytorch(self, model_name: str, config: dict) -> dict: |
@@ -242,6 +392,8 @@ def transform_frame(self, image, metadata: dict): |
242 | 392 |
|
243 | 393 | if self.backend == "coreml": |
244 | 394 | depth_colored = self._infer_coreml(image) |
| 395 | + elif self.backend == "tensorrt": |
| 396 | + depth_colored = self._infer_tensorrt(image) |
245 | 397 | else: |
246 | 398 | depth_colored = self._infer_pytorch(image) |
247 | 399 |
|
@@ -308,6 +460,36 @@ def _infer_pytorch(self, image): |
308 | 460 |
|
309 | 461 | return depth_colored |
310 | 462 |
|
| 463 | + def _infer_tensorrt(self, image): |
| 464 | + """Run TensorRT FP16 inference and return colorized depth map.""" |
| 465 | + import torch |
| 466 | + import cv2 |
| 467 | + import numpy as np |
| 468 | + |
| 469 | + original_h, original_w = image.shape[:2] |
| 470 | + rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) |
| 471 | + |
| 472 | + resized = cv2.resize(rgb, (518, 518), interpolation=cv2.INTER_LINEAR) |
| 473 | + img_float = resized.astype(np.float32) / 255.0 |
| 474 | + mean = np.array([0.485, 0.456, 0.406], dtype=np.float32) |
| 475 | + std = np.array([0.229, 0.224, 0.225], dtype=np.float32) |
| 476 | + img_float = (img_float - mean) / std |
| 477 | + img_nchw = np.transpose(img_float, (2, 0, 1))[np.newaxis] |
| 478 | + |
| 479 | + self._trt_input_tensor.copy_(torch.from_numpy(img_nchw)) |
| 480 | + self._trt_context.execute_async_v3(self._trt_stream) |
| 481 | + torch.cuda.synchronize() |
| 482 | + |
| 483 | + depth = self._trt_output_tensor.cpu().numpy() |
| 484 | + depth = np.squeeze(depth) |
| 485 | + |
| 486 | + d_min, d_max = depth.min(), depth.max() |
| 487 | + depth_norm = ((depth - d_min) / (d_max - d_min + 1e-8) * 255).astype(np.uint8) |
| 488 | + depth_colored = cv2.applyColorMap(depth_norm, self.colormap_id) |
| 489 | + depth_colored = cv2.resize(depth_colored, (original_w, original_h)) |
| 490 | + |
| 491 | + return depth_colored |
| 492 | + |
311 | 493 | # ── Config updates ──────────────────────────────────────────────── |
312 | 494 |
|
313 | 495 | def on_config_update(self, config: dict): |
|
0 commit comments