From 53481203c6eeeb08922224f317b1e103da30e95b Mon Sep 17 00:00:00 2001 From: divijghose Date: Wed, 24 Apr 2024 19:47:04 +0530 Subject: [PATCH] Added Hardconstraint Models Added a model object for hardconstrain of boundary conditions A edge case is added, if the hardconstraint function is not passed, then a lambda function, which returns the output as it is is added An example file on how to run the hard constrained problem for poisson problem is also described --- .../poisson_2d/main_poisson2d_hard.py | 20 +++++++++- fastvpinns/model/model_hard.py | 37 +++++++------------ fastvpinns/model/model_inverse.py | 2 +- 3 files changed, 32 insertions(+), 27 deletions(-) diff --git a/examples/forward_problems_2d/hard_boundary_constraints/poisson_2d/main_poisson2d_hard.py b/examples/forward_problems_2d/hard_boundary_constraints/poisson_2d/main_poisson2d_hard.py index e5c229e..4800062 100644 --- a/examples/forward_problems_2d/hard_boundary_constraints/poisson_2d/main_poisson2d_hard.py +++ b/examples/forward_problems_2d/hard_boundary_constraints/poisson_2d/main_poisson2d_hard.py @@ -17,7 +17,7 @@ from fastvpinns.Geometry.geometry_2d import Geometry_2D from fastvpinns.FE_2D.fespace2d import Fespace2D from fastvpinns.data.datahandler2d import DataHandler2D -from fastvpinns.model.model_hard import DenseModel +from fastvpinns.model.model_hard import DenseModel_Hard from fastvpinns.physics.poisson2d import pde_loss_poisson from fastvpinns.utils.plot_utils import plot_contour, plot_loss_function, plot_test_loss_function from fastvpinns.utils.compute_utils import compute_errors_combined @@ -152,7 +152,22 @@ # and convert them into tensors of desired dtype bilinear_params_dict = datahandler.get_bilinear_params_dict_as_tensors(get_bilinear_params_dict) - model = DenseModel( + # Setup the hard constraints + @tf.function + def apply_hard_boundary_constraints(inputs, x): + """This method applies hard boundary constraints to the model. + :param inputs: Input tensor + :type inputs: tf.Tensor + :param x: Output tensor from the model + :type x: tf.Tensor + :return: Output tensor with hard boundary constraints + :rtype: tf.Tensor + """ + ansatz = tf.tanh(4.0*np.pi*inputs[:,0:1])*tf.tanh(4.0*np.pi*inputs[:,1:2])*tf.tanh(4.0*np.pi*(inputs[:,0:1]-1.0))*tf.tanh(4.0*np.pi*(inputs[:,1:2]-1.0)) + ansatz = tf.cast(ansatz, i_dtype) + return ansatz*x + + model = DenseModel_Hard( layer_dims=[2, 30, 30, 30, 1], learning_rate_dict=i_learning_rate_dict, params_dict=params_dict, @@ -168,6 +183,7 @@ use_attention=i_use_attention, activation=i_activation, hessian=False, + hard_constraint_function = apply_hard_boundary_constraints, ) test_points = domain.get_test_points() diff --git a/fastvpinns/model/model_hard.py b/fastvpinns/model/model_hard.py index c401a60..c3302e1 100644 --- a/fastvpinns/model/model_hard.py +++ b/fastvpinns/model/model_hard.py @@ -28,8 +28,8 @@ def custom_loss2(y_true2, y_pred2): # Custom Model -class DenseModel(tf.keras.Model): - """ The DenseModel class is a custom model class that hosts the neural network model. +class DenseModel_Hard(tf.keras.Model): + """The DenseModel_Hard class is a custom model class that hosts the neural network model. The class inherits from the tf.keras.Model class and is used to define the neural network model architecture and the training loop for FastVPINNs. @@ -64,9 +64,8 @@ class DenseModel(tf.keras.Model): This method is used to get the configuration of the model. train_step(beta=10, bilinear_params_dict=None) This method is used to define the training step of the model. - - """ + """ def __init__( self, @@ -81,14 +80,19 @@ def __init__( use_attention=False, activation="tanh", hessian=False, + hard_constraint_function=None, ): - super(DenseModel, self).__init__() + super(DenseModel_Hard, self).__init__() self.layer_dims = layer_dims self.use_attention = use_attention self.activation = activation self.layer_list = [] self.loss_function = loss_function self.hessian = hessian + if hard_constraint_function is None: + self.hard_constraint_function = lambda x, y: y + else: + self.hard_constraint_function = hard_constraint_function self.tensor_dtype = tensor_dtype @@ -208,21 +212,8 @@ def __init__( # def build(self, input_shape): # super(DenseModel, self).build(input_shape) - def apply_hard_boundary_constraints(self, inputs, x): - """ This method applies hard boundary constraints to the model. - :param inputs: Input tensor - :type inputs: tf.Tensor - :param x: Output tensor from the model - :type x: tf.Tensor - :return: Output tensor with hard boundary constraints - :rtype: tf.Tensor - """ - ansatz = tf.tanh(4.0*np.pi*inputs[:,0:1])*tf.tanh(4.0*np.pi*inputs[:,1:2])*tf.tanh(4.0*np.pi*(inputs[:,0:1]-1.0))*tf.tanh(4.0*np.pi*(inputs[:,1:2]-1.0)) - ansatz = tf.cast(ansatz, self.tensor_dtype) - return ansatz*x - def call(self, inputs): - """ This method is used to define the forward pass of the model. + """This method is used to define the forward pass of the model. :param inputs: Input tensor :type inputs: tf.Tensor :return: Output tensor from the model @@ -238,12 +229,12 @@ def call(self, inputs): for layer in self.layer_list: x = layer(x) - x = self.apply_hard_boundary_constraints(inputs, x) + x = self.hard_constraint_function(inputs, x) return x def get_config(self): - """ This method is used to get the configuration of the model. + """This method is used to get the configuration of the model. :return: Configuration of the model :rtype: dict """ @@ -271,7 +262,7 @@ def get_config(self): @tf.function def train_step(self, beta=10, bilinear_params_dict=None): - """ This method is used to define the training step of the mode. + """This method is used to define the training step of the mode. :param bilinear_params_dict: Dictionary containing the bilinear parameters :type bilinear_params_dict: dict :return: Dictionary containing the loss values @@ -328,8 +319,6 @@ def train_step(self, beta=10, bilinear_params_dict=None): # convert predicted_values_dirichlet to tf.float64 # predicted_values_dirichlet = tf.cast(predicted_values_dirichlet, tf.float64) - - # tf.print("Boundary Loss : ", boundary_loss) # tf.print("Boundary Loss Shape : ", boundary_loss.shape) # tf.print("Total PDE Loss : ", total_pde_loss) diff --git a/fastvpinns/model/model_inverse.py b/fastvpinns/model/model_inverse.py index e000286..d5cec5c 100644 --- a/fastvpinns/model/model_inverse.py +++ b/fastvpinns/model/model_inverse.py @@ -332,4 +332,4 @@ def train_step(self, beta=10, bilinear_params_dict=None): "loss": total_loss, "inverse_params": self.inverse_params_dict, "sensor_loss": sensor_loss, - } \ No newline at end of file + }