From 58016b21786d1cb3c373ced5d327aafd39ca8913 Mon Sep 17 00:00:00 2001 From: Ivan Basov Date: Mon, 20 Apr 2026 11:01:45 -0700 Subject: [PATCH 1/3] fix(timing): warmup pass before timing loop to amortise torch.compile JIT MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Without this, the first batch in the timing loop bears the full torch.compile lazy-compilation cost (~887 ms vs ~1 ms steady-state), skewing Phase Timing numbers — especially at low sample counts like PREDECODER_INFERENCE_NUM_SAMPLES=1. The warmup only runs when torch.compile is active and TRT is not in use. Co-Authored-By: Claude Sonnet 4.6 --- code/evaluation/logical_error_rate.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/code/evaluation/logical_error_rate.py b/code/evaluation/logical_error_rate.py index 342bf41..f5806c7 100644 --- a/code/evaluation/logical_error_rate.py +++ b/code/evaluation/logical_error_rate.py @@ -1321,6 +1321,17 @@ def run_inference_and_decode_pre_decoder_memory(model, device, dist, cfg) -> dic else: data_iter = test_dataloader + # Warmup: trigger torch.compile lazy compilation before the timing loop. + if trt_context is None and _applied_compile: + with torch.no_grad(): + _warmup_tensor = torch.as_tensor( + stim_dets[:1], dtype=torch.float32, device=device + ) + _ = pipeline_module(_warmup_tensor) + del _warmup_tensor + if device.type == "cuda": + torch.cuda.synchronize() + # Timing instrumentation accumulators (used when timing_rank0 is True) residual_syndrome_density_sum = 0.0 predecoder_batch_times = [] if timing_rank0 else None From c09e524b247ceebbd7d26b6466432c2c8341f703 Mon Sep 17 00:00:00 2001 From: Ivan Basov Date: Mon, 20 Apr 2026 11:07:32 -0700 Subject: [PATCH 2/3] test(timing): extract _maybe_warmup_compile helper + unit tests Extracts the warmup block into a named helper so it can be tested in isolation. Five tests cover: fires when compile is active (CPU), skipped when compile is off, skipped when TRT context is present, CUDA sync called on GPU device, CUDA sync not called on CPU device. Co-Authored-By: Claude Sonnet 4.6 --- code/evaluation/logical_error_rate.py | 22 ++++----- code/tests/test_inference_latency_timing.py | 49 ++++++++++++++++++++- 2 files changed, 59 insertions(+), 12 deletions(-) diff --git a/code/evaluation/logical_error_rate.py b/code/evaluation/logical_error_rate.py index f5806c7..3b705e0 100644 --- a/code/evaluation/logical_error_rate.py +++ b/code/evaluation/logical_error_rate.py @@ -501,6 +501,17 @@ def map_grid_to_stabilizer_tensor(grid_btdd, stab_indices_1d): return flat_bdt.index_select(dim=1, index=stab_indices_1d) +def _maybe_warmup_compile(pipeline_module, stim_dets, device, trt_context, applied_compile): + """Fire one forward pass to trigger torch.compile JIT before the timing loop.""" + if trt_context is not None or not applied_compile: + return + with torch.no_grad(): + _warmup_tensor = torch.as_tensor(stim_dets[:1], dtype=torch.float32, device=device) + pipeline_module(_warmup_tensor) + if device.type == "cuda": + torch.cuda.synchronize() + + class PreDecoderMemoryEvalModule(nn.Module): """ nn.Module that encapsulates the full pre-decoder eval path: batch input -> trainX, @@ -1321,16 +1332,7 @@ def run_inference_and_decode_pre_decoder_memory(model, device, dist, cfg) -> dic else: data_iter = test_dataloader - # Warmup: trigger torch.compile lazy compilation before the timing loop. - if trt_context is None and _applied_compile: - with torch.no_grad(): - _warmup_tensor = torch.as_tensor( - stim_dets[:1], dtype=torch.float32, device=device - ) - _ = pipeline_module(_warmup_tensor) - del _warmup_tensor - if device.type == "cuda": - torch.cuda.synchronize() + _maybe_warmup_compile(pipeline_module, stim_dets, device, trt_context, _applied_compile) # Timing instrumentation accumulators (used when timing_rank0 is True) residual_syndrome_density_sum = 0.0 diff --git a/code/tests/test_inference_latency_timing.py b/code/tests/test_inference_latency_timing.py index d537412..3e16c13 100644 --- a/code/tests/test_inference_latency_timing.py +++ b/code/tests/test_inference_latency_timing.py @@ -15,11 +15,12 @@ import math import unittest -from unittest.mock import patch +from unittest.mock import MagicMock, call, patch import numpy as np +import torch -from evaluation.logical_error_rate import _time_single_shot_latency_stim +from evaluation.logical_error_rate import _maybe_warmup_compile, _time_single_shot_latency_stim class _FakeMatcher: @@ -84,5 +85,49 @@ def test_time_single_shot_latency_handles_empty(self) -> None: self.assertEqual(len(matcher.calls), 0) +class TestMaybeWarmupCompile(unittest.TestCase): + + def _cpu_device(self): + return torch.device("cpu") + + def test_calls_pipeline_module_when_compile_active(self): + pipeline = MagicMock(return_value=torch.zeros(1)) + dets = np.zeros((4, 8), dtype=np.uint8) + _maybe_warmup_compile(pipeline, dets, self._cpu_device(), trt_context=None, applied_compile=True) + self.assertEqual(pipeline.call_count, 1) + tensor_arg = pipeline.call_args[0][0] + self.assertEqual(tensor_arg.shape[0], 1) + self.assertEqual(tensor_arg.dtype, torch.float32) + + def test_skipped_when_compile_not_applied(self): + pipeline = MagicMock() + dets = np.zeros((4, 8), dtype=np.uint8) + _maybe_warmup_compile(pipeline, dets, self._cpu_device(), trt_context=None, applied_compile=False) + pipeline.assert_not_called() + + def test_skipped_when_trt_context_present(self): + pipeline = MagicMock() + dets = np.zeros((4, 8), dtype=np.uint8) + _maybe_warmup_compile(pipeline, dets, self._cpu_device(), trt_context=object(), applied_compile=True) + pipeline.assert_not_called() + + def test_cuda_sync_called_on_gpu_device(self): + pipeline = MagicMock(return_value=torch.zeros(1)) + dets = np.zeros((4, 8), dtype=np.uint8) + gpu_device = MagicMock(spec=torch.device) + gpu_device.type = "cuda" + with patch("evaluation.logical_error_rate.torch.as_tensor", return_value=torch.zeros(1)) as _mock_tensor, \ + patch("evaluation.logical_error_rate.torch.cuda.synchronize") as mock_sync: + _maybe_warmup_compile(pipeline, dets, gpu_device, trt_context=None, applied_compile=True) + mock_sync.assert_called_once() + + def test_cuda_sync_not_called_on_cpu_device(self): + pipeline = MagicMock(return_value=torch.zeros(1)) + dets = np.zeros((4, 8), dtype=np.uint8) + with patch("evaluation.logical_error_rate.torch.cuda.synchronize") as mock_sync: + _maybe_warmup_compile(pipeline, dets, self._cpu_device(), trt_context=None, applied_compile=True) + mock_sync.assert_not_called() + + if __name__ == "__main__": unittest.main() From 2427b59e12f8ad5f783e5e9be3113a3ddb5cf243 Mon Sep 17 00:00:00 2001 From: Ivan Basov Date: Mon, 20 Apr 2026 11:10:25 -0700 Subject: [PATCH 3/3] style: yapf formatting on test_inference_latency_timing Co-Authored-By: Claude Sonnet 4.6 --- code/tests/test_inference_latency_timing.py | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/code/tests/test_inference_latency_timing.py b/code/tests/test_inference_latency_timing.py index 3e16c13..cfba0a2 100644 --- a/code/tests/test_inference_latency_timing.py +++ b/code/tests/test_inference_latency_timing.py @@ -93,7 +93,9 @@ def _cpu_device(self): def test_calls_pipeline_module_when_compile_active(self): pipeline = MagicMock(return_value=torch.zeros(1)) dets = np.zeros((4, 8), dtype=np.uint8) - _maybe_warmup_compile(pipeline, dets, self._cpu_device(), trt_context=None, applied_compile=True) + _maybe_warmup_compile( + pipeline, dets, self._cpu_device(), trt_context=None, applied_compile=True + ) self.assertEqual(pipeline.call_count, 1) tensor_arg = pipeline.call_args[0][0] self.assertEqual(tensor_arg.shape[0], 1) @@ -102,13 +104,17 @@ def test_calls_pipeline_module_when_compile_active(self): def test_skipped_when_compile_not_applied(self): pipeline = MagicMock() dets = np.zeros((4, 8), dtype=np.uint8) - _maybe_warmup_compile(pipeline, dets, self._cpu_device(), trt_context=None, applied_compile=False) + _maybe_warmup_compile( + pipeline, dets, self._cpu_device(), trt_context=None, applied_compile=False + ) pipeline.assert_not_called() def test_skipped_when_trt_context_present(self): pipeline = MagicMock() dets = np.zeros((4, 8), dtype=np.uint8) - _maybe_warmup_compile(pipeline, dets, self._cpu_device(), trt_context=object(), applied_compile=True) + _maybe_warmup_compile( + pipeline, dets, self._cpu_device(), trt_context=object(), applied_compile=True + ) pipeline.assert_not_called() def test_cuda_sync_called_on_gpu_device(self): @@ -118,14 +124,18 @@ def test_cuda_sync_called_on_gpu_device(self): gpu_device.type = "cuda" with patch("evaluation.logical_error_rate.torch.as_tensor", return_value=torch.zeros(1)) as _mock_tensor, \ patch("evaluation.logical_error_rate.torch.cuda.synchronize") as mock_sync: - _maybe_warmup_compile(pipeline, dets, gpu_device, trt_context=None, applied_compile=True) + _maybe_warmup_compile( + pipeline, dets, gpu_device, trt_context=None, applied_compile=True + ) mock_sync.assert_called_once() def test_cuda_sync_not_called_on_cpu_device(self): pipeline = MagicMock(return_value=torch.zeros(1)) dets = np.zeros((4, 8), dtype=np.uint8) with patch("evaluation.logical_error_rate.torch.cuda.synchronize") as mock_sync: - _maybe_warmup_compile(pipeline, dets, self._cpu_device(), trt_context=None, applied_compile=True) + _maybe_warmup_compile( + pipeline, dets, self._cpu_device(), trt_context=None, applied_compile=True + ) mock_sync.assert_not_called()