Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 43 additions & 9 deletions src/diffusers/video_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from typing import List, Optional, Tuple, Union

import numpy as np
import PIL
import PIL.Image
import torch
import torch.nn.functional as F

Expand All @@ -26,9 +26,11 @@
class VideoProcessor(VaeImageProcessor):
r"""Simple video processor."""

def preprocess_video(self, video, height: Optional[int] = None, width: Optional[int] = None) -> torch.Tensor:
def preprocess_video(
self, video, height: Optional[int] = None, width: Optional[int] = None, **kwargs
) -> torch.Tensor:
r"""
Preprocesses input video(s).
Preprocesses input video(s). Keyword arguments will be forwarded to `VaeImageProcessor.preprocess`.

Args:
video (`List[PIL.Image]`, `List[List[PIL.Image]]`, `torch.Tensor`, `np.array`, `List[torch.Tensor]`, `List[np.array]`):
Expand All @@ -50,6 +52,10 @@ def preprocess_video(self, video, height: Optional[int] = None, width: Optional[
width (`int`, *optional*`, defaults to `None`):
The width in preprocessed frames of the video. If `None`, will use get_default_height_width()` to get
the default width.

Returns:
`torch.Tensor` of shape `(batch_size, num_channels, num_frames, height, width)`:
A 5D tensor holding the batched channels-first video(s).
"""
if isinstance(video, list) and isinstance(video[0], np.ndarray) and video[0].ndim == 5:
warnings.warn(
Expand All @@ -67,31 +73,59 @@ def preprocess_video(self, video, height: Optional[int] = None, width: Optional[
video = torch.cat(video, axis=0)

# ensure the input is a list of videos:
# - if it is a batch of videos (5d torch.Tensor or np.ndarray), it is converted to a list of videos (a list of 4d torch.Tensor or np.ndarray)
# - if it is a single video, it is converted to a list of one video.
# - if it is a batched array of videos (5d torch.Tensor or np.ndarray), it is converted to a list of video
# arrays (a list of 4d torch.Tensor or np.ndarray). `VaeImageProcessor.preprocess` will then treat the first
# (frame) dim as a batch dim.
# - if it is a single video, it is converted to a list of one video. (A single video is a list of images or a
# single imagelist.)
# - if it is a list of imagelists, it will be kept as is (already a list of videos).
# - if it is a single image, it is expanded to a single frame video and then to a list of one video. The
# expansion will depend on the image type:
# - PIL.Image.Image --> one element list of PIL.Image.Image
# - 3D np.ndarray --> interpret as (H, W, C), expand to (F=1, H, W, C)
# - 3D torch.Tensor --> interpret as (C, H, W), expand to (F=1, C, H, W)
if isinstance(video, (np.ndarray, torch.Tensor)) and video.ndim == 5:
video = list(video)
elif isinstance(video, list) and is_valid_image(video[0]) or is_valid_image_imagelist(video):
video = [video]
elif isinstance(video, list) and is_valid_image_imagelist(video[0]):
video = video
elif is_valid_image(video):
if isinstance(video, PIL.Image.Image):
video = [video]
elif isinstance(video, np.ndarray):
if video.ndim == 2:
video = np.expand_dims(video, axis=-1) # Unsqueeze channel dim in last axis
if video.ndim == 3:
video = np.expand_dims(video, axis=0)
else:
raise ValueError(f"Input numpy.ndarray is expected to have 2 or 3 dims but got {video.ndim} dims")
elif isinstance(video, torch.Tensor):
if video.ndim == 2:
video = torch.unsqueeze(video, dim=0) # Unsqueeze channel dim in first dim
if video.ndim == 3:
video = torch.unsqueeze(video, dim=0)
else:
raise ValueError(f"Input torch.Tensor is expected to have 2 or 3 dims but got {video.ndim} dims")
video = [video]
else:
raise ValueError(
"Input is in incorrect format. Currently, we only support numpy.ndarray, torch.Tensor, PIL.Image.Image"
)

video = torch.stack([self.preprocess(img, height=height, width=width) for img in video], dim=0)
video = torch.stack([self.preprocess(img, height=height, width=width, **kwargs) for img in video], dim=0)

# move the number of channels before the number of frames.
video = video.permute(0, 2, 1, 3, 4)

return video

def postprocess_video(
self, video: torch.Tensor, output_type: str = "np"
self, video: torch.Tensor, output_type: str = "np", **kwargs
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What would kwargs facilitate?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently, this would facilitate passing the do_denormalize flag to VaeImageProcessor.postprocess. But it's intended more as a forward-looking change which allows postprocess_video to support any arguments that postprocess might want.

) -> Union[np.ndarray, torch.Tensor, List[PIL.Image.Image]]:
r"""
Converts a video tensor to a list of frames for export.
Converts a video tensor to a list of frames for export. Keyword arguments will be forwarded to
`VaeImageProcessor.postprocess`.

Args:
video (`torch.Tensor`): The video as a tensor.
Expand All @@ -101,7 +135,7 @@ def postprocess_video(
outputs = []
for batch_idx in range(batch_size):
batch_vid = video[batch_idx].permute(1, 0, 2, 3)
batch_output = self.postprocess(batch_vid, output_type)
batch_output = self.postprocess(batch_vid, output_type, **kwargs)
outputs.append(batch_output)

if output_type == "np":
Expand Down
Loading