Skip to content

Commit f2f0f95

Browse files
authored
Text2video zero refinements (huggingface#3070)
* fix progress bar issue in pipeline_text_to_video_zero.py. Copy scheduler after first backward * fix tensor loading in test_text_to_video_zero.py * make style && make quality
1 parent 2e681ff commit f2f0f95

File tree

2 files changed

+10
-6
lines changed

2 files changed

+10
-6
lines changed

pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import copy
12
from dataclasses import dataclass
23
from typing import Callable, List, Optional, Union
34

@@ -56,8 +57,8 @@ def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_ma
5657
is_cross_attention = encoder_hidden_states is not None
5758
if encoder_hidden_states is None:
5859
encoder_hidden_states = hidden_states
59-
elif attn.cross_attention_norm:
60-
encoder_hidden_states = attn.norm_cross(encoder_hidden_states)
60+
elif attn.norm_cross:
61+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
6162

6263
key = attn.to_k(encoder_hidden_states)
6364
value = attn.to_v(encoder_hidden_states)
@@ -285,7 +286,8 @@ def backward_loop(
285286
latents: latents of backward process output at time timesteps[-1]
286287
"""
287288
do_classifier_free_guidance = guidance_scale > 1.0
288-
with self.progress_bar(total=len(timesteps)) as progress_bar:
289+
num_steps = (len(timesteps) - num_warmup_steps) // self.scheduler.order
290+
with self.progress_bar(total=num_steps) as progress_bar:
289291
for i, t in enumerate(timesteps):
290292
# expand the latents if we are doing classifier free guidance
291293
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
@@ -465,6 +467,7 @@ def __call__(
465467
extra_step_kwargs=extra_step_kwargs,
466468
num_warmup_steps=num_warmup_steps,
467469
)
470+
scheduler_copy = copy.deepcopy(self.scheduler)
468471

469472
# Perform the second backward process up to time T_0
470473
x_1_t0 = self.backward_loop(
@@ -475,7 +478,7 @@ def __call__(
475478
callback=callback,
476479
callback_steps=callback_steps,
477480
extra_step_kwargs=extra_step_kwargs,
478-
num_warmup_steps=num_warmup_steps,
481+
num_warmup_steps=0,
479482
)
480483

481484
# Propagate first frame latents at time T_0 to remaining frames
@@ -502,7 +505,7 @@ def __call__(
502505
b, l, d = prompt_embeds.size()
503506
prompt_embeds = prompt_embeds[:, None].repeat(1, video_length, 1, 1).reshape(b * video_length, l, d)
504507

505-
self.scheduler.set_timesteps(num_inference_steps, device=device)
508+
self.scheduler = scheduler_copy
506509
x_1k_0 = self.backward_loop(
507510
timesteps=timesteps[-t1 - 1 :],
508511
prompt_embeds=prompt_embeds,
@@ -511,7 +514,7 @@ def __call__(
511514
callback=callback,
512515
callback_steps=callback_steps,
513516
extra_step_kwargs=extra_step_kwargs,
514-
num_warmup_steps=num_warmup_steps,
517+
num_warmup_steps=0,
515518
)
516519
latents = x_1k_0
517520

utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@
8686
load_hf_numpy,
8787
load_image,
8888
load_numpy,
89+
load_pt,
8990
nightly,
9091
parse_flag_from_env,
9192
print_tensor_test,

0 commit comments

Comments
 (0)