Skip to content

Commit

Permalink
BUG: comm plan = operator
Browse files Browse the repository at this point in the history
  • Loading branch information
Adrian-Diaz committed Dec 11, 2024
1 parent 81c6c90 commit 251dd13
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 13 deletions.
36 changes: 23 additions & 13 deletions examples/test_tpetra_mesh.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,13 @@ using namespace mtr; // matar namespace

struct mesh_data {
int num_dim = 3;
size_t nlocal_nodes, rnum_elem;
size_t num_nodes, num_elem;
size_t nlocal_nodes, rnum_elem; //local node and element count respectively
size_t num_nodes, num_elem; //global node and element count respectively
TpetraDFArray<double> node_coords_distributed; //unique local coords
TpetraDFArray<double> ghost_node_coords_distributed; //local data set by other processes
TpetraDFArray<double> all_node_coords_distributed; //unique + ghost
TpetraDFArray<long long int> nodes_in_elem_distributed;
TpetraDFArray<long long int> nodes_in_elem_distributed; //element node connectivity
TpetraCommunicationPlan<real_t> ghost_comms; //comms plan to update ghost data
};

void setup_maps(mesh_data &mesh);
Expand Down Expand Up @@ -99,9 +100,10 @@ int main(int argc, char* argv[])
void run_test(mesh_data &mesh)
{
int num_dim = mesh.num_dim;
TpetraDFArray<double> all_node_coords_distributed = mesh.all_node_coords_distributed;
TpetraDFArray<double> node_coords_distributed = mesh.node_coords_distributed;
TpetraDFArray<long long int> nodes_in_elem_distributed = mesh.nodes_in_elem_distributed;
TpetraPartitionMap<> all_node_map = mesh.nodes_in_elem_distributed.pmap;
size_t nlocal_nodes = mesh.nlocal_nodes;
int ntimesteps = 1000;
real_t constant_velocity = 0.0001;
real_t timestep = 0.001;
Expand All @@ -112,12 +114,16 @@ void run_test(mesh_data &mesh)
FOR_ALL(ielem,0,mesh.rnum_elem, {
for(int inode=0; inode < 8; inode++){
int local_node_index = nodes_in_elem_distributed(ielem,inode);
for(int idim=0; idim < num_dim; idim++){
all_node_coords_distributed(local_node_index, idim) += constant_velocity*timestep;
if(local_node_index < nlocal_nodes){
for(int idim=0; idim < num_dim; idim++){
node_coords_distributed(local_node_index, idim) += constant_velocity*timestep;
}
}
}
});
}
//update ghosts
mesh.ghost_comms.execute_comms();
}

/* ----------------------------------------------------------------------
Expand Down Expand Up @@ -264,20 +270,24 @@ void setup_maps(mesh_data &mesh)
// create distributed multivector of the ghost node coords as a subview of the all vector
mesh.ghost_node_coords_distributed = TpetraDFArray<double>(mesh.all_node_coords_distributed, ghost_node_map, nlocal_nodes);

// create communication object between ghosts and unique local data
TpetraCommunicationPlan<real_t> ghost_comms(mesh.ghost_node_coords_distributed, mesh.node_coords_distributed);

// comms to get ghosts coords initialized
ghost_comms.execute_comms();

//initialize 0:nlocal-1 data in the all vector since the comms just set nlocal:nall via the subview
//initialize 0:nlocal-1 data in the all vector
FOR_ALL(inode,0,nlocal_nodes, {
for (int idim=0; idim < num_dim; idim++){
mesh.all_node_coords_distributed(inode,idim) = mesh.node_coords_distributed(inode,idim);
}
});
mesh.all_node_coords_distributed.update_host();

//set local node array to be a subview of the all node array to avoid carrying duplicate memory
mesh.node_coords_distributed = TpetraDFArray<double>(mesh.all_node_coords_distributed, map, 0);
node_coords_distributed = mesh.node_coords_distributed; //reset local variable

// create communication object between ghosts and unique local data
mesh.ghost_comms = TpetraCommunicationPlan<real_t>(mesh.ghost_node_coords_distributed, mesh.node_coords_distributed);

// comms to get ghosts coords initialized
mesh.ghost_comms.execute_comms();

//convert nodes in elem to local node ids to avoid excessive map conversion calls
FOR_ALL(ielem,0,mesh.rnum_elem, {
for(int inode=0; inode < 8; inode++){
Expand Down
12 changes: 12 additions & 0 deletions src/include/tpetra_wrapper_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -3717,6 +3717,12 @@ TpetraCommunicationPlan<T,Layout,ExecSpace,MemoryTraits>& TpetraCommunicationPla
combine_mode_ = temp.combine_mode_;
destination_vector_ = temp.destination_vector_;
source_vector_ = temp.source_vector_;
if(reverse_comms_flag){
exporter = temp.exporter;
}
else{
importer = temp.importer;
}
}

return *this;
Expand Down Expand Up @@ -3833,6 +3839,12 @@ TpetraLRCommunicationPlan<T,Layout,ExecSpace,MemoryTraits>& TpetraLRCommunicatio
combine_mode_ = temp.combine_mode_;
destination_vector_ = temp.destination_vector_;
source_vector_ = temp.source_vector_;
if(reverse_comms_flag){
exporter = temp.exporter;
}
else{
importer = temp.importer;
}
}

return *this;
Expand Down

0 comments on commit 251dd13

Please sign in to comment.