Skip to content

Commit

Permalink
Merge pull request CEED#1435 from CEED/zach/gpu-assemble-diag-basis-none
Browse files Browse the repository at this point in the history
  • Loading branch information
zatkins-dev authored Dec 22, 2023
2 parents de27559 + 91db28b commit 9ee4c00
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 37 deletions.
39 changes: 21 additions & 18 deletions backends/cuda-ref/ceed-cuda-ref-operator.c
Original file line number Diff line number Diff line change
Expand Up @@ -809,25 +809,28 @@ static inline int CeedOperatorAssembleDiagonalCore_Cuda(CeedOperator op, CeedVec
}
CeedCallBackend(CeedVectorSetValue(elem_diag, 0.0));

// Assemble element operator diagonals
CeedCallBackend(CeedVectorGetArray(elem_diag, CEED_MEM_DEVICE, &elem_diag_array));
CeedCallBackend(CeedVectorGetArrayRead(assembled_qf, CEED_MEM_DEVICE, &assembled_qf_array));
CeedCallBackend(CeedElemRestrictionGetNumElements(diag_rstr, &num_elem));

// Compute the diagonal of B^T D B
int elem_per_block = 1;
int grid = num_elem / elem_per_block + ((num_elem / elem_per_block * elem_per_block < num_elem) ? 1 : 0);
void *args[] = {(void *)&num_elem, &diag->d_identity, &diag->d_interp_in, &diag->d_grad_in, &diag->d_interp_out,
&diag->d_grad_out, &diag->d_e_mode_in, &diag->d_e_mode_out, &assembled_qf_array, &elem_diag_array};
if (is_point_block) {
CeedCallBackend(CeedRunKernelDim_Cuda(ceed, diag->linearPointBlock, grid, diag->num_nodes, 1, elem_per_block, args));
} else {
CeedCallBackend(CeedRunKernelDim_Cuda(ceed, diag->linearDiagonal, grid, diag->num_nodes, 1, elem_per_block, args));
}
// Only assemble diagonal if the basis has nodes, otherwise inputs are null pointers
if (diag->num_nodes > 0) {
// Assemble element operator diagonals
CeedCallBackend(CeedVectorGetArray(elem_diag, CEED_MEM_DEVICE, &elem_diag_array));
CeedCallBackend(CeedVectorGetArrayRead(assembled_qf, CEED_MEM_DEVICE, &assembled_qf_array));
CeedCallBackend(CeedElemRestrictionGetNumElements(diag_rstr, &num_elem));

// Compute the diagonal of B^T D B
int elem_per_block = 1;
int grid = num_elem / elem_per_block + ((num_elem / elem_per_block * elem_per_block < num_elem) ? 1 : 0);
void *args[] = {(void *)&num_elem, &diag->d_identity, &diag->d_interp_in, &diag->d_grad_in, &diag->d_interp_out,
&diag->d_grad_out, &diag->d_e_mode_in, &diag->d_e_mode_out, &assembled_qf_array, &elem_diag_array};
if (is_point_block) {
CeedCallBackend(CeedRunKernelDim_Cuda(ceed, diag->linearPointBlock, grid, diag->num_nodes, 1, elem_per_block, args));
} else {
CeedCallBackend(CeedRunKernelDim_Cuda(ceed, diag->linearDiagonal, grid, diag->num_nodes, 1, elem_per_block, args));
}

// Restore arrays
CeedCallBackend(CeedVectorRestoreArray(elem_diag, &elem_diag_array));
CeedCallBackend(CeedVectorRestoreArrayRead(assembled_qf, &assembled_qf_array));
// Restore arrays
CeedCallBackend(CeedVectorRestoreArray(elem_diag, &elem_diag_array));
CeedCallBackend(CeedVectorRestoreArrayRead(assembled_qf, &assembled_qf_array));
}

// Assemble local operator diagonal
CeedCallBackend(CeedElemRestrictionApply(diag_rstr, CEED_TRANSPOSE, elem_diag, assembled, request));
Expand Down
41 changes: 22 additions & 19 deletions backends/hip-ref/ceed-hip-ref-operator.c
Original file line number Diff line number Diff line change
Expand Up @@ -812,26 +812,29 @@ static inline int CeedOperatorAssembleDiagonalCore_Hip(CeedOperator op, CeedVect
}
CeedCallBackend(CeedVectorSetValue(elem_diag, 0.0));

// Assemble element operator diagonals
CeedCallBackend(CeedVectorGetArray(elem_diag, CEED_MEM_DEVICE, &elem_diag_array));
CeedCallBackend(CeedVectorGetArrayRead(assembled_qf, CEED_MEM_DEVICE, &assembled_qf_array));
CeedCallBackend(CeedElemRestrictionGetNumElements(diag_rstr, &num_elem));

// Compute the diagonal of B^T D B
int elem_per_block = 1;
int grid = num_elem / elem_per_block + ((num_elem / elem_per_block * elem_per_block < num_elem) ? 1 : 0);
void *args[] = {(void *)&num_elem, &diag->d_identity, &diag->d_interp_in, &diag->d_grad_in, &diag->d_interp_out,
&diag->d_grad_out, &diag->d_e_mode_in, &diag->d_e_mode_out, &assembled_qf_array, &elem_diag_array};

if (is_point_block) {
CeedCallBackend(CeedRunKernelDim_Hip(ceed, diag->linearPointBlock, grid, diag->num_modes, 1, elem_per_block, args));
} else {
CeedCallBackend(CeedRunKernelDim_Hip(ceed, diag->linearDiagonal, grid, diag->num_modes, 1, elem_per_block, args));
}
// Only assemble diagonal if the basis has nodes, otherwise inputs are null pointers
if (diag->num_modes > 0) {
// Assemble element operator diagonals
CeedCallBackend(CeedVectorGetArray(elem_diag, CEED_MEM_DEVICE, &elem_diag_array));
CeedCallBackend(CeedVectorGetArrayRead(assembled_qf, CEED_MEM_DEVICE, &assembled_qf_array));
CeedCallBackend(CeedElemRestrictionGetNumElements(diag_rstr, &num_elem));

// Compute the diagonal of B^T D B
int elem_per_block = 1;
int grid = num_elem / elem_per_block + ((num_elem / elem_per_block * elem_per_block < num_elem) ? 1 : 0);
void *args[] = {(void *)&num_elem, &diag->d_identity, &diag->d_interp_in, &diag->d_grad_in, &diag->d_interp_out,
&diag->d_grad_out, &diag->d_e_mode_in, &diag->d_e_mode_out, &assembled_qf_array, &elem_diag_array};

if (is_point_block) {
CeedCallBackend(CeedRunKernelDim_Hip(ceed, diag->linearPointBlock, grid, diag->num_modes, 1, elem_per_block, args));
} else {
CeedCallBackend(CeedRunKernelDim_Hip(ceed, diag->linearDiagonal, grid, diag->num_modes, 1, elem_per_block, args));
}

// Restore arrays
CeedCallBackend(CeedVectorRestoreArray(elem_diag, &elem_diag_array));
CeedCallBackend(CeedVectorRestoreArrayRead(assembled_qf, &assembled_qf_array));
// Restore arrays
CeedCallBackend(CeedVectorRestoreArray(elem_diag, &elem_diag_array));
CeedCallBackend(CeedVectorRestoreArrayRead(assembled_qf, &assembled_qf_array));
}

// Assemble local operator diagonal
CeedCallBackend(CeedElemRestrictionApply(diag_rstr, CEED_TRANSPOSE, elem_diag, assembled, request));
Expand Down

0 comments on commit 9ee4c00

Please sign in to comment.