Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Test] Convert test_rnn_vanilla , test_gru, test_rnn_extra and test_gru_extra gTests #2550

Merged
merged 99 commits into from
Feb 9, 2024
Merged
Show file tree
Hide file tree
Changes from 67 commits
Commits
Show all changes
99 commits
Select commit Hold shift + click to select a range
9a41dde
WIP: updatable env vars, first draft
amberhassaan Nov 7, 2023
1e63ea6
update env var declarations
cderb Nov 7, 2023
0222ee6
draw out env string definitions
cderb Nov 7, 2023
cd035e4
compilation fixes
cderb Nov 8, 2023
de4766e
Merge remote-tracking branch 'origin/develop' into cderb/env-var-update
cderb Nov 8, 2023
79446a4
fix
cderb Nov 8, 2023
9d4d4f5
fix
cderb Nov 8, 2023
541eea5
update driver envs
cderb Nov 8, 2023
5e59daf
env for test folder
cderb Nov 8, 2023
438aa34
compilation fixes
cderb Nov 8, 2023
fb292b3
format
cderb Nov 8, 2023
44a174f
tidy
cderb Nov 8, 2023
473b17c
quality improvements for env.h
cderb Nov 8, 2023
d33b9a9
convert to const string reference
cderb Nov 8, 2023
694a079
fix
cderb Nov 8, 2023
9c12c2f
fix string definition
cderb Nov 9, 2023
a3605ab
formatting
cderb Nov 9, 2023
688919c
fix
cderb Nov 9, 2023
5ea88a6
use std::string::compare
cderb Nov 9, 2023
434f154
rename IsDefault to IsUnset, check unset in disabled/enabled check
cderb Nov 9, 2023
708b472
Merge branch 'cderb/env-var-update' of https://github.com/ROCmSoftwar…
cderb Nov 9, 2023
55dfe4d
macro for bool + string envs
cderb Nov 9, 2023
c1a257a
fix macros
cderb Nov 9, 2023
e21d7da
relabel uint64 envs
cderb Nov 9, 2023
f037fe9
typo fix + format
cderb Nov 9, 2023
051538c
tidy
cderb Nov 10, 2023
fbc8558
Update src/include/miopen/env.hpp
cderb Nov 10, 2023
9074f83
update tests with env syntax
cderb Nov 10, 2023
dd38c5f
Merge branch 'cderb/env-var-update' of https://github.com/ROCmSoftwar…
cderb Nov 10, 2023
42893a9
Merge branch 'develop' into cderb/env-var-update
cderb Nov 10, 2023
57a80f5
format
cderb Nov 10, 2023
86c61c7
string references
cderb Nov 13, 2023
2d50b1d
Merge branch 'develop' into cderb/env-var-update
cderb Nov 13, 2023
4da821c
Merge branch 'develop' into cderb/env-var-update
cderb Nov 13, 2023
312a321
move generic_search env defaults into header
cderb Nov 13, 2023
34a21ca
Merge branch 'develop' into cderb/env-var-update
xinlipn Nov 14, 2023
9f85fae
Merge branch 'develop' into cderb/env-var-update
xinlipn Nov 15, 2023
cae82dc
mergefix
cderb Nov 15, 2023
2692bb8
mergefix
cderb Nov 16, 2023
3cfee15
missing include
cderb Nov 16, 2023
65f0f89
string fix
cderb Nov 16, 2023
61de002
Merge remote-tracking branch 'origin/develop' into cderb/env-var-update
cderb Nov 17, 2023
479d97e
mergefix
cderb Nov 17, 2023
f775ff1
Merge branch 'develop' into cderb/env-var-update
xinlipn Nov 20, 2023
6665114
address comments
cderb Nov 20, 2023
6be69d4
small performance remedy for find controls, remove log pollutants
cderb Nov 20, 2023
18cc17a
clang format
cderb Nov 20, 2023
a7b4344
revert find_controls changes
cderb Nov 20, 2023
a06e464
env namespace
cderb Nov 21, 2023
88bd707
fix
cderb Nov 21, 2023
eb46c35
fix
cderb Nov 21, 2023
a44b0d5
Merge remote-tracking branch 'origin/develop' into cderb/env-var-update
cderb Nov 21, 2023
263917e
mergefix
cderb Nov 22, 2023
2be776c
add Unset method for env, update assert string
cderb Nov 22, 2023
b597b81
fix
cderb Nov 22, 2023
707ce2d
Merge branch 'develop' into cderb/env-var-update
cderb Nov 22, 2023
c1c8336
Merge branch 'develop' into cderb/env-var-update
cderb Nov 22, 2023
4dc46ef
Merge branch 'develop' into cderb/env-var-update
xinlipn Nov 26, 2023
29a8b21
Branched off PR2514 and merge with my changes
xinlipn Nov 27, 2023
64c53e1
Merge branch 'develop' into cderb/env-var-update
cderb Nov 27, 2023
1aea239
Merge remote-tracking branch 'origin/develop' into cderb/env-var-update
cderb Nov 28, 2023
11c11b2
Merge branch 'develop' into cderb/env-var-update
cderb Nov 28, 2023
589b8d1
Merge branch 'develop' into cderb/env-var-update
xinlipn Nov 30, 2023
753ec83
Fix build errors
xinlipn Nov 30, 2023
dd11b5c
Merge remote-tracking branch 'origin/cderb/env-var-update' into sl/gt…
cderb Dec 1, 2023
652d1db
Merge remote-tracking branch 'origin/develop' into sl/gtest_rnn_gru_v…
cderb Dec 1, 2023
2952552
Added two tests, removed miopenInt8x4 from output message
xinlipn Dec 2, 2023
6c4c494
Merge branch 'develop' into sl/gtest_rnn_gru_vanilla_env_var_update
cderb Dec 4, 2023
b858476
Merge branch 'develop' into sl/gtest_rnn_gru_vanilla_env_var_update
xinlipn Dec 7, 2023
d5ef774
Merge branch 'develop' into sl/gtest_rnn_gru_vanilla_env_var_update
cderb Dec 11, 2023
ccc787f
Merge branch 'develop' into sl/gtest_rnn_gru_vanilla_env_var_update
cderb Dec 15, 2023
65c6f54
Merge branch 'develop' of github.com:ROCm/MIOpen into sl/gtest_rnn_gr…
bghimireamd Dec 18, 2023
5d4a5fa
sl/gtest_rnn_gru_vanilla_env_var_update : add namespace around test
bghimireamd Dec 18, 2023
e3f6ea3
sl/gtest_rnn_gru_vanilla_env_var_update: add unique names in TEST_P
bghimireamd Dec 19, 2023
ab03f7f
Merge branch 'develop' of github.com:ROCm/MIOpen into sl/gtest_rnn_gr…
bghimireamd Dec 19, 2023
490af72
Merge branch 'develop' of github.com:ROCm/MIOpen into sl/gtest_rnn_gr…
bghimireamd Dec 20, 2023
6e231bd
sl/gtest_rnn_gru_vanilla_env_var_update: fix gtest names
bghimireamd Dec 20, 2023
64a0666
Merge branch 'develop' into sl/gtest_rnn_gru_vanilla_env_var_update
junliume Dec 21, 2023
e9a179e
Merge branch 'develop' into sl/gtest_rnn_gru_vanilla_env_var_update
xinlipn Jan 8, 2024
5c509aa
Merge branch 'develop' into sl/gtest_rnn_gru_vanilla_env_var_update
xinlipn Jan 11, 2024
6414c61
Merge branch 'develop' into sl/gtest_rnn_gru_vanilla_env_var_update
xinlipn Jan 12, 2024
2f470aa
Merge branch 'develop' into sl/gtest_rnn_gru_vanilla_env_var_update
xinlipn Jan 14, 2024
f67df43
Update SkipTest logic, remove smoke tests
xinlipn Jan 17, 2024
d5cc4df
Allow deepbench tests to run in standalone mode
xinlipn Jan 18, 2024
5350d90
Merge branch 'develop' into sl/gtest_rnn_gru_vanilla_env_var_update
xinlipn Jan 19, 2024
e21e099
Merge branch 'develop' into sl/gtest_rnn_gru_vanilla_env_var_update
xinlipn Jan 22, 2024
faee7fd
Merge branch 'develop' into sl/gtest_rnn_gru_vanilla_env_var_update
xinlipn Jan 25, 2024
6bf3658
Merge branch 'develop' into sl/gtest_rnn_gru_vanilla_env_var_update
xinlipn Jan 29, 2024
e059be2
Remove unneeded logic and clean up code
xinlipn Jan 29, 2024
fdce760
Merge branch 'develop' into sl/gtest_rnn_gru_vanilla_env_var_update
xinlipn Jan 29, 2024
5e5f592
Remove unused function
xinlipn Jan 30, 2024
1c13064
Fix build
xinlipn Jan 30, 2024
69a255d
Merge branch 'develop' into sl/gtest_rnn_gru_vanilla_env_var_update
xinlipn Jan 30, 2024
7b55b6c
Merge branch 'develop' into sl/gtest_rnn_gru_vanilla_env_var_update
xinlipn Jan 30, 2024
0e38d7d
Merge branch 'develop' into sl/gtest_rnn_gru_vanilla_env_var_update
xinlipn Feb 5, 2024
bb41758
Merge branch 'develop' into sl/gtest_rnn_gru_vanilla_env_var_update
xinlipn Feb 5, 2024
39a49b4
Merge branch 'develop' into sl/gtest_rnn_gru_vanilla_env_var_update
xinlipn Feb 6, 2024
250bc71
Merge branch 'develop' into sl/gtest_rnn_gru_vanilla_env_var_update
xinlipn Feb 7, 2024
7bf75c0
Merge branch 'develop' into sl/gtest_rnn_gru_vanilla_env_var_update
xinlipn Feb 8, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 0 additions & 80 deletions test/CMakeLists.txt

Large diffs are not rendered by default.

40 changes: 1 addition & 39 deletions test/gru.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,45 +24,7 @@
*
*******************************************************************************/

#include "gru_common.hpp"

template <class T>
struct gru_driver : gru_basic_driver<T>
{
gru_driver() : gru_basic_driver<T>()
{
std::vector<int> modes(2, 0);
modes[1] = 1;
std::vector<int> defaultBS(1);

this->add(this->batchSize, "batch-size", this->generate_data(get_gru_batchSize(), {17}));
this->add(this->seqLength, "seq-len", this->generate_data(get_gru_seq_len(), {2}));
this->add(this->inVecLen, "vector-len", this->generate_data(get_gru_vector_len()));
this->add(this->hiddenSize, "hidden-size", this->generate_data(get_gru_hidden_size()));
this->add(this->numLayers, "num-layers", this->generate_data(get_gru_num_layers()));
this->add(this->nohx, "no-hx", this->flag());
this->add(this->nodhy, "no-dhy", this->flag());
this->add(this->nohy, "no-hy", this->flag());
this->add(this->nodhx, "no-dhx", this->flag());
this->add(this->flatBatchFill, "flat-batch-fill", this->flag());
this->add(this->useDropout, "use-dropout", this->generate_data({0}));

#if(MIO_GRU_TEST_DEBUG == 3)
this->biasMode = 0;
this->dirMode = 1;
this->inputMode = 0;
#else
this->add(this->inputMode, "in-mode", this->generate_data(modes));
this->add(this->biasMode, "bias-mode", this->generate_data(modes));
this->add(this->dirMode, "dir-mode", this->generate_data(modes));
#endif
this->add(
this->batchSeq,
"batch-seq",
this->lazy_generate_data(
[=] { return generate_batchSeq(this->batchSize, this->seqLength); }, defaultBS));
}
};
#include "gru.hpp"

int main(int argc, const char* argv[])
{
Expand Down
66 changes: 66 additions & 0 deletions test/gru.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
/*******************************************************************************
*
* MIT License
*
* Copyright (c) 2017 Advanced Micro Devices, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*
*******************************************************************************/
#pragma once

#include "gru_common.hpp"

template <class T>
struct gru_driver : gru_basic_driver<T>
{
gru_driver() : gru_basic_driver<T>()
{
std::vector<int> modes(2, 0);
modes[1] = 1;
std::vector<int> defaultBS(1);

this->add(this->batchSize, "batch-size", this->generate_data(get_gru_batchSize(), {17}));
this->add(this->seqLength, "seq-len", this->generate_data(get_gru_seq_len(), {2}));
this->add(this->inVecLen, "vector-len", this->generate_data(get_gru_vector_len()));
this->add(this->hiddenSize, "hidden-size", this->generate_data(get_gru_hidden_size()));
this->add(this->numLayers, "num-layers", this->generate_data(get_gru_num_layers()));
this->add(this->nohx, "no-hx", this->flag());
this->add(this->nodhy, "no-dhy", this->flag());
this->add(this->nohy, "no-hy", this->flag());
this->add(this->nodhx, "no-dhx", this->flag());
this->add(this->flatBatchFill, "flat-batch-fill", this->flag());
this->add(this->useDropout, "use-dropout", this->generate_data({0}));

#if(MIO_GRU_TEST_DEBUG == 3)
this->biasMode = 0;
this->dirMode = 1;
this->inputMode = 0;
#else
this->add(this->inputMode, "in-mode", this->generate_data(modes));
this->add(this->biasMode, "bias-mode", this->generate_data(modes));
this->add(this->dirMode, "dir-mode", this->generate_data(modes));
#endif
this->add(
this->batchSeq,
"batch-seq",
this->lazy_generate_data(
[=] { return generate_batchSeq(this->batchSize, this->seqLength); }, defaultBS));
}
};
144 changes: 144 additions & 0 deletions test/gtest/deepbench_gru.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
/*******************************************************************************
*
* MIT License
*
* Copyright (c) 2023 Advanced Micro Devices, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*
*******************************************************************************/
#include <miopen/miopen.h>
#include <gtest/gtest.h>
#include <miopen/env.hpp>
#include "../gru.hpp"
#include "get_handle.hpp"

MIOPEN_DECLARE_ENV_VAR_BOOL(MIOPEN_TEST_DEEPBENCH)

static bool SkipTest(void) { return miopen::IsDisabled(ENV(MIOPEN_TEST_DEEPBENCH)); }

void GetArgs(const std::string& param, std::vector<std::string>& tokens)
{
std::stringstream ss(param);
std::istream_iterator<std::string> begin(ss);
std::istream_iterator<std::string> end;
while(begin != end)
tokens.push_back(*begin++);
}

class ConfigWithFloat : public testing::TestWithParam<std::vector<std::string>>
{
};

void Run2dDriver(miopenDataType_t prec)
{

std::vector<std::string> params;
switch(prec)
{
case miopenFloat: params = ConfigWithFloat::GetParam(); break;
case miopenHalf:
case miopenFloat8:
case miopenBFloat8:
case miopenInt8:
case miopenBFloat16:
case miopenInt32:
case miopenDouble:
FAIL() << "miopenHalf, miopenInt8, miopenBFloat16, miopenInt32, miopenDouble "
"data type not supported by "
"rnn_vanilla test";

default: params = ConfigWithFloat::GetParam();
}

for(const auto& test_value : params)
{
std::vector<std::string> tokens;
GetArgs(test_value, tokens);
std::vector<const char*> ptrs;

std::transform(tokens.begin(), tokens.end(), std::back_inserter(ptrs), [](const auto& str) {
return str.data();
});

testing::internal::CaptureStderr();
test_drive<gru_driver>(ptrs.size(), ptrs.data());
auto capture = testing::internal::GetCapturedStderr();
std::cout << capture;
}
};

bool IsTestSupportedForDevice(const miopen::Handle& handle)
{
std::string devName = handle.GetDeviceName();
if(devName == "gfx900" || devName == "gfx906" || devName == "gfx908" || devName == "gfx90a" ||
miopen::StartsWith(devName, "gfx94") || miopen::StartsWith(devName, "gfx103") ||
miopen::StartsWith(devName, "gfx110"))
return true;
else
return false;
junliume marked this conversation as resolved.
Show resolved Hide resolved
}

TEST_P(ConfigWithFloat, FloatTest)
{
const auto& handle = get_handle();
if(IsTestSupportedForDevice(handle) && !SkipTest())
{
Run2dDriver(miopenFloat);
}
else
{
GTEST_SKIP();
}
};

std::vector<std::string> GetTestCases(void)
{
std::string flags = " --verbose";
std::string commonFlags =
" --num-layers 1 --in-mode 1 --bias-mode 0 -dir-mode 0 --rnn-mode 0 --flat-batch-fill";

const std::vector<std::string> test_cases = {
// clang-format off
{flags + " --batch-size 32 --seq-len 1500 --vector-len 2816 --hidden-size 2816" + commonFlags},
{flags + " --batch-size 32 --seq-len 750 --vector-len 2816 --hidden-size 2816" + commonFlags},
{flags + " --batch-size 32 --seq-len 375 --vector-len 2816 --hidden-size 2816" + commonFlags},
{flags + " --batch-size 32 --seq-len 187 --vector-len 2816 --hidden-size 2816" + commonFlags},
{flags + " --batch-size 32 --seq-len 1500 --vector-len 2048 --hidden-size 2048" + commonFlags},
{flags + " --batch-size 32 --seq-len 750 --vector-len 2048 --hidden-size 2048" + commonFlags},
{flags + " --batch-size 32 --seq-len 375 --vector-len 2048 --hidden-size 2048" + commonFlags},
{flags + " --batch-size 32 --seq-len 187 --vector-len 2048 --hidden-size 2048" + commonFlags},
{flags + " --batch-size 32 --seq-len 1500 --vector-len 1536 --hidden-size 1536" + commonFlags},
{flags + " --batch-size 32 --seq-len 750 --vector-len 1536 --hidden-size 1536" + commonFlags},
{flags + " --batch-size 32 --seq-len 375 --vector-len 1536 --hidden-size 1536" + commonFlags},
{flags + " --batch-size 32 --seq-len 187 --vector-len 1536 --hidden-size 1536" + commonFlags},
{flags + " --batch-size 32 --seq-len 1500 --vector-len 2560 --hidden-size 2560" + commonFlags},
{flags + " --batch-size 32 --seq-len 750 --vector-len 2560 --hidden-size 2560" + commonFlags},
{flags + " --batch-size 32 --seq-len 375 --vector-len 2560 --hidden-size 2560" + commonFlags},
{flags + " --batch-size 32 --seq-len 187 --vector-len 2560 --hidden-size 2560" + commonFlags},
{flags + " --batch-size 32 --seq-len 1 --vector-len 512 --hidden-size 512" + commonFlags},
{flags + " --batch-size 32 --seq-len 1500 --vector-len 1024 --hidden-size 1024" + commonFlags},
{flags + " --batch-size 64 --seq-len 1500 --vector-len 1024 --hidden-size 1024" + commonFlags}
// clang-format on
};

return test_cases;
}

INSTANTIATE_TEST_SUITE_P(ConvTrans, ConfigWithFloat, testing::Values(GetTestCases()));
138 changes: 138 additions & 0 deletions test/gtest/deepbench_rnn.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
/*******************************************************************************
*
* MIT License
*
* Copyright (c) 2023 Advanced Micro Devices, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*
*******************************************************************************/
#include <miopen/miopen.h>
#include <gtest/gtest.h>
#include <miopen/env.hpp>
#include "../rnn_vanilla.hpp"
#include "get_handle.hpp"

MIOPEN_DECLARE_ENV_VAR_BOOL(MIOPEN_TEST_DEEPBENCH)

static bool SkipTest(void) { return miopen::IsDisabled(ENV(MIOPEN_TEST_DEEPBENCH)); }
junliume marked this conversation as resolved.
Show resolved Hide resolved

void GetArgs(const std::string& param, std::vector<std::string>& tokens)
{
std::stringstream ss(param);
std::istream_iterator<std::string> begin(ss);
std::istream_iterator<std::string> end;
while(begin != end)
tokens.push_back(*begin++);
}

class ConfigWithFloat : public testing::TestWithParam<std::vector<std::string>>
{
};

void Run2dDriver(miopenDataType_t prec)
{

std::vector<std::string> params;
switch(prec)
{
case miopenFloat: params = ConfigWithFloat::GetParam(); break;
case miopenHalf:
case miopenFloat8:
case miopenBFloat8:
case miopenInt8:
case miopenBFloat16:
case miopenInt32:
case miopenDouble:
FAIL() << "miopenHalf, miopenInt8, miopenBFloat16, miopenInt32, miopenDouble "
"data type not supported by "
"rnn_vanilla test";

default: params = ConfigWithFloat::GetParam();
}

for(const auto& test_value : params)
{
std::vector<std::string> tokens;
GetArgs(test_value, tokens);
std::vector<const char*> ptrs;

std::transform(tokens.begin(), tokens.end(), std::back_inserter(ptrs), [](const auto& str) {
return str.data();
});

testing::internal::CaptureStderr();
test_drive<rnn_vanilla_driver>(ptrs.size(), ptrs.data());
auto capture = testing::internal::GetCapturedStderr();
std::cout << capture;
}
};

bool IsTestSupportedForDevice(const miopen::Handle& handle)
{
std::string devName = handle.GetDeviceName();
if(devName == "gfx900" || devName == "gfx906" || devName == "gfx908" || devName == "gfx90a" ||
miopen::StartsWith(devName, "gfx94") || miopen::StartsWith(devName, "gfx103") ||
miopen::StartsWith(devName, "gfx110"))
return true;
else
return false;
junliume marked this conversation as resolved.
Show resolved Hide resolved
}

TEST_P(ConfigWithFloat, FloatTest)
{
const auto& handle = get_handle();
if(IsTestSupportedForDevice(handle) && !SkipTest())
{
Run2dDriver(miopenFloat);
}
else
{
GTEST_SKIP();
}
};

std::vector<std::string> GetTestCases(void)
{
std::string flags = " --verbose";

std::string postFlags =
"--num-layers 1 --in-mode 1 --bias-mode 0 -dir-mode 0 --rnn-mode 0 --flat-batch-fill";

const std::vector<std::string> test_cases = {
// clang-format off
{flags + " --batch-size 16 --seq-len 50 --vector-len 1760 --hidden-size 1760 " + postFlags},
{flags + " --batch-size 32 --seq-len 50 --vector-len 1760 --hidden-size 1760 " + postFlags},
{flags + " --batch-size 64 --seq-len 50 --vector-len 1760 --hidden-size 1760 " + postFlags},
{flags + " --batch-size 128 --seq-len 50 --vector-len 1760 --hidden-size 1760 " + postFlags},
{flags + " --batch-size 16 --seq-len 50 --vector-len 2048 --hidden-size 2048 " + postFlags},
{flags + " --batch-size 32 --seq-len 50 --vector-len 2048 --hidden-size 2048 " + postFlags},
{flags + " --batch-size 64 --seq-len 50 --vector-len 2048 --hidden-size 2048 " + postFlags},
{flags + " --batch-size 128 --seq-len 50 --vector-len 2048 --hidden-size 2048 " + postFlags},
{flags + " --batch-size 16 --seq-len 50 --vector-len 2560 --hidden-size 2560 " + postFlags},
{flags + " --batch-size 32 --seq-len 50 --vector-len 2560 --hidden-size 2560 " + postFlags},
{flags + " --batch-size 64 --seq-len 50 --vector-len 2560 --hidden-size 2560 " + postFlags},
{flags + " --batch-size 128 --seq-len 50 --vector-len 2560 --hidden-size 2560 " + postFlags}
// clang-format on
};

return test_cases;
}

INSTANTIATE_TEST_SUITE_P(ConvTrans, ConfigWithFloat, testing::Values(GetTestCases()));
Loading