Skip to content

Commit d80f78f

Browse files
mergennachinfacebook-github-bot
authored andcommitted
Read SpinQuant checkpoints (#5259)
Summary: Pull Request resolved: #5259 Read SpinQuant checkpoints that is in exported with scales/weights. bypass-github-export-checks bypass-github-pytorch-ci-checks bypass-github-executorch-ci-checks Reviewed By: iseeyuan, helunwencser Differential Revision: D62403094 fbshipit-source-id: 283ae18a1d2053306677086b9edd5cb5f38120ee
1 parent 41b463e commit d80f78f

File tree

6 files changed

+262
-17
lines changed

6 files changed

+262
-17
lines changed

examples/models/llama2/export_llama_lib.py

+21-14
Original file line numberDiff line numberDiff line change
@@ -695,6 +695,7 @@ def _load_llama_model(
695695
fairseq2=weight_type == WeightType.FAIRSEQ2,
696696
max_seq_len=max_seq_len,
697697
enable_dynamic_shape=enable_dynamic_shape,
698+
args=args,
698699
)
699700
state_dict = model.state_dict()
700701
dtype = state_dict[next(iter(state_dict))].dtype
@@ -747,9 +748,26 @@ def _get_source_transforms(
747748
transforms = []
748749
if args.quantization_mode:
749750
modelname = f"{modelname}_q"
750-
transforms.append(
751-
get_quant_weight_transform(args, dtype_override, verbose_export())
752-
)
751+
if args.use_spin_quant is None:
752+
transforms.append(
753+
get_quant_weight_transform(args, dtype_override, verbose_export())
754+
)
755+
# For SpinQuant, the checkpoints are already quantized
756+
# aka the weights have corresponding scales value,
757+
# So that means, we don't need to apply quantization
758+
# transform. However, we will still need to apply
759+
# transformations that change the model structure to
760+
# match the checkpoint format.
761+
# transform_for_spinquant() will apply these transformations
762+
# later in model.py file.
763+
elif args.use_spin_quant == "cuda":
764+
from .source_transformation.spin_quant import (
765+
inject_fast_hadamard_transform_cuda_for_spin_quant,
766+
)
767+
768+
transforms.append(inject_fast_hadamard_transform_cuda_for_spin_quant)
769+
elif args.use_spin_quant == "native":
770+
raise NotImplementedError("native SpinQuant is not implemented yet.")
753771

754772
if args.embedding_quantize:
755773
modelname = f"{modelname}_e"
@@ -783,15 +801,4 @@ def _get_source_transforms(
783801
transforms.append(replace_sdpa_with_simple_sdpa)
784802
transforms.append(replace_causal_mask)
785803

786-
if args.use_spin_quant:
787-
if args.use_spin_quant == "cuda":
788-
from .source_transformation.spin_quant import (
789-
inject_fast_hadamard_transform_cuda_for_spin_quant,
790-
)
791-
792-
transforms.append(inject_fast_hadamard_transform_cuda_for_spin_quant)
793-
794-
elif args.use_spin_quant == "native":
795-
raise NotImplementedError("native SpinQuant is not implemented yet.")
796-
797804
return transforms

examples/models/llama2/model.py

+44-3
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ def __init__(self, **kwargs):
6565
self.enable_dynamic_shape = kwargs.get("enable_dynamic_shape", False)
6666

6767
self.max_seq_len = kwargs.get("max_seq_len", 128)
68+
self.args = kwargs.get("args", None)
6869
# The example is using a dummy small model with random weights for demo purpose only.
6970
# Follow the instruction in https://github.com/facebookresearch/llama to download the model
7071
device = "cpu"
@@ -126,7 +127,8 @@ def __init__(self, **kwargs):
126127
# get checkpoint dtype
127128
self.dtype = None
128129
if len(checkpoint) > 0:
129-
first = checkpoint[next(iter(checkpoint))]
130+
first_key = next(iter(checkpoint))
131+
first = checkpoint[first_key]
130132
self.dtype = first.dtype
131133
mismatched_dtypes = [
132134
(key, value.dtype)
@@ -135,7 +137,7 @@ def __init__(self, **kwargs):
135137
]
136138
if len(mismatched_dtypes) > 0:
137139
print(
138-
f"Mixed dtype model. Dtype of {first.key}: {first.dtype}. Mismatches in the checkpoint: {mismatched_dtypes}"
140+
f"Mixed dtype model. Dtype of {first_key}: {first.dtype}. Mismatches in the checkpoint: {mismatched_dtypes}"
139141
)
140142
with open(params_path, "r") as f:
141143
params = json.loads(f.read())
@@ -179,15 +181,54 @@ def __init__(self, **kwargs):
179181
self.model_ = Int8DynActInt4WeightQuantizer()._convert_for_runtime(
180182
self.model_
181183
)
184+
elif hasattr(self.args, "use_spin_quant") and self.args.use_spin_quant:
185+
print("Using SPIN quantization.")
186+
assert hasattr(self.args, "group_size"), "group_size must be specified"
187+
assert hasattr(
188+
self.args, "quantization_mode"
189+
), "quantization_mode must be specified"
190+
assert hasattr(
191+
self.args, "dtype_override"
192+
), "dtype_override must be specified"
193+
from .source_transformation.spin_quant import (
194+
sanitize_checkpoint_from_spinquant,
195+
transform_for_spinquant,
196+
)
197+
198+
mapping = {
199+
"fp32": torch.float32,
200+
"fp16": torch.float16,
201+
"bf16": torch.bfloat16,
202+
}
203+
204+
self.model_ = transform_for_spinquant(
205+
self.model_,
206+
checkpoint,
207+
self.args.group_size,
208+
self.args.quantization_mode,
209+
mapping[self.args.dtype_override],
210+
)
211+
212+
sanitize_checkpoint_from_spinquant(
213+
checkpoint,
214+
self.args.group_size,
215+
)
182216

183217
# assign=True: load params/buffers by assignment instead of performing an in-place copy.
184218
# Because we are using device="meta", tensors do not have memory associated with them
185219
# and an in-place copy is a no-op. Use assign=True in load_state_dict for this scenario.
186-
self.model_.load_state_dict(
220+
missing, unexpected = self.model_.load_state_dict(
187221
checkpoint,
188222
strict=False,
189223
assign=True,
190224
) # self.model_ = Transformer(gptconf)
225+
if kwargs.get("verbose", False):
226+
print("============= missing keys ================")
227+
print(missing)
228+
print("============= /missing ================")
229+
print("============= unexpected keys ================")
230+
print(unexpected)
231+
print("============= /unexpected ================")
191232

192233
def get_eager_model(self):
193234
if self.dtype:

examples/models/llama2/source_transformation/spin_quant.py

+93
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,16 @@
99
# Helper functions for tranforming the model to be able to run SpinQuant.
1010
# See https://github.com/facebookresearch/SpinQuant for more details about SpinQuant.
1111

12+
from typing import Any
13+
1214
import torch
1315

1416
import torch.nn.functional as F
1517

1618
from executorch.examples.models.llama2.llama_transformer import FeedForward
1719
from torch import nn
20+
from torchao.quantization.GPTQ import _check_linear_int4_k, Int8DynActInt4WeightLinear
21+
from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter
1822

1923

2024
def _inject_fast_hadamard_transform_cuda_for_spin_quant(module: torch.nn.Module):
@@ -53,3 +57,92 @@ def inject_fast_hadamard_transform_cuda_for_spin_quant(
5357
) -> torch.nn.Module:
5458
_inject_fast_hadamard_transform_cuda_for_spin_quant(module)
5559
return module
60+
61+
62+
def _replace_linear_with_linear_8da4w_for_spin_quant(
63+
module: torch.nn.Module,
64+
checkpoint: Any,
65+
group_size: int,
66+
precision: torch.dtype,
67+
scales_precision: torch.dtype,
68+
):
69+
def filter_fn(child: torch.nn.Module, cur_fqn: str) -> bool:
70+
# Only replace linear layers where the checkpoint contains explicit scales
71+
scales_key = f"{cur_fqn}.scale"
72+
if isinstance(child, nn.Linear) and scales_key in checkpoint:
73+
assert _check_linear_int4_k(child.in_features, group_size)
74+
assert checkpoint[f"{cur_fqn}.weight"].dtype == torch.int8
75+
assert checkpoint[scales_key].dtype == scales_precision
76+
return True
77+
return False
78+
79+
def replacement_fn(child: torch.nn.Module) -> torch.nn.Module:
80+
new_linear = Int8DynActInt4WeightLinear(
81+
child.in_features,
82+
child.out_features,
83+
bias=False,
84+
device=child.weight.device,
85+
groupsize=group_size,
86+
precision=precision,
87+
scales_precision=scales_precision,
88+
)
89+
return new_linear
90+
91+
_replace_with_custom_fn_if_matches_filter(module, replacement_fn, filter_fn)
92+
93+
94+
def transform_for_spinquant(
95+
module: torch.nn.Module,
96+
checkpoint: Any,
97+
group_size: int,
98+
quantization_mode: str,
99+
dtype: torch.dtype,
100+
) -> torch.nn.Module:
101+
"""
102+
Transform the model to be able to load SpinQuant checkpoints that
103+
are quantized with the given group size and quantization mode.
104+
"""
105+
106+
if group_size not in [32, 64, 128, 256]:
107+
raise ValueError(f"Group size {group_size} is not supported for SpinQuant.")
108+
if quantization_mode not in ["8da4w"]:
109+
raise ValueError(
110+
f"Quantization mode {quantization_mode} is not compatible with SpinQuant."
111+
)
112+
_replace_linear_with_linear_8da4w_for_spin_quant(
113+
module,
114+
checkpoint,
115+
group_size,
116+
dtype,
117+
dtype,
118+
)
119+
return module
120+
121+
122+
def sanitize_checkpoint_from_spinquant(
123+
checkpoint: Any,
124+
group_size: int,
125+
):
126+
"""
127+
Sanitize the SpinQuant checkpoint.
128+
- Renames 'scale' to 'scales'
129+
- Groups scales
130+
- Removes 'o_weight'
131+
- Converts all tensors to contiguous format
132+
"""
133+
keys_to_rename = []
134+
keys_to_remove = []
135+
for k, _ in checkpoint.items():
136+
if k.endswith(".scale"):
137+
new_key = k + "s"
138+
keys_to_rename.append((k, new_key))
139+
if k.endswith(".o_weight"):
140+
keys_to_remove.append(k)
141+
142+
for old_key, new_key in keys_to_rename:
143+
old_val = checkpoint.pop(old_key)
144+
checkpoint[new_key] = old_val if group_size == -1 else old_val[:, ::group_size]
145+
for k in keys_to_remove:
146+
checkpoint.pop(k)
147+
for k, v in checkpoint.items():
148+
checkpoint[k] = v.contiguous()

examples/models/llama2/tests/TARGETS

+13
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,16 @@ python_unittest(
1313
"//executorch/examples/models/llama2:llama_transformer",
1414
],
1515
)
16+
17+
python_unittest(
18+
name = "test_spinquant_transforms",
19+
srcs = [
20+
"test_spinquant_transforms.py",
21+
],
22+
deps = [
23+
"//caffe2:torch",
24+
"//executorch/examples/models/llama2:export_library",
25+
"//executorch/examples/models/llama2:llama_transformer",
26+
"//pytorch/ao:torchao",
27+
],
28+
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
9+
import torch
10+
from executorch.examples.models.llama2.llama_transformer import ModelArgs, Transformer
11+
from executorch.examples.models.llama2.source_transformation.spin_quant import (
12+
sanitize_checkpoint_from_spinquant,
13+
transform_for_spinquant,
14+
)
15+
from torchao.quantization.utils import group_quantize_tensor_symmetric
16+
17+
18+
class SpinQuantTests(unittest.TestCase):
19+
def test_transforms_for_spinquant(self):
20+
21+
# Step 1: Create llama class with dummy weights
22+
params = {
23+
"dim": 768,
24+
"multiple_of": 32,
25+
"n_heads": 12,
26+
"n_layers": 12,
27+
"norm_eps": 1e-05,
28+
"vocab_size": 32000,
29+
}
30+
31+
model_args = ModelArgs(
32+
max_seq_len=2048,
33+
max_batch_size=1,
34+
use_kv_cache=False,
35+
use_sdpa_with_kv_cache_op=False,
36+
generate_full_logits=False,
37+
enable_dynamic_shape=True,
38+
**params,
39+
)
40+
41+
model = Transformer(model_args)
42+
checkpoint = model.state_dict()
43+
44+
# Step 2:
45+
# Do group-wise quantization and amend the checkpoints with
46+
# int8 weight and fp32 scales
47+
group_size = 32
48+
n_bit = 4
49+
scales_precision = torch.float32
50+
for fqn, mod in model.named_modules():
51+
# Quantize everything except the last layer
52+
if isinstance(mod, torch.nn.Linear) and ("output" not in fqn):
53+
weight = mod.weight.data
54+
(
55+
weight_int8,
56+
scales,
57+
zeros,
58+
) = group_quantize_tensor_symmetric(
59+
weight.to(torch.float32), n_bit, group_size, scales_precision
60+
)
61+
checkpoint[f"{fqn}.weight"] = weight_int8.to("cpu")
62+
checkpoint[f"{fqn}.scale"] = scales.to("cpu")
63+
64+
# Step 3:
65+
# Transform the model so that it is compatible with the new checkpoint
66+
transform_for_spinquant(
67+
model,
68+
checkpoint,
69+
32,
70+
"8da4w",
71+
torch.float32,
72+
)
73+
sanitize_checkpoint_from_spinquant(
74+
checkpoint,
75+
-1,
76+
)
77+
78+
model.load_state_dict(
79+
checkpoint,
80+
strict=False,
81+
assign=True,
82+
)
83+
84+
new_checkpoint = model.state_dict()
85+
86+
for k, v in checkpoint.items():
87+
# The new_checkpoint contains zeros so
88+
# have to iterate over the keys.
89+
self.assertTrue(torch.allclose(new_checkpoint[k], v))

pytest.ini

+2
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ addopts =
3838
test/end2end/test_end2end.py
3939
--ignore=backends/xnnpack/test/ops/linear.py
4040
--ignore=backends/xnnpack/test/models/llama2_et_example.py
41+
# T200992559: Add torchao to ET as core dependency
42+
--ignore=examples/models/llama2/tests/test_spinquant_transforms.py
4143
--ignore=exir/backend/test/demos
4244
--ignore=exir/backend/test/test_backends.py
4345
--ignore=exir/backend/test/test_backends_lifted.py

0 commit comments

Comments
 (0)