diff --git a/src/omlt/io/onnx_parser.py b/src/omlt/io/onnx_parser.py index 511261c0..224aae5a 100644 --- a/src/omlt/io/onnx_parser.py +++ b/src/omlt/io/onnx_parser.py @@ -176,13 +176,15 @@ def _visit_node(self, node, next_nodes): def _consume_dense_nodes(self, node, next_nodes): """Starting from a MatMul node, consume nodes to form a dense Ax + b node.""" + # This should only be called when we know we have a starting MatMul node. This + # error indicates a bug in the function calling this one. if node.op_type != "MatMul": raise ValueError( - f"{node.name} is a {node.op_type} node, only MatMul nodes can be used as starting points for consumption." + f"{node.name} is a {node.op_type} node, but the method for parsing MatMul nodes was invoked." ) if len(node.input) != 2: raise ValueError( - f"{node.name} input has {len(node.input)} dimensions, only nodes with 2 input dimensions can be used as starting points for consumption." + f"{node.name} input has {len(node.input)} dimensions, but the parser requires the starting node to have 2 input dimensions." ) [in_0, in_1] = list(node.input) @@ -200,7 +202,7 @@ def _consume_dense_nodes(self, node, next_nodes): raise TypeError(f"Expected a node next, got a {type_} instead.") if node.op_type != "Add": raise ValueError( - f"The first node to be consumed, {node.name}, is a {node.op_type} node. Only Add nodes are supported." + f"The next node to be parsed, {node.name}, is a {node.op_type} node. Only Add nodes are supported." ) # extract biases @@ -255,11 +257,11 @@ def _consume_gemm_dense_nodes(self, node, next_nodes): """Starting from a Gemm node, consume nodes to form a dense aAB + bC node.""" if node.op_type != "Gemm": raise ValueError( - f"{node.name} is a {node.op_type} node, only Gemm nodes can be used as starting points for consumption." + f"{node.name} is a {node.op_type} node, but the method for parsing Gemm nodes was invoked." ) if len(node.input) != 3: raise ValueError( - f"{node.name} input has {len(node.input)} dimensions, only nodes with 3 input dimensions can be used as starting points for consumption." + f"{node.name} input has {len(node.input)} dimensions, but the parser requires the starting node to have 3 input dimensions." ) attr = _collect_attributes(node) @@ -310,11 +312,11 @@ def _consume_conv_nodes(self, node, next_nodes): """ if node.op_type != "Conv": raise ValueError( - f"{node.name} is a {node.op_type} node, only Conv nodes can be used as starting points for consumption." + f"{node.name} is a {node.op_type} node, but the method for parsing Conv nodes was invoked." ) if len(node.input) not in [2, 3]: raise ValueError( - f"{node.name} input has {len(node.input)} dimensions, only nodes with 2 or 3 input dimensions can be used as starting points for consumption." + f"{node.name} input has {len(node.input)} dimensions, but the parser requires the starting node to have 2 or 3 input dimensions." ) if len(node.input) == 2: @@ -359,25 +361,32 @@ def _consume_conv_nodes(self, node, next_nodes): f"Input/output size ({input_output_size}) first dimension must match input weights channels ({in_channels})." ) + # TODO: need to check pads and dilations also have correct dimensions. Also should + # add support for autopad. + if "pads" in attr: + pads = attr["pads"] + else: + pads = 2 * (len(input_output_size) - 1) * [0] + + if "dilations" in attr: + dilations = attr["dilations"] + else: + dilations = (len(input_output_size) - 1) * [1] + # Other attributes are not supported - if "dilations" in attr and attr["dilations"] != [1, 1]: - raise ValueError( - f"{node} has non-identity dilations ({attr['dilations']}). This is not supported." - ) if attr["group"] != 1: raise ValueError( f"{node} has multiple groups ({attr['group']}). This is not supported." ) - if "pads" in attr and np.any(attr["pads"]): - raise ValueError( - f"{node} has non-zero pads ({attr['pads']}). This is not supported." - ) # generate new nodes for the node output - padding = 0 + padding = [ + pads[i] + pads[i + len(input_output_size) - 1] + for i in range(len(input_output_size) - 1) + ] output_size = [out_channels] - for w, k, s in zip(input_output_size[1:], kernel_shape, strides): - new_w = int((w - k + 2 * padding) / s) + 1 + for w, k, s, p in zip(input_output_size[1:], kernel_shape, strides, padding): + new_w = int((w - k + p) / s) + 1 output_size.append(new_w) activation = "linear" @@ -401,6 +410,8 @@ def _consume_conv_nodes(self, node, next_nodes): output_size, strides, weights, + pads=pads, + dilations=dilations, activation=activation, input_index_mapper=transformer, ) @@ -413,11 +424,11 @@ def _consume_reshape_nodes(self, node, next_nodes): """Parse a Reshape node.""" if node.op_type != "Reshape": raise ValueError( - f"{node.name} is a {node.op_type} node, only Reshape nodes can be used as starting points for consumption." + f"{node.name} is a {node.op_type} node, but the method for parsing Reshape nodes was invoked." ) if len(node.input) != 2: raise ValueError( - f"{node.name} input has {len(node.input)} dimensions, only nodes with 2 input dimensions can be used as starting points for consumption." + f"{node.name} input has {len(node.input)} dimensions, but the parser requires the starting node to have 2 input dimensions." ) [in_0, in_1] = list(node.input) input_layer = self._node_map[in_0] @@ -434,7 +445,7 @@ def _consume_pool_nodes(self, node, next_nodes): """ if node.op_type not in _POOLING_OP_TYPES: raise ValueError( - f"{node.name} is a {node.op_type} node, only MaxPool nodes can be used as starting points for consumption." + f"{node.name} is a {node.op_type} node, but the method for parsing MaxPool nodes was invoked." ) pool_func_name = "max" @@ -445,7 +456,7 @@ def _consume_pool_nodes(self, node, next_nodes): ) if len(node.input) != 1: raise ValueError( - f"{node.name} input has {len(node.input)} dimensions, only nodes with 1 input dimension can be used as starting points for consumption." + f"{node.name} input has {len(node.input)} dimensions, but the parser requires the starting node to have 1 input dimension." ) input_layer, transformer = self._node_input_and_transformer(node.input[0]) @@ -464,20 +475,14 @@ def _consume_pool_nodes(self, node, next_nodes): in_channels = input_output_size[0] attr = _collect_attributes(node) - kernel_depth = attr["kernel_shape"][0] + kernel_depth = in_channels kernel_shape = attr["kernel_shape"][1:] strides = attr["strides"] if "strides" in attr else [1] * len(kernel_shape) + pads = attr["pads"] if "pads" in attr else None + dilations = attr["dilations"] if "dilations" in attr else None # check only kernel shape, stride, storage order are set # everything else is not supported - if "dilations" in attr and attr["dilations"] != [1, 1]: - raise ValueError( - f"{node.name} has non-identity dilations ({attr['dilations']}). This is not supported." - ) - if "pads" in attr and np.any(attr["pads"]): - raise ValueError( - f"{node.name} has non-zero pads ({attr['pads']}). This is not supported." - ) if ("auto_pad" in attr) and (attr["auto_pad"] != "NOTSET"): raise ValueError( f"{node.name} has autopad set ({attr['auto_pad']}). This is not supported." @@ -519,6 +524,8 @@ def _consume_pool_nodes(self, node, next_nodes): pool_func_name, tuple(kernel_shape), kernel_depth, + pads=pads, + dilations=dilations, activation=activation, input_index_mapper=transformer, ) diff --git a/src/omlt/neuralnet/layer.py b/src/omlt/neuralnet/layer.py index 16e068a3..3feb9e59 100644 --- a/src/omlt/neuralnet/layer.py +++ b/src/omlt/neuralnet/layer.py @@ -402,6 +402,10 @@ class Layer2D(Layer): the size of the output. strides : matrix-like stride of the kernel. + pads : matrix-like + Padding for the kernel. Given as [left, bottom, right, top] + dilations : matrix-like + Dilations of the kernel activation : str or None activation function name input_index_mapper : IndexMapper or None @@ -414,6 +418,8 @@ def __init__( output_size, strides, *, + pads=None, + dilations=None, activation=None, input_index_mapper=None, ): @@ -424,12 +430,25 @@ def __init__( input_index_mapper=input_index_mapper, ) self.__strides = strides + if pads is None: + self.__pads = [0, 0, 0, 0] + else: + self.__pads = pads + if dilations is None: + self.__dilations = [1, 1] + else: + self.__dilations = dilations @property def strides(self): """Return the stride of the layer""" return self.__strides + @property + def pads(self): + """Return the padding of the layer""" + return self.__pads + @property def kernel_shape(self): """Return the shape of the kernel""" @@ -440,6 +459,20 @@ def kernel_depth(self): """Return the depth of the kernel""" raise NotImplementedError() + @property + def dilations(self): + """Return the kernel dilation of the layer""" + return self.__dilations + + @property + def dilated_kernel_shape(self): + """Return the shape of the kernel after dilation""" + dilated_dims = [ + self.dilations[i] * (self.kernel_shape[i] - 1) + 1 + for i in range(len(self.kernel_shape)) + ] + return tuple(dilated_dims) + def kernel_index_with_input_indexes(self, out_d, out_r, out_c): """ Returns an iterator over the index within the kernel and input index @@ -455,14 +488,16 @@ def kernel_index_with_input_indexes(self, out_d, out_r, out_c): the output column. """ kernel_d = self.kernel_depth - [kernel_r, kernel_c] = self.kernel_shape + [kernel_r, kernel_c] = self.dilated_kernel_shape [rows_stride, cols_stride] = self.__strides + [pads_row, pads_col] = self.__pads[:2] start_in_d = 0 - start_in_r = out_r * rows_stride - start_in_c = out_c * cols_stride - mapper = lambda x: x - if self.input_index_mapper is not None: - mapper = self.input_index_mapper + start_in_r = out_r * rows_stride - pads_row + start_in_c = out_c * cols_stride - pads_col + # Defined but never used: + # mapper = lambda x: x + # if self.input_index_mapper is not None: + # mapper = self.input_index_mapper for k_d in range(kernel_d): for k_r in range(kernel_r): @@ -475,7 +510,7 @@ def kernel_index_with_input_indexes(self, out_d, out_r, out_c): # as this could require using a partial kernel # even though we loop over ALL kernel indexes. if not all( - input_index[i] < self.input_size[i] + input_index[i] < self.input_size[i] and input_index[i] >= 0 for i in range(len(input_index)) ): continue @@ -522,6 +557,10 @@ class PoolingLayer2D(Layer2D): the size of the output. strides : matrix-like stride of the kernel. + pads : matrix-like + Padding for the kernel. Given as [left, bottom, right, top] + dilations : matrix-like + Dilations of the kernel pool_func : str name of function used to pool values in a kernel to a single value. transpose : bool @@ -544,6 +583,8 @@ def __init__( kernel_shape, kernel_depth, *, + pads=None, + dilations=None, activation=None, input_index_mapper=None, ): @@ -551,6 +592,8 @@ def __init__( input_size, output_size, strides, + pads=pads, + dilations=dilations, activation=activation, input_index_mapper=input_index_mapper, ) @@ -598,6 +641,10 @@ class ConvLayer2D(Layer2D): stride of the cross-correlation kernel. kernel : matrix-like the cross-correlation kernel. + pads : matrix-like + Padding for the kernel. Given as [left, bottom, right, top] + dilations : matrix-like + Dilations of the kernel activation : str or None activation function name input_index_mapper : IndexMapper or None @@ -611,6 +658,8 @@ def __init__( strides, kernel, *, + pads=None, + dilations=None, activation=None, input_index_mapper=None, ): @@ -618,10 +667,31 @@ def __init__( input_size, output_size, strides, + pads=pads, + dilations=dilations, activation=activation, input_index_mapper=input_index_mapper, ) self.__kernel = kernel + if self.dilations != [1, 1]: + dilated = np.zeros( + ( + kernel.shape[0], + kernel.shape[1], + (kernel.shape[2] - 1) * dilations[0] + 1, + (kernel.shape[3] - 1) * dilations[1] + 1, + ) + ) + for i in range(kernel.shape[0]): + for j in range(kernel.shape[1]): + for k in range(kernel.shape[2]): + for l in range(kernel.shape[3]): + dilated[i, j, k * dilations[0], l * dilations[1]] = kernel[ + i, j, k, l + ] + self.__dilated_kernel = dilated + else: + self.__dilated_kernel = kernel def kernel_with_input_indexes(self, out_d, out_r, out_c): """ @@ -658,6 +728,11 @@ def kernel(self): """Return the cross-correlation kernel""" return self.__kernel + @property + def dilated_kernel(self): + """Return the dilated cross-correlation kernel""" + return self.__dilated_kernel + def __str__(self): return f"ConvLayer(input_size={self.input_size}, output_size={self.output_size}, strides={self.strides}, kernel_shape={self.kernel_shape})" diff --git a/tests/io/test_onnx_parser.py b/tests/io/test_onnx_parser.py index 763b282c..a0f67448 100644 --- a/tests/io/test_onnx_parser.py +++ b/tests/io/test_onnx_parser.py @@ -5,6 +5,8 @@ if onnx_available: from omlt.io.onnx import load_onnx_neural_network from omlt.io.onnx_parser import NetworkParser + from onnx import numpy_helper + from numpy import array @pytest.mark.skipif(not onnx_available, reason="Need ONNX for this test") @@ -76,6 +78,8 @@ def test_gemm_transB(datadir): @pytest.mark.skipif(not onnx_available, reason="Need ONNX for this test") def test_conv(datadir): model = onnx.load(datadir.file("convx1_gemmx1.onnx")) + del model.graph.node[0].attribute[0] + del model.graph.node[0].attribute[2] net = load_onnx_neural_network(model) layers = list(net.layers) assert len(layers) == 4 @@ -84,6 +88,37 @@ def test_conv(datadir): assert layers[3].activation == "relu" assert layers[1].strides == [1, 1] assert layers[1].kernel_shape == (2, 2) + assert layers[1].dilations == [1, 1] + assert layers[1].pads == [0, 0, 0, 0] + + +@pytest.mark.skipif(not onnx_available, reason="Need ONNX for this test") +def test_conv_dilations(datadir): + model = onnx.load(datadir.file("convx1_gemmx1.onnx")) + for attr in model.graph.node[0].attribute: + if attr.name == "dilations": + del attr.ints[:] + attr.ints.extend([2, 2]) + if attr.name == "pads": + del attr.ints[:] + attr.ints.extend([1, 2, 1, 0]) + model.graph.node[1].attribute[0].t.raw_data = numpy_helper.from_array( + array([-1, 128]) + ).raw_data + net = load_onnx_neural_network(model) + layers = list(net.layers) + assert layers[1].dilations == [2, 2] + assert ( + layers[1].dilated_kernel[0][0].round(8) + == array( + [[-0.00886667, 0, 0.18750042], [0, 0, 0], [-0.11404419, 0, -0.02588665]] + ) + ).all() + assert ( + layers[1].dilated_kernel[1][0].round(8) + == array([[-0.07554907, 0, -0.05939162], [0, 0, 0], [0.2217437, 0, 0.14637864]]) + ).all() + assert layers[1].pads == [1, 2, 1, 0] @pytest.mark.skipif(not onnx_available, reason="Need ONNX for this test") @@ -153,7 +188,7 @@ def test_consume_wrong_node_type(datadir): parser._nodes["StatefulPartitionedCall/keras_linear_131/dense/BiasAdd"][1], parser._nodes["StatefulPartitionedCall/keras_linear_131/dense/BiasAdd"][2], ) - expected_msg_dense = "StatefulPartitionedCall/keras_linear_131/dense/BiasAdd is a Add node, only MatMul nodes can be used as starting points for consumption." + expected_msg_dense = "StatefulPartitionedCall/keras_linear_131/dense/BiasAdd is a Add node, but the method for parsing MatMul nodes was invoked." assert str(excinfo.value) == expected_msg_dense with pytest.raises(ValueError) as excinfo: @@ -161,7 +196,7 @@ def test_consume_wrong_node_type(datadir): parser._nodes["StatefulPartitionedCall/keras_linear_131/dense/BiasAdd"][1], parser._nodes["StatefulPartitionedCall/keras_linear_131/dense/BiasAdd"][2], ) - expected_msg_gemm = "StatefulPartitionedCall/keras_linear_131/dense/BiasAdd is a Add node, only Gemm nodes can be used as starting points for consumption." + expected_msg_gemm = "StatefulPartitionedCall/keras_linear_131/dense/BiasAdd is a Add node, but the method for parsing Gemm nodes was invoked." assert str(excinfo.value) == expected_msg_gemm with pytest.raises(ValueError) as excinfo: @@ -169,7 +204,7 @@ def test_consume_wrong_node_type(datadir): parser._nodes["StatefulPartitionedCall/keras_linear_131/dense/BiasAdd"][1], parser._nodes["StatefulPartitionedCall/keras_linear_131/dense/BiasAdd"][2], ) - expected_msg_conv = "StatefulPartitionedCall/keras_linear_131/dense/BiasAdd is a Add node, only Conv nodes can be used as starting points for consumption." + expected_msg_conv = "StatefulPartitionedCall/keras_linear_131/dense/BiasAdd is a Add node, but the method for parsing Conv nodes was invoked." assert str(excinfo.value) == expected_msg_conv with pytest.raises(ValueError) as excinfo: @@ -177,7 +212,7 @@ def test_consume_wrong_node_type(datadir): parser._nodes["StatefulPartitionedCall/keras_linear_131/dense/BiasAdd"][1], parser._nodes["StatefulPartitionedCall/keras_linear_131/dense/BiasAdd"][2], ) - expected_msg_reshape = "StatefulPartitionedCall/keras_linear_131/dense/BiasAdd is a Add node, only Reshape nodes can be used as starting points for consumption." + expected_msg_reshape = "StatefulPartitionedCall/keras_linear_131/dense/BiasAdd is a Add node, but the method for parsing Reshape nodes was invoked." assert str(excinfo.value) == expected_msg_reshape with pytest.raises(ValueError) as excinfo: @@ -185,7 +220,7 @@ def test_consume_wrong_node_type(datadir): parser._nodes["StatefulPartitionedCall/keras_linear_131/dense/BiasAdd"][1], parser._nodes["StatefulPartitionedCall/keras_linear_131/dense/BiasAdd"][2], ) - expected_msg_pool = """StatefulPartitionedCall/keras_linear_131/dense/BiasAdd is a Add node, only MaxPool nodes can be used as starting points for consumption.""" + expected_msg_pool = """StatefulPartitionedCall/keras_linear_131/dense/BiasAdd is a Add node, but the method for parsing MaxPool nodes was invoked.""" assert str(excinfo.value) == expected_msg_pool @@ -203,7 +238,7 @@ def test_consume_dense_wrong_dims(datadir): parser._nodes["StatefulPartitionedCall/keras_linear_131/dense/MatMul"][1], parser._nodes["StatefulPartitionedCall/keras_linear_131/dense/MatMul"][2], ) - expected_msg_dense = "StatefulPartitionedCall/keras_linear_131/dense/MatMul input has 3 dimensions, only nodes with 2 input dimensions can be used as starting points for consumption." + expected_msg_dense = "StatefulPartitionedCall/keras_linear_131/dense/MatMul input has 3 dimensions, but the parser requires the starting node to have 2 input dimensions." assert str(excinfo.value) == expected_msg_dense @@ -217,7 +252,7 @@ def test_consume_gemm_wrong_dims(datadir): parser._consume_gemm_dense_nodes( parser._nodes["Gemm_0"][1], parser._nodes["Gemm_0"][2] ) - expected_msg_gemm = "Gemm_0 input has 4 dimensions, only nodes with 3 input dimensions can be used as starting points for consumption." + expected_msg_gemm = "Gemm_0 input has 4 dimensions, but the parser requires the starting node to have 3 input dimensions." assert str(excinfo.value) == expected_msg_gemm @@ -231,7 +266,7 @@ def test_consume_conv_wrong_dims(datadir): parser._consume_conv_nodes( parser._nodes["Conv_0"][1], parser._nodes["Conv_0"][2] ) - expected_msg_conv = "Conv_0 input has 4 dimensions, only nodes with 2 or 3 input dimensions can be used as starting points for consumption." + expected_msg_conv = "Conv_0 input has 4 dimensions, but the parser requires the starting node to have 2 or 3 input dimensions." assert str(excinfo.value) == expected_msg_conv @@ -245,7 +280,7 @@ def test_consume_reshape_wrong_dims(datadir): parser._consume_reshape_nodes( parser._nodes["Reshape_2"][1], parser._nodes["Reshape_2"][2] ) - expected_msg_reshape = """Reshape_2 input has 3 dimensions, only nodes with 2 input dimensions can be used as starting points for consumption.""" + expected_msg_reshape = """Reshape_2 input has 3 dimensions, but the parser requires the starting node to have 2 input dimensions.""" assert str(excinfo.value) == expected_msg_reshape @@ -257,5 +292,5 @@ def test_consume_maxpool_wrong_dims(datadir): parser._nodes["node1"][1].input.append("abcd") with pytest.raises(ValueError) as excinfo: parser._consume_pool_nodes(parser._nodes["node1"][1], parser._nodes["node1"][2]) - expected_msg_maxpool = """node1 input has 2 dimensions, only nodes with 1 input dimension can be used as starting points for consumption.""" + expected_msg_maxpool = """node1 input has 2 dimensions, but the parser requires the starting node to have 1 input dimension.""" assert str(excinfo.value) == expected_msg_maxpool