diff --git a/src/register_target.cpp b/src/register_target.cpp index 77b2c00601f..946857d754e 100644 --- a/src/register_target.cpp +++ b/src/register_target.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -21,6 +21,7 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. */ +#include #include #include #include @@ -33,46 +34,92 @@ inline namespace MIGRAPHX_INLINE_NS { void store_target_lib(const dynamic_loader& lib) { static std::vector target_loader; + static std::mutex mutex; + std::unique_lock lock(mutex); + target_loader.emplace_back(lib); } +/** + * Returns a singleton map of targets and names. + */ std::unordered_map& target_map() { static std::unordered_map m; // NOLINT return m; } +/** + * Returns a singleton mutex used by the various register_target methods. + */ +std::mutex& target_mutex() +{ + static std::mutex m; // NOLINT + return m; +} + void register_target_init() { (void)target_map(); } void unregister_target(const std::string& name) { + std::unique_lock lock(target_mutex()); assert(target_map().count(name)); target_map().erase(name); } -void register_target(const target& t) { target_map()[t.name()] = t; } +/** + * Insert a target name in the target_map; thread safe. + */ +void register_target(const target& t) +{ + std::unique_lock lock(target_mutex()); + target_map()[t.name()] = t; +} + +/** + * Search for a target by name in the target_map; thread-safe. + */ +migraphx::optional find_target(const std::string& name) +{ + // search for match or return none + std::unique_lock lock(target_mutex()); + const auto it = target_map().find(name); + + if(it == target_map().end()) + return nullopt; + return it->second; +} +/** + * Get a target by name. Load target library and register target if needed. + * Thread safe. + */ target make_target(const std::string& name) { - if(not contains(target_map(), name)) + // no lock required here + auto t = find_target(name); + if(t == nullopt) { #ifdef _WIN32 std::string target_name = "migraphx_" + name + ".dll"; #else std::string target_name = "libmigraphx_" + name + ".so"; #endif + // register_target is called by this store_target_lib(dynamic_loader(target_name)); + t = find_target(name); } - const auto it = target_map().find(name); - if(it == target_map().end()) - { - MIGRAPHX_THROW("Requested target '" + name + "' is not loaded or not supported"); - } - return it->second; + // at this point we should always have a target + + return *t; } +/** + * Get list of names of registered targets. + */ std::vector get_targets() { + std::unique_lock lock(target_mutex()); std::vector result; std::transform(target_map().begin(), target_map().end(), diff --git a/test/targets.cpp b/test/targets.cpp index 97e2d8d5f2f..b6e828837b0 100644 --- a/test/targets.cpp +++ b/test/targets.cpp @@ -23,6 +23,7 @@ */ #include #include +#include #include "test.hpp" TEST_CASE(make_target) @@ -46,4 +47,42 @@ TEST_CASE(targets) EXPECT(ts.size() >= 1); } +TEST_CASE(concurrent_targets) +{ + std::vector threads; +#ifdef HAVE_GPU + std::string target_name = "gpu"; +#elif defined(HAVE_CPU) + std::string target_name = "cpu"; +#elif defined(HAVE_FPGA) + std::string target_name = "fpga"; +#else + std::string target_name = "ref"; +#endif + + auto n_threads = std::thread::hardware_concurrency() * 4; + + for(auto i = 0u; i < n_threads; i++) + { + auto thread_body = [&target_name]() { + // TODO: remove all existing targets, if any. + // The existing code cannot pass a test in which different threads + // register and unregister the same targets; not known if this is + // needed in any deployed product. + // std::vector target_list = migraphx::get_targets(); + // for(const auto& tt : target_list) + // migraphx::unregister_target(tt); + + auto ref_target = migraphx::make_target(target_name); + migraphx::register_target(ref_target); + EXPECT(test::throws([&] { ref_target = migraphx::make_target("xyz"); })); + + migraphx::get_targets(); + }; + + threads.emplace_back(thread_body); + } + // joinable_thread don't need to have join() called. +} + int main(int argc, const char* argv[]) { test::run(argc, argv); }