Skip to content

Commit 1296df5

Browse files
committed
Simplify decomposition of controlled eigengates with global phase
1 parent 65a4105 commit 1296df5

10 files changed

+207
-6
lines changed

cirq-core/cirq/ops/classically_controlled_operation_test.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -473,7 +473,8 @@ def test_decompose():
473473
op = cirq.H(q0).with_classical_controls('a')
474474
assert cirq.decompose(op) == [
475475
(cirq.Y(q0) ** 0.5).with_classical_controls('a'),
476-
cirq.XPowGate(exponent=1.0, global_shift=-0.25).on(q0).with_classical_controls('a'),
476+
cirq.X(q0).with_classical_controls('a'),
477+
cirq.global_phase_operation(1j**-0.5).with_classical_controls('a'),
477478
]
478479

479480

cirq-core/cirq/ops/common_gates.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -352,6 +352,12 @@ def __init__(self, *, rads: value.TParamVal):
352352
def _with_exponent(self, exponent: value.TParamVal) -> 'Rx':
353353
return Rx(rads=exponent * _pi(exponent))
354354

355+
def _decompose_(self, qubits: Tuple['cirq.Qid', ...]) -> NotImplementedType:
356+
"""Returns:
357+
NotImplemented, to signify the gate doesn't decompose further.
358+
"""
359+
return NotImplemented
360+
355361
def _circuit_diagram_info_(
356362
self, args: 'cirq.CircuitDiagramInfoArgs'
357363
) -> Union[str, 'protocols.CircuitDiagramInfo']:
@@ -537,6 +543,12 @@ def __init__(self, *, rads: value.TParamVal):
537543
def _with_exponent(self, exponent: value.TParamVal) -> 'Ry':
538544
return Ry(rads=exponent * _pi(exponent))
539545

546+
def _decompose_(self, qubits: Tuple['cirq.Qid', ...]) -> NotImplementedType:
547+
"""Returns:
548+
NotImplemented, to signify the gate doesn't decompose further.
549+
"""
550+
return NotImplemented
551+
540552
def _circuit_diagram_info_(
541553
self, args: 'cirq.CircuitDiagramInfoArgs'
542554
) -> Union[str, 'protocols.CircuitDiagramInfo']:
@@ -882,6 +894,12 @@ def __init__(self, *, rads: value.TParamVal):
882894
def _with_exponent(self, exponent: value.TParamVal) -> 'Rz':
883895
return Rz(rads=exponent * _pi(exponent))
884896

897+
def _decompose_(self, qubits: Tuple['cirq.Qid', ...]) -> NotImplementedType:
898+
"""Returns:
899+
NotImplemented, to signify the gate doesn't decompose further.
900+
"""
901+
return NotImplemented
902+
885903
def _circuit_diagram_info_(
886904
self, args: 'cirq.CircuitDiagramInfoArgs'
887905
) -> Union[str, 'protocols.CircuitDiagramInfo']:

cirq-core/cirq/ops/controlled_gate.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
control_values as cv,
3434
controlled_operation as cop,
3535
diagonal_gate as dg,
36+
eigen_gate,
3637
global_phase_op as gp,
3738
op_tree,
3839
raw_types,
@@ -159,6 +160,12 @@ def _decompose_with_context_(
159160
self, qubits: Tuple['cirq.Qid', ...], context: Optional['cirq.DecompositionContext'] = None
160161
) -> Union[None, NotImplementedType, 'cirq.OP_TREE']:
161162
control_qubits = list(qubits[: self.num_controls()])
163+
# If the subgate is an EigenGate with non-zero phase, try to decompose it
164+
# into a phase-free gate and a global phase gate.
165+
if isinstance(self.sub_gate, eigen_gate.EigenGate) and self.sub_gate.global_shift != 0:
166+
result = self._decompose_sub_gate_with_controls(qubits, context)
167+
if result is not NotImplemented:
168+
return result
162169
if (
163170
protocols.has_unitary(self.sub_gate)
164171
and protocols.num_qubits(self.sub_gate) == 1
@@ -219,6 +226,11 @@ def _decompose_with_context_(
219226
control_qid_shape=self.control_qid_shape,
220227
).on(*control_qubits)
221228
return [result, controlled_phase_op]
229+
return self._decompose_sub_gate_with_controls(qubits, context)
230+
231+
def _decompose_sub_gate_with_controls(
232+
self, qubits: Tuple['cirq.Qid', ...], context: Optional['cirq.DecompositionContext'] = None
233+
) -> Union[None, NotImplementedType, 'cirq.OP_TREE']:
222234
result = protocols.decompose_once_with_qubits(
223235
self.sub_gate,
224236
qubits[self.num_controls() :],

cirq-core/cirq/ops/controlled_gate_test.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -494,6 +494,22 @@ def _test_controlled_gate_is_consistent(
494494
np.testing.assert_allclose(cirq.unitary(cgate), cirq.unitary(circuit), atol=1e-13)
495495

496496

497+
@pytest.mark.parametrize(
498+
'sub_gate, expected_decomposition',
499+
[
500+
(cirq.XPowGate(global_shift=0.22), [cirq.Y**-0.5, cirq.CZ, cirq.Y**0.5, cirq.Z**0.22]),
501+
(cirq.ZPowGate(exponent=1.2, global_shift=0.3), [cirq.CZ**1.2, cirq.Z**0.36]),
502+
],
503+
)
504+
def test_decompose_takes_out_global_phase(
505+
sub_gate: cirq.Gate, expected_decomposition: Sequence[cirq.Gate]
506+
):
507+
cgate = cirq.ControlledGate(sub_gate, num_controls=1)
508+
qubits = cirq.LineQubit.range(cgate.num_qubits())
509+
dec = cirq.decompose(cgate.on(*qubits))
510+
assert [op.gate for op in dec] == expected_decomposition
511+
512+
497513
def test_pow_inverse():
498514
assert cirq.inverse(CRestricted, None) is None
499515
assert cirq.pow(CRestricted, 1.5, None) is None

cirq-core/cirq/ops/eigen_gate.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
import sympy
3535

3636
from cirq import protocols, value
37-
from cirq.ops import raw_types
37+
from cirq.ops import global_phase_op, raw_types
3838

3939
if TYPE_CHECKING:
4040
import cirq
@@ -375,6 +375,24 @@ def _json_dict_(self) -> Dict[str, Any]:
375375
def _measurement_key_objs_(self):
376376
return frozenset()
377377

378+
def _decompose_(
379+
self, qubits: Tuple['cirq.Qid', ...]
380+
) -> Union[NotImplementedType, 'cirq.OP_TREE']:
381+
"""Attempts to decompose the gate into a phase-free gate and a global phase gate.
382+
383+
Returns:
384+
NotImplemented, if global phase or exponent are 0. Otherwise a phase-free gate
385+
applied to the qubits followed by a global phase gate.
386+
"""
387+
if self.global_shift == 0 or self.exponent == 0:
388+
return NotImplemented
389+
self_without_phase = self._with_exponent(self.exponent)
390+
# This doesn't work for gates that fix global_shift, such as Rx. These gates must define
391+
# their own _decompose_ method.
392+
self_without_phase._global_shift = 0
393+
global_phase = 1j ** (2 * self.global_shift * self.exponent)
394+
return [self_without_phase.on(*qubits), global_phase_op.GlobalPhaseGate(global_phase)()]
395+
378396

379397
def _lcm(vals: Iterable[int]) -> int:
380398
t = 1

cirq-core/cirq/ops/eigen_gate_test.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
import cirq
2222
from cirq import value
23+
from cirq.ops import global_phase_op
2324
from cirq.testing import assert_has_consistent_trace_distance_bound
2425

2526

@@ -421,3 +422,43 @@ def _with_exponent(self, exponent):
421422
)
422423
def test_equal_up_to_global_phase(gate1, gate2, eq_up_to_global_phase):
423424
assert cirq.equal_up_to_global_phase(gate1, gate2) == eq_up_to_global_phase
425+
426+
427+
@pytest.mark.parametrize(
428+
'gate',
429+
[
430+
cirq.Z,
431+
cirq.Z**2,
432+
cirq.XPowGate(global_shift=0.0),
433+
cirq.rx(0),
434+
cirq.ry(0),
435+
cirq.rz(0),
436+
cirq.CZPowGate(exponent=0.0, global_shift=0.25),
437+
],
438+
)
439+
def test_decompose_once_returns_not_implemented(gate: cirq.Gate):
440+
qubits = cirq.LineQubit.range(gate.num_qubits())
441+
assert cirq.decompose_once(gate.on(*qubits), default=NotImplemented) == NotImplemented
442+
443+
444+
@pytest.mark.parametrize(
445+
'gate, expected_decomposition',
446+
[
447+
(cirq.X, [cirq.X]),
448+
(cirq.ZPowGate(global_shift=0.5), [cirq.Z, global_phase_op.GlobalPhaseGate(1j)]),
449+
(
450+
cirq.ZPowGate(global_shift=0.5) ** sympy.Symbol('e'),
451+
[
452+
cirq.Z ** sympy.Symbol('e'),
453+
global_phase_op.GlobalPhaseGate(1j ** (1.0 * sympy.Symbol('e'))),
454+
],
455+
),
456+
(cirq.rx(np.pi / 2), [cirq.rx(np.pi / 2)]),
457+
(cirq.ry(np.pi / 2), [cirq.ry(np.pi / 2)]),
458+
(cirq.rz(np.pi / 2), [cirq.rz(np.pi / 2)]),
459+
],
460+
)
461+
def test_decompose_takes_out_global_phase(gate: cirq.Gate, expected_decomposition: List[cirq.Gate]):
462+
qubits = cirq.LineQubit.range(gate.num_qubits())
463+
dec = cirq.decompose(gate.on(*qubits))
464+
assert [op.gate for op in dec] == expected_decomposition

cirq-core/cirq/ops/parity_gates_test.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414

1515
"""Tests for `parity_gates.py`."""
1616

17+
from typing import List
18+
1719
import numpy as np
1820
import pytest
1921
import sympy
@@ -348,3 +350,60 @@ def test_clifford_protocols(gate_cls: type[cirq.EigenGate], exponent: float, is_
348350
else:
349351
assert not cirq.has_stabilizer_effect(gate)
350352
assert gate._decompose_into_clifford_with_qubits_(cirq.LineQubit.range(2)) is NotImplemented
353+
354+
355+
@pytest.mark.parametrize(
356+
'gate, expected_decomposition',
357+
[
358+
(
359+
cirq.XXPowGate(),
360+
[
361+
(cirq.Y**-0.5).on(cirq.LineQubit(0)),
362+
(cirq.Y**-0.5).on(cirq.LineQubit(1)),
363+
cirq.Z(cirq.LineQubit(0)),
364+
cirq.Z(cirq.LineQubit(1)),
365+
(cirq.CZ**-2.0).on(cirq.LineQubit(0), cirq.LineQubit(1)),
366+
(cirq.Y**0.5).on(cirq.LineQubit(0)),
367+
(cirq.Y**0.5).on(cirq.LineQubit(1)),
368+
],
369+
),
370+
(
371+
cirq.YYPowGate(),
372+
[
373+
(cirq.X**0.5).on(cirq.LineQubit(0)),
374+
(cirq.X**0.5).on(cirq.LineQubit(1)),
375+
cirq.Z(cirq.LineQubit(0)),
376+
cirq.Z(cirq.LineQubit(1)),
377+
(cirq.CZ**-2.0).on(cirq.LineQubit(0), cirq.LineQubit(1)),
378+
(cirq.X**-0.5).on(cirq.LineQubit(0)),
379+
(cirq.X**-0.5).on(cirq.LineQubit(1)),
380+
],
381+
),
382+
(
383+
cirq.ZZPowGate(),
384+
[
385+
cirq.Z(cirq.LineQubit(0)),
386+
cirq.Z(cirq.LineQubit(1)),
387+
(cirq.CZ**-2.0).on(cirq.LineQubit(0), cirq.LineQubit(1)),
388+
],
389+
),
390+
(
391+
cirq.MSGate(rads=0),
392+
[
393+
(cirq.Y**-0.5).on(cirq.LineQubit(0)),
394+
(cirq.Y**-0.5).on(cirq.LineQubit(1)),
395+
(cirq.Z**0.0).on(cirq.LineQubit(0)),
396+
(cirq.Z**0.0).on(cirq.LineQubit(1)),
397+
cirq.CZPowGate(exponent=-0.0, global_shift=0.25).on(
398+
cirq.LineQubit(0), cirq.LineQubit(1)
399+
),
400+
(cirq.Y**0.5).on(cirq.LineQubit(0)),
401+
(cirq.Y**0.5).on(cirq.LineQubit(1)),
402+
],
403+
),
404+
],
405+
)
406+
def test_gate_decomposition(gate: cirq.Gate, expected_decomposition: List[cirq.Gate]):
407+
qubits = cirq.LineQubit.range(gate.num_qubits())
408+
dec = cirq.decompose(gate.on(*qubits))
409+
assert [op for op in dec] == expected_decomposition

cirq-core/cirq/ops/raw_types_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -647,7 +647,7 @@ def test_tagged_operation_forwards_protocols():
647647
np.testing.assert_equal(cirq.unitary(tagged_h), cirq.unitary(h))
648648
assert cirq.has_unitary(tagged_h)
649649
assert cirq.decompose(tagged_h) == cirq.decompose(h)
650-
assert [*tagged_h._decompose_()] == cirq.decompose(h)
650+
assert [*tagged_h._decompose_()] == cirq.decompose_once(h)
651651
assert cirq.pauli_expansion(tagged_h) == cirq.pauli_expansion(h)
652652
assert cirq.equal_up_to_global_phase(h, tagged_h)
653653
assert np.isclose(cirq.kraus(h), cirq.kraus(tagged_h)).all()

cirq-core/cirq/transformers/merge_single_qubit_gates.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ def merge_single_qubit_moments_to_phxz(
116116

117117
def can_merge_moment(m: 'cirq.Moment'):
118118
return all(
119-
protocols.num_qubits(op) == 1
119+
(protocols.num_qubits(op) == 1 or protocols.num_qubits(op) == 0)
120120
and protocols.has_unitary(op)
121121
and tags_to_ignore.isdisjoint(op.tags)
122122
for op in m
@@ -144,6 +144,10 @@ def merge_func(m1: 'cirq.Moment', m2: 'cirq.Moment') -> Optional['cirq.Moment']:
144144
)
145145
if gate:
146146
ret_ops.append(gate(q))
147+
# Transfer global phase
148+
for op in m1.operations + m2.operations:
149+
if protocols.num_qubits(op) == 0:
150+
ret_ops.append(op)
147151
return circuits.Moment(ret_ops)
148152

149153
return transformer_primitives.merge_moments(

cirq-core/cirq/transformers/merge_single_qubit_gates_test.py

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -221,13 +221,45 @@ def test_merge_single_qubit_moments_to_phxz_deep():
221221
)
222222

223223

224-
def test_merge_single_qubit_moments_to_phxz_global_phase():
224+
def test_merge_single_qubit_gates_to_phxz_global_phase():
225225
c = cirq.Circuit(cirq.GlobalPhaseGate(1j).on())
226226
c2 = cirq.merge_single_qubit_gates_to_phxz(c)
227227
assert c == c2
228228

229229

230-
def test_merge_single_qubit_moments_to_phased_x_and_z_global_phase():
230+
def test_merge_single_qubit_gates_to_phased_x_and_z_global_phase():
231231
c = cirq.Circuit(cirq.GlobalPhaseGate(1j).on())
232232
c2 = cirq.merge_single_qubit_gates_to_phased_x_and_z(c)
233233
assert c == c2
234+
235+
236+
def test_merge_single_qubit_moments_to_phxz_with_global_phase_in_first_moment():
237+
q0 = cirq.LineQubit(0)
238+
c_orig = cirq.Circuit(
239+
cirq.Moment(cirq.Y(q0) ** 0.5, cirq.GlobalPhaseGate(1j**0.5).on()), cirq.Moment(cirq.X(q0))
240+
)
241+
c_expected = cirq.Circuit(
242+
cirq.Moment(
243+
cirq.PhasedXZGate(axis_phase_exponent=-0.5, x_exponent=0.5, z_exponent=-1.0).on(q0),
244+
cirq.GlobalPhaseGate(1j**0.5).on(),
245+
)
246+
)
247+
context = cirq.TransformerContext(tags_to_ignore=["ignore"])
248+
c_new = cirq.merge_single_qubit_moments_to_phxz(c_orig, context=context)
249+
assert c_new == c_expected
250+
251+
252+
def test_merge_single_qubit_moments_to_phxz_with_global_phase_in_second_moment():
253+
q0 = cirq.LineQubit(0)
254+
c_orig = cirq.Circuit(
255+
cirq.Moment(cirq.Y(q0) ** 0.5), cirq.Moment(cirq.X(q0), cirq.GlobalPhaseGate(1j**0.5).on())
256+
)
257+
c_expected = cirq.Circuit(
258+
cirq.Moment(
259+
cirq.PhasedXZGate(axis_phase_exponent=-0.5, x_exponent=0.5, z_exponent=-1.0).on(q0),
260+
cirq.GlobalPhaseGate(1j**0.5).on(),
261+
)
262+
)
263+
context = cirq.TransformerContext(tags_to_ignore=["ignore"])
264+
c_new = cirq.merge_single_qubit_moments_to_phxz(c_orig, context=context)
265+
assert c_new == c_expected

0 commit comments

Comments
 (0)