Skip to content

Commit

Permalink
pre-commit applied
Browse files Browse the repository at this point in the history
Signed-off-by: Michal Danilowicz <[email protected]>
  • Loading branch information
Michal Danilowicz authored and PWzor committed Oct 10, 2024
1 parent 5c95ac4 commit b806a85
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 6 deletions.
2 changes: 1 addition & 1 deletion src/finn/qnn-data/templates/driver/driver_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def load_external_weights(self):
# weight_buf.sync_to_device()
weight_buf.flush()

input_shape = self._io_shape_dict['external_weights_input_shapes'][idma_name]
input_shape = self._io_shape_dict["external_weights_input_shapes"][idma_name]
# NHWC input?
if len(input_shape) == 4:
num_repeats = input_shape[1] * input_shape[2]
Expand Down
2 changes: 1 addition & 1 deletion src/finn/qnn-data/templates/driver/validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@
driver.copy_output_data_from_device(obuf_normal)
batch_ok = (obuf_normal.flatten() == exp.flatten()).sum()
ok += batch_ok
nok += (bsize - batch_ok)
nok += bsize - batch_ok
print("batch %d / %d : total OK %d NOK %d" % (i + 1, n_batches, ok, nok))

acc = 100.0 * ok / (total)
Expand Down
4 changes: 3 additions & 1 deletion src/finn/transformation/fpgadataflow/make_pynq_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,9 @@ def apply(self, model):
dma_target_model = ModelWrapper(dma_target_sdp.get_nodeattr("model"))
iodma_output_tensor = iodma_node.onnx_node.output[0]
dma_consumer = dma_target_model.find_consumer(iodma_output_tensor)
ext_weight_shapes_dict[idma_name] = dma_target_model.get_tensor_shape(dma_consumer.output[0])
ext_weight_shapes_dict[idma_name] = dma_target_model.get_tensor_shape(
dma_consumer.output[0]
)

init_tensor = df_model.get_initializer(iodma_node.onnx_node.input[0])
ext_weight_dma_cnt += 1
Expand Down
6 changes: 3 additions & 3 deletions tests/end2end/test_ext_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,11 @@

import pytest

import json
import os
import shutil
import subprocess
import wget
import json

import finn.builder.build_dataflow as build
import finn.builder.build_dataflow_config as build_cfg
Expand Down Expand Up @@ -67,7 +67,7 @@ def get_checkpoint_name(step, topology):
else:
# other checkpoints are onnx files
return build_dir + "/end2end_" + topology + "_ext_weights_%s.onnx" % (step)


def verify_runtime_weights(folding_config_file, runtime_weights_dir):
with open(folding_config_file) as file:
Expand All @@ -79,7 +79,7 @@ def verify_runtime_weights(folding_config_file, runtime_weights_dir):
num_ext_weights += 1
runtime_weights_files = os.listdir(runtime_weights_dir)
for idx in range(num_ext_weights):
expected_file = 'idma{}.npy'.format(idx)
expected_file = "idma{}.npy".format(idx)
assert expected_file in runtime_weights_files


Expand Down

0 comments on commit b806a85

Please sign in to comment.