Skip to content

fix mixtral quantization scaler axis when dimension > 2 #132

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 20, 2024
Merged

Conversation

sixiang-google
Copy link
Collaborator

No description provided.

@sixiang-google sixiang-google requested a review from lsy323 June 20, 2024 18:02
@lsy323 lsy323 merged commit ec66a75 into main Jun 20, 2024
4 checks passed
@@ -586,7 +586,7 @@ def main(argv) -> None:
if FLAGS.quantize_weights:
quantize_num_bits = 8 if "int8" in FLAGS.quantize_type else 4
is_blockwise = "blockwise" in FLAGS.quantize_type
weight_axis = lambda x: 0 if x in quantize_embedding_weight_map else 1
weight_axis = lambda x: 0 if x in quantize_embedding_weight_map else -1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great! Thanks Xiang for fixing it!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants