-
Notifications
You must be signed in to change notification settings - Fork 7.1k
[WIP] UCF101 prototype with utilities for video loading #4838
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 27 commits
9c9b27e
e00c095
914380f
9dd6786
7ad8357
dc205e9
711adf3
56c1779
65f3c64
017e9b9
666ca6e
acc0e54
c209153
31c0eb7
f5eb8fd
0a66ff0
d29d22b
cf4f354
52b2b67
5e2f15d
b608f6d
4f281c4
ab6a2b8
9800f8e
64b644f
7557931
18eb9c0
587723e
a3737ab
8fce5ff
a10a3a0
697fdfd
ebef4f2
62078b6
a574089
d809cb9
8f57ee6
8f21f0e
8dbda84
788d82a
31a8929
84cdecb
4386c48
97bd457
4609783
1c77e6f
6019ce7
25c3668
08a616c
0675649
381f70e
f1a69e0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
from torchvision.prototype import datasets | ||
from torchvision.prototype.datasets.video_utils import AVKeyframeReader, AVRandomFrameReader, AVClipReader | ||
|
||
|
||
|
||
print("\n \n KEYFRAMES \n \n") | ||
ct = 0 | ||
dataset = AVKeyframeReader(datasets.load("ucf101")) | ||
for i in dataset: | ||
print(i) | ||
ct += 1 | ||
if ct > 5: | ||
break | ||
|
||
|
||
print("\n \n RANDOM FRAMES") | ||
ct = 0 | ||
dataset = AVRandomFrameReader(datasets.load("ucf101"), num_samples=3) | ||
for i in dataset: | ||
print(i) | ||
ct += 1 | ||
if ct > 5: | ||
break | ||
|
||
print("\n \n CLIPS ") | ||
ct = 0 | ||
dataset = AVClipReader(datasets.load("ucf101"), num_frames_per_clip=16, num_clips_per_video=8) | ||
for i in dataset: | ||
print(i['path'], i["range"]) | ||
ct += 1 | ||
if ct > 5: | ||
break |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
import io | ||
import pathlib | ||
from typing import Any, Callable, Dict, List, Optional, Tuple | ||
|
||
from torchvision.prototype.datasets.utils._internal import RarArchiveReader, INFINITE_BUFFER_SIZE | ||
|
||
import torch | ||
from torchdata.datapipes.iter import CSVParser, KeyZipper | ||
from torch.utils.data import IterDataPipe | ||
from torch.utils.data.datapipes.iter import ( | ||
Filter, | ||
Mapper, | ||
ZipArchiveReader, | ||
Shuffler, | ||
) | ||
from torchvision.prototype.datasets.utils._internal import path_accessor, path_comparator | ||
from torchvision.prototype.datasets.utils import ( | ||
Dataset, | ||
DatasetConfig, | ||
DatasetInfo, | ||
HttpResource, | ||
OnlineResource, | ||
DatasetType, | ||
) | ||
|
||
|
||
class ucf101(Dataset): | ||
bjuncek marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"""This is a base datapipe that returns a file handler of the video. | ||
What we want to do is implement either several decoder options or additional | ||
datapipe extensions to make this work. | ||
""" | ||
bjuncek marked this conversation as resolved.
Show resolved
Hide resolved
|
||
def _make_info(self) -> DatasetInfo: | ||
return DatasetInfo( | ||
"ucf101", | ||
type=DatasetType.VIDEO, | ||
valid_options={'split': ["train", "test"], 'fold': ["1", "2", "3"]}, | ||
homepage="https://www.crcv.ucf.edu/data/UCF101.php", | ||
) | ||
|
||
def resources(self, config: DatasetConfig) -> List[OnlineResource]: | ||
return [ | ||
HttpResource( | ||
"https://www.crcv.ucf.edu/data/UCF101/UCF101TrainTestSplits-RecognitionTask.zip", | ||
sha256="5c0d1a53b8ed364a2ac830a73f405e51bece7d98ce1254fd19ed4a36b224bd27", | ||
), | ||
HttpResource( | ||
"https://www.crcv.ucf.edu/data/UCF101/UCF101.rar", | ||
sha256="ca8dfadb4c891cb11316f94d52b6b0ac2a11994e67a0cae227180cd160bd8e55", | ||
) | ||
] | ||
|
||
def _generate_categories(self, root: pathlib.Path) -> List[str]: | ||
dp = self.resources(self.default_config)[1].to_datapipe(pathlib.Path(root) / self.name) | ||
dp = RarArchiveReader(dp) | ||
dir_names = {pathlib.Path(path).parent.name for path, _ in dp} | ||
return [name.split(".")[1] for name in sorted(dir_names)] | ||
|
||
def _collate_and_decode( | ||
self, | ||
data: Tuple[Tuple[str, int], Tuple[str, io.IOBase]], | ||
*, | ||
decoder: Optional[Callable[[io.IOBase], Dict[str, Any]]] = None, | ||
) -> Dict[str, Any]: | ||
annotations_d, file_d = data | ||
label = annotations_d[1] | ||
_path, file_handle = file_d | ||
file = decoder(file_handle) if decoder else file_handle | ||
return {"path": _path, "file": file, "target": label} | ||
|
||
def _make_datapipe( | ||
self, | ||
resource_dps: List[IterDataPipe], | ||
*, | ||
config: DatasetConfig, | ||
decoder: Optional[Callable[[io.IOBase], torch.Tensor]], | ||
) -> IterDataPipe[Dict[str, Any]]: | ||
|
||
annotations = resource_dps[0] | ||
files = resource_dps[1] | ||
|
||
annotations_dp = ZipArchiveReader(annotations) | ||
annotations_dp = Filter(annotations_dp, path_comparator("name", f"{config.split}list0{config.fold}.txt")) | ||
annotations_dp = CSVParser(annotations_dp, delimiter=" ") | ||
# COMMENT OUT FOR TESTING | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. True, but should be removed before merge. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would really want our datasets to be deterministic outside of a DataLoader though. Making it stochastic will make it much harder to debug. Maybe what we should do instead is have a new Thoughts? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
annotations_dp = Shuffler(annotations_dp, buffer_size=INFINITE_BUFFER_SIZE) | ||
|
||
files_dp = RarArchiveReader(files) | ||
dp = KeyZipper(annotations_dp, files_dp, path_accessor("name")) | ||
return Mapper(dp, self._collate_and_decode, fn_kwargs=dict(decoder=decoder)) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,171 @@ | ||
from typing import Any, Dict, Iterator | ||
import random | ||
import av | ||
import numpy as np | ||
import torch | ||
from torchdata.datapipes.iter import IterDataPipe | ||
from torchvision.io import video, _video_opt | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure if I would use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure. |
||
|
||
class AVKeyframeReader(IterDataPipe[Dict[str, Any]]): | ||
def __init__(self, video_dp: IterDataPipe[Dict[str, Any]]) -> None: | ||
"""TorchData Iterdatapype that takes in video datapipe | ||
and yields all keyframes in a video | ||
|
||
Args: | ||
video_dp (IterDataPipe[Dict[str, Any]]): Video dataset IterDataPipe | ||
""" | ||
self.datapipe = video_dp | ||
|
||
def __iter__(self) -> Iterator[Dict[str, Any]]: | ||
for video_d in self.datapipe: | ||
buffer = video_d.pop("file") | ||
with av.open(buffer, metadata_errors="ignore") as container: | ||
stream = container.streams.video[0] | ||
stream.codec_context.skip_frame = 'NONKEY' | ||
for frame in container.decode(stream): | ||
img = frame.to_image() | ||
yield dict( | ||
video_d, | ||
frame=img, | ||
pts=frame.pts, | ||
video_meta={ | ||
"time_base": float(frame.time_base), | ||
"guessed_fps": float(stream.guessed_rate), | ||
}) | ||
|
||
class AVRandomFrameReader(IterDataPipe[Dict[str, Any]]): | ||
def __init__(self, video_dp: IterDataPipe[Dict[str, Any]], num_samples=1, transform=None) -> None: | ||
"""TorchData Iterdatapype that takes in video datapipe | ||
and yields `num_samples` random frames from a video. | ||
|
||
Args: | ||
video_dp (IterDataPipe[Dict[str, Any]]): Video dataset IterDataPipe | ||
num_samples (int, optional): Number of frames to sample from each video. Defaults to 1. | ||
""" | ||
self.datapipe = video_dp | ||
self.num_samples = num_samples | ||
|
||
def __iter__(self) -> Iterator[Dict[str, Any]]: | ||
for video_d in self.datapipe: | ||
buffer = video_d.pop("file") | ||
with av.open(buffer, metadata_errors="ignore") as container: | ||
stream = container.streams.video[0] | ||
# duration is given in time_base units as int | ||
duration = stream.duration | ||
# seek to a random frame | ||
seek_idxs = random.sample(list(range(duration)), self.num_samples) | ||
for i in seek_idxs: | ||
container.seek(i, any_frame=True, stream=stream) | ||
frame = next(container.decode(stream)) | ||
img = frame.to_image() | ||
|
||
video_meta = {"time_base": float(frame.time_base), | ||
"guessed_fps": float(stream.guessed_rate)} | ||
|
||
yield dict(video_d, frame=img, pts=frame.pts, video_meta=video_meta) | ||
|
||
class AVClipReader(IterDataPipe[Dict[str, Any]]): | ||
def __init__( | ||
self, | ||
video_dp: IterDataPipe[Dict[str, Any]], | ||
num_frames_per_clip=8, | ||
num_clips_per_video=1, | ||
step_between_clips=1) -> None: | ||
"""TorchData Iterdatapype that takes in video datapipe | ||
and yields `num_clips_per_video` video clips (sequences of `num_frames_per_clip` frames) from a video. | ||
Clips are sampled from all possible clips of length `num_frames_per_clip` spaced `step_between_clips` apart. | ||
|
||
Args: | ||
video_dp (IterDataPipe[Dict[str, Any]]): Video dataset IterDataPipe | ||
num_frames_per_clip (int, optional): Length of a video clip in frames. Defaults to 8. | ||
num_clips_per_video (int, optional): How many clips to sample from each video. Defaults to 1. | ||
step_between_clips (int, optional): Minimum spacing between two clips. Defaults to 1. | ||
""" | ||
|
||
self.datapipe = video_dp | ||
self.num_frames_per_clip = num_frames_per_clip | ||
self.num_clips_per_video = num_clips_per_video | ||
self.step_between_clips = step_between_clips | ||
|
||
def _unfold(self, tensor, dilation=1): | ||
""" | ||
similar to tensor.unfold, but with the dilation | ||
and specialized for 1d tensors | ||
Returns all consecutive windows of `self.num_frames_per_clip` elements, with | ||
`self.step_between_clips` between windows. The distance between each element | ||
in a window is given by `dilation`. | ||
""" | ||
assert tensor.dim() == 1 | ||
o_stride = tensor.stride(0) | ||
numel = tensor.numel() | ||
new_stride = (self.step_between_clips * o_stride, dilation * o_stride) | ||
new_size = ((numel - (dilation * (self.num_frames_per_clip - 1) + 1)) // self.step_between_clips + 1, | ||
self.num_frames_per_clip) | ||
if new_size[0] < 1: | ||
new_size = (0, self.num_frames_per_clip) | ||
return torch.as_strided(tensor, new_size, new_stride) | ||
|
||
def __iter__(self) -> Iterator[Dict[str, Any]]: | ||
for video_d in self.datapipe: | ||
buffer = video_d["file"] | ||
with av.open(buffer, metadata_errors="ignore") as container: | ||
stream = container.streams.video[0] | ||
time_base = stream.time_base | ||
|
||
# duration is given in time_base units as int | ||
duration = stream.duration | ||
|
||
# get video_stream timestramps | ||
# with a tolerance for pyav imprecission | ||
_ptss = torch.arange(duration - 7) | ||
_ptss = self._unfold(_ptss) | ||
# shuffle the clips | ||
perm = torch.randperm(_ptss.size(0)) | ||
idx = perm[:self.num_clips_per_video] | ||
samples = _ptss[idx] | ||
|
||
for clip_pts in samples: | ||
start_pts = clip_pts[0].item() | ||
end_pts = clip_pts[-1].item() | ||
# video_timebase is the default time_base | ||
pts_unit = "pts" | ||
start_pts, end_pts, pts_unit = _video_opt._convert_to_sec(start_pts, end_pts, "pts", time_base) | ||
video_frames = video._read_from_stream( | ||
container, | ||
start_pts, | ||
end_pts, | ||
pts_unit, | ||
stream, | ||
{"video": 0}, | ||
) | ||
|
||
vframes_list = [frame.to_ndarray(format='rgb24') for frame in video_frames] | ||
|
||
if vframes_list: | ||
vframes = torch.as_tensor(np.stack(vframes_list)) | ||
# account for rounding errors in conversion | ||
# FIXME: fix this in the code | ||
vframes = vframes[:self.num_frames_per_clip, ...] | ||
|
||
else: | ||
vframes = torch.empty((0, 1, 1, 3), dtype=torch.uint8) | ||
print("FAIL") | ||
|
||
# [N,H,W,C] to [N,C,H,W] | ||
vframes = vframes.permute(0, 3, 1, 2) | ||
assert(vframes.size(0) == self.num_frames_per_clip) | ||
|
||
# TODO: support sampling rates (FPS change) | ||
# TODO: optimization (read all and select) | ||
|
||
yield { | ||
"clip": vframes, | ||
"pts": clip_pts, | ||
"range": (start_pts, end_pts), | ||
"video_meta": { | ||
"time_base": float(stream.time_base), | ||
"guessed_fps": float(stream.guessed_rate), | ||
}, | ||
"path": video_d["path"], | ||
"target": video_d["target"] | ||
} |
Uh oh!
There was an error while loading. Please reload this page.