Skip to content

Commit 16e57a7

Browse files
authored
Adding tensor.ttt implementation. (#44)
Closes 28
1 parent 2ab1934 commit 16e57a7

File tree

2 files changed

+60
-0
lines changed

2 files changed

+60
-0
lines changed

pyttb/tensor.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -942,6 +942,48 @@ def ttm(self, matrix, dims=None, transpose=False):
942942
Y = np.transpose(Y, np.argsort(order))
943943
return ttb.tensor.from_data(Y)
944944

945+
def ttt(self, other, selfdims=None, otherdims=None):
946+
"""
947+
Tensor mulitplication (tensor times tensor)
948+
949+
Parameters
950+
----------
951+
other: :class:`ttb.tensor`
952+
selfdims: :class:`Numpy.ndarray`, int
953+
otherdims: :class:`Numpy.ndarray`, int
954+
"""
955+
956+
if not isinstance(other, tensor):
957+
assert False, "other must be of type tensor"
958+
959+
if selfdims is None:
960+
selfdims = np.array([])
961+
selfshape = ()
962+
else:
963+
selfshape = tuple(np.array(self.shape)[selfdims])
964+
965+
if otherdims is None:
966+
otherdims = selfdims.copy()
967+
othershape = ()
968+
else:
969+
othershape = tuple(np.array(other.shape)[otherdims])
970+
971+
if not selfshape == othershape:
972+
assert False, "Specified dimensions do not match"
973+
974+
# Compute the product
975+
976+
# Avoid transpose by reshaping self and computing result = self * other
977+
amatrix = ttb.tenmat.from_tensor_type(self, cdims=selfdims)
978+
bmatrix = ttb.tenmat.from_tensor_type(other, rdims=otherdims)
979+
cmatrix = amatrix * bmatrix
980+
981+
# Check whether or not the result is a scalar
982+
if isinstance(cmatrix, ttb.tenmat):
983+
return ttb.tensor.from_tensor_type(cmatrix)
984+
else:
985+
return cmatrix
986+
945987
def ttv(self, vector, dims=None):
946988
"""
947989
Tensor times vector

tests/test_tensor.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1057,6 +1057,24 @@ def test_tensor_ttm(sample_tensor_2way, sample_tensor_3way, sample_tensor_4way):
10571057
tensorInstance3.ttm(M2, tensorInstance3.ndims + 1)
10581058
assert "dims must contain values in [0,self.dims]" in str(excinfo)
10591059

1060+
@pytest.mark.indevelopment
1061+
def test_tensor_ttt(sample_tensor_2way, sample_tensor_3way, sample_tensor_4way):
1062+
1063+
M31 = ttb.tensor.from_data(np.reshape(np.arange(1,2*3*4+1),[4,3,2], order='F'))
1064+
M32 = ttb.tensor.from_data(np.reshape(np.arange(1,2*3*4+1),[3,4,2], order='F'))
1065+
1066+
# outer product of M31 and M32
1067+
TTT1 = M31.ttt(M32)
1068+
assert TTT1.shape == (4,3,2,3,4,2)
1069+
# choose two random 2-way slices
1070+
data11 = np.array([1,2,3,4])
1071+
data12 = np.array([289,306,323,340])
1072+
data13 = np.array([504,528,552,576])
1073+
assert (TTT1[:,0,0,0,0,0].data == data11).all()
1074+
assert (TTT1[:,1,1,1,1,1].data == data12).all()
1075+
assert (TTT1[:,2,1,2,3,1].data == data13).all()
1076+
1077+
10601078
@pytest.mark.indevelopment
10611079
def test_tensor_ttv(sample_tensor_2way, sample_tensor_3way, sample_tensor_4way):
10621080
(params2, tensorInstance2) = sample_tensor_2way

0 commit comments

Comments
 (0)