Skip to content

Commit

Permalink
some renaming
Browse files Browse the repository at this point in the history
  • Loading branch information
MartinuzziFrancesco committed Nov 3, 2024
1 parent d25c32b commit 65dbfe1
Show file tree
Hide file tree
Showing 10 changed files with 104 additions and 48 deletions.
14 changes: 10 additions & 4 deletions src/indrnn_cell.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,17 @@ end
Flux.@layer IndRNNCell

"""
IndRNNCell((in, out)::Pair, σ=relu; init = glorot_uniform, bias = true)
function IndRNNCell((in, out)::Pair, σ=relu;
kernel_init = glorot_uniform,
recurrent_kernel_init = glorot_uniform,
bias = true)
"""
function IndRNNCell((in, out)::Pair, σ=relu; init = glorot_uniform, bias = true)
Wi = init(out, in)
u = init(out)
function IndRNNCell((in, out)::Pair, σ=relu;
kernel_init = glorot_uniform,
recurrent_kernel_init = glorot_uniform,
bias = true)
Wi = kernel_init(out, in)
u = recrrent_kernel_init(out)
b = create_bias(Wi, bias, size(Wi, 1))
return IndRNNCell(σ, Wi, u, b)
end
Expand Down
14 changes: 10 additions & 4 deletions src/lightru_cell.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,17 @@ end
Flux.@layer LightRUCell

"""
LightRUCell((in, out)::Pair, σ=tanh; init = glorot_uniform, bias = true)
LightRUCell((in, out)::Pair, σ=tanh;
kernel_init = glorot_uniform,
recurrent_kernel_init = glorot_uniform,
bias = true)
"""
function LightRUCell((in, out)::Pair, σ=tanh; init = glorot_uniform, bias = true)
Wi = init(2 * out, in)
Wh = init(out, out)
function LightRUCell((in, out)::Pair, σ=tanh;
kernel_init = glorot_uniform,
recurrent_kernel_init = glorot_uniform,
bias = true)
Wi = kernel_init(2 * out, in)
Wh = recurrent_kernel_init(out, out)
b = create_bias(Wi, bias, size(Wh, 1))

return LightRUCell(Wi, Wh, b)
Expand Down
12 changes: 8 additions & 4 deletions src/ligru_cell.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,18 @@ end
Flux.@layer LiGRUCell

"""
LiGRUCell((in, out)::Pair; init = glorot_uniform, bias = true)
LiGRUCell((in, out)::Pair;
kernel_init = glorot_uniform,
recurrent_kernel_init = glorot_uniform,
bias = true)
"""
function LiGRUCell((in, out)::Pair;
init = glorot_uniform,
kernel_init = glorot_uniform,
recurrent_kernel_init = glorot_uniform,
bias = true)

Wi = init(out * 2, in)
Wh = init(out * 2, out)
Wi = kernel_init(out * 2, in)
Wh = recurrent_kernel_init(out * 2, out)
b = create_bias(Wi, bias, size(Wi, 1))

return LiGRUCell(Wi, Wh, b)
Expand Down
12 changes: 8 additions & 4 deletions src/mgu_cell.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,18 @@ end
Flux.@layer MGUCell

"""
MGUCell((in, out)::Pair; init = glorot_uniform, bias = true)
MGUCell((in, out)::Pair;
kernel_init = glorot_uniform,
recurrent_kernel_init = glorot_uniform,
bias = true)
"""
function MGUCell((in, out)::Pair;
init = glorot_uniform,
kernel_init = glorot_uniform,
recurrent_kernel_init = glorot_uniform,
bias = true)

Wi = init(out * 2, in)
Wh = init(out * 2, out)
Wi = kernel_init(out * 2, in)
Wh = recurrent_kernel_init(out * 2, out)
b = create_bias(Wi, bias, size(Wi, 1))

return MGUCell(Wi, Wh, b)
Expand Down
36 changes: 24 additions & 12 deletions src/mut_cell.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,18 @@ end
Flux.@layer MUT1Cell

"""
MUT1Cell((in, out)::Pair; init = glorot_uniform, bias = true)
MUT1Cell((in, out)::Pair;
kernel_init = glorot_uniform,
recurrent_kernel_init = glorot_uniform,
bias = true)
"""
function MUT1Cell((in, out)::Pair;
init = glorot_uniform,
kernel_init = glorot_uniform,
recurrent_kernel_init = glorot_uniform,
bias = true)

Wi = init(out * 3, in)
Wh = init(out * 2, out)
Wi = kernel_init(out * 3, in)
Wh = recurrent_kernel_init(out * 2, out)
b = create_bias(Wi, bias, 3 * out)

return MUT1Cell(Wi, Wh, b)
Expand Down Expand Up @@ -55,14 +59,18 @@ end
Flux.@layer MUT2Cell

"""
MUT2Cell((in, out)::Pair; init = glorot_uniform, bias = true)
MUT2Cell((in, out)::Pair;
kernel_init = glorot_uniform,
recurrent_kernel_init = glorot_uniform,
bias = true)
"""
function MUT2Cell((in, out)::Pair;
init = glorot_uniform,
kernel_init = glorot_uniform,
recurrent_kernel_init = glorot_uniform,
bias = true)

Wi = init(out * 3, in)
Wh = init(out * 3, out)
Wi = kernel_init(out * 3, in)
Wh = recurrent_kernel_init(out * 3, out)
b = create_bias(Wi, bias, 3 * out)

return MUT2Cell(Wi, Wh, b)
Expand Down Expand Up @@ -102,14 +110,18 @@ end
Flux.@layer MUT3Cell

"""
MUT3Cell((in, out)::Pair; init = glorot_uniform, bias = true)
MUT3Cell((in, out)::Pair;
kernel_init = glorot_uniform,
recurrent_kernel_init = glorot_uniform,
bias = true)
"""
function MUT3Cell((in, out)::Pair;
init = glorot_uniform,
kernel_init = glorot_uniform,
recurrent_kernel_init = glorot_uniform,
bias = true)

Wi = init(out * 3, in)
Wh = init(out * 3, out)
Wi = kernel_init(out * 3, in)
Wh = recurrent_kernel_init(out * 3, out)
b = create_bias(Wi, bias, 3 * out)

return MUT3Cell(Wi, Wh, b)
Expand Down
14 changes: 10 additions & 4 deletions src/nas_cell.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,17 @@ end
Flux.@layer NASCell

"""
NASCell((in, out)::Pair; init = glorot_uniform, bias = true)
NASCell((in, out)::Pair;
kernel_init = glorot_uniform,
recurrent_kernel_init = glorot_uniform,
bias = true)
"""
function NASCell((in, out)::Pair; init = glorot_uniform, bias = true)
Wi = init(8 * out, in)
Wh = init(8 * out, out)
function NASCell((in, out)::Pair;
kernel_init = glorot_uniform,
recurrent_kernel_init = glorot_uniform,
bias = true)
Wi = kernel_init(8 * out, in)
Wh = recurrent_kernel_init(8 * out, out)
b = create_bias(Wi, bias, size(Wh, 1))
return NASCell(Wi, Wh, b)
end
Expand Down
14 changes: 10 additions & 4 deletions src/ran_cell.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@ Flux.@layer RANCell


"""
RANCell(in => out; init = glorot_uniform, bias = true)
RANCell((in, out)::Pair;
kernel_init = glorot_uniform,
recurrent_kernel_init = glorot_uniform,
bias = true)
The `RANCell`, introduced in [this paper](https://arxiv.org/pdf/1705.07393),
is a recurrent cell layer which provides additional memory through the
Expand Down Expand Up @@ -51,9 +54,12 @@ result = rancell(inp)
result_state = rancell(inp, (state, c_state))
```
"""
function RANCell((in, out)::Pair; init = glorot_uniform, bias = true)
Wi = init(3 * out, in)
Wh = init(2 * out, out)
function RANCell((in, out)::Pair;
kernel_init = glorot_uniform,
recurrent_kernel_init = glorot_uniform,
bias = true)
Wi = kernel_init(3 * out, in)
Wh = recurrent_kernel_init(2 * out, out)
b = create_bias(Wi, bias, size(Wh, 1))
return RANCell(Wi, Wh, b)
end
Expand Down
10 changes: 7 additions & 3 deletions src/rhn_cell.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,14 @@ end
Flux.@layer RHNCellUnit

"""
RHNCellUnit((in, out)::Pair; init = glorot_uniform, bias = true)
RHNCellUnit((in, out)::Pair;
kernel_init = glorot_uniform,
bias = true)
"""
function RHNCellUnit((in, out)::Pair; init = glorot_uniform, bias = true)
weight = init(3 * out, in)
function RHNCellUnit((in, out)::Pair;
kernel_init = glorot_uniform,
bias = true)
weight = kernel_init(3 * out, in)
b = create_bias(weight, bias, size(weight, 1))
return RHNCellUnit(weight, b)
end
Expand Down
15 changes: 10 additions & 5 deletions src/scrn_cell.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,21 @@ Flux.@layer SCRNCell


"""
SCRNCell(in => out; init = glorot_uniform, bias = true)
function SCRNCell((in, out)::Pair;
kernel_init = glorot_uniform,
recurrent_kernel_init = glorot_uniform,
bias = true,
alpha = 0.0)
"""
function SCRNCell((in, out)::Pair;
init = glorot_uniform,
kernel_init = glorot_uniform,
recurrent_kernel_init = glorot_uniform,
bias = true,
alpha = 0.0)

Wi = init(2 * out, in)
Wh = init(2 * out, out)
Wc = init(2 * out, out)
Wi = kernel_init(2 * out, in)
Wh = recurrent_kernel_init(2 * out, out)
Wc = recurrent_kernel_init(2 * out, out)
b = create_bias(Wi, bias, size(Wh, 1))
return SCRNCell(Wi, Wh, Wc, b, alpha)
end
Expand Down
11 changes: 7 additions & 4 deletions src/sru_cell.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,13 @@ end

Flux.@layer SRUCell

function SRUCell((in, out)::Pair, σ=tanh; init = glorot_uniform, bias = true)
Wi = init(2 * out, in)
Wh = init(2 * out, out)
v = init(2 * out)
function SRUCell((in, out)::Pair, σ=tanh;
kernel_init = glorot_uniform,
recurrent_kernel_init = glorot_uniform,
bias = true)
Wi = kernel_init(2 * out, in)
Wh = recurrent_kernel_init(2 * out, out)
v = kernel_init(2 * out)
b = create_bias(Wi, bias, size(Wh, 1))

return SRUCell(Wi, Wh, v, b)
Expand Down

0 comments on commit 65dbfe1

Please sign in to comment.