Skip to content

Commit 9ca5397

Browse files
authored
feat(templates): add a image segmentation template (#129)
1 parent c8871ec commit 9ca5397

File tree

11 files changed

+1088
-3
lines changed

11 files changed

+1088
-3
lines changed

src/templates/template-common/config.yaml

-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ num_workers: 2
66
max_epochs: 2
77
train_epoch_length: 4
88
eval_epoch_length: 4
9-
lr: 0.0001
109
use_amp: false
1110
debug: false
1211

src/templates/template-common/utils.py

-2
Original file line numberDiff line numberDiff line change
@@ -181,8 +181,6 @@ def setup_handlers(
181181

182182
#::: if (it.patience) { :::#
183183
# early stopping
184-
def score_fn(engine: Engine):
185-
return -engine.state.metrics["eval_loss"]
186184

187185
es = EarlyStopping(config.patience, score_fn, trainer)
188186
evaluator.add_event_handler(Events.EPOCH_COMPLETED, es)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
# Template by Code-Generator
2+
3+
## Getting Started
4+
5+
Install the dependencies with `pip`:
6+
7+
```sh
8+
pip install -r requirements.txt --progress-bar off -U
9+
```
10+
11+
## Training
12+
13+
#::: if (it.use_dist) { :::#
14+
#::: if (it.dist === 'launch') { :::#
15+
#::: if (it.nproc_per_node) { :::#
16+
#::: if (it.nnodes > 1 && it.master_addr && it.master_port) { :::#
17+
18+
### Multi Node, Multi GPU Training (`torch.distributed.launch`) (recommended)
19+
20+
- Execute on master node
21+
22+
```sh
23+
python -m torch.distributed.launch \
24+
--nproc_per_node #:::= nproc_per_node :::# \
25+
--nnodes #:::= it.nnodes :::# \
26+
--node_rank 0 \
27+
--master_addr #:::= it.master_addr :::# \
28+
--master_port #:::= it.master_port :::# \
29+
--use_env main.py \
30+
--backend nccl
31+
```
32+
33+
- Execute on worker nodes
34+
35+
```sh
36+
python -m torch.distributed.launch \
37+
--nproc_per_node #:::= nproc_per_node :::# \
38+
--nnodes #:::= it.nnodes :::# \
39+
--node_rank <node_rank> \
40+
--master_addr #:::= it.master_addr :::# \
41+
--master_port #:::= it.master_port :::# \
42+
--use_env main.py \
43+
--backend nccl
44+
```
45+
46+
#::: } else { :::#
47+
48+
### Multi GPU Training (`torch.distributed.launch`) (recommended)
49+
50+
```sh
51+
python -m torch.distributed.launch \
52+
--nproc_per_node #:::= it.nproc_per_node :::# \
53+
--use_env main.py \
54+
--backend nccl
55+
```
56+
57+
#::: } :::#
58+
#::: } :::#
59+
#::: } :::#
60+
61+
#::: if (it.dist === 'spawn') { :::#
62+
#::: if (it.nproc_per_node) { :::#
63+
#::: if (it.nnodes > 1 && it.master_addr && it.master_port) { :::#
64+
65+
### Multi Node, Multi GPU Training (`torch.multiprocessing.spawn`)
66+
67+
- Execute on master node
68+
69+
```sh
70+
python main.py \
71+
--nproc_per_node #:::= nproc_per_node :::# \
72+
--nnodes #:::= it.nnodes :::# \
73+
--node_rank 0 \
74+
--master_addr #:::= it.master_addr :::# \
75+
--master_port #:::= it.master_port :::# \
76+
--backend nccl
77+
```
78+
79+
- Execute on worker nodes
80+
81+
```sh
82+
python main.py \
83+
--nproc_per_node #:::= nproc_per_node :::# \
84+
--nnodes #:::= it.nnodes :::# \
85+
--node_rank <node_rank> \
86+
--master_addr #:::= it.master_addr :::# \
87+
--master_port #:::= it.master_port :::# \
88+
--backend nccl
89+
```
90+
91+
#::: } else { :::#
92+
93+
### Multi GPU Training (`torch.multiprocessing.spawn`)
94+
95+
```sh
96+
python main.py \
97+
--nproc_per_node #:::= it.nproc_per_node :::# \
98+
--backend nccl
99+
```
100+
101+
#::: } :::#
102+
#::: } :::#
103+
#::: } :::#
104+
105+
#::: } else { :::#
106+
107+
### 1 GPU Training
108+
109+
```sh
110+
python main.py
111+
```
112+
113+
#::: } :::#
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
seed: 666
2+
data_path: ./
3+
train_batch_size: 4
4+
eval_batch_size: 8
5+
num_workers: 2
6+
max_epochs: 2
7+
train_epoch_length: 4
8+
eval_epoch_length: 4
9+
lr: 0.007
10+
use_amp: false
11+
debug: false
12+
accumulation_steps: 4
13+
num_classes: 21
14+
15+
#::: if (it.dist === 'spawn') { :::#
16+
# distributed spawn
17+
nproc_per_node: #:::= it.nproc_per_node :::#
18+
#::: if (it.nnodes) { :::#
19+
# distributed multi node spawn
20+
nnodes: #:::= it.nnodes :::#
21+
#::: if (it.nnodes > 1) { :::#
22+
node_rank: 0
23+
master_addr: #:::= it.master_addr :::#
24+
master_port: #:::= it.master_port :::#
25+
#::: } :::#
26+
#::: } :::#
27+
#::: } :::#
28+
29+
#::: if (it.filename_prefix) { :::#
30+
filename_prefix: #:::= it.filename_prefix :::#
31+
#::: } :::#
32+
33+
#::: if (it.n_saved) { :::#
34+
n_saved: #:::= it.n_saved :::#
35+
#::: } :::#
36+
37+
#::: if (it.save_every_iters) { :::#
38+
save_every_iters: #:::= it.save_every_iters :::#
39+
#::: } :::#
40+
41+
#::: if (it.patience) { :::#
42+
patience: #:::= it.patience :::#
43+
#::: } :::#
44+
45+
#::: if (it.limit_sec) { :::#
46+
limit_sec: #:::= it.limit_sec :::#
47+
#::: } :::#
48+
49+
#::: if (it.output_dir) { :::#
50+
output_dir: #:::= it.output_dir :::#
51+
#::: } :::#
52+
53+
#::: if (it.log_every_iters) { :::#
54+
log_every_iters: #:::= it.log_every_iters :::#
55+
#::: } :::#
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
from argparse import Namespace
2+
3+
import albumentations as A
4+
import cv2
5+
import ignite.distributed as idist
6+
import numpy as np
7+
import torch
8+
from albumentations.pytorch import ToTensorV2 as ToTensor
9+
from ignite.utils import convert_tensor
10+
from PIL import Image
11+
from torch.utils.data import Dataset
12+
from torchvision.datasets.voc import VOCSegmentation
13+
14+
15+
class TransformedDataset(Dataset):
16+
def __init__(self, ds, transform_fn):
17+
assert isinstance(ds, Dataset)
18+
assert callable(transform_fn)
19+
self.ds = ds
20+
self.transform_fn = transform_fn
21+
22+
def __len__(self):
23+
return len(self.ds)
24+
25+
def __getitem__(self, index):
26+
dp = self.ds[index]
27+
return self.transform_fn(**dp)
28+
29+
30+
class VOCSegmentationPIL(VOCSegmentation):
31+
32+
target_names = [
33+
"background",
34+
"aeroplane",
35+
"bicycle",
36+
"bird",
37+
"boat",
38+
"bottle",
39+
"bus",
40+
"car",
41+
"cat",
42+
"chair",
43+
"cow",
44+
"diningtable",
45+
"dog",
46+
"horse",
47+
"motorbike",
48+
"person",
49+
"plant",
50+
"sheep",
51+
"sofa",
52+
"train",
53+
"tv/monitor",
54+
]
55+
56+
def __init__(self, *args, return_meta=False, **kwargs):
57+
super().__init__(*args, **kwargs)
58+
self.return_meta = return_meta
59+
60+
def __getitem__(self, index):
61+
img = np.asarray(Image.open(self.images[index]).convert("RGB"))
62+
assert img is not None, f"Image at '{self.images[index]}' has a problem"
63+
mask = np.asarray(Image.open(self.masks[index]))
64+
65+
if self.return_meta:
66+
return {
67+
"image": img,
68+
"mask": mask,
69+
"meta": {
70+
"index": index,
71+
"image_path": self.images[index],
72+
"mask_path": self.masks[index],
73+
},
74+
}
75+
76+
return {"image": img, "mask": mask}
77+
78+
79+
def setup_data(config: Namespace):
80+
dataset_train = VOCSegmentationPIL(
81+
root=config.data_path, year="2012", image_set="train", download=False
82+
)
83+
dataset_eval = VOCSegmentationPIL(
84+
root=config.data_path, year="2012", image_set="val", download=False
85+
)
86+
87+
val_img_size = 513
88+
train_img_size = 480
89+
90+
mean = (0.485, 0.456, 0.406)
91+
std = (0.229, 0.224, 0.225)
92+
93+
transform_train = A.Compose(
94+
[
95+
A.RandomScale(
96+
scale_limit=(0.0, 1.5), interpolation=cv2.INTER_LINEAR, p=1.0
97+
),
98+
A.PadIfNeeded(
99+
val_img_size, val_img_size, border_mode=cv2.BORDER_CONSTANT
100+
),
101+
A.RandomCrop(train_img_size, train_img_size),
102+
A.HorizontalFlip(),
103+
A.Blur(blur_limit=3),
104+
A.Normalize(mean=mean, std=std),
105+
ignore_mask_boundaries,
106+
ToTensor(),
107+
]
108+
)
109+
110+
transform_eval = A.Compose(
111+
[
112+
A.PadIfNeeded(
113+
val_img_size, val_img_size, border_mode=cv2.BORDER_CONSTANT
114+
),
115+
A.Normalize(mean=mean, std=std),
116+
ignore_mask_boundaries,
117+
ToTensor(),
118+
]
119+
)
120+
121+
dataset_train = TransformedDataset(
122+
dataset_train, transform_fn=transform_train
123+
)
124+
dataset_eval = TransformedDataset(dataset_eval, transform_fn=transform_eval)
125+
126+
dataloader_train = idist.auto_dataloader(
127+
dataset_train,
128+
shuffle=True,
129+
batch_size=config.train_batch_size,
130+
num_workers=config.num_workers,
131+
drop_last=True,
132+
)
133+
dataloader_eval = idist.auto_dataloader(
134+
dataset_eval,
135+
shuffle=False,
136+
batch_size=config.train_batch_size,
137+
num_workers=config.num_workers,
138+
drop_last=False,
139+
)
140+
141+
return dataloader_train, dataloader_eval
142+
143+
144+
def ignore_mask_boundaries(force_apply, **kwargs):
145+
assert "mask" in kwargs, "Input should contain 'mask'"
146+
mask = kwargs["mask"]
147+
mask[mask == 255] = 0
148+
kwargs["mask"] = mask
149+
return kwargs
150+
151+
152+
def denormalize(t, mean, std, max_pixel_value=255):
153+
assert isinstance(t, torch.Tensor), f"{type(t)}"
154+
assert t.ndim == 3
155+
d = t.device
156+
mean = torch.tensor(mean, device=d).unsqueeze(-1).unsqueeze(-1)
157+
std = torch.tensor(std, device=d).unsqueeze(-1).unsqueeze(-1)
158+
tensor = std * t + mean
159+
tensor *= max_pixel_value
160+
return tensor
161+
162+
163+
def prepare_image_mask(batch, device, non_blocking):
164+
x, y = batch["image"], batch["mask"]
165+
x = convert_tensor(x, device, non_blocking=non_blocking)
166+
y = convert_tensor(y, device, non_blocking=non_blocking).long()
167+
return x, y
168+
169+
170+
def download_datasets(data_path):
171+
VOCSegmentation(data_path, image_set="train", download=True)
172+
VOCSegmentation(data_path, image_set="val", download=True)

0 commit comments

Comments
 (0)