Skip to content

Commit

Permalink
Add parse tests for multi head attention
Browse files Browse the repository at this point in the history
  • Loading branch information
marko-fabo-htec committed Dec 3, 2024
1 parent 0d75fde commit becc9ab
Show file tree
Hide file tree
Showing 11 changed files with 462 additions and 0 deletions.
65 changes: 65 additions & 0 deletions test/onnx/gen_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -7848,6 +7848,71 @@ def mha_qkv_packed_test():
return ([node], [qkv], [out])


@onnx_test()
def mha_invalid_input_test():
node = helper.make_node('MultiHeadAttention',
inputs=[],
outputs=[],
num_heads=1,
domain='com.microsoft')

return ([node], [], [])


@onnx_test()
def mha_invalid_query_test():
query = helper.make_tensor_value_info("q", TensorProto.FLOAT, [1, 1])

node = helper.make_node('MultiHeadAttention',
inputs=['q'],
outputs=[],
num_heads=1,
domain='com.microsoft')

return ([node], [query], [])


@onnx_test()
def mha_invalid_qkv_test():
qkv = helper.make_tensor_value_info("qkv", TensorProto.FLOAT,
[1, 1, 1, 1, 1])

node = helper.make_node('MultiHeadAttention',
inputs=['qkv'],
outputs=[],
num_heads=1,
domain='com.microsoft')

return ([node], [qkv], [])


@onnx_test()
def mha_invalid_key_missing_test():
query = helper.make_tensor_value_info("q", TensorProto.FLOAT, [1, 1, 1])

node = helper.make_node('MultiHeadAttention',
inputs=['q'],
outputs=[],
num_heads=1,
domain='com.microsoft')

return ([node], [query], [])


@onnx_test()
def mha_invalid_key_test():
query = helper.make_tensor_value_info("q", TensorProto.FLOAT, [1, 1, 1])
key = helper.make_tensor_value_info("k", TensorProto.FLOAT, [1, 1])

node = helper.make_node('MultiHeadAttention',
inputs=['q', 'k'],
outputs=[],
num_heads=1,
domain='com.microsoft')

return ([node], [query, key], [])


@onnx_test()
def multinomial_test():
sample_size = 13
Expand Down
3 changes: 3 additions & 0 deletions test/onnx/mha_invalid_input_test.onnx
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
 mha_invalid_input_test:O
5"MultiHeadAttention*
num_heads�:com.microsoftmha_invalid_input_testB
Expand Down
9 changes: 9 additions & 0 deletions test/onnx/mha_invalid_key_missing_test.onnx
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
 mha_invalid_key_missing_test:q
8
q"MultiHeadAttention*
num_heads�:com.microsoftmha_invalid_key_missing_testZ
q



B
Expand Down
14 changes: 14 additions & 0 deletions test/onnx/mha_invalid_key_test.onnx
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
 mha_invalid_key_test:�
;
q
k"MultiHeadAttention*
num_heads�:com.microsoftmha_invalid_key_testZ
q



Z
k


B
Expand Down
11 changes: 11 additions & 0 deletions test/onnx/mha_invalid_qkv_test.onnx
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
 mha_invalid_qkv_test:u
:
qkv"MultiHeadAttention*
num_heads�:com.microsoftmha_invalid_qkv_testZ!
qkv





B
Expand Down
8 changes: 8 additions & 0 deletions test/onnx/mha_invalid_query_test.onnx
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
 mha_invalid_query_test:g
8
q"MultiHeadAttention*
num_heads�:com.microsoftmha_invalid_query_testZ
q


B
Expand Down
69 changes: 69 additions & 0 deletions test/onnx/parse/mha_cross_attention_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/

#include <onnx_test.hpp>

TEST_CASE(multi_head_attention_cross_attention_test)
{
const int64_t batch_size = 1;
const int64_t sequence_length = 2;
const int64_t hidden_size = 4;
const int64_t num_heads = 2;
const int64_t head_size = 2;
const std::vector<std::size_t> q_lens{batch_size, sequence_length, hidden_size};
const std::vector<std::size_t> kv_lens{batch_size, num_heads, sequence_length, head_size};
const std::vector<std::size_t> reshape_lens{batch_size, sequence_length, num_heads, head_size};
const std::vector<int64_t> permutation{0, 2, 1, 3};

migraphx::program p;
auto* mm = p.get_main_module();
auto q = mm->add_parameter("q", {migraphx::shape::float_type, q_lens});
auto k = mm->add_parameter("k", {migraphx::shape::float_type, kv_lens});
auto v = mm->add_parameter("v", {migraphx::shape::float_type, kv_lens});

q = mm->add_instruction(migraphx::make_op("reshape", {{"dims", reshape_lens}}), q);
q = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", permutation}}), q);

const float scale = 1 / std::sqrt(head_size);
auto scale_literal =
mm->add_literal(migraphx::literal{migraphx::shape{migraphx::shape::float_type}, {scale}});

auto key_transposed =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), k);

auto result = mm->add_instruction(migraphx::make_op("dot"), q, key_transposed);
scale_literal = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", result->get_shape().lens()}}),
scale_literal);
result = mm->add_instruction(migraphx::make_op("mul"), result, scale_literal);
result = mm->add_instruction(migraphx::make_op("softmax", {{"axis", -1}}), result);
result = mm->add_instruction(migraphx::make_op("dot"), result, v);
result =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", permutation}}), result);
result = mm->add_instruction(migraphx::make_op("reshape", {{"dims", q_lens}}), result);

auto prog = optimize_onnx("mha_cross_attention_test.onnx");

EXPECT(p == prog);
}
50 changes: 50 additions & 0 deletions test/onnx/parse/mha_invalid_tests.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/

#include <onnx_test.hpp>

TEST_CASE(multi_head_attention_invalid_input_test)
{
EXPECT(test::throws([&] { read_onnx("mha_invalid_input_test.onnx"); }));
}

TEST_CASE(multi_head_attention_invalid_query_test)
{
EXPECT(test::throws([&] { read_onnx("mha_invalid_query_test.onnx"); }));
}

TEST_CASE(multi_head_attention_invalid_qkv_test)
{
EXPECT(test::throws([&] { read_onnx("mha_invalid_qkv_test.onnx"); }));
}

TEST_CASE(multi_head_attention_invalid_key_missing_test)
{
EXPECT(test::throws([&] { read_onnx("mha_invalid_key_missing_test.onnx"); }));
}

TEST_CASE(multi_head_attention_invalid_key_test)
{
EXPECT(test::throws([&] { read_onnx("mha_invalid_key_test.onnx"); }));
}
80 changes: 80 additions & 0 deletions test/onnx/parse/mha_kv_packed_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/

#include <onnx_test.hpp>

TEST_CASE(multi_head_attention_kv_packed_test)
{
const int64_t batch_size = 1;
const int64_t sequence_length = 2;
const int64_t hidden_size = 4;
const int64_t num_heads = 2;
const int64_t head_size = 2;
const std::vector<std::size_t> q_lens{batch_size, sequence_length, hidden_size};
const std::vector<std::size_t> kv_lens{batch_size, sequence_length, num_heads, 2, head_size};
const std::vector<std::size_t> reshape_lens{batch_size, sequence_length, num_heads, head_size};
const std::vector<int64_t> permutation{0, 2, 1, 3};

migraphx::program p;
auto* mm = p.get_main_module();
auto q = mm->add_parameter("q", {migraphx::shape::float_type, q_lens});
auto kv = mm->add_parameter("kv", {migraphx::shape::float_type, kv_lens});

q = mm->add_instruction(migraphx::make_op("reshape", {{"dims", reshape_lens}}), q);

kv =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {3, 0, 1, 2, 4}}}), kv);
auto k = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), kv);
k = mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), k);
auto v = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), kv);
v = mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), v);

q = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", permutation}}), q);
k = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", permutation}}), k);
v = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", permutation}}), v);

const float scale = 1 / std::sqrt(head_size);
auto scale_literal =
mm->add_literal(migraphx::literal{migraphx::shape{migraphx::shape::float_type}, {scale}});

auto key_transposed =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), k);

auto result = mm->add_instruction(migraphx::make_op("dot"), q, key_transposed);
scale_literal = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", result->get_shape().lens()}}),
scale_literal);
result = mm->add_instruction(migraphx::make_op("mul"), result, scale_literal);
result = mm->add_instruction(migraphx::make_op("softmax", {{"axis", -1}}), result);
result = mm->add_instruction(migraphx::make_op("dot"), result, v);
result =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", permutation}}), result);
result = mm->add_instruction(migraphx::make_op("reshape", {{"dims", q_lens}}), result);

auto prog = optimize_onnx("mha_kv_packed_test.onnx");

EXPECT(p == prog);
}
Loading

0 comments on commit becc9ab

Please sign in to comment.