Skip to content

Commit

Permalink
Merge pull request #179 from yuzawa-san/iobinding2
Browse files Browse the repository at this point in the history
Add support for simple IoBinding
  • Loading branch information
yuzawa-san authored Jan 4, 2024
2 parents 3db50d5 + 0ef50ed commit e5aba6d
Show file tree
Hide file tree
Showing 12 changed files with 397 additions and 15 deletions.
2 changes: 1 addition & 1 deletion gradle.properties
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
version=1.3.2-SNAPSHOT
version=1.4.0-SNAPSHOT
com.jyuzawa.onnxruntime.library_version=1.16.0
com.jyuzawa.onnxruntime.library_baseline=1.3.0
org.gradle.parallel=true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,10 @@ public class Microbenchmark {

private static final String ONNXRUNTIME_JAVA = "onnxruntime-java";
private static final String ONNXRUNTIME_JAVA_ARENA = "onnxruntime-java-arena";
private static final String ONNXRUNTIME_JAVA_IOBINDING = "onnxruntime-java-iobinding";
private static final String MICROSOFT = "microsoft";

@Param(value = {ONNXRUNTIME_JAVA, ONNXRUNTIME_JAVA_ARENA, MICROSOFT})
@Param(value = {ONNXRUNTIME_JAVA, ONNXRUNTIME_JAVA_ARENA, ONNXRUNTIME_JAVA_IOBINDING, MICROSOFT})
private String implementation;

@Param({"16", "256", "4096"})
Expand Down Expand Up @@ -79,8 +80,9 @@ public void setup() throws Exception {
input[i] = random.nextLong();
}
wrapper = switch (implementation) {
case ONNXRUNTIME_JAVA -> new OnnxruntimeJava(bytes, false);
case ONNXRUNTIME_JAVA_ARENA -> new OnnxruntimeJava(bytes, true);
case ONNXRUNTIME_JAVA -> new OnnxruntimeJava(bytes, false, size);
case ONNXRUNTIME_JAVA_ARENA -> new OnnxruntimeJava(bytes, true, size);
case ONNXRUNTIME_JAVA_IOBINDING -> new OnnxruntimeJavaIoBinding(bytes, false, size);
case MICROSOFT -> new Microsoft(bytes);
default -> throw new IllegalArgumentException();};
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ public static final void main(String[] args) throws Exception {
.addOutput(ValueInfoProto.newBuilder().setName("output").setType(type)))
.build();
byte[] bytes = model.toByteArray();
List<Wrapper> wrappers = List.of(new OnnxruntimeJava(bytes, false), new Microsoft(bytes));
List<Wrapper> wrappers = List.of(new OnnxruntimeJava(bytes, false, input.length), new Microsoft(bytes));
long i = 0;
long startMs = System.currentTimeMillis();
while (i >= 0) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,22 +15,24 @@
import java.io.IOException;
import java.util.Map;

final class OnnxruntimeJava implements Wrapper {
class OnnxruntimeJava implements Wrapper {

private static final Environment ENVIRONMENT = OnnxRuntime.get()
.getApi()
.newEnvironment()
.setLogSeverityLevel(OnnxRuntimeLoggingLevel.WARNING)
.build();

private final Session session;
protected final Session session;
protected final long[] out;

OnnxruntimeJava(byte[] bytes, boolean arena) throws IOException {
OnnxruntimeJava(byte[] bytes, boolean arena, int size) throws IOException {
this.session = ENVIRONMENT
.newSession()
.setByteArray(bytes)
.addProvider(ExecutionProvider.CPU_EXECUTION_PROVIDER, Map.of("use_arena", arena ? "1" : "0"))
.build();
this.out = new long[size];
}

@Override
Expand All @@ -44,9 +46,8 @@ public long[] evaluate(long[] input) {
txn.addInput(0).asTensor().getLongBuffer().put(input);
txn.addOutput(0);
NamedCollection<OnnxValue> result = txn.run();
long[] output = new long[input.length];
result.get(0).asTensor().getLongBuffer().get(output);
return output;
result.get(0).asTensor().getLongBuffer().get(out);
return out;
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
/*
* Copyright (c) 2022 James Yuzawa (https://www.jyuzawa.com/)
* SPDX-License-Identifier: MIT
*/
package com.jyuzawa.onnxruntime_benchmark;

import com.jyuzawa.onnxruntime.IoBinding;
import java.io.IOException;
import java.nio.LongBuffer;

final class OnnxruntimeJavaIoBinding extends OnnxruntimeJava {

private final IoBinding ioBinding;
private final LongBuffer inputBuf;
private final LongBuffer outputBuf;

OnnxruntimeJavaIoBinding(byte[] bytes, boolean arena, int size) throws IOException {
super(bytes, arena, size);
this.ioBinding = session.newIoBinding().bindInput(0).bindOutput(0).build();
this.inputBuf = ioBinding.getInputs().get(0).asTensor().getLongBuffer();
this.outputBuf = ioBinding.getOutputs().get(0).asTensor().getLongBuffer();
}

@Override
public void close() throws Exception {
ioBinding.close();
super.close();
}

@Override
public long[] evaluate(long[] input) {
inputBuf.clear().put(input);
ioBinding.run();
outputBuf.rewind().get(out);
return out;
}
}
10 changes: 10 additions & 0 deletions src/main/java/com/jyuzawa/onnxruntime/ApiImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ final class ApiImpl implements Api {
final AddRunConfigEntry AddRunConfigEntry;
final AddSessionConfigEntry AddSessionConfigEntry;
final AllocatorFree AllocatorFree;
final BindInput BindInput;
final BindOutput BindOutput;
final CastTypeInfoToMapTypeInfo CastTypeInfoToMapTypeInfo;
final CastTypeInfoToSequenceTypeInfo CastTypeInfoToSequenceTypeInfo;
final CastTypeInfoToTensorInfo CastTypeInfoToTensorInfo;
Expand All @@ -38,6 +40,7 @@ final class ApiImpl implements Api {
final CreateCpuMemoryInfo CreateCpuMemoryInfo;
final CreateCUDAProviderOptions CreateCUDAProviderOptions;
final CreateDnnlProviderOptions CreateDnnlProviderOptions;
final CreateIoBinding CreateIoBinding;
final CreateTensorRTProviderOptions CreateTensorRTProviderOptions;
final CreateEnvWithCustomLogger CreateEnvWithCustomLogger;
final CreateEnvWithCustomLoggerAndGlobalThreadPools CreateEnvWithCustomLoggerAndGlobalThreadPools;
Expand Down Expand Up @@ -93,6 +96,7 @@ final class ApiImpl implements Api {
final ReleaseCUDAProviderOptions ReleaseCUDAProviderOptions;
final ReleaseDnnlProviderOptions ReleaseDnnlProviderOptions;
final ReleaseEnv ReleaseEnv;
final ReleaseIoBinding ReleaseIoBinding;
final ReleaseMemoryInfo ReleaseMemoryInfo;
final ReleaseModelMetadata ReleaseModelMetadata;
final ReleaseRunOptions ReleaseRunOptions;
Expand All @@ -109,6 +113,7 @@ final class ApiImpl implements Api {
final RunOptionsSetRunLogVerbosityLevel RunOptionsSetRunLogVerbosityLevel;
final RunOptionsSetRunTag RunOptionsSetRunTag;
final RunOptionsSetTerminate RunOptionsSetTerminate;
final RunWithBinding RunWithBinding;
final SetGlobalDenormalAsZero SetGlobalDenormalAsZero;
final SetGlobalIntraOpNumThreads SetGlobalIntraOpNumThreads;
final SetGlobalInterOpNumThreads SetGlobalInterOpNumThreads;
Expand Down Expand Up @@ -153,6 +158,8 @@ final class ApiImpl implements Api {
this.AddRunConfigEntry = OrtApi.AddRunConfigEntry(memorySegment, memorySession);
this.AddSessionConfigEntry = OrtApi.AddSessionConfigEntry(memorySegment, memorySession);
this.AllocatorFree = OrtApi.AllocatorFree(memorySegment, memorySession);
this.BindInput = OrtApi.BindInput(memorySegment, memorySession);
this.BindOutput = OrtApi.BindOutput(memorySegment, memorySession);
this.CastTypeInfoToMapTypeInfo = OrtApi.CastTypeInfoToMapTypeInfo(memorySegment, memorySession);
this.CastTypeInfoToSequenceTypeInfo = OrtApi.CastTypeInfoToSequenceTypeInfo(memorySegment, memorySession);
this.CastTypeInfoToTensorInfo = OrtApi.CastTypeInfoToTensorInfo(memorySegment, memorySession);
Expand All @@ -163,6 +170,7 @@ final class ApiImpl implements Api {
this.CreateCUDAProviderOptions = OrtApi.CreateCUDAProviderOptions(memorySegment, memorySession);
this.CreateDnnlProviderOptions = OrtApi.CreateDnnlProviderOptions(memorySegment, memorySession);
this.CreateTensorRTProviderOptions = OrtApi.CreateTensorRTProviderOptions(memorySegment, memorySession);
this.CreateIoBinding = OrtApi.CreateIoBinding(memorySegment, memorySession);
this.CreateEnvWithCustomLogger = OrtApi.CreateEnvWithCustomLogger(memorySegment, memorySession);
this.CreateEnvWithCustomLoggerAndGlobalThreadPools =
OrtApi.CreateEnvWithCustomLoggerAndGlobalThreadPools(memorySegment, memorySession);
Expand Down Expand Up @@ -220,6 +228,7 @@ final class ApiImpl implements Api {
this.ReleaseCUDAProviderOptions = OrtApi.ReleaseCUDAProviderOptions(memorySegment, memorySession);
this.ReleaseDnnlProviderOptions = OrtApi.ReleaseDnnlProviderOptions(memorySegment, memorySession);
this.ReleaseEnv = OrtApi.ReleaseEnv(memorySegment, memorySession);
this.ReleaseIoBinding = OrtApi.ReleaseIoBinding(memorySegment, memorySession);
this.ReleaseMemoryInfo = OrtApi.ReleaseMemoryInfo(memorySegment, memorySession);
this.ReleaseModelMetadata = OrtApi.ReleaseModelMetadata(memorySegment, memorySession);
this.ReleaseRunOptions = OrtApi.ReleaseRunOptions(memorySegment, memorySession);
Expand All @@ -236,6 +245,7 @@ final class ApiImpl implements Api {
this.RunOptionsSetRunLogVerbosityLevel = OrtApi.RunOptionsSetRunLogVerbosityLevel(memorySegment, memorySession);
this.RunOptionsSetRunTag = OrtApi.RunOptionsSetRunTag(memorySegment, memorySession);
this.RunOptionsSetTerminate = OrtApi.RunOptionsSetTerminate(memorySegment, memorySession);
this.RunWithBinding = OrtApi.RunWithBinding(memorySegment, memorySession);
this.SetGlobalDenormalAsZero = OrtApi.SetGlobalDenormalAsZero(memorySegment, memorySession);
this.SetGlobalIntraOpNumThreads = OrtApi.SetGlobalIntraOpNumThreads(memorySegment, memorySession);
this.SetGlobalInterOpNumThreads = OrtApi.SetGlobalInterOpNumThreads(memorySegment, memorySession);
Expand Down
100 changes: 100 additions & 0 deletions src/main/java/com/jyuzawa/onnxruntime/IoBinding.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
/*
* Copyright (c) 2022 James Yuzawa (https://www.jyuzawa.com/)
* SPDX-License-Identifier: MIT
*/
package com.jyuzawa.onnxruntime;

import java.util.Map;

/**
* A representation of a model evaluation. Capable of reuse in repetitive runs. More efficient than {@link Transaction}. Only supports tensors. Input and outputs are pre-allocated and must be of fixed size. This class is NOT thread-safe.
*
* @since 1.4.0
*/
public interface IoBinding extends AutoCloseable {

/**
* Set the severity for logging for this specific transaction.
* Can override the environment's or session's logger's severity.
* @param level
* @return this
*/
IoBinding setLogSeverityLevel(OnnxRuntimeLoggingLevel level);

/**
* Set the verbosity for logging for this specific transaction.
* Can override the environment's or session's logger's verbosity.
* @param level
* @return this
*/
IoBinding setLogVerbosityLevel(int level);

/**
* Set the run tag (which is the logger id)
* @param runTag
* @return this
*/
IoBinding setRunTag(String runTag);

NamedCollection<OnnxValue> getInputs();

NamedCollection<OnnxValue> getOutputs();

/**
* Run the model evaluation.
*/
void run();

/**
* Frees the native resources (typically buffers) associated with this transaction.
*/
@Override
void close();

/**
* A builder of a {@link IoBinding}. Should NOT be reused. This class is NOT thread-safe.
*
* @since 1.0.0
*/
public interface Builder {
/**
* Add an input and get an OnnxValue to populate.
* @param name
* @return the value to be populated
*/
Builder bindInput(String name);

/**
* Add an input and get an OnnxValue to populate.
* @param index
* @return the value to be populated
*/
Builder bindInput(int index);

/**
* Request a specific output to be produced.
* @param name
*/
Builder bindOutput(String name);

/**
* Request a specific output to be produced.
* @param index
*/
Builder bindOutput(int index);

/**
* Set custom parameters for this transaction.
* @param config
* @return the builder
*/
Builder setConfigMap(Map<String, String> config);

/**
* Construct a {@link IoBinding}.
*
* @return a new instance
*/
IoBinding build();
}
}
Loading

0 comments on commit e5aba6d

Please sign in to comment.