diff --git a/gradle.properties b/gradle.properties index 3e1dc03..7f0706f 100644 --- a/gradle.properties +++ b/gradle.properties @@ -1,4 +1,4 @@ -version=1.3.2-SNAPSHOT +version=1.4.0-SNAPSHOT com.jyuzawa.onnxruntime.library_version=1.16.0 com.jyuzawa.onnxruntime.library_baseline=1.3.0 org.gradle.parallel=true diff --git a/onnxruntime-benchmark/src/jmh/java/com/jyuzawa/onnxruntime_benchmark/Microbenchmark.java b/onnxruntime-benchmark/src/jmh/java/com/jyuzawa/onnxruntime_benchmark/Microbenchmark.java index fbf57de..95f73fe 100644 --- a/onnxruntime-benchmark/src/jmh/java/com/jyuzawa/onnxruntime_benchmark/Microbenchmark.java +++ b/onnxruntime-benchmark/src/jmh/java/com/jyuzawa/onnxruntime_benchmark/Microbenchmark.java @@ -41,9 +41,10 @@ public class Microbenchmark { private static final String ONNXRUNTIME_JAVA = "onnxruntime-java"; private static final String ONNXRUNTIME_JAVA_ARENA = "onnxruntime-java-arena"; + private static final String ONNXRUNTIME_JAVA_IOBINDING = "onnxruntime-java-iobinding"; private static final String MICROSOFT = "microsoft"; - @Param(value = {ONNXRUNTIME_JAVA, ONNXRUNTIME_JAVA_ARENA, MICROSOFT}) + @Param(value = {ONNXRUNTIME_JAVA, ONNXRUNTIME_JAVA_ARENA, ONNXRUNTIME_JAVA_IOBINDING, MICROSOFT}) private String implementation; @Param({"16", "256", "4096"}) @@ -79,8 +80,9 @@ public void setup() throws Exception { input[i] = random.nextLong(); } wrapper = switch (implementation) { - case ONNXRUNTIME_JAVA -> new OnnxruntimeJava(bytes, false); - case ONNXRUNTIME_JAVA_ARENA -> new OnnxruntimeJava(bytes, true); + case ONNXRUNTIME_JAVA -> new OnnxruntimeJava(bytes, false, size); + case ONNXRUNTIME_JAVA_ARENA -> new OnnxruntimeJava(bytes, true, size); + case ONNXRUNTIME_JAVA_IOBINDING -> new OnnxruntimeJavaIoBinding(bytes, false, size); case MICROSOFT -> new Microsoft(bytes); default -> throw new IllegalArgumentException();}; } diff --git a/onnxruntime-benchmark/src/main/java/com/jyuzawa/onnxruntime_benchmark/Benchmark.java b/onnxruntime-benchmark/src/main/java/com/jyuzawa/onnxruntime_benchmark/Benchmark.java index 42c9def..654fbbb 100644 --- a/onnxruntime-benchmark/src/main/java/com/jyuzawa/onnxruntime_benchmark/Benchmark.java +++ b/onnxruntime-benchmark/src/main/java/com/jyuzawa/onnxruntime_benchmark/Benchmark.java @@ -40,7 +40,7 @@ public static final void main(String[] args) throws Exception { .addOutput(ValueInfoProto.newBuilder().setName("output").setType(type))) .build(); byte[] bytes = model.toByteArray(); - List wrappers = List.of(new OnnxruntimeJava(bytes, false), new Microsoft(bytes)); + List wrappers = List.of(new OnnxruntimeJava(bytes, false, input.length), new Microsoft(bytes)); long i = 0; long startMs = System.currentTimeMillis(); while (i >= 0) { diff --git a/onnxruntime-benchmark/src/main/java/com/jyuzawa/onnxruntime_benchmark/OnnxruntimeJava.java b/onnxruntime-benchmark/src/main/java/com/jyuzawa/onnxruntime_benchmark/OnnxruntimeJava.java index f5981b4..9e54f65 100644 --- a/onnxruntime-benchmark/src/main/java/com/jyuzawa/onnxruntime_benchmark/OnnxruntimeJava.java +++ b/onnxruntime-benchmark/src/main/java/com/jyuzawa/onnxruntime_benchmark/OnnxruntimeJava.java @@ -15,7 +15,7 @@ import java.io.IOException; import java.util.Map; -final class OnnxruntimeJava implements Wrapper { +class OnnxruntimeJava implements Wrapper { private static final Environment ENVIRONMENT = OnnxRuntime.get() .getApi() @@ -23,14 +23,16 @@ final class OnnxruntimeJava implements Wrapper { .setLogSeverityLevel(OnnxRuntimeLoggingLevel.WARNING) .build(); - private final Session session; + protected final Session session; + protected final long[] out; - OnnxruntimeJava(byte[] bytes, boolean arena) throws IOException { + OnnxruntimeJava(byte[] bytes, boolean arena, int size) throws IOException { this.session = ENVIRONMENT .newSession() .setByteArray(bytes) .addProvider(ExecutionProvider.CPU_EXECUTION_PROVIDER, Map.of("use_arena", arena ? "1" : "0")) .build(); + this.out = new long[size]; } @Override @@ -44,9 +46,8 @@ public long[] evaluate(long[] input) { txn.addInput(0).asTensor().getLongBuffer().put(input); txn.addOutput(0); NamedCollection result = txn.run(); - long[] output = new long[input.length]; - result.get(0).asTensor().getLongBuffer().get(output); - return output; + result.get(0).asTensor().getLongBuffer().get(out); + return out; } } } diff --git a/onnxruntime-benchmark/src/main/java/com/jyuzawa/onnxruntime_benchmark/OnnxruntimeJavaIoBinding.java b/onnxruntime-benchmark/src/main/java/com/jyuzawa/onnxruntime_benchmark/OnnxruntimeJavaIoBinding.java new file mode 100644 index 0000000..2a99020 --- /dev/null +++ b/onnxruntime-benchmark/src/main/java/com/jyuzawa/onnxruntime_benchmark/OnnxruntimeJavaIoBinding.java @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2022 James Yuzawa (https://www.jyuzawa.com/) + * SPDX-License-Identifier: MIT + */ +package com.jyuzawa.onnxruntime_benchmark; + +import com.jyuzawa.onnxruntime.IoBinding; +import java.io.IOException; +import java.nio.LongBuffer; + +final class OnnxruntimeJavaIoBinding extends OnnxruntimeJava { + + private final IoBinding ioBinding; + private final LongBuffer inputBuf; + private final LongBuffer outputBuf; + + OnnxruntimeJavaIoBinding(byte[] bytes, boolean arena, int size) throws IOException { + super(bytes, arena, size); + this.ioBinding = session.newIoBinding().bindInput(0).bindOutput(0).build(); + this.inputBuf = ioBinding.getInputs().get(0).asTensor().getLongBuffer(); + this.outputBuf = ioBinding.getOutputs().get(0).asTensor().getLongBuffer(); + } + + @Override + public void close() throws Exception { + ioBinding.close(); + super.close(); + } + + @Override + public long[] evaluate(long[] input) { + inputBuf.clear().put(input); + ioBinding.run(); + outputBuf.rewind().get(out); + return out; + } +} diff --git a/src/main/java/com/jyuzawa/onnxruntime/ApiImpl.java b/src/main/java/com/jyuzawa/onnxruntime/ApiImpl.java index 130b477..1d120c8 100644 --- a/src/main/java/com/jyuzawa/onnxruntime/ApiImpl.java +++ b/src/main/java/com/jyuzawa/onnxruntime/ApiImpl.java @@ -29,6 +29,8 @@ final class ApiImpl implements Api { final AddRunConfigEntry AddRunConfigEntry; final AddSessionConfigEntry AddSessionConfigEntry; final AllocatorFree AllocatorFree; + final BindInput BindInput; + final BindOutput BindOutput; final CastTypeInfoToMapTypeInfo CastTypeInfoToMapTypeInfo; final CastTypeInfoToSequenceTypeInfo CastTypeInfoToSequenceTypeInfo; final CastTypeInfoToTensorInfo CastTypeInfoToTensorInfo; @@ -38,6 +40,7 @@ final class ApiImpl implements Api { final CreateCpuMemoryInfo CreateCpuMemoryInfo; final CreateCUDAProviderOptions CreateCUDAProviderOptions; final CreateDnnlProviderOptions CreateDnnlProviderOptions; + final CreateIoBinding CreateIoBinding; final CreateTensorRTProviderOptions CreateTensorRTProviderOptions; final CreateEnvWithCustomLogger CreateEnvWithCustomLogger; final CreateEnvWithCustomLoggerAndGlobalThreadPools CreateEnvWithCustomLoggerAndGlobalThreadPools; @@ -93,6 +96,7 @@ final class ApiImpl implements Api { final ReleaseCUDAProviderOptions ReleaseCUDAProviderOptions; final ReleaseDnnlProviderOptions ReleaseDnnlProviderOptions; final ReleaseEnv ReleaseEnv; + final ReleaseIoBinding ReleaseIoBinding; final ReleaseMemoryInfo ReleaseMemoryInfo; final ReleaseModelMetadata ReleaseModelMetadata; final ReleaseRunOptions ReleaseRunOptions; @@ -109,6 +113,7 @@ final class ApiImpl implements Api { final RunOptionsSetRunLogVerbosityLevel RunOptionsSetRunLogVerbosityLevel; final RunOptionsSetRunTag RunOptionsSetRunTag; final RunOptionsSetTerminate RunOptionsSetTerminate; + final RunWithBinding RunWithBinding; final SetGlobalDenormalAsZero SetGlobalDenormalAsZero; final SetGlobalIntraOpNumThreads SetGlobalIntraOpNumThreads; final SetGlobalInterOpNumThreads SetGlobalInterOpNumThreads; @@ -153,6 +158,8 @@ final class ApiImpl implements Api { this.AddRunConfigEntry = OrtApi.AddRunConfigEntry(memorySegment, memorySession); this.AddSessionConfigEntry = OrtApi.AddSessionConfigEntry(memorySegment, memorySession); this.AllocatorFree = OrtApi.AllocatorFree(memorySegment, memorySession); + this.BindInput = OrtApi.BindInput(memorySegment, memorySession); + this.BindOutput = OrtApi.BindOutput(memorySegment, memorySession); this.CastTypeInfoToMapTypeInfo = OrtApi.CastTypeInfoToMapTypeInfo(memorySegment, memorySession); this.CastTypeInfoToSequenceTypeInfo = OrtApi.CastTypeInfoToSequenceTypeInfo(memorySegment, memorySession); this.CastTypeInfoToTensorInfo = OrtApi.CastTypeInfoToTensorInfo(memorySegment, memorySession); @@ -163,6 +170,7 @@ final class ApiImpl implements Api { this.CreateCUDAProviderOptions = OrtApi.CreateCUDAProviderOptions(memorySegment, memorySession); this.CreateDnnlProviderOptions = OrtApi.CreateDnnlProviderOptions(memorySegment, memorySession); this.CreateTensorRTProviderOptions = OrtApi.CreateTensorRTProviderOptions(memorySegment, memorySession); + this.CreateIoBinding = OrtApi.CreateIoBinding(memorySegment, memorySession); this.CreateEnvWithCustomLogger = OrtApi.CreateEnvWithCustomLogger(memorySegment, memorySession); this.CreateEnvWithCustomLoggerAndGlobalThreadPools = OrtApi.CreateEnvWithCustomLoggerAndGlobalThreadPools(memorySegment, memorySession); @@ -220,6 +228,7 @@ final class ApiImpl implements Api { this.ReleaseCUDAProviderOptions = OrtApi.ReleaseCUDAProviderOptions(memorySegment, memorySession); this.ReleaseDnnlProviderOptions = OrtApi.ReleaseDnnlProviderOptions(memorySegment, memorySession); this.ReleaseEnv = OrtApi.ReleaseEnv(memorySegment, memorySession); + this.ReleaseIoBinding = OrtApi.ReleaseIoBinding(memorySegment, memorySession); this.ReleaseMemoryInfo = OrtApi.ReleaseMemoryInfo(memorySegment, memorySession); this.ReleaseModelMetadata = OrtApi.ReleaseModelMetadata(memorySegment, memorySession); this.ReleaseRunOptions = OrtApi.ReleaseRunOptions(memorySegment, memorySession); @@ -236,6 +245,7 @@ final class ApiImpl implements Api { this.RunOptionsSetRunLogVerbosityLevel = OrtApi.RunOptionsSetRunLogVerbosityLevel(memorySegment, memorySession); this.RunOptionsSetRunTag = OrtApi.RunOptionsSetRunTag(memorySegment, memorySession); this.RunOptionsSetTerminate = OrtApi.RunOptionsSetTerminate(memorySegment, memorySession); + this.RunWithBinding = OrtApi.RunWithBinding(memorySegment, memorySession); this.SetGlobalDenormalAsZero = OrtApi.SetGlobalDenormalAsZero(memorySegment, memorySession); this.SetGlobalIntraOpNumThreads = OrtApi.SetGlobalIntraOpNumThreads(memorySegment, memorySession); this.SetGlobalInterOpNumThreads = OrtApi.SetGlobalInterOpNumThreads(memorySegment, memorySession); diff --git a/src/main/java/com/jyuzawa/onnxruntime/IoBinding.java b/src/main/java/com/jyuzawa/onnxruntime/IoBinding.java new file mode 100644 index 0000000..bbebae8 --- /dev/null +++ b/src/main/java/com/jyuzawa/onnxruntime/IoBinding.java @@ -0,0 +1,100 @@ +/* + * Copyright (c) 2022 James Yuzawa (https://www.jyuzawa.com/) + * SPDX-License-Identifier: MIT + */ +package com.jyuzawa.onnxruntime; + +import java.util.Map; + +/** + * A representation of a model evaluation. Capable of reuse in repetitive runs. More efficient than {@link Transaction}. Only supports tensors. Input and outputs are pre-allocated and must be of fixed size. This class is NOT thread-safe. + * + * @since 1.4.0 + */ +public interface IoBinding extends AutoCloseable { + + /** + * Set the severity for logging for this specific transaction. + * Can override the environment's or session's logger's severity. + * @param level + * @return this + */ + IoBinding setLogSeverityLevel(OnnxRuntimeLoggingLevel level); + + /** + * Set the verbosity for logging for this specific transaction. + * Can override the environment's or session's logger's verbosity. + * @param level + * @return this + */ + IoBinding setLogVerbosityLevel(int level); + + /** + * Set the run tag (which is the logger id) + * @param runTag + * @return this + */ + IoBinding setRunTag(String runTag); + + NamedCollection getInputs(); + + NamedCollection getOutputs(); + + /** + * Run the model evaluation. + */ + void run(); + + /** + * Frees the native resources (typically buffers) associated with this transaction. + */ + @Override + void close(); + + /** + * A builder of a {@link IoBinding}. Should NOT be reused. This class is NOT thread-safe. + * + * @since 1.0.0 + */ + public interface Builder { + /** + * Add an input and get an OnnxValue to populate. + * @param name + * @return the value to be populated + */ + Builder bindInput(String name); + + /** + * Add an input and get an OnnxValue to populate. + * @param index + * @return the value to be populated + */ + Builder bindInput(int index); + + /** + * Request a specific output to be produced. + * @param name + */ + Builder bindOutput(String name); + + /** + * Request a specific output to be produced. + * @param index + */ + Builder bindOutput(int index); + + /** + * Set custom parameters for this transaction. + * @param config + * @return the builder + */ + Builder setConfigMap(Map config); + + /** + * Construct a {@link IoBinding}. + * + * @return a new instance + */ + IoBinding build(); + } +} diff --git a/src/main/java/com/jyuzawa/onnxruntime/IoBindingImpl.java b/src/main/java/com/jyuzawa/onnxruntime/IoBindingImpl.java new file mode 100644 index 0000000..c41330e --- /dev/null +++ b/src/main/java/com/jyuzawa/onnxruntime/IoBindingImpl.java @@ -0,0 +1,182 @@ +/* + * Copyright (c) 2022 James Yuzawa (https://www.jyuzawa.com/) + * SPDX-License-Identifier: MIT + */ +package com.jyuzawa.onnxruntime; + +import java.lang.foreign.Addressable; +import java.lang.foreign.MemoryAddress; +import java.lang.foreign.MemorySegment; +import java.lang.foreign.MemorySession; +import java.lang.foreign.SegmentAllocator; +import java.util.ArrayList; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; + +final class IoBindingImpl implements IoBinding { + private final ApiImpl api; + private final MemorySession memorySession; + private final MemoryAddress ioBinding; + private final MemoryAddress runOptions; + private final NamedCollectionImpl inputs; + private final NamedCollectionImpl outputs; + private final MemoryAddress session; + + IoBindingImpl(Builder builder) { + this.memorySession = MemorySession.openConfined(); + this.api = builder.api; + this.session = builder.session.address(); + SegmentAllocator allocator = SegmentAllocator.newNativeArena(memorySession); + this.ioBinding = builder.api.create(allocator, out -> builder.api.CreateIoBinding.apply(session, out)); + this.runOptions = api.create(allocator, out -> api.CreateRunOptions.apply(out)); + Map config = builder.config; + if (config != null && !config.isEmpty()) { + for (Map.Entry entry : config.entrySet()) { + api.checkStatus(api.AddRunConfigEntry.apply( + runOptions, + allocator.allocateUtf8String(entry.getKey()).address(), + allocator.allocateUtf8String(entry.getValue()).address())); + } + } + ValueContext valueContext = new ValueContext( + builder.api, + allocator, + memorySession, + builder.session.environment.ortAllocator, + builder.session.environment.memoryInfo); + this.inputs = add(builder.inputs, valueContext, memorySession, api, ioBinding, true); + this.outputs = add(builder.outputs, valueContext, memorySession, api, ioBinding, false); + } + + private static final NamedCollectionImpl add( + List nodes, + ValueContext valueContext, + MemorySession memorySession, + ApiImpl api, + MemoryAddress ioBinding, + boolean isInput) { + LinkedHashMap out = new LinkedHashMap<>(nodes.size()); + for (NodeInfoImpl node : nodes) { + OnnxValueImpl output = node.getTypeInfo().newValue(valueContext, null); + MemoryAddress valueAddress = output.toNative(); + memorySession.addCloseAction(() -> api.ReleaseValue.apply(valueAddress)); + out.put(node.getName(), output); + final Addressable result; + if (isInput) { + result = api.BindInput.apply(ioBinding, node.nameSegment, valueAddress); + } else { + result = api.BindOutput.apply(ioBinding, node.nameSegment, valueAddress); + } + api.checkStatus(result); + } + return new NamedCollectionImpl<>(out); + } + + @Override + public void close() { + api.ReleaseIoBinding.apply(ioBinding); + api.ReleaseRunOptions.apply(runOptions); + memorySession.close(); + } + + @Override + public void run() { + api.checkStatus(api.RunWithBinding.apply(session, runOptions, ioBinding)); + } + + static final class Builder implements IoBinding.Builder { + + final ApiImpl api; + final SessionImpl session; + private Map config; + final List inputs; + final List outputs; + + public Builder(SessionImpl session) { + this.api = session.api; + this.session = session; + this.inputs = new ArrayList<>(); + this.outputs = new ArrayList<>(); + } + + @Override + public IoBinding build() { + if (inputs.isEmpty()) { + throw new IllegalArgumentException("No inputs specified"); + } + if (outputs.isEmpty()) { + throw new IllegalArgumentException("No outputs specified"); + } + return new IoBindingImpl(this); + } + + @Override + public Builder setConfigMap(Map config) { + this.config = config; + return this; + } + + private void accumulate(List list, NodeInfoImpl nodeInfo) { + if (nodeInfo == null) { + throw new IllegalArgumentException("node info missing"); + } + list.add(nodeInfo); + } + + @Override + public Builder bindInput(String name) { + accumulate(inputs, session.inputs.get(name)); + return this; + } + + @Override + public Builder bindInput(int index) { + accumulate(inputs, session.inputs.get(index)); + return this; + } + + @Override + public Builder bindOutput(String name) { + accumulate(outputs, session.outputs.get(name)); + return this; + } + + @Override + public Builder bindOutput(int index) { + accumulate(outputs, session.outputs.get(index)); + return this; + } + } + + @Override + public IoBinding setLogSeverityLevel(OnnxRuntimeLoggingLevel level) { + api.checkStatus(api.RunOptionsSetRunLogSeverityLevel.apply(runOptions, level.getNumber())); + return this; + } + + @Override + public IoBinding setLogVerbosityLevel(int level) { + api.checkStatus(api.RunOptionsSetRunLogVerbosityLevel.apply(runOptions, level)); + return this; + } + + @Override + public IoBinding setRunTag(String runTag) { + try (MemorySession allocator = MemorySession.openConfined()) { + MemorySegment segment = allocator.allocateUtf8String(runTag); + api.checkStatus(api.RunOptionsSetRunTag.apply(runOptions, segment.address())); + } + return this; + } + + @Override + public NamedCollection getInputs() { + return inputs; + } + + @Override + public NamedCollection getOutputs() { + return outputs; + } +} diff --git a/src/main/java/com/jyuzawa/onnxruntime/NodeInfoImpl.java b/src/main/java/com/jyuzawa/onnxruntime/NodeInfoImpl.java index 7de583e..44bae2c 100644 --- a/src/main/java/com/jyuzawa/onnxruntime/NodeInfoImpl.java +++ b/src/main/java/com/jyuzawa/onnxruntime/NodeInfoImpl.java @@ -4,15 +4,15 @@ */ package com.jyuzawa.onnxruntime; -import java.lang.foreign.MemorySegment; +import java.lang.foreign.MemoryAddress; final class NodeInfoImpl implements NodeInfo { private final String name; - final MemorySegment nameSegment; + final MemoryAddress nameSegment; private final TypeInfoImpl typeInfo; - NodeInfoImpl(String name, MemorySegment nameSegment, TypeInfoImpl typeInfo) { + NodeInfoImpl(String name, MemoryAddress nameSegment, TypeInfoImpl typeInfo) { this.name = name; this.nameSegment = nameSegment; this.typeInfo = typeInfo; diff --git a/src/main/java/com/jyuzawa/onnxruntime/Session.java b/src/main/java/com/jyuzawa/onnxruntime/Session.java index cf10eb0..80432ec 100644 --- a/src/main/java/com/jyuzawa/onnxruntime/Session.java +++ b/src/main/java/com/jyuzawa/onnxruntime/Session.java @@ -65,6 +65,8 @@ public interface Session extends AutoCloseable { */ Transaction.Builder newTransaction(); + IoBinding.Builder newIoBinding(); + /** * A builder of a {@link Session}. Must provide either bytes or a path. * diff --git a/src/main/java/com/jyuzawa/onnxruntime/SessionImpl.java b/src/main/java/com/jyuzawa/onnxruntime/SessionImpl.java index b5af974..bf248d1 100644 --- a/src/main/java/com/jyuzawa/onnxruntime/SessionImpl.java +++ b/src/main/java/com/jyuzawa/onnxruntime/SessionImpl.java @@ -160,7 +160,10 @@ private static NamedCollection createMap( api.checkStatus(api.AllocatorFree.apply(ortAllocator, nameSegment)); MemoryAddress typeInfoAddress = api.create(allocator, out -> getTypeInfo.apply(session, j, out)); TypeInfoImpl typeInfo = new TypeInfoImpl(api, typeInfoAddress, allocator, sessionAllocator, ortAllocator); - inputs.put(name, new NodeInfoImpl(name, sessionAllocator.allocateUtf8String(name), typeInfo)); + inputs.put( + name, + new NodeInfoImpl( + name, sessionAllocator.allocateUtf8String(name).address(), typeInfo)); } return new NamedCollectionImpl<>(inputs); } @@ -208,6 +211,11 @@ public Transaction.Builder newTransaction() { return new TransactionImpl.Builder(this); } + @Override + public IoBinding.Builder newIoBinding() { + return new IoBindingImpl.Builder(this); + } + static final class Builder implements Session.Builder { private final ApiImpl api; private final EnvironmentImpl environment; diff --git a/src/test/java/com/jyuzawa/onnxruntime/SessionTest.java b/src/test/java/com/jyuzawa/onnxruntime/SessionTest.java index c0057a2..d405279 100644 --- a/src/test/java/com/jyuzawa/onnxruntime/SessionTest.java +++ b/src/test/java/com/jyuzawa/onnxruntime/SessionTest.java @@ -13,6 +13,7 @@ import java.lang.System.Logger; import java.lang.System.Logger.Level; import java.nio.ByteBuffer; +import java.nio.IntBuffer; import java.nio.charset.StandardCharsets; import java.nio.file.Files; import java.nio.file.Path; @@ -21,6 +22,7 @@ import java.util.Map; import java.util.NoSuchElementException; import java.util.Set; +import java.util.concurrent.ThreadLocalRandom; import onnx.OnnxMl.AttributeProto; import onnx.OnnxMl.AttributeProto.AttributeType; import onnx.OnnxMl.GraphProto; @@ -951,6 +953,44 @@ public void runOptionsTest() throws IOException { } } + @Test + public void ioBindingTest() throws IOException { + TypeProto type = TypeProto.newBuilder() + .setTensorType(Tensor.newBuilder() + .setElemType(DataType.INT32_VALUE) + .setShape(TensorShapeProto.newBuilder() + .addDim(Dimension.newBuilder().setDimValue(1)) + .addDim(Dimension.newBuilder().setDimValue(3)))) + .build(); + try (Session session = environment + .newSession() + .setByteBuffer(identityModel(type)) + .build(); + IoBinding txn = session.newIoBinding() + .bindInput(0) + .bindOutput(0) + .setConfigMap(Map.of("foo", "bar")) + .build()) { + txn.setLogSeverityLevel(OnnxRuntimeLoggingLevel.INFO); + txn.setLogVerbosityLevel(0); + txn.setRunTag("foo"); + IntBuffer inputBuf = txn.getInputs().get(0).asTensor().getIntBuffer(); + IntBuffer outputBuf = txn.getOutputs().get(0).asTensor().getIntBuffer(); + int[] rawOutput = new int[3]; + for (int i = 0; i < 100; i++) { + int[] rawInput = new int[] { + ThreadLocalRandom.current().nextInt(), + ThreadLocalRandom.current().nextInt(), + ThreadLocalRandom.current().nextInt() + }; + inputBuf.clear().put(rawInput); + txn.run(); + outputBuf.rewind().get(rawOutput); + assertArrayEquals(rawInput, rawOutput); + } + } + } + @Test public void optimizationTest() throws IOException { File file = File.createTempFile("ort-optimized", ".onnx");