Skip to content

Commit 47ae092

Browse files
authored
Add transforms and presets for optical flow models (#5026)
1 parent 4dd8b5c commit 47ae092

File tree

2 files changed

+325
-0
lines changed

2 files changed

+325
-0
lines changed

references/optical_flow/presets.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
import torch
2+
import transforms as T
3+
4+
5+
class OpticalFlowPresetEval(torch.nn.Module):
6+
def __init__(self):
7+
super().__init__()
8+
9+
self.transforms = T.Compose(
10+
[
11+
T.PILToTensor(),
12+
T.ConvertImageDtype(torch.float32),
13+
T.Normalize(mean=0.5, std=0.5), # map [0, 1] into [-1, 1]
14+
T.ValidateModelInput(),
15+
]
16+
)
17+
18+
def forward(self, img1, img2, flow, valid):
19+
return self.transforms(img1, img2, flow, valid)
20+
21+
22+
class OpticalFlowPresetTrain(torch.nn.Module):
23+
def __init__(
24+
self,
25+
# RandomResizeAndCrop params
26+
crop_size,
27+
min_scale=-0.2,
28+
max_scale=0.5,
29+
stretch_prob=0.8,
30+
# AsymmetricColorJitter params
31+
brightness=0.4,
32+
contrast=0.4,
33+
saturation=0.4,
34+
hue=0.5 / 3.14,
35+
# Random[H,V]Flip params
36+
asymmetric_jitter_prob=0.2,
37+
do_flip=True,
38+
):
39+
super().__init__()
40+
41+
transforms = [
42+
T.PILToTensor(),
43+
T.AsymmetricColorJitter(
44+
brightness=brightness, contrast=contrast, saturation=saturation, hue=hue, p=asymmetric_jitter_prob
45+
),
46+
T.RandomResizeAndCrop(
47+
crop_size=crop_size, min_scale=min_scale, max_scale=max_scale, stretch_prob=stretch_prob
48+
),
49+
]
50+
51+
if do_flip:
52+
transforms += [T.RandomHorizontalFlip(p=0.5), T.RandomVerticalFlip(p=0.1)]
53+
54+
transforms += [
55+
T.ConvertImageDtype(torch.float32),
56+
T.Normalize(mean=0.5, std=0.5), # map [0, 1] into [-1, 1]
57+
T.RandomErasing(max_erase=2),
58+
T.MakeValidFlowMask(),
59+
T.ValidateModelInput(),
60+
]
61+
self.transforms = T.Compose(transforms)
62+
63+
def forward(self, img1, img2, flow, valid):
64+
return self.transforms(img1, img2, flow, valid)

references/optical_flow/transforms.py

Lines changed: 261 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,261 @@
1+
import torch
2+
import torchvision.transforms as T
3+
import torchvision.transforms.functional as F
4+
5+
6+
class ValidateModelInput(torch.nn.Module):
7+
# Pass-through transform that checks the shape and dtypes to make sure the model gets what it expects
8+
def forward(self, img1, img2, flow, valid_flow_mask):
9+
10+
assert all(isinstance(arg, torch.Tensor) for arg in (img1, img2, flow, valid_flow_mask) if arg is not None)
11+
assert all(arg.dtype == torch.float32 for arg in (img1, img2, flow) if arg is not None)
12+
13+
assert img1.shape == img2.shape
14+
h, w = img1.shape[-2:]
15+
if flow is not None:
16+
assert flow.shape == (2, h, w)
17+
if valid_flow_mask is not None:
18+
assert valid_flow_mask.shape == (h, w)
19+
assert valid_flow_mask.dtype == torch.bool
20+
21+
return img1, img2, flow, valid_flow_mask
22+
23+
24+
class MakeValidFlowMask(torch.nn.Module):
25+
# This transform generates a valid_flow_mask if it doesn't exist.
26+
# The flow is considered valid if ||flow||_inf < threshold
27+
# This is a noop for Kitti and HD1K which already come with a built-in flow mask.
28+
def __init__(self, threshold=1000):
29+
super().__init__()
30+
self.threshold = threshold
31+
32+
def forward(self, img1, img2, flow, valid_flow_mask):
33+
if flow is not None and valid_flow_mask is None:
34+
valid_flow_mask = (flow.abs() < self.threshold).all(axis=0)
35+
return img1, img2, flow, valid_flow_mask
36+
37+
38+
class ConvertImageDtype(torch.nn.Module):
39+
def __init__(self, dtype):
40+
super().__init__()
41+
self.dtype = dtype
42+
43+
def forward(self, img1, img2, flow, valid_flow_mask):
44+
img1 = F.convert_image_dtype(img1, dtype=self.dtype)
45+
img2 = F.convert_image_dtype(img2, dtype=self.dtype)
46+
47+
img1 = img1.contiguous()
48+
img2 = img2.contiguous()
49+
50+
return img1, img2, flow, valid_flow_mask
51+
52+
53+
class Normalize(torch.nn.Module):
54+
def __init__(self, mean, std):
55+
super().__init__()
56+
self.mean = mean
57+
self.std = std
58+
59+
def forward(self, img1, img2, flow, valid_flow_mask):
60+
img1 = F.normalize(img1, mean=self.mean, std=self.std)
61+
img2 = F.normalize(img2, mean=self.mean, std=self.std)
62+
63+
return img1, img2, flow, valid_flow_mask
64+
65+
66+
class PILToTensor(torch.nn.Module):
67+
# Converts all inputs to tensors
68+
# Technically the flow and the valid mask are numpy arrays, not PIL images, but we keep that naming
69+
# for consistency with the rest, e.g. the segmentation reference.
70+
def forward(self, img1, img2, flow, valid_flow_mask):
71+
img1 = F.pil_to_tensor(img1)
72+
img2 = F.pil_to_tensor(img2)
73+
if flow is not None:
74+
flow = torch.from_numpy(flow)
75+
if valid_flow_mask is not None:
76+
valid_flow_mask = torch.from_numpy(valid_flow_mask)
77+
78+
return img1, img2, flow, valid_flow_mask
79+
80+
81+
class AsymmetricColorJitter(T.ColorJitter):
82+
# p determines the proba of doing asymmertric vs symmetric color jittering
83+
def __init__(self, brightness=0, contrast=0, saturation=0, hue=0, p=0.2):
84+
super().__init__(brightness=brightness, contrast=contrast, saturation=saturation, hue=hue)
85+
self.p = p
86+
87+
def forward(self, img1, img2, flow, valid_flow_mask):
88+
89+
if torch.rand(1) < self.p:
90+
# asymmetric: different transform for img1 and img2
91+
img1 = super().forward(img1)
92+
img2 = super().forward(img2)
93+
else:
94+
# symmetric: same transform for img1 and img2
95+
batch = torch.stack([img1, img2])
96+
batch = super().forward(batch)
97+
img1, img2 = batch[0], batch[1]
98+
99+
return img1, img2, flow, valid_flow_mask
100+
101+
102+
class RandomErasing(T.RandomErasing):
103+
# This only erases img2, and with an extra max_erase param
104+
# This max_erase is needed because in the RAFT training ref does:
105+
# 0 erasing with .5 proba
106+
# 1 erase with .25 proba
107+
# 2 erase with .25 proba
108+
# and there's no accurate way to achieve this otherwise.
109+
def __init__(self, p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3), value=0, inplace=False, max_erase=1):
110+
super().__init__(p=p, scale=scale, ratio=ratio, value=value, inplace=inplace)
111+
self.max_erase = max_erase
112+
assert self.max_erase > 0
113+
114+
def forward(self, img1, img2, flow, valid_flow_mask):
115+
if torch.rand(1) > self.p:
116+
return img1, img2, flow, valid_flow_mask
117+
118+
for _ in range(torch.randint(self.max_erase, size=(1,)).item()):
119+
x, y, h, w, v = self.get_params(img2, scale=self.scale, ratio=self.ratio, value=[self.value])
120+
img2 = F.erase(img2, x, y, h, w, v, self.inplace)
121+
122+
return img1, img2, flow, valid_flow_mask
123+
124+
125+
class RandomHorizontalFlip(T.RandomHorizontalFlip):
126+
def forward(self, img1, img2, flow, valid_flow_mask):
127+
if torch.rand(1) > self.p:
128+
return img1, img2, flow, valid_flow_mask
129+
130+
img1 = F.hflip(img1)
131+
img2 = F.hflip(img2)
132+
flow = F.hflip(flow) * torch.tensor([-1, 1])[:, None, None]
133+
if valid_flow_mask is not None:
134+
valid_flow_mask = F.hflip(valid_flow_mask)
135+
return img1, img2, flow, valid_flow_mask
136+
137+
138+
class RandomVerticalFlip(T.RandomVerticalFlip):
139+
def forward(self, img1, img2, flow, valid_flow_mask):
140+
if torch.rand(1) > self.p:
141+
return img1, img2, flow, valid_flow_mask
142+
143+
img1 = F.vflip(img1)
144+
img2 = F.vflip(img2)
145+
flow = F.vflip(flow) * torch.tensor([1, -1])[:, None, None]
146+
if valid_flow_mask is not None:
147+
valid_flow_mask = F.vflip(valid_flow_mask)
148+
return img1, img2, flow, valid_flow_mask
149+
150+
151+
class RandomResizeAndCrop(torch.nn.Module):
152+
# This transform will resize the input with a given proba, and then crop it.
153+
# These are the reversed operations of the built-in RandomResizedCrop,
154+
# although the order of the operations doesn't matter too much: resizing a
155+
# crop would give the same result as cropping a resized image, up to
156+
# interpolation artifact at the borders of the output.
157+
#
158+
# The reason we don't rely on RandomResizedCrop is because of a significant
159+
# difference in the parametrization of both transforms, in particular,
160+
# because of the way the random parameters are sampled in both transforms,
161+
# which leads to fairly different resuts (and different epe). For more details see
162+
# https://github.com/pytorch/vision/pull/5026/files#r762932579
163+
def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, stretch_prob=0.8):
164+
super().__init__()
165+
self.crop_size = crop_size
166+
self.min_scale = min_scale
167+
self.max_scale = max_scale
168+
self.stretch_prob = stretch_prob
169+
self.resize_prob = 0.8
170+
self.max_stretch = 0.2
171+
172+
def forward(self, img1, img2, flow, valid_flow_mask):
173+
# randomly sample scale
174+
h, w = img1.shape[-2:]
175+
# Note: in original code, they use + 1 instead of + 8 for sparse datasets (e.g. Kitti)
176+
# It shouldn't matter much
177+
min_scale = max((self.crop_size[0] + 8) / h, (self.crop_size[1] + 8) / w)
178+
179+
scale = 2 ** torch.empty(1, dtype=torch.float32).uniform_(self.min_scale, self.max_scale).item()
180+
scale_x = scale
181+
scale_y = scale
182+
if torch.rand(1) < self.stretch_prob:
183+
scale_x *= 2 ** torch.empty(1, dtype=torch.float32).uniform_(-self.max_stretch, self.max_stretch).item()
184+
scale_y *= 2 ** torch.empty(1, dtype=torch.float32).uniform_(-self.max_stretch, self.max_stretch).item()
185+
186+
scale_x = max(scale_x, min_scale)
187+
scale_y = max(scale_y, min_scale)
188+
189+
new_h, new_w = round(h * scale_y), round(w * scale_x)
190+
191+
if torch.rand(1).item() < self.resize_prob:
192+
# rescale the images
193+
img1 = F.resize(img1, size=(new_h, new_w))
194+
img2 = F.resize(img2, size=(new_h, new_w))
195+
if valid_flow_mask is None:
196+
flow = F.resize(flow, size=(new_h, new_w))
197+
flow = flow * torch.tensor([scale_x, scale_y])[:, None, None]
198+
else:
199+
flow, valid_flow_mask = self._resize_sparse_flow(
200+
flow, valid_flow_mask, scale_x=scale_x, scale_y=scale_y
201+
)
202+
203+
# Note: For sparse datasets (Kitti), the original code uses a "margin"
204+
# See e.g. https://github.com/princeton-vl/RAFT/blob/master/core/utils/augmentor.py#L220:L220
205+
# We don't, not sure it matters much
206+
y0 = torch.randint(0, img1.shape[1] - self.crop_size[0], size=(1,)).item()
207+
x0 = torch.randint(0, img1.shape[2] - self.crop_size[1], size=(1,)).item()
208+
209+
img1 = F.crop(img1, y0, x0, self.crop_size[0], self.crop_size[1])
210+
img2 = F.crop(img2, y0, x0, self.crop_size[0], self.crop_size[1])
211+
flow = F.crop(flow, y0, x0, self.crop_size[0], self.crop_size[1])
212+
if valid_flow_mask is not None:
213+
valid_flow_mask = F.crop(valid_flow_mask, y0, x0, self.crop_size[0], self.crop_size[1])
214+
215+
return img1, img2, flow, valid_flow_mask
216+
217+
def _resize_sparse_flow(self, flow, valid_flow_mask, scale_x=1.0, scale_y=1.0):
218+
# This resizes both the flow and the valid_flow_mask mask (which is assumed to be reasonably sparse)
219+
# There are as-many non-zero values in the original flow as in the resized flow (up to OOB)
220+
# So for example if scale_x = scale_y = 2, the sparsity of the output flow is multiplied by 4
221+
222+
h, w = flow.shape[-2:]
223+
224+
h_new = int(round(h * scale_y))
225+
w_new = int(round(w * scale_x))
226+
flow_new = torch.zeros(size=[2, h_new, w_new], dtype=flow.dtype)
227+
valid_new = torch.zeros(size=[h_new, w_new], dtype=valid_flow_mask.dtype)
228+
229+
jj, ii = torch.meshgrid(torch.arange(w), torch.arange(h), indexing="xy")
230+
231+
ii_valid, jj_valid = ii[valid_flow_mask], jj[valid_flow_mask]
232+
233+
ii_valid_new = torch.round(ii_valid.to(float) * scale_y).to(torch.long)
234+
jj_valid_new = torch.round(jj_valid.to(float) * scale_x).to(torch.long)
235+
236+
within_bounds_mask = (0 <= ii_valid_new) & (ii_valid_new < h_new) & (0 <= jj_valid_new) & (jj_valid_new < w_new)
237+
238+
ii_valid = ii_valid[within_bounds_mask]
239+
jj_valid = jj_valid[within_bounds_mask]
240+
ii_valid_new = ii_valid_new[within_bounds_mask]
241+
jj_valid_new = jj_valid_new[within_bounds_mask]
242+
243+
valid_flow_new = flow[:, ii_valid, jj_valid]
244+
valid_flow_new[0] *= scale_x
245+
valid_flow_new[1] *= scale_y
246+
247+
flow_new[:, ii_valid_new, jj_valid_new] = valid_flow_new
248+
valid_new[ii_valid_new, jj_valid_new] = 1
249+
250+
return flow_new, valid_new
251+
252+
253+
class Compose(torch.nn.Module):
254+
def __init__(self, transforms):
255+
super().__init__()
256+
self.transforms = transforms
257+
258+
def forward(self, img1, img2, flow, valid_flow_mask):
259+
for t in self.transforms:
260+
img1, img2, flow, valid_flow_mask = t(img1, img2, flow, valid_flow_mask)
261+
return img1, img2, flow, valid_flow_mask

0 commit comments

Comments
 (0)