Skip to content

Commit 5adcb64

Browse files
authored
Adding tags to Moments (#7467)
- This PR adds the ability to add tags to Moment objects. - tags in Moments function similar to tags in operations, in that they can be any hashable object and are lost if the circuit undergoes transformation, such as by a transformer.
1 parent b3fe874 commit 5adcb64

File tree

4 files changed

+155
-8
lines changed

4 files changed

+155
-8
lines changed

cirq-core/cirq/circuits/moment.py

Lines changed: 60 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
Any,
2525
Callable,
2626
cast,
27+
Hashable,
2728
Iterable,
2829
Iterator,
2930
Mapping,
@@ -77,7 +78,12 @@ class Moment:
7778
are no such operations, returns an empty Moment.
7879
"""
7980

80-
def __init__(self, *contents: cirq.OP_TREE, _flatten_contents: bool = True) -> None:
81+
def __init__(
82+
self,
83+
*contents: cirq.OP_TREE,
84+
_flatten_contents: bool = True,
85+
tags: tuple[Hashable, ...] = (),
86+
) -> None:
8187
"""Constructs a moment with the given operations.
8288
8389
Args:
@@ -88,6 +94,12 @@ def __init__(self, *contents: cirq.OP_TREE, _flatten_contents: bool = True) -> N
8894
we skip flattening and assume that contents already consists
8995
of individual operations. This is used internally by helper
9096
methods to avoid unnecessary validation.
97+
tags: Optional tags to denote specific Moment objects with meta-data.
98+
These are a tuple of any Hashable object. Typically, a class
99+
will be passed. Tags apply only to this specific set of operations
100+
and will be lost on any transformation of the
101+
Moment. For instance, if operations are added to the Moment, tags
102+
will be dropped unless explicitly added back in by the user.
91103
92104
Raises:
93105
ValueError: A qubit appears more than once.
@@ -110,9 +122,10 @@ def __init__(self, *contents: cirq.OP_TREE, _flatten_contents: bool = True) -> N
110122

111123
self._measurement_key_objs: frozenset[cirq.MeasurementKey] | None = None
112124
self._control_keys: frozenset[cirq.MeasurementKey] | None = None
125+
self._tags = tags
113126

114127
@classmethod
115-
def from_ops(cls, *ops: cirq.Operation) -> cirq.Moment:
128+
def from_ops(cls, *ops: cirq.Operation, tags: tuple[Hashable, ...] = ()) -> cirq.Moment:
116129
"""Construct a Moment from the given operations.
117130
118131
This avoids calling `flatten_to_ops` in the moment constructor, which
@@ -122,8 +135,11 @@ def from_ops(cls, *ops: cirq.Operation) -> cirq.Moment:
122135
123136
Args:
124137
*ops: Operations to include in the Moment.
138+
tags: Optional tags to denote specific Moment objects with meta-data.
139+
These are a tuple of any Hashable object. Tags will be dropped if
140+
the operations in the Moment are modified or transformed.
125141
"""
126-
return cls(*ops, _flatten_contents=False)
142+
return cls(*ops, _flatten_contents=False, tags=tags)
127143

128144
@property
129145
def operations(self) -> tuple[cirq.Operation, ...]:
@@ -133,6 +149,34 @@ def operations(self) -> tuple[cirq.Operation, ...]:
133149
def qubits(self) -> frozenset[cirq.Qid]:
134150
return frozenset(self._qubit_to_op)
135151

152+
@property
153+
def tags(self) -> tuple[Hashable, ...]:
154+
"""Returns a tuple of the operation's tags."""
155+
return self._tags
156+
157+
def with_tags(self, *new_tags: Hashable) -> cirq.Moment:
158+
"""Creates a new Moment with the current ops and the specified tags.
159+
160+
If the moment already has tags, this will add the new_tags to the
161+
preexisting tags.
162+
163+
This method can be used to attach meta-data to moments
164+
without affecting their functionality. The intended usage is to
165+
attach classes intended for this purpose or strings to mark operations
166+
for specific usage that will be recognized by consumers.
167+
168+
Tags can be a list of any type of object that is useful to identify
169+
this operation as long as the type is hashable. If you wish the
170+
resulting operation to be eventually serialized into JSON, you should
171+
also restrict the operation to be JSON serializable.
172+
173+
Please note that tags should be instantiated if classes are
174+
used. Raw types are not allowed.
175+
"""
176+
if not new_tags:
177+
return self
178+
return Moment(*self._operations, _flatten_contents=False, tags=(*self._tags, *new_tags))
179+
136180
def operates_on_single_qubit(self, qubit: cirq.Qid) -> bool:
137181
"""Determines if the moment has operations touching the given qubit.
138182
Args:
@@ -170,6 +214,8 @@ def operation_at(self, qubit: raw_types.Qid) -> cirq.Operation | None:
170214
def with_operation(self, operation: cirq.Operation) -> cirq.Moment:
171215
"""Returns an equal moment, but with the given op added.
172216
217+
Any tags on the Moment will be dropped.
218+
173219
Args:
174220
operation: The operation to append.
175221
@@ -198,6 +244,9 @@ def with_operation(self, operation: cirq.Operation) -> cirq.Moment:
198244
def with_operations(self, *contents: cirq.OP_TREE) -> cirq.Moment:
199245
"""Returns a new moment with the given contents added.
200246
247+
Any tags on the original Moment object are dropped if the Moment
248+
is changed.
249+
201250
Args:
202251
*contents: New operations to add to this moment.
203252
@@ -235,6 +284,9 @@ def with_operations(self, *contents: cirq.OP_TREE) -> cirq.Moment:
235284
def without_operations_touching(self, qubits: Iterable[cirq.Qid]) -> cirq.Moment:
236285
"""Returns an equal moment, but without ops on the given qubits.
237286
287+
Any tags on the original Moment object are dropped if the Moment
288+
is changed.
289+
238290
Args:
239291
qubits: Operations that touch these will be removed.
240292
@@ -510,11 +562,13 @@ def _superoperator_(self) -> np.ndarray:
510562
return qis.kraus_to_superoperator(self._kraus_())
511563

512564
def _json_dict_(self) -> dict[str, Any]:
513-
return protocols.obj_to_dict_helper(self, ['operations'])
565+
# For backwards compatibility, only output tags if they exist.
566+
args = ['operations', 'tags'] if self._tags else ['operations']
567+
return protocols.obj_to_dict_helper(self, args)
514568

515569
@classmethod
516-
def _from_json_dict_(cls, operations, **kwargs):
517-
return cls.from_ops(*operations)
570+
def _from_json_dict_(cls, operations, tags=(), **kwargs):
571+
return cls(*operations, tags=tags)
518572

519573
def __add__(self, other: cirq.OP_TREE) -> cirq.Moment:
520574
if isinstance(other, circuit.AbstractCircuit):

cirq-core/cirq/circuits/moment_test.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -939,3 +939,68 @@ def test_superoperator() -> None:
939939
assert m._has_superoperator_()
940940
s = m._superoperator_()
941941
assert np.allclose(s, np.array([[1, 0, 0, 1], [0, 0, 0, 0], [0, 0, 0, 0], [1, 0, 0, 1]]) / 2)
942+
943+
944+
def test_moment_with_tags() -> None:
945+
q0 = cirq.LineQubit(0)
946+
q1 = cirq.LineQubit(1)
947+
op1 = cirq.X(q0)
948+
op2 = cirq.Y(q1)
949+
950+
# Test initialization with no tags
951+
moment_no_tags = cirq.Moment(op1)
952+
assert moment_no_tags.tags == ()
953+
954+
# Test initialization with tags
955+
moment_with_tags = cirq.Moment(op1, op2, tags=("initial_tag_1", "initial_tag_2"))
956+
assert moment_with_tags.tags == ("initial_tag_1", "initial_tag_2")
957+
958+
# Test with_tags method to add new tags
959+
new_moment = moment_with_tags.with_tags("new_tag_1", "new_tag_2")
960+
961+
# Ensure the original moment's tags are unchanged
962+
assert moment_with_tags.tags == ("initial_tag_1", "initial_tag_2")
963+
964+
# Ensure the new moment has both old and new tags
965+
assert new_moment.tags == ("initial_tag_1", "initial_tag_2", "new_tag_1", "new_tag_2")
966+
967+
# Test with_tags on a moment that initially had no tags
968+
new_moment_from_no_tags = moment_no_tags.with_tags("single_new_tag")
969+
assert new_moment_from_no_tags.tags == ("single_new_tag",)
970+
971+
# Test adding no new tags
972+
same_moment_tags = moment_with_tags.with_tags()
973+
assert same_moment_tags.tags == ("initial_tag_1", "initial_tag_2")
974+
975+
class CustomTag:
976+
"""Example Hashable Tag"""
977+
978+
def __init__(self, value):
979+
self.value = value
980+
981+
def __hash__(self):
982+
return hash(self.value) # pragma: nocover
983+
984+
def __eq__(self, other):
985+
return isinstance(other, CustomTag) and self.value == other.value # pragma: nocover
986+
987+
def __repr__(self):
988+
return f"CustomTag({self.value})" # pragma: nocover
989+
990+
tag_obj = CustomTag("complex_tag")
991+
moment_with_custom_tag = cirq.Moment(op1, tags=("string_tag", 123, tag_obj))
992+
assert moment_with_custom_tag.tags == ("string_tag", 123, tag_obj)
993+
994+
new_moment_with_custom_tag = moment_with_custom_tag.with_tags(456)
995+
assert new_moment_with_custom_tag.tags == ("string_tag", 123, tag_obj, 456)
996+
997+
# Test that tags are dropped if the Moment is changed.
998+
moment = cirq.Moment.from_ops(op1, tags=(tag_obj,))
999+
assert moment.tags == (tag_obj,)
1000+
assert moment.with_operation(op2).tags == ()
1001+
assert moment.with_operations(op2).tags == ()
1002+
assert moment.without_operations_touching([q0]).tags == ()
1003+
1004+
# Test that tags are retained if the Moment is unchanged.
1005+
assert moment.with_operations().tags == (tag_obj,)
1006+
assert moment.without_operations_touching([q1]).tags == (tag_obj,)

cirq-core/cirq/protocols/json_test_data/Moment.json

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,5 +39,28 @@
3939
}
4040
}
4141
]
42+
},
43+
{
44+
"cirq_type": "Moment",
45+
"operations": [
46+
{
47+
"cirq_type": "SingleQubitPauliStringGateOperation",
48+
"pauli": {
49+
"cirq_type": "_PauliX",
50+
"exponent": 1.0,
51+
"global_shift": 0.0
52+
},
53+
"qubit": {
54+
"cirq_type": "LineQubit",
55+
"x": 0
56+
}
57+
}
58+
],
59+
"tags": [
60+
{
61+
"cirq_type": "Duration",
62+
"picos": 25000
63+
}
64+
]
4265
}
43-
]
66+
]

cirq-core/cirq/protocols/json_test_data/Moment.repr

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,9 @@
22
cirq.X(cirq.LineQubit(0)),
33
cirq.Y(cirq.LineQubit(1)),
44
cirq.Z(cirq.LineQubit(2)),
5-
)]
5+
),
6+
cirq.Moment(
7+
cirq.X(cirq.LineQubit(0)),
8+
tags=(cirq.Duration(nanos=25),)
9+
)
10+
]

0 commit comments

Comments
 (0)