Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add typing and type annotations for icbc #1453

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 61 additions & 47 deletions deepxde/backend/backend.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
from __future__ import annotations

from numbers import Number
from typing import Sequence, overload
from ..types import Tensor, dtype, SparseTensor, TensorOrTensors
from numpy.typing import NDArray, ArrayLike
"""This file defines the unified tensor framework interface required by DeepXDE.

The principles of this interface:
Expand Down Expand Up @@ -26,7 +32,7 @@
# Tensor, data type and context interfaces


def data_type_dict():
def data_type_dict() -> dict[str, object]:
"""Returns a dictionary from data type string to the data type.

The dictionary should include at least:
Expand Down Expand Up @@ -58,19 +64,19 @@ def data_type_dict():
"""


def is_gpu_available():
def is_gpu_available() -> bool:
"""Returns a bool indicating if GPU is currently available.

Returns:
True if a GPU device is available.
"""


def is_tensor(obj):
def is_tensor(obj: object) -> bool:
"""Returns True if `obj` is a backend-native type tensor."""


def shape(input_tensor):
def shape(input_tensor: Tensor) -> Sequence[int]:
"""Return the shape of the tensor.

Args:
Expand All @@ -81,7 +87,7 @@ def shape(input_tensor):
"""


def size(input_tensor):
def size(input_tensor: Tensor) -> int:
"""Return the total number of elements in the input tensor.

Args:
Expand All @@ -92,7 +98,7 @@ def size(input_tensor):
"""


def ndim(input_tensor):
def ndim(input_tensor: Tensor) -> int:
"""Returns the number of dimensions of the tensor.

Args:
Expand All @@ -103,7 +109,7 @@ def ndim(input_tensor):
"""


def transpose(tensor, axes=None):
def transpose(tensor: Tensor, axes: Sequence[int] | int | None = None) -> Tensor:
"""Reverse or permute the axes of a tensor; returns the modified array.

For a tensor with two axes, transpose gives the matrix transpose.
Expand All @@ -117,7 +123,7 @@ def transpose(tensor, axes=None):
"""


def reshape(tensor, shape):
def reshape(tensor: Tensor, shape: Sequence[int]) -> Tensor:
"""Gives a new shape to a tensor without changing its data.

Args:
Expand All @@ -130,7 +136,7 @@ def reshape(tensor, shape):
"""


def Variable(initial_value, dtype=None):
def Variable(initial_value: Number, dtype: dtype = None) -> Tensor:
"""Return a trainable variable.

Args:
Expand All @@ -140,7 +146,7 @@ def Variable(initial_value, dtype=None):
"""


def as_tensor(data, dtype=None):
def as_tensor(data: ArrayLike, dtype: dtype = None) -> Tensor:
"""Convert the data to a Tensor.

If the data is already a tensor and has the same dtype, directly return.
Expand All @@ -155,7 +161,7 @@ def as_tensor(data, dtype=None):
"""


def sparse_tensor(indices, values, shape):
def sparse_tensor(indices: Sequence[Sequence[Number, Number]], values: Tensor, shape: Sequence[int]) -> SparseTensor:
"""Construct a sparse tensor based on given indices, values and shape.

Args:
Expand All @@ -170,7 +176,7 @@ def sparse_tensor(indices, values, shape):
"""


def from_numpy(np_array):
def from_numpy(np_array: NDArray) -> Tensor:
"""Create a tensor that shares the underlying numpy array memory, if possible.

Args:
Expand All @@ -181,7 +187,7 @@ def from_numpy(np_array):
"""


def to_numpy(input_tensor):
def to_numpy(input_tensor: Tensor) -> NDArray:
"""Create a numpy ndarray that shares the same underlying storage, if possible.

Args:
Expand All @@ -192,7 +198,7 @@ def to_numpy(input_tensor):
"""


def concat(values, axis):
def concat(values: TensorOrTensors, axis: int) -> Tensor:
"""Returns the concatenation of the input tensors along the given dim.

Args:
Expand All @@ -204,7 +210,7 @@ def concat(values, axis):
"""


def stack(values, axis):
def stack(values: TensorOrTensors, axis: int) -> Tensor:
"""Returns the stack of the input tensors along the given dim.

Args:
Expand All @@ -216,7 +222,7 @@ def stack(values, axis):
"""


def expand_dims(tensor, axis):
def expand_dims(tensor: Tensor, axis: int) -> Tensor:
"""Expand dim for tensor along given axis.

Args:
Expand All @@ -228,7 +234,7 @@ def expand_dims(tensor, axis):
"""


def reverse(tensor, axis):
def reverse(tensor: Tensor, axis: int) -> Tensor:
"""Reverse the order of elements along the given axis.

Args:
Expand All @@ -240,7 +246,7 @@ def reverse(tensor, axis):
"""


def roll(tensor, shift, axis):
def roll(tensor: Tensor, shift: int | Sequence[int], axis: int | Sequence[int]) -> Tensor:
"""Roll the tensor along the given axis (axes).

Args:
Expand All @@ -261,67 +267,67 @@ def roll(tensor, shift, axis):
# implementation in each framework.


def lgamma(x):
def lgamma(x: Tensor) -> Tensor:
"""Computes the natural logarithm of the absolute value of the gamma function of x
element-wise.
"""


def elu(x):
def elu(x: Tensor) -> Tensor:
"""Computes the exponential linear function."""


def relu(x):
def relu(x: Tensor) -> Tensor:
"""Applies the rectified linear unit activation function."""


def gelu(x):
def gelu(x: Tensor) -> Tensor:
"""Computes Gaussian Error Linear Unit function."""


def selu(x):
def selu(x: Tensor) -> Tensor:
"""Computes scaled exponential linear."""


def sigmoid(x):
def sigmoid(x: Tensor) -> Tensor:
"""Computes sigmoid of x element-wise."""


def silu(x):
def silu(x: Tensor) -> Tensor:
"""Sigmoid Linear Unit (SiLU) function, also known as the swish function.
silu(x) = x * sigmoid(x).
"""


def sin(x):
def sin(x: Tensor) -> Tensor:
"""Computes sine of x element-wise."""


def cos(x):
def cos(x: Tensor) -> Tensor:
"""Computes cosine of x element-wise."""


def exp(x):
def exp(x: Tensor) -> Tensor:
"""Computes exponential of x element-wise."""


def square(x):
def square(x: Tensor) -> Tensor:
"""Returns the square of the elements of input."""


def abs(x):
def abs(x: Tensor) -> Tensor:
"""Computes the absolute value element-wise."""


def minimum(x, y):
def minimum(x: Tensor, y: Tensor) -> Tensor:
"""Returns the minimum of x and y (i.e. x < y ? x : y) element-wise."""


def tanh(x):
def tanh(x: Tensor) -> Tensor:
"""Computes hyperbolic tangent of x element-wise."""


def pow(x, y):
def pow(x: Tensor, y: Number | Tensor) -> Tensor:
"""Computes the power of one value to another: x ^ y."""


Expand All @@ -332,15 +338,15 @@ def pow(x, y):
# implementation in each framework.


def mean(input_tensor, dim, keepdims=False):
def mean(input_tensor: Tensor, dim: int | Sequence[int], keepdims: bool = False) -> Tensor:
"""Returns the mean value of the input tensor in the given dimension dim."""


def reduce_mean(input_tensor):
def reduce_mean(input_tensor: Tensor) -> Tensor:
"""Returns the mean value of all elements in the input tensor."""


def sum(input_tensor, dim, keepdims=False):
def sum(input_tensor: Tensor, dim: int | Sequence[int], keepdims: Tensor = False):
"""Returns the sum of the input tensor along the given dim.

Args:
Expand All @@ -353,7 +359,7 @@ def sum(input_tensor, dim, keepdims=False):
"""


def reduce_sum(input_tensor):
def reduce_sum(input_tensor: Tensor) -> Tensor:
"""Returns the sum of all elements in the input tensor.

Args:
Expand All @@ -364,7 +370,7 @@ def reduce_sum(input_tensor):
"""


def prod(input_tensor, dim, keepdims=False):
def prod(input_tensor: Tensor, dim: int | Sequence[int], keepdims: bool = False) -> Tensor:
"""Returns the product of the input tensor along the given dim.

Args:
Expand All @@ -377,7 +383,7 @@ def prod(input_tensor, dim, keepdims=False):
"""


def reduce_prod(input_tensor):
def reduce_prod(input_tensor: Tensor) -> Tensor:
"""Returns the product of all elements in the input tensor.

Args:
Expand All @@ -388,7 +394,7 @@ def reduce_prod(input_tensor):
"""


def min(input_tensor, dim, keepdims=False):
def min(input_tensor: Tensor, dim: int | Sequence[int], keepdims: bool = False) -> Tensor:
"""Returns the minimum of the input tensor along the given dim.

Args:
Expand All @@ -401,7 +407,7 @@ def min(input_tensor, dim, keepdims=False):
"""


def reduce_min(input_tensor):
def reduce_min(input_tensor: Tensor) -> Tensor:
"""Returns the minimum of all elements in the input tensor.

Args:
Expand All @@ -412,7 +418,7 @@ def reduce_min(input_tensor):
"""


def max(input_tensor, dim, keepdims=False):
def max(input_tensor: Tensor, dim: int | Sequence[int], keepdims: bool = False) -> Tensor:
"""Returns the maximum of the input tensor along the given dim.

Args:
Expand All @@ -425,7 +431,7 @@ def max(input_tensor, dim, keepdims=False):
"""


def reduce_max(input_tensor):
def reduce_max(input_tensor: Tensor) -> Tensor:
"""Returns the maximum of all elements in the input tensor.

Args:
Expand All @@ -436,7 +442,7 @@ def reduce_max(input_tensor):
"""


def norm(tensor, ord=None, axis=None, keepdims=False):
def norm(tensor: Tensor, ord: Number | None = None, axis: int | None = None, keepdims: bool = False) -> Tensor:
"""Computes a vector norm.

Due to the incompatibility of different backends, only some vector norms are
Expand All @@ -457,7 +463,7 @@ def norm(tensor, ord=None, axis=None, keepdims=False):
"""


def zeros(shape, dtype):
def zeros(shape: Sequence[int], dtype: dtype) -> Tensor:
"""Creates a tensor with all elements set to zero.

Args:
Expand All @@ -469,7 +475,7 @@ def zeros(shape, dtype):
"""


def zeros_like(input_tensor):
def zeros_like(input_tensor: Tensor) -> Tensor:
"""Create a zero tensor with the same shape, dtype and context of the given tensor.

Args:
Expand All @@ -480,7 +486,7 @@ def zeros_like(input_tensor):
"""


def matmul(x, y):
def matmul(x: Tensor, y: Tensor) -> Tensor:
"""Compute matrix multiplication for two matrices x and y.

Args:
Expand All @@ -492,6 +498,14 @@ def matmul(x, y):
"""


@overload
def sparse_dense_matmul(x: SparseTensor, y: Tensor) -> Tensor: ...


@overload
def sparse_dense_matmul(x: SparseTensor, y: SparseTensor) -> SparseTensor: ...


def sparse_dense_matmul(x, y):
"""Compute matrix multiplication of a sparse matrix x and a sparse/dense matrix y.

Expand Down
Loading