-
Notifications
You must be signed in to change notification settings - Fork 5.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[training] CogVideoX Lora #9302
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
Hi @a-r-r-o-w, did you achieve any satisfactory results in the current version? I tried the code on my machine but got broken generation results after just dozens of iterations. |
The first issue I noticed is the re-parameterization was wrong. target = model_input
(alpha_prod_t**0.5) * sample - (beta_prod_t**0.5)
alphas_cumprod = scheduler.alphas_cumprod.to(model_pred.device, model_pred.dtype)
alphas_cumprod_sqrt = alphas_cumprod[timesteps] ** 0.5
c_skip = alphas_cumprod_sqrt
c_out = -((1 - alphas_cumprod_sqrt**2) ** 0.5)
while len(c_skip.shape) < len(model_pred.shape):
c_skip = c_skip.unsqueeze(-1)
while len(c_out.shape) < len(model_pred.shape):
c_out = c_out.unsqueeze(-1)
weights = 1 / ( 1-alphas_cumprod_sqrt**2)
while len(weights.shape) < len(model_pred.shape):
weights = weights.unsqueeze(-1)
model_pred = c_out * model_pred + c_skip * model_input But after fixing it, I still got broken results after around 200 iterations. Any advice would be appreciated. |
cc @bghira here in case you have interest and time to help a little bit with CogVideoX lora! (no worries if not!) |
can you plot some of the values during inference that work and then compare them to training? |
Hey, thanks a lot for noticing this! I haven't been able to generate any good results too. I actually have the following locally: model_output = transformer(
hidden_states=noisy_model_input,
encoder_hidden_states=prompt_embeds,
timestep=timesteps,
image_rotary_emb=image_rotary_emb,
return_dict=False,
)[0]
alphas_cumprod = scheduler.alphas_cumprod[timesteps]
alphas_cumprod_sqrt = alphas_cumprod ** 0.5
one_minus_alphas_cumprod_sqrt = (1 - alphas_cumprod) ** 0.5
model_pred = noisy_model_input * alphas_cumprod_sqrt - model_output * one_minus_alphas_cumprod_sqrt Should it be |
Ops, sorry. I made a typo. It should be noisy_model input. @a-r-r-o-w |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Left some comments. Let's maybe also add a test case first to quickly identify potential suspects?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Left some questions here and there (some of them are clarification questions so bear with me).
Here's a quick test. I apply a single frame video (video_length = 1) for tuning with batch size 1. The learning rate is set to 1e-3. I trained for 500 iterations and It can reproduce the trained frame. Yet I can observe the generation results broken in the middle iterations. I also trained the lora on the same but longer video and observed similar results. I tried to load the trained lora and generate new videos. It can still follow prompt to generate other videos but suffer from quality degradation. Validation outputs0_validation_video_0_The_video_features_a_man_.mp440_validation_video_0_The_video_features_a_man_.mp480_validation_video_0_The_video_features_a_man_.mp4100_validation_video_0_The_video_features_a_man_.mp4120_validation_video_0_The_video_features_a_man_.mp4160_validation_video_0_The_video_features_a_man_.mp4200_validation_video_0_The_video_features_a_man_.mp4240_validation_video_0_The_video_features_a_man_.mp4280_validation_video_0_The_video_features_a_man_.mp4320_validation_video_0_The_video_features_a_man_.mp4360_validation_video_0_The_video_features_a_man_.mp4400_validation_video_0_The_video_features_a_man_.mp4440_validation_video_0_The_video_features_a_man_.mp4 |
I am facing similar issues too after overfitting on single example even after 1000 steps. Will try to take another deep look some time soon but AFAICT, there don't seem to be any more differences. I've left a comment asking some questions. From the different discussions, I gather that ~100 videos and 4000+ steps seem to be ideal for finetuning. This seems very different from normal Dreambooth-like finetuning tbh where just a few examples would be okay to teach new concepts. Maybe @zRzRzRzRzRzRzR @tengjiayan20 can hopefully take a look and help here. |
Very insightful comment @a-r-r-o-w . Thanks for the reply. |
I occur an error here, now is there a solution about it ? Thanks |
I don't seem to get this error nor G-U-N. Could you provide more context? What are your flags when launching this script? What specific line does it fail on? Have you modified the script in any way? |
what is the shape of tensor alphas_cumprod ? Thanks |
|
I use batch_size = 2, The problem may be it. Thanks |
Co-Authored-By: Fu-Yun Wang <[email protected]>
help="Number of frames to skip from the end of each input video. Useful if training data contains outro sequences.", | ||
) | ||
parser.add_argument( | ||
"--random_flip", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think not used yet. TODO to support in follow-up PR
|
||
# Downloading and loading a dataset from the hub. See more about loading custom images at | ||
# https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script | ||
dataset = load_dataset( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The load_dataset
method is not too good here due to lack of support for video data. Similar to how I did in the lora testing script, I think supporting snapshot_download
from the hub would be nice to have and easier.
return parser.parse_args() | ||
|
||
|
||
class VideoDataset(Dataset): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TODOs:
- Support loading latents directly instead of videos
- Create a
prepare_dataset.py
for preprocessing data, and possibly having captioning utilities
|
||
def _preprocess_data(self): | ||
try: | ||
import decord |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TODO: Maybe better to add as a backend for load_video
in the future.
[`CogVideoX`]. | ||
""" | ||
|
||
_lora_loadable_modules = ["transformer", "text_encoder"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
_lora_loadable_modules = ["transformer", "text_encoder"] | |
_lora_loadable_modules = ["transformer"] |
For now, since we removed text encoder related training, need to remove everything related to this in the lora loader
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok we should remove text_encoder from lora in this PR?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So, I tried removing it but this causes test failures in ~10 different places, so I chose not to do it for now since it would require significant modification of many tests. I think it's okay to leave it here in case someone manages to fine tune the text encoder and wants to use it (eventually we can add support too).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
don't think we should add this code unless it is needed though
I just went through all the loraloadermixin here, I think we currently do not support t5 at all - cc @sayakpaul here to confirm if it is the case?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes I came to the same conclusion. I'm actually working on removing thet text encoder parts at the moment so will update in a bit.
Need to fight a few more tests 👊
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah community doesn't really do T5 at the moment. So, we don't support it. No LoRA that is extremely popular has T5 (at least that is what @apolinario and myself have known). But supporting it is no big deal really.
src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py
Outdated
Show resolved
Hide resolved
@@ -690,14 +708,21 @@ def test_simple_inference_with_text_denoiser_lora_and_scale(self): | |||
scheduler_classes = ( | |||
[FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler] | |||
) | |||
call_signature_keys = inspect.signature(self.pipeline_class.__call__).parameters.keys() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this can be a property (child class can overwrite it too) - ok to keep it as it is the test here;
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks! and congrats on winning the fights against the lora test!
Co-Authored-By: YiYi Xu <[email protected]>
Why didn't the code run the following two lines of code after calculating the loss? |
* cogvideox lora training draft * update * update * update * update * update * make fix-copies * update * update * apply suggestions from review * apply suggestions from reveiw * fix typo * Update examples/cogvideo/train_cogvideox_lora.py Co-authored-by: YiYi Xu <[email protected]> * fix lora alpha * use correct lora scaling for final test pipeline * Update examples/cogvideo/train_cogvideox_lora.py Co-authored-by: YiYi Xu <[email protected]> * apply suggestions from review; prodigy optimizer YiYi Xu <[email protected]> * add tests * make style * add README * update * update * make style * fix * update * add test skeleton * revert lora utils changes * add cleaner modifications to lora testing utils * update lora tests * deepspeed stuff * add requirements.txt * deepspeed refactor * add lora stuff to img2vid pipeline to fix tests * fight tests * add co-authors Co-Authored-By: Fu-Yun Wang <[email protected]> Co-Authored-By: zR <[email protected]> * fight lora runner tests * import Dummy optim and scheduler only wheh required * update docs * add coauthors Co-Authored-By: Fu-Yun Wang <[email protected]> * remove option to train text encoder Co-Authored-By: bghira <[email protected]> * update tests * fight more tests * update * fix vid2vid * fix typo * remove lora tests; todo in follow-up PR * undo img2vid changes * remove text encoder related changes in lora loader mixin * Revert "remove text encoder related changes in lora loader mixin" This reverts commit f8a8444. * update * round 1 of fighting tests * round 2 of fighting tests * fix copied from comment * fix typo in lora test * update styling Co-Authored-By: YiYi Xu <[email protected]> --------- Co-authored-by: YiYi Xu <[email protected]> Co-authored-by: zR <[email protected]> Co-authored-by: Fu-Yun Wang <[email protected]> Co-authored-by: bghira <[email protected]>
What does this PR do?
Adds LoRA training and loading support for CogVideoX.
This is a rough draft and incomplete conversion from CogVideoX SAT.
The above is assuming a 50-video dataset (total of 2000 training steps)
TODO:
Verify outputs against SAT implementationDon't match 1:1 possibly due to many reasonsWho can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@sayakpaul @yiyixuxu @linoytsaban
cc @zRzRzRzRzRzRzR @bghira