Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[training] CogVideoX Lora #9302

Merged
merged 68 commits into from
Sep 19, 2024
Merged

[training] CogVideoX Lora #9302

merged 68 commits into from
Sep 19, 2024

Conversation

a-r-r-o-w
Copy link
Member

@a-r-r-o-w a-r-r-o-w commented Aug 28, 2024

What does this PR do?

Adds LoRA training and loading support for CogVideoX.

This is a rough draft and incomplete conversion from CogVideoX SAT.

#!/bin/bash

export TORCH_LOGS="+dynamo,recompiles,graph_breaks"
export TORCHDYNAMO_VERBOSE=1

GPU_IDS="3"

accelerate launch --gpu_ids $GPU_IDS examples/cogvideo/train_cogvideox_lora.py \
  --pretrained_model_name_or_path THUDM/CogVideoX-2b \
  --cache_dir <CACHE_DIR> \
  --instance_data_root <DATASET_ROOT_DIR> \
  --caption_column <CAPTION_COLUMN> \
  --video_column <VIDEO_COLUMN> \
  --id_token <ID_TOKEN> \
  --validation_prompt "<ID_TOKEN> A black and white animated scene unfolds, featuring a bulldog in overalls and a hat, standing on a ship's deck. The bulldog assumes various poses, then walks towards a dockside with two ducks and a cow. A wooden platform reads 'PODUNK LANDING,' while a building marked 'BOAT TICKETS' and scattered barrels hint at a destination. The bulldog and ducks move purposefully, possibly heading towards a food stand or boating services, amidst a monochromatic backdrop with no noticeable changes in environment or lighting:::A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical atmosphere of this unique musical performance" \
  --validation_prompt_separator ::: \
  --num_validation_videos 1 \
  --validation_epochs 10 \
  --seed 42 \
  --rank 64 \
  --lora_alpha 64 \
  --mixed_precision fp16 \
  --output_dir /raid/aryan/cogvideox-lora \
  --height 480 --width 720 --fps 8 --max_num_frames 49 --skip_frames_start 0 --skip_frames_end 0 \
  --train_batch_size 1 \
  --num_train_epochs 40 \
  --checkpointing_steps 1000 \
  --gradient_accumulation_steps 1 \
  --learning_rate 1e-3 \
  --lr_scheduler cosine_with_restarts \
  --lr_warmup_steps 200 \
  --lr_num_cycles 1 \
  --enable_slicing \
  --enable_tiling \
  --optimizer Adam \
  --adam_beta1 0.9 \
  --adam_beta2 0.95 \
  --max_grad_norm 1.0 \
  --report_to wandb

The above is assuming a 50-video dataset (total of 2000 training steps)

TODO:

  • Implement tiled encoding (current OOMs for Cog-5B but works for Cog-2B)
  • Test with Prodigy optimizer
  • Determine best data preparation format and make the process more clean
  • Prepare dummy test data repository for others to test (Edit: Available internally on our org. No public release from diffusers team on this at the moment)
  • Remove unnecessary parameters
  • Verify outputs against SAT implementation Don't match 1:1 possibly due to many reasons
  • Add lora tests
  • Docs

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@sayakpaul @yiyixuxu @linoytsaban

cc @zRzRzRzRzRzRzR @bghira

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@G-U-N
Copy link
Contributor

G-U-N commented Sep 2, 2024

Hi @a-r-r-o-w, did you achieve any satisfactory results in the current version? I tried the code on my machine but got broken generation results after just dozens of iterations.

@G-U-N
Copy link
Contributor

G-U-N commented Sep 2, 2024

The first issue I noticed is the re-parameterization was wrong.
After checking the official repo, I think it should be

                target = model_input
                
                (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5)
                alphas_cumprod = scheduler.alphas_cumprod.to(model_pred.device, model_pred.dtype)
                alphas_cumprod_sqrt = alphas_cumprod[timesteps] ** 0.5
                c_skip = alphas_cumprod_sqrt
                c_out = -((1 - alphas_cumprod_sqrt**2) ** 0.5)
                while len(c_skip.shape) < len(model_pred.shape):
                    c_skip = c_skip.unsqueeze(-1) 
                while len(c_out.shape) < len(model_pred.shape):
                    c_out = c_out.unsqueeze(-1)             
                weights = 1 / ( 1-alphas_cumprod_sqrt**2)
                while len(weights.shape) < len(model_pred.shape):
                    weights = weights.unsqueeze(-1)        
            
                
                model_pred = c_out * model_pred + c_skip * model_input

But after fixing it, I still got broken results after around 200 iterations. Any advice would be appreciated.

@yiyixuxu
Copy link
Collaborator

yiyixuxu commented Sep 2, 2024

cc @bghira here in case you have interest and time to help a little bit with CogVideoX lora! (no worries if not!)

@bghira
Copy link
Contributor

bghira commented Sep 2, 2024

can you plot some of the values during inference that work and then compare them to training?

@G-U-N
Copy link
Contributor

G-U-N commented Sep 3, 2024

Great @yiyixuxu @bghira, I am willing to assist if there is any need.

@a-r-r-o-w
Copy link
Member Author

The first issue I noticed is the re-parameterization was wrong.

Hey, thanks a lot for noticing this! I haven't been able to generate any good results too. I actually have the following locally:

model_output = transformer(
    hidden_states=noisy_model_input,
    encoder_hidden_states=prompt_embeds,
    timestep=timesteps,
    image_rotary_emb=image_rotary_emb,
    return_dict=False,
)[0]
alphas_cumprod = scheduler.alphas_cumprod[timesteps]
alphas_cumprod_sqrt = alphas_cumprod ** 0.5
one_minus_alphas_cumprod_sqrt = (1 - alphas_cumprod) ** 0.5
model_pred = noisy_model_input * alphas_cumprod_sqrt - model_output * one_minus_alphas_cumprod_sqrt

Should it be model_input * alphas_cumprod_sqrt - ... here instead of noisy_model_input? From the original codebase, I think noisy_model_input is correct here but doesn't work yet possibly due to a different bug.

@G-U-N
Copy link
Contributor

G-U-N commented Sep 3, 2024

Ops, sorry. I made a typo. It should be noisy_model input. @a-r-r-o-w
I am going to test it on my code and report my training to you.

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left some comments. Let's maybe also add a test case first to quickly identify potential suspects?

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left some questions here and there (some of them are clarification questions so bear with me).

examples/cogvideo/train_cogvideox_lora.py Outdated Show resolved Hide resolved
examples/cogvideo/train_cogvideox_lora.py Show resolved Hide resolved
examples/cogvideo/train_cogvideox_lora.py Show resolved Hide resolved
examples/cogvideo/train_cogvideox_lora.py Outdated Show resolved Hide resolved
examples/cogvideo/train_cogvideox_lora.py Show resolved Hide resolved
examples/cogvideo/train_cogvideox_lora.py Show resolved Hide resolved
examples/cogvideo/train_cogvideox_lora.py Show resolved Hide resolved
examples/cogvideo/train_cogvideox_lora.py Outdated Show resolved Hide resolved
@G-U-N
Copy link
Contributor

G-U-N commented Sep 3, 2024

Here's a quick test. I apply a single frame video (video_length = 1) for tuning with batch size 1. The learning rate is set to 1e-3. I trained for 500 iterations and It can reproduce the trained frame. Yet I can observe the generation results broken in the middle iterations. I also trained the lora on the same but longer video and observed similar results. I tried to load the trained lora and generate new videos. It can still follow prompt to generate other videos but suffer from quality degradation.

Validation outputs
0_validation_video_0_The_video_features_a_man_.mp4
40_validation_video_0_The_video_features_a_man_.mp4
80_validation_video_0_The_video_features_a_man_.mp4
100_validation_video_0_The_video_features_a_man_.mp4
120_validation_video_0_The_video_features_a_man_.mp4
160_validation_video_0_The_video_features_a_man_.mp4
200_validation_video_0_The_video_features_a_man_.mp4
240_validation_video_0_The_video_features_a_man_.mp4
280_validation_video_0_The_video_features_a_man_.mp4
320_validation_video_0_The_video_features_a_man_.mp4
360_validation_video_0_The_video_features_a_man_.mp4
400_validation_video_0_The_video_features_a_man_.mp4
440_validation_video_0_The_video_features_a_man_.mp4

@a-r-r-o-w
Copy link
Member Author

I am facing similar issues too after overfitting on single example even after 1000 steps. Will try to take another deep look some time soon but AFAICT, there don't seem to be any more differences.

I've left a comment asking some questions. From the different discussions, I gather that ~100 videos and 4000+ steps seem to be ideal for finetuning. This seems very different from normal Dreambooth-like finetuning tbh where just a few examples would be okay to teach new concepts.

Maybe @zRzRzRzRzRzRzR @tengjiayan20 can hopefully take a look and help here.

@G-U-N
Copy link
Contributor

G-U-N commented Sep 3, 2024

I am facing similar issues too after overfitting on single example even after 1000 steps. Will try to take another deep look some time soon but AFAICT, there don't seem to be any more differences.

I've left a comment asking some questions. From the different discussions, I gather that ~100 videos and 4000+ steps seem to be ideal for finetuning. This seems very different from normal Dreambooth-like finetuning tbh where just a few examples would be okay to teach new concepts.

Maybe @zRzRzRzRzRzRzR @tengjiayan20 can hopefully take a look and help here.

Very insightful comment @a-r-r-o-w . Thanks for the reply.

@FDInSky
Copy link

FDInSky commented Sep 4, 2024

The first issue I noticed is the re-parameterization was wrong.

Hey, thanks a lot for noticing this! I haven't been able to generate any good results too. I actually have the following locally:

model_output = transformer(
    hidden_states=noisy_model_input,
    encoder_hidden_states=prompt_embeds,
    timestep=timesteps,
    image_rotary_emb=image_rotary_emb,
    return_dict=False,
)[0]
alphas_cumprod = scheduler.alphas_cumprod[timesteps]
alphas_cumprod_sqrt = alphas_cumprod ** 0.5
one_minus_alphas_cumprod_sqrt = (1 - alphas_cumprod) ** 0.5
model_pred = noisy_model_input * alphas_cumprod_sqrt - model_output * one_minus_alphas_cumprod_sqrt

Should it be model_input * alphas_cumprod_sqrt - ... here instead of noisy_model_input? From the original codebase, I think noisy_model_input is correct here but doesn't work yet possibly due to a different bug.

I occur an error here, now is there a solution about it ? Thanks
RuntimeError: The size of tensor a (90) must match the size of tensor b (2) at non-singleton dimension 4

@a-r-r-o-w
Copy link
Member Author

I occur an error here, now is there a solution about it ? Thanks
RuntimeError: The size of tensor a (90) must match the size of tensor b (2) at non-singleton dimension 4

I don't seem to get this error nor G-U-N. Could you provide more context? What are your flags when launching this script? What specific line does it fail on? Have you modified the script in any way?

@FDInSky
Copy link

FDInSky commented Sep 4, 2024

I occur an error here, now is there a solution about it ? Thanks
RuntimeError: The size of tensor a (90) must match the size of tensor b (2) at non-singleton dimension 4

I don't seem to get this error nor G-U-N. Could you provide more context? What are your flags when launching this script? What specific line does it fail on? Have you modified the script in any way?

what is the shape of tensor alphas_cumprod ? Thanks

@a-r-r-o-w
Copy link
Member Author

what is the shape of tensor alphas_cumprod ? Thanks

scheduler.alphas_cumprod has shape (1000,). alphas_cumprod in the training script is indexed using timesteps, which has the shape of (train_batch_size,) so that should be the shape. I've only experimented with train_batch_size=1. Are you using higher by any chance?

@FDInSky
Copy link

FDInSky commented Sep 4, 2024

what is the shape of tensor alphas_cumprod ? Thanks

scheduler.alphas_cumprod has shape (1000,). alphas_cumprod in the training script is indexed using timesteps, which has the shape of (train_batch_size,) so that should be the shape. I've only experimented with train_batch_size=1. Are you using higher by any chance?

I use batch_size = 2, The problem may be it. Thanks

help="Number of frames to skip from the end of each input video. Useful if training data contains outro sequences.",
)
parser.add_argument(
"--random_flip",
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think not used yet. TODO to support in follow-up PR


# Downloading and loading a dataset from the hub. See more about loading custom images at
# https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script
dataset = load_dataset(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The load_dataset method is not too good here due to lack of support for video data. Similar to how I did in the lora testing script, I think supporting snapshot_download from the hub would be nice to have and easier.

return parser.parse_args()


class VideoDataset(Dataset):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODOs:

  • Support loading latents directly instead of videos
  • Create a prepare_dataset.py for preprocessing data, and possibly having captioning utilities


def _preprocess_data(self):
try:
import decord
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO: Maybe better to add as a backend for load_video in the future.

examples/cogvideo/train_cogvideox_lora.py Outdated Show resolved Hide resolved
examples/cogvideo/train_cogvideox_lora.py Outdated Show resolved Hide resolved
[`CogVideoX`].
"""

_lora_loadable_modules = ["transformer", "text_encoder"]
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
_lora_loadable_modules = ["transformer", "text_encoder"]
_lora_loadable_modules = ["transformer"]

For now, since we removed text encoder related training, need to remove everything related to this in the lora loader

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok we should remove text_encoder from lora in this PR?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, I tried removing it but this causes test failures in ~10 different places, so I chose not to do it for now since it would require significant modification of many tests. I think it's okay to leave it here in case someone manages to fine tune the text encoder and wants to use it (eventually we can add support too).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't think we should add this code unless it is needed though
I just went through all the loraloadermixin here, I think we currently do not support t5 at all - cc @sayakpaul here to confirm if it is the case?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I came to the same conclusion. I'm actually working on removing thet text encoder parts at the moment so will update in a bit.

Need to fight a few more tests 👊

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah community doesn't really do T5 at the moment. So, we don't support it. No LoRA that is extremely popular has T5 (at least that is what @apolinario and myself have known). But supporting it is no big deal really.

tests/lora/utils.py Outdated Show resolved Hide resolved
tests/lora/utils.py Show resolved Hide resolved
@@ -690,14 +708,21 @@ def test_simple_inference_with_text_denoiser_lora_and_scale(self):
scheduler_classes = (
[FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler]
)
call_signature_keys = inspect.signature(self.pipeline_class.__call__).parameters.keys()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this can be a property (child class can overwrite it too) - ok to keep it as it is the test here;

tests/lora/utils.py Outdated Show resolved Hide resolved
tests/lora/utils.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks! and congrats on winning the fights against the lora test!

tests/lora/utils.py Outdated Show resolved Hide resolved
@a-r-r-o-w a-r-r-o-w merged commit 2b443a5 into main Sep 19, 2024
18 checks passed
@a-r-r-o-w a-r-r-o-w deleted the cogvideox-lora-and-training branch September 19, 2024 09:08
@963658029
Copy link
Contributor

Why didn't the code run the following two lines of code after calculating the loss?
avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
train_loss += avg_loss.item() / args.gradient_accumulation_steps

leisuzz pushed a commit to leisuzz/diffusers that referenced this pull request Oct 11, 2024
* cogvideox lora training draft

* update

* update

* update

* update

* update

* make fix-copies

* update

* update

* apply suggestions from review

* apply suggestions from reveiw

* fix typo

* Update examples/cogvideo/train_cogvideox_lora.py

Co-authored-by: YiYi Xu <[email protected]>

* fix lora alpha

* use correct lora scaling for final test pipeline

* Update examples/cogvideo/train_cogvideox_lora.py

Co-authored-by: YiYi Xu <[email protected]>

* apply suggestions from review; prodigy optimizer

YiYi Xu <[email protected]>

* add tests

* make style

* add README

* update

* update

* make style

* fix

* update

* add test skeleton

* revert lora utils changes

* add cleaner modifications to lora testing utils

* update lora tests

* deepspeed stuff

* add requirements.txt

* deepspeed refactor

* add lora stuff to img2vid pipeline to fix tests

* fight tests

* add co-authors

Co-Authored-By: Fu-Yun Wang <[email protected]>

Co-Authored-By: zR <[email protected]>

* fight lora runner tests

* import Dummy optim and scheduler only wheh required

* update docs

* add coauthors

Co-Authored-By: Fu-Yun Wang <[email protected]>

* remove option to train text encoder

Co-Authored-By: bghira <[email protected]>

* update tests

* fight more tests

* update

* fix vid2vid

* fix typo

* remove lora tests; todo in follow-up PR

* undo img2vid changes

* remove text encoder related changes in lora loader mixin

* Revert "remove text encoder related changes in lora loader mixin"

This reverts commit f8a8444.

* update

* round 1 of fighting tests

* round 2 of fighting tests

* fix copied from comment

* fix typo in lora test

* update styling

Co-Authored-By: YiYi Xu <[email protected]>

---------

Co-authored-by: YiYi Xu <[email protected]>
Co-authored-by: zR <[email protected]>
Co-authored-by: Fu-Yun Wang <[email protected]>
Co-authored-by: bghira <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.