Skip to content

Commit

Permalink
[Relax][Transform] Provide callback versions of LazyTransformParams (#…
Browse files Browse the repository at this point in the history
…16798)

* [TIR][Analysis] Implemented tir.analysis.is_pure_function

This commit introduces two related utilities,
`tir.analysis.is_pure_function` and `tir.analysis.assert_pure_function`.
In contrast to the existing `tvm::tir::SideEffect`, which checks for
side effects on a for a `PrimExpr`, `is_pure_function` checks for side
effects for the function as a whole.

* [Transform] Implement relax.transform.ComputePrimValue

Prior to this commit, while expressions of type `DataType::Int(64)`
could be computed in the `relax.transform.VMShapeLower`, expressions
of any other type could not.  This commit introduces
`relax.transform.ComputePrimValue`, which produces `PrimFunc`
subroutines to compute `PrimExpr` values of any dtype.

This functionality will allow boolean values to be computed based on
the symbolic values known at runtime.

* [Relax] Allow R.Prim('bool') in relax::If and assert_op

Prior to this commit, the condition used for `relax::If` node and the
`"relax.assert_op"` operator was required to be a scalar tensor.  This
made it difficult to alter behavior based on a runtime shape
parameter.  For example, delegating to a vectorized implementation
based on a whether a tensor shape is divisible by the vector size.

This commit adds support for expressions of type `R.Prim('bool')` as
the conditional for `relax::If` and `"relax.assert_op"`, to allow
these use cases.

* [Relax][Transform] Provide callback versions of LazyTransformParams

Prior to this commit, the `LazyTransformParams` function could be used
to load model parameters on demand.  However, the function used to
load or set parameters needed to be registered within the global
registry of `PackedFunc`s.  This PR provides `LazyGetInput` and
`LazySetOutput` transforms, which perform the lazy-loading through a
`R.Callable` callback argument, rather than through a
globally-registered `PackedFunc`.

* Reverse the order of parameters in fget_param

If `fget_param` accepts the parameter index first, and the parameter
name second, then an implementation with signauture and default values
of `def fget_param(index: int, name: Optional[str]=None)` could be
used as either the callback of `LazyGetInput`, or as the
globally-registered `"get_item"` for the existing
`LazyTransformParams`, which should make it easier to transition
between the two.

* lint fix

* Updates based on review comments
  • Loading branch information
Lunderberg authored Apr 3, 2024
1 parent 545e097 commit 61249b4
Show file tree
Hide file tree
Showing 4 changed files with 676 additions and 0 deletions.
2 changes: 2 additions & 0 deletions python/tvm/relax/transform/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@
InlinePrivateFunctions,
KillAfterLastUse,
LambdaLift,
LazyGetInput,
LazySetOutput,
LegalizeOps,
LiftTransformParams,
LowerAllocTensor,
Expand Down
80 changes: 80 additions & 0 deletions python/tvm/relax/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,86 @@ def LambdaLift() -> tvm.ir.transform.Pass:
return _ffi_api.LambdaLift()


def LazyGetInput() -> tvm.ir.transform.Pass:
"""A pass that requests inputs lazily.
In many cases, the size of the model weights exceeds the available
memory on a GPU. In these cases, a function that accepts all
model weights as arguments would not be able to be called. In
these cases, parameters must be loaded as they are required by the
function, and unloaded once they are no longer needed.
This pass mutates a function such that all model weights
(arguments after the first `func.attrs["num_input"]` arguments)
are loaded on demand. Rather than accepting the weights as
function arguments, the function accepts a callback argument,
which can load each parameter as needed. The callback accepts two
arguments, first the index of the model weight, and second the
name of the parameter. The callback should return the parameter
as specified.
.. code-block:: python
@R.function
def before(A: R.Tensor([16,32],"float32")):
...
@R.function
def after(fget_param: R.Callable([R.Prim('int64'), R.Object], R.Object)):
A_untyped = fget_param(0, R.str('A'))
A = R.match_cast(A_untyped, R.Tensor([16,32], "float32")
...
Returns
-------
ret : tvm.ir.transform.Pass
"""
return _ffi_api.LazyGetInput()


def LazySetOutput() -> tvm.ir.transform.Pass:
"""A pass that sets function outputs when available
In many cases, the size of the model weights exceeds the available
memory on a GPU. In these cases, a function that produces all
model weights as a single return value would not be able to be
called. In these cases, parameters must be returned as they are
produced, unloaded from the GPU (or saved to disk), before
producing additional outputs.
This pass mutates a function such that all outputs from a function
are returned when they are available. The function accepts an
additional callback argument, which is called with each output of
the function. The callback accepts two arguments, first the index
of the output tuple that was produced (or zero if the output is
not a tuple), and second the value itself.
.. code-block:: python
@R.function
def before(args):
...
return (A, B)
@R.function
def after(args, fset_param: R.Callable([R.Prim('int64'), R.Object])):
...
fset_param(0, A)
...
fset_param(1, B)
...
return ()
Returns
-------
ret : tvm.ir.transform.Pass
"""
return _ffi_api.LazySetOutput()


def ConvertToDataflow(min_size: int = 2) -> tvm.ir.transform.Pass:
"""A pass that converts consecutive dataflow operations
inside binding blocks into dataflow blocks.
Expand Down
266 changes: 266 additions & 0 deletions src/relax/transform/lazy_transform_params.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,266 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*! \file src/relax/transform/lazy_transform_params.cc */

#include <tvm/relax/analysis.h>
#include <tvm/relax/expr.h>
#include <tvm/relax/expr_functor.h>
#include <tvm/relax/transform.h>

#include <optional>
#include <unordered_map>

#include "utils.h"

namespace tvm {
namespace relax {

namespace {
std::optional<int64_t> GetNumInputParams(const FunctionNode* func) {
if (auto opt_int_imm = func->GetAttr<IntImm>(attr::kNumInput)) {
int64_t num_input_params = opt_int_imm.value()->value;
CHECK_GE(num_input_params, 0) << "ValueError: "
<< "Annotation for attr::kNumInput (\"" << attr::kNumInput
<< "\") must be non-negative, but was " << num_input_params;
CHECK_LE(static_cast<size_t>(num_input_params), func->params.size())
<< "ValueError: "
<< "Annotation for attr::kNumInput (\"" << attr::kNumInput << "\") specifies "
<< num_input_params << " parameters to be provided at runtime, "
<< "but the function only accepts " << func->params.size() << " parameters in total";
return num_input_params;
} else {
return std::nullopt;
}
}

class LazyInputMutator : public ExprMutator {
public:
Expr VisitExpr_(const FunctionNode* func) override {
if (plan_.has_value()) {
return ExprMutator::VisitExpr_(func);
}

int64_t num_input_params = GetNumInputParams(func).value_or(0);

std::unordered_map<Var, size_t, ObjectPtrHash, ObjectPtrEqual> param_lookup;
for (size_t i = num_input_params; i < func->params.size(); i++) {
param_lookup.insert({func->params[i], i - num_input_params});
}

Var fget_param("fget_param",
FuncStructInfo({PrimStructInfo(DataType::Int(64)), ObjectStructInfo()},
ObjectStructInfo()));

Array<Var> new_params(func->params.begin(), func->params.begin() + num_input_params);
new_params.push_back(fget_param);

auto node = GetRef<Function>(func);
node.CopyOnWrite()->params = new_params;
node = WithAttr(node, attr::kNumInput, Integer(num_input_params + 1));

plan_ = FunctionPlan{std::move(param_lookup), fget_param};
auto output = Downcast<Function>(ExprMutator::VisitExpr_(node.get()));
plan_.reset();
return output;
}

Expr VisitExpr_(const VarNode* op) override {
if (plan_) {
Var var = GetRef<Var>(op);
if (auto it = plan_->param_lookup.find(var); it != plan_->param_lookup.end()) {
auto untyped =
builder_->Emit(relax::Call(plan_->fget_param,
{
PrimValue(IntImm(DataType::Int(64), it->second)),
StringImm(var->name_hint()),
}),
var->name_hint() + "_untyped");
return builder_->EmitMatchCast(untyped, GetStructInfo(var), var->name_hint());
}
}

return ExprMutator::VisitExpr_(op);
}

private:
struct FunctionPlan {
std::unordered_map<Var, size_t, ObjectPtrHash, ObjectPtrEqual> param_lookup;
Expr fget_param;
};
std::optional<FunctionPlan> plan_;
};

class LazyOutputMutator : public ExprMutator {
public:
Expr VisitExpr_(const FunctionNode* func) override {
if (plan_.has_value()) {
return ExprMutator::VisitExpr_(func);
}

std::unordered_map<Var, std::vector<size_t>, ObjectPtrHash, ObjectPtrEqual> output_lookup;
std::vector<std::tuple<size_t, Expr>> inline_outputs;
auto define_lookup = [&](size_t output_index, Expr output_value) {
if (auto var = output_value.as<Var>()) {
output_lookup[var.value()].push_back(output_index);
} else {
inline_outputs.push_back({output_index, output_value});
}
};

auto func_body = Downcast<SeqExpr>(func->body);
if (auto tuple_output = func_body->body.as<TupleNode>()) {
for (size_t i = 0; i < tuple_output->fields.size(); i++) {
define_lookup(i, tuple_output->fields[i]);
}
} else {
define_lookup(0, func_body->body);
}

Var fset_output("fset_output",
FuncStructInfo({PrimStructInfo(DataType::Int(64)), ObjectStructInfo()},
TupleStructInfo(Array<StructInfo>{})));
plan_ = FunctionPlan{std::move(output_lookup), fset_output};

std::optional<int64_t> num_input_params = GetNumInputParams(func);

auto new_params = func->params;
new_params.insert(new_params.begin() + num_input_params.value_or(func->params.size()),
fset_output);

BindingBlock start_of_func = [&]() {
Array<Binding> propagated_params;
for (auto param : func->params) {
GenerateSetOutputCalls(param, [&](const auto& fset_output_call) {
Var void_output("_void", TupleStructInfo(Array<StructInfo>{}));
propagated_params.push_back(VarBinding(void_output, fset_output_call));
});
}
return BindingBlock(propagated_params);
}();
BindingBlock end_of_func = [&]() {
Array<Binding> propagated_params;
for (const auto& [output_index, expr] : inline_outputs) {
Call fset_output_call(fset_output,
{PrimValue(IntImm(DataType::Int(64), output_index)), expr});
Var void_output("_void", TupleStructInfo(Array<StructInfo>{}));
propagated_params.push_back(VarBinding(void_output, fset_output_call));
}
return BindingBlock(propagated_params);
}();

Array<BindingBlock> new_blocks = func_body->blocks;
new_blocks.insert(new_blocks.begin(), start_of_func);
new_blocks.push_back(end_of_func);
Expr new_body = SeqExpr(new_blocks, Tuple(Array<Expr>{}));

auto node = GetRef<Function>(func);
{
auto write_ptr = node.CopyOnWrite();
write_ptr->params = new_params;
write_ptr->body = new_body;
}
if (num_input_params.has_value()) {
node = WithAttr(node, attr::kNumInput, Integer(num_input_params.value() + 1));
}

auto output = Downcast<Function>(ExprMutator::VisitExpr_(node.get()));
plan_.reset();
return output;
}

void VisitBinding(const Binding& binding) override {
ExprMutator::VisitBinding(binding);
GenerateSetOutputCalls(binding->var, [this](const auto& fset_output_call) {
builder_->Emit(fset_output_call, "_void");
});
}

private:
template <typename Callback>
void GenerateSetOutputCalls(const Var& var, Callback callback) {
if (plan_.has_value()) {
if (auto it = plan_->output_lookup.find(var); it != plan_->output_lookup.end()) {
for (auto output_index : it->second) {
callback(
Call(plan_->fset_output, {PrimValue(IntImm(DataType::Int(64), output_index)), var}));
}
}
}
}

struct FunctionPlan {
std::unordered_map<Var, std::vector<size_t>, ObjectPtrHash, ObjectPtrEqual> output_lookup;
Expr fset_output;
};
std::optional<FunctionPlan> plan_;
};
} // namespace

Function WithLazyInputs(Function func) {
LazyInputMutator mutator;

func = Downcast<Function>(mutator.VisitExpr(func));
func = Downcast<Function>(EliminateCommonSubexpr(func));
func = Downcast<Function>(RemoveAllUnused(func));
return func;
}

Function WithLazyOutputs(Function func) {
LazyOutputMutator mutator;

func = Downcast<Function>(mutator.VisitExpr(func));
return func;
}

namespace transform {

Pass LazyGetInput() {
auto pass_func = [](Function func, IRModule, PassContext) -> Function {
if (!func->GetAttr<String>(tvm::attr::kGlobalSymbol).defined()) {
return func;
}
return WithLazyInputs(func);
};
return CreateFunctionPass(/*pass_function=*/pass_func,
/*opt_level=*/0,
/*pass_name=*/"LazyGetInput",
/*required=*/{});
}

TVM_REGISTER_GLOBAL("relax.transform.LazyGetInput").set_body_typed(LazyGetInput);

Pass LazySetOutput() {
auto pass_func = [](Function func, IRModule, PassContext) -> Function {
if (!func->GetAttr<String>(tvm::attr::kGlobalSymbol).defined()) {
return func;
}
return WithLazyOutputs(func);
};
return CreateFunctionPass(/*pass_function=*/pass_func,
/*opt_level=*/0,
/*pass_name=*/"LazySetOutput",
/*required=*/{});
}

TVM_REGISTER_GLOBAL("relax.transform.LazySetOutput").set_body_typed(LazySetOutput);

} // namespace transform
} // namespace relax
} // namespace tvm
Loading

0 comments on commit 61249b4

Please sign in to comment.