Skip to content

Commit 15c17e9

Browse files
committed
Move and rename GranularityType -> Granularity
Summary: Move GranularityType to its own file for more flexible use outside of observers. Test Plan: CI ghstack-source-id: 4da5e4c Pull Request resolved: #1038
1 parent 9e0a59f commit 15c17e9

File tree

15 files changed

+143
-111
lines changed

15 files changed

+143
-111
lines changed

test/dtypes/test_affine_quantized_float.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,17 @@
2626
float8_weight_only,
2727
quantize_,
2828
)
29-
from torchao.quantization.observer import PerRow, PerTensor
29+
from torchao.quantization.granularity import (
30+
PerRow,
31+
PerTensor,
32+
)
3033
from torchao.quantization.quant_api import (
3134
float8_static_activation_float8_weight,
3235
)
33-
from torchao.quantization.quant_primitives import MappingType, choose_qparams_affine
36+
from torchao.quantization.quant_primitives import (
37+
MappingType,
38+
choose_qparams_affine,
39+
)
3440

3541
random.seed(0)
3642
torch.manual_seed(0)

test/quantization/test_observer.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,13 @@
99
from torch.testing._internal import common_utils
1010
from torch.testing._internal.common_utils import TestCase
1111

12-
from torchao.quantization.observer import (
13-
AffineQuantizedMinMaxObserver,
12+
from torchao.quantization.granularity import (
1413
PerAxis,
1514
PerTensor,
1615
)
16+
from torchao.quantization.observer import (
17+
AffineQuantizedMinMaxObserver,
18+
)
1719
from torchao.quantization.quant_api import (
1820
insert_observers_,
1921
)
@@ -42,7 +44,7 @@ def test_min_max_per_tensor_affine(self):
4244
obs = AffineQuantizedMinMaxObserver(
4345
MappingType.ASYMMETRIC,
4446
torch.uint8,
45-
granularity_type=PerTensor(),
47+
granularity=PerTensor(),
4648
eps=torch.finfo(torch.float32).eps,
4749
scale_dtype=torch.float,
4850
zero_point_dtype=torch.int,
@@ -54,7 +56,7 @@ def test_min_max_per_channel_affine(self):
5456
obs = AffineQuantizedMinMaxObserver(
5557
MappingType.ASYMMETRIC,
5658
torch.uint8,
57-
granularity_type=PerAxis(axis=0),
59+
granularity=PerAxis(axis=0),
5860
eps=torch.finfo(torch.float32).eps,
5961
scale_dtype=torch.float,
6062
zero_point_dtype=torch.int,
@@ -68,7 +70,7 @@ def test_block_size_calc_success(self):
6870
obs = AffineQuantizedMinMaxObserver(
6971
MappingType.SYMMETRIC,
7072
torch.float8_e4m3fn,
71-
granularity_type=PerTensor(),
73+
granularity=PerTensor(),
7274
eps=torch.finfo(torch.float32).eps,
7375
scale_dtype=torch.float,
7476
zero_point_dtype=torch.int,
@@ -87,7 +89,7 @@ def test_block_size_calc_success(self):
8789
obs = AffineQuantizedMinMaxObserver(
8890
MappingType.SYMMETRIC,
8991
torch.float8_e4m3fn,
90-
granularity_type=PerAxis(1),
92+
granularity=PerAxis(1),
9193
eps=torch.finfo(torch.float32).eps,
9294
scale_dtype=torch.float,
9395
zero_point_dtype=torch.int,
@@ -102,7 +104,7 @@ def test_block_size_row_errors(self):
102104
obs = AffineQuantizedMinMaxObserver(
103105
MappingType.SYMMETRIC,
104106
torch.float8_e4m3fn,
105-
granularity_type=PerAxis(0),
107+
granularity=PerAxis(0),
106108
eps=torch.finfo(torch.float32).eps,
107109
scale_dtype=torch.float,
108110
zero_point_dtype=torch.int,
@@ -121,7 +123,7 @@ def test_block_size_row_errors(self):
121123
obs = AffineQuantizedMinMaxObserver(
122124
MappingType.SYMMETRIC,
123125
torch.float8_e4m3fn,
124-
granularity_type=PerAxis(1),
126+
granularity=PerAxis(1),
125127
eps=torch.finfo(torch.float32).eps,
126128
scale_dtype=torch.float,
127129
zero_point_dtype=torch.int,
@@ -149,7 +151,7 @@ def test_linear_observer_tensor(self, observe_weight: bool):
149151
input_observer = AffineQuantizedMinMaxObserver(
150152
MappingType.SYMMETRIC,
151153
torch.float8_e4m3fn,
152-
granularity_type=PerTensor(),
154+
granularity=PerTensor(),
153155
eps=torch.finfo(torch.float32).eps,
154156
scale_dtype=torch.float,
155157
zero_point_dtype=torch.int,
@@ -159,7 +161,7 @@ def test_linear_observer_tensor(self, observe_weight: bool):
159161
weight_observer = AffineQuantizedMinMaxObserver(
160162
MappingType.SYMMETRIC,
161163
torch.float8_e4m3fn,
162-
granularity_type=PerTensor(),
164+
granularity=PerTensor(),
163165
eps=torch.finfo(torch.float32).eps,
164166
scale_dtype=torch.float,
165167
zero_point_dtype=torch.int,

torchao/_models/llama/eval.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,9 @@
2424
float8_dynamic_activation_float8_weight,
2525
float8_static_activation_float8_weight,
2626
)
27-
from torchao.quantization.observer import PerRow, PerTensor
2827
from torchao._models._eval import TransformerEvalWrapper, InputRecorder
2928
from torchao._models.llama.model import prepare_inputs_for_model
29+
from torchao.quantization.granularity import PerRow, PerTensor
3030

3131
from tokenizer import get_tokenizer
3232
import time
@@ -255,4 +255,4 @@ def run_evaluation(
255255
args.calibration_limit,
256256
args.calibration_seq_length,
257257
args.pad_calibration_inputs,
258-
)
258+
)

torchao/_models/llama/generate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,7 @@ def main(
216216
float8_weight_only,
217217
float8_dynamic_activation_float8_weight,
218218
)
219-
from torchao.quantization.observer import PerTensor, PerRow
219+
from torchao.quantization.granularity import PerTensor, PerRow
220220
if "int8wo" in quantization:
221221
quantize_(model, int8_weight_only())
222222
if "int8dq" in quantization:

torchao/prototype/awq/api.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,11 @@
33

44
from torchao.quantization.quant_primitives import (
55
MappingType,
6+
PerGroup,
67
ZeroPointDomain,
78
_DTYPE_TO_QVALUE_BOUNDS,
89
)
910
from torchao.quantization import to_weight_tensor_with_linear_activation_scale_metadata
10-
from torchao.quantization.observer import PerGroup
1111
from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter
1212
from torchao.dtypes.uintx import _DTYPE_TO_BIT_WIDTH, UintxLayoutType
1313
from torchao.dtypes import(

torchao/prototype/awq/core.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,20 +7,21 @@
77
from torch.utils._python_dispatch import return_and_correct_aliasing
88
from torchao.dtypes.uintx import _DTYPE_TO_BIT_WIDTH, UintxLayoutType
99
from torchao.dtypes import to_affine_quantized_intx
10+
from torchao.quantization.granularity import Granularity
1011
from torchao.quantization.quant_primitives import (
1112
MappingType,
1213
ZeroPointDomain,
1314
)
1415
from torchao.quantization.observer import (
15-
AffineQuantizedObserverBase, GranularityType
16+
AffineQuantizedObserverBase,
1617
)
1718

1819

1920
class AWQObserver(AffineQuantizedObserverBase):
2021
def __init__(self,
2122
weight: torch.Tensor,
2223
bias: torch.Tensor,
23-
quantization_granularity: GranularityType,
24+
quantization_granularity: Granularity,
2425
mapping_type: MappingType,
2526
target_dtype: torch.dtype,
2627
n_validation_examples: int,
@@ -40,7 +41,7 @@ def __init__(self,
4041
Args:
4142
weight: The weight tensor to be observed.
4243
bias: The bias tensor to be observed.
43-
quantization_granularity: Granularity type which specifies how many weights share the same scale/zero point
44+
quantization_granularity: Granularity which specifies how many weights share the same scale/zero point
4445
input_dtype: The data type of the input tensor.
4546
mapping_type: Always set to asymmetric
4647
target_dtype: The target data type of the quantized tensor
@@ -153,4 +154,4 @@ def from_float(cls, float_linear: torch.nn.Linear, act_obs: AWQObserver):
153154
observed_linear = cls(float_linear.in_features, float_linear.out_features, act_obs, False, device=float_linear.weight.device, dtype=float_linear.weight.dtype)
154155
observed_linear.weight = float_linear.weight
155156
observed_linear.bias = float_linear.bias
156-
return observed_linear
157+
return observed_linear

torchao/quantization/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ change_linear_weights_to_int8_dqtensors(model)
137137
```python
138138
# for torch 2.4+
139139
from torchao.quantization import quantize_, float8_dynamic_activation_float8_weight
140-
from torchao.quantization.observer import PerTensor
140+
from torchao.quantization.quant_api import PerTensor
141141
quantize_(model, float8_dynamic_activation_float8_weight(granularity=PerTensor()))
142142
```
143143

torchao/quantization/autoquant.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,14 @@
1212
from torchao.dtypes import AffineQuantizedTensor, PlainLayoutType, TensorCoreTiledLayoutType, Float8LayoutType
1313
from torchao.quantization.linear_activation_quantized_tensor import LinearActivationQuantizedTensor
1414
from torch.utils._python_dispatch import return_and_correct_aliasing
15-
from .quant_primitives import (
16-
safe_int_mm,
15+
from .granularity import (
16+
PerAxis,
17+
PerRow,
18+
PerTensor,
1719
)
20+
from .quant_primitives import safe_int_mm
1821
from torchao.utils import TORCH_VERSION_AT_LEAST_2_3, TORCH_VERSION_AT_LEAST_2_5
1922
from torchao.quantization.utils import quantize_activation_per_token_absmax
20-
from torchao.quantization.observer import PerAxis, PerTensor, PerRow
2123
from torchao.float8.inference import Float8MMConfig
2224

2325
import torch.nn.functional as F

torchao/quantization/granularity.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
4+
# This source code is licensed under the license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from dataclasses import dataclass
8+
9+
10+
@dataclass(frozen=True)
11+
class Granularity:
12+
"""
13+
Base class for representing the granularity of quantization.
14+
15+
This class serves as a parent for specific granularity types used in
16+
quantization operations, such as per-tensor or per-axis quantization.
17+
"""
18+
pass
19+
20+
@dataclass(frozen=True)
21+
class PerTensor(Granularity):
22+
"""
23+
Represents per-tensor granularity in quantization.
24+
25+
This granularity type calcualtes the quantization parameters
26+
based off the entire tensor.
27+
"""
28+
pass
29+
30+
@dataclass(frozen=True)
31+
class PerAxis(Granularity):
32+
"""
33+
Represents per-axis granularity in quantization.
34+
35+
This granularity type calcualtes different quantization parameters
36+
along a specified axis of the tensor.
37+
38+
For example if the input tensor is shape [8, 16] and axis=0, then
39+
the quantization parameters are calculated for each row of the tensor.
40+
Giving a total of 8 quantization parameters.
41+
42+
43+
Attributes:
44+
axis (int): The axis along which reduction is performed.
45+
"""
46+
axis: int
47+
48+
@dataclass(frozen=True)
49+
50+
class PerGroup(Granularity):
51+
"""
52+
Represents per-channel group granularity in quantization.
53+
54+
This granularity type calcualtes different quantization parameters
55+
for each group of <group_size> elements.
56+
57+
For example if the input tensor is shape [8, 16], and the group size is 4, then
58+
the input tensor is reshaped to [64, 4]
59+
quantization parameters are calculated for each group of 4 elements,
60+
giving a total of 64 quantization parameters.
61+
62+
Attributes:
63+
group_size (int): The size of each quantization group
64+
65+
"""
66+
group_size: int
67+
68+
class PerRow(Granularity):
69+
"""
70+
Represents row-wise granularity in quantization.
71+
72+
This is a special case of per-axis quantization and is unique to Float8 matmuls
73+
where the input is quantized with a block_size of (1, ..., input.shape[-1]). And the weight
74+
is quantized with a block_size of (1, weight.shape[1]).
75+
"""
76+
pass

0 commit comments

Comments
 (0)