Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SDXL-Lightning inference steps ignored #107

Open
stronk-dev opened this issue Jun 15, 2024 · 4 comments
Open

SDXL-Lightning inference steps ignored #107

stronk-dev opened this issue Jun 15, 2024 · 4 comments

Comments

@stronk-dev
Copy link
Contributor

Premise: this is a model which does not accept the guidance_scale param and loads a specific set of model weights according to the amount of num_inference_steps you want to do (1, 2, 4 or 8 steps).

As apps would request the ByteDance/SDXL-Lightning model, the following code would make it default to 2 steps:

if SDXL_LIGHTNING_MODEL_ID in model_id:
base = "stabilityai/stable-diffusion-xl-base-1.0"
# ByteDance/SDXL-Lightning-2step
if "2step" in model_id:
unet_id = "sdxl_lightning_2step_unet"
# ByteDance/SDXL-Lightning-4step
elif "4step" in model_id:
unet_id = "sdxl_lightning_4step_unet"
# ByteDance/SDXL-Lightning-8step
elif "8step" in model_id:
unet_id = "sdxl_lightning_8step_unet"
else:
# Default to 2step
unet_id = "sdxl_lightning_2step_unet"

And then when running inference, it would override num_inference_steps to 2:

elif SDXL_LIGHTNING_MODEL_ID in self.model_id:
# SDXL-Lightning models should have guidance_scale = 0 and use
# the correct number of inference steps for the unet checkpoint loaded
kwargs["guidance_scale"] = 0.0
if "2step" in self.model_id:
kwargs["num_inference_steps"] = 2
elif "4step" in self.model_id:
kwargs["num_inference_steps"] = 4
elif "8step" in self.model_id:
kwargs["num_inference_steps"] = 8
else:
# Default to 2step
kwargs["num_inference_steps"] = 2

Apparently apps needs to append 4step or 8step to the model ID if they want to do a different amount of num_inference_steps. This can be very confusing to app developers, who likely just request ByteDance/SDXL-Lightning with a specific number of num_inference_steps, which then quietly get overwritten during inference.

This would also explain why people have reported this model to have bad output, as running this model at 8 steps provides a vastly different output than at 2 steps.

Proposed solutions could be to switch unet/LoRas during inference or to make the documentation very clear how this specifc model behaves. Luckily with models like RealVisXL_V4.0_Lightning you're not tied to a specific amount of inference_steps

@yondonfu
Copy link
Member

Agreed that this behavior is confusing. FWIW I originally implemented this as a quick hack to support loading a specific N-step checkpoint for SDXL-Lightning (since all the SDXL-Lightning checkpoints are tied with a specific # of inference steps) on pipeline initialization.

LoRA switching at inference time could work (I'm not sure that unet switching at inference time would be a good idea as that would probably incur a lot more overhead), but since the general LoRA switching logic is not implemented yet IMO starting with the low hanging fruit of establishing clearer docs would be a better place to start.

@rickstaa
Copy link
Member

Agreed that this behavior is confusing. FWIW I originally implemented this as a quick hack to support loading a specific N-step checkpoint for SDXL-Lightning (since all the SDXL-Lightning checkpoints are tied with a specific # of inference steps) on pipeline initialization.

LoRA switching at inference time could work (I'm not sure that unet switching at inference time would be a good idea as that would probably incur a lot more overhead), but since the general LoRA switching logic is not implemented yet IMO starting with the low hanging fruit of establishing clearer docs would be a better place to start.

image

@stronk-dev, @yondonfu, what are your thoughts on removing the 2/4-step models and exclusively serving the 8-step model, while documenting the behavior of the unused parameters? The 2-step model is only 1 second faster, and the difference between the 4-step and 8-step models is minimal. I think this will decrease the confusion.

@stronk-dev
Copy link
Contributor Author

I haven't done testing with the 4 step model, so I can't speak to it's inference speed and quality difference. I'd expect there to be a bigger difference in inference time (I think the worker prints the amount of it/sec ? You could calculate the extra time required using that).

I'd certainly prefer the simplicity of advertising just 1 model, but I would be curious to see the quality difference between 4-step and 8-step first

@yondonfu
Copy link
Member

what are your thoughts on removing the 2/4-step models and exclusively serving the 8-step model

IMO the 2/4-step models should continue to be supported and the 8-step model can just be used as the default if the model ID is set to ByteDance/SDXL-Lightning. This change in tandem with a couple sentences in relevant docs, noting that ByteDance/SDXL-Lightning is a special case (in most other cases each HF repo just has one model, but ByteDance structured their HF repo differently) where the user can also specify the -Nstep suffix in the model ID to use the corresponding N-step model, should address the OP.

My reasoning here is that each N-step model is actually a distinct checkpoint and devs should be able to request specific checkpoints if they want to and ByteDance/SDXL-Lightning will just be treated as an alias for ByteDance/SDXL-Lightning-8step.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants