diff --git a/skills/transformation/depth-estimation/models.json b/skills/transformation/depth-estimation/models.json index 27ee043..d240e34 100644 --- a/skills/transformation/depth-estimation/models.json +++ b/skills/transformation/depth-estimation/models.json @@ -77,6 +77,13 @@ "precision": "float32", "size_mb": 99.0, "description": "PyTorch ViT-S — CUDA/CPU" + }, + "depth_anything_v2_vits_trt_fp16": { + "precision": "float16", + "size_mb": 52.6, + "format": "trt", + "requires": "tensorrt", + "description": "TensorRT FP16 — built locally, 6.9x faster" } } } diff --git a/skills/transformation/depth-estimation/requirements.txt b/skills/transformation/depth-estimation/requirements.txt index 2717a00..7ee3a71 100644 --- a/skills/transformation/depth-estimation/requirements.txt +++ b/skills/transformation/depth-estimation/requirements.txt @@ -20,3 +20,8 @@ numpy>=1.24.0 opencv-python-headless>=4.8.0 Pillow>=10.0.0 matplotlib>=3.7.0 + +# ── TensorRT (optional, Windows/Linux NVIDIA) ──────────────────────── +# If available, transform.py auto-selects TRT FP16 for ~7x speedup. +# Falls back to PyTorch CUDA if not installed. +tensorrt>=10.0; sys_platform != "darwin" diff --git a/skills/transformation/depth-estimation/scripts/transform.py b/skills/transformation/depth-estimation/scripts/transform.py index c4013c3..d2592a3 100644 --- a/skills/transformation/depth-estimation/scripts/transform.py +++ b/skills/transformation/depth-estimation/scripts/transform.py @@ -70,6 +70,9 @@ # Where Aegis DepthVisionStudio stores downloaded models MODELS_DIR = Path.home() / ".aegis-ai" / "models" / "feature-extraction" +# TensorRT engine cache directory (engines are GPU-specific) +TRT_CACHE_DIR = MODELS_DIR / "trt_engines" + # PyTorch model configs (fallback on non-macOS) PYTORCH_CONFIGS = { "depth-anything-v2-small": { @@ -110,6 +113,13 @@ def __init__(self): self.opacity = 0.5 self.blend_mode = "depth_only" # Default for privacy: depth_only anonymizes self._coreml_input_size = COREML_INPUT_SIZE + # TensorRT state (populated by _load_tensorrt) + self._trt_context = None + self._trt_input_name = None + self._trt_output_name = None + self._trt_input_tensor = None + self._trt_output_tensor = None + self._trt_stream = None def parse_extra_args(self, parser: argparse.ArgumentParser): parser.add_argument("--model", type=str, default="depth-anything-v2-small", @@ -137,6 +147,13 @@ def load_model(self, config: dict) -> dict: except Exception as e: _log(f"CoreML load failed ({e}), falling back to PyTorch", self._tag) + # Try TensorRT (fails fast if not installed) + try: + info = self._load_tensorrt(model_name, config) + return info + except Exception as e: + _log(f"TensorRT unavailable ({e}), falling back to PyTorch", self._tag) + # Fallback: PyTorch return self._load_pytorch(model_name, config) @@ -196,6 +213,139 @@ def _download_coreml_model(self, variant_id: str): _log(f"CoreML model download failed: {e}", self._tag) raise + # ── TensorRT backend (Windows/Linux NVIDIA) ─────────────────────── + + def _load_tensorrt(self, model_name: str, config: dict) -> dict: + """Load or build a TensorRT FP16 engine for fastest NVIDIA inference.""" + import torch + import tensorrt as trt + + _log(f"Attempting TensorRT FP16 for {model_name}", self._tag) + + cfg = PYTORCH_CONFIGS.get(model_name) + if not cfg: + raise ValueError(f"Unknown model: {model_name}") + + gpu_tag = torch.cuda.get_device_name(0).replace(" ", "_").lower() + engine_path = TRT_CACHE_DIR / f"{cfg['filename'].replace('.pth', '')}_fp16_{gpu_tag}.trt" + + if engine_path.exists(): + _log(f"Loading cached TRT engine: {engine_path}", self._tag) + engine = self._deserialize_engine(engine_path) + else: + _log("No cached engine — building from ONNX (30-120s)...", self._tag) + engine = self._build_trt_engine(cfg, engine_path) + + if engine is None: + raise RuntimeError("TensorRT engine build/load failed") + + self._trt_context = engine.create_execution_context() + self._trt_input_name = engine.get_tensor_name(0) + self._trt_output_name = engine.get_tensor_name(1) + + input_shape = engine.get_tensor_shape(self._trt_input_name) + fixed_shape = tuple(1 if d == -1 else d for d in input_shape) + self._trt_context.set_input_shape(self._trt_input_name, fixed_shape) + + self._trt_input_tensor = torch.zeros(fixed_shape, dtype=torch.float32, device="cuda") + actual_out_shape = self._trt_context.get_tensor_shape(self._trt_output_name) + self._trt_output_tensor = torch.empty(list(actual_out_shape), dtype=torch.float32, device="cuda") + + self._trt_context.set_tensor_address(self._trt_input_name, self._trt_input_tensor.data_ptr()) + self._trt_context.set_tensor_address(self._trt_output_name, self._trt_output_tensor.data_ptr()) + self._trt_stream = torch.cuda.current_stream().cuda_stream + + self.backend = "tensorrt" + _log(f"TensorRT FP16 engine ready: {engine_path.name}", self._tag) + return { + "model": model_name, + "device": "cuda", + "blend_mode": self.blend_mode, + "colormap": config.get("colormap", "inferno"), + "backend": "tensorrt", + } + + def _build_trt_engine(self, cfg: dict, engine_path: Path): + """Export PyTorch → ONNX → build TRT FP16 engine → serialize to disk.""" + import torch + import tensorrt as trt + from depth_anything_v2.dpt import DepthAnythingV2 + from huggingface_hub import hf_hub_download + + weights_path = hf_hub_download(cfg["repo"], cfg["filename"]) + pt_model = DepthAnythingV2( + encoder=cfg["encoder"], features=cfg["features"], + out_channels=cfg["out_channels"], + ) + pt_model.load_state_dict(torch.load(weights_path, map_location="cuda", weights_only=True)) + pt_model.to("cuda").eval() + + dummy = torch.randn(1, 3, 518, 518, device="cuda") + onnx_path = TRT_CACHE_DIR / f"{cfg['filename'].replace('.pth', '')}.onnx" + TRT_CACHE_DIR.mkdir(parents=True, exist_ok=True) + + _log(f"Exporting ONNX: {onnx_path.name}", self._tag) + torch.onnx.export( + pt_model, dummy, str(onnx_path), + input_names=["input"], output_names=["depth"], + dynamic_axes={"input": {0: "batch"}, "depth": {0: "batch"}}, + opset_version=17, + ) + del pt_model + torch.cuda.empty_cache() + + logger = trt.Logger(trt.Logger.WARNING) + builder = trt.Builder(logger) + network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)) + parser = trt.OnnxParser(network, logger) + + _log("Parsing ONNX for TensorRT...", self._tag) + with open(str(onnx_path), "rb") as f: + if not parser.parse(f.read()): + for i in range(parser.num_errors): + _log(f" ONNX parse error: {parser.get_error(i)}", self._tag) + return None + + config = builder.create_builder_config() + config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 30) + + inp = network.get_input(0) + if any(d == -1 for d in inp.shape): + profile = builder.create_optimization_profile() + fixed = tuple(1 if d == -1 else d for d in inp.shape) + profile.set_shape(inp.name, fixed, fixed, fixed) + config.add_optimization_profile(profile) + + config.set_flag(trt.BuilderFlag.FP16) + + _log("Building TRT FP16 engine (30-120s)...", self._tag) + serialized = builder.build_serialized_network(network, config) + if serialized is None: + _log("TRT engine build failed!", self._tag) + return None + + engine_bytes = bytes(serialized) + with open(str(engine_path), "wb") as f: + f.write(engine_bytes) + _log(f"Engine cached: {engine_path} ({len(engine_bytes) / 1e6:.1f} MB)", self._tag) + + try: + onnx_path.unlink() + except OSError: + pass + + runtime = trt.Runtime(logger) + return runtime.deserialize_cuda_engine(engine_bytes) + + @staticmethod + def _deserialize_engine(engine_path: Path): + """Load a previously serialized TRT engine from disk.""" + import tensorrt as trt + logger = trt.Logger(trt.Logger.WARNING) + runtime = trt.Runtime(logger) + with open(str(engine_path), "rb") as f: + return runtime.deserialize_cuda_engine(f.read()) + # ── PyTorch backend (fallback) ──────────────────────────────────── def _load_pytorch(self, model_name: str, config: dict) -> dict: @@ -242,6 +392,8 @@ def transform_frame(self, image, metadata: dict): if self.backend == "coreml": depth_colored = self._infer_coreml(image) + elif self.backend == "tensorrt": + depth_colored = self._infer_tensorrt(image) else: depth_colored = self._infer_pytorch(image) @@ -308,6 +460,36 @@ def _infer_pytorch(self, image): return depth_colored + def _infer_tensorrt(self, image): + """Run TensorRT FP16 inference and return colorized depth map.""" + import torch + import cv2 + import numpy as np + + original_h, original_w = image.shape[:2] + rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + + resized = cv2.resize(rgb, (518, 518), interpolation=cv2.INTER_LINEAR) + img_float = resized.astype(np.float32) / 255.0 + mean = np.array([0.485, 0.456, 0.406], dtype=np.float32) + std = np.array([0.229, 0.224, 0.225], dtype=np.float32) + img_float = (img_float - mean) / std + img_nchw = np.transpose(img_float, (2, 0, 1))[np.newaxis] + + self._trt_input_tensor.copy_(torch.from_numpy(img_nchw)) + self._trt_context.execute_async_v3(self._trt_stream) + torch.cuda.synchronize() + + depth = self._trt_output_tensor.cpu().numpy() + depth = np.squeeze(depth) + + d_min, d_max = depth.min(), depth.max() + depth_norm = ((depth - d_min) / (d_max - d_min + 1e-8) * 255).astype(np.uint8) + depth_colored = cv2.applyColorMap(depth_norm, self.colormap_id) + depth_colored = cv2.resize(depth_colored, (original_w, original_h)) + + return depth_colored + # ── Config updates ──────────────────────────────────────────────── def on_config_update(self, config: dict):