Skip to content

Commit ec66a75

Browse files
fix mixtral quantization scaler axis when dimension > 2 (#132)
1 parent d3547f9 commit ec66a75

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

convert_checkpoints.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -586,7 +586,7 @@ def main(argv) -> None:
586586
if FLAGS.quantize_weights:
587587
quantize_num_bits = 8 if "int8" in FLAGS.quantize_type else 4
588588
is_blockwise = "blockwise" in FLAGS.quantize_type
589-
weight_axis = lambda x: 0 if x in quantize_embedding_weight_map else 1
589+
weight_axis = lambda x: 0 if x in quantize_embedding_weight_map else -1
590590
start = time.perf_counter()
591591
state_dict = _quantize_state_dict(
592592
state_dict,

0 commit comments

Comments
 (0)