Skip to content

Commit

Permalink
Merge branch 'pytorch:main' into optim_compile
Browse files Browse the repository at this point in the history
  • Loading branch information
gau-nernst authored Nov 13, 2024
2 parents 72365a5 + f96e5ec commit b09f3fc
Show file tree
Hide file tree
Showing 5 changed files with 112 additions and 46 deletions.
2 changes: 2 additions & 0 deletions ruff.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ include = [
"test/dtypes/test_affine_quantized_float.py",
"test/dtypes/test_nf4.py",
"test/prototype/low_bit_optim/**.py",
"torchao/utils.py",

]

lint.ignore = ["E731"]
6 changes: 2 additions & 4 deletions test/dtypes/test_nf4.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,11 +170,10 @@ def test_load_from_state_dicts(self, dtype: torch.dtype):
assert base_mod.param.block_size == 32
assert base_mod.param.scaler_block_size == 2

@unittest.skipIf(not torch.cuda.is_available(), "Need cuda for test")
@parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32])
def test_load_from_nf4_same_meta(self, dtype: torch.dtype):
"""Tests loading to and from different module state dicts"""
input_tensor = torch.rand(64, device="cuda", dtype=dtype)
input_tensor = torch.rand(64, dtype=dtype)
base_mod = self.TestMod(input_tensor, 32, 2)
state_dict = base_mod.state_dict()
saved_state_dict = self.save_state_dict_to_buffer(state_dict)
Expand All @@ -184,11 +183,10 @@ def test_load_from_nf4_same_meta(self, dtype: torch.dtype):
assert other_mod.param.block_size == 32
assert other_mod.param.scaler_block_size == 2

@unittest.skipIf(not torch.cuda.is_available(), "Need cuda for test")
@parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32])
def test_load_from_nf4_diff_meta(self, dtype: torch.dtype):
"""Tests loading to and from different module state dicts"""
input_tensor = torch.rand(128, device="cuda", dtype=dtype)
input_tensor = torch.rand(128, dtype=dtype)
base_mod = self.TestMod(input_tensor, 32, 2)
state_dict = base_mod.state_dict()
saved_state_dict = self.save_state_dict_to_buffer(state_dict)
Expand Down
2 changes: 1 addition & 1 deletion test/float8/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,7 +632,7 @@ def test_pad_inner_dim(self, base_dtype, use_fast_accum):
with pytest.raises(
RuntimeError,
match=re.escape(
"Expected trailing dimension of mat1 to be divisible by 16 but got mat1 shape: (16x41."
"Expected trailing dimension of mat1 to be divisible by 16 but got mat1 shape: (16x41)."
),
):
a_fp8 @ b_fp8
Expand Down
6 changes: 6 additions & 0 deletions torchao/dtypes/nf4tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from torch._prims_common import make_contiguous_strides_for
from torch.distributed.device_mesh import DeviceMesh

from torchao.utils import TORCH_VERSION_AT_LEAST_2_5

aten = torch.ops.aten

c10d_functional = torch.ops.c10d_functional
Expand Down Expand Up @@ -1043,3 +1045,7 @@ def nf4_constructor(
quantized_data,
nf4,
)


if TORCH_VERSION_AT_LEAST_2_5:
torch.serialization.add_safe_globals([NF4Tensor])
Loading

0 comments on commit b09f3fc

Please sign in to comment.