Skip to content

Commit

Permalink
[Minor] Quick fix on python build for path and python3.8 (#133)
Browse files Browse the repository at this point in the history
1. This PR revert the changes of relative path in `setup.py`. See
comments:
#128 (comment)
2. Fix python3.8 build.
  • Loading branch information
esmeetu authored Feb 23, 2024
1 parent 4f65fbc commit a346b27
Showing 1 changed file with 7 additions and 5 deletions.
12 changes: 7 additions & 5 deletions python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
limitations under the License.
"""

from typing import List, Tuple

import pathlib
import os
import re
Expand All @@ -40,7 +42,7 @@
torch_cpp_ext.COMMON_NVCC_FLAGS.append("-DFLASHINFER_ENABLE_BF16")


def get_instantiation_cu() -> list[str]:
def get_instantiation_cu() -> List[str]:
prefix = "csrc/generated"
(root / prefix).mkdir(parents=True, exist_ok=True)
dtypes = {"fp16": "nv_half"}
Expand Down Expand Up @@ -154,12 +156,12 @@ def get_instantiation_cu() -> list[str]:
def get_version():
version = os.getenv("FLASHINFER_BUILD_VERSION")
if version is None:
with open((root / "version.txt").resolve()) as f:
with open(root / "version.txt") as f:
version = f.read().strip()
return version


def get_cuda_version() -> tuple[int, int]:
def get_cuda_version() -> Tuple[int, int]:
if torch_cpp_ext.CUDA_HOME is None:
nvcc = "nvcc"
else:
Expand Down Expand Up @@ -213,7 +215,7 @@ def remove_unwanted_pytorch_nvcc_flags():
]
+ get_instantiation_cu(),
include_dirs=[
str((root / "../include")).resolve(),
str(root.resolve() / "include"),
],
extra_compile_args={
"cxx": ["-O3"],
Expand All @@ -230,7 +232,7 @@ def remove_unwanted_pytorch_nvcc_flags():
license="Apache License 2.0",
description="FlashInfer: Kernel Library for LLM Serving",
url="https://github.com/flashinfer-ai/flashinfer",
python_requires=">=3.9",
python_requires=">=3.8",
ext_modules=ext_modules,
cmdclass={"build_ext": torch_cpp_ext.BuildExtension},
)

0 comments on commit a346b27

Please sign in to comment.