Skip to content

Commit 0e5df4e

Browse files
authored
Fix mapped measurement keys of repeat_until fields in CircuitOperations (#6881)
* Fix repeat_until scoped keys Create a _mapped_repeat_until cached property that ensures repeat_until conditions have the path and key mappings of the subcircuit applied during execution time or when calculating control_keys. * Remove param resolver support from repeat_until Conditions don't support params, as any unresolved symbols are expected to be measurement keys. Perhaps a future feature to consider, but outside the scope of this PR. * Undo previous change, just add dummy test for resolve_parameters Built-in conditions don't support param resolvers, but custom ones plausibly could. May as well keep the functionality, and add a dummy test for coverage. * Improve tests * coverage * fix test after merge from "use_repetition_ids" changes * check for symbols in repeat_until * doc
1 parent e6c1101 commit 0e5df4e

File tree

2 files changed

+87
-12
lines changed

2 files changed

+87
-12
lines changed

cirq-core/cirq/circuits/circuit_operation.py

Lines changed: 34 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -204,11 +204,12 @@ def __init__(
204204
)
205205

206206
self._repeat_until = repeat_until
207-
if self._repeat_until:
207+
mapped_repeat_until = self._mapped_repeat_until
208+
if mapped_repeat_until:
208209
if self._use_repetition_ids or self._repetitions != 1:
209210
raise ValueError('Cannot use repetitions with repeat_until')
210211
if protocols.measurement_key_objs(self._mapped_single_loop()).isdisjoint(
211-
self._repeat_until.keys
212+
mapped_repeat_until.keys
212213
):
213214
raise ValueError('Infinite loop: condition is not modified in subcircuit.')
214215

@@ -349,8 +350,9 @@ def _control_keys(self) -> FrozenSet['cirq.MeasurementKey']:
349350
if not protocols.control_keys(self.circuit)
350351
else protocols.control_keys(self._mapped_single_loop())
351352
)
352-
if self.repeat_until is not None:
353-
keys |= frozenset(self.repeat_until.keys) - self._measurement_key_objs_()
353+
mapped_repeat_until = self._mapped_repeat_until
354+
if mapped_repeat_until is not None:
355+
keys |= frozenset(mapped_repeat_until.keys) - self._measurement_key_objs_()
354356
return keys
355357

356358
def _control_keys_(self) -> FrozenSet['cirq.MeasurementKey']:
@@ -364,11 +366,8 @@ def _parameter_names_(self) -> FrozenSet[str]:
364366

365367
def _parameter_names_generator(self) -> Iterator[str]:
366368
yield from protocols.parameter_names(self.repetitions)
367-
for symbol in protocols.parameter_symbols(self.circuit):
368-
for name in protocols.parameter_names(
369-
protocols.resolve_parameters(symbol, self.param_resolver, recursive=False)
370-
):
371-
yield name
369+
yield from protocols.parameter_names(self._mapped_repeat_until)
370+
yield from protocols.parameter_names(self._mapped_any_loop)
372371

373372
@cached_property
374373
def _mapped_any_loop(self) -> 'cirq.Circuit':
@@ -391,6 +390,26 @@ def _mapped_single_loop(self, repetition_id: Optional[str] = None) -> 'cirq.Circ
391390
circuit, self.parent_path, bindable_keys=self._extern_keys
392391
)
393392

393+
@cached_property
394+
def _mapped_repeat_until(self) -> Optional['cirq.Condition']:
395+
"""Applies measurement_key_map, param_resolver, and current scope to repeat_until."""
396+
repeat_until = self.repeat_until
397+
if not repeat_until:
398+
return repeat_until
399+
if self.measurement_key_map:
400+
repeat_until = protocols.with_measurement_key_mapping(
401+
repeat_until, self.measurement_key_map
402+
)
403+
if self.param_resolver:
404+
repeat_until = protocols.resolve_parameters(
405+
repeat_until, self.param_resolver, recursive=False
406+
)
407+
return protocols.with_rescoped_keys(
408+
repeat_until,
409+
self.parent_path,
410+
bindable_keys=self._extern_keys | self._measurement_key_objs,
411+
)
412+
394413
def mapped_circuit(self, deep: bool = False) -> 'cirq.Circuit':
395414
"""Applies all maps to the contained circuit and returns the result.
396415
@@ -427,12 +446,13 @@ def _decompose_(self) -> Iterator['cirq.Operation']:
427446
return self.mapped_circuit(deep=False).all_operations()
428447

429448
def _act_on_(self, sim_state: 'cirq.SimulationStateBase') -> bool:
430-
if self.repeat_until:
449+
mapped_repeat_until = self._mapped_repeat_until
450+
if mapped_repeat_until:
431451
circuit = self._mapped_single_loop()
432452
while True:
433453
for op in circuit.all_operations():
434454
protocols.act_on(op, sim_state)
435-
if self.repeat_until.resolve(sim_state.classical_data):
455+
if mapped_repeat_until.resolve(sim_state.classical_data):
436456
break
437457
else:
438458
for op in self._decompose_():
@@ -808,7 +828,9 @@ def with_params(
808828
by param_values.
809829
"""
810830
new_params = {}
811-
for k in protocols.parameter_symbols(self.circuit):
831+
for k in protocols.parameter_symbols(self.circuit) | protocols.parameter_symbols(
832+
self.repeat_until
833+
):
812834
v = self.param_resolver.value_of(k, recursive=False)
813835
v = protocols.resolve_parameters(v, param_values, recursive=recursive)
814836
if v != k:

cirq-core/cirq/circuits/circuit_operation_test.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1239,4 +1239,57 @@ def test_repeat_until_error():
12391239
)
12401240

12411241

1242+
def test_repeat_until_protocols():
1243+
q = cirq.LineQubit(0)
1244+
op = cirq.CircuitOperation(
1245+
cirq.FrozenCircuit(cirq.H(q) ** sympy.Symbol('p'), cirq.measure(q, key='a')),
1246+
repeat_until=cirq.SympyCondition(sympy.Eq(sympy.Symbol('a'), 0)),
1247+
)
1248+
scoped = cirq.with_rescoped_keys(op, ('0',))
1249+
# Ensure the _repeat_until has been mapped, the measurement has been mapped to the same key,
1250+
# and the control keys of the subcircuit is empty (because the control key of the condition is
1251+
# bound to the measurement).
1252+
assert scoped._mapped_repeat_until.keys == (cirq.MeasurementKey('a', ('0',)),)
1253+
assert cirq.measurement_key_objs(scoped) == {cirq.MeasurementKey('a', ('0',))}
1254+
assert not cirq.control_keys(scoped)
1255+
mapped = cirq.with_measurement_key_mapping(scoped, {'a': 'b'})
1256+
assert mapped._mapped_repeat_until.keys == (cirq.MeasurementKey('b', ('0',)),)
1257+
assert cirq.measurement_key_objs(mapped) == {cirq.MeasurementKey('b', ('0',))}
1258+
assert not cirq.control_keys(mapped)
1259+
prefixed = cirq.with_key_path_prefix(mapped, ('1',))
1260+
assert prefixed._mapped_repeat_until.keys == (cirq.MeasurementKey('b', ('1', '0')),)
1261+
assert cirq.measurement_key_objs(prefixed) == {cirq.MeasurementKey('b', ('1', '0'))}
1262+
assert not cirq.control_keys(prefixed)
1263+
setpath = cirq.with_key_path(prefixed, ('2',))
1264+
assert setpath._mapped_repeat_until.keys == (cirq.MeasurementKey('b', ('2',)),)
1265+
assert cirq.measurement_key_objs(setpath) == {cirq.MeasurementKey('b', ('2',))}
1266+
assert not cirq.control_keys(setpath)
1267+
resolved = cirq.resolve_parameters(setpath, {'p': 1})
1268+
assert resolved._mapped_repeat_until.keys == (cirq.MeasurementKey('b', ('2',)),)
1269+
assert cirq.measurement_key_objs(resolved) == {cirq.MeasurementKey('b', ('2',))}
1270+
assert not cirq.control_keys(resolved)
1271+
1272+
1273+
def test_inner_repeat_until_simulate():
1274+
sim = cirq.Simulator()
1275+
q = cirq.LineQubit(0)
1276+
inner_loop = cirq.CircuitOperation(
1277+
cirq.FrozenCircuit(cirq.H(q), cirq.measure(q, key="inner_loop")),
1278+
repeat_until=cirq.SympyCondition(sympy.Eq(sympy.Symbol("inner_loop"), 0)),
1279+
)
1280+
outer_loop = cirq.Circuit(inner_loop, cirq.X(q), cirq.measure(q, key="outer_loop"))
1281+
circuit = cirq.Circuit(
1282+
cirq.CircuitOperation(
1283+
cirq.FrozenCircuit(outer_loop), repetitions=2, use_repetition_ids=True
1284+
)
1285+
)
1286+
result = sim.run(circuit, repetitions=1)
1287+
assert all(len(v) == 1 and v[0] == 1 for v in result.records['0:inner_loop'][0][:-1])
1288+
assert result.records['0:inner_loop'][0][-1] == [0]
1289+
assert result.records['0:outer_loop'] == [[[1]]]
1290+
assert all(len(v) == 1 and v[0] == 1 for v in result.records['1:inner_loop'][0][:-1])
1291+
assert result.records['1:inner_loop'][0][-1] == [0]
1292+
assert result.records['1:outer_loop'] == [[[1]]]
1293+
1294+
12421295
# TODO: Operation has a "gate" property. What is this for a CircuitOperation?

0 commit comments

Comments
 (0)