File tree 1 file changed +5
-4
lines changed
1 file changed +5
-4
lines changed Original file line number Diff line number Diff line change @@ -681,13 +681,14 @@ def apply_int4_weight_only_quant(weight):
681
681
eps = 1e-6
682
682
zero_point_dtype = torch .bfloat16
683
683
684
- assert layout in LAYOUT_TO_ZERO_POINT_DOMAIN .keys (), f"Only support layout: { LAYOUT_TO_ZERO_POINT_DOMAIN .keys ()} "
684
+ nonlocal zero_point_domain
685
+ assert any (isinstance (layout , support_layout ) for support_layout in LAYOUT_TO_ZERO_POINT_DOMAIN .keys ()), f"Only support layout: { LAYOUT_TO_ZERO_POINT_DOMAIN .keys ()} "
685
686
if zero_point_domain is None :
686
687
# the first value is the default one
687
- zero_point_domain = LAYOUT_TO_ZERO_POINT_DOMAIN [layout ][0 ]
688
- preserve_zero = LAYOUT_TO_PRESERVE_ZEROS [layout ]
688
+ zero_point_domain = LAYOUT_TO_ZERO_POINT_DOMAIN [type ( layout ) ][0 ]
689
+ preserve_zero = LAYOUT_TO_PRESERVE_ZEROS [type ( layout ) ]
689
690
else :
690
- assert zero_point_domain in LAYOUT_TO_ZERO_POINT_DOMAIN [layout ], f"Layout only support { LAYOUT_TO_ZERO_POINT_DOMAIN [layout ]} "
691
+ assert zero_point_domain in LAYOUT_TO_ZERO_POINT_DOMAIN [type ( layout ) ], f"Layout only support { LAYOUT_TO_ZERO_POINT_DOMAIN [layout ]} "
691
692
692
693
# Sparse Marlin only supports symmetric quantization.
693
694
# NOTE: If we start having lots of layouts that require different configurations,
You can’t perform that action at this time.
0 commit comments