Skip to content

Commit

Permalink
add build meta
Browse files Browse the repository at this point in the history
  • Loading branch information
yzh119 committed Jan 28, 2024
1 parent 4c9ad8a commit c223b72
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 1 deletion.
26 changes: 26 additions & 0 deletions python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,14 @@
limitations under the License.
"""
import pathlib
import os
import re
import datetime
import subprocess
import platform

import setuptools
import torch
import torch.utils.cpp_extension as torch_cpp_ext

root = pathlib.Path(__name__).parent.resolve().parent
Expand All @@ -44,6 +48,28 @@ def get_version():
return version


def get_cuda_version() -> tuple[int, int]:
if torch_cpp_ext.CUDA_HOME is None:
nvcc = "nvcc"
else:
nvcc = os.path.join(torch_cpp_ext.CUDA_HOME, "bin/nvcc")
txt = subprocess.check_output([nvcc, "--version"], text=True)
major, minor = map(int, re.findall(r"release (\d+)\.(\d+),", txt)[0])
return major, minor


def generate_build_meta() -> None:
d = {}
version = get_version()
d["cuda_major"], d["cuda_minor"] = get_cuda_version()
d["torch"] = torch.__version__
d["python"] = platform.python_version()
d["TORCH_CUDA_ARCH_LIST"] = os.environ.get("TORCH_CUDA_ARCH_LIST", None)
with open(root / "python/flashinfer/_build_meta.py", "w") as f:
f.write(f"__version__ = {version!r}\n")
f.write(f"build_meta = {d!r}")


def remove_unwanted_pytorch_nvcc_flags():
REMOVE_NVCC_FLAGS = [
"-D__CUDA_NO_HALF_OPERATORS__",
Expand Down
1 change: 0 additions & 1 deletion python/version.txt

This file was deleted.

0 comments on commit c223b72

Please sign in to comment.