Skip to content

Commit

Permalink
[SLM] Allow modules to define pre-processing of weights (#16757)
Browse files Browse the repository at this point in the history
* [SLM] Allow TensorStructInfo to specify parameter in export

Prior to this commit, the parameter specification for SLM tensor
needed to be passed as a `nn.spec.Tensor`.  As this object is only
used to construct a `relax.TensorStructInfo`, and has the same fields
as a `relax.TensorStructInfo`, this commit allows the parameter
specification to be passed as a `relax.TensorStructInfo`.

* Resolve breakage in unit tests

* [SLM] Use `CopyWithNewVars` to de-duplicate symbolic variables

Prior to this commit, a `nn.spec.Tensor`'s shape had special handling
to ensure that symbolic variable were not reused across multiple
functions.  This commit updates this to instead be performed using the
`CopyWithNewVars` function.

* [SLM] Allow modules to define pre-processing of weights

Prior to this commit, the weights used by `nn.Module` instances were
required to be `nn.Parameter` instances.  This commit allows the
weights to instead be `nn.Tensor` instances, defined in terms of other
`nn.Parameter` weights.  This allows a model to define both the
original weights that would be present in an external
checkpoint (e.g. a Pytorch or Safetensors file), and the
pre-processing that should be performed on those weights.

* Undo portions that would introduce R.Tensor to nn.Module

* Remove unit tests that were related to TensorStructInfo
  • Loading branch information
Lunderberg authored Mar 22, 2024
1 parent 31803e6 commit 1cccc3b
Show file tree
Hide file tree
Showing 8 changed files with 498 additions and 58 deletions.
17 changes: 16 additions & 1 deletion python/tvm/relax/frontend/nn/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,7 +591,22 @@ def wrap_nested(expr: rx.Expr, name: str) -> Union[Tensor, Sequence[Tensor]]:
The computed result.
"""
if not isinstance(expr, rx.DataflowVar):
expr = BlockBuilder.current().emit(expr, name)
block_builder = BlockBuilder.current()
if block_builder is None:
# Normalize to make sure we have valid StructInfo, but
# wait until we are actually building the function to
# flatten nested expressions.
#
# TODO(Lunderberg): Make this easier to call. Infering
# struct info for a nested expression should be doable in
# a free function, without requiring an active
# BlockBuilder and an active FunctionFrame.
builder = BlockBuilder()
with builder.function("dummy_scope", params=[]):
expr = builder.normalize(expr)
builder.emit_func_output([])
else:
expr = BlockBuilder.current().emit(expr, name)
if isinstance(expr.struct_info_, TensorStructInfo):
return Tensor(_expr=expr)
if isinstance(expr.struct_info_, TupleStructInfo):
Expand Down
40 changes: 19 additions & 21 deletions python/tvm/relax/frontend/nn/exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,8 @@ def _effects() -> typing.List[typing.Tuple[str, core.Effect]]:
return result

# pylint: enable=protected-access
params = None

params = _params()
effects = _effects()
ext_mods = self.extern_mods
with self:
Expand All @@ -121,7 +122,6 @@ def _effects() -> typing.List[typing.Tuple[str, core.Effect]]:
outputs = _emit_effect_init(self.builder, effects)
self.builder.emit_func_output(outputs, params=[])
for method_name, method_spec in zip(spec.method_names, spec.method_specs):
params = _params() # Re-initialize so symbolic shapes not shared across methods
len_args = len(method_spec.arg_specs)
len_effects = {
"packed": 1,
Expand All @@ -135,9 +135,18 @@ def _effects() -> typing.List[typing.Tuple[str, core.Effect]]:
with self.builder.dataflow():
outputs, inputs = _emit_method(self.builder, method_spec, params, effects)
self.builder.emit_func_output(outputs, inputs)

# TODO(Lunderberg): Make a `ir.transform.ConvertSSA`,
# similar to the existing `tir.transform.ConvertSSA`,
# that converts an entire module to SSA, including TIR
# variable definitions used in either TIR or Relax.
mod = self.builder.get()
mod[method_name] = rx.utils.copy_with_new_vars(mod[method_name])

mod = self.builder.finalize()
assert rx.analysis.well_formed(mod)

mod = rx.transform.CanonicalizeBindings()(mod)
return mod, params, ext_mods


Expand All @@ -161,8 +170,6 @@ def _emit_method( # pylint: disable=too-many-locals,too-many-branches,too-many-
effects: typing.Optional[typing.List[typing.Tuple[str, core.Effect]]],
):
# pylint: disable=protected-access
# symbolic shape's name mapping to its tir.Var for reuse
str2var_params: typing.Dict[str, tir.Var] = {}

def _unwrap_ret(expr: typing.Any) -> typing.Any:
if isinstance(expr, (core.Tensor, core.Object)):
Expand All @@ -176,35 +183,26 @@ def _unwrap_ret(expr: typing.Any) -> typing.Any:
def _convert_input(arg):
if isinstance(arg, tir.Var):
return rx.Var(arg.name, struct_info=ShapeStructInfo(values=[arg]))
if isinstance(arg, (core.Tensor, core.Object)):
elif isinstance(arg, (core.Tensor, core.Object)):
return arg._expr # pylint: disable=protected-access
if isinstance(arg, _spec.Tuple):
elif isinstance(arg, _spec.Tuple):
return rx.Var(
arg.name,
struct_info=TupleStructInfo(
[_convert_input(arg_i).struct_info for arg_i in arg.elements]
),
)
raise TypeError(f"Unsupported input type: {type(arg)}")
elif isinstance(arg, rx.Expr):
return arg
else:
raise TypeError(f"Unsupported input type: {type(arg)}")

def _params(mode: str) -> typing.List[rx.Var]:
inputs: typing.List[rx.Var] = []

def _get_var(shape_var: tir.Var) -> tir.Var:
name = shape_var.name
if name in str2var_params:
return str2var_params[name]
var = tir.Var(name, "int64")
str2var_params[name] = var
return var

for name, param in params:
# Make sure the a symbolic shape is not re-registered (same as _method_spec_to_inputs)
# e.g. we do not see `vocab_size` for `lm_head` and `vocab_size_1` for `embed_tokens`
new_shape = [_get_var(x) if isinstance(x, tir.Var) else x for x in param.shape]
var = core.Tensor.placeholder(new_shape, param.dtype, name)._expr
inputs.append(var)
param._expr = var
inputs.append(param._expr)

if mode == "none":
return []
if mode == "plain":
Expand Down
Loading

0 comments on commit 1cccc3b

Please sign in to comment.