Skip to content

Commit b969cca

Browse files
bjuncekBruno KorbardatumboxNicolasHug
authored
Use Kinetics instead of Kinetics400 in references (#5787) (#5952)
* Dataset creation now supports "new" version of Kinetics dataset * remove unnecessary warning for now * provide kinetics option * new reading somehow doesn't need BHWC to BCHW transform * Addressing minor comments * Adding kinetics deprication warning for the old Kinetics400 class * lint error * Update torchvision/datasets/kinetics.py Co-authored-by: Nicolas Hug <[email protected]> * Updating README * Remove BHWC to BCHW * Put warning back * formatting Co-authored-by: Bruno Korbar <[email protected]> Co-authored-by: Vasilis Vryniotis <[email protected]> Co-authored-by: Nicolas Hug <[email protected]>
1 parent 16af667 commit b969cca

File tree

6 files changed

+30
-26
lines changed

6 files changed

+30
-26
lines changed

references/video_classification/README.md

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ We assume the training and validation AVI videos are stored at `/data/kinectics4
1818

1919
Run the training on a single node with 8 GPUs:
2020
```bash
21-
torchrun --nproc_per_node=8 train.py --data-path=/data/kinectics400 --train-dir=train --val-dir=val --batch-size=16 --cache-dataset --sync-bn --amp
21+
torchrun --nproc_per_node=8 train.py --data-path=/data/kinectics400 --kinetics-version="400" --batch-size=16 --cache-dataset --sync-bn --amp
2222
```
2323

2424
**Note:** all our models were trained on 8 nodes with 8 V100 GPUs each for a total of 64 GPUs. Expected training time for 64 GPUs is 24 hours, depending on the storage solution.
@@ -30,5 +30,13 @@ torchrun --nproc_per_node=8 train.py --data-path=/data/kinectics400 --train-dir=
3030

3131

3232
```bash
33-
python train.py --data-path=/data/kinectics400 --train-dir=train --val-dir=val --batch-size=8 --cache-dataset
33+
python train.py --data-path=/data/kinectics400 --kinetics-version="400" --batch-size=8 --cache-dataset
3434
```
35+
36+
37+
### Additional Kinetics versions
38+
39+
Since the original release, additional versions of Kinetics dataset became available (Kinetics 600).
40+
Our training scripts support these versions of dataset as well by setting the `--kinetics-version` parameter to `"600"`.
41+
42+
**Note:** training on Kinetics 600 requires a different set of hyperparameters for optimal performance. We do not provide Kinetics 600 pretrained models.

references/video_classification/presets.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import torch
22
from torchvision.transforms import transforms
3-
from transforms import ConvertBHWCtoBCHW, ConvertBCHWtoCBHW
3+
from transforms import ConvertBCHWtoCBHW
44

55

66
class VideoClassificationPresetTrain:
@@ -14,7 +14,6 @@ def __init__(
1414
hflip_prob=0.5,
1515
):
1616
trans = [
17-
ConvertBHWCtoBCHW(),
1817
transforms.ConvertImageDtype(torch.float32),
1918
transforms.Resize(resize_size),
2019
]
@@ -31,7 +30,6 @@ class VideoClassificationPresetEval:
3130
def __init__(self, *, crop_size, resize_size, mean=(0.43216, 0.394666, 0.37645), std=(0.22803, 0.22145, 0.216989)):
3231
self.transforms = transforms.Compose(
3332
[
34-
ConvertBHWCtoBCHW(),
3533
transforms.ConvertImageDtype(torch.float32),
3634
transforms.Resize(resize_size),
3735
transforms.Normalize(mean=mean, std=std),

references/video_classification/train.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -130,8 +130,8 @@ def main(args):
130130

131131
# Data loading code
132132
print("Loading data")
133-
traindir = os.path.join(args.data_path, args.train_dir)
134-
valdir = os.path.join(args.data_path, args.val_dir)
133+
traindir = os.path.join(args.data_path, "train")
134+
valdir = os.path.join(args.data_path, "val")
135135

136136
print("Loading training data")
137137
st = time.time()
@@ -145,9 +145,11 @@ def main(args):
145145
else:
146146
if args.distributed:
147147
print("It is recommended to pre-compute the dataset cache on a single-gpu first, as it will be faster")
148-
dataset = torchvision.datasets.Kinetics400(
149-
traindir,
148+
dataset = torchvision.datasets.Kinetics(
149+
args.data_path,
150150
frames_per_clip=args.clip_len,
151+
num_classes=args.kinetics_version,
152+
split="train",
151153
step_between_clips=1,
152154
transform=transform_train,
153155
frame_rate=15,
@@ -179,9 +181,11 @@ def main(args):
179181
else:
180182
if args.distributed:
181183
print("It is recommended to pre-compute the dataset cache on a single-gpu first, as it will be faster")
182-
dataset_test = torchvision.datasets.Kinetics400(
183-
valdir,
184+
dataset_test = torchvision.datasets.Kinetics(
185+
args.data_path,
184186
frames_per_clip=args.clip_len,
187+
num_classes=args.kinetics_version,
188+
split="val",
185189
step_between_clips=1,
186190
transform=transform_test,
187191
frame_rate=15,
@@ -312,8 +316,9 @@ def parse_args():
312316
parser = argparse.ArgumentParser(description="PyTorch Video Classification Training")
313317

314318
parser.add_argument("--data-path", default="/datasets01_101/kinetics/070618/", type=str, help="dataset path")
315-
parser.add_argument("--train-dir", default="train_avi-480p", type=str, help="name of train dir")
316-
parser.add_argument("--val-dir", default="val_avi-480p", type=str, help="name of val dir")
319+
parser.add_argument(
320+
"--kinetics-version", default="400", type=str, choices=["400", "600"], help="Select kinetics version"
321+
)
317322
parser.add_argument("--model", default="r2plus1d_18", type=str, help="model name")
318323
parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu Default: cuda)")
319324
parser.add_argument("--clip-len", default=16, type=int, metavar="N", help="number of frames per clip")

references/video_classification/transforms.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,6 @@
22
import torch.nn as nn
33

44

5-
class ConvertBHWCtoBCHW(nn.Module):
6-
"""Convert tensor from (B, H, W, C) to (B, C, H, W)"""
7-
8-
def forward(self, vid: torch.Tensor) -> torch.Tensor:
9-
return vid.permute(0, 3, 1, 2)
10-
11-
125
class ConvertBCHWtoCBHW(nn.Module):
136
"""Convert tensor from (B, C, H, W) to (C, B, H, W)"""
147

torchvision/datasets/kinetics.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,7 @@ def __init__(
308308
warnings.warn(
309309
"The Kinetics400 class is deprecated since 0.12 and will be removed in 0.14."
310310
"Please use Kinetics(..., num_classes='400') instead."
311+
"Note that Kinetics(..., num_classes='400') returns video in a more logical Tensor[T, C, H, W] format."
311312
)
312313
if any(value is not None for value in (num_classes, split, download, num_download_workers)):
313314
raise RuntimeError(

torchvision/io/video.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -153,14 +153,13 @@ def _read_from_stream(
153153
gc.collect()
154154

155155
if pts_unit == "sec":
156+
# TODO: we should change all of this from ground up to simply take
157+
# sec and convert to MS in C++
156158
start_offset = int(math.floor(start_offset * (1 / stream.time_base)))
157159
if end_offset != float("inf"):
158160
end_offset = int(math.ceil(end_offset * (1 / stream.time_base)))
159161
else:
160-
warnings.warn(
161-
"The pts_unit 'pts' gives wrong results and will be removed in a "
162-
+ "follow-up version. Please use pts_unit 'sec'."
163-
)
162+
warnings.warn("The pts_unit 'pts' gives wrong results. Please use pts_unit 'sec'.")
164163

165164
frames = {}
166165
should_buffer = True
@@ -176,9 +175,9 @@ def _read_from_stream(
176175
# can't use regex directly because of some weird characters sometimes...
177176
pos = extradata.find(b"DivX")
178177
d = extradata[pos:]
179-
o = re.search(br"DivX(\d+)Build(\d+)(\w)", d)
178+
o = re.search(rb"DivX(\d+)Build(\d+)(\w)", d)
180179
if o is None:
181-
o = re.search(br"DivX(\d+)b(\d+)(\w)", d)
180+
o = re.search(rb"DivX(\d+)b(\d+)(\w)", d)
182181
if o is not None:
183182
should_buffer = o.group(3) == b"p"
184183
seek_offset = start_offset

0 commit comments

Comments
 (0)