Skip to content
65 changes: 40 additions & 25 deletions cirq-core/cirq/circuits/circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,30 @@ class AbstractCircuit(abc.ABC):
* get_independent_qubit_sets
"""

@classmethod
def from_moments(cls: Type[CIRCUIT_TYPE], *moments: 'cirq.OP_TREE') -> CIRCUIT_TYPE:
"""Create a circuit from moment op trees.

Args:
*moments: Op tree for each moment.
"""
return cls._from_moments(
moment if isinstance(moment, Moment) else Moment(moment) for moment in moments
)

@classmethod
@abc.abstractmethod
def _from_moments(cls: Type[CIRCUIT_TYPE], moments: Iterable['cirq.Moment']) -> CIRCUIT_TYPE:
"""Create a circuit from moments.

This must be implemented by subclasses. It provides a more efficient way
to construct a circuit instance since we already have the moments and so
can skip the analysis required to implement various insert strategies.

Args:
moments: Moments of the circuit.
"""

@property
@abc.abstractmethod
def moments(self) -> Sequence['cirq.Moment']:
Expand Down Expand Up @@ -225,8 +249,7 @@ def __getitem__(self: CIRCUIT_TYPE, key: Tuple[slice, Iterable['cirq.Qid']]) ->

def __getitem__(self, key):
if isinstance(key, slice):
sliced_moments = self.moments[key]
return self._with_sliced_moments(sliced_moments)
return self._from_moments(self.moments[key])
if hasattr(key, '__index__'):
return self.moments[key]
if isinstance(key, tuple):
Expand All @@ -239,17 +262,12 @@ def __getitem__(self, key):
return selected_moments[qubit_idx]
if isinstance(qubit_idx, ops.Qid):
qubit_idx = [qubit_idx]
sliced_moments = [moment[qubit_idx] for moment in selected_moments]
return self._with_sliced_moments(sliced_moments)
return self._from_moments(moment[qubit_idx] for moment in selected_moments)

raise TypeError('__getitem__ called with key not of type slice, int, or tuple.')

# pylint: enable=function-redefined

@abc.abstractmethod
def _with_sliced_moments(self: CIRCUIT_TYPE, moments: Iterable['cirq.Moment']) -> CIRCUIT_TYPE:
"""Helper method for constructing circuits from __getitem__."""

def __str__(self) -> str:
return self.to_text_diagram()

Expand Down Expand Up @@ -909,7 +927,7 @@ def map_moment(moment: 'cirq.Moment') -> 'cirq.Circuit':
"""Apply func to expand each op into a circuit, then zip up the circuits."""
return Circuit.zip(*[Circuit(func(op)) for op in moment])

return self._with_sliced_moments(m for moment in self for m in map_moment(moment))
return self._from_moments(m for moment in self for m in map_moment(moment))

def qid_shape(
self, qubit_order: 'cirq.QubitOrderOrList' = ops.QubitOrder.DEFAULT
Expand Down Expand Up @@ -949,18 +967,16 @@ def _measurement_key_names_(self) -> FrozenSet[str]:
return self.all_measurement_key_names()

def _with_measurement_key_mapping_(self, key_map: Mapping[str, str]):
return self._with_sliced_moments(
[protocols.with_measurement_key_mapping(moment, key_map) for moment in self.moments]
return self._from_moments(
protocols.with_measurement_key_mapping(moment, key_map) for moment in self.moments
)

def _with_key_path_(self, path: Tuple[str, ...]):
return self._with_sliced_moments(
[protocols.with_key_path(moment, path) for moment in self.moments]
)
return self._from_moments(protocols.with_key_path(moment, path) for moment in self.moments)

def _with_key_path_prefix_(self, prefix: Tuple[str, ...]):
return self._with_sliced_moments(
[protocols.with_key_path_prefix(moment, prefix) for moment in self.moments]
return self._from_moments(
protocols.with_key_path_prefix(moment, prefix) for moment in self.moments
)

def _with_rescoped_keys_(
Expand All @@ -971,7 +987,7 @@ def _with_rescoped_keys_(
new_moment = protocols.with_rescoped_keys(moment, path, bindable_keys)
moments.append(new_moment)
bindable_keys |= protocols.measurement_key_objs(new_moment)
return self._with_sliced_moments(moments)
return self._from_moments(moments)

def _qid_shape_(self) -> Tuple[int, ...]:
return self.qid_shape()
Expand Down Expand Up @@ -1552,9 +1568,7 @@ def factorize(self: CIRCUIT_TYPE) -> Iterable[CIRCUIT_TYPE]:
# the qubits from one factor belong to a specific independent qubit set.
# This makes it possible to create independent circuits based on these
# moments.
return (
self._with_sliced_moments([m[qubits] for m in self.moments]) for qubits in qubit_factors
)
return (self._from_moments(m[qubits] for m in self.moments) for qubits in qubit_factors)

def _control_keys_(self) -> FrozenSet['cirq.MeasurementKey']:
controls = frozenset(k for op in self.all_operations() for k in protocols.control_keys(op))
Expand Down Expand Up @@ -1719,6 +1733,12 @@ def __init__(
else:
self.append(contents, strategy=strategy)

@classmethod
def _from_moments(cls, moments: Iterable['cirq.Moment']) -> 'Circuit':
new_circuit = Circuit()
new_circuit._moments[:] = moments
return new_circuit

def _load_contents_with_earliest_strategy(self, contents: 'cirq.OP_TREE'):
"""Optimized algorithm to load contents quickly.

Expand Down Expand Up @@ -1813,11 +1833,6 @@ def copy(self) -> 'Circuit':
copied_circuit._moments = self._moments[:]
return copied_circuit

def _with_sliced_moments(self, moments: Iterable['cirq.Moment']) -> 'Circuit':
new_circuit = Circuit()
new_circuit._moments = list(moments)
return new_circuit

# pylint: disable=function-redefined
@overload
def __setitem__(self, key: int, value: 'cirq.Moment'):
Expand Down
17 changes: 17 additions & 0 deletions cirq-core/cirq/circuits/circuit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,23 @@ def validate_moment(self, moment):
moment_and_op_type_validating_device = _MomentAndOpTypeValidatingDeviceType()


def test_from_moments():
a, b, c, d = cirq.LineQubit.range(4)
assert cirq.Circuit.from_moments(
[cirq.X(a), cirq.Y(b)],
[cirq.X(c)],
[],
cirq.Z(d),
[cirq.measure(a, b, key='ab'), cirq.measure(c, d, key='cd')],
) == cirq.Circuit(
cirq.Moment(cirq.X(a), cirq.Y(b)),
cirq.Moment(cirq.X(c)),
cirq.Moment(),
cirq.Moment(cirq.Z(d)),
cirq.Moment(cirq.measure(a, b, key='ab'), cirq.measure(c, d, key='cd')),
)


def test_alignment():
assert repr(cirq.Alignment.LEFT) == 'cirq.Alignment.LEFT'
assert repr(cirq.Alignment.RIGHT) == 'cirq.Alignment.RIGHT'
Expand Down
13 changes: 7 additions & 6 deletions cirq-core/cirq/circuits/frozen_circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""An immutable version of the Circuit data structure."""
from typing import TYPE_CHECKING, FrozenSet, Iterable, Iterator, Sequence, Tuple, Union
from typing import FrozenSet, Iterable, Iterator, Sequence, Tuple, TYPE_CHECKING, Union

import numpy as np

Expand Down Expand Up @@ -51,6 +51,12 @@ def __init__(
base = Circuit(contents, strategy=strategy)
self._moments = tuple(base.moments)

@classmethod
def _from_moments(cls, moments: Iterable['cirq.Moment']) -> 'FrozenCircuit':
new_circuit = FrozenCircuit()
new_circuit._moments = tuple(moments)
return new_circuit

@property
def moments(self) -> Sequence['cirq.Moment']:
return self._moments
Expand Down Expand Up @@ -143,11 +149,6 @@ def __pow__(self, other) -> 'cirq.FrozenCircuit':
except:
return NotImplemented

def _with_sliced_moments(self, moments: Iterable['cirq.Moment']) -> 'FrozenCircuit':
new_circuit = FrozenCircuit()
new_circuit._moments = tuple(moments)
return new_circuit

def _resolve_parameters_(
self, resolver: 'cirq.ParamResolver', recursive: bool
) -> 'cirq.FrozenCircuit':
Expand Down
17 changes: 17 additions & 0 deletions cirq-core/cirq/circuits/frozen_circuit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,23 @@
import cirq


def test_from_moments():
a, b, c, d = cirq.LineQubit.range(4)
assert cirq.FrozenCircuit.from_moments(
[cirq.X(a), cirq.Y(b)],
[cirq.X(c)],
[],
cirq.Z(d),
[cirq.measure(a, b, key='ab'), cirq.measure(c, d, key='cd')],
) == cirq.FrozenCircuit(
cirq.Moment(cirq.X(a), cirq.Y(b)),
cirq.Moment(cirq.X(c)),
cirq.Moment(),
cirq.Moment(cirq.Z(d)),
cirq.Moment(cirq.measure(a, b, key='ab'), cirq.measure(c, d, key='cd')),
)


def test_freeze_and_unfreeze():
a, b = cirq.LineQubit.range(2)
c = cirq.Circuit(cirq.X(a), cirq.H(b))
Expand Down