diff --git a/test/gtest/common/mem_buffer.cc b/test/gtest/common/mem_buffer.cc index 5b5222b768c..7bc8afb713a 100644 --- a/test/gtest/common/mem_buffer.cc +++ b/test/gtest/common/mem_buffer.cc @@ -169,7 +169,7 @@ bool mem_buffer::is_mem_type_supported(ucs_memory_type_t mem_type) mem_types.end(); } -void mem_buffer::set_device_context() +void mem_buffer::set_device_context(int device) { static __thread bool device_set = false; @@ -179,7 +179,7 @@ void mem_buffer::set_device_context() #if HAVE_CUDA if (is_cuda_supported()) { - cudaSetDevice(0); + cudaSetDevice(device); /* need to call free as context maybe lazily initialized when calling * cudaSetDevice(0) but calling cudaFree(0) should guarantee context * creation upon return */ @@ -189,7 +189,7 @@ void mem_buffer::set_device_context() #if HAVE_ROCM if (is_rocm_supported()) { - hipSetDevice(0); + hipSetDevice(device); } #endif diff --git a/test/gtest/common/mem_buffer.h b/test/gtest/common/mem_buffer.h index 4b1c285b2b8..9c45c8466e2 100644 --- a/test/gtest/common/mem_buffer.h +++ b/test/gtest/common/mem_buffer.h @@ -86,7 +86,7 @@ class mem_buffer { static bool is_gpu_supported(); /* set device context if compiled with GPU support */ - static void set_device_context(); + static void set_device_context(int device = 0); /* returns whether ROCM device supports managed memory */ static bool is_rocm_managed_supported(); diff --git a/test/gtest/ucp/test_ucp_mmap.cc b/test/gtest/ucp/test_ucp_mmap.cc index 21b6cc3d2da..f9027258713 100644 --- a/test/gtest/ucp/test_ucp_mmap.cc +++ b/test/gtest/ucp/test_ucp_mmap.cc @@ -17,6 +17,10 @@ extern "C" { #include } +#if HAVE_CUDA +#include +#endif + #include #include @@ -1248,3 +1252,51 @@ UCS_TEST_P(test_ucp_mmap_export, export_import) { } UCP_INSTANTIATE_TEST_CASE_GPU_AWARE(test_ucp_mmap_export) + +#if HAVE_CUDA +class test_ucp_mmap_mgpu : public ucs::test { +}; + +UCS_TEST_F(test_ucp_mmap_mgpu, switch_gpu) { + if (!mem_buffer::is_mem_type_supported(UCS_MEMORY_TYPE_CUDA)) { + UCS_TEST_SKIP_R("cuda is not supported"); + } + + int num_devices; + ASSERT_EQ(cudaGetDeviceCount(&num_devices), cudaSuccess); + + if (num_devices < 2) { + UCS_TEST_SKIP_R("less than two cuda devices available"); + } + + ucs::handle config; + UCS_TEST_CREATE_HANDLE(ucp_config_t*, config, ucp_config_release, + ucp_config_read, NULL, NULL); + + ucs::handle context; + ucp_params_t params; + params.field_mask = UCP_PARAM_FIELD_FEATURES; + params.features = UCP_FEATURE_TAG; + UCS_TEST_CREATE_HANDLE(ucp_context_h, context, ucp_cleanup, ucp_init, + ¶ms, config.get()); + + int device; + ASSERT_EQ(cudaGetDevice(&device), cudaSuccess); + ASSERT_EQ(cudaSetDevice((device + 1) % num_devices), cudaSuccess); + + const size_t size = 16; + mem_buffer buffer(size, UCS_MEMORY_TYPE_CUDA); + + ASSERT_EQ(cudaSetDevice(device), cudaSuccess); + + ucp_mem_map_params_t mem_map_params; + mem_map_params.field_mask = UCP_MEM_MAP_PARAM_FIELD_ADDRESS | + UCP_MEM_MAP_PARAM_FIELD_LENGTH; + mem_map_params.address = buffer.ptr(); + mem_map_params.length = size; + + ucp_mem_h ucp_mem; + ASSERT_EQ(ucp_mem_map(context.get(), &mem_map_params, &ucp_mem), UCS_OK); + EXPECT_EQ(ucp_mem_unmap(context.get(), ucp_mem), UCS_OK); +} +#endif