Skip to content

Commit cb89f3a

Browse files
Update Density Matrix and State Vector Simulators to work when an operation allocates new qubits as part of its decomposition (#6108)
* WIP add factoring and kron methods to sim state for adding and removing ancillas in state vector and density matrix simulators * add test cases * add delegating gate test case * update test * all tests pass * add test case for unitary Y * nit * addresses PR comments by adding empty checks. Applys formatter. Subsequent push will add more test cases per Tanuj's comment * nit formatting changes, add docustring with input/output for remove_qubits * merge this branch and tanujkhattar/Cirq@ccde689 * merging branches, adding test coverage in next push * format files * add coverage tests * change assert * coverage and type check tests should pass * incorporate tanujkhattar/Cirq@1db8ac5 * nit * remove block comment * add coverage --------- Co-authored-by: Tanuj Khattar <tanujkhattar@google.com>
1 parent 3393439 commit cb89f3a

File tree

6 files changed

+255
-14
lines changed

6 files changed

+255
-14
lines changed

cirq-core/cirq/protocols/act_on_protocol.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ def act_on(
149149

150150
arg_fallback = getattr(sim_state, '_act_on_fallback_', None)
151151
if arg_fallback is not None:
152-
qubits = action.qubits if isinstance(action, ops.Operation) else qubits
152+
qubits = action.qubits if is_op else qubits
153153
result = arg_fallback(action, qubits=qubits, allow_decompose=allow_decompose)
154154
if result is True:
155155
return

cirq-core/cirq/sim/density_matrix_simulation_state.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,22 @@ def __init__(
285285
)
286286
super().__init__(state=state, prng=prng, qubits=qubits, classical_data=classical_data)
287287

288+
def add_qubits(self, qubits: Sequence['cirq.Qid']):
289+
ret = super().add_qubits(qubits)
290+
return (
291+
self.kronecker_product(type(self)(qubits=qubits), inplace=True)
292+
if ret is NotImplemented
293+
else ret
294+
)
295+
296+
def remove_qubits(self, qubits: Sequence['cirq.Qid']):
297+
ret = super().remove_qubits(qubits)
298+
if ret is not NotImplemented:
299+
return ret
300+
extracted, remainder = self.factor(qubits)
301+
remainder._state._density_matrix *= extracted._state._density_matrix.reshape(-1)[0]
302+
return remainder
303+
288304
def _act_on_fallback_(
289305
self, action: Any, qubits: Sequence['cirq.Qid'], allow_decompose: bool = True
290306
) -> bool:

cirq-core/cirq/sim/density_matrix_simulation_state_test.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,3 +123,15 @@ def test_initial_state_bad_shape():
123123
cirq.DensityMatrixSimulationState(
124124
qubits=qubits, initial_state=np.full((2, 2, 2, 2), 1 / 4), dtype=np.complex64
125125
)
126+
127+
128+
def test_remove_qubits():
129+
"""Test the remove_qubits method."""
130+
q1 = cirq.LineQubit(0)
131+
q2 = cirq.LineQubit(1)
132+
state = cirq.DensityMatrixSimulationState(qubits=[q1, q2])
133+
134+
new_state = state.remove_qubits([q1])
135+
136+
assert len(new_state.qubits) == 1
137+
assert q1 not in new_state.qubits

cirq-core/cirq/sim/simulation_state.py

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,38 @@ def create_merged_state(self) -> Self:
166166
"""Creates a final merged state."""
167167
return self
168168

169+
def add_qubits(self: Self, qubits: Sequence['cirq.Qid']):
170+
"""Add qubits to a new state space and take the kron product.
171+
172+
Note that only Density Matrix and State Vector simulators
173+
override this function.
174+
175+
Args:
176+
qubits: Sequence of qubits to be added.
177+
178+
Returns:
179+
NotImplemented: If the subclass does not implement this method.
180+
181+
Raises:
182+
ValueError: If a qubit being added is already tracked.
183+
"""
184+
if any(q in self.qubits for q in qubits):
185+
raise ValueError(f"Qubit to add {qubits} should not already be tracked.")
186+
return NotImplemented
187+
188+
def remove_qubits(self: Self, qubits: Sequence['cirq.Qid']) -> Self:
189+
"""Remove qubits from the state space.
190+
191+
Args:
192+
qubits: Sequence of qubits to be added.
193+
194+
Returns:
195+
A new Simulation State with qubits removed. Or
196+
`self` if there are no qubits to remove."""
197+
if qubits is None or not qubits:
198+
return self
199+
return NotImplemented
200+
169201
def kronecker_product(self, other: Self, *, inplace=False) -> Self:
170202
"""Joins two state spaces together."""
171203
args = self if inplace else copy.copy(self)
@@ -294,13 +326,24 @@ def strat_act_on_from_apply_decompose(
294326
val: Any, args: 'cirq.SimulationState', qubits: Sequence['cirq.Qid']
295327
) -> bool:
296328
operations, qubits1, _ = _try_decompose_into_operations_and_qubits(val)
297-
assert len(qubits1) == len(qubits)
298-
qubit_map = {q: qubits[i] for i, q in enumerate(qubits1)}
299329
if operations is None:
300330
return NotImplemented
331+
assert len(qubits1) == len(qubits)
332+
all_qubits = frozenset([q for op in operations for q in op.qubits])
333+
qubit_map = dict(zip(all_qubits, all_qubits))
334+
qubit_map.update(dict(zip(qubits1, qubits)))
335+
new_ancilla = tuple(q for q in sorted(all_qubits.difference(qubits)) if q not in args.qubits)
336+
args = args.add_qubits(new_ancilla)
337+
if args is NotImplemented:
338+
return NotImplemented
301339
for operation in operations:
302340
operation = operation.with_qubits(*[qubit_map[q] for q in operation.qubits])
303341
protocols.act_on(operation, args)
342+
args = args.remove_qubits(new_ancilla)
343+
if args is NotImplemented: # coverage: ignore
344+
raise TypeError( # coverage: ignore
345+
f"{type(args)} implements `add_qubits` but not `remove_qubits`." # coverage: ignore
346+
) # coverage: ignore
304347
return True
305348

306349

cirq-core/cirq/sim/simulation_state_test.py

Lines changed: 165 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
import cirq
2121
from cirq.sim import simulation_state
22+
from cirq.testing import PhaseUsingCleanAncilla, PhaseUsingDirtyAncilla
2223

2324

2425
class DummyQuantumState(cirq.QuantumStateRepresentation):
@@ -33,32 +34,90 @@ def reindex(self, axes):
3334

3435

3536
class DummySimulationState(cirq.SimulationState):
36-
def __init__(self):
37-
super().__init__(state=DummyQuantumState(), qubits=cirq.LineQubit.range(2))
37+
def __init__(self, qubits=cirq.LineQubit.range(2)):
38+
super().__init__(state=DummyQuantumState(), qubits=qubits)
3839

3940
def _act_on_fallback_(
4041
self, action: Any, qubits: Sequence['cirq.Qid'], allow_decompose: bool = True
4142
) -> bool:
4243
return True
4344

4445

46+
class AncillaZ(cirq.Gate):
47+
def __init__(self, exponent=1):
48+
self._exponent = exponent
49+
50+
def num_qubits(self) -> int:
51+
return 1
52+
53+
def _decompose_(self, qubits):
54+
ancilla = cirq.NamedQubit('Ancilla')
55+
yield cirq.CX(qubits[0], ancilla)
56+
yield cirq.Z(ancilla) ** self._exponent
57+
yield cirq.CX(qubits[0], ancilla)
58+
59+
60+
class AncillaH(cirq.Gate):
61+
def __init__(self, exponent=1):
62+
self._exponent = exponent
63+
64+
def num_qubits(self) -> int:
65+
return 1
66+
67+
def _decompose_(self, qubits):
68+
ancilla = cirq.NamedQubit('Ancilla')
69+
yield cirq.H(ancilla) ** self._exponent
70+
yield cirq.CX(ancilla, qubits[0])
71+
yield cirq.H(ancilla) ** self._exponent
72+
73+
74+
class AncillaY(cirq.Gate):
75+
def __init__(self, exponent=1):
76+
self._exponent = exponent
77+
78+
def num_qubits(self) -> int:
79+
return 1
80+
81+
def _decompose_(self, qubits):
82+
ancilla = cirq.NamedQubit('Ancilla')
83+
yield cirq.Y(ancilla) ** self._exponent
84+
yield cirq.CX(ancilla, qubits[0])
85+
yield cirq.Y(ancilla) ** self._exponent
86+
87+
88+
class DelegatingAncillaZ(cirq.Gate):
89+
def __init__(self, exponent=1):
90+
self._exponent = exponent
91+
92+
def num_qubits(self) -> int:
93+
return 1
94+
95+
def _decompose_(self, qubits):
96+
a = cirq.NamedQubit('a')
97+
yield cirq.CX(qubits[0], a)
98+
yield AncillaZ(self._exponent).on(a)
99+
yield cirq.CX(qubits[0], a)
100+
101+
102+
class Composite(cirq.Gate):
103+
def num_qubits(self) -> int:
104+
return 1
105+
106+
def _decompose_(self, qubits):
107+
yield cirq.X(*qubits)
108+
109+
45110
def test_measurements():
46111
args = DummySimulationState()
47112
args.measure([cirq.LineQubit(0)], "test", [False], {})
48113
assert args.log_of_measurement_results["test"] == [5]
49114

50115

51116
def test_decompose():
52-
class Composite(cirq.Gate):
53-
def num_qubits(self) -> int:
54-
return 1
55-
56-
def _decompose_(self, qubits):
57-
yield cirq.X(*qubits)
58-
59117
args = DummySimulationState()
60-
assert simulation_state.strat_act_on_from_apply_decompose(
61-
Composite(), args, [cirq.LineQubit(0)]
118+
assert (
119+
simulation_state.strat_act_on_from_apply_decompose(Composite(), args, [cirq.LineQubit(0)])
120+
is NotImplemented
62121
)
63122

64123

@@ -101,3 +160,98 @@ def test_field_getters():
101160
args = DummySimulationState()
102161
assert args.prng is np.random
103162
assert args.qubit_map == {q: i for i, q in enumerate(cirq.LineQubit.range(2))}
163+
164+
165+
@pytest.mark.parametrize('exp', [-3, -2, -1, 0, 1, 2, 3])
166+
def test_ancilla_z(exp):
167+
q = cirq.LineQubit(0)
168+
test_circuit = cirq.Circuit(AncillaZ(exp).on(q))
169+
170+
control_circuit = cirq.Circuit(cirq.ZPowGate(exponent=exp).on(q))
171+
172+
assert_test_circuit_for_sv_dm_simulators(test_circuit, control_circuit)
173+
174+
175+
@pytest.mark.parametrize('exp', [-3, -2, -1, 0, 1, 2, 3])
176+
def test_ancilla_y(exp):
177+
q = cirq.LineQubit(0)
178+
test_circuit = cirq.Circuit(AncillaY(exp).on(q))
179+
180+
control_circuit = cirq.Circuit(cirq.Y(q))
181+
control_circuit.append(cirq.Y(q))
182+
control_circuit.append(cirq.XPowGate(exponent=exp).on(q))
183+
184+
assert_test_circuit_for_sv_dm_simulators(test_circuit, control_circuit)
185+
186+
187+
@pytest.mark.parametrize('exp', [-3, -2, -1, 0, 1, 2, 3])
188+
def test_borrowable_qubit(exp):
189+
q = cirq.LineQubit(0)
190+
test_circuit = cirq.Circuit()
191+
test_circuit.append(cirq.H(q))
192+
test_circuit.append(cirq.X(q))
193+
test_circuit.append(AncillaH(exp).on(q))
194+
195+
control_circuit = cirq.Circuit(cirq.H(q))
196+
197+
assert_test_circuit_for_sv_dm_simulators(test_circuit, control_circuit)
198+
199+
200+
@pytest.mark.parametrize('exp', [-3, -2, -1, 0, 1, 2, 3])
201+
def test_delegating_gate_qubit(exp):
202+
q = cirq.LineQubit(0)
203+
204+
test_circuit = cirq.Circuit()
205+
test_circuit.append(cirq.H(q))
206+
test_circuit.append(DelegatingAncillaZ(exp).on(q))
207+
208+
control_circuit = cirq.Circuit(cirq.H(q))
209+
control_circuit.append(cirq.ZPowGate(exponent=exp).on(q))
210+
211+
assert_test_circuit_for_sv_dm_simulators(test_circuit, control_circuit)
212+
213+
214+
@pytest.mark.parametrize('num_ancilla', [1, 2, 3])
215+
def test_phase_using_dirty_ancilla(num_ancilla: int):
216+
q = cirq.LineQubit(0)
217+
anc = cirq.NamedQubit.range(num_ancilla, prefix='anc')
218+
219+
u = cirq.MatrixGate(cirq.testing.random_unitary(2 ** (num_ancilla + 1)))
220+
test_circuit = cirq.Circuit(
221+
u.on(q, *anc), PhaseUsingDirtyAncilla(ancilla_bitsize=num_ancilla).on(q)
222+
)
223+
control_circuit = cirq.Circuit(u.on(q, *anc), cirq.Z(q))
224+
assert_test_circuit_for_sv_dm_simulators(test_circuit, control_circuit)
225+
226+
227+
@pytest.mark.parametrize('num_ancilla', [1, 2, 3])
228+
@pytest.mark.parametrize('theta', np.linspace(0, 2 * np.pi, 10))
229+
def test_phase_using_clean_ancilla(num_ancilla: int, theta: float):
230+
q = cirq.LineQubit(0)
231+
u = cirq.MatrixGate(cirq.testing.random_unitary(2))
232+
test_circuit = cirq.Circuit(
233+
u.on(q), PhaseUsingCleanAncilla(theta=theta, ancilla_bitsize=num_ancilla).on(q)
234+
)
235+
control_circuit = cirq.Circuit(u.on(q), cirq.ZPowGate(exponent=theta).on(q))
236+
assert_test_circuit_for_sv_dm_simulators(test_circuit, control_circuit)
237+
238+
239+
def test_add_qubits_raise_value_error(num_ancilla=1):
240+
q = cirq.LineQubit(0)
241+
args = cirq.StateVectorSimulationState(qubits=[q])
242+
243+
with pytest.raises(ValueError, match='should not already be tracked.'):
244+
args.add_qubits([q])
245+
246+
247+
def test_remove_qubits_not_implemented(num_ancilla=1):
248+
args = DummySimulationState()
249+
250+
assert args.remove_qubits([cirq.LineQubit(0)]) is NotImplemented
251+
252+
253+
def assert_test_circuit_for_sv_dm_simulators(test_circuit, control_circuit) -> None:
254+
for test_simulator in ['cirq.final_state_vector', 'cirq.final_density_matrix']:
255+
test_sim = eval(test_simulator)(test_circuit)
256+
control_sim = eval(test_simulator)(control_circuit)
257+
assert np.allclose(test_sim, control_sim)

cirq-core/cirq/sim/state_vector_simulation_state.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -355,6 +355,22 @@ def __init__(
355355
)
356356
super().__init__(state=state, prng=prng, qubits=qubits, classical_data=classical_data)
357357

358+
def add_qubits(self, qubits: Sequence['cirq.Qid']):
359+
ret = super().add_qubits(qubits)
360+
return (
361+
self.kronecker_product(type(self)(qubits=qubits), inplace=True)
362+
if ret is NotImplemented
363+
else ret
364+
)
365+
366+
def remove_qubits(self, qubits: Sequence['cirq.Qid']):
367+
ret = super().remove_qubits(qubits)
368+
if ret is not NotImplemented:
369+
return ret
370+
extracted, remainder = self.factor(qubits, inplace=True)
371+
remainder._state._state_vector *= extracted._state._state_vector.reshape((-1,))[0]
372+
return remainder
373+
358374
def _act_on_fallback_(
359375
self, action: Any, qubits: Sequence['cirq.Qid'], allow_decompose: bool = True
360376
) -> bool:

0 commit comments

Comments
 (0)