Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cirq-core/cirq/ops/gate_operation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,7 @@ def all_subclasses(cls):
cirq.Pauli,
# Private gates.
cirq.transformers.analytical_decompositions.two_qubit_to_fsim._BGate,
cirq.transformers.measurement_transformers._Add,
cirq.ops.raw_types._InverseCompositeGate,
cirq.circuits.qasm_output.QasmTwoQubitGate,
cirq.ops.MSGate,
Expand Down
37 changes: 34 additions & 3 deletions cirq-core/cirq/transformers/measurement_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def dimension(self) -> int:
return self._qid.dimension

def _comparison_key(self) -> Any:
return (str(self._key), self._qid._comparison_key())
return str(self._key), self._qid._comparison_key()

def __str__(self) -> str:
return f"M('{self._key}', q={self._qid})"
Expand Down Expand Up @@ -104,7 +104,7 @@ def defer(op: 'cirq.Operation', _) -> 'cirq.OP_TREE':
key = value.MeasurementKey.parse_serialized(gate.key)
targets = [_MeasurementQid(key, q) for q in op.qubits]
measurement_qubits[key] = targets
cxs = [ops.CX(q, target) for q, target in zip(op.qubits, targets)]
cxs = [_cx(q.dimension).on(q, target) for q, target in zip(op.qubits, targets)]
Comment thread
daxfohl marked this conversation as resolved.
Outdated
xs = [ops.X(targets[i]) for i, b in enumerate(gate.full_invert_mask()) if b]
return cxs + xs
elif protocols.is_measurement(op):
Expand All @@ -117,7 +117,7 @@ def defer(op: 'cirq.Operation', _) -> 'cirq.OP_TREE':
raise ValueError(f'Deferred measurement for key={c.key} not found.')
qs = measurement_qubits[c.key]
if len(qs) == 1:
control_values: Any = range(1, qs[0].dimension)
control_values: Any = [range(1, qs[0].dimension)]
else:
all_values = itertools.product(*[range(q.dimension) for q in qs])
anything_but_all_zeros = tuple(itertools.islice(all_values, 1, None))
Expand Down Expand Up @@ -227,3 +227,34 @@ def flip_inversion(op: 'cirq.Operation', _) -> 'cirq.OP_TREE':
return transformer_primitives.map_operations(
circuit, flip_inversion, deep=context.deep if context else True, tags_to_ignore=ignored
).unfreeze()


@value.value_equality
class _Add(ops.ArithmeticGate):
Comment thread
daxfohl marked this conversation as resolved.
Outdated
"""Adds two qudits of the same dimension.

Operates on two qudits by modular addition:

|a,b> -> |a,a+b mod d>"""

def __init__(self, dimension: int):
self._dimension = dimension

def registers(self):
return (self._dimension,), (self._dimension,)

def with_registers(self, *new_registers):
raise NotImplementedError()

def apply(self, input_value, target_value):
return input_value, target_value + input_value

def _value_equality_values_(self):
return self._dimension
Comment thread
daxfohl marked this conversation as resolved.


def _cx(dimension: int):
Comment thread
daxfohl marked this conversation as resolved.
Outdated
Comment thread
daxfohl marked this conversation as resolved.
Outdated
# We can use an Add gate in the qudit case, since the ancilla qudit corresponding to the
# measurement is always zero, so "adding" the measured qudit to it sets the ancilla qudit to
# the same state.
return ops.CX if dimension == 2 else _Add(dimension)
31 changes: 26 additions & 5 deletions cirq-core/cirq/transformers/measurement_transformers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,19 @@
import sympy

import cirq
from cirq.transformers.measurement_transformers import _MeasurementQid
from cirq.transformers.measurement_transformers import _cx, _MeasurementQid


def assert_equivalent_to_deferred(circuit: cirq.Circuit):
qubits = list(circuit.all_qubits())
sim = cirq.Simulator()
num_qubits = len(qubits)
for i in range(2**num_qubits):
bits = cirq.big_endian_int_to_bits(i, bit_count=num_qubits)
dimensions = [q.dimension for q in qubits]
for i in range(np.prod(dimensions)):
bits = cirq.big_endian_int_to_digits(i, base=dimensions)
modified = cirq.Circuit()
for j in range(num_qubits):
if bits[j]:
modified.append(cirq.X(qubits[j]))
modified.append(cirq.XPowGate(dimension=qubits[j].dimension)(qubits[j]) ** bits[j])
modified.append(circuit)
deferred = cirq.defer_measurements(modified)
result = sim.simulate(modified)
Expand Down Expand Up @@ -58,6 +58,27 @@ def test_basic():
)


def test_qudits():
q0, q1 = cirq.LineQid.range(2, dimension=3)
circuit = cirq.Circuit(
cirq.measure(q0, key='a'),
cirq.XPowGate(dimension=3).on(q1).with_classical_controls('a'),
cirq.measure(q1, key='b'),
)
assert_equivalent_to_deferred(circuit)
deferred = cirq.defer_measurements(circuit)
q_ma = _MeasurementQid('a', q0)
cirq.testing.assert_same_circuits(
deferred,
cirq.Circuit(
_cx(3)(q0, q_ma),
cirq.XPowGate(dimension=3).on(q1).controlled_by(q_ma, control_values=[[1, 2]]),
cirq.measure(q_ma, key='a'),
cirq.measure(q1, key='b'),
),
)


def test_nocompile_context():
q0, q1 = cirq.LineQubit.range(2)
circuit = cirq.Circuit(
Expand Down