From d5442fb9d5a5ced989bdffa2f026312b48694d33 Mon Sep 17 00:00:00 2001 From: Vivek1106-04 Date: Wed, 3 Dec 2025 16:14:34 +0530 Subject: [PATCH 01/10] Add drop_diagonal_before_measurement transformer This transformer removes diagonal gates (Z, CZ, etc.) that appear immediately before measurements, as they do not affect the measurement outcome in the computational basis. Uses eject_z to maximize optimization opportunities. Fixes #4935 --- cirq-core/cirq/transformers/__init__.py | 4 + .../transformers/diagonal_optimization.py | 154 ++++++++++++++ .../diagonal_optimization_test.py | 188 ++++++++++++++++++ 3 files changed, 346 insertions(+) create mode 100644 cirq-core/cirq/transformers/diagonal_optimization.py create mode 100644 cirq-core/cirq/transformers/diagonal_optimization_test.py diff --git a/cirq-core/cirq/transformers/__init__.py b/cirq-core/cirq/transformers/__init__.py index a6e37eb0882..ee38dbf8b62 100644 --- a/cirq-core/cirq/transformers/__init__.py +++ b/cirq-core/cirq/transformers/__init__.py @@ -88,6 +88,10 @@ from cirq.transformers.eject_z import eject_z as eject_z +from cirq.transformers.diagonal_optimization import ( + drop_diagonal_before_measurement as drop_diagonal_before_measurement, +) + from cirq.transformers.measurement_transformers import ( defer_measurements as defer_measurements, dephase_measurements as dephase_measurements, diff --git a/cirq-core/cirq/transformers/diagonal_optimization.py b/cirq-core/cirq/transformers/diagonal_optimization.py new file mode 100644 index 00000000000..169ec647a05 --- /dev/null +++ b/cirq-core/cirq/transformers/diagonal_optimization.py @@ -0,0 +1,154 @@ +# Copyright 2024 The Cirq Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Transformer pass that removes diagonal gates before measurements.""" + +from __future__ import annotations + +from collections import defaultdict +from typing import TYPE_CHECKING + +import numpy as np + +import cirq +from cirq import ops, protocols +from cirq.transformers import eject_z, transformer_api, transformer_primitives + +if TYPE_CHECKING: + pass + + +def _is_diagonal(op: cirq.Operation) -> bool: + """Checks if an operation is diagonal in the computational basis. + + Args: + op: The operation to check. + + Returns: + True if the operation is diagonal in the computational basis. + """ + # Fast Path: Check for common diagonal gate types directly + if isinstance(op.gate, (ops.ZPowGate, ops.CZPowGate, ops.IdentityGate)): + return True + + # Slow Path: Check the unitary matrix + if protocols.has_unitary(op): + try: + u = protocols.unitary(op) + # Check if off-diagonal elements are close to zero + return np.allclose(u, np.diag(np.diag(u))) + except Exception: + # If matrix calculation fails (e.g. huge gates), assume not diagonal + return False + + return False + + +@transformer_api.transformer +def drop_diagonal_before_measurement( + circuit: cirq.AbstractCircuit, + *, + context: cirq.TransformerContext | None = None, +) -> cirq.Circuit: + """Removes diagonal gates that appear immediately before measurements. + + This transformer optimizes circuits by removing diagonal gates (gates that are + diagonal in the computational basis, such as Z, S, T, CZ, etc.) that appear + immediately before measurement operations. Since measurements project onto the + computational basis, any diagonal gate applied immediately before a measurement + does not affect the measurement outcome and can be safely removed. + + To maximize the effectiveness of this optimization, the transformer first applies + the `eject_z` transformation, which pushes Z gates (and other diagonal phases) + later in the circuit. This handles cases where diagonal gates can commute past + other operations. For example: + + Z(q0) - CZ(q0, q1) - measure(q1) + + After `eject_z`, the Z gate on the control qubit commutes through the CZ: + + CZ(q0, q1) - Z(q1) - measure(q1) + + Then both the CZ and Z(q1) can be removed since they're before the measurement: + + measure(q1) + + Args: + circuit: Input circuit to transform. + context: `cirq.TransformerContext` storing common configurable options for transformers. + + Returns: + Copy of the transformed input circuit with diagonal gates before measurements removed. + + Examples: + >>> import cirq + >>> q0, q1 = cirq.LineQubit.range(2) + >>> + >>> # Simple case: Z before measurement + >>> circuit = cirq.Circuit(cirq.H(q0), cirq.Z(q0), cirq.measure(q0)) + >>> optimized = cirq.drop_diagonal_before_measurement(circuit) + >>> print(optimized) + 0: ───H───M─── + + >>> # Complex case: Z-CZ commutation + >>> circuit = cirq.Circuit( + ... cirq.Z(q0), + ... cirq.CZ(q0, q1), + ... cirq.measure(q1) + ... ) + >>> optimized = cirq.drop_diagonal_before_measurement(circuit) + >>> print(optimized) + 1: ───M─── + """ + if context is None: + context = transformer_api.TransformerContext() + + # Phase 1: Apply eject_z to push Z gates later in the circuit. + # This handles commutation of Z gates through other operations, + # particularly important for the Z-CZ case mentioned in the feature request. + circuit = eject_z(circuit, context=context) + + # Phase 2: Remove diagonal gates that appear before measurements. + # We iterate in reverse to identify which qubits will be measured. + # Track qubits that will be measured (set grows as we go backwards) + measured_qubits: set[ops.Qid] = set() + + # Build new moments in reverse + new_moments = [] + for moment in reversed(circuit): + new_ops = [] + + for op in moment: + # If this is a measurement, mark these qubits as measured + if protocols.is_measurement(op): + measured_qubits.update(op.qubits) + new_ops.append(op) + # If this is a diagonal gate and ANY of its qubits will be measured, remove it + # (diagonal gates only affect phase, which doesn't impact computational basis measurements) + elif _is_diagonal(op) and all(q in measured_qubits for q in op.qubits): + # Skip this operation (it's diagonal and at least one qubit is measured) + pass + else: + # Keep the operation + new_ops.append(op) + # If it's not diagonal, these qubits are no longer "safe to optimize" + if not _is_diagonal(op): + measured_qubits.difference_update(op.qubits) + + # Add the moment if it has any operations + if new_ops: + new_moments.append(cirq.Moment(new_ops)) + + # Reverse back to original order + return cirq.Circuit(reversed(new_moments)) \ No newline at end of file diff --git a/cirq-core/cirq/transformers/diagonal_optimization_test.py b/cirq-core/cirq/transformers/diagonal_optimization_test.py new file mode 100644 index 00000000000..d56d3f364a0 --- /dev/null +++ b/cirq-core/cirq/transformers/diagonal_optimization_test.py @@ -0,0 +1,188 @@ +# Copyright 2024 The Cirq Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import cirq +import numpy as np +import pytest +from cirq.transformers.diagonal_optimization import drop_diagonal_before_measurement + + +def test_removes_z_before_measure(): + q = cirq.NamedQubit('q') + + # Original: H -> Z -> Measure + circuit = cirq.Circuit( + cirq.H(q), + cirq.Z(q), + cirq.measure(q, key='m') + ) + + optimized = drop_diagonal_before_measurement(circuit) + + # Expected: H -> Measure (Z is gone) + expected = cirq.Circuit( + cirq.H(q), + cirq.measure(q, key='m') + ) + + assert optimized == expected + + +def test_removes_diagonal_chain(): + q = cirq.NamedQubit('q') + + # Original: H -> Z -> S -> Measure + circuit = cirq.Circuit( + cirq.H(q), + cirq.Z(q), + cirq.S(q), + cirq.measure(q, key='m') + ) + + optimized = drop_diagonal_before_measurement(circuit) + + # Expected: H -> Measure (Both Z and S are gone) + expected = cirq.Circuit( + cirq.H(q), + cirq.measure(q, key='m') + ) + + assert optimized == expected + + +def test_keeps_z_blocked_by_x(): + q = cirq.NamedQubit('q') + + # Original: Z -> X -> Measure + circuit = cirq.Circuit( + cirq.Z(q), + cirq.X(q), + cirq.measure(q, key='m') + ) + + # Z cannot commute past X, so it should be kept + # Note: eject_z will phase the X, so the circuit changes but Z is preserved + optimized = drop_diagonal_before_measurement(circuit) + + # We use this helper to check mathematical equivalence + # instead of checking exact gate types (Y vs PhasedX) + cirq.testing.assert_circuits_with_terminal_measurements_are_equivalent( + circuit, optimized, atol=1e-6 + ) + + +def test_keeps_cz_if_only_one_qubit_measured(): + q0, q1 = cirq.LineQubit.range(2) + + # Original: CZ(0,1) -> Measure(0) + circuit = cirq.Circuit( + cirq.CZ(q0, q1), + cirq.measure(q0, key='m') + ) + + # CZ shouldn't be removed because q1 is not measured + optimized = drop_diagonal_before_measurement(circuit) + + assert optimized == circuit + + +def test_removes_cz_if_both_measured(): + q0, q1 = cirq.LineQubit.range(2) + + # Original: CZ(0,1) -> Measure(0), Measure(1) + circuit = cirq.Circuit( + cirq.CZ(q0, q1), + cirq.measure(q0, key='m0'), + cirq.measure(q1, key='m1') + ) + + optimized = drop_diagonal_before_measurement(circuit) + + # Expected: Measures only + expected = cirq.Circuit( + cirq.measure(q0, key='m0'), + cirq.measure(q1, key='m1') + ) + + # Check that operations match (ignoring Moment structure) + assert list(optimized.all_operations()) == list(expected.all_operations()) + + +def test_feature_request_z_cz_commutation(): + """Test the original feature request case: Z-CZ commutation before measurement. + + The circuit Z(q0) - CZ(q0, q1) - measure(q1) should be optimized to just measure(q1). + This is because: + 1. Z on the control qubit of CZ commutes through the CZ + 2. After commutation, both gates are diagonal and before measurement + 3. Both can be removed + """ + q0, q1 = cirq.LineQubit.range(2) + + # Original feature request circuit + circuit = cirq.Circuit( + cirq.Z(q0), + cirq.CZ(q0, q1), + cirq.measure(q1, key='m1') + ) + + optimized = drop_diagonal_before_measurement(circuit) + + # The Z(0) might be moved or merged by eject_z, but the CZ MUST stay. + # We check that a two-qubit gate still exists. + assert len(list(optimized.findall_operations(lambda op: len(op.qubits) == 2))) > 0 + + +def test_feature_request_full_example(): + """Test the full feature request example with measurements on both qubits.""" + q0, q1 = cirq.LineQubit.range(2) + + # From feature request + circuit = cirq.Circuit( + cirq.Z(q0), + cirq.CZ(q0, q1), + cirq.Z(q1), + cirq.measure(q0, key='m0'), + cirq.measure(q1, key='m1') + ) + + optimized = drop_diagonal_before_measurement(circuit) + + # Should simplify to just measurements + expected = cirq.Circuit( + cirq.measure(q0, key='m0'), + cirq.measure(q1, key='m1') + ) + + assert list(optimized.all_operations()) == list(expected.all_operations()) + + +def test_preserves_non_diagonal_gates(): + """Test that non-diagonal gates are preserved.""" + q = cirq.NamedQubit('q') + + circuit = cirq.Circuit( + cirq.H(q), + cirq.X(q), + cirq.Z(q), + cirq.measure(q, key='m') + ) + + optimized = drop_diagonal_before_measurement(circuit) + + # Verify the physics hasn't changed (handles PhasedX vs Y differences) + cirq.testing.assert_circuits_with_terminal_measurements_are_equivalent( + circuit, optimized, atol=1e-6 + ) + From 9f96503866ec2e01a2a42bc894e9193c74fbf1f0 Mon Sep 17 00:00:00 2001 From: Vivek1106-04 Date: Thu, 4 Dec 2025 10:44:21 +0530 Subject: [PATCH 02/10] Update diagonal optimization --- .../transformers/diagonal_optimization.py | 62 +++--- .../diagonal_optimization_test.py | 176 ++++++++++-------- 2 files changed, 123 insertions(+), 115 deletions(-) diff --git a/cirq-core/cirq/transformers/diagonal_optimization.py b/cirq-core/cirq/transformers/diagonal_optimization.py index 169ec647a05..99f4ab41fef 100644 --- a/cirq-core/cirq/transformers/diagonal_optimization.py +++ b/cirq-core/cirq/transformers/diagonal_optimization.py @@ -16,32 +16,27 @@ from __future__ import annotations -from collections import defaultdict -from typing import TYPE_CHECKING - import numpy as np import cirq from cirq import ops, protocols -from cirq.transformers import eject_z, transformer_api, transformer_primitives - -if TYPE_CHECKING: - pass +from cirq.transformers import transformer_api +from cirq.transformers.eject_z import eject_z def _is_diagonal(op: cirq.Operation) -> bool: """Checks if an operation is diagonal in the computational basis. - + Args: op: The operation to check. - + Returns: True if the operation is diagonal in the computational basis. """ # Fast Path: Check for common diagonal gate types directly if isinstance(op.gate, (ops.ZPowGate, ops.CZPowGate, ops.IdentityGate)): return True - + # Slow Path: Check the unitary matrix if protocols.has_unitary(op): try: @@ -51,56 +46,54 @@ def _is_diagonal(op: cirq.Operation) -> bool: except Exception: # If matrix calculation fails (e.g. huge gates), assume not diagonal return False - + return False @transformer_api.transformer def drop_diagonal_before_measurement( - circuit: cirq.AbstractCircuit, - *, - context: cirq.TransformerContext | None = None, + circuit: cirq.AbstractCircuit, *, context: cirq.TransformerContext | None = None ) -> cirq.Circuit: """Removes diagonal gates that appear immediately before measurements. - + This transformer optimizes circuits by removing diagonal gates (gates that are diagonal in the computational basis, such as Z, S, T, CZ, etc.) that appear immediately before measurement operations. Since measurements project onto the computational basis, any diagonal gate applied immediately before a measurement does not affect the measurement outcome and can be safely removed. - + To maximize the effectiveness of this optimization, the transformer first applies - the `eject_z` transformation, which pushes Z gates (and other diagonal phases) + the `eject_z` transformation, which pushes Z gates (and other diagonal phases) later in the circuit. This handles cases where diagonal gates can commute past other operations. For example: - + Z(q0) - CZ(q0, q1) - measure(q1) - + After `eject_z`, the Z gate on the control qubit commutes through the CZ: - + CZ(q0, q1) - Z(q1) - measure(q1) - + Then both the CZ and Z(q1) can be removed since they're before the measurement: - + measure(q1) - + Args: circuit: Input circuit to transform. context: `cirq.TransformerContext` storing common configurable options for transformers. - + Returns: Copy of the transformed input circuit with diagonal gates before measurements removed. - + Examples: >>> import cirq >>> q0, q1 = cirq.LineQubit.range(2) - >>> + >>> >>> # Simple case: Z before measurement >>> circuit = cirq.Circuit(cirq.H(q0), cirq.Z(q0), cirq.measure(q0)) >>> optimized = cirq.drop_diagonal_before_measurement(circuit) >>> print(optimized) 0: ───H───M─── - + >>> # Complex case: Z-CZ commutation >>> circuit = cirq.Circuit( ... cirq.Z(q0), @@ -113,29 +106,30 @@ def drop_diagonal_before_measurement( """ if context is None: context = transformer_api.TransformerContext() - + # Phase 1: Apply eject_z to push Z gates later in the circuit. # This handles commutation of Z gates through other operations, # particularly important for the Z-CZ case mentioned in the feature request. circuit = eject_z(circuit, context=context) - + # Phase 2: Remove diagonal gates that appear before measurements. # We iterate in reverse to identify which qubits will be measured. # Track qubits that will be measured (set grows as we go backwards) measured_qubits: set[ops.Qid] = set() - + # Build new moments in reverse new_moments = [] for moment in reversed(circuit): new_ops = [] - + for op in moment: # If this is a measurement, mark these qubits as measured if protocols.is_measurement(op): measured_qubits.update(op.qubits) new_ops.append(op) # If this is a diagonal gate and ANY of its qubits will be measured, remove it - # (diagonal gates only affect phase, which doesn't impact computational basis measurements) + # (diagonal gates only affect phase, which doesn't impact computational basis + # measurements) elif _is_diagonal(op) and all(q in measured_qubits for q in op.qubits): # Skip this operation (it's diagonal and at least one qubit is measured) pass @@ -145,10 +139,10 @@ def drop_diagonal_before_measurement( # If it's not diagonal, these qubits are no longer "safe to optimize" if not _is_diagonal(op): measured_qubits.difference_update(op.qubits) - + # Add the moment if it has any operations if new_ops: new_moments.append(cirq.Moment(new_ops)) - + # Reverse back to original order return cirq.Circuit(reversed(new_moments)) \ No newline at end of file diff --git a/cirq-core/cirq/transformers/diagonal_optimization_test.py b/cirq-core/cirq/transformers/diagonal_optimization_test.py index d56d3f364a0..6fc77c3479f 100644 --- a/cirq-core/cirq/transformers/diagonal_optimization_test.py +++ b/cirq-core/cirq/transformers/diagonal_optimization_test.py @@ -12,69 +12,51 @@ # See the License for the specific language governing permissions and # limitations under the License. -import cirq import numpy as np import pytest + +import cirq from cirq.transformers.diagonal_optimization import drop_diagonal_before_measurement def test_removes_z_before_measure(): q = cirq.NamedQubit('q') - + # Original: H -> Z -> Measure - circuit = cirq.Circuit( - cirq.H(q), - cirq.Z(q), - cirq.measure(q, key='m') - ) - + circuit = cirq.Circuit(cirq.H(q), cirq.Z(q), cirq.measure(q, key='m')) + optimized = drop_diagonal_before_measurement(circuit) - + # Expected: H -> Measure (Z is gone) - expected = cirq.Circuit( - cirq.H(q), - cirq.measure(q, key='m') - ) - + expected = cirq.Circuit(cirq.H(q), cirq.measure(q, key='m')) + assert optimized == expected def test_removes_diagonal_chain(): q = cirq.NamedQubit('q') - + # Original: H -> Z -> S -> Measure - circuit = cirq.Circuit( - cirq.H(q), - cirq.Z(q), - cirq.S(q), - cirq.measure(q, key='m') - ) - + circuit = cirq.Circuit(cirq.H(q), cirq.Z(q), cirq.S(q), cirq.measure(q, key='m')) + optimized = drop_diagonal_before_measurement(circuit) - + # Expected: H -> Measure (Both Z and S are gone) - expected = cirq.Circuit( - cirq.H(q), - cirq.measure(q, key='m') - ) - + expected = cirq.Circuit(cirq.H(q), cirq.measure(q, key='m')) + assert optimized == expected def test_keeps_z_blocked_by_x(): q = cirq.NamedQubit('q') - + # Original: Z -> X -> Measure - circuit = cirq.Circuit( - cirq.Z(q), - cirq.X(q), - cirq.measure(q, key='m') - ) - + circuit = cirq.Circuit(cirq.Z(q), cirq.X(q), cirq.measure(q, key='m')) + # Z cannot commute past X, so it should be kept # Note: eject_z will phase the X, so the circuit changes but Z is preserved optimized = drop_diagonal_before_measurement(circuit) - + # We use this helper to check mathematical equivalence # instead of checking exact gate types (Y vs PhasedX) cirq.testing.assert_circuits_with_terminal_measurements_are_equivalent( @@ -84,44 +66,34 @@ def test_keeps_z_blocked_by_x(): def test_keeps_cz_if_only_one_qubit_measured(): q0, q1 = cirq.LineQubit.range(2) - + # Original: CZ(0,1) -> Measure(0) - circuit = cirq.Circuit( - cirq.CZ(q0, q1), - cirq.measure(q0, key='m') - ) - + circuit = cirq.Circuit(cirq.CZ(q0, q1), cirq.measure(q0, key='m')) + # CZ shouldn't be removed because q1 is not measured optimized = drop_diagonal_before_measurement(circuit) - + assert optimized == circuit def test_removes_cz_if_both_measured(): q0, q1 = cirq.LineQubit.range(2) - + # Original: CZ(0,1) -> Measure(0), Measure(1) - circuit = cirq.Circuit( - cirq.CZ(q0, q1), - cirq.measure(q0, key='m0'), - cirq.measure(q1, key='m1') - ) + circuit = cirq.Circuit(cirq.CZ(q0, q1), cirq.measure(q0, key='m0'), cirq.measure(q1, key='m1')) optimized = drop_diagonal_before_measurement(circuit) - + # Expected: Measures only - expected = cirq.Circuit( - cirq.measure(q0, key='m0'), - cirq.measure(q1, key='m1') - ) - + expected = cirq.Circuit(cirq.measure(q0, key='m0'), cirq.measure(q1, key='m1')) + # Check that operations match (ignoring Moment structure) assert list(optimized.all_operations()) == list(expected.all_operations()) def test_feature_request_z_cz_commutation(): """Test the original feature request case: Z-CZ commutation before measurement. - + The circuit Z(q0) - CZ(q0, q1) - measure(q1) should be optimized to just measure(q1). This is because: 1. Z on the control qubit of CZ commutes through the CZ @@ -129,14 +101,10 @@ def test_feature_request_z_cz_commutation(): 3. Both can be removed """ q0, q1 = cirq.LineQubit.range(2) - + # Original feature request circuit - circuit = cirq.Circuit( - cirq.Z(q0), - cirq.CZ(q0, q1), - cirq.measure(q1, key='m1') - ) - + circuit = cirq.Circuit(cirq.Z(q0), cirq.CZ(q0, q1), cirq.measure(q1, key='m1')) + optimized = drop_diagonal_before_measurement(circuit) # The Z(0) might be moved or merged by eject_z, but the CZ MUST stay. @@ -147,42 +115,88 @@ def test_feature_request_z_cz_commutation(): def test_feature_request_full_example(): """Test the full feature request example with measurements on both qubits.""" q0, q1 = cirq.LineQubit.range(2) - + # From feature request circuit = cirq.Circuit( cirq.Z(q0), cirq.CZ(q0, q1), cirq.Z(q1), cirq.measure(q0, key='m0'), - cirq.measure(q1, key='m1') + cirq.measure(q1, key='m1'), ) - + optimized = drop_diagonal_before_measurement(circuit) - + # Should simplify to just measurements - expected = cirq.Circuit( - cirq.measure(q0, key='m0'), - cirq.measure(q1, key='m1') - ) + expected = cirq.Circuit(cirq.measure(q0, key='m0'), cirq.measure(q1, key='m1')) - assert list(optimized.all_operations()) == list(expected.all_operations()) + assert list(optimized.all_operations()) == list(expected.all_operations()) def test_preserves_non_diagonal_gates(): """Test that non-diagonal gates are preserved.""" q = cirq.NamedQubit('q') - - circuit = cirq.Circuit( - cirq.H(q), - cirq.X(q), - cirq.Z(q), - cirq.measure(q, key='m') - ) - + + circuit = cirq.Circuit(cirq.H(q), cirq.X(q), cirq.Z(q), cirq.measure(q, key='m')) + optimized = drop_diagonal_before_measurement(circuit) # Verify the physics hasn't changed (handles PhasedX vs Y differences) cirq.testing.assert_circuits_with_terminal_measurements_are_equivalent( circuit, optimized, atol=1e-6 ) - + + +def test_is_diagonal_helper_edge_cases(): + """Test edge cases in _is_diagonal helper function for full coverage.""" + from cirq.transformers.diagonal_optimization import _is_diagonal + + q = cirq.NamedQubit('q') + + # Test diagonal gates (fast path) + assert _is_diagonal(cirq.Z(q)) + assert _is_diagonal(cirq.S(q)) # S is Z**0.5 + assert _is_diagonal(cirq.T(q)) # T is Z**0.25 + + # Test non-diagonal gates + assert not _is_diagonal(cirq.H(q)) + assert not _is_diagonal(cirq.X(q)) + + # Test two-qubit diagonal gate + q0, q1 = cirq.LineQubit.range(2) + assert _is_diagonal(cirq.CZ(q0, q1)) + + # Test operation without a gate (circuit operation or custom operation) + # This covers the "return False" when protocols.has_unitary is False + class NoUnitaryOp(cirq.Operation): + """Custom operation without a unitary.""" + + def __init__(self, qubits): + self._qubits = tuple(qubits) + + @property + def qubits(self): + return self._qubits + + def with_qubits(self, *new_qubits): + return NoUnitaryOp(new_qubits) + + no_unitary_op = NoUnitaryOp([q]) + assert not _is_diagonal(no_unitary_op) + + # Test operation that raises exception when computing unitary + # This covers the exception handling in _is_diagonal + class ExceptionGate(cirq.Gate): + """Custom gate that raises exception during unitary computation.""" + + def _num_qubits_(self): + return 1 + + def _has_unitary_(self): + return True + + def _unitary_(self): + raise ValueError("Simulated unitary computation error") + + exception_op = ExceptionGate().on(q) + assert not _is_diagonal(exception_op) From 571d11c3672145a94c6195de2a4f48bab8085136 Mon Sep 17 00:00:00 2001 From: Vivek1106-04 Date: Tue, 16 Dec 2025 11:45:11 +0530 Subject: [PATCH 03/10] optimize _is_diagonal --- .../transformers/diagonal_optimization.py | 54 +++++++------------ .../diagonal_optimization_test.py | 44 ++++----------- 2 files changed, 29 insertions(+), 69 deletions(-) diff --git a/cirq-core/cirq/transformers/diagonal_optimization.py b/cirq-core/cirq/transformers/diagonal_optimization.py index 99f4ab41fef..7a311bafaa2 100644 --- a/cirq-core/cirq/transformers/diagonal_optimization.py +++ b/cirq-core/cirq/transformers/diagonal_optimization.py @@ -16,8 +16,6 @@ from __future__ import annotations -import numpy as np - import cirq from cirq import ops, protocols from cirq.transformers import transformer_api @@ -25,29 +23,13 @@ def _is_diagonal(op: cirq.Operation) -> bool: - """Checks if an operation is diagonal in the computational basis. - - Args: - op: The operation to check. + """Checks if an operation is a known diagonal gate (Z, CZ, etc.). - Returns: - True if the operation is diagonal in the computational basis. + As suggested in review, we avoid computing the unitary matrix (which is expensive) + and instead strictly check for gates known to be diagonal in the computational basis. """ - # Fast Path: Check for common diagonal gate types directly - if isinstance(op.gate, (ops.ZPowGate, ops.CZPowGate, ops.IdentityGate)): - return True - - # Slow Path: Check the unitary matrix - if protocols.has_unitary(op): - try: - u = protocols.unitary(op) - # Check if off-diagonal elements are close to zero - return np.allclose(u, np.diag(np.diag(u))) - except Exception: - # If matrix calculation fails (e.g. huge gates), assume not diagonal - return False - - return False + # ZPowGate covers Z, S, T, Rz. CZPowGate covers CZ. + return isinstance(op.gate, (ops.ZPowGate, ops.CZPowGate, ops.IdentityGate)) @transformer_api.transformer @@ -107,9 +89,7 @@ def drop_diagonal_before_measurement( if context is None: context = transformer_api.TransformerContext() - # Phase 1: Apply eject_z to push Z gates later in the circuit. - # This handles commutation of Z gates through other operations, - # particularly important for the Z-CZ case mentioned in the feature request. + # Phase 1: Push Z gates later in the circuit to maximize removal opportunities. circuit = eject_z(circuit, context=context) # Phase 2: Remove diagonal gates that appear before measurements. @@ -130,19 +110,25 @@ def drop_diagonal_before_measurement( # If this is a diagonal gate and ANY of its qubits will be measured, remove it # (diagonal gates only affect phase, which doesn't impact computational basis # measurements) - elif _is_diagonal(op) and all(q in measured_qubits for q in op.qubits): - # Skip this operation (it's diagonal and at least one qubit is measured) - pass + elif _is_diagonal(op): + # CRITICAL: we can only remove if all qubits involved are measured. + # if even one qubit is NOT measured, the gate must stay to preserve + # the state of that unmeasured qubit (due to phase kickback/entanglement). + if all(q in measured_qubits for q in op.qubits): + continue # Drop the operation + + new_ops.append(op) + # Note: We do NOT remove qubits from measured_qubits here. + # Diagonal gates commute with other diagonal gates. else: - # Keep the operation + # Non-diagonal gate found. new_ops.append(op) - # If it's not diagonal, these qubits are no longer "safe to optimize" - if not _is_diagonal(op): - measured_qubits.difference_update(op.qubits) + # the chain is broken for these qubits. + measured_qubits.difference_update(op.qubits) # Add the moment if it has any operations if new_ops: new_moments.append(cirq.Moment(new_ops)) # Reverse back to original order - return cirq.Circuit(reversed(new_moments)) \ No newline at end of file + return cirq.Circuit(reversed(new_moments)) diff --git a/cirq-core/cirq/transformers/diagonal_optimization_test.py b/cirq-core/cirq/transformers/diagonal_optimization_test.py index 6fc77c3479f..ae51a9e8dd3 100644 --- a/cirq-core/cirq/transformers/diagonal_optimization_test.py +++ b/cirq-core/cirq/transformers/diagonal_optimization_test.py @@ -153,50 +153,24 @@ def test_is_diagonal_helper_edge_cases(): q = cirq.NamedQubit('q') - # Test diagonal gates (fast path) + # Test Z gates (including variants like S and T) assert _is_diagonal(cirq.Z(q)) assert _is_diagonal(cirq.S(q)) # S is Z**0.5 assert _is_diagonal(cirq.T(q)) # T is Z**0.25 + # Test identity gate + assert _is_diagonal(cirq.I(q)) + # Test non-diagonal gates assert not _is_diagonal(cirq.H(q)) assert not _is_diagonal(cirq.X(q)) + assert not _is_diagonal(cirq.Y(q)) - # Test two-qubit diagonal gate + # Test two-qubit CZ gate q0, q1 = cirq.LineQubit.range(2) assert _is_diagonal(cirq.CZ(q0, q1)) - # Test operation without a gate (circuit operation or custom operation) - # This covers the "return False" when protocols.has_unitary is False - class NoUnitaryOp(cirq.Operation): - """Custom operation without a unitary.""" - - def __init__(self, qubits): - self._qubits = tuple(qubits) - - @property - def qubits(self): - return self._qubits - - def with_qubits(self, *new_qubits): - return NoUnitaryOp(new_qubits) - - no_unitary_op = NoUnitaryOp([q]) - assert not _is_diagonal(no_unitary_op) - - # Test operation that raises exception when computing unitary - # This covers the exception handling in _is_diagonal - class ExceptionGate(cirq.Gate): - """Custom gate that raises exception during unitary computation.""" - - def _num_qubits_(self): - return 1 - - def _has_unitary_(self): - return True - - def _unitary_(self): - raise ValueError("Simulated unitary computation error") + # Other diagonal gates (like CCZ) are not detected by the optimized version + # This is intentional - eject_z is only effective for Z and CZ anyway + assert not _is_diagonal(cirq.CCZ(q0, q1, q)) - exception_op = ExceptionGate().on(q) - assert not _is_diagonal(exception_op) From ca3e796d6925ff95672c79dea6c014be6146bd68 Mon Sep 17 00:00:00 2001 From: Vivek1106-04 Date: Fri, 19 Dec 2025 11:21:29 +0530 Subject: [PATCH 04/10] format checks --- .../transformers/diagonal_optimization_test.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/cirq-core/cirq/transformers/diagonal_optimization_test.py b/cirq-core/cirq/transformers/diagonal_optimization_test.py index ae51a9e8dd3..d85c1e12ed6 100644 --- a/cirq-core/cirq/transformers/diagonal_optimization_test.py +++ b/cirq-core/cirq/transformers/diagonal_optimization_test.py @@ -11,15 +11,18 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +"""Tests for diagonal_optimization transformer.""" -import numpy as np -import pytest import cirq -from cirq.transformers.diagonal_optimization import drop_diagonal_before_measurement +from cirq.transformers.diagonal_optimization import ( + drop_diagonal_before_measurement, + _is_diagonal, +) def test_removes_z_before_measure(): + """Tests that Z gates are removed before measurement.""" q = cirq.NamedQubit('q') # Original: H -> Z -> Measure @@ -34,6 +37,7 @@ def test_removes_z_before_measure(): def test_removes_diagonal_chain(): + """Tests that a chain of diagonal gates is removed.""" q = cirq.NamedQubit('q') # Original: H -> Z -> S -> Measure @@ -48,6 +52,7 @@ def test_removes_diagonal_chain(): def test_keeps_z_blocked_by_x(): + """Tests that Z gates blocked by X gates are preserved.""" q = cirq.NamedQubit('q') # Original: Z -> X -> Measure @@ -65,6 +70,7 @@ def test_keeps_z_blocked_by_x(): def test_keeps_cz_if_only_one_qubit_measured(): + """Tests that CZ is kept if only one qubit is measured.""" q0, q1 = cirq.LineQubit.range(2) # Original: CZ(0,1) -> Measure(0) @@ -77,6 +83,7 @@ def test_keeps_cz_if_only_one_qubit_measured(): def test_removes_cz_if_both_measured(): + """Tests that CZ is removed if both qubits are measured.""" q0, q1 = cirq.LineQubit.range(2) # Original: CZ(0,1) -> Measure(0), Measure(1) @@ -149,7 +156,6 @@ def test_preserves_non_diagonal_gates(): def test_is_diagonal_helper_edge_cases(): """Test edge cases in _is_diagonal helper function for full coverage.""" - from cirq.transformers.diagonal_optimization import _is_diagonal q = cirq.NamedQubit('q') @@ -173,4 +179,3 @@ def test_is_diagonal_helper_edge_cases(): # Other diagonal gates (like CCZ) are not detected by the optimized version # This is intentional - eject_z is only effective for Z and CZ anyway assert not _is_diagonal(cirq.CCZ(q0, q1, q)) - From 8894ea1a32789b0dbb421e7567bee712657ea0c7 Mon Sep 17 00:00:00 2001 From: Vivek1106-04 Date: Sat, 20 Dec 2025 04:49:35 +0530 Subject: [PATCH 05/10] format checks --- cirq-core/cirq/transformers/diagonal_optimization_test.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/cirq-core/cirq/transformers/diagonal_optimization_test.py b/cirq-core/cirq/transformers/diagonal_optimization_test.py index d85c1e12ed6..b457b3966d4 100644 --- a/cirq-core/cirq/transformers/diagonal_optimization_test.py +++ b/cirq-core/cirq/transformers/diagonal_optimization_test.py @@ -15,10 +15,7 @@ import cirq -from cirq.transformers.diagonal_optimization import ( - drop_diagonal_before_measurement, - _is_diagonal, -) +from cirq.transformers.diagonal_optimization import _is_diagonal, drop_diagonal_before_measurement def test_removes_z_before_measure(): From 312599169e7dd1e099577b6d33cc31f5ae5d350c Mon Sep 17 00:00:00 2001 From: Vivek1106-04 Date: Thu, 25 Dec 2025 09:06:15 +0530 Subject: [PATCH 06/10] Add tests for commutation and unrecognized gates --- .../transformers/diagonal_optimization.py | 32 +++--- .../diagonal_optimization_test.py | 101 +++++++++++++++--- 2 files changed, 103 insertions(+), 30 deletions(-) diff --git a/cirq-core/cirq/transformers/diagonal_optimization.py b/cirq-core/cirq/transformers/diagonal_optimization.py index 7a311bafaa2..ab9531a1140 100644 --- a/cirq-core/cirq/transformers/diagonal_optimization.py +++ b/cirq-core/cirq/transformers/diagonal_optimization.py @@ -22,7 +22,7 @@ from cirq.transformers.eject_z import eject_z -def _is_diagonal(op: cirq.Operation) -> bool: +def _is_z_or_cz_pow_gate(op: cirq.Operation) -> bool: """Checks if an operation is a known diagonal gate (Z, CZ, etc.). As suggested in review, we avoid computing the unitary matrix (which is expensive) @@ -36,28 +36,29 @@ def _is_diagonal(op: cirq.Operation) -> bool: def drop_diagonal_before_measurement( circuit: cirq.AbstractCircuit, *, context: cirq.TransformerContext | None = None ) -> cirq.Circuit: - """Removes diagonal gates that appear immediately before measurements. + """Removes Z and CZ gates that appear immediately before measurements. - This transformer optimizes circuits by removing diagonal gates (gates that are - diagonal in the computational basis, such as Z, S, T, CZ, etc.) that appear - immediately before measurement operations. Since measurements project onto the - computational basis, any diagonal gate applied immediately before a measurement - does not affect the measurement outcome and can be safely removed. + This transformer optimizes circuits by removing Z-type and CZ-type diagonal gates + (specifically ZPowGate instances like Z, S, T, Rz, and CZPowGate instances like CZ) + that appear immediately before measurement operations. Since measurements project onto + the computational basis, these diagonal gates applied immediately before a measurement + do not affect the measurement outcome and can be safely removed (when all their qubits + are measured). To maximize the effectiveness of this optimization, the transformer first applies the `eject_z` transformation, which pushes Z gates (and other diagonal phases) later in the circuit. This handles cases where diagonal gates can commute past other operations. For example: - Z(q0) - CZ(q0, q1) - measure(q1) + Z(q0) - CZ(q0, q1) - measure(q0) - measure(q1) After `eject_z`, the Z gate on the control qubit commutes through the CZ: - CZ(q0, q1) - Z(q1) - measure(q1) + CZ(q0, q1) - Z(q1) - measure(q0) - measure(q1) - Then both the CZ and Z(q1) can be removed since they're before the measurement: + Then both the CZ and Z(q1) can be removed since all their qubits are measured: - measure(q1) + measure(q0) - measure(q1) Args: circuit: Input circuit to transform. @@ -76,14 +77,17 @@ def drop_diagonal_before_measurement( >>> print(optimized) 0: ───H───M─── - >>> # Complex case: Z-CZ commutation + >>> # Complex case: Z-CZ commutation with both qubits measured >>> circuit = cirq.Circuit( ... cirq.Z(q0), ... cirq.CZ(q0, q1), + ... cirq.measure(q0), ... cirq.measure(q1) ... ) >>> optimized = cirq.drop_diagonal_before_measurement(circuit) >>> print(optimized) + 0: ───M─── + 1: ───M─── """ if context is None: @@ -107,10 +111,10 @@ def drop_diagonal_before_measurement( if protocols.is_measurement(op): measured_qubits.update(op.qubits) new_ops.append(op) - # If this is a diagonal gate and ANY of its qubits will be measured, remove it + # If this is a diagonal gate and ALL of its qubits will be measured, remove it # (diagonal gates only affect phase, which doesn't impact computational basis # measurements) - elif _is_diagonal(op): + elif _is_z_or_cz_pow_gate(op): # CRITICAL: we can only remove if all qubits involved are measured. # if even one qubit is NOT measured, the gate must stay to preserve # the state of that unmeasured qubit (due to phase kickback/entanglement). diff --git a/cirq-core/cirq/transformers/diagonal_optimization_test.py b/cirq-core/cirq/transformers/diagonal_optimization_test.py index b457b3966d4..bbdde209e16 100644 --- a/cirq-core/cirq/transformers/diagonal_optimization_test.py +++ b/cirq-core/cirq/transformers/diagonal_optimization_test.py @@ -14,8 +14,13 @@ """Tests for diagonal_optimization transformer.""" +import numpy as np + import cirq -from cirq.transformers.diagonal_optimization import _is_diagonal, drop_diagonal_before_measurement +from cirq.transformers.diagonal_optimization import ( + _is_z_or_cz_pow_gate, + drop_diagonal_before_measurement, +) def test_removes_z_before_measure(): @@ -98,11 +103,12 @@ def test_removes_cz_if_both_measured(): def test_feature_request_z_cz_commutation(): """Test the original feature request case: Z-CZ commutation before measurement. - The circuit Z(q0) - CZ(q0, q1) - measure(q1) should be optimized to just measure(q1). + The circuit Z(q0) - CZ(q0, q1) - measure(q1) should keep the CZ gate. This is because: - 1. Z on the control qubit of CZ commutes through the CZ - 2. After commutation, both gates are diagonal and before measurement - 3. Both can be removed + 1. Z on the control qubit of CZ commutes through the CZ (via eject_z) + 2. After commutation: CZ(q0, q1) - Z(q1) - measure(q1) + 3. Z(q1) can be removed (only acts on measured qubit) + 4. CZ(q0, q1) must be kept (q0 is not measured) """ q0, q1 = cirq.LineQubit.range(2) @@ -151,28 +157,91 @@ def test_preserves_non_diagonal_gates(): ) -def test_is_diagonal_helper_edge_cases(): - """Test edge cases in _is_diagonal helper function for full coverage.""" +def test_diagonal_gates_commute_before_measurement(): + """Test that multiple recognized diagonal gates are all removed when all qubits are measured. + + This tests the property that recognized diagonal gates (Z, CZ) commute with each other, + so we don't remove qubits from measured_qubits when we encounter them. This allows + earlier diagonal gates in the circuit to also be removed. + """ + q0, q1 = cirq.LineQubit.range(2) + + # Circuit with multiple recognized diagonal gates before measurements + circuit = cirq.Circuit( + cirq.CZ(q0, q1), + cirq.Z(q0), + cirq.Z(q1), + cirq.measure(q0, key='m0'), + cirq.measure(q1, key='m1'), + ) + + optimized = drop_diagonal_before_measurement(circuit) + + # All recognized diagonal gates should be removed since all qubits are measured + expected = cirq.Circuit(cirq.measure(q0, key='m0'), cirq.measure(q1, key='m1')) + + assert list(optimized.all_operations()) == list(expected.all_operations()) + + +def test_unrecognized_diagonal_breaks_chain(): + """Test that a CZ followed by an unrecognized diagonal 4x4 unitary is handled correctly. + + Even if a gate is diagonal, if it's not a ZPowGate or CZPowGate, it won't be recognized + and will break the optimization chain. The earlier CZ gate cannot be removed because + the unrecognized diagonal gate blocks it. + """ + q0, q1 = cirq.LineQubit.range(2) + + # Create a custom diagonal 4x4 unitary (not a CZPowGate) + # This is diagonal but won't be recognized by _is_z_or_cz_pow_gate + diagonal_matrix = np.diag([1, 1j, -1, -1j]) + custom_diagonal_gate = cirq.MatrixGate(diagonal_matrix) + + # Circuit: CZ -> custom diagonal -> measurements + circuit = cirq.Circuit( + cirq.CZ(q0, q1), + custom_diagonal_gate(q0, q1), + cirq.measure(q0, key='m0'), + cirq.measure(q1, key='m1'), + ) + + optimized = drop_diagonal_before_measurement(circuit) + + # The custom diagonal gate is not recognized, so it blocks the chain + # Only the custom diagonal gate can be removed... wait, no! It's not recognized + # so it won't be removed at all. And it breaks the chain for q0 and q1. + # So the CZ also cannot be removed. + + # The optimized circuit should still have both gates + ops = list(optimized.all_operations()) + gate_ops = [op for op in ops if not cirq.is_measurement(op)] + + # Both CZ and custom diagonal should still be present + assert len(gate_ops) == 2 + + +def test_is_z_or_cz_pow_gate_helper_edge_cases(): + """Test edge cases in _is_z_or_cz_pow_gate helper function for full coverage.""" q = cirq.NamedQubit('q') # Test Z gates (including variants like S and T) - assert _is_diagonal(cirq.Z(q)) - assert _is_diagonal(cirq.S(q)) # S is Z**0.5 - assert _is_diagonal(cirq.T(q)) # T is Z**0.25 + assert _is_z_or_cz_pow_gate(cirq.Z(q)) + assert _is_z_or_cz_pow_gate(cirq.S(q)) # S is Z**0.5 + assert _is_z_or_cz_pow_gate(cirq.T(q)) # T is Z**0.25 # Test identity gate - assert _is_diagonal(cirq.I(q)) + assert _is_z_or_cz_pow_gate(cirq.I(q)) # Test non-diagonal gates - assert not _is_diagonal(cirq.H(q)) - assert not _is_diagonal(cirq.X(q)) - assert not _is_diagonal(cirq.Y(q)) + assert not _is_z_or_cz_pow_gate(cirq.H(q)) + assert not _is_z_or_cz_pow_gate(cirq.X(q)) + assert not _is_z_or_cz_pow_gate(cirq.Y(q)) # Test two-qubit CZ gate q0, q1 = cirq.LineQubit.range(2) - assert _is_diagonal(cirq.CZ(q0, q1)) + assert _is_z_or_cz_pow_gate(cirq.CZ(q0, q1)) # Other diagonal gates (like CCZ) are not detected by the optimized version # This is intentional - eject_z is only effective for Z and CZ anyway - assert not _is_diagonal(cirq.CCZ(q0, q1, q)) + assert not _is_z_or_cz_pow_gate(cirq.CCZ(q0, q1, q)) From 8d125ece92d2f1a9b334fec2ab74d16dcbdf4a75 Mon Sep 17 00:00:00 2001 From: Vivek1106-04 Date: Fri, 9 Jan 2026 20:19:36 +0530 Subject: [PATCH 07/10] requested changes --- cirq-core/cirq/__init__.py | 1 + .../transformers/diagonal_optimization.py | 7 ++- .../diagonal_optimization_test.py | 51 ++++++++----------- 3 files changed, 25 insertions(+), 34 deletions(-) diff --git a/cirq-core/cirq/__init__.py b/cirq-core/cirq/__init__.py index 95d711e4ef8..55f7a506ded 100644 --- a/cirq-core/cirq/__init__.py +++ b/cirq-core/cirq/__init__.py @@ -359,6 +359,7 @@ defer_measurements as defer_measurements, dephase_measurements as dephase_measurements, drop_empty_moments as drop_empty_moments, + drop_diagonal_before_measurement as drop_diagonal_before_measurement, drop_negligible_operations as drop_negligible_operations, drop_terminal_measurements as drop_terminal_measurements, eject_phased_paulis as eject_phased_paulis, diff --git a/cirq-core/cirq/transformers/diagonal_optimization.py b/cirq-core/cirq/transformers/diagonal_optimization.py index ab9531a1140..3dc2bd889f7 100644 --- a/cirq-core/cirq/transformers/diagonal_optimization.py +++ b/cirq-core/cirq/transformers/diagonal_optimization.py @@ -1,4 +1,4 @@ -# Copyright 2024 The Cirq Developers +# Copyright 2025 The Cirq Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -19,7 +19,6 @@ import cirq from cirq import ops, protocols from cirq.transformers import transformer_api -from cirq.transformers.eject_z import eject_z def _is_z_or_cz_pow_gate(op: cirq.Operation) -> bool: @@ -94,7 +93,7 @@ def drop_diagonal_before_measurement( context = transformer_api.TransformerContext() # Phase 1: Push Z gates later in the circuit to maximize removal opportunities. - circuit = eject_z(circuit, context=context) + circuit = cirq.eject_z(circuit, context=context) # Phase 2: Remove diagonal gates that appear before measurements. # We iterate in reverse to identify which qubits will be measured. @@ -118,7 +117,7 @@ def drop_diagonal_before_measurement( # CRITICAL: we can only remove if all qubits involved are measured. # if even one qubit is NOT measured, the gate must stay to preserve # the state of that unmeasured qubit (due to phase kickback/entanglement). - if all(q in measured_qubits for q in op.qubits): + if measured_qubits.issuperset(op.qubits): continue # Drop the operation new_ops.append(op) diff --git a/cirq-core/cirq/transformers/diagonal_optimization_test.py b/cirq-core/cirq/transformers/diagonal_optimization_test.py index bbdde209e16..2bc047b15e6 100644 --- a/cirq-core/cirq/transformers/diagonal_optimization_test.py +++ b/cirq-core/cirq/transformers/diagonal_optimization_test.py @@ -1,4 +1,4 @@ -# Copyright 2024 The Cirq Developers +# Copyright 2025 The Cirq Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -35,7 +35,7 @@ def test_removes_z_before_measure(): # Expected: H -> Measure (Z is gone) expected = cirq.Circuit(cirq.H(q), cirq.measure(q, key='m')) - assert optimized == expected + cirq.testing.assert_same_circuits(optimized, expected) def test_removes_diagonal_chain(): @@ -50,7 +50,7 @@ def test_removes_diagonal_chain(): # Expected: H -> Measure (Both Z and S are gone) expected = cirq.Circuit(cirq.H(q), cirq.measure(q, key='m')) - assert optimized == expected + cirq.testing.assert_same_circuits(optimized, expected) def test_keeps_z_blocked_by_x(): @@ -66,9 +66,7 @@ def test_keeps_z_blocked_by_x(): # We use this helper to check mathematical equivalence # instead of checking exact gate types (Y vs PhasedX) - cirq.testing.assert_circuits_with_terminal_measurements_are_equivalent( - circuit, optimized, atol=1e-6 - ) + cirq.testing.assert_circuits_with_terminal_measurements_are_equivalent(circuit, optimized) def test_keeps_cz_if_only_one_qubit_measured(): @@ -81,7 +79,7 @@ def test_keeps_cz_if_only_one_qubit_measured(): # CZ shouldn't be removed because q1 is not measured optimized = drop_diagonal_before_measurement(circuit) - assert optimized == circuit + cirq.testing.assert_same_circuits(optimized, circuit) def test_removes_cz_if_both_measured(): @@ -96,19 +94,20 @@ def test_removes_cz_if_both_measured(): # Expected: Measures only expected = cirq.Circuit(cirq.measure(q0, key='m0'), cirq.measure(q1, key='m1')) - # Check that operations match (ignoring Moment structure) - assert list(optimized.all_operations()) == list(expected.all_operations()) + cirq.testing.assert_same_circuits(optimized, expected) def test_feature_request_z_cz_commutation(): - """Test the original feature request case: Z-CZ commutation before measurement. + """Test the original feature request #4935: Z-CZ commutation before measurement. The circuit Z(q0) - CZ(q0, q1) - measure(q1) should keep the CZ gate. This is because: 1. Z on the control qubit of CZ commutes through the CZ (via eject_z) - 2. After commutation: CZ(q0, q1) - Z(q1) - measure(q1) + 2. After commutation: CZ(q0, q1) - Z(q0) - Z(q1) - measure(q1) 3. Z(q1) can be removed (only acts on measured qubit) - 4. CZ(q0, q1) must be kept (q0 is not measured) + 4. CZ(q0, q1) and Z(q0) must be kept (q0 is not measured) + + The optimized circuit is: CZ(q0, q1) - Z(q0) - M(q1) """ q0, q1 = cirq.LineQubit.range(2) @@ -117,13 +116,14 @@ def test_feature_request_z_cz_commutation(): optimized = drop_diagonal_before_measurement(circuit) - # The Z(0) might be moved or merged by eject_z, but the CZ MUST stay. - # We check that a two-qubit gate still exists. - assert len(list(optimized.findall_operations(lambda op: len(op.qubits) == 2))) > 0 + # Expected: CZ(q0, q1) - Z(q0) - M(q1) + expected = cirq.Circuit(cirq.CZ(q0, q1), cirq.Z(q0), cirq.measure(q1, key='m1')) + + cirq.testing.assert_same_circuits(optimized, expected) def test_feature_request_full_example(): - """Test the full feature request example with measurements on both qubits.""" + """Test the full feature request #4935 with measurements on both qubits.""" q0, q1 = cirq.LineQubit.range(2) # From feature request @@ -131,8 +131,7 @@ def test_feature_request_full_example(): cirq.Z(q0), cirq.CZ(q0, q1), cirq.Z(q1), - cirq.measure(q0, key='m0'), - cirq.measure(q1, key='m1'), + cirq.Moment(cirq.measure(q0, key='m0'), cirq.measure(q1, key='m1')), ) optimized = drop_diagonal_before_measurement(circuit) @@ -140,7 +139,7 @@ def test_feature_request_full_example(): # Should simplify to just measurements expected = cirq.Circuit(cirq.measure(q0, key='m0'), cirq.measure(q1, key='m1')) - assert list(optimized.all_operations()) == list(expected.all_operations()) + cirq.testing.assert_same_circuits(optimized, expected) def test_preserves_non_diagonal_gates(): @@ -152,9 +151,7 @@ def test_preserves_non_diagonal_gates(): optimized = drop_diagonal_before_measurement(circuit) # Verify the physics hasn't changed (handles PhasedX vs Y differences) - cirq.testing.assert_circuits_with_terminal_measurements_are_equivalent( - circuit, optimized, atol=1e-6 - ) + cirq.testing.assert_circuits_with_terminal_measurements_are_equivalent(circuit, optimized) def test_diagonal_gates_commute_before_measurement(): @@ -180,7 +177,7 @@ def test_diagonal_gates_commute_before_measurement(): # All recognized diagonal gates should be removed since all qubits are measured expected = cirq.Circuit(cirq.measure(q0, key='m0'), cirq.measure(q1, key='m1')) - assert list(optimized.all_operations()) == list(expected.all_operations()) + cirq.testing.assert_same_circuits(optimized, expected) def test_unrecognized_diagonal_breaks_chain(): @@ -211,13 +208,7 @@ def test_unrecognized_diagonal_breaks_chain(): # Only the custom diagonal gate can be removed... wait, no! It's not recognized # so it won't be removed at all. And it breaks the chain for q0 and q1. # So the CZ also cannot be removed. - - # The optimized circuit should still have both gates - ops = list(optimized.all_operations()) - gate_ops = [op for op in ops if not cirq.is_measurement(op)] - - # Both CZ and custom diagonal should still be present - assert len(gate_ops) == 2 + cirq.testing.assert_same_circuits(optimized, circuit) def test_is_z_or_cz_pow_gate_helper_edge_cases(): From 06bb2ad006ecd932d86d06d29d20794c0a33b22d Mon Sep 17 00:00:00 2001 From: Pavol Juhas Date: Fri, 9 Jan 2026 12:29:37 -0800 Subject: [PATCH 08/10] Sync test_feature_request_z_cz_commutation with the example in #4935 Make test docstring correspond to the code below. --- .../cirq/transformers/diagonal_optimization_test.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/cirq-core/cirq/transformers/diagonal_optimization_test.py b/cirq-core/cirq/transformers/diagonal_optimization_test.py index 2bc047b15e6..eebeb0f6c16 100644 --- a/cirq-core/cirq/transformers/diagonal_optimization_test.py +++ b/cirq-core/cirq/transformers/diagonal_optimization_test.py @@ -100,24 +100,23 @@ def test_removes_cz_if_both_measured(): def test_feature_request_z_cz_commutation(): """Test the original feature request #4935: Z-CZ commutation before measurement. - The circuit Z(q0) - CZ(q0, q1) - measure(q1) should keep the CZ gate. + The circuit Z(q0) - CZ(q0, q1) - Z(q1) - M(q1) should keep the CZ gate. This is because: - 1. Z on the control qubit of CZ commutes through the CZ (via eject_z) - 2. After commutation: CZ(q0, q1) - Z(q0) - Z(q1) - measure(q1) - 3. Z(q1) can be removed (only acts on measured qubit) - 4. CZ(q0, q1) and Z(q0) must be kept (q0 is not measured) + 1. Z(q0) commutes through the CZ and Z(q1) is removed (via eject_z) + 2. After commutation: CZ(q0, q1) - Z(q0) - M(q1) + 3. CZ(q0, q1) and Z(q0) must be kept (q0 is not measured) The optimized circuit is: CZ(q0, q1) - Z(q0) - M(q1) """ q0, q1 = cirq.LineQubit.range(2) # Original feature request circuit - circuit = cirq.Circuit(cirq.Z(q0), cirq.CZ(q0, q1), cirq.measure(q1, key='m1')) + circuit = cirq.Circuit(cirq.Z(q0), cirq.CZ(q0, q1), cirq.Z(q1), cirq.measure(q1, key='m1')) optimized = drop_diagonal_before_measurement(circuit) # Expected: CZ(q0, q1) - Z(q0) - M(q1) - expected = cirq.Circuit(cirq.CZ(q0, q1), cirq.Z(q0), cirq.measure(q1, key='m1')) + expected = cirq.Circuit(cirq.CZ(q0, q1), cirq.Z(q0), cirq.Moment(cirq.measure(q1, key='m1'))) cirq.testing.assert_same_circuits(optimized, expected) From 03aeb0341f4d5f7991c1700c09c51586d2692942 Mon Sep 17 00:00:00 2001 From: Pavol Juhas Date: Fri, 9 Jan 2026 12:46:59 -0800 Subject: [PATCH 09/10] Avoid possible issues with circular import Replace top-level import of cirq, but keep it for for type-checking. --- .../cirq/transformers/diagonal_optimization.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/cirq-core/cirq/transformers/diagonal_optimization.py b/cirq-core/cirq/transformers/diagonal_optimization.py index 3dc2bd889f7..e72353ddd0e 100644 --- a/cirq-core/cirq/transformers/diagonal_optimization.py +++ b/cirq-core/cirq/transformers/diagonal_optimization.py @@ -16,10 +16,14 @@ from __future__ import annotations -import cirq -from cirq import ops, protocols +from typing import TYPE_CHECKING + +from cirq import circuits, ops, protocols, transformers from cirq.transformers import transformer_api +if TYPE_CHECKING: + import cirq + def _is_z_or_cz_pow_gate(op: cirq.Operation) -> bool: """Checks if an operation is a known diagonal gate (Z, CZ, etc.). @@ -93,7 +97,7 @@ def drop_diagonal_before_measurement( context = transformer_api.TransformerContext() # Phase 1: Push Z gates later in the circuit to maximize removal opportunities. - circuit = cirq.eject_z(circuit, context=context) + circuit = transformers.eject_z(circuit, context=context) # Phase 2: Remove diagonal gates that appear before measurements. # We iterate in reverse to identify which qubits will be measured. @@ -131,7 +135,7 @@ def drop_diagonal_before_measurement( # Add the moment if it has any operations if new_ops: - new_moments.append(cirq.Moment(new_ops)) + new_moments.append(circuits.Moment(new_ops)) # Reverse back to original order - return cirq.Circuit(reversed(new_moments)) + return circuits.Circuit(reversed(new_moments)) From df4a51fb74cc8dee349f6d921a683b22c7fa0cbe Mon Sep 17 00:00:00 2001 From: Pavol Juhas Date: Fri, 9 Jan 2026 12:50:06 -0800 Subject: [PATCH 10/10] Prefer alphabetic order of imports in `__init__` files No change in code function. --- cirq-core/cirq/__init__.py | 2 +- cirq-core/cirq/transformers/__init__.py | 8 ++++---- cirq-core/cirq/transformers/diagonal_optimization_test.py | 1 + 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/cirq-core/cirq/__init__.py b/cirq-core/cirq/__init__.py index 55f7a506ded..290d48ac632 100644 --- a/cirq-core/cirq/__init__.py +++ b/cirq-core/cirq/__init__.py @@ -358,8 +358,8 @@ decompose_two_qubit_interaction_into_four_fsim_gates as decompose_two_qubit_interaction_into_four_fsim_gates, # noqa: E501 defer_measurements as defer_measurements, dephase_measurements as dephase_measurements, - drop_empty_moments as drop_empty_moments, drop_diagonal_before_measurement as drop_diagonal_before_measurement, + drop_empty_moments as drop_empty_moments, drop_negligible_operations as drop_negligible_operations, drop_terminal_measurements as drop_terminal_measurements, eject_phased_paulis as eject_phased_paulis, diff --git a/cirq-core/cirq/transformers/__init__.py b/cirq-core/cirq/transformers/__init__.py index ee38dbf8b62..d1f6c70b7a7 100644 --- a/cirq-core/cirq/transformers/__init__.py +++ b/cirq-core/cirq/transformers/__init__.py @@ -76,6 +76,10 @@ optimize_for_target_gateset as optimize_for_target_gateset, ) +from cirq.transformers.diagonal_optimization import ( + drop_diagonal_before_measurement as drop_diagonal_before_measurement, +) + from cirq.transformers.drop_empty_moments import drop_empty_moments as drop_empty_moments from cirq.transformers.drop_negligible_operations import ( @@ -88,10 +92,6 @@ from cirq.transformers.eject_z import eject_z as eject_z -from cirq.transformers.diagonal_optimization import ( - drop_diagonal_before_measurement as drop_diagonal_before_measurement, -) - from cirq.transformers.measurement_transformers import ( defer_measurements as defer_measurements, dephase_measurements as dephase_measurements, diff --git a/cirq-core/cirq/transformers/diagonal_optimization_test.py b/cirq-core/cirq/transformers/diagonal_optimization_test.py index eebeb0f6c16..78455197388 100644 --- a/cirq-core/cirq/transformers/diagonal_optimization_test.py +++ b/cirq-core/cirq/transformers/diagonal_optimization_test.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + """Tests for diagonal_optimization transformer."""