Skip to content

Commit

Permalink
clean up constructor
Browse files Browse the repository at this point in the history
  • Loading branch information
yuzawa-san committed Jan 4, 2024
1 parent c9752a3 commit 0ef50ed
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 23 deletions.
43 changes: 24 additions & 19 deletions src/main/java/com/jyuzawa/onnxruntime/IoBindingImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
*/
package com.jyuzawa.onnxruntime;

import java.lang.foreign.Addressable;
import java.lang.foreign.MemoryAddress;
import java.lang.foreign.MemorySegment;
import java.lang.foreign.MemorySession;
Expand Down Expand Up @@ -44,28 +45,32 @@ final class IoBindingImpl implements IoBinding {
memorySession,
builder.session.environment.ortAllocator,
builder.session.environment.memoryInfo);
this.inputs = add(builder.inputs, valueContext, memorySession, api, ioBinding, true);
this.outputs = add(builder.outputs, valueContext, memorySession, api, ioBinding, false);
}

LinkedHashMap<String, OnnxValue> rawInputs = new LinkedHashMap<>(builder.inputs.size());
for (NodeInfoImpl inputNode : builder.inputs) {
OnnxValueImpl input = inputNode.getTypeInfo().newValue(valueContext, null);
MemoryAddress valueAddress = input.toNative();
memorySession.addCloseAction(() -> builder.api.ReleaseValue.apply(valueAddress));
rawInputs.put(inputNode.getName(), input);
builder.api.checkStatus(
api.BindInput.apply(ioBinding.address(), inputNode.nameSegment.address(), valueAddress.address()));
}
this.inputs = new NamedCollectionImpl<>(rawInputs);

LinkedHashMap<String, OnnxValue> rawOutputs = new LinkedHashMap<>(builder.outputs.size());
for (NodeInfoImpl outputNode : builder.outputs) {
OnnxValueImpl output = outputNode.getTypeInfo().newValue(valueContext, null);
private static final NamedCollectionImpl<OnnxValue> add(
List<NodeInfoImpl> nodes,
ValueContext valueContext,
MemorySession memorySession,
ApiImpl api,
MemoryAddress ioBinding,
boolean isInput) {
LinkedHashMap<String, OnnxValue> out = new LinkedHashMap<>(nodes.size());
for (NodeInfoImpl node : nodes) {
OnnxValueImpl output = node.getTypeInfo().newValue(valueContext, null);
MemoryAddress valueAddress = output.toNative();
memorySession.addCloseAction(() -> builder.api.ReleaseValue.apply(valueAddress));
rawOutputs.put(outputNode.getName(), output);
builder.api.checkStatus(api.BindOutput.apply(
ioBinding.address(), outputNode.nameSegment.address(), valueAddress.address()));
memorySession.addCloseAction(() -> api.ReleaseValue.apply(valueAddress));
out.put(node.getName(), output);
final Addressable result;
if (isInput) {
result = api.BindInput.apply(ioBinding, node.nameSegment, valueAddress);
} else {
result = api.BindOutput.apply(ioBinding, node.nameSegment, valueAddress);
}
api.checkStatus(result);
}
this.outputs = new NamedCollectionImpl<>(rawOutputs);
return new NamedCollectionImpl<>(out);
}

@Override
Expand Down
6 changes: 3 additions & 3 deletions src/main/java/com/jyuzawa/onnxruntime/NodeInfoImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@
*/
package com.jyuzawa.onnxruntime;

import java.lang.foreign.MemorySegment;
import java.lang.foreign.MemoryAddress;

final class NodeInfoImpl implements NodeInfo {

private final String name;
final MemorySegment nameSegment;
final MemoryAddress nameSegment;
private final TypeInfoImpl typeInfo;

NodeInfoImpl(String name, MemorySegment nameSegment, TypeInfoImpl typeInfo) {
NodeInfoImpl(String name, MemoryAddress nameSegment, TypeInfoImpl typeInfo) {
this.name = name;
this.nameSegment = nameSegment;
this.typeInfo = typeInfo;
Expand Down
5 changes: 4 additions & 1 deletion src/main/java/com/jyuzawa/onnxruntime/SessionImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,10 @@ private static NamedCollection<NodeInfoImpl> createMap(
api.checkStatus(api.AllocatorFree.apply(ortAllocator, nameSegment));
MemoryAddress typeInfoAddress = api.create(allocator, out -> getTypeInfo.apply(session, j, out));
TypeInfoImpl typeInfo = new TypeInfoImpl(api, typeInfoAddress, allocator, sessionAllocator, ortAllocator);
inputs.put(name, new NodeInfoImpl(name, sessionAllocator.allocateUtf8String(name), typeInfo));
inputs.put(
name,
new NodeInfoImpl(
name, sessionAllocator.allocateUtf8String(name).address(), typeInfo));
}
return new NamedCollectionImpl<>(inputs);
}
Expand Down

0 comments on commit 0ef50ed

Please sign in to comment.