13
13
def format_id (format ):
14
14
return f"{ format = } "
15
15
16
+
16
17
@pytest .mark .parametrize ("format" , ["coo" , "gcxs" ])
17
18
def test_matmul (benchmark , sides , seed , format , backend , min_size , max_size , ids = format_id ):
18
-
19
19
m , n , p = sides
20
-
20
+
21
21
if m * n >= max_size or n * p >= max_size or m * n <= min_size or n * p <= min_size :
22
22
pytest .skip ()
23
-
23
+
24
24
rng = np .random .default_rng (seed = seed )
25
25
x = sparse .random ((m , n ), density = DENSITY , format = format , random_state = rng )
26
26
y = sparse .random ((n , p ), density = DENSITY , format = format , random_state = rng )
27
-
28
- if hasattr (sparse , "compiled" ): operator .matmul = sparse .compiled (operator .matmul )
27
+
28
+ if hasattr (sparse , "compiled" ):
29
+ operator .matmul = sparse .compiled (operator .matmul )
29
30
30
31
x @ y # Numba compilation
31
32
@@ -55,8 +56,9 @@ def elemwise_args(request, seed, max_size):
55
56
def test_elemwise (benchmark , f , elemwise_args , backend ):
56
57
x , y = elemwise_args
57
58
58
- if hasattr (sparse , "compiled" ): f = sparse .compiled (f )
59
-
59
+ if hasattr (sparse , "compiled" ):
60
+ f = sparse .compiled (f )
61
+
60
62
f (x , y )
61
63
62
64
@benchmark
@@ -84,7 +86,8 @@ def elemwise_broadcast_args(request, seed, max_size):
84
86
def test_elemwise_broadcast (benchmark , f , elemwise_broadcast_args ):
85
87
x , y = elemwise_broadcast_args
86
88
87
- if hasattr (sparse , "compiled" ): f = sparse .compiled (f )
89
+ if hasattr (sparse , "compiled" ):
90
+ f = sparse .compiled (f )
88
91
89
92
f (x , y )
90
93
@@ -109,7 +112,8 @@ def test_index_scalar(benchmark, indexing_args):
109
112
side = x .shape [0 ]
110
113
rank = x .ndim
111
114
112
- if hasattr (sparse , "compiled" ): operator .getitem = sparse .compiled (operator .getitem )
115
+ if hasattr (sparse , "compiled" ):
116
+ operator .getitem = sparse .compiled (operator .getitem )
113
117
114
118
x [(side // 2 ,) * rank ] # Numba compilation
115
119
@@ -123,7 +127,8 @@ def test_index_slice(benchmark, indexing_args):
123
127
side = x .shape [0 ]
124
128
rank = x .ndim
125
129
126
- if hasattr (sparse , "compiled" ): operator .getitem = sparse .compiled (operator .getitem )
130
+ if hasattr (sparse , "compiled" ):
131
+ operator .getitem = sparse .compiled (operator .getitem )
127
132
128
133
x [(slice (side // 2 ),) * rank ] # Numba compilation
129
134
@@ -138,7 +143,8 @@ def test_index_fancy(benchmark, indexing_args, seed):
138
143
rng = np .random .default_rng (seed = seed )
139
144
index = rng .integers (0 , side , size = (side // 2 ,))
140
145
141
- if hasattr (sparse , "compiled" ): operator .getitem = sparse .compiled (operator .getitem )
146
+ if hasattr (sparse , "compiled" ):
147
+ operator .getitem = sparse .compiled (operator .getitem )
142
148
143
149
x [index ] # Numba compilation
144
150
@@ -179,7 +185,8 @@ def densemul_args(request, sides, seed, max_size):
179
185
def test_gcxs_dot_ndarray (benchmark , densemul_args ):
180
186
x , t = densemul_args
181
187
182
- if hasattr (sparse , "compiled" ): operator .matmul = sparse .compiled (operator .matmul )
188
+ if hasattr (sparse , "compiled" ):
189
+ operator .matmul = sparse .compiled (operator .matmul )
183
190
184
191
# Numba compilation
185
192
x @ t
0 commit comments