Skip to content

Read SpinQuant checkpoints #5259

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 21 additions & 14 deletions examples/models/llama2/export_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -688,6 +688,7 @@ def _load_llama_model(
fairseq2=weight_type == WeightType.FAIRSEQ2,
max_seq_len=max_seq_len,
enable_dynamic_shape=enable_dynamic_shape,
args=args,
)
state_dict = model.state_dict()
dtype = state_dict[next(iter(state_dict))].dtype
Expand Down Expand Up @@ -740,9 +741,26 @@ def _get_source_transforms(
transforms = []
if args.quantization_mode:
modelname = f"{modelname}_q"
transforms.append(
get_quant_weight_transform(args, dtype_override, verbose_export())
)
if args.use_spin_quant is None:
transforms.append(
get_quant_weight_transform(args, dtype_override, verbose_export())
)
# For SpinQuant, the checkpoints are already quantized
# aka the weights have corresponding scales value,
# So that means, we don't need to apply quantization
# transform. However, we will still need to apply
# transformations that change the model structure to
# match the checkpoint format.
# transform_for_spinquant() will apply these transformations
# later in model.py file.
elif args.use_spin_quant == "cuda":
from .source_transformation.spin_quant import (
inject_fast_hadamard_transform_cuda_for_spin_quant,
)

transforms.append(inject_fast_hadamard_transform_cuda_for_spin_quant)
elif args.use_spin_quant == "native":
raise NotImplementedError("native SpinQuant is not implemented yet.")

if args.embedding_quantize:
modelname = f"{modelname}_e"
Expand Down Expand Up @@ -776,15 +794,4 @@ def _get_source_transforms(
transforms.append(replace_sdpa_with_simple_sdpa)
transforms.append(replace_causal_mask)

if args.use_spin_quant:
if args.use_spin_quant == "cuda":
from .source_transformation.spin_quant import (
inject_fast_hadamard_transform_cuda_for_spin_quant,
)

transforms.append(inject_fast_hadamard_transform_cuda_for_spin_quant)

elif args.use_spin_quant == "native":
raise NotImplementedError("native SpinQuant is not implemented yet.")

return transforms
47 changes: 44 additions & 3 deletions examples/models/llama2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def __init__(self, **kwargs):
self.enable_dynamic_shape = kwargs.get("enable_dynamic_shape", False)

self.max_seq_len = kwargs.get("max_seq_len", 128)
self.args = kwargs.get("args", None)
# The example is using a dummy small model with random weights for demo purpose only.
# Follow the instruction in https://github.com/facebookresearch/llama to download the model
device = "cpu"
Expand Down Expand Up @@ -126,7 +127,8 @@ def __init__(self, **kwargs):
# get checkpoint dtype
self.dtype = None
if len(checkpoint) > 0:
first = checkpoint[next(iter(checkpoint))]
first_key = next(iter(checkpoint))
first = checkpoint[first_key]
self.dtype = first.dtype
mismatched_dtypes = [
(key, value.dtype)
Expand All @@ -135,7 +137,7 @@ def __init__(self, **kwargs):
]
if len(mismatched_dtypes) > 0:
print(
f"Mixed dtype model. Dtype of {first.key}: {first.dtype}. Mismatches in the checkpoint: {mismatched_dtypes}"
f"Mixed dtype model. Dtype of {first_key}: {first.dtype}. Mismatches in the checkpoint: {mismatched_dtypes}"
)
with open(params_path, "r") as f:
params = json.loads(f.read())
Expand Down Expand Up @@ -179,15 +181,54 @@ def __init__(self, **kwargs):
self.model_ = Int8DynActInt4WeightQuantizer()._convert_for_runtime(
self.model_
)
elif hasattr(self.args, "use_spin_quant") and self.args.use_spin_quant:
print("Using SPIN quantization.")
assert hasattr(self.args, "group_size"), "group_size must be specified"
assert hasattr(
self.args, "quantization_mode"
), "quantization_mode must be specified"
assert hasattr(
self.args, "dtype_override"
), "dtype_override must be specified"
from .source_transformation.spin_quant import (
sanitize_checkpoint_from_spinquant,
transform_for_spinquant,
)

mapping = {
"fp32": torch.float32,
"fp16": torch.float16,
"bf16": torch.bfloat16,
}

self.model_ = transform_for_spinquant(
self.model_,
checkpoint,
self.args.group_size,
self.args.quantization_mode,
mapping[self.args.dtype_override],
)

sanitize_checkpoint_from_spinquant(
checkpoint,
self.args.group_size,
)

# assign=True: load params/buffers by assignment instead of performing an in-place copy.
# Because we are using device="meta", tensors do not have memory associated with them
# and an in-place copy is a no-op. Use assign=True in load_state_dict for this scenario.
self.model_.load_state_dict(
missing, unexpected = self.model_.load_state_dict(
checkpoint,
strict=False,
assign=True,
) # self.model_ = Transformer(gptconf)
if kwargs.get("verbose", False):
print("============= missing keys ================")
print(missing)
print("============= /missing ================")
print("============= unexpected keys ================")
print(unexpected)
print("============= /unexpected ================")

def get_eager_model(self):
if self.dtype:
Expand Down
93 changes: 93 additions & 0 deletions examples/models/llama2/source_transformation/spin_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,16 @@
# Helper functions for tranforming the model to be able to run SpinQuant.
# See https://github.com/facebookresearch/SpinQuant for more details about SpinQuant.

from typing import Any

import torch

import torch.nn.functional as F

from executorch.examples.models.llama2.llama_transformer import FeedForward
from torch import nn
from torchao.quantization.GPTQ import _check_linear_int4_k, Int8DynActInt4WeightLinear
from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter


def _inject_fast_hadamard_transform_cuda_for_spin_quant(module: torch.nn.Module):
Expand Down Expand Up @@ -53,3 +57,92 @@ def inject_fast_hadamard_transform_cuda_for_spin_quant(
) -> torch.nn.Module:
_inject_fast_hadamard_transform_cuda_for_spin_quant(module)
return module


def _replace_linear_with_linear_8da4w_for_spin_quant(
module: torch.nn.Module,
checkpoint: Any,
group_size: int,
precision: torch.dtype,
scales_precision: torch.dtype,
):
def filter_fn(child: torch.nn.Module, cur_fqn: str) -> bool:
# Only replace linear layers where the checkpoint contains explicit scales
scales_key = f"{cur_fqn}.scale"
if isinstance(child, nn.Linear) and scales_key in checkpoint:
assert _check_linear_int4_k(child.in_features, group_size)
assert checkpoint[f"{cur_fqn}.weight"].dtype == torch.int8
assert checkpoint[scales_key].dtype == scales_precision
return True
return False

def replacement_fn(child: torch.nn.Module) -> torch.nn.Module:
new_linear = Int8DynActInt4WeightLinear(
child.in_features,
child.out_features,
bias=False,
device=child.weight.device,
groupsize=group_size,
precision=precision,
scales_precision=scales_precision,
)
return new_linear

_replace_with_custom_fn_if_matches_filter(module, replacement_fn, filter_fn)


def transform_for_spinquant(
module: torch.nn.Module,
checkpoint: Any,
group_size: int,
quantization_mode: str,
dtype: torch.dtype,
) -> torch.nn.Module:
"""
Transform the model to be able to load SpinQuant checkpoints that
are quantized with the given group size and quantization mode.
"""

if group_size not in [32, 64, 128, 256]:
raise ValueError(f"Group size {group_size} is not supported for SpinQuant.")
if quantization_mode not in ["8da4w"]:
raise ValueError(
f"Quantization mode {quantization_mode} is not compatible with SpinQuant."
)
_replace_linear_with_linear_8da4w_for_spin_quant(
module,
checkpoint,
group_size,
dtype,
dtype,
)
return module


def sanitize_checkpoint_from_spinquant(
checkpoint: Any,
group_size: int,
):
"""
Sanitize the SpinQuant checkpoint.
- Renames 'scale' to 'scales'
- Groups scales
- Removes 'o_weight'
- Converts all tensors to contiguous format
"""
keys_to_rename = []
keys_to_remove = []
for k, _ in checkpoint.items():
if k.endswith(".scale"):
new_key = k + "s"
keys_to_rename.append((k, new_key))
if k.endswith(".o_weight"):
keys_to_remove.append(k)

for old_key, new_key in keys_to_rename:
old_val = checkpoint.pop(old_key)
checkpoint[new_key] = old_val if group_size == -1 else old_val[:, ::group_size]
for k in keys_to_remove:
checkpoint.pop(k)
for k, v in checkpoint.items():
checkpoint[k] = v.contiguous()
13 changes: 13 additions & 0 deletions examples/models/llama2/tests/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,16 @@ python_unittest(
"//executorch/examples/models/llama2:llama_transformer",
],
)

python_unittest(
name = "test_spinquant_transforms",
srcs = [
"test_spinquant_transforms.py",
],
deps = [
"//caffe2:torch",
"//executorch/examples/models/llama2:export_library",
"//executorch/examples/models/llama2:llama_transformer",
"//pytorch/ao:torchao",
],
)
89 changes: 89 additions & 0 deletions examples/models/llama2/tests/test_spinquant_transforms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import unittest

import torch
from executorch.examples.models.llama2.llama_transformer import ModelArgs, Transformer
from executorch.examples.models.llama2.source_transformation.spin_quant import (
sanitize_checkpoint_from_spinquant,
transform_for_spinquant,
)
from torchao.quantization.utils import group_quantize_tensor_symmetric


class SpinQuantTests(unittest.TestCase):
def test_transforms_for_spinquant(self):

# Step 1: Create llama class with dummy weights
params = {
"dim": 768,
"multiple_of": 32,
"n_heads": 12,
"n_layers": 12,
"norm_eps": 1e-05,
"vocab_size": 32000,
}

model_args = ModelArgs(
max_seq_len=2048,
max_batch_size=1,
use_kv_cache=False,
use_sdpa_with_kv_cache_op=False,
generate_full_logits=False,
enable_dynamic_shape=True,
**params,
)

model = Transformer(model_args)
checkpoint = model.state_dict()

# Step 2:
# Do group-wise quantization and amend the checkpoints with
# int8 weight and fp32 scales
group_size = 32
n_bit = 4
scales_precision = torch.float32
for fqn, mod in model.named_modules():
# Quantize everything except the last layer
if isinstance(mod, torch.nn.Linear) and ("output" not in fqn):
weight = mod.weight.data
(
weight_int8,
scales,
zeros,
) = group_quantize_tensor_symmetric(
weight.to(torch.float32), n_bit, group_size, scales_precision
)
checkpoint[f"{fqn}.weight"] = weight_int8.to("cpu")
checkpoint[f"{fqn}.scale"] = scales.to("cpu")

# Step 3:
# Transform the model so that it is compatible with the new checkpoint
transform_for_spinquant(
model,
checkpoint,
32,
"8da4w",
torch.float32,
)
sanitize_checkpoint_from_spinquant(
checkpoint,
-1,
)

model.load_state_dict(
checkpoint,
strict=False,
assign=True,
)

new_checkpoint = model.state_dict()

for k, v in checkpoint.items():
# The new_checkpoint contains zeros so
# have to iterate over the keys.
self.assertTrue(torch.allclose(new_checkpoint[k], v))
2 changes: 2 additions & 0 deletions pytest.ini
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ addopts =
test/end2end/test_end2end.py
--ignore=backends/xnnpack/test/ops/linear.py
--ignore=backends/xnnpack/test/models/llama2_et_example.py
# T200992559: Add torchao to ET as core dependency
--ignore=examples/models/llama2/tests/test_spinquant_transforms.py
--ignore=exir/backend/test/demos
--ignore=exir/backend/test/test_backends.py
--ignore=exir/backend/test/test_backends_lifted.py
Expand Down
Loading