From 2d5676eaffb74e9afed3faebf18b8c77532d62ef Mon Sep 17 00:00:00 2001 From: jainapurva Date: Tue, 22 Oct 2024 19:34:00 -0700 Subject: [PATCH] Move files to prototype/sparsity --- .../test_parametrization.py | 6 +- .../{sparsity => prototype}/test_scheduler.py | 11 +- .../test_sparse_api.py | 22 +- .../test_sparsifier.py | 13 +- .../test_sparsity_utils.py | 10 +- .../test_structured_sparsifier.py | 15 +- torchao/prototype/sparsity/__init__.py | 20 + .../prototype/sparsity/pruner/FPGM_pruner.py | 93 ++ .../sparsity}/pruner/README.md | 0 torchao/prototype/sparsity/pruner/__init__.py | 8 + .../pruner/base_structured_sparsifier.py | 314 ++++ .../sparsity}/pruner/images/prune_1.png | Bin .../sparsity}/pruner/images/prune_2.png | Bin .../sparsity}/pruner/images/prune_3.png | Bin .../sparsity}/pruner/images/prune_4.png | Bin .../sparsity}/pruner/images/prune_5.png | Bin .../sparsity}/pruner/images/prune_6.png | Bin .../sparsity/pruner/lstm_saliency_pruner.py | 48 + .../sparsity}/pruner/match_utils.py | 0 .../sparsity/pruner/parametrization.py | 59 + .../sparsity}/pruner/prune_functions.py | 0 .../sparsity/pruner/saliency_pruner.py | 29 + .../prototype/sparsity/scheduler/__init__.py | 0 .../sparsity/scheduler/base_scheduler.py | 170 +++ .../sparsity/scheduler/cubic_scheduler.py | 107 ++ .../sparsity/scheduler/lambda_scheduler.py | 47 + .../prototype/sparsity/sparsifier/__init__.py | 0 .../sparsity/sparsifier/base_sparsifier.py | 353 +++++ .../sparsifier/nearly_diagonal_sparsifier.py | 55 + .../prototype/sparsity/sparsifier/utils.py | 130 ++ .../sparsifier/weight_norm_sparsifier.py | 200 +++ .../sparsity}/superblock/.gitignore | 0 .../sparsity}/superblock/README.md | 0 .../sparsity}/superblock/TRAINING.md | 0 .../prototype/sparsity/superblock/__init__.py | 0 .../sparsity}/superblock/benchmark.py | 2 +- .../sparsity}/superblock/benchmark.sh | 0 .../superblock/benchmark_results.txt | 0 .../sparsity/superblock/blocksparse.py | 239 +++ .../sparsity}/superblock/evaluate.py | 11 +- .../sparsity}/superblock/evaluate.sh | 0 .../superblock/evaluation_results.txt | 0 .../sparsity/superblock/supermask.py | 275 ++++ .../sparsity}/superblock/train.py | 8 +- .../prototype/sparsity/superblock/utils.py | 1297 +++++++++++++++++ torchao/sparsity/README.md | 2 +- torchao/sparsity/prototype/__init__.py | 27 +- .../sparsity/prototype/pruner/FPGM_pruner.py | 94 +- .../pruner/base_structured_sparsifier.py | 313 +--- .../prototype/pruner/lstm_saliency_pruner.py | 49 +- .../prototype/pruner/parametrization.py | 63 +- .../prototype/pruner/saliency_pruner.py | 30 +- .../prototype/scheduler/base_scheduler.py | 160 +- .../prototype/scheduler/cubic_scheduler.py | 108 +- .../prototype/scheduler/lambda_scheduler.py | 48 +- .../prototype/sparsifier/base_sparsifier.py | 354 +---- .../sparsifier/nearly_diagonal_sparsifier.py | 58 +- .../sparsity/prototype/sparsifier/utils.py | 131 +- .../sparsifier/weight_norm_sparsifier.py | 203 +-- .../prototype/superblock/blocksparse.py | 240 +-- .../prototype/superblock/supermask.py | 280 +--- .../sparsity/prototype/superblock/utils.py | 1064 +------------- 62 files changed, 3557 insertions(+), 3209 deletions(-) rename test/{sparsity => prototype}/test_parametrization.py (99%) rename test/{sparsity => prototype}/test_scheduler.py (98%) rename test/{sparsity => prototype}/test_sparse_api.py (92%) rename test/{sparsity => prototype}/test_sparsifier.py (99%) rename test/{sparsity => prototype}/test_sparsity_utils.py (98%) rename test/{sparsity => prototype}/test_structured_sparsifier.py (99%) create mode 100644 torchao/prototype/sparsity/__init__.py create mode 100644 torchao/prototype/sparsity/pruner/FPGM_pruner.py rename torchao/{sparsity/prototype => prototype/sparsity}/pruner/README.md (100%) create mode 100644 torchao/prototype/sparsity/pruner/__init__.py create mode 100644 torchao/prototype/sparsity/pruner/base_structured_sparsifier.py rename torchao/{sparsity/prototype => prototype/sparsity}/pruner/images/prune_1.png (100%) rename torchao/{sparsity/prototype => prototype/sparsity}/pruner/images/prune_2.png (100%) rename torchao/{sparsity/prototype => prototype/sparsity}/pruner/images/prune_3.png (100%) rename torchao/{sparsity/prototype => prototype/sparsity}/pruner/images/prune_4.png (100%) rename torchao/{sparsity/prototype => prototype/sparsity}/pruner/images/prune_5.png (100%) rename torchao/{sparsity/prototype => prototype/sparsity}/pruner/images/prune_6.png (100%) create mode 100644 torchao/prototype/sparsity/pruner/lstm_saliency_pruner.py rename torchao/{sparsity/prototype => prototype/sparsity}/pruner/match_utils.py (100%) create mode 100644 torchao/prototype/sparsity/pruner/parametrization.py rename torchao/{sparsity/prototype => prototype/sparsity}/pruner/prune_functions.py (100%) create mode 100644 torchao/prototype/sparsity/pruner/saliency_pruner.py create mode 100644 torchao/prototype/sparsity/scheduler/__init__.py create mode 100644 torchao/prototype/sparsity/scheduler/base_scheduler.py create mode 100644 torchao/prototype/sparsity/scheduler/cubic_scheduler.py create mode 100644 torchao/prototype/sparsity/scheduler/lambda_scheduler.py create mode 100644 torchao/prototype/sparsity/sparsifier/__init__.py create mode 100644 torchao/prototype/sparsity/sparsifier/base_sparsifier.py create mode 100644 torchao/prototype/sparsity/sparsifier/nearly_diagonal_sparsifier.py create mode 100644 torchao/prototype/sparsity/sparsifier/utils.py create mode 100644 torchao/prototype/sparsity/sparsifier/weight_norm_sparsifier.py rename torchao/{sparsity/prototype => prototype/sparsity}/superblock/.gitignore (100%) rename torchao/{sparsity/prototype => prototype/sparsity}/superblock/README.md (100%) rename torchao/{sparsity/prototype => prototype/sparsity}/superblock/TRAINING.md (100%) create mode 100644 torchao/prototype/sparsity/superblock/__init__.py rename torchao/{sparsity/prototype => prototype/sparsity}/superblock/benchmark.py (98%) rename torchao/{sparsity/prototype => prototype/sparsity}/superblock/benchmark.sh (100%) rename torchao/{sparsity/prototype => prototype/sparsity}/superblock/benchmark_results.txt (100%) create mode 100644 torchao/prototype/sparsity/superblock/blocksparse.py rename torchao/{sparsity/prototype => prototype/sparsity}/superblock/evaluate.py (91%) rename torchao/{sparsity/prototype => prototype/sparsity}/superblock/evaluate.sh (100%) rename torchao/{sparsity/prototype => prototype/sparsity}/superblock/evaluation_results.txt (100%) create mode 100644 torchao/prototype/sparsity/superblock/supermask.py rename torchao/{sparsity/prototype => prototype/sparsity}/superblock/train.py (99%) create mode 100644 torchao/prototype/sparsity/superblock/utils.py diff --git a/test/sparsity/test_parametrization.py b/test/prototype/test_parametrization.py similarity index 99% rename from test/sparsity/test_parametrization.py rename to test/prototype/test_parametrization.py index 92cba1c022..19fc5737d2 100644 --- a/test/sparsity/test_parametrization.py +++ b/test/prototype/test_parametrization.py @@ -1,11 +1,12 @@ import logging -import torch import unittest + +import torch from torch import nn from torch.nn.utils import parametrize from torch.testing._internal.common_utils import TestCase -from torchao.sparsity.prototype.sparsifier import utils +from torchao.prototype.sparsity.sparsifier import utils logging.basicConfig( format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO @@ -168,5 +169,6 @@ def test_jit_trace(self): y_hat = model_trace(x) self.assertEqual(y_hat, y) + if __name__ == "__main__": unittest.main() diff --git a/test/sparsity/test_scheduler.py b/test/prototype/test_scheduler.py similarity index 98% rename from test/sparsity/test_scheduler.py rename to test/prototype/test_scheduler.py index 0cfc898dcd..6c924ed300 100644 --- a/test/sparsity/test_scheduler.py +++ b/test/prototype/test_scheduler.py @@ -1,10 +1,16 @@ -import warnings import unittest +import warnings from torch import nn from torch.testing._internal.common_utils import TestCase -from torchao.sparsity.prototype import BaseScheduler, CubicSL, LambdaSL, WeightNormSparsifier +from torchao.prototype.sparsity import ( + BaseScheduler, + CubicSL, + LambdaSL, + WeightNormSparsifier, +) + class ImplementedScheduler(BaseScheduler): def get_sl(self): @@ -190,5 +196,6 @@ def test_step(self): msg="Sparsity level is not reaching the target level afer delta_t * n steps ", ) + if __name__ == "__main__": unittest.main() diff --git a/test/sparsity/test_sparse_api.py b/test/prototype/test_sparse_api.py similarity index 92% rename from test/sparsity/test_sparse_api.py rename to test/prototype/test_sparse_api.py index fb0fa1b8e3..2cd1d14d52 100644 --- a/test/sparsity/test_sparse_api.py +++ b/test/prototype/test_sparse_api.py @@ -13,7 +13,12 @@ ) from torchao.sparsity import apply_fake_sparsity, semi_sparse_weight, sparsify_ -from torchao.utils import TORCH_VERSION_AFTER_2_5, TORCH_VERSION_AT_LEAST_2_3, TORCH_VERSION_AT_LEAST_2_5, TORCH_VERSION_AT_LEAST_2_4 +from torchao.utils import ( + TORCH_VERSION_AFTER_2_5, + TORCH_VERSION_AT_LEAST_2_3, + TORCH_VERSION_AT_LEAST_2_4, + TORCH_VERSION_AT_LEAST_2_5, +) logging.basicConfig( @@ -88,7 +93,7 @@ def test_quant_semi_sparse(self, compile): def test_sparse_marlin(self, compile): if not torch.backends.cusparselt.is_available(): self.skipTest("Need cuSPARSELt") - + input = torch.rand((256, 256)).half().cuda() model = ( nn.Sequential( @@ -117,7 +122,10 @@ def test_sparse_marlin(self, compile): class TestBlockSparseWeight(common_utils.TestCase): - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "pytorch 2.4+ feature due to need for custom op support") + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_4, + "pytorch 2.4+ feature due to need for custom op support", + ) @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") @common_utils.parametrize("compile", [True, False]) def test_sparse(self, compile): @@ -140,7 +148,7 @@ def test_sparse(self, compile): model[1].weight.data = create_block_sparse_tensor(M, N, 64, 0.5, torch.float16) dense_result = model(input) - from torchao.sparsity.prototype.superblock.blocksparse import ( + from torchao.prototype.sparsity.superblock.blocksparse import ( block_sparse_weight, ) @@ -167,7 +175,7 @@ def test_sparse(self, compile): .cuda() .eval() ) - from torchao.sparsity.prototype.superblock.blocksparse import ( + from torchao.prototype.sparsity.superblock.blocksparse import ( blocksparse_int_addmm, ) from torchao.sparsity.utils import create_block_sparse_tensor @@ -189,9 +197,7 @@ def test_sparse(self, compile): quantize_( model, - int8_dynamic_activation_int8_weight( - layout=BlockSparseLayout(blocksize=64) - ), + int8_dynamic_activation_int8_weight(layout=BlockSparseLayout(blocksize=64)), ) if compile: model = torch.compile(model) diff --git a/test/sparsity/test_sparsifier.py b/test/prototype/test_sparsifier.py similarity index 99% rename from test/sparsity/test_sparsifier.py rename to test/prototype/test_sparsifier.py index 0deeea9ca7..d3f5c1b1ab 100644 --- a/test/sparsity/test_sparsifier.py +++ b/test/prototype/test_sparsifier.py @@ -7,12 +7,6 @@ import torch from torch import nn -from torchao.sparsity.prototype import ( - BaseSparsifier, - FakeSparsity, - NearlyDiagonalSparsifier, - WeightNormSparsifier, -) from torch.nn.utils.parametrize import is_parametrized from torch.testing._internal.common_pruning import ( ImplementedSparsifier, @@ -21,6 +15,12 @@ ) from torch.testing._internal.common_utils import TestCase +from torchao.prototype.sparsity import ( + BaseSparsifier, + FakeSparsity, + NearlyDiagonalSparsifier, + WeightNormSparsifier, +) logging.basicConfig( format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO @@ -486,5 +486,6 @@ def _verify_nearliness(self, mask: torch.Tensor, nearliness: int): else: assert mask[row, col] == 0 + if __name__ == "__main__": unittest.main() diff --git a/test/sparsity/test_sparsity_utils.py b/test/prototype/test_sparsity_utils.py similarity index 98% rename from test/sparsity/test_sparsity_utils.py rename to test/prototype/test_sparsity_utils.py index 91d0d2d562..194aff1787 100644 --- a/test/sparsity/test_sparsity_utils.py +++ b/test/prototype/test_sparsity_utils.py @@ -2,11 +2,6 @@ import unittest import torch -from torchao.sparsity.prototype.sparsifier.utils import ( - fqn_to_module, - get_arg_info_from_tensor_fqn, - module_to_fqn, -) from torch.testing._internal.common_quantization import ( ConvBnReLUModel, @@ -18,6 +13,11 @@ TwoLayerLinearModel, ) from torch.testing._internal.common_utils import TestCase +from torchao.prototype.sparsity.sparsifier.utils import ( + fqn_to_module, + get_arg_info_from_tensor_fqn, + module_to_fqn, +) logging.basicConfig( format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO diff --git a/test/sparsity/test_structured_sparsifier.py b/test/prototype/test_structured_sparsifier.py similarity index 99% rename from test/sparsity/test_structured_sparsifier.py rename to test/prototype/test_structured_sparsifier.py index 6a662421cd..aa9e3879b2 100644 --- a/test/sparsity/test_structured_sparsifier.py +++ b/test/prototype/test_structured_sparsifier.py @@ -6,13 +6,6 @@ import torch from torch import nn -from torchao.sparsity.prototype.pruner import ( - BaseStructuredSparsifier, - FakeStructuredSparsity, - FPGMPruner, - LSTMSaliencyPruner, - SaliencyPruner, -) from torch.nn.utils import parametrize from torch.testing._internal.common_pruning import ( Conv2dActivation, @@ -32,6 +25,13 @@ ) from torch.testing._internal.common_utils import skipIfTorchDynamo, TestCase +from torchao.prototype.sparsity.pruner import ( + BaseStructuredSparsifier, + FakeStructuredSparsity, + FPGMPruner, + LSTMSaliencyPruner, + SaliencyPruner, +) logging.basicConfig( @@ -1093,5 +1093,6 @@ def test_update_mask(self): expected_conv1, expected_conv2, device ) + if __name__ == "__main__": unittest.main() diff --git a/torchao/prototype/sparsity/__init__.py b/torchao/prototype/sparsity/__init__.py new file mode 100644 index 0000000000..924b7f409b --- /dev/null +++ b/torchao/prototype/sparsity/__init__.py @@ -0,0 +1,20 @@ +# Sparsifier +# Scheduler +from torchao.prototype.sparsity.scheduler.base_scheduler import BaseScheduler +from torchao.prototype.sparsity.scheduler.cubic_scheduler import CubicSL +from torchao.prototype.sparsity.scheduler.lambda_scheduler import LambdaSL +from torchao.prototype.sparsity.sparsifier.base_sparsifier import BaseSparsifier +from torchao.prototype.sparsity.sparsifier.nearly_diagonal_sparsifier import ( + NearlyDiagonalSparsifier, +) + +# Parametrizations +from torchao.prototype.sparsity.sparsifier.utils import ( + FakeSparsity, + fqn_to_module, + get_arg_info_from_tensor_fqn, + module_to_fqn, +) +from torchao.prototype.sparsity.sparsifier.weight_norm_sparsifier import ( + WeightNormSparsifier, +) diff --git a/torchao/prototype/sparsity/pruner/FPGM_pruner.py b/torchao/prototype/sparsity/pruner/FPGM_pruner.py new file mode 100644 index 0000000000..d8c3d20052 --- /dev/null +++ b/torchao/prototype/sparsity/pruner/FPGM_pruner.py @@ -0,0 +1,93 @@ +from typing import Callable, Optional, Union + +import torch + +from .base_structured_sparsifier import BaseStructuredSparsifier + +__all__ = ["FPGMPruner"] + + +class FPGMPruner(BaseStructuredSparsifier): + r"""Filter Pruning via Geometric Median (FPGM) Structured Pruner + This sparsifier prune fliter (row) in a tensor according to distances among filters according to + `Filter Pruning via Geometric Median for Deep Convolutional Neural Networks Acceleration `_. + + This sparsifier is controlled by three variables: + 1. `sparsity_level` defines the number of filters (rows) that are zeroed-out. + 2. `dist` defines the distance measurement type. Default: 3 (L2 distance). + Available options are: [1, 2, (custom callable distance function)]. + + Note:: + Inputs should be a 4D convolutional tensor of shape (N, C, H, W). + - N: output channels size + - C: input channels size + - H: height of kernel + - W: width of kernel + """ + + def __init__( + self, sparsity_level: float = 0.5, dist: Optional[Union[Callable, int]] = None + ): + defaults = { + "sparsity_level": sparsity_level, + } + + if dist is None: + dist = 2 + + if callable(dist): + self.dist_fn = dist + elif dist == 1: + self.dist_fn = lambda x: torch.cdist(x, x, p=1) + elif dist == 2: + self.dist_fn = lambda x: torch.cdist(x, x, p=2) + else: + raise NotImplementedError("Distance function is not yet implemented.") + super().__init__(defaults=defaults) + + def _compute_distance(self, t): + r"""Compute distance across all entries in tensor `t` along all dimension + except for the one identified by dim. + Args: + t (torch.Tensor): tensor representing the parameter to prune + Returns: + distance (torch.Tensor): distance computed across filtters + """ + dim = 0 # prune filter (row) + + size = t.size(dim) + slc = [slice(None)] * t.dim() + + # flatten the tensor along the dimension + t_flatten = [ + t[tuple(slc[:dim] + [slice(i, i + 1)] + slc[dim + 1 :])].reshape(-1) + for i in range(size) + ] + t_flatten = torch.stack(t_flatten) + + # distance measurement + dist_matrix = self.dist_fn(t_flatten) + + # more similar with other filter indicates large in the sum of row + distance = torch.sum(torch.abs(dist_matrix), 1) + + return distance + + def update_mask(self, module, tensor_name, sparsity_level, **kwargs): + tensor_weight = getattr(module, tensor_name) + mask = getattr(module.parametrizations, tensor_name)[0].mask + + if sparsity_level <= 0: + mask.data = torch.ones_like(mask).bool() + elif sparsity_level >= 1.0: + mask.data = torch.zeros_like(mask).bool() + else: + distance = self._compute_distance(tensor_weight) + + tensor_size = tensor_weight.shape[0] # prune filter (row) + nparams_toprune = round(sparsity_level * tensor_size) + nparams_toprune = min( + max(nparams_toprune, 0), tensor_size + ) # clamp to [0, tensor_size] + topk = torch.topk(distance, k=nparams_toprune, largest=False) + mask[topk.indices] = False diff --git a/torchao/sparsity/prototype/pruner/README.md b/torchao/prototype/sparsity/pruner/README.md similarity index 100% rename from torchao/sparsity/prototype/pruner/README.md rename to torchao/prototype/sparsity/pruner/README.md diff --git a/torchao/prototype/sparsity/pruner/__init__.py b/torchao/prototype/sparsity/pruner/__init__.py new file mode 100644 index 0000000000..6f017aa9e2 --- /dev/null +++ b/torchao/prototype/sparsity/pruner/__init__.py @@ -0,0 +1,8 @@ +from .base_structured_sparsifier import BaseStructuredSparsifier +from .parametrization import ( + FakeStructuredSparsity, + BiasHook, +) +from .saliency_pruner import SaliencyPruner +from .lstm_saliency_pruner import LSTMSaliencyPruner +from .FPGM_pruner import FPGMPruner diff --git a/torchao/prototype/sparsity/pruner/base_structured_sparsifier.py b/torchao/prototype/sparsity/pruner/base_structured_sparsifier.py new file mode 100644 index 0000000000..fa0c3bad6a --- /dev/null +++ b/torchao/prototype/sparsity/pruner/base_structured_sparsifier.py @@ -0,0 +1,314 @@ +from itertools import chain +from operator import getitem +from typing import Callable, Dict, Optional, Set, Tuple, Type, Union + +import torch +import torch.nn.functional as F +from torch import nn +from torch.fx import symbolic_trace +from torch.nn.utils import parametrize + +from torchao.prototype.sparsity import BaseSparsifier + +from .match_utils import apply_match, MatchAllNode +from .parametrization import BiasHook, FakeStructuredSparsity, module_contains_param +from .prune_functions import ( + prune_conv2d, + prune_conv2d_activation_conv2d, + prune_conv2d_activation_pool_conv2d, + prune_conv2d_conv2d, + prune_conv2d_pool_activation_conv2d, + prune_conv2d_pool_flatten_linear, + prune_linear, + prune_linear_activation_linear, + prune_linear_linear, + prune_lstm_output_layernorm_linear, + prune_lstm_output_linear, +) + + +def _get_supported_structured_pruning_modules(): + SUPPORTED_STRUCTURED_PRUNING_MODULES = { # added to config if None given + nn.Linear, + nn.Conv2d, + nn.LSTM, + } + return SUPPORTED_STRUCTURED_PRUNING_MODULES + + +def _get_supported_activation_functions(): + SUPPORTED_ACTIVATION_FUNCTIONS = { + F.relu, + F.rrelu, + F.hardtanh, + F.relu6, + F.sigmoid, + F.hardsigmoid, + F.tanh, + F.silu, + F.mish, + F.hardswish, + F.elu, + F.celu, + F.selu, + F.hardshrink, + F.leaky_relu, + F.logsigmoid, + F.softplus, + F.prelu, + F.softsign, + F.tanhshrink, + F.gelu, + F.dropout, + } + return SUPPORTED_ACTIVATION_FUNCTIONS + + +def _get_supported_activation_modules(): + SUPPORTED_ACTIVATION_MODULES = { + nn.ReLU, + nn.RReLU, + nn.Hardtanh, + nn.ReLU6, + nn.Sigmoid, + nn.Hardsigmoid, + nn.Tanh, + nn.SiLU, + nn.Mish, + nn.Hardswish, + nn.ELU, + nn.CELU, + nn.SELU, + nn.Hardshrink, + nn.LeakyReLU, + nn.LogSigmoid, + nn.Softplus, + nn.PReLU, + nn.Softsign, + nn.Tanhshrink, + nn.GELU, + nn.Dropout, + } + return SUPPORTED_ACTIVATION_MODULES + + +def _get_default_structured_pruning_patterns() -> Dict[ + Tuple[Union[Type[nn.Module], Callable, MatchAllNode, str], ...], + Callable[..., None], +]: + """ + Returns the patterns for conv2d / linear conversion for each element in the activation functions/modules defined above. + """ + patterns: Dict[ + Tuple[Union[Type[nn.Module], Callable, MatchAllNode, str], ...], + Callable[..., None], + ] = { + # linear -> linear + (nn.Linear, "output"): prune_linear, + (nn.Linear, nn.Linear): prune_linear_linear, + # conv2d -> conv2d + (nn.Conv2d, "output"): prune_conv2d, + (nn.Conv2d, nn.Conv2d): prune_conv2d_conv2d, + # TODO LSTM Structured pruning does not support returned state currently. + # Should find a way to explicitly match getitem(0) instead of getitem. + # This will also require changing the pruning function. + # lstm -> getitem(0) -> linear + (nn.LSTM, getitem, nn.Linear): prune_lstm_output_linear, + # lstm -> getitem(0) -> layernorm -> linear + (nn.LSTM, getitem, nn.LayerNorm, nn.Linear): prune_lstm_output_layernorm_linear, + } + + for activation in chain( + _get_supported_activation_functions(), _get_supported_activation_modules() + ): + patterns.update( + { + # linear -> activation -> linear + (nn.Linear, activation, nn.Linear): prune_linear_activation_linear, + # conv2d -> activation -> conv2d + (nn.Conv2d, activation, nn.Conv2d): prune_conv2d_activation_conv2d, + # conv2d -> activation -> pool -> conv2d + ( + nn.Conv2d, + activation, + nn.AvgPool2d, + nn.Conv2d, + ): prune_conv2d_activation_pool_conv2d, + ( + nn.Conv2d, + activation, + F.avg_pool2d, + nn.Conv2d, + ): prune_conv2d_activation_pool_conv2d, + ( + nn.Conv2d, + activation, + nn.MaxPool2d, + nn.Conv2d, + ): prune_conv2d_activation_pool_conv2d, + ( + nn.Conv2d, + activation, + F.max_pool2d, + nn.Conv2d, + ): prune_conv2d_activation_pool_conv2d, + # conv2d -> pool -> activation -> conv2d + ( + nn.Conv2d, + nn.AvgPool2d, + activation, + nn.Conv2d, + ): prune_conv2d_pool_activation_conv2d, + ( + nn.Conv2d, + F.avg_pool2d, + activation, + nn.Conv2d, + ): prune_conv2d_pool_activation_conv2d, + ( + nn.Conv2d, + nn.MaxPool2d, + activation, + nn.Conv2d, + ): prune_conv2d_pool_activation_conv2d, + ( + nn.Conv2d, + F.max_pool2d, + activation, + nn.Conv2d, + ): prune_conv2d_pool_activation_conv2d, + # conv2d -> adaptive pool -> flatten -> linear + ( + nn.Conv2d, + nn.AdaptiveAvgPool2d, + nn.Flatten, + nn.Linear, + ): prune_conv2d_pool_flatten_linear, + ( + nn.Conv2d, + nn.AdaptiveAvgPool2d, + torch.flatten, + nn.Linear, + ): prune_conv2d_pool_flatten_linear, + ( + nn.Conv2d, + nn.AdaptiveMaxPool2d, + nn.Flatten, + nn.Linear, + ): prune_conv2d_pool_flatten_linear, + ( + nn.Conv2d, + nn.AdaptiveMaxPool2d, + torch.flatten, + nn.Linear, + ): prune_conv2d_pool_flatten_linear, + } + ) + return patterns + + +class BaseStructuredSparsifier(BaseSparsifier): + r"""Base class for structured pruning. + + Abstract methods that need to be implemented: + - update_mask: Function to compute a new mask for all keys in the + `groups` attribute. + + Args: + - defaults [dict]: default configurations will be attached to the + configuration. Only the keys that don't exist in the `config` will + be updated. + """ + + def __init__(self, defaults, patterns=None): + super().__init__(defaults) + if patterns is None: + patterns = _get_default_structured_pruning_patterns() + self.patterns = patterns + + def make_config_from_model( + self, + model: nn.Module, + SUPPORTED_MODULES: Optional[Set[Type]] = None, + ) -> None: + if SUPPORTED_MODULES is None: + SUPPORTED_MODULES = _get_supported_structured_pruning_modules() + super().make_config_from_model(model, SUPPORTED_MODULES=SUPPORTED_MODULES) + + def _prepare(self, *args, **kwargs) -> None: + r"""This function will attach the FakeStructuredSparsity parameterizations + and BiasHooks at the appropriate points in the model. + """ + for config in self.groups: + module = config["module"] + tensor_name = config["tensor_name"] + parametrization = config.get("parametrization", FakeStructuredSparsity) + tensor = getattr(module, tensor_name) + + mask = config.get( + "mask", + torch.ones(tensor.shape[0], dtype=torch.bool, device=tensor.device), + ) + self.state[config["tensor_fqn"]]["mask"] = mask + parametrize.register_parametrization( + module, tensor_name, parametrization(mask) + ) + + # if linear / conv, we add in bias hooks + if isinstance(module, (nn.Linear, nn.Conv2d)): + prune_bias = config.get("prune_bias", True) + if module.bias is not None: + module.register_parameter( + "_bias", nn.Parameter(module.bias.detach()) + ) + module.bias = None + module.prune_bias = prune_bias + + module.register_forward_hook( + BiasHook(module.parametrizations.weight[0], prune_bias) + ) + + def prune(self) -> None: + r""" + This function will FX symbolically trace the model and then find instances of the patterns + defined in self.patterns (by default SUPPORTED_STRUCTURED_PRUNING_PATTERNS ). + + For each pattern, it will apply to corresponding conversion function, which will modify the output + and input size expected by the modules within the pattern + """ + + self.traced = symbolic_trace(self.model) + modules = dict(self.traced.named_modules()) + + # Right now we check for matches simply by iterating across all the patterns + # if this is slow we can store patterns in a trie-structure and modify this code for faster lookup + for node in self.traced.graph.nodes: + for pattern, convert_fn in self.patterns.items(): + matched = apply_match(modules, pattern, node, []) + if matched is None: + continue + + first_module = modules.get(node.target) + # check if first module exists and has appropriate parameterization, otherwise skip + if ( + first_module is not None + and parametrize.is_parametrized(first_module) + and module_contains_param(first_module, FakeStructuredSparsity) + ): + convert_block = [] + for node in matched: + if node.op == "call_module": + convert_block.append(modules.get(node.target)) + elif node.op == "call_function": + convert_block.append(node.target) + convert_fn(*convert_block) + + for module in self.traced.modules(): + if module_contains_param(module, FakeStructuredSparsity): + raise Exception( + f"Error: {module} still contains FakeStructuredSparsity parametrizations!" + ) + + self.traced.graph.lint() + self.traced.recompile() + return self.traced diff --git a/torchao/sparsity/prototype/pruner/images/prune_1.png b/torchao/prototype/sparsity/pruner/images/prune_1.png similarity index 100% rename from torchao/sparsity/prototype/pruner/images/prune_1.png rename to torchao/prototype/sparsity/pruner/images/prune_1.png diff --git a/torchao/sparsity/prototype/pruner/images/prune_2.png b/torchao/prototype/sparsity/pruner/images/prune_2.png similarity index 100% rename from torchao/sparsity/prototype/pruner/images/prune_2.png rename to torchao/prototype/sparsity/pruner/images/prune_2.png diff --git a/torchao/sparsity/prototype/pruner/images/prune_3.png b/torchao/prototype/sparsity/pruner/images/prune_3.png similarity index 100% rename from torchao/sparsity/prototype/pruner/images/prune_3.png rename to torchao/prototype/sparsity/pruner/images/prune_3.png diff --git a/torchao/sparsity/prototype/pruner/images/prune_4.png b/torchao/prototype/sparsity/pruner/images/prune_4.png similarity index 100% rename from torchao/sparsity/prototype/pruner/images/prune_4.png rename to torchao/prototype/sparsity/pruner/images/prune_4.png diff --git a/torchao/sparsity/prototype/pruner/images/prune_5.png b/torchao/prototype/sparsity/pruner/images/prune_5.png similarity index 100% rename from torchao/sparsity/prototype/pruner/images/prune_5.png rename to torchao/prototype/sparsity/pruner/images/prune_5.png diff --git a/torchao/sparsity/prototype/pruner/images/prune_6.png b/torchao/prototype/sparsity/pruner/images/prune_6.png similarity index 100% rename from torchao/sparsity/prototype/pruner/images/prune_6.png rename to torchao/prototype/sparsity/pruner/images/prune_6.png diff --git a/torchao/prototype/sparsity/pruner/lstm_saliency_pruner.py b/torchao/prototype/sparsity/pruner/lstm_saliency_pruner.py new file mode 100644 index 0000000000..4a0d74d6dc --- /dev/null +++ b/torchao/prototype/sparsity/pruner/lstm_saliency_pruner.py @@ -0,0 +1,48 @@ +from typing import cast + +import torch +from .base_structured_sparsifier import BaseStructuredSparsifier, FakeStructuredSparsity + +class LSTMSaliencyPruner(BaseStructuredSparsifier): + """ + Prune packed LSTM weights based on saliency. + For each layer {k} inside a LSTM, we have two packed weight matrices + - weight_ih_l{k} + - weight_hh_l{k} + + These tensors pack the weights for the 4 linear layers together for efficiency. + + [W_ii | W_if | W_ig | W_io] + + Pruning this tensor directly will lead to weights being misassigned when unpacked. + To ensure that each packed linear layer is pruned the same amount: + 1. We split the packed weight into the 4 constituent linear parts + 2. Update the mask for each individual piece using saliency individually + + This applies to both weight_ih_l{k} and weight_hh_l{k}. + """ + + def update_mask(self, module, tensor_name, **kwargs): + weights = getattr(module, tensor_name) + + for p in getattr(module.parametrizations, tensor_name): + if isinstance(p, FakeStructuredSparsity): + mask = cast(torch.Tensor, p.mask) + + # select weights based on magnitude + if weights.dim() <= 1: + raise Exception("Structured pruning can only be applied to a 2+dim weight tensor!") + # take norm over all but first dim + dims = tuple(range(1, weights.dim())) + saliency = weights.norm(dim=dims, p=1) + + # handle weights in 4 groups + split_size = len(mask) // 4 + masks = torch.split(mask, split_size) + saliencies = torch.split(saliency, split_size) + + for keep_mask, sal in zip(masks, saliencies): + # mask smallest k values to be removed + k = int(len(keep_mask) * kwargs["sparsity_level"]) + prune = sal.topk(k, largest=False, sorted=False).indices + keep_mask.data[prune] = False # modifies underlying p.mask directly diff --git a/torchao/sparsity/prototype/pruner/match_utils.py b/torchao/prototype/sparsity/pruner/match_utils.py similarity index 100% rename from torchao/sparsity/prototype/pruner/match_utils.py rename to torchao/prototype/sparsity/pruner/match_utils.py diff --git a/torchao/prototype/sparsity/pruner/parametrization.py b/torchao/prototype/sparsity/pruner/parametrization.py new file mode 100644 index 0000000000..df94f7093b --- /dev/null +++ b/torchao/prototype/sparsity/pruner/parametrization.py @@ -0,0 +1,59 @@ +import torch +from torch import nn +from torch.nn.utils.parametrize import is_parametrized + + +def module_contains_param(module, parametrization): + if is_parametrized(module): + # see if any of the module tensors have a parametriztion attached that matches the one passed in + return any( + any(isinstance(param, parametrization) for param in param_list) + for key, param_list in module.parametrizations.items() + ) + return False + + +# Structured Pruning Parameterizations +class FakeStructuredSparsity(nn.Module): + r""" + Parametrization for Structured Pruning. Like FakeSparsity, this should be attached to + the 'weight' or any other parameter that requires a mask. + + Instead of an element-wise bool mask, this parameterization uses a row-wise bool mask. + """ + + def __init__(self, mask): + super().__init__() + self.register_buffer("mask", mask) + + def forward(self, x): + assert isinstance(self.mask, torch.Tensor) + assert self.mask.shape[0] == x.shape[0] + shape = [1] * len(x.shape) + shape[0] = -1 + return self.mask.reshape(shape) * x + + def state_dict(self, *args, **kwargs): + # avoid double saving masks + return {} + + +class BiasHook: + def __init__(self, parametrization, prune_bias): + self.param = parametrization + self.prune_bias = prune_bias + + def __call__(self, module, input, output): + + if getattr(module, "_bias", None) is not None: + bias = module._bias.data + if self.prune_bias: + bias[~self.param.mask] = 0 + + # reshape bias to broadcast over output dimensions + idx = [1] * len(output.shape) + idx[1] = -1 + bias = bias.reshape(idx) + + output += bias + return output diff --git a/torchao/sparsity/prototype/pruner/prune_functions.py b/torchao/prototype/sparsity/pruner/prune_functions.py similarity index 100% rename from torchao/sparsity/prototype/pruner/prune_functions.py rename to torchao/prototype/sparsity/pruner/prune_functions.py diff --git a/torchao/prototype/sparsity/pruner/saliency_pruner.py b/torchao/prototype/sparsity/pruner/saliency_pruner.py new file mode 100644 index 0000000000..f965fa647d --- /dev/null +++ b/torchao/prototype/sparsity/pruner/saliency_pruner.py @@ -0,0 +1,29 @@ +from .base_structured_sparsifier import BaseStructuredSparsifier + + +class SaliencyPruner(BaseStructuredSparsifier): + """ + Prune rows based on the saliency (L1 norm) of each row. + + This pruner works on N-Dimensional weight tensors. + For each row, we will calculate the saliency, whic is the sum the L1 norm of all weights in that row. + We expect that the resulting saliency vector has the same shape as our mask. + We then pick elements to remove until we reach the target sparsity_level. + """ + + def update_mask(self, module, tensor_name, **kwargs): + # tensor_name will give you the FQN, all other entries in sparse config is present in kwargs + weights = getattr(module, tensor_name) + mask = getattr(module.parametrizations, tensor_name)[0].mask + + # use negative weights so we can use topk (we prune out the smallest) + if weights.dim() <= 1: + raise Exception("Structured pruning can only be applied to a 2+dim weight tensor!") + saliency = -weights.norm(dim=tuple(range(1, weights.dim())), p=1) + assert saliency.shape == mask.shape + + num_to_pick = int(len(mask) * kwargs["sparsity_level"]) + prune = saliency.topk(num_to_pick).indices + + # Set the mask to be false for the rows we want to prune + mask.data[prune] = False diff --git a/torchao/prototype/sparsity/scheduler/__init__.py b/torchao/prototype/sparsity/scheduler/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/torchao/prototype/sparsity/scheduler/base_scheduler.py b/torchao/prototype/sparsity/scheduler/base_scheduler.py new file mode 100644 index 0000000000..5d91098c55 --- /dev/null +++ b/torchao/prototype/sparsity/scheduler/base_scheduler.py @@ -0,0 +1,170 @@ +import warnings +import weakref +from functools import wraps + +from torchao.prototype.sparsity.sparsifier.base_sparsifier import BaseSparsifier + +__all__ = ["BaseScheduler"] + + +class BaseScheduler: + + def __init__(self, sparsifier, last_epoch=-1, verbose=False): + + # Attach sparsifier + if not isinstance(sparsifier, BaseSparsifier): + raise TypeError( + f"{type(sparsifier).__name__} is not an instance of torch.ao.pruning.BaseSparsifier" + ) + self.sparsifier = sparsifier + + # Initialize epoch and base sparsity levels + + self.base_sl = [group["sparsity_level"] for group in sparsifier.groups] + self.last_epoch = last_epoch + + # Following https://github.com/pytorch/pytorch/issues/20124 + # We would like to ensure that `scheduler.step()` is called after + # `sparsifier.step()` + def with_counter(method): + if getattr(method, "_with_counter", False): + # `sparsifier.step()` has already been replaced, return. + return method + + # Keep a weak reference to the sparsifier instance to prevent + # cyclic references. + instance_ref = weakref.ref(method.__self__) + # Get the unbound method for the same purpose. + func = method.__func__ + cls = instance_ref().__class__ + del method + + @wraps(func) + def wrapper(*args, **kwargs): + instance = instance_ref() + instance._step_count += 1 # type: ignore[union-attr] + wrapped = func.__get__(instance, cls) + return wrapped(*args, **kwargs) + + # Note that the returned function here is no longer a bound method, + # so attributes like `__func__` and `__self__` no longer exist. + wrapper._with_counter = True # type: ignore[attr-defined] + return wrapper + + self.sparsifier.step = with_counter(self.sparsifier.step) # type: ignore[assignment] + self.sparsifier._step_count = 0 # type: ignore[attr-defined] + self._step_count: int = 0 + self.verbose = verbose + + # Housekeeping + self._get_sl_called_within_step: bool = False + + self.step() + + def state_dict(self): + """Returns the state of the scheduler as a :class:`dict`. + + It contains an entry for every variable in self.__dict__ which + is not the sparsifier. + """ + return { + key: value for key, value in self.__dict__.items() if key != "sparsifier" + } + + def load_state_dict(self, state_dict): + """Loads the schedulers state. + + Args: + state_dict (dict): scheduler state. Should be an object returned + from a call to :meth:`state_dict`. + """ + self.__dict__.update(state_dict) + + def get_last_sl(self): + """Return last computed sparsity level by current scheduler.""" + return self._last_sl + + def get_sl(self): + # Compute sparsity level using chainable form of the scheduler + # Note: This method is not intended to be called directly, and is only + # used by the ".step" method. Use .get_last_sl() instead. + if not self._get_sl_called_within_step: + warnings.warn( + "To get the last sparsity level computed by the scheduler, " + "please use `get_last_sl()`." + ) + raise NotImplementedError + + def print_sl(self, is_verbose, group, sl, epoch=None): + """Display the current sparsity level.""" + if is_verbose: + if epoch is None: + print(f"Adjusting sparsity level of group {group} to {sl:.4e}.") + else: + print( + f"Epoch {epoch:5d}: adjusting sparsity level of group {group} to {sl:.4e}." + ) + + def __repr__(self): + format_string = self.__class__.__name__ + " (" + format_string += "\n" + format_string += f"Sparsifier {self.sparsifier}\n" + format_string += f" base_sl: {self.base_sl}\n" + format_string += ")" + return format_string + + def step(self, epoch=None): + # Raise warning if trying to call scheduler step before the sparsifier. + # https://github.com/pytorch/pytorch/issues/20124 + if self._step_count == 1: + if not hasattr(self.sparsifier.step, "_with_counter"): + warnings.warn( + "Seems like `sparsifier.step()` has been overridden after sparsity scheduler " + "initialization. Please, make sure to call `sparsifier.step()` before " + "`scheduler.step()`.", + UserWarning, + ) + + # Just check if there were two first scheduler.step() calls before sparsifier.step() + elif self.sparsifier._step_count < 1: # type: ignore[attr-defined] + warnings.warn( + "Detected call of `scheduler.step()` before `sparsifier.step()`. " + "You have to make sure you run the sparsifier.step() BEFORE any " + "calls to the scheduler.step().", + UserWarning, + ) + self._step_count += 1 + + class _enable_get_sl_call: + + def __init__(self, o): + self.o = o + + def __enter__(self): + self.o._get_sl_called_within_step = True + return self + + def __exit__(self, type, value, traceback): + self.o._get_sl_called_within_step = False + + with _enable_get_sl_call(self): + self.last_epoch += 1 + values = self.get_sl() + + for i, data in enumerate(zip(self.sparsifier.groups, values)): + param_group, sl = data + param_group["sparsity_level"] = sl + self.print_sl(self.verbose, i, sl, epoch) + + self._last_sl = [group["sparsity_level"] for group in self.sparsifier.groups] + self.sparsifier.enable_mask_update = True + + def _make_sure_a_list(self, var): + r"""Utility that extends it to the same length as the .groups, ensuring it is a list""" + n = len(self.sparsifier.groups) + if not isinstance(var, (list, tuple)): + return [var] * n + else: + if len(var) != n: + raise ValueError(f"Expected variable of length {n}, but got {len(var)}") + return list(var) # We want the result to be in a list, not tuple diff --git a/torchao/prototype/sparsity/scheduler/cubic_scheduler.py b/torchao/prototype/sparsity/scheduler/cubic_scheduler.py new file mode 100644 index 0000000000..76fc61daa2 --- /dev/null +++ b/torchao/prototype/sparsity/scheduler/cubic_scheduler.py @@ -0,0 +1,107 @@ +import warnings + +from .base_scheduler import BaseScheduler + +__all__ = ["CubicSL"] + +def _clamp(x, lo, hi): + return max(lo, min(hi, x)) + + +class CubicSL(BaseScheduler): + r"""Sets the sparsity level of each parameter group to the final sl + plus a given exponential function. + + .. math:: + + s_i = s_f + (s_0 - s_f) \cdot \left( 1 - \frac{t - t_0}{n\Delta t} \right)^3 + + where :math:`s_i` is the sparsity at epoch :math:`t`, :math;`s_f` is the final + sparsity level, :math:`f(i)` is the function to be applied to the current epoch + :math:`t`, initial epoch :math:`t_0`, and final epoch :math:`t_f`. + :math:`\Delta t` is used to control how often the update of the sparsity level + happens. By default, + + Args: + sparsifier (BaseSparsifier): Wrapped sparsifier. + init_sl (int, list): Initial level of sparsity + init_t (int, list): Initial step, when pruning starts + delta_t (int, list): Pruning frequency + total_t (int, list): Total number of pruning steps + initially_zero (bool, list): If True, sets the level of sparsity to 0 + before init_t (:math:`t_0`). Otherwise, the sparsity level before + init_t (:math:`t_0`) is set to init_sl(:math:`s_0`) + last_epoch (int): The index of last epoch. Default: -1. + verbose (bool): If ``True``, prints a message to stdout for + each update. Default: ``False``. + """ + def __init__(self, + sparsifier, + init_sl=0.0, + init_t=0, + delta_t=10, + total_t=100, + initially_zero=False, + last_epoch=-1, + verbose=False + ): + self.sparsifier = sparsifier + + self.init_sl = self._make_sure_a_list(init_sl) + self.init_t = self._make_sure_a_list(init_t) + self.delta_t = self._make_sure_a_list(delta_t) + self.total_t = self._make_sure_a_list(total_t) + + self.initially_zero = self._make_sure_a_list(initially_zero) + + super().__init__(sparsifier, last_epoch, verbose) + + @staticmethod + def sparsity_compute_fn(s_0, s_f, t, t_0, dt, n, initially_zero=False): + r""""Computes the current level of sparsity. + + Based on https://arxiv.org/pdf/1710.01878.pdf + + Args: + s_0: Initial level of sparsity, :math:`s_i` + s_f: Target level of sparsity, :math:`s_f` + t: Current step, :math:`t` + t_0: Initial step, :math:`t_0` + dt: Pruning frequency, :math:`\Delta T` + n: Pruning steps, :math:`n` + initially_zero: Sets the level of sparsity to 0 before t_0. + If False, sets to s_0 + + Returns: + The sparsity level :math:`s_t` at the current step :math:`t` + """ + if initially_zero and t < t_0: + return 0 + s_t = s_f + (s_0 - s_f) * (1.0 - (t - t_0) / (dt * n)) ** 3 + s_t = _clamp(s_t, s_0, s_f) + return s_t + + def get_sl(self): + if not self._get_sl_called_within_step: + warnings.warn( + "To get the last sparsity level computed by the scheduler, " + "please use `get_last_sl()`.") + return [ + self.sparsity_compute_fn( + s_0=initial_sparsity, + s_f=final_sparsity, + t=self.last_epoch, + t_0=initial_epoch, + dt=delta_epoch, + n=interval_epochs, + initially_zero=initially_zero + ) for initial_sparsity, final_sparsity, initial_epoch, delta_epoch, interval_epochs, initially_zero in + zip( + self.init_sl, + self.base_sl, + self.init_t, + self.delta_t, + self.total_t, + self.initially_zero + ) + ] diff --git a/torchao/prototype/sparsity/scheduler/lambda_scheduler.py b/torchao/prototype/sparsity/scheduler/lambda_scheduler.py new file mode 100644 index 0000000000..a88d99a1f8 --- /dev/null +++ b/torchao/prototype/sparsity/scheduler/lambda_scheduler.py @@ -0,0 +1,47 @@ +import warnings + +from .base_scheduler import BaseScheduler + +__all__ = ["LambdaSL"] + +class LambdaSL(BaseScheduler): + """Sets the sparsity level of each parameter group to the final sl + times a given function. When last_epoch=-1, sets initial sl as zero. + Args: + sparsifier (BaseSparsifier): Wrapped sparsifier. + sl_lambda (function or list): A function which computes a multiplicative + factor given an integer parameter epoch, or a list of such + functions, one for each group in sparsifier.param_groups. + last_epoch (int): The index of last epoch. Default: -1. + verbose (bool): If ``True``, prints a message to stdout for + each update. Default: ``False``. + Example: + >>> # Assuming sparsifier has two groups. + >>> lambda1 = lambda epoch: epoch // 30 + >>> lambda2 = lambda epoch: 0.95 ** epoch + >>> # xdoctest: +SKIP + >>> scheduler = LambdaSL(sparsifier, sl_lambda=[lambda1, lambda2]) + >>> for epoch in range(100): + >>> train(...) + >>> validate(...) + >>> scheduler.step() + """ + + def __init__(self, sparsifier, sl_lambda, last_epoch=-1, verbose=False): + self.sparsifier = sparsifier + + if not isinstance(sl_lambda, list) and not isinstance(sl_lambda, tuple): + self.sl_lambdas = [sl_lambda] * len(sparsifier.groups) + else: + if len(sl_lambda) != len(sparsifier.groups): + raise ValueError(f"Expected {len(sparsifier.groups)} lr_lambdas, but got {len(sl_lambda)}") + self.sl_lambdas = list(sl_lambda) + super().__init__(sparsifier, last_epoch, verbose) + + def get_sl(self): + if not self._get_sl_called_within_step: + warnings.warn( + "To get the last sparsity level computed by the scheduler, " + "please use `get_last_sl()`.") + return [base_sl * lmbda(self.last_epoch) + for lmbda, base_sl in zip(self.sl_lambdas, self.base_sl)] diff --git a/torchao/prototype/sparsity/sparsifier/__init__.py b/torchao/prototype/sparsity/sparsifier/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/torchao/prototype/sparsity/sparsifier/base_sparsifier.py b/torchao/prototype/sparsity/sparsifier/base_sparsifier.py new file mode 100644 index 0000000000..1c210ace34 --- /dev/null +++ b/torchao/prototype/sparsity/sparsifier/base_sparsifier.py @@ -0,0 +1,353 @@ +import abc +import copy +from collections import defaultdict +from typing import Any, Dict, Optional, Set, Tuple, List, Type + +import torch +from torch import nn +from torch.nn.utils import parametrize +from torch.nn.utils.parametrize import type_before_parametrizations + +from .utils import ( + module_contains_param, + swap_module, + FakeSparsity, + get_arg_info_from_tensor_fqn, + module_to_fqn, +) + +__all__ = ["BaseSparsifier"] + +SUPPORTED_MODULES = {nn.Linear} + +KEYS_NOT_IN_STATE_DICT = ["module", "module_fqn", "tensor_name"] + +__all__ = ["BaseSparsifier"] + + +# TODO update desc with new config args +class BaseSparsifier(abc.ABC): + r"""Base class for all sparsifiers. + + Abstract methods that need to be implemented: + + - update_mask: Function to compute a new mask for all keys in the + `groups`. + + Args: + - model [nn.Module]: model to configure. The model itself is not saved + but used for the state_dict saving / loading. + - config [list]: configuration elements should be a dict map that includes + `tensor_fqn` of tensors to sparsify + - defaults [dict]: default configurations will be attached to the + configuration. Only the keys that don't exist in the `config` will + be updated. + + Example:: + + >>> # xdoctest: +SKIP("Can't instantiate abstract class BaseSparsifier with abstract method update_mask") + >>> config = [{'tensor_fqn': 'layer1.weight', 'tensor_fqn': 'linear2.weight2', 'sparsity_level': 0.5}] + >>> defaults = {'sparsity_level': 0.7} + >>> # model.layer1.weight will have `sparsity_level` = 0.7 (getting default) + >>> sparsifier = BaseSparsifier(config, defaults) + """ + + def __init__(self, defaults: Optional[Dict[str, Any]] = None): + super().__init__() + self.defaults: Dict[str, Any] = defaults or {} + + self.state: Dict[str, Dict] = defaultdict(dict) + self.groups: List[Dict[str, Any]] = [] + self.enable_mask_update = True + + def __getstate__(self) -> Dict[str, Any]: + return { + "defaults": self.defaults, + "state": self.state, + "groups": self.groups, + } + + def __setstate__(self, state: Dict[str, Dict[str, Any]]) -> None: + self.__dict__.update(state) + + def __repr__(self): + format_string = self.__class__.__name__ + " (" + for i, sparse_args in enumerate(self.groups): + module = sparse_args["module"] + format_string += "\n" + format_string += f"\tGroup {i}\n" + format_string += f"\t module: {module}\n" + for key in sorted(sparse_args.keys()): + if key == "module": + continue + format_string += f"\t {key}: {sparse_args[key]}\n" + format_string += ")" + return format_string + + def state_dict(self) -> Dict[str, Any]: + r"""Returns the state of the optimizer as a :class:`dict`. + + It contains: + * state - current state of the sparsification. + * groups - a list containing all sparsity configuration groups + with the key 'tensor_fqn' specifying the path to the sparsified tensor within a model + + TODO: Need a clean way of loading the state of the "prepared" module + """ + + groups: List[Dict[str, Any]] = [ + dict( + filter( + lambda key_value: key_value[0] not in KEYS_NOT_IN_STATE_DICT, + mg.items(), + ) + ) + for mg in self.groups + ] + + return { + "state": self.state, + "groups": groups, + } + + def load_state_dict(self, state_dict: Dict[str, Any], strict: bool = True): + groups = copy.deepcopy(state_dict["groups"]) + states = state_dict["state"] + for tensor_fqn, s in states.items(): + arg_info = get_arg_info_from_tensor_fqn(self.model, tensor_fqn) + module = arg_info["module"] + tensor_name = arg_info["tensor_name"] + if strict and module is None: + raise RuntimeError(f"Error loading {tensor_fqn} into the model") + + found = False + for p in module.parametrizations[tensor_name]: + if isinstance(p, FakeSparsity): + found = True + break + if not found: + p = FakeSparsity(torch.ones(getattr(module, tensor_name).shape)) + parametrize.register_parametrization(module, tensor_name, p) + if s.get("mask", None) is not None: + mask = s.pop("mask") + p.mask = mask + + for mg in groups: + if mg["tensor_fqn"] == tensor_fqn: + mg.update(arg_info) + self.__setstate__({"state": states, "groups": groups}) + + def make_config_from_model( + self, + model: nn.Module, + SUPPORTED_MODULES: Set[Type] = SUPPORTED_MODULES, + ) -> None: + self.config = [] + stack = [model] + while stack: + module = stack.pop() + for name, child in module.named_children(): + if type(child) in SUPPORTED_MODULES: + module_fqn = module_to_fqn(model, child) + assert isinstance(module_fqn, str) # for mypy + self.config.append({"tensor_fqn": module_fqn + ".weight"}) + else: + stack.append(child) + + def prepare(self, model, config): + r"""Prepares a model, by adding the parametrizations. + + Note:: + + The model is modified inplace. If you need to preserve the original + model, use copy.deepcopy. + """ + self.model = model # TODO: Need to figure out how to load without this. + self.config = config + + # If no config -- try getting all the supported layers + if self.config is None: + self.make_config_from_model(model) + + # TODO: Remove the configuration by reference ('module') + for module_config in self.config: + assert isinstance(module_config, dict), ( + "config elements should be dicts not modules i.e.:" + "[{`tensor_fqn`: `foo.bar.weight`}, {`tensor_fqn`: ... }, ...]" + ) + + assert isinstance(self.defaults, Dict) # for mypy + local_args = copy.deepcopy(self.defaults) + local_args.update(module_config) + + tensor_fqn = local_args.get("tensor_fqn", None) + assert tensor_fqn is not None, ( + "tensor_fqn is a required argument in the sparsity config which" + "replaces previous `module` and [module]`fqn` arguments" + ) + + # populate all information from tensor_fqn + info_from_tensor_fqn = get_arg_info_from_tensor_fqn(model, tensor_fqn) + + # check that whatever was put into local_args agrees with what was obtained + # from tensor_fqn + for key in info_from_tensor_fqn.keys(): + if key in local_args: + assert ( + info_from_tensor_fqn[key] == local_args[key] + or ( + key == "tensor_fqn" + and "." + info_from_tensor_fqn[key] == local_args[key] + ) + # info_from_tensor_fqn will chop leading '.' from tensor_fqn so ignore that + ), ( + f"Given both `{key}` and `tensor_fqn` in the config, it is expected them to agree!" + ) + local_args.update(info_from_tensor_fqn) + self.groups.append(local_args) + self._prepare() + + def _prepare(self, *args, **kwargs): + r"""Adds mask parametrization to the layer weight""" + for config in self.groups: + module = config["module"] + tensor_name = config["tensor_name"] + parametrization = config.get("parametrization", FakeSparsity) + mask = config.get("mask", torch.ones_like(getattr(module, tensor_name))) + self.state[config["tensor_fqn"]]["mask"] = mask + parametrize.register_parametrization( + module, tensor_name, parametrization(mask) + ) + + def squash_mask( + self, + params_to_keep: Optional[Tuple[str, ...]] = None, + params_to_keep_per_layer: Optional[Dict[str, Tuple[str, ...]]] = None, + *args, + **kwargs, + ): + r"""Squashes the sparse masks into the appropriate tensors. + + If either the `params_to_keep` or `params_to_keep_per_layer` is set, + the module will have a `sparse_params` dict attached to it. + + Args: + params_to_keep: List of keys to save in the module or a dict + representing the modules and keys that will have + sparsity parameters saved + params_to_keep_per_layer: Dict to specify the params that should be + saved for specific layers. The keys in the dict + should be the module fqn, while the values should + be a list of strings with the names of the variables + to save in the `sparse_params` + + Examples: + >>> # xdoctest: +SKIP("locals are undefined") + >>> # Don't save any sparse params + >>> sparsifier.squash_mask() + >>> hasattr(model.submodule1, 'sparse_params') + False + + >>> # Keep sparse params per layer + >>> sparsifier.squash_mask( + ... params_to_keep_per_layer={ + ... 'submodule1.linear1': ('foo', 'bar'), + ... 'submodule2.linear42': ('baz',) + ... }) + >>> print(model.submodule1.linear1.sparse_params) + {'foo': 42, 'bar': 24} + >>> print(model.submodule2.linear42.sparse_params) + {'baz': 0.1} + + >>> # Keep sparse params for all layers + >>> sparsifier.squash_mask(params_to_keep=('foo', 'bar')) + >>> print(model.submodule1.linear1.sparse_params) + {'foo': 42, 'bar': 24} + >>> print(model.submodule2.linear42.sparse_params) + {'foo': 42, 'bar': 24} + + >>> # Keep some sparse params for all layers, and specific ones for + >>> # some other layers + >>> sparsifier.squash_mask( + ... params_to_keep=('foo', 'bar'), + ... params_to_keep_per_layer={ + ... 'submodule2.linear42': ('baz',) + ... }) + >>> print(model.submodule1.linear1.sparse_params) + {'foo': 42, 'bar': 24} + >>> print(model.submodule2.linear42.sparse_params) + {'foo': 42, 'bar': 24, 'baz': 0.1} + """ + for config in self.groups: + module = config["module"] + tensor_name = config["tensor_name"] + parametrize.remove_parametrizations( + module, tensor_name, leave_parametrized=True + ) + sparse_params = {} + if params_to_keep is not None: + global_params = {k: config[k] for k in params_to_keep} + sparse_params.update(global_params) + if params_to_keep_per_layer is not None: + params = params_to_keep_per_layer.get(config["module_fqn"], None) + if params is not None: + per_layer_params = {k: config[k] for k in params} + sparse_params.update(per_layer_params) + if sparse_params: + # TODO handle multiple tensor being quantized on a single module, where to store sparse_params? + module.sparse_params = sparse_params + + def convert( + self, + module: nn.Module, + mapping: Optional[Dict[Type[nn.Module], Type[nn.Module]]] = None, + inplace: bool = False, + parameterization: Type[nn.Module] = FakeSparsity, + ): + r"""Converts submodules in input module to a different module according to `mapping` + by calling `from_dense` method on the target module class + Args: + module: input module + mapping: a dictionary that maps from source module type to target + module type, can be overwritten to allow swapping user defined + Modules + inplace: carry out model transformations in-place, the original module + is mutated + """ + if mapping is None: + raise NotImplementedError("Need to auto generate mapping ") + if not inplace: + module = copy.deepcopy(module) + + reassign = {} + for name, mod in module.named_children(): + # leaf node + if ( + module_contains_param(mod, parameterization) + and type_before_parametrizations(mod) in mapping + ): + reassign[name] = swap_module(mod, mapping) + else: + # recurse + reassign[name] = self.convert( + mod, + mapping=mapping, + inplace=True, + parameterization=parameterization, + ) + + for key, value in reassign.items(): + module._modules[key] = value + + return module + + def step(self, use_path: bool = True) -> None: + if not self.enable_mask_update: + return + with torch.no_grad(): + for config in self.groups: + self.update_mask(**config) + + @abc.abstractmethod + def update_mask(self, module: nn.Module, tensor_name: str, **kwargs): + pass diff --git a/torchao/prototype/sparsity/sparsifier/nearly_diagonal_sparsifier.py b/torchao/prototype/sparsity/sparsifier/nearly_diagonal_sparsifier.py new file mode 100644 index 0000000000..4f44e81485 --- /dev/null +++ b/torchao/prototype/sparsity/sparsifier/nearly_diagonal_sparsifier.py @@ -0,0 +1,55 @@ +import torch + +from . import base_sparsifier + + +class NearlyDiagonalSparsifier(base_sparsifier.BaseSparsifier): + r"""Nearly Diagonal Sparsifier + + This sparsifier creates a nearly diagonal mask to be applied to the weight matrix. + Nearly Diagonal Matrix is a matrix that contains non-zero elements near the diagonal and the rest are zero. + An example of a nearly diagonal matrix with degree (or nearliness) 3 and 5 are follows respectively. + 1 1 0 0 1 1 1 0 + 1 1 1 0 1 1 1 1 + 0 1 1 1 1 1 1 1 + 0 0 1 1 0 1 1 1 + Note that a nearly diagonal matrix with degree 1 is just a matrix with main diagonal populated + + This sparsifier is controlled by one variable: + 1. `nearliness` defines the number of non-zero diagonal lines that are closest to the main diagonal. + Currently - supports only odd number + + Note: + This can be accelerated (vectorized) once the Spdiagonal feature (PR: #78439) is landed or the banded matrix + feature is landed: https://stackoverflow.com/questions/52463972/generating-banded-matrices-using-numpy + + Args: + nearliness: The degree of nearliness (default = 1) + + """ + def __init__(self, nearliness: int = 1): + defaults = {'nearliness': nearliness} + super().__init__(defaults=defaults) + + def update_mask(self, module, tensor_name, nearliness, + **kwargs): + mask = getattr(module.parametrizations, tensor_name)[0].mask + mask.data = torch.zeros_like(mask) + if nearliness <= 0: + return + + tensor = getattr(module, tensor_name) + height, width = tensor.shape + + if nearliness % 2 == 0: + raise ValueError("nearliness can only be an odd number") + dist_to_diagonal = nearliness // 2 + # check + if dist_to_diagonal >= min(height, width): + raise ValueError("nearliness cannot be larger than the dimensions of tensor.") + + for row in range(0, height): + # Bounds of entries that needs to be set to 1 + low = max(0, row - dist_to_diagonal) + high = min(width, row + dist_to_diagonal + 1) + mask[row, low:high].fill_(1) diff --git a/torchao/prototype/sparsity/sparsifier/utils.py b/torchao/prototype/sparsity/sparsifier/utils.py new file mode 100644 index 0000000000..c52af88698 --- /dev/null +++ b/torchao/prototype/sparsity/sparsifier/utils.py @@ -0,0 +1,130 @@ +from typing import Any, Dict, Optional, Type +from torch.nn.utils.parametrize import type_before_parametrizations, is_parametrized +from itertools import chain + +from torch import nn + +__all__ = [ + "module_contains_param", + "swap_module", + "module_to_fqn", + "fqn_to_module", + "get_arg_info_from_tensor_fqn", + "FakeSparsity", +] + + +def module_contains_param(module: nn.Module, parametrization: Type[nn.Module]) -> bool: + if is_parametrized(module): + # see if any of the module tensors have a parametriztion attached that matches the one passed in + return any( + any(isinstance(param, parametrization) for param in param_list) + for key, param_list in module.parametrizations.items() # type: ignore[union-attr,operator] + ) + return False + + +def swap_module( + mod: nn.Module, mapping: Dict[Type[nn.Module], Type[nn.Module]] +) -> nn.Module: + r"""Swaps the module using from_dense according to the mapping passed in. + Args: + mod: input module + mapping: a dictionary that maps from nn module to sparse nn module + Return: + The corresponding sparse module of `mod` according to mapping, created using from_dense + """ + if type_before_parametrizations(mod) in mapping: + sparse_mod = mapping[type_before_parametrizations(mod)] + + # TODO Fix this typing, as Type[Module] has no attribute "from_dense" + new_mod = sparse_mod.from_dense(mod) # type: ignore[attr-defined] + + # Preserve module's pre forward hooks. They'll be called on quantized input + for pre_hook_fn in mod._forward_pre_hooks.values(): + new_mod.register_forward_pre_hook(pre_hook_fn) + # Preserve module's post forward hooks except _observer_forward_hook + # After convert they'll work with quantized output + for hook_fn in mod._forward_hooks.values(): + new_mod.register_forward_hook(hook_fn) + + # respect device affinity when swapping modules + devices = {p.device for p in chain(mod.parameters(), mod.buffers())} + assert len(devices) <= 1, ( + f"swap_module only works with cpu or single-device CUDA modules, but got devices {devices}" + ) + device = next(iter(devices)) if len(devices) > 0 else None + if device: + new_mod.to(device) + + return new_mod + + else: + return mod + + +def module_to_fqn( + model: nn.Module, module: nn.Module, prefix: str = "" +) -> Optional[str]: + """ + Returns the fqn for a module or None if module not a descendent of model. + """ + if module is model: + return "" + for name, child in model.named_children(): + fqn = module_to_fqn(child, module, ".") + if isinstance(fqn, str): + return prefix + name + fqn + return None + + +def fqn_to_module(model: Optional[nn.Module], path: str) -> Optional[nn.Module]: + """ + Given an fqn, returns the corresponding module or tensor or None if the fqn given by `path` + doesn't correspond to anything. Similar to model.get_submodule(path) but works for tensors. + """ + if path != "": + for name in path.split("."): + model = getattr(model, name, None) + return model + + +def get_arg_info_from_tensor_fqn(model: nn.Module, tensor_fqn: str) -> Dict[str, Any]: + """ + Uses tensor_fqn to obtain a dict containing module_fqn, module and tensor_name + """ + # string manip to split tensor_fqn into module_fqn and tensor_name + # if tensor_fqn is 'weight' then module_fqn and tensor_name are '' and 'weight' + # if tensor_fqn is 'linear.weight' then module_fqn and tensor_name are 'linear' and 'weight' + tensor_name = tensor_fqn.split(".")[-1] + module_fqn = tensor_fqn[: -len(tensor_name) - ("." in tensor_fqn)] + + module = fqn_to_module(model, module_fqn) + + return { + "module_fqn": module_fqn, + "module": module, + "tensor_name": tensor_name, + "tensor_fqn": tensor_fqn, + } + + +# Parametrizations +class FakeSparsity(nn.Module): + r"""Parametrization for the weights. Should be attached to the 'weight' or + any other parameter that requires a mask applied to it. + + Note:: + + Once the mask is passed, the variable should not change the id. The + contents of the mask can change, but the mask reference itself should + not. + """ + + def __init__(self, mask): + super().__init__() + self.register_buffer("mask", mask) + + def forward(self, x): + assert self.mask.shape == x.shape + return self.mask * x diff --git a/torchao/prototype/sparsity/sparsifier/weight_norm_sparsifier.py b/torchao/prototype/sparsity/sparsifier/weight_norm_sparsifier.py new file mode 100644 index 0000000000..2b24ca3d82 --- /dev/null +++ b/torchao/prototype/sparsity/sparsifier/weight_norm_sparsifier.py @@ -0,0 +1,200 @@ +from functools import reduce +from typing import Callable, Optional, Tuple, Union + +import torch +import torch.nn.functional as F + +from .base_sparsifier import BaseSparsifier +import operator + +__all__ = ["WeightNormSparsifier"] + +def _flat_idx_to_2d(idx, shape): + rows = idx // shape[1] + cols = idx % shape[1] + return rows, cols + +class WeightNormSparsifier(BaseSparsifier): + r"""Weight-Norm Sparsifier + + This sparsifier computes the norm of every sparse block and "zeroes-out" the + ones with the lowest norm. The level of sparsity defines how many of the + blocks is removed. + + This sparsifier is controlled by three variables: + 1. `sparsity_level` defines the number of *sparse blocks* that are zeroed-out + 2. `sparse_block_shape` defines the shape of the sparse blocks. Note that + the sparse blocks originate at the zero-index of the tensor. + 3. `zeros_per_block` is the number of zeros that we are expecting in each + sparse block. By default we assume that all elements within a block are + zeroed-out. However, setting this variable sets the target number of + zeros per block. The zeros within each block are chosen as the *smallest + absolute values*. + + Args: + + sparsity_level: The target level of sparsity + sparse_block_shape: The shape of a sparse block (see note below) + zeros_per_block: Number of zeros in a sparse block + norm: Norm to use. Could be either `int` or a callable. + If `int`, only L1 and L2 are implemented. + + Note:: + The `sparse_block_shape` is tuple representing (block_ROWS, block_COLS), + irrespective of what the rows / cols mean in the data tensor. That means, + if you were to sparsify a weight tensor in the nn.Linear, which has a + weight shape `(Cout, Cin)`, the `block_ROWS` would refer to the output + channels, while the `block_COLS` would refer to the input channels. + + Note:: + All arguments to the WeightNormSparsifier constructor are "default" + arguments and could be overriden by the configuration provided in the + `prepare` step. + """ + def __init__(self, + sparsity_level: float = 0.5, + sparse_block_shape: Tuple[int, int] = (1, 4), + zeros_per_block: Optional[int] = None, + norm: Optional[Union[Callable, int]] = None): + if zeros_per_block is None: + zeros_per_block = reduce(operator.mul, sparse_block_shape) + defaults = { + "sparsity_level": sparsity_level, + "sparse_block_shape": sparse_block_shape, + "zeros_per_block": zeros_per_block, + } + if norm is None: + norm = 2 + if callable(norm): + self.norm_fn = norm + elif norm == 1: + self.norm_fn = lambda T: T.abs() + elif norm == 2: + self.norm_fn = lambda T: T * T + else: + raise NotImplementedError(f"L-{norm} is not yet implemented.") + super().__init__(defaults=defaults) + + def _scatter_fold_block_mask(self, output_shape, dim, indices, block_shape, + mask=None, input_shape=None, device=None): + r"""Creates patches of size `block_shape` after scattering the indices.""" + if mask is None: + assert input_shape is not None + mask = torch.ones(input_shape, device=device) + mask.scatter_(dim=dim, index=indices, value=0) + mask.data = F.fold(mask, output_size=output_shape, kernel_size=block_shape, stride=block_shape) + return mask + + def _make_tensor_mask(self, data, input_shape, sparsity_level, sparse_block_shape, mask=None): + r"""Creates a tensor-level mask. + + Tensor-level mask is described as a mask, where the granularity of sparsification of the + smallest patch is the sparse_block_shape. That means, that for a given mask and a + sparse_block_shape, the smallest "patch" of zeros/ones could be the sparse_block_shape. + + In this context, `sparsity_level` describes the fraction of sparse patches. + """ + h, w = data.shape[-2:] + block_h, block_w = sparse_block_shape + dh = (block_h - h % block_h) % block_h + dw = (block_w - w % block_w) % block_w + + if mask is None: + mask = torch.ones(h + dh, w + dw, device=data.device) + + if sparsity_level >= 1.0: + mask.data = torch.zeros_like(mask) + return mask + elif sparsity_level <= 0.0: + mask.data = torch.ones_like(mask) + return mask + + values_per_block = reduce(operator.mul, sparse_block_shape) + if values_per_block > 1: + # Reduce the data + data = F.avg_pool2d( + data[None, None, :], kernel_size=sparse_block_shape, stride=sparse_block_shape, ceil_mode=True + ) + data = data.flatten() + num_blocks = len(data) + + data = data.repeat(1, values_per_block, 1) + + threshold_idx = int(round(sparsity_level * num_blocks)) + threshold_idx = max(0, min(num_blocks - 1, threshold_idx)) # Sanity check + _, sorted_idx = torch.topk(data, k=threshold_idx, dim=2, largest=False) + + # Temp reshape for mask + mask_reshape = mask.reshape(data.shape) # data might be reshaped + self._scatter_fold_block_mask( + dim=2, output_shape=(h + dh, w + dw), + indices=sorted_idx, block_shape=sparse_block_shape, mask=mask_reshape + ) + mask.data = mask_reshape.squeeze().reshape(mask.shape)[:h, :w].contiguous() + return mask + + def _make_block_mask(self, data, sparse_block_shape, zeros_per_block, mask=None): + r"""Creates a block-level mask. + + Block-level mask is described as a mask, where the granularity of sparsification of the + largest patch is the sparse_block_shape. That means that for a given mask and a + sparse_block_shape, the sparsity is computed only within a patch of a size sparse_block_shape. + + In this context the `zeros_per_block` describes the number of zeroed-out elements within a patch. + """ + h, w = data.shape[-2:] + block_h, block_w = sparse_block_shape + dh = (block_h - h % block_h) % block_h + dw = (block_w - w % block_w) % block_w + values_per_block = reduce(operator.mul, sparse_block_shape) + + if mask is None: + mask = torch.ones((h + dh, w + dw), device=data.device) + + if values_per_block == zeros_per_block: + # Everything should be sparsified + mask.data = torch.zeros_like(mask) + return mask + + # create a new padded tensor like data (to match the block_shape) + padded_data = torch.ones(h + dh, w + dw, dtype=data.dtype, device=data.device) + padded_data.fill_(torch.nan) + padded_data[:h, :w] = data + unfolded_data = F.unfold(padded_data[None, None, :], kernel_size=sparse_block_shape, stride=sparse_block_shape) + + # Temp reshape for mask + mask_reshape = mask.reshape(unfolded_data.shape) + _, sorted_idx = torch.topk(unfolded_data, k=zeros_per_block, dim=1, largest=False) + + self._scatter_fold_block_mask( + dim=1, indices=sorted_idx, output_shape=padded_data.shape, block_shape=sparse_block_shape, mask=mask_reshape + ) + + mask.data = mask_reshape.squeeze().reshape(mask.shape).contiguous() + return mask + + def update_mask(self, module, tensor_name, sparsity_level, sparse_block_shape, + zeros_per_block, **kwargs): + values_per_block = reduce(operator.mul, sparse_block_shape) + if zeros_per_block > values_per_block: + raise ValueError( + "Number of zeros per block cannot be more than the total number of elements in that block." + ) + if zeros_per_block < 0: + raise ValueError("Number of zeros per block should be positive.") + + mask = getattr(module.parametrizations, tensor_name)[0].mask + if sparsity_level <= 0 or zeros_per_block == 0: + mask.data = torch.ones_like(mask) + elif sparsity_level >= 1.0 and (zeros_per_block == values_per_block): + mask.data = torch.zeros_like(mask) + else: + ww = self.norm_fn(getattr(module, tensor_name)) + tensor_mask = self._make_tensor_mask( + data=ww, input_shape=ww.shape, sparsity_level=sparsity_level, sparse_block_shape=sparse_block_shape + ) + if values_per_block != zeros_per_block: + block_mask = self._make_block_mask(data=ww, sparse_block_shape=sparse_block_shape, + zeros_per_block=zeros_per_block) + tensor_mask = torch.logical_or(tensor_mask, block_mask) + mask.data = tensor_mask diff --git a/torchao/sparsity/prototype/superblock/.gitignore b/torchao/prototype/sparsity/superblock/.gitignore similarity index 100% rename from torchao/sparsity/prototype/superblock/.gitignore rename to torchao/prototype/sparsity/superblock/.gitignore diff --git a/torchao/sparsity/prototype/superblock/README.md b/torchao/prototype/sparsity/superblock/README.md similarity index 100% rename from torchao/sparsity/prototype/superblock/README.md rename to torchao/prototype/sparsity/superblock/README.md diff --git a/torchao/sparsity/prototype/superblock/TRAINING.md b/torchao/prototype/sparsity/superblock/TRAINING.md similarity index 100% rename from torchao/sparsity/prototype/superblock/TRAINING.md rename to torchao/prototype/sparsity/superblock/TRAINING.md diff --git a/torchao/prototype/sparsity/superblock/__init__.py b/torchao/prototype/sparsity/superblock/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/torchao/sparsity/prototype/superblock/benchmark.py b/torchao/prototype/sparsity/superblock/benchmark.py similarity index 98% rename from torchao/sparsity/prototype/superblock/benchmark.py rename to torchao/prototype/sparsity/superblock/benchmark.py index a0fb27022c..253d29dbe0 100644 --- a/torchao/sparsity/prototype/superblock/benchmark.py +++ b/torchao/prototype/sparsity/superblock/benchmark.py @@ -7,7 +7,7 @@ dump as store_tuned_kernel_params, optimize_bsr_dense_addmm, ) -from torchao.sparsity.prototype.superblock.utils import ( +from torchao.prototype.sparsity.superblock.utils import ( accelerate_with_sparsity, get_args_parser, simulate_sparsity, diff --git a/torchao/sparsity/prototype/superblock/benchmark.sh b/torchao/prototype/sparsity/superblock/benchmark.sh similarity index 100% rename from torchao/sparsity/prototype/superblock/benchmark.sh rename to torchao/prototype/sparsity/superblock/benchmark.sh diff --git a/torchao/sparsity/prototype/superblock/benchmark_results.txt b/torchao/prototype/sparsity/superblock/benchmark_results.txt similarity index 100% rename from torchao/sparsity/prototype/superblock/benchmark_results.txt rename to torchao/prototype/sparsity/superblock/benchmark_results.txt diff --git a/torchao/prototype/sparsity/superblock/blocksparse.py b/torchao/prototype/sparsity/superblock/blocksparse.py new file mode 100644 index 0000000000..69c98f6afc --- /dev/null +++ b/torchao/prototype/sparsity/superblock/blocksparse.py @@ -0,0 +1,239 @@ +from functools import partial +from typing import Any, Callable, Dict, List, Optional, Tuple + +import torch +from torch.sparse._triton_ops import broadcast_batch_dims, bsr_dense_addmm, bsr_dense_mm +from torch.utils._python_dispatch import return_and_correct_aliasing +from torchao.quantization.quant_api import _get_linear_subclass_inserter +from torchao.utils import TorchAOBaseTensor + +aten = torch.ops.aten + + +# quantization support +@torch.library.custom_op("blocksparse::bsr_to_dense", mutates_args=()) +def bsr_to_dense( + crow_indices: torch.Tensor, + col_indices: torch.Tensor, + values: torch.Tensor, + M: int, + K: int, +) -> torch.Tensor: + return torch.sparse_bsr_tensor( + crow_indices=crow_indices, col_indices=col_indices, values=values, size=(M, K) + ).to_dense() + + +@torch.library.register_fake("blocksparse::bsr_to_dense") +def bsr_to_dense_abstract( + crow_indices: torch.Tensor, + col_indices: torch.Tensor, + values: torch.Tensor, + M: int, + K: int, +) -> torch.Tensor: + return torch.empty((M, K), dtype=values.dtype, device=values.device) + + +@torch.library.custom_op("blocksparse::int_addmm", mutates_args=()) +def blocksparse_int_addmm( + crow_indices: torch.Tensor, + col_indices: torch.Tensor, + values: torch.Tensor, + A: torch.Tensor, + left_alpha: torch.Tensor, + right_alpha: torch.Tensor, +) -> torch.Tensor: + assert values.dtype == torch.int8 + M = left_alpha.shape[-1] + K = A.shape[-2] + N = A.shape[-1] + weight_bsr = torch.sparse_bsr_tensor(crow_indices, col_indices, values, size=(M, K)) + original_batch_dims_broadcasted = broadcast_batch_dims( + blocksparse_int_addmm, weight_bsr, A + ) + out = A.new_empty(original_batch_dims_broadcasted + (M, N), dtype=torch.bfloat16) + return bsr_dense_addmm( + out, + weight_bsr, + A, + alpha=1, + beta=0, + out=out, + left_alpha=left_alpha, + right_alpha=right_alpha, + ).t() + + +@torch.library.register_fake("blocksparse::int_addmm") +def blocksparse_int_addmm_abstract( + crow_indices: torch.Tensor, + col_indices: torch.Tensor, + values: torch.Tensor, + A: torch.Tensor, + left_alpha: torch.Tensor, + right_alpha: torch.Tensor, +) -> torch.Tensor: + N = A.shape[-1] + M = left_alpha.shape[-1] + # to have the same strides as the transposed result + return torch.empty((M, N), dtype=torch.bfloat16, device=A.device).t() + + +# bsr wrapper custom op +@torch.library.custom_op("blocksparse::linear", mutates_args=()) +def blocksparse_linear( + A: torch.Tensor, + crow_indices: torch.Tensor, + col_indices: torch.Tensor, + values: torch.Tensor, + M: int, + K: int, + bias: torch.Tensor, +) -> torch.Tensor: + weight_bsr = torch.sparse_bsr_tensor(crow_indices, col_indices, values, size=(M, K)) + return torch.nn.functional.linear(A, weight_bsr, bias) + + +@torch.library.register_fake("blocksparse::linear") +def blocksparse_linear_abstract( + A: torch.Tensor, + crow_indices: torch.Tensor, + col_indices: torch.Tensor, + values: torch.Tensor, + M: int, + K: int, + bias: torch.Tensor, +) -> torch.Tensor: + new_shape = A.shape[:-1] + (M,) + return torch.empty(new_shape, dtype=A.dtype, device=A.device) + + +# Subclass definition +class BlockSparseTensor(TorchAOBaseTensor): + bsr_crow_indices: Optional[torch.Tensor] + bsr_col_indices: Optional[torch.Tensor] + bsr_values: Optional[torch.Tensor] + + __slots__ = ["bsr_crow_indices", "bsr_col_indices", "bsr_values"] + + @staticmethod + def __new__( # noqa: PYI034 + cls, + shape: torch.Size, + bsr_crow_indices: Optional[torch.Tensor], + bsr_col_indices: Optional[torch.Tensor], + bsr_values: Optional[torch.Tensor], + requires_grad: bool = False, + ): + if bsr_values is None: + raise ValueError( + "No values passed to BlockSparseTensor: bsr_values must be provided!" + ) + else: + previous_tensor = bsr_values + + kwargs = { + "device": previous_tensor.device, + "dtype": previous_tensor.dtype, + "layout": previous_tensor.layout, + "requires_grad": requires_grad, + } + tensor = torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] + tensor.bsr_crow_indices = bsr_crow_indices + tensor.bsr_col_indices = bsr_col_indices + tensor.bsr_values = bsr_values + return tensor + + def __repr__(self) -> str: # type: ignore[override] + assert hasattr(self, "shape") + return f"{self.__class__.__name__}(shape={self.shape})" + + def __tensor_flatten__(self) -> Tuple[List[str], Tuple[torch.Size, bool]]: + inner_tensors = list( + filter(lambda x: getattr(self, x) is not None, self.__slots__) + ) + tensor_meta = (self.shape, self.requires_grad) + return inner_tensors, tensor_meta + + @classmethod + def __tensor_unflatten__( + cls, + inner_tensors, + tensor_meta: Tuple[torch.Size, bool], + outer_size, + outer_stride, + ) -> torch.Tensor: + shape, requires_grad = tensor_meta + return cls( + shape=shape, + bsr_crow_indices=inner_tensors.get("bsr_crow_indices", None), + bsr_col_indices=inner_tensors.get("bsr_col_indices", None), + bsr_values=inner_tensors.get("bsr_values", None), + requires_grad=requires_grad, + ) + + @classmethod + def from_dense(cls, dense_tensor, blocksize): + bsr_tensor = dense_tensor.to_sparse_bsr(blocksize) + return cls( + shape=dense_tensor.shape, + bsr_crow_indices=bsr_tensor.crow_indices(), + bsr_col_indices=bsr_tensor.col_indices(), + bsr_values=bsr_tensor.values(), + requires_grad=False, + ) + + def apply_fn_to_shard(self, func): + return BlockSparseTensor( + shape=self.shape, + bsr_crow_indices=func(self.bsr_crow_indices), + bsr_col_indices=func(self.bsr_col_indices), + bsr_values=func(self.bsr_values), + requires_grad=self.requires_grad, + ) + + +# Subclass op dispatch registration +implements = BlockSparseTensor.implements + + +@implements(aten.detach.default) +def block_sparse_detach(func, types, args, kwargs): + return return_and_correct_aliasing( + func, args, kwargs, args[0].apply_fn_to_shard(torch.detach) + ) + + +@implements(aten.values.default) +def block_sparse_values(func, types, args, kwargs): + return args[0].bsr_values.detach() + + +@implements(aten.crow_indices.default) +def block_sparse_crow_indices(func, types, args, kwargs): + return args[0].bsr_crow_indices.detach() + + +@implements(aten.col_indices.default) +def block_sparse_col_indices(func, types, args, kwargs): + return args[0].bsr_col_indices.detach() + + +@implements(aten._nnz.default) +def block_sparse__nnz(func, types, args, kwargs): + return args[0].bsr_values.shape[0] + + +@implements(torch.nn.functional.linear) +def block_sparse_linear(func, types, args, kwargs): + x, w, bias = args + return torch.ops.blocksparse.linear( + x, w.crow_indices(), w.col_indices(), w.values(), w.shape[0], w.shape[1], bias + ) + + +def block_sparse_weight(blocksize=64): + return _get_linear_subclass_inserter( + partial(BlockSparseTensor.from_dense, blocksize=blocksize) + ) diff --git a/torchao/sparsity/prototype/superblock/evaluate.py b/torchao/prototype/sparsity/superblock/evaluate.py similarity index 91% rename from torchao/sparsity/prototype/superblock/evaluate.py rename to torchao/prototype/sparsity/superblock/evaluate.py index 5db9fc9e38..bd17a1379c 100644 --- a/torchao/sparsity/prototype/superblock/evaluate.py +++ b/torchao/prototype/sparsity/superblock/evaluate.py @@ -1,10 +1,11 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. import os + import torch import torchvision -from torchao.sparsity.prototype.superblock.train import evaluate, load_data -from torchao.sparsity.prototype.superblock.utils import ( +from torchao.prototype.sparsity.superblock.train import evaluate, load_data +from torchao.prototype.sparsity.superblock.utils import ( accelerate_with_sparsity, apply_sparsity, get_args_parser, @@ -62,7 +63,9 @@ def main(args): model = torch.compile(model, mode="max-autotune", fullgraph=True) criterion = torch.nn.CrossEntropyLoss(label_smoothing=args.label_smoothing) - return evaluate(model, criterion, data_loader_test, device=device, dtype=torch.bfloat16) + return evaluate( + model, criterion, data_loader_test, device=device, dtype=torch.bfloat16 + ) if __name__ == "__main__": @@ -92,7 +95,7 @@ def main(args): args.quantization, accuracy, throughput, - max_mem + max_mem, ] ) with open("evaluation_results.txt", "a") as f: diff --git a/torchao/sparsity/prototype/superblock/evaluate.sh b/torchao/prototype/sparsity/superblock/evaluate.sh similarity index 100% rename from torchao/sparsity/prototype/superblock/evaluate.sh rename to torchao/prototype/sparsity/superblock/evaluate.sh diff --git a/torchao/sparsity/prototype/superblock/evaluation_results.txt b/torchao/prototype/sparsity/superblock/evaluation_results.txt similarity index 100% rename from torchao/sparsity/prototype/superblock/evaluation_results.txt rename to torchao/prototype/sparsity/superblock/evaluation_results.txt diff --git a/torchao/prototype/sparsity/superblock/supermask.py b/torchao/prototype/sparsity/superblock/supermask.py new file mode 100644 index 0000000000..0b28763445 --- /dev/null +++ b/torchao/prototype/sparsity/superblock/supermask.py @@ -0,0 +1,275 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. + +import torch.nn as nn +import math +import torch +from torch.autograd import Variable +import torch.nn.functional as F +import numpy as np + +# original supermask +scores_min=None +scores_max=9e9 +uniform_init_01 = False + +# adjusted supermask, initialize scores with uniform distribution in [0,1], clamp scores in each step in [0,1] +# scores_min=0. +# scores_max=1. +# uniform_init_01 = True + +def percentile(t, q): + """Return the value that is larger than q% of t""" + k = 1 + round(.01 * float(q) * (t.numel() - 1)) + return t.view(-1).kthvalue(k).values + + +class GetSubnet(torch.autograd.Function): + """Supermask STE function""" + @staticmethod + def forward(ctx, scores, zeros, ones, sparsity): + clamped_scores = scores.clamp(min=scores_min,max=scores_max) + k_val = percentile(clamped_scores, sparsity*100) + return torch.where(clamped_scores < k_val, zeros.to(scores.device), ones.to(scores.device)) + @staticmethod + def backward(ctx, g): + return g, None, None, None + + +class SupermaskLinear(nn.Linear): + """Supermask class for Linear layer""" + def __init__(self, sparsity, fixed_mask, fixed_weight, bitwidth, transform, fixed_transform, *args, **kwargs): + tile_size = kwargs.pop("tile_size", 1) + super(SupermaskLinear, self).__init__(*args, **kwargs) + # initialize the scores + max_sparsity = 1 - (1 / math.prod([math.ceil(k / tile_size) for k in self.weight.size()])) + self.sparsity = sparsity + if self.sparsity > max_sparsity: + print( + f"reducing sparsity from {self.sparsity} to {max_sparsity}", + f"(maximum sparsity for layer with shape {self.weight.size()} and tile size {tile_size})" + ) + self.sparsity = max_sparsity + self.tile_size = tile_size + self.sparsify_weights = False + self.scores = nn.Parameter( + torch.empty( + [max(1, int(math.ceil(wn / tile_size))) for wn in self.weight.size()] + ), + requires_grad=not fixed_mask, + ) + nn.init.uniform_(self.scores) if uniform_init_01 else nn.init.kaiming_uniform_(self.scores, a=math.sqrt(5)) + + # the shift and the scale are transformation parameters + # the actually used weights = self.weight*self.scale+self.shift + # the transformation is activated only for quantized weights + self.shift=nn.Parameter(torch.Tensor(1).fill_(0.), requires_grad=False) + self.scale=nn.Parameter(torch.Tensor(1).fill_(1.), requires_grad=False) + + with torch.no_grad(): + # if bitwidth is None, then use floating point values in self.weight + # if bitwidth is not None, then quantize self.weight into k-bit (k=bitwidth) + # quantized values are -2^(k-1), -2^(k-1)+1, ..., 0, 1, ..., 2^(k-1)-1 + # these quantized values are uniformly distributed + if bitwidth is not None: + weights_max = torch.max(self.weight).item() + weights_min = torch.min(self.weight).item() + least_step = (weights_max-weights_min)/pow(2,bitwidth) + left_bound = weights_min-1e-6 + right_bound = weights_min+least_step+1e-6 + # self.shift=nn.Parameter(torch.Tensor(1).fill_( (weights_min+(pow(2,bitwidth-1)+0.5)*least_step) if transform[0] is None else transform[0] ), requires_grad=not fixed_transform[0]) + # self.scale=nn.Parameter(torch.Tensor(1).fill_( least_step if transform[1] is None else transform[1] ), requires_grad=not fixed_transform[1]) + # for example, if using binary weights (k=1) with -a, +a, set transform = [a,2a]; if using binary weights (k=1) with a, 0, set transform = [0,-a]; + self.shift=nn.Parameter(torch.Tensor(1).fill_( 0. if transform[0] is None else transform[0] ), requires_grad=not fixed_transform[0]) + self.scale=nn.Parameter(torch.Tensor(1).fill_( 1. if transform[1] is None else transform[1] ), requires_grad=not fixed_transform[1]) + for i in range(-int(pow(2,bitwidth-1)),int(pow(2,bitwidth-1))): + self.weight[torch.logical_and(self.weight>left_bound, self.weight<=right_bound)] = i + left_bound = right_bound + right_bound += least_step + + self.weight.requires_grad = not fixed_weight + + def get_mask(self): + subnet = GetSubnet.apply(self.scores, + torch.zeros_like(self.scores), + torch.ones_like(self.scores), + self.sparsity) + + if self.tile_size != 1: + for i, k in enumerate(self.weight.shape): + subnet = subnet.repeat_interleave(self.tile_size, dim=i) + subnet = torch.narrow(subnet, i, 0, k) + + return subnet + + def sparsify_offline(self): + subnet = self.get_mask() + self.weight.data = (self.weight*self.scale+self.shift) * subnet + self.sparsify_weights = True + + def forward(self, x): + if not self.sparsify_weights: + subnet = self.get_mask() + w = (self.weight*self.scale+self.shift) * subnet + else: + w = self.weight + return F.linear(x, w, self.bias) + + +class SupermaskConv2d(nn.Conv2d): + """Supermask class for Conv2d layer""" + def __init__(self, sparsity, fixed_mask, fixed_weight, bitwidth, transform, fixed_transform, *args, **kwargs): + tile_size = kwargs.pop("tile_size", 1) + super(SupermaskConv2d, self).__init__(*args, **kwargs) + # initialize the scores + max_sparsity = 1 - (1 / math.prod([math.ceil(k / tile_size) for k in self.weight.size()])) + self.sparsity = sparsity + if self.sparsity > max_sparsity: + print( + f"reducing sparsity from {self.sparsity} to {max_sparsity}", + f"(maximum sparsity for layer with shape {self.weight.size()} and tile size {tile_size})" + ) + self.sparsity = max_sparsity + self.tile_size = tile_size + self.scores = nn.Parameter( + torch.empty( + [max(1, int(math.ceil(wn / tile_size))) for wn in self.weight.size()] + ), + requires_grad=not fixed_mask, + ) + nn.init.uniform_(self.scores) if uniform_init_01 else nn.init.kaiming_uniform_(self.scores, a=math.sqrt(5)) + + # the shift and the scale are transformation parameters + # the actually used weights = self.weight*self.scale+self.shift + # the transformation is activated only for quantized weights + self.shift=nn.Parameter(torch.Tensor(1).fill_(0.), requires_grad=False) + self.scale=nn.Parameter(torch.Tensor(1).fill_(1.), requires_grad=False) + + with torch.no_grad(): + # if bitwidth is None, then use floating point values in self.weight + # if bitwidth is not None, then quantize self.weight into k-bit (k=bitwidth) + # quantized values are -2^(k-1), -2^(k-1)+1, ..., 0, 1, ..., 2^(k-1)-1 + # these quantized values are uniformly distributed + if bitwidth is not None: + weights_max = torch.max(self.weight).item() + weights_min = torch.min(self.weight).item() + least_step = (weights_max-weights_min)/pow(2,bitwidth) + left_bound = weights_min-1e-6 + right_bound = weights_min+least_step+1e-6 + # self.shift=nn.Parameter(torch.Tensor(1).fill_( (weights_min+(pow(2,bitwidth-1)+0.5)*least_step) if transform[0] is None else transform[0] ), requires_grad=not fixed_transform[0]) + # self.scale=nn.Parameter(torch.Tensor(1).fill_( least_step if transform[1] is None else transform[1]), requires_grad=not fixed_transform[1]) + # for example, if using binary weights (k=1) with -a, +a, set transform = [a,2a]; if using binary weights (k=1) with a, 0, set transform = [0,-a]; + self.shift=nn.Parameter(torch.Tensor(1).fill_( 0. if transform[0] is None else transform[0] ), requires_grad=not fixed_transform[0]) + self.scale=nn.Parameter(torch.Tensor(1).fill_( 1. if transform[1] is None else transform[1] ), requires_grad=not fixed_transform[1]) + for i in range(-int(pow(2,bitwidth-1)),int(pow(2,bitwidth-1))): + self.weight[torch.logical_and(self.weight>left_bound, self.weight<=right_bound)] = i + left_bound = right_bound + right_bound += least_step + + self.weight.requires_grad = not fixed_weight + + def forward(self, x): + subnet = GetSubnet.apply(self.scores, + torch.zeros_like(self.scores), + torch.ones_like(self.scores), + self.sparsity) + + if self.tile_size != 1: + for i, k in enumerate(self.weight.shape): + # if k == 1: continue + subnet = subnet.repeat_interleave(self.tile_size, dim=i) + subnet = torch.narrow(subnet, i, 0, k) + + w = (self.weight*self.scale+self.shift) * subnet + return F.conv2d(x, w, self.bias, self.stride, self.padding, self.dilation, self.groups) + +def apply_supermask( + model, + linear_sparsity=0.0, + linear_sp_tilesize=1, + conv1x1_sparsity=0.0, + conv1x1_sp_tilesize=1, + conv_sparsity=0.0, + conv_sp_tilesize=1, + skip_last_layer_sparsity=False, + skip_first_transformer_sparsity=False, + device="cuda", + verbose=False, +): + sparsified_modules = {} + + for n, m in model.named_modules(): + # check conditions for skipping sparsity + if skip_last_layer_sparsity and n == "heads.head": + continue + if skip_first_transformer_sparsity and "encoder.layers.encoder_layer_0" in n: + continue + + # convert 1x1 convolutions + if conv1x1_sparsity != 0.0 and isinstance(m, torch.nn.Conv2d) and m.kernel_size == (1, 1): + new_m = SupermaskConv2d( + conv1x1_sparsity, False, False, None, None, None, + m.in_channels, + m.out_channels, + m.kernel_size, + stride=m.stride, + padding=m.padding, + dilation=m.dilation, + groups=m.groups, + bias=m.bias is not None, + padding_mode=m.padding_mode, + device=device, + tile_size=conv1x1_sp_tilesize, + ) + new_m.weight.data.copy_(m.weight.data) + if m.bias is not None: + new_m.bias.data.copy_(m.bias.data) + sparsified_modules[n] = new_m + continue + + # convert all other convolutions (not tested!) + if conv_sparsity != 0.0 and isinstance(m, torch.nn.Conv2d): + new_m = SupermaskConv2d( + conv_sparsity, False, False, None, None, None, + m.in_channels, + m.out_channels, + m.kernel_size, + stride=m.stride, + padding=m.padding, + dilation=m.dilation, + groups=m.groups, + bias=m.bias is not None, + padding_mode=m.padding_mode, + device=device, + tile_size=conv_sp_tilesize, + ) + new_m.weight.data.copy_(m.weight.data) + if m.bias is not None: + new_m.bias.data.copy_(m.bias.data) + sparsified_modules[n] = new_m + continue + + if linear_sparsity != 0.0 and isinstance(m, torch.nn.Linear): + new_m = SupermaskLinear( + linear_sparsity, False, False, None, None, None, + m.in_features, + m.out_features, + bias=m.bias is not None, + device=device, + tile_size=linear_sp_tilesize, + ) + new_m.weight.data.copy_(m.weight.data) + if m.bias is not None: + new_m.bias.data.copy_(m.bias.data) + sparsified_modules[n] = new_m + continue + + # add modules to model + for k, v in sparsified_modules.items(): + sm_name, ch_name = k.rsplit(".", 1) + sm = model.get_submodule(sm_name) + sm.add_module(ch_name, v) + + if verbose: + print(f'sparsified module "{k}" with sparsity={v.sparsity}, tile size={v.tile_size}') + + return model diff --git a/torchao/sparsity/prototype/superblock/train.py b/torchao/prototype/sparsity/superblock/train.py similarity index 99% rename from torchao/sparsity/prototype/superblock/train.py rename to torchao/prototype/sparsity/superblock/train.py index acfed09bc6..355b5f33ef 100644 --- a/torchao/sparsity/prototype/superblock/train.py +++ b/torchao/prototype/sparsity/superblock/train.py @@ -14,7 +14,7 @@ from torch import nn from torch.utils.data.dataloader import default_collate -from torchao.sparsity.prototype.superblock.utils import simulate_sparsity +from torchao.prototype.sparsity.superblock.utils import simulate_sparsity from torchvision.transforms.functional import InterpolationMode from utils import RASampler @@ -149,7 +149,11 @@ def evaluate( f"{header} Acc@1 {metric_logger.acc1.global_avg:.3f} Acc@5 {metric_logger.acc5.global_avg:.3f}" ) total_time = encoder_time / 1000.0 - return metric_logger.acc1.global_avg, num_processed_samples.item() / total_time, max_mem + return ( + metric_logger.acc1.global_avg, + num_processed_samples.item() / total_time, + max_mem, + ) def _get_cache_path(filepath): diff --git a/torchao/prototype/sparsity/superblock/utils.py b/torchao/prototype/sparsity/superblock/utils.py new file mode 100644 index 0000000000..9ed38e50d3 --- /dev/null +++ b/torchao/prototype/sparsity/superblock/utils.py @@ -0,0 +1,1297 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. + +import argparse +import copy +import datetime +import errno +import hashlib +import math +import os +import time +from collections import defaultdict, deque, OrderedDict +from typing import List, Optional, Tuple + +import torch +from torchao.prototype.sparsity.sparsifier.weight_norm_sparsifier import ( + WeightNormSparsifier, +) +from torchao.prototype.sparsity.superblock.blocksparse import block_sparse_weight +from torchao.prototype.sparsity.superblock.supermask import ( + apply_supermask, + SupermaskLinear, +) + +from torchao.quantization import int8_dynamic_activation_int8_weight, quantize_ +from torchao.sparsity import semi_sparse_weight, sparsify_ +from torchvision.transforms import autoaugment, functional as F, transforms +from torchvision.transforms.functional import InterpolationMode + + +def get_args_parser(train=False, evaluate=False, benchmark=False): + assert ( + sum([train, evaluate, benchmark]) == 1 + ), "One and only one of training, evaluation, or benchmark can be true" + + # Shared common args + parser = argparse.ArgumentParser( + description="SuperBlock Imagenet Training/Evaluation/Benchmarking Script", + add_help=True, + ) + parser.add_argument("--data-path", type=str, help="IMAGENET dataset path") + parser.add_argument( + "--model", + default="vit_b_16", + choices=["vit_b_16", "vit_h_14"], + type=str, + help="ViT base model", + ) + parser.add_argument( + "--device", default="cuda", type=str, help="device (Default: cuda)" + ) + parser.add_argument( + "-b", "--batch-size", default=32, type=int, help="per device batch size" + ) + parser.add_argument( + "--val-crop-size", + default=224, + type=int, + help="the central crop size used for validation (default: 224)", + ) + parser.add_argument( + "--sparsity", + choices=["bsr", "semi_structured"], + default=None, + help="weight sparsification to apply", + ) + parser.add_argument( + "--bsr", + type=int, + nargs="?", + const=256, + default=None, + help="Convert sparsified weights to BSR format with optional block size (default: 256)", + ) + parser.add_argument("--sparsity-linear", type=float, default=0.0) + parser.add_argument("--sparsity-conv1x1", type=float, default=0.0) + parser.add_argument("--sparsity-conv", type=float, default=0.0) + parser.add_argument( + "--skip-last-layer-sparsity", + action="store_true", + help="Skip applying sparsity to the last linear layer (for vit only)", + ) + parser.add_argument( + "--skip-first-transformer-sparsity", + action="store_true", + help="Skip applying sparsity to the first transformer layer (for vit only)", + ) + parser.add_argument( + "--quantization", action="store_true", help="Run with int8 dynamic quantization" + ) + parser.add_argument( + "--weights", default=None, type=str, help="the weights enum name to load" + ) + parser.add_argument( + "--weights-path", + type=str, + help="optional checkpoint to load weights after intialization", + ) + parser.add_argument( + "--header", action="store_true", help="Print header for first run" + ) + + # Eval a subset of training args + # lots of training args + if train or evaluate: + parser.add_argument( + "-j", + "--workers", + default=16, + type=int, + metavar="N", + help="number of data loading workers", + ) + parser.add_argument( + "--accumulation-steps", + default=1, + type=int, + help="Number of steps to accumulate gradients over", + ) + parser.add_argument( + "--epochs", + default=90, + type=int, + metavar="N", + help="number of total epochs to run", + ) + parser.add_argument("--opt", default="sgd", type=str, help="optimizer") + parser.add_argument( + "--lr", default=0.1, type=float, help="initial learning rate" + ) + parser.add_argument( + "--momentum", default=0.9, type=float, metavar="M", help="momentum" + ) + parser.add_argument( + "--wd", + "--weight-decay", + default=1e-4, + type=float, + metavar="W", + help="weight decay", + dest="weight_decay", + ) + parser.add_argument( + "--norm-weight-decay", + default=None, + type=float, + help="weight decay for Normalization layers (default: None, same value as --wd)", + ) + parser.add_argument( + "--bias-weight-decay", + default=None, + type=float, + help="weight decay for bias parameters of all layers (default: None, same value as --wd)", + ) + parser.add_argument( + "--transformer-embedding-decay", + default=None, + type=float, + help="weight decay for embedding parameters for vision transformer models (default: None, same value as --wd)", + ) + parser.add_argument( + "--label-smoothing", + default=0.0, + type=float, + help="label smoothing (default: 0.0)", + dest="label_smoothing", + ) + parser.add_argument( + "--mixup-alpha", default=0.0, type=float, help="mixup alpha (default: 0.0)" + ) + parser.add_argument( + "--cutmix-alpha", + default=0.0, + type=float, + help="cutmix alpha (default: 0.0)", + ) + parser.add_argument( + "--lr-scheduler", + default="steplr", + type=str, + help="the lr scheduler (default: steplr)", + ) + parser.add_argument( + "--lr-warmup-epochs", + default=0, + type=int, + help="the number of epochs to warmup (default: 0)", + ) + parser.add_argument( + "--lr-warmup-method", + default="constant", + type=str, + help="the warmup method (default: constant)", + ) + parser.add_argument( + "--lr-warmup-decay", default=0.01, type=float, help="the decay for lr" + ) + parser.add_argument( + "--lr-step-size", + default=30, + type=int, + help="decrease lr every step-size epochs", + ) + parser.add_argument( + "--lr-gamma", + default=0.1, + type=float, + help="decrease lr by a factor of lr-gamma", + ) + parser.add_argument( + "--lr-min", + default=0.0, + type=float, + help="minimum lr of lr schedule (default: 0.0)", + ) + parser.add_argument( + "--print-freq", default=10, type=int, help="print frequency" + ) + parser.add_argument( + "--output-dir", default=".", type=str, help="path to save outputs" + ) + parser.add_argument( + "--resume", + action="store_true", + help='Resumes training from latest available checkpoint ("model_.pth")', + ) + parser.add_argument( + "--start-epoch", default=0, type=int, metavar="N", help="start epoch" + ) + parser.add_argument( + "--cache-dataset", + dest="cache_dataset", + help="Cache the datasets for quicker initialization. It also serializes the transforms", + action="store_true", + ) + parser.add_argument( + "--sync-bn", dest="sync_bn", help="Use sync batch norm", action="store_true" + ) + parser.add_argument( + "--auto-augment", + default=None, + type=str, + help="auto augment policy (default: None)", + ) + parser.add_argument( + "--ra-magnitude", + default=9, + type=int, + help="magnitude of auto augment policy", + ) + parser.add_argument( + "--augmix-severity", default=3, type=int, help="severity of augmix policy" + ) + parser.add_argument( + "--random-erase", + default=0.0, + type=float, + help="random erasing probability (default: 0.0)", + ) + # Mixed precision training parameters + parser.add_argument( + "--amp", + action="store_true", + help="Use torch.cuda.amp for mixed precision training", + ) + # distributed training parameters + parser.add_argument( + "--world-size", default=1, type=int, help="number of distributed processes" + ) + parser.add_argument( + "--dist-url", + default="env://", + type=str, + help="url used to set up distributed training", + ) + parser.add_argument( + "--model-ema", + action="store_true", + help="enable tracking Exponential Moving Average of model parameters", + ) + parser.add_argument( + "--model-ema-steps", + type=int, + default=32, + help="the number of iterations that controls how often to update the EMA model (default: 32)", + ) + parser.add_argument( + "--model-ema-decay", + type=float, + default=0.99998, + help="decay factor for Exponential Moving Average of model parameters (default: 0.99998)", + ) + parser.add_argument( + "--use-deterministic-algorithms", + action="store_true", + help="Forces the use of deterministic algorithms only.", + ) + parser.add_argument( + "--interpolation", + default="bilinear", + type=str, + help="the interpolation method (default: bilinear)", + ) + parser.add_argument( + "--val-resize-size", + default=256, + type=int, + help="the resize size used for validation (default: 256)", + ) + parser.add_argument( + "--train-crop-size", + default=224, + type=int, + help="the random crop size used for training (default: 224)", + ) + parser.add_argument( + "--clip-grad-norm", + default=None, + type=float, + help="the maximum gradient norm (default None)", + ) + parser.add_argument( + "--ra-reps", + default=3, + type=int, + help="number of repetitions for Repeated Augmentation (default: 3)", + ) + parser.add_argument( + "--meta", action="store_true", help="Use Meta internal imagenet structure" + ) + + if benchmark: + parser.add_argument( + "--dtype", + choices=["float32", "bfloat16", "float16"], + help="Data type", + default="bfloat16", + ) + parser.add_argument( + "--tune-kernel-params", + action="store_true", + help="Tune kernel params for BSR", + ) + parser.add_argument( + "--profile", action="store_true", help="Dump Prefetto trace" + ) + + return parser + + +# filter functions +def mlp_0_only(mod, name): + return isinstance(mod, torch.nn.Linear) and "mlp.0" in name + + +def mlp_3_only(mod, name): + return isinstance(mod, torch.nn.Linear) and "mlp.3" in name + + +def mlp_only(mod, name): + return isinstance(mod, torch.nn.Linear) and "mlp" in name + + +def superblock_only(mod, name): + return isinstance(mod, SupermaskLinear) and "mlp" in name + + +def mlp_only_with_args( + mod, name, skip_last_layer_sparsity=False, skip_first_transformer_sparsity=False +): + if skip_last_layer_sparsity and "heads.head" in name: + return False + if skip_first_transformer_sparsity and "encoder.layers.encoder_layer_0" in name: + return False + if isinstance(mod, torch.nn.Linear) and "mlp" in name: + return True + return False + + +### Custom sparsification utils +def apply_sparsity(model): + for name, module in model.named_modules(): + if isinstance(module, SupermaskLinear) and "mlp" in name: + module.sparsify_offline() + + +def accelerate_with_sparsity(model, args): + if args.sparsity == "bsr": + apply_sparsity(model) + if args.quantization: + from torchao.dtypes.affine_quantized_tensor import BlockSparseLayout + + quantize_( + model, + int8_dynamic_activation_int8_weight( + _layout=BlockSparseLayout(blocksize=args.bsr) + ), + superblock_only, + ) + else: + assert args.bsr is not None, "BSR requires a block size" + sparsify_(model, block_sparse_weight(blocksize=args.bsr), superblock_only) + elif args.sparsity == "semi_structured": + if args.quantization: + from torchao.dtypes.affine_quantized_tensor import SemiSparseLayout + + quantize_( + model, + int8_dynamic_activation_int8_weight(layout=SemiSparseLayout()), + mlp_0_only, + ) + sparsify_(model, semi_sparse_weight(), mlp_3_only) + else: + sparsify_(model, semi_sparse_weight(), mlp_only) + else: + if args.quantization: + quantize_(model, int8_dynamic_activation_int8_weight(), mlp_only) + + +def simulate_sparsity(model, args): + if args.sparsity == "bsr": + apply_supermask( + model, + linear_sparsity=args.sparsity_linear, + linear_sp_tilesize=args.bsr, + conv1x1_sparsity=args.sparsity_conv1x1, + conv1x1_sp_tilesize=args.bsr, + conv_sparsity=args.sparsity_conv, + conv_sp_tilesize=args.bsr, + skip_last_layer_sparsity=args.skip_last_layer_sparsity, + skip_first_transformer_sparsity=args.skip_first_transformer_sparsity, + device=args.device, + verbose=False, + ) + elif args.sparsity == "semi_structured": + sparse_config = [] + for name, mod in model.named_modules(): + if mlp_only_with_args( + mod, + name, + skip_first_transformer_sparsity=args.skip_first_transformer_sparsity, + skip_last_layer_sparsity=args.skip_last_layer_sparsity, + ): + sparse_config.append({"tensor_fqn": f"{name}.weight"}) + + sparsifier = WeightNormSparsifier( + sparsity_level=1.0, sparse_block_shape=(1, 4), zeros_per_block=2 + ) + sparsifier.prepare(model, sparse_config) + sparsifier.step() + return sparsifier + + +# ------------------------------------------------------------ +# The following code contains torchvision reference code, +# largely copied from: https://github.com/pytorch/vision/tree/main/references/classification +# Please open issues in the original repository if you have questions. + + +class SmoothedValue: + """Track a series of values and provide access to smoothed values over a + window or the global series average. + """ + + def __init__(self, window_size=20, fmt=None): + if fmt is None: + fmt = "{median:.4f} ({global_avg:.4f})" + self.deque = deque(maxlen=window_size) + self.total = 0.0 + self.count = 0 + self.fmt = fmt + + def update(self, value, n=1): + self.deque.append(value) + self.count += n + self.total += value * n + + def synchronize_between_processes(self): + """ + Warning: does not synchronize the deque! + """ + t = reduce_across_processes([self.count, self.total]) + t = t.tolist() + self.count = int(t[0]) + self.total = t[1] + + @property + def median(self): + d = torch.tensor(list(self.deque)) + return d.median().item() + + @property + def avg(self): + d = torch.tensor(list(self.deque), dtype=torch.float32) + return d.mean().item() + + @property + def global_avg(self): + return self.total / self.count + + @property + def max(self): + return max(self.deque) + + @property + def value(self): + return self.deque[-1] + + def __str__(self): + return self.fmt.format( + median=self.median, + avg=self.avg, + global_avg=self.global_avg, + max=self.max, + value=self.value, + ) + + +class MetricLogger: + def __init__(self, delimiter="\t"): + self.meters = defaultdict(SmoothedValue) + self.delimiter = delimiter + + def update(self, **kwargs): + for k, v in kwargs.items(): + if isinstance(v, torch.Tensor): + v = v.item() + assert isinstance(v, (float, int)) + self.meters[k].update(v) + + def __getattr__(self, attr): + if attr in self.meters: + return self.meters[attr] + if attr in self.__dict__: + return self.__dict__[attr] + raise AttributeError( + f"'{type(self).__name__}' object has no attribute '{attr}'" + ) + + def __str__(self): + loss_str = [] + for name, meter in self.meters.items(): + loss_str.append(f"{name}: {str(meter)}") + return self.delimiter.join(loss_str) + + def synchronize_between_processes(self): + for meter in self.meters.values(): + meter.synchronize_between_processes() + + def add_meter(self, name, meter): + self.meters[name] = meter + + def log_every(self, iterable, print_freq, header=None): + i = 0 + if not header: + header = "" + start_time = time.time() + end = time.time() + iter_time = SmoothedValue(fmt="{avg:.4f}") + data_time = SmoothedValue(fmt="{avg:.4f}") + space_fmt = ":" + str(len(str(len(iterable)))) + "d" + if torch.cuda.is_available(): + log_msg = self.delimiter.join( + [ + header, + "[{0" + space_fmt + "}/{1}]", + "eta: {eta}", + "{meters}", + "time: {time}", + "data: {data}", + "max mem: {memory:.0f}", + ] + ) + else: + log_msg = self.delimiter.join( + [ + header, + "[{0" + space_fmt + "}/{1}]", + "eta: {eta}", + "{meters}", + "time: {time}", + "data: {data}", + ] + ) + MB = 1024.0 * 1024.0 + for obj in iterable: + data_time.update(time.time() - end) + yield obj + iter_time.update(time.time() - end) + if i % print_freq == 0: + eta_seconds = iter_time.global_avg * (len(iterable) - i) + eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) + if torch.cuda.is_available(): + print( + log_msg.format( + i, + len(iterable), + eta=eta_string, + meters=str(self), + time=str(iter_time), + data=str(data_time), + memory=torch.cuda.max_memory_allocated() / MB, + ) + ) + else: + print( + log_msg.format( + i, + len(iterable), + eta=eta_string, + meters=str(self), + time=str(iter_time), + data=str(data_time), + ) + ) + i += 1 + end = time.time() + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + print(f"{header} Total time: {total_time_str}") + + +class ExponentialMovingAverage(torch.optim.swa_utils.AveragedModel): + """Maintains moving averages of model parameters using an exponential decay. + ``ema_avg = decay * avg_model_param + (1 - decay) * model_param`` + `torch.optim.swa_utils.AveragedModel `_ + is used to compute the EMA. + """ + + def __init__(self, model, decay, device="cpu"): + def ema_avg(avg_model_param, model_param, num_averaged): + return decay * avg_model_param + (1 - decay) * model_param + + super().__init__(model, device, ema_avg, use_buffers=True) + + +def accuracy(output, target, topk=(1,)): + """Computes the accuracy over the k top predictions for the specified values of k""" + with torch.inference_mode(): + maxk = max(topk) + batch_size = target.size(0) + if target.ndim == 2: + target = target.max(dim=1)[1] + + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + correct = pred.eq(target[None]) + + res = [] + for k in topk: + correct_k = correct[:k].flatten().sum(dtype=torch.float32) + res.append(correct_k * (100.0 / batch_size)) + return res + + +def mkdir(path): + try: + os.makedirs(path) + except OSError as e: + if e.errno != errno.EEXIST: + raise + + +def setup_for_distributed(is_master): + """ + This function disables printing when not in master process + """ + import builtins as __builtin__ + + builtin_print = __builtin__.print + + def print(*args, **kwargs): + force = kwargs.pop("force", False) + if is_master or force: + builtin_print(*args, **kwargs) + + __builtin__.print = print + + +def is_dist_avail_and_initialized(): + if not torch.distributed.is_available(): + return False + if not torch.distributed.is_initialized(): + return False + return True + + +def get_world_size(): + if not is_dist_avail_and_initialized(): + return 1 + return torch.distributed.get_world_size() + + +def get_rank(): + if not is_dist_avail_and_initialized(): + return 0 + return torch.distributed.get_rank() + + +def is_main_process(): + return get_rank() == 0 + + +def save_on_master(*args, **kwargs): + if is_main_process(): + torch.save(*args, **kwargs) + + +def init_distributed_mode(args): + if "RANK" in os.environ and "WORLD_SIZE" in os.environ: + args.rank = int(os.environ["RANK"]) + args.world_size = int(os.environ["WORLD_SIZE"]) + args.gpu = int(os.environ["LOCAL_RANK"]) + elif "SLURM_PROCID" in os.environ: + args.rank = int(os.environ["SLURM_PROCID"]) + args.gpu = args.rank % torch.cuda.device_count() + elif hasattr(args, "rank"): + pass + else: + print("Not using distributed mode") + args.distributed = False + return + + args.distributed = True + + torch.cuda.set_device(args.gpu) + args.dist_backend = "nccl" + print(f"| distributed init (rank {args.rank})", flush=True) + torch.distributed.init_process_group( + backend=args.dist_backend, + init_method=args.dist_url, + world_size=args.world_size, + rank=args.rank, + ) + torch.distributed.barrier() + setup_for_distributed(args.rank == 0) + + +def average_checkpoints(inputs): + """Loads checkpoints from inputs and returns a model with averaged weights. Original implementation taken from: + https://github.com/pytorch/fairseq/blob/a48f235636557b8d3bc4922a6fa90f3a0fa57955/scripts/average_checkpoints.py#L16 + + Args: + inputs (List[str]): An iterable of string paths of checkpoints to load from. + Returns: + A dict of string keys mapping to various values. The 'model' key + from the returned dict should correspond to an OrderedDict mapping + string parameter names to torch Tensors. + """ + params_dict = OrderedDict() + params_keys = None + new_state = None + num_models = len(inputs) + for fpath in inputs: + with open(fpath, "rb") as f: + state = torch.load( + f, + map_location=( + lambda s, _: torch.serialization.default_restore_location(s, "cpu") + ), + ) + # Copies over the settings from the first checkpoint + if new_state is None: + new_state = state + model_params = state["model"] + model_params_keys = list(model_params.keys()) + if params_keys is None: + params_keys = model_params_keys + elif params_keys != model_params_keys: + raise KeyError( + f"For checkpoint {f}, expected list of params: {params_keys}, but found: {model_params_keys}" + ) + for k in params_keys: + p = model_params[k] + if isinstance(p, torch.HalfTensor): + p = p.float() + if k not in params_dict: + params_dict[k] = p.clone() + # NOTE: clone() is needed in case of p is a shared parameter + else: + params_dict[k] += p + averaged_params = OrderedDict() + for k, v in params_dict.items(): + averaged_params[k] = v + if averaged_params[k].is_floating_point(): + averaged_params[k].div_(num_models) + else: + averaged_params[k] //= num_models + new_state["model"] = averaged_params + return new_state + + +def store_model_weights(model, checkpoint_path, checkpoint_key="model", strict=True): + """ + This method can be used to prepare weights files for new models. It receives as + input a model architecture and a checkpoint from the training script and produces + a file with the weights ready for release. + + Examples: + from torchvision import models as M + + # Classification + model = M.mobilenet_v3_large(weights=None) + print(store_model_weights(model, './class.pth')) + + # Quantized Classification + model = M.quantization.mobilenet_v3_large(weights=None, quantize=False) + model.fuse_model(is_qat=True) + model.qconfig = torch.ao.quantization.get_default_qat_qconfig('qnnpack') + _ = torch.ao.quantization.prepare_qat(model, inplace=True) + print(store_model_weights(model, './qat.pth')) + + # Object Detection + model = M.detection.fasterrcnn_mobilenet_v3_large_fpn(weights=None, weights_backbone=None) + print(store_model_weights(model, './obj.pth')) + + # Segmentation + model = M.segmentation.deeplabv3_mobilenet_v3_large(weights=None, weights_backbone=None, aux_loss=True) + print(store_model_weights(model, './segm.pth', strict=False)) + + Args: + model (pytorch.nn.Module): The model on which the weights will be loaded for validation purposes. + checkpoint_path (str): The path of the checkpoint we will load. + checkpoint_key (str, optional): The key of the checkpoint where the model weights are stored. + Default: "model". + strict (bool): whether to strictly enforce that the keys + in :attr:`state_dict` match the keys returned by this module's + :meth:`~torch.nn.Module.state_dict` function. Default: ``True`` + + Returns: + output_path (str): The location where the weights are saved. + """ + # Store the new model next to the checkpoint_path + checkpoint_path = os.path.abspath(checkpoint_path) + output_dir = os.path.dirname(checkpoint_path) + + # Deep copy to avoid side-effects on the model object. + model = copy.deepcopy(model) + checkpoint = torch.load(checkpoint_path, map_location="cpu") + + # Load the weights to the model to validate that everything works + # and remove unnecessary weights (such as auxiliaries, etc) + if checkpoint_key == "model_ema": + del checkpoint[checkpoint_key]["n_averaged"] + torch.nn.modules.utils.consume_prefix_in_state_dict_if_present( + checkpoint[checkpoint_key], "module." + ) + model.load_state_dict(checkpoint[checkpoint_key], strict=strict) + + tmp_path = os.path.join(output_dir, str(model.__hash__())) + torch.save(model.state_dict(), tmp_path) + + sha256_hash = hashlib.sha256() + with open(tmp_path, "rb") as f: + # Read and update hash string value in blocks of 4K + for byte_block in iter(lambda: f.read(4096), b""): + sha256_hash.update(byte_block) + hh = sha256_hash.hexdigest() + + output_path = os.path.join(output_dir, "weights-" + str(hh[:8]) + ".pth") + os.replace(tmp_path, output_path) + + return output_path + + +def reduce_across_processes(val): + if not is_dist_avail_and_initialized(): + # nothing to sync, but we still convert to tensor for consistency with the distributed case. + return torch.tensor(val) + + t = torch.tensor(val, device="cuda") + torch.distributed.barrier() + torch.distributed.all_reduce(t) + return t + + +def set_weight_decay( + model: torch.nn.Module, + weight_decay: float, + norm_weight_decay: Optional[float] = None, + norm_classes: Optional[List[type]] = None, + custom_keys_weight_decay: Optional[List[Tuple[str, float]]] = None, +): + if not norm_classes: + norm_classes = [ + torch.nn.modules.batchnorm._BatchNorm, + torch.nn.LayerNorm, + torch.nn.GroupNorm, + torch.nn.modules.instancenorm._InstanceNorm, + torch.nn.LocalResponseNorm, + ] + norm_classes = tuple(norm_classes) + + params = { + "other": [], + "norm": [], + } + params_weight_decay = { + "other": weight_decay, + "norm": norm_weight_decay, + } + custom_keys = [] + if custom_keys_weight_decay is not None: + for key, weight_decay in custom_keys_weight_decay: + params[key] = [] + params_weight_decay[key] = weight_decay + custom_keys.append(key) + + def _add_params(module, prefix=""): + for name, p in module.named_parameters(recurse=False): + if not p.requires_grad: + continue + is_custom_key = False + for key in custom_keys: + target_name = ( + f"{prefix}.{name}" if prefix != "" and "." in key else name + ) + if key == target_name: + params[key].append(p) + is_custom_key = True + break + if not is_custom_key: + if norm_weight_decay is not None and isinstance(module, norm_classes): + params["norm"].append(p) + else: + params["other"].append(p) + + for child_name, child_module in module.named_children(): + child_prefix = f"{prefix}.{child_name}" if prefix != "" else child_name + _add_params(child_module, prefix=child_prefix) + + _add_params(model) + + param_groups = [] + for key in params: + if len(params[key]) > 0: + param_groups.append( + {"params": params[key], "weight_decay": params_weight_decay[key]} + ) + return param_groups + + +# Presets for ImageNet training/eval taken from: https://github.com/pytorch/vision/blob/main/references/classification/presets.py + + +class ClassificationPresetTrain: + def __init__( + self, + *, + crop_size, + mean=(0.485, 0.456, 0.406), + std=(0.229, 0.224, 0.225), + interpolation=InterpolationMode.BILINEAR, + hflip_prob=0.5, + auto_augment_policy=None, + ra_magnitude=9, + augmix_severity=3, + random_erase_prob=0.0, + ): + trans = [transforms.RandomResizedCrop(crop_size, interpolation=interpolation)] + if hflip_prob > 0: + trans.append(transforms.RandomHorizontalFlip(hflip_prob)) + if auto_augment_policy is not None: + if auto_augment_policy == "ra": + trans.append( + autoaugment.RandAugment( + interpolation=interpolation, magnitude=ra_magnitude + ) + ) + elif auto_augment_policy == "ta_wide": + trans.append( + autoaugment.TrivialAugmentWide(interpolation=interpolation) + ) + elif auto_augment_policy == "augmix": + trans.append( + autoaugment.AugMix( + interpolation=interpolation, severity=augmix_severity + ) + ) + else: + aa_policy = autoaugment.AutoAugmentPolicy(auto_augment_policy) + trans.append( + autoaugment.AutoAugment( + policy=aa_policy, interpolation=interpolation + ) + ) + trans.extend( + [ + transforms.PILToTensor(), + transforms.ConvertImageDtype(torch.float), + transforms.Normalize(mean=mean, std=std), + ] + ) + if random_erase_prob > 0: + trans.append(transforms.RandomErasing(p=random_erase_prob)) + + self.transforms = transforms.Compose(trans) + + def __call__(self, img): + return self.transforms(img) + + +class ClassificationPresetEval: + def __init__( + self, + *, + crop_size, + resize_size=256, + mean=(0.485, 0.456, 0.406), + std=(0.229, 0.224, 0.225), + interpolation=InterpolationMode.BILINEAR, + ): + + self.transforms = transforms.Compose( + [ + transforms.Resize(resize_size, interpolation=interpolation), + transforms.CenterCrop(crop_size), + transforms.PILToTensor(), + transforms.ConvertImageDtype(torch.float), + transforms.Normalize(mean=mean, std=std), + ] + ) + + def __call__(self, img): + return self.transforms(img) + + +# transforms taken from: https://github.com/pytorch/vision/blob/main/references/classification/transforms.py + + +class RandomMixup(torch.nn.Module): + """Randomly apply Mixup to the provided batch and targets. + The class implements the data augmentations as described in the paper + `"mixup: Beyond Empirical Risk Minimization" `_. + + Args: + num_classes (int): number of classes used for one-hot encoding. + p (float): probability of the batch being transformed. Default value is 0.5. + alpha (float): hyperparameter of the Beta distribution used for mixup. + Default value is 1.0. + inplace (bool): boolean to make this transform inplace. Default set to False. + """ + + def __init__( + self, + num_classes: int, + p: float = 0.5, + alpha: float = 1.0, + inplace: bool = False, + ) -> None: + super().__init__() + + if num_classes < 1: + raise ValueError( + f"Please provide a valid positive value for the num_classes. Got num_classes={num_classes}" + ) + + if alpha <= 0: + raise ValueError("Alpha param can't be zero.") + + self.num_classes = num_classes + self.p = p + self.alpha = alpha + self.inplace = inplace + + def forward( + self, batch: torch.Tensor, target: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + batch (Tensor): Float tensor of size (B, C, H, W) + target (Tensor): Integer tensor of size (B, ) + + Returns: + Tensor: Randomly transformed batch. + """ + if batch.ndim != 4: + raise ValueError(f"Batch ndim should be 4. Got {batch.ndim}") + if target.ndim != 1: + raise ValueError(f"Target ndim should be 1. Got {target.ndim}") + if not batch.is_floating_point(): + raise TypeError(f"Batch dtype should be a float tensor. Got {batch.dtype}.") + if target.dtype != torch.int64: + raise TypeError(f"Target dtype should be torch.int64. Got {target.dtype}") + + if not self.inplace: + batch = batch.clone() + target = target.clone() + + if target.ndim == 1: + target = torch.nn.functional.one_hot( + target, num_classes=self.num_classes + ).to(dtype=batch.dtype) + + if torch.rand(1).item() >= self.p: + return batch, target + + # It's faster to roll the batch by one instead of shuffling it to create image pairs + batch_rolled = batch.roll(1, 0) + target_rolled = target.roll(1, 0) + + # Implemented as on mixup paper, page 3. + lambda_param = float( + torch._sample_dirichlet(torch.tensor([self.alpha, self.alpha]))[0] + ) + batch_rolled.mul_(1.0 - lambda_param) + batch.mul_(lambda_param).add_(batch_rolled) + + target_rolled.mul_(1.0 - lambda_param) + target.mul_(lambda_param).add_(target_rolled) + + return batch, target + + def __repr__(self) -> str: + s = ( + f"{self.__class__.__name__}(" + f"num_classes={self.num_classes}" + f", p={self.p}" + f", alpha={self.alpha}" + f", inplace={self.inplace}" + f")" + ) + return s + + +class RandomCutmix(torch.nn.Module): + """Randomly apply Cutmix to the provided batch and targets. + The class implements the data augmentations as described in the paper + `"CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features" + `_. + + Args: + num_classes (int): number of classes used for one-hot encoding. + p (float): probability of the batch being transformed. Default value is 0.5. + alpha (float): hyperparameter of the Beta distribution used for cutmix. + Default value is 1.0. + inplace (bool): boolean to make this transform inplace. Default set to False. + """ + + def __init__( + self, + num_classes: int, + p: float = 0.5, + alpha: float = 1.0, + inplace: bool = False, + ) -> None: + super().__init__() + if num_classes < 1: + raise ValueError( + "Please provide a valid positive value for the num_classes." + ) + if alpha <= 0: + raise ValueError("Alpha param can't be zero.") + + self.num_classes = num_classes + self.p = p + self.alpha = alpha + self.inplace = inplace + + def forward( + self, batch: torch.Tensor, target: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + batch (Tensor): Float tensor of size (B, C, H, W) + target (Tensor): Integer tensor of size (B, ) + + Returns: + Tensor: Randomly transformed batch. + """ + if batch.ndim != 4: + raise ValueError(f"Batch ndim should be 4. Got {batch.ndim}") + if target.ndim != 1: + raise ValueError(f"Target ndim should be 1. Got {target.ndim}") + if not batch.is_floating_point(): + raise TypeError(f"Batch dtype should be a float tensor. Got {batch.dtype}.") + if target.dtype != torch.int64: + raise TypeError(f"Target dtype should be torch.int64. Got {target.dtype}") + + if not self.inplace: + batch = batch.clone() + target = target.clone() + + if target.ndim == 1: + target = torch.nn.functional.one_hot( + target, num_classes=self.num_classes + ).to(dtype=batch.dtype) + + if torch.rand(1).item() >= self.p: + return batch, target + + # It's faster to roll the batch by one instead of shuffling it to create image pairs + batch_rolled = batch.roll(1, 0) + target_rolled = target.roll(1, 0) + + # Implemented as on cutmix paper, page 12 (with minor corrections on typos). + lambda_param = float( + torch._sample_dirichlet(torch.tensor([self.alpha, self.alpha]))[0] + ) + _, H, W = F.get_dimensions(batch) + + r_x = torch.randint(W, (1,)) + r_y = torch.randint(H, (1,)) + + r = 0.5 * math.sqrt(1.0 - lambda_param) + r_w_half = int(r * W) + r_h_half = int(r * H) + + x1 = int(torch.clamp(r_x - r_w_half, min=0)) + y1 = int(torch.clamp(r_y - r_h_half, min=0)) + x2 = int(torch.clamp(r_x + r_w_half, max=W)) + y2 = int(torch.clamp(r_y + r_h_half, max=H)) + + batch[:, :, y1:y2, x1:x2] = batch_rolled[:, :, y1:y2, x1:x2] + lambda_param = float(1.0 - (x2 - x1) * (y2 - y1) / (W * H)) + + target_rolled.mul_(1.0 - lambda_param) + target.mul_(lambda_param).add_(target_rolled) + + return batch, target + + def __repr__(self) -> str: + s = ( + f"{self.__class__.__name__}(" + f"num_classes={self.num_classes}" + f", p={self.p}" + f", alpha={self.alpha}" + f", inplace={self.inplace}" + f")" + ) + return s + + +# RA Sampler implementaion taken from: https://github.com/pytorch/vision/blob/main/references/classification/sampler.py + + +class RASampler(torch.utils.data.Sampler): + """Sampler that restricts data loading to a subset of the dataset for distributed, + with repeated augmentation. + It ensures that different each augmented version of a sample will be visible to a + different process (GPU). + Heavily based on 'torch.utils.data.DistributedSampler'. + + This is borrowed from the DeiT Repo: + https://github.com/facebookresearch/deit/blob/main/samplers.py + """ + + def __init__( + self, dataset, num_replicas=None, rank=None, shuffle=True, seed=0, repetitions=3 + ): + if num_replicas is None: + if not torch.distributed.is_available(): + raise RuntimeError("Requires distributed package to be available!") + num_replicas = torch.distributed.get_world_size() + if rank is None: + if not torch.distributed.is_available(): + raise RuntimeError("Requires distributed package to be available!") + rank = torch.distributed.get_rank() + self.dataset = dataset + self.num_replicas = num_replicas + self.rank = rank + self.epoch = 0 + self.num_samples = int( + math.ceil(len(self.dataset) * float(repetitions) / self.num_replicas) + ) + self.total_size = self.num_samples * self.num_replicas + self.num_selected_samples = int( + math.floor(len(self.dataset) // 256 * 256 / self.num_replicas) + ) + self.shuffle = shuffle + self.seed = seed + self.repetitions = repetitions + + def __iter__(self): + if self.shuffle: + # Deterministically shuffle based on epoch + g = torch.Generator() + g.manual_seed(self.seed + self.epoch) + indices = torch.randperm(len(self.dataset), generator=g).tolist() + else: + indices = list(range(len(self.dataset))) + + # Add extra samples to make it evenly divisible + indices = [ele for ele in indices for i in range(self.repetitions)] + indices += indices[: (self.total_size - len(indices))] + assert len(indices) == self.total_size + + # Subsample + indices = indices[self.rank : self.total_size : self.num_replicas] + assert len(indices) == self.num_samples + + return iter(indices[: self.num_selected_samples]) + + def __len__(self): + return self.num_selected_samples + + def set_epoch(self, epoch): + self.epoch = epoch diff --git a/torchao/sparsity/README.md b/torchao/sparsity/README.md index e644bd16df..be7fa8979b 100644 --- a/torchao/sparsity/README.md +++ b/torchao/sparsity/README.md @@ -90,7 +90,7 @@ We offer prototype support for accelerating block sparsity with our triton kerne ```py from torchao.sparsity.sparse_api import sparsify_ -from torchao.sparsity.prototype.superblock.blocksparse import block_sparse_weight +from torchao.prototype.sparsity.superblock.blocksparse import block_sparse_weight model = model.cuda() sparsify_(model, block_sparse_weight()) diff --git a/torchao/sparsity/prototype/__init__.py b/torchao/sparsity/prototype/__init__.py index 350a310501..924b7f409b 100644 --- a/torchao/sparsity/prototype/__init__.py +++ b/torchao/sparsity/prototype/__init__.py @@ -1,15 +1,20 @@ # Sparsifier -from torchao.sparsity.prototype.sparsifier.base_sparsifier import BaseSparsifier -from torchao.sparsity.prototype.sparsifier.weight_norm_sparsifier import WeightNormSparsifier -from torchao.sparsity.prototype.sparsifier.nearly_diagonal_sparsifier import NearlyDiagonalSparsifier - # Scheduler -from torchao.sparsity.prototype.scheduler.base_scheduler import BaseScheduler -from torchao.sparsity.prototype.scheduler.lambda_scheduler import LambdaSL -from torchao.sparsity.prototype.scheduler.cubic_scheduler import CubicSL +from torchao.prototype.sparsity.scheduler.base_scheduler import BaseScheduler +from torchao.prototype.sparsity.scheduler.cubic_scheduler import CubicSL +from torchao.prototype.sparsity.scheduler.lambda_scheduler import LambdaSL +from torchao.prototype.sparsity.sparsifier.base_sparsifier import BaseSparsifier +from torchao.prototype.sparsity.sparsifier.nearly_diagonal_sparsifier import ( + NearlyDiagonalSparsifier, +) # Parametrizations -from torchao.sparsity.prototype.sparsifier.utils import FakeSparsity -from torchao.sparsity.prototype.sparsifier.utils import module_to_fqn -from torchao.sparsity.prototype.sparsifier.utils import fqn_to_module -from torchao.sparsity.prototype.sparsifier.utils import get_arg_info_from_tensor_fqn +from torchao.prototype.sparsity.sparsifier.utils import ( + FakeSparsity, + fqn_to_module, + get_arg_info_from_tensor_fqn, + module_to_fqn, +) +from torchao.prototype.sparsity.sparsifier.weight_norm_sparsifier import ( + WeightNormSparsifier, +) diff --git a/torchao/sparsity/prototype/pruner/FPGM_pruner.py b/torchao/sparsity/prototype/pruner/FPGM_pruner.py index d8c3d20052..412c395108 100644 --- a/torchao/sparsity/prototype/pruner/FPGM_pruner.py +++ b/torchao/sparsity/prototype/pruner/FPGM_pruner.py @@ -1,93 +1 @@ -from typing import Callable, Optional, Union - -import torch - -from .base_structured_sparsifier import BaseStructuredSparsifier - -__all__ = ["FPGMPruner"] - - -class FPGMPruner(BaseStructuredSparsifier): - r"""Filter Pruning via Geometric Median (FPGM) Structured Pruner - This sparsifier prune fliter (row) in a tensor according to distances among filters according to - `Filter Pruning via Geometric Median for Deep Convolutional Neural Networks Acceleration `_. - - This sparsifier is controlled by three variables: - 1. `sparsity_level` defines the number of filters (rows) that are zeroed-out. - 2. `dist` defines the distance measurement type. Default: 3 (L2 distance). - Available options are: [1, 2, (custom callable distance function)]. - - Note:: - Inputs should be a 4D convolutional tensor of shape (N, C, H, W). - - N: output channels size - - C: input channels size - - H: height of kernel - - W: width of kernel - """ - - def __init__( - self, sparsity_level: float = 0.5, dist: Optional[Union[Callable, int]] = None - ): - defaults = { - "sparsity_level": sparsity_level, - } - - if dist is None: - dist = 2 - - if callable(dist): - self.dist_fn = dist - elif dist == 1: - self.dist_fn = lambda x: torch.cdist(x, x, p=1) - elif dist == 2: - self.dist_fn = lambda x: torch.cdist(x, x, p=2) - else: - raise NotImplementedError("Distance function is not yet implemented.") - super().__init__(defaults=defaults) - - def _compute_distance(self, t): - r"""Compute distance across all entries in tensor `t` along all dimension - except for the one identified by dim. - Args: - t (torch.Tensor): tensor representing the parameter to prune - Returns: - distance (torch.Tensor): distance computed across filtters - """ - dim = 0 # prune filter (row) - - size = t.size(dim) - slc = [slice(None)] * t.dim() - - # flatten the tensor along the dimension - t_flatten = [ - t[tuple(slc[:dim] + [slice(i, i + 1)] + slc[dim + 1 :])].reshape(-1) - for i in range(size) - ] - t_flatten = torch.stack(t_flatten) - - # distance measurement - dist_matrix = self.dist_fn(t_flatten) - - # more similar with other filter indicates large in the sum of row - distance = torch.sum(torch.abs(dist_matrix), 1) - - return distance - - def update_mask(self, module, tensor_name, sparsity_level, **kwargs): - tensor_weight = getattr(module, tensor_name) - mask = getattr(module.parametrizations, tensor_name)[0].mask - - if sparsity_level <= 0: - mask.data = torch.ones_like(mask).bool() - elif sparsity_level >= 1.0: - mask.data = torch.zeros_like(mask).bool() - else: - distance = self._compute_distance(tensor_weight) - - tensor_size = tensor_weight.shape[0] # prune filter (row) - nparams_toprune = round(sparsity_level * tensor_size) - nparams_toprune = min( - max(nparams_toprune, 0), tensor_size - ) # clamp to [0, tensor_size] - topk = torch.topk(distance, k=nparams_toprune, largest=False) - mask[topk.indices] = False +from torchao.prototype.sparsity.pruner.FPGM_pruner import FPGMPruner diff --git a/torchao/sparsity/prototype/pruner/base_structured_sparsifier.py b/torchao/sparsity/prototype/pruner/base_structured_sparsifier.py index b8e4112a79..257750d4df 100644 --- a/torchao/sparsity/prototype/pruner/base_structured_sparsifier.py +++ b/torchao/sparsity/prototype/pruner/base_structured_sparsifier.py @@ -1,312 +1,3 @@ -from itertools import chain -from operator import getitem -import torch -import torch.nn.functional as F -from torch import nn -from torch.fx import symbolic_trace -from torch.nn.utils import parametrize -from typing import Type, Set, Dict, Callable, Tuple, Optional, Union - -from torchao.sparsity.prototype import BaseSparsifier -from .parametrization import FakeStructuredSparsity, BiasHook, module_contains_param -from .match_utils import apply_match, MatchAllNode -from .prune_functions import ( - prune_linear, - prune_linear_linear, - prune_linear_activation_linear, - prune_conv2d, - prune_conv2d_conv2d, - prune_conv2d_activation_conv2d, - prune_conv2d_activation_pool_conv2d, - prune_conv2d_pool_activation_conv2d, - prune_conv2d_pool_flatten_linear, - prune_lstm_output_linear, - prune_lstm_output_layernorm_linear, +from torchao.prototype.sparsity.pruner.base_structured_sparsifier import ( + BaseStructuredSparsifier, ) - - -def _get_supported_structured_pruning_modules(): - SUPPORTED_STRUCTURED_PRUNING_MODULES = { # added to config if None given - nn.Linear, - nn.Conv2d, - nn.LSTM, - } - return SUPPORTED_STRUCTURED_PRUNING_MODULES - - -def _get_supported_activation_functions(): - SUPPORTED_ACTIVATION_FUNCTIONS = { - F.relu, - F.rrelu, - F.hardtanh, - F.relu6, - F.sigmoid, - F.hardsigmoid, - F.tanh, - F.silu, - F.mish, - F.hardswish, - F.elu, - F.celu, - F.selu, - F.hardshrink, - F.leaky_relu, - F.logsigmoid, - F.softplus, - F.prelu, - F.softsign, - F.tanhshrink, - F.gelu, - F.dropout, - } - return SUPPORTED_ACTIVATION_FUNCTIONS - - -def _get_supported_activation_modules(): - SUPPORTED_ACTIVATION_MODULES = { - nn.ReLU, - nn.RReLU, - nn.Hardtanh, - nn.ReLU6, - nn.Sigmoid, - nn.Hardsigmoid, - nn.Tanh, - nn.SiLU, - nn.Mish, - nn.Hardswish, - nn.ELU, - nn.CELU, - nn.SELU, - nn.Hardshrink, - nn.LeakyReLU, - nn.LogSigmoid, - nn.Softplus, - nn.PReLU, - nn.Softsign, - nn.Tanhshrink, - nn.GELU, - nn.Dropout, - } - return SUPPORTED_ACTIVATION_MODULES - - -def _get_default_structured_pruning_patterns() -> Dict[ - Tuple[Union[Type[nn.Module], Callable, MatchAllNode, str], ...], - Callable[..., None], -]: - """ - Returns the patterns for conv2d / linear conversion for each element in the activation functions/modules defined above. - """ - patterns: Dict[ - Tuple[Union[Type[nn.Module], Callable, MatchAllNode, str], ...], - Callable[..., None], - ] = { - # linear -> linear - (nn.Linear, "output"): prune_linear, - (nn.Linear, nn.Linear): prune_linear_linear, - # conv2d -> conv2d - (nn.Conv2d, "output"): prune_conv2d, - (nn.Conv2d, nn.Conv2d): prune_conv2d_conv2d, - # TODO LSTM Structured pruning does not support returned state currently. - # Should find a way to explicitly match getitem(0) instead of getitem. - # This will also require changing the pruning function. - # lstm -> getitem(0) -> linear - (nn.LSTM, getitem, nn.Linear): prune_lstm_output_linear, - # lstm -> getitem(0) -> layernorm -> linear - (nn.LSTM, getitem, nn.LayerNorm, nn.Linear): prune_lstm_output_layernorm_linear, - } - - for activation in chain( - _get_supported_activation_functions(), _get_supported_activation_modules() - ): - patterns.update( - { - # linear -> activation -> linear - (nn.Linear, activation, nn.Linear): prune_linear_activation_linear, - # conv2d -> activation -> conv2d - (nn.Conv2d, activation, nn.Conv2d): prune_conv2d_activation_conv2d, - # conv2d -> activation -> pool -> conv2d - ( - nn.Conv2d, - activation, - nn.AvgPool2d, - nn.Conv2d, - ): prune_conv2d_activation_pool_conv2d, - ( - nn.Conv2d, - activation, - F.avg_pool2d, - nn.Conv2d, - ): prune_conv2d_activation_pool_conv2d, - ( - nn.Conv2d, - activation, - nn.MaxPool2d, - nn.Conv2d, - ): prune_conv2d_activation_pool_conv2d, - ( - nn.Conv2d, - activation, - F.max_pool2d, - nn.Conv2d, - ): prune_conv2d_activation_pool_conv2d, - # conv2d -> pool -> activation -> conv2d - ( - nn.Conv2d, - nn.AvgPool2d, - activation, - nn.Conv2d, - ): prune_conv2d_pool_activation_conv2d, - ( - nn.Conv2d, - F.avg_pool2d, - activation, - nn.Conv2d, - ): prune_conv2d_pool_activation_conv2d, - ( - nn.Conv2d, - nn.MaxPool2d, - activation, - nn.Conv2d, - ): prune_conv2d_pool_activation_conv2d, - ( - nn.Conv2d, - F.max_pool2d, - activation, - nn.Conv2d, - ): prune_conv2d_pool_activation_conv2d, - # conv2d -> adaptive pool -> flatten -> linear - ( - nn.Conv2d, - nn.AdaptiveAvgPool2d, - nn.Flatten, - nn.Linear, - ): prune_conv2d_pool_flatten_linear, - ( - nn.Conv2d, - nn.AdaptiveAvgPool2d, - torch.flatten, - nn.Linear, - ): prune_conv2d_pool_flatten_linear, - ( - nn.Conv2d, - nn.AdaptiveMaxPool2d, - nn.Flatten, - nn.Linear, - ): prune_conv2d_pool_flatten_linear, - ( - nn.Conv2d, - nn.AdaptiveMaxPool2d, - torch.flatten, - nn.Linear, - ): prune_conv2d_pool_flatten_linear, - } - ) - return patterns - - -class BaseStructuredSparsifier(BaseSparsifier): - r"""Base class for structured pruning. - - Abstract methods that need to be implemented: - - update_mask: Function to compute a new mask for all keys in the - `groups` attribute. - - Args: - - defaults [dict]: default configurations will be attached to the - configuration. Only the keys that don't exist in the `config` will - be updated. - """ - - def __init__(self, defaults, patterns=None): - super().__init__(defaults) - if patterns is None: - patterns = _get_default_structured_pruning_patterns() - self.patterns = patterns - - def make_config_from_model( - self, - model: nn.Module, - SUPPORTED_MODULES: Optional[Set[Type]] = None, - ) -> None: - if SUPPORTED_MODULES is None: - SUPPORTED_MODULES = _get_supported_structured_pruning_modules() - super().make_config_from_model(model, SUPPORTED_MODULES=SUPPORTED_MODULES) - - def _prepare(self, *args, **kwargs) -> None: - r"""This function will attach the FakeStructuredSparsity parameterizations - and BiasHooks at the appropriate points in the model. - """ - for config in self.groups: - module = config["module"] - tensor_name = config["tensor_name"] - parametrization = config.get("parametrization", FakeStructuredSparsity) - tensor = getattr(module, tensor_name) - - mask = config.get( - "mask", - torch.ones(tensor.shape[0], dtype=torch.bool, device=tensor.device), - ) - self.state[config["tensor_fqn"]]["mask"] = mask - parametrize.register_parametrization( - module, tensor_name, parametrization(mask) - ) - - # if linear / conv, we add in bias hooks - if isinstance(module, (nn.Linear, nn.Conv2d)): - prune_bias = config.get("prune_bias", True) - if module.bias is not None: - module.register_parameter( - "_bias", nn.Parameter(module.bias.detach()) - ) - module.bias = None - module.prune_bias = prune_bias - - module.register_forward_hook( - BiasHook(module.parametrizations.weight[0], prune_bias) - ) - - def prune(self) -> None: - r""" - This function will FX symbolically trace the model and then find instances of the patterns - defined in self.patterns (by default SUPPORTED_STRUCTURED_PRUNING_PATTERNS ). - - For each pattern, it will apply to corresponding conversion function, which will modify the output - and input size expected by the modules within the pattern - """ - - self.traced = symbolic_trace(self.model) - modules = dict(self.traced.named_modules()) - - # Right now we check for matches simply by iterating across all the patterns - # if this is slow we can store patterns in a trie-structure and modify this code for faster lookup - for node in self.traced.graph.nodes: - for pattern, convert_fn in self.patterns.items(): - matched = apply_match(modules, pattern, node, []) - if matched is None: - continue - - first_module = modules.get(node.target) - # check if first module exists and has appropriate parameterization, otherwise skip - if ( - first_module is not None - and parametrize.is_parametrized(first_module) - and module_contains_param(first_module, FakeStructuredSparsity) - ): - convert_block = [] - for node in matched: - if node.op == "call_module": - convert_block.append(modules.get(node.target)) - elif node.op == "call_function": - convert_block.append(node.target) - convert_fn(*convert_block) - - for module in self.traced.modules(): - if module_contains_param(module, FakeStructuredSparsity): - raise Exception( - f"Error: {module} still contains FakeStructuredSparsity parametrizations!" - ) - - self.traced.graph.lint() - self.traced.recompile() - return self.traced diff --git a/torchao/sparsity/prototype/pruner/lstm_saliency_pruner.py b/torchao/sparsity/prototype/pruner/lstm_saliency_pruner.py index 4a0d74d6dc..9c1656bf47 100644 --- a/torchao/sparsity/prototype/pruner/lstm_saliency_pruner.py +++ b/torchao/sparsity/prototype/pruner/lstm_saliency_pruner.py @@ -1,48 +1 @@ -from typing import cast - -import torch -from .base_structured_sparsifier import BaseStructuredSparsifier, FakeStructuredSparsity - -class LSTMSaliencyPruner(BaseStructuredSparsifier): - """ - Prune packed LSTM weights based on saliency. - For each layer {k} inside a LSTM, we have two packed weight matrices - - weight_ih_l{k} - - weight_hh_l{k} - - These tensors pack the weights for the 4 linear layers together for efficiency. - - [W_ii | W_if | W_ig | W_io] - - Pruning this tensor directly will lead to weights being misassigned when unpacked. - To ensure that each packed linear layer is pruned the same amount: - 1. We split the packed weight into the 4 constituent linear parts - 2. Update the mask for each individual piece using saliency individually - - This applies to both weight_ih_l{k} and weight_hh_l{k}. - """ - - def update_mask(self, module, tensor_name, **kwargs): - weights = getattr(module, tensor_name) - - for p in getattr(module.parametrizations, tensor_name): - if isinstance(p, FakeStructuredSparsity): - mask = cast(torch.Tensor, p.mask) - - # select weights based on magnitude - if weights.dim() <= 1: - raise Exception("Structured pruning can only be applied to a 2+dim weight tensor!") - # take norm over all but first dim - dims = tuple(range(1, weights.dim())) - saliency = weights.norm(dim=dims, p=1) - - # handle weights in 4 groups - split_size = len(mask) // 4 - masks = torch.split(mask, split_size) - saliencies = torch.split(saliency, split_size) - - for keep_mask, sal in zip(masks, saliencies): - # mask smallest k values to be removed - k = int(len(keep_mask) * kwargs["sparsity_level"]) - prune = sal.topk(k, largest=False, sorted=False).indices - keep_mask.data[prune] = False # modifies underlying p.mask directly +from torchao.prototype.sparsity.pruner.lstm_saliency_pruner import LSTMSaliencyPruner diff --git a/torchao/sparsity/prototype/pruner/parametrization.py b/torchao/sparsity/prototype/pruner/parametrization.py index df94f7093b..8603639293 100644 --- a/torchao/sparsity/prototype/pruner/parametrization.py +++ b/torchao/sparsity/prototype/pruner/parametrization.py @@ -1,59 +1,4 @@ -import torch -from torch import nn -from torch.nn.utils.parametrize import is_parametrized - - -def module_contains_param(module, parametrization): - if is_parametrized(module): - # see if any of the module tensors have a parametriztion attached that matches the one passed in - return any( - any(isinstance(param, parametrization) for param in param_list) - for key, param_list in module.parametrizations.items() - ) - return False - - -# Structured Pruning Parameterizations -class FakeStructuredSparsity(nn.Module): - r""" - Parametrization for Structured Pruning. Like FakeSparsity, this should be attached to - the 'weight' or any other parameter that requires a mask. - - Instead of an element-wise bool mask, this parameterization uses a row-wise bool mask. - """ - - def __init__(self, mask): - super().__init__() - self.register_buffer("mask", mask) - - def forward(self, x): - assert isinstance(self.mask, torch.Tensor) - assert self.mask.shape[0] == x.shape[0] - shape = [1] * len(x.shape) - shape[0] = -1 - return self.mask.reshape(shape) * x - - def state_dict(self, *args, **kwargs): - # avoid double saving masks - return {} - - -class BiasHook: - def __init__(self, parametrization, prune_bias): - self.param = parametrization - self.prune_bias = prune_bias - - def __call__(self, module, input, output): - - if getattr(module, "_bias", None) is not None: - bias = module._bias.data - if self.prune_bias: - bias[~self.param.mask] = 0 - - # reshape bias to broadcast over output dimensions - idx = [1] * len(output.shape) - idx[1] = -1 - bias = bias.reshape(idx) - - output += bias - return output +from torchao.prototype.sparsity.pruner.parametrization import ( + BiasHook, + FakeStructuredSparsity, +) diff --git a/torchao/sparsity/prototype/pruner/saliency_pruner.py b/torchao/sparsity/prototype/pruner/saliency_pruner.py index f965fa647d..4f43ccf46e 100644 --- a/torchao/sparsity/prototype/pruner/saliency_pruner.py +++ b/torchao/sparsity/prototype/pruner/saliency_pruner.py @@ -1,29 +1 @@ -from .base_structured_sparsifier import BaseStructuredSparsifier - - -class SaliencyPruner(BaseStructuredSparsifier): - """ - Prune rows based on the saliency (L1 norm) of each row. - - This pruner works on N-Dimensional weight tensors. - For each row, we will calculate the saliency, whic is the sum the L1 norm of all weights in that row. - We expect that the resulting saliency vector has the same shape as our mask. - We then pick elements to remove until we reach the target sparsity_level. - """ - - def update_mask(self, module, tensor_name, **kwargs): - # tensor_name will give you the FQN, all other entries in sparse config is present in kwargs - weights = getattr(module, tensor_name) - mask = getattr(module.parametrizations, tensor_name)[0].mask - - # use negative weights so we can use topk (we prune out the smallest) - if weights.dim() <= 1: - raise Exception("Structured pruning can only be applied to a 2+dim weight tensor!") - saliency = -weights.norm(dim=tuple(range(1, weights.dim())), p=1) - assert saliency.shape == mask.shape - - num_to_pick = int(len(mask) * kwargs["sparsity_level"]) - prune = saliency.topk(num_to_pick).indices - - # Set the mask to be false for the rows we want to prune - mask.data[prune] = False +from torchao.prototype.sparsity.pruner.saliency_pruner import SaliencyPruner diff --git a/torchao/sparsity/prototype/scheduler/base_scheduler.py b/torchao/sparsity/prototype/scheduler/base_scheduler.py index f102f351ea..877f419ac1 100644 --- a/torchao/sparsity/prototype/scheduler/base_scheduler.py +++ b/torchao/sparsity/prototype/scheduler/base_scheduler.py @@ -1,159 +1 @@ - -from functools import wraps -import warnings -import weakref - -from torchao.sparsity.prototype.sparsifier.base_sparsifier import BaseSparsifier - -__all__ = ["BaseScheduler"] - -class BaseScheduler: - - def __init__(self, sparsifier, last_epoch=-1, verbose=False): - - # Attach sparsifier - if not isinstance(sparsifier, BaseSparsifier): - raise TypeError(f'{type(sparsifier).__name__} is not an instance of torch.ao.pruning.BaseSparsifier') - self.sparsifier = sparsifier - - # Initialize epoch and base sparsity levels - - self.base_sl = [group['sparsity_level'] for group in sparsifier.groups] - self.last_epoch = last_epoch - - # Following https://github.com/pytorch/pytorch/issues/20124 - # We would like to ensure that `scheduler.step()` is called after - # `sparsifier.step()` - def with_counter(method): - if getattr(method, '_with_counter', False): - # `sparsifier.step()` has already been replaced, return. - return method - - # Keep a weak reference to the sparsifier instance to prevent - # cyclic references. - instance_ref = weakref.ref(method.__self__) - # Get the unbound method for the same purpose. - func = method.__func__ - cls = instance_ref().__class__ - del method - - @wraps(func) - def wrapper(*args, **kwargs): - instance = instance_ref() - instance._step_count += 1 # type: ignore[union-attr] - wrapped = func.__get__(instance, cls) - return wrapped(*args, **kwargs) - - # Note that the returned function here is no longer a bound method, - # so attributes like `__func__` and `__self__` no longer exist. - wrapper._with_counter = True # type: ignore[attr-defined] - return wrapper - - self.sparsifier.step = with_counter(self.sparsifier.step) # type: ignore[assignment] - self.sparsifier._step_count = 0 # type: ignore[attr-defined] - self._step_count: int = 0 - self.verbose = verbose - - # Housekeeping - self._get_sl_called_within_step: bool = False - - self.step() - - def state_dict(self): - """Returns the state of the scheduler as a :class:`dict`. - - It contains an entry for every variable in self.__dict__ which - is not the sparsifier. - """ - return {key: value for key, value in self.__dict__.items() if key != 'sparsifier'} - - def load_state_dict(self, state_dict): - """Loads the schedulers state. - - Args: - state_dict (dict): scheduler state. Should be an object returned - from a call to :meth:`state_dict`. - """ - self.__dict__.update(state_dict) - - def get_last_sl(self): - """ Return last computed sparsity level by current scheduler. - """ - return self._last_sl - - def get_sl(self): - # Compute sparsity level using chainable form of the scheduler - # Note: This method is not intended to be called directly, and is only - # used by the ".step" method. Use .get_last_sl() instead. - if not self._get_sl_called_within_step: - warnings.warn( - "To get the last sparsity level computed by the scheduler, " - "please use `get_last_sl()`.") - raise NotImplementedError - - def print_sl(self, is_verbose, group, sl, epoch=None): - """Display the current sparsity level. - """ - if is_verbose: - if epoch is None: - print(f'Adjusting sparsity level of group {group} to {sl:.4e}.') - else: - print(f'Epoch {epoch:5d}: adjusting sparsity level of group {group} to {sl:.4e}.') - - def __repr__(self): - format_string = self.__class__.__name__ + ' (' - format_string += '\n' - format_string += f'Sparsifier {self.sparsifier}\n' - format_string += f' base_sl: {self.base_sl}\n' - format_string += ')' - return format_string - - def step(self, epoch=None): - # Raise warning if trying to call scheduler step before the sparsifier. - # https://github.com/pytorch/pytorch/issues/20124 - if self._step_count == 1: - if not hasattr(self.sparsifier.step, "_with_counter"): - warnings.warn("Seems like `sparsifier.step()` has been overridden after sparsity scheduler " - "initialization. Please, make sure to call `sparsifier.step()` before " - "`scheduler.step()`.", UserWarning) - - # Just check if there were two first scheduler.step() calls before sparsifier.step() - elif self.sparsifier._step_count < 1: # type: ignore[attr-defined] - warnings.warn("Detected call of `scheduler.step()` before `sparsifier.step()`. " - "You have to make sure you run the sparsifier.step() BEFORE any " - "calls to the scheduler.step().", UserWarning) - self._step_count += 1 - - class _enable_get_sl_call: - - def __init__(self, o): - self.o = o - - def __enter__(self): - self.o._get_sl_called_within_step = True - return self - - def __exit__(self, type, value, traceback): - self.o._get_sl_called_within_step = False - - with _enable_get_sl_call(self): - self.last_epoch += 1 - values = self.get_sl() - - for i, data in enumerate(zip(self.sparsifier.groups, values)): - param_group, sl = data - param_group['sparsity_level'] = sl - self.print_sl(self.verbose, i, sl, epoch) - - self._last_sl = [group['sparsity_level'] for group in self.sparsifier.groups] - self.sparsifier.enable_mask_update = True - - def _make_sure_a_list(self, var): - r"""Utility that extends it to the same length as the .groups, ensuring it is a list""" - n = len(self.sparsifier.groups) - if not isinstance(var, (list, tuple)): - return [var] * n - else: - if len(var) != n: - raise ValueError(f"Expected variable of length {n}, but got {len(var)}") - return list(var) # We want the result to be in a list, not tuple +from torchao.prototype.sparsity.scheduler.base_scheduler import BaseScheduler diff --git a/torchao/sparsity/prototype/scheduler/cubic_scheduler.py b/torchao/sparsity/prototype/scheduler/cubic_scheduler.py index 76fc61daa2..9b86f16be6 100644 --- a/torchao/sparsity/prototype/scheduler/cubic_scheduler.py +++ b/torchao/sparsity/prototype/scheduler/cubic_scheduler.py @@ -1,107 +1 @@ -import warnings - -from .base_scheduler import BaseScheduler - -__all__ = ["CubicSL"] - -def _clamp(x, lo, hi): - return max(lo, min(hi, x)) - - -class CubicSL(BaseScheduler): - r"""Sets the sparsity level of each parameter group to the final sl - plus a given exponential function. - - .. math:: - - s_i = s_f + (s_0 - s_f) \cdot \left( 1 - \frac{t - t_0}{n\Delta t} \right)^3 - - where :math:`s_i` is the sparsity at epoch :math:`t`, :math;`s_f` is the final - sparsity level, :math:`f(i)` is the function to be applied to the current epoch - :math:`t`, initial epoch :math:`t_0`, and final epoch :math:`t_f`. - :math:`\Delta t` is used to control how often the update of the sparsity level - happens. By default, - - Args: - sparsifier (BaseSparsifier): Wrapped sparsifier. - init_sl (int, list): Initial level of sparsity - init_t (int, list): Initial step, when pruning starts - delta_t (int, list): Pruning frequency - total_t (int, list): Total number of pruning steps - initially_zero (bool, list): If True, sets the level of sparsity to 0 - before init_t (:math:`t_0`). Otherwise, the sparsity level before - init_t (:math:`t_0`) is set to init_sl(:math:`s_0`) - last_epoch (int): The index of last epoch. Default: -1. - verbose (bool): If ``True``, prints a message to stdout for - each update. Default: ``False``. - """ - def __init__(self, - sparsifier, - init_sl=0.0, - init_t=0, - delta_t=10, - total_t=100, - initially_zero=False, - last_epoch=-1, - verbose=False - ): - self.sparsifier = sparsifier - - self.init_sl = self._make_sure_a_list(init_sl) - self.init_t = self._make_sure_a_list(init_t) - self.delta_t = self._make_sure_a_list(delta_t) - self.total_t = self._make_sure_a_list(total_t) - - self.initially_zero = self._make_sure_a_list(initially_zero) - - super().__init__(sparsifier, last_epoch, verbose) - - @staticmethod - def sparsity_compute_fn(s_0, s_f, t, t_0, dt, n, initially_zero=False): - r""""Computes the current level of sparsity. - - Based on https://arxiv.org/pdf/1710.01878.pdf - - Args: - s_0: Initial level of sparsity, :math:`s_i` - s_f: Target level of sparsity, :math:`s_f` - t: Current step, :math:`t` - t_0: Initial step, :math:`t_0` - dt: Pruning frequency, :math:`\Delta T` - n: Pruning steps, :math:`n` - initially_zero: Sets the level of sparsity to 0 before t_0. - If False, sets to s_0 - - Returns: - The sparsity level :math:`s_t` at the current step :math:`t` - """ - if initially_zero and t < t_0: - return 0 - s_t = s_f + (s_0 - s_f) * (1.0 - (t - t_0) / (dt * n)) ** 3 - s_t = _clamp(s_t, s_0, s_f) - return s_t - - def get_sl(self): - if not self._get_sl_called_within_step: - warnings.warn( - "To get the last sparsity level computed by the scheduler, " - "please use `get_last_sl()`.") - return [ - self.sparsity_compute_fn( - s_0=initial_sparsity, - s_f=final_sparsity, - t=self.last_epoch, - t_0=initial_epoch, - dt=delta_epoch, - n=interval_epochs, - initially_zero=initially_zero - ) for initial_sparsity, final_sparsity, initial_epoch, delta_epoch, interval_epochs, initially_zero in - zip( - self.init_sl, - self.base_sl, - self.init_t, - self.delta_t, - self.total_t, - self.initially_zero - ) - ] +from torchao.prototype.sparsity.scheduler.cubic_scheduler import CubicSL diff --git a/torchao/sparsity/prototype/scheduler/lambda_scheduler.py b/torchao/sparsity/prototype/scheduler/lambda_scheduler.py index a88d99a1f8..0730558195 100644 --- a/torchao/sparsity/prototype/scheduler/lambda_scheduler.py +++ b/torchao/sparsity/prototype/scheduler/lambda_scheduler.py @@ -1,47 +1 @@ -import warnings - -from .base_scheduler import BaseScheduler - -__all__ = ["LambdaSL"] - -class LambdaSL(BaseScheduler): - """Sets the sparsity level of each parameter group to the final sl - times a given function. When last_epoch=-1, sets initial sl as zero. - Args: - sparsifier (BaseSparsifier): Wrapped sparsifier. - sl_lambda (function or list): A function which computes a multiplicative - factor given an integer parameter epoch, or a list of such - functions, one for each group in sparsifier.param_groups. - last_epoch (int): The index of last epoch. Default: -1. - verbose (bool): If ``True``, prints a message to stdout for - each update. Default: ``False``. - Example: - >>> # Assuming sparsifier has two groups. - >>> lambda1 = lambda epoch: epoch // 30 - >>> lambda2 = lambda epoch: 0.95 ** epoch - >>> # xdoctest: +SKIP - >>> scheduler = LambdaSL(sparsifier, sl_lambda=[lambda1, lambda2]) - >>> for epoch in range(100): - >>> train(...) - >>> validate(...) - >>> scheduler.step() - """ - - def __init__(self, sparsifier, sl_lambda, last_epoch=-1, verbose=False): - self.sparsifier = sparsifier - - if not isinstance(sl_lambda, list) and not isinstance(sl_lambda, tuple): - self.sl_lambdas = [sl_lambda] * len(sparsifier.groups) - else: - if len(sl_lambda) != len(sparsifier.groups): - raise ValueError(f"Expected {len(sparsifier.groups)} lr_lambdas, but got {len(sl_lambda)}") - self.sl_lambdas = list(sl_lambda) - super().__init__(sparsifier, last_epoch, verbose) - - def get_sl(self): - if not self._get_sl_called_within_step: - warnings.warn( - "To get the last sparsity level computed by the scheduler, " - "please use `get_last_sl()`.") - return [base_sl * lmbda(self.last_epoch) - for lmbda, base_sl in zip(self.sl_lambdas, self.base_sl)] +from torchao.prototype.sparsity.scheduler.lambda_scheduler import LambdaSL diff --git a/torchao/sparsity/prototype/sparsifier/base_sparsifier.py b/torchao/sparsity/prototype/sparsifier/base_sparsifier.py index 1c210ace34..954c06b74f 100644 --- a/torchao/sparsity/prototype/sparsifier/base_sparsifier.py +++ b/torchao/sparsity/prototype/sparsifier/base_sparsifier.py @@ -1,353 +1 @@ -import abc -import copy -from collections import defaultdict -from typing import Any, Dict, Optional, Set, Tuple, List, Type - -import torch -from torch import nn -from torch.nn.utils import parametrize -from torch.nn.utils.parametrize import type_before_parametrizations - -from .utils import ( - module_contains_param, - swap_module, - FakeSparsity, - get_arg_info_from_tensor_fqn, - module_to_fqn, -) - -__all__ = ["BaseSparsifier"] - -SUPPORTED_MODULES = {nn.Linear} - -KEYS_NOT_IN_STATE_DICT = ["module", "module_fqn", "tensor_name"] - -__all__ = ["BaseSparsifier"] - - -# TODO update desc with new config args -class BaseSparsifier(abc.ABC): - r"""Base class for all sparsifiers. - - Abstract methods that need to be implemented: - - - update_mask: Function to compute a new mask for all keys in the - `groups`. - - Args: - - model [nn.Module]: model to configure. The model itself is not saved - but used for the state_dict saving / loading. - - config [list]: configuration elements should be a dict map that includes - `tensor_fqn` of tensors to sparsify - - defaults [dict]: default configurations will be attached to the - configuration. Only the keys that don't exist in the `config` will - be updated. - - Example:: - - >>> # xdoctest: +SKIP("Can't instantiate abstract class BaseSparsifier with abstract method update_mask") - >>> config = [{'tensor_fqn': 'layer1.weight', 'tensor_fqn': 'linear2.weight2', 'sparsity_level': 0.5}] - >>> defaults = {'sparsity_level': 0.7} - >>> # model.layer1.weight will have `sparsity_level` = 0.7 (getting default) - >>> sparsifier = BaseSparsifier(config, defaults) - """ - - def __init__(self, defaults: Optional[Dict[str, Any]] = None): - super().__init__() - self.defaults: Dict[str, Any] = defaults or {} - - self.state: Dict[str, Dict] = defaultdict(dict) - self.groups: List[Dict[str, Any]] = [] - self.enable_mask_update = True - - def __getstate__(self) -> Dict[str, Any]: - return { - "defaults": self.defaults, - "state": self.state, - "groups": self.groups, - } - - def __setstate__(self, state: Dict[str, Dict[str, Any]]) -> None: - self.__dict__.update(state) - - def __repr__(self): - format_string = self.__class__.__name__ + " (" - for i, sparse_args in enumerate(self.groups): - module = sparse_args["module"] - format_string += "\n" - format_string += f"\tGroup {i}\n" - format_string += f"\t module: {module}\n" - for key in sorted(sparse_args.keys()): - if key == "module": - continue - format_string += f"\t {key}: {sparse_args[key]}\n" - format_string += ")" - return format_string - - def state_dict(self) -> Dict[str, Any]: - r"""Returns the state of the optimizer as a :class:`dict`. - - It contains: - * state - current state of the sparsification. - * groups - a list containing all sparsity configuration groups - with the key 'tensor_fqn' specifying the path to the sparsified tensor within a model - - TODO: Need a clean way of loading the state of the "prepared" module - """ - - groups: List[Dict[str, Any]] = [ - dict( - filter( - lambda key_value: key_value[0] not in KEYS_NOT_IN_STATE_DICT, - mg.items(), - ) - ) - for mg in self.groups - ] - - return { - "state": self.state, - "groups": groups, - } - - def load_state_dict(self, state_dict: Dict[str, Any], strict: bool = True): - groups = copy.deepcopy(state_dict["groups"]) - states = state_dict["state"] - for tensor_fqn, s in states.items(): - arg_info = get_arg_info_from_tensor_fqn(self.model, tensor_fqn) - module = arg_info["module"] - tensor_name = arg_info["tensor_name"] - if strict and module is None: - raise RuntimeError(f"Error loading {tensor_fqn} into the model") - - found = False - for p in module.parametrizations[tensor_name]: - if isinstance(p, FakeSparsity): - found = True - break - if not found: - p = FakeSparsity(torch.ones(getattr(module, tensor_name).shape)) - parametrize.register_parametrization(module, tensor_name, p) - if s.get("mask", None) is not None: - mask = s.pop("mask") - p.mask = mask - - for mg in groups: - if mg["tensor_fqn"] == tensor_fqn: - mg.update(arg_info) - self.__setstate__({"state": states, "groups": groups}) - - def make_config_from_model( - self, - model: nn.Module, - SUPPORTED_MODULES: Set[Type] = SUPPORTED_MODULES, - ) -> None: - self.config = [] - stack = [model] - while stack: - module = stack.pop() - for name, child in module.named_children(): - if type(child) in SUPPORTED_MODULES: - module_fqn = module_to_fqn(model, child) - assert isinstance(module_fqn, str) # for mypy - self.config.append({"tensor_fqn": module_fqn + ".weight"}) - else: - stack.append(child) - - def prepare(self, model, config): - r"""Prepares a model, by adding the parametrizations. - - Note:: - - The model is modified inplace. If you need to preserve the original - model, use copy.deepcopy. - """ - self.model = model # TODO: Need to figure out how to load without this. - self.config = config - - # If no config -- try getting all the supported layers - if self.config is None: - self.make_config_from_model(model) - - # TODO: Remove the configuration by reference ('module') - for module_config in self.config: - assert isinstance(module_config, dict), ( - "config elements should be dicts not modules i.e.:" - "[{`tensor_fqn`: `foo.bar.weight`}, {`tensor_fqn`: ... }, ...]" - ) - - assert isinstance(self.defaults, Dict) # for mypy - local_args = copy.deepcopy(self.defaults) - local_args.update(module_config) - - tensor_fqn = local_args.get("tensor_fqn", None) - assert tensor_fqn is not None, ( - "tensor_fqn is a required argument in the sparsity config which" - "replaces previous `module` and [module]`fqn` arguments" - ) - - # populate all information from tensor_fqn - info_from_tensor_fqn = get_arg_info_from_tensor_fqn(model, tensor_fqn) - - # check that whatever was put into local_args agrees with what was obtained - # from tensor_fqn - for key in info_from_tensor_fqn.keys(): - if key in local_args: - assert ( - info_from_tensor_fqn[key] == local_args[key] - or ( - key == "tensor_fqn" - and "." + info_from_tensor_fqn[key] == local_args[key] - ) - # info_from_tensor_fqn will chop leading '.' from tensor_fqn so ignore that - ), ( - f"Given both `{key}` and `tensor_fqn` in the config, it is expected them to agree!" - ) - local_args.update(info_from_tensor_fqn) - self.groups.append(local_args) - self._prepare() - - def _prepare(self, *args, **kwargs): - r"""Adds mask parametrization to the layer weight""" - for config in self.groups: - module = config["module"] - tensor_name = config["tensor_name"] - parametrization = config.get("parametrization", FakeSparsity) - mask = config.get("mask", torch.ones_like(getattr(module, tensor_name))) - self.state[config["tensor_fqn"]]["mask"] = mask - parametrize.register_parametrization( - module, tensor_name, parametrization(mask) - ) - - def squash_mask( - self, - params_to_keep: Optional[Tuple[str, ...]] = None, - params_to_keep_per_layer: Optional[Dict[str, Tuple[str, ...]]] = None, - *args, - **kwargs, - ): - r"""Squashes the sparse masks into the appropriate tensors. - - If either the `params_to_keep` or `params_to_keep_per_layer` is set, - the module will have a `sparse_params` dict attached to it. - - Args: - params_to_keep: List of keys to save in the module or a dict - representing the modules and keys that will have - sparsity parameters saved - params_to_keep_per_layer: Dict to specify the params that should be - saved for specific layers. The keys in the dict - should be the module fqn, while the values should - be a list of strings with the names of the variables - to save in the `sparse_params` - - Examples: - >>> # xdoctest: +SKIP("locals are undefined") - >>> # Don't save any sparse params - >>> sparsifier.squash_mask() - >>> hasattr(model.submodule1, 'sparse_params') - False - - >>> # Keep sparse params per layer - >>> sparsifier.squash_mask( - ... params_to_keep_per_layer={ - ... 'submodule1.linear1': ('foo', 'bar'), - ... 'submodule2.linear42': ('baz',) - ... }) - >>> print(model.submodule1.linear1.sparse_params) - {'foo': 42, 'bar': 24} - >>> print(model.submodule2.linear42.sparse_params) - {'baz': 0.1} - - >>> # Keep sparse params for all layers - >>> sparsifier.squash_mask(params_to_keep=('foo', 'bar')) - >>> print(model.submodule1.linear1.sparse_params) - {'foo': 42, 'bar': 24} - >>> print(model.submodule2.linear42.sparse_params) - {'foo': 42, 'bar': 24} - - >>> # Keep some sparse params for all layers, and specific ones for - >>> # some other layers - >>> sparsifier.squash_mask( - ... params_to_keep=('foo', 'bar'), - ... params_to_keep_per_layer={ - ... 'submodule2.linear42': ('baz',) - ... }) - >>> print(model.submodule1.linear1.sparse_params) - {'foo': 42, 'bar': 24} - >>> print(model.submodule2.linear42.sparse_params) - {'foo': 42, 'bar': 24, 'baz': 0.1} - """ - for config in self.groups: - module = config["module"] - tensor_name = config["tensor_name"] - parametrize.remove_parametrizations( - module, tensor_name, leave_parametrized=True - ) - sparse_params = {} - if params_to_keep is not None: - global_params = {k: config[k] for k in params_to_keep} - sparse_params.update(global_params) - if params_to_keep_per_layer is not None: - params = params_to_keep_per_layer.get(config["module_fqn"], None) - if params is not None: - per_layer_params = {k: config[k] for k in params} - sparse_params.update(per_layer_params) - if sparse_params: - # TODO handle multiple tensor being quantized on a single module, where to store sparse_params? - module.sparse_params = sparse_params - - def convert( - self, - module: nn.Module, - mapping: Optional[Dict[Type[nn.Module], Type[nn.Module]]] = None, - inplace: bool = False, - parameterization: Type[nn.Module] = FakeSparsity, - ): - r"""Converts submodules in input module to a different module according to `mapping` - by calling `from_dense` method on the target module class - Args: - module: input module - mapping: a dictionary that maps from source module type to target - module type, can be overwritten to allow swapping user defined - Modules - inplace: carry out model transformations in-place, the original module - is mutated - """ - if mapping is None: - raise NotImplementedError("Need to auto generate mapping ") - if not inplace: - module = copy.deepcopy(module) - - reassign = {} - for name, mod in module.named_children(): - # leaf node - if ( - module_contains_param(mod, parameterization) - and type_before_parametrizations(mod) in mapping - ): - reassign[name] = swap_module(mod, mapping) - else: - # recurse - reassign[name] = self.convert( - mod, - mapping=mapping, - inplace=True, - parameterization=parameterization, - ) - - for key, value in reassign.items(): - module._modules[key] = value - - return module - - def step(self, use_path: bool = True) -> None: - if not self.enable_mask_update: - return - with torch.no_grad(): - for config in self.groups: - self.update_mask(**config) - - @abc.abstractmethod - def update_mask(self, module: nn.Module, tensor_name: str, **kwargs): - pass +from torchao.prototype.sparsity.sparsifier.base_sparsifier import BaseSparsifier diff --git a/torchao/sparsity/prototype/sparsifier/nearly_diagonal_sparsifier.py b/torchao/sparsity/prototype/sparsifier/nearly_diagonal_sparsifier.py index 4f44e81485..640ec667c2 100644 --- a/torchao/sparsity/prototype/sparsifier/nearly_diagonal_sparsifier.py +++ b/torchao/sparsity/prototype/sparsifier/nearly_diagonal_sparsifier.py @@ -1,55 +1,3 @@ -import torch - -from . import base_sparsifier - - -class NearlyDiagonalSparsifier(base_sparsifier.BaseSparsifier): - r"""Nearly Diagonal Sparsifier - - This sparsifier creates a nearly diagonal mask to be applied to the weight matrix. - Nearly Diagonal Matrix is a matrix that contains non-zero elements near the diagonal and the rest are zero. - An example of a nearly diagonal matrix with degree (or nearliness) 3 and 5 are follows respectively. - 1 1 0 0 1 1 1 0 - 1 1 1 0 1 1 1 1 - 0 1 1 1 1 1 1 1 - 0 0 1 1 0 1 1 1 - Note that a nearly diagonal matrix with degree 1 is just a matrix with main diagonal populated - - This sparsifier is controlled by one variable: - 1. `nearliness` defines the number of non-zero diagonal lines that are closest to the main diagonal. - Currently - supports only odd number - - Note: - This can be accelerated (vectorized) once the Spdiagonal feature (PR: #78439) is landed or the banded matrix - feature is landed: https://stackoverflow.com/questions/52463972/generating-banded-matrices-using-numpy - - Args: - nearliness: The degree of nearliness (default = 1) - - """ - def __init__(self, nearliness: int = 1): - defaults = {'nearliness': nearliness} - super().__init__(defaults=defaults) - - def update_mask(self, module, tensor_name, nearliness, - **kwargs): - mask = getattr(module.parametrizations, tensor_name)[0].mask - mask.data = torch.zeros_like(mask) - if nearliness <= 0: - return - - tensor = getattr(module, tensor_name) - height, width = tensor.shape - - if nearliness % 2 == 0: - raise ValueError("nearliness can only be an odd number") - dist_to_diagonal = nearliness // 2 - # check - if dist_to_diagonal >= min(height, width): - raise ValueError("nearliness cannot be larger than the dimensions of tensor.") - - for row in range(0, height): - # Bounds of entries that needs to be set to 1 - low = max(0, row - dist_to_diagonal) - high = min(width, row + dist_to_diagonal + 1) - mask[row, low:high].fill_(1) +from torchao.prototype.sparsity.sparsifier.nearly_diagonal_sparsifier import ( + NearlyDiagonalSparsifier, +) diff --git a/torchao/sparsity/prototype/sparsifier/utils.py b/torchao/sparsity/prototype/sparsifier/utils.py index c52af88698..f410d1b325 100644 --- a/torchao/sparsity/prototype/sparsifier/utils.py +++ b/torchao/sparsity/prototype/sparsifier/utils.py @@ -1,130 +1 @@ -from typing import Any, Dict, Optional, Type -from torch.nn.utils.parametrize import type_before_parametrizations, is_parametrized -from itertools import chain - -from torch import nn - -__all__ = [ - "module_contains_param", - "swap_module", - "module_to_fqn", - "fqn_to_module", - "get_arg_info_from_tensor_fqn", - "FakeSparsity", -] - - -def module_contains_param(module: nn.Module, parametrization: Type[nn.Module]) -> bool: - if is_parametrized(module): - # see if any of the module tensors have a parametriztion attached that matches the one passed in - return any( - any(isinstance(param, parametrization) for param in param_list) - for key, param_list in module.parametrizations.items() # type: ignore[union-attr,operator] - ) - return False - - -def swap_module( - mod: nn.Module, mapping: Dict[Type[nn.Module], Type[nn.Module]] -) -> nn.Module: - r"""Swaps the module using from_dense according to the mapping passed in. - Args: - mod: input module - mapping: a dictionary that maps from nn module to sparse nn module - Return: - The corresponding sparse module of `mod` according to mapping, created using from_dense - """ - if type_before_parametrizations(mod) in mapping: - sparse_mod = mapping[type_before_parametrizations(mod)] - - # TODO Fix this typing, as Type[Module] has no attribute "from_dense" - new_mod = sparse_mod.from_dense(mod) # type: ignore[attr-defined] - - # Preserve module's pre forward hooks. They'll be called on quantized input - for pre_hook_fn in mod._forward_pre_hooks.values(): - new_mod.register_forward_pre_hook(pre_hook_fn) - # Preserve module's post forward hooks except _observer_forward_hook - # After convert they'll work with quantized output - for hook_fn in mod._forward_hooks.values(): - new_mod.register_forward_hook(hook_fn) - - # respect device affinity when swapping modules - devices = {p.device for p in chain(mod.parameters(), mod.buffers())} - assert len(devices) <= 1, ( - f"swap_module only works with cpu or single-device CUDA modules, but got devices {devices}" - ) - device = next(iter(devices)) if len(devices) > 0 else None - if device: - new_mod.to(device) - - return new_mod - - else: - return mod - - -def module_to_fqn( - model: nn.Module, module: nn.Module, prefix: str = "" -) -> Optional[str]: - """ - Returns the fqn for a module or None if module not a descendent of model. - """ - if module is model: - return "" - for name, child in model.named_children(): - fqn = module_to_fqn(child, module, ".") - if isinstance(fqn, str): - return prefix + name + fqn - return None - - -def fqn_to_module(model: Optional[nn.Module], path: str) -> Optional[nn.Module]: - """ - Given an fqn, returns the corresponding module or tensor or None if the fqn given by `path` - doesn't correspond to anything. Similar to model.get_submodule(path) but works for tensors. - """ - if path != "": - for name in path.split("."): - model = getattr(model, name, None) - return model - - -def get_arg_info_from_tensor_fqn(model: nn.Module, tensor_fqn: str) -> Dict[str, Any]: - """ - Uses tensor_fqn to obtain a dict containing module_fqn, module and tensor_name - """ - # string manip to split tensor_fqn into module_fqn and tensor_name - # if tensor_fqn is 'weight' then module_fqn and tensor_name are '' and 'weight' - # if tensor_fqn is 'linear.weight' then module_fqn and tensor_name are 'linear' and 'weight' - tensor_name = tensor_fqn.split(".")[-1] - module_fqn = tensor_fqn[: -len(tensor_name) - ("." in tensor_fqn)] - - module = fqn_to_module(model, module_fqn) - - return { - "module_fqn": module_fqn, - "module": module, - "tensor_name": tensor_name, - "tensor_fqn": tensor_fqn, - } - - -# Parametrizations -class FakeSparsity(nn.Module): - r"""Parametrization for the weights. Should be attached to the 'weight' or - any other parameter that requires a mask applied to it. - - Note:: - - Once the mask is passed, the variable should not change the id. The - contents of the mask can change, but the mask reference itself should - not. - """ - - def __init__(self, mask): - super().__init__() - self.register_buffer("mask", mask) - - def forward(self, x): - assert self.mask.shape == x.shape - return self.mask * x +from torchao.prototype.sparsity.sparsifier.utils import FakeSparsity diff --git a/torchao/sparsity/prototype/sparsifier/weight_norm_sparsifier.py b/torchao/sparsity/prototype/sparsifier/weight_norm_sparsifier.py index 2b24ca3d82..a490e5f65b 100644 --- a/torchao/sparsity/prototype/sparsifier/weight_norm_sparsifier.py +++ b/torchao/sparsity/prototype/sparsifier/weight_norm_sparsifier.py @@ -1,200 +1,3 @@ -from functools import reduce -from typing import Callable, Optional, Tuple, Union - -import torch -import torch.nn.functional as F - -from .base_sparsifier import BaseSparsifier -import operator - -__all__ = ["WeightNormSparsifier"] - -def _flat_idx_to_2d(idx, shape): - rows = idx // shape[1] - cols = idx % shape[1] - return rows, cols - -class WeightNormSparsifier(BaseSparsifier): - r"""Weight-Norm Sparsifier - - This sparsifier computes the norm of every sparse block and "zeroes-out" the - ones with the lowest norm. The level of sparsity defines how many of the - blocks is removed. - - This sparsifier is controlled by three variables: - 1. `sparsity_level` defines the number of *sparse blocks* that are zeroed-out - 2. `sparse_block_shape` defines the shape of the sparse blocks. Note that - the sparse blocks originate at the zero-index of the tensor. - 3. `zeros_per_block` is the number of zeros that we are expecting in each - sparse block. By default we assume that all elements within a block are - zeroed-out. However, setting this variable sets the target number of - zeros per block. The zeros within each block are chosen as the *smallest - absolute values*. - - Args: - - sparsity_level: The target level of sparsity - sparse_block_shape: The shape of a sparse block (see note below) - zeros_per_block: Number of zeros in a sparse block - norm: Norm to use. Could be either `int` or a callable. - If `int`, only L1 and L2 are implemented. - - Note:: - The `sparse_block_shape` is tuple representing (block_ROWS, block_COLS), - irrespective of what the rows / cols mean in the data tensor. That means, - if you were to sparsify a weight tensor in the nn.Linear, which has a - weight shape `(Cout, Cin)`, the `block_ROWS` would refer to the output - channels, while the `block_COLS` would refer to the input channels. - - Note:: - All arguments to the WeightNormSparsifier constructor are "default" - arguments and could be overriden by the configuration provided in the - `prepare` step. - """ - def __init__(self, - sparsity_level: float = 0.5, - sparse_block_shape: Tuple[int, int] = (1, 4), - zeros_per_block: Optional[int] = None, - norm: Optional[Union[Callable, int]] = None): - if zeros_per_block is None: - zeros_per_block = reduce(operator.mul, sparse_block_shape) - defaults = { - "sparsity_level": sparsity_level, - "sparse_block_shape": sparse_block_shape, - "zeros_per_block": zeros_per_block, - } - if norm is None: - norm = 2 - if callable(norm): - self.norm_fn = norm - elif norm == 1: - self.norm_fn = lambda T: T.abs() - elif norm == 2: - self.norm_fn = lambda T: T * T - else: - raise NotImplementedError(f"L-{norm} is not yet implemented.") - super().__init__(defaults=defaults) - - def _scatter_fold_block_mask(self, output_shape, dim, indices, block_shape, - mask=None, input_shape=None, device=None): - r"""Creates patches of size `block_shape` after scattering the indices.""" - if mask is None: - assert input_shape is not None - mask = torch.ones(input_shape, device=device) - mask.scatter_(dim=dim, index=indices, value=0) - mask.data = F.fold(mask, output_size=output_shape, kernel_size=block_shape, stride=block_shape) - return mask - - def _make_tensor_mask(self, data, input_shape, sparsity_level, sparse_block_shape, mask=None): - r"""Creates a tensor-level mask. - - Tensor-level mask is described as a mask, where the granularity of sparsification of the - smallest patch is the sparse_block_shape. That means, that for a given mask and a - sparse_block_shape, the smallest "patch" of zeros/ones could be the sparse_block_shape. - - In this context, `sparsity_level` describes the fraction of sparse patches. - """ - h, w = data.shape[-2:] - block_h, block_w = sparse_block_shape - dh = (block_h - h % block_h) % block_h - dw = (block_w - w % block_w) % block_w - - if mask is None: - mask = torch.ones(h + dh, w + dw, device=data.device) - - if sparsity_level >= 1.0: - mask.data = torch.zeros_like(mask) - return mask - elif sparsity_level <= 0.0: - mask.data = torch.ones_like(mask) - return mask - - values_per_block = reduce(operator.mul, sparse_block_shape) - if values_per_block > 1: - # Reduce the data - data = F.avg_pool2d( - data[None, None, :], kernel_size=sparse_block_shape, stride=sparse_block_shape, ceil_mode=True - ) - data = data.flatten() - num_blocks = len(data) - - data = data.repeat(1, values_per_block, 1) - - threshold_idx = int(round(sparsity_level * num_blocks)) - threshold_idx = max(0, min(num_blocks - 1, threshold_idx)) # Sanity check - _, sorted_idx = torch.topk(data, k=threshold_idx, dim=2, largest=False) - - # Temp reshape for mask - mask_reshape = mask.reshape(data.shape) # data might be reshaped - self._scatter_fold_block_mask( - dim=2, output_shape=(h + dh, w + dw), - indices=sorted_idx, block_shape=sparse_block_shape, mask=mask_reshape - ) - mask.data = mask_reshape.squeeze().reshape(mask.shape)[:h, :w].contiguous() - return mask - - def _make_block_mask(self, data, sparse_block_shape, zeros_per_block, mask=None): - r"""Creates a block-level mask. - - Block-level mask is described as a mask, where the granularity of sparsification of the - largest patch is the sparse_block_shape. That means that for a given mask and a - sparse_block_shape, the sparsity is computed only within a patch of a size sparse_block_shape. - - In this context the `zeros_per_block` describes the number of zeroed-out elements within a patch. - """ - h, w = data.shape[-2:] - block_h, block_w = sparse_block_shape - dh = (block_h - h % block_h) % block_h - dw = (block_w - w % block_w) % block_w - values_per_block = reduce(operator.mul, sparse_block_shape) - - if mask is None: - mask = torch.ones((h + dh, w + dw), device=data.device) - - if values_per_block == zeros_per_block: - # Everything should be sparsified - mask.data = torch.zeros_like(mask) - return mask - - # create a new padded tensor like data (to match the block_shape) - padded_data = torch.ones(h + dh, w + dw, dtype=data.dtype, device=data.device) - padded_data.fill_(torch.nan) - padded_data[:h, :w] = data - unfolded_data = F.unfold(padded_data[None, None, :], kernel_size=sparse_block_shape, stride=sparse_block_shape) - - # Temp reshape for mask - mask_reshape = mask.reshape(unfolded_data.shape) - _, sorted_idx = torch.topk(unfolded_data, k=zeros_per_block, dim=1, largest=False) - - self._scatter_fold_block_mask( - dim=1, indices=sorted_idx, output_shape=padded_data.shape, block_shape=sparse_block_shape, mask=mask_reshape - ) - - mask.data = mask_reshape.squeeze().reshape(mask.shape).contiguous() - return mask - - def update_mask(self, module, tensor_name, sparsity_level, sparse_block_shape, - zeros_per_block, **kwargs): - values_per_block = reduce(operator.mul, sparse_block_shape) - if zeros_per_block > values_per_block: - raise ValueError( - "Number of zeros per block cannot be more than the total number of elements in that block." - ) - if zeros_per_block < 0: - raise ValueError("Number of zeros per block should be positive.") - - mask = getattr(module.parametrizations, tensor_name)[0].mask - if sparsity_level <= 0 or zeros_per_block == 0: - mask.data = torch.ones_like(mask) - elif sparsity_level >= 1.0 and (zeros_per_block == values_per_block): - mask.data = torch.zeros_like(mask) - else: - ww = self.norm_fn(getattr(module, tensor_name)) - tensor_mask = self._make_tensor_mask( - data=ww, input_shape=ww.shape, sparsity_level=sparsity_level, sparse_block_shape=sparse_block_shape - ) - if values_per_block != zeros_per_block: - block_mask = self._make_block_mask(data=ww, sparse_block_shape=sparse_block_shape, - zeros_per_block=zeros_per_block) - tensor_mask = torch.logical_or(tensor_mask, block_mask) - mask.data = tensor_mask +from torchao.prototype.sparsity.sparsifier.weight_norm_sparsifier import ( + WeightNormSparsifier, +) diff --git a/torchao/sparsity/prototype/superblock/blocksparse.py b/torchao/sparsity/prototype/superblock/blocksparse.py index 69c98f6afc..a1696c02b7 100644 --- a/torchao/sparsity/prototype/superblock/blocksparse.py +++ b/torchao/sparsity/prototype/superblock/blocksparse.py @@ -1,239 +1 @@ -from functools import partial -from typing import Any, Callable, Dict, List, Optional, Tuple - -import torch -from torch.sparse._triton_ops import broadcast_batch_dims, bsr_dense_addmm, bsr_dense_mm -from torch.utils._python_dispatch import return_and_correct_aliasing -from torchao.quantization.quant_api import _get_linear_subclass_inserter -from torchao.utils import TorchAOBaseTensor - -aten = torch.ops.aten - - -# quantization support -@torch.library.custom_op("blocksparse::bsr_to_dense", mutates_args=()) -def bsr_to_dense( - crow_indices: torch.Tensor, - col_indices: torch.Tensor, - values: torch.Tensor, - M: int, - K: int, -) -> torch.Tensor: - return torch.sparse_bsr_tensor( - crow_indices=crow_indices, col_indices=col_indices, values=values, size=(M, K) - ).to_dense() - - -@torch.library.register_fake("blocksparse::bsr_to_dense") -def bsr_to_dense_abstract( - crow_indices: torch.Tensor, - col_indices: torch.Tensor, - values: torch.Tensor, - M: int, - K: int, -) -> torch.Tensor: - return torch.empty((M, K), dtype=values.dtype, device=values.device) - - -@torch.library.custom_op("blocksparse::int_addmm", mutates_args=()) -def blocksparse_int_addmm( - crow_indices: torch.Tensor, - col_indices: torch.Tensor, - values: torch.Tensor, - A: torch.Tensor, - left_alpha: torch.Tensor, - right_alpha: torch.Tensor, -) -> torch.Tensor: - assert values.dtype == torch.int8 - M = left_alpha.shape[-1] - K = A.shape[-2] - N = A.shape[-1] - weight_bsr = torch.sparse_bsr_tensor(crow_indices, col_indices, values, size=(M, K)) - original_batch_dims_broadcasted = broadcast_batch_dims( - blocksparse_int_addmm, weight_bsr, A - ) - out = A.new_empty(original_batch_dims_broadcasted + (M, N), dtype=torch.bfloat16) - return bsr_dense_addmm( - out, - weight_bsr, - A, - alpha=1, - beta=0, - out=out, - left_alpha=left_alpha, - right_alpha=right_alpha, - ).t() - - -@torch.library.register_fake("blocksparse::int_addmm") -def blocksparse_int_addmm_abstract( - crow_indices: torch.Tensor, - col_indices: torch.Tensor, - values: torch.Tensor, - A: torch.Tensor, - left_alpha: torch.Tensor, - right_alpha: torch.Tensor, -) -> torch.Tensor: - N = A.shape[-1] - M = left_alpha.shape[-1] - # to have the same strides as the transposed result - return torch.empty((M, N), dtype=torch.bfloat16, device=A.device).t() - - -# bsr wrapper custom op -@torch.library.custom_op("blocksparse::linear", mutates_args=()) -def blocksparse_linear( - A: torch.Tensor, - crow_indices: torch.Tensor, - col_indices: torch.Tensor, - values: torch.Tensor, - M: int, - K: int, - bias: torch.Tensor, -) -> torch.Tensor: - weight_bsr = torch.sparse_bsr_tensor(crow_indices, col_indices, values, size=(M, K)) - return torch.nn.functional.linear(A, weight_bsr, bias) - - -@torch.library.register_fake("blocksparse::linear") -def blocksparse_linear_abstract( - A: torch.Tensor, - crow_indices: torch.Tensor, - col_indices: torch.Tensor, - values: torch.Tensor, - M: int, - K: int, - bias: torch.Tensor, -) -> torch.Tensor: - new_shape = A.shape[:-1] + (M,) - return torch.empty(new_shape, dtype=A.dtype, device=A.device) - - -# Subclass definition -class BlockSparseTensor(TorchAOBaseTensor): - bsr_crow_indices: Optional[torch.Tensor] - bsr_col_indices: Optional[torch.Tensor] - bsr_values: Optional[torch.Tensor] - - __slots__ = ["bsr_crow_indices", "bsr_col_indices", "bsr_values"] - - @staticmethod - def __new__( # noqa: PYI034 - cls, - shape: torch.Size, - bsr_crow_indices: Optional[torch.Tensor], - bsr_col_indices: Optional[torch.Tensor], - bsr_values: Optional[torch.Tensor], - requires_grad: bool = False, - ): - if bsr_values is None: - raise ValueError( - "No values passed to BlockSparseTensor: bsr_values must be provided!" - ) - else: - previous_tensor = bsr_values - - kwargs = { - "device": previous_tensor.device, - "dtype": previous_tensor.dtype, - "layout": previous_tensor.layout, - "requires_grad": requires_grad, - } - tensor = torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] - tensor.bsr_crow_indices = bsr_crow_indices - tensor.bsr_col_indices = bsr_col_indices - tensor.bsr_values = bsr_values - return tensor - - def __repr__(self) -> str: # type: ignore[override] - assert hasattr(self, "shape") - return f"{self.__class__.__name__}(shape={self.shape})" - - def __tensor_flatten__(self) -> Tuple[List[str], Tuple[torch.Size, bool]]: - inner_tensors = list( - filter(lambda x: getattr(self, x) is not None, self.__slots__) - ) - tensor_meta = (self.shape, self.requires_grad) - return inner_tensors, tensor_meta - - @classmethod - def __tensor_unflatten__( - cls, - inner_tensors, - tensor_meta: Tuple[torch.Size, bool], - outer_size, - outer_stride, - ) -> torch.Tensor: - shape, requires_grad = tensor_meta - return cls( - shape=shape, - bsr_crow_indices=inner_tensors.get("bsr_crow_indices", None), - bsr_col_indices=inner_tensors.get("bsr_col_indices", None), - bsr_values=inner_tensors.get("bsr_values", None), - requires_grad=requires_grad, - ) - - @classmethod - def from_dense(cls, dense_tensor, blocksize): - bsr_tensor = dense_tensor.to_sparse_bsr(blocksize) - return cls( - shape=dense_tensor.shape, - bsr_crow_indices=bsr_tensor.crow_indices(), - bsr_col_indices=bsr_tensor.col_indices(), - bsr_values=bsr_tensor.values(), - requires_grad=False, - ) - - def apply_fn_to_shard(self, func): - return BlockSparseTensor( - shape=self.shape, - bsr_crow_indices=func(self.bsr_crow_indices), - bsr_col_indices=func(self.bsr_col_indices), - bsr_values=func(self.bsr_values), - requires_grad=self.requires_grad, - ) - - -# Subclass op dispatch registration -implements = BlockSparseTensor.implements - - -@implements(aten.detach.default) -def block_sparse_detach(func, types, args, kwargs): - return return_and_correct_aliasing( - func, args, kwargs, args[0].apply_fn_to_shard(torch.detach) - ) - - -@implements(aten.values.default) -def block_sparse_values(func, types, args, kwargs): - return args[0].bsr_values.detach() - - -@implements(aten.crow_indices.default) -def block_sparse_crow_indices(func, types, args, kwargs): - return args[0].bsr_crow_indices.detach() - - -@implements(aten.col_indices.default) -def block_sparse_col_indices(func, types, args, kwargs): - return args[0].bsr_col_indices.detach() - - -@implements(aten._nnz.default) -def block_sparse__nnz(func, types, args, kwargs): - return args[0].bsr_values.shape[0] - - -@implements(torch.nn.functional.linear) -def block_sparse_linear(func, types, args, kwargs): - x, w, bias = args - return torch.ops.blocksparse.linear( - x, w.crow_indices(), w.col_indices(), w.values(), w.shape[0], w.shape[1], bias - ) - - -def block_sparse_weight(blocksize=64): - return _get_linear_subclass_inserter( - partial(BlockSparseTensor.from_dense, blocksize=blocksize) - ) +from torchao.prototype.sparsity.superblock.blocksparse import BlockSparseTensor diff --git a/torchao/sparsity/prototype/superblock/supermask.py b/torchao/sparsity/prototype/superblock/supermask.py index 0b28763445..75d5b17651 100644 --- a/torchao/sparsity/prototype/superblock/supermask.py +++ b/torchao/sparsity/prototype/superblock/supermask.py @@ -1,275 +1,5 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. - -import torch.nn as nn -import math -import torch -from torch.autograd import Variable -import torch.nn.functional as F -import numpy as np - -# original supermask -scores_min=None -scores_max=9e9 -uniform_init_01 = False - -# adjusted supermask, initialize scores with uniform distribution in [0,1], clamp scores in each step in [0,1] -# scores_min=0. -# scores_max=1. -# uniform_init_01 = True - -def percentile(t, q): - """Return the value that is larger than q% of t""" - k = 1 + round(.01 * float(q) * (t.numel() - 1)) - return t.view(-1).kthvalue(k).values - - -class GetSubnet(torch.autograd.Function): - """Supermask STE function""" - @staticmethod - def forward(ctx, scores, zeros, ones, sparsity): - clamped_scores = scores.clamp(min=scores_min,max=scores_max) - k_val = percentile(clamped_scores, sparsity*100) - return torch.where(clamped_scores < k_val, zeros.to(scores.device), ones.to(scores.device)) - @staticmethod - def backward(ctx, g): - return g, None, None, None - - -class SupermaskLinear(nn.Linear): - """Supermask class for Linear layer""" - def __init__(self, sparsity, fixed_mask, fixed_weight, bitwidth, transform, fixed_transform, *args, **kwargs): - tile_size = kwargs.pop("tile_size", 1) - super(SupermaskLinear, self).__init__(*args, **kwargs) - # initialize the scores - max_sparsity = 1 - (1 / math.prod([math.ceil(k / tile_size) for k in self.weight.size()])) - self.sparsity = sparsity - if self.sparsity > max_sparsity: - print( - f"reducing sparsity from {self.sparsity} to {max_sparsity}", - f"(maximum sparsity for layer with shape {self.weight.size()} and tile size {tile_size})" - ) - self.sparsity = max_sparsity - self.tile_size = tile_size - self.sparsify_weights = False - self.scores = nn.Parameter( - torch.empty( - [max(1, int(math.ceil(wn / tile_size))) for wn in self.weight.size()] - ), - requires_grad=not fixed_mask, - ) - nn.init.uniform_(self.scores) if uniform_init_01 else nn.init.kaiming_uniform_(self.scores, a=math.sqrt(5)) - - # the shift and the scale are transformation parameters - # the actually used weights = self.weight*self.scale+self.shift - # the transformation is activated only for quantized weights - self.shift=nn.Parameter(torch.Tensor(1).fill_(0.), requires_grad=False) - self.scale=nn.Parameter(torch.Tensor(1).fill_(1.), requires_grad=False) - - with torch.no_grad(): - # if bitwidth is None, then use floating point values in self.weight - # if bitwidth is not None, then quantize self.weight into k-bit (k=bitwidth) - # quantized values are -2^(k-1), -2^(k-1)+1, ..., 0, 1, ..., 2^(k-1)-1 - # these quantized values are uniformly distributed - if bitwidth is not None: - weights_max = torch.max(self.weight).item() - weights_min = torch.min(self.weight).item() - least_step = (weights_max-weights_min)/pow(2,bitwidth) - left_bound = weights_min-1e-6 - right_bound = weights_min+least_step+1e-6 - # self.shift=nn.Parameter(torch.Tensor(1).fill_( (weights_min+(pow(2,bitwidth-1)+0.5)*least_step) if transform[0] is None else transform[0] ), requires_grad=not fixed_transform[0]) - # self.scale=nn.Parameter(torch.Tensor(1).fill_( least_step if transform[1] is None else transform[1] ), requires_grad=not fixed_transform[1]) - # for example, if using binary weights (k=1) with -a, +a, set transform = [a,2a]; if using binary weights (k=1) with a, 0, set transform = [0,-a]; - self.shift=nn.Parameter(torch.Tensor(1).fill_( 0. if transform[0] is None else transform[0] ), requires_grad=not fixed_transform[0]) - self.scale=nn.Parameter(torch.Tensor(1).fill_( 1. if transform[1] is None else transform[1] ), requires_grad=not fixed_transform[1]) - for i in range(-int(pow(2,bitwidth-1)),int(pow(2,bitwidth-1))): - self.weight[torch.logical_and(self.weight>left_bound, self.weight<=right_bound)] = i - left_bound = right_bound - right_bound += least_step - - self.weight.requires_grad = not fixed_weight - - def get_mask(self): - subnet = GetSubnet.apply(self.scores, - torch.zeros_like(self.scores), - torch.ones_like(self.scores), - self.sparsity) - - if self.tile_size != 1: - for i, k in enumerate(self.weight.shape): - subnet = subnet.repeat_interleave(self.tile_size, dim=i) - subnet = torch.narrow(subnet, i, 0, k) - - return subnet - - def sparsify_offline(self): - subnet = self.get_mask() - self.weight.data = (self.weight*self.scale+self.shift) * subnet - self.sparsify_weights = True - - def forward(self, x): - if not self.sparsify_weights: - subnet = self.get_mask() - w = (self.weight*self.scale+self.shift) * subnet - else: - w = self.weight - return F.linear(x, w, self.bias) - - -class SupermaskConv2d(nn.Conv2d): - """Supermask class for Conv2d layer""" - def __init__(self, sparsity, fixed_mask, fixed_weight, bitwidth, transform, fixed_transform, *args, **kwargs): - tile_size = kwargs.pop("tile_size", 1) - super(SupermaskConv2d, self).__init__(*args, **kwargs) - # initialize the scores - max_sparsity = 1 - (1 / math.prod([math.ceil(k / tile_size) for k in self.weight.size()])) - self.sparsity = sparsity - if self.sparsity > max_sparsity: - print( - f"reducing sparsity from {self.sparsity} to {max_sparsity}", - f"(maximum sparsity for layer with shape {self.weight.size()} and tile size {tile_size})" - ) - self.sparsity = max_sparsity - self.tile_size = tile_size - self.scores = nn.Parameter( - torch.empty( - [max(1, int(math.ceil(wn / tile_size))) for wn in self.weight.size()] - ), - requires_grad=not fixed_mask, - ) - nn.init.uniform_(self.scores) if uniform_init_01 else nn.init.kaiming_uniform_(self.scores, a=math.sqrt(5)) - - # the shift and the scale are transformation parameters - # the actually used weights = self.weight*self.scale+self.shift - # the transformation is activated only for quantized weights - self.shift=nn.Parameter(torch.Tensor(1).fill_(0.), requires_grad=False) - self.scale=nn.Parameter(torch.Tensor(1).fill_(1.), requires_grad=False) - - with torch.no_grad(): - # if bitwidth is None, then use floating point values in self.weight - # if bitwidth is not None, then quantize self.weight into k-bit (k=bitwidth) - # quantized values are -2^(k-1), -2^(k-1)+1, ..., 0, 1, ..., 2^(k-1)-1 - # these quantized values are uniformly distributed - if bitwidth is not None: - weights_max = torch.max(self.weight).item() - weights_min = torch.min(self.weight).item() - least_step = (weights_max-weights_min)/pow(2,bitwidth) - left_bound = weights_min-1e-6 - right_bound = weights_min+least_step+1e-6 - # self.shift=nn.Parameter(torch.Tensor(1).fill_( (weights_min+(pow(2,bitwidth-1)+0.5)*least_step) if transform[0] is None else transform[0] ), requires_grad=not fixed_transform[0]) - # self.scale=nn.Parameter(torch.Tensor(1).fill_( least_step if transform[1] is None else transform[1]), requires_grad=not fixed_transform[1]) - # for example, if using binary weights (k=1) with -a, +a, set transform = [a,2a]; if using binary weights (k=1) with a, 0, set transform = [0,-a]; - self.shift=nn.Parameter(torch.Tensor(1).fill_( 0. if transform[0] is None else transform[0] ), requires_grad=not fixed_transform[0]) - self.scale=nn.Parameter(torch.Tensor(1).fill_( 1. if transform[1] is None else transform[1] ), requires_grad=not fixed_transform[1]) - for i in range(-int(pow(2,bitwidth-1)),int(pow(2,bitwidth-1))): - self.weight[torch.logical_and(self.weight>left_bound, self.weight<=right_bound)] = i - left_bound = right_bound - right_bound += least_step - - self.weight.requires_grad = not fixed_weight - - def forward(self, x): - subnet = GetSubnet.apply(self.scores, - torch.zeros_like(self.scores), - torch.ones_like(self.scores), - self.sparsity) - - if self.tile_size != 1: - for i, k in enumerate(self.weight.shape): - # if k == 1: continue - subnet = subnet.repeat_interleave(self.tile_size, dim=i) - subnet = torch.narrow(subnet, i, 0, k) - - w = (self.weight*self.scale+self.shift) * subnet - return F.conv2d(x, w, self.bias, self.stride, self.padding, self.dilation, self.groups) - -def apply_supermask( - model, - linear_sparsity=0.0, - linear_sp_tilesize=1, - conv1x1_sparsity=0.0, - conv1x1_sp_tilesize=1, - conv_sparsity=0.0, - conv_sp_tilesize=1, - skip_last_layer_sparsity=False, - skip_first_transformer_sparsity=False, - device="cuda", - verbose=False, -): - sparsified_modules = {} - - for n, m in model.named_modules(): - # check conditions for skipping sparsity - if skip_last_layer_sparsity and n == "heads.head": - continue - if skip_first_transformer_sparsity and "encoder.layers.encoder_layer_0" in n: - continue - - # convert 1x1 convolutions - if conv1x1_sparsity != 0.0 and isinstance(m, torch.nn.Conv2d) and m.kernel_size == (1, 1): - new_m = SupermaskConv2d( - conv1x1_sparsity, False, False, None, None, None, - m.in_channels, - m.out_channels, - m.kernel_size, - stride=m.stride, - padding=m.padding, - dilation=m.dilation, - groups=m.groups, - bias=m.bias is not None, - padding_mode=m.padding_mode, - device=device, - tile_size=conv1x1_sp_tilesize, - ) - new_m.weight.data.copy_(m.weight.data) - if m.bias is not None: - new_m.bias.data.copy_(m.bias.data) - sparsified_modules[n] = new_m - continue - - # convert all other convolutions (not tested!) - if conv_sparsity != 0.0 and isinstance(m, torch.nn.Conv2d): - new_m = SupermaskConv2d( - conv_sparsity, False, False, None, None, None, - m.in_channels, - m.out_channels, - m.kernel_size, - stride=m.stride, - padding=m.padding, - dilation=m.dilation, - groups=m.groups, - bias=m.bias is not None, - padding_mode=m.padding_mode, - device=device, - tile_size=conv_sp_tilesize, - ) - new_m.weight.data.copy_(m.weight.data) - if m.bias is not None: - new_m.bias.data.copy_(m.bias.data) - sparsified_modules[n] = new_m - continue - - if linear_sparsity != 0.0 and isinstance(m, torch.nn.Linear): - new_m = SupermaskLinear( - linear_sparsity, False, False, None, None, None, - m.in_features, - m.out_features, - bias=m.bias is not None, - device=device, - tile_size=linear_sp_tilesize, - ) - new_m.weight.data.copy_(m.weight.data) - if m.bias is not None: - new_m.bias.data.copy_(m.bias.data) - sparsified_modules[n] = new_m - continue - - # add modules to model - for k, v in sparsified_modules.items(): - sm_name, ch_name = k.rsplit(".", 1) - sm = model.get_submodule(sm_name) - sm.add_module(ch_name, v) - - if verbose: - print(f'sparsified module "{k}" with sparsity={v.sparsity}, tile size={v.tile_size}') - - return model +from torchao.prototype.sparsity.superblock.supermask import ( + GetSubnet, + SupermaskConv2d, + SupermaskLinear, +) diff --git a/torchao/sparsity/prototype/superblock/utils.py b/torchao/sparsity/prototype/superblock/utils.py index e0cf4a1777..e409abffc0 100644 --- a/torchao/sparsity/prototype/superblock/utils.py +++ b/torchao/sparsity/prototype/superblock/utils.py @@ -1,1056 +1,10 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. - -import argparse -import copy -import datetime -import errno -import hashlib -import math -import os -import time -from collections import defaultdict, deque, OrderedDict -from typing import List, Optional, Tuple - -import torch - -from torchao.quantization import int8_dynamic_activation_int8_weight, quantize_ -from torchao.sparsity import semi_sparse_weight, sparsify_ -from torchao.sparsity.prototype.sparsifier.weight_norm_sparsifier import ( - WeightNormSparsifier, +from torchao.prototype.sparsity.superblock.utils import ( + ClassificationPresetEval, + ClassificationPresetTrain, + ExponentialMovingAverage, + MetricLogger, + RandomCutmix, + RandomMixup, + RASampler, + SmoothedValue, ) -from torchao.sparsity.prototype.superblock.blocksparse import block_sparse_weight -from torchao.sparsity.prototype.superblock.supermask import ( - apply_supermask, - SupermaskLinear, -) -from torchvision.transforms import autoaugment, functional as F, transforms -from torchvision.transforms.functional import InterpolationMode - -def get_args_parser(train=False, evaluate=False, benchmark=False): - assert sum([train, evaluate, benchmark]) == 1, "One and only one of training, evaluation, or benchmark can be true" - - # Shared common args - parser = argparse.ArgumentParser(description="SuperBlock Imagenet Training/Evaluation/Benchmarking Script", add_help=True) - parser.add_argument("--data-path", type=str, help="IMAGENET dataset path") - parser.add_argument("--model", default="vit_b_16", choices=["vit_b_16", "vit_h_14"], type=str, help="ViT base model") - parser.add_argument("--device", default="cuda", type=str, help="device (Default: cuda)") - parser.add_argument("-b", "--batch-size", default=32, type=int, help="per device batch size") - parser.add_argument("--val-crop-size", default=224, type=int, help="the central crop size used for validation (default: 224)") - parser.add_argument("--sparsity", choices=["bsr", "semi_structured"], default=None, help='weight sparsification to apply') - parser.add_argument('--bsr', type=int, nargs='?', const=256, default=None, help='Convert sparsified weights to BSR format with optional block size (default: 256)') - parser.add_argument("--sparsity-linear", type=float, default=0.0) - parser.add_argument("--sparsity-conv1x1", type=float, default=0.0) - parser.add_argument("--sparsity-conv", type=float, default=0.0) - parser.add_argument("--skip-last-layer-sparsity", action="store_true", help="Skip applying sparsity to the last linear layer (for vit only)") - parser.add_argument("--skip-first-transformer-sparsity", action="store_true", help="Skip applying sparsity to the first transformer layer (for vit only)") - parser.add_argument("--quantization", action="store_true", help="Run with int8 dynamic quantization") - parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load") - parser.add_argument("--weights-path", type=str, help="optional checkpoint to load weights after intialization") - parser.add_argument("--header", action="store_true", help="Print header for first run") - - # Eval a subset of training args - # lots of training args - if train or evaluate: - parser.add_argument("-j", "--workers", default=16, type=int, metavar="N", help="number of data loading workers") - parser.add_argument("--accumulation-steps", default=1, type=int, help="Number of steps to accumulate gradients over") - parser.add_argument("--epochs", default=90, type=int, metavar="N", help="number of total epochs to run") - parser.add_argument("--opt", default="sgd", type=str, help="optimizer") - parser.add_argument("--lr", default=0.1, type=float, help="initial learning rate") - parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum") - parser.add_argument("--wd", "--weight-decay", default=1e-4, type=float, metavar="W", help="weight decay", dest="weight_decay") - parser.add_argument("--norm-weight-decay", default=None, type=float, help="weight decay for Normalization layers (default: None, same value as --wd)") - parser.add_argument("--bias-weight-decay", default=None, type=float, help="weight decay for bias parameters of all layers (default: None, same value as --wd)") - parser.add_argument("--transformer-embedding-decay", default=None, type=float, help="weight decay for embedding parameters for vision transformer models (default: None, same value as --wd)") - parser.add_argument("--label-smoothing", default=0.0, type=float, help="label smoothing (default: 0.0)", dest="label_smoothing") - parser.add_argument("--mixup-alpha", default=0.0, type=float, help="mixup alpha (default: 0.0)") - parser.add_argument("--cutmix-alpha", default=0.0, type=float, help="cutmix alpha (default: 0.0)") - parser.add_argument("--lr-scheduler", default="steplr", type=str, help="the lr scheduler (default: steplr)") - parser.add_argument("--lr-warmup-epochs", default=0, type=int, help="the number of epochs to warmup (default: 0)") - parser.add_argument("--lr-warmup-method", default="constant", type=str, help="the warmup method (default: constant)") - parser.add_argument("--lr-warmup-decay", default=0.01, type=float, help="the decay for lr") - parser.add_argument("--lr-step-size", default=30, type=int, help="decrease lr every step-size epochs") - parser.add_argument("--lr-gamma", default=0.1, type=float, help="decrease lr by a factor of lr-gamma") - parser.add_argument("--lr-min", default=0.0, type=float, help="minimum lr of lr schedule (default: 0.0)") - parser.add_argument("--print-freq", default=10, type=int, help="print frequency") - parser.add_argument("--output-dir", default=".", type=str, help="path to save outputs") - parser.add_argument('--resume', action='store_true', help='Resumes training from latest available checkpoint ("model_.pth")') - parser.add_argument("--start-epoch", default=0, type=int, metavar="N", help="start epoch") - parser.add_argument("--cache-dataset", dest="cache_dataset", help="Cache the datasets for quicker initialization. It also serializes the transforms", action="store_true") - parser.add_argument("--sync-bn", dest="sync_bn", help="Use sync batch norm", action="store_true") - parser.add_argument("--auto-augment", default=None, type=str, help="auto augment policy (default: None)") - parser.add_argument("--ra-magnitude", default=9, type=int, help="magnitude of auto augment policy") - parser.add_argument("--augmix-severity", default=3, type=int, help="severity of augmix policy") - parser.add_argument("--random-erase", default=0.0, type=float, help="random erasing probability (default: 0.0)") - # Mixed precision training parameters - parser.add_argument("--amp", action="store_true", help="Use torch.cuda.amp for mixed precision training") - # distributed training parameters - parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes") - parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training") - parser.add_argument("--model-ema", action="store_true", help="enable tracking Exponential Moving Average of model parameters") - parser.add_argument("--model-ema-steps", type=int, default=32, help="the number of iterations that controls how often to update the EMA model (default: 32)") - parser.add_argument("--model-ema-decay", type=float, default=0.99998, help="decay factor for Exponential Moving Average of model parameters (default: 0.99998)") - parser.add_argument("--use-deterministic-algorithms", action="store_true", help="Forces the use of deterministic algorithms only.") - parser.add_argument("--interpolation", default="bilinear", type=str, help="the interpolation method (default: bilinear)") - parser.add_argument("--val-resize-size", default=256, type=int, help="the resize size used for validation (default: 256)") - parser.add_argument("--train-crop-size", default=224, type=int, help="the random crop size used for training (default: 224)") - parser.add_argument("--clip-grad-norm", default=None, type=float, help="the maximum gradient norm (default None)") - parser.add_argument("--ra-reps", default=3, type=int, help="number of repetitions for Repeated Augmentation (default: 3)") - parser.add_argument('--meta', action='store_true', help='Use Meta internal imagenet structure') - - if benchmark: - parser.add_argument("--dtype", choices=["float32", "bfloat16", "float16"], help="Data type", default="bfloat16") - parser.add_argument("--tune-kernel-params", action="store_true", help="Tune kernel params for BSR") - parser.add_argument("--profile", action="store_true", help="Dump Prefetto trace") - - return parser - - - -# filter functions -def mlp_0_only(mod, name): - return isinstance(mod, torch.nn.Linear) and "mlp.0" in name - - -def mlp_3_only(mod, name): - return isinstance(mod, torch.nn.Linear) and "mlp.3" in name - - -def mlp_only(mod, name): - return isinstance(mod, torch.nn.Linear) and "mlp" in name - - -def superblock_only(mod, name): - return isinstance(mod, SupermaskLinear) and "mlp" in name - - -def mlp_only_with_args( - mod, name, skip_last_layer_sparsity=False, skip_first_transformer_sparsity=False -): - if skip_last_layer_sparsity and "heads.head" in name: - return False - if skip_first_transformer_sparsity and "encoder.layers.encoder_layer_0" in name: - return False - if isinstance(mod, torch.nn.Linear) and "mlp" in name: - return True - return False - - -### Custom sparsification utils -def apply_sparsity(model): - for name, module in model.named_modules(): - if isinstance(module, SupermaskLinear) and "mlp" in name: - module.sparsify_offline() - - -def accelerate_with_sparsity(model, args): - if args.sparsity == "bsr": - apply_sparsity(model) - if args.quantization: - from torchao.dtypes.affine_quantized_tensor import BlockSparseLayout - - quantize_( - model, - int8_dynamic_activation_int8_weight( - _layout=BlockSparseLayout(blocksize=args.bsr) - ), - superblock_only, - ) - else: - assert args.bsr is not None, "BSR requires a block size" - sparsify_(model, block_sparse_weight(blocksize=args.bsr), superblock_only) - elif args.sparsity == "semi_structured": - if args.quantization: - from torchao.dtypes.affine_quantized_tensor import SemiSparseLayout - - quantize_( - model, - int8_dynamic_activation_int8_weight(layout=SemiSparseLayout()), - mlp_0_only, - ) - sparsify_(model, semi_sparse_weight(), mlp_3_only) - else: - sparsify_(model, semi_sparse_weight(), mlp_only) - else: - if args.quantization: - quantize_(model, int8_dynamic_activation_int8_weight(), mlp_only) - - -def simulate_sparsity(model, args): - if args.sparsity == "bsr": - apply_supermask( - model, - linear_sparsity=args.sparsity_linear, - linear_sp_tilesize=args.bsr, - conv1x1_sparsity=args.sparsity_conv1x1, - conv1x1_sp_tilesize=args.bsr, - conv_sparsity=args.sparsity_conv, - conv_sp_tilesize=args.bsr, - skip_last_layer_sparsity=args.skip_last_layer_sparsity, - skip_first_transformer_sparsity=args.skip_first_transformer_sparsity, - device=args.device, - verbose=False, - ) - elif args.sparsity == "semi_structured": - sparse_config = [] - for name, mod in model.named_modules(): - if mlp_only_with_args( - mod, - name, - skip_first_transformer_sparsity=args.skip_first_transformer_sparsity, - skip_last_layer_sparsity=args.skip_last_layer_sparsity, - ): - sparse_config.append({"tensor_fqn": f"{name}.weight"}) - - sparsifier = WeightNormSparsifier( - sparsity_level=1.0, sparse_block_shape=(1, 4), zeros_per_block=2 - ) - sparsifier.prepare(model, sparse_config) - sparsifier.step() - return sparsifier - - -# ------------------------------------------------------------ -# The following code contains torchvision reference code, -# largely copied from: https://github.com/pytorch/vision/tree/main/references/classification -# Please open issues in the original repository if you have questions. - - -class SmoothedValue: - """Track a series of values and provide access to smoothed values over a - window or the global series average. - """ - - def __init__(self, window_size=20, fmt=None): - if fmt is None: - fmt = "{median:.4f} ({global_avg:.4f})" - self.deque = deque(maxlen=window_size) - self.total = 0.0 - self.count = 0 - self.fmt = fmt - - def update(self, value, n=1): - self.deque.append(value) - self.count += n - self.total += value * n - - def synchronize_between_processes(self): - """ - Warning: does not synchronize the deque! - """ - t = reduce_across_processes([self.count, self.total]) - t = t.tolist() - self.count = int(t[0]) - self.total = t[1] - - @property - def median(self): - d = torch.tensor(list(self.deque)) - return d.median().item() - - @property - def avg(self): - d = torch.tensor(list(self.deque), dtype=torch.float32) - return d.mean().item() - - @property - def global_avg(self): - return self.total / self.count - - @property - def max(self): - return max(self.deque) - - @property - def value(self): - return self.deque[-1] - - def __str__(self): - return self.fmt.format( - median=self.median, - avg=self.avg, - global_avg=self.global_avg, - max=self.max, - value=self.value, - ) - - -class MetricLogger: - def __init__(self, delimiter="\t"): - self.meters = defaultdict(SmoothedValue) - self.delimiter = delimiter - - def update(self, **kwargs): - for k, v in kwargs.items(): - if isinstance(v, torch.Tensor): - v = v.item() - assert isinstance(v, (float, int)) - self.meters[k].update(v) - - def __getattr__(self, attr): - if attr in self.meters: - return self.meters[attr] - if attr in self.__dict__: - return self.__dict__[attr] - raise AttributeError( - f"'{type(self).__name__}' object has no attribute '{attr}'" - ) - - def __str__(self): - loss_str = [] - for name, meter in self.meters.items(): - loss_str.append(f"{name}: {str(meter)}") - return self.delimiter.join(loss_str) - - def synchronize_between_processes(self): - for meter in self.meters.values(): - meter.synchronize_between_processes() - - def add_meter(self, name, meter): - self.meters[name] = meter - - def log_every(self, iterable, print_freq, header=None): - i = 0 - if not header: - header = "" - start_time = time.time() - end = time.time() - iter_time = SmoothedValue(fmt="{avg:.4f}") - data_time = SmoothedValue(fmt="{avg:.4f}") - space_fmt = ":" + str(len(str(len(iterable)))) + "d" - if torch.cuda.is_available(): - log_msg = self.delimiter.join( - [ - header, - "[{0" + space_fmt + "}/{1}]", - "eta: {eta}", - "{meters}", - "time: {time}", - "data: {data}", - "max mem: {memory:.0f}", - ] - ) - else: - log_msg = self.delimiter.join( - [ - header, - "[{0" + space_fmt + "}/{1}]", - "eta: {eta}", - "{meters}", - "time: {time}", - "data: {data}", - ] - ) - MB = 1024.0 * 1024.0 - for obj in iterable: - data_time.update(time.time() - end) - yield obj - iter_time.update(time.time() - end) - if i % print_freq == 0: - eta_seconds = iter_time.global_avg * (len(iterable) - i) - eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) - if torch.cuda.is_available(): - print( - log_msg.format( - i, - len(iterable), - eta=eta_string, - meters=str(self), - time=str(iter_time), - data=str(data_time), - memory=torch.cuda.max_memory_allocated() / MB, - ) - ) - else: - print( - log_msg.format( - i, - len(iterable), - eta=eta_string, - meters=str(self), - time=str(iter_time), - data=str(data_time), - ) - ) - i += 1 - end = time.time() - total_time = time.time() - start_time - total_time_str = str(datetime.timedelta(seconds=int(total_time))) - print(f"{header} Total time: {total_time_str}") - - -class ExponentialMovingAverage(torch.optim.swa_utils.AveragedModel): - """Maintains moving averages of model parameters using an exponential decay. - ``ema_avg = decay * avg_model_param + (1 - decay) * model_param`` - `torch.optim.swa_utils.AveragedModel `_ - is used to compute the EMA. - """ - - def __init__(self, model, decay, device="cpu"): - def ema_avg(avg_model_param, model_param, num_averaged): - return decay * avg_model_param + (1 - decay) * model_param - - super().__init__(model, device, ema_avg, use_buffers=True) - - -def accuracy(output, target, topk=(1,)): - """Computes the accuracy over the k top predictions for the specified values of k""" - with torch.inference_mode(): - maxk = max(topk) - batch_size = target.size(0) - if target.ndim == 2: - target = target.max(dim=1)[1] - - _, pred = output.topk(maxk, 1, True, True) - pred = pred.t() - correct = pred.eq(target[None]) - - res = [] - for k in topk: - correct_k = correct[:k].flatten().sum(dtype=torch.float32) - res.append(correct_k * (100.0 / batch_size)) - return res - - -def mkdir(path): - try: - os.makedirs(path) - except OSError as e: - if e.errno != errno.EEXIST: - raise - - -def setup_for_distributed(is_master): - """ - This function disables printing when not in master process - """ - import builtins as __builtin__ - - builtin_print = __builtin__.print - - def print(*args, **kwargs): - force = kwargs.pop("force", False) - if is_master or force: - builtin_print(*args, **kwargs) - - __builtin__.print = print - - -def is_dist_avail_and_initialized(): - if not torch.distributed.is_available(): - return False - if not torch.distributed.is_initialized(): - return False - return True - - -def get_world_size(): - if not is_dist_avail_and_initialized(): - return 1 - return torch.distributed.get_world_size() - - -def get_rank(): - if not is_dist_avail_and_initialized(): - return 0 - return torch.distributed.get_rank() - - -def is_main_process(): - return get_rank() == 0 - - -def save_on_master(*args, **kwargs): - if is_main_process(): - torch.save(*args, **kwargs) - - -def init_distributed_mode(args): - if "RANK" in os.environ and "WORLD_SIZE" in os.environ: - args.rank = int(os.environ["RANK"]) - args.world_size = int(os.environ["WORLD_SIZE"]) - args.gpu = int(os.environ["LOCAL_RANK"]) - elif "SLURM_PROCID" in os.environ: - args.rank = int(os.environ["SLURM_PROCID"]) - args.gpu = args.rank % torch.cuda.device_count() - elif hasattr(args, "rank"): - pass - else: - print("Not using distributed mode") - args.distributed = False - return - - args.distributed = True - - torch.cuda.set_device(args.gpu) - args.dist_backend = "nccl" - print(f"| distributed init (rank {args.rank})", flush=True) - torch.distributed.init_process_group( - backend=args.dist_backend, - init_method=args.dist_url, - world_size=args.world_size, - rank=args.rank, - ) - torch.distributed.barrier() - setup_for_distributed(args.rank == 0) - - -def average_checkpoints(inputs): - """Loads checkpoints from inputs and returns a model with averaged weights. Original implementation taken from: - https://github.com/pytorch/fairseq/blob/a48f235636557b8d3bc4922a6fa90f3a0fa57955/scripts/average_checkpoints.py#L16 - - Args: - inputs (List[str]): An iterable of string paths of checkpoints to load from. - Returns: - A dict of string keys mapping to various values. The 'model' key - from the returned dict should correspond to an OrderedDict mapping - string parameter names to torch Tensors. - """ - params_dict = OrderedDict() - params_keys = None - new_state = None - num_models = len(inputs) - for fpath in inputs: - with open(fpath, "rb") as f: - state = torch.load( - f, - map_location=( - lambda s, _: torch.serialization.default_restore_location(s, "cpu") - ), - ) - # Copies over the settings from the first checkpoint - if new_state is None: - new_state = state - model_params = state["model"] - model_params_keys = list(model_params.keys()) - if params_keys is None: - params_keys = model_params_keys - elif params_keys != model_params_keys: - raise KeyError( - f"For checkpoint {f}, expected list of params: {params_keys}, but found: {model_params_keys}" - ) - for k in params_keys: - p = model_params[k] - if isinstance(p, torch.HalfTensor): - p = p.float() - if k not in params_dict: - params_dict[k] = p.clone() - # NOTE: clone() is needed in case of p is a shared parameter - else: - params_dict[k] += p - averaged_params = OrderedDict() - for k, v in params_dict.items(): - averaged_params[k] = v - if averaged_params[k].is_floating_point(): - averaged_params[k].div_(num_models) - else: - averaged_params[k] //= num_models - new_state["model"] = averaged_params - return new_state - - -def store_model_weights(model, checkpoint_path, checkpoint_key="model", strict=True): - """ - This method can be used to prepare weights files for new models. It receives as - input a model architecture and a checkpoint from the training script and produces - a file with the weights ready for release. - - Examples: - from torchvision import models as M - - # Classification - model = M.mobilenet_v3_large(weights=None) - print(store_model_weights(model, './class.pth')) - - # Quantized Classification - model = M.quantization.mobilenet_v3_large(weights=None, quantize=False) - model.fuse_model(is_qat=True) - model.qconfig = torch.ao.quantization.get_default_qat_qconfig('qnnpack') - _ = torch.ao.quantization.prepare_qat(model, inplace=True) - print(store_model_weights(model, './qat.pth')) - - # Object Detection - model = M.detection.fasterrcnn_mobilenet_v3_large_fpn(weights=None, weights_backbone=None) - print(store_model_weights(model, './obj.pth')) - - # Segmentation - model = M.segmentation.deeplabv3_mobilenet_v3_large(weights=None, weights_backbone=None, aux_loss=True) - print(store_model_weights(model, './segm.pth', strict=False)) - - Args: - model (pytorch.nn.Module): The model on which the weights will be loaded for validation purposes. - checkpoint_path (str): The path of the checkpoint we will load. - checkpoint_key (str, optional): The key of the checkpoint where the model weights are stored. - Default: "model". - strict (bool): whether to strictly enforce that the keys - in :attr:`state_dict` match the keys returned by this module's - :meth:`~torch.nn.Module.state_dict` function. Default: ``True`` - - Returns: - output_path (str): The location where the weights are saved. - """ - # Store the new model next to the checkpoint_path - checkpoint_path = os.path.abspath(checkpoint_path) - output_dir = os.path.dirname(checkpoint_path) - - # Deep copy to avoid side-effects on the model object. - model = copy.deepcopy(model) - checkpoint = torch.load(checkpoint_path, map_location="cpu") - - # Load the weights to the model to validate that everything works - # and remove unnecessary weights (such as auxiliaries, etc) - if checkpoint_key == "model_ema": - del checkpoint[checkpoint_key]["n_averaged"] - torch.nn.modules.utils.consume_prefix_in_state_dict_if_present( - checkpoint[checkpoint_key], "module." - ) - model.load_state_dict(checkpoint[checkpoint_key], strict=strict) - - tmp_path = os.path.join(output_dir, str(model.__hash__())) - torch.save(model.state_dict(), tmp_path) - - sha256_hash = hashlib.sha256() - with open(tmp_path, "rb") as f: - # Read and update hash string value in blocks of 4K - for byte_block in iter(lambda: f.read(4096), b""): - sha256_hash.update(byte_block) - hh = sha256_hash.hexdigest() - - output_path = os.path.join(output_dir, "weights-" + str(hh[:8]) + ".pth") - os.replace(tmp_path, output_path) - - return output_path - - -def reduce_across_processes(val): - if not is_dist_avail_and_initialized(): - # nothing to sync, but we still convert to tensor for consistency with the distributed case. - return torch.tensor(val) - - t = torch.tensor(val, device="cuda") - torch.distributed.barrier() - torch.distributed.all_reduce(t) - return t - - -def set_weight_decay( - model: torch.nn.Module, - weight_decay: float, - norm_weight_decay: Optional[float] = None, - norm_classes: Optional[List[type]] = None, - custom_keys_weight_decay: Optional[List[Tuple[str, float]]] = None, -): - if not norm_classes: - norm_classes = [ - torch.nn.modules.batchnorm._BatchNorm, - torch.nn.LayerNorm, - torch.nn.GroupNorm, - torch.nn.modules.instancenorm._InstanceNorm, - torch.nn.LocalResponseNorm, - ] - norm_classes = tuple(norm_classes) - - params = { - "other": [], - "norm": [], - } - params_weight_decay = { - "other": weight_decay, - "norm": norm_weight_decay, - } - custom_keys = [] - if custom_keys_weight_decay is not None: - for key, weight_decay in custom_keys_weight_decay: - params[key] = [] - params_weight_decay[key] = weight_decay - custom_keys.append(key) - - def _add_params(module, prefix=""): - for name, p in module.named_parameters(recurse=False): - if not p.requires_grad: - continue - is_custom_key = False - for key in custom_keys: - target_name = ( - f"{prefix}.{name}" if prefix != "" and "." in key else name - ) - if key == target_name: - params[key].append(p) - is_custom_key = True - break - if not is_custom_key: - if norm_weight_decay is not None and isinstance(module, norm_classes): - params["norm"].append(p) - else: - params["other"].append(p) - - for child_name, child_module in module.named_children(): - child_prefix = f"{prefix}.{child_name}" if prefix != "" else child_name - _add_params(child_module, prefix=child_prefix) - - _add_params(model) - - param_groups = [] - for key in params: - if len(params[key]) > 0: - param_groups.append( - {"params": params[key], "weight_decay": params_weight_decay[key]} - ) - return param_groups - - -# Presets for ImageNet training/eval taken from: https://github.com/pytorch/vision/blob/main/references/classification/presets.py - - -class ClassificationPresetTrain: - def __init__( - self, - *, - crop_size, - mean=(0.485, 0.456, 0.406), - std=(0.229, 0.224, 0.225), - interpolation=InterpolationMode.BILINEAR, - hflip_prob=0.5, - auto_augment_policy=None, - ra_magnitude=9, - augmix_severity=3, - random_erase_prob=0.0, - ): - trans = [transforms.RandomResizedCrop(crop_size, interpolation=interpolation)] - if hflip_prob > 0: - trans.append(transforms.RandomHorizontalFlip(hflip_prob)) - if auto_augment_policy is not None: - if auto_augment_policy == "ra": - trans.append( - autoaugment.RandAugment( - interpolation=interpolation, magnitude=ra_magnitude - ) - ) - elif auto_augment_policy == "ta_wide": - trans.append( - autoaugment.TrivialAugmentWide(interpolation=interpolation) - ) - elif auto_augment_policy == "augmix": - trans.append( - autoaugment.AugMix( - interpolation=interpolation, severity=augmix_severity - ) - ) - else: - aa_policy = autoaugment.AutoAugmentPolicy(auto_augment_policy) - trans.append( - autoaugment.AutoAugment( - policy=aa_policy, interpolation=interpolation - ) - ) - trans.extend( - [ - transforms.PILToTensor(), - transforms.ConvertImageDtype(torch.float), - transforms.Normalize(mean=mean, std=std), - ] - ) - if random_erase_prob > 0: - trans.append(transforms.RandomErasing(p=random_erase_prob)) - - self.transforms = transforms.Compose(trans) - - def __call__(self, img): - return self.transforms(img) - - -class ClassificationPresetEval: - def __init__( - self, - *, - crop_size, - resize_size=256, - mean=(0.485, 0.456, 0.406), - std=(0.229, 0.224, 0.225), - interpolation=InterpolationMode.BILINEAR, - ): - - self.transforms = transforms.Compose( - [ - transforms.Resize(resize_size, interpolation=interpolation), - transforms.CenterCrop(crop_size), - transforms.PILToTensor(), - transforms.ConvertImageDtype(torch.float), - transforms.Normalize(mean=mean, std=std), - ] - ) - - def __call__(self, img): - return self.transforms(img) - - -# transforms taken from: https://github.com/pytorch/vision/blob/main/references/classification/transforms.py - - -class RandomMixup(torch.nn.Module): - """Randomly apply Mixup to the provided batch and targets. - The class implements the data augmentations as described in the paper - `"mixup: Beyond Empirical Risk Minimization" `_. - - Args: - num_classes (int): number of classes used for one-hot encoding. - p (float): probability of the batch being transformed. Default value is 0.5. - alpha (float): hyperparameter of the Beta distribution used for mixup. - Default value is 1.0. - inplace (bool): boolean to make this transform inplace. Default set to False. - """ - - def __init__( - self, - num_classes: int, - p: float = 0.5, - alpha: float = 1.0, - inplace: bool = False, - ) -> None: - super().__init__() - - if num_classes < 1: - raise ValueError( - f"Please provide a valid positive value for the num_classes. Got num_classes={num_classes}" - ) - - if alpha <= 0: - raise ValueError("Alpha param can't be zero.") - - self.num_classes = num_classes - self.p = p - self.alpha = alpha - self.inplace = inplace - - def forward( - self, batch: torch.Tensor, target: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Args: - batch (Tensor): Float tensor of size (B, C, H, W) - target (Tensor): Integer tensor of size (B, ) - - Returns: - Tensor: Randomly transformed batch. - """ - if batch.ndim != 4: - raise ValueError(f"Batch ndim should be 4. Got {batch.ndim}") - if target.ndim != 1: - raise ValueError(f"Target ndim should be 1. Got {target.ndim}") - if not batch.is_floating_point(): - raise TypeError(f"Batch dtype should be a float tensor. Got {batch.dtype}.") - if target.dtype != torch.int64: - raise TypeError(f"Target dtype should be torch.int64. Got {target.dtype}") - - if not self.inplace: - batch = batch.clone() - target = target.clone() - - if target.ndim == 1: - target = torch.nn.functional.one_hot( - target, num_classes=self.num_classes - ).to(dtype=batch.dtype) - - if torch.rand(1).item() >= self.p: - return batch, target - - # It's faster to roll the batch by one instead of shuffling it to create image pairs - batch_rolled = batch.roll(1, 0) - target_rolled = target.roll(1, 0) - - # Implemented as on mixup paper, page 3. - lambda_param = float( - torch._sample_dirichlet(torch.tensor([self.alpha, self.alpha]))[0] - ) - batch_rolled.mul_(1.0 - lambda_param) - batch.mul_(lambda_param).add_(batch_rolled) - - target_rolled.mul_(1.0 - lambda_param) - target.mul_(lambda_param).add_(target_rolled) - - return batch, target - - def __repr__(self) -> str: - s = ( - f"{self.__class__.__name__}(" - f"num_classes={self.num_classes}" - f", p={self.p}" - f", alpha={self.alpha}" - f", inplace={self.inplace}" - f")" - ) - return s - - -class RandomCutmix(torch.nn.Module): - """Randomly apply Cutmix to the provided batch and targets. - The class implements the data augmentations as described in the paper - `"CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features" - `_. - - Args: - num_classes (int): number of classes used for one-hot encoding. - p (float): probability of the batch being transformed. Default value is 0.5. - alpha (float): hyperparameter of the Beta distribution used for cutmix. - Default value is 1.0. - inplace (bool): boolean to make this transform inplace. Default set to False. - """ - - def __init__( - self, - num_classes: int, - p: float = 0.5, - alpha: float = 1.0, - inplace: bool = False, - ) -> None: - super().__init__() - if num_classes < 1: - raise ValueError( - "Please provide a valid positive value for the num_classes." - ) - if alpha <= 0: - raise ValueError("Alpha param can't be zero.") - - self.num_classes = num_classes - self.p = p - self.alpha = alpha - self.inplace = inplace - - def forward( - self, batch: torch.Tensor, target: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Args: - batch (Tensor): Float tensor of size (B, C, H, W) - target (Tensor): Integer tensor of size (B, ) - - Returns: - Tensor: Randomly transformed batch. - """ - if batch.ndim != 4: - raise ValueError(f"Batch ndim should be 4. Got {batch.ndim}") - if target.ndim != 1: - raise ValueError(f"Target ndim should be 1. Got {target.ndim}") - if not batch.is_floating_point(): - raise TypeError(f"Batch dtype should be a float tensor. Got {batch.dtype}.") - if target.dtype != torch.int64: - raise TypeError(f"Target dtype should be torch.int64. Got {target.dtype}") - - if not self.inplace: - batch = batch.clone() - target = target.clone() - - if target.ndim == 1: - target = torch.nn.functional.one_hot( - target, num_classes=self.num_classes - ).to(dtype=batch.dtype) - - if torch.rand(1).item() >= self.p: - return batch, target - - # It's faster to roll the batch by one instead of shuffling it to create image pairs - batch_rolled = batch.roll(1, 0) - target_rolled = target.roll(1, 0) - - # Implemented as on cutmix paper, page 12 (with minor corrections on typos). - lambda_param = float( - torch._sample_dirichlet(torch.tensor([self.alpha, self.alpha]))[0] - ) - _, H, W = F.get_dimensions(batch) - - r_x = torch.randint(W, (1,)) - r_y = torch.randint(H, (1,)) - - r = 0.5 * math.sqrt(1.0 - lambda_param) - r_w_half = int(r * W) - r_h_half = int(r * H) - - x1 = int(torch.clamp(r_x - r_w_half, min=0)) - y1 = int(torch.clamp(r_y - r_h_half, min=0)) - x2 = int(torch.clamp(r_x + r_w_half, max=W)) - y2 = int(torch.clamp(r_y + r_h_half, max=H)) - - batch[:, :, y1:y2, x1:x2] = batch_rolled[:, :, y1:y2, x1:x2] - lambda_param = float(1.0 - (x2 - x1) * (y2 - y1) / (W * H)) - - target_rolled.mul_(1.0 - lambda_param) - target.mul_(lambda_param).add_(target_rolled) - - return batch, target - - def __repr__(self) -> str: - s = ( - f"{self.__class__.__name__}(" - f"num_classes={self.num_classes}" - f", p={self.p}" - f", alpha={self.alpha}" - f", inplace={self.inplace}" - f")" - ) - return s - - -# RA Sampler implementaion taken from: https://github.com/pytorch/vision/blob/main/references/classification/sampler.py - - -class RASampler(torch.utils.data.Sampler): - """Sampler that restricts data loading to a subset of the dataset for distributed, - with repeated augmentation. - It ensures that different each augmented version of a sample will be visible to a - different process (GPU). - Heavily based on 'torch.utils.data.DistributedSampler'. - - This is borrowed from the DeiT Repo: - https://github.com/facebookresearch/deit/blob/main/samplers.py - """ - - def __init__( - self, dataset, num_replicas=None, rank=None, shuffle=True, seed=0, repetitions=3 - ): - if num_replicas is None: - if not torch.distributed.is_available(): - raise RuntimeError("Requires distributed package to be available!") - num_replicas = torch.distributed.get_world_size() - if rank is None: - if not torch.distributed.is_available(): - raise RuntimeError("Requires distributed package to be available!") - rank = torch.distributed.get_rank() - self.dataset = dataset - self.num_replicas = num_replicas - self.rank = rank - self.epoch = 0 - self.num_samples = int( - math.ceil(len(self.dataset) * float(repetitions) / self.num_replicas) - ) - self.total_size = self.num_samples * self.num_replicas - self.num_selected_samples = int( - math.floor(len(self.dataset) // 256 * 256 / self.num_replicas) - ) - self.shuffle = shuffle - self.seed = seed - self.repetitions = repetitions - - def __iter__(self): - if self.shuffle: - # Deterministically shuffle based on epoch - g = torch.Generator() - g.manual_seed(self.seed + self.epoch) - indices = torch.randperm(len(self.dataset), generator=g).tolist() - else: - indices = list(range(len(self.dataset))) - - # Add extra samples to make it evenly divisible - indices = [ele for ele in indices for i in range(self.repetitions)] - indices += indices[: (self.total_size - len(indices))] - assert len(indices) == self.total_size - - # Subsample - indices = indices[self.rank : self.total_size : self.num_replicas] - assert len(indices) == self.num_samples - - return iter(indices[: self.num_selected_samples]) - - def __len__(self): - return self.num_selected_samples - - def set_epoch(self, epoch): - self.epoch = epoch