diff --git a/source/op/pt/comm.cc b/source/op/pt/comm.cc index 98141cdde8..16e15ef281 100644 --- a/source/op/pt/comm.cc +++ b/source/op/pt/comm.cc @@ -221,7 +221,8 @@ class Border : public torch::autograd::Function { static void unpack_communicator(const torch::Tensor& communicator_tensor, MPI_Comm& mpi_comm) { long int* communicator = communicator_tensor.data_ptr(); - mpi_comm = reinterpret_cast(*communicator); + int* int_ptr = reinterpret_cast(communicator);//in order to solve mpich type, may cause error + mpi_comm = reinterpret_cast(*int_ptr); } #endif };