-
Notifications
You must be signed in to change notification settings - Fork 81
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Sharding error when trying to serialize PEFT models #635
Comments
To further demonstrate my JAX ignorance, I tried adding a fully replicated sharding explicitly (since in this case everything should fit) with:
but this doesn't seem to have any effect (I get the "non-addressable" error, or the "incompatible devices" error if I also specify |
And, oddly enough, everything starts to work when I remove the
Are there any downsides to doing this for this type of copy operation? |
so, in the small scale, that first error is crazy because the devices are literally the same, just in a different sort order. This seems like a bug? Could it be (another) issue with best_effort_sharding?
The constraints, as I understand them, are:
I don't think my current solution is the right one, but I'm not entirely sure what it should be. I think it's:
which can probably simplify to
|
Yeah, I'm puzzled by the behavior. I think your analysis is correct, you'd expect to need a 2 stage movement, but why doesn't your current version work for this simple case? Even when I switched to the fully replicated sharding, JAX still reported the device issue. Isn't reshuffling exactly what https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html#constraining-shardings-of-intermediates-in-jitted-code IIUC, your original code should have worked fine: there were no input shardings, and |
I synced with the Jax folks, and they said that jit requires the devices to be in the same order for input and output. They’re working on relaxing that but there are weird performance regressions. the right thing is probably to detect if the device sequence/mesh(not set) is changing and use device_put if so? I feel like this is much more complicated than it needs to be. |
Interesting... I guess you're expected to effectively re-use the same device mesh all the time, and that's why this isn't hitting people more often? This just seems confusingly hard for what I'd think would be such a common desire: "get the data to/from the CPU". If you're forced to use the original mesh, it seems like you either have to give up and fully replicate the array, or somehow carefully choose a sharding that doesn't partition you across hosts, and then run device_put. Are we holding it wrong? (I guess fully-replicating isn't too bad if you're doing it one array at a time?) |
Yeah so the best effort sharding logic relatively new. I added it and the crazy CPU-vs-accelerator logic to handle loading 34B param models on v4-8 (in response to #508 ) and also to be able to load smaller models on our internal gpu cluster, which have less CPU memory and we needed a solution there. At the time I didn't realize that you were effectively stuck with one mesh, and it just happened that the one mesh we were using at the time is the same one I created for best effort sharding and got lucky. I think the things to do are:
WDYT? |
Ah interesting, for the model loading side, everything makes sense: you're building the whole pytree, and if you don't shard, lots of models won't end up fitting. I don't quite follow how the CPU usage is reduced, since IIUC we're always loading the model replicated on the CPU and then sharding (but I didn't read the model loading code closely...). Oh, maybe because you have the implicit mesh you can avoid making a copy of the tree -> state_dict first... I should really just read the code and PR 😛.
For export we implicitly have the mesh from the input array already, so we can re-use that. I'm probably confused, but I think generating an appropriate (non-replicated) sharding for an arbitrary mesh for the export side seems hard. We'd need to choose an arrangement that keeps physical devices within a single host so that we don't get the "non-addressable devices" issue when we convert to CPU, right? (It doesn't seem like there's anything wrong with the idea of the best-effort sharding for loading, and we could even avoid looking for the implicit mesh, if just moving between meshs worked at all...) I thought something dumb like this would work to make the array fully replicated and then copy-able, but I still hit the "fetching an array that spans non-addressable devices". Though again, just switching to
(I think) we're now getting the mesh correctly for loading, since the model loads quickly and without errors. It's only at save time that we run into this issue (there's a sort of copy of the best-effort sharding there). For this PEFT save logic, do we need the best-effort sharding at all? It seems like we're iterating over layers one at a time and copying them to the CPU, so we only need enough device memory for that single layer (again, I could be missing something).
Yeah I feel some well-compartmentalized functions would help a lot: "make this CPU array appear on the devices with this sharding", "make this device array appear on the CPU replicated", both handling the resharding as necessary... it seems like they should be part of JAX TBH... (Some of that could be extending Having |
I haven't read the entire thread but to do this: maybe just try putting your array on |
This should keep the sharding the same as the TPU one without having to mess around with devices. i.e. the mesh stays the same with TPU devices! You only change the memory kind of the sharding to point to host. |
Note you need to enable this config: |
Talking to @yashk2810, it seems like device_put is more capable than I thought (in particular it seems like it can do cross-host transfers)
|
ok @yashk2810 figured this out for me and I'm patching Haliax with the fix. stanford-crfm/haliax#96 (Yash avert your eyes...) |
@rjpower if you get a chance, could you check if things work with the latest jax nightly? |
(@yashk2810 fixed it i think) |
Hrm, I changed the Haliax dependency to
I still see this error:
Guessing without looking, do we also need to adjust the .config:
|
You need to use jax.device_put, not jax.jit |
Thanks Yash, with this change to
@dlwh I'll send the one-line patch CL; let me know if I'm missing something obvious though! |
lgtm! |
#622 fixes issue #609 for loading HF models into a sharded representation. But now when I try to serialize a PEFT model I'm getting a similar error as before:
I tried the simple thing of removing the sharding annotation entirely and just running
jax.array(input)
but that yields the (I guess expected) error since I'm assuming the original sharding is spread across multiple machines:I haven't yet reproduced as a test, but you can reproduce using the gsm8k example:
The script needs one patch to load models correctly (there's an error now if you try to load a model using a name the way the script used to). I'll clean up the fixes and send them as a separate PR, but for now:
https://github.com/stanford-crfm/levanter/compare/main...rjpower:levanter:multi-lora?expand=1
The text was updated successfully, but these errors were encountered: