Skip to content

Commit ffb4350

Browse files
authored
Initial PARQ addition and testing (#1738)
* Initial PARQ addition and testing * Fix errors due to torch version differences * Revert torchao/float8/config.py * Fix custom decorator * Reformat parq.py * Undo third_party/cutlass change
1 parent 03b83ec commit ffb4350

File tree

13 files changed

+880
-0
lines changed

13 files changed

+880
-0
lines changed

test/prototype/test_parq.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
import unittest
2+
3+
import torch
4+
5+
from torchao.prototype.parq.optim import (
6+
ProxHardQuant,
7+
ProxPARQ,
8+
QuantOptimizer,
9+
)
10+
from torchao.prototype.parq.quant import LSBQuantizer, UnifQuantizer
11+
12+
_DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
13+
14+
15+
def split_param_groups(model):
16+
params_no_quant, params_quant = [], []
17+
for p in model.parameters():
18+
if p.dim() > 1:
19+
params_quant.append(p)
20+
else:
21+
params_no_quant.append(p)
22+
return params_no_quant, params_quant
23+
24+
25+
class M(torch.nn.Module):
26+
def __init__(self):
27+
super().__init__()
28+
self.embedding = torch.nn.Embedding(10, 256)
29+
self.linear1 = torch.nn.Linear(256, 128)
30+
self.linear2 = torch.nn.Linear(128, 16)
31+
self.relu = torch.nn.ReLU()
32+
self.sigmoid = torch.nn.Sigmoid()
33+
34+
def reset_parameters(self):
35+
for module in (self.linear1, self.linear2):
36+
torch.nn.init.xavier_uniform_(module.weight)
37+
torch.nn.init.zeros_(module.bias)
38+
39+
def example_inputs(self):
40+
return torch.randint(1, 10, (1, 256))
41+
42+
def forward(self, x):
43+
x = self.embedding(x)
44+
x = self.linear1(x)
45+
x = self.relu(x)
46+
x = self.linear2(x)
47+
x = self.sigmoid(x)
48+
return x
49+
50+
51+
class TestPARQuantization(unittest.TestCase):
52+
def setUp(self):
53+
torch.manual_seed(123)
54+
self.model = M().to(_DEVICE)
55+
self.params_no_quant, self.params_quant = split_param_groups(self.model)
56+
57+
def test_2bit_unif_quantizer_hard_prox(self):
58+
self.model.reset_parameters()
59+
param_groups = [
60+
{"params": self.params_no_quant},
61+
{"params": self.params_quant, "quant_bits": 2},
62+
]
63+
base_optimizer = torch.optim.AdamW(param_groups)
64+
quantizer = UnifQuantizer()
65+
prox_map = ProxHardQuant()
66+
optimizer = QuantOptimizer(base_optimizer, quantizer, prox_map)
67+
68+
x = self.model.example_inputs().to(_DEVICE)
69+
out = self.model(x)
70+
out.sum().backward()
71+
optimizer.step()
72+
73+
for child in self.model.children():
74+
if isinstance(child, torch.nn.Linear):
75+
self.assertEqual(child.weight.unique().numel(), 4)
76+
77+
def test_ternarybit_lsbq_parq_prox(self):
78+
self.model.reset_parameters()
79+
param_groups = [
80+
{"params": self.params_no_quant},
81+
{"params": self.params_quant, "quant_bits": 0},
82+
]
83+
base_optimizer = torch.optim.AdamW(param_groups)
84+
quantizer = LSBQuantizer()
85+
prox_map = ProxPARQ(anneal_start=0, anneal_end=2)
86+
optimizer = QuantOptimizer(base_optimizer, quantizer, prox_map)
87+
88+
for _ in range(3):
89+
x = self.model.example_inputs().to(_DEVICE)
90+
out = self.model(x)
91+
out.sum().backward()
92+
optimizer.step()
93+
94+
for child in self.model.children():
95+
if isinstance(child, torch.nn.Linear):
96+
self.assertEqual(child.weight.unique().numel(), 3)
97+
98+
99+
if __name__ == "__main__":
100+
unittest.main()

torchao/prototype/parq/README.md

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
# PARQ: Piecewise-Affine Regularized Quantization
2+
3+
PARQ is a QAT method based on a convex regularization framework. It converges to hard quantization (i.e., STE) at its asymptotic limit.
4+
5+
This library applies QAT without modifying model-level code. It instead interfaces with the optimizer only, allowing a user to choose which parameters should be quantized via parameter groups. It separates QAT into the below components.
6+
7+
* quantization method: computing the best set of discrete, quantized values
8+
* proximal mapping: projection of weights onto quantized values
9+
10+
## QAT arguments
11+
12+
| | description | choices |
13+
| --- | --- | --- |
14+
| `quant-bits` | bit-width for quantized weights | 0 (ternary), 1—4 |
15+
| `quant-method` | method for determining quantized values | `lsbq`, `uniform` |
16+
| `quant-proxmap` | proximal mapping to project weights onto quantized values | `hard`, `parq`, `binaryrelax` |
17+
| `anneal-start` | start epoch for QAT annealing period | (0, `total_steps` - 1) |
18+
| `anneal-end` | end epoch for QAT annealing period | (`anneal_end`, `total_steps`) |
19+
| `anneal-steepness` | sigmoid steepness for PARQ inverse slope schedule | 25—100 |
20+
21+
## Optimizer-only interface
22+
23+
The `QuantOptimizer` wrapper takes any `torch.optim.Optimizer` object. It is also initialized with a `Quantizer` and `ProxMap` object. Integration into new training pipelines is simple:
24+
```python
25+
from parq.optim import ProxPARQ, QuantOptimizer
26+
from parq.quant import LSBQuantizer
27+
28+
29+
# split params into quantizable and non-quantizable params
30+
params_quant, params_no_wd, params_wd = split_param_groups(model) # user-defined
31+
param_groups = [
32+
{"params": params_quant, "quant_bits": 2},
33+
{"params": params_no_wd, "weight_decay": 0},
34+
{"params": params_wd},
35+
]
36+
37+
# create PyTorch optimizer
38+
base_optimizer = torch.optim.SGD( # user-defined
39+
param_groups, lr=0.1, momentum=0.9, weight_decay=1e-4
40+
)
41+
42+
# create quantizer and proximal map objects
43+
quantizer = LSBQuantizer()
44+
prox_map = ProxPARQ(anneal_start=..., anneal_end=..., steepness=...)
45+
46+
optimizer = QuantOptimizer(base_optimizer, quantizer, prox_map)
47+
```

torchao/prototype/parq/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
from .optim import ( # noqa: F401
2+
ProxBinaryRelax,
3+
ProxHardQuant,
4+
ProxMap,
5+
ProxPARQ,
6+
QuantOptimizer,
7+
)
8+
from .quant import LSBQuantizer, Quantizer, UnifQuantizer # noqa: F401
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from .binarelax import ProxBinaryRelax # noqa: F401
2+
from .parq import ProxPARQ # noqa: F401
3+
from .proxmap import ProxHardQuant, ProxMap # noqa: F401
4+
from .quantopt import QuantOptimizer # noqa: F401
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
from typing import Optional
2+
3+
import torch
4+
from torch import Tensor
5+
6+
from ..utils import channel_bucketize
7+
from .proxmap import ProxMap
8+
9+
10+
class ProxBinaryRelax(ProxMap):
11+
"""Prox-map of Binary Relax, Q may not be evenly spaced."""
12+
13+
def __init__(self, anneal_start: int, anneal_end: int) -> None:
14+
self.anneal_start = anneal_start
15+
self.anneal_end = anneal_end
16+
17+
@torch.no_grad()
18+
def apply_(
19+
self,
20+
p: Tensor,
21+
q: Tensor,
22+
Q: Tensor,
23+
step_count: int,
24+
dim: Optional[int] = None,
25+
) -> None:
26+
if step_count < self.anneal_start:
27+
return
28+
29+
if q is None:
30+
# hard quantization to the nearest point in Q
31+
Q_mid = (Q[..., :-1] + Q[..., 1:]) / 2
32+
if dim is None:
33+
q = Q[torch.bucketize(p, Q_mid)]
34+
else:
35+
q = Q.gather(1, channel_bucketize(p, Q_mid))
36+
37+
if step_count >= self.anneal_end:
38+
p.copy_(q)
39+
return
40+
else:
41+
# linear annealing of relaxation coefficient
42+
theta = (step_count - self.anneal_start) / (
43+
self.anneal_end - self.anneal_start
44+
)
45+
p.mul_(1 - theta).add_(q, alpha=theta)

torchao/prototype/parq/optim/parq.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
import math
2+
from functools import partial
3+
from typing import Optional
4+
5+
import torch
6+
from torch import Tensor
7+
8+
from ..utils import channel_bucketize
9+
from .proxmap import ProxMap
10+
11+
12+
def amp_custom_fwd(cast_inputs: Optional[torch.types._dtype] = None):
13+
try:
14+
return partial(
15+
torch.amp.custom_fwd, device_type="cuda", cast_inputs=cast_inputs
16+
)
17+
except AttributeError:
18+
return partial(torch.cuda.amp.custom_fwd, cast_inputs=cast_inputs)
19+
20+
21+
def normalized_mirror_sigmoid(t: float, t1: float, t2: float, s: float) -> float:
22+
"""Sigmoid-like function decreasing from 1 to 0 over interval [t1, t2).
23+
s is steepness of the sigmoid-like function, almost linear for s < 1.
24+
'mirror' means decreasing instead of increasing as true sigmoid,
25+
'normalized' means value 1 at starting point t1 and 0 at end point t2."""
26+
assert t >= t1 and t < t2, "Normalized sigmoid: ensure t1 <= t < t2"
27+
ft = (t - t1) / (t2 - t1) # fraction of progress from t1 to t2
28+
st = 1 / (1 + math.exp(s * (ft - 0.5))) # scaled and shifted mirror sigmoid
29+
s1 = 1 / (1 + math.exp(-0.5 * s)) # st value when t = t1 -> ft = 0
30+
s2 = 1 / (1 + math.exp(0.5 * s)) # st value when t = t2 -> ft = 1
31+
return (st - s2) / (s1 - s2) # shift and scale to range (0, 1]
32+
33+
34+
class ProxPARQ(ProxMap):
35+
def __init__(
36+
self, anneal_start: int, anneal_end: int, steepness: float = 10
37+
) -> None:
38+
assert anneal_start < anneal_end, "PARQ annealing: start before end."
39+
assert steepness > 0, "PARQ annealing steepness should be positive."
40+
self.anneal_start = anneal_start
41+
self.anneal_end = anneal_end
42+
self.steepness = steepness
43+
44+
@torch.no_grad()
45+
@amp_custom_fwd(cast_inputs=torch.float32)
46+
def apply_(
47+
self,
48+
p: Tensor,
49+
q: Tensor,
50+
Q: Tensor,
51+
step_count: int,
52+
dim: Optional[int] = None,
53+
) -> float:
54+
"""Prox-map of PARQ with gradual annealing to hard quantization."""
55+
56+
if step_count < self.anneal_start:
57+
inv_slope = 1.0
58+
elif step_count >= self.anneal_end:
59+
inv_slope = 0.0
60+
if q is None:
61+
# hard quantization to the nearest point in Q
62+
Q_mid = (Q[..., :-1] + Q[..., 1:]) / 2
63+
if dim is None:
64+
q = Q[torch.bucketize(p, Q_mid)]
65+
else:
66+
q = Q.gather(1, channel_bucketize(p, Q_mid))
67+
p.copy_(q)
68+
else:
69+
inv_slope = normalized_mirror_sigmoid(
70+
step_count, self.anneal_start, self.anneal_end, self.steepness
71+
)
72+
# it is important to clamp idx-1 and then clamping idx itself
73+
# idx_1[k] == idx[k] iff p[k] > Q.max() or p[k] <= Q.min()
74+
if dim is None:
75+
idx = torch.bucketize(p, Q) # locate quant interval
76+
idx_lower = (idx - 1).clamp_(min=0) # index of lower bound
77+
idx_upper = idx.clamp(max=Q.numel() - 1) # index of upper bound
78+
q_lower = Q[idx_lower] # lower boundary of interval
79+
q_upper = Q[idx_upper] # upper boundary of interval
80+
center = (q_lower + q_upper) / 2 # center of interval
81+
# concise implementation of piecewise-affine prox map
82+
q = (center + (p - center) / inv_slope).clamp_(min=q_lower, max=q_upper)
83+
else:
84+
idx = channel_bucketize(p, Q)
85+
idx_lower = (idx - 1).clamp_(min=0)
86+
idx_upper = idx.clamp(max=Q.size(1) - 1)
87+
q_lower = Q.gather(1, idx_lower)
88+
q_upper = Q.gather(1, idx_upper)
89+
center = (q_lower + q_upper) / 2
90+
q = torch.minimum(
91+
torch.maximum(center + (p - center) / inv_slope, q_lower), q_upper
92+
)
93+
# in-place update of model parameters
94+
p.copy_(q)
95+
return inv_slope
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
from abc import ABC, abstractmethod
2+
from typing import Optional
3+
4+
import torch
5+
from torch import Tensor
6+
7+
from ..utils import channel_bucketize
8+
9+
10+
# Create an abstract class to provide proximal-mapping interface
11+
class ProxMap(ABC):
12+
@abstractmethod
13+
def apply_(self, p: Tensor, q: Tensor, Q: Tensor, step_count: int) -> None:
14+
"""Provide interface for proximal mapping (modify p in-place):
15+
prox_map.apply_(p, q, Q, step_count)
16+
Inputs:
17+
p (Tensor): tensor to be quantized
18+
q (Tensor): None or hard quantized tensor of same size as p
19+
Q (Tensor): set of target quantization values
20+
step_count: trigger iteration-dependent mapping if needed
21+
"""
22+
23+
24+
class ProxHardQuant(ProxMap):
25+
"""Prox-map of hard quantization, Q may not be evenly spaced."""
26+
27+
@torch.no_grad()
28+
def apply_(
29+
self,
30+
p: Tensor,
31+
q: Tensor,
32+
Q: Tensor,
33+
step_count: int,
34+
dim: Optional[int] = None,
35+
) -> None:
36+
if q is None:
37+
# quantize to the nearest point in Q
38+
Q_mid = (Q[..., :-1] + Q[..., 1:]) / 2
39+
if dim is None:
40+
q = Q[torch.bucketize(p, Q_mid)]
41+
else:
42+
q = Q.gather(1, channel_bucketize(p, Q_mid))
43+
p.copy_(q)

0 commit comments

Comments
 (0)