24
24
pack_scales_and_zeros ,
25
25
)
26
26
27
+ from torchao .dtypes .utils import is_device
28
+
27
29
28
30
logger : logging .Logger = logging .getLogger (__name__ )
29
31
@@ -128,6 +130,7 @@ def linear_int4(input, weight_int4pack, scales_and_zeros, out_features, groupsiz
128
130
groupsize ,
129
131
scales_and_zeros ,
130
132
)
133
+
131
134
new_shape = origin_input_size [:- 1 ] + (out_features ,)
132
135
c = c .reshape (new_shape )
133
136
return c
@@ -178,16 +181,27 @@ def __init__(
178
181
), "must specify both weights and scales_and_zeros, or neither"
179
182
180
183
if weight is None :
181
- weight = torch .empty (
182
- (
183
- out_features // 8 ,
184
- in_features // (inner_k_tiles * 16 ),
185
- 32 ,
186
- inner_k_tiles // 2 ,
187
- ),
188
- dtype = torch .int32 ,
189
- device = device ,
190
- )
184
+ if is_device (device , "cpu" ):
185
+ weight = torch .empty (
186
+ (
187
+ out_features ,
188
+ in_features // 2 ,
189
+ ),
190
+ dtype = torch .uint8 ,
191
+ device = device ,
192
+ )
193
+ else :
194
+ weight = torch .empty (
195
+ (
196
+ out_features // 8 ,
197
+ in_features // (inner_k_tiles * 16 ),
198
+ 32 ,
199
+ inner_k_tiles // 2 ,
200
+ ),
201
+ dtype = torch .int32 ,
202
+ device = device ,
203
+ )
204
+
191
205
scales_and_zeros = torch .empty (
192
206
(in_features // groupsize , out_features , 2 ),
193
207
dtype = get_precision (),
@@ -223,12 +237,17 @@ def _prepare_weight_and_scales_and_zeros(
223
237
weight_int32 , scales_and_zeros = group_quantize_tensor (
224
238
weight_bf16 , n_bit = 4 , groupsize = groupsize
225
239
)
226
- weight_uint8 = (weight_int32 [::, ::2 ] << 4 | weight_int32 [::, 1 ::2 ]).to (
227
- torch .uint8
228
- )
229
- weight_int4pack = torch .ops .aten ._convert_weight_to_int4pack (
230
- weight_uint8 , inner_k_tiles
231
- )
240
+ if is_device (weight_int32 .device .type , "cpu" ):
241
+ weight_int4pack = torch .ops .aten ._convert_weight_to_int4pack_for_cpu (
242
+ weight_int32 , inner_k_tiles
243
+ )
244
+ else :
245
+ weight_uint8 = (weight_int32 [::, ::2 ] << 4 | weight_int32 [::, 1 ::2 ]).to (
246
+ torch .uint8
247
+ )
248
+ weight_int4pack = torch .ops .aten ._convert_weight_to_int4pack (
249
+ weight_uint8 , inner_k_tiles
250
+ )
232
251
return weight_int4pack , scales_and_zeros
233
252
234
253
@classmethod
@@ -609,17 +628,14 @@ def load_model_and_state_dict(
609
628
if load_state_dict :
610
629
q , s , z = Q4_0 .unpack (t )
611
630
scales_and_zeros = pack_scales_and_zeros (s , z )
612
- q_uint8 = (q [::, ::2 ] << 4 | q [::, 1 ::2 ]).to (torch .uint8 )
613
-
614
- if torch .device (device ).type == "cpu" :
615
- weight_int4pack = (
616
- torch .ops .aten ._convert_weight_to_int4pack_for_cpu (
617
- q , inner_k_tiles
618
- )
631
+ if is_device (q .device .type , "cpu" ):
632
+ weight_int4pack = torch .ops .aten ._convert_weight_to_int4pack_for_cpu (
633
+ q , inner_k_tiles
619
634
)
620
635
else :
636
+ q_tmp = (q [::, ::2 ] << 4 | q [::, 1 ::2 ]).to (torch .uint8 )
621
637
weight_int4pack = torch .ops .aten ._convert_weight_to_int4pack (
622
- q_uint8 , inner_k_tiles
638
+ q_tmp , inner_k_tiles
623
639
)
624
640
state_dict [f"{ fqn } .weight" ] = weight_int4pack
625
641
state_dict [f"{ fqn } .scales_and_zeros" ] = scales_and_zeros
@@ -632,7 +648,7 @@ def load_model_and_state_dict(
632
648
in_features = in_features ,
633
649
out_features = out_features ,
634
650
bias = False ,
635
- device = "meta " ,
651
+ device = "cpu " ,
636
652
groupsize = Q4_0 .groupsize ,
637
653
inner_k_tiles = inner_k_tiles ,
638
654
),
0 commit comments