diff --git a/common/include/common/float_comparison.hpp b/common/include/common/float_comparison.hpp index 1222ccc41..a8465a1d0 100644 --- a/common/include/common/float_comparison.hpp +++ b/common/include/common/float_comparison.hpp @@ -120,17 +120,27 @@ scalar_t clamp_to_limits(scalar_t v) { * Indicates the tolerated margin for relative differences */ template -inline scalar_t getRelativeErrorMargin() { +inline scalar_t getRelativeErrorMargin(const bool is_trsm) { /* Measured empirically with gemm. The dimensions of the matrices (even k) * don't seem to have an impact on the observed relative differences * In the cases where the relative error is relevant (non close to zero), * relative differences of up to 0.002 were observed for float */ - return static_cast(0.005); + scalar_t margin = 0.005; + if (is_trsm) { + const char* en_joint_matrix = std::getenv("SB_ENABLE_JOINT_MATRIX"); + if (en_joint_matrix != NULL && std::is_same::value && + *en_joint_matrix == '1') { + // increase error margin for mixed precision calculation + // for trsm operator. + margin = 0.009f; + } + } + return margin; } template <> -inline double getRelativeErrorMargin() { +inline double getRelativeErrorMargin(const bool) { /* Measured empirically with gemm. The dimensions of the matrices (even k) * don't seem to have an impact on the observed relative differences * In the cases where the relative error is relevant (non close to zero), @@ -142,7 +152,7 @@ inline double getRelativeErrorMargin() { #ifdef BLAS_DATA_TYPE_HALF template <> -inline cl::sycl::half getRelativeErrorMargin() { +inline cl::sycl::half getRelativeErrorMargin(const bool) { // Measured empirically with gemm return 0.05f; } @@ -152,16 +162,27 @@ inline cl::sycl::half getRelativeErrorMargin() { * scalars are close to 0) */ template -inline scalar_t getAbsoluteErrorMargin() { +inline scalar_t getAbsoluteErrorMargin(const bool is_trsm) { /* Measured empirically with gemm. * In the cases where the relative error is irrelevant (close to zero), * absolute differences of up to 0.0006 were observed for float */ - return 0.001f; + scalar_t margin = 0.001f; + if (is_trsm) { + const char* en_joint_matrix = std::getenv("SB_ENABLE_JOINT_MATRIX"); + if (en_joint_matrix != NULL && std::is_same::value && + *en_joint_matrix == '1') { + // increase error margin for mixed precision calculation + // for trsm operator. + margin = 0.009f; + } + } + + return margin; } template <> -inline double getAbsoluteErrorMargin() { +inline double getAbsoluteErrorMargin(const bool) { /* Measured empirically with gemm. * In the cases where the relative error is irrelevant (close to zero), * absolute differences of up to 10^-12 were observed for double @@ -171,7 +192,7 @@ inline double getAbsoluteErrorMargin() { #ifdef BLAS_DATA_TYPE_HALF template <> -inline cl::sycl::half getAbsoluteErrorMargin() { +inline cl::sycl::half getAbsoluteErrorMargin(const bool) { // Measured empirically with gemm. return 1.0f; } @@ -181,7 +202,8 @@ inline cl::sycl::half getAbsoluteErrorMargin() { * Compare two scalars and returns false if the difference is not acceptable. */ template -inline bool almost_equal(scalar_t const& scalar1, scalar_t const& scalar2) { +inline bool almost_equal(scalar_t const& scalar1, scalar_t const& scalar2, + const bool is_trsm = false) { // Shortcut, also handles case where both are zero if (scalar1 == scalar2) { return true; @@ -196,12 +218,13 @@ inline bool almost_equal(scalar_t const& scalar1, scalar_t const& scalar2) { // Close to zero, the relative error doesn't work, use absolute error if (scalar1 == scalar_t{0} || scalar2 == scalar_t{0} || - absolute_diff < getAbsoluteErrorMargin()) { - return (absolute_diff < getAbsoluteErrorMargin()); + absolute_diff < getAbsoluteErrorMargin(is_trsm)) { + return (absolute_diff < getAbsoluteErrorMargin(is_trsm)); } // Use relative error const auto absolute_sum = utils::abs(scalar1) + utils::abs(scalar2); - return (absolute_diff / absolute_sum) < getRelativeErrorMargin(); + return (absolute_diff / absolute_sum) < + getRelativeErrorMargin(is_trsm); } /** @@ -214,6 +237,7 @@ inline bool almost_equal(scalar_t const& scalar1, scalar_t const& scalar2) { template inline bool compare_vectors(std::vector const& vec, std::vector const& ref, + const bool is_trsm = false, std::ostream& err_stream = std::cerr, std::string end_line = "\n") { if (vec.size() != ref.size()) { @@ -223,7 +247,7 @@ inline bool compare_vectors(std::vector const& vec, } for (int i = 0; i < vec.size(); ++i) { - if (!almost_equal(vec[i], ref[i])) { + if (!almost_equal(vec[i], ref[i], is_trsm)) { err_stream << "Value mismatch at index " << i << ": " << vec[i] << "; expected " << ref[i] << end_line; return false; @@ -243,6 +267,7 @@ inline bool compare_vectors(std::vector const& vec, template inline bool compare_vectors(std::vector> const& vec, std::vector> const& ref, + bool is_trsm = false, std::ostream& err_stream = std::cerr, std::string end_line = "\n") { if (vec.size() != ref.size()) { diff --git a/test/unittest/blas3/blas3_trsm_test.cpp b/test/unittest/blas3/blas3_trsm_test.cpp index 76c87e2db..b4d916326 100644 --- a/test/unittest/blas3/blas3_trsm_test.cpp +++ b/test/unittest/blas3/blas3_trsm_test.cpp @@ -92,7 +92,7 @@ void run_test(const combination_t combi) { blas::helper::copy_to_host(q, b_gpu, B.data(), B.size()); sb_handle.wait(event); - bool isAlmostEqual = utils::compare_vectors(cpu_B, B); + bool isAlmostEqual = utils::compare_vectors(cpu_B, B, true); ASSERT_TRUE(isAlmostEqual);