Skip to content

Commit 71f5466

Browse files
committed
Address Matt's comments
1 parent 041ff78 commit 71f5466

File tree

10 files changed

+72
-54
lines changed

10 files changed

+72
-54
lines changed

cirq-ft/cirq_ft/algos/and_gate.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,9 @@ class And(infra.GateWithRegisters):
4949
ValueError: If number of control values (i.e. `len(self.cv)`) is less than 2.
5050
"""
5151

52-
cv: Tuple[int, ...] = attr.field(default=(1, 1), converter=infra.to_tuple)
52+
cv: Tuple[int, ...] = attr.field(
53+
default=(1, 1), converter=lambda v: (v,) if isinstance(v, int) else tuple(v)
54+
)
5355
adjoint: bool = False
5456

5557
@cv.validator

cirq-ft/cirq_ft/algos/arithmetic_gates.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -611,7 +611,9 @@ class AddMod(cirq.ArithmeticGate):
611611
bitsize: int
612612
mod: int = attr.field()
613613
add_val: int = 1
614-
cv: Tuple[int, ...] = attr.field(converter=infra.to_tuple, default=())
614+
cv: Tuple[int, ...] = attr.field(
615+
converter=lambda v: (v,) if isinstance(v, int) else tuple(v), default=()
616+
)
615617

616618
@mod.validator
617619
def _validate_mod(self, attribute, value):

cirq-ft/cirq_ft/algos/mean_estimation/mean_estimation_operator.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,9 @@ class MeanEstimationOperator(infra.GateWithRegisters):
8080
"""
8181

8282
code: CodeForRandomVariable
83-
cv: Tuple[int, ...] = attr.field(converter=infra.to_tuple, default=())
83+
cv: Tuple[int, ...] = attr.field(
84+
converter=lambda v: (v,) if isinstance(v, int) else tuple(v), default=()
85+
)
8486
power: int = 1
8587
arctan_bitsize: int = 32
8688

cirq-ft/cirq_ft/algos/multi_control_multi_target_pauli.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ class MultiControlPauli(infra.GateWithRegisters):
7373
(https://algassert.com/circuits/2015/06/05/Constructing-Large-Controlled-Nots.html)
7474
"""
7575

76-
cvs: Tuple[int, ...] = attr.field(converter=infra.to_tuple)
76+
cvs: Tuple[int, ...] = attr.field(converter=lambda v: (v,) if isinstance(v, int) else tuple(v))
7777
target_gate: cirq.Pauli = cirq.X
7878

7979
@cached_property

cirq-ft/cirq_ft/algos/prepare_uniform_superposition.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,9 @@ class PrepareUniformSuperposition(infra.GateWithRegisters):
4646
"""
4747

4848
n: int
49-
cv: Tuple[int, ...] = attr.field(converter=infra.to_tuple, default=())
49+
cv: Tuple[int, ...] = attr.field(
50+
converter=lambda v: (v,) if isinstance(v, int) else tuple(v), default=()
51+
)
5052

5153
@cached_property
5254
def registers(self) -> infra.Registers:

cirq-ft/cirq_ft/infra/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,4 +25,3 @@
2525
from cirq_ft.infra.qubit_management_transformers import map_clean_and_borrowable_qubits
2626
from cirq_ft.infra.qubit_manager import GreedyQubitManager
2727
from cirq_ft.infra.t_complexity_protocol import TComplexity, t_complexity
28-
from cirq_ft.infra.type_convertors import to_tuple

cirq-ft/cirq_ft/infra/gate_with_registers.py

Lines changed: 52 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def all_idxs(self) -> Iterable[Tuple[int, ...]]:
4343
def total_bits(self) -> int:
4444
"""The total number of bits in this register.
4545
46-
This is the product of bitsize and each of the dimensions in `shape`.
46+
This is the product of each of the dimensions in `shape`.
4747
"""
4848
return int(np.product(self.shape))
4949

@@ -52,12 +52,16 @@ def __repr__(self):
5252

5353

5454
def total_bits(registers: Iterable[Register]) -> int:
55+
"""Sum of `reg.total_bits()` for each register `reg` in input `registers`."""
56+
5557
return sum(reg.total_bits() for reg in registers)
5658

5759

5860
def split_qubits(
5961
registers: Iterable[Register], qubits: Sequence[cirq.Qid]
6062
) -> Dict[str, NDArray[cirq.Qid]]: # type: ignore[type-var]
63+
"""Splits the flat list of qubits into a dictionary of appropriately shaped qubit arrays."""
64+
6165
qubit_regs = {}
6266
base = 0
6367
for reg in registers:
@@ -70,21 +74,25 @@ def merge_qubits(
7074
registers: Iterable[Register],
7175
**qubit_regs: Union[cirq.Qid, Sequence[cirq.Qid], NDArray[cirq.Qid]],
7276
) -> List[cirq.Qid]:
77+
"""Merges the dictionary of appropriately shaped qubit arrays into a flat list of qubits."""
78+
7379
ret: List[cirq.Qid] = []
7480
for reg in registers:
75-
assert (
76-
reg.name in qubit_regs
77-
), f"All qubit registers must be present. {reg.name} not in qubit_regs"
81+
if reg.name not in qubit_regs:
82+
raise ValueError(f"All qubit registers must be present. {reg.name} not in qubit_regs")
7883
qubits = qubit_regs[reg.name]
7984
qubits = np.array([qubits] if isinstance(qubits, cirq.Qid) else qubits)
80-
assert (
81-
qubits.shape == reg.shape
82-
), f'{reg.name} register must of shape {reg.shape} but is of shape {qubits.shape}'
85+
if qubits.shape != reg.shape:
86+
raise ValueError(
87+
f'{reg.name} register must of shape {reg.shape} but is of shape {qubits.shape}'
88+
)
8389
ret += qubits.flatten().tolist()
8490
return ret
8591

8692

8793
def get_named_qubits(registers: Iterable[Register]) -> Dict[str, NDArray[cirq.Qid]]:
94+
"""Returns a dictionary of appropriately shaped named qubit registers for input `registers`."""
95+
8896
def _qubit_array(reg: Register):
8997
qubits = np.empty(reg.shape, dtype=object)
9098
for ii in reg.all_idxs():
@@ -169,6 +177,43 @@ class SelectionRegister(Register):
169177
170178
`SelectionRegister` extends the `Register` class to store the iteration length
171179
corresponding to that register along with its size.
180+
181+
LCU methods often make use of coherent for-loops via UnaryIteration, iterating over a range
182+
of values stored as a superposition over the `SELECT` register. Such (nested) coherent
183+
for-loops can be represented using a `Tuple[SelectionRegister, ...]` where the i'th entry
184+
stores the bitsize and iteration length of i'th nested for-loop.
185+
186+
One useful feature when processing such nested for-loops is to flatten out a composite index,
187+
represented by a tuple of indices (i, j, ...), one for each selection register into a single
188+
integer that can be used to index a flat target register. An example of such a mapping
189+
function is described in Eq.45 of https://arxiv.org/abs/1805.03662. A general version of this
190+
mapping function can be implemented using `numpy.ravel_multi_index` and `numpy.unravel_index`.
191+
192+
For example:
193+
1) We can flatten a 2D for-loop as follows
194+
>>> import numpy as np
195+
>>> N, M = 10, 20
196+
>>> flat_indices = set()
197+
>>> for x in range(N):
198+
... for y in range(M):
199+
... flat_idx = x * M + y
200+
... assert np.ravel_multi_index((x, y), (N, M)) == flat_idx
201+
... assert np.unravel_index(flat_idx, (N, M)) == (x, y)
202+
... flat_indices.add(flat_idx)
203+
>>> assert len(flat_indices) == N * M
204+
205+
2) Similarly, we can flatten a 3D for-loop as follows
206+
>>> import numpy as np
207+
>>> N, M, L = 10, 20, 30
208+
>>> flat_indices = set()
209+
>>> for x in range(N):
210+
... for y in range(M):
211+
... for z in range(L):
212+
... flat_idx = x * M * L + y * L + z
213+
... assert np.ravel_multi_index((x, y, z), (N, M, L)) == flat_idx
214+
... assert np.unravel_index(flat_idx, (N, M, L)) == (x, y, z)
215+
... flat_indices.add(flat_idx)
216+
>>> assert len(flat_indices) == N * M * L
172217
"""
173218

174219
iteration_length: int = attr.field()

cirq-ft/cirq_ft/infra/gate_with_registers_test.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,13 @@ def test_registers():
5757
assert qregs["r3"].tolist() == [cirq.LineQubit(7)]
5858

5959
qubits = qubits[::-1]
60+
61+
with pytest.raises(ValueError, match="qubit registers must be present"):
62+
_ = merge_qubits(regs, r1=qubits[:5], r2=qubits[5:7], r4=qubits[-1])
63+
64+
with pytest.raises(ValueError, match="register must of shape"):
65+
_ = merge_qubits(regs, r1=qubits[:4], r2=qubits[5:7], r3=qubits[-1])
66+
6067
merged_qregs = merge_qubits(regs, r1=qubits[:5], r2=qubits[5:7], r3=qubits[-1])
6168
assert merged_qregs == qubits
6269

cirq-ft/cirq_ft/infra/type_convertors.py

Lines changed: 0 additions & 20 deletions
This file was deleted.

cirq-ft/cirq_ft/infra/type_convertors_test.py

Lines changed: 0 additions & 21 deletions
This file was deleted.

0 commit comments

Comments
 (0)