From 30f0765f8b0e6b1b74fea263eb282b9955db0ae5 Mon Sep 17 00:00:00 2001 From: Shuheng Liu Date: Mon, 4 Jan 2021 17:27:41 +0800 Subject: [PATCH 01/18] feat(solve): add support for batch numbers in `solve_*` functions --- neurodiffeq/pde.py | 215 ++++++++++++++++++++++++--------------------- 1 file changed, 116 insertions(+), 99 deletions(-) diff --git a/neurodiffeq/pde.py b/neurodiffeq/pde.py index c7c370c..6e6f313 100644 --- a/neurodiffeq/pde.py +++ b/neurodiffeq/pde.py @@ -54,9 +54,8 @@ def _trial_solution_2input(single_net, nets, xs, ys, conditions): def solve2D( pde, condition, xy_min=None, xy_max=None, net=None, train_generator=None, shuffle=True, valid_generator=None, optimizer=None, - criterion=None, additional_loss_term=None, metrics=None, - batch_size=16, - max_epochs=1000, + criterion=None, additional_loss_term=None, n_batches_train=1, n_batches_valid=4, + metrics=None, batch_size=16, max_epochs=1000, monitor=None, return_internal=False, return_best=False ): r"""Train a neural network to solve a PDE with 2 independent variables. @@ -110,6 +109,14 @@ def solve2D( Extra terms to add to the loss function besides the part specified by `criterion`. The input of `additional_loss_term` should be the same as `pde_system`. :type additional_loss_term: callable + :param n_batches_train: + Number of batches to train in every epoch, where batch-size equals ``train_generator.size``. + Defaults to 1. + :type n_batches_train: int, optional + :param n_batches_valid: + Number of batches to validate in every epoch, where batch-size equals ``valid_generator.size``. + Defaults to 4. + :type n_batches_valid: int, optional :param metrics: Metrics to keep track of during training. The metrics should be passed as a dictionary where the keys are the names of the metrics, @@ -148,8 +155,8 @@ def solve2D( pde_system=lambda u, x, y: [pde(u, x, y)], conditions=[condition], xy_min=xy_min, xy_max=xy_max, nets=nets, train_generator=train_generator, shuffle=shuffle, valid_generator=valid_generator, - optimizer=optimizer, criterion=criterion, additional_loss_term=additional_loss_term, metrics=metrics, - batch_size=batch_size, + optimizer=optimizer, criterion=criterion, n_batches_train=n_batches_train, n_batches_valid=n_batches_valid, + additional_loss_term=additional_loss_term, metrics=metrics, batch_size=batch_size, max_epochs=max_epochs, monitor=monitor, return_internal=return_internal, return_best=return_best ) @@ -157,104 +164,112 @@ def solve2D( def solve2D_system( pde_system, conditions, xy_min=None, xy_max=None, single_net=None, nets=None, train_generator=None, shuffle=True, valid_generator=None, - optimizer=None, criterion=None, additional_loss_term=None, metrics=None, batch_size=16, - max_epochs=1000, + optimizer=None, criterion=None, n_batches_train=1, n_batches_valid=4, + additional_loss_term=None, metrics=None, batch_size=16, max_epochs=1000, monitor=None, return_internal=False, return_best=False ): r"""Train a neural network to solve a PDE with 2 independent variables. - :param pde_system: - The PDE system to solve. - If the PDE is :math:`F_i(u_1, u_2, ..., u_n, x, y) = 0` - where :math:`u_i` is the i-th dependent variable and :math:`x` and :math:`y` are the independent variables, - then `pde_system` should be a function that maps :math:`(u_1, u_2, ..., u_n, x, y)` - to a list where the i-th entry is :math:`F_i(u_1, u_2, ..., u_n, x, y)`. - :type pde_system: callable - :param conditions: - The initial/boundary conditions. - The ith entry of the conditions is the condition that :math:`x_i` should satisfy. - :type conditions: list[`neurodiffeq.conditions.BaseCondition`] - :param xy_min: - The lower bound of 2 dimensions. - If we only care about :math:`x \geq x_0` and :math:`y \geq y_0`, - then `xy_min` is `(x_0, y_0)`. - Only needed when train_generator or valid_generator are not specified. - Defaults to None - :type xy_min: tuple[float, float], optional - :param xy_max: - The upper bound of 2 dimensions. - If we only care about :math:`x \leq x_1` and :math:`y \leq y_1`, then `xy_min` is `(x_1, y_1)`. - Only needed when train_generator or valid_generator are not specified. - Defaults to None - :type xy_max: tuple[float, float], optional - :param single_net: - The single neural network used to approximate the solution. - Only one of `single_net` and `nets` should be specified. - Defaults to None - :param single_net: `torch.nn.Module`, optional - :param nets: - The neural networks used to approximate the solution. - Defaults to None. - :type nets: list[`torch.nn.Module`], optional - :param train_generator: - The example generator to generate 1-D training points. - Default to None. - :type train_generator: `neurodiffeq.generators.Generator2D`, optional - :param shuffle: - Whether to shuffle the training examples every epoch. - Defaults to True. - :type shuffle: bool, optional - :param valid_generator: - The example generator to generate 1-D validation points. - Default to None. - :type valid_generator: `neurodiffeq.generators.Generator2D`, optional - :param optimizer: - The optimization method to use for training. - Defaults to None. - :type optimizer: `torch.optim.Optimizer`, optional - :param criterion: - The loss function to use for training. - Defaults to None. - :type criterion: `torch.nn.modules.loss._Loss`, optional - :param additional_loss_term: - Extra terms to add to the loss function besides the part specified by `criterion`. - The input of `additional_loss_term` should be the same as `pde_system`. - :type additional_loss_term: callable - :param metrics: - Metrics to keep track of during training. - The metrics should be passed as a dictionary where the keys are the names of the metrics, - and the values are the corresponding function. - The input functions should be the same as `pde_system` and the output should be a numeric value. - The metrics are evaluated on both the training set and validation set. - :type metrics: dict[string, callable] - :param batch_size: - The size of the mini-batch to use. - Defaults to 16. - :type batch_size: int, optional - :param max_epochs: - The maximum number of epochs to train. - Defaults to 1000. - :type max_epochs: int, optional - :param monitor: - The monitor to check the status of nerual network during training. - Defaults to None. - :type monitor: `neurodiffeq.pde.Monitor2D`, optional - :param return_internal: - Whether to return the nets, conditions, training generator, - validation generator, optimizer and loss function. - Defaults to False. - :type return_internal: bool, optional - :param return_best: - Whether to return the nets that achieved the lowest validation loss. - Defaults to False. - :type return_best: bool, optional - :return: - The solution of the PDE. - The history of training loss and validation loss. - Optionally, the nets, conditions, training generator, validation generator, optimizer and loss function. - The solution is a function that has the signature `solution(xs, ys, as_type)`. - :rtype: tuple[`neurodiffeq.pde.Solution`, dict] or tuple[`neurodiffeq.pde.Solution`, dict, dict] - """ + :param pde_system: + The PDE system to solve. + If the PDE is :math:`F_i(u_1, u_2, ..., u_n, x, y) = 0` + where :math:`u_i` is the i-th dependent variable and :math:`x` and :math:`y` are the independent variables, + then `pde_system` should be a function that maps :math:`(u_1, u_2, ..., u_n, x, y)` + to a list where the i-th entry is :math:`F_i(u_1, u_2, ..., u_n, x, y)`. + :type pde_system: callable + :param conditions: + The initial/boundary conditions. + The ith entry of the conditions is the condition that :math:`x_i` should satisfy. + :type conditions: list[`neurodiffeq.conditions.BaseCondition`] + :param xy_min: + The lower bound of 2 dimensions. + If we only care about :math:`x \geq x_0` and :math:`y \geq y_0`, + then `xy_min` is `(x_0, y_0)`. + Only needed when train_generator or valid_generator are not specified. + Defaults to None + :type xy_min: tuple[float, float], optional + :param xy_max: + The upper bound of 2 dimensions. + If we only care about :math:`x \leq x_1` and :math:`y \leq y_1`, then `xy_min` is `(x_1, y_1)`. + Only needed when train_generator or valid_generator are not specified. + Defaults to None + :type xy_max: tuple[float, float], optional + :param single_net: + The single neural network used to approximate the solution. + Only one of `single_net` and `nets` should be specified. + Defaults to None + :param single_net: `torch.nn.Module`, optional + :param nets: + The neural networks used to approximate the solution. + Defaults to None. + :type nets: list[`torch.nn.Module`], optional + :param train_generator: + The example generator to generate 1-D training points. + Default to None. + :type train_generator: `neurodiffeq.generators.Generator2D`, optional + :param shuffle: + Whether to shuffle the training examples every epoch. + Defaults to True. + :type shuffle: bool, optional + :param valid_generator: + The example generator to generate 1-D validation points. + Default to None. + :type valid_generator: `neurodiffeq.generators.Generator2D`, optional + :param optimizer: + The optimization method to use for training. + Defaults to None. + :type optimizer: `torch.optim.Optimizer`, optional + :param criterion: + The loss function to use for training. + Defaults to None. + :type criterion: `torch.nn.modules.loss._Loss`, optional + :param n_batches_train: + Number of batches to train in every epoch, where batch-size equals ``train_generator.size``. + Defaults to 1. + :type n_batches_train: int, optional + :param n_batches_valid: + Number of batches to validate in every epoch, where batch-size equals ``valid_generator.size``. + Defaults to 4. + :type n_batches_valid: int, optional + :param additional_loss_term: + Extra terms to add to the loss function besides the part specified by `criterion`. + The input of `additional_loss_term` should be the same as `pde_system`. + :type additional_loss_term: callable + :param metrics: + Metrics to keep track of during training. + The metrics should be passed as a dictionary where the keys are the names of the metrics, + and the values are the corresponding function. + The input functions should be the same as `pde_system` and the output should be a numeric value. + The metrics are evaluated on both the training set and validation set. + :type metrics: dict[string, callable] + :param batch_size: + The size of the mini-batch to use. + Defaults to 16. + :type batch_size: int, optional + :param max_epochs: + The maximum number of epochs to train. + Defaults to 1000. + :type max_epochs: int, optional + :param monitor: + The monitor to check the status of nerual network during training. + Defaults to None. + :type monitor: `neurodiffeq.pde.Monitor2D`, optional + :param return_internal: + Whether to return the nets, conditions, training generator, + validation generator, optimizer and loss function. + Defaults to False. + :type return_internal: bool, optional + :param return_best: + Whether to return the nets that achieved the lowest validation loss. + Defaults to False. + :type return_best: bool, optional + :return: + The solution of the PDE. + The history of training loss and validation loss. + Optionally, the nets, conditions, training generator, validation generator, optimizer and loss function. + The solution is a function that has the signature `solution(xs, ys, as_type)`. + :rtype: tuple[`neurodiffeq.pde.Solution`, dict] or tuple[`neurodiffeq.pde.Solution`, dict, dict] + """ warnings.warn( "The `solve2D_system` function is deprecated, use a `neurodiffeq.solvers.Solver2D` instance instead", @@ -296,6 +311,8 @@ class CustomSolver2D(Solver2D): valid_generator=valid_generator, optimizer=optimizer, criterion=criterion, + n_batches_train=n_batches_train, + n_batches_valid=n_batches_valid, metrics=metrics, batch_size=batch_size, shuffle=shuffle, From 387992e85d87371d5ac68dee69fd37fa92959804 Mon Sep 17 00:00:00 2001 From: Shuheng Liu Date: Mon, 4 Jan 2021 17:29:32 +0800 Subject: [PATCH 02/18] api(solve): deprecate `batch_size` and `shuffle` in `solve_*` functions --- neurodiffeq/pde.py | 55 +++++++++++++++++++++++++--------------------- 1 file changed, 30 insertions(+), 25 deletions(-) diff --git a/neurodiffeq/pde.py b/neurodiffeq/pde.py index 6e6f313..59979ac 100644 --- a/neurodiffeq/pde.py +++ b/neurodiffeq/pde.py @@ -35,6 +35,7 @@ def _network_output_2input(net, xs, ys, ith_unit): else: return nn_output + # Adjust the output of the neural network with trial solutions # coded into `conditions`. def _trial_solution_2input(single_net, nets, xs, ys, conditions): @@ -53,10 +54,10 @@ def _trial_solution_2input(single_net, nets, xs, ys, conditions): def solve2D( pde, condition, xy_min=None, xy_max=None, - net=None, train_generator=None, shuffle=True, valid_generator=None, optimizer=None, - criterion=None, additional_loss_term=None, n_batches_train=1, n_batches_valid=4, - metrics=None, batch_size=16, max_epochs=1000, - monitor=None, return_internal=False, return_best=False + net=None, train_generator=None, valid_generator=None, optimizer=None, + criterion=None, n_batches_train=1, n_batches_valid=4, + additional_loss_term=None, metrics=None, max_epochs=1000, + monitor=None, return_internal=False, return_best=False, batch_size=None, shuffle=True ): r"""Train a neural network to solve a PDE with 2 independent variables. @@ -89,10 +90,6 @@ def solve2D( The example generator to generate 1-D training points. Default to None. :type train_generator: `neurodiffeq.generators.Generator2D`, optional - :param shuffle: - Whether to shuffle the training examples every epoch. - Defaults to True. - :type shuffle: bool, optional :param valid_generator: The example generator to generate 1-D validation points. Default to None. @@ -124,10 +121,6 @@ def solve2D( The input functions should be the same as `pde` and the output should be a numeric value. The metrics are evaluated on both the training set and validation set. :type metrics: dict[string, callable] - :param batch_size: - The size of the mini-batch to use. - Defaults to 16. - :type batch_size: int, optional :param max_epochs: The maximum number of epochs to train. Defaults to 1000. @@ -144,6 +137,15 @@ def solve2D( Whether to return the nets that achieved the lowest validation loss. Defaults to False. :type return_best: bool, optional + :param batch_size: + **[DEPRECATED and IGNORED]** + Each batch will use all samples generated. + Please specify n_batches_train and n_batches_valid instead. + :type batch_size: int + :param shuffle: + **[DEPRECATED and IGNORED]** + Shuffling should be performed by generators. + :type shuffle: bool :return: The solution of the PDE. The history of training loss and validation loss. Optionally, the nets, conditions, training generator, validation generator, optimizer and loss function. @@ -163,10 +165,10 @@ def solve2D( def solve2D_system( pde_system, conditions, xy_min=None, xy_max=None, - single_net=None, nets=None, train_generator=None, shuffle=True, valid_generator=None, + single_net=None, nets=None, train_generator=None, valid_generator=None, optimizer=None, criterion=None, n_batches_train=1, n_batches_valid=4, - additional_loss_term=None, metrics=None, batch_size=16, max_epochs=1000, - monitor=None, return_internal=False, return_best=False + additional_loss_term=None, metrics=None, max_epochs=1000, + monitor=None, return_internal=False, return_best=False, batch_size=None, shuffle=True, ): r"""Train a neural network to solve a PDE with 2 independent variables. @@ -207,10 +209,6 @@ def solve2D_system( The example generator to generate 1-D training points. Default to None. :type train_generator: `neurodiffeq.generators.Generator2D`, optional - :param shuffle: - Whether to shuffle the training examples every epoch. - Defaults to True. - :type shuffle: bool, optional :param valid_generator: The example generator to generate 1-D validation points. Default to None. @@ -242,10 +240,6 @@ def solve2D_system( The input functions should be the same as `pde_system` and the output should be a numeric value. The metrics are evaluated on both the training set and validation set. :type metrics: dict[string, callable] - :param batch_size: - The size of the mini-batch to use. - Defaults to 16. - :type batch_size: int, optional :param max_epochs: The maximum number of epochs to train. Defaults to 1000. @@ -263,6 +257,15 @@ def solve2D_system( Whether to return the nets that achieved the lowest validation loss. Defaults to False. :type return_best: bool, optional + :param batch_size: + **[DEPRECATED and IGNORED]** + Each batch will use all samples generated. + Please specify n_batches_train and n_batches_valid instead. + :type batch_size: int + :param shuffle: + **[DEPRECATED and IGNORED]** + Shuffling should be performed by generators. + :type shuffle: bool :return: The solution of the PDE. The history of training loss and validation loss. @@ -344,7 +347,7 @@ def make_animation(solution, xs, ts): sol_net = solution(xx, tt, as_type='np') def u_gen(): - for i in range( len(sol_net) ): + for i in range(len(sol_net)): yield sol_net[i] fig, ax = plt.subplots() @@ -352,8 +355,9 @@ def u_gen(): umin, umax = sol_net.min(), sol_net.max() scale = umax - umin - ax.set_ylim(umin-scale*0.1, umax+scale*0.1) + ax.set_ylim(umin - scale * 0.1, umax + scale * 0.1) ax.set_xlim(xs.min(), xs.max()) + def run(data): line.set_data(xs, data) return line, @@ -362,6 +366,7 @@ def run(data): fig, run, u_gen, blit=True, interval=50, repeat=False ) + ############################# arbitraty boundary conditions ############################# # CONSTANTS From 46f170d952a5f8eab369b18bfa078445f2a34b6f Mon Sep 17 00:00:00 2001 From: Shuheng Liu Date: Mon, 4 Jan 2021 17:34:02 +0800 Subject: [PATCH 03/18] feat(solve): support specification of batch numbers in ode --- neurodiffeq/ode.py | 30 +++++++++++++++++++++++++----- 1 file changed, 25 insertions(+), 5 deletions(-) diff --git a/neurodiffeq/ode.py b/neurodiffeq/ode.py index 73327fc..7d8a26a 100644 --- a/neurodiffeq/ode.py +++ b/neurodiffeq/ode.py @@ -36,8 +36,8 @@ def _trial_solution(single_net, nets, ts, conditions): def solve( ode, condition, t_min=None, t_max=None, net=None, train_generator=None, shuffle=True, valid_generator=None, - optimizer=None, criterion=None, additional_loss_term=None, metrics=None, batch_size=16, - max_epochs=1000, + optimizer=None, criterion=None, n_batches_train=1, n_batches_valid=4, + additional_loss_term=None, metrics=None, batch_size=16, max_epochs=1000, monitor=None, return_internal=False, return_best=False ): @@ -86,6 +86,14 @@ def solve( The loss function to use for training. Defaults to None. :type criterion: `torch.nn.modules.loss._Loss`, optional + :param n_batches_train: + Number of batches to train in every epoch, where batch-size equals ``train_generator.size``. + Defaults to 1. + :type n_batches_train: int, optional + :param n_batches_valid: + Number of batches to validate in every epoch, where batch-size equals ``valid_generator.size``. + Defaults to 4. + :type n_batches_valid: int, optional :param additional_loss_term: Extra terms to add to the loss function besides the part specified by `criterion`. The input of `additional_loss_term` should be the same as `ode`. @@ -128,7 +136,8 @@ def solve( ode_system=lambda x, t: [ode(x, t)], conditions=[condition], t_min=t_min, t_max=t_max, nets=nets, train_generator=train_generator, shuffle=shuffle, valid_generator=valid_generator, - optimizer=optimizer, criterion=criterion, additional_loss_term=additional_loss_term, metrics=metrics, + optimizer=optimizer, criterion=criterion, n_batches_train=n_batches_train, n_batches_valid=n_batches_valid, + additional_loss_term=additional_loss_term, metrics=metrics, batch_size=batch_size, max_epochs=max_epochs, monitor=monitor, return_internal=return_internal, return_best=return_best ) @@ -137,8 +146,9 @@ def solve( def solve_system( ode_system, conditions, t_min, t_max, single_net=None, nets=None, train_generator=None, shuffle=True, valid_generator=None, - optimizer=None, criterion=None, additional_loss_term=None, metrics=None, batch_size=16, - max_epochs=1000, monitor=None, return_internal=False, + optimizer=None, criterion=None, n_batches_train=1, n_batches_valid=4, + additional_loss_term=None, metrics=None, batch_size=16, max_epochs=1000, monitor=None, + return_internal=False, return_best=False, ): r"""Train a neural network to solve an ODE. @@ -189,6 +199,14 @@ def solve_system( The loss function to use for training. Defaults to None and sum of square of the output of `ode_system` will be used. :type criterion: `torch.nn.modules.loss._Loss`, optional + :param n_batches_train: + Number of batches to train in every epoch, where batch-size equals ``train_generator.size``. + Defaults to 1. + :type n_batches_train: int, optional + :param n_batches_valid: + Number of batches to validate in every epoch, where batch-size equals ``valid_generator.size``. + Defaults to 4. + :type n_batches_valid: int, optional :param additional_loss_term: Extra terms to add to the loss function besides the part specified by `criterion`. The input of `additional_loss_term` should be the same as `ode_system`. @@ -272,6 +290,8 @@ class CustomSolver1D(Solver1D): valid_generator=valid_generator, optimizer=optimizer, criterion=criterion, + n_batches_train=n_batches_train, + n_batches_valid=n_batches_valid, metrics=metrics, batch_size=batch_size, shuffle=shuffle, From 2e3cd3399d8ef7a720b3fed018b3da517c138e02 Mon Sep 17 00:00:00 2001 From: Shuheng Liu Date: Mon, 4 Jan 2021 17:39:18 +0800 Subject: [PATCH 04/18] api(solve): deprecate `shuffle` & `batch_size` for solve* functions --- neurodiffeq/ode.py | 31 +++++++++++------------ neurodiffeq/pde.py | 4 +-- neurodiffeq/pde_spherical.py | 48 +++++++++++++++++++----------------- 3 files changed, 42 insertions(+), 41 deletions(-) diff --git a/neurodiffeq/ode.py b/neurodiffeq/ode.py index 7d8a26a..2f144ae 100644 --- a/neurodiffeq/ode.py +++ b/neurodiffeq/ode.py @@ -35,11 +35,10 @@ def _trial_solution(single_net, nets, ts, conditions): def solve( ode, condition, t_min=None, t_max=None, - net=None, train_generator=None, shuffle=True, valid_generator=None, + net=None, train_generator=None, valid_generator=None, optimizer=None, criterion=None, n_batches_train=1, n_batches_valid=4, - additional_loss_term=None, metrics=None, batch_size=16, max_epochs=1000, - monitor=None, return_internal=False, - return_best=False + additional_loss_term=None, metrics=None, max_epochs=1000, + monitor=None, return_internal=False, return_best=False, batch_size=None, shuffle=None, ): r"""Train a neural network to solve an ODE. @@ -70,10 +69,6 @@ def solve( The example generator to generate 1-D training points. Default to None. :type train_generator: `neurodiffeq.generators.Generator1D`, optional - :param shuffle: - Whether to shuffle the training examples every epoch. - Defaults to True. - :type shuffle: bool, optional :param valid_generator: The example generator to generate 1-D validation points. Default to None. @@ -105,10 +100,6 @@ def solve( The input functions should be the same as `ode` and the output should be a numeric value. The metrics are evaluated on both the training set and validation set. :type metrics: dict[string, callable] - :param batch_size: - The size of the mini-batch to use. - Defaults to 16. - :type batch_size: int, optional :param max_epochs: The maximum number of epochs to train. Defaults to 1000. @@ -125,6 +116,15 @@ def solve( Whether to return the nets that achieved the lowest validation loss. Defaults to False. :type return_best: bool, optional + :param batch_size: + **[DEPRECATED and IGNORED]** + Each batch will use all samples generated. + Please specify n_batches_train and n_batches_valid instead. + :type batch_size: int + :param shuffle: + **[DEPRECATED and IGNORED]** + Shuffling should be performed by generators. + :type shuffle: bool :return: The solution of the ODE. The history of training loss and validation loss. @@ -145,11 +145,10 @@ def solve( def solve_system( ode_system, conditions, t_min, t_max, - single_net=None, nets=None, train_generator=None, shuffle=True, valid_generator=None, + single_net=None, nets=None, train_generator=None, valid_generator=None, optimizer=None, criterion=None, n_batches_train=1, n_batches_valid=4, - additional_loss_term=None, metrics=None, batch_size=16, max_epochs=1000, monitor=None, - return_internal=False, - return_best=False, + additional_loss_term=None, metrics=None, max_epochs=1000, monitor=None, + return_internal=False, return_best=False, batch_size=None, shuffle=None, ): r"""Train a neural network to solve an ODE. diff --git a/neurodiffeq/pde.py b/neurodiffeq/pde.py index 59979ac..be8ad11 100644 --- a/neurodiffeq/pde.py +++ b/neurodiffeq/pde.py @@ -57,7 +57,7 @@ def solve2D( net=None, train_generator=None, valid_generator=None, optimizer=None, criterion=None, n_batches_train=1, n_batches_valid=4, additional_loss_term=None, metrics=None, max_epochs=1000, - monitor=None, return_internal=False, return_best=False, batch_size=None, shuffle=True + monitor=None, return_internal=False, return_best=False, batch_size=None, shuffle=None, ): r"""Train a neural network to solve a PDE with 2 independent variables. @@ -168,7 +168,7 @@ def solve2D_system( single_net=None, nets=None, train_generator=None, valid_generator=None, optimizer=None, criterion=None, n_batches_train=1, n_batches_valid=4, additional_loss_term=None, metrics=None, max_epochs=1000, - monitor=None, return_internal=False, return_best=False, batch_size=None, shuffle=True, + monitor=None, return_internal=False, return_best=False, batch_size=None, shuffle=None, ): r"""Train a neural network to solve a PDE with 2 independent variables. diff --git a/neurodiffeq/pde_spherical.py b/neurodiffeq/pde_spherical.py index 5803c05..b5140d7 100644 --- a/neurodiffeq/pde_spherical.py +++ b/neurodiffeq/pde_spherical.py @@ -43,9 +43,9 @@ def solve_spherical( pde, condition, r_min=None, r_max=None, - net=None, train_generator=None, shuffle=True, valid_generator=None, analytic_solution=None, - optimizer=None, criterion=None, batch_size=16, max_epochs=1000, - monitor=None, return_internal=False, return_best=False, harmonics_fn=None, + net=None, train_generator=None, valid_generator=None, analytic_solution=None, + optimizer=None, criterion=None, max_epochs=1000, + monitor=None, return_internal=False, return_best=False, harmonics_fn=None, batch_size=None, shuffle=None, ): r"""[**DEPRECATED**, use SphericalSolver class instead] Train a neural network to solve one PDE with spherical inputs in 3D space. @@ -77,10 +77,6 @@ def solve_spherical( The example generator to generate 3-D validation points. Default to None. :type valid_generator: `neurodiffeq.generators.BaseGenerator`, optional - :param shuffle: - Whether to shuffle the training examples every epoch. - Defaults to True. - :type shuffle: bool, optional :param analytic_solution: Analytic solution to the pde system, used for testing purposes. It should map (``rs``, ``thetas``, ``phis``) to u. @@ -93,10 +89,6 @@ def solve_spherical( The loss function to use for training. Defaults to None. :type criterion: `torch.nn.modules.loss._Loss`, optional - :param batch_size: - The size of the mini-batch to use. - Defaults to 16. - :type batch_size: int, optional :param max_epochs: The maximum number of epochs to train. Defaults to 1000. @@ -125,7 +117,15 @@ def solve_spherical( :rtype: tuple[`neurodiffeq.pde_spherical.SolutionSpherical`, dict] or tuple[`neurodiffeq.pde_spherical.SolutionSpherical`, dict, dict] - + :param batch_size: + **[DEPRECATED and IGNORED]** + Each batch will use all samples generated. + Please specify ``n_batches_train`` and ``n_batches_valid`` instead. + :type batch_size: int + :param shuffle: + **[DEPRECATED and IGNORED]** + Shuffling should be performed by generators. + :type shuffle: bool .. note:: This function is deprecated, use a `SphericalSolver` instead @@ -151,9 +151,9 @@ def solve_spherical( def solve_spherical_system( pde_system, conditions, r_min=None, r_max=None, - nets=None, train_generator=None, shuffle=True, valid_generator=None, analytic_solutions=None, - optimizer=None, criterion=None, batch_size=None, - max_epochs=1000, monitor=None, return_internal=False, return_best=False, harmonics_fn=None + nets=None, train_generator=None, valid_generator=None, analytic_solutions=None, + optimizer=None, criterion=None, max_epochs=1000, monitor=None, return_internal=False, + return_best=False, harmonics_fn=None, batch_size=None, shuffle=None, ): r"""[**DEPRECATED**, use SphericalSolver class instead] Train a neural network to solve a PDE system with spherical inputs in 3D space @@ -190,10 +190,6 @@ def solve_spherical_system( The example generator to generate 3-D validation points. Default to None. :type valid_generator: `neurodiffeq.generators.BaseGenerator`, optional - :param shuffle: - **[DEPRECATED and IGNORED]** Don't use this. - Shuffling should be performed by generators. - :type shuffle: bool, optional :param analytic_solutions: Analytic solution to the pde system, used for testing purposes. It should map (rs, thetas, phis) to a list of [u_1, u_2, ..., u_n]. @@ -206,10 +202,6 @@ def solve_spherical_system( The loss function to use for training. Defaults to None. :type criterion: `torch.nn.modules.loss._Loss`, optional - :param batch_size: - The size of the mini-batch to use. - Defaults to 16. - :type batch_size: int, optional :param max_epochs: The maximum number of epochs to train. Defaults to 1000. @@ -238,6 +230,16 @@ def solve_spherical_system( :rtype: tuple[`neurodiffeq.pde_spherical.SolutionSpherical`, dict] or tuple[`neurodiffeq.pde_spherical.SolutionSpherical`, dict, dict] + :param batch_size: + **[DEPRECATED and IGNORED]** + Each batch will use all samples generated. + Please specify n_batches_train and n_batches_valid instead. + :type batch_size: int + :param shuffle: + **[DEPRECATED and IGNORED]** + Shuffling should be performed by generators. + :type shuffle: bool + .. note:: This function is deprecated, use a `SphericalSolver` instead From 8679712598e3a695627a440fb2425b97b61c3f9a Mon Sep 17 00:00:00 2001 From: Shuheng Liu Date: Mon, 4 Jan 2021 18:00:32 +0800 Subject: [PATCH 05/18] docs(solve): fix docstring typos and styles --- neurodiffeq/ode.py | 12 ++++++++++-- neurodiffeq/pde.py | 12 ++++++++++-- neurodiffeq/pde_spherical.py | 4 ++-- neurodiffeq/solvers.py | 8 ++++---- 4 files changed, 26 insertions(+), 10 deletions(-) diff --git a/neurodiffeq/ode.py b/neurodiffeq/ode.py index 2f144ae..b30630e 100644 --- a/neurodiffeq/ode.py +++ b/neurodiffeq/ode.py @@ -119,7 +119,7 @@ def solve( :param batch_size: **[DEPRECATED and IGNORED]** Each batch will use all samples generated. - Please specify n_batches_train and n_batches_valid instead. + Please specify ``n_batches_train`` and ``n_batches_valid`` instead. :type batch_size: int :param shuffle: **[DEPRECATED and IGNORED]** @@ -130,6 +130,10 @@ def solve( The history of training loss and validation loss. Optionally, the nets, conditions, training generator, validation generator, optimizer and loss function. :rtype: tuple[`neurodiffeq.ode.Solution`, dict] or tuple[`neurodiffeq.ode.Solution`, dict, dict] + + + .. note:: + This function is deprecated, use a ``neurodiffeq.solvers.Solver1D`` instead. """ nets = None if not net else [net] return solve_system( @@ -237,7 +241,7 @@ def solve_system( :param batch_size: **[DEPRECATED and IGNORED]** Each batch will use all samples generated. - Please specify n_batches_train and n_batches_valid instead. + Please specify ``n_batches_train`` and ``n_batches_valid`` instead. :type batch_size: int :param shuffle: **[DEPRECATED and IGNORED]** @@ -247,6 +251,10 @@ def solve_system( The solution of the ODE. The history of training loss and validation loss. Optionally, the nets, conditions, training generator, validation generator, optimizer and loss function. :rtype: tuple[`neurodiffeq.ode.Solution`, dict] or tuple[`neurodiffeq.ode.Solution`, dict, dict] + + + .. note:: + This function is deprecated, use a ``neurodiffeq.solvers.Solver1D`` instead. """ warnings.warn( diff --git a/neurodiffeq/pde.py b/neurodiffeq/pde.py index be8ad11..23e696e 100644 --- a/neurodiffeq/pde.py +++ b/neurodiffeq/pde.py @@ -140,7 +140,7 @@ def solve2D( :param batch_size: **[DEPRECATED and IGNORED]** Each batch will use all samples generated. - Please specify n_batches_train and n_batches_valid instead. + Please specify ``n_batches_train`` and ``n_batches_valid`` instead. :type batch_size: int :param shuffle: **[DEPRECATED and IGNORED]** @@ -151,6 +151,10 @@ def solve2D( Optionally, the nets, conditions, training generator, validation generator, optimizer and loss function. The solution is a function that has the signature `solution(xs, ys, as_type)`. :rtype: tuple[`neurodiffeq.pde.Solution`, dict] or tuple[`neurodiffeq.pde.Solution`, dict, dict] + + + .. note:: + This function is deprecated, use a ``neurodiffeq.solvers.Solver2D`` instead. """ nets = None if not net else [net] return solve2D_system( @@ -260,7 +264,7 @@ def solve2D_system( :param batch_size: **[DEPRECATED and IGNORED]** Each batch will use all samples generated. - Please specify n_batches_train and n_batches_valid instead. + Please specify ``n_batches_train`` and ``n_batches_valid`` instead. :type batch_size: int :param shuffle: **[DEPRECATED and IGNORED]** @@ -272,6 +276,10 @@ def solve2D_system( Optionally, the nets, conditions, training generator, validation generator, optimizer and loss function. The solution is a function that has the signature `solution(xs, ys, as_type)`. :rtype: tuple[`neurodiffeq.pde.Solution`, dict] or tuple[`neurodiffeq.pde.Solution`, dict, dict] + + + .. note:: + This function is deprecated, use a ``neurodiffeq.solvers.Solver2D`` instead. """ warnings.warn( diff --git a/neurodiffeq/pde_spherical.py b/neurodiffeq/pde_spherical.py index b5140d7..2c264ff 100644 --- a/neurodiffeq/pde_spherical.py +++ b/neurodiffeq/pde_spherical.py @@ -128,7 +128,7 @@ def solve_spherical( :type shuffle: bool .. note:: - This function is deprecated, use a `SphericalSolver` instead + This function is deprecated, use a ``neurodiffeq.solvers.SphericalSolver`` instead """ warnings.warn("solve_spherical is deprecated, consider using SphericalSolver instead") @@ -242,7 +242,7 @@ def solve_spherical_system( .. note:: - This function is deprecated, use a `SphericalSolver` instead + This function is deprecated, use a ``neurodiffeq.solvers.SphericalSolver`` instead """ warnings.warn("solve_spherical_system is deprecated, consider using SphericalSolver instead") diff --git a/neurodiffeq/solvers.py b/neurodiffeq/solvers.py index 363e8e4..c3028eb 100644 --- a/neurodiffeq/solvers.py +++ b/neurodiffeq/solvers.py @@ -74,7 +74,7 @@ class BaseSolver(ABC): :param batch_size: **[DEPRECATED and IGNORED]** Each batch will use all samples generated. - Please specify n_batches_train and n_batches_valid instead. + Please specify ``n_batches_train`` and ``n_batches_valid`` instead. :type batch_size: int :param shuffle: **[DEPRECATED and IGNORED]** @@ -611,7 +611,7 @@ class SolverSpherical(BaseSolver): :param batch_size: **[DEPRECATED and IGNORED]** Each batch will use all samples generated. - Please specify n_batches_train and n_batches_valid instead. + Please specify ``n_batches_train`` and ``n_batches_valid`` instead. :type batch_size: int :param shuffle: **[DEPRECATED and IGNORED]** @@ -845,7 +845,7 @@ class Solver1D(BaseSolver): :param batch_size: **[DEPRECATED and IGNORED]** Each batch will use all samples generated. - Please specify n_batches_train and n_batches_valid instead. + Please specify ``n_batches_train`` and ``n_batches_valid`` instead. :type batch_size: int :param shuffle: **[DEPRECATED and IGNORED]** @@ -1008,7 +1008,7 @@ class Solver2D(BaseSolver): :param batch_size: **[DEPRECATED and IGNORED]** Each batch will use all samples generated. - Please specify n_batches_train and n_batches_valid instead. + Please specify ``n_batches_train`` and ``n_batches_valid`` instead. :type batch_size: int :param shuffle: **[DEPRECATED and IGNORED]** From 1dd4c8a46f5496885cd870c0bbd9627426752040 Mon Sep 17 00:00:00 2001 From: Shuheng Liu Date: Tue, 5 Jan 2021 00:39:15 +0800 Subject: [PATCH 06/18] refactor(shuffle): set default `shuffle` to None --- neurodiffeq/solvers.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/neurodiffeq/solvers.py b/neurodiffeq/solvers.py index c3028eb..7848c5b 100644 --- a/neurodiffeq/solvers.py +++ b/neurodiffeq/solvers.py @@ -87,7 +87,7 @@ def __init__(self, diff_eqs, conditions, optimizer=None, criterion=None, n_batches_train=1, n_batches_valid=4, metrics=None, n_input_units=None, n_output_units=None, # deprecated arguments are listed below - shuffle=False, batch_size=None): + shuffle=None, batch_size=None): # deprecate argument `shuffle` if shuffle: warnings.warn( @@ -624,7 +624,7 @@ def __init__(self, pde_system, conditions, r_min=None, r_max=None, optimizer=None, criterion=None, n_batches_train=1, n_batches_valid=4, enforcer=None, n_output_units=1, # deprecated arguments are listed below - shuffle=False, batch_size=None): + shuffle=None, batch_size=None): if train_generator is None or valid_generator is None: if r_min is None or r_max is None: @@ -857,7 +857,7 @@ def __init__(self, ode_system, conditions, t_min, t_max, nets=None, train_generator=None, valid_generator=None, analytic_solutions=None, optimizer=None, criterion=None, n_batches_train=1, n_batches_valid=4, metrics=None, n_output_units=1, # deprecated arguments are listed below - batch_size=None, shuffle=True): + batch_size=None, shuffle=None): if train_generator is None or valid_generator is None: if t_min is None or t_max is None: @@ -1020,7 +1020,7 @@ def __init__(self, pde_system, conditions, xy_min, xy_max, nets=None, train_generator=None, valid_generator=None, analytic_solutions=None, optimizer=None, criterion=None, n_batches_train=1, n_batches_valid=4, metrics=None, n_output_units=1, # deprecated arguments are listed below - batch_size=None, shuffle=True): + batch_size=None, shuffle=None): if train_generator is None or valid_generator is None: if xy_min is None or xy_max is None: From b1ec9d566880a8f2645f0a91cc90564b504a7eeb Mon Sep 17 00:00:00 2001 From: Shuheng Liu Date: Tue, 5 Jan 2021 00:41:15 +0800 Subject: [PATCH 07/18] refactor(solver): use `diff_eqs` instead of `pdes` in `get_internals()` --- neurodiffeq/solvers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/neurodiffeq/solvers.py b/neurodiffeq/solvers.py index 7848c5b..5dd0df2 100644 --- a/neurodiffeq/solvers.py +++ b/neurodiffeq/solvers.py @@ -430,7 +430,7 @@ def _get_internal_variables(self): "n_funcs": self.n_funcs, "nets": self.nets, "optimizer": self.optimizer, - "pdes": self.diff_eqs, + "diff_eqs": self.diff_eqs, "generator": self.generator, } From c456fed040356329d4af6d8b5bcb8da09cc37d1d Mon Sep 17 00:00:00 2001 From: Shuheng Liu Date: Tue, 5 Jan 2021 00:42:35 +0800 Subject: [PATCH 08/18] refactor(animation): update calling of `BaseSolution` --- neurodiffeq/pde.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/neurodiffeq/pde.py b/neurodiffeq/pde.py index 23e696e..0ea72de 100644 --- a/neurodiffeq/pde.py +++ b/neurodiffeq/pde.py @@ -352,7 +352,7 @@ def make_animation(solution, xs, ts): """ xx, tt = np.meshgrid(xs, ts) - sol_net = solution(xx, tt, as_type='np') + sol_net = solution(xx, tt, to_numpy=True) def u_gen(): for i in range(len(sol_net)): From 7d8c76779b82d6574eec12a578084972452543e1 Mon Sep 17 00:00:00 2001 From: Shuheng Liu Date: Tue, 5 Jan 2021 00:43:44 +0800 Subject: [PATCH 09/18] refactor(callback): move MonitorCallback to a standalone `callbacks.py` --- neurodiffeq/callbacks.py | 53 ++++++++++++++++++++++++++++++++++++ neurodiffeq/pde_spherical.py | 51 ---------------------------------- 2 files changed, 53 insertions(+), 51 deletions(-) create mode 100644 neurodiffeq/callbacks.py diff --git a/neurodiffeq/callbacks.py b/neurodiffeq/callbacks.py new file mode 100644 index 0000000..87a9d02 --- /dev/null +++ b/neurodiffeq/callbacks.py @@ -0,0 +1,53 @@ +import os + + +class MonitorCallback: + """A callback for updating the monitor plots (and optionally saving the fig to disk). + + :param monitor: The underlying monitor responsible for plotting solutions. + :type monitor: `neurodiffeq.monitors.BaseMonitor` + :param fig_dir: Directory for saving monitor figs; if not specified, figs will not be saved. + :type fig_dir: str + :param check_against: Which epoch count to check against; either 'local' (default) or 'global'. + :type check_against: str + :param repaint_last: Whether to update the plot on the last local epoch, defaults to True. + :type repaint_last: bool + """ + + def __init__(self, monitor, fig_dir=None, check_against='local', repaint_last=True): + self.monitor = monitor + self.fig_dir = fig_dir + self.repaint_last = repaint_last + if check_against not in ['local', 'global']: + raise ValueError(f'unknown check_against type = {check_against}') + self.check_against = check_against + + def to_repaint(self, solver): + if self.check_against == 'local': + epoch_now = solver.local_epoch + 1 + elif self.check_against == 'global': + epoch_now = solver.global_epoch + 1 + else: + raise ValueError(f'unknown check_against type = {self.check_against}') + + if epoch_now % self.monitor.check_every == 0: + return True + if self.repaint_last and solver.local_epoch == solver._max_local_epoch - 1: + return True + + return False + + def __call__(self, solver): + if not self.to_repaint(solver): + return + + self.monitor.check( + solver.nets, + solver.conditions, + history=solver.metrics_history, + ) + if self.fig_dir: + pic_path = os.path.join(self.fig_dir, f"epoch-{solver.global_epoch}.png") + self.monitor.fig.savefig(pic_path) + + diff --git a/neurodiffeq/pde_spherical.py b/neurodiffeq/pde_spherical.py index 2c264ff..4643d83 100644 --- a/neurodiffeq/pde_spherical.py +++ b/neurodiffeq/pde_spherical.py @@ -284,57 +284,6 @@ def enforcer(net, cond, points): # callbacks to be passed to SphericalSolver.fit() -class MonitorCallback: - """A callback for updating the monitor plots (and optionally saving the fig to disk). - - :param monitor: The underlying monitor responsible for plotting solutions. - :type monitor: MonitorSpherical - :param fig_dir: Directory for saving monitor figs; if not specified, figs will not be saved. - :type fig_dir: str - :param check_against: Which epoch count to check against; either 'local' (default) or 'global'. - :type check_against: str - :param repaint_last: Whether to update the plot on the last local epoch, defaults to True. - :type repaint_last: bool - """ - - def __init__(self, monitor, fig_dir=None, check_against='local', repaint_last=True): - self.monitor = monitor - self.fig_dir = fig_dir - self.repaint_last = repaint_last - if check_against not in ['local', 'global']: - raise ValueError(f'unknown check_against type = {check_against}') - self.check_against = check_against - - def to_repaint(self, solver): - if self.check_against == 'local': - epoch_now = solver.local_epoch + 1 - elif self.check_against == 'global': - epoch_now = solver.global_epoch + 1 - else: - raise ValueError(f'unknown check_against type = {self.check_against}') - - if epoch_now % self.monitor.check_every == 0: - return True - if self.repaint_last and solver.local_epoch == solver._max_local_epoch - 1: - return True - - return False - - def __call__(self, solver): - if not self.to_repaint(solver): - return - - self.monitor.check( - solver.nets, - solver.conditions, - history=solver.loss, - analytic_mse_history=solver.analytic_solutions - ) - if self.fig_dir: - pic_path = os.path.join(self.fig_dir, f"epoch-{solver.global_epoch}.png") - self.monitor.fig.savefig(pic_path) - - class CheckpointCallback: def __init__(self, ckpt_dir): self.ckpt_dir = ckpt_dir From 5f353a357ffb4bb244f3afcaab699520b363d041 Mon Sep 17 00:00:00 2001 From: Shuheng Liu Date: Tue, 5 Jan 2021 01:15:41 +0800 Subject: [PATCH 10/18] docs(get_start); remove deprecated & outdated information --- docs/getstart.ipynb | 1744 ++++++++++++++++++++++++++++++++++++------- 1 file changed, 1467 insertions(+), 277 deletions(-) diff --git a/docs/getstart.ipynb b/docs/getstart.ipynb index 04e45b9..e645c33 100644 --- a/docs/getstart.ipynb +++ b/docs/getstart.ipynb @@ -15,9 +15,7 @@ "source": [ "import numpy as np\n", "import matplotlib.pyplot as plt\n", - "%matplotlib notebook\n", - "import warnings\n", - "warnings.filterwarnings('ignore')" + "%matplotlib notebook" ] }, { @@ -26,7 +24,12 @@ "source": [ "## Solving ODEs\n", "\n", - "ODEs can be solved by `neurodiffeq.ode.solve`. \n", + "There are two ways to solve an ODE (or ODE system).\n", + "1. As a legacy option, ODEs can be solved by `neurodiffeq.ode.solve`. \n", + "2. For users who want fine-grained control over the training process, please consider using a `neurodiffeq.solvers.Solver1d`.\n", + "\n", + "- The first option is easier to use but has been **deprecated** and might be removed in a future version. \n", + "- The second option is **recommended** for most users and supports advanced features like custom callbacks, checkpointing, early stopping, gradient clipping, learning rate scheduling, curriculum learning, etc.\n", "\n", "Just for the sake of notation in the following examples, here we see differentiation as an operation, then an ODE can be rewritten as \n", "\n", @@ -37,23 +40,25 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### ODE Example 1: Exponential Decay\n", + "### ODE Example 1: Exponential Decay (using the legacy `solve` function)\n", + "\n", + "__To show how simple__ `neurodiffeq` __can be, we'll first introduce the legacy option with the__ `solve` __function form__ `neurodiffeq.ode`__.__\n", "\n", "Start by solving \n", "\n", - "$$\\frac{dx}{dt} = -x.$$ \n", + "$$\\frac{du}{dt} = -u.$$ \n", "\n", - "for $x(t)$ with $x(0) = 1.0$. The analytical solution is \n", + "for $u(t)$ with $u(0) = 1.0$. The analytical solution is \n", "\n", "$$\n", - "x = e^{-t}.\n", + "u = e^{-t}.\n", "$$\n", "\n", "For `neurodiffeq.ode.solve` to solve this ODE, the following parameters needs to be specified:\n", "\n", - "* `ode`: a function representing the ODE to be solved. It should be a function that maps $(x, t)$ to $F(x, t)$. Here we are solving $$F(x, t)=\\dfrac{dx}{dt} + x=0,$$ then `ode` should be `lambda x, t: diff(x, t) - x`. `diff(x, t)` is the first order derivative of x with respect to t.\n", + "* `ode`: a function representing the ODE to be solved. It should be a function that maps $(u, t)$ to $F(u, t)$. Here we are solving $$F(u, t)=\\dfrac{du}{dt} + u=0,$$ then `ode` should be `lambda u, t: diff(u, t) - u`, where `diff(u, t)` is the first order derivative of `u` with respect to `t`.\n", "\n", - "* `condition`: a `neurodiffeq.ode.Condition` instance representing the initial condition / boundary condition of the ODE. Here we use `IVP(t_0=0.0, x_0=1.0)` to ensure $x(0) = 1.0$.\n", + "* `condition`: a `neurodiffeq.conditions.BaseCondition` instance representing the initial condition / boundary condition of the ODE. Here we use `neurodiffeq.conditions.IVP(t_0=0.0, u_0=1.0)` to ensure $u(0) = 1.0$.\n", "\n", "* `t_min` and `t_max`: the domain of $t$ to solve the ODE on." ] @@ -75,10 +80,19 @@ "metadata": { "scrolled": false }, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/liushuheng/Documents/GitHub/neurodiffeq/neurodiffeq/ode.py:260: FutureWarning: The `solve_system` function is deprecated, use a `neurodiffeq.solvers.Solver1D` instance instead\n", + " warnings.warn(\n" + ] + } + ], "source": [ - "exponential = lambda x, t: diff(x, t) + x # specify the ODE\n", - "init_val_ex = IVP(t_0=0.0, x_0=1.0) # specify the initial conditon\n", + "exponential = lambda u, t: diff(u, t) + u # specify the ODE\n", + "init_val_ex = IVP(t_0=0.0, u_0=1.0) # specify the initial conditon\n", "\n", "# solve the ODE\n", "solution_ex, loss_ex = solve(\n", @@ -90,7 +104,9 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "`solve` returns a tuple, where the first entry is the solution as a function and the second entry is the loss history. The solution is a function that maps $t$ to $x$. It accepts `numpy.array` as input as well. The default return type of the solution is `torch.tensor`. If we wanted to return `numpy.array`, we can specify `as_type='np'`. the loss history is a dictionary, where the 'train_loss' entry is the training loss and the 'valid_loss' entry is the validation loss. Here we compare the ANN-based solution with the analytical solution:" + "(Oops, we have a warning. As we have explained, the `solve` function still works but is deprecated. Hence we have the warning message, which we'll ignore for now.)\n", + "\n", + "`solve` returns a tuple, where the first entry is the solution (as a function) and the second entry is the history (of loss and other metrics) of training and validation. The solution is a function that maps $t$ to $u$. It accepts `numpy.array` or `troch.Tensor` as its input. The default return type of the solution is `torch.tensor`. If we wanted to return `numpy.array`, we can specify `to_numpy=True`. The history is a dictionary, where the 'train_loss' entry is the training loss and the 'valid_loss' entry is the validation loss. Here we compare the ANN-based solution with the analytical solution:" ] }, { @@ -1032,7 +1048,7 @@ { "data": { "text/html": [ - "" + "
" ], "text/plain": [ "" @@ -1044,13 +1060,13 @@ ], "source": [ "ts = np.linspace(0, 2.0, 100)\n", - "x_net = solution_ex(ts, as_type='np')\n", - "x_ana = np.exp(-ts)\n", + "u_net = solution_ex(ts, to_numpy=True)\n", + "u_ana = np.exp(-ts)\n", "\n", "plt.figure()\n", - "plt.plot(ts, x_net, label='ANN-based solution')\n", - "plt.plot(ts, x_ana, label='analytical solution')\n", - "plt.ylabel('x')\n", + "plt.plot(ts, u_net, label='ANN-based solution')\n", + "plt.plot(ts, u_ana, label='analytical solution')\n", + "plt.ylabel('u')\n", "plt.xlabel('t')\n", "plt.title('comparing solutions')\n", "plt.legend()\n", @@ -1996,7 +2012,7 @@ { "data": { "text/html": [ - "" + "
" ], "text/plain": [ "" @@ -2020,12 +2036,12 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "We may want to see the check the solution and the loss function during solving the problem (training the network). To do this, we need to pass a `neurodiffeq.ode.Monitor` object to `solve`. A `Monitor` has the following parameters:\n", + "We may want to see the check the solution and the loss function during solving the problem (training the network). To do this, we need to pass a `neurodiffeq.monitors.Monitor1D` object to `solve`. A `Monitor1D` has the following parameters:\n", "\n", "* `t_min` and `t_max`: the region of $t$ we want to monitor\n", "* `check_every`: the frequency of visualization. If `check_every=100`, then the monitor will visualize the solution every 100 epochs.\n", "\n", - "`%matplotlib notebook` should be executed to allow `Monitor` to work. Here we solve the above ODE again." + "`%matplotlib notebook` should be executed to allow `Monitor1D` to work. Here we solve the above ODE again." ] }, { @@ -2034,7 +2050,7 @@ "metadata": {}, "outputs": [], "source": [ - "from neurodiffeq.ode import Monitor" + "from neurodiffeq.monitors import Monitor1D" ] }, { @@ -2976,7 +2992,7 @@ { "data": { "text/html": [ - "" + "" ], "text/plain": [ "" @@ -2984,16 +3000,41 @@ }, "metadata": {}, "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/liushuheng/Documents/GitHub/neurodiffeq/neurodiffeq/ode.py:260: FutureWarning: The `solve_system` function is deprecated, use a `neurodiffeq.solvers.Solver1D` instance instead\n", + " warnings.warn(\n", + "/Users/liushuheng/Documents/GitHub/neurodiffeq/neurodiffeq/solvers.py:359: UserWarning: Passing `monitor` is deprecated, use a MonitorCallback and pass a list of callbacks instead\n", + " warnings.warn(\"Passing `monitor` is deprecated, \"\n" + ] } ], "source": [ - "%matplotlib notebook\n", + "# This must be executed for Jupyter Notebook environments\n", + "# If you are using Jupyter Lab, try `%matplotlib widget`\n", + "# Don't use `%matplotlib inline`!\n", + "\n", + "%matplotlib notebook \n", + "\n", "solution_ex, _ = solve(\n", " ode=exponential, condition=init_val_ex, t_min=0.0, t_max=2.0, \n", - " monitor=Monitor(t_min=0.0, t_max=2.0, check_every=100)\n", + " monitor=Monitor1D(t_min=0.0, t_max=2.0, check_every=100)\n", ")" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Here we have two warnings. But **don't worry**, the training process is not affected.\n", + "\n", + "- The first one warns that we should use a `neurodiffeq.solvers.Solver1D` instance, which we have discussed before.\n", + "- The second warning is slightly different. It says we should use a callback instead of using a `monitor`. Remember we said using a `neurodiffeq.solvers.Solver1D` allows flexible callbacks? This warning is also caused by using the deprecated `solve` function." + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -3003,23 +3044,38 @@ "Here we solve a damped harmonic oscillator: \n", "\n", "$$\n", - "F(x, t) = \\frac{d^2x}{dt^2} + x = 0\n", + "F(u, t) = \\frac{d^2u}{dt^2} + u = 0\n", "$$\n", "\n", "for\n", "\n", "$$\n", - "x(0) = 0.0, \\frac{dx}{dt}|_{t=0} = 1.0\n", + "u(0) = 0.0, \\frac{du}{dt}|_{t=0} = 1.0\n", "$$\n", "\n", "The analytical solution is \n", "\n", - "$$x = \\sin(t)$$\n", + "$$u = \\sin(t)$$\n", "\n", "We can include higher order derivatives in our ODE with the `order` keyword of `diff`, which is defaulted to 1.\n", "\n", - "Initial condition on $\\dfrac{dx}{dt}$ can be specified with the `x_0_prime` keyword of `IVP`. \n", - "\n", + "Initial condition on $\\dfrac{du}{dt}$ can be specified with the `u_0_prime` keyword of `IVP`. " + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "harmonic_oscillator = lambda u, t: diff(u, t, order=2) + u\n", + "init_val_ho = IVP(t_0=0.0, u_0=0.0, u_0_prime=1.0)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ "Here we will use another keyword for `solve`:\n", "\n", "* `max_epochs`: the number of epochs to run" @@ -3027,8 +3083,10 @@ }, { "cell_type": "code", - "execution_count": 8, - "metadata": {}, + "execution_count": 9, + "metadata": { + "scrolled": false + }, "outputs": [ { "data": { @@ -3964,7 +4022,7 @@ { "data": { "text/html": [ - "" + "" ], "text/plain": [ "" @@ -3972,22 +4030,36 @@ }, "metadata": {}, "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/liushuheng/Documents/GitHub/neurodiffeq/neurodiffeq/ode.py:260: FutureWarning: The `solve_system` function is deprecated, use a `neurodiffeq.solvers.Solver1D` instance instead\n", + " warnings.warn(\n", + "/Users/liushuheng/Documents/GitHub/neurodiffeq/neurodiffeq/solvers.py:359: UserWarning: Passing `monitor` is deprecated, use a MonitorCallback and pass a list of callbacks instead\n", + " warnings.warn(\"Passing `monitor` is deprecated, \"\n" + ] } ], "source": [ - "harmonic_oscillator = lambda x, t: diff(x, t, order=2) + x\n", - "init_val_ho = IVP(t_0=0.0, x_0=0.0, x_0_prime=1.0)\n", - "\n", "solution_ho, _ = solve(\n", " ode=harmonic_oscillator, condition=init_val_ho, t_min=0.0, t_max=2*np.pi, \n", " max_epochs=3000,\n", - " monitor=Monitor(t_min=0.0, t_max=2*np.pi, check_every=100)\n", + " monitor=Monitor1D(t_min=0.0, t_max=2*np.pi, check_every=100)\n", ")" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This is the third time we see these warnings. I promise we'll learn to get ride of them by the end of this chapter :)" + ] + }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 10, "metadata": {}, "outputs": [ { @@ -4924,7 +4996,7 @@ { "data": { "text/html": [ - "
" + "
" ], "text/plain": [ "" @@ -4935,87 +5007,1078 @@ } ], "source": [ - "ts = np.linspace(0, 2*np.pi, 100)\n", - "x_net = solution_ho(ts, as_type='np')\n", - "x_ana = np.sin(ts)\n", + "ts = np.linspace(0, 2*np.pi, 100)\n", + "u_net = solution_ho(ts, to_numpy=True)\n", + "u_ana = np.sin(ts)\n", + "\n", + "plt.figure()\n", + "plt.plot(ts, u_net, label='ANN-based solution')\n", + "plt.plot(ts, u_ana, label='analytical solution')\n", + "plt.ylabel('u')\n", + "plt.xlabel('t')\n", + "plt.title('comparing solutions')\n", + "plt.legend()\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Solving Systems of ODEs\n", + "\n", + "Systems of ODEs can be solved by `neurodiffeq.ode.solve_system`. \n", + "\n", + "Again, just for the sake of notation in the following examples, here we see differentiation as an operation, and see each element $u_i$of $\\vec{u}$ as different dependent vairables, then a ODE system above can be rewritten as\n", + "\n", + "$$\n", + "\\begin{pmatrix} \n", + "F_0(u_0, u_1, \\ldots, u_{m-1}, t) \\\\\n", + "F_1(u_0, u_1, \\ldots, u_{m-1}, t) \\\\\n", + "\\vdots \\\\\n", + "F_{m-1}(u_0, u_1, \\ldots, u_{m-1}, t)\n", + "\\end{pmatrix}\n", + "= \n", + "\\begin{pmatrix} \n", + "0 \\\\\n", + "0 \\\\\n", + "\\vdots \\\\\n", + "0\n", + "\\end{pmatrix}\n", + "$$" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Systems of ODE Example 1: Harmonic Oscilator\n", + "\n", + "For the harmonic oscillator example above, if we let $u_1 = u$ and $u_2 = \\dfrac{du}{dt}$. We can rewrite this ODE into a system of ODE:\n", + "\n", + "$$\\begin{align}\n", + "u_1^{'} - u_2 &= 0, \\\\\n", + "u_2^{'} + u_1 &= 0, \\\\\n", + "u_1(0) &= 0, \\\\\n", + "u_2(0) &= 1.\n", + "\\end{align}$$\n", + "\n", + "Here the analytical solution is \n", + "$$\\begin{align}\n", + "u_1 &= \\sin(t), \\\\\n", + "u_2 &= \\cos(t).\n", + "\\end{align}$$\n", + "\n", + "The `solve_system` function is for solving ODE systems. The signature is almost the same as `solve` except that we specify an `ode_system` and a set of `conditions`. \n", + "\n", + "* `ode_system`: a function representing the system of ODEs to be solved. If the our system of ODEs is $f_i(u_0, u_1, ..., u_{m-1}, t) = 0$ for $i = 0, 1, ..., n-1$ where $u_0, u_1, ..., u_{m-1}$ are dependent variables and $t$ is the independent variable, then `ode_system` should map $(u_0, u_1, ..., u_{m-1}, t)$ to a $n$-element list where the $i^{th}$ element is the value of $f_i(u_0, u_1, ..., u_{m-1}, t)$.\n", + "\n", + "* `conditions`: the initial value/boundary conditions as a list of Condition instance. They should be in an order such that the first condition constraints the first variable in $f_i$'s (see above) signature ($u_0$). The second condition constraints the second ($u_1$), and so on." + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "from neurodiffeq.ode import solve_system" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "data": { + "application/javascript": [ + "/* Put everything inside the global mpl namespace */\n", + "/* global mpl */\n", + "window.mpl = {};\n", + "\n", + "mpl.get_websocket_type = function () {\n", + " if (typeof WebSocket !== 'undefined') {\n", + " return WebSocket;\n", + " } else if (typeof MozWebSocket !== 'undefined') {\n", + " return MozWebSocket;\n", + " } else {\n", + " alert(\n", + " 'Your browser does not have WebSocket support. ' +\n", + " 'Please try Chrome, Safari or Firefox ≥ 6. ' +\n", + " 'Firefox 4 and 5 are also supported but you ' +\n", + " 'have to enable WebSockets in about:config.'\n", + " );\n", + " }\n", + "};\n", + "\n", + "mpl.figure = function (figure_id, websocket, ondownload, parent_element) {\n", + " this.id = figure_id;\n", + "\n", + " this.ws = websocket;\n", + "\n", + " this.supports_binary = this.ws.binaryType !== undefined;\n", + "\n", + " if (!this.supports_binary) {\n", + " var warnings = document.getElementById('mpl-warnings');\n", + " if (warnings) {\n", + " warnings.style.display = 'block';\n", + " warnings.textContent =\n", + " 'This browser does not support binary websocket messages. ' +\n", + " 'Performance may be slow.';\n", + " }\n", + " }\n", + "\n", + " this.imageObj = new Image();\n", + "\n", + " this.context = undefined;\n", + " this.message = undefined;\n", + " this.canvas = undefined;\n", + " this.rubberband_canvas = undefined;\n", + " this.rubberband_context = undefined;\n", + " this.format_dropdown = undefined;\n", + "\n", + " this.image_mode = 'full';\n", + "\n", + " this.root = document.createElement('div');\n", + " this.root.setAttribute('style', 'display: inline-block');\n", + " this._root_extra_style(this.root);\n", + "\n", + " parent_element.appendChild(this.root);\n", + "\n", + " this._init_header(this);\n", + " this._init_canvas(this);\n", + " this._init_toolbar(this);\n", + "\n", + " var fig = this;\n", + "\n", + " this.waiting = false;\n", + "\n", + " this.ws.onopen = function () {\n", + " fig.send_message('supports_binary', { value: fig.supports_binary });\n", + " fig.send_message('send_image_mode', {});\n", + " if (mpl.ratio !== 1) {\n", + " fig.send_message('set_dpi_ratio', { dpi_ratio: mpl.ratio });\n", + " }\n", + " fig.send_message('refresh', {});\n", + " };\n", + "\n", + " this.imageObj.onload = function () {\n", + " if (fig.image_mode === 'full') {\n", + " // Full images could contain transparency (where diff images\n", + " // almost always do), so we need to clear the canvas so that\n", + " // there is no ghosting.\n", + " fig.context.clearRect(0, 0, fig.canvas.width, fig.canvas.height);\n", + " }\n", + " fig.context.drawImage(fig.imageObj, 0, 0);\n", + " };\n", + "\n", + " this.imageObj.onunload = function () {\n", + " fig.ws.close();\n", + " };\n", + "\n", + " this.ws.onmessage = this._make_on_message_function(this);\n", + "\n", + " this.ondownload = ondownload;\n", + "};\n", + "\n", + "mpl.figure.prototype._init_header = function () {\n", + " var titlebar = document.createElement('div');\n", + " titlebar.classList =\n", + " 'ui-dialog-titlebar ui-widget-header ui-corner-all ui-helper-clearfix';\n", + " var titletext = document.createElement('div');\n", + " titletext.classList = 'ui-dialog-title';\n", + " titletext.setAttribute(\n", + " 'style',\n", + " 'width: 100%; text-align: center; padding: 3px;'\n", + " );\n", + " titlebar.appendChild(titletext);\n", + " this.root.appendChild(titlebar);\n", + " this.header = titletext;\n", + "};\n", + "\n", + "mpl.figure.prototype._canvas_extra_style = function (_canvas_div) {};\n", + "\n", + "mpl.figure.prototype._root_extra_style = function (_canvas_div) {};\n", + "\n", + "mpl.figure.prototype._init_canvas = function () {\n", + " var fig = this;\n", + "\n", + " var canvas_div = (this.canvas_div = document.createElement('div'));\n", + " canvas_div.setAttribute(\n", + " 'style',\n", + " 'border: 1px solid #ddd;' +\n", + " 'box-sizing: content-box;' +\n", + " 'clear: both;' +\n", + " 'min-height: 1px;' +\n", + " 'min-width: 1px;' +\n", + " 'outline: 0;' +\n", + " 'overflow: hidden;' +\n", + " 'position: relative;' +\n", + " 'resize: both;'\n", + " );\n", + "\n", + " function on_keyboard_event_closure(name) {\n", + " return function (event) {\n", + " return fig.key_event(event, name);\n", + " };\n", + " }\n", + "\n", + " canvas_div.addEventListener(\n", + " 'keydown',\n", + " on_keyboard_event_closure('key_press')\n", + " );\n", + " canvas_div.addEventListener(\n", + " 'keyup',\n", + " on_keyboard_event_closure('key_release')\n", + " );\n", + "\n", + " this._canvas_extra_style(canvas_div);\n", + " this.root.appendChild(canvas_div);\n", + "\n", + " var canvas = (this.canvas = document.createElement('canvas'));\n", + " canvas.classList.add('mpl-canvas');\n", + " canvas.setAttribute('style', 'box-sizing: content-box;');\n", + "\n", + " this.context = canvas.getContext('2d');\n", + "\n", + " var backingStore =\n", + " this.context.backingStorePixelRatio ||\n", + " this.context.webkitBackingStorePixelRatio ||\n", + " this.context.mozBackingStorePixelRatio ||\n", + " this.context.msBackingStorePixelRatio ||\n", + " this.context.oBackingStorePixelRatio ||\n", + " this.context.backingStorePixelRatio ||\n", + " 1;\n", + "\n", + " mpl.ratio = (window.devicePixelRatio || 1) / backingStore;\n", + "\n", + " var rubberband_canvas = (this.rubberband_canvas = document.createElement(\n", + " 'canvas'\n", + " ));\n", + " rubberband_canvas.setAttribute(\n", + " 'style',\n", + " 'box-sizing: content-box; position: absolute; left: 0; top: 0; z-index: 1;'\n", + " );\n", + "\n", + " var resizeObserver = new ResizeObserver(function (entries) {\n", + " var nentries = entries.length;\n", + " for (var i = 0; i < nentries; i++) {\n", + " var entry = entries[i];\n", + " var width, height;\n", + " if (entry.contentBoxSize) {\n", + " if (entry.contentBoxSize instanceof Array) {\n", + " // Chrome 84 implements new version of spec.\n", + " width = entry.contentBoxSize[0].inlineSize;\n", + " height = entry.contentBoxSize[0].blockSize;\n", + " } else {\n", + " // Firefox implements old version of spec.\n", + " width = entry.contentBoxSize.inlineSize;\n", + " height = entry.contentBoxSize.blockSize;\n", + " }\n", + " } else {\n", + " // Chrome <84 implements even older version of spec.\n", + " width = entry.contentRect.width;\n", + " height = entry.contentRect.height;\n", + " }\n", + "\n", + " // Keep the size of the canvas and rubber band canvas in sync with\n", + " // the canvas container.\n", + " if (entry.devicePixelContentBoxSize) {\n", + " // Chrome 84 implements new version of spec.\n", + " canvas.setAttribute(\n", + " 'width',\n", + " entry.devicePixelContentBoxSize[0].inlineSize\n", + " );\n", + " canvas.setAttribute(\n", + " 'height',\n", + " entry.devicePixelContentBoxSize[0].blockSize\n", + " );\n", + " } else {\n", + " canvas.setAttribute('width', width * mpl.ratio);\n", + " canvas.setAttribute('height', height * mpl.ratio);\n", + " }\n", + " canvas.setAttribute(\n", + " 'style',\n", + " 'width: ' + width + 'px; height: ' + height + 'px;'\n", + " );\n", + "\n", + " rubberband_canvas.setAttribute('width', width);\n", + " rubberband_canvas.setAttribute('height', height);\n", + "\n", + " // And update the size in Python. We ignore the initial 0/0 size\n", + " // that occurs as the element is placed into the DOM, which should\n", + " // otherwise not happen due to the minimum size styling.\n", + " if (width != 0 && height != 0) {\n", + " fig.request_resize(width, height);\n", + " }\n", + " }\n", + " });\n", + " resizeObserver.observe(canvas_div);\n", + "\n", + " function on_mouse_event_closure(name) {\n", + " return function (event) {\n", + " return fig.mouse_event(event, name);\n", + " };\n", + " }\n", + "\n", + " rubberband_canvas.addEventListener(\n", + " 'mousedown',\n", + " on_mouse_event_closure('button_press')\n", + " );\n", + " rubberband_canvas.addEventListener(\n", + " 'mouseup',\n", + " on_mouse_event_closure('button_release')\n", + " );\n", + " // Throttle sequential mouse events to 1 every 20ms.\n", + " rubberband_canvas.addEventListener(\n", + " 'mousemove',\n", + " on_mouse_event_closure('motion_notify')\n", + " );\n", + "\n", + " rubberband_canvas.addEventListener(\n", + " 'mouseenter',\n", + " on_mouse_event_closure('figure_enter')\n", + " );\n", + " rubberband_canvas.addEventListener(\n", + " 'mouseleave',\n", + " on_mouse_event_closure('figure_leave')\n", + " );\n", + "\n", + " canvas_div.addEventListener('wheel', function (event) {\n", + " if (event.deltaY < 0) {\n", + " event.step = 1;\n", + " } else {\n", + " event.step = -1;\n", + " }\n", + " on_mouse_event_closure('scroll')(event);\n", + " });\n", + "\n", + " canvas_div.appendChild(canvas);\n", + " canvas_div.appendChild(rubberband_canvas);\n", + "\n", + " this.rubberband_context = rubberband_canvas.getContext('2d');\n", + " this.rubberband_context.strokeStyle = '#000000';\n", + "\n", + " this._resize_canvas = function (width, height, forward) {\n", + " if (forward) {\n", + " canvas_div.style.width = width + 'px';\n", + " canvas_div.style.height = height + 'px';\n", + " }\n", + " };\n", + "\n", + " // Disable right mouse context menu.\n", + " this.rubberband_canvas.addEventListener('contextmenu', function (_e) {\n", + " event.preventDefault();\n", + " return false;\n", + " });\n", + "\n", + " function set_focus() {\n", + " canvas.focus();\n", + " canvas_div.focus();\n", + " }\n", + "\n", + " window.setTimeout(set_focus, 100);\n", + "};\n", + "\n", + "mpl.figure.prototype._init_toolbar = function () {\n", + " var fig = this;\n", + "\n", + " var toolbar = document.createElement('div');\n", + " toolbar.classList = 'mpl-toolbar';\n", + " this.root.appendChild(toolbar);\n", + "\n", + " function on_click_closure(name) {\n", + " return function (_event) {\n", + " return fig.toolbar_button_onclick(name);\n", + " };\n", + " }\n", + "\n", + " function on_mouseover_closure(tooltip) {\n", + " return function (event) {\n", + " if (!event.currentTarget.disabled) {\n", + " return fig.toolbar_button_onmouseover(tooltip);\n", + " }\n", + " };\n", + " }\n", + "\n", + " fig.buttons = {};\n", + " var buttonGroup = document.createElement('div');\n", + " buttonGroup.classList = 'mpl-button-group';\n", + " for (var toolbar_ind in mpl.toolbar_items) {\n", + " var name = mpl.toolbar_items[toolbar_ind][0];\n", + " var tooltip = mpl.toolbar_items[toolbar_ind][1];\n", + " var image = mpl.toolbar_items[toolbar_ind][2];\n", + " var method_name = mpl.toolbar_items[toolbar_ind][3];\n", + "\n", + " if (!name) {\n", + " /* Instead of a spacer, we start a new button group. */\n", + " if (buttonGroup.hasChildNodes()) {\n", + " toolbar.appendChild(buttonGroup);\n", + " }\n", + " buttonGroup = document.createElement('div');\n", + " buttonGroup.classList = 'mpl-button-group';\n", + " continue;\n", + " }\n", + "\n", + " var button = (fig.buttons[name] = document.createElement('button'));\n", + " button.classList = 'mpl-widget';\n", + " button.setAttribute('role', 'button');\n", + " button.setAttribute('aria-disabled', 'false');\n", + " button.addEventListener('click', on_click_closure(method_name));\n", + " button.addEventListener('mouseover', on_mouseover_closure(tooltip));\n", + "\n", + " var icon_img = document.createElement('img');\n", + " icon_img.src = '_images/' + image + '.png';\n", + " icon_img.srcset = '_images/' + image + '_large.png 2x';\n", + " icon_img.alt = tooltip;\n", + " button.appendChild(icon_img);\n", + "\n", + " buttonGroup.appendChild(button);\n", + " }\n", + "\n", + " if (buttonGroup.hasChildNodes()) {\n", + " toolbar.appendChild(buttonGroup);\n", + " }\n", + "\n", + " var fmt_picker = document.createElement('select');\n", + " fmt_picker.classList = 'mpl-widget';\n", + " toolbar.appendChild(fmt_picker);\n", + " this.format_dropdown = fmt_picker;\n", + "\n", + " for (var ind in mpl.extensions) {\n", + " var fmt = mpl.extensions[ind];\n", + " var option = document.createElement('option');\n", + " option.selected = fmt === mpl.default_extension;\n", + " option.innerHTML = fmt;\n", + " fmt_picker.appendChild(option);\n", + " }\n", + "\n", + " var status_bar = document.createElement('span');\n", + " status_bar.classList = 'mpl-message';\n", + " toolbar.appendChild(status_bar);\n", + " this.message = status_bar;\n", + "};\n", + "\n", + "mpl.figure.prototype.request_resize = function (x_pixels, y_pixels) {\n", + " // Request matplotlib to resize the figure. Matplotlib will then trigger a resize in the client,\n", + " // which will in turn request a refresh of the image.\n", + " this.send_message('resize', { width: x_pixels, height: y_pixels });\n", + "};\n", + "\n", + "mpl.figure.prototype.send_message = function (type, properties) {\n", + " properties['type'] = type;\n", + " properties['figure_id'] = this.id;\n", + " this.ws.send(JSON.stringify(properties));\n", + "};\n", + "\n", + "mpl.figure.prototype.send_draw_message = function () {\n", + " if (!this.waiting) {\n", + " this.waiting = true;\n", + " this.ws.send(JSON.stringify({ type: 'draw', figure_id: this.id }));\n", + " }\n", + "};\n", + "\n", + "mpl.figure.prototype.handle_save = function (fig, _msg) {\n", + " var format_dropdown = fig.format_dropdown;\n", + " var format = format_dropdown.options[format_dropdown.selectedIndex].value;\n", + " fig.ondownload(fig, format);\n", + "};\n", + "\n", + "mpl.figure.prototype.handle_resize = function (fig, msg) {\n", + " var size = msg['size'];\n", + " if (size[0] !== fig.canvas.width || size[1] !== fig.canvas.height) {\n", + " fig._resize_canvas(size[0], size[1], msg['forward']);\n", + " fig.send_message('refresh', {});\n", + " }\n", + "};\n", + "\n", + "mpl.figure.prototype.handle_rubberband = function (fig, msg) {\n", + " var x0 = msg['x0'] / mpl.ratio;\n", + " var y0 = (fig.canvas.height - msg['y0']) / mpl.ratio;\n", + " var x1 = msg['x1'] / mpl.ratio;\n", + " var y1 = (fig.canvas.height - msg['y1']) / mpl.ratio;\n", + " x0 = Math.floor(x0) + 0.5;\n", + " y0 = Math.floor(y0) + 0.5;\n", + " x1 = Math.floor(x1) + 0.5;\n", + " y1 = Math.floor(y1) + 0.5;\n", + " var min_x = Math.min(x0, x1);\n", + " var min_y = Math.min(y0, y1);\n", + " var width = Math.abs(x1 - x0);\n", + " var height = Math.abs(y1 - y0);\n", + "\n", + " fig.rubberband_context.clearRect(\n", + " 0,\n", + " 0,\n", + " fig.canvas.width / mpl.ratio,\n", + " fig.canvas.height / mpl.ratio\n", + " );\n", + "\n", + " fig.rubberband_context.strokeRect(min_x, min_y, width, height);\n", + "};\n", + "\n", + "mpl.figure.prototype.handle_figure_label = function (fig, msg) {\n", + " // Updates the figure title.\n", + " fig.header.textContent = msg['label'];\n", + "};\n", + "\n", + "mpl.figure.prototype.handle_cursor = function (fig, msg) {\n", + " var cursor = msg['cursor'];\n", + " switch (cursor) {\n", + " case 0:\n", + " cursor = 'pointer';\n", + " break;\n", + " case 1:\n", + " cursor = 'default';\n", + " break;\n", + " case 2:\n", + " cursor = 'crosshair';\n", + " break;\n", + " case 3:\n", + " cursor = 'move';\n", + " break;\n", + " }\n", + " fig.rubberband_canvas.style.cursor = cursor;\n", + "};\n", + "\n", + "mpl.figure.prototype.handle_message = function (fig, msg) {\n", + " fig.message.textContent = msg['message'];\n", + "};\n", + "\n", + "mpl.figure.prototype.handle_draw = function (fig, _msg) {\n", + " // Request the server to send over a new figure.\n", + " fig.send_draw_message();\n", + "};\n", + "\n", + "mpl.figure.prototype.handle_image_mode = function (fig, msg) {\n", + " fig.image_mode = msg['mode'];\n", + "};\n", + "\n", + "mpl.figure.prototype.handle_history_buttons = function (fig, msg) {\n", + " for (var key in msg) {\n", + " if (!(key in fig.buttons)) {\n", + " continue;\n", + " }\n", + " fig.buttons[key].disabled = !msg[key];\n", + " fig.buttons[key].setAttribute('aria-disabled', !msg[key]);\n", + " }\n", + "};\n", + "\n", + "mpl.figure.prototype.handle_navigate_mode = function (fig, msg) {\n", + " if (msg['mode'] === 'PAN') {\n", + " fig.buttons['Pan'].classList.add('active');\n", + " fig.buttons['Zoom'].classList.remove('active');\n", + " } else if (msg['mode'] === 'ZOOM') {\n", + " fig.buttons['Pan'].classList.remove('active');\n", + " fig.buttons['Zoom'].classList.add('active');\n", + " } else {\n", + " fig.buttons['Pan'].classList.remove('active');\n", + " fig.buttons['Zoom'].classList.remove('active');\n", + " }\n", + "};\n", + "\n", + "mpl.figure.prototype.updated_canvas_event = function () {\n", + " // Called whenever the canvas gets updated.\n", + " this.send_message('ack', {});\n", + "};\n", + "\n", + "// A function to construct a web socket function for onmessage handling.\n", + "// Called in the figure constructor.\n", + "mpl.figure.prototype._make_on_message_function = function (fig) {\n", + " return function socket_on_message(evt) {\n", + " if (evt.data instanceof Blob) {\n", + " /* FIXME: We get \"Resource interpreted as Image but\n", + " * transferred with MIME type text/plain:\" errors on\n", + " * Chrome. But how to set the MIME type? It doesn't seem\n", + " * to be part of the websocket stream */\n", + " evt.data.type = 'image/png';\n", + "\n", + " /* Free the memory for the previous frames */\n", + " if (fig.imageObj.src) {\n", + " (window.URL || window.webkitURL).revokeObjectURL(\n", + " fig.imageObj.src\n", + " );\n", + " }\n", + "\n", + " fig.imageObj.src = (window.URL || window.webkitURL).createObjectURL(\n", + " evt.data\n", + " );\n", + " fig.updated_canvas_event();\n", + " fig.waiting = false;\n", + " return;\n", + " } else if (\n", + " typeof evt.data === 'string' &&\n", + " evt.data.slice(0, 21) === 'data:image/png;base64'\n", + " ) {\n", + " fig.imageObj.src = evt.data;\n", + " fig.updated_canvas_event();\n", + " fig.waiting = false;\n", + " return;\n", + " }\n", + "\n", + " var msg = JSON.parse(evt.data);\n", + " var msg_type = msg['type'];\n", + "\n", + " // Call the \"handle_{type}\" callback, which takes\n", + " // the figure and JSON message as its only arguments.\n", + " try {\n", + " var callback = fig['handle_' + msg_type];\n", + " } catch (e) {\n", + " console.log(\n", + " \"No handler for the '\" + msg_type + \"' message type: \",\n", + " msg\n", + " );\n", + " return;\n", + " }\n", + "\n", + " if (callback) {\n", + " try {\n", + " // console.log(\"Handling '\" + msg_type + \"' message: \", msg);\n", + " callback(fig, msg);\n", + " } catch (e) {\n", + " console.log(\n", + " \"Exception inside the 'handler_\" + msg_type + \"' callback:\",\n", + " e,\n", + " e.stack,\n", + " msg\n", + " );\n", + " }\n", + " }\n", + " };\n", + "};\n", + "\n", + "// from http://stackoverflow.com/questions/1114465/getting-mouse-location-in-canvas\n", + "mpl.findpos = function (e) {\n", + " //this section is from http://www.quirksmode.org/js/events_properties.html\n", + " var targ;\n", + " if (!e) {\n", + " e = window.event;\n", + " }\n", + " if (e.target) {\n", + " targ = e.target;\n", + " } else if (e.srcElement) {\n", + " targ = e.srcElement;\n", + " }\n", + " if (targ.nodeType === 3) {\n", + " // defeat Safari bug\n", + " targ = targ.parentNode;\n", + " }\n", + "\n", + " // pageX,Y are the mouse positions relative to the document\n", + " var boundingRect = targ.getBoundingClientRect();\n", + " var x = e.pageX - (boundingRect.left + document.body.scrollLeft);\n", + " var y = e.pageY - (boundingRect.top + document.body.scrollTop);\n", + "\n", + " return { x: x, y: y };\n", + "};\n", + "\n", + "/*\n", + " * return a copy of an object with only non-object keys\n", + " * we need this to avoid circular references\n", + " * http://stackoverflow.com/a/24161582/3208463\n", + " */\n", + "function simpleKeys(original) {\n", + " return Object.keys(original).reduce(function (obj, key) {\n", + " if (typeof original[key] !== 'object') {\n", + " obj[key] = original[key];\n", + " }\n", + " return obj;\n", + " }, {});\n", + "}\n", + "\n", + "mpl.figure.prototype.mouse_event = function (event, name) {\n", + " var canvas_pos = mpl.findpos(event);\n", + "\n", + " if (name === 'button_press') {\n", + " this.canvas.focus();\n", + " this.canvas_div.focus();\n", + " }\n", + "\n", + " var x = canvas_pos.x * mpl.ratio;\n", + " var y = canvas_pos.y * mpl.ratio;\n", + "\n", + " this.send_message(name, {\n", + " x: x,\n", + " y: y,\n", + " button: event.button,\n", + " step: event.step,\n", + " guiEvent: simpleKeys(event),\n", + " });\n", + "\n", + " /* This prevents the web browser from automatically changing to\n", + " * the text insertion cursor when the button is pressed. We want\n", + " * to control all of the cursor setting manually through the\n", + " * 'cursor' event from matplotlib */\n", + " event.preventDefault();\n", + " return false;\n", + "};\n", + "\n", + "mpl.figure.prototype._key_event_extra = function (_event, _name) {\n", + " // Handle any extra behaviour associated with a key event\n", + "};\n", + "\n", + "mpl.figure.prototype.key_event = function (event, name) {\n", + " // Prevent repeat events\n", + " if (name === 'key_press') {\n", + " if (event.which === this._key) {\n", + " return;\n", + " } else {\n", + " this._key = event.which;\n", + " }\n", + " }\n", + " if (name === 'key_release') {\n", + " this._key = null;\n", + " }\n", + "\n", + " var value = '';\n", + " if (event.ctrlKey && event.which !== 17) {\n", + " value += 'ctrl+';\n", + " }\n", + " if (event.altKey && event.which !== 18) {\n", + " value += 'alt+';\n", + " }\n", + " if (event.shiftKey && event.which !== 16) {\n", + " value += 'shift+';\n", + " }\n", + "\n", + " value += 'k';\n", + " value += event.which.toString();\n", + "\n", + " this._key_event_extra(event, name);\n", + "\n", + " this.send_message(name, { key: value, guiEvent: simpleKeys(event) });\n", + " return false;\n", + "};\n", + "\n", + "mpl.figure.prototype.toolbar_button_onclick = function (name) {\n", + " if (name === 'download') {\n", + " this.handle_save(this, null);\n", + " } else {\n", + " this.send_message('toolbar_button', { name: name });\n", + " }\n", + "};\n", + "\n", + "mpl.figure.prototype.toolbar_button_onmouseover = function (tooltip) {\n", + " this.message.textContent = tooltip;\n", + "};\n", + "mpl.toolbar_items = [[\"Home\", \"Reset original view\", \"fa fa-home icon-home\", \"home\"], [\"Back\", \"Back to previous view\", \"fa fa-arrow-left icon-arrow-left\", \"back\"], [\"Forward\", \"Forward to next view\", \"fa fa-arrow-right icon-arrow-right\", \"forward\"], [\"\", \"\", \"\", \"\"], [\"Pan\", \"Left button pans, Right button zooms\\nx/y fixes axis, CTRL fixes aspect\", \"fa fa-arrows icon-move\", \"pan\"], [\"Zoom\", \"Zoom to rectangle\\nx/y fixes axis, CTRL fixes aspect\", \"fa fa-square-o icon-check-empty\", \"zoom\"], [\"\", \"\", \"\", \"\"], [\"Download\", \"Download plot\", \"fa fa-floppy-o icon-save\", \"download\"]];\n", + "\n", + "mpl.extensions = [\"eps\", \"jpeg\", \"pdf\", \"png\", \"ps\", \"raw\", \"svg\", \"tif\"];\n", + "\n", + "mpl.default_extension = \"png\";/* global mpl */\n", + "\n", + "var comm_websocket_adapter = function (comm) {\n", + " // Create a \"websocket\"-like object which calls the given IPython comm\n", + " // object with the appropriate methods. Currently this is a non binary\n", + " // socket, so there is still some room for performance tuning.\n", + " var ws = {};\n", + "\n", + " ws.close = function () {\n", + " comm.close();\n", + " };\n", + " ws.send = function (m) {\n", + " //console.log('sending', m);\n", + " comm.send(m);\n", + " };\n", + " // Register the callback with on_msg.\n", + " comm.on_msg(function (msg) {\n", + " //console.log('receiving', msg['content']['data'], msg);\n", + " // Pass the mpl event to the overridden (by mpl) onmessage function.\n", + " ws.onmessage(msg['content']['data']);\n", + " });\n", + " return ws;\n", + "};\n", + "\n", + "mpl.mpl_figure_comm = function (comm, msg) {\n", + " // This is the function which gets called when the mpl process\n", + " // starts-up an IPython Comm through the \"matplotlib\" channel.\n", + "\n", + " var id = msg.content.data.id;\n", + " // Get hold of the div created by the display call when the Comm\n", + " // socket was opened in Python.\n", + " var element = document.getElementById(id);\n", + " var ws_proxy = comm_websocket_adapter(comm);\n", + "\n", + " function ondownload(figure, _format) {\n", + " window.open(figure.canvas.toDataURL());\n", + " }\n", + "\n", + " var fig = new mpl.figure(id, ws_proxy, ondownload, element);\n", + "\n", + " // Call onopen now - mpl needs it, as it is assuming we've passed it a real\n", + " // web socket which is closed, not our websocket->open comm proxy.\n", + " ws_proxy.onopen();\n", + "\n", + " fig.parent_element = element;\n", + " fig.cell_info = mpl.find_output_cell(\"
\");\n", + " if (!fig.cell_info) {\n", + " console.error('Failed to find cell for figure', id, fig);\n", + " return;\n", + " }\n", + "};\n", + "\n", + "mpl.figure.prototype.handle_close = function (fig, msg) {\n", + " var width = fig.canvas.width / mpl.ratio;\n", + " fig.root.removeEventListener('remove', this._remove_fig_handler);\n", + "\n", + " // Update the output cell to use the data from the current canvas.\n", + " fig.push_to_output();\n", + " var dataURL = fig.canvas.toDataURL();\n", + " // Re-enable the keyboard manager in IPython - without this line, in FF,\n", + " // the notebook keyboard shortcuts fail.\n", + " IPython.keyboard_manager.enable();\n", + " fig.parent_element.innerHTML =\n", + " '';\n", + " fig.close_ws(fig, msg);\n", + "};\n", + "\n", + "mpl.figure.prototype.close_ws = function (fig, msg) {\n", + " fig.send_message('closing', msg);\n", + " // fig.ws.close()\n", + "};\n", + "\n", + "mpl.figure.prototype.push_to_output = function (_remove_interactive) {\n", + " // Turn the data on the canvas into data in the output cell.\n", + " var width = this.canvas.width / mpl.ratio;\n", + " var dataURL = this.canvas.toDataURL();\n", + " this.cell_info[1]['text/html'] =\n", + " '';\n", + "};\n", + "\n", + "mpl.figure.prototype.updated_canvas_event = function () {\n", + " // Tell IPython that the notebook contents must change.\n", + " IPython.notebook.set_dirty(true);\n", + " this.send_message('ack', {});\n", + " var fig = this;\n", + " // Wait a second, then push the new image to the DOM so\n", + " // that it is saved nicely (might be nice to debounce this).\n", + " setTimeout(function () {\n", + " fig.push_to_output();\n", + " }, 1000);\n", + "};\n", + "\n", + "mpl.figure.prototype._init_toolbar = function () {\n", + " var fig = this;\n", + "\n", + " var toolbar = document.createElement('div');\n", + " toolbar.classList = 'btn-toolbar';\n", + " this.root.appendChild(toolbar);\n", + "\n", + " function on_click_closure(name) {\n", + " return function (_event) {\n", + " return fig.toolbar_button_onclick(name);\n", + " };\n", + " }\n", + "\n", + " function on_mouseover_closure(tooltip) {\n", + " return function (event) {\n", + " if (!event.currentTarget.disabled) {\n", + " return fig.toolbar_button_onmouseover(tooltip);\n", + " }\n", + " };\n", + " }\n", + "\n", + " fig.buttons = {};\n", + " var buttonGroup = document.createElement('div');\n", + " buttonGroup.classList = 'btn-group';\n", + " var button;\n", + " for (var toolbar_ind in mpl.toolbar_items) {\n", + " var name = mpl.toolbar_items[toolbar_ind][0];\n", + " var tooltip = mpl.toolbar_items[toolbar_ind][1];\n", + " var image = mpl.toolbar_items[toolbar_ind][2];\n", + " var method_name = mpl.toolbar_items[toolbar_ind][3];\n", + "\n", + " if (!name) {\n", + " /* Instead of a spacer, we start a new button group. */\n", + " if (buttonGroup.hasChildNodes()) {\n", + " toolbar.appendChild(buttonGroup);\n", + " }\n", + " buttonGroup = document.createElement('div');\n", + " buttonGroup.classList = 'btn-group';\n", + " continue;\n", + " }\n", + "\n", + " button = fig.buttons[name] = document.createElement('button');\n", + " button.classList = 'btn btn-default';\n", + " button.href = '#';\n", + " button.title = name;\n", + " button.innerHTML = '';\n", + " button.addEventListener('click', on_click_closure(method_name));\n", + " button.addEventListener('mouseover', on_mouseover_closure(tooltip));\n", + " buttonGroup.appendChild(button);\n", + " }\n", + "\n", + " if (buttonGroup.hasChildNodes()) {\n", + " toolbar.appendChild(buttonGroup);\n", + " }\n", + "\n", + " // Add the status bar.\n", + " var status_bar = document.createElement('span');\n", + " status_bar.classList = 'mpl-message pull-right';\n", + " toolbar.appendChild(status_bar);\n", + " this.message = status_bar;\n", + "\n", + " // Add the close button to the window.\n", + " var buttongrp = document.createElement('div');\n", + " buttongrp.classList = 'btn-group inline pull-right';\n", + " button = document.createElement('button');\n", + " button.classList = 'btn btn-mini btn-primary';\n", + " button.href = '#';\n", + " button.title = 'Stop Interaction';\n", + " button.innerHTML = '';\n", + " button.addEventListener('click', function (_evt) {\n", + " fig.handle_close(fig, {});\n", + " });\n", + " button.addEventListener(\n", + " 'mouseover',\n", + " on_mouseover_closure('Stop Interaction')\n", + " );\n", + " buttongrp.appendChild(button);\n", + " var titlebar = this.root.querySelector('.ui-dialog-titlebar');\n", + " titlebar.insertBefore(buttongrp, titlebar.firstChild);\n", + "};\n", + "\n", + "mpl.figure.prototype._remove_fig_handler = function () {\n", + " this.close_ws(this, {});\n", + "};\n", + "\n", + "mpl.figure.prototype._root_extra_style = function (el) {\n", + " el.style.boxSizing = 'content-box'; // override notebook setting of border-box.\n", + " el.addEventListener('remove', this._remove_fig_handler);\n", + "};\n", + "\n", + "mpl.figure.prototype._canvas_extra_style = function (el) {\n", + " // this is important to make the div 'focusable\n", + " el.setAttribute('tabindex', 0);\n", + " // reach out to IPython and tell the keyboard manager to turn it's self\n", + " // off when our div gets focus\n", + "\n", + " // location in version 3\n", + " if (IPython.notebook.keyboard_manager) {\n", + " IPython.notebook.keyboard_manager.register_events(el);\n", + " } else {\n", + " // location in version 2\n", + " IPython.keyboard_manager.register_events(el);\n", + " }\n", + "};\n", + "\n", + "mpl.figure.prototype._key_event_extra = function (event, _name) {\n", + " var manager = IPython.notebook.keyboard_manager;\n", + " if (!manager) {\n", + " manager = IPython.keyboard_manager;\n", + " }\n", + "\n", + " // Check for shift+enter\n", + " if (event.shiftKey && event.which === 13) {\n", + " this.canvas_div.blur();\n", + " // select the cell after this one\n", + " var index = IPython.notebook.find_cell_index(this.cell_info[0]);\n", + " IPython.notebook.select(index + 1);\n", + " }\n", + "};\n", + "\n", + "mpl.figure.prototype.handle_save = function (fig, _msg) {\n", + " fig.ondownload(fig, null);\n", + "};\n", + "\n", + "mpl.find_output_cell = function (html_output) {\n", + " // Return the cell and output element which can be found *uniquely* in the notebook.\n", + " // Note - this is a bit hacky, but it is done because the \"notebook_saving.Notebook\"\n", + " // IPython event is triggered only after the cells have been serialised, which for\n", + " // our purposes (turning an active figure into a static one), is too late.\n", + " var cells = IPython.notebook.get_cells();\n", + " var ncells = cells.length;\n", + " for (var i = 0; i < ncells; i++) {\n", + " var cell = cells[i];\n", + " if (cell.cell_type === 'code') {\n", + " for (var j = 0; j < cell.output_area.outputs.length; j++) {\n", + " var data = cell.output_area.outputs[j];\n", + " if (data.data) {\n", + " // IPython >= 3 moved mimebundle to data attribute of output\n", + " data = data.data;\n", + " }\n", + " if (data['text/html'] === html_output) {\n", + " return [cell, data, j];\n", + " }\n", + " }\n", + " }\n", + " }\n", + "};\n", + "\n", + "// Register the function which deals with the matplotlib target/channel.\n", + "// The kernel may be null if the page has been refreshed.\n", + "if (IPython.notebook.kernel !== null) {\n", + " IPython.notebook.kernel.comm_manager.register_target(\n", + " 'matplotlib',\n", + " mpl.mpl_figure_comm\n", + " );\n", + "}\n" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/liushuheng/Documents/GitHub/neurodiffeq/neurodiffeq/ode.py:260: FutureWarning: The `solve_system` function is deprecated, use a `neurodiffeq.solvers.Solver1D` instance instead\n", + " warnings.warn(\n", + "/Users/liushuheng/Documents/GitHub/neurodiffeq/neurodiffeq/solvers.py:359: UserWarning: Passing `monitor` is deprecated, use a MonitorCallback and pass a list of callbacks instead\n", + " warnings.warn(\"Passing `monitor` is deprecated, \"\n" + ] + } + ], + "source": [ + "# specify the ODE system\n", + "parametric_circle = lambda u1, u2, t : [diff(u1, t) - u2, \n", + " diff(u2, t) + u1]\n", + "# specify the initial conditions\n", + "init_vals_pc = [\n", + " IVP(t_0=0.0, u_0=0.0),\n", + " IVP(t_0=0.0, u_0=1.0)\n", + "]\n", "\n", - "plt.figure()\n", - "plt.plot(ts, x_net, label='ANN-based solution')\n", - "plt.plot(ts, x_ana, label='analytical solution')\n", - "plt.ylabel('x')\n", - "plt.xlabel('t')\n", - "plt.title('comparing solutions')\n", - "plt.legend()\n", - "plt.show()" + "# solve the ODE system\n", + "solution_pc, _ = solve_system(\n", + " ode_system=parametric_circle, conditions=init_vals_pc, t_min=0.0, t_max=2*np.pi, \n", + " max_epochs=5000,\n", + " monitor=Monitor1D(t_min=0.0, t_max=2*np.pi, check_every=100)\n", + ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "## Solving Systems of ODEs\n", - "\n", - "Systems of ODEs can be solved by `neurodiffeq.ode.solve_system`. \n", - "\n", - "Again, just for the sake of notation in the following examples, here we see differentiation as an operation, and see each element $x_i$of $\\vec{x}$ as different dependent vairables, then a ODE system above can be rewritten as\n", - "\n", - "$$\n", - "\\begin{pmatrix} \n", - "F_0(x_0, x_1, \\ldots, x_{m-1}, t) \\\\\n", - "F_1(x_0, x_1, \\ldots, x_{m-1}, t) \\\\\n", - "\\vdots \\\\\n", - "F_{m-1}(x_0, x_1, \\ldots, x_{m-1}, t)\n", - "\\end{pmatrix}\n", - "= \n", - "\\begin{pmatrix} \n", - "0 \\\\\n", - "0 \\\\\n", - "\\vdots \\\\\n", - "0\n", - "\\end{pmatrix}\n", - "$$" + "`solve_system` returns a tuple, where the first entry is the solution as a function and the second entry is the loss history as a list. The solution is a function that maps $t$ to $[u_0, u_1, ..., u_{m-1}]$. It accepts `numpy.array` or `torch.Tensor` as its input. " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "### Systems of ODE Example 1: Harmonic Oscilator\n", - "\n", - "For the harmonic oscillator example above, if we let $x_1 = x$ and $x_2 = \\dfrac{dx}{dt}$. We can rewrite this ODE into a system of ODE:\n", - "\n", - "$$\\begin{align}\n", - "x_1^{'} - x_2 &= 0, \\\\\n", - "x_2^{'} + x_1 &= 0, \\\\\n", - "x_1(0) &= 0, \\\\\n", - "x_2(0) &= 1.\n", - "\\end{align}$$\n", - "\n", - "Here the analytical solution is \n", - "$$\\begin{align}\n", - "x_1 &= \\sin(t), \\\\\n", - "x_2 &= \\cos(t).\n", - "\\end{align}$$\n", - "\n", - "The `solve_system` function is for solving ODE systems. The signature is almost the same as `solve` except that we specify an `ode_system` and a set of `conditions`. \n", - "\n", - "* `ode_system`: a function representing the system of ODEs to be solved. If the our system of ODEs is $f_i(x_0, x_1, ..., x_{m-1}, t) = 0$ for $i = 0, 1, ..., n-1$ where $x_0, x_1, ..., x_{m-1}$ are dependent variables and $t$ is the independent variable, then `ode_system` should map $(x_0, x_1, ..., x_{m-1}, t)$ to a $n$-element list where the $i^{th}$ element is the value of $f_i(x_0, x_1, ..., x_{m-1}, t)$.\n", - "\n", - "* `conditions`: the initial value/boundary conditions as a list of Condition instance. They should be in an order such that the first condition constraints the first variable in $f_i$'s (see above) signature ($x_0$). The second condition constraints the second ($x_1$), and so on." - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": {}, - "outputs": [], - "source": [ - "from neurodiffeq.ode import solve_system" + "Here we compare the ANN-based solution with the analytical solution:" ] }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 13, "metadata": {}, "outputs": [ { @@ -5952,7 +7015,7 @@ { "data": { "text/html": [ - "" + "
" ], "text/plain": [ "" @@ -5963,33 +7026,65 @@ } ], "source": [ - "# specify the ODE system\n", - "parametric_circle = lambda x1, x2, t : [diff(x1, t) - x2, \n", - " diff(x2, t) + x1]\n", - "# specify the initial conditions\n", - "init_vals_pc = [\n", - " IVP(t_0=0.0, x_0=0.0),\n", - " IVP(t_0=0.0, x_0=1.0)\n", - "]\n", + "ts = np.linspace(0, 2*np.pi, 100)\n", + "u1_net, u2_net = solution_pc(ts, to_numpy=True)\n", + "u1_ana, u2_ana = np.sin(ts), np.cos(ts)\n", "\n", - "# solve the ODE system\n", - "solution_pc, _ = solve_system(\n", - " ode_system=parametric_circle, conditions=init_vals_pc, t_min=0.0, t_max=2*np.pi, \n", - " max_epochs=5000,\n", - " monitor=Monitor(t_min=0.0, t_max=2*np.pi, check_every=100)\n", - ")" + "plt.figure()\n", + "plt.plot(ts, u1_net, label='ANN-based solution of $u_1$')\n", + "plt.plot(ts, u1_ana, label='Analytical solution of $u_1$')\n", + "plt.plot(ts, u2_net, label='ANN-based solution of $u_2$')\n", + "plt.plot(ts, u2_ana, label='Analytical solution of $u_2$')\n", + "plt.ylabel('u')\n", + "plt.xlabel('t')\n", + "plt.title('comparing solutions')\n", + "plt.legend()\n", + "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "`solve_system` returns a tuple, where the first entry is the solution as a function and the second entry is the loss history as a list. The solution is a function that maps $t$ to $[x_0, x_1, ..., x_{m-1}]$. It accepts `numpy.array` as input as well. Here we compare the ANN-based solution with the analytical solution:" + "### Systems of ODE Example 2: Lotka–Volterra equations\n", + "\n", + "The previous examples are rather simple because they are both linear ODE systems. We have numerous existing numerical methods that solve these linear ODEs very well. To show the capability neurodiffeq, let's see this example of *nonlinear* ODEs.\n", + "\n", + "Lotka–Volterra equations are a pair of nonlinear ODE frequently used to describe the dynamics of biological systems in which two species interact, one as a predator and the other as prey:\n", + "\n", + "$$\\begin{align}\n", + "\\frac{du}{dt} = \\alpha u - \\beta uv \\\\\n", + "\\frac{dv}{dt} = \\delta uv - \\gamma v\n", + "\\end{align}$$\n", + "\n", + "Let $\\alpha = \\beta = \\delta = \\gamma = 1$. Here we solve this pair of ODE when $u(0) = 1.5$ and $v(0) = 1.0$.\n", + "\n", + "If not specified otherwise, `solve` and `solve_system` will use a fully-connected network with 1 hidden layer with 32 hidden units (tanh activation) to approximate each dependent variables. In some situations, we may want to use our own neural network. For example, the default neural net is not good at solving a problem where the solution oscillates. However, if we know in advance that the solution oscillates, we can use sin as activation function, which resulted in much faster convergence.\n", + "\n", + "`neurodiffeq.FCNN` is a fully connected neural network. It is initiated by the following parameters:\n", + "\n", + "* `hidden_units`: number of units in each hidden layer. If you have 3 hidden layers with 32, 64, and 16 neurons respectively, `hidden_units` should be a tuple `(32, 64, 16)`.\n", + "\n", + "* `actv`: a `torch.nn.Module` *class*. e.g. `nn.Tanh`, `nn.Sigmoid`.\n", + "\n", + "Here we will use another keyword for `solve_system`:\n", + "\n", + "* `nets`: a list of networks to be used to approximate each dependent variable" ] }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 14, + "metadata": {}, + "outputs": [], + "source": [ + "from neurodiffeq.networks import FCNN # fully-connect neural network\n", + "from neurodiffeq.networks import SinActv # sin activation" + ] + }, + { + "cell_type": "code", + "execution_count": 15, "metadata": {}, "outputs": [ { @@ -6926,7 +8021,7 @@ { "data": { "text/html": [ - "
" + "" ], "text/plain": [ "" @@ -6934,67 +8029,65 @@ }, "metadata": {}, "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/liushuheng/Documents/GitHub/neurodiffeq/neurodiffeq/ode.py:260: FutureWarning: The `solve_system` function is deprecated, use a `neurodiffeq.solvers.Solver1D` instance instead\n", + " warnings.warn(\n", + "/Users/liushuheng/Documents/GitHub/neurodiffeq/neurodiffeq/solvers.py:359: UserWarning: Passing `monitor` is deprecated, use a MonitorCallback and pass a list of callbacks instead\n", + " warnings.warn(\"Passing `monitor` is deprecated, \"\n" + ] } ], "source": [ - "ts = np.linspace(0, 2*np.pi, 100)\n", - "x1_net, x2_net = solution_pc(ts, as_type='np')\n", - "x1_ana, x2_ana = np.sin(ts), np.cos(ts)\n", + "# specify the ODE system and its parameters\n", + "alpha, beta, delta, gamma = 1, 1, 1, 1\n", + "lotka_volterra = lambda u, v, t : [ diff(u, t) - (alpha*u - beta*u*v), \n", + " diff(v, t) - (delta*u*v - gamma*v), ]\n", + "# specify the initial conditions\n", + "init_vals_lv = [\n", + " IVP(t_0=0.0, u_0=1.5), # 1.5 is the value of u at t_0 = 0.0\n", + " IVP(t_0=0.0, u_0=1.0), # 1.0 is the value of v at t_0 = 0.0\n", + "]\n", "\n", - "plt.figure()\n", - "plt.plot(ts, x1_net, label='ANN-based solution of x1')\n", - "plt.plot(ts, x1_ana, label='analytical solution of x1')\n", - "plt.plot(ts, x2_net, label='ANN-based solution of x2')\n", - "plt.plot(ts, x2_ana, label='analytical solution of x2')\n", - "plt.ylabel('x')\n", - "plt.xlabel('t')\n", - "plt.title('comparing solutions')\n", - "plt.legend()\n", - "plt.show()" + "# specify the network to be used to approximate each dependent variable\n", + "# the input units and output units default to 1 for FCNN\n", + "nets_lv = [\n", + " FCNN(n_input_units=1, n_output_units=1, hidden_units=(32, 32), actv=SinActv),\n", + " FCNN(n_input_units=1, n_output_units=1, hidden_units=(32, 32), actv=SinActv)\n", + "]\n", + "\n", + "# solve the ODE system\n", + "solution_lv, _ = solve_system(\n", + " ode_system=lotka_volterra, conditions=init_vals_lv, t_min=0.0, t_max=12, \n", + " nets=nets_lv, max_epochs=3000,\n", + " monitor=Monitor1D(t_min=0.0, t_max=12, check_every=100)\n", + ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "### Systems of ODE Example 2: Lotka–Volterra equations\n", - "\n", - "Lotka–Volterra equations are a pair of nonlinear ODE frequently used to describe the dynamics of biological systems in which two species interact, one as a predator and the other as prey:\n", - "\n", - "$$\\begin{align}\n", - "\\frac{dx}{dt} = \\alpha x - \\beta xy \\\\\n", - "\\frac{dy}{dt} = \\delta xy - \\gamma y\n", - "\\end{align}$$\n", - "\n", - "Let $\\alpha = \\beta = \\delta = \\gamma = 1$. Here we solve this pair of ODE when $x(0) = 1.5$ and $y(0) = 1.0$.\n", - "\n", - "If not specified otherwise, `solve` and `solve_system` will use a fully-connected network with 1 hidden layer with 32 hidden units (tanh activation) to approximate each dependent variables. In some situations, we may want to use our own neural network. For example, the default neural net is not good at solving a problem where the solution oscillates. However, if we know in advance that the solution oscillates, we can use sin as activation function, which resulted in much faster convergence.\n", - "\n", - "`neurodiffeq.FCNN` is a fully connected neural network. It is initiated by the following parameters:\n", - "\n", - "* `hidden_units`: number of units in each hidden layer. If you have 3 hidden layers with 32, 64, and 16 neurons respectively, `hidden_units` should be a tuple `(32, 64, 16)`.\n", + "### Tired of the annoying warning messages? Let's get rid of them by using a 'Solver'\n", "\n", - "* `actv`: a `torch.nn.Module` *class*. e.g. `nn.Tanh`, `nn.Sigmoid`.\n", + "Now that you are familiar with the usage of `solve`, let's try the second way of solving ODEs -- using a `neurodiffeq.solvers.Solver1D` instance. If you are familiar with `sklearn` or `keras`, the workflow with a *Solver* is quite similar.\n", "\n", - "Here we will use another keyword for `solve_system`:\n", + "1. Instantiate a solver. (Specify the ODE/PDE system, initial/boundary conditions, problem domain, etc.)\n", + "2. Fit the solver (Specify number of epochs to train, callbacks in each epoch, monitor, etc.)\n", + "3. Get the solutions and other internal variables.\n", "\n", - "* `nets`: a list of networks to be used to approximate each dependent variable" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "metadata": {}, - "outputs": [], - "source": [ - "from neurodiffeq.networks import FCNN # fully-connect neural network\n", - "from neurodiffeq.networks import SinActv # sin activation" + "**This is the recommended way of solving ODEs (and PDEs later). Once you learn to use a Solver, you should stick to this way instead of using a** `solve` **function.**" ] }, { "cell_type": "code", - "execution_count": 14, - "metadata": {}, + "execution_count": 16, + "metadata": { + "scrolled": false + }, "outputs": [ { "data": { @@ -7930,7 +9023,7 @@ { "data": { "text/html": [ - "" + "" ], "text/plain": [ "" @@ -7941,33 +9034,33 @@ } ], "source": [ - "# specify the ODE system and its parameters\n", - "alpha, beta, delta, gamma = 1, 1, 1, 1\n", - "lotka_volterra = lambda x, y, t : [diff(x, t) - (alpha*x - beta*x*y), \n", - " diff(y, t) - (delta*x*y - gamma*y)]\n", - "# specify the initial conditions\n", - "init_vals_lv = [\n", - " IVP(t_0=0.0, x_0=1.5),\n", - " IVP(t_0=0.0, x_0=1.0)\n", - "]\n", + "from neurodiffeq.solvers import Solver1D\n", + "from neurodiffeq.callbacks import MonitorCallback\n", "\n", - "# specify the network to be used to approximate each dependent variable\n", - "nets_lv = [\n", - " FCNN(hidden_units=(32, 32), actv=SinActv),\n", - " FCNN(hidden_units=(32, 32), actv=SinActv)\n", - "]\n", + "# Let's create a monitor first\n", + "monitor = Monitor1D(t_min=0.0, t_max=12.0, check_every=100)\n", + "# ... and turn it into a Callback instance\n", + "monitor_callback = MonitorCallback(monitor)\n", "\n", - "# solve the ODE system\n", - "solution_lv, _ = solve_system(\n", - " ode_system=lotka_volterra, conditions=init_vals_lv, t_min=0.0, t_max=12, \n", - " nets=nets_lv, max_epochs=12000,\n", - " monitor=Monitor(t_min=0.0, t_max=12, check_every=100)\n", - ")" + "# Instantiate a solver instance\n", + "solver = Solver1D(\n", + " ode_system=lotka_volterra,\n", + " conditions=init_vals_lv,\n", + " t_min=0.1,\n", + " t_max=12.0,\n", + " nets=nets_lv,\n", + ")\n", + "\n", + "# Fit the solver (i.e., train the neural networks)\n", + "solver.fit(max_epochs=3000, callbacks=[monitor_callback])\n", + "\n", + "# Get the solution\n", + "solution_lv = solver.get_solution()" ] }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 17, "metadata": { "scrolled": false }, @@ -8906,7 +9999,7 @@ { "data": { "text/html": [ - "
" + "
" ], "text/plain": [ "" @@ -8914,13 +10007,23 @@ }, "metadata": {}, "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" } ], "source": [ "ts = np.linspace(0, 12, 100)\n", "\n", "# ANN-based solution\n", - "prey_net, pred_net = solution_lv(ts, as_type='np')\n", + "prey_net, pred_net = solution_lv(ts, to_numpy=True)\n", "\n", "# numerical solution\n", "from scipy.integrate import odeint\n", @@ -8931,16 +10034,23 @@ "prey_num = Ps[:,0]\n", "pred_num = Ps[:,1]\n", "\n", - "plt.figure()\n", - "plt.plot(ts, prey_net, label='ANN-based solution of prey')\n", - "plt.plot(ts, prey_num, label='numerical solution of prey')\n", - "plt.plot(ts, pred_net, label='ANN-based solution of predator')\n", - "plt.plot(ts, pred_num, label='numerical solution of predator')\n", - "plt.ylabel('population')\n", - "plt.xlabel('t')\n", - "plt.title('comparing solutions')\n", - "plt.legend()\n", - "plt.show()" + "fig = plt.figure(figsize=(12, 5))\n", + "ax1, ax2 = fig.subplots(1, 2)\n", + "ax1.plot(ts, prey_net, label='ANN-based solution of prey')\n", + "ax1.plot(ts, prey_num, label='numerical solution of prey')\n", + "ax1.plot(ts, pred_net, label='ANN-based solution of predator')\n", + "ax1.plot(ts, pred_num, label='numerical solution of predator')\n", + "ax1.set_ylabel('population')\n", + "ax1.set_xlabel('t')\n", + "ax1.set_title('Comparing solutions')\n", + "ax1.legend()\n", + "\n", + "ax2.set_title('Error of ANN solution from numerical solution')\n", + "ax2.plot(ts, prey_net-prey_num, label='error in prey number')\n", + "ax2.plot(ts, pred_net-pred_num, label='error in predator number')\n", + "ax2.set_ylabel('populator')\n", + "ax2.set_xlabel('t')\n", + "ax2.legend()" ] }, { @@ -8949,7 +10059,7 @@ "source": [ "## Solving PDEs\n", "\n", - "PDEs can be solved by `neurodiffeq.pde.solve2D`. Currently `neurodiffeq` only support solving PDEs with 2 independent variables.\n", + "Two-dimensional PDEs can be solved by the legacy `neurodiffeq.pde.solve2D` or the more flexible `neurodiffeq.solvers.Solver2D`.\n", "\n", "Aain, just for the sake of notation in the following examples, here we see differentiation as an operation, then an PDE of $u(x, y)$ can be rewritten as: \n", "\n", @@ -8987,24 +10097,32 @@ "\n", "Here we have a Dirichlet boundary condition on both 4 edges of the orthogonal box. We will be using `DirichletBVP2D` for this boundary condition. The arguments `x_min_val`, `x_max_val`, `y_min_val` and `y_max_val` correspond to $u(x, y)\\bigg|_{x=0}$, $u(x, y)\\bigg|_{x=1}$, $u(x, y)\\bigg|_{y=0}$ and $u(x, y)\\bigg|_{y=1}$. Note that they should all be functions of $x$ or $y$. These functions are expected to take in a `torch.tensor` and return a `torch.tensor`, so if the function involves some elementary functions like $\\sin$, we should use `torch.sin` instead of `numpy.sin`.\n", "\n", - "The `solve2D` function is almost the same as `solve` and `solve_system` in the `ode` module. The difference is that we indicate the domain of our problem with `xy_min` and `xy_max`, they are tuples representing the 'lower left' point and 'upper right' point of our domain. Also, we need to use `ExampleGenerator2D` and `Monitor2D` from the `pde` module." + "Like in the ODE case, we have two ways to solve 2-D PDEs.\n", + "\n", + "1. The `neurodiffeq.pde.solve2D` function is almost the same as `solve` and `solve_system` in the `neurodiffeq.ode` module. Again, **this way is deprecated and won't be covered here**.\n", + "2. The `neurodiffeq.solvers.Solver2D` class is almost the same as `neurodiffeq.solvers.Solver1D`.\n", + "\n", + "The difference is that we indicate the domain of our problem with `xy_min` and `xy_max`, they are tuples representing the 'lower left' point and 'upper right' point of our domain. \n", + "\n", + "Also, we need to use `neurodiffeq.generators.Generator2D` and `neurodiffeq.monitors.Monitor2D`." ] }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 18, "metadata": {}, "outputs": [], "source": [ "from neurodiffeq.conditions import DirichletBVP2D\n", - "from neurodiffeq.pde import solve2D, Monitor2D\n", + "from neurodiffeq.solvers import Solver2D\n", + "from neurodiffeq.monitors import Monitor2D\n", "from neurodiffeq.generators import Generator2D\n", "import torch" ] }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 19, "metadata": {}, "outputs": [ { @@ -9941,7 +11059,7 @@ { "data": { "text/html": [ - "" + "" ], "text/plain": [ "" @@ -9952,22 +11070,49 @@ } ], "source": [ - "laplace = lambda u, x, y: diff(u, x, order=2) + diff(u, y, order=2)\n", - "bc = DirichletBVP2D(\n", - " x_min=0, x_min_val=lambda y: torch.sin(np.pi*y), \n", - " x_max=1, x_max_val=lambda y: 0, \n", - " y_min=0, y_min_val=lambda x: 0, \n", - " y_max=1, y_max_val=lambda x: 0\n", + "# Define the PDE system\n", + "# There's only one (Laplace) equation in the system, so the function maps (u, x, y) to a single entry\n", + "laplace = lambda u, x, y: [\n", + " diff(u, x, order=2) + diff(u, y, order=2)\n", + "]\n", + "\n", + "# Define the boundary conditions\n", + "# There's only one function to be solved for, so we only have a single condition\n", + "conditions = [\n", + " DirichletBVP2D(\n", + " x_min=0, x_min_val=lambda y: torch.sin(np.pi*y), \n", + " x_max=1, x_max_val=lambda y: 0, \n", + " y_min=0, y_min_val=lambda x: 0, \n", + " y_max=1, y_max_val=lambda x: 0,\n", + " )\n", + "]\n", + "\n", + "# Define the neural network to be used\n", + "# Again, there's only one function to be solved for, so we only have a single network\n", + "nets = [\n", + " FCNN(n_input_units=2, n_output_units=1, hidden_units=[512])\n", + "]\n", + "\n", + "# Define the monitor callback\n", + "monitor=Monitor2D(check_every=10, xy_min=(0, 0), xy_max=(1, 1))\n", + "monitor_callback = MonitorCallback(monitor)\n", + "\n", + "# Instantiate the solver \n", + "solver = Solver2D(\n", + " pde_system=laplace,\n", + " conditions=conditions,\n", + " xy_min=(0, 0), # We can omit xy_min when both train_generator and valid_generator are specified\n", + " xy_max=(1, 1), # We can omit xy_max when both train_generator and valid_generator are specified\n", + " nets=nets,\n", + " train_generator=Generator2D((32, 32), (0, 0), (1, 1), method='equally-spaced-noisy'),\n", + " valid_generator=Generator2D((32, 32), (0, 0), (1, 1), method='equally-spaced'),\n", ")\n", - "net = FCNN(n_input_units=2, hidden_units=(32, 32))\n", "\n", - "solution_neural_net_laplace, _ = solve2D(\n", - " pde=laplace, condition=bc, xy_min=(0, 0), xy_max=(1, 1),\n", - " net=net, max_epochs=200, train_generator=Generator2D(\n", - " (32, 32), (0, 0), (1, 1), method='equally-spaced-noisy'\n", - " ),\n", - " monitor=Monitor2D(check_every=10, xy_min=(0, 0), xy_max=(1, 1))\n", - ")" + "# Fit the neural network\n", + "solver.fit(max_epochs=200, callbacks=[monitor_callback])\n", + "\n", + "# Obtain the solution\n", + "solution_neural_net_laplace = solver.get_solution()" ] }, { @@ -9979,7 +11124,7 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 20, "metadata": {}, "outputs": [], "source": [ @@ -9991,15 +11136,17 @@ " ax.set_xlabel(x_label)\n", " ax.set_ylabel(y_label)\n", " ax.set_zlabel(z_label)\n", - " ax.set_title(title)\n", + " fig.suptitle(title)\n", " ax.set_proj_type('ortho')\n", " plt.show()" ] }, { "cell_type": "code", - "execution_count": 19, - "metadata": {}, + "execution_count": 21, + "metadata": { + "scrolled": false + }, "outputs": [ { "data": { @@ -10935,7 +12082,7 @@ { "data": { "text/html": [ - "
" + "
" ], "text/plain": [ "" @@ -10948,13 +12095,13 @@ "source": [ "xs, ys = np.linspace(0, 1, 101), np.linspace(0, 1, 101)\n", "xx, yy = np.meshgrid(xs, ys)\n", - "sol_net = solution_neural_net_laplace(xx, yy, as_type='np')\n", - "plt_surf(xx, yy, sol_net, title='u(x, y) as solved by neural network')" + "sol_net = solution_neural_net_laplace(xx, yy, to_numpy=True)\n", + "plt_surf(xx, yy, sol_net, title='$u(x, y)$ as solved by neural network')" ] }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 22, "metadata": { "scrolled": false }, @@ -11893,7 +13040,7 @@ { "data": { "text/html": [ - "
" + "
" ], "text/plain": [ "" @@ -11940,7 +13087,7 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 23, "metadata": {}, "outputs": [], "source": [ @@ -11950,7 +13097,7 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 24, "metadata": {}, "outputs": [ { @@ -12887,7 +14034,7 @@ { "data": { "text/html": [ - "" + "" ], "text/plain": [ "" @@ -12899,22 +14046,49 @@ ], "source": [ "k, L, T = 0.3, 2, 3\n", - "heat = lambda u, x, t: diff(u, t) - k * diff(u, x, order=2)\n", + "# Define the PDE system\n", + "# There's only one (heat) equation in the system, so the function maps (u, x, y) to a single entry\n", + "heat = lambda u, x, t: [\n", + " diff(u, t) - k * diff(u, x, order=2)\n", + "]\n", + "\n", + "# Define the initial and boundary conditions\n", + "# There's only one function to be solved for, so we only have a single condition object\n", + "conditions = [\n", + " IBVP1D(\n", + " t_min=0, t_min_val=lambda x: torch.sin(np.pi * x / L),\n", + " x_min=0, x_min_prime=lambda t: np.pi/L * torch.exp(-k*np.pi**2*t/L**2),\n", + " x_max=L, x_max_prime=lambda t: -np.pi/L * torch.exp(-k*np.pi**2*t/L**2)\n", + " )\n", + "]\n", + "\n", + "# Define the neural network to be used\n", + "# Again, there's only one function to be solved for, so we only have a single network\n", + "nets = [\n", + " FCNN(n_input_units=2, hidden_units=(32, 32))\n", + "]\n", + "\n", "\n", - "ibvp = IBVP1D(\n", - " t_min=0, t_min_val=lambda x: torch.sin(np.pi * x / L),\n", - " x_min=0, x_min_prime=lambda t: np.pi/L * torch.exp(-k*np.pi**2*t/L**2),\n", - " x_max=L, x_max_prime=lambda t: -np.pi/L * torch.exp(-k*np.pi**2*t/L**2)\n", + "# Define the monitor callback\n", + "monitor=Monitor2D(check_every=10, xy_min=(0, 0), xy_max=(L, T))\n", + "monitor_callback = MonitorCallback(monitor)\n", + "\n", + "# Instantiate the solver \n", + "solver = Solver2D(\n", + " pde_system=heat,\n", + " conditions=conditions,\n", + " xy_min=(0, 0), # We can omit xy_min when both train_generator and valid_generator are specified\n", + " xy_max=(L, T), # We can omit xy_max when both train_generator and valid_generator are specified\n", + " nets=nets,\n", + " train_generator=Generator2D((32, 32), (0, 0), (L, T), method='equally-spaced-noisy'),\n", + " valid_generator=Generator2D((32, 32), (0, 0), (L, T), method='equally-spaced'),\n", ")\n", - "net = FCNN(n_input_units=2, hidden_units=(32, 32))\n", "\n", - "solution_neural_net_heat, _ = solve2D(\n", - " pde=heat, condition=ibvp, xy_min=(0, 0), xy_max=(L, T),\n", - " net=net, max_epochs=200, train_generator=Generator2D(\n", - " (32, 32), (0, 0), (L, T), method='equally-spaced-noisy'\n", - " ),\n", - " monitor=Monitor2D(check_every=10, xy_min=(0, 0), xy_max=(L, T))\n", - ")" + "# Fit the neural network\n", + "solver.fit(max_epochs=200, callbacks=[monitor_callback])\n", + "\n", + "# Obtain the solution\n", + "solution_neural_net_heat = solver.get_solution()" ] }, { @@ -12926,8 +14100,10 @@ }, { "cell_type": "code", - "execution_count": 23, - "metadata": {}, + "execution_count": 25, + "metadata": { + "scrolled": false + }, "outputs": [ { "data": { @@ -13863,7 +15039,7 @@ { "data": { "text/html": [ - "
" + "
" ], "text/plain": [ "" @@ -13875,10 +15051,10 @@ { "data": { "text/plain": [ - "" + "" ] }, - "execution_count": 23, + "execution_count": 25, "metadata": {}, "output_type": "execute_result" } @@ -13892,8 +15068,10 @@ }, { "cell_type": "code", - "execution_count": 24, - "metadata": {}, + "execution_count": 26, + "metadata": { + "scrolled": false + }, "outputs": [ { "data": { @@ -14829,7 +16007,7 @@ { "data": { "text/html": [ - "
" + "
" ], "text/plain": [ "" @@ -14842,7 +16020,7 @@ "source": [ "solution_analytical_heat = lambda x, t: np.sin(np.pi*x/L) * np.exp(-k * np.pi**2 * t / L**2)\n", "sol_ana = solution_analytical_heat(xx, tt)\n", - "sol_net = solution_neural_net_heat(xx, tt, as_type='np')\n", + "sol_net = solution_neural_net_heat(xx, tt, to_numpy=True)\n", "plt_surf(xx, tt, sol_net-sol_ana, y_label='t', z_label='residual of the neural network solution')" ] }, @@ -14855,7 +16033,7 @@ }, { "cell_type": "code", - "execution_count": 25, + "execution_count": 27, "metadata": {}, "outputs": [], "source": [ @@ -14881,7 +16059,7 @@ }, { "cell_type": "code", - "execution_count": 26, + "execution_count": 28, "metadata": {}, "outputs": [], "source": [ @@ -14914,7 +16092,7 @@ "\n", "def compare_contour(sol_net, sol_ana, eval_on_xs, eval_on_ys, cdbc=None):\n", " eval_on_xs, eval_on_ys = eval_on_xs.flatten(), eval_on_ys.flatten()\n", - " s_net = sol_net(eval_on_xs, eval_on_ys, as_type='np')\n", + " s_net = sol_net(eval_on_xs, eval_on_ys, to_numpy=True)\n", " s_ana = sol_ana(eval_on_xs, eval_on_ys)\n", " \n", " fig = plt.figure(figsize=(18, 4))\n", @@ -14954,7 +16132,7 @@ }, { "cell_type": "code", - "execution_count": 27, + "execution_count": 29, "metadata": {}, "outputs": [], "source": [ @@ -14971,7 +16149,7 @@ }, { "cell_type": "code", - "execution_count": 28, + "execution_count": 30, "metadata": {}, "outputs": [], "source": [ @@ -15012,7 +16190,7 @@ }, { "cell_type": "code", - "execution_count": 29, + "execution_count": 31, "metadata": {}, "outputs": [ { @@ -15949,7 +17127,7 @@ { "data": { "text/html": [ - "
" + "
" ], "text/plain": [ "" @@ -16002,7 +17180,7 @@ }, { "cell_type": "code", - "execution_count": 30, + "execution_count": 32, "metadata": {}, "outputs": [ { @@ -16939,7 +18117,7 @@ { "data": { "text/html": [ - "" + "" ], "text/plain": [ "" @@ -16955,34 +18133,46 @@ " true_u = torch.log(1+x**2+y**2)\n", " return torch.mean( (u - true_u)**2 )\n", "\n", + "# Define the differential equation\n", "def de_star(u, x, y):\n", - " return diff(u, x, order=2) + diff(u, y, order=2) + torch.exp(u) - 1.0 - x**2 - y**2 - 4.0/(1.0+x**2+y**2)**2\n", + " return [diff(u, x, order=2) + diff(u, y, order=2) + torch.exp(u) - 1.0 - x**2 - y**2 - 4.0/(1.0+x**2+y**2)**2]\n", "\n", "# fully connected network with one hidden layer (40 hidden units with Sigmoid activation)\n", "net = FCNN(n_input_units=2, hidden_units=(40, 40), actv=nn.ELU)\n", "adam = optim.Adam(params=net.parameters(), lr=0.01)\n", "\n", - "# train on 32 X 32 grid\n", - "solution_neural_net_star, history_star = solve2D(\n", - " pde=de_star, condition=cdbc_star,\n", - " xy_min=(-1, -1), xy_max=(1, 1),\n", - " train_generator=train_gen, valid_generator=valid_gen,\n", - " net=net, max_epochs=100, batch_size=torch.sum(is_in_domain_train).item(), optimizer=adam,\n", - " monitor=Monitor2D(check_every=1, xy_min=(-1, -1), xy_max=(1, 1)),\n", - " metrics={'mse': mse}\n", - ")" + "# Define the monitor callback\n", + "monitor = Monitor2D(check_every=10, xy_min=(-1, -1), xy_max=(1, 1))\n", + "monitor_callback = MonitorCallback(monitor)\n", + "\n", + "# Instantiate the solver \n", + "solver = Solver2D(\n", + " pde_system=de_star,\n", + " conditions=[cdbc_star],\n", + " xy_min=(-1, -1), # We can omit xy_min when both train_generator and valid_generator are specified\n", + " xy_max=(1, 1), # We can omit xy_max when both train_generator and valid_generator are specified\n", + " nets=nets,\n", + " train_generator=train_gen,\n", + " valid_generator=valid_gen,\n", + ")\n", + "\n", + "# Fit the neural network, train on 32 x 32 grids\n", + "solver.fit(max_epochs=100, callbacks=[monitor_callback])\n", + "\n", + "# Obtain the solution\n", + "solution_neural_net_star = solver.get_solution()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "We can compare our approximation with the analytical solution:" + "We can compare the neural network solution with the analytical solution:" ] }, { "cell_type": "code", - "execution_count": 31, + "execution_count": 34, "metadata": {}, "outputs": [ { @@ -17919,7 +19109,7 @@ { "data": { "text/html": [ - "" + "" ], "text/plain": [ "" From 59ee9243cb12d21db82c7fe9c8c5ca7ecff30f55 Mon Sep 17 00:00:00 2001 From: Shuheng Liu Date: Tue, 5 Jan 2021 19:05:27 +0800 Subject: [PATCH 11/18] fix(solver): fix a compatibility issue with torch losses --- neurodiffeq/solvers.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/neurodiffeq/solvers.py b/neurodiffeq/solvers.py index 5dd0df2..8adb7c5 100644 --- a/neurodiffeq/solvers.py +++ b/neurodiffeq/solvers.py @@ -151,7 +151,13 @@ def analytic_mse(*args): self.metrics_history.update({'valid__' + name: [] for name in self.metrics_fn}) self.optimizer = optimizer if optimizer else Adam(chain.from_iterable(n.parameters() for n in self.nets)) - self.criterion = criterion if criterion else lambda r: (r ** 2).mean() + + if criterion is None: + self.criterion = lambda r: (r ** 2).mean() + elif isinstance(criterion, nn.modules.loss._Loss): + self.criterion = lambda r: criterion(r, torch.zeros_like(r)) + else: + self.criterion = criterion def make_pair_dict(train=None, valid=None): return {'train': train, 'valid': valid} From cf0f9f9ee503ef93d83a66272ba23e3fc22bdac6 Mon Sep 17 00:00:00 2001 From: Shuheng Liu Date: Tue, 5 Jan 2021 19:06:23 +0800 Subject: [PATCH 12/18] fix(solver): fix an issue for legacy `return_internal` option --- neurodiffeq/solvers.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/neurodiffeq/solvers.py b/neurodiffeq/solvers.py index 8adb7c5..8664432 100644 --- a/neurodiffeq/solvers.py +++ b/neurodiffeq/solvers.py @@ -438,6 +438,8 @@ def _get_internal_variables(self): "optimizer": self.optimizer, "diff_eqs": self.diff_eqs, "generator": self.generator, + "train_generator": self.generator['train'], + "valid_generator": self.generator['valid'], } @deprecated_alias(param_names='var_names') From 4831b1dbe8234196da1d6a82f0818995720f5855 Mon Sep 17 00:00:00 2001 From: Shuheng Liu Date: Tue, 5 Jan 2021 19:09:04 +0800 Subject: [PATCH 13/18] feat(fix): allow implicit return of all internal variables --- neurodiffeq/solvers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/neurodiffeq/solvers.py b/neurodiffeq/solvers.py index 8664432..e0c53bf 100644 --- a/neurodiffeq/solvers.py +++ b/neurodiffeq/solvers.py @@ -443,7 +443,7 @@ def _get_internal_variables(self): } @deprecated_alias(param_names='var_names') - def get_internals(self, var_names, return_type='list'): + def get_internals(self, var_names=None, return_type='list'): r"""Return internal variable(s) of the solver - If var_names == 'all', return all internal variables as a dict. @@ -461,7 +461,7 @@ def get_internals(self, var_names, return_type='list'): available_variables = self._get_internal_variables() - if var_names == "all": + if var_names == "all" or var_names is None: return available_variables if isinstance(var_names, str): From a8cfaa0a08aee2f8d340d9fa22edefce40eaa4d6 Mon Sep 17 00:00:00 2001 From: Shuheng Liu Date: Tue, 5 Jan 2021 19:09:38 +0800 Subject: [PATCH 14/18] refactor(callbacks): move all callbacks to `callbacks.py` --- neurodiffeq/callbacks.py | 30 ++++++++++++++++++++++++++++++ neurodiffeq/pde_spherical.py | 30 ------------------------------ 2 files changed, 30 insertions(+), 30 deletions(-) diff --git a/neurodiffeq/callbacks.py b/neurodiffeq/callbacks.py index 87a9d02..8361721 100644 --- a/neurodiffeq/callbacks.py +++ b/neurodiffeq/callbacks.py @@ -1,4 +1,7 @@ import os +import dill +from datetime import datetime +import logging class MonitorCallback: @@ -51,3 +54,30 @@ def __call__(self, solver): self.monitor.fig.savefig(pic_path) +class CheckpointCallback: + def __init__(self, ckpt_dir): + self.ckpt_dir = ckpt_dir + + def __call__(self, solver): + if solver.local_epoch == solver._max_local_epoch - 1: + now = datetime.now() + timestr = now.strftime("%Y-%m-%d_%H-%M-%S") + fname = os.path.join(self.ckpt_dir, timestr + ".internals") + with open(fname, 'wb') as f: + dill.dump(solver.get_internals("all"), f) + logging.info(f"Saved checkpoint to {fname} at local epoch = {solver.local_epoch} " + f"(global epoch = {solver.global_epoch})") + + +class ReportOnFitCallback: + def __call__(self, solver): + if solver.local_epoch == 0: + logging.info( + f"Starting from global epoch {solver.global_epoch - 1}, training on {(solver.r_min, solver.r_max)}") + tb = solver.generator['train'].size + ntb = solver.n_batches['train'] + t = tb * ntb + vb = solver.generator['valid'].size + nvb = solver.n_batches['valid'] + v = vb * nvb + logging.info(f"train size = {tb} x {ntb} = {t}, valid_size = {vb} x {nvb} = {v}") diff --git a/neurodiffeq/pde_spherical.py b/neurodiffeq/pde_spherical.py index 4643d83..a6356bd 100644 --- a/neurodiffeq/pde_spherical.py +++ b/neurodiffeq/pde_spherical.py @@ -281,33 +281,3 @@ def enforcer(net, cond, points): ret = ret + (internals,) return ret -# callbacks to be passed to SphericalSolver.fit() - - -class CheckpointCallback: - def __init__(self, ckpt_dir): - self.ckpt_dir = ckpt_dir - - def __call__(self, solver): - if solver.local_epoch == solver._max_local_epoch - 1: - now = datetime.now() - timestr = now.strftime("%Y-%m-%d_%H-%M-%S") - fname = os.path.join(self.ckpt_dir, timestr + ".internals") - with open(fname, 'wb') as f: - dill.dump(solver.get_internals("all"), f) - logging.info(f"Saved checkpoint to {fname} at local epoch = {solver.local_epoch} " - f"(global epoch = {solver.global_epoch})") - - -class ReportOnFitCallback: - def __call__(self, solver): - if solver.local_epoch == 0: - logging.info( - f"Starting from global epoch {solver.global_epoch - 1}, training on {(solver.r_min, solver.r_max)}") - tb = solver.generator['train'].size - ntb = solver.n_batches['train'] - t = tb * ntb - vb = solver.generator['valid'].size - nvb = solver.n_batches['valid'] - v = vb * nvb - logging.info(f"train size = {tb} x {ntb} = {t}, valid_size = {vb} x {nvb} = {v}") From 7b1560bbd1bc910c35b19a59688a5457c6fb6df0 Mon Sep 17 00:00:00 2001 From: Shuheng Liu Date: Tue, 5 Jan 2021 19:10:40 +0800 Subject: [PATCH 15/18] test(solver): add test case for `BaseSolver.get_internals` --- tests/test_ode.py | 34 ++++++++++++++++++++++++++++++++-- 1 file changed, 32 insertions(+), 2 deletions(-) diff --git a/tests/test_ode.py b/tests/test_ode.py index 7ce2cde..edc5be5 100644 --- a/tests/test_ode.py +++ b/tests/test_ode.py @@ -11,8 +11,8 @@ from neurodiffeq.ode import IVP, DirichletBVP from neurodiffeq.ode import solve, solve_system, Monitor from neurodiffeq.monitors import Monitor1D -from neurodiffeq.solvers import Solution1D -from neurodiffeq.generators import Generator1D +from neurodiffeq.solvers import Solution1D, Solver1D +from neurodiffeq.generators import Generator1D, BaseGenerator import torch @@ -190,3 +190,33 @@ def check_output(us, shape, type, msg=""): check_output(us, shape=(N_SAMPLES, 1), type=torch.Tensor, msg=f"[use_single={use_single}]") us = solution(ts, as_type='np') check_output(us, shape=(N_SAMPLES, 1), type=np.ndarray, msg=f"[use_single={use_single}]") + + +def test_get_internals(): + parametric_circle = lambda x1, x2, t: [diff(x1, t) - x2, diff(x2, t) + x1] + + init_vals_pc = [ + IVP(t_0=0.0, u_0=0.0), + IVP(t_0=0.0, u_0=1.0), + ] + + solver = Solver1D( + ode_system=parametric_circle, + conditions=init_vals_pc, + t_min=0.0, + t_max=2*np.pi, + ) + + solver.fit(max_epochs=1) + internals = solver.get_internals() + assert isinstance(internals, dict) + internals = solver.get_internals(return_type='list') + assert isinstance(internals, dict) + internals = solver.get_internals(return_type='dict') + assert isinstance(internals, dict) + internals = solver.get_internals(['generator', 'n_batches'], return_type='dict') + assert isinstance(internals, dict) + internals = solver.get_internals(['generator', 'n_batches'], return_type='list') + assert isinstance(internals, list) + internals = solver.get_internals('train_generator') + assert isinstance(internals, BaseGenerator) From 6399622e91697abf7a13e14e6a6bd465869df532 Mon Sep 17 00:00:00 2001 From: Shuheng Liu Date: Tue, 5 Jan 2021 19:11:14 +0800 Subject: [PATCH 16/18] api(neurodiffeq): export all modules to top level --- neurodiffeq/__init__.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/neurodiffeq/__init__.py b/neurodiffeq/__init__.py index 0902078..07e1646 100644 --- a/neurodiffeq/__init__.py +++ b/neurodiffeq/__init__.py @@ -13,6 +13,10 @@ from . import ode from . import pde_spherical from . import temporal +from . import solvers +from . import callbacks +from . import monitors +from . import utils # Set default float type to 64 bits _set_tensor_type(float_bits=64) From edebc953275dbd596ac4b59859039be76dce995f Mon Sep 17 00:00:00 2001 From: Shuheng Liu Date: Tue, 5 Jan 2021 19:12:13 +0800 Subject: [PATCH 17/18] docs(advanced): update documentation for tuning solvers --- docs/advanced.ipynb | 362 ++++++++++++++++++++++++++++++++++---------- 1 file changed, 286 insertions(+), 76 deletions(-) diff --git a/docs/advanced.ipynb b/docs/advanced.ipynb index b26a9b2..e52bc14 100644 --- a/docs/advanced.ipynb +++ b/docs/advanced.ipynb @@ -11,20 +11,14 @@ "cell_type": "code", "execution_count": 1, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Setting device and dtype to `torch.FloatTensor`\n" - ] - } - ], + "outputs": [], "source": [ "import numpy as np\n", "import torch\n", "from neurodiffeq.neurodiffeq import safe_diff as diff\n", - "from neurodiffeq.ode import solve, solve_system, Monitor\n", + "from neurodiffeq.ode import solve, solve_system\n", + "from neurodiffeq.solvers import Solver1D\n", + "from neurodiffeq.monitors import Monitor1D\n", "from neurodiffeq.conditions import IVP\n", "\n", "import matplotlib.pyplot as plt\n", @@ -37,13 +31,28 @@ "source": [ "## Tuning the Solver\n", "\n", - "The `ode.solve`, `ode.solve_system` and `pde.solve2D` choose some hyperparameters by default. For example, in `solve`, by default:\n", - "* the solution is approximated by a fully connected network with 1 hidden layer with 32 units (tanh activation),\n", + "The `solve*` functions (in `neurodiffeq.ode`, `neurodiffeq.pde`, `neurodiffeq.pde_spherical`) and `Solver*` classes (in `neurodiffeq.solvers`) choose some hyperparameters by default. For example, in `neurodiffeq.solver.Solver1D`, by default:\n", + "* the solution is approximated by a fully connected network of 2 hidden layers with 32 units each (tanh activation),\n", "* for each epoch we train on 16 different points that are generated by adding a Gaussian noise on the 32 equally spaced points on the $t$ domain,\n", "* an Adam optimizer with learning rate 0.001 is used\n", - "Sometimes we may want to choose these hyperparameters ourselves. We will be using the harmonic oscillator problem from above to demonstrate how to do that.\n", "\n", - "In the following, we demo how to change these default settings using the harmonic oscillator as an example." + "Sometimes we may want to choose these hyperparameters ourselves. We will be using the harmonic oscillator problem from above to demonstrate how to do that." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Simple Harmonic Oscillator Example\n", + "In the following example, we demonstrate how to change these default settings using the harmonic oscillator as an example.\n", + "\n", + "The differential equation and the initial condition are:\n", + "$$\n", + "\\frac{\\partial^2 u}{\\partial t^2} + u = 0\\\\\n", + "u\\bigg{|}_{t=0} = 0 \\quad\n", + "\\frac{\\partial u}{\\partial t}\\bigg{|}_{t=0} = 1\n", + "$$\n", + "\n" ] }, { @@ -52,8 +61,9 @@ "metadata": {}, "outputs": [], "source": [ - "harmonic_oscillator = lambda x, t: diff(x, t, order=2) + x\n", - "init_val_ho = IVP(t_0=0.0, x_0=0.0, x_0_prime=1.0)" + "# Note that the function maps (u, t) to a single-entry list\n", + "harmonic_oscillator = lambda u, t: [ diff(u, t, order=2) + u ]\n", + "init_val_ho = IVP(t_0=0.0, u_0=0.0, u_0_prime=1.0)" ] }, { @@ -70,28 +80,44 @@ "outputs": [], "source": [ "from neurodiffeq.networks import FCNN # fully-connect neural network\n", - "from torch import nn # PyTorch neural network module" + "import torch.nn as nn # PyTorch neural network module" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "We can pass a `torch.nn.Module` object as the `net` argument to `solve`. This specifies the network architecture we will use to approximate $x$. `neurodiffeq.networks.FCNN` is a fully connected network that has the same hidden units for each hidden layer. It can be initiated with the following arguments:\n", + "Whether you are using a `neurodiffeq.solvers.Solver1D` instance or the legacy functions `neurodiffeq.ode.solve` and `neurodiffeq.ode.solve_system` to solve differential equations, you can specify the network architecture you want to use.\n", "\n", - "* `hidden_units`: number of units for each hidden layer. If you have 3 hidden layers with 32, 64, and 16 neurons respectively, then `hidden_units` should be a tuple `(32, 64, 16)`.\n", + "The architecture must be defined as a subclass of `torch.nn.Module`. If you are familiar with PyTorch, this process couldn't be simpler. If you don't know PyTorch at all, we have defined a `neurodiffeq.networks.FCNN` for you. `FCNN` stands for Fully-Connected Neural Network. You can tweak it any how you want by specifying \n", "\n", - "* `actv`: a `torch.nn.Module` *class*. e.g. `nn.Tanh`, `nn.Sigmoid`.\n", + "1. `hidden_units`: number of units for each hidden layer. If you have 3 hidden layers with 32, 64, and 16 neurons respectively, then `hidden_units` should be a tuple `(32, 64, 16)`.\n", "\n", - "Here we create a fully connected network with 2 hidden layers, each with 16 units and tanh activation. We then use it to fit our ODE solution." + "2. `actv`: a `torch.nn.Module` *class*. e.g. `nn.Tanh`, `nn.Sigmoid`. Impirically, [Swish](https://arxiv.org/abs/1710.05941) works better in many situations. We have implemented a `Swish` activation in `neurodiffeq.networks` for you to try out.\n", + "\n", + "3. `n_input_units` and `n_output_units`: number of input/output units of the network. This is largely dependent on your problem. In most cases, `n_output_units` should be 1. And `n_input_units` should be the number of independent variables. In the case of ODE, this is 1, since we only have a single independent variable $t$.\n", + "\n", + "If you want more flexibility than only using fully connected networks, check out [PyTorch's tutorials](https://pytorch.org/tutorials/beginner/examples_nn/two_layer_net_module.html) on defining your custom `torch.nn.Module`. **Pro tip: it's simpler than you think :)**\n", + "\n", + "Once you figure out how to define your own network (as an instance of `torch.nn.Module`), you can pass it to \n", + "\n", + "1. `neurodiffeq.solvers.Solver1D` and other `Solver`s in this Module by specifying `nets=[your_net1, your_net2, ...]`; or\n", + "2. `neurodiffeq.ode.solve`, `neurodiffeq.pde.solve2D`, `neurodiffeq.pde_spherical.solve_spherical`, etc., by specifying `net=your_net`; or\n", + "3. `neurodiffeq.ode.solve_system`, `neurodiffeq.pde.solve2D_sytem`, `neurodiffeq.pde_spherical.solve_spherical_system`, etc., by specifying `nets=[your_net1, your_net2, ...]`.\n", + "\n", + "Notes:\n", + "* Only the 1st way (using a `Solver`) is recommended, the 2nd and 3rd way (using a `solve*` function) are deprecated will some day be removed;\n", + "* In the 2nd case, these functions assumes you only solving a single equation for a single function, so you pass in **a single network** — `net=...`;\n", + "* In the 1st and 3rd cases, they assume you are solving arbitraily many equations for arbitrarily functions, so you pass in **a list of networks** — `nets=[...]`.\n", + "\n", + "\n", + "Here we create a fully connected network with 3 hidden layers, each with 16 units and tanh activation. We then use it to fit our ODE solution." ] }, { "cell_type": "code", "execution_count": 4, - "metadata": { - "scrolled": false - }, + "metadata": {}, "outputs": [ { "data": { @@ -1027,7 +1053,7 @@ { "data": { "text/html": [ - "" + "" ], "text/plain": [ "" @@ -1039,17 +1065,30 @@ ], "source": [ "%matplotlib notebook\n", - "# specify the network architecture\n", + "# Specify the network architecture\n", "net_ho = FCNN(\n", " hidden_units=(16, 16, 16), actv=nn.Tanh\n", ")\n", "\n", - "# solve the ODE\n", - "solution_ho, _ = solve(\n", - " ode=harmonic_oscillator, condition=init_val_ho, t_min=0.0, t_max=2*np.pi,\n", - " net=net_ho,\n", - " monitor=Monitor(t_min=0.0, t_max=2*np.pi, check_every=100)\n", - ")" + "# Create a monitor callback\n", + "from neurodiffeq.callbacks import MonitorCallback\n", + "from neurodiffeq.monitors import Monitor1D\n", + "monitor_callback = MonitorCallback(Monitor1D(t_min=0.0, t_max=2*np.pi, check_every=100))\n", + "\n", + "# Create a solver\n", + "solver = Solver1D(\n", + " ode_system=harmonic_oscillator, # Note that `harmonic_oscillator` returns a single-entry list\n", + " conditions=[init_val_ho], # Again, `conditions` is a single-entry list\n", + " t_min=0.0,\n", + " t_max=2*np.pi,\n", + " nets=[net_ho], # Again, `nets` is a single-entry list\n", + ")\n", + "\n", + "# Fit the solver\n", + "solver.fit(max_epochs=1000, callbacks=[monitor_callback])\n", + "\n", + "# Obtain the solution\n", + "solution_ho = solver.get_solution()" ] }, { @@ -1058,7 +1097,7 @@ "source": [ "### Specifying the Training Set and Validation Set\n", "\n", - "`solve` and `solve_system` will train the neural network on a new set of examples. These examples are $t$s drawn from the domain of $t$. The way these $t$s are generated can be specified by passing a `neurodiffeq.ode.Generator` object as the `train_generator` argument (and `valid_generator` argument) to `solve` or `solve_system`. An `Generator` can be initiated by the following arguments:\n", + "Both `Solver*` classes and `solve*` functions train the neural network on a new set of points, randomly sampled every time. These examples are $t$s drawn from the domain of $t$. The way these $t$s are generated can be specified by passing a `neurodiffeq.generators.BaseGenerator` object as the `train_generator` argument (and `valid_generator` argument) to `Solver*` classes or `solve*` functions. An `Generator` can be intialized by the following arguments:\n", "\n", "* `size`: the number of $t$s generated for each epoch\n", "\n", @@ -2017,7 +2056,7 @@ { "data": { "text/html": [ - "" + "" ], "text/plain": [ "" @@ -2034,11 +2073,20 @@ "valid_gen = Generator1D(size=128, t_min=0.0, t_max=2*np.pi, method='equally-spaced')\n", "\n", "# solve the ODE\n", - "solution_ho, _ = solve(\n", - " ode=harmonic_oscillator, condition=init_val_ho, t_min=0.0, t_max=2*np.pi,\n", - " train_generator=train_gen, valid_generator=valid_gen,\n", - " monitor=Monitor(t_min=0.0, t_max=2*np.pi, check_every=100)\n", - ")" + "solver = Solver1D(\n", + " ode_system=harmonic_oscillator, \n", + " conditions=[init_val_ho], \n", + " t_min=0.0, \n", + " t_max=2*np.pi,\n", + " train_generator=train_gen, \n", + " valid_generator=valid_gen,\n", + ")\n", + "solver.fit(\n", + " max_epochs=1000, \n", + " callbacks=[MonitorCallback(Monitor1D(t_min=0.0, t_max=2*np.pi, check_every=100))]\n", + ")\n", + "\n", + "solution_ho = solver.get_solution()" ] }, { @@ -2047,7 +2095,9 @@ "source": [ "### Specifying the Optimizer\n", "\n", - "We can change the optimization algorithms by passing a `torch.optim.Optimizer` object to `solve` and `solve_system` as the `optimizer` argument. The ugly thing here is that, to initiate an `Optimizer`, we need to tell it the parameters to optimize. In other words, if we want to use a different optimizer from the default one, we also need to create our own networks. \n", + "We can change the optimization algorithms by passing a `torch.optim.Optimizer` object to `Solver*` classes and `solve*` functions as the `optimizer` argument. \n", + "\n", + "If you are familiar with PyTorch, you know that to initialize an `Optimizer`, we need to tell it the parameters to optimize. In other words, if we want to use a different optimizer from the default one, we also need to create our own networks. \n", "\n", "Here we create a fully connected network and an `SGD` optimizer to optimize its weights. Then we use them to solve the ODE." ] @@ -2064,7 +2114,9 @@ { "cell_type": "code", "execution_count": 8, - "metadata": {}, + "metadata": { + "scrolled": false + }, "outputs": [ { "data": { @@ -3000,7 +3052,7 @@ { "data": { "text/html": [ - "" + "" ], "text/plain": [ "" @@ -3014,19 +3066,39 @@ "%matplotlib notebook\n", "# specify the network architecture\n", "net_ho = FCNN(\n", - " n_hidden\n", - " n_hidden_layers=2, n_hidden_units=16, actv=nn.Tanh\n", + " n_input_units=1,\n", + " n_output_units=1,\n", + " hidden_units=(16, 16, 16), \n", + " actv=nn.Tanh,\n", ")\n", "\n", + "nets = [net_ho]\n", + "\n", "# specify the optimizer\n", - "sgd_ho = SGD(net_ho.parameters(), lr=0.001, momentum=0.99)\n", + "from itertools import chain\n", + "\n", + "sgd_ho = SGD(\n", + " chain.from_iterable(n.parameters() for n in nets), # this gives all parameters in `nets`\n", + " lr=0.001, # learning rate\n", + " momentum=0.99, # momentum of SGD\n", + ")\n", "\n", "# solve the ODE\n", - "solution_ho, _ = solve(\n", - " ode=harmonic_oscillator, condition=init_val_ho, t_min=0.0, t_max=2*np.pi,\n", - " net=net_ho, optimizer=sgd_ho,\n", - " monitor=Monitor(t_min=0.0, t_max=2*np.pi, check_every=100)\n", - ")" + "solver = Solver1D(\n", + " ode_system=harmonic_oscillator,\n", + " conditions=[init_val_ho],\n", + " t_min=0.0, \n", + " t_max=2*np.pi,\n", + " nets=nets, \n", + " optimizer=sgd_ho,\n", + ")\n", + "\n", + "solver.fit(\n", + " max_epochs=1000, \n", + " callbacks=[MonitorCallback(Monitor1D(t_min=0.0, t_max=2*np.pi, check_every=100))]\n", + ")\n", + "\n", + "solution_ho = solver.get_solution()" ] }, { @@ -3052,7 +3124,9 @@ { "cell_type": "code", "execution_count": 10, - "metadata": {}, + "metadata": { + "scrolled": false + }, "outputs": [ { "data": { @@ -3988,7 +4062,7 @@ { "data": { "text/html": [ - "" + "" ], "text/plain": [ "" @@ -4001,11 +4075,19 @@ "source": [ "%matplotlib notebook\n", "# solve the ODE\n", - "solution_ho, _ = solve(\n", - " ode=harmonic_oscillator, condition=init_val_ho, t_min=0.0, t_max=2*np.pi,\n", + "solver = Solver1D(\n", + " ode_system=harmonic_oscillator, \n", + " conditions=[init_val_ho], \n", + " t_min=0.0, \n", + " t_max=2*np.pi,\n", " criterion=L1Loss(),\n", - " monitor=Monitor(t_min=0.0, t_max=2*np.pi, check_every=100)\n", - ")" + ")\n", + "solver.fit(\n", + " max_epochs=1000, \n", + " callbacks=[MonitorCallback(Monitor1D(t_min=0.0, t_max=2*np.pi, check_every=100))]\n", + ")\n", + "\n", + "solution_ho = solver.get_solution()" ] }, { @@ -4014,22 +4096,31 @@ "source": [ "## Access the Internals\n", "\n", - "When the network, example generator, optimizer and loss function are specified outside `solve` and `solve_system` function, users will naturally have access to these objects. We may still want to access these objects when we are using default network architecture, example generator, optimizer and loss function. We can get these internal objects by setting the `return_internal` keyword to `True`. This will add a third element in the returned tuple, which is a dictionary containing the reference to the network, example generator, optimizer and loss function." + "When the network, generator, optimizer and loss function are specified outside `solve` and `solve_system` function, users will naturally have access to these objects. We may still want to access these objects when we are using default network architecture, generator, optimizer and loss function. We can get these internal objects by setting the `return_internal` keyword to `True`. This will add a third element in the returned tuple, which is a dictionary containing the reference to the network, example generator, optimizer and loss function." ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/liushuheng/Documents/GitHub/neurodiffeq/neurodiffeq/ode.py:260: FutureWarning: The `solve_system` function is deprecated, use a `neurodiffeq.solvers.Solver1D` instance instead\n", + " warnings.warn(\n" + ] + } + ], "source": [ "# specify the ODE system\n", "parametric_circle = lambda x1, x2, t : [diff(x1, t) - x2, \n", " diff(x2, t) + x1]\n", "# specify the initial conditions\n", "init_vals_pc = [\n", - " IVP(t_0=0.0, x_0=0.0),\n", - " IVP(t_0=0.0, x_0=1.0)\n", + " IVP(t_0=0.0, u_0=0.0),\n", + " IVP(t_0=0.0, u_0=1.0),\n", "]\n", "\n", "# solve the ODE system\n", @@ -4044,25 +4135,35 @@ { "cell_type": "code", "execution_count": 12, - "metadata": {}, + "metadata": { + "scrolled": false + }, "outputs": [ { "data": { "text/plain": [ - "{'single_net': FCNN(\n", - " (NN): Sequential(\n", - " (0): Linear(in_features=1, out_features=32, bias=True)\n", - " (1): Tanh()\n", - " (2): Linear(in_features=32, out_features=32, bias=True)\n", - " (3): Tanh()\n", - " (4): Linear(in_features=32, out_features=2, bias=True)\n", - " )\n", - " ),\n", - " 'nets': None,\n", - " 'conditions': [,\n", - " ],\n", - " 'train_generator': ,\n", - " 'valid_generator': ,\n", + "{'nets': [FCNN(\n", + " (NN): Sequential(\n", + " (0): Linear(in_features=1, out_features=32, bias=True)\n", + " (1): Tanh()\n", + " (2): Linear(in_features=32, out_features=32, bias=True)\n", + " (3): Tanh()\n", + " (4): Linear(in_features=32, out_features=2, bias=True)\n", + " )\n", + " ),\n", + " FCNN(\n", + " (NN): Sequential(\n", + " (0): Linear(in_features=1, out_features=32, bias=True)\n", + " (1): Tanh()\n", + " (2): Linear(in_features=32, out_features=32, bias=True)\n", + " (3): Tanh()\n", + " (4): Linear(in_features=32, out_features=2, bias=True)\n", + " )\n", + " )],\n", + " 'conditions': [,\n", + " ],\n", + " 'train_generator': ,\n", + " 'valid_generator': ,\n", " 'optimizer': Adam (\n", " Parameter Group 0\n", " amsgrad: False\n", @@ -4071,7 +4172,7 @@ " lr: 0.001\n", " weight_decay: 0\n", " ),\n", - " 'criterion': MSELoss()}" + " 'criterion': .(r)>}" ] }, "execution_count": 12, @@ -4083,6 +4184,115 @@ "internal" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "You get more internal objects when using the `Solver`s. The process is demonstrated as follows:" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "parametric_circle = lambda x1, x2, t: [diff(x1, t) - x2, diff(x2, t) + x1]\n", + "\n", + "init_vals_pc = [\n", + " IVP(t_0=0.0, u_0=0.0),\n", + " IVP(t_0=0.0, u_0=1.0),\n", + "]\n", + "\n", + "solver = Solver1D(\n", + " ode_system=parametric_circle,\n", + " conditions=init_vals_pc,\n", + " t_min=0.0,\n", + " t_max=2*np.pi,\n", + ")\n", + "\n", + "solver.fit(max_epochs=100)\n", + "internals = solver.get_internals()" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'metrics': {},\n", + " 'n_batches': {'train': 1, 'valid': 4},\n", + " 'best_nets': [FCNN(\n", + " (NN): Sequential(\n", + " (0): Linear(in_features=1, out_features=32, bias=True)\n", + " (1): Tanh()\n", + " (2): Linear(in_features=32, out_features=32, bias=True)\n", + " (3): Tanh()\n", + " (4): Linear(in_features=32, out_features=1, bias=True)\n", + " )\n", + " ),\n", + " FCNN(\n", + " (NN): Sequential(\n", + " (0): Linear(in_features=1, out_features=32, bias=True)\n", + " (1): Tanh()\n", + " (2): Linear(in_features=32, out_features=32, bias=True)\n", + " (3): Tanh()\n", + " (4): Linear(in_features=32, out_features=1, bias=True)\n", + " )\n", + " )],\n", + " 'criterion': .(r)>,\n", + " 'conditions': [,\n", + " ],\n", + " 'global_epoch': 100,\n", + " 'lowest_loss': 0.026307572100156426,\n", + " 'n_funcs': 2,\n", + " 'nets': [FCNN(\n", + " (NN): Sequential(\n", + " (0): Linear(in_features=1, out_features=32, bias=True)\n", + " (1): Tanh()\n", + " (2): Linear(in_features=32, out_features=32, bias=True)\n", + " (3): Tanh()\n", + " (4): Linear(in_features=32, out_features=1, bias=True)\n", + " )\n", + " ),\n", + " FCNN(\n", + " (NN): Sequential(\n", + " (0): Linear(in_features=1, out_features=32, bias=True)\n", + " (1): Tanh()\n", + " (2): Linear(in_features=32, out_features=32, bias=True)\n", + " (3): Tanh()\n", + " (4): Linear(in_features=32, out_features=1, bias=True)\n", + " )\n", + " )],\n", + " 'optimizer': Adam (\n", + " Parameter Group 0\n", + " amsgrad: False\n", + " betas: (0.9, 0.999)\n", + " eps: 1e-08\n", + " lr: 0.001\n", + " weight_decay: 0\n", + " ),\n", + " 'diff_eqs': (x1, x2, t)>,\n", + " 'generator': {'train': ,\n", + " 'valid': },\n", + " 'train_generator': ,\n", + " 'valid_generator': ,\n", + " 't_min': 0.0,\n", + " 't_max': 6.283185307179586}" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "internals" + ] + }, { "cell_type": "code", "execution_count": null, From 6adc48d8bb26442cf00633aa3acb020189ed3983 Mon Sep 17 00:00:00 2001 From: Shuheng Liu Date: Tue, 5 Jan 2021 19:14:44 +0800 Subject: [PATCH 18/18] chore(pypi): prepare for release of v0.3.0 --- setup.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/setup.py b/setup.py index 7864e28..ea408fe 100644 --- a/setup.py +++ b/setup.py @@ -8,14 +8,14 @@ setuptools.setup( name="neurodiffeq", - version="0.2.2", + version="0.3.0", author="odegym", author_email="wish1104@icloud.com", description="A light-weight & flexible library for solving differential equations using neural networks based on PyTorch. ", long_description=long_description, long_description_content_type="text/markdown", url="https://github.com/odegym/neurodiffeq", - download_url="https://github.com/odegym/neurodiffeq/archive/v0.2.2.tar.gz", + download_url="https://github.com/odegym/neurodiffeq/archive/v0.3.0.tar.gz", keywords=[ "neural network", "deep learning",