Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions examples/logical_error_rates/5_state_preparation.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -239,15 +239,15 @@
}
],
"source": [
"logical_error_rates, discard_rates = circuits.get_logical_error_and_discard_rates(\n",
" code,\n",
" state_prep_circuit,\n",
" error_rates,\n",
" noise_model_family,\n",
" sinter_decoder=sinter_decoder,\n",
" num_samples=10**6,\n",
" post_select_on_flags=True,\n",
")\n",
"logical_error_rates = np.empty(len(tasks))\n",
"discard_rates = np.empty(len(tasks))\n",
"for tt, task in enumerate(tasks):\n",
" logical_error_rates[tt], discard_rates[tt] = circuits.get_logical_error_and_discard_rate(\n",
" task.circuit,\n",
" sinter_decoder=sinter_decoder,\n",
" num_samples=10**6,\n",
" flags=task.json_metadata[\"flags\"],\n",
" )\n",
"\n",
"# plot simulation results!\n",
"plot_error_and_discard_rates(error_rates, logical_error_rates, discard_rates)\n",
Expand Down
8 changes: 4 additions & 4 deletions src/qldpc/circuits/__init__.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
from .alpha_syndrome import AlphaSyndrome
from .benchmarking import (
get_logical_error_and_discard_rates,
get_logical_error_and_discard_rate,
get_nontrivial_logical_stabilizers,
get_state_prep_diagnostic_circuit,
get_state_prep_diagnostic_tasks,
)
from .bookkeeping import (
DetectorRecord,
MeasurementRecord,
MemoryExperimentParts,
QubitIDs,
Record,
)
Expand All @@ -22,6 +21,7 @@
with_remapped_qubits,
)
from .memory import (
MemoryExperimentParts,
get_logical_bell_prep,
get_memory_experiment,
get_memory_experiment_parts,
Expand Down Expand Up @@ -49,13 +49,12 @@

__all__ = [
"AlphaSyndrome",
"get_logical_error_and_discard_rates",
"get_logical_error_and_discard_rate",
"get_nontrivial_logical_stabilizers",
"get_state_prep_diagnostic_circuit",
"get_state_prep_diagnostic_tasks",
"DetectorRecord",
"MeasurementRecord",
"MemoryExperimentParts",
"QubitIDs",
"Record",
"get_encoder_and_decoder",
Expand All @@ -65,6 +64,7 @@
"get_pauli_product_measurements",
"restrict_to_qubits",
"with_remapped_qubits",
"MemoryExperimentParts",
"get_logical_bell_prep",
"get_memory_experiment",
"get_memory_experiment_parts",
Expand Down
160 changes: 69 additions & 91 deletions src/qldpc/circuits/benchmarking.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,114 +284,92 @@ def get_state_prep_diagnostic_tasks(
sinter.Task(
circuit=noise_model_family(error_rate).noisy_circuit(diagnostic_circuit),
postselection_mask=postselection_mask_bit_packed,
json_metadata={"p": error_rate},
json_metadata={"p": error_rate, "flags": detector_record.get_events("flag")},
)
for error_rate in error_rates
]


def get_logical_error_and_discard_rates(
code: codes.QuditCode,
state_prep_circuit: stim.Circuit,
error_rates: Sequence[float] | npt.NDArray[np.floating],
noise_model_family: Callable[[float], NoiseModel] = DepolarizingNoiseModel,
def get_logical_error_and_discard_rate(
circuit_or_dem: stim.Circuit | stim.DetectorErrorModel,
sinter_decoder: sinter.Decoder,
*,
sinter_decoder: sinter.Decoder | Sequence[sinter.Decoder],
num_samples: int | Sequence[int],
observables: npt.NDArray[np.int_]
| Sequence[Sequence[int]]
| Sequence[stim.PauliString]
| None = None,
post_select_on_flags: bool = False,
skip_validation: bool = False,
) -> tuple[npt.NDArray[np.floating], npt.NDArray[np.floating]]:
"""Compute logical error rates of the provided logical state prep circuit for the provided code.

The first len(code) qubits addressed by the circuit must be the data qubits of the code.
num_samples: int,
flags: Sequence[int] | None = None,
) -> tuple[float, float]:
"""Compute a logical error rate and discard rate from samples of the provided cirucit.

Each logical error rate is a fraction of the (possibly post-selected) shots in which observable
flips are predicted incorrectly by the provided decoder.

This method is provided as an alternative to get_state_prep_diagnostic_tasks, which currently
cannot support post-selection due to a sinter bug: https://github.com/quantumlib/Stim/pull/844
Once this bug is fixed, it is recommended to instead use get_state_prep_diagnostic_tasks.
This method is provided as an alternative to sinter, which currently cannot support post
selection due to an outstanding bug: https://github.com/quantumlib/Stim/pull/844
Once the bug is fixed, it is recommended to instead build a sinter.Task and call sinter.collect.

The sinter.Task would use the post-selection flags as follows:
postselection_mask_bits = np.zeros(circuit_or_dem.num_detectors, dtype=int)
postselection_mask_bits[flags] = 1
postselection_mask = np.packbits(postselection_mask, bitorder="little")
task = sinter.Task(
circuit=circuit,
postselection_mask=postselection_mask_bit_packed,
)
Sampling data would then be collected with:
stats = sinter.collect(
tasks=[task], # or more maybe more tasks
decoders=["custom"],
custom_decoders={"custom": sinter_decoder},
num_shots=num_samples,
# other options such as num_workers=os.cpu_count() or max_errors=100,
)

Args:
code: The code whose logical state is prepared by the provided state_prep_circuit.
state_prep_circuit: A circuit that prepares a logical state of the provided code.
error_rates: The error rates at which to evaluate the provided family of noise models.
noise_model_family: A single-parameter family of noise models for adding noise to circuits.
Default: qldpc.circuits.DepolarizingNoiseModel.
circuit_or_dem: The circuit or detector error model we wish to sample.
sinter_decoder: The circuit-level decoder used to predict observable flips.

Keyword args:
sinter_decoder: The circuit-level decoder used to predict observable flips, or a sequence of
circuit-level decoders (one for each error rate).
num_samples: The number of times to sample each noisy circuit, or a sequence of sample
numbers (one for each error rate).
observables: The observables that should stabilize the prepared state, or (by default) None.
If not None, the observables should be either a a matrix of symplectic row vectors, with
shape (num_observables, 2 * len(code)), or a sequence of Pauli strings supported on the
data qubits of the code. If None, observables are determined automatically by finding
all logical Pauli operators of the code that stabilize the state prepared by
state_prep_circuit.
post_select_on_flags: If True, post-select samples on nonzero measurement outcomes in the
provided state_prep_circuit. Default: False.
skip_validation: If True, skip the check to assert that the provided circuit prepares a
logical state fo the provided code.
num_samples: The number of times to the circuit_or_dem.
flags: The detectors in circuit_or_dem to post-select on.

Returns:
An array of estimated logical error rates.
An array of discard rates, or the fraction of shots (for each simulated error rate) that
were discarded due to post-selection on state prep flags. If post_select_on_flags is
False, this array contains only zeros.
A fraction of samples in which at least one observable was decoded incorrectly.
A fraction of samples that were discarded due to post-selection.
"""
diagnostic_circuit, detector_record = get_state_prep_diagnostic_circuit(
code, state_prep_circuit, observables=observables
)
if not isinstance(num_samples, Sequence):
num_samples = [num_samples] * len(error_rates)
if not isinstance(sinter_decoder, Sequence):
sinter_decoder = [sinter_decoder] * len(error_rates)

logical_error_rates = np.zeros(len(error_rates), dtype=float)
discard_rates = np.zeros(len(error_rates), dtype=float)
for pp, error_rate in enumerate(error_rates):
# sample detector and observable flips in the circuit
noise_model = noise_model_family(error_rate)
noisy_circuit = noise_model.noisy_circuit(diagnostic_circuit)
dem_arrays = decoders.DetectorErrorModelArrays(
noisy_circuit.detector_error_model(), simplify=True
)
dem = dem_arrays.to_dem()
sampler = dem.compile_sampler()
det_data, obs_data, err_data = sampler.sample(shots=num_samples[pp])

# if applicable, post-select on flag detectors
if post_select_on_flags:
# identify shots and detectors to remove
flag_dets = detector_record.get_events("flag")
shot_mask = ~np.any(det_data[:, flag_dets], axis=1)
detector_mask = np.ones(dem.num_detectors, dtype=bool)
detector_mask[flag_dets] = False

# post-select simulated data
det_data = det_data[shot_mask][:, detector_mask]
obs_data = obs_data[shot_mask]
dem = dem_arrays.post_selected_on(detector_record.get_events("flag")).to_dem()

# record the fraction of shots that were discarded
discard_rates[pp] = 1 - np.sum(shot_mask) / len(shot_mask)

# compile a decoder for this detector error model
compiled_sinter_decoder = sinter_decoder[pp].compile_decoder_for_dem(dem)

# decode and compute the logical error rate
predicted_flips = compiled_sinter_decoder.decode_shots(det_data)
obs_flips = obs_data ^ predicted_flips
failures = np.any(obs_flips, axis=1)
logical_error_rates[pp] = np.sum(failures) / len(failures)

return logical_error_rates, discard_rates
# build and simplify a detector error model
dem_arrays = decoders.DetectorErrorModelArrays(circuit_or_dem, simplify=True)
dem = dem_arrays.to_dem()

# sample detector and observable flips in the circuit
sampler = dem.compile_sampler()
det_data, obs_data, err_data = sampler.sample(shots=num_samples)

# if applicable, post-select on flag detectors
if flags:
# identify shots and detectors to remove
shot_mask = ~np.any(det_data[:, flags], axis=1)
detector_mask = np.ones(dem.num_detectors, dtype=bool)
detector_mask[flags] = False

# post-select simulated data
det_data = det_data[shot_mask][:, detector_mask]
obs_data = obs_data[shot_mask]
dem = dem_arrays.post_selected_on(flags).to_dem()

# record the fraction of shots that were discarded
discard_rate = 1 - np.sum(shot_mask) / len(shot_mask)
else: # pragma: no cover
discard_rate = 0

# compile a decoder for this detector error model
compiled_sinter_decoder = sinter_decoder.compile_decoder_for_dem(dem)

# decode and compute the logical error rate
predicted_flips = compiled_sinter_decoder.decode_shots(det_data)
obs_flips = obs_data ^ predicted_flips
failures = np.any(obs_flips, axis=1)
logical_error_rate = np.sum(failures) / len(failures)

return logical_error_rate, discard_rate


def _assert_logical_state_preparation(
Expand Down
22 changes: 15 additions & 7 deletions src/qldpc/circuits/benchmarking_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,15 +78,23 @@ def test_state_prep() -> None:
for error_rate, task in zip(error_rates, tasks):
assert task.json_metadata["p"] == error_rate

# cover alternative method for computing logical error rates
logical_error_rates, discard_rates = circuits.get_logical_error_and_discard_rates(
# find observables automatically
task = circuits.get_state_prep_diagnostic_tasks(
code,
state_prep_circuit,
error_rates=[0],
error_rates[:1],
noise_model_family,
observables=None,
post_select_on_flags=False,
)[0]
assert task == tasks[0]

# bypass sinter to compute logical error rates
logical_error_rate, discard_rate = circuits.get_logical_error_and_discard_rate(
task.circuit,
sinter_decoder=decoders.SinterDecoder(),
num_samples=1,
observables=None, # construct automatically
post_select_on_flags=True,
flags=task.json_metadata["flags"],
)
assert np.array_equal(logical_error_rates, [0])
assert np.array_equal(discard_rates, [0])
assert logical_error_rate == 0
assert discard_rate == 0
10 changes: 0 additions & 10 deletions src/qldpc/circuits/bookkeeping.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import dataclasses
import itertools
from collections.abc import Hashable, ItemsView, Iterable, Iterator, Mapping, Sequence
from typing import NamedTuple

import numpy as np
import stim
Expand Down Expand Up @@ -274,12 +273,3 @@ def after_post_selection(self, key: Hashable) -> DetectorRecord:
if other_key != key
}
)


class MemoryExperimentParts(NamedTuple):
initialization: stim.Circuit
qec_cycle: stim.Circuit
readout: stim.Circuit
measurement_record: MeasurementRecord
detector_record: DetectorRecord
qubit_ids: QubitIDs
12 changes: 11 additions & 1 deletion src/qldpc/circuits/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,15 @@
"""

from collections.abc import Collection, Sequence
from typing import NamedTuple

import numpy as np
import stim

from qldpc import codes
from qldpc.objects import Node, Pauli, PauliXZ

from .bookkeeping import DetectorRecord, MeasurementRecord, MemoryExperimentParts, QubitIDs
from .bookkeeping import DetectorRecord, MeasurementRecord, QubitIDs
from .common import (
get_encoding_circuit,
get_pauli_product_measurements,
Expand All @@ -34,6 +35,15 @@
from .syndrome_measurement import EdgeColoring, SyndromeMeasurementStrategy


class MemoryExperimentParts(NamedTuple):
initialization: stim.Circuit
qec_cycle: stim.Circuit
readout: stim.Circuit
measurement_record: MeasurementRecord
detector_record: DetectorRecord
qubit_ids: QubitIDs


def get_memory_experiment(
code: codes.QuditCode | codes.ClassicalCode,
basis: PauliXZ | None = Pauli.X,
Expand Down
9 changes: 8 additions & 1 deletion src/qldpc/decoders/dems.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,15 @@ class DetectorErrorModelArrays:
observable_flip_matrix: scipy.sparse.csc_matrix # maps errors to observable flips
error_probs: npt.NDArray[np.floating] # probability of occurrence for each error

def __init__(self, dem: stim.DetectorErrorModel, *, simplify: bool = True) -> None:
def __init__(
self, circuit_or_dem: stim.Circuit | stim.DetectorErrorModel, *, simplify: bool = True
) -> None:
"""Initialize from a stim.DetectorErrorModel."""
dem = (
circuit_or_dem.detector_error_model()
if isinstance(circuit_or_dem, stim.Circuit)
else circuit_or_dem
)
errors = DetectorErrorModelArrays.get_circuit_errors(dem)
if simplify:
errors = DetectorErrorModelArrays.get_merged_circuit_errors(errors)
Expand Down