@@ -194,6 +194,7 @@ def __init__(
194194 )
195195
196196 self .stage = None
197+ self ._cache_backend = None
197198
198199 self .vae_scale_factor = 2 ** len (self .vae .temperal_downsample ) if getattr (self , "vae" , None ) else 8
199200 self .latent_channels = self .vae .config .z_dim if getattr (self , "vae" , None ) else 16
@@ -543,31 +544,39 @@ def diffuse(
543544 if image_latents is not None :
544545 latent_model_input = torch .cat ([latents , image_latents ], dim = 1 )
545546
546- noise_pred = self .transformer (
547- hidden_states = latent_model_input ,
548- timestep = timestep / 1000 ,
549- guidance = guidance ,
550- encoder_hidden_states_mask = prompt_embeds_mask ,
551- encoder_hidden_states = prompt_embeds ,
552- img_shapes = img_shapes ,
553- txt_seq_lens = txt_seq_lens ,
554- attention_kwargs = self .attention_kwargs ,
555- return_dict = False ,
556- )[0 ]
547+ transformer_kwargs = {
548+ "hidden_states" : latent_model_input ,
549+ "timestep" : timestep / 1000 ,
550+ "guidance" : guidance ,
551+ "encoder_hidden_states_mask" : prompt_embeds_mask ,
552+ "encoder_hidden_states" : prompt_embeds ,
553+ "img_shapes" : img_shapes ,
554+ "txt_seq_lens" : txt_seq_lens ,
555+ "attention_kwargs" : self .attention_kwargs ,
556+ "return_dict" : False ,
557+ }
558+ if self ._cache_backend is not None :
559+ transformer_kwargs ["cache_branch" ] = "positive"
560+
561+ noise_pred = self .transformer (** transformer_kwargs )[0 ]
557562 noise_pred = noise_pred [:, : latents .size (1 )]
558563
559564 if do_true_cfg :
560- neg_noise_pred = self .transformer (
561- hidden_states = latent_model_input ,
562- timestep = timestep / 1000 ,
563- guidance = guidance ,
564- encoder_hidden_states_mask = negative_prompt_embeds_mask ,
565- encoder_hidden_states = negative_prompt_embeds ,
566- img_shapes = img_shapes ,
567- txt_seq_lens = negative_txt_seq_lens ,
568- attention_kwargs = self .attention_kwargs ,
569- return_dict = False ,
570- )[0 ]
565+ neg_transformer_kwargs = {
566+ "hidden_states" : latent_model_input ,
567+ "timestep" : timestep / 1000 ,
568+ "guidance" : guidance ,
569+ "encoder_hidden_states_mask" : negative_prompt_embeds_mask ,
570+ "encoder_hidden_states" : negative_prompt_embeds ,
571+ "img_shapes" : img_shapes ,
572+ "txt_seq_lens" : negative_txt_seq_lens ,
573+ "attention_kwargs" : self .attention_kwargs ,
574+ "return_dict" : False ,
575+ }
576+ if self ._cache_backend is not None :
577+ neg_transformer_kwargs ["cache_branch" ] = "negative"
578+
579+ neg_noise_pred = self .transformer (** neg_transformer_kwargs )[0 ]
571580 neg_noise_pred = neg_noise_pred [:, : latents .size (1 )]
572581 comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred )
573582 cond_norm = torch .norm (noise_pred , dim = - 1 , keepdim = True )
0 commit comments