Skip to content
Merged
Show file tree
Hide file tree
Changes from 45 commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
b4849ca
Add handling for sympy conditions in deferred measurement transformer
daxfohl Aug 12, 2022
c88df71
docstring
daxfohl Aug 12, 2022
d9f2776
mypy
daxfohl Aug 12, 2022
48c5736
mypy
daxfohl Aug 12, 2022
17f5e80
cover
daxfohl Aug 12, 2022
d74a699
Make this more generic, covers all kinds of conditions.
daxfohl Aug 12, 2022
d37cbad
Better docs
daxfohl Aug 13, 2022
2c1d003
Sympy can also be CX
daxfohl Aug 13, 2022
27ff43d
docs
daxfohl Aug 13, 2022
78ef78d
docs
daxfohl Aug 13, 2022
4dcb564
Merge branch 'master' into sympy-deferred
tanujkhattar Aug 25, 2022
6ffe765
Merge branch 'master' into sympy-deferred
tanujkhattar Aug 25, 2022
0292f3c
Allow repeated measurements in deferred transformer
daxfohl Sep 4, 2022
773e660
Coverage
daxfohl Sep 5, 2022
ed1257d
Merge branch 'master' into sympy-deferred
daxfohl Sep 5, 2022
6bcc71c
Add mixed tests, simplify loop, add simplification in ControlledGate
daxfohl Sep 5, 2022
8f045bc
Fix error message
daxfohl Sep 5, 2022
1c32404
Simplify error message
daxfohl Sep 6, 2022
9fd971b
Inline variable
daxfohl Sep 14, 2022
4a484b0
Merge branch 'master' into sympy-deferred
daxfohl Sep 21, 2022
96aff58
Merge branch 'master' into deferred-repeated
daxfohl Sep 21, 2022
30b9121
Merge branch 'master' into sympy-deferred
daxfohl Oct 11, 2022
9d1f5ef
fix merge
daxfohl Oct 11, 2022
d4c80b9
qudit sympy test
daxfohl Oct 11, 2022
dd06a75
Merge branch 'master' into deferred-repeated
daxfohl Oct 11, 2022
480849c
Merge branch 'master' into deferred-repeated
daxfohl Oct 12, 2022
2d1cabf
Merge branch 'master' into sympy-deferred
daxfohl Oct 12, 2022
f033b39
fix build
daxfohl Oct 13, 2022
72388ce
Merge branch 'master' into sympy-deferred
daxfohl Oct 13, 2022
f7f2825
Merge branch 'master' into deferred-repeated
daxfohl Oct 13, 2022
8e8dfc1
Fix test
daxfohl Oct 16, 2022
e733e89
Fix test
daxfohl Oct 16, 2022
46cceef
Merge branch 'deferred-repeated' into deferred-all
daxfohl Oct 16, 2022
0e9dede
Merge branch 'master' into deferred-repeated
daxfohl Oct 31, 2022
29bc38d
Merge branch 'deferred-all' into deferred-repeated
daxfohl Nov 4, 2022
4f9be93
Merge branch 'master' into deferred-repeated
daxfohl Nov 4, 2022
a9df6c9
nits
daxfohl Nov 4, 2022
b86c411
mypy
daxfohl Nov 4, 2022
b873c31
mypy
daxfohl Nov 4, 2022
4920118
mypy
daxfohl Nov 4, 2022
922f827
Add some code comments
daxfohl Nov 4, 2022
4a40c99
Add test for repeated measurement diagram
daxfohl Nov 4, 2022
fb5d11f
change test back
daxfohl Nov 12, 2022
fc93a6f
Merge branch 'master' into deferred-repeated
daxfohl Nov 12, 2022
6fcf183
Merge branch 'master' into deferred-repeated
daxfohl Dec 16, 2022
2d67649
Merge branch 'master' into deferred-repeated
tanujkhattar Dec 19, 2022
90d9ab8
Merge branch 'master' into deferred-repeated
tanujkhattar Dec 19, 2022
4c7a1cd
Merge branch 'master' into deferred-repeated
tanujkhattar Dec 19, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 54 additions & 35 deletions cirq-core/cirq/transformers/measurement_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,8 @@
# limitations under the License.

import itertools
from typing import (
Any,
Dict,
Iterable,
List,
Mapping,
Optional,
Sequence,
Tuple,
TYPE_CHECKING,
Union,
)
from collections import defaultdict
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, TYPE_CHECKING, Union

import numpy as np

Expand All @@ -43,30 +33,32 @@ class _MeasurementQid(ops.Qid):
Exactly one qubit will be created per qubit in the measurement gate.
"""

def __init__(self, key: Union[str, 'cirq.MeasurementKey'], qid: 'cirq.Qid'):
def __init__(self, key: Union[str, 'cirq.MeasurementKey'], qid: 'cirq.Qid', index: int = 0):
"""Initializes the qubit.

Args:
key: The key of the measurement gate being deferred.
qid: One qubit that is being measured. Each deferred measurement
should create one new _MeasurementQid per qubit being measured
by that gate.
index: For repeated measurement keys, this represents the index of that measurement.
"""
self._key = value.MeasurementKey.parse_serialized(key) if isinstance(key, str) else key
self._qid = qid
self._index = index

@property
def dimension(self) -> int:
return self._qid.dimension

def _comparison_key(self) -> Any:
return str(self._key), self._qid._comparison_key()
return str(self._key), self._index, self._qid._comparison_key()

def __str__(self) -> str:
return f"M('{self._key}', q={self._qid})"
return f"M('{self._key}[{self._index}]', q={self._qid})"

def __repr__(self) -> str:
return f'_MeasurementQid({self._key!r}, {self._qid!r})'
return f'_MeasurementQid({self._key!r}, {self._qid!r}, {self._index})'


@transformer_api.transformer
Expand Down Expand Up @@ -102,16 +94,18 @@ def defer_measurements(

circuit = transformer_primitives.unroll_circuit_op(circuit, deep=True, tags_to_check=None)
terminal_measurements = {op for _, op in find_terminal_measurements(circuit)}
measurement_qubits: Dict['cirq.MeasurementKey', List['_MeasurementQid']] = {}
measurement_qubits: Dict['cirq.MeasurementKey', List[Tuple['cirq.Qid', ...]]] = defaultdict(
list
)

def defer(op: 'cirq.Operation', _) -> 'cirq.OP_TREE':
if op in terminal_measurements:
return op
gate = op.gate
if isinstance(gate, ops.MeasurementGate):
key = value.MeasurementKey.parse_serialized(gate.key)
targets = [_MeasurementQid(key, q) for q in op.qubits]
measurement_qubits[key] = targets
targets = [_MeasurementQid(key, q, len(measurement_qubits[key])) for q in op.qubits]
measurement_qubits[key].append(tuple(targets))
cxs = [_mod_add(q, target) for q, target in zip(op.qubits, targets)]
confusions = [
_ConfusionChannel(m, [op.qubits[i].dimension for i in indexes]).on(
Expand All @@ -125,10 +119,24 @@ def defer(op: 'cirq.Operation', _) -> 'cirq.OP_TREE':
return [defer(op, None) for op in protocols.decompose_once(op)]
elif op.classical_controls:
# Convert to a quantum control
keys = sorted(set(key for c in op.classical_controls for key in c.keys))
for key in keys:

# First create a sorted set of the indexed keys for this control.
keys = sorted(
set(
indexed_key
for condition in op.classical_controls
for indexed_key in (
[(condition.key, condition.index)]
if isinstance(condition, value.KeyCondition)
else [(k, -1) for k in condition.keys]
)
)
)
for key, index in keys:
if key not in measurement_qubits:
raise ValueError(f'Deferred measurement for key={key} not found.')
if index >= len(measurement_qubits[key]) or index < -len(measurement_qubits[key]):
raise ValueError(f'Invalid index for {key}')

# Try every possible datastore state (exponential in the number of keys) against the
# condition, and the ones that work are the control values for the new op.
Expand All @@ -140,12 +148,11 @@ def defer(op: 'cirq.Operation', _) -> 'cirq.OP_TREE':

# Rearrange these into the format expected by SumOfProducts
products = [
[i for key in keys for i in store.records[key][0]]
[val for k, i in keys for val in store.records[k][i]]
for store in compatible_datastores
]

control_values = ops.SumOfProducts(products)
qs = [q for key in keys for q in measurement_qubits[key]]
qs = [q for k, i in keys for q in measurement_qubits[k][i]]
return op.without_classical_controls().controlled_by(*qs, control_values=control_values)
return op

Expand All @@ -155,14 +162,15 @@ def defer(op: 'cirq.Operation', _) -> 'cirq.OP_TREE':
tags_to_ignore=context.tags_to_ignore if context else (),
raise_if_add_qubits=False,
).unfreeze()
for k, qubits in measurement_qubits.items():
circuit.append(ops.measure(*qubits, key=k))
for k, qubits_list in measurement_qubits.items():
for qubits in qubits_list:
circuit.append(ops.measure(*qubits, key=k))
return circuit


def _all_possible_datastore_states(
keys: Iterable['cirq.MeasurementKey'],
measurement_qubits: Mapping['cirq.MeasurementKey', Iterable['cirq.Qid']],
keys: Iterable[Tuple['cirq.MeasurementKey', int]],
measurement_qubits: Dict['cirq.MeasurementKey', List[Tuple['cirq.Qid', ...]]],
) -> Iterable['cirq.ClassicalDataStoreReader']:
"""The cartesian product of all possible DataStore states for the given keys."""
# First we get the list of all possible values. So if we have a key mapped to qubits of shape
Expand All @@ -179,17 +187,28 @@ def _all_possible_datastore_states(
# ((1, 1), (0,)),
# ((1, 1), (1,)),
# ((1, 1), (2,))]
all_values = itertools.product(
all_possible_measurements = itertools.product(
*[
tuple(itertools.product(*[range(q.dimension) for q in measurement_qubits[k]]))
for k in keys
tuple(itertools.product(*[range(q.dimension) for q in measurement_qubits[k][i]]))
for k, i in keys
]
)
# Then we create the ClassicalDataDictionaryStore for each of the above.
for sequences in all_values:
lookup = {k: [sequence] for k, sequence in zip(keys, sequences)}
# Then we create the ClassicalDataDictionaryStore for each of the above. A `measurement_list`
# is a single row of the above example, and can be zipped with `keys`.
for measurement_list in all_possible_measurements:
# Initialize a set of measurement records for this iteration. This will have the same shape
# as `measurement_qubits` but zeros for all measurements.
records = {
key: [(0,) * len(qubits) for qubits in qubits_list]
for key, qubits_list in measurement_qubits.items()
}
# Set the measurement values from the current row of the above, for each key/index we care
# about.
for (k, i), measurement in zip(keys, measurement_list):
records[k][i] = measurement
# Finally yield this sample to the consumer.
yield value.ClassicalDataDictionaryStore(
_records=lookup, _measured_qubits={k: [tuple(measurement_qubits[k])] for k in keys}
_records=records, _measured_qubits=measurement_qubits
)


Expand Down
68 changes: 51 additions & 17 deletions cirq-core/cirq/transformers/measurement_transformers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,40 @@ def test_multi_qubit_control():
)


@pytest.mark.parametrize('index', [-3, -2, -1, 0, 1, 2])
def test_repeated(index: int):
q0, q1 = cirq.LineQubit.range(2)
circuit = cirq.Circuit(
cirq.measure(q0, key='a'), # The control measurement when `index` is 0 or -2
cirq.X(q0),
cirq.measure(q0, key='a'), # The control measurement when `index` is 1 or -1
cirq.X(q1).with_classical_controls(cirq.KeyCondition(cirq.MeasurementKey('a'), index)),
cirq.measure(q1, key='b'),
)
if index in [-3, 2]:
with pytest.raises(ValueError, match='Invalid index'):
_ = cirq.defer_measurements(circuit)
return
assert_equivalent_to_deferred(circuit)
deferred = cirq.defer_measurements(circuit)
q_ma = _MeasurementQid('a', q0) # The ancilla qubit created for the first `a` measurement
q_ma1 = _MeasurementQid('a', q0, 1) # The ancilla qubit created for the second `a` measurement
# The ancilla used for control should match the measurement used for control above.
q_expected_control = q_ma if index in [0, -2] else q_ma1
cirq.testing.assert_same_circuits(
deferred,
cirq.Circuit(
cirq.CX(q0, q_ma),
cirq.X(q0),
cirq.CX(q0, q_ma1),
cirq.Moment(cirq.CX(q_expected_control, q1)),
cirq.measure(q_ma, key='a'),
cirq.measure(q_ma1, key='a'),
cirq.measure(q1, key='b'),
),
)


def test_diagram():
q0, q1, q2, q3 = cirq.LineQubit.range(4)
circuit = cirq.Circuit(
Expand All @@ -457,23 +491,23 @@ def test_diagram():
cirq.testing.assert_has_diagram(
deferred,
"""
┌────┐
0: ─────────────────@───────X────────M('c')───
│ │
1: ─────────────────┼─@──────────────M────────
│ │ │
2: ─────────────────┼@┼──────────────M────────
│││ │
3: ─────────────────┼┼┼@─────────────M────────
││││
M('a', q=q(0)): ────X┼┼┼────M('a')────────────
│││ │
M('a', q=q(2)): ─────X┼┼────M─────────────────
││
M('b', q=q(1)): ──────X┼────M('b')────────────
│ │
M('b', q=q(3)): ───────X────M─────────────────
└────┘
┌────┐
0: ────────────────────@───────X────────M('c')───
│ │
1: ────────────────────┼─@──────────────M────────
│ │ │
2: ────────────────────┼@┼──────────────M────────
│││ │
3: ────────────────────┼┼┼@─────────────M────────
││││
M('a[0]', q=q(0)): ────X┼┼┼────M('a')────────────
│││ │
M('a[0]', q=q(2)): ─────X┼┼────M─────────────────
││
M('b[0]', q=q(1)): ──────X┼────M('b')────────────
│ │
M('b[0]', q=q(3)): ───────X────M─────────────────
└────┘
Comment thread
daxfohl marked this conversation as resolved.
""",
use_unicode_characters=True,
)
Expand Down