Skip to content

Commit ec5b8f8

Browse files
NoureldinYosriBichengYing
authored andcommitted
Create a new condition that allows easy control by bitmasks and Add a new classical Update the notebook for 'Classical control' to reflect new features" (quantumlib#7166)
* Create a new condition that allows easy control by bitmasks * add tests * docstring * add missing files * swith from dataclasses to attrs * mypy * Update the notebook for 'Classical control' to reflect new features * nit * nit * tests * nit * fix doc regex * format * address comments * lint
1 parent 89871cd commit ec5b8f8

File tree

11 files changed

+492
-3
lines changed

11 files changed

+492
-3
lines changed

cirq-core/cirq/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -533,6 +533,7 @@
533533
canonicalize_half_turns as canonicalize_half_turns,
534534
chosen_angle_to_canonical_half_turns as chosen_angle_to_canonical_half_turns,
535535
chosen_angle_to_half_turns as chosen_angle_to_half_turns,
536+
BitMaskKeyCondition as BitMaskKeyCondition,
536537
ClassicalDataDictionaryStore as ClassicalDataDictionaryStore,
537538
ClassicalDataStore as ClassicalDataStore,
538539
ClassicalDataStoreReader as ClassicalDataStoreReader,

cirq-core/cirq/json_resolver_cache.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ def _symmetricalqidpair(qids):
107107
'AnyUnitaryGateFamily': cirq.AnyUnitaryGateFamily,
108108
'AsymmetricDepolarizingChannel': cirq.AsymmetricDepolarizingChannel,
109109
'BitFlipChannel': cirq.BitFlipChannel,
110+
'BitMaskKeyCondition': cirq.BitMaskKeyCondition,
110111
'BitstringAccumulator': cirq.work.BitstringAccumulator,
111112
'BooleanHamiltonianGate': cirq.BooleanHamiltonianGate,
112113
'CCNotPowGate': cirq.CCNotPowGate,

cirq-core/cirq/protocols/json_serialization.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
Union,
3535
)
3636

37+
import attrs
3738
import numpy as np
3839
import pandas as pd
3940
import sympy
@@ -182,6 +183,12 @@ def dataclass_json_dict(obj: Any) -> Dict[str, Any]:
182183
return obj_to_dict_helper(obj, attribute_names)
183184

184185

186+
def attrs_json_dict(obj: Any) -> Dict[str, Any]:
187+
"""Return a dictionary suitable for `_json_dict_` from an attrs dataclass."""
188+
attribute_names = [f.name for f in attrs.fields(type(obj))]
189+
return obj_to_dict_helper(obj, attribute_names)
190+
191+
185192
def _json_dict_with_cirq_type(obj: Any):
186193
base_dict = obj._json_dict_()
187194
if 'cirq_type' in base_dict:

cirq-core/cirq/protocols/json_serialization_test.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from typing import Dict, List, Optional, Tuple, Type
2525
from unittest import mock
2626

27+
import attrs
2728
import networkx as nx
2829
import numpy as np
2930
import pandas as pd
@@ -790,3 +791,15 @@ def test_datetime():
790791
assert re_pst_dt == pst_dt
791792
assert re_pst_dt == utc_dt
792793
assert re_pst_dt == re_naive_dt
794+
795+
796+
@attrs.frozen
797+
class _TestAttrsClas:
798+
name: str
799+
x: int
800+
801+
802+
def test_attrs_json_dict():
803+
obj = _TestAttrsClas('test', x=123)
804+
js = json_serialization.attrs_json_dict(obj)
805+
assert js == {'name': 'test', 'x': 123}
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
[
2+
{
3+
"cirq_type": "BitMaskKeyCondition",
4+
"key": {
5+
"cirq_type": "MeasurementKey",
6+
"name": "a",
7+
"path": []
8+
},
9+
"index": 59,
10+
"target_value": 0,
11+
"equal_target": false,
12+
"bitmask": null
13+
},
14+
{
15+
"cirq_type": "BitMaskKeyCondition",
16+
"key": {
17+
"cirq_type": "MeasurementKey",
18+
"name": "b",
19+
"path": []
20+
},
21+
"index": 58,
22+
"target_value": 3,
23+
"equal_target": false,
24+
"bitmask": null
25+
},
26+
{
27+
"cirq_type": "BitMaskKeyCondition",
28+
"key": {
29+
"cirq_type": "MeasurementKey",
30+
"name": "c",
31+
"path": []
32+
},
33+
"index": 57,
34+
"target_value": 0,
35+
"equal_target": false,
36+
"bitmask": 13
37+
},
38+
{
39+
"cirq_type": "BitMaskKeyCondition",
40+
"key": {
41+
"cirq_type": "MeasurementKey",
42+
"name": "d",
43+
"path": []
44+
},
45+
"index": 56,
46+
"target_value": 12,
47+
"equal_target": false,
48+
"bitmask": 13
49+
},
50+
{
51+
"cirq_type": "BitMaskKeyCondition",
52+
"key": {
53+
"cirq_type": "MeasurementKey",
54+
"name": "d",
55+
"path": []
56+
},
57+
"index": 55,
58+
"target_value": 12,
59+
"equal_target": true,
60+
"bitmask": 13
61+
},
62+
{
63+
"cirq_type": "BitMaskKeyCondition",
64+
"key": {
65+
"cirq_type": "MeasurementKey",
66+
"name": "e",
67+
"path": []
68+
},
69+
"index": 54,
70+
"target_value": 11,
71+
"equal_target": true,
72+
"bitmask": 11
73+
},
74+
{
75+
"cirq_type": "BitMaskKeyCondition",
76+
"key": {
77+
"cirq_type": "MeasurementKey",
78+
"name": "e",
79+
"path": []
80+
},
81+
"index": 53,
82+
"target_value": 9,
83+
"equal_target": false,
84+
"bitmask": 9
85+
}
86+
]
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
[cirq.BitMaskKeyCondition(key=cirq.MeasurementKey(name='a'), index=59, target_value=0, equal_target=False, bitmask=None),
2+
cirq.BitMaskKeyCondition(key=cirq.MeasurementKey(name='b'), index=58, target_value=3, equal_target=False, bitmask=None),
3+
cirq.BitMaskKeyCondition(key=cirq.MeasurementKey(name='c'), index=57, target_value=0, equal_target=False, bitmask=13),
4+
cirq.BitMaskKeyCondition(key=cirq.MeasurementKey(name='d'), index=56, target_value=12, equal_target=False, bitmask=13),
5+
cirq.BitMaskKeyCondition(key=cirq.MeasurementKey(name='d'), index=55, target_value=12, equal_target=True, bitmask=13),
6+
cirq.BitMaskKeyCondition(key=cirq.MeasurementKey(name='e'), index=54, target_value=11, equal_target=True, bitmask=11),
7+
cirq.BitMaskKeyCondition(key=cirq.MeasurementKey(name='e'), index=53, target_value=9, equal_target=False, bitmask=9)]

cirq-core/cirq/value/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
Condition as Condition,
3636
KeyCondition as KeyCondition,
3737
SympyCondition as SympyCondition,
38+
BitMaskKeyCondition as BitMaskKeyCondition,
3839
)
3940

4041
from cirq.value.digits import (

cirq-core/cirq/value/condition.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import dataclasses
1717
from typing import Any, Dict, FrozenSet, Mapping, Optional, Tuple, TYPE_CHECKING
1818

19+
import attrs
1920
import sympy
2021

2122
from cirq._compat import proper_repr
@@ -135,6 +136,112 @@ def _qasm_(self, args: 'cirq.QasmArgs', **kwargs) -> Optional[str]:
135136
return f'{key}==1'
136137

137138

139+
@attrs.frozen
140+
class BitMaskKeyCondition(Condition):
141+
"""A multiqubit classical control condition with a bitmask.
142+
143+
The control is based on a single measurement key and allows comparing equality or inequality
144+
after taking the bitwise and with a bitmask.
145+
146+
Examples:
147+
- BitMaskKeycondition('a') -> a != 0
148+
- BitMaskKeyCondition('a', bitmask=13) -> (a & 13) != 0
149+
- BitMaskKeyCondition('a', bitmask=13, target_value=9) -> (a & 13) != 9
150+
- BitMaskKeyCondition('a', bitmask=13, target_value=9, equal_target=True) -> (a & 13) == 9
151+
- BitMaskKeyCondition.create_equal_mask('a', 13) -> (a & 13) == 13
152+
- BitMaskKeyCondition.create_not_equal_mask('a', 13) -> (a & 13) != 13
153+
154+
The bits in the bitmask have the same order as the qubits passed to `cirq.measure(...)`. That's
155+
the most significant bit corresponds to the the first (left most) qubit.
156+
157+
Attributes:
158+
- key: Measurement key.
159+
- index: integer index (same as KeyCondition.index).
160+
- target_value: The value we compare with.
161+
- equal_target: Whether to comapre with == or !=.
162+
- bitmask: Optional bitmask to apply before doing the comparison.
163+
"""
164+
165+
key: 'cirq.MeasurementKey' = attrs.field(
166+
converter=lambda x: (
167+
x
168+
if isinstance(x, measurement_key.MeasurementKey)
169+
else measurement_key.MeasurementKey(x)
170+
)
171+
)
172+
index: int = -1
173+
target_value: int = 0
174+
equal_target: bool = False
175+
bitmask: Optional[int] = None
176+
177+
@property
178+
def keys(self):
179+
return (self.key,)
180+
181+
@staticmethod
182+
def create_equal_mask(
183+
key: 'cirq.MeasurementKey', bitmask: int, *, index: int = -1
184+
) -> 'BitMaskKeyCondition':
185+
"""Creates a condition that evaluates (meas & bitmask) == bitmask."""
186+
return BitMaskKeyCondition(
187+
key, index, target_value=bitmask, equal_target=True, bitmask=bitmask
188+
)
189+
190+
@staticmethod
191+
def create_not_equal_mask(
192+
key: 'cirq.MeasurementKey', bitmask: int, *, index: int = -1
193+
) -> 'BitMaskKeyCondition':
194+
"""Creates a condition that evaluates (meas & bitmask) != bitmask."""
195+
return BitMaskKeyCondition(
196+
key, index, target_value=bitmask, equal_target=False, bitmask=bitmask
197+
)
198+
199+
def replace_key(self, current: 'cirq.MeasurementKey', replacement: 'cirq.MeasurementKey'):
200+
return BitMaskKeyCondition(replacement) if self.key == current else self
201+
202+
def __str__(self):
203+
s = str(self.key) if self.index == -1 else f'{self.key}[{self.index}]'
204+
if self.bitmask is not None:
205+
s = f'{s} & {self.bitmask}'
206+
if self.equal_target:
207+
if self.bitmask is not None:
208+
s = f'({s})'
209+
s = f'{s} == {self.target_value}'
210+
elif self.target_value != 0:
211+
if self.bitmask is not None:
212+
s = f'({s})'
213+
s = f'{s} != {self.target_value}'
214+
return s
215+
216+
def __repr__(self):
217+
values = attrs.asdict(self)
218+
parameters = ', '.join(f'{f.name}={repr(values[f.name])}' for f in attrs.fields(type(self)))
219+
return f'cirq.BitMaskKeyCondition({parameters})'
220+
221+
def resolve(self, classical_data: 'cirq.ClassicalDataStoreReader') -> bool:
222+
if self.key not in classical_data.keys():
223+
raise ValueError(f'Measurement key {self.key} missing when testing classical control')
224+
value = classical_data.get_int(self.key, self.index)
225+
if self.bitmask is not None:
226+
value &= self.bitmask
227+
if self.equal_target:
228+
return value == self.target_value
229+
return value != self.target_value
230+
231+
def _json_dict_(self):
232+
return json_serialization.attrs_json_dict(self)
233+
234+
@classmethod
235+
def _from_json_dict_(cls, key, **kwargs):
236+
parameter_names = [f.name for f in attrs.fields(cls)[1:]]
237+
parameters = {k: kwargs[k] for k in parameter_names if k in kwargs}
238+
return cls(key=key, **parameters)
239+
240+
@property
241+
def qasm(self):
242+
raise NotImplementedError() # pragma: no cover
243+
244+
138245
@dataclasses.dataclass(frozen=True)
139246
class SympyCondition(Condition):
140247
"""A classical control condition based on a sympy expression.

0 commit comments

Comments
 (0)