Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Llama/unshard on load #174

Merged
merged 1 commit into from
Dec 25, 2023
Merged

Conversation

dastrobu
Copy link
Contributor

Similar to #92 I noticed that converting the llama-2 70b models takes quite a bit of RAM (succeeded at around 140GB on an 128GB machine with swaping).

However, the resulting huge weights files are still very hard to handle (e.g. upload them to HF is impossible and would required extra steps).

So I think changing the conversion algorithm a bit: keep the shards on model conversion and then unshard the weights on loading. This would be more RAM efficient and file size friendly.

This PR

I tested locally with tiny llama, llama-2-13b-chat and llama-2-70b-chat. The largest 70b model now takes around 16GB on average while converting, with peaks around 32GB. On inference it still requires around 128 GB, which makes sense, given that weights are around 128 GB on disk. With a bit of swapping one can run it on a 128GB machine, though not really productive on an Apple M3 Max 128GB:

[INFO] Prompt processing: 206.761 s
[INFO] Full generation: 36003.395 s

@awni
Copy link
Member

awni commented Dec 22, 2023

I love this change!! Could you rebase on main and then I will review?

@dastrobu
Copy link
Contributor Author

Could you rebase on main and then I will review?

Sure, after taking a brief look at the changes, I guess the new quantization makes this more complicated. It expects the entire model to be loaded.

I could think of moving the the quantization into its own script, so you would run

  1. convert.py to convert the model
  2. quantize.py, if quantization is wanted (would operate on the converted model)
  3. llama.py for inference

If we really want to do quantize in convert.py it would require the unsharding there I guess. I don't know enough about the new quantize_module to know if we could run it on the shards somehow without merging them first.

Given that quantization should enable smaller machines, it would be nice if we could do the quantization without merging all the unquantized weights in memory first. Not sure though, if there is a clever way to achieve this.

@dastrobu dastrobu force-pushed the llama/unshard-on-load branch 2 times, most recently from 41c8efe to dff87bc Compare December 22, 2023 21:08
@dastrobu
Copy link
Contributor Author

@awni I think I found a good way to refactor it to support quantize in convert.py. It will still unshard for quantization, but keeps shard loading and conversion lazy and memory friendly. Looking forward to your review.

@awni
Copy link
Member

awni commented Dec 23, 2023

@dastrobu I like where this is going, but I suggest we reorganize the computation to avoid the need to unshard in the final loading script. Here's my suggestion:

  1. Do the unsharding as before to get the full weights
  2. Quantize
  3. Split the weights (e.g. different arrays) into smaller files but don't split the arrays themselves
  4. Change the llama.py to load from a few saved weight files (like you have done), but the loading just needs to read a few files but does not need to deal with concatenating etc which will be faster / simpler / easier to maintain.

Does that make sense?

@awni
Copy link
Member

awni commented Dec 23, 2023

So your changes to llama.py will look more like how we load mixtral, except with the option to load from a single file if there is only one as you have it now.

@dastrobu dastrobu force-pushed the llama/unshard-on-load branch 2 times, most recently from a7d08be to 7f95a25 Compare December 24, 2023 13:14
@dastrobu
Copy link
Contributor Author

Does that make sense?

@awni yes, it does. Thanks for your review and suggestions.
With quantize in place, I agree that it got a little complicated, and we are probably better off defining our own shards. Something that I initially tried to avoid, as it adds another algorithm to maintain.
Please take a look at the updated code, especially at make_shards, which basically the feature of this PR.

Copy link
Member

@awni awni left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great and much simpler, thanks for adding this!! I left a couple of comments, please address then we can merge.

llms/llama/llama.py Outdated Show resolved Hide resolved
llms/llama/llama.py Outdated Show resolved Hide resolved
@@ -140,6 +139,21 @@ def quantize(weights, config, args):
return quantized_weights, quantized_config


def make_shards(weights: dict, max_file_size_GiB: int = 15):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style nit: max_file_size_gb

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as we are using 2**30 = GiB I'd suggest to use: max_file_size_gibibyte as I find max_file_size_gb wrong and max_file_size_gib unreadable.

shards = []
shard, shard_size = {}, 0
for k, v in weights.items():
estimated_size = len(v.flatten()) * v.dtype.itemsize
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you check this with quantization? I think this line might break as dtype doesn't have an itemsize?

We really ought to expose nbytes in python.. for consistency with numpy. For now you can do:

v.size * v.dtype.size if isintance(v, mx.array) else v.nbytes

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wasn't aware that quantization stores mx arrays already...
your suggestion seems to be a good intermediate solution. Exposing nbytes sounds even better, I'll create a PR, sounds like a small change.

@dastrobu dastrobu force-pushed the llama/unshard-on-load branch from 7f95a25 to a11e3f8 Compare December 25, 2023 17:25
@dastrobu dastrobu requested a review from awni December 25, 2023 17:26
@dastrobu
Copy link
Contributor Author

Looks great and much simpler, thanks for adding this!! I left a couple of comments, please address then we can merge.

Thanks, should be all fixed now.

Copy link
Member

@awni awni left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome, thanks!!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants