Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add typing and type annotations for icbc #1453

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
Open
181 changes: 153 additions & 28 deletions deepxde/icbc/boundary_conditions.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,27 +14,37 @@
import numbers
from abc import ABC, abstractmethod
from functools import wraps
from typing import Any, Callable, List, Optional, overload, Tuple, Union
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given that DeepXDE supports Python 3.8+, you should use -

from __future__ import annotations
from typing import Any, Callable, Optional, overload


import numpy as np
from numpy.typing import NDArray, ArrayLike

from .. import backend as bkd
from .. import config
from .. import data
from .. import gradients as grad
from .. import utils
from ..backend import backend_name
from ..geometry import Geometry
from ..types import Tensor, TensorOrTensors


class BC(ABC):
"""Boundary condition base class.

Args:
geom: A ``deepxde.geometry.Geometry`` instance.
on_boundary: A function: (x, Geometry.on_boundary(x)) -> True/False.
component: The output component satisfying this BC.
on_boundary: A function: `(x, Geometry.on_boundary(x))` -> True/False.
component: The output component satisfying this BC, should be provided
if ``BC.error`` involves derivatives and the output has multiple components.
"""

def __init__(self, geom, on_boundary, component):
def __init__(
self,
geom: Geometry,
on_boundary: Callable[[NDArray[Any], NDArray[Any]], NDArray[np.bool_]],
component: Union[List[int], int],
):
self.geom = geom
self.on_boundary = lambda x, on: np.array(
[on_boundary(x[i], on[i]) for i in range(len(x))]
Expand All @@ -45,28 +55,57 @@ def __init__(self, geom, on_boundary, component):
utils.return_tensor(self.geom.boundary_normal)
)

def filter(self, X):
def filter(self, X: NDArray[Any]) -> NDArray[np.bool_]:
return X[self.on_boundary(X, self.geom.on_boundary(X))]

def collocation_points(self, X):
def collocation_points(self, X: NDArray[Any]) -> NDArray[Any]:
return self.filter(X)

def normal_derivative(self, X, inputs, outputs, beg, end):
def normal_derivative(
self,
X: NDArray[Any],
inputs: TensorOrTensors,
outputs: Tensor,
beg: int,
end: int,
) -> Tensor:
dydx = grad.jacobian(outputs, inputs, i=self.component, j=None)[beg:end]
n = self.boundary_normal(X, beg, end, None)
return bkd.sum(dydx * n, 1, keepdims=True)

@abstractmethod
def error(self, X, inputs, outputs, beg, end, aux_var=None):
def error(
self,
X: NDArray[Any],
inputs: TensorOrTensors,
outputs: Tensor,
beg: int,
end: int,
aux_var: Union[NDArray[np.float_], None] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should work once the suggestion above is applied.

Suggested change
aux_var: Union[NDArray[np.float_], None] = None,
aux_var: NDArray[np.float_] | None = None,

) -> Tensor:
"""Returns the loss."""
# aux_var is used in PI-DeepONet, where aux_var is the input function evaluated
# at x.


class DirichletBC(BC):
"""Dirichlet boundary conditions: y(x) = func(x)."""
"""Dirichlet boundary conditions: `y(x) = func(x)`.

Args:
geom: A ``deepxde.geometry.Geometry`` instance.
func: A function: `x` -> `y`.
on_boundary: A function: `(x, Geometry.on_boundary(x))` -> True/False.
component: The output component satisfying this BC, should be provided
if ``BC.error`` involves derivatives and the output has multiple components.
"""

def __init__(self, geom, func, on_boundary, component=0):
def __init__(
self,
geom: Geometry,
func: Callable[[NDArray[np.float_]], NDArray[np.float_]],
on_boundary: Callable[[NDArray[Any], NDArray[Any]], NDArray[np.bool_]],
component: Union[List[int], int] = 0,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
component: Union[List[int], int] = 0,
component: list[int] | int = 0,

):
super().__init__(geom, on_boundary, component)
self.func = npfunc_range_autocache(utils.return_tensor(func))

Expand All @@ -81,9 +120,23 @@ def error(self, X, inputs, outputs, beg, end, aux_var=None):


class NeumannBC(BC):
"""Neumann boundary conditions: dy/dn(x) = func(x)."""
"""Neumann boundary conditions: `dy/dn(x) = func(x)`.

def __init__(self, geom, func, on_boundary, component=0):
Args:
geom: A ``deepxde.geometry.Geometry`` instance.
func: A function: `x` -> `dy/dn`.
on_boundary: A function: `(x, Geometry.on_boundary(x))` -> True/False.
component: The output component satisfying this BC, should be provided
if ``BC.error`` involves derivatives and the output has multiple components.
"""

def __init__(
self,
geom: Geometry,
func: Callable[[NDArray[np.float_]], NDArray[np.float_]],
on_boundary: Callable[[NDArray[Any], NDArray[Any]], NDArray[np.bool_]],
component: Union[List[int], int] = 0,
):
super().__init__(geom, on_boundary, component)
self.func = npfunc_range_autocache(utils.return_tensor(func))

Expand All @@ -93,9 +146,23 @@ def error(self, X, inputs, outputs, beg, end, aux_var=None):


class RobinBC(BC):
"""Robin boundary conditions: dy/dn(x) = func(x, y)."""
"""Robin boundary conditions: `dy/dn(x) = func(x, y)`.

Args:
geom: A ``deepxde.geometry.Geometry`` instance.
func: A function: `(x, y)` -> `dy/dn`.
on_boundary: A function: `(x, Geometry.on_boundary(x))` -> True/False.
component: The output component satisfying this BC, should be provided
if ``BC.error`` involves derivatives and the output has multiple components.
"""

def __init__(self, geom, func, on_boundary, component=0):
def __init__(
self,
geom: Geometry,
func: Callable[[NDArray[np.float_]], NDArray[np.float_]],
on_boundary: Callable[[NDArray[Any], NDArray[Any]], NDArray[np.bool_]],
component: Union[List[int], int] = 0,
):
super().__init__(geom, on_boundary, component)
self.func = func

Expand All @@ -106,9 +173,25 @@ def error(self, X, inputs, outputs, beg, end, aux_var=None):


class PeriodicBC(BC):
"""Periodic boundary conditions on component_x."""
"""Periodic boundary conditions on component_x.

Args:
geom: A ``deepxde.geometry.Geometry`` instance.
component_x: The component of the input satisfying this BC.
on_boundary: A function: `(x, Geometry.on_boundary(x))` -> True/False.
derivative_order: The derivative order of the output satisfying this BC.
component: The output component satisfying this BC, should be provided
if ``BC.error`` involves derivatives and the output has multiple components.
"""

def __init__(self, geom, component_x, on_boundary, derivative_order=0, component=0):
def __init__(
self,
geom: Geometry,
component_x: int,
on_boundary: Callable[[NDArray[Any], NDArray[Any]], NDArray[np.bool_]],
derivative_order: int = 0,
component: Union[List[int], int] = 0,
):
super().__init__(geom, on_boundary, component)
self.component_x = component_x
self.derivative_order = derivative_order
Expand All @@ -135,11 +218,11 @@ def error(self, X, inputs, outputs, beg, end, aux_var=None):


class OperatorBC(BC):
"""General operator boundary conditions: func(inputs, outputs, X) = 0.
"""General operator boundary conditions: `func(inputs, outputs, X) = 0`.

Args:
geom: ``Geometry``.
func: A function takes arguments (`inputs`, `outputs`, `X`)
geom: A ``deepxde.geometry.Geometry`` instance.
func: A function takes arguments `(inputs, outputs, X)`
and outputs a tensor of size `N x 1`, where `N` is the length of `inputs`.
`inputs` and `outputs` are the network input and output tensors,
respectively; `X` are the NumPy array of the `inputs`.
Expand All @@ -153,15 +236,20 @@ class OperatorBC(BC):
which cannot be fixed in an easy way for all backends.
"""

def __init__(self, geom, func, on_boundary):
def __init__(
self,
geom: Geometry,
func: Callable[[TensorOrTensors, Tensor, NDArray[np.float_]], Tensor],
on_boundary: Callable[[NDArray[Any], NDArray[Any]], NDArray[np.bool_]],
):
super().__init__(geom, on_boundary, 0)
self.func = func

def error(self, X, inputs, outputs, beg, end, aux_var=None):
return self.func(inputs, outputs, X)[beg:end]


class PointSetBC:
class PointSetBC(BC):
"""Dirichlet boundary condition for a set of points.

Compare the output (that associates with `points`) with `values` (target data).
Expand All @@ -172,7 +260,7 @@ class PointSetBC:
points: An array of points where the corresponding target values are known and
used for training.
values: A scalar or a 2D-array of values that gives the exact solution of the problem.
component: Integer or a list of integers. The output components satisfying this BC.
omponent: Integer or a list of integers. The output components satisfying this BC.
List of integers only supported for the backend PyTorch.
batch_size: The number of points per minibatch, or `None` to return all points.
This is only supported for the backend PyTorch and PaddlePaddle.
Expand All @@ -181,7 +269,14 @@ class PointSetBC:
shuffle: Randomize the order on each pass through the data when batching.
"""

def __init__(self, points, values, component=0, batch_size=None, shuffle=True):
def __init__(
self,
points: ArrayLike,
values: ArrayLike,
component: Union[List[int], int] = 0,
batch_size: Union[int, None] = None,
shuffle: bool = True,
):
self.points = np.array(points, dtype=config.real(np))
self.values = bkd.as_tensor(values, dtype=config.real(bkd.lib))
self.component = component
Expand Down Expand Up @@ -233,7 +328,7 @@ def error(self, X, inputs, outputs, beg, end, aux_var=None):
return outputs[beg:end, self.component] - self.values


class PointSetOperatorBC:
class PointSetOperatorBC(BC):
"""General operator boundary conditions for a set of points.

Compare the function output, func, (that associates with `points`)
Expand All @@ -249,7 +344,12 @@ class PointSetOperatorBC:
tensors, respectively; `X` are the NumPy array of the `inputs`.
"""

def __init__(self, points, values, func):
def __init__(
self,
points: ArrayLike,
values: ArrayLike,
func: Callable[[TensorOrTensors, Tensor, NDArray[np.float_]], Tensor],
):
self.points = np.array(points, dtype=config.real(np))
if not isinstance(values, numbers.Number) and values.shape[1] != 1:
raise RuntimeError("PointSetOperatorBC should output 1D values")
Expand All @@ -263,6 +363,22 @@ def error(self, X, inputs, outputs, beg, end, aux_var=None):
return self.func(inputs, outputs, X)[beg:end] - self.values


@overload
def npfunc_range_autocache(
func: Callable[[NDArray[np.float_]], NDArray[np.float_]]
) -> NDArray[np.float_]:
...


@overload
def npfunc_range_autocache(
func: Callable[
[NDArray[np.float_], NDArray[np.float_]], Optional[NDArray[np.float_]]
]
) -> NDArray[np.float_]:
...


def npfunc_range_autocache(func):
"""Call a NumPy function on a range of the input ndarray.

Expand Down Expand Up @@ -291,22 +407,31 @@ def npfunc_range_autocache(func):
cache = {}

@wraps(func)
def wrapper_nocache(X, beg, end, _):
def wrapper_nocache(
X: NDArray[np.float_], beg: int, end: int, _
) -> NDArray[np.float_]:
return func(X[beg:end])

@wraps(func)
def wrapper_nocache_auxiliary(X, beg, end, aux_var):
def wrapper_nocache_auxiliary(
X: NDArray[np.float_], beg: int, end: int, aux_var: NDArray[np.float_]
) -> NDArray[np.float_]:
aux_var: callable
return func(X[beg:end], aux_var[beg:end])

@wraps(func)
def wrapper_cache(X, beg, end, _):
def wrapper_cache(
X: NDArray[np.float_], beg: int, end: int, _
) -> NDArray[np.float_]:
key = (id(X), beg, end)
if key not in cache:
cache[key] = func(X[beg:end])
return cache[key]

@wraps(func)
def wrapper_cache_auxiliary(X, beg, end, aux_var):
def wrapper_cache_auxiliary(
X: NDArray[np.float_], beg: int, end: int, aux_var: NDArray[np.float_]
) -> NDArray[np.float_]:
# Even if X is the same one, aux_var could be different
key = (id(X), beg, end)
if key not in cache:
Expand Down
27 changes: 23 additions & 4 deletions deepxde/icbc/initial_conditions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,31 +2,50 @@

__all__ = ["IC"]

from typing import Any, Callable, List, Optional, overload, Union

import numpy as np
from numpy.typing import NDArray, ArrayLike

from .boundary_conditions import npfunc_range_autocache
from .. import backend as bkd
from .. import utils
from ..geometry import Geometry
from ..types import Tensor, TensorOrTensors


class IC:
"""Initial conditions: y([x, t0]) = func([x, t0])."""

def __init__(self, geom, func, on_initial, component=0):
def __init__(
self,
geom: Geometry,
func: Callable[[NDArray[np.float_]], NDArray[np.float_]],
on_initial: Callable[[NDArray[Any], NDArray[Any]], NDArray[np.bool_]],
component: Union[List[int], int] = 0,
):
self.geom = geom
self.func = npfunc_range_autocache(utils.return_tensor(func))
self.on_initial = lambda x, on: np.array(
[on_initial(x[i], on[i]) for i in range(len(x))]
)
self.component = component

def filter(self, X):
def filter(self, X: NDArray[np.float_]) -> NDArray[np.bool_]:
return X[self.on_initial(X, self.geom.on_initial(X))]

def collocation_points(self, X):
def collocation_points(self, X: NDArray[np.float_]) -> NDArray[np.float_]:
return self.filter(X)

def error(self, X, inputs, outputs, beg, end, aux_var=None):
def error(
self,
X: NDArray[np.float_],
inputs: TensorOrTensors,
outputs: Tensor,
beg: int,
end: int,
aux_var: Union[NDArray[np.float_], None] = None,
) -> Tensor:
values = self.func(X, beg, end, aux_var)
if bkd.ndim(values) == 2 and bkd.shape(values)[1] != 1:
raise RuntimeError(
Expand Down
7 changes: 7 additions & 0 deletions deepxde/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
import numpy as np
from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union, TypeVar
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union, TypeVar
from __future__ import annotations
from typing import Sequence, TypeVar

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for your suggestions, I will change them later 👍

from numpy.typing import NDArray, ArrayLike

# Tensor from any backend
Tensor = TypeVar("Tensor")
TensorOrTensors = Union[Tensor, Sequence[Tensor]]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
TensorOrTensors = Union[Tensor, Sequence[Tensor]]
TensorOrTensors = Tensor | Sequence[Tensor]