Skip to content

Commit 0cb7cf3

Browse files
prabhat00155facebook-github-bot
authored andcommitted
[fbsync] Fast seek implementation (#3179)
Summary: * modify processPacket to support fast seek * add fastSeek to ProcessPacket decoder definition * add fastseek flag to DecoderParametersStruct * add fastseek flag to the process packet call * no default params in C++ implementation * enable flag in C++ implementation * make order of parameters more normal * register new seek with python api * [somewhat broken] test suite for keyframes using pyav * revert " changes * add type annotations to init * Adding tests * linter * Flake doesn't show up :| * Change from unitest to pytest syntax * add return type Reviewed By: kazhang Differential Revision: D32216689 fbshipit-source-id: 695975c2930cb663ea82c83e4bc924a09e124a7d Co-authored-by: Prabhat Roy <[email protected]>
1 parent 2e71a67 commit 0cb7cf3

File tree

7 files changed

+66
-10
lines changed

7 files changed

+66
-10
lines changed

test/test_videoapi.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,43 @@ def test_fate_suite(self):
167167
assert metadata["subtitles"]["duration"] is not None
168168
os.remove(video_path)
169169

170+
@pytest.mark.skipif(av is None, reason="PyAV unavailable")
171+
def test_keyframe_reading(self):
172+
for test_video, config in test_videos.items():
173+
full_path = os.path.join(VIDEO_DIR, test_video)
174+
175+
av_reader = av.open(full_path)
176+
# reduce streams to only keyframes
177+
av_stream = av_reader.streams.video[0]
178+
av_stream.codec_context.skip_frame = "NONKEY"
179+
180+
av_keyframes = []
181+
vr_keyframes = []
182+
if av_reader.streams.video:
183+
184+
# get all keyframes using pyav. Then, seek randomly into video reader
185+
# and assert that all the returned values are in AV_KEYFRAMES
186+
187+
for av_frame in av_reader.decode(av_stream):
188+
av_keyframes.append(float(av_frame.pts * av_frame.time_base))
189+
190+
if len(av_keyframes) > 1:
191+
video_reader = VideoReader(full_path, "video")
192+
for i in range(1, len(av_keyframes)):
193+
seek_val = (av_keyframes[i] + av_keyframes[i - 1]) / 2
194+
data = next(video_reader.seek(seek_val, True))
195+
vr_keyframes.append(data["pts"])
196+
197+
data = next(video_reader.seek(config.duration, True))
198+
vr_keyframes.append(data["pts"])
199+
200+
assert len(av_keyframes) == len(vr_keyframes)
201+
# NOTE: this video gets different keyframe with different
202+
# loaders (0.333 pyav, 0.666 for us)
203+
if test_video != "TrumanShow_wave_f_nm_np1_fr_med_26.avi":
204+
for i in range(len(av_keyframes)):
205+
assert av_keyframes[i] == approx(vr_keyframes[i], rel=0.001)
206+
170207

171208
if __name__ == "__main__":
172209
pytest.main([__file__])

torchvision/csrc/io/decoder/decoder.cpp

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -552,9 +552,9 @@ int Decoder::getFrame(size_t workingTimeInMs) {
552552
bool gotFrame = false;
553553
bool hasMsg = false;
554554
// packet either got consumed completely or not at all
555-
if ((result = processPacket(stream, &avPacket, &gotFrame, &hasMsg)) < 0) {
556-
LOG(ERROR) << "uuid=" << params_.loggingUuid
557-
<< " processPacket failed with code=" << result;
555+
if ((result = processPacket(
556+
stream, &avPacket, &gotFrame, &hasMsg, params_.fastSeek)) < 0) {
557+
LOG(ERROR) << "processPacket failed with code: " << result;
558558
break;
559559
}
560560

@@ -635,7 +635,8 @@ int Decoder::processPacket(
635635
Stream* stream,
636636
AVPacket* packet,
637637
bool* gotFrame,
638-
bool* hasMsg) {
638+
bool* hasMsg,
639+
bool fastSeek) {
639640
// decode package
640641
int result;
641642
DecoderOutputMessage msg;
@@ -648,7 +649,15 @@ int Decoder::processPacket(
648649
bool endInRange =
649650
params_.endOffset <= 0 || msg.header.pts <= params_.endOffset;
650651
inRange_.set(stream->getIndex(), endInRange);
651-
if (endInRange && msg.header.pts >= params_.startOffset) {
652+
// if fastseek is enabled, we're returning the first
653+
// frame that we decode after (potential) seek.
654+
// By default, we perform accurate seek to the closest
655+
// following frame
656+
bool startCondition = true;
657+
if (!fastSeek) {
658+
startCondition = msg.header.pts >= params_.startOffset;
659+
}
660+
if (endInRange && startCondition) {
652661
*hasMsg = true;
653662
push(std::move(msg));
654663
}

torchvision/csrc/io/decoder/decoder.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,8 @@ class Decoder : public MediaDecoder {
7272
Stream* stream,
7373
AVPacket* packet,
7474
bool* gotFrame,
75-
bool* hasMsg);
75+
bool* hasMsg,
76+
bool fastSeek = false);
7677
void flushStreams();
7778
void cleanUp();
7879

torchvision/csrc/io/decoder/defs.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,8 @@ struct DecoderParameters {
190190
bool listen{false};
191191
// don't copy frame body, only header
192192
bool headerOnly{false};
193+
// enable fast seek (seek only to keyframes)
194+
bool fastSeek{false};
193195
// interrupt init method on timeout
194196
bool preventStaleness{true};
195197
// seek tolerated accuracy (us)

torchvision/csrc/io/video/video.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ void Video::_getDecoderParams(
9898
int64_t getPtsOnly,
9999
std::string stream,
100100
long stream_id = -1,
101+
bool fastSeek = true,
101102
bool all_streams = false,
102103
int64_t num_threads = 1,
103104
double seekFrameMarginUs = 10) {
@@ -106,6 +107,7 @@ void Video::_getDecoderParams(
106107
params.timeoutMs = decoderTimeoutMs;
107108
params.startOffset = videoStartUs;
108109
params.seekAccuracy = seekFrameMarginUs;
110+
params.fastSeek = fastSeek;
109111
params.headerOnly = false;
110112
params.numThreads = num_threads;
111113

@@ -165,6 +167,7 @@ Video::Video(std::string videoPath, std::string stream, int64_t numThreads) {
165167
0, // headerOnly
166168
std::get<0>(current_stream), // stream info - remove that
167169
long(-1), // stream_id parsed from info above change to -2
170+
false, // fastseek: we're using the default param here
168171
true, // read all streams
169172
numThreads_ // global number of Threads for decoding
170173
);
@@ -246,6 +249,7 @@ bool Video::setCurrentStream(std::string stream = "video") {
246249
std::get<0>(current_stream), // stream
247250
long(std::get<1>(
248251
current_stream)), // stream_id parsed from info above change to -2
252+
false, // fastseek param set to 0 false by default (changed in seek)
249253
false, // read all streams
250254
numThreads_ // global number of threads
251255
);
@@ -263,14 +267,15 @@ c10::Dict<std::string, c10::Dict<std::string, std::vector<double>>> Video::
263267
return streamsMetadata;
264268
}
265269

266-
void Video::Seek(double ts) {
270+
void Video::Seek(double ts, bool fastSeek = false) {
267271
// initialize the class variables used for seeking and retrurn
268272
_getDecoderParams(
269273
ts, // video start
270274
0, // headerOnly
271275
std::get<0>(current_stream), // stream
272276
long(std::get<1>(
273277
current_stream)), // stream_id parsed from info above change to -2
278+
fastSeek, // fastseek
274279
false, // read all streams
275280
numThreads_ // global number of threads
276281
);

torchvision/csrc/io/video/video.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ struct Video : torch::CustomClassHolder {
2323
std::tuple<std::string, int64_t> getCurrentStream() const;
2424
c10::Dict<std::string, c10::Dict<std::string, std::vector<double>>>
2525
getStreamMetadata() const;
26-
void Seek(double ts);
26+
void Seek(double ts, bool fastSeek);
2727
bool setCurrentStream(std::string stream);
2828
std::tuple<torch::Tensor, double> Next();
2929

@@ -39,6 +39,7 @@ struct Video : torch::CustomClassHolder {
3939
int64_t getPtsOnly,
4040
std::string stream,
4141
long stream_id,
42+
bool fastSeek,
4243
bool all_streams,
4344
int64_t num_threads,
4445
double seekFrameMarginUs); // this needs to be improved

torchvision/io/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -135,19 +135,20 @@ def __next__(self) -> Dict[str, Any]:
135135
def __iter__(self) -> Iterator["VideoReader"]:
136136
return self
137137

138-
def seek(self, time_s: float) -> "VideoReader":
138+
def seek(self, time_s: float, keyframes_only: bool = False) -> "VideoReader":
139139
"""Seek within current stream.
140140
141141
Args:
142142
time_s (float): seek time in seconds
143+
keyframes_only (bool): allow to seek only to keyframes
143144
144145
.. note::
145146
Current implementation is the so-called precise seek. This
146147
means following seek, call to :mod:`next()` will return the
147148
frame with the exact timestamp if it exists or
148149
the first frame with timestamp larger than ``time_s``.
149150
"""
150-
self._c.seek(time_s)
151+
self._c.seek(time_s, keyframes_only)
151152
return self
152153

153154
def get_metadata(self) -> Dict[str, Any]:

0 commit comments

Comments
 (0)