Skip to content

Commit 8f95fb1

Browse files
Fix diffusion inferers to support diffusers-style schedulers
1 parent 0a8d945 commit 8f95fb1

3 files changed

Lines changed: 157 additions & 12 deletions

File tree

monai/inferers/inferer.py

Lines changed: 57 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from __future__ import annotations
1313

14+
import inspect
1415
import math
1516
import warnings
1617
from abc import ABC, abstractmethod
@@ -861,6 +862,42 @@ def __init__(self, scheduler: Scheduler) -> None: # type: ignore[override]
861862

862863
self.scheduler = scheduler
863864

865+
@staticmethod
866+
def _scheduler_step_supports_kwarg(scheduler: Scheduler, kwarg: str) -> bool:
867+
try:
868+
return kwarg in inspect.signature(scheduler.step).parameters
869+
except (TypeError, ValueError):
870+
return False
871+
872+
@staticmethod
873+
def _get_previous_sample_from_step_output(step_output: Any) -> torch.Tensor:
874+
if isinstance(step_output, tuple):
875+
return step_output[0]
876+
if isinstance(step_output, Mapping):
877+
return step_output["prev_sample"]
878+
if hasattr(step_output, "prev_sample"):
879+
return step_output.prev_sample
880+
raise TypeError("Unsupported scheduler.step output. Expected a tuple or an object with `prev_sample`.")
881+
882+
def _scheduler_step(
883+
self,
884+
scheduler: Scheduler,
885+
model_output: torch.Tensor,
886+
timestep: int | torch.Tensor,
887+
sample: torch.Tensor,
888+
next_timestep: int | torch.Tensor | None = None,
889+
) -> torch.Tensor:
890+
step_kwargs = {}
891+
if self._scheduler_step_supports_kwarg(scheduler, "return_dict"):
892+
step_kwargs["return_dict"] = False
893+
894+
if isinstance(scheduler, RFlowScheduler):
895+
step_output = scheduler.step(model_output, timestep, sample, next_timestep, **step_kwargs) # type: ignore
896+
else:
897+
step_output = scheduler.step(model_output, timestep, sample, **step_kwargs) # type: ignore
898+
899+
return self._get_previous_sample_from_step_output(step_output)
900+
864901
def __call__( # type: ignore[override]
865902
self,
866903
inputs: torch.Tensor,
@@ -940,7 +977,12 @@ def sample(
940977
scheduler = self.scheduler
941978
image = input_noise
942979

943-
all_next_timesteps = torch.cat((scheduler.timesteps[1:], torch.tensor([0], dtype=scheduler.timesteps.dtype)))
980+
all_next_timesteps = torch.cat(
981+
(
982+
scheduler.timesteps[1:],
983+
torch.tensor([0], dtype=scheduler.timesteps.dtype, device=scheduler.timesteps.device),
984+
)
985+
)
944986
if verbose and has_tqdm:
945987
progress_bar = tqdm(
946988
zip(scheduler.timesteps, all_next_timesteps),
@@ -984,10 +1026,9 @@ def sample(
9841026
model_output = model_output_uncond + cfg * (model_output_cond - model_output_uncond)
9851027

9861028
# 2. compute previous image: x_t -> x_t-1
987-
if not isinstance(scheduler, RFlowScheduler):
988-
image, _ = scheduler.step(model_output, t, image) # type: ignore
989-
else:
990-
image, _ = scheduler.step(model_output, t, image, next_t) # type: ignore
1029+
image = self._scheduler_step(
1030+
scheduler=scheduler, model_output=model_output, timestep=t, sample=image, next_timestep=next_t
1031+
)
9911032
if save_intermediates and t % intermediate_steps == 0:
9921033
intermediates.append(image)
9931034

@@ -1046,7 +1087,7 @@ def get_likelihood(
10461087
total_kl = torch.zeros(inputs.shape[0]).to(inputs.device)
10471088
for t in progress_bar:
10481089
timesteps = torch.full(inputs.shape[:1], t, device=inputs.device).long()
1049-
noisy_image = self.scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps)
1090+
noisy_image = scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps)
10501091
diffusion_model = (
10511092
partial(diffusion_model, seg=seg)
10521093
if isinstance(diffusion_model, SPADEDiffusionModelUNet)
@@ -1509,7 +1550,12 @@ def sample( # type: ignore[override]
15091550
scheduler = self.scheduler
15101551
image = input_noise
15111552

1512-
all_next_timesteps = torch.cat((scheduler.timesteps[1:], torch.tensor([0], dtype=scheduler.timesteps.dtype)))
1553+
all_next_timesteps = torch.cat(
1554+
(
1555+
scheduler.timesteps[1:],
1556+
torch.tensor([0], dtype=scheduler.timesteps.dtype, device=scheduler.timesteps.device),
1557+
)
1558+
)
15131559
if verbose and has_tqdm:
15141560
progress_bar = tqdm(
15151561
zip(scheduler.timesteps, all_next_timesteps),
@@ -1583,10 +1629,9 @@ def sample( # type: ignore[override]
15831629
model_output = model_output_uncond + cfg * (model_output_cond - model_output_uncond)
15841630

15851631
# 3. compute previous image: x_t -> x_t-1
1586-
if not isinstance(scheduler, RFlowScheduler):
1587-
image, _ = scheduler.step(model_output, t, image) # type: ignore
1588-
else:
1589-
image, _ = scheduler.step(model_output, t, image, next_t) # type: ignore
1632+
image = self._scheduler_step(
1633+
scheduler=scheduler, model_output=model_output, timestep=t, sample=image, next_timestep=next_t
1634+
)
15901635

15911636
if save_intermediates and t % intermediate_steps == 0:
15921637
intermediates.append(image)
@@ -1647,7 +1692,7 @@ def get_likelihood( # type: ignore[override]
16471692
total_kl = torch.zeros(inputs.shape[0]).to(inputs.device)
16481693
for t in progress_bar:
16491694
timesteps = torch.full(inputs.shape[:1], t, device=inputs.device).long()
1650-
noisy_image = self.scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps)
1695+
noisy_image = scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps)
16511696

16521697
diffuse = diffusion_model
16531698
if isinstance(diffusion_model, SPADEDiffusionModelUNet):

tests/inferers/test_diffusion_inferer.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,31 @@
5555
]
5656

5757

58+
class DiffusersLikeSchedulerOutput:
59+
def __init__(self, prev_sample: torch.Tensor, pred_original_sample: torch.Tensor) -> None:
60+
self.prev_sample = prev_sample
61+
self.pred_original_sample = pred_original_sample
62+
63+
64+
class DiffusersStyleDDPMScheduler(DDPMScheduler):
65+
def step(
66+
self,
67+
model_output: torch.Tensor,
68+
timestep: int,
69+
sample: torch.Tensor,
70+
generator: torch.Generator | None = None,
71+
return_dict: bool = True,
72+
):
73+
prev_sample, pred_original_sample = super().step(
74+
model_output=model_output, timestep=timestep, sample=sample, generator=generator
75+
)
76+
if return_dict:
77+
return DiffusersLikeSchedulerOutput(
78+
prev_sample=prev_sample, pred_original_sample=pred_original_sample
79+
)
80+
return prev_sample, pred_original_sample
81+
82+
5883
class TestDiffusionSamplingInferer(unittest.TestCase):
5984
@parameterized.expand(TEST_CASES)
6085
@skipUnless(has_einops, "Requires einops")
@@ -126,6 +151,23 @@ def test_ddpm_sampler(self, model_params, input_shape):
126151
)
127152
self.assertEqual(len(intermediates), 10)
128153

154+
@parameterized.expand(TEST_CASES)
155+
@skipUnless(has_einops, "Requires einops")
156+
def test_diffusers_style_ddpm_sampler(self, model_params, input_shape):
157+
model = DiffusionModelUNet(**model_params)
158+
device = "cuda:0" if torch.cuda.is_available() else "cpu"
159+
model.to(device)
160+
model.eval()
161+
noise = torch.randn(input_shape).to(device)
162+
scheduler = DiffusersStyleDDPMScheduler(num_train_timesteps=1000)
163+
inferer = DiffusionInferer(scheduler=scheduler)
164+
scheduler.set_timesteps(num_inference_steps=10)
165+
sample, intermediates = inferer.sample(
166+
input_noise=noise, diffusion_model=model, scheduler=scheduler, save_intermediates=True, intermediate_steps=1
167+
)
168+
self.assertEqual(sample.shape, noise.shape)
169+
self.assertEqual(len(intermediates), 10)
170+
129171
@parameterized.expand(TEST_CASES)
130172
@skipUnless(has_einops, "Requires einops")
131173
def test_ddim_sampler(self, model_params, input_shape):

tests/inferers/test_latent_diffusion_inferer.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,33 @@
313313
],
314314
]
315315

316+
TEST_CASES_DIFFUSERS = [TEST_CASES[0]]
317+
318+
319+
class DiffusersLikeSchedulerOutput:
320+
def __init__(self, prev_sample: torch.Tensor, pred_original_sample: torch.Tensor) -> None:
321+
self.prev_sample = prev_sample
322+
self.pred_original_sample = pred_original_sample
323+
324+
325+
class DiffusersStyleDDPMScheduler(DDPMScheduler):
326+
def step(
327+
self,
328+
model_output: torch.Tensor,
329+
timestep: int,
330+
sample: torch.Tensor,
331+
generator: torch.Generator | None = None,
332+
return_dict: bool = True,
333+
):
334+
prev_sample, pred_original_sample = super().step(
335+
model_output=model_output, timestep=timestep, sample=sample, generator=generator
336+
)
337+
if return_dict:
338+
return DiffusersLikeSchedulerOutput(
339+
prev_sample=prev_sample, pred_original_sample=pred_original_sample
340+
)
341+
return prev_sample, pred_original_sample
342+
316343

317344
class TestDiffusionSamplingInferer(unittest.TestCase):
318345
@parameterized.expand(TEST_CASES)
@@ -414,6 +441,37 @@ def test_sample_shape(
414441
)
415442
self.assertEqual(sample.shape, input_shape)
416443

444+
@parameterized.expand(TEST_CASES_DIFFUSERS)
445+
@skipUnless(has_einops, "Requires einops")
446+
def test_diffusers_style_ddpm_sample_shape(
447+
self, ae_model_type, autoencoder_params, dm_model_type, stage_2_params, input_shape, latent_shape
448+
):
449+
if ae_model_type == "AutoencoderKL":
450+
stage_1 = AutoencoderKL(**autoencoder_params)
451+
else:
452+
stage_1 = VQVAE(**autoencoder_params)
453+
454+
if dm_model_type == "SPADEDiffusionModelUNet":
455+
stage_2 = SPADEDiffusionModelUNet(**stage_2_params)
456+
else:
457+
stage_2 = DiffusionModelUNet(**stage_2_params)
458+
459+
device = "cuda:0" if torch.cuda.is_available() else "cpu"
460+
stage_1.to(device)
461+
stage_2.to(device)
462+
stage_1.eval()
463+
stage_2.eval()
464+
465+
noise = torch.randn(latent_shape).to(device)
466+
scheduler = DiffusersStyleDDPMScheduler(num_train_timesteps=1000)
467+
inferer = LatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0)
468+
scheduler.set_timesteps(num_inference_steps=10)
469+
470+
sample = inferer.sample(
471+
input_noise=noise, autoencoder_model=stage_1, diffusion_model=stage_2, scheduler=scheduler
472+
)
473+
self.assertEqual(sample.shape, input_shape)
474+
417475
@parameterized.expand(TEST_CASES)
418476
@skipUnless(has_einops, "Requires einops")
419477
def test_sample_shape_with_cfg(

0 commit comments

Comments
 (0)