Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sharding error when trying to serialize PEFT models #635

Open
rjpower opened this issue Jun 15, 2024 · 19 comments
Open

Sharding error when trying to serialize PEFT models #635

rjpower opened this issue Jun 15, 2024 · 19 comments

Comments

@rjpower
Copy link
Collaborator

rjpower commented Jun 15, 2024

#622 fixes issue #609 for loading HF models into a sharded representation. But now when I try to serialize a PEFT model I'm getting a similar error as before:

ValueError: Received incompatible devices for jitted computation. Got argument x of _identity_fn with shape float32[32,8,4096] and device ids [0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15] on platform TPU and explicit output sharding with device ids [0, 1, 4, 5, 2, 3, 6, 7, 8, 9, 12, 13, 10, 11, 14, 15] on platform TPU

I tried the simple thing of removing the sharding annotation entirely and just running jax.array(input) but that yields the (I guess expected) error since I'm assuming the original sharding is spread across multiple machines:

RuntimeError: Fetching value for `jax.Array` that spans non-addressable devices is not possible. You can use `jax.experimental.multihost_utils.process_allgather` for this use case.

I haven't yet reproduced as a test, but you can reproduce using the gsm8k example:

python infra/launch.py\ --foreground\ --tpu_name=tpu-0\
 --zone=us-west4-a\
 --\ 
 python examples/gsm8k-lora/gsm8k_lora.py\
 --config=examples/gsm8k-lora/gsm8k-llama2.yaml\
 --trainer.checkpointer.base_path=gs://wasabi-tpu-training/llama3-gsm8k/attempt-0\
 --hf_save_path=gs://wasabi-tpu-training/llama3-gsm8k/attempt-0\
 --data_cache_dir=gs://wasabi-tpu-training/gsm8k/data\
 --trainer.num_train_steps=10\
 --data_seed=0

The script needs one patch to load models correctly (there's an error now if you try to load a model using a name the way the script used to). I'll clean up the fixes and send them as a separate PR, but for now:

https://github.com/stanford-crfm/levanter/compare/main...rjpower:levanter:multi-lora?expand=1

@rjpower
Copy link
Collaborator Author

rjpower commented Jun 15, 2024

To further demonstrate my JAX ignorance, I tried adding a fully replicated sharding explicitly (since in this case everything should fit) with:

                process_mesh = Mesh(
                    np.array(jax.devices()).reshape((jax.process_count(), -1)),
                    ("process", "device"),
                )
                shardings = [None for i in range(len(arr.shape))]
                sharding = NamedSharding(process_mesh, PartitionSpec(*shardings))

                def _shard_fn(input_array):
                    return jax.device_put(input_array, sharding)

                out = jax.jit(_shard_fn)(arr)

but this doesn't seem to have any effect (I get the "non-addressable" error, or the "incompatible devices" error if I also specify output_shardings=sharding).

@rjpower
Copy link
Collaborator Author

rjpower commented Jun 15, 2024

And, oddly enough, everything starts to work when I remove the jax.jit for the shard function and just do:

shardings = [None for i in range(len(arr.shape))]
sharding = NamedSharding(process_mesh, PartitionSpec(*shardings))
input_array = jax.device_put(arr, sharding)
return np.array(input_array)

Are there any downsides to doing this for this type of copy operation?

@dlwh
Copy link
Member

dlwh commented Jun 16, 2024

so, in the small scale, that first error is crazy because the devices are literally the same, just in a different sort order. This seems like a bug? Could it be (another) issue with best_effort_sharding?

Python 3.10.9 | packaged by conda-forge | (main, Feb  2 2023, 20:26:08) [Clang 14.0.6 ] on darwin
Type "help", "copyright", "credits" or "license" for more information.
>>>
>>> a = [0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15]
>>> b = [0, 1, 4, 5, 2, 3, 6, 7, 8, 9, 12, 13, 10, 11, 14, 15]
>>> sorted(a) == sorted(b)
True
>>>
  1. i'm still wrestling with the right thing to do here and I should be more systematic.

The constraints, as I understand them, are:

  1. only device_put can move things between device types (cpu->tpu)
  2. only jit can coordinate cross-host data movement across devices (I think), so the arrays must be fully addressable to do without jit.
  3. device_put cannot be used inside jit. device_put cannot be differentiated.
  4. with_sharding_constraint can be used in either jit or outside jit
  5. "device_put transposes while with_sharding_constraint doesn't" is what the JAX people have told me, which is apparently just because device_put can take a src.

I don't think my current solution is the right one, but I'm not entirely sure what it should be. I think it's:

  1. cross-device, device_put
  2. un'jitted wsc if the array has a sharding and is fully addressable
  3. jitted wsc otherwise

which can probably simplify to

  1. cross-device, device_put
  2. jitted wsc otherwise, but I need to test it out

@rjpower
Copy link
Collaborator Author

rjpower commented Jun 16, 2024

Yeah, I'm puzzled by the behavior. I think your analysis is correct, you'd expect to need a 2 stage movement, but why doesn't your current version work for this simple case? Even when I switched to the fully replicated sharding, JAX still reported the device issue. Isn't reshuffling exactly what .with_sharding_constraint is supposed to do?

https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html#constraining-shardings-of-intermediates-in-jitted-code
"""
Using jax.lax.with_sharding_constraint is much like jax.device_put, except we use it inside staged-out (i.e. jit-decorated) functions:
"""

IIUC, your original code should have worked fine: there were no input shardings, and output_shards should be the same as having jax.lax.with_sharding_constraint, which should force the sharding the way we want. The device_put should be identical (but might not work if we needed cross-host movement in this case), but for some reason works...

@dlwh
Copy link
Member

dlwh commented Jun 16, 2024

I synced with the Jax folks, and they said that jit requires the devices to be in the same order for input and output. They’re working on relaxing that but there are weird performance regressions.

the right thing is probably to detect if the device sequence/mesh(not set) is changing and use device_put if so?

I feel like this is much more complicated than it needs to be.

@rjpower
Copy link
Collaborator Author

rjpower commented Jun 17, 2024

Interesting... I guess you're expected to effectively re-use the same device mesh all the time, and that's why this isn't hitting people more often? This just seems confusingly hard for what I'd think would be such a common desire: "get the data to/from the CPU".

If you're forced to use the original mesh, it seems like you either have to give up and fully replicate the array, or somehow carefully choose a sharding that doesn't partition you across hosts, and then run device_put.

Are we holding it wrong? (I guess fully-replicating isn't too bad if you're doing it one array at a time?)

@dlwh
Copy link
Member

dlwh commented Jun 17, 2024

Yeah so the best effort sharding logic relatively new. I added it and the crazy CPU-vs-accelerator logic to handle loading 34B param models on v4-8 (in response to #508 ) and also to be able to load smaller models on our internal gpu cluster, which have less CPU memory and we needed a solution there.

At the time I didn't realize that you were effectively stuck with one mesh, and it just happened that the one mesh we were using at the time is the same one I created for best effort sharding and got lucky.

I think the things to do are:

  • skip best effort sharding if context mesh is not set (easy and I'm guessing fixes this particular issue)
  • probably ensure that we have the context mesh set inside this script at load time, because I added best effort sharding specifically for the lora use case
  • check inside named_jit that shardings are consistent with the context mesh (harder but probably we can back it out from JAX?). For now error, but one could stretch to say reshard before
  • improve the logic in hax.shard and test it a bit more religiously

WDYT?

@rjpower
Copy link
Collaborator Author

rjpower commented Jun 17, 2024

Ah interesting, for the model loading side, everything makes sense: you're building the whole pytree, and if you don't shard, lots of models won't end up fitting. I don't quite follow how the CPU usage is reduced, since IIUC we're always loading the model replicated on the CPU and then sharding (but I didn't read the model loading code closely...). Oh, maybe because you have the implicit mesh you can avoid making a copy of the tree -> state_dict first... I should really just read the code and PR 😛.

skip best effort sharding if context mesh is not set (easy and I'm guessing fixes this particular issue)

For export we implicitly have the mesh from the input array already, so we can re-use that. I'm probably confused, but I think generating an appropriate (non-replicated) sharding for an arbitrary mesh for the export side seems hard. We'd need to choose an arrangement that keeps physical devices within a single host so that we don't get the "non-addressable devices" issue when we convert to CPU, right?

(It doesn't seem like there's anything wrong with the idea of the best-effort sharding for loading, and we could even avoid looking for the implicit mesh, if just moving between meshs worked at all...)

I thought something dumb like this would work to make the array fully replicated and then copy-able, but I still hit the "fetching an array that spans non-addressable devices". Though again, just switching to device_put works fine. I'm assuming this is just because I'm getting lucky and there's no model parallelism here. Maybe you just need both: first the sharding to convert to replicated and then the device_put to... I'm not sure what it's doing at this point, TBH.

                shardings = [None for i in range(len(arr.shape))]
                sharding = NamedSharding(arr.sharding.mesh, PartitionSpec(*shardings))

                def _copy(in_array):
                    return jax.lax.with_sharding_constraint(in_array, sharding)

                arr = jax.jit(_copy, donate_argnums=0)(arr)
                # but jax.device_put(arr, sharding) works fine!
                return np.array(arr)

probably ensure that we have the context mesh set inside this script at load time, because I added best effort sharding specifically for the lora use case

(I think) we're now getting the mesh correctly for loading, since the model loads quickly and without errors. It's only at save time that we run into this issue (there's a sort of copy of the best-effort sharding there). For this PEFT save logic, do we need the best-effort sharding at all? It seems like we're iterating over layers one at a time and copying them to the CPU, so we only need enough device memory for that single layer (again, I could be missing something).

check inside named_jit that shardings are consistent with the context mesh (harder but probably we can back it out from JAX?). For now error, but one could stretch to say reshard before

Yeah I feel some well-compartmentalized functions would help a lot: "make this CPU array appear on the devices with this sharding", "make this device array appear on the CPU replicated", both handling the resharding as necessary... it seems like they should be part of JAX TBH... (Some of that could be extending hax.shard it seems like, but for some of the state_dict manipulation, you're outside of the PyTree context so maybe harder to use then?)

Having named_jit do some magic to reshard seems okay, but probably unnecessary if it's relatively easy to coerce things ahead of time. Automatic resharding always worries me a bit that you'll accidentally keep doing a bunch of data movement on every step without realizing it (I know it's unlikely but e.g. "oops my output sharding for my weights is different from my input sharding").

@yashk2810
Copy link

I haven't read the entire thread but to do this: This just seems confusingly hard for what I'd think would be such a common desire: "get the data to/from the CPU".

maybe just try putting your array on pinned_host memory? jax.device_put(x, NamedSharding(mesh, pspec, memory_kind='pinned_host')

@yashk2810
Copy link

This should keep the sharding the same as the TPU one without having to mess around with devices. i.e. the mesh stays the same with TPU devices! You only change the memory kind of the sharding to point to host.

@yashk2810
Copy link

Note you need to enable this config: jax.config.update('jax_enable_memories', True)

@dlwh
Copy link
Member

dlwh commented Jun 27, 2024

Talking to @yashk2810, it seems like device_put is more capable than I thought (in particular it seems like it can do cross-host transfers)

import jax
import jax.numpy as jnp
import numpy as np
from jax.sharding import Mesh, PartitionSpec as P, NamedSharding
import jax.experimental.mesh_utils as mesh_utils

D = 4096

mesh_devices = mesh_utils.create_device_mesh((len(jax.devices()),))

smart_mesh = Mesh(mesh_devices, ('dp',))
# mesh = Mesh(jax.devices(), ('dp'))
host_mesh = Mesh(np.array(jax.devices()).reshape(jax.process_count(), -1), ('host', 'device'))

z = jnp.full((D, D), jax.process_index())

print(z.sharding)

smart_sharding = NamedSharding(smart_mesh, P('dp'))
host_sharding = NamedSharding(host_mesh, P('device'))

z2 = jax.jit(lambda: jnp.zeros((D, D)), out_shardings=smart_sharding)()

print(z2.sharding)

z3 = jax.jit(lambda: jnp.zeros((D, D)), out_shardings=host_sharding)()
print(z3.sharding)

# this is no good:
# Traceback (most recent call last):
#   File "/home/dlwh/test_device_put.py", line 30, in <module>
#     z4 = jax.jit(lambda x: x, out_shardings=host_sharding)(z2)
# ValueError: Received incompatible devices for jitted computation. Got argument x of <lambda> with shape float32[4096,4096] and device ids

# z4 = jax.jit(lambda x: x, out_shardings=host_sharding)(z2)
# print(z4.sharding)

z5 = jax.device_put(z2, host_sharding)

print(z5.sharding)

@dlwh
Copy link
Member

dlwh commented Jun 27, 2024

ok @yashk2810 figured this out for me and I'm patching Haliax with the fix. stanford-crfm/haliax#96 (Yash avert your eyes...)

@dlwh
Copy link
Member

dlwh commented Jul 3, 2024

@rjpower if you get a chance, could you check if things work with the latest jax nightly?

@dlwh
Copy link
Member

dlwh commented Jul 3, 2024

(@yashk2810 fixed it i think)

@rjpower
Copy link
Collaborator Author

rjpower commented Jul 7, 2024

Hrm, I changed the Haliax dependency to "haliax @ git+https://github.com/stanford-crfm/haliax.git@main" and ran:

BASE_DIR=gs://wasabi-tpu-training/gsm8k/test/llama2-0 python infra/launch.py --foreground --tpu_name=tpu-0 -- python examples/gsm8k-lora/gsm8k_lora.py --config=examples/gsm8k-lora/gsm8k-llama2.yaml --hf_save_path=$BASE_DIR/hf --data_cache_dir=gs://wasabi-tpu-training/gsm8k/data --data_seed=0 --trainer.num_train_steps=10

I still see this error:

  File "/opt/levanter/src/levanter/compat/torch_serialization.py", line 449, in <lambda>
    model = jax.tree_util.tree_map(lambda arr: get_to_cpu(arr), model)
  File "/opt/levanter/src/levanter/compat/torch_serialization.py", line 445, in get_to_cpu
    out = jax.jit(_identity_fn, out_shardings=sharding)(arr)
ValueError: Received incompatible devices for jitted computation. Got argument x of _identity_fn with shape float32[32,8,4096] and device ids [0, 4, 8, 12, 16, 20, 24, 28, 1, 5, 9, 13, 17, 21, 25, 29, 2, 6, 10, 14, 18, 22, 26, 30, 3, 7, 11, 15, 19, 23, 27, 31] on platform TPU and explicit output sharding with device ids [0, 1, 4, 5, 2, 3, 6, 7, 8, 9, 12, 13, 10, 11, 14, 15, 16, 17, 20, 21, 18, 19, 22, 23, 24, 25, 28, 29, 26, 27, 30, 31] on platform TPU

Guessing without looking, do we also need to adjust the torch_serialization.py in Levanter?

.config:

env:
    XLA_FLAGS: "--xla_dump_to=/tmp/output_folder/xla_dumps --xla_dump_hlo_pass_re=.*"
    LIBTPU_INIT_ARGS: --xla_tpu_impure_oom_fast_exit_threshold=-1

docker_repository: levanter
zone: us-west4-a
tpu_type: v5litepod-32
vm_image: "tpu-ubuntu2204-base"
capacity_type: preemptible
autodelete: false
subnetwork: "default"

@yashk2810
Copy link

You need to use jax.device_put, not jax.jit

@rjpower
Copy link
Collaborator Author

rjpower commented Jul 7, 2024

Thanks Yash, with this change to torch_serialization.py, things work for me:

-                out = jax.jit(_identity_fn, out_shardings=sharding)(arr)
+                out = jax.device_put(arr, sharding)

@dlwh I'll send the one-line patch CL; let me know if I'm missing something obvious though!

@dlwh
Copy link
Member

dlwh commented Jul 7, 2024

lgtm!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants