Skip to content

Commit 29d4ab7

Browse files
committed
Update int4pack related for gguf
1 parent fff956c commit 29d4ab7

File tree

1 file changed

+59
-27
lines changed

1 file changed

+59
-27
lines changed

torchchat/utils/gguf_loader.py

Lines changed: 59 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@
2424
pack_scales_and_zeros,
2525
)
2626

27+
from torchao.dtypes.utils import is_device
28+
from torchao.utils import TORCH_VERSION_AT_LEAST_2_6
29+
2730

2831
logger: logging.Logger = logging.getLogger(__name__)
2932

@@ -122,12 +125,20 @@ def linear_int4(input, weight_int4pack, scales_and_zeros, out_features, groupsiz
122125
input.dtype
123126
) # cast back to input.dtype
124127
else:
125-
c = torch.ops.aten._weight_int4pack_mm(
126-
input,
127-
weight_int4pack,
128-
groupsize,
129-
scales_and_zeros,
130-
)
128+
if TORCH_VERSION_AT_LEAST_2_6:
129+
c = torch.ops.aten._weight_int4pack_mm_for_cpu(
130+
input,
131+
weight_int4pack,
132+
groupsize,
133+
scales_and_zeros,
134+
)
135+
else:
136+
c = torch.ops.aten._weight_int4pack_mm(
137+
input,
138+
weight_int4pack,
139+
groupsize,
140+
scales_and_zeros,
141+
)
131142
new_shape = origin_input_size[:-1] + (out_features,)
132143
c = c.reshape(new_shape)
133144
return c
@@ -178,16 +189,27 @@ def __init__(
178189
), "must specify both weights and scales_and_zeros, or neither"
179190

180191
if weight is None:
181-
weight = torch.empty(
182-
(
183-
out_features // 8,
184-
in_features // (inner_k_tiles * 16),
185-
32,
186-
inner_k_tiles // 2,
187-
),
188-
dtype=torch.int32,
189-
device=device,
190-
)
192+
if is_device(device, "cpu"):
193+
weight = torch.empty(
194+
(
195+
out_features,
196+
in_features // 2,
197+
),
198+
dtype=torch.uint8,
199+
device=device,
200+
)
201+
else:
202+
weight = torch.empty(
203+
(
204+
out_features // 8,
205+
in_features // (inner_k_tiles * 16),
206+
32,
207+
inner_k_tiles // 2,
208+
),
209+
dtype=torch.int32,
210+
device=device,
211+
)
212+
191213
scales_and_zeros = torch.empty(
192214
(in_features // groupsize, out_features, 2),
193215
dtype=get_precision(),
@@ -223,12 +245,17 @@ def _prepare_weight_and_scales_and_zeros(
223245
weight_int32, scales_and_zeros = group_quantize_tensor(
224246
weight_bf16, n_bit=4, groupsize=groupsize
225247
)
226-
weight_uint8 = (weight_int32[::, ::2] << 4 | weight_int32[::, 1::2]).to(
227-
torch.uint8
228-
)
229-
weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(
230-
weight_uint8, inner_k_tiles
231-
)
248+
if is_device(weight_int32.device.type, "cpu") and TORCH_VERSION_AT_LEAST_2_6:
249+
weight_int4pack = torch.ops.aten._convert_weight_to_int4pack_for_cpu(
250+
weight_int32, inner_k_tiles
251+
)
252+
else:
253+
weight_uint8 = (weight_int32[::, ::2] << 4 | weight_int32[::, 1::2]).to(
254+
torch.uint8
255+
)
256+
weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(
257+
weight_uint8, inner_k_tiles
258+
)
232259
return weight_int4pack, scales_and_zeros
233260

234261
@classmethod
@@ -608,10 +635,15 @@ def load_model_and_state_dict(
608635
if load_state_dict:
609636
q, s, z = Q4_0.unpack(t)
610637
scales_and_zeros = pack_scales_and_zeros(s, z)
611-
q_uint8 = (q[::, ::2] << 4 | q[::, 1::2]).to(torch.uint8)
612-
weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(
613-
q_uint8, inner_k_tiles
614-
)
638+
if is_device(q.device.type, "cpu") and TORCH_VERSION_AT_LEAST_2_6:
639+
weight_int4pack = torch.ops.aten._convert_weight_to_int4pack_for_cpu(
640+
q, inner_k_tiles
641+
)
642+
else:
643+
q_tmp = (q[::, ::2] << 4 | q[::, 1::2]).to(torch.uint8)
644+
weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(
645+
q_tmp, inner_k_tiles
646+
)
615647
state_dict[f"{fqn}.weight"] = weight_int4pack
616648
state_dict[f"{fqn}.scales_and_zeros"] = scales_and_zeros
617649

@@ -623,7 +655,7 @@ def load_model_and_state_dict(
623655
in_features=in_features,
624656
out_features=out_features,
625657
bias=False,
626-
device="meta",
658+
device="cpu",
627659
groupsize=Q4_0.groupsize,
628660
inner_k_tiles=inner_k_tiles,
629661
),

0 commit comments

Comments
 (0)