Skip to content

Commit

Permalink
[Scheduler] introduce sigma schedule. (#7649)
Browse files Browse the repository at this point in the history
* introduce sigma schedule.

Co-authored-by: Suraj Patil <[email protected]>

* address yiyi

* update docstrings.

* implement the schedule for EDMDPMSolverMultistepScheduler

---------

Co-authored-by: Suraj Patil <[email protected]>
  • Loading branch information
sayakpaul and patil-suraj authored Apr 27, 2024
1 parent 9d16daa commit 56bd7e6
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 9 deletions.
34 changes: 29 additions & 5 deletions src/diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

# DISCLAIMER: This file is strongly influenced by https://github.com/LuChengTHU/dpm-solver and https://github.com/NVlabs/edm

import math
from typing import List, Optional, Tuple, Union

import numpy as np
Expand Down Expand Up @@ -44,6 +45,10 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
range is [0.2, 80.0].
sigma_data (`float`, *optional*, defaults to 0.5):
The standard deviation of the data distribution. This is set to 0.5 in the EDM paper [1].
sigma_schedule (`str`, *optional*, defaults to `karras`):
Sigma schedule to compute the `sigmas`. By default, we the schedule introduced in the EDM paper
(https://arxiv.org/abs/2206.00364). Other acceptable value is "exponential". The exponential schedule was
incorporated in this model: https://huggingface.co/stabilityai/cosxl.
num_train_timesteps (`int`, defaults to 1000):
The number of diffusion steps to train the model.
solver_order (`int`, defaults to 2):
Expand Down Expand Up @@ -89,6 +94,7 @@ def __init__(
sigma_min: float = 0.002,
sigma_max: float = 80.0,
sigma_data: float = 0.5,
sigma_schedule: str = "karras",
num_train_timesteps: int = 1000,
prediction_type: str = "epsilon",
rho: float = 7.0,
Expand Down Expand Up @@ -121,7 +127,11 @@ def __init__(
)

ramp = torch.linspace(0, 1, num_train_timesteps)
sigmas = self._compute_sigmas(ramp)
if sigma_schedule == "karras":
sigmas = self._compute_karras_sigmas(ramp)
elif sigma_schedule == "exponential":
sigmas = self._compute_exponential_sigmas(ramp)

self.timesteps = self.precondition_noise(sigmas)

self.sigmas = self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
Expand Down Expand Up @@ -236,7 +246,10 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc
self.num_inference_steps = num_inference_steps

ramp = np.linspace(0, 1, self.num_inference_steps)
sigmas = self._compute_sigmas(ramp)
if self.config.sigma_schedule == "karras":
sigmas = self._compute_karras_sigmas(ramp)
elif self.config.sigma_schedule == "exponential":
sigmas = self._compute_exponential_sigmas(ramp)

sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)
self.timesteps = self.precondition_noise(sigmas)
Expand All @@ -262,17 +275,28 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc
self._begin_index = None
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication

# Taken from https://github.com/crowsonkb/k-diffusion/blob/686dbad0f39640ea25c8a8c6a6e56bb40eacefa2/k_diffusion/sampling.py#L17
def _compute_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.FloatTensor:
# Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler._compute_karras_sigmas
def _compute_karras_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.FloatTensor:
"""Constructs the noise schedule of Karras et al. (2022)."""

sigma_min = sigma_min or self.config.sigma_min
sigma_max = sigma_max or self.config.sigma_max

rho = self.config.rho
min_inv_rho = sigma_min ** (1 / rho)
max_inv_rho = sigma_max ** (1 / rho)
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho

return sigmas

# Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler._compute_exponential_sigmas
def _compute_exponential_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.FloatTensor:
"""Implementation closely follows k-diffusion.
https://github.com/crowsonkb/k-diffusion/blob/6ab5146d4a5ef63901326489f31f1d8e7dd36b48/k_diffusion/sampling.py#L26
"""
sigma_min = sigma_min or self.config.sigma_min
sigma_max = sigma_max or self.config.sigma_max
sigmas = torch.linspace(math.log(sigma_min), math.log(sigma_max), len(ramp)).exp().flip(0)
return sigmas

# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
Expand Down
34 changes: 30 additions & 4 deletions src/diffusers/schedulers/scheduling_edm_euler.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import math
from dataclasses import dataclass
from typing import Optional, Tuple, Union

Expand Down Expand Up @@ -65,6 +66,10 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
range is [0.2, 80.0].
sigma_data (`float`, *optional*, defaults to 0.5):
The standard deviation of the data distribution. This is set to 0.5 in the EDM paper [1].
sigma_schedule (`str`, *optional*, defaults to `karras`):
Sigma schedule to compute the `sigmas`. By default, we the schedule introduced in the EDM paper
(https://arxiv.org/abs/2206.00364). Other acceptable value is "exponential". The exponential schedule was
incorporated in this model: https://huggingface.co/stabilityai/cosxl.
num_train_timesteps (`int`, defaults to 1000):
The number of diffusion steps to train the model.
prediction_type (`str`, defaults to `epsilon`, *optional*):
Expand All @@ -84,15 +89,23 @@ def __init__(
sigma_min: float = 0.002,
sigma_max: float = 80.0,
sigma_data: float = 0.5,
sigma_schedule: str = "karras",
num_train_timesteps: int = 1000,
prediction_type: str = "epsilon",
rho: float = 7.0,
):
if sigma_schedule not in ["karras", "exponential"]:
raise ValueError(f"Wrong value for provided for `{sigma_schedule=}`.`")

# setable values
self.num_inference_steps = None

ramp = torch.linspace(0, 1, num_train_timesteps)
sigmas = self._compute_sigmas(ramp)
if sigma_schedule == "karras":
sigmas = self._compute_karras_sigmas(ramp)
elif sigma_schedule == "exponential":
sigmas = self._compute_exponential_sigmas(ramp)

self.timesteps = self.precondition_noise(sigmas)

self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
Expand Down Expand Up @@ -200,7 +213,10 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
self.num_inference_steps = num_inference_steps

ramp = np.linspace(0, 1, self.num_inference_steps)
sigmas = self._compute_sigmas(ramp)
if self.config.sigma_schedule == "karras":
sigmas = self._compute_karras_sigmas(ramp)
elif self.config.sigma_schedule == "exponential":
sigmas = self._compute_exponential_sigmas(ramp)

sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)
self.timesteps = self.precondition_noise(sigmas)
Expand All @@ -211,16 +227,26 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication

# Taken from https://github.com/crowsonkb/k-diffusion/blob/686dbad0f39640ea25c8a8c6a6e56bb40eacefa2/k_diffusion/sampling.py#L17
def _compute_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.FloatTensor:
def _compute_karras_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.FloatTensor:
"""Constructs the noise schedule of Karras et al. (2022)."""

sigma_min = sigma_min or self.config.sigma_min
sigma_max = sigma_max or self.config.sigma_max

rho = self.config.rho
min_inv_rho = sigma_min ** (1 / rho)
max_inv_rho = sigma_max ** (1 / rho)
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho

return sigmas

def _compute_exponential_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.FloatTensor:
"""Implementation closely follows k-diffusion.
https://github.com/crowsonkb/k-diffusion/blob/6ab5146d4a5ef63901326489f31f1d8e7dd36b48/k_diffusion/sampling.py#L26
"""
sigma_min = sigma_min or self.config.sigma_min
sigma_max = sigma_max or self.config.sigma_max
sigmas = torch.linspace(math.log(sigma_min), math.log(sigma_max), len(ramp)).exp().flip(0)
return sigmas

# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep
Expand Down

0 comments on commit 56bd7e6

Please sign in to comment.