Skip to content

Commit

Permalink
Improve docs and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
matthewdouglas committed Nov 18, 2024
1 parent f61d8bc commit eed9c3c
Show file tree
Hide file tree
Showing 5 changed files with 88 additions and 29 deletions.
4 changes: 2 additions & 2 deletions bitsandbytes/autograd/_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ def forward(
# 1. Quantize A. Note that as a side-effect, outliers are suppressed in CA/CAt.
if ctx.needs_input_grad[1]:
# Slower path
CA, CAt, SCA, SCAt, outlier_cols = F.double_quant(A.to(torch.float16), threshold=state.threshold)
CA, CAt, SCA, SCAt, outlier_cols = F.int8_double_quant(A.to(torch.float16), threshold=state.threshold)
else:
# Fast path
CA, SCA, outlier_cols = F.int8_vectorwise_quant(A.to(torch.float16), threshold=state.threshold)
Expand Down Expand Up @@ -422,7 +422,7 @@ def backward(ctx: torch.autograd.function.FunctionCtx, grad_output: torch.Tensor
grad_output = grad_output.reshape(-1, grad_output.shape[-1]).contiguous()

if req_gradB:
Cgrad, _, _, SCgradt, _ = F.double_quant(grad_output.to(torch.float16))
Cgrad, _, _, SCgradt, _ = F.int8_double_quant(grad_output.to(torch.float16))

gradB32 = F.int8_linear_matmul(Cgrad.t().contiguous(), CAt.t())
grad_B = F.int8_mm_dequant(gradB32, SCgradt, SCAt)
Expand Down
72 changes: 68 additions & 4 deletions bitsandbytes/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,7 +442,7 @@ def is_on_gpu(tensors: Iterable[Optional[torch.Tensor]]):
An input tensor may also be marked as `paged`, in which case the device placement is ignored.
Args:
tensors (Iterable[Optional[torch.Tensor]]): A list of tensors to verify.
tensors (`Iterable[Optional[torch.Tensor]]`): A list of tensors to verify.
Raises:
`RuntimeError`: Raised when the verification fails.
Expand Down Expand Up @@ -2572,13 +2572,80 @@ def coo_zeros(rows, cols, nnz, device, dtype=torch.half):
return COOSparseTensor(rows, cols, nnz, rowidx, colidx, values)


@deprecated("This function is deprecated. Please use `int8_double_quant` instead.", category=FutureWarning)
def double_quant(
A: torch.Tensor,
col_stats: Optional[torch.Tensor] = None,
row_stats: Optional[torch.Tensor] = None,
out_col: Optional[torch.Tensor] = None,
out_row: Optional[torch.Tensor] = None,
threshold=0.0,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Optional[COOSparseTensor]]:
"""Determine the quantization statistics for input matrix `A` in accordance to the `LLM.int8()` algorithm.
The statistics are determined both row-wise and column-wise (transposed).
For more information, see the [LLM.int8() paper](https://arxiv.org/abs/2208.07339).
<Tip warning={true}>
This function exists for backwards compatibility only. It is advised to use [`int8_double_quant`] instead.
The difference is that this function will return a [`COOSparseTensor`] for outliers instead of a column index.
</Tip>
Args:
A (`torch.Tensor` with dtype `torch.float16`): The input matrix.
col_stats (`torch.Tensor`, *optional*): A pre-allocated tensor to hold the column-wise quantization scales.
row_stats (`torch.Tensor`, *optional*): A pre-allocated tensor to hold the row-wise quantization scales.
out_col (`torch.Tensor`, *optional*): A pre-allocated tensor to hold the column-wise quantized data.
out_row (`torch.Tensor`, *optional*): A pre-allocated tensor to hold the row-wise quantized data.
threshold (`float`, *optional*):
An optional threshold for sparse decomposition of outlier features.
No outliers are held back when 0.0. Defaults to 0.0.
Returns:
`Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]`: A tuple containing the quantized tensor and relevant statistics.
- `torch.Tensor` with dtype `torch.int8`: The row-wise quantized data.
- `torch.Tensor` with dtype `torch.int8`: The column-wise quantized data.
- `torch.Tensor` with dtype `torch.float32`: The row-wise quantization scales.
- `torch.Tensor` with dtype `torch.float32`: The column-wise quantization scales.
- `COOSparseTensor`, *optional*: A structure representing the outlier values from the input tensor.
"""

coo_tensor = None
quant_row, quant_col, row_stats, col_stats, _ = int8_double_quant(
A,
col_stats,
row_stats,
out_col,
out_row,
threshold=threshold,
)

if threshold > 0.0:
# Build COO tensor for any outliers.
outlier_mask = A.abs() >= threshold
outlier_locations = outlier_mask.nonzero()
outliers = A[outlier_mask]
coo_tensor = COOSparseTensor(
A.shape[0],
A.shape[1],
outliers.numel(),
outlier_locations[:, 0].int(),
outlier_locations[:, 1].int(),
outliers,
)

return quant_row, quant_col, row_stats, col_stats.flatten().float(), coo_tensor


def int8_double_quant(
A: torch.Tensor,
col_stats: Optional[torch.Tensor] = None,
row_stats: Optional[torch.Tensor] = None,
out_col: Optional[torch.Tensor] = None,
out_row: Optional[torch.Tensor] = None,
threshold=0.0,
):
"""Determine the quantization statistics for input matrix `A` in accordance to the `LLM.int8()` algorithm.
Expand Down Expand Up @@ -2612,7 +2679,6 @@ def double_quant(
"""

# TODO: Optimize/write CUDA kernel for this?
# Note: for inference, use the new int8_vectorwise_quant.

# Use CUDA kernel for rowwise and COO tensor
quant_row, row_stats, outlier_cols = int8_vectorwise_quant(A, threshold=threshold)
Expand Down Expand Up @@ -2665,8 +2731,6 @@ def int8_vectorwise_quant(A: torch.Tensor, threshold=0.0):
# TODO we could improve perf of this
outliers = A.abs() >= threshold

# argwhere needs host/device sync, so we skip when
# there aren't actually any outliers.
if outliers.any():
outlier_cols = torch.argwhere(outliers.any(dim=0)).view(-1)

Expand Down
6 changes: 3 additions & 3 deletions bitsandbytes/research/autograd/_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def forward(ctx, A, B, out=None, bias=None, state: Optional[MatmulLtState] = Non
# 1. Quantize A
if len(A.shape) == 3:
A = A.view(-1, A.shape[-1]).contiguous()
CA, CAt, SCA, SCAt, outlier_cols = F.double_quant(A.to(torch.float16), threshold=state.threshold)
CA, CAt, SCA, SCAt, outlier_cols = F.int8_double_quant(A.to(torch.float16), threshold=state.threshold)

if state.threshold > 0.0 and outlier_cols is not None:
if state.has_fp16_weights:
Expand Down Expand Up @@ -248,7 +248,7 @@ def forward(ctx, A, B, out=None, bias=None, state: Optional[MatmulLtState] = Non
state.SCB,
state.SCBt,
_,
) = F.double_quant(B.to(torch.float16))
) = F.int8_double_quant(B.to(torch.float16))
state.SB = (state.CB.shape, "row")
else:
has_grad = False
Expand Down Expand Up @@ -320,7 +320,7 @@ def backward(ctx, grad_output):
if len(grad_output.shape) == 3:
grad_output = grad_output.reshape(-1, grad_output.shape[-1]).contiguous()

Cgrad, Cgradt, SCgrad, SCgradt, outlier_cols = F.double_quant(grad_output.to(torch.float16))
Cgrad, Cgradt, SCgrad, SCgradt, outlier_cols = F.int8_double_quant(grad_output.to(torch.float16))

if req_gradB:
# print('back A shape', A.shape)
Expand Down
33 changes: 15 additions & 18 deletions tests/test_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -606,8 +606,8 @@ def test_int8_linear_matmul_half(dim1, dim2, dim3, dim4, dims):

A = A.view(-1, A.shape[-1])

CA, _, statsA, _, _ = F.double_quant(A)
CB, _, statsB, _, _ = F.int8_vectorwise_quant(B)
CA, _, statsA, _, _ = F.int8_double_quant(A)
CB, statsB, _ = F.int8_vectorwise_quant(B)
output = F.int8_mm_dequant(F.int8_linear_matmul(CA, CB), statsA, statsB)

torch.testing.assert_close(C1.view(-1, C1.shape[-1]), output, atol=0.025, rtol=0.05)
Expand Down Expand Up @@ -863,7 +863,7 @@ def test_double_quant(dim1, dim2):
out_col1, Scol = F.vectorwise_quant(A, dim=0)
out_row1, Srow = F.vectorwise_quant(A, dim=1)

CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A)
CA, CAt, statsA, statsAt, coo_tensor = F.int8_double_quant(A)

# max difference is 1 due to rounding differences
torch.testing.assert_close(CA, out_row1, atol=1, rtol=0)
Expand Down Expand Up @@ -953,7 +953,7 @@ def test_igemmlt_row_scale(dim1, dim4, inner):

out1 = torch.matmul(A.half(), B.t().half())

C1a, C1b, stats1a, stats1b, coo_tensor = F.double_quant(A)
C1a, C1b, stats1a, stats1b, coo_tensor = F.int8_double_quant(A)
CB, absmaxB = F.vectorwise_quant(B, quant_type="linear")
A2, SA = F.nvidia_transform(C1a, "col32")
B2, SB = F.nvidia_transform(CB, formatB)
Expand Down Expand Up @@ -1032,7 +1032,7 @@ def test_row_scale_bench(dim1, dim4, inner):
torch.cuda.synchronize()
print("16", time.time() - t0)

C1a, C1b, stats1a, stats1b, coo_tensor = F.double_quant(A)
C1a, C1b, stats1a, stats1b, coo_tensor = F.int8_double_quant(A)
CB, absmaxB = F.vectorwise_quant(B, quant_type="linear")
A2, SA = F.nvidia_transform(C1a, "col32")
B2, SB = F.nvidia_transform(CB, formatB)
Expand All @@ -1047,7 +1047,7 @@ def test_row_scale_bench(dim1, dim4, inner):
torch.cuda.synchronize()
print("row-wise", time.time() - t0)

C2a, C2b, stats2a, stats2b, coo_tensor = F.double_quant(B)
C2a, C2b, stats2a, stats2b, coo_tensor = F.int8_double_quant(B)
B2, SB = F.nvidia_transform(C2a, formatB)
torch.cuda.synchronize()
t0 = time.time()
Expand Down Expand Up @@ -1115,7 +1115,8 @@ def test_coo_double_quant(dim1, dim2):

if coo_tensor is not None:
A1 = A * idx
A2 = coo_tensor.to_dense()
A2 = torch.zeros_like(A)
A2[coo_tensor.rowidx.long(), coo_tensor.colidx.long()] = coo_tensor.values
torch.testing.assert_close(A1, A2)

A1 = A * (idx == 0)
Expand All @@ -1133,14 +1134,9 @@ def test_coo_int8_vectorwise_quant(dim1, dim2):
A = torch.randn(dim1, dim2, device="cuda").half()

idx = torch.abs(A) >= threshold
CA, statsA, coo_tensor = F.int8_vectorwise_quant(A, threshold=threshold)
CA, statsA, outlier_cols = F.int8_vectorwise_quant(A, threshold=threshold)

if coo_tensor is not None:
A1 = A * idx
A2 = coo_tensor.to_dense()
torch.testing.assert_close(A1, A2)

A1 = A * (idx == 0)
if outlier_cols is not None:
A2 = (CA.float() * statsA.unsqueeze(1) / 127).half()
torch.testing.assert_close(A * (idx == 0), A2, rtol=0.05, atol=1.5e-2)

Expand Down Expand Up @@ -1230,13 +1226,14 @@ def test_integrated_sparse_decomp(dim1, dim2):
w1 = torch.randn(dim1, dim2).cuda().half()
out1 = torch.matmul(A, w1.t())

Cw1, statsw1, coo_tensor = F.int8_vectorwise_quant(w1)
CA, statsA, coo_tensor = F.int8_vectorwise_quant(A)
Cw1, statsw1, _ = F.int8_vectorwise_quant(w1)
CA, statsA, _ = F.int8_vectorwise_quant(A)

out1_32 = F.int8_linear_matmul(CA, Cw1)
out2 = F.int8_mm_dequant(out1_32, statsA, statsw1)

CA, statsA, coo_tensor = F.int8_vectorwise_quant(A, threshold=threshold)
# CA, statsA, outlier_cols = F.int8_vectorwise_quant(A, threshold=threshold)
CA, _, statsA, _, coo_tensor = F.double_quant(A, threshold=threshold)

out1_32 = F.int8_linear_matmul(CA, Cw1)
out3 = F.int8_mm_dequant(out1_32, statsA, statsw1)
Expand Down Expand Up @@ -1377,7 +1374,7 @@ def test_spmm_coo_dequant(dim1, dim2, dtype):
torch.nn.init.xavier_uniform_(B)
Bt = B.t().contiguous()

CB, CBt, statsB, statsBt, coo_tensor = F.double_quant(B)
CB, CBt, statsB, statsBt, coo_tensor = F.int8_double_quant(B)

rowidx = torch.randint(0, A.shape[-1], size=(15,))

Expand Down
2 changes: 0 additions & 2 deletions tests/test_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,15 +356,13 @@ def test_linear8bitlt_accumulated_gradient():


@pytest.mark.parametrize("threshold", [0.0, 2.0])
@pytest.mark.parametrize("memory_efficient_backward", [False])
def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward):
l1 = (
bnb.nn.Linear8bitLt(
32,
64,
threshold=threshold,
has_fp16_weights=False,
memory_efficient_backward=memory_efficient_backward,
)
.cuda()
.half()
Expand Down

0 comments on commit eed9c3c

Please sign in to comment.