Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 0 additions & 21 deletions cirq-google/cirq_google/engine/processor_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,27 +76,6 @@ async def run_batch_async(
params_list: Optional[Sequence[cirq.Sweepable]] = None,
repetitions: Union[int, Sequence[int]] = 1,
) -> Sequence[Sequence['cg.EngineResult']]:
"""Runs the supplied circuits.

In order to gain a speedup from using this method instead of other run
methods, the following conditions must be satisfied:
1. All circuits must measure the same set of qubits.
2. The number of circuit repetitions must be the same for all
circuits. That is, the `repetitions` argument must be an integer,
or else a list with identical values.
"""
params_list, repetitions = self._normalize_batch_args(programs, params_list, repetitions)
if len(set(repetitions)) == 1:
# All repetitions are the same so batching can be done efficiently
job = await self._processor.run_batch_async(
programs=programs,
params_list=params_list,
repetitions=repetitions[0],
run_name=self._run_name,
device_config_name=self._device_config_name,
)
return await job.batched_results_async()
# Varying number of repetitions so no speedup
return cast(
Sequence[Sequence['cg.EngineResult']],
await super().run_batch_async(programs, params_list, repetitions),
Expand Down
64 changes: 44 additions & 20 deletions cirq-google/cirq_google/engine/processor_sampler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,16 +54,28 @@ def test_run_batch(run_name, device_config_name):
circuit2 = cirq.Circuit(cirq.Y(a))
params1 = [cirq.ParamResolver({'t': 1})]
params2 = [cirq.ParamResolver({'t': 2})]
circuits = [circuit1, circuit2]
params_list = [params1, params2]
sampler.run_batch(circuits, params_list, 5)
processor.run_batch_async.assert_called_with(
params_list=params_list,
programs=circuits,
repetitions=5,
run_name=run_name,
device_config_name=device_config_name,
)

sampler.run_batch([circuit1, circuit2], [params1, params2], 5)

expected_calls = [
mock.call(
program=circuit1,
params=params1,
repetitions=5,
run_name=run_name,
device_config_name=device_config_name,
),
mock.call().results_async(),
mock.call(
program=circuit2,
params=params2,
repetitions=5,
run_name=run_name,
device_config_name=device_config_name,
),
mock.call().results_async(),
]
processor.run_sweep_async.assert_has_calls(expected_calls)


@pytest.mark.parametrize(
Expand All @@ -79,16 +91,28 @@ def test_run_batch_identical_repetitions(run_name, device_config_name):
circuit2 = cirq.Circuit(cirq.Y(a))
params1 = [cirq.ParamResolver({'t': 1})]
params2 = [cirq.ParamResolver({'t': 2})]
circuits = [circuit1, circuit2]
params_list = [params1, params2]
sampler.run_batch(circuits, params_list, [5, 5])
processor.run_batch_async.assert_called_with(
params_list=params_list,
programs=circuits,
repetitions=5,
run_name=run_name,
device_config_name=device_config_name,
)

sampler.run_batch([circuit1, circuit2], [params1, params2], [5, 5])

expected_calls = [
mock.call(
program=circuit1,
params=params1,
repetitions=5,
run_name=run_name,
device_config_name=device_config_name,
),
mock.call().results_async(),
mock.call(
program=circuit2,
params=params2,
repetitions=5,
run_name=run_name,
device_config_name=device_config_name,
),
mock.call().results_async(),
]
processor.run_sweep_async.assert_has_calls(expected_calls)


def test_run_batch_bad_number_of_repetitions():
Expand Down