Skip to content

Commit

Permalink
Remove backward compatibility to PT before 2.4
Browse files Browse the repository at this point in the history
ghstack-source-id: 228c42bc7b56f30358351b78bb01448f667780ca
Pull Request resolved: fairinternal/xformers#1194

__original_commit__ = fairinternal/xformers@caaa6f4
  • Loading branch information
danthe3rd authored and xFormers Bot committed Aug 22, 2024
1 parent b7c5a3d commit 57227c6
Show file tree
Hide file tree
Showing 16 changed files with 12 additions and 80 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Removed
- fMHA: Removed `decoder` and `small_k` backends
- profiler: Removed `DetectSlowOpsProfiler` profiler
- Removed compatibility with PyTorch < 2.4

## [0.0.27.post2] - 2024-07-26
Pre-built binary wheels require PyTorch 2.4.0
Expand Down
5 changes: 0 additions & 5 deletions tests/test_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ def _all_policy(ctx, func, *args, **kwargs):
return True


@pytest.mark.skipif(torch.__version__ < "2.2", reason="Only new PyTorch supported")
@pytest.mark.parametrize("policy_fn", [None, [], _relu_policy, _all_policy])
@pytest.mark.parametrize("input_requires_grad", [True, False])
@pytest.mark.parametrize("device", _devices)
Expand Down Expand Up @@ -75,7 +74,6 @@ def build_module():
assert torch.allclose(p.grad, p_copy.grad)


@pytest.mark.skipif(torch.__version__ < "2.2", reason="Only new PyTorch supported")
@pytest.mark.parametrize("policy_fn", [None, [], _relu_policy, _all_policy])
@pytest.mark.parametrize("input_requires_grad", [True, False])
@pytest.mark.parametrize("grad_mode", [True, False])
Expand Down Expand Up @@ -105,7 +103,6 @@ def test_checkpoint_with_grad(policy_fn, input_requires_grad, grad_mode):
assert torch.allclose(out, out_copy)


@pytest.mark.skipif(torch.__version__ < "2.2", reason="Only new PyTorch supported")
@cuda_only
@pytest.mark.parametrize("policy_fn", [None, [], _relu_policy, _all_policy])
@pytest.mark.parametrize("input_requires_grad", [True, False])
Expand Down Expand Up @@ -290,7 +287,6 @@ def forward(self, x):
return x


@pytest.mark.skipif(torch.__version__ < "2.2", reason="Only new PyTorch supported")
@cuda_only
@pytest.mark.parametrize("device", ["cuda"])
@pytest.mark.parametrize("memory_budget", [0, 0.03, 0.05, 0.1, 0.3, 0.5, 0.8, 1.0])
Expand Down Expand Up @@ -331,7 +327,6 @@ def test_optimal_checkpoint_policy(
torch.testing.assert_close(p.grad, p_ref.grad)


@pytest.mark.skipif(torch.__version__ < "2.2", reason="Only new PyTorch supported")
@pytest.mark.skipif(True, reason="TODO[fmassa]: Broken on nightly")
@cuda_only
@pytest.mark.parametrize("no_grad", [False, True])
Expand Down
6 changes: 0 additions & 6 deletions tests/test_sparsity24.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,6 @@
if torch.cuda.is_available():
compute_capability = torch.cuda.get_device_capability("cuda")

torch_compile_tests = pytest.mark.skipif(
torch.__version__ < "2.2.0.dev20231122", reason="requires PyTorch 2.2+"
)
requires_sp24 = pytest.mark.skipif(compute_capability < (8, 0), reason="requires sm80+")
requires_sp24_gemm = pytest.mark.skipif(
compute_capability != (8, 0), reason="requires sm80"
Expand Down Expand Up @@ -897,7 +894,6 @@ def test_linear_dispatch_inference_mode(backend: str, with_bias: bool) -> None:
assert_allclose(out, out_ref, msg="output", **atol_rtol_kw[x.dtype])


@torch_compile_tests
@cuda_only
def test_sp24_meta() -> None:
x = torch.randn([1024, 512], device="meta", dtype=torch.float16)
Expand All @@ -907,7 +903,6 @@ def test_sp24_meta() -> None:
assert x_s_t.shape == x.t().shape


@torch_compile_tests
@requires_sp24_gemm
@parametrize_backend
def test_sp24_compile(backend) -> None:
Expand Down Expand Up @@ -952,7 +947,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:


@requires_sp24_gemm
@torch_compile_tests
@pytest.mark.skipif(not sp24._has_cusparseLt(), reason="requires cusparselt")
def test_linearw24_block_compile() -> None:
# TODO: Parametrize on `dtype` when torch.compile gets faster
Expand Down
9 changes: 0 additions & 9 deletions tests/test_swiglu.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,6 @@
_devices = []
_is_sm80 = False
cuda_sm80_only = pytest.mark.skipif(not _is_sm80, reason="requires sm80+")

torch_compile_tests = pytest.mark.skipif(
torch.__version__ < "2.2.0.dev20231122", reason="requires PyTorch 2.2+"
)

disable_on_rocm = pytest.mark.skipif(
not not torch.version.hip, reason="could not be done on ROCM"
)
Expand Down Expand Up @@ -249,7 +244,6 @@ def backward_gather_grads(inp, output):
assert gout.norm(2) > BACKWARD_ATOL[dtype] / BACKWARD_RTOL[dtype]


@torch_compile_tests
@cuda_sm80_only
@pytest.mark.parametrize("device", _devices)
@pytest.mark.parametrize("dtype", _dtypes, ids=[str(x) for x in _dtypes])
Expand Down Expand Up @@ -346,7 +340,6 @@ def test_swiglu_compile(

@disable_tf32
@torch.inference_mode()
@torch_compile_tests
@cuda_sm80_only
@pytest.mark.parametrize("dtype", _dtypes, ids=[str(x) for x in _dtypes])
@pytest.mark.parametrize("device", _devices)
Expand Down Expand Up @@ -388,7 +381,6 @@ def fn(x):
@disable_tf32
@torch.inference_mode()
@cuda_sm80_only
@torch_compile_tests
@pytest.mark.parametrize("dtype", _dtypes, ids=[str(x) for x in _dtypes])
@pytest.mark.parametrize("device", _devices)
def test_gemm_fused_operand_sum_compile(dtype, device) -> None:
Expand Down Expand Up @@ -420,7 +412,6 @@ def fn(x):

@disable_tf32
@torch.inference_mode()
@torch_compile_tests
@pytest.mark.parametrize("dtype", _dtypes, ids=[str(x) for x in _dtypes])
@pytest.mark.parametrize("device", _devices)
def test_silu_bw_fused_compile(dtype, device) -> None:
Expand Down
4 changes: 0 additions & 4 deletions xformers/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,10 +474,6 @@ def __call__(self, ctx, func, *args, **kwargs) -> bool:

class SelectiveCheckpointWrapper(ActivationWrapper):
def __init__(self, mod, memory_budget=None, policy_fn=None):
if torch.__version__ < (2, 1):
raise RuntimeError(
"SelectiveCheckpointWrapper only supported for torch >- 2.1"
)
super().__init__(mod)
if not ((memory_budget is None) ^ (policy_fn is None)):
raise ValueError("Need to specify either policy_fn or memory_budget")
Expand Down
21 changes: 0 additions & 21 deletions xformers/csrc/autocast.h

This file was deleted.

3 changes: 1 addition & 2 deletions xformers/csrc/sparse24/sparse24_apply_dense_output.cu
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
#include <ATen/autocast_mode.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/library.h>
#include "autocast.h"
#include "compute_sparse_tile.h"
#include "sparse24_pack.h"

Expand Down Expand Up @@ -215,7 +214,7 @@ at::Tensor sparse24_apply_dense_output_autocast(
double mul0,
double mul1) {
c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast);
auto exec_type = xformers::get_autocast_cuda_dtype();
auto exec_type = at::autocast::get_autocast_dtype(at::kCUDA);
return sparse24_apply_dense_output(
at::autocast::cached_cast(exec_type, input), threads_masks, mul0, mul1);
}
Expand Down
3 changes: 1 addition & 2 deletions xformers/csrc/sparse24/sparse24_pack.cu
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
#include <c10/cuda/CUDAGuard.h>
#include <torch/library.h>
#include <torch/types.h>
#include "autocast.h"
#include "compute_sparse_tile.h"
#include "sparse24_metadata.h"
#include "sparse24_pack.h"
Expand Down Expand Up @@ -155,7 +154,7 @@ std::
std::string algorithm,
std::string backend) {
c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast);
auto exec_type = xformers::get_autocast_cuda_dtype();
auto exec_type = at::autocast::get_autocast_dtype(at::kCUDA);
return sparse24_sparsify_both_ways(
at::autocast::cached_cast(exec_type, input), algorithm, backend);
}
Expand Down
4 changes: 1 addition & 3 deletions xformers/csrc/swiglu/cuda/dual_gemm_silu_identity_mul.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
#include <45_dual_gemm/device/dual_gemm.h>
#include <45_dual_gemm/thread/left_silu_and_mul.h>

#include "autocast.h"

namespace {

template <typename scalar_t>
Expand Down Expand Up @@ -205,7 +203,7 @@ dual_gemm_silu_identity_mul_autocast(
const at::Tensor& w1,
const std::optional<at::Tensor>& b1) {
c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast);
auto exec_type = xformers::get_autocast_cuda_dtype();
auto exec_type = at::autocast::get_autocast_dtype(at::kCUDA);
return dual_gemm_silu_identity_mul(
at::autocast::cached_cast(exec_type, x),
at::autocast::cached_cast(exec_type, w0),
Expand Down
4 changes: 1 addition & 3 deletions xformers/csrc/swiglu/cuda/gemm_fused_operand_sum.cu
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@
#include "cutlass/reduction/kernel/reduce_split_k.h"
#include "cutlass/reduction/thread/reduction_operators.h"

#include "autocast.h"

namespace {
template <typename scalar_t>
void gemm_fused_operand_sum_(
Expand Down Expand Up @@ -245,7 +243,7 @@ std::tuple<at::Tensor, at::Tensor> gemm_fused_operand_sum_autocast(
at::Tensor& out_mm,
at::Tensor& out_sum) {
c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast);
auto exec_type = xformers::get_autocast_cuda_dtype();
auto exec_type = at::autocast::get_autocast_dtype(at::kCUDA);
return gemm_fused_operand_sum(
at::autocast::cached_cast(exec_type, a),
at::autocast::cached_cast(exec_type, b),
Expand Down
4 changes: 1 addition & 3 deletions xformers/csrc/swiglu/cuda/silu_bw_fused.cu
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@
#include <torch/library.h>
#include <ATen/native/cuda/Loops.cuh>

#include "autocast.h"

namespace {
/*
Computes the following:
Expand Down Expand Up @@ -104,7 +102,7 @@ std::tuple<at::Tensor, at::Tensor> silu_bw_fused_autocast(
const at::Tensor& x2,
const at::Tensor& dx4) {
c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast);
auto exec_type = xformers::get_autocast_cuda_dtype();
auto exec_type = at::autocast::get_autocast_dtype(at::kCUDA);
return silu_bw_fused(
at::autocast::cached_cast(exec_type, x1),
at::autocast::cached_cast(exec_type, x2),
Expand Down
4 changes: 1 addition & 3 deletions xformers/csrc/swiglu/swiglu_packedw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@
#include <torch/library.h>
// clang-format on

#include "autocast.h"

namespace {
// Kernels implemented in `cuda/`
std::tuple<at::Tensor, at::Tensor, at::Tensor> dual_gemm_silu_identity_mul(
Expand Down Expand Up @@ -198,7 +196,7 @@ at::Tensor swiglu_packedw_autocast(
const at::Tensor w3,
const std::optional<at::Tensor> b3) {
c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast);
auto exec_type = xformers::get_autocast_cuda_dtype();
auto exec_type = at::autocast::get_autocast_dtype(at::kCUDA);
return SwiGLUPackedWeights::apply(
at::autocast::cached_cast(exec_type, x),
at::autocast::cached_cast(exec_type, w1w2),
Expand Down
5 changes: 2 additions & 3 deletions xformers/ops/sp24.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,9 +475,8 @@ def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
)


if torch.__version__ >= "2.1.0":
torch._dynamo.allow_in_graph(Sparse24TensorCuSparseLt)
torch._dynamo.allow_in_graph(Sparse24TensorCutlass)
torch._dynamo.allow_in_graph(Sparse24TensorCuSparseLt)
torch._dynamo.allow_in_graph(Sparse24TensorCutlass)

GRADIENT_SP24 = "24sparse"
GRADIENT_DENSE = "24dense"
Expand Down
6 changes: 3 additions & 3 deletions xformers/ops/swiglu_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
import torch
import torch.nn.functional as F
from torch import nn
from torch.amp import custom_bwd, custom_fwd

from ..utils import custom_bwd, custom_fwd
from .common import BaseOperator, get_xformers_operator, register_operator
from .unbind import stack_or_none, unbind

Expand Down Expand Up @@ -110,7 +110,7 @@ class _SwiGLUFusedFunc(torch.autograd.Function):
NAME = "fused.py"

@classmethod
@custom_fwd
@custom_fwd(device_type="cuda")
def forward(cls, ctx, x, w1, b1, w2, b2, w3, b3):
x1, x2, x4 = DualGemmSiluOp.OPERATOR(x, w1, b1, w2, b2)

Expand All @@ -131,7 +131,7 @@ def _linear_bw(
return dw, db

@classmethod
@custom_bwd
@custom_bwd(device_type="cuda")
def backward(cls, ctx, dx5):
x, w1, w2, w3, x1, x2 = ctx.saved_tensors
w1w2 = stack_or_none([w1, w2], dim=0)
Expand Down
1 change: 0 additions & 1 deletion xformers/sparse/csr_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ def __new__(cls, row_offsets, column_indices, values, shape):
kwargs["layout"] = values.layout
kwargs["requires_grad"] = values.requires_grad
assert len(shape) == 3
assert torch.__version__ > (1, 10), "SparseCSRTensor requires PyTorch 1.11+"
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs)

def __init__(self, row_offsets, column_indices, values, shape):
Expand Down
12 changes: 0 additions & 12 deletions xformers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,6 @@

import torch

# PyTorch 2.4 introduced new functions and immediately marked the old ones as
# deprecated, causing a lot of log spew. Let's use the new ones if available.
try:
from torch.amp import custom_bwd as new_custom_bwd # type: ignore[attr-defined]
from torch.amp import custom_fwd as new_custom_fwd # type: ignore[attr-defined]

custom_fwd = new_custom_fwd(device_type="cuda")
custom_bwd = new_custom_bwd(device_type="cuda")
except ImportError:
from torch.cuda.amp import custom_bwd, custom_fwd # type: ignore # noqa: F401


Item = namedtuple("Item", ["constructor", "config"])


Expand Down

0 comments on commit 57227c6

Please sign in to comment.