Skip to content

Commit 18ae68a

Browse files
authored
Exclude dims (#91)
* Explicit Exclude_dims: * Updated tt_dimscheck * Update all uses of tt_dimscheck and propagate interface * Add test coverage for exclude dims changes * Tucker_als: Fix workaround that motivated exclude_dims
1 parent 362033b commit 18ae68a

File tree

11 files changed

+201
-92
lines changed

11 files changed

+201
-92
lines changed

pyttb/ktensor.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1792,7 +1792,7 @@ def tovec(self, include_weights=True):
17921792
# offset += f.shape[0]
17931793
return x
17941794

1795-
def ttv(self, vector, dims=None):
1795+
def ttv(self, vector, dims=None, exclude_dims=None):
17961796
"""
17971797
`Tensor` times vector for `ktensors`.
17981798
@@ -1833,7 +1833,7 @@ def ttv(self, vector, dims=None):
18331833
>>> weights = 2 * np.ones(rank)
18341834
>>> weights_and_data = np.concatenate((weights, data), axis=0)
18351835
>>> K = ttb.ktensor.from_vector(weights_and_data[:], shape, True)
1836-
>>> K0 = K.ttv(np.array([1, 1, 1]), dims=1) # compute along a single dimension
1836+
>>> K0 = K.ttv(np.array([1, 1, 1]),dims=1) # compute along a single dimension
18371837
>>> print(K0)
18381838
ktensor of shape 2 x 4
18391839
weights=[36. 54.]
@@ -1857,7 +1857,7 @@ def ttv(self, vector, dims=None):
18571857
18581858
Compute the product of a `ktensor` and multiple vectors out of order (results in a `ktensor`):
18591859
1860-
>>> K2 = K.ttv([vec4, vec3], np.array([2, 1]))
1860+
>>> K2 = K.ttv([vec4, vec3],np.array([2, 1]))
18611861
>>> print(K2)
18621862
ktensor of shape 2
18631863
weights=[1800. 3564.]
@@ -1866,17 +1866,20 @@ def ttv(self, vector, dims=None):
18661866
[2. 4.]]
18671867
"""
18681868

1869-
if dims is None:
1869+
if dims is None and exclude_dims is None:
18701870
dims = np.array([])
18711871
elif isinstance(dims, (float, int)):
18721872
dims = np.array([dims])
18731873

1874+
if isinstance(exclude_dims, (float, int)):
1875+
exclude_dims = np.array([exclude_dims])
1876+
18741877
# Check that vector is a list of vectors, if not place single vector as element in list
18751878
if len(vector) > 0 and isinstance(vector[0], (int, float, np.int_, np.float_)):
18761879
return self.ttv([vector], dims)
18771880

18781881
# Get sorted dims and index for multiplicands
1879-
dims, vidx = ttb.tt_dimscheck(dims, self.ndims, len(vector))
1882+
dims, vidx = ttb.tt_dimscheck(self.ndims, len(vector), dims, exclude_dims)
18801883

18811884
# Check that each multiplicand is the right size.
18821885
for i in range(dims.size):

pyttb/pyttb_utils.py

Lines changed: 46 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -114,17 +114,30 @@ def tt_union_rows(MatrixA, MatrixB):
114114

115115

116116
@overload
117-
def tt_dimscheck(dims: np.ndarray, N: int, M: None = None) -> Tuple[np.ndarray, None]:
117+
def tt_dimscheck(
118+
N: int,
119+
M: None = None,
120+
dims: Optional[np.ndarray] = None,
121+
exclude_dims: Optional[np.ndarray] = None,
122+
) -> Tuple[np.ndarray, None]:
118123
... # pragma: no cover see coveragepy/issues/970
119124

120125

121126
@overload
122-
def tt_dimscheck(dims: np.ndarray, N: int, M: int) -> Tuple[np.ndarray, np.ndarray]:
127+
def tt_dimscheck(
128+
N: int,
129+
M: int,
130+
dims: Optional[np.ndarray] = None,
131+
exclude_dims: Optional[np.ndarray] = None,
132+
) -> Tuple[np.ndarray, np.ndarray]:
123133
... # pragma: no cover see coveragepy/issues/970
124134

125135

126136
def tt_dimscheck(
127-
dims: np.ndarray, N: int, M: Optional[int] = None
137+
N: int,
138+
M: Optional[int] = None,
139+
dims: Optional[np.ndarray] = None,
140+
exclude_dims: Optional[np.ndarray] = None,
128141
) -> Tuple[np.ndarray, Optional[np.ndarray]]:
129142
"""
130143
Used to preprocess dimensions for tensor dimensions
@@ -136,24 +149,43 @@ def tt_dimscheck(
136149
-------
137150
138151
"""
139-
# Fix empty case
140-
if dims.size == 0:
141-
dims = np.arange(0, N)
152+
if dims is not None and exclude_dims is not None:
153+
raise ValueError("Either specify dims to include or exclude, but not both")
154+
155+
dim_array: np.ndarray = np.empty((1,))
142156

143-
# Fix "minus" case
144-
if np.max(dims) < 0:
157+
# Explicit exclude to resolve ambiguous -0
158+
if exclude_dims is not None:
145159
# Check that all members in range
146-
if not np.all(np.isin(-dims, np.arange(0, N + 1))):
147-
assert False, "Invalid magnitude for negative dims selection"
148-
dims = np.setdiff1d(np.arange(1, N + 1), -dims) - 1
160+
valid_indices = np.isin(exclude_dims, np.arange(0, N))
161+
if not np.all(valid_indices):
162+
invalid_indices = np.logical_not(valid_indices)
163+
raise ValueError(
164+
f"Exclude dims provided: {exclude_dims} "
165+
f"but, {exclude_dims[invalid_indices]} were out of valid range"
166+
f"[0,{N}]"
167+
)
168+
dim_array = np.setdiff1d(np.arange(0, N), exclude_dims)
169+
170+
# Fix empty case
171+
if (dims is None or dims.size == 0) and exclude_dims is None:
172+
dim_array = np.arange(0, N)
173+
elif isinstance(dims, np.ndarray):
174+
dim_array = dims
175+
176+
# Catch minus case to avoid silent errors
177+
if np.any(dim_array < 0):
178+
raise ValueError(
179+
"Negative dims aren't allowed in pyttb, see exclude_dims argument instead"
180+
)
149181

150182
# Save dimensions of dims
151-
P = len(dims)
183+
P = len(dim_array)
152184

153185
# Reorder dims from smallest to largest (this matters in particular for the vector
154186
# multiplicand case, where the order affects the result)
155-
sidx = np.argsort(dims)
156-
sdims = dims[sidx]
187+
sidx = np.argsort(dim_array)
188+
sdims = dim_array[sidx]
157189
vidx = None
158190

159191
if M is not None:

pyttb/sptensor.py

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -382,7 +382,7 @@ def collapse(
382382
if dims is None:
383383
dims = np.arange(0, self.ndims)
384384

385-
dims, _ = tt_dimscheck(dims, self.ndims)
385+
dims, _ = tt_dimscheck(self.ndims, dims=dims)
386386
remdims = np.setdiff1d(np.arange(0, self.ndims), dims)
387387

388388
# Check for the case where we accumulate over *all* dimensions
@@ -882,7 +882,7 @@ def mttkrp(self, U: Union[ttb.ktensor, List[np.ndarray]], n: int) -> np.ndarray:
882882
else:
883883
Z.append(np.array([]))
884884
# Perform ttv multiplication
885-
V[:, r] = self.ttv(Z, -(n + 1)).double()
885+
V[:, r] = self.ttv(Z, exclude_dims=n).double()
886886

887887
return V
888888

@@ -1044,7 +1044,7 @@ def scale(self, factor: np.ndarray, dims: Union[float, np.ndarray]) -> sptensor:
10441044
"""
10451045
if isinstance(dims, (float, int)):
10461046
dims = np.array([dims])
1047-
dims, _ = ttb.tt_dimscheck(dims, self.ndims)
1047+
dims, _ = ttb.tt_dimscheck(self.ndims, dims=dims)
10481048

10491049
if isinstance(factor, ttb.tensor):
10501050
shapeArray = np.array(self.shape)
@@ -1181,6 +1181,7 @@ def ttv(
11811181
self,
11821182
vector: Union[np.ndarray, List[np.ndarray]],
11831183
dims: Optional[Union[int, np.ndarray]] = None,
1184+
exclude_dims: Optional[Union[int, np.ndarray]] = None,
11841185
) -> Union[sptensor, ttb.tensor]:
11851186
"""
11861187
Sparse tensor times vector
@@ -1189,20 +1190,24 @@ def ttv(
11891190
----------
11901191
vector: Vector(s) to multiply against
11911192
dims: Dimensions to multiply with vector(s)
1193+
exclude_dims: Use all dimensions but these
11921194
"""
11931195

1194-
if dims is None:
1196+
if dims is None and exclude_dims is None:
11951197
dims = np.array([])
11961198
elif isinstance(dims, (float, int)):
11971199
dims = np.array([dims])
11981200

1201+
if isinstance(exclude_dims, (float, int)):
1202+
exclude_dims = np.array([exclude_dims])
1203+
11991204
# Check that vector is a list of vectors,
12001205
# if not place single vector as element in list
12011206
if len(vector) > 0 and isinstance(vector[0], (int, float, np.int_, np.float_)):
1202-
return self.ttv(np.array([vector]), dims)
1207+
return self.ttv(np.array([vector]), dims, exclude_dims)
12031208

12041209
# Get sorted dims and index for multiplicands
1205-
dims, vidx = ttb.tt_dimscheck(dims, self.ndims, len(vector))
1210+
dims, vidx = ttb.tt_dimscheck(self.ndims, len(vector), dims, exclude_dims)
12061211
remdims = np.setdiff1d(np.arange(0, self.ndims), dims).astype(int)
12071212

12081213
# Check that each multiplicand is the right size.
@@ -2495,6 +2500,7 @@ def ttm(
24952500
self,
24962501
matrices: Union[np.ndarray, List[np.ndarray]],
24972502
dims: Optional[Union[float, np.ndarray]] = None,
2503+
exclude_dims: Optional[Union[float, np.ndarray]] = None,
24982504
transpose: bool = False,
24992505
):
25002506
"""
@@ -2503,24 +2509,28 @@ def ttm(
25032509
Parameters
25042510
----------
25052511
matrices: A matrix or list of matrices
2506-
dims: :class:`Numpy.ndarray`, int
2512+
dims: Dimensions to multiply against
2513+
exclude_dims: Use all dimensions but these
25072514
transpose: Transpose matrices to be multiplied
25082515
25092516
Returns
25102517
-------
25112518
25122519
"""
2513-
if dims is None:
2520+
if dims is None and exclude_dims is None:
25142521
dims = np.arange(self.ndims)
25152522
elif isinstance(dims, list):
25162523
dims = np.array(dims)
25172524
elif isinstance(dims, (float, int, np.generic)):
25182525
dims = np.array([dims])
25192526

2527+
if isinstance(exclude_dims, (float, int)):
2528+
exclude_dims = np.array([exclude_dims])
2529+
25202530
# Handle list of matrices
25212531
if isinstance(matrices, list):
25222532
# Check dimensions are valid
2523-
[dims, vidx] = tt_dimscheck(dims, self.ndims, len(matrices))
2533+
[dims, vidx] = tt_dimscheck(self.ndims, len(matrices), dims, exclude_dims)
25242534
# Calculate individual products
25252535
Y = self.ttm(matrices[vidx[0]], dims[0], transpose=transpose)
25262536
for i in range(1, dims.size):
@@ -2535,6 +2545,10 @@ def ttm(
25352545
if transpose:
25362546
matrices = matrices.transpose()
25372547

2548+
# FIXME: This made typing happy but shouldn't be possible
2549+
if not isinstance(dims, np.ndarray): # pragma: no cover
2550+
raise ValueError("Dims should be an array here")
2551+
25382552
# Ensure this is the terminal single dimension case
25392553
if not (dims.size == 1 and np.isin(dims, np.arange(self.ndims))):
25402554
assert False, "dims must contain values in [0,self.dims)"

0 commit comments

Comments
 (0)