1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15- import re
1615from typing import List , Tuple
1716
1817import numpy as np
@@ -50,7 +49,7 @@ def _eigen_components(self) -> List[Tuple[float, np.ndarray]]:
5049 ]
5150
5251
53- class ZGateDef (cirq .EigenGate , cirq .testing .TwoQubitGate ):
52+ class ZGateDef (cirq .EigenGate , cirq .testing .SingleQubitGate ):
5453 @property
5554 def exponent (self ):
5655 return self ._exponent
@@ -97,7 +96,6 @@ def test_eq():
9796 eq .make_equality_group (lambda : CExpZinGate (quarter_turns = 0.1 ))
9897 eq .add_equality_group (CExpZinGate (0 ), CExpZinGate (4 ), CExpZinGate (- 4 ))
9998
100- # Equates by canonicalized period.
10199 eq .add_equality_group (CExpZinGate (1.5 ), CExpZinGate (41.5 ))
102100 eq .add_equality_group (CExpZinGate (3.5 ), CExpZinGate (- 0.5 ))
103101
@@ -109,6 +107,64 @@ def test_eq():
109107 eq .add_equality_group (ZGateDef (exponent = 0.5 , global_shift = 0.5 ))
110108 eq .add_equality_group (ZGateDef (exponent = 1.0 , global_shift = 0.5 ))
111109
110+ # All variants of (0,0) == (0*a,0*a) == (0, 2) == (2, 2)
111+ a , b = sympy .symbols ('a, b' )
112+ eq .add_equality_group (
113+ WeightedZPowGate (0 ),
114+ WeightedZPowGate (0 ) ** 1.1 ,
115+ WeightedZPowGate (0 ) ** a ,
116+ (WeightedZPowGate (0 ) ** a ) ** 1.2 ,
117+ WeightedZPowGate (0 ) ** (a + 1.3 ),
118+ WeightedZPowGate (0 ) ** b ,
119+ WeightedZPowGate (1 ) ** 2 ,
120+ WeightedZPowGate (0 , global_shift = 1 ) ** 2 ,
121+ WeightedZPowGate (1 , global_shift = 1 ) ** 2 ,
122+ WeightedZPowGate (2 ),
123+ WeightedZPowGate (0 , global_shift = 2 ),
124+ WeightedZPowGate (2 , global_shift = 2 ),
125+ )
126+ # WeightedZPowGate(2) is identity, but non-integer exponent would make it different, similar to
127+ # how we treat (X**2)**0.5==X. So these are in their own equality group. (0, 2*a)
128+ eq .add_equality_group (
129+ WeightedZPowGate (2 ) ** a ,
130+ (WeightedZPowGate (1 ) ** 2 ) ** a ,
131+ (WeightedZPowGate (1 ) ** a ) ** 2 ,
132+ WeightedZPowGate (1 ) ** (a * 2 ),
133+ WeightedZPowGate (1 ) ** (a + a ),
134+ )
135+ # Similarly, these are identity without the exponent, but global_shift affects both phases
136+ # instead of just the one, so will have a different effect from the above depending on the
137+ # exponent. (2*a, 0)
138+ eq .add_equality_group (
139+ WeightedZPowGate (0 , global_shift = 2 ) ** a ,
140+ (WeightedZPowGate (0 , global_shift = 1 ) ** 2 ) ** a ,
141+ (WeightedZPowGate (0 , global_shift = 1 ) ** a ) ** 2 ,
142+ WeightedZPowGate (0 , global_shift = 1 ) ** (a * 2 ),
143+ WeightedZPowGate (0 , global_shift = 1 ) ** (a + a ),
144+ )
145+ # Symbolic exponents that cancel (0, 1) == (0, a/a)
146+ eq .add_equality_group (
147+ WeightedZPowGate (1 ),
148+ WeightedZPowGate (a ) ** (1 / a ),
149+ WeightedZPowGate (b ) ** (1 / b ),
150+ WeightedZPowGate (1 / a ) ** a ,
151+ WeightedZPowGate (1 / b ) ** b ,
152+ )
153+ # Symbol in one phase and constant off by period in another (0, a) == (2, a)
154+ eq .add_equality_group (
155+ WeightedZPowGate (a ),
156+ WeightedZPowGate (a - 2 , global_shift = 2 ),
157+ WeightedZPowGate (1 - 2 / a , global_shift = 2 / a ) ** a ,
158+ )
159+ # Different symbol, different equality group (0, b)
160+ eq .add_equality_group (WeightedZPowGate (b ))
161+ # Various number types
162+ eq .add_equality_group (
163+ WeightedZPowGate (np .int64 (3 ), global_shift = sympy .Number (5 )) ** 7.0 ,
164+ WeightedZPowGate (sympy .Number (3 ), global_shift = 5.0 ) ** np .int64 (7 ),
165+ WeightedZPowGate (3.0 , global_shift = np .int64 (5 )) ** sympy .Number (7 ),
166+ )
167+
112168
113169def test_approx_eq ():
114170 assert cirq .approx_eq (CExpZinGate (1.5 ), CExpZinGate (1.5 ), atol = 0.1 )
@@ -118,8 +174,7 @@ def test_approx_eq():
118174 assert cirq .approx_eq (ZGateDef (exponent = 1.5 ), ZGateDef (exponent = 1.5 ), atol = 0.1 )
119175 assert not cirq .approx_eq (CExpZinGate (1.5 ), ZGateDef (exponent = 1.5 ), atol = 0.1 )
120176 with pytest .raises (
121- TypeError ,
122- match = re .escape ("unsupported operand type(s) for -: 'Symbol' and 'PeriodicValue'" ),
177+ TypeError , match = "unsupported operand type\\ (s\\ ) for -: '.*' and 'PeriodicValue'"
123178 ):
124179 cirq .approx_eq (ZGateDef (exponent = 1.5 ), ZGateDef (exponent = sympy .Symbol ('a' )), atol = 0.1 )
125180 assert cirq .approx_eq (CExpZinGate (sympy .Symbol ('a' )), CExpZinGate (sympy .Symbol ('a' )), atol = 0.1 )
@@ -333,11 +388,6 @@ def __init__(self, weight, **kwargs):
333388 self .weight = weight
334389 super ().__init__ (** kwargs )
335390
336- def _value_equality_values_ (self ):
337- return self .weight , self ._canonical_exponent , self ._global_shift
338-
339- _value_equality_approximate_values_ = _value_equality_values_
340-
341391 def _eigen_components (self ) -> List [Tuple [float , np .ndarray ]]:
342392 return [(0 , np .diag ([1 , 0 ])), (self .weight , np .diag ([0 , 1 ]))]
343393
0 commit comments