1
- from typing import Any , NamedTuple , Optional , Tuple
1
+ from typing import Any , NamedTuple , Optional , Tuple , Union
2
2
3
3
import torch
4
4
import torch .utils ._pytree as pytree
5
- from torch import Tensor
5
+ from torch import Tensor , nn
6
6
from torch .utils ._triton import has_triton
7
7
8
8
from torchao .quantization .quant_api import _get_linear_subclass_inserter
@@ -75,7 +75,7 @@ def to_original(self):
75
75
def __torch_dispatch__ (cls , func , types , args , kwargs ):
76
76
config = None
77
77
78
- def unwrap (x : cls ):
78
+ def unwrap (x ):
79
79
nonlocal config
80
80
if config is None :
81
81
config = x .config
@@ -151,7 +151,16 @@ def _(func, types, args, kwargs):
151
151
if torch .is_autocast_enabled ("cuda" ):
152
152
dtype = torch .get_autocast_gpu_dtype ()
153
153
args = tuple (x .to (dtype ) if x is not None else x for x in args )
154
- return _Int8MixedPrecisionTrainingLinear .apply (* args , ** kwargs )
154
+ return _Int8MixedPrecisionTrainingLinearFunction .apply (* args , ** kwargs )
155
+
156
+
157
+ class Int8MixedPrecisionTrainingLinear (nn .Linear ):
158
+ def __init__ (self , * args , config : Int8MixedPrecisionTrainingConfig , ** kwargs ) -> None :
159
+ super ().__init__ (* args , ** kwargs )
160
+ self .config = config
161
+
162
+ def forward (self , input : Tensor ) -> Tensor :
163
+ return _Int8MixedPrecisionTrainingLinearFunction .apply (input , self .weight , self .bias , self .config )
155
164
156
165
157
166
def _dynamic_int8_mm (A : Tensor , B : Tensor ) -> Tensor :
@@ -184,26 +193,46 @@ def _dynamic_int8_mm(A: Tensor, B: Tensor) -> Tensor:
184
193
return out .view (* A .shape [:- 1 ], out .shape [- 1 ])
185
194
186
195
187
- class _Int8MixedPrecisionTrainingLinear (torch .autograd .Function ):
196
+ @torch .compiler .allow_in_graph # this is required for module-swap, but not for tensor subclass
197
+ class _Int8MixedPrecisionTrainingLinearFunction (torch .autograd .Function ):
188
198
@staticmethod
189
- def forward (input : Tensor , weight : Int8MixedPrecisionTrainingLinearWeight , bias : Optional [Tensor ]):
190
- if weight .config .output :
191
- out = _dynamic_int8_mm (input , weight ._data .T )
199
+ def forward (
200
+ ctx ,
201
+ input : Tensor ,
202
+ weight : Union [Int8MixedPrecisionTrainingLinearWeight , Tensor ],
203
+ bias : Optional [Tensor ],
204
+ config : Optional [Int8MixedPrecisionTrainingConfig ] = None ,
205
+ ):
206
+ # unpack tensor subclass and dequant if necessary.
207
+ # NOTE: we have to do this inside autograd.Function so that autograd works correctly.
208
+ if isinstance (weight , Int8MixedPrecisionTrainingLinearWeight ):
209
+ config = weight .config # override `config` input argument
210
+ weight = weight ._data
211
+
212
+ ctx .config = config
213
+ ctx .save_for_backward (input , weight )
214
+ ctx .bias = bias is not None
215
+
216
+ # for NF4Tensor, this will dequantize the tensor.
217
+ # NOTE: not all quantized tensor subclasses implement .to() this way.
218
+ # e.g. AffineQuantizedTensor.to(dtype=dtype) returns the same AQT tensor.
219
+ # casting weight dtype may also introduce unintended behavior.
220
+ # e.g. FP32 activations and BF16 weight (both plain tensors), which should raise an error,
221
+ # but now we cast BF16 weight to FP32 instead (and return results in FP32).
222
+ weight = weight .to (input .dtype )
223
+
224
+ if config .output :
225
+ out = _dynamic_int8_mm (input , weight .T )
192
226
else :
193
- out = input @ weight ._data . T
227
+ out = input @ weight .T
194
228
out = out + bias if bias is not None else out
195
229
return out
196
230
197
- @staticmethod
198
- def setup_context (ctx , inputs , output ):
199
- input , weight , bias = inputs
200
- ctx .config = weight .config
201
- ctx .save_for_backward (input , weight ._data )
202
- ctx .bias = bias is not None
203
-
204
231
@staticmethod
205
232
def backward (ctx , grad_output ):
206
233
input , weight = ctx .saved_tensors
234
+ weight = weight .to (input .dtype ) # dequant NF4
235
+
207
236
grad_input = grad_weight = grad_bias = None
208
237
209
238
if ctx .needs_input_grad [0 ]:
@@ -224,12 +253,28 @@ def backward(ctx, grad_output):
224
253
if ctx .needs_input_grad [2 ] and ctx .bias :
225
254
grad_bias = grad_output .sum (0 )
226
255
227
- return grad_input , grad_weight , grad_bias
228
-
229
-
230
- def int8_mixed_precision_training (config : Int8MixedPrecisionTrainingConfig = _DEFAULT_CONFIG ):
231
- return _get_linear_subclass_inserter (
232
- Int8MixedPrecisionTrainingLinearWeight ,
233
- config = config ,
234
- allow_requires_grad = True ,
235
- )
256
+ return grad_input , grad_weight , grad_bias , None
257
+
258
+
259
+ def int8_mixed_precision_training (
260
+ config : Int8MixedPrecisionTrainingConfig = _DEFAULT_CONFIG ,
261
+ * ,
262
+ module_swap : bool = False ,
263
+ ):
264
+ # TODO: skip small layers that don't have perf gain.
265
+ if module_swap :
266
+ # module swap implementation
267
+ def convert_linear (linear : nn .Linear ):
268
+ linear .__class__ = Int8MixedPrecisionTrainingLinear
269
+ linear .config = config
270
+ return linear
271
+
272
+ return convert_linear
273
+
274
+ else :
275
+ # tensor subclass implementation
276
+ return _get_linear_subclass_inserter (
277
+ Int8MixedPrecisionTrainingLinearWeight ,
278
+ config = config ,
279
+ allow_requires_grad = True ,
280
+ )
0 commit comments