Skip to content

Commit 2c467dd

Browse files
larryliu0820facebook-github-bot
authored andcommitted
Fix quantized embedding export logic (#3095)
Summary: Add patches to make 4bit quantized embedding work for export. Fixed: * Schema mismatch between functional embedding_4bit and out variant * Set `packed=True` for 4bit quantization Pull Request resolved: #3095 Reviewed By: mikekgfb Differential Revision: D56340670 Pulled By: larryliu0820 fbshipit-source-id: c98623a9b7633fc5a6c390be1557213c719fa95a
1 parent cf78107 commit 2c467dd

File tree

2 files changed

+5
-2
lines changed

2 files changed

+5
-2
lines changed

examples/models/llama2/export_llama_lib.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -614,7 +614,10 @@ def _prepare_for_llama_export(modelname: str, args) -> LlamaEdgeManager:
614614
bitwidth = int(bitwidth)
615615
transforms.append(
616616
lambda model: EmbeddingQuantHandler(
617-
model, bitwidth=bitwidth, group_size=group_size
617+
model,
618+
bitwidth=bitwidth,
619+
group_size=group_size,
620+
packed=(bitwidth == 4),
618621
).quantized_model()
619622
)
620623

exir/passes/_quant_patterns_and_replacements.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ def embedding_byte_dtype_out_meta(
189189

190190
quantized_decomposed_lib.define(
191191
"embedding_4bit.dtype_out(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, "
192-
"int weight_quant_min, int weight_quant_max, Tensor indices, ScalarType? dtype=None, *, Tensor(a!) out) -> Tensor(a!)",
192+
"int weight_quant_min, int weight_quant_max, Tensor indices, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)",
193193
)
194194

195195

0 commit comments

Comments
 (0)