|
16 | 16 |
|
17 | 17 | import torch
|
18 | 18 | import torch.nn as nn
|
| 19 | +import torch.nn.functional as F |
19 | 20 | import torch.utils.checkpoint
|
20 | 21 |
|
21 | 22 | from ..configuration_utils import ConfigMixin, register_to_config
|
@@ -101,6 +102,9 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
101 | 102 | class conditioning with `class_embed_type` equal to `None`.
|
102 | 103 | time_embedding_type (`str`, *optional*, default to `positional`):
|
103 | 104 | The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
|
| 105 | + time_embedding_act_fn (`str`, *optional*, default to `None`): |
| 106 | + Optional activation function to use on the time embeddings only one time before they as passed to the rest |
| 107 | + of the unet. Choose from `silu`, `mish`, `gelu`, and `swish`. |
104 | 108 | timestep_post_act (`str, *optional*, default to `None`):
|
105 | 109 | The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
|
106 | 110 | time_cond_proj_dim (`int`, *optional*, default to `None`):
|
@@ -157,6 +161,7 @@ def __init__(
|
157 | 161 | resnet_skip_time_act: bool = False,
|
158 | 162 | resnet_out_scale_factor: int = 1.0,
|
159 | 163 | time_embedding_type: str = "positional",
|
| 164 | + time_embedding_act_fn: Optional[str] = None, |
160 | 165 | timestep_post_act: Optional[str] = None,
|
161 | 166 | time_cond_proj_dim: Optional[int] = None,
|
162 | 167 | conv_in_kernel: int = 3,
|
@@ -267,6 +272,19 @@ def __init__(
|
267 | 272 | else:
|
268 | 273 | self.class_embedding = None
|
269 | 274 |
|
| 275 | + if time_embedding_act_fn is None: |
| 276 | + self.time_embed_act = None |
| 277 | + elif time_embedding_act_fn == "swish": |
| 278 | + self.time_embed_act = lambda x: F.silu(x) |
| 279 | + elif time_embedding_act_fn == "mish": |
| 280 | + self.time_embed_act = nn.Mish() |
| 281 | + elif time_embedding_act_fn == "silu": |
| 282 | + self.time_embed_act = nn.SiLU() |
| 283 | + elif time_embedding_act_fn == "gelu": |
| 284 | + self.time_embed_act = nn.GELU() |
| 285 | + else: |
| 286 | + raise ValueError(f"Unsupported activation function: {time_embedding_act_fn}") |
| 287 | + |
270 | 288 | self.down_blocks = nn.ModuleList([])
|
271 | 289 | self.up_blocks = nn.ModuleList([])
|
272 | 290 |
|
@@ -657,6 +675,9 @@ def forward(
|
657 | 675 | else:
|
658 | 676 | emb = emb + class_emb
|
659 | 677 |
|
| 678 | + if self.time_embed_act is not None: |
| 679 | + emb = self.time_embed_act(emb) |
| 680 | + |
660 | 681 | if self.encoder_hid_proj is not None:
|
661 | 682 | encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
|
662 | 683 |
|
|
0 commit comments