Skip to content

extend video reader to support fast video probing #1437

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Oct 12, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions test/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,22 @@ def test_write_read_video(self):
self.assertTrue(data.equal(lv))
self.assertEqual(info["video_fps"], 5)

@unittest.skipIf(not io._HAS_VIDEO_OPT, "video_reader backend is not chosen")
def test_probe_video_from_file(self):
with temp_video(10, 300, 300, 5) as (f_name, data):
video_info = io._probe_video_from_file(f_name)
self.assertAlmostEqual(video_info["video_duration"], 2, delta=0.1)
self.assertAlmostEqual(video_info["video_fps"], 5, delta=0.1)

@unittest.skipIf(not io._HAS_VIDEO_OPT, "video_reader backend is not chosen")
def test_probe_video_from_memory(self):
with temp_video(10, 300, 300, 5) as (f_name, data):
with open(f_name, "rb") as fp:
filebuffer = fp.read()
video_info = io._probe_video_from_memory(filebuffer)
self.assertAlmostEqual(video_info["video_duration"], 2, delta=0.1)
self.assertAlmostEqual(video_info["video_fps"], 5, delta=0.1)

def test_read_timestamps(self):
with temp_video(10, 300, 300, 5) as (f_name, data):
if _video_backend == "pyav":
Expand Down
105 changes: 76 additions & 29 deletions test/test_video_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
VIDEO_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "videos")

CheckerConfig = [
"duration",
"video_fps",
"audio_sample_rate",
# We find for some videos (e.g. HMDB51 videos), the decoded audio frames and pts are
Expand All @@ -44,6 +45,7 @@
)

all_check_config = GroundTruth(
duration=0,
video_fps=0,
audio_sample_rate=0,
check_aframes=True,
Expand All @@ -52,50 +54,58 @@

test_videos = {
"RATRACE_wave_f_nm_np1_fr_goo_37.avi": GroundTruth(
duration=2.0,
video_fps=30.0,
audio_sample_rate=None,
check_aframes=True,
check_aframe_pts=True,
),
"SchoolRulesHowTheyHelpUs_wave_f_nm_np1_ba_med_0.avi": GroundTruth(
duration=2.0,
video_fps=30.0,
audio_sample_rate=None,
check_aframes=True,
check_aframe_pts=True,
),
"TrumanShow_wave_f_nm_np1_fr_med_26.avi": GroundTruth(
duration=2.0,
video_fps=30.0,
audio_sample_rate=None,
check_aframes=True,
check_aframe_pts=True,
),
"v_SoccerJuggling_g23_c01.avi": GroundTruth(
duration=8.0,
video_fps=29.97,
audio_sample_rate=None,
check_aframes=True,
check_aframe_pts=True,
),
"v_SoccerJuggling_g24_c01.avi": GroundTruth(
duration=8.0,
video_fps=29.97,
audio_sample_rate=None,
check_aframes=True,
check_aframe_pts=True,
),
"R6llTwEh07w.mp4": GroundTruth(
duration=10.0,
video_fps=30.0,
audio_sample_rate=44100,
# PyAv miss one audio frame at the beginning (pts=0)
check_aframes=False,
check_aframe_pts=False,
),
"SOX5yA1l24A.mp4": GroundTruth(
duration=11.0,
video_fps=29.97,
audio_sample_rate=48000,
# PyAv miss one audio frame at the beginning (pts=0)
check_aframes=False,
check_aframe_pts=False,
),
"WUzgd7C1pWA.mp4": GroundTruth(
duration=11.0,
video_fps=29.97,
audio_sample_rate=48000,
# PyAv miss one audio frame at the beginning (pts=0)
Expand Down Expand Up @@ -272,13 +282,22 @@ class TestVideoReader(unittest.TestCase):
def check_separate_decoding_result(self, tv_result, config):
"""check the decoding results from TorchVision decoder
"""
vframes, vframe_pts, vtimebase, vfps, aframes, aframe_pts, atimebase, asample_rate = (
tv_result
vframes, vframe_pts, vtimebase, vfps, vduration, aframes, aframe_pts, \
atimebase, asample_rate, aduration = tv_result

video_duration = vduration.item() * Fraction(
vtimebase[0].item(), vtimebase[1].item()
)
self.assertAlmostEqual(video_duration, config.duration, delta=0.5)

self.assertAlmostEqual(vfps.item(), config.video_fps, delta=0.5)
if asample_rate.numel() > 0:
self.assertEqual(asample_rate.item(), config.audio_sample_rate)
audio_duration = aduration.item() * Fraction(
atimebase[0].item(), atimebase[1].item()
)
self.assertAlmostEqual(audio_duration, config.duration, delta=0.5)

# check if pts of video frames are sorted in ascending order
for i in range(len(vframe_pts) - 1):
self.assertEqual(vframe_pts[i] < vframe_pts[i + 1], True)
Expand All @@ -288,6 +307,20 @@ def check_separate_decoding_result(self, tv_result, config):
for i in range(len(aframe_pts) - 1):
self.assertEqual(aframe_pts[i] < aframe_pts[i + 1], True)

def check_probe_result(self, result, config):
vtimebase, vfps, vduration, atimebase, asample_rate, aduration = result
video_duration = vduration.item() * Fraction(
vtimebase[0].item(), vtimebase[1].item()
)
self.assertAlmostEqual(video_duration, config.duration, delta=0.5)
self.assertAlmostEqual(vfps.item(), config.video_fps, delta=0.5)
if asample_rate.numel() > 0:
self.assertEqual(asample_rate.item(), config.audio_sample_rate)
audio_duration = aduration.item() * Fraction(
atimebase[0].item(), atimebase[1].item()
)
self.assertAlmostEqual(audio_duration, config.duration, delta=0.5)

def compare_decoding_result(self, tv_result, ref_result, config=all_check_config):
"""
Compare decoding results from two sources.
Expand All @@ -297,18 +330,17 @@ def compare_decoding_result(self, tv_result, ref_result, config=all_check_config
decoder or TorchVision decoder with getPtsOnly = 1
config: config of decoding results checker
"""
vframes, vframe_pts, vtimebase, _vfps, aframes, aframe_pts, atimebase, _asample_rate = (
tv_result
)
vframes, vframe_pts, vtimebase, _vfps, _vduration, aframes, aframe_pts, \
atimebase, _asample_rate, _aduration = tv_result
if isinstance(ref_result, list):
# the ref_result is from new video_reader decoder
ref_result = DecoderResult(
vframes=ref_result[0],
vframe_pts=ref_result[1],
vtimebase=ref_result[2],
aframes=ref_result[4],
aframe_pts=ref_result[5],
atimebase=ref_result[6],
aframes=ref_result[5],
aframe_pts=ref_result[6],
atimebase=ref_result[7],
)

if vframes.numel() > 0 and ref_result.vframes.numel() > 0:
Expand Down Expand Up @@ -351,12 +383,12 @@ def test_stress_test_read_video_from_file(self):
audio_start_pts, audio_end_pts = 0, -1
audio_timebase_num, audio_timebase_den = 0, 1

for i in range(num_iter):
for test_video, config in test_videos.items():
for _i in range(num_iter):
for test_video, _config in test_videos.items():
full_path = os.path.join(VIDEO_DIR, test_video)

# pass 1: decode all frames using new decoder
_ = torch.ops.video_reader.read_video_from_file(
torch.ops.video_reader.read_video_from_file(
full_path,
seek_frame_margin,
0, # getPtsOnly
Expand Down Expand Up @@ -460,9 +492,8 @@ def test_read_video_from_file_read_single_stream_only(self):
audio_timebase_den,
)

vframes, vframe_pts, vtimebase, vfps, aframes, aframe_pts, atimebase, asample_rate = (
tv_result
)
vframes, vframe_pts, vtimebase, vfps, vduration, aframes, aframe_pts, \
atimebase, asample_rate, aduration = tv_result

self.assertEqual(vframes.numel() > 0, readVideoStream)
self.assertEqual(vframe_pts.numel() > 0, readVideoStream)
Expand All @@ -489,7 +520,7 @@ def test_read_video_from_file_rescale_min_dimension(self):
audio_start_pts, audio_end_pts = 0, -1
audio_timebase_num, audio_timebase_den = 0, 1

for test_video, config in test_videos.items():
for test_video, _config in test_videos.items():
full_path = os.path.join(VIDEO_DIR, test_video)

tv_result = torch.ops.video_reader.read_video_from_file(
Expand Down Expand Up @@ -528,7 +559,7 @@ def test_read_video_from_file_rescale_width(self):
audio_start_pts, audio_end_pts = 0, -1
audio_timebase_num, audio_timebase_den = 0, 1

for test_video, config in test_videos.items():
for test_video, _config in test_videos.items():
full_path = os.path.join(VIDEO_DIR, test_video)

tv_result = torch.ops.video_reader.read_video_from_file(
Expand Down Expand Up @@ -567,7 +598,7 @@ def test_read_video_from_file_rescale_height(self):
audio_start_pts, audio_end_pts = 0, -1
audio_timebase_num, audio_timebase_den = 0, 1

for test_video, config in test_videos.items():
for test_video, _config in test_videos.items():
full_path = os.path.join(VIDEO_DIR, test_video)

tv_result = torch.ops.video_reader.read_video_from_file(
Expand Down Expand Up @@ -606,7 +637,7 @@ def test_read_video_from_file_rescale_width_and_height(self):
audio_start_pts, audio_end_pts = 0, -1
audio_timebase_num, audio_timebase_den = 0, 1

for test_video, config in test_videos.items():
for test_video, _config in test_videos.items():
full_path = os.path.join(VIDEO_DIR, test_video)

tv_result = torch.ops.video_reader.read_video_from_file(
Expand Down Expand Up @@ -651,7 +682,7 @@ def test_read_video_from_file_audio_resampling(self):
audio_start_pts, audio_end_pts = 0, -1
audio_timebase_num, audio_timebase_den = 0, 1

for test_video, config in test_videos.items():
for test_video, _config in test_videos.items():
full_path = os.path.join(VIDEO_DIR, test_video)

tv_result = torch.ops.video_reader.read_video_from_file(
Expand All @@ -674,18 +705,17 @@ def test_read_video_from_file_audio_resampling(self):
audio_timebase_num,
audio_timebase_den,
)
vframes, vframe_pts, vtimebase, vfps, aframes, aframe_pts, atimebase, a_sample_rate = (
tv_result
)
vframes, vframe_pts, vtimebase, vfps, vduration, aframes, aframe_pts, \
atimebase, asample_rate, aduration = tv_result
if aframes.numel() > 0:
self.assertEqual(samples, a_sample_rate.item())
self.assertEqual(samples, asample_rate.item())
self.assertEqual(1, aframes.size(1))
# when audio stream is found
duration = float(aframe_pts[-1]) * float(atimebase[0]) / float(atimebase[1])
self.assertAlmostEqual(
aframes.size(0),
int(duration * a_sample_rate.item()),
delta=0.1 * a_sample_rate.item(),
int(duration * asample_rate.item()),
delta=0.1 * asample_rate.item(),
)

def test_compare_read_video_from_memory_and_file(self):
Expand Down Expand Up @@ -859,7 +889,7 @@ def test_read_video_from_memory_get_pts_only(self):
)

self.assertEqual(tv_result_pts_only[0].numel(), 0)
self.assertEqual(tv_result_pts_only[4].numel(), 0)
self.assertEqual(tv_result_pts_only[5].numel(), 0)
self.compare_decoding_result(tv_result, tv_result_pts_only)

def test_read_video_in_range_from_memory(self):
Expand Down Expand Up @@ -899,9 +929,8 @@ def test_read_video_in_range_from_memory(self):
audio_timebase_num,
audio_timebase_den,
)
vframes, vframe_pts, vtimebase, vfps, aframes, aframe_pts, atimebase, asample_rate = (
tv_result
)
vframes, vframe_pts, vtimebase, vfps, vduration, aframes, aframe_pts, \
atimebase, asample_rate, aduration = tv_result
self.assertAlmostEqual(config.video_fps, vfps.item(), delta=0.01)

for num_frames in [4, 8, 16, 32, 64, 128]:
Expand Down Expand Up @@ -997,6 +1026,24 @@ def test_read_video_in_range_from_memory(self):
# and PyAv
self.compare_decoding_result(tv_result, pyav_result, config)

def test_probe_video_from_file(self):
"""
Test the case when decoder probes a video file
"""
for test_video, config in test_videos.items():
full_path = os.path.join(VIDEO_DIR, test_video)
probe_result = torch.ops.video_reader.probe_video_from_file(full_path)
self.check_probe_result(probe_result, config)

def test_probe_video_from_memory(self):
"""
Test the case when decoder probes a video in memory
"""
for test_video, config in test_videos.items():
full_path, video_tensor = _get_video_tensor(VIDEO_DIR, test_video)
probe_result = torch.ops.video_reader.probe_video_from_memory(video_tensor)
self.check_probe_result(probe_result, config)


if __name__ == '__main__':
unittest.main()
1 change: 1 addition & 0 deletions torchvision/csrc/cpu/video_reader/FfmpegAudioStream.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ void FfmpegAudioStream::updateStreamDecodeParams() {
mediaFormat_.format.audio.timeBaseDen =
inputCtx_->streams[index_]->time_base.den;
}
mediaFormat_.format.audio.duration = inputCtx_->streams[index_]->duration;
}

int FfmpegAudioStream::initFormat() {
Expand Down
34 changes: 34 additions & 0 deletions torchvision/csrc/cpu/video_reader/FfmpegDecoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,30 @@ int FfmpegDecoder::decodeMemory(
return ret;
}

int FfmpegDecoder::probeFile(
unique_ptr<DecoderParameters> params,
const string& fileName,
DecoderOutput& decoderOutput) {
VLOG(1) << "probe file: " << fileName;
FfmpegAvioContext ioctx;
return probeVideo(std::move(params), fileName, true, ioctx, decoderOutput);
}

int FfmpegDecoder::probeMemory(
unique_ptr<DecoderParameters> params,
const uint8_t* buffer,
int64_t size,
DecoderOutput& decoderOutput) {
VLOG(1) << "probe video data in memory";
FfmpegAvioContext ioctx;
int ret = ioctx.initAVIOContext(buffer, size);
if (ret == 0) {
ret =
probeVideo(std::move(params), string(""), false, ioctx, decoderOutput);
}
return ret;
}

void FfmpegDecoder::cleanUp() {
if (formatCtx_) {
for (auto& stream : streams_) {
Expand Down Expand Up @@ -320,6 +344,16 @@ int FfmpegDecoder::decodeLoop(
return ret;
}

int FfmpegDecoder::probeVideo(
unique_ptr<DecoderParameters> params,
const std::string& filename,
bool isDecodeFile,
FfmpegAvioContext& ioctx,
DecoderOutput& decoderOutput) {
params_ = std::move(params);
return init(filename, isDecodeFile, ioctx, decoderOutput);
}

bool FfmpegDecoder::initStreams() {
for (auto it = params_->formats.begin(); it != params_->formats.end(); ++it) {
AVMediaType mediaType;
Expand Down
Loading