Skip to content

Commit d5b6ed4

Browse files
hlkysayakpaul
authored andcommitted
Add set_shift to FlowMatchEulerDiscreteScheduler (#10269)
1 parent 9edee8d commit d5b6ed4

File tree

1 file changed

+13
-1
lines changed

1 file changed

+13
-1
lines changed

src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,10 +99,19 @@ def __init__(
9999
self._step_index = None
100100
self._begin_index = None
101101

102+
self._shift = shift
103+
102104
self.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication
103105
self.sigma_min = self.sigmas[-1].item()
104106
self.sigma_max = self.sigmas[0].item()
105107

108+
@property
109+
def shift(self):
110+
"""
111+
The value used for shifting.
112+
"""
113+
return self._shift
114+
106115
@property
107116
def step_index(self):
108117
"""
@@ -128,6 +137,9 @@ def set_begin_index(self, begin_index: int = 0):
128137
"""
129138
self._begin_index = begin_index
130139

140+
def set_shift(self, shift: float):
141+
self._shift = shift
142+
131143
def scale_noise(
132144
self,
133145
sample: torch.FloatTensor,
@@ -236,7 +248,7 @@ def set_timesteps(
236248
if self.config.use_dynamic_shifting:
237249
sigmas = self.time_shift(mu, 1.0, sigmas)
238250
else:
239-
sigmas = self.config.shift * sigmas / (1 + (self.config.shift - 1) * sigmas)
251+
sigmas = self.shift * sigmas / (1 + (self.shift - 1) * sigmas)
240252

241253
if self.config.shift_terminal:
242254
sigmas = self.stretch_shift_to_terminal(sigmas)

0 commit comments

Comments
 (0)