Skip to content

Commit aee2c30

Browse files
authored
Support parameters in InternalTag (#7060)
* Support parameters in InternalTag - This adds support for parameters in InternalTag. * Address comments.
1 parent e3b46a1 commit aee2c30

File tree

2 files changed

+17
-6
lines changed

2 files changed

+17
-6
lines changed

cirq-google/cirq_google/ops/internal_tag.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,6 @@
1717
from cirq import value
1818
from cirq_google.api.v2 import program_pb2
1919

20-
# from cirq_google.serialization import arg_func_langs
21-
2220

2321
@value.value_equality
2422
class InternalTag:
@@ -62,18 +60,31 @@ def _value_equality_values_(self):
6260
return (self.name, self.package, tag_args_eq_values)
6361

6462
def to_proto(self, msg: Optional[program_pb2.Tag] = None) -> program_pb2.Tag:
63+
# To avoid circular import
64+
from cirq_google.serialization import arg_func_langs
65+
6566
if msg is None:
6667
msg = program_pb2.Tag()
6768
msg.internal_tag.tag_name = self.name
6869
msg.internal_tag.tag_package = self.package
70+
for k, v in self.tag_args.items():
71+
arg_func_langs.arg_to_proto(
72+
v, out=msg.internal_tag.tag_args[k], arg_function_language='exp'
73+
)
6974
return msg
7075

7176
@staticmethod
7277
def from_proto(msg: program_pb2.Tag) -> 'InternalTag':
78+
# To avoid circular import
79+
from cirq_google.serialization import arg_func_langs
80+
7381
if msg.WhichOneof("tag") != "internal_tag":
7482
raise ValueError(f"Message is not a InternalTag, {msg}")
83+
84+
kw_dict = {}
85+
for k, v in msg.internal_tag.tag_args.items():
86+
kw_dict[k] = arg_func_langs.arg_from_proto(v, arg_function_language='exp')
87+
7588
return InternalTag(
76-
name=msg.internal_tag.tag_name,
77-
package=msg.internal_tag.tag_package,
78-
**msg.internal_tag.tag_args,
89+
name=msg.internal_tag.tag_name, package=msg.internal_tag.tag_package, **kw_dict
7990
)

cirq-google/cirq_google/ops/internal_tag_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def test_internal_tag_with_hashable_args_is_hashable():
5858

5959

6060
def test_proto():
61-
tag = cirq_google.InternalTag(name="TagWithNoParams", package='test')
61+
tag = cirq_google.InternalTag(name="TagWithNoParams", package='test', param1=2.5)
6262
msg = tag.to_proto()
6363
assert tag == cirq_google.InternalTag.from_proto(msg)
6464

0 commit comments

Comments
 (0)