From b863e22e8d13d57c4ffe4b07692f911765173b8a Mon Sep 17 00:00:00 2001 From: Jay Zhang Date: Wed, 20 Nov 2024 03:49:57 +0000 Subject: [PATCH] Change the way to validate keep_num_dims attribute for new tf. Signed-off-by: Jay Zhang --- tf2onnx/tflite_handlers/tfl_math.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/tf2onnx/tflite_handlers/tfl_math.py b/tf2onnx/tflite_handlers/tfl_math.py index add2f7de3..745034a41 100644 --- a/tf2onnx/tflite_handlers/tfl_math.py +++ b/tf2onnx/tflite_handlers/tfl_math.py @@ -201,12 +201,10 @@ def to_tf(cls, ctx, node, **kwargs): separate_fused_activation_function(ctx, node) utils.make_sure(node.attr['weights_format'].s == b'DEFAULT', "Only default weights format supported for fully connected op") - utils.make_sure(node.attr['keep_num_dims'].i == 0, - "Only keep_num_dims=False supported for fully connected op") if node.attr['asymmetric_quantize_inputs'].i == 1: dynamic_quantize_inputs(ctx, node) - if ctx.get_rank(node.input[0]) != 2: + if node.attr['keep_num_dims'].i == 0 and ctx.get_rank(node.input[0]) != 2: # When a fullyconnected node has keep_num_dims=0 and input[0] rank > 2, the extra dims must be compressed utils.make_sure(ctx.get_rank(node.input[1]) == 2, "weights for FullyConnected must have rank 2") weights_shape = ctx.get_shape(node.input[1])[1] @@ -217,7 +215,7 @@ def to_tf(cls, ctx, node, **kwargs): ctx.replace_inputs(node, [reshape_node.output[0], node.input[1]]) transpose_node = ctx.insert_new_node_on_input(node, "Transpose", node.input[1], - name=None, input_index=1, perm=[1, 0]) + name=None, input_index=1, perm=[1, 0]) transpose_node.skip_conversion = True node.set_attr("transpose_a", 0) node.set_attr("transpose_b", 0)