Skip to content

Commit db5bbe7

Browse files
caaatch22ctios
authored andcommitted
feat(framework): enable amp by options
1 parent b5de207 commit db5bbe7

3 files changed

Lines changed: 11 additions & 4 deletions

File tree

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ classifiers=[
1919
]
2020
dependencies = [
2121
"torch>=2.0",
22-
"accelerate==0.29.2",
22+
"accelerate>=1.12.0",
2323
"simple-parsing"
2424
]
2525

@@ -201,4 +201,4 @@ select = [
201201
"TRY401", # verbose-log-message
202202
"UP",
203203
"YTT",
204-
]
204+
]

recis/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import torch
44

55

6-
__version__ = "1.0.25"
6+
__version__ = "1.0.26"
77

88
pkg_path = os.path.dirname(os.path.realpath(__file__))
99
lib_path = os.path.join(pkg_path, "lib")

recis/framework/trainer.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ class TrainingArguments:
6767
saver_option (Optional[SaverOptions]): Options for checkpoint saver. Defaults to None.
6868
ckpt_save_arg (Optional[CheckpointSaveArguments]): Arguments for checkpoint save. Defaults to None.
6969
ckpt_load_arg (Optional[CheckpointLoadArguments]): Arguments for checkpoint load. Defaults to None.
70+
mixed_precision (Optional[str]): Mixed precision training mode. Defaults to None. Only support "bf16" and "fp16".
7071
"""
7172

7273
gradient_accumulation_steps: int = 1
@@ -90,6 +91,7 @@ class TrainingArguments:
9091
saver_option: Optional[SaverOptions] = None
9192
ckpt_save_arg: Optional[CheckpointSaveArguments] = None
9293
ckpt_load_arg: Optional[CheckpointLoadArguments] = None
94+
mixed_precision: Optional[str] = None
9395

9496

9597
class Trainer:
@@ -197,11 +199,15 @@ def __init__(
197199
self.dense_lr_scheduler = dense_optimizers[1]
198200
self.sparse_optimizer = sparse_optimizer
199201
self.data_to_cuda = data_to_cuda
202+
self.mixed_precision = args.mixed_precision
203+
if self.mixed_precision is not None:
204+
assert self.mixed_precision in ["bf16", "fp16"], "mixed_precision must be 'bf16' or 'fp16'"
200205
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
201206
init_kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=1800))
202207
self.accelerator = Accelerator(
203208
kwargs_handlers=[ddp_kwargs, init_kwargs],
204209
gradient_accumulation_steps=args.gradient_accumulation_steps,
210+
mixed_precision=self.mixed_precision,
205211
**kwargs,
206212
)
207213
self.gradient_accumulation_steps = args.gradient_accumulation_steps
@@ -559,7 +565,8 @@ def _train_step(self, data, epoch, metrics):
559565
self.dense_optimizer.zero_grad()
560566
if self.sparse_optimizer is not None:
561567
self.sparse_optimizer.zero_grad()
562-
loss = self.model(data)
568+
with self.accelerator.autocast():
569+
loss = self.model(data)
563570
metrics.update(epoch=epoch)
564571
metrics.update(loss=loss)
565572
metrics.update(get_global_metrics())

0 commit comments

Comments
 (0)