Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions code/evaluation/logical_error_rate.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,17 @@ def map_grid_to_stabilizer_tensor(grid_btdd, stab_indices_1d):
return flat_bdt.index_select(dim=1, index=stab_indices_1d)


def _maybe_warmup_compile(pipeline_module, stim_dets, device, trt_context, applied_compile):
"""Fire one forward pass to trigger torch.compile JIT before the timing loop."""
if trt_context is not None or not applied_compile:
return
with torch.no_grad():
_warmup_tensor = torch.as_tensor(stim_dets[:1], dtype=torch.float32, device=device)
pipeline_module(_warmup_tensor)
if device.type == "cuda":
torch.cuda.synchronize()


class PreDecoderMemoryEvalModule(nn.Module):
"""
nn.Module that encapsulates the full pre-decoder eval path: batch input -> trainX,
Expand Down Expand Up @@ -1321,6 +1332,8 @@ def run_inference_and_decode_pre_decoder_memory(model, device, dist, cfg) -> dic
else:
data_iter = test_dataloader

_maybe_warmup_compile(pipeline_module, stim_dets, device, trt_context, _applied_compile)

# Timing instrumentation accumulators (used when timing_rank0 is True)
residual_syndrome_density_sum = 0.0
predecoder_batch_times = [] if timing_rank0 else None
Expand Down
59 changes: 57 additions & 2 deletions code/tests/test_inference_latency_timing.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,12 @@

import math
import unittest
from unittest.mock import patch
from unittest.mock import MagicMock, call, patch

import numpy as np
import torch

from evaluation.logical_error_rate import _time_single_shot_latency_stim
from evaluation.logical_error_rate import _maybe_warmup_compile, _time_single_shot_latency_stim


class _FakeMatcher:
Expand Down Expand Up @@ -84,5 +85,59 @@ def test_time_single_shot_latency_handles_empty(self) -> None:
self.assertEqual(len(matcher.calls), 0)


class TestMaybeWarmupCompile(unittest.TestCase):

def _cpu_device(self):
return torch.device("cpu")

def test_calls_pipeline_module_when_compile_active(self):
pipeline = MagicMock(return_value=torch.zeros(1))
dets = np.zeros((4, 8), dtype=np.uint8)
_maybe_warmup_compile(
pipeline, dets, self._cpu_device(), trt_context=None, applied_compile=True
)
self.assertEqual(pipeline.call_count, 1)
tensor_arg = pipeline.call_args[0][0]
self.assertEqual(tensor_arg.shape[0], 1)
self.assertEqual(tensor_arg.dtype, torch.float32)

def test_skipped_when_compile_not_applied(self):
pipeline = MagicMock()
dets = np.zeros((4, 8), dtype=np.uint8)
_maybe_warmup_compile(
pipeline, dets, self._cpu_device(), trt_context=None, applied_compile=False
)
pipeline.assert_not_called()

def test_skipped_when_trt_context_present(self):
pipeline = MagicMock()
dets = np.zeros((4, 8), dtype=np.uint8)
_maybe_warmup_compile(
pipeline, dets, self._cpu_device(), trt_context=object(), applied_compile=True
)
pipeline.assert_not_called()

def test_cuda_sync_called_on_gpu_device(self):
pipeline = MagicMock(return_value=torch.zeros(1))
dets = np.zeros((4, 8), dtype=np.uint8)
gpu_device = MagicMock(spec=torch.device)
gpu_device.type = "cuda"
with patch("evaluation.logical_error_rate.torch.as_tensor", return_value=torch.zeros(1)) as _mock_tensor, \
patch("evaluation.logical_error_rate.torch.cuda.synchronize") as mock_sync:
_maybe_warmup_compile(
pipeline, dets, gpu_device, trt_context=None, applied_compile=True
)
mock_sync.assert_called_once()

def test_cuda_sync_not_called_on_cpu_device(self):
pipeline = MagicMock(return_value=torch.zeros(1))
dets = np.zeros((4, 8), dtype=np.uint8)
with patch("evaluation.logical_error_rate.torch.cuda.synchronize") as mock_sync:
_maybe_warmup_compile(
pipeline, dets, self._cpu_device(), trt_context=None, applied_compile=True
)
mock_sync.assert_not_called()


if __name__ == "__main__":
unittest.main()
Loading