Skip to content

Commit e38bc51

Browse files
williambermandg845
authored andcommitted
unet time embedding activation function (huggingface#3048)
* unet time embedding activation function * typo act_fn -> time_embedding_act_fn * flatten conditional
1 parent b80790f commit e38bc51

File tree

2 files changed

+42
-0
lines changed

2 files changed

+42
-0
lines changed

src/diffusers/models/unet_2d_condition.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import torch
1818
import torch.nn as nn
19+
import torch.nn.functional as F
1920
import torch.utils.checkpoint
2021

2122
from ..configuration_utils import ConfigMixin, register_to_config
@@ -101,6 +102,9 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
101102
class conditioning with `class_embed_type` equal to `None`.
102103
time_embedding_type (`str`, *optional*, default to `positional`):
103104
The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
105+
time_embedding_act_fn (`str`, *optional*, default to `None`):
106+
Optional activation function to use on the time embeddings only one time before they as passed to the rest
107+
of the unet. Choose from `silu`, `mish`, `gelu`, and `swish`.
104108
timestep_post_act (`str, *optional*, default to `None`):
105109
The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
106110
time_cond_proj_dim (`int`, *optional*, default to `None`):
@@ -157,6 +161,7 @@ def __init__(
157161
resnet_skip_time_act: bool = False,
158162
resnet_out_scale_factor: int = 1.0,
159163
time_embedding_type: str = "positional",
164+
time_embedding_act_fn: Optional[str] = None,
160165
timestep_post_act: Optional[str] = None,
161166
time_cond_proj_dim: Optional[int] = None,
162167
conv_in_kernel: int = 3,
@@ -267,6 +272,19 @@ def __init__(
267272
else:
268273
self.class_embedding = None
269274

275+
if time_embedding_act_fn is None:
276+
self.time_embed_act = None
277+
elif time_embedding_act_fn == "swish":
278+
self.time_embed_act = lambda x: F.silu(x)
279+
elif time_embedding_act_fn == "mish":
280+
self.time_embed_act = nn.Mish()
281+
elif time_embedding_act_fn == "silu":
282+
self.time_embed_act = nn.SiLU()
283+
elif time_embedding_act_fn == "gelu":
284+
self.time_embed_act = nn.GELU()
285+
else:
286+
raise ValueError(f"Unsupported activation function: {time_embedding_act_fn}")
287+
270288
self.down_blocks = nn.ModuleList([])
271289
self.up_blocks = nn.ModuleList([])
272290

@@ -657,6 +675,9 @@ def forward(
657675
else:
658676
emb = emb + class_emb
659677

678+
if self.time_embed_act is not None:
679+
emb = self.time_embed_act(emb)
680+
660681
if self.encoder_hid_proj is not None:
661682
encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
662683

src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import numpy as np
44
import torch
55
import torch.nn as nn
6+
import torch.nn.functional as F
67

78
from ...configuration_utils import ConfigMixin, register_to_config
89
from ...models import ModelMixin
@@ -182,6 +183,9 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
182183
class conditioning with `class_embed_type` equal to `None`.
183184
time_embedding_type (`str`, *optional*, default to `positional`):
184185
The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
186+
time_embedding_act_fn (`str`, *optional*, default to `None`):
187+
Optional activation function to use on the time embeddings only one time before they as passed to the rest
188+
of the unet. Choose from `silu`, `mish`, `gelu`, and `swish`.
185189
timestep_post_act (`str, *optional*, default to `None`):
186190
The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
187191
time_cond_proj_dim (`int`, *optional*, default to `None`):
@@ -243,6 +247,7 @@ def __init__(
243247
resnet_skip_time_act: bool = False,
244248
resnet_out_scale_factor: int = 1.0,
245249
time_embedding_type: str = "positional",
250+
time_embedding_act_fn: Optional[str] = None,
246251
timestep_post_act: Optional[str] = None,
247252
time_cond_proj_dim: Optional[int] = None,
248253
conv_in_kernel: int = 3,
@@ -359,6 +364,19 @@ def __init__(
359364
else:
360365
self.class_embedding = None
361366

367+
if time_embedding_act_fn is None:
368+
self.time_embed_act = None
369+
elif time_embedding_act_fn == "swish":
370+
self.time_embed_act = lambda x: F.silu(x)
371+
elif time_embedding_act_fn == "mish":
372+
self.time_embed_act = nn.Mish()
373+
elif time_embedding_act_fn == "silu":
374+
self.time_embed_act = nn.SiLU()
375+
elif time_embedding_act_fn == "gelu":
376+
self.time_embed_act = nn.GELU()
377+
else:
378+
raise ValueError(f"Unsupported activation function: {time_embedding_act_fn}")
379+
362380
self.down_blocks = nn.ModuleList([])
363381
self.up_blocks = nn.ModuleList([])
364382

@@ -752,6 +770,9 @@ def forward(
752770
else:
753771
emb = emb + class_emb
754772

773+
if self.time_embed_act is not None:
774+
emb = self.time_embed_act(emb)
775+
755776
if self.encoder_hid_proj is not None:
756777
encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
757778

0 commit comments

Comments
 (0)