@@ -188,3 +188,68 @@ def rewriter_replace_with_decomp(op: 'cirq.CircuitOperation') -> 'cirq.OP_TREE':
188188 ║ ║
189189a: ═════════════════════════════════════════════════════════════════════════════════════════════@══════════════════════════════^═══''' ,
190190 )
191+
192+
193+ def test_merge_k_qubit_unitaries_deep ():
194+ q = cirq .LineQubit .range (2 )
195+ h_cz_y = [cirq .H (q [0 ]), cirq .CZ (* q ), cirq .Y (q [1 ])]
196+ c_orig = cirq .Circuit (
197+ h_cz_y ,
198+ cirq .Moment (cirq .X (q [0 ]).with_tags ("ignore" ), cirq .Y (q [1 ])),
199+ cirq .CircuitOperation (cirq .FrozenCircuit (h_cz_y )).repeat (6 ).with_tags ("ignore" ),
200+ [cirq .CNOT (* q ), cirq .CNOT (* q )],
201+ cirq .CircuitOperation (cirq .FrozenCircuit (h_cz_y )).repeat (4 ),
202+ [cirq .CNOT (* q ), cirq .CZ (* q ), cirq .CNOT (* q )],
203+ cirq .CircuitOperation (cirq .FrozenCircuit (h_cz_y )).repeat (5 ).with_tags ("preserve_tag" ),
204+ )
205+
206+ def _wrap_in_cop (ops : cirq .OP_TREE , tag : str ):
207+ return cirq .CircuitOperation (cirq .FrozenCircuit (ops )).with_tags (tag )
208+
209+ c_expected = cirq .Circuit (
210+ _wrap_in_cop ([h_cz_y , cirq .Y (q [1 ])], '1' ),
211+ cirq .Moment (cirq .X (q [0 ]).with_tags ("ignore" )),
212+ cirq .CircuitOperation (cirq .FrozenCircuit (h_cz_y )).repeat (6 ).with_tags ("ignore" ),
213+ _wrap_in_cop ([cirq .CNOT (* q ), cirq .CNOT (* q )], '2' ),
214+ cirq .CircuitOperation (cirq .FrozenCircuit (_wrap_in_cop (h_cz_y , '3' ))).repeat (4 ),
215+ _wrap_in_cop ([cirq .CNOT (* q ), cirq .CZ (* q ), cirq .CNOT (* q )], '4' ),
216+ cirq .CircuitOperation (cirq .FrozenCircuit (_wrap_in_cop (h_cz_y , '5' )))
217+ .repeat (5 )
218+ .with_tags ("preserve_tag" ),
219+ strategy = cirq .InsertStrategy .NEW ,
220+ )
221+
222+ component_id = 0
223+
224+ def rewriter_merge_to_circuit_op (op : 'cirq.CircuitOperation' ) -> 'cirq.OP_TREE' :
225+ nonlocal component_id
226+ component_id = component_id + 1
227+ return op .with_tags (f'{ component_id } ' )
228+
229+ context = cirq .TransformerContext (tags_to_ignore = ("ignore" ,), deep = True )
230+ c_new = cirq .merge_k_qubit_unitaries (
231+ c_orig ,
232+ k = 2 ,
233+ context = context ,
234+ rewriter = rewriter_merge_to_circuit_op ,
235+ )
236+ cirq .testing .assert_same_circuits (c_new , c_expected )
237+
238+ def _wrap_in_matrix_gate (ops : cirq .OP_TREE ):
239+ op = _wrap_in_cop (ops , 'temp' )
240+ return cirq .MatrixGate (cirq .unitary (op )).on (* op .qubits )
241+
242+ c_expected_matrix = cirq .Circuit (
243+ _wrap_in_matrix_gate ([h_cz_y , cirq .Y (q [1 ])]),
244+ cirq .Moment (cirq .X (q [0 ]).with_tags ("ignore" )),
245+ cirq .CircuitOperation (cirq .FrozenCircuit (h_cz_y )).repeat (6 ).with_tags ("ignore" ),
246+ _wrap_in_matrix_gate ([cirq .CNOT (* q ), cirq .CNOT (* q )]),
247+ cirq .CircuitOperation (cirq .FrozenCircuit (_wrap_in_matrix_gate (h_cz_y ))).repeat (4 ),
248+ _wrap_in_matrix_gate ([cirq .CNOT (* q ), cirq .CZ (* q ), cirq .CNOT (* q )]),
249+ cirq .CircuitOperation (cirq .FrozenCircuit (_wrap_in_matrix_gate (h_cz_y )))
250+ .repeat (5 )
251+ .with_tags ("preserve_tag" ),
252+ strategy = cirq .InsertStrategy .NEW ,
253+ )
254+ c_new_matrix = cirq .merge_k_qubit_unitaries (c_orig , k = 2 , context = context )
255+ cirq .testing .assert_same_circuits (c_new_matrix , c_expected_matrix )
0 commit comments