Skip to content

Commit e27d883

Browse files
Support serializing sweeps with numpy values (#7398)
1 parent 92c9a91 commit e27d883

File tree

2 files changed

+12
-4
lines changed

2 files changed

+12
-4
lines changed

cirq-google/cirq_google/api/v2/sweeps.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from __future__ import annotations
1616

17+
import numbers
1718
from typing import Any, Callable, cast, TYPE_CHECKING
1819

1920
import sympy
@@ -31,10 +32,10 @@ def _build_sweep_const(value: Any) -> run_context_pb2.ConstValue:
3132
"""Build the sweep const message from a value."""
3233
if value is None:
3334
return run_context_pb2.ConstValue(is_none=True)
34-
elif isinstance(value, float):
35-
return run_context_pb2.ConstValue(float_value=value)
36-
elif isinstance(value, int):
37-
return run_context_pb2.ConstValue(int_value=value)
35+
elif isinstance(value, numbers.Integral):
36+
return run_context_pb2.ConstValue(int_value=int(value))
37+
elif isinstance(value, numbers.Real):
38+
return run_context_pb2.ConstValue(float_value=float(value))
3839
elif isinstance(value, str):
3940
return run_context_pb2.ConstValue(string_value=value)
4041
elif isinstance(value, tunits.Value):

cirq-google/cirq_google/api/v2/sweeps_test.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from copy import deepcopy
1919
from typing import Iterator
2020

21+
import numpy as np
2122
import pytest
2223
import sympy
2324
import tunits
@@ -418,3 +419,9 @@ def test_tunits_round_trip(sweep):
418419
msg = v2.sweep_to_proto(sweep)
419420
recovered = v2.sweep_from_proto(msg)
420421
assert sweep == recovered
422+
423+
424+
@pytest.mark.parametrize('value', [np.float32(3.14), np.int64(5)])
425+
def test_const_sweep_with_numpy_types_roundtrip(value):
426+
sweep = cirq.Points('const', [value])
427+
assert v2.sweep_from_proto(v2.sweep_to_proto(sweep)) == sweep

0 commit comments

Comments
 (0)