diff --git a/code/evaluation/logical_error_rate.py b/code/evaluation/logical_error_rate.py index 54b65e9..db06dec 100644 --- a/code/evaluation/logical_error_rate.py +++ b/code/evaluation/logical_error_rate.py @@ -501,6 +501,17 @@ def map_grid_to_stabilizer_tensor(grid_btdd, stab_indices_1d): return flat_bdt.index_select(dim=1, index=stab_indices_1d) +def _maybe_warmup_compile(pipeline_module, stim_dets, device, trt_context, applied_compile): + """Fire one forward pass to trigger torch.compile JIT before the timing loop.""" + if trt_context is not None or not applied_compile: + return + with torch.no_grad(): + _warmup_tensor = torch.as_tensor(stim_dets[:1], dtype=torch.float32, device=device) + pipeline_module(_warmup_tensor) + if device.type == "cuda": + torch.cuda.synchronize() + + class PreDecoderMemoryEvalModule(nn.Module): """ nn.Module that encapsulates the full pre-decoder eval path: batch input -> trainX, @@ -1316,6 +1327,8 @@ def run_inference_and_decode_pre_decoder_memory(model, device, dist, cfg) -> dic else: data_iter = test_dataloader + _maybe_warmup_compile(pipeline_module, stim_dets, device, trt_context, _applied_compile) + # Timing instrumentation accumulators (used when timing_rank0 is True) residual_syndrome_density_sum = 0.0 predecoder_batch_times = [] if timing_rank0 else None diff --git a/code/tests/test_inference_latency_timing.py b/code/tests/test_inference_latency_timing.py index d537412..cfba0a2 100644 --- a/code/tests/test_inference_latency_timing.py +++ b/code/tests/test_inference_latency_timing.py @@ -15,11 +15,12 @@ import math import unittest -from unittest.mock import patch +from unittest.mock import MagicMock, call, patch import numpy as np +import torch -from evaluation.logical_error_rate import _time_single_shot_latency_stim +from evaluation.logical_error_rate import _maybe_warmup_compile, _time_single_shot_latency_stim class _FakeMatcher: @@ -84,5 +85,59 @@ def test_time_single_shot_latency_handles_empty(self) -> None: self.assertEqual(len(matcher.calls), 0) +class TestMaybeWarmupCompile(unittest.TestCase): + + def _cpu_device(self): + return torch.device("cpu") + + def test_calls_pipeline_module_when_compile_active(self): + pipeline = MagicMock(return_value=torch.zeros(1)) + dets = np.zeros((4, 8), dtype=np.uint8) + _maybe_warmup_compile( + pipeline, dets, self._cpu_device(), trt_context=None, applied_compile=True + ) + self.assertEqual(pipeline.call_count, 1) + tensor_arg = pipeline.call_args[0][0] + self.assertEqual(tensor_arg.shape[0], 1) + self.assertEqual(tensor_arg.dtype, torch.float32) + + def test_skipped_when_compile_not_applied(self): + pipeline = MagicMock() + dets = np.zeros((4, 8), dtype=np.uint8) + _maybe_warmup_compile( + pipeline, dets, self._cpu_device(), trt_context=None, applied_compile=False + ) + pipeline.assert_not_called() + + def test_skipped_when_trt_context_present(self): + pipeline = MagicMock() + dets = np.zeros((4, 8), dtype=np.uint8) + _maybe_warmup_compile( + pipeline, dets, self._cpu_device(), trt_context=object(), applied_compile=True + ) + pipeline.assert_not_called() + + def test_cuda_sync_called_on_gpu_device(self): + pipeline = MagicMock(return_value=torch.zeros(1)) + dets = np.zeros((4, 8), dtype=np.uint8) + gpu_device = MagicMock(spec=torch.device) + gpu_device.type = "cuda" + with patch("evaluation.logical_error_rate.torch.as_tensor", return_value=torch.zeros(1)) as _mock_tensor, \ + patch("evaluation.logical_error_rate.torch.cuda.synchronize") as mock_sync: + _maybe_warmup_compile( + pipeline, dets, gpu_device, trt_context=None, applied_compile=True + ) + mock_sync.assert_called_once() + + def test_cuda_sync_not_called_on_cpu_device(self): + pipeline = MagicMock(return_value=torch.zeros(1)) + dets = np.zeros((4, 8), dtype=np.uint8) + with patch("evaluation.logical_error_rate.torch.cuda.synchronize") as mock_sync: + _maybe_warmup_compile( + pipeline, dets, self._cpu_device(), trt_context=None, applied_compile=True + ) + mock_sync.assert_not_called() + + if __name__ == "__main__": unittest.main()