Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cirq-core/cirq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,6 +533,7 @@
canonicalize_half_turns as canonicalize_half_turns,
chosen_angle_to_canonical_half_turns as chosen_angle_to_canonical_half_turns,
chosen_angle_to_half_turns as chosen_angle_to_half_turns,
BitMaskKeyCondition as BitMaskKeyCondition,
ClassicalDataDictionaryStore as ClassicalDataDictionaryStore,
ClassicalDataStore as ClassicalDataStore,
ClassicalDataStoreReader as ClassicalDataStoreReader,
Expand Down
1 change: 1 addition & 0 deletions cirq-core/cirq/json_resolver_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ def _symmetricalqidpair(qids):
'AnyUnitaryGateFamily': cirq.AnyUnitaryGateFamily,
'AsymmetricDepolarizingChannel': cirq.AsymmetricDepolarizingChannel,
'BitFlipChannel': cirq.BitFlipChannel,
'BitMaskKeyCondition': cirq.BitMaskKeyCondition,
'BitstringAccumulator': cirq.work.BitstringAccumulator,
'BooleanHamiltonianGate': cirq.BooleanHamiltonianGate,
'CCNotPowGate': cirq.CCNotPowGate,
Expand Down
7 changes: 7 additions & 0 deletions cirq-core/cirq/protocols/json_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
Union,
)

import attrs
import numpy as np
import pandas as pd
import sympy
Expand Down Expand Up @@ -182,6 +183,12 @@ def dataclass_json_dict(obj: Any) -> Dict[str, Any]:
return obj_to_dict_helper(obj, attribute_names)


def attrs_json_dict(obj: Any) -> Dict[str, Any]:
"""Return a dictionary suitable for `_json_dict_` from an attrs dataclass."""
attribute_names = [f.name for f in attrs.fields(type(obj))]
return obj_to_dict_helper(obj, attribute_names)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a unit test for this function?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done



def _json_dict_with_cirq_type(obj: Any):
base_dict = obj._json_dict_()
if 'cirq_type' in base_dict:
Expand Down
13 changes: 13 additions & 0 deletions cirq-core/cirq/protocols/json_serialization_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from typing import Dict, List, Optional, Tuple, Type
from unittest import mock

import attrs
import networkx as nx
import numpy as np
import pandas as pd
Expand Down Expand Up @@ -790,3 +791,15 @@ def test_datetime():
assert re_pst_dt == pst_dt
assert re_pst_dt == utc_dt
assert re_pst_dt == re_naive_dt


@attrs.frozen
class _TestAttrsClas:
name: str
x: int


def test_attrs_json_dict():
obj = _TestAttrsClas('test', x=123)
js = json_serialization.attrs_json_dict(obj)
assert js == {'name': 'test', 'x': 123}
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
[
{
"cirq_type": "BitMaskKeyCondition",
"key": {
"cirq_type": "MeasurementKey",
"name": "a",
"path": []
},
"index": 59,
"target_value": 0,
"equal_target": false,
"bitmask": null
},
{
"cirq_type": "BitMaskKeyCondition",
"key": {
"cirq_type": "MeasurementKey",
"name": "b",
"path": []
},
"index": 58,
"target_value": 3,
"equal_target": false,
"bitmask": null
},
{
"cirq_type": "BitMaskKeyCondition",
"key": {
"cirq_type": "MeasurementKey",
"name": "c",
"path": []
},
"index": 57,
"target_value": 0,
"equal_target": false,
"bitmask": 13
},
{
"cirq_type": "BitMaskKeyCondition",
"key": {
"cirq_type": "MeasurementKey",
"name": "d",
"path": []
},
"index": 56,
"target_value": 12,
"equal_target": false,
"bitmask": 13
},
{
"cirq_type": "BitMaskKeyCondition",
"key": {
"cirq_type": "MeasurementKey",
"name": "d",
"path": []
},
"index": 55,
"target_value": 12,
"equal_target": true,
"bitmask": 13
},
{
"cirq_type": "BitMaskKeyCondition",
"key": {
"cirq_type": "MeasurementKey",
"name": "e",
"path": []
},
"index": 54,
"target_value": 11,
"equal_target": true,
"bitmask": 11
},
{
"cirq_type": "BitMaskKeyCondition",
"key": {
"cirq_type": "MeasurementKey",
"name": "e",
"path": []
},
"index": 53,
"target_value": 9,
"equal_target": false,
"bitmask": 9
}
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
[cirq.BitMaskKeyCondition(key=cirq.MeasurementKey(name='a'), index=59, target_value=0, equal_target=False, bitmask=None),
cirq.BitMaskKeyCondition(key=cirq.MeasurementKey(name='b'), index=58, target_value=3, equal_target=False, bitmask=None),
cirq.BitMaskKeyCondition(key=cirq.MeasurementKey(name='c'), index=57, target_value=0, equal_target=False, bitmask=13),
cirq.BitMaskKeyCondition(key=cirq.MeasurementKey(name='d'), index=56, target_value=12, equal_target=False, bitmask=13),
cirq.BitMaskKeyCondition(key=cirq.MeasurementKey(name='d'), index=55, target_value=12, equal_target=True, bitmask=13),
cirq.BitMaskKeyCondition(key=cirq.MeasurementKey(name='e'), index=54, target_value=11, equal_target=True, bitmask=11),
cirq.BitMaskKeyCondition(key=cirq.MeasurementKey(name='e'), index=53, target_value=9, equal_target=False, bitmask=9)]
1 change: 1 addition & 0 deletions cirq-core/cirq/value/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
Condition as Condition,
KeyCondition as KeyCondition,
SympyCondition as SympyCondition,
BitMaskKeyCondition as BitMaskKeyCondition,
)

from cirq.value.digits import (
Expand Down
107 changes: 107 additions & 0 deletions cirq-core/cirq/value/condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import dataclasses
from typing import Any, Dict, FrozenSet, Mapping, Optional, Tuple, TYPE_CHECKING

import attrs
import sympy

from cirq._compat import proper_repr
Expand Down Expand Up @@ -135,6 +136,112 @@ def _qasm_(self, args: 'cirq.QasmArgs', **kwargs) -> Optional[str]:
return f'{key}==1'


@attrs.frozen
class BitMaskKeyCondition(Condition):
"""A multiqubit classical control condition with a bitmask.

The control is based on a single measurement key and allows comparing equality or inequality
after taking the bitwise and with a bitmask.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we have a short blurb about the bit order of the measurement (or a reference to the measurement code that explains)?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added


Examples:
- BitMaskKeycondition('a') -> a != 0
- BitMaskKeyCondition('a', bitmask=13) -> (a & 13) != 0
- BitMaskKeyCondition('a', bitmask=13, target_value=9) -> (a & 13) != 9
- BitMaskKeyCondition('a', bitmask=13, target_value=9, equal_target=True) -> (a & 13) == 9
- BitMaskKeyCondition.create_equal_mask('a', 13) -> (a & 13) == 13
- BitMaskKeyCondition.create_not_equal_mask('a', 13) -> (a & 13) != 13

The bits in the bitmask have the same order as the qubits passed to `cirq.measure(...)`. That's
the most significant bit corresponds to the the first (left most) qubit.

Attributes:
- key: Measurement key.
- index: integer index (same as KeyCondition.index).
- target_value: The value we compare with.
- equal_target: Whether to comapre with == or !=.
- bitmask: Optional bitmask to apply before doing the comparison.
"""

key: 'cirq.MeasurementKey' = attrs.field(
converter=lambda x: (
x
if isinstance(x, measurement_key.MeasurementKey)
else measurement_key.MeasurementKey(x)
)
)
index: int = -1
target_value: int = 0
equal_target: bool = False
bitmask: Optional[int] = None

@property
def keys(self):
return (self.key,)

@staticmethod
def create_equal_mask(
key: 'cirq.MeasurementKey', bitmask: int, *, index: int = -1
) -> 'BitMaskKeyCondition':
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add docstring.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

"""Creates a condition that evaluates (meas & bitmask) == bitmask."""
return BitMaskKeyCondition(
key, index, target_value=bitmask, equal_target=True, bitmask=bitmask
)

@staticmethod
def create_not_equal_mask(
key: 'cirq.MeasurementKey', bitmask: int, *, index: int = -1
) -> 'BitMaskKeyCondition':
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add docstring.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

"""Creates a condition that evaluates (meas & bitmask) != bitmask."""
return BitMaskKeyCondition(
key, index, target_value=bitmask, equal_target=False, bitmask=bitmask
)

def replace_key(self, current: 'cirq.MeasurementKey', replacement: 'cirq.MeasurementKey'):
return BitMaskKeyCondition(replacement) if self.key == current else self

def __str__(self):
s = str(self.key) if self.index == -1 else f'{self.key}[{self.index}]'
if self.bitmask is not None:
s = f'{s} & {self.bitmask}'
if self.equal_target:
if self.bitmask is not None:
s = f'({s})'
s = f'{s} == {self.target_value}'
elif self.target_value != 0:
if self.bitmask is not None:
s = f'({s})'
s = f'{s} != {self.target_value}'
return s

def __repr__(self):
values = attrs.asdict(self)
parameters = ', '.join(f'{f.name}={repr(values[f.name])}' for f in attrs.fields(type(self)))
return f'cirq.BitMaskKeyCondition({parameters})'

def resolve(self, classical_data: 'cirq.ClassicalDataStoreReader') -> bool:
if self.key not in classical_data.keys():
raise ValueError(f'Measurement key {self.key} missing when testing classical control')
value = classical_data.get_int(self.key, self.index)
if self.bitmask is not None:
value &= self.bitmask
if self.equal_target:
return value == self.target_value
return value != self.target_value

def _json_dict_(self):
return json_serialization.attrs_json_dict(self)

@classmethod
def _from_json_dict_(cls, key, **kwargs):
parameter_names = [f.name for f in attrs.fields(cls)[1:]]
parameters = {k: kwargs[k] for k in parameter_names if k in kwargs}
return cls(key=key, **parameters)

@property
def qasm(self):
raise NotImplementedError() # pragma: no cover


@dataclasses.dataclass(frozen=True)
class SympyCondition(Condition):
"""A classical control condition based on a sympy expression.
Expand Down
Loading