diff --git a/ruff.toml b/ruff.toml index 25b412c38..1aed057d2 100644 --- a/ruff.toml +++ b/ruff.toml @@ -12,6 +12,8 @@ include = [ "test/dtypes/test_affine_quantized_float.py", "test/dtypes/test_nf4.py", "test/prototype/low_bit_optim/**.py", + "torchao/utils.py", + ] lint.ignore = ["E731"] diff --git a/test/dtypes/test_nf4.py b/test/dtypes/test_nf4.py index 9f1c37330..4da6b9539 100644 --- a/test/dtypes/test_nf4.py +++ b/test/dtypes/test_nf4.py @@ -170,11 +170,10 @@ def test_load_from_state_dicts(self, dtype: torch.dtype): assert base_mod.param.block_size == 32 assert base_mod.param.scaler_block_size == 2 - @unittest.skipIf(not torch.cuda.is_available(), "Need cuda for test") @parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) def test_load_from_nf4_same_meta(self, dtype: torch.dtype): """Tests loading to and from different module state dicts""" - input_tensor = torch.rand(64, device="cuda", dtype=dtype) + input_tensor = torch.rand(64, dtype=dtype) base_mod = self.TestMod(input_tensor, 32, 2) state_dict = base_mod.state_dict() saved_state_dict = self.save_state_dict_to_buffer(state_dict) @@ -184,11 +183,10 @@ def test_load_from_nf4_same_meta(self, dtype: torch.dtype): assert other_mod.param.block_size == 32 assert other_mod.param.scaler_block_size == 2 - @unittest.skipIf(not torch.cuda.is_available(), "Need cuda for test") @parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) def test_load_from_nf4_diff_meta(self, dtype: torch.dtype): """Tests loading to and from different module state dicts""" - input_tensor = torch.rand(128, device="cuda", dtype=dtype) + input_tensor = torch.rand(128, dtype=dtype) base_mod = self.TestMod(input_tensor, 32, 2) state_dict = base_mod.state_dict() saved_state_dict = self.save_state_dict_to_buffer(state_dict) diff --git a/test/float8/test_base.py b/test/float8/test_base.py index b0c759e75..7bc5a3788 100644 --- a/test/float8/test_base.py +++ b/test/float8/test_base.py @@ -632,7 +632,7 @@ def test_pad_inner_dim(self, base_dtype, use_fast_accum): with pytest.raises( RuntimeError, match=re.escape( - "Expected trailing dimension of mat1 to be divisible by 16 but got mat1 shape: (16x41." + "Expected trailing dimension of mat1 to be divisible by 16 but got mat1 shape: (16x41)." ), ): a_fp8 @ b_fp8 diff --git a/torchao/dtypes/nf4tensor.py b/torchao/dtypes/nf4tensor.py index 617fd0871..3771f9d4b 100644 --- a/torchao/dtypes/nf4tensor.py +++ b/torchao/dtypes/nf4tensor.py @@ -10,6 +10,8 @@ from torch._prims_common import make_contiguous_strides_for from torch.distributed.device_mesh import DeviceMesh +from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 + aten = torch.ops.aten c10d_functional = torch.ops.c10d_functional @@ -1043,3 +1045,7 @@ def nf4_constructor( quantized_data, nf4, ) + + +if TORCH_VERSION_AT_LEAST_2_5: + torch.serialization.add_safe_globals([NF4Tensor]) diff --git a/torchao/utils.py b/torchao/utils.py index 51c1d7a35..e47482413 100644 --- a/torchao/utils.py +++ b/torchao/utils.py @@ -1,15 +1,14 @@ -import torch -from typing import Tuple, Any, Callable -from functools import reduce import functools -from importlib.metadata import version -from math import gcd -import torch.nn.utils.parametrize as parametrize import itertools -import time -import warnings import re +import time +from functools import reduce +from importlib.metadata import version +from math import gcd +from typing import Any, Callable, Tuple +import torch +import torch.nn.utils.parametrize as parametrize __all__ = [ "benchmark_model", @@ -27,7 +26,6 @@ "TORCH_VERSION_AT_LEAST_2_4", "TORCH_VERSION_AT_LEAST_2_5", "TORCH_VERSION_AT_LEAST_2_6", - # Needs to be deprecated in the future "TORCH_VERSION_AFTER_2_2", "TORCH_VERSION_AFTER_2_3", @@ -42,8 +40,9 @@ def _assert_and_get_unique_device(module: torch.nn.Module) -> Any: Returns the unique device for a module, or None if no device is found. Throws an error if multiple devices are detected. """ - devices = {p.device for p in module.parameters()} | \ - {p.device for p in module.buffers()} + devices = {p.device for p in module.parameters()} | { + p.device for p in module.buffers() + } assert len(devices) <= 1, ( "prepare only works with cpu or single-device CUDA modules, " @@ -54,13 +53,14 @@ def _assert_and_get_unique_device(module: torch.nn.Module) -> Any: def benchmark_model(model, num_runs, args=(), kwargs=None, device_type=None): - """Benchmark model runs with `args` and `kwargs` both are optional - """ + """Benchmark model runs with `args` and `kwargs` both are optional""" if kwargs is None: kwargs = {} if device_type is None: - assert isinstance(model, torch.nn.Module), "Expecting `model` to be torch.nn.Module if device_type is not provided" + assert isinstance( + model, torch.nn.Module + ), "Expecting `model` to be torch.nn.Module if device_type is not provided" device_type = _assert_and_get_unique_device(model).type if device_type == "cuda": @@ -110,35 +110,48 @@ def benchmark_model(model, num_runs, args=(), kwargs=None, device_type=None): def profiler_runner(path, fn, *args, **kwargs): with torch.profiler.profile( - activities=[torch.profiler.ProfilerActivity.CPU, - torch.profiler.ProfilerActivity.CUDA], - record_shapes=True) as prof: + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ], + record_shapes=True, + ) as prof: result = fn(*args, **kwargs) prof.export_chrome_trace(path) return result + def get_compute_capability(): if torch.cuda.is_available(): capability = torch.cuda.get_device_capability() return float(f"{capability[0]}.{capability[1]}") return 0.0 + def skip_if_compute_capability_less_than(min_capability): import unittest + def decorator(test_func): def wrapper(*args, **kwargs): if get_compute_capability() < min_capability: - raise unittest.SkipTest(f"Compute capability is less than {min_capability}") + raise unittest.SkipTest( + f"Compute capability is less than {min_capability}" + ) return test_func(*args, **kwargs) + return wrapper + return decorator + def compute_max_diff(output: torch.Tensor, output_ref: torch.Tensor) -> torch.Tensor: return torch.mean(torch.abs(output - output_ref)) / torch.mean( - torch.abs(output_ref)) + torch.abs(output_ref) + ) + def benchmark_torch_function_in_microseconds(f, *args, **kwargs): - import torch.utils.benchmark as benchmark # this avoids importing numpy when torchao module is loaded + import torch.utils.benchmark as benchmark # this avoids importing numpy when torchao module is loaded # Manual warmup f(*args, **kwargs) @@ -158,6 +171,7 @@ def find_multiple(n: int, *args: Tuple[int]) -> int: return n return n + k - (n % k) + def _register_custom_op(lib): """This decorator is used to preserve some high level operators for torch.export.export while still allow them to be decomposed for inductor path @@ -191,8 +205,12 @@ def decorator(fn): # expecting fn.__name__ starts with `_` and we want to take the rest # to be the name of the custom op - assert fn.__name__[0] == "_", f"Expecting function name starts with `_`, got {fn.__name__}" - assert not any(c in fn.__name__ for c in ".<>"), f"Expecting op to be defined in normal functions, not lambda or local: {fn.__name__}" + assert ( + fn.__name__[0] == "_" + ), f"Expecting function name starts with `_`, got {fn.__name__}" + assert not any( + c in fn.__name__ for c in ".<>" + ), f"Expecting op to be defined in normal functions, not lambda or local: {fn.__name__}" op_name = fn.__name__[1:] schema = op_name + infer_schema(fn, mutates_args={}) lib.define(schema) @@ -207,12 +225,14 @@ def decorator(fn): return decorator + def get_model_size_in_bytes(model, ignore_embeddings=False): """ Returns the model size in bytes. The option to ignore embeddings is useful for models with disproportionately large embeddings compared to other model parameters that get quantized/sparsified. """ + def flat_size(tensor): if hasattr(tensor, "__tensor_flatten__"): size = 0 @@ -228,11 +248,14 @@ def flat_size(tensor): model_size = 0 for name, child in model.named_children(): if not (isinstance(child, torch.nn.Embedding) and ignore_embeddings): - for p in itertools.chain(child.parameters(recurse=False), child.buffers(recurse=False)): + for p in itertools.chain( + child.parameters(recurse=False), child.buffers(recurse=False) + ): model_size += flat_size(p) model_size += get_model_size_in_bytes(child, ignore_embeddings) return model_size + class UnwrapTensorSubclass(torch.nn.Module): def forward(self, *tensors): todo = list(tensors) @@ -267,6 +290,7 @@ def right_inverse(self, tensor): return plain_tensors + def unwrap_tensor_subclass(model, filter_fn=None): """Unwraps (nested) tensor subclass in the model to plain tensors This is a workaround to make a model with tensor subclass to work with `torch.export.export` @@ -276,14 +300,19 @@ def unwrap_tensor_subclass(model, filter_fn=None): for name, child in model.named_children(): # make sure child.weight is a tensor subclass if ( - (isinstance(child, torch.nn.Linear) or isinstance(child, torch.nn.Embedding)) and - hasattr(child, "weight") and - type(child.weight) is not torch.Tensor and - type(child.weight) is not torch.nn.Parameter and - isinstance(child.weight, torch.Tensor) and - issubclass(type(child.weight), torch.Tensor) + ( + isinstance(child, torch.nn.Linear) + or isinstance(child, torch.nn.Embedding) + ) + and hasattr(child, "weight") + and type(child.weight) is not torch.Tensor + and type(child.weight) is not torch.nn.Parameter + and isinstance(child.weight, torch.Tensor) + and issubclass(type(child.weight), torch.Tensor) ): - parametrize.register_parametrization(child, "weight", UnwrapTensorSubclass()) + parametrize.register_parametrization( + child, "weight", UnwrapTensorSubclass() + ) unwrap_tensor_subclass(child) return model @@ -300,24 +329,28 @@ def _is_float8_type(dtype: torch.dtype) -> bool: def parse_version(version_string): # Extract just the X.Y.Z part from the version string - match = re.match(r'(\d+\.\d+\.\d+)', version_string) + match = re.match(r"(\d+\.\d+\.\d+)", version_string) if match: version = match.group(1) - return [int(x) for x in version.split('.')] + return [int(x) for x in version.split(".")] else: raise ValueError(f"Invalid version string format: {version_string}") + def compare_versions(v1, v2): v1_parts = parse_version(v1) v2_parts = parse_version(v2) return (v1_parts > v2_parts) - (v1_parts < v2_parts) + def is_fbcode(): return not hasattr(torch.version, "git_version") + def torch_version_at_least(min_version): return is_fbcode() or compare_versions(torch.__version__, min_version) >= 0 + TORCH_VERSION_AT_LEAST_2_6 = torch_version_at_least("2.6.0") TORCH_VERSION_AT_LEAST_2_5 = torch_version_at_least("2.5.0") TORCH_VERSION_AT_LEAST_2_4 = torch_version_at_least("2.4.0") @@ -329,6 +362,8 @@ def torch_version_at_least(min_version): Helper function for implementing aten op or torch function dispatch and dispatching to these implementations. """ + + def _implements(cls, aten_ops_or_torch_fns): """Use this decorator to implement a function for an aten ops in __torch_dispatch__ (if user passed in a list of ops) @@ -350,16 +385,20 @@ def _(func, types, args, kwargs): if not isinstance(aten_ops_or_torch_fns, (list, tuple)): aten_ops_or_torch_fns = [aten_ops_or_torch_fns] + def decorator(func): for op in aten_ops_or_torch_fns: + @functools.wraps(op) def wrapper(f, types, args, kwargs): return func(f, types, args, kwargs) cls._ATEN_OP_OR_TORCH_FN_TABLE[op] = wrapper return func + return decorator + def _dispatch__torch_function__(cls, func, types, args=(), kwargs=None): """Use this util function for a common `__torch_function__` implementation that dispatches to ops/functions registered with `_implements` @@ -369,13 +408,16 @@ class MyTensor(torch.Tensor): __torch_function__ = classmethod(_dispatch__torch_function__) """ kwargs = {} if kwargs is None else kwargs - if hasattr(cls, "_ATEN_OP_OR_TORCH_FN_TABLE") and \ - func in cls._ATEN_OP_OR_TORCH_FN_TABLE: + if ( + hasattr(cls, "_ATEN_OP_OR_TORCH_FN_TABLE") + and func in cls._ATEN_OP_OR_TORCH_FN_TABLE + ): return cls._ATEN_OP_OR_TORCH_FN_TABLE[func](func, types, args, kwargs) with torch._C.DisableTorchFunctionSubclass(): return func(*args, **kwargs) + def _dispatch__torch_dispatch__(cls, func, types, args, kwargs): """Use this util function for a common `__torch_dispatch__` implementation that dispatches to ops/functions registered with `_implements` @@ -384,13 +426,18 @@ class MyTensor(torch.Tensor): ... __torch_dispatch__ = classmethod(_dispatch__torch_dispatch__) """ - if hasattr(cls, "_ATEN_OP_OR_TORCH_FN_TABLE") and \ - func in cls._ATEN_OP_OR_TORCH_FN_TABLE: + if ( + hasattr(cls, "_ATEN_OP_OR_TORCH_FN_TABLE") + and func in cls._ATEN_OP_OR_TORCH_FN_TABLE + ): return cls._ATEN_OP_OR_TORCH_FN_TABLE[func](func, types, args, kwargs) arg_types = tuple(type(arg) for arg in args) kwarg_types = {k: type(arg) for k, arg in kwargs.items()} - raise NotImplementedError(f"{cls.__name__} dispatch: attempting to run unimplemented operator/function: {func=}, {types=}, {arg_types=}, {kwarg_types=}") + raise NotImplementedError( + f"{cls.__name__} dispatch: attempting to run unimplemented operator/function: {func=}, {types=}, {arg_types=}, {kwarg_types=}" + ) + def _register_layout(tensor_class: Callable, layout_class: Callable): """Helper function for layout registrations, this is used to implement @@ -411,14 +458,20 @@ def _register_layout(tensor_class: Callable, layout_class: Callable): tensor_class._LAYOUT_CONSTRUCTOR_TABLE = {} def decorator(tensor_impl_class): - tensor_class._LAYOUT_CONSTRUCTOR_TABLE[layout_class] = tensor_impl_class.from_plain + tensor_class._LAYOUT_CONSTRUCTOR_TABLE[layout_class] = ( + tensor_impl_class.from_plain + ) if TORCH_VERSION_AT_LEAST_2_5: # Allow serialization to work for models uses this tensor impl subclass torch.serialization.add_safe_globals([layout_class, tensor_impl_class]) return tensor_impl_class + return decorator -def _get_tensor_impl_constructor(tensor_class: Callable, layout_class: Callable) -> Callable: + +def _get_tensor_impl_constructor( + tensor_class: Callable, layout_class: Callable +) -> Callable: """Get TensorImpl class constructor (TensorImplClass.from_plain) for `tensor_class` based on `layout_class` `layout_class` means the class type of subclass of `Layout`, e.g. `PlainLayout` @@ -430,9 +483,13 @@ def _get_tensor_impl_constructor(tensor_class: Callable, layout_class: Callable) tensor impl subclass constructor for the layout_class """ if not hasattr(tensor_class, "_LAYOUT_CONSTRUCTOR_TABLE"): - raise ValueError(f"no registered tensor_impl class constructor for: {tensor_class}") + raise ValueError( + f"no registered tensor_impl class constructor for: {tensor_class}" + ) if layout_class not in tensor_class._LAYOUT_CONSTRUCTOR_TABLE: - raise ValueError(f"layout_name: {layout_class} is not supported yet for {tensor_class}") + raise ValueError( + f"layout_name: {layout_class} is not supported yet for {tensor_class}" + ) return tensor_class._LAYOUT_CONSTRUCTOR_TABLE[layout_class] @@ -471,6 +528,7 @@ class PlainAQTTensorImpl(...): tensor_impl = tensor_impl_ctr(data, scale, zero_point, _layout) """ + implements = classmethod(_implements) __torch_dispatch__ = classmethod(_dispatch__torch_dispatch__) __torch_function__ = classmethod(_dispatch__torch_function__) @@ -497,6 +555,7 @@ def _get_to_kwargs(self, *args, **kwargs): } return kwargs + def fill_defaults(args, n, defaults_tail): """ __torch_dispatch__ doesn't guarantee the number of arguments you are @@ -526,6 +585,7 @@ def fill_defaults(args, n, defaults_tail): def _torch_version_at_least(min_version): return is_fbcode() or version("torch") >= min_version + TORCH_VERSION_AFTER_2_5 = _torch_version_at_least("2.5.0.dev") TORCH_VERSION_AFTER_2_4 = _torch_version_at_least("2.4.0.dev") TORCH_VERSION_AFTER_2_3 = _torch_version_at_least("2.3.0.dev")