Skip to content

Commit 6604277

Browse files
committed
Update int4pack related in torchchat gguf
1 parent fff956c commit 6604277

File tree

1 file changed

+27
-10
lines changed

1 file changed

+27
-10
lines changed

torchchat/utils/gguf_loader.py

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@
2424
pack_scales_and_zeros,
2525
)
2626

27+
from torchao.dtypes.utils import is_device
28+
from torchao.utils import TORCH_VERSION_AT_LEAST_2_6
29+
2730

2831
logger: logging.Logger = logging.getLogger(__name__)
2932

@@ -122,12 +125,20 @@ def linear_int4(input, weight_int4pack, scales_and_zeros, out_features, groupsiz
122125
input.dtype
123126
) # cast back to input.dtype
124127
else:
125-
c = torch.ops.aten._weight_int4pack_mm(
126-
input,
127-
weight_int4pack,
128-
groupsize,
129-
scales_and_zeros,
130-
)
128+
if TORCH_VERSION_AT_LEAST_2_6:
129+
c = torch.ops.aten._weight_int4pack_mm_for_cpu(
130+
input,
131+
weight_int4pack,
132+
groupsize,
133+
scales_and_zeros,
134+
)
135+
else:
136+
c = torch.ops.aten._weight_int4pack_mm(
137+
input,
138+
weight_int4pack,
139+
groupsize,
140+
scales_and_zeros,
141+
)
131142
new_shape = origin_input_size[:-1] + (out_features,)
132143
c = c.reshape(new_shape)
133144
return c
@@ -608,10 +619,16 @@ def load_model_and_state_dict(
608619
if load_state_dict:
609620
q, s, z = Q4_0.unpack(t)
610621
scales_and_zeros = pack_scales_and_zeros(s, z)
611-
q_uint8 = (q[::, ::2] << 4 | q[::, 1::2]).to(torch.uint8)
612-
weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(
613-
q_uint8, inner_k_tiles
614-
)
622+
q_tmp = q
623+
if is_device(q.device.type, "cpu") and TORCH_VERSION_AT_LEAST_2_6:
624+
weight_int4pack = torch.ops.aten._convert_weight_to_int4pack_for_cpu(
625+
q_tmp, inner_k_tiles
626+
)
627+
else:
628+
q_tmp = (q[::, ::2] << 4 | q[::, 1::2]).to(torch.uint8)
629+
weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(
630+
q_tmp, inner_k_tiles
631+
)
615632
state_dict[f"{fqn}.weight"] = weight_int4pack
616633
state_dict[f"{fqn}.scales_and_zeros"] = scales_and_zeros
617634

0 commit comments

Comments
 (0)