Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,13 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [Unreleased]

### Added
- Fast path in the detector sampler for components whose output is deterministically given by a single error variable. These components now skip the JAX compilation and autoregressive sampling pipeline, significantly speeding up detector sampling for surface-code circuits at low physical error rates.



## [0.1.4] - 2026-04-28

### Fixed
Expand Down
5 changes: 5 additions & 0 deletions src/tsim/circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -740,6 +740,11 @@ def compile_detector_sampler(
) -> CompiledDetectorSampler:
"""Compile circuit into a detector sampler.
Connected components whose single output is deterministically given by
one f-variable are handled via a fast direct path (no compilation or
autoregressive sampling). Remaining components go through the full
compilation pipeline.
Args:
strategy: Stabilizer rank decomposition strategy.
Must be one of "cat5", "bss", "cutting".
Expand Down
49 changes: 38 additions & 11 deletions src/tsim/compile/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,19 @@
from typing import Literal

import jax.numpy as jnp
import numpy as np
import pyzx_param as zx
from pyzx_param.graph.base import BaseGraph
from pyzx_param.simulate import DecompositionStrategy

from tsim.compile.compile import CompiledScalarGraphs, compile_scalar_graphs
from tsim.compile.stabrank import find_stab
from tsim.core.graph import ConnectedComponent, connected_components, get_params
from tsim.core.graph import (
ConnectedComponent,
classify_direct,
connected_components,
get_params,
)
from tsim.core.types import CompiledComponent, CompiledProgram, SamplingGraph

DecompositionMode = Literal["sequential", "joint"]
Expand Down Expand Up @@ -52,24 +58,45 @@ def compile_program(
f_indices_global = _get_f_indices(prepared.graph)
num_outputs = prepared.num_outputs

direct_entries: list[tuple[int, int, bool]] = [] # (output_idx, f_idx, flip)
compiled_components: list[CompiledComponent] = []
output_order: list[int] = []
compiled_output_order: list[int] = []

sorted_components = sorted(components, key=lambda c: len(c.output_indices))

for component in sorted_components:
compiled = _compile_component(
component=component,
f_indices_global=f_indices_global,
mode=mode,
strategy=strategy,
)
compiled_components.append(compiled)
output_order.extend(component.output_indices)
direct = classify_direct(component)
if direct is not None:
f_idx, flip = direct
direct_entries.append((component.output_indices[0], f_idx, flip))
else:
compiled = _compile_component(
component=component,
f_indices_global=f_indices_global,
mode=mode,
strategy=strategy,
)
compiled_components.append(compiled)
compiled_output_order.extend(component.output_indices)

# Sort direct entries by output index so that — together with the output
# prioritisation in transform_error_basis — the concatenated layout often
# matches the original output order, sparing a reindex at sample time.
direct_entries.sort()
direct_output_order = [e[0] for e in direct_entries]
direct_f_indices = [e[1] for e in direct_entries]
direct_flips = [e[2] for e in direct_entries]

output_order = np.array(direct_output_order + compiled_output_order, dtype=np.int32)
reindex = np.argsort(output_order)
is_identity = np.array_equal(reindex, np.arange(len(output_order)))

return CompiledProgram(
components=tuple(compiled_components),
output_order=jnp.array(output_order, dtype=jnp.int32),
direct_f_indices=jnp.array(direct_f_indices, dtype=jnp.int32),
direct_flips=jnp.array(direct_flips, dtype=jnp.bool_),
output_order=jnp.asarray(output_order),
output_reindex=None if is_identity else jnp.asarray(reindex),
num_outputs=num_outputs,
num_detectors=prepared.num_detectors,
)
Expand Down
77 changes: 75 additions & 2 deletions src/tsim/core/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from pyzx_param.graph.graph import Graph
from pyzx_param.graph.graph_s import GraphS
from pyzx_param.graph.scalar import Scalar
from pyzx_param.utils import VertexType
from pyzx_param.utils import EdgeType, VertexType

from tsim.core.instructions import GraphRepresentation
from tsim.core.parse import parse_stim_circuit
Expand Down Expand Up @@ -65,6 +65,65 @@ def connected_components(g: BaseGraph) -> list[ConnectedComponent]:
return components


def classify_direct(
component: ConnectedComponent,
) -> tuple[int, bool] | None:
"""Check if a component is directly determined by a single f-variable.

A component qualifies when its graph consists of exactly two vertices — one
boundary output and one Z-spider — connected by a Hadamard edge, where the
Z-spider carries a single ``f`` parameter and a constant phase of either 0
(no flip) or π (flip).

Args:
component: A connected component to classify.

Returns:
``(f_index, flip)`` if the fast path applies, otherwise ``None``.

"""
graph = component.graph
outputs = list(graph.outputs())
if len(outputs) != 1:
return None

vertices = list(graph.vertices())
if len(vertices) != 2:
return None

v_out = outputs[0]
neighbors = list(graph.neighbors(v_out))
if len(neighbors) != 1:
return None

v_det = neighbors[0]
if graph.type(v_det) != VertexType.Z:
return None
if graph.edge_type(graph.edge(v_out, v_det)) != EdgeType.HADAMARD:
return None

params = graph.get_params(v_det)
if len(params) != 1:
return None
f_param = next(iter(params))
if not f_param.startswith("f"):
return None

all_graph_params = get_params(graph)
if all_graph_params != {f_param}:
return None

phase = graph.phase(v_det)
if phase == 0:
flip = False
elif phase == Fraction(1, 1):
flip = True
else:
return None

return int(f_param[1:]), flip


def _collect_vertices(
g: BaseGraph,
start: Any,
Expand Down Expand Up @@ -274,7 +333,21 @@ def transform_error_basis(
then f0 = e1 XOR e3.

"""
parametrized_vertices = [v for v in g.vertices() if g._phaseVars.get(v)]
# Prioritize output-connected detector vertices so that f0, f1, ...
# are assigned in output order. This maximises the chance that the
# direct-component fast path produces an identity permutation, avoiding
# a column reindex at sample time.
output_detectors = []
for v_out in g.outputs():
neighbors = list(g.neighbors(v_out))
if len(neighbors) == 1 and g._phaseVars.get(neighbors[0]):
output_detectors.append(neighbors[0])

output_det_set = set(output_detectors)
other_param_vertices = [
v for v in g.vertices() if v not in output_det_set and g._phaseVars.get(v)
]
parametrized_vertices = output_detectors + other_param_vertices

if not parametrized_vertices:
g.scalar = Scalar()
Expand Down
12 changes: 10 additions & 2 deletions src/tsim/core/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,14 +86,22 @@ class CompiledProgram:

Attributes:
components: The compiled components, sorted by number of outputs.
output_order: Array for reordering component outputs to final order.
final_samples = combined[:, np.argsort(output_order)]
direct_f_indices: Precomputed f-parameter indices for direct components.
direct_flips: Precomputed flip flags for direct components.
output_order: Maps concatenated position to original output index.
The first ``len(direct_f_indices)`` entries correspond to direct
components; the remainder to compiled components.
output_reindex: Precomputed ``argsort(output_order)`` permutation,
or ``None`` when the outputs are already in order.
num_outputs: Total number of outputs across all components.
num_detectors: Number of detector outputs (for detector sampling).

"""

components: tuple[CompiledComponent, ...]
direct_f_indices: Array
direct_flips: Array
output_order: Array
output_reindex: Array | None
num_outputs: int
num_detectors: int
60 changes: 55 additions & 5 deletions src/tsim/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,11 +130,18 @@ def sample_program(
match the original output indices.

"""
if len(program.components) == 0:
results: list[jax.Array] = []

if program.num_outputs == 0:
batch_size = f_params.shape[0]
return jnp.empty((batch_size, 0), dtype=jnp.bool_)
return jnp.zeros((batch_size, 0), dtype=jnp.bool_)

results: list[jax.Array] = []
if len(program.direct_f_indices) > 0:
direct_bits = (
f_params[:, program.direct_f_indices].astype(jnp.bool_)
^ program.direct_flips
)
results.append(direct_bits)

for component in program.components:
samples, key, max_norm_deviation = sample_component(component, f_params, key)
Expand All @@ -154,7 +161,9 @@ def sample_program(
results.append(samples)

combined = jnp.concatenate(results, axis=1)
return combined[:, jnp.argsort(program.output_order)]
if program.output_reindex is not None:
combined = combined[:, program.output_reindex]
return combined


class _CompiledSamplerBase:
Expand Down Expand Up @@ -200,6 +209,22 @@ def __init__(
self.circuit = circuit
self._num_detectors = prepared.num_detectors

prog = self._program
self._direct_f_indices = np.asarray(prog.direct_f_indices)
self._direct_flips = np.asarray(prog.direct_flips, dtype=np.bool_)
self._direct_reindex = (
np.asarray(prog.output_reindex) if prog.output_reindex is not None else None
)
# Zero-copy fast path: f-indices are 0..n-1, no flips, no reindex.
# Hit by typical surface-code detector circuits at low noise.
n_direct = len(self._direct_f_indices)
self._direct_zero_copy = (
n_direct > 0
and self._direct_reindex is None
and not self._direct_flips.any()
and np.array_equal(self._direct_f_indices, np.arange(n_direct))
)

def _peak_bytes_per_sample(self) -> int:
"""Estimate peak device memory per sample from compiled program structure."""
peak = 0
Expand Down Expand Up @@ -276,6 +301,9 @@ def _sample_batches(
return empty, np.zeros(self._program.num_outputs, dtype=np.bool_)
return empty

if not self._program.components and not compute_reference:
return self._sample_direct(shots)

if batch_size is None:
max_batch_size = self._estimate_batch_size()
num_batches = max(1, ceil(shots / max_batch_size))
Expand Down Expand Up @@ -314,8 +342,20 @@ def _sample_batches(
return result, reference
return result

def _sample_direct(self, shots: int) -> np.ndarray:
"""Fast path when all components are direct (pure numpy, no JAX)."""
f_params = self._channel_sampler.sample(shots)
if self._direct_zero_copy:
return f_params[:, : len(self._direct_f_indices)].view(np.bool_)
result = f_params[:, self._direct_f_indices] ^ self._direct_flips
if self._direct_reindex is not None:
result = result[:, self._direct_reindex]
return result.view(np.bool_)

def __repr__(self) -> str:
"""Return a string representation with compilation statistics."""
n_direct = len(self._program.direct_f_indices)

c_graphs = []
c_params = []
c_a_terms = []
Expand Down Expand Up @@ -356,7 +396,8 @@ def _format_bytes(n: int) -> str:
)

return (
f"{type(self).__name__}({np.sum(c_graphs)} graphs, "
f"{type(self).__name__}({n_direct} direct, "
f"{np.sum(c_graphs)} graphs, "
f"{error_channel_bits} error channel bits, "
f"{np.max(num_outputs) if num_outputs else 0} outputs for largest cc, "
f"≤ {np.max(c_params) if c_params else 0} parameters, {np.sum(c_a_terms)} A terms, "
Expand Down Expand Up @@ -637,6 +678,15 @@ def probability_of(self, state: np.ndarray, *, batch_size: int) -> np.ndarray:
p_norm = jnp.ones(batch_size)
p_joint = jnp.ones(batch_size)

if len(self._program.direct_f_indices) > 0:
direct_bits = (
f_samples[:, self._program.direct_f_indices].astype(jnp.bool_)
^ self._program.direct_flips
)
n_direct = len(self._program.direct_f_indices)
targets = state[self._program.output_order[:n_direct]]
p_joint = p_joint * (direct_bits == targets).all(axis=1)

for component in self._program.components:
assert len(component.compiled_scalar_graphs) == 2

Expand Down
3 changes: 3 additions & 0 deletions test/integration/test_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,10 @@ def test_sample_program_raises_on_component_norm_deviation(monkeypatch):
)
program = CompiledProgram(
components=components,
direct_f_indices=jnp.array([], dtype=jnp.int32),
direct_flips=jnp.array([], dtype=jnp.bool_),
output_order=jnp.array([0, 1]),
output_reindex=None,
num_outputs=2,
num_detectors=0,
)
Expand Down
9 changes: 9 additions & 0 deletions test/unit/test_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,15 @@ def test_detector_sampler_no_detectors_bit_packed():
assert result.shape == (5, 0)


def test_detector_sampler_no_detectors_with_reference_sample():
"""Detector sampler with use_detector_reference_sample=True and no detectors returns (shots, 0)."""
c = Circuit("H 0\nM 0")
sampler = c.compile_detector_sampler()
result = sampler.sample(5, use_detector_reference_sample=True)
assert result.dtype == np.bool_
assert result.shape == (5, 0)


def test_sampler_negative_shots_raises():
sampler = Circuit("H 0\nM 0").compile_sampler(seed=0)
with pytest.raises(ValueError, match="shots must be non-negative"):
Expand Down
Loading