-
Notifications
You must be signed in to change notification settings - Fork 120
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[ENH] Add and Validate n_layers
, n_units
, activation
& dropout_rate
kwargs to MLPNetwork
#2338
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,6 +3,8 @@ | |
__maintainer__ = ["hadifawaz1999"] | ||
|
||
|
||
from typing import Union | ||
|
||
from aeon.networks.base import BaseDeepLearningNetwork | ||
|
||
|
||
|
@@ -11,6 +13,17 @@ class MLPNetwork(BaseDeepLearningNetwork): | |
|
||
Adapted from the implementation used in [1]_ | ||
|
||
Parameters | ||
---------- | ||
n_layers : int, optional (default=3) | ||
The number of dense layers in the MLP. | ||
n_units : Union[int, list[int]], optional (default=200) | ||
Number of units in each dense layer. | ||
activation : Union[str, list[str]], optional (default='relu') | ||
Activation function(s) for each dense layer. | ||
dropout_rate : Union[int, float, list[Union[int, float]]], optional (default=None) | ||
Dropout rate(s) for each dense layer. If None, a default rate of 0.2 is used. | ||
|
||
Notes | ||
----- | ||
Adapted from the implementation from source code | ||
|
@@ -24,9 +37,56 @@ class MLPNetwork(BaseDeepLearningNetwork): | |
|
||
def __init__( | ||
self, | ||
n_layers: int = 3, | ||
n_units: Union[int, list[int]] = 200, | ||
activation: Union[str, list[str]] = "relu", | ||
dropout_rate: Union[int, list[int]] = None, | ||
): | ||
super().__init__() | ||
|
||
self._n_layers = n_layers | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. no need to define private for this one with underscore as it wont change |
||
|
||
if isinstance(activation, str): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. better to declare all of them as self before defining the internal versions, so self.activation = activation and then define self._activation |
||
self._activation = [activation] * self._n_layers | ||
elif isinstance(activation, list): | ||
assert ( | ||
len(activation) == self._n_layers | ||
), "There should be an `activation` function associated with each layer." | ||
assert all( | ||
isinstance(a, str) for a in activation | ||
), "Activation must be a list of strings." | ||
assert ( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. asserting the len() twice |
||
len(activation) == n_layers | ||
), "Activation list length must match number of layers." | ||
self._activation = activation | ||
|
||
if dropout_rate is None: | ||
self._dropout_rate = [0.2] * self._n_layers | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the default value of dropout is not 0.2 in the original implementation (check main branch), its 0.1 then 0.2 then 0.2 |
||
elif isinstance(dropout_rate, (int, float)): | ||
self._dropout_rate = [float(dropout_rate)] * self._n_layers | ||
elif isinstance(dropout_rate, list): | ||
assert ( | ||
len(dropout_rate) == self._n_layers | ||
), "There should be a `dropout_rate` associated with each layer." | ||
assert all( | ||
isinstance(d, (int, float)) for d in dropout_rate | ||
), "Dropout rates must be int or float." | ||
assert ( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. asserting on len twice |
||
len(dropout_rate) == n_layers | ||
), "Dropout list length must match number of layers." | ||
self._dropout_rate = [float(d) for d in dropout_rate] | ||
|
||
if isinstance(n_units, int): | ||
self._n_units = [n_units] * self._n_layers | ||
elif isinstance(n_units, list): | ||
assert all( | ||
isinstance(u, int) for u in n_units | ||
), "`n_units` must be int for all layers." | ||
assert ( | ||
len(n_units) == n_layers | ||
), "`n_units` length must match number of layers." | ||
self._n_units = n_units | ||
|
||
def build_network(self, input_shape, **kwargs): | ||
"""Construct a network and return its input and output layers. | ||
|
||
|
@@ -42,19 +102,18 @@ def build_network(self, input_shape, **kwargs): | |
""" | ||
from tensorflow import keras | ||
|
||
# flattened because multivariate should be on same axis | ||
input_layer = keras.layers.Input(input_shape) | ||
input_layer_flattened = keras.layers.Flatten()(input_layer) | ||
|
||
layer_1 = keras.layers.Dropout(0.1)(input_layer_flattened) | ||
layer_1 = keras.layers.Dense(500, activation="relu")(layer_1) | ||
|
||
layer_2 = keras.layers.Dropout(0.2)(layer_1) | ||
layer_2 = keras.layers.Dense(500, activation="relu")(layer_2) | ||
x = keras.layers.Dropout(self._dropout_rate[0])(input_layer_flattened) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. better to include this block in the loop |
||
x = keras.layers.Dense(self._n_units[0], activation=self._activation[0])(x) | ||
|
||
layer_3 = keras.layers.Dropout(0.2)(layer_2) | ||
layer_3 = keras.layers.Dense(500, activation="relu")(layer_3) | ||
for idx in range(1, self._n_layers): | ||
x = keras.layers.Dropout(self._dropout_rate[idx])(x) | ||
x = keras.layers.Dense( | ||
self._n_units[idx], activation=self._activation[idx] | ||
)(x) | ||
|
||
output_layer = keras.layers.Dropout(0.3)(layer_3) | ||
output_layer = keras.layers.Dropout(0.3)(x) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i would define a parameter called dropout_last default to 0.3, and document it in details for that this dropout layer is applied at the end |
||
|
||
return input_layer, output_layer |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i would obligate float always for dropout_rate and say that is should be between 0 and 1