From 9f0d026de201daa07566ed949a05be1b2d4fcb87 Mon Sep 17 00:00:00 2001 From: Marko Fabo Date: Wed, 13 Nov 2024 08:38:55 +0000 Subject: [PATCH] Add verify tests for multi head attention --- test/onnx/gen_onnx.py | 95 +++++++++++++++++ test/onnx/mha_attention_bias_test.onnx | Bin 0 -> 268 bytes test/onnx/mha_bias_test.onnx | 31 ++++++ test/onnx/mha_cross_attention_test.onnx | 28 +++++ test/onnx/mha_kv_packed_test.onnx | 22 ++++ test/onnx/mha_qkv_packed_test.onnx | 16 +++ test/onnx/mha_test.onnx | 26 +++++ test/onnx/verify/mha_cross_attention_test.cpp | 50 +++++++++ test/onnx/verify/mha_kv_packed_test.cpp | 49 +++++++++ test/onnx/verify/mha_qkv_packed_test.cpp | 47 +++++++++ test/onnx/verify/mha_test.cpp | 98 ++++++++++++++++++ 11 files changed, 462 insertions(+) create mode 100644 test/onnx/mha_attention_bias_test.onnx create mode 100644 test/onnx/mha_bias_test.onnx create mode 100644 test/onnx/mha_cross_attention_test.onnx create mode 100644 test/onnx/mha_kv_packed_test.onnx create mode 100644 test/onnx/mha_qkv_packed_test.onnx create mode 100644 test/onnx/mha_test.onnx create mode 100644 test/onnx/verify/mha_cross_attention_test.cpp create mode 100644 test/onnx/verify/mha_kv_packed_test.cpp create mode 100644 test/onnx/verify/mha_qkv_packed_test.cpp create mode 100644 test/onnx/verify/mha_test.cpp diff --git a/test/onnx/gen_onnx.py b/test/onnx/gen_onnx.py index 05713398294..317eda65645 100644 --- a/test/onnx/gen_onnx.py +++ b/test/onnx/gen_onnx.py @@ -7785,6 +7785,101 @@ def mod_test_fmod_different_dtypes(): return ([node], [a, b], [y]) +@onnx_test() +def mha_test(): + query = helper.make_tensor_value_info("q", TensorProto.FLOAT, [1, 2, 4]) + key = helper.make_tensor_value_info("k", TensorProto.FLOAT, [1, 2, 4]) + value = helper.make_tensor_value_info("v", TensorProto.FLOAT, [1, 2, 4]) + out = helper.make_tensor_value_info("out", TensorProto.FLOAT, [1, 2, 4]) + + node = helper.make_node('MultiHeadAttention', + inputs=['q', 'k', 'v'], + outputs=['out'], + num_heads=2, + domain='com.microsoft') + + return ([node], [query, key, value], [out]) + + +@onnx_test() +def mha_cross_attention_test(): + query = helper.make_tensor_value_info("q", TensorProto.FLOAT, [1, 2, 4]) + key = helper.make_tensor_value_info("k", TensorProto.FLOAT, [1, 2, 2, 2]) + value = helper.make_tensor_value_info("v", TensorProto.FLOAT, [1, 2, 2, 2]) + out = helper.make_tensor_value_info("out", TensorProto.FLOAT, [1, 2, 4]) + + node = helper.make_node('MultiHeadAttention', + inputs=['q', 'k', 'v'], + outputs=['out'], + num_heads=2, + domain='com.microsoft') + + return ([node], [query, key, value], [out]) + + +@onnx_test() +def mha_kv_packed_test(): + query = helper.make_tensor_value_info("q", TensorProto.FLOAT, [1, 2, 4]) + kv = helper.make_tensor_value_info("kv", TensorProto.FLOAT, [1, 2, 2, 2, 2]) + out = helper.make_tensor_value_info("out", TensorProto.FLOAT, [1, 2, 4]) + + node = helper.make_node('MultiHeadAttention', + inputs=['q', 'kv'], + outputs=['out'], + num_heads=2, + domain='com.microsoft') + + return ([node], [query, kv], [out]) + + +@onnx_test() +def mha_qkv_packed_test(): + qkv = helper.make_tensor_value_info("qkv", TensorProto.FLOAT, [1, 2, 2, 3, 2]) + out = helper.make_tensor_value_info("out", TensorProto.FLOAT, [1, 2, 4]) + + node = helper.make_node('MultiHeadAttention', + inputs=['qkv'], + outputs=['out'], + num_heads=2, + domain='com.microsoft') + + return ([node], [qkv], [out]) + + +@onnx_test() +def mha_bias_test(): + query = helper.make_tensor_value_info("q", TensorProto.FLOAT, [1, 2, 4]) + key = helper.make_tensor_value_info("k", TensorProto.FLOAT, [1, 2, 4]) + value = helper.make_tensor_value_info("v", TensorProto.FLOAT, [1, 2, 4]) + bias = helper.make_tensor_value_info("b", TensorProto.FLOAT, [12]) + out = helper.make_tensor_value_info("out", TensorProto.FLOAT, [1, 2, 4]) + + node = helper.make_node('MultiHeadAttention', + inputs=['q', 'k', 'v', 'b'], + outputs=['out'], + num_heads=2, + domain='com.microsoft') + + return ([node], [query, key, value, bias], [out]) + + +@onnx_test() +def mha_attention_bias_test(): + query = helper.make_tensor_value_info("q", TensorProto.FLOAT, [1, 2, 4]) + key = helper.make_tensor_value_info("k", TensorProto.FLOAT, [1, 2, 4]) + value = helper.make_tensor_value_info("v", TensorProto.FLOAT, [1, 2, 4]) + attention_bias = helper.make_tensor_value_info("ab", TensorProto.FLOAT, [1, 2, 2, 2]) + out = helper.make_tensor_value_info("out", TensorProto.FLOAT, [1, 2, 4]) + + node = helper.make_node('MultiHeadAttention', + inputs=['q', 'k', 'v', '', '', 'ab'], + outputs=['out'], + num_heads=2, + domain='com.microsoft') + + return ([node], [query, key, value, attention_bias], [out]) + + @onnx_test() def multinomial_test(): sample_size = 13 diff --git a/test/onnx/mha_attention_bias_test.onnx b/test/onnx/mha_attention_bias_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..898f656ed4cab234a7c555f2ce9096600613060a GIT binary patch literal 268 zcmdxp9-Brd HCIJxuf%h~U literal 0 HcmV?d00001 diff --git a/test/onnx/mha_bias_test.onnx b/test/onnx/mha_bias_test.onnx new file mode 100644 index 00000000000..b89175f5adb --- /dev/null +++ b/test/onnx/mha_bias_test.onnx @@ -0,0 +1,31 @@ +  mha_bias_test:Î +F +q +k +v +bout"MultiHeadAttention* + num_heads : com.microsoft mha_bias_testZ +q + + + +Z +k + + + +Z +v + + + +Z +b + + + b +out + + + +B \ No newline at end of file diff --git a/test/onnx/mha_cross_attention_test.onnx b/test/onnx/mha_cross_attention_test.onnx new file mode 100644 index 00000000000..e68c9dcab5b --- /dev/null +++ b/test/onnx/mha_cross_attention_test.onnx @@ -0,0 +1,28 @@ + mha_cross_attention_test:Í +C +q +k +vout"MultiHeadAttention* + num_heads : com.microsoftmha_cross_attention_testZ +q + + + +Z +k + + + + +Z +v + + + + +b +out + + + +B \ No newline at end of file diff --git a/test/onnx/mha_kv_packed_test.onnx b/test/onnx/mha_kv_packed_test.onnx new file mode 100644 index 00000000000..bc0bf12f6e0 --- /dev/null +++ b/test/onnx/mha_kv_packed_test.onnx @@ -0,0 +1,22 @@ + mha_kv_packed_test:­ +A +q +kvout"MultiHeadAttention* + num_heads : com.microsoftmha_kv_packed_testZ +q + + + +Z +kv + + + + + +b +out + + + +B \ No newline at end of file diff --git a/test/onnx/mha_qkv_packed_test.onnx b/test/onnx/mha_qkv_packed_test.onnx new file mode 100644 index 00000000000..cbc7ff18cc9 --- /dev/null +++ b/test/onnx/mha_qkv_packed_test.onnx @@ -0,0 +1,16 @@ + mha_qkv_packed_test:” +? +qkvout"MultiHeadAttention* + num_heads : com.microsoftmha_qkv_packed_testZ! +qkv + + + + + +b +out + + + +B \ No newline at end of file diff --git a/test/onnx/mha_test.onnx b/test/onnx/mha_test.onnx new file mode 100644 index 00000000000..9a3133a2993 --- /dev/null +++ b/test/onnx/mha_test.onnx @@ -0,0 +1,26 @@ + mha_test:µ +C +q +k +vout"MultiHeadAttention* + num_heads : com.microsoftmha_testZ +q + + + +Z +k + + + +Z +v + + + +b +out + + + +B \ No newline at end of file diff --git a/test/onnx/verify/mha_cross_attention_test.cpp b/test/onnx/verify/mha_cross_attention_test.cpp new file mode 100644 index 00000000000..318eb0a16cf --- /dev/null +++ b/test/onnx/verify/mha_cross_attention_test.cpp @@ -0,0 +1,50 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include +#include +#include + +TEST_CASE(multi_head_attention_cross_attention_test) +{ + migraphx::program p = read_onnx("mha_cross_attention_test.onnx"); + p.compile(migraphx::make_target("ref")); + + migraphx::shape q_s{migraphx::shape::float_type, {1, 2, 4}}; + migraphx::shape kv_s{migraphx::shape::float_type, {1, 2, 2, 2}}; + std::vector query = {1, 3, 5, 7, 2, 4, 6, 8}; + std::vector key_value = {1, 3, 2, 4, 5, 7, 6, 8}; + + migraphx::parameter_map pp; + pp["q"] = migraphx::argument(q_s, query.data()); + pp["k"] = migraphx::argument(kv_s, key_value.data()); + pp["v"] = migraphx::argument(kv_s, key_value.data()); + + auto result = p.eval(pp).back(); + std::vector result_vector; + result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); + std::vector gold = { + 1.9441926, 3.9441926, 5.9997935, 7.999794, 1.9858339, 3.9858341, 5.99995, 7.9999495}; + EXPECT(migraphx::verify::verify_rms_range(result_vector, gold)); +} diff --git a/test/onnx/verify/mha_kv_packed_test.cpp b/test/onnx/verify/mha_kv_packed_test.cpp new file mode 100644 index 00000000000..3594ea29c3d --- /dev/null +++ b/test/onnx/verify/mha_kv_packed_test.cpp @@ -0,0 +1,49 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include +#include +#include + +TEST_CASE(multi_head_attention_kv_packed_test) +{ + migraphx::program p = read_onnx("mha_kv_packed_test.onnx"); + p.compile(migraphx::make_target("ref")); + + migraphx::shape q_s{migraphx::shape::float_type, {1, 2, 4}}; + migraphx::shape kv_s{migraphx::shape::float_type, {1, 2, 2, 2, 2}}; + std::vector query = {1, 3, 5, 7, 2, 4, 6, 8}; + std::vector key_value = {1, 3, 1, 3, 5, 7, 5, 7, 2, 4, 2, 4, 6, 8, 6, 8}; + + migraphx::parameter_map pp; + pp["q"] = migraphx::argument(q_s, query.data()); + pp["kv"] = migraphx::argument(kv_s, key_value.data()); + + auto result = p.eval(pp).back(); + std::vector result_vector; + result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); + std::vector gold = { + 1.9441926, 3.9441926, 5.9997935, 7.999794, 1.9858339, 3.9858341, 5.99995, 7.9999495}; + EXPECT(migraphx::verify::verify_rms_range(result_vector, gold)); +} diff --git a/test/onnx/verify/mha_qkv_packed_test.cpp b/test/onnx/verify/mha_qkv_packed_test.cpp new file mode 100644 index 00000000000..accff9279ca --- /dev/null +++ b/test/onnx/verify/mha_qkv_packed_test.cpp @@ -0,0 +1,47 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include +#include +#include + +TEST_CASE(multi_head_attention_qkv_packed_test) +{ + migraphx::program p = read_onnx("mha_qkv_packed_test.onnx"); + p.compile(migraphx::make_target("ref")); + + migraphx::shape s{migraphx::shape::float_type, {1, 2, 2, 3, 2}}; + std::vector data = {1, 3, 1, 3, 1, 3, 5, 7, 5, 7, 5, 7, + 2, 4, 2, 4, 2, 4, 6, 8, 6, 8, 6, 8}; + + migraphx::parameter_map pp; + pp["qkv"] = migraphx::argument(s, data.data()); + + auto result = p.eval(pp).back(); + std::vector result_vector; + result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); + std::vector gold = { + 1.9441926, 3.9441926, 5.9997935, 7.999794, 1.9858339, 3.9858341, 5.99995, 7.9999495}; + EXPECT(migraphx::verify::verify_rms_range(result_vector, gold)); +} diff --git a/test/onnx/verify/mha_test.cpp b/test/onnx/verify/mha_test.cpp new file mode 100644 index 00000000000..9fb77f4eb21 --- /dev/null +++ b/test/onnx/verify/mha_test.cpp @@ -0,0 +1,98 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include +#include +#include + +TEST_CASE(multi_head_attention_test) +{ + migraphx::program p = read_onnx("mha_test.onnx"); + p.compile(migraphx::make_target("ref")); + + migraphx::shape s{migraphx::shape::float_type, {1, 2, 4}}; + std::vector data = {1, 3, 5, 7, 2, 4, 6, 8}; + + migraphx::parameter_map pp; + pp["q"] = migraphx::argument(s, data.data()); + pp["k"] = migraphx::argument(s, data.data()); + pp["v"] = migraphx::argument(s, data.data()); + + auto result = p.eval(pp).back(); + std::vector result_vector; + result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); + std::vector gold = { + 1.9441926, 3.9441926, 5.9997935, 7.999794, 1.9858339, 3.9858341, 5.99995, 7.9999495}; + EXPECT(migraphx::verify::verify_rms_range(result_vector, gold)); +} + +TEST_CASE(multi_head_attention_bias_test) +{ + migraphx::program p = read_onnx("mha_bias_test.onnx"); + p.compile(migraphx::make_target("ref")); + + migraphx::shape s{migraphx::shape::float_type, {1, 2, 4}}; + std::vector data = {1, 3, 5, 7, 2, 4, 6, 8}; + + migraphx::shape bs{migraphx::shape::float_type, {12}}; + std::vector bias_data(12, 1); + + migraphx::parameter_map pp; + pp["q"] = migraphx::argument(s, data.data()); + pp["k"] = migraphx::argument(s, data.data()); + pp["v"] = migraphx::argument(s, data.data()); + pp["b"] = migraphx::argument(bs, bias_data.data()); + + auto result = p.eval(pp).back(); + std::vector result_vector; + result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); + std::vector gold = { + 2.985834, 4.985834, 6.99995, 8.999949, 2.9965186, 4.9965186, 6.9999886, 8.999988}; + EXPECT(migraphx::verify::verify_rms_range(result_vector, gold)); +} + +TEST_CASE(multi_head_attention_attention_bias_test) +{ + migraphx::program p = read_onnx("mha_attention_bias_test.onnx"); + p.compile(migraphx::make_target("ref")); + + migraphx::shape s{migraphx::shape::float_type, {1, 2, 4}}; + std::vector data = {1, 3, 5, 7, 2, 4, 6, 8}; + + migraphx::shape abs{migraphx::shape::float_type, {1, 2, 2, 2}}; + std::vector att_bias_data = {1, 2, 3, 4, 5, 6, 7, 8}; + + migraphx::parameter_map pp; + pp["q"] = migraphx::argument(s, data.data()); + pp["k"] = migraphx::argument(s, data.data()); + pp["v"] = migraphx::argument(s, data.data()); + pp["ab"] = migraphx::argument(abs, att_bias_data.data()); + + auto result = p.eval(pp).back(); + std::vector result_vector; + result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); + std::vector gold = { + 1.9787189, 3.978719, 5.999924, 7.999924, 1.9947413, 3.9947412, 5.9999814, 7.9999814}; + EXPECT(migraphx::verify::verify_rms_range(result_vector, gold)); +}