-
Notifications
You must be signed in to change notification settings - Fork 87
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enforce the zero terminal SNR to schedulers #397
Comments
This could be good to do but should be motivated with our own experiment to show there's a meaningful difference. |
Starting to play with this, will let you know what I find out |
So I've found that starting from t=T helps for my application, but the zero terminal SNR hinders performance. However, my use cache is a bit niche (segmentation). I'm planning to re-run the FID tutorial with the changes to see if there is any difference |
Hi @marksgraham , this implementation might help with our code too (huggingface/diffusers#3664) |
Thanks walter. Interesting to see they aren't totally convinced by the updates. I like the idea of allowing the user to select the method for timestep discretisation (e.g. trailing, leading) as an argument to the scheduler. I think we could just make the noise schedules with snr=0 at t=T available as new options and keep the current ones as default. |
How would we want to implement this feature? In the Scheduler constructor we could add an argument for a rescale function which we'd apply to def __init__(self, num_train_timesteps: int = 1000, schedule: str = "linear_beta", rescale_func: Callable | None = None, **schedule_args) -> None:
super().__init__()
schedule_args["num_train_timesteps"] = num_train_timesteps
noise_sched = NoiseSchedules[schedule](**schedule_args)
# set betas, alphas, alphas_cumprod based off return value from noise function
if isinstance(noise_sched, tuple):
self.betas, self.alphas, self.alphas_cumprod = noise_sched
else:
self.betas = noise_sched
self.alphas = 1.0 - self.betas
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
if rescale_func is not None:
self.betas, self.alphas, self.alphas_cumprod = rescale_func(self.betas, self.alphas, self.alphas_cumprod)
... |
Allowing for a generic function to be supplied strikes me as a bit too general, given a user can now easily specify custom beta schedules. I think we could provide a |
As it stands, I think it makes sense to allow users to specify how they want to do the timestep spacing. I've found it helps in some of my applications. I tried to get some results for FID score son MNIST, but the results are unclear. I propose we give users the option and let them decide what is right for their application. There is a suggestion of how to do it in the linked PR, what do you guys think? If it looks OK i'll implement in the other schedulers. I haven't managed to get models with SNR=0 at t=T to train well at all, so I'm reluctant to implement it. I also note they weren't fully convinced of its utility in the huggingface discussion. |
Hi, I found an related issue with the DDIM scheduler. When using a scheduler with a steep terminal SNR decay like a cosine scheduler*, my results are super bad: However, if I change from the def set_timesteps(
self, num_inference_steps: int, device: str | torch.device | None = None
) -> None:
self.num_inference_steps = num_inference_steps
step_ratio = self.num_train_timesteps // self.num_inference_steps
# creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3
# timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64) # Leading
timesteps = (
np.linspace(0, self.num_train_timesteps - 1, num_inference_steps)
.round()[::-1]
.copy()
.astype(np.int64)
) # Linspace
self.timesteps = torch.from_numpy(timesteps).to(device)
self.timesteps += self.steps_offset Intuitively, I guess it makes sense that skipping the early, steep steps in the SNR is an issue for a cosine scheduler, while the first part likely does not matter so much in the case of something like a linear schedule. Anyways, how about letting the user decide on which scheme to use similar as *Note: I am using a custom implementation for the cosine schedule function def betas_for_alpha_bar(
num_diffusion_timesteps, alpha_bar, max_beta=0.999
): # https://github.com/openai/improved-diffusion/blob/783b6740edb79fdb7d063250db2c51cc9545dcd1/improved_diffusion/gaussian_diffusion.py#L45C1-L62C27
betas = []
for i in range(num_diffusion_timesteps):
t1 = i / num_diffusion_timesteps
t2 = (i + 1) / num_diffusion_timesteps
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
return torch.tensor(betas)
@NoiseSchedules.add_def("cosine_poly", "Cosine schedule")
def _cosine_beta(num_train_timesteps: int, s: float = 8e-3, order: float = 2, *args):
return betas_for_alpha_bar(
num_train_timesteps,
lambda t: np.cos((t + s) / (1 + s) * np.pi / 2) ** order,
) |
Hi, sorry for the slow response. We've had another report of bad results with the cosine scheduler here too. I actually started implementing this, but closed it because I couldn't find any benefit, but it does seem worth implementing. My closed PR is here. Do you have any interest in resurrecting it and doing a new PR? If not, will do it but it won't be a priority for a while as we will be focusing on integration with MONAI core for the moment. |
No worries and still thanks for your comment! I want to look into it, though I have some deadlines incoming so I fear I won't have time in the next couple of weeks. |
Can this issue be closed, since there is a new issue #489 open? |
I'd say do :) |
According to "Common Diffusion Noise Schedules and Sample Steps are Flawed" (https://arxiv.org/pdf/2305.08891.pdf), the implementation of most schedulers do not use t = T in the sampling process. We should include the corrections to enforce the zero terminal SNR
The text was updated successfully, but these errors were encountered: