Skip to content

Commit

Permalink
[Cherry-pick][MSC][M4.1] Add plugin && plugin_builder, enable build a…
Browse files Browse the repository at this point in the history
…nd test in different frameworks (#16397) (#16460)

[Unity][MSC][M4.1] Add plugin && plugin_builder, enable build and test in different frameworks (#16397)

* add plugin building

* minor fix

Co-authored-by: Archermmt <[email protected]>
  • Loading branch information
Hzfengsy and Archermmt authored Jan 24, 2024
1 parent 20b08a5 commit 0e8e421
Show file tree
Hide file tree
Showing 25 changed files with 6,438 additions and 35 deletions.
33 changes: 15 additions & 18 deletions python/tvm/contrib/msc/core/codegen/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,21 +169,18 @@ def relay_to_relax(
]

# pylint: disable=unused-argument
def _bind_weights(mod: tvm.IRModule, folder: msc_utils.MSCDirectory) -> tvm.IRModule:
return BindParams("main", weights)(mod)

mod = codegen.load(inputs, post_load=_bind_weights)

mod = tvm.ir.transform.Sequential(
[
# The canonicalization of relax variable bindings is not required
# for correctness. It does, however, remove trivial `x = y`
# bindings, preventing test cases from depending on their
# presence.
tvm.relax.transform.CanonicalizeBindings(),
tvm.relax.transform.ConvertToDataflow(min_size=1),
],
name="tvm.contrib.msc.core.codegen.relay_to_relax_postproc",
)(mod)

return mod
def _post_proc(mod: tvm.IRModule, folder: msc_utils.MSCDirectory) -> tvm.IRModule:
mod = BindParams("main", weights)(mod)
return tvm.ir.transform.Sequential(
[
# The canonicalization of relax variable bindings is not required
# for correctness. It does, however, remove trivial `x = y`
# bindings, preventing test cases from depending on their
# presence.
tvm.relax.transform.CanonicalizeBindings(),
tvm.relax.transform.ConvertToDataflow(min_size=1),
],
name="tvm.contrib.msc.core.codegen.relay_to_relax_postproc",
)(mod)

return codegen.load(inputs, post_load=_post_proc)
30 changes: 13 additions & 17 deletions python/tvm/contrib/msc/framework/tvm/codegen/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,24 +65,20 @@ def _save_weights(folder: msc_utils.MSCDirectory):
f_params.write(tvm.runtime.save_param_dict(weights))

# pylint: disable=unused-argument
def _bind_weights(mod: tvm.IRModule, folder: msc_utils.MSCDirectory) -> tvm.IRModule:
def _post_proc(mod: tvm.IRModule, folder: msc_utils.MSCDirectory) -> tvm.IRModule:
if weights:
mod = BindParams("main", weights)(mod)
return mod
return tvm.ir.transform.Sequential(
[
# The canonicalization of relax variable bindings is not required
# for correctness. It does, however, remove trivial `x = y`
# bindings, preventing test cases from depending on their
# presence.
tvm.relax.transform.CanonicalizeBindings(),
tvm.relax.transform.ConvertToDataflow(min_size=1),
],
name="tvm.contrib.msc.framework.tvm.codegen.to_relax_postproc",
)(mod)

codegen = CodeGen(graph, _ffi_api.GetRelaxSources, codegen_config, print_config, build_folder)
mod = codegen.load(inputs, pre_load=_save_weights, post_load=_bind_weights)

mod = tvm.ir.transform.Sequential(
[
# The canonicalization of relax variable bindings is not required
# for correctness. It does, however, remove trivial `x = y`
# bindings, preventing test cases from depending on their
# presence.
tvm.relax.transform.CanonicalizeBindings(),
tvm.relax.transform.ConvertToDataflow(min_size=1),
],
name="tvm.contrib.msc.framework.tvm.codegen.to_relax_postproc",
)(mod)

return mod
return codegen.load(inputs, pre_load=_save_weights, post_load=_post_proc)
19 changes: 19 additions & 0 deletions python/tvm/contrib/msc/plugin/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""tvm.contrib.msc.plugin"""

from .build import *
21 changes: 21 additions & 0 deletions python/tvm/contrib/msc/plugin/_ffi_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""tvm.contrib.msc.plugin._ffi_api"""

import tvm._ffi

tvm._ffi._init_api("msc.plugin", __name__)
Loading

0 comments on commit 0e8e421

Please sign in to comment.