Skip to content

Commit ceec750

Browse files
authored
Update INT8 mixed-precision training test to be less flaky (#950)
1 parent 637ed13 commit ceec750

File tree

1 file changed

+19
-32
lines changed

1 file changed

+19
-32
lines changed

test/prototype/test_quantized_training.py

Lines changed: 19 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -161,44 +161,31 @@ def test_int8_weight_only_training(self, compile, device):
161161
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
162162
def test_int8_mixed_precision_training(self, compile, config):
163163
_reset()
164-
bsize = 4
165-
embed_dim = 32
164+
bsize = 64
165+
embed_dim = 64
166166
device = "cuda"
167167

168-
# only use 1 matmul shape to reduce triton autotune time
169-
model_ref = nn.Sequential(
170-
nn.Linear(embed_dim, embed_dim, bias=False),
171-
nn.GELU(),
172-
nn.Linear(embed_dim, embed_dim),
173-
).to(device)
174-
model_int8mp = copy.deepcopy(model_ref)
175-
quantize_(model_int8mp, int8_mixed_precision_training(config), set_inductor_config=False)
168+
linear = nn.Linear(embed_dim, embed_dim).cuda()
169+
linear_int8mp = copy.deepcopy(linear)
170+
quantize_(linear_int8mp, int8_mixed_precision_training(config), set_inductor_config=False)
176171

177172
if compile:
178-
model_ref.compile()
179-
model_int8mp.compile()
173+
linear.compile()
174+
linear_int8mp.compile()
180175

181-
optim_ref = torch.optim.AdamW(model_ref.parameters())
182-
optim_int8mp = torch.optim.AdamW(model_int8mp.parameters())
176+
inputs = torch.randn(bsize, embed_dim, device=device)
177+
grad_outputs = torch.randn(bsize, embed_dim, device=device)
183178

184-
for i in range(5):
185-
inputs = torch.randn(bsize, embed_dim, device=device)
186-
labels = torch.randint(embed_dim, size=(bsize,), device=device)
187-
loss_ref = F.cross_entropy(model_ref(inputs), labels)
188-
loss_int8mp = F.cross_entropy(model_int8mp(inputs), labels)
189-
190-
rel_error = abs(loss_int8mp.item() - loss_ref.item()) / abs(loss_ref.item())
191-
assert rel_error < 3e-3, (i, rel_error)
192-
193-
loss_ref.backward()
194-
optim_ref.step()
195-
optim_ref.zero_grad()
196-
197-
loss_int8mp.backward()
198-
for p in model_int8mp.parameters():
199-
assert p.grad is not None
200-
optim_int8mp.step()
201-
optim_int8mp.zero_grad()
179+
inputs_ref, outputs_ref = self._forward_and_backward(linear, inputs, grad_outputs)
180+
inputs_int8mp, outputs_int8mp = self._forward_and_backward(linear_int8mp, inputs, grad_outputs)
181+
182+
def snr(ref, actual):
183+
error = actual - ref
184+
return 20 * torch.log10(ref.norm() / error.norm())
185+
186+
assert snr(outputs_ref, outputs_int8mp) > 20
187+
assert snr(inputs_ref.grad, inputs_int8mp.grad) > 20
188+
assert snr(linear.weight.grad, linear_int8mp.weight.grad) > 20
202189

203190

204191
_FSDP_WORLD_SIZE = 2

0 commit comments

Comments
 (0)