Skip to content

Commit

Permalink
optimize
Browse files Browse the repository at this point in the history
  • Loading branch information
yuzawa-san committed Jan 10, 2024
1 parent 083b5d0 commit b92b302
Show file tree
Hide file tree
Showing 5 changed files with 63 additions and 113 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,10 @@ public int getNumber() {
}

public static final OnnxRuntimeExecutionMode forNumber(int number) {
switch (number) {
case 1:
return PARALLEL;
case 0:
default:
return SEQUENTIAL;
}
return switch (number) {
case 0 -> SEQUENTIAL;
case 1 -> PARALLEL;
default -> SEQUENTIAL;
};
}
}
62 changes: 21 additions & 41 deletions src/main/java/com/jyuzawa/onnxruntime/OnnxRuntimeLoggingLevel.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import java.lang.invoke.MethodHandle;
import java.lang.invoke.MethodHandles;
import java.lang.invoke.MethodType;
import java.util.function.Supplier;

/**
* The level for the internal logger within the ONNX runtime.
Expand Down Expand Up @@ -51,19 +50,14 @@ public int getNumber() {
* @return the level, VERBOSE if not found
*/
public static final OnnxRuntimeLoggingLevel forNumber(int number) {
switch (number) {
case 1:
return INFO;
case 2:
return WARNING;
case 3:
return ERROR;
case 4:
return FATAL;
case 0:
default:
return VERBOSE;
}
return switch (number) {
case 0 -> VERBOSE;
case 1 -> INFO;
case 2 -> WARNING;
case 3 -> ERROR;
case 4 -> FATAL;
default -> VERBOSE;
};
}

@SuppressWarnings("unused")
Expand All @@ -74,33 +68,19 @@ private static final void logCallback(
MemoryAddress idAddress,
MemoryAddress locationAddress,
MemoryAddress messageAddress) {
String category = categoryAddress.address().getUtf8String(0);
String id = idAddress.address().getUtf8String(0);
String location = locationAddress.address().getUtf8String(0);
String message = messageAddress.address().getUtf8String(0);
Supplier<String> line = () -> new StringBuilder()
.append(category)
.append(' ')
.append(id)
.append(' ')
.append(location)
.append(' ')
.append(message)
.toString();
switch (OnnxRuntimeLoggingLevel.forNumber(level)) {
case VERBOSE:
LOG.log(Level.DEBUG, line);
break;
case INFO:
LOG.log(Level.INFO, line);
break;
case WARNING:
LOG.log(Level.WARNING, line);
break;
case FATAL:
case ERROR:
LOG.log(Level.ERROR, line);
break;
Level theLevel =
switch (OnnxRuntimeLoggingLevel.forNumber(level)) {
case VERBOSE -> Level.DEBUG;
case INFO -> Level.INFO;
case WARNING -> Level.WARNING;
case FATAL, ERROR -> Level.ERROR;
};
if (LOG.isLoggable(theLevel)) {
String category = categoryAddress.address().getUtf8String(0);
String id = idAddress.address().getUtf8String(0);
String location = locationAddress.address().getUtf8String(0);
String message = messageAddress.address().getUtf8String(0);
LOG.log(theLevel, category + ' ' + id + ' ' + location + ' ' + message);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,16 +26,12 @@ public int getNumber() {
}

public static final OnnxRuntimeOptimizationLevel forNumber(int number) {
switch (number) {
case 1:
return ENABLE_BASIC;
case 2:
return ENABLE_EXTENDED;
case 99:
return ENABLE_ALL;
case 0:
default:
return DISABLE_ALL;
}
return switch (number) {
case 0 -> DISABLE_ALL;
case 1 -> ENABLE_BASIC;
case 2 -> ENABLE_EXTENDED;
case 99 -> ENABLE_ALL;
default -> DISABLE_ALL;
};
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -57,43 +57,26 @@ public int getNumber() {
* @return the level, UNDEFINED if not found
*/
public static final OnnxTensorElementDataType forNumber(int number) {
switch (number) {
case 1:
return FLOAT;
case 2:
return UINT8;
case 3:
return INT8;
case 4:
return UINT16;
case 5:
return INT16;
case 6:
return INT32;
case 7:
return INT64;
case 8:
return STRING;
case 9:
return BOOL;
case 10:
return FLOAT16;
case 11:
return DOUBLE;
case 12:
return UINT32;
case 13:
return UINT64;
case 14:
return COMPLEX64;
case 15:
return COMPLEX128;
case 16:
return BFLOAT16;
case 0:
default:
return UNDEFINED;
}
return switch (number) {
case 0 -> UNDEFINED;
case 1 -> FLOAT;
case 2 -> UINT8;
case 3 -> INT8;
case 4 -> UINT16;
case 5 -> INT16;
case 6 -> INT32;
case 7 -> INT64;
case 8 -> STRING;
case 9 -> BOOL;
case 10 -> FLOAT16;
case 11 -> DOUBLE;
case 12 -> UINT32;
case 13 -> UINT64;
case 14 -> COMPLEX64;
case 15 -> COMPLEX128;
case 16 -> BFLOAT16;
default -> UNDEFINED;
};
}

ValueLayout getValueLayout() {
Expand Down
27 changes: 10 additions & 17 deletions src/main/java/com/jyuzawa/onnxruntime/OnnxType.java
Original file line number Diff line number Diff line change
Expand Up @@ -34,22 +34,15 @@ public int getNumber() {
* @return the level, UNKNOWN if not found
*/
public static final OnnxType forNumber(int number) {
switch (number) {
case 1:
return TENSOR;
case 2:
return SEQUENCE;
case 3:
return MAP;
case 4:
return OPAQUE;
case 5:
return SPARSETENSOR;
case 6:
return OPTIONAL;
case 0:
default:
return UNKNOWN;
}
return switch (number) {
case 0 -> UNKNOWN;
case 1 -> TENSOR;
case 2 -> SEQUENCE;
case 3 -> MAP;
case 4 -> OPAQUE;
case 5 -> SPARSETENSOR;
case 6 -> OPTIONAL;
default -> UNKNOWN;
};
}
}

0 comments on commit b92b302

Please sign in to comment.