Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[rust] Update gpu build pipeline to cu122 #3334

Merged
merged 6 commits into from
Jul 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions .github/workflows/native_s3_huggingface.yml
Original file line number Diff line number Diff line change
Expand Up @@ -216,14 +216,16 @@ jobs:
aws s3 sync extensions/tokenizers/jnilib/$DJL_VERSION/linux-aarch64 s3://djl-ai/publish/tokenizers/${TOKENIZERS_VERSION}/jnilib/$DJL_VERSION/linux-aarch64/
aws cloudfront create-invalidation --distribution-id E371VB8JQ6NRVY --paths "/tokenizers/${TOKENIZERS_VERSION}/jnilib/*"

build-tokenizers-jni-cu124:
build-tokenizers-jni-cu122:
if: github.repository == 'deepjavalibrary/djl'
runs-on: [ self-hosted, g5 ]
timeout-minutes: 30
needs: create-runners
container:
image: nvidia/cuda:12.4.1-cudnn-devel-ubuntu20.04
image: nvidia/cuda:12.2.2-cudnn8-devel-ubuntu20.04
options: --gpus all --runtime=nvidia
env:
CUDA_VERSION: cu122
steps:
- name: Install Environment
run: |
Expand Down Expand Up @@ -254,9 +256,8 @@ jobs:
${{ runner.os }}-gradle-
- name: Release JNI prep
run: |
CUDA_VERSION=cu124
. "$HOME/.cargo/env"
./gradlew :extensions:tokenizers:compileJNI -Pcuda=$CUDA_VERSION
./gradlew :extensions:tokenizers:compileJNI -Pcuda=${{ env.CUDA_VERSION }}
./gradlew -Pjni :extensions:tokenizers:test
- name: Configure AWS Credentials
uses: aws-actions/configure-aws-credentials@v2
Expand All @@ -268,13 +269,13 @@ jobs:
run: |
DJL_VERSION=$(awk -F '=' '/djl / {gsub(/ ?"/, "", $2); print $2}' gradle/libs.versions.toml)
TOKENIZERS_VERSION="$(awk -F '=' '/tokenizers/ {gsub(/ ?"/, "", $2); print $2}' gradle/libs.versions.toml)"
aws s3 sync extensions/tokenizers/jnilib/$DJL_VERSION/linux-x86_64/cu124 s3://djl-ai/publish/tokenizers/${TOKENIZERS_VERSION}/jnilib/$DJL_VERSION/linux-x86_64/cu124/
aws s3 sync extensions/tokenizers/jnilib/$DJL_VERSION/linux-x86_64/${{ env.CUDA_VERSION }} s3://djl-ai/publish/tokenizers/${TOKENIZERS_VERSION}/jnilib/${DJL_VERSION}/linux-x86_64/${{ env.CUDA_VERSION }}/
aws cloudfront create-invalidation --distribution-id E371VB8JQ6NRVY --paths "/tokenizers/${TOKENIZERS_VERSION}/jnilib/*"

stop-runners:
if: ${{ github.repository == 'deepjavalibrary/djl' && always() }}
runs-on: [ self-hosted, scheduler ]
needs: [ create-runners, build-tokenizer-jni-aarch64, build-tokenizers-jni-cu124 ]
needs: [ create-runners, build-tokenizer-jni-aarch64, build-tokenizers-jni-cu122 ]
steps:
- name: Stop all instances
run: |
Expand Down
6 changes: 4 additions & 2 deletions extensions/tokenizers/build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,10 @@ javac -sourcepath src/main/java/ src/main/java/ai/djl/huggingface/tokenizers/jni
javac -sourcepath src/main/java/ src/main/java/ai/djl/engine/rust/RustLibrary.java -h build/include -d build/classes

RUST_MANIFEST=rust/Cargo.toml
if [ -x "$(command -v nvcc)" ]; then
cargo build --manifest-path $RUST_MANIFEST --release --features cuda,flash-attn,cublaslt
if [[ "$FLAVOR" = "cpu"* ]]; then
cargo build --manifest-path $RUST_MANIFEST --release
elif [[ "$FLAVOR" = "cu"* && "$FLAVOR" > "cu121" ]]; then
cargo build --manifest-path $RUST_MANIFEST --release --features cuda,cublaslt,flash-attn
else
cargo build --manifest-path $RUST_MANIFEST --release
fi
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ public final class LibUtils {
private static final Pattern VERSION_PATTERN =
Pattern.compile(
"(\\d+\\.\\d+\\.\\d+(-[a-z]+)?)-(\\d+\\.\\d+\\.\\d+)(-SNAPSHOT)?(-\\d+)?");
private static final String FLAVOR_CU124 = "cu124";
private static final int[] SUPPORTED_CUDA_VERSIONS = {122};

private static EngineException exception;

Expand Down Expand Up @@ -89,29 +89,68 @@ private static Path copyJniLibrary(String[] libs) {
Platform platform = Platform.detectPlatform("tokenizers");
String os = platform.getOsPrefix();
String classifier = platform.getClassifier();
String flavor = platform.getFlavor();
String version = platform.getVersion();
String flavor = Utils.getEnvOrSystemProperty("TOKENIZERS_FLAVOR");
boolean override = flavor != null && !flavor.isEmpty();
if (override) {
logger.info("Uses override TOKENIZERS_FLAVOR: {}", flavor);
} else {
if (Utils.isOfflineMode() || "win".equals(os)) {
flavor = "cpu";
} else {
flavor = platform.getFlavor();
}
}

// Find the highest matching CUDA version
if (flavor.startsWith("cu")) {
int cudaVersion = Integer.parseInt(flavor.substring(2, 5));
boolean match = false;
for (int v : SUPPORTED_CUDA_VERSIONS) {
if (override && cudaVersion == v) {
match = true;
break;
} else if (cudaVersion >= v) {
flavor = "cu" + v;
match = true;
break;
}
}
if (!match) {
logger.warn("No matching cuda flavor for {} found: {}.", classifier, flavor);
flavor = "cpu"; // Fallback to CPU
}
}

Path dir = cacheDir.resolve(version + '-' + flavor + '-' + classifier);
Path path = dir.resolve(LIB_NAME);
logger.debug("Using cache dir: {}", dir);
if (Files.exists(path)) {
return dir.toAbsolutePath();
}

// For Linux cuda 12.x, download JNI library
if (flavor.startsWith("cu12") && !"win".equals(os)) {
// Copy JNI library from classpath
if (copyJniLibraryFromClasspath(libs, cacheDir, dir, classifier, flavor)) {
return dir.toAbsolutePath();
}

// Download JNI library
if (flavor.startsWith("cu")) {
Matcher matcher = VERSION_PATTERN.matcher(version);
if (!matcher.matches()) {
throw new EngineException("Unexpected version: " + version);
}
String jniVersion = matcher.group(1);
String djlVersion = matcher.group(3);

downloadJniLib(dir, path, djlVersion, jniVersion, classifier, FLAVOR_CU124);
downloadJniLib(dir, path, djlVersion, jniVersion, classifier, flavor);
return dir.toAbsolutePath();
}
return null;
}

// Extract JNI library from classpath
private static boolean copyJniLibraryFromClasspath(
String[] libs, Path cacheDir, Path dir, String classifier, String flavor) {
Path tmp = null;
try {
Files.createDirectories(cacheDir);
Expand All @@ -126,14 +165,15 @@ private static Path copyJniLibrary(String[] libs) {
}
}
Utils.moveQuietly(tmp, dir);
return dir.toAbsolutePath();
return true;
} catch (IOException e) {
throw new IllegalStateException("Cannot copy jni files", e);
logger.error("Cannot copy jni files", e);
} finally {
if (tmp != null) {
Utils.deleteQuietly(tmp);
}
}
return false;
}

private static void downloadJniLib(
Expand Down
Loading