@@ -60,13 +60,17 @@ def map_moments(
6060 circuit : CIRCUIT_TYPE ,
6161 map_func : Callable [[circuits .Moment , int ], Union [circuits .Moment , Sequence [circuits .Moment ]]],
6262 * ,
63+ tags_to_ignore : Sequence [Hashable ] = (),
6364 deep : bool = False ,
6465) -> CIRCUIT_TYPE :
6566 """Applies local transformation on moments, by calling `map_func(moment)` for each moment.
6667
6768 Args:
6869 circuit: Input circuit to apply the transformations on. The input circuit is not mutated.
6970 map_func: Mapping function from (cirq.Moment, moment_index) to a sequence of moments.
71+ tags_to_ignore: Tagged circuit operations marked with any of `tags_to_ignore` will be
72+ ignored when recursively applying the transformer primitive to sub-circuits, given
73+ deep=True.
7074 deep: If true, `map_func` will be recursively applied to circuits wrapped inside
7175 any circuit operations contained within `circuit`.
7276
@@ -79,6 +83,8 @@ def map_moments(
7983 for i , op in circuit .findall_operations (
8084 lambda o : isinstance (o .untagged , circuits .CircuitOperation )
8185 ):
86+ if set (op .tags ).intersection (tags_to_ignore ):
87+ continue
8288 op_untagged = cast (circuits .CircuitOperation , op .untagged )
8389 mapped_op = op_untagged .replace (
8490 circuit = map_moments (op_untagged .circuit , map_func , deep = deep )
@@ -190,6 +196,7 @@ def merge_operations(
190196 merge_func : Callable [[ops .Operation , ops .Operation ], Optional [ops .Operation ]],
191197 * ,
192198 tags_to_ignore : Sequence [Hashable ] = (),
199+ deep : bool = False ,
193200) -> CIRCUIT_TYPE :
194201 """Merges operations in a circuit by calling `merge_func` iteratively on operations.
195202
@@ -226,6 +233,8 @@ def merge_operations(
226233 tags_to_ignore: Sequence of tags which should be ignored while applying `merge_func` on
227234 tagged operations -- i.e. `merge_func(op1, op2)` will be called only if both `op1` and
228235 `op2` satisfy `set(op.tags).isdisjoint(tags_to_ignore)`.
236+ deep: If true, the transformer primitive will be recursively applied to all circuits
237+ wrapped inside circuit operations.
229238
230239
231240 Returns:
@@ -235,9 +244,11 @@ def merge_operations(
235244 ValueError if the merged operation acts on new qubits outside the set of qubits
236245 corresponding to the original operations to be merged.
237246 """
247+ _circuit_op_tag = "_internal_tag_to_mark_circuit_ops_in_circuit"
248+ tags_to_ignore_set = set (tags_to_ignore ) | {_circuit_op_tag }
238249
239250 def apply_merge_func (op1 : ops .Operation , op2 : ops .Operation ) -> Optional [ops .Operation ]:
240- if not all (set (op .tags ). isdisjoint ( tags_to_ignore ) for op in [op1 , op2 ]):
251+ if not all (tags_to_ignore_set . isdisjoint (op .tags ) for op in [op1 , op2 ]):
241252 return None
242253 new_op = merge_func (op1 , op2 )
243254 qubit_set = frozenset (op1 .qubits + op2 .qubits )
@@ -252,6 +263,23 @@ def apply_merge_func(op1: ops.Operation, op2: ops.Operation) -> Optional[ops.Ope
252263 for current_moment in circuit :
253264 new_moment = circuits .Moment ()
254265 for op in sorted (current_moment .operations , key = lambda op : op .qubits ):
266+ if (
267+ deep
268+ and isinstance (op .untagged , circuits .CircuitOperation )
269+ and tags_to_ignore_set .isdisjoint (op .tags )
270+ ):
271+ op_untagged = op .untagged
272+ new_moment = new_moment .with_operation (
273+ op_untagged .replace (
274+ circuit = merge_operations (
275+ op_untagged .circuit ,
276+ merge_func ,
277+ tags_to_ignore = tags_to_ignore ,
278+ deep = True ,
279+ )
280+ ).with_tags (* op .tags , _circuit_op_tag )
281+ )
282+ continue
255283 op_qs = set (op .qubits )
256284 idx = ret_circuit .prev_moment_operating_on (tuple (op_qs ))
257285 if idx is not None and op_qs .issubset (ret_circuit [idx ][op_qs ].operations [0 ].qubits ):
@@ -279,6 +307,12 @@ def apply_merge_func(op1: ops.Operation, op2: ops.Operation) -> Optional[ops.Ope
279307 idx = ret_circuit .prev_moment_operating_on (tuple (op_qs ))
280308 new_moment = new_moment .with_operation (op )
281309 ret_circuit += new_moment
310+ if deep :
311+ ret_circuit = map_operations (
312+ ret_circuit ,
313+ lambda o , _ : o .untagged .with_tags (* (set (o .tags ) - {_circuit_op_tag })),
314+ deep = True ,
315+ )
282316 return _to_target_circuit_type (ret_circuit , circuit )
283317
284318
@@ -288,6 +322,7 @@ def merge_operations_to_circuit_op(
288322 * ,
289323 tags_to_ignore : Sequence [Hashable ] = (),
290324 merged_circuit_op_tag : str = "Merged connected component" ,
325+ deep : bool = False ,
291326) -> CIRCUIT_TYPE :
292327 """Merges connected components of operations and wraps each component into a circuit operation.
293328
@@ -307,6 +342,8 @@ def merge_operations_to_circuit_op(
307342 potential candidates for any connected component.
308343 merged_circuit_op_tag: Tag to be applied on circuit operations wrapping valid connected
309344 components.
345+ deep: If true, the transformer primitive will be recursively applied to all circuits
346+ wrapped inside circuit operations.
310347
311348 Returns:
312349 Copy of input circuit with valid connected components wrapped in tagged circuit operations.
@@ -329,7 +366,7 @@ def get_ops(op: 'cirq.Operation'):
329366 merged_circuit_op_tag
330367 )
331368
332- return merge_operations (circuit , merge_func , tags_to_ignore = tags_to_ignore )
369+ return merge_operations (circuit , merge_func , tags_to_ignore = tags_to_ignore , deep = deep )
333370
334371
335372def merge_k_qubit_unitaries_to_circuit_op (
@@ -338,6 +375,7 @@ def merge_k_qubit_unitaries_to_circuit_op(
338375 * ,
339376 tags_to_ignore : Sequence [Hashable ] = (),
340377 merged_circuit_op_tag : Optional [str ] = None ,
378+ deep : bool = False ,
341379) -> CIRCUIT_TYPE :
342380 """Merges connected components of operations, acting on <= k qubits, into circuit operations.
343381
@@ -353,6 +391,8 @@ def merge_k_qubit_unitaries_to_circuit_op(
353391 potential candidates for any connected component.
354392 merged_circuit_op_tag: Tag to be applied on circuit operations wrapping valid connected
355393 components. A default tag is applied if left None.
394+ deep: If true, the transformer primitive will be recursively applied to all circuits
395+ wrapped inside circuit operations.
356396
357397 Returns:
358398 Copy of input circuit with valid connected components wrapped in tagged circuit operations.
@@ -370,12 +410,16 @@ def can_merge(ops1: Sequence['cirq.Operation'], ops2: Sequence['cirq.Operation']
370410 can_merge ,
371411 tags_to_ignore = tags_to_ignore ,
372412 merged_circuit_op_tag = merged_circuit_op_tag or f"Merged { k } q unitary connected component." ,
413+ deep = deep ,
373414 )
374415
375416
376417def merge_moments (
377418 circuit : CIRCUIT_TYPE ,
378419 merge_func : Callable [[circuits .Moment , circuits .Moment ], Optional [circuits .Moment ]],
420+ * ,
421+ tags_to_ignore : Sequence [Hashable ] = (),
422+ deep : bool = False ,
379423) -> CIRCUIT_TYPE :
380424 """Merges adjacent moments, one by one from left to right, by calling `merge_func(m1, m2)`.
381425
@@ -384,12 +428,27 @@ def merge_moments(
384428 merge_func: Callable to determine whether two adjacent moments in the circuit should be
385429 merged. If the moments can be merged, the callable should return the merged moment,
386430 else None.
431+ tags_to_ignore: Tagged circuit operations marked with any of `tags_to_ignore` will be
432+ ignored when recursively applying the transformer primitive to sub-circuits, given
433+ deep=True.
434+ deep: If true, the transformer primitive will be recursively applied to all circuits
435+ wrapped inside circuit operations.
387436
388437 Returns:
389438 Copy of input circuit with merged moments.
390439 """
391440 if not circuit :
392441 return circuit
442+ if deep :
443+ circuit = map_operations (
444+ circuit ,
445+ lambda op , _ : op .untagged .replace (
446+ circuit = merge_moments (op .untagged .circuit , merge_func , deep = deep )
447+ ).with_tags (* op .tags )
448+ if isinstance (op .untagged , circuits .CircuitOperation )
449+ else op ,
450+ tags_to_ignore = tags_to_ignore ,
451+ )
393452 merged_moments : List [circuits .Moment ] = [circuit [0 ]]
394453 for current_moment in circuit [1 :]:
395454 merged_moment = merge_func (merged_moments [- 1 ], current_moment )
0 commit comments