Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions cirq-google/cirq_google/api/v2/run_context.proto
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
syntax = "proto3";

import "cirq_google/api/v2/program.proto";
import "tunits/proto/tunits.proto";

package cirq.google.api.v2;

Expand Down Expand Up @@ -209,6 +210,8 @@ message SingleSweep {
message Points {
// The values.
repeated float points = 1;

tunits.Value unit = 2;
}

// A range of evenly-spaced values.
Expand All @@ -225,6 +228,8 @@ message Linspace {
// greater than zero. If it is 1, the first_point and last_point must be
// the same.
int64 num_points = 3;

tunits.Value unit = 4;
}

// A constant value.
Expand All @@ -236,5 +241,6 @@ message ConstValue {
float float_value = 2;
int64 int_value = 3;
string string_value = 4;
tunits.Value with_unit_value = 5;
}
}
59 changes: 30 additions & 29 deletions cirq-google/cirq_google/api/v2/run_context_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

25 changes: 20 additions & 5 deletions cirq-google/cirq_google/api/v2/run_context_pb2.pyi

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

41 changes: 34 additions & 7 deletions cirq-google/cirq_google/api/v2/sweeps.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from typing import Any, cast, Dict, List, Optional

import sympy
import tunits

import cirq
from cirq_google.api.v2 import run_context_pb2
Expand All @@ -31,6 +32,8 @@ def _build_sweep_const(value: Any) -> run_context_pb2.ConstValue:
return run_context_pb2.ConstValue(int_value=value)
elif isinstance(value, str):
return run_context_pb2.ConstValue(string_value=value)
elif isinstance(value, tunits.Value):
Comment thread
senecameeks marked this conversation as resolved.
return run_context_pb2.ConstValue(with_unit_value=value.to_proto())
else:
raise ValueError(
f"Unsupported type for serializing const sweep: {value=} and {type(value)=}"
Expand All @@ -47,6 +50,8 @@ def _recover_sweep_const(const_pb: run_context_pb2.ConstValue) -> Any:
return const_pb.int_value
if const_pb.WhichOneof('value') == 'string_value':
return const_pb.string_value
if const_pb.WhichOneof('value') == 'with_unit_value':
return tunits.Value.from_proto(const_pb.with_unit_value)


def sweep_to_proto(
Expand Down Expand Up @@ -87,9 +92,16 @@ def sweep_to_proto(
sweep_to_proto(s, out=out.sweep_function.sweeps.add())
elif isinstance(sweep, cirq.Linspace) and not isinstance(sweep.key, sympy.Expr):
out.single_sweep.parameter_key = sweep.key
out.single_sweep.linspace.first_point = sweep.start
out.single_sweep.linspace.last_point = sweep.stop
out.single_sweep.linspace.num_points = sweep.length
if isinstance(sweep.start, tunits.Value):
Comment thread
senecameeks marked this conversation as resolved.
unit = sweep.start.unit
out.single_sweep.linspace.first_point = sweep.start[unit]
out.single_sweep.linspace.last_point = sweep.stop[unit]
out.single_sweep.linspace.num_points = sweep.length
unit.to_proto(out.single_sweep.linspace.unit)
else:
out.single_sweep.linspace.first_point = sweep.start
out.single_sweep.linspace.last_point = sweep.stop
out.single_sweep.linspace.num_points = sweep.length
# Use duck-typing to support google-internal Parameter objects
if sweep.metadata and getattr(sweep.metadata, 'path', None):
out.single_sweep.parameter.path.extend(sweep.metadata.path)
Expand All @@ -102,7 +114,12 @@ def sweep_to_proto(
if len(sweep.points) == 1:
out.single_sweep.const_value.MergeFrom(_build_sweep_const(sweep.points[0]))
else:
out.single_sweep.points.points.extend(sweep.points)
if isinstance(sweep.points[0], tunits.Value):
unit = sweep.points[0].unit
out.single_sweep.points.points.extend(p[unit] for p in sweep.points)
unit.to_proto(out.single_sweep.points.unit)
else:
out.single_sweep.points.points.extend(sweep.points)
# Use duck-typing to support google-internal Parameter objects
if sweep.metadata and getattr(sweep.metadata, 'path', None):
out.single_sweep.parameter.path.extend(sweep.metadata.path)
Expand Down Expand Up @@ -162,15 +179,25 @@ def sweep_from_proto(msg: run_context_pb2.Sweep) -> cirq.Sweep:
else:
metadata = None
if msg.single_sweep.WhichOneof('sweep') == 'linspace':
unit: float | tunits.Value = 1.0
if msg.single_sweep.linspace.HasField('unit'):
unit = tunits.Value.from_proto(msg.single_sweep.linspace.unit)
return cirq.Linspace(
key=key,
start=msg.single_sweep.linspace.first_point,
stop=msg.single_sweep.linspace.last_point,
start=msg.single_sweep.linspace.first_point * unit, # type: ignore[arg-type]
stop=msg.single_sweep.linspace.last_point * unit, # type: ignore[arg-type]
length=msg.single_sweep.linspace.num_points,
metadata=metadata,
)
if msg.single_sweep.WhichOneof('sweep') == 'points':
return cirq.Points(key=key, points=msg.single_sweep.points.points, metadata=metadata)
unit = 1.0
if msg.single_sweep.points.HasField('unit'):
unit = tunits.Value.from_proto(msg.single_sweep.points.unit)
return cirq.Points(
key=key,
points=[p * unit for p in msg.single_sweep.points.points],
metadata=metadata,
)
if msg.single_sweep.WhichOneof('sweep') == 'const_value':
return cirq.Points(
key=key,
Expand Down
Loading