From 69624cce8015dee7b8cae6cd3ad81d802c820f18 Mon Sep 17 00:00:00 2001 From: Nick Johnson <24689722+ntjohnson1@users.noreply.github.com> Date: Sat, 6 May 2023 12:26:42 -0400 Subject: [PATCH 1/5] Tensor.__setitem__/__getitem__: Fix linear index * Before required numpy array now works on value/slice/Iterable --- pyttb/tensor.py | 46 ++++++++++++++++++++++++++++++++++++-------- tests/test_tensor.py | 10 ++++++++++ 2 files changed, 48 insertions(+), 8 deletions(-) diff --git a/pyttb/tensor.py b/pyttb/tensor.py index 3bebc6ea..6fa4290f 100644 --- a/pyttb/tensor.py +++ b/pyttb/tensor.py @@ -5,6 +5,7 @@ from __future__ import annotations import logging +from collections.abc import Iterable from itertools import permutations from math import factorial from typing import Any, Callable, List, Optional, Tuple, Union @@ -1275,12 +1276,16 @@ def __setitem__(self, key, value): """ # Figure out if we are doing a subtensor, a list of subscripts or a list of # linear indices + # print(f"Key: {key} {type(key)}") access_type = "error" - if self.ndims <= 1: - if isinstance(key, np.ndarray): - access_type = "subscripts" - else: + # TODO pull out this big decision tree into a function + if isinstance(key, (float, int, np.generic, slice)): + access_type = "linear indices" + elif self.ndims <= 1: + if isinstance(key, tuple): access_type = "subtensor" + elif isinstance(key, np.ndarray): + access_type = "subscripts" else: if isinstance(key, np.ndarray): if len(key.shape) > 1 and key.shape[1] >= self.ndims: @@ -1289,10 +1294,18 @@ def __setitem__(self, key, value): access_type = "linear indices" elif isinstance(key, tuple): validSubtensor = [ - isinstance(keyElement, (int, slice)) for keyElement in key + isinstance(keyElement, (int, slice, Iterable)) for keyElement in key ] + # TODO probably need to confirm the Iterable is in fact numeric if np.all(validSubtensor): access_type = "subtensor" + elif isinstance(key, Iterable): + key = np.array(key) + # Clean up copy paste + if len(key.shape) > 1 and key.shape[1] >= self.ndims: + access_type = "subscripts" + elif len(key.shape) == 1 or key.shape[1] == 1: + access_type = "linear indices" # Case 1: Rectangular Subtensor if access_type == "subtensor": @@ -1310,10 +1323,14 @@ def __setitem__(self, key, value): def _set_linear(self, key, value): idx = key - if (idx > np.prod(self.shape)).any(): + if not isinstance(idx, slice) and (idx > np.prod(self.shape)).any(): assert ( False ), "TTB:BadIndex In assignment X[I] = Y, a tensor X cannot be resized" + if isinstance(key, (int, float, np.generic)): + idx = np.array([key]) + elif isinstance(key, slice): + idx = np.array(range(np.prod(self.shape))[key]) idx = tt_ind2sub(self.shape, idx) if idx.shape[0] == 1: self.data[tuple(idx[0, :])] = value @@ -1333,6 +1350,8 @@ def _set_subtensor(self, key, value): sliceCheck.append(1) else: sliceCheck.append(element.stop) + elif isinstance(element, Iterable): + sliceCheck.append(max(element)) else: sliceCheck.append(element) bsiz = np.array(sliceCheck) @@ -1443,6 +1462,17 @@ def __getitem__(self, item): ------- :class:`pyttb.tensor` or :class:`numpy.ndarray` """ + # Case 0: Single Index Linear + if isinstance(item, (int, float, np.generic, slice)): + if isinstance(item, (int, float, np.generic)): + idx = np.array(item) + elif isinstance(item, slice): + idx = np.array(range(np.prod(self.shape))[item]) + a = np.squeeze( + self.data[tuple(ttb.tt_ind2sub(self.shape, idx).transpose())] + ) + # Todo if row make column? + return ttb.tt_subsubsref(a, idx) # Case 1: Rectangular Subtensor if ( isinstance(item, tuple) @@ -1492,9 +1522,9 @@ def __getitem__(self, item): return ttb.tt_subsubsref(a, subs) # Case 2b: Linear Indexing - if len(item) >= 2 and not isinstance(item[-1], str): + if isinstance(item, tuple) and len(item) >= 2 and not isinstance(item[-1], str): assert False, "Linear indexing requires single input array" - idx = item[0] + idx = np.array(item) a = np.squeeze(self.data[tuple(ttb.tt_ind2sub(self.shape, idx).transpose())]) # Todo if row make column? return ttb.tt_subsubsref(a, idx) diff --git a/tests/test_tensor.py b/tests/test_tensor.py index 8894e0ea..36f7ec02 100644 --- a/tests/test_tensor.py +++ b/tests/test_tensor.py @@ -280,6 +280,14 @@ def test_tensor__setitem__(sample_tensor_2way): dataGrowth[np.unravel_index([0], dataGrowth.shape, "F")] = 13.0 assert (tensorInstance.data == dataGrowth).all() + tensorInstance[0] = 14.0 + dataGrowth[np.unravel_index([0], dataGrowth.shape, "F")] = 14.0 + assert (tensorInstance.data == dataGrowth).all() + + tensorInstance[0:1] = 14.0 + dataGrowth[np.unravel_index([0], dataGrowth.shape, "F")] = 14.0 + assert (tensorInstance.data == dataGrowth).all() + # Linear Index with constant tensorInstance[np.array([0, 3, 4])] = 13.0 dataGrowth[np.unravel_index([0, 3, 4], dataGrowth.shape, "F")] = 13 @@ -340,6 +348,8 @@ def test_tensor__getitem__(sample_tensor_2way): # Case 2b: Linear Indexing assert tensorInstance[np.array([0])] == params["data"][0, 0] + assert tensorInstance[0] == params["data"][0, 0] + assert np.array_equal(tensorInstance[0:1], params["data"][0, 0]) with pytest.raises(AssertionError) as excinfo: tensorInstance[np.array([0]), np.array([0]), np.array([0])] assert "Linear indexing requires single input array" in str(excinfo) From 3497b7538c8be2a8bf1fa405f4428447cd95d548 Mon Sep 17 00:00:00 2001 From: Nick Johnson <24689722+ntjohnson1@users.noreply.github.com> Date: Sat, 27 May 2023 17:54:57 -0400 Subject: [PATCH 2/5] Tensor.__getitem__: Fix subscripts usage * Consistent with setitem now * Update usages (primarily in sptensor) --- pyttb/sptensor.py | 43 ++++++++++++++++++++++++++++--------------- pyttb/tensor.py | 15 +++++++++++++-- tests/test_tensor.py | 7 ++++--- 3 files changed, 45 insertions(+), 20 deletions(-) diff --git a/pyttb/sptensor.py b/pyttb/sptensor.py index f6d26fc4..46974585 100644 --- a/pyttb/sptensor.py +++ b/pyttb/sptensor.py @@ -6,7 +6,7 @@ import logging import warnings -from collections.abc import Sequence +from collections.abc import Iterable, Sequence from typing import Any, Callable, List, Optional, Tuple, Union, cast, overload import numpy as np @@ -620,7 +620,7 @@ def innerprod( if self.shape != other.shape: assert False, "Sptensor and tensor must be same shape for innerproduct" [subsSelf, valsSelf] = self.find() - valsOther = other[subsSelf, "extract"] + valsOther = other[subsSelf.transpose(), "extract"] return valsOther.transpose().dot(valsSelf) if isinstance(other, (ttb.ktensor, ttb.ttensor)): # pragma: no cover @@ -685,7 +685,7 @@ def is_length_2(x): if isinstance(B, ttb.tensor): BB = sptensor.from_data( - self.subs, B[self.subs, "extract"][:, None], self.shape + self.subs, B[self.subs.transpose(), "extract"][:, None], self.shape ) C = self.logical_and(BB) return C @@ -1053,7 +1053,7 @@ def scale(self, factor: np.ndarray, dims: Union[float, np.ndarray]) -> sptensor: assert False, "Size mismatch in scale" return ttb.sptensor.from_data( self.subs, - self.vals * factor[self.subs[:, dims], "extract"][:, None], + self.vals * factor[self.subs[:, dims].transpose(), "extract"][:, None], self.shape, ) if isinstance(factor, ttb.sptensor): @@ -1660,7 +1660,7 @@ def _set_subtensor(self, key, value): ) else: newsz.append(key[n].stop) - elif isinstance(key[n], np.ndarray): + elif isinstance(key[n], (np.ndarray, Iterable)): newsz.append(max(key[n]) + 1) else: newsz.append(key[n] + 1) @@ -1671,7 +1671,8 @@ def _set_subtensor(self, key, value): self.subs = np.append( self.subs, np.zeros( - shape=(self.subs.shape[0], len(self.shape) - self.subs.shape[1]) + shape=(self.subs.shape[0], len(self.shape) - self.subs.shape[1]), + dtype=int, ), axis=1, ) @@ -1689,7 +1690,7 @@ def _set_subtensor(self, key, value): if isinstance(value, (int, float)): # Determine number of dimensions (may be larger than current number) N = len(key) - keyCopy = np.array(key) + keyCopy = [None] * N # Figure out how many indices are in each dimension nssubs = np.zeros((N, 1)) for n in range(0, N): @@ -1697,7 +1698,11 @@ def _set_subtensor(self, key, value): # Generate slice explicitly to determine its length keyCopy[n] = np.arange(0, self.shape[n])[key[n]] indicesInN = len(keyCopy[n]) + elif isinstance(key[n], Iterable): + keyCopy[n] = key[n] + indicesInN = len(key[n]) else: + keyCopy[n] = key[n] indicesInN = 1 nssubs[n] = indicesInN @@ -1806,7 +1811,7 @@ def __eq__(self, other): ] # Find where their nonzeros intersect - othervals = other[self.subs, "extract"] + othervals = other[self.subs.transpose(), "extract"] znzsubs = self.subs[(othervals[:, None] == self.vals).transpose()[0], :] return sptensor.from_data( @@ -1887,7 +1892,7 @@ def __ne__(self, other): subs1 = np.empty((0, self.subs.shape[1])) # find entries where x is nonzero but not equal to y subs2 = self.subs[ - self.vals.transpose()[0] != other[self.subs, "extract"], : + self.vals.transpose()[0] != other[self.subs.transpose(), "extract"], : ] if subs2.size == 0: subs2 = np.empty((0, self.subs.shape[1])) @@ -2002,7 +2007,7 @@ def __mul__(self, other): ) if isinstance(other, ttb.tensor): csubs = self.subs - cvals = self.vals * other[csubs, "extract"][:, None] + cvals = self.vals * other[csubs.transpose(), "extract"][:, None] return ttb.sptensor.from_data(csubs, cvals, self.shape) if isinstance(other, ttb.ktensor): csubs = self.subs @@ -2124,7 +2129,7 @@ def __le__(self, other): # self nonzero subs2 = self.subs[ - self.vals.transpose()[0] <= other[self.subs, "extract"], : + self.vals.transpose()[0] <= other[self.subs.transpose(), "extract"], : ] # assemble @@ -2212,7 +2217,9 @@ def __lt__(self, other): subs1 = subs1[ttb.tt_setdiff_rows(subs1, self.subs), :] # self nonzero - subs2 = self.subs[self.vals.transpose()[0] < other[self.subs, "extract"], :] + subs2 = self.subs[ + self.vals.transpose()[0] < other[self.subs.transpose(), "extract"], : + ] # assemble subs = np.vstack((subs1, subs2)) @@ -2267,7 +2274,10 @@ def __ge__(self, other): # self nonzero subs2 = self.subs[ - (self.vals >= other[self.subs, "extract"][:, None]).transpose()[0], : + ( + self.vals >= other[self.subs.transpose(), "extract"][:, None] + ).transpose()[0], + :, ] # assemble @@ -2325,7 +2335,10 @@ def __gt__(self, other): # self and other nonzero subs2 = self.subs[ - (self.vals > other[self.subs, "extract"][:, None]).transpose()[0], : + ( + self.vals > other[self.subs.transpose(), "extract"][:, None] + ).transpose()[0], + :, ] # assemble @@ -2428,7 +2441,7 @@ def __truediv__(self, other): if isinstance(other, ttb.tensor): csubs = self.subs - cvals = self.vals / other[csubs, "extract"][:, None] + cvals = self.vals / other[csubs.transpose(), "extract"][:, None] return ttb.sptensor.from_data(csubs, cvals, self.shape) if isinstance(other, ttb.ktensor): # TODO consider removing epsilon and generating nans consistent with above diff --git a/pyttb/tensor.py b/pyttb/tensor.py index 6fa4290f..56c91c50 100644 --- a/pyttb/tensor.py +++ b/pyttb/tensor.py @@ -1514,10 +1514,21 @@ def __getitem__(self, item): return a # *** CASE 2a: Subscript indexing *** - if len(item) > 1 and isinstance(item[-1], str) and item[-1] == "extract": + if isinstance(item, np.ndarray) and len(item) > 1: # Extract array of subscripts + subs = np.array(item) + a = np.squeeze(self.data[tuple(subs)]) + # TODO if is row make column? + return ttb.tt_subsubsref(a, subs) + if ( + len(item) > 1 + and isinstance(item[0], np.ndarray) + and isinstance(item[-1], str) + and item[-1] == "extract" + ): + # TODO dry this up subs = np.array(item[0]) - a = np.squeeze(self.data[tuple(subs.transpose())]) + a = np.squeeze(self.data[tuple(subs)]) # TODO if is row make column? return ttb.tt_subsubsref(a, subs) diff --git a/tests/test_tensor.py b/tests/test_tensor.py index 36f7ec02..d5167cbc 100644 --- a/tests/test_tensor.py +++ b/tests/test_tensor.py @@ -313,9 +313,10 @@ def test_tensor__setitem__(sample_tensor_2way): ) # Attempting to set some other way - with pytest.raises(AssertionError) as excinfo: + # TODO either catch this error ourselves or specify more specific exception we expect here + with pytest.raises(Exception) as excinfo: tensorInstance[0, "a", 5] = 13.0 - assert "Invalid use of tensor setitem" in str(excinfo) + # assert "Invalid use of tensor setitem" in str(excinfo) @pytest.mark.indevelopment @@ -343,7 +344,7 @@ def test_tensor__getitem__(sample_tensor_2way): assert tensorInstance[np.array([0, 0]), "extract"] == params["data"][0, 0] assert ( tensorInstance[np.array([[0, 0], [1, 1]]), "extract"] - == params["data"][([0, 1], [0, 1])] + == params["data"][([0, 0], [1, 1])] ).all() # Case 2b: Linear Indexing From 5aad34f3ffdd3700cd1fbc7755774019193ff7e2 Mon Sep 17 00:00:00 2001 From: Nick Johnson <24689722+ntjohnson1@users.noreply.github.com> Date: Sat, 27 May 2023 18:36:45 -0400 Subject: [PATCH 3/5] Sptensor.__setitem__/__getitem__: Fix subscripts usage * Consistent with tensor and MATLAB now * Update test usage --- pyttb/sptensor.py | 16 ++++++++-------- tests/test_sptensor.py | 42 ++++++++++++++++++++++-------------------- 2 files changed, 30 insertions(+), 28 deletions(-) diff --git a/pyttb/sptensor.py b/pyttb/sptensor.py index 46974585..a5993842 100644 --- a/pyttb/sptensor.py +++ b/pyttb/sptensor.py @@ -1368,9 +1368,9 @@ def __getitem__(self, item): if ( isinstance(item, np.ndarray) and len(item.shape) == 2 - and item.shape[1] == self.ndims + and item.shape[0] == self.ndims ): - srchsubs = np.array(item) + srchsubs = np.array(item.transpose()) # *** CASE 2b: Linear indexing *** else: @@ -1463,21 +1463,21 @@ def _set_subscripts(self, key, value): tt_subscheck(newsubs, nargout=False) # Error check on subscripts - if newsubs.shape[1] < self.ndims: + if newsubs.shape[0] < self.ndims: assert False, "Invalid subscripts" # Check for expanding the order - if newsubs.shape[1] > self.ndims: + if newsubs.shape[0] > self.ndims: newshape = list(self.shape) # TODO no need for loop, just add correct size - for _ in range(self.ndims, newsubs.shape[1]): + for _ in range(self.ndims, newsubs.shape[0]): newshape.append(1) if self.subs.size > 0: self.subs = np.concatenate( ( self.subs, np.ones( - (self.shape[0], newsubs.shape[1] - self.ndims), + (self.shape[0], newsubs.shape[0] - self.ndims), dtype=int, ), ), @@ -1497,7 +1497,7 @@ def _set_subscripts(self, key, value): # Determine number of nonzeros being inserted. # (This is determined by number of subscripts) - newnnz = newsubs.shape[0] + newnnz = newsubs.shape[1] # Error check on size of newvals if newvals.size == 1: @@ -1510,7 +1510,7 @@ def _set_subscripts(self, key, value): assert False, "Number of subscripts and number of values do not match!" # Remove duplicates and print warning if any duplicates were removed - newsubs, idx = np.unique(newsubs, axis=0, return_index=True) + newsubs, idx = np.unique(newsubs.transpose(), axis=0, return_index=True) if newsubs.shape[0] != newnnz: warnings.warn("Duplicate assignments discarded") diff --git a/tests/test_sptensor.py b/tests/test_sptensor.py index 3a14ccdb..62ebae17 100644 --- a/tests/test_sptensor.py +++ b/tests/test_sptensor.py @@ -293,9 +293,9 @@ def test_sptensor__getitem__(sample_sptensor): # TODO need to understand what this intends to do ## Case 2 subscript indexing - assert sptensorInstance[np.array([[1, 2, 1]])] == np.array([[0]]) + assert sptensorInstance[np.array([[1], [2], [1]])] == np.array([[0]]) assert ( - sptensorInstance[np.array([[1, 2, 1], [1, 3, 1]])] == np.array([[0], [0]]) + sptensorInstance[np.array([[1, 1], [2, 3], [1, 1]])] == np.array([[0], [0]]) ).all() ## Case 2 Linear Indexing @@ -551,12 +551,14 @@ def test_sptensor_setitem_Case2(sample_sptensor): # Case II: Too few keys in setitem for number of assignement values with pytest.raises(AssertionError) as excinfo: - sptensorInstance[np.array([1, 1, 1]).astype(int)] = np.array([[999.0], [888.0]]) + sptensorInstance[np.array([[1], [1], [1]]).astype(int)] = np.array( + [[999.0], [888.0]] + ) assert "Number of subscripts and number of values do not match!" in str(excinfo) # Case II: Warning For duplicates with pytest.warns(Warning) as record: - sptensorInstance[np.array([[1, 1, 1], [1, 1, 1]]).astype(int)] = np.array( + sptensorInstance[np.array([[1, 1], [1, 1], [1, 1]]).astype(int)] = np.array( [[999.0], [999.0]] ) assert "Duplicate assignments discarded" in str(record[0].message) @@ -567,54 +569,54 @@ def test_sptensor_setitem_Case2(sample_sptensor): assert np.all(empty_tensor[np.array([[0, 1], [2, 2]])] == 4) # Case II: Single entry, for single sub that exists - sptensorInstance[np.array([1, 1, 1]).astype(int)] = 999.0 - assert (sptensorInstance[np.array([[1, 1, 1]])] == np.array([[999]])).all() + sptensorInstance[np.array([[1], [1], [1]]).astype(int)] = 999.0 + assert (sptensorInstance[np.array([[1], [1], [1]])] == np.array([[999]])).all() assert (sptensorInstance.subs == data["subs"]).all() # Case II: Single entry, for multiple subs that exist (data, sptensorInstance) = sample_sptensor - sptensorInstance[np.array([[1, 1, 1], [1, 1, 3]]).astype(int)] = 999.0 + sptensorInstance[np.array([[1, 1], [1, 1], [1, 3]]).astype(int)] = 999.0 assert ( - sptensorInstance[np.array([[1, 1, 1], [1, 1, 3]])] == np.array([[999], [999]]) + sptensorInstance[np.array([[1, 1], [1, 1], [1, 3]])] == np.array([[999], [999]]) ).all() assert (sptensorInstance.subs == data["subs"]).all() # Case II: Multiple entries, for multiple subs that exist (data, sptensorInstance) = sample_sptensor - sptensorInstance[np.array([[1, 1, 1], [1, 1, 3]]).astype(int)] = np.array( + sptensorInstance[np.array([[1, 1], [1, 1], [1, 3]]).astype(int)] = np.array( [[888], [999]] ) assert ( - sptensorInstance[np.array([[1, 1, 3], [1, 1, 1]])] == np.array([[999], [888]]) + sptensorInstance[np.array([[1, 1], [1, 1], [3, 1]])] == np.array([[999], [888]]) ).all() assert (sptensorInstance.subs == data["subs"]).all() # Case II: Single entry, for single sub that doesn't exist (data, sptensorInstance) = sample_sptensor copy = ttb.sptensor.from_tensor_type(sptensorInstance) - copy[np.array([[1, 1, 2]]).astype(int)] = 999.0 - assert (copy[np.array([[1, 1, 2]])] == np.array([999])).all() + copy[np.array([[1], [1], [2]]).astype(int)] = 999.0 + assert (copy[np.array([[1], [1], [2]])] == np.array([999])).all() assert (copy.subs == np.concatenate((data["subs"], np.array([[1, 1, 2]])))).all() # Case II: Single entry, for single sub that doesn't exist, expand dimensions (data, sptensorInstance) = sample_sptensor copy = ttb.sptensor.from_tensor_type(sptensorInstance) - copy[np.array([[1, 1, 2, 1]]).astype(int)] = 999.0 - assert (copy[np.array([[1, 1, 2, 1]])] == np.array([999])).all() + copy[np.array([[1], [1], [2], [1]]).astype(int)] = 999.0 + assert (copy[np.array([[1], [1], [2], [1]])] == np.array([999])).all() # assert (copy.subs == np.concatenate((data['subs'], np.array([[1, 1, 2]])))).all() # Case II: Single entry, for multiple subs one that exists and the other doesn't (data, sptensorInstance) = sample_sptensor copy = ttb.sptensor.from_tensor_type(sptensorInstance) - copy[np.array([[1, 1, 1], [2, 1, 3]]).astype(int)] = 999.0 - assert (copy[np.array([[2, 1, 3]])] == np.array([999])).all() + copy[np.array([[1, 2], [1, 1], [1, 3]]).astype(int)] = 999.0 + assert (copy[np.array([[2], [1], [3]])] == np.array([999])).all() assert (copy.subs == np.concatenate((data["subs"], np.array([[2, 1, 3]])))).all() # Case II: Multiple entries, for multiple subs that don't exist (data, sptensorInstance) = sample_sptensor copy = ttb.sptensor.from_tensor_type(sptensorInstance) - copy[np.array([[1, 1, 2], [2, 1, 3]]).astype(int)] = np.array([[888], [999]]) - assert (copy[np.array([[1, 1, 2], [2, 1, 3]])] == np.array([[888], [999]])).all() + copy[np.array([[1, 2], [1, 1], [2, 3]]).astype(int)] = np.array([[888], [999]]) + assert (copy[np.array([[1, 2], [1, 1], [2, 3]])] == np.array([[888], [999]])).all() assert ( copy.subs == np.concatenate((data["subs"], np.array([[1, 1, 2], [2, 1, 3]]))) ).all() @@ -622,8 +624,8 @@ def test_sptensor_setitem_Case2(sample_sptensor): # Case II: Multiple entries, for multiple subs that exist and need to be removed (data, sptensorInstance) = sample_sptensor copy = ttb.sptensor.from_tensor_type(sptensorInstance) - copy[np.array([[1, 1, 1], [1, 1, 3]]).astype(int)] = np.array([[0], [0]]) - assert (copy[np.array([[1, 1, 2], [2, 1, 3]])] == np.array([[0], [0]])).all() + copy[np.array([[1, 1], [1, 1], [1, 3]]).astype(int)] = np.array([[0], [0]]) + assert (copy[np.array([[1, 2], [1, 1], [1, 3]])] == np.array([[0], [0]])).all() assert (copy.subs == np.array([[2, 2, 2], [3, 3, 3]])).all() From 61ec65c7c36ced7f7d5feec2978709e22697cd41 Mon Sep 17 00:00:00 2001 From: Nick Johnson <24689722+ntjohnson1@users.noreply.github.com> Date: Sun, 28 May 2023 16:33:03 -0400 Subject: [PATCH 4/5] sptensor: Add coverage for improved indexing capability --- pyttb/sptensor.py | 2 ++ tests/test_sptensor.py | 12 ++++++++++++ 2 files changed, 14 insertions(+) diff --git a/pyttb/sptensor.py b/pyttb/sptensor.py index a5993842..9aeb83a5 100644 --- a/pyttb/sptensor.py +++ b/pyttb/sptensor.py @@ -1647,6 +1647,8 @@ def _set_subtensor(self, key, value): newsz.append(self.shape[n]) else: newsz.append(max([self.shape[n], key[n].stop])) + elif isinstance(key[n], Iterable): + newsz.append(max([self.shape[n], max(key[n]) + 1])) else: newsz.append(max([self.shape[n], key[n] + 1])) diff --git a/tests/test_sptensor.py b/tests/test_sptensor.py index 62ebae17..f9662d10 100644 --- a/tests/test_sptensor.py +++ b/tests/test_sptensor.py @@ -533,6 +533,18 @@ def test_sptensor_setitem_Case1(sample_sptensor): assert (sptensorInstance.vals == np.vstack((data["vals"], np.array([[7]])))).all() assert sptensorInstance.shape == data["shape"] + # Case I(b)ii: Set with scalar, iterable index, empty sptensor + someTensor = ttb.sptensor() + someTensor[[0, 1], 0] = 1 + assert someTensor[0, 0] == 1 + assert someTensor[1, 0] == 1 + assert np.all(someTensor[[0, 1], 0].vals == 1) + # Case I(b)ii: Set with scalar, iterable index, non-empty sptensor + someTensor[[0, 1], 1] = 2 + assert someTensor[0, 1] == 2 + assert someTensor[1, 1] == 2 + assert np.all(someTensor[[0, 1], 1].vals == 2) + # Case I: Assign with non-scalar or sptensor sptensorInstanceLarger = ttb.sptensor.from_tensor_type(sptensorInstance) with pytest.raises(AssertionError) as excinfo: From d1be0512d5125e3f5ed42538c8d3f63ee06a1582 Mon Sep 17 00:00:00 2001 From: Nick Johnson <24689722+ntjohnson1@users.noreply.github.com> Date: Sun, 28 May 2023 16:53:49 -0400 Subject: [PATCH 5/5] tensor: Add coverage for improved indexing capability --- pyttb/tensor.py | 13 +++++++------ tests/test_tensor.py | 32 +++++++++++++++++++++++++++++--- 2 files changed, 36 insertions(+), 9 deletions(-) diff --git a/pyttb/tensor.py b/pyttb/tensor.py index 56c91c50..b415824d 100644 --- a/pyttb/tensor.py +++ b/pyttb/tensor.py @@ -1276,7 +1276,6 @@ def __setitem__(self, key, value): """ # Figure out if we are doing a subtensor, a list of subscripts or a list of # linear indices - # print(f"Key: {key} {type(key)}") access_type = "error" # TODO pull out this big decision tree into a function if isinstance(key, (float, int, np.generic, slice)): @@ -1296,15 +1295,11 @@ def __setitem__(self, key, value): validSubtensor = [ isinstance(keyElement, (int, slice, Iterable)) for keyElement in key ] - # TODO probably need to confirm the Iterable is in fact numeric if np.all(validSubtensor): access_type = "subtensor" elif isinstance(key, Iterable): key = np.array(key) - # Clean up copy paste - if len(key.shape) > 1 and key.shape[1] >= self.ndims: - access_type = "subscripts" - elif len(key.shape) == 1 or key.shape[1] == 1: + if len(key.shape) == 1 or key.shape[1] == 1: access_type = "linear indices" # Case 1: Rectangular Subtensor @@ -1351,6 +1346,12 @@ def _set_subtensor(self, key, value): else: sliceCheck.append(element.stop) elif isinstance(element, Iterable): + if any( + not isinstance(entry, (float, int, np.generic)) for entry in element + ): + raise ValueError( + f"Entries for setitem must be numeric but recieved, {element}" + ) sliceCheck.append(max(element)) else: sliceCheck.append(element) diff --git a/tests/test_tensor.py b/tests/test_tensor.py index d5167cbc..d27463c1 100644 --- a/tests/test_tensor.py +++ b/tests/test_tensor.py @@ -251,6 +251,13 @@ def test_tensor__setitem__(sample_tensor_2way): # Subtensor add dimension empty_tensor[0, 0, 0] = 2 + # Subtensor with lists + some_tensor = ttb.tenones((3, 3)) + some_tensor[[0, 1], [0, 1]] = 11 + assert some_tensor[0, 0] == 11 + assert some_tensor[1, 1] == 11 + assert np.all(some_tensor[[0, 1], [0, 1]].data == 11) + # Subscripts with constant tensorInstance[np.array([[1, 1]])] = 13.0 dataGrowth[1, 1] = 13.0 @@ -293,6 +300,13 @@ def test_tensor__setitem__(sample_tensor_2way): dataGrowth[np.unravel_index([0, 3, 4], dataGrowth.shape, "F")] = 13 assert (tensorInstance.data == dataGrowth).all() + # Linear index with multiple indicies + some_tensor = ttb.tenones((3, 3)) + some_tensor[[0, 1]] = 2 + assert some_tensor[0] == 2 + assert some_tensor[1] == 2 + assert np.array_equal(some_tensor[[0, 1]], [2, 2]) + # Test Empty Tensor Set Item, subtensor emptyTensor = ttb.tensor.from_data(np.array([])) emptyTensor[0, 0, 0] = 0 @@ -313,10 +327,17 @@ def test_tensor__setitem__(sample_tensor_2way): ) # Attempting to set some other way - # TODO either catch this error ourselves or specify more specific exception we expect here - with pytest.raises(Exception) as excinfo: + with pytest.raises(ValueError) as excinfo: tensorInstance[0, "a", 5] = 13.0 - # assert "Invalid use of tensor setitem" in str(excinfo) + assert "must be numeric" in str(excinfo) + + with pytest.raises(AssertionError) as excinfo: + + class BadKey: + pass + + tensorInstance[BadKey] = 13.0 + assert "Invalid use of tensor setitem" in str(excinfo) @pytest.mark.indevelopment @@ -346,6 +367,11 @@ def test_tensor__getitem__(sample_tensor_2way): tensorInstance[np.array([[0, 0], [1, 1]]), "extract"] == params["data"][([0, 0], [1, 1])] ).all() + # Case 2a: Extract doesn't seem to be needed + assert tensorInstance[np.array([0, 0])] == params["data"][0, 0] + assert ( + tensorInstance[np.array([[0, 0], [1, 1]])] == params["data"][([0, 0], [1, 1])] + ).all() # Case 2b: Linear Indexing assert tensorInstance[np.array([0])] == params["data"][0, 0]