Skip to content

Commit fbdd08c

Browse files
authored
Address quantization failures on devices (#204)
* add device for quantization, enable embedding quant with device * typo * fix filename weirdness * enable mps embedding table runs * import os for basename * fix extraneous updates with group_size
1 parent 55e7583 commit fbdd08c

File tree

6 files changed

+53
-35
lines changed

6 files changed

+53
-35
lines changed

.github/workflows/compile_t4.yml

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -52,13 +52,13 @@ jobs:
5252
echo "******************************************"
5353
echo "******* Emb: channel-wise quantized ******"
5454
echo "******************************************"
55-
# python generate.py --device cuda --quant '{"embedding" : {"bitwidth": 8, "groupsize": 0}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager
56-
# cat ./output_eager
57-
# python generate.py --device cuda --compile --quant '{"embedding" : {"bitwidth": 8, "groupsize": 0}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_compiled
58-
# cat ./output_compiled
59-
# python export.py --device cuda --quant '{"embedding" : {"bitwidth": 8, "groupsize": 0}}' --checkpoint-path ${MODEL_PATH} --output-dso-path ${MODEL_DIR}/${MODEL_NAME}.so
60-
# python generate.py --device cuda --checkpoint-path ${MODEL_PATH} --temperature 0 --dso-path ${MODEL_DIR}/${MODEL_NAME}.so > ./output_aoti
61-
# cat ./output_aoti
55+
python generate.py --device cuda --quant '{"embedding" : {"bitwidth": 8, "groupsize": 0}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager
56+
cat ./output_eager
57+
python generate.py --device cuda --compile --quant '{"embedding" : {"bitwidth": 8, "groupsize": 0}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_compiled
58+
cat ./output_compiled
59+
python export.py --device cuda --quant '{"embedding" : {"bitwidth": 8, "groupsize": 0}}' --checkpoint-path ${MODEL_PATH} --output-dso-path ${MODEL_DIR}/${MODEL_NAME}.so
60+
python generate.py --device cuda --checkpoint-path ${MODEL_PATH} --temperature 0 --dso-path ${MODEL_DIR}/${MODEL_NAME}.so > ./output_aoti
61+
cat ./output_aoti
6262
6363
echo "******************************************"
6464
echo "******** Emb: group-wise quantized *******"

.github/workflows/test_mps-dtype.yml

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -52,14 +52,14 @@ jobs:
5252
5353
python generate.py --dtype ${DTYPE} --device mps --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager
5454
cat ./output_eager
55-
# python generate.py --dtype ${DTYPE} --device mps --quant '{"embedding" : {"bitwidth": 8, "group_size": 0}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager
55+
# python generate.py --dtype ${DTYPE} --device mps --quant '{"embedding" : {"bitwidth": 8, "groupsize": 0}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager
5656
# cat ./output_eager
57-
# python generate.py --dtype ${DTYPE} --device mps --quant '{"embedding" : {"bitwidth": 8, "group_size": 8}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager
57+
# python generate.py --dtype ${DTYPE} --device mps --quant '{"embedding" : {"bitwidth": 8, "groupsize": 8}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager
5858
# cat ./output_eager
59-
# python generate.py --dtype ${DTYPE} --device mps --quant '{"linear:int8" : {"bitwidth": 8, "group_size": 0}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager
59+
# python generate.py --dtype ${DTYPE} --device mps --quant '{"linear:int8" : {"bitwidth": 8, "groupsize": 0}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager
6060
# cat ./output_eager
61-
# python generate.py --dtype ${DTYPE} --device mps --quant '{"linear:int8" : {"bitwidth": 8, "group_size": 8}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager
61+
# python generate.py --dtype ${DTYPE} --device mps --quant '{"linear:int8" : {"bitwidth": 8, "groupsize": 8}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager
6262
# cat ./output_eager
63-
# PYTORCH_ENABLE_MPS_FALLBACK=1 python generate.py --dtype ${DTYPE} --device mps --quant '{"linear:int4" : {"group_size": 32}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager
63+
# PYTORCH_ENABLE_MPS_FALLBACK=1 python generate.py --dtype ${DTYPE} --device mps --quant '{"linear:int4" : {"groupsize": 32}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager
6464
# cat ./output_eager
6565
done

.github/workflows/test_mps.yml

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -48,14 +48,14 @@ jobs:
4848
4949
python generate.py --device mps --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager
5050
cat ./output_eager
51-
# python generate.py --device mps --quant '{"embedding" : {"bitwidth": 8, "group_size": 0}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager
52-
# cat ./output_eager
53-
# python generate.py --device mps --quant '{"embedding" : {"bitwidth": 8, "group_size": 8}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager
54-
# cat ./output_eager
55-
# python generate.py --device mps --quant '{"linear:int8" : {"bitwidth": 8, "group_size": 0}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager
51+
python generate.py --device mps --quant '{"embedding" : {"bitwidth": 8, "groupsize": 0}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager
52+
cat ./output_eager
53+
python generate.py --device mps --quant '{"embedding" : {"bitwidth": 8, "groupsize": 8}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager
54+
cat ./output_eager
55+
# python generate.py --device mps --quant '{"linear:int8" : {"bitwidth": 8, "groupsize": 0}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager
5656
# cat ./output_eager
57-
# python generate.py --device mps --quant '{"linear:int8" : {"bitwidth": 8, "group_size": 8}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager
57+
# python generate.py --device mps --quant '{"linear:int8" : {"bitwidth": 8, "groupsize": 8}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager
5858
# cat ./output_eager
59-
# PYTORCH_ENABLE_MPS_FALLBACK=1 python generate.py --device mps --quant '{"linear:int4" : {"group_size": 32}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager
59+
# PYTORCH_ENABLE_MPS_FALLBACK=1 python generate.py --device mps --quant '{"linear:int4" : {"groupsize": 32}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager
6060
# cat ./output_eager
6161

build/builder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,7 @@ def _initialize_model(
259259

260260
if quantize:
261261
t0q = time.time()
262-
quantize_model(model, quantize)
262+
quantize_model(model, builder_args.device, quantize)
263263
device_sync(device=builder_args.device)
264264
print(f"Time to quantize model: {time.time() - t0q:.02f} seconds")
265265

generate.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66
import itertools
77
import sys
8+
import os
89
import time
910
from pathlib import Path
1011
from typing import Optional, Tuple
@@ -333,9 +334,9 @@ def _main(
333334
set_precision(builder_args.precision)
334335
is_speculative = speculative_builder_args.checkpoint_path is not None
335336

336-
is_chat = "chat" in str(builder_args.checkpoint_path)
337+
is_chat = "chat" in str(os.path.basename(builder_args.checkpoint_path))
337338
if is_chat:
338-
raise RuntimeError("need to stop filename based kludgery, at a minimum need to look at all pathnames. yuck!")
339+
raise RuntimeError("need to stop filename based kludgery, at a minimum need to look at all pathnames. in particular, this now fails because chat is part of the pathname, yuck!")
339340

340341
tokenizer = _initialize_tokenizer(tokenizer_args)
341342

quantize.py

Lines changed: 30 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def name_to_dtype(name):
5555
##########################################################################
5656
### process quantization dictionary ###
5757

58-
def quantize_model(model: nn.Module, quantize_options):
58+
def quantize_model(model: nn.Module, device, quantize_options):
5959
"""
6060
Quantize the specified model using the quantizers described by
6161
a quantization dict of the form:
@@ -74,6 +74,7 @@ def quantize_model(model: nn.Module, quantize_options):
7474
if quantizer == "embedding":
7575
model = EmbeddingOnlyInt8QuantHandler(
7676
model,
77+
device,
7778
**q_kwargs
7879
).quantized_model()
7980
elif linears_quantized:
@@ -82,30 +83,35 @@ def quantize_model(model: nn.Module, quantize_options):
8283
linears_quantized = True
8384
model = WeightOnlyInt8QuantHandler(
8485
model,
86+
device,
8587
**q_kwargs
8688
).quantized_model()
8789
elif quantizer == "linear:int4":
8890
linears_quantized = True
8991
model = WeightOnlyInt4QuantHandler(
9092
model,
93+
device,
9194
**q_kwargs
9295
).quantized_model()
9396
elif quantizer == "linear:a8w4dq":
9497
linears_quantized = True
9598
model = Int8DynActInt4WeightQuantHandler(
9699
model,
100+
device,
97101
**q_kwargs
98102
).quantized_model()
99103
elif quantizer == "linear:gptq":
100104
linears_quantized = True
101105
model = WeightOnlyInt4GPTQQuantHandler(
102106
model,
107+
device,
103108
**q_kwargs
104109
).quantized_model()
105110
elif quantizer == "linear:hqq":
106111
linears_quantized = True
107112
model = WeightOnlyInt4HqqQuantHandler(
108113
model,
114+
device,
109115
**q_kwargs
110116
).quantized_model()
111117
elif quantizer == "precision":
@@ -371,12 +377,14 @@ class WeightOnlyInt8QuantHandler(QuantHandler):
371377
def __init__(
372378
self,
373379
mod,
380+
device,
374381
*,
375382
node_type: str = "*",
376383
bitwidth: Optional[int] = None,
377384
groupsize: Optional[int] = None,
378385
):
379386
self.mod = mod
387+
self.device = device,
380388
self.groupsize = groupsize
381389
self.node_type = node_type
382390
if bitwidth is None:
@@ -494,7 +502,7 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
494502

495503

496504
def replace_embedding_weight_only_grouped_int8_per_channel(
497-
module, bitwidth: int = 8, groupsize: Optional[int] = None, packed = False
505+
module, device, bitwidth: int = 8, groupsize: Optional[int] = None, packed = False
498506
):
499507
for name, child in module.named_children():
500508
# print(f"name: {name}")
@@ -505,6 +513,7 @@ def replace_embedding_weight_only_grouped_int8_per_channel(
505513
module,
506514
name,
507515
QuantizedGroupEmbedding(
516+
device=device,
508517
vocab_size=child.weight.shape[0],
509518
embedding_dim=child.weight.shape[1],
510519
groupsize=groupsize,
@@ -518,10 +527,11 @@ def replace_embedding_weight_only_grouped_int8_per_channel(
518527

519528

520529
class EmbeddingOnlyInt8QuantHandler(QuantHandler):
521-
def __init__(self, mod, *, bitwidth: int = 8, groupsize: Optional[int] = None, packed = False):
530+
def __init__(self, mod, device, *, bitwidth: int = 8, groupsize: Optional[int] = None, packed = False):
522531
if isinstance(packed, str):
523532
packed = (packed == "True")
524533
self.mod = mod
534+
self.device = device
525535
self.groupsize = groupsize
526536
self.bitwidth = bitwidth
527537
self.packed = packed
@@ -565,7 +575,7 @@ def create_quantized_state_dict(self, packed=False) -> Dict:
565575

566576
if packed:
567577
if weight.shape[-1] %2 != 0:
568-
raise RUntimeError("automatic padding not implemented yet")
578+
raise RuntimeError("automatic padding not implemented yet")
569579

570580
weight_range_shifted = weight.add(8).view(torch.uint8)
571581
weight_view = weight_range_shifted.view(
@@ -578,6 +588,8 @@ def create_quantized_state_dict(self, packed=False) -> Dict:
578588
weight_packed = weight_even + weight_odd
579589
weight = weight_packed
580590

591+
weight = weight.to(device=self.device)
592+
scales = scales.to(device=self.device)
581593
# Update state dict
582594
cur_state_dict[f"{fqn}.weight"] = weight
583595
# squeeze makes groupsize=rowsize unidimensional
@@ -587,7 +599,7 @@ def create_quantized_state_dict(self, packed=False) -> Dict:
587599

588600
def convert_for_runtime(self) -> nn.Module:
589601
replace_embedding_weight_only_grouped_int8_per_channel(
590-
self.mod, self.bitwidth, self.groupsize, self.packed
602+
self.mod, self.device, self.bitwidth, self.groupsize, self.packed
591603
)
592604
return self.mod
593605

@@ -601,10 +613,10 @@ def quantized_model(self) -> nn.Module:
601613
class QuantizedGroupEmbedding(torch.nn.Module):
602614
def __init__(
603615
self,
616+
device,
604617
vocab_size: int,
605618
embedding_dim: int,
606619
groupsize: Optional[int] = None,
607-
device=None,
608620
dtype=torch.half,
609621
packed=False,
610622
) -> None:
@@ -616,20 +628,20 @@ def __init__(
616628
self.packed = packed
617629
if not packed:
618630
self.register_buffer(
619-
"weight", torch.empty((vocab_size, embedding_dim), dtype=torch.int8)
631+
"weight", torch.empty((vocab_size, embedding_dim), dtype=torch.int8, device=device)
620632
)
621633
else: # packed
622634
self.register_buffer(
623-
"weight", torch.empty((vocab_size, embedding_dim//2), dtype=torch.uint8)
635+
"weight", torch.empty((vocab_size, embedding_dim//2), dtype=torch.uint8, device=device)
624636
)
625637
groups_per_row = (embedding_dim + groupsize - 1) // groupsize
626638
if groups_per_row > 1:
627639
self.register_buffer(
628-
"scales", torch.ones((vocab_size, groups_per_row), dtype=torch.float16)
640+
"scales", torch.ones((vocab_size, groups_per_row), dtype=torch.float16, device=device)
629641
)
630642
else:
631643
self.register_buffer(
632-
"scales", torch.ones((vocab_size,), dtype=torch.float16)
644+
"scales", torch.ones((vocab_size,), dtype=torch.float16, device=device)
633645
)
634646

635647
@torch.no_grad()
@@ -712,8 +724,9 @@ def replace_linear_int4(module, groupsize, inner_k_tiles, padding_allowed, use_c
712724

713725

714726
class WeightOnlyInt4QuantHandler(QuantHandler):
715-
def __init__(self, mod, groupsize=128, inner_k_tiles=8, padding_allowed=True):
727+
def __init__(self, mod, device, *, groupsize=128, inner_k_tiles=8, padding_allowed=True):
716728
self.mod = mod
729+
self.device = device,
717730
self.groupsize = groupsize
718731
self.inner_k_tiles = inner_k_tiles
719732
self.padding_allowed = padding_allowed
@@ -908,12 +921,15 @@ class Int8DynActInt4WeightQuantHandler(QuantHandler):
908921
def __init__(
909922
self,
910923
mod,
924+
device,
925+
* ,
911926
groupsize=256,
912927
padding_allowed=False,
913928
precision=torch.float32,
914929
scales_precision=torch.float32,
915930
):
916931
self.mod = mod
932+
self.device = device
917933
self.groupsize = groupsize
918934
self.padding_allowed = padding_allowed
919935
self.precision = precision
@@ -1209,9 +1225,10 @@ def convert_for_runtime(self) -> "nn.Module":
12091225

12101226

12111227
class WeightOnlyInt4GPTQQuantHandler(GPTQQuantHandler):
1212-
def __init__(self, mod, groupsize=128, inner_k_tiles=8, padding=True):
1228+
def __init__(self, mod, device, *, groupsize=128, inner_k_tiles=8, padding=True):
12131229
from build.model import find_multiple
12141230
self.mod = mod
1231+
self.device = device
12151232
self.groupsize = groupsize
12161233
self.inner_k_tiles = inner_k_tiles
12171234
self.padding = padding
@@ -1329,7 +1346,7 @@ def quantized_model(self) -> nn.Module:
13291346
### WIP: HQQ ###
13301347

13311348
class WeightOnlyInt4HqqQuantHandler:
1332-
def __init__(self, mod, groupsize):
1349+
def __init__(self, mod, device, *, groupsize):
13331350
self.mod = mod
13341351
self.groupsize = groupsize
13351352

0 commit comments

Comments
 (0)