Skip to content

Commit

Permalink
feat: generalize pooling implementation and add LP versions
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Sep 6, 2024
1 parent e47f063 commit fd66780
Show file tree
Hide file tree
Showing 6 changed files with 252 additions and 267 deletions.
3 changes: 3 additions & 0 deletions docs/src/api/Lux/layers.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,13 @@ VariationalHiddenDropout
## Pooling Layers

```@docs
AdaptiveLPPool
AdaptiveMaxPool
AdaptiveMeanPool
GlobalLPPool
GlobalMaxPool
GlobalMeanPool
LPPool
MaxPool
MeanPool
```
Expand Down
5 changes: 3 additions & 2 deletions ext/LuxSimpleChainsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,9 @@ function Lux.make_simplechain_network(layer::FlattenLayer)
end

function Lux.make_simplechain_network(layer::MaxPool)
if layer.stride == layer.k && (!(layer.pad isa SamePad) && all(==(0), layer.pad))
return SimpleChains.MaxPool(layer.k)
if layer.layer.mode.stride == layer.layer.mode.kernel_size &&
all(==(0), layer.layer.mode.pad)
return SimpleChains.MaxPool(layer.layer.mode.kernel_size)
end
throw(SimpleChainsModelConversionException("MaxPool with non-standard parameters not \
supported."))
Expand Down
6 changes: 4 additions & 2 deletions src/Lux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ include("layers/basic.jl")
include("layers/containers.jl")
include("layers/normalize.jl")
include("layers/conv.jl")
include("layers/pooling.jl")
include("layers/dropout.jl")
include("layers/recurrent.jl")
include("layers/extension.jl")
Expand Down Expand Up @@ -87,8 +88,9 @@ include("distributed/public_api.jl")
# Layers
export Chain, Parallel, SkipConnection, PairwiseFusion, BranchLayer, Maxout, RepeatedLayer
export Bilinear, Dense, Embedding, Scale
export Conv, ConvTranspose, MaxPool, MeanPool, GlobalMaxPool, GlobalMeanPool,
AdaptiveMaxPool, AdaptiveMeanPool, Upsample, PixelShuffle
export Conv, ConvTranspose, Upsample, PixelShuffle
export MaxPool, MeanPool, LPPool, GlobalMaxPool, GlobalMeanPool, GlobalLPPool,
AdaptiveMaxPool, AdaptiveMeanPool, AdaptiveLPPool
export AlphaDropout, Dropout, VariationalHiddenDropout
export BatchNorm, GroupNorm, InstanceNorm, LayerNorm
export WeightNorm
Expand Down
261 changes: 0 additions & 261 deletions src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,6 @@ CRC.@non_differentiable conv_transpose_dims(::Any...)

conv_transpose(x, weight, cdims) = LuxLib.Impl.∇conv_data(x, weight, cdims)

function compute_adaptive_pooling_dims(x::AbstractArray, outsize)
insize = size(x)[1:(end - 2)]
stride = insize outsize
k = insize .- (outsize .- 1) .* stride
return PoolDims(x, k; padding=0, stride=stride)
end

CRC.@non_differentiable compute_adaptive_pooling_dims(::Any, ::Any)

function init_conv_weight(
rng::AbstractRNG, init_weight::F, filter::NTuple{N, <:IntegerType},
in_chs::IntegerType, out_chs::IntegerType, groups, σ::A) where {F, N, A}
Expand Down Expand Up @@ -508,255 +499,3 @@ end
function PixelShuffle(r::IntegerType)
return PixelShuffle(WrappedFunction(Base.Fix2(pixel_shuffle, r)))
end

@doc doc"""
MaxPool(window::NTuple; pad=0, stride=window)
Max pooling layer, which replaces all pixels in a block of size `window` with the maximum
value.
# Arguments
- `window`: Tuple of integers specifying the size of the window. Eg, for 2D pooling
`length(window) == 2`
## Keyword Arguments
- `stride`: Should each be either single integer, or a tuple with `N` integers
- `pad`: Specifies the number of elements added to the borders of the data array. It can
be
+ a single integer for equal padding all around,
+ a tuple of `N` integers, to apply the same padding at begin/end of each spatial
dimension,
+ a tuple of `2*N` integers, for asymmetric padding, or
+ the singleton `SamePad()`, to calculate padding such that
`size(output,d) == size(x,d) / stride` (possibly rounded) for each spatial
dimension.
# Extended Help
## Inputs
- `x`: Data satisfying `ndims(x) == N + 2`, i.e. `size(x) = (I_N, ..., I_1, C, N)`
## Returns
- Output of the pooling `y` of size `(O_N, ..., O_1, C, N)` where
```math
O_i = \left\lfloor\frac{I_i + p_i + p_{(i + N) \% |p|} - d_i \times (k_i - 1)}{s_i} + 1\right\rfloor
```
- Empty `NamedTuple()`
See also [`Conv`](@ref), [`MeanPool`](@ref), [`GlobalMaxPool`](@ref),
[`AdaptiveMaxPool`](@ref)
"""
@concrete struct MaxPool <: AbstractLuxLayer
k <: Tuple{Vararg{IntegerType}}
pad <: Tuple{Vararg{IntegerType}}
stride <: Tuple{Vararg{IntegerType}}
end

function MaxPool(k::Tuple{Vararg{IntegerType}}; pad=0, stride=k)
stride = Utils.expand(Val(length(k)), stride)
pad = calc_padding(pad, k, 1, stride)
@argcheck allequal(length, (stride, k))

return MaxPool(k, pad, stride)
end

function (m::MaxPool)(x, _, st::NamedTuple)
return maxpool(x, PoolDims(x, m.k; padding=m.pad, m.stride)), st
end

function Base.show(io::IO, m::MaxPool)
print(io, "MaxPool(", m.k)
all(==(0), m.pad) || print(io, ", pad=", PrettyPrinting.tuple_string(m.pad))
m.stride == m.k || print(io, ", stride=", PrettyPrinting.tuple_string(m.stride))
print(io, ")")
end

@doc doc"""
MeanPool(window::NTuple; pad=0, stride=window)
Mean pooling layer, which replaces all pixels in a block of size `window` with the mean
value.
# Arguments
- `window`: Tuple of integers specifying the size of the window. Eg, for 2D pooling
`length(window) == 2`
## Keyword Arguments
- `stride`: Should each be either single integer, or a tuple with `N` integers
- `pad`: Specifies the number of elements added to the borders of the data array. It can
be
+ a single integer for equal padding all around,
+ a tuple of `N` integers, to apply the same padding at begin/end of each spatial
dimension,
+ a tuple of `2*N` integers, for asymmetric padding, or
+ the singleton `SamePad()`, to calculate padding such that
`size(output,d) == size(x,d) / stride` (possibly rounded) for each spatial
dimension.
# Extended Help
## Inputs
- `x`: Data satisfying `ndims(x) == N + 2`, i.e. `size(x) = (I_N, ..., I_1, C, N)`
## Returns
- Output of the pooling `y` of size `(O_N, ..., O_1, C, N)` where
```math
O_i = \left\lfloor\frac{I_i + p_i + p_{(i + N) \% |p|} - d_i \times (k_i - 1)}{s_i} + 1\right\rfloor
```
- Empty `NamedTuple()`
See also [`Conv`](@ref), [`MaxPool`](@ref), [`GlobalMeanPool`](@ref),
[`AdaptiveMeanPool`](@ref)
"""
@concrete struct MeanPool <: AbstractLuxLayer
k <: Tuple{Vararg{IntegerType}}
pad <: Tuple{Vararg{IntegerType}}
stride <: Tuple{Vararg{IntegerType}}
end

function MeanPool(k::Tuple{Vararg{IntegerType}}; pad=0, stride=k)
stride = Utils.expand(Val(length(k)), stride)
pad = calc_padding(pad, k, 1, stride)
@argcheck allequal(length, (stride, k))

return MeanPool(k, pad, stride)
end

function (m::MeanPool)(x, _, st::NamedTuple)
return meanpool(x, PoolDims(x, m.k; padding=m.pad, m.stride)), st
end

function Base.show(io::IO, m::MeanPool)
print(io, "MeanPool(", m.k)
all(==(0), m.pad) || print(io, ", pad=", PrettyPrinting.tuple_string(m.pad))
m.stride == m.k || print(io, ", stride=", PrettyPrinting.tuple_string(m.stride))
print(io, ")")
end

"""
GlobalMaxPool()
Global Max Pooling layer. Transforms (w,h,c,b)-shaped input into (1,1,c,b)-shaped output,
by performing max pooling on the complete (w,h)-shaped feature maps.
## Inputs
- `x`: Data satisfying `ndims(x) > 2`, i.e. `size(x) = (I_N, ..., I_1, C, N)`
## Returns
- Output of the pooling `y` of size `(1, ..., 1, C, N)`
- Empty `NamedTuple()`
See also [`MaxPool`](@ref), [`AdaptiveMaxPool`](@ref), [`GlobalMeanPool`](@ref)
"""
struct GlobalMaxPool <: AbstractLuxLayer end

function (g::GlobalMaxPool)(x, _, st::NamedTuple)
return maxpool(x, PoolDims(x, size(x)[1:(end - 2)])), st
end

"""
GlobalMeanPool()
Global Mean Pooling layer. Transforms (w,h,c,b)-shaped input into (1,1,c,b)-shaped output,
by performing mean pooling on the complete (w,h)-shaped feature maps.
## Inputs
- `x`: Data satisfying `ndims(x) > 2`, i.e. `size(x) = (I_N, ..., I_1, C, N)`
## Returns
- Output of the pooling `y` of size `(1, ..., 1, C, N)`
- Empty `NamedTuple()`
See also [`MeanPool`](@ref), [`AdaptiveMeanPool`](@ref), [`GlobalMaxPool`](@ref)
"""
struct GlobalMeanPool <: AbstractLuxLayer end

function (g::GlobalMeanPool)(x, _, st::NamedTuple)
return meanpool(x, PoolDims(x, size(x)[1:(end - 2)])), st
end

"""
AdaptiveMaxPool(out::NTuple)
Adaptive Max Pooling layer. Calculates the necessary window size such that its output has
`size(y)[1:N] == out`.
## Arguments
- `out`: Size of the first `N` dimensions for the output
## Inputs
- `x`: Expects as input an array with `ndims(x) == N+2`, i.e. channel and batch
dimensions, after the `N` feature dimensions, where `N = length(out)`.
## Returns
- Output of size `(out..., C, N)`
- Empty `NamedTuple()`
See also [`MaxPool`](@ref), [`AdaptiveMeanPool`](@ref).
"""
struct AdaptiveMaxPool{S, O <: Tuple{Vararg{IntegerType}}} <: AbstractLuxLayer
out::O
AdaptiveMaxPool(out) = new{length(out) + 2, typeof(out)}(out)
end

function (a::AdaptiveMaxPool{S})(x::AbstractArray{T, S}, _, st::NamedTuple) where {S, T}
return maxpool(x, compute_adaptive_pooling_dims(x, a.out)), st
end

Base.show(io::IO, a::AdaptiveMaxPool) = print(io, "AdaptiveMaxPool(", a.out, ")")

"""
AdaptiveMeanPool(out::NTuple)
Adaptive Mean Pooling layer. Calculates the necessary window size such that its output has
`size(y)[1:N] == out`.
## Arguments
- `out`: Size of the first `N` dimensions for the output
## Inputs
- `x`: Expects as input an array with `ndims(x) == N+2`, i.e. channel and batch
dimensions, after the `N` feature dimensions, where `N = length(out)`.
## Returns
- Output of size `(out..., C, N)`
- Empty `NamedTuple()`
See also [`MeanPool`](@ref), [`AdaptiveMaxPool`](@ref).
"""
struct AdaptiveMeanPool{S, O <: Tuple{Vararg{IntegerType}}} <: AbstractLuxLayer
out::O
AdaptiveMeanPool(out) = new{length(out) + 2, typeof(out)}(out)
end

function (a::AdaptiveMeanPool{S})(x::AbstractArray{T, S}, _, st::NamedTuple) where {S, T}
return meanpool(x, compute_adaptive_pooling_dims(x, a.out)), st
end

Base.show(io::IO, a::AdaptiveMeanPool) = print(io, "AdaptiveMeanPool(", a.out, ")")
4 changes: 2 additions & 2 deletions src/layers/normalize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -323,8 +323,8 @@ Use `Lux.testmode` during inference.
## Example
```jldoctest
julia> Chain(Dense(784 => 64), InstanceNorm(64, relu), Dense(64 => 10),
InstanceNorm(10, relu))
julia> Chain(Dense(784 => 64), InstanceNorm(64, relu; affine=true), Dense(64 => 10),
InstanceNorm(10, relu; affine=true))
Chain(
layer_1 = Dense(784 => 64), # 50_240 parameters
layer_2 = InstanceNorm(64, relu, affine=true, track_stats=false), # 128 parameters, plus 1
Expand Down
Loading

0 comments on commit fd66780

Please sign in to comment.