Skip to content

Commit cace65c

Browse files
TimDettmersclaude
andauthored
Fix AdEMAMix scheduler guard and add state_dict round-trip test (#1861)
* Fix AdEMAMix scheduler guard and add state_dict round-trip test (#1382) Fix potential division-by-zero in AdEMAMix update_step when t_alpha or t_beta3 is 0 (e.g. from get_config defaults). Change scheduler guards from `is None` to falsy checks so that 0, None, and False all correctly skip the scheduler path. Also change get_config defaults for t_alpha and t_beta3 from 0 to None to match the intended semantics. Add test_ademamix_state_dict_no_nan which saves and loads AdEMAMix state_dict (8-bit, 32-bit, with and without schedulers) and verifies: - loaded state matches original byte-for-byte - training resumes without NaN or Inf - two optimizers loaded from the same checkpoint produce identical updates Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * style: Fix ruff format violation in test_linear4bit.py Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> --------- Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 505a00a commit cace65c

File tree

3 files changed

+113
-5
lines changed

3 files changed

+113
-5
lines changed

bitsandbytes/optim/ademamix.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ def init_state(self, group, p, gindex, pindex):
180180
def update_step(self, group, p, gindex, pindex):
181181
config = self.get_config(gindex, pindex, group)
182182

183-
if config["t_alpha"] is None and config["t_beta3"] is None:
183+
if not config["t_alpha"] and not config["t_beta3"]:
184184
# Not using alpha/beta3 scheduler; we can fall through.
185185
super().update_step(group, p, gindex, pindex)
186186
return
@@ -201,13 +201,13 @@ def update_step(self, group, p, gindex, pindex):
201201
t_beta3 = config["t_beta3"]
202202

203203
# Apply scheduler for alpha
204-
if t_alpha is not None:
204+
if t_alpha:
205205
alpha_t = min(step * alpha / t_alpha, alpha)
206206
else:
207207
alpha_t = alpha
208208

209209
# Apply scheduler for beta3
210-
if t_beta3 is not None:
210+
if t_beta3:
211211
ln_beta1 = math.log(beta1)
212212
ln_beta3 = math.log(beta3)
213213
step_scale = step / t_beta3

bitsandbytes/optim/optimizer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -341,8 +341,8 @@ def get_config(self, gindex, pindex, group):
341341
config["weight_decay"] = group["weight_decay"]
342342
config["lr"] = group["lr"]
343343
config["alpha"] = group.get("alpha", 0.0)
344-
config["t_alpha"] = group.get("t_alpha", 0)
345-
config["t_beta3"] = group.get("t_beta3", 0)
344+
config["t_alpha"] = group.get("t_alpha", None)
345+
config["t_beta3"] = group.get("t_beta3", None)
346346
config["optim_bits"] = self.args.optim_bits
347347
config["min_8bit_size"] = self.args.min_8bit_size
348348
config["percentile_clipping"] = self.args.percentile_clipping

tests/test_optim.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -592,3 +592,111 @@ def test_benchmark_blockwise(dim1, dim2, gtype, optim_name, device):
592592
params = (total_steps - total_steps // 5) * dim1 * dim2
593593
print(optim_name, gtype, s, params, s / params)
594594
# assert s < 3.9
595+
596+
597+
ademamix_state_dict_opts = [
598+
("AdEMAMix8bit", lambda p: bnb.optim.AdEMAMix8bit(p, lr=1e-3)),
599+
("AdEMAMix32bit", lambda p: bnb.optim.AdEMAMix(p, lr=1e-3)),
600+
("AdEMAMix8bit_scheduled", lambda p: bnb.optim.AdEMAMix8bit(p, lr=1e-3, t_alpha=100, t_beta3=100)),
601+
("AdEMAMix32bit_scheduled", lambda p: bnb.optim.AdEMAMix(p, lr=1e-3, t_alpha=100, t_beta3=100)),
602+
]
603+
604+
605+
@pytest.mark.parametrize(
606+
"optim_name,optim_factory",
607+
ademamix_state_dict_opts,
608+
ids=[x[0] for x in ademamix_state_dict_opts],
609+
)
610+
@pytest.mark.parametrize("device", get_available_devices(no_cpu=True))
611+
@pytest.mark.skipif(not get_available_devices(no_cpu=True), reason="No device")
612+
def test_ademamix_state_dict_no_nan(optim_name, optim_factory, device):
613+
"""Test that AdEMAMix can save/load state_dict and continue training without NaN.
614+
615+
Regression test for https://github.com/bitsandbytes-foundation/bitsandbytes/issues/1382
616+
"""
617+
if device not in ["cuda", "xpu"]:
618+
pytest.skip("Optimizers are only supported on CUDA and XPU")
619+
620+
import torch.nn as nn
621+
622+
torch.manual_seed(42)
623+
model = nn.Linear(256, 64).to(device)
624+
opt = optim_factory(model.parameters())
625+
626+
# Train a few steps to populate optimizer state
627+
for _ in range(10):
628+
x = torch.randn(8, 256, device=device)
629+
loss = model(x).sum()
630+
loss.backward()
631+
opt.step()
632+
opt.zero_grad()
633+
634+
# Save state
635+
model_sd = {k: v.clone() for k, v in model.state_dict().items()}
636+
opt_sd = opt.state_dict()
637+
path = get_temp_dir()
638+
torch.save(opt_sd, join(path, "opt.pt"))
639+
torch.save(model_sd, join(path, "model.pt"))
640+
641+
# Create fresh model and optimizer, load state
642+
model2 = nn.Linear(256, 64).to(device)
643+
model2.load_state_dict(torch.load(join(path, "model.pt")))
644+
opt2 = optim_factory(model2.parameters())
645+
opt2.load_state_dict(torch.load(join(path, "opt.pt")))
646+
rm_path(path)
647+
648+
# Verify loaded state matches original byte-for-byte
649+
orig_params = list(model.parameters())
650+
loaded_params = list(model2.parameters())
651+
for p_idx in range(len(orig_params)):
652+
s1 = opt.state[orig_params[p_idx]]
653+
s2 = opt2.state[loaded_params[p_idx]]
654+
for k in s1:
655+
if isinstance(s1[k], torch.Tensor):
656+
assert s1[k].shape == s2[k].shape, f"Shape mismatch for param {p_idx} {k}"
657+
assert s1[k].dtype == s2[k].dtype, f"Dtype mismatch for param {p_idx} {k}"
658+
torch.testing.assert_close(s1[k], s2[k])
659+
660+
# Resume training and verify no NaN
661+
for i in range(10):
662+
x = torch.randn(8, 256, device=device)
663+
loss = model2(x).sum()
664+
assert not torch.isnan(loss), f"NaN loss at step {i} after loading state_dict"
665+
assert not torch.isinf(loss), f"Inf loss at step {i} after loading state_dict"
666+
loss.backward()
667+
opt2.step()
668+
opt2.zero_grad()
669+
670+
# Check parameters for NaN/Inf after each step
671+
for p in model2.parameters():
672+
assert not p.isnan().any(), f"NaN in parameters at step {i} after loading state_dict"
673+
assert not p.isinf().any(), f"Inf in parameters at step {i} after loading state_dict"
674+
675+
# Verify the original and loaded optimizers produce identical updates
676+
# from the same starting point (immediately after loading, before any divergence)
677+
torch.manual_seed(999)
678+
x_orig = torch.randn(8, 256, device=device)
679+
x_loaded = x_orig.clone()
680+
681+
# Reset models to the saved checkpoint weights
682+
model.load_state_dict(model_sd)
683+
model2.load_state_dict(model_sd)
684+
685+
# Reload optimizer states from the same checkpoint into two fresh optimizers
686+
opt_fresh = optim_factory(model.parameters())
687+
opt_fresh.load_state_dict(opt_sd)
688+
opt_fresh2 = optim_factory(model2.parameters())
689+
opt_fresh2.load_state_dict(opt_sd)
690+
691+
loss_a = model(x_orig).sum()
692+
loss_a.backward()
693+
opt_fresh.step()
694+
opt_fresh.zero_grad()
695+
696+
loss_b = model2(x_loaded).sum()
697+
loss_b.backward()
698+
opt_fresh2.step()
699+
opt_fresh2.zero_grad()
700+
701+
for p_a, p_b in zip(model.parameters(), model2.parameters()):
702+
torch.testing.assert_close(p_a, p_b)

0 commit comments

Comments
 (0)