Skip to content

Adding implementation. #44

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions pyttb/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -942,6 +942,48 @@ def ttm(self, matrix, dims=None, transpose=False):
Y = np.transpose(Y, np.argsort(order))
return ttb.tensor.from_data(Y)

def ttt(self, other, selfdims=None, otherdims=None):
"""
Tensor mulitplication (tensor times tensor)

Parameters
----------
other: :class:`ttb.tensor`
selfdims: :class:`Numpy.ndarray`, int
otherdims: :class:`Numpy.ndarray`, int
"""

if not isinstance(other, tensor):
assert False, "other must be of type tensor"

if selfdims is None:
selfdims = np.array([])
selfshape = ()
else:
selfshape = tuple(np.array(self.shape)[selfdims])

if otherdims is None:
otherdims = selfdims.copy()
othershape = ()
else:
othershape = tuple(np.array(other.shape)[otherdims])

if not selfshape == othershape:
assert False, "Specified dimensions do not match"

# Compute the product

# Avoid transpose by reshaping self and computing result = self * other
amatrix = ttb.tenmat.from_tensor_type(self, cdims=selfdims)
bmatrix = ttb.tenmat.from_tensor_type(other, rdims=otherdims)
cmatrix = amatrix * bmatrix

# Check whether or not the result is a scalar
if isinstance(cmatrix, ttb.tenmat):
return ttb.tensor.from_tensor_type(cmatrix)
else:
return cmatrix

def ttv(self, vector, dims=None):
"""
Tensor times vector
Expand Down
18 changes: 18 additions & 0 deletions tests/test_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1057,6 +1057,24 @@ def test_tensor_ttm(sample_tensor_2way, sample_tensor_3way, sample_tensor_4way):
tensorInstance3.ttm(M2, tensorInstance3.ndims + 1)
assert "dims must contain values in [0,self.dims]" in str(excinfo)

@pytest.mark.indevelopment
def test_tensor_ttt(sample_tensor_2way, sample_tensor_3way, sample_tensor_4way):

M31 = ttb.tensor.from_data(np.reshape(np.arange(1,2*3*4+1),[4,3,2], order='F'))
M32 = ttb.tensor.from_data(np.reshape(np.arange(1,2*3*4+1),[3,4,2], order='F'))

# outer product of M31 and M32
TTT1 = M31.ttt(M32)
assert TTT1.shape == (4,3,2,3,4,2)
# choose two random 2-way slices
data11 = np.array([1,2,3,4])
data12 = np.array([289,306,323,340])
data13 = np.array([504,528,552,576])
assert (TTT1[:,0,0,0,0,0].data == data11).all()
assert (TTT1[:,1,1,1,1,1].data == data12).all()
assert (TTT1[:,2,1,2,3,1].data == data13).all()


@pytest.mark.indevelopment
def test_tensor_ttv(sample_tensor_2way, sample_tensor_3way, sample_tensor_4way):
(params2, tensorInstance2) = sample_tensor_2way
Expand Down