diff --git a/README.md b/README.md
index f0cb7e2c00..c7e67ddd9f 100644
--- a/README.md
+++ b/README.md
@@ -47,25 +47,41 @@ For int4 we make heavy use of [tinygemm](https://github.com/pytorch/ao/blob/cb3b
And a quick crash course on inference quantization to help parse the above table. Int4 quantization is an ambiguous term because there's the dtype in which a layer is represented and then the dtype in which the computation is done. For example, if you're using Weight-Only (wo) int4 quantization that means that the layer will be upcasted to a larger dtype like fp16 so an int4 matrix multiplication is defined as `F.linear(input, weight.to(input.dtype))`. Dynamic quantization (DQ) primarily targets activations, enabling on-the-fly quantization from higher precision formats like bf16 to lower precision formats such as int8. This process, when supported by hardware, allows for direct computation, such as performing `F.linear(input, weight)`. Naive quantization algorithms are also notoriously sensitive to outliers so we also typically set a group size that applies a scale factor per group of 64 elements in the case of `int4wo64`.
+Sparsifying your model is also a 1 liner that should work on any model with an `nn.Linear`. We find that sparsity works best on compute bound models like SAM, specifically the MLP layers.
+```python
+from torchao.sparsity import sparsify
+from torch.sparse import to_sparse_semi_structured
-#### With intrusive code changes
+m = sparsify(m, to_sparse_semi_structured)
+```
+Sparsity can also be composed with int8 dynamic quantization for further speedups:
-In some cases we rewrote popular GenAI models to be significantly faster in native PyTorch as in no C++/CUDA to achieve at the time SOTA inference performance. These involve more intrusive code changes.
+```python
+from torchao.sparsity import sparsify
+from torchao.sparsity.prototype.dynamic_quant_sparse import int8_dynamic_activation_int8_2x4_sparse_weight
-* 9.5x speedups for Image segmentation models with [sam-fast](https://pytorch.org/blog/accelerating-generative-ai) compared to vanilla [sam](https://github.com/facebookresearch/segment-anything).
-* 1.16x speedup when composing int8 quantization with 2:4 sparsity against the accelerated baseline `bfloat16` dtype and `torch.compile="max_autotune"`.
+m = sparsify(m, int8_dynamic_activation_int8_2x4_sparse_weight())
+```
+We found that applying int8 dynamic quantization to the attention layers, int8 dynamic quantization + 2:4 sparsity to mlp layer 1 and 2:4 sparsity to mlp layer 2 yielded the best configuration.
+We were able to provide a **1.16x (22.7 -> 26.5 img/s) speedup over our dense baseline, while maintaining 97.5% (0.581 -> 0.567) of the evaluation accuracy (mIOU)**.
+
+The following benchmarks were ran for [segment-anything-fast](https://github.com/pytorch-labs/segment-anything-fast) ViT-h on an NVIDIA-A100-80GB, with batch_size=32 and `bfloat16` dtype, with `torch.compile="max_autotune"`:
| Model Type | Technique | img/s | memory (MiB) | mIoU (coco2017 val) | relative speedup | relative accuracy |
|------------|------------------------------------------------------------------------------------------------------|-------|--------------|---------------------|------------------|-------------------|
-| ViT-h | sam (float32, eager) | 2.78 | 28806 | 0.58 | baseline | baseline |
-| | sam (bfloat16, eager) | 14.85 | 14424 | 0.58 | **5.34x** | **100%** |
-| | sam-fast (bfloat16, max-autotune) | 22.75 | 15172 | 0.58 | **8.18x** | **100%** |
-| | int8 dynamic quant (attn + mlp) | 24.91 | 15154 | 0.58 | **8.96x** | **100%** |
-| | 2:4 sparsity (mlp only) | 24.81 | 15632 | 0.57 | **8.92x** | **98%** |
-| | int8 dynamic quant (attn)
int8 dynamic quant + 2:4 sparsity (mlp lin1)
2:4 sparsity (mlp lin2) | 26.46 | 14865 | 0.57 | **9.52x** | **98%** |
+| ViT-h | baseline (bfloat16, max-autotune) | 22.75 | 15172 | 0.5811 | | |
+| | int8 dynamic quant (attn + mlp) | 24.91 | 15154 | 0.5822 | **1.09x** | **100.19%** |
+| | 2:4 sparsity (mlp only) | 24.81 | 15632 | 0.5672 | **1.10x** | **97.61%** |
+| | 2:4 sparsity (attn + mlp) | 24.30 | 13429 | 0.5306 | **1.07x** | **91.31%** |
+| | int8 dynamic quant (attn)
int8 dynamic quant + 2:4 sparsity (mlp lin1)
2:4 sparsity (mlp lin2) | 26.46 | 14865 | 0.5668 | **1.16x** | **97.54%** |
+
+To reproduce our benchmarks please follow these [instructions](/scripts/sam/README.md).
-The relative speedup is measured purely across the image encoder (ViT) of the model, where we apply our model optimizations. Benchmarks ran on an NVIDIA-A100-80GB with batch_size=32
+#### With intrusive code changes
+
+In some cases we rewrote popular GenAI models to be significantly faster in native PyTorch as in no C++/CUDA to achieve at the time SOTA inference performance. These involve more intrusive code changes.
+* 8x with in speedups for Image segmentation models with [sam-fast](https://pytorch.org/blog/accelerating-generative-ai) (9.5x with int8 dynamic quantization + 2:4 sparsity)
* 10x speedups for Language models with [gpt-fast](https://pytorch.org/blog/accelerating-generative-ai-2)
* 3x speedup for Diffusion models with [sd-fast](https://pytorch.org/blog/accelerating-generative-ai-3)
diff --git a/scripts/sam/eval_combo.py b/scripts/sam/eval_combo.py
index 7671f9ce37..869d3df4d0 100644
--- a/scripts/sam/eval_combo.py
+++ b/scripts/sam/eval_combo.py
@@ -286,14 +286,20 @@ def run(
elif compress == "sparse_mlp_only":
def mlp_only(mod, name):
return isinstance(mod, torch.nn.Linear) and 'mlp' in name
- from torchao.sparsity import apply_sparse_semi_structured
- apply_sparse_semi_structured(predictor.model.image_encoder, filter_fn=mlp_only)
+ from torchao.sparsity import sparsify
+ from torch.sparse import to_sparse_semi_structured, apply_fake_sparsity
+ apply_fake_sparsity(predictor.model.image_encoder, filter_fn=mlp_only)
+ predictor.model.image_encoder = sparsify(predictor.model.image_encoder, to_sparse_semi_structured, filter_fn=mlp_only)
elif compress == "sparse":
- from torchao.sparsity import apply_sparse_semi_structured
- apply_sparse_semi_structured(predictor.model.image_encoder)
+ from torchao.sparsity import sparsify
+ from torch.sparse import to_sparse_semi_structured, apply_fake_sparsity
+ apply_fake_sparsity(predictor.model.image_encoder)
+ predictor.model.image_encoder = sparsify(predictor.model.image_encoder, to_sparse_semi_structured)
elif compress == "int8_dynamic_quant_sparse":
- from torchao.sparsity.prototype.dynamic_quant_sparse import Int8DynamicallyQuantized24CusparseltLinearFuseMulWeight
- from torchao.sparsity import apply_fake_sparsity, apply_sparse_semi_structured
+ from torch.sparse import to_sparse_semi_structured, SparseSemiStructuredTensor
+ SparseSemiStructuredTensor._FORCE_CUTLASS = False
+ from torchao.sparsity import sparsify, apply_fake_sparsity
+ from torchao.sparsity.prototype.dynamic_quant_sparse import int8_dynamic_activation_int8_2x4_sparse_weight
from torchao.quantization import quantize, int8_dynamic_activation_int8_weight
from torchao.utils import unwrap_tensor_subclass
@@ -306,6 +312,7 @@ def mlp_lin2_only(mod, name):
def mlp_only(mod, name):
return isinstance(mod, torch.nn.Linear) and 'mlp' in name
+ # apply sparsify first to set qparams
apply_fake_sparsity(predictor.model.image_encoder,
filter_fn=mlp_only)
@@ -314,10 +321,13 @@ def mlp_only(mod, name):
attn_only)
predictor.model.image_encoder = unwrap_tensor_subclass(predictor.model.image_encoder)
- predictor.model.image_encoder = quantize(predictor.model.image_encoder,
- Int8DynamicallyQuantized24CusparseltLinearFuseMulWeight.from_float,
- mlp_lin1_only)
- apply_sparse_semi_structured(predictor.model.image_encoder, filter_fn=mlp_lin2_only)
+ predictor.model.image_encoder = sparsify(predictor.model.image_encoder,
+ int8_dynamic_activation_int8_2x4_sparse_weight(),
+ mlp_lin1_only, prune=False)
+
+ predictor.model.image_encoder = sparsify(predictor.model.image_encoder,
+ to_sparse_semi_structured,
+ mlp_lin2_only, prune=False)
else:
assert compress is None, f"Unsupported compress mode {compress}"
diff --git a/test/sparsity/test_sparse_api.py b/test/sparsity/test_sparse_api.py
index c7bc2700df..3e566732bb 100644
--- a/test/sparsity/test_sparse_api.py
+++ b/test/sparsity/test_sparse_api.py
@@ -3,9 +3,10 @@
import torch
from torch import nn
+from torch.sparse import to_sparse_semi_structured
-from torchao.sparsity import apply_fake_sparsity, apply_sparse_semi_structured
-from torchao.sparsity.prototype.dynamic_quant_sparse import Int8DynamicallyQuantizedSemiStructuredSparseLinearWeight
+from torchao.sparsity import apply_fake_sparsity, sparsify
+from torchao.sparsity.prototype.dynamic_quant_sparse import int8_dynamic_activation_int8_2x4_sparse_weight
from torchao.quantization.quant_api import (
_replace_with_custom_fn_if_matches_filter,
_get_subclass_inserter,
@@ -37,7 +38,7 @@ def test_sparse(self):
apply_fake_sparsity(model)
dense_result = model(input)
- apply_sparse_semi_structured(model)
+ model = sparsify(model, to_sparse_semi_structured)
sparse_result = model(input)
assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3)
@@ -61,7 +62,7 @@ def test_quant_semi_sparse(self):
apply_fake_sparsity(model)
dense_result = model(input)
- _replace_with_custom_fn_if_matches_filter(model, _get_subclass_inserter(Int8DynamicallyQuantizedSemiStructuredSparseLinearWeight), _is_linear)
+ sparsify(model, int8_dynamic_activation_int8_2x4_sparse_weight())
sparse_result = model(input)
assert torch.allclose(dense_result, sparse_result, rtol=1e-1, atol=1e-1)
diff --git a/torchao/sparsity/__init__.py b/torchao/sparsity/__init__.py
index 6621d086d0..9b288c07f9 100644
--- a/torchao/sparsity/__init__.py
+++ b/torchao/sparsity/__init__.py
@@ -6,11 +6,11 @@
from .wanda import WandaSparsifier # noqa: F403
from .utils import PerChannelNormObserver # noqa: F403
-from .sparse_api import apply_sparse_semi_structured, apply_fake_sparsity
+from .sparse_api import apply_fake_sparsity, sparsify
__all__ = [
"WandaSparsifier",
"PerChannelNormObserver",
- "apply_sparse_semi_structured",
"apply_fake_sparsity",
+ "sparsify"
]
diff --git a/torchao/sparsity/prototype/dynamic_quant_sparse.py b/torchao/sparsity/prototype/dynamic_quant_sparse.py
index 2601f166a8..2f2a198278 100644
--- a/torchao/sparsity/prototype/dynamic_quant_sparse.py
+++ b/torchao/sparsity/prototype/dynamic_quant_sparse.py
@@ -309,3 +309,6 @@ def from_float(cls, input_float, qmin=-128, qmax=127):
input_float.shape,
dtype=input_float.dtype,
)
+
+def int8_dynamic_activation_int8_2x4_sparse_weight():
+ return Int8DynamicallyQuantized24CusparseltLinearFuseMulWeight.from_float
diff --git a/torchao/sparsity/sparse_api.py b/torchao/sparsity/sparse_api.py
index d8ec14a266..8f8ca24a39 100644
--- a/torchao/sparsity/sparse_api.py
+++ b/torchao/sparsity/sparse_api.py
@@ -1,7 +1,13 @@
+from typing import Callable, Optional
+
import torch
from torch.ao.pruning import WeightNormSparsifier
from torch.sparse import to_sparse_semi_structured
-from torchao.quantization.quant_api import _is_linear
+from torchao.quantization.quant_api import (
+ _is_linear,
+ _replace_with_custom_fn_if_matches_filter,
+ _get_linear_subclass_inserter,
+)
# Sparsity helper functions
def apply_fake_sparsity(model, **kwargs):
@@ -24,10 +30,44 @@ def apply_fake_sparsity(model, **kwargs):
sparsifier.squash_mask()
-def apply_sparse_semi_structured(model, **kwargs):
- filter_fn = kwargs.pop("filter_fn", _is_linear)
+def sparsify(model: torch.nn.Module,
+ apply_tensor_subclass: Callable[[torch.Tensor], torch.Tensor],
+ filter_fn: Optional[Callable[[torch.nn.Module, str], bool]]=None) -> torch.nn.Module:
+ """Convert the weight of linear modules in the model with `apply_tensor_subclass`
+ This function is essentially the same as quantize, put for sparsity subclasses.
- apply_fake_sparsity(model, filter_fn=filter_fn)
- for name, mod in model.named_modules():
- if filter_fn(mod, name):
- mod.weight = torch.nn.Parameter(to_sparse_semi_structured(mod.weight))
+ Currently, we support two options for sparsity:
+ - semi-structured (2:4) sparsity with `to_sparse_semi_structured`
+ - int8 dynamic quantization + 2:4 sparsity with `int8_dynamic_activation_int8_2x4_sparse_weight`, which is also available via the quantize API
+
+ Args:
+ model (torch.nn.Module): input model
+ apply_tensor_subclass (Callable[[torch.Tensor], torch.Tensor]): function that convert a floating point Tensor to a (sparsified) tensor subclass instance (e.g. affine quantized tensor instance)
+ filter_fn (Optional[Callable[[torch.nn.Module, str], bool]]): function that takes a nn.Module instance and fully qualified name of the module, returns True if we want to run `apply_tensor_subclass` on
+ the weight of the module
+
+ Example::
+ import torch
+ import torch.nn as nn
+ from torchao.sparsity import sparsify
+
+ def filter_fn(module: nn.Module, fqn: str) -> bool:
+ return isinstance(module, nn.Linear)
+
+ m = nn.Sequential(nn.Linear(32, 1024), nn.Linear(1024, 32))
+
+ # for 2:4 sparsity
+ from torch.sparse import to_sparse_semi_structured
+ m = sparsify(m, to_sparse_semi_structured, filter_fn)
+
+ # for int8 dynamic quantization + 2:4 sparsity
+ from torchao.sparsity.prototype import int8_dynamic_activation_int8_2x4_sparse_weight
+ m = sparsify(m, int8_dynamic_activation_int8_2x4_sparse_weight(), filter_fn)
+ """
+ _replace_with_custom_fn_if_matches_filter(
+ model,
+ _get_linear_subclass_inserter(apply_tensor_subclass),
+ _is_linear if filter_fn is None else filter_fn,
+ )
+
+ return model