Skip to content

Commit 6518372

Browse files
authored
Added PILToTensor and ConvertImageDtype classes in reference scripts (#4495)
* Added PILToTensor and ConvertImageDtype classes in reference scripts * Addressed review comments * Fixed TypeError * Addressed review comment
1 parent 055708d commit 6518372

File tree

4 files changed

+45
-11
lines changed

4 files changed

+45
-11
lines changed

references/detection/presets.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import torch
2+
13
import transforms as T
24

35

@@ -6,21 +8,24 @@ def __init__(self, data_augmentation, hflip_prob=0.5, mean=(123., 117., 104.)):
68
if data_augmentation == 'hflip':
79
self.transforms = T.Compose([
810
T.RandomHorizontalFlip(p=hflip_prob),
9-
T.ToTensor(),
11+
T.PILToTensor(),
12+
T.ConvertImageDtype(torch.float),
1013
])
1114
elif data_augmentation == 'ssd':
1215
self.transforms = T.Compose([
1316
T.RandomPhotometricDistort(),
1417
T.RandomZoomOut(fill=list(mean)),
1518
T.RandomIoUCrop(),
1619
T.RandomHorizontalFlip(p=hflip_prob),
17-
T.ToTensor(),
20+
T.PILToTensor(),
21+
T.ConvertImageDtype(torch.float),
1822
])
1923
elif data_augmentation == 'ssdlite':
2024
self.transforms = T.Compose([
2125
T.RandomIoUCrop(),
2226
T.RandomHorizontalFlip(p=hflip_prob),
23-
T.ToTensor(),
27+
T.PILToTensor(),
28+
T.ConvertImageDtype(torch.float),
2429
])
2530
else:
2631
raise ValueError(f'Unknown data augmentation policy "{data_augmentation}"')

references/detection/transforms.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
1+
from typing import List, Tuple, Dict, Optional
2+
13
import torch
24
import torchvision
3-
45
from torch import nn, Tensor
56
from torchvision.transforms import functional as F
67
from torchvision.transforms import transforms as T
7-
from typing import List, Tuple, Dict, Optional
88

99

1010
def _flip_coco_person_keypoints(kps, width):
@@ -52,6 +52,24 @@ def forward(self, image: Tensor,
5252
return image, target
5353

5454

55+
class PILToTensor(nn.Module):
56+
def forward(self, image: Tensor,
57+
target: Optional[Dict[str, Tensor]] = None) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
58+
image = F.pil_to_tensor(image)
59+
return image, target
60+
61+
62+
class ConvertImageDtype(nn.Module):
63+
def __init__(self, dtype: torch.dtype) -> None:
64+
super().__init__()
65+
self.dtype = dtype
66+
67+
def forward(self, image: Tensor,
68+
target: Optional[Dict[str, Tensor]] = None) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
69+
image = F.convert_image_dtype(image, self.dtype)
70+
return image, target
71+
72+
5573
class RandomIoUCrop(nn.Module):
5674
def __init__(self, min_scale: float = 0.3, max_scale: float = 1.0, min_aspect_ratio: float = 0.5,
5775
max_aspect_ratio: float = 2.0, sampler_options: Optional[List[float]] = None, trials: int = 40):

references/segmentation/presets.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import torch
2+
13
import transforms as T
24

35

@@ -11,7 +13,8 @@ def __init__(self, base_size, crop_size, hflip_prob=0.5, mean=(0.485, 0.456, 0.4
1113
trans.append(T.RandomHorizontalFlip(hflip_prob))
1214
trans.extend([
1315
T.RandomCrop(crop_size),
14-
T.ToTensor(),
16+
T.PILToTensor(),
17+
T.ConvertImageDtype(torch.float),
1518
T.Normalize(mean=mean, std=std),
1619
])
1720
self.transforms = T.Compose(trans)
@@ -24,7 +27,8 @@ class SegmentationPresetEval:
2427
def __init__(self, base_size, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
2528
self.transforms = T.Compose([
2629
T.RandomResize(base_size, base_size),
27-
T.ToTensor(),
30+
T.PILToTensor(),
31+
T.ConvertImageDtype(torch.float),
2832
T.Normalize(mean=mean, std=std),
2933
])
3034

references/segmentation/transforms.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
1-
import numpy as np
2-
from PIL import Image
31
import random
42

3+
import numpy as np
54
import torch
65
from torchvision import transforms as T
76
from torchvision.transforms import functional as F
@@ -75,14 +74,22 @@ def __call__(self, image, target):
7574
return image, target
7675

7776

78-
class ToTensor(object):
77+
class PILToTensor:
7978
def __call__(self, image, target):
8079
image = F.pil_to_tensor(image)
81-
image = F.convert_image_dtype(image)
8280
target = torch.as_tensor(np.array(target), dtype=torch.int64)
8381
return image, target
8482

8583

84+
class ConvertImageDtype:
85+
def __init__(self, dtype):
86+
self.dtype = dtype
87+
88+
def __call__(self, image, target):
89+
image = F.convert_image_dtype(image, self.dtype)
90+
return image, target
91+
92+
8693
class Normalize(object):
8794
def __init__(self, mean, std):
8895
self.mean = mean

0 commit comments

Comments
 (0)