|
| 1 | +# coding=utf-8 |
| 2 | +# Copyright 2023 HuggingFace Inc. |
| 3 | +# |
| 4 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | +# you may not use this file except in compliance with the License. |
| 6 | +# You may obtain a copy of the License at |
| 7 | +# |
| 8 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | +# |
| 10 | +# Unless required by applicable law or agreed to in writing, software |
| 11 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | +# See the License for the specific language governing permissions and |
| 14 | +# limitations under the License. |
| 15 | + |
| 16 | +import tempfile |
| 17 | +import unittest |
| 18 | + |
| 19 | +import torch |
| 20 | + |
| 21 | +from diffusers import UNet2DConditionModel |
| 22 | +from diffusers.training_utils import EMAModel |
| 23 | +from diffusers.utils.testing_utils import skip_mps, torch_device |
| 24 | + |
| 25 | + |
| 26 | +class EMAModelTests(unittest.TestCase): |
| 27 | + model_id = "hf-internal-testing/tiny-stable-diffusion-pipe" |
| 28 | + batch_size = 1 |
| 29 | + prompt_length = 77 |
| 30 | + text_encoder_hidden_dim = 32 |
| 31 | + num_in_channels = 4 |
| 32 | + latent_height = latent_width = 64 |
| 33 | + generator = torch.manual_seed(0) |
| 34 | + |
| 35 | + def get_models(self, decay=0.9999): |
| 36 | + unet = UNet2DConditionModel.from_pretrained(self.model_id, subfolder="unet", device=torch_device) |
| 37 | + ema_unet = UNet2DConditionModel.from_pretrained(self.model_id, subfolder="unet") |
| 38 | + ema_unet = EMAModel( |
| 39 | + ema_unet.parameters(), decay=decay, model_cls=UNet2DConditionModel, model_config=ema_unet.config |
| 40 | + ) |
| 41 | + return unet, ema_unet |
| 42 | + |
| 43 | + def get_dummy_inputs(self): |
| 44 | + noisy_latents = torch.randn( |
| 45 | + self.batch_size, self.num_in_channels, self.latent_height, self.latent_width, generator=self.generator |
| 46 | + ).to(torch_device) |
| 47 | + timesteps = torch.randint(0, 1000, size=(self.batch_size,), generator=self.generator).to(torch_device) |
| 48 | + encoder_hidden_states = torch.randn( |
| 49 | + self.batch_size, self.prompt_length, self.text_encoder_hidden_dim, generator=self.generator |
| 50 | + ).to(torch_device) |
| 51 | + return noisy_latents, timesteps, encoder_hidden_states |
| 52 | + |
| 53 | + def simulate_backprop(self, unet): |
| 54 | + updated_state_dict = {} |
| 55 | + for k, param in unet.state_dict().items(): |
| 56 | + updated_param = torch.randn_like(param) + (param * torch.randn_like(param)) |
| 57 | + updated_state_dict.update({k: updated_param}) |
| 58 | + unet.load_state_dict(updated_state_dict) |
| 59 | + return unet |
| 60 | + |
| 61 | + def test_optimization_steps_updated(self): |
| 62 | + unet, ema_unet = self.get_models() |
| 63 | + # Take the first (hypothetical) EMA step. |
| 64 | + ema_unet.step(unet.parameters()) |
| 65 | + assert ema_unet.optimization_step == 1 |
| 66 | + |
| 67 | + # Take two more. |
| 68 | + for _ in range(2): |
| 69 | + ema_unet.step(unet.parameters()) |
| 70 | + assert ema_unet.optimization_step == 3 |
| 71 | + |
| 72 | + def test_shadow_params_not_updated(self): |
| 73 | + unet, ema_unet = self.get_models() |
| 74 | + # Since the `unet` is not being updated (i.e., backprop'd) |
| 75 | + # there won't be any difference between the `params` of `unet` |
| 76 | + # and `ema_unet` even if we call `ema_unet.step(unet.parameters())`. |
| 77 | + ema_unet.step(unet.parameters()) |
| 78 | + orig_params = list(unet.parameters()) |
| 79 | + for s_param, param in zip(ema_unet.shadow_params, orig_params): |
| 80 | + assert torch.allclose(s_param, param) |
| 81 | + |
| 82 | + # The above holds true even if we call `ema.step()` multiple times since |
| 83 | + # `unet` params are still not being updated. |
| 84 | + for _ in range(4): |
| 85 | + ema_unet.step(unet.parameters()) |
| 86 | + for s_param, param in zip(ema_unet.shadow_params, orig_params): |
| 87 | + assert torch.allclose(s_param, param) |
| 88 | + |
| 89 | + def test_shadow_params_updated(self): |
| 90 | + unet, ema_unet = self.get_models() |
| 91 | + # Here we simulate the parameter updates for `unet`. Since there might |
| 92 | + # be some parameters which are initialized to zero we take extra care to |
| 93 | + # initialize their values to something non-zero before the multiplication. |
| 94 | + unet_pseudo_updated_step_one = self.simulate_backprop(unet) |
| 95 | + |
| 96 | + # Take the EMA step. |
| 97 | + ema_unet.step(unet_pseudo_updated_step_one.parameters()) |
| 98 | + |
| 99 | + # Now the EMA'd parameters won't be equal to the original model parameters. |
| 100 | + orig_params = list(unet_pseudo_updated_step_one.parameters()) |
| 101 | + for s_param, param in zip(ema_unet.shadow_params, orig_params): |
| 102 | + assert ~torch.allclose(s_param, param) |
| 103 | + |
| 104 | + # Ensure this is the case when we take multiple EMA steps. |
| 105 | + for _ in range(4): |
| 106 | + ema_unet.step(unet.parameters()) |
| 107 | + for s_param, param in zip(ema_unet.shadow_params, orig_params): |
| 108 | + assert ~torch.allclose(s_param, param) |
| 109 | + |
| 110 | + def test_consecutive_shadow_params_updated(self): |
| 111 | + # If we call EMA step after a backpropagation consecutively for two times, |
| 112 | + # the shadow params from those two steps should be different. |
| 113 | + unet, ema_unet = self.get_models() |
| 114 | + |
| 115 | + # First backprop + EMA |
| 116 | + unet_step_one = self.simulate_backprop(unet) |
| 117 | + ema_unet.step(unet_step_one.parameters()) |
| 118 | + step_one_shadow_params = ema_unet.shadow_params |
| 119 | + |
| 120 | + # Second backprop + EMA |
| 121 | + unet_step_two = self.simulate_backprop(unet_step_one) |
| 122 | + ema_unet.step(unet_step_two.parameters()) |
| 123 | + step_two_shadow_params = ema_unet.shadow_params |
| 124 | + |
| 125 | + for step_one, step_two in zip(step_one_shadow_params, step_two_shadow_params): |
| 126 | + assert ~torch.allclose(step_one, step_two) |
| 127 | + |
| 128 | + def test_zero_decay(self): |
| 129 | + # If there's no decay even if there are backprops, EMA steps |
| 130 | + # won't take any effect i.e., the shadow params would remain the |
| 131 | + # same. |
| 132 | + unet, ema_unet = self.get_models(decay=0.0) |
| 133 | + unet_step_one = self.simulate_backprop(unet) |
| 134 | + ema_unet.step(unet_step_one.parameters()) |
| 135 | + step_one_shadow_params = ema_unet.shadow_params |
| 136 | + |
| 137 | + unet_step_two = self.simulate_backprop(unet_step_one) |
| 138 | + ema_unet.step(unet_step_two.parameters()) |
| 139 | + step_two_shadow_params = ema_unet.shadow_params |
| 140 | + |
| 141 | + for step_one, step_two in zip(step_one_shadow_params, step_two_shadow_params): |
| 142 | + assert torch.allclose(step_one, step_two) |
| 143 | + |
| 144 | + @skip_mps |
| 145 | + def test_serialization(self): |
| 146 | + unet, ema_unet = self.get_models() |
| 147 | + noisy_latents, timesteps, encoder_hidden_states = self.get_dummy_inputs() |
| 148 | + |
| 149 | + with tempfile.TemporaryDirectory() as tmpdir: |
| 150 | + ema_unet.save_pretrained(tmpdir) |
| 151 | + loaded_unet = UNet2DConditionModel.from_pretrained(tmpdir, model_cls=UNet2DConditionModel) |
| 152 | + |
| 153 | + # Since no EMA step has been performed the outputs should match. |
| 154 | + output = unet(noisy_latents, timesteps, encoder_hidden_states).sample |
| 155 | + output_loaded = loaded_unet(noisy_latents, timesteps, encoder_hidden_states).sample |
| 156 | + |
| 157 | + assert torch.allclose(output, output_loaded) |
0 commit comments