Skip to content

Commit

Permalink
format
Browse files Browse the repository at this point in the history
  • Loading branch information
MartinuzziFrancesco committed Jan 8, 2025
1 parent e5eeb7e commit 136addf
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 14 deletions.
6 changes: 2 additions & 4 deletions src/esn/esn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ julia> train_data = rand(Float32, 10, 100) # 10 features, 100 time steps
julia> esn = ESN(train_data, 10, 300; washout=10)
ESN(10 => 300)
```
"""
function ESN(train_data,
Expand Down Expand Up @@ -95,8 +94,9 @@ function (esn::AbstractEchoStateNetwork)(prediction::AbstractPrediction,
kwargs...)
end

Base.show(io::IO, esn::ESN) =
function Base.show(io::IO, esn::ESN)
print(io, "ESN(", size(esn.train_data, 1), " => ", size(esn.reservoir_matrix, 1), ")")
end

#training dispatch on esn
"""
Expand All @@ -110,7 +110,6 @@ Trains an Echo State Network (ESN) using the provided target data and a specifie
- `target_data`: Supervised training data for the ESN.
- `training_method`: The method for training the ESN (default: `StandardRidge(0.0)`).
# Example
```jldoctest
Expand All @@ -132,7 +131,6 @@ ESN(10 => 300)
julia> output_layer = train(esn, rand(Float32, 3, 90))
OutputLayer successfully trained with output size: 3
```
"""
function train(esn::AbstractEchoStateNetwork,
Expand Down
8 changes: 1 addition & 7 deletions src/esn/esn_reservoir_drivers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ specified reservoir driver.
and reservoir nodes.
- `bias_vector`: The bias vector to be added at each time step during the reservoir
update.
"""
function create_states(reservoir_driver::AbstractReservoirDriver,
train_data,
Expand Down Expand Up @@ -108,8 +107,6 @@ echo state networks (`ESN`).
Defaults to `tanh_fast`.
- `leaky_coefficient`: The leaky coefficient used in the RNN.
Defaults to 1.0.
"""
function RNN(; activation_function=NNlib.fast_act(tanh), leaky_coefficient=1.0)
RNN(activation_function, leaky_coefficient)
Expand Down Expand Up @@ -185,7 +182,6 @@ This function creates an MRNN object with the specified activation functions,
leaky coefficient, and scaling factors, which can be used as a reservoir driver
in the ESN.
[^Lun2015]: Lun, Shu-Xian, et al.
"_A novel model of leaky integrator echo state network for
time-series prediction._" Neurocomputing 159 (2015): 58-66.
Expand Down Expand Up @@ -234,10 +230,9 @@ end
Returns a Fully Gated Recurrent Unit (FullyGated) initializer
for the Echo State Network (ESN).
Returns the standard gated recurrent unit [^Cho2014] as a driver for the
Returns the standard gated recurrent unit [^Cho2014] as a driver for the
echo state network (`ESN`).
[^Cho2014]: Cho, Kyunghyun, et al.
"_Learning phrase representations using RNN encoder-decoder
for statistical machine translation._"
Expand Down Expand Up @@ -281,7 +276,6 @@ This driver is based on the GRU architecture [^Cho2014].
- `variant`: The GRU variant to use.
By default, it uses the "FullyGated" variant.
[^Cho2014]: Cho, Kyunghyun, et al.
"_Learning phrase representations using RNN encoder-decoder for statistical machine translation._"
arXiv preprint arXiv:1406.1078 (2014).
Expand Down
6 changes: 3 additions & 3 deletions src/predict.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@ struct OutputLayer{T, I, S, L} <: AbstractOutputLayer
last_value::L
end

Base.show(io::IO, ol::OutputLayer) =
function Base.show(io::IO, ol::OutputLayer)
print(io, "OutputLayer successfully trained with output size: ", ol.out_size)
end

#prediction types
"""
Expand Down Expand Up @@ -58,14 +59,13 @@ of input features (`prediction_data`).
The `Predictive` prediction method uses the provided input data
(`prediction_data`) to produce corresponding labels or outputs based
on the learned relationships in the model.
on the learned relationships in the model.
"""
function Predictive(prediction_data)
prediction_len = size(prediction_data, 2)
Predictive(prediction_data, prediction_len)
end


function obtain_prediction(rc::AbstractReservoirComputer,
prediction::Generative,
x,
Expand Down

0 comments on commit 136addf

Please sign in to comment.