Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
The current `check_two_equal` may not able to check the `NaN` case. We can improve it by using `std::isnan`. When the following incorrect code is run on M1 environment, the test still passes. The cause of this issue is that `diff` being `NaN` and the final check won't detect `NaN` case. Here is the incorrect code; however, it passes the current test. ``` -------- Sanity check of simd_programming implementation: Passed! -------- Section, Total time(ms), Average time(ms), Count, GOPs simd_programming, 1082.963989, 108.295998, 10, 2.420616 ``` ```c for (int q = 0; q < num_block; q++) { // load 32x4bit (16 bytes) weight const uint8x16_t w0 = vld1q_u8(w_start); w_start += 16; /* We will accelerate the program using ARM Intrinsics. You can check the documentation of operations at: https://developer.arm.com/architectures/instruction-sets/intrinsics */ // TODO: decode the lower and upper half of the weights as int8x16_t // Hint: // (1) use `vandq_u8` with the mask_low4bit to get the lower half // (2) use `vshrq_n_u8` to right shift 4 bits and get the upper half // (3) use `vreinterpretq_s8_u8` to interpret the vector as int8 // lowbit mask const uint8x16_t mask_low4bit = vdupq_n_u8(0xf); int8x16_t w0_low = vreinterpretq_s8_u8(vandq_u8(w0, mask_low4bit)); int8x16_t w0_high = vreinterpretq_s8_u8(vshrq_n_u8(w0, 4)); // TODO: apply zero_point to weights and convert the range from (0, 15) to (-8, 7) // Hint: using `vsubq_s8` to the lower-half and upper-half vectors of weights const int8x16_t offsets = vdupq_n_s8(8); w0_low = vsubq_s8(w0_low, offsets); w0_high = vsubq_s8(w0_high, offsets); // load 32 8-bit activation const int8x16_t a0 = vld1q_s8(a_start); const int8x16_t a1 = vld1q_s8(a_start + 16); a_start += 32; // TODO: perform dot product and store the result into the intermediate sum, int_sum0 // Hint: use `vdotq_s32` to compute sumv0 = a0 * lower-half weights + a1 * upper-half weights // int32x4 vector to store intermediate sum int32x4_t int_sum0; int_sum0 = vdupq_n_s32(0); sumv0 = vdotq_s32(int_sum0, w0_low, a0); sumv0 = vdotq_s32(int_sum0, w0_high, a1); float s_0 = *s_a++ * *s_w++; sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(int_sum0), s_0); } C->data_ptr[row * n + col] = vaddvq_f32(sumv0); ```
- Loading branch information