Skip to content

Commit 2d5676e

Browse files
committed
Move files to prototype/sparsity
1 parent f33cff7 commit 2d5676e

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

62 files changed

+3557
-3209
lines changed

test/sparsity/test_parametrization.py renamed to test/prototype/test_parametrization.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
import logging
2-
import torch
32
import unittest
3+
4+
import torch
45
from torch import nn
56
from torch.nn.utils import parametrize
67
from torch.testing._internal.common_utils import TestCase
78

8-
from torchao.sparsity.prototype.sparsifier import utils
9+
from torchao.prototype.sparsity.sparsifier import utils
910

1011
logging.basicConfig(
1112
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO
@@ -168,5 +169,6 @@ def test_jit_trace(self):
168169
y_hat = model_trace(x)
169170
self.assertEqual(y_hat, y)
170171

172+
171173
if __name__ == "__main__":
172174
unittest.main()

test/sparsity/test_scheduler.py renamed to test/prototype/test_scheduler.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,16 @@
1-
import warnings
21
import unittest
2+
import warnings
33

44
from torch import nn
55
from torch.testing._internal.common_utils import TestCase
66

7-
from torchao.sparsity.prototype import BaseScheduler, CubicSL, LambdaSL, WeightNormSparsifier
7+
from torchao.prototype.sparsity import (
8+
BaseScheduler,
9+
CubicSL,
10+
LambdaSL,
11+
WeightNormSparsifier,
12+
)
13+
814

915
class ImplementedScheduler(BaseScheduler):
1016
def get_sl(self):
@@ -190,5 +196,6 @@ def test_step(self):
190196
msg="Sparsity level is not reaching the target level afer delta_t * n steps ",
191197
)
192198

199+
193200
if __name__ == "__main__":
194201
unittest.main()

test/sparsity/test_sparse_api.py renamed to test/prototype/test_sparse_api.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,12 @@
1313
)
1414

1515
from torchao.sparsity import apply_fake_sparsity, semi_sparse_weight, sparsify_
16-
from torchao.utils import TORCH_VERSION_AFTER_2_5, TORCH_VERSION_AT_LEAST_2_3, TORCH_VERSION_AT_LEAST_2_5, TORCH_VERSION_AT_LEAST_2_4
16+
from torchao.utils import (
17+
TORCH_VERSION_AFTER_2_5,
18+
TORCH_VERSION_AT_LEAST_2_3,
19+
TORCH_VERSION_AT_LEAST_2_4,
20+
TORCH_VERSION_AT_LEAST_2_5,
21+
)
1722

1823

1924
logging.basicConfig(
@@ -88,7 +93,7 @@ def test_quant_semi_sparse(self, compile):
8893
def test_sparse_marlin(self, compile):
8994
if not torch.backends.cusparselt.is_available():
9095
self.skipTest("Need cuSPARSELt")
91-
96+
9297
input = torch.rand((256, 256)).half().cuda()
9398
model = (
9499
nn.Sequential(
@@ -117,7 +122,10 @@ def test_sparse_marlin(self, compile):
117122

118123

119124
class TestBlockSparseWeight(common_utils.TestCase):
120-
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "pytorch 2.4+ feature due to need for custom op support")
125+
@unittest.skipIf(
126+
not TORCH_VERSION_AT_LEAST_2_4,
127+
"pytorch 2.4+ feature due to need for custom op support",
128+
)
121129
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
122130
@common_utils.parametrize("compile", [True, False])
123131
def test_sparse(self, compile):
@@ -140,7 +148,7 @@ def test_sparse(self, compile):
140148
model[1].weight.data = create_block_sparse_tensor(M, N, 64, 0.5, torch.float16)
141149
dense_result = model(input)
142150

143-
from torchao.sparsity.prototype.superblock.blocksparse import (
151+
from torchao.prototype.sparsity.superblock.blocksparse import (
144152
block_sparse_weight,
145153
)
146154

@@ -167,7 +175,7 @@ def test_sparse(self, compile):
167175
.cuda()
168176
.eval()
169177
)
170-
from torchao.sparsity.prototype.superblock.blocksparse import (
178+
from torchao.prototype.sparsity.superblock.blocksparse import (
171179
blocksparse_int_addmm,
172180
)
173181
from torchao.sparsity.utils import create_block_sparse_tensor
@@ -189,9 +197,7 @@ def test_sparse(self, compile):
189197

190198
quantize_(
191199
model,
192-
int8_dynamic_activation_int8_weight(
193-
layout=BlockSparseLayout(blocksize=64)
194-
),
200+
int8_dynamic_activation_int8_weight(layout=BlockSparseLayout(blocksize=64)),
195201
)
196202
if compile:
197203
model = torch.compile(model)

test/sparsity/test_sparsifier.py renamed to test/prototype/test_sparsifier.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,6 @@
77

88
import torch
99
from torch import nn
10-
from torchao.sparsity.prototype import (
11-
BaseSparsifier,
12-
FakeSparsity,
13-
NearlyDiagonalSparsifier,
14-
WeightNormSparsifier,
15-
)
1610
from torch.nn.utils.parametrize import is_parametrized
1711
from torch.testing._internal.common_pruning import (
1812
ImplementedSparsifier,
@@ -21,6 +15,12 @@
2115
)
2216

2317
from torch.testing._internal.common_utils import TestCase
18+
from torchao.prototype.sparsity import (
19+
BaseSparsifier,
20+
FakeSparsity,
21+
NearlyDiagonalSparsifier,
22+
WeightNormSparsifier,
23+
)
2424

2525
logging.basicConfig(
2626
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO
@@ -486,5 +486,6 @@ def _verify_nearliness(self, mask: torch.Tensor, nearliness: int):
486486
else:
487487
assert mask[row, col] == 0
488488

489+
489490
if __name__ == "__main__":
490491
unittest.main()

test/sparsity/test_sparsity_utils.py renamed to test/prototype/test_sparsity_utils.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,6 @@
22
import unittest
33

44
import torch
5-
from torchao.sparsity.prototype.sparsifier.utils import (
6-
fqn_to_module,
7-
get_arg_info_from_tensor_fqn,
8-
module_to_fqn,
9-
)
105

116
from torch.testing._internal.common_quantization import (
127
ConvBnReLUModel,
@@ -18,6 +13,11 @@
1813
TwoLayerLinearModel,
1914
)
2015
from torch.testing._internal.common_utils import TestCase
16+
from torchao.prototype.sparsity.sparsifier.utils import (
17+
fqn_to_module,
18+
get_arg_info_from_tensor_fqn,
19+
module_to_fqn,
20+
)
2121

2222
logging.basicConfig(
2323
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO

test/sparsity/test_structured_sparsifier.py renamed to test/prototype/test_structured_sparsifier.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,6 @@
66

77
import torch
88
from torch import nn
9-
from torchao.sparsity.prototype.pruner import (
10-
BaseStructuredSparsifier,
11-
FakeStructuredSparsity,
12-
FPGMPruner,
13-
LSTMSaliencyPruner,
14-
SaliencyPruner,
15-
)
169
from torch.nn.utils import parametrize
1710
from torch.testing._internal.common_pruning import (
1811
Conv2dActivation,
@@ -32,6 +25,13 @@
3225
)
3326

3427
from torch.testing._internal.common_utils import skipIfTorchDynamo, TestCase
28+
from torchao.prototype.sparsity.pruner import (
29+
BaseStructuredSparsifier,
30+
FakeStructuredSparsity,
31+
FPGMPruner,
32+
LSTMSaliencyPruner,
33+
SaliencyPruner,
34+
)
3535

3636

3737
logging.basicConfig(
@@ -1093,5 +1093,6 @@ def test_update_mask(self):
10931093
expected_conv1, expected_conv2, device
10941094
)
10951095

1096+
10961097
if __name__ == "__main__":
10971098
unittest.main()
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# Sparsifier
2+
# Scheduler
3+
from torchao.prototype.sparsity.scheduler.base_scheduler import BaseScheduler
4+
from torchao.prototype.sparsity.scheduler.cubic_scheduler import CubicSL
5+
from torchao.prototype.sparsity.scheduler.lambda_scheduler import LambdaSL
6+
from torchao.prototype.sparsity.sparsifier.base_sparsifier import BaseSparsifier
7+
from torchao.prototype.sparsity.sparsifier.nearly_diagonal_sparsifier import (
8+
NearlyDiagonalSparsifier,
9+
)
10+
11+
# Parametrizations
12+
from torchao.prototype.sparsity.sparsifier.utils import (
13+
FakeSparsity,
14+
fqn_to_module,
15+
get_arg_info_from_tensor_fqn,
16+
module_to_fqn,
17+
)
18+
from torchao.prototype.sparsity.sparsifier.weight_norm_sparsifier import (
19+
WeightNormSparsifier,
20+
)
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
from typing import Callable, Optional, Union
2+
3+
import torch
4+
5+
from .base_structured_sparsifier import BaseStructuredSparsifier
6+
7+
__all__ = ["FPGMPruner"]
8+
9+
10+
class FPGMPruner(BaseStructuredSparsifier):
11+
r"""Filter Pruning via Geometric Median (FPGM) Structured Pruner
12+
This sparsifier prune fliter (row) in a tensor according to distances among filters according to
13+
`Filter Pruning via Geometric Median for Deep Convolutional Neural Networks Acceleration <https://arxiv.org/abs/1811.00250>`_.
14+
15+
This sparsifier is controlled by three variables:
16+
1. `sparsity_level` defines the number of filters (rows) that are zeroed-out.
17+
2. `dist` defines the distance measurement type. Default: 3 (L2 distance).
18+
Available options are: [1, 2, (custom callable distance function)].
19+
20+
Note::
21+
Inputs should be a 4D convolutional tensor of shape (N, C, H, W).
22+
- N: output channels size
23+
- C: input channels size
24+
- H: height of kernel
25+
- W: width of kernel
26+
"""
27+
28+
def __init__(
29+
self, sparsity_level: float = 0.5, dist: Optional[Union[Callable, int]] = None
30+
):
31+
defaults = {
32+
"sparsity_level": sparsity_level,
33+
}
34+
35+
if dist is None:
36+
dist = 2
37+
38+
if callable(dist):
39+
self.dist_fn = dist
40+
elif dist == 1:
41+
self.dist_fn = lambda x: torch.cdist(x, x, p=1)
42+
elif dist == 2:
43+
self.dist_fn = lambda x: torch.cdist(x, x, p=2)
44+
else:
45+
raise NotImplementedError("Distance function is not yet implemented.")
46+
super().__init__(defaults=defaults)
47+
48+
def _compute_distance(self, t):
49+
r"""Compute distance across all entries in tensor `t` along all dimension
50+
except for the one identified by dim.
51+
Args:
52+
t (torch.Tensor): tensor representing the parameter to prune
53+
Returns:
54+
distance (torch.Tensor): distance computed across filtters
55+
"""
56+
dim = 0 # prune filter (row)
57+
58+
size = t.size(dim)
59+
slc = [slice(None)] * t.dim()
60+
61+
# flatten the tensor along the dimension
62+
t_flatten = [
63+
t[tuple(slc[:dim] + [slice(i, i + 1)] + slc[dim + 1 :])].reshape(-1)
64+
for i in range(size)
65+
]
66+
t_flatten = torch.stack(t_flatten)
67+
68+
# distance measurement
69+
dist_matrix = self.dist_fn(t_flatten)
70+
71+
# more similar with other filter indicates large in the sum of row
72+
distance = torch.sum(torch.abs(dist_matrix), 1)
73+
74+
return distance
75+
76+
def update_mask(self, module, tensor_name, sparsity_level, **kwargs):
77+
tensor_weight = getattr(module, tensor_name)
78+
mask = getattr(module.parametrizations, tensor_name)[0].mask
79+
80+
if sparsity_level <= 0:
81+
mask.data = torch.ones_like(mask).bool()
82+
elif sparsity_level >= 1.0:
83+
mask.data = torch.zeros_like(mask).bool()
84+
else:
85+
distance = self._compute_distance(tensor_weight)
86+
87+
tensor_size = tensor_weight.shape[0] # prune filter (row)
88+
nparams_toprune = round(sparsity_level * tensor_size)
89+
nparams_toprune = min(
90+
max(nparams_toprune, 0), tensor_size
91+
) # clamp to [0, tensor_size]
92+
topk = torch.topk(distance, k=nparams_toprune, largest=False)
93+
mask[topk.indices] = False
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
from .base_structured_sparsifier import BaseStructuredSparsifier
2+
from .parametrization import (
3+
FakeStructuredSparsity,
4+
BiasHook,
5+
)
6+
from .saliency_pruner import SaliencyPruner
7+
from .lstm_saliency_pruner import LSTMSaliencyPruner
8+
from .FPGM_pruner import FPGMPruner

0 commit comments

Comments
 (0)