|
| 1 | +import torch |
| 2 | +import torchvision.transforms as T |
| 3 | +import torchvision.transforms.functional as F |
| 4 | + |
| 5 | + |
| 6 | +class ValidateModelInput(torch.nn.Module): |
| 7 | + # Pass-through transform that checks the shape and dtypes to make sure the model gets what it expects |
| 8 | + def forward(self, img1, img2, flow, valid_flow_mask): |
| 9 | + |
| 10 | + assert all(isinstance(arg, torch.Tensor) for arg in (img1, img2, flow, valid_flow_mask) if arg is not None) |
| 11 | + assert all(arg.dtype == torch.float32 for arg in (img1, img2, flow) if arg is not None) |
| 12 | + |
| 13 | + assert img1.shape == img2.shape |
| 14 | + h, w = img1.shape[-2:] |
| 15 | + if flow is not None: |
| 16 | + assert flow.shape == (2, h, w) |
| 17 | + if valid_flow_mask is not None: |
| 18 | + assert valid_flow_mask.shape == (h, w) |
| 19 | + assert valid_flow_mask.dtype == torch.bool |
| 20 | + |
| 21 | + return img1, img2, flow, valid_flow_mask |
| 22 | + |
| 23 | + |
| 24 | +class MakeValidFlowMask(torch.nn.Module): |
| 25 | + # This transform generates a valid_flow_mask if it doesn't exist. |
| 26 | + # The flow is considered valid if ||flow||_inf < threshold |
| 27 | + # This is a noop for Kitti and HD1K which already come with a built-in flow mask. |
| 28 | + def __init__(self, threshold=1000): |
| 29 | + super().__init__() |
| 30 | + self.threshold = threshold |
| 31 | + |
| 32 | + def forward(self, img1, img2, flow, valid_flow_mask): |
| 33 | + if flow is not None and valid_flow_mask is None: |
| 34 | + valid_flow_mask = (flow.abs() < self.threshold).all(axis=0) |
| 35 | + return img1, img2, flow, valid_flow_mask |
| 36 | + |
| 37 | + |
| 38 | +class ConvertImageDtype(torch.nn.Module): |
| 39 | + def __init__(self, dtype): |
| 40 | + super().__init__() |
| 41 | + self.dtype = dtype |
| 42 | + |
| 43 | + def forward(self, img1, img2, flow, valid_flow_mask): |
| 44 | + img1 = F.convert_image_dtype(img1, dtype=self.dtype) |
| 45 | + img2 = F.convert_image_dtype(img2, dtype=self.dtype) |
| 46 | + |
| 47 | + img1 = img1.contiguous() |
| 48 | + img2 = img2.contiguous() |
| 49 | + |
| 50 | + return img1, img2, flow, valid_flow_mask |
| 51 | + |
| 52 | + |
| 53 | +class Normalize(torch.nn.Module): |
| 54 | + def __init__(self, mean, std): |
| 55 | + super().__init__() |
| 56 | + self.mean = mean |
| 57 | + self.std = std |
| 58 | + |
| 59 | + def forward(self, img1, img2, flow, valid_flow_mask): |
| 60 | + img1 = F.normalize(img1, mean=self.mean, std=self.std) |
| 61 | + img2 = F.normalize(img2, mean=self.mean, std=self.std) |
| 62 | + |
| 63 | + return img1, img2, flow, valid_flow_mask |
| 64 | + |
| 65 | + |
| 66 | +class PILToTensor(torch.nn.Module): |
| 67 | + # Converts all inputs to tensors |
| 68 | + # Technically the flow and the valid mask are numpy arrays, not PIL images, but we keep that naming |
| 69 | + # for consistency with the rest, e.g. the segmentation reference. |
| 70 | + def forward(self, img1, img2, flow, valid_flow_mask): |
| 71 | + img1 = F.pil_to_tensor(img1) |
| 72 | + img2 = F.pil_to_tensor(img2) |
| 73 | + if flow is not None: |
| 74 | + flow = torch.from_numpy(flow) |
| 75 | + if valid_flow_mask is not None: |
| 76 | + valid_flow_mask = torch.from_numpy(valid_flow_mask) |
| 77 | + |
| 78 | + return img1, img2, flow, valid_flow_mask |
| 79 | + |
| 80 | + |
| 81 | +class AsymmetricColorJitter(T.ColorJitter): |
| 82 | + # p determines the proba of doing asymmertric vs symmetric color jittering |
| 83 | + def __init__(self, brightness=0, contrast=0, saturation=0, hue=0, p=0.2): |
| 84 | + super().__init__(brightness=brightness, contrast=contrast, saturation=saturation, hue=hue) |
| 85 | + self.p = p |
| 86 | + |
| 87 | + def forward(self, img1, img2, flow, valid_flow_mask): |
| 88 | + |
| 89 | + if torch.rand(1) < self.p: |
| 90 | + # asymmetric: different transform for img1 and img2 |
| 91 | + img1 = super().forward(img1) |
| 92 | + img2 = super().forward(img2) |
| 93 | + else: |
| 94 | + # symmetric: same transform for img1 and img2 |
| 95 | + batch = torch.stack([img1, img2]) |
| 96 | + batch = super().forward(batch) |
| 97 | + img1, img2 = batch[0], batch[1] |
| 98 | + |
| 99 | + return img1, img2, flow, valid_flow_mask |
| 100 | + |
| 101 | + |
| 102 | +class RandomErasing(T.RandomErasing): |
| 103 | + # This only erases img2, and with an extra max_erase param |
| 104 | + # This max_erase is needed because in the RAFT training ref does: |
| 105 | + # 0 erasing with .5 proba |
| 106 | + # 1 erase with .25 proba |
| 107 | + # 2 erase with .25 proba |
| 108 | + # and there's no accurate way to achieve this otherwise. |
| 109 | + def __init__(self, p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3), value=0, inplace=False, max_erase=1): |
| 110 | + super().__init__(p=p, scale=scale, ratio=ratio, value=value, inplace=inplace) |
| 111 | + self.max_erase = max_erase |
| 112 | + assert self.max_erase > 0 |
| 113 | + |
| 114 | + def forward(self, img1, img2, flow, valid_flow_mask): |
| 115 | + if torch.rand(1) > self.p: |
| 116 | + return img1, img2, flow, valid_flow_mask |
| 117 | + |
| 118 | + for _ in range(torch.randint(self.max_erase, size=(1,)).item()): |
| 119 | + x, y, h, w, v = self.get_params(img2, scale=self.scale, ratio=self.ratio, value=[self.value]) |
| 120 | + img2 = F.erase(img2, x, y, h, w, v, self.inplace) |
| 121 | + |
| 122 | + return img1, img2, flow, valid_flow_mask |
| 123 | + |
| 124 | + |
| 125 | +class RandomHorizontalFlip(T.RandomHorizontalFlip): |
| 126 | + def forward(self, img1, img2, flow, valid_flow_mask): |
| 127 | + if torch.rand(1) > self.p: |
| 128 | + return img1, img2, flow, valid_flow_mask |
| 129 | + |
| 130 | + img1 = F.hflip(img1) |
| 131 | + img2 = F.hflip(img2) |
| 132 | + flow = F.hflip(flow) * torch.tensor([-1, 1])[:, None, None] |
| 133 | + if valid_flow_mask is not None: |
| 134 | + valid_flow_mask = F.hflip(valid_flow_mask) |
| 135 | + return img1, img2, flow, valid_flow_mask |
| 136 | + |
| 137 | + |
| 138 | +class RandomVerticalFlip(T.RandomVerticalFlip): |
| 139 | + def forward(self, img1, img2, flow, valid_flow_mask): |
| 140 | + if torch.rand(1) > self.p: |
| 141 | + return img1, img2, flow, valid_flow_mask |
| 142 | + |
| 143 | + img1 = F.vflip(img1) |
| 144 | + img2 = F.vflip(img2) |
| 145 | + flow = F.vflip(flow) * torch.tensor([1, -1])[:, None, None] |
| 146 | + if valid_flow_mask is not None: |
| 147 | + valid_flow_mask = F.vflip(valid_flow_mask) |
| 148 | + return img1, img2, flow, valid_flow_mask |
| 149 | + |
| 150 | + |
| 151 | +class RandomResizeAndCrop(torch.nn.Module): |
| 152 | + # This transform will resize the input with a given proba, and then crop it. |
| 153 | + # These are the reversed operations of the built-in RandomResizedCrop, |
| 154 | + # although the order of the operations doesn't matter too much: resizing a |
| 155 | + # crop would give the same result as cropping a resized image, up to |
| 156 | + # interpolation artifact at the borders of the output. |
| 157 | + # |
| 158 | + # The reason we don't rely on RandomResizedCrop is because of a significant |
| 159 | + # difference in the parametrization of both transforms, in particular, |
| 160 | + # because of the way the random parameters are sampled in both transforms, |
| 161 | + # which leads to fairly different resuts (and different epe). For more details see |
| 162 | + # https://github.com/pytorch/vision/pull/5026/files#r762932579 |
| 163 | + def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, stretch_prob=0.8): |
| 164 | + super().__init__() |
| 165 | + self.crop_size = crop_size |
| 166 | + self.min_scale = min_scale |
| 167 | + self.max_scale = max_scale |
| 168 | + self.stretch_prob = stretch_prob |
| 169 | + self.resize_prob = 0.8 |
| 170 | + self.max_stretch = 0.2 |
| 171 | + |
| 172 | + def forward(self, img1, img2, flow, valid_flow_mask): |
| 173 | + # randomly sample scale |
| 174 | + h, w = img1.shape[-2:] |
| 175 | + # Note: in original code, they use + 1 instead of + 8 for sparse datasets (e.g. Kitti) |
| 176 | + # It shouldn't matter much |
| 177 | + min_scale = max((self.crop_size[0] + 8) / h, (self.crop_size[1] + 8) / w) |
| 178 | + |
| 179 | + scale = 2 ** torch.empty(1, dtype=torch.float32).uniform_(self.min_scale, self.max_scale).item() |
| 180 | + scale_x = scale |
| 181 | + scale_y = scale |
| 182 | + if torch.rand(1) < self.stretch_prob: |
| 183 | + scale_x *= 2 ** torch.empty(1, dtype=torch.float32).uniform_(-self.max_stretch, self.max_stretch).item() |
| 184 | + scale_y *= 2 ** torch.empty(1, dtype=torch.float32).uniform_(-self.max_stretch, self.max_stretch).item() |
| 185 | + |
| 186 | + scale_x = max(scale_x, min_scale) |
| 187 | + scale_y = max(scale_y, min_scale) |
| 188 | + |
| 189 | + new_h, new_w = round(h * scale_y), round(w * scale_x) |
| 190 | + |
| 191 | + if torch.rand(1).item() < self.resize_prob: |
| 192 | + # rescale the images |
| 193 | + img1 = F.resize(img1, size=(new_h, new_w)) |
| 194 | + img2 = F.resize(img2, size=(new_h, new_w)) |
| 195 | + if valid_flow_mask is None: |
| 196 | + flow = F.resize(flow, size=(new_h, new_w)) |
| 197 | + flow = flow * torch.tensor([scale_x, scale_y])[:, None, None] |
| 198 | + else: |
| 199 | + flow, valid_flow_mask = self._resize_sparse_flow( |
| 200 | + flow, valid_flow_mask, scale_x=scale_x, scale_y=scale_y |
| 201 | + ) |
| 202 | + |
| 203 | + # Note: For sparse datasets (Kitti), the original code uses a "margin" |
| 204 | + # See e.g. https://github.com/princeton-vl/RAFT/blob/master/core/utils/augmentor.py#L220:L220 |
| 205 | + # We don't, not sure it matters much |
| 206 | + y0 = torch.randint(0, img1.shape[1] - self.crop_size[0], size=(1,)).item() |
| 207 | + x0 = torch.randint(0, img1.shape[2] - self.crop_size[1], size=(1,)).item() |
| 208 | + |
| 209 | + img1 = F.crop(img1, y0, x0, self.crop_size[0], self.crop_size[1]) |
| 210 | + img2 = F.crop(img2, y0, x0, self.crop_size[0], self.crop_size[1]) |
| 211 | + flow = F.crop(flow, y0, x0, self.crop_size[0], self.crop_size[1]) |
| 212 | + if valid_flow_mask is not None: |
| 213 | + valid_flow_mask = F.crop(valid_flow_mask, y0, x0, self.crop_size[0], self.crop_size[1]) |
| 214 | + |
| 215 | + return img1, img2, flow, valid_flow_mask |
| 216 | + |
| 217 | + def _resize_sparse_flow(self, flow, valid_flow_mask, scale_x=1.0, scale_y=1.0): |
| 218 | + # This resizes both the flow and the valid_flow_mask mask (which is assumed to be reasonably sparse) |
| 219 | + # There are as-many non-zero values in the original flow as in the resized flow (up to OOB) |
| 220 | + # So for example if scale_x = scale_y = 2, the sparsity of the output flow is multiplied by 4 |
| 221 | + |
| 222 | + h, w = flow.shape[-2:] |
| 223 | + |
| 224 | + h_new = int(round(h * scale_y)) |
| 225 | + w_new = int(round(w * scale_x)) |
| 226 | + flow_new = torch.zeros(size=[2, h_new, w_new], dtype=flow.dtype) |
| 227 | + valid_new = torch.zeros(size=[h_new, w_new], dtype=valid_flow_mask.dtype) |
| 228 | + |
| 229 | + jj, ii = torch.meshgrid(torch.arange(w), torch.arange(h), indexing="xy") |
| 230 | + |
| 231 | + ii_valid, jj_valid = ii[valid_flow_mask], jj[valid_flow_mask] |
| 232 | + |
| 233 | + ii_valid_new = torch.round(ii_valid.to(float) * scale_y).to(torch.long) |
| 234 | + jj_valid_new = torch.round(jj_valid.to(float) * scale_x).to(torch.long) |
| 235 | + |
| 236 | + within_bounds_mask = (0 <= ii_valid_new) & (ii_valid_new < h_new) & (0 <= jj_valid_new) & (jj_valid_new < w_new) |
| 237 | + |
| 238 | + ii_valid = ii_valid[within_bounds_mask] |
| 239 | + jj_valid = jj_valid[within_bounds_mask] |
| 240 | + ii_valid_new = ii_valid_new[within_bounds_mask] |
| 241 | + jj_valid_new = jj_valid_new[within_bounds_mask] |
| 242 | + |
| 243 | + valid_flow_new = flow[:, ii_valid, jj_valid] |
| 244 | + valid_flow_new[0] *= scale_x |
| 245 | + valid_flow_new[1] *= scale_y |
| 246 | + |
| 247 | + flow_new[:, ii_valid_new, jj_valid_new] = valid_flow_new |
| 248 | + valid_new[ii_valid_new, jj_valid_new] = 1 |
| 249 | + |
| 250 | + return flow_new, valid_new |
| 251 | + |
| 252 | + |
| 253 | +class Compose(torch.nn.Module): |
| 254 | + def __init__(self, transforms): |
| 255 | + super().__init__() |
| 256 | + self.transforms = transforms |
| 257 | + |
| 258 | + def forward(self, img1, img2, flow, valid_flow_mask): |
| 259 | + for t in self.transforms: |
| 260 | + img1, img2, flow, valid_flow_mask = t(img1, img2, flow, valid_flow_mask) |
| 261 | + return img1, img2, flow, valid_flow_mask |
0 commit comments