Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow scaling of init functions #2375

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions docs/src/utilities.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,22 @@ julia> Dense(4 => 5, tanh; init=Flux.randn32(MersenneTwister(1)))
Dense(4 => 5, tanh) # 25 parameters
```

All of the initialisation functions may be multiplied by a number
to scale their output:

```jldoctest; setup = :(using Flux, Random)
julia> lay = Dense(3 => 1, relu; init=42*Flux.ones32)
Dense(3 => 1, relu) # 4 parameters

julia> lay.weight
1×3 Matrix{Float32}:
42.0 42.0 42.0

julia> lay.bias
1-element Vector{Float32}:
0.0
```

## Initialisation functions

```@docs
Expand Down
71 changes: 67 additions & 4 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,57 @@ rng_from_array(::AbstractArray) = Random.default_rng()
@non_differentiable rng_from_array(::Any)


"""
FixRNG(init, [rng, scale; kw...])

This is a bit like `Base.Fix1` in that `FixRNG(randn, rng)` makes a function,
but also allows for scaling by a factor.

It exists to allow modifying initialisation functions:
```
julia> 2 * randn32
Flux.FixRNG(randn32, 2.0)

julia> Dense(3 => 1, init=pi*ones32)
Dense(3 => 1) # 4 parameters

julia> ans.weight
1×3 Matrix{Float32}:
3.14159 3.14159 3.14159
```

The struct itself is not part of Flux's API, so using it directly is not recommended.
"""
struct FixRNG{F<:Function, R<:Tuple, K<:NamedTuple} <: Function
fun::F
args::R
kwargs::K
scale::Float32
end
FixRNG(f::Function, scale::Real=1f0; kw...) = FixRNG(f, (), NamedTuple(kw), scale)
FixRNG(f::Function, rng::AbstractRNG, scale::Real=1f0; kw...) = FixRNG(f, (rng,), NamedTuple(kw), scale)

function (init::FixRNG)(args...; kw...)
raw = init.fun(init.args..., args...; kw...)
if isone(init.scale)
return raw
elseif raw isa Array{<:AbstractFloat}
return lmul!(init.scale, raw) # premature optimisation to save alloc!
else
return @. oftype(float(raw), init.scale * raw)
end
end

Base.:*(λ::Real, init::FixRNG) = FixRNG(init.fun, init.args, init.kwargs, Float32(λ * init.scale))

function Base.show(io::IO, init::FixRNG)
print(io, "Flux.FixRNG(", init.fun)
isempty(init.args) || print(io, ", ", join(init.args, ", "))
isone(init.scale) || print(io, ", ", init.scale)
print(io, ")")
end
Base.show(io::IO, ::MIME"text/plain", init::FixRNG) = Base.show(io, init) # needed because of <:Function

"""
glorot_uniform([rng], size...; gain = 1) -> Array
glorot_uniform([rng]; kw...) -> Function
Expand Down Expand Up @@ -87,7 +138,7 @@ function glorot_uniform(rng::AbstractRNG, dims::Integer...; gain::Real=1)
(rand(rng, Float32, dims...) .- 0.5f0) .* scale
end
glorot_uniform(dims::Integer...; kw...) = glorot_uniform(default_rng(), dims...; kw...)
glorot_uniform(rng::AbstractRNG=default_rng(); init_kwargs...) = (dims...; kwargs...) -> glorot_uniform(rng, dims...; init_kwargs..., kwargs...)
glorot_uniform(rng::AbstractRNG=default_rng(); init_kwargs...) = FixRNG(glorot_uniform, rng; init_kwargs...)

ChainRulesCore.@non_differentiable glorot_uniform(::Any...)

Expand Down Expand Up @@ -130,7 +181,7 @@ function glorot_normal(rng::AbstractRNG, dims::Integer...; gain::Real=1)
randn(rng, Float32, dims...) .* std
end
glorot_normal(dims::Integer...; kwargs...) = glorot_normal(default_rng(), dims...; kwargs...)
glorot_normal(rng::AbstractRNG=default_rng(); init_kwargs...) = (dims...; kwargs...) -> glorot_normal(rng, dims...; init_kwargs..., kwargs...)
glorot_normal(rng::AbstractRNG=default_rng(); init_kwargs...) = FixRNG(glorot_normal, rng; init_kwargs...)

ChainRulesCore.@non_differentiable glorot_normal(::Any...)

Expand Down Expand Up @@ -455,6 +506,9 @@ ChainRulesCore.@non_differentiable identity_init(::Any...)
ones32(size...) = ones(Float32, size...)

Return an `Array{Float32}` of the given `size` filled with 1s.

Multiplying by a number scales the output. Thus `init = 10 * ones32` is a function
with makes an array with all values `10f0`.
"""
ones32(dims...) = Base.ones(Float32, dims...)

Expand All @@ -473,17 +527,26 @@ When the size is not provided, `rand32(rng::AbstractRNG)` returns a function.
"""
rand32(dims::Integer...) = Base.rand(Float32, dims...)
rand32(rng::AbstractRNG, dims::Integer...) = Base.rand(rng, Float32, dims...)
rand32(rng::AbstractRNG) = (dims...,) -> Base.rand(rng, Float32, dims...)
rand32(rng::AbstractRNG) = FixRNG(rand32, rng)

"""
randn32([rng], size...)

Return an `Array{Float32}` of the given `size`, filled like `randn`.
When the size is not provided, `randn32(rng::AbstractRNG)` returns a function.

Multiplying by a number scales the output. Thus `init = 10 * randn32` is a function
with makes an array of mean zero and standard deviation 10.
"""
randn32(dims::Integer...) = Base.randn(Float32, dims...)
randn32(rng::AbstractRNG, dims::Integer...) = Base.randn(rng, Float32, dims...)
randn32(rng::AbstractRNG) = (dims...,) -> Base.randn(rng, Float32, dims...)
randn32(rng::AbstractRNG) = FixRNG(randn32, rng)

for fun in [ones32, zeros32, rand32, randn32]
@eval Base.:*(λ::Real, fun::typeof($fun)) = λ * FixRNG($fun)
@eval Base.:*(fun::typeof($fun), λ::Real) = λ * FixRNG($fun)
@eval Base.:/(fun::typeof($fun), λ::Real) = (1/λ) * FixRNG($fun)
end

"""
create_bias(weights, bias, size...)
Expand Down
Loading