Skip to content

Commit f4708cf

Browse files
author
zyan3
committed
allow to accept input video of type torch.Tensor
1 parent 63cdd3e commit f4708cf

File tree

1 file changed

+14
-12
lines changed

1 file changed

+14
-12
lines changed

torchvision/io/_video_opt.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,7 @@ def _probe_video_from_file(filename):
205205

206206

207207
def _read_video_from_memory(
208-
file_buffer,
208+
video_data,
209209
seek_frame_margin=0.25,
210210
read_video_stream=1,
211211
video_width=0,
@@ -225,8 +225,8 @@ def _read_video_from_memory(
225225
226226
Args
227227
----------
228-
file_buffer : buffer
229-
buffer of compressed video content
228+
video_data : data type could be 1) torch.Tensor, dtype=torch.int8 or 2) python bytes
229+
compressed video content stored in either 1) torch.Tensor 2) python bytes
230230
seek_frame_margin: double, optional
231231
seeking frame in the stream is imprecise. Thus, when video_start_pts is specified,
232232
we seek the pts earlier by seek_frame_margin seconds
@@ -273,10 +273,11 @@ def _read_video_from_memory(
273273
_validate_pts(video_pts_range)
274274
_validate_pts(audio_pts_range)
275275

276-
video_tensor = torch.from_numpy(np.frombuffer(file_buffer, dtype=np.uint8))
276+
if not isinstance(video_data, torch.Tensor):
277+
video_data = torch.from_numpy(np.frombuffer(video_data, dtype=np.uint8))
277278

278279
result = torch.ops.video_reader.read_video_from_memory(
279-
video_tensor,
280+
video_data,
280281
seek_frame_margin,
281282
0, # getPtsOnly
282283
read_video_stream,
@@ -305,16 +306,16 @@ def _read_video_from_memory(
305306
return vframes, aframes, info
306307

307308

308-
def _read_video_timestamps_from_memory(file_buffer):
309+
def _read_video_timestamps_from_memory(video_data):
309310
"""
310311
Decode all frames in the video. Only pts (presentation timestamp) is returned.
311312
The actual frame pixel data is not copied. Thus, read_video_timestamps(...)
312313
is much faster than read_video(...)
313314
"""
314-
315-
video_tensor = torch.from_numpy(np.frombuffer(file_buffer, dtype=np.uint8))
315+
if not isinstance(video_data, torch.Tensor):
316+
video_data = torch.from_numpy(np.frombuffer(video_data, dtype=np.uint8))
316317
result = torch.ops.video_reader.read_video_from_memory(
317-
video_tensor,
318+
video_data,
318319
0, # seek_frame_margin
319320
1, # getPtsOnly
320321
1, # read_video_stream
@@ -342,15 +343,16 @@ def _read_video_timestamps_from_memory(file_buffer):
342343
return vframe_pts, aframe_pts, info
343344

344345

345-
def _probe_video_from_memory(file_buffer):
346+
def _probe_video_from_memory(video_data):
346347
"""
347348
Probe a video in memory.
348349
Return:
349350
info [dict]: contain video meta information, including video_timebase,
350351
video_duration, video_fps, audio_timebase, audio_duration, audio_sample_rate
351352
"""
352-
video_tensor = torch.from_numpy(np.frombuffer(file_buffer, dtype=np.uint8))
353-
result = torch.ops.video_reader.probe_video_from_memory(video_tensor)
353+
if not isinstance(video_data, torch.Tensor):
354+
video_data = torch.from_numpy(np.frombuffer(video_data, dtype=np.uint8))
355+
result = torch.ops.video_reader.probe_video_from_memory(video_data)
354356
vtimebase, vfps, vduration, atimebase, asample_rate, aduration = result
355357
info = _fill_info(vtimebase, vfps, vduration, atimebase, asample_rate, aduration)
356358
return info

0 commit comments

Comments
 (0)