|
| 1 | +from typing import Optional, Tuple, Union |
| 2 | + |
| 3 | +import torch |
| 4 | + |
| 5 | +from diffusers import DDIMScheduler, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel |
| 6 | +from diffusers.pipeline_utils import ImagePipelineOutput |
| 7 | +from diffusers.schedulers.scheduling_ddim import DDIMSchedulerOutput |
| 8 | +from diffusers.schedulers.scheduling_ddpm import DDPMSchedulerOutput |
| 9 | +from einops import rearrange, reduce |
| 10 | + |
| 11 | + |
| 12 | +BITS = 8 |
| 13 | + |
| 14 | + |
| 15 | +# convert to bit representations and back taken from https://github.com/lucidrains/bit-diffusion/blob/main/bit_diffusion/bit_diffusion.py |
| 16 | +def decimal_to_bits(x, bits=BITS): |
| 17 | + """expects image tensor ranging from 0 to 1, outputs bit tensor ranging from -1 to 1""" |
| 18 | + device = x.device |
| 19 | + |
| 20 | + x = (x * 255).int().clamp(0, 255) |
| 21 | + |
| 22 | + mask = 2 ** torch.arange(bits - 1, -1, -1, device=device) |
| 23 | + mask = rearrange(mask, "d -> d 1 1") |
| 24 | + x = rearrange(x, "b c h w -> b c 1 h w") |
| 25 | + |
| 26 | + bits = ((x & mask) != 0).float() |
| 27 | + bits = rearrange(bits, "b c d h w -> b (c d) h w") |
| 28 | + bits = bits * 2 - 1 |
| 29 | + return bits |
| 30 | + |
| 31 | + |
| 32 | +def bits_to_decimal(x, bits=BITS): |
| 33 | + """expects bits from -1 to 1, outputs image tensor from 0 to 1""" |
| 34 | + device = x.device |
| 35 | + |
| 36 | + x = (x > 0).int() |
| 37 | + mask = 2 ** torch.arange(bits - 1, -1, -1, device=device, dtype=torch.int32) |
| 38 | + |
| 39 | + mask = rearrange(mask, "d -> d 1 1") |
| 40 | + x = rearrange(x, "b (c d) h w -> b c d h w", d=8) |
| 41 | + dec = reduce(x * mask, "b c d h w -> b c h w", "sum") |
| 42 | + return (dec / 255).clamp(0.0, 1.0) |
| 43 | + |
| 44 | + |
| 45 | +# modified scheduler step functions for clamping the predicted x_0 between -bit_scale and +bit_scale |
| 46 | +def ddim_bit_scheduler_step( |
| 47 | + self, |
| 48 | + model_output: torch.FloatTensor, |
| 49 | + timestep: int, |
| 50 | + sample: torch.FloatTensor, |
| 51 | + eta: float = 0.0, |
| 52 | + use_clipped_model_output: bool = True, |
| 53 | + generator=None, |
| 54 | + return_dict: bool = True, |
| 55 | +) -> Union[DDIMSchedulerOutput, Tuple]: |
| 56 | + """ |
| 57 | + Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion |
| 58 | + process from the learned model outputs (most often the predicted noise). |
| 59 | + Args: |
| 60 | + model_output (`torch.FloatTensor`): direct output from learned diffusion model. |
| 61 | + timestep (`int`): current discrete timestep in the diffusion chain. |
| 62 | + sample (`torch.FloatTensor`): |
| 63 | + current instance of sample being created by diffusion process. |
| 64 | + eta (`float`): weight of noise for added noise in diffusion step. |
| 65 | + use_clipped_model_output (`bool`): TODO |
| 66 | + generator: random number generator. |
| 67 | + return_dict (`bool`): option for returning tuple rather than DDIMSchedulerOutput class |
| 68 | + Returns: |
| 69 | + [`~schedulers.scheduling_utils.DDIMSchedulerOutput`] or `tuple`: |
| 70 | + [`~schedulers.scheduling_utils.DDIMSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When |
| 71 | + returning a tuple, the first element is the sample tensor. |
| 72 | + """ |
| 73 | + if self.num_inference_steps is None: |
| 74 | + raise ValueError( |
| 75 | + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" |
| 76 | + ) |
| 77 | + |
| 78 | + # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf |
| 79 | + # Ideally, read DDIM paper in-detail understanding |
| 80 | + |
| 81 | + # Notation (<variable name> -> <name in paper> |
| 82 | + # - pred_noise_t -> e_theta(x_t, t) |
| 83 | + # - pred_original_sample -> f_theta(x_t, t) or x_0 |
| 84 | + # - std_dev_t -> sigma_t |
| 85 | + # - eta -> η |
| 86 | + # - pred_sample_direction -> "direction pointing to x_t" |
| 87 | + # - pred_prev_sample -> "x_t-1" |
| 88 | + |
| 89 | + # 1. get previous step value (=t-1) |
| 90 | + prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps |
| 91 | + |
| 92 | + # 2. compute alphas, betas |
| 93 | + alpha_prod_t = self.alphas_cumprod[timestep] |
| 94 | + alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod |
| 95 | + |
| 96 | + beta_prod_t = 1 - alpha_prod_t |
| 97 | + |
| 98 | + # 3. compute predicted original sample from predicted noise also called |
| 99 | + # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf |
| 100 | + pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) |
| 101 | + |
| 102 | + # 4. Clip "predicted x_0" |
| 103 | + scale = self.bit_scale |
| 104 | + if self.config.clip_sample: |
| 105 | + pred_original_sample = torch.clamp(pred_original_sample, -scale, scale) |
| 106 | + |
| 107 | + # 5. compute variance: "sigma_t(η)" -> see formula (16) |
| 108 | + # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) |
| 109 | + variance = self._get_variance(timestep, prev_timestep) |
| 110 | + std_dev_t = eta * variance ** (0.5) |
| 111 | + |
| 112 | + if use_clipped_model_output: |
| 113 | + # the model_output is always re-derived from the clipped x_0 in Glide |
| 114 | + model_output = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) |
| 115 | + |
| 116 | + # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf |
| 117 | + pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * model_output |
| 118 | + |
| 119 | + # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf |
| 120 | + prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction |
| 121 | + |
| 122 | + if eta > 0: |
| 123 | + # randn_like does not support generator https://github.com/pytorch/pytorch/issues/27072 |
| 124 | + device = model_output.device if torch.is_tensor(model_output) else "cpu" |
| 125 | + noise = torch.randn(model_output.shape, dtype=model_output.dtype, generator=generator).to(device) |
| 126 | + variance = self._get_variance(timestep, prev_timestep) ** (0.5) * eta * noise |
| 127 | + |
| 128 | + prev_sample = prev_sample + variance |
| 129 | + |
| 130 | + if not return_dict: |
| 131 | + return (prev_sample,) |
| 132 | + |
| 133 | + return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample) |
| 134 | + |
| 135 | + |
| 136 | +def ddpm_bit_scheduler_step( |
| 137 | + self, |
| 138 | + model_output: torch.FloatTensor, |
| 139 | + timestep: int, |
| 140 | + sample: torch.FloatTensor, |
| 141 | + predict_epsilon=True, |
| 142 | + generator=None, |
| 143 | + return_dict: bool = True, |
| 144 | +) -> Union[DDPMSchedulerOutput, Tuple]: |
| 145 | + """ |
| 146 | + Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion |
| 147 | + process from the learned model outputs (most often the predicted noise). |
| 148 | + Args: |
| 149 | + model_output (`torch.FloatTensor`): direct output from learned diffusion model. |
| 150 | + timestep (`int`): current discrete timestep in the diffusion chain. |
| 151 | + sample (`torch.FloatTensor`): |
| 152 | + current instance of sample being created by diffusion process. |
| 153 | + predict_epsilon (`bool`): |
| 154 | + optional flag to use when model predicts the samples directly instead of the noise, epsilon. |
| 155 | + generator: random number generator. |
| 156 | + return_dict (`bool`): option for returning tuple rather than DDPMSchedulerOutput class |
| 157 | + Returns: |
| 158 | + [`~schedulers.scheduling_utils.DDPMSchedulerOutput`] or `tuple`: |
| 159 | + [`~schedulers.scheduling_utils.DDPMSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When |
| 160 | + returning a tuple, the first element is the sample tensor. |
| 161 | + """ |
| 162 | + t = timestep |
| 163 | + |
| 164 | + if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]: |
| 165 | + model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1) |
| 166 | + else: |
| 167 | + predicted_variance = None |
| 168 | + |
| 169 | + # 1. compute alphas, betas |
| 170 | + alpha_prod_t = self.alphas_cumprod[t] |
| 171 | + alpha_prod_t_prev = self.alphas_cumprod[t - 1] if t > 0 else self.one |
| 172 | + beta_prod_t = 1 - alpha_prod_t |
| 173 | + beta_prod_t_prev = 1 - alpha_prod_t_prev |
| 174 | + |
| 175 | + # 2. compute predicted original sample from predicted noise also called |
| 176 | + # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf |
| 177 | + if predict_epsilon: |
| 178 | + pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) |
| 179 | + else: |
| 180 | + pred_original_sample = model_output |
| 181 | + |
| 182 | + # 3. Clip "predicted x_0" |
| 183 | + scale = self.bit_scale |
| 184 | + if self.config.clip_sample: |
| 185 | + pred_original_sample = torch.clamp(pred_original_sample, -scale, scale) |
| 186 | + |
| 187 | + # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t |
| 188 | + # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf |
| 189 | + pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * self.betas[t]) / beta_prod_t |
| 190 | + current_sample_coeff = self.alphas[t] ** (0.5) * beta_prod_t_prev / beta_prod_t |
| 191 | + |
| 192 | + # 5. Compute predicted previous sample µ_t |
| 193 | + # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf |
| 194 | + pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * sample |
| 195 | + |
| 196 | + # 6. Add noise |
| 197 | + variance = 0 |
| 198 | + if t > 0: |
| 199 | + noise = torch.randn( |
| 200 | + model_output.size(), dtype=model_output.dtype, layout=model_output.layout, generator=generator |
| 201 | + ).to(model_output.device) |
| 202 | + variance = (self._get_variance(t, predicted_variance=predicted_variance) ** 0.5) * noise |
| 203 | + |
| 204 | + pred_prev_sample = pred_prev_sample + variance |
| 205 | + |
| 206 | + if not return_dict: |
| 207 | + return (pred_prev_sample,) |
| 208 | + |
| 209 | + return DDPMSchedulerOutput(prev_sample=pred_prev_sample, pred_original_sample=pred_original_sample) |
| 210 | + |
| 211 | + |
| 212 | +class BitDiffusion(DiffusionPipeline): |
| 213 | + def __init__( |
| 214 | + self, |
| 215 | + unet: UNet2DConditionModel, |
| 216 | + scheduler: Union[DDIMScheduler, DDPMScheduler], |
| 217 | + bit_scale: Optional[float] = 1.0, |
| 218 | + ): |
| 219 | + super().__init__() |
| 220 | + self.bit_scale = bit_scale |
| 221 | + self.scheduler.step = ( |
| 222 | + ddim_bit_scheduler_step if isinstance(scheduler, DDIMScheduler) else ddpm_bit_scheduler_step |
| 223 | + ) |
| 224 | + |
| 225 | + self.register_modules(unet=unet, scheduler=scheduler) |
| 226 | + |
| 227 | + @torch.no_grad() |
| 228 | + def __call__( |
| 229 | + self, |
| 230 | + height: Optional[int] = 256, |
| 231 | + width: Optional[int] = 256, |
| 232 | + num_inference_steps: Optional[int] = 50, |
| 233 | + generator: Optional[torch.Generator] = None, |
| 234 | + batch_size: Optional[int] = 1, |
| 235 | + output_type: Optional[str] = "pil", |
| 236 | + return_dict: bool = True, |
| 237 | + **kwargs, |
| 238 | + ) -> Union[Tuple, ImagePipelineOutput]: |
| 239 | + latents = torch.randn( |
| 240 | + (batch_size, self.unet.in_channels, height, width), |
| 241 | + generator=generator, |
| 242 | + ) |
| 243 | + latents = decimal_to_bits(latents) * self.bit_scale |
| 244 | + latents = latents.to(self.device) |
| 245 | + |
| 246 | + self.scheduler.set_timesteps(num_inference_steps) |
| 247 | + |
| 248 | + for t in self.progress_bar(self.scheduler.timesteps): |
| 249 | + # predict the noise residual |
| 250 | + noise_pred = self.unet(latents, t).sample |
| 251 | + |
| 252 | + # compute the previous noisy sample x_t -> x_t-1 |
| 253 | + latents = self.scheduler.step(noise_pred, t, latents).prev_sample |
| 254 | + |
| 255 | + image = bits_to_decimal(latents) |
| 256 | + |
| 257 | + if output_type == "pil": |
| 258 | + image = self.numpy_to_pil(image) |
| 259 | + |
| 260 | + if not return_dict: |
| 261 | + return (image,) |
| 262 | + |
| 263 | + return ImagePipelineOutput(images=image) |
0 commit comments