Skip to content

Commit 931297f

Browse files
committed
Add explicit check for failing ndarray.dot(TensorVariable)
1 parent 58fb850 commit 931297f

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

tests/tensor/test_variable.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -75,18 +75,20 @@ def test_numpy_method(fct, value):
7575
utt.assert_allclose(np.nan_to_num(f(value)), np.nan_to_num(fct(value)))
7676

7777

78-
def test_infix_dot_method():
78+
def test_dot_method():
7979
X = dmatrix("X")
8080
y = dvector("y")
8181

8282
res = X.dot(y)
8383
exp_res = dot(X, y)
8484
assert equal_computations([res], [exp_res])
8585

86+
# This doesn't work. Numpy calls TensorVariable.__rmul__ at some point and everything is messed up
8687
X_val = np.arange(2 * 3).reshape((2, 3))
87-
res = as_tensor(X_val).dot(y)
88+
res = X_val.dot(y)
8889
exp_res = dot(X_val, y)
89-
assert equal_computations([res], [exp_res])
90+
with pytest.raises(AssertionError):
91+
assert equal_computations([res], [exp_res])
9092

9193

9294
def test_infix_matmul_method():

0 commit comments

Comments
 (0)