-
Notifications
You must be signed in to change notification settings - Fork 6.1k
enable flux pipeline compatible with unipc and dpm-solver #11908
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
have flux pipeline work with unipc/dpm schedulers
ohh thanks so much @gameofdimension essentially, are we able to:
|
Noted, I'll improve that |
The sigma schedule in UniPC etc. with The following patch directly supports passing Existing behavior is maintained when Changes should be replicated to other Passing diffusers/src/diffusers/schedulers/scheduling_euler_discrete.py Lines 366 to 373 in 3c8b67b
This is reflected by the Patch
diff --git a/src/diffusers/pipelines/sana/pipeline_sana.py b/src/diffusers/pipelines/sana/pipeline_sana.py
index 103f57a23..c0446e48c 100644
--- a/src/diffusers/pipelines/sana/pipeline_sana.py
+++ b/src/diffusers/pipelines/sana/pipeline_sana.py
@@ -19,6 +19,7 @@ import urllib.parse as ul
import warnings
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+import numpy as np
import torch
from transformers import Gemma2PreTrainedModel, GemmaTokenizer, GemmaTokenizerFast
@@ -902,6 +903,11 @@ class SanaPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
# 4. Prepare timesteps
+
+ if sigmas is None:
+ alphas = np.linspace(1, 1 / self.config.num_train_timesteps, num_inference_steps + 1)
+ sigmas = 1.0 - alphas
+ sigmas = np.flip(sigmas)[:-1].copy()
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler, num_inference_steps, device, timesteps, sigmas
)
diff --git a/src/diffusers/schedulers/scheduling_unipc_multistep.py b/src/diffusers/schedulers/scheduling_unipc_multistep.py
index 0125d256e..44f10378e 100644
--- a/src/diffusers/schedulers/scheduling_unipc_multistep.py
+++ b/src/diffusers/schedulers/scheduling_unipc_multistep.py
@@ -300,7 +300,13 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
"""
self._begin_index = begin_index
- def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None, mu: Optional[float] = None):
+ def set_timesteps(
+ self,
+ num_inference_steps: int = None,
+ device: Union[str, torch.device] = None,
+ sigmas: Optional[List[float]] = None,
+ mu: Optional[float] = None,
+ ):
"""
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
@@ -312,8 +318,22 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
"""
# "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://huggingface.co/papers/2305.08891
if mu is not None:
- assert self.config.use_dynamic_shifting and self.config.time_shift_type == 'exponential'
+ assert self.config.use_dynamic_shifting and self.config.time_shift_type == "exponential"
self.config.flow_shift = np.exp(mu)
+
+ if sigmas is not None and not self.config.use_flow_sigmas:
+ raise NotImplementedError("Passing `sigmas` when `config.use_flow_sigmas=False` is not supported.")
+
+ if num_inference_steps is not None:
+ if sigmas is not None and len(sigmas) != num_inference_steps:
+ raise ValueError(
+ "`sigmas` should have the same length as num_inference_steps, if `num_inference_steps` is provided"
+ )
+ elif sigmas is not None:
+ num_inference_steps = len(sigmas)
+ else:
+ raise ValueError("One of `num_inference_steps` or `sigmas` should be provided.")
+
if self.config.timestep_spacing == "linspace":
timesteps = (
np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps + 1)
@@ -338,7 +358,13 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
)
- sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
+ if sigmas is None:
+ if self.config.use_flow_sigmas:
+ alphas = np.linspace(1, 1 / self.config.num_train_timesteps, num_inference_steps + 1)
+ sigmas = 1.0 - alphas
+ sigmas = np.flip(sigmas)[:-1].copy()
+ else:
+ sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
if self.config.use_karras_sigmas:
log_sigmas = np.log(sigmas)
sigmas = np.flip(sigmas).copy()
@@ -382,9 +408,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
)
sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
elif self.config.use_flow_sigmas:
- alphas = np.linspace(1, 1 / self.config.num_train_timesteps, num_inference_steps + 1)
- sigmas = 1.0 - alphas
- sigmas = np.flip(self.config.flow_shift * sigmas / (1 + (self.config.flow_shift - 1) * sigmas))[:-1].copy()
+ sigmas = self.config.flow_shift * sigmas / (1 + (self.config.flow_shift - 1) * sigmas)
timesteps = (sigmas * self.config.num_train_timesteps).copy()
if self.config.final_sigmas_type == "sigma_min":
sigma_last = sigmas[-1]
Repro
from diffusers import FlowMatchEulerDiscreteScheduler, UniPCMultistepScheduler
import numpy as np
def calculate_shift(
image_seq_len,
base_seq_len: int = 256,
max_seq_len: int = 4096,
base_shift: float = 0.5,
max_shift: float = 1.15,
):
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
b = base_shift - m * base_seq_len
mu = image_seq_len * m + b
return mu
num_inference_steps = 28
image_seq_len = 4096
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
flow_match_euler = FlowMatchEulerDiscreteScheduler.from_pretrained(
"black-forest-labs/FLUX.1-dev", subfolder="scheduler"
)
mu = calculate_shift(
image_seq_len,
flow_match_euler.config.get("base_image_seq_len", 256),
flow_match_euler.config.get("max_image_seq_len", 4096),
flow_match_euler.config.get("base_shift", 0.5),
flow_match_euler.config.get("max_shift", 1.15),
)
flow_match_euler.set_timesteps(sigmas=sigmas, mu=mu)
print(f"{flow_match_euler.sigmas=} | {flow_match_euler.timesteps=} | {len(flow_match_euler.sigmas)=} | {len(flow_match_euler.timesteps)=}")
unipc = UniPCMultistepScheduler.from_config(
flow_match_euler.config, use_flow_sigmas=True, prediction_type="flow_prediction"
)
unipc.set_timesteps(sigmas=sigmas, mu=mu)
print(f"{unipc.sigmas=} | {unipc.timesteps=} | {len(unipc.sigmas)=} | {len(unipc.timesteps)=}")
unipc = UniPCMultistepScheduler.from_config(
flow_match_euler.config, use_flow_sigmas=True, prediction_type="flow_prediction"
)
sana_alphas = np.linspace(1, 1 / unipc.config.num_train_timesteps, num_inference_steps + 1)
sana_sigmas = 1.0 - sana_alphas
sana_sigmas = np.flip(sana_sigmas)[:-1].copy()
unipc.set_timesteps(sigmas=sana_sigmas, mu=mu)
print(f"{unipc.sigmas=} | {unipc.timesteps=} | {len(unipc.sigmas)=} | {len(unipc.timesteps)=}") Sigmas
Flow Match Euler
UniPC with Flux sigmas
UniPC with SANA sigmas
|
@hlky |
@gameofdimension You are more than welcome to apply this in a subsequent PR, however, the issue is literally caused by lack of support for passing sigmas, overriding sigmas to None does not technically fix the issue and the quality of generated images will be affected by different sigma schedule, the sigma schedule happens to be relatively close between SANA and Flux however this is not always the case, the provided patch can be easily applied and fully fixes the issue in a future proof manner. |
thanks @gameofdimension can you show flux outputs with unipc and dpm-solve along with the default flow_match scheduler? we can add cusom sigma support in a different PR (but need specific custom schedules that works well and can test out) |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
@bot /style |
Style bot fixed some files and pushed the changes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks!
cc @asomoza I think deis is a bit off, no? |
@yiyixuxu in the example images
@gameofdimension can you please give it another go and test |
@asomoza |
oh that's it, I use 28 to test Flux, you should probably use that or the official 50 steps to test so this doesn't happen, 20 steps is too low and even if some schedulers or images work, some won't an may cause misunderstandings. |
with this pr, we can use dpm/unipc together with flux pipeline by
or
What does this PR do?
Fixes #11907
Before submitting
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.