Skip to content

Commit e41ca4e

Browse files
authored
Add module-swap UX for INT8 mixed-precision training (#1179)
* add module swap UX * update * fix typing. add small notes * try NF4 support * fix * fix unpacking * fix * update nf4 integration * update backward pass
1 parent 71a442a commit e41ca4e

File tree

4 files changed

+87
-35
lines changed

4 files changed

+87
-35
lines changed

benchmarks/quantized_training/pretrain_llama2.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,9 @@ def insert_rmsnorm(module: torch.nn.Module):
160160
elif args.quantize == "int8_mixed_precision":
161161
quantize_(model.layers, int8_mixed_precision_training(), set_inductor_config=False)
162162

163+
elif args.quantize == "int8_mixed_precision_module_swap":
164+
quantize_(model.layers, int8_mixed_precision_training(module_swap=True), set_inductor_config=False)
165+
163166
elif args.quantize == "bitnet":
164167
quantize_(model.layers, bitnet_training(), set_inductor_config=False)
165168

test/prototype/test_quantized_training.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -159,16 +159,18 @@ def test_int8_weight_only_training(self, compile, device):
159159
Int8MixedPrecisionTrainingConfig(grad_weight=False),
160160
],
161161
)
162+
@parametrize("module_swap", [False, True])
162163
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
163-
def test_int8_mixed_precision_training(self, compile, config):
164+
def test_int8_mixed_precision_training(self, compile, config, module_swap):
164165
_reset()
165166
bsize = 64
166167
embed_dim = 64
167168
device = "cuda"
168169

169170
linear = nn.Linear(embed_dim, embed_dim, device=device)
170171
linear_int8mp = copy.deepcopy(linear)
171-
quantize_(linear_int8mp, int8_mixed_precision_training(config), set_inductor_config=False)
172+
apply_func = int8_mixed_precision_training(config, module_swap=module_swap)
173+
quantize_(linear_int8mp, apply_func, set_inductor_config=False)
172174

173175
if compile:
174176
linear.compile()
@@ -269,9 +271,10 @@ def test_fsdp2_correctness(self):
269271
# quantize_fn, mp_policy, tolerance
270272
test_args = [
271273
# high tolerance due to stochastic rounding
272-
(int8_weight_only_quantized_training, mp_policy, 0.05),
273-
(int8_mixed_precision_training, mp_policy, 1e-6),
274-
(bitnet_training, mp_policy, 1e-5),
274+
(int8_weight_only_quantized_training(), mp_policy, 0.05),
275+
(int8_mixed_precision_training(), mp_policy, 1e-6),
276+
(int8_mixed_precision_training(module_swap=True), mp_policy, 1e-6),
277+
(bitnet_training(), mp_policy, 1e-5),
275278
]
276279

277280
# FSDP2 mixed-precision requires https://github.com/pytorch/pytorch/pull/136129
@@ -284,9 +287,9 @@ def test_fsdp2_correctness(self):
284287
bf16_mp_policy = MixedPrecisionPolicy(param_dtype=torch.bfloat16)
285288

286289
extra_args = [
287-
(int8_weight_only_quantized_training, bf16_mp_policy, 1e-2),
288-
(int8_mixed_precision_training, bf16_mp_policy, 1e-2),
289-
(bitnet_training, bf16_mp_policy, 1e-2),
290+
(int8_weight_only_quantized_training(), bf16_mp_policy, 1e-2),
291+
(int8_mixed_precision_training(), bf16_mp_policy, 1e-2),
292+
(bitnet_training(), bf16_mp_policy, 1e-2),
290293
]
291294
test_args.extend(extra_args)
292295

@@ -312,8 +315,8 @@ def _run_subtest(self, args):
312315
base_model = Transformer(model_args).cuda()
313316
fsdp_model = copy.deepcopy(base_model)
314317

315-
quantize_(base_model.layers, quantize_fn(), set_inductor_config=False)
316-
quantize_(fsdp_model.layers, quantize_fn(), set_inductor_config=False)
318+
quantize_(base_model.layers, quantize_fn, set_inductor_config=False)
319+
quantize_(fsdp_model.layers, quantize_fn, set_inductor_config=False)
317320

318321
for layer in fsdp_model.layers:
319322
fully_shard(layer, mp_policy=mp_policy)

torchao/prototype/quantized_training/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
)
77
from .int8_mixed_precision import (
88
Int8MixedPrecisionTrainingConfig,
9+
Int8MixedPrecisionTrainingLinear,
910
Int8MixedPrecisionTrainingLinearWeight,
1011
int8_mixed_precision_training,
1112
)

torchao/prototype/quantized_training/int8_mixed_precision.py

Lines changed: 70 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1-
from typing import Any, NamedTuple, Optional, Tuple
1+
from typing import Any, NamedTuple, Optional, Tuple, Union
22

33
import torch
44
import torch.utils._pytree as pytree
5-
from torch import Tensor
5+
from torch import Tensor, nn
66
from torch.utils._triton import has_triton
77

88
from torchao.quantization.quant_api import _get_linear_subclass_inserter
@@ -75,7 +75,7 @@ def to_original(self):
7575
def __torch_dispatch__(cls, func, types, args, kwargs):
7676
config = None
7777

78-
def unwrap(x: cls):
78+
def unwrap(x):
7979
nonlocal config
8080
if config is None:
8181
config = x.config
@@ -151,7 +151,16 @@ def _(func, types, args, kwargs):
151151
if torch.is_autocast_enabled("cuda"):
152152
dtype = torch.get_autocast_gpu_dtype()
153153
args = tuple(x.to(dtype) if x is not None else x for x in args)
154-
return _Int8MixedPrecisionTrainingLinear.apply(*args, **kwargs)
154+
return _Int8MixedPrecisionTrainingLinearFunction.apply(*args, **kwargs)
155+
156+
157+
class Int8MixedPrecisionTrainingLinear(nn.Linear):
158+
def __init__(self, *args, config: Int8MixedPrecisionTrainingConfig, **kwargs) -> None:
159+
super().__init__(*args, **kwargs)
160+
self.config = config
161+
162+
def forward(self, input: Tensor) -> Tensor:
163+
return _Int8MixedPrecisionTrainingLinearFunction.apply(input, self.weight, self.bias, self.config)
155164

156165

157166
def _dynamic_int8_mm(A: Tensor, B: Tensor) -> Tensor:
@@ -184,26 +193,46 @@ def _dynamic_int8_mm(A: Tensor, B: Tensor) -> Tensor:
184193
return out.view(*A.shape[:-1], out.shape[-1])
185194

186195

187-
class _Int8MixedPrecisionTrainingLinear(torch.autograd.Function):
196+
@torch.compiler.allow_in_graph # this is required for module-swap, but not for tensor subclass
197+
class _Int8MixedPrecisionTrainingLinearFunction(torch.autograd.Function):
188198
@staticmethod
189-
def forward(input: Tensor, weight: Int8MixedPrecisionTrainingLinearWeight, bias: Optional[Tensor]):
190-
if weight.config.output:
191-
out = _dynamic_int8_mm(input, weight._data.T)
199+
def forward(
200+
ctx,
201+
input: Tensor,
202+
weight: Union[Int8MixedPrecisionTrainingLinearWeight, Tensor],
203+
bias: Optional[Tensor],
204+
config: Optional[Int8MixedPrecisionTrainingConfig] = None,
205+
):
206+
# unpack tensor subclass and dequant if necessary.
207+
# NOTE: we have to do this inside autograd.Function so that autograd works correctly.
208+
if isinstance(weight, Int8MixedPrecisionTrainingLinearWeight):
209+
config = weight.config # override `config` input argument
210+
weight = weight._data
211+
212+
ctx.config = config
213+
ctx.save_for_backward(input, weight)
214+
ctx.bias = bias is not None
215+
216+
# for NF4Tensor, this will dequantize the tensor.
217+
# NOTE: not all quantized tensor subclasses implement .to() this way.
218+
# e.g. AffineQuantizedTensor.to(dtype=dtype) returns the same AQT tensor.
219+
# casting weight dtype may also introduce unintended behavior.
220+
# e.g. FP32 activations and BF16 weight (both plain tensors), which should raise an error,
221+
# but now we cast BF16 weight to FP32 instead (and return results in FP32).
222+
weight = weight.to(input.dtype)
223+
224+
if config.output:
225+
out = _dynamic_int8_mm(input, weight.T)
192226
else:
193-
out = input @ weight._data.T
227+
out = input @ weight.T
194228
out = out + bias if bias is not None else out
195229
return out
196230

197-
@staticmethod
198-
def setup_context(ctx, inputs, output):
199-
input, weight, bias = inputs
200-
ctx.config = weight.config
201-
ctx.save_for_backward(input, weight._data)
202-
ctx.bias = bias is not None
203-
204231
@staticmethod
205232
def backward(ctx, grad_output):
206233
input, weight = ctx.saved_tensors
234+
weight = weight.to(input.dtype) # dequant NF4
235+
207236
grad_input = grad_weight = grad_bias = None
208237

209238
if ctx.needs_input_grad[0]:
@@ -224,12 +253,28 @@ def backward(ctx, grad_output):
224253
if ctx.needs_input_grad[2] and ctx.bias:
225254
grad_bias = grad_output.sum(0)
226255

227-
return grad_input, grad_weight, grad_bias
228-
229-
230-
def int8_mixed_precision_training(config: Int8MixedPrecisionTrainingConfig = _DEFAULT_CONFIG):
231-
return _get_linear_subclass_inserter(
232-
Int8MixedPrecisionTrainingLinearWeight,
233-
config=config,
234-
allow_requires_grad=True,
235-
)
256+
return grad_input, grad_weight, grad_bias, None
257+
258+
259+
def int8_mixed_precision_training(
260+
config: Int8MixedPrecisionTrainingConfig = _DEFAULT_CONFIG,
261+
*,
262+
module_swap: bool = False,
263+
):
264+
# TODO: skip small layers that don't have perf gain.
265+
if module_swap:
266+
# module swap implementation
267+
def convert_linear(linear: nn.Linear):
268+
linear.__class__ = Int8MixedPrecisionTrainingLinear
269+
linear.config = config
270+
return linear
271+
272+
return convert_linear
273+
274+
else:
275+
# tensor subclass implementation
276+
return _get_linear_subclass_inserter(
277+
Int8MixedPrecisionTrainingLinearWeight,
278+
config=config,
279+
allow_requires_grad=True,
280+
)

0 commit comments

Comments
 (0)