Skip to content

Commit bbbc9fa

Browse files
fix: update benchmark script for cross-platform compatibility
1 parent 89bd3f9 commit bbbc9fa

1 file changed

Lines changed: 20 additions & 2 deletions

File tree

skills/transformation/depth-estimation/scripts/benchmark.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,14 @@ def run_benchmark(args):
248248
model.to(device)
249249
model.eval()
250250

251+
# ── CRITICAL FIX: Device-mismatch workaround ──────────────────────
252+
# The upstream depth_anything_v2 library hardcodes device selection
253+
# inside image2tensor(): `DEVICE = 'cuda' if torch.cuda.is_available()`
254+
# This ignores the model's actual device, causing crashes when the
255+
# model is on CPU but CUDA is available. We store the target device
256+
# and correct the tensor placement manually in the inference loop.
257+
_benchmark_device = device
258+
251259
model_load_ms = (time.perf_counter() - t0) * 1000
252260
backend = "pytorch"
253261
_log(f"PyTorch model loaded in {model_load_ms:.0f}ms on {device}")
@@ -280,11 +288,21 @@ def run_benchmark(args):
280288
if depth_map.ndim > 2:
281289
depth_map = np.squeeze(depth_map)
282290
else:
283-
# PyTorch inference
291+
# PyTorch inference — manual device-correct path
292+
# We bypass model.infer_image() because the upstream library's
293+
# image2tensor() hardcodes CUDA device selection, causing crashes
294+
# when model is on CPU. Instead, we call image2tensor ourselves,
295+
# fix the device, then call model.forward() directly.
284296
import torch
297+
import torch.nn.functional as F
285298
rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
286299
with torch.no_grad():
287-
depth_map = model.infer_image(rgb)
300+
img_tensor, (h, w) = model.image2tensor(rgb, input_size=518)
301+
# FIX: Override the library's hardcoded device with the model's device
302+
img_tensor = img_tensor.to(_benchmark_device)
303+
depth = model.forward(img_tensor)
304+
depth = F.interpolate(depth[:, None], (h, w), mode="bilinear", align_corners=True)[0, 0]
305+
depth_map = depth.cpu().numpy()
288306

289307
# Normalize and colorize
290308
d_min, d_max = depth_map.min(), depth_map.max()

0 commit comments

Comments
 (0)