Skip to content

UCF101 Sketchy Fix #4204

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions torchvision/datasets/ucf101.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ def __init__(
_video_height=_video_height,
_video_min_dimension=_video_min_dimension,
_audio_samples=_audio_samples,
_pts_unit="pts"

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unnecessary new line, plus , at the end would be preferable.

)
# we bookkeep the full version of video clips because we want to be able
# to return the meta data of full version rather than the subset version of
Expand Down
13 changes: 8 additions & 5 deletions torchvision/datasets/video_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,14 +55,15 @@ class _VideoTimestampsDataset(object):
pickled when forking.
"""

def __init__(self, video_paths: List[str]):
def __init__(self, video_paths: List[str], pts_unit: str):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the class being private the reason for not having a default value for pts_unit?

self.video_paths = video_paths
self.pts_unit = pts_unit

def __len__(self):
return len(self.video_paths)

def __getitem__(self, idx):
return read_video_timestamps(self.video_paths[idx])
return read_video_timestamps(self.video_paths[idx], pts_unit=self.pts_unit)


def _collate_fn(x):
Expand Down Expand Up @@ -112,10 +113,13 @@ def __init__(
_video_max_dimension=0,
_audio_samples=0,
_audio_channels=0,
_pts_unit="pts"
):

self.video_paths = video_paths
self.num_workers = num_workers
# a hack to avoid rounding errors
self.pts_unit = _pts_unit

# these options are not valid for pyav backend
self._video_width = _video_width
Expand All @@ -138,9 +142,8 @@ def _compute_frame_pts(self):
# strategy: use a DataLoader to parallelize read_video_timestamps
# so need to create a dummy dataset first
import torch.utils.data

dl = torch.utils.data.DataLoader(
_VideoTimestampsDataset(self.video_paths),
_VideoTimestampsDataset(self.video_paths, self.pts_unit),
batch_size=16,
num_workers=self.num_workers,
collate_fn=_collate_fn,
Expand Down Expand Up @@ -327,7 +330,7 @@ def get_clip(self, idx):
if backend == "pyav":
start_pts = clip_pts[0].item()
end_pts = clip_pts[-1].item()
video, audio, info = read_video(video_path, start_pts, end_pts)
video, audio, info = read_video(video_path, start_pts, end_pts, pts_unit=self.pts_unit)
else:
info = _probe_video_from_file(video_path)
video_fps = info.video_fps
Expand Down