We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 232c6b7 commit 737c3e2Copy full SHA for 737c3e2
hivemind/optim/grad_scaler.py
@@ -4,8 +4,17 @@
4
from typing import Dict, Optional
5
6
import torch
7
-from torch.cuda.amp import GradScaler as TorchGradScaler
8
-from torch.cuda.amp.grad_scaler import OptState, _refresh_per_optimizer_state
+from packaging import version
+
9
+torch_version = torch.__version__.split("+")[0]
10
11
+if version.parse(torch_version) >= version.parse("2.3.0"):
12
+ from torch.amp import GradScaler as TorchGradScaler
13
+ from torch.amp.grad_scaler import OptState, _refresh_per_optimizer_state
14
+else:
15
+ from torch.cuda.amp import GradScaler as TorchGradScaler
16
+ from torch.cuda.amp.grad_scaler import OptState, _refresh_per_optimizer_state
17
18
from torch.optim import Optimizer as TorchOptimizer
19
20
import hivemind
0 commit comments