Skip to content

Commit a35a1cd

Browse files
authored
Add sparsify API to torchao (#473)
* Add sparsify API to torchao * fix typo
1 parent a2e8e2a commit a35a1cd

File tree

6 files changed

+104
-34
lines changed

6 files changed

+104
-34
lines changed

README.md

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -47,25 +47,41 @@ For int4 we make heavy use of [tinygemm](https://github.com/pytorch/ao/blob/cb3b
4747

4848
And a quick crash course on inference quantization to help parse the above table. Int4 quantization is an ambiguous term because there's the dtype in which a layer is represented and then the dtype in which the computation is done. For example, if you're using Weight-Only (wo) int4 quantization that means that the layer will be upcasted to a larger dtype like fp16 so an int4 matrix multiplication is defined as `F.linear(input, weight.to(input.dtype))`. Dynamic quantization (DQ) primarily targets activations, enabling on-the-fly quantization from higher precision formats like bf16 to lower precision formats such as int8. This process, when supported by hardware, allows for direct computation, such as performing `F.linear(input, weight)`. Naive quantization algorithms are also notoriously sensitive to outliers so we also typically set a group size that applies a scale factor per group of 64 elements in the case of `int4wo64`.
4949

50+
Sparsifying your model is also a 1 liner that should work on any model with an `nn.Linear`. We find that sparsity works best on compute bound models like SAM, specifically the MLP layers.
51+
```python
52+
from torchao.sparsity import sparsify
53+
from torch.sparse import to_sparse_semi_structured
5054

51-
#### With intrusive code changes
55+
m = sparsify(m, to_sparse_semi_structured)
56+
```
57+
Sparsity can also be composed with int8 dynamic quantization for further speedups:
5258

53-
In some cases we rewrote popular GenAI models to be significantly faster in native PyTorch as in no C++/CUDA to achieve at the time SOTA inference performance. These involve more intrusive code changes.
59+
```python
60+
from torchao.sparsity import sparsify
61+
from torchao.sparsity.prototype.dynamic_quant_sparse import int8_dynamic_activation_int8_2x4_sparse_weight
5462

55-
* 9.5x speedups for Image segmentation models with [sam-fast](https://pytorch.org/blog/accelerating-generative-ai) compared to vanilla [sam](https://github.com/facebookresearch/segment-anything).
56-
* 1.16x speedup when composing int8 quantization with 2:4 sparsity against the accelerated baseline `bfloat16` dtype and `torch.compile="max_autotune"`.
63+
m = sparsify(m, int8_dynamic_activation_int8_2x4_sparse_weight())
64+
```
65+
We found that applying int8 dynamic quantization to the attention layers, int8 dynamic quantization + 2:4 sparsity to mlp layer 1 and 2:4 sparsity to mlp layer 2 yielded the best configuration.
66+
We were able to provide a **1.16x (22.7 -> 26.5 img/s) speedup over our dense baseline, while maintaining 97.5% (0.581 -> 0.567) of the evaluation accuracy (mIOU)**.
67+
68+
The following benchmarks were ran for [segment-anything-fast](https://github.com/pytorch-labs/segment-anything-fast) ViT-h on an NVIDIA-A100-80GB, with batch_size=32 and `bfloat16` dtype, with `torch.compile="max_autotune"`:
5769

5870
| Model Type | Technique | img/s | memory (MiB) | mIoU (coco2017 val) | relative speedup | relative accuracy |
5971
|------------|------------------------------------------------------------------------------------------------------|-------|--------------|---------------------|------------------|-------------------|
60-
| ViT-h | sam (float32, eager) | 2.78 | 28806 | 0.58 | baseline | baseline |
61-
| | sam (bfloat16, eager) | 14.85 | 14424 | 0.58 | **5.34x** | **100%** |
62-
| | sam-fast (bfloat16, max-autotune) | 22.75 | 15172 | 0.58 | **8.18x** | **100%** |
63-
| | int8 dynamic quant (attn + mlp) | 24.91 | 15154 | 0.58 | **8.96x** | **100%** |
64-
| | 2:4 sparsity (mlp only) | 24.81 | 15632 | 0.57 | **8.92x** | **98%** |
65-
| | int8 dynamic quant (attn)<br>int8 dynamic quant + 2:4 sparsity (mlp lin1)<br>2:4 sparsity (mlp lin2) | 26.46 | 14865 | 0.57 | **9.52x** | **98%** |
72+
| ViT-h | baseline (bfloat16, max-autotune) | 22.75 | 15172 | 0.5811 | | |
73+
| | int8 dynamic quant (attn + mlp) | 24.91 | 15154 | 0.5822 | **1.09x** | **100.19%** |
74+
| | 2:4 sparsity (mlp only) | 24.81 | 15632 | 0.5672 | **1.10x** | **97.61%** |
75+
| | 2:4 sparsity (attn + mlp) | 24.30 | 13429 | 0.5306 | **1.07x** | **91.31%** |
76+
| | int8 dynamic quant (attn)<br>int8 dynamic quant + 2:4 sparsity (mlp lin1)<br>2:4 sparsity (mlp lin2) | 26.46 | 14865 | 0.5668 | **1.16x** | **97.54%** |
77+
78+
To reproduce our benchmarks please follow these [instructions](/scripts/sam/README.md).
6679

67-
The relative speedup is measured purely across the image encoder (ViT) of the model, where we apply our model optimizations. Benchmarks ran on an NVIDIA-A100-80GB with batch_size=32
80+
#### With intrusive code changes
81+
82+
In some cases we rewrote popular GenAI models to be significantly faster in native PyTorch as in no C++/CUDA to achieve at the time SOTA inference performance. These involve more intrusive code changes.
6883

84+
* 8x with in speedups for Image segmentation models with [sam-fast](https://pytorch.org/blog/accelerating-generative-ai) (9.5x with int8 dynamic quantization + 2:4 sparsity)
6985
* 10x speedups for Language models with [gpt-fast](https://pytorch.org/blog/accelerating-generative-ai-2)
7086
* 3x speedup for Diffusion models with [sd-fast](https://pytorch.org/blog/accelerating-generative-ai-3)
7187

scripts/sam/eval_combo.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -286,14 +286,20 @@ def run(
286286
elif compress == "sparse_mlp_only":
287287
def mlp_only(mod, name):
288288
return isinstance(mod, torch.nn.Linear) and 'mlp' in name
289-
from torchao.sparsity import apply_sparse_semi_structured
290-
apply_sparse_semi_structured(predictor.model.image_encoder, filter_fn=mlp_only)
289+
from torchao.sparsity import sparsify
290+
from torch.sparse import to_sparse_semi_structured, apply_fake_sparsity
291+
apply_fake_sparsity(predictor.model.image_encoder, filter_fn=mlp_only)
292+
predictor.model.image_encoder = sparsify(predictor.model.image_encoder, to_sparse_semi_structured, filter_fn=mlp_only)
291293
elif compress == "sparse":
292-
from torchao.sparsity import apply_sparse_semi_structured
293-
apply_sparse_semi_structured(predictor.model.image_encoder)
294+
from torchao.sparsity import sparsify
295+
from torch.sparse import to_sparse_semi_structured, apply_fake_sparsity
296+
apply_fake_sparsity(predictor.model.image_encoder)
297+
predictor.model.image_encoder = sparsify(predictor.model.image_encoder, to_sparse_semi_structured)
294298
elif compress == "int8_dynamic_quant_sparse":
295-
from torchao.sparsity.prototype.dynamic_quant_sparse import Int8DynamicallyQuantized24CusparseltLinearFuseMulWeight
296-
from torchao.sparsity import apply_fake_sparsity, apply_sparse_semi_structured
299+
from torch.sparse import to_sparse_semi_structured, SparseSemiStructuredTensor
300+
SparseSemiStructuredTensor._FORCE_CUTLASS = False
301+
from torchao.sparsity import sparsify, apply_fake_sparsity
302+
from torchao.sparsity.prototype.dynamic_quant_sparse import int8_dynamic_activation_int8_2x4_sparse_weight
297303
from torchao.quantization import quantize, int8_dynamic_activation_int8_weight
298304
from torchao.utils import unwrap_tensor_subclass
299305

@@ -306,6 +312,7 @@ def mlp_lin2_only(mod, name):
306312
def mlp_only(mod, name):
307313
return isinstance(mod, torch.nn.Linear) and 'mlp' in name
308314

315+
# apply sparsify first to set qparams
309316
apply_fake_sparsity(predictor.model.image_encoder,
310317
filter_fn=mlp_only)
311318

@@ -314,10 +321,13 @@ def mlp_only(mod, name):
314321
attn_only)
315322
predictor.model.image_encoder = unwrap_tensor_subclass(predictor.model.image_encoder)
316323

317-
predictor.model.image_encoder = quantize(predictor.model.image_encoder,
318-
Int8DynamicallyQuantized24CusparseltLinearFuseMulWeight.from_float,
319-
mlp_lin1_only)
320-
apply_sparse_semi_structured(predictor.model.image_encoder, filter_fn=mlp_lin2_only)
324+
predictor.model.image_encoder = sparsify(predictor.model.image_encoder,
325+
int8_dynamic_activation_int8_2x4_sparse_weight(),
326+
mlp_lin1_only, prune=False)
327+
328+
predictor.model.image_encoder = sparsify(predictor.model.image_encoder,
329+
to_sparse_semi_structured,
330+
mlp_lin2_only, prune=False)
321331
else:
322332
assert compress is None, f"Unsupported compress mode {compress}"
323333

test/sparsity/test_sparse_api.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,10 @@
33

44
import torch
55
from torch import nn
6+
from torch.sparse import to_sparse_semi_structured
67

7-
from torchao.sparsity import apply_fake_sparsity, apply_sparse_semi_structured
8-
from torchao.sparsity.prototype.dynamic_quant_sparse import Int8DynamicallyQuantizedSemiStructuredSparseLinearWeight
8+
from torchao.sparsity import apply_fake_sparsity, sparsify
9+
from torchao.sparsity.prototype.dynamic_quant_sparse import int8_dynamic_activation_int8_2x4_sparse_weight
910
from torchao.quantization.quant_api import (
1011
_replace_with_custom_fn_if_matches_filter,
1112
_get_subclass_inserter,
@@ -37,7 +38,7 @@ def test_sparse(self):
3738
apply_fake_sparsity(model)
3839
dense_result = model(input)
3940

40-
apply_sparse_semi_structured(model)
41+
model = sparsify(model, to_sparse_semi_structured)
4142
sparse_result = model(input)
4243

4344
assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3)
@@ -61,7 +62,7 @@ def test_quant_semi_sparse(self):
6162
apply_fake_sparsity(model)
6263
dense_result = model(input)
6364

64-
_replace_with_custom_fn_if_matches_filter(model, _get_subclass_inserter(Int8DynamicallyQuantizedSemiStructuredSparseLinearWeight), _is_linear)
65+
sparsify(model, int8_dynamic_activation_int8_2x4_sparse_weight())
6566
sparse_result = model(input)
6667

6768
assert torch.allclose(dense_result, sparse_result, rtol=1e-1, atol=1e-1)

torchao/sparsity/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,11 @@
66

77
from .wanda import WandaSparsifier # noqa: F403
88
from .utils import PerChannelNormObserver # noqa: F403
9-
from .sparse_api import apply_sparse_semi_structured, apply_fake_sparsity
9+
from .sparse_api import apply_fake_sparsity, sparsify
1010

1111
__all__ = [
1212
"WandaSparsifier",
1313
"PerChannelNormObserver",
14-
"apply_sparse_semi_structured",
1514
"apply_fake_sparsity",
15+
"sparsify"
1616
]

torchao/sparsity/prototype/dynamic_quant_sparse.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -309,3 +309,6 @@ def from_float(cls, input_float, qmin=-128, qmax=127):
309309
input_float.shape,
310310
dtype=input_float.dtype,
311311
)
312+
313+
def int8_dynamic_activation_int8_2x4_sparse_weight():
314+
return Int8DynamicallyQuantized24CusparseltLinearFuseMulWeight.from_float

torchao/sparsity/sparse_api.py

Lines changed: 47 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,13 @@
1+
from typing import Callable, Optional
2+
13
import torch
24
from torch.ao.pruning import WeightNormSparsifier
35
from torch.sparse import to_sparse_semi_structured
4-
from torchao.quantization.quant_api import _is_linear
6+
from torchao.quantization.quant_api import (
7+
_is_linear,
8+
_replace_with_custom_fn_if_matches_filter,
9+
_get_linear_subclass_inserter,
10+
)
511

612
# Sparsity helper functions
713
def apply_fake_sparsity(model, **kwargs):
@@ -24,10 +30,44 @@ def apply_fake_sparsity(model, **kwargs):
2430
sparsifier.squash_mask()
2531

2632

27-
def apply_sparse_semi_structured(model, **kwargs):
28-
filter_fn = kwargs.pop("filter_fn", _is_linear)
33+
def sparsify(model: torch.nn.Module,
34+
apply_tensor_subclass: Callable[[torch.Tensor], torch.Tensor],
35+
filter_fn: Optional[Callable[[torch.nn.Module, str], bool]]=None) -> torch.nn.Module:
36+
"""Convert the weight of linear modules in the model with `apply_tensor_subclass`
37+
This function is essentially the same as quantize, put for sparsity subclasses.
2938
30-
apply_fake_sparsity(model, filter_fn=filter_fn)
31-
for name, mod in model.named_modules():
32-
if filter_fn(mod, name):
33-
mod.weight = torch.nn.Parameter(to_sparse_semi_structured(mod.weight))
39+
Currently, we support two options for sparsity:
40+
- semi-structured (2:4) sparsity with `to_sparse_semi_structured`
41+
- int8 dynamic quantization + 2:4 sparsity with `int8_dynamic_activation_int8_2x4_sparse_weight`, which is also available via the quantize API
42+
43+
Args:
44+
model (torch.nn.Module): input model
45+
apply_tensor_subclass (Callable[[torch.Tensor], torch.Tensor]): function that convert a floating point Tensor to a (sparsified) tensor subclass instance (e.g. affine quantized tensor instance)
46+
filter_fn (Optional[Callable[[torch.nn.Module, str], bool]]): function that takes a nn.Module instance and fully qualified name of the module, returns True if we want to run `apply_tensor_subclass` on
47+
the weight of the module
48+
49+
Example::
50+
import torch
51+
import torch.nn as nn
52+
from torchao.sparsity import sparsify
53+
54+
def filter_fn(module: nn.Module, fqn: str) -> bool:
55+
return isinstance(module, nn.Linear)
56+
57+
m = nn.Sequential(nn.Linear(32, 1024), nn.Linear(1024, 32))
58+
59+
# for 2:4 sparsity
60+
from torch.sparse import to_sparse_semi_structured
61+
m = sparsify(m, to_sparse_semi_structured, filter_fn)
62+
63+
# for int8 dynamic quantization + 2:4 sparsity
64+
from torchao.sparsity.prototype import int8_dynamic_activation_int8_2x4_sparse_weight
65+
m = sparsify(m, int8_dynamic_activation_int8_2x4_sparse_weight(), filter_fn)
66+
"""
67+
_replace_with_custom_fn_if_matches_filter(
68+
model,
69+
_get_linear_subclass_inserter(apply_tensor_subclass),
70+
_is_linear if filter_fn is None else filter_fn,
71+
)
72+
73+
return model

0 commit comments

Comments
 (0)