From 003eb1b93d4c34f6d784c91862f5c4baf2df3091 Mon Sep 17 00:00:00 2001 From: Mark Saroufim Date: Sat, 10 Aug 2024 16:22:44 -0700 Subject: [PATCH 1/7] Support PyTorch 2.4 and drop PyTorch 2.2 --- .github/workflows/regression_test.yml | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/.github/workflows/regression_test.yml b/.github/workflows/regression_test.yml index 119d228085..25e18645cd 100644 --- a/.github/workflows/regression_test.yml +++ b/.github/workflows/regression_test.yml @@ -21,14 +21,14 @@ jobs: fail-fast: false matrix: include: - - name: CUDA 2.2.2 + - name: CUDA 2.3.0 runs-on: linux.g5.12xlarge.nvidia.gpu - torch-spec: 'torch==2.2.2 "numpy<2" ' + torch-spec: 'torch==2.3.0' gpu-arch-type: "cuda" gpu-arch-version: "12.1" - - name: CUDA 2.3 + - name: CUDA 2.4.0 runs-on: linux.g5.12xlarge.nvidia.gpu - torch-spec: 'torch==2.3.0' + torch-spec: 'torch==2.4.0' gpu-arch-type: "cuda" gpu-arch-version: "12.1" - name: CUDA Nightly @@ -36,14 +36,15 @@ jobs: torch-spec: '--pre torch --index-url https://download.pytorch.org/whl/nightly/cu121' gpu-arch-type: "cuda" gpu-arch-version: "12.1" - - name: CPU 2.2.2 + + - name: CPU 2.3.0 runs-on: linux.4xlarge - torch-spec: 'torch==2.2.2 --index-url https://download.pytorch.org/whl/cpu "numpy<2" ' + torch-spec: 'torch==2.3.0 --index-url https://download.pytorch.org/whl/cpu' gpu-arch-type: "cpu" gpu-arch-version: "" - - name: CPU 2.3 + - name: CPU 2.4 runs-on: linux.4xlarge - torch-spec: 'torch==2.3.0 --index-url https://download.pytorch.org/whl/cpu' + torch-spec: 'torch==2.4.0 --index-url https://download.pytorch.org/whl/cpu' gpu-arch-type: "cpu" gpu-arch-version: "" - name: CPU Nightly From e6e718977d05185a1c2d7bdb898213c64a26a943 Mon Sep 17 00:00:00 2001 From: Mark Saroufim Date: Sat, 10 Aug 2024 16:29:36 -0700 Subject: [PATCH 2/7] Update regression_test.yml --- .github/workflows/regression_test.yml | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/.github/workflows/regression_test.yml b/.github/workflows/regression_test.yml index 25e18645cd..2c3b594eea 100644 --- a/.github/workflows/regression_test.yml +++ b/.github/workflows/regression_test.yml @@ -21,12 +21,17 @@ jobs: fail-fast: false matrix: include: - - name: CUDA 2.3.0 + - name: CUDA 2.2.2 + runs-on: linux.g5.12xlarge.nvidia.gpu + torch-spec: 'torch==2.2.2 "numpy<2" ' + gpu-arch-type: "cuda" + gpu-arch-version: "12.1" + - name: CUDA 2.3 runs-on: linux.g5.12xlarge.nvidia.gpu torch-spec: 'torch==2.3.0' gpu-arch-type: "cuda" gpu-arch-version: "12.1" - - name: CUDA 2.4.0 + - name: CUDA 2.4 runs-on: linux.g5.12xlarge.nvidia.gpu torch-spec: 'torch==2.4.0' gpu-arch-type: "cuda" @@ -36,8 +41,13 @@ jobs: torch-spec: '--pre torch --index-url https://download.pytorch.org/whl/nightly/cu121' gpu-arch-type: "cuda" gpu-arch-version: "12.1" - - - name: CPU 2.3.0 + + - name: CPU 2.2.2 + runs-on: linux.4xlarge + torch-spec: 'torch==2.2.2 --index-url https://download.pytorch.org/whl/cpu "numpy<2" ' + gpu-arch-type: "cpu" + gpu-arch-version: "" + - name: CPU 2.3 runs-on: linux.4xlarge torch-spec: 'torch==2.3.0 --index-url https://download.pytorch.org/whl/cpu' gpu-arch-type: "cpu" From 41dd58b2b07358aada1f544d2a45def31d394644 Mon Sep 17 00:00:00 2001 From: Mark Saroufim Date: Wed, 14 Aug 2024 16:54:09 -0700 Subject: [PATCH 3/7] update tests --- test/float8/test_base.py | 4 ++-- test/float8/test_compile.py | 4 ++-- test/float8/test_dtensor.py | 4 ++-- test/float8/test_fsdp.py | 4 ++-- test/float8/test_fsdp2/test_fsdp2.py | 4 ++-- test/float8/test_fsdp_compile.py | 4 ++-- test/float8/test_inference_flows.py | 4 ++-- test/float8/test_numerics_integration.py | 4 ++-- test/integration/test_integration.py | 8 ++++---- 9 files changed, 20 insertions(+), 20 deletions(-) diff --git a/test/float8/test_base.py b/test/float8/test_base.py index e7283ec1ee..632fbc5869 100644 --- a/test/float8/test_base.py +++ b/test/float8/test_base.py @@ -16,9 +16,9 @@ import torch import torch.nn as nn -from torchao.utils import TORCH_VERSION_AT_LEAST_2_4 +from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 -if not TORCH_VERSION_AT_LEAST_2_4: +if not TORCH_VERSION_AT_LEAST_2_5: pytest.skip("Unsupported PyTorch version", allow_module_level=True) diff --git a/test/float8/test_compile.py b/test/float8/test_compile.py index 9d52d6cf4d..ccbc4f80b3 100644 --- a/test/float8/test_compile.py +++ b/test/float8/test_compile.py @@ -11,9 +11,9 @@ import pytest -from torchao.utils import TORCH_VERSION_AT_LEAST_2_4 +from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 -if not TORCH_VERSION_AT_LEAST_2_4: +if not TORCH_VERSION_AT_LEAST_2_5: pytest.skip("Unsupported PyTorch version", allow_module_level=True) import torch diff --git a/test/float8/test_dtensor.py b/test/float8/test_dtensor.py index 8780f2f305..70d6673fca 100644 --- a/test/float8/test_dtensor.py +++ b/test/float8/test_dtensor.py @@ -19,9 +19,9 @@ import pytest -from torchao.utils import TORCH_VERSION_AT_LEAST_2_4 +from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 -if not TORCH_VERSION_AT_LEAST_2_4: +if not TORCH_VERSION_AT_LEAST_2_5: pytest.skip("Unsupported PyTorch version", allow_module_level=True) from torchao.float8 import Float8LinearConfig diff --git a/test/float8/test_fsdp.py b/test/float8/test_fsdp.py index 232a4818b9..2ba33bba08 100644 --- a/test/float8/test_fsdp.py +++ b/test/float8/test_fsdp.py @@ -18,9 +18,9 @@ import fire -from torchao.utils import TORCH_VERSION_AT_LEAST_2_4 +from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 -if not TORCH_VERSION_AT_LEAST_2_4: +if not TORCH_VERSION_AT_LEAST_2_5: pytest.skip("Unsupported PyTorch version", allow_module_level=True) import torch diff --git a/test/float8/test_fsdp2/test_fsdp2.py b/test/float8/test_fsdp2/test_fsdp2.py index a28b447487..30aa735480 100644 --- a/test/float8/test_fsdp2/test_fsdp2.py +++ b/test/float8/test_fsdp2/test_fsdp2.py @@ -5,9 +5,9 @@ import unittest from typing import Any, List -from torchao.utils import TORCH_VERSION_AT_LEAST_2_4 +from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 -if not TORCH_VERSION_AT_LEAST_2_4: +if not TORCH_VERSION_AT_LEAST_2_5: pytest.skip("Unsupported PyTorch version", allow_module_level=True) diff --git a/test/float8/test_fsdp_compile.py b/test/float8/test_fsdp_compile.py index c65311a956..b481c14e30 100644 --- a/test/float8/test_fsdp_compile.py +++ b/test/float8/test_fsdp_compile.py @@ -15,9 +15,9 @@ import pytest -from torchao.utils import TORCH_VERSION_AT_LEAST_2_4 +from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 -if not TORCH_VERSION_AT_LEAST_2_4: +if not TORCH_VERSION_AT_LEAST_2_5: pytest.skip("Unsupported PyTorch version", allow_module_level=True) import torch diff --git a/test/float8/test_inference_flows.py b/test/float8/test_inference_flows.py index 5743c55635..0845ae9cd1 100644 --- a/test/float8/test_inference_flows.py +++ b/test/float8/test_inference_flows.py @@ -12,11 +12,11 @@ import pytest from unittest.mock import patch from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_4, + TORCH_VERSION_AT_LEAST_2_5, unwrap_tensor_subclass, ) -if not TORCH_VERSION_AT_LEAST_2_4: +if not TORCH_VERSION_AT_LEAST_2_5: pytest.skip("Unsupported PyTorch version", allow_module_level=True) import torch diff --git a/test/float8/test_numerics_integration.py b/test/float8/test_numerics_integration.py index 5c35e139e0..ee9332ea4d 100644 --- a/test/float8/test_numerics_integration.py +++ b/test/float8/test_numerics_integration.py @@ -11,9 +11,9 @@ import pytest -from torchao.utils import TORCH_VERSION_AT_LEAST_2_4 +from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 -if not TORCH_VERSION_AT_LEAST_2_4: +if not TORCH_VERSION_AT_LEAST_2_5: pytest.skip("Unsupported PyTorch version", allow_module_level=True) import torch diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index 06f92edd02..5988c26492 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -1222,7 +1222,7 @@ def test_autoquant_one_input(self, device, dtype, m, k, n): (1, 32, 128, 128), (32, 32, 128, 128), ])) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "autoquant requires 2.4+.") + @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "autoquant requires 2.5+.") def test_autoquant_compile(self, device, dtype, m1, m2, k, n): undo_recommended_configs() if device != "cuda" or not torch.cuda.is_available(): @@ -1254,7 +1254,7 @@ def test_autoquant_compile(self, device, dtype, m1, m2, k, n): self.assertTrue(sqnr >= 30) @parameterized.expand(COMMON_DEVICE_DTYPE) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "autoquant requires 2.4+.") + @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "autoquant requires 2.5+.") def test_autoquant_manual(self, device, dtype): undo_recommended_configs() if device != "cuda" or not torch.cuda.is_available(): @@ -1295,7 +1295,7 @@ def test_autoquant_manual(self, device, dtype): (1, 32, 128, 128), (32, 32, 128, 128), ])) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "autoquant requires 2.4+.") + @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "autoquant requires 2.5+.") def test_autoquant_kwargs(self, device, dtype, m1, m2, k, n): undo_recommended_configs() if device != "cuda" or not torch.cuda.is_available(): @@ -1478,7 +1478,7 @@ def forward(self, x): class TestUtils(unittest.TestCase): @parameterized.expand(COMMON_DEVICE_DTYPE) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "autoquant requires 2.4+.") + @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "autoquant requires 2.5+.") def test_get_model_size_autoquant(self, device, dtype): if device != "cuda" and dtype != torch.bfloat16: self.skipTest(f"autoquant currently does not support {device}") From 443518fe2a1d738c0b2c8d01d72149ebb67ff1c0 Mon Sep 17 00:00:00 2001 From: Mark Saroufim Date: Wed, 14 Aug 2024 17:33:05 -0700 Subject: [PATCH 4/7] update adam8 test and mixed_mm --- test/integration/test_integration.py | 4 ++-- test/prototype/test_low_bit_optim.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index 5988c26492..4e8f6fbc39 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -913,7 +913,7 @@ def test_weight_only_quant_force_mixed_mm(self, device, dtype): if dtype == torch.bfloat16 and torch.cuda.get_device_capability() < (8, 0): self.skipTest("test requires SM capability of at least (8, 0).") from torch._inductor import config - mixed_mm_key, mixed_mm_val = ("mixed_mm_choice", "triton") if TORCH_VERSION_AT_LEAST_2_4 else ("force_mixed_mm", True) + mixed_mm_key, mixed_mm_val = ("mixed_mm_choice", "triton") if TORCH_VERSION_AT_LEAST_2_5 else ("force_mixed_mm", True) with config.patch({ "epilogue_fusion": True, @@ -943,7 +943,7 @@ def test_weight_only_quant_use_mixed_mm(self, device, dtype): self.skipTest("test requires SM capability of at least (8, 0).") torch.manual_seed(0) from torch._inductor import config - mixed_mm_key, mixed_mm_val = ("mixed_mm_choice", "triton") if TORCH_VERSION_AT_LEAST_2_4 else ("force_mixed_mm", True) + mixed_mm_key, mixed_mm_val = ("mixed_mm_choice", "triton") if TORCH_VERSION_AT_LEAST_2_5 else ("force_mixed_mm", True) with config.patch({ "epilogue_fusion": False, diff --git a/test/prototype/test_low_bit_optim.py b/test/prototype/test_low_bit_optim.py index 050965e811..af66df0b19 100644 --- a/test/prototype/test_low_bit_optim.py +++ b/test/prototype/test_low_bit_optim.py @@ -75,7 +75,7 @@ def test_quantize_4bit_with_qmap_compile(self, device): class TestOptim(TestCase): - @pytest.mark.xfail(not TORCH_VERSION_AT_LEAST_2_3, reason="torch.compile() fails for PyTorch < 2.3") + @pytest.mark.xfail(not TORCH_VERSION_AT_LEAST_2_4, reason="torch.compile() fails for PyTorch < 2.4") @parametrize("optim_name", ["Adam8bit", "AdamW8bit", "Adam4bit", "AdamW4bit", "AdamFp8", "AdamWFp8"]) @parametrize("dtype", [torch.float32, torch.bfloat16]) @parametrize("device", _DEVICES) From 2c78161ead083d0a311c5c102c46f1f1e1ae9d39 Mon Sep 17 00:00:00 2001 From: Mark Saroufim Date: Wed, 14 Aug 2024 17:37:57 -0700 Subject: [PATCH 5/7] push --- test/prototype/test_low_bit_optim.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/prototype/test_low_bit_optim.py b/test/prototype/test_low_bit_optim.py index af66df0b19..afeefa2239 100644 --- a/test/prototype/test_low_bit_optim.py +++ b/test/prototype/test_low_bit_optim.py @@ -75,7 +75,7 @@ def test_quantize_4bit_with_qmap_compile(self, device): class TestOptim(TestCase): - @pytest.mark.xfail(not TORCH_VERSION_AT_LEAST_2_4, reason="torch.compile() fails for PyTorch < 2.4") + @pytest.mark.xfail(not TORCH_VERSION_AT_LEAST_2_3, reason="torch.compile() fails for PyTorch < 2.3") @parametrize("optim_name", ["Adam8bit", "AdamW8bit", "Adam4bit", "AdamW4bit", "AdamFp8", "AdamWFp8"]) @parametrize("dtype", [torch.float32, torch.bfloat16]) @parametrize("device", _DEVICES) @@ -229,7 +229,7 @@ class TestFSDP2(FSDPTest): def world_size(self) -> int: return 2 - @pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_4, reason="torch >= 2.4 required") + @pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_5, reason="OptimState8bit dispatch: attempting to run unimplemented operator/function: aten.as_strided.default") @pytest.mark.skipif(TORCH_VERSION_AT_LEAST_2_5, reason="https://github.com/pytorch/ao/issues/652") @skip_if_lt_x_gpu(2) def test_fsdp2(self): From dc489066445f2de226d5c03ab090635effa9e9bc Mon Sep 17 00:00:00 2001 From: Mark Saroufim Date: Wed, 14 Aug 2024 19:24:57 -0700 Subject: [PATCH 6/7] push --- test/quantization/test_qat.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/quantization/test_qat.py b/test/quantization/test_qat.py index 7c8b8a3f17..cef78b32ef 100644 --- a/test/quantization/test_qat.py +++ b/test/quantization/test_qat.py @@ -423,6 +423,7 @@ def test_qat_4w_primitives(self): @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") @unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available") # TODO: remove once we fix int4 error: https://github.com/pytorch/ao/pull/517 + @unittest.skip(TORCH_VERSION_AFTER_2_4, "assert input.dtype == torch.float32" ) @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 doesn't work for 2.5+ right now") def test_qat_4w_linear(self): from torchao.quantization.prototype.qat.api import Int4WeightOnlyQATLinear @@ -453,6 +454,7 @@ def test_qat_4w_linear(self): @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") @unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available") # TODO: remove once we fix int4 error: https://github.com/pytorch/ao/pull/517 + @unittest.skip(TORCH_VERSION_AFTER_2_4, "assert input.dtype == torch.float32" ) @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 doesn't work for 2.5+ right now") def test_qat_4w_quantizer(self): from torchao.quantization.prototype.qat import Int4WeightOnlyQATQuantizer From 7d0a9e78a40a2d061bbf6b432bed567ef50b5ece Mon Sep 17 00:00:00 2001 From: Mark Saroufim Date: Wed, 14 Aug 2024 19:31:26 -0700 Subject: [PATCH 7/7] update --- test/quantization/test_qat.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/quantization/test_qat.py b/test/quantization/test_qat.py index cef78b32ef..232fbef813 100644 --- a/test/quantization/test_qat.py +++ b/test/quantization/test_qat.py @@ -423,7 +423,7 @@ def test_qat_4w_primitives(self): @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") @unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available") # TODO: remove once we fix int4 error: https://github.com/pytorch/ao/pull/517 - @unittest.skip(TORCH_VERSION_AFTER_2_4, "assert input.dtype == torch.float32" ) + @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_4, "assert input.dtype == torch.float32" ) @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 doesn't work for 2.5+ right now") def test_qat_4w_linear(self): from torchao.quantization.prototype.qat.api import Int4WeightOnlyQATLinear @@ -454,7 +454,7 @@ def test_qat_4w_linear(self): @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") @unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available") # TODO: remove once we fix int4 error: https://github.com/pytorch/ao/pull/517 - @unittest.skip(TORCH_VERSION_AFTER_2_4, "assert input.dtype == torch.float32" ) + @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_4, "assert input.dtype == torch.float32" ) @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 doesn't work for 2.5+ right now") def test_qat_4w_quantizer(self): from torchao.quantization.prototype.qat import Int4WeightOnlyQATQuantizer