Skip to content

Commit ca6ceb3

Browse files
authored
Allow ability to plug in custom (de)serializers for cirq_google protos (#7059)
* Allow ability to plug in custom (de)serializers for cirq_google protos - This will allow users to plug in custom serializers and deserializers, which can parse gates before falling back to the default. - This enables internal libraries to parse and deserialize non-public gates, tags, and operations. * Fix coverage and get rid of unneeded junk. * Address comments. * Flip warnings.
1 parent 2e71950 commit ca6ceb3

File tree

5 files changed

+167
-85
lines changed

5 files changed

+167
-85
lines changed

cirq-google/cirq_google/serialization/circuit_serializer.py

Lines changed: 66 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -56,15 +56,23 @@ class CircuitSerializer(serializer.Serializer):
5656
serialization of duplicate operations as entries in the constant table.
5757
This flag will soon become the default and disappear as soon as
5858
deserialization of this field is deployed.
59+
op_serializer: Optional custom serializer for serializing unknown gates.
60+
op_deserializer: Optional custom deserializer for deserializing unknown gates.
5961
"""
6062

6163
def __init__(
62-
self, USE_CONSTANTS_TABLE_FOR_MOMENTS=False, USE_CONSTANTS_TABLE_FOR_OPERATIONS=False
64+
self,
65+
USE_CONSTANTS_TABLE_FOR_MOMENTS=False,
66+
USE_CONSTANTS_TABLE_FOR_OPERATIONS=False,
67+
op_serializer: Optional[op_serializer.OpSerializer] = None,
68+
op_deserializer: Optional[op_deserializer.OpDeserializer] = None,
6369
):
6470
"""Construct the circuit serializer object."""
6571
super().__init__(gate_set_name=_SERIALIZER_NAME)
6672
self.use_constants_table_for_moments = USE_CONSTANTS_TABLE_FOR_MOMENTS
6773
self.use_constants_table_for_operations = USE_CONSTANTS_TABLE_FOR_OPERATIONS
74+
self.op_serializer = op_serializer
75+
self.op_deserializer = op_deserializer
6876

6977
def serialize(
7078
self,
@@ -144,26 +152,44 @@ def _serialize_circuit(
144152
moment_proto.operation_indices.append(op_index)
145153
else:
146154
op_pb = v2.program_pb2.Operation()
155+
if self.op_serializer and self.op_serializer.can_serialize_operation(op):
156+
self.op_serializer.to_proto(
157+
op,
158+
op_pb,
159+
arg_function_language=arg_function_language,
160+
constants=constants,
161+
raw_constants=raw_constants,
162+
)
163+
else:
164+
self._serialize_gate_op(
165+
op,
166+
op_pb,
167+
arg_function_language=arg_function_language,
168+
constants=constants,
169+
raw_constants=raw_constants,
170+
)
171+
constants.append(v2.program_pb2.Constant(operation_value=op_pb))
172+
op_index = len(constants) - 1
173+
raw_constants[op] = op_index
174+
moment_proto.operation_indices.append(op_index)
175+
else:
176+
op_pb = moment_proto.operations.add()
177+
if self.op_serializer and self.op_serializer.can_serialize_operation(op):
178+
self.op_serializer.to_proto(
179+
op,
180+
op_pb,
181+
arg_function_language=arg_function_language,
182+
constants=constants,
183+
raw_constants=raw_constants,
184+
)
185+
else:
147186
self._serialize_gate_op(
148187
op,
149188
op_pb,
150189
arg_function_language=arg_function_language,
151190
constants=constants,
152191
raw_constants=raw_constants,
153192
)
154-
constants.append(v2.program_pb2.Constant(operation_value=op_pb))
155-
op_index = len(constants) - 1
156-
raw_constants[op] = op_index
157-
moment_proto.operation_indices.append(op_index)
158-
else:
159-
op_pb = moment_proto.operations.add()
160-
self._serialize_gate_op(
161-
op,
162-
op_pb,
163-
arg_function_language=arg_function_language,
164-
constants=constants,
165-
raw_constants=raw_constants,
166-
)
167193

168194
if self.use_constants_table_for_moments:
169195
# Add this moment to the constants table
@@ -469,14 +495,23 @@ def deserialize(self, proto: v2.program_pb2.Program) -> cirq.Circuit:
469495
elif which_const == 'qubit':
470496
deserialized_constants.append(v2.qubit_from_proto_id(constant.qubit.id))
471497
elif which_const == 'operation_value':
472-
deserialized_constants.append(
473-
self._deserialize_gate_op(
498+
if self.op_deserializer and self.op_deserializer.can_deserialize_proto(
499+
constant.operation_value
500+
):
501+
op_pb = self.op_deserializer.from_proto(
474502
constant.operation_value,
475503
arg_function_language=arg_func_language,
476504
constants=proto.constants,
477505
deserialized_constants=deserialized_constants,
478506
)
479-
)
507+
else:
508+
op_pb = self._deserialize_gate_op(
509+
constant.operation_value,
510+
arg_function_language=arg_func_language,
511+
constants=proto.constants,
512+
deserialized_constants=deserialized_constants,
513+
)
514+
deserialized_constants.append(op_pb)
480515
elif which_const == 'moment_value':
481516
deserialized_constants.append(
482517
self._deserialize_moment(
@@ -541,12 +576,20 @@ def _deserialize_moment(
541576
) -> cirq.Moment:
542577
moment_ops = []
543578
for op in moment_proto.operations:
544-
gate_op = self._deserialize_gate_op(
545-
op,
546-
arg_function_language=arg_function_language,
547-
constants=constants,
548-
deserialized_constants=deserialized_constants,
549-
)
579+
if self.op_deserializer and self.op_deserializer.can_deserialize_proto(op):
580+
gate_op = self.op_deserializer.from_proto(
581+
op,
582+
arg_function_language=arg_function_language,
583+
constants=constants,
584+
deserialized_constants=deserialized_constants,
585+
)
586+
else:
587+
gate_op = self._deserialize_gate_op(
588+
op,
589+
arg_function_language=arg_function_language,
590+
constants=constants,
591+
deserialized_constants=deserialized_constants,
592+
)
550593
if op.tag_indices:
551594
tags = [
552595
deserialized_constants[tag_index]

cirq-google/cirq_google/serialization/circuit_serializer_test.py

Lines changed: 92 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import Dict, List
15+
from typing import Any, Dict, List, Optional
1616
import pytest
1717

1818
import numpy as np
@@ -25,6 +25,8 @@
2525
import cirq_google as cg
2626
from cirq_google.api import v2
2727
from cirq_google.serialization.circuit_serializer import _SERIALIZER_NAME
28+
from cirq_google.serialization.op_deserializer import OpDeserializer
29+
from cirq_google.serialization.op_serializer import OpSerializer
2830

2931

3032
class FakeDevice(cirq.Device):
@@ -856,6 +858,7 @@ def test_circuit_with_tag(tag):
856858
assert nc[0].operations[0].tags == (tag,)
857859

858860

861+
@pytest.mark.filterwarnings('ignore:Unrecognized Tag .*DingDongTag')
859862
def test_unknown_tag_is_ignored():
860863
class DingDongTag:
861864
pass
@@ -866,6 +869,7 @@ class DingDongTag:
866869
assert cirq.Circuit(cirq.X(cirq.q(0))) == nc
867870

868871

872+
@pytest.mark.filterwarnings('ignore:Unknown tag msg=phase_match')
869873
def test_unrecognized_tag_is_ignored():
870874
op_tag = v2.program_pb2.Operation()
871875
op_tag.xpowgate.exponent.float_value = 1.0
@@ -917,3 +921,90 @@ def test_circuit_with_units():
917921
)
918922
msg = cg.CIRCUIT_SERIALIZER.serialize(c)
919923
assert c == cg.CIRCUIT_SERIALIZER.deserialize(msg)
924+
925+
926+
class BingBongGate(cirq.Gate):
927+
928+
def __init__(self, param: float):
929+
self.param = param
930+
931+
def _num_qubits_(self) -> int:
932+
return 1
933+
934+
935+
class BingBongSerializer(OpSerializer):
936+
"""Describes how to serialize CircuitOperations."""
937+
938+
def can_serialize_operation(self, op):
939+
return isinstance(op.gate, BingBongGate)
940+
941+
def to_proto(
942+
self,
943+
op: cirq.CircuitOperation,
944+
msg: Optional[v2.program_pb2.CircuitOperation] = None,
945+
*,
946+
arg_function_language: Optional[str] = '',
947+
constants: List[v2.program_pb2.Constant],
948+
raw_constants: Dict[Any, int],
949+
) -> v2.program_pb2.CircuitOperation:
950+
assert isinstance(op.gate, BingBongGate)
951+
if msg is None:
952+
msg = v2.program_pb2.Operation() # pragma: nocover
953+
msg.internalgate.name = 'bingbong'
954+
msg.internalgate.module = 'test'
955+
msg.internalgate.num_qubits = 1
956+
msg.internalgate.gate_args['param'].arg_value.float_value = op.gate.param
957+
958+
for qubit in op.qubits:
959+
if qubit not in raw_constants:
960+
constants.append(
961+
v2.program_pb2.Constant(
962+
qubit=v2.program_pb2.Qubit(id=v2.qubit_to_proto_id(qubit))
963+
)
964+
)
965+
raw_constants[qubit] = len(constants) - 1
966+
msg.qubit_constant_index.append(raw_constants[qubit])
967+
return msg
968+
969+
970+
class BingBongDeserializer(OpDeserializer):
971+
"""Describes how to serialize CircuitOperations."""
972+
973+
def can_deserialize_proto(self, proto):
974+
return (
975+
isinstance(proto, v2.program_pb2.Operation)
976+
and proto.WhichOneof("gate_value") == "internalgate"
977+
and proto.internalgate.name == 'bingbong'
978+
and proto.internalgate.module == 'test'
979+
)
980+
981+
def from_proto(
982+
self,
983+
proto: v2.program_pb2.Operation,
984+
*,
985+
arg_function_language: str = '',
986+
constants: List[v2.program_pb2.Constant],
987+
deserialized_constants: List[Any],
988+
) -> cirq.Operation:
989+
return BingBongGate(param=proto.internalgate.gate_args["param"].arg_value.float_value).on(
990+
deserialized_constants[proto.qubit_constant_index[0]]
991+
)
992+
993+
994+
@pytest.mark.parametrize('use_constants_table', [True, False])
995+
def test_custom_serializer(use_constants_table: bool):
996+
c = cirq.Circuit(BingBongGate(param=2.5)(cirq.q(0, 0)))
997+
serializer = cg.CircuitSerializer(
998+
USE_CONSTANTS_TABLE_FOR_MOMENTS=use_constants_table,
999+
USE_CONSTANTS_TABLE_FOR_OPERATIONS=use_constants_table,
1000+
op_serializer=BingBongSerializer(),
1001+
op_deserializer=BingBongDeserializer(),
1002+
)
1003+
msg = serializer.serialize(c)
1004+
deserialized_circuit = serializer.deserialize(msg)
1005+
moment = deserialized_circuit[0]
1006+
assert len(moment) == 1
1007+
op = moment[cirq.q(0, 0)]
1008+
assert isinstance(op.gate, BingBongGate)
1009+
assert op.gate.param == 2.5
1010+
assert op.qubits == (cirq.q(0, 0),)

cirq-google/cirq_google/serialization/op_deserializer.py

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -26,18 +26,12 @@ class OpDeserializer(abc.ABC):
2626
"""Generic supertype for operation deserializers.
2727
2828
Each operation deserializer describes how to deserialize operation protos
29-
with a particular `serialized_id` to a specific type of Cirq operation.
29+
to a specific type of Cirq operation.
3030
"""
3131

32-
@property
3332
@abc.abstractmethod
34-
def serialized_id(self) -> str:
35-
"""Returns the string identifier for the accepted serialized objects.
36-
37-
This ID denotes the serialization format this deserializer consumes. For
38-
example, one of the common deserializers converts objects with the id
39-
'xy' into PhasedXPowGates.
40-
"""
33+
def can_deserialize_proto(self, proto) -> bool:
34+
"""Whether the given operation can be serialized by this serializer."""
4135

4236
@abc.abstractmethod
4337
def from_proto(
@@ -66,9 +60,8 @@ def from_proto(
6660
class CircuitOpDeserializer(OpDeserializer):
6761
"""Describes how to serialize CircuitOperations."""
6862

69-
@property
70-
def serialized_id(self):
71-
return 'circuit'
63+
def can_deserialize_proto(self, proto):
64+
return isinstance(proto, v2.program_pb2.CircuitOperation) # pragma: nocover
7265

7366
def from_proto(
7467
self,

cirq-google/cirq_google/serialization/op_serializer.py

Lines changed: 4 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import Any, Callable, Dict, List, Optional, Type, TypeVar
15+
from typing import Any, Dict, List, Optional, Union
1616
import numbers
1717

1818
import abc
@@ -23,9 +23,6 @@
2323
from cirq_google.api import v2
2424
from cirq_google.serialization.arg_func_langs import arg_to_proto
2525

26-
# Type for variables that are subclasses of ops.Gate.
27-
Gate = TypeVar('Gate', bound=cirq.Gate)
28-
2926

3027
class OpSerializer(abc.ABC):
3128
"""Generic supertype for operation serializers.
@@ -35,25 +32,6 @@ class OpSerializer(abc.ABC):
3532
may serialize to the same format.
3633
"""
3734

38-
@property
39-
@abc.abstractmethod
40-
def internal_type(self) -> Type:
41-
"""Returns the type that the operation contains.
42-
43-
For GateOperations, this is the gate type.
44-
For CircuitOperations, this is FrozenCircuit.
45-
"""
46-
47-
@property
48-
@abc.abstractmethod
49-
def serialized_id(self) -> str:
50-
"""Returns the string identifier for the resulting serialized object.
51-
52-
This ID denotes the serialization format this serializer produces. For
53-
example, one of the common serializers assigns the id 'xy' to XPowGates,
54-
as they serialize into a format also used by YPowGates.
55-
"""
56-
5735
@abc.abstractmethod
5836
def to_proto(
5937
self,
@@ -63,7 +41,7 @@ def to_proto(
6341
arg_function_language: Optional[str] = '',
6442
constants: List[v2.program_pb2.Constant],
6543
raw_constants: Dict[Any, int],
66-
) -> Optional[v2.program_pb2.CircuitOperation]:
44+
) -> Optional[Union[v2.program_pb2.CircuitOperation, v2.program_pb2.Operation]]:
6745
"""Converts op to proto using this serializer.
6846
6947
If self.can_serialize_operation(op) == false, this should return None.
@@ -83,33 +61,16 @@ def to_proto(
8361
the returned object.
8462
"""
8563

86-
@property
8764
@abc.abstractmethod
88-
def can_serialize_predicate(self) -> Callable[[cirq.Operation], bool]:
89-
"""The method used to determine if this can serialize an operation.
90-
91-
Depending on the serializer, additional checks may be required.
92-
"""
93-
9465
def can_serialize_operation(self, op: cirq.Operation) -> bool:
9566
"""Whether the given operation can be serialized by this serializer."""
96-
return self.can_serialize_predicate(op)
9767

9868

9969
class CircuitOpSerializer(OpSerializer):
10070
"""Describes how to serialize CircuitOperations."""
10171

102-
@property
103-
def internal_type(self):
104-
return cirq.FrozenCircuit
105-
106-
@property
107-
def serialized_id(self):
108-
return 'circuit'
109-
110-
@property
111-
def can_serialize_predicate(self):
112-
return lambda op: isinstance(op.untagged, cirq.CircuitOperation)
72+
def can_serialize_operation(self, op: cirq.Operation):
73+
return isinstance(op.untagged, cirq.CircuitOperation)
11374

11475
def to_proto(
11576
self,

cirq-google/cirq_google/serialization/op_serializer_test.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -49,12 +49,6 @@ def default_circuit():
4949
)
5050

5151

52-
def test_circuit_op_serializer_properties():
53-
serializer = cg.CircuitOpSerializer()
54-
assert serializer.internal_type == cirq.FrozenCircuit
55-
assert serializer.serialized_id == 'circuit'
56-
57-
5852
def test_can_serialize_circuit_op():
5953
serializer = cg.CircuitOpSerializer()
6054
assert serializer.can_serialize_operation(cirq.CircuitOperation(default_circuit()))

0 commit comments

Comments
 (0)