-
Notifications
You must be signed in to change notification settings - Fork 259
Add sparsify API to torchao #473
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -47,25 +47,41 @@ For int4 we make heavy use of [tinygemm](https://github.com/pytorch/ao/blob/cb3b | |
|
||
And a quick crash course on inference quantization to help parse the above table. Int4 quantization is an ambiguous term because there's the dtype in which a layer is represented and then the dtype in which the computation is done. For example, if you're using Weight-Only (wo) int4 quantization that means that the layer will be upcasted to a larger dtype like fp16 so an int4 matrix multiplication is defined as `F.linear(input, weight.to(input.dtype))`. Dynamic quantization (DQ) primarily targets activations, enabling on-the-fly quantization from higher precision formats like bf16 to lower precision formats such as int8. This process, when supported by hardware, allows for direct computation, such as performing `F.linear(input, weight)`. Naive quantization algorithms are also notoriously sensitive to outliers so we also typically set a group size that applies a scale factor per group of 64 elements in the case of `int4wo64`. | ||
|
||
Sparsifying your model is also a 1 liner that should work on any model with an `nn.Linear`. We find that sparsity works best on compute bound models like SAM, specifically the MLP layers. | ||
```python | ||
from torchao.sparsity import sparsify | ||
from torch.sparse import to_sparse_semi_structured | ||
|
||
#### With intrusive code changes | ||
m = sparsify(m, to_sparse_semi_structured) | ||
``` | ||
Sparsity can also be composed with int8 dynamic quantization for further speedups: | ||
|
||
In some cases we rewrote popular GenAI models to be significantly faster in native PyTorch as in no C++/CUDA to achieve at the time SOTA inference performance. These involve more intrusive code changes. | ||
```python | ||
from torchao.sparsity import sparsify | ||
from torchao.sparsity.prototype.dynamic_quant_sparse import int8_dynamic_activation_int8_2x4_sparse_weight | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what's the plan to move this out of prototype? is this related to composing sparsity and quant properly? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah I plan to move this out as part of 0.4, implementing a layout like here: https://github.com/pytorch/ao/compare/jcaip/affine-quantize-sparse?expand=1 This works, but I am running into a performance regression so need to debug that first before we can merge. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We might have briefly chatted about this when we were discussing the quantize api but just thinking out loud here If I add There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Right now, the composition of int8 quantization and 2:4 sparsity is treated as it's own distinct technique, so you can either go:
or vice versa. Mathematically, the how you apply the optimizations will matter, but I think we should make them the so it doesn't matter for our API, for two reasons:
|
||
|
||
* 9.5x speedups for Image segmentation models with [sam-fast](https://pytorch.org/blog/accelerating-generative-ai) compared to vanilla [sam](https://github.com/facebookresearch/segment-anything). | ||
* 1.16x speedup when composing int8 quantization with 2:4 sparsity against the accelerated baseline `bfloat16` dtype and `torch.compile="max_autotune"`. | ||
m = sparsify(m, int8_dynamic_activation_int8_2x4_sparse_weight()) | ||
``` | ||
We found that applying int8 dynamic quantization to the attention layers, int8 dynamic quantization + 2:4 sparsity to mlp layer 1 and 2:4 sparsity to mlp layer 2 yielded the best configuration. | ||
We were able to provide a **1.16x (22.7 -> 26.5 img/s) speedup over our dense baseline, while maintaining 97.5% (0.581 -> 0.567) of the evaluation accuracy (mIOU)**. | ||
|
||
The following benchmarks were ran for [segment-anything-fast](https://github.com/pytorch-labs/segment-anything-fast) ViT-h on an NVIDIA-A100-80GB, with batch_size=32 and `bfloat16` dtype, with `torch.compile="max_autotune"`: | ||
|
||
| Model Type | Technique | img/s | memory (MiB) | mIoU (coco2017 val) | relative speedup | relative accuracy | | ||
|------------|------------------------------------------------------------------------------------------------------|-------|--------------|---------------------|------------------|-------------------| | ||
| ViT-h | sam (float32, eager) | 2.78 | 28806 | 0.58 | baseline | baseline | | ||
| | sam (bfloat16, eager) | 14.85 | 14424 | 0.58 | **5.34x** | **100%** | | ||
| | sam-fast (bfloat16, max-autotune) | 22.75 | 15172 | 0.58 | **8.18x** | **100%** | | ||
| | int8 dynamic quant (attn + mlp) | 24.91 | 15154 | 0.58 | **8.96x** | **100%** | | ||
| | 2:4 sparsity (mlp only) | 24.81 | 15632 | 0.57 | **8.92x** | **98%** | | ||
| | int8 dynamic quant (attn)<br>int8 dynamic quant + 2:4 sparsity (mlp lin1)<br>2:4 sparsity (mlp lin2) | 26.46 | 14865 | 0.57 | **9.52x** | **98%** | | ||
| ViT-h | baseline (bfloat16, max-autotune) | 22.75 | 15172 | 0.5811 | | | | ||
| | int8 dynamic quant (attn + mlp) | 24.91 | 15154 | 0.5822 | **1.09x** | **100.19%** | | ||
| | 2:4 sparsity (mlp only) | 24.81 | 15632 | 0.5672 | **1.10x** | **97.61%** | | ||
| | 2:4 sparsity (attn + mlp) | 24.30 | 13429 | 0.5306 | **1.07x** | **91.31%** | | ||
| | int8 dynamic quant (attn)<br>int8 dynamic quant + 2:4 sparsity (mlp lin1)<br>2:4 sparsity (mlp lin2) | 26.46 | 14865 | 0.5668 | **1.16x** | **97.54%** | | ||
|
||
To reproduce our benchmarks please follow these [instructions](/scripts/sam/README.md). | ||
|
||
The relative speedup is measured purely across the image encoder (ViT) of the model, where we apply our model optimizations. Benchmarks ran on an NVIDIA-A100-80GB with batch_size=32 | ||
#### With intrusive code changes | ||
|
||
In some cases we rewrote popular GenAI models to be significantly faster in native PyTorch as in no C++/CUDA to achieve at the time SOTA inference performance. These involve more intrusive code changes. | ||
|
||
* 8x with in speedups for Image segmentation models with [sam-fast](https://pytorch.org/blog/accelerating-generative-ai) (9.5x with int8 dynamic quantization + 2:4 sparsity) | ||
* 10x speedups for Language models with [gpt-fast](https://pytorch.org/blog/accelerating-generative-ai-2) | ||
* 3x speedup for Diffusion models with [sd-fast](https://pytorch.org/blog/accelerating-generative-ai-3) | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,9 +3,10 @@ | |
|
||
import torch | ||
from torch import nn | ||
from torch.sparse import to_sparse_semi_structured | ||
|
||
from torchao.sparsity import apply_fake_sparsity, apply_sparse_semi_structured | ||
from torchao.sparsity.prototype.dynamic_quant_sparse import Int8DynamicallyQuantizedSemiStructuredSparseLinearWeight | ||
from torchao.sparsity import apply_fake_sparsity, sparsify | ||
from torchao.sparsity.prototype.dynamic_quant_sparse import int8_dynamic_activation_int8_2x4_sparse_weight | ||
from torchao.quantization.quant_api import ( | ||
_replace_with_custom_fn_if_matches_filter, | ||
_get_subclass_inserter, | ||
|
@@ -37,7 +38,7 @@ def test_sparse(self): | |
apply_fake_sparsity(model) | ||
dense_result = model(input) | ||
|
||
apply_sparse_semi_structured(model) | ||
model = sparsify(model, to_sparse_semi_structured) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is sparsify an in place op? This came up recently since There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, I think there's some discussion on whether we can use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we can use the inplace version for now I think There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
sparse_result = model(input) | ||
|
||
assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3) | ||
|
@@ -61,7 +62,7 @@ def test_quant_semi_sparse(self): | |
apply_fake_sparsity(model) | ||
dense_result = model(input) | ||
|
||
_replace_with_custom_fn_if_matches_filter(model, _get_subclass_inserter(Int8DynamicallyQuantizedSemiStructuredSparseLinearWeight), _is_linear) | ||
sparsify(model, int8_dynamic_activation_int8_2x4_sparse_weight()) | ||
sparse_result = model(input) | ||
|
||
assert torch.allclose(dense_result, sparse_result, rtol=1e-1, atol=1e-1) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,13 @@ | ||
from typing import Callable, Optional | ||
|
||
import torch | ||
from torch.ao.pruning import WeightNormSparsifier | ||
from torch.sparse import to_sparse_semi_structured | ||
from torchao.quantization.quant_api import _is_linear | ||
from torchao.quantization.quant_api import ( | ||
_is_linear, | ||
_replace_with_custom_fn_if_matches_filter, | ||
_get_linear_subclass_inserter, | ||
) | ||
|
||
# Sparsity helper functions | ||
def apply_fake_sparsity(model, **kwargs): | ||
|
@@ -24,10 +30,44 @@ def apply_fake_sparsity(model, **kwargs): | |
sparsifier.squash_mask() | ||
|
||
|
||
def apply_sparse_semi_structured(model, **kwargs): | ||
filter_fn = kwargs.pop("filter_fn", _is_linear) | ||
def sparsify(model: torch.nn.Module, | ||
apply_tensor_subclass: Callable[[torch.Tensor], torch.Tensor], | ||
filter_fn: Optional[Callable[[torch.nn.Module, str], bool]]=None) -> torch.nn.Module: | ||
"""Convert the weight of linear modules in the model with `apply_tensor_subclass` | ||
This function is essentially the same as quantize, put for sparsity subclasses. | ||
|
||
apply_fake_sparsity(model, filter_fn=filter_fn) | ||
for name, mod in model.named_modules(): | ||
if filter_fn(mod, name): | ||
mod.weight = torch.nn.Parameter(to_sparse_semi_structured(mod.weight)) | ||
Currently, we support two options for sparsity: | ||
- semi-structured (2:4) sparsity with `to_sparse_semi_structured` | ||
- int8 dynamic quantization + 2:4 sparsity with `int8_dynamic_activation_int8_2x4_sparse_weight`, which is also available via the quantize API | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is where I'm a bi confused, it's not clear whether the quantize and sparsify apis compose reading the docstrings |
||
|
||
Args: | ||
model (torch.nn.Module): input model | ||
apply_tensor_subclass (Callable[[torch.Tensor], torch.Tensor]): function that convert a floating point Tensor to a (sparsified) tensor subclass instance (e.g. affine quantized tensor instance) | ||
filter_fn (Optional[Callable[[torch.nn.Module, str], bool]]): function that takes a nn.Module instance and fully qualified name of the module, returns True if we want to run `apply_tensor_subclass` on | ||
the weight of the module | ||
|
||
Example:: | ||
import torch | ||
import torch.nn as nn | ||
from torchao.sparsity import sparsify | ||
|
||
def filter_fn(module: nn.Module, fqn: str) -> bool: | ||
return isinstance(module, nn.Linear) | ||
|
||
m = nn.Sequential(nn.Linear(32, 1024), nn.Linear(1024, 32)) | ||
|
||
# for 2:4 sparsity | ||
from torch.sparse import to_sparse_semi_structured | ||
m = sparsify(m, to_sparse_semi_structured, filter_fn) | ||
|
||
# for int8 dynamic quantization + 2:4 sparsity | ||
from torchao.sparsity.prototype import int8_dynamic_activation_int8_2x4_sparse_weight | ||
m = sparsify(m, int8_dynamic_activation_int8_2x4_sparse_weight(), filter_fn) | ||
""" | ||
_replace_with_custom_fn_if_matches_filter( | ||
model, | ||
_get_linear_subclass_inserter(apply_tensor_subclass), | ||
_is_linear if filter_fn is None else filter_fn, | ||
) | ||
|
||
return model |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also this is using
to_sparse_semi_structured
while the other one is usingint8_dynamic_activation_int8_2x4_sparse_weight()
which might be a bit confusing, I'd suggest to just alignThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK ill add a
semi_sparse_weight()
wrapper function.