From daa37e7e954f77006da4f2994bf544a1e9c62fce Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Thu, 15 Feb 2024 13:10:10 -0600 Subject: [PATCH] [Relax][VM] Re-implementation of callback functions (#16573) A follow-up PR from https://github.com/apache/tvm/pull/16542, after some post-merge discussion regarding the implementation. --- include/tvm/runtime/relax_vm/bytecode.h | 28 ++---------------- src/relax/backend/vm/codegen_vm.cc | 10 ++++++- src/relax/backend/vm/exec_builder.cc | 20 +++---------- src/runtime/relax_vm/bytecode.cc | 11 ------- src/runtime/relax_vm/executable.cc | 8 ----- src/runtime/relax_vm/vm.cc | 39 ++++++++----------------- 6 files changed, 28 insertions(+), 88 deletions(-) diff --git a/include/tvm/runtime/relax_vm/bytecode.h b/include/tvm/runtime/relax_vm/bytecode.h index 0db610ff42b2..4526c6fffa1d 100644 --- a/include/tvm/runtime/relax_vm/bytecode.h +++ b/include/tvm/runtime/relax_vm/bytecode.h @@ -58,7 +58,6 @@ enum class Opcode { Ret = 2U, Goto = 3U, If = 4U, - CallFromRegister = 5U, }; /*! \brief A single virtual machine instruction. @@ -184,15 +183,10 @@ struct Instruction { /*! \brief The instruction opcode. */ Opcode op; union { - struct /* Call, CallFromRegister */ { + struct /* Call */ { /*! \brief The destination register. */ RegName dst; - /*! \brief The index of the function. - * - * For `OpCode::Call`, this is an index into the table of static - * functions. For `OpCode::CallFromRegister`, this is an index - * of a register. - */ + /*! \brief The index into the packed function table. */ Index func_idx; /*! \brief The number of arguments to the packed function. */ Index num_args; @@ -214,43 +208,27 @@ struct Instruction { Index false_offset; }; }; - /*! * \brief Construct a Call instruction. - * \param func_idx The index of the function to call within the - * static function table + * \param func_idx The index of the function to call. * \param num_args The number of arguments. * \param args The input arguments. * \param dst The destination register. * \return The call instruction. */ static Instruction Call(Index func_idx, Index num_args, Arg* args, RegName dst); - - /*! - * \brief Construct a Call instruction. - * \param func_idx The index of the function to call within the - * current stack frame's registers. - * \param num_args The number of arguments. - * \param args The input arguments. - * \param dst The destination register. - * \return The call instruction. - */ - static Instruction CallFromRegister(Index func_idx, Index num_args, Arg* args, RegName dst); - /*! * \brief Construct a return instruction. * \param result The register containing the return value. * \return The return instruction. */ static Instruction Ret(RegName result); - /*! * \brief Construct a goto instruction. * \param pc_offset The register containing the jump offset. * \return The goto instruction. */ static Instruction Goto(RegName pc_offset); - /*! * \brief Construct an If instruction. * \param cond The register containing the cond value. diff --git a/src/relax/backend/vm/codegen_vm.cc b/src/relax/backend/vm/codegen_vm.cc index 218fe6b1202c..329da67e84ec 100644 --- a/src/relax/backend/vm/codegen_vm.cc +++ b/src/relax/backend/vm/codegen_vm.cc @@ -391,7 +391,15 @@ class CodeGenVM : public ExprFunctor { void EmitNormalCall(const Call& call_node, RegName dst_reg) { Instruction::Arg func = VisitExpr(call_node->op); std::vector args = VisitArray(call_node->args); - builder_->EmitCall(func, args, dst_reg); + + if (func.kind() == vm::Instruction::ArgKind::kFuncIdx) { + builder_->EmitCall(func, args, dst_reg); + } else { + std::vector closure_args = { + Instruction::Arg::Register(Instruction::kVMRegister), func}; + std::copy(args.begin(), args.end(), std::back_inserter(closure_args)); + builder_->EmitCall("vm.builtin.invoke_closure", closure_args, dst_reg); + } } // Emits call to packed function `name` with arguments copied over from `call_node` args diff --git a/src/relax/backend/vm/exec_builder.cc b/src/relax/backend/vm/exec_builder.cc index aa478122353d..b5d932137be0 100644 --- a/src/relax/backend/vm/exec_builder.cc +++ b/src/relax/backend/vm/exec_builder.cc @@ -138,20 +138,10 @@ void ExecBuilderNode::EndFunction(const std::string& func_name) { void ExecBuilderNode::EmitCall(vm::Instruction::Arg func, std::vector args, vm::RegName dst) { - Opcode op_code; - if (func.kind() == vm::Instruction::ArgKind::kFuncIdx) { - op_code = Opcode::Call; - } else if (func.kind() == vm::Instruction::ArgKind::kRegister) { - op_code = Opcode::CallFromRegister; - } else { - LOG(FATAL) << "VM instruction for a function must be either " - << "kFuncIdx (static function ) " - << "or kRegister (function passed as parameter), " - << "but instead found " << func.kind(); - } + ICHECK(func.kind() == vm::Instruction::ArgKind::kFuncIdx); // store instruction exec_->instr_offset.push_back(exec_->instr_data.size()); - exec_->instr_data.push_back(static_cast(op_code)); + exec_->instr_data.push_back(static_cast(Opcode::Call)); exec_->instr_data.push_back(dst); exec_->instr_data.push_back(func.value()); exec_->instr_data.push_back(args.size()); @@ -238,8 +228,7 @@ void ExecBuilderNode::CheckExecutable() { for (size_t idx = start_instr; idx < end_instr; ++idx) { Instruction instr = exec_->GetInstruction(idx); switch (instr.op) { - case Opcode::Call: - case Opcode::CallFromRegister: { + case Opcode::Call: { check_func_defined(Instruction::Arg::FuncIdx(instr.func_idx)); for (int i = 0; i < instr.num_args; ++i) { check_reg_defined(instr.args[i]); @@ -291,8 +280,7 @@ void ExecBuilderNode::Formalize() { for (size_t idx = start_instr; idx < end_instr; ++idx) { Instruction instr = this->exec_->GetInstruction(idx); switch (instr.op) { - case Opcode::Call: - case Opcode::CallFromRegister: { + case Opcode::Call: { // rewrite args for (int i = 0; i < instr.num_args; ++i) { if (instr.args[i].kind() == Instruction::ArgKind::kRegister && diff --git a/src/runtime/relax_vm/bytecode.cc b/src/runtime/relax_vm/bytecode.cc index 30d3bebd5f33..9084207848b5 100644 --- a/src/runtime/relax_vm/bytecode.cc +++ b/src/runtime/relax_vm/bytecode.cc @@ -42,17 +42,6 @@ Instruction Instruction::Call(Index func_idx, Index num_args, Instruction::Arg* return instr; } -Instruction Instruction::CallFromRegister(Index func_idx, Index num_args, Instruction::Arg* args, - RegName dst) { - Instruction instr; - instr.op = Opcode::CallFromRegister; - instr.dst = dst; - instr.func_idx = func_idx; - instr.num_args = num_args; - instr.args = args; - return instr; -} - Instruction Instruction::Ret(RegName result) { Instruction instr; instr.op = Opcode::Ret; diff --git a/src/runtime/relax_vm/executable.cc b/src/runtime/relax_vm/executable.cc index 9de708f49a9a..f45786c3da32 100644 --- a/src/runtime/relax_vm/executable.cc +++ b/src/runtime/relax_vm/executable.cc @@ -134,14 +134,6 @@ Instruction Executable::GetInstruction(Index i) const { ExecWord* args = const_cast(&instr_data[offset + 4]); return Instruction::Call(func_idx, num_args, reinterpret_cast(args), dst); } - case Opcode::CallFromRegister: { - RegName dst = instr_data[offset + 1]; - Index func_idx = instr_data[offset + 2]; - Index num_args = instr_data[offset + 3]; - ExecWord* args = const_cast(&instr_data[offset + 4]); - return Instruction::CallFromRegister(func_idx, num_args, - reinterpret_cast(args), dst); - } case Opcode::Ret: { RegName result = instr_data[offset + 1]; return Instruction::Ret(result); diff --git a/src/runtime/relax_vm/vm.cc b/src/runtime/relax_vm/vm.cc index 14a42df5f1e4..d7f943d5f40f 100644 --- a/src/runtime/relax_vm/vm.cc +++ b/src/runtime/relax_vm/vm.cc @@ -372,10 +372,9 @@ class VirtualMachineImpl : public VirtualMachine { /*! * \brief Run call instruction. * \param curr_frame The current frame. - * \param callable The callable object, either PackedFunc or closure * \param inst The call instruction. */ - virtual void RunInstrCall(VMFrame* curr_frame, const ObjectRef& callable, Instruction inst); + virtual void RunInstrCall(VMFrame* curr_frame, Instruction inst); /*! \brief Run VM dispatch loop. */ void RunLoop(); @@ -507,9 +506,6 @@ void VirtualMachineImpl::SetInput(std::string func_name, bool with_param_module, //------------------------------------------ void VirtualMachineImpl::InvokeClosurePacked(const ObjectRef& closure_or_packedfunc, TVMArgs args, TVMRetValue* rv) { - ICHECK(closure_or_packedfunc.defined()) - << "InvokeClosurePacked requires the callable object to be defined"; - // run packed call if it is a packed func. if (auto* packed = closure_or_packedfunc.as()) { packed->CallPacked(args, rv); @@ -517,8 +513,7 @@ void VirtualMachineImpl::InvokeClosurePacked(const ObjectRef& closure_or_packedf } // run closure call. auto* clo = closure_or_packedfunc.as(); - ICHECK(clo != nullptr) << "Function expects a closure or PackedFunc, " - << "but received " << closure_or_packedfunc->GetTypeKey(); + ICHECK(clo != nullptr) << "Function expects a closure or PackedFunc "; std::vector values(args.size() + 1); std::vector tcodes(args.size() + 1); @@ -600,8 +595,6 @@ Optional VirtualMachineImpl::GetClosureInternal(const String& func_na auto impl = PackedFunc([gf_idx](TVMArgs args, TVMRetValue* rv) { // Per convention, ctx ptr is a VirtualMachine* VirtualMachine* ctx_ptr = static_cast(args[0].operator void*()); - ICHECK(ctx_ptr) << "Context pointer for relax VM closure should be a VirtualMachine*, " - << "but was NULL"; std::vector inputs(args.size() - 1); for (size_t i = 0; i < inputs.size(); ++i) { @@ -651,7 +644,7 @@ RegType VirtualMachineImpl::InvokeBytecode(Index gf_idx, const std::vectorpc_, gfunc); // Get new frame and set the caller info. VMFrame* curr_frame = frames_.back().get(); - if (curr_instr.op == Opcode::Call || curr_instr.op == Opcode::CallFromRegister) { + if (curr_instr.op == Opcode::Call) { curr_frame->caller_return_register = curr_instr.dst; } @@ -695,12 +688,8 @@ void VirtualMachineImpl::InitFuncPool() { } } -void VirtualMachineImpl::RunInstrCall(VMFrame* curr_frame, const ObjectRef& callable, - Instruction instr) { - ICHECK(callable.defined()) << "RunInstrCall requires the callable object to be defined"; - auto func_name = instr.op == Opcode::Call ? GetFuncName(instr.func_idx) : ""; - - DLOG(INFO) << "\n pc = " << pc_ << ", execute: " << func_name; +void VirtualMachineImpl::RunInstrCall(VMFrame* curr_frame, Instruction instr) { + DLOG(INFO) << "\n pc = " << pc_ << ", execute: " << GetFuncName(instr.func_idx); int args_begin_offset = instrument_ != nullptr ? 4 : 0; // Use the call arg stack from the current frame to increase reuse // and avoid re-allocation @@ -746,11 +735,11 @@ void VirtualMachineImpl::RunInstrCall(VMFrame* curr_frame, const ObjectRef& call ICHECK_LT(static_cast(instr.func_idx), this->func_pool_.size()); if (instrument_ == nullptr) { - this->InvokeClosurePacked(callable, args, &ret); + this->InvokeClosurePacked(func_pool_[instr.func_idx], args, &ret); } else { // insert light-weight instrument callback - setter(0, callable); - setter(1, func_name); + setter(0, func_pool_[instr.func_idx]); + setter(1, GetFuncName(instr.func_idx)); setter(2, true); setter(3, nullptr); TVMRetValue rv; @@ -769,7 +758,7 @@ void VirtualMachineImpl::RunInstrCall(VMFrame* curr_frame, const ObjectRef& call ret_kind = rv; } if (ret_kind != static_cast(VMInstrumentReturnKind::kSkipRun)) { - this->InvokeClosurePacked(callable, args, &ret); + this->InvokeClosurePacked(func_pool_[instr.func_idx], args, &ret); setter(2, false); setter(3, ret); instrument_.CallPacked(TVMArgs(values.data(), tcodes.data(), values.size()), &rv); @@ -793,11 +782,7 @@ void VirtualMachineImpl::RunLoop() { Instruction instr = exec_->GetInstruction(pc_); switch (instr.op) { case Opcode::Call: { - this->RunInstrCall(curr_frame, func_pool_[instr.func_idx], instr); - break; - } - case Opcode::CallFromRegister: { - this->RunInstrCall(curr_frame, ReadRegister(curr_frame, instr.func_idx), instr); + this->RunInstrCall(curr_frame, instr); break; } case Opcode::Ret: { @@ -1015,7 +1000,7 @@ class VirtualMachineProfiler : public VirtualMachineImpl { } protected: - void RunInstrCall(VMFrame* curr_frame, const ObjectRef& callable, Instruction inst) override { + void RunInstrCall(VMFrame* curr_frame, Instruction inst) override { bool profiling = false; if (prof_ && prof_->IsRunning()) { auto f_name = GetFuncName(inst.func_idx); @@ -1051,7 +1036,7 @@ class VirtualMachineProfiler : public VirtualMachineImpl { } } - VirtualMachineImpl::RunInstrCall(curr_frame, callable, inst); + VirtualMachineImpl::RunInstrCall(curr_frame, inst); if (profiling) { prof_->StopCall();