Skip to content

Commit

Permalink
accumulate_s
Browse files Browse the repository at this point in the history
  • Loading branch information
brucefan1983 committed Sep 21, 2024
1 parent 5bafbe8 commit c4f0b75
Show file tree
Hide file tree
Showing 7 changed files with 118 additions and 31 deletions.
8 changes: 4 additions & 4 deletions src/force/nep3.cu
Original file line number Diff line number Diff line change
Expand Up @@ -652,7 +652,7 @@ static __global__ void find_descriptor(
weight_left +
g_gn_angular[(index_right * paramb.num_types_sq + t12) * (paramb.n_max_angular + 1) + n] *
weight_right;
accumulate_s(d12, x12, y12, z12, gn12, s);
accumulate_s(paramb.L_max, d12, x12, y12, z12, gn12, s);
#else
float fc12;
int t2 = g_type[n2];
Expand All @@ -674,7 +674,7 @@ static __global__ void find_descriptor(
c_index += t1 * paramb.num_types + t2 + paramb.num_c_radial;
gn12 += fn12[k] * annmb.c[c_index];
}
accumulate_s(d12, x12, y12, z12, gn12, s);
accumulate_s(paramb.L_max, d12, x12, y12, z12, gn12, s);
#endif
}
find_q(paramb.L_max, paramb.num_L, paramb.n_max_angular + 1, n, s, q + (paramb.n_max_radial + 1));
Expand Down Expand Up @@ -1624,7 +1624,7 @@ static __global__ void find_descriptor(
weight_left +
g_gn_angular[(index_right * paramb.num_types_sq + t12) * (paramb.n_max_angular + 1) + n] *
weight_right;
accumulate_s(d12, x12, y12, z12, gn12, s);
accumulate_s(paramb.L_max, d12, x12, y12, z12, gn12, s);
#else
float fc12;
int t2 = g_type[n2];
Expand All @@ -1646,7 +1646,7 @@ static __global__ void find_descriptor(
c_index += t1 * paramb.num_types + t2 + paramb.num_c_radial;
gn12 += fn12[k] * annmb.c[c_index];
}
accumulate_s(d12, x12, y12, z12, gn12, s);
accumulate_s(paramb.L_max, d12, x12, y12, z12, gn12, s);
#endif
}
find_q(paramb.L_max, paramb.num_L, paramb.n_max_angular + 1, n, s, q + (paramb.n_max_radial + 1));
Expand Down
8 changes: 4 additions & 4 deletions src/force/nep3_multigpu.cu
Original file line number Diff line number Diff line change
Expand Up @@ -924,7 +924,7 @@ static __global__ void find_descriptor(
weight_left +
g_gn_angular[(index_right * paramb.num_types_sq + t12) * (paramb.n_max_angular + 1) + n] *
weight_right;
accumulate_s(d12, x12, y12, z12, gn12, s);
accumulate_s(paramb.L_max, d12, x12, y12, z12, gn12, s);
#else
float fc12;
int t2 = g_type[n2];
Expand All @@ -946,7 +946,7 @@ static __global__ void find_descriptor(
c_index += t1 * paramb.num_types + t2 + paramb.num_c_radial;
gn12 += fn12[k] * annmb.c[c_index];
}
accumulate_s(d12, x12, y12, z12, gn12, s);
accumulate_s(paramb.L_max, d12, x12, y12, z12, gn12, s);
#endif
}
find_q(paramb.L_max, paramb.num_L, paramb.n_max_angular + 1, n, s, q + (paramb.n_max_radial + 1));
Expand Down Expand Up @@ -2026,7 +2026,7 @@ static __global__ void find_descriptor(
weight_left +
g_gn_angular[(index_right * paramb.num_types_sq + t12) * (paramb.n_max_angular + 1) + n] *
weight_right;
accumulate_s(d12, x12, y12, z12, gn12, s);
accumulate_s(paramb.L_max, d12, x12, y12, z12, gn12, s);
#else
float fc12;
int t2 = g_type[n2];
Expand All @@ -2048,7 +2048,7 @@ static __global__ void find_descriptor(
c_index += t1 * paramb.num_types + t2 + paramb.num_c_radial;
gn12 += fn12[k] * annmb.c[c_index];
}
accumulate_s(d12, x12, y12, z12, gn12, s);
accumulate_s(paramb.L_max, d12, x12, y12, z12, gn12, s);
#endif
}
find_q(paramb.L_max, paramb.num_L, paramb.n_max_angular + 1, n, s, q + (paramb.n_max_radial + 1));
Expand Down
8 changes: 4 additions & 4 deletions src/force/nep3_small_box.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ static __global__ void find_descriptor_small_box(
weight_left +
g_gn_angular[(index_right * paramb.num_types_sq + t12) * (paramb.n_max_angular + 1) + n] *
weight_right;
accumulate_s(d12, r12[0], r12[1], r12[2], gn12, s);
accumulate_s(paramb.L_max, d12, r12[0], r12[1], r12[2], gn12, s);
#else
float fc12;
int t2 = g_type[n2];
Expand All @@ -281,7 +281,7 @@ static __global__ void find_descriptor_small_box(
c_index += t1 * paramb.num_types + t2 + paramb.num_c_radial;
gn12 += fn12[k] * annmb.c[c_index];
}
accumulate_s(d12, r12[0], r12[1], r12[2], gn12, s);
accumulate_s(paramb.L_max, d12, r12[0], r12[1], r12[2], gn12, s);
#endif
}
find_q(paramb.L_max, paramb.num_L, paramb.n_max_angular + 1, n, s, q + (paramb.n_max_radial + 1));
Expand Down Expand Up @@ -427,7 +427,7 @@ static __global__ void find_descriptor_small_box(
weight_left +
g_gn_angular[(index_right * paramb.num_types_sq + t12) * (paramb.n_max_angular + 1) + n] *
weight_right;
accumulate_s(d12, r12[0], r12[1], r12[2], gn12, s);
accumulate_s(paramb.L_max, d12, r12[0], r12[1], r12[2], gn12, s);
#else
float fc12;
int t2 = g_type[n2];
Expand All @@ -449,7 +449,7 @@ static __global__ void find_descriptor_small_box(
c_index += t1 * paramb.num_types + t2 + paramb.num_c_radial;
gn12 += fn12[k] * annmb.c[c_index];
}
accumulate_s(d12, r12[0], r12[1], r12[2], gn12, s);
accumulate_s(paramb.L_max, d12, r12[0], r12[1], r12[2], gn12, s);
#endif
}
find_q(paramb.L_max, paramb.num_L, paramb.n_max_angular + 1, n, s, q + (paramb.n_max_radial + 1));
Expand Down
2 changes: 1 addition & 1 deletion src/main_nep/nep3.cu
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ static __global__ void find_descriptors_angular(
c_index += t1 * paramb.num_types + t2 + paramb.num_c_radial;
gn12 += fn12[k] * annmb.c[c_index];
}
accumulate_s(d12, x12, y12, z12, gn12, s);
accumulate_s(paramb.L_max, d12, x12, y12, z12, gn12, s);
}
find_q(paramb.L_max, paramb.num_L, paramb.n_max_angular + 1, n, s, q);
for (int abc = 0; abc < NUM_OF_ABC; ++abc) {
Expand Down
2 changes: 1 addition & 1 deletion src/makefile
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ CUDA_ARCH=-arch=sm_60
ifdef OS # For Windows with the cl.exe compiler
CFLAGS = -O3 $(CUDA_ARCH)
else # For linux
CFLAGS = -std=c++14 -O3 $(CUDA_ARCH)
CFLAGS = -std=c++14 -O3 $(CUDA_ARCH) -DDEBUG -DUSE_TABLE
endif
INC = -I./
LDFLAGS =
Expand Down
2 changes: 1 addition & 1 deletion src/mc/nep_energy.cu
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,7 @@ static __global__ void find_energy_nep(
c_index += t1 * paramb.num_types + t2 + paramb.num_c_radial;
gn12 += fn12[k] * annmb.c[c_index];
}
accumulate_s(d12, r12[0], r12[1], r12[2], gn12, s);
accumulate_s(paramb.L_max, d12, r12[0], r12[1], r12[2], gn12, s);
}
find_q(paramb.L_max, paramb.num_L, paramb.n_max_angular + 1, n, s, q + (paramb.n_max_radial + 1));
}
Expand Down
119 changes: 103 additions & 16 deletions src/utilities/nep_utilities.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -719,31 +719,63 @@ static __device__ __forceinline__ void accumulate_f12(
}

static __device__ __forceinline__ void
accumulate_s(const float d12, float x12, float y12, float z12, const float fn, float* s)
accumulate_s1(const float x12, const float y12, const float z12, const float fn, float* s)
{
float d12inv = 1.0f / d12;
x12 *= d12inv;
y12 *= d12inv;
z12 *= d12inv;
float x12sq = x12 * x12;
float y12sq = y12 * y12;
float z12sq = z12 * z12;
float x12sq_minus_y12sq = x12sq - y12sq;
s[0] += z12 * fn; // Y10
s[1] += x12 * fn; // Y11_real
s[2] += y12 * fn; // Y11_imag
}

static __device__ __forceinline__ void
accumulate_s2(
const float x12,
const float y12,
const float z12,
const float z12sq,
const float x12sq_minus_y12sq,
const float fn,
float* s)
{
s[3] += (3.0f * z12sq - 1.0f) * fn; // Y20
s[4] += x12 * z12 * fn; // Y21_real
s[5] += y12 * z12 * fn; // Y21_imag
s[6] += x12sq_minus_y12sq * fn; // Y22_real
s[7] += 2.0f * x12 * y12 * fn; // Y22_imag
s[8] += (5.0f * z12sq - 3.0f) * z12 * fn; // Y30
s[9] += (5.0f * z12sq - 1.0f) * x12 * fn; // Y31_real
s[10] += (5.0f * z12sq - 1.0f) * y12 * fn; // Y31_imag
s[11] += x12sq_minus_y12sq * z12 * fn; // Y32_real
s[12] += 2.0f * x12 * y12 * z12 * fn; // Y32_imag
s[13] += (x12 * x12 - 3.0f * y12 * y12) * x12 * fn; // Y33_real
s[14] += (3.0f * x12 * x12 - y12 * y12) * y12 * fn; // Y33_imag
}

static __device__ __forceinline__ void
accumulate_s3(
const float x12,
const float y12,
const float z12,
const float x12sq,
const float y12sq,
const float z12sq,
const float x12sq_minus_y12sq,
const float fn,
float* s)
{
s[8] += (5.0f * z12sq - 3.0f) * z12 * fn; // Y30
s[9] += (5.0f * z12sq - 1.0f) * x12 * fn; // Y31_real
s[10] += (5.0f * z12sq - 1.0f) * y12 * fn; // Y31_imag
s[11] += x12sq_minus_y12sq * z12 * fn; // Y32_real
s[12] += 2.0f * x12 * y12 * z12 * fn; // Y32_imag
s[13] += (x12sq - 3.0f * y12sq) * x12 * fn; // Y33_real
s[14] += (3.0f * x12sq - y12sq) * y12 * fn; // Y33_imag
}

static __device__ __forceinline__ void
accumulate_s4(
const float x12,
const float y12,
const float z12,
const float x12sq,
const float y12sq,
const float z12sq,
const float x12sq_minus_y12sq,
const float fn,
float* s)
{
s[15] += ((35.0f * z12sq - 30.0f) * z12sq + 3.0f) * fn; // Y40
s[16] += (7.0f * z12sq - 3.0f) * x12 * z12 * fn; // Y41_real
s[17] += (7.0f * z12sq - 3.0f) * y12 * z12 * fn; // Y41_iamg
Expand All @@ -755,6 +787,61 @@ accumulate_s(const float d12, float x12, float y12, float z12, const float fn, f
s[23] += (4.0f * x12 * y12 * x12sq_minus_y12sq) * fn; // Y44_imag
}

static __device__ __forceinline__ void
accumulate_s5(
const float x12,
const float y12,
const float z12,
const float x12sq,
const float y12sq,
const float z12sq,
const float x12sq_minus_y12sq,
const float fn,
float* s)
{
float z12sqsq = z12sq * z12sq;
float temp1 = 21.0f * z12sqsq - 14.0f * z12sq + 1.0f;
float temp2 = 3.0f * z12sq - 1.0f;
s[24] += (63.0f * z12sqsq - 70.0f * z12sq + 15.0f) * z12 * fn; // Y50
s[25] += temp1 * x12 * fn; // Y51_real
s[26] += temp1 * y12 * fn; // Y51_iamg
s[27] += x12sq_minus_y12sq * temp2 * z12 * fn; // Y51_real
s[28] += 2.0f * x12 * y12 * temp2 * z12 * fn; // Y51_iamg
s[29] += 1; // TODO
s[30] += 1; // TODO
s[31] += (x12sq_minus_y12sq * x12sq_minus_y12sq - 4.0f * x12sq * y12sq) * z12 * fn; // Y54_real
s[32] += (4.0f * x12 * y12 * x12sq_minus_y12sq) * z12 * fn; // Y54_imag

}

static __device__ __forceinline__ void
accumulate_s(const int L_max, const float d12, float x12, float y12, float z12, const float fn, float* s)
{
float d12inv = 1.0f / d12;
x12 *= d12inv;
y12 *= d12inv;
z12 *= d12inv;
float x12sq = x12 * x12;
float y12sq = y12 * y12;
float z12sq = z12 * z12;
float x12sq_minus_y12sq = x12sq - y12sq;
if (L_max >= 1) {
accumulate_s1(x12, y12, z12, fn, s);
}
if (L_max >= 2) {
accumulate_s2(x12, y12, z12, z12sq, x12sq_minus_y12sq, fn, s);
}
if (L_max >= 3) {
accumulate_s3(x12, y12, z12, x12sq, y12sq, z12sq, x12sq_minus_y12sq, fn, s);
}
if (L_max >= 4) {
accumulate_s4(x12, y12, z12, x12sq, y12sq, z12sq, x12sq_minus_y12sq, fn, s);
}
if (L_max >= 5) {
accumulate_s5(x12, y12, z12, x12sq, y12sq, z12sq, x12sq_minus_y12sq, fn, s);
}
}

static __device__ __forceinline__ void
find_q(
const int L_max,
Expand Down

0 comments on commit c4f0b75

Please sign in to comment.