Skip to content

Commit

Permalink
Added Hardconstraint Models
Browse files Browse the repository at this point in the history
Added a model object for hardconstrain of boundary conditions
A edge case is added, if the hardconstraint function is not passed, then a lambda function, which returns the output as it is is added
An example file on how to run the hard constrained problem for poisson problem is also described
  • Loading branch information
divijghose committed Apr 24, 2024
1 parent 86488b9 commit 5348120
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from fastvpinns.Geometry.geometry_2d import Geometry_2D
from fastvpinns.FE_2D.fespace2d import Fespace2D
from fastvpinns.data.datahandler2d import DataHandler2D
from fastvpinns.model.model_hard import DenseModel
from fastvpinns.model.model_hard import DenseModel_Hard
from fastvpinns.physics.poisson2d import pde_loss_poisson
from fastvpinns.utils.plot_utils import plot_contour, plot_loss_function, plot_test_loss_function
from fastvpinns.utils.compute_utils import compute_errors_combined
Expand Down Expand Up @@ -152,7 +152,22 @@
# and convert them into tensors of desired dtype
bilinear_params_dict = datahandler.get_bilinear_params_dict_as_tensors(get_bilinear_params_dict)

model = DenseModel(
# Setup the hard constraints
@tf.function
def apply_hard_boundary_constraints(inputs, x):
"""This method applies hard boundary constraints to the model.
:param inputs: Input tensor
:type inputs: tf.Tensor
:param x: Output tensor from the model
:type x: tf.Tensor
:return: Output tensor with hard boundary constraints
:rtype: tf.Tensor
"""
ansatz = tf.tanh(4.0*np.pi*inputs[:,0:1])*tf.tanh(4.0*np.pi*inputs[:,1:2])*tf.tanh(4.0*np.pi*(inputs[:,0:1]-1.0))*tf.tanh(4.0*np.pi*(inputs[:,1:2]-1.0))
ansatz = tf.cast(ansatz, i_dtype)
return ansatz*x

model = DenseModel_Hard(
layer_dims=[2, 30, 30, 30, 1],
learning_rate_dict=i_learning_rate_dict,
params_dict=params_dict,
Expand All @@ -168,6 +183,7 @@
use_attention=i_use_attention,
activation=i_activation,
hessian=False,
hard_constraint_function = apply_hard_boundary_constraints,
)

test_points = domain.get_test_points()
Expand Down
37 changes: 13 additions & 24 deletions fastvpinns/model/model_hard.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ def custom_loss2(y_true2, y_pred2):


# Custom Model
class DenseModel(tf.keras.Model):
""" The DenseModel class is a custom model class that hosts the neural network model.
class DenseModel_Hard(tf.keras.Model):
"""The DenseModel_Hard class is a custom model class that hosts the neural network model.
The class inherits from the tf.keras.Model class and is used
to define the neural network model architecture and the training loop for FastVPINNs.
Expand Down Expand Up @@ -64,9 +64,8 @@ class DenseModel(tf.keras.Model):
This method is used to get the configuration of the model.
train_step(beta=10, bilinear_params_dict=None)
This method is used to define the training step of the model.
"""
"""

def __init__(
self,
Expand All @@ -81,14 +80,19 @@ def __init__(
use_attention=False,
activation="tanh",
hessian=False,
hard_constraint_function=None,
):
super(DenseModel, self).__init__()
super(DenseModel_Hard, self).__init__()
self.layer_dims = layer_dims
self.use_attention = use_attention
self.activation = activation
self.layer_list = []
self.loss_function = loss_function
self.hessian = hessian
if hard_constraint_function is None:
self.hard_constraint_function = lambda x, y: y
else:
self.hard_constraint_function = hard_constraint_function

self.tensor_dtype = tensor_dtype

Expand Down Expand Up @@ -208,21 +212,8 @@ def __init__(
# def build(self, input_shape):
# super(DenseModel, self).build(input_shape)

def apply_hard_boundary_constraints(self, inputs, x):
""" This method applies hard boundary constraints to the model.
:param inputs: Input tensor
:type inputs: tf.Tensor
:param x: Output tensor from the model
:type x: tf.Tensor
:return: Output tensor with hard boundary constraints
:rtype: tf.Tensor
"""
ansatz = tf.tanh(4.0*np.pi*inputs[:,0:1])*tf.tanh(4.0*np.pi*inputs[:,1:2])*tf.tanh(4.0*np.pi*(inputs[:,0:1]-1.0))*tf.tanh(4.0*np.pi*(inputs[:,1:2]-1.0))
ansatz = tf.cast(ansatz, self.tensor_dtype)
return ansatz*x

def call(self, inputs):
""" This method is used to define the forward pass of the model.
"""This method is used to define the forward pass of the model.
:param inputs: Input tensor
:type inputs: tf.Tensor
:return: Output tensor from the model
Expand All @@ -238,12 +229,12 @@ def call(self, inputs):
for layer in self.layer_list:
x = layer(x)

x = self.apply_hard_boundary_constraints(inputs, x)
x = self.hard_constraint_function(inputs, x)

return x

def get_config(self):
""" This method is used to get the configuration of the model.
"""This method is used to get the configuration of the model.
:return: Configuration of the model
:rtype: dict
"""
Expand Down Expand Up @@ -271,7 +262,7 @@ def get_config(self):

@tf.function
def train_step(self, beta=10, bilinear_params_dict=None):
""" This method is used to define the training step of the mode.
"""This method is used to define the training step of the mode.
:param bilinear_params_dict: Dictionary containing the bilinear parameters
:type bilinear_params_dict: dict
:return: Dictionary containing the loss values
Expand Down Expand Up @@ -328,8 +319,6 @@ def train_step(self, beta=10, bilinear_params_dict=None):
# convert predicted_values_dirichlet to tf.float64
# predicted_values_dirichlet = tf.cast(predicted_values_dirichlet, tf.float64)



# tf.print("Boundary Loss : ", boundary_loss)
# tf.print("Boundary Loss Shape : ", boundary_loss.shape)
# tf.print("Total PDE Loss : ", total_pde_loss)
Expand Down
2 changes: 1 addition & 1 deletion fastvpinns/model/model_inverse.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,4 +332,4 @@ def train_step(self, beta=10, bilinear_params_dict=None):
"loss": total_loss,
"inverse_params": self.inverse_params_dict,
"sensor_loss": sensor_loss,
}
}

0 comments on commit 5348120

Please sign in to comment.