1515"""Transformer pass to repack circuits avoiding simultaneous operations with different classes."""
1616
1717import itertools
18- from typing import TYPE_CHECKING , Type , Callable , Optional , Union , Iterable , Sequence , List , Tuple
18+ from typing import TYPE_CHECKING , Type , Callable , Dict , Optional , Union , Iterable , Sequence , List
1919
20- from cirq import ops , circuits , _import
21- from cirq .transformers import transformer_api , transformer_primitives
20+ from cirq import ops , circuits , protocols , _import
21+ from cirq .transformers import transformer_api
2222
2323drop_empty_moments = _import .LazyLoader ('drop_empty_moments' , globals (), 'cirq.transformers' )
2424
@@ -61,38 +61,36 @@ def stratified_circuit(
6161 Returns:
6262 A copy of the original circuit, but with re-arranged operations.
6363 """
64-
6564 # Normalize categories into classifier functions.
66- classifiers = [_category_to_classifier (category ) for category in categories ]
67- # Make the classifiers exhaustive by adding an "everything else" bucket.
68- and_the_rest = lambda op : all (not classifier (op ) for classifier in classifiers )
69- classifiers_and_the_rest = [* classifiers , and_the_rest ]
65+ classifiers = _get_classifiers (circuit , categories )
7066
7167 # Try the algorithm with each permutation of the classifiers.
72- classifiers_permutations = list (itertools .permutations (classifiers_and_the_rest ))
68+ smallest_depth = protocols .num_qubits (circuit ) * len (circuit ) + 1
69+ shortest_stratified_circuit = circuits .Circuit ()
7370 reversed_circuit = circuit [::- 1 ]
74- solutions = []
75- for c in classifiers_permutations :
76- solutions .append (
77- _stratify_circuit (
78- circuit ,
79- classifiers = list (c ),
80- context = context or transformer_api .TransformerContext (),
81- )
71+ for ordered_classifiers in itertools .permutations (classifiers ):
72+ solution = _stratify_circuit (
73+ circuit ,
74+ classifiers = ordered_classifiers ,
75+ context = context or transformer_api .TransformerContext (),
8276 )
77+ if len (solution ) < smallest_depth :
78+ shortest_stratified_circuit = solution
79+ smallest_depth = len (solution )
80+
8381 # Do the same thing, except this time in reverse. This helps for some
8482 # circuits because it inserts operations at the end instead of at the
8583 # beginning.
86- solutions .append (
87- _stratify_circuit (
88- reversed_circuit ,
89- classifiers = list (c ),
90- context = context or transformer_api .TransformerContext (),
91- )[::- 1 ]
92- )
84+ solution = _stratify_circuit (
85+ reversed_circuit ,
86+ classifiers = ordered_classifiers ,
87+ context = context or transformer_api .TransformerContext (),
88+ )[::- 1 ]
89+ if len (solution ) < smallest_depth :
90+ shortest_stratified_circuit = solution
91+ smallest_depth = len (solution )
9392
94- # Return the shortest circuit.
95- return min (solutions , key = lambda c : len (c ))
93+ return shortest_stratified_circuit
9694
9795
9896def _stratify_circuit (
@@ -116,43 +114,88 @@ def _stratify_circuit(
116114 Returns:
117115 The stratified circuit.
118116 """
119- num_categories = len (classifiers ) + 1
120-
121- def map_func (m : 'cirq.Moment' , _ ) -> Sequence ['cirq.Moment' ]:
122- stratified_ops : List [List ['cirq.Operation' ]] = [[] for _ in range (num_categories )]
123- for op in m :
124- if set (op .tags ) & set (context .tags_to_ignore ):
125- stratified_ops [0 ].append (op )
126- continue
127- for i , classifier in enumerate (classifiers ):
128- if classifier (op ):
129- stratified_ops [i + 1 ].append (op )
130- break
131- return [circuits .Moment (op_list ) for op_list in stratified_ops ]
132-
133- stratified_circuit = transformer_primitives .map_moments (circuit , map_func ).unfreeze (copy = False )
134- assert len (stratified_circuit ) == len (circuit ) * num_categories
135-
136- # Try to move operations to the left to reduce circuit depth, preserving stratification.
137- for curr_idx , moment in enumerate (stratified_circuit ):
138- curr_category = curr_idx % num_categories
139- if curr_category == 0 :
140- # Moment containing tagged operations to be ignored.
141- continue
142- batch_removals : List [Tuple [int , 'cirq.Operation' ]] = []
143- batch_inserts : List [Tuple [int , 'cirq.Operation' ]] = []
117+ num_classes = len (classifiers ) + 1 # include one "extra" category for ignored operations
118+ new_moments : List [List ['cirq.Operation' ]] = []
119+
120+ # Keep track of the the latest time index for each qubit, measurement key, and control key.
121+ qubit_time_index : Dict ['cirq.Qid' , int ] = {}
122+ measurement_time_index : Dict ['cirq.MeasurementKey' , int ] = {}
123+ control_time_index : Dict ['cirq.MeasurementKey' , int ] = {}
124+
125+ # The minimum time index for operations with a tag in context.tags_to_ignore.
126+ last_ignored_ops_time_index = 0
127+
128+ for moment in circuit :
129+ # Identify the new time indices that operations should be moved into.
130+ ignored_ops = []
131+ op_time_indices = {}
144132 for op in moment :
145- prv_idx = stratified_circuit .earliest_available_moment (op , end_moment_index = curr_idx )
146- prv_category = prv_idx % num_categories
147- should_move_to_next_batch = curr_category < prv_category
148- prv_idx += curr_category - prv_category + num_categories * should_move_to_next_batch
149- assert prv_idx <= curr_idx and prv_idx % num_categories == curr_idx % num_categories
150- if prv_idx < curr_idx :
151- batch_inserts .append ((prv_idx , op ))
152- batch_removals .append ((curr_idx , op ))
153- stratified_circuit .batch_remove (batch_removals )
154- stratified_circuit .batch_insert_into (batch_inserts )
155- return drop_empty_moments .drop_empty_moments (stratified_circuit )
133+
134+ # Identify the earliest moment that can accommodate this op.
135+ min_time_index_for_op = circuits .circuit .get_earliest_accommodating_moment_index (
136+ op , qubit_time_index , measurement_time_index , control_time_index
137+ )
138+
139+ # Identify the "class" of this operation (by index).
140+ ignored_op = any (tag in op .tags for tag in context .tags_to_ignore )
141+ if not ignored_op :
142+ op_class = _get_op_class (op , classifiers )
143+ else :
144+ op_class = len (classifiers )
145+ ignored_ops .append (op )
146+ min_time_index_for_op = max (min_time_index_for_op , last_ignored_ops_time_index + 1 )
147+
148+ # Identify the time index to place this operation into.
149+ time_index = (min_time_index_for_op // num_classes ) * num_classes + op_class
150+ if time_index < min_time_index_for_op :
151+ time_index += num_classes
152+ op_time_indices [op ] = time_index
153+
154+ # Assign ignored operations to the same moment.
155+ if ignored_ops :
156+ last_ignored_ops_time_index = max (op_time_indices [op ] for op in ignored_ops )
157+ for op in ignored_ops :
158+ op_time_indices [op ] = last_ignored_ops_time_index
159+
160+ # Move the operations into their assigned moments.
161+ for op , time_index in op_time_indices .items ():
162+ if time_index >= len (new_moments ):
163+ new_moments += [[] for _ in range (num_classes )]
164+ new_moments [time_index ].append (op )
165+
166+ # Update qubit, measurment key, and control key moments.
167+ for qubit in op .qubits :
168+ qubit_time_index [qubit ] = time_index
169+ for key in protocols .measurement_key_objs (op ):
170+ measurement_time_index [key ] = time_index
171+ for key in protocols .control_keys (op ):
172+ control_time_index [key ] = time_index
173+
174+ return circuits .Circuit (circuits .Moment (moment ) for moment in new_moments if moment )
175+
176+
177+ def _get_classifiers (
178+ circuit : circuits .AbstractCircuit , categories : Iterable [Category ]
179+ ) -> List [Classifier ]:
180+ """Convert a collection of categories into a list of classifiers.
181+
182+ The returned list of classifiers is:
183+ - Exhaustive, meaning every operation in the circuit is classified by at least one classifier.
184+ - Minimal, meaning unused classifiers are forgotten.
185+ """
186+ # Convert all categories into classifiers, and make the list exhaustive by adding a dummy
187+ # classifier for otherwise unclassified ops.
188+ classifiers = [_category_to_classifier (cat ) for cat in categories ] + [_dummy_classifier ]
189+
190+ # Figure out which classes are actually used in the circuit.
191+ class_is_used = [False for _ in classifiers ]
192+ for op in circuit .all_operations ():
193+ class_is_used [_get_op_class (op , classifiers )] = True
194+ if all (class_is_used ):
195+ break
196+
197+ # Return only the classifiers that are used.
198+ return [classifier for classifier , is_used in zip (classifiers , class_is_used ) if is_used ]
156199
157200
158201# No type for `category` because mypy does not keep the return type when
@@ -177,3 +220,22 @@ def _category_to_classifier(category) -> Classifier:
177220 f'Type[cirq.Gate], Type[cirq.Operation], '
178221 f'or Callable[[cirq.Operation], bool].'
179222 )
223+
224+
225+ def _dummy_classifier (op : 'cirq.Operation' ) -> bool :
226+ """Dummy classifier, used to "complete" a collection of classifiers and make it exhaustive."""
227+
228+
229+ def _get_op_class (op : 'cirq.Operation' , classifiers : Sequence [Classifier ]) -> int :
230+ """Get the "class" of an operator, by index."""
231+ for class_index , classifier in enumerate (classifiers ):
232+ if classifier is _dummy_classifier :
233+ dummy_classifier_index = class_index
234+ elif classifier (op ):
235+ return class_index
236+ # If we got this far, the operation did not match any "actual" classifier,
237+ # so return the index of the dummy classifer.
238+ try :
239+ return dummy_classifier_index
240+ except NameError :
241+ raise ValueError (f"Operation { op } not identified by any classifier" )
0 commit comments