diff --git a/CHANGELOG.md b/CHANGELOG.md index f668209..125d4f9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,13 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [Unreleased] + +### Added +- Fast path in the detector sampler for components whose output is deterministically given by a single error variable. These components now skip the JAX compilation and autoregressive sampling pipeline, significantly speeding up detector sampling for surface-code circuits at low physical error rates. + + + ## [0.1.4] - 2026-04-28 ### Fixed diff --git a/src/tsim/circuit.py b/src/tsim/circuit.py index 77eb160..0921056 100644 --- a/src/tsim/circuit.py +++ b/src/tsim/circuit.py @@ -740,6 +740,11 @@ def compile_detector_sampler( ) -> CompiledDetectorSampler: """Compile circuit into a detector sampler. + Connected components whose single output is deterministically given by + one f-variable are handled via a fast direct path (no compilation or + autoregressive sampling). Remaining components go through the full + compilation pipeline. + Args: strategy: Stabilizer rank decomposition strategy. Must be one of "cat5", "bss", "cutting". diff --git a/src/tsim/compile/pipeline.py b/src/tsim/compile/pipeline.py index 8668c21..4789c78 100644 --- a/src/tsim/compile/pipeline.py +++ b/src/tsim/compile/pipeline.py @@ -5,13 +5,19 @@ from typing import Literal import jax.numpy as jnp +import numpy as np import pyzx_param as zx from pyzx_param.graph.base import BaseGraph from pyzx_param.simulate import DecompositionStrategy from tsim.compile.compile import CompiledScalarGraphs, compile_scalar_graphs from tsim.compile.stabrank import find_stab -from tsim.core.graph import ConnectedComponent, connected_components, get_params +from tsim.core.graph import ( + ConnectedComponent, + classify_direct, + connected_components, + get_params, +) from tsim.core.types import CompiledComponent, CompiledProgram, SamplingGraph DecompositionMode = Literal["sequential", "joint"] @@ -52,24 +58,45 @@ def compile_program( f_indices_global = _get_f_indices(prepared.graph) num_outputs = prepared.num_outputs + direct_entries: list[tuple[int, int, bool]] = [] # (output_idx, f_idx, flip) compiled_components: list[CompiledComponent] = [] - output_order: list[int] = [] + compiled_output_order: list[int] = [] sorted_components = sorted(components, key=lambda c: len(c.output_indices)) for component in sorted_components: - compiled = _compile_component( - component=component, - f_indices_global=f_indices_global, - mode=mode, - strategy=strategy, - ) - compiled_components.append(compiled) - output_order.extend(component.output_indices) + direct = classify_direct(component) + if direct is not None: + f_idx, flip = direct + direct_entries.append((component.output_indices[0], f_idx, flip)) + else: + compiled = _compile_component( + component=component, + f_indices_global=f_indices_global, + mode=mode, + strategy=strategy, + ) + compiled_components.append(compiled) + compiled_output_order.extend(component.output_indices) + + # Sort direct entries by output index so that — together with the output + # prioritisation in transform_error_basis — the concatenated layout often + # matches the original output order, sparing a reindex at sample time. + direct_entries.sort() + direct_output_order = [e[0] for e in direct_entries] + direct_f_indices = [e[1] for e in direct_entries] + direct_flips = [e[2] for e in direct_entries] + + output_order = np.array(direct_output_order + compiled_output_order, dtype=np.int32) + reindex = np.argsort(output_order) + is_identity = np.array_equal(reindex, np.arange(len(output_order))) return CompiledProgram( components=tuple(compiled_components), - output_order=jnp.array(output_order, dtype=jnp.int32), + direct_f_indices=jnp.array(direct_f_indices, dtype=jnp.int32), + direct_flips=jnp.array(direct_flips, dtype=jnp.bool_), + output_order=jnp.asarray(output_order), + output_reindex=None if is_identity else jnp.asarray(reindex), num_outputs=num_outputs, num_detectors=prepared.num_detectors, ) diff --git a/src/tsim/core/graph.py b/src/tsim/core/graph.py index 0cc2491..66db35b 100644 --- a/src/tsim/core/graph.py +++ b/src/tsim/core/graph.py @@ -13,7 +13,7 @@ from pyzx_param.graph.graph import Graph from pyzx_param.graph.graph_s import GraphS from pyzx_param.graph.scalar import Scalar -from pyzx_param.utils import VertexType +from pyzx_param.utils import EdgeType, VertexType from tsim.core.instructions import GraphRepresentation from tsim.core.parse import parse_stim_circuit @@ -65,6 +65,65 @@ def connected_components(g: BaseGraph) -> list[ConnectedComponent]: return components +def classify_direct( + component: ConnectedComponent, +) -> tuple[int, bool] | None: + """Check if a component is directly determined by a single f-variable. + + A component qualifies when its graph consists of exactly two vertices — one + boundary output and one Z-spider — connected by a Hadamard edge, where the + Z-spider carries a single ``f`` parameter and a constant phase of either 0 + (no flip) or π (flip). + + Args: + component: A connected component to classify. + + Returns: + ``(f_index, flip)`` if the fast path applies, otherwise ``None``. + + """ + graph = component.graph + outputs = list(graph.outputs()) + if len(outputs) != 1: + return None + + vertices = list(graph.vertices()) + if len(vertices) != 2: + return None + + v_out = outputs[0] + neighbors = list(graph.neighbors(v_out)) + if len(neighbors) != 1: + return None + + v_det = neighbors[0] + if graph.type(v_det) != VertexType.Z: + return None + if graph.edge_type(graph.edge(v_out, v_det)) != EdgeType.HADAMARD: + return None + + params = graph.get_params(v_det) + if len(params) != 1: + return None + f_param = next(iter(params)) + if not f_param.startswith("f"): + return None + + all_graph_params = get_params(graph) + if all_graph_params != {f_param}: + return None + + phase = graph.phase(v_det) + if phase == 0: + flip = False + elif phase == Fraction(1, 1): + flip = True + else: + return None + + return int(f_param[1:]), flip + + def _collect_vertices( g: BaseGraph, start: Any, @@ -274,7 +333,21 @@ def transform_error_basis( then f0 = e1 XOR e3. """ - parametrized_vertices = [v for v in g.vertices() if g._phaseVars.get(v)] + # Prioritize output-connected detector vertices so that f0, f1, ... + # are assigned in output order. This maximises the chance that the + # direct-component fast path produces an identity permutation, avoiding + # a column reindex at sample time. + output_detectors = [] + for v_out in g.outputs(): + neighbors = list(g.neighbors(v_out)) + if len(neighbors) == 1 and g._phaseVars.get(neighbors[0]): + output_detectors.append(neighbors[0]) + + output_det_set = set(output_detectors) + other_param_vertices = [ + v for v in g.vertices() if v not in output_det_set and g._phaseVars.get(v) + ] + parametrized_vertices = output_detectors + other_param_vertices if not parametrized_vertices: g.scalar = Scalar() diff --git a/src/tsim/core/types.py b/src/tsim/core/types.py index 6757eac..4691b56 100644 --- a/src/tsim/core/types.py +++ b/src/tsim/core/types.py @@ -86,14 +86,22 @@ class CompiledProgram: Attributes: components: The compiled components, sorted by number of outputs. - output_order: Array for reordering component outputs to final order. - final_samples = combined[:, np.argsort(output_order)] + direct_f_indices: Precomputed f-parameter indices for direct components. + direct_flips: Precomputed flip flags for direct components. + output_order: Maps concatenated position to original output index. + The first ``len(direct_f_indices)`` entries correspond to direct + components; the remainder to compiled components. + output_reindex: Precomputed ``argsort(output_order)`` permutation, + or ``None`` when the outputs are already in order. num_outputs: Total number of outputs across all components. num_detectors: Number of detector outputs (for detector sampling). """ components: tuple[CompiledComponent, ...] + direct_f_indices: Array + direct_flips: Array output_order: Array + output_reindex: Array | None num_outputs: int num_detectors: int diff --git a/src/tsim/sampler.py b/src/tsim/sampler.py index af94d1e..bfe4e74 100644 --- a/src/tsim/sampler.py +++ b/src/tsim/sampler.py @@ -130,11 +130,18 @@ def sample_program( match the original output indices. """ - if len(program.components) == 0: + results: list[jax.Array] = [] + + if program.num_outputs == 0: batch_size = f_params.shape[0] - return jnp.empty((batch_size, 0), dtype=jnp.bool_) + return jnp.zeros((batch_size, 0), dtype=jnp.bool_) - results: list[jax.Array] = [] + if len(program.direct_f_indices) > 0: + direct_bits = ( + f_params[:, program.direct_f_indices].astype(jnp.bool_) + ^ program.direct_flips + ) + results.append(direct_bits) for component in program.components: samples, key, max_norm_deviation = sample_component(component, f_params, key) @@ -154,7 +161,9 @@ def sample_program( results.append(samples) combined = jnp.concatenate(results, axis=1) - return combined[:, jnp.argsort(program.output_order)] + if program.output_reindex is not None: + combined = combined[:, program.output_reindex] + return combined class _CompiledSamplerBase: @@ -200,6 +209,22 @@ def __init__( self.circuit = circuit self._num_detectors = prepared.num_detectors + prog = self._program + self._direct_f_indices = np.asarray(prog.direct_f_indices) + self._direct_flips = np.asarray(prog.direct_flips, dtype=np.bool_) + self._direct_reindex = ( + np.asarray(prog.output_reindex) if prog.output_reindex is not None else None + ) + # Zero-copy fast path: f-indices are 0..n-1, no flips, no reindex. + # Hit by typical surface-code detector circuits at low noise. + n_direct = len(self._direct_f_indices) + self._direct_zero_copy = ( + n_direct > 0 + and self._direct_reindex is None + and not self._direct_flips.any() + and np.array_equal(self._direct_f_indices, np.arange(n_direct)) + ) + def _peak_bytes_per_sample(self) -> int: """Estimate peak device memory per sample from compiled program structure.""" peak = 0 @@ -276,6 +301,9 @@ def _sample_batches( return empty, np.zeros(self._program.num_outputs, dtype=np.bool_) return empty + if not self._program.components and not compute_reference: + return self._sample_direct(shots) + if batch_size is None: max_batch_size = self._estimate_batch_size() num_batches = max(1, ceil(shots / max_batch_size)) @@ -314,8 +342,20 @@ def _sample_batches( return result, reference return result + def _sample_direct(self, shots: int) -> np.ndarray: + """Fast path when all components are direct (pure numpy, no JAX).""" + f_params = self._channel_sampler.sample(shots) + if self._direct_zero_copy: + return f_params[:, : len(self._direct_f_indices)].view(np.bool_) + result = f_params[:, self._direct_f_indices] ^ self._direct_flips + if self._direct_reindex is not None: + result = result[:, self._direct_reindex] + return result.view(np.bool_) + def __repr__(self) -> str: """Return a string representation with compilation statistics.""" + n_direct = len(self._program.direct_f_indices) + c_graphs = [] c_params = [] c_a_terms = [] @@ -356,7 +396,8 @@ def _format_bytes(n: int) -> str: ) return ( - f"{type(self).__name__}({np.sum(c_graphs)} graphs, " + f"{type(self).__name__}({n_direct} direct, " + f"{np.sum(c_graphs)} graphs, " f"{error_channel_bits} error channel bits, " f"{np.max(num_outputs) if num_outputs else 0} outputs for largest cc, " f"≤ {np.max(c_params) if c_params else 0} parameters, {np.sum(c_a_terms)} A terms, " @@ -637,6 +678,15 @@ def probability_of(self, state: np.ndarray, *, batch_size: int) -> np.ndarray: p_norm = jnp.ones(batch_size) p_joint = jnp.ones(batch_size) + if len(self._program.direct_f_indices) > 0: + direct_bits = ( + f_samples[:, self._program.direct_f_indices].astype(jnp.bool_) + ^ self._program.direct_flips + ) + n_direct = len(self._program.direct_f_indices) + targets = state[self._program.output_order[:n_direct]] + p_joint = p_joint * (direct_bits == targets).all(axis=1) + for component in self._program.components: assert len(component.compiled_scalar_graphs) == 2 diff --git a/test/integration/test_sampler.py b/test/integration/test_sampler.py index d3780a4..e68579e 100644 --- a/test/integration/test_sampler.py +++ b/test/integration/test_sampler.py @@ -120,7 +120,10 @@ def test_sample_program_raises_on_component_norm_deviation(monkeypatch): ) program = CompiledProgram( components=components, + direct_f_indices=jnp.array([], dtype=jnp.int32), + direct_flips=jnp.array([], dtype=jnp.bool_), output_order=jnp.array([0, 1]), + output_reindex=None, num_outputs=2, num_detectors=0, ) diff --git a/test/unit/test_sampler.py b/test/unit/test_sampler.py index 7b40a5a..1562eb1 100644 --- a/test/unit/test_sampler.py +++ b/test/unit/test_sampler.py @@ -168,6 +168,15 @@ def test_detector_sampler_no_detectors_bit_packed(): assert result.shape == (5, 0) +def test_detector_sampler_no_detectors_with_reference_sample(): + """Detector sampler with use_detector_reference_sample=True and no detectors returns (shots, 0).""" + c = Circuit("H 0\nM 0") + sampler = c.compile_detector_sampler() + result = sampler.sample(5, use_detector_reference_sample=True) + assert result.dtype == np.bool_ + assert result.shape == (5, 0) + + def test_sampler_negative_shots_raises(): sampler = Circuit("H 0\nM 0").compile_sampler(seed=0) with pytest.raises(ValueError, match="shots must be non-negative"):