diff --git a/cirq-ft/cirq_ft/__init__.py b/cirq-ft/cirq_ft/__init__.py index 00053e949b5..47bf47cf660 100644 --- a/cirq-ft/cirq_ft/__init__.py +++ b/cirq-ft/cirq_ft/__init__.py @@ -49,7 +49,6 @@ Register, Registers, SelectionRegister, - SelectionRegisters, TComplexity, map_clean_and_borrowable_qubits, t_complexity, diff --git a/cirq-ft/cirq_ft/algos/and_gate.ipynb b/cirq-ft/cirq_ft/algos/and_gate.ipynb index ca1562840ad..498081f0cb4 100644 --- a/cirq-ft/cirq_ft/algos/and_gate.ipynb +++ b/cirq-ft/cirq_ft/algos/and_gate.ipynb @@ -63,11 +63,11 @@ "source": [ "import cirq\n", "from cirq.contrib.svg import SVGCircuit\n", - "from cirq_ft import And\n", + "from cirq_ft import And, infra\n", "\n", "gate = And()\n", "r = gate.registers\n", - "quregs = r.get_named_qubits()\n", + "quregs = infra.get_named_qubits(r)\n", "operation = gate.on_registers(**quregs)\n", "circuit = cirq.Circuit(operation)\n", "SVGCircuit(circuit)" @@ -223,4 +223,4 @@ }, "nbformat": 4, "nbformat_minor": 5 -} +} \ No newline at end of file diff --git a/cirq-ft/cirq_ft/algos/and_gate.py b/cirq-ft/cirq_ft/algos/and_gate.py index 973528386dc..f308926d632 100644 --- a/cirq-ft/cirq_ft/algos/and_gate.py +++ b/cirq-ft/cirq_ft/algos/and_gate.py @@ -49,7 +49,9 @@ class And(infra.GateWithRegisters): ValueError: If number of control values (i.e. `len(self.cv)`) is less than 2. """ - cv: Tuple[int, ...] = attr.field(default=(1, 1), converter=infra.to_tuple) + cv: Tuple[int, ...] = attr.field( + default=(1, 1), converter=lambda v: (v,) if isinstance(v, int) else tuple(v) + ) adjoint: bool = False @cv.validator diff --git a/cirq-ft/cirq_ft/algos/and_gate_test.py b/cirq-ft/cirq_ft/algos/and_gate_test.py index f41b6a271c1..70de51a205b 100644 --- a/cirq-ft/cirq_ft/algos/and_gate_test.py +++ b/cirq-ft/cirq_ft/algos/and_gate_test.py @@ -20,6 +20,7 @@ import cirq_ft import numpy as np import pytest +from cirq_ft import infra from cirq_ft.infra.jupyter_tools import execute_notebook random.seed(12345) @@ -46,12 +47,12 @@ def test_multi_controlled_and_gate(cv: List[int]): gate = cirq_ft.And(cv) r = gate.registers assert r['ancilla'].total_bits() == r['control'].total_bits() - 2 - quregs = r.get_named_qubits() + quregs = infra.get_named_qubits(r) and_op = gate.on_registers(**quregs) circuit = cirq.Circuit(and_op) input_controls = [cv] + [random_cv(len(cv)) for _ in range(10)] - qubit_order = gate.registers.merge_qubits(**quregs) + qubit_order = infra.merge_qubits(gate.registers, **quregs) for input_control in input_controls: initial_state = input_control + [0] * (r['ancilla'].total_bits() + 1) @@ -77,7 +78,7 @@ def test_multi_controlled_and_gate(cv: List[int]): def test_and_gate_diagram(): gate = cirq_ft.And((1, 0, 1, 0, 1, 0)) - qubit_regs = gate.registers.get_named_qubits() + qubit_regs = infra.get_named_qubits(gate.registers) op = gate.on_registers(**qubit_regs) # Qubit order should be alternating (control, ancilla) pairs. c_and_a = sum(zip(qubit_regs["control"][1:], qubit_regs["ancilla"]), ()) + ( diff --git a/cirq-ft/cirq_ft/algos/apply_gate_to_lth_target.ipynb b/cirq-ft/cirq_ft/algos/apply_gate_to_lth_target.ipynb index 90e4ac086fe..3e34704c619 100644 --- a/cirq-ft/cirq_ft/algos/apply_gate_to_lth_target.ipynb +++ b/cirq-ft/cirq_ft/algos/apply_gate_to_lth_target.ipynb @@ -69,7 +69,7 @@ "`selection`-th qubit of `target` all controlled by the `control` register.\n", "\n", "#### Parameters\n", - " - `selection_regs`: Indexing `select` registers of type `SelectionRegisters`. It also contains information about the iteration length of each selection register.\n", + " - `selection_regs`: Indexing `select` registers of type Tuple[`SelectionRegister`, ...]. It also contains information about the iteration length of each selection register.\n", " - `nth_gate`: A function mapping the composite selection index to a single-qubit gate.\n", " - `control_regs`: Control registers for constructing a controlled version of the gate.\n" ] @@ -89,7 +89,7 @@ " return cirq.I\n", "\n", "apply_z_to_odd = cirq_ft.ApplyGateToLthQubit(\n", - " cirq_ft.SelectionRegisters([cirq_ft.SelectionRegister('selection', 3, 4)]),\n", + " cirq_ft.SelectionRegister('selection', 3, 4),\n", " nth_gate=_z_to_odd,\n", " control_regs=cirq_ft.Registers.build(control=2),\n", ")\n", diff --git a/cirq-ft/cirq_ft/algos/apply_gate_to_lth_target.py b/cirq-ft/cirq_ft/algos/apply_gate_to_lth_target.py index e796b1b05f0..e3bb08be143 100644 --- a/cirq-ft/cirq_ft/algos/apply_gate_to_lth_target.py +++ b/cirq-ft/cirq_ft/algos/apply_gate_to_lth_target.py @@ -13,10 +13,11 @@ # limitations under the License. import itertools -from typing import Callable, Sequence +from typing import Callable, Sequence, Tuple import attr import cirq +import numpy as np from cirq._compat import cached_property from cirq_ft import infra from cirq_ft.algos import unary_iteration_gate @@ -36,8 +37,8 @@ class ApplyGateToLthQubit(unary_iteration_gate.UnaryIterationGate): `selection`-th qubit of `target` all controlled by the `control` register. Args: - selection_regs: Indexing `select` registers of type `SelectionRegisters`. It also contains - information about the iteration length of each selection register. + selection_regs: Indexing `select` registers of type Tuple[`SelectionRegisters`, ...]. + It also contains information about the iteration length of each selection register. nth_gate: A function mapping the composite selection index to a single-qubit gate. control_regs: Control registers for constructing a controlled version of the gate. @@ -46,43 +47,45 @@ class ApplyGateToLthQubit(unary_iteration_gate.UnaryIterationGate): (https://arxiv.org/abs/1805.03662). Babbush et. al. (2018). Section III.A. and Figure 7. """ - selection_regs: infra.SelectionRegisters + selection_regs: Tuple[infra.SelectionRegister, ...] = attr.field( + converter=lambda v: (v,) if isinstance(v, infra.SelectionRegister) else tuple(v) + ) nth_gate: Callable[..., cirq.Gate] - control_regs: infra.Registers = infra.Registers.build(control=1) + control_regs: Tuple[infra.Register, ...] = attr.field( + converter=lambda v: (v,) if isinstance(v, infra.Register) else tuple(v), + default=(infra.Register('control', 1),), + ) @classmethod def make_on( cls, *, nth_gate: Callable[..., cirq.Gate], **quregs: Sequence[cirq.Qid] ) -> cirq.Operation: """Helper constructor to automatically deduce bitsize attributes.""" - return cls( - infra.SelectionRegisters( - [ - infra.SelectionRegister( - 'selection', len(quregs['selection']), len(quregs['target']) - ) - ] - ), + return ApplyGateToLthQubit( + infra.SelectionRegister('selection', len(quregs['selection']), len(quregs['target'])), nth_gate=nth_gate, - control_regs=infra.Registers.build(control=len(quregs['control'])), + control_regs=infra.Register('control', len(quregs['control'])), ).on_registers(**quregs) @cached_property - def control_registers(self) -> infra.Registers: + def control_registers(self) -> Tuple[infra.Register, ...]: return self.control_regs @cached_property - def selection_registers(self) -> infra.SelectionRegisters: + def selection_registers(self) -> Tuple[infra.SelectionRegister, ...]: return self.selection_regs @cached_property - def target_registers(self) -> infra.Registers: - return infra.Registers.build(target=self.selection_registers.total_iteration_size) + def target_registers(self) -> Tuple[infra.Register, ...]: + total_iteration_size = np.product( + tuple(reg.iteration_length for reg in self.selection_registers) + ) + return (infra.Register('target', int(total_iteration_size)),) def _circuit_diagram_info_(self, args: cirq.CircuitDiagramInfoArgs) -> cirq.CircuitDiagramInfo: - wire_symbols = ["@"] * self.control_registers.total_bits() - wire_symbols += ["In"] * self.selection_registers.total_bits() - for it in itertools.product(*[range(x) for x in self.selection_regs.iteration_lengths]): + wire_symbols = ["@"] * infra.total_bits(self.control_registers) + wire_symbols += ["In"] * infra.total_bits(self.selection_registers) + for it in itertools.product(*[range(reg.iteration_length) for reg in self.selection_regs]): wire_symbols += [str(self.nth_gate(*it))] return cirq.CircuitDiagramInfo(wire_symbols=wire_symbols) @@ -93,6 +96,7 @@ def nth_operation( # type: ignore[override] target: Sequence[cirq.Qid], **selection_indices: int, ) -> cirq.OP_TREE: + selection_shape = tuple(reg.iteration_length for reg in self.selection_regs) selection_idx = tuple(selection_indices[reg.name] for reg in self.selection_regs) - target_idx = self.selection_registers.to_flat_idx(*selection_idx) + target_idx = int(np.ravel_multi_index(selection_idx, selection_shape)) return self.nth_gate(*selection_idx).on(target[target_idx]).controlled_by(control) diff --git a/cirq-ft/cirq_ft/algos/apply_gate_to_lth_target_test.py b/cirq-ft/cirq_ft/algos/apply_gate_to_lth_target_test.py index da285792d36..2c2e29e7c0c 100644 --- a/cirq-ft/cirq_ft/algos/apply_gate_to_lth_target_test.py +++ b/cirq-ft/cirq_ft/algos/apply_gate_to_lth_target_test.py @@ -15,6 +15,7 @@ import cirq import cirq_ft import pytest +from cirq_ft import infra from cirq_ft.infra.bit_tools import iter_bits from cirq_ft.infra.jupyter_tools import execute_notebook @@ -23,16 +24,13 @@ def test_apply_gate_to_lth_qubit(selection_bitsize, target_bitsize): greedy_mm = cirq_ft.GreedyQubitManager(prefix="_a", maximize_reuse=True) gate = cirq_ft.ApplyGateToLthQubit( - cirq_ft.SelectionRegisters( - [cirq_ft.SelectionRegister('selection', selection_bitsize, target_bitsize)] - ), - lambda _: cirq.X, + cirq_ft.SelectionRegister('selection', selection_bitsize, target_bitsize), lambda _: cirq.X ) g = cirq_ft.testing.GateHelper(gate, context=cirq.DecompositionContext(greedy_mm)) # Upper bounded because not all ancillas may be used as part of unary iteration. assert ( len(g.all_qubits) - <= target_bitsize + 2 * (selection_bitsize + gate.control_registers.total_bits()) - 1 + <= target_bitsize + 2 * (selection_bitsize + infra.total_bits(gate.control_registers)) - 1 ) for n in range(target_bitsize): @@ -54,12 +52,12 @@ def test_apply_gate_to_lth_qubit(selection_bitsize, target_bitsize): def test_apply_gate_to_lth_qubit_diagram(): # Apply Z gate to all odd targets and Identity to even targets. gate = cirq_ft.ApplyGateToLthQubit( - cirq_ft.SelectionRegisters([cirq_ft.SelectionRegister('selection', 3, 5)]), + cirq_ft.SelectionRegister('selection', 3, 5), lambda n: cirq.Z if n & 1 else cirq.I, control_regs=cirq_ft.Registers.build(control=2), ) - circuit = cirq.Circuit(gate.on_registers(**gate.registers.get_named_qubits())) - qubits = list(q for v in gate.registers.get_named_qubits().values() for q in v) + circuit = cirq.Circuit(gate.on_registers(**infra.get_named_qubits(gate.registers))) + qubits = list(q for v in infra.get_named_qubits(gate.registers).values() for q in v) cirq.testing.assert_has_diagram( circuit, """ @@ -89,13 +87,13 @@ def test_apply_gate_to_lth_qubit_diagram(): def test_apply_gate_to_lth_qubit_make_on(): gate = cirq_ft.ApplyGateToLthQubit( - cirq_ft.SelectionRegisters([cirq_ft.SelectionRegister('selection', 3, 5)]), + cirq_ft.SelectionRegister('selection', 3, 5), lambda n: cirq.Z if n & 1 else cirq.I, control_regs=cirq_ft.Registers.build(control=2), ) - op = gate.on_registers(**gate.registers.get_named_qubits()) + op = gate.on_registers(**infra.get_named_qubits(gate.registers)) op2 = cirq_ft.ApplyGateToLthQubit.make_on( - nth_gate=lambda n: cirq.Z if n & 1 else cirq.I, **gate.registers.get_named_qubits() + nth_gate=lambda n: cirq.Z if n & 1 else cirq.I, **infra.get_named_qubits(gate.registers) ) # Note: ApplyGateToLthQubit doesn't support value equality. assert op.qubits == op2.qubits diff --git a/cirq-ft/cirq_ft/algos/arithmetic_gates.py b/cirq-ft/cirq_ft/algos/arithmetic_gates.py index b75383015e3..6054c90709f 100644 --- a/cirq-ft/cirq_ft/algos/arithmetic_gates.py +++ b/cirq-ft/cirq_ft/algos/arithmetic_gates.py @@ -611,7 +611,9 @@ class AddMod(cirq.ArithmeticGate): bitsize: int mod: int = attr.field() add_val: int = 1 - cv: Tuple[int, ...] = attr.field(converter=infra.to_tuple, default=()) + cv: Tuple[int, ...] = attr.field( + converter=lambda v: (v,) if isinstance(v, int) else tuple(v), default=() + ) @mod.validator def _validate_mod(self, attribute, value): diff --git a/cirq-ft/cirq_ft/algos/generic_select.py b/cirq-ft/cirq_ft/algos/generic_select.py index 68d62cf98f6..8822beb32f6 100644 --- a/cirq-ft/cirq_ft/algos/generic_select.py +++ b/cirq-ft/cirq_ft/algos/generic_select.py @@ -68,23 +68,20 @@ def __attrs_post_init__(self): ) @cached_property - def control_registers(self) -> infra.Registers: - registers = [] if self.control_val is None else [infra.Register('control', 1)] - return infra.Registers(registers) + def control_registers(self) -> Tuple[infra.Register, ...]: + return () if self.control_val is None else (infra.Register('control', 1),) @cached_property - def selection_registers(self) -> infra.SelectionRegisters: - return infra.SelectionRegisters( - [ - infra.SelectionRegister( - 'selection', self.selection_bitsize, len(self.select_unitaries) - ) - ] + def selection_registers(self) -> Tuple[infra.SelectionRegister, ...]: + return ( + infra.SelectionRegister( + 'selection', self.selection_bitsize, len(self.select_unitaries) + ), ) @cached_property - def target_registers(self) -> infra.Registers: - return infra.Registers.build(target=self.target_bitsize) + def target_registers(self) -> Tuple[infra.Register, ...]: + return (infra.Register('target', self.target_bitsize),) def decompose_from_registers( self, context, **quregs: NDArray[cirq.Qid] # type:ignore[type-var] diff --git a/cirq-ft/cirq_ft/algos/generic_select_test.py b/cirq-ft/cirq_ft/algos/generic_select_test.py index c074f8d3197..255e9ba6b79 100644 --- a/cirq-ft/cirq_ft/algos/generic_select_test.py +++ b/cirq-ft/cirq_ft/algos/generic_select_test.py @@ -17,6 +17,7 @@ import cirq_ft import numpy as np import pytest +from cirq_ft import infra from cirq_ft.infra.bit_tools import iter_bits from cirq_ft.infra.jupyter_tools import execute_notebook @@ -255,7 +256,7 @@ def test_generic_select_consistent_protocols_and_controlled(): # Build GenericSelect gate. gate = cirq_ft.GenericSelect(select_bitsize, num_sites, dps_hamiltonian) - op = gate.on_registers(**gate.registers.get_named_qubits()) + op = gate.on_registers(**infra.get_named_qubits(gate.registers)) cirq.testing.assert_equivalent_repr(gate, setup_code='import cirq\nimport cirq_ft') # Build controlled gate diff --git a/cirq-ft/cirq_ft/algos/hubbard_model.py b/cirq-ft/cirq_ft/algos/hubbard_model.py index 520d305062a..dad2a443c9a 100644 --- a/cirq-ft/cirq_ft/algos/hubbard_model.py +++ b/cirq-ft/cirq_ft/algos/hubbard_model.py @@ -118,28 +118,25 @@ def __attrs_post_init__(self): raise NotImplementedError("Currently only supports the case where x_dim=y_dim.") @cached_property - def control_registers(self) -> infra.Registers: - registers = [] if self.control_val is None else [infra.Register('control', 1)] - return infra.Registers(registers) + def control_registers(self) -> Tuple[infra.Register, ...]: + return () if self.control_val is None else (infra.Register('control', 1),) @cached_property - def selection_registers(self) -> infra.SelectionRegisters: - return infra.SelectionRegisters( - [ - infra.SelectionRegister('U', 1, 2), - infra.SelectionRegister('V', 1, 2), - infra.SelectionRegister('p_x', (self.x_dim - 1).bit_length(), self.x_dim), - infra.SelectionRegister('p_y', (self.y_dim - 1).bit_length(), self.y_dim), - infra.SelectionRegister('alpha', 1, 2), - infra.SelectionRegister('q_x', (self.x_dim - 1).bit_length(), self.x_dim), - infra.SelectionRegister('q_y', (self.y_dim - 1).bit_length(), self.y_dim), - infra.SelectionRegister('beta', 1, 2), - ] + def selection_registers(self) -> Tuple[infra.SelectionRegister, ...]: + return ( + infra.SelectionRegister('U', 1, 2), + infra.SelectionRegister('V', 1, 2), + infra.SelectionRegister('p_x', (self.x_dim - 1).bit_length(), self.x_dim), + infra.SelectionRegister('p_y', (self.y_dim - 1).bit_length(), self.y_dim), + infra.SelectionRegister('alpha', 1, 2), + infra.SelectionRegister('q_x', (self.x_dim - 1).bit_length(), self.x_dim), + infra.SelectionRegister('q_y', (self.y_dim - 1).bit_length(), self.y_dim), + infra.SelectionRegister('beta', 1, 2), ) @cached_property - def target_registers(self) -> infra.Registers: - return infra.Registers.build(target=self.x_dim * self.y_dim * 2) + def target_registers(self) -> Tuple[infra.Register, ...]: + return (infra.Register('target', self.x_dim * self.y_dim * 2),) @cached_property def registers(self) -> infra.Registers: @@ -158,12 +155,10 @@ def decompose_from_registers( control, target = quregs.get('control', ()), quregs['target'] yield selected_majorana_fermion.SelectedMajoranaFermionGate( - selection_regs=infra.SelectionRegisters( - [ - infra.SelectionRegister('alpha', 1, 2), - infra.SelectionRegister('p_y', self.registers['p_y'].total_bits(), self.y_dim), - infra.SelectionRegister('p_x', self.registers['p_x'].total_bits(), self.x_dim), - ] + selection_regs=( + infra.SelectionRegister('alpha', 1, 2), + infra.SelectionRegister('p_y', self.registers['p_y'].total_bits(), self.y_dim), + infra.SelectionRegister('p_x', self.registers['p_x'].total_bits(), self.x_dim), ), control_regs=self.control_registers, target_gate=cirq.Y, @@ -173,12 +168,10 @@ def decompose_from_registers( yield swap_network.MultiTargetCSwap.make_on(control=V, target_x=p_y, target_y=q_y) yield swap_network.MultiTargetCSwap.make_on(control=V, target_x=alpha, target_y=beta) - q_selection_regs = infra.SelectionRegisters( - [ - infra.SelectionRegister('beta', 1, 2), - infra.SelectionRegister('q_y', self.registers['q_y'].total_bits(), self.y_dim), - infra.SelectionRegister('q_x', self.registers['q_x'].total_bits(), self.x_dim), - ] + q_selection_regs = ( + infra.SelectionRegister('beta', 1, 2), + infra.SelectionRegister('q_y', self.registers['q_y'].total_bits(), self.y_dim), + infra.SelectionRegister('q_x', self.registers['q_x'].total_bits(), self.x_dim), ) yield selected_majorana_fermion.SelectedMajoranaFermionGate( selection_regs=q_selection_regs, control_regs=self.control_registers, target_gate=cirq.X @@ -194,20 +187,18 @@ def decompose_from_registers( yield cirq.Z(*U).controlled_by(*control) # Fix errant -1 from multiple pauli applications target_qubits_for_apply_to_lth_gate = [ - target[q_selection_regs.to_flat_idx(1, qy, qx)] + target[np.ravel_multi_index((1, qy, qx), (2, self.y_dim, self.x_dim))] for qx in range(self.x_dim) for qy in range(self.y_dim) ] yield apply_gate_to_lth_target.ApplyGateToLthQubit( - selection_regs=infra.SelectionRegisters( - [ - infra.SelectionRegister('q_y', self.registers['q_y'].total_bits(), self.y_dim), - infra.SelectionRegister('q_x', self.registers['q_x'].total_bits(), self.x_dim), - ] + selection_regs=( + infra.SelectionRegister('q_y', self.registers['q_y'].total_bits(), self.y_dim), + infra.SelectionRegister('q_x', self.registers['q_x'].total_bits(), self.x_dim), ), nth_gate=lambda *_: cirq.Z, - control_regs=infra.Registers.build(control=1 + self.control_registers.total_bits()), + control_regs=infra.Register('control', 1 + infra.total_bits(self.control_registers)), ).on_registers( q_x=q_x, q_y=q_y, control=[*V, *control], target=target_qubits_for_apply_to_lth_gate ) @@ -291,23 +282,21 @@ def __attrs_post_init__(self): raise NotImplementedError("Currently only supports the case where x_dim=y_dim.") @cached_property - def selection_registers(self) -> infra.SelectionRegisters: - return infra.SelectionRegisters( - [ - infra.SelectionRegister('U', 1, 2), - infra.SelectionRegister('V', 1, 2), - infra.SelectionRegister('p_x', (self.x_dim - 1).bit_length(), self.x_dim), - infra.SelectionRegister('p_y', (self.y_dim - 1).bit_length(), self.y_dim), - infra.SelectionRegister('alpha', 1, 2), - infra.SelectionRegister('q_x', (self.x_dim - 1).bit_length(), self.x_dim), - infra.SelectionRegister('q_y', (self.y_dim - 1).bit_length(), self.y_dim), - infra.SelectionRegister('beta', 1, 2), - ] + def selection_registers(self) -> Tuple[infra.SelectionRegister, ...]: + return ( + infra.SelectionRegister('U', 1, 2), + infra.SelectionRegister('V', 1, 2), + infra.SelectionRegister('p_x', (self.x_dim - 1).bit_length(), self.x_dim), + infra.SelectionRegister('p_y', (self.y_dim - 1).bit_length(), self.y_dim), + infra.SelectionRegister('alpha', 1, 2), + infra.SelectionRegister('q_x', (self.x_dim - 1).bit_length(), self.x_dim), + infra.SelectionRegister('q_y', (self.y_dim - 1).bit_length(), self.y_dim), + infra.SelectionRegister('beta', 1, 2), ) @cached_property - def junk_registers(self) -> infra.Registers: - return infra.Registers.build(temp=2) + def junk_registers(self) -> Tuple[infra.Register, ...]: + return (infra.Register('temp', 2),) @cached_property def registers(self) -> infra.Registers: diff --git a/cirq-ft/cirq_ft/algos/hubbard_model_test.py b/cirq-ft/cirq_ft/algos/hubbard_model_test.py index 43c66f9bf0f..b13f9e6dfd6 100644 --- a/cirq-ft/cirq_ft/algos/hubbard_model_test.py +++ b/cirq-ft/cirq_ft/algos/hubbard_model_test.py @@ -15,6 +15,7 @@ import cirq import cirq_ft import pytest +from cirq_ft import infra from cirq_ft.infra.jupyter_tools import execute_notebook @@ -48,7 +49,7 @@ def test_hubbard_model_consistent_protocols(): cirq.testing.assert_equivalent_repr(prepare_gate, setup_code='import cirq_ft') # Build controlled SELECT gate - select_op = select_gate.on_registers(**select_gate.registers.get_named_qubits()) + select_op = select_gate.on_registers(**infra.get_named_qubits(select_gate.registers)) equals_tester = cirq.testing.EqualsTester() equals_tester.add_equality_group( select_gate.controlled(), diff --git a/cirq-ft/cirq_ft/algos/mean_estimation/complex_phase_oracle.py b/cirq-ft/cirq_ft/algos/mean_estimation/complex_phase_oracle.py index 74f992b0f5d..695a51ba854 100644 --- a/cirq-ft/cirq_ft/algos/mean_estimation/complex_phase_oracle.py +++ b/cirq-ft/cirq_ft/algos/mean_estimation/complex_phase_oracle.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Tuple from numpy.typing import NDArray import attr @@ -37,11 +38,11 @@ class ComplexPhaseOracle(infra.GateWithRegisters): arctan_bitsize: int = 32 @cached_property - def control_registers(self) -> infra.Registers: + def control_registers(self) -> Tuple[infra.Register, ...]: return self.encoder.control_registers @cached_property - def selection_registers(self) -> infra.SelectionRegisters: + def selection_registers(self) -> Tuple[infra.SelectionRegister, ...]: return self.encoder.selection_registers @cached_property @@ -58,7 +59,7 @@ def decompose_from_registers( target_reg = { reg.name: qm.qalloc(reg.total_bits()) for reg in self.encoder.target_registers } - target_qubits = self.encoder.target_registers.merge_qubits(**target_reg) + target_qubits = infra.merge_qubits(self.encoder.target_registers, **target_reg) encoder_op = self.encoder.on_registers(**quregs, **target_reg) arctan_sign, arctan_target = qm.qalloc(1), qm.qalloc(self.arctan_bitsize) @@ -78,6 +79,6 @@ def decompose_from_registers( qm.qfree([*arctan_sign, *arctan_target, *target_qubits]) def _circuit_diagram_info_(self, args: cirq.CircuitDiagramInfoArgs) -> cirq.CircuitDiagramInfo: - wire_symbols = ['@'] * self.control_registers.total_bits() - wire_symbols += ['ROTy'] * self.selection_registers.total_bits() + wire_symbols = ['@'] * infra.total_bits(self.control_registers) + wire_symbols += ['ROTy'] * infra.total_bits(self.selection_registers) return cirq.CircuitDiagramInfo(wire_symbols=wire_symbols) diff --git a/cirq-ft/cirq_ft/algos/mean_estimation/complex_phase_oracle_test.py b/cirq-ft/cirq_ft/algos/mean_estimation/complex_phase_oracle_test.py index 7e6b61527c2..a7926bf3847 100644 --- a/cirq-ft/cirq_ft/algos/mean_estimation/complex_phase_oracle_test.py +++ b/cirq-ft/cirq_ft/algos/mean_estimation/complex_phase_oracle_test.py @@ -13,7 +13,7 @@ # limitations under the License. import math -from typing import Optional +from typing import Optional, Tuple import cirq import cirq_ft @@ -32,17 +32,16 @@ class DummySelect(cirq_ft.SelectOracle): control_val: Optional[int] = None @cached_property - def control_registers(self) -> cirq_ft.Registers: - registers = [] if self.control_val is None else [cirq_ft.Register('control', 1)] - return cirq_ft.Registers(registers) + def control_registers(self) -> Tuple[cirq_ft.Register, ...]: + return () if self.control_val is None else (cirq_ft.Register('control', 1),) @cached_property - def selection_registers(self) -> cirq_ft.SelectionRegisters: - return cirq_ft.SelectionRegisters.build(selection=self.bitsize) + def selection_registers(self) -> Tuple[cirq_ft.SelectionRegister, ...]: + return (cirq_ft.SelectionRegister('selection', self.bitsize),) @cached_property - def target_registers(self) -> cirq_ft.Registers: - return cirq_ft.Registers.build(target=self.bitsize) + def target_registers(self) -> Tuple[cirq_ft.Register, ...]: + return (cirq_ft.Register('target', self.bitsize),) def decompose_from_registers(self, context, selection, target): yield [cirq.CNOT(s, t) for s, t in zip(selection, target)] diff --git a/cirq-ft/cirq_ft/algos/mean_estimation/mean_estimation_operator.py b/cirq-ft/cirq_ft/algos/mean_estimation/mean_estimation_operator.py index f3fff37ecd7..40de332dad1 100644 --- a/cirq-ft/cirq_ft/algos/mean_estimation/mean_estimation_operator.py +++ b/cirq-ft/cirq_ft/algos/mean_estimation/mean_estimation_operator.py @@ -80,7 +80,9 @@ class MeanEstimationOperator(infra.GateWithRegisters): """ code: CodeForRandomVariable - cv: Tuple[int, ...] = attr.field(converter=infra.to_tuple, default=()) + cv: Tuple[int, ...] = attr.field( + converter=lambda v: (v,) if isinstance(v, int) else tuple(v), default=() + ) power: int = 1 arctan_bitsize: int = 32 @@ -99,11 +101,11 @@ def select(self) -> complex_phase_oracle.ComplexPhaseOracle: return complex_phase_oracle.ComplexPhaseOracle(self.code.encoder, self.arctan_bitsize) @cached_property - def control_registers(self) -> infra.Registers: + def control_registers(self) -> Tuple[infra.Register, ...]: return self.code.encoder.control_registers @cached_property - def selection_registers(self) -> infra.SelectionRegisters: + def selection_registers(self) -> Tuple[infra.SelectionRegister, ...]: return self.code.encoder.selection_registers @cached_property @@ -130,7 +132,7 @@ def decompose_from_registers( def _circuit_diagram_info_(self, args: cirq.CircuitDiagramInfoArgs) -> cirq.CircuitDiagramInfo: wire_symbols = [] if self.cv == () else [["@(0)", "@"][self.cv[0]]] wire_symbols += ['U_ko'] * ( - self.registers.total_bits() - self.control_registers.total_bits() + infra.total_bits(self.registers) - infra.total_bits(self.control_registers) ) if self.power != 1: wire_symbols[-1] = f'U_ko^{self.power}' diff --git a/cirq-ft/cirq_ft/algos/mean_estimation/mean_estimation_operator_test.py b/cirq-ft/cirq_ft/algos/mean_estimation/mean_estimation_operator_test.py index f9f7c359165..ef9596dd861 100644 --- a/cirq-ft/cirq_ft/algos/mean_estimation/mean_estimation_operator_test.py +++ b/cirq-ft/cirq_ft/algos/mean_estimation/mean_estimation_operator_test.py @@ -19,6 +19,7 @@ import numpy as np import pytest from attr import frozen +from cirq_ft import infra from cirq._compat import cached_property from cirq_ft.algos.mean_estimation import CodeForRandomVariable, MeanEstimationOperator from cirq_ft.infra import bit_tools @@ -32,8 +33,8 @@ class BernoulliSynthesizer(cirq_ft.PrepareOracle): nqubits: int @cached_property - def selection_registers(self) -> cirq_ft.SelectionRegisters: - return cirq_ft.SelectionRegisters([cirq_ft.SelectionRegister('q', self.nqubits, 2)]) + def selection_registers(self) -> Tuple[cirq_ft.SelectionRegister, ...]: + return (cirq_ft.SelectionRegister('q', self.nqubits, 2),) def decompose_from_registers( # type:ignore[override] self, context, q: Sequence[cirq.Qid] @@ -54,19 +55,16 @@ class BernoulliEncoder(cirq_ft.SelectOracle): control_val: Optional[int] = None @cached_property - def control_registers(self) -> cirq_ft.Registers: - registers = [] if self.control_val is None else [cirq_ft.Register('control', 1)] - return cirq_ft.Registers(registers) + def control_registers(self) -> Tuple[cirq_ft.Register, ...]: + return () if self.control_val is None else (cirq_ft.Register('control', 1),) @cached_property - def selection_registers(self) -> cirq_ft.SelectionRegisters: - return cirq_ft.SelectionRegisters( - [cirq_ft.SelectionRegister('q', self.selection_bitsize, 2)] - ) + def selection_registers(self) -> Tuple[cirq_ft.SelectionRegister, ...]: + return (cirq_ft.SelectionRegister('q', self.selection_bitsize, 2),) @cached_property - def target_registers(self) -> cirq_ft.Registers: - return cirq_ft.Registers.build(t=self.target_bitsize) + def target_registers(self) -> Tuple[cirq_ft.Register, ...]: + return (cirq_ft.Register('t', self.target_bitsize),) def decompose_from_registers( # type:ignore[override] self, context, q: Sequence[cirq.Qid], t: Sequence[cirq.Qid] @@ -119,7 +117,7 @@ def satisfies_theorem_321( assert cirq.is_unitary(u) # Compute the final state vector obtained using the synthesizer `Prep |0>` - prep_op = synthesizer.on_registers(**synthesizer.registers.get_named_qubits()) + prep_op = synthesizer.on_registers(**infra.get_named_qubits(synthesizer.registers)) prep_state = cirq.Circuit(prep_op).final_state_vector() expected_hav = abs(mu) * np.sqrt(1 / (1 + s**2)) @@ -174,8 +172,8 @@ class GroverSynthesizer(cirq_ft.PrepareOracle): n: int @cached_property - def selection_registers(self) -> cirq_ft.SelectionRegisters: - return cirq_ft.SelectionRegisters.build(selection=self.n) + def selection_registers(self) -> Tuple[cirq_ft.SelectionRegister, ...]: + return (cirq_ft.SelectionRegister('selection', self.n),) def decompose_from_registers( # type:ignore[override] self, *, context, selection: Sequence[cirq.Qid] @@ -197,24 +195,24 @@ class GroverEncoder(cirq_ft.SelectOracle): marked_val: int @cached_property - def control_registers(self) -> cirq_ft.Registers: - return cirq_ft.Registers([]) + def control_registers(self) -> Tuple[cirq_ft.Register, ...]: + return () @cached_property - def selection_registers(self) -> cirq_ft.SelectionRegisters: - return cirq_ft.SelectionRegisters.build(selection=self.n) + def selection_registers(self) -> Tuple[cirq_ft.SelectionRegister, ...]: + return (cirq_ft.SelectionRegister('selection', self.n),) @cached_property - def target_registers(self) -> cirq_ft.Registers: - return cirq_ft.Registers.build(target=self.marked_val.bit_length()) + def target_registers(self) -> Tuple[cirq_ft.Register, ...]: + return (cirq_ft.Register('target', self.marked_val.bit_length()),) def decompose_from_registers( # type:ignore[override] self, context, *, selection: Sequence[cirq.Qid], target: Sequence[cirq.Qid] ) -> cirq.OP_TREE: selection_cv = [ - *bit_tools.iter_bits(self.marked_item, self.selection_registers.total_bits()) + *bit_tools.iter_bits(self.marked_item, infra.total_bits(self.selection_registers)) ] - yval_bin = [*bit_tools.iter_bits(self.marked_val, self.target_registers.total_bits())] + yval_bin = [*bit_tools.iter_bits(self.marked_val, infra.total_bits(self.target_registers))] for b, q in zip(yval_bin, target): if b: @@ -254,7 +252,7 @@ def test_mean_estimation_operator_consistent_protocols(): encoder = BernoulliEncoder(p, (0, y_1), selection_bitsize, target_bitsize) code = CodeForRandomVariable(synthesizer=synthesizer, encoder=encoder) mean_gate = MeanEstimationOperator(code, arctan_bitsize=arctan_bitsize) - op = mean_gate.on_registers(**mean_gate.registers.get_named_qubits()) + op = mean_gate.on_registers(**infra.get_named_qubits(mean_gate.registers)) # Test controlled gate. equals_tester = cirq.testing.EqualsTester() diff --git a/cirq-ft/cirq_ft/algos/multi_control_multi_target_pauli.py b/cirq-ft/cirq_ft/algos/multi_control_multi_target_pauli.py index 6ab7e65e51f..bb96215b729 100644 --- a/cirq-ft/cirq_ft/algos/multi_control_multi_target_pauli.py +++ b/cirq-ft/cirq_ft/algos/multi_control_multi_target_pauli.py @@ -73,7 +73,7 @@ class MultiControlPauli(infra.GateWithRegisters): (https://algassert.com/circuits/2015/06/05/Constructing-Large-Controlled-Nots.html) """ - cvs: Tuple[int, ...] = attr.field(converter=infra.to_tuple) + cvs: Tuple[int, ...] = attr.field(converter=lambda v: (v,) if isinstance(v, int) else tuple(v)) target_gate: cirq.Pauli = cirq.X @cached_property diff --git a/cirq-ft/cirq_ft/algos/phase_estimation_of_quantum_walk.ipynb b/cirq-ft/cirq_ft/algos/phase_estimation_of_quantum_walk.ipynb index 206b4466b32..6121248b9fa 100644 --- a/cirq-ft/cirq_ft/algos/phase_estimation_of_quantum_walk.ipynb +++ b/cirq-ft/cirq_ft/algos/phase_estimation_of_quantum_walk.ipynb @@ -39,6 +39,7 @@ "import numpy as np\n", "\n", "import cirq_ft\n", + "from cirq_ft import infra\n", "\n", "from cirq_ft.algos.qubitization_walk_operator_test import get_walk_operator_for_1d_Ising_model\n", "from cirq_ft.algos.hubbard_model import get_walk_operator_for_hubbard_model" @@ -87,7 +88,7 @@ " Fig. 2\n", " \"\"\"\n", " reflect = walk.reflect\n", - " walk_regs = walk.registers.get_named_qubits()\n", + " walk_regs = infra.get_named_qubits(walk.registers)\n", " reflect_regs = {k:v for k, v in walk_regs.items() if k in reflect.registers}\n", " \n", " reflect_controlled = reflect.controlled(control_values=[0])\n", diff --git a/cirq-ft/cirq_ft/algos/prepare_uniform_superposition.py b/cirq-ft/cirq_ft/algos/prepare_uniform_superposition.py index e75f735bfe9..374415e90bc 100644 --- a/cirq-ft/cirq_ft/algos/prepare_uniform_superposition.py +++ b/cirq-ft/cirq_ft/algos/prepare_uniform_superposition.py @@ -46,7 +46,9 @@ class PrepareUniformSuperposition(infra.GateWithRegisters): """ n: int - cv: Tuple[int, ...] = attr.field(converter=infra.to_tuple, default=()) + cv: Tuple[int, ...] = attr.field( + converter=lambda v: (v,) if isinstance(v, int) else tuple(v), default=() + ) @cached_property def registers(self) -> infra.Registers: diff --git a/cirq-ft/cirq_ft/algos/prepare_uniform_superposition_test.py b/cirq-ft/cirq_ft/algos/prepare_uniform_superposition_test.py index b89ac0bb698..f58cc671e63 100644 --- a/cirq-ft/cirq_ft/algos/prepare_uniform_superposition_test.py +++ b/cirq-ft/cirq_ft/algos/prepare_uniform_superposition_test.py @@ -14,6 +14,7 @@ import cirq import cirq_ft +from cirq_ft import infra import numpy as np import pytest @@ -51,7 +52,7 @@ def test_prepare_uniform_superposition_t_complexity(n: int): result = cirq_ft.t_complexity(gate) # TODO(#233): Controlled-H is currently counted as a separate rotation, but it can be # implemented using 2 T-gates. - assert result.rotations <= 2 + 2 * gate.registers.total_bits() + assert result.rotations <= 2 + 2 * infra.total_bits(gate.registers) assert result.t <= 12 * (n - 1).bit_length() diff --git a/cirq-ft/cirq_ft/algos/programmable_rotation_gate_array.py b/cirq-ft/cirq_ft/algos/programmable_rotation_gate_array.py index 47e43857247..158ec1a112d 100644 --- a/cirq-ft/cirq_ft/algos/programmable_rotation_gate_array.py +++ b/cirq-ft/cirq_ft/algos/programmable_rotation_gate_array.py @@ -102,22 +102,20 @@ def interleaved_unitary( pass @cached_property - def selection_registers(self) -> infra.SelectionRegisters: - return infra.SelectionRegisters( - [infra.SelectionRegister('selection', self._selection_bitsize, len(self.angles[0]))] - ) + def selection_registers(self) -> Tuple[infra.SelectionRegister, ...]: + return (infra.SelectionRegister('selection', self._selection_bitsize, len(self.angles[0])),) @cached_property - def kappa_load_target(self) -> infra.Registers: - return infra.Registers.build(kappa_load_target=self.kappa) + def kappa_load_target(self) -> Tuple[infra.Register, ...]: + return (infra.Register('kappa_load_target', self.kappa),) @cached_property - def rotations_target(self) -> infra.Registers: - return infra.Registers.build(rotations_target=self._target_bitsize) + def rotations_target(self) -> Tuple[infra.Register, ...]: + return (infra.Register('rotations_target', self._target_bitsize),) @property @abc.abstractmethod - def interleaved_unitary_target(self) -> infra.Registers: + def interleaved_unitary_target(self) -> Tuple[infra.Register, ...]: pass @cached_property @@ -195,7 +193,7 @@ def __init__( ): super().__init__(*angles, kappa=kappa, rotation_gate=rotation_gate) if not interleaved_unitaries: - identity_gate = cirq.IdentityGate(self.rotations_target.total_bits()) + identity_gate = cirq.IdentityGate(infra.total_bits(self.rotations_target)) interleaved_unitaries = (identity_gate,) * (len(angles) - 1) assert len(interleaved_unitaries) == len(angles) - 1 assert all(cirq.num_qubits(u) == self._target_bitsize for u in interleaved_unitaries) @@ -205,5 +203,5 @@ def interleaved_unitary(self, index: int, **qubit_regs: NDArray[cirq.Qid]) -> ci return self._interleaved_unitaries[index].on(*qubit_regs['rotations_target']) @cached_property - def interleaved_unitary_target(self) -> infra.Registers: - return infra.Registers.build() + def interleaved_unitary_target(self) -> Tuple[infra.Register, ...]: + return () diff --git a/cirq-ft/cirq_ft/algos/programmable_rotation_gate_array_test.py b/cirq-ft/cirq_ft/algos/programmable_rotation_gate_array_test.py index a860bb07b6e..cdd4212bcc3 100644 --- a/cirq-ft/cirq_ft/algos/programmable_rotation_gate_array_test.py +++ b/cirq-ft/cirq_ft/algos/programmable_rotation_gate_array_test.py @@ -12,12 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Tuple from numpy.typing import NDArray import cirq import cirq_ft import numpy as np import pytest +from cirq_ft import infra from cirq._compat import cached_property from cirq_ft.infra.bit_tools import iter_bits @@ -33,8 +35,8 @@ def interleaved_unitary( return two_qubit_ops_factory[index % 2] @cached_property - def interleaved_unitary_target(self) -> cirq_ft.Registers: - return cirq_ft.Registers.build(unrelated_target=1) + def interleaved_unitary_target(self) -> Tuple[cirq_ft.Register, ...]: + return tuple(cirq_ft.Registers.build(unrelated_target=1)) def construct_custom_prga(*args, **kwargs) -> cirq_ft.ProgrammableRotationGateArrayBase: @@ -78,7 +80,7 @@ def test_programmable_rotation_gate_array(angles, kappa, constructor): *programmable_rotation_gate.interleaved_unitary_target, ] ) - rotations_and_unitary_qubits = rotations_and_unitary_registers.merge_qubits(**g.quregs) + rotations_and_unitary_qubits = infra.merge_qubits(rotations_and_unitary_registers, **g.quregs) # Build circuit. simulator = cirq.Simulator(dtype=np.complex128) diff --git a/cirq-ft/cirq_ft/algos/qrom.py b/cirq-ft/cirq_ft/algos/qrom.py index fdbe36792cc..8d09d82ed9b 100644 --- a/cirq-ft/cirq_ft/algos/qrom.py +++ b/cirq-ft/cirq_ft/algos/qrom.py @@ -92,36 +92,26 @@ def __attrs_post_init__(self): assert isinstance(self.target_bitsizes, tuple) @cached_property - def control_registers(self) -> infra.Registers: - return ( - infra.Registers.build(control=self.num_controls) - if self.num_controls - else infra.Registers([]) - ) + def control_registers(self) -> Tuple[infra.Register, ...]: + return () if not self.num_controls else (infra.Register('control', self.num_controls),) @cached_property - def selection_registers(self) -> infra.SelectionRegisters: + def selection_registers(self) -> Tuple[infra.SelectionRegister, ...]: if len(self.data[0].shape) == 1: - return infra.SelectionRegisters( - [ - infra.SelectionRegister( - 'selection', self.selection_bitsizes[0], self.data[0].shape[0] - ) - ] + return ( + infra.SelectionRegister( + 'selection', self.selection_bitsizes[0], self.data[0].shape[0] + ), ) else: - return infra.SelectionRegisters( - [ - infra.SelectionRegister(f'selection{i}', sb, len) - for i, (len, sb) in enumerate(zip(self.data[0].shape, self.selection_bitsizes)) - ] + return tuple( + infra.SelectionRegister(f'selection{i}', sb, l) + for i, (l, sb) in enumerate(zip(self.data[0].shape, self.selection_bitsizes)) ) @cached_property - def target_registers(self) -> infra.Registers: - return infra.Registers.build( - **{f'target{i}': len for i, len in enumerate(self.target_bitsizes)} - ) + def target_registers(self) -> Tuple[infra.Register, ...]: + return tuple(infra.Register(f'target{i}', l) for i, l in enumerate(self.target_bitsizes)) def __repr__(self) -> str: data_repr = f"({','.join(cirq._compat.proper_repr(d) for d in self.data)})" @@ -147,8 +137,8 @@ def _load_nth_data( def decompose_zero_selection( self, context: cirq.DecompositionContext, **quregs: NDArray[cirq.Qid] ) -> cirq.OP_TREE: - controls = self.control_registers.merge_qubits(**quregs) - target_regs = {k: v for k, v in quregs.items() if k in self.target_registers} + controls = infra.merge_qubits(self.control_registers, **quregs) + target_regs = {reg.name: quregs[reg.name] for reg in self.target_registers} zero_indx = (0,) * len(self.data[0].shape) if self.num_controls == 0: yield self._load_nth_data(zero_indx, cirq.X, **target_regs) @@ -181,7 +171,7 @@ def nth_operation( def _circuit_diagram_info_(self, _) -> cirq.CircuitDiagramInfo: wire_symbols = ["@"] * self.num_controls - wire_symbols += ["In"] * self.selection_registers.total_bits() + wire_symbols += ["In"] * infra.total_bits(self.selection_registers) for i, target in enumerate(self.target_registers): wire_symbols += [f"QROM_{i}"] * target.total_bits() return cirq.CircuitDiagramInfo(wire_symbols=wire_symbols) diff --git a/cirq-ft/cirq_ft/algos/qrom_test.py b/cirq-ft/cirq_ft/algos/qrom_test.py index 01025ac38c5..514e7f03935 100644 --- a/cirq-ft/cirq_ft/algos/qrom_test.py +++ b/cirq-ft/cirq_ft/algos/qrom_test.py @@ -18,6 +18,7 @@ import cirq_ft import numpy as np import pytest +from cirq_ft import infra from cirq_ft.infra.bit_tools import iter_bits from cirq_ft.infra.jupyter_tools import execute_notebook @@ -34,7 +35,8 @@ def test_qrom_1d(data, num_controls): inverse = cirq.Circuit(cirq.decompose(g.operation**-1, context=g.context)) assert ( - len(inverse.all_qubits()) <= g.r.total_bits() + g.r['selection'].total_bits() + num_controls + len(inverse.all_qubits()) + <= infra.total_bits(g.r) + g.r['selection'].total_bits() + num_controls ) assert inverse.all_qubits() == decomposed_circuit.all_qubits() @@ -73,7 +75,7 @@ def test_qrom_diagram(): d1 = np.array([4, 5, 6]) qrom = cirq_ft.QROM.build(d0, d1) q = cirq.LineQubit.range(cirq.num_qubits(qrom)) - circuit = cirq.Circuit(qrom.on_registers(**qrom.registers.split_qubits(q))) + circuit = cirq.Circuit(qrom.on_registers(**infra.split_qubits(qrom.registers, q))) cirq.testing.assert_has_diagram( circuit, """ @@ -213,7 +215,7 @@ def test_qrom_multi_dim(data, num_controls): assert ( len(inverse.all_qubits()) - <= g.r.total_bits() + qrom.selection_registers.total_bits() + num_controls + <= infra.total_bits(g.r) + infra.total_bits(qrom.selection_registers) + num_controls ) assert inverse.all_qubits() == decomposed_circuit.all_qubits() diff --git a/cirq-ft/cirq_ft/algos/qubitization_walk_operator.py b/cirq-ft/cirq_ft/algos/qubitization_walk_operator.py index 6910de45905..f39964af93f 100644 --- a/cirq-ft/cirq_ft/algos/qubitization_walk_operator.py +++ b/cirq-ft/cirq_ft/algos/qubitization_walk_operator.py @@ -61,15 +61,15 @@ def __attrs_post_init__(self): assert self.select.control_registers == self.reflect.control_registers @cached_property - def control_registers(self) -> infra.Registers: + def control_registers(self) -> Tuple[infra.Register, ...]: return self.select.control_registers @cached_property - def selection_registers(self) -> infra.SelectionRegisters: + def selection_registers(self) -> Tuple[infra.SelectionRegister, ...]: return self.prepare.selection_registers @cached_property - def target_registers(self) -> infra.Registers: + def target_registers(self) -> Tuple[infra.Register, ...]: return self.select.target_registers @cached_property @@ -99,8 +99,12 @@ def decompose_from_registers( yield reflect_op def _circuit_diagram_info_(self, args: cirq.CircuitDiagramInfoArgs) -> cirq.CircuitDiagramInfo: - wire_symbols = ['@' if self.control_val else '@(0)'] * self.control_registers.total_bits() - wire_symbols += ['W'] * (self.registers.total_bits() - self.control_registers.total_bits()) + wire_symbols = ['@' if self.control_val else '@(0)'] * infra.total_bits( + self.control_registers + ) + wire_symbols += ['W'] * ( + infra.total_bits(self.registers) - infra.total_bits(self.control_registers) + ) wire_symbols[-1] = f'W^{self.power}' if self.power != 1 else 'W' return cirq.CircuitDiagramInfo(wire_symbols=wire_symbols) diff --git a/cirq-ft/cirq_ft/algos/qubitization_walk_operator_test.py b/cirq-ft/cirq_ft/algos/qubitization_walk_operator_test.py index 8ea413661da..9b54501e99c 100644 --- a/cirq-ft/cirq_ft/algos/qubitization_walk_operator_test.py +++ b/cirq-ft/cirq_ft/algos/qubitization_walk_operator_test.py @@ -16,6 +16,7 @@ import cirq_ft import numpy as np import pytest +from cirq_ft import infra from cirq_ft.algos.generic_select_test import get_1d_Ising_hamiltonian from cirq_ft.algos.reflection_using_prepare_test import greedily_allocate_ancilla, keep from cirq_ft.infra.jupyter_tools import execute_notebook @@ -31,7 +32,9 @@ def walk_operator_for_pauli_hamiltonian( ham_coeff, probability_epsilon=eps ) select = cirq_ft.GenericSelect( - prepare.selection_registers.total_bits(), select_unitaries=ham_dps, target_bitsize=len(q) + infra.total_bits(prepare.selection_registers), + select_unitaries=ham_dps, + target_bitsize=len(q), ) return cirq_ft.QubitizationWalkOperator(select=select, prepare=prepare) @@ -96,7 +99,7 @@ def test_qubitization_walk_operator_diagrams(): num_sites, eps = 4, 1e-1 walk = get_walk_operator_for_1d_Ising_model(num_sites, eps) # 1. Diagram for $W = SELECT.R_{L}$ - qu_regs = walk.registers.get_named_qubits() + qu_regs = infra.get_named_qubits(walk.registers) walk_op = walk.on_registers(**qu_regs) circuit = cirq.Circuit(cirq.decompose_once(walk_op)) cirq.testing.assert_has_diagram( @@ -214,7 +217,7 @@ def keep(op): def test_qubitization_walk_operator_consistent_protocols_and_controlled(): gate = get_walk_operator_for_1d_Ising_model(4, 1e-1) - op = gate.on_registers(**gate.registers.get_named_qubits()) + op = gate.on_registers(**infra.get_named_qubits(gate.registers)) # Test consistent repr cirq.testing.assert_equivalent_repr( gate, setup_code='import cirq\nimport cirq_ft\nimport numpy as np' diff --git a/cirq-ft/cirq_ft/algos/reflection_using_prepare.py b/cirq-ft/cirq_ft/algos/reflection_using_prepare.py index 361644fc5ee..980465f524d 100644 --- a/cirq-ft/cirq_ft/algos/reflection_using_prepare.py +++ b/cirq-ft/cirq_ft/algos/reflection_using_prepare.py @@ -57,12 +57,11 @@ class ReflectionUsingPrepare(infra.GateWithRegisters): control_val: Optional[int] = None @cached_property - def control_registers(self) -> infra.Registers: - registers = [] if self.control_val is None else [infra.Register('control', 1)] - return infra.Registers(registers) + def control_registers(self) -> Tuple[infra.Register, ...]: + return () if self.control_val is None else (infra.Register('control', 1),) @cached_property - def selection_registers(self) -> infra.SelectionRegisters: + def selection_registers(self) -> Tuple[infra.SelectionRegister, ...]: return self.prepare_gate.selection_registers @cached_property @@ -87,7 +86,7 @@ def decompose_from_registers( # 1. PREPARE† yield cirq.inverse(prepare_op) # 2. MultiControlled Z, controlled on |000..00> state. - phase_control = self.selection_registers.merge_qubits(**state_prep_selection_regs) + phase_control = infra.merge_qubits(self.selection_registers, **state_prep_selection_regs) yield cirq.X(phase_target) if not self.control_val else [] yield mcmt.MultiControlPauli([0] * len(phase_control), target_gate=cirq.Z).on_registers( controls=phase_control, target=phase_target @@ -102,8 +101,10 @@ def decompose_from_registers( qm.qfree([phase_target]) def _circuit_diagram_info_(self, args: cirq.CircuitDiagramInfoArgs) -> cirq.CircuitDiagramInfo: - wire_symbols = ['@' if self.control_val else '@(0)'] * self.control_registers.total_bits() - wire_symbols += ['R_L'] * self.selection_registers.total_bits() + wire_symbols = ['@' if self.control_val else '@(0)'] * infra.total_bits( + self.control_registers + ) + wire_symbols += ['R_L'] * infra.total_bits(self.selection_registers) return cirq.CircuitDiagramInfo(wire_symbols=wire_symbols) def __repr__(self): diff --git a/cirq-ft/cirq_ft/algos/reflection_using_prepare_test.py b/cirq-ft/cirq_ft/algos/reflection_using_prepare_test.py index 138415bfd35..b4a74c56ea7 100644 --- a/cirq-ft/cirq_ft/algos/reflection_using_prepare_test.py +++ b/cirq-ft/cirq_ft/algos/reflection_using_prepare_test.py @@ -16,6 +16,7 @@ import cirq import cirq_ft +from cirq_ft import infra import numpy as np import pytest @@ -108,7 +109,7 @@ def test_reflection_using_prepare_diagram(): ) # No control gate = cirq_ft.ReflectionUsingPrepare(prepare_gate, control_val=None) - op = gate.on_registers(**gate.registers.get_named_qubits()) + op = gate.on_registers(**infra.get_named_qubits(gate.registers)) circuit = greedily_allocate_ancilla(cirq.Circuit(cirq.decompose_once(op))) cirq.testing.assert_has_diagram( circuit, @@ -138,7 +139,7 @@ def test_reflection_using_prepare_diagram(): # Control on `|1>` state gate = cirq_ft.ReflectionUsingPrepare(prepare_gate, control_val=1) - op = gate.on_registers(**gate.registers.get_named_qubits()) + op = gate.on_registers(**infra.get_named_qubits(gate.registers)) circuit = greedily_allocate_ancilla(cirq.Circuit(cirq.decompose_once(op))) cirq.testing.assert_has_diagram( circuit, @@ -167,7 +168,7 @@ def test_reflection_using_prepare_diagram(): # Control on `|0>` state gate = cirq_ft.ReflectionUsingPrepare(prepare_gate, control_val=0) - op = gate.on_registers(**gate.registers.get_named_qubits()) + op = gate.on_registers(**infra.get_named_qubits(gate.registers)) circuit = greedily_allocate_ancilla(cirq.Circuit(cirq.decompose_once(op))) cirq.testing.assert_has_diagram( circuit, @@ -203,7 +204,7 @@ def test_reflection_using_prepare_consistent_protocols_and_controlled(): ) # No control gate = cirq_ft.ReflectionUsingPrepare(prepare_gate, control_val=None) - op = gate.on_registers(**gate.registers.get_named_qubits()) + op = gate.on_registers(**infra.get_named_qubits(gate.registers)) # Test consistent repr cirq.testing.assert_equivalent_repr( gate, setup_code='import cirq\nimport cirq_ft\nimport numpy as np' diff --git a/cirq-ft/cirq_ft/algos/select_and_prepare.py b/cirq-ft/cirq_ft/algos/select_and_prepare.py index b85fbccdfb3..72958b7fb4f 100644 --- a/cirq-ft/cirq_ft/algos/select_and_prepare.py +++ b/cirq-ft/cirq_ft/algos/select_and_prepare.py @@ -13,6 +13,7 @@ # limitations under the License. import abc +from typing import Tuple from cirq._compat import cached_property from cirq_ft import infra @@ -38,17 +39,17 @@ class SelectOracle(infra.GateWithRegisters): @property @abc.abstractmethod - def control_registers(self) -> infra.Registers: + def control_registers(self) -> Tuple[infra.Register, ...]: ... @property @abc.abstractmethod - def selection_registers(self) -> infra.SelectionRegisters: + def selection_registers(self) -> Tuple[infra.SelectionRegister, ...]: ... @property @abc.abstractmethod - def target_registers(self) -> infra.Registers: + def target_registers(self) -> Tuple[infra.Register, ...]: ... @cached_property @@ -75,12 +76,12 @@ class PrepareOracle(infra.GateWithRegisters): @property @abc.abstractmethod - def selection_registers(self) -> infra.SelectionRegisters: + def selection_registers(self) -> Tuple[infra.SelectionRegister, ...]: ... @cached_property - def junk_registers(self) -> infra.Registers: - return infra.Registers([]) + def junk_registers(self) -> Tuple[infra.Register, ...]: + return () @cached_property def registers(self) -> infra.Registers: diff --git a/cirq-ft/cirq_ft/algos/select_swap_qrom.py b/cirq-ft/cirq_ft/algos/select_swap_qrom.py index 4cde2b1f172..248dd1ab4f5 100644 --- a/cirq-ft/cirq_ft/algos/select_swap_qrom.py +++ b/cirq-ft/cirq_ft/algos/select_swap_qrom.py @@ -138,21 +138,19 @@ def __init__( self._data = tuple(tuple(d) for d in data) @cached_property - def selection_registers(self) -> infra.SelectionRegisters: - return infra.SelectionRegisters( - [ - infra.SelectionRegister( - 'selection', self.selection_q + self.selection_r, self._iteration_length - ) - ] + def selection_registers(self) -> Tuple[infra.SelectionRegister, ...]: + return ( + infra.SelectionRegister( + 'selection', self.selection_q + self.selection_r, self._iteration_length + ), ) @cached_property - def target_registers(self) -> infra.Registers: - clean_output = {} - for sequence_id in range(self._num_sequences): - clean_output[f'target{sequence_id}'] = self._target_bitsizes[sequence_id] - return infra.Registers.build(**clean_output) + def target_registers(self) -> Tuple[infra.Register, ...]: + return tuple( + infra.Register(f'target{sequence_id}', self._target_bitsizes[sequence_id]) + for sequence_id in range(self._num_sequences) + ) @cached_property def registers(self) -> infra.Registers: @@ -212,15 +210,16 @@ def decompose_from_registers( target_bitsizes=tuple(qrom_target_bitsizes), ) qrom_op = qrom_gate.on_registers( - selection=q, **qrom_gate.target_registers.split_qubits(ordered_target_qubits) + selection=q, **infra.split_qubits(qrom_gate.target_registers, ordered_target_qubits) ) swap_with_zero_gate = swap_network.SwapWithZeroGate( - k, self.target_registers.total_bits(), self.block_size + k, infra.total_bits(self.target_registers), self.block_size ) swap_with_zero_op = swap_with_zero_gate.on_registers( - selection=r, **swap_with_zero_gate.target_registers.split_qubits(ordered_target_qubits) + selection=r, + **infra.split_qubits(swap_with_zero_gate.target_registers, ordered_target_qubits), ) - clean_targets = self.target_registers.merge_qubits(**targets) + clean_targets = infra.merge_qubits(self.target_registers, **targets) cnot_op = cirq.Moment(cirq.CNOT(s, t) for s, t in zip(ordered_target_qubits, clean_targets)) # Yield the operations in correct order. yield qrom_op diff --git a/cirq-ft/cirq_ft/algos/select_swap_qrom_test.py b/cirq-ft/cirq_ft/algos/select_swap_qrom_test.py index f040f312bbf..85c64425d63 100644 --- a/cirq-ft/cirq_ft/algos/select_swap_qrom_test.py +++ b/cirq-ft/cirq_ft/algos/select_swap_qrom_test.py @@ -16,6 +16,7 @@ import cirq_ft import numpy as np import pytest +from cirq_ft import infra from cirq_ft.infra.bit_tools import iter_bits @@ -23,7 +24,7 @@ @pytest.mark.parametrize("block_size", [None, 1, 2, 3]) def test_select_swap_qrom(data, block_size): qrom = cirq_ft.SelectSwapQROM(*data, block_size=block_size) - qubit_regs = qrom.registers.get_named_qubits() + qubit_regs = infra.get_named_qubits(qrom.registers) selection = qubit_regs["selection"] selection_q, selection_r = selection[: qrom.selection_q], selection[qrom.selection_q :] targets = [qubit_regs[f"target{i}"] for i in range(len(data))] @@ -47,7 +48,7 @@ def test_select_swap_qrom(data, block_size): cirq.H.on_each(*dirty_target_ancilla), ) all_qubits = sorted(circuit.all_qubits()) - for selection_integer in range(qrom.selection_registers.iteration_lengths[0]): + for selection_integer in range(qrom.selection_registers[0].iteration_length): svals_q = list(iter_bits(selection_integer // qrom.block_size, len(selection_q))) svals_r = list(iter_bits(selection_integer % qrom.block_size, len(selection_r))) qubit_vals = {x: 0 for x in all_qubits} @@ -77,7 +78,7 @@ def test_qroam_diagram(): blocksize = 2 qrom = cirq_ft.SelectSwapQROM(*data, block_size=blocksize) q = cirq.LineQubit.range(cirq.num_qubits(qrom)) - circuit = cirq.Circuit(qrom.on_registers(**qrom.registers.split_qubits(q))) + circuit = cirq.Circuit(qrom.on_registers(**infra.split_qubits(qrom.registers, q))) cirq.testing.assert_has_diagram( circuit, """ diff --git a/cirq-ft/cirq_ft/algos/selected_majorana_fermion.py b/cirq-ft/cirq_ft/algos/selected_majorana_fermion.py index 501b23bb786..877c81f39a3 100644 --- a/cirq-ft/cirq_ft/algos/selected_majorana_fermion.py +++ b/cirq-ft/cirq_ft/algos/selected_majorana_fermion.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Sequence, Union +from typing import Sequence, Union, Tuple from numpy.typing import NDArray import attr @@ -34,7 +34,7 @@ class SelectedMajoranaFermionGate(unary_iteration_gate.UnaryIterationGate): Args: - selection_regs: Indexing `select` registers of type `SelectionRegisters`. It also contains + selection_regs: Indexing `select` registers of type `SelectionRegister`. It also contains information about the iteration length of each selection register. control_regs: Control registers for constructing a controlled version of the gate. target_gate: Single qubit gate to be applied to the target qubits. @@ -43,8 +43,13 @@ class SelectedMajoranaFermionGate(unary_iteration_gate.UnaryIterationGate): See Fig 9 of https://arxiv.org/abs/1805.03662 for more details. """ - selection_regs: infra.SelectionRegisters - control_regs: infra.Registers = infra.Registers.build(control=1) + selection_regs: Tuple[infra.SelectionRegister, ...] = attr.field( + converter=lambda v: (v,) if isinstance(v, infra.SelectionRegister) else tuple(v) + ) + control_regs: Tuple[infra.Register, ...] = attr.field( + converter=lambda v: (v,) if isinstance(v, infra.Register) else tuple(v), + default=(infra.Register('control', 1),), + ) target_gate: cirq.Gate = cirq.Y @classmethod @@ -55,38 +60,39 @@ def make_on( **quregs: Union[Sequence[cirq.Qid], NDArray[cirq.Qid]], # type: ignore[type-var] ) -> cirq.Operation: """Helper constructor to automatically deduce selection_regs attribute.""" - return cls( - selection_regs=infra.SelectionRegisters( - [ - infra.SelectionRegister( - 'selection', len(quregs['selection']), len(quregs['target']) - ) - ] + return SelectedMajoranaFermionGate( + selection_regs=infra.SelectionRegister( + 'selection', len(quregs['selection']), len(quregs['target']) ), target_gate=target_gate, ).on_registers(**quregs) @cached_property - def control_registers(self) -> infra.Registers: + def control_registers(self) -> Tuple[infra.Register, ...]: return self.control_regs @cached_property - def selection_registers(self) -> infra.SelectionRegisters: + def selection_registers(self) -> Tuple[infra.SelectionRegister, ...]: return self.selection_regs @cached_property - def target_registers(self) -> infra.Registers: - return infra.Registers.build(target=self.selection_regs.total_iteration_size) + def target_registers(self) -> Tuple[infra.Register, ...]: + total_iteration_size = np.product( + tuple(reg.iteration_length for reg in self.selection_registers) + ) + return (infra.Register('target', int(total_iteration_size)),) @cached_property - def extra_registers(self) -> infra.Registers: - return infra.Registers.build(accumulator=1) + def extra_registers(self) -> Tuple[infra.Register, ...]: + return (infra.Register('accumulator', 1),) def decompose_from_registers( self, context: cirq.DecompositionContext, **quregs: NDArray[cirq.Qid] ) -> cirq.OP_TREE: quregs['accumulator'] = np.array(context.qubit_manager.qalloc(1)) - control = quregs[self.control_regs[0].name] if self.control_registers.total_bits() else [] + control = ( + quregs[self.control_regs[0].name] if infra.total_bits(self.control_registers) else [] + ) yield cirq.X(*quregs['accumulator']).controlled_by(*control) yield super(SelectedMajoranaFermionGate, self).decompose_from_registers( context=context, **quregs @@ -94,9 +100,9 @@ def decompose_from_registers( context.qubit_manager.qfree(quregs['accumulator']) def _circuit_diagram_info_(self, args: cirq.CircuitDiagramInfoArgs) -> cirq.CircuitDiagramInfo: - wire_symbols = ["@"] * self.control_registers.total_bits() - wire_symbols += ["In"] * self.selection_registers.total_bits() - wire_symbols += [f"Z{self.target_gate}"] * self.target_registers.total_bits() + wire_symbols = ["@"] * infra.total_bits(self.control_registers) + wire_symbols += ["In"] * infra.total_bits(self.selection_registers) + wire_symbols += [f"Z{self.target_gate}"] * infra.total_bits(self.target_registers) return cirq.CircuitDiagramInfo(wire_symbols=wire_symbols) def nth_operation( # type: ignore[override] @@ -107,8 +113,9 @@ def nth_operation( # type: ignore[override] accumulator: Sequence[cirq.Qid], **selection_indices: int, ) -> cirq.OP_TREE: + selection_shape = tuple(reg.iteration_length for reg in self.selection_regs) selection_idx = tuple(selection_indices[reg.name] for reg in self.selection_regs) - target_idx = self.selection_registers.to_flat_idx(*selection_idx) + target_idx = int(np.ravel_multi_index(selection_idx, selection_shape)) yield cirq.CNOT(control, *accumulator) yield self.target_gate(target[target_idx]).controlled_by(control) yield cirq.CZ(*accumulator, target[target_idx]) diff --git a/cirq-ft/cirq_ft/algos/selected_majorana_fermion_test.py b/cirq-ft/cirq_ft/algos/selected_majorana_fermion_test.py index cb674c51cd2..9367bcc607f 100644 --- a/cirq-ft/cirq_ft/algos/selected_majorana_fermion_test.py +++ b/cirq-ft/cirq_ft/algos/selected_majorana_fermion_test.py @@ -16,6 +16,7 @@ import cirq_ft import numpy as np import pytest +from cirq_ft import infra from cirq_ft.infra.bit_tools import iter_bits @@ -23,13 +24,11 @@ @pytest.mark.parametrize("target_gate", [cirq.X, cirq.Y]) def test_selected_majorana_fermion_gate(selection_bitsize, target_bitsize, target_gate): gate = cirq_ft.SelectedMajoranaFermionGate( - cirq_ft.SelectionRegisters( - [cirq_ft.SelectionRegister('selection', selection_bitsize, target_bitsize)] - ), + cirq_ft.SelectionRegister('selection', selection_bitsize, target_bitsize), target_gate=target_gate, ) g = cirq_ft.testing.GateHelper(gate) - assert len(g.all_qubits) <= gate.registers.total_bits() + selection_bitsize + 1 + assert len(g.all_qubits) <= infra.total_bits(gate.registers) + selection_bitsize + 1 sim = cirq.Simulator(dtype=np.complex128) for n in range(target_bitsize): @@ -65,13 +64,11 @@ def test_selected_majorana_fermion_gate(selection_bitsize, target_bitsize, targe def test_selected_majorana_fermion_gate_diagram(): selection_bitsize, target_bitsize = 3, 5 gate = cirq_ft.SelectedMajoranaFermionGate( - cirq_ft.SelectionRegisters( - [cirq_ft.SelectionRegister('selection', selection_bitsize, target_bitsize)] - ), + cirq_ft.SelectionRegister('selection', selection_bitsize, target_bitsize), target_gate=cirq.X, ) - circuit = cirq.Circuit(gate.on_registers(**gate.registers.get_named_qubits())) - qubits = list(q for v in gate.registers.get_named_qubits().values() for q in v) + circuit = cirq.Circuit(gate.on_registers(**infra.get_named_qubits(gate.registers))) + qubits = list(q for v in infra.get_named_qubits(gate.registers).values() for q in v) cirq.testing.assert_has_diagram( circuit, """ @@ -100,9 +97,7 @@ def test_selected_majorana_fermion_gate_diagram(): def test_selected_majorana_fermion_gate_decomposed_diagram(): selection_bitsize, target_bitsize = 2, 3 gate = cirq_ft.SelectedMajoranaFermionGate( - cirq_ft.SelectionRegisters( - [cirq_ft.SelectionRegister('selection', selection_bitsize, target_bitsize)] - ), + cirq_ft.SelectionRegister('selection', selection_bitsize, target_bitsize), target_gate=cirq.X, ) greedy_mm = cirq_ft.GreedyQubitManager(prefix="_a", maximize_reuse=True) @@ -145,13 +140,11 @@ def test_selected_majorana_fermion_gate_decomposed_diagram(): def test_selected_majorana_fermion_gate_make_on(): selection_bitsize, target_bitsize = 3, 5 gate = cirq_ft.SelectedMajoranaFermionGate( - cirq_ft.SelectionRegisters( - [cirq_ft.SelectionRegister('selection', selection_bitsize, target_bitsize)] - ), + cirq_ft.SelectionRegister('selection', selection_bitsize, target_bitsize), target_gate=cirq.X, ) - op = gate.on_registers(**gate.registers.get_named_qubits()) + op = gate.on_registers(**infra.get_named_qubits(gate.registers)) op2 = cirq_ft.SelectedMajoranaFermionGate.make_on( - target_gate=cirq.X, **gate.registers.get_named_qubits() + target_gate=cirq.X, **infra.get_named_qubits(gate.registers) ) assert op == op2 diff --git a/cirq-ft/cirq_ft/algos/state_preparation.py b/cirq-ft/cirq_ft/algos/state_preparation.py index cb7fd397b03..bec54f50a6b 100644 --- a/cirq-ft/cirq_ft/algos/state_preparation.py +++ b/cirq-ft/cirq_ft/algos/state_preparation.py @@ -20,7 +20,7 @@ largest absolute error that one can tolerate in the prepared amplitudes. """ -from typing import List +from typing import List, Tuple from numpy.typing import NDArray import attr @@ -83,7 +83,9 @@ class StatePreparationAliasSampling(select_and_prepare.PrepareOracle): (https://arxiv.org/abs/1805.03662). Babbush et. al. (2018). Section III.D. and Figure 11. """ - selection_registers: infra.SelectionRegisters + selection_registers: Tuple[infra.SelectionRegister, ...] = attr.field( + converter=lambda v: (v,) if isinstance(v, infra.SelectionRegister) else tuple(v) + ) alt: NDArray[np.int_] keep: NDArray[np.int_] mu: int @@ -106,9 +108,7 @@ def from_lcu_probs( ) N = len(lcu_probabilities) return StatePreparationAliasSampling( - selection_registers=infra.SelectionRegisters( - [infra.SelectionRegister('selection', (N - 1).bit_length(), N)] - ), + selection_registers=infra.SelectionRegister('selection', (N - 1).bit_length(), N), alt=np.array(alt), keep=np.array(keep), mu=mu, @@ -120,7 +120,7 @@ def sigma_mu_bitsize(self) -> int: @cached_property def alternates_bitsize(self) -> int: - return self.selection_registers.total_bits() + return infra.total_bits(self.selection_registers) @cached_property def keep_bitsize(self) -> int: @@ -128,15 +128,17 @@ def keep_bitsize(self) -> int: @cached_property def selection_bitsize(self) -> int: - return self.selection_registers.total_bits() + return infra.total_bits(self.selection_registers) @cached_property - def junk_registers(self) -> infra.Registers: - return infra.Registers.build( - sigma_mu=self.sigma_mu_bitsize, - alt=self.alternates_bitsize, - keep=self.keep_bitsize, - less_than_equal=1, + def junk_registers(self) -> Tuple[infra.Register, ...]: + return tuple( + infra.Registers.build( + sigma_mu=self.sigma_mu_bitsize, + alt=self.alternates_bitsize, + keep=self.keep_bitsize, + less_than_equal=1, + ) ) def _value_equality_values_(self): diff --git a/cirq-ft/cirq_ft/algos/swap_network.py b/cirq-ft/cirq_ft/algos/swap_network.py index 480bf0e4f3c..279ab33be38 100644 --- a/cirq-ft/cirq_ft/algos/swap_network.py +++ b/cirq-ft/cirq_ft/algos/swap_network.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Sequence, Union +from typing import Sequence, Union, Tuple from numpy.typing import NDArray import attr @@ -145,14 +145,14 @@ def __attrs_post_init__(self): assert self.n_target_registers <= 2**self.selection_bitsize @cached_property - def selection_registers(self) -> infra.SelectionRegisters: - return infra.SelectionRegisters( - [infra.SelectionRegister('selection', self.selection_bitsize, self.n_target_registers)] + def selection_registers(self) -> Tuple[infra.SelectionRegister, ...]: + return ( + infra.SelectionRegister('selection', self.selection_bitsize, self.n_target_registers), ) @cached_property - def target_registers(self) -> infra.Registers: - return infra.Registers.build(target=(self.n_target_registers, self.target_bitsize)) + def target_registers(self) -> Tuple[infra.Register, ...]: + return (infra.Register('target', (self.n_target_registers, self.target_bitsize)),) @cached_property def registers(self) -> infra.Registers: diff --git a/cirq-ft/cirq_ft/algos/swap_network_test.py b/cirq-ft/cirq_ft/algos/swap_network_test.py index c934ca4dec3..92cb5865a2a 100644 --- a/cirq-ft/cirq_ft/algos/swap_network_test.py +++ b/cirq-ft/cirq_ft/algos/swap_network_test.py @@ -18,6 +18,7 @@ import cirq_ft import numpy as np import pytest +from cirq_ft import infra from cirq_ft.infra.jupyter_tools import execute_notebook random.seed(12345) @@ -65,7 +66,7 @@ def test_swap_with_zero_gate(selection_bitsize, target_bitsize, n_target_registe def test_swap_with_zero_gate_diagram(): gate = cirq_ft.SwapWithZeroGate(3, 2, 4) q = cirq.LineQubit.range(cirq.num_qubits(gate)) - circuit = cirq.Circuit(gate.on_registers(**gate.registers.split_qubits(q))) + circuit = cirq.Circuit(gate.on_registers(**infra.split_qubits(gate.registers, q))) cirq.testing.assert_has_diagram( circuit, """ diff --git a/cirq-ft/cirq_ft/algos/unary_iteration.ipynb b/cirq-ft/cirq_ft/algos/unary_iteration.ipynb index 003941203c3..4eabc65a0af 100644 --- a/cirq-ft/cirq_ft/algos/unary_iteration.ipynb +++ b/cirq-ft/cirq_ft/algos/unary_iteration.ipynb @@ -471,7 +471,7 @@ "metadata": {}, "outputs": [], "source": [ - "from cirq_ft import Registers, SelectionRegister, SelectionRegisters, UnaryIterationGate\n", + "from cirq_ft import Register, Registers, SelectionRegister, UnaryIterationGate\n", "from cirq._compat import cached_property\n", "\n", "class ApplyXToLthQubit(UnaryIterationGate):\n", @@ -481,16 +481,16 @@ " self._control_bitsize = control_bitsize\n", "\n", " @cached_property\n", - " def control_registers(self) -> Registers:\n", - " return Registers.build(control=self._control_bitsize)\n", + " def control_registers(self) -> Tuple[Register, ...]:\n", + " return Register('control', self._control_bitsize),\n", "\n", " @cached_property\n", - " def selection_registers(self) -> SelectionRegisters:\n", - " return SelectionRegisters([SelectionRegister('selection', self._selection_bitsize, self._target_bitsize)])\n", + " def selection_registers(self) -> Tuple[SelectionRegister, ...]:\n", + " return SelectionRegister('selection', self._selection_bitsize, self._target_bitsize),\n", "\n", " @cached_property\n", - " def target_registers(self) -> Registers:\n", - " return Registers.build(target=self._target_bitsize)\n", + " def target_registers(self) -> Tuple[Register, ...]:\n", + " return Register('target', self._target_bitsize),\n", "\n", " def nth_operation(\n", " self, context, control: cirq.Qid, selection: int, target: Sequence[cirq.Qid]\n", diff --git a/cirq-ft/cirq_ft/algos/unary_iteration_gate.py b/cirq-ft/cirq_ft/algos/unary_iteration_gate.py index f37c2718b3b..d72ab7381ce 100644 --- a/cirq-ft/cirq_ft/algos/unary_iteration_gate.py +++ b/cirq-ft/cirq_ft/algos/unary_iteration_gate.py @@ -268,17 +268,17 @@ class UnaryIterationGate(infra.GateWithRegisters): @cached_property @abc.abstractmethod - def control_registers(self) -> infra.Registers: + def control_registers(self) -> Tuple[infra.Register, ...]: pass @cached_property @abc.abstractmethod - def selection_registers(self) -> infra.SelectionRegisters: + def selection_registers(self) -> Tuple[infra.SelectionRegister, ...]: pass @cached_property @abc.abstractmethod - def target_registers(self) -> infra.Registers: + def target_registers(self) -> Tuple[infra.Register, ...]: pass @cached_property @@ -288,8 +288,8 @@ def registers(self) -> infra.Registers: ) @cached_property - def extra_registers(self) -> infra.Registers: - return infra.Registers([]) + def extra_registers(self) -> Tuple[infra.Register, ...]: + return () @abc.abstractmethod def nth_operation( @@ -325,7 +325,7 @@ def decompose_zero_selection( By default, if the selection register is empty, the decomposition will raise a `NotImplementedError`. The derived classes can override this method and specify a custom decomposition that should be used if the selection register is empty, - i.e. `self.selection_registers.total_bits() == 0`. + i.e. `infra.total_bits(self.selection_registers) == 0`. The derived classes should specify the following arguments as `**kwargs`: 1) Register names in `self.control_registers`: Each argument corresponds to a @@ -366,14 +366,14 @@ def _break_early(self, selection_index_prefix: Tuple[int, ...], l: int, r: int) def decompose_from_registers( self, *, context: cirq.DecompositionContext, **quregs: NDArray[cirq.Qid] ) -> cirq.OP_TREE: - if self.selection_registers.total_bits() == 0 or self._break_early( + if infra.total_bits(self.selection_registers) == 0 or self._break_early( (), 0, self.selection_registers[0].iteration_length ): return self.decompose_zero_selection(context=context, **quregs) num_loops = len(self.selection_registers) - target_regs = {k: v for k, v in quregs.items() if k in self.target_registers} - extra_regs = {k: v for k, v in quregs.items() if k in self.extra_registers} + target_regs = {reg.name: quregs[reg.name] for reg in self.target_registers} + extra_regs = {reg.name: quregs[reg.name] for reg in self.extra_registers} def unary_iteration_loops( nested_depth: int, @@ -430,7 +430,7 @@ def unary_iteration_loops( selection_reg_name_to_val.pop(self.selection_registers[nested_depth].name) yield ops - return unary_iteration_loops(0, {}, self.control_registers.merge_qubits(**quregs)) + return unary_iteration_loops(0, {}, infra.merge_qubits(self.control_registers, **quregs)) def _circuit_diagram_info_(self, args: cirq.CircuitDiagramInfoArgs) -> cirq.CircuitDiagramInfo: """Basic circuit diagram. @@ -438,7 +438,7 @@ def _circuit_diagram_info_(self, args: cirq.CircuitDiagramInfoArgs) -> cirq.Circ Descendants are encouraged to override this with more descriptive circuit diagram information. """ - wire_symbols = ["@"] * self.control_registers.total_bits() - wire_symbols += ["In"] * self.selection_registers.total_bits() - wire_symbols += [self.__class__.__name__] * self.target_registers.total_bits() + wire_symbols = ["@"] * infra.total_bits(self.control_registers) + wire_symbols += ["In"] * infra.total_bits(self.selection_registers) + wire_symbols += [self.__class__.__name__] * infra.total_bits(self.target_registers) return cirq.CircuitDiagramInfo(wire_symbols=wire_symbols) diff --git a/cirq-ft/cirq_ft/algos/unary_iteration_gate_test.py b/cirq-ft/cirq_ft/algos/unary_iteration_gate_test.py index ffa50adc940..0c754fc5980 100644 --- a/cirq-ft/cirq_ft/algos/unary_iteration_gate_test.py +++ b/cirq-ft/cirq_ft/algos/unary_iteration_gate_test.py @@ -19,6 +19,7 @@ import cirq_ft import pytest from cirq._compat import cached_property +from cirq_ft import infra from cirq_ft.infra.bit_tools import iter_bits from cirq_ft.infra.jupyter_tools import execute_notebook @@ -30,18 +31,18 @@ def __init__(self, selection_bitsize: int, target_bitsize: int, control_bitsize: self._control_bitsize = control_bitsize @cached_property - def control_registers(self) -> cirq_ft.Registers: - return cirq_ft.Registers.build(control=self._control_bitsize) + def control_registers(self) -> Tuple[cirq_ft.Register, ...]: + return (cirq_ft.Register('control', self._control_bitsize),) @cached_property - def selection_registers(self) -> cirq_ft.SelectionRegisters: - return cirq_ft.SelectionRegisters( - [cirq_ft.SelectionRegister('selection', self._selection_bitsize, self._target_bitsize)] + def selection_registers(self) -> Tuple[cirq_ft.SelectionRegister, ...]: + return ( + cirq_ft.SelectionRegister('selection', self._selection_bitsize, self._target_bitsize), ) @cached_property - def target_registers(self) -> cirq_ft.Registers: - return cirq_ft.Registers.build(target=self._target_bitsize) + def target_registers(self) -> Tuple[cirq_ft.Register, ...]: + return (cirq_ft.Register('target', self._target_bitsize),) def nth_operation( # type: ignore[override] self, @@ -83,24 +84,24 @@ def __init__(self, target_shape: Tuple[int, int, int]): self._target_shape = target_shape @cached_property - def control_registers(self) -> cirq_ft.Registers: - return cirq_ft.Registers([]) + def control_registers(self) -> Tuple[cirq_ft.Register, ...]: + return () @cached_property - def selection_registers(self) -> cirq_ft.SelectionRegisters: - return cirq_ft.SelectionRegisters( - [ - cirq_ft.SelectionRegister( - 'ijk'[i], (self._target_shape[i] - 1).bit_length(), self._target_shape[i] - ) - for i in range(3) - ] + def selection_registers(self) -> Tuple[cirq_ft.SelectionRegister, ...]: + return tuple( + cirq_ft.SelectionRegister( + 'ijk'[i], (self._target_shape[i] - 1).bit_length(), self._target_shape[i] + ) + for i in range(3) ) @cached_property - def target_registers(self) -> cirq_ft.Registers: - return cirq_ft.Registers.build( - t1=self._target_shape[0], t2=self._target_shape[1], t3=self._target_shape[2] + def target_registers(self) -> Tuple[cirq_ft.Register, ...]: + return tuple( + cirq_ft.Registers.build( + t1=self._target_shape[0], t2=self._target_shape[1], t3=self._target_shape[2] + ) ) def nth_operation( # type: ignore[override] @@ -123,7 +124,8 @@ def test_multi_dimensional_unary_iteration_gate(target_shape: Tuple[int, int, in gate = ApplyXToIJKthQubit(target_shape) g = cirq_ft.testing.GateHelper(gate, context=cirq.DecompositionContext(greedy_mm)) assert ( - len(g.all_qubits) <= gate.registers.total_bits() + gate.selection_registers.total_bits() - 1 + len(g.all_qubits) + <= infra.total_bits(gate.registers) + infra.total_bits(gate.selection_registers) - 1 ) max_i, max_j, max_k = target_shape @@ -147,10 +149,11 @@ def test_multi_dimensional_unary_iteration_gate(target_shape: Tuple[int, int, in def test_unary_iteration_loop(): n_range, m_range = (3, 5), (6, 8) - selection_registers = cirq_ft.SelectionRegisters( - [cirq_ft.SelectionRegister('n', 3, 5), cirq_ft.SelectionRegister('m', 3, 8)] - ) - selection = selection_registers.get_named_qubits() + selection_registers = [ + cirq_ft.SelectionRegister('n', 3, 5), + cirq_ft.SelectionRegister('m', 3, 8), + ] + selection = infra.get_named_qubits(selection_registers) target = {(n, m): cirq.q(f't({n}, {m})') for n in range(*n_range) for m in range(*m_range)} qm = cirq_ft.GreedyQubitManager("ancilla", maximize_reuse=True) circuit = cirq.Circuit() diff --git a/cirq-ft/cirq_ft/infra/__init__.py b/cirq-ft/cirq_ft/infra/__init__.py index bfc99572bef..02f503110ca 100644 --- a/cirq-ft/cirq_ft/infra/__init__.py +++ b/cirq-ft/cirq_ft/infra/__init__.py @@ -17,9 +17,11 @@ Register, Registers, SelectionRegister, - SelectionRegisters, + total_bits, + split_qubits, + merge_qubits, + get_named_qubits, ) from cirq_ft.infra.qubit_management_transformers import map_clean_and_borrowable_qubits from cirq_ft.infra.qubit_manager import GreedyQubitManager from cirq_ft.infra.t_complexity_protocol import TComplexity, t_complexity -from cirq_ft.infra.type_convertors import to_tuple diff --git a/cirq-ft/cirq_ft/infra/gate_with_registers.ipynb b/cirq-ft/cirq_ft/infra/gate_with_registers.ipynb index 70e4a6e59ba..6afb6d49d4f 100644 --- a/cirq-ft/cirq_ft/infra/gate_with_registers.ipynb +++ b/cirq-ft/cirq_ft/infra/gate_with_registers.ipynb @@ -49,7 +49,7 @@ "metadata": {}, "outputs": [], "source": [ - "from cirq_ft import Register, Registers\n", + "from cirq_ft import Register, Registers, infra\n", "\n", "control_reg = Register(name='control', shape=(2,))\n", "target_reg = Register(name='target', shape=(3,))\n", @@ -163,7 +163,7 @@ "outputs": [], "source": [ "r = gate.registers\n", - "quregs = r.get_named_qubits()\n", + "quregs = infra.get_named_qubits(r)\n", "quregs" ] }, diff --git a/cirq-ft/cirq_ft/infra/gate_with_registers.py b/cirq-ft/cirq_ft/infra/gate_with_registers.py index c65d3578902..7139b66a65a 100644 --- a/cirq-ft/cirq_ft/infra/gate_with_registers.py +++ b/cirq-ft/cirq_ft/infra/gate_with_registers.py @@ -43,7 +43,7 @@ def all_idxs(self) -> Iterable[Tuple[int, ...]]: def total_bits(self) -> int: """The total number of bits in this register. - This is the product of bitsize and each of the dimensions in `shape`. + This is the product of each of the dimensions in `shape`. """ return int(np.product(self.shape)) @@ -51,6 +51,68 @@ def __repr__(self): return f'cirq_ft.Register(name="{self.name}", shape={self.shape})' +def total_bits(registers: Iterable[Register]) -> int: + """Sum of `reg.total_bits()` for each register `reg` in input `registers`.""" + + return sum(reg.total_bits() for reg in registers) + + +def split_qubits( + registers: Iterable[Register], qubits: Sequence[cirq.Qid] +) -> Dict[str, NDArray[cirq.Qid]]: # type: ignore[type-var] + """Splits the flat list of qubits into a dictionary of appropriately shaped qubit arrays.""" + + qubit_regs = {} + base = 0 + for reg in registers: + qubit_regs[reg.name] = np.array(qubits[base : base + reg.total_bits()]).reshape(reg.shape) + base += reg.total_bits() + return qubit_regs + + +def merge_qubits( + registers: Iterable[Register], + **qubit_regs: Union[cirq.Qid, Sequence[cirq.Qid], NDArray[cirq.Qid]], +) -> List[cirq.Qid]: + """Merges the dictionary of appropriately shaped qubit arrays into a flat list of qubits.""" + + ret: List[cirq.Qid] = [] + for reg in registers: + if reg.name not in qubit_regs: + raise ValueError(f"All qubit registers must be present. {reg.name} not in qubit_regs") + qubits = qubit_regs[reg.name] + qubits = np.array([qubits] if isinstance(qubits, cirq.Qid) else qubits) + if qubits.shape != reg.shape: + raise ValueError( + f'{reg.name} register must of shape {reg.shape} but is of shape {qubits.shape}' + ) + ret += qubits.flatten().tolist() + return ret + + +def get_named_qubits(registers: Iterable[Register]) -> Dict[str, NDArray[cirq.Qid]]: + """Returns a dictionary of appropriately shaped named qubit registers for input `registers`.""" + + def _qubit_array(reg: Register): + qubits = np.empty(reg.shape, dtype=object) + for ii in reg.all_idxs(): + qubits[ii] = cirq.NamedQubit(f'{reg.name}[{", ".join(str(i) for i in ii)}]') + return qubits + + def _qubits_for_reg(reg: Register): + if len(reg.shape) > 1: + return _qubit_array(reg) + + return np.array( + [cirq.NamedQubit(f"{reg.name}")] + if reg.total_bits() == 1 + else cirq.NamedQubit.range(reg.total_bits(), prefix=reg.name), + dtype=object, + ) + + return {reg.name: _qubits_for_reg(reg) for reg in registers} + + class Registers: """An ordered collection of `cirq_ft.Register`. @@ -67,9 +129,6 @@ def __init__(self, registers: Iterable[Register]): def __repr__(self): return f'cirq_ft.Registers({self._registers})' - def total_bits(self) -> int: - return sum(reg.total_bits() for reg in self) - @classmethod def build(cls, **registers: Union[int, Tuple[int, ...]]) -> 'Registers': return cls(Register(name=k, shape=v) for k, v in registers.items()) @@ -105,54 +164,6 @@ def __iter__(self): def __len__(self) -> int: return len(self._registers) - def split_qubits( - self, qubits: Sequence[cirq.Qid] - ) -> Dict[str, NDArray[cirq.Qid]]: # type: ignore[type-var] - qubit_regs = {} - base = 0 - for reg in self: - qubit_regs[reg.name] = np.array(qubits[base : base + reg.total_bits()]).reshape( - reg.shape - ) - base += reg.total_bits() - return qubit_regs - - def merge_qubits( - self, **qubit_regs: Union[cirq.Qid, Sequence[cirq.Qid], NDArray[cirq.Qid]] - ) -> List[cirq.Qid]: - ret: List[cirq.Qid] = [] - for reg in self: - assert ( - reg.name in qubit_regs - ), f"All qubit registers must be present. {reg.name} not in qubit_regs" - qubits = qubit_regs[reg.name] - qubits = np.array([qubits] if isinstance(qubits, cirq.Qid) else qubits) - assert ( - qubits.shape == reg.shape - ), f'{reg.name} register must of shape {reg.shape} but is of shape {qubits.shape}' - ret += qubits.flatten().tolist() - return ret - - def get_named_qubits(self) -> Dict[str, NDArray[cirq.Qid]]: - def _qubit_array(reg: Register): - qubits = np.empty(reg.shape, dtype=object) - for ii in reg.all_idxs(): - qubits[ii] = cirq.NamedQubit(f'{reg.name}[{", ".join(str(i) for i in ii)}]') - return qubits - - def _qubits_for_reg(reg: Register): - if len(reg.shape) > 1: - return _qubit_array(reg) - - return np.array( - [cirq.NamedQubit(f"{reg.name}")] - if reg.total_bits() == 1 - else cirq.NamedQubit.range(reg.total_bits(), prefix=reg.name), - dtype=object, - ) - - return {reg.name: _qubits_for_reg(reg) for reg in self._registers} - def __eq__(self, other) -> bool: return self._registers == other._registers @@ -166,113 +177,65 @@ class SelectionRegister(Register): `SelectionRegister` extends the `Register` class to store the iteration length corresponding to that register along with its size. - """ - - iteration_length: int = attr.field() - - @iteration_length.default - def _default_iteration_length(self): - return 2 ** self.shape[0] - - @iteration_length.validator - def validate_iteration_length(self, attribute, value): - if len(self.shape) != 1: - raise ValueError(f'Selection register {self.name} should be flat. Found {self.shape=}') - if not (0 <= value <= 2 ** self.shape[0]): - raise ValueError(f'iteration length must be in range [0, 2^{self.shape[0]}]') - - def __repr__(self) -> str: - return ( - f'cirq_ft.SelectionRegister(' - f'name="{self.name}", ' - f'shape={self.shape}, ' - f'iteration_length={self.iteration_length})' - ) - - -class SelectionRegisters(Registers): - """Registers used to represent SELECT registers for various LCU methods. LCU methods often make use of coherent for-loops via UnaryIteration, iterating over a range - of values stored as a superposition over the `SELECT` register. The `SelectionRegisters` class - is used to represent such SELECT registers. In particular, it provides two additional features - on top of the regular `Registers` class: - - - For each selection register, we store the iteration length corresponding to that register - along with its size. - - We provide a default way of "flattening out" a composite index represented by a tuple of - values stored in multiple input selection registers to a single integer that can be used - to index a flat target register. - """ - - def __init__(self, registers: Iterable[SelectionRegister]): - super().__init__(registers) - self.iteration_lengths = tuple([reg.iteration_length for reg in registers]) - self._suffix_prod = np.multiply.accumulate(self.iteration_lengths[::-1])[::-1] - self._suffix_prod = np.append(self._suffix_prod, [1]) + of values stored as a superposition over the `SELECT` register. Such (nested) coherent + for-loops can be represented using a `Tuple[SelectionRegister, ...]` where the i'th entry + stores the bitsize and iteration length of i'th nested for-loop. - def to_flat_idx(self, *selection_vals: int) -> int: - """Flattens a composite index represented by a Tuple[int, ...] to a single output integer. - - For example: + One useful feature when processing such nested for-loops is to flatten out a composite index, + represented by a tuple of indices (i, j, ...), one for each selection register into a single + integer that can be used to index a flat target register. An example of such a mapping + function is described in Eq.45 of https://arxiv.org/abs/1805.03662. A general version of this + mapping function can be implemented using `numpy.ravel_multi_index` and `numpy.unravel_index`. + For example: 1) We can flatten a 2D for-loop as follows + >>> import numpy as np >>> N, M = 10, 20 >>> flat_indices = set() >>> for x in range(N): ... for y in range(M): ... flat_idx = x * M + y + ... assert np.ravel_multi_index((x, y), (N, M)) == flat_idx + ... assert np.unravel_index(flat_idx, (N, M)) == (x, y) ... flat_indices.add(flat_idx) >>> assert len(flat_indices) == N * M 2) Similarly, we can flatten a 3D for-loop as follows + >>> import numpy as np >>> N, M, L = 10, 20, 30 >>> flat_indices = set() >>> for x in range(N): ... for y in range(M): ... for z in range(L): ... flat_idx = x * M * L + y * L + z + ... assert np.ravel_multi_index((x, y, z), (N, M, L)) == flat_idx + ... assert np.unravel_index(flat_idx, (N, M, L)) == (x, y, z) ... flat_indices.add(flat_idx) >>> assert len(flat_indices) == N * M * L + """ - This is a general version of the mapping function described in Eq.45 of - https://arxiv.org/abs/1805.03662 - """ - assert len(selection_vals) == len(self) - return sum(v * self._suffix_prod[i + 1] for i, v in enumerate(selection_vals)) - - @property - def total_iteration_size(self) -> int: - return int(np.product(self.iteration_lengths)) - - @classmethod - def build(cls, **registers: Union[int, Tuple[int, ...]]) -> 'SelectionRegisters': - return cls(SelectionRegister(name=k, shape=v) for k, v in registers.items()) - - @overload - def __getitem__(self, key: int) -> SelectionRegister: - pass - - @overload - def __getitem__(self, key: str) -> SelectionRegister: - pass + iteration_length: int = attr.field() - @overload - def __getitem__(self, key: slice) -> 'SelectionRegisters': - pass + @iteration_length.default + def _default_iteration_length(self): + return 2 ** self.shape[0] - def __getitem__(self, key): - if isinstance(key, slice): - return SelectionRegisters(self._registers[key]) - elif isinstance(key, int): - return self._registers[key] - elif isinstance(key, str): - return self._register_dict[key] - else: - raise IndexError(f"key {key} must be of the type str/int/slice.") + @iteration_length.validator + def validate_iteration_length(self, attribute, value): + if len(self.shape) != 1: + raise ValueError(f'Selection register {self.name} should be flat. Found {self.shape=}') + if not (0 <= value <= 2 ** self.shape[0]): + raise ValueError(f'iteration length must be in range [0, 2^{self.shape[0]}]') def __repr__(self) -> str: - return f'cirq_ft.SelectionRegisters({self._registers})' + return ( + f'cirq_ft.SelectionRegister(' + f'name="{self.name}", ' + f'shape={self.shape}, ' + f'iteration_length={self.iteration_length})' + ) class GateWithRegisters(cirq.Gate, metaclass=abc.ABCMeta): @@ -329,7 +292,7 @@ def registers(self) -> Registers: ... def _num_qubits_(self) -> int: - return self.registers.total_bits() + return total_bits(self.registers) def decompose_from_registers( self, *, context: cirq.DecompositionContext, **quregs: NDArray[cirq.Qid] @@ -339,7 +302,7 @@ def decompose_from_registers( def _decompose_with_context_( self, qubits: Sequence[cirq.Qid], context: Optional[cirq.DecompositionContext] = None ) -> cirq.OP_TREE: - qubit_regs = self.registers.split_qubits(qubits) + qubit_regs = split_qubits(self.registers, qubits) if context is None: context = cirq.DecompositionContext(cirq.ops.SimpleQubitManager()) return self.decompose_from_registers(context=context, **qubit_regs) @@ -350,7 +313,7 @@ def _decompose_(self, qubits: Sequence[cirq.Qid]) -> cirq.OP_TREE: def on_registers( self, **qubit_regs: Union[cirq.Qid, Sequence[cirq.Qid], NDArray[cirq.Qid]] ) -> cirq.Operation: - return self.on(*self.registers.merge_qubits(**qubit_regs)) + return self.on(*merge_qubits(self.registers, **qubit_regs)) def _circuit_diagram_info_(self, args: cirq.CircuitDiagramInfoArgs) -> cirq.CircuitDiagramInfo: """Default diagram info that uses register names to name the boxes in multi-qubit gates. diff --git a/cirq-ft/cirq_ft/infra/gate_with_registers_test.py b/cirq-ft/cirq_ft/infra/gate_with_registers_test.py index a3442eb0554..7560cb7a357 100644 --- a/cirq-ft/cirq_ft/infra/gate_with_registers_test.py +++ b/cirq-ft/cirq_ft/infra/gate_with_registers_test.py @@ -17,6 +17,7 @@ import numpy as np import pytest from cirq_ft.infra.jupyter_tools import execute_notebook +from cirq_ft.infra import split_qubits, merge_qubits, get_named_qubits def test_register(): @@ -50,13 +51,20 @@ def test_registers(): assert list(regs) == [r1, r2, r3] qubits = cirq.LineQubit.range(8) - qregs = regs.split_qubits(qubits) + qregs = split_qubits(regs, qubits) assert qregs["r1"].tolist() == cirq.LineQubit.range(5) assert qregs["r2"].tolist() == cirq.LineQubit.range(5, 5 + 2) assert qregs["r3"].tolist() == [cirq.LineQubit(7)] qubits = qubits[::-1] - merged_qregs = regs.merge_qubits(r1=qubits[:5], r2=qubits[5:7], r3=qubits[-1]) + + with pytest.raises(ValueError, match="qubit registers must be present"): + _ = merge_qubits(regs, r1=qubits[:5], r2=qubits[5:7], r4=qubits[-1]) + + with pytest.raises(ValueError, match="register must of shape"): + _ = merge_qubits(regs, r1=qubits[:4], r2=qubits[5:7], r3=qubits[-1]) + + merged_qregs = merge_qubits(regs, r1=qubits[:5], r2=qubits[5:7], r3=qubits[-1]) assert merged_qregs == qubits expected_named_qubits = { @@ -65,7 +73,7 @@ def test_registers(): "r3": [cirq.NamedQubit("r3")], } - named_qregs = regs.get_named_qubits() + named_qregs = get_named_qubits(regs) for reg_name in expected_named_qubits: assert np.array_equal(named_qregs[reg_name], expected_named_qubits[reg_name]) @@ -73,7 +81,7 @@ def test_registers(): # initial registers. for reg_order in [[r1, r2, r3], [r2, r3, r1]]: flat_named_qubits = [ - q for v in cirq_ft.Registers(reg_order).get_named_qubits().values() for q in v + q for v in get_named_qubits(cirq_ft.Registers(reg_order)).values() for q in v ] expected_qubits = [q for r in reg_order for q in expected_named_qubits[r.name]] assert flat_named_qubits == expected_qubits @@ -81,15 +89,13 @@ def test_registers(): @pytest.mark.parametrize('n, N, m, M', [(4, 10, 5, 19), (4, 16, 5, 32)]) def test_selection_registers_indexing(n, N, m, M): - reg = cirq_ft.SelectionRegisters( - [cirq_ft.SelectionRegister('x', n, N), cirq_ft.SelectionRegister('y', m, M)] - ) - assert reg.iteration_lengths == (N, M) - for x in range(N): - for y in range(M): - assert reg.to_flat_idx(x, y) == x * M + y + regs = [cirq_ft.SelectionRegister('x', n, N), cirq_ft.SelectionRegister('y', m, M)] + for x in range(regs[0].iteration_length): + for y in range(regs[1].iteration_length): + assert np.ravel_multi_index((x, y), (N, M)) == x * M + y + assert np.unravel_index(x * M + y, (N, M)) == (x, y) - assert reg.total_iteration_size == N * M + assert np.product(tuple(reg.iteration_length for reg in regs)) == N * M def test_selection_registers_consistent(): @@ -99,7 +105,7 @@ def test_selection_registers_consistent(): with pytest.raises(ValueError, match="should be flat"): _ = cirq_ft.SelectionRegister('a', (3, 5), 5) - selection_reg = cirq_ft.SelectionRegisters( + selection_reg = cirq_ft.Registers( [ cirq_ft.SelectionRegister('n', shape=3, iteration_length=5), cirq_ft.SelectionRegister('m', shape=4, iteration_length=12), @@ -108,7 +114,7 @@ def test_selection_registers_consistent(): assert selection_reg[0] == cirq_ft.SelectionRegister('n', 3, 5) assert selection_reg['n'] == cirq_ft.SelectionRegister('n', 3, 5) assert selection_reg[1] == cirq_ft.SelectionRegister('m', 4, 12) - assert selection_reg[:1] == cirq_ft.SelectionRegisters([cirq_ft.SelectionRegister('n', 3, 5)]) + assert selection_reg[:1] == cirq_ft.Registers([cirq_ft.SelectionRegister('n', 3, 5)]) def test_registers_getitem_raises(): @@ -116,9 +122,7 @@ def test_registers_getitem_raises(): with pytest.raises(IndexError, match="must be of the type"): _ = g[2.5] - selection_reg = cirq_ft.SelectionRegisters( - [cirq_ft.SelectionRegister('n', shape=3, iteration_length=5)] - ) + selection_reg = cirq_ft.Registers([cirq_ft.SelectionRegister('n', shape=3, iteration_length=5)]) with pytest.raises(IndexError, match='must be of the type'): _ = selection_reg[2.5] diff --git a/cirq-ft/cirq_ft/infra/jupyter_tools.py b/cirq-ft/cirq_ft/infra/jupyter_tools.py index 148c29baca8..a9ae4817ef7 100644 --- a/cirq-ft/cirq_ft/infra/jupyter_tools.py +++ b/cirq-ft/cirq_ft/infra/jupyter_tools.py @@ -21,7 +21,7 @@ import IPython.display import ipywidgets import nbformat -from cirq_ft.infra import gate_with_registers, t_complexity_protocol +from cirq_ft.infra import gate_with_registers, t_complexity_protocol, get_named_qubits, merge_qubits from nbconvert.preprocessors import ExecutePreprocessor @@ -83,7 +83,7 @@ def svg_circuit( if registers is not None: qubit_order = cirq.QubitOrder.explicit( - registers.merge_qubits(**registers.get_named_qubits()), fallback=cirq.QubitOrder.DEFAULT + merge_qubits(registers, **get_named_qubits(registers)), fallback=cirq.QubitOrder.DEFAULT ) else: qubit_order = cirq.QubitOrder.DEFAULT diff --git a/cirq-ft/cirq_ft/infra/t_complexity.ipynb b/cirq-ft/cirq_ft/infra/t_complexity.ipynb index 3986abfaa5a..3a4c7c4596a 100644 --- a/cirq-ft/cirq_ft/infra/t_complexity.ipynb +++ b/cirq-ft/cirq_ft/infra/t_complexity.ipynb @@ -40,7 +40,7 @@ "import matplotlib.pyplot as plt\n", "import numpy as np\n", "\n", - "from cirq_ft import And, t_complexity" + "from cirq_ft import And, t_complexity, infra" ] }, { @@ -61,7 +61,7 @@ "# And of two qubits\n", "gate = And() # create an And gate\n", "# create an operation\n", - "operation = gate.on_registers(**gate.registers.get_named_qubits()) \n", + "operation = gate.on_registers(**infra.get_named_qubits(gate.registers))\n", "# this operation doesn't directly support TComplexity but it's decomposable and its components are simple.\n", "print(t_complexity(operation))" ] @@ -82,7 +82,7 @@ "outputs": [], "source": [ "gate = And() ** -1 # adjoint of And\n", - "operation = gate.on_registers(**gate.registers.get_named_qubits()) \n", + "operation = gate.on_registers(**infra.get_named_qubits(gate.registers))\n", "# the deomposition is H, measure, CZ, and Reset\n", "print(t_complexity(operation))" ] @@ -104,7 +104,7 @@ "source": [ "n = 5\n", "gate = And((1, )*n)\n", - "operation = gate.on_registers(**gate.registers.get_named_qubits()) \n", + "operation = gate.on_registers(**infra.get_named_qubits(gate.registers))\n", "print(t_complexity(operation))" ] }, @@ -122,7 +122,7 @@ " for n in range(2, n_max + 2):\n", " n_controls.append(n)\n", " gate = And(cv=(1, )*n)\n", - " op = gate.on_registers(**gate.registers.get_named_qubits()) \n", + " op = gate.on_registers(**infra.get_named_qubits(gate.registers))\n", " c = t_complexity(op)\n", " t_count.append(c.t)\n", " return n_controls, t_count" @@ -171,4 +171,4 @@ }, "nbformat": 4, "nbformat_minor": 5 -} +} \ No newline at end of file diff --git a/cirq-ft/cirq_ft/infra/t_complexity_protocol_test.py b/cirq-ft/cirq_ft/infra/t_complexity_protocol_test.py index 851e5907119..f28f3fc5e6a 100644 --- a/cirq-ft/cirq_ft/infra/t_complexity_protocol_test.py +++ b/cirq-ft/cirq_ft/infra/t_complexity_protocol_test.py @@ -15,6 +15,7 @@ import cirq import cirq_ft import pytest +from cirq_ft import infra from cirq_ft.infra.jupyter_tools import execute_notebook @@ -108,11 +109,11 @@ def test_operations(): assert cirq_ft.t_complexity(cirq.T(q)) == cirq_ft.TComplexity(t=1) gate = cirq_ft.And() - op = gate.on_registers(**gate.registers.get_named_qubits()) + op = gate.on_registers(**infra.get_named_qubits(gate.registers)) assert cirq_ft.t_complexity(op) == cirq_ft.TComplexity(t=4, clifford=9) gate = cirq_ft.And() ** -1 - op = gate.on_registers(**gate.registers.get_named_qubits()) + op = gate.on_registers(**infra.get_named_qubits(gate.registers)) assert cirq_ft.t_complexity(op) == cirq_ft.TComplexity(clifford=4) diff --git a/cirq-ft/cirq_ft/infra/testing.py b/cirq-ft/cirq_ft/infra/testing.py index 6ceb21d5c2a..31802d5300d 100644 --- a/cirq-ft/cirq_ft/infra/testing.py +++ b/cirq-ft/cirq_ft/infra/testing.py @@ -18,7 +18,7 @@ import cirq import numpy as np from cirq._compat import cached_property -from cirq_ft.infra import gate_with_registers, t_complexity_protocol +from cirq_ft.infra import gate_with_registers, t_complexity_protocol, merge_qubits, get_named_qubits from cirq_ft.infra.decompose_protocol import _decompose_once_considering_known_decomposition @@ -44,12 +44,12 @@ def r(self) -> gate_with_registers.Registers: @cached_property def quregs(self) -> Dict[str, NDArray[cirq.Qid]]: # type: ignore[type-var] """A dictionary of named qubits appropriate for the registers for the gate.""" - return self.r.get_named_qubits() + return get_named_qubits(self.r) @cached_property def all_qubits(self) -> List[cirq.Qid]: """All qubits in Register order.""" - merged_qubits = self.r.merge_qubits(**self.quregs) + merged_qubits = merge_qubits(self.r, **self.quregs) decomposed_qubits = self.decomposed_circuit.all_qubits() return merged_qubits + sorted(decomposed_qubits - frozenset(merged_qubits)) diff --git a/cirq-ft/cirq_ft/infra/type_convertors.py b/cirq-ft/cirq_ft/infra/type_convertors.py deleted file mode 100644 index fa182596df7..00000000000 --- a/cirq-ft/cirq_ft/infra/type_convertors.py +++ /dev/null @@ -1,20 +0,0 @@ -# Copyright 2023 The Cirq Developers -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Sequence, Tuple, Union - - -def to_tuple(x: Union[int, Sequence[int]]) -> Tuple[int, ...]: - """Mypy type-safe convertor to be used in an attrs field.""" - return (x,) if isinstance(x, int) else tuple(x) diff --git a/cirq-ft/cirq_ft/infra/type_convertors_test.py b/cirq-ft/cirq_ft/infra/type_convertors_test.py deleted file mode 100644 index ce1f780e6af..00000000000 --- a/cirq-ft/cirq_ft/infra/type_convertors_test.py +++ /dev/null @@ -1,21 +0,0 @@ -# Copyright 2023 The Cirq Developers -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import cirq_ft - - -def test_to_tuple(): - assert cirq_ft.infra.to_tuple([1, 2]) == (1, 2) - assert cirq_ft.infra.to_tuple((1, 2)) == (1, 2) - assert cirq_ft.infra.to_tuple(1) == (1,)