Skip to content

Commit

Permalink
packaging: Use setuptools-scm (#666)
Browse files Browse the repository at this point in the history
  • Loading branch information
ur4t authored Dec 16, 2024
1 parent 51236c9 commit 08405a5
Show file tree
Hide file tree
Showing 8 changed files with 28 additions and 33 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ compile_commands.json
csrc/generated/
docs/generated/
flashinfer/_build_meta.py
flashinfer/_version.py
flashinfer/data/
flashinfer/jit/aot_config.py
src/generated/
Expand Down
8 changes: 5 additions & 3 deletions docs/conf.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import os
import sys
from pathlib import Path
from setuptools_scm import get_version
from packaging.version import Version

# import tlcpack_sphinx_addon
# Configuration file for the Sphinx documentation builder.
Expand All @@ -20,9 +22,9 @@
author = "FlashInfer Contributors"
copyright = f"2023-2024, {author}"

package_version = (root / "version.txt").read_text().strip()
version = package_version
release = package_version
package_version = Version(get_version(root=root, version_scheme="only-version"))
version = str(package_version)
release = package_version.base_version

# -- General configuration ---------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
Expand Down
1 change: 1 addition & 0 deletions docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ sphinx == 8.1.3
sphinx-reredirects == 0.1.5
sphinx-tabs == 3.4.5
sphinx-toolbox == 3.8.1
setuptools-scm == 8.1.0
2 changes: 1 addition & 1 deletion flashinfer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,4 +88,4 @@
from .sampling import top_p_sampling_from_probs as top_p_sampling_from_probs
from .sparse import BlockSparseAttentionWrapper as BlockSparseAttentionWrapper

from ._build_meta import __version__ as __version__
from ._version import __version__ as __version__
7 changes: 5 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ urls = { Homepage = "https://github.com/flashinfer-ai/flashinfer" }
dynamic = ["dependencies", "version"]

[build-system]
requires = ["setuptools"]
requires = ["setuptools", "setuptools-scm"]
build-backend = "custom_backend"
backend-path = ["."]

Expand All @@ -48,13 +48,16 @@ include-package-data = false
"flashinfer.data" = [
"csrc/**",
"include/**",
"version.txt"
]
"flashinfer.data.cutlass" = [
"include/**",
"tools/util/include/**"
]

[tool.setuptools_scm]
version_file = "flashinfer/_version.py"
version_scheme = "only-version"

[tool.mypy]
ignore_missing_imports = false
show_column_numbers = true
Expand Down
4 changes: 2 additions & 2 deletions scripts/run-ci-build-wheel.sh
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,13 @@ echo "::endgroup::"

echo "::group::Install build system"
pip install ninja numpy
pip install --upgrade setuptools wheel build
pip install --upgrade build setuptools setuptools-scm wheel
echo "::endgroup::"


echo "::group::Build wheel for FlashInfer"
cd "$PROJECT_ROOT"
FLASHINFER_ENABLE_AOT=1 FLASHINFER_LOCAL_VERSION="cu${CUDA_MAJOR}${CUDA_MINOR}torch${FLASHINFER_CI_TORCH_VERSION}" python -m build --no-isolation --wheel
FLASHINFER_ENABLE_AOT=1 python -m build --no-isolation --wheel
python -m build --no-isolation --sdist
ls -la dist/
echo "::endgroup::"
37 changes: 13 additions & 24 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

root = Path(__file__).parent.resolve()
gen_dir = root / "csrc" / "generated"
build_meta = root / "flashinfer" / "_build_meta.py"

head_dims = os.environ.get("FLASHINFER_HEAD_DIMS", "64,128,256").split(",")
pos_encoding_modes = os.environ.get("FLASHINFER_POS_ENCODING_MODES", "0").split(",")
Expand All @@ -46,24 +47,6 @@
enable_fp8 = os.environ.get("FLASHINFER_ENABLE_FP8", "1") == "1"


def get_version():
build_version = os.environ.get("FLASHINFER_BUILD_VERSION")
if build_version is not None:
return build_version
package_version = (root / "version.txt").read_text().strip()
local_version = os.environ.get("FLASHINFER_LOCAL_VERSION")
if local_version is None:
return package_version
return f"{package_version}+{local_version}"


def generate_build_meta(aot_build_meta: dict) -> None:
build_meta_str = f"__version__ = {get_version()!r}\n"
if len(aot_build_meta) != 0:
build_meta_str += f"build_meta = {aot_build_meta!r}\n"
(root / "flashinfer" / "_build_meta.py").write_text(build_meta_str)


def generate_cuda() -> None:
try: # no aot_build_utils in sdist
sys.path.append(str(root))
Expand Down Expand Up @@ -98,14 +81,16 @@ def generate_cuda() -> None:

ext_modules = []
cmdclass = {}
use_scm_version = {}
install_requires = ["torch", "ninja"]
generate_build_meta({})
build_meta.write_text("\n")
generate_cuda()

if enable_aot:
import torch
import torch.utils.cpp_extension as torch_cpp_ext
from packaging.version import Version
from setuptools_scm.version import get_local_node_and_date as default

def get_cuda_version() -> Version:
if torch_cpp_ext.CUDA_HOME is None:
Expand All @@ -131,17 +116,21 @@ def __init__(self, *args, **kwargs) -> None:
raise RuntimeError("FlashInfer requires sm75+")

cuda_version = get_cuda_version()
torch_version = Version(torch.__version__).base_version
cmdclass["build_ext"] = NinjaBuildExtension
install_requires = [f"torch == {torch_version}"]
torch_full_version = Version(torch.__version__)
torch_version = f"{torch_full_version.major}.{torch_full_version.minor}"
local_version = f"cu{cuda_version.major}{cuda_version.minor}torch{torch_version}"

aot_build_meta = {}
aot_build_meta["cuda_major"] = cuda_version.major
aot_build_meta["cuda_minor"] = cuda_version.minor
aot_build_meta["torch"] = torch_version
aot_build_meta["python"] = platform.python_version()
aot_build_meta["TORCH_CUDA_ARCH_LIST"] = os.environ.get("TORCH_CUDA_ARCH_LIST")
generate_build_meta(aot_build_meta)
build_meta.write_text(f"build_meta = {aot_build_meta!r}\n")

cmdclass["build_ext"] = NinjaBuildExtension
use_scm_version["local_scheme"] = lambda x: f"{default(x)}.{local_version}"
install_requires = [f"torch == {torch_version}"]

if enable_bf16:
torch_cpp_ext.COMMON_NVCC_FLAGS.append("-DFLASHINFER_ENABLE_BF16")
Expand Down Expand Up @@ -230,9 +219,9 @@ def __init__(self, *args, **kwargs) -> None:
]

setuptools.setup(
version=get_version(),
ext_modules=ext_modules,
cmdclass=cmdclass,
options={"bdist_wheel": {"py_limited_api": "cp38"}},
install_requires=install_requires,
use_scm_version=use_scm_version,
)
1 change: 0 additions & 1 deletion version.txt

This file was deleted.

0 comments on commit 08405a5

Please sign in to comment.