@@ -55,7 +55,7 @@ def name_to_dtype(name):
55
55
##########################################################################
56
56
### process quantization dictionary ###
57
57
58
- def quantize_model (model : nn .Module , quantize_options ):
58
+ def quantize_model (model : nn .Module , device , quantize_options ):
59
59
"""
60
60
Quantize the specified model using the quantizers described by
61
61
a quantization dict of the form:
@@ -74,6 +74,7 @@ def quantize_model(model: nn.Module, quantize_options):
74
74
if quantizer == "embedding" :
75
75
model = EmbeddingOnlyInt8QuantHandler (
76
76
model ,
77
+ device ,
77
78
** q_kwargs
78
79
).quantized_model ()
79
80
elif linears_quantized :
@@ -82,30 +83,35 @@ def quantize_model(model: nn.Module, quantize_options):
82
83
linears_quantized = True
83
84
model = WeightOnlyInt8QuantHandler (
84
85
model ,
86
+ device ,
85
87
** q_kwargs
86
88
).quantized_model ()
87
89
elif quantizer == "linear:int4" :
88
90
linears_quantized = True
89
91
model = WeightOnlyInt4QuantHandler (
90
92
model ,
93
+ device ,
91
94
** q_kwargs
92
95
).quantized_model ()
93
96
elif quantizer == "linear:a8w4dq" :
94
97
linears_quantized = True
95
98
model = Int8DynActInt4WeightQuantHandler (
96
99
model ,
100
+ device ,
97
101
** q_kwargs
98
102
).quantized_model ()
99
103
elif quantizer == "linear:gptq" :
100
104
linears_quantized = True
101
105
model = WeightOnlyInt4GPTQQuantHandler (
102
106
model ,
107
+ device ,
103
108
** q_kwargs
104
109
).quantized_model ()
105
110
elif quantizer == "linear:hqq" :
106
111
linears_quantized = True
107
112
model = WeightOnlyInt4HqqQuantHandler (
108
113
model ,
114
+ device ,
109
115
** q_kwargs
110
116
).quantized_model ()
111
117
elif quantizer == "precision" :
@@ -371,12 +377,14 @@ class WeightOnlyInt8QuantHandler(QuantHandler):
371
377
def __init__ (
372
378
self ,
373
379
mod ,
380
+ device ,
374
381
* ,
375
382
node_type : str = "*" ,
376
383
bitwidth : Optional [int ] = None ,
377
384
groupsize : Optional [int ] = None ,
378
385
):
379
386
self .mod = mod
387
+ self .device = device ,
380
388
self .groupsize = groupsize
381
389
self .node_type = node_type
382
390
if bitwidth is None :
@@ -494,7 +502,7 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
494
502
495
503
496
504
def replace_embedding_weight_only_grouped_int8_per_channel (
497
- module , bitwidth : int = 8 , groupsize : Optional [int ] = None , packed = False
505
+ module , device , bitwidth : int = 8 , groupsize : Optional [int ] = None , packed = False
498
506
):
499
507
for name , child in module .named_children ():
500
508
# print(f"name: {name}")
@@ -505,6 +513,7 @@ def replace_embedding_weight_only_grouped_int8_per_channel(
505
513
module ,
506
514
name ,
507
515
QuantizedGroupEmbedding (
516
+ device = device ,
508
517
vocab_size = child .weight .shape [0 ],
509
518
embedding_dim = child .weight .shape [1 ],
510
519
groupsize = groupsize ,
@@ -518,10 +527,11 @@ def replace_embedding_weight_only_grouped_int8_per_channel(
518
527
519
528
520
529
class EmbeddingOnlyInt8QuantHandler (QuantHandler ):
521
- def __init__ (self , mod , * , bitwidth : int = 8 , groupsize : Optional [int ] = None , packed = False ):
530
+ def __init__ (self , mod , device , * , bitwidth : int = 8 , groupsize : Optional [int ] = None , packed = False ):
522
531
if isinstance (packed , str ):
523
532
packed = (packed == "True" )
524
533
self .mod = mod
534
+ self .device = device
525
535
self .groupsize = groupsize
526
536
self .bitwidth = bitwidth
527
537
self .packed = packed
@@ -565,7 +575,7 @@ def create_quantized_state_dict(self, packed=False) -> Dict:
565
575
566
576
if packed :
567
577
if weight .shape [- 1 ] % 2 != 0 :
568
- raise RUntimeError ("automatic padding not implemented yet" )
578
+ raise RuntimeError ("automatic padding not implemented yet" )
569
579
570
580
weight_range_shifted = weight .add (8 ).view (torch .uint8 )
571
581
weight_view = weight_range_shifted .view (
@@ -578,6 +588,8 @@ def create_quantized_state_dict(self, packed=False) -> Dict:
578
588
weight_packed = weight_even + weight_odd
579
589
weight = weight_packed
580
590
591
+ weight = weight .to (device = self .device )
592
+ scales = scales .to (device = self .device )
581
593
# Update state dict
582
594
cur_state_dict [f"{ fqn } .weight" ] = weight
583
595
# squeeze makes groupsize=rowsize unidimensional
@@ -587,7 +599,7 @@ def create_quantized_state_dict(self, packed=False) -> Dict:
587
599
588
600
def convert_for_runtime (self ) -> nn .Module :
589
601
replace_embedding_weight_only_grouped_int8_per_channel (
590
- self .mod , self .bitwidth , self .groupsize , self .packed
602
+ self .mod , self .device , self . bitwidth , self .groupsize , self .packed
591
603
)
592
604
return self .mod
593
605
@@ -601,10 +613,10 @@ def quantized_model(self) -> nn.Module:
601
613
class QuantizedGroupEmbedding (torch .nn .Module ):
602
614
def __init__ (
603
615
self ,
616
+ device ,
604
617
vocab_size : int ,
605
618
embedding_dim : int ,
606
619
groupsize : Optional [int ] = None ,
607
- device = None ,
608
620
dtype = torch .half ,
609
621
packed = False ,
610
622
) -> None :
@@ -616,20 +628,20 @@ def __init__(
616
628
self .packed = packed
617
629
if not packed :
618
630
self .register_buffer (
619
- "weight" , torch .empty ((vocab_size , embedding_dim ), dtype = torch .int8 )
631
+ "weight" , torch .empty ((vocab_size , embedding_dim ), dtype = torch .int8 , device = device )
620
632
)
621
633
else : # packed
622
634
self .register_buffer (
623
- "weight" , torch .empty ((vocab_size , embedding_dim // 2 ), dtype = torch .uint8 )
635
+ "weight" , torch .empty ((vocab_size , embedding_dim // 2 ), dtype = torch .uint8 , device = device )
624
636
)
625
637
groups_per_row = (embedding_dim + groupsize - 1 ) // groupsize
626
638
if groups_per_row > 1 :
627
639
self .register_buffer (
628
- "scales" , torch .ones ((vocab_size , groups_per_row ), dtype = torch .float16 )
640
+ "scales" , torch .ones ((vocab_size , groups_per_row ), dtype = torch .float16 , device = device )
629
641
)
630
642
else :
631
643
self .register_buffer (
632
- "scales" , torch .ones ((vocab_size ,), dtype = torch .float16 )
644
+ "scales" , torch .ones ((vocab_size ,), dtype = torch .float16 , device = device )
633
645
)
634
646
635
647
@torch .no_grad ()
@@ -712,8 +724,9 @@ def replace_linear_int4(module, groupsize, inner_k_tiles, padding_allowed, use_c
712
724
713
725
714
726
class WeightOnlyInt4QuantHandler (QuantHandler ):
715
- def __init__ (self , mod , groupsize = 128 , inner_k_tiles = 8 , padding_allowed = True ):
727
+ def __init__ (self , mod , device , * , groupsize = 128 , inner_k_tiles = 8 , padding_allowed = True ):
716
728
self .mod = mod
729
+ self .device = device ,
717
730
self .groupsize = groupsize
718
731
self .inner_k_tiles = inner_k_tiles
719
732
self .padding_allowed = padding_allowed
@@ -908,12 +921,15 @@ class Int8DynActInt4WeightQuantHandler(QuantHandler):
908
921
def __init__ (
909
922
self ,
910
923
mod ,
924
+ device ,
925
+ * ,
911
926
groupsize = 256 ,
912
927
padding_allowed = False ,
913
928
precision = torch .float32 ,
914
929
scales_precision = torch .float32 ,
915
930
):
916
931
self .mod = mod
932
+ self .device = device
917
933
self .groupsize = groupsize
918
934
self .padding_allowed = padding_allowed
919
935
self .precision = precision
@@ -1209,9 +1225,10 @@ def convert_for_runtime(self) -> "nn.Module":
1209
1225
1210
1226
1211
1227
class WeightOnlyInt4GPTQQuantHandler (GPTQQuantHandler ):
1212
- def __init__ (self , mod , groupsize = 128 , inner_k_tiles = 8 , padding = True ):
1228
+ def __init__ (self , mod , device , * , groupsize = 128 , inner_k_tiles = 8 , padding = True ):
1213
1229
from build .model import find_multiple
1214
1230
self .mod = mod
1231
+ self .device = device
1215
1232
self .groupsize = groupsize
1216
1233
self .inner_k_tiles = inner_k_tiles
1217
1234
self .padding = padding
@@ -1329,7 +1346,7 @@ def quantized_model(self) -> nn.Module:
1329
1346
### WIP: HQQ ###
1330
1347
1331
1348
class WeightOnlyInt4HqqQuantHandler :
1332
- def __init__ (self , mod , groupsize ):
1349
+ def __init__ (self , mod , device , * , groupsize ):
1333
1350
self .mod = mod
1334
1351
self .groupsize = groupsize
1335
1352
0 commit comments