diff --git a/cirq-core/cirq/transformers/transformer_primitives.py b/cirq-core/cirq/transformers/transformer_primitives.py index ddd14db08e5..18d3d035c5b 100644 --- a/cirq-core/cirq/transformers/transformer_primitives.py +++ b/cirq-core/cirq/transformers/transformer_primitives.py @@ -81,8 +81,8 @@ def map_moments( ): op_untagged = cast(circuits.CircuitOperation, op.untagged) mapped_op = op_untagged.replace( - circuit=map_moments(op_untagged.mapped_circuit(), map_func, deep=deep).freeze() - ) + circuit=map_moments(op_untagged.circuit, map_func, deep=deep) + ).with_tags(*op.tags) batch_replace.append((i, op, mapped_op)) mutable_circuit = circuit.unfreeze(copy=True) mutable_circuit.batch_replace(batch_replace) @@ -180,7 +180,8 @@ def map_operations_and_unroll( deep=deep, raise_if_add_qubits=raise_if_add_qubits, tags_to_ignore=tags_to_ignore, - ) + ), + deep=deep, ) @@ -399,12 +400,6 @@ def merge_moments( return _create_target_circuit_type(merged_moments, circuit) -def _check_circuit_op(op, tags_to_check: Optional[Sequence[Hashable]]) -> bool: - return isinstance(op.untagged, circuits.CircuitOperation) and ( - tags_to_check is None or any(tag in op.tags for tag in tags_to_check) - ) - - def unroll_circuit_op( circuit: CIRCUIT_TYPE, *, @@ -418,8 +413,8 @@ def unroll_circuit_op( Args: circuit: Input circuit to apply the transformations on. The input circuit is not mutated. - deep: If True, `unroll_circuit_op` is recursively called on all circuit operations matching - `tags_to_check`. + deep: If true, the transformer primitive will be recursively applied to all circuits + wrapped inside circuit operations. tags_to_check: If specified, only circuit operations tagged with one of the `tags_to_check` are unrolled. @@ -430,12 +425,18 @@ def unroll_circuit_op( def map_func(m: circuits.Moment, _: int): to_zip: List['cirq.AbstractCircuit'] = [] for op in m: - if _check_circuit_op(op, tags_to_check): - sub_circuit = cast(circuits.CircuitOperation, op.untagged).mapped_circuit() + op_untagged = op.untagged + if isinstance(op_untagged, circuits.CircuitOperation): + if deep: + op_untagged = op_untagged.replace( + circuit=unroll_circuit_op( + op_untagged.circuit, deep=deep, tags_to_check=tags_to_check + ) + ) to_zip.append( - unroll_circuit_op(sub_circuit, deep=deep, tags_to_check=tags_to_check) - if deep - else sub_circuit + op_untagged.mapped_circuit() + if (tags_to_check is None or set(tags_to_check).intersection(op.tags)) + else circuits.Circuit(op_untagged.with_tags(*op.tags)) ) else: to_zip.append(circuits.Circuit(op)) @@ -458,27 +459,36 @@ def unroll_circuit_op_greedy_earliest( Args: circuit: Input circuit to apply the transformations on. The input circuit is not mutated. - deep: If True, `unroll_circuit_op_greedy_earliest` is recursively called on all circuit - operations matching `tags_to_check`. + deep: If true, the transformer primitive will be recursively applied to all circuits + wrapped inside circuit operations. tags_to_check: If specified, only circuit operations tagged with one of the `tags_to_check` are unrolled. Returns: Copy of input circuit with (Tagged) CircuitOperation's expanded using EARLIEST strategy. """ - batch_removals = [*circuit.findall_operations(lambda op: _check_circuit_op(op, tags_to_check))] - batch_inserts = [] - for i, op in batch_removals: - sub_circuit = cast(circuits.CircuitOperation, op.untagged).mapped_circuit() - sub_circuit = ( - unroll_circuit_op_greedy_earliest(sub_circuit, deep=deep, tags_to_check=tags_to_check) - if deep - else sub_circuit - ) - batch_inserts += [(i, sub_circuit.all_operations())] + batch_replace = [] + batch_remove = [] + batch_insert = [] + for i, op in circuit.findall_operations( + lambda o: isinstance(o.untagged, circuits.CircuitOperation) + ): + op_untagged = cast(circuits.CircuitOperation, op.untagged) + if deep: + op_untagged = op_untagged.replace( + circuit=unroll_circuit_op_greedy_earliest( + op_untagged.circuit, deep=deep, tags_to_check=tags_to_check + ) + ) + if tags_to_check is None or set(tags_to_check).intersection(op.tags): + batch_remove.append((i, op)) + batch_insert.append((i, op_untagged.mapped_circuit().all_operations())) + elif deep: + batch_replace.append((i, op, op_untagged.with_tags(*op.tags))) unrolled_circuit = circuit.unfreeze(copy=True) - unrolled_circuit.batch_remove(batch_removals) - unrolled_circuit.batch_insert(batch_inserts) + unrolled_circuit.batch_replace(batch_replace) + unrolled_circuit.batch_remove(batch_remove) + unrolled_circuit.batch_insert(batch_insert) return _to_target_circuit_type(unrolled_circuit, circuit) @@ -496,8 +506,8 @@ def unroll_circuit_op_greedy_frontier( Args: circuit: Input circuit to apply the transformations on. The input circuit is not mutated. - deep: If True, `unroll_circuit_op_greedy_frontier` is recursively called on all circuit - operations matching `tags_to_check`. + deep: If true, the transformer primitive will be recursively applied to all circuits + wrapped inside circuit operations. tags_to_check: If specified, only circuit operations tagged with one of the `tags_to_check` are unrolled. @@ -506,16 +516,29 @@ def unroll_circuit_op_greedy_frontier( """ unrolled_circuit = circuit.unfreeze(copy=True) frontier: Dict['cirq.Qid', int] = defaultdict(lambda: 0) - for idx, op in circuit.findall_operations(lambda op: _check_circuit_op(op, tags_to_check)): - idx = max(idx, max(frontier[q] for q in op.qubits)) - unrolled_circuit.clear_operations_touching(op.qubits, [idx]) - sub_circuit = cast(circuits.CircuitOperation, op.untagged).mapped_circuit() - sub_circuit = ( - unroll_circuit_op_greedy_earliest(sub_circuit, deep=deep, tags_to_check=tags_to_check) - if deep - else sub_circuit - ) - frontier = unrolled_circuit.insert_at_frontier(sub_circuit.all_operations(), idx, frontier) + idx = 0 + while idx < len(unrolled_circuit): + for op in unrolled_circuit[idx].operations: + # Don't touch stuff inserted by unrolling previous circuit ops. + if not isinstance(op.untagged, circuits.CircuitOperation): + continue + if any(frontier[q] > idx for q in op.qubits): + continue + op_untagged = cast(circuits.CircuitOperation, op.untagged) + if deep: + op_untagged = op_untagged.replace( + circuit=unroll_circuit_op_greedy_frontier( + op_untagged.circuit, deep=deep, tags_to_check=tags_to_check + ) + ) + if tags_to_check is None or set(tags_to_check).intersection(op.tags): + unrolled_circuit.clear_operations_touching(op.qubits, [idx]) + frontier = unrolled_circuit.insert_at_frontier( + op_untagged.mapped_circuit().all_operations(), idx, frontier + ) + elif deep: + unrolled_circuit.batch_replace([(idx, op, op_untagged.with_tags(*op.tags))]) + idx += 1 return _to_target_circuit_type(unrolled_circuit, circuit) diff --git a/cirq-core/cirq/transformers/transformer_primitives_test.py b/cirq-core/cirq/transformers/transformer_primitives_test.py index a639b90626e..d00fba04c98 100644 --- a/cirq-core/cirq/transformers/transformer_primitives_test.py +++ b/cirq-core/cirq/transformers/transformer_primitives_test.py @@ -117,6 +117,7 @@ def map_func(op: cirq.Operation, _: int) -> cirq.OP_TREE: ) +# pylint: disable=line-too-long def test_map_operations_deep_subcircuits(): q = cirq.LineQubit.range(5) c_orig = cirq.Circuit( @@ -127,9 +128,14 @@ def test_map_operations_deep_subcircuits(): c_orig_with_circuit_ops = cirq.Circuit( cirq.CircuitOperation( cirq.FrozenCircuit( - [cirq.CircuitOperation(cirq.FrozenCircuit(op)) for op in c_orig.all_operations()] + [ + cirq.CircuitOperation(cirq.FrozenCircuit(op)).repeat(2).with_tags("internal") + for op in c_orig.all_operations() + ] ) ) + .repeat(6) + .with_tags("external") ) def map_func(op: cirq.Operation, _: int) -> cirq.OP_TREE: @@ -139,23 +145,73 @@ def map_func(op: cirq.Operation, _: int) -> cirq.OP_TREE: cirq.Z.on_each(*op.qubits), ] if op.gate == cirq.CX else op - c_mapped = cirq.map_operations(c_orig_with_circuit_ops, map_func, deep=True) - c_mapped = cirq.unroll_circuit_op(c_mapped, deep=True, tags_to_check=None) cirq.testing.assert_has_diagram( - c_mapped, + c_orig_with_circuit_ops, ''' -0: ───Z───@───Z─────────────── - │ -1: ───Z───X───Z─────────────── - -2: ───Z───X───Z─────────────── - │ -3: ───Z───@───Z───Z───@───Z─── - │ -4: ───────────────Z───X───Z─── + [ [ 0: ───@─── ] ] + [ 0: ───[ │ ]────────────────────────────────────────────────────────────── ] + [ [ 1: ───X─── ](loops=2)['internal'] ] + [ │ ] + [ 1: ───#2────────────────────────────────────────────────────────────────────────── ] + [ ] + [ [ 2: ───X─── ] ] +0: ───[ 2: ───[ │ ]────────────────────────────────────────────────────────────── ]──────────────────────── + [ [ 3: ───@─── ](loops=2)['internal'] ] + [ │ ] + [ │ [ 3: ───@─── ] ] + [ 3: ───#2────────────────────────────────────[ │ ]──────────────────────── ] + [ [ 4: ───X─── ](loops=2)['internal'] ] + [ │ ] + [ 4: ─────────────────────────────────────────#2──────────────────────────────────── ](loops=6)['external'] + │ +1: ───#2──────────────────────────────────────────────────────────────────────────────────────────────────────────── + │ +2: ───#3──────────────────────────────────────────────────────────────────────────────────────────────────────────── + │ +3: ───#4──────────────────────────────────────────────────────────────────────────────────────────────────────────── + │ +4: ───#5──────────────────────────────────────────────────────────────────────────────────────────────────────────── ''', ) + c_mapped = cirq.map_operations(c_orig_with_circuit_ops, map_func, deep=True) + for unroller in [ + cirq.unroll_circuit_op, + cirq.unroll_circuit_op_greedy_earliest, + cirq.unroll_circuit_op_greedy_frontier, + ]: + cirq.testing.assert_has_diagram( + unroller(c_mapped, deep=True), + ''' + [ [ 0: ───Z───@───Z─── ] ] + [ 0: ───[ │ ]────────────────────────────────────────────────────────────────────── ] + [ [ 1: ───Z───X───Z─── ](loops=2)['internal'] ] + [ │ ] + [ 1: ───#2────────────────────────────────────────────────────────────────────────────────────────── ] + [ ] + [ [ 2: ───Z───X───Z─── ] ] +0: ───[ 2: ───[ │ ]────────────────────────────────────────────────────────────────────── ]──────────────────────── + [ [ 3: ───Z───@───Z─── ](loops=2)['internal'] ] + [ │ ] + [ │ [ 3: ───Z───@───Z─── ] ] + [ 3: ───#2────────────────────────────────────────────[ │ ]──────────────────────── ] + [ [ 4: ───Z───X───Z─── ](loops=2)['internal'] ] + [ │ ] + [ 4: ─────────────────────────────────────────────────#2──────────────────────────────────────────── ](loops=6)['external'] + │ +1: ───#2──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────── + │ +2: ───#3──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────── + │ +3: ───#4──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────── + │ +4: ───#5──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────── +''', + ) + + +# pylint: enable=line-too-long + def test_map_operations_respects_tags_to_ignore(): q = cirq.LineQubit.range(2) @@ -204,13 +260,29 @@ def test_unroll_circuit_op_and_variants(): [cirq.Moment(cirq.CircuitOperation(cirq.FrozenCircuit(m))) for m in mapped_circuit[:-1]], mapped_circuit[-1], ) + cirq.testing.assert_has_diagram( + mapped_circuit_deep, + ''' +0: ───[ 0: ───X─── ]────────────────────────────────────────────────────────────X─── + +1: ────────────────────[ 1: ───[ 1: ───Z───Z─── ]['']─── ]─────── +''', + ) for unroller in [ cirq.unroll_circuit_op_greedy_earliest, cirq.unroll_circuit_op_greedy_frontier, cirq.unroll_circuit_op, ]: cirq.testing.assert_same_circuits( - unroller(mapped_circuit), unroller(mapped_circuit_deep, tags_to_check=None, deep=True) + unroller(mapped_circuit), unroller(mapped_circuit_deep, deep=True, tags_to_check=None) + ) + cirq.testing.assert_has_diagram( + unroller(mapped_circuit_deep, deep=True), + ''' +0: ───[ 0: ───X─── ]────────────────────────X─── + +1: ────────────────────[ 1: ───Z───Z─── ]─────── + ''', ) cirq.testing.assert_has_diagram( @@ -239,6 +311,16 @@ def test_unroll_circuit_op_and_variants(): ) +def test_unroll_circuit_op_greedy_frontier_doesnt_touch_same_op_twice(): + q = cirq.NamedQubit("q") + nested_ops = [cirq.CircuitOperation(cirq.FrozenCircuit(cirq.X(q)))] * 5 + nested_circuit_op = cirq.CircuitOperation(cirq.FrozenCircuit(nested_ops)) + c = cirq.Circuit(nested_circuit_op, nested_circuit_op, nested_circuit_op) + c_expected = cirq.Circuit(nested_ops, nested_ops, nested_ops) + c_unrolled = cirq.unroll_circuit_op_greedy_frontier(c, tags_to_check=None) + cirq.testing.assert_same_circuits(c_unrolled, c_expected) + + def test_unroll_circuit_op_deep(): q0, q1, q2 = cirq.LineQubit.range(3) c = cirq.Circuit(