Skip to content

Commit

Permalink
Improve checking to detect NaN case
Browse files Browse the repository at this point in the history
The current `check_two_equal` may not able to check the `NaN` case. We can improve it by using `std::isnan`.

When the following incorrect code is run on M1 environment, the test still passes.
The cause of this issue is that `diff` being `NaN` and the final check won't detect `NaN` case.

Here is the incorrect code; however, it passes the current test.

```
-------- Sanity check of simd_programming implementation: Passed! --------
Section, Total time(ms), Average time(ms), Count, GOPs
simd_programming, 1082.963989, 108.295998, 10, 2.420616
```

```c

            for (int q = 0; q < num_block; q++) {
                // load 32x4bit (16 bytes) weight
                const uint8x16_t w0 = vld1q_u8(w_start);
                w_start += 16;

                /*
                   We will accelerate the program using ARM Intrinsics. You can check the documentation of operations
                   at: https://developer.arm.com/architectures/instruction-sets/intrinsics
                */
                // TODO: decode the lower and upper half of the weights as int8x16_t
                // Hint:
                // (1) use `vandq_u8` with the mask_low4bit to get the lower half
                // (2) use `vshrq_n_u8` to right shift 4 bits and get the upper half
                // (3) use `vreinterpretq_s8_u8` to interpret the  vector as int8
                // lowbit mask
                const uint8x16_t mask_low4bit = vdupq_n_u8(0xf);

                int8x16_t w0_low = vreinterpretq_s8_u8(vandq_u8(w0, mask_low4bit));
                int8x16_t w0_high = vreinterpretq_s8_u8(vshrq_n_u8(w0, 4));

                // TODO: apply zero_point to weights and convert the range from (0, 15) to (-8, 7)
                // Hint: using `vsubq_s8` to the lower-half and upper-half vectors of weights
                const int8x16_t offsets = vdupq_n_s8(8);

                w0_low = vsubq_s8(w0_low, offsets);
                w0_high = vsubq_s8(w0_high, offsets);

                // load 32 8-bit activation
                const int8x16_t a0 = vld1q_s8(a_start);
                const int8x16_t a1 = vld1q_s8(a_start + 16);
                a_start += 32;

                // TODO: perform dot product and store the result into the intermediate sum, int_sum0
                // Hint: use `vdotq_s32` to compute sumv0 = a0 * lower-half weights + a1 * upper-half weights
                // int32x4 vector to store intermediate sum
                int32x4_t int_sum0;

                int_sum0 = vdupq_n_s32(0);

                sumv0 = vdotq_s32(int_sum0, w0_low, a0);
                sumv0 = vdotq_s32(int_sum0, w0_high, a1);

                float s_0 = *s_a++ * *s_w++;
                sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(int_sum0), s_0);
            }
            C->data_ptr[row * n + col] = vaddvq_f32(sumv0);

```
  • Loading branch information
insop committed Jan 1, 2024
1 parent 025f96e commit e4effcb
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions transformer/src/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ bool check_two_equal(float* array, float* array2, int size, float error) {
error_info.a2 = array2[i];
}
}
if ((sq_diff / size) > error) {

if (((sq_diff / size) > error) || (std::isnan(sq_diff))) {
return false;
}
return true;
Expand All @@ -64,7 +65,7 @@ bool check_two_equal(T* array, T* array2, int size) {
return false;
}
}
if ((sq_diff / size) > ERROR_MAX) {
if (((sq_diff / size) > ERROR_MAX) || (std::isnan(sq_diff))) {
std::cout << "MSE:" << sq_diff / size << ", MAX SQ diff:" << max_sqdiff << std::endl;
return false;
}
Expand All @@ -80,7 +81,7 @@ bool check_two_equal<int8_t>(int8_t* array, int8_t* array2, int size) {
sq_diff += diff * diff;
if (diff * diff > max_sqdiff) max_sqdiff = diff * diff;
}
if ((sq_diff / size) > INT_ERROR_MAX) {
if (((sq_diff / size) > INT_ERROR_MAX) || (std::isnan(sq_diff))) {
std::cout << "MSE:" << sq_diff / size << ", MAX SQ diff:" << max_sqdiff << std::endl;
return false;
}
Expand Down

0 comments on commit e4effcb

Please sign in to comment.