Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion cirq-ft/cirq_ft/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@
Register,
Registers,
SelectionRegister,
SelectionRegisters,
TComplexity,
map_clean_and_borrowable_qubits,
t_complexity,
Expand Down
6 changes: 3 additions & 3 deletions cirq-ft/cirq_ft/algos/and_gate.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,11 @@
"source": [
"import cirq\n",
"from cirq.contrib.svg import SVGCircuit\n",
"from cirq_ft import And\n",
"from cirq_ft import And, infra\n",
"\n",
"gate = And()\n",
"r = gate.registers\n",
"quregs = r.get_named_qubits()\n",
"quregs = infra.get_named_qubits(r)\n",
"operation = gate.on_registers(**quregs)\n",
"circuit = cirq.Circuit(operation)\n",
"SVGCircuit(circuit)"
Expand Down Expand Up @@ -223,4 +223,4 @@
},
"nbformat": 4,
"nbformat_minor": 5
}
}
4 changes: 3 additions & 1 deletion cirq-ft/cirq_ft/algos/and_gate.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,9 @@ class And(infra.GateWithRegisters):
ValueError: If number of control values (i.e. `len(self.cv)`) is less than 2.
"""

cv: Tuple[int, ...] = attr.field(default=(1, 1), converter=infra.to_tuple)
cv: Tuple[int, ...] = attr.field(
default=(1, 1), converter=lambda v: (v,) if isinstance(v, int) else tuple(v)
)
adjoint: bool = False

@cv.validator
Expand Down
7 changes: 4 additions & 3 deletions cirq-ft/cirq_ft/algos/and_gate_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import cirq_ft
import numpy as np
import pytest
from cirq_ft import infra
from cirq_ft.infra.jupyter_tools import execute_notebook

random.seed(12345)
Expand All @@ -46,12 +47,12 @@ def test_multi_controlled_and_gate(cv: List[int]):
gate = cirq_ft.And(cv)
r = gate.registers
assert r['ancilla'].total_bits() == r['control'].total_bits() - 2
quregs = r.get_named_qubits()
quregs = infra.get_named_qubits(r)
and_op = gate.on_registers(**quregs)
circuit = cirq.Circuit(and_op)

input_controls = [cv] + [random_cv(len(cv)) for _ in range(10)]
qubit_order = gate.registers.merge_qubits(**quregs)
qubit_order = infra.merge_qubits(gate.registers, **quregs)

for input_control in input_controls:
initial_state = input_control + [0] * (r['ancilla'].total_bits() + 1)
Expand All @@ -77,7 +78,7 @@ def test_multi_controlled_and_gate(cv: List[int]):

def test_and_gate_diagram():
gate = cirq_ft.And((1, 0, 1, 0, 1, 0))
qubit_regs = gate.registers.get_named_qubits()
qubit_regs = infra.get_named_qubits(gate.registers)
op = gate.on_registers(**qubit_regs)
# Qubit order should be alternating (control, ancilla) pairs.
c_and_a = sum(zip(qubit_regs["control"][1:], qubit_regs["ancilla"]), ()) + (
Expand Down
4 changes: 2 additions & 2 deletions cirq-ft/cirq_ft/algos/apply_gate_to_lth_target.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@
"`selection`-th qubit of `target` all controlled by the `control` register.\n",
"\n",
"#### Parameters\n",
" - `selection_regs`: Indexing `select` registers of type `SelectionRegisters`. It also contains information about the iteration length of each selection register.\n",
" - `selection_regs`: Indexing `select` registers of type Tuple[`SelectionRegister`, ...]. It also contains information about the iteration length of each selection register.\n",
" - `nth_gate`: A function mapping the composite selection index to a single-qubit gate.\n",
" - `control_regs`: Control registers for constructing a controlled version of the gate.\n"
]
Expand All @@ -89,7 +89,7 @@
" return cirq.I\n",
"\n",
"apply_z_to_odd = cirq_ft.ApplyGateToLthQubit(\n",
" cirq_ft.SelectionRegisters([cirq_ft.SelectionRegister('selection', 3, 4)]),\n",
" cirq_ft.SelectionRegister('selection', 3, 4),\n",
" nth_gate=_z_to_odd,\n",
" control_regs=cirq_ft.Registers.build(control=2),\n",
")\n",
Expand Down
48 changes: 26 additions & 22 deletions cirq-ft/cirq_ft/algos/apply_gate_to_lth_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,11 @@
# limitations under the License.

import itertools
from typing import Callable, Sequence
from typing import Callable, Sequence, Tuple

import attr
import cirq
import numpy as np
from cirq._compat import cached_property
from cirq_ft import infra
from cirq_ft.algos import unary_iteration_gate
Expand All @@ -36,8 +37,8 @@ class ApplyGateToLthQubit(unary_iteration_gate.UnaryIterationGate):
`selection`-th qubit of `target` all controlled by the `control` register.

Args:
selection_regs: Indexing `select` registers of type `SelectionRegisters`. It also contains
information about the iteration length of each selection register.
selection_regs: Indexing `select` registers of type Tuple[`SelectionRegisters`, ...].
It also contains information about the iteration length of each selection register.
nth_gate: A function mapping the composite selection index to a single-qubit gate.
control_regs: Control registers for constructing a controlled version of the gate.

Expand All @@ -46,43 +47,45 @@ class ApplyGateToLthQubit(unary_iteration_gate.UnaryIterationGate):
(https://arxiv.org/abs/1805.03662).
Babbush et. al. (2018). Section III.A. and Figure 7.
"""
selection_regs: infra.SelectionRegisters
selection_regs: Tuple[infra.SelectionRegister, ...] = attr.field(
converter=lambda v: (v,) if isinstance(v, infra.SelectionRegister) else tuple(v)
)
nth_gate: Callable[..., cirq.Gate]
control_regs: infra.Registers = infra.Registers.build(control=1)
control_regs: Tuple[infra.Register, ...] = attr.field(
converter=lambda v: (v,) if isinstance(v, infra.Register) else tuple(v),
default=(infra.Register('control', 1),),
)

@classmethod
def make_on(
cls, *, nth_gate: Callable[..., cirq.Gate], **quregs: Sequence[cirq.Qid]
) -> cirq.Operation:
"""Helper constructor to automatically deduce bitsize attributes."""
return cls(
infra.SelectionRegisters(
[
infra.SelectionRegister(
'selection', len(quregs['selection']), len(quregs['target'])
)
]
),
return ApplyGateToLthQubit(
infra.SelectionRegister('selection', len(quregs['selection']), len(quregs['target'])),
nth_gate=nth_gate,
control_regs=infra.Registers.build(control=len(quregs['control'])),
control_regs=infra.Register('control', len(quregs['control'])),
).on_registers(**quregs)

@cached_property
def control_registers(self) -> infra.Registers:
def control_registers(self) -> Tuple[infra.Register, ...]:
return self.control_regs

@cached_property
def selection_registers(self) -> infra.SelectionRegisters:
def selection_registers(self) -> Tuple[infra.SelectionRegister, ...]:
return self.selection_regs

@cached_property
def target_registers(self) -> infra.Registers:
return infra.Registers.build(target=self.selection_registers.total_iteration_size)
def target_registers(self) -> Tuple[infra.Register, ...]:
total_iteration_size = np.product(
tuple(reg.iteration_length for reg in self.selection_registers)
)
return (infra.Register('target', int(total_iteration_size)),)

def _circuit_diagram_info_(self, args: cirq.CircuitDiagramInfoArgs) -> cirq.CircuitDiagramInfo:
wire_symbols = ["@"] * self.control_registers.total_bits()
wire_symbols += ["In"] * self.selection_registers.total_bits()
for it in itertools.product(*[range(x) for x in self.selection_regs.iteration_lengths]):
wire_symbols = ["@"] * infra.total_bits(self.control_registers)
wire_symbols += ["In"] * infra.total_bits(self.selection_registers)
for it in itertools.product(*[range(reg.iteration_length) for reg in self.selection_regs]):
wire_symbols += [str(self.nth_gate(*it))]
return cirq.CircuitDiagramInfo(wire_symbols=wire_symbols)

Expand All @@ -93,6 +96,7 @@ def nth_operation( # type: ignore[override]
target: Sequence[cirq.Qid],
**selection_indices: int,
) -> cirq.OP_TREE:
selection_shape = tuple(reg.iteration_length for reg in self.selection_regs)
selection_idx = tuple(selection_indices[reg.name] for reg in self.selection_regs)
target_idx = self.selection_registers.to_flat_idx(*selection_idx)
target_idx = int(np.ravel_multi_index(selection_idx, selection_shape))
return self.nth_gate(*selection_idx).on(target[target_idx]).controlled_by(control)
20 changes: 9 additions & 11 deletions cirq-ft/cirq_ft/algos/apply_gate_to_lth_target_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import cirq
import cirq_ft
import pytest
from cirq_ft import infra
from cirq_ft.infra.bit_tools import iter_bits
from cirq_ft.infra.jupyter_tools import execute_notebook

Expand All @@ -23,16 +24,13 @@
def test_apply_gate_to_lth_qubit(selection_bitsize, target_bitsize):
greedy_mm = cirq_ft.GreedyQubitManager(prefix="_a", maximize_reuse=True)
gate = cirq_ft.ApplyGateToLthQubit(
cirq_ft.SelectionRegisters(
[cirq_ft.SelectionRegister('selection', selection_bitsize, target_bitsize)]
),
lambda _: cirq.X,
cirq_ft.SelectionRegister('selection', selection_bitsize, target_bitsize), lambda _: cirq.X
)
g = cirq_ft.testing.GateHelper(gate, context=cirq.DecompositionContext(greedy_mm))
# Upper bounded because not all ancillas may be used as part of unary iteration.
assert (
len(g.all_qubits)
<= target_bitsize + 2 * (selection_bitsize + gate.control_registers.total_bits()) - 1
<= target_bitsize + 2 * (selection_bitsize + infra.total_bits(gate.control_registers)) - 1
)

for n in range(target_bitsize):
Expand All @@ -54,12 +52,12 @@ def test_apply_gate_to_lth_qubit(selection_bitsize, target_bitsize):
def test_apply_gate_to_lth_qubit_diagram():
# Apply Z gate to all odd targets and Identity to even targets.
gate = cirq_ft.ApplyGateToLthQubit(
cirq_ft.SelectionRegisters([cirq_ft.SelectionRegister('selection', 3, 5)]),
cirq_ft.SelectionRegister('selection', 3, 5),
lambda n: cirq.Z if n & 1 else cirq.I,
control_regs=cirq_ft.Registers.build(control=2),
)
circuit = cirq.Circuit(gate.on_registers(**gate.registers.get_named_qubits()))
qubits = list(q for v in gate.registers.get_named_qubits().values() for q in v)
circuit = cirq.Circuit(gate.on_registers(**infra.get_named_qubits(gate.registers)))
qubits = list(q for v in infra.get_named_qubits(gate.registers).values() for q in v)
cirq.testing.assert_has_diagram(
circuit,
"""
Expand Down Expand Up @@ -89,13 +87,13 @@ def test_apply_gate_to_lth_qubit_diagram():

def test_apply_gate_to_lth_qubit_make_on():
gate = cirq_ft.ApplyGateToLthQubit(
cirq_ft.SelectionRegisters([cirq_ft.SelectionRegister('selection', 3, 5)]),
cirq_ft.SelectionRegister('selection', 3, 5),
lambda n: cirq.Z if n & 1 else cirq.I,
control_regs=cirq_ft.Registers.build(control=2),
)
op = gate.on_registers(**gate.registers.get_named_qubits())
op = gate.on_registers(**infra.get_named_qubits(gate.registers))
op2 = cirq_ft.ApplyGateToLthQubit.make_on(
nth_gate=lambda n: cirq.Z if n & 1 else cirq.I, **gate.registers.get_named_qubits()
nth_gate=lambda n: cirq.Z if n & 1 else cirq.I, **infra.get_named_qubits(gate.registers)
)
# Note: ApplyGateToLthQubit doesn't support value equality.
assert op.qubits == op2.qubits
Expand Down
4 changes: 3 additions & 1 deletion cirq-ft/cirq_ft/algos/arithmetic_gates.py
Original file line number Diff line number Diff line change
Expand Up @@ -611,7 +611,9 @@ class AddMod(cirq.ArithmeticGate):
bitsize: int
mod: int = attr.field()
add_val: int = 1
cv: Tuple[int, ...] = attr.field(converter=infra.to_tuple, default=())
cv: Tuple[int, ...] = attr.field(
converter=lambda v: (v,) if isinstance(v, int) else tuple(v), default=()
)

@mod.validator
def _validate_mod(self, attribute, value):
Expand Down
21 changes: 9 additions & 12 deletions cirq-ft/cirq_ft/algos/generic_select.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,23 +68,20 @@ def __attrs_post_init__(self):
)

@cached_property
def control_registers(self) -> infra.Registers:
registers = [] if self.control_val is None else [infra.Register('control', 1)]
return infra.Registers(registers)
def control_registers(self) -> Tuple[infra.Register, ...]:
return () if self.control_val is None else (infra.Register('control', 1),)

@cached_property
def selection_registers(self) -> infra.SelectionRegisters:
return infra.SelectionRegisters(
[
infra.SelectionRegister(
'selection', self.selection_bitsize, len(self.select_unitaries)
)
]
def selection_registers(self) -> Tuple[infra.SelectionRegister, ...]:
return (
infra.SelectionRegister(
'selection', self.selection_bitsize, len(self.select_unitaries)
),
)

@cached_property
def target_registers(self) -> infra.Registers:
return infra.Registers.build(target=self.target_bitsize)
def target_registers(self) -> Tuple[infra.Register, ...]:
return (infra.Register('target', self.target_bitsize),)

def decompose_from_registers(
self, context, **quregs: NDArray[cirq.Qid] # type:ignore[type-var]
Expand Down
3 changes: 2 additions & 1 deletion cirq-ft/cirq_ft/algos/generic_select_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import cirq_ft
import numpy as np
import pytest
from cirq_ft import infra
from cirq_ft.infra.bit_tools import iter_bits
from cirq_ft.infra.jupyter_tools import execute_notebook

Expand Down Expand Up @@ -255,7 +256,7 @@ def test_generic_select_consistent_protocols_and_controlled():

# Build GenericSelect gate.
gate = cirq_ft.GenericSelect(select_bitsize, num_sites, dps_hamiltonian)
op = gate.on_registers(**gate.registers.get_named_qubits())
op = gate.on_registers(**infra.get_named_qubits(gate.registers))
cirq.testing.assert_equivalent_repr(gate, setup_code='import cirq\nimport cirq_ft')

# Build controlled gate
Expand Down
Loading