-
Notifications
You must be signed in to change notification settings - Fork 25
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
12 changed files
with
164 additions
and
42 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
import pytest | ||
|
||
|
||
def is_pytest_mode_compile(): | ||
return pytest.mode == "compile" | ||
|
||
|
||
def is_pytest_mode_lazy(): | ||
return pytest.mode == "lazy" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
import os | ||
from typing import Mapping | ||
|
||
import pytest | ||
|
||
|
||
def pytest_addoption(parser): | ||
parser.addoption( | ||
"--mode", | ||
action="store", | ||
default="lazy", | ||
help="{compile|lazy}, default lazy. Choose mode to run tests", | ||
) | ||
|
||
|
||
backup_env = pytest.StashKey[Mapping]() | ||
|
||
|
||
def pytest_configure(config): | ||
pytest.mode = config.getoption("--mode") | ||
assert pytest.mode.lower() in ["lazy", "compile"] | ||
|
||
config.stash[backup_env] = os.environ | ||
|
||
if pytest.mode == "lazy": | ||
os.environ["PT_HPU_LAZY_MODE"] = "1" | ||
elif pytest.mode == "compile": | ||
os.environ["PT_HPU_LAZY_MODE"] = "0" | ||
os.environ["PT_ENABLE_INT64_SUPPORT"] = "1" | ||
|
||
|
||
def pytest_unconfigure(config): | ||
os.environ.clear() | ||
os.environ.update(config.stash[backup_env]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,46 @@ | ||
import pytest | ||
import torch | ||
from auto_round.utils import is_hpu_supported | ||
|
||
from _test_helpers import is_pytest_mode_compile, is_pytest_mode_lazy | ||
|
||
|
||
def run_opt_125m_on_hpu(): | ||
from auto_round import AutoRound | ||
from transformers import AutoModelForCausalLM, AutoTokenizer | ||
|
||
model_name = "facebook/opt-125m" | ||
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", trust_remote_code=True) | ||
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) | ||
|
||
bits, group_size, sym = 4, 128, False | ||
autoround = AutoRound( | ||
model, | ||
tokenizer, | ||
bits=bits, | ||
group_size=group_size, | ||
sym=sym, | ||
iters=2, | ||
seqlen=2, | ||
) | ||
q_model, qconfig = autoround.quantize() | ||
assert q_model is not None, f"Expected q_model to be not None" | ||
|
||
|
||
@pytest.mark.skipif(not is_hpu_supported(), reason="HPU is not supported") | ||
@pytest.mark.skipif(not is_pytest_mode_lazy(), reason="Only for lazy mode") | ||
def test_opt_125m_lazy_mode(): | ||
run_opt_125m_on_hpu() | ||
|
||
|
||
@pytest.mark.skipif(not is_hpu_supported(), reason="HPU is not supported") | ||
@pytest.mark.skipif(not is_pytest_mode_compile(), reason="Only for compile mode") | ||
def test_opt_125m_compile_mode(): | ||
torch._dynamo.reset() | ||
run_opt_125m_on_hpu() | ||
|
||
|
||
def test_import(): | ||
from auto_round import AutoRound | ||
from auto_round.export.export_to_itrex.export import save_quantized_as_itrex, WeightOnlyLinear | ||
from auto_round.export.export_to_itrex.export import ( | ||
WeightOnlyLinear, save_quantized_as_itrex) |