@@ -134,7 +134,7 @@ def quantize(
134
134
from torchao .quantization .quant_api import Int8DynActInt4WeightQuantizer
135
135
136
136
model = Int8DynActInt4WeightQuantizer (
137
- precision = torch_dtype , group_size = group_size
137
+ precision = torch_dtype , groupsize = group_size
138
138
).quantize (model )
139
139
if verbose_export ():
140
140
print ("quantized model:" , model )
@@ -153,6 +153,7 @@ def quantize(
153
153
if calibration_tasks is None :
154
154
calibration_tasks = ["wikitext" ]
155
155
156
+ from torchao .quantization .GPTQ import InputRecorder
156
157
from torchao .quantization .quant_api import Int8DynActInt4WeightGPTQQuantizer
157
158
158
159
if tokenizer_path is None :
@@ -161,17 +162,28 @@ def quantize(
161
162
tokenizer = SentencePieceProcessor ( # pyre-ignore[28]
162
163
model_file = str (tokenizer_path )
163
164
)
165
+
166
+ inputs = (
167
+ InputRecorder (
168
+ tokenizer ,
169
+ calibration_seq_length ,
170
+ None , # input_prep_func
171
+ pad_calibration_inputs ,
172
+ model .vocab_size ,
173
+ )
174
+ .record_inputs (
175
+ calibration_tasks ,
176
+ calibration_limit ,
177
+ )
178
+ .get_inputs ()
179
+ )
180
+
164
181
gptq_quantizer = Int8DynActInt4WeightGPTQQuantizer (
165
- tokenizer ,
166
182
blocksize ,
167
183
percdamp ,
168
184
group_size ,
169
- calibration_tasks ,
170
- calibration_limit ,
171
- calibration_seq_length ,
172
- pad_calibration_inputs ,
173
185
)
174
- model = gptq_quantizer .quantize (model )
186
+ model = gptq_quantizer .quantize (model , inputs )
175
187
return model
176
188
else :
177
189
raise Exception (f"Unrecognized quantize mode: { qmode } " )
0 commit comments