Skip to content

Commit

Permalink
[Relax][VM] Re-implementation of callback functions (#16573)
Browse files Browse the repository at this point in the history
A follow-up PR from #16542, after
some post-merge discussion regarding the implementation.
  • Loading branch information
Lunderberg authored Feb 15, 2024
1 parent 67bd739 commit daa37e7
Show file tree
Hide file tree
Showing 6 changed files with 28 additions and 88 deletions.
28 changes: 3 additions & 25 deletions include/tvm/runtime/relax_vm/bytecode.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ enum class Opcode {
Ret = 2U,
Goto = 3U,
If = 4U,
CallFromRegister = 5U,
};

/*! \brief A single virtual machine instruction.
Expand Down Expand Up @@ -184,15 +183,10 @@ struct Instruction {
/*! \brief The instruction opcode. */
Opcode op;
union {
struct /* Call, CallFromRegister */ {
struct /* Call */ {
/*! \brief The destination register. */
RegName dst;
/*! \brief The index of the function.
*
* For `OpCode::Call`, this is an index into the table of static
* functions. For `OpCode::CallFromRegister`, this is an index
* of a register.
*/
/*! \brief The index into the packed function table. */
Index func_idx;
/*! \brief The number of arguments to the packed function. */
Index num_args;
Expand All @@ -214,43 +208,27 @@ struct Instruction {
Index false_offset;
};
};

/*!
* \brief Construct a Call instruction.
* \param func_idx The index of the function to call within the
* static function table
* \param func_idx The index of the function to call.
* \param num_args The number of arguments.
* \param args The input arguments.
* \param dst The destination register.
* \return The call instruction.
*/
static Instruction Call(Index func_idx, Index num_args, Arg* args, RegName dst);

/*!
* \brief Construct a Call instruction.
* \param func_idx The index of the function to call within the
* current stack frame's registers.
* \param num_args The number of arguments.
* \param args The input arguments.
* \param dst The destination register.
* \return The call instruction.
*/
static Instruction CallFromRegister(Index func_idx, Index num_args, Arg* args, RegName dst);

/*!
* \brief Construct a return instruction.
* \param result The register containing the return value.
* \return The return instruction.
*/
static Instruction Ret(RegName result);

/*!
* \brief Construct a goto instruction.
* \param pc_offset The register containing the jump offset.
* \return The goto instruction.
*/
static Instruction Goto(RegName pc_offset);

/*!
* \brief Construct an If instruction.
* \param cond The register containing the cond value.
Expand Down
10 changes: 9 additions & 1 deletion src/relax/backend/vm/codegen_vm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,15 @@ class CodeGenVM : public ExprFunctor<Instruction::Arg(const Expr&)> {
void EmitNormalCall(const Call& call_node, RegName dst_reg) {
Instruction::Arg func = VisitExpr(call_node->op);
std::vector<Instruction::Arg> args = VisitArray(call_node->args);
builder_->EmitCall(func, args, dst_reg);

if (func.kind() == vm::Instruction::ArgKind::kFuncIdx) {
builder_->EmitCall(func, args, dst_reg);
} else {
std::vector<Instruction::Arg> closure_args = {
Instruction::Arg::Register(Instruction::kVMRegister), func};
std::copy(args.begin(), args.end(), std::back_inserter(closure_args));
builder_->EmitCall("vm.builtin.invoke_closure", closure_args, dst_reg);
}
}

// Emits call to packed function `name` with arguments copied over from `call_node` args
Expand Down
20 changes: 4 additions & 16 deletions src/relax/backend/vm/exec_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -138,20 +138,10 @@ void ExecBuilderNode::EndFunction(const std::string& func_name) {

void ExecBuilderNode::EmitCall(vm::Instruction::Arg func, std::vector<vm::Instruction::Arg> args,
vm::RegName dst) {
Opcode op_code;
if (func.kind() == vm::Instruction::ArgKind::kFuncIdx) {
op_code = Opcode::Call;
} else if (func.kind() == vm::Instruction::ArgKind::kRegister) {
op_code = Opcode::CallFromRegister;
} else {
LOG(FATAL) << "VM instruction for a function must be either "
<< "kFuncIdx (static function ) "
<< "or kRegister (function passed as parameter), "
<< "but instead found " << func.kind();
}
ICHECK(func.kind() == vm::Instruction::ArgKind::kFuncIdx);
// store instruction
exec_->instr_offset.push_back(exec_->instr_data.size());
exec_->instr_data.push_back(static_cast<ExecWord>(op_code));
exec_->instr_data.push_back(static_cast<ExecWord>(Opcode::Call));
exec_->instr_data.push_back(dst);
exec_->instr_data.push_back(func.value());
exec_->instr_data.push_back(args.size());
Expand Down Expand Up @@ -238,8 +228,7 @@ void ExecBuilderNode::CheckExecutable() {
for (size_t idx = start_instr; idx < end_instr; ++idx) {
Instruction instr = exec_->GetInstruction(idx);
switch (instr.op) {
case Opcode::Call:
case Opcode::CallFromRegister: {
case Opcode::Call: {
check_func_defined(Instruction::Arg::FuncIdx(instr.func_idx));
for (int i = 0; i < instr.num_args; ++i) {
check_reg_defined(instr.args[i]);
Expand Down Expand Up @@ -291,8 +280,7 @@ void ExecBuilderNode::Formalize() {
for (size_t idx = start_instr; idx < end_instr; ++idx) {
Instruction instr = this->exec_->GetInstruction(idx);
switch (instr.op) {
case Opcode::Call:
case Opcode::CallFromRegister: {
case Opcode::Call: {
// rewrite args
for (int i = 0; i < instr.num_args; ++i) {
if (instr.args[i].kind() == Instruction::ArgKind::kRegister &&
Expand Down
11 changes: 0 additions & 11 deletions src/runtime/relax_vm/bytecode.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,17 +42,6 @@ Instruction Instruction::Call(Index func_idx, Index num_args, Instruction::Arg*
return instr;
}

Instruction Instruction::CallFromRegister(Index func_idx, Index num_args, Instruction::Arg* args,
RegName dst) {
Instruction instr;
instr.op = Opcode::CallFromRegister;
instr.dst = dst;
instr.func_idx = func_idx;
instr.num_args = num_args;
instr.args = args;
return instr;
}

Instruction Instruction::Ret(RegName result) {
Instruction instr;
instr.op = Opcode::Ret;
Expand Down
8 changes: 0 additions & 8 deletions src/runtime/relax_vm/executable.cc
Original file line number Diff line number Diff line change
Expand Up @@ -134,14 +134,6 @@ Instruction Executable::GetInstruction(Index i) const {
ExecWord* args = const_cast<ExecWord*>(&instr_data[offset + 4]);
return Instruction::Call(func_idx, num_args, reinterpret_cast<Instruction::Arg*>(args), dst);
}
case Opcode::CallFromRegister: {
RegName dst = instr_data[offset + 1];
Index func_idx = instr_data[offset + 2];
Index num_args = instr_data[offset + 3];
ExecWord* args = const_cast<ExecWord*>(&instr_data[offset + 4]);
return Instruction::CallFromRegister(func_idx, num_args,
reinterpret_cast<Instruction::Arg*>(args), dst);
}
case Opcode::Ret: {
RegName result = instr_data[offset + 1];
return Instruction::Ret(result);
Expand Down
39 changes: 12 additions & 27 deletions src/runtime/relax_vm/vm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -372,10 +372,9 @@ class VirtualMachineImpl : public VirtualMachine {
/*!
* \brief Run call instruction.
* \param curr_frame The current frame.
* \param callable The callable object, either PackedFunc or closure
* \param inst The call instruction.
*/
virtual void RunInstrCall(VMFrame* curr_frame, const ObjectRef& callable, Instruction inst);
virtual void RunInstrCall(VMFrame* curr_frame, Instruction inst);

/*! \brief Run VM dispatch loop. */
void RunLoop();
Expand Down Expand Up @@ -507,18 +506,14 @@ void VirtualMachineImpl::SetInput(std::string func_name, bool with_param_module,
//------------------------------------------
void VirtualMachineImpl::InvokeClosurePacked(const ObjectRef& closure_or_packedfunc, TVMArgs args,
TVMRetValue* rv) {
ICHECK(closure_or_packedfunc.defined())
<< "InvokeClosurePacked requires the callable object to be defined";

// run packed call if it is a packed func.
if (auto* packed = closure_or_packedfunc.as<PackedFunc::ContainerType>()) {
packed->CallPacked(args, rv);
return;
}
// run closure call.
auto* clo = closure_or_packedfunc.as<VMClosureObj>();
ICHECK(clo != nullptr) << "Function expects a closure or PackedFunc, "
<< "but received " << closure_or_packedfunc->GetTypeKey();
ICHECK(clo != nullptr) << "Function expects a closure or PackedFunc ";

std::vector<TVMValue> values(args.size() + 1);
std::vector<int> tcodes(args.size() + 1);
Expand Down Expand Up @@ -600,8 +595,6 @@ Optional<VMClosure> VirtualMachineImpl::GetClosureInternal(const String& func_na
auto impl = PackedFunc([gf_idx](TVMArgs args, TVMRetValue* rv) {
// Per convention, ctx ptr is a VirtualMachine*
VirtualMachine* ctx_ptr = static_cast<VirtualMachine*>(args[0].operator void*());
ICHECK(ctx_ptr) << "Context pointer for relax VM closure should be a VirtualMachine*, "
<< "but was NULL";

std::vector<RegType> inputs(args.size() - 1);
for (size_t i = 0; i < inputs.size(); ++i) {
Expand Down Expand Up @@ -651,7 +644,7 @@ RegType VirtualMachineImpl::InvokeBytecode(Index gf_idx, const std::vector<RegTy
auto guard = PushFrame(this->pc_, gfunc);
// Get new frame and set the caller info.
VMFrame* curr_frame = frames_.back().get();
if (curr_instr.op == Opcode::Call || curr_instr.op == Opcode::CallFromRegister) {
if (curr_instr.op == Opcode::Call) {
curr_frame->caller_return_register = curr_instr.dst;
}

Expand Down Expand Up @@ -695,12 +688,8 @@ void VirtualMachineImpl::InitFuncPool() {
}
}

void VirtualMachineImpl::RunInstrCall(VMFrame* curr_frame, const ObjectRef& callable,
Instruction instr) {
ICHECK(callable.defined()) << "RunInstrCall requires the callable object to be defined";
auto func_name = instr.op == Opcode::Call ? GetFuncName(instr.func_idx) : "<dynamic>";

DLOG(INFO) << "\n pc = " << pc_ << ", execute: " << func_name;
void VirtualMachineImpl::RunInstrCall(VMFrame* curr_frame, Instruction instr) {
DLOG(INFO) << "\n pc = " << pc_ << ", execute: " << GetFuncName(instr.func_idx);
int args_begin_offset = instrument_ != nullptr ? 4 : 0;
// Use the call arg stack from the current frame to increase reuse
// and avoid re-allocation
Expand Down Expand Up @@ -746,11 +735,11 @@ void VirtualMachineImpl::RunInstrCall(VMFrame* curr_frame, const ObjectRef& call
ICHECK_LT(static_cast<size_t>(instr.func_idx), this->func_pool_.size());

if (instrument_ == nullptr) {
this->InvokeClosurePacked(callable, args, &ret);
this->InvokeClosurePacked(func_pool_[instr.func_idx], args, &ret);
} else {
// insert light-weight instrument callback
setter(0, callable);
setter(1, func_name);
setter(0, func_pool_[instr.func_idx]);
setter(1, GetFuncName(instr.func_idx));
setter(2, true);
setter(3, nullptr);
TVMRetValue rv;
Expand All @@ -769,7 +758,7 @@ void VirtualMachineImpl::RunInstrCall(VMFrame* curr_frame, const ObjectRef& call
ret_kind = rv;
}
if (ret_kind != static_cast<int>(VMInstrumentReturnKind::kSkipRun)) {
this->InvokeClosurePacked(callable, args, &ret);
this->InvokeClosurePacked(func_pool_[instr.func_idx], args, &ret);
setter(2, false);
setter(3, ret);
instrument_.CallPacked(TVMArgs(values.data(), tcodes.data(), values.size()), &rv);
Expand All @@ -793,11 +782,7 @@ void VirtualMachineImpl::RunLoop() {
Instruction instr = exec_->GetInstruction(pc_);
switch (instr.op) {
case Opcode::Call: {
this->RunInstrCall(curr_frame, func_pool_[instr.func_idx], instr);
break;
}
case Opcode::CallFromRegister: {
this->RunInstrCall(curr_frame, ReadRegister(curr_frame, instr.func_idx), instr);
this->RunInstrCall(curr_frame, instr);
break;
}
case Opcode::Ret: {
Expand Down Expand Up @@ -1015,7 +1000,7 @@ class VirtualMachineProfiler : public VirtualMachineImpl {
}

protected:
void RunInstrCall(VMFrame* curr_frame, const ObjectRef& callable, Instruction inst) override {
void RunInstrCall(VMFrame* curr_frame, Instruction inst) override {
bool profiling = false;
if (prof_ && prof_->IsRunning()) {
auto f_name = GetFuncName(inst.func_idx);
Expand Down Expand Up @@ -1051,7 +1036,7 @@ class VirtualMachineProfiler : public VirtualMachineImpl {
}
}

VirtualMachineImpl::RunInstrCall(curr_frame, callable, inst);
VirtualMachineImpl::RunInstrCall(curr_frame, inst);

if (profiling) {
prof_->StopCall();
Expand Down

0 comments on commit daa37e7

Please sign in to comment.