Skip to content

Commit

Permalink
Fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Nov 14, 2023
1 parent a16bbe2 commit d27799c
Showing 1 changed file with 5 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,9 @@ def test_brevitas_fc_onnx_export_and_exec(size, wbits, abits, pretrained):
input_a = np.random.randint(MIN_INP_VAL, MAX_INP_VAL, size=FC_INPUT_SIZE).astype(np.float32)
scale = 1. / 255
input_t = torch.from_numpy(input_a * scale)
input_qt = QuantTensor(
input_t, scale=torch.tensor(scale), bit_width=torch.tensor(8.0), signed=False)
export_qonnx(fc, export_path=finn_onnx, input_t=input_qt, input_names=['input'])
# input_qt = QuantTensor(
# input_t, scale=torch.tensor(scale), bit_width=torch.tensor(8.0), signed=False)
export_qonnx(fc, export_path=finn_onnx, input_t=input_t, input_names=['input'])
model = ModelWrapper(finn_onnx)
model = model.transform(GiveUniqueNodeNames())
model = model.transform(DoubleToSingleFloat())
Expand All @@ -62,7 +62,7 @@ def test_brevitas_fc_onnx_export_and_exec(size, wbits, abits, pretrained):
model = model.transform(RemoveStaticGraphInputs())

# run using FINN-based execution
input_dict = {'input': input_a}
input_dict = {'input': input_t}
output_dict = oxe.execute_onnx(model, input_dict)
produced = output_dict[list(output_dict.keys())[0]]
# do forward pass in PyTorch/Brevitas
Expand All @@ -87,6 +87,7 @@ def test_brevitas_cnv_onnx_export_and_exec(wbits, abits, pretrained):
input_a = np.random.randint(MIN_INP_VAL, MAX_INP_VAL, size=CNV_INPUT_SIZE).astype(np.float32)
scale = 1. / 255
input_t = torch.from_numpy(input_a * scale)
# QONNX Export does not expect QuantTensor, only Tensor
input_qt = QuantTensor(
input_t, scale=torch.tensor(scale), bit_width=torch.tensor(8.0), signed=False)
export_qonnx(cnv, export_path=finn_onnx, input_t=input_qt, input_names=['input'])
Expand Down

0 comments on commit d27799c

Please sign in to comment.