Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 14 additions & 20 deletions cirq-core/cirq/circuits/circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -1284,6 +1284,20 @@ def _is_parameterized_(self) -> bool:
def _parameter_names_(self) -> AbstractSet[str]:
return {name for op in self.all_operations() for name in protocols.parameter_names(op)}

def _resolve_parameters_(
self: CIRCUIT_TYPE, resolver: 'cirq.ParamResolver', recursive: bool
) -> CIRCUIT_TYPE:
changed = False
resolved_moments: List['cirq.Moment'] = []
for moment in self:
resolved_moment = protocols.resolve_parameters(moment, resolver, recursive)
if resolved_moment is not moment:
changed = True
resolved_moments.append(resolved_moment)
if not changed:
return self
return self._from_moments(resolved_moments)

def _qasm_(self) -> str:
return self.to_qasm()

Expand Down Expand Up @@ -2377,17 +2391,6 @@ def clear_operations_touching(
if 0 <= k < len(self._moments):
self._moments[k] = self._moments[k].without_operations_touching(qubits)

def _resolve_parameters_(
self, resolver: 'cirq.ParamResolver', recursive: bool
) -> 'cirq.Circuit':
resolved_moments = []
for moment in self:
resolved_operations = _resolve_operations(moment.operations, resolver, recursive)
new_moment = Moment(resolved_operations)
resolved_moments.append(new_moment)

return Circuit(resolved_moments)

@property
def moments(self) -> Sequence['cirq.Moment']:
return self._moments
Expand Down Expand Up @@ -2441,15 +2444,6 @@ def _pick_inserted_ops_moment_indices(
return moment_indices, frontier


def _resolve_operations(
operations: Iterable['cirq.Operation'], param_resolver: 'cirq.ParamResolver', recursive: bool
) -> List['cirq.Operation']:
resolved_operations: List['cirq.Operation'] = []
for op in operations:
resolved_operations.append(protocols.resolve_parameters(op, param_resolver, recursive))
return resolved_operations


def _get_moment_annotations(moment: 'cirq.Moment') -> Iterator['cirq.Operation']:
for op in moment.operations:
if op.qubits:
Expand Down
17 changes: 17 additions & 0 deletions cirq-core/cirq/circuits/circuit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3039,6 +3039,23 @@ def test_resolve_parameters(circuit_cls, resolve_fn):
cirq.testing.assert_same_circuits(expected_circuit, resolved_circuit)


@pytest.mark.parametrize('circuit_cls', [cirq.Circuit, cirq.FrozenCircuit])
@pytest.mark.parametrize('resolve_fn', [cirq.resolve_parameters, cirq.resolve_parameters_once])
def test_resolve_parameters_no_change(circuit_cls, resolve_fn):
a, b = cirq.LineQubit.range(2)
circuit = circuit_cls(cirq.CZ(a, b), cirq.X(a), cirq.Y(b))
resolved_circuit = resolve_fn(circuit, cirq.ParamResolver({'u': 0.1, 'v': 0.3, 'w': 0.2}))
assert resolved_circuit is circuit

circuit = circuit_cls(
cirq.CZ(a, b) ** sympy.Symbol('u'),
cirq.X(a) ** sympy.Symbol('v'),
cirq.Y(b) ** sympy.Symbol('w'),
)
resolved_circuit = resolve_fn(circuit, cirq.ParamResolver({}))
assert resolved_circuit is circuit


@pytest.mark.parametrize('circuit_cls', [cirq.Circuit, cirq.FrozenCircuit])
@pytest.mark.parametrize('resolve_fn', [cirq.resolve_parameters, cirq.resolve_parameters_once])
def test_parameter_names(circuit_cls, resolve_fn):
Expand Down
5 changes: 0 additions & 5 deletions cirq-core/cirq/circuits/frozen_circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,11 +149,6 @@ def __pow__(self, other) -> 'cirq.FrozenCircuit':
except:
return NotImplemented

def _resolve_parameters_(
self, resolver: 'cirq.ParamResolver', recursive: bool
) -> 'cirq.FrozenCircuit':
return self.unfreeze()._resolve_parameters_(resolver, recursive).freeze()

def concat_ragged(
*circuits: 'cirq.AbstractCircuit', align: Union['cirq.Alignment', str] = Alignment.LEFT
) -> 'cirq.FrozenCircuit':
Expand Down
22 changes: 22 additions & 0 deletions cirq-core/cirq/circuits/moment.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,14 @@

import itertools
from typing import (
AbstractSet,
Any,
Callable,
Dict,
FrozenSet,
Iterable,
Iterator,
List,
Mapping,
overload,
Optional,
Expand Down Expand Up @@ -236,6 +238,26 @@ def without_operations_touching(self, qubits: Iterable['cirq.Qid']) -> 'cirq.Mom
if qubits.isdisjoint(frozenset(operation.qubits))
)

def _is_parameterized_(self) -> bool:
return any(protocols.is_parameterized(op) for op in self)

def _parameter_names_(self) -> AbstractSet[str]:
return {name for op in self for name in protocols.parameter_names(op)}

def _resolve_parameters_(
self, resolver: 'cirq.ParamResolver', recursive: bool
) -> 'cirq.Moment':
changed = False
resolved_ops: List['cirq.Operation'] = []
for op in self:
resolved_op = protocols.resolve_parameters(op, resolver, recursive)
if resolved_op != op:
changed = True
resolved_ops.append(resolved_op)
if not changed:
return self
return Moment(resolved_ops)

def _with_measurement_key_mapping_(self, key_map: Mapping[str, str]):
return Moment(
protocols.with_measurement_key_mapping(op, key_map)
Expand Down
33 changes: 33 additions & 0 deletions cirq-core/cirq/circuits/moment_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import numpy as np
import pytest
import sympy

import cirq
import cirq.testing
Expand Down Expand Up @@ -274,6 +275,38 @@ def test_without_operations_touching():
)


def test_is_parameterized():
a, b = cirq.LineQubit.range(2)
moment = cirq.Moment(cirq.X(a) ** sympy.Symbol('v'), cirq.Y(b) ** sympy.Symbol('w'))
assert cirq.is_parameterized(moment)
assert not cirq.is_parameterized(cirq.Moment(cirq.X(a), cirq.Y(b)))


def test_resolve_parameters():
a, b = cirq.LineQubit.range(2)
moment = cirq.Moment(cirq.X(a) ** sympy.Symbol('v'), cirq.Y(b) ** sympy.Symbol('w'))
resolved_moment = cirq.resolve_parameters(moment, cirq.ParamResolver({'v': 0.1, 'w': 0.2}))
assert resolved_moment == cirq.Moment(cirq.X(a) ** 0.1, cirq.Y(b) ** 0.2)


def test_resolve_parameters_no_change():
a, b = cirq.LineQubit.range(2)
moment = cirq.Moment(cirq.X(a), cirq.Y(b))
resolved_moment = cirq.resolve_parameters(moment, cirq.ParamResolver({'v': 0.1, 'w': 0.2}))
assert resolved_moment is moment

moment = cirq.Moment(cirq.X(a) ** sympy.Symbol('v'), cirq.Y(b) ** sympy.Symbol('w'))
resolved_moment = cirq.resolve_parameters(moment, cirq.ParamResolver({}))
assert resolved_moment is moment


def test_parameter_names():
a, b = cirq.LineQubit.range(2)
moment = cirq.Moment(cirq.X(a) ** sympy.Symbol('v'), cirq.Y(b) ** sympy.Symbol('w'))
assert cirq.parameter_names(moment) == {'v', 'w'}
assert cirq.parameter_names(cirq.Moment(cirq.X(a), cirq.Y(b))) == set()


def test_with_measurement_keys():
a, b = cirq.LineQubit.range(2)
m = cirq.Moment(cirq.measure(a, key='m1'), cirq.measure(b, key='m2'))
Expand Down