Skip to content

Commit

Permalink
Merge pull request #150 from NeuroDiffGym/v0.5.0
Browse files Browse the repository at this point in the history
fix bug in H1 and H1-semi loss
  • Loading branch information
shuheng-liu authored Sep 29, 2021
2 parents 7bb802a + 9eae9ec commit ab670a1
Show file tree
Hide file tree
Showing 4 changed files with 106 additions and 14 deletions.
4 changes: 2 additions & 2 deletions neurodiffeq/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@ def _infinity_norm(residual, funcs, coords):

def _h1_norm(residual, funcs, coords):
g = grad(residual, *coords)
rg = torch.cat([residual, *g])
rg = torch.cat([residual, *g], dim=1)
return (rg ** 2).mean()


def _h1_semi_norm(residual, funcs, coords):
g = grad(residual, *coords)
g = torch.cat(g)
g = torch.cat(g, dim=1)
return (g ** 2).mean()


Expand Down
75 changes: 65 additions & 10 deletions neurodiffeq/solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,19 @@ class BaseSolver(ABC, PretrainedSolver):
The optimizer to be used for training.
:type optimizer: `torch.nn.optim.Optimizer`, optional
:param criterion:
A function that maps a PDE residual vector (torch tensor with shape (-1, 1)) to a scalar loss.
:type criterion: callable, optional
The loss function used for training.
- If a str, must be present in the keys of `neurodiffeq.losses._losses`.
- If a `torch.nn.modules.loss._Loss` instance, just pass the instance.
- If any other callable, it must map
A) a residual tensor (shape `(n_points, n_equations)`),
B) a function values tuple (length `n_funcs`, each element a tensor of shape `(n_points, 1)`), and
C) a coordinate values tuple (length `n_coords`, each element a tensor of shape `(n_coords, 1)`
to a tensor of empty shape (i.e. a scalar). The returned tensor must be connected to the computational graph,
so that backpropagation can be performed.
:type criterion:
str or `torch.nn.moduesl.loss._Loss` or callable
:param n_batches_train:
Number of batches to train in every epoch, where batch-size equals ``train_generator.size``.
Defaults to 1.
Expand Down Expand Up @@ -687,8 +698,19 @@ class SolverSpherical(BaseSolver):
Defaults to a ``torch.optim.Adam`` instance that trains on all parameters of ``nets``.
:type optimizer: ``torch.nn.optim.Optimizer``, optional
:param criterion:
Function that maps a PDE residual tensor (of shape (-1, 1)) to a scalar loss.
:type criterion: callable, optional
The loss function used for training.
- If a str, must be present in the keys of `neurodiffeq.losses._losses`.
- If a `torch.nn.modules.loss._Loss` instance, just pass the instance.
- If any other callable, it must map
A) a residual tensor (shape `(n_points, n_equations)`),
B) a function values tuple (length `n_funcs`, each element a tensor of shape `(n_points, 1)`), and
C) a coordinate values tuple (length `n_coords`, each element a tensor of shape `(n_coords, 1)`
to a tensor of empty shape (i.e. a scalar). The returned tensor must be connected to the computational graph,
so that backpropagation can be performed.
:type criterion:
str or `torch.nn.moduesl.loss._Loss` or callable
:param n_batches_train:
Number of batches to train in every epoch, where batch-size equals ``train_generator.size``.
Defaults to 1.
Expand Down Expand Up @@ -935,8 +957,19 @@ class Solver1D(BaseSolver):
Defaults to a ``torch.optim.Adam`` instance that trains on all parameters of ``nets``.
:type optimizer: ``torch.nn.optim.Optimizer``, optional
:param criterion:
Function that maps a ODE residual tensor (of shape (-1, 1)) to a scalar loss.
:type criterion: callable, optional
The loss function used for training.
- If a str, must be present in the keys of `neurodiffeq.losses._losses`.
- If a `torch.nn.modules.loss._Loss` instance, just pass the instance.
- If any other callable, it must map
A) a residual tensor (shape `(n_points, n_equations)`),
B) a function values tuple (length `n_funcs`, each element a tensor of shape `(n_points, 1)`), and
C) a coordinate values tuple (length `n_coords`, each element a tensor of shape `(n_coords, 1)`
to a tensor of empty shape (i.e. a scalar). The returned tensor must be connected to the computational graph,
so that backpropagation can be performed.
:type criterion:
str or `torch.nn.moduesl.loss._Loss` or callable
:param n_batches_train:
Number of batches to train in every epoch, where batch-size equals ``train_generator.size``.
Defaults to 1.
Expand Down Expand Up @@ -1108,8 +1141,19 @@ class BundleSolver1D(BaseSolver):
Defaults to a ``torch.optim.Adam`` instance that trains on all parameters of ``nets``.
:type optimizer: ``torch.nn.optim.Optimizer``, optional
:param criterion:
Function that maps a ODE residual tensor (of shape (-1, 1)) to a scalar loss.
:type criterion: callable, optional
The loss function used for training.
- If a str, must be present in the keys of `neurodiffeq.losses._losses`.
- If a `torch.nn.modules.loss._Loss` instance, just pass the instance.
- If any other callable, it must map
A) a residual tensor (shape `(n_points, n_equations)`),
B) a function values tuple (length `n_funcs`, each element a tensor of shape `(n_points, 1)`), and
C) a coordinate values tuple (length `n_coords`, each element a tensor of shape `(n_coords, 1)`
to a tensor of empty shape (i.e. a scalar). The returned tensor must be connected to the computational graph,
so that backpropagation can be performed.
:type criterion:
str or `torch.nn.moduesl.loss._Loss` or callable
:param n_batches_train:
Number of batches to train in every epoch, where batch-size equals ``train_generator.size``.
Defaults to 1.
Expand Down Expand Up @@ -1308,8 +1352,19 @@ class Solver2D(BaseSolver):
Defaults to a ``torch.optim.Adam`` instance that trains on all parameters of ``nets``.
:type optimizer: ``torch.nn.optim.Optimizer``, optional
:param criterion:
Function that maps a PDE residual tensor (of shape (-1, 1)) to a scalar loss.
:type criterion: callable, optional
The loss function used for training.
- If a str, must be present in the keys of `neurodiffeq.losses._losses`.
- If a `torch.nn.modules.loss._Loss` instance, just pass the instance.
- If any other callable, it must map
A) a residual tensor (shape `(n_points, n_equations)`),
B) a function values tuple (length `n_funcs`, each element a tensor of shape `(n_points, 1)`), and
C) a coordinate values tuple (length `n_coords`, each element a tensor of shape `(n_coords, 1)`
to a tensor of empty shape (i.e. a scalar). The returned tensor must be connected to the computational graph,
so that backpropagation can be performed.
:type criterion:
str or `torch.nn.moduesl.loss._Loss` or callable
:param n_batches_train:
Number of batches to train in every epoch, where batch-size equals ``train_generator.size``.
Defaults to 1.
Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,14 @@ def func(m):

setuptools.setup(
name="neurodiffeq",
version="0.4.0",
version="0.5.0",
author="neurodiffgym",
author_email="[email protected]",
description="A light-weight & flexible library for solving differential equations using neural networks based on PyTorch. ",
long_description=long_description,
long_description_content_type="text/markdown",
url="https://github.com/NeuroDiffGym/neurodiffeq",
download_url="https://github.com/NeuroDiffGym/neurodiffeq/archive/v0.4.0.tar.gz",
download_url="https://github.com/NeuroDiffGym/neurodiffeq/archive/v0.5.0.tar.gz",
keywords=[
"neural network",
"deep learning",
Expand Down
37 changes: 37 additions & 0 deletions tests/test_losses.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import pytest
import torch
from neurodiffeq import diff
from neurodiffeq.losses import _losses as losses

N = 100


def pde_system(u, v, w, x, y):
return [
diff(u, x, order=2) + diff(u, y, order=2),
diff(v, x, order=2) + diff(v, y, order=2),
diff(w, x, order=2) + diff(w, y, order=2),
]


def get_rfx(n_input, n_output, n_equation):
coords = [torch.rand((N, 1), requires_grad=True) for _ in range(n_input)]
coords_tensor = torch.cat(coords, dim=1)
funcs = [torch.sigmoid(torch.sum(coords_tensor, dim=1, keepdim=True)) for _ in range(n_output)]
residual = [diff(funcs[0], coords[0]) + funcs[0] for _ in range(n_equation)]
residual = torch.cat(residual, dim=1)
return residual, funcs, coords


@pytest.mark.parametrize(argnames='n_input', argvalues=[1, 3])
@pytest.mark.parametrize(argnames='n_output', argvalues=[1, 3])
@pytest.mark.parametrize(argnames='n_equation', argvalues=[1, 3])
@pytest.mark.parametrize(
argnames=('loss_name', 'loss_fn'),
argvalues=losses.items(),
)
def test_losses(n_input, n_output, n_equation, loss_name, loss_fn):
r, f, x = get_rfx(n_input, n_output, n_equation)
loss = loss_fn(r, f, x)
assert loss.shape == (), f"{loss_name} doesn't output scalar"
assert loss.requires_grad, f"{loss_name} doesn't require gradient"

0 comments on commit ab670a1

Please sign in to comment.