From 251dd13ad1eb305ab4bc63a7343574948a85a8b5 Mon Sep 17 00:00:00 2001 From: Adrian-Diaz Date: Tue, 10 Dec 2024 19:13:59 -0700 Subject: [PATCH] BUG: comm plan = operator --- examples/test_tpetra_mesh.cpp | 36 +++++++++++++++++++----------- src/include/tpetra_wrapper_types.h | 12 ++++++++++ 2 files changed, 35 insertions(+), 13 deletions(-) diff --git a/examples/test_tpetra_mesh.cpp b/examples/test_tpetra_mesh.cpp index 0aae46de..0b7d21af 100644 --- a/examples/test_tpetra_mesh.cpp +++ b/examples/test_tpetra_mesh.cpp @@ -45,12 +45,13 @@ using namespace mtr; // matar namespace struct mesh_data { int num_dim = 3; - size_t nlocal_nodes, rnum_elem; - size_t num_nodes, num_elem; + size_t nlocal_nodes, rnum_elem; //local node and element count respectively + size_t num_nodes, num_elem; //global node and element count respectively TpetraDFArray node_coords_distributed; //unique local coords TpetraDFArray ghost_node_coords_distributed; //local data set by other processes TpetraDFArray all_node_coords_distributed; //unique + ghost - TpetraDFArray nodes_in_elem_distributed; + TpetraDFArray nodes_in_elem_distributed; //element node connectivity + TpetraCommunicationPlan ghost_comms; //comms plan to update ghost data }; void setup_maps(mesh_data &mesh); @@ -99,9 +100,10 @@ int main(int argc, char* argv[]) void run_test(mesh_data &mesh) { int num_dim = mesh.num_dim; - TpetraDFArray all_node_coords_distributed = mesh.all_node_coords_distributed; + TpetraDFArray node_coords_distributed = mesh.node_coords_distributed; TpetraDFArray nodes_in_elem_distributed = mesh.nodes_in_elem_distributed; TpetraPartitionMap<> all_node_map = mesh.nodes_in_elem_distributed.pmap; + size_t nlocal_nodes = mesh.nlocal_nodes; int ntimesteps = 1000; real_t constant_velocity = 0.0001; real_t timestep = 0.001; @@ -112,12 +114,16 @@ void run_test(mesh_data &mesh) FOR_ALL(ielem,0,mesh.rnum_elem, { for(int inode=0; inode < 8; inode++){ int local_node_index = nodes_in_elem_distributed(ielem,inode); - for(int idim=0; idim < num_dim; idim++){ - all_node_coords_distributed(local_node_index, idim) += constant_velocity*timestep; + if(local_node_index < nlocal_nodes){ + for(int idim=0; idim < num_dim; idim++){ + node_coords_distributed(local_node_index, idim) += constant_velocity*timestep; + } } } }); } + //update ghosts + mesh.ghost_comms.execute_comms(); } /* ---------------------------------------------------------------------- @@ -264,13 +270,7 @@ void setup_maps(mesh_data &mesh) // create distributed multivector of the ghost node coords as a subview of the all vector mesh.ghost_node_coords_distributed = TpetraDFArray(mesh.all_node_coords_distributed, ghost_node_map, nlocal_nodes); - // create communication object between ghosts and unique local data - TpetraCommunicationPlan ghost_comms(mesh.ghost_node_coords_distributed, mesh.node_coords_distributed); - - // comms to get ghosts coords initialized - ghost_comms.execute_comms(); - - //initialize 0:nlocal-1 data in the all vector since the comms just set nlocal:nall via the subview + //initialize 0:nlocal-1 data in the all vector FOR_ALL(inode,0,nlocal_nodes, { for (int idim=0; idim < num_dim; idim++){ mesh.all_node_coords_distributed(inode,idim) = mesh.node_coords_distributed(inode,idim); @@ -278,6 +278,16 @@ void setup_maps(mesh_data &mesh) }); mesh.all_node_coords_distributed.update_host(); + //set local node array to be a subview of the all node array to avoid carrying duplicate memory + mesh.node_coords_distributed = TpetraDFArray(mesh.all_node_coords_distributed, map, 0); + node_coords_distributed = mesh.node_coords_distributed; //reset local variable + + // create communication object between ghosts and unique local data + mesh.ghost_comms = TpetraCommunicationPlan(mesh.ghost_node_coords_distributed, mesh.node_coords_distributed); + + // comms to get ghosts coords initialized + mesh.ghost_comms.execute_comms(); + //convert nodes in elem to local node ids to avoid excessive map conversion calls FOR_ALL(ielem,0,mesh.rnum_elem, { for(int inode=0; inode < 8; inode++){ diff --git a/src/include/tpetra_wrapper_types.h b/src/include/tpetra_wrapper_types.h index 9671d2c3..b73618dc 100644 --- a/src/include/tpetra_wrapper_types.h +++ b/src/include/tpetra_wrapper_types.h @@ -3717,6 +3717,12 @@ TpetraCommunicationPlan& TpetraCommunicationPla combine_mode_ = temp.combine_mode_; destination_vector_ = temp.destination_vector_; source_vector_ = temp.source_vector_; + if(reverse_comms_flag){ + exporter = temp.exporter; + } + else{ + importer = temp.importer; + } } return *this; @@ -3833,6 +3839,12 @@ TpetraLRCommunicationPlan& TpetraLRCommunicatio combine_mode_ = temp.combine_mode_; destination_vector_ = temp.destination_vector_; source_vector_ = temp.source_vector_; + if(reverse_comms_flag){ + exporter = temp.exporter; + } + else{ + importer = temp.importer; + } } return *this;