@@ -369,8 +369,7 @@ def eval_val_sliding(
369369INT8_KEEP_FLOAT_MAX_NUMEL = 65_536
370370INT8_KEEP_FLOAT_STORE_DTYPE = torch .float16
371371INT8_PER_ROW_SCALE_DTYPE = torch .float16
372- INT8_CLIP_PERCENTILE = 99.99984
373- INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0
372+ GPTQ_CLIP_PERCENTILES = [0.999 , 0.9995 , 0.9999 , 0.99999 , 1.0 ]
374373
375374def tensor_nbytes (t : Tensor ) -> int :
376375 return int (t .numel ()) * int (t .element_size ())
@@ -386,17 +385,29 @@ def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, s
386385def quantize_float_tensor (t : Tensor , bits : int = 8 ) -> tuple [Tensor , Tensor ]:
387386 max_val = 2 ** (bits - 1 ) - 1
388387 t32 = t .float ()
388+ if t32 .ndim == 2 and t32 .numel () > 0 :
389+ best_q = None
390+ best_scale = None
391+ best_mse = torch .full ((t32 .shape [0 ],), float ("inf" ))
392+ for pct in GPTQ_CLIP_PERCENTILES :
393+ clip_abs = torch .quantile (t32 .abs (), pct , dim = 1 )
394+ scale = (clip_abs / max_val ).clamp_min (1.0 / max_val )
395+ clipped = torch .clamp (t32 , - clip_abs [:, None ], clip_abs [:, None ])
396+ q = torch .clamp (torch .round (clipped / scale [:, None ]), - max_val , max_val ).to (torch .int8 )
397+ recon = q .float () * scale [:, None ]
398+ mse = (t32 - recon ).pow (2 ).mean (dim = 1 )
399+ improved = mse < best_mse
400+ if improved .any ():
401+ if best_q is None :
402+ best_q , best_scale , best_mse = q .clone (), scale .clone (), mse .clone ()
403+ else :
404+ best_q [improved ] = q [improved ]
405+ best_scale [improved ] = scale [improved ]
406+ best_mse [improved ] = mse [improved ]
407+ return best_q .contiguous (), best_scale .to (dtype = INT8_PER_ROW_SCALE_DTYPE ).contiguous ()
389408 if t32 .ndim == 2 :
390- clip_abs = (
391- torch .quantile (t32 .abs (), INT8_CLIP_Q , dim = 1 )
392- if t32 .numel ()
393- else torch .empty ((t32 .shape [0 ],), dtype = torch .float32 )
394- )
395- clipped = torch .maximum (torch .minimum (t32 , clip_abs [:, None ]), - clip_abs [:, None ])
396- scale = (clip_abs / max_val ).clamp_min (1.0 / max_val )
397- q = torch .clamp (torch .round (clipped / scale [:, None ]), - max_val , max_val ).to (torch .int8 ).contiguous ()
398- return q , scale .to (dtype = INT8_PER_ROW_SCALE_DTYPE ).contiguous ()
399- clip_abs = float (torch .quantile (t32 .abs ().flatten (), INT8_CLIP_Q ).item ()) if t32 .numel () else 0.0
409+ return torch .zeros_like (t32 , dtype = torch .int8 ), torch .empty ((t32 .shape [0 ],), dtype = INT8_PER_ROW_SCALE_DTYPE )
410+ clip_abs = float (t32 .abs ().max ().item ()) if t32 .numel () else 0.0
400411 scale = torch .tensor (clip_abs / max_val if clip_abs > 0 else 1.0 , dtype = torch .float32 )
401412 q = torch .clamp (torch .round (torch .clamp (t32 , - clip_abs , clip_abs ) / scale ), - max_val , max_val ).to (torch .int8 ).contiguous ()
402413 return q , scale
0 commit comments