Skip to content

Commit f384133

Browse files
committed
fix key errors
1 parent e4a28a8 commit f384133

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

torchao/quantization/quant_api.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -681,13 +681,14 @@ def apply_int4_weight_only_quant(weight):
681681
eps = 1e-6
682682
zero_point_dtype = torch.bfloat16
683683

684-
assert layout in LAYOUT_TO_ZERO_POINT_DOMAIN.keys(), f"Only support layout: {LAYOUT_TO_ZERO_POINT_DOMAIN.keys()}"
684+
nonlocal zero_point_domain
685+
assert any(isinstance(layout, support_layout) for support_layout in LAYOUT_TO_ZERO_POINT_DOMAIN.keys()), f"Only support layout: {LAYOUT_TO_ZERO_POINT_DOMAIN.keys()}"
685686
if zero_point_domain is None:
686687
# the first value is the default one
687-
zero_point_domain = LAYOUT_TO_ZERO_POINT_DOMAIN[layout][0]
688-
preserve_zero = LAYOUT_TO_PRESERVE_ZEROS[layout]
688+
zero_point_domain = LAYOUT_TO_ZERO_POINT_DOMAIN[type(layout)][0]
689+
preserve_zero = LAYOUT_TO_PRESERVE_ZEROS[type(layout)]
689690
else:
690-
assert zero_point_domain in LAYOUT_TO_ZERO_POINT_DOMAIN[layout], f"Layout only support {LAYOUT_TO_ZERO_POINT_DOMAIN[layout]}"
691+
assert zero_point_domain in LAYOUT_TO_ZERO_POINT_DOMAIN[type(layout)], f"Layout only support {LAYOUT_TO_ZERO_POINT_DOMAIN[layout]}"
691692

692693
# Sparse Marlin only supports symmetric quantization.
693694
# NOTE: If we start having lots of layouts that require different configurations,

0 commit comments

Comments
 (0)