Skip to content

enable flux pipeline compatible with unipc and dpm-solver #11908

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Jul 16, 2025

Conversation

gameofdimension
Copy link
Contributor

@gameofdimension gameofdimension commented Jul 10, 2025

with this pr, we can use dpm/unipc together with flux pipeline by

pipe.scheduler = DPMSolverMultistepScheduler.from_config(
    pipe.scheduler.config, use_flow_sigmas=True, prediction_type="flow_prediction")

or

pipe.scheduler = UniPCMultistepScheduler.from_config(
    pipe.scheduler.config, use_flow_sigmas=True, prediction_type="flow_prediction")

What does this PR do?

Fixes #11907

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

have flux pipeline work with unipc/dpm schedulers
@gameofdimension gameofdimension changed the title enable flux pipeline compatible with unipc and dpm enable flux pipeline compatible with unipc and dpm-solver Jul 10, 2025
@yiyixuxu
Copy link
Collaborator

ohh thanks so much @gameofdimension
does it make sense to add use_dynamic_shifting config to dpm/unipc when use_flow_sigmas is enabled?

essentially, are we able to:

  1. not adding any changes to the pipeline
  2. add use_dynamic_shifting config the dpm/unipc scheduler, and accept the optional mu in their set_timesteps
  3. use this API at runtime to switch pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config,use_flow_sigma=True, prediction_type = "flow_prediction" )

@gameofdimension
Copy link
Contributor Author

Noted, I'll improve that

@hlky
Copy link
Contributor

hlky commented Jul 14, 2025

The sigma schedule in UniPC etc. with use_flow_sigmas is specific to SANA.

The following patch directly supports passing sigmas to UniPC. Note patch applies to this PR not main.

Existing behavior is maintained when sigmas=None, if self.config.use_flow_sigmas then SANA's sigma schedule is used, otherwise, the alphas_cumprod schedule is used, therefore SANA pipelines do not necessarily need updating but patch includes the required change to SanaPipeline.

Changes should be replicated to other use_flow_sigmas schedulers and SANA pipelines.

Passing sigmas when use_flow_sigmas=False would require further changes similar to

if num_inference_steps is None:
num_inference_steps = len(timesteps) if timesteps is not None else len(sigmas) - 1
self.num_inference_steps = num_inference_steps
if sigmas is not None:
log_sigmas = np.log(np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5))
sigmas = np.array(sigmas).astype(np.float32)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas[:-1]])

This is reflected by the NotImplementedError when sigmas is not None and not self.config.use_flow_sigmas.

Patch

diff --git a/src/diffusers/pipelines/sana/pipeline_sana.py b/src/diffusers/pipelines/sana/pipeline_sana.py
index 103f57a23..c0446e48c 100644
--- a/src/diffusers/pipelines/sana/pipeline_sana.py
+++ b/src/diffusers/pipelines/sana/pipeline_sana.py
@@ -19,6 +19,7 @@ import urllib.parse as ul
 import warnings
 from typing import Any, Callable, Dict, List, Optional, Tuple, Union
 
+import numpy as np
 import torch
 from transformers import Gemma2PreTrainedModel, GemmaTokenizer, GemmaTokenizerFast
 
@@ -902,6 +903,11 @@ class SanaPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
             prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
 
         # 4. Prepare timesteps
+
+        if sigmas is None:
+            alphas = np.linspace(1, 1 / self.config.num_train_timesteps, num_inference_steps + 1)
+            sigmas = 1.0 - alphas
+            sigmas = np.flip(sigmas)[:-1].copy()
         timesteps, num_inference_steps = retrieve_timesteps(
             self.scheduler, num_inference_steps, device, timesteps, sigmas
         )
diff --git a/src/diffusers/schedulers/scheduling_unipc_multistep.py b/src/diffusers/schedulers/scheduling_unipc_multistep.py
index 0125d256e..44f10378e 100644
--- a/src/diffusers/schedulers/scheduling_unipc_multistep.py
+++ b/src/diffusers/schedulers/scheduling_unipc_multistep.py
@@ -300,7 +300,13 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
         """
         self._begin_index = begin_index
 
-    def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None, mu: Optional[float] = None):
+    def set_timesteps(
+        self,
+        num_inference_steps: int = None,
+        device: Union[str, torch.device] = None,
+        sigmas: Optional[List[float]] = None,
+        mu: Optional[float] = None,
+    ):
         """
         Sets the discrete timesteps used for the diffusion chain (to be run before inference).
 
@@ -312,8 +318,22 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
         """
         # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://huggingface.co/papers/2305.08891
         if mu is not None:
-            assert self.config.use_dynamic_shifting and self.config.time_shift_type == 'exponential'
+            assert self.config.use_dynamic_shifting and self.config.time_shift_type == "exponential"
             self.config.flow_shift = np.exp(mu)
+
+        if sigmas is not None and not self.config.use_flow_sigmas:
+            raise NotImplementedError("Passing `sigmas` when `config.use_flow_sigmas=False` is not supported.")
+
+        if num_inference_steps is not None:
+            if sigmas is not None and len(sigmas) != num_inference_steps:
+                raise ValueError(
+                    "`sigmas` should have the same length as num_inference_steps, if `num_inference_steps` is provided"
+                )
+        elif sigmas is not None:
+            num_inference_steps = len(sigmas)
+        else:
+            raise ValueError("One of `num_inference_steps` or `sigmas` should be provided.")
+
         if self.config.timestep_spacing == "linspace":
             timesteps = (
                 np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps + 1)
@@ -338,7 +358,13 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
                 f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
             )
 
-        sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
+        if sigmas is None:
+            if self.config.use_flow_sigmas:
+                alphas = np.linspace(1, 1 / self.config.num_train_timesteps, num_inference_steps + 1)
+                sigmas = 1.0 - alphas
+                sigmas = np.flip(sigmas)[:-1].copy()
+            else:
+                sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
         if self.config.use_karras_sigmas:
             log_sigmas = np.log(sigmas)
             sigmas = np.flip(sigmas).copy()
@@ -382,9 +408,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
                 )
             sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
         elif self.config.use_flow_sigmas:
-            alphas = np.linspace(1, 1 / self.config.num_train_timesteps, num_inference_steps + 1)
-            sigmas = 1.0 - alphas
-            sigmas = np.flip(self.config.flow_shift * sigmas / (1 + (self.config.flow_shift - 1) * sigmas))[:-1].copy()
+            sigmas = self.config.flow_shift * sigmas / (1 + (self.config.flow_shift - 1) * sigmas)
             timesteps = (sigmas * self.config.num_train_timesteps).copy()
             if self.config.final_sigmas_type == "sigma_min":
                 sigma_last = sigmas[-1]

Repro

from diffusers import FlowMatchEulerDiscreteScheduler, UniPCMultistepScheduler
import numpy as np


def calculate_shift(
    image_seq_len,
    base_seq_len: int = 256,
    max_seq_len: int = 4096,
    base_shift: float = 0.5,
    max_shift: float = 1.15,
):
    m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
    b = base_shift - m * base_seq_len
    mu = image_seq_len * m + b
    return mu


num_inference_steps = 28
image_seq_len = 4096

sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)

flow_match_euler = FlowMatchEulerDiscreteScheduler.from_pretrained(
    "black-forest-labs/FLUX.1-dev", subfolder="scheduler"
)

mu = calculate_shift(
    image_seq_len,
    flow_match_euler.config.get("base_image_seq_len", 256),
    flow_match_euler.config.get("max_image_seq_len", 4096),
    flow_match_euler.config.get("base_shift", 0.5),
    flow_match_euler.config.get("max_shift", 1.15),
)

flow_match_euler.set_timesteps(sigmas=sigmas, mu=mu)

print(f"{flow_match_euler.sigmas=} | {flow_match_euler.timesteps=} | {len(flow_match_euler.sigmas)=} | {len(flow_match_euler.timesteps)=}")


unipc = UniPCMultistepScheduler.from_config(
    flow_match_euler.config, use_flow_sigmas=True, prediction_type="flow_prediction"
)

unipc.set_timesteps(sigmas=sigmas, mu=mu)

print(f"{unipc.sigmas=} | {unipc.timesteps=} | {len(unipc.sigmas)=} | {len(unipc.timesteps)=}")

unipc = UniPCMultistepScheduler.from_config(
    flow_match_euler.config, use_flow_sigmas=True, prediction_type="flow_prediction"
)

sana_alphas = np.linspace(1, 1 / unipc.config.num_train_timesteps, num_inference_steps + 1)
sana_sigmas = 1.0 - sana_alphas
sana_sigmas = np.flip(sana_sigmas)[:-1].copy()


unipc.set_timesteps(sigmas=sana_sigmas, mu=mu)

print(f"{unipc.sigmas=} | {unipc.timesteps=} | {len(unipc.sigmas)=} | {len(unipc.timesteps)=}")

Sigmas

Flow Match Euler

flow_match_euler.sigmas=tensor([1.0000, 0.9884, 0.9762, 0.9634, 0.9499, 0.9356, 0.9205, 0.9045, 0.8876, 0.8696, 0.8504, 0.8300, 0.8081, 0.7847, 0.7595, 0.7324, 0.7031, 0.6714, 0.6370, 0.5994, 0.5582, 0.5128, 0.4627, 0.4071, 0.3448, 0.2748, 0.1955, 0.1047, 0.0000]) |
flow_match_euler.timesteps=tensor([1000.0000,  988.4086,  976.2225,  963.3944,  949.8726,  935.5989,  920.5090,  904.5308,  887.5834,  869.5759,  850.4057,  829.9564,  808.0955,  784.6716,  759.5109,  732.4128,  703.1447,  671.4348,  636.9645,  599.3567,  558.1628,  512.8441,  462.7484,  407.0784,  344.8489,  274.8280,  195.4546,  104.7209]) |
len(flow_match_euler.sigmas)=29 | len(flow_match_euler.timesteps)=28

UniPC with Flux sigmas

unipc.sigmas=tensor([1.0000, 0.9884, 0.9762, 0.9634, 0.9499, 0.9356, 0.9205, 0.9045, 0.8876, 0.8696, 0.8504, 0.8300, 0.8081, 0.7847, 0.7595, 0.7324, 0.7031, 0.6714, 0.6370, 0.5994, 0.5582, 0.5128, 0.4627, 0.4071, 0.3448, 0.2748, 0.1955, 0.1047, 0.0000]) |
unipc.timesteps=tensor([1000,  988,  976,  963,  949,  935,  920,  904,  887,  869,  850,  829,  808,  784,  759,  732,  703,  671,  636,  599,  558,  512,  462,  407,  344,  274,  195,  104]) |
len(unipc.sigmas)=29 | len(unipc.timesteps)=28

UniPC with SANA sigmas

unipc.sigmas=tensor([0.9997, 0.9881, 0.9759, 0.9631, 0.9495, 0.9353, 0.9202, 0.9042, 0.8872, 0.8692, 0.8500, 0.8296, 0.8077, 0.7843, 0.7591, 0.7320, 0.7028, 0.6711, 0.6366, 0.5990, 0.5578, 0.5125, 0.4624, 0.4068, 0.3446, 0.2746, 0.1953, 0.1046, 0.0000]) |
unipc.timesteps=tensor([999, 988, 975, 963, 949, 935, 920, 904, 887, 869, 850, 829, 807, 784, 759, 732, 702, 671, 636, 599, 557, 512, 462, 406, 344, 274, 195, 104]) |
len(unipc.sigmas)=29 | len(unipc.timesteps)=28

@gameofdimension
Copy link
Contributor Author

@hlky
"passing sigmas to UniPC" is beyond my intentions to open this PR. maybe we can do it in other PRs

@hlky
Copy link
Contributor

hlky commented Jul 14, 2025

@gameofdimension You are more than welcome to apply this in a subsequent PR, however, the issue is literally caused by lack of support for passing sigmas, overriding sigmas to None does not technically fix the issue and the quality of generated images will be affected by different sigma schedule, the sigma schedule happens to be relatively close between SANA and Flux however this is not always the case, the provided patch can be easily applied and fully fixes the issue in a future proof manner.

@yiyixuxu
Copy link
Collaborator

yiyixuxu commented Jul 14, 2025

thanks @gameofdimension can you show flux outputs with unipc and dpm-solve along with the default flow_match scheduler?
will merge after that

we can add cusom sigma support in a different PR (but need specific custom schedules that works well and can test out)

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@yiyixuxu
Copy link
Collaborator

@bot /style

Copy link
Contributor

github-actions bot commented Jul 14, 2025

Style bot fixed some files and pushed the changes.

@gameofdimension
Copy link
Contributor Author

@yiyixuxu

prompt

'a motorcycle parked in an ornate bank lobby'

ouputs

default

flux-diffusers-default-20

unipc

flux-diffusers-unipc-20

dpm-solver multistep

flux-diffusers-dpm-20

dpm-solver singlestep

flux-diffusers-sdpm-20

deis

flux-diffusers-deis-20

repro code snippet

    print("original scheduler", pipe.scheduler)
    if scheduler == 'dpm':
        pipe.scheduler = DPMSolverMultistepScheduler.from_config(
            pipe.scheduler.config, use_flow_sigmas=True, prediction_type="flow_prediction")
    elif scheduler == 'unipc':
        pipe.scheduler = UniPCMultistepScheduler.from_config(
            pipe.scheduler.config, use_flow_sigmas=True, prediction_type="flow_prediction")
    elif scheduler == 'sdpm':
        pipe.scheduler = DPMSolverSinglestepScheduler.from_config(
            pipe.scheduler.config, use_flow_sigmas=True, prediction_type="flow_prediction")
    elif scheduler == 'deis':
        pipe.scheduler = DEISMultistepScheduler.from_config(
            pipe.scheduler.config, use_flow_sigmas=True, prediction_type="flow_prediction")
    print("updated scheduler", pipe.scheduler)

Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks!

@yiyixuxu
Copy link
Collaborator

cc @asomoza I think deis is a bit off, no?
not a merge blocker though, we can look into it separately

@asomoza
Copy link
Member

asomoza commented Jul 16, 2025

@yiyixuxu in the example images deis looks bad but I did a test myself and it looks good, I don't see the same artifacts.

DPMSolverMultistepScheduler DEISMultistepScheduler
20250715204221_42 20250715204844_42

@gameofdimension can you please give it another go and test deis again to see if you still get those artifacts? Just se we know if we have to do a follow up PR or not.

@gameofdimension
Copy link
Contributor Author

@asomoza
what is your num_inference_steps, mine is 20

@asomoza
Copy link
Member

asomoza commented Jul 16, 2025

oh that's it, I use 28 to test Flux, you should probably use that or the official 50 steps to test so this doesn't happen, 20 steps is too low and even if some schedulers or images work, some won't an may cause misunderstandings.

@yiyixuxu yiyixuxu merged commit 5c52097 into huggingface:main Jul 16, 2025
9 of 11 checks passed
@gameofdimension gameofdimension deleted the patch-1 branch July 16, 2025 04:02
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

flux pipeline not work with DPMSolverMultistepScheduler and UniPCMultistepScheduler
5 participants