@@ -103,6 +103,10 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
103
103
interpolation_type (`str`, default `"linear"`, optional):
104
104
interpolation type to compute intermediate sigmas for the scheduler denoising steps. Should be one of
105
105
[`"linear"`, `"log_linear"`].
106
+ use_karras_sigmas (`bool`, *optional*, defaults to `False`):
107
+ This parameter controls whether to use Karras sigmas (Karras et al. (2022) scheme) for step sizes in the
108
+ noise schedule during the sampling process. If True, the sigmas will be determined according to a sequence
109
+ of noise levels {σi} as defined in Equation (5) of the paper https://arxiv.org/pdf/2206.00364.pdf.
106
110
"""
107
111
108
112
_compatibles = [e .name for e in KarrasDiffusionSchedulers ]
@@ -118,6 +122,7 @@ def __init__(
118
122
trained_betas : Optional [Union [np .ndarray , List [float ]]] = None ,
119
123
prediction_type : str = "epsilon" ,
120
124
interpolation_type : str = "linear" ,
125
+ use_karras_sigmas : Optional [bool ] = False ,
121
126
):
122
127
if trained_betas is not None :
123
128
self .betas = torch .tensor (trained_betas , dtype = torch .float32 )
@@ -149,6 +154,7 @@ def __init__(
149
154
timesteps = np .linspace (0 , num_train_timesteps - 1 , num_train_timesteps , dtype = float )[::- 1 ].copy ()
150
155
self .timesteps = torch .from_numpy (timesteps )
151
156
self .is_scale_input_called = False
157
+ self .use_karras_sigmas = use_karras_sigmas
152
158
153
159
def scale_model_input (
154
160
self , sample : torch .FloatTensor , timestep : Union [float , torch .FloatTensor ]
@@ -187,6 +193,7 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
187
193
188
194
timesteps = np .linspace (0 , self .config .num_train_timesteps - 1 , num_inference_steps , dtype = float )[::- 1 ].copy ()
189
195
sigmas = np .array (((1 - self .alphas_cumprod ) / self .alphas_cumprod ) ** 0.5 )
196
+ log_sigmas = np .log (sigmas )
190
197
191
198
if self .config .interpolation_type == "linear" :
192
199
sigmas = np .interp (timesteps , np .arange (0 , len (sigmas )), sigmas )
@@ -198,6 +205,10 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
198
205
" 'linear' or 'log_linear'"
199
206
)
200
207
208
+ if self .use_karras_sigmas :
209
+ sigmas = self ._convert_to_karras (in_sigmas = sigmas )
210
+ timesteps = np .array ([self ._sigma_to_t (sigma , log_sigmas ) for sigma in sigmas ])
211
+
201
212
sigmas = np .concatenate ([sigmas , [0.0 ]]).astype (np .float32 )
202
213
self .sigmas = torch .from_numpy (sigmas ).to (device = device )
203
214
if str (device ).startswith ("mps" ):
@@ -206,6 +217,43 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
206
217
else :
207
218
self .timesteps = torch .from_numpy (timesteps ).to (device = device )
208
219
220
+ def _sigma_to_t (self , sigma , log_sigmas ):
221
+ # get log sigma
222
+ log_sigma = np .log (sigma )
223
+
224
+ # get distribution
225
+ dists = log_sigma - log_sigmas [:, np .newaxis ]
226
+
227
+ # get sigmas range
228
+ low_idx = np .cumsum ((dists >= 0 ), axis = 0 ).argmax (axis = 0 ).clip (max = log_sigmas .shape [0 ] - 2 )
229
+ high_idx = low_idx + 1
230
+
231
+ low = log_sigmas [low_idx ]
232
+ high = log_sigmas [high_idx ]
233
+
234
+ # interpolate sigmas
235
+ w = (low - log_sigma ) / (low - high )
236
+ w = np .clip (w , 0 , 1 )
237
+
238
+ # transform interpolation to time range
239
+ t = (1 - w ) * low_idx + w * high_idx
240
+ t = t .reshape (sigma .shape )
241
+ return t
242
+
243
+ # Copied from https://github.com/crowsonkb/k-diffusion/blob/686dbad0f39640ea25c8a8c6a6e56bb40eacefa2/k_diffusion/sampling.py#L17
244
+ def _convert_to_karras (self , in_sigmas : torch .FloatTensor ) -> torch .FloatTensor :
245
+ """Constructs the noise schedule of Karras et al. (2022)."""
246
+
247
+ sigma_min : float = in_sigmas [- 1 ].item ()
248
+ sigma_max : float = in_sigmas [0 ].item ()
249
+
250
+ rho = 7.0 # 7.0 is the value used in the paper
251
+ ramp = np .linspace (0 , 1 , self .num_inference_steps )
252
+ min_inv_rho = sigma_min ** (1 / rho )
253
+ max_inv_rho = sigma_max ** (1 / rho )
254
+ sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho )) ** rho
255
+ return sigmas
256
+
209
257
def step (
210
258
self ,
211
259
model_output : torch .FloatTensor ,
0 commit comments