diff --git a/README.md b/README.md index f0cb7e2c00..c7e67ddd9f 100644 --- a/README.md +++ b/README.md @@ -47,25 +47,41 @@ For int4 we make heavy use of [tinygemm](https://github.com/pytorch/ao/blob/cb3b And a quick crash course on inference quantization to help parse the above table. Int4 quantization is an ambiguous term because there's the dtype in which a layer is represented and then the dtype in which the computation is done. For example, if you're using Weight-Only (wo) int4 quantization that means that the layer will be upcasted to a larger dtype like fp16 so an int4 matrix multiplication is defined as `F.linear(input, weight.to(input.dtype))`. Dynamic quantization (DQ) primarily targets activations, enabling on-the-fly quantization from higher precision formats like bf16 to lower precision formats such as int8. This process, when supported by hardware, allows for direct computation, such as performing `F.linear(input, weight)`. Naive quantization algorithms are also notoriously sensitive to outliers so we also typically set a group size that applies a scale factor per group of 64 elements in the case of `int4wo64`. +Sparsifying your model is also a 1 liner that should work on any model with an `nn.Linear`. We find that sparsity works best on compute bound models like SAM, specifically the MLP layers. +```python +from torchao.sparsity import sparsify +from torch.sparse import to_sparse_semi_structured -#### With intrusive code changes +m = sparsify(m, to_sparse_semi_structured) +``` +Sparsity can also be composed with int8 dynamic quantization for further speedups: -In some cases we rewrote popular GenAI models to be significantly faster in native PyTorch as in no C++/CUDA to achieve at the time SOTA inference performance. These involve more intrusive code changes. +```python +from torchao.sparsity import sparsify +from torchao.sparsity.prototype.dynamic_quant_sparse import int8_dynamic_activation_int8_2x4_sparse_weight -* 9.5x speedups for Image segmentation models with [sam-fast](https://pytorch.org/blog/accelerating-generative-ai) compared to vanilla [sam](https://github.com/facebookresearch/segment-anything). -* 1.16x speedup when composing int8 quantization with 2:4 sparsity against the accelerated baseline `bfloat16` dtype and `torch.compile="max_autotune"`. +m = sparsify(m, int8_dynamic_activation_int8_2x4_sparse_weight()) +``` +We found that applying int8 dynamic quantization to the attention layers, int8 dynamic quantization + 2:4 sparsity to mlp layer 1 and 2:4 sparsity to mlp layer 2 yielded the best configuration. +We were able to provide a **1.16x (22.7 -> 26.5 img/s) speedup over our dense baseline, while maintaining 97.5% (0.581 -> 0.567) of the evaluation accuracy (mIOU)**. + +The following benchmarks were ran for [segment-anything-fast](https://github.com/pytorch-labs/segment-anything-fast) ViT-h on an NVIDIA-A100-80GB, with batch_size=32 and `bfloat16` dtype, with `torch.compile="max_autotune"`: | Model Type | Technique | img/s | memory (MiB) | mIoU (coco2017 val) | relative speedup | relative accuracy | |------------|------------------------------------------------------------------------------------------------------|-------|--------------|---------------------|------------------|-------------------| -| ViT-h | sam (float32, eager) | 2.78 | 28806 | 0.58 | baseline | baseline | -| | sam (bfloat16, eager) | 14.85 | 14424 | 0.58 | **5.34x** | **100%** | -| | sam-fast (bfloat16, max-autotune) | 22.75 | 15172 | 0.58 | **8.18x** | **100%** | -| | int8 dynamic quant (attn + mlp) | 24.91 | 15154 | 0.58 | **8.96x** | **100%** | -| | 2:4 sparsity (mlp only) | 24.81 | 15632 | 0.57 | **8.92x** | **98%** | -| | int8 dynamic quant (attn)
int8 dynamic quant + 2:4 sparsity (mlp lin1)
2:4 sparsity (mlp lin2) | 26.46 | 14865 | 0.57 | **9.52x** | **98%** | +| ViT-h | baseline (bfloat16, max-autotune) | 22.75 | 15172 | 0.5811 | | | +| | int8 dynamic quant (attn + mlp) | 24.91 | 15154 | 0.5822 | **1.09x** | **100.19%** | +| | 2:4 sparsity (mlp only) | 24.81 | 15632 | 0.5672 | **1.10x** | **97.61%** | +| | 2:4 sparsity (attn + mlp) | 24.30 | 13429 | 0.5306 | **1.07x** | **91.31%** | +| | int8 dynamic quant (attn)
int8 dynamic quant + 2:4 sparsity (mlp lin1)
2:4 sparsity (mlp lin2) | 26.46 | 14865 | 0.5668 | **1.16x** | **97.54%** | + +To reproduce our benchmarks please follow these [instructions](/scripts/sam/README.md). -The relative speedup is measured purely across the image encoder (ViT) of the model, where we apply our model optimizations. Benchmarks ran on an NVIDIA-A100-80GB with batch_size=32 +#### With intrusive code changes + +In some cases we rewrote popular GenAI models to be significantly faster in native PyTorch as in no C++/CUDA to achieve at the time SOTA inference performance. These involve more intrusive code changes. +* 8x with in speedups for Image segmentation models with [sam-fast](https://pytorch.org/blog/accelerating-generative-ai) (9.5x with int8 dynamic quantization + 2:4 sparsity) * 10x speedups for Language models with [gpt-fast](https://pytorch.org/blog/accelerating-generative-ai-2) * 3x speedup for Diffusion models with [sd-fast](https://pytorch.org/blog/accelerating-generative-ai-3) diff --git a/scripts/sam/eval_combo.py b/scripts/sam/eval_combo.py index 7671f9ce37..869d3df4d0 100644 --- a/scripts/sam/eval_combo.py +++ b/scripts/sam/eval_combo.py @@ -286,14 +286,20 @@ def run( elif compress == "sparse_mlp_only": def mlp_only(mod, name): return isinstance(mod, torch.nn.Linear) and 'mlp' in name - from torchao.sparsity import apply_sparse_semi_structured - apply_sparse_semi_structured(predictor.model.image_encoder, filter_fn=mlp_only) + from torchao.sparsity import sparsify + from torch.sparse import to_sparse_semi_structured, apply_fake_sparsity + apply_fake_sparsity(predictor.model.image_encoder, filter_fn=mlp_only) + predictor.model.image_encoder = sparsify(predictor.model.image_encoder, to_sparse_semi_structured, filter_fn=mlp_only) elif compress == "sparse": - from torchao.sparsity import apply_sparse_semi_structured - apply_sparse_semi_structured(predictor.model.image_encoder) + from torchao.sparsity import sparsify + from torch.sparse import to_sparse_semi_structured, apply_fake_sparsity + apply_fake_sparsity(predictor.model.image_encoder) + predictor.model.image_encoder = sparsify(predictor.model.image_encoder, to_sparse_semi_structured) elif compress == "int8_dynamic_quant_sparse": - from torchao.sparsity.prototype.dynamic_quant_sparse import Int8DynamicallyQuantized24CusparseltLinearFuseMulWeight - from torchao.sparsity import apply_fake_sparsity, apply_sparse_semi_structured + from torch.sparse import to_sparse_semi_structured, SparseSemiStructuredTensor + SparseSemiStructuredTensor._FORCE_CUTLASS = False + from torchao.sparsity import sparsify, apply_fake_sparsity + from torchao.sparsity.prototype.dynamic_quant_sparse import int8_dynamic_activation_int8_2x4_sparse_weight from torchao.quantization import quantize, int8_dynamic_activation_int8_weight from torchao.utils import unwrap_tensor_subclass @@ -306,6 +312,7 @@ def mlp_lin2_only(mod, name): def mlp_only(mod, name): return isinstance(mod, torch.nn.Linear) and 'mlp' in name + # apply sparsify first to set qparams apply_fake_sparsity(predictor.model.image_encoder, filter_fn=mlp_only) @@ -314,10 +321,13 @@ def mlp_only(mod, name): attn_only) predictor.model.image_encoder = unwrap_tensor_subclass(predictor.model.image_encoder) - predictor.model.image_encoder = quantize(predictor.model.image_encoder, - Int8DynamicallyQuantized24CusparseltLinearFuseMulWeight.from_float, - mlp_lin1_only) - apply_sparse_semi_structured(predictor.model.image_encoder, filter_fn=mlp_lin2_only) + predictor.model.image_encoder = sparsify(predictor.model.image_encoder, + int8_dynamic_activation_int8_2x4_sparse_weight(), + mlp_lin1_only, prune=False) + + predictor.model.image_encoder = sparsify(predictor.model.image_encoder, + to_sparse_semi_structured, + mlp_lin2_only, prune=False) else: assert compress is None, f"Unsupported compress mode {compress}" diff --git a/test/sparsity/test_sparse_api.py b/test/sparsity/test_sparse_api.py index c7bc2700df..3e566732bb 100644 --- a/test/sparsity/test_sparse_api.py +++ b/test/sparsity/test_sparse_api.py @@ -3,9 +3,10 @@ import torch from torch import nn +from torch.sparse import to_sparse_semi_structured -from torchao.sparsity import apply_fake_sparsity, apply_sparse_semi_structured -from torchao.sparsity.prototype.dynamic_quant_sparse import Int8DynamicallyQuantizedSemiStructuredSparseLinearWeight +from torchao.sparsity import apply_fake_sparsity, sparsify +from torchao.sparsity.prototype.dynamic_quant_sparse import int8_dynamic_activation_int8_2x4_sparse_weight from torchao.quantization.quant_api import ( _replace_with_custom_fn_if_matches_filter, _get_subclass_inserter, @@ -37,7 +38,7 @@ def test_sparse(self): apply_fake_sparsity(model) dense_result = model(input) - apply_sparse_semi_structured(model) + model = sparsify(model, to_sparse_semi_structured) sparse_result = model(input) assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3) @@ -61,7 +62,7 @@ def test_quant_semi_sparse(self): apply_fake_sparsity(model) dense_result = model(input) - _replace_with_custom_fn_if_matches_filter(model, _get_subclass_inserter(Int8DynamicallyQuantizedSemiStructuredSparseLinearWeight), _is_linear) + sparsify(model, int8_dynamic_activation_int8_2x4_sparse_weight()) sparse_result = model(input) assert torch.allclose(dense_result, sparse_result, rtol=1e-1, atol=1e-1) diff --git a/torchao/sparsity/__init__.py b/torchao/sparsity/__init__.py index 6621d086d0..9b288c07f9 100644 --- a/torchao/sparsity/__init__.py +++ b/torchao/sparsity/__init__.py @@ -6,11 +6,11 @@ from .wanda import WandaSparsifier # noqa: F403 from .utils import PerChannelNormObserver # noqa: F403 -from .sparse_api import apply_sparse_semi_structured, apply_fake_sparsity +from .sparse_api import apply_fake_sparsity, sparsify __all__ = [ "WandaSparsifier", "PerChannelNormObserver", - "apply_sparse_semi_structured", "apply_fake_sparsity", + "sparsify" ] diff --git a/torchao/sparsity/prototype/dynamic_quant_sparse.py b/torchao/sparsity/prototype/dynamic_quant_sparse.py index 2601f166a8..2f2a198278 100644 --- a/torchao/sparsity/prototype/dynamic_quant_sparse.py +++ b/torchao/sparsity/prototype/dynamic_quant_sparse.py @@ -309,3 +309,6 @@ def from_float(cls, input_float, qmin=-128, qmax=127): input_float.shape, dtype=input_float.dtype, ) + +def int8_dynamic_activation_int8_2x4_sparse_weight(): + return Int8DynamicallyQuantized24CusparseltLinearFuseMulWeight.from_float diff --git a/torchao/sparsity/sparse_api.py b/torchao/sparsity/sparse_api.py index d8ec14a266..8f8ca24a39 100644 --- a/torchao/sparsity/sparse_api.py +++ b/torchao/sparsity/sparse_api.py @@ -1,7 +1,13 @@ +from typing import Callable, Optional + import torch from torch.ao.pruning import WeightNormSparsifier from torch.sparse import to_sparse_semi_structured -from torchao.quantization.quant_api import _is_linear +from torchao.quantization.quant_api import ( + _is_linear, + _replace_with_custom_fn_if_matches_filter, + _get_linear_subclass_inserter, +) # Sparsity helper functions def apply_fake_sparsity(model, **kwargs): @@ -24,10 +30,44 @@ def apply_fake_sparsity(model, **kwargs): sparsifier.squash_mask() -def apply_sparse_semi_structured(model, **kwargs): - filter_fn = kwargs.pop("filter_fn", _is_linear) +def sparsify(model: torch.nn.Module, + apply_tensor_subclass: Callable[[torch.Tensor], torch.Tensor], + filter_fn: Optional[Callable[[torch.nn.Module, str], bool]]=None) -> torch.nn.Module: + """Convert the weight of linear modules in the model with `apply_tensor_subclass` + This function is essentially the same as quantize, put for sparsity subclasses. - apply_fake_sparsity(model, filter_fn=filter_fn) - for name, mod in model.named_modules(): - if filter_fn(mod, name): - mod.weight = torch.nn.Parameter(to_sparse_semi_structured(mod.weight)) + Currently, we support two options for sparsity: + - semi-structured (2:4) sparsity with `to_sparse_semi_structured` + - int8 dynamic quantization + 2:4 sparsity with `int8_dynamic_activation_int8_2x4_sparse_weight`, which is also available via the quantize API + + Args: + model (torch.nn.Module): input model + apply_tensor_subclass (Callable[[torch.Tensor], torch.Tensor]): function that convert a floating point Tensor to a (sparsified) tensor subclass instance (e.g. affine quantized tensor instance) + filter_fn (Optional[Callable[[torch.nn.Module, str], bool]]): function that takes a nn.Module instance and fully qualified name of the module, returns True if we want to run `apply_tensor_subclass` on + the weight of the module + + Example:: + import torch + import torch.nn as nn + from torchao.sparsity import sparsify + + def filter_fn(module: nn.Module, fqn: str) -> bool: + return isinstance(module, nn.Linear) + + m = nn.Sequential(nn.Linear(32, 1024), nn.Linear(1024, 32)) + + # for 2:4 sparsity + from torch.sparse import to_sparse_semi_structured + m = sparsify(m, to_sparse_semi_structured, filter_fn) + + # for int8 dynamic quantization + 2:4 sparsity + from torchao.sparsity.prototype import int8_dynamic_activation_int8_2x4_sparse_weight + m = sparsify(m, int8_dynamic_activation_int8_2x4_sparse_weight(), filter_fn) + """ + _replace_with_custom_fn_if_matches_filter( + model, + _get_linear_subclass_inserter(apply_tensor_subclass), + _is_linear if filter_fn is None else filter_fn, + ) + + return model