Skip to content

Commit 131b90a

Browse files
stephenyan1231fmassa
authored andcommitted
extend video reader to support fast video probing (#1437)
* extend video reader to support fast video probing * fix c++ lint * small fix * allow to accept input video of type torch.Tensor
1 parent bc68234 commit 131b90a

10 files changed

+332
-57
lines changed

test/test_io.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,22 @@ def test_write_read_video(self):
8686
self.assertTrue(data.equal(lv))
8787
self.assertEqual(info["video_fps"], 5)
8888

89+
@unittest.skipIf(not io._HAS_VIDEO_OPT, "video_reader backend is not chosen")
90+
def test_probe_video_from_file(self):
91+
with temp_video(10, 300, 300, 5) as (f_name, data):
92+
video_info = io._probe_video_from_file(f_name)
93+
self.assertAlmostEqual(video_info["video_duration"], 2, delta=0.1)
94+
self.assertAlmostEqual(video_info["video_fps"], 5, delta=0.1)
95+
96+
@unittest.skipIf(not io._HAS_VIDEO_OPT, "video_reader backend is not chosen")
97+
def test_probe_video_from_memory(self):
98+
with temp_video(10, 300, 300, 5) as (f_name, data):
99+
with open(f_name, "rb") as fp:
100+
filebuffer = fp.read()
101+
video_info = io._probe_video_from_memory(filebuffer)
102+
self.assertAlmostEqual(video_info["video_duration"], 2, delta=0.1)
103+
self.assertAlmostEqual(video_info["video_fps"], 5, delta=0.1)
104+
89105
def test_read_timestamps(self):
90106
with temp_video(10, 300, 300, 5) as (f_name, data):
91107
if _video_backend == "pyav":

test/test_video_reader.py

Lines changed: 76 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
VIDEO_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "videos")
3232

3333
CheckerConfig = [
34+
"duration",
3435
"video_fps",
3536
"audio_sample_rate",
3637
# We find for some videos (e.g. HMDB51 videos), the decoded audio frames and pts are
@@ -44,6 +45,7 @@
4445
)
4546

4647
all_check_config = GroundTruth(
48+
duration=0,
4749
video_fps=0,
4850
audio_sample_rate=0,
4951
check_aframes=True,
@@ -52,50 +54,58 @@
5254

5355
test_videos = {
5456
"RATRACE_wave_f_nm_np1_fr_goo_37.avi": GroundTruth(
57+
duration=2.0,
5558
video_fps=30.0,
5659
audio_sample_rate=None,
5760
check_aframes=True,
5861
check_aframe_pts=True,
5962
),
6063
"SchoolRulesHowTheyHelpUs_wave_f_nm_np1_ba_med_0.avi": GroundTruth(
64+
duration=2.0,
6165
video_fps=30.0,
6266
audio_sample_rate=None,
6367
check_aframes=True,
6468
check_aframe_pts=True,
6569
),
6670
"TrumanShow_wave_f_nm_np1_fr_med_26.avi": GroundTruth(
71+
duration=2.0,
6772
video_fps=30.0,
6873
audio_sample_rate=None,
6974
check_aframes=True,
7075
check_aframe_pts=True,
7176
),
7277
"v_SoccerJuggling_g23_c01.avi": GroundTruth(
78+
duration=8.0,
7379
video_fps=29.97,
7480
audio_sample_rate=None,
7581
check_aframes=True,
7682
check_aframe_pts=True,
7783
),
7884
"v_SoccerJuggling_g24_c01.avi": GroundTruth(
85+
duration=8.0,
7986
video_fps=29.97,
8087
audio_sample_rate=None,
8188
check_aframes=True,
8289
check_aframe_pts=True,
8390
),
8491
"R6llTwEh07w.mp4": GroundTruth(
92+
duration=10.0,
8593
video_fps=30.0,
8694
audio_sample_rate=44100,
8795
# PyAv miss one audio frame at the beginning (pts=0)
8896
check_aframes=False,
8997
check_aframe_pts=False,
9098
),
9199
"SOX5yA1l24A.mp4": GroundTruth(
100+
duration=11.0,
92101
video_fps=29.97,
93102
audio_sample_rate=48000,
94103
# PyAv miss one audio frame at the beginning (pts=0)
95104
check_aframes=False,
96105
check_aframe_pts=False,
97106
),
98107
"WUzgd7C1pWA.mp4": GroundTruth(
108+
duration=11.0,
99109
video_fps=29.97,
100110
audio_sample_rate=48000,
101111
# PyAv miss one audio frame at the beginning (pts=0)
@@ -272,13 +282,22 @@ class TestVideoReader(unittest.TestCase):
272282
def check_separate_decoding_result(self, tv_result, config):
273283
"""check the decoding results from TorchVision decoder
274284
"""
275-
vframes, vframe_pts, vtimebase, vfps, aframes, aframe_pts, atimebase, asample_rate = (
276-
tv_result
285+
vframes, vframe_pts, vtimebase, vfps, vduration, aframes, aframe_pts, \
286+
atimebase, asample_rate, aduration = tv_result
287+
288+
video_duration = vduration.item() * Fraction(
289+
vtimebase[0].item(), vtimebase[1].item()
277290
)
291+
self.assertAlmostEqual(video_duration, config.duration, delta=0.5)
278292

279293
self.assertAlmostEqual(vfps.item(), config.video_fps, delta=0.5)
280294
if asample_rate.numel() > 0:
281295
self.assertEqual(asample_rate.item(), config.audio_sample_rate)
296+
audio_duration = aduration.item() * Fraction(
297+
atimebase[0].item(), atimebase[1].item()
298+
)
299+
self.assertAlmostEqual(audio_duration, config.duration, delta=0.5)
300+
282301
# check if pts of video frames are sorted in ascending order
283302
for i in range(len(vframe_pts) - 1):
284303
self.assertEqual(vframe_pts[i] < vframe_pts[i + 1], True)
@@ -288,6 +307,20 @@ def check_separate_decoding_result(self, tv_result, config):
288307
for i in range(len(aframe_pts) - 1):
289308
self.assertEqual(aframe_pts[i] < aframe_pts[i + 1], True)
290309

310+
def check_probe_result(self, result, config):
311+
vtimebase, vfps, vduration, atimebase, asample_rate, aduration = result
312+
video_duration = vduration.item() * Fraction(
313+
vtimebase[0].item(), vtimebase[1].item()
314+
)
315+
self.assertAlmostEqual(video_duration, config.duration, delta=0.5)
316+
self.assertAlmostEqual(vfps.item(), config.video_fps, delta=0.5)
317+
if asample_rate.numel() > 0:
318+
self.assertEqual(asample_rate.item(), config.audio_sample_rate)
319+
audio_duration = aduration.item() * Fraction(
320+
atimebase[0].item(), atimebase[1].item()
321+
)
322+
self.assertAlmostEqual(audio_duration, config.duration, delta=0.5)
323+
291324
def compare_decoding_result(self, tv_result, ref_result, config=all_check_config):
292325
"""
293326
Compare decoding results from two sources.
@@ -297,18 +330,17 @@ def compare_decoding_result(self, tv_result, ref_result, config=all_check_config
297330
decoder or TorchVision decoder with getPtsOnly = 1
298331
config: config of decoding results checker
299332
"""
300-
vframes, vframe_pts, vtimebase, _vfps, aframes, aframe_pts, atimebase, _asample_rate = (
301-
tv_result
302-
)
333+
vframes, vframe_pts, vtimebase, _vfps, _vduration, aframes, aframe_pts, \
334+
atimebase, _asample_rate, _aduration = tv_result
303335
if isinstance(ref_result, list):
304336
# the ref_result is from new video_reader decoder
305337
ref_result = DecoderResult(
306338
vframes=ref_result[0],
307339
vframe_pts=ref_result[1],
308340
vtimebase=ref_result[2],
309-
aframes=ref_result[4],
310-
aframe_pts=ref_result[5],
311-
atimebase=ref_result[6],
341+
aframes=ref_result[5],
342+
aframe_pts=ref_result[6],
343+
atimebase=ref_result[7],
312344
)
313345

314346
if vframes.numel() > 0 and ref_result.vframes.numel() > 0:
@@ -351,12 +383,12 @@ def test_stress_test_read_video_from_file(self):
351383
audio_start_pts, audio_end_pts = 0, -1
352384
audio_timebase_num, audio_timebase_den = 0, 1
353385

354-
for i in range(num_iter):
355-
for test_video, config in test_videos.items():
386+
for _i in range(num_iter):
387+
for test_video, _config in test_videos.items():
356388
full_path = os.path.join(VIDEO_DIR, test_video)
357389

358390
# pass 1: decode all frames using new decoder
359-
_ = torch.ops.video_reader.read_video_from_file(
391+
torch.ops.video_reader.read_video_from_file(
360392
full_path,
361393
seek_frame_margin,
362394
0, # getPtsOnly
@@ -460,9 +492,8 @@ def test_read_video_from_file_read_single_stream_only(self):
460492
audio_timebase_den,
461493
)
462494

463-
vframes, vframe_pts, vtimebase, vfps, aframes, aframe_pts, atimebase, asample_rate = (
464-
tv_result
465-
)
495+
vframes, vframe_pts, vtimebase, vfps, vduration, aframes, aframe_pts, \
496+
atimebase, asample_rate, aduration = tv_result
466497

467498
self.assertEqual(vframes.numel() > 0, readVideoStream)
468499
self.assertEqual(vframe_pts.numel() > 0, readVideoStream)
@@ -489,7 +520,7 @@ def test_read_video_from_file_rescale_min_dimension(self):
489520
audio_start_pts, audio_end_pts = 0, -1
490521
audio_timebase_num, audio_timebase_den = 0, 1
491522

492-
for test_video, config in test_videos.items():
523+
for test_video, _config in test_videos.items():
493524
full_path = os.path.join(VIDEO_DIR, test_video)
494525

495526
tv_result = torch.ops.video_reader.read_video_from_file(
@@ -528,7 +559,7 @@ def test_read_video_from_file_rescale_width(self):
528559
audio_start_pts, audio_end_pts = 0, -1
529560
audio_timebase_num, audio_timebase_den = 0, 1
530561

531-
for test_video, config in test_videos.items():
562+
for test_video, _config in test_videos.items():
532563
full_path = os.path.join(VIDEO_DIR, test_video)
533564

534565
tv_result = torch.ops.video_reader.read_video_from_file(
@@ -567,7 +598,7 @@ def test_read_video_from_file_rescale_height(self):
567598
audio_start_pts, audio_end_pts = 0, -1
568599
audio_timebase_num, audio_timebase_den = 0, 1
569600

570-
for test_video, config in test_videos.items():
601+
for test_video, _config in test_videos.items():
571602
full_path = os.path.join(VIDEO_DIR, test_video)
572603

573604
tv_result = torch.ops.video_reader.read_video_from_file(
@@ -606,7 +637,7 @@ def test_read_video_from_file_rescale_width_and_height(self):
606637
audio_start_pts, audio_end_pts = 0, -1
607638
audio_timebase_num, audio_timebase_den = 0, 1
608639

609-
for test_video, config in test_videos.items():
640+
for test_video, _config in test_videos.items():
610641
full_path = os.path.join(VIDEO_DIR, test_video)
611642

612643
tv_result = torch.ops.video_reader.read_video_from_file(
@@ -651,7 +682,7 @@ def test_read_video_from_file_audio_resampling(self):
651682
audio_start_pts, audio_end_pts = 0, -1
652683
audio_timebase_num, audio_timebase_den = 0, 1
653684

654-
for test_video, config in test_videos.items():
685+
for test_video, _config in test_videos.items():
655686
full_path = os.path.join(VIDEO_DIR, test_video)
656687

657688
tv_result = torch.ops.video_reader.read_video_from_file(
@@ -674,18 +705,17 @@ def test_read_video_from_file_audio_resampling(self):
674705
audio_timebase_num,
675706
audio_timebase_den,
676707
)
677-
vframes, vframe_pts, vtimebase, vfps, aframes, aframe_pts, atimebase, a_sample_rate = (
678-
tv_result
679-
)
708+
vframes, vframe_pts, vtimebase, vfps, vduration, aframes, aframe_pts, \
709+
atimebase, asample_rate, aduration = tv_result
680710
if aframes.numel() > 0:
681-
self.assertEqual(samples, a_sample_rate.item())
711+
self.assertEqual(samples, asample_rate.item())
682712
self.assertEqual(1, aframes.size(1))
683713
# when audio stream is found
684714
duration = float(aframe_pts[-1]) * float(atimebase[0]) / float(atimebase[1])
685715
self.assertAlmostEqual(
686716
aframes.size(0),
687-
int(duration * a_sample_rate.item()),
688-
delta=0.1 * a_sample_rate.item(),
717+
int(duration * asample_rate.item()),
718+
delta=0.1 * asample_rate.item(),
689719
)
690720

691721
def test_compare_read_video_from_memory_and_file(self):
@@ -859,7 +889,7 @@ def test_read_video_from_memory_get_pts_only(self):
859889
)
860890

861891
self.assertEqual(tv_result_pts_only[0].numel(), 0)
862-
self.assertEqual(tv_result_pts_only[4].numel(), 0)
892+
self.assertEqual(tv_result_pts_only[5].numel(), 0)
863893
self.compare_decoding_result(tv_result, tv_result_pts_only)
864894

865895
def test_read_video_in_range_from_memory(self):
@@ -899,9 +929,8 @@ def test_read_video_in_range_from_memory(self):
899929
audio_timebase_num,
900930
audio_timebase_den,
901931
)
902-
vframes, vframe_pts, vtimebase, vfps, aframes, aframe_pts, atimebase, asample_rate = (
903-
tv_result
904-
)
932+
vframes, vframe_pts, vtimebase, vfps, vduration, aframes, aframe_pts, \
933+
atimebase, asample_rate, aduration = tv_result
905934
self.assertAlmostEqual(config.video_fps, vfps.item(), delta=0.01)
906935

907936
for num_frames in [4, 8, 16, 32, 64, 128]:
@@ -997,6 +1026,24 @@ def test_read_video_in_range_from_memory(self):
9971026
# and PyAv
9981027
self.compare_decoding_result(tv_result, pyav_result, config)
9991028

1029+
def test_probe_video_from_file(self):
1030+
"""
1031+
Test the case when decoder probes a video file
1032+
"""
1033+
for test_video, config in test_videos.items():
1034+
full_path = os.path.join(VIDEO_DIR, test_video)
1035+
probe_result = torch.ops.video_reader.probe_video_from_file(full_path)
1036+
self.check_probe_result(probe_result, config)
1037+
1038+
def test_probe_video_from_memory(self):
1039+
"""
1040+
Test the case when decoder probes a video in memory
1041+
"""
1042+
for test_video, config in test_videos.items():
1043+
full_path, video_tensor = _get_video_tensor(VIDEO_DIR, test_video)
1044+
probe_result = torch.ops.video_reader.probe_video_from_memory(video_tensor)
1045+
self.check_probe_result(probe_result, config)
1046+
10001047

10011048
if __name__ == '__main__':
10021049
unittest.main()

torchvision/csrc/cpu/video_reader/FfmpegAudioStream.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ void FfmpegAudioStream::updateStreamDecodeParams() {
4949
mediaFormat_.format.audio.timeBaseDen =
5050
inputCtx_->streams[index_]->time_base.den;
5151
}
52+
mediaFormat_.format.audio.duration = inputCtx_->streams[index_]->duration;
5253
}
5354

5455
int FfmpegAudioStream::initFormat() {

torchvision/csrc/cpu/video_reader/FfmpegDecoder.cpp

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,30 @@ int FfmpegDecoder::decodeMemory(
220220
return ret;
221221
}
222222

223+
int FfmpegDecoder::probeFile(
224+
unique_ptr<DecoderParameters> params,
225+
const string& fileName,
226+
DecoderOutput& decoderOutput) {
227+
VLOG(1) << "probe file: " << fileName;
228+
FfmpegAvioContext ioctx;
229+
return probeVideo(std::move(params), fileName, true, ioctx, decoderOutput);
230+
}
231+
232+
int FfmpegDecoder::probeMemory(
233+
unique_ptr<DecoderParameters> params,
234+
const uint8_t* buffer,
235+
int64_t size,
236+
DecoderOutput& decoderOutput) {
237+
VLOG(1) << "probe video data in memory";
238+
FfmpegAvioContext ioctx;
239+
int ret = ioctx.initAVIOContext(buffer, size);
240+
if (ret == 0) {
241+
ret =
242+
probeVideo(std::move(params), string(""), false, ioctx, decoderOutput);
243+
}
244+
return ret;
245+
}
246+
223247
void FfmpegDecoder::cleanUp() {
224248
if (formatCtx_) {
225249
for (auto& stream : streams_) {
@@ -320,6 +344,16 @@ int FfmpegDecoder::decodeLoop(
320344
return ret;
321345
}
322346

347+
int FfmpegDecoder::probeVideo(
348+
unique_ptr<DecoderParameters> params,
349+
const std::string& filename,
350+
bool isDecodeFile,
351+
FfmpegAvioContext& ioctx,
352+
DecoderOutput& decoderOutput) {
353+
params_ = std::move(params);
354+
return init(filename, isDecodeFile, ioctx, decoderOutput);
355+
}
356+
323357
bool FfmpegDecoder::initStreams() {
324358
for (auto it = params_->formats.begin(); it != params_->formats.end(); ++it) {
325359
AVMediaType mediaType;

0 commit comments

Comments
 (0)