Skip to content

Commit e95a3d2

Browse files
authored
Restructure the video/video_reader C++ codebase (#3311)
* Moving registration of video methods in Video.cpp and removing unnecessary includes. * Rename files according to cpp styles. * Adding namespaces and moving private methods to anonymous namespaces. * Syncing method names. * Fixing minor issues.
1 parent 691ec6d commit e95a3d2

File tree

6 files changed

+202
-149
lines changed

6 files changed

+202
-149
lines changed

torchvision/csrc/io/video/register.cpp

Lines changed: 0 additions & 14 deletions
This file was deleted.

torchvision/csrc/io/video/Video.cpp renamed to torchvision/csrc/io/video/video.cpp

Lines changed: 28 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
1-
#include "Video.h"
2-
#include <c10/util/Logging.h>
3-
#include <torch/script.h>
4-
#include "defs.h"
5-
#include "memory_buffer.h"
6-
#include "sync_decoder.h"
1+
#include "video.h"
72

8-
using namespace std;
9-
using namespace ffmpeg;
3+
#include <regex>
4+
5+
namespace vision {
6+
namespace video {
7+
8+
namespace {
109

1110
const size_t decoderTimeoutMs = 600000;
1211
const AVPixelFormat defaultVideoPixelFormat = AV_PIX_FMT_RGB24;
@@ -93,6 +92,8 @@ std::tuple<std::string, long> _parseStream(const std::string& streamString) {
9392
return std::make_tuple(type_, index_);
9493
}
9594

95+
} // namespace
96+
9697
void Video::_getDecoderParams(
9798
double videoStartS,
9899
int64_t getPtsOnly,
@@ -159,7 +160,7 @@ Video::Video(std::string videoPath, std::string stream) {
159160
Video::_getDecoderParams(
160161
0, // video start
161162
0, // headerOnly
162-
get<0>(current_stream), // stream info - remove that
163+
std::get<0>(current_stream), // stream info - remove that
163164
long(-1), // stream_id parsed from info above change to -2
164165
true // read all streams
165166
);
@@ -209,9 +210,9 @@ Video::Video(std::string videoPath, std::string stream) {
209210

210211
succeeded = Video::setCurrentStream(stream);
211212
LOG(INFO) << "\nDecoder inited with: " << succeeded << "\n";
212-
if (get<1>(current_stream) != -1) {
213+
if (std::get<1>(current_stream) != -1) {
213214
LOG(INFO)
214-
<< "Stream index set to " << get<1>(current_stream)
215+
<< "Stream index set to " << std::get<1>(current_stream)
215216
<< ". If you encounter trouble, consider switching it to automatic stream discovery. \n";
216217
}
217218
} // video
@@ -229,8 +230,8 @@ bool Video::setCurrentStream(std::string stream = "video") {
229230
_getDecoderParams(
230231
ts, // video start
231232
0, // headerOnly
232-
get<0>(current_stream), // stream
233-
long(get<1>(
233+
std::get<0>(current_stream), // stream
234+
long(std::get<1>(
234235
current_stream)), // stream_id parsed from info above change to -2
235236
false // read all streams
236237
);
@@ -253,8 +254,8 @@ void Video::Seek(double ts) {
253254
_getDecoderParams(
254255
ts, // video start
255256
0, // headerOnly
256-
get<0>(current_stream), // stream
257-
long(get<1>(
257+
std::get<0>(current_stream), // stream
258+
long(std::get<1>(
258259
current_stream)), // stream_id parsed from info above change to -2
259260
false // read all streams
260261
);
@@ -319,3 +320,15 @@ std::tuple<torch::Tensor, double> Video::Next() {
319320

320321
return std::make_tuple(outFrame, frame_pts_s);
321322
}
323+
324+
static auto registerVideo =
325+
torch::class_<Video>("torchvision", "Video")
326+
.def(torch::init<std::string, std::string>())
327+
.def("get_current_stream", &Video::getCurrentStream)
328+
.def("set_current_stream", &Video::setCurrentStream)
329+
.def("get_metadata", &Video::getStreamMetadata)
330+
.def("seek", &Video::Seek)
331+
.def("next", &Video::Next);
332+
333+
} // namespace video
334+
} // namespace vision

torchvision/csrc/io/video/Video.h renamed to torchvision/csrc/io/video/video.h

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,16 @@
11
#pragma once
22

3-
#include <map>
4-
#include <regex>
5-
#include <string>
6-
#include <vector>
3+
#include <torch/types.h>
74

8-
#include <ATen/ATen.h>
9-
#include <c10/util/Logging.h>
10-
#include <torch/script.h>
11-
12-
#include <exception>
13-
#include "defs.h"
14-
#include "memory_buffer.h"
15-
#include "sync_decoder.h"
5+
#include "../decoder/defs.h"
6+
#include "../decoder/memory_buffer.h"
7+
#include "../decoder/sync_decoder.h"
168

179
using namespace ffmpeg;
1810

11+
namespace vision {
12+
namespace video {
13+
1914
struct Video : torch::CustomClassHolder {
2015
std::tuple<std::string, long> current_stream; // stream type, id
2116
// global video metadata
@@ -58,3 +53,6 @@ struct Video : torch::CustomClassHolder {
5853
DecoderParameters params;
5954

6055
}; // struct Video
56+
57+
} // namespace video
58+
} // namespace vision

torchvision/csrc/io/video_reader/VideoReader.h

Lines changed: 0 additions & 3 deletions
This file was deleted.

torchvision/csrc/io/video_reader/VideoReader.cpp renamed to torchvision/csrc/io/video_reader/video_reader.cpp

Lines changed: 109 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,9 @@
1-
#include "VideoReader.h"
2-
#include <ATen/ATen.h>
1+
#include "video_reader.h"
2+
33
#include <Python.h>
4-
#include <c10/util/Logging.h>
5-
#include <exception>
6-
#include "memory_buffer.h"
7-
#include "sync_decoder.h"
84

9-
using namespace std;
10-
using namespace ffmpeg;
5+
#include "../decoder/memory_buffer.h"
6+
#include "../decoder/sync_decoder.h"
117

128
// If we are in a Windows environment, we need to define
139
// initialization functions for the _custom_ops extension
@@ -18,8 +14,13 @@ PyMODINIT_FUNC PyInit_video_reader(void) {
1814
}
1915
#endif
2016

17+
using namespace ffmpeg;
18+
19+
namespace vision {
2120
namespace video_reader {
2221

22+
namespace {
23+
2324
const AVPixelFormat defaultVideoPixelFormat = AV_PIX_FMT_RGB24;
2425
const AVSampleFormat defaultAudioSampleFormat = AV_SAMPLE_FMT_FLT;
2526
const AVRational timeBaseQ = AVRational{1, AV_TIME_BASE};
@@ -417,95 +418,6 @@ torch::List<torch::Tensor> readVideo(
417418
return result;
418419
}
419420

420-
torch::List<torch::Tensor> readVideoFromMemory(
421-
torch::Tensor input_video,
422-
double seekFrameMargin,
423-
int64_t getPtsOnly,
424-
int64_t readVideoStream,
425-
int64_t width,
426-
int64_t height,
427-
int64_t minDimension,
428-
int64_t maxDimension,
429-
int64_t videoStartPts,
430-
int64_t videoEndPts,
431-
int64_t videoTimeBaseNum,
432-
int64_t videoTimeBaseDen,
433-
int64_t readAudioStream,
434-
int64_t audioSamples,
435-
int64_t audioChannels,
436-
int64_t audioStartPts,
437-
int64_t audioEndPts,
438-
int64_t audioTimeBaseNum,
439-
int64_t audioTimeBaseDen) {
440-
return readVideo(
441-
false,
442-
input_video,
443-
"", // videoPath
444-
seekFrameMargin,
445-
getPtsOnly,
446-
readVideoStream,
447-
width,
448-
height,
449-
minDimension,
450-
maxDimension,
451-
videoStartPts,
452-
videoEndPts,
453-
videoTimeBaseNum,
454-
videoTimeBaseDen,
455-
readAudioStream,
456-
audioSamples,
457-
audioChannels,
458-
audioStartPts,
459-
audioEndPts,
460-
audioTimeBaseNum,
461-
audioTimeBaseDen);
462-
}
463-
464-
torch::List<torch::Tensor> readVideoFromFile(
465-
std::string videoPath,
466-
double seekFrameMargin,
467-
int64_t getPtsOnly,
468-
int64_t readVideoStream,
469-
int64_t width,
470-
int64_t height,
471-
int64_t minDimension,
472-
int64_t maxDimension,
473-
int64_t videoStartPts,
474-
int64_t videoEndPts,
475-
int64_t videoTimeBaseNum,
476-
int64_t videoTimeBaseDen,
477-
int64_t readAudioStream,
478-
int64_t audioSamples,
479-
int64_t audioChannels,
480-
int64_t audioStartPts,
481-
int64_t audioEndPts,
482-
int64_t audioTimeBaseNum,
483-
int64_t audioTimeBaseDen) {
484-
torch::Tensor dummy_input_video = torch::ones({0});
485-
return readVideo(
486-
true,
487-
dummy_input_video,
488-
videoPath,
489-
seekFrameMargin,
490-
getPtsOnly,
491-
readVideoStream,
492-
width,
493-
height,
494-
minDimension,
495-
maxDimension,
496-
videoStartPts,
497-
videoEndPts,
498-
videoTimeBaseNum,
499-
videoTimeBaseDen,
500-
readAudioStream,
501-
audioSamples,
502-
audioChannels,
503-
audioStartPts,
504-
audioEndPts,
505-
audioTimeBaseNum,
506-
audioTimeBaseDen);
507-
}
508-
509421
torch::List<torch::Tensor> probeVideo(
510422
bool isReadFile,
511423
const torch::Tensor& input_video,
@@ -650,20 +562,112 @@ torch::List<torch::Tensor> probeVideo(
650562
return result;
651563
}
652564

653-
torch::List<torch::Tensor> probeVideoFromMemory(torch::Tensor input_video) {
565+
} // namespace
566+
567+
torch::List<torch::Tensor> read_video_from_memory(
568+
torch::Tensor input_video,
569+
double seekFrameMargin,
570+
int64_t getPtsOnly,
571+
int64_t readVideoStream,
572+
int64_t width,
573+
int64_t height,
574+
int64_t minDimension,
575+
int64_t maxDimension,
576+
int64_t videoStartPts,
577+
int64_t videoEndPts,
578+
int64_t videoTimeBaseNum,
579+
int64_t videoTimeBaseDen,
580+
int64_t readAudioStream,
581+
int64_t audioSamples,
582+
int64_t audioChannels,
583+
int64_t audioStartPts,
584+
int64_t audioEndPts,
585+
int64_t audioTimeBaseNum,
586+
int64_t audioTimeBaseDen) {
587+
return readVideo(
588+
false,
589+
input_video,
590+
"", // videoPath
591+
seekFrameMargin,
592+
getPtsOnly,
593+
readVideoStream,
594+
width,
595+
height,
596+
minDimension,
597+
maxDimension,
598+
videoStartPts,
599+
videoEndPts,
600+
videoTimeBaseNum,
601+
videoTimeBaseDen,
602+
readAudioStream,
603+
audioSamples,
604+
audioChannels,
605+
audioStartPts,
606+
audioEndPts,
607+
audioTimeBaseNum,
608+
audioTimeBaseDen);
609+
}
610+
611+
torch::List<torch::Tensor> read_video_from_file(
612+
std::string videoPath,
613+
double seekFrameMargin,
614+
int64_t getPtsOnly,
615+
int64_t readVideoStream,
616+
int64_t width,
617+
int64_t height,
618+
int64_t minDimension,
619+
int64_t maxDimension,
620+
int64_t videoStartPts,
621+
int64_t videoEndPts,
622+
int64_t videoTimeBaseNum,
623+
int64_t videoTimeBaseDen,
624+
int64_t readAudioStream,
625+
int64_t audioSamples,
626+
int64_t audioChannels,
627+
int64_t audioStartPts,
628+
int64_t audioEndPts,
629+
int64_t audioTimeBaseNum,
630+
int64_t audioTimeBaseDen) {
631+
torch::Tensor dummy_input_video = torch::ones({0});
632+
return readVideo(
633+
true,
634+
dummy_input_video,
635+
videoPath,
636+
seekFrameMargin,
637+
getPtsOnly,
638+
readVideoStream,
639+
width,
640+
height,
641+
minDimension,
642+
maxDimension,
643+
videoStartPts,
644+
videoEndPts,
645+
videoTimeBaseNum,
646+
videoTimeBaseDen,
647+
readAudioStream,
648+
audioSamples,
649+
audioChannels,
650+
audioStartPts,
651+
audioEndPts,
652+
audioTimeBaseNum,
653+
audioTimeBaseDen);
654+
}
655+
656+
torch::List<torch::Tensor> probe_video_from_memory(torch::Tensor input_video) {
654657
return probeVideo(false, input_video, "");
655658
}
656659

657-
torch::List<torch::Tensor> probeVideoFromFile(std::string videoPath) {
660+
torch::List<torch::Tensor> probe_video_from_file(std::string videoPath) {
658661
torch::Tensor dummy_input_video = torch::ones({0});
659662
return probeVideo(true, dummy_input_video, videoPath);
660663
}
661664

662-
} // namespace video_reader
663-
664665
TORCH_LIBRARY_FRAGMENT(video_reader, m) {
665-
m.def("read_video_from_memory", video_reader::readVideoFromMemory);
666-
m.def("read_video_from_file", video_reader::readVideoFromFile);
667-
m.def("probe_video_from_memory", video_reader::probeVideoFromMemory);
668-
m.def("probe_video_from_file", video_reader::probeVideoFromFile);
666+
m.def("read_video_from_memory", read_video_from_memory);
667+
m.def("read_video_from_file", read_video_from_file);
668+
m.def("probe_video_from_memory", probe_video_from_memory);
669+
m.def("probe_video_from_file", probe_video_from_file);
669670
}
671+
672+
} // namespace video_reader
673+
} // namespace vision

0 commit comments

Comments
 (0)