Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enforce the zero terminal SNR to schedulers #397

Closed
Warvito opened this issue May 25, 2023 · 13 comments · May be fixed by #404
Closed

Enforce the zero terminal SNR to schedulers #397

Warvito opened this issue May 25, 2023 · 13 comments · May be fixed by #404

Comments

@Warvito
Copy link
Collaborator

Warvito commented May 25, 2023

According to "Common Diffusion Noise Schedules and Sample Steps are Flawed" (https://arxiv.org/pdf/2305.08891.pdf), the implementation of most schedulers do not use t = T in the sampling process. We should include the corrections to enforce the zero terminal SNR

@ericspod
Copy link
Member

This could be good to do but should be motivated with our own experiment to show there's a meaningful difference.

@marksgraham
Copy link
Collaborator

Starting to play with this, will let you know what I find out

@marksgraham
Copy link
Collaborator

So I've found that starting from t=T helps for my application, but the zero terminal SNR hinders performance. However, my use cache is a bit niche (segmentation). I'm planning to re-run the FID tutorial with the changes to see if there is any difference

@Warvito
Copy link
Collaborator Author

Warvito commented Jun 6, 2023

Hi @marksgraham , this implementation might help with our code too (huggingface/diffusers#3664)

@marksgraham
Copy link
Collaborator

Thanks walter. Interesting to see they aren't totally convinced by the updates.

I like the idea of allowing the user to select the method for timestep discretisation (e.g. trailing, leading) as an argument to the scheduler.

I think we could just make the noise schedules with snr=0 at t=T available as new options and keep the current ones as default.

@ericspod
Copy link
Member

ericspod commented Jun 7, 2023

How would we want to implement this feature? In the Scheduler constructor we could add an argument for a rescale function which we'd apply to self.betas, self.alphas, self.alphas_cumprod:

def __init__(self, num_train_timesteps: int = 1000, schedule: str = "linear_beta", rescale_func: Callable | None = None, **schedule_args) -> None:
        super().__init__()
        schedule_args["num_train_timesteps"] = num_train_timesteps
        noise_sched = NoiseSchedules[schedule](**schedule_args)

        # set betas, alphas, alphas_cumprod based off return value from noise function
        if isinstance(noise_sched, tuple):
            self.betas, self.alphas, self.alphas_cumprod = noise_sched
        else:
            self.betas = noise_sched
            self.alphas = 1.0 - self.betas
            self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)

        if rescale_func is not None:
           self.betas, self.alphas, self.alphas_cumprod = rescale_func(self.betas, self.alphas, self.alphas_cumprod)
        ...

@marksgraham
Copy link
Collaborator

Allowing for a generic function to be supplied strikes me as a bit too general, given a user can now easily specify custom beta schedules. I think we could provide a rescale_to_snr0 flag, similar to what you've proposed, or just add a new set of beta schedule options which enforce snr=0 at t=T, to complement the existing ones (e.g. linear_beta_snr0, scaled_linear_beta_snr0)

@marksgraham
Copy link
Collaborator

As it stands, I think it makes sense to allow users to specify how they want to do the timestep spacing. I've found it helps in some of my applications. I tried to get some results for FID score son MNIST, but the results are unclear. I propose we give users the option and let them decide what is right for their application. There is a suggestion of how to do it in the linked PR, what do you guys think? If it looks OK i'll implement in the other schedulers.

I haven't managed to get models with SNR=0 at t=T to train well at all, so I'm reluctant to implement it. I also note they weren't fully convinced of its utility in the huggingface discussion.

@sRassmann
Copy link

Hi, I found an related issue with the DDIM scheduler. When using a scheduler with a steep terminal SNR decay like a cosine scheduler*, my results are super bad:

image

However, if I change from the Leading (notation from the Table 2 in the Common Diffusion Noise Schedules and Sample Steps are Flawed paper) to Linspace method for timestep spacing, my results are a lot better:

    def set_timesteps(
        self, num_inference_steps: int, device: str | torch.device | None = None
    ) -> None:
        self.num_inference_steps = num_inference_steps
        step_ratio = self.num_train_timesteps // self.num_inference_steps
        # creates integer timesteps by multiplying by ratio
        # casting to int to avoid issues when num_inference_step is power of 3
        # timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64)  # Leading 
        timesteps = (
            np.linspace(0, self.num_train_timesteps - 1, num_inference_steps)
            .round()[::-1]
            .copy()
            .astype(np.int64)
        )  # Linspace
        self.timesteps = torch.from_numpy(timesteps).to(device)
        self.timesteps += self.steps_offset

image

Intuitively, I guess it makes sense that skipping the early, steep steps in the SNR is an issue for a cosine scheduler, while the first part likely does not matter so much in the case of something like a linear schedule.

Anyways, how about letting the user decide on which scheme to use similar as diffusers handles it?

*Note: I am using a custom implementation for the cosine schedule function

def betas_for_alpha_bar(
    num_diffusion_timesteps, alpha_bar, max_beta=0.999
):  # https://github.com/openai/improved-diffusion/blob/783b6740edb79fdb7d063250db2c51cc9545dcd1/improved_diffusion/gaussian_diffusion.py#L45C1-L62C27
    betas = []
    for i in range(num_diffusion_timesteps):
        t1 = i / num_diffusion_timesteps
        t2 = (i + 1) / num_diffusion_timesteps
        betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
    return torch.tensor(betas)


@NoiseSchedules.add_def("cosine_poly", "Cosine schedule")
def _cosine_beta(num_train_timesteps: int, s: float = 8e-3, order: float = 2, *args):
    return betas_for_alpha_bar(
        num_train_timesteps,
        lambda t: np.cos((t + s) / (1 + s) * np.pi / 2) ** order,
    )

@marksgraham
Copy link
Collaborator

Hi, sorry for the slow response.

We've had another report of bad results with the cosine scheduler here too.

I actually started implementing this, but closed it because I couldn't find any benefit, but it does seem worth implementing. My closed PR is here.

Do you have any interest in resurrecting it and doing a new PR? If not, will do it but it won't be a priority for a while as we will be focusing on integration with MONAI core for the moment.

@sRassmann
Copy link

No worries and still thanks for your comment!

I want to look into it, though I have some deadlines incoming so I fear I won't have time in the next couple of weeks.

@virginiafdez
Copy link
Contributor

Can this issue be closed, since there is a new issue #489 open?

@marksgraham
Copy link
Collaborator

I'd say do :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants