@@ -1667,18 +1667,15 @@ def test_broadcastable_flag_assignment_mixed_otheraxes(self):
1667
1667
a = self .shared (a_val , shape = (None , None , 1 ))
1668
1668
b = self .shared (b_val , shape = (1 , None , 1 ))
1669
1669
c = self .join_op (1 , a , b )
1670
- assert c .type .shape [0 ] == 1 and c .type .shape [2 ] == 1
1671
- assert c .type .shape [1 ] != 1
1670
+ assert c .type .shape == (1 , None , 1 )
1672
1671
1673
1672
# Opt can remplace the int by an PyTensor constant
1674
1673
c = self .join_op (constant (1 ), a , b )
1675
- assert c .type .shape [0 ] == 1 and c .type .shape [2 ] == 1
1676
- assert c .type .shape [1 ] != 1
1674
+ assert c .type .shape == (1 , None , 1 )
1677
1675
1678
1676
# In case futur opt insert other useless stuff
1679
1677
c = self .join_op (cast (constant (1 ), dtype = "int32" ), a , b )
1680
- assert c .type .shape [0 ] == 1 and c .type .shape [2 ] == 1
1681
- assert c .type .shape [1 ] != 1
1678
+ assert c .type .shape == (1 , None , 1 )
1682
1679
1683
1680
f = function ([], c , mode = self .mode )
1684
1681
topo = f .maker .fgraph .toposort ()
@@ -1783,15 +1780,21 @@ def test_broadcastable_flags_many_dims_and_inputs(self):
1783
1780
c = TensorType (dtype = self .floatX , shape = (1 , None , None , None , None , None ))()
1784
1781
d = TensorType (dtype = self .floatX , shape = (1 , None , 1 , 1 , None , 1 ))()
1785
1782
e = TensorType (dtype = self .floatX , shape = (1 , None , 1 , None , None , 1 ))()
1783
+
1786
1784
f = self .join_op (0 , a , b , c , d , e )
1787
1785
fb = tuple (s == 1 for s in f .type .shape )
1788
- assert not fb [0 ] and fb [1 ] and fb [2 ] and fb [3 ] and not fb [4 ] and fb [5 ]
1786
+ assert f .type .shape == (5 , 1 , 1 , 1 , None , 1 )
1787
+ assert fb == (False , True , True , True , False , True )
1788
+
1789
1789
g = self .join_op (1 , a , b , c , d , e )
1790
1790
gb = tuple (s == 1 for s in g .type .shape )
1791
- assert gb [0 ] and not gb [1 ] and gb [2 ] and gb [3 ] and not gb [4 ] and gb [5 ]
1791
+ assert g .type .shape == (1 , None , 1 , 1 , None , 1 )
1792
+ assert gb == (True , False , True , True , False , True )
1793
+
1792
1794
h = self .join_op (4 , a , b , c , d , e )
1793
1795
hb = tuple (s == 1 for s in h .type .shape )
1794
- assert hb [0 ] and hb [1 ] and hb [2 ] and hb [3 ] and not hb [4 ] and hb [5 ]
1796
+ assert h .type .shape == (1 , 1 , 1 , 1 , None , 1 )
1797
+ assert hb == (True , True , True , True , False , True )
1795
1798
1796
1799
f = function ([a , b , c , d , e ], f , mode = self .mode )
1797
1800
topo = f .maker .fgraph .toposort ()
@@ -1903,7 +1906,7 @@ def test_mixed_ndim_error(self):
1903
1906
rng = np .random .default_rng (seed = utt .fetch_seed ())
1904
1907
v = self .shared (rng .random (4 ).astype (self .floatX ))
1905
1908
m = self .shared (rng .random ((4 , 4 )).astype (self .floatX ))
1906
- with pytest .raises (TypeError ):
1909
+ with pytest .raises (TypeError , match = "same number of dimensions" ):
1907
1910
self .join_op (0 , v , m )
1908
1911
1909
1912
def test_split_0elem (self ):
0 commit comments