Skip to content

Commit 9e4c0e4

Browse files
michaelosthegetwiecki
authored andcommitted
Simplify asserts
1 parent 4116a35 commit 9e4c0e4

File tree

1 file changed

+13
-10
lines changed

1 file changed

+13
-10
lines changed

tests/tensor/test_basic.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1667,18 +1667,15 @@ def test_broadcastable_flag_assignment_mixed_otheraxes(self):
16671667
a = self.shared(a_val, shape=(None, None, 1))
16681668
b = self.shared(b_val, shape=(1, None, 1))
16691669
c = self.join_op(1, a, b)
1670-
assert c.type.shape[0] == 1 and c.type.shape[2] == 1
1671-
assert c.type.shape[1] != 1
1670+
assert c.type.shape == (1, None, 1)
16721671

16731672
# Opt can remplace the int by an PyTensor constant
16741673
c = self.join_op(constant(1), a, b)
1675-
assert c.type.shape[0] == 1 and c.type.shape[2] == 1
1676-
assert c.type.shape[1] != 1
1674+
assert c.type.shape == (1, None, 1)
16771675

16781676
# In case futur opt insert other useless stuff
16791677
c = self.join_op(cast(constant(1), dtype="int32"), a, b)
1680-
assert c.type.shape[0] == 1 and c.type.shape[2] == 1
1681-
assert c.type.shape[1] != 1
1678+
assert c.type.shape == (1, None, 1)
16821679

16831680
f = function([], c, mode=self.mode)
16841681
topo = f.maker.fgraph.toposort()
@@ -1783,15 +1780,21 @@ def test_broadcastable_flags_many_dims_and_inputs(self):
17831780
c = TensorType(dtype=self.floatX, shape=(1, None, None, None, None, None))()
17841781
d = TensorType(dtype=self.floatX, shape=(1, None, 1, 1, None, 1))()
17851782
e = TensorType(dtype=self.floatX, shape=(1, None, 1, None, None, 1))()
1783+
17861784
f = self.join_op(0, a, b, c, d, e)
17871785
fb = tuple(s == 1 for s in f.type.shape)
1788-
assert not fb[0] and fb[1] and fb[2] and fb[3] and not fb[4] and fb[5]
1786+
assert f.type.shape == (5, 1, 1, 1, None, 1)
1787+
assert fb == (False, True, True, True, False, True)
1788+
17891789
g = self.join_op(1, a, b, c, d, e)
17901790
gb = tuple(s == 1 for s in g.type.shape)
1791-
assert gb[0] and not gb[1] and gb[2] and gb[3] and not gb[4] and gb[5]
1791+
assert g.type.shape == (1, None, 1, 1, None, 1)
1792+
assert gb == (True, False, True, True, False, True)
1793+
17921794
h = self.join_op(4, a, b, c, d, e)
17931795
hb = tuple(s == 1 for s in h.type.shape)
1794-
assert hb[0] and hb[1] and hb[2] and hb[3] and not hb[4] and hb[5]
1796+
assert h.type.shape == (1, 1, 1, 1, None, 1)
1797+
assert hb == (True, True, True, True, False, True)
17951798

17961799
f = function([a, b, c, d, e], f, mode=self.mode)
17971800
topo = f.maker.fgraph.toposort()
@@ -1903,7 +1906,7 @@ def test_mixed_ndim_error(self):
19031906
rng = np.random.default_rng(seed=utt.fetch_seed())
19041907
v = self.shared(rng.random(4).astype(self.floatX))
19051908
m = self.shared(rng.random((4, 4)).astype(self.floatX))
1906-
with pytest.raises(TypeError):
1909+
with pytest.raises(TypeError, match="same number of dimensions"):
19071910
self.join_op(0, v, m)
19081911

19091912
def test_split_0elem(self):

0 commit comments

Comments
 (0)