-
Notifications
You must be signed in to change notification settings - Fork 5.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Official callbacks #7761
Official callbacks #7761
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
cc @a-r-r-o-w here too |
I like this - very nice and simple for the next steps, let's see a proposal for this?
after that, we can play around with |
actually would be nice to support list of callbacks since now we provide official ones that user can mix and match |
Yeah, I think this is the right way to do it. In fact I would say to not even use "callbacks" but rather just a pure function for doing each sampling step called Basically we have the sampling loop (SD pipeline as an example) with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual
noise_pred = self.unet(
latent_model_input,
t,
encoder_hidden_states=prompt_embeds,
timestep_cond=timestep_cond,
cross_attention_kwargs=self.cross_attention_kwargs,
added_cond_kwargs=added_cond_kwargs,
return_dict=False,
)[0]
# perform guidance
if self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents) In my opinion, we need should replace everything underneath We can add an argument In this way we can finally get complete control over the sampling loop and chain multiple functions together the process the output of the sampling loop in the order of the sampling functions. |
the specifics of that i'll have to wrap my head around but the initial idea of decoupling the logic inside |
there's the concept library on the hf hub from back in the day. for the uninitiated, it is/was a collection of dreambooths others had done, to make it easier to find eg. a backpack checkpoint or some other oddly specific item you reliably needed to work. i know it's a security nightmare, but the idea of hub-supported callbacks "calls to me" as something worth bringing up. on the other hand, having community callbacks in this repo is time-consuming but that allows thorough review of any callbacks that are included. unlike dreambooths, callbacks seem like they'd be rarely created, whereas there a billion potential concepts for a dreambooth. listing the available callbacks is quite trivial in either case, where a |
It's probably easier if I write it out in some pseudocode. Writing it down, I think class SamplingInput:
def __init__(self, img, text_embedding, unet, timestep=None, **kwargs):
self.img = img
self.text_embedding = text_embedding
self.unet = unet
self.timestep = timestep
# ... lots of other code ...
inp = SamplingInput(img, text_embedding, unet)
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
inp.timestep = t
for sampling_function in self.sampling_functions:
inp = sampling_function(inp)
output_img = inp.img Then we can do something like |
I really like the array of function pointers. It makes composition easy and clearly signals that the methods are designed to be changeable. |
ohh thanks! |
a bit late to the party here, but adding one use-case: modifying or skipping steps. for i, t in enumerate(timesteps): big use case is for callback to actually modify timesteps in some sense - perhaps we want to skip a step? perhaps force an early end since callback function determined it got what it needed and there is no point of running all the remaining steps to completion? |
The scheduler modifies which timesteps are in the timestep list, so determination of timesteps to run lives there. You can very simply just write your own scheduler to exclude some timesteps. |
or i guess a scheduler wrapper that takes in its own callbacks, in teh case of SD.Next |
It's a continuation of #7736 but engineering a proper solution rather than a half baked one will save longer in the long run. For example right now for determining timesteps we have schedulers -- the scheduler is a effectively a function you can pass into the pipeline that is relatively pure and just gets which timesteps you are supposed to perform, for the most part. Ideally we extend such functional designs to the sampling loop as well, and in this case, extend the ability to run multiple sampling functions in sequence. I believe this solves every current and previous deficit that hacks like callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"], have failed to fully address. In my opinion this is poorly engineered, and now that it exists in the codebase it will need to be supported with backwards compatibility for the rest of time whereas I believe my proposed solution is (1) clean (2) consistent with the engineering of schedulers and (3) will not result in technical debt, but will incur a large one time cost to support by various pipelines. For now, just a few of the most used pipelines could be done and the rest stubbed with |
@AmericanPresidentJimmyCarter i get that, but i don't want to monkey-patch all schedulers existing in diffusers. example use case - there are some experimental sd15 models popping-up that are only finetuned on high-noise or low-noise - with idea behind them very similar to sdxl-refiner, but stabilityai never did refiner for sd15 and there is no pipeline for it.
@bghira i might as well need to do that, i though since we're talking about callbacks design here this would be a place to address future needs. |
that's a good point vlad. i was just thinking a preliminary attempt at a scheduler wrapper might result in some lessons being discovered that might help make a better upstream (diffusers) design. but maybe you already have a concrete idea? :P also #4355 for your SD 1.x refiner needs. |
Yeah, this would require even more re-engineering. You would need sampling functions to be a part of the scheduler, and all of them would need to be passed to the scheduler instead of the pipeline. The net effect is more or less the same. So for every pipeline, we would have a default sampling function which we pass to the default scheduler, and we could also pass multiple of these as I proposed. Then the only difference is in the sampling loop we |
or a very simple hack using existing callback concept:
for i, t in enumerate(timesteps):
if t <= 0:
continue |
Why wouldn't you just subclass the scheduler and then overwrite the get timestep method? That seems trivial? |
@vladmandic this should solve your problem, no? |
@vladmandic I think it's what you proposed here, already implemented
|
@AmericanPresidentJimmyCarter
feel free to open another issue or discussion |
Opened as #7808 |
thanks @bghira, I really appreciate your comment. |
- `pipeline` (or the pipeline instance) provides access to important properties such as `num_timesteps` and `guidance_scale`. You can modify these properties by updating the underlying attributes. For this example, you'll disable CFG by setting `pipeline._guidance_scale=0.0`. | ||
- `step_index` and `timestep` tell you where you are in the denoising loop. Use `step_index` to turn off CFG after reaching 40% of `num_timesteps`. | ||
- `callback_kwargs` is a dict that contains tensor variables you can modify during the denoising loop. It only includes variables specified in the `callback_on_step_end_tensor_inputs` argument, which is passed to the pipeline's `__call__` method. Different pipelines may use different sets of variables, so please check a pipeline's `_callback_tensor_inputs` attribute for the list of variables you can modify. Some common variables include `latents` and `prompt_embeds`. For this function, change the batch size of `prompt_embeds` after setting `guidance_scale=0.0` in order for it to work properly. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
these changes were made by make quality
. Also asked @stevhliu for a review of the documentation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the test fails because you spelled Dict
as dict
, can we fix them so the tests pass?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks! let's get this merged soon :)
cc @sayakpaul for a final review too |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The design looks very clean to me. I love it!
My main comments are mostly on the documentation side. Additionally, I think having some tests (fast) would be nice to have.
@asomoza I think we can merge this PR and introduce a test suite in a future PR. Up to you how you want to tackle it. |
yeah, I prefer to write the tests in a different PR after merging this one |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
left some final feedback on the docs, I think we are ready to merge:)
What does this PR do?
Initial draft to support for official callbacks.
This is the most basic implementation I could think of without the need of modifying the pipelines.
After this, we need to discuss if we're going modify the pipelines to support additional functionalities:
On step begin
For this issue for example, the propossal is to start the CFG after a certain step and to stop it after another step. For the CFG on begin we would need to add an additional callback
on_step_begin
if we want to do it on the callbacks instead of manually doing it with the embeds and pass them to the pipelines. The same will be needed for differential diffusion.Automatic
callback_on_step_end_tensor_inputs
With the current implementation the user needs to know what to add to the
callback_on_step_end_tensor_inputs
list, for example for the SDXL implementation of the CFG cutout we need to addprompt_embeds
,add_text_embeds
,add_time_ids
or it won't work. If we want to do this automatically I'll need to modify the pipelines, if not, I can add a error message indicating what values are missing.The user already needs to know the args for each callback so maybe this is better to just document in a README for all the callbacks.
Chain callbacks
Should we add the functionality to chain callbacks? for example to use a list of callbacks, so we can use the CFG and IP cutouts at the same time? The alternative is to create another callback that does both of them.
callback_on_step_end_tensor_inputs
Fixes #7736
Example usage:
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.