forked from dusty-nv/jetson-containers
-
Notifications
You must be signed in to change notification settings - Fork 0
/
build.sh
executable file
·37 lines (28 loc) · 1.56 KB
/
build.sh
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
#!/usr/bin/env bash
# JAX builder for Jetson (architecture: ARM64, CUDA support)
set -ex
# Install LLVM/clang dev packages
./llvm.sh 18 all
echo "Building JAX for Jetson"
# Clone JAX repository
git clone --branch "jaxlib-v${JAX_BUILD_VERSION}" --depth=1 --recursive https://github.com/google/jax /opt/jax || \
git clone --depth=1 --recursive https://github.com/google/jax /opt/jax
cd /opt/jax
# Build jaxlib from source with detected versions
BUILD_FLAGS='--enable_cuda --enable_nccl=False '
BUILD_FLAGS+='--cuda_compute_capabilities="sm_87" '
BUILD_FLAGS+='--cuda_version=12.6.0 --cudnn_version=9.4.0 '
BUILD_FLAGS+='--bazel_options=--repo_env=LOCAL_CUDA_PATH="/usr/local/cuda-12.6" '
BUILD_FLAGS+='--bazel_options=--repo_env=LOCAL_CUDNN_PATH="/opt/nvidia/cudnn/" '
BUILD_FLAGS+='--output_path=/opt/wheels '
python3 build/build.py $BUILD_FLAGS
python3 build/build.py $BUILD_FLAGS --build_gpu_kernel_plugin=cuda --build_gpu_plugin
# Build the jax pip wheels
pip3 wheel --wheel-dir=/opt/wheels --no-deps --verbose .
# Upload the wheels to mirror
twine upload --verbose /opt/wheels/jaxlib-*.whl || echo "failed to upload wheel to ${TWINE_REPOSITORY_URL}"
twine upload --verbose /opt/wheels/jax_cuda12_pjrt-*.whl || echo "failed to upload wheel to ${TWINE_REPOSITORY_URL}"
twine upload --verbose /opt/wheels/jax_cuda12_plugin-*.whl || echo "failed to upload wheel to ${TWINE_REPOSITORY_URL}"
twine upload --verbose /opt/wheels/jax-*.whl || echo "failed to upload wheel to ${TWINE_REPOSITORY_URL}"
# Install them into the container
pip3 install --verbose --no-cache-dir /opt/wheels/jax*.whl