Skip to content

Commit d87d87e

Browse files
authored
Fix EigenGate equality (#7057)
* Fix equality for EigenGates * Handle zero-phase with symbolic exponents * Simplify code * Remove superfluous _value_equality_values_ * more tests * more tests * Fix pauli_interaction_gate change * Remove unnecessary EigenGate._value_equality_approximate_values_ override * resort imports * sort imports * remove unused import * Add details to EigenGate._value_equality_values_ docstring
1 parent a68adcd commit d87d87e

File tree

7 files changed

+83
-66
lines changed

7 files changed

+83
-66
lines changed

cirq-core/cirq/ops/common_gates.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -326,12 +326,6 @@ def _json_dict_(self) -> Dict[str, Any]:
326326
d['dimension'] = self.dimension
327327
return d
328328

329-
def _value_equality_values_(self):
330-
return (*super()._value_equality_values_(), self._dimension)
331-
332-
def _value_equality_approximate_values_(self):
333-
return (*super()._value_equality_approximate_values_(), self._dimension)
334-
335329

336330
class Rx(XPowGate):
337331
r"""A gate with matrix $e^{-i X t/2}$ that rotates around the X axis of the Bloch sphere by $t$.
@@ -862,12 +856,6 @@ def _json_dict_(self) -> Dict[str, Any]:
862856
d['dimension'] = self.dimension
863857
return d
864858

865-
def _value_equality_values_(self):
866-
return (*super()._value_equality_values_(), self._dimension)
867-
868-
def _value_equality_approximate_values_(self):
869-
return (*super()._value_equality_approximate_values_(), self._dimension)
870-
871859

872860
class Rz(ZPowGate):
873861
r"""A gate with matrix $e^{-i Z t/2}$ that rotates around the Z axis of the Bloch sphere by $t$.

cirq-core/cirq/ops/common_gates_test.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -245,11 +245,12 @@ def test_rot_gates_eq():
245245
eq.add_equality_group(cirq.YPowGate(), cirq.YPowGate(exponent=1), cirq.Y)
246246
eq.add_equality_group(cirq.ZPowGate(), cirq.ZPowGate(exponent=1), cirq.Z)
247247
eq.add_equality_group(
248-
cirq.ZPowGate(exponent=1, global_shift=-0.5), cirq.ZPowGate(exponent=5, global_shift=-0.5)
248+
cirq.ZPowGate(exponent=1, global_shift=-0.5),
249+
cirq.ZPowGate(exponent=5, global_shift=-0.5),
250+
cirq.ZPowGate(exponent=5, global_shift=-0.1),
249251
)
250252
eq.add_equality_group(cirq.ZPowGate(exponent=3, global_shift=-0.5))
251253
eq.add_equality_group(cirq.ZPowGate(exponent=1, global_shift=-0.1))
252-
eq.add_equality_group(cirq.ZPowGate(exponent=5, global_shift=-0.1))
253254
eq.add_equality_group(
254255
cirq.CNotPowGate(), cirq.CXPowGate(), cirq.CNotPowGate(exponent=1), cirq.CNOT
255256
)

cirq-core/cirq/ops/eigen_gate.py

Lines changed: 12 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@
3434
import sympy
3535

3636
from cirq import protocols, value
37-
from cirq.linalg import tolerance
3837
from cirq.ops import raw_types
3938

4039
if TYPE_CHECKING:
@@ -122,7 +121,6 @@ def __init__(
122121
exponent = exponent.real
123122
self._exponent = exponent
124123
self._global_shift = global_shift
125-
self._canonical_exponent_cached = None
126124

127125
@property
128126
def exponent(self) -> value.TParamVal:
@@ -305,30 +303,19 @@ def __pow__(self, exponent: Union[float, sympy.Symbol]) -> 'EigenGate':
305303
return NotImplemented # pragma: no cover
306304
return self._with_exponent(exponent=new_exponent)
307305

308-
@property
309-
def _canonical_exponent(self):
310-
if self._canonical_exponent_cached is None:
311-
period = self._period()
312-
if not period:
313-
self._canonical_exponent_cached = self._exponent
314-
elif protocols.is_parameterized(self._exponent):
315-
self._canonical_exponent_cached = self._exponent
316-
if isinstance(self._exponent, sympy.Number):
317-
self._canonical_exponent_cached = float(self._exponent)
318-
else:
319-
self._canonical_exponent_cached = self._exponent % period
320-
return self._canonical_exponent_cached
321-
322306
def _value_equality_values_(self):
323-
return self._canonical_exponent, self._global_shift
307+
"""The phases by which we multiply the eigenspaces.
324308
325-
def _value_equality_approximate_values_(self):
326-
period = self._period()
327-
if not period or protocols.is_parameterized(self._exponent):
328-
exponent = self._exponent
329-
else:
330-
exponent = value.PeriodicValue(self._exponent, period)
331-
return exponent, self._global_shift
309+
The default implementation assumes that the eigenspaces are constant
310+
for the class, and the eigenphases are the only distinguishing
311+
characteristics. For gates whose eigenspaces can change, such as
312+
`PhasedISwapPowGate`, this must be overridden to provide the additional
313+
fields that affect the eigenspaces.
314+
"""
315+
symbolic = lambda x: isinstance(x, sympy.Expr) and x.free_symbols
316+
f = lambda x: x if symbolic(x) else float(x)
317+
shifts = (f(self._exponent) * f(self._global_shift + e) for e in self._eigen_shifts())
318+
return tuple(s if symbolic(s) else value.PeriodicValue(f(s), 2) for s in shifts)
332319

333320
def _trace_distance_bound_(self) -> Optional[float]:
334321
if protocols.is_parameterized(self._exponent):
@@ -378,20 +365,9 @@ def _equal_up_to_global_phase_(self, other, atol):
378365
return False
379366
self_without_phase = self._with_exponent(self.exponent)
380367
self_without_phase._global_shift = 0
381-
self_without_exp_or_phase = self_without_phase._with_exponent(0)
382-
self_without_exp_or_phase._global_shift = 0
383368
other_without_phase = other._with_exponent(other.exponent)
384369
other_without_phase._global_shift = 0
385-
other_without_exp_or_phase = other_without_phase._with_exponent(0)
386-
other_without_exp_or_phase._global_shift = 0
387-
if not protocols.approx_eq(
388-
self_without_exp_or_phase, other_without_exp_or_phase, atol=atol
389-
):
390-
return False
391-
392-
period = self_without_phase._period()
393-
exponents_diff = exponents[0] - exponents[1]
394-
return tolerance.near_zero_mod(exponents_diff, period, atol=atol)
370+
return protocols.approx_eq(self_without_phase, other_without_phase, atol=atol)
395371

396372
def _json_dict_(self) -> Dict[str, Any]:
397373
return protocols.obj_to_dict_helper(self, ['exponent', 'global_shift'])

cirq-core/cirq/ops/eigen_gate_test.py

Lines changed: 60 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import re
1615
from typing import List, Tuple
1716

1817
import numpy as np
@@ -50,7 +49,7 @@ def _eigen_components(self) -> List[Tuple[float, np.ndarray]]:
5049
]
5150

5251

53-
class ZGateDef(cirq.EigenGate, cirq.testing.TwoQubitGate):
52+
class ZGateDef(cirq.EigenGate, cirq.testing.SingleQubitGate):
5453
@property
5554
def exponent(self):
5655
return self._exponent
@@ -97,7 +96,6 @@ def test_eq():
9796
eq.make_equality_group(lambda: CExpZinGate(quarter_turns=0.1))
9897
eq.add_equality_group(CExpZinGate(0), CExpZinGate(4), CExpZinGate(-4))
9998

100-
# Equates by canonicalized period.
10199
eq.add_equality_group(CExpZinGate(1.5), CExpZinGate(41.5))
102100
eq.add_equality_group(CExpZinGate(3.5), CExpZinGate(-0.5))
103101

@@ -109,6 +107,64 @@ def test_eq():
109107
eq.add_equality_group(ZGateDef(exponent=0.5, global_shift=0.5))
110108
eq.add_equality_group(ZGateDef(exponent=1.0, global_shift=0.5))
111109

110+
# All variants of (0,0) == (0*a,0*a) == (0, 2) == (2, 2)
111+
a, b = sympy.symbols('a, b')
112+
eq.add_equality_group(
113+
WeightedZPowGate(0),
114+
WeightedZPowGate(0) ** 1.1,
115+
WeightedZPowGate(0) ** a,
116+
(WeightedZPowGate(0) ** a) ** 1.2,
117+
WeightedZPowGate(0) ** (a + 1.3),
118+
WeightedZPowGate(0) ** b,
119+
WeightedZPowGate(1) ** 2,
120+
WeightedZPowGate(0, global_shift=1) ** 2,
121+
WeightedZPowGate(1, global_shift=1) ** 2,
122+
WeightedZPowGate(2),
123+
WeightedZPowGate(0, global_shift=2),
124+
WeightedZPowGate(2, global_shift=2),
125+
)
126+
# WeightedZPowGate(2) is identity, but non-integer exponent would make it different, similar to
127+
# how we treat (X**2)**0.5==X. So these are in their own equality group. (0, 2*a)
128+
eq.add_equality_group(
129+
WeightedZPowGate(2) ** a,
130+
(WeightedZPowGate(1) ** 2) ** a,
131+
(WeightedZPowGate(1) ** a) ** 2,
132+
WeightedZPowGate(1) ** (a * 2),
133+
WeightedZPowGate(1) ** (a + a),
134+
)
135+
# Similarly, these are identity without the exponent, but global_shift affects both phases
136+
# instead of just the one, so will have a different effect from the above depending on the
137+
# exponent. (2*a, 0)
138+
eq.add_equality_group(
139+
WeightedZPowGate(0, global_shift=2) ** a,
140+
(WeightedZPowGate(0, global_shift=1) ** 2) ** a,
141+
(WeightedZPowGate(0, global_shift=1) ** a) ** 2,
142+
WeightedZPowGate(0, global_shift=1) ** (a * 2),
143+
WeightedZPowGate(0, global_shift=1) ** (a + a),
144+
)
145+
# Symbolic exponents that cancel (0, 1) == (0, a/a)
146+
eq.add_equality_group(
147+
WeightedZPowGate(1),
148+
WeightedZPowGate(a) ** (1 / a),
149+
WeightedZPowGate(b) ** (1 / b),
150+
WeightedZPowGate(1 / a) ** a,
151+
WeightedZPowGate(1 / b) ** b,
152+
)
153+
# Symbol in one phase and constant off by period in another (0, a) == (2, a)
154+
eq.add_equality_group(
155+
WeightedZPowGate(a),
156+
WeightedZPowGate(a - 2, global_shift=2),
157+
WeightedZPowGate(1 - 2 / a, global_shift=2 / a) ** a,
158+
)
159+
# Different symbol, different equality group (0, b)
160+
eq.add_equality_group(WeightedZPowGate(b))
161+
# Various number types
162+
eq.add_equality_group(
163+
WeightedZPowGate(np.int64(3), global_shift=sympy.Number(5)) ** 7.0,
164+
WeightedZPowGate(sympy.Number(3), global_shift=5.0) ** np.int64(7),
165+
WeightedZPowGate(3.0, global_shift=np.int64(5)) ** sympy.Number(7),
166+
)
167+
112168

113169
def test_approx_eq():
114170
assert cirq.approx_eq(CExpZinGate(1.5), CExpZinGate(1.5), atol=0.1)
@@ -118,8 +174,7 @@ def test_approx_eq():
118174
assert cirq.approx_eq(ZGateDef(exponent=1.5), ZGateDef(exponent=1.5), atol=0.1)
119175
assert not cirq.approx_eq(CExpZinGate(1.5), ZGateDef(exponent=1.5), atol=0.1)
120176
with pytest.raises(
121-
TypeError,
122-
match=re.escape("unsupported operand type(s) for -: 'Symbol' and 'PeriodicValue'"),
177+
TypeError, match="unsupported operand type\\(s\\) for -: '.*' and 'PeriodicValue'"
123178
):
124179
cirq.approx_eq(ZGateDef(exponent=1.5), ZGateDef(exponent=sympy.Symbol('a')), atol=0.1)
125180
assert cirq.approx_eq(CExpZinGate(sympy.Symbol('a')), CExpZinGate(sympy.Symbol('a')), atol=0.1)
@@ -333,11 +388,6 @@ def __init__(self, weight, **kwargs):
333388
self.weight = weight
334389
super().__init__(**kwargs)
335390

336-
def _value_equality_values_(self):
337-
return self.weight, self._canonical_exponent, self._global_shift
338-
339-
_value_equality_approximate_values_ = _value_equality_values_
340-
341391
def _eigen_components(self) -> List[Tuple[float, np.ndarray]]:
342392
return [(0, np.diag([1, 0])), (self.weight, np.diag([0, 1]))]
343393

cirq-core/cirq/ops/parity_gates_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ def test_xx_eq():
3939
cirq.XXPowGate(),
4040
cirq.XXPowGate(exponent=1, global_shift=0),
4141
cirq.XXPowGate(exponent=3, global_shift=0),
42+
cirq.XXPowGate(global_shift=100000),
4243
)
4344
eq.add_equality_group(cirq.XX**0.5, cirq.XX**2.5, cirq.XX**4.5)
4445
eq.add_equality_group(cirq.XX**0.25, cirq.XX**2.25, cirq.XX**-1.75)

cirq-core/cirq/ops/pauli_gates.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -103,11 +103,6 @@ def on(self, *qubits: 'cirq.Qid') -> 'SingleQubitPauliStringGateOperation':
103103

104104
return pauli_string.SingleQubitPauliStringGateOperation(self, qubits[0])
105105

106-
@property
107-
def _canonical_exponent(self):
108-
"""Overrides EigenGate._canonical_exponent in subclasses."""
109-
return 1
110-
111106

112107
class _PauliX(Pauli, common_gates.XPowGate):
113108
def __init__(self):

cirq-core/cirq/ops/pauli_interaction_gate.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,13 @@ def _num_qubits_(self) -> int:
8585
return 2
8686

8787
def _value_equality_values_(self):
88-
return (self.pauli0, self.invert0, self.pauli1, self.invert1, self._canonical_exponent)
88+
return (
89+
self.pauli0,
90+
self.invert0,
91+
self.pauli1,
92+
self.invert1,
93+
value.PeriodicValue(self.exponent, 2),
94+
)
8995

9096
def qubit_index_to_equivalence_group_key(self, index: int) -> int:
9197
if self.pauli0 == self.pauli1 and self.invert0 == self.invert1:

0 commit comments

Comments
 (0)