Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion cirq-ft/cirq_ft/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@
Register,
Registers,
SelectionRegister,
SelectionRegisters,
TComplexity,
map_clean_and_borrowable_qubits,
t_complexity,
Expand Down
6 changes: 3 additions & 3 deletions cirq-ft/cirq_ft/algos/and_gate.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,11 @@
"source": [
"import cirq\n",
"from cirq.contrib.svg import SVGCircuit\n",
"from cirq_ft import And\n",
"from cirq_ft import And, infra\n",
"\n",
"gate = And()\n",
"r = gate.registers\n",
"quregs = r.get_named_qubits()\n",
"quregs = infra.get_named_qubits(r)\n",
"operation = gate.on_registers(**quregs)\n",
"circuit = cirq.Circuit(operation)\n",
"SVGCircuit(circuit)"
Expand Down Expand Up @@ -223,4 +223,4 @@
},
"nbformat": 4,
"nbformat_minor": 5
}
}
7 changes: 4 additions & 3 deletions cirq-ft/cirq_ft/algos/and_gate_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import cirq_ft
import numpy as np
import pytest
from cirq_ft import infra
from cirq_ft.infra.jupyter_tools import execute_notebook

random.seed(12345)
Expand All @@ -46,12 +47,12 @@ def test_multi_controlled_and_gate(cv: List[int]):
gate = cirq_ft.And(cv)
r = gate.registers
assert r['ancilla'].total_bits() == r['control'].total_bits() - 2
quregs = r.get_named_qubits()
quregs = infra.get_named_qubits(r)
and_op = gate.on_registers(**quregs)
circuit = cirq.Circuit(and_op)

input_controls = [cv] + [random_cv(len(cv)) for _ in range(10)]
qubit_order = gate.registers.merge_qubits(**quregs)
qubit_order = infra.merge_qubits(gate.registers, **quregs)

for input_control in input_controls:
initial_state = input_control + [0] * (r['ancilla'].total_bits() + 1)
Expand All @@ -77,7 +78,7 @@ def test_multi_controlled_and_gate(cv: List[int]):

def test_and_gate_diagram():
gate = cirq_ft.And((1, 0, 1, 0, 1, 0))
qubit_regs = gate.registers.get_named_qubits()
qubit_regs = infra.get_named_qubits(gate.registers)
op = gate.on_registers(**qubit_regs)
# Qubit order should be alternating (control, ancilla) pairs.
c_and_a = sum(zip(qubit_regs["control"][1:], qubit_regs["ancilla"]), ()) + (
Expand Down
4 changes: 2 additions & 2 deletions cirq-ft/cirq_ft/algos/apply_gate_to_lth_target.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@
"`selection`-th qubit of `target` all controlled by the `control` register.\n",
"\n",
"#### Parameters\n",
" - `selection_regs`: Indexing `select` registers of type `SelectionRegisters`. It also contains information about the iteration length of each selection register.\n",
" - `selection_regs`: Indexing `select` registers of type Tuple[`SelectionRegister`, ...]. It also contains information about the iteration length of each selection register.\n",
" - `nth_gate`: A function mapping the composite selection index to a single-qubit gate.\n",
" - `control_regs`: Control registers for constructing a controlled version of the gate.\n"
]
Expand All @@ -89,7 +89,7 @@
" return cirq.I\n",
"\n",
"apply_z_to_odd = cirq_ft.ApplyGateToLthQubit(\n",
" cirq_ft.SelectionRegisters([cirq_ft.SelectionRegister('selection', 3, 4)]),\n",
" cirq_ft.SelectionRegister('selection', 3, 4),\n",
" nth_gate=_z_to_odd,\n",
" control_regs=cirq_ft.Registers.build(control=2),\n",
")\n",
Expand Down
48 changes: 26 additions & 22 deletions cirq-ft/cirq_ft/algos/apply_gate_to_lth_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,11 @@
# limitations under the License.

import itertools
from typing import Callable, Sequence
from typing import Callable, Sequence, Tuple

import attr
import cirq
import numpy as np
from cirq._compat import cached_property
from cirq_ft import infra
from cirq_ft.algos import unary_iteration_gate
Expand All @@ -36,8 +37,8 @@ class ApplyGateToLthQubit(unary_iteration_gate.UnaryIterationGate):
`selection`-th qubit of `target` all controlled by the `control` register.

Args:
selection_regs: Indexing `select` registers of type `SelectionRegisters`. It also contains
information about the iteration length of each selection register.
selection_regs: Indexing `select` registers of type Tuple[`SelectionRegisters`, ...].
It also contains information about the iteration length of each selection register.
nth_gate: A function mapping the composite selection index to a single-qubit gate.
control_regs: Control registers for constructing a controlled version of the gate.

Expand All @@ -46,43 +47,45 @@ class ApplyGateToLthQubit(unary_iteration_gate.UnaryIterationGate):
(https://arxiv.org/abs/1805.03662).
Babbush et. al. (2018). Section III.A. and Figure 7.
"""
selection_regs: infra.SelectionRegisters
selection_regs: Tuple[infra.SelectionRegister, ...] = attr.field(
converter=lambda v: (v,) if isinstance(v, infra.SelectionRegister) else tuple(v)
)
nth_gate: Callable[..., cirq.Gate]
control_regs: infra.Registers = infra.Registers.build(control=1)
control_regs: Tuple[infra.Register, ...] = attr.field(
converter=lambda v: (v,) if isinstance(v, infra.Register) else tuple(v),
default=(infra.Register('control', 1),),
)

@classmethod
def make_on(
cls, *, nth_gate: Callable[..., cirq.Gate], **quregs: Sequence[cirq.Qid]
) -> cirq.Operation:
"""Helper constructor to automatically deduce bitsize attributes."""
return cls(
infra.SelectionRegisters(
[
infra.SelectionRegister(
'selection', len(quregs['selection']), len(quregs['target'])
)
]
),
return ApplyGateToLthQubit(
infra.SelectionRegister('selection', len(quregs['selection']), len(quregs['target'])),
nth_gate=nth_gate,
control_regs=infra.Registers.build(control=len(quregs['control'])),
control_regs=infra.Register('control', len(quregs['control'])),
).on_registers(**quregs)

@cached_property
def control_registers(self) -> infra.Registers:
def control_registers(self) -> Tuple[infra.Register, ...]:
return self.control_regs

@cached_property
def selection_registers(self) -> infra.SelectionRegisters:
def selection_registers(self) -> Tuple[infra.SelectionRegister, ...]:
return self.selection_regs

@cached_property
def target_registers(self) -> infra.Registers:
return infra.Registers.build(target=self.selection_registers.total_iteration_size)
def target_registers(self) -> Tuple[infra.Register, ...]:
total_iteration_size = np.product(
tuple(reg.iteration_length for reg in self.selection_registers)
)
return (infra.Register('target', int(total_iteration_size)),)

def _circuit_diagram_info_(self, args: cirq.CircuitDiagramInfoArgs) -> cirq.CircuitDiagramInfo:
wire_symbols = ["@"] * self.control_registers.total_bits()
wire_symbols += ["In"] * self.selection_registers.total_bits()
for it in itertools.product(*[range(x) for x in self.selection_regs.iteration_lengths]):
wire_symbols = ["@"] * infra.total_bits(self.control_registers)
wire_symbols += ["In"] * infra.total_bits(self.selection_registers)
for it in itertools.product(*[range(reg.iteration_length) for reg in self.selection_regs]):
wire_symbols += [str(self.nth_gate(*it))]
return cirq.CircuitDiagramInfo(wire_symbols=wire_symbols)

Expand All @@ -93,6 +96,7 @@ def nth_operation( # type: ignore[override]
target: Sequence[cirq.Qid],
**selection_indices: int,
) -> cirq.OP_TREE:
selection_shape = tuple(reg.iteration_length for reg in self.selection_regs)
selection_idx = tuple(selection_indices[reg.name] for reg in self.selection_regs)
target_idx = self.selection_registers.to_flat_idx(*selection_idx)
target_idx = int(np.ravel_multi_index(selection_idx, selection_shape))
return self.nth_gate(*selection_idx).on(target[target_idx]).controlled_by(control)
20 changes: 9 additions & 11 deletions cirq-ft/cirq_ft/algos/apply_gate_to_lth_target_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import cirq
import cirq_ft
import pytest
from cirq_ft import infra
from cirq_ft.infra.bit_tools import iter_bits
from cirq_ft.infra.jupyter_tools import execute_notebook

Expand All @@ -23,16 +24,13 @@
def test_apply_gate_to_lth_qubit(selection_bitsize, target_bitsize):
greedy_mm = cirq_ft.GreedyQubitManager(prefix="_a", maximize_reuse=True)
gate = cirq_ft.ApplyGateToLthQubit(
cirq_ft.SelectionRegisters(
[cirq_ft.SelectionRegister('selection', selection_bitsize, target_bitsize)]
),
lambda _: cirq.X,
cirq_ft.SelectionRegister('selection', selection_bitsize, target_bitsize), lambda _: cirq.X
)
g = cirq_ft.testing.GateHelper(gate, context=cirq.DecompositionContext(greedy_mm))
# Upper bounded because not all ancillas may be used as part of unary iteration.
assert (
len(g.all_qubits)
<= target_bitsize + 2 * (selection_bitsize + gate.control_registers.total_bits()) - 1
<= target_bitsize + 2 * (selection_bitsize + infra.total_bits(gate.control_registers)) - 1
)

for n in range(target_bitsize):
Expand All @@ -54,12 +52,12 @@ def test_apply_gate_to_lth_qubit(selection_bitsize, target_bitsize):
def test_apply_gate_to_lth_qubit_diagram():
# Apply Z gate to all odd targets and Identity to even targets.
gate = cirq_ft.ApplyGateToLthQubit(
cirq_ft.SelectionRegisters([cirq_ft.SelectionRegister('selection', 3, 5)]),
cirq_ft.SelectionRegister('selection', 3, 5),
lambda n: cirq.Z if n & 1 else cirq.I,
control_regs=cirq_ft.Registers.build(control=2),
)
circuit = cirq.Circuit(gate.on_registers(**gate.registers.get_named_qubits()))
qubits = list(q for v in gate.registers.get_named_qubits().values() for q in v)
circuit = cirq.Circuit(gate.on_registers(**infra.get_named_qubits(gate.registers)))
qubits = list(q for v in infra.get_named_qubits(gate.registers).values() for q in v)
cirq.testing.assert_has_diagram(
circuit,
"""
Expand Down Expand Up @@ -89,13 +87,13 @@ def test_apply_gate_to_lth_qubit_diagram():

def test_apply_gate_to_lth_qubit_make_on():
gate = cirq_ft.ApplyGateToLthQubit(
cirq_ft.SelectionRegisters([cirq_ft.SelectionRegister('selection', 3, 5)]),
cirq_ft.SelectionRegister('selection', 3, 5),
lambda n: cirq.Z if n & 1 else cirq.I,
control_regs=cirq_ft.Registers.build(control=2),
)
op = gate.on_registers(**gate.registers.get_named_qubits())
op = gate.on_registers(**infra.get_named_qubits(gate.registers))
op2 = cirq_ft.ApplyGateToLthQubit.make_on(
nth_gate=lambda n: cirq.Z if n & 1 else cirq.I, **gate.registers.get_named_qubits()
nth_gate=lambda n: cirq.Z if n & 1 else cirq.I, **infra.get_named_qubits(gate.registers)
)
# Note: ApplyGateToLthQubit doesn't support value equality.
assert op.qubits == op2.qubits
Expand Down
21 changes: 9 additions & 12 deletions cirq-ft/cirq_ft/algos/generic_select.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,23 +68,20 @@ def __attrs_post_init__(self):
)

@cached_property
def control_registers(self) -> infra.Registers:
registers = [] if self.control_val is None else [infra.Register('control', 1)]
return infra.Registers(registers)
def control_registers(self) -> Tuple[infra.Register, ...]:
return () if self.control_val is None else (infra.Register('control', 1),)

@cached_property
def selection_registers(self) -> infra.SelectionRegisters:
return infra.SelectionRegisters(
[
infra.SelectionRegister(
'selection', self.selection_bitsize, len(self.select_unitaries)
)
]
def selection_registers(self) -> Tuple[infra.SelectionRegister, ...]:
return (
infra.SelectionRegister(
'selection', self.selection_bitsize, len(self.select_unitaries)
),
)

@cached_property
def target_registers(self) -> infra.Registers:
return infra.Registers.build(target=self.target_bitsize)
def target_registers(self) -> Tuple[infra.Register, ...]:
return (infra.Register('target', self.target_bitsize),)

def decompose_from_registers(
self, context, **quregs: NDArray[cirq.Qid] # type:ignore[type-var]
Expand Down
3 changes: 2 additions & 1 deletion cirq-ft/cirq_ft/algos/generic_select_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import cirq_ft
import numpy as np
import pytest
from cirq_ft import infra
from cirq_ft.infra.bit_tools import iter_bits
from cirq_ft.infra.jupyter_tools import execute_notebook

Expand Down Expand Up @@ -255,7 +256,7 @@ def test_generic_select_consistent_protocols_and_controlled():

# Build GenericSelect gate.
gate = cirq_ft.GenericSelect(select_bitsize, num_sites, dps_hamiltonian)
op = gate.on_registers(**gate.registers.get_named_qubits())
op = gate.on_registers(**infra.get_named_qubits(gate.registers))
cirq.testing.assert_equivalent_repr(gate, setup_code='import cirq\nimport cirq_ft')

# Build controlled gate
Expand Down
Loading