Skip to content

Commit 1c0ba5b

Browse files
authored
Flatten controlled-CZ and controlled-CX more consistently (#7365)
Fixes #7241 by absorbing the control layer of controlled-CZ into the control itself, leaving a controlled-Z, and does the equivalent for CX / X. As documented in the issue, this approach allows for more consistency in how controlled gates are represented. As stated in the linked issue, almost all existing controlled-CZ cases already work this way; the only outlier is when control_values==[0]. Eliminating that outlier allows most of the special handling to be consolidated in the base gates (Z and X), so it's already possible to see some benefit from the added consistency.
1 parent f5228d7 commit 1c0ba5b

File tree

3 files changed

+44
-85
lines changed

3 files changed

+44
-85
lines changed

cirq-core/cirq/ops/common_gates.py

Lines changed: 24 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -222,25 +222,17 @@ def controlled(
222222
A `cirq.ControlledGate` (or `cirq.CXPowGate` if possible) representing
223223
`self` controlled by the given control values and qubits.
224224
"""
225-
if control_values and not isinstance(control_values, cv.AbstractControlValues):
226-
control_values = cv.ProductOfSums(
227-
tuple(
228-
(val,) if isinstance(val, int) else tuple(sorted(val)) for val in control_values
229-
)
230-
)
231225
result = super().controlled(num_controls, control_values, control_qid_shape)
232226
if (
233227
self._global_shift == 0
234228
and isinstance(result, controlled_gate.ControlledGate)
235229
and isinstance(result.control_values, cv.ProductOfSums)
236-
and result.control_values[-1] == (1,)
237-
and result.control_qid_shape[-1] == 2
230+
and result.control_values.is_trivial
238231
):
239-
return cirq.CXPowGate(
240-
exponent=self._exponent, global_shift=self._global_shift
241-
).controlled(
242-
result.num_controls() - 1, result.control_values[:-1], result.control_qid_shape[:-1]
243-
)
232+
if result.control_qid_shape == (2,):
233+
return cirq.CXPowGate(exponent=self._exponent)
234+
if result.control_qid_shape == (2, 2):
235+
return cirq.CCXPowGate(exponent=self._exponent)
244236
return result
245237

246238
def _pauli_expansion_(self) -> value.LinearDict[str]:
@@ -694,25 +686,17 @@ def controlled(
694686
A `cirq.ControlledGate` (or `cirq.CZPowGate` if possible) representing
695687
`self` controlled by the given control values and qubits.
696688
"""
697-
if control_values and not isinstance(control_values, cv.AbstractControlValues):
698-
control_values = cv.ProductOfSums(
699-
tuple(
700-
(val,) if isinstance(val, int) else tuple(sorted(val)) for val in control_values
701-
)
702-
)
703689
result = super().controlled(num_controls, control_values, control_qid_shape)
704690
if (
705691
self._global_shift == 0
706692
and isinstance(result, controlled_gate.ControlledGate)
707693
and isinstance(result.control_values, cv.ProductOfSums)
708-
and result.control_values[-1] == (1,)
709-
and result.control_qid_shape[-1] == 2
694+
and result.control_values.is_trivial
710695
):
711-
return cirq.CZPowGate(
712-
exponent=self._exponent, global_shift=self._global_shift
713-
).controlled(
714-
result.num_controls() - 1, result.control_values[:-1], result.control_qid_shape[:-1]
715-
)
696+
if result.control_qid_shape == (2,):
697+
return cirq.CZPowGate(exponent=self._exponent)
698+
if result.control_qid_shape == (2, 2):
699+
return cirq.CCZPowGate(exponent=self._exponent)
716700
return result
717701

718702
def _qid_shape_(self) -> tuple[int, ...]:
@@ -1138,26 +1122,14 @@ def controlled(
11381122
A `cirq.ControlledGate` (or `cirq.CCZPowGate` if possible) representing
11391123
`self` controlled by the given control values and qubits.
11401124
"""
1141-
if control_values and not isinstance(control_values, cv.AbstractControlValues):
1142-
control_values = cv.ProductOfSums(
1143-
tuple(
1144-
(val,) if isinstance(val, int) else tuple(sorted(val)) for val in control_values
1145-
)
1146-
)
11471125
result = super().controlled(num_controls, control_values, control_qid_shape)
1148-
if (
1149-
self._global_shift == 0
1150-
and isinstance(result, controlled_gate.ControlledGate)
1151-
and isinstance(result.control_values, cv.ProductOfSums)
1152-
and result.control_values[-1] == (1,)
1153-
and result.control_qid_shape[-1] == 2
1154-
):
1155-
return cirq.CCZPowGate(
1156-
exponent=self._exponent, global_shift=self._global_shift
1157-
).controlled(
1158-
result.num_controls() - 1, result.control_values[:-1], result.control_qid_shape[:-1]
1159-
)
1160-
return result
1126+
if self._global_shift != 0 or not isinstance(result, controlled_gate.ControlledGate):
1127+
return result
1128+
return ZPowGate(exponent=self.exponent).controlled(
1129+
num_controls=result.num_controls() + 1,
1130+
control_values=result.control_values & cv.ProductOfSums([1]),
1131+
control_qid_shape=result.control_qid_shape + (2,),
1132+
)
11611133

11621134
def _circuit_diagram_info_(self, args: cirq.CircuitDiagramInfoArgs) -> cirq.CircuitDiagramInfo:
11631135
return protocols.CircuitDiagramInfo(
@@ -1340,26 +1312,14 @@ def controlled(
13401312
A `cirq.ControlledGate` (or `cirq.CCXPowGate` if possible) representing
13411313
`self` controlled by the given control values and qubits.
13421314
"""
1343-
if control_values and not isinstance(control_values, cv.AbstractControlValues):
1344-
control_values = cv.ProductOfSums(
1345-
tuple(
1346-
(val,) if isinstance(val, int) else tuple(sorted(val)) for val in control_values
1347-
)
1348-
)
13491315
result = super().controlled(num_controls, control_values, control_qid_shape)
1350-
if (
1351-
self._global_shift == 0
1352-
and isinstance(result, controlled_gate.ControlledGate)
1353-
and isinstance(result.control_values, cv.ProductOfSums)
1354-
and result.control_values[-1] == (1,)
1355-
and result.control_qid_shape[-1] == 2
1356-
):
1357-
return cirq.CCXPowGate(
1358-
exponent=self._exponent, global_shift=self._global_shift
1359-
).controlled(
1360-
result.num_controls() - 1, result.control_values[:-1], result.control_qid_shape[:-1]
1361-
)
1362-
return result
1316+
if self._global_shift != 0 or not isinstance(result, controlled_gate.ControlledGate):
1317+
return result
1318+
return XPowGate(exponent=self.exponent).controlled(
1319+
num_controls=result.num_controls() + 1,
1320+
control_values=result.control_values & cv.ProductOfSums([1]),
1321+
control_qid_shape=result.control_qid_shape + (2,),
1322+
)
13631323

13641324
def _qasm_(self, args: cirq.QasmArgs, qubits: tuple[cirq.Qid, ...]) -> str | None:
13651325
if self._exponent != 1:

cirq-core/cirq/ops/common_gates_test.py

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -109,19 +109,19 @@ def test_z_init():
109109

110110

111111
@pytest.mark.parametrize(
112-
'input_gate, specialized_output',
112+
'input_gate, specialized_output, base_gate',
113113
[
114-
(cirq.Z, cirq.CZ),
115-
(cirq.CZ, cirq.CCZ),
116-
(cirq.X, cirq.CX),
117-
(cirq.CX, cirq.CCX),
118-
(cirq.ZPowGate(exponent=0.5), cirq.CZPowGate(exponent=0.5)),
119-
(cirq.CZPowGate(exponent=0.5), cirq.CCZPowGate(exponent=0.5)),
120-
(cirq.XPowGate(exponent=0.5), cirq.CXPowGate(exponent=0.5)),
121-
(cirq.CXPowGate(exponent=0.5), cirq.CCXPowGate(exponent=0.5)),
114+
(cirq.Z, cirq.CZ, cirq.Z),
115+
(cirq.CZ, cirq.CCZ, cirq.Z),
116+
(cirq.X, cirq.CX, cirq.X),
117+
(cirq.CX, cirq.CCX, cirq.X),
118+
(cirq.ZPowGate(exponent=0.5), cirq.CZPowGate(exponent=0.5), cirq.S),
119+
(cirq.CZPowGate(exponent=0.5), cirq.CCZPowGate(exponent=0.5), cirq.S),
120+
(cirq.XPowGate(exponent=0.5), cirq.CXPowGate(exponent=0.5), cirq.XPowGate(exponent=0.5)),
121+
(cirq.CXPowGate(exponent=0.5), cirq.CCXPowGate(exponent=0.5), cirq.XPowGate(exponent=0.5)),
122122
],
123123
)
124-
def test_specialized_control(input_gate, specialized_output):
124+
def test_specialized_control(input_gate, specialized_output, base_gate):
125125
# Single qubit control on the input gate gives the specialized output
126126
assert input_gate.controlled() == specialized_output
127127
assert input_gate.controlled(num_controls=1) == specialized_output
@@ -151,20 +151,24 @@ def test_specialized_control(input_gate, specialized_output):
151151
)
152152

153153
# When a control_value 1 qubit is not acting first, results in a regular
154-
# ControlledGate on the input gate instance.
154+
# ControlledGate on the base gate instance, with any extra control layer
155+
# of the input gate being absorbed into the ControlledGate.
156+
absorbed = 0 if base_gate == input_gate else 1
157+
absorbed_values = ((1,),) * absorbed
158+
absorbed_shape = (2,) * absorbed
155159
assert input_gate.controlled(num_controls=1, control_qid_shape=(3,)) == cirq.ControlledGate(
156-
input_gate, num_controls=1, control_qid_shape=(3,)
160+
base_gate, num_controls=1 + absorbed, control_qid_shape=(3,) + absorbed_shape
157161
)
158162
assert input_gate.controlled(control_values=((0,), (1,), (0,))) == cirq.ControlledGate(
159-
input_gate, num_controls=3, control_values=((0,), (1,), (0,))
163+
base_gate, num_controls=3 + absorbed, control_values=((0,), (1,), (0,)) + absorbed_values
160164
)
161165
assert input_gate.controlled(control_qid_shape=(3, 2, 3)) == cirq.ControlledGate(
162-
input_gate, num_controls=3, control_qid_shape=(3, 2, 3)
166+
base_gate, num_controls=3 + absorbed, control_qid_shape=(3, 2, 3) + absorbed_shape
163167
)
164168
assert input_gate.controlled(control_qid_shape=(3,)).controlled(
165169
control_qid_shape=(2,)
166170
).controlled(control_qid_shape=(4,)) != cirq.ControlledGate(
167-
input_gate, num_controls=3, control_qid_shape=(3, 2, 4)
171+
base_gate, num_controls=3 + absorbed, control_qid_shape=(3, 2, 4) + absorbed_shape
168172
)
169173

170174

cirq-core/cirq/ops/controlled_gate.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -151,12 +151,7 @@ def _decompose_with_context_(
151151
)
152152
# Prefer the subgate controlled version if available
153153
if self != controlled_sub_gate:
154-
# Prevent 2-cycle from appearing in the recursive decomposition
155-
# TODO: Remove after #7241 is resolved
156-
if not isinstance(controlled_sub_gate, ControlledGate) or not isinstance(
157-
controlled_sub_gate.sub_gate, common_gates.CZPowGate
158-
):
159-
return controlled_sub_gate.on(*qubits)
154+
return controlled_sub_gate.on(*qubits)
160155
if (
161156
protocols.has_unitary(self.sub_gate)
162157
and protocols.num_qubits(self.sub_gate) == 1

0 commit comments

Comments
 (0)