-
Notifications
You must be signed in to change notification settings - Fork 24
/
Copy pathtrainers.py
107 lines (87 loc) · 3.01 KB
/
trainers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
from typing import Any, Dict, Union
import ignite.distributed as idist
import torch
from ignite.engine import DeterministicEngine, Engine, Events
from ignite.metrics.metric import Metric
from torch import nn
from torch.cuda.amp import GradScaler, autocast
from torch.optim.optimizer import Optimizer
from torch.utils.data import DistributedSampler, Sampler
def setup_trainer(
config: Any,
model: nn.Module,
optimizer: Optimizer,
loss_fn: nn.Module,
device: Union[str, torch.device],
train_sampler: Sampler,
) -> Union[Engine, DeterministicEngine]:
scaler = GradScaler(enabled=config.use_amp)
def train_function(engine: Union[Engine, DeterministicEngine], batch: Any):
input_ids = batch["input_ids"].to(
device, non_blocking=True, dtype=torch.long
)
attention_mask = batch["attention_mask"].to(
device, non_blocking=True, dtype=torch.long
)
token_type_ids = batch["token_type_ids"].to(
device, non_blocking=True, dtype=torch.long
)
labels = (
batch["label"]
.view(-1, 1)
.to(device, non_blocking=True, dtype=torch.float)
)
model.train()
with autocast(enabled=config.use_amp):
y_pred = model(input_ids, attention_mask, token_type_ids)
loss = loss_fn(y_pred, labels)
optimizer.zero_grad()
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
metric = {"train_loss": loss.item()}
engine.state.metrics = metric
return metric
#::: if(it.deterministic) { :::#
trainer = DeterministicEngine(train_function)
#::: } else { :::#
trainer = Engine(train_function)
#::: } :::#
# set epoch for distributed sampler
@trainer.on(Events.EPOCH_STARTED)
def set_epoch():
if idist.get_world_size() > 1 and isinstance(
train_sampler, DistributedSampler
):
train_sampler.set_epoch(trainer.state.epoch - 1)
return trainer
def setup_evaluator(
config: Any,
model: nn.Module,
metrics: Dict[str, Metric],
device: Union[str, torch.device],
):
@torch.no_grad()
def evalutate_function(engine: Engine, batch: Any):
model.eval()
input_ids = batch["input_ids"].to(
device, non_blocking=True, dtype=torch.long
)
attention_mask = batch["attention_mask"].to(
device, non_blocking=True, dtype=torch.long
)
token_type_ids = batch["token_type_ids"].to(
device, non_blocking=True, dtype=torch.long
)
labels = (
batch["label"]
.view(-1, 1)
.to(device, non_blocking=True, dtype=torch.float)
)
with autocast(enabled=config.use_amp):
output = model(input_ids, attention_mask, token_type_ids)
return output, labels
evaluator = Engine(evalutate_function)
for name, metric in metrics.items():
metric.attach(evaluator, name)
return evaluator