Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions cirq-core/cirq/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,10 @@

from cirq.transformers.eject_z import eject_z as eject_z

from cirq.transformers.diagonal_optimization import (
drop_diagonal_before_measurement as drop_diagonal_before_measurement,
)

from cirq.transformers.measurement_transformers import (
defer_measurements as defer_measurements,
dephase_measurements as dephase_measurements,
Expand Down
148 changes: 148 additions & 0 deletions cirq-core/cirq/transformers/diagonal_optimization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
# Copyright 2024 The Cirq Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Transformer pass that removes diagonal gates before measurements."""

from __future__ import annotations

import numpy as np

import cirq
from cirq import ops, protocols
from cirq.transformers import transformer_api
from cirq.transformers.eject_z import eject_z


def _is_diagonal(op: cirq.Operation) -> bool:
"""Checks if an operation is diagonal in the computational basis.

Args:
op: The operation to check.

Returns:
True if the operation is diagonal in the computational basis.
"""
# Fast Path: Check for common diagonal gate types directly
if isinstance(op.gate, (ops.ZPowGate, ops.CZPowGate, ops.IdentityGate)):
return True

# Slow Path: Check the unitary matrix
if protocols.has_unitary(op):
try:
u = protocols.unitary(op)
# Check if off-diagonal elements are close to zero
return np.allclose(u, np.diag(np.diag(u)))
except Exception:
# If matrix calculation fails (e.g. huge gates), assume not diagonal
return False

return False


@transformer_api.transformer
def drop_diagonal_before_measurement(
circuit: cirq.AbstractCircuit, *, context: cirq.TransformerContext | None = None
) -> cirq.Circuit:
"""Removes diagonal gates that appear immediately before measurements.

This transformer optimizes circuits by removing diagonal gates (gates that are
diagonal in the computational basis, such as Z, S, T, CZ, etc.) that appear
immediately before measurement operations. Since measurements project onto the
computational basis, any diagonal gate applied immediately before a measurement
does not affect the measurement outcome and can be safely removed.

To maximize the effectiveness of this optimization, the transformer first applies
the `eject_z` transformation, which pushes Z gates (and other diagonal phases)
later in the circuit. This handles cases where diagonal gates can commute past
other operations. For example:

Z(q0) - CZ(q0, q1) - measure(q1)

After `eject_z`, the Z gate on the control qubit commutes through the CZ:

CZ(q0, q1) - Z(q1) - measure(q1)

Then both the CZ and Z(q1) can be removed since they're before the measurement:

measure(q1)

Args:
circuit: Input circuit to transform.
context: `cirq.TransformerContext` storing common configurable options for transformers.

Returns:
Copy of the transformed input circuit with diagonal gates before measurements removed.

Examples:
>>> import cirq
>>> q0, q1 = cirq.LineQubit.range(2)
>>>
>>> # Simple case: Z before measurement
>>> circuit = cirq.Circuit(cirq.H(q0), cirq.Z(q0), cirq.measure(q0))
>>> optimized = cirq.drop_diagonal_before_measurement(circuit)
>>> print(optimized)
0: ───H───M───

>>> # Complex case: Z-CZ commutation
>>> circuit = cirq.Circuit(
... cirq.Z(q0),
... cirq.CZ(q0, q1),
... cirq.measure(q1)
... )
>>> optimized = cirq.drop_diagonal_before_measurement(circuit)
>>> print(optimized)
1: ───M───
"""
if context is None:
context = transformer_api.TransformerContext()

# Phase 1: Apply eject_z to push Z gates later in the circuit.
# This handles commutation of Z gates through other operations,
# particularly important for the Z-CZ case mentioned in the feature request.
circuit = eject_z(circuit, context=context)

# Phase 2: Remove diagonal gates that appear before measurements.
# We iterate in reverse to identify which qubits will be measured.
# Track qubits that will be measured (set grows as we go backwards)
measured_qubits: set[ops.Qid] = set()

# Build new moments in reverse
new_moments = []
for moment in reversed(circuit):
new_ops = []

for op in moment:
# If this is a measurement, mark these qubits as measured
if protocols.is_measurement(op):
measured_qubits.update(op.qubits)
new_ops.append(op)
# If this is a diagonal gate and ANY of its qubits will be measured, remove it
# (diagonal gates only affect phase, which doesn't impact computational basis
# measurements)
elif _is_diagonal(op) and all(q in measured_qubits for q in op.qubits):
# Skip this operation (it's diagonal and at least one qubit is measured)
pass
else:
# Keep the operation
new_ops.append(op)
# If it's not diagonal, these qubits are no longer "safe to optimize"
if not _is_diagonal(op):
measured_qubits.difference_update(op.qubits)

# Add the moment if it has any operations
if new_ops:
new_moments.append(cirq.Moment(new_ops))

# Reverse back to original order
return cirq.Circuit(reversed(new_moments))
202 changes: 202 additions & 0 deletions cirq-core/cirq/transformers/diagonal_optimization_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
# Copyright 2024 The Cirq Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
import pytest

import cirq
from cirq.transformers.diagonal_optimization import drop_diagonal_before_measurement


def test_removes_z_before_measure():
q = cirq.NamedQubit('q')

# Original: H -> Z -> Measure
circuit = cirq.Circuit(cirq.H(q), cirq.Z(q), cirq.measure(q, key='m'))

optimized = drop_diagonal_before_measurement(circuit)

# Expected: H -> Measure (Z is gone)
expected = cirq.Circuit(cirq.H(q), cirq.measure(q, key='m'))

assert optimized == expected


def test_removes_diagonal_chain():
q = cirq.NamedQubit('q')

# Original: H -> Z -> S -> Measure
circuit = cirq.Circuit(cirq.H(q), cirq.Z(q), cirq.S(q), cirq.measure(q, key='m'))

optimized = drop_diagonal_before_measurement(circuit)

# Expected: H -> Measure (Both Z and S are gone)
expected = cirq.Circuit(cirq.H(q), cirq.measure(q, key='m'))

assert optimized == expected


def test_keeps_z_blocked_by_x():
q = cirq.NamedQubit('q')

# Original: Z -> X -> Measure
circuit = cirq.Circuit(cirq.Z(q), cirq.X(q), cirq.measure(q, key='m'))

# Z cannot commute past X, so it should be kept
# Note: eject_z will phase the X, so the circuit changes but Z is preserved
optimized = drop_diagonal_before_measurement(circuit)

# We use this helper to check mathematical equivalence
# instead of checking exact gate types (Y vs PhasedX)
cirq.testing.assert_circuits_with_terminal_measurements_are_equivalent(
circuit, optimized, atol=1e-6
)


def test_keeps_cz_if_only_one_qubit_measured():
q0, q1 = cirq.LineQubit.range(2)

# Original: CZ(0,1) -> Measure(0)
circuit = cirq.Circuit(cirq.CZ(q0, q1), cirq.measure(q0, key='m'))

# CZ shouldn't be removed because q1 is not measured
optimized = drop_diagonal_before_measurement(circuit)

assert optimized == circuit


def test_removes_cz_if_both_measured():
q0, q1 = cirq.LineQubit.range(2)

# Original: CZ(0,1) -> Measure(0), Measure(1)
circuit = cirq.Circuit(cirq.CZ(q0, q1), cirq.measure(q0, key='m0'), cirq.measure(q1, key='m1'))

optimized = drop_diagonal_before_measurement(circuit)

# Expected: Measures only
expected = cirq.Circuit(cirq.measure(q0, key='m0'), cirq.measure(q1, key='m1'))

# Check that operations match (ignoring Moment structure)
assert list(optimized.all_operations()) == list(expected.all_operations())


def test_feature_request_z_cz_commutation():
"""Test the original feature request case: Z-CZ commutation before measurement.

The circuit Z(q0) - CZ(q0, q1) - measure(q1) should be optimized to just measure(q1).
This is because:
1. Z on the control qubit of CZ commutes through the CZ
2. After commutation, both gates are diagonal and before measurement
3. Both can be removed
"""
q0, q1 = cirq.LineQubit.range(2)

# Original feature request circuit
circuit = cirq.Circuit(cirq.Z(q0), cirq.CZ(q0, q1), cirq.measure(q1, key='m1'))

optimized = drop_diagonal_before_measurement(circuit)

# The Z(0) might be moved or merged by eject_z, but the CZ MUST stay.
# We check that a two-qubit gate still exists.
assert len(list(optimized.findall_operations(lambda op: len(op.qubits) == 2))) > 0


def test_feature_request_full_example():
"""Test the full feature request example with measurements on both qubits."""
q0, q1 = cirq.LineQubit.range(2)

# From feature request
circuit = cirq.Circuit(
cirq.Z(q0),
cirq.CZ(q0, q1),
cirq.Z(q1),
cirq.measure(q0, key='m0'),
cirq.measure(q1, key='m1'),
)

optimized = drop_diagonal_before_measurement(circuit)

# Should simplify to just measurements
expected = cirq.Circuit(cirq.measure(q0, key='m0'), cirq.measure(q1, key='m1'))

assert list(optimized.all_operations()) == list(expected.all_operations())


def test_preserves_non_diagonal_gates():
"""Test that non-diagonal gates are preserved."""
q = cirq.NamedQubit('q')

circuit = cirq.Circuit(cirq.H(q), cirq.X(q), cirq.Z(q), cirq.measure(q, key='m'))

optimized = drop_diagonal_before_measurement(circuit)

# Verify the physics hasn't changed (handles PhasedX vs Y differences)
cirq.testing.assert_circuits_with_terminal_measurements_are_equivalent(
circuit, optimized, atol=1e-6
)


def test_is_diagonal_helper_edge_cases():
"""Test edge cases in _is_diagonal helper function for full coverage."""
from cirq.transformers.diagonal_optimization import _is_diagonal

q = cirq.NamedQubit('q')

# Test diagonal gates (fast path)
assert _is_diagonal(cirq.Z(q))
assert _is_diagonal(cirq.S(q)) # S is Z**0.5
assert _is_diagonal(cirq.T(q)) # T is Z**0.25

# Test non-diagonal gates
assert not _is_diagonal(cirq.H(q))
assert not _is_diagonal(cirq.X(q))

# Test two-qubit diagonal gate
q0, q1 = cirq.LineQubit.range(2)
assert _is_diagonal(cirq.CZ(q0, q1))

# Test operation without a gate (circuit operation or custom operation)
# This covers the "return False" when protocols.has_unitary is False
class NoUnitaryOp(cirq.Operation):
"""Custom operation without a unitary."""

def __init__(self, qubits):
self._qubits = tuple(qubits)

@property
def qubits(self):
return self._qubits

def with_qubits(self, *new_qubits):
return NoUnitaryOp(new_qubits)

no_unitary_op = NoUnitaryOp([q])
assert not _is_diagonal(no_unitary_op)

# Test operation that raises exception when computing unitary
# This covers the exception handling in _is_diagonal
class ExceptionGate(cirq.Gate):
"""Custom gate that raises exception during unitary computation."""

def _num_qubits_(self):
return 1

def _has_unitary_(self):
return True

def _unitary_(self):
raise ValueError("Simulated unitary computation error")

exception_op = ExceptionGate().on(q)
assert not _is_diagonal(exception_op)
Loading