Skip to content

Commit 737c3e2

Browse files
mryabVectorrent
andcommitted
Update optim.grad_scaler to use torch.amp
Co-authored-by: Luciferian Ink <[email protected]>
1 parent 232c6b7 commit 737c3e2

File tree

1 file changed

+11
-2
lines changed

1 file changed

+11
-2
lines changed

hivemind/optim/grad_scaler.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,17 @@
44
from typing import Dict, Optional
55

66
import torch
7-
from torch.cuda.amp import GradScaler as TorchGradScaler
8-
from torch.cuda.amp.grad_scaler import OptState, _refresh_per_optimizer_state
7+
from packaging import version
8+
9+
torch_version = torch.__version__.split("+")[0]
10+
11+
if version.parse(torch_version) >= version.parse("2.3.0"):
12+
from torch.amp import GradScaler as TorchGradScaler
13+
from torch.amp.grad_scaler import OptState, _refresh_per_optimizer_state
14+
else:
15+
from torch.cuda.amp import GradScaler as TorchGradScaler
16+
from torch.cuda.amp.grad_scaler import OptState, _refresh_per_optimizer_state
17+
918
from torch.optim import Optimizer as TorchOptimizer
1019

1120
import hivemind

0 commit comments

Comments
 (0)