2525class TestLinear (unittest .TestCase ):
2626 def test_fp16_linear (self ):
2727 for use_bias in (True , False ):
28- self ._test_linear (
29- lambda in_size , out_size : torch .nn .Linear (
30- in_size , out_size , bias = use_bias # noqa
31- ),
32- uses_bias = use_bias ,
33- dtype = torch .float16 ,
34- atol = 5e-2 ,
35- )
28+ for num_batch_dims in range (1 , 3 ):
29+ self ._test_linear (
30+ lambda in_size , out_size : torch .nn .Linear (
31+ in_size , out_size , bias = use_bias # noqa
32+ ),
33+ num_batch_dims = num_batch_dims ,
34+ uses_bias = use_bias ,
35+ dtype = torch .float16 ,
36+ atol = 5e-2 ,
37+ )
3638
3739 def test_fp32_linear (self ):
3840 for use_bias in (True , False ):
39- self ._test_linear (
40- lambda in_size , out_size : torch .nn .Linear (
41- in_size , out_size , bias = use_bias # noqa
42- ),
43- uses_bias = use_bias ,
44- )
41+ for num_batch_dims in range (1 , 3 ):
42+ self ._test_linear (
43+ lambda in_size , out_size : torch .nn .Linear (
44+ in_size , out_size , bias = use_bias # noqa
45+ ),
46+ uses_bias = use_bias ,
47+ num_batch_dims = num_batch_dims ,
48+ )
4549
4650 def test_fp32_addmm (self ):
4751 """
@@ -62,24 +66,71 @@ def forward(self, x):
6266 uses_bias = True ,
6367 )
6468
69+ def test_fp32_linear_fused_relu (self ):
70+ class LinearReluModule (torch .nn .Module ):
71+ def __init__ (self , in_size , out_size , use_bias ):
72+ super ().__init__ ()
73+ self .linear = torch .nn .Linear (in_size , out_size , bias = use_bias )
74+
75+ def forward (self , x ):
76+ return torch .nn .functional .relu (self .linear (x ))
77+
78+ for use_bias in (True , False ):
79+ for num_batch_dims in range (1 , 3 ):
80+ self ._test_linear (
81+ lambda in_size , out_size : LinearReluModule (
82+ in_size ,
83+ out_size ,
84+ use_bias , # noqa
85+ ),
86+ uses_bias = use_bias ,
87+ num_batch_dims = num_batch_dims ,
88+ )
89+
90+ def test_qs8_linear_fused_relu (self ):
91+ class LinearReluModule (torch .nn .Module ):
92+ def __init__ (self , in_size , out_size , use_bias ):
93+ super ().__init__ ()
94+ self .linear = torch .nn .Linear (in_size , out_size , bias = use_bias )
95+
96+ def forward (self , x ):
97+ return torch .nn .functional .relu (self .linear (x ))
98+
99+ for use_bias in (True , False ):
100+ for num_batch_dims in range (1 , 3 ):
101+ self ._test_linear (
102+ lambda in_size , out_size : LinearReluModule (
103+ in_size ,
104+ out_size ,
105+ use_bias , # noqa
106+ ),
107+ num_batch_dims = num_batch_dims ,
108+ uses_bias = use_bias ,
109+ quant = True ,
110+ )
111+
65112 def test_qs8_linear (self ):
66113 for use_bias in (True , False ):
67- self ._test_linear (
68- lambda in_size , out_size : torch .nn .Linear (
69- in_size , out_size , bias = use_bias # noqa
70- ),
71- uses_bias = use_bias ,
72- )
114+ for num_batch_dims in range (1 , 3 ):
115+ self ._test_linear (
116+ lambda in_size , out_size : torch .nn .Linear (
117+ in_size , out_size , bias = use_bias # noqa
118+ ),
119+ uses_bias = use_bias ,
120+ num_batch_dims = num_batch_dims ,
121+ )
73122
74123 @unittest .skip ("XNNPACK currently only supports per-channel dynamic quantization." )
75124 def test_qd8_per_tensor_linear (self ):
76125 for uses_bias in (False , True ):
77126 inputs = (torch .randn (2 , 4 ),)
78127 module = torch .nn .Linear (4 , 5 , bias = uses_bias )
128+ dynamic_shapes = ({0 : torch .export .Dim ("batch" , max = 100 )},)
79129
80130 self ._test_dqlinear (
81131 module ,
82132 inputs ,
133+ dynamic_shapes = dynamic_shapes ,
83134 is_per_channel = False ,
84135 uses_bias = uses_bias ,
85136 )
@@ -92,6 +143,7 @@ def test_qd8_per_channel_linear(self):
92143 self ._test_dqlinear (
93144 module ,
94145 inputs ,
146+ dynamic_shapes = ({0 : torch .export .Dim ("batch" , max = 100 )},),
95147 is_per_channel = True ,
96148 uses_bias = uses_bias ,
97149 )
@@ -113,7 +165,7 @@ def test_qd8_per_channel_4w_linear(self):
113165 qconfig = self ._get_4b_dqconfig ()
114166 input_channels = [2 , 63 ]
115167 output_channels = [1 , 8 , 127 ]
116- batches = [1 , 2 ]
168+ batches = [2 , 2 ]
117169 use_bias = [False , True ]
118170
119171 for bs , bias , ipc , opc in product (
@@ -128,13 +180,14 @@ def test_qd8_per_channel_4w_linear(self):
128180 self ._test_dqlinear (
129181 module ,
130182 inputs ,
183+ dynamic_shapes = ({0 : torch .export .Dim ("batch" , max = 100 )},),
131184 is_per_channel = True ,
132185 uses_bias = bias ,
133186 qconfig = qconfig ,
134187 )
135188
136189 def test_qd8_per_channel_linear_parallel (self ):
137- in_size = 1
190+ in_size = 2
138191 input_size = 4
139192 output_size = 5
140193
@@ -164,17 +217,39 @@ def forward(self, x, y):
164217 torch .rand (in_size , input_size , dtype = torch .float ),
165218 torch .rand (in_size , input_size , dtype = torch .float ),
166219 )
220+ batch_dim = torch .export .Dim ("batch" , max = 100 )
221+ dynamic_shapes = ({0 : batch_dim }, {0 : batch_dim })
167222
168223 self ._test_dqlinear (
169224 ParallelLinear (),
170225 inputs ,
226+ dynamic_shapes = dynamic_shapes ,
171227 linear_count = 2 ,
172228 is_per_channel = True ,
173229 uses_bias = True ,
174230 )
175231
232+ def test_qd8_per_channel_linear_with_two_batch (self ):
233+ in_size = 2
234+ input_size = 4
235+ output_size = 5
236+
237+ linear = torch .nn .Linear (input_size , output_size )
238+ inputs = (torch .randn (2 , in_size , input_size , dtype = torch .float ),)
239+ batch_dim = torch .export .Dim ("batch" , max = 100 )
240+ dynamic_shapes = ({0 : batch_dim , 1 : batch_dim },)
241+
242+ self ._test_dqlinear (
243+ linear ,
244+ inputs ,
245+ dynamic_shapes = dynamic_shapes ,
246+ linear_count = 1 ,
247+ is_per_channel = True ,
248+ uses_bias = True ,
249+ )
250+
176251 def test_qd8_per_channel_linear_sequential (self ):
177- in_size = 1
252+ in_size = 2
178253 input_size = 4
179254 intermediate_size = 5
180255 output_size = 3
@@ -202,17 +277,19 @@ def forward(self, x):
202277 return b
203278
204279 inputs = (torch .rand (in_size , input_size , dtype = torch .float ),)
280+ dynamic_shapes = ({0 : torch .export .Dim ("batch" , max = 100 )},)
205281
206282 self ._test_dqlinear (
207283 LinearSequential (),
208284 inputs ,
285+ dynamic_shapes = dynamic_shapes ,
209286 linear_count = 2 ,
210287 is_per_channel = True ,
211288 uses_bias = True ,
212289 )
213290
214291 def test_qd8_per_channel_linear_parellel_and_sequential (self ):
215- in_size = 1
292+ in_size = 2
216293 input_size = 4
217294 intermediate_size = 5
218295 output_size = 3
@@ -251,54 +328,26 @@ def forward(self, x, y):
251328 torch .rand (in_size , input_size , dtype = torch .float ),
252329 torch .rand (in_size , input_size , dtype = torch .float ),
253330 )
331+ dynamic_shapes = (
332+ {0 : torch .export .Dim ("batch" , max = 100 )},
333+ {0 : torch .export .Dim ("batch2" , max = 100 )},
334+ )
254335
255336 self ._test_dqlinear (
256- LinearModule (), inputs , linear_count = 3 , is_per_channel = True , uses_bias = True
337+ LinearModule (),
338+ inputs ,
339+ dynamic_shapes = dynamic_shapes ,
340+ linear_count = 3 ,
341+ is_per_channel = True ,
342+ uses_bias = True ,
343+ atol = 1e-1 ,
257344 )
258345
259- def test_fp32_linear_fused_relu (self ):
260- class LinearReluModule (torch .nn .Module ):
261- def __init__ (self , in_size , out_size , use_bias ):
262- super ().__init__ ()
263- self .linear = torch .nn .Linear (in_size , out_size , bias = use_bias )
264-
265- def forward (self , x ):
266- return torch .nn .functional .relu (self .linear (x ))
267-
268- for use_bias in (True , False ):
269- self ._test_linear (
270- lambda in_size , out_size : LinearReluModule (
271- in_size ,
272- out_size ,
273- use_bias , # noqa
274- ),
275- uses_bias = use_bias ,
276- )
277-
278- def test_qs8_linear_fused_relu (self ):
279- class LinearReluModule (torch .nn .Module ):
280- def __init__ (self , in_size , out_size , use_bias ):
281- super ().__init__ ()
282- self .linear = torch .nn .Linear (in_size , out_size , bias = use_bias )
283-
284- def forward (self , x ):
285- return torch .nn .functional .relu (self .linear (x ))
286-
287- for use_bias in (True , False ):
288- self ._test_linear (
289- lambda in_size , out_size : LinearReluModule (
290- in_size ,
291- out_size ,
292- use_bias , # noqa
293- ),
294- uses_bias = use_bias ,
295- quant = True ,
296- )
297-
298346 def _test_linear (
299347 self ,
300348 make_module ,
301349 uses_bias ,
350+ num_batch_dims = 1 ,
302351 quant = False ,
303352 dtype : torch .dtype = torch .float ,
304353 atol = 1e-03 ,
@@ -315,7 +364,7 @@ def _test_linear(
315364 )
316365 )
317366
318- in_sizes = [1 , 4 , 4 ]
367+ in_sizes = [3 , 4 , 4 ]
319368 input_sizes = [4 , 37 , 17 ]
320369 output_sizes = [4 , 17 , 37 ]
321370
@@ -327,12 +376,18 @@ def _test_linear(
327376 in_size = int (in_sizes [i ])
328377 input_size = int (input_sizes [i ])
329378 output_size = int (output_sizes [i ])
330- print (f"Testing { in_size } { input_size } { output_size } " )
379+ input_shape = [in_size ] * num_batch_dims + [input_size ]
380+ print (f"Testing input_shape { input_shape } with { output_size } out_channels" )
331381
332382 module = make_module (input_size , output_size ).eval ().to (dtype )
333- inputs = (torch .randn (in_size , input_size ).to (dtype ),)
383+ inputs = (torch .randn (input_shape ).to (dtype ),)
384+ dynamic_shape = {}
385+ for i in range (num_batch_dims ):
386+ dynamic_shape [i ] = torch .export .Dim (f"batch{ i } " , max = in_size )
387+
388+ dynamic_shape = (dynamic_shape ,)
334389
335- tester = Tester (module , inputs )
390+ tester = Tester (module , inputs , dynamic_shapes = dynamic_shape )
336391
337392 if quant :
338393 tester .quantize ()
@@ -360,10 +415,12 @@ def _test_dqlinear(
360415 self ,
361416 module ,
362417 inputs ,
418+ dynamic_shapes ,
363419 linear_count = 1 ,
364420 is_per_channel = False ,
365421 uses_bias = False ,
366422 qconfig : Optional [QuantizationConfig ] = None ,
423+ atol = 5e-02 ,
367424 ):
368425 aten_op , edge_op = (
369426 (
@@ -382,13 +439,12 @@ def _test_dqlinear(
382439 is_dynamic = True ,
383440 )
384441
385- tester = Tester (module , inputs )
442+ tester = Tester (module , inputs , dynamic_shapes = dynamic_shapes )
386443 tester .quantize (Quantize (quantization_config = quant_config ))
387444
388445 tester .export ()
389446 tester .check_count ({aten_op : linear_count })
390447 tester .check (["torch.ops.quantized_decomposed" ])
391- tester .dump_artifact ()
392448 tester .to_edge ()
393449 tester .check_count ({edge_op : linear_count })
394450
@@ -400,4 +456,4 @@ def _test_dqlinear(
400456
401457 tester .to_executorch ()
402458 tester .serialize ()
403- tester .run_method_and_compare_outputs (atol = 5e-02 )
459+ tester .run_method_and_compare_outputs (atol = atol )
0 commit comments