Skip to content

Move and rename GranularityType -> Granularity #1038

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Oct 10, 2024
10 changes: 8 additions & 2 deletions test/dtypes/test_affine_quantized_float.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,17 @@
float8_weight_only,
quantize_,
)
from torchao.quantization.observer import PerRow, PerTensor
from torchao.quantization.granularity import (
PerRow,
PerTensor,
)
from torchao.quantization.quant_api import (
float8_static_activation_float8_weight,
)
from torchao.quantization.quant_primitives import MappingType, choose_qparams_affine
from torchao.quantization.quant_primitives import (
MappingType,
choose_qparams_affine,
)

random.seed(0)
torch.manual_seed(0)
Expand Down
22 changes: 12 additions & 10 deletions test/quantization/test_observer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,13 @@
from torch.testing._internal import common_utils
from torch.testing._internal.common_utils import TestCase

from torchao.quantization.observer import (
AffineQuantizedMinMaxObserver,
from torchao.quantization.granularity import (
PerAxis,
PerTensor,
)
from torchao.quantization.observer import (
AffineQuantizedMinMaxObserver,
)
from torchao.quantization.quant_api import (
insert_observers_,
)
Expand Down Expand Up @@ -42,7 +44,7 @@ def test_min_max_per_tensor_affine(self):
obs = AffineQuantizedMinMaxObserver(
MappingType.ASYMMETRIC,
torch.uint8,
granularity_type=PerTensor(),
granularity=PerTensor(),
eps=torch.finfo(torch.float32).eps,
scale_dtype=torch.float,
zero_point_dtype=torch.int,
Expand All @@ -54,7 +56,7 @@ def test_min_max_per_channel_affine(self):
obs = AffineQuantizedMinMaxObserver(
MappingType.ASYMMETRIC,
torch.uint8,
granularity_type=PerAxis(axis=0),
granularity=PerAxis(axis=0),
eps=torch.finfo(torch.float32).eps,
scale_dtype=torch.float,
zero_point_dtype=torch.int,
Expand All @@ -68,7 +70,7 @@ def test_block_size_calc_success(self):
obs = AffineQuantizedMinMaxObserver(
MappingType.SYMMETRIC,
torch.float8_e4m3fn,
granularity_type=PerTensor(),
granularity=PerTensor(),
eps=torch.finfo(torch.float32).eps,
scale_dtype=torch.float,
zero_point_dtype=torch.int,
Expand All @@ -87,7 +89,7 @@ def test_block_size_calc_success(self):
obs = AffineQuantizedMinMaxObserver(
MappingType.SYMMETRIC,
torch.float8_e4m3fn,
granularity_type=PerAxis(1),
granularity=PerAxis(1),
eps=torch.finfo(torch.float32).eps,
scale_dtype=torch.float,
zero_point_dtype=torch.int,
Expand All @@ -102,7 +104,7 @@ def test_block_size_row_errors(self):
obs = AffineQuantizedMinMaxObserver(
MappingType.SYMMETRIC,
torch.float8_e4m3fn,
granularity_type=PerAxis(0),
granularity=PerAxis(0),
eps=torch.finfo(torch.float32).eps,
scale_dtype=torch.float,
zero_point_dtype=torch.int,
Expand All @@ -121,7 +123,7 @@ def test_block_size_row_errors(self):
obs = AffineQuantizedMinMaxObserver(
MappingType.SYMMETRIC,
torch.float8_e4m3fn,
granularity_type=PerAxis(1),
granularity=PerAxis(1),
eps=torch.finfo(torch.float32).eps,
scale_dtype=torch.float,
zero_point_dtype=torch.int,
Expand Down Expand Up @@ -149,7 +151,7 @@ def test_linear_observer_tensor(self, observe_weight: bool):
input_observer = AffineQuantizedMinMaxObserver(
MappingType.SYMMETRIC,
torch.float8_e4m3fn,
granularity_type=PerTensor(),
granularity=PerTensor(),
eps=torch.finfo(torch.float32).eps,
scale_dtype=torch.float,
zero_point_dtype=torch.int,
Expand All @@ -159,7 +161,7 @@ def test_linear_observer_tensor(self, observe_weight: bool):
weight_observer = AffineQuantizedMinMaxObserver(
MappingType.SYMMETRIC,
torch.float8_e4m3fn,
granularity_type=PerTensor(),
granularity=PerTensor(),
eps=torch.finfo(torch.float32).eps,
scale_dtype=torch.float,
zero_point_dtype=torch.int,
Expand Down
4 changes: 2 additions & 2 deletions torchao/_models/llama/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@
float8_dynamic_activation_float8_weight,
float8_static_activation_float8_weight,
)
from torchao.quantization.observer import PerRow, PerTensor
from torchao._models._eval import TransformerEvalWrapper, InputRecorder
from torchao._models.llama.model import prepare_inputs_for_model
from torchao.quantization.granularity import PerRow, PerTensor

from tokenizer import get_tokenizer
import time
Expand Down Expand Up @@ -255,4 +255,4 @@ def run_evaluation(
args.calibration_limit,
args.calibration_seq_length,
args.pad_calibration_inputs,
)
)
2 changes: 1 addition & 1 deletion torchao/_models/llama/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def main(
float8_weight_only,
float8_dynamic_activation_float8_weight,
)
from torchao.quantization.observer import PerTensor, PerRow
from torchao.quantization.granularity import PerTensor, PerRow
if "int8wo" in quantization:
quantize_(model, int8_weight_only())
if "int8dq" in quantization:
Expand Down
2 changes: 1 addition & 1 deletion torchao/prototype/awq/api.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import torch
import torch.nn.functional as F

from torchao.quantization.granularity import PerGroup
from torchao.quantization.quant_primitives import (
MappingType,
ZeroPointDomain,
_DTYPE_TO_QVALUE_BOUNDS,
)
from torchao.quantization import to_weight_tensor_with_linear_activation_scale_metadata
from torchao.quantization.observer import PerGroup
from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter
from torchao.dtypes.uintx import _DTYPE_TO_BIT_WIDTH, UintxLayoutType
from torchao.dtypes import(
Expand Down
9 changes: 5 additions & 4 deletions torchao/prototype/awq/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,21 @@
from torch.utils._python_dispatch import return_and_correct_aliasing
from torchao.dtypes.uintx import _DTYPE_TO_BIT_WIDTH, UintxLayoutType
from torchao.dtypes import to_affine_quantized_intx
from torchao.quantization.granularity import Granularity
from torchao.quantization.quant_primitives import (
MappingType,
ZeroPointDomain,
)
from torchao.quantization.observer import (
AffineQuantizedObserverBase, GranularityType
AffineQuantizedObserverBase,
)


class AWQObserver(AffineQuantizedObserverBase):
def __init__(self,
weight: torch.Tensor,
bias: torch.Tensor,
quantization_granularity: GranularityType,
quantization_granularity: Granularity,
mapping_type: MappingType,
target_dtype: torch.dtype,
n_validation_examples: int,
Expand All @@ -40,7 +41,7 @@ def __init__(self,
Args:
weight: The weight tensor to be observed.
bias: The bias tensor to be observed.
quantization_granularity: Granularity type which specifies how many weights share the same scale/zero point
quantization_granularity: Granularity which specifies how many weights share the same scale/zero point
input_dtype: The data type of the input tensor.
mapping_type: Always set to asymmetric
target_dtype: The target data type of the quantized tensor
Expand Down Expand Up @@ -153,4 +154,4 @@ def from_float(cls, float_linear: torch.nn.Linear, act_obs: AWQObserver):
observed_linear = cls(float_linear.in_features, float_linear.out_features, act_obs, False, device=float_linear.weight.device, dtype=float_linear.weight.dtype)
observed_linear.weight = float_linear.weight
observed_linear.bias = float_linear.bias
return observed_linear
return observed_linear
2 changes: 1 addition & 1 deletion torchao/quantization/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ change_linear_weights_to_int8_dqtensors(model)
```python
# for torch 2.4+
from torchao.quantization import quantize_, float8_dynamic_activation_float8_weight
from torchao.quantization.observer import PerTensor
from torchao.quantization.quant_api import PerTensor
quantize_(model, float8_dynamic_activation_float8_weight(granularity=PerTensor()))
```

Expand Down
8 changes: 5 additions & 3 deletions torchao/quantization/autoquant.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,14 @@
from torchao.dtypes import AffineQuantizedTensor, PlainLayoutType, TensorCoreTiledLayoutType, Float8LayoutType
from torchao.quantization.linear_activation_quantized_tensor import LinearActivationQuantizedTensor
from torch.utils._python_dispatch import return_and_correct_aliasing
from .quant_primitives import (
safe_int_mm,
from .granularity import (
PerAxis,
PerRow,
PerTensor,
)
from .quant_primitives import safe_int_mm
from torchao.utils import TORCH_VERSION_AT_LEAST_2_3, TORCH_VERSION_AT_LEAST_2_5
from torchao.quantization.utils import quantize_activation_per_token_absmax
from torchao.quantization.observer import PerAxis, PerTensor, PerRow
from torchao.float8.inference import Float8MMConfig

import torch.nn.functional as F
Expand Down
76 changes: 76 additions & 0 deletions torchao/quantization/granularity.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

from dataclasses import dataclass


@dataclass(frozen=True)
class Granularity:
"""
Base class for representing the granularity of quantization.

This class serves as a parent for specific granularity types used in
quantization operations, such as per-tensor or per-axis quantization.
"""
pass

@dataclass(frozen=True)
class PerTensor(Granularity):
"""
Represents per-tensor granularity in quantization.

This granularity type calcualtes the quantization parameters
based off the entire tensor.
"""
pass

@dataclass(frozen=True)
class PerAxis(Granularity):
"""
Represents per-axis granularity in quantization.

This granularity type calcualtes different quantization parameters
along a specified axis of the tensor.

For example if the input tensor is shape [8, 16] and axis=0, then
the quantization parameters are calculated for each row of the tensor.
Giving a total of 8 quantization parameters.


Attributes:
axis (int): The axis along which reduction is performed.
"""
axis: int

@dataclass(frozen=True)

class PerGroup(Granularity):
"""
Represents per-channel group granularity in quantization.

This granularity type calcualtes different quantization parameters
for each group of <group_size> elements.

For example if the input tensor is shape [8, 16], and the group size is 4, then
the input tensor is reshaped to [64, 4]
quantization parameters are calculated for each group of 4 elements,
giving a total of 64 quantization parameters.

Attributes:
group_size (int): The size of each quantization group

"""
group_size: int

class PerRow(Granularity):
"""
Represents row-wise granularity in quantization.

This is a special case of per-axis quantization and is unique to Float8 matmuls
where the input is quantized with a block_size of (1, ..., input.shape[-1]). And the weight
is quantized with a block_size of (1, weight.shape[1]).
"""
pass
Loading
Loading