@@ -479,30 +479,6 @@ def dequantize_mixed_int6(result: dict[str, Tensor], meta: dict[str, object],
479479# GPTQ: HESSIAN COLLECTION + QUANTIZATION
480480# -----------------------------
481481
482- def _make_hessian_hook (hessians : dict [str , Tensor ], name : str , n_samples : dict [str , int ]):
483- """Create a forward hook that accumulates H = X^T X for a linear layer."""
484- def hook (_module , input , _output ):
485- x = input [0 ].detach ().float ()
486- if x .ndim > 2 :
487- x = x .reshape (- 1 , x .shape [- 1 ]) # (B*T, D)
488- if name not in hessians :
489- hessians [name ] = torch .zeros (x .shape [1 ], x .shape [1 ], dtype = torch .float32 , device = "cpu" )
490- hessians [name ].add_ ((x .T @ x ).cpu ())
491- n_samples [name ] = n_samples .get (name , 0 ) + x .shape [0 ]
492- return hook
493-
494-
495- def _finalize_hessians (hessians : dict [str , Tensor ], num_batches : int ) -> dict [str , Tensor ]:
496- """Match the PR Hessian preparation path before quantization."""
497- finalized : dict [str , Tensor ] = {}
498- for name , H in hessians .items ():
499- H = H .float () / max (num_batches , 1 )
500- damp = 0.01 * torch .diag (H ).mean ().clamp_min (1e-6 )
501- H = H + damp * torch .eye (H .shape [0 ], dtype = H .dtype )
502- finalized [name ] = H
503- return finalized
504-
505-
506482def collect_hessians (
507483 model : nn .Module ,
508484 data_files : str ,
@@ -511,145 +487,110 @@ def collect_hessians(
511487 seq_len : int = 2048 ,
512488 batch_size : int = 4 ,
513489) -> dict [str , Tensor ]:
514- """Collect H = X^T X for each int6-targeted CastedLinear layer using training data .
490+ """Collect H = X^T X for CastedLinear layers (PR #1019 faithful transplant) .
515491
516- Returns dict mapping module_name (e.g. "blocks.0.attn.c_q") -> H (float32, CPU).
517- Uses forward_logits (no targets needed) and TokenStream (rank-0 only).
492+ Returns dict mapping param_name (e.g. "blocks.0.attn.c_q.weight") -> H (float32, CPU).
518493 """
519494 hessians : dict [str , Tensor ] = {}
520- n_samples : dict [str , int ] = {}
521- handles = []
522- skipped_names = []
523- sd_keys = set (model .state_dict ().keys ())
524-
495+ hooks : list = []
525496 for name , module in model .named_modules ():
526497 if isinstance (module , CastedLinear ):
527- weight_key = name + ".weight"
528- assert weight_key in sd_keys , f"No state_dict key for { weight_key } "
529- cat = _classify_param (weight_key )
530- if cat not in ("mlp" , "attn" ):
531- skipped_names .append (name )
532- continue
533- handles .append (module .register_forward_hook (_make_hessian_hook (hessians , name , n_samples )))
534-
535- print (f"gptq: hooking { len (handles )} layers, skipped: { skipped_names } " , flush = True )
498+ param_name = name + ".weight"
499+ cols = module .weight .shape [1 ]
500+ hessians [param_name ] = torch .zeros (cols , cols , dtype = torch .float32 , device = 'cpu' )
501+ def make_hook (pname ):
502+ def hook_fn (module , input , output ):
503+ x = input [0 ].detach ().float ()
504+ if x .ndim == 3 :
505+ x = x .reshape (- 1 , x .shape [- 1 ])
506+ hessians [pname ] += (x .T @ x ).cpu ()
507+ return hook_fn
508+ hooks .append (module .register_forward_hook (make_hook (param_name )))
509+
510+ print (f"gptq: hooking { len (hooks )} layers" , flush = True )
536511
537- stream = TokenStream (data_files )
538512 num_batches = num_samples // batch_size
513+ stream = TokenStream (data_files )
539514 model .eval ()
540- with torch .inference_mode ():
541- for i in range (num_batches ):
515+ with torch .inference_mode (), torch . autocast ( device_type = "cuda" , dtype = torch . bfloat16 ) :
516+ for _ in range (num_batches ):
542517 tokens = stream .take (batch_size * seq_len ).to (device = device , dtype = torch .int64 )
543518 x = tokens .reshape (batch_size , seq_len )
544- with torch .autocast (device_type = "cuda" , dtype = LOWP_DTYPE , enabled = True ):
545- model .forward_logits (x )
519+ model .forward_logits (x )
546520
547- for h in handles :
521+ for h in hooks :
548522 h .remove ()
549523
550- result = _finalize_hessians (hessians , num_batches )
551- print (f"gptq: collected { len (result )} Hessians, samples per layer: "
552- f"{ dict (list (n_samples .items ())[:3 ])} ..." , flush = True )
553- return result
554-
555-
556- def gptq_quantize_layer (
557- W : Tensor ,
558- H : Tensor ,
559- block_size : int = 128 ,
560- percdamp : float = 0.01 ,
561- clip_range : int = 31 ,
562- actorder : bool = True ,
563- ) -> tuple [Tensor , Tensor , bool , dict [str , object ]]:
564- """GPTQ-quantize weight matrix W using the PR-grounded loop.
565-
566- Returns (q: int8 in [-clip_range, clip_range], scale: fp16 per-row, degraded: bool, stats).
567- """
568- W_orig = W .float ().clone ()
569- H = H .float ().clone ()
570- d_row , d_col = W_orig .shape
571-
572- dead = (H .diag () == 0 )
573- H [dead , dead ] = 1.0
574-
575- damp = percdamp * H .diag ().mean ().clamp_min (1e-6 )
576- H .diagonal ().add_ (damp )
577-
578- perm = torch .arange (d_col , device = H .device )
579- if actorder :
580- perm = torch .argsort (H .diag (), descending = True )
581- invperm = torch .argsort (perm )
582- W_perm = W_orig [:, perm ].clone ()
583- W_perm [:, dead [perm ]] = 0.0
524+ # Finalize: normalize and damp (PR #1019 lines 1131-1136)
525+ n_samples_example = {}
526+ for name in hessians :
527+ H = hessians [name ]
528+ n_samples_example [name ] = num_batches * batch_size * seq_len
529+ H /= num_batches
530+ damp = 0.01 * torch .diag (H ).mean ().clamp_min (1e-6 )
531+ H += damp * torch .eye (H .shape [0 ])
532+ hessians [name ] = H
533+
534+ print (f"gptq: collected { len (hessians )} Hessians, samples per layer: "
535+ f"{ dict (list (n_samples_example .items ())[:3 ])} ..." , flush = True )
536+ return hessians
537+
538+
539+ def quantize_int6_gptq (weight , hessian = None , clip_range = 31 , block_size = 128 ):
540+ """Full GPTQ: Hessian-aware int6 quantization with Cholesky error compensation.
541+ If hessian is None, falls back to percentile search.
542+ Verbatim from PR #1019 (lines 1171-1224)."""
543+ t32 = weight .float ()
544+ if t32 .ndim != 2 or hessian is None :
545+ return quantize_int6_per_row (t32 , clip_range )
546+ rows , cols = t32 .shape
547+ H = hessian .float ().clone ()
548+ dead = torch .diag (H ) == 0
549+ H [dead , dead ] = 1
550+ damp = 0.01 * torch .mean (torch .diag (H ))
551+ H [torch .arange (cols ), torch .arange (cols )] += damp
552+ perm = torch .argsort (torch .diag (H ), descending = True )
553+ inv_perm = torch .argsort (perm )
554+ W = t32 [:, perm ].clone ()
555+ W [:, dead [perm ]] = 0
584556 H = H [perm ][:, perm ]
585-
586- try :
587- Hinv_chol = torch .cholesky_inverse (torch .linalg .cholesky (H ))
588- Hinv_chol = torch .linalg .cholesky (Hinv_chol , upper = True )
589- except torch .linalg .LinAlgError :
590- cond_est = H .diag ().max ().item () / max (H .diag ().min ().item (), 1e-12 )
591- print (f"gptq:WARNING Cholesky failed, cond~{ cond_est :.1e} , falling back to naive" , flush = True )
592- q , s = quantize_int6_per_row (W_orig , clip_range = clip_range )
593- return q , s , True , {
594- "mse" : _quantization_mse (W_orig , q , s ),
595- "best_pct" : None ,
596- "dead_cols" : int (dead .sum ().item ()),
597- "max_block_mse" : None ,
598- "worst_block_start" : None ,
599- "cholesky_fallback" : True ,
600- }
601-
602- best_q , best_scale , best_stats = None , None , None
603- best_err = float ("inf" )
604- for pct , row_clip in _iter_int6_row_clips (W_orig ):
605- scale = (row_clip / float (clip_range )).clamp_min (1.0 / float (clip_range )).to (torch .float16 )
606- scale_f = scale .float ()
607- Q = torch .zeros ((d_row , d_col ), dtype = torch .int8 , device = W_perm .device )
608- W_work = W_perm .clone ()
609- max_block_mse = - 1.0
610- worst_block_start = 0
611-
612- for block_start in range (0 , d_col , block_size ):
613- block_end = min (block_start + block_size , d_col )
614- W_block = W_work [:, block_start :block_end ].clone ()
615- Err = torch .zeros ((d_row , block_end - block_start ), dtype = W_work .dtype , device = W_work .device )
616- Hinv_block = Hinv_chol [block_start :block_end , block_start :block_end ]
617-
618- for j in range (block_end - block_start ):
619- w_col = W_block [:, j ]
620- d = Hinv_block [j , j ]
621- q_col = torch .clamp (torch .round (w_col / scale_f ), - clip_range , clip_range )
622- Q [:, block_start + j ] = q_col .to (torch .int8 )
623- err = (w_col - q_col .float () * scale_f ) / d
624- Err [:, j ] = err
625- W_block [:, j :] -= err .unsqueeze (1 ) * Hinv_block [j , j :].unsqueeze (0 )
626-
627- if block_end < d_col :
628- W_work [:, block_end :] -= Err @ Hinv_chol [block_start :block_end , block_end :]
629-
630- block_recon = Q [:, block_start :block_end ].float () * scale_f [:, None ]
631- block_mse = float ((W_perm [:, block_start :block_end ] - block_recon ).pow (2 ).mean ().item ())
632- if block_mse > max_block_mse :
633- max_block_mse = block_mse
634- worst_block_start = block_start
635-
636- recon = Q .float () * scale_f [:, None ]
637- mse = float ((W_perm - recon ).pow (2 ).mean ().item ())
557+ Hinv = torch .linalg .cholesky (H )
558+ Hinv = torch .cholesky_inverse (Hinv )
559+ Hinv = torch .linalg .cholesky (Hinv , upper = True )
560+ best_q = None ; best_scale = None ; best_err = float ('inf' )
561+ for pct in [0.9990 , 0.9995 , 0.9999 , 0.99999 , 1.0 ]:
562+ if pct < 1.0 :
563+ row_clip = torch .quantile (t32 .abs (), pct , dim = 1 )
564+ else :
565+ row_clip = t32 .abs ().amax (dim = 1 )
566+ s = (row_clip / clip_range ).clamp_min (1.0 / clip_range ).to (torch .float16 )
567+ sf = s .float ()
568+ Q = torch .zeros_like (W , dtype = torch .int8 )
569+ W_work = W .clone ()
570+ for i1 in range (0 , cols , block_size ):
571+ i2 = min (i1 + block_size , cols )
572+ count = i2 - i1
573+ W1 = W_work [:, i1 :i2 ].clone ()
574+ Q1 = torch .zeros (rows , count , dtype = torch .int8 )
575+ Err1 = torch .zeros (rows , count )
576+ Hinv1 = Hinv [i1 :i2 , i1 :i2 ]
577+ for i in range (count ):
578+ w = W1 [:, i ]
579+ d = Hinv1 [i , i ]
580+ q = torch .clamp (torch .round (w / sf ), - clip_range , clip_range ).to (torch .int8 )
581+ Q1 [:, i ] = q
582+ err = (w - q .float () * sf ) / d
583+ W1 [:, i :] -= err .unsqueeze (1 ) * Hinv1 [i , i :].unsqueeze (0 )
584+ Err1 [:, i ] = err
585+ Q [:, i1 :i2 ] = Q1
586+ if i2 < cols :
587+ W_work [:, i2 :] -= Err1 @ Hinv [i1 :i2 , i2 :]
588+ recon = Q .float () * sf [:, None ]
589+ mse = (W - recon ).pow (2 ).mean ().item ()
638590 if mse < best_err :
639- best_q = Q
640- best_scale = scale
641- best_err = mse
642- best_stats = {
643- "mse" : mse ,
644- "best_pct" : pct ,
645- "dead_cols" : int (dead .sum ().item ()),
646- "max_block_mse" : max_block_mse ,
647- "worst_block_start" : worst_block_start ,
648- "cholesky_fallback" : False ,
649- }
650-
651- best_q = best_q [:, invperm ]
652- return best_q , best_scale , False , best_stats
591+ best_q , best_scale , best_err = Q , s , mse
592+ best_q = best_q [:, inv_perm ]
593+ return best_q , best_scale
653594
654595
655596def gptq_mixed_quantize_int6 (
@@ -662,8 +603,8 @@ def gptq_mixed_quantize_int6(
662603):
663604 """Like mixed_quantize_int6, but uses GPTQ for layers with Hessians.
664605
665- hessians: dict mapping module_name (e.g. "blocks.0.attn.c_q") -> H tensor.
666- state_dict keys use param names (e.g. "blocks.0.attn.c_q.weight") .
606+ hessians: dict mapping param_name (e.g. "blocks.0.attn.c_q.weight ") -> H tensor.
607+ state_dict keys use the same param names .
667608 """
668609 result : dict [str , Tensor ] = {}
669610 meta : dict [str , object ] = {}
@@ -681,33 +622,25 @@ def gptq_mixed_quantize_int6(
681622 meta [name ] = "passthrough_ctrl"
682623 continue
683624 if cat in int6_cats and t .ndim >= 1 :
684- module_name = name .rsplit (".weight" , 1 )[0 ] if name .endswith (".weight" ) else name
685- H = hessians .get (module_name )
625+ H = hessians .get (name )
686626 if H is not None and t .ndim == 2 :
687627 legacy_q , legacy_s = quantize_int6_per_row_legacy (t , clip_range = clip_range )
688628 naive_q , naive_s = quantize_int6_per_row (t , clip_range = clip_range )
689- q , s , degraded , gptq_stats = gptq_quantize_layer (
690- t , H , block_size = block_size , clip_range = clip_range , actorder = actorder ,
691- )
629+ q , s = quantize_int6_gptq (t , hessian = H , clip_range = clip_range , block_size = block_size )
692630 legacy_mse = _quantization_mse (t , legacy_q , legacy_s )
693631 naive_mse = _quantization_mse (t , naive_q , naive_s )
694- gptq_mse = float ( gptq_stats [ "mse" ] )
632+ gptq_mse = _quantization_mse ( t , q , s )
695633 diagnostics .append ({
696- "name" : module_name ,
634+ "name" : name ,
697635 "legacy_rowmax_mse" : legacy_mse ,
698636 "percentile_naive_mse" : naive_mse ,
699637 "gptq_mse" : gptq_mse ,
700638 "gptq_minus_legacy_rowmax_mse" : gptq_mse - legacy_mse ,
701639 "gptq_minus_percentile_naive_mse" : gptq_mse - naive_mse ,
702640 "gptq_worse_than_legacy_rowmax" : gptq_mse > legacy_mse ,
703641 "gptq_worse_than_percentile_naive" : gptq_mse > naive_mse ,
704- ** gptq_stats ,
705642 })
706- if degraded :
707- fallback_count += 1
708- print (f"gptq: DEGRADED layer { module_name } " , flush = True )
709- else :
710- gptq_count += 1
643+ gptq_count += 1
711644 else :
712645 q , s = quantize_int6_per_row (t , clip_range = clip_range )
713646 naive_count += 1
@@ -1732,8 +1665,7 @@ def lr_mul(step: int, elapsed_ms: float) -> float:
17321665 log0 (
17331666 "gptq: legacy_worse "
17341667 f"{ diag ['name' ]} delta_mse:{ diag ['gptq_minus_legacy_rowmax_mse' ]:.6e} "
1735- f"gptq_mse:{ diag ['gptq_mse' ]:.6e} legacy_mse:{ diag ['legacy_rowmax_mse' ]:.6e} "
1736- f"pct:{ diag ['best_pct' ]} block:{ diag ['worst_block_start' ]} "
1668+ f"gptq_mse:{ diag ['gptq_mse' ]:.6e} legacy_mse:{ diag ['legacy_rowmax_mse' ]:.6e} "
17371669 )
17381670 if not worse_naive :
17391671 log0 ("gptq: no layers worse than percentile naive int6" )
@@ -1742,8 +1674,7 @@ def lr_mul(step: int, elapsed_ms: float) -> float:
17421674 log0 (
17431675 "gptq: percentile_worse "
17441676 f"{ diag ['name' ]} delta_mse:{ diag ['gptq_minus_percentile_naive_mse' ]:.6e} "
1745- f"gptq_mse:{ diag ['gptq_mse' ]:.6e} naive_mse:{ diag ['percentile_naive_mse' ]:.6e} "
1746- f"pct:{ diag ['best_pct' ]} block:{ diag ['worst_block_start' ]} "
1677+ f"gptq_mse:{ diag ['gptq_mse' ]:.6e} naive_mse:{ diag ['percentile_naive_mse' ]:.6e} "
17471678 )
17481679 del hessians , sd_cpu_r0
17491680
0 commit comments