Skip to content

Commit

Permalink
Bug fix for sycl-shared-basis
Browse files Browse the repository at this point in the history
  • Loading branch information
uumesh committed Aug 8, 2023
1 parent ca6301b commit b6c8abe
Showing 1 changed file with 23 additions and 23 deletions.
46 changes: 23 additions & 23 deletions include/ceed/jit-source/sycl/sycl-shared-basis-tensor-templates.h
Original file line number Diff line number Diff line change
Expand Up @@ -284,21 +284,21 @@ inline void ContractX3d(const CeedInt P_1D, const CeedInt Q_1D, private const Ce
private CeedScalar *restrict V, local CeedScalar *restrict scratch) {
const CeedInt item_id_x = get_local_id(0);
const CeedInt item_id_y = get_local_id(1) % T_1D;
const CeedInt item_id_z = get_local_id(2) / T_1D;
const CeedInt item_id_z = get_local_id(1) / T_1D;

// CeedScalar r_B[T_1D];
// for (CeedInt i = 0; i < P_1D; i++) {
// r_B[i] = B[i + item_id_x * P_1D];
// }

// for (CeedInt k = 0; k < P_1D; k++) {
scratch[item_id_x + T_1D * (item_id_y + T_1D * item_id_z)] = U[0];
scratch[item_id_x + T_1D * (item_id_y + T_1D * item_id_z)] = *U;
work_group_barrier(CLK_LOCAL_MEM_FENCE);

V[0] = 0.0;
*V = 0.0;
if (item_id_x < Q_1D && item_id_y < P_1D && item_id_z < P_1D) {
for (CeedInt i = 0; i < P_1D; i++) {
V[0] += B[i + item_id_x * P_1D] * scratch[i + T_1D * (item_id_y + T_1D * item_id_z)]; // Contract x direction
*V += B[i + item_id_x * P_1D] * scratch[i + T_1D * (item_id_y + T_1D * item_id_z)]; // Contract x direction
}
}
work_group_barrier(CLK_LOCAL_MEM_FENCE);
Expand All @@ -320,13 +320,13 @@ inline void ContractY3d(const CeedInt P_1D, const CeedInt Q_1D, private const Ce
// }

// for (CeedInt k = 0; k < P_1D; k++) {
scratch[item_id_x + T_1D * (item_id_y + T_1D * item_id_z)] = U[0];
scratch[item_id_x + T_1D * (item_id_y + T_1D * item_id_z)] = *U;
work_group_barrier(CLK_LOCAL_MEM_FENCE);

V[0] = 0.0;
*V = 0.0;
if (item_id_x < Q_1D && item_id_y < Q_1D && item_id_z < P_1D) {
for (CeedInt i = 0; i < P_1D; i++) {
V[0] += B[i + item_id_y * P_1D] * scratch[item_id_x + T_1D * (i + T_1D * item_id_z)]; // Contract y direction
*V += B[i + item_id_y * P_1D] * scratch[item_id_x + T_1D * (i + T_1D * item_id_z)]; // Contract y direction
}
}
work_group_barrier(CLK_LOCAL_MEM_FENCE);
Expand All @@ -342,14 +342,14 @@ inline void ContractZ3d(const CeedInt P_1D, const CeedInt Q_1D, private const Ce
const CeedInt item_id_y = get_local_id(1) % T_1D;
const CeedInt item_id_z = get_local_id(1) / T_1D;

scratch[item_id_x + T_1D * (item_id_y + T_1D * item_id_z)] = U[0];
scratch[item_id_x + T_1D * (item_id_y + T_1D * item_id_z)] = *U;
work_group_barrier(CLK_LOCAL_MEM_FENCE);

// for (CeedInt k = 0; k < Q_1D; k++) {
V[0] = 0.0;
*V = 0.0;
if (item_id_x < Q_1D && item_id_y < Q_1D && item_id_z < Q_1D) {
for (CeedInt i = 0; i < P_1D; i++) {
V[0] += B[i + item_id_z * P_1D] * scratch[item_id_x + T_1D * (item_id_y + T_1D * i)]; // Contract z direction
*V += B[i + item_id_z * P_1D] * scratch[item_id_x + T_1D * (item_id_y + T_1D * i)]; // Contract z direction
}
}
// }
Expand All @@ -365,14 +365,14 @@ inline void ContractTransposeZ3d(const CeedInt P_1D, const CeedInt Q_1D, private
const CeedInt item_id_y = get_local_id(1) % T_1D;
const CeedInt item_id_z = get_local_id(1) / T_1D;

scratch[item_id_x + T_1D * (item_id_y + T_1D * item_id_z)] = U[0];
scratch[item_id_x + T_1D * (item_id_y + T_1D * item_id_z)] = *U;
work_group_barrier(CLK_LOCAL_MEM_FENCE);

// for (CeedInt k = 0; k < P_1D; k++) {
V[0] = 0.0;
*V = 0.0;
if (item_id_x < Q_1D && item_id_y < Q_1D && item_id_z < P_1D) {
for (CeedInt i = 0; i < Q_1D; i++) {
V[0] += B[item_id_z + i * P_1D] * scratch[item_id_x + T_1D * (item_id_y + T_1D * i)]; // Contract z direction
*V += B[item_id_z + i * P_1D] * scratch[item_id_x + T_1D * (item_id_y + T_1D * i)]; // Contract z direction
}
}
// }
Expand All @@ -394,13 +394,13 @@ inline void ContractTransposeY3d(const CeedInt P_1D, const CeedInt Q_1D, private
// }

// for (CeedInt k = 0; k < P_1D; k++) {
scratch[item_id_x + T_1D * (item_id_y + T_1D * item_id_z)] = U[0];
scratch[item_id_x + T_1D * (item_id_y + T_1D * item_id_z)] = *U;
work_group_barrier(CLK_LOCAL_MEM_FENCE);

V[0] = 0.0;
*V = 0.0;
if (item_id_x < Q_1D && item_id_y < P_1D && item_id_z < P_1D) {
for (CeedInt i = 0; i < Q_1D; i++) {
V[0] += B[item_id_y + i * P_1D] * scratch[item_id_x + T_1D * (i + T_1D * item_id_z)]; // Contract y direction
*V += B[item_id_y + i * P_1D] * scratch[item_id_x + T_1D * (i + T_1D * item_id_z)]; // Contract y direction
}
}
work_group_barrier(CLK_LOCAL_MEM_FENCE);
Expand All @@ -422,12 +422,12 @@ inline void ContractTransposeAddY3d(const CeedInt P_1D, const CeedInt Q_1D, priv
// }

// for (CeedInt k = 0; k < P_1D; k++) {
scratch[item_id_x + T_1D * (item_id_y + T_1D * item_id_z)] = U[0];
scratch[item_id_x + T_1D * (item_id_y + T_1D * item_id_z)] = *U;
work_group_barrier(CLK_LOCAL_MEM_FENCE);

if (item_id_x < Q_1D && item_id_y < P_1D && item_id_z < P_1D) {
for (CeedInt i = 0; i < Q_1D; i++) {
V[0] += B[item_id_y + i * P_1D] * scratch[item_id_x + T_1D * (i + T_1D * item_id_z)]; // Contract y direction
*V += B[item_id_y + i * P_1D] * scratch[item_id_x + T_1D * (i + T_1D * item_id_z)]; // Contract y direction
}
}
work_group_barrier(CLK_LOCAL_MEM_FENCE);
Expand All @@ -449,13 +449,13 @@ inline void ContractTransposeX3d(const CeedInt P_1D, const CeedInt Q_1D, private
// }

// for (CeedInt k = 0; k < P_1D; k++) {
scratch[item_id_x + T_1D * (item_id_y + T_1D * item_id_z)] = U[0];
scratch[item_id_x + T_1D * (item_id_y + T_1D * item_id_z)] = *U;
work_group_barrier(CLK_LOCAL_MEM_FENCE);

V[0] = 0.0;
*V = 0.0;
if (item_id_x < P_1D && item_id_y < P_1D && item_id_z < P_1D) {
for (CeedInt i = 0; i < Q_1D; i++) {
V[0] += B[item_id_x + i * P_1D] * scratch[i + T_1D * (item_id_y + T_1D * item_id_z)]; // Contract x direction
*V += B[item_id_x + i * P_1D] * scratch[i + T_1D * (item_id_y + T_1D * item_id_z)]; // Contract x direction
}
}
work_group_barrier(CLK_LOCAL_MEM_FENCE);
Expand All @@ -477,12 +477,12 @@ inline void ContractTransposeAddX3d(const CeedInt P_1D, const CeedInt Q_1D, priv
// }

// for (CeedInt k = 0; k < P_1D; k++) {
scratch[item_id_x + T_1D * (item_id_y + T_1D * item_id_z)] = U[0];
scratch[item_id_x + T_1D * (item_id_y + T_1D * item_id_z)] = *U;
work_group_barrier(CLK_LOCAL_MEM_FENCE);

if (item_id_x < P_1D && item_id_y < P_1D && item_id_z < P_1D) {
for (CeedInt i = 0; i < Q_1D; i++) {
V[0] += B[item_id_x + i * P_1D] * scratch[i + T_1D * (item_id_y + T_1D * item_id_z)]; // Contract x direction
*V += B[item_id_x + i * P_1D] * scratch[i + T_1D * (item_id_y + T_1D * item_id_z)]; // Contract x direction
}
}
work_group_barrier(CLK_LOCAL_MEM_FENCE);
Expand Down

0 comments on commit b6c8abe

Please sign in to comment.