Skip to content

Commit

Permalink
Add verify tests for multi head attention
Browse files Browse the repository at this point in the history
  • Loading branch information
marko-fabo-htec committed Nov 27, 2024
1 parent 8f85c7d commit 9f0d026
Show file tree
Hide file tree
Showing 11 changed files with 462 additions and 0 deletions.
95 changes: 95 additions & 0 deletions test/onnx/gen_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -7785,6 +7785,101 @@ def mod_test_fmod_different_dtypes():
return ([node], [a, b], [y])


@onnx_test()
def mha_test():
query = helper.make_tensor_value_info("q", TensorProto.FLOAT, [1, 2, 4])
key = helper.make_tensor_value_info("k", TensorProto.FLOAT, [1, 2, 4])
value = helper.make_tensor_value_info("v", TensorProto.FLOAT, [1, 2, 4])
out = helper.make_tensor_value_info("out", TensorProto.FLOAT, [1, 2, 4])

node = helper.make_node('MultiHeadAttention',
inputs=['q', 'k', 'v'],
outputs=['out'],
num_heads=2,
domain='com.microsoft')

return ([node], [query, key, value], [out])


@onnx_test()
def mha_cross_attention_test():
query = helper.make_tensor_value_info("q", TensorProto.FLOAT, [1, 2, 4])
key = helper.make_tensor_value_info("k", TensorProto.FLOAT, [1, 2, 2, 2])
value = helper.make_tensor_value_info("v", TensorProto.FLOAT, [1, 2, 2, 2])
out = helper.make_tensor_value_info("out", TensorProto.FLOAT, [1, 2, 4])

node = helper.make_node('MultiHeadAttention',
inputs=['q', 'k', 'v'],
outputs=['out'],
num_heads=2,
domain='com.microsoft')

return ([node], [query, key, value], [out])


@onnx_test()
def mha_kv_packed_test():
query = helper.make_tensor_value_info("q", TensorProto.FLOAT, [1, 2, 4])
kv = helper.make_tensor_value_info("kv", TensorProto.FLOAT, [1, 2, 2, 2, 2])
out = helper.make_tensor_value_info("out", TensorProto.FLOAT, [1, 2, 4])

node = helper.make_node('MultiHeadAttention',
inputs=['q', 'kv'],
outputs=['out'],
num_heads=2,
domain='com.microsoft')

return ([node], [query, kv], [out])


@onnx_test()
def mha_qkv_packed_test():
qkv = helper.make_tensor_value_info("qkv", TensorProto.FLOAT, [1, 2, 2, 3, 2])
out = helper.make_tensor_value_info("out", TensorProto.FLOAT, [1, 2, 4])

node = helper.make_node('MultiHeadAttention',
inputs=['qkv'],
outputs=['out'],
num_heads=2,
domain='com.microsoft')

return ([node], [qkv], [out])


@onnx_test()
def mha_bias_test():
query = helper.make_tensor_value_info("q", TensorProto.FLOAT, [1, 2, 4])
key = helper.make_tensor_value_info("k", TensorProto.FLOAT, [1, 2, 4])
value = helper.make_tensor_value_info("v", TensorProto.FLOAT, [1, 2, 4])
bias = helper.make_tensor_value_info("b", TensorProto.FLOAT, [12])
out = helper.make_tensor_value_info("out", TensorProto.FLOAT, [1, 2, 4])

node = helper.make_node('MultiHeadAttention',
inputs=['q', 'k', 'v', 'b'],
outputs=['out'],
num_heads=2,
domain='com.microsoft')

return ([node], [query, key, value, bias], [out])


@onnx_test()
def mha_attention_bias_test():
query = helper.make_tensor_value_info("q", TensorProto.FLOAT, [1, 2, 4])
key = helper.make_tensor_value_info("k", TensorProto.FLOAT, [1, 2, 4])
value = helper.make_tensor_value_info("v", TensorProto.FLOAT, [1, 2, 4])
attention_bias = helper.make_tensor_value_info("ab", TensorProto.FLOAT, [1, 2, 2, 2])
out = helper.make_tensor_value_info("out", TensorProto.FLOAT, [1, 2, 4])

node = helper.make_node('MultiHeadAttention',
inputs=['q', 'k', 'v', '', '', 'ab'],
outputs=['out'],
num_heads=2,
domain='com.microsoft')

return ([node], [query, key, value, attention_bias], [out])


@onnx_test()
def multinomial_test():
sample_size = 13
Expand Down
Binary file added test/onnx/mha_attention_bias_test.onnx
Binary file not shown.
31 changes: 31 additions & 0 deletions test/onnx/mha_bias_test.onnx
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
 mha_bias_test:�
F
q
k
v
bout"MultiHeadAttention*
num_heads�:com.microsoftmha_bias_testZ
q



Z
k



Z
v



Z
b


 b
out



B
Expand Down
28 changes: 28 additions & 0 deletions test/onnx/mha_cross_attention_test.onnx
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
 mha_cross_attention_test:�
C
q
k
vout"MultiHeadAttention*
num_heads�:com.microsoftmha_cross_attention_testZ
q



Z
k




Z
v




b
out



B
Expand Down
22 changes: 22 additions & 0 deletions test/onnx/mha_kv_packed_test.onnx
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
 mha_kv_packed_test:�
A
q
kvout"MultiHeadAttention*
num_heads�:com.microsoftmha_kv_packed_testZ
q



Z
kv





b
out



B
Expand Down
16 changes: 16 additions & 0 deletions test/onnx/mha_qkv_packed_test.onnx
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
 mha_qkv_packed_test:�
?
qkvout"MultiHeadAttention*
num_heads�:com.microsoftmha_qkv_packed_testZ!
qkv





b
out



B
Expand Down
26 changes: 26 additions & 0 deletions test/onnx/mha_test.onnx
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
 mha_test:�
C
q
k
vout"MultiHeadAttention*
num_heads�:com.microsoftmha_testZ
q



Z
k



Z
v



b
out



B
Expand Down
50 changes: 50 additions & 0 deletions test/onnx/verify/mha_cross_attention_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/

#include <migraphx/register_target.hpp>
#include <migraphx/verify.hpp>
#include <onnx_test.hpp>

TEST_CASE(multi_head_attention_cross_attention_test)
{
migraphx::program p = read_onnx("mha_cross_attention_test.onnx");
p.compile(migraphx::make_target("ref"));

migraphx::shape q_s{migraphx::shape::float_type, {1, 2, 4}};
migraphx::shape kv_s{migraphx::shape::float_type, {1, 2, 2, 2}};
std::vector<float> query = {1, 3, 5, 7, 2, 4, 6, 8};
std::vector<float> key_value = {1, 3, 2, 4, 5, 7, 6, 8};

migraphx::parameter_map pp;
pp["q"] = migraphx::argument(q_s, query.data());
pp["k"] = migraphx::argument(kv_s, key_value.data());
pp["v"] = migraphx::argument(kv_s, key_value.data());

auto result = p.eval(pp).back();
std::vector<float> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {
1.9441926, 3.9441926, 5.9997935, 7.999794, 1.9858339, 3.9858341, 5.99995, 7.9999495};
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
49 changes: 49 additions & 0 deletions test/onnx/verify/mha_kv_packed_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/

#include <migraphx/register_target.hpp>
#include <migraphx/verify.hpp>
#include <onnx_test.hpp>

TEST_CASE(multi_head_attention_kv_packed_test)
{
migraphx::program p = read_onnx("mha_kv_packed_test.onnx");
p.compile(migraphx::make_target("ref"));

migraphx::shape q_s{migraphx::shape::float_type, {1, 2, 4}};
migraphx::shape kv_s{migraphx::shape::float_type, {1, 2, 2, 2, 2}};
std::vector<float> query = {1, 3, 5, 7, 2, 4, 6, 8};
std::vector<float> key_value = {1, 3, 1, 3, 5, 7, 5, 7, 2, 4, 2, 4, 6, 8, 6, 8};

migraphx::parameter_map pp;
pp["q"] = migraphx::argument(q_s, query.data());
pp["kv"] = migraphx::argument(kv_s, key_value.data());

auto result = p.eval(pp).back();
std::vector<float> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {
1.9441926, 3.9441926, 5.9997935, 7.999794, 1.9858339, 3.9858341, 5.99995, 7.9999495};
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
47 changes: 47 additions & 0 deletions test/onnx/verify/mha_qkv_packed_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/

#include <migraphx/register_target.hpp>
#include <migraphx/verify.hpp>
#include <onnx_test.hpp>

TEST_CASE(multi_head_attention_qkv_packed_test)
{
migraphx::program p = read_onnx("mha_qkv_packed_test.onnx");
p.compile(migraphx::make_target("ref"));

migraphx::shape s{migraphx::shape::float_type, {1, 2, 2, 3, 2}};
std::vector<float> data = {1, 3, 1, 3, 1, 3, 5, 7, 5, 7, 5, 7,
2, 4, 2, 4, 2, 4, 6, 8, 6, 8, 6, 8};

migraphx::parameter_map pp;
pp["qkv"] = migraphx::argument(s, data.data());

auto result = p.eval(pp).back();
std::vector<float> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {
1.9441926, 3.9441926, 5.9997935, 7.999794, 1.9858339, 3.9858341, 5.99995, 7.9999495};
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
Loading

0 comments on commit 9f0d026

Please sign in to comment.