-
Notifications
You must be signed in to change notification settings - Fork 2
/
setup.py
49 lines (44 loc) · 1.85 KB
/
setup.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
from setuptools import setup, Extension
from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CppExtension
import os
def find_cpp_files(directory, allow_cuda):
files = os.listdir(directory)
out = []
for f in files:
if f.endswith('.cpp') or (allow_cuda and f.endswith('.cu')):
out.append(os.path.join(directory, f))
elif os.path.isdir(f):
out = out + find_cpp_files(os.path.join(directory, f))
return out
# Detect if we have cuda installed and compile accordingly
native_ext = None
if 'CUDA_HOME' in os.environ or 'CUDA_PATH' in os.environ:
print('Detected CUDA, compiling with CUDA acceleration...')
native_ext = CUDAExtension(name='numml_torch_cpp',
sources=find_cpp_files('cpp', allow_cuda=True),
include_dirs=[
os.path.join(os.getcwd(), 'ext/cuCollections/include')
],
extra_compile_args={
'nvcc': ['-std=c++17'],
'cxx': ['-DCUDA_ENABLED=1', '-O2', '-std=c++17', '-w'],
})
else:
print('No CUDA detected, compiling CPU implementation only...')
native_ext = CppExtension(name='numml_torch_cpp',
sources=find_cpp_files('cpp', allow_cuda=False),
extra_compile_args=[
'-std=c++17',
'-DCUDA_ENABLED=0',
'-O2'
])
setup(name='numml',
version='0.0.1',
ext_modules=[native_ext],
cmdclass={
'build_ext': BuildExtension
},
author='Nicolas Nytko',
author_email='[email protected]',
packages=['numml', 'numml.sparse']
)