Skip to content

Commit

Permalink
Merge pull request #78 from odegym/v0.3.0
Browse files Browse the repository at this point in the history
Releasing V0.3.0
  • Loading branch information
shuheng-liu authored Jan 5, 2021
2 parents 718f226 + 6adc48d commit ceb547c
Show file tree
Hide file tree
Showing 10 changed files with 2,111 additions and 608 deletions.
362 changes: 286 additions & 76 deletions docs/advanced.ipynb

Large diffs are not rendered by default.

1,744 changes: 1,467 additions & 277 deletions docs/getstart.ipynb

Large diffs are not rendered by default.

4 changes: 4 additions & 0 deletions neurodiffeq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@
from . import ode
from . import pde_spherical
from . import temporal
from . import solvers
from . import callbacks
from . import monitors
from . import utils

# Set default float type to 64 bits
_set_tensor_type(float_bits=64)
Expand Down
83 changes: 83 additions & 0 deletions neurodiffeq/callbacks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import os
import dill
from datetime import datetime
import logging


class MonitorCallback:
"""A callback for updating the monitor plots (and optionally saving the fig to disk).
:param monitor: The underlying monitor responsible for plotting solutions.
:type monitor: `neurodiffeq.monitors.BaseMonitor`
:param fig_dir: Directory for saving monitor figs; if not specified, figs will not be saved.
:type fig_dir: str
:param check_against: Which epoch count to check against; either 'local' (default) or 'global'.
:type check_against: str
:param repaint_last: Whether to update the plot on the last local epoch, defaults to True.
:type repaint_last: bool
"""

def __init__(self, monitor, fig_dir=None, check_against='local', repaint_last=True):
self.monitor = monitor
self.fig_dir = fig_dir
self.repaint_last = repaint_last
if check_against not in ['local', 'global']:
raise ValueError(f'unknown check_against type = {check_against}')
self.check_against = check_against

def to_repaint(self, solver):
if self.check_against == 'local':
epoch_now = solver.local_epoch + 1
elif self.check_against == 'global':
epoch_now = solver.global_epoch + 1
else:
raise ValueError(f'unknown check_against type = {self.check_against}')

if epoch_now % self.monitor.check_every == 0:
return True
if self.repaint_last and solver.local_epoch == solver._max_local_epoch - 1:
return True

return False

def __call__(self, solver):
if not self.to_repaint(solver):
return

self.monitor.check(
solver.nets,
solver.conditions,
history=solver.metrics_history,
)
if self.fig_dir:
pic_path = os.path.join(self.fig_dir, f"epoch-{solver.global_epoch}.png")
self.monitor.fig.savefig(pic_path)


class CheckpointCallback:
def __init__(self, ckpt_dir):
self.ckpt_dir = ckpt_dir

def __call__(self, solver):
if solver.local_epoch == solver._max_local_epoch - 1:
now = datetime.now()
timestr = now.strftime("%Y-%m-%d_%H-%M-%S")
fname = os.path.join(self.ckpt_dir, timestr + ".internals")
with open(fname, 'wb') as f:
dill.dump(solver.get_internals("all"), f)
logging.info(f"Saved checkpoint to {fname} at local epoch = {solver.local_epoch} "
f"(global epoch = {solver.global_epoch})")


class ReportOnFitCallback:
def __call__(self, solver):
if solver.local_epoch == 0:
logging.info(
f"Starting from global epoch {solver.global_epoch - 1}, training on {(solver.r_min, solver.r_max)}")
tb = solver.generator['train'].size
ntb = solver.n_batches['train']
t = tb * ntb
vb = solver.generator['valid'].size
nvb = solver.n_batches['valid']
v = vb * nvb
logging.info(f"train size = {tb} x {ntb} = {t}, valid_size = {vb} x {nvb} = {v}")
65 changes: 46 additions & 19 deletions neurodiffeq/ode.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,10 @@ def _trial_solution(single_net, nets, ts, conditions):

def solve(
ode, condition, t_min=None, t_max=None,
net=None, train_generator=None, shuffle=True, valid_generator=None,
optimizer=None, criterion=None, additional_loss_term=None, metrics=None, batch_size=16,
max_epochs=1000,
monitor=None, return_internal=False,
return_best=False
net=None, train_generator=None, valid_generator=None,
optimizer=None, criterion=None, n_batches_train=1, n_batches_valid=4,
additional_loss_term=None, metrics=None, max_epochs=1000,
monitor=None, return_internal=False, return_best=False, batch_size=None, shuffle=None,
):
r"""Train a neural network to solve an ODE.
Expand Down Expand Up @@ -70,10 +69,6 @@ def solve(
The example generator to generate 1-D training points.
Default to None.
:type train_generator: `neurodiffeq.generators.Generator1D`, optional
:param shuffle:
Whether to shuffle the training examples every epoch.
Defaults to True.
:type shuffle: bool, optional
:param valid_generator:
The example generator to generate 1-D validation points.
Default to None.
Expand All @@ -86,6 +81,14 @@ def solve(
The loss function to use for training.
Defaults to None.
:type criterion: `torch.nn.modules.loss._Loss`, optional
:param n_batches_train:
Number of batches to train in every epoch, where batch-size equals ``train_generator.size``.
Defaults to 1.
:type n_batches_train: int, optional
:param n_batches_valid:
Number of batches to validate in every epoch, where batch-size equals ``valid_generator.size``.
Defaults to 4.
:type n_batches_valid: int, optional
:param additional_loss_term:
Extra terms to add to the loss function besides the part specified by `criterion`.
The input of `additional_loss_term` should be the same as `ode`.
Expand All @@ -97,10 +100,6 @@ def solve(
The input functions should be the same as `ode` and the output should be a numeric value.
The metrics are evaluated on both the training set and validation set.
:type metrics: dict[string, callable]
:param batch_size:
The size of the mini-batch to use.
Defaults to 16.
:type batch_size: int, optional
:param max_epochs:
The maximum number of epochs to train.
Defaults to 1000.
Expand All @@ -117,29 +116,43 @@ def solve(
Whether to return the nets that achieved the lowest validation loss.
Defaults to False.
:type return_best: bool, optional
:param batch_size:
**[DEPRECATED and IGNORED]**
Each batch will use all samples generated.
Please specify ``n_batches_train`` and ``n_batches_valid`` instead.
:type batch_size: int
:param shuffle:
**[DEPRECATED and IGNORED]**
Shuffling should be performed by generators.
:type shuffle: bool
:return:
The solution of the ODE.
The history of training loss and validation loss.
Optionally, the nets, conditions, training generator, validation generator, optimizer and loss function.
:rtype: tuple[`neurodiffeq.ode.Solution`, dict] or tuple[`neurodiffeq.ode.Solution`, dict, dict]
.. note::
This function is deprecated, use a ``neurodiffeq.solvers.Solver1D`` instead.
"""
nets = None if not net else [net]
return solve_system(
ode_system=lambda x, t: [ode(x, t)], conditions=[condition],
t_min=t_min, t_max=t_max, nets=nets,
train_generator=train_generator, shuffle=shuffle, valid_generator=valid_generator,
optimizer=optimizer, criterion=criterion, additional_loss_term=additional_loss_term, metrics=metrics,
optimizer=optimizer, criterion=criterion, n_batches_train=n_batches_train, n_batches_valid=n_batches_valid,
additional_loss_term=additional_loss_term, metrics=metrics,
batch_size=batch_size, max_epochs=max_epochs, monitor=monitor, return_internal=return_internal,
return_best=return_best
)


def solve_system(
ode_system, conditions, t_min, t_max,
single_net=None, nets=None, train_generator=None, shuffle=True, valid_generator=None,
optimizer=None, criterion=None, additional_loss_term=None, metrics=None, batch_size=16,
max_epochs=1000, monitor=None, return_internal=False,
return_best=False,
single_net=None, nets=None, train_generator=None, valid_generator=None,
optimizer=None, criterion=None, n_batches_train=1, n_batches_valid=4,
additional_loss_term=None, metrics=None, max_epochs=1000, monitor=None,
return_internal=False, return_best=False, batch_size=None, shuffle=None,
):
r"""Train a neural network to solve an ODE.
Expand Down Expand Up @@ -189,6 +202,14 @@ def solve_system(
The loss function to use for training.
Defaults to None and sum of square of the output of `ode_system` will be used.
:type criterion: `torch.nn.modules.loss._Loss`, optional
:param n_batches_train:
Number of batches to train in every epoch, where batch-size equals ``train_generator.size``.
Defaults to 1.
:type n_batches_train: int, optional
:param n_batches_valid:
Number of batches to validate in every epoch, where batch-size equals ``valid_generator.size``.
Defaults to 4.
:type n_batches_valid: int, optional
:param additional_loss_term:
Extra terms to add to the loss function besides the part specified by `criterion`.
The input of `additional_loss_term` should be the same as `ode_system`.
Expand Down Expand Up @@ -220,7 +241,7 @@ def solve_system(
:param batch_size:
**[DEPRECATED and IGNORED]**
Each batch will use all samples generated.
Please specify n_batches_train and n_batches_valid instead.
Please specify ``n_batches_train`` and ``n_batches_valid`` instead.
:type batch_size: int
:param shuffle:
**[DEPRECATED and IGNORED]**
Expand All @@ -230,6 +251,10 @@ def solve_system(
The solution of the ODE. The history of training loss and validation loss.
Optionally, the nets, conditions, training generator, validation generator, optimizer and loss function.
:rtype: tuple[`neurodiffeq.ode.Solution`, dict] or tuple[`neurodiffeq.ode.Solution`, dict, dict]
.. note::
This function is deprecated, use a ``neurodiffeq.solvers.Solver1D`` instead.
"""

warnings.warn(
Expand Down Expand Up @@ -272,6 +297,8 @@ class CustomSolver1D(Solver1D):
valid_generator=valid_generator,
optimizer=optimizer,
criterion=criterion,
n_batches_train=n_batches_train,
n_batches_valid=n_batches_valid,
metrics=metrics,
batch_size=batch_size,
shuffle=shuffle,
Expand Down
Loading

0 comments on commit ceb547c

Please sign in to comment.