-
Notifications
You must be signed in to change notification settings - Fork 521
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Decouple int4 weight with serialized format #187
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also need to update the readme.
@@ -404,7 +403,7 @@ def __init__(self, mod, groupsize=128, inner_k_tiles=8, padding=True): | |||
|
|||
@torch.no_grad() | |||
def create_quantized_state_dict(self, use_cuda = True): | |||
if use_cuda: | |||
if use_cuda and torch.cuda.is_available(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
won't be necessary to add torch.cuda.is_available()
, just let it report error if use_cuda
is true and gpu is not available.
@malfet Hi we have modified the int4 packed weight logic from gpt-fast and also from torch: pytorch/pytorch#129940 could you please help review? @yanbing-j could you also help evaluate how much time spent on the weight prepacking after model is loaded? both CPU and GPU numbers will be needed. Also this won't affect the 1st token latency, right ? |
@mingfeima , README has been updated. The time of weight prepacking in CPU is 0.23s (total time of loading model is 0.28s), and in GPU is 4ms (total time of loading model is 1.2s, mainly in |
Hi @yanboliang , could you please help merge this PR? Since pytorch/pytorch#129940 has been merged. |
cc1f6cd
to
acdc197
Compare
Hi @yanboliang , could you please help review this PR? Since the API of |
Hi @yanboliang , could you please help review this PR? Thanks! |
generate.py
Outdated
if isinstance(mod, WeightOnlyInt4Linear): | ||
weight = mod.weight.data | ||
weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(weight, mod.inner_k_tiles) | ||
mod.weight = weight_int4pack |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you put the weight conversion into WeightOnlyInt4QuantHandler.convert_for_runtime
at L236? More concretely, it can be part of replace_linear_int4
, I think that's the right place. And your change doesn't work well with TP, since the conversion happens after L246. Otherwise, this looks good!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yanboliang Thanks for the comments!
When quantizing, it generates [n][k / 2] uint8 (serialized format). After model loading to specific device, int4 pack weight can be converted then. So, moving to WeightOnlyInt4QuantHandler.convert_for_runtime
or replace_linear_int4
cannot be a suitable way.
I try to make some refactor to wrapper the conversion into a new function and change TP after model loading and weight convert. Is this okay for you? Thanks!
update int4 weight dim Add CPU profiling
acdc197
to
fe741f4
Compare
This PR is to decouple int4 weight with serialized format, so that int4 model checkpoint can be shared in different test machines or ISAs, without re-generating in one certain platform.
In int4 woq quantization, weight is saved as
[n][k / 2] uint8
(serialized format). The behavior of converting weight to int4 weight is moved to loading model ingenerate.py
.And this PR is based on pytorch/pytorch#129940, which updates the input
weight
of_convert_weight_to_int4pack
to[n][k / 2] uint8
.