|
13 | 13 | )
|
14 | 14 |
|
15 | 15 | from torchao.sparsity import apply_fake_sparsity, semi_sparse_weight, sparsify_
|
16 |
| -from torchao.utils import TORCH_VERSION_AFTER_2_5, TORCH_VERSION_AT_LEAST_2_3 |
| 16 | +from torchao.utils import TORCH_VERSION_AFTER_2_5, TORCH_VERSION_AT_LEAST_2_3, TORCH_VERSION_AT_LEAST_2_5, TORCH_VERSION_AT_LEAST_2_4 |
17 | 17 |
|
18 | 18 |
|
19 | 19 | logging.basicConfig(
|
@@ -48,7 +48,7 @@ def test_sparse(self):
|
48 | 48 |
|
49 | 49 | class TestQuantSemiSparse(common_utils.TestCase):
|
50 | 50 |
|
51 |
| - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "pytorch 2.3+ feature") |
| 51 | + @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "pytorch 2.4+ feature") |
52 | 52 | @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
|
53 | 53 | @common_utils.parametrize("compile", [True, False])
|
54 | 54 | def test_quant_semi_sparse(self, compile):
|
@@ -79,6 +79,7 @@ def test_quant_semi_sparse(self, compile):
|
79 | 79 |
|
80 | 80 | torch.testing.assert_close(dense_result, sparse_result, rtol=1e-2, atol=1e-2)
|
81 | 81 |
|
| 82 | + @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "pytorch 2.5+ feature") |
82 | 83 | @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
|
83 | 84 | @common_utils.parametrize("compile", [True, False])
|
84 | 85 | def test_sparse_marlin(self, compile):
|
@@ -110,7 +111,7 @@ def test_sparse_marlin(self, compile):
|
110 | 111 |
|
111 | 112 |
|
112 | 113 | class TestBlockSparseWeight(common_utils.TestCase):
|
113 |
| - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "pytorch 2.3+ feature") |
| 114 | + @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "pytorch 2.4+ feature due to need for custom op support") |
114 | 115 | @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
|
115 | 116 | @common_utils.parametrize("compile", [True, False])
|
116 | 117 | def test_sparse(self, compile):
|
|
0 commit comments