forked from ivan-chai/torch-linear-assignment
-
Notifications
You must be signed in to change notification settings - Fork 0
/
setup.py
68 lines (53 loc) · 1.84 KB
/
setup.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import os
from setuptools import setup
with open("requirements.txt", "r") as fp:
required_packages = [line.strip() for line in fp.readlines()]
def is_cuda() -> bool:
import torch
return torch.backends.cuda.is_built() and int(os.environ.get("TLA_BUILD_CUDA", "1")) and torch.cuda.is_available()
def generate_cuda_ext_modules() -> list:
import torch.utils.cpp_extension as torch_cpp_ext
compile_args = {
"cxx": ["-O3"]
}
if os.environ.get("CC", None) is not None:
compile_args["nvcc"] = ["-ccbin", os.environ["CC"]]
return [
torch_cpp_ext.CUDAExtension(
"torch_linear_assignment._backend",
[
"src/torch_linear_assignment_cuda.cpp",
"src/torch_linear_assignment_cuda_kernel.cu"
],
extra_compile_args=compile_args
)
]
def generate_cpu_ext_modules() -> list:
import torch.utils.cpp_extension as torch_cpp_ext
return [
torch_cpp_ext.CppExtension(
"torch_linear_assignment._backend",
[
"src/torch_linear_assignment.cpp",
],
extra_compile_args={"cxx": ["-O3"]}
)
]
def get_build_ext():
import torch.utils.cpp_extension as torch_cpp_ext
return torch_cpp_ext.BuildExtension
if __name__ == '__main__':
setup(
name="torch-linear-assignment",
version="0.0.1.post1",
author="Ivan Karpukhin",
author_email="[email protected]",
description="Batched linear assignment with PyTorch and CUDA.",
packages=["torch_linear_assignment"],
ext_modules=generate_cuda_ext_modules() if is_cuda() else generate_cpu_ext_modules(),
setup_requires=required_packages,
install_requires=required_packages,
cmdclass={
"build_ext": get_build_ext()
}
)