Skip to content

Commit

Permalink
multithread jmh
Browse files Browse the repository at this point in the history
  • Loading branch information
yuzawa-san committed Jun 7, 2023
1 parent 8030f9f commit c0e4ade
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import org.openjdk.jmh.annotations.Param;
import org.openjdk.jmh.annotations.Setup;
import org.openjdk.jmh.annotations.State;
import org.openjdk.jmh.annotations.Threads;
import org.openjdk.jmh.annotations.Warmup;
import org.openjdk.jmh.infra.Blackhole;

Expand All @@ -39,9 +40,10 @@
public class Microbenchmark {

private static final String ONNXRUNTIME_JAVA = "onnxruntime-java";
private static final String ONNXRUNTIME_JAVA_ARENA = "onnxruntime-java-arena";
private static final String MICROSOFT = "microsoft";

@Param(value = {ONNXRUNTIME_JAVA, MICROSOFT})
@Param(value = {ONNXRUNTIME_JAVA, ONNXRUNTIME_JAVA_ARENA, MICROSOFT})
private String implementation;

@Param({"16", "256", "4096"})
Expand Down Expand Up @@ -77,12 +79,14 @@ public void setup() throws Exception {
input[i] = random.nextLong();
}
wrapper = switch (implementation) {
case ONNXRUNTIME_JAVA -> new OnnxruntimeJava(bytes);
case ONNXRUNTIME_JAVA -> new OnnxruntimeJava(bytes, false);
case ONNXRUNTIME_JAVA_ARENA -> new OnnxruntimeJava(bytes, true);
case MICROSOFT -> new Microsoft(bytes);
default -> throw new IllegalArgumentException();};
}

@Benchmark
@Threads(Threads.MAX)
public void run(Blackhole bh) throws Exception {
bh.consume(wrapper.evaluate(input));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ public static final void main(String[] args) throws Exception {
.addOutput(ValueInfoProto.newBuilder().setName("output").setType(type)))
.build();
byte[] bytes = model.toByteArray();
List<Wrapper> wrappers = List.of(new OnnxruntimeJava(bytes), new Microsoft(bytes));
List<Wrapper> wrappers = List.of(new OnnxruntimeJava(bytes, false), new Microsoft(bytes));
long i = 0;
long startMs = System.currentTimeMillis();
while (i >= 0) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,15 @@
package com.jyuzawa.onnxruntime_benchmark;

import com.jyuzawa.onnxruntime.Environment;
import com.jyuzawa.onnxruntime.ExecutionProvider;
import com.jyuzawa.onnxruntime.NamedCollection;
import com.jyuzawa.onnxruntime.OnnxRuntime;
import com.jyuzawa.onnxruntime.OnnxRuntimeLoggingLevel;
import com.jyuzawa.onnxruntime.OnnxValue;
import com.jyuzawa.onnxruntime.Session;
import com.jyuzawa.onnxruntime.Transaction;
import java.io.IOException;
import java.util.Map;

final class OnnxruntimeJava implements Wrapper {

Expand All @@ -23,8 +25,12 @@ final class OnnxruntimeJava implements Wrapper {

private final Session session;

OnnxruntimeJava(byte[] bytes) throws IOException {
this.session = ENVIRONMENT.newSession().setByteArray(bytes).build();
OnnxruntimeJava(byte[] bytes, boolean arena) throws IOException {
this.session = ENVIRONMENT
.newSession()
.setByteArray(bytes)
.addProvider(ExecutionProvider.CPU_EXECUTION_PROVIDER, Map.of("use_arena", arena ? "1" : "0"))
.build();
}

@Override
Expand Down

0 comments on commit c0e4ade

Please sign in to comment.