Skip to content

Commit ec36a27

Browse files
authored
TUCKER_ALS: TTM with negative values is broken in ttensor (#62) (#66)
* Replace usage in tucker_als * Update test for tucker_als to ensure result matches expectation * Add early error handling in ttensor ttm for negative dims
1 parent 992772b commit ec36a27

File tree

5 files changed

+18
-14
lines changed

5 files changed

+18
-14
lines changed

pyttb/pyttb_utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ def tt_dimscheck(
198198

199199
# Fix "minus" case
200200
if np.max(dims) < 0:
201-
# Check that all memebers in range
201+
# Check that all members in range
202202
if not np.all(np.isin(-dims, np.arange(0, N + 1))):
203203
assert False, "Invalid magnitude for negative dims selection"
204204
dims = np.setdiff1d(np.arange(1, N + 1), -dims) - 1

pyttb/ttensor.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -434,7 +434,9 @@ def ttm(self, matrix, dims=None, transpose=False):
434434
dims = np.arange(self.ndims)
435435
elif isinstance(dims, list):
436436
dims = np.array(dims)
437-
elif np.isscalar(dims) or isinstance(dims, list):
437+
elif np.isscalar(dims):
438+
if dims < 0:
439+
raise ValueError("Negative dims is currently unsupported, see #62")
438440
dims = np.array([dims])
439441

440442
if not isinstance(matrix, list):

pyttb/tucker_als.py

+5-10
Original file line numberDiff line numberDiff line change
@@ -124,14 +124,11 @@ def tucker_als(
124124

125125
# Iterate over all N modes of the tensor
126126
for n in dimorder:
127-
if (
128-
n == 0
129-
): # TODO proposal to change ttm to include_dims and exclude_dims to resolve -0 ambiguity
130-
dims = np.arange(1, tensor.ndims)
131-
Utilde = tensor.ttm(U, dims, True)
132-
else:
133-
Utilde = tensor.ttm(U, -n, True)
134-
127+
# TODO proposal to change ttm to include_dims and exclude_dims to resolve -0 ambiguity
128+
dims = np.arange(0, tensor.ndims)
129+
dims = dims[dims != n]
130+
Utilde = tensor.ttm(U, dims, True)
131+
print(f"Utilde[{n}] = {Utilde}")
135132
# Maximize norm(Utilde x_n W') wrt W and
136133
# maintain orthonormality of W
137134
U[n] = Utilde.nvecs(n, rank[n])
@@ -140,13 +137,11 @@ def tucker_als(
140137
core = Utilde.ttm(U, n, True)
141138

142139
# Compute fit
143-
# TODO this abs is missing from MATLAB, but I get negative numbers for trivial examples
144140
normresidual = np.sqrt(abs(normX**2 - core.norm() ** 2))
145141
fit = 1 - (normresidual / normX) # fraction explained by model
146142
fitchange = abs(fitold - fit)
147143

148144
if iter % printitn == 0:
149-
print(f" NormX: {normX} Core norm: {core.norm()}")
150145
print(f" Iter {iter}: fit = {fit:e} fitdelta = {fitchange:7.1e}\n")
151146

152147
# Check for convergence

tests/test_ttensor.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -310,9 +310,15 @@ def test_ttensor_ttm(random_ttensor):
310310

311311
# Negative Tests
312312
big_wrong_size = 123
313-
matrices[0] = np.random.random((big_wrong_size, big_wrong_size))
313+
bad_matrices = matrices.copy()
314+
bad_matrices[0] = np.random.random((big_wrong_size, big_wrong_size))
314315
with pytest.raises(ValueError):
315-
_ = ttensorInstance.ttm(matrices, np.arange(len(matrices)))
316+
_ = ttensorInstance.ttm(bad_matrices, np.arange(len(bad_matrices)))
317+
318+
with pytest.raises(ValueError):
319+
# Negative dims currently broken, ensure we catch early and
320+
# remove once resolved
321+
ttensorInstance.ttm(matrices, -1)
316322

317323

318324
@pytest.mark.indevelopment

tests/test_tucker_als.py

+1
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ def test_tucker_als_tensor_default_init(capsys, sample_tensor):
1919
(Solution, Uinit, output) = ttb.tucker_als(T, 2)
2020
capsys.readouterr()
2121
assert pytest.approx(output["fit"], 1) == 0
22+
assert np.all(np.isclose(Solution.double(), T.double()))
2223

2324
(Solution, Uinit, output) = ttb.tucker_als(T, 2, init=Uinit)
2425
capsys.readouterr()

0 commit comments

Comments
 (0)