@@ -423,10 +423,16 @@ def test_qd8_fp32_per_token_weight_per_channel_group_int4(self):
423
423
use_bias = use_bias ,
424
424
)
425
425
426
+ # rank 3
426
427
inputs = (torch .randn (1 , M , K ),)
427
428
self ._test_groupwise_dq_linear (
428
429
lin_mod , inputs , group_size = bl , use_bias = use_bias
429
430
)
431
+ # rank 2
432
+ inputs = (torch .randn (1 , K ),)
433
+ self ._test_groupwise_dq_linear (
434
+ lin_mod , inputs , group_size = bl , use_bias = use_bias
435
+ )
430
436
431
437
@unittest .skipIf (
432
438
not torchao_installed , "Per Channel Group Quantization Required TorchAO"
@@ -437,28 +443,29 @@ def test_qd8_fp16_per_token_weight_per_channel_group_int4(self):
437
443
bl_sizes = [32 , 32 , 32 , 64 ]
438
444
N_sizes = [2 , 17 , 92 , 128 ]
439
445
440
- for use_bias in [True , False ]:
441
- for M , K , bl , N in zip (M_sizes , K_sizes , bl_sizes , N_sizes ):
442
- lin_mod = BaseLinear (
443
- in_size = M ,
444
- input_channels = K ,
445
- output_channels = N ,
446
- dtype = torch .float16 ,
447
- use_bias = use_bias ,
448
- )
446
+ for input_rank in range (2 , 4 ):
447
+ for use_bias in [True , False ]:
448
+ for M , K , bl , N in zip (M_sizes , K_sizes , bl_sizes , N_sizes ):
449
+ lin_mod = BaseLinear (
450
+ in_size = M ,
451
+ input_channels = K ,
452
+ output_channels = N ,
453
+ dtype = torch .float16 ,
454
+ use_bias = use_bias ,
455
+ )
449
456
450
- inputs = lin_mod .get_inputs ()
451
- # This requires slightly higher atol, but if you look at error it is not that bad:
452
- # Difference: max: 0.00140380859375, abs: 0.00140380859375, mean abs error: 0.00042724609375.
453
- # -- Model vs. Reference --
454
- # Numel: 4, 4
455
- # Median: -0.05023193359375, -0.0516357421875
456
- # Mean: 0.2373046875, 0.237060546875
457
- # Max: 1.0078125, 1.0078125
458
- # Min: -0.08465576171875, -0.08441162109375
459
- self ._test_groupwise_dq_linear (
460
- lin_mod , inputs , group_size = bl , use_bias = use_bias , atol = 1e-2
461
- )
457
+ inputs = lin_mod .get_inputs (rank = input_rank )
458
+ # This requires slightly higher atol, but if you look at error it is not that bad:
459
+ # Difference: max: 0.00140380859375, abs: 0.00140380859375, mean abs error: 0.00042724609375.
460
+ # -- Model vs. Reference --
461
+ # Numel: 4, 4
462
+ # Median: -0.05023193359375, -0.0516357421875
463
+ # Mean: 0.2373046875, 0.237060546875
464
+ # Max: 1.0078125, 1.0078125
465
+ # Min: -0.08465576171875, -0.08441162109375
466
+ self ._test_groupwise_dq_linear (
467
+ lin_mod , inputs , group_size = bl , use_bias = use_bias , atol = 1e-2
468
+ )
462
469
463
470
@unittest .skipIf (
464
471
not torchao_installed , "Per Channel Group Quantization Required TorchAO"
0 commit comments