diff --git a/ruff.toml b/ruff.toml index 447648929c..09d0a1ec97 100644 --- a/ruff.toml +++ b/ruff.toml @@ -10,8 +10,7 @@ include = [ "torchao/prototype/low_bit_optim/**.py", "test/float8/**/*.py", "test/quantization/test_observer.py", - "test/dtypes/test_affine_quantized_float.py", - "test/dtypes/test_nf4.py", + "test/dtypes/**/*.py", "test/prototype/low_bit_optim/**.py", "torchao/utils.py", diff --git a/test/dtypes/test_affine_quantized.py b/test/dtypes/test_affine_quantized.py index 540d85e032..f398a9d238 100644 --- a/test/dtypes/test_affine_quantized.py +++ b/test/dtypes/test_affine_quantized.py @@ -1,24 +1,24 @@ +import tempfile +import unittest + +import torch +from torch.testing._internal import common_utils from torch.testing._internal.common_utils import ( TestCase, run_tests, ) + +from torchao.dtypes import SemiSparseLayout from torchao.quantization import ( + float8_weight_only, int4_weight_only, - int8_weight_only, int8_dynamic_activation_int4_weight, int8_dynamic_activation_int8_weight, - int8_dynamic_activation_int8_semi_sparse_weight, - float8_weight_only, + int8_weight_only, ) from torchao.quantization.quant_primitives import MappingType -from torchao.dtypes import SemiSparseLayout -from torch.testing._internal import common_utils from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 -import torch -import unittest -import tempfile - is_cuda_8_9 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9) @@ -33,7 +33,9 @@ def get_quantization_functions(do_sparse: bool, do_int4: bool): base_functions.append(int4_weight_only(group_size=32)) if do_sparse: - base_functions.append(int8_dynamic_activation_int8_weight(layout=SemiSparseLayout())) + base_functions.append( + int8_dynamic_activation_int8_weight(layout=SemiSparseLayout()) + ) if is_cuda_8_9: base_functions.append(float8_weight_only()) @@ -44,11 +46,11 @@ def get_quantization_functions(do_sparse: bool, do_int4: bool): class TestAffineQuantized(TestCase): @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") def test_tensor_core_layout_transpose(self): - l = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda") - t = l.weight + linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda") + t = linear.weight shape = t.shape apply_int4_weight_only_quant = int4_weight_only(group_size=32) - ql = apply_int4_weight_only_quant(l) + ql = apply_int4_weight_only_quant(linear) aqt = ql.weight aqt_shape = aqt.shape self.assertEqual(aqt_shape, shape) @@ -64,8 +66,8 @@ def test_tensor_core_layout_transpose(self): @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") @common_utils.parametrize("apply_quant", get_quantization_functions(True, True)) def test_weights_only(self, apply_quant): - l = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda") - ql = apply_quant(l) + linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda") + ql = apply_quant(linear) with tempfile.NamedTemporaryFile() as f: torch.save(ql.state_dict(), f) f.seek(0) @@ -78,33 +80,32 @@ def test_weights_only(self, apply_quant): @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") @common_utils.parametrize("apply_quant", get_quantization_functions(False, False)) def test_to_device(self, apply_quant): - l = torch.nn.Linear(128, 256, dtype=torch.bfloat16) - ql = apply_quant(l) + linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16) + ql = apply_quant(linear) ql.to("cuda") - l = torch.nn.Linear(128, 256, dtype=torch.bfloat16) - ql = apply_quant(l) + linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16) + ql = apply_quant(linear) ql.to(device="cuda") - l = torch.nn.Linear(128, 256, dtype=torch.bfloat16) - ql = apply_quant(l) + linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16) + ql = apply_quant(linear) ql.cuda() @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") def test_register_new_dispatch(self): + from torchao.dtypes import AffineQuantizedTensor, to_affine_quantized_intx from torchao.dtypes.affine_quantized_tensor_ops import ( - register_aqt_quantized_linear_dispatch, deregister_aqt_quantized_linear_dispatch, + register_aqt_quantized_linear_dispatch, ) - from torchao.dtypes import to_affine_quantized_intx - from torchao.dtypes import AffineQuantizedTensor from torchao.quantization.quant_primitives import MappingType def dispatch_condition(input_tensor, weight_tensor, bias): return ( - isinstance(weight_tensor, AffineQuantizedTensor) and - weight_tensor.quant_min == 0 and - weight_tensor.quant_max == 2**6-1 + isinstance(weight_tensor, AffineQuantizedTensor) + and weight_tensor.quant_min == 0 + and weight_tensor.quant_max == 2**6 - 1 ) def impl(input_tensor, weight_tensor, bias): @@ -115,23 +116,35 @@ def impl(input_tensor, weight_tensor, bias): register_aqt_quantized_linear_dispatch(dispatch_condition, impl) def apply_uint6_weight_only_quant(linear): - linear.weight = torch.nn.Parameter(to_affine_quantized_intx(linear.weight, MappingType.ASYMMETRIC, (1, linear.weight.shape[-1]), torch.uint8, 0, 2**6-1), requires_grad=False) + linear.weight = torch.nn.Parameter( + to_affine_quantized_intx( + linear.weight, + MappingType.ASYMMETRIC, + (1, linear.weight.shape[-1]), + torch.uint8, + 0, + 2**6 - 1, + ), + requires_grad=False, + ) return linear - l = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda") - apply_uint6_weight_only_quant(l) + linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda") + apply_uint6_weight_only_quant(linear) example_input = torch.randn(1, 128, dtype=torch.bfloat16, device="cuda") - with self.assertRaisesRegex(AssertionError, "dispatching to my impl for uint6 weight only quant"): - l(example_input) + with self.assertRaisesRegex( + AssertionError, "dispatching to my impl for uint6 weight only quant" + ): + linear(example_input) deregister_aqt_quantized_linear_dispatch(dispatch_condition) @common_utils.parametrize("apply_quant", get_quantization_functions(True, True)) @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") def test_print_quantized_module(self, apply_quant): - l = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda") - ql = apply_quant(l) + linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda") + ql = apply_quant(linear) assert "AffineQuantizedTensor" in str(ql) @@ -143,20 +156,25 @@ class TestAffineQuantizedBasic(TestCase): @common_utils.parametrize("device", COMMON_DEVICES) @common_utils.parametrize("dtype", COMMON_DTYPES) def test_flatten_unflatten(self, apply_quant, device, dtype): - l = torch.nn.Linear(128, 256, dtype=dtype, device=device) - ql = apply_quant(l) + linear = torch.nn.Linear(128, 256, dtype=dtype, device=device) + ql = apply_quant(linear) lp_tensor = ql.weight tensor_data_name_dict, tensor_attributes = lp_tensor.__tensor_flatten__() - tensor_data_dict = {name: getattr(lp_tensor, name) for name in tensor_data_name_dict} + tensor_data_dict = { + name: getattr(lp_tensor, name) for name in tensor_data_name_dict + } outer_size = lp_tensor.size() outer_stride = lp_tensor.stride() - reconstructed = type(lp_tensor).__tensor_unflatten__(tensor_data_dict, tensor_attributes, outer_size, outer_stride) + reconstructed = type(lp_tensor).__tensor_unflatten__( + tensor_data_dict, tensor_attributes, outer_size, outer_stride + ) example_inputs = (torch.randn(32, 128, dtype=dtype, device=device),) ref = ql(*example_inputs) ql.weight = torch.nn.Parameter(reconstructed, requires_grad=False) reconstruct_res = ql(*example_inputs) self.assertEqual(reconstruct_res, ref) + common_utils.instantiate_parametrized_tests(TestAffineQuantized) common_utils.instantiate_parametrized_tests(TestAffineQuantizedBasic) diff --git a/test/dtypes/test_affine_quantized_tensor_parallel.py b/test/dtypes/test_affine_quantized_tensor_parallel.py index af07328407..82d3d2501d 100644 --- a/test/dtypes/test_affine_quantized_tensor_parallel.py +++ b/test/dtypes/test_affine_quantized_tensor_parallel.py @@ -1,28 +1,28 @@ -import torch import unittest -from torch.testing._internal.common_utils import run_tests + +import torch +from torch.distributed._tensor import DeviceMesh, DTensor, Replicate, Shard from torch.testing._internal import common_utils +from torch.testing._internal.common_utils import run_tests +from torch.testing._internal.distributed._tensor.common_dtensor import ( + DTensorTestBase, + with_comms, +) + from torchao.quantization import ( + float8_dynamic_activation_float8_weight, + float8_weight_only, int4_weight_only, int8_weight_only, - float8_weight_only, - float8_dynamic_activation_float8_weight, ) from torchao.quantization.observer import PerRow, PerTensor -import torch.distributed as dist -from torch.distributed._tensor import DTensor, Replicate, Shard, DeviceMesh -from torch.testing._internal.distributed._tensor.common_dtensor import ( - DTensorTestBase, - with_comms, - NUM_DEVICES, -) from torchao.quantization.quant_api import quantize_ -from torchao.dtypes import AffineQuantizedTensor from torchao.utils import TORCH_VERSION_AT_LEAST_2_6 + class TestAffineQuantizedTensorParallel(DTensorTestBase): - """Basic test case for tensor subclasses - """ + """Basic test case for tensor subclasses""" + QUANT_METHOD_FN = staticmethod(int8_weight_only) QUANT_METHOD_KWARGS = {} @@ -40,9 +40,7 @@ def colwise_shard(m: torch.nn.Module, mesh: DeviceMesh) -> torch.nn.Module: # Construct DTensor from local shard dtensor = DTensor.from_local(local_shard, mesh, [Shard(0)]) # Replace parameter in module - m.linear.weight = torch.nn.Parameter( - dtensor, requires_grad=False - ) + m.linear.weight = torch.nn.Parameter(dtensor, requires_grad=False) return m @staticmethod @@ -59,9 +57,7 @@ def rowwise_shard(m: torch.nn.Module, mesh: DeviceMesh) -> torch.nn.Module: # Construct DTensor from local shard dtensor = DTensor.from_local(local_shard, mesh, [Shard(1)], run_check=True) # Replace parameter in module - m.linear.weight = torch.nn.Parameter( - dtensor, requires_grad=False - ) + m.linear.weight = torch.nn.Parameter(dtensor, requires_grad=False) return m def quantize(self, m: torch.nn.Module) -> torch.nn.Module: @@ -79,7 +75,9 @@ def _test_tp(self, dtype): class M(torch.nn.Module): def __init__(self, in_features, out_features, **kwargs) -> None: super().__init__(**kwargs) - self.linear = torch.nn.Linear(in_features, out_features, bias=False, device="cuda") + self.linear = torch.nn.Linear( + in_features, out_features, bias=False, device="cuda" + ) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.linear(x) @@ -91,11 +89,11 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: proj_up = M(1024, 2048).to(device).to(dtype) proj_dn = M(2048, 1024).to(device).to(dtype) example_input = 100 * torch.randn(128, 1024, device=device, dtype=dtype) - y = proj_dn(proj_up(example_input)) + proj_dn(proj_up(example_input)) # Quantize the model up_quant = self.quantize(proj_up) dn_quant = self.quantize(proj_dn) - y_q = dn_quant(up_quant(example_input)) + dn_quant(up_quant(example_input)) mesh = self.build_device_mesh() mesh.device_type = "cuda" @@ -105,11 +103,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: dn_dist = self.rowwise_shard(dn_quant, mesh) # We need to turn inputs into DTensor form as well -- just a format change - input_dtensor = DTensor.from_local( - example_input, mesh, [Replicate()] - ) + input_dtensor = DTensor.from_local(example_input, mesh, [Replicate()]) - y_d = dn_dist(up_dist(input_dtensor)) + dn_dist(up_dist(input_dtensor)) if not TORCH_VERSION_AT_LEAST_2_6: # Need torch 2.6 to support compiled tensor parallelism @@ -118,7 +114,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: up_compiled = torch.compile(up_dist) y_up = up_compiled(input_dtensor) dn_compiled = torch.compile(dn_dist) - y_dn = dn_compiled(y_up) + dn_compiled(y_up) class TestInt8woAffineQuantizedTensorParallel(TestAffineQuantizedTensorParallel): @@ -142,11 +138,13 @@ class TestInt4woAffineQuantizedTensorParallel(TestAffineQuantizedTensorParallel) def test_tp(self, dtype): return self._test_tp(dtype) + common_utils.instantiate_parametrized_tests(TestInt8woAffineQuantizedTensorParallel) common_utils.instantiate_parametrized_tests(TestInt4woAffineQuantizedTensorParallel) # Run only on H100 if torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0): + class TestFloat8woAffineQuantizedTensorParallel(TestAffineQuantizedTensorParallel): QUANT_METHOD_FN = staticmethod(float8_weight_only) COMMON_DTYPES = [torch.bfloat16, torch.float16, torch.float32] @@ -157,7 +155,9 @@ class TestFloat8woAffineQuantizedTensorParallel(TestAffineQuantizedTensorParalle def test_tp(self, dtype): return self._test_tp(dtype) - class TestFloat8dqTensorAffineQuantizedTensorParallel(TestAffineQuantizedTensorParallel): + class TestFloat8dqTensorAffineQuantizedTensorParallel( + TestAffineQuantizedTensorParallel + ): QUANT_METHOD_FN = staticmethod(float8_dynamic_activation_float8_weight) QUANT_METHOD_KWARGS = {"granularity": PerTensor()} COMMON_DTYPES = [torch.bfloat16, torch.float16, torch.float32] @@ -168,7 +168,9 @@ class TestFloat8dqTensorAffineQuantizedTensorParallel(TestAffineQuantizedTensorP def test_tp(self, dtype): return self._test_tp(dtype) - class TestFloat8dqRowAffineQuantizedTensorParallel(TestAffineQuantizedTensorParallel): + class TestFloat8dqRowAffineQuantizedTensorParallel( + TestAffineQuantizedTensorParallel + ): QUANT_METHOD_FN = staticmethod(float8_dynamic_activation_float8_weight) QUANT_METHOD_KWARGS = {"granularity": PerRow()} COMMON_DTYPES = [torch.bfloat16] @@ -179,7 +181,11 @@ class TestFloat8dqRowAffineQuantizedTensorParallel(TestAffineQuantizedTensorPara def test_tp(self, dtype): return self._test_tp(dtype) - common_utils.instantiate_parametrized_tests(TestFloat8dqTensorAffineQuantizedTensorParallel) - common_utils.instantiate_parametrized_tests(TestFloat8dqRowAffineQuantizedTensorParallel) + common_utils.instantiate_parametrized_tests( + TestFloat8dqTensorAffineQuantizedTensorParallel + ) + common_utils.instantiate_parametrized_tests( + TestFloat8dqRowAffineQuantizedTensorParallel + ) if __name__ == "__main__": run_tests() diff --git a/test/dtypes/test_bitnet.py b/test/dtypes/test_bitnet.py index 70153cf5ba..e248b04b05 100644 --- a/test/dtypes/test_bitnet.py +++ b/test/dtypes/test_bitnet.py @@ -1,6 +1,7 @@ import pytest import torch import torch.nn as nn + from torchao.prototype.dtypes import BitnetTensor from torchao.prototype.dtypes.uint2 import unpack_uint2 from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter @@ -9,6 +10,7 @@ if not TORCH_VERSION_AT_LEAST_2_4: pytest.skip("Unsupported PyTorch version", allow_module_level=True) + @pytest.fixture(autouse=True) def run_before_and_after_tests(): # source: https://stackoverflow.com/questions/22627659/run-code-before-and-after-each-test-in-py-test # noqa: E501 @@ -22,34 +24,45 @@ def run_before_and_after_tests(): # avoid dynamo cache limit issues torch._dynamo.reset() + @pytest.fixture def bitnet_tensor(): - input_tensor = torch.randint(0, 15, (4,4), dtype=torch.uint8) + input_tensor = torch.randint(0, 15, (4, 4), dtype=torch.uint8) return BitnetTensor.from_unpacked(input_tensor) + def test_copy(bitnet_tensor): copied_tensor = bitnet_tensor.clone() assert torch.equal(bitnet_tensor.elem, copied_tensor.elem) + def test_transpose(bitnet_tensor): transposed_tensor = bitnet_tensor.t() expected_tensor = unpack_uint2(bitnet_tensor.elem).t() assert torch.equal(unpack_uint2(transposed_tensor.elem), expected_tensor) + def test_multiply(bitnet_tensor): w_t = torch.randint(0, 15, (4, 16), dtype=torch.uint8) w = BitnetTensor.from_unpacked(w_t) - y = torch.addmm(torch.Tensor([1]), bitnet_tensor, w) + torch.addmm(torch.Tensor([1]), bitnet_tensor, w) + -@pytest.mark.parametrize("dtype", [torch.float, torch.float16, torch.bfloat16, torch.int16, torch.int32, torch.int64]) +@pytest.mark.parametrize( + "dtype", + [torch.float, torch.float16, torch.bfloat16, torch.int16, torch.int32, torch.int64], +) def test_conversion(bitnet_tensor, dtype): converted_tensor = bitnet_tensor.to(dtype) expected_tensor = unpack_uint2(bitnet_tensor.elem).to(dtype) assert torch.allclose(converted_tensor, expected_tensor, atol=1e-5) + def _apply_weight_only_uint2_quant(model): def fn(mod): - mod.weight = torch.nn.Parameter(BitnetTensor.from_float(mod.weight), requires_grad=False) + mod.weight = torch.nn.Parameter( + BitnetTensor.from_float(mod.weight), requires_grad=False + ) return mod _replace_with_custom_fn_if_matches_filter( @@ -58,19 +71,21 @@ def fn(mod): lambda mod, fqn: isinstance(mod, torch.nn.Linear), ) -@pytest.mark.skipif(TORCH_VERSION_AT_LEAST_2_5, reason="Regression introdued in nightlies") + +@pytest.mark.skipif( + TORCH_VERSION_AT_LEAST_2_5, reason="Regression introdued in nightlies" +) @pytest.mark.parametrize("input_shape", [[2, 4], [5, 5, 5, 4], [1, 4, 4]]) def test_uint2_quant(input_shape): - device = 'cuda' if torch.cuda.is_available() else 'cpu' + device = "cuda" if torch.cuda.is_available() else "cpu" x = torch.randn(*input_shape).to(device) m = nn.Sequential(nn.Linear(4, 16)).to(device) y_ref = m(x) _apply_weight_only_uint2_quant(m) y_wo = m(x) assert y_ref.shape == y_wo.shape - y_compiled = torch.compile(m, fullgraph=True)(x) + torch.compile(m, fullgraph=True)(x) if __name__ == "__main__": pytest.main(__file__) - diff --git a/test/dtypes/test_bitpacking.py b/test/dtypes/test_bitpacking.py index 647ead8fd8..262a4d6ca6 100644 --- a/test/dtypes/test_bitpacking.py +++ b/test/dtypes/test_bitpacking.py @@ -1,33 +1,38 @@ -import torch -from torchao.dtypes.uintx.bitpacking import pack, unpack, pack_cpu, unpack_cpu import pytest +import torch from torch.utils._triton import has_triton -bit_widths = (1,2,3,4,5,6,7) +from torchao.dtypes.uintx.bitpacking import pack, pack_cpu, unpack, unpack_cpu + +bit_widths = (1, 2, 3, 4, 5, 6, 7) dimensions = (0, -1, 1) + @pytest.fixture(autouse=True) def run_before_and_after_tests(): yield - torch._dynamo.reset() # reset cache between tests + torch._dynamo.reset() # reset cache between tests + @pytest.mark.parametrize("bit_width", bit_widths) @pytest.mark.parametrize("dim", dimensions) def test_CPU(bit_width, dim): - test_tensor = torch.randint(0, 2**bit_width, (32,32,32), dtype=torch.uint8, device='cpu') - packed = pack_cpu(test_tensor, bit_width, dim = dim) - unpacked = unpack_cpu(packed, bit_width, dim = dim) - assert(unpacked.allclose(test_tensor)) + test_tensor = torch.randint( + 0, 2**bit_width, (32, 32, 32), dtype=torch.uint8, device="cpu" + ) + packed = pack_cpu(test_tensor, bit_width, dim=dim) + unpacked = unpack_cpu(packed, bit_width, dim=dim) + assert unpacked.allclose(test_tensor) @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.parametrize("bit_width", bit_widths) @pytest.mark.parametrize("dim", dimensions) def test_GPU(bit_width, dim): - test_tensor = torch.randint(0, 2**bit_width, (32,32,32), dtype=torch.uint8).cuda() - packed = pack(test_tensor, bit_width, dim = dim) - unpacked = unpack(packed, bit_width, dim = dim) - assert(unpacked.allclose(test_tensor)) + test_tensor = torch.randint(0, 2**bit_width, (32, 32, 32), dtype=torch.uint8).cuda() + packed = pack(test_tensor, bit_width, dim=dim) + unpacked = unpack(packed, bit_width, dim=dim) + assert unpacked.allclose(test_tensor) @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @@ -36,27 +41,33 @@ def test_GPU(bit_width, dim): @pytest.mark.parametrize("dim", dimensions) def test_compile(bit_width, dim): torch._dynamo.config.specialize_int = True - pack_compile = torch.compile(pack, fullgraph=True) - unpack_compile = torch.compile(unpack, fullgraph=True) - test_tensor = torch.randint(0, 2**bit_width, (32,32,32), dtype=torch.uint8).cuda() - packed = pack(test_tensor, bit_width, dim = dim) - unpacked = unpack(packed, bit_width, dim = dim) - assert(unpacked.allclose(test_tensor)) + torch.compile(pack, fullgraph=True) + torch.compile(unpack, fullgraph=True) + test_tensor = torch.randint(0, 2**bit_width, (32, 32, 32), dtype=torch.uint8).cuda() + packed = pack(test_tensor, bit_width, dim=dim) + unpacked = unpack(packed, bit_width, dim=dim) + assert unpacked.allclose(test_tensor) + # these test cases are for the example pack walk through in the bitpacking.py file @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") def test_pack_example(): - test_tensor = torch.tensor([0x30,0x29,0x17,0x5,0x20,0x16,0x9,0x22], dtype=torch.uint8).cuda() - shard_4,shard_2 = pack(test_tensor, 6) + test_tensor = torch.tensor( + [0x30, 0x29, 0x17, 0x5, 0x20, 0x16, 0x9, 0x22], dtype=torch.uint8 + ).cuda() + shard_4, shard_2 = pack(test_tensor, 6) print(shard_4, shard_2) assert torch.tensor([0, 105, 151, 37], dtype=torch.uint8).cuda().allclose(shard_4) assert torch.tensor([39, 146], dtype=torch.uint8).cuda().allclose(shard_2) unpacked = unpack([shard_4, shard_2], 6) assert unpacked.allclose(test_tensor) + def test_pack_example_CPU(): - test_tensor = torch.tensor([0x30,0x29,0x17,0x5,0x20,0x16,0x9,0x22], dtype=torch.uint8) - shard_4,shard_2 = pack(test_tensor, 6) + test_tensor = torch.tensor( + [0x30, 0x29, 0x17, 0x5, 0x20, 0x16, 0x9, 0x22], dtype=torch.uint8 + ) + shard_4, shard_2 = pack(test_tensor, 6) print(shard_4, shard_2) assert torch.tensor([0, 105, 151, 37], dtype=torch.uint8).allclose(shard_4) assert torch.tensor([39, 146], dtype=torch.uint8).allclose(shard_2) diff --git a/test/dtypes/test_floatx.py b/test/dtypes/test_floatx.py index aa55164716..8bb39b2cc8 100644 --- a/test/dtypes/test_floatx.py +++ b/test/dtypes/test_floatx.py @@ -1,7 +1,6 @@ import copy - -import pytest import unittest + import torch from torch.testing._internal.common_utils import ( TestCase, @@ -9,21 +8,27 @@ parametrize, run_tests, ) + from torchao.dtypes.floatx import ( FloatxTensorCoreLayout, - to_scaled_tc_floatx, from_scaled_tc_floatx, + to_scaled_tc_floatx, +) +from torchao.dtypes.floatx.floatx_tensor_core_layout import ( + FloatxTensorCoreAQTTensorImpl, + _pack_tc_floatx, + _pack_tc_fp6, +) +from torchao.prototype.custom_fp_utils import ( + _f32_to_floatx_unpacked, + _floatx_unpacked_to_f32, ) -from torchao.dtypes.floatx.floatx_tensor_core_layout import _pack_tc_floatx, _pack_tc_fp6, FloatxTensorCoreAQTTensorImpl -from torchao.prototype.custom_fp_utils import _f32_to_floatx_unpacked, _floatx_unpacked_to_f32 from torchao.quantization import ( - quantize_, fpx_weight_only, + quantize_, ) - from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, is_fbcode - _DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else []) _Floatx_DTYPES = [(3, 2), (2, 2)] @@ -52,7 +57,9 @@ def test_from_tc_floatx_correctness(self, ebits, mbits, device): x = torch.randn(256, 64, device=device) * 100 # quantize and dequantize so that the values are exactly representable in Floatx - x = _floatx_unpacked_to_f32(_f32_to_floatx_unpacked(x, ebits, mbits), ebits, mbits) + x = _floatx_unpacked_to_f32( + _f32_to_floatx_unpacked(x, ebits, mbits), ebits, mbits + ) tc_floatx, scale = to_scaled_tc_floatx(x, ebits, mbits) actual = from_scaled_tc_floatx(tc_floatx, ebits, mbits, scale=scale) @@ -63,11 +70,15 @@ def test_from_tc_floatx_correctness(self, ebits, mbits, device): def test_from_scaled_tc_floatx_compile(self, ebits, mbits, device): M, N = 256, 64 nbits = 1 + ebits + mbits - x = torch.randint(256, size=(M, N // 8 * nbits), dtype=torch.uint8, device=device) + x = torch.randint( + 256, size=(M, N // 8 * nbits), dtype=torch.uint8, device=device + ) scale = torch.randn(M, device=device) expected = from_scaled_tc_floatx(x, ebits, mbits, scale) - actual = torch.compile(from_scaled_tc_floatx, fullgraph=True)(x, ebits, mbits, scale) + actual = torch.compile(from_scaled_tc_floatx, fullgraph=True)( + x, ebits, mbits, scale + ) torch.testing.assert_close(actual, expected) @unittest.skipIf(not torch.cuda.is_available(), reason="CUDA not available") @@ -82,13 +93,18 @@ def test_to_copy_device(self, ebits, mbits): scale = choose_qparams_affine_floatx(x, ebits, mbits) x = quantize_affine_floatx(x, scale, ebits, mbits) _layout = FloatxTensorCoreLayout(ebits, mbits) - floatx_tensor_impl = FloatxTensorCoreAQTTensorImpl.from_plain(x, scale, None, _layout).cuda() + floatx_tensor_impl = FloatxTensorCoreAQTTensorImpl.from_plain( + x, scale, None, _layout + ).cuda() assert floatx_tensor_impl.device.type == "cuda" floatx_tensor_impl = floatx_tensor_impl.cpu() assert floatx_tensor_impl.device.type == "cpu" @unittest.skipIf(not torch.cuda.is_available(), reason="CUDA not available") - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, reason="quantization only works with torch.compile for 2.5+") + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_5, + reason="quantization only works with torch.compile for 2.5+", + ) @parametrize("ebits,mbits", _Floatx_DTYPES) @parametrize("bias", [False, True]) @parametrize("dtype", [torch.half, torch.bfloat16]) diff --git a/test/dtypes/test_uint2.py b/test/dtypes/test_uint2.py index b017c47dd4..f6faaea10d 100644 --- a/test/dtypes/test_uint2.py +++ b/test/dtypes/test_uint2.py @@ -1,6 +1,6 @@ import pytest import torch -import torch.nn as nn + from torchao.prototype.dtypes import UInt2Tensor from torchao.prototype.dtypes.uint2 import unpack_uint2 from torchao.utils import TORCH_VERSION_AT_LEAST_2_4 @@ -8,26 +8,33 @@ if not TORCH_VERSION_AT_LEAST_2_4: pytest.skip("Unsupported PyTorch version", allow_module_level=True) + @pytest.fixture def uint2_tensor(): - input_tensor = torch.randint(0, 15, (4,4), dtype = torch.uint8) + input_tensor = torch.randint(0, 15, (4, 4), dtype=torch.uint8) return UInt2Tensor(input_tensor) + def test_copy(uint2_tensor): copied_tensor = uint2_tensor.clone() assert torch.equal(uint2_tensor.elem, copied_tensor.elem) + def test_transpose(uint2_tensor): transposed_tensor = uint2_tensor.t() expected_tensor = unpack_uint2(uint2_tensor.elem).t() assert torch.equal(unpack_uint2(transposed_tensor.elem), expected_tensor) -@pytest.mark.parametrize("dtype", [torch.float, torch.float16, torch.bfloat16, torch.int16, torch.int32, torch.int64]) + +@pytest.mark.parametrize( + "dtype", + [torch.float, torch.float16, torch.bfloat16, torch.int16, torch.int32, torch.int64], +) def test_conversion(uint2_tensor, dtype): converted_tensor = uint2_tensor.to(dtype) expected_tensor = unpack_uint2(uint2_tensor.elem).to(dtype) assert torch.allclose(converted_tensor, expected_tensor, atol=1e-5) -if __name__ == '__main__': + +if __name__ == "__main__": pytest.main(__file__) - diff --git a/test/dtypes/test_uint4.py b/test/dtypes/test_uint4.py index 432ffebbd2..e148d68abb 100644 --- a/test/dtypes/test_uint4.py +++ b/test/dtypes/test_uint4.py @@ -1,35 +1,42 @@ -import torch -from torchao.dtypes.uintx.uint4_layout import ( - UInt4Tensor, - PerChannelSymmetricWeightUInt4Tensor, -) +import copy import unittest -from torch.ao.quantization.quantize_pt2e import prepare_pt2e, convert_pt2e -from torch.ao.quantization.quantizer import QuantizationSpec, Quantizer +import torch +from torch import nn +from torch.ao.quantization.observer import ObserverBase +from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e +from torch.ao.quantization.quantizer import ( + QuantizationAnnotation, + QuantizationSpec, + Quantizer, +) +from torch.fx import ( + GraphModule, + Node, +) from torch.testing._internal.common_quantization import ( NodeSpec as ns, +) +from torch.testing._internal.common_quantization import ( QuantizationTestCase, ) + +from torchao.dtypes.uintx.uint4_layout import ( + PerChannelSymmetricWeightUInt4Tensor, + UInt4Tensor, +) from torchao.quantization.quant_api import ( _replace_with_custom_fn_if_matches_filter, ) -from torch.ao.quantization.observer import ObserverBase -from torch import nn -from torch.fx import ( - Node, - GraphModule, -) -from torch.ao.quantization.quantizer import ( - QuantizationAnnotation, -) -import copy from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 def _apply_weight_only_uint4_quant(model): def fn(mod): - mod.weight = torch.nn.Parameter(PerChannelSymmetricWeightUInt4Tensor.from_float(mod.weight), requires_grad=False) + mod.weight = torch.nn.Parameter( + PerChannelSymmetricWeightUInt4Tensor.from_float(mod.weight), + requires_grad=False, + ) return mod _replace_with_custom_fn_if_matches_filter( @@ -38,28 +45,46 @@ def fn(mod): lambda mod, fqn: isinstance(mod, torch.nn.Linear), ) -@unittest.skip("FAILED test/dtypes/test_uint4.py::TestUInt4::test_basic_tensor_ops - AttributeError: module 'torch' has no attribute 'uint4'") + +@unittest.skip( + "FAILED test/dtypes/test_uint4.py::TestUInt4::test_basic_tensor_ops - AttributeError: module 'torch' has no attribute 'uint4'" +) class TestUInt4(QuantizationTestCase): def test_basic_tensor_ops(self): - x = UInt4Tensor(torch.tensor([ - [0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF], - [0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF], - [0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF], - ], dtype=torch.uint8)) + x = UInt4Tensor( + torch.tensor( + [ + [0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF], + [0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF], + [0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF], + ], + dtype=torch.uint8, + ) + ) self.assertEqual(x.shape, (3, 16)) # TODO: make sure this returns torch.uint4 self.assertIs(x.dtype, torch.uint4) # making sure these works x.to(torch.uint8) - expected = UInt4Tensor(torch.tensor([ - [0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF], - ], dtype=torch.uint8)) + expected = UInt4Tensor( + torch.tensor( + [ + [0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF], + ], + dtype=torch.uint8, + ) + ) self.assertEqual(x[0:1, :], expected) - expected = UInt4Tensor(torch.tensor([ - [0x23, 0x45], - [0x23, 0x45], - [0x23, 0x45], - ], dtype=torch.uint8)) + expected = UInt4Tensor( + torch.tensor( + [ + [0x23, 0x45], + [0x23, 0x45], + [0x23, 0x45], + ], + dtype=torch.uint8, + ) + ) self.assertEqual(x[:, 2:6], expected) torch.save(x, "uint4_tensor.pt") x = torch.load("uint4_tensor.pt") @@ -71,9 +96,9 @@ def test_gpu_quant(self): for x_shape in [[2, 4], [5, 5, 5, 4], [1, 4, 4]]: x = torch.randn(*x_shape) m = nn.Sequential(nn.Linear(4, 16)) - y_ref = m(x) + m(x) # checking if it runs _apply_weight_only_uint4_quant(m) - y_wo = m(x) + m(x) # checking if it runs # sqnr = compute_error(y_ref, y_wo) opt = torch.compile(m, fullgraph=True, mode="max-autotune") # make sure it runs @@ -81,9 +106,9 @@ def test_gpu_quant(self): def test_pt2e_quant(self): from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import ( - OP_TO_ANNOTATOR, QuantizationConfig, ) + class Uint4Observer(ObserverBase): def __init__(self, *args, **kwargs): # just faking a dtype here @@ -99,9 +124,15 @@ def calculate_qparams(self, **kwargs): def convert(self, model: GraphModule, observer_node: Node): with model.graph.inserting_before(observer_node): q_node = model.graph.call_function( - torch.ops.qtensors.quantize_per_tensor_uint4, (observer_node.args[0], 1.0, 0), {}) + torch.ops.qtensors.quantize_per_tensor_uint4, + (observer_node.args[0], 1.0, 0), + {}, + ) dq_node = model.graph.call_function( - torch.ops.qtensors.dequantize_per_tensor_uint4, (q_node, 1.0, 0), {}) + torch.ops.qtensors.dequantize_per_tensor_uint4, + (q_node, 1.0, 0), + {}, + ) observer_node.replace_all_uses_with(dq_node) model.graph.erase_node(observer_node) @@ -160,10 +191,12 @@ def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: if _is_annotated(partition): continue - linear_node.meta["quantization_annotation"] = QuantizationAnnotation( - input_qspec_map=input_qspec_map, - output_qspec=quantization_config.output_activation, - _annotated=True, + linear_node.meta["quantization_annotation"] = ( + QuantizationAnnotation( + input_qspec_map=input_qspec_map, + output_qspec=quantization_config.output_activation, + _annotated=True, + ) ) _mark_nodes_as_annotated(partition) @@ -197,7 +230,6 @@ def forward(self, x): # _test_quantizer in PT2EQuantizationTestCase # resetting dynamo cache - export_with_dynamic_shape = False torch._dynamo.reset() m_eager = M().eval() @@ -210,23 +242,22 @@ def forward(self, x): ).module() else: m = torch._export.capture_pre_autograd_graph( - m, - example_inputs, - ).module() + m, + example_inputs, + ).module() m = prepare_pt2e(m, quantizer) # Calibrate m(*example_inputs) m = convert_pt2e(m, fold_quantize=False) - pt2_quant_output = m(*example_inputs) + m(*example_inputs) - node_occurrence = { - ns.call_function(k): v for k, v in node_occurrence.items() - } + node_occurrence = {ns.call_function(k): v for k, v in node_occurrence.items()} node_list = [ns.call_function(n) for n in node_list] self.checkGraphModuleNodes( m, expected_node_occurrence=node_occurrence, expected_node_list=node_list ) + if __name__ == "__main__": unittest.main() diff --git a/test/dtypes/test_uintx.py b/test/dtypes/test_uintx.py index f4823c4d3b..da43253678 100644 --- a/test/dtypes/test_uintx.py +++ b/test/dtypes/test_uintx.py @@ -1,85 +1,109 @@ -from math import log -from copy import deepcopy import pytest - import torch from torchao.dtypes.uintx.uintx_layout import to_uintx from torchao.quantization.quant_api import quantize_, uintx_weight_only -from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_3, - TORCH_VERSION_AT_LEAST_2_5, -) - from torchao.quantization.quant_primitives import ( MappingType, ZeroPointDomain, choose_qparams_affine, - quantize_affine, dequantize_affine, + quantize_affine, +) +from torchao.utils import ( + TORCH_VERSION_AT_LEAST_2_3, + TORCH_VERSION_AT_LEAST_2_5, ) # torch.uintx dtypes are introduced in 2.3 if TORCH_VERSION_AT_LEAST_2_3: - dtypes = (torch.uint1, torch.uint2, torch.uint3, torch.uint4, torch.uint5, torch.uint6, torch.uint7) + dtypes = ( + torch.uint1, + torch.uint2, + torch.uint3, + torch.uint4, + torch.uint5, + torch.uint6, + torch.uint7, + ) else: dtypes = () group_sizes = [32, 64, 128] devices = ["cpu", "cuda"] + + @pytest.fixture(autouse=True) def run_before_and_after_tests(): yield - torch._dynamo.reset() # reset cache between tests + torch._dynamo.reset() # reset cache between tests + class Linear16(torch.nn.Module): def __init__(self, scale, device): super().__init__() self.net = torch.nn.Sequential( - torch.nn.Linear(scale * 2, scale, bias=False, dtype=torch.float16, device=device), - torch.nn.Linear(scale, scale, bias=False, dtype=torch.float16, device=device), - torch.nn.Linear(scale, scale//2, bias=False, dtype=torch.float16, device=device), + torch.nn.Linear( + scale * 2, scale, bias=False, dtype=torch.float16, device=device + ), + torch.nn.Linear( + scale, scale, bias=False, dtype=torch.float16, device=device + ), + torch.nn.Linear( + scale, scale // 2, bias=False, dtype=torch.float16, device=device + ), ) def forward(self, x): return self.net(x) + @pytest.mark.parametrize("dtype", dtypes) @pytest.mark.parametrize("group_size", group_sizes) @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_5, reason="only works with fix in the nightly build") +@pytest.mark.skipif( + not TORCH_VERSION_AT_LEAST_2_5, reason="only works with fix in the nightly build" +) def test_uintx_quant_on_cpu_then_move_to_cuda(dtype, group_size): scale = 512 fp16_mod_on_cpu = Linear16(scale, "cpu") quantize_(fp16_mod_on_cpu, uintx_weight_only(dtype, group_size=group_size)) - test_input_on_cpu = torch.randn(scale*2, dtype=torch.float16, device="cpu") + test_input_on_cpu = torch.randn(scale * 2, dtype=torch.float16, device="cpu") output_on_cpu = fp16_mod_on_cpu(test_input_on_cpu) fp16_mod_on_cuda = fp16_mod_on_cpu.to("cuda") test_input_on_cuda = test_input_on_cpu.to("cuda") output_on_cuda = fp16_mod_on_cuda(test_input_on_cuda) - assert torch.allclose(output_on_cpu, output_on_cuda.cpu(), atol=1.0e-3), "The output of the model on CPU and CUDA should be close" + assert torch.allclose( + output_on_cpu, output_on_cuda.cpu(), atol=1.0e-3 + ), "The output of the model on CPU and CUDA should be close" + @pytest.mark.parametrize("dtype", dtypes) @pytest.mark.parametrize("group_size", group_sizes) @pytest.mark.parametrize("device", devices) @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_5, reason="only works with fix in the nightly build") +@pytest.mark.skipif( + not TORCH_VERSION_AT_LEAST_2_5, reason="only works with fix in the nightly build" +) def test_uintx_weight_only_model_quant(dtype, group_size, device): scale = 512 fp16 = Linear16(scale, device) quantize_(fp16, uintx_weight_only(dtype, group_size=group_size)) uintx = torch.compile(fp16, fullgraph=True) - test_input = torch.randn(scale*2, dtype=torch.float16, device=device) + test_input = torch.randn(scale * 2, dtype=torch.float16, device=device) output = uintx.forward(test_input) - assert output != None, "model quantization failed" + assert output is not None, "model quantization failed" + @pytest.mark.parametrize("dtype", dtypes) @pytest.mark.parametrize("group_size", group_sizes) @pytest.mark.parametrize("device", devices) @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_5, reason="only works with fix in the nightly build") +@pytest.mark.skipif( + not TORCH_VERSION_AT_LEAST_2_5, reason="only works with fix in the nightly build" +) def test_uintx_weight_only_quant(dtype, group_size, device): - input_float = torch.randn((1, 256), dtype=torch.float16, device = device) + input_float = torch.randn((1, 256), dtype=torch.float16, device=device) mapping_type = MappingType.SYMMETRIC eps = torch.finfo(torch.float32).eps zero_point_dtype = torch.int32 @@ -87,73 +111,91 @@ def test_uintx_weight_only_quant(dtype, group_size, device): block_size = (1, group_size) scale, zero_point = choose_qparams_affine( - input_float, mapping_type, block_size, - dtype, eps=eps, scale_dtype=torch.float32, - zero_point_dtype=zero_point_dtype, preserve_zero=True, zero_point_domain=zero_point_domain + input_float, + mapping_type, + block_size, + dtype, + eps=eps, + scale_dtype=torch.float32, + zero_point_dtype=zero_point_dtype, + preserve_zero=True, + zero_point_domain=zero_point_domain, ) aqt = quantize_affine( - input_float, block_size, scale, - zero_point, dtype, - zero_point_domain=zero_point_domain + input_float, + block_size, + scale, + zero_point, + dtype, + zero_point_domain=zero_point_domain, ) # Note: output will be uint8 tensor for sub byte tensors for now - q = to_uintx(aqt, dtype, -1) - assert q != None, "quantization failed" + q = to_uintx(aqt, dtype, -1) + assert q is not None, "quantization failed" deqaunt = dequantize_affine( - q, block_size, scale, - zero_point, dtype, - zero_point_domain=zero_point_domain + q, block_size, scale, zero_point, dtype, zero_point_domain=zero_point_domain ) - assert deqaunt != None, "deqauntization failed" + assert deqaunt is not None, "deqauntization failed" @pytest.mark.parametrize("dtype", dtypes) @pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available") -@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_3, reason="sub byte dtype requires torch 2.3+") +@pytest.mark.skipif( + not TORCH_VERSION_AT_LEAST_2_3, reason="sub byte dtype requires torch 2.3+" +) def test_uintx_target_dtype(dtype): from torchao.quantization.quant_api import uintx_weight_only - l = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda") + + linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda") # make sure it runs - uintx_weight_only(dtype)(l) - l(torch.randn(1, 128, dtype=torch.bfloat16, device="cuda")) + uintx_weight_only(dtype)(linear) + linear(torch.randn(1, 128, dtype=torch.bfloat16, device="cuda")) + @pytest.mark.parametrize("dtype", dtypes) @pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available") -@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_5, reason="torch.compile without unwrap_tensor_subclass requires torch 2.5+") +@pytest.mark.skipif( + not TORCH_VERSION_AT_LEAST_2_5, + reason="torch.compile without unwrap_tensor_subclass requires torch 2.5+", +) def test_uintx_target_dtype_compile(dtype): from torchao.quantization.quant_api import uintx_weight_only - l = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda") + + linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda") # make sure it runs - uintx_weight_only(dtype)(l) - l = torch.compile(l) - l(torch.randn(1, 128, dtype=torch.bfloat16, device="cuda")) + uintx_weight_only(dtype)(linear) + linear = torch.compile(linear) + linear(torch.randn(1, 128, dtype=torch.bfloat16, device="cuda")) @pytest.mark.parametrize("dtype", dtypes) @pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available") -@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_3, reason="sub byte dtype requires torch 2.3+") +@pytest.mark.skipif( + not TORCH_VERSION_AT_LEAST_2_3, reason="sub byte dtype requires torch 2.3+" +) def test_uintx_model_size(dtype): from torchao.quantization.quant_api import uintx_weight_only from torchao.utils import get_model_size_in_bytes + # scale size = 1/64 * 2 bytes = 1/32 bytes # zero_point size = 1/64 * 4 bytes = 1/16 bytes # dtype data size = 1 * bit_width/8 = bit_width/8 bytes _dtype_to_ratio = { - torch.uint1: (1/8 + 1/16 + 1/32) / 2, - torch.uint2: (2/8 + 1/16 + 1/32) / 2, - torch.uint3: (3/8 + 1/16 + 1/32) / 2, - torch.uint4: (4/8 + 1/16 + 1/32) / 2, - torch.uint5: (5/8 + 1/16 + 1/32) / 2, - torch.uint6: (6/8 + 1/16 + 1/32) / 2, - torch.uint7: (7/8 + 1/16 + 1/32) / 2, + torch.uint1: (1 / 8 + 1 / 16 + 1 / 32) / 2, + torch.uint2: (2 / 8 + 1 / 16 + 1 / 32) / 2, + torch.uint3: (3 / 8 + 1 / 16 + 1 / 32) / 2, + torch.uint4: (4 / 8 + 1 / 16 + 1 / 32) / 2, + torch.uint5: (5 / 8 + 1 / 16 + 1 / 32) / 2, + torch.uint6: (6 / 8 + 1 / 16 + 1 / 32) / 2, + torch.uint7: (7 / 8 + 1 / 16 + 1 / 32) / 2, } - l = torch.nn.Sequential( + linear = torch.nn.Sequential( torch.nn.Linear(128, 256, bias=False, dtype=torch.bfloat16, device="cuda") ) - bf16_size = get_model_size_in_bytes(l) + bf16_size = get_model_size_in_bytes(linear) # make sure it runs - uintx_weight_only(dtype)(l[0]) - quantized_size = get_model_size_in_bytes(l) + uintx_weight_only(dtype)(linear[0]) + quantized_size = get_model_size_in_bytes(linear) assert bf16_size * _dtype_to_ratio[dtype] == quantized_size