1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414"""Objects and methods for acting efficiently on a state tensor."""
15- import abc
1615import copy
1716import inspect
17+ import warnings
1818from typing import (
1919 Any ,
2020 cast ,
2828 TYPE_CHECKING ,
2929 Tuple ,
3030)
31- import warnings
3231
3332import numpy as np
3433
@@ -59,6 +58,7 @@ def __init__(
5958 log_of_measurement_results : Optional [Dict [str , List [int ]]] = None ,
6059 ignore_measurement_results : bool = False ,
6160 classical_data : Optional ['cirq.ClassicalDataStore' ] = None ,
61+ state : Optional ['cirq.QuantumStateRepresentation' ] = None ,
6262 ):
6363 """Inits ActOnArgs.
6464
@@ -76,6 +76,7 @@ def __init__(
7676 simulators that can represent mixed states.
7777 classical_data: The shared classical data container for this
7878 simulation.
79+ state: The underlying quantum state of the simulation.
7980 """
8081 if prng is None :
8182 prng = cast (np .random .RandomState , np .random )
@@ -90,6 +91,7 @@ def __init__(
9091 }
9192 )
9293 self ._ignore_measurement_results = ignore_measurement_results
94+ self ._state = state
9395
9496 @property
9597 def prng (self ) -> np .random .RandomState :
@@ -148,10 +150,21 @@ def measure(self, qubits: Sequence['cirq.Qid'], key: str, invert_mask: Sequence[
148150 def get_axes (self , qubits : Sequence ['cirq.Qid' ]) -> List [int ]:
149151 return [self .qubit_map [q ] for q in qubits ]
150152
151- @abc .abstractmethod
152153 def _perform_measurement (self , qubits : Sequence ['cirq.Qid' ]) -> List [int ]:
153- """Child classes that perform measurements should implement this with
154- the implementation."""
154+ """Delegates the call to measure the density matrix."""
155+ if self ._state is not None :
156+ return self ._state .measure (self .get_axes (qubits ), self .prng )
157+ raise NotImplementedError ()
158+
159+ def sample (
160+ self ,
161+ qubits : Sequence ['cirq.Qid' ],
162+ repetitions : int = 1 ,
163+ seed : 'cirq.RANDOM_STATE_OR_SEED_LIKE' = None ,
164+ ) -> np .ndarray :
165+ if self ._state is not None :
166+ return self ._state .sample (self .get_axes (qubits ), repetitions , seed )
167+ raise NotImplementedError ()
155168
156169 def copy (self : TSelf , deep_copy_buffers : bool = True ) -> TSelf :
157170 """Creates a copy of the object.
@@ -165,6 +178,10 @@ def copy(self: TSelf, deep_copy_buffers: bool = True) -> TSelf:
165178 A copied instance.
166179 """
167180 args = copy .copy (self )
181+ args ._classical_data = self ._classical_data .copy ()
182+ if self ._state is not None :
183+ args ._state = self ._state .copy (deep_copy_buffers = deep_copy_buffers )
184+ return args
168185 if 'deep_copy_buffers' in inspect .signature (self ._on_copy ).parameters :
169186 self ._on_copy (args , deep_copy_buffers )
170187 else :
@@ -176,7 +193,6 @@ def copy(self: TSelf, deep_copy_buffers: bool = True) -> TSelf:
176193 DeprecationWarning ,
177194 )
178195 self ._on_copy (args )
179- args ._classical_data = self ._classical_data .copy ()
180196 return args
181197
182198 def _on_copy (self : TSelf , args : TSelf , deep_copy_buffers : bool = True ):
@@ -190,7 +206,10 @@ def create_merged_state(self: TSelf) -> TSelf:
190206 def kronecker_product (self : TSelf , other : TSelf , * , inplace = False ) -> TSelf :
191207 """Joins two state spaces together."""
192208 args = self if inplace else copy .copy (self )
193- self ._on_kronecker_product (other , args )
209+ if self ._state is not None and other ._state is not None :
210+ args ._state = self ._state .kron (other ._state )
211+ else :
212+ self ._on_kronecker_product (other , args )
194213 args ._set_qubits (self .qubits + other .qubits )
195214 return args
196215
@@ -225,15 +244,20 @@ def factor(
225244 """Splits two state spaces after a measurement or reset."""
226245 extracted = copy .copy (self )
227246 remainder = self if inplace else copy .copy (self )
228- self ._on_factor (qubits , extracted , remainder , validate , atol )
247+ if self ._state is not None :
248+ e , r = self ._state .factor (self .get_axes (qubits ), validate = validate , atol = atol )
249+ extracted ._state = e
250+ remainder ._state = r
251+ else :
252+ self ._on_factor (qubits , extracted , remainder , validate , atol )
229253 extracted ._set_qubits (qubits )
230254 remainder ._set_qubits ([q for q in self .qubits if q not in qubits ])
231255 return extracted , remainder
232256
233257 @property
234258 def allows_factoring (self ):
235259 """Subclasses that allow factorization should override this."""
236- return False
260+ return self . _state . supports_factor if self . _state is not None else False
237261
238262 def _on_factor (
239263 self : TSelf ,
@@ -265,7 +289,10 @@ def transpose_to_qubit_order(
265289 if len (self .qubits ) != len (qubits ) or set (qubits ) != set (self .qubits ):
266290 raise ValueError (f'Qubits do not match. Existing: { self .qubits } , provided: { qubits } ' )
267291 args = self if inplace else copy .copy (self )
268- self ._on_transpose_to_qubit_order (qubits , args )
292+ if self ._state is not None :
293+ args ._state = self ._state .reindex (self .get_axes (qubits ))
294+ else :
295+ self ._on_transpose_to_qubit_order (qubits , args )
269296 args ._set_qubits (qubits )
270297 return args
271298
@@ -356,7 +383,7 @@ def __iter__(self) -> Iterator[Optional['cirq.Qid']]:
356383
357384 @property
358385 def can_represent_mixed_states (self ) -> bool :
359- return False
386+ return self . _state . can_represent_mixed_states if self . _state is not None else False
360387
361388
362389def strat_act_on_from_apply_decompose (
0 commit comments