From 57227c6b1f64d7869c364b590459396ca52b52f1 Mon Sep 17 00:00:00 2001 From: danthe3rd <43445237+danthe3rd@users.noreply.github.com> Date: Thu, 22 Aug 2024 10:53:31 +0000 Subject: [PATCH] Remove backward compatibility to PT before 2.4 ghstack-source-id: 228c42bc7b56f30358351b78bb01448f667780ca Pull Request resolved: https://github.com/fairinternal/xformers/pull/1194 __original_commit__ = fairinternal/xformers@caaa6f45796e96bf4fb75e70b39dfa54084b33bd --- CHANGELOG.md | 1 + tests/test_checkpoint.py | 5 ----- tests/test_sparsity24.py | 6 ------ tests/test_swiglu.py | 9 -------- xformers/checkpoint.py | 4 ---- xformers/csrc/autocast.h | 21 ------------------- .../sparse24/sparse24_apply_dense_output.cu | 3 +-- xformers/csrc/sparse24/sparse24_pack.cu | 3 +-- .../cuda/dual_gemm_silu_identity_mul.cu | 4 +--- .../swiglu/cuda/gemm_fused_operand_sum.cu | 4 +--- xformers/csrc/swiglu/cuda/silu_bw_fused.cu | 4 +--- xformers/csrc/swiglu/swiglu_packedw.cpp | 4 +--- xformers/ops/sp24.py | 5 ++--- xformers/ops/swiglu_op.py | 6 +++--- xformers/sparse/csr_tensor.py | 1 - xformers/utils.py | 12 ----------- 16 files changed, 12 insertions(+), 80 deletions(-) delete mode 100644 xformers/csrc/autocast.h diff --git a/CHANGELOG.md b/CHANGELOG.md index e01f0c0fa8..a0b5752f36 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Removed - fMHA: Removed `decoder` and `small_k` backends - profiler: Removed `DetectSlowOpsProfiler` profiler +- Removed compatibility with PyTorch < 2.4 ## [0.0.27.post2] - 2024-07-26 Pre-built binary wheels require PyTorch 2.4.0 diff --git a/tests/test_checkpoint.py b/tests/test_checkpoint.py index 287cb6cc99..9dfe158bbd 100644 --- a/tests/test_checkpoint.py +++ b/tests/test_checkpoint.py @@ -37,7 +37,6 @@ def _all_policy(ctx, func, *args, **kwargs): return True -@pytest.mark.skipif(torch.__version__ < "2.2", reason="Only new PyTorch supported") @pytest.mark.parametrize("policy_fn", [None, [], _relu_policy, _all_policy]) @pytest.mark.parametrize("input_requires_grad", [True, False]) @pytest.mark.parametrize("device", _devices) @@ -75,7 +74,6 @@ def build_module(): assert torch.allclose(p.grad, p_copy.grad) -@pytest.mark.skipif(torch.__version__ < "2.2", reason="Only new PyTorch supported") @pytest.mark.parametrize("policy_fn", [None, [], _relu_policy, _all_policy]) @pytest.mark.parametrize("input_requires_grad", [True, False]) @pytest.mark.parametrize("grad_mode", [True, False]) @@ -105,7 +103,6 @@ def test_checkpoint_with_grad(policy_fn, input_requires_grad, grad_mode): assert torch.allclose(out, out_copy) -@pytest.mark.skipif(torch.__version__ < "2.2", reason="Only new PyTorch supported") @cuda_only @pytest.mark.parametrize("policy_fn", [None, [], _relu_policy, _all_policy]) @pytest.mark.parametrize("input_requires_grad", [True, False]) @@ -290,7 +287,6 @@ def forward(self, x): return x -@pytest.mark.skipif(torch.__version__ < "2.2", reason="Only new PyTorch supported") @cuda_only @pytest.mark.parametrize("device", ["cuda"]) @pytest.mark.parametrize("memory_budget", [0, 0.03, 0.05, 0.1, 0.3, 0.5, 0.8, 1.0]) @@ -331,7 +327,6 @@ def test_optimal_checkpoint_policy( torch.testing.assert_close(p.grad, p_ref.grad) -@pytest.mark.skipif(torch.__version__ < "2.2", reason="Only new PyTorch supported") @pytest.mark.skipif(True, reason="TODO[fmassa]: Broken on nightly") @cuda_only @pytest.mark.parametrize("no_grad", [False, True]) diff --git a/tests/test_sparsity24.py b/tests/test_sparsity24.py index 364074801c..d700d43fd5 100644 --- a/tests/test_sparsity24.py +++ b/tests/test_sparsity24.py @@ -23,9 +23,6 @@ if torch.cuda.is_available(): compute_capability = torch.cuda.get_device_capability("cuda") -torch_compile_tests = pytest.mark.skipif( - torch.__version__ < "2.2.0.dev20231122", reason="requires PyTorch 2.2+" -) requires_sp24 = pytest.mark.skipif(compute_capability < (8, 0), reason="requires sm80+") requires_sp24_gemm = pytest.mark.skipif( compute_capability != (8, 0), reason="requires sm80" @@ -897,7 +894,6 @@ def test_linear_dispatch_inference_mode(backend: str, with_bias: bool) -> None: assert_allclose(out, out_ref, msg="output", **atol_rtol_kw[x.dtype]) -@torch_compile_tests @cuda_only def test_sp24_meta() -> None: x = torch.randn([1024, 512], device="meta", dtype=torch.float16) @@ -907,7 +903,6 @@ def test_sp24_meta() -> None: assert x_s_t.shape == x.t().shape -@torch_compile_tests @requires_sp24_gemm @parametrize_backend def test_sp24_compile(backend) -> None: @@ -952,7 +947,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: @requires_sp24_gemm -@torch_compile_tests @pytest.mark.skipif(not sp24._has_cusparseLt(), reason="requires cusparselt") def test_linearw24_block_compile() -> None: # TODO: Parametrize on `dtype` when torch.compile gets faster diff --git a/tests/test_swiglu.py b/tests/test_swiglu.py index 9d74d7e82f..6600488135 100644 --- a/tests/test_swiglu.py +++ b/tests/test_swiglu.py @@ -25,11 +25,6 @@ _devices = [] _is_sm80 = False cuda_sm80_only = pytest.mark.skipif(not _is_sm80, reason="requires sm80+") - -torch_compile_tests = pytest.mark.skipif( - torch.__version__ < "2.2.0.dev20231122", reason="requires PyTorch 2.2+" -) - disable_on_rocm = pytest.mark.skipif( not not torch.version.hip, reason="could not be done on ROCM" ) @@ -249,7 +244,6 @@ def backward_gather_grads(inp, output): assert gout.norm(2) > BACKWARD_ATOL[dtype] / BACKWARD_RTOL[dtype] -@torch_compile_tests @cuda_sm80_only @pytest.mark.parametrize("device", _devices) @pytest.mark.parametrize("dtype", _dtypes, ids=[str(x) for x in _dtypes]) @@ -346,7 +340,6 @@ def test_swiglu_compile( @disable_tf32 @torch.inference_mode() -@torch_compile_tests @cuda_sm80_only @pytest.mark.parametrize("dtype", _dtypes, ids=[str(x) for x in _dtypes]) @pytest.mark.parametrize("device", _devices) @@ -388,7 +381,6 @@ def fn(x): @disable_tf32 @torch.inference_mode() @cuda_sm80_only -@torch_compile_tests @pytest.mark.parametrize("dtype", _dtypes, ids=[str(x) for x in _dtypes]) @pytest.mark.parametrize("device", _devices) def test_gemm_fused_operand_sum_compile(dtype, device) -> None: @@ -420,7 +412,6 @@ def fn(x): @disable_tf32 @torch.inference_mode() -@torch_compile_tests @pytest.mark.parametrize("dtype", _dtypes, ids=[str(x) for x in _dtypes]) @pytest.mark.parametrize("device", _devices) def test_silu_bw_fused_compile(dtype, device) -> None: diff --git a/xformers/checkpoint.py b/xformers/checkpoint.py index 2eb0aca5e4..c2f323ffa1 100644 --- a/xformers/checkpoint.py +++ b/xformers/checkpoint.py @@ -474,10 +474,6 @@ def __call__(self, ctx, func, *args, **kwargs) -> bool: class SelectiveCheckpointWrapper(ActivationWrapper): def __init__(self, mod, memory_budget=None, policy_fn=None): - if torch.__version__ < (2, 1): - raise RuntimeError( - "SelectiveCheckpointWrapper only supported for torch >- 2.1" - ) super().__init__(mod) if not ((memory_budget is None) ^ (policy_fn is None)): raise ValueError("Need to specify either policy_fn or memory_budget") diff --git a/xformers/csrc/autocast.h b/xformers/csrc/autocast.h deleted file mode 100644 index 467e149fe8..0000000000 --- a/xformers/csrc/autocast.h +++ /dev/null @@ -1,21 +0,0 @@ -#pragma once - -#include -#include - -namespace xformers { - -// In PyTorch 2.4 (https://github.com/pytorch/pytorch/pull/124359) they renamed -// some functions and immediately marked the old ones as deprecated, causing a -// lot of log spew. For a while we need to support both old and new PyTorch. - -inline at::ScalarType get_autocast_cuda_dtype() { -#if TORCH_VERSION_MAJOR > 2 || \ - (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR >= 4) - return at::autocast::get_autocast_dtype(at::kCUDA); -#else - return at::autocast::get_autocast_gpu_dtype(); -#endif -} - -} // namespace xformers diff --git a/xformers/csrc/sparse24/sparse24_apply_dense_output.cu b/xformers/csrc/sparse24/sparse24_apply_dense_output.cu index 205930c0cf..d8e3d0d56c 100644 --- a/xformers/csrc/sparse24/sparse24_apply_dense_output.cu +++ b/xformers/csrc/sparse24/sparse24_apply_dense_output.cu @@ -3,7 +3,6 @@ #include #include #include -#include "autocast.h" #include "compute_sparse_tile.h" #include "sparse24_pack.h" @@ -215,7 +214,7 @@ at::Tensor sparse24_apply_dense_output_autocast( double mul0, double mul1) { c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast); - auto exec_type = xformers::get_autocast_cuda_dtype(); + auto exec_type = at::autocast::get_autocast_dtype(at::kCUDA); return sparse24_apply_dense_output( at::autocast::cached_cast(exec_type, input), threads_masks, mul0, mul1); } diff --git a/xformers/csrc/sparse24/sparse24_pack.cu b/xformers/csrc/sparse24/sparse24_pack.cu index 6a8bcf0437..676f80283e 100644 --- a/xformers/csrc/sparse24/sparse24_pack.cu +++ b/xformers/csrc/sparse24/sparse24_pack.cu @@ -4,7 +4,6 @@ #include #include #include -#include "autocast.h" #include "compute_sparse_tile.h" #include "sparse24_metadata.h" #include "sparse24_pack.h" @@ -155,7 +154,7 @@ std:: std::string algorithm, std::string backend) { c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast); - auto exec_type = xformers::get_autocast_cuda_dtype(); + auto exec_type = at::autocast::get_autocast_dtype(at::kCUDA); return sparse24_sparsify_both_ways( at::autocast::cached_cast(exec_type, input), algorithm, backend); } diff --git a/xformers/csrc/swiglu/cuda/dual_gemm_silu_identity_mul.cu b/xformers/csrc/swiglu/cuda/dual_gemm_silu_identity_mul.cu index d3c080d918..a643f5f3c8 100644 --- a/xformers/csrc/swiglu/cuda/dual_gemm_silu_identity_mul.cu +++ b/xformers/csrc/swiglu/cuda/dual_gemm_silu_identity_mul.cu @@ -15,8 +15,6 @@ #include <45_dual_gemm/device/dual_gemm.h> #include <45_dual_gemm/thread/left_silu_and_mul.h> -#include "autocast.h" - namespace { template @@ -205,7 +203,7 @@ dual_gemm_silu_identity_mul_autocast( const at::Tensor& w1, const std::optional& b1) { c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast); - auto exec_type = xformers::get_autocast_cuda_dtype(); + auto exec_type = at::autocast::get_autocast_dtype(at::kCUDA); return dual_gemm_silu_identity_mul( at::autocast::cached_cast(exec_type, x), at::autocast::cached_cast(exec_type, w0), diff --git a/xformers/csrc/swiglu/cuda/gemm_fused_operand_sum.cu b/xformers/csrc/swiglu/cuda/gemm_fused_operand_sum.cu index 57166dc887..15629f90cf 100644 --- a/xformers/csrc/swiglu/cuda/gemm_fused_operand_sum.cu +++ b/xformers/csrc/swiglu/cuda/gemm_fused_operand_sum.cu @@ -20,8 +20,6 @@ #include "cutlass/reduction/kernel/reduce_split_k.h" #include "cutlass/reduction/thread/reduction_operators.h" -#include "autocast.h" - namespace { template void gemm_fused_operand_sum_( @@ -245,7 +243,7 @@ std::tuple gemm_fused_operand_sum_autocast( at::Tensor& out_mm, at::Tensor& out_sum) { c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast); - auto exec_type = xformers::get_autocast_cuda_dtype(); + auto exec_type = at::autocast::get_autocast_dtype(at::kCUDA); return gemm_fused_operand_sum( at::autocast::cached_cast(exec_type, a), at::autocast::cached_cast(exec_type, b), diff --git a/xformers/csrc/swiglu/cuda/silu_bw_fused.cu b/xformers/csrc/swiglu/cuda/silu_bw_fused.cu index 4d7a9fb129..0f2398cd83 100644 --- a/xformers/csrc/swiglu/cuda/silu_bw_fused.cu +++ b/xformers/csrc/swiglu/cuda/silu_bw_fused.cu @@ -18,8 +18,6 @@ #include #include -#include "autocast.h" - namespace { /* Computes the following: @@ -104,7 +102,7 @@ std::tuple silu_bw_fused_autocast( const at::Tensor& x2, const at::Tensor& dx4) { c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast); - auto exec_type = xformers::get_autocast_cuda_dtype(); + auto exec_type = at::autocast::get_autocast_dtype(at::kCUDA); return silu_bw_fused( at::autocast::cached_cast(exec_type, x1), at::autocast::cached_cast(exec_type, x2), diff --git a/xformers/csrc/swiglu/swiglu_packedw.cpp b/xformers/csrc/swiglu/swiglu_packedw.cpp index 9dab94440c..e70a3a72fe 100644 --- a/xformers/csrc/swiglu/swiglu_packedw.cpp +++ b/xformers/csrc/swiglu/swiglu_packedw.cpp @@ -13,8 +13,6 @@ #include // clang-format on -#include "autocast.h" - namespace { // Kernels implemented in `cuda/` std::tuple dual_gemm_silu_identity_mul( @@ -198,7 +196,7 @@ at::Tensor swiglu_packedw_autocast( const at::Tensor w3, const std::optional b3) { c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast); - auto exec_type = xformers::get_autocast_cuda_dtype(); + auto exec_type = at::autocast::get_autocast_dtype(at::kCUDA); return SwiGLUPackedWeights::apply( at::autocast::cached_cast(exec_type, x), at::autocast::cached_cast(exec_type, w1w2), diff --git a/xformers/ops/sp24.py b/xformers/ops/sp24.py index c2ae6403ee..a27ce88092 100644 --- a/xformers/ops/sp24.py +++ b/xformers/ops/sp24.py @@ -475,9 +475,8 @@ def __torch_dispatch__(cls, func, types, args=(), kwargs=None): ) -if torch.__version__ >= "2.1.0": - torch._dynamo.allow_in_graph(Sparse24TensorCuSparseLt) - torch._dynamo.allow_in_graph(Sparse24TensorCutlass) +torch._dynamo.allow_in_graph(Sparse24TensorCuSparseLt) +torch._dynamo.allow_in_graph(Sparse24TensorCutlass) GRADIENT_SP24 = "24sparse" GRADIENT_DENSE = "24dense" diff --git a/xformers/ops/swiglu_op.py b/xformers/ops/swiglu_op.py index d764196803..630335ac6c 100644 --- a/xformers/ops/swiglu_op.py +++ b/xformers/ops/swiglu_op.py @@ -9,8 +9,8 @@ import torch import torch.nn.functional as F from torch import nn +from torch.amp import custom_bwd, custom_fwd -from ..utils import custom_bwd, custom_fwd from .common import BaseOperator, get_xformers_operator, register_operator from .unbind import stack_or_none, unbind @@ -110,7 +110,7 @@ class _SwiGLUFusedFunc(torch.autograd.Function): NAME = "fused.py" @classmethod - @custom_fwd + @custom_fwd(device_type="cuda") def forward(cls, ctx, x, w1, b1, w2, b2, w3, b3): x1, x2, x4 = DualGemmSiluOp.OPERATOR(x, w1, b1, w2, b2) @@ -131,7 +131,7 @@ def _linear_bw( return dw, db @classmethod - @custom_bwd + @custom_bwd(device_type="cuda") def backward(cls, ctx, dx5): x, w1, w2, w3, x1, x2 = ctx.saved_tensors w1w2 = stack_or_none([w1, w2], dim=0) diff --git a/xformers/sparse/csr_tensor.py b/xformers/sparse/csr_tensor.py index 67e7394ce0..5ec9846c39 100644 --- a/xformers/sparse/csr_tensor.py +++ b/xformers/sparse/csr_tensor.py @@ -26,7 +26,6 @@ def __new__(cls, row_offsets, column_indices, values, shape): kwargs["layout"] = values.layout kwargs["requires_grad"] = values.requires_grad assert len(shape) == 3 - assert torch.__version__ > (1, 10), "SparseCSRTensor requires PyTorch 1.11+" return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) def __init__(self, row_offsets, column_indices, values, shape): diff --git a/xformers/utils.py b/xformers/utils.py index e20616a0ec..69515ece8f 100644 --- a/xformers/utils.py +++ b/xformers/utils.py @@ -13,18 +13,6 @@ import torch -# PyTorch 2.4 introduced new functions and immediately marked the old ones as -# deprecated, causing a lot of log spew. Let's use the new ones if available. -try: - from torch.amp import custom_bwd as new_custom_bwd # type: ignore[attr-defined] - from torch.amp import custom_fwd as new_custom_fwd # type: ignore[attr-defined] - - custom_fwd = new_custom_fwd(device_type="cuda") - custom_bwd = new_custom_bwd(device_type="cuda") -except ImportError: - from torch.cuda.amp import custom_bwd, custom_fwd # type: ignore # noqa: F401 - - Item = namedtuple("Item", ["constructor", "config"])