|
14 | 14 |
|
15 | 15 | import inspect |
16 | 16 | import warnings |
17 | | -from typing import Callable, List, Optional, Union |
| 17 | +from typing import Any, Callable, Dict, List, Optional, Union |
18 | 18 |
|
19 | 19 | import numpy as np |
20 | 20 | import PIL |
@@ -744,6 +744,7 @@ def __call__( |
744 | 744 | return_dict: bool = True, |
745 | 745 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, |
746 | 746 | callback_steps: int = 1, |
| 747 | + cross_attention_kwargs: Optional[Dict[str, Any]] = None, |
747 | 748 | ): |
748 | 749 | r""" |
749 | 750 | Function invoked when calling the pipeline for generation. |
@@ -815,7 +816,10 @@ def __call__( |
815 | 816 | callback_steps (`int`, *optional*, defaults to 1): |
816 | 817 | The frequency at which the `callback` function will be called. If not specified, the callback will be |
817 | 818 | called at every step. |
818 | | -
|
| 819 | + cross_attention_kwargs (`dict`, *optional*): |
| 820 | + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under |
| 821 | + `self.processor` in |
| 822 | + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). |
819 | 823 | Examples: |
820 | 824 |
|
821 | 825 | ```py |
@@ -966,9 +970,13 @@ def __call__( |
966 | 970 | latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1) |
967 | 971 |
|
968 | 972 | # predict the noise residual |
969 | | - noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds, return_dict=False)[ |
970 | | - 0 |
971 | | - ] |
| 973 | + noise_pred = self.unet( |
| 974 | + latent_model_input, |
| 975 | + t, |
| 976 | + encoder_hidden_states=prompt_embeds, |
| 977 | + cross_attention_kwargs=cross_attention_kwargs, |
| 978 | + return_dict=False, |
| 979 | + )[0] |
972 | 980 |
|
973 | 981 | # perform guidance |
974 | 982 | if do_classifier_free_guidance: |
|
0 commit comments