-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathoptim.py
92 lines (70 loc) · 2.42 KB
/
optim.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import inspect
from typing import Any, Dict
import torch
from accelerate.logging import get_logger
try:
from optimi import AdamW as OptimiAdamW
from optimi import StableAdamW as OptimiStableAdamW
except ImportError:
OptimiAdamW, OptimiStableAdamW = None, None
try:
from bitsandbytes.optim import AdamW8bit, Lion8bit
except ImportError:
AdamW8bit, Lion8bit = None, None
try:
from came_pytorch import CAME
except ImportError:
CAME = None
import ast
logger = get_logger(__name__)
OPTIMIZER_FUNC_TO_NAME = {
"adam": torch.optim.Adam,
"adamw": torch.optim.AdamW,
"optimi-adamw": OptimiAdamW,
"optimi-stableadamw": OptimiStableAdamW,
"bnb-adamw8bit": AdamW8bit,
"bnb-lion8bit": Lion8bit,
"came": CAME,
}
def get_optimizer(
params_to_optimize,
optimizer_name: str = "adam",
learning_rate: float = 1e-3,
optimizer_args_str: str | None = None,
use_deepspeed: bool = False,
# use_cpu_offload_optimizer: bool = False,
# offload_gradients: bool = False,
) -> torch.optim.Optimizer:
optimizer_kwargs = {}
if optimizer_args_str is not None and len(optimizer_args_str) > 0:
for arg in optimizer_args_str:
key, value = arg.split("=")
value = ast.literal_eval(value)
optimizer_kwargs[key] = value
optimizer_name = optimizer_name.lower()
if use_deepspeed:
from accelerate.utils import DummyOptim
return DummyOptim(params_to_optimize, lr=learning_rate, **optimizer_kwargs)
assert optimizer_name in OPTIMIZER_FUNC_TO_NAME, f"Unknown optimizer: {optimizer_name!r}"
optimizer_class = OPTIMIZER_FUNC_TO_NAME[optimizer_name]
assert optimizer_class is not None
optimizer = optimizer_class(params_to_optimize, lr=learning_rate, **optimizer_kwargs)
logger.info(f"Use {optimizer.__class__.__name__!r} | {optimizer_kwargs!r}")
return optimizer
def gradient_norm(parameters):
norm = 0
for param in parameters:
if param.grad is None:
continue
local_norm = param.grad.detach().data.norm(2)
norm += local_norm.item() ** 2
norm = norm**0.5
return norm
def max_gradient(parameters):
max_grad_value = float("-inf")
for param in parameters:
if param.grad is None:
continue
local_max_grad = param.grad.detach().data.abs().max()
max_grad_value = max(max_grad_value, local_max_grad.item())
return max_grad_value