Skip to content

Commit 6e8e1ed

Browse files
nipunjindalnjindal
and
njindal
authored
[2905]: Add Karras pattern to discrete euler (#2956)
* [2905]: Add Karras pattern to discrete euler * [2905]: Add Karras pattern to discrete euler * Review comments * Review comments * Review comments * Review comments --------- Co-authored-by: njindal <[email protected]>
1 parent 37b359b commit 6e8e1ed

File tree

2 files changed

+75
-0
lines changed

2 files changed

+75
-0
lines changed

src/diffusers/schedulers/scheduling_euler_discrete.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,10 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
103103
interpolation_type (`str`, default `"linear"`, optional):
104104
interpolation type to compute intermediate sigmas for the scheduler denoising steps. Should be one of
105105
[`"linear"`, `"log_linear"`].
106+
use_karras_sigmas (`bool`, *optional*, defaults to `False`):
107+
This parameter controls whether to use Karras sigmas (Karras et al. (2022) scheme) for step sizes in the
108+
noise schedule during the sampling process. If True, the sigmas will be determined according to a sequence
109+
of noise levels {σi} as defined in Equation (5) of the paper https://arxiv.org/pdf/2206.00364.pdf.
106110
"""
107111

108112
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
@@ -118,6 +122,7 @@ def __init__(
118122
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
119123
prediction_type: str = "epsilon",
120124
interpolation_type: str = "linear",
125+
use_karras_sigmas: Optional[bool] = False,
121126
):
122127
if trained_betas is not None:
123128
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
@@ -149,6 +154,7 @@ def __init__(
149154
timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=float)[::-1].copy()
150155
self.timesteps = torch.from_numpy(timesteps)
151156
self.is_scale_input_called = False
157+
self.use_karras_sigmas = use_karras_sigmas
152158

153159
def scale_model_input(
154160
self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor]
@@ -187,6 +193,7 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
187193

188194
timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy()
189195
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
196+
log_sigmas = np.log(sigmas)
190197

191198
if self.config.interpolation_type == "linear":
192199
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
@@ -198,6 +205,10 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
198205
" 'linear' or 'log_linear'"
199206
)
200207

208+
if self.use_karras_sigmas:
209+
sigmas = self._convert_to_karras(in_sigmas=sigmas)
210+
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
211+
201212
sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
202213
self.sigmas = torch.from_numpy(sigmas).to(device=device)
203214
if str(device).startswith("mps"):
@@ -206,6 +217,43 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
206217
else:
207218
self.timesteps = torch.from_numpy(timesteps).to(device=device)
208219

220+
def _sigma_to_t(self, sigma, log_sigmas):
221+
# get log sigma
222+
log_sigma = np.log(sigma)
223+
224+
# get distribution
225+
dists = log_sigma - log_sigmas[:, np.newaxis]
226+
227+
# get sigmas range
228+
low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2)
229+
high_idx = low_idx + 1
230+
231+
low = log_sigmas[low_idx]
232+
high = log_sigmas[high_idx]
233+
234+
# interpolate sigmas
235+
w = (low - log_sigma) / (low - high)
236+
w = np.clip(w, 0, 1)
237+
238+
# transform interpolation to time range
239+
t = (1 - w) * low_idx + w * high_idx
240+
t = t.reshape(sigma.shape)
241+
return t
242+
243+
# Copied from https://github.com/crowsonkb/k-diffusion/blob/686dbad0f39640ea25c8a8c6a6e56bb40eacefa2/k_diffusion/sampling.py#L17
244+
def _convert_to_karras(self, in_sigmas: torch.FloatTensor) -> torch.FloatTensor:
245+
"""Constructs the noise schedule of Karras et al. (2022)."""
246+
247+
sigma_min: float = in_sigmas[-1].item()
248+
sigma_max: float = in_sigmas[0].item()
249+
250+
rho = 7.0 # 7.0 is the value used in the paper
251+
ramp = np.linspace(0, 1, self.num_inference_steps)
252+
min_inv_rho = sigma_min ** (1 / rho)
253+
max_inv_rho = sigma_max ** (1 / rho)
254+
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
255+
return sigmas
256+
209257
def step(
210258
self,
211259
model_output: torch.FloatTensor,

tests/schedulers/test_scheduler_euler.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,3 +117,30 @@ def test_full_loop_device(self):
117117

118118
assert abs(result_sum.item() - 10.0807) < 1e-2
119119
assert abs(result_mean.item() - 0.0131) < 1e-3
120+
121+
def test_full_loop_device_karras_sigmas(self):
122+
scheduler_class = self.scheduler_classes[0]
123+
scheduler_config = self.get_scheduler_config()
124+
scheduler = scheduler_class(**scheduler_config, use_karras_sigmas=True)
125+
126+
scheduler.set_timesteps(self.num_inference_steps, device=torch_device)
127+
128+
generator = torch.manual_seed(0)
129+
130+
model = self.dummy_model()
131+
sample = self.dummy_sample_deter * scheduler.init_noise_sigma
132+
sample = sample.to(torch_device)
133+
134+
for t in scheduler.timesteps:
135+
sample = scheduler.scale_model_input(sample, t)
136+
137+
model_output = model(sample, t)
138+
139+
output = scheduler.step(model_output, t, sample, generator=generator)
140+
sample = output.prev_sample
141+
142+
result_sum = torch.sum(torch.abs(sample))
143+
result_mean = torch.mean(torch.abs(sample))
144+
145+
assert abs(result_sum.item() - 124.52299499511719) < 1e-2
146+
assert abs(result_mean.item() - 0.16213932633399963) < 1e-3

0 commit comments

Comments
 (0)