Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

[WIP] Add initializers from ReservoirComputing #23

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ version = "0.1.7"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is quite heavy, we should make it into an extension

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we even need that, I was using it for Uniform and now it's only needed for Bernoulli, so I'm sure we can avoid it all together. I just needed to get the tests running to catch the various CUDA errors

LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
PartialFunctions = "570af359-4316-4cb7-8c74-252c00c2016b"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
Expand All @@ -22,10 +23,12 @@ WeightInitializersCUDAExt = "CUDA"
Aqua = "0.8"
CUDA = "5"
ChainRulesCore = "1.21"
Distributions = "0.25"
LinearAlgebra = "1.9"
PartialFunctions = "1.2"
PrecompileTools = "1.2"
Random = "1.9"
SafeTestsets = "0.1"
SpecialFunctions = "2"
StableRNGs = "1"
Statistics = "1.9"
Expand All @@ -36,9 +39,10 @@ julia = "1.9"
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Aqua", "Test", "StableRNGs", "Random", "Statistics", "CUDA"]
test = ["Aqua", "Test", "StableRNGs", "Random", "Statistics", "CUDA", "SafeTestsets"]
36 changes: 34 additions & 2 deletions ext/WeightInitializersCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ module WeightInitializersCUDAExt
using WeightInitializers, CUDA
using Random
import WeightInitializers: __partial_apply, NUM_TO_FPOINT, identity_init, sparse_init,
orthogonal
orthogonal, delay_line, delay_line_backward

const AbstractCuRNG = Union{CUDA.RNG, CURAND.RNG}

Expand Down Expand Up @@ -62,7 +62,39 @@ function identity_init(rng::AbstractCuRNG, ::Type{T}, dims::Integer...;
end
end

for initializer in (:sparse_init, :identity_init)
# rc initializers

function delay_line(rng::AbstractCuRNG,
::Type{T},
dims::Integer...;
weight = T(0.1)) where {T <: Number}
reservoir_matrix = CUDA.zeros(T, dims...)
@assert length(dims) == 2&&dims[1] == dims[2] "The dimensions must define a square matrix (e.g., (100, 100))"

for i in 1:(dims[1] - 1)
reservoir_matrix[i + 1, i] = T(weight)
end

return reservoir_matrix
end

function delay_line_backward(rng::AbstractCuRNG,
::Type{T},
dims::Integer...;
weight = T(0.1),
fb_weight = T(0.2)) where {T <: Number}
res_size = first(dims)
reservoir_matrix = CUDA.zeros(T, dims...)

for i in 1:(res_size - 1)
reservoir_matrix[i + 1, i] = T(weight)
reservoir_matrix[i, i + 1] = T(fb_weight)
end

return reservoir_matrix
end

for initializer in (:sparse_init, :identity_init, :delay_line, :delay_line_backward)
@eval function ($initializer)(rng::AbstractCuRNG, dims::Integer...; kwargs...)
return $initializer(rng, Float32, dims...; kwargs...)
end
Expand Down
42 changes: 40 additions & 2 deletions src/WeightInitializers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@ import PrecompileTools: @recompile_invalidations

@recompile_invalidations begin
using ChainRulesCore, PartialFunctions, Random, SpecialFunctions, Statistics,
LinearAlgebra
LinearAlgebra, Distributions
end

include("utils.jl")
include("initializers.jl")
include("rc_initializers.jl")

# Mark the functions as non-differentiable
for f in [
Expand Down Expand Up @@ -43,11 +44,46 @@ for f in [
:truncated_normal,
:orthogonal,
:sparse_init,
:identity_init
:identity_init,
:rand_sparse,
:delay_line,
:delay_line_backward,
:cycle_jumps,
:simple_cycle,
:pseudo_svd,
:scaled_rand,
:weighted_init,
:informed_init,
:minimal_init
]
@eval @non_differentiable $(f)(::Any...)
end

for initializer in (:glorot_uniform, :glorot_normal, :kaiming_uniform, :kaiming_normal,
:truncated_normal, :orthogonal, :sparse_init, :identity_init, :rand_sparse, :delay_line,
:delay_line_backward, :cycle_jumps, :simple_cycle, :pseudo_svd, :scaled_rand,
:weighted_init, :informed_init, :minimal_init)
NType = ifelse(initializer === :truncated_normal, Real, Number)
@eval function ($initializer)(dims::Integer...; kwargs...)
return $initializer(_default_rng(), Float32, dims...; kwargs...)
end
@eval function ($initializer)(rng::AbstractRNG, dims::Integer...; kwargs...)
return $initializer(rng, Float32, dims...; kwargs...)
end
@eval function ($initializer)(::Type{T},
dims::Integer...; kwargs...) where {T <: $NType}
return $initializer(_default_rng(), T, dims...; kwargs...)
end
@eval function ($initializer)(rng::AbstractRNG; kwargs...)
return __partial_apply($initializer, (rng, (; kwargs...)))
end
@eval function ($initializer)(rng::AbstractRNG,
::Type{T}; kwargs...) where {T <: $NType}
return __partial_apply($initializer, ((rng, T), (; kwargs...)))
end
@eval ($initializer)(; kwargs...) = __partial_apply($initializer, (; kwargs...))
end

export zeros64, ones64, rand64, randn64, zeros32, ones32, rand32, randn32, zeros16, ones16,
rand16, randn16
export zerosC64, onesC64, randC64, randnC64, zerosC32, onesC32, randC32, randnC32, zerosC16,
Expand All @@ -58,5 +94,7 @@ export truncated_normal
export orthogonal
export sparse_init
export identity_init
export scaled_rand, weighted_init, informed_init, minimal_init
export rand_sparse, delay_line, delay_line_backward, cycle_jumps, simple_cycle, pseudo_svd

end
23 changes: 0 additions & 23 deletions src/initializers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -330,29 +330,6 @@ function identity_init(rng::AbstractRNG, ::Type{T}, dims::Integer...;
end

# Default Fallbacks for all functions
for initializer in (:glorot_uniform, :glorot_normal, :kaiming_uniform, :kaiming_normal,
:truncated_normal, :orthogonal, :sparse_init, :identity_init)
NType = ifelse(initializer === :truncated_normal, Real, Number)
@eval function ($initializer)(dims::Integer...; kwargs...)
return $initializer(_default_rng(), Float32, dims...; kwargs...)
end
@eval function ($initializer)(rng::AbstractRNG, dims::Integer...; kwargs...)
return $initializer(rng, Float32, dims...; kwargs...)
end
@eval function ($initializer)(::Type{T},
dims::Integer...; kwargs...) where {T <: $NType}
return $initializer(_default_rng(), T, dims...; kwargs...)
end
@eval function ($initializer)(rng::AbstractRNG; kwargs...)
return __partial_apply($initializer, (rng, (; kwargs...)))
end
@eval function ($initializer)(rng::AbstractRNG,
::Type{T}; kwargs...) where {T <: $NType}
return __partial_apply($initializer, ((rng, T), (; kwargs...)))
end
@eval ($initializer)(; kwargs...) = __partial_apply($initializer, (; kwargs...))
end

for tp in ("16", "32", "64", "C16", "C32", "C64"), func in (:zeros, :ones, :randn, :rand)
initializer = Symbol(func, tp)
@eval function ($initializer)(dims::Integer...; kwargs...)
Expand Down
Loading
Loading