@@ -132,15 +132,17 @@ def cuda(self):
132132
133133class GenericGPTQRunner (fx .Interpreter ):
134134 """
135- This is a generic GPTQ runner that takes an existing model and applies GPTQ.
136- It uses torch._dynamo.export to obtain a graph of the model and then hooks
137- into function calls and when it detects a linear, it applies GPTQ to the weight
138- given the calibration of inputs passed in at initialization. It puts the results
139- into the state_dict so that the quantized model weights/qparams can be loaded
140- directly into the model.
135+ This is a generic GPTQ runner that takes an existing model and
136+ applies GPTQ. It uses torch._dynamo.export to obtain a graph of
137+ the model and then hooks into function calls and when it detects a
138+ linear, it applies GPTQ to the weight given the calibration of
139+ inputs passed in at initialization. It puts the results into the
140+ state_dict so that the quantized model weights/qparams can be
141+ loaded directly into the model.
141142
142143 This class is expected to work in concert with a GPTQSimpleQuantizer
143144 class to define the specific type of quantization being done.
145+
144146 """
145147
146148 def __init__ (
@@ -206,7 +208,7 @@ def get_quantized_state_dict(self):
206208 self .gptq_done
207209 ), "need to run GPTQRunner before you can get_quantized_state_dict"
208210 quantized_state_dict = self .new_state_dict
209- # Don't want to store/load the kv_cache so remove it from the state_dict
211+
210212 del_list = []
211213 for param_fqn in quantized_state_dict :
212214 if "kv_cache" in param_fqn :
@@ -224,7 +226,8 @@ def tensors_to_cuda(args):
224226
225227 # flatten args and kwargs together
226228 flat_args , spec = tree_flatten ((args , kwargs ))
227- # move all single tensors to cuda, will move MultiInputs to cuda one at a time
229+ # move all single tensors to cuda, will move MultiInputs
230+ # to cuda one at a time
228231 flat_args = tensors_to_cuda (flat_args )
229232
230233 has_multi_input = MultiInput in [type (x ) for x in flat_args ]
@@ -421,8 +424,9 @@ def faster_quant(self, H, W):
421424 if all_qparams == []:
422425 all_qparams .append (cur_qparams )
423426
424- # convert a list of qparams objects into a single one. enerally by
425- # concatenating a bunch of n,1 scale/zeros tensors into a n,num_groups tensor
427+ # convert a list of qparams objects into a single
428+ # one. generally by concatenating a bunch of n,1 scale/zeros
429+ # tensors into a n,num_groups tensor
426430 all_qparams = self .combine_qparams_list_func (all_qparams )
427431 Q = self .quantize_func (DQ , all_qparams )
428432 return Q , DQ .to (orig_dtype ), all_qparams
0 commit comments