diff --git a/pyttb/tenmat.py b/pyttb/tenmat.py index 7f178a0f..9e48f849 100644 --- a/pyttb/tenmat.py +++ b/pyttb/tenmat.py @@ -113,12 +113,19 @@ def from_tensor_type(cls, source, rdims=None, cdims=None, cdims_cyclic=None): elif rdims is None and cdims is not None: rdims = np.setdiff1d(alldims, cdims) - - dims = np.hstack([rdims, cdims]) + # if rdims or cdims is empty, hstack will output an array of float not int + if rdims.size == 0: + dims = cdims.copy() + elif cdims.size == 0: + dims = rdims.copy() + else: + dims = np.hstack([rdims, cdims]) if not len(dims) == n or not (alldims == np.sort(dims)).all(): assert False, 'Incorrect specification of dimensions, the sorted concatenation of rdims and cdims must be range(source.ndims).' - data = np.reshape(source.permute(dims).data, (np.prod(np.array(tshape)[rdims]), np.prod(np.array(tshape)[cdims])), order='F') + rprod = 1 if rdims.size == 0 else np.prod(np.array(tshape)[rdims]) + cprod = 1 if cdims.size == 0 else np.prod(np.array(tshape)[cdims]) + data = np.reshape(source.permute(dims).data, (rprod, cprod), order='F') # Create tenmat tenmatInstance = cls() diff --git a/tests/test_tenmat.py b/tests/test_tenmat.py index 29e87dfe..1c4a4d88 100644 --- a/tests/test_tenmat.py +++ b/tests/test_tenmat.py @@ -22,6 +22,14 @@ def sample_ndarray_2way(): params = {'data':ndarrayInstance, 'shape':shape} return params, ndarrayInstance +@pytest.fixture() +def sample_tensor_3way(): + data = np.array([1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12.]) + shape = (2, 3, 2) + params = {'data':np.reshape(data, np.array(shape), order='F'), 'shape': shape} + tensorInstance = ttb.tensor().from_data(data, shape) + return params, tensorInstance + @pytest.fixture() def sample_ndarray_4way(): shape = (2, 2, 2, 2) @@ -184,8 +192,9 @@ def test_tenmat_initialization_from_data(sample_ndarray_1way, sample_ndarray_2wa assert exc in str(excinfo) @pytest.mark.indevelopment -def test_tenmat_initialization_from_tensor_type(sample_tenmat_4way, sample_tensor_4way): +def test_tenmat_initialization_from_tensor_type(sample_tenmat_4way, sample_tensor_3way, sample_tensor_4way): (_, tensorInstance) = sample_tensor_4way + (_, tensorInstance3) = sample_tensor_3way (params, tenmatInstance) = sample_tenmat_4way tshape = params['tshape'] rdims = params['rdims'] @@ -208,6 +217,11 @@ def test_tenmat_initialization_from_tensor_type(sample_tenmat_4way, sample_tenso assert tenmatInstance.shape == tenmatTensorRdims.shape assert tenmatInstance.tshape == tenmatTensorRdims.tshape + # Constructor from tensor using empty rdims + tenmatTensorRdims = ttb.tenmat.from_tensor_type(tensorInstance3, rdims=np.array([])) + data = np.reshape(np.arange(1,13),(1,12)) + assert (tenmatTensorRdims.data == data).all() + # Constructor from tensor using cdims only tenmatTensorCdims = ttb.tenmat.from_tensor_type(tensorInstance, cdims=cdims) assert (tenmatInstance.data == tenmatTensorCdims.data).all() @@ -216,6 +230,11 @@ def test_tenmat_initialization_from_tensor_type(sample_tenmat_4way, sample_tenso assert tenmatInstance.shape == tenmatTensorCdims.shape assert tenmatInstance.tshape == tenmatTensorCdims.tshape + # Constructor from tensor using empty cdims + tenmatTensorCdims = ttb.tenmat.from_tensor_type(tensorInstance3, cdims=np.array([])) + data = np.reshape(np.arange(1,13),(12,1)) + assert (tenmatTensorCdims.data == data).all() + # Constructor from tensor using rdims and cdims tenmatTensorRdimsCdims = ttb.tenmat.from_tensor_type(tensorInstance, rdims=rdims, cdims=cdims) assert (tenmatInstance.data == tenmatTensorRdimsCdims.data).all()