-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
77 lines (60 loc) · 2.88 KB
/
train.py
File metadata and controls
77 lines (60 loc) · 2.88 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import os
import warnings
from typing import Dict, Any
import torch
import torch.multiprocessing as mp
import torch.distributed as dist
import hydra
import builtins
from hydra.utils import instantiate
from omegaconf import DictConfig
from core.train import get_tracking_datasets
from core.train.train_val import SiamABC_train_val
from core.utils import prepare_experiment, create_logger
logger = create_logger(__name__)
warnings.filterwarnings("ignore")
def train(gpu, ngpus_per_node, config: Dict[str, Any]) -> None:
# suppress printing if not master
if config["ddp"] and gpu != 0:
def print_pass(*args):
pass
builtins.print = print_pass
if gpu is not None:
print("Using GPU - {} for training".format(gpu))
if config["ddp"]:
if config["dist_url"] == "env://" and config["rank"] == -1:
config["rank"] = int(os.environ["RANK"])
if config["ddp"]:
# mp distributed training, rank needs to be the global rank among all the processes
config["rank"] = config["rank"] * ngpus_per_node + gpu
dist.init_process_group(backend=config["dist_backend"], init_method=config["dist_url"],
world_size=config["world_size"], rank=config["rank"])
torch.distributed.barrier()
model = instantiate(config["model"])
print(model)
train_dataset, val_dataset = get_tracking_datasets(config)
trainer = SiamABC_train_val(model=model, config=config, train=train_dataset, val=val_dataset, ngpus_per_node=ngpus_per_node, gpu=gpu)
train_loss, val_ios = trainer.train_network()
@hydra.main(config_name="SiamABC_tracker", config_path="core/config")
def run_experiment(hydra_config: DictConfig) -> None:
config = prepare_experiment(hydra_config)
logger.info("Experiment dir %s" % config["experiment"]["folder"])
save_path = os.path.join(config["experiment"]["folder"], config["experiment"]["name"])
if os.path.exists(save_path) == False: os.makedirs(save_path)
if config["dist_url"] == "env://" and config["world_size"] == -1:
config["world_size"] = int(os.environ["WORLD_SIZE"])
config["ddp"] = config["world_size"] > 1 or config["ddp"] == True
# code inspired from simsiam
ngpus_per_node = torch.cuda.device_count()
if config["ddp"] and ngpus_per_node>1:
# Since we have ngpus_per_node processes per node, the total world_size
# needs to be adjusted accordingly
config["world_size"] = ngpus_per_node * config["world_size"]
# Use torch.multiprocessing.spawn to launch distributed processes: the
# main_worker process function
mp.spawn(train, nprocs=ngpus_per_node, args=(ngpus_per_node, config))
else:
# Simply call main_worker function
train(config["gpus"][0], ngpus_per_node, config)
if __name__ == "__main__":
run_experiment()