Skip to content

Commit 4d0980d

Browse files
authored
Merge branch 'main' into gdrive-virus-scan
2 parents 46a5512 + b7c59a0 commit 4d0980d

File tree

8 files changed

+163
-48
lines changed

8 files changed

+163
-48
lines changed

README.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,3 +185,10 @@ Disclaimer on Datasets
185185
This is a utility library that downloads and prepares public datasets. We do not host or distribute these datasets, vouch for their quality or fairness, or claim that you have license to use the dataset. It is your responsibility to determine whether you have permission to use the dataset under the dataset's license.
186186

187187
If you're a dataset owner and wish to update any part of it (description, citation, etc.), or do not want your dataset to be included in this library, please get in touch through a GitHub issue. Thanks for your contribution to the ML community!
188+
189+
Pre-trained Model License
190+
=========================
191+
192+
The pre-trained models provided in this library may have their own licenses or terms and conditions derived from the dataset used for training. It is your responsibility to determine whether you have permission to use the models for your use case.
193+
194+
More specifically, SWAG models are released under the CC-BY-NC 4.0 license. See `SWAG LICENSE <https://github.com/facebookresearch/SWAG/blob/main/LICENSE>`_ for additional details.

references/classification/train.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -229,12 +229,18 @@ def main(args):
229229

230230
criterion = nn.CrossEntropyLoss(label_smoothing=args.label_smoothing)
231231

232-
if args.norm_weight_decay is None:
233-
parameters = [p for p in model.parameters() if p.requires_grad]
234-
else:
235-
param_groups = torchvision.ops._utils.split_normalization_params(model)
236-
wd_groups = [args.norm_weight_decay, args.weight_decay]
237-
parameters = [{"params": p, "weight_decay": w} for p, w in zip(param_groups, wd_groups) if p]
232+
custom_keys_weight_decay = []
233+
if args.bias_weight_decay is not None:
234+
custom_keys_weight_decay.append(("bias", args.bias_weight_decay))
235+
if args.transformer_embedding_decay is not None:
236+
for key in ["class_token", "position_embedding", "relative_position_bias"]:
237+
custom_keys_weight_decay.append((key, args.transformer_embedding_decay))
238+
parameters = utils.set_weight_decay(
239+
model,
240+
args.weight_decay,
241+
norm_weight_decay=args.norm_weight_decay,
242+
custom_keys_weight_decay=custom_keys_weight_decay if len(custom_keys_weight_decay) > 0 else None,
243+
)
238244

239245
opt_name = args.opt.lower()
240246
if opt_name.startswith("sgd"):
@@ -393,6 +399,18 @@ def get_args_parser(add_help=True):
393399
type=float,
394400
help="weight decay for Normalization layers (default: None, same value as --wd)",
395401
)
402+
parser.add_argument(
403+
"--bias-weight-decay",
404+
default=None,
405+
type=float,
406+
help="weight decay for bias parameters of all layers (default: None, same value as --wd)",
407+
)
408+
parser.add_argument(
409+
"--transformer-embedding-decay",
410+
default=None,
411+
type=float,
412+
help="weight decay for embedding parameters for vision transformer models (default: None, same value as --wd)",
413+
)
396414
parser.add_argument(
397415
"--label-smoothing", default=0.0, type=float, help="label smoothing (default: 0.0)", dest="label_smoothing"
398416
)

references/classification/utils.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import os
66
import time
77
from collections import defaultdict, deque, OrderedDict
8+
from typing import List, Optional, Tuple
89

910
import torch
1011
import torch.distributed as dist
@@ -400,3 +401,65 @@ def reduce_across_processes(val):
400401
dist.barrier()
401402
dist.all_reduce(t)
402403
return t
404+
405+
406+
def set_weight_decay(
407+
model: torch.nn.Module,
408+
weight_decay: float,
409+
norm_weight_decay: Optional[float] = None,
410+
norm_classes: Optional[List[type]] = None,
411+
custom_keys_weight_decay: Optional[List[Tuple[str, float]]] = None,
412+
):
413+
if not norm_classes:
414+
norm_classes = [
415+
torch.nn.modules.batchnorm._BatchNorm,
416+
torch.nn.LayerNorm,
417+
torch.nn.GroupNorm,
418+
torch.nn.modules.instancenorm._InstanceNorm,
419+
torch.nn.LocalResponseNorm,
420+
]
421+
norm_classes = tuple(norm_classes)
422+
423+
params = {
424+
"other": [],
425+
"norm": [],
426+
}
427+
params_weight_decay = {
428+
"other": weight_decay,
429+
"norm": norm_weight_decay,
430+
}
431+
custom_keys = []
432+
if custom_keys_weight_decay is not None:
433+
for key, weight_decay in custom_keys_weight_decay:
434+
params[key] = []
435+
params_weight_decay[key] = weight_decay
436+
custom_keys.append(key)
437+
438+
def _add_params(module, prefix=""):
439+
for name, p in module.named_parameters(recurse=False):
440+
if not p.requires_grad:
441+
continue
442+
is_custom_key = False
443+
for key in custom_keys:
444+
target_name = f"{prefix}.{name}" if prefix != "" and "." in key else name
445+
if key == target_name:
446+
params[key].append(p)
447+
is_custom_key = True
448+
break
449+
if not is_custom_key:
450+
if norm_weight_decay is not None and isinstance(module, norm_classes):
451+
params["norm"].append(p)
452+
else:
453+
params["other"].append(p)
454+
455+
for child_name, child_module in module.named_children():
456+
child_prefix = f"{prefix}.{child_name}" if prefix != "" else child_name
457+
_add_params(child_module, prefix=child_prefix)
458+
459+
_add_params(model)
460+
461+
param_groups = []
462+
for key in params:
463+
if len(params[key]) > 0:
464+
param_groups.append({"params": params[key], "weight_decay": params_weight_decay[key]})
465+
return param_groups

test/test_extended_models.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,8 @@ def test_schema_meta_validation(model_fn):
115115
incorrect_params.append(w)
116116
else:
117117
if w.meta.get("num_params") != weights_enum.DEFAULT.meta.get("num_params"):
118-
incorrect_params.append(w)
118+
if w.meta.get("num_params") != sum(p.numel() for p in model_fn(weights=w).parameters()):
119+
incorrect_params.append(w)
119120
if not w.name.isupper():
120121
bad_names.append(w)
121122

test/test_video_reader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1225,7 +1225,7 @@ def test_invalid_file(self):
12251225

12261226
@pytest.mark.parametrize("test_video", test_videos.keys())
12271227
@pytest.mark.parametrize("backend", ["video_reader", "pyav"])
1228-
@pytest.mark.parametrize("start_offset", [0, 1000])
1228+
@pytest.mark.parametrize("start_offset", [0, 500])
12291229
@pytest.mark.parametrize("end_offset", [3000, None])
12301230
def test_audio_present_pts(self, test_video, backend, start_offset, end_offset):
12311231
"""Test if audio frames are returned with pts unit."""

torchvision/io/_video_opt.py

Lines changed: 10 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -423,16 +423,6 @@ def _probe_video_from_memory(
423423
return info
424424

425425

426-
def _convert_to_sec(
427-
start_pts: Union[float, Fraction], end_pts: Union[float, Fraction], pts_unit: str, time_base: Fraction
428-
) -> Tuple[Union[float, Fraction], Union[float, Fraction], str]:
429-
if pts_unit == "pts":
430-
start_pts = float(start_pts * time_base)
431-
end_pts = float(end_pts * time_base)
432-
pts_unit = "sec"
433-
return start_pts, end_pts, pts_unit
434-
435-
436426
def _read_video(
437427
filename: str,
438428
start_pts: Union[float, Fraction] = 0,
@@ -452,38 +442,28 @@ def _read_video(
452442

453443
has_video = info.has_video
454444
has_audio = info.has_audio
455-
video_pts_range = (0, -1)
456-
video_timebase = default_timebase
457-
audio_pts_range = (0, -1)
458-
audio_timebase = default_timebase
459-
time_base = default_timebase
460-
461-
if has_video:
462-
video_timebase = Fraction(info.video_timebase.numerator, info.video_timebase.denominator)
463-
time_base = video_timebase
464-
465-
if has_audio:
466-
audio_timebase = Fraction(info.audio_timebase.numerator, info.audio_timebase.denominator)
467-
time_base = time_base if time_base else audio_timebase
468-
469-
# video_timebase is the default time_base
470-
start_pts_sec, end_pts_sec, pts_unit = _convert_to_sec(start_pts, end_pts, pts_unit, time_base)
471445

472446
def get_pts(time_base):
473-
start_offset = start_pts_sec
474-
end_offset = end_pts_sec
447+
start_offset = start_pts
448+
end_offset = end_pts
475449
if pts_unit == "sec":
476-
start_offset = int(math.floor(start_pts_sec * (1 / time_base)))
450+
start_offset = int(math.floor(start_pts * (1 / time_base)))
477451
if end_offset != float("inf"):
478-
end_offset = int(math.ceil(end_pts_sec * (1 / time_base)))
452+
end_offset = int(math.ceil(end_pts * (1 / time_base)))
479453
if end_offset == float("inf"):
480454
end_offset = -1
481455
return start_offset, end_offset
482456

457+
video_pts_range = (0, -1)
458+
video_timebase = default_timebase
483459
if has_video:
460+
video_timebase = Fraction(info.video_timebase.numerator, info.video_timebase.denominator)
484461
video_pts_range = get_pts(video_timebase)
485462

463+
audio_pts_range = (0, -1)
464+
audio_timebase = default_timebase
486465
if has_audio:
466+
audio_timebase = Fraction(info.audio_timebase.numerator, info.audio_timebase.denominator)
487467
audio_pts_range = get_pts(audio_timebase)
488468

489469
vframes, aframes, info = _read_video_from_file(

torchvision/io/video.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -287,13 +287,6 @@ def read_video(
287287
with av.open(filename, metadata_errors="ignore") as container:
288288
if container.streams.audio:
289289
audio_timebase = container.streams.audio[0].time_base
290-
time_base = _video_opt.default_timebase
291-
if container.streams.video:
292-
time_base = container.streams.video[0].time_base
293-
elif container.streams.audio:
294-
time_base = container.streams.audio[0].time_base
295-
# video_timebase is the default time_base
296-
start_pts, end_pts, pts_unit = _video_opt._convert_to_sec(start_pts, end_pts, pts_unit, time_base)
297290
if container.streams.video:
298291
video_frames = _read_from_stream(
299292
container,

torchvision/models/vision_transformer.py

Lines changed: 56 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import math
22
from collections import OrderedDict
33
from functools import partial
4-
from typing import Any, Callable, List, NamedTuple, Optional
4+
from typing import Any, Callable, List, NamedTuple, Optional, Sequence
55

66
import torch
77
import torch.nn as nn
@@ -284,10 +284,21 @@ def _vision_transformer(
284284
progress: bool,
285285
**kwargs: Any,
286286
) -> VisionTransformer:
287-
image_size = kwargs.pop("image_size", 224)
288-
289287
if weights is not None:
290288
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
289+
if isinstance(weights.meta["size"], int):
290+
_ovewrite_named_param(kwargs, "image_size", weights.meta["size"])
291+
elif isinstance(weights.meta["size"], Sequence):
292+
if len(weights.meta["size"]) != 2 or weights.meta["size"][0] != weights.meta["size"][1]:
293+
raise ValueError(
294+
f'size: {weights.meta["size"]} is not valid! Currently we only support a 2-dimensional square and width = height'
295+
)
296+
_ovewrite_named_param(kwargs, "image_size", weights.meta["size"][0])
297+
else:
298+
raise ValueError(
299+
f'weights.meta["size"]: {weights.meta["size"]} is not valid, the type should be either an int or a Sequence[int]'
300+
)
301+
image_size = kwargs.pop("image_size", 224)
291302

292303
model = VisionTransformer(
293304
image_size=image_size,
@@ -313,6 +324,14 @@ def _vision_transformer(
313324
"interpolation": InterpolationMode.BILINEAR,
314325
}
315326

327+
_COMMON_SWAG_META = {
328+
**_COMMON_META,
329+
"publication_year": 2022,
330+
"recipe": "https://github.com/facebookresearch/SWAG",
331+
"license": "https://github.com/facebookresearch/SWAG/blob/main/LICENSE",
332+
"interpolation": InterpolationMode.BICUBIC,
333+
}
334+
316335

317336
class ViT_B_16_Weights(WeightsEnum):
318337
IMAGENET1K_V1 = Weights(
@@ -328,6 +347,23 @@ class ViT_B_16_Weights(WeightsEnum):
328347
"acc@5": 95.318,
329348
},
330349
)
350+
IMAGENET1K_SWAG_V1 = Weights(
351+
url="https://download.pytorch.org/models/vit_b_16_swag-9ac1b537.pth",
352+
transforms=partial(
353+
ImageClassification,
354+
crop_size=384,
355+
resize_size=384,
356+
interpolation=InterpolationMode.BICUBIC,
357+
),
358+
meta={
359+
**_COMMON_SWAG_META,
360+
"num_params": 86859496,
361+
"size": (384, 384),
362+
"min_size": (384, 384),
363+
"acc@1": 85.304,
364+
"acc@5": 97.650,
365+
},
366+
)
331367
DEFAULT = IMAGENET1K_V1
332368

333369

@@ -362,6 +398,23 @@ class ViT_L_16_Weights(WeightsEnum):
362398
"acc@5": 94.638,
363399
},
364400
)
401+
IMAGENET1K_SWAG_V1 = Weights(
402+
url="https://download.pytorch.org/models/vit_l_16_swag-4f3808c9.pth",
403+
transforms=partial(
404+
ImageClassification,
405+
crop_size=512,
406+
resize_size=512,
407+
interpolation=InterpolationMode.BICUBIC,
408+
),
409+
meta={
410+
**_COMMON_SWAG_META,
411+
"num_params": 305174504,
412+
"size": (512, 512),
413+
"min_size": (512, 512),
414+
"acc@1": 88.064,
415+
"acc@5": 98.512,
416+
},
417+
)
365418
DEFAULT = IMAGENET1K_V1
366419

367420

0 commit comments

Comments
 (0)