24
24
pack_scales_and_zeros ,
25
25
)
26
26
27
+ from torchao .dtypes .utils import is_device
28
+ from torchao .utils import TORCH_VERSION_AT_LEAST_2_6
29
+
27
30
28
31
logger : logging .Logger = logging .getLogger (__name__ )
29
32
@@ -122,12 +125,20 @@ def linear_int4(input, weight_int4pack, scales_and_zeros, out_features, groupsiz
122
125
input .dtype
123
126
) # cast back to input.dtype
124
127
else :
125
- c = torch .ops .aten ._weight_int4pack_mm (
126
- input ,
127
- weight_int4pack ,
128
- groupsize ,
129
- scales_and_zeros ,
130
- )
128
+ if TORCH_VERSION_AT_LEAST_2_6 :
129
+ c = torch .ops .aten ._weight_int4pack_mm_for_cpu (
130
+ input ,
131
+ weight_int4pack ,
132
+ groupsize ,
133
+ scales_and_zeros ,
134
+ )
135
+ else :
136
+ c = torch .ops .aten ._weight_int4pack_mm (
137
+ input ,
138
+ weight_int4pack ,
139
+ groupsize ,
140
+ scales_and_zeros ,
141
+ )
131
142
new_shape = origin_input_size [:- 1 ] + (out_features ,)
132
143
c = c .reshape (new_shape )
133
144
return c
@@ -178,16 +189,27 @@ def __init__(
178
189
), "must specify both weights and scales_and_zeros, or neither"
179
190
180
191
if weight is None :
181
- weight = torch .empty (
182
- (
183
- out_features // 8 ,
184
- in_features // (inner_k_tiles * 16 ),
185
- 32 ,
186
- inner_k_tiles // 2 ,
187
- ),
188
- dtype = torch .int32 ,
189
- device = device ,
190
- )
192
+ if is_device (device , "cpu" ):
193
+ weight = torch .empty (
194
+ (
195
+ out_features ,
196
+ in_features // 2 ,
197
+ ),
198
+ dtype = torch .uint8 ,
199
+ device = device ,
200
+ )
201
+ else :
202
+ weight = torch .empty (
203
+ (
204
+ out_features // 8 ,
205
+ in_features // (inner_k_tiles * 16 ),
206
+ 32 ,
207
+ inner_k_tiles // 2 ,
208
+ ),
209
+ dtype = torch .int32 ,
210
+ device = device ,
211
+ )
212
+
191
213
scales_and_zeros = torch .empty (
192
214
(in_features // groupsize , out_features , 2 ),
193
215
dtype = get_precision (),
@@ -223,12 +245,17 @@ def _prepare_weight_and_scales_and_zeros(
223
245
weight_int32 , scales_and_zeros = group_quantize_tensor (
224
246
weight_bf16 , n_bit = 4 , groupsize = groupsize
225
247
)
226
- weight_uint8 = (weight_int32 [::, ::2 ] << 4 | weight_int32 [::, 1 ::2 ]).to (
227
- torch .uint8
228
- )
229
- weight_int4pack = torch .ops .aten ._convert_weight_to_int4pack (
230
- weight_uint8 , inner_k_tiles
231
- )
248
+ if is_device (weight_int32 .device .type , "cpu" ) and TORCH_VERSION_AT_LEAST_2_6 :
249
+ weight_int4pack = torch .ops .aten ._convert_weight_to_int4pack_for_cpu (
250
+ weight_int32 , inner_k_tiles
251
+ )
252
+ else :
253
+ weight_uint8 = (weight_int32 [::, ::2 ] << 4 | weight_int32 [::, 1 ::2 ]).to (
254
+ torch .uint8
255
+ )
256
+ weight_int4pack = torch .ops .aten ._convert_weight_to_int4pack (
257
+ weight_uint8 , inner_k_tiles
258
+ )
232
259
return weight_int4pack , scales_and_zeros
233
260
234
261
@classmethod
@@ -608,10 +635,15 @@ def load_model_and_state_dict(
608
635
if load_state_dict :
609
636
q , s , z = Q4_0 .unpack (t )
610
637
scales_and_zeros = pack_scales_and_zeros (s , z )
611
- q_uint8 = (q [::, ::2 ] << 4 | q [::, 1 ::2 ]).to (torch .uint8 )
612
- weight_int4pack = torch .ops .aten ._convert_weight_to_int4pack (
613
- q_uint8 , inner_k_tiles
614
- )
638
+ if is_device (q .device .type , "cpu" ) and TORCH_VERSION_AT_LEAST_2_6 :
639
+ weight_int4pack = torch .ops .aten ._convert_weight_to_int4pack_for_cpu (
640
+ q , inner_k_tiles
641
+ )
642
+ else :
643
+ q_tmp = (q [::, ::2 ] << 4 | q [::, 1 ::2 ]).to (torch .uint8 )
644
+ weight_int4pack = torch .ops .aten ._convert_weight_to_int4pack (
645
+ q_tmp , inner_k_tiles
646
+ )
615
647
state_dict [f"{ fqn } .weight" ] = weight_int4pack
616
648
state_dict [f"{ fqn } .scales_and_zeros" ] = scales_and_zeros
617
649
@@ -623,7 +655,7 @@ def load_model_and_state_dict(
623
655
in_features = in_features ,
624
656
out_features = out_features ,
625
657
bias = False ,
626
- device = "meta " ,
658
+ device = "cpu " ,
627
659
groupsize = Q4_0 .groupsize ,
628
660
inner_k_tiles = inner_k_tiles ,
629
661
),
0 commit comments