Skip to content

Lint test dtypes #1305

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Nov 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions ruff.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@ include = [
"torchao/prototype/low_bit_optim/**.py",
"test/float8/**/*.py",
"test/quantization/test_observer.py",
"test/dtypes/test_affine_quantized_float.py",
"test/dtypes/test_nf4.py",
"test/dtypes/**/*.py",
"test/prototype/low_bit_optim/**.py",
"torchao/utils.py",

Expand Down
94 changes: 56 additions & 38 deletions test/dtypes/test_affine_quantized.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,24 @@
import tempfile
import unittest

import torch
from torch.testing._internal import common_utils
from torch.testing._internal.common_utils import (
TestCase,
run_tests,
)

from torchao.dtypes import SemiSparseLayout
from torchao.quantization import (
float8_weight_only,
int4_weight_only,
int8_weight_only,
int8_dynamic_activation_int4_weight,
int8_dynamic_activation_int8_weight,
int8_dynamic_activation_int8_semi_sparse_weight,
float8_weight_only,
int8_weight_only,
)
from torchao.quantization.quant_primitives import MappingType
from torchao.dtypes import SemiSparseLayout
from torch.testing._internal import common_utils
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5

import torch
import unittest
import tempfile

is_cuda_8_9 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9)


Expand All @@ -33,7 +33,9 @@ def get_quantization_functions(do_sparse: bool, do_int4: bool):
base_functions.append(int4_weight_only(group_size=32))

if do_sparse:
base_functions.append(int8_dynamic_activation_int8_weight(layout=SemiSparseLayout()))
base_functions.append(
int8_dynamic_activation_int8_weight(layout=SemiSparseLayout())
)

if is_cuda_8_9:
base_functions.append(float8_weight_only())
Expand All @@ -44,11 +46,11 @@ def get_quantization_functions(do_sparse: bool, do_int4: bool):
class TestAffineQuantized(TestCase):
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_tensor_core_layout_transpose(self):
l = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
t = l.weight
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
t = linear.weight
shape = t.shape
apply_int4_weight_only_quant = int4_weight_only(group_size=32)
ql = apply_int4_weight_only_quant(l)
ql = apply_int4_weight_only_quant(linear)
aqt = ql.weight
aqt_shape = aqt.shape
self.assertEqual(aqt_shape, shape)
Expand All @@ -64,8 +66,8 @@ def test_tensor_core_layout_transpose(self):
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@common_utils.parametrize("apply_quant", get_quantization_functions(True, True))
def test_weights_only(self, apply_quant):
l = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
ql = apply_quant(l)
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
ql = apply_quant(linear)
with tempfile.NamedTemporaryFile() as f:
torch.save(ql.state_dict(), f)
f.seek(0)
Expand All @@ -78,33 +80,32 @@ def test_weights_only(self, apply_quant):
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@common_utils.parametrize("apply_quant", get_quantization_functions(False, False))
def test_to_device(self, apply_quant):
l = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
ql = apply_quant(l)
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
ql = apply_quant(linear)
ql.to("cuda")

l = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
ql = apply_quant(l)
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
ql = apply_quant(linear)
ql.to(device="cuda")

l = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
ql = apply_quant(l)
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
ql = apply_quant(linear)
ql.cuda()

@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_register_new_dispatch(self):
from torchao.dtypes import AffineQuantizedTensor, to_affine_quantized_intx
from torchao.dtypes.affine_quantized_tensor_ops import (
register_aqt_quantized_linear_dispatch,
deregister_aqt_quantized_linear_dispatch,
register_aqt_quantized_linear_dispatch,
)
from torchao.dtypes import to_affine_quantized_intx
from torchao.dtypes import AffineQuantizedTensor
from torchao.quantization.quant_primitives import MappingType

def dispatch_condition(input_tensor, weight_tensor, bias):
return (
isinstance(weight_tensor, AffineQuantizedTensor) and
weight_tensor.quant_min == 0 and
weight_tensor.quant_max == 2**6-1
isinstance(weight_tensor, AffineQuantizedTensor)
and weight_tensor.quant_min == 0
and weight_tensor.quant_max == 2**6 - 1
)

def impl(input_tensor, weight_tensor, bias):
Expand All @@ -115,23 +116,35 @@ def impl(input_tensor, weight_tensor, bias):
register_aqt_quantized_linear_dispatch(dispatch_condition, impl)

def apply_uint6_weight_only_quant(linear):
linear.weight = torch.nn.Parameter(to_affine_quantized_intx(linear.weight, MappingType.ASYMMETRIC, (1, linear.weight.shape[-1]), torch.uint8, 0, 2**6-1), requires_grad=False)
linear.weight = torch.nn.Parameter(
to_affine_quantized_intx(
linear.weight,
MappingType.ASYMMETRIC,
(1, linear.weight.shape[-1]),
torch.uint8,
0,
2**6 - 1,
),
requires_grad=False,
)
return linear

l = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
apply_uint6_weight_only_quant(l)
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
apply_uint6_weight_only_quant(linear)

example_input = torch.randn(1, 128, dtype=torch.bfloat16, device="cuda")
with self.assertRaisesRegex(AssertionError, "dispatching to my impl for uint6 weight only quant"):
l(example_input)
with self.assertRaisesRegex(
AssertionError, "dispatching to my impl for uint6 weight only quant"
):
linear(example_input)

deregister_aqt_quantized_linear_dispatch(dispatch_condition)

@common_utils.parametrize("apply_quant", get_quantization_functions(True, True))
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_print_quantized_module(self, apply_quant):
l = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
ql = apply_quant(l)
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
ql = apply_quant(linear)
assert "AffineQuantizedTensor" in str(ql)


Expand All @@ -143,20 +156,25 @@ class TestAffineQuantizedBasic(TestCase):
@common_utils.parametrize("device", COMMON_DEVICES)
@common_utils.parametrize("dtype", COMMON_DTYPES)
def test_flatten_unflatten(self, apply_quant, device, dtype):
l = torch.nn.Linear(128, 256, dtype=dtype, device=device)
ql = apply_quant(l)
linear = torch.nn.Linear(128, 256, dtype=dtype, device=device)
ql = apply_quant(linear)
lp_tensor = ql.weight
tensor_data_name_dict, tensor_attributes = lp_tensor.__tensor_flatten__()
tensor_data_dict = {name: getattr(lp_tensor, name) for name in tensor_data_name_dict}
tensor_data_dict = {
name: getattr(lp_tensor, name) for name in tensor_data_name_dict
}
outer_size = lp_tensor.size()
outer_stride = lp_tensor.stride()
reconstructed = type(lp_tensor).__tensor_unflatten__(tensor_data_dict, tensor_attributes, outer_size, outer_stride)
reconstructed = type(lp_tensor).__tensor_unflatten__(
tensor_data_dict, tensor_attributes, outer_size, outer_stride
)
example_inputs = (torch.randn(32, 128, dtype=dtype, device=device),)
ref = ql(*example_inputs)
ql.weight = torch.nn.Parameter(reconstructed, requires_grad=False)
reconstruct_res = ql(*example_inputs)
self.assertEqual(reconstruct_res, ref)


common_utils.instantiate_parametrized_tests(TestAffineQuantized)
common_utils.instantiate_parametrized_tests(TestAffineQuantizedBasic)

Expand Down
70 changes: 38 additions & 32 deletions test/dtypes/test_affine_quantized_tensor_parallel.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,28 @@
import torch
import unittest
from torch.testing._internal.common_utils import run_tests

import torch
from torch.distributed._tensor import DeviceMesh, DTensor, Replicate, Shard
from torch.testing._internal import common_utils
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase,
with_comms,
)

from torchao.quantization import (
float8_dynamic_activation_float8_weight,
float8_weight_only,
int4_weight_only,
int8_weight_only,
float8_weight_only,
float8_dynamic_activation_float8_weight,
)
from torchao.quantization.observer import PerRow, PerTensor
import torch.distributed as dist
from torch.distributed._tensor import DTensor, Replicate, Shard, DeviceMesh
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase,
with_comms,
NUM_DEVICES,
)
from torchao.quantization.quant_api import quantize_
from torchao.dtypes import AffineQuantizedTensor
from torchao.utils import TORCH_VERSION_AT_LEAST_2_6


class TestAffineQuantizedTensorParallel(DTensorTestBase):
"""Basic test case for tensor subclasses
"""
"""Basic test case for tensor subclasses"""

QUANT_METHOD_FN = staticmethod(int8_weight_only)
QUANT_METHOD_KWARGS = {}

Expand All @@ -40,9 +40,7 @@ def colwise_shard(m: torch.nn.Module, mesh: DeviceMesh) -> torch.nn.Module:
# Construct DTensor from local shard
dtensor = DTensor.from_local(local_shard, mesh, [Shard(0)])
# Replace parameter in module
m.linear.weight = torch.nn.Parameter(
dtensor, requires_grad=False
)
m.linear.weight = torch.nn.Parameter(dtensor, requires_grad=False)
return m

@staticmethod
Expand All @@ -59,9 +57,7 @@ def rowwise_shard(m: torch.nn.Module, mesh: DeviceMesh) -> torch.nn.Module:
# Construct DTensor from local shard
dtensor = DTensor.from_local(local_shard, mesh, [Shard(1)], run_check=True)
# Replace parameter in module
m.linear.weight = torch.nn.Parameter(
dtensor, requires_grad=False
)
m.linear.weight = torch.nn.Parameter(dtensor, requires_grad=False)
return m

def quantize(self, m: torch.nn.Module) -> torch.nn.Module:
Expand All @@ -79,7 +75,9 @@ def _test_tp(self, dtype):
class M(torch.nn.Module):
def __init__(self, in_features, out_features, **kwargs) -> None:
super().__init__(**kwargs)
self.linear = torch.nn.Linear(in_features, out_features, bias=False, device="cuda")
self.linear = torch.nn.Linear(
in_features, out_features, bias=False, device="cuda"
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.linear(x)
Expand All @@ -91,11 +89,11 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
proj_up = M(1024, 2048).to(device).to(dtype)
proj_dn = M(2048, 1024).to(device).to(dtype)
example_input = 100 * torch.randn(128, 1024, device=device, dtype=dtype)
y = proj_dn(proj_up(example_input))
proj_dn(proj_up(example_input))
# Quantize the model
up_quant = self.quantize(proj_up)
dn_quant = self.quantize(proj_dn)
y_q = dn_quant(up_quant(example_input))
dn_quant(up_quant(example_input))

mesh = self.build_device_mesh()
mesh.device_type = "cuda"
Expand All @@ -105,11 +103,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
dn_dist = self.rowwise_shard(dn_quant, mesh)

# We need to turn inputs into DTensor form as well -- just a format change
input_dtensor = DTensor.from_local(
example_input, mesh, [Replicate()]
)
input_dtensor = DTensor.from_local(example_input, mesh, [Replicate()])

y_d = dn_dist(up_dist(input_dtensor))
dn_dist(up_dist(input_dtensor))

if not TORCH_VERSION_AT_LEAST_2_6:
# Need torch 2.6 to support compiled tensor parallelism
Expand All @@ -118,7 +114,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
up_compiled = torch.compile(up_dist)
y_up = up_compiled(input_dtensor)
dn_compiled = torch.compile(dn_dist)
y_dn = dn_compiled(y_up)
dn_compiled(y_up)


class TestInt8woAffineQuantizedTensorParallel(TestAffineQuantizedTensorParallel):
Expand All @@ -142,11 +138,13 @@ class TestInt4woAffineQuantizedTensorParallel(TestAffineQuantizedTensorParallel)
def test_tp(self, dtype):
return self._test_tp(dtype)


common_utils.instantiate_parametrized_tests(TestInt8woAffineQuantizedTensorParallel)
common_utils.instantiate_parametrized_tests(TestInt4woAffineQuantizedTensorParallel)

# Run only on H100
if torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0):

class TestFloat8woAffineQuantizedTensorParallel(TestAffineQuantizedTensorParallel):
QUANT_METHOD_FN = staticmethod(float8_weight_only)
COMMON_DTYPES = [torch.bfloat16, torch.float16, torch.float32]
Expand All @@ -157,7 +155,9 @@ class TestFloat8woAffineQuantizedTensorParallel(TestAffineQuantizedTensorParalle
def test_tp(self, dtype):
return self._test_tp(dtype)

class TestFloat8dqTensorAffineQuantizedTensorParallel(TestAffineQuantizedTensorParallel):
class TestFloat8dqTensorAffineQuantizedTensorParallel(
TestAffineQuantizedTensorParallel
):
QUANT_METHOD_FN = staticmethod(float8_dynamic_activation_float8_weight)
QUANT_METHOD_KWARGS = {"granularity": PerTensor()}
COMMON_DTYPES = [torch.bfloat16, torch.float16, torch.float32]
Expand All @@ -168,7 +168,9 @@ class TestFloat8dqTensorAffineQuantizedTensorParallel(TestAffineQuantizedTensorP
def test_tp(self, dtype):
return self._test_tp(dtype)

class TestFloat8dqRowAffineQuantizedTensorParallel(TestAffineQuantizedTensorParallel):
class TestFloat8dqRowAffineQuantizedTensorParallel(
TestAffineQuantizedTensorParallel
):
QUANT_METHOD_FN = staticmethod(float8_dynamic_activation_float8_weight)
QUANT_METHOD_KWARGS = {"granularity": PerRow()}
COMMON_DTYPES = [torch.bfloat16]
Expand All @@ -179,7 +181,11 @@ class TestFloat8dqRowAffineQuantizedTensorParallel(TestAffineQuantizedTensorPara
def test_tp(self, dtype):
return self._test_tp(dtype)

common_utils.instantiate_parametrized_tests(TestFloat8dqTensorAffineQuantizedTensorParallel)
common_utils.instantiate_parametrized_tests(TestFloat8dqRowAffineQuantizedTensorParallel)
common_utils.instantiate_parametrized_tests(
TestFloat8dqTensorAffineQuantizedTensorParallel
)
common_utils.instantiate_parametrized_tests(
TestFloat8dqRowAffineQuantizedTensorParallel
)
if __name__ == "__main__":
run_tests()
Loading
Loading