Skip to content

Prototype references #7220

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
126 commits
Select commit Hold shift + click to select a range
c19f838
use prototype transforms in classification reference
pmeier Aug 17, 2022
7b7602e
cleanup
pmeier Aug 17, 2022
4ea2aaf
Merge branch 'main' into prototype-references/classification
pmeier Aug 18, 2022
4990b89
move WrapIntoFeatures into transforms module
pmeier Aug 18, 2022
1cfa965
Merge branch 'main' into prototype-references/classification
pmeier Aug 18, 2022
ca4c5a7
[skip ci] add p=1.0 to CutMix and MixUp
pmeier Aug 18, 2022
a2e24e1
Merge branch 'main' into prototype-references/classification
pmeier Aug 23, 2022
693795e
[skip ci]
pmeier Aug 23, 2022
69f5299
Merge branch 'main' into prototype-references/classification
pmeier Aug 24, 2022
fe96a54
use prototype transforms in detection references
pmeier Aug 24, 2022
0f06516
Merge branch 'main' into prototype-references/classification
pmeier Aug 24, 2022
148885d
Merge branch 'main' into prototype-references/classification
pmeier Aug 26, 2022
6fd5e50
[skip ci]
pmeier Aug 26, 2022
6edb7f4
Merge branch 'main' into prototype-references/classification
pmeier Aug 30, 2022
6fcffb2
[skip ci]
pmeier Aug 30, 2022
4b68e2f
[skip ci] Merge branch 'main' into prototype-references/classification
pmeier Aug 30, 2022
4d73fe7
[skip ci] Merge branch 'main' into prototype-references/classification
pmeier Aug 31, 2022
7cb08d5
[skip ci] fix scripts
pmeier Sep 1, 2022
e51791d
[skip ci] Merge branch 'main' into prototype-references/classification
pmeier Sep 1, 2022
3e5e064
[skip ci] Merge branch 'main' into prototype-references/classification
pmeier Sep 1, 2022
ec120ff
[skip ci] Merge branch 'main' into prototype-references/classification
pmeier Sep 1, 2022
f993926
Merge branch 'main' into prototype-references/classification
datumbox Sep 5, 2022
6167038
[skip ci] Merge branch 'main' into prototype-references/classification
pmeier Sep 6, 2022
4699e55
[skip ci] Merge branch 'prototype-references/classification' of https…
pmeier Sep 6, 2022
e2e459d
Merge branch 'main' into prototype-references/classification
datumbox Sep 7, 2022
aa7a655
[skip ci] Merge branch 'main' into prototype-references/classification
pmeier Sep 7, 2022
cb02041
Merge branch 'prototype-references/classification' of https://github.…
pmeier Sep 7, 2022
a98c05d
[SKIP CI] CircleCI
pmeier Sep 7, 2022
49e653f
[skip ci]
pmeier Sep 7, 2022
fcd37d9
[skip ci] Merge branch 'main' into prototype-references/classification
pmeier Sep 8, 2022
05be06d
Merge branch 'main' into prototype-references/classification
pmeier Sep 12, 2022
9459b0a
Merge branch 'main' into prototype-references/classification
pmeier Sep 13, 2022
6c90b3a
Merge branch 'main' into prototype-references/classification
pmeier Sep 13, 2022
2eccb84
Merge branch 'main' into prototype-references/classification
pmeier Sep 14, 2022
8df9043
update segmentation references
pmeier Sep 13, 2022
99e6c36
[skip ci]
pmeier Sep 14, 2022
47772ac
Merge branch 'main' into prototype-references/classification
pmeier Sep 14, 2022
94ac15d
[skip ci]
pmeier Sep 14, 2022
51307b7
[skip ci] fix workaround
pmeier Sep 14, 2022
9dad6e0
only wrap segmentation mask
pmeier Sep 14, 2022
f5f1716
fix pretrained weights test only
pmeier Sep 14, 2022
2aefd09
[skip ci]
pmeier Sep 14, 2022
a2893a1
Restore get_dimensions
datumbox Sep 14, 2022
8df0cf4
Merge branch 'main' into prototype-references/classification
pmeier Sep 21, 2022
e912976
fix segmentation transforms
pmeier Sep 21, 2022
2e7e168
[skip ci]
pmeier Sep 21, 2022
74ecb49
Merge branch 'prototype-references/classification' of https://github.…
pmeier Sep 21, 2022
585c64a
fix mask rewrapping
pmeier Sep 21, 2022
93d7a32
[skip ci]
pmeier Sep 21, 2022
5a311b3
Merge branch 'main' into prototype-references/classification
datumbox Sep 21, 2022
aac24c1
Merge branch 'main' into prototype-references/classification
datumbox Sep 23, 2022
766af6c
Fix merge issue
datumbox Sep 23, 2022
cb6c90e
Tensor Backend + antialiasing=True
datumbox Sep 23, 2022
e9c480e
Switch to view to reshape to avoid incompatibilities with size/stride
datumbox Sep 23, 2022
3894efb
Merge branch 'main' into prototype-references/classification
datumbox Sep 25, 2022
6ef4d82
Cherrypick PR #6642
datumbox Sep 25, 2022
5a1de52
Merge branch 'main' into prototype-references/classification
datumbox Sep 26, 2022
c6950ae
Merge branch 'main'
pmeier Oct 10, 2022
b59beae
Merge branch 'main' into prototype-references/classification
pmeier Oct 10, 2022
906428a
Merge branch 'main' into prototype-references/classification
pmeier Oct 10, 2022
758de46
[skip ci] add support for video_classification
pmeier Oct 10, 2022
a0895c1
Merge branch 'prototype-references/classification' of https://github.…
pmeier Oct 10, 2022
2bd4291
Merge branch 'main' into prototype-references/classification
datumbox Oct 11, 2022
669b1ba
Restoring original reference transforms so that test can run
datumbox Oct 11, 2022
591a773
Adding AA, Random Erase, MixUp/CutMix and a different resize/crop str…
datumbox Oct 11, 2022
0db3ce2
Merge branch 'main' into prototype-references/classification
datumbox Oct 11, 2022
9e95b78
image_size to spatial_size
datumbox Oct 13, 2022
4f3b593
Merge branch 'main' into prototype-references/classification
datumbox Oct 14, 2022
a364b15
Merge branch 'main' into prototype-references/classification
datumbox Oct 14, 2022
00d1b9b
Update the RandomShortestSize behaviour on Video presets.
datumbox Oct 14, 2022
ef3dc55
Fix ToDtype transform to accept dictionaries.
datumbox Oct 14, 2022
711128c
Merge branch 'main' into prototype-references/classification
datumbox Oct 14, 2022
25c4664
Fix issue with collate and audio using Philip's proposal.
datumbox Oct 14, 2022
091948e
Fix linter
datumbox Oct 14, 2022
6b23587
Fix ToDtype parameters.
datumbox Oct 14, 2022
eb37f8f
Wrapping id into a no-op.
datumbox Oct 14, 2022
bb468ba
Define `_Feature` in the dict.
datumbox Oct 14, 2022
5f8d233
Merge branch 'main' into prototype-references/classification
datumbox Oct 14, 2022
5928876
Handling hot-encoded tensors in `accuracy`
datumbox Oct 14, 2022
b63e607
Handle ConvertBCHWtoCBHW interactions with mixup/cutmix.
datumbox Oct 14, 2022
ab141f9
Merge branch 'main' into prototype-references/classification
datumbox Oct 14, 2022
d5f1532
Add Permute Transform.
datumbox Oct 14, 2022
707190c
Merge branch 'main' into prototype-references/classification
datumbox Oct 19, 2022
598542c
Merge branch 'main' into prototype-references/classification
datumbox Oct 21, 2022
6a0a32c
Switch to `TransposeDimensions`
datumbox Oct 21, 2022
a59f995
Merge branch 'main' into prototype-references/classification
datumbox Oct 26, 2022
f72f5b2
Merge branch 'main' into prototype-references/classification
datumbox Oct 27, 2022
9d0a0a3
Fix linter.
datumbox Oct 27, 2022
7c41f0c
Merge branch 'main' into prototype-references/classification
datumbox Oct 31, 2022
87031f1
Merge branch 'main' into prototype-references/classification
datumbox Nov 4, 2022
7c5da3a
Merge branch 'main' into prototype-references/classification
datumbox Nov 4, 2022
d8b5202
Fix method location.
datumbox Nov 4, 2022
959af2d
Fixing minor bug
datumbox Nov 7, 2022
d435378
Merge branch 'main' into prototype-references/classification
datumbox Nov 15, 2022
bda072d
Merge branch 'main' into prototype-references/classification
datumbox Nov 16, 2022
8f07159
Convert to floats at the beginning.
datumbox Nov 17, 2022
8344ce9
Revert "Convert to floats at the beginning."
datumbox Nov 17, 2022
8b53036
Switch to PIL backend
datumbox Nov 17, 2022
c7f2ac8
Revert "Switch to PIL backend"
datumbox Nov 17, 2022
f205f1e
Merge branch 'main' of github.com:pytorch/vision into prototype-refer…
NicolasHug Feb 9, 2023
81a4c0f
Clf ref: remove usage of Label and use MixUp/CutMix from reference (n…
NicolasHug Feb 9, 2023
04248ae
Merge branch 'main' of github.com:pytorch/vision into proto_reference…
NicolasHug Feb 9, 2023
5786940
Merge branch 'main' of github.com:pytorch/vision into proto_reference…
NicolasHug Feb 10, 2023
da10e99
Update detection reference
NicolasHug Feb 10, 2023
1586691
Revert changes made to segmentation and video
NicolasHug Feb 10, 2023
4040da4
remove unused stuff
NicolasHug Feb 10, 2023
6fcc4bc
Merge branch 'main' of github.com:pytorch/vision into proto_reference…
NicolasHug Feb 13, 2023
8988c22
Add --backend param to clasif references + remove unneeded patches
NicolasHug Feb 13, 2023
8c25a3c
Merge branch 'main' of github.com:pytorch/vision into proto_reference…
NicolasHug Feb 13, 2023
8d4f662
Merge branch 'main' of github.com:pytorch/vision into proto_reference…
NicolasHug Feb 13, 2023
f5e5e40
Add --backend param to detection references
NicolasHug Feb 13, 2023
4929c02
clean up
NicolasHug Feb 13, 2023
fd96e03
Cleanup again
NicolasHug Feb 13, 2023
4d1c6ce
Merge branch 'main' of github.com:pytorch/vision into proto_reference…
NicolasHug Feb 13, 2023
f04df31
Some fixes and todos to detection ref
NicolasHug Feb 13, 2023
a9a8f8b
Merge branch 'main' of github.com:pytorch/vision into proto_reference…
NicolasHug Feb 16, 2023
170ed2a
cleanup classif
NicolasHug Feb 16, 2023
9f1738b
cleanup detection
NicolasHug Feb 16, 2023
035ccd7
oops
NicolasHug Feb 16, 2023
831aacc
Merge branch 'main' of github.com:pytorch/vision into proto_reference…
NicolasHug Feb 21, 2023
b5e3b91
Disable warnings
NicolasHug Feb 21, 2023
c00a181
Add keypoint support
NicolasHug Feb 22, 2023
5147d8b
Add segmentation
NicolasHug Feb 22, 2023
a6ed2d3
Merge branch 'main' of github.com:pytorch/vision into proto_reference…
NicolasHug Apr 11, 2023
6e023fc
Fix updated name
NicolasHug Apr 11, 2023
e4de74b
Return needed image_id
NicolasHug Apr 11, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 50 additions & 21 deletions references/classification/presets.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import torch
from torchvision.transforms import autoaugment, transforms
from torchvision.transforms.functional import InterpolationMode
import torchvision.transforms.v2 as transforms


class ClassificationPresetTrain:
Expand All @@ -10,29 +9,44 @@ def __init__(
crop_size,
mean=(0.485, 0.456, 0.406),
std=(0.229, 0.224, 0.225),
interpolation=InterpolationMode.BILINEAR,
interpolation=transforms.InterpolationMode.BILINEAR,
hflip_prob=0.5,
auto_augment_policy=None,
ra_magnitude=9,
augmix_severity=3,
random_erase_prob=0.0,
backend="pil",
):
trans = [transforms.RandomResizedCrop(crop_size, interpolation=interpolation)]
trans = []

backend = backend.lower()
if backend == "datapoint":
trans.append(transforms.ToImageTensor())
elif backend == "tensor":
trans.append(transforms.PILToTensor())
else:
assert backend == "pil"

trans.append(transforms.RandomResizedCrop(crop_size, interpolation=interpolation, antialias=True))
if hflip_prob > 0:
trans.append(transforms.RandomHorizontalFlip(hflip_prob))
trans.append(transforms.RandomHorizontalFlip(p=hflip_prob))
if auto_augment_policy is not None:
if auto_augment_policy == "ra":
trans.append(autoaugment.RandAugment(interpolation=interpolation, magnitude=ra_magnitude))
trans.append(transforms.RandAugment(interpolation=interpolation, magnitude=ra_magnitude))
elif auto_augment_policy == "ta_wide":
trans.append(autoaugment.TrivialAugmentWide(interpolation=interpolation))
trans.append(transforms.TrivialAugmentWide(interpolation=interpolation))
elif auto_augment_policy == "augmix":
trans.append(autoaugment.AugMix(interpolation=interpolation, severity=augmix_severity))
trans.append(transforms.AugMix(interpolation=interpolation, severity=augmix_severity))
else:
aa_policy = autoaugment.AutoAugmentPolicy(auto_augment_policy)
trans.append(autoaugment.AutoAugment(policy=aa_policy, interpolation=interpolation))
aa_policy = transforms.AutoAugmentPolicy(auto_augment_policy)
trans.append(transforms.AutoAugment(policy=aa_policy, interpolation=interpolation))

if backend == "pil":
# Note: we could also just use pure tensors?
trans.append(transforms.ToImageTensor())

trans.extend(
[
transforms.PILToTensor(),
transforms.ConvertImageDtype(torch.float),
transforms.Normalize(mean=mean, std=std),
]
Expand All @@ -54,18 +68,33 @@ def __init__(
resize_size=256,
mean=(0.485, 0.456, 0.406),
std=(0.229, 0.224, 0.225),
interpolation=InterpolationMode.BILINEAR,
interpolation=transforms.InterpolationMode.BILINEAR,
backend="pil",
):
trans = []

self.transforms = transforms.Compose(
[
transforms.Resize(resize_size, interpolation=interpolation),
transforms.CenterCrop(crop_size),
transforms.PILToTensor(),
transforms.ConvertImageDtype(torch.float),
transforms.Normalize(mean=mean, std=std),
]
)
backend = backend.lower()
if backend == "datapoint":
trans.append(transforms.ToImageTensor())
elif backend == "tensor":
trans.append(transforms.PILToTensor())
else:
assert backend == "pil"

trans += [
transforms.Resize(resize_size, interpolation=interpolation, antialias=True),
transforms.CenterCrop(crop_size),
]

if backend == "pil":
trans.append(transforms.ToImageTensor())

trans += [
transforms.ConvertImageDtype(torch.float),
transforms.Normalize(mean=mean, std=std),
]

self.transforms = transforms.Compose(trans)

def __call__(self, img):
return self.transforms(img)
17 changes: 12 additions & 5 deletions references/classification/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,17 @@
import warnings

import presets
from sampler import RASampler
import utils # usort: skip

import torch
import torch.utils.data
import torchvision

import transforms
import utils
from sampler import RASampler
from torch import nn
from torch.utils.data.dataloader import default_collate
from torchvision.transforms.functional import InterpolationMode
from torchvision.transforms.v2 import InterpolationMode


def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args, model_ema=None, scaler=None):
Expand Down Expand Up @@ -143,6 +145,7 @@ def load_data(traindir, valdir, args):
random_erase_prob=random_erase_prob,
ra_magnitude=ra_magnitude,
augmix_severity=augmix_severity,
backend=args.backend,
),
)
if args.cache_dataset:
Expand All @@ -163,12 +166,15 @@ def load_data(traindir, valdir, args):
preprocessing = weights.transforms()
else:
preprocessing = presets.ClassificationPresetEval(
crop_size=val_crop_size, resize_size=val_resize_size, interpolation=interpolation
crop_size=val_crop_size,
resize_size=val_resize_size,
interpolation=interpolation,
backend=args.backend,
)

dataset_test = torchvision.datasets.ImageFolder(
valdir,
preprocessing,
transform=preprocessing,
)
if args.cache_dataset:
print(f"Saving dataset_test to {cache_path}")
Expand Down Expand Up @@ -507,6 +513,7 @@ def get_args_parser(add_help=True):
"--ra-reps", default=3, type=int, help="number of repetitions for Repeated Augmentation (default: 3)"
)
parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")
parser.add_argument("--backend", default="PIL", type=str, help="PIL, tensor or datapoint - case insensitive")
return parser


Expand Down
65 changes: 17 additions & 48 deletions references/detection/coco_utils.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,12 @@
import copy
import os

import torch
import torch.utils.data
import torchvision
import transforms as T

from pycocotools import mask as coco_mask
from pycocotools.coco import COCO


class FilterAndRemapCocoCategories:
def __init__(self, categories, remap=True):
self.categories = categories
self.remap = remap

def __call__(self, image, target):
anno = target["annotations"]
anno = [obj for obj in anno if obj["category_id"] in self.categories]
if not self.remap:
target["annotations"] = anno
return image, target
anno = copy.deepcopy(anno)
for obj in anno:
obj["category_id"] = self.categories.index(obj["category_id"])
target["annotations"] = anno
return image, target
from torchvision.datasets import wrap_dataset_for_transforms_v2


def convert_coco_poly_to_mask(segmentations, height, width):
Expand All @@ -44,8 +26,10 @@ def convert_coco_poly_to_mask(segmentations, height, width):
return masks


# TODO: Is this still needed?
class ConvertCocoPolysToMask:
def __call__(self, image, target):
def __call__(self, sample):
image, target = sample
w, h = image.size

image_id = target["image_id"]
Expand Down Expand Up @@ -126,10 +110,10 @@ def _has_valid_annotation(anno):
return True
return False

if not isinstance(dataset, torchvision.datasets.CocoDetection):
raise TypeError(
f"This function expects dataset of type torchvision.datasets.CocoDetection, instead got {type(dataset)}"
)
# if not isinstance(dataset, torchvision.datasets.CocoDetection):
# raise TypeError(
# f"This function expects dataset of type torchvision.datasets.CocoDetection, instead got {type(dataset)}"
# )
ids = []
for ds_idx, img_id in enumerate(dataset.ids):
ann_ids = dataset.coco.getAnnIds(imgIds=img_id, iscrowd=None)
Expand Down Expand Up @@ -201,50 +185,35 @@ def get_coco_api_from_dataset(dataset):
break
if isinstance(dataset, torch.utils.data.Subset):
dataset = dataset.dataset
if isinstance(dataset, torchvision.datasets.CocoDetection):
# TODO: hmmmmm
if isinstance(dataset, torchvision.datasets.CocoDetection) or isinstance(
getattr(dataset, "_dataset", None), torchvision.datasets.CocoDetection
):
return dataset.coco
return convert_to_coco_api(dataset)


class CocoDetection(torchvision.datasets.CocoDetection):
def __init__(self, img_folder, ann_file, transforms):
super().__init__(img_folder, ann_file)
self._transforms = transforms

def __getitem__(self, idx):
img, target = super().__getitem__(idx)
image_id = self.ids[idx]
target = dict(image_id=image_id, annotations=target)
if self._transforms is not None:
img, target = self._transforms(img, target)
return img, target
# TODO: Maybe not critical but the wrapper doesn't work on sub-classes


def get_coco(root, image_set, transforms, mode="instances"):
anno_file_template = "{}_{}2017.json"
PATHS = {
"train": ("train2017", os.path.join("annotations", anno_file_template.format(mode, "train"))),
"val": ("val2017", os.path.join("annotations", anno_file_template.format(mode, "val"))),
# "train": ("val2017", os.path.join("annotations", anno_file_template.format(mode, "val")))
}

t = [ConvertCocoPolysToMask()]

if transforms is not None:
t.append(transforms)
transforms = T.Compose(t)

img_folder, ann_file = PATHS[image_set]
img_folder = os.path.join(root, img_folder)
ann_file = os.path.join(root, ann_file)

dataset = CocoDetection(img_folder, ann_file, transforms=transforms)
dataset = torchvision.datasets.CocoDetection(img_folder, ann_file, transforms=transforms)
# TODO: need to update target_keys to handle masks for segmentation!
dataset = wrap_dataset_for_transforms_v2(dataset, target_keys={"boxes", "labels", "image_id"})

if image_set == "train":
dataset = _coco_remove_images_without_annotations(dataset)

# dataset = torch.utils.data.Subset(dataset, [i for i in range(500)])

return dataset


Expand Down
6 changes: 4 additions & 2 deletions references/detection/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@ def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq, sc

for images, targets in metric_logger.log_every(data_loader, print_freq, header):
images = list(image.to(device) for image in images)
targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
for t in targets:
assert t["labels"].shape[0] == t["boxes"].shape[0], f"{t['labels'].shape} {t['boxes'].shape}"
targets = [{k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in t.items()} for t in targets]
with torch.cuda.amp.autocast(enabled=scaler is not None):
loss_dict = model(images, targets)
losses = sum(loss for loss in loss_dict.values())
Expand Down Expand Up @@ -97,7 +99,7 @@ def evaluate(model, data_loader, device):
outputs = [{k: v.to(cpu_device) for k, v in t.items()} for t in outputs]
model_time = time.time() - model_time

res = {target["image_id"].item(): output for target, output in zip(targets, outputs)}
res = {target["image_id"]: output for target, output in zip(targets, outputs)}
evaluator_time = time.time()
coco_evaluator.update(res)
evaluator_time = time.time() - evaluator_time
Expand Down
5 changes: 4 additions & 1 deletion references/detection/group_by_aspect_ratio.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,10 @@ def compute_aspect_ratios(dataset, indices=None):
if hasattr(dataset, "get_height_and_width"):
return _compute_aspect_ratios_custom_dataset(dataset, indices)

if isinstance(dataset, torchvision.datasets.CocoDetection):
# TODO: hmmmmm
if isinstance(dataset, torchvision.datasets.CocoDetection) or isinstance(
getattr(dataset, "_dataset", None), torchvision.datasets.CocoDetection
):
return _compute_aspect_ratios_coco_dataset(dataset, indices)

if isinstance(dataset, torchvision.datasets.VOCDetection):
Expand Down
Loading