From 3a25236f8d969f87d5e20f819b7e0a641e4faf63 Mon Sep 17 00:00:00 2001 From: lshamis Date: Fri, 11 Oct 2024 20:25:36 +0000 Subject: [PATCH] remove pre-dep torch (fairinternal/xformers#1237) * remove pre-dep torch * fix lint * isort * silly mistake --------- Co-authored-by: Leonid Shamis __original_commit__ = fairinternal/xformers@51111b0c7affd6c0c88b3beae45c62e47143bf83 --- setup.py | 233 +++++++++++++++++++++++++++++-------------------------- 1 file changed, 125 insertions(+), 108 deletions(-) diff --git a/setup.py b/setup.py index 4f0da97fb0..2530854d23 100644 --- a/setup.py +++ b/setup.py @@ -21,30 +21,10 @@ from typing import List, Optional import setuptools -import torch -from torch.utils.cpp_extension import ( - CUDA_HOME, - ROCM_HOME, - BuildExtension, - CppExtension, - CUDAExtension, -) - -this_dir = os.path.dirname(__file__) -pt_attn_compat_file_path = os.path.join( - this_dir, "xformers", "ops", "fmha", "torch_attention_compat.py" -) +from setuptools.command.build_ext import build_ext -# Define the module name -module_name = "torch_attention_compat" - -# Load the module -spec = importlib.util.spec_from_file_location(module_name, pt_attn_compat_file_path) -assert spec is not None -attn_compat_module = importlib.util.module_from_spec(spec) -sys.modules[module_name] = attn_compat_module -assert spec.loader is not None -spec.loader.exec_module(attn_compat_module) +this_file = Path(__file__) +this_dir = this_file.parent def get_extra_nvcc_flags_for_build_type(cuda_version: int) -> List[str]: @@ -70,18 +50,18 @@ def fetch_requirements(): def get_local_version_suffix() -> str: - if not (Path(__file__).parent / ".git").is_dir(): + if not (this_dir / ".git").is_dir(): # Most likely installing from a source distribution return "" date_suffix = datetime.datetime.now().strftime("%Y%m%d") git_hash = subprocess.check_output( - ["git", "rev-parse", "--short", "HEAD"], cwd=Path(__file__).parent + ["git", "rev-parse", "--short", "HEAD"], cwd=this_dir ).decode("ascii")[:-1] return f"+{git_hash}.d{date_suffix}" def get_flash_version() -> str: - flash_dir = Path(__file__).parent / "third_party" / "flash-attention" + flash_dir = this_dir / "third_party" / "flash-attention" try: return subprocess.check_output( ["git", "describe", "--tags", "--always"], @@ -104,17 +84,15 @@ def generate_version_py(version: str) -> str: def symlink_package(name: str, path: Path, is_building_wheel: bool) -> None: - cwd = Path(__file__).resolve().parent + cwd = this_file.resolve().parent path_from = cwd / path - path_to = os.path.join(cwd, *name.split(".")) + path_to = cwd / Path(name.replace(".", os.sep)) try: - if os.path.islink(path_to): - os.unlink(path_to) - elif os.path.isdir(path_to): + if path_to.is_dir() and not path_to.is_symlink(): shutil.rmtree(path_to) else: - os.remove(path_to) + path_to.unlink() except FileNotFoundError: pass # OSError: [WinError 1314] A required privilege is not held by the client @@ -123,6 +101,7 @@ def symlink_package(name: str, path: Path, is_building_wheel: bool) -> None: # So we force a copy, see #611 use_symlink = os.name != "nt" and not is_building_wheel if use_symlink: + # path_to.symlink_to(path_from) os.symlink(src=path_from, dst=path_to) else: shutil.copytree(src=path_from, dst=path_to) @@ -211,29 +190,31 @@ def get_flash_attention2_nvcc_archs_flags(cuda_version: int): def get_flash_attention2_extensions(cuda_version: int, extra_compile_args): + from torch.utils.cpp_extension import CUDAExtension + nvcc_archs_flags = get_flash_attention2_nvcc_archs_flags(cuda_version) if not nvcc_archs_flags: return [] - flash_root = os.path.join(this_dir, "third_party", "flash-attention") - cutlass_inc = os.path.join(flash_root, "csrc", "cutlass", "include") - if not os.path.exists(flash_root) or not os.path.exists(cutlass_inc): + flash_root = this_dir / "third_party" / "flash-attention" + cutlass_inc = flash_root / "csrc" / "cutlass" / "include" + if not flash_root.exists() or not cutlass_inc.exists(): raise RuntimeError( "flashattention submodule not found. Did you forget " "to run `git submodule update --init --recursive` ?" ) sources = ["csrc/flash_attn/flash_api.cpp"] - for f in glob.glob(os.path.join(flash_root, "csrc", "flash_attn", "src", "*.cu")): - if "hdim224" in Path(f).name: + for f in (flash_root / "csrc" / "flash_attn" / "src").glob("*.cu"): + if "hdim224" in f.name: continue - sources.append(str(Path(f).relative_to(flash_root))) + sources.append(str(f.relative_to(flash_root))) common_extra_compile_args = ["-DFLASHATTENTION_DISABLE_ALIBI"] return [ CUDAExtension( name="xformers._C_flashattention", - sources=[os.path.join(flash_root, path) for path in sources], + sources=[str(flash_root / path) for path in sources], extra_compile_args={ "cxx": extra_compile_args.get("cxx", []) + common_extra_compile_args, "nvcc": extra_compile_args.get("nvcc", []) @@ -256,9 +237,9 @@ def get_flash_attention2_extensions(cuda_version: int, extra_compile_args): include_dirs=[ p.absolute() for p in [ - Path(flash_root) / "csrc" / "flash_attn", - Path(flash_root) / "csrc" / "flash_attn" / "src", - Path(flash_root) / "csrc" / "cutlass" / "include", + flash_root / "csrc" / "flash_attn", + flash_root / "csrc" / "flash_attn" / "src", + flash_root / "csrc" / "cutlass" / "include", ] ], ) @@ -269,6 +250,8 @@ def get_flash_attention2_extensions(cuda_version: int, extra_compile_args): # FLASH-ATTENTION v3 ###################################### def get_flash_attention3_nvcc_archs_flags(cuda_version: int): + import torch + if os.getenv("XFORMERS_DISABLE_FLASH_ATTN", "0") != "0": return [] if platform.system() != "Linux" or cuda_version < 1203: @@ -297,29 +280,33 @@ def get_flash_attention3_nvcc_archs_flags(cuda_version: int): def get_flash_attention3_extensions(cuda_version: int, extra_compile_args): + from torch.utils.cpp_extension import CUDAExtension + nvcc_archs_flags = get_flash_attention3_nvcc_archs_flags(cuda_version) if not nvcc_archs_flags: return [] - flash_root = os.path.join(this_dir, "third_party", "flash-attention") - cutlass_inc = os.path.join(flash_root, "csrc", "cutlass", "include") - if not os.path.exists(flash_root) or not os.path.exists(cutlass_inc): + flash_root = this_dir / "third_party" / "flash-attention" + cutlass_inc = flash_root / "csrc" / "cutlass" / "include" + if not flash_root.exists() or not cutlass_inc.exists(): raise RuntimeError( "flashattention submodule not found. Did you forget " "to run `git submodule update --init --recursive` ?" ) - sources = [ - str(Path(f).relative_to(flash_root)) - for f in glob.glob(os.path.join(flash_root, "hopper", "*.cu")) - + glob.glob(os.path.join(flash_root, "hopper", "*.cpp")) + sources = [] + sources += [ + str(f.relative_to(flash_root)) for f in (flash_root / "hopper").glob("*.cu") + ] + sources += [ + str(f.relative_to(flash_root)) for f in (flash_root / "hopper").glob("*.cpp") ] sources = [s for s in sources if "flash_bwd_hdim256_fp16_sm90.cu" not in s] return [ CUDAExtension( name="xformers._C_flashattention3", - sources=[os.path.join(flash_root, path) for path in sources], + sources=[str(flash_root / path) for path in sources], extra_compile_args={ "cxx": extra_compile_args.get("cxx", []), "nvcc": extra_compile_args.get("nvcc", []) @@ -352,8 +339,8 @@ def get_flash_attention3_extensions(cuda_version: int, extra_compile_args): include_dirs=[ p.absolute() for p in [ - Path(flash_root) / "csrc" / "cutlass" / "include", - Path(flash_root) / "hopper", + flash_root / "csrc" / "cutlass" / "include", + flash_root / "hopper", ] ], ) @@ -366,6 +353,29 @@ def rename_cpp_cu(cpp_files): def get_extensions(): + import torch + from torch.utils.cpp_extension import ( + CUDA_HOME, + ROCM_HOME, + CppExtension, + CUDAExtension, + ) + + pt_attn_compat_file_path = ( + this_dir / "xformers" / "ops" / "fmha" / "torch_attention_compat.py" + ) + + # Define the module name + module_name = "torch_attention_compat" + + # Load the module + spec = importlib.util.spec_from_file_location(module_name, pt_attn_compat_file_path) + assert spec is not None + attn_compat_module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = attn_compat_module + assert spec.loader is not None + spec.loader.exec_module(attn_compat_module) + extensions_dir = os.path.join("xformers", "csrc") sources = glob.glob(os.path.join(extensions_dir, "**", "*.cpp"), recursive=True) @@ -392,7 +402,7 @@ def get_extensions(): source_cuda = list(set(source_cuda) - set(source_hip_generated)) sources = list(set(sources) - set(source_hip)) - sputnik_dir = os.path.join(this_dir, "third_party", "sputnik") + sputnik_dir = this_dir / "third_party" / "sputnik" xformers_pt_cutlass_attn = os.getenv("XFORMERS_PT_CUTLASS_ATTN") # By default, we try to link to torch internal CUTLASS attention implementation @@ -405,12 +415,12 @@ def get_extensions(): ): source_cuda = list(set(source_cuda) - set(fmha_source_cuda)) - cutlass_dir = os.path.join(this_dir, "third_party", "cutlass", "include") - cutlass_util_dir = os.path.join( - this_dir, "third_party", "cutlass", "tools", "util", "include" + cutlass_dir = this_dir / "third_party" / "cutlass" / "include" + cutlass_util_dir = ( + this_dir / "third_party" / "cutlass" / "tools" / "util" / "include" ) - cutlass_examples_dir = os.path.join(this_dir, "third_party", "cutlass", "examples") - if not os.path.exists(cutlass_dir): + cutlass_examples_dir = this_dir / "third_party" / "cutlass" / "examples" + if not cutlass_dir.exists(): raise RuntimeError( f"CUTLASS submodule not found at {cutlass_dir}. " "Did you forget to run " @@ -438,7 +448,7 @@ def get_extensions(): use_pt_flash = False if ( - (torch.cuda.is_available() and ((CUDA_HOME is not None))) + (torch.cuda.is_available() and (CUDA_HOME is not None)) or os.getenv("FORCE_CUDA", "0") == "1" or os.getenv("TORCH_CUDA_ARCH_LIST", "") != "" ): @@ -446,10 +456,10 @@ def get_extensions(): extension = CUDAExtension sources += source_cuda include_dirs += [ - sputnik_dir, - cutlass_dir, - cutlass_util_dir, - cutlass_examples_dir, + str(sputnik_dir), + str(cutlass_dir), + str(cutlass_util_dir), + str(cutlass_examples_dir), ] nvcc_flags = [ "-DHAS_PYTORCH", @@ -534,12 +544,10 @@ def get_extensions(): extension = CUDAExtension sources += source_hip_cu - include_dirs += [ - Path(this_dir) / "xformers" / "csrc" / "attention" / "hip_fmha" - ] + include_dirs += [this_dir / "xformers" / "csrc" / "attention" / "hip_fmha"] include_dirs += [ - Path(this_dir) / "third_party" / "composable_kernel_tiled" / "include" + this_dir / "third_party" / "composable_kernel_tiled" / "include" ] generator_flag = [] @@ -627,41 +635,61 @@ def run(self): distutils.command.clean.clean.run(self) -class BuildExtensionWithExtraFiles(BuildExtension): - def __init__(self, *args, **kwargs) -> None: - self.xformers_build_metadata = kwargs.pop("extra_files") - self.pkg_name = "xformers" - super().__init__(*args, **kwargs) - - def build_extensions(self) -> None: - super().build_extensions() - for filename, content in self.xformers_build_metadata.items(): - with open( - os.path.join(self.build_lib, self.pkg_name, filename), "w+" - ) as fp: - fp.write(content) - - def copy_extensions_to_source(self) -> None: - """ - Used for `pip install -e .` - Copies everything we built back into the source repo - """ - build_py = self.get_finalized_command("build_py") - package_dir = build_py.get_package_dir(self.pkg_name) - - for filename in self.xformers_build_metadata.keys(): - inplace_file = os.path.join(package_dir, filename) - regular_file = os.path.join(self.build_lib, self.pkg_name, filename) - self.copy_file(regular_file, inplace_file, level=self.verbose) - super().copy_extensions_to_source() +class TorchBuildExtension(build_ext): + def run(self): + from torch.utils.cpp_extension import BuildExtension + + class BuildExtensionWithExtraFiles(BuildExtension): + def __init__(self, *args, **kwargs) -> None: + self.xformers_build_metadata = kwargs.pop("extra_files") + self.pkg_name = "xformers" + super().__init__(*args, **kwargs) + + def build_extensions(self) -> None: + super().build_extensions() + for filename, content in self.xformers_build_metadata.items(): + with open( + os.path.join(self.build_lib, self.pkg_name, filename), "w+" + ) as fp: + fp.write(content) + + def copy_extensions_to_source(self) -> None: + """ + Used for `pip install -e .` + Copies everything we built back into the source repo + """ + build_py = self.get_finalized_command("build_py") + package_dir = build_py.get_package_dir(self.pkg_name) + + for filename in self.xformers_build_metadata.keys(): + inplace_file = os.path.join(package_dir, filename) + regular_file = os.path.join(self.build_lib, self.pkg_name, filename) + self.copy_file(regular_file, inplace_file, level=self.verbose) + super().copy_extensions_to_source() + + extensions, extensions_metadata = get_extensions() + + setuptools.setup( + ext_modules=extensions, + cmdclass={ + "build_ext": BuildExtensionWithExtraFiles.with_options( + no_python_abi_suffix=True, + extra_files={ + "cpp_lib.json": json.dumps(extensions_metadata), + "version.py": generate_version_py(version), + }, + ), + "clean": clean, + }, + ) if __name__ == "__main__": if os.getenv("BUILD_VERSION"): # In CI version = os.getenv("BUILD_VERSION", "0.0.0") else: - version_txt = os.path.join(this_dir, "version.txt") - with open(version_txt) as f: + version_txt = this_dir / "version.txt" + with version_txt.open() as f: version = f.readline().strip() version += get_local_version_suffix() @@ -676,24 +704,13 @@ def copy_extensions_to_source(self) -> None: Path("third_party") / "flash-attention" / "flash_attn", is_building_wheel, ) - extensions, extensions_metadata = get_extensions() setuptools.setup( name="xformers", description="XFormers: A collection of composable Transformer building blocks.", version=version, install_requires=fetch_requirements(), packages=setuptools.find_packages(exclude=("tests*", "benchmarks*")), - ext_modules=extensions, - cmdclass={ - "build_ext": BuildExtensionWithExtraFiles.with_options( - no_python_abi_suffix=True, - extra_files={ - "cpp_lib.json": json.dumps(extensions_metadata), - "version.py": generate_version_py(version), - }, - ), - "clean": clean, - }, + cmdclass={"build_ext": TorchBuildExtension}, url="https://facebookresearch.github.io/xformers/", python_requires=">=3.7", author="Facebook AI Research",