-
Notifications
You must be signed in to change notification settings - Fork 242
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Flux] Port Flux Core Model #1864
[Flux] Port Flux Core Model #1864
Conversation
d6d626c
to
2bc150e
Compare
@divyashreepathihalli turned into a functional subclassing module - had to wrestle a bit with shapes/autograph, but it should be ready for another review. Here's the notebook showing numerical equivalence to atol=1e-5 on all modules, as well as the final output of the core model: https://colab.research.google.com/drive/1Jr5pa9BGAxP6lZPimlpb22rD5DMijN3H#scrollTo=Bi_WbOjk7C4k Adding a preprocessing flow and we can open a PR for integrating with T5 and CLIP. |
|
||
class FluxBackboneTest(TestCase): | ||
def setUp(self): | ||
vae = VAEBackbone( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will be part of the generation pipeline so these are added preemptively and unused for now
@divyashreepathihalli could we do another review here? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks David! left a few comments.
Do you have a demo colab to verify the outputs?
return self.out_layer(x) | ||
|
||
|
||
# TODO: Maybe this can be exported as part of the public API? Seems to have enough reusability. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
here - keras_hub/src/layers/modeling
Yes - here: https://colab.research.google.com/drive/1Jr5pa9BGAxP6lZPimlpb22rD5DMijN3H#scrollTo=_ys5NSkcoQ_O With converted weights (in the Colab as well), we get identical outputs between the official model and the port, within 1e-3 sensitivity: |
048107c
to
c6e20f6
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
This PR ports the core model into a Keras model and includes a weight conversion script.
VAE and rest of the pipeline would make sense in a separate PR.
Each layer is numerically compared against the original PyTorch implementation here: https://colab.research.google.com/drive/1Jr5pa9BGAxP6lZPimlpb22rD5DMijN3H#scrollTo=Bi_WbOjk7C4k
Modules included:
Output Comparison
The core model's outputs are latents. We plot the PCA of the output from the original implementation and the Keras re-implementation on the same input:
Numerically, equivalent to 1e-3 precision: