@@ -81,8 +81,8 @@ def map_moments(
8181 ):
8282 op_untagged = cast (circuits .CircuitOperation , op .untagged )
8383 mapped_op = op_untagged .replace (
84- circuit = map_moments (op_untagged .mapped_circuit () , map_func , deep = deep ). freeze ( )
85- )
84+ circuit = map_moments (op_untagged .circuit , map_func , deep = deep )
85+ ). with_tags ( * op . tags )
8686 batch_replace .append ((i , op , mapped_op ))
8787 mutable_circuit = circuit .unfreeze (copy = True )
8888 mutable_circuit .batch_replace (batch_replace )
@@ -180,7 +180,8 @@ def map_operations_and_unroll(
180180 deep = deep ,
181181 raise_if_add_qubits = raise_if_add_qubits ,
182182 tags_to_ignore = tags_to_ignore ,
183- )
183+ ),
184+ deep = deep ,
184185 )
185186
186187
@@ -399,12 +400,6 @@ def merge_moments(
399400 return _create_target_circuit_type (merged_moments , circuit )
400401
401402
402- def _check_circuit_op (op , tags_to_check : Optional [Sequence [Hashable ]]) -> bool :
403- return isinstance (op .untagged , circuits .CircuitOperation ) and (
404- tags_to_check is None or any (tag in op .tags for tag in tags_to_check )
405- )
406-
407-
408403def unroll_circuit_op (
409404 circuit : CIRCUIT_TYPE ,
410405 * ,
@@ -418,8 +413,8 @@ def unroll_circuit_op(
418413
419414 Args:
420415 circuit: Input circuit to apply the transformations on. The input circuit is not mutated.
421- deep: If True, `unroll_circuit_op` is recursively called on all circuit operations matching
422- `tags_to_check` .
416+ deep: If true, the transformer primitive will be recursively applied to all circuits
417+ wrapped inside circuit operations .
423418 tags_to_check: If specified, only circuit operations tagged with one of the `tags_to_check`
424419 are unrolled.
425420
@@ -430,12 +425,18 @@ def unroll_circuit_op(
430425 def map_func (m : circuits .Moment , _ : int ):
431426 to_zip : List ['cirq.AbstractCircuit' ] = []
432427 for op in m :
433- if _check_circuit_op (op , tags_to_check ):
434- sub_circuit = cast (circuits .CircuitOperation , op .untagged ).mapped_circuit ()
428+ op_untagged = op .untagged
429+ if isinstance (op_untagged , circuits .CircuitOperation ):
430+ if deep :
431+ op_untagged = op_untagged .replace (
432+ circuit = unroll_circuit_op (
433+ op_untagged .circuit , deep = deep , tags_to_check = tags_to_check
434+ )
435+ )
435436 to_zip .append (
436- unroll_circuit_op ( sub_circuit , deep = deep , tags_to_check = tags_to_check )
437- if deep
438- else sub_circuit
437+ op_untagged . mapped_circuit ( )
438+ if ( tags_to_check is None or set ( tags_to_check ). intersection ( op . tags ))
439+ else circuits . Circuit ( op_untagged . with_tags ( * op . tags ))
439440 )
440441 else :
441442 to_zip .append (circuits .Circuit (op ))
@@ -458,27 +459,36 @@ def unroll_circuit_op_greedy_earliest(
458459
459460 Args:
460461 circuit: Input circuit to apply the transformations on. The input circuit is not mutated.
461- deep: If True, `unroll_circuit_op_greedy_earliest` is recursively called on all circuit
462- operations matching `tags_to_check` .
462+ deep: If true, the transformer primitive will be recursively applied to all circuits
463+ wrapped inside circuit operations .
463464 tags_to_check: If specified, only circuit operations tagged with one of the `tags_to_check`
464465 are unrolled.
465466
466467 Returns:
467468 Copy of input circuit with (Tagged) CircuitOperation's expanded using EARLIEST strategy.
468469 """
469- batch_removals = [* circuit .findall_operations (lambda op : _check_circuit_op (op , tags_to_check ))]
470- batch_inserts = []
471- for i , op in batch_removals :
472- sub_circuit = cast (circuits .CircuitOperation , op .untagged ).mapped_circuit ()
473- sub_circuit = (
474- unroll_circuit_op_greedy_earliest (sub_circuit , deep = deep , tags_to_check = tags_to_check )
475- if deep
476- else sub_circuit
477- )
478- batch_inserts += [(i , sub_circuit .all_operations ())]
470+ batch_replace = []
471+ batch_remove = []
472+ batch_insert = []
473+ for i , op in circuit .findall_operations (
474+ lambda o : isinstance (o .untagged , circuits .CircuitOperation )
475+ ):
476+ op_untagged = cast (circuits .CircuitOperation , op .untagged )
477+ if deep :
478+ op_untagged = op_untagged .replace (
479+ circuit = unroll_circuit_op_greedy_earliest (
480+ op_untagged .circuit , deep = deep , tags_to_check = tags_to_check
481+ )
482+ )
483+ if tags_to_check is None or set (tags_to_check ).intersection (op .tags ):
484+ batch_remove .append ((i , op ))
485+ batch_insert .append ((i , op_untagged .mapped_circuit ().all_operations ()))
486+ elif deep :
487+ batch_replace .append ((i , op , op_untagged .with_tags (* op .tags )))
479488 unrolled_circuit = circuit .unfreeze (copy = True )
480- unrolled_circuit .batch_remove (batch_removals )
481- unrolled_circuit .batch_insert (batch_inserts )
489+ unrolled_circuit .batch_replace (batch_replace )
490+ unrolled_circuit .batch_remove (batch_remove )
491+ unrolled_circuit .batch_insert (batch_insert )
482492 return _to_target_circuit_type (unrolled_circuit , circuit )
483493
484494
@@ -496,8 +506,8 @@ def unroll_circuit_op_greedy_frontier(
496506
497507 Args:
498508 circuit: Input circuit to apply the transformations on. The input circuit is not mutated.
499- deep: If True, `unroll_circuit_op_greedy_frontier` is recursively called on all circuit
500- operations matching `tags_to_check` .
509+ deep: If true, the transformer primitive will be recursively applied to all circuits
510+ wrapped inside circuit operations .
501511 tags_to_check: If specified, only circuit operations tagged with one of the `tags_to_check`
502512 are unrolled.
503513
@@ -506,16 +516,29 @@ def unroll_circuit_op_greedy_frontier(
506516 """
507517 unrolled_circuit = circuit .unfreeze (copy = True )
508518 frontier : Dict ['cirq.Qid' , int ] = defaultdict (lambda : 0 )
509- for idx , op in circuit .findall_operations (lambda op : _check_circuit_op (op , tags_to_check )):
510- idx = max (idx , max (frontier [q ] for q in op .qubits ))
511- unrolled_circuit .clear_operations_touching (op .qubits , [idx ])
512- sub_circuit = cast (circuits .CircuitOperation , op .untagged ).mapped_circuit ()
513- sub_circuit = (
514- unroll_circuit_op_greedy_earliest (sub_circuit , deep = deep , tags_to_check = tags_to_check )
515- if deep
516- else sub_circuit
517- )
518- frontier = unrolled_circuit .insert_at_frontier (sub_circuit .all_operations (), idx , frontier )
519+ idx = 0
520+ while idx < len (unrolled_circuit ):
521+ for op in unrolled_circuit [idx ].operations :
522+ # Don't touch stuff inserted by unrolling previous circuit ops.
523+ if not isinstance (op .untagged , circuits .CircuitOperation ):
524+ continue
525+ if any (frontier [q ] > idx for q in op .qubits ):
526+ continue
527+ op_untagged = cast (circuits .CircuitOperation , op .untagged )
528+ if deep :
529+ op_untagged = op_untagged .replace (
530+ circuit = unroll_circuit_op_greedy_frontier (
531+ op_untagged .circuit , deep = deep , tags_to_check = tags_to_check
532+ )
533+ )
534+ if tags_to_check is None or set (tags_to_check ).intersection (op .tags ):
535+ unrolled_circuit .clear_operations_touching (op .qubits , [idx ])
536+ frontier = unrolled_circuit .insert_at_frontier (
537+ op_untagged .mapped_circuit ().all_operations (), idx , frontier
538+ )
539+ elif deep :
540+ unrolled_circuit .batch_replace ([(idx , op , op_untagged .with_tags (* op .tags ))])
541+ idx += 1
519542 return _to_target_circuit_type (unrolled_circuit , circuit )
520543
521544
0 commit comments