Skip to content

Commit

Permalink
Change the way to validate keep_num_dims attribute for new tf.
Browse files Browse the repository at this point in the history
Signed-off-by: Jay Zhang <[email protected]>
  • Loading branch information
fatcat-z committed Nov 20, 2024
1 parent f85e88e commit b863e22
Showing 1 changed file with 2 additions and 4 deletions.
6 changes: 2 additions & 4 deletions tf2onnx/tflite_handlers/tfl_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,12 +201,10 @@ def to_tf(cls, ctx, node, **kwargs):
separate_fused_activation_function(ctx, node)
utils.make_sure(node.attr['weights_format'].s == b'DEFAULT',
"Only default weights format supported for fully connected op")
utils.make_sure(node.attr['keep_num_dims'].i == 0,
"Only keep_num_dims=False supported for fully connected op")
if node.attr['asymmetric_quantize_inputs'].i == 1:
dynamic_quantize_inputs(ctx, node)

if ctx.get_rank(node.input[0]) != 2:
if node.attr['keep_num_dims'].i == 0 and ctx.get_rank(node.input[0]) != 2:
# When a fullyconnected node has keep_num_dims=0 and input[0] rank > 2, the extra dims must be compressed
utils.make_sure(ctx.get_rank(node.input[1]) == 2, "weights for FullyConnected must have rank 2")
weights_shape = ctx.get_shape(node.input[1])[1]
Expand All @@ -217,7 +215,7 @@ def to_tf(cls, ctx, node, **kwargs):
ctx.replace_inputs(node, [reshape_node.output[0], node.input[1]])

transpose_node = ctx.insert_new_node_on_input(node, "Transpose", node.input[1],
name=None, input_index=1, perm=[1, 0])
name=None, input_index=1, perm=[1, 0])
transpose_node.skip_conversion = True
node.set_attr("transpose_a", 0)
node.set_attr("transpose_b", 0)
Expand Down

0 comments on commit b863e22

Please sign in to comment.