1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414"""An immutable version of the Circuit data structure."""
15- from typing import AbstractSet , FrozenSet , Iterable , Iterator , Sequence , Tuple , TYPE_CHECKING , Union
15+ from typing import (
16+ AbstractSet ,
17+ FrozenSet ,
18+ Hashable ,
19+ Iterable ,
20+ Iterator ,
21+ Sequence ,
22+ Tuple ,
23+ TYPE_CHECKING ,
24+ Union ,
25+ )
1626
1727import numpy as np
1828
@@ -34,7 +44,10 @@ class FrozenCircuit(AbstractCircuit, protocols.SerializableByKey):
3444 """
3545
3646 def __init__ (
37- self , * contents : 'cirq.OP_TREE' , strategy : 'cirq.InsertStrategy' = InsertStrategy .EARLIEST
47+ self ,
48+ * contents : 'cirq.OP_TREE' ,
49+ strategy : 'cirq.InsertStrategy' = InsertStrategy .EARLIEST ,
50+ tags : Sequence [Hashable ] = (),
3851 ) -> None :
3952 """Initializes a frozen circuit.
4053
@@ -47,9 +60,14 @@ def __init__(
4760 strategy: When initializing the circuit with operations and moments
4861 from `contents`, this determines how the operations are packed
4962 together.
63+ tags: A sequence of any type of object that is useful to attach metadata
64+ to this circuit as long as the type is hashable. If you wish the
65+ resulting circuit to be eventually serialized into JSON, you should
66+ also restrict the tags to be JSON serializable.
5067 """
5168 base = Circuit (contents , strategy = strategy )
5269 self ._moments = tuple (base .moments )
70+ self ._tags = tuple (tags )
5371
5472 @classmethod
5573 def _from_moments (cls , moments : Iterable ['cirq.Moment' ]) -> 'FrozenCircuit' :
@@ -61,10 +79,35 @@ def _from_moments(cls, moments: Iterable['cirq.Moment']) -> 'FrozenCircuit':
6179 def moments (self ) -> Sequence ['cirq.Moment' ]:
6280 return self ._moments
6381
82+ @property
83+ def tags (self ) -> Tuple [Hashable , ...]:
84+ """Returns a tuple of the Circuit's tags."""
85+ return self ._tags
86+
87+ @_compat .cached_property
88+ def untagged (self ) -> 'cirq.FrozenCircuit' :
89+ """Returns the underlying FrozenCircuit without any tags."""
90+ return self ._from_moments (self ._moments ) if self .tags else self
91+
92+ def with_tags (self , * new_tags : Hashable ) -> 'cirq.FrozenCircuit' :
93+ """Creates a new tagged `FrozenCircuit` with `self.tags` and `new_tags` combined."""
94+ if not new_tags :
95+ return self
96+ new_circuit = FrozenCircuit (tags = self .tags + new_tags )
97+ new_circuit ._moments = self ._moments
98+ return new_circuit
99+
64100 @_compat .cached_method
65101 def __hash__ (self ) -> int :
66102 # Explicitly cached for performance
67- return hash ((self .moments ,))
103+ return hash ((self .moments , self .tags ))
104+
105+ def __eq__ (self , other ):
106+ super_eq = super ().__eq__ (other )
107+ if super_eq is not True :
108+ return super_eq
109+ other_tags = other .tags if isinstance (other , FrozenCircuit ) else ()
110+ return self .tags == other_tags
68111
69112 def __getstate__ (self ):
70113 # Don't save hash when pickling; see #3777.
@@ -130,11 +173,23 @@ def all_measurement_key_names(self) -> FrozenSet[str]:
130173
131174 @_compat .cached_method
132175 def _is_parameterized_ (self ) -> bool :
133- return super ()._is_parameterized_ ()
176+ return super ()._is_parameterized_ () or any (
177+ protocols .is_parameterized (tag ) for tag in self .tags
178+ )
134179
135180 @_compat .cached_method
136181 def _parameter_names_ (self ) -> AbstractSet [str ]:
137- return super ()._parameter_names_ ()
182+ tag_params = {name for tag in self .tags for name in protocols .parameter_names (tag )}
183+ return super ()._parameter_names_ () | tag_params
184+
185+ def _resolve_parameters_ (
186+ self , resolver : 'cirq.ParamResolver' , recursive : bool
187+ ) -> 'cirq.FrozenCircuit' :
188+ resolved_circuit = super ()._resolve_parameters_ (resolver , recursive )
189+ resolved_tags = [
190+ protocols .resolve_parameters (tag , resolver , recursive ) for tag in self .tags
191+ ]
192+ return resolved_circuit .with_tags (* resolved_tags )
138193
139194 def _measurement_key_names_ (self ) -> FrozenSet [str ]:
140195 return self .all_measurement_key_names ()
@@ -161,6 +216,20 @@ def __pow__(self, other) -> 'cirq.FrozenCircuit':
161216 except :
162217 return NotImplemented
163218
219+ def _repr_args (self ) -> str :
220+ moments_repr = super ()._repr_args ()
221+ tag_repr = ',' .join (_compat .proper_repr (t ) for t in self ._tags )
222+ return f'{ moments_repr } , tags=[{ tag_repr } ]' if self .tags else moments_repr
223+
224+ def _json_dict_ (self ):
225+ attribute_names = ['moments' , 'tags' ] if self .tags else ['moments' ]
226+ ret = protocols .obj_to_dict_helper (self , attribute_names )
227+ return ret
228+
229+ @classmethod
230+ def _from_json_dict_ (cls , moments , * , tags = (), ** kwargs ):
231+ return cls (moments , strategy = InsertStrategy .EARLIEST , tags = tags )
232+
164233 def concat_ragged (
165234 * circuits : 'cirq.AbstractCircuit' , align : Union ['cirq.Alignment' , str ] = Alignment .LEFT
166235 ) -> 'cirq.FrozenCircuit' :
0 commit comments