-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Cache Circuit properties between mutations
#6322
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 4 commits
d511136
f88b750
592bfed
523a83e
975a298
73e895d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -188,28 +188,20 @@ def _from_moments(cls: Type[CIRCUIT_TYPE], moments: Iterable['cirq.Moment']) -> | |
| def moments(self) -> Sequence['cirq.Moment']: | ||
| pass | ||
|
|
||
| @abc.abstractmethod | ||
| def freeze(self) -> 'cirq.FrozenCircuit': | ||
| """Creates a FrozenCircuit from this circuit. | ||
|
|
||
| If 'self' is a FrozenCircuit, the original object is returned. | ||
| """ | ||
| from cirq.circuits import FrozenCircuit | ||
|
|
||
| if isinstance(self, FrozenCircuit): | ||
| return self | ||
|
|
||
| return FrozenCircuit(self, strategy=InsertStrategy.EARLIEST) | ||
|
|
||
| @abc.abstractmethod | ||
| def unfreeze(self, copy: bool = True) -> 'cirq.Circuit': | ||
| """Creates a Circuit from this circuit. | ||
|
|
||
| Args: | ||
| copy: If True and 'self' is a Circuit, returns a copy that circuit. | ||
| """ | ||
| if isinstance(self, Circuit): | ||
| return Circuit.copy(self) if copy else self | ||
|
|
||
| return Circuit(self, strategy=InsertStrategy.EARLIEST) | ||
|
|
||
| def __bool__(self): | ||
| return bool(self.moments) | ||
|
|
@@ -822,6 +814,9 @@ def has_measurements(self): | |
| """ | ||
| return protocols.is_measurement(self) | ||
|
|
||
| def _is_measurement_(self) -> bool: | ||
| return any(protocols.is_measurement(op) for op in self.all_operations()) | ||
|
|
||
| def are_all_measurements_terminal(self) -> bool: | ||
| """Whether all measurement gates are at the end of the circuit. | ||
|
|
||
|
|
@@ -1383,8 +1378,7 @@ def save_qasm( | |
| self._to_qasm_output(header, precision, qubit_order).save(file_path) | ||
|
|
||
| def _json_dict_(self): | ||
| ret = protocols.obj_to_dict_helper(self, ['moments']) | ||
| return ret | ||
| return protocols.obj_to_dict_helper(self, ['moments']) | ||
|
|
||
| @classmethod | ||
| def _from_json_dict_(cls, moments, **kwargs): | ||
|
|
@@ -1759,6 +1753,16 @@ def __init__( | |
| circuit. | ||
| """ | ||
| self._moments: List['cirq.Moment'] = [] | ||
|
|
||
| # Implementation note: the following cached properties are set lazily and then | ||
| # invalidated and reset to None in `self._mutated()`, which is called any time | ||
| # `self._moments` is changed. | ||
| self._all_qubits: Optional[FrozenSet['cirq.Qid']] = None | ||
| self._frozen: Optional['cirq.FrozenCircuit'] = None | ||
| self._is_measurement: Optional[bool] = None | ||
| self._is_parameterized: Optional[bool] = None | ||
| self._parameter_names: Optional[AbstractSet[str]] = None | ||
|
|
||
| flattened_contents = tuple(ops.flatten_to_ops_or_moments(contents)) | ||
| if all(isinstance(c, Moment) for c in flattened_contents): | ||
| self._moments[:] = cast(Iterable[Moment], flattened_contents) | ||
|
|
@@ -1769,6 +1773,14 @@ def __init__( | |
| else: | ||
| self.append(flattened_contents, strategy=strategy) | ||
|
|
||
| def _mutated(self) -> None: | ||
| """Clear cached properties in response to this circuit being mutated.""" | ||
| self._all_qubits = None | ||
| self._frozen = None | ||
| self._is_measurement = None | ||
| self._is_parameterized = None | ||
| self._parameter_names = None | ||
|
|
||
| @classmethod | ||
| def _from_moments(cls, moments: Iterable['cirq.Moment']) -> 'Circuit': | ||
| new_circuit = Circuit() | ||
|
|
@@ -1831,6 +1843,41 @@ def _load_contents_with_earliest_strategy(self, contents: 'cirq.OP_TREE'): | |
| def __copy__(self) -> 'cirq.Circuit': | ||
| return self.copy() | ||
|
|
||
| def freeze(self) -> 'cirq.FrozenCircuit': | ||
| """Gets a frozen version of this circuit. | ||
|
|
||
| Repeated calls to `.freeze()` will return the same FrozenCircuit | ||
| instance as long as this circuit is not mutated. | ||
| """ | ||
| from cirq.circuits import FrozenCircuit | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: should this be a full import path
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This was copied from what we had in |
||
|
|
||
| if self._frozen is None: | ||
| self._frozen = FrozenCircuit.from_moments(*self._moments) | ||
| return self._frozen | ||
|
|
||
| def unfreeze(self, copy: bool = True) -> 'cirq.Circuit': | ||
| return self.copy() if copy else self | ||
|
|
||
| def all_qubits(self) -> FrozenSet['cirq.Qid']: | ||
| if self._all_qubits is None: | ||
| self._all_qubits = super().all_qubits() | ||
| return self._all_qubits | ||
|
|
||
| def _is_measurement_(self) -> bool: | ||
| if self._is_measurement is None: | ||
| self._is_measurement = super()._is_measurement_() | ||
| return self._is_measurement | ||
|
|
||
| def _is_parameterized_(self) -> bool: | ||
| if self._is_parameterized is None: | ||
| self._is_parameterized = super()._is_parameterized_() | ||
| return self._is_parameterized | ||
|
|
||
| def _parameter_names_(self) -> AbstractSet[str]: | ||
| if self._parameter_names is None: | ||
| self._parameter_names = super()._parameter_names_() | ||
| return self._parameter_names | ||
|
|
||
| def copy(self) -> 'Circuit': | ||
| """Return a copy of this circuit.""" | ||
| copied_circuit = Circuit() | ||
|
|
@@ -1856,11 +1903,13 @@ def __setitem__(self, key, value): | |
| raise TypeError('Can only assign Moments into Circuits.') | ||
|
|
||
| self._moments[key] = value | ||
| self._mutated() | ||
|
|
||
| # pylint: enable=function-redefined | ||
|
|
||
| def __delitem__(self, key: Union[int, slice]): | ||
| del self._moments[key] | ||
| self._mutated() | ||
|
|
||
| def __iadd__(self, other): | ||
| self.append(other) | ||
|
|
@@ -1889,6 +1938,7 @@ def __imul__(self, repetitions: _INT_TYPE): | |
| if not isinstance(repetitions, (int, np.integer)): | ||
| return NotImplemented | ||
| self._moments *= int(repetitions) | ||
| self._mutated() | ||
| return self | ||
|
|
||
| def __mul__(self, repetitions: _INT_TYPE): | ||
|
|
@@ -2032,6 +2082,7 @@ def _pick_or_create_inserted_op_moment_index( | |
|
|
||
| if strategy is InsertStrategy.NEW or strategy is InsertStrategy.NEW_THEN_INLINE: | ||
| self._moments.insert(splitter_index, Moment()) | ||
| self._mutated() | ||
| return splitter_index | ||
|
|
||
| if strategy is InsertStrategy.INLINE: | ||
|
|
@@ -2099,6 +2150,7 @@ def insert( | |
| k = max(k, p + 1) | ||
| if strategy is InsertStrategy.NEW_THEN_INLINE: | ||
| strategy = InsertStrategy.INLINE | ||
| self._mutated() | ||
| return k | ||
|
|
||
| def insert_into_range(self, operations: 'cirq.OP_TREE', start: int, end: int) -> int: | ||
|
|
@@ -2135,6 +2187,7 @@ def insert_into_range(self, operations: 'cirq.OP_TREE', start: int, end: int) -> | |
|
|
||
| self._moments[i] = self._moments[i].with_operation(op) | ||
| op_index += 1 | ||
| self._mutated() | ||
|
|
||
| if op_index >= len(flat_ops): | ||
| return end | ||
|
|
@@ -2180,6 +2233,7 @@ def _push_frontier( | |
| if n_new_moments > 0: | ||
| insert_index = min(late_frontier.values()) | ||
| self._moments[insert_index:insert_index] = [Moment()] * n_new_moments | ||
| self._mutated() | ||
| for q in update_qubits: | ||
| if early_frontier.get(q, 0) > insert_index: | ||
| early_frontier[q] += n_new_moments | ||
|
|
@@ -2206,13 +2260,12 @@ def _insert_operations( | |
| if len(operations) != len(insertion_indices): | ||
| raise ValueError('operations and insertion_indices must have the same length.') | ||
| self._moments += [Moment() for _ in range(1 + max(insertion_indices) - len(self))] | ||
| self._mutated() | ||
| moment_to_ops: Dict[int, List['cirq.Operation']] = defaultdict(list) | ||
| for op_index, moment_index in enumerate(insertion_indices): | ||
| moment_to_ops[moment_index].append(operations[op_index]) | ||
| for moment_index, new_ops in moment_to_ops.items(): | ||
| self._moments[moment_index] = Moment( | ||
| self._moments[moment_index].operations + tuple(new_ops) | ||
| ) | ||
| self._moments[moment_index] = self._moments[moment_index].with_operations(*new_ops) | ||
|
|
||
| def insert_at_frontier( | ||
| self, | ||
|
|
@@ -2274,6 +2327,7 @@ def batch_remove(self, removals: Iterable[Tuple[int, 'cirq.Operation']]) -> None | |
| old_op for old_op in copy._moments[i].operations if op != old_op | ||
| ) | ||
| self._moments = copy._moments | ||
| self._mutated() | ||
|
|
||
| def batch_replace( | ||
| self, replacements: Iterable[Tuple[int, 'cirq.Operation', 'cirq.Operation']] | ||
|
|
@@ -2298,6 +2352,7 @@ def batch_replace( | |
| old_op if old_op != op else new_op for old_op in copy._moments[i].operations | ||
| ) | ||
| self._moments = copy._moments | ||
| self._mutated() | ||
|
|
||
| def batch_insert_into(self, insert_intos: Iterable[Tuple[int, 'cirq.OP_TREE']]) -> None: | ||
| """Inserts operations into empty spaces in existing moments. | ||
|
|
@@ -2318,6 +2373,7 @@ def batch_insert_into(self, insert_intos: Iterable[Tuple[int, 'cirq.OP_TREE']]) | |
| for i, insertions in insert_intos: | ||
| copy._moments[i] = copy._moments[i].with_operations(insertions) | ||
| self._moments = copy._moments | ||
| self._mutated() | ||
|
|
||
| def batch_insert(self, insertions: Iterable[Tuple[int, 'cirq.OP_TREE']]) -> None: | ||
| """Applies a batched insert operation to the circuit. | ||
|
|
@@ -2352,6 +2408,7 @@ def batch_insert(self, insertions: Iterable[Tuple[int, 'cirq.OP_TREE']]) -> None | |
| if next_index > insert_index: | ||
| shift += next_index - insert_index | ||
| self._moments = copy._moments | ||
| self._mutated() | ||
|
|
||
| def append( | ||
| self, | ||
|
|
@@ -2382,6 +2439,7 @@ def clear_operations_touching( | |
| for k in moment_indices: | ||
| if 0 <= k < len(self._moments): | ||
| self._moments[k] = self._moments[k].without_operations_touching(qubits) | ||
| self._mutated() | ||
|
|
||
| @property | ||
| def moments(self) -> Sequence['cirq.Moment']: | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shouldn't this also include
_is_measurementand_parameter_names?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch. Fixed.