diff --git a/src/main/java/com/jyuzawa/onnxruntime/IoBindingImpl.java b/src/main/java/com/jyuzawa/onnxruntime/IoBindingImpl.java index 14834e5..c41330e 100644 --- a/src/main/java/com/jyuzawa/onnxruntime/IoBindingImpl.java +++ b/src/main/java/com/jyuzawa/onnxruntime/IoBindingImpl.java @@ -4,6 +4,7 @@ */ package com.jyuzawa.onnxruntime; +import java.lang.foreign.Addressable; import java.lang.foreign.MemoryAddress; import java.lang.foreign.MemorySegment; import java.lang.foreign.MemorySession; @@ -44,28 +45,32 @@ final class IoBindingImpl implements IoBinding { memorySession, builder.session.environment.ortAllocator, builder.session.environment.memoryInfo); + this.inputs = add(builder.inputs, valueContext, memorySession, api, ioBinding, true); + this.outputs = add(builder.outputs, valueContext, memorySession, api, ioBinding, false); + } - LinkedHashMap rawInputs = new LinkedHashMap<>(builder.inputs.size()); - for (NodeInfoImpl inputNode : builder.inputs) { - OnnxValueImpl input = inputNode.getTypeInfo().newValue(valueContext, null); - MemoryAddress valueAddress = input.toNative(); - memorySession.addCloseAction(() -> builder.api.ReleaseValue.apply(valueAddress)); - rawInputs.put(inputNode.getName(), input); - builder.api.checkStatus( - api.BindInput.apply(ioBinding.address(), inputNode.nameSegment.address(), valueAddress.address())); - } - this.inputs = new NamedCollectionImpl<>(rawInputs); - - LinkedHashMap rawOutputs = new LinkedHashMap<>(builder.outputs.size()); - for (NodeInfoImpl outputNode : builder.outputs) { - OnnxValueImpl output = outputNode.getTypeInfo().newValue(valueContext, null); + private static final NamedCollectionImpl add( + List nodes, + ValueContext valueContext, + MemorySession memorySession, + ApiImpl api, + MemoryAddress ioBinding, + boolean isInput) { + LinkedHashMap out = new LinkedHashMap<>(nodes.size()); + for (NodeInfoImpl node : nodes) { + OnnxValueImpl output = node.getTypeInfo().newValue(valueContext, null); MemoryAddress valueAddress = output.toNative(); - memorySession.addCloseAction(() -> builder.api.ReleaseValue.apply(valueAddress)); - rawOutputs.put(outputNode.getName(), output); - builder.api.checkStatus(api.BindOutput.apply( - ioBinding.address(), outputNode.nameSegment.address(), valueAddress.address())); + memorySession.addCloseAction(() -> api.ReleaseValue.apply(valueAddress)); + out.put(node.getName(), output); + final Addressable result; + if (isInput) { + result = api.BindInput.apply(ioBinding, node.nameSegment, valueAddress); + } else { + result = api.BindOutput.apply(ioBinding, node.nameSegment, valueAddress); + } + api.checkStatus(result); } - this.outputs = new NamedCollectionImpl<>(rawOutputs); + return new NamedCollectionImpl<>(out); } @Override diff --git a/src/main/java/com/jyuzawa/onnxruntime/NodeInfoImpl.java b/src/main/java/com/jyuzawa/onnxruntime/NodeInfoImpl.java index 7de583e..44bae2c 100644 --- a/src/main/java/com/jyuzawa/onnxruntime/NodeInfoImpl.java +++ b/src/main/java/com/jyuzawa/onnxruntime/NodeInfoImpl.java @@ -4,15 +4,15 @@ */ package com.jyuzawa.onnxruntime; -import java.lang.foreign.MemorySegment; +import java.lang.foreign.MemoryAddress; final class NodeInfoImpl implements NodeInfo { private final String name; - final MemorySegment nameSegment; + final MemoryAddress nameSegment; private final TypeInfoImpl typeInfo; - NodeInfoImpl(String name, MemorySegment nameSegment, TypeInfoImpl typeInfo) { + NodeInfoImpl(String name, MemoryAddress nameSegment, TypeInfoImpl typeInfo) { this.name = name; this.nameSegment = nameSegment; this.typeInfo = typeInfo; diff --git a/src/main/java/com/jyuzawa/onnxruntime/SessionImpl.java b/src/main/java/com/jyuzawa/onnxruntime/SessionImpl.java index 841b68f..bf248d1 100644 --- a/src/main/java/com/jyuzawa/onnxruntime/SessionImpl.java +++ b/src/main/java/com/jyuzawa/onnxruntime/SessionImpl.java @@ -160,7 +160,10 @@ private static NamedCollection createMap( api.checkStatus(api.AllocatorFree.apply(ortAllocator, nameSegment)); MemoryAddress typeInfoAddress = api.create(allocator, out -> getTypeInfo.apply(session, j, out)); TypeInfoImpl typeInfo = new TypeInfoImpl(api, typeInfoAddress, allocator, sessionAllocator, ortAllocator); - inputs.put(name, new NodeInfoImpl(name, sessionAllocator.allocateUtf8String(name), typeInfo)); + inputs.put( + name, + new NodeInfoImpl( + name, sessionAllocator.allocateUtf8String(name).address(), typeInfo)); } return new NamedCollectionImpl<>(inputs); }