Skip to content

Commit

Permalink
added accuracy2 to mmnb ut
Browse files Browse the repository at this point in the history
  • Loading branch information
fajin-corp committed Nov 14, 2024
1 parent 6f1a063 commit 051daf2
Showing 1 changed file with 27 additions and 27 deletions.
54 changes: 27 additions & 27 deletions onnxruntime/test/contrib_ops/matmul_4bits_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -395,44 +395,44 @@ TEST(MatMulNBits, Float32_Accuracy4) {
#if !defined(USE_DML)
// Actual and expected difference is over 0.01 with DmlExecutionProvider.
// Skip the tests instead of raising the tolerance to make is pass.
TEST(MatMulNBits, Float16_Accuracy2) {
TestMatMulNBitsTyped<MLFloat16, 1, 1, 16, 16, 2>();
TestMatMulNBitsTyped<MLFloat16, 1, 2, 16, 16, 2>();
TestMatMulNBitsTyped<MLFloat16, 1, 32, 16, 16, 2>();
TestMatMulNBitsTyped<MLFloat16, 1, 32, 32, 16, 2>();
TestMatMulNBitsTyped<MLFloat16, 1, 32, 16, 128, 2>();
TestMatMulNBitsTyped<MLFloat16, 1, 288, 16, 16, 2>();
TestMatMulNBitsTyped<MLFloat16, 1, 288, 1024, 16, 2>();
TestMatMulNBitsTyped<MLFloat16, 1, 288, 1024, 128, 2>();
TestMatMulNBitsTyped<MLFloat16, 1, 288, 93, 32, 2>();
TestMatMulNBitsTyped<MLFloat16, 1, 288, 93, 128, 2>();
TestMatMulNBitsTyped<MLFloat16, 1, 288, 1234, 16, 2>();
TestMatMulNBitsTyped<MLFloat16, 2, 1, 16, 16, 2>();
TestMatMulNBitsTyped<MLFloat16, 2, 2, 16, 16, 2>();
TestMatMulNBitsTyped<MLFloat16, 100, 1, 16, 16, 2>();
TestMatMulNBitsTyped<MLFloat16, 100, 2, 16, 16, 2>();
TestMatMulNBitsTyped<MLFloat16, 100, 32, 16, 16, 2>();
TestMatMulNBitsTyped<MLFloat16, 100, 32, 32, 16, 2>();
TestMatMulNBitsTyped<MLFloat16, 100, 32, 16, 128, 2>();
TestMatMulNBitsTyped<MLFloat16, 100, 288, 16, 16, 2>();
TestMatMulNBitsTyped<MLFloat16, 100, 288, 1024, 16, 2>();
TestMatMulNBitsTyped<MLFloat16, 100, 288, 1024, 128, 2>();
TestMatMulNBitsTyped<MLFloat16, 100, 288, 93, 32, 2>();
TestMatMulNBitsTyped<MLFloat16, 100, 288, 93, 128, 2>();
TestMatMulNBitsTyped<MLFloat16, 100, 288, 1234, 16, 2>();
}

TEST(MatMulNBits, Float16_Accuracy0) {
TestMatMulNBitsTyped<MLFloat16, 1, 1, 16, 16, 0>();
TestMatMulNBitsTyped<MLFloat16, 1, 2, 16, 16, 0>();
TestMatMulNBitsTyped<MLFloat16, 1, 32, 16, 16, 0>();
TestMatMulNBitsTyped<MLFloat16, 1, 32, 32, 16, 0>();
TestMatMulNBitsTyped<MLFloat16, 1, 32, 16, 128, 0>();
TestMatMulNBitsTyped<MLFloat16, 1, 288, 16, 16, 0>();
TestMatMulNBitsTyped<MLFloat16, 1, 288, 1024, 16, 0>();
TestMatMulNBitsTyped<MLFloat16, 1, 288, 1024, 128, 0>();
TestMatMulNBitsTyped<MLFloat16, 1, 288, 93, 32, 0>();
TestMatMulNBitsTyped<MLFloat16, 1, 288, 93, 128, 0>();
TestMatMulNBitsTyped<MLFloat16, 1, 288, 1234, 16, 0>();
TestMatMulNBitsTyped<MLFloat16, 2, 1, 16, 16, 0>();
TestMatMulNBitsTyped<MLFloat16, 2, 2, 16, 16, 0>();
TestMatMulNBitsTyped<MLFloat16, 100, 1, 16, 16, 0>();
TestMatMulNBitsTyped<MLFloat16, 100, 2, 16, 16, 0>();
TestMatMulNBitsTyped<MLFloat16, 100, 32, 16, 16, 0>();
TestMatMulNBitsTyped<MLFloat16, 100, 32, 32, 16, 0>();
TestMatMulNBitsTyped<MLFloat16, 100, 32, 16, 128, 0>();
TestMatMulNBitsTyped<MLFloat16, 100, 288, 16, 16, 0>();
TestMatMulNBitsTyped<MLFloat16, 100, 288, 1024, 16, 0>();
TestMatMulNBitsTyped<MLFloat16, 100, 288, 1024, 128, 0>();
TestMatMulNBitsTyped<MLFloat16, 100, 288, 93, 32, 0>();
TestMatMulNBitsTyped<MLFloat16, 100, 288, 93, 128, 0>();
TestMatMulNBitsTyped<MLFloat16, 100, 288, 1234, 16, 0>();
}

TEST(MatMulNBits, Float16_Accuracy1) {
TestMatMulNBitsTyped<MLFloat16, 1, 1, 16, 16, 1>();
TestMatMulNBitsTyped<MLFloat16, 1, 288, 93, 32, 1>();
TestMatMulNBitsTyped<MLFloat16, 1, 288, 1234, 16, 1>();
TestMatMulNBitsTyped<MLFloat16, 2, 1, 16, 16, 1>();
TestMatMulNBitsTyped<MLFloat16, 100, 2, 16, 16, 1>();
TestMatMulNBitsTyped<MLFloat16, 100, 288, 1024, 128, 1>();
TestMatMulNBitsTyped<MLFloat16, 100, 288, 93, 32, 1>();
TestMatMulNBitsTyped<MLFloat16, 100, 288, 1234, 16, 1>();
}

TEST(MatMulNBits, Float16_Accuracy4) {
TestMatMulNBitsTyped<MLFloat16, 1, 1, 16, 16, 4>();
TestMatMulNBitsTyped<MLFloat16, 1, 2, 16, 16, 4>();
Expand Down

0 comments on commit 051daf2

Please sign in to comment.