Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 14 additions & 7 deletions cirq-core/cirq/sim/simulation_product_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,19 @@ def split_untangled_states(self) -> bool:
return self._split_untangled_states

def create_merged_state(self) -> TSimulationState:
merged_state = self.sim_states[None]
if not self.split_untangled_states:
return self.sim_states[None]
final_args = self.sim_states[None]
for args in set([self.sim_states[k] for k in self.sim_states.keys() if k is not None]):
final_args = final_args.kronecker_product(args)
return final_args.transpose_to_qubit_order(self.qubits)
return merged_state
extra_states = set([self.sim_states[k] for k in self.sim_states.keys() if k is not None])
if not extra_states:
return merged_state

# This comes from a member variable so we need to copy it if we're going to modify inplace
# before returning. We're not running a step currently, so no need to copy buffers.
merged_state = merged_state.copy(deep_copy_buffers=False)
for state in extra_states:
merged_state.kronecker_product(state, inplace=True)
return merged_state.transpose_to_qubit_order(self.qubits, inplace=True)

def _act_on_fallback_(
self, action: Any, qubits: Sequence['cirq.Qid'], allow_decompose: bool = True
Expand Down Expand Up @@ -106,7 +113,7 @@ def _act_on_fallback_(
if op_args_opt is None:
op_args_opt = self.sim_states[q]
elif q not in op_args_opt.qubits:
op_args_opt = op_args_opt.kronecker_product(self.sim_states[q])
op_args_opt.kronecker_product(self.sim_states[q], inplace=True)
op_args = op_args_opt or self.sim_states[None]

# (Backfill the args map with the new value)
Expand All @@ -123,7 +130,7 @@ def _act_on_fallback_(
):
for q in qubits:
if op_args.allows_factoring and len(op_args.qubits) > 1:
q_args, op_args = op_args.factor((q,), validate=False)
q_args, _ = op_args.factor((q,), validate=False, inplace=True)
self._sim_states[q] = q_args

# (Backfill the args map with the new value)
Expand Down