Skip to content

Move files to prototype/sparsity #1145

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import logging
import torch
import unittest

import torch
from torch import nn
from torch.nn.utils import parametrize
from torch.testing._internal.common_utils import TestCase

from torchao.sparsity.prototype.sparsifier import utils
from torchao.prototype.sparsity.sparsifier import utils

logging.basicConfig(
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO
Expand Down Expand Up @@ -168,5 +169,6 @@ def test_jit_trace(self):
y_hat = model_trace(x)
self.assertEqual(y_hat, y)


if __name__ == "__main__":
unittest.main()
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
import warnings
import unittest
import warnings

from torch import nn
from torch.testing._internal.common_utils import TestCase

from torchao.sparsity.prototype import BaseScheduler, CubicSL, LambdaSL, WeightNormSparsifier
from torchao.prototype.sparsity import (
BaseScheduler,
CubicSL,
LambdaSL,
WeightNormSparsifier,
)


class ImplementedScheduler(BaseScheduler):
def get_sl(self):
Expand Down Expand Up @@ -190,5 +196,6 @@ def test_step(self):
msg="Sparsity level is not reaching the target level afer delta_t * n steps ",
)


if __name__ == "__main__":
unittest.main()
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,12 @@
)

from torchao.sparsity import apply_fake_sparsity, semi_sparse_weight, sparsify_
from torchao.utils import TORCH_VERSION_AFTER_2_5, TORCH_VERSION_AT_LEAST_2_3, TORCH_VERSION_AT_LEAST_2_5, TORCH_VERSION_AT_LEAST_2_4
from torchao.utils import (
TORCH_VERSION_AFTER_2_5,
TORCH_VERSION_AT_LEAST_2_3,
TORCH_VERSION_AT_LEAST_2_4,
TORCH_VERSION_AT_LEAST_2_5,
)


logging.basicConfig(
Expand Down Expand Up @@ -88,7 +93,7 @@ def test_quant_semi_sparse(self, compile):
def test_sparse_marlin(self, compile):
if not torch.backends.cusparselt.is_available():
self.skipTest("Need cuSPARSELt")

input = torch.rand((256, 256)).half().cuda()
model = (
nn.Sequential(
Expand Down Expand Up @@ -117,7 +122,10 @@ def test_sparse_marlin(self, compile):


class TestBlockSparseWeight(common_utils.TestCase):
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "pytorch 2.4+ feature due to need for custom op support")
@unittest.skipIf(
not TORCH_VERSION_AT_LEAST_2_4,
"pytorch 2.4+ feature due to need for custom op support",
)
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@common_utils.parametrize("compile", [True, False])
def test_sparse(self, compile):
Expand All @@ -140,7 +148,7 @@ def test_sparse(self, compile):
model[1].weight.data = create_block_sparse_tensor(M, N, 64, 0.5, torch.float16)
dense_result = model(input)

from torchao.sparsity.prototype.superblock.blocksparse import (
from torchao.prototype.sparsity.superblock.blocksparse import (
block_sparse_weight,
)

Expand All @@ -167,7 +175,7 @@ def test_sparse(self, compile):
.cuda()
.eval()
)
from torchao.sparsity.prototype.superblock.blocksparse import (
from torchao.prototype.sparsity.superblock.blocksparse import (
blocksparse_int_addmm,
)
from torchao.sparsity.utils import create_block_sparse_tensor
Expand All @@ -189,9 +197,7 @@ def test_sparse(self, compile):

quantize_(
model,
int8_dynamic_activation_int8_weight(
layout=BlockSparseLayout(blocksize=64)
),
int8_dynamic_activation_int8_weight(layout=BlockSparseLayout(blocksize=64)),
)
if compile:
model = torch.compile(model)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,6 @@

import torch
from torch import nn
from torchao.sparsity.prototype import (
BaseSparsifier,
FakeSparsity,
NearlyDiagonalSparsifier,
WeightNormSparsifier,
)
from torch.nn.utils.parametrize import is_parametrized
from torch.testing._internal.common_pruning import (
ImplementedSparsifier,
Expand All @@ -21,6 +15,12 @@
)

from torch.testing._internal.common_utils import TestCase
from torchao.prototype.sparsity import (
BaseSparsifier,
FakeSparsity,
NearlyDiagonalSparsifier,
WeightNormSparsifier,
)

logging.basicConfig(
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO
Expand Down Expand Up @@ -486,5 +486,6 @@ def _verify_nearliness(self, mask: torch.Tensor, nearliness: int):
else:
assert mask[row, col] == 0


if __name__ == "__main__":
unittest.main()
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,6 @@
import unittest

import torch
from torchao.sparsity.prototype.sparsifier.utils import (
fqn_to_module,
get_arg_info_from_tensor_fqn,
module_to_fqn,
)

from torch.testing._internal.common_quantization import (
ConvBnReLUModel,
Expand All @@ -18,6 +13,11 @@
TwoLayerLinearModel,
)
from torch.testing._internal.common_utils import TestCase
from torchao.prototype.sparsity.sparsifier.utils import (
fqn_to_module,
get_arg_info_from_tensor_fqn,
module_to_fqn,
)

logging.basicConfig(
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,6 @@

import torch
from torch import nn
from torchao.sparsity.prototype.pruner import (
BaseStructuredSparsifier,
FakeStructuredSparsity,
FPGMPruner,
LSTMSaliencyPruner,
SaliencyPruner,
)
from torch.nn.utils import parametrize
from torch.testing._internal.common_pruning import (
Conv2dActivation,
Expand All @@ -32,6 +25,13 @@
)

from torch.testing._internal.common_utils import skipIfTorchDynamo, TestCase
from torchao.prototype.sparsity.pruner import (
BaseStructuredSparsifier,
FakeStructuredSparsity,
FPGMPruner,
LSTMSaliencyPruner,
SaliencyPruner,
)


logging.basicConfig(
Expand Down Expand Up @@ -1093,5 +1093,6 @@ def test_update_mask(self):
expected_conv1, expected_conv2, device
)


if __name__ == "__main__":
unittest.main()
20 changes: 20 additions & 0 deletions torchao/prototype/sparsity/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Sparsifier
# Scheduler
from torchao.prototype.sparsity.scheduler.base_scheduler import BaseScheduler
from torchao.prototype.sparsity.scheduler.cubic_scheduler import CubicSL
from torchao.prototype.sparsity.scheduler.lambda_scheduler import LambdaSL
from torchao.prototype.sparsity.sparsifier.base_sparsifier import BaseSparsifier
from torchao.prototype.sparsity.sparsifier.nearly_diagonal_sparsifier import (
NearlyDiagonalSparsifier,
)

# Parametrizations
from torchao.prototype.sparsity.sparsifier.utils import (
FakeSparsity,
fqn_to_module,
get_arg_info_from_tensor_fqn,
module_to_fqn,
)
from torchao.prototype.sparsity.sparsifier.weight_norm_sparsifier import (
WeightNormSparsifier,
)
93 changes: 93 additions & 0 deletions torchao/prototype/sparsity/pruner/FPGM_pruner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
from typing import Callable, Optional, Union

import torch

from .base_structured_sparsifier import BaseStructuredSparsifier

__all__ = ["FPGMPruner"]


class FPGMPruner(BaseStructuredSparsifier):
r"""Filter Pruning via Geometric Median (FPGM) Structured Pruner
This sparsifier prune fliter (row) in a tensor according to distances among filters according to
`Filter Pruning via Geometric Median for Deep Convolutional Neural Networks Acceleration <https://arxiv.org/abs/1811.00250>`_.

This sparsifier is controlled by three variables:
1. `sparsity_level` defines the number of filters (rows) that are zeroed-out.
2. `dist` defines the distance measurement type. Default: 3 (L2 distance).
Available options are: [1, 2, (custom callable distance function)].

Note::
Inputs should be a 4D convolutional tensor of shape (N, C, H, W).
- N: output channels size
- C: input channels size
- H: height of kernel
- W: width of kernel
"""

def __init__(
self, sparsity_level: float = 0.5, dist: Optional[Union[Callable, int]] = None
):
defaults = {
"sparsity_level": sparsity_level,
}

if dist is None:
dist = 2

if callable(dist):
self.dist_fn = dist
elif dist == 1:
self.dist_fn = lambda x: torch.cdist(x, x, p=1)
elif dist == 2:
self.dist_fn = lambda x: torch.cdist(x, x, p=2)
else:
raise NotImplementedError("Distance function is not yet implemented.")
super().__init__(defaults=defaults)

def _compute_distance(self, t):
r"""Compute distance across all entries in tensor `t` along all dimension
except for the one identified by dim.
Args:
t (torch.Tensor): tensor representing the parameter to prune
Returns:
distance (torch.Tensor): distance computed across filtters
"""
dim = 0 # prune filter (row)

size = t.size(dim)
slc = [slice(None)] * t.dim()

# flatten the tensor along the dimension
t_flatten = [
t[tuple(slc[:dim] + [slice(i, i + 1)] + slc[dim + 1 :])].reshape(-1)
for i in range(size)
]
t_flatten = torch.stack(t_flatten)

# distance measurement
dist_matrix = self.dist_fn(t_flatten)

# more similar with other filter indicates large in the sum of row
distance = torch.sum(torch.abs(dist_matrix), 1)

return distance

def update_mask(self, module, tensor_name, sparsity_level, **kwargs):
tensor_weight = getattr(module, tensor_name)
mask = getattr(module.parametrizations, tensor_name)[0].mask

if sparsity_level <= 0:
mask.data = torch.ones_like(mask).bool()
elif sparsity_level >= 1.0:
mask.data = torch.zeros_like(mask).bool()
else:
distance = self._compute_distance(tensor_weight)

tensor_size = tensor_weight.shape[0] # prune filter (row)
nparams_toprune = round(sparsity_level * tensor_size)
nparams_toprune = min(
max(nparams_toprune, 0), tensor_size
) # clamp to [0, tensor_size]
topk = torch.topk(distance, k=nparams_toprune, largest=False)
mask[topk.indices] = False
8 changes: 8 additions & 0 deletions torchao/prototype/sparsity/pruner/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from .base_structured_sparsifier import BaseStructuredSparsifier
from .parametrization import (
FakeStructuredSparsity,
BiasHook,
)
from .saliency_pruner import SaliencyPruner
from .lstm_saliency_pruner import LSTMSaliencyPruner
from .FPGM_pruner import FPGMPruner
Loading
Loading