Skip to content

Add sparsify API to torchao #473

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 27 additions & 11 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,25 +47,41 @@ For int4 we make heavy use of [tinygemm](https://github.com/pytorch/ao/blob/cb3b

And a quick crash course on inference quantization to help parse the above table. Int4 quantization is an ambiguous term because there's the dtype in which a layer is represented and then the dtype in which the computation is done. For example, if you're using Weight-Only (wo) int4 quantization that means that the layer will be upcasted to a larger dtype like fp16 so an int4 matrix multiplication is defined as `F.linear(input, weight.to(input.dtype))`. Dynamic quantization (DQ) primarily targets activations, enabling on-the-fly quantization from higher precision formats like bf16 to lower precision formats such as int8. This process, when supported by hardware, allows for direct computation, such as performing `F.linear(input, weight)`. Naive quantization algorithms are also notoriously sensitive to outliers so we also typically set a group size that applies a scale factor per group of 64 elements in the case of `int4wo64`.

Sparsifying your model is also a 1 liner that should work on any model with an `nn.Linear`. We find that sparsity works best on compute bound models like SAM, specifically the MLP layers.
```python
from torchao.sparsity import sparsify
from torch.sparse import to_sparse_semi_structured

#### With intrusive code changes
m = sparsify(m, to_sparse_semi_structured)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also this is using to_sparse_semi_structured while the other one is using int8_dynamic_activation_int8_2x4_sparse_weight() which might be a bit confusing, I'd suggest to just align

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK ill add a semi_sparse_weight() wrapper function.

```
Sparsity can also be composed with int8 dynamic quantization for further speedups:

In some cases we rewrote popular GenAI models to be significantly faster in native PyTorch as in no C++/CUDA to achieve at the time SOTA inference performance. These involve more intrusive code changes.
```python
from torchao.sparsity import sparsify
from torchao.sparsity.prototype.dynamic_quant_sparse import int8_dynamic_activation_int8_2x4_sparse_weight
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the plan to move this out of prototype? is this related to composing sparsity and quant properly?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I plan to move this out as part of 0.4, implementing a layout like here: https://github.com/pytorch/ao/compare/jcaip/affine-quantize-sparse?expand=1

This works, but I am running into a performance regression so need to debug that first before we can merge.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We might have briefly chatted about this when we were discussing the quantize api but just thinking out loud here

If I add quantize(sparsify(m)) is that different from sparsify(quantize(m)) and if so in what order if any are optimizations applied in this case?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right now, the composition of int8 quantization and 2:4 sparsity is treated as it's own distinct technique, so you can either go:

quantize(int8dynamic + 2:4 sparse) or sparsify(int8dynamic + 2:4 sparse). Once we implement sparsity as a AQTLayout we can add support for a "composable" API, where we go

quantize(int8 dynamic)
sparsify(to_sparse_semi_structured)

or vice versa.

Mathematically, the how you apply the optimizations will matter, but I think we should make them the so it doesn't matter for our API, for two reasons:

  1. currently only quantize -> sparsify is supported, it would be extra work to support sparsify -> quantize.
  2. One of these orderings will be "better", I can't really see a situation where the order of how you apply the optimizations will differ across different layers. So we should always just default to the "best" one and not give users an option to shoot themselves in their foot.


* 9.5x speedups for Image segmentation models with [sam-fast](https://pytorch.org/blog/accelerating-generative-ai) compared to vanilla [sam](https://github.com/facebookresearch/segment-anything).
* 1.16x speedup when composing int8 quantization with 2:4 sparsity against the accelerated baseline `bfloat16` dtype and `torch.compile="max_autotune"`.
m = sparsify(m, int8_dynamic_activation_int8_2x4_sparse_weight())
```
We found that applying int8 dynamic quantization to the attention layers, int8 dynamic quantization + 2:4 sparsity to mlp layer 1 and 2:4 sparsity to mlp layer 2 yielded the best configuration.
We were able to provide a **1.16x (22.7 -> 26.5 img/s) speedup over our dense baseline, while maintaining 97.5% (0.581 -> 0.567) of the evaluation accuracy (mIOU)**.

The following benchmarks were ran for [segment-anything-fast](https://github.com/pytorch-labs/segment-anything-fast) ViT-h on an NVIDIA-A100-80GB, with batch_size=32 and `bfloat16` dtype, with `torch.compile="max_autotune"`:

| Model Type | Technique | img/s | memory (MiB) | mIoU (coco2017 val) | relative speedup | relative accuracy |
|------------|------------------------------------------------------------------------------------------------------|-------|--------------|---------------------|------------------|-------------------|
| ViT-h | sam (float32, eager) | 2.78 | 28806 | 0.58 | baseline | baseline |
| | sam (bfloat16, eager) | 14.85 | 14424 | 0.58 | **5.34x** | **100%** |
| | sam-fast (bfloat16, max-autotune) | 22.75 | 15172 | 0.58 | **8.18x** | **100%** |
| | int8 dynamic quant (attn + mlp) | 24.91 | 15154 | 0.58 | **8.96x** | **100%** |
| | 2:4 sparsity (mlp only) | 24.81 | 15632 | 0.57 | **8.92x** | **98%** |
| | int8 dynamic quant (attn)<br>int8 dynamic quant + 2:4 sparsity (mlp lin1)<br>2:4 sparsity (mlp lin2) | 26.46 | 14865 | 0.57 | **9.52x** | **98%** |
| ViT-h | baseline (bfloat16, max-autotune) | 22.75 | 15172 | 0.5811 | | |
| | int8 dynamic quant (attn + mlp) | 24.91 | 15154 | 0.5822 | **1.09x** | **100.19%** |
| | 2:4 sparsity (mlp only) | 24.81 | 15632 | 0.5672 | **1.10x** | **97.61%** |
| | 2:4 sparsity (attn + mlp) | 24.30 | 13429 | 0.5306 | **1.07x** | **91.31%** |
| | int8 dynamic quant (attn)<br>int8 dynamic quant + 2:4 sparsity (mlp lin1)<br>2:4 sparsity (mlp lin2) | 26.46 | 14865 | 0.5668 | **1.16x** | **97.54%** |

To reproduce our benchmarks please follow these [instructions](/scripts/sam/README.md).

The relative speedup is measured purely across the image encoder (ViT) of the model, where we apply our model optimizations. Benchmarks ran on an NVIDIA-A100-80GB with batch_size=32
#### With intrusive code changes

In some cases we rewrote popular GenAI models to be significantly faster in native PyTorch as in no C++/CUDA to achieve at the time SOTA inference performance. These involve more intrusive code changes.

* 8x with in speedups for Image segmentation models with [sam-fast](https://pytorch.org/blog/accelerating-generative-ai) (9.5x with int8 dynamic quantization + 2:4 sparsity)
* 10x speedups for Language models with [gpt-fast](https://pytorch.org/blog/accelerating-generative-ai-2)
* 3x speedup for Diffusion models with [sd-fast](https://pytorch.org/blog/accelerating-generative-ai-3)

Expand Down
30 changes: 20 additions & 10 deletions scripts/sam/eval_combo.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,14 +286,20 @@ def run(
elif compress == "sparse_mlp_only":
def mlp_only(mod, name):
return isinstance(mod, torch.nn.Linear) and 'mlp' in name
from torchao.sparsity import apply_sparse_semi_structured
apply_sparse_semi_structured(predictor.model.image_encoder, filter_fn=mlp_only)
from torchao.sparsity import sparsify
from torch.sparse import to_sparse_semi_structured, apply_fake_sparsity
apply_fake_sparsity(predictor.model.image_encoder, filter_fn=mlp_only)
predictor.model.image_encoder = sparsify(predictor.model.image_encoder, to_sparse_semi_structured, filter_fn=mlp_only)
elif compress == "sparse":
from torchao.sparsity import apply_sparse_semi_structured
apply_sparse_semi_structured(predictor.model.image_encoder)
from torchao.sparsity import sparsify
from torch.sparse import to_sparse_semi_structured, apply_fake_sparsity
apply_fake_sparsity(predictor.model.image_encoder)
predictor.model.image_encoder = sparsify(predictor.model.image_encoder, to_sparse_semi_structured)
elif compress == "int8_dynamic_quant_sparse":
from torchao.sparsity.prototype.dynamic_quant_sparse import Int8DynamicallyQuantized24CusparseltLinearFuseMulWeight
from torchao.sparsity import apply_fake_sparsity, apply_sparse_semi_structured
from torch.sparse import to_sparse_semi_structured, SparseSemiStructuredTensor
SparseSemiStructuredTensor._FORCE_CUTLASS = False
from torchao.sparsity import sparsify, apply_fake_sparsity
from torchao.sparsity.prototype.dynamic_quant_sparse import int8_dynamic_activation_int8_2x4_sparse_weight
from torchao.quantization import quantize, int8_dynamic_activation_int8_weight
from torchao.utils import unwrap_tensor_subclass

Expand All @@ -306,6 +312,7 @@ def mlp_lin2_only(mod, name):
def mlp_only(mod, name):
return isinstance(mod, torch.nn.Linear) and 'mlp' in name

# apply sparsify first to set qparams
apply_fake_sparsity(predictor.model.image_encoder,
filter_fn=mlp_only)

Expand All @@ -314,10 +321,13 @@ def mlp_only(mod, name):
attn_only)
predictor.model.image_encoder = unwrap_tensor_subclass(predictor.model.image_encoder)

predictor.model.image_encoder = quantize(predictor.model.image_encoder,
Int8DynamicallyQuantized24CusparseltLinearFuseMulWeight.from_float,
mlp_lin1_only)
apply_sparse_semi_structured(predictor.model.image_encoder, filter_fn=mlp_lin2_only)
predictor.model.image_encoder = sparsify(predictor.model.image_encoder,
int8_dynamic_activation_int8_2x4_sparse_weight(),
mlp_lin1_only, prune=False)

predictor.model.image_encoder = sparsify(predictor.model.image_encoder,
to_sparse_semi_structured,
mlp_lin2_only, prune=False)
else:
assert compress is None, f"Unsupported compress mode {compress}"

Expand Down
9 changes: 5 additions & 4 deletions test/sparsity/test_sparse_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@

import torch
from torch import nn
from torch.sparse import to_sparse_semi_structured

from torchao.sparsity import apply_fake_sparsity, apply_sparse_semi_structured
from torchao.sparsity.prototype.dynamic_quant_sparse import Int8DynamicallyQuantizedSemiStructuredSparseLinearWeight
from torchao.sparsity import apply_fake_sparsity, sparsify
from torchao.sparsity.prototype.dynamic_quant_sparse import int8_dynamic_activation_int8_2x4_sparse_weight
from torchao.quantization.quant_api import (
_replace_with_custom_fn_if_matches_filter,
_get_subclass_inserter,
Expand Down Expand Up @@ -37,7 +38,7 @@ def test_sparse(self):
apply_fake_sparsity(model)
dense_result = model(input)

apply_sparse_semi_structured(model)
model = sparsify(model, to_sparse_semi_structured)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is sparsify an in place op? This came up recently since quantize is in place and here it looks like the api used to be in place but now its not

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I think there's some discussion on whether we can use quantize_ or quantize for the in-place op, I'm not sure if we came to a conclusion. cc @jerryzh168 do you have a preference for what to use here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can use the inplace version for now I think

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jcaip I think we can change this to sparsify_ to be consistent with quantization: #467

sparse_result = model(input)

assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3)
Expand All @@ -61,7 +62,7 @@ def test_quant_semi_sparse(self):
apply_fake_sparsity(model)
dense_result = model(input)

_replace_with_custom_fn_if_matches_filter(model, _get_subclass_inserter(Int8DynamicallyQuantizedSemiStructuredSparseLinearWeight), _is_linear)
sparsify(model, int8_dynamic_activation_int8_2x4_sparse_weight())
sparse_result = model(input)

assert torch.allclose(dense_result, sparse_result, rtol=1e-1, atol=1e-1)
Expand Down
4 changes: 2 additions & 2 deletions torchao/sparsity/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@

from .wanda import WandaSparsifier # noqa: F403
from .utils import PerChannelNormObserver # noqa: F403
from .sparse_api import apply_sparse_semi_structured, apply_fake_sparsity
from .sparse_api import apply_fake_sparsity, sparsify

__all__ = [
"WandaSparsifier",
"PerChannelNormObserver",
"apply_sparse_semi_structured",
"apply_fake_sparsity",
"sparsify"
]
3 changes: 3 additions & 0 deletions torchao/sparsity/prototype/dynamic_quant_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,3 +309,6 @@ def from_float(cls, input_float, qmin=-128, qmax=127):
input_float.shape,
dtype=input_float.dtype,
)

def int8_dynamic_activation_int8_2x4_sparse_weight():
return Int8DynamicallyQuantized24CusparseltLinearFuseMulWeight.from_float
54 changes: 47 additions & 7 deletions torchao/sparsity/sparse_api.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
from typing import Callable, Optional

import torch
from torch.ao.pruning import WeightNormSparsifier
from torch.sparse import to_sparse_semi_structured
from torchao.quantization.quant_api import _is_linear
from torchao.quantization.quant_api import (
_is_linear,
_replace_with_custom_fn_if_matches_filter,
_get_linear_subclass_inserter,
)

# Sparsity helper functions
def apply_fake_sparsity(model, **kwargs):
Expand All @@ -24,10 +30,44 @@ def apply_fake_sparsity(model, **kwargs):
sparsifier.squash_mask()


def apply_sparse_semi_structured(model, **kwargs):
filter_fn = kwargs.pop("filter_fn", _is_linear)
def sparsify(model: torch.nn.Module,
apply_tensor_subclass: Callable[[torch.Tensor], torch.Tensor],
filter_fn: Optional[Callable[[torch.nn.Module, str], bool]]=None) -> torch.nn.Module:
"""Convert the weight of linear modules in the model with `apply_tensor_subclass`
This function is essentially the same as quantize, put for sparsity subclasses.

apply_fake_sparsity(model, filter_fn=filter_fn)
for name, mod in model.named_modules():
if filter_fn(mod, name):
mod.weight = torch.nn.Parameter(to_sparse_semi_structured(mod.weight))
Currently, we support two options for sparsity:
- semi-structured (2:4) sparsity with `to_sparse_semi_structured`
- int8 dynamic quantization + 2:4 sparsity with `int8_dynamic_activation_int8_2x4_sparse_weight`, which is also available via the quantize API
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is where I'm a bi confused, it's not clear whether the quantize and sparsify apis compose reading the docstrings


Args:
model (torch.nn.Module): input model
apply_tensor_subclass (Callable[[torch.Tensor], torch.Tensor]): function that convert a floating point Tensor to a (sparsified) tensor subclass instance (e.g. affine quantized tensor instance)
filter_fn (Optional[Callable[[torch.nn.Module, str], bool]]): function that takes a nn.Module instance and fully qualified name of the module, returns True if we want to run `apply_tensor_subclass` on
the weight of the module

Example::
import torch
import torch.nn as nn
from torchao.sparsity import sparsify

def filter_fn(module: nn.Module, fqn: str) -> bool:
return isinstance(module, nn.Linear)

m = nn.Sequential(nn.Linear(32, 1024), nn.Linear(1024, 32))

# for 2:4 sparsity
from torch.sparse import to_sparse_semi_structured
m = sparsify(m, to_sparse_semi_structured, filter_fn)

# for int8 dynamic quantization + 2:4 sparsity
from torchao.sparsity.prototype import int8_dynamic_activation_int8_2x4_sparse_weight
m = sparsify(m, int8_dynamic_activation_int8_2x4_sparse_weight(), filter_fn)
"""
_replace_with_custom_fn_if_matches_filter(
model,
_get_linear_subclass_inserter(apply_tensor_subclass),
_is_linear if filter_fn is None else filter_fn,
)

return model
Loading