diff --git a/src/mpid/ch4/netmod/ofi/ofi_impl.h b/src/mpid/ch4/netmod/ofi/ofi_impl.h index 9917c2e9ea8..af9dd9e2d05 100644 --- a/src/mpid/ch4/netmod/ofi/ofi_impl.h +++ b/src/mpid/ch4/netmod/ofi/ofi_impl.h @@ -11,6 +11,9 @@ #include "ofi_types.h" #include "mpidch4r.h" #include "ch4_impl.h" +#ifdef MPL_HAVE_CUDA +#include /* for cuDeviceGet */ +#endif extern unsigned long long PVAR_COUNTER_nic_sent_bytes_count[MPIDI_OFI_MAX_NICS] ATTRIBUTE((unused)); extern unsigned long long PVAR_COUNTER_nic_recvd_bytes_count[MPIDI_OFI_MAX_NICS] @@ -707,8 +710,15 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_OFI_register_memory(char *send_buf, size_t da mr_attr.context = NULL; if (MPL_gpu_attr_is_strict_dev(attr)) { #ifdef MPL_HAVE_CUDA + CUdevice device; + int dev_id; + + /* libfabric says to get the device handle from cuDeviceGet */ + dev_id = MPL_gpu_get_dev_id_from_attr(attr); + cuDeviceGet(&device, dev_id); + mr_attr.iface = FI_HMEM_CUDA; - mr_attr.device.cuda = MPL_gpu_get_dev_id_from_attr(attr); + mr_attr.device.cuda = device; #elif defined MPL_HAVE_ZE /* OFI does not support tiles yet, need to pass the root device. */ mr_attr.iface = FI_HMEM_ZE;