diff --git a/transformer/src/utils.cc b/transformer/src/utils.cc index d496fd2..18d440a 100644 --- a/transformer/src/utils.cc +++ b/transformer/src/utils.cc @@ -44,7 +44,8 @@ bool check_two_equal(float* array, float* array2, int size, float error) { error_info.a2 = array2[i]; } } - if ((sq_diff / size) > error) { + + if (((sq_diff / size) > error) || (std::isnan(sq_diff))) { return false; } return true; @@ -64,7 +65,7 @@ bool check_two_equal(T* array, T* array2, int size) { return false; } } - if ((sq_diff / size) > ERROR_MAX) { + if (((sq_diff / size) > ERROR_MAX) || (std::isnan(sq_diff))) { std::cout << "MSE:" << sq_diff / size << ", MAX SQ diff:" << max_sqdiff << std::endl; return false; } @@ -80,7 +81,7 @@ bool check_two_equal(int8_t* array, int8_t* array2, int size) { sq_diff += diff * diff; if (diff * diff > max_sqdiff) max_sqdiff = diff * diff; } - if ((sq_diff / size) > INT_ERROR_MAX) { + if (((sq_diff / size) > INT_ERROR_MAX) || (std::isnan(sq_diff))) { std::cout << "MSE:" << sq_diff / size << ", MAX SQ diff:" << max_sqdiff << std::endl; return false; }