1313# limitations under the License.
1414
1515import itertools
16- from typing import Any , Dict , List , Optional , TYPE_CHECKING , Union
16+ from typing import Any , Dict , List , Optional , Tuple , TYPE_CHECKING , Union
1717
1818from cirq import ops , protocols , value
1919from cirq .transformers import transformer_api , transformer_primitives
@@ -46,7 +46,7 @@ def dimension(self) -> int:
4646 return self ._qid .dimension
4747
4848 def _comparison_key (self ) -> Any :
49- return ( str (self ._key ), self ._qid ._comparison_key () )
49+ return str (self ._key ), self ._qid ._comparison_key ()
5050
5151 def __str__ (self ) -> str :
5252 return f"M('{ self ._key } ', q={ self ._qid } )"
@@ -104,7 +104,7 @@ def defer(op: 'cirq.Operation', _) -> 'cirq.OP_TREE':
104104 key = value .MeasurementKey .parse_serialized (gate .key )
105105 targets = [_MeasurementQid (key , q ) for q in op .qubits ]
106106 measurement_qubits [key ] = targets
107- cxs = [ops . CX (q , target ) for q , target in zip (op .qubits , targets )]
107+ cxs = [_mod_add (q , target ) for q , target in zip (op .qubits , targets )]
108108 xs = [ops .X (targets [i ]) for i , b in enumerate (gate .full_invert_mask ()) if b ]
109109 return cxs + xs
110110 elif protocols .is_measurement (op ):
@@ -117,7 +117,7 @@ def defer(op: 'cirq.Operation', _) -> 'cirq.OP_TREE':
117117 raise ValueError (f'Deferred measurement for key={ c .key } not found.' )
118118 qs = measurement_qubits [c .key ]
119119 if len (qs ) == 1 :
120- control_values : Any = range (1 , qs [0 ].dimension )
120+ control_values : Any = [ range (1 , qs [0 ].dimension )]
121121 else :
122122 all_values = itertools .product (* [range (q .dimension ) for q in qs ])
123123 anything_but_all_zeros = tuple (itertools .islice (all_values , 1 , None ))
@@ -227,3 +227,38 @@ def flip_inversion(op: 'cirq.Operation', _) -> 'cirq.OP_TREE':
227227 return transformer_primitives .map_operations (
228228 circuit , flip_inversion , deep = context .deep if context else True , tags_to_ignore = ignored
229229 ).unfreeze ()
230+
231+
232+ @value .value_equality
233+ class _ModAdd (ops .ArithmeticGate ):
234+ """Adds two qudits of the same dimension.
235+
236+ Operates on two qudits by modular addition:
237+
238+ |a,b> -> |a,a+b mod d>"""
239+
240+ def __init__ (self , dimension : int ):
241+ self ._dimension = dimension
242+
243+ def registers (self ) -> Tuple [Tuple [int ], Tuple [int ]]:
244+ return (self ._dimension ,), (self ._dimension ,)
245+
246+ def with_registers (self , * new_registers ) -> '_ModAdd' :
247+ raise NotImplementedError ()
248+
249+ def apply (self , * register_values : int ) -> Tuple [int , int ]:
250+ return register_values [0 ], sum (register_values )
251+
252+ def _value_equality_values_ (self ) -> int :
253+ return self ._dimension
254+
255+
256+ def _mod_add (source : 'cirq.Qid' , target : 'cirq.Qid' ) -> 'cirq.Operation' :
257+ assert source .dimension == target .dimension
258+ if source .dimension == 2 :
259+ # Use a CX gate in 2D case for simplicity.
260+ return ops .CX (source , target )
261+ # We can use a ModAdd gate in the qudit case, since the ancilla qudit corresponding to the
262+ # measurement is always zero, so "adding" the measured qudit to it sets the ancilla qudit to
263+ # the same state, which is the quantum equivalent to a measurement onto a creg.
264+ return _ModAdd (source .dimension ).on (source , target )
0 commit comments