Skip to content

Commit 5bc4262

Browse files
authored
Merge pull request #159 from SharpAI/feature/tensorrt-fp16-backend
feat: TensorRT FP16 backend for depth estimation
2 parents f043dd2 + adc3859 commit 5bc4262

3 files changed

Lines changed: 194 additions & 0 deletions

File tree

skills/transformation/depth-estimation/models.json

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,13 @@
7777
"precision": "float32",
7878
"size_mb": 99.0,
7979
"description": "PyTorch ViT-S — CUDA/CPU"
80+
},
81+
"depth_anything_v2_vits_trt_fp16": {
82+
"precision": "float16",
83+
"size_mb": 52.6,
84+
"format": "trt",
85+
"requires": "tensorrt",
86+
"description": "TensorRT FP16 — built locally, 6.9x faster"
8087
}
8188
}
8289
}

skills/transformation/depth-estimation/requirements.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,8 @@ numpy>=1.24.0
2020
opencv-python-headless>=4.8.0
2121
Pillow>=10.0.0
2222
matplotlib>=3.7.0
23+
24+
# ── TensorRT (optional, Windows/Linux NVIDIA) ────────────────────────
25+
# If available, transform.py auto-selects TRT FP16 for ~7x speedup.
26+
# Falls back to PyTorch CUDA if not installed.
27+
tensorrt>=10.0; sys_platform != "darwin"

skills/transformation/depth-estimation/scripts/transform.py

Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,9 @@
7070
# Where Aegis DepthVisionStudio stores downloaded models
7171
MODELS_DIR = Path.home() / ".aegis-ai" / "models" / "feature-extraction"
7272

73+
# TensorRT engine cache directory (engines are GPU-specific)
74+
TRT_CACHE_DIR = MODELS_DIR / "trt_engines"
75+
7376
# PyTorch model configs (fallback on non-macOS)
7477
PYTORCH_CONFIGS = {
7578
"depth-anything-v2-small": {
@@ -110,6 +113,13 @@ def __init__(self):
110113
self.opacity = 0.5
111114
self.blend_mode = "depth_only" # Default for privacy: depth_only anonymizes
112115
self._coreml_input_size = COREML_INPUT_SIZE
116+
# TensorRT state (populated by _load_tensorrt)
117+
self._trt_context = None
118+
self._trt_input_name = None
119+
self._trt_output_name = None
120+
self._trt_input_tensor = None
121+
self._trt_output_tensor = None
122+
self._trt_stream = None
113123

114124
def parse_extra_args(self, parser: argparse.ArgumentParser):
115125
parser.add_argument("--model", type=str, default="depth-anything-v2-small",
@@ -137,6 +147,13 @@ def load_model(self, config: dict) -> dict:
137147
except Exception as e:
138148
_log(f"CoreML load failed ({e}), falling back to PyTorch", self._tag)
139149

150+
# Try TensorRT (fails fast if not installed)
151+
try:
152+
info = self._load_tensorrt(model_name, config)
153+
return info
154+
except Exception as e:
155+
_log(f"TensorRT unavailable ({e}), falling back to PyTorch", self._tag)
156+
140157
# Fallback: PyTorch
141158
return self._load_pytorch(model_name, config)
142159

@@ -196,6 +213,139 @@ def _download_coreml_model(self, variant_id: str):
196213
_log(f"CoreML model download failed: {e}", self._tag)
197214
raise
198215

216+
# ── TensorRT backend (Windows/Linux NVIDIA) ───────────────────────
217+
218+
def _load_tensorrt(self, model_name: str, config: dict) -> dict:
219+
"""Load or build a TensorRT FP16 engine for fastest NVIDIA inference."""
220+
import torch
221+
import tensorrt as trt
222+
223+
_log(f"Attempting TensorRT FP16 for {model_name}", self._tag)
224+
225+
cfg = PYTORCH_CONFIGS.get(model_name)
226+
if not cfg:
227+
raise ValueError(f"Unknown model: {model_name}")
228+
229+
gpu_tag = torch.cuda.get_device_name(0).replace(" ", "_").lower()
230+
engine_path = TRT_CACHE_DIR / f"{cfg['filename'].replace('.pth', '')}_fp16_{gpu_tag}.trt"
231+
232+
if engine_path.exists():
233+
_log(f"Loading cached TRT engine: {engine_path}", self._tag)
234+
engine = self._deserialize_engine(engine_path)
235+
else:
236+
_log("No cached engine — building from ONNX (30-120s)...", self._tag)
237+
engine = self._build_trt_engine(cfg, engine_path)
238+
239+
if engine is None:
240+
raise RuntimeError("TensorRT engine build/load failed")
241+
242+
self._trt_context = engine.create_execution_context()
243+
self._trt_input_name = engine.get_tensor_name(0)
244+
self._trt_output_name = engine.get_tensor_name(1)
245+
246+
input_shape = engine.get_tensor_shape(self._trt_input_name)
247+
fixed_shape = tuple(1 if d == -1 else d for d in input_shape)
248+
self._trt_context.set_input_shape(self._trt_input_name, fixed_shape)
249+
250+
self._trt_input_tensor = torch.zeros(fixed_shape, dtype=torch.float32, device="cuda")
251+
actual_out_shape = self._trt_context.get_tensor_shape(self._trt_output_name)
252+
self._trt_output_tensor = torch.empty(list(actual_out_shape), dtype=torch.float32, device="cuda")
253+
254+
self._trt_context.set_tensor_address(self._trt_input_name, self._trt_input_tensor.data_ptr())
255+
self._trt_context.set_tensor_address(self._trt_output_name, self._trt_output_tensor.data_ptr())
256+
self._trt_stream = torch.cuda.current_stream().cuda_stream
257+
258+
self.backend = "tensorrt"
259+
_log(f"TensorRT FP16 engine ready: {engine_path.name}", self._tag)
260+
return {
261+
"model": model_name,
262+
"device": "cuda",
263+
"blend_mode": self.blend_mode,
264+
"colormap": config.get("colormap", "inferno"),
265+
"backend": "tensorrt",
266+
}
267+
268+
def _build_trt_engine(self, cfg: dict, engine_path: Path):
269+
"""Export PyTorch → ONNX → build TRT FP16 engine → serialize to disk."""
270+
import torch
271+
import tensorrt as trt
272+
from depth_anything_v2.dpt import DepthAnythingV2
273+
from huggingface_hub import hf_hub_download
274+
275+
weights_path = hf_hub_download(cfg["repo"], cfg["filename"])
276+
pt_model = DepthAnythingV2(
277+
encoder=cfg["encoder"], features=cfg["features"],
278+
out_channels=cfg["out_channels"],
279+
)
280+
pt_model.load_state_dict(torch.load(weights_path, map_location="cuda", weights_only=True))
281+
pt_model.to("cuda").eval()
282+
283+
dummy = torch.randn(1, 3, 518, 518, device="cuda")
284+
onnx_path = TRT_CACHE_DIR / f"{cfg['filename'].replace('.pth', '')}.onnx"
285+
TRT_CACHE_DIR.mkdir(parents=True, exist_ok=True)
286+
287+
_log(f"Exporting ONNX: {onnx_path.name}", self._tag)
288+
torch.onnx.export(
289+
pt_model, dummy, str(onnx_path),
290+
input_names=["input"], output_names=["depth"],
291+
dynamic_axes={"input": {0: "batch"}, "depth": {0: "batch"}},
292+
opset_version=17,
293+
)
294+
del pt_model
295+
torch.cuda.empty_cache()
296+
297+
logger = trt.Logger(trt.Logger.WARNING)
298+
builder = trt.Builder(logger)
299+
network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
300+
parser = trt.OnnxParser(network, logger)
301+
302+
_log("Parsing ONNX for TensorRT...", self._tag)
303+
with open(str(onnx_path), "rb") as f:
304+
if not parser.parse(f.read()):
305+
for i in range(parser.num_errors):
306+
_log(f" ONNX parse error: {parser.get_error(i)}", self._tag)
307+
return None
308+
309+
config = builder.create_builder_config()
310+
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 30)
311+
312+
inp = network.get_input(0)
313+
if any(d == -1 for d in inp.shape):
314+
profile = builder.create_optimization_profile()
315+
fixed = tuple(1 if d == -1 else d for d in inp.shape)
316+
profile.set_shape(inp.name, fixed, fixed, fixed)
317+
config.add_optimization_profile(profile)
318+
319+
config.set_flag(trt.BuilderFlag.FP16)
320+
321+
_log("Building TRT FP16 engine (30-120s)...", self._tag)
322+
serialized = builder.build_serialized_network(network, config)
323+
if serialized is None:
324+
_log("TRT engine build failed!", self._tag)
325+
return None
326+
327+
engine_bytes = bytes(serialized)
328+
with open(str(engine_path), "wb") as f:
329+
f.write(engine_bytes)
330+
_log(f"Engine cached: {engine_path} ({len(engine_bytes) / 1e6:.1f} MB)", self._tag)
331+
332+
try:
333+
onnx_path.unlink()
334+
except OSError:
335+
pass
336+
337+
runtime = trt.Runtime(logger)
338+
return runtime.deserialize_cuda_engine(engine_bytes)
339+
340+
@staticmethod
341+
def _deserialize_engine(engine_path: Path):
342+
"""Load a previously serialized TRT engine from disk."""
343+
import tensorrt as trt
344+
logger = trt.Logger(trt.Logger.WARNING)
345+
runtime = trt.Runtime(logger)
346+
with open(str(engine_path), "rb") as f:
347+
return runtime.deserialize_cuda_engine(f.read())
348+
199349
# ── PyTorch backend (fallback) ────────────────────────────────────
200350

201351
def _load_pytorch(self, model_name: str, config: dict) -> dict:
@@ -242,6 +392,8 @@ def transform_frame(self, image, metadata: dict):
242392

243393
if self.backend == "coreml":
244394
depth_colored = self._infer_coreml(image)
395+
elif self.backend == "tensorrt":
396+
depth_colored = self._infer_tensorrt(image)
245397
else:
246398
depth_colored = self._infer_pytorch(image)
247399

@@ -308,6 +460,36 @@ def _infer_pytorch(self, image):
308460

309461
return depth_colored
310462

463+
def _infer_tensorrt(self, image):
464+
"""Run TensorRT FP16 inference and return colorized depth map."""
465+
import torch
466+
import cv2
467+
import numpy as np
468+
469+
original_h, original_w = image.shape[:2]
470+
rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
471+
472+
resized = cv2.resize(rgb, (518, 518), interpolation=cv2.INTER_LINEAR)
473+
img_float = resized.astype(np.float32) / 255.0
474+
mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
475+
std = np.array([0.229, 0.224, 0.225], dtype=np.float32)
476+
img_float = (img_float - mean) / std
477+
img_nchw = np.transpose(img_float, (2, 0, 1))[np.newaxis]
478+
479+
self._trt_input_tensor.copy_(torch.from_numpy(img_nchw))
480+
self._trt_context.execute_async_v3(self._trt_stream)
481+
torch.cuda.synchronize()
482+
483+
depth = self._trt_output_tensor.cpu().numpy()
484+
depth = np.squeeze(depth)
485+
486+
d_min, d_max = depth.min(), depth.max()
487+
depth_norm = ((depth - d_min) / (d_max - d_min + 1e-8) * 255).astype(np.uint8)
488+
depth_colored = cv2.applyColorMap(depth_norm, self.colormap_id)
489+
depth_colored = cv2.resize(depth_colored, (original_w, original_h))
490+
491+
return depth_colored
492+
311493
# ── Config updates ────────────────────────────────────────────────
312494

313495
def on_config_update(self, config: dict):

0 commit comments

Comments
 (0)