Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
  • Loading branch information
xyang16 committed Jul 12, 2024
1 parent d71301c commit b8ace11
Showing 1 changed file with 47 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ public final class LibUtils {
private static final Pattern VERSION_PATTERN =
Pattern.compile(
"(\\d+\\.\\d+\\.\\d+(-[a-z]+)?)-(\\d+\\.\\d+\\.\\d+)(-SNAPSHOT)?(-\\d+)?");
private static int[] SUPPORTED_CUDA_VERSIONS = {122};

private static EngineException exception;

Expand Down Expand Up @@ -88,7 +89,38 @@ private static Path copyJniLibrary(String[] libs) {
Platform platform = Platform.detectPlatform("tokenizers");
String os = platform.getOsPrefix();
String classifier = platform.getClassifier();
String flavor = platform.getFlavor();
String flavor = Utils.getEnvOrSystemProperty("TOKENIZERS_FLAVOR");
boolean override = flavor != null && !flavor.isEmpty();
if (override) {
logger.info("Uses override TOKENIZERS_FLAVOR: {}", flavor);
} else {
flavor = platform.getFlavor();
if (Utils.isOfflineMode() || "win".equals(os)) {
flavor = "cpu";
}
}

if (flavor.startsWith("cu")) {
int cudaVersion = Integer.parseInt(flavor.substring(2, 5));

// Find the highest matching CUDA version
boolean match = false;
for (int v : SUPPORTED_CUDA_VERSIONS) {
if (override && cudaVersion == v) {
match = true;
break;
} else if (cudaVersion >= v) {
flavor = "cu" + v;
match = true;
break;
}
}
if (!match) {
logger.warn("No matching cuda flavor for {} found: {}.", classifier, flavor);
flavor = "cpu"; // Fallback to CPU
}
}

String version = platform.getVersion();
Path dir = cacheDir.resolve(version + '-' + flavor + '-' + classifier);
Path path = dir.resolve(LIB_NAME);
Expand All @@ -97,53 +129,51 @@ private static Path copyJniLibrary(String[] libs) {
return dir.toAbsolutePath();
}

String resolvedFlavor = flavor;
if ("win".equals(os)) {
resolvedFlavor = "cpu";
} else if (flavor.startsWith("cu")) {
if ("cu122".compareTo(flavor) <= 0) { // cu122 onwards will resolve to cu122
resolvedFlavor = "cu122";
} else { // Else resolve to cpu
resolvedFlavor = "cpu";
}
// Copy JNI library from classpath
if (copyJniLibraryFromClasspath(libs, dir, classifier, flavor)) {
return dir.toAbsolutePath();
}

// Download JNI library
if (resolvedFlavor.startsWith("cu")) {
if (flavor.startsWith("cu")) {
Matcher matcher = VERSION_PATTERN.matcher(version);
if (!matcher.matches()) {
throw new EngineException("Unexpected version: " + version);
}
String jniVersion = matcher.group(1);
String djlVersion = matcher.group(3);

downloadJniLib(dir, path, djlVersion, jniVersion, classifier, resolvedFlavor);
downloadJniLib(dir, path, djlVersion, jniVersion, classifier, flavor);
return dir.toAbsolutePath();
}
return null;
}

// Extract JNI library from classpath
private static boolean copyJniLibraryFromClasspath(
String[] libs, Path cacheDir, String classifier, String flavor) {
Path tmp = null;
try {
Files.createDirectories(cacheDir);
tmp = Files.createTempDirectory(cacheDir, "tmp");

for (String libName : libs) {
String libPath = "native/lib/" + classifier + "/" + resolvedFlavor + "/" + libName;
String libPath = "native/lib/" + classifier + "/" + flavor + "/" + libName;
logger.info("Extracting {} to cache ...", libPath);
try (InputStream is = ClassLoaderUtils.getResourceAsStream(libPath)) {
Path target = tmp.resolve(libName);
Files.copy(is, target, StandardCopyOption.REPLACE_EXISTING);
}
}
Utils.moveQuietly(tmp, dir);
return dir.toAbsolutePath();
Utils.moveQuietly(tmp, cacheDir);
return true;
} catch (IOException e) {
throw new IllegalStateException("Cannot copy jni files", e);
logger.error("Cannot copy jni files", e);
} finally {
if (tmp != null) {
Utils.deleteQuietly(tmp);
}
}
return false;
}

private static void downloadJniLib(
Expand Down

0 comments on commit b8ace11

Please sign in to comment.