Skip to content

Commit e014f4d

Browse files
authored
Small fixes of linalg.decompositions (#7128)
* Small fixes of linalg.decomposition. * fix lint
1 parent 540445a commit e014f4d

File tree

2 files changed

+24
-44
lines changed

2 files changed

+24
-44
lines changed

cirq-core/cirq/linalg/decompositions.py

Lines changed: 10 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
Iterable,
2525
List,
2626
Optional,
27-
Set,
2827
Tuple,
2928
TYPE_CHECKING,
3029
TypeVar,
@@ -106,29 +105,6 @@ def deconstruct_single_qubit_matrix_into_angles(mat: np.ndarray) -> Tuple[float,
106105
return right_phase + diagonal_phase, rotation * 2, bottom_phase
107106

108107

109-
def _group_similar(items: List[T], comparer: Callable[[T, T], bool]) -> List[List[T]]:
110-
"""Combines similar items into groups.
111-
112-
Args:
113-
items: The list of items to group.
114-
comparer: Determines if two items are similar.
115-
116-
Returns:
117-
A list of groups of items.
118-
"""
119-
groups: List[List[T]] = []
120-
used: Set[int] = set()
121-
for i in range(len(items)):
122-
if i not in used:
123-
group = [items[i]]
124-
for j in range(i + 1, len(items)):
125-
if j not in used and comparer(items[i], items[j]):
126-
used.add(j)
127-
group.append(items[j])
128-
groups.append(group)
129-
return groups
130-
131-
132108
def unitary_eig(
133109
matrix: np.ndarray, check_preconditions: bool = True, atol: float = 1e-8
134110
) -> Tuple[np.ndarray, np.ndarray]:
@@ -175,7 +151,6 @@ def map_eigenvalues(
175151
Args:
176152
matrix: The matrix to modify with the function.
177153
func: The function to apply to the eigenvalues of the matrix.
178-
rtol: Relative threshold used when separating eigenspaces.
179154
atol: Absolute threshold used when separating eigenspaces.
180155
181156
Returns:
@@ -191,15 +166,18 @@ def map_eigenvalues(
191166
return total
192167

193168

194-
def kron_factor_4x4_to_2x2s(matrix: np.ndarray) -> Tuple[complex, np.ndarray, np.ndarray]:
169+
def kron_factor_4x4_to_2x2s(
170+
matrix: np.ndarray, rtol=1e-5, atol=1e-8
171+
) -> Tuple[complex, np.ndarray, np.ndarray]:
195172
"""Splits a 4x4 matrix U = kron(A, B) into A, B, and a global factor.
196173
197174
Requires the matrix to be the kronecker product of two 2x2 unitaries.
198175
Requires the matrix to have a non-zero determinant.
199-
Giving an incorrect matrix will cause garbage output.
200176
201177
Args:
202178
matrix: The 4x4 unitary matrix to factor.
179+
rtol: Per-matrix-entry relative tolerance on equality.
180+
atol: Per-matrix-entry absolute tolerance on equality.
203181
204182
Returns:
205183
A scalar factor and a pair of 2x2 unit-determinant matrices. The
@@ -232,6 +210,9 @@ def kron_factor_4x4_to_2x2s(matrix: np.ndarray) -> Tuple[complex, np.ndarray, np
232210
f1 *= -1
233211
g = -g
234212

213+
if not np.allclose(matrix, g * np.kron(f1, f2), rtol=rtol, atol=atol):
214+
raise ValueError("Invalid 4x4 kronecker product.")
215+
235216
return g, f1, f2
236217

237218

@@ -266,7 +247,7 @@ def so4_to_magic_su2s(
266247
raise ValueError('mat must be 4x4 special orthogonal.')
267248

268249
ab = combinators.dot(MAGIC, mat, MAGIC_CONJ_T)
269-
_, a, b = kron_factor_4x4_to_2x2s(ab)
250+
_, a, b = kron_factor_4x4_to_2x2s(ab, rtol, atol)
270251

271252
return a, b
272253

@@ -987,7 +968,7 @@ def _canonicalize_kak_vector(k_vec: np.ndarray, atol: float) -> np.ndarray:
987968
unitaries required to bring the KAK vector into canonical form.
988969
989970
Args:
990-
k_vec: THe KAK vector to be canonicalized. This input may be vectorized,
971+
k_vec: The KAK vector to be canonicalized. This input may be vectorized,
991972
with shape (...,3), where the final axis denotes the k_vector and
992973
all other axes are broadcast.
993974
atol: How close x2 must be to π/4 to guarantee z2 >= 0.

cirq-core/cirq/linalg/decompositions_test.py

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import cirq
2121
from cirq import value
2222
from cirq import unitary_eig
23+
from cirq.linalg.decompositions import MAGIC, MAGIC_CONJ_T
2324

2425
X = np.array([[0, 1], [1, 0]])
2526
Y = np.array([[0, -1j], [1j, 0]])
@@ -45,9 +46,7 @@ def assert_kronecker_factorization_not_within_tolerance(matrix, g, f1, f2):
4546

4647

4748
def assert_magic_su2_within_tolerance(mat, a, b):
48-
M = cirq.linalg.decompositions.MAGIC
49-
MT = cirq.linalg.decompositions.MAGIC_CONJ_T
50-
recon = cirq.linalg.combinators.dot(MT, cirq.linalg.combinators.kron(a, b), M)
49+
recon = cirq.linalg.combinators.dot(MAGIC_CONJ_T, cirq.linalg.combinators.kron(a, b), MAGIC)
5150
assert np.allclose(recon, mat), "Failed to decompose within tolerance."
5251

5352

@@ -149,14 +148,15 @@ def test_kron_factor_special_unitaries(f1, f2):
149148
assert_kronecker_factorization_within_tolerance(p, g, g1, g2)
150149

151150

152-
def test_kron_factor_fail():
153-
mat = cirq.kron_with_controls(cirq.CONTROL_TAG, X)
154-
g, f1, f2 = cirq.kron_factor_4x4_to_2x2s(mat)
155-
with pytest.raises(ValueError):
156-
assert_kronecker_factorization_not_within_tolerance(mat, g, f1, f2)
157-
mat = cirq.kron_factor_4x4_to_2x2s(np.diag([1, 1, 1, 1j]))
158-
with pytest.raises(ValueError):
159-
assert_kronecker_factorization_not_within_tolerance(mat, g, f1, f2)
151+
def test_kron_factor_invalid_input():
152+
mats = [
153+
cirq.kron_with_controls(cirq.CONTROL_TAG, X),
154+
np.array([[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 2, 3, 4]]),
155+
np.diag([1, 1, 1, 1j]),
156+
]
157+
for mat in mats:
158+
with pytest.raises(ValueError, match="Invalid 4x4 kronecker product"):
159+
cirq.kron_factor_4x4_to_2x2s(mat)
160160

161161

162162
def recompose_so4(a: np.ndarray, b: np.ndarray) -> np.ndarray:
@@ -165,8 +165,7 @@ def recompose_so4(a: np.ndarray, b: np.ndarray) -> np.ndarray:
165165
assert cirq.is_special_unitary(a)
166166
assert cirq.is_special_unitary(b)
167167

168-
magic = np.array([[1, 0, 0, 1j], [0, 1j, 1, 0], [0, 1j, -1, 0], [1, 0, 0, -1j]]) * np.sqrt(0.5)
169-
result = np.real(cirq.dot(np.conj(magic.T), cirq.kron(a, b), magic))
168+
result = np.real(cirq.dot(MAGIC_CONJ_T, cirq.kron(a, b), MAGIC))
170169
assert cirq.is_orthogonal(result)
171170
return result
172171

@@ -656,7 +655,7 @@ def test_kak_vector_matches_vectorized():
656655
np.testing.assert_almost_equal(actual, expected)
657656

658657

659-
def test_KAK_vector_local_invariants_random_input():
658+
def test_kak_vector_local_invariants_random_input():
660659
actual = _local_invariants_from_kak(cirq.kak_vector(_random_unitaries))
661660
expected = _local_invariants_from_kak(_kak_vecs)
662661

@@ -697,7 +696,7 @@ def test_kak_vector_on_weyl_chamber_face():
697696
(np.kron(X, X), (0, 0, 0)),
698697
),
699698
)
700-
def test_KAK_vector_weyl_chamber_vertices(unitary, expected):
699+
def test_kak_vector_weyl_chamber_vertices(unitary, expected):
701700
actual = cirq.kak_vector(unitary)
702701
np.testing.assert_almost_equal(actual, expected)
703702

0 commit comments

Comments
 (0)