@@ -161,44 +161,31 @@ def test_int8_weight_only_training(self, compile, device):
161
161
@pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
162
162
def test_int8_mixed_precision_training (self , compile , config ):
163
163
_reset ()
164
- bsize = 4
165
- embed_dim = 32
164
+ bsize = 64
165
+ embed_dim = 64
166
166
device = "cuda"
167
167
168
- # only use 1 matmul shape to reduce triton autotune time
169
- model_ref = nn .Sequential (
170
- nn .Linear (embed_dim , embed_dim , bias = False ),
171
- nn .GELU (),
172
- nn .Linear (embed_dim , embed_dim ),
173
- ).to (device )
174
- model_int8mp = copy .deepcopy (model_ref )
175
- quantize_ (model_int8mp , int8_mixed_precision_training (config ), set_inductor_config = False )
168
+ linear = nn .Linear (embed_dim , embed_dim ).cuda ()
169
+ linear_int8mp = copy .deepcopy (linear )
170
+ quantize_ (linear_int8mp , int8_mixed_precision_training (config ), set_inductor_config = False )
176
171
177
172
if compile :
178
- model_ref .compile ()
179
- model_int8mp .compile ()
173
+ linear .compile ()
174
+ linear_int8mp .compile ()
180
175
181
- optim_ref = torch .optim . AdamW ( model_ref . parameters () )
182
- optim_int8mp = torch .optim . AdamW ( model_int8mp . parameters () )
176
+ inputs = torch .randn ( bsize , embed_dim , device = device )
177
+ grad_outputs = torch .randn ( bsize , embed_dim , device = device )
183
178
184
- for i in range (5 ):
185
- inputs = torch .randn (bsize , embed_dim , device = device )
186
- labels = torch .randint (embed_dim , size = (bsize ,), device = device )
187
- loss_ref = F .cross_entropy (model_ref (inputs ), labels )
188
- loss_int8mp = F .cross_entropy (model_int8mp (inputs ), labels )
189
-
190
- rel_error = abs (loss_int8mp .item () - loss_ref .item ()) / abs (loss_ref .item ())
191
- assert rel_error < 3e-3 , (i , rel_error )
192
-
193
- loss_ref .backward ()
194
- optim_ref .step ()
195
- optim_ref .zero_grad ()
196
-
197
- loss_int8mp .backward ()
198
- for p in model_int8mp .parameters ():
199
- assert p .grad is not None
200
- optim_int8mp .step ()
201
- optim_int8mp .zero_grad ()
179
+ inputs_ref , outputs_ref = self ._forward_and_backward (linear , inputs , grad_outputs )
180
+ inputs_int8mp , outputs_int8mp = self ._forward_and_backward (linear_int8mp , inputs , grad_outputs )
181
+
182
+ def snr (ref , actual ):
183
+ error = actual - ref
184
+ return 20 * torch .log10 (ref .norm () / error .norm ())
185
+
186
+ assert snr (outputs_ref , outputs_int8mp ) > 20
187
+ assert snr (inputs_ref .grad , inputs_int8mp .grad ) > 20
188
+ assert snr (linear .weight .grad , linear_int8mp .weight .grad ) > 20
202
189
203
190
204
191
_FSDP_WORLD_SIZE = 2
0 commit comments