Skip to content

Commit

Permalink
onnx model.py fix (nod-ai#254)
Browse files Browse the repository at this point in the history
 For nod-ai/SHARK-ModelDev#683
 RAFT model test failed due to accessing`E2ESHARK_CHECK` with "inputs" instead of "input".
 Fixed other tests failing in the same way (found in model's `inference.log` file).
  • Loading branch information
IanWood1 authored Jun 25, 2024
1 parent 15dd074 commit 4751fab
Show file tree
Hide file tree
Showing 7 changed files with 12 additions and 12 deletions.
4 changes: 2 additions & 2 deletions e2eshark/onnx/models/DeepLabV3_resnet50_vaiq_int8/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@


model_output = session.run(
[outputs[0].name],
[outputs[0].name, outputs[1].name],
{inputs[0].name: model_input_X},
)[0]
)
E2ESHARK_CHECK["input"] = [torch.from_numpy(model_input_X)]
E2ESHARK_CHECK["output"] = [torch.from_numpy(arr) for arr in model_output]

Expand Down
2 changes: 1 addition & 1 deletion e2eshark/onnx/models/FCN_vaiq_int8/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
model_output = session.run(
[outputs[0].name, outputs[1].name],
{inputs[0].name: model_input_X},
)[0]
)
E2ESHARK_CHECK["input"] = [torch.from_numpy(model_input_X)]
E2ESHARK_CHECK["output"] = [torch.from_numpy(arr) for arr in model_output]

Expand Down
10 changes: 5 additions & 5 deletions e2eshark/onnx/models/RAFT_vaiq_int8/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,12 @@
inputs[0].name: model_input_X,
inputs[1].name: model_input_Y
},
)[0]
E2ESHARK_CHECK["inputs"] = [torch.from_numpy(model_input_X), torch.from_numpy(model_input_Y)]
E2ESHARK_CHECK["outputs"] = [torch.from_numpy(arr) for arr in model_output]
)
E2ESHARK_CHECK["input"] = [torch.from_numpy(model_input_X), torch.from_numpy(model_input_Y)]
E2ESHARK_CHECK["output"] = [torch.from_numpy(arr) for arr in model_output]

print("Input:", E2ESHARK_CHECK["inputs"])
print("Output:", E2ESHARK_CHECK["outputs"])
print("Input:", E2ESHARK_CHECK["input"])
print("Output:", E2ESHARK_CHECK["output"])

# Post process output to do:
# sort(topk(torch.nn.functional.softmax(output, 0), 2)[1])[0]
Expand Down
2 changes: 1 addition & 1 deletion e2eshark/onnx/models/U-2-Net_vaiq_int8/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
outputs[6].name,
],
{inputs[0].name: model_input_X},
)[0]
)
E2ESHARK_CHECK["input"] = [torch.from_numpy(model_input_X)]
E2ESHARK_CHECK["output"] = [torch.from_numpy(arr) for arr in model_output]

Expand Down
2 changes: 1 addition & 1 deletion e2eshark/onnx/models/YoloNetV3_vaiq_int8/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
model_output = session.run(
[outputs[0].name],
{inputs[0].name: model_input_X},
)[0]
)
E2ESHARK_CHECK["input"] = [torch.from_numpy(model_input_X)]
E2ESHARK_CHECK["output"] = [torch.from_numpy(arr) for arr in model_output]

Expand Down
2 changes: 1 addition & 1 deletion e2eshark/onnx/models/u-net_brain_mri_vaiq_int8/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
model_output = session.run(
[outputs[0].name],
{inputs[0].name: model_input_X},
)[0]
)
E2ESHARK_CHECK["input"] = [torch.from_numpy(model_input_X)]
E2ESHARK_CHECK["output"] = [torch.from_numpy(arr) for arr in model_output]

Expand Down
2 changes: 1 addition & 1 deletion e2eshark/onnx/models/yolov8n_vaiq_int8/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
model_output = session.run(
[outputs[0].name, outputs[1].name, outputs[2].name, outputs[3].name],
{inputs[0].name: model_input_X},
)[0]
)
E2ESHARK_CHECK["input"] = [torch.from_numpy(model_input_X)]
E2ESHARK_CHECK["output"] = [torch.from_numpy(arr) for arr in model_output]

Expand Down

0 comments on commit 4751fab

Please sign in to comment.