|
11 | 11 |
|
12 | 12 | from __future__ import annotations |
13 | 13 |
|
| 14 | +import inspect |
14 | 15 | import math |
15 | 16 | import warnings |
16 | 17 | from abc import ABC, abstractmethod |
@@ -861,6 +862,42 @@ def __init__(self, scheduler: Scheduler) -> None: # type: ignore[override] |
861 | 862 |
|
862 | 863 | self.scheduler = scheduler |
863 | 864 |
|
| 865 | + @staticmethod |
| 866 | + def _scheduler_step_supports_kwarg(scheduler: Scheduler, kwarg: str) -> bool: |
| 867 | + try: |
| 868 | + return kwarg in inspect.signature(scheduler.step).parameters |
| 869 | + except (TypeError, ValueError): |
| 870 | + return False |
| 871 | + |
| 872 | + @staticmethod |
| 873 | + def _get_previous_sample_from_step_output(step_output: Any) -> torch.Tensor: |
| 874 | + if isinstance(step_output, tuple): |
| 875 | + return step_output[0] |
| 876 | + if isinstance(step_output, Mapping): |
| 877 | + return step_output["prev_sample"] |
| 878 | + if hasattr(step_output, "prev_sample"): |
| 879 | + return step_output.prev_sample |
| 880 | + raise TypeError("Unsupported scheduler.step output. Expected a tuple or an object with `prev_sample`.") |
| 881 | + |
| 882 | + def _scheduler_step( |
| 883 | + self, |
| 884 | + scheduler: Scheduler, |
| 885 | + model_output: torch.Tensor, |
| 886 | + timestep: int | torch.Tensor, |
| 887 | + sample: torch.Tensor, |
| 888 | + next_timestep: int | torch.Tensor | None = None, |
| 889 | + ) -> torch.Tensor: |
| 890 | + step_kwargs = {} |
| 891 | + if self._scheduler_step_supports_kwarg(scheduler, "return_dict"): |
| 892 | + step_kwargs["return_dict"] = False |
| 893 | + |
| 894 | + if isinstance(scheduler, RFlowScheduler): |
| 895 | + step_output = scheduler.step(model_output, timestep, sample, next_timestep, **step_kwargs) # type: ignore |
| 896 | + else: |
| 897 | + step_output = scheduler.step(model_output, timestep, sample, **step_kwargs) # type: ignore |
| 898 | + |
| 899 | + return self._get_previous_sample_from_step_output(step_output) |
| 900 | + |
864 | 901 | def __call__( # type: ignore[override] |
865 | 902 | self, |
866 | 903 | inputs: torch.Tensor, |
@@ -940,7 +977,12 @@ def sample( |
940 | 977 | scheduler = self.scheduler |
941 | 978 | image = input_noise |
942 | 979 |
|
943 | | - all_next_timesteps = torch.cat((scheduler.timesteps[1:], torch.tensor([0], dtype=scheduler.timesteps.dtype))) |
| 980 | + all_next_timesteps = torch.cat( |
| 981 | + ( |
| 982 | + scheduler.timesteps[1:], |
| 983 | + torch.tensor([0], dtype=scheduler.timesteps.dtype, device=scheduler.timesteps.device), |
| 984 | + ) |
| 985 | + ) |
944 | 986 | if verbose and has_tqdm: |
945 | 987 | progress_bar = tqdm( |
946 | 988 | zip(scheduler.timesteps, all_next_timesteps), |
@@ -984,10 +1026,9 @@ def sample( |
984 | 1026 | model_output = model_output_uncond + cfg * (model_output_cond - model_output_uncond) |
985 | 1027 |
|
986 | 1028 | # 2. compute previous image: x_t -> x_t-1 |
987 | | - if not isinstance(scheduler, RFlowScheduler): |
988 | | - image, _ = scheduler.step(model_output, t, image) # type: ignore |
989 | | - else: |
990 | | - image, _ = scheduler.step(model_output, t, image, next_t) # type: ignore |
| 1029 | + image = self._scheduler_step( |
| 1030 | + scheduler=scheduler, model_output=model_output, timestep=t, sample=image, next_timestep=next_t |
| 1031 | + ) |
991 | 1032 | if save_intermediates and t % intermediate_steps == 0: |
992 | 1033 | intermediates.append(image) |
993 | 1034 |
|
@@ -1046,7 +1087,7 @@ def get_likelihood( |
1046 | 1087 | total_kl = torch.zeros(inputs.shape[0]).to(inputs.device) |
1047 | 1088 | for t in progress_bar: |
1048 | 1089 | timesteps = torch.full(inputs.shape[:1], t, device=inputs.device).long() |
1049 | | - noisy_image = self.scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps) |
| 1090 | + noisy_image = scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps) |
1050 | 1091 | diffusion_model = ( |
1051 | 1092 | partial(diffusion_model, seg=seg) |
1052 | 1093 | if isinstance(diffusion_model, SPADEDiffusionModelUNet) |
@@ -1509,7 +1550,12 @@ def sample( # type: ignore[override] |
1509 | 1550 | scheduler = self.scheduler |
1510 | 1551 | image = input_noise |
1511 | 1552 |
|
1512 | | - all_next_timesteps = torch.cat((scheduler.timesteps[1:], torch.tensor([0], dtype=scheduler.timesteps.dtype))) |
| 1553 | + all_next_timesteps = torch.cat( |
| 1554 | + ( |
| 1555 | + scheduler.timesteps[1:], |
| 1556 | + torch.tensor([0], dtype=scheduler.timesteps.dtype, device=scheduler.timesteps.device), |
| 1557 | + ) |
| 1558 | + ) |
1513 | 1559 | if verbose and has_tqdm: |
1514 | 1560 | progress_bar = tqdm( |
1515 | 1561 | zip(scheduler.timesteps, all_next_timesteps), |
@@ -1583,10 +1629,9 @@ def sample( # type: ignore[override] |
1583 | 1629 | model_output = model_output_uncond + cfg * (model_output_cond - model_output_uncond) |
1584 | 1630 |
|
1585 | 1631 | # 3. compute previous image: x_t -> x_t-1 |
1586 | | - if not isinstance(scheduler, RFlowScheduler): |
1587 | | - image, _ = scheduler.step(model_output, t, image) # type: ignore |
1588 | | - else: |
1589 | | - image, _ = scheduler.step(model_output, t, image, next_t) # type: ignore |
| 1632 | + image = self._scheduler_step( |
| 1633 | + scheduler=scheduler, model_output=model_output, timestep=t, sample=image, next_timestep=next_t |
| 1634 | + ) |
1590 | 1635 |
|
1591 | 1636 | if save_intermediates and t % intermediate_steps == 0: |
1592 | 1637 | intermediates.append(image) |
@@ -1647,7 +1692,7 @@ def get_likelihood( # type: ignore[override] |
1647 | 1692 | total_kl = torch.zeros(inputs.shape[0]).to(inputs.device) |
1648 | 1693 | for t in progress_bar: |
1649 | 1694 | timesteps = torch.full(inputs.shape[:1], t, device=inputs.device).long() |
1650 | | - noisy_image = self.scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps) |
| 1695 | + noisy_image = scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps) |
1651 | 1696 |
|
1652 | 1697 | diffuse = diffusion_model |
1653 | 1698 | if isinstance(diffusion_model, SPADEDiffusionModelUNet): |
|
0 commit comments