Skip to content

Commit

Permalink
Add transposed convolution layer (keras-team#131)
Browse files Browse the repository at this point in the history
* add transposed convolution layer

* fix comments
  • Loading branch information
chenmoneygithub authored May 10, 2023
1 parent eb09131 commit 5c0f9b5
Show file tree
Hide file tree
Showing 11 changed files with 1,107 additions and 16 deletions.
19 changes: 10 additions & 9 deletions keras_core/backend/common/backend_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,17 @@ def _compute_conv_transpose_output_length(


def compute_conv_transpose_output_shape(
inputs,
kernel,
input_shape,
kernel_size,
filters,
strides,
padding,
output_padding=None,
data_format="channels_last",
dilation_rate=1,
):
num_spatial_dims = len(inputs.shape) - 2
kernel_spatial_shape = kernel.shape[:-2]
num_spatial_dims = len(input_shape) - 2
kernel_spatial_shape = kernel_size

if isinstance(output_padding, int):
output_padding = (output_padding,) * len(kernel_spatial_shape)
Expand All @@ -52,9 +53,9 @@ def compute_conv_transpose_output_shape(
dilation_rate = (dilation_rate,) * num_spatial_dims

if data_format == "channels_last":
inputs_spatial_shape = inputs.shape[1:-1]
input_spatial_shape = input_shape[1:-1]
else:
inputs_spatial_shape = inputs.shape[2:]
input_spatial_shape = input_shape[2:]

output_shape = []
for i in range(num_spatial_dims):
Expand All @@ -63,7 +64,7 @@ def compute_conv_transpose_output_shape(
)
output_shape.append(
_compute_conv_transpose_output_length(
inputs_spatial_shape[i],
input_spatial_shape[i],
kernel_spatial_shape[i],
padding=padding,
output_padding=current_output_padding,
Expand All @@ -73,7 +74,7 @@ def compute_conv_transpose_output_shape(
)

if data_format == "channels_last":
output_shape = [inputs.shape[0]] + output_shape + [kernel.shape[-2]]
output_shape = [input_shape[0]] + output_shape + [filters]
else:
output_shape = [inputs.shape[0], kernel.shape[-1]] + output_shape
output_shape = [input_shape[0], filters] + output_shape
return output_shape
7 changes: 5 additions & 2 deletions keras_core/backend/tensorflow/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,9 +360,12 @@ def conv_transpose(
dilation_rate=1,
):
tf_data_format = _convert_data_format(data_format, len(inputs.shape))
kernel_size = kernel.shape[:-2]
filters = kernel.shape[-2]
output_shape = compute_conv_transpose_output_shape(
inputs,
kernel,
inputs.shape,
kernel_size,
filters,
strides,
padding,
output_padding,
Expand Down
3 changes: 3 additions & 0 deletions keras_core/layers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from keras_core.layers.activations.activation import Activation
from keras_core.layers.attention.attention import Attention
from keras_core.layers.convolutional.conv1d import Conv1D
from keras_core.layers.convolutional.conv1d_transpose import Conv1DTranspose
from keras_core.layers.convolutional.conv2d import Conv2D
from keras_core.layers.convolutional.conv2d_transpose import Conv2DTranspose
from keras_core.layers.convolutional.conv3d import Conv3D
from keras_core.layers.convolutional.conv3d_transpose import Conv3DTranspose
from keras_core.layers.core.dense import Dense
from keras_core.layers.core.einsum_dense import EinsumDense
from keras_core.layers.core.embedding import Embedding
Expand Down
255 changes: 255 additions & 0 deletions keras_core/layers/convolutional/base_conv_transpose.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,255 @@
"""Keras base class for transpose convolution layers."""

from keras_core import activations
from keras_core import constraints
from keras_core import initializers
from keras_core import operations as ops
from keras_core import regularizers
from keras_core.backend import standardize_data_format
from keras_core.backend.common.backend_utils import (
compute_conv_transpose_output_shape,
)
from keras_core.layers.input_spec import InputSpec
from keras_core.layers.layer import Layer


class BaseConvTranspose(Layer):
"""Abstract N-D transpose convolution layer.
The need for transposed convolutions generally arises
from the desire to use a transformation going in the opposite direction
of a normal convolution, i.e., from something that has the shape of the
output of some convolution to something that has the shape of its input
while maintaining a connectivity pattern that is compatible with
said convolution.
Args:
rank: int, the rank of the transposed convolution, e.g. 2 for 2D
transposed convolution.
filters: int, the dimension of the output space (the number of filters
in the transposed convolution).
kernel_size: int or tuple/list of N integers (N=`rank`), specifying the
size of the transposed convolution window.
strides: int or tuple/list of N integers, specifying the stride length
of the transposed convolution. If only one int is specified, the
same stride size will be used for all dimensions.
`stride value != 1` is incompatible with `dilation_rate != 1`.
padding: string, either `"valid"` or `"same"` (case-insensitive).
`"valid"` means no padding. `"same"` results in padding evenly to
the left/right or up/down of the input such that output has the same
height/width dimension as the input.
data_format: string, either `"channels_last"` or `"channels_first"`.
The ordering of the dimensions in the inputs. `"channels_last"`
corresponds to inputs with shape `(batch, steps, features)`
while `"channels_first"` corresponds to inputs with shape
`(batch, features, steps)`. It defaults to the `image_data_format`
value found in your Keras config file at `~/.keras/keras.json`.
If you never set it, then it will be `"channels_last"`.
dilation_rate: int or tuple/list of N integers, specifying the dilation
rate to use for dilated convolution. If only one int is specified,
the same dilation rate will be used for all dimensions.
activation: Activation function. If `None`, no activation is applied.
use_bias: bool, if `True`, bias will be added to the output.
kernel_initializer: Initializer for the convolution kernel. If `None`,
the default initializer (`"glorot_uniform"`) will be used.
bias_initializer: Initializer for the bias vector. If `None`, the
default initializer (`"zeros"`) will be used.
kernel_regularizer: Optional regularizer for the convolution kernel.
bias_regularizer: Optional regularizer for the bias vector.
activity_regularizer: Optional regularizer function for the output.
kernel_constraint: Optional projection function to be applied to the
kernel after being updated by an `Optimizer` (e.g. used to implement
norm constraints or value constraints for layer weights). The
function must take as input the unprojected variable and must return
the projected variable (which must have the same shape). Constraints
are not safe to use when doing asynchronous distributed training.
bias_constraint: Optional projection function to be applied to the
bias after being updated by an `Optimizer`.
"""

def __init__(
self,
rank,
filters,
kernel_size,
strides=1,
padding="valid",
data_format=None,
dilation_rate=1,
activation=None,
use_bias=True,
kernel_initializer="glorot_uniform",
bias_initializer="zeros",
kernel_regularizer=None,
bias_regularizer=None,
activity_regularizer=None,
kernel_constraint=None,
bias_constraint=None,
trainable=True,
name=None,
**kwargs,
):
super().__init__(
trainable=trainable,
name=name,
activity_regularizer=activity_regularizer,
**kwargs,
)
self.rank = rank
self.filters = filters

if isinstance(kernel_size, int):
kernel_size = (kernel_size,) * self.rank
self.kernel_size = kernel_size

if isinstance(strides, int):
strides = (strides,) * self.rank
self.strides = strides

if isinstance(dilation_rate, int):
dilation_rate = (dilation_rate,) * self.rank
self.dilation_rate = dilation_rate

self.padding = padding
self.data_format = standardize_data_format(data_format)
self.activation = activations.get(activation)
self.use_bias = use_bias
self.kernel_initializer = initializers.get(kernel_initializer)
self.bias_initializer = initializers.get(bias_initializer)
self.kernel_regularizer = regularizers.get(kernel_regularizer)
self.bias_regularizer = regularizers.get(bias_regularizer)
self.kernel_constraint = constraints.get(kernel_constraint)
self.bias_constraint = constraints.get(bias_constraint)
self.input_spec = InputSpec(min_ndim=self.rank + 2)
self.data_format = self.data_format

if self.filters is not None and self.filters <= 0:
raise ValueError(
"Invalid value for argument `filters`. Expected a strictly "
f"positive value. Received filters={self.filters}."
)

if not all(self.kernel_size):
raise ValueError(
"The argument `kernel_size` cannot contain 0. Received: "
f"{self.kernel_size}"
)

if not all(self.strides):
raise ValueError(
"The argument `strides` cannot contains 0. Received: "
f"{self.strides}"
)

if max(self.strides) > 1 and max(self.dilation_rate) > 1:
raise ValueError(
"`strides > 1` not supported in conjunction with "
f"`dilation_rate > 1`. Received: strides={self.strides} and "
f"dilation_rate={self.dilation_rate}"
)

def build(self, input_shape):
if self.data_format == "channels_last":
channel_axis = -1
input_channel = input_shape[-1]
else:
channel_axis = 1
input_channel = input_shape[1]
self.input_spec = InputSpec(
min_ndim=self.rank + 2, axes={channel_axis: input_channel}
)
kernel_shape = self.kernel_size + (
self.filters,
input_channel,
)

self.kernel = self.add_weight(
name="kernel",
shape=kernel_shape,
initializer=self.kernel_initializer,
regularizer=self.kernel_regularizer,
constraint=self.kernel_constraint,
trainable=True,
dtype=self.dtype,
)
if self.use_bias:
self.bias = self.add_weight(
name="bias",
shape=(self.filters,),
initializer=self.bias_initializer,
regularizer=self.bias_regularizer,
constraint=self.bias_constraint,
trainable=True,
dtype=self.dtype,
)
else:
self.bias = None
self.built = True

def call(self, inputs):
outputs = ops.conv_transpose(
inputs,
self.kernel,
strides=list(self.strides),
padding=self.padding,
dilation_rate=self.dilation_rate,
data_format=self.data_format,
)

if self.use_bias:
if self.data_format == "channels_last":
bias_shape = (1,) * (self.rank + 1) + (self.filters,)
else:
bias_shape = (1, self.filters) + (1,) * self.rank
bias = ops.reshape(self.bias, bias_shape)
outputs += bias

if self.activation is not None:
return self.activation(outputs)
return outputs

def compute_output_shape(self, input_shape):
return compute_conv_transpose_output_shape(
input_shape,
self.kernel_size,
self.filters,
strides=self.strides,
padding=self.padding,
data_format=self.data_format,
dilation_rate=self.dilation_rate,
)

def get_config(self):
config = super().get_config()
config.update(
{
"filters": self.filters,
"kernel_size": self.kernel_size,
"strides": self.strides,
"padding": self.padding,
"data_format": self.data_format,
"dilation_rate": self.dilation_rate,
"activation": activations.serialize(self.activation),
"use_bias": self.use_bias,
"kernel_initializer": initializers.serialize(
self.kernel_initializer
),
"bias_initializer": initializers.serialize(
self.bias_initializer
),
"kernel_regularizer": regularizers.serialize(
self.kernel_regularizer
),
"bias_regularizer": regularizers.serialize(
self.bias_regularizer
),
"activity_regularizer": regularizers.serialize(
self.activity_regularizer
),
"kernel_constraint": constraints.serialize(
self.kernel_constraint
),
"bias_constraint": constraints.serialize(self.bias_constraint),
}
)
return config
3 changes: 1 addition & 2 deletions keras_core/layers/convolutional/conv1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,7 @@ class Conv1D(BaseConv):
>>> # The inputs are 128-length vectors with 10 timesteps, and the
>>> # batch size is 4.
>>> input_shape = (4, 10, 128)
>>> x = np.random.normal(4, 10, 128)
>>> x = np.random.rand(4, 10, 128)
>>> y = keras_core.layers.Conv1D(32, 3, activation='relu')(x)
>>> print(y.shape)
(4, 8, 32)
Expand Down
Loading

0 comments on commit 5c0f9b5

Please sign in to comment.