Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Flux] Port Flux Core Model #1864

Merged
merged 74 commits into from
Nov 13, 2024

Conversation

DavidLandup0
Copy link
Collaborator

@DavidLandup0 DavidLandup0 commented Sep 23, 2024

This PR ports the core model into a Keras model and includes a weight conversion script.
VAE and rest of the pipeline would make sense in a separate PR.

Each layer is numerically compared against the original PyTorch implementation here: https://colab.research.google.com/drive/1Jr5pa9BGAxP6lZPimlpb22rD5DMijN3H#scrollTo=Bi_WbOjk7C4k

Modules included:

  • Maths module
    • Timestep embedding
    • RoPE
    • Attention
    • Scaled dot product attention re-implementation in Keras (to match the PyTorch one)
  • Layers module
    • MLPEmbedder
    • RMSNorm
    • QKNorm
    • SelfAttention
    • Modulation
    • DoubleStreamBlock
    • SingleStreamBlock
    • LastLayer

Output Comparison

The core model's outputs are latents. We plot the PCA of the output from the original implementation and the Keras re-implementation on the same input:

image

Numerically, equivalent to 1e-3 precision:

>>> np.allclose(output_keras.numpy(), output_pt.detach().numpy(), atol=1e-3)
True

@sachinprasadhs sachinprasadhs marked this pull request as draft September 23, 2024 16:51
@DavidLandup0 DavidLandup0 changed the title [Flux] Port Flux Model and Pipeline [Flux] Port Flux Core Model Oct 2, 2024
@DavidLandup0 DavidLandup0 marked this pull request as ready for review October 3, 2024 12:10
@DavidLandup0
Copy link
Collaborator Author

@divyashreepathihalli turned into a functional subclassing module - had to wrestle a bit with shapes/autograph, but it should be ready for another review.

Here's the notebook showing numerical equivalence to atol=1e-5 on all modules, as well as the final output of the core model: https://colab.research.google.com/drive/1Jr5pa9BGAxP6lZPimlpb22rD5DMijN3H#scrollTo=Bi_WbOjk7C4k

Adding a preprocessing flow and we can open a PR for integrating with T5 and CLIP.


class FluxBackboneTest(TestCase):
def setUp(self):
vae = VAEBackbone(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will be part of the generation pipeline so these are added preemptively and unused for now

@DavidLandup0 DavidLandup0 mentioned this pull request Oct 23, 2024
4 tasks
@DavidLandup0
Copy link
Collaborator Author

@divyashreepathihalli could we do another review here?

@divyashreepathihalli divyashreepathihalli added the kokoro:force-run Runs Tests on GPU label Oct 28, 2024
@kokoro-team kokoro-team removed the kokoro:force-run Runs Tests on GPU label Oct 28, 2024
Copy link
Collaborator

@divyashreepathihalli divyashreepathihalli left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks David! left a few comments.
Do you have a demo colab to verify the outputs?

return self.out_layer(x)


# TODO: Maybe this can be exported as part of the public API? Seems to have enough reusability.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here - keras_hub/src/layers/modeling

keras_hub/src/models/flux/flux_layers.py Show resolved Hide resolved
keras_hub/src/models/flux/flux_layers.py Outdated Show resolved Hide resolved
keras_hub/src/models/flux/flux_layers.py Outdated Show resolved Hide resolved
keras_hub/src/models/flux/flux_layers.py Outdated Show resolved Hide resolved
keras_hub/src/models/flux/flux_layers.py Outdated Show resolved Hide resolved
tools/checkpoint_conversion/convert_flux_checkpoints.py Outdated Show resolved Hide resolved
@DavidLandup0
Copy link
Collaborator Author

DavidLandup0 commented Oct 29, 2024

Thanks David! left a few comments. Do you have a demo colab to verify the outputs?

Yes - here: https://colab.research.google.com/drive/1Jr5pa9BGAxP6lZPimlpb22rD5DMijN3H#scrollTo=_ys5NSkcoQ_O

With converted weights (in the Colab as well), we get identical outputs between the official model and the port, within 1e-3 sensitivity:
image

Copy link
Collaborator

@divyashreepathihalli divyashreepathihalli left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@divyashreepathihalli divyashreepathihalli merged commit 0756fb4 into keras-team:master Nov 13, 2024
6 of 7 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants