Skip to content

Commit

Permalink
feat(solvers): add Universal Solvers
Browse files Browse the repository at this point in the history
  • Loading branch information
sathvikbhagavan committed Jun 23, 2023
1 parent 85e8641 commit 9732955
Show file tree
Hide file tree
Showing 2 changed files with 312 additions and 5 deletions.
5 changes: 5 additions & 0 deletions neurodiffeq/conditions.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,11 @@ def parameterize(self, output_tensor, t):
:rtype: `torch.Tensor`
"""
if self.u_0_prime is None:
if isinstance(self.u_0, list):
parameterized = torch.zeros_like(output_tensor)
for i in range(len(self.u_0)):
parameterized[:, i] = (self.u_0[i] + (1 - torch.exp(-t + self.t_0)) * output_tensor[:, i].view(-1, 1))[:, 0]
return parameterized
return self.u_0 + (1 - torch.exp(-t + self.t_0)) * output_tensor
else:
return self.u_0 + (t - self.t_0) * self.u_0_prime + ((1 - torch.exp(-t + self.t_0)) ** 2) * output_tensor
Expand Down
312 changes: 307 additions & 5 deletions neurodiffeq/solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from .generators import Generator2D
from .generators import GeneratorND
from .function_basis import RealSphericalHarmonics
from .conditions import BaseCondition
from .conditions import BaseCondition, NoCondition
from .neurodiffeq import safe_diff as diff
from .losses import _losses

Expand Down Expand Up @@ -113,7 +113,7 @@ class BaseSolver(ABC, PretrainedSolver):
def __init__(self, diff_eqs, conditions,
nets=None, train_generator=None, valid_generator=None, analytic_solutions=None,
optimizer=None, loss_fn=None, n_batches_train=1, n_batches_valid=4,
metrics=None, n_input_units=None, n_output_units=None,
metrics=None, n_input_units=None, n_output_units=None, system_parameters=None,
# deprecated arguments are listed below
shuffle=None, batch_size=None):
# deprecate argument `shuffle`
Expand All @@ -130,6 +130,9 @@ def __init__(self, diff_eqs, conditions,
)

self.diff_eqs = diff_eqs
self.system_parameters = {}
if system_parameters is not None:
self.system_parameters = system_parameters
self.conditions = conditions
self.n_funcs = len(conditions)
if nets is None:
Expand Down Expand Up @@ -376,7 +379,7 @@ def closure(zero_grad=True):
for name in self.metrics_fn:
value = self.metrics_fn[name](*funcs, *batch).item()
metric_values[name] += value
residuals = self.diff_eqs(*funcs, *batch)
residuals = self.diff_eqs(*funcs, *batch, **self.system_parameters)
residuals = torch.cat(residuals, dim=1)
try:
loss = self.loss_fn(residuals, funcs, batch) + self.additional_loss(residuals, funcs, batch)
Expand Down Expand Up @@ -1105,7 +1108,7 @@ class Solver1D(BaseSolver):

def __init__(self, ode_system, conditions, t_min=None, t_max=None,
nets=None, train_generator=None, valid_generator=None, analytic_solutions=None, optimizer=None,
loss_fn=None, n_batches_train=1, n_batches_valid=4, metrics=None, n_output_units=1,
loss_fn=None, n_batches_train=1, n_batches_valid=4, metrics=None, n_output_units=1, system_parameters=None,
# deprecated arguments are listed below
batch_size=None, shuffle=None):

Expand Down Expand Up @@ -1136,6 +1139,7 @@ def __init__(self, ode_system, conditions, t_min=None, t_max=None,
metrics=metrics,
n_input_units=1,
n_output_units=n_output_units,
system_parameters=system_parameters,
shuffle=shuffle,
batch_size=batch_size,
)
Expand Down Expand Up @@ -1164,11 +1168,12 @@ def get_solution(self, copy=True, best=True):
:rtype: BaseSolution
"""
nets = self.best_nets if best else self.nets
print(nets)
conditions = self.conditions
if copy:
nets = deepcopy(nets)
conditions = deepcopy(conditions)

print(nets)
return Solution1D(nets, conditions)

def _get_internal_variables(self):
Expand Down Expand Up @@ -1590,3 +1595,300 @@ def _get_internal_variables(self):
'xy_max': self.xy_max,
})
return available_variables

class _SingleSolver1D(GenericSolver):

class Head(nn.Module):
def __init__(self, u_0, base, n_input, n_output=1):
super().__init__()
self.u_0 = u_0
self.base = base
self.last_layer = nn.Linear(n_input, n_output)

def forward(self, x):
x = self.base(x)
x = self.last_layer(x)
return x

def __init__(self, bases, HeadClass, initial_conditions, n_last_layer_head, diff_eqs,
system_parameters=[{}],
optimizer=torch.optim.Adam, optimizer_args=None, optimizer_kwargs={"lr":1e-3},
train_generator=None, valid_generator=None, n_batches_train=1, n_batches_valid=4,
loss_fn=None, metrics=None, is_system=False):

if train_generator is None or valid_generator is None:
raise Exception(f"Train and Valid Generator cannot be None")

self.num = len(initial_conditions)
self.bases = bases
if HeadClass is None:
if is_system:
self.head = [self.Head(initial_conditions[i], self.bases[i], n_last_layer_head) for i in range(self.num)]
else:
self.head = [self.Head(torch.Tensor(initial_conditions).view(1, -1), self.bases, n_last_layer_head, len(initial_conditions))]
else:
if is_system:
self.head = [HeadClass(initial_conditions[i], self.bases[i], n_last_layer_head) for i in range(self.num)]
else:
self.head = [HeadClass(torch.Tensor(initial_conditions).view(1, -1), self.bases, n_last_layer_head, len(initial_conditions))]

self.optimizer_args = optimizer_args or ()
self.optimizer_kwargs = optimizer_kwargs or {}

if isinstance(optimizer, torch.optim.Optimizer):
self.optimizer = optimizer
elif issubclass(optimizer, torch.optim.Optimizer):
params = chain.from_iterable(n.parameters() for n in self.head)
self.optimizer = optimizer(params, *self.optimizer_args, **self.optimizer_kwargs)
else:
raise TypeError(f"Unknown optimizer instance/type {self.optimizer}")

super().__init__(
diff_eqs=diff_eqs,
conditions=[NoCondition()]*self.num,
train_generator=train_generator,
valid_generator=valid_generator,
nets=self.head,
system_parameters=system_parameters,
optimizer=self.optimizer,
n_batches_train=n_batches_train,
n_batches_valid=n_batches_valid,
loss_fn=loss_fn,
metrics=metrics
)

def additional_loss(self, residuals, funcs, coords):

loss = 0
for i in range(len(self.nets)):
out = self.nets[i](torch.zeros((1,1)))
loss += ((self.nets[i].u_0 - out)**2).mean()
return loss


class UniversalSolver1D(ABC):
r"""A solver class for solving a family of ODEs (for different initial conditions and parameters)
:param ode_system:
The ODE system to solve, which maps a torch.Tensor to a tuple of ODE residuals,
both the input and output must have shape (n_samples, 1).
:type ode_system: callable
"""

class Base(nn.Module):
def __init__(self):
super().__init__()
self.linear_1 = nn.Linear(1, 10)
self.linear_2 = nn.Linear(10, 10)
self.linear_3 = nn.Linear(10, 10)

def forward(self, x):
x = self.linear_1(x)
x = torch.tanh(x)
x = self.linear_2(x)
x = torch.tanh(x)
x = self.linear_3(x)
x = torch.tanh(x)
return x

def __init__(self, diff_eqs, is_system = True):

self.diff_eqs = diff_eqs
self.is_system = is_system

self.t_min = None
self.t_max = None
self.train_generator = None
self.valid_generator = None

def build(self,u_0s=None,
system_parameters=[{}],
BaseClass=Base,
HeadClass=None,
n_last_layer_head=10,
build_source=False,
optimizer=torch.optim.Adam,
optimizer_args=None, optimizer_kwargs={"lr":1e-3},
t_min=None,
t_max=None,
train_generator=None,
valid_generator=None,
n_batches_train=1,
n_batches_valid=4,
loss_fn=None,
metrics=None):

r"""
:param system_parameters:
List of dictionaries of parameters for which the solver will be trained
:type system_parameters: list[dict]
:param BaseClass:
Neural network class for base networks
:type nets: torch.nn.Module
:param n_last_layer_head:
Number of neurons in the last layer for each network
:type n_last_layer_head: int
:param build_source:
Boolean value for training the base networks or freezing their weights
:type build_source: bool
:param optimizer:
Optimizer to be used for training.
Defaults to a ``torch.optim.Adam`` instance that trains on all parameters of ``nets``.
:type optimizer: ``torch.nn.optim.Optimizer``, optional
:param t_min:
Lower bound of input (start time).
Ignored if ``train_generator`` and ``valid_generator`` are both set.
:type t_min: float, optional
:param t_max:
Upper bound of input (start time).
Ignored if ``train_generator`` and ``valid_generator`` are both set.
:type t_max: float, optional
:param train_generator:
Generator for sampling training points,
which must provide a ``.get_examples()`` method and a ``.size`` field.
``train_generator`` must be specified if ``t_min`` and ``t_max`` are not set.
:type train_generator: `neurodiffeq.generators.BaseGenerator`, optional
:param valid_generator:
Generator for sampling validation points,
which must provide a ``.get_examples()`` method and a ``.size`` field.
``valid_generator`` must be specified if ``t_min`` and ``t_max`` are not set.
:type valid_generator: `neurodiffeq.generators.BaseGenerator`, optional
:param n_batches_train:
Number of batches to train in every epoch, where batch-size equals ``train_generator.size``.
Defaults to 1.
:type n_batches_train: int, optional
:param n_batches_valid:
Number of batches to validate in every epoch, where batch-size equals ``valid_generator.size``.
Defaults to 4.
:type n_batches_valid: int, optional
:param loss_fn:
The loss function used for training.
- If a str, must be present in the keys of `neurodiffeq.losses._losses`.
- If a `torch.nn.modules.loss._Loss` instance, just pass the instance.
- If any other callable, it must map
A) a residual tensor (shape `(n_points, n_equations)`),
B) a function values tuple (length `n_funcs`, each element a tensor of shape `(n_points, 1)`), and
C) a coordinate values tuple (length `n_coords`, each element a tensor of shape `(n_coords, 1)`
to a tensor of empty shape (i.e. a scalar). The returned tensor must be connected to the computational graph,
so that backpropagation can be performed.
:type loss_fn:
str or `torch.nn.moduesl.loss._Loss` or callable
:param metrics:
Additional metrics to be logged (besides loss). ``metrics`` should be a dict where
- Keys are metric names (e.g. 'analytic_mse');
- Values are functions (callables) that computes the metric value.
These functions must accept the same input as the differential equation ``ode_system``.
:type metrics: dict[str, callable], optional
"""

self.u_0s = u_0s
self.system_parameters = system_parameters
self.n_last_layer_head = n_last_layer_head

if t_min is not None:
self.t_min = t_min
if t_max is not None:
self.t_max = t_max

if self.t_min is not None and self.t_max is not None:
self.train_generator = Generator1D(32, t_min=self.t_min, t_max=self.t_max, method='equally-spaced-noisy')
self.valid_generator = Generator1D(32, t_min=self.t_min, t_max=self.t_max, method='equally-spaced-noisy')

if train_generator is not None:
self.train_generator = train_generator
if valid_generator is not None:
self.valid_generator = valid_generator

if self.u_0s is None:
raise Exception("ICs must be specified")
if self.train_generator is None or self.valid_generator is None:
raise Exception(f"Train and valid generators cannot be None. Either provide `t_min` and `t_max` \
or provide the generators as arguments")

self.optimizer = optimizer
self.optimizer_args = optimizer_args or ()
self.optimizer_kwargs = optimizer_kwargs or {}

if build_source:
if self.is_system:
self.bases = [BaseClass() for _ in range(len(u_0s[0]))]
else:
self.bases = BaseClass()

self.solvers_base = [_SingleSolver1D(
bases=self.bases,
HeadClass=HeadClass,
initial_conditions=self.u_0s[i],
n_last_layer_head=n_last_layer_head,
diff_eqs=self.diff_eqs,
train_generator=self.train_generator,
valid_generator=self.valid_generator,
system_parameters=self.system_parameters[p],
optimizer=optimizer,optimizer_args=optimizer_args, optimizer_kwargs=optimizer_kwargs,
n_batches_train=n_batches_train,
n_batches_valid=n_batches_valid,
loss_fn=loss_fn,
metrics=metrics,
is_system=self.is_system
) for i in range(len(u_0s)) for p in range(len(self.system_parameters))]
else:
self.solvers_head = [_SingleSolver1D(
bases=self.bases,
HeadClass=HeadClass,
initial_conditions=self.u_0s[i],
n_last_layer_head=self.n_last_layer_head,
diff_eqs=self.diff_eqs,
train_generator=self.train_generator,
valid_generator=self.valid_generator,
system_parameters=self.system_parameters[p],
optimizer=optimizer,optimizer_args=optimizer_args, optimizer_kwargs=optimizer_kwargs,
n_batches_train=n_batches_train,
n_batches_valid=n_batches_valid,
loss_fn=loss_fn,
metrics=metrics,
is_system=self.is_system
) for i in range(len(self.u_0s)) for p in range(len(self.system_parameters))]


def fit(self, epochs=10, freeze_source=True):
r"""
:param epochs:
Number of epochs for training
:type epochs: int
:param freeze_source:
Boolean value indicating whether to freeze the base networks or not
:type freeze_source: bool
"""

if not freeze_source:
for i in range(len(self.solvers_base)):
self.solvers_base[i].fit(max_epochs=epochs)
else:
if self.is_system:
for net in self.bases:
for param in net.parameters():
param.requires_grad = False
else:
for param in self.bases.parameters():
param.requires_grad = False
for i in range(len(self.solvers_head)):
self.solvers_head[i].fit(max_epochs=epochs)


def get_solution(self, base=False):
r"""
:param base:
Boolean value indicating whether to get solutions for those conditions for which the base
was trained or solutions for those conditions for which only the last layer was trained
:type base: bool
:rtype: list[BaseSolution]
"""

if base:
return [self.solvers_base[i].get_solution() for i in range(len(self.solvers_base))]
else:
return [self.solvers_head[i].get_solution() for i in range(len(self.solvers_head))]

0 comments on commit 9732955

Please sign in to comment.