Skip to content

Commit 9a89f9d

Browse files
authored
Merge branch 'main' into ra-reps
2 parents e6d4868 + 1efb567 commit 9a89f9d

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

64 files changed

+324
-181
lines changed

references/classification/README.md

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ torchrun --nproc_per_node=8 train.py\
125125
```
126126
Here `$MODEL` is one of `regnet_x_400mf`, `regnet_x_800mf`, `regnet_x_1_6gf`, `regnet_y_400mf`, `regnet_y_800mf` and `regnet_y_1_6gf`. Please note we used learning rate 0.4 for `regent_y_400mf` to get the same Acc@1 as [the paper)(https://arxiv.org/abs/2003.13678).
127127

128-
### Medium models
128+
#### Medium models
129129
```
130130
torchrun --nproc_per_node=8 train.py\
131131
--model $MODEL --epochs 100 --batch-size 64 --wd 0.00005 --lr=0.4\
@@ -134,7 +134,7 @@ torchrun --nproc_per_node=8 train.py\
134134
```
135135
Here `$MODEL` is one of `regnet_x_3_2gf`, `regnet_x_8gf`, `regnet_x_16gf`, `regnet_y_3_2gf` and `regnet_y_8gf`.
136136

137-
### Large models
137+
#### Large models
138138
```
139139
torchrun --nproc_per_node=8 train.py\
140140
--model $MODEL --epochs 100 --batch-size 32 --wd 0.00005 --lr=0.2\
@@ -143,6 +143,28 @@ torchrun --nproc_per_node=8 train.py\
143143
```
144144
Here `$MODEL` is one of `regnet_x_32gf`, `regnet_y_16gf` and `regnet_y_32gf`.
145145

146+
### Vision Transformer
147+
148+
#### Base models
149+
```
150+
torchrun --nproc_per_node=8 train.py\
151+
--model $MODEL --epochs 300 --batch-size 64 --opt adamw --lr 0.003 --wd 0.3\
152+
--lr-scheduler cosineannealinglr --lr-warmup-method linear --lr-warmup-epochs 30\
153+
--lr-warmup-decay 0.033 --amp --label-smoothing 0.11 --mixup-alpha 0.2 --auto-augment ra\
154+
--clip-grad-norm 1 --ra-sampler --cutmix-alpha 1.0 --model-ema
155+
```
156+
Here `$MODEL` is one of `vit_b_16` and `vit_b_32`.
157+
158+
#### Large models
159+
```
160+
torchrun --nproc_per_node=8 train.py\
161+
--model $MODEL --epochs 300 --batch-size 16 --opt adamw --lr 0.003 --wd 0.3\
162+
--lr-scheduler cosineannealinglr --lr-warmup-method linear --lr-warmup-epochs 30\
163+
--lr-warmup-decay 0.033 --amp --label-smoothing 0.11 --mixup-alpha 0.2 --auto-augment ra\
164+
--clip-grad-norm 1 --ra-sampler --cutmix-alpha 1.0 --model-ema
165+
```
166+
Here `$MODEL` is one of `vit_l_16` and `vit_l_32`.
167+
146168
## Mixed precision training
147169
Automatic Mixed Precision (AMP) training on GPU for Pytorch can be enabled with the [torch.cuda.amp](https://pytorch.org/docs/stable/amp.html?highlight=amp#module-torch.cuda.amp).
148170

references/optical_flow/README.md

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,14 @@ training and evaluation scripts to quickly bootstrap research.
1010
The RAFT large model was trained on Flying Chairs and then on Flying Things.
1111
Both used 8 A100 GPUs and a batch size of 2 (so effective batch size is 16). The
1212
rest of the hyper-parameters are exactly the same as the original RAFT training
13-
recipe from https://github.com/princeton-vl/RAFT.
13+
recipe from https://github.com/princeton-vl/RAFT. The original recipe trains for
14+
100000 updates (or steps) on each dataset - this corresponds to about 72 and 20
15+
epochs on Chairs and Things respectively:
16+
17+
```
18+
num_epochs = ceil(num_steps / number_of_steps_per_epoch)
19+
= ceil(num_steps / (num_samples / effective_batch_size))
20+
```
1421

1522
```
1623
torchrun --nproc_per_node 8 --nnodes 1 train.py \
@@ -21,7 +28,7 @@ torchrun --nproc_per_node 8 --nnodes 1 train.py \
2128
--batch-size 2 \
2229
--lr 0.0004 \
2330
--weight-decay 0.0001 \
24-
--num-steps 100000 \
31+
--epochs 72 \
2532
--output-dir $chairs_dir
2633
```
2734

@@ -34,7 +41,7 @@ torchrun --nproc_per_node 8 --nnodes 1 train.py \
3441
--batch-size 2 \
3542
--lr 0.000125 \
3643
--weight-decay 0.0001 \
37-
--num-steps 100000 \
44+
--epochs 20 \
3845
--freeze-batch-norm \
3946
--output-dir $things_dir\
4047
--resume $chairs_dir/$name_chairs.pth

references/optical_flow/train.py

Lines changed: 8 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import argparse
22
import warnings
3+
from math import ceil
34
from pathlib import Path
45

56
import torch
@@ -168,7 +169,7 @@ def validate(model, args):
168169
warnings.warn(f"Can't validate on {val_dataset}, skipping.")
169170

170171

171-
def train_one_epoch(model, optimizer, scheduler, train_loader, logger, current_step, args):
172+
def train_one_epoch(model, optimizer, scheduler, train_loader, logger, args):
172173
for data_blob in logger.log_every(train_loader):
173174

174175
optimizer.zero_grad()
@@ -189,13 +190,6 @@ def train_one_epoch(model, optimizer, scheduler, train_loader, logger, current_s
189190
optimizer.step()
190191
scheduler.step()
191192

192-
current_step += 1
193-
194-
if current_step == args.num_steps:
195-
return True, current_step
196-
197-
return False, current_step
198-
199193

200194
def main(args):
201195
utils.setup_ddp(args)
@@ -243,7 +237,8 @@ def main(args):
243237
scheduler = torch.optim.lr_scheduler.OneCycleLR(
244238
optimizer=optimizer,
245239
max_lr=args.lr,
246-
total_steps=args.num_steps + 100,
240+
epochs=args.epochs,
241+
steps_per_epoch=ceil(len(train_dataset) / (args.world_size * args.batch_size)),
247242
pct_start=0.05,
248243
cycle_momentum=False,
249244
anneal_strategy="linear",
@@ -252,26 +247,22 @@ def main(args):
252247
logger = utils.MetricLogger()
253248

254249
done = False
255-
current_epoch = current_step = 0
256-
while not done:
250+
for current_epoch in range(args.epochs):
257251
print(f"EPOCH {current_epoch}")
258252

259253
sampler.set_epoch(current_epoch) # needed, otherwise the data loading order would be the same for all epochs
260-
done, current_step = train_one_epoch(
254+
train_one_epoch(
261255
model=model,
262256
optimizer=optimizer,
263257
scheduler=scheduler,
264258
train_loader=train_loader,
265259
logger=logger,
266-
current_step=current_step,
267260
args=args,
268261
)
269262

270263
# Note: we don't sync the SmoothedValues across processes, so the printed metrics are just those of rank 0
271264
print(f"Epoch {current_epoch} done. ", logger)
272265

273-
current_epoch += 1
274-
275266
if args.rank == 0:
276267
# TODO: Also save the optimizer and scheduler
277268
torch.save(model.state_dict(), Path(args.output_dir) / f"{args.name}_{current_epoch}.pth")
@@ -310,10 +301,8 @@ def get_args_parser(add_help=True):
310301
)
311302
parser.add_argument("--val-dataset", type=str, nargs="+", help="The dataset(s) to use for validation.")
312303
parser.add_argument("--val-freq", type=int, default=2, help="Validate every X epochs")
313-
# TODO: eventually, it might be preferable to support epochs instead of num_steps.
314-
# Keeping it this way for now to reproduce results more easily.
315-
parser.add_argument("--num-steps", type=int, default=100000, help="The total number of steps (updates) to train.")
316-
parser.add_argument("--batch-size", type=int, default=6)
304+
parser.add_argument("--epochs", type=int, default=20, help="The total number of epochs to train.")
305+
parser.add_argument("--batch-size", type=int, default=2)
317306

318307
parser.add_argument("--lr", type=float, default=0.00002, help="Learning rate for AdamW optimizer")
319308
parser.add_argument("--weight-decay", type=float, default=0.00005, help="Weight decay for AdamW optimizer")

test/test_prototype_builtin_datasets.py

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,11 @@
22

33
import builtin_dataset_mocks
44
import pytest
5+
import torch
6+
from torch.utils.data.datapipes.iter.grouping import ShardingFilterIterDataPipe as ShardingFilter
57
from torch.utils.data.graph import traverse
6-
from torchdata.datapipes.iter import IterDataPipe
7-
from torchvision.prototype import datasets, features
8+
from torchdata.datapipes.iter import IterDataPipe, Shuffler
9+
from torchvision.prototype import datasets, transforms
810
from torchvision.prototype.datasets._api import DEFAULT_DECODER
911
from torchvision.prototype.utils._internal import sequence_to_str
1012

@@ -88,15 +90,36 @@ def test_decoding(self, dataset, mock_info):
8890
)
8991

9092
@dataset_parametrization(decoder=DEFAULT_DECODER)
91-
def test_at_least_one_feature(self, dataset, mock_info):
92-
sample = next(iter(dataset))
93-
if not any(isinstance(value, features.Feature) for value in sample.values()):
94-
raise AssertionError("The sample contained no feature.")
93+
def test_no_vanilla_tensors(self, dataset, mock_info):
94+
vanilla_tensors = {key for key, value in next(iter(dataset)).items() if type(value) is torch.Tensor}
95+
if vanilla_tensors:
96+
raise AssertionError(
97+
f"The values of key(s) "
98+
f"{sequence_to_str(sorted(vanilla_tensors), separate_last='and ')} contained vanilla tensors."
99+
)
100+
101+
@dataset_parametrization()
102+
def test_transformable(self, dataset, mock_info):
103+
next(iter(dataset.map(transforms.Identity())))
95104

96105
@dataset_parametrization()
97106
def test_traversable(self, dataset, mock_info):
98107
traverse(dataset)
99108

109+
@dataset_parametrization()
110+
@pytest.mark.parametrize("annotation_dp_type", (Shuffler, ShardingFilter), ids=lambda type: type.__name__)
111+
def test_has_annotations(self, dataset, mock_info, annotation_dp_type):
112+
def scan(graph):
113+
for node, sub_graph in graph.items():
114+
yield node
115+
yield from scan(sub_graph)
116+
117+
for dp in scan(traverse(dataset)):
118+
if type(dp) is annotation_dp_type:
119+
break
120+
else:
121+
raise AssertionError(f"The dataset doesn't comprise a {annotation_dp_type.__name__}() datapipe.")
122+
100123

101124
class TestQMNIST:
102125
@pytest.mark.parametrize(

torchvision/csrc/io/video/video.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ Video::Video(std::string videoPath, std::string stream, int64_t numThreads) {
188188
c10::Dict<std::string, std::vector<double>> ccMetadata;
189189
c10::Dict<std::string, std::vector<double>> subsMetadata;
190190

191-
// calback and metadata defined in struct
191+
// callback and metadata defined in struct
192192
succeeded = decoder.init(params, std::move(callback), &metadata);
193193
if (succeeded) {
194194
for (const auto& header : metadata) {
@@ -254,7 +254,7 @@ bool Video::setCurrentStream(std::string stream = "video") {
254254
numThreads_ // global number of threads
255255
);
256256

257-
// calback and metadata defined in Video.h
257+
// callback and metadata defined in Video.h
258258
return (decoder.init(params, std::move(callback), &metadata));
259259
}
260260

@@ -280,7 +280,7 @@ void Video::Seek(double ts, bool fastSeek = false) {
280280
numThreads_ // global number of threads
281281
);
282282

283-
// calback and metadata defined in Video.h
283+
// callback and metadata defined in Video.h
284284
succeeded = decoder.init(params, std::move(callback), &metadata);
285285
LOG(INFO) << "Decoder init at seek " << succeeded << "\n";
286286
}

torchvision/csrc/ops/deform_conv2d.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ at::Tensor deform_conv2d(
2020
int64_t groups,
2121
int64_t offset_groups,
2222
bool use_mask) {
23+
C10_LOG_API_USAGE_ONCE("torchvision.csrc.ops.deform_conv2d.deform_conv2d");
2324
static auto op = c10::Dispatcher::singleton()
2425
.findSchemaOrThrow("torchvision::deform_conv2d", "")
2526
.typed<decltype(deform_conv2d)>();

torchvision/csrc/ops/nms.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ at::Tensor nms(
99
const at::Tensor& dets,
1010
const at::Tensor& scores,
1111
double iou_threshold) {
12+
C10_LOG_API_USAGE_ONCE("torchvision.csrc.ops.nms.nms");
1213
static auto op = c10::Dispatcher::singleton()
1314
.findSchemaOrThrow("torchvision::nms", "")
1415
.typed<decltype(nms)>();

torchvision/csrc/ops/ps_roi_align.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ std::tuple<at::Tensor, at::Tensor> ps_roi_align(
1212
int64_t pooled_height,
1313
int64_t pooled_width,
1414
int64_t sampling_ratio) {
15+
C10_LOG_API_USAGE_ONCE("torchvision.csrc.ops.ps_roi_align.ps_roi_align");
1516
static auto op = c10::Dispatcher::singleton()
1617
.findSchemaOrThrow("torchvision::ps_roi_align", "")
1718
.typed<decltype(ps_roi_align)>();

torchvision/csrc/ops/ps_roi_pool.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ std::tuple<at::Tensor, at::Tensor> ps_roi_pool(
1111
double spatial_scale,
1212
int64_t pooled_height,
1313
int64_t pooled_width) {
14+
C10_LOG_API_USAGE_ONCE("torchvision.csrc.ops.ps_roi_pool.ps_roi_pool");
1415
static auto op = c10::Dispatcher::singleton()
1516
.findSchemaOrThrow("torchvision::ps_roi_pool", "")
1617
.typed<decltype(ps_roi_pool)>();

torchvision/csrc/ops/roi_align.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ at::Tensor roi_align(
1616
bool aligned) // The flag for pixel shift
1717
// along each axis.
1818
{
19+
C10_LOG_API_USAGE_ONCE("torchvision.csrc.ops.roi_align.roi_align");
1920
static auto op = c10::Dispatcher::singleton()
2021
.findSchemaOrThrow("torchvision::roi_align", "")
2122
.typed<decltype(roi_align)>();

torchvision/csrc/ops/roi_pool.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ std::tuple<at::Tensor, at::Tensor> roi_pool(
1111
double spatial_scale,
1212
int64_t pooled_height,
1313
int64_t pooled_width) {
14+
C10_LOG_API_USAGE_ONCE("torchvision.csrc.ops.roi_pool.roi_pool");
1415
static auto op = c10::Dispatcher::singleton()
1516
.findSchemaOrThrow("torchvision::roi_pool", "")
1617
.typed<decltype(roi_pool)>();

torchvision/datasets/vision.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def __init__(
3535
transform: Optional[Callable] = None,
3636
target_transform: Optional[Callable] = None,
3737
) -> None:
38-
_log_api_usage_once("datasets", self.__class__.__name__)
38+
_log_api_usage_once(self)
3939
if isinstance(root, torch._six.string_classes):
4040
root = os.path.expanduser(root)
4141
self.root = root

torchvision/extension.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
import ctypes
2+
import os
3+
import sys
4+
from warnings import warn
5+
16
import torch
27

38
from ._internally_replaced_utils import _get_extension_path
@@ -67,4 +72,22 @@ def _check_cuda_version():
6772
return _version
6873

6974

75+
def _load_library(lib_name):
76+
lib_path = _get_extension_path(lib_name)
77+
# On Windows Python-3.8+ has `os.add_dll_directory` call,
78+
# which is called from _get_extension_path to configure dll search path
79+
# Condition below adds a workaround for older versions by
80+
# explicitly calling `LoadLibraryExW` with the following flags:
81+
# - LOAD_LIBRARY_SEARCH_DEFAULT_DIRS (0x1000)
82+
# - LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR (0x100)
83+
if os.name == "nt" and sys.version_info < (3, 8):
84+
_kernel32 = ctypes.WinDLL("kernel32.dll", use_last_error=True)
85+
if hasattr(_kernel32, "LoadLibraryExW"):
86+
_kernel32.LoadLibraryExW(lib_path, None, 0x00001100)
87+
else:
88+
warn("LoadLibraryExW is missing in kernel32.dll")
89+
90+
torch.ops.load_library(lib_path)
91+
92+
7093
_check_cuda_version()

torchvision/io/_video_opt.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,11 @@
55

66
import torch
77

8-
from .._internally_replaced_utils import _get_extension_path
8+
from ..extension import _load_library
99

1010

1111
try:
12-
lib_path = _get_extension_path("video_reader")
13-
torch.ops.load_library(lib_path)
12+
_load_library("video_reader")
1413
_HAS_VIDEO_OPT = True
1514
except (ImportError, OSError):
1615
_HAS_VIDEO_OPT = False

torchvision/io/image.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
from enum import Enum
2+
from warnings import warn
23

34
import torch
45

5-
from .._internally_replaced_utils import _get_extension_path
6+
from ..extension import _load_library
67

78

89
try:
9-
lib_path = _get_extension_path("image")
10-
torch.ops.load_library(lib_path)
11-
except (ImportError, OSError):
12-
pass
10+
_load_library("image")
11+
except (ImportError, OSError) as e:
12+
warn(f"Failed to load image Python extension: {e}")
1313

1414

1515
class ImageReadMode(Enum):

torchvision/models/alexnet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
class AlexNet(nn.Module):
1919
def __init__(self, num_classes: int = 1000, dropout: float = 0.5) -> None:
2020
super().__init__()
21-
_log_api_usage_once("models", self.__class__.__name__)
21+
_log_api_usage_once(self)
2222
self.features = nn.Sequential(
2323
nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
2424
nn.ReLU(inplace=True),

torchvision/models/densenet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ def __init__(
163163
) -> None:
164164

165165
super().__init__()
166-
_log_api_usage_once("models", self.__class__.__name__)
166+
_log_api_usage_once(self)
167167

168168
# First convolution
169169
self.features = nn.Sequential(

torchvision/models/detection/generalized_rcnn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ class GeneralizedRCNN(nn.Module):
2727

2828
def __init__(self, backbone: nn.Module, rpn: nn.Module, roi_heads: nn.Module, transform: nn.Module) -> None:
2929
super().__init__()
30-
_log_api_usage_once("models", self.__class__.__name__)
30+
_log_api_usage_once(self)
3131
self.transform = transform
3232
self.backbone = backbone
3333
self.rpn = rpn

0 commit comments

Comments
 (0)