|
16 | 16 | import dataclasses |
17 | 17 | from typing import Any, Dict, FrozenSet, Mapping, Optional, Tuple, TYPE_CHECKING |
18 | 18 |
|
| 19 | +import attrs |
19 | 20 | import sympy |
20 | 21 |
|
21 | 22 | from cirq._compat import proper_repr |
@@ -135,6 +136,112 @@ def _qasm_(self, args: 'cirq.QasmArgs', **kwargs) -> Optional[str]: |
135 | 136 | return f'{key}==1' |
136 | 137 |
|
137 | 138 |
|
| 139 | +@attrs.frozen |
| 140 | +class BitMaskKeyCondition(Condition): |
| 141 | + """A multiqubit classical control condition with a bitmask. |
| 142 | +
|
| 143 | + The control is based on a single measurement key and allows comparing equality or inequality |
| 144 | + after taking the bitwise and with a bitmask. |
| 145 | +
|
| 146 | + Examples: |
| 147 | + - BitMaskKeycondition('a') -> a != 0 |
| 148 | + - BitMaskKeyCondition('a', bitmask=13) -> (a & 13) != 0 |
| 149 | + - BitMaskKeyCondition('a', bitmask=13, target_value=9) -> (a & 13) != 9 |
| 150 | + - BitMaskKeyCondition('a', bitmask=13, target_value=9, equal_target=True) -> (a & 13) == 9 |
| 151 | + - BitMaskKeyCondition.create_equal_mask('a', 13) -> (a & 13) == 13 |
| 152 | + - BitMaskKeyCondition.create_not_equal_mask('a', 13) -> (a & 13) != 13 |
| 153 | +
|
| 154 | + The bits in the bitmask have the same order as the qubits passed to `cirq.measure(...)`. That's |
| 155 | + the most significant bit corresponds to the the first (left most) qubit. |
| 156 | +
|
| 157 | + Attributes: |
| 158 | + - key: Measurement key. |
| 159 | + - index: integer index (same as KeyCondition.index). |
| 160 | + - target_value: The value we compare with. |
| 161 | + - equal_target: Whether to comapre with == or !=. |
| 162 | + - bitmask: Optional bitmask to apply before doing the comparison. |
| 163 | + """ |
| 164 | + |
| 165 | + key: 'cirq.MeasurementKey' = attrs.field( |
| 166 | + converter=lambda x: ( |
| 167 | + x |
| 168 | + if isinstance(x, measurement_key.MeasurementKey) |
| 169 | + else measurement_key.MeasurementKey(x) |
| 170 | + ) |
| 171 | + ) |
| 172 | + index: int = -1 |
| 173 | + target_value: int = 0 |
| 174 | + equal_target: bool = False |
| 175 | + bitmask: Optional[int] = None |
| 176 | + |
| 177 | + @property |
| 178 | + def keys(self): |
| 179 | + return (self.key,) |
| 180 | + |
| 181 | + @staticmethod |
| 182 | + def create_equal_mask( |
| 183 | + key: 'cirq.MeasurementKey', bitmask: int, *, index: int = -1 |
| 184 | + ) -> 'BitMaskKeyCondition': |
| 185 | + """Creates a condition that evaluates (meas & bitmask) == bitmask.""" |
| 186 | + return BitMaskKeyCondition( |
| 187 | + key, index, target_value=bitmask, equal_target=True, bitmask=bitmask |
| 188 | + ) |
| 189 | + |
| 190 | + @staticmethod |
| 191 | + def create_not_equal_mask( |
| 192 | + key: 'cirq.MeasurementKey', bitmask: int, *, index: int = -1 |
| 193 | + ) -> 'BitMaskKeyCondition': |
| 194 | + """Creates a condition that evaluates (meas & bitmask) != bitmask.""" |
| 195 | + return BitMaskKeyCondition( |
| 196 | + key, index, target_value=bitmask, equal_target=False, bitmask=bitmask |
| 197 | + ) |
| 198 | + |
| 199 | + def replace_key(self, current: 'cirq.MeasurementKey', replacement: 'cirq.MeasurementKey'): |
| 200 | + return BitMaskKeyCondition(replacement) if self.key == current else self |
| 201 | + |
| 202 | + def __str__(self): |
| 203 | + s = str(self.key) if self.index == -1 else f'{self.key}[{self.index}]' |
| 204 | + if self.bitmask is not None: |
| 205 | + s = f'{s} & {self.bitmask}' |
| 206 | + if self.equal_target: |
| 207 | + if self.bitmask is not None: |
| 208 | + s = f'({s})' |
| 209 | + s = f'{s} == {self.target_value}' |
| 210 | + elif self.target_value != 0: |
| 211 | + if self.bitmask is not None: |
| 212 | + s = f'({s})' |
| 213 | + s = f'{s} != {self.target_value}' |
| 214 | + return s |
| 215 | + |
| 216 | + def __repr__(self): |
| 217 | + values = attrs.asdict(self) |
| 218 | + parameters = ', '.join(f'{f.name}={repr(values[f.name])}' for f in attrs.fields(type(self))) |
| 219 | + return f'cirq.BitMaskKeyCondition({parameters})' |
| 220 | + |
| 221 | + def resolve(self, classical_data: 'cirq.ClassicalDataStoreReader') -> bool: |
| 222 | + if self.key not in classical_data.keys(): |
| 223 | + raise ValueError(f'Measurement key {self.key} missing when testing classical control') |
| 224 | + value = classical_data.get_int(self.key, self.index) |
| 225 | + if self.bitmask is not None: |
| 226 | + value &= self.bitmask |
| 227 | + if self.equal_target: |
| 228 | + return value == self.target_value |
| 229 | + return value != self.target_value |
| 230 | + |
| 231 | + def _json_dict_(self): |
| 232 | + return json_serialization.attrs_json_dict(self) |
| 233 | + |
| 234 | + @classmethod |
| 235 | + def _from_json_dict_(cls, key, **kwargs): |
| 236 | + parameter_names = [f.name for f in attrs.fields(cls)[1:]] |
| 237 | + parameters = {k: kwargs[k] for k in parameter_names if k in kwargs} |
| 238 | + return cls(key=key, **parameters) |
| 239 | + |
| 240 | + @property |
| 241 | + def qasm(self): |
| 242 | + raise NotImplementedError() # pragma: no cover |
| 243 | + |
| 244 | + |
138 | 245 | @dataclasses.dataclass(frozen=True) |
139 | 246 | class SympyCondition(Condition): |
140 | 247 | """A classical control condition based on a sympy expression. |
|
0 commit comments