|
1 | 1 | import itertools
|
2 | 2 | import re
|
| 3 | +from collections import defaultdict |
3 | 4 |
|
4 | 5 | import numpy as np
|
5 | 6 |
|
@@ -1988,3 +1989,154 @@ def test__transform(self, inpt):
|
1988 | 1989 | assert type(output) is type(inpt)
|
1989 | 1990 | assert output.shape[-4] == num_samples
|
1990 | 1991 | assert output.dtype == inpt.dtype
|
| 1992 | + |
| 1993 | + |
| 1994 | +@pytest.mark.parametrize("image_type", (PIL.Image, torch.Tensor, datapoints.Image)) |
| 1995 | +@pytest.mark.parametrize("label_type", (torch.Tensor, int)) |
| 1996 | +@pytest.mark.parametrize("dataset_return_type", (dict, tuple)) |
| 1997 | +@pytest.mark.parametrize("to_tensor", (transforms.ToTensor, transforms.ToImageTensor)) |
| 1998 | +def test_classif_preset(image_type, label_type, dataset_return_type, to_tensor): |
| 1999 | + |
| 2000 | + image = datapoints.Image(torch.randint(0, 256, size=(1, 3, 250, 250), dtype=torch.uint8)) |
| 2001 | + if image_type is PIL.Image: |
| 2002 | + image = to_pil_image(image[0]) |
| 2003 | + elif image_type is torch.Tensor: |
| 2004 | + image = image.as_subclass(torch.Tensor) |
| 2005 | + assert is_simple_tensor(image) |
| 2006 | + |
| 2007 | + label = 1 if label_type is int else torch.tensor([1]) |
| 2008 | + |
| 2009 | + if dataset_return_type is dict: |
| 2010 | + sample = { |
| 2011 | + "image": image, |
| 2012 | + "label": label, |
| 2013 | + } |
| 2014 | + else: |
| 2015 | + sample = image, label |
| 2016 | + |
| 2017 | + t = transforms.Compose( |
| 2018 | + [ |
| 2019 | + transforms.RandomResizedCrop((224, 224)), |
| 2020 | + transforms.RandomHorizontalFlip(p=1), |
| 2021 | + transforms.RandAugment(), |
| 2022 | + transforms.TrivialAugmentWide(), |
| 2023 | + transforms.AugMix(), |
| 2024 | + transforms.AutoAugment(), |
| 2025 | + to_tensor(), |
| 2026 | + # TODO: ConvertImageDtype is a pass-through on PIL images, is that |
| 2027 | + # intended? This results in a failure if we convert to tensor after |
| 2028 | + # it, because the image would still be uint8 which make Normalize |
| 2029 | + # fail. |
| 2030 | + transforms.ConvertImageDtype(torch.float), |
| 2031 | + transforms.Normalize(mean=[0, 0, 0], std=[1, 1, 1]), |
| 2032 | + transforms.RandomErasing(p=1), |
| 2033 | + ] |
| 2034 | + ) |
| 2035 | + |
| 2036 | + out = t(sample) |
| 2037 | + |
| 2038 | + assert type(out) == type(sample) |
| 2039 | + |
| 2040 | + if dataset_return_type is tuple: |
| 2041 | + out_image, out_label = out |
| 2042 | + else: |
| 2043 | + assert out.keys() == sample.keys() |
| 2044 | + out_image, out_label = out.values() |
| 2045 | + |
| 2046 | + assert out_image.shape[-2:] == (224, 224) |
| 2047 | + assert out_label == label |
| 2048 | + |
| 2049 | + |
| 2050 | +@pytest.mark.parametrize("image_type", (PIL.Image, torch.Tensor, datapoints.Image)) |
| 2051 | +@pytest.mark.parametrize("label_type", (torch.Tensor, list)) |
| 2052 | +@pytest.mark.parametrize("data_augmentation", ("hflip", "lsj", "multiscale", "ssd", "ssdlite")) |
| 2053 | +@pytest.mark.parametrize("to_tensor", (transforms.ToTensor, transforms.ToImageTensor)) |
| 2054 | +def test_detection_preset(image_type, label_type, data_augmentation, to_tensor): |
| 2055 | + if data_augmentation == "hflip": |
| 2056 | + t = [ |
| 2057 | + transforms.RandomHorizontalFlip(p=1), |
| 2058 | + to_tensor(), |
| 2059 | + transforms.ConvertImageDtype(torch.float), |
| 2060 | + ] |
| 2061 | + elif data_augmentation == "lsj": |
| 2062 | + t = [ |
| 2063 | + transforms.ScaleJitter(target_size=(1024, 1024), antialias=True), |
| 2064 | + # Note: replaced FixedSizeCrop with RandomCrop, becuase we're |
| 2065 | + # leaving FixedSizeCrop in prototype for now, and it expects Label |
| 2066 | + # classes which we won't release yet. |
| 2067 | + # transforms.FixedSizeCrop( |
| 2068 | + # size=(1024, 1024), fill=defaultdict(lambda: (123.0, 117.0, 104.0), {datapoints.Mask: 0}) |
| 2069 | + # ), |
| 2070 | + transforms.RandomCrop((1024, 1024), pad_if_needed=True), |
| 2071 | + transforms.RandomHorizontalFlip(p=1), |
| 2072 | + to_tensor(), |
| 2073 | + transforms.ConvertImageDtype(torch.float), |
| 2074 | + ] |
| 2075 | + elif data_augmentation == "multiscale": |
| 2076 | + t = [ |
| 2077 | + transforms.RandomShortestSize( |
| 2078 | + min_size=(480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800), max_size=1333, antialias=True |
| 2079 | + ), |
| 2080 | + transforms.RandomHorizontalFlip(p=1), |
| 2081 | + to_tensor(), |
| 2082 | + transforms.ConvertImageDtype(torch.float), |
| 2083 | + ] |
| 2084 | + elif data_augmentation == "ssd": |
| 2085 | + t = [ |
| 2086 | + transforms.RandomPhotometricDistort(p=1), |
| 2087 | + transforms.RandomZoomOut(fill=defaultdict(lambda: (123.0, 117.0, 104.0), {datapoints.Mask: 0})), |
| 2088 | + # TODO: put back IoUCrop once we remove its hard requirement for Labels |
| 2089 | + # transforms.RandomIoUCrop(), |
| 2090 | + transforms.RandomHorizontalFlip(p=1), |
| 2091 | + to_tensor(), |
| 2092 | + transforms.ConvertImageDtype(torch.float), |
| 2093 | + ] |
| 2094 | + elif data_augmentation == "ssdlite": |
| 2095 | + t = [ |
| 2096 | + # TODO: put back IoUCrop once we remove its hard requirement for Labels |
| 2097 | + # transforms.RandomIoUCrop(), |
| 2098 | + transforms.RandomHorizontalFlip(p=1), |
| 2099 | + to_tensor(), |
| 2100 | + transforms.ConvertImageDtype(torch.float), |
| 2101 | + ] |
| 2102 | + t = transforms.Compose(t) |
| 2103 | + |
| 2104 | + num_boxes = 5 |
| 2105 | + H = W = 250 |
| 2106 | + |
| 2107 | + image = datapoints.Image(torch.randint(0, 256, size=(1, 3, H, W), dtype=torch.uint8)) |
| 2108 | + if image_type is PIL.Image: |
| 2109 | + image = to_pil_image(image[0]) |
| 2110 | + elif image_type is torch.Tensor: |
| 2111 | + image = image.as_subclass(torch.Tensor) |
| 2112 | + assert is_simple_tensor(image) |
| 2113 | + |
| 2114 | + label = torch.randint(0, 10, size=(num_boxes,)) |
| 2115 | + if label_type is list: |
| 2116 | + label = label.tolist() |
| 2117 | + |
| 2118 | + # TODO: is the shape of the boxes OK? Should it be (1, num_boxes, 4)?? Same for masks |
| 2119 | + boxes = torch.randint(0, min(H, W) // 2, size=(num_boxes, 4)) |
| 2120 | + boxes[:, 2:] += boxes[:, :2] |
| 2121 | + boxes = boxes.clamp(min=0, max=min(H, W)) |
| 2122 | + boxes = datapoints.BoundingBox(boxes, format="XYXY", spatial_size=(H, W)) |
| 2123 | + |
| 2124 | + masks = datapoints.Mask(torch.randint(0, 2, size=(num_boxes, H, W), dtype=torch.uint8)) |
| 2125 | + |
| 2126 | + sample = { |
| 2127 | + "image": image, |
| 2128 | + "label": label, |
| 2129 | + "boxes": boxes, |
| 2130 | + "masks": masks, |
| 2131 | + } |
| 2132 | + |
| 2133 | + out = t(sample) |
| 2134 | + |
| 2135 | + if to_tensor is transforms.ToTensor and image_type is not datapoints.Image: |
| 2136 | + assert is_simple_tensor(out["image"]) |
| 2137 | + else: |
| 2138 | + assert isinstance(out["image"], datapoints.Image) |
| 2139 | + assert isinstance(out["label"], type(sample["label"])) |
| 2140 | + |
| 2141 | + out["label"] = torch.tensor(out["label"]) |
| 2142 | + assert out["boxes"].shape[0] == out["masks"].shape[0] == out["label"].shape[0] == num_boxes |
0 commit comments