diff --git a/.github/workflows/ci-gpu.yml b/.github/workflows/ci-gpu.yml index 95ddd97..f0037c1 100644 --- a/.github/workflows/ci-gpu.yml +++ b/.github/workflows/ci-gpu.yml @@ -194,7 +194,7 @@ jobs: options: -u root --security-opt seccomp=unconfined --shm-size 16g env: NVIDIA_VISIBLE_DEVICES: ${{ env.NVIDIA_VISIBLE_DEVICES }} - timeout-minutes: 20 + timeout-minutes: 35 steps: - name: Install system dependencies run: | @@ -255,6 +255,41 @@ jobs: PREDECODER_TEST_SAMPLES: "2048" PREDECODER_TRAIN_EPOCHS: "2" + - name: Multi-GPU smoke training with parallel spacelike HE (2 GPUs, DDP) + # Additive coverage on top of the default-config multi-GPU smoke above. + # Forces data.use_compile=True + data.use_parallel_spacelike=True so the + # parallel + compiled spacelike HE path runs end-to-end under DDP on + # 2 GPUs. Failure modes specific to this combination (per-rank device + # pinning of the partition, torch.compile cache contention across + # ranks, deadlocks during the compiled inner loop) surface as a + # training crash here. The existing default-config step above is + # intentionally left untouched so we do not regress on coverage of + # the default path. + shell: bash + run: | + . .venv_multigpu/bin/activate + export PREDECODER_TIMING_RUN=1 + export PREDECODER_DISABLE_SDR=1 + export PREDECODER_LER_FINAL_ONLY=1 + export PREDECODER_INFERENCE_NUM_SAMPLES=32 + export PREDECODER_INFERENCE_LATENCY_SAMPLES=0 + export PREDECODER_INFERENCE_MEAS_BASIS=both + export PREDECODER_INFERENCE_NUM_WORKERS=0 + EXPERIMENT_NAME=ci_multi_gpu_he WORKFLOW=train GPUS=2 \ + EXTRA_PARAMS="data.use_compile=True data.use_parallel_spacelike=True" \ + bash code/scripts/local_run.sh 2>&1 | tee /tmp/ci_multigpu_he_train.log + r=${PIPESTATUS[0]}; [ $r -ne 0 ] && exit $r + EXPERIMENT_NAME=ci_multi_gpu_he WORKFLOW=inference GPUS=2 \ + EXTRA_PARAMS="data.use_compile=True data.use_parallel_spacelike=True" \ + bash code/scripts/local_run.sh 2>&1 | tee /tmp/ci_multigpu_he_infer.log + r=${PIPESTATUS[0]}; [ $r -ne 0 ] && exit $r + python code/scripts/check_ler_from_log.py /tmp/ci_multigpu_he_train.log --max-ler 0.35 + env: + PREDECODER_TRAIN_SAMPLES: "16384" + PREDECODER_VAL_SAMPLES: "2048" + PREDECODER_TEST_SAMPLES: "2048" + PREDECODER_TRAIN_EPOCHS: "2" + # --------------------------------------------------------------------------- # GPU coverage: captures GPU-specific code paths missed by the CPU coverage job # --------------------------------------------------------------------------- diff --git a/README.md b/README.md index 57b62e8..a30cb02 100644 --- a/README.md +++ b/README.md @@ -651,6 +651,66 @@ time. - **Inference uses the trained model from `outputs//models/`**, so keep the same `EXPERIMENT_NAME` when you switch from training to inference. - **Training auto-resumes**: if a run is interrupted, launching the same training command again (same `EXPERIMENT_NAME`) will automatically load the latest checkpoint it finds and continue training (up to the fixed 100 epochs). To force a clean restart, set `FRESH_START=1`, although we recommend changing `EXPERIMENT_NAME` instead. +### HE acceleration (advanced): parallel spacelike + +The spacelike homological-equivalence (HE) pass canonicalises each +`(batch, round)` diff frame independently. By default the canonicalisation +processes stabilisers sequentially. With `data.use_parallel_spacelike: True`, +the cache build computes a 2-partition of the stabiliser-overlap graph so the +two colour classes are reduced in parallel inside a `torch.compile`-friendly +inner loop. This cuts Python <-> compiled-graph crossings per HE pass and +exposes more parallelism to the GPU. + +#### How to enable + +In any config: + +```yaml +data: + use_compile: True # required to see the speedup + use_parallel_spacelike: True +``` + +Or on the CLI: + +```bash +EXTRA_PARAMS="data.use_compile=True data.use_parallel_spacelike=True" \ + bash code/scripts/local_run.sh +``` + +#### Pros (when to enable) + +- **Faster spacelike HE on GPU** for the rotated single-basis surface code, by + amortising per-iteration Python overhead and running both colour classes + through `torch.compile` together. +- **Syndrome-equivalent to the sequential path** on supported codes: the + parallel path preserves the HE invariants and produces valid non-increasing + representatives, while avoiding the sequential stabiliser order. Outputs are + not guaranteed bit-identical to the sequential path; both are valid + representatives of the same coset. + Coverage is added under `code/tests/mid/test_homological_equivalence.py`. +- **Composes with `data.use_weight2`** — the weight-2 fix-equivalence pass is + applied per colour. + +#### Cons / caveats (when to leave it off) + +- **Rotated single-basis surface code only.** The 2-colouring assumes the + stabiliser-overlap graph is bipartite, which holds by construction for the + rotated surface code targeted here. Color codes, non-rotated layouts, + subsystem codes and mixed-basis matrices can produce odd cycles; in that + case the cache build refuses with a diagnostic naming the offending + stabiliser pair rather than silently falling back. +- **`use_compile=True` is required** for the speedup; without it the partition + is built but the optimised compiled inner loop is not entered. +- **`torch.compile` has cold-start cost.** The first compiled call can pause + while Inductor/CUDA graph capture runs, and shape changes such as different + batch sizes or round counts can trigger recompilation. +- **Cache-build cost and memory grow slightly.** A packed + `parallel_partition_packed` view is materialised once at cache-build time so + the hot path only does dtype casts. +- **GPU-targeted.** The parallel path is designed for CUDA; on CPU you may + not see a speedup over the sequential path. + ## Logging and outputs ### What gets written where diff --git a/code/data/generator_torch.py b/code/data/generator_torch.py index 41db2f4..a07b8c2 100644 --- a/code/data/generator_torch.py +++ b/code/data/generator_torch.py @@ -53,6 +53,7 @@ def __init__( use_coset_search=False, coset_max_generators=20, use_dense_overlap=False, + use_parallel_spacelike=False, **_ignored, ): if global_rank is None: @@ -102,6 +103,7 @@ def __init__( max_passes_w1=max_passes_w1, use_weight2=use_weight2, max_passes_w2=max_passes_w2, + use_parallel_spacelike=use_parallel_spacelike, ), daemon=True, ) @@ -211,6 +213,7 @@ def __init__( use_coset_search=use_coset_search, coset_max_generators=coset_max_generators, use_dense_overlap=use_dense_overlap, + use_parallel_spacelike=use_parallel_spacelike, ) if self._mixed: diff --git a/code/qec/surface_code/homological_equivalence_torch.py b/code/qec/surface_code/homological_equivalence_torch.py index bc6b10d..fbff089 100644 --- a/code/qec/surface_code/homological_equivalence_torch.py +++ b/code/qec/surface_code/homological_equivalence_torch.py @@ -86,6 +86,22 @@ class SpacelikeHECache: w4_bl: torch.Tensor w4_br: torch.Tensor + # Independent stabilizer partitions for parallel spacelike HE. + # Built from a 2-partition of the stabilizer-overlap graph when `basis` is provided. + parallel_partition: Optional[dict] = None + + # Diagnostic for why the parallel partition could not be built (if any). + # Surfaced by `_require_parallel_partition` so callers see the offending + # stabilizer pair instead of a generic "missing partition" message. + parallel_partition_failure_reason: Optional[str] = None + + # Pre-packed compile-friendly view of `parallel_partition` (built once at + # cache construction time). The compiled path reads its fixed-shape inputs + # straight out of this dict, avoiding the per-call `torch.stack`s, + # `is_boundary` boolean cast, and zero-padding that + # `_pack_partition_for_compile` would otherwise re-do every call. + parallel_partition_packed: Optional[dict] = None + # Precomputed data for compiled sequential spacelike HE (P2+P3) seq_compile_data: Optional[dict] = None @@ -194,6 +210,35 @@ def build_spacelike_he_cache( support_sizes = support_masks.sum(dim=1, dtype=torch.int64) layers = tuple(torch.tensor(layer, dtype=torch.int64, device=device) for layer in layers_list) + parallel_partition = None + parallel_partition_failure_reason: Optional[str] = None + if basis is not None: + fix_map = _build_fix_equiv_map( + parity_cpu, + d, + basis.upper(), + w4_tl_cpu=w4_tl_cpu, + w4_tr_cpu=w4_tr_cpu, + w4_bl_cpu=w4_bl_cpu, + w4_br_cpu=w4_br_cpu, + ) + parallel_partition, parallel_partition_failure_reason = _build_spacelike_partition( + parity_cpu, + fix_map, + device, + w2_canonical_cpu=w2_canonical_cpu, + w2_other_cpu=w2_other_cpu, + ) + + parallel_partition_packed = None + if parallel_partition is not None: + # Pack-once: the compile-friendly view of the partition is fixed at + # cache-build time. Doing this here saves the per-call cost of 8x + # `torch.stack`, the `(weights == 2).float().unsqueeze(0)` allocation, + # and the conditional zero-padding that `_pack_partition_for_compile` + # would otherwise repeat on every call. + parallel_partition_packed = _pack_partition_for_compile(parallel_partition, device) + cache = SpacelikeHECache( distance=d, parity=parity_dev, @@ -206,6 +251,9 @@ def build_spacelike_he_cache( w4_tr=w4_tr_cpu.to(device), w4_bl=w4_bl_cpu.to(device), w4_br=w4_br_cpu.to(device), + parallel_partition=parallel_partition, + parallel_partition_failure_reason=parallel_partition_failure_reason, + parallel_partition_packed=parallel_partition_packed, ) if basis is not None: @@ -215,6 +263,456 @@ def build_spacelike_he_cache( return cache +def _emit_w4_fe_patterns(tl: int, tr: int, bl: int, br: int, error_type: str) -> torch.Tensor: + """Emit the 3 fix-equivalence move patterns for one weight-4 stabilizer. + + Each row is ``[src_q0, src_q1, dst_q0, dst_q1]`` and is interpreted as + "if the error currently sits at ``(src_q0, src_q1)``, move it to + ``(dst_q0, dst_q1)``". The three patterns are: + + * row 0: vertical - TL+BL -> TR+BR + * row 1: horizontal - BL+BR -> TL+TR + * row 2: diagonal - basis='X': TL+BR -> TR+BL + basis='Z': TR+BL -> TL+BR + + This helper is the single source of truth for the FE pattern set: callers + pass corner indices that they already know (typically from + ``cache.w4_tl/tr/bl/br``), and never re-derive the 2x2 box layout. + """ + et = error_type.upper() + patterns = torch.empty((3, 4), dtype=torch.int16) + patterns[0] = torch.tensor([tl, bl, tr, br], dtype=torch.int16) + patterns[1] = torch.tensor([bl, br, tl, tr], dtype=torch.int16) + if et == "X": + patterns[2] = torch.tensor([tl, br, tr, bl], dtype=torch.int16) + else: + patterns[2] = torch.tensor([tr, bl, tl, br], dtype=torch.int16) + return patterns + + +def _build_fix_equiv_map( + parity_matrix: torch.Tensor, + distance: int, + error_type: str, + *, + w4_tl_cpu: Optional[torch.Tensor] = None, + w4_tr_cpu: Optional[torch.Tensor] = None, + w4_bl_cpu: Optional[torch.Tensor] = None, + w4_br_cpu: Optional[torch.Tensor] = None, +) -> torch.Tensor: + """ + Precompute the fix-equivalence canonical mapping for each weight-4 stabilizer. + + Returns a CPU tensor of shape ``(num_weight4_stabs, 3, 4)``. For each + weight-4 stabilizer, each pattern row stores ``[src_q0, src_q1, dst_q0, + dst_q1]``. + + When ``w4_{tl,tr,bl,br}_cpu`` are provided (the common path inside + ``build_spacelike_he_cache``), corner indices are read from those cache + tensors so this function and the cache builder agree on the 2x2 box + layout by construction. The legacy ``parity_matrix``-driven path is kept + only for callers that do not yet have a corner cache. + """ + have_corners = ( + w4_tl_cpu is not None and w4_tr_cpu is not None and w4_bl_cpu is not None and + w4_br_cpu is not None + ) + + w4_maps: list[torch.Tensor] = [] + if have_corners: + for s in range(int(w4_tl_cpu.numel())): + tl = int(w4_tl_cpu[s].item()) + if tl < 0: + continue + tr = int(w4_tr_cpu[s].item()) + bl = int(w4_bl_cpu[s].item()) + br = int(w4_br_cpu[s].item()) + w4_maps.append(_emit_w4_fe_patterns(tl, tr, bl, br, error_type)) + else: + for s in range(parity_matrix.shape[0]): + support = torch.nonzero(parity_matrix[s], as_tuple=True)[0].tolist() + if len(support) != 4: + continue + coords = sorted((idx // distance, idx % distance, idx) for idx in support) + tl, tr, bl, br = coords[0][2], coords[1][2], coords[2][2], coords[3][2] + w4_maps.append(_emit_w4_fe_patterns(tl, tr, bl, br, error_type)) + + if not w4_maps: + return torch.zeros((0, 3, 4), dtype=torch.int16) + return torch.stack(w4_maps, dim=0) + + +def _validate_spacelike_partition( + parity_matrix: torch.Tensor, indices_by_partition: list[list[int]] +) -> bool: + """Return True iff every stabilizer is assigned once and same-partition supports are disjoint.""" + num_stabs = int(parity_matrix.shape[0]) + assigned = [idx for group in indices_by_partition for idx in group] + if sorted(assigned) != list(range(num_stabs)): + return False + + parity_bool = parity_matrix.bool() + for group in indices_by_partition: + for pos, i in enumerate(group): + supp_i = parity_bool[i] + for j in group[pos + 1:]: + if bool(torch.any(supp_i & parity_bool[j])): + return False + return True + + +def _build_spacelike_partition( + parity_matrix: torch.Tensor, + fix_map: torch.Tensor, + device: torch.device, + *, + w2_canonical_cpu: Optional[torch.Tensor] = None, + w2_other_cpu: Optional[torch.Tensor] = None, +) -> Tuple[Optional[dict], Optional[str]]: + """ + Build two independent stabilizer partitions for parallel spacelike HE. + + The implementation 2-partitions the stabilizer-overlap graph: within each + partition no two stabilizers share data qubits, so all stabilizers in a + partition can be applied simultaneously. + + Bipartite-by-construction caveat + -------------------------------- + The 2-coloring relies on the stabilizer-overlap graph being **bipartite**. + For the rotated, single-basis (X-only or Z-only) surface code that the + parallel path targets, this holds by construction: same-basis stabilizers + on the rotated lattice form a bipartite overlap graph. This is **not** a + generic CSS property -- color codes, non-rotated layouts, subsystem codes, + and mixed-basis matrices can produce odd cycles in the overlap graph. + + For non-bipartite inputs (or the post-BFS validator catching a same- + partition overlap), this function returns ``(None, reason)`` where + ``reason`` names the offending stabilizer pair so callers can debug + quickly. ``_require_parallel_partition`` surfaces the reason in its + error message. + + For each color we also emit per-color weight-2 fix-equivalence indices + (``w2_can_c{c}`` / ``w2_oth_c{c}``) so the parallel FE pass can apply the + boundary-stabilizer "move error from ``other`` to ``canonical``" rule that + the sequential ``_fix_equivalence`` performs in its ``ss == 2`` branch. + """ + parity_cpu = _as_uint8_binary(parity_matrix).cpu() + S, D2 = parity_cpu.shape + overlap = (parity_cpu.float() @ parity_cpu.float().T) > 0 + overlap.fill_diagonal_(False) + + color = [-1] * S + for start in range(S): + if color[start] >= 0: + continue + color[start] = 0 + queue = [start] + while queue: + u = queue.pop(0) + for v in range(S): + if not bool(overlap[u, v]): + continue + if color[v] < 0: + color[v] = 1 - color[u] + queue.append(v) + elif color[v] == color[u]: + return None, ( + "stabilizer overlap graph is not bipartite " + f"(stabilizers {u} and {v} fall in the same color class via BFS); " + "the parallel spacelike path requires a rotated single-basis " + "surface code, and this parity matrix violates that assumption" + ) + + indices_by_partition = [[i for i in range(S) if color[i] == c] for c in (0, 1)] + if not _validate_spacelike_partition(parity_cpu, indices_by_partition): + return None, ( + "internal: BFS produced a same-partition overlap that the post-hoc " + "validator rejected. This is a bug -- file an issue with the parity " + "matrix attached" + ) + + result: dict = { + "indices_c0": torch.tensor(indices_by_partition[0], dtype=torch.long, device=device), + "indices_c1": torch.tensor(indices_by_partition[1], dtype=torch.long, device=device), + } + fmap_cpu = fix_map.cpu() if fix_map.shape[0] > 0 else None + + # Map each global w4 stabilizer index to its row in fix_map. Computed once + # so the per-color loop below is O(M_c) rather than O(M_c * S) via the + # previous `w4_all.index(...)` lookup. + w4_all = [i for i in range(S) if int(parity_cpu[i].sum().item()) == 4] + g_to_fm_row = {g: row for row, g in enumerate(w4_all)} + + for c, idx_c in enumerate(indices_by_partition): + p_c = parity_cpu[idx_c].float().to(device) + w_c = parity_cpu[idx_c].sum(dim=1).int().to(device) + result[f"parity_c{c}"] = p_c + result[f"weights_c{c}"] = w_c + + w4_global = [i for i in idx_c if i in g_to_fm_row] + w4_fix_patterns: list = [] + + if fmap_cpu is not None and w4_global: + for g_idx in w4_global: + fm_row = g_to_fm_row[g_idx] + pats = [] + for p in range(3): + s0, s1 = int(fmap_cpu[fm_row, p, 0]), int(fmap_cpu[fm_row, p, 1]) + d0, d1 = int(fmap_cpu[fm_row, p, 2]), int(fmap_cpu[fm_row, p, 3]) + pats.append((s0, s1, d0, d1)) + w4_fix_patterns.append(pats) + + src_idx, dst_idx = [], [] + for pat_i in range(3): + s_list, d_list = [], [] + for stab_pats in w4_fix_patterns: + s0, s1, d0, d1 = stab_pats[pat_i] + s_list.append([s0, s1]) + d_list.append([d0, d1]) + src_idx.append(torch.tensor(s_list, dtype=torch.long, device=device)) + dst_idx.append(torch.tensor(d_list, dtype=torch.long, device=device)) + + result[f"w4_fix_c{c}"] = list(zip(src_idx, dst_idx)) + result[f"w4_parity_c{c}"] = parity_cpu[w4_global].float().to(device) + else: + result[f"w4_fix_c{c}"] = [] + result[f"w4_parity_c{c}"] = torch.zeros((0, D2), dtype=torch.float32, device=device) + + # Per-color weight-2 boundary stabilizers: emit canonical / other index lists + # so the parallel FE pass can apply the same "move from other -> canonical" + # rule the sequential path applies in its `ss == 2` branch. + if w2_canonical_cpu is not None and w2_other_cpu is not None: + w2_can_list: list[int] = [] + w2_oth_list: list[int] = [] + for s in idx_c: + can_s = int(w2_canonical_cpu[s].item()) + oth_s = int(w2_other_cpu[s].item()) + if can_s >= 0 and oth_s >= 0: + w2_can_list.append(can_s) + w2_oth_list.append(oth_s) + result[f"w2_can_c{c}"] = torch.tensor(w2_can_list, dtype=torch.long, device=device) + result[f"w2_oth_c{c}"] = torch.tensor(w2_oth_list, dtype=torch.long, device=device) + else: + result[f"w2_can_c{c}"] = torch.zeros((0,), dtype=torch.long, device=device) + result[f"w2_oth_c{c}"] = torch.zeros((0,), dtype=torch.long, device=device) + + return result, None + + +def _require_parallel_partition(cache: SpacelikeHECache) -> dict: + partition = cache.parallel_partition + if partition is None: + reason = cache.parallel_partition_failure_reason or "no diagnostic available" + raise ValueError( + "Parallel spacelike HE requires a valid 2-partition of the stabilizer-overlap graph. " + "Build the cache with basis='X' or basis='Z' and ensure the stabilizer graph is bipartite. " + f"Reason: {reason}." + ) + return partition + + +def _weight_reduction_parallel( + error: torch.Tensor, + parity_group: torch.Tensor, + weights_group: torch.Tensor, +) -> torch.Tensor: + """ + Fully vectorized weight reduction for one independent stabilizer partition. + + Since stabilizers in the partition do not share qubits, all partition + stabilizers can be applied simultaneously. + """ + try: + counts = (error.to(torch.int8) @ parity_group.to(torch.int8).T).to(torch.int32) + except RuntimeError: + counts = (error.to(torch.float32) @ parity_group.to(torch.float32).T).to(torch.int32) + + boundary = (weights_group == 2).unsqueeze(0).expand_as(counts) + act1 = (counts == 4) | ((counts == 2) & boundary) + act2 = (counts == 3) + + if not act1.any() and not act2.any(): + return error + + zero_mask = ((act1.float() @ parity_group) > 0).to(torch.uint8) + flip_mask = ((act2.float() @ parity_group) > 0).to(torch.uint8 + ) & (~zero_mask.bool()).to(torch.uint8) + + error = error & (~zero_mask.bool()).to(error.dtype) + error = error ^ flip_mask + return error + + +def _fix_equivalence_w2_parallel( + error: torch.Tensor, + can_idx: torch.Tensor, + oth_idx: torch.Tensor, + claimed: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Vectorized weight-2 fix-equivalence for one independent partition. + + Mirrors the `ss == 2` branch of the sequential `_fix_equivalence`: + - `should_process = (count == 1) & ~has_overlap` + where `has_overlap = (vals & claimed).any()` per (canonical, other). + - `should_move = should_process & (cfg[:, canonical] == 0)` + i.e. the error currently sits on `other`; move it to `canonical`. + - On `should_move`, set `canonical=1`, `other=0`. + - Claim is updated **only** for `should_move` (not `should_process`), + matching the sequential `claimed[:, ...] = claimed[:, ...] | should_move`. + This is important: an already-canonical error does not lock its qubits, + so a later w4 stabilizer can still legitimately use them. + + Disjoint supports within a color make this safe to fire in parallel across + all w2 stabilizers in the partition. + + The eager (uint8/bool) variant; see `_fe_w2_parallel_step_nobreak` for the + compile-friendly all-float twin. + """ + N, D2 = error.shape + if claimed is None: + claimed = torch.zeros((N, D2), dtype=torch.bool, device=error.device) + if can_idx.shape[0] == 0: + return error, claimed + + can_b = can_idx.unsqueeze(0).expand(N, -1) + oth_b = oth_idx.unsqueeze(0).expand(N, -1) + + v_can = torch.gather(error, 1, can_b) + v_oth = torch.gather(error, 1, oth_b) + c_can = torch.gather(claimed, 1, can_b) + c_oth = torch.gather(claimed, 1, oth_b) + + has_overlap = (v_can.bool() & c_can) | (v_oth.bool() & c_oth) + err_count = v_can.to(torch.int16) + v_oth.to(torch.int16) + should_process = (err_count == 1) & (~has_overlap) + should_move = should_process & (v_can == 0) + + if not should_move.any(): + return error, claimed + + # `should_move == True` implies `v_can == 0` and `v_oth == 1`, so the + # write reduces to bitwise `v_can | move` and `v_oth & ~move`. This avoids + # the `torch.where` plus `ones_like` / `zeros_like` allocations the naive + # form would carry. + error = error.clone() + move_u8 = should_move.to(error.dtype) + error.scatter_(1, can_b, v_can | move_u8) + error.scatter_(1, oth_b, v_oth & (~should_move).to(error.dtype)) + + claimed = claimed.scatter(1, can_b, c_can | should_move) + claimed = claimed.scatter(1, oth_b, c_oth | should_move) + + return error, claimed + + +def _fix_equivalence_parallel( + error: torch.Tensor, + parity_w4: torch.Tensor, + fix_patterns: list, + claimed: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Vectorized fix-equivalence for one independent partition's weight-4 stabilizers. + + `claimed` is threaded across partitions so overlapping stabilizers from + different partitions do not double-modify the same qubit in one pass. + """ + N, D2 = error.shape + if claimed is None: + claimed = torch.zeros((N, D2), dtype=torch.bool, device=error.device) + + if parity_w4.shape[0] == 0 or not fix_patterns: + return error, claimed + + try: + counts = (error.to(torch.int8) @ parity_w4.to(torch.int8).T).to(torch.int32) + except RuntimeError: + counts = (error.to(torch.float32) @ parity_w4.to(torch.float32).T).to(torch.int32) + + has_2 = counts == 2 + if not has_2.any(): + return error, claimed + + error = error.clone() + handled = torch.zeros_like(has_2) + num_w4 = int(parity_w4.shape[0]) + + for src_idx, dst_idx in fix_patterns: + eligible = has_2 & ~handled + if not eligible.any(): + break + + # Per-stabilizer corner indices: (num_w4, 4). Within a partition the + # supports are disjoint, so each row's 4 entries are unique qubits. + qi_per_stab = torch.stack( + [src_idx[:, 0], src_idx[:, 1], dst_idx[:, 0], dst_idx[:, 1]], dim=-1 + ) + qi_flat = qi_per_stab.reshape(-1).unsqueeze(0).expand(N, -1) # (N, 4*num_w4) + # (N, num_w4, 4): True iff that corner of that stabilizer was claimed + # by an earlier partition in this iteration. + claimed_at_corners = torch.gather(claimed, 1, qi_flat).view(N, num_w4, 4) + has_overlap = claimed_at_corners.any(dim=2) + + src_vals = error[..., src_idx[:, 0]] & error[..., src_idx[:, 1]] + matches = eligible & (src_vals == 1) & (~has_overlap) + if not matches.any(): + continue + + zero = torch.tensor(0, dtype=torch.uint8, device=error.device) + one = torch.tensor(1, dtype=torch.uint8, device=error.device) + for k in range(2): + qi = src_idx[:, k] + error.scatter_( + -1, + qi.unsqueeze(0).expand(N, -1), + torch.where(matches, zero, error[..., qi]), + ) + for k in range(2): + qi = dst_idx[:, k] + error.scatter_( + -1, + qi.unsqueeze(0).expand(N, -1), + torch.where(matches, one, error[..., qi]), + ) + + # Claim only the corners of stabilizers that actually fired, per sample. + matches_per_corner = matches.repeat_interleave(4, dim=1) + old_claim = claimed.gather(1, qi_flat) + claimed = claimed.scatter(1, qi_flat, old_claim | matches_per_corner) + handled = handled | matches + + return error, claimed + + +def _simplify_spacelike_parallel( + cfg: torch.Tensor, + partition: dict, + max_iterations: int = 100, +) -> torch.Tensor: + """Run spacelike HE using two independent stabilizer partitions.""" + cfg = _ensure_uint8(cfg) + for _ in range(max_iterations): + prev = cfg + for c in (0, 1): + cfg = _weight_reduction_parallel( + cfg, partition[f"parity_c{c}"], partition[f"weights_c{c}"] + ) + claimed = None + # Color-major FE: w2 then w4 within each color, threading `claimed` across + # both colors so cross-partition overlaps cannot double-modify a qubit. + for c in (0, 1): + cfg, claimed = _fix_equivalence_w2_parallel( + cfg, partition[f"w2_can_c{c}"], partition[f"w2_oth_c{c}"], claimed=claimed + ) + cfg, claimed = _fix_equivalence_parallel( + cfg, partition[f"w4_parity_c{c}"], partition[f"w4_fix_c{c}"], claimed=claimed + ) + if torch.equal(cfg, prev): + break + return cfg + + # --------------------------------------------------------------------------- # Coset min-weight search for stuck patterns (NEW-4 / P12) # --------------------------------------------------------------------------- @@ -285,6 +783,184 @@ def coset_minimum_weight( return result +# --------------------------------------------------------------------------- +# Compile-safe parallel spacelike variants +# --------------------------------------------------------------------------- +# All-float, no .item(), no data-dependent branches, no in-place mutations. +# Algebraic XOR: error ^ flip == error - 2*error*flip + flip (exact for {0,1} floats). + + +def _wr_parallel_step_nobreak( + error_f: torch.Tensor, + parity: torch.Tensor, + is_boundary: torch.Tensor, +) -> torch.Tensor: + """Compile-friendly weight reduction for one independent partition.""" + counts = error_f @ parity.T + act1 = (counts == 4.0) | ((counts == 2.0) & (is_boundary > 0.5)) + act2 = counts == 3.0 + + zero_mask = (act1.float() @ parity).clamp(max=1.0) + flip_raw = (act2.float() @ parity).clamp(max=1.0) + flip_mask = flip_raw * (1.0 - zero_mask) + + error_f = error_f * (1.0 - zero_mask) + return error_f - 2.0 * error_f * flip_mask + flip_mask + + +def _fe_w2_parallel_step_nobreak( + error_f: torch.Tensor, + can_idx: torch.Tensor, + oth_idx: torch.Tensor, + valid_f: torch.Tensor, + claimed_f: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Compile-friendly weight-2 fix-equivalence for one independent partition. + + Compile-safe twin of `_fix_equivalence_w2_parallel`: + - all-float, no `.item()`, no data-dependent Python branches; + - static shapes (callers pad an empty color to width 1 and mask via + `valid_f`, so the chunk function has fixed input shapes for CUDA-graph + capture). + + `valid_f`: (M_c,) float in {0., 1.} — 0 marks padded dummy slots that the + mask zeros out so they perform no FE move. + """ + if claimed_f is None: + claimed_f = torch.zeros_like(error_f) + if can_idx.shape[0] == 0: + return error_f, claimed_f + + N = error_f.shape[0] + can_b = can_idx.unsqueeze(0).expand(N, -1) + oth_b = oth_idx.unsqueeze(0).expand(N, -1) + + v_can = torch.gather(error_f, 1, can_b) + v_oth = torch.gather(error_f, 1, oth_b) + c_can = torch.gather(claimed_f, 1, can_b) + c_oth = torch.gather(claimed_f, 1, oth_b) + + overlap = ((v_can > 0.5) & (c_can > 0.5)) | ((v_oth > 0.5) & (c_oth > 0.5)) + valid_b = (valid_f > 0.5).unsqueeze(0) + process = (v_can + v_oth == 1.0) & (~overlap) & valid_b + move = process & (v_can < 0.5) + move_f = move.float() + + error_f = error_f.scatter(1, can_b, v_can * (1.0 - move_f) + move_f) + error_f = error_f.scatter(1, oth_b, v_oth * (1.0 - move_f)) + + claimed_f = claimed_f.scatter(1, can_b, (c_can + move_f).clamp(max=1.0)) + claimed_f = claimed_f.scatter(1, oth_b, (c_oth + move_f).clamp(max=1.0)) + + return error_f, claimed_f + + +def _fe_parallel_step_nobreak( + error_f: torch.Tensor, + src0: torch.Tensor, + src1: torch.Tensor, + dst0: torch.Tensor, + dst1: torch.Tensor, + parity_w4: torch.Tensor, + claimed_f: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Compile-friendly fix-equivalence for one independent partition.""" + if claimed_f is None: + claimed_f = torch.zeros_like(error_f) + + if parity_w4.shape[0] == 0: + return error_f, claimed_f + + N = error_f.shape[0] + num_w4 = int(parity_w4.shape[0]) + counts = error_f @ parity_w4.T + has_2 = counts == 2.0 + handled = torch.zeros_like(has_2) + + for p in range(3): + s0, s1 = src0[p], src1[p] + d0, d1 = dst0[p], dst1[p] + eligible = has_2 & ~handled + + # Per-stabilizer corner indices: (num_w4, 4). Disjoint within partition. + qi_per_stab = torch.stack([s0, s1, d0, d1], dim=-1) + qi_flat = qi_per_stab.reshape(-1).unsqueeze(0).expand(N, -1) # (N, 4*num_w4) + claimed_at_corners = torch.gather(claimed_f, 1, qi_flat).view(N, num_w4, 4) + has_overlap = (claimed_at_corners > 0.5).any(dim=2) # (N, num_w4) + + v0 = torch.gather(error_f, 1, s0.unsqueeze(0).expand(N, -1)) + v1 = torch.gather(error_f, 1, s1.unsqueeze(0).expand(N, -1)) + match = eligible & (v0 == 1.0) & (v1 == 1.0) & (~has_overlap) + match_f = match.float() + + for qi in (s0, s1): + idx = qi.unsqueeze(0).expand(N, -1) + old = torch.gather(error_f, 1, idx) + error_f = error_f.scatter(1, idx, old * (1 - match_f)) + + for qi in (d0, d1): + idx = qi.unsqueeze(0).expand(N, -1) + old = torch.gather(error_f, 1, idx) + error_f = error_f.scatter(1, idx, old * (1 - match_f) + match_f) + + # Claim only matched stabilizers' corners, per sample. + match_per_corner = match.float().repeat_interleave(4, dim=1) + old_claim = torch.gather(claimed_f, 1, qi_flat) + claimed_f = claimed_f.scatter(1, qi_flat, (old_claim + match_per_corner).clamp(max=1.0)) + handled = handled | match + + return error_f, claimed_f + + +def _pack_partition_for_compile(partition: dict, device: torch.device) -> dict: + """Convert a parallel partition dict into compile-friendly tensor arguments. + + For each color we also pack the weight-2 boundary canonical/other index + pair plus a `w2_valid_c{c}` mask. Empty colors (no w2 stabilizers) are + padded to width 1 with `valid=0` so the compiled chunk function has + fixed input shapes regardless of partition. + """ + packed: dict = {} + for c in (0, 1): + packed[f"parity_c{c}"] = partition[f"parity_c{c}"] + packed[f"is_boundary_c{c}"] = (partition[f"weights_c{c}"] == 2).float().unsqueeze(0) + packed[f"w4_parity_c{c}"] = partition[f"w4_parity_c{c}"] + + fix_list = partition[f"w4_fix_c{c}"] + if fix_list: + s0_list, s1_list, d0_list, d1_list = [], [], [], [] + for src_idx, dst_idx in fix_list: + s0_list.append(src_idx[:, 0]) + s1_list.append(src_idx[:, 1]) + d0_list.append(dst_idx[:, 0]) + d1_list.append(dst_idx[:, 1]) + packed[f"src0_c{c}"] = torch.stack(s0_list) + packed[f"src1_c{c}"] = torch.stack(s1_list) + packed[f"dst0_c{c}"] = torch.stack(d0_list) + packed[f"dst1_c{c}"] = torch.stack(d1_list) + else: + S_w4 = partition[f"w4_parity_c{c}"].shape[0] + width = max(S_w4, 1) + packed[f"src0_c{c}"] = torch.zeros(3, width, dtype=torch.long, device=device) + packed[f"src1_c{c}"] = torch.zeros(3, width, dtype=torch.long, device=device) + packed[f"dst0_c{c}"] = torch.zeros(3, width, dtype=torch.long, device=device) + packed[f"dst1_c{c}"] = torch.zeros(3, width, dtype=torch.long, device=device) + + w2_can = partition[f"w2_can_c{c}"] + w2_oth = partition[f"w2_oth_c{c}"] + if w2_can.numel() > 0: + packed[f"w2_can_c{c}"] = w2_can + packed[f"w2_oth_c{c}"] = w2_oth + packed[f"w2_valid_c{c}"] = torch.ones( + w2_can.numel(), dtype=torch.float32, device=device + ) + else: + packed[f"w2_can_c{c}"] = torch.zeros(1, dtype=torch.long, device=device) + packed[f"w2_oth_c{c}"] = torch.zeros(1, dtype=torch.long, device=device) + packed[f"w2_valid_c{c}"] = torch.zeros(1, dtype=torch.float32, device=device) + return packed + + def _simplify_time_w1_step_nobreak( err: torch.Tensor, syn: torch.Tensor, @@ -1078,6 +1754,11 @@ def _simplify_spacelike_seq_compiled( prev.copy_(cfg) torch.compiler.cudagraph_mark_step_begin() + # `.clone()` is required: with `mode="reduce-overhead"` the compiled + # WR function returns a tensor that aliases an internal CUDA-graph + # output buffer. Without this clone, the next iteration's replay + # would overwrite the buffer that `cfg_f` (and therefore `cfg`) is + # observing, silently corrupting subsequent FE input. cfg_f = wr_fn( cfg.to(torch.float32), scd["padded_masks"], @@ -1088,6 +1769,10 @@ def _simplify_spacelike_seq_compiled( if fe_graph_data is not None: gd = fe_graph_data + # The CUDA graph operates on fixed `cfg_static` / `claimed_static` + # buffers; copy in the live `cfg` and replay, then copy out. The + # final `cfg.copy_(...)` is what makes the result visible outside + # the graph's static-buffer world. gd["cfg_static"].copy_(cfg) gd["claimed_static"].zero_() gd["graph"].replay() @@ -1101,6 +1786,102 @@ def _simplify_spacelike_seq_compiled( return cfg +_PARALLEL_PACKED_ARG_KEYS = ( + "parity_c0", + "is_boundary_c0", + "w4_parity_c0", + "src0_c0", + "src1_c0", + "dst0_c0", + "dst1_c0", + "w2_can_c0", + "w2_oth_c0", + "w2_valid_c0", + "parity_c1", + "is_boundary_c1", + "w4_parity_c1", + "src0_c1", + "src1_c1", + "dst0_c1", + "dst1_c1", + "w2_can_c1", + "w2_oth_c1", + "w2_valid_c1", +) + + +def _simplify_spacelike_parallel_compiled( + cfg: torch.Tensor, + cache: SpacelikeHECache, + max_iterations: int = 100, + compute_dtype: torch.dtype = torch.float32, + partition_override: Optional[dict] = None, +) -> torch.Tensor: + """ + Run parallel spacelike HE through torch.compile with chunked early exit. + + A compiled inner function runs `_SPACELIKE_CHUNK` iterations, then + convergence is checked outside the compiled boundary. The compile-friendly + inputs are packed once at cache-build time (``cache.parallel_partition_packed``) + so this hot path only does dtype casts of the float entries when the + caller asks for a non-float32 ``compute_dtype``. + + Dispatch rule: if ``partition_override`` is ``None`` *or* is the very + object the cache already packed (the common production case -- + ``_simplify_spacelike`` resolves ``_require_parallel_partition(cache)`` and + threads it back in), we read ``cache.parallel_partition_packed`` and avoid + re-running ``_pack_partition_for_compile`` on every call. Test/bench + harnesses that pass a *different* partition still go through the pack-on- + the-fly path -- identity, not equality, decides which path runs, so a + caller that hands us a fresh dict will get fresh packing. + """ + use_cached = partition_override is None or ( + cache.parallel_partition is not None and partition_override is cache.parallel_partition + ) + + if use_cached: + if cache.parallel_partition is None or cache.parallel_partition_packed is None: + _require_parallel_partition(cache) # raises with diagnostic + packed = cache.parallel_partition_packed + else: + packed = _pack_partition_for_compile(partition_override, cfg.device) + + chunk_fn = _get_compiled_spacelike_chunk() + if cfg.dtype != torch.uint8: + cfg = _as_uint8_binary(cfg) + + cfg_f = cfg.to(compute_dtype) + args: list = [] + for key in _PARALLEL_PACKED_ARG_KEYS: + v = packed[key] + # `.to(dtype)` is a no-op when the tensor already matches; for the + # default `compute_dtype=float32` case the packed dict is built at + # float32, so this loop is just dict lookups + identity casts. + args.append(v.to(compute_dtype) if v.is_floating_point() else v) + + # Convergence is tracked entirely in float. The chunk produces values that + # are exactly {0.0, 1.0} in IEEE float (algebraic XOR `e - 2*e*f + f` is + # exact on {0,1}-valued floats), so `torch.equal` on the rounded float is + # safe -- and skipping the per-chunk uint8 round-trip saves two N*D2 dtype + # casts per outer iteration. + prev_f = cfg_f.round() + for _ in range(0, max_iterations, _SPACELIKE_CHUNK): + torch.compiler.cudagraph_mark_step_begin() + # `.clone()` is required: `mode="reduce-overhead"` puts `chunk_fn` + # behind a CUDA graph whose output buffer is reused on the next replay. + # Without this clone, `prev_f` would alias the buffer that the next + # iteration's WR overwrites, the convergence check would read the + # post-iteration tensor instead of the pre-iteration one, and the + # outer loop would always exit on the first iteration. + cfg_f = chunk_fn(cfg_f, *args).clone() + curr_f = cfg_f.round() + if torch.equal(curr_f, prev_f): + break + prev_f = curr_f + + return cfg_f.round().to(torch.uint8) + + def _simplify_spacelike( cfg: torch.Tensor, cache: SpacelikeHECache, @@ -1112,11 +1893,24 @@ def _simplify_spacelike( use_coset_search: bool = False, parity: Optional[torch.Tensor] = None, coset_max_generators: int = 20, + use_parallel_spacelike: bool = False, ) -> torch.Tensor: - if use_compile and cache.seq_compile_data is not None: + partition = _require_parallel_partition(cache) if use_parallel_spacelike else None + + if use_compile and partition is not None: + cfg = _simplify_spacelike_parallel_compiled( + cfg, + cache, + max_iterations=max_iterations, + compute_dtype=compute_dtype, + partition_override=partition, + ) + elif use_compile and cache.seq_compile_data is not None: cfg = _simplify_spacelike_seq_compiled( cfg, cache, max_iterations=max_iterations, basis=basis ) + elif partition is not None: + cfg = _simplify_spacelike_parallel(cfg, partition, max_iterations=max_iterations) else: if cfg.dtype != torch.uint8: cfg = _as_uint8_binary(cfg) @@ -1149,6 +1943,7 @@ def apply_homological_equivalence_torch_vmap( compute_dtype: torch.dtype = torch.float32, use_coset_search: bool = False, coset_max_generators: int = 20, + use_parallel_spacelike: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Torch spacelike HE implementation. @@ -1160,6 +1955,7 @@ def apply_homological_equivalence_torch_vmap( coset representatives and pick the minimum-weight one (P12 / NEW-4). coset_max_generators: Guard against exponential blowup — skip coset search if the stabilizer count exceeds this value. + use_parallel_spacelike: If True, use the 2-partition parallel spacelike path. """ z = _as_uint8_binary(z_diffs) x = _as_uint8_binary(x_diffs) @@ -1189,6 +1985,7 @@ def apply_homological_equivalence_torch_vmap( use_coset_search=use_coset_search, parity=parity_X, coset_max_generators=coset_max_generators, + use_parallel_spacelike=use_parallel_spacelike, ) z_can = _simplify_spacelike( z_flat, @@ -1200,6 +1997,7 @@ def apply_homological_equivalence_torch_vmap( use_coset_search=use_coset_search, parity=parity_Z, coset_max_generators=coset_max_generators, + use_parallel_spacelike=use_parallel_spacelike, ) return z_can.reshape(B, T, D2), x_can.reshape(B, T, D2) @@ -2032,6 +2830,7 @@ def apply_weight1_timelike_homological_equivalence_torch( cache_X_w2: Optional[Weight2TimelikeCache] = None, use_coset_search: bool = False, coset_max_generators: int = 20, + use_parallel_spacelike: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ Torch HE: spacelike + timelike weight-1 (+ optional weight-2). @@ -2052,6 +2851,7 @@ def apply_weight1_timelike_homological_equivalence_torch( use_coset_search: If True, after greedy spacelike canonicalization, enumerate all coset representatives and pick the minimum-weight one (P12 / NEW-4). coset_max_generators: Skip coset search if S exceeds this (exponential guard). + use_parallel_spacelike: If True, use 2-partition parallel spacelike HE. """ z_diffs, x_diffs = _cumulative_to_diffs_torch(z_errors, x_errors) sx = _as_uint8_binary(s1s2_x) @@ -2089,6 +2889,7 @@ def apply_weight1_timelike_homological_equivalence_torch( compute_dtype=dt, use_coset_search=use_coset_search, coset_max_generators=coset_max_generators, + use_parallel_spacelike=use_parallel_spacelike, ) for _ in range(int(num_he_cycles)): @@ -2172,9 +2973,68 @@ def apply_weight1_timelike_homological_equivalence_torch( # torch.compile caches and warmup (OPT-6) # --------------------------------------------------------------------------- +_compiled_spacelike_cache: dict = {} _compiled_timelike_cache: dict = {} _compiled_weight2_cache: dict = {} +_SPACELIKE_CHUNK = 4 + + +def _get_compiled_spacelike_chunk(): + """Return a compiled parallel spacelike HE chunk function.""" + key = _SPACELIKE_CHUNK + if key in _compiled_spacelike_cache: + return _compiled_spacelike_cache[key] + + chunk = _SPACELIKE_CHUNK + + def _spacelike_chunk( + error_f, + par_c0, + bnd_c0, + w4par_c0, + s0_c0, + s1_c0, + d0_c0, + d1_c0, + w2can_c0, + w2oth_c0, + w2val_c0, + par_c1, + bnd_c1, + w4par_c1, + s0_c1, + s1_c1, + d0_c1, + d1_c1, + w2can_c1, + w2oth_c1, + w2val_c1, + ): + for _ in range(chunk): + error_f = _wr_parallel_step_nobreak(error_f, par_c0, bnd_c0) + error_f = _wr_parallel_step_nobreak(error_f, par_c1, bnd_c1) + claimed_f = torch.zeros_like(error_f) + # Color-major FE: w2 then w4 within each color, threading `claimed_f` + # so cross-partition overlaps cannot double-modify a qubit. + error_f, claimed_f = _fe_w2_parallel_step_nobreak( + error_f, w2can_c0, w2oth_c0, w2val_c0, claimed_f + ) + error_f, claimed_f = _fe_parallel_step_nobreak( + error_f, s0_c0, s1_c0, d0_c0, d1_c0, w4par_c0, claimed_f + ) + error_f, claimed_f = _fe_w2_parallel_step_nobreak( + error_f, w2can_c1, w2oth_c1, w2val_c1, claimed_f + ) + error_f, claimed_f = _fe_parallel_step_nobreak( + error_f, s0_c1, s1_c1, d0_c1, d1_c1, w4par_c1, claimed_f + ) + return error_f + + compiled = torch.compile(_spacelike_chunk, mode="reduce-overhead", fullgraph=True) + _compiled_spacelike_cache[key] = compiled + return compiled + def _get_compiled_timelike_loop( max_t: int, @@ -2436,6 +3296,7 @@ def warmup_he_compile( apply_spacelike: bool = True, use_weight2: bool = False, max_passes_w2: int = 4, + use_parallel_spacelike: bool = False, ) -> None: """Eagerly trigger torch.compile for all HE kernels. @@ -2454,6 +3315,8 @@ def warmup_he_compile( min_t_z = 1 if str(basis).upper() == "Z" else 0 if apply_spacelike: + if use_parallel_spacelike: + _get_compiled_spacelike_chunk() for nl in range(max(1, int(distance) - 1), int(distance) + 2): _get_compiled_seq_wr(nl) diff --git a/code/qec/surface_code/memory_circuit_torch.py b/code/qec/surface_code/memory_circuit_torch.py index d688ab2..8a965d1 100644 --- a/code/qec/surface_code/memory_circuit_torch.py +++ b/code/qec/surface_code/memory_circuit_torch.py @@ -84,6 +84,7 @@ def __init__( use_coset_search: bool = False, coset_max_generators: int = 20, use_dense_overlap: bool = False, + use_parallel_spacelike: bool = False, # Optional in-memory DEM artifacts (to avoid writing/loading files). H: torch.Tensor | None = None, # (2*num_detectors, num_errors) uint8 p: torch.Tensor | None = None, # (num_errors,) float32 @@ -105,6 +106,7 @@ def __init__( self.use_coset_search = bool(use_coset_search) self.coset_max_generators = int(coset_max_generators) self.use_dense_overlap = bool(use_dense_overlap) + self.use_parallel_spacelike = bool(use_parallel_spacelike) self.device = device if device is not None else torch.device( "cuda" if torch.cuda.is_available() else "cpu" ) @@ -121,6 +123,7 @@ def __init__( max_passes_w1=self.max_passes_w1, use_weight2=self.use_weight2, max_passes_w2=self.max_passes_w2, + use_parallel_spacelike=self.use_parallel_spacelike, ), daemon=True, ) @@ -339,6 +342,7 @@ def generate_batch( use_coset_search=self.use_coset_search, coset_max_generators=self.coset_max_generators, use_dense_overlap=self.use_dense_overlap, + use_parallel_spacelike=self.use_parallel_spacelike, ) meas_new = torch.cat([s1s2x, s1s2z], dim=2) else: diff --git a/code/tests/mid/test_homological_equivalence.py b/code/tests/mid/test_homological_equivalence.py index 129591f..d7f13a9 100644 --- a/code/tests/mid/test_homological_equivalence.py +++ b/code/tests/mid/test_homological_equivalence.py @@ -37,11 +37,17 @@ from qec.surface_code.memory_circuit import SurfaceCode from qec.surface_code.memory_circuit_torch import MemoryCircuitTorch from qec.surface_code.homological_equivalence_torch import ( + apply_homological_equivalence_torch_vmap, apply_weight1_timelike_homological_equivalence_torch, build_spacelike_he_cache, build_timelike_he_cache, build_weight2_timelike_cache, + _validate_spacelike_partition, _simplify_time_w2_step, + _build_fix_equiv_map, + _build_spacelike_partition, + _require_parallel_partition, + SpacelikeHECache, ) from qec.surface_code.homological_equivalence import ( linear_index_to_coordinates, @@ -1184,6 +1190,681 @@ def test_cache_layers_cover_all_stabs(self): all_idx.extend(layer.tolist()) self.assertEqual(sorted(all_idx), list(range(hx.shape[0]))) + def test_parallel_partition_valid_for_supported_distances(self): + """Parallel partition should cover every stabilizer and avoid same-partition overlaps.""" + for d in (3, 5, 7, 9): + hx, hz, _ = _build_parity_matrices(d) + for basis, parity in (("X", hx), ("Z", hz)): + cache = build_spacelike_he_cache( + parity.to(torch.uint8), distance=d, basis=basis, device=torch.device("cpu") + ) + partition = cache.parallel_partition + self.assertIsNotNone(partition, f"d={d} basis={basis}: missing partition") + + groups = [ + partition["indices_c0"].cpu().tolist(), + partition["indices_c1"].cpu().tolist(), + ] + self.assertEqual( + sorted(groups[0] + groups[1]), + list(range(parity.shape[0])), + f"d={d} basis={basis}: partition does not cover stabilizers", + ) + self.assertTrue( + _validate_spacelike_partition(parity.to(torch.uint8), groups), + f"d={d} basis={basis}: invalid same-partition overlap", + ) + + def test_parallel_partition_validator_rejects_overlap(self): + """A same-partition overlap must fail validation instead of silently looking valid.""" + hx, _, _ = _build_parity_matrices(3) + overlapping = None + for i in range(hx.shape[0]): + for j in range(i + 1, hx.shape[0]): + if bool(torch.any((hx[i] == 1) & (hx[j] == 1))): + overlapping = (i, j) + break + if overlapping is not None: + break + self.assertIsNotNone(overlapping) + i, j = overlapping + remaining = [k for k in range(hx.shape[0]) if k not in (i, j)] + self.assertFalse(_validate_spacelike_partition(hx.to(torch.uint8), [[i, j], remaining])) + + +class TestParallelSpacelikeHE(unittest.TestCase): + """Correctness tests for the 2-partition parallel spacelike HE path.""" + + def _make_inputs(self, d: int, seed: int = 123): + hx, hz, _ = _build_parity_matrices(d) + parity_X = hx.to(torch.uint8) + parity_Z = hz.to(torch.uint8) + cache_X = build_spacelike_he_cache( + parity_X, distance=d, basis="X", device=torch.device("cpu") + ) + cache_Z = build_spacelike_he_cache( + parity_Z, distance=d, basis="Z", device=torch.device("cpu") + ) + g = torch.Generator().manual_seed(seed) + z_diffs = torch.randint(0, 2, (4, d, d * d), dtype=torch.uint8, generator=g) + x_diffs = torch.randint(0, 2, (4, d, d * d), dtype=torch.uint8, generator=g) + return parity_X, parity_Z, cache_X, cache_Z, z_diffs, x_diffs + + def _assert_spacelike_invariants( + self, + *, + d: int, + parity_X: torch.Tensor, + parity_Z: torch.Tensor, + z_in: torch.Tensor, + x_in: torch.Tensor, + z_out: torch.Tensor, + x_out: torch.Tensor, + tag: str, + ): + self.assertTrue(((z_out == 0) | (z_out == 1)).all(), f"{tag}: Z output non-binary") + self.assertTrue(((x_out == 0) | (x_out == 1)).all(), f"{tag}: X output non-binary") + self.assertTrue( + torch.all(z_out.sum(dim=-1) <= z_in.sum(dim=-1)), + f"{tag}: Z weight increased for d={d}", + ) + self.assertTrue( + torch.all(x_out.sum(dim=-1) <= x_in.sum(dim=-1)), + f"{tag}: X weight increased for d={d}", + ) + z_syn_in = (z_in.float() @ parity_X.float().T) % 2 + z_syn_out = (z_out.float() @ parity_X.float().T) % 2 + x_syn_in = (x_in.float() @ parity_Z.float().T) % 2 + x_syn_out = (x_out.float() @ parity_Z.float().T) % 2 + self.assertTrue(torch.equal(z_syn_in, z_syn_out), f"{tag}: Z syndrome changed") + self.assertTrue(torch.equal(x_syn_in, x_syn_out), f"{tag}: X syndrome changed") + + def test_parallel_spacelike_invariants_and_idempotence(self): + for d in (3, 5, 7): + parity_X, parity_Z, cache_X, cache_Z, z_in, x_in = self._make_inputs(d, seed=1000 + d) + z_out, x_out = apply_homological_equivalence_torch_vmap( + z_in, + x_in, + parity_Z, + parity_X, + d, + cache_Z=cache_Z, + cache_X=cache_X, + use_parallel_spacelike=True, + ) + self._assert_spacelike_invariants( + d=d, + parity_X=parity_X, + parity_Z=parity_Z, + z_in=z_in, + x_in=x_in, + z_out=z_out, + x_out=x_out, + tag=f"parallel d={d}", + ) + + z_twice, x_twice = apply_homological_equivalence_torch_vmap( + z_out, + x_out, + parity_Z, + parity_X, + d, + cache_Z=cache_Z, + cache_X=cache_X, + use_parallel_spacelike=True, + ) + self.assertTrue(torch.equal(z_out, z_twice), f"parallel d={d}: Z not idempotent") + self.assertTrue(torch.equal(x_out, x_twice), f"parallel d={d}: X not idempotent") + + def test_parallel_spacelike_matches_or_preserves_legacy_invariants(self): + for d in (3, 5): + parity_X, parity_Z, cache_X, cache_Z, z_in, x_in = self._make_inputs(d, seed=2000 + d) + z_seq, x_seq = apply_homological_equivalence_torch_vmap( + z_in, + x_in, + parity_Z, + parity_X, + d, + cache_Z=cache_Z, + cache_X=cache_X, + ) + z_par, x_par = apply_homological_equivalence_torch_vmap( + z_in, + x_in, + parity_Z, + parity_X, + d, + cache_Z=cache_Z, + cache_X=cache_X, + use_parallel_spacelike=True, + ) + if not (torch.equal(z_seq, z_par) and torch.equal(x_seq, x_par)): + self._assert_spacelike_invariants( + d=d, + parity_X=parity_X, + parity_Z=parity_Z, + z_in=z_in, + x_in=x_in, + z_out=z_par, + x_out=x_par, + tag=f"parallel-vs-legacy d={d}", + ) + z_seq_syn = (z_seq.float() @ parity_X.float().T) % 2 + z_par_syn = (z_par.float() @ parity_X.float().T) % 2 + x_seq_syn = (x_seq.float() @ parity_Z.float().T) % 2 + x_par_syn = (x_par.float() @ parity_Z.float().T) % 2 + self.assertTrue( + torch.equal(z_seq_syn, z_par_syn), f"parallel d={d}: Z syndrome mismatch" + ) + self.assertTrue( + torch.equal(x_seq_syn, x_par_syn), f"parallel d={d}: X syndrome mismatch" + ) + + def test_parallel_spacelike_requires_partition(self): + d = 3 + hx, hz, _ = _build_parity_matrices(d) + parity_X = hx.to(torch.uint8) + parity_Z = hz.to(torch.uint8) + cache_X = build_spacelike_he_cache(parity_X, distance=d, device=torch.device("cpu")) + cache_Z = build_spacelike_he_cache(parity_Z, distance=d, device=torch.device("cpu")) + z = torch.zeros((1, 1, d * d), dtype=torch.uint8) + x = torch.zeros((1, 1, d * d), dtype=torch.uint8) + with self.assertRaises(ValueError): + apply_homological_equivalence_torch_vmap( + z, + x, + parity_Z, + parity_X, + d, + cache_Z=cache_Z, + cache_X=cache_X, + use_parallel_spacelike=True, + ) + + def test_parallel_partition_packs_w2_indices(self): + """The 2-partition must surface per-color w2 (canonical, other) index lists. + + Regression guard: without these, the parallel + FE path silently drops the boundary-stabilizer fix-equivalence rule that + the sequential path applies in its `ss == 2` branch. + """ + for d in (3, 5, 7): + hx, hz, _ = _build_parity_matrices(d) + for basis, parity in (("X", hx), ("Z", hz)): + cache = build_spacelike_he_cache( + parity.to(torch.uint8), distance=d, basis=basis, device=torch.device("cpu") + ) + partition = cache.parallel_partition + self.assertIsNotNone(partition) + for c in (0, 1): + self.assertIn(f"w2_can_c{c}", partition, f"d={d} {basis} c{c}: missing w2_can") + self.assertIn(f"w2_oth_c{c}", partition, f"d={d} {basis} c{c}: missing w2_oth") + self.assertEqual( + partition[f"w2_can_c{c}"].shape, + partition[f"w2_oth_c{c}"].shape, + f"d={d} {basis} c{c}: w2_can/oth shape mismatch", + ) + + total_w2 = sum(partition[f"w2_can_c{c}"].numel() for c in (0, 1)) + expected_w2 = int((cache.support_sizes == 2).sum().item()) + self.assertEqual( + total_w2, + expected_w2, + f"d={d} {basis}: partition w2 count {total_w2} != cache w2 count {expected_w2}", + ) + + def test_parallel_w2_fe_moves_error_from_other_to_canonical(self): + """Single-error sentinel: error placed on `other` of a w2 stabilizer must + end up on `canonical` after one parallel HE pass, exactly like the + sequential path. The parallel path must preserve this property.""" + for d in (3, 5, 7): + hx, hz, _ = _build_parity_matrices(d) + parity_X = hx.to(torch.uint8) + parity_Z = hz.to(torch.uint8) + cache_X = build_spacelike_he_cache( + parity_X, distance=d, basis="X", device=torch.device("cpu") + ) + cache_Z = build_spacelike_he_cache( + parity_Z, distance=d, basis="Z", device=torch.device("cpu") + ) + + for basis, cache in (("X", cache_X), ("Z", cache_Z)): + w2_stabs = (cache.support_sizes == 2).nonzero().flatten().tolist() + self.assertGreater(len(w2_stabs), 0, f"d={d} {basis}: expected at least one w2") + + for s in w2_stabs: + canonical = int(cache.w2_canonical[s].item()) + other = int(cache.w2_other[s].item()) + self.assertGreaterEqual(canonical, 0) + self.assertGreaterEqual(other, 0) + + err = torch.zeros((1, 1, d * d), dtype=torch.uint8) + err[0, 0, other] = 1 + + if basis == "X": + z_in = torch.zeros_like(err) + x_in = err + else: + z_in = err + x_in = torch.zeros_like(err) + + z_out, x_out = apply_homological_equivalence_torch_vmap( + z_in, + x_in, + parity_Z, + parity_X, + d, + cache_Z=cache_Z, + cache_X=cache_X, + use_parallel_spacelike=True, + ) + out = (x_out if basis == "X" else z_out)[0, 0] + self.assertEqual( + int(out[canonical].item()), + 1, + f"d={d} {basis} s={s}: parallel did not move error to canonical " + f"(canonical={canonical}, other={other}, out={out.tolist()})", + ) + self.assertEqual( + int(out[other].item()), + 0, + f"d={d} {basis} s={s}: parallel left error at `other`", + ) + + def test_parallel_w2_fe_idempotent_on_already_canonical(self): + """Error already at `canonical` of a w2 stabilizer must stay put.""" + d = 5 + hx, hz, _ = _build_parity_matrices(d) + parity_X = hx.to(torch.uint8) + parity_Z = hz.to(torch.uint8) + cache_X = build_spacelike_he_cache( + parity_X, distance=d, basis="X", device=torch.device("cpu") + ) + cache_Z = build_spacelike_he_cache( + parity_Z, distance=d, basis="Z", device=torch.device("cpu") + ) + + for basis, cache in (("X", cache_X), ("Z", cache_Z)): + w2_stabs = (cache.support_sizes == 2).nonzero().flatten().tolist() + for s in w2_stabs: + canonical = int(cache.w2_canonical[s].item()) + err = torch.zeros((1, 1, d * d), dtype=torch.uint8) + err[0, 0, canonical] = 1 + if basis == "X": + z_in, x_in = torch.zeros_like(err), err + else: + z_in, x_in = err, torch.zeros_like(err) + z_out, x_out = apply_homological_equivalence_torch_vmap( + z_in, + x_in, + parity_Z, + parity_X, + d, + cache_Z=cache_Z, + cache_X=cache_X, + use_parallel_spacelike=True, + ) + out = (x_out if basis == "X" else z_out)[0, 0] + self.assertEqual( + int(out[canonical].item()), + 1, + f"d={d} {basis} s={s}: parallel mutated already-canonical error", + ) + self.assertEqual(int(out.sum().item()), 1) + + def test_parallel_w2_fe_matches_sequential_for_w2_only_inputs(self): + """For inputs that exercise only w2 stabilizers (single error on `other`), + parallel and sequential paths must produce bit-identical outputs. + + Stronger than the broader random-input comparison (which can legitimately + diverge on cross-color w4 tie-breaking): w2-only inputs have no w4 + action, so both paths are constrained to the same final bit pattern. + """ + for d in (3, 5, 7): + hx, hz, _ = _build_parity_matrices(d) + parity_X = hx.to(torch.uint8) + parity_Z = hz.to(torch.uint8) + cache_X = build_spacelike_he_cache( + parity_X, distance=d, basis="X", device=torch.device("cpu") + ) + cache_Z = build_spacelike_he_cache( + parity_Z, distance=d, basis="Z", device=torch.device("cpu") + ) + + for basis, cache in (("X", cache_X), ("Z", cache_Z)): + w2_stabs = (cache.support_sizes == 2).nonzero().flatten().tolist() + for s in w2_stabs: + other = int(cache.w2_other[s].item()) + err = torch.zeros((1, 1, d * d), dtype=torch.uint8) + err[0, 0, other] = 1 + if basis == "X": + z_in, x_in = torch.zeros_like(err), err + else: + z_in, x_in = err, torch.zeros_like(err) + + z_seq, x_seq = apply_homological_equivalence_torch_vmap( + z_in, + x_in, + parity_Z, + parity_X, + d, + cache_Z=cache_Z, + cache_X=cache_X, + ) + z_par, x_par = apply_homological_equivalence_torch_vmap( + z_in, + x_in, + parity_Z, + parity_X, + d, + cache_Z=cache_Z, + cache_X=cache_X, + use_parallel_spacelike=True, + ) + self.assertTrue( + torch.equal(z_seq, z_par), + f"d={d} {basis} s={s}: w2-only Z mismatch seq={z_seq} par={z_par}", + ) + self.assertTrue( + torch.equal(x_seq, x_par), + f"d={d} {basis} s={s}: w2-only X mismatch seq={x_seq} par={x_par}", + ) + + def test_fix_equiv_map_cache_driven_matches_legacy_sort(self): + """The cache-driven `_build_fix_equiv_map` must produce bit-identical + output to the legacy parity-matrix sort path. The corner derivation + has a single source of truth in `cache.w4_{tl,tr,bl,br}`.""" + + def _legacy_sort_form(parity, d, error_type): + # Inline copy of the previous parity-driven implementation, kept + # locally so that any future drift in the cache-driven path is + # caught here rather than silently producing different patterns. + et = error_type.upper() + w4_maps = [] + for s in range(parity.shape[0]): + support = torch.nonzero(parity[s], as_tuple=True)[0].tolist() + if len(support) != 4: + continue + coords = sorted((idx // d, idx % d, idx) for idx in support) + tl, tr, bl, br = coords[0][2], coords[1][2], coords[2][2], coords[3][2] + patterns = torch.empty((3, 4), dtype=torch.int16) + patterns[0] = torch.tensor([tl, bl, tr, br], dtype=torch.int16) + patterns[1] = torch.tensor([bl, br, tl, tr], dtype=torch.int16) + if et == "X": + patterns[2] = torch.tensor([tl, br, tr, bl], dtype=torch.int16) + else: + patterns[2] = torch.tensor([tr, bl, tl, br], dtype=torch.int16) + w4_maps.append(patterns) + if not w4_maps: + return torch.zeros((0, 3, 4), dtype=torch.int16) + return torch.stack(w4_maps, dim=0) + + for d in (3, 5, 7, 9): + hx, hz, _ = _build_parity_matrices(d) + for basis, parity in (("X", hx), ("Z", hz)): + cache = build_spacelike_he_cache( + parity.to(torch.uint8), + distance=d, + basis=basis, + device=torch.device("cpu"), + ) + cache_driven = _build_fix_equiv_map( + parity.to(torch.uint8), + d, + basis, + w4_tl_cpu=cache.w4_tl.cpu(), + w4_tr_cpu=cache.w4_tr.cpu(), + w4_bl_cpu=cache.w4_bl.cpu(), + w4_br_cpu=cache.w4_br.cpu(), + ) + legacy = _legacy_sort_form(parity.to(torch.uint8), d, basis) + self.assertTrue( + torch.equal(cache_driven, legacy), + f"d={d} {basis}: cache-driven fix_equiv_map diverged from legacy sort form", + ) + + def test_build_partition_returns_named_failure_on_non_bipartite(self): + """A non-bipartite stabilizer-overlap graph must be rejected with a + diagnostic that names the offending stabilizer pair.""" + # 3 weight-2 stabilizers forming an odd cycle in the overlap graph: + # stab 0: qubits {0, 1} + # stab 1: qubits {1, 2} + # stab 2: qubits {2, 0} + # Triangle => non-bipartite => expect failure with named pair. + parity = torch.zeros((3, 9), dtype=torch.uint8) + parity[0, 0] = 1 + parity[0, 1] = 1 + parity[1, 1] = 1 + parity[1, 2] = 1 + parity[2, 2] = 1 + parity[2, 0] = 1 + + fix_map = torch.zeros((0, 3, 4), dtype=torch.int16) + partition, reason = _build_spacelike_partition( + parity, + fix_map, + torch.device("cpu"), + ) + self.assertIsNone(partition) + self.assertIsNotNone(reason) + self.assertIn("not bipartite", reason) + # Must name a concrete offending stabilizer pair, not just "no partition". + self.assertTrue( + any(f"stabilizers {a} and {b}" in reason for a in range(3) for b in range(3) if a != b), + f"failure reason did not name a stabilizer pair: {reason!r}", + ) + + def test_parallel_partition_packed_at_cache_build_time(self): + """`cache.parallel_partition_packed` must be populated at cache-build + time when a partition is built (this avoids re-packing on every + compiled-path call). Empty colors must be padded + to width 1 with `valid=0` so the compiled chunk function has fixed + input shapes regardless of partition.""" + required_keys = { + "parity_c0", + "is_boundary_c0", + "w4_parity_c0", + "src0_c0", + "src1_c0", + "dst0_c0", + "dst1_c0", + "w2_can_c0", + "w2_oth_c0", + "w2_valid_c0", + "parity_c1", + "is_boundary_c1", + "w4_parity_c1", + "src0_c1", + "src1_c1", + "dst0_c1", + "dst1_c1", + "w2_can_c1", + "w2_oth_c1", + "w2_valid_c1", + } + for d in (3, 5, 7, 9): + hx, hz, _ = _build_parity_matrices(d) + for basis, parity in (("X", hx), ("Z", hz)): + cache = build_spacelike_he_cache( + parity.to(torch.uint8), + distance=d, + basis=basis, + device=torch.device("cpu"), + ) + packed = cache.parallel_partition_packed + self.assertIsNotNone( + packed, + f"d={d} {basis}: parallel_partition_packed not populated", + ) + missing = required_keys - set(packed.keys()) + self.assertEqual( + missing, + set(), + f"d={d} {basis}: packed dict missing keys {missing}", + ) + # Empty-color padding contract: every w2_can / w2_oth / w2_valid + # tensor must have a positive size (padded to 1 if natural width + # was 0). + for c in (0, 1): + self.assertGreaterEqual(packed[f"w2_can_c{c}"].numel(), 1) + self.assertGreaterEqual(packed[f"w2_oth_c{c}"].numel(), 1) + self.assertEqual( + packed[f"w2_valid_c{c}"].numel(), packed[f"w2_can_c{c}"].numel() + ) + + def test_compiled_parallel_reads_pre_packed_partition_off_cache(self): + """Production hot path must NOT re-run `_pack_partition_for_compile` + on every call. + + An earlier iteration plumbed `cache.parallel_partition_packed` onto the + cache, but the production entry point (`_simplify_spacelike`) + unconditionally passed `partition_override=_require_parallel_partition(cache)` + through to `_simplify_spacelike_parallel_compiled`. With the previous + `if partition_override is not None: pack` dispatch, that meant the + cached pack was populated at build time but never read on the hot path + and `_pack_partition_for_compile` fired on every training-step call. + + This test pins the production-call contract: across N production calls + through `_simplify_spacelike`, the packer fires exactly zero times + post-cache-build. Test-only callers that hand a *different* partition + dict to `_simplify_spacelike_parallel_compiled` (bench harnesses, + synthetic overrides) still pack on the fly -- identity, not + equality, decides which path runs. + """ + from unittest import mock + from qec.surface_code import homological_equivalence_torch as he_mod + + # Build the cache *outside* the patched region so the one-time pack at + # build time is not counted. We want to assert "post-build re-packs", + # not "lifetime packs". + hx, _, _ = _build_parity_matrices(5) + cache = build_spacelike_he_cache( + hx.to(torch.uint8), + distance=5, + basis="X", + device=torch.device("cpu"), + ) + self.assertIsNotNone(cache.parallel_partition_packed) + + # Stub the compiled chunk to an identity (`cfg_f` -> `cfg_f`). This + # keeps the test in pure-Python land: the dispatch logic in + # `_simplify_spacelike_parallel_compiled` runs unchanged, but no + # torch.compile / inductor / CUDA-graph machinery fires, so the test + # is fast on CPU. The convergence check trivially short-circuits on + # iteration 1, which is fine -- we are testing dispatch, not work. + identity_chunk = lambda cfg_f, *args: cfg_f + pack_wrapper = mock.Mock(wraps=he_mod._pack_partition_for_compile) + + with mock.patch.object(he_mod, "_get_compiled_spacelike_chunk", + return_value=identity_chunk), \ + mock.patch.object(he_mod, "_pack_partition_for_compile", pack_wrapper): + + cfg = torch.zeros((4, 5 * 5), dtype=torch.uint8) + + # (1) Production entry point: 10 calls via `_simplify_spacelike` + # with `use_compile=True, use_parallel_spacelike=True` must + # produce zero `_pack_partition_for_compile` invocations. + for _ in range(10): + he_mod._simplify_spacelike( + cfg, + cache, + basis="X", + max_iterations=1, + use_compile=True, + use_parallel_spacelike=True, + ) + self.assertEqual( + pack_wrapper.call_count, + 0, + f"production path re-ran `_pack_partition_for_compile` " + f"{pack_wrapper.call_count} times across 10 calls; " + f"the cached pack on `cache.parallel_partition_packed` is " + f"populated but unreachable from the hot path", + ) + + # (2) Identity override (same object the cache packed) must also + # hit the cached path -- this mirrors what `_simplify_spacelike` + # threads through. + he_mod._simplify_spacelike_parallel_compiled( + cfg, + cache, + max_iterations=1, + partition_override=cache.parallel_partition, + ) + self.assertEqual( + pack_wrapper.call_count, + 0, + "identity `partition_override is cache.parallel_partition` " + "must read the cached pack, not re-pack", + ) + + # (3) Fresh-dict override (e.g. bench harness with a synthesized + # partition) MUST re-pack -- the dispatch is by identity, not + # equality, so callers that supply a different partition still + # get correct behavior. + fresh_partition = dict(cache.parallel_partition) + he_mod._simplify_spacelike_parallel_compiled( + cfg, + cache, + max_iterations=1, + partition_override=fresh_partition, + ) + self.assertEqual( + pack_wrapper.call_count, + 1, + "fresh-dict override must trigger `_pack_partition_for_compile` " + "(the override path is the escape hatch for bench/test harnesses)", + ) + + def test_partition_packed_is_none_when_no_partition(self): + """When `basis` is not provided the parallel partition is not built; + the packed view must be `None` (not an empty dict).""" + hx, _, _ = _build_parity_matrices(3) + cache = build_spacelike_he_cache( + hx.to(torch.uint8), + distance=3, + device=torch.device("cpu"), + ) + self.assertIsNone(cache.parallel_partition) + self.assertIsNone(cache.parallel_partition_packed) + + def test_require_parallel_partition_surfaces_failure_reason(self): + """`_require_parallel_partition` must surface the cache's failure + reason, not a generic 'missing partition' message.""" + # Build a real cache and synthesize an "as-if non-bipartite" failure + # state by replacing the partition + stashing a reason. Use a real + # cache so we don't have to duplicate the SpacelikeHECache field list. + hx, _, _ = _build_parity_matrices(3) + cache = build_spacelike_he_cache( + hx.to(torch.uint8), + distance=3, + basis="X", + device=torch.device("cpu"), + ) + bad = SpacelikeHECache( + distance=cache.distance, + parity=cache.parity, + support_masks=cache.support_masks, + support_sizes=cache.support_sizes, + layers=cache.layers, + w2_canonical=cache.w2_canonical, + w2_other=cache.w2_other, + w4_tl=cache.w4_tl, + w4_tr=cache.w4_tr, + w4_bl=cache.w4_bl, + w4_br=cache.w4_br, + parallel_partition=None, + parallel_partition_failure_reason=( + "stabilizer overlap graph is not bipartite " + "(stabilizers 7 and 11 fall in the same color class via BFS)" + ), + ) + with self.assertRaises(ValueError) as ctx: + _require_parallel_partition(bad) + msg = str(ctx.exception) + self.assertIn("stabilizers 7 and 11", msg) + self.assertIn("not bipartite", msg) + # --------------------------------------------------------------------------- # Torch integration tests @@ -1209,7 +1890,13 @@ def setUp(self): seed=2026, ) - def _make_generator(self, num_he_cycles: int) -> MemoryCircuitTorch: + def _make_generator( + self, + num_he_cycles: int, + *, + use_parallel_spacelike: bool = False, + use_weight2: bool = False, + ) -> MemoryCircuitTorch: return MemoryCircuitTorch( distance=self.distance, n_rounds=self.n_rounds, @@ -1218,6 +1905,9 @@ def _make_generator(self, num_he_cycles: int) -> MemoryCircuitTorch: timelike_he=True, num_he_cycles=num_he_cycles, max_passes_w1=8, + use_parallel_spacelike=use_parallel_spacelike, + use_weight2=use_weight2, + max_passes_w2=2, device=self.device, H=self.H, p=self.p, @@ -1274,6 +1964,39 @@ def test_trainY_binary_channels(self): # All values should be exactly 0 or 1 for trainY channels. self.assertTrue(torch.all((unique == 0.0) | (unique == 1.0))) + def test_parallel_spacelike_generate_batch_shapes_and_trainX(self): + """Parallel spacelike HE should integrate through MemoryCircuitTorch without changing trainX.""" + gen_seq = self._make_generator(num_he_cycles=1) + gen_par = self._make_generator(num_he_cycles=1, use_parallel_spacelike=True) + + trainX_seq, trainY_seq = gen_seq.generate_batch(batch_size=self.batch_size, seed=314) + trainX_par, trainY_par = gen_par.generate_batch(batch_size=self.batch_size, seed=314) + + expected = (self.batch_size, 4, self.n_rounds, self.distance, self.distance) + self.assertEqual(trainX_par.shape, expected) + self.assertEqual(trainY_par.shape, expected) + self.assertTrue(torch.equal(trainX_seq, trainX_par)) + self.assertFalse(torch.isnan(trainY_par).any()) + self.assertTrue( + torch.all((torch.unique(trainY_par) == 0.0) | (torch.unique(trainY_par) == 1.0)) + ) + # Spacelike kernel correctness (syndrome preservation, weight non-increase, + # idempotence) is covered directly by TestParallelSpacelikeHE; this test + # only smokes the integration: shapes, dtype, no NaNs, trainX unchanged. + + def test_parallel_spacelike_generate_batch_with_weight2(self): + """Parallel spacelike should compose with weight-2 timelike HE.""" + gen = self._make_generator( + num_he_cycles=1, + use_parallel_spacelike=True, + use_weight2=True, + ) + trainX, trainY = gen.generate_batch(batch_size=self.batch_size, seed=2718) + expected = (self.batch_size, 4, self.n_rounds, self.distance, self.distance) + self.assertEqual(trainX.shape, expected) + self.assertEqual(trainY.shape, expected) + self.assertFalse(torch.isnan(trainY).any()) + class TestWeight2TimelikeTorchVsCPU(unittest.TestCase): """Compare torch weight-2 timelike HE against CPU reference.""" diff --git a/code/tests/test_gpu.py b/code/tests/test_gpu.py index 5f549c3..8053334 100644 --- a/code/tests/test_gpu.py +++ b/code/tests/test_gpu.py @@ -383,6 +383,112 @@ def test_he_weight_nonincreasing_on_cuda(self): w_after = z_diff.sum() + x_diff.sum() self.assertLessEqual(w_after.item(), w_before.item() + 1e-6) + def test_parallel_spacelike_eager_on_cuda(self): + """Parallel spacelike HE should run on CUDA and preserve basic invariants.""" + from qec.surface_code.memory_circuit import SurfaceCode + from qec.surface_code.homological_equivalence_torch import ( + apply_homological_equivalence_torch_vmap, + build_spacelike_he_cache, + ) + + code = SurfaceCode(self.distance, first_bulk_syndrome_type="X", rotated_type="V") + parity_X = torch.tensor(code.hx, dtype=torch.uint8, device=self.device) + parity_Z = torch.tensor(code.hz, dtype=torch.uint8, device=self.device) + cache_X = build_spacelike_he_cache( + parity_X, distance=self.distance, basis="X", device=self.device + ) + cache_Z = build_spacelike_he_cache( + parity_Z, distance=self.distance, basis="Z", device=self.device + ) + + B, D2, R = 4, self.distance**2, self.n_rounds + torch.manual_seed(2026) + z_in = torch.randint(0, 2, (B, R, D2), dtype=torch.uint8, device=self.device) + x_in = torch.randint(0, 2, (B, R, D2), dtype=torch.uint8, device=self.device) + z_out, x_out = apply_homological_equivalence_torch_vmap( + z_in, + x_in, + parity_Z, + parity_X, + self.distance, + cache_Z=cache_Z, + cache_X=cache_X, + use_parallel_spacelike=True, + ) + + self.assertEqual(z_out.device.type, "cuda") + self.assertEqual(x_out.device.type, "cuda") + self.assertTrue(torch.all(z_out.sum(dim=-1) <= z_in.sum(dim=-1))) + self.assertTrue(torch.all(x_out.sum(dim=-1) <= x_in.sum(dim=-1))) + + def test_parallel_spacelike_compiled_on_cuda(self): + """Compiled parallel spacelike HE should run on CUDA and produce + bit-identical output to the eager parallel path. Equality locks in + the pack-once cache field (`cache.parallel_partition_packed`) and the + float-only chunk convergence check that the parallel path introduces. + """ + from qec.surface_code.memory_circuit import SurfaceCode + from qec.surface_code.homological_equivalence_torch import ( + apply_homological_equivalence_torch_vmap, + build_spacelike_he_cache, + ) + + code = SurfaceCode(self.distance, first_bulk_syndrome_type="X", rotated_type="V") + parity_X = torch.tensor(code.hx, dtype=torch.uint8, device=self.device) + parity_Z = torch.tensor(code.hz, dtype=torch.uint8, device=self.device) + cache_X = build_spacelike_he_cache( + parity_X, distance=self.distance, basis="X", device=self.device + ) + cache_Z = build_spacelike_he_cache( + parity_Z, distance=self.distance, basis="Z", device=self.device + ) + + # The packed partition view must be populated so the compiled path + # can read it directly instead of re-packing on every call. + self.assertIsNotNone(cache_X.parallel_partition_packed) + self.assertIsNotNone(cache_Z.parallel_partition_packed) + + B, D2, R = 2, self.distance**2, self.n_rounds + torch.manual_seed(2027) + z_in = torch.randint(0, 2, (B, R, D2), dtype=torch.uint8, device=self.device) + x_in = torch.randint(0, 2, (B, R, D2), dtype=torch.uint8, device=self.device) + + z_eager, x_eager = apply_homological_equivalence_torch_vmap( + z_in, + x_in, + parity_Z, + parity_X, + self.distance, + cache_Z=cache_Z, + cache_X=cache_X, + use_compile=False, + use_parallel_spacelike=True, + ) + z_out, x_out = apply_homological_equivalence_torch_vmap( + z_in, + x_in, + parity_Z, + parity_X, + self.distance, + cache_Z=cache_Z, + cache_X=cache_X, + use_compile=True, + use_parallel_spacelike=True, + ) + + self.assertEqual(z_out.device.type, "cuda") + self.assertEqual(x_out.device.type, "cuda") + self.assertTrue(torch.all((z_out == 0) | (z_out == 1))) + self.assertTrue(torch.all((x_out == 0) | (x_out == 1))) + self.assertTrue( + torch.equal(z_eager, z_out), + "compiled parallel Z output diverged from eager parallel", + ) + self.assertTrue( + torch.equal(x_eager, x_out), + "compiled parallel X output diverged from eager parallel", + ) + # --------------------------------------------------------------------------- # Oracle predecoder residuals on GPU diff --git a/code/tests/test_public_config.py b/code/tests/test_public_config.py index edbe640..e6537a3 100644 --- a/code/tests/test_public_config.py +++ b/code/tests/test_public_config.py @@ -158,6 +158,64 @@ def test_validate_accepts_noise_model(self): expected_frames_dir = (repo_root / "frames_data").resolve() self.assertEqual(Path(merged.data.precomputed_frames_dir).resolve(), expected_frames_dir) + def test_validate_accepts_use_parallel_spacelike_flag(self): + cfg = OmegaConf.create( + { + "model_id": 1, + "distance": 9, + "n_rounds": 9, + "data": { + "use_parallel_spacelike": True + }, + } + ) + spec = validate_public_config(cfg) + merged = apply_public_defaults_and_model(cfg, spec) + self.assertTrue(bool(merged.data.use_parallel_spacelike)) + + def test_validate_rejects_nonbool_use_parallel_spacelike_flag(self): + cfg = OmegaConf.create( + { + "model_id": 1, + "distance": 9, + "n_rounds": 9, + "data": { + "use_parallel_spacelike": "yes" + }, + } + ) + with self.assertRaises(ValueError): + validate_public_config(cfg) + + def test_validate_accepts_use_compile_flag(self): + cfg = OmegaConf.create( + { + "model_id": 1, + "distance": 9, + "n_rounds": 9, + "data": { + "use_compile": True + }, + } + ) + spec = validate_public_config(cfg) + merged = apply_public_defaults_and_model(cfg, spec) + self.assertTrue(bool(merged.data.use_compile)) + + def test_validate_rejects_nonbool_use_compile_flag(self): + cfg = OmegaConf.create( + { + "model_id": 1, + "distance": 9, + "n_rounds": 9, + "data": { + "use_compile": "true" + }, + } + ) + with self.assertRaises(ValueError): + validate_public_config(cfg) + def test_validate_rejects_optimizer_subfields(self): cfg = OmegaConf.create( { diff --git a/code/training/train.py b/code/training/train.py index 294e367..281cec7 100644 --- a/code/training/train.py +++ b/code/training/train.py @@ -45,7 +45,8 @@ HE acceleration (forwarded to QCDataGeneratorTorch): cfg.data.use_compile, cfg.data.compile_chunk_size, cfg.data.compute_dtype, cfg.data.use_weight2, cfg.data.max_passes_w2, cfg.data.use_coset_search, - cfg.data.coset_max_generators, cfg.data.use_dense_overlap + cfg.data.coset_max_generators, cfg.data.use_dense_overlap, + cfg.data.use_parallel_spacelike Logging: Compact epoch summary: loss | LER | SDR | wall time | throughput @@ -948,6 +949,7 @@ def is_list_like(obj): use_coset_search=bool(getattr(cfg.data, 'use_coset_search', False)), coset_max_generators=int(getattr(cfg.data, 'coset_max_generators', 20)), use_dense_overlap=bool(getattr(cfg.data, 'use_dense_overlap', False)), + use_parallel_spacelike=bool(getattr(cfg.data, 'use_parallel_spacelike', False)), ) if use_multi_pairs: diff --git a/code/workflows/config_validator.py b/code/workflows/config_validator.py index 7d8974f..941f2ef 100644 --- a/code/workflows/config_validator.py +++ b/code/workflows/config_validator.py @@ -138,6 +138,7 @@ def _base_hidden_defaults_dict() -> Dict[str, Any]: "timelike_he": True, "num_he_cycles": 1, "use_weight2_timelike": False, + "use_parallel_spacelike": False, "max_passes_w1": 8, "max_passes_w2": 4, "decompose_y": True, @@ -358,13 +359,33 @@ def validate_public_config(cfg: DictConfig) -> PublicModelSpec: "Config field 'data.precomputed_frames_dir' is not supported in the public release. " "Remove it from the config/CLI overrides." ) - allowed_data_keys = {"code_rotation", "noise_model"} + # `use_compile` and `use_parallel_spacelike` are HE-acceleration flags + # surfaced in the public release. Both default False; users opt in via + # `conf/config_public.yaml` or CLI override. See README.md, section + # "HE acceleration (advanced): parallel spacelike" for the contract. + allowed_data_keys = { + "code_rotation", + "noise_model", + "use_compile", + "use_parallel_spacelike", + } for k in cfg.data.keys(): if k not in allowed_data_keys: raise ValueError( f"Config field 'data.{k}' is not supported in the public release. " f"Allowed data fields are: {sorted(allowed_data_keys)}" ) + # These two flags are part of the public config surface, so keep their + # accepted type stricter than hidden/internal HE knobs that are merged + # from trusted defaults. OmegaConf accepts strings like "True"/"yes", + # which would otherwise flow into downstream `bool(...)` casts and + # become truthy regardless of the user's intent. + for bool_key in ("use_compile", "use_parallel_spacelike"): + if bool_key in cfg.data and not isinstance(cfg.data[bool_key], bool): + raise ValueError( + f"Config field 'data.{bool_key}' must be a boolean " + f"(got {type(cfg.data[bool_key]).__name__}: {cfg.data[bool_key]!r})." + ) # Validate rotation value (accept O1..O4; also allow internal XV/XH/ZV/ZH for compatibility). if "code_rotation" in cfg.data: _normalize_code_rotation(cfg.data.code_rotation) diff --git a/conf/config_pre_decoder_memory_surface_model_1_d9.yaml b/conf/config_pre_decoder_memory_surface_model_1_d9.yaml index 19a4781..a7624f1 100644 --- a/conf/config_pre_decoder_memory_surface_model_1_d9.yaml +++ b/conf/config_pre_decoder_memory_surface_model_1_d9.yaml @@ -36,6 +36,7 @@ data: generator: torch # Training data generator (Torch-only) timelike_he: True # Enable timelike homological equivalence num_he_cycles: 1 # Number of (spacelike+timelike) HE cycles to apply (default: 1) + use_parallel_spacelike: False # Enable 2-partition parallel spacelike HE acceleration use_weight2: False # Enable weight-2 timelike HE (in addition to weight-1) max_passes_w1: 8 # Maximum passes for weight-1 timelike HE convergence (default: 32) max_passes_w2: 4 # Maximum passes for weight-2 timelike HE convergence (default: 32) diff --git a/conf/config_public.yaml b/conf/config_public.yaml index 1170e64..a4fea8e 100644 --- a/conf/config_public.yaml +++ b/conf/config_public.yaml @@ -35,6 +35,12 @@ workflow: data: # Surface code orientation (public naming): O1, O2, O3, O4 code_rotation: O1 + # Optional HE acceleration flag, surfaced here for discoverability (other HE + # knobs live in the internal defaults). Enables the 2-partition parallel + # spacelike homological-equivalence path; see the "HE acceleration (advanced): + # parallel spacelike" section in README.md for the pros/cons and constraints. + use_compile: False # Required to see the speedup from use_parallel_spacelike=True. + use_parallel_spacelike: False # Circuit-level noise model (25-parameter). This is the default public noise specification. # The defaults are chosen for p=0.003. noise_model: