diff --git a/docs/source/en/api/pipelines/ltx2.md b/docs/source/en/api/pipelines/ltx2.md index 24776b42309e..c77efa09f594 100644 --- a/docs/source/en/api/pipelines/ltx2.md +++ b/docs/source/en/api/pipelines/ltx2.md @@ -106,8 +106,6 @@ video, audio = pipe( output_type="np", return_dict=False, ) -video = (video * 255).round().astype("uint8") -video = torch.from_numpy(video) encode_video( video[0], @@ -185,8 +183,6 @@ video, audio = pipe( output_type="np", return_dict=False, ) -video = (video * 255).round().astype("uint8") -video = torch.from_numpy(video) encode_video( video[0], diff --git a/src/diffusers/pipelines/ltx2/export_utils.py b/src/diffusers/pipelines/ltx2/export_utils.py index 0bc7a59db228..05c1ae1929cd 100644 --- a/src/diffusers/pipelines/ltx2/export_utils.py +++ b/src/diffusers/pipelines/ltx2/export_utils.py @@ -13,10 +13,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +from collections.abc import Iterator from fractions import Fraction -from typing import Optional +from itertools import chain +from typing import List, Optional, Union +import numpy as np +import PIL.Image import torch +from tqdm import tqdm from ...utils import is_av_available @@ -101,11 +106,54 @@ def _write_audio( def encode_video( - video: torch.Tensor, fps: int, audio: Optional[torch.Tensor], audio_sample_rate: Optional[int], output_path: str + video: Union[List[PIL.Image.Image], np.ndarray, torch.Tensor, Iterator[torch.Tensor]], + fps: int, + audio: Optional[torch.Tensor], + audio_sample_rate: Optional[int], + output_path: str, + video_chunks_number: int = 1, ) -> None: - video_np = video.cpu().numpy() - - _, height, width, _ = video_np.shape + """ + Encodes a video with audio using the PyAV library. Based on code from the original LTX-2 repo: + https://github.com/Lightricks/LTX-2/blob/4f410820b198e05074a1e92de793e3b59e9ab5a0/packages/ltx-pipelines/src/ltx_pipelines/utils/media_io.py#L182 + + Args: + video (`List[PIL.Image.Image]` or `np.ndarray` or `torch.Tensor`): + A video tensor of shape [frames, height, width, channels] with integer pixel values in [0, 255]. If the + input is a `np.ndarray`, it is expected to be a float array with values in [0, 1] (which is what pipelines + usually return with `output_type="np"`). + fps (`int`) + The frames per second (FPS) of the encoded video. + audio (`torch.Tensor`, *optional*): + An audio waveform of shape [audio_channels, samples]. + audio_sample_rate: (`int`, *optional*): + The sampling rate of the audio waveform. For LTX 2, this is typically 24000 (24 kHz). + output_path (`str`): + The path to save the encoded video to. + video_chunks_number (`int`, *optional*, defaults to `1`): + The number of chunks to split the video into for encoding. Each chunk will be encoded separately. The + number of chunks to use often depends on the tiling config for the video VAE. + """ + if isinstance(video, list) and isinstance(video[0], PIL.Image.Image): + # Pipeline output_type="pil"; assumes each image is in "RGB" mode + video_frames = [np.array(frame) for frame in video] + video = np.stack(video_frames, axis=0) + video = torch.from_numpy(video) + elif isinstance(video, np.ndarray): + # Pipeline output_type="np" + is_denormalized = np.logical_and(np.zeros_like(video) <= video, video <= np.ones_like(video)) + if np.all(is_denormalized): + video = (video * 255).round().astype("uint8") + video = torch.from_numpy(video) + + if isinstance(video, torch.Tensor): + # Split into video_chunks_number along the frame dimension + video = torch.tensor_split(video, video_chunks_number, dim=0) + video = iter(video) + + first_chunk = next(video) + + _, height, width, _ = first_chunk.shape container = av.open(output_path, mode="w") stream = container.add_stream("libx264", rate=int(fps)) @@ -119,10 +167,12 @@ def encode_video( audio_stream = _prepare_audio_stream(container, audio_sample_rate) - for frame_array in video_np: - frame = av.VideoFrame.from_ndarray(frame_array, format="rgb24") - for packet in stream.encode(frame): - container.mux(packet) + for video_chunk in tqdm(chain([first_chunk], video), total=video_chunks_number, desc="Encoding video chunks"): + video_chunk_cpu = video_chunk.to("cpu").numpy() + for frame_array in video_chunk_cpu: + frame = av.VideoFrame.from_ndarray(frame_array, format="rgb24") + for packet in stream.encode(frame): + container.mux(packet) # Flush encoder for packet in stream.encode(): diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2.py index a92a7a2c8869..cb01159a81a7 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2.py @@ -69,8 +69,6 @@ ... output_type="np", ... return_dict=False, ... ) - >>> video = (video * 255).round().astype("uint8") - >>> video = torch.from_numpy(video) >>> encode_video( ... video[0], diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py index 04d7ee89c52a..c120e1f010e9 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py @@ -75,8 +75,6 @@ ... output_type="np", ... return_dict=False, ... ) - >>> video = (video * 255).round().astype("uint8") - >>> video = torch.from_numpy(video) >>> encode_video( ... video[0], diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2_latent_upsample.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2_latent_upsample.py index 340efd10f24f..b0db1bdee317 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2_latent_upsample.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2_latent_upsample.py @@ -76,8 +76,6 @@ ... output_type="np", ... return_dict=False, ... )[0] - >>> video = (video * 255).round().astype("uint8") - >>> video = torch.from_numpy(video) >>> encode_video( ... video[0],