Skip to content

Adds functionality for viewing and debugging swap networks #1821

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Jul 17, 2019
9 changes: 4 additions & 5 deletions cirq/contrib/acquaintance/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,10 @@
from cirq.contrib.acquaintance.optimizers import (
remove_redundant_acquaintance_opportunities,)

from cirq.contrib.acquaintance.permutation import (LinearPermutationGate,
PermutationGate,
SwapPermutationGate,
update_mapping,
get_logical_operations)
from cirq.contrib.acquaintance.permutation import (
LinearPermutationGate, PermutationGate, SwapPermutationGate, update_mapping,
get_logical_operations, display_mapping, return_to_initial_mapping,
uses_consistent_swap_gate)

from cirq.contrib.acquaintance.shift import (
CircularShiftGate,)
Expand Down
68 changes: 68 additions & 0 deletions cirq/contrib/acquaintance/permutation.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,42 @@ def _circuit_diagram_info_(self, args: protocols.CircuitDiagramInfoArgs
return wire_symbols


class MappingDisplayGate(ops.Gate):
"""Displays the indices mapped to a set of wires."""

def __init__(self, indices):
self.indices = tuple(indices)
self._num_qubits = len(self.indices)

def num_qubits(self) -> int:
return self._num_qubits

def _circuit_diagram_info_(self, args: protocols.CircuitDiagramInfoArgs
) -> protocols.CircuitDiagramInfo:
wire_symbols = tuple('' if i is None else str(i) for i in self.indices)
return protocols.CircuitDiagramInfo(wire_symbols, connected=False)


def display_mapping(circuit: circuits.Circuit,
initial_mapping: LogicalMapping) -> None:
"""Inserts display gates between moments to indicate the mapping throughout
the circuit."""
qubits = sorted(circuit.all_qubits())
mapping = initial_mapping.copy()

old_moments = circuit._moments
gate = MappingDisplayGate(mapping.get(q) for q in qubits)
new_moments = [ops.Moment([gate(*qubits)])]
for moment in old_moments:
new_moments.append(moment)
update_mapping(mapping, moment)
gate = MappingDisplayGate(mapping.get(q) for q in qubits)
new_moments.append(ops.Moment([gate(*qubits)]))

circuit._moments = new_moments


@value.value_equality
class SwapPermutationGate(PermutationGate):
"""Generic swap gate."""

Expand All @@ -99,6 +135,14 @@ def _decompose_(
self, qubits: Sequence[ops.Qid]) -> ops.OP_TREE:
yield self.swap_gate(*qubits)

def __repr__(self):
return ('cirq.contrib.acquaintance.SwapPermutationGate(' +
('' if self.swap_gate == ops.SWAP else repr(self.swap_gate)) +
')')

def _value_equality_values_(self):
return (self.swap_gate,)


def _canonicalize_permutation(permutation: Dict[int, int]) -> Dict[int, int]:
return {i: j for i, j in permutation.items() if i != j}
Expand Down Expand Up @@ -199,3 +243,27 @@ def __init__(self):
not isinstance(op.gate, SwapPermutationGate)]))

expand_permutation_gates = ExpandPermutationGates()


def return_to_initial_mapping(circuit: circuits.Circuit,
swap_gate: ops.Gate = ops.SWAP) -> None:
qubits = sorted(circuit.all_qubits())
n_qubits = len(qubits)

mapping = {q: i for i, q in enumerate(qubits)}
update_mapping(mapping, circuit.all_operations())

permutation = {i: mapping[q] for i, q in enumerate(qubits)}
returning_permutation_op = LinearPermutationGate(n_qubits, permutation,
swap_gate)(*qubits)
circuit.append(returning_permutation_op)


def uses_consistent_swap_gate(circuit: circuits.Circuit,
swap_gate: ops.Gate) -> bool:
for op in circuit.all_operations():
if (isinstance(op, ops.GateOperation) and
isinstance(op.gate, PermutationGate)):
if op.gate.swap_gate != swap_gate:
return False
return True
80 changes: 77 additions & 3 deletions cirq/contrib/acquaintance/permutation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,14 @@ def test_validate_permutation_errors():
validate_permutation({})

with pytest.raises(IndexError,
match='key and value sets must be the same.'):
message='key and value sets must be the same.'):
validate_permutation({0: 2, 1: 3})

with pytest.raises(IndexError,
match='keys of the permutation must be non-negative.'):
message='keys of the permutation must be non-negative.'):
validate_permutation({-1: 0, 0: -1})

with pytest.raises(IndexError, match='key is out of bounds.'):
with pytest.raises(IndexError, message='key is out of bounds.'):
validate_permutation({0: 3, 3: 0}, 2)

gate = cca.SwapPermutationGate()
Expand Down Expand Up @@ -246,3 +246,77 @@ def test_linear_permutation_gate_pow_inverse(num_qubits, permutation, inverse):

assert permutation_gate**-1 == inverse_gate
assert cirq.inverse(permutation_gate) == inverse_gate


def test_display_mapping():
indices = [4, 2, 0, 1, 3]
qubits = cirq.LineQubit.range(len(indices))
circuit = cca.complete_acquaintance_strategy(qubits, 2)
cca.expose_acquaintance_gates(circuit)
initial_mapping = dict(zip(qubits, indices))
cca.display_mapping(circuit, initial_mapping)
expected_diagram = """
0: ───4───█───4───╲0╱───2───────2─────────2───█───2───╲0╱───1───────1─────────1───█───1───╲0╱───3───
│ │ │ │ │ │
1: ───2───█───2───╱1╲───4───█───4───╲0╱───1───█───1───╱1╲───2───█───2───╲0╱───3───█───3───╱1╲───1───
│ │ │ │
2: ───0───█───0───╲0╱───1───█───1───╱1╲───4───█───4───╲0╱───3───█───3───╱1╲───2───█───2───╲0╱───0───
│ │ │ │ │ │
3: ───1───█───1───╱1╲───0───█───0───╲0╱───3───█───3───╱1╲───4───█───4───╲0╱───0───█───0───╱1╲───2───
│ │ │ │
4: ───3───────3─────────3───█───3───╱1╲───0───────0─────────0───█───0───╱1╲───4───────4─────────4───
"""
cirq.testing.assert_has_diagram(circuit, expected_diagram)


@pytest.mark.parametrize('circuit', [
cirq.Circuit.from_ops(
cca.SwapPermutationGate()(*qubit_pair)
for qubit_pair in
[random.sample(cirq.LineQubit.range(10), 2)
for _ in range(20)])
for _ in range(4)
])
def test_return_to_initial_mapping(circuit):
qubits = sorted(circuit.all_qubits())
cca.return_to_initial_mapping(circuit)
initial_mapping = {q: i for i, q in enumerate(qubits)}
mapping = dict(initial_mapping)
cca.update_mapping(mapping, circuit.all_operations())
assert mapping == initial_mapping


def test_uses_consistent_swap_gate():
a, b = cirq.LineQubit.range(2)
circuit = cirq.Circuit.from_ops(
[cca.SwapPermutationGate()(a, b),
cca.SwapPermutationGate()(a, b)])
assert cca.uses_consistent_swap_gate(circuit, cirq.SWAP)
assert not cca.uses_consistent_swap_gate(circuit, cirq.CZ)
circuit = cirq.Circuit.from_ops([
cca.SwapPermutationGate(cirq.CZ)(a, b),
cca.SwapPermutationGate(cirq.CZ)(a, b)
])
assert cca.uses_consistent_swap_gate(circuit, cirq.CZ)
assert not cca.uses_consistent_swap_gate(circuit, cirq.SWAP)
circuit = cirq.Circuit.from_ops([
cca.SwapPermutationGate()(a, b),
cca.SwapPermutationGate(cirq.CZ)(a, b)
])
assert not cca.uses_consistent_swap_gate(circuit, cirq.SWAP)
assert not cca.uses_consistent_swap_gate(circuit, cirq.CZ)


def test_swap_gate_eq():
assert cca.SwapPermutationGate() == cca.SwapPermutationGate(cirq.SWAP)
assert cca.SwapPermutationGate() != cca.SwapPermutationGate(cirq.CZ)
assert cca.SwapPermutationGate(cirq.CZ) == cca.SwapPermutationGate(cirq.CZ)


@pytest.mark.parametrize('gate', [
cca.SwapPermutationGate(),
cca.SwapPermutationGate(cirq.SWAP),
cca.SwapPermutationGate(cirq.CZ)
])
def test_swap_gate_repr(gate):
cirq.testing.assert_equivalent_repr(gate)