diff --git a/src/diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py index 26a41d7335c5..dfc7978a2ee2 100644 --- a/src/diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +++ b/src/diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py @@ -14,6 +14,7 @@ # DISCLAIMER: This file is strongly influenced by https://github.com/LuChengTHU/dpm-solver and https://github.com/NVlabs/edm +import math from typing import List, Optional, Tuple, Union import numpy as np @@ -44,6 +45,10 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): range is [0.2, 80.0]. sigma_data (`float`, *optional*, defaults to 0.5): The standard deviation of the data distribution. This is set to 0.5 in the EDM paper [1]. + sigma_schedule (`str`, *optional*, defaults to `karras`): + Sigma schedule to compute the `sigmas`. By default, we the schedule introduced in the EDM paper + (https://arxiv.org/abs/2206.00364). Other acceptable value is "exponential". The exponential schedule was + incorporated in this model: https://huggingface.co/stabilityai/cosxl. num_train_timesteps (`int`, defaults to 1000): The number of diffusion steps to train the model. solver_order (`int`, defaults to 2): @@ -89,6 +94,7 @@ def __init__( sigma_min: float = 0.002, sigma_max: float = 80.0, sigma_data: float = 0.5, + sigma_schedule: str = "karras", num_train_timesteps: int = 1000, prediction_type: str = "epsilon", rho: float = 7.0, @@ -121,7 +127,11 @@ def __init__( ) ramp = torch.linspace(0, 1, num_train_timesteps) - sigmas = self._compute_sigmas(ramp) + if sigma_schedule == "karras": + sigmas = self._compute_karras_sigmas(ramp) + elif sigma_schedule == "exponential": + sigmas = self._compute_exponential_sigmas(ramp) + self.timesteps = self.precondition_noise(sigmas) self.sigmas = self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)]) @@ -236,7 +246,10 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc self.num_inference_steps = num_inference_steps ramp = np.linspace(0, 1, self.num_inference_steps) - sigmas = self._compute_sigmas(ramp) + if self.config.sigma_schedule == "karras": + sigmas = self._compute_karras_sigmas(ramp) + elif self.config.sigma_schedule == "exponential": + sigmas = self._compute_exponential_sigmas(ramp) sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device) self.timesteps = self.precondition_noise(sigmas) @@ -262,10 +275,9 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc self._begin_index = None self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication - # Taken from https://github.com/crowsonkb/k-diffusion/blob/686dbad0f39640ea25c8a8c6a6e56bb40eacefa2/k_diffusion/sampling.py#L17 - def _compute_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.FloatTensor: + # Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler._compute_karras_sigmas + def _compute_karras_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.FloatTensor: """Constructs the noise schedule of Karras et al. (2022).""" - sigma_min = sigma_min or self.config.sigma_min sigma_max = sigma_max or self.config.sigma_max @@ -273,6 +285,18 @@ def _compute_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.FloatTe min_inv_rho = sigma_min ** (1 / rho) max_inv_rho = sigma_max ** (1 / rho) sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho + + return sigmas + + # Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler._compute_exponential_sigmas + def _compute_exponential_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.FloatTensor: + """Implementation closely follows k-diffusion. + + https://github.com/crowsonkb/k-diffusion/blob/6ab5146d4a5ef63901326489f31f1d8e7dd36b48/k_diffusion/sampling.py#L26 + """ + sigma_min = sigma_min or self.config.sigma_min + sigma_max = sigma_max or self.config.sigma_max + sigmas = torch.linspace(math.log(sigma_min), math.log(sigma_max), len(ramp)).exp().flip(0) return sigmas # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample diff --git a/src/diffusers/schedulers/scheduling_edm_euler.py b/src/diffusers/schedulers/scheduling_edm_euler.py index f6a09ca1ee16..0ef9263c9e30 100644 --- a/src/diffusers/schedulers/scheduling_edm_euler.py +++ b/src/diffusers/schedulers/scheduling_edm_euler.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import math from dataclasses import dataclass from typing import Optional, Tuple, Union @@ -65,6 +66,10 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin): range is [0.2, 80.0]. sigma_data (`float`, *optional*, defaults to 0.5): The standard deviation of the data distribution. This is set to 0.5 in the EDM paper [1]. + sigma_schedule (`str`, *optional*, defaults to `karras`): + Sigma schedule to compute the `sigmas`. By default, we the schedule introduced in the EDM paper + (https://arxiv.org/abs/2206.00364). Other acceptable value is "exponential". The exponential schedule was + incorporated in this model: https://huggingface.co/stabilityai/cosxl. num_train_timesteps (`int`, defaults to 1000): The number of diffusion steps to train the model. prediction_type (`str`, defaults to `epsilon`, *optional*): @@ -84,15 +89,23 @@ def __init__( sigma_min: float = 0.002, sigma_max: float = 80.0, sigma_data: float = 0.5, + sigma_schedule: str = "karras", num_train_timesteps: int = 1000, prediction_type: str = "epsilon", rho: float = 7.0, ): + if sigma_schedule not in ["karras", "exponential"]: + raise ValueError(f"Wrong value for provided for `{sigma_schedule=}`.`") + # setable values self.num_inference_steps = None ramp = torch.linspace(0, 1, num_train_timesteps) - sigmas = self._compute_sigmas(ramp) + if sigma_schedule == "karras": + sigmas = self._compute_karras_sigmas(ramp) + elif sigma_schedule == "exponential": + sigmas = self._compute_exponential_sigmas(ramp) + self.timesteps = self.precondition_noise(sigmas) self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)]) @@ -200,7 +213,10 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic self.num_inference_steps = num_inference_steps ramp = np.linspace(0, 1, self.num_inference_steps) - sigmas = self._compute_sigmas(ramp) + if self.config.sigma_schedule == "karras": + sigmas = self._compute_karras_sigmas(ramp) + elif self.config.sigma_schedule == "exponential": + sigmas = self._compute_exponential_sigmas(ramp) sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device) self.timesteps = self.precondition_noise(sigmas) @@ -211,9 +227,8 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication # Taken from https://github.com/crowsonkb/k-diffusion/blob/686dbad0f39640ea25c8a8c6a6e56bb40eacefa2/k_diffusion/sampling.py#L17 - def _compute_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.FloatTensor: + def _compute_karras_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.FloatTensor: """Constructs the noise schedule of Karras et al. (2022).""" - sigma_min = sigma_min or self.config.sigma_min sigma_max = sigma_max or self.config.sigma_max @@ -221,6 +236,17 @@ def _compute_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.FloatTe min_inv_rho = sigma_min ** (1 / rho) max_inv_rho = sigma_max ** (1 / rho) sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho + + return sigmas + + def _compute_exponential_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.FloatTensor: + """Implementation closely follows k-diffusion. + + https://github.com/crowsonkb/k-diffusion/blob/6ab5146d4a5ef63901326489f31f1d8e7dd36b48/k_diffusion/sampling.py#L26 + """ + sigma_min = sigma_min or self.config.sigma_min + sigma_max = sigma_max or self.config.sigma_max + sigmas = torch.linspace(math.log(sigma_min), math.log(sigma_max), len(ramp)).exp().flip(0) return sigmas # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep