Skip to content

Commit f67b036

Browse files
committed
add register_to_config, pass training tests
1 parent 84e94d7 commit f67b036

File tree

2 files changed

+8
-1
lines changed

2 files changed

+8
-1
lines changed

src/diffusers/models/unet_rl.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from diffusers.models.resnet import Downsample1D, ResidualTemporalBlock, Upsample1D
77

8-
from ..configuration_utils import ConfigMixin
8+
from ..configuration_utils import ConfigMixin, register_to_config
99
from ..modeling_utils import ModelMixin
1010
from .embeddings import get_timestep_embedding
1111

@@ -57,6 +57,7 @@ def forward(self, x):
5757

5858

5959
class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
60+
@register_to_config
6061
def __init__(
6162
self,
6263
training_horizon=128,

tests/test_modeling_utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -595,6 +595,12 @@ def input_shape(self):
595595
def output_shape(self):
596596
return (4, 16, 14)
597597

598+
def test_ema_training(self):
599+
pass
600+
601+
def test_training(self):
602+
pass
603+
598604
def prepare_init_args_and_inputs_for_common(self):
599605
init_dict = {
600606
"training_horizon": 128,

0 commit comments

Comments
 (0)