Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 66 additions & 23 deletions cirq-google/cirq_google/serialization/circuit_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,15 +56,23 @@ class CircuitSerializer(serializer.Serializer):
serialization of duplicate operations as entries in the constant table.
This flag will soon become the default and disappear as soon as
deserialization of this field is deployed.
op_serializer: Optional custom serializer for serializing unknown gates.
op_deserializer: Optional custom deserializer for deserializing unknown gates.
"""

def __init__(
self, USE_CONSTANTS_TABLE_FOR_MOMENTS=False, USE_CONSTANTS_TABLE_FOR_OPERATIONS=False
self,
USE_CONSTANTS_TABLE_FOR_MOMENTS=False,
USE_CONSTANTS_TABLE_FOR_OPERATIONS=False,
op_serializer: Optional[op_serializer.OpSerializer] = None,
op_deserializer: Optional[op_deserializer.OpDeserializer] = None,
):
"""Construct the circuit serializer object."""
super().__init__(gate_set_name=_SERIALIZER_NAME)
self.use_constants_table_for_moments = USE_CONSTANTS_TABLE_FOR_MOMENTS
self.use_constants_table_for_operations = USE_CONSTANTS_TABLE_FOR_OPERATIONS
self.op_serializer = op_serializer
self.op_deserializer = op_deserializer

def serialize(
self,
Expand Down Expand Up @@ -144,26 +152,44 @@ def _serialize_circuit(
moment_proto.operation_indices.append(op_index)
else:
op_pb = v2.program_pb2.Operation()
if self.op_serializer and self.op_serializer.can_serialize_operation(op):
self.op_serializer.to_proto(
op,
op_pb,
arg_function_language=arg_function_language,
constants=constants,
raw_constants=raw_constants,
)
else:
self._serialize_gate_op(
op,
op_pb,
arg_function_language=arg_function_language,
constants=constants,
raw_constants=raw_constants,
)
constants.append(v2.program_pb2.Constant(operation_value=op_pb))
op_index = len(constants) - 1
raw_constants[op] = op_index
moment_proto.operation_indices.append(op_index)
else:
op_pb = moment_proto.operations.add()
if self.op_serializer and self.op_serializer.can_serialize_operation(op):
self.op_serializer.to_proto(
op,
op_pb,
arg_function_language=arg_function_language,
constants=constants,
raw_constants=raw_constants,
)
else:
self._serialize_gate_op(
op,
op_pb,
arg_function_language=arg_function_language,
constants=constants,
raw_constants=raw_constants,
)
constants.append(v2.program_pb2.Constant(operation_value=op_pb))
op_index = len(constants) - 1
raw_constants[op] = op_index
moment_proto.operation_indices.append(op_index)
else:
op_pb = moment_proto.operations.add()
self._serialize_gate_op(
op,
op_pb,
arg_function_language=arg_function_language,
constants=constants,
raw_constants=raw_constants,
)

if self.use_constants_table_for_moments:
# Add this moment to the constants table
Expand Down Expand Up @@ -469,14 +495,23 @@ def deserialize(self, proto: v2.program_pb2.Program) -> cirq.Circuit:
elif which_const == 'qubit':
deserialized_constants.append(v2.qubit_from_proto_id(constant.qubit.id))
elif which_const == 'operation_value':
deserialized_constants.append(
self._deserialize_gate_op(
if self.op_deserializer and self.op_deserializer.can_deserialize_proto(
constant.operation_value
):
op_pb = self.op_deserializer.from_proto(
constant.operation_value,
arg_function_language=arg_func_language,
constants=proto.constants,
deserialized_constants=deserialized_constants,
)
)
else:
op_pb = self._deserialize_gate_op(
constant.operation_value,
arg_function_language=arg_func_language,
constants=proto.constants,
deserialized_constants=deserialized_constants,
)
deserialized_constants.append(op_pb)
elif which_const == 'moment_value':
deserialized_constants.append(
self._deserialize_moment(
Expand Down Expand Up @@ -541,12 +576,20 @@ def _deserialize_moment(
) -> cirq.Moment:
moment_ops = []
for op in moment_proto.operations:
gate_op = self._deserialize_gate_op(
op,
arg_function_language=arg_function_language,
constants=constants,
deserialized_constants=deserialized_constants,
)
if self.op_deserializer and self.op_deserializer.can_deserialize_proto(op):
gate_op = self.op_deserializer.from_proto(
op,
arg_function_language=arg_function_language,
constants=constants,
deserialized_constants=deserialized_constants,
)
else:
gate_op = self._deserialize_gate_op(
op,
arg_function_language=arg_function_language,
constants=constants,
deserialized_constants=deserialized_constants,
)
if op.tag_indices:
tags = [
deserialized_constants[tag_index]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Dict, List
from typing import Any, Dict, List, Optional
import pytest

import numpy as np
Expand All @@ -25,6 +25,8 @@
import cirq_google as cg
from cirq_google.api import v2
from cirq_google.serialization.circuit_serializer import _SERIALIZER_NAME
from cirq_google.serialization.op_deserializer import OpDeserializer
from cirq_google.serialization.op_serializer import OpSerializer


class FakeDevice(cirq.Device):
Expand Down Expand Up @@ -917,3 +919,92 @@ def test_circuit_with_units():
)
msg = cg.CIRCUIT_SERIALIZER.serialize(c)
assert c == cg.CIRCUIT_SERIALIZER.deserialize(msg)


class BingBongGate(cirq.Gate):

def __init__(self, param: float):
self.param = param

def _num_qubits_(self) -> int:
return 1


class BingBongSerializer(OpSerializer):
"""Describes how to serialize CircuitOperations."""

@property
def can_serialize_predicate(self):
return lambda op: isinstance(op.gate, BingBongGate)

def to_proto(
self,
op: cirq.CircuitOperation,
msg: Optional[v2.program_pb2.CircuitOperation] = None,
*,
arg_function_language: Optional[str] = '',
constants: List[v2.program_pb2.Constant],
raw_constants: Dict[Any, int],
) -> v2.program_pb2.CircuitOperation:
assert isinstance(op.gate, BingBongGate)
if msg is None:
msg = v2.program_pb2.Operation() # pragma: nocover
msg.internalgate.name = 'bingbong'
msg.internalgate.module = 'test'
msg.internalgate.num_qubits = 1
msg.internalgate.gate_args['param'].arg_value.float_value = op.gate.param

for qubit in op.qubits:
if qubit not in raw_constants:
constants.append(
v2.program_pb2.Constant(
qubit=v2.program_pb2.Qubit(id=v2.qubit_to_proto_id(qubit))
)
)
raw_constants[qubit] = len(constants) - 1
msg.qubit_constant_index.append(raw_constants[qubit])
return msg


class BingBongDeserializer(OpDeserializer):
"""Describes how to serialize CircuitOperations."""

@property
def can_deserialize_predicate(self):
return lambda proto: (
isinstance(proto, v2.program_pb2.Operation)
and proto.WhichOneof("gate_value") == "internalgate"
and proto.internalgate.name == 'bingbong'
and proto.internalgate.module == 'test'
)

def from_proto(
self,
proto: v2.program_pb2.Operation,
*,
arg_function_language: str = '',
constants: List[v2.program_pb2.Constant],
deserialized_constants: List[Any],
) -> cirq.Operation:
return BingBongGate(param=proto.internalgate.gate_args["param"].arg_value.float_value).on(
deserialized_constants[proto.qubit_constant_index[0]]
)


@pytest.mark.parametrize('use_constants_table', [True, False])
def test_custom_serializer(use_constants_table: bool):
c = cirq.Circuit(BingBongGate(param=2.5)(cirq.q(0, 0)))
serializer = cg.CircuitSerializer(
USE_CONSTANTS_TABLE_FOR_MOMENTS=use_constants_table,
USE_CONSTANTS_TABLE_FOR_OPERATIONS=use_constants_table,
op_serializer=BingBongSerializer(),
op_deserializer=BingBongDeserializer(),
)
msg = serializer.serialize(c)
deserialized_circuit = serializer.deserialize(msg)
moment = deserialized_circuit[0]
assert len(moment) == 1
op = moment[cirq.q(0, 0)]
assert isinstance(op.gate, BingBongGate)
assert op.gate.param == 2.5
assert op.qubits == (cirq.q(0, 0),)
17 changes: 8 additions & 9 deletions cirq-google/cirq_google/serialization/op_deserializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, List
from typing import Any, Callable, List

import abc
import sympy
Expand All @@ -31,13 +31,12 @@ class OpDeserializer(abc.ABC):

@property
@abc.abstractmethod
def serialized_id(self) -> str:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please update class docstring which refers to serializer_id.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. Also warnings suppressed.

"""Returns the string identifier for the accepted serialized objects.
def can_deserialize_predicate(self) -> Callable[[Any], bool]:
"""The method used to determine if this can deserialize a proto."""
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this be a regular method returning a bool instead? Or a staticmethod if you want to have it decoupled from OpDeserializer object?

As it is, each call of CircuitOpDeserializer.can_deserialize_proto creates lambda object, calls it once, and discards it after the call.

(Same for OpSerializer.can_serialize_predicate.)

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea. I don't remember why this was done originally, but it is no longer needed.


This ID denotes the serialization format this deserializer consumes. For
example, one of the common deserializers converts objects with the id
'xy' into PhasedXPowGates.
"""
def can_deserialize_proto(self, proto) -> bool:
"""Whether the given operation can be serialized by this serializer."""
return self.can_deserialize_predicate(proto)

@abc.abstractmethod
def from_proto(
Expand Down Expand Up @@ -67,8 +66,8 @@ class CircuitOpDeserializer(OpDeserializer):
"""Describes how to serialize CircuitOperations."""

@property
def serialized_id(self):
return 'circuit'
def can_deserialize_predicate(self):
return lambda proto: isinstance(proto, v2.program_pb2.CircuitOperation) # pragma: nocover

def from_proto(
self,
Expand Down
34 changes: 2 additions & 32 deletions cirq-google/cirq_google/serialization/op_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Callable, Dict, List, Optional, Type, TypeVar
from typing import Any, Callable, Dict, List, Optional, Union
import numbers

import abc
Expand All @@ -23,9 +23,6 @@
from cirq_google.api import v2
from cirq_google.serialization.arg_func_langs import arg_to_proto

# Type for variables that are subclasses of ops.Gate.
Gate = TypeVar('Gate', bound=cirq.Gate)


class OpSerializer(abc.ABC):
"""Generic supertype for operation serializers.
Expand All @@ -35,25 +32,6 @@ class OpSerializer(abc.ABC):
may serialize to the same format.
"""

@property
@abc.abstractmethod
def internal_type(self) -> Type:
"""Returns the type that the operation contains.

For GateOperations, this is the gate type.
For CircuitOperations, this is FrozenCircuit.
"""

@property
@abc.abstractmethod
def serialized_id(self) -> str:
"""Returns the string identifier for the resulting serialized object.

This ID denotes the serialization format this serializer produces. For
example, one of the common serializers assigns the id 'xy' to XPowGates,
as they serialize into a format also used by YPowGates.
"""

@abc.abstractmethod
def to_proto(
self,
Expand All @@ -63,7 +41,7 @@ def to_proto(
arg_function_language: Optional[str] = '',
constants: List[v2.program_pb2.Constant],
raw_constants: Dict[Any, int],
) -> Optional[v2.program_pb2.CircuitOperation]:
) -> Optional[Union[v2.program_pb2.CircuitOperation, v2.program_pb2.Operation]]:
"""Converts op to proto using this serializer.

If self.can_serialize_operation(op) == false, this should return None.
Expand Down Expand Up @@ -99,14 +77,6 @@ def can_serialize_operation(self, op: cirq.Operation) -> bool:
class CircuitOpSerializer(OpSerializer):
"""Describes how to serialize CircuitOperations."""

@property
def internal_type(self):
return cirq.FrozenCircuit

@property
def serialized_id(self):
return 'circuit'

@property
def can_serialize_predicate(self):
return lambda op: isinstance(op.untagged, cirq.CircuitOperation)
Expand Down
6 changes: 0 additions & 6 deletions cirq-google/cirq_google/serialization/op_serializer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,6 @@ def default_circuit():
)


def test_circuit_op_serializer_properties():
serializer = cg.CircuitOpSerializer()
assert serializer.internal_type == cirq.FrozenCircuit
assert serializer.serialized_id == 'circuit'


def test_can_serialize_circuit_op():
serializer = cg.CircuitOpSerializer()
assert serializer.can_serialize_operation(cirq.CircuitOperation(default_circuit()))
Expand Down