Skip to content

Commit 80a9001

Browse files
authored
Allow qudits in deferred measurements (#5850)
For this, we have to define a multidimensional ModAdd gate, for use in applying the state from the source qudit to the ancilla qudit representing the creg. That done, we insert it into the deferred measurements algorithm instead of the ordinary CX gate, and add a qudit test to make sure it all works. cc @viathor for sanity check on the gate logic
1 parent df7d313 commit 80a9001

File tree

3 files changed

+66
-9
lines changed

3 files changed

+66
-9
lines changed

cirq-core/cirq/ops/gate_operation_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -494,6 +494,7 @@ def all_subclasses(cls):
494494
cirq.Pauli,
495495
# Private gates.
496496
cirq.transformers.analytical_decompositions.two_qubit_to_fsim._BGate,
497+
cirq.transformers.measurement_transformers._ModAdd,
497498
cirq.transformers.routing.visualize_routed_circuit._SwapPrintGate,
498499
cirq.ops.raw_types._InverseCompositeGate,
499500
cirq.circuits.qasm_output.QasmTwoQubitGate,

cirq-core/cirq/transformers/measurement_transformers.py

Lines changed: 39 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
import itertools
16-
from typing import Any, Dict, List, Optional, TYPE_CHECKING, Union
16+
from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING, Union
1717

1818
from cirq import ops, protocols, value
1919
from cirq.transformers import transformer_api, transformer_primitives
@@ -46,7 +46,7 @@ def dimension(self) -> int:
4646
return self._qid.dimension
4747

4848
def _comparison_key(self) -> Any:
49-
return (str(self._key), self._qid._comparison_key())
49+
return str(self._key), self._qid._comparison_key()
5050

5151
def __str__(self) -> str:
5252
return f"M('{self._key}', q={self._qid})"
@@ -104,7 +104,7 @@ def defer(op: 'cirq.Operation', _) -> 'cirq.OP_TREE':
104104
key = value.MeasurementKey.parse_serialized(gate.key)
105105
targets = [_MeasurementQid(key, q) for q in op.qubits]
106106
measurement_qubits[key] = targets
107-
cxs = [ops.CX(q, target) for q, target in zip(op.qubits, targets)]
107+
cxs = [_mod_add(q, target) for q, target in zip(op.qubits, targets)]
108108
xs = [ops.X(targets[i]) for i, b in enumerate(gate.full_invert_mask()) if b]
109109
return cxs + xs
110110
elif protocols.is_measurement(op):
@@ -117,7 +117,7 @@ def defer(op: 'cirq.Operation', _) -> 'cirq.OP_TREE':
117117
raise ValueError(f'Deferred measurement for key={c.key} not found.')
118118
qs = measurement_qubits[c.key]
119119
if len(qs) == 1:
120-
control_values: Any = range(1, qs[0].dimension)
120+
control_values: Any = [range(1, qs[0].dimension)]
121121
else:
122122
all_values = itertools.product(*[range(q.dimension) for q in qs])
123123
anything_but_all_zeros = tuple(itertools.islice(all_values, 1, None))
@@ -227,3 +227,38 @@ def flip_inversion(op: 'cirq.Operation', _) -> 'cirq.OP_TREE':
227227
return transformer_primitives.map_operations(
228228
circuit, flip_inversion, deep=context.deep if context else True, tags_to_ignore=ignored
229229
).unfreeze()
230+
231+
232+
@value.value_equality
233+
class _ModAdd(ops.ArithmeticGate):
234+
"""Adds two qudits of the same dimension.
235+
236+
Operates on two qudits by modular addition:
237+
238+
|a,b> -> |a,a+b mod d>"""
239+
240+
def __init__(self, dimension: int):
241+
self._dimension = dimension
242+
243+
def registers(self) -> Tuple[Tuple[int], Tuple[int]]:
244+
return (self._dimension,), (self._dimension,)
245+
246+
def with_registers(self, *new_registers) -> '_ModAdd':
247+
raise NotImplementedError()
248+
249+
def apply(self, *register_values: int) -> Tuple[int, int]:
250+
return register_values[0], sum(register_values)
251+
252+
def _value_equality_values_(self) -> int:
253+
return self._dimension
254+
255+
256+
def _mod_add(source: 'cirq.Qid', target: 'cirq.Qid') -> 'cirq.Operation':
257+
assert source.dimension == target.dimension
258+
if source.dimension == 2:
259+
# Use a CX gate in 2D case for simplicity.
260+
return ops.CX(source, target)
261+
# We can use a ModAdd gate in the qudit case, since the ancilla qudit corresponding to the
262+
# measurement is always zero, so "adding" the measured qudit to it sets the ancilla qudit to
263+
# the same state, which is the quantum equivalent to a measurement onto a creg.
264+
return _ModAdd(source.dimension).on(source, target)

cirq-core/cirq/transformers/measurement_transformers_test.py

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,19 +17,19 @@
1717
import sympy
1818

1919
import cirq
20-
from cirq.transformers.measurement_transformers import _MeasurementQid
20+
from cirq.transformers.measurement_transformers import _mod_add, _MeasurementQid
2121

2222

2323
def assert_equivalent_to_deferred(circuit: cirq.Circuit):
2424
qubits = list(circuit.all_qubits())
2525
sim = cirq.Simulator()
2626
num_qubits = len(qubits)
27-
for i in range(2**num_qubits):
28-
bits = cirq.big_endian_int_to_bits(i, bit_count=num_qubits)
27+
dimensions = [q.dimension for q in qubits]
28+
for i in range(np.prod(dimensions)):
29+
bits = cirq.big_endian_int_to_digits(i, base=dimensions)
2930
modified = cirq.Circuit()
3031
for j in range(num_qubits):
31-
if bits[j]:
32-
modified.append(cirq.X(qubits[j]))
32+
modified.append(cirq.XPowGate(dimension=qubits[j].dimension)(qubits[j]) ** bits[j])
3333
modified.append(circuit)
3434
deferred = cirq.defer_measurements(modified)
3535
result = sim.simulate(modified)
@@ -58,6 +58,27 @@ def test_basic():
5858
)
5959

6060

61+
def test_qudits():
62+
q0, q1 = cirq.LineQid.range(2, dimension=3)
63+
circuit = cirq.Circuit(
64+
cirq.measure(q0, key='a'),
65+
cirq.XPowGate(dimension=3).on(q1).with_classical_controls('a'),
66+
cirq.measure(q1, key='b'),
67+
)
68+
assert_equivalent_to_deferred(circuit)
69+
deferred = cirq.defer_measurements(circuit)
70+
q_ma = _MeasurementQid('a', q0)
71+
cirq.testing.assert_same_circuits(
72+
deferred,
73+
cirq.Circuit(
74+
_mod_add(q0, q_ma),
75+
cirq.XPowGate(dimension=3).on(q1).controlled_by(q_ma, control_values=[[1, 2]]),
76+
cirq.measure(q_ma, key='a'),
77+
cirq.measure(q1, key='b'),
78+
),
79+
)
80+
81+
6182
def test_nocompile_context():
6283
q0, q1 = cirq.LineQubit.range(2)
6384
circuit = cirq.Circuit(

0 commit comments

Comments
 (0)