diff --git a/src/runtime/rpc/rpc_endpoint.cc b/src/runtime/rpc/rpc_endpoint.cc index 2c431cdb643c..a0c732a9c845 100644 --- a/src/runtime/rpc/rpc_endpoint.cc +++ b/src/runtime/rpc/rpc_endpoint.cc @@ -258,8 +258,12 @@ class RPCEndpoint::EventHandler : public dmlc::Stream { if (type_index == kRuntimeRPCObjectRefTypeIndex) { uint64_t handle; this->template Read(&handle); - tcode[0] = kTVMObjectHandle; - value[0].v_handle = reinterpret_cast(handle); + // Always wrap things back in RPCObjectRef + // this is because we want to enable multi-hop RPC + // and next hop would also need to check the object index + RPCObjectRef rpc_obj(make_object(reinterpret_cast(handle), nullptr)); + TVMArgsSetter(value, tcode)(0, rpc_obj); + object_arena_.push_back(rpc_obj); } else { LOG(FATAL) << "ValueError: Object type is not supported in Disco calling convention: " << Object::TypeIndex2Key(type_index) << " (type_index = " << type_index << ")"; @@ -276,6 +280,12 @@ class RPCEndpoint::EventHandler : public dmlc::Stream { return arena_.template allocate_(count); } + /*! \brief Recycle all the memory used in the arena */ + void RecycleAll() { + this->object_arena_.clear(); + this->arena_.RecycleAll(); + } + protected: enum State { kInitHeader, @@ -296,6 +306,8 @@ class RPCEndpoint::EventHandler : public dmlc::Stream { bool async_server_mode_{false}; // Internal arena support::Arena arena_; + // internal arena for temp objects + std::vector object_arena_; // State switcher void SwitchToState(State state) { @@ -313,7 +325,7 @@ class RPCEndpoint::EventHandler : public dmlc::Stream { if (state == kRecvPacketNumBytes) { this->RequestBytes(sizeof(uint64_t)); // recycle arena for the next session. - arena_.RecycleAll(); + this->RecycleAll(); } } diff --git a/src/runtime/rpc/rpc_session.h b/src/runtime/rpc/rpc_session.h index b09900d0abaa..f01b571b2599 100644 --- a/src/runtime/rpc/rpc_session.h +++ b/src/runtime/rpc/rpc_session.h @@ -295,13 +295,16 @@ class RPCObjectRefObj : public Object { /*! * \brief constructor * \param object_handle handle that points to the remote object - * \param sess The remote session + * + * \param sess The remote session, when session is nullptr + * it indicate the object is a temp object during rpc transmission + * and we don't have to free it */ RPCObjectRefObj(void* object_handle, std::shared_ptr sess) : object_handle_(object_handle), sess_(sess) {} ~RPCObjectRefObj() { - if (object_handle_ != nullptr) { + if (object_handle_ != nullptr && sess_ != nullptr) { try { sess_->FreeHandle(object_handle_, kTVMObjectHandle); } catch (const Error& e) { diff --git a/tests/python/runtime/test_runtime_rpc.py b/tests/python/runtime/test_runtime_rpc.py index fff203df0051..2cdbb248cfd9 100644 --- a/tests/python/runtime/test_runtime_rpc.py +++ b/tests/python/runtime/test_runtime_rpc.py @@ -449,10 +449,15 @@ def check(client, is_local): assert get_size(shape) == 2 # start server - server = rpc.Server(key="x1") - client = rpc.connect("127.0.0.1", server.port, key="x1") + check(rpc.LocalSession(), True) - check(client, False) + + def check_remote(): + server = rpc.Server(key="x1") + client = rpc.connect("127.0.0.1", server.port, key="x1") + check(client, False) + + check_remote() def check_minrpc(): if tvm.get_global_func("rpc.CreatePipeClient", allow_missing=True) is None: @@ -462,6 +467,14 @@ def check_minrpc(): minrpc_exec = temp.relpath("minrpc") tvm.rpc.with_minrpc(cc.create_executable)(minrpc_exec, []) check(rpc.PopenSession(minrpc_exec), False) + # minrpc on the remote + server = rpc.Server() + client = rpc.connect( + "127.0.0.1", + server.port, + session_constructor_args=["rpc.PopenSession", open(minrpc_exec, "rb").read()], + ) + check(client, False) check_minrpc()