diff --git a/docs/src/api/Lux/layers.md b/docs/src/api/Lux/layers.md index 6041984940..b591844aab 100644 --- a/docs/src/api/Lux/layers.md +++ b/docs/src/api/Lux/layers.md @@ -35,10 +35,13 @@ VariationalHiddenDropout ## Pooling Layers ```@docs +AdaptiveLPPool AdaptiveMaxPool AdaptiveMeanPool +GlobalLPPool GlobalMaxPool GlobalMeanPool +LPPool MaxPool MeanPool ``` diff --git a/ext/LuxSimpleChainsExt.jl b/ext/LuxSimpleChainsExt.jl index 1d10fc106a..1d1dd2651c 100644 --- a/ext/LuxSimpleChainsExt.jl +++ b/ext/LuxSimpleChainsExt.jl @@ -61,8 +61,8 @@ function Lux.make_simplechain_network(layer::FlattenLayer) end function Lux.make_simplechain_network(layer::MaxPool) - if layer.stride == layer.k && (!(layer.pad isa SamePad) && all(==(0), layer.pad)) - return SimpleChains.MaxPool(layer.k) + if layer.layer.mode.stride == layer.layer.mode.k && all(==(0), layer.layer.mode.pad) + return SimpleChains.MaxPool(layer.layer.mode.k) end throw(SimpleChainsModelConversionException("MaxPool with non-standard parameters not \ supported.")) diff --git a/src/Lux.jl b/src/Lux.jl index a0650c673f..37506abcaa 100644 --- a/src/Lux.jl +++ b/src/Lux.jl @@ -58,6 +58,7 @@ include("layers/basic.jl") include("layers/containers.jl") include("layers/normalize.jl") include("layers/conv.jl") +include("layers/pooling.jl") include("layers/dropout.jl") include("layers/recurrent.jl") include("layers/extension.jl") @@ -87,8 +88,9 @@ include("distributed/public_api.jl") # Layers export Chain, Parallel, SkipConnection, PairwiseFusion, BranchLayer, Maxout, RepeatedLayer export Bilinear, Dense, Embedding, Scale -export Conv, ConvTranspose, MaxPool, MeanPool, GlobalMaxPool, GlobalMeanPool, - AdaptiveMaxPool, AdaptiveMeanPool, Upsample, PixelShuffle +export Conv, ConvTranspose, Upsample, PixelShuffle +export MaxPool, MeanPool, LPPool, GlobalMaxPool, GlobalMeanPool, GlobalLPPool, + AdaptiveMaxPool, AdaptiveMeanPool, AdaptiveLPPool export AlphaDropout, Dropout, VariationalHiddenDropout export BatchNorm, GroupNorm, InstanceNorm, LayerNorm export WeightNorm diff --git a/src/layers/conv.jl b/src/layers/conv.jl index 66c5e61787..b026384069 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -42,15 +42,6 @@ CRC.@non_differentiable conv_transpose_dims(::Any...) conv_transpose(x, weight, cdims) = LuxLib.Impl.∇conv_data(x, weight, cdims) -function compute_adaptive_pooling_dims(x::AbstractArray, outsize) - insize = size(x)[1:(end - 2)] - stride = insize .÷ outsize - k = insize .- (outsize .- 1) .* stride - return PoolDims(x, k; padding=0, stride=stride) -end - -CRC.@non_differentiable compute_adaptive_pooling_dims(::Any, ::Any) - function init_conv_weight( rng::AbstractRNG, init_weight::F, filter::NTuple{N, <:IntegerType}, in_chs::IntegerType, out_chs::IntegerType, groups, σ::A) where {F, N, A} @@ -508,255 +499,3 @@ end function PixelShuffle(r::IntegerType) return PixelShuffle(WrappedFunction(Base.Fix2(pixel_shuffle, r))) end - -@doc doc""" - MaxPool(window::NTuple; pad=0, stride=window) - -Max pooling layer, which replaces all pixels in a block of size `window` with the maximum -value. - -# Arguments - - - `window`: Tuple of integers specifying the size of the window. Eg, for 2D pooling - `length(window) == 2` - -## Keyword Arguments - - - `stride`: Should each be either single integer, or a tuple with `N` integers - - - `pad`: Specifies the number of elements added to the borders of the data array. It can - be - - + a single integer for equal padding all around, - + a tuple of `N` integers, to apply the same padding at begin/end of each spatial - dimension, - + a tuple of `2*N` integers, for asymmetric padding, or - + the singleton `SamePad()`, to calculate padding such that - `size(output,d) == size(x,d) / stride` (possibly rounded) for each spatial - dimension. - -# Extended Help - -## Inputs - - - `x`: Data satisfying `ndims(x) == N + 2`, i.e. `size(x) = (I_N, ..., I_1, C, N)` - -## Returns - - - Output of the pooling `y` of size `(O_N, ..., O_1, C, N)` where - -```math - O_i = \left\lfloor\frac{I_i + p_i + p_{(i + N) \% |p|} - d_i \times (k_i - 1)}{s_i} + 1\right\rfloor -``` - - - Empty `NamedTuple()` - -See also [`Conv`](@ref), [`MeanPool`](@ref), [`GlobalMaxPool`](@ref), -[`AdaptiveMaxPool`](@ref) -""" -@concrete struct MaxPool <: AbstractLuxLayer - k <: Tuple{Vararg{IntegerType}} - pad <: Tuple{Vararg{IntegerType}} - stride <: Tuple{Vararg{IntegerType}} -end - -function MaxPool(k::Tuple{Vararg{IntegerType}}; pad=0, stride=k) - stride = Utils.expand(Val(length(k)), stride) - pad = calc_padding(pad, k, 1, stride) - @argcheck allequal(length, (stride, k)) - - return MaxPool(k, pad, stride) -end - -function (m::MaxPool)(x, _, st::NamedTuple) - return maxpool(x, PoolDims(x, m.k; padding=m.pad, m.stride)), st -end - -function Base.show(io::IO, m::MaxPool) - print(io, "MaxPool(", m.k) - all(==(0), m.pad) || print(io, ", pad=", PrettyPrinting.tuple_string(m.pad)) - m.stride == m.k || print(io, ", stride=", PrettyPrinting.tuple_string(m.stride)) - print(io, ")") -end - -@doc doc""" - MeanPool(window::NTuple; pad=0, stride=window) - -Mean pooling layer, which replaces all pixels in a block of size `window` with the mean -value. - -# Arguments - - - `window`: Tuple of integers specifying the size of the window. Eg, for 2D pooling - `length(window) == 2` - -## Keyword Arguments - - - `stride`: Should each be either single integer, or a tuple with `N` integers - - - `pad`: Specifies the number of elements added to the borders of the data array. It can - be - - + a single integer for equal padding all around, - + a tuple of `N` integers, to apply the same padding at begin/end of each spatial - dimension, - + a tuple of `2*N` integers, for asymmetric padding, or - + the singleton `SamePad()`, to calculate padding such that - `size(output,d) == size(x,d) / stride` (possibly rounded) for each spatial - dimension. - -# Extended Help - -## Inputs - - - `x`: Data satisfying `ndims(x) == N + 2`, i.e. `size(x) = (I_N, ..., I_1, C, N)` - -## Returns - - - Output of the pooling `y` of size `(O_N, ..., O_1, C, N)` where - -```math - O_i = \left\lfloor\frac{I_i + p_i + p_{(i + N) \% |p|} - d_i \times (k_i - 1)}{s_i} + 1\right\rfloor -``` - - - Empty `NamedTuple()` - -See also [`Conv`](@ref), [`MaxPool`](@ref), [`GlobalMeanPool`](@ref), -[`AdaptiveMeanPool`](@ref) -""" -@concrete struct MeanPool <: AbstractLuxLayer - k <: Tuple{Vararg{IntegerType}} - pad <: Tuple{Vararg{IntegerType}} - stride <: Tuple{Vararg{IntegerType}} -end - -function MeanPool(k::Tuple{Vararg{IntegerType}}; pad=0, stride=k) - stride = Utils.expand(Val(length(k)), stride) - pad = calc_padding(pad, k, 1, stride) - @argcheck allequal(length, (stride, k)) - - return MeanPool(k, pad, stride) -end - -function (m::MeanPool)(x, _, st::NamedTuple) - return meanpool(x, PoolDims(x, m.k; padding=m.pad, m.stride)), st -end - -function Base.show(io::IO, m::MeanPool) - print(io, "MeanPool(", m.k) - all(==(0), m.pad) || print(io, ", pad=", PrettyPrinting.tuple_string(m.pad)) - m.stride == m.k || print(io, ", stride=", PrettyPrinting.tuple_string(m.stride)) - print(io, ")") -end - -""" - GlobalMaxPool() - -Global Max Pooling layer. Transforms (w,h,c,b)-shaped input into (1,1,c,b)-shaped output, -by performing max pooling on the complete (w,h)-shaped feature maps. - -## Inputs - - - `x`: Data satisfying `ndims(x) > 2`, i.e. `size(x) = (I_N, ..., I_1, C, N)` - -## Returns - - - Output of the pooling `y` of size `(1, ..., 1, C, N)` - - Empty `NamedTuple()` - -See also [`MaxPool`](@ref), [`AdaptiveMaxPool`](@ref), [`GlobalMeanPool`](@ref) -""" -struct GlobalMaxPool <: AbstractLuxLayer end - -function (g::GlobalMaxPool)(x, _, st::NamedTuple) - return maxpool(x, PoolDims(x, size(x)[1:(end - 2)])), st -end - -""" - GlobalMeanPool() - -Global Mean Pooling layer. Transforms (w,h,c,b)-shaped input into (1,1,c,b)-shaped output, -by performing mean pooling on the complete (w,h)-shaped feature maps. - -## Inputs - - - `x`: Data satisfying `ndims(x) > 2`, i.e. `size(x) = (I_N, ..., I_1, C, N)` - -## Returns - - - Output of the pooling `y` of size `(1, ..., 1, C, N)` - - Empty `NamedTuple()` - -See also [`MeanPool`](@ref), [`AdaptiveMeanPool`](@ref), [`GlobalMaxPool`](@ref) -""" -struct GlobalMeanPool <: AbstractLuxLayer end - -function (g::GlobalMeanPool)(x, _, st::NamedTuple) - return meanpool(x, PoolDims(x, size(x)[1:(end - 2)])), st -end - -""" - AdaptiveMaxPool(out::NTuple) - -Adaptive Max Pooling layer. Calculates the necessary window size such that its output has -`size(y)[1:N] == out`. - -## Arguments - - - `out`: Size of the first `N` dimensions for the output - -## Inputs - - - `x`: Expects as input an array with `ndims(x) == N+2`, i.e. channel and batch - dimensions, after the `N` feature dimensions, where `N = length(out)`. - -## Returns - - - Output of size `(out..., C, N)` - - Empty `NamedTuple()` - -See also [`MaxPool`](@ref), [`AdaptiveMeanPool`](@ref). -""" -struct AdaptiveMaxPool{S, O <: Tuple{Vararg{IntegerType}}} <: AbstractLuxLayer - out::O - AdaptiveMaxPool(out) = new{length(out) + 2, typeof(out)}(out) -end - -function (a::AdaptiveMaxPool{S})(x::AbstractArray{T, S}, _, st::NamedTuple) where {S, T} - return maxpool(x, compute_adaptive_pooling_dims(x, a.out)), st -end - -Base.show(io::IO, a::AdaptiveMaxPool) = print(io, "AdaptiveMaxPool(", a.out, ")") - -""" - AdaptiveMeanPool(out::NTuple) - -Adaptive Mean Pooling layer. Calculates the necessary window size such that its output has -`size(y)[1:N] == out`. - -## Arguments - - - `out`: Size of the first `N` dimensions for the output - -## Inputs - - - `x`: Expects as input an array with `ndims(x) == N+2`, i.e. channel and batch - dimensions, after the `N` feature dimensions, where `N = length(out)`. - -## Returns - - - Output of size `(out..., C, N)` - - Empty `NamedTuple()` - -See also [`MeanPool`](@ref), [`AdaptiveMaxPool`](@ref). -""" -struct AdaptiveMeanPool{S, O <: Tuple{Vararg{IntegerType}}} <: AbstractLuxLayer - out::O - AdaptiveMeanPool(out) = new{length(out) + 2, typeof(out)}(out) -end - -function (a::AdaptiveMeanPool{S})(x::AbstractArray{T, S}, _, st::NamedTuple) where {S, T} - return meanpool(x, compute_adaptive_pooling_dims(x, a.out)), st -end - -Base.show(io::IO, a::AdaptiveMeanPool) = print(io, "AdaptiveMeanPool(", a.out, ")") diff --git a/src/layers/normalize.jl b/src/layers/normalize.jl index 1912922863..a355b03a1c 100644 --- a/src/layers/normalize.jl +++ b/src/layers/normalize.jl @@ -323,8 +323,8 @@ Use `Lux.testmode` during inference. ## Example ```jldoctest -julia> Chain(Dense(784 => 64), InstanceNorm(64, relu), Dense(64 => 10), - InstanceNorm(10, relu)) +julia> Chain(Dense(784 => 64), InstanceNorm(64, relu; affine=true), Dense(64 => 10), + InstanceNorm(10, relu; affine=true)) Chain( layer_1 = Dense(784 => 64), # 50_240 parameters layer_2 = InstanceNorm(64, relu, affine=true, track_stats=false), # 128 parameters, plus 1 diff --git a/src/layers/pooling.jl b/src/layers/pooling.jl new file mode 100644 index 0000000000..815fb0e4c4 --- /dev/null +++ b/src/layers/pooling.jl @@ -0,0 +1,240 @@ +abstract type AbstractPoolMode end + +CRC.@non_differentiable (::AbstractPoolMode)(::Any...) + +@concrete struct GenericPoolMode <: AbstractPoolMode + kernel_size <: Tuple{Vararg{IntegerType}} + stride <: Tuple{Vararg{IntegerType}} + pad <: Tuple{Vararg{IntegerType}} + dilation <: Tuple{Vararg{IntegerType}} +end + +(m::GenericPoolMode)(x) = PoolDims(x, m.kernel_size; padding=m.pad, m.stride, m.dilation) + +struct GlobalPoolMode <: AbstractPoolMode end + +(::GlobalPoolMode)(x) = PoolDims(x, size(x)[1:(end - 2)]) + +@concrete struct AdaptivePoolMode <: AbstractPoolMode + out_size <: Tuple{Vararg{IntegerType}} +end + +function (m::AdaptivePoolMode)(x) + in_size = size(x)[1:(end - 2)] + stride = in_size .÷ m.out_size + kernel_size = in_size .- (m.out_size .- 1) .* stride + return PoolDims(x, kernel_size; padding=0, stride, dilation=1) +end + +symbol_to_pool_mode(::StaticSymbol{:generic}) = GenericPoolMode +symbol_to_pool_mode(::StaticSymbol{:global}) = GlobalPoolMode +symbol_to_pool_mode(::StaticSymbol{:adaptive}) = AdaptivePoolMode + +abstract type AbstractPoolOp end + +struct MaxPoolOp <: AbstractPoolOp end +(m::MaxPoolOp)(x, pdims) = maxpool(x, pdims) + +struct MeanPoolOp <: AbstractPoolOp end +(m::MeanPoolOp)(x, pdims) = meanpool(x, pdims) + +@concrete struct LpPoolOp <: AbstractPoolOp + p +end +(m::LpPoolOp)(x, pdims) = lpnormpool(x, pdims; m.p) + +symbol_to_pool_op(::StaticSymbol{:max}, _) = MaxPoolOp() +symbol_to_pool_op(::StaticSymbol{:mean}, _) = MeanPoolOp() +symbol_to_pool_op(::StaticSymbol{:lp}, p) = LpPoolOp(p) + +@concrete struct PoolingLayer <: AbstractLuxLayer + mode <: AbstractPoolMode + op <: AbstractPoolOp +end + +function PoolingLayer(mode::SymbolType, op::SymbolType, + arg::Union{Nothing, Tuple{Vararg{IntegerType}}}=nothing; + stride=arg, pad=0, dilation=1, p=2) + return PoolingLayer(symbol_to_pool_mode(static(mode)), + symbol_to_pool_op(static(op), p), arg; stride, pad, dilation) +end + +function PoolingLayer(::Type{GenericPoolMode}, op::AbstractPoolOp, + kernel_size::Tuple{Vararg{IntegerType}}; stride=kernel_size, pad=0, dilation=1) + stride = Utils.expand(Val(length(kernel_size)), stride) + pad = calc_padding(pad, kernel_size, dilation, stride) + dilation = Utils.expand(Val(length(kernel_size)), dilation) + @argcheck allequal(length, (stride, kernel_size, dilation)) + + return PoolingLayer(GenericPoolMode(kernel_size, stride, pad, dilation), op) +end + +function PoolingLayer(::Type{AdaptivePoolMode}, op::AbstractPoolOp, + out_size::Tuple{Vararg{IntegerType}}; kwargs...) + return PoolingLayer(AdaptivePoolMode(out_size), op) +end + +function PoolingLayer(::Type{GlobalPoolMode}, op::AbstractPoolOp, ::Nothing; kwargs...) + return PoolingLayer(GlobalPoolMode(), op) +end + +(m::PoolingLayer)(x, _, st::NamedTuple) = m.op(x, m.mode(x)), st + +for layer_op in (:Max, :Mean, :LP) + op = Symbol(lowercase(string(layer_op))) + + layer_name = Symbol(layer_op, :Pool) + extra_kwargs = layer_op == :LP ? ", p=2" : "" + layer_docstring = """ + $(layer_name)(window; stride=window, pad=0, dilation=1$(extra_kwargs)) + + $(layer_op) Pooling layer, which replaces all pixels in a block of size `window` with + the reduction operation: $(op). + + ## Arguments + + - `window`: Tuple of integers specifying the size of the window. Eg, for 2D pooling + `length(window) == 2` + + ## Keyword Arguments + + - `stride`: Should each be either single integer, or a tuple with `N` integers + - `dilation`: Should each be either single integer, or a tuple with `N` integers + + - `pad`: Specifies the number of elements added to the borders of the data array. It can + be + + + a single integer for equal padding all around, + + a tuple of `N` integers, to apply the same padding at begin/end of each spatial + dimension, + + a tuple of `2*N` integers, for asymmetric padding, or + + the singleton `SamePad()`, to calculate padding such that + `size(output,d) == size(x,d) / stride` (possibly rounded) for each spatial + dimension. + + # Extended Help + + ## Inputs + + - `x`: Data satisfying `ndims(x) == N + 2`, i.e. `size(x) = (I_N, ..., I_1, C, N)` + + ## Returns + + - Output of the pooling `y` of size `(O_N, ..., O_1, C, N)` where + + ```math + O_i = \\left\\lfloor\\frac{I_i + p_i + p_{(i + N) \\% |p|} - d_i \\times (k_i - 1)}{s_i} + 1\\right\\rfloor + ``` + + - Empty `NamedTuple()` + """ + + global_layer_name = Symbol(:Global, layer_name) + extra_kwargs = layer_op == :LP ? "; p=2" : "" + global_pooling_docstring = """ + $(global_layer_name)($(extra_kwargs)) + + Global $(layer_op) Pooling layer. Transforms `(w, h, c, b)`-shaped input into + `(1, 1, c, b)`-shaped output, by performing mean pooling on the complete `(w, h)`-shaped + feature maps. + + ## Inputs + + - `x`: Data satisfying `ndims(x) > 2`, i.e. `size(x) = (I_N, ..., I_1, C, N)` + + ## Returns + + - Output of the pooling `y` of size `(1, ..., 1, C, N)` + - Empty `NamedTuple()` + """ + + adaptive_layer_name = Symbol(:Adaptive, layer_name) + adaptive_pooling_docstring = """ + $(adaptive_layer_name)(output_size$(extra_kwargs)) + + Adaptive $(layer_op) Pooling layer. Calculates the necessary window size such that + its output has `size(y)[1:N] == output_size`. + + ## Arguments + + - `output_size`: Size of the first `N` dimensions for the output + + ## Inputs + + - `x`: Expects as input an array with `ndims(x) == N + 2`, i.e. channel and batch + dimensions, after the `N` feature dimensions, where `N = length(output_size)`. + + ## Returns + + - Output of size `(out..., C, N)` + - Empty `NamedTuple()` + """ + + @eval begin + # Generic Pooling Layer + @doc $(layer_docstring) @concrete struct $(layer_name) <: + AbstractLuxWrapperLayer{:layer} + layer <: PoolingLayer + end + + function $(layer_name)( + window::Tuple{Vararg{IntegerType}}; stride=window, pad=0, dilation=1, p=2) + return $(layer_name)(PoolingLayer(static(:generic), static($(Meta.quot(op))), + window; stride, pad, dilation, p)) + end + + function Base.show(io::IO, ::MIME"text/plain", m::$(layer_name)) + kernel_size = m.layer.mode.kernel_size + print(io, string($(Meta.quot(layer_name))), "($(kernel_size)") + pad = m.layer.mode.pad + all(==(0), pad) || print(io, ", pad=", PrettyPrinting.tuple_string(pad)) + stride = m.layer.mode.stride + stride == kernel_size || + print(io, ", stride=", PrettyPrinting.tuple_string(stride)) + dilation = m.layer.mode.dilation + all(==(1), dilation) || + print(io, ", dilation=", PrettyPrinting.tuple_string(dilation)) + if $(Meta.quot(op)) == :lp + a.layer.op.p == 2 || print(io, ", p=", a.layer.op.p) + end + print(io, ")") + end + + # Global Pooling Layer + @doc $(global_pooling_docstring) @concrete struct $(global_layer_name) <: + AbstractLuxWrapperLayer{:layer} + layer <: PoolingLayer + end + + function $(global_layer_name)(; p=2) + return $(global_layer_name)(PoolingLayer(static(:global), $(Meta.quot(op)); p)) + end + + function Base.show(io::IO, ::MIME"text/plain", g::$(global_layer_name)) + print(io, string($(Meta.quot(global_layer_name))), "(") + if $(Meta.quot(op)) == :lp + a.layer.op.p == 2 || print(io, ", p=", a.layer.op.p) + end + print(io, ")") + end + + # Adaptive Pooling Layer + @doc $(adaptive_pooling_docstring) @concrete struct $(adaptive_layer_name) <: + AbstractLuxWrapperLayer{:layer} + layer <: PoolingLayer + end + + function $(adaptive_layer_name)(out_size::Tuple{Vararg{IntegerType}}; p=2) + return $(adaptive_layer_name)(PoolingLayer( + static(:adaptive), $(Meta.quot(op)), out_size; p)) + end + + function Base.show(io::IO, ::MIME"text/plain", a::$(adaptive_layer_name)) + print(io, string($(Meta.quot(adaptive_layer_name))), "(", a.layer.mode.out_size) + if $(Meta.quot(op)) == :lp + a.layer.op.p == 2 || print(io, ", p=", a.layer.op.p) + end + print(io, ")") + end + end +end