@@ -43,7 +43,7 @@ def all_idxs(self) -> Iterable[Tuple[int, ...]]:
4343 def total_bits (self ) -> int :
4444 """The total number of bits in this register.
4545
46- This is the product of bitsize and each of the dimensions in `shape`.
46+ This is the product of each of the dimensions in `shape`.
4747 """
4848 return int (np .product (self .shape ))
4949
@@ -52,12 +52,16 @@ def __repr__(self):
5252
5353
5454def total_bits (registers : Iterable [Register ]) -> int :
55+ """Sum of `reg.total_bits()` for each register `reg` in input `registers`."""
56+
5557 return sum (reg .total_bits () for reg in registers )
5658
5759
5860def split_qubits (
5961 registers : Iterable [Register ], qubits : Sequence [cirq .Qid ]
6062) -> Dict [str , NDArray [cirq .Qid ]]: # type: ignore[type-var]
63+ """Splits the flat list of qubits into a dictionary of appropriately shaped qubit arrays."""
64+
6165 qubit_regs = {}
6266 base = 0
6367 for reg in registers :
@@ -70,21 +74,25 @@ def merge_qubits(
7074 registers : Iterable [Register ],
7175 ** qubit_regs : Union [cirq .Qid , Sequence [cirq .Qid ], NDArray [cirq .Qid ]],
7276) -> List [cirq .Qid ]:
77+ """Merges the dictionary of appropriately shaped qubit arrays into a flat list of qubits."""
78+
7379 ret : List [cirq .Qid ] = []
7480 for reg in registers :
75- assert (
76- reg .name in qubit_regs
77- ), f"All qubit registers must be present. { reg .name } not in qubit_regs"
81+ if reg .name not in qubit_regs :
82+ raise ValueError (f"All qubit registers must be present. { reg .name } not in qubit_regs" )
7883 qubits = qubit_regs [reg .name ]
7984 qubits = np .array ([qubits ] if isinstance (qubits , cirq .Qid ) else qubits )
80- assert (
81- qubits .shape == reg .shape
82- ), f'{ reg .name } register must of shape { reg .shape } but is of shape { qubits .shape } '
85+ if qubits .shape != reg .shape :
86+ raise ValueError (
87+ f'{ reg .name } register must of shape { reg .shape } but is of shape { qubits .shape } '
88+ )
8389 ret += qubits .flatten ().tolist ()
8490 return ret
8591
8692
8793def get_named_qubits (registers : Iterable [Register ]) -> Dict [str , NDArray [cirq .Qid ]]:
94+ """Returns a dictionary of appropriately shaped named qubit registers for input `registers`."""
95+
8896 def _qubit_array (reg : Register ):
8997 qubits = np .empty (reg .shape , dtype = object )
9098 for ii in reg .all_idxs ():
@@ -169,6 +177,43 @@ class SelectionRegister(Register):
169177
170178 `SelectionRegister` extends the `Register` class to store the iteration length
171179 corresponding to that register along with its size.
180+
181+ LCU methods often make use of coherent for-loops via UnaryIteration, iterating over a range
182+ of values stored as a superposition over the `SELECT` register. Such (nested) coherent
183+ for-loops can be represented using a `Tuple[SelectionRegister, ...]` where the i'th entry
184+ stores the bitsize and iteration length of i'th nested for-loop.
185+
186+ One useful feature when processing such nested for-loops is to flatten out a composite index,
187+ represented by a tuple of indices (i, j, ...), one for each selection register into a single
188+ integer that can be used to index a flat target register. An example of such a mapping
189+ function is described in Eq.45 of https://arxiv.org/abs/1805.03662. A general version of this
190+ mapping function can be implemented using `numpy.ravel_multi_index` and `numpy.unravel_index`.
191+
192+ For example:
193+ 1) We can flatten a 2D for-loop as follows
194+ >>> import numpy as np
195+ >>> N, M = 10, 20
196+ >>> flat_indices = set()
197+ >>> for x in range(N):
198+ ... for y in range(M):
199+ ... flat_idx = x * M + y
200+ ... assert np.ravel_multi_index((x, y), (N, M)) == flat_idx
201+ ... assert np.unravel_index(flat_idx, (N, M)) == (x, y)
202+ ... flat_indices.add(flat_idx)
203+ >>> assert len(flat_indices) == N * M
204+
205+ 2) Similarly, we can flatten a 3D for-loop as follows
206+ >>> import numpy as np
207+ >>> N, M, L = 10, 20, 30
208+ >>> flat_indices = set()
209+ >>> for x in range(N):
210+ ... for y in range(M):
211+ ... for z in range(L):
212+ ... flat_idx = x * M * L + y * L + z
213+ ... assert np.ravel_multi_index((x, y, z), (N, M, L)) == flat_idx
214+ ... assert np.unravel_index(flat_idx, (N, M, L)) == (x, y, z)
215+ ... flat_indices.add(flat_idx)
216+ >>> assert len(flat_indices) == N * M * L
172217 """
173218
174219 iteration_length : int = attr .field ()
0 commit comments