Skip to content

Commit

Permalink
Raise error if group_size is passed but wrong strategy (#149)
Browse files Browse the repository at this point in the history
* add error

* Update src/compressed_tensors/quantization/quant_args.py

Co-authored-by: Dipika Sikka <[email protected]>

* remove extra line

* fix invalid configs it tests

---------

Co-authored-by: Dipika Sikka <[email protected]>
  • Loading branch information
kylesayrs and dsikka authored Sep 9, 2024
1 parent b885229 commit 4d89141
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 3 deletions.
6 changes: 6 additions & 0 deletions src/compressed_tensors/quantization/quant_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,12 @@ def validate_model_after(model: "QuantizationArgs") -> Dict[str, Any]:
f"strategy {strategy} requires group_size to be "
"set to a positive value"
)
if (
group_size is not None
and group_size > 0
and strategy != QuantizationStrategy.GROUP
):
raise ValueError("group_size requires strategy to be set to 'group'")

# validate activation ordering and strategy
if actorder is not None and strategy != QuantizationStrategy.GROUP:
Expand Down
2 changes: 1 addition & 1 deletion tests/test_compressors/test_fp8_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def make_dummy_g_idx(columns: int, group_size: int) -> torch.Tensor:
],
[
QuantizationStrategy.CHANNEL,
128,
None,
torch.rand((512, 1)) * 0.01,
torch.zeros((512, 1), dtype=torch.int8),
],
Expand Down
4 changes: 2 additions & 2 deletions tests/test_compressors/test_int_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def get_dummy_quant_config(strategy, group_size=None):
[
QuantizationStrategy.CHANNEL,
False,
128,
None,
torch.rand((512, 1)) * 0.01,
((torch.rand((512, 1)) - 0.5) * 127).to(torch.int8),
],
Expand Down Expand Up @@ -102,7 +102,7 @@ def test_quant_format(strategy, symmetric, group_size, sc, zp):
],
[
QuantizationStrategy.CHANNEL,
128,
None,
torch.rand((300, 1)) * 0.01,
torch.zeros((300, 1), dtype=torch.int8),
],
Expand Down
10 changes: 10 additions & 0 deletions tests/test_quantization/test_quant_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,16 @@ def test_group():
with pytest.raises(ValueError):
QuantizationArgs(strategy=QuantizationStrategy.GROUP, group_size=-1)

args = QuantizationArgs(group_size=128, strategy="group")
assert args.group_size == 128
assert args.strategy == "group"

with pytest.raises(ValueError):
QuantizationArgs(strategy=QuantizationStrategy.GROUP)

with pytest.raises(ValueError):
QuantizationArgs(strategy="tensor", group_size=128)


def test_block():
kwargs = {"strategy": "block", "block_structure": "2x4"}
Expand Down

0 comments on commit 4d89141

Please sign in to comment.