diff --git a/src/core/EspressoSystemStandAlone.cpp b/src/core/EspressoSystemStandAlone.cpp index 90788c3c023..900fc5cdf1d 100644 --- a/src/core/EspressoSystemStandAlone.cpp +++ b/src/core/EspressoSystemStandAlone.cpp @@ -29,17 +29,18 @@ #include #include +#include #include EspressoSystemStandAlone::EspressoSystemStandAlone(int argc, char **argv) { - auto mpi_env = mpi_init(argc, argv); + m_mpi_env = mpi_init(argc, argv); boost::mpi::communicator world; head_node = world.rank() == 0; // initialize the MpiCallbacks framework - Communication::init(mpi_env); + Communication::init(m_mpi_env); // default-construct global state of the system #ifdef VIRTUAL_SITES @@ -50,6 +51,10 @@ EspressoSystemStandAlone::EspressoSystemStandAlone(int argc, char **argv) { mpi_loop(); } +EspressoSystemStandAlone::~EspressoSystemStandAlone() { + Communication::deinit(); +} + void EspressoSystemStandAlone::set_box_l(Utils::Vector3d const &box_l) const { if (!head_node) return; diff --git a/src/core/EspressoSystemStandAlone.hpp b/src/core/EspressoSystemStandAlone.hpp index d1183835659..198b235c255 100644 --- a/src/core/EspressoSystemStandAlone.hpp +++ b/src/core/EspressoSystemStandAlone.hpp @@ -21,12 +21,21 @@ #include +#include + +namespace boost { +namespace mpi { +class environment; +} +} // namespace boost + /** Manager for a stand-alone ESPResSo system. * The system is default-initialized, MPI-ready and has no script interface. */ class EspressoSystemStandAlone { public: EspressoSystemStandAlone(int argc, char **argv); + ~EspressoSystemStandAlone(); void set_box_l(Utils::Vector3d const &box_l) const; void set_node_grid(Utils::Vector3i const &node_grid) const; void set_time_step(double time_step) const; @@ -34,6 +43,7 @@ class EspressoSystemStandAlone { private: bool head_node; + std::shared_ptr m_mpi_env; }; #endif diff --git a/src/core/MpiCallbacks.hpp b/src/core/MpiCallbacks.hpp index 2f58f460d38..3c5e8f297c9 100644 --- a/src/core/MpiCallbacks.hpp +++ b/src/core/MpiCallbacks.hpp @@ -48,6 +48,7 @@ #include #include #include +#include #include #include @@ -364,8 +365,8 @@ class MpiCallbacks { template ::argument_types, std::tuple>::value>> - CallbackHandle(MpiCallbacks *cb, F &&f) - : m_id(cb->add(std::forward(f))), m_cb(cb) {} + CallbackHandle(std::shared_ptr cb, F &&f) + : m_id(cb->add(std::forward(f))), m_cb(std::move(cb)) {} CallbackHandle(CallbackHandle const &) = delete; CallbackHandle(CallbackHandle &&rhs) noexcept = default; @@ -374,7 +375,7 @@ class MpiCallbacks { private: int m_id; - MpiCallbacks *m_cb; + std::shared_ptr m_cb; public: /** @@ -400,7 +401,6 @@ class MpiCallbacks { m_cb->remove(m_id); } - MpiCallbacks *cb() const { return m_cb; } int id() const { return m_id; } }; @@ -419,8 +419,10 @@ class MpiCallbacks { public: explicit MpiCallbacks(boost::mpi::communicator comm, + std::shared_ptr mpi_env, bool abort_on_exit = true) - : m_abort_on_exit(abort_on_exit), m_comm(std::move(comm)) { + : m_abort_on_exit(abort_on_exit), m_comm(std::move(comm)), + m_mpi_env(std::move(mpi_env)) { /* Add a dummy at id 0 for loop abort. */ m_callback_map.add(nullptr); @@ -721,6 +723,10 @@ class MpiCallbacks { */ boost::mpi::communicator const &comm() const { return m_comm; } + std::shared_ptr share_mpi_env() const { + return m_mpi_env; + } + private: /** * @brief Id for the @ref abort_loop. Has to be 0. @@ -738,6 +744,11 @@ class MpiCallbacks { */ boost::mpi::communicator m_comm; + /** + * The MPI environment used for the callbacks. + */ + std::shared_ptr m_mpi_env; + /** * Internal storage for the callback functions. */ diff --git a/src/core/communication.cpp b/src/core/communication.cpp index 20b7d052637..2268e88f4db 100644 --- a/src/core/communication.cpp +++ b/src/core/communication.cpp @@ -28,6 +28,8 @@ #include #include +#include +#include #include #ifdef OPEN_MPI @@ -39,15 +41,10 @@ #include #include -namespace Communication { -auto const &mpi_datatype_cache = boost::mpi::detail::mpi_datatype_cache(); -std::shared_ptr mpi_env; -} // namespace Communication - boost::mpi::communicator comm_cart; namespace Communication { -std::unique_ptr m_callbacks; +static std::shared_ptr m_callbacks; /* We use a singleton callback class for now. */ MpiCallbacks &mpiCallbacks() { @@ -55,6 +52,12 @@ MpiCallbacks &mpiCallbacks() { return *m_callbacks; } + +std::shared_ptr mpiCallbacksHandle() { + assert(m_callbacks && "Mpi not initialized!"); + + return m_callbacks; +} } // namespace Communication using Communication::mpiCallbacks; @@ -120,8 +123,6 @@ void openmpi_global_namespace() { namespace Communication { void init(std::shared_ptr mpi_env) { - Communication::mpi_env = std::move(mpi_env); - MPI_Comm_size(MPI_COMM_WORLD, &n_nodes); node_grid = Utils::Mpi::dims_create<3>(n_nodes); @@ -131,12 +132,14 @@ void init(std::shared_ptr mpi_env) { this_node = comm_cart.rank(); Communication::m_callbacks = - std::make_unique(comm_cart); + std::make_shared(comm_cart, mpi_env); - ErrorHandling::init_error_handling(mpiCallbacks()); + ErrorHandling::init_error_handling(Communication::m_callbacks); on_program_start(); } + +void deinit() { Communication::m_callbacks.reset(); } } // namespace Communication std::shared_ptr mpi_init(int argc, char **argv) { diff --git a/src/core/communication.hpp b/src/core/communication.hpp index cab2d8507a2..4bc11ae8ac7 100644 --- a/src/core/communication.hpp +++ b/src/core/communication.hpp @@ -63,6 +63,7 @@ namespace Communication { * @brief Returns a reference to the global callback class instance. */ MpiCallbacks &mpiCallbacks(); +std::shared_ptr mpiCallbacksHandle(); } // namespace Communication /************************************************** @@ -136,12 +137,9 @@ namespace Communication { /** * @brief Init globals for communication. * - * and calls @ref on_program_start. Keeps a copy of - * the pointer to the mpi environment to keep it alive - * while the program is loaded. - * * @param mpi_env MPI environment that should be used */ void init(std::shared_ptr mpi_env); +void deinit(); } // namespace Communication #endif diff --git a/src/core/errorhandling.cpp b/src/core/errorhandling.cpp index 6a50495282f..be74ef73cd1 100644 --- a/src/core/errorhandling.cpp +++ b/src/core/errorhandling.cpp @@ -34,6 +34,7 @@ #include #include #include +#include #include namespace ErrorHandling { @@ -45,14 +46,14 @@ namespace { std::unique_ptr runtimeErrorCollector; /** The callback loop we are on. */ -Communication::MpiCallbacks *m_callbacks = nullptr; +std::weak_ptr m_callbacks; } // namespace -void init_error_handling(Communication::MpiCallbacks &cb) { - m_callbacks = &cb; +void init_error_handling(std::weak_ptr callbacks) { + m_callbacks = std::move(callbacks); runtimeErrorCollector = - std::make_unique(m_callbacks->comm()); + std::make_unique(m_callbacks.lock()->comm()); } RuntimeErrorStream _runtimeMessageStream(RuntimeError::ErrorLevel level, @@ -69,7 +70,7 @@ static void mpi_gather_runtime_errors_local() { REGISTER_CALLBACK(mpi_gather_runtime_errors_local) std::vector mpi_gather_runtime_errors() { - m_callbacks->call(mpi_gather_runtime_errors_local); + m_callbacks.lock()->call(mpi_gather_runtime_errors_local); return runtimeErrorCollector->gather(); } @@ -83,7 +84,7 @@ std::vector mpi_gather_runtime_errors_all(bool is_head_node) { } // namespace ErrorHandling void errexit() { - ErrorHandling::m_callbacks->comm().abort(1); + ErrorHandling::m_callbacks.lock()->comm().abort(1); std::abort(); } diff --git a/src/core/errorhandling.hpp b/src/core/errorhandling.hpp index 1d79a1a94e6..66a4af536c0 100644 --- a/src/core/errorhandling.hpp +++ b/src/core/errorhandling.hpp @@ -31,6 +31,7 @@ #include "error_handling/RuntimeError.hpp" #include "error_handling/RuntimeErrorStream.hpp" +#include #include #include @@ -85,7 +86,7 @@ namespace ErrorHandling { * * @param callbacks Callbacks system the error handler should be on. */ -void init_error_handling(Communication::MpiCallbacks &callbacks); +void init_error_handling(std::weak_ptr callbacks); RuntimeErrorStream _runtimeMessageStream(RuntimeError::ErrorLevel level, const std::string &file, int line, diff --git a/src/core/particle_node.cpp b/src/core/particle_node.cpp index 9cf4270d210..bd1aebea533 100644 --- a/src/core/particle_node.cpp +++ b/src/core/particle_node.cpp @@ -407,7 +407,7 @@ static int mpi_place_new_particle(int p_id, const Utils::Vector3d &pos) { void mpi_place_particle_local(int pnode, int p_id) { if (pnode == this_node) { - Utils::Vector3d pos; + Utils::Vector3d pos{}; comm_cart.recv(0, some_tag, pos); local_move_particle(p_id, pos); } diff --git a/src/core/unit_tests/EspressoSystemStandAlone_test.cpp b/src/core/unit_tests/EspressoSystemStandAlone_test.cpp index aec1c71d5c8..d1fbe0ccab4 100644 --- a/src/core/unit_tests/EspressoSystemStandAlone_test.cpp +++ b/src/core/unit_tests/EspressoSystemStandAlone_test.cpp @@ -357,5 +357,7 @@ BOOST_FIXTURE_TEST_CASE(espresso_system_stand_alone, ParticleFactory, int main(int argc, char **argv) { espresso::system = std::make_unique(argc, argv); - return boost::unit_test::unit_test_main(init_unit_test, argc, argv); + int retval = boost::unit_test::unit_test_main(init_unit_test, argc, argv); + espresso::system.reset(); + return retval; } diff --git a/src/core/unit_tests/MpiCallbacks_test.cpp b/src/core/unit_tests/MpiCallbacks_test.cpp index 3afbe3f2a63..517981d47ff 100644 --- a/src/core/unit_tests/MpiCallbacks_test.cpp +++ b/src/core/unit_tests/MpiCallbacks_test.cpp @@ -29,6 +29,7 @@ #include "MpiCallbacks.hpp" #include +#include #include #include @@ -36,6 +37,7 @@ #include #include +static std::weak_ptr mpi_env; static bool called = false; BOOST_AUTO_TEST_CASE(invoke_test) { @@ -111,7 +113,7 @@ BOOST_AUTO_TEST_CASE(callback_model_t) { BOOST_AUTO_TEST_CASE(adding_function_ptr_cb) { boost::mpi::communicator world; - Communication::MpiCallbacks cb(world); + Communication::MpiCallbacks cb(world, ::mpi_env.lock()); void (*fp)(int, const std::string &) = [](int i, const std::string &s) { BOOST_CHECK_EQUAL(537, i); @@ -143,7 +145,7 @@ BOOST_AUTO_TEST_CASE(RegisterCallback) { Communication::RegisterCallback{fp}; boost::mpi::communicator world; - Communication::MpiCallbacks cb(world); + Communication::MpiCallbacks cb(world, ::mpi_env.lock()); called = false; @@ -157,11 +159,12 @@ BOOST_AUTO_TEST_CASE(RegisterCallback) { BOOST_AUTO_TEST_CASE(CallbackHandle) { boost::mpi::communicator world; - Communication::MpiCallbacks cbs(world); + auto const cbs = + std::make_shared(world, ::mpi_env.lock()); bool m_called = false; Communication::CallbackHandle cb( - &cbs, [&m_called](std::string s) { + cbs, [&m_called](std::string s) { BOOST_CHECK_EQUAL("CallbackHandle", s); m_called = true; @@ -170,7 +173,7 @@ BOOST_AUTO_TEST_CASE(CallbackHandle) { if (0 == world.rank()) { cb(std::string("CallbackHandle")); } else { - cbs.loop(); + cbs->loop(); BOOST_CHECK(called); } } @@ -182,7 +185,7 @@ BOOST_AUTO_TEST_CASE(reduce_callback) { std::plus()); boost::mpi::communicator world; - Communication::MpiCallbacks cbs(world); + Communication::MpiCallbacks cbs(world, ::mpi_env.lock()); if (0 == world.rank()) { auto const ret = cbs.call(Communication::Result::reduction, @@ -203,7 +206,7 @@ BOOST_AUTO_TEST_CASE(ignore_callback) { Communication::MpiCallbacks::add_static(Communication::Result::ignore, fp); boost::mpi::communicator world; - Communication::MpiCallbacks cbs(world); + Communication::MpiCallbacks cbs(world, ::mpi_env.lock()); if (0 == world.rank()) { cbs.call(Communication::Result::ignore, fp); @@ -229,7 +232,7 @@ BOOST_AUTO_TEST_CASE(one_rank_callback) { Communication::MpiCallbacks::add_static(Communication::Result::one_rank, fp); boost::mpi::communicator world; - Communication::MpiCallbacks cbs(world); + Communication::MpiCallbacks cbs(world, ::mpi_env.lock()); if (0 == world.rank()) { BOOST_CHECK_EQUAL(cbs.call(Communication::Result::one_rank, fp), @@ -254,7 +257,7 @@ BOOST_AUTO_TEST_CASE(main_rank_callback) { Communication::MpiCallbacks::add_static(Communication::Result::main_rank, fp); boost::mpi::communicator world; - Communication::MpiCallbacks cbs(world); + Communication::MpiCallbacks cbs(world, ::mpi_env.lock()); if (0 == world.rank()) { BOOST_CHECK_EQUAL(cbs.call(Communication::Result::main_rank, fp), @@ -273,7 +276,7 @@ BOOST_AUTO_TEST_CASE(call_all) { Communication::MpiCallbacks::add_static(fp); boost::mpi::communicator world; - Communication::MpiCallbacks cbs(world); + Communication::MpiCallbacks cbs(world, ::mpi_env.lock()); if (0 == world.rank()) { cbs.call_all(fp); @@ -294,7 +297,7 @@ BOOST_AUTO_TEST_CASE(check_exceptions) { Communication::MpiCallbacks::add_static(fp1); boost::mpi::communicator world; - Communication::MpiCallbacks cbs(world); + Communication::MpiCallbacks cbs(world, ::mpi_env.lock()); if (0 == world.rank()) { // can't call an unregistered callback @@ -302,11 +305,13 @@ BOOST_AUTO_TEST_CASE(check_exceptions) { } else { // can't call a callback from worker nodes BOOST_CHECK_THROW(cbs.call(fp1), std::logic_error); + cbs.loop(); } } int main(int argc, char **argv) { - boost::mpi::environment mpi_env(argc, argv); + auto const mpi_env = std::make_shared(argc, argv); + ::mpi_env = mpi_env; return boost::unit_test::unit_test_main(init_unit_test, argc, argv); } diff --git a/src/core/unit_tests/Verlet_list_test.cpp b/src/core/unit_tests/Verlet_list_test.cpp index 43368396528..43947c50cd8 100644 --- a/src/core/unit_tests/Verlet_list_test.cpp +++ b/src/core/unit_tests/Verlet_list_test.cpp @@ -241,6 +241,7 @@ int main(int argc, char **argv) { if (world.size() == 4) { error_code = boost::unit_test::unit_test_main(init_unit_test, argc, argv); } + espresso::system.reset(); return error_code; } #else // ifdef LENNARD_JONES diff --git a/src/python/espressomd/_init.pyx b/src/python/espressomd/_init.pyx index 0ef6ee52f99..0b96999140d 100644 --- a/src/python/espressomd/_init.pyx +++ b/src/python/espressomd/_init.pyx @@ -17,6 +17,7 @@ # along with this program. If not, see . # import sys +import atexit from . cimport script_interface from . cimport communication from libcpp.memory cimport shared_ptr @@ -28,7 +29,17 @@ communication.init(mpi_env) # Initialize script interface # Has to be _after_ mpi_init -script_interface.init(communication.mpiCallbacks()) +script_interface.init(communication.mpiCallbacksHandle()) + + +def session_shutdown(): + mpi_env.reset() + communication.deinit() + script_interface.deinit() + + +atexit.register(session_shutdown) + # Block the worker nodes in the callback loop. # The head node is just returning to the user script. diff --git a/src/python/espressomd/communication.pxd b/src/python/espressomd/communication.pxd index aeb4f43adf1..7065f96a9d7 100644 --- a/src/python/espressomd/communication.pxd +++ b/src/python/espressomd/communication.pxd @@ -30,4 +30,6 @@ cdef extern from "communication.hpp": cdef extern from "communication.hpp" namespace "Communication": MpiCallbacks & mpiCallbacks() + shared_ptr[MpiCallbacks] mpiCallbacksHandle() void init(shared_ptr[environment]) + void deinit() diff --git a/src/python/espressomd/script_interface.pxd b/src/python/espressomd/script_interface.pxd index bd2b0c2860b..5e9bd38ddf2 100644 --- a/src/python/espressomd/script_interface.pxd +++ b/src/python/espressomd/script_interface.pxd @@ -58,7 +58,7 @@ cdef extern from "script_interface/ContextManager.hpp" namespace "ScriptInterfac cdef extern from "script_interface/ContextManager.hpp" namespace "ScriptInterface": cdef cppclass ContextManager: - ContextManager(MpiCallbacks & , const Factory[ObjectHandle] & ) + ContextManager(const shared_ptr[MpiCallbacks] & , const Factory[ObjectHandle] & ) shared_ptr[ObjectHandle] make_shared(CreationPolicy, const string &, const VariantMap) except + shared_ptr[ObjectHandle] deserialize(const string &) except + string serialize(const ObjectHandle *) except + @@ -69,4 +69,5 @@ cdef extern from "script_interface/initialize.hpp" namespace "ScriptInterface": cdef extern from "script_interface/get_value.hpp" namespace "ScriptInterface": T get_value[T](const Variant T) -cdef void init(MpiCallbacks &) +cdef void init(const shared_ptr[MpiCallbacks] &) +cdef void deinit() diff --git a/src/python/espressomd/script_interface.pyx b/src/python/espressomd/script_interface.pyx index d6fff6879d1..aed9f781d35 100644 --- a/src/python/espressomd/script_interface.pyx +++ b/src/python/espressomd/script_interface.pyx @@ -484,10 +484,15 @@ def script_interface_register(c): return c -cdef void init(MpiCallbacks & cb): +cdef void init(const shared_ptr[MpiCallbacks] & cb): cdef Factory[ObjectHandle] f initialize(& f) global _om _om = make_shared[ContextManager](cb, f) + + +cdef void deinit(): + global _om + _om.reset() diff --git a/src/python/espressomd/system.pyx b/src/python/espressomd/system.pyx index 3bb08f381a5..3003f66c913 100644 --- a/src/python/espressomd/system.pyx +++ b/src/python/espressomd/system.pyx @@ -202,10 +202,20 @@ cdef class System: self._active_virtual_sites_handle = virtual_sites.ActiveVirtualSitesHandle( implementation=virtual_sites.VirtualSitesOff()) + self._setup_atexit() + # lock class global _system_created _system_created = True + def _setup_atexit(self): + import atexit + + def session_shutdown(): + self.actors.clear() + + atexit.register(session_shutdown) + # __getstate__ and __setstate__ define the pickle interaction def __getstate__(self): checkpointable_properties = ["_box_geo", "integrator"] @@ -230,6 +240,7 @@ cdef class System: def __setstate__(self, params): for property_name in params.keys(): System.__setattr__(self, property_name, params[property_name]) + self._setup_atexit() property box_l: """ diff --git a/src/script_interface/ContextManager.cpp b/src/script_interface/ContextManager.cpp index 7bb3dfd37c1..e7f20c3475c 100644 --- a/src/script_interface/ContextManager.cpp +++ b/src/script_interface/ContextManager.cpp @@ -53,15 +53,16 @@ std::string ContextManager::serialize(const ObjectHandle *o) const { return Utils::pack(std::make_pair(policy(ctx), o->serialize())); } -ContextManager::ContextManager(Communication::MpiCallbacks &callbacks, - const Utils::Factory &factory) { +ContextManager::ContextManager( + std::shared_ptr const &callbacks, + const Utils::Factory &factory) { auto local_context = - std::make_shared(factory, callbacks.comm()); + std::make_shared(factory, callbacks->comm()); /* If there is only one node, we can treat all objects as local, and thus * never invoke any callback. */ m_global_context = - (callbacks.comm().size() > 1) + (callbacks->comm().size() > 1) ? std::make_shared(callbacks, local_context) : std::static_pointer_cast(local_context); diff --git a/src/script_interface/ContextManager.hpp b/src/script_interface/ContextManager.hpp index 7f5a098a590..47fa10a73b2 100644 --- a/src/script_interface/ContextManager.hpp +++ b/src/script_interface/ContextManager.hpp @@ -67,7 +67,7 @@ class ContextManager { GLOBAL }; - ContextManager(Communication::MpiCallbacks &callbacks, + ContextManager(std::shared_ptr const &callbacks, const Utils::Factory &factory); /** diff --git a/src/script_interface/GlobalContext.hpp b/src/script_interface/GlobalContext.hpp index d4de6ebbc02..ede91335991 100644 --- a/src/script_interface/GlobalContext.hpp +++ b/src/script_interface/GlobalContext.hpp @@ -86,28 +86,28 @@ class GlobalContext : public Context { Communication::CallbackHandle cb_delete_handle; public: - GlobalContext(Communication::MpiCallbacks &callbacks, + GlobalContext(std::shared_ptr const &callbacks, std::shared_ptr node_local_context) : m_local_objects(), m_node_local_context(std::move(node_local_context)), - m_is_head_node(callbacks.comm().rank() == 0), + m_is_head_node(callbacks->comm().rank() == 0), // NOLINTNEXTLINE(bugprone-throw-keyword-missing) - m_parallel_exception_handler(callbacks.comm()), - cb_make_handle(&callbacks, + m_parallel_exception_handler(callbacks->comm()), + cb_make_handle(callbacks, [this](ObjectId id, const std::string &name, const PackedMap ¶meters) { make_handle(id, name, parameters); }), - cb_set_parameter(&callbacks, + cb_set_parameter(callbacks, [this](ObjectId id, std::string const &name, PackedVariant const &value) { set_parameter(id, name, value); }), - cb_call_method(&callbacks, + cb_call_method(callbacks, [this](ObjectId id, std::string const &name, PackedMap const &arguments) { call_method(id, name, arguments); }), - cb_delete_handle(&callbacks, + cb_delete_handle(callbacks, [this](ObjectId id) { delete_handle(id); }) {} private: diff --git a/src/script_interface/electrostatics/CoulombScafacos.hpp b/src/script_interface/electrostatics/CoulombScafacos.hpp index 57c7ed059f3..94b1089726a 100644 --- a/src/script_interface/electrostatics/CoulombScafacos.hpp +++ b/src/script_interface/electrostatics/CoulombScafacos.hpp @@ -26,12 +26,16 @@ #include "Actor.hpp" +#include "core/MpiCallbacks.hpp" +#include "core/communication.hpp" #include "core/electrostatics/scafacos.hpp" #include "core/scafacos/ScafacosContextBase.hpp" #include "script_interface/get_value.hpp" #include "script_interface/scafacos/scafacos.hpp" +#include +#include #include #include #include @@ -42,6 +46,8 @@ namespace ScriptInterface { namespace Coulomb { class CoulombScafacos : public Actor { + std::shared_ptr m_mpi_env_lock; + public: CoulombScafacos() { add_parameters({ @@ -77,6 +83,11 @@ class CoulombScafacos : public Actor { }); } + ~CoulombScafacos() override { + m_actor.reset(); + m_mpi_env_lock.reset(); + } + void do_construct(VariantMap const ¶ms) override { auto const method_name = get_value(params, "method_name"); auto const param_list = params.at("method_params"); @@ -89,6 +100,8 @@ class CoulombScafacos : public Actor { actor()->set_prefactor(prefactor); }); set_charge_neutrality_tolerance(params); + // MPI communicator is needed to destroy the FFT plans + m_mpi_env_lock = ::Communication::mpiCallbacksHandle()->share_mpi_env(); } Variant do_call_method(std::string const &name, diff --git a/src/script_interface/h5md/h5md.hpp b/src/script_interface/h5md/h5md.hpp index 21e16bba28b..82a0f2ef65e 100644 --- a/src/script_interface/h5md/h5md.hpp +++ b/src/script_interface/h5md/h5md.hpp @@ -26,6 +26,8 @@ #ifdef H5MD +#include "core/MpiCallbacks.hpp" +#include "core/communication.hpp" #include "io/writer/h5md_core.hpp" #include "script_interface/ScriptInterface.hpp" @@ -66,8 +68,16 @@ class H5md : public AutoParameters { params, "file_path", "script_path", "fields", "mass_unit", "length_unit", "time_unit", "force_unit", "velocity_unit", "charge_unit"); + // MPI communicator is needed to close parallel file handles + m_mpi_env_lock = ::Communication::mpiCallbacksHandle()->share_mpi_env(); } + ~H5md() override { + m_h5md.reset(); + m_mpi_env_lock.reset(); + } + + std::shared_ptr m_mpi_env_lock; std::shared_ptr<::Writer::H5md::File> m_h5md; std::vector m_output_fields; }; diff --git a/src/script_interface/magnetostatics/DipolarScafacos.hpp b/src/script_interface/magnetostatics/DipolarScafacos.hpp index ef53205ef87..61c99cd5590 100644 --- a/src/script_interface/magnetostatics/DipolarScafacos.hpp +++ b/src/script_interface/magnetostatics/DipolarScafacos.hpp @@ -26,11 +26,15 @@ #include "Actor.hpp" +#include "core/MpiCallbacks.hpp" +#include "core/communication.hpp" #include "core/magnetostatics/scafacos.hpp" #include "core/scafacos/ScafacosContextBase.hpp" #include "script_interface/scafacos/scafacos.hpp" +#include +#include #include #include @@ -38,6 +42,7 @@ namespace ScriptInterface { namespace Dipoles { class DipolarScafacos : public Actor { + std::shared_ptr m_mpi_env_lock; public: DipolarScafacos() { @@ -51,6 +56,11 @@ class DipolarScafacos : public Actor { }); } + ~DipolarScafacos() override { + m_actor.reset(); + m_mpi_env_lock.reset(); + } + void do_construct(VariantMap const ¶ms) override { auto const method_name = get_value(params, "method_name"); auto const param_list = params.at("method_params"); @@ -65,6 +75,8 @@ class DipolarScafacos : public Actor { m_actor = make_dipolar_scafacos(method_name, method_params); actor()->prefactor = prefactor; }); + // MPI communicator is needed to destroy the FFT plans + m_mpi_env_lock = ::Communication::mpiCallbacksHandle()->share_mpi_env(); } Variant do_call_method(std::string const &name, diff --git a/src/script_interface/tests/GlobalContext_test.cpp b/src/script_interface/tests/GlobalContext_test.cpp index b5556b2d18e..da0401ad413 100644 --- a/src/script_interface/tests/GlobalContext_test.cpp +++ b/src/script_interface/tests/GlobalContext_test.cpp @@ -26,12 +26,15 @@ #include #include +#include #include #include #include #include +static std::weak_ptr mpi_env; + namespace si = ScriptInterface; struct Dummy : si::ObjectHandle { @@ -54,18 +57,18 @@ struct Dummy : si::ObjectHandle { } }; -auto make_global_context(Communication::MpiCallbacks &cb) { +auto make_global_context(std::shared_ptr &cb) { Utils::Factory factory; factory.register_new("Dummy"); - boost::mpi::communicator comm; return std::make_shared( - cb, std::make_shared(factory, comm)); + cb, std::make_shared(factory, cb->comm())); } BOOST_AUTO_TEST_CASE(GlobalContext_make_shared) { boost::mpi::communicator world; - Communication::MpiCallbacks cb{world}; + auto cb = + std::make_shared(world, ::mpi_env.lock()); auto ctx = make_global_context(cb); if (world.rank() == 0) { @@ -74,13 +77,14 @@ BOOST_AUTO_TEST_CASE(GlobalContext_make_shared) { BOOST_CHECK_EQUAL(res->context(), ctx.get()); BOOST_CHECK_EQUAL(ctx->name(res.get()), "Dummy"); } else { - cb.loop(); + cb->loop(); } } BOOST_AUTO_TEST_CASE(GlobalContext_serialization) { boost::mpi::communicator world; - Communication::MpiCallbacks cb{world}; + auto cb = + std::make_shared(world, ::mpi_env.lock()); auto ctx = make_global_context(cb); if (world.rank() == 0) { @@ -108,12 +112,13 @@ BOOST_AUTO_TEST_CASE(GlobalContext_serialization) { BOOST_REQUIRE(d3); BOOST_CHECK_EQUAL(boost::get(d3->get_parameter("id")), 3); } else { - cb.loop(); + cb->loop(); } } int main(int argc, char **argv) { - boost::mpi::environment mpi_env(argc, argv); + auto const mpi_env = std::make_shared(argc, argv); + ::mpi_env = mpi_env; return boost::unit_test::unit_test_main(init_unit_test, argc, argv); } diff --git a/src/script_interface/tests/ParallelExceptionHandler_test.cpp b/src/script_interface/tests/ParallelExceptionHandler_test.cpp index c78c11c0cf0..8a475c6a3c1 100644 --- a/src/script_interface/tests/ParallelExceptionHandler_test.cpp +++ b/src/script_interface/tests/ParallelExceptionHandler_test.cpp @@ -45,6 +45,7 @@ #include #include +#include #include #include @@ -52,6 +53,8 @@ namespace utf = boost::unit_test; +static std::weak_ptr mpi_env; + namespace Testing { struct Error : public std::exception {}; struct Warning : public std::exception {}; @@ -68,7 +71,8 @@ struct if_parallel_test { BOOST_TEST_DECORATOR(*utf::precondition(if_parallel_test())) BOOST_AUTO_TEST_CASE(parallel_exceptions) { boost::mpi::communicator world; - Communication::MpiCallbacks callbacks{world}; + auto callbacks = + std::make_shared(world, ::mpi_env.lock()); ErrorHandling::init_error_handling(callbacks); auto handler = ScriptInterface::ParallelExceptionHandler{world}; @@ -130,10 +134,14 @@ BOOST_AUTO_TEST_CASE(parallel_exceptions) { // runtime error messages are printed to stderr and cleared BOOST_CHECK_EQUAL(check_runtime_errors_local(), 0); } + if (world.rank() != 0) { + callbacks->loop(); + } } int main(int argc, char **argv) { - boost::mpi::environment mpi_env(argc, argv); + auto const mpi_env = std::make_shared(argc, argv); + ::mpi_env = mpi_env; return boost::unit_test::unit_test_main(init_unit_test, argc, argv); }