You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
It would be amazing if we could load and finetune the models on TPUs using the flax LM classes in HF. In my experience, this makes the training and generation very straightforward on TPUs, along ofc with taking advantage of their compute.
I have tried to load a mistral checkpoint with the following code: model = FlaxAutoModelForCausalLM.from_pretrained("alias/arwen-x21-checkpoint-400000", from_pt=True, pad_token_id=50256, )
This seems to work. The model loads, I can access its properties, and can even generate text.
Finally, it would be nice if the changes in mistral models were smh included when loading the model in HF (I am actually not 100% sure that does not happen). Specifically, I'm thinking of this line here:
Hey Theodore - so we're definitely working on pushing the Mistral-specific operation changes (like the one you mentioned) to Transformers proper, as a flag in the GPT-2 Model class. This should happen by the end of the week (or at least, we'll have a PR in transformers you can use!).
As for why the Flax code is running slower - that's super interesting, and I don't have a good answer! Could be some weird interaction between the way we handle the upcasting code and defaults in the run_clm_flax.py script. Would be great if you could do some digging (or create an issue/PR!) as we're not too familiar with Flax ourselves, otherwise, I'll take a look when I can!
It would be amazing if we could load and finetune the models on TPUs using the flax LM classes in HF. In my experience, this makes the training and generation very straightforward on TPUs, along ofc with taking advantage of their compute.
I have tried to load a mistral checkpoint with the following code:
model = FlaxAutoModelForCausalLM.from_pretrained("alias/arwen-x21-checkpoint-400000", from_pt=True, pad_token_id=50256, )
This seems to work. The model loads, I can access its properties, and can even generate text.
However, once I try to fine tune it, using (more or less) the code here: https://github.com/huggingface/transformers/blob/master/examples/flax/language-modeling/run_clm_flax.py, it takes about 10mins to compile and then about 5mins for each step (for reference, in this should be 2mins and some seconds respectively got gpt2-medium).
Finally, it would be nice if the changes in mistral models were smh included when loading the model in HF (I am actually not 100% sure that does not happen). Specifically, I'm thinking of this line here:
mistral/src/models/mistral_gpt2.py
Line 312 in 7be4c58
Hope this makes sense. Thank you in advance!
Best,
Theodore.
The text was updated successfully, but these errors were encountered: