-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add mlonmcu as performance estimator
- Loading branch information
Showing
5 changed files
with
181 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
## | ||
## Copyright (c) 2024 University of Tübingen. | ||
## | ||
## This file is part of hannah. | ||
## See https://atreus.informatik.uni-tuebingen.de/ties/ai/hannah/hannah for further info. | ||
## | ||
## Licensed under the Apache License, Version 2.0 (the "License"); | ||
## you may not use this file except in compliance with the License. | ||
## You may obtain a copy of the License at | ||
## | ||
## http://www.apache.org/licenses/LICENSE-2.0 | ||
## | ||
## Unless required by applicable law or agreed to in writing, software | ||
## distributed under the License is distributed on an "AS IS" BASIS, | ||
## WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
## See the License for the specific language governing permissions and | ||
## limitations under the License. | ||
## | ||
defaults: | ||
- base_config | ||
- override nas: aging_evolution_nas | ||
- override model: embedded_vision_net | ||
- override dataset: cifar10 # Dataset configuration name | ||
- override features: identity # Feature extractor configuration name (use identity for vision datasets) | ||
- override scheduler: 1cycle # learning rate scheduler config name | ||
- override optimizer: adamw # Optimizer config name | ||
- override normalizer: null # Feature normalizer (used for quantized neural networks) | ||
- override module: image_classifier # Lightning module config for the training loop (image classifier for image classification tasks) | ||
- override /nas/constraint_model: random_walk | ||
- override /nas/mac_predictor: mlonmcu # MLonMCU predictor | ||
- override augmentation: null | ||
- _self_ | ||
|
||
|
||
model: | ||
num_classes: 10 | ||
|
||
module: | ||
batch_size: 128 | ||
num_workers: 4 | ||
|
||
nas: | ||
budget: 100 | ||
n_jobs: 4 | ||
bounds: | ||
val_error: 0.03 | ||
total_weights: 100000 | ||
Cycles: 700000000 # measured by MLonMCU | ||
Total ROM: 1500000 # measured by MLonMCU | ||
|
||
trainer: | ||
max_epochs: 10 | ||
|
||
scheduler: | ||
max_lr: 0.001 | ||
|
||
fx_mac_summary: True | ||
|
||
seed: [1234] | ||
|
||
experiment_id: "nas_mlonmcu" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
## | ||
## Copyright (c) 2022 University of Tübingen. | ||
## | ||
## This file is part of hannah. | ||
## See https://atreus.informatik.uni-tuebingen.de/ties/ai/hannah/hannah for further info. | ||
## | ||
## Licensed under the Apache License, Version 2.0 (the "License"); | ||
## you may not use this file except in compliance with the License. | ||
## You may obtain a copy of the License at | ||
## | ||
## http://www.apache.org/licenses/LICENSE-2.0 | ||
## | ||
## Unless required by applicable law or agreed to in writing, software | ||
## distributed under the License is distributed on an "AS IS" BASIS, | ||
## WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
## See the License for the specific language governing permissions and | ||
## limitations under the License. | ||
## | ||
|
||
_target_: hannah.nas.performance_prediction.mlonmcu.predictor.MLonMCUPredictor | ||
model_name: ${model.name} |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
import os | ||
import logging | ||
from pathlib import Path | ||
|
||
import torch | ||
import mlonmcu.context | ||
from mlonmcu.session.run import RunStage | ||
|
||
|
||
msglogger = logging.getLogger(__name__) | ||
|
||
|
||
class MLonMCUPredictor(): | ||
def __init__(self, | ||
model_name, | ||
metrics=["Cycles", "Total ROM", "Total RAM"], | ||
platform="mlif", | ||
backend="tvmaot", | ||
target="etiss_pulpino", | ||
frontend="onnx", | ||
postprocess=None, | ||
feature=None, | ||
configs=None, | ||
parallel=None, | ||
progress=False, | ||
verbose=False,): | ||
self.model_name = model_name | ||
|
||
self.metrics = metrics | ||
self.platform = platform | ||
self.backend = backend | ||
self.target = target | ||
self.frontend = frontend | ||
self.postprocess = postprocess | ||
self.feature = feature | ||
self.configs = configs | ||
self.parallel = parallel | ||
self.progress = progress | ||
self.verbose = verbose | ||
|
||
def predict(self, model, input): | ||
if hasattr(model, 'model'): | ||
model = model.model # FIXME: Decide when to use pl_module and when to use model | ||
|
||
# Convert PyTorch model to ONNX | ||
ckpt_path = Path("mlonmcu") | ||
if not os.path.exists(ckpt_path): | ||
os.mkdir(ckpt_path) | ||
model_path = os.path.join(ckpt_path, f"{self.model_name}.onnx") | ||
convert_to_onnx(model, input, model_path) | ||
|
||
# Run MLonMCU | ||
print("MLonMCU evaluating {}".format(self.model_name)) | ||
if not os.path.isdir(ckpt_path): | ||
raise Exception("INVALID MODEL PATH: ", ckpt_path) | ||
with mlonmcu.context.MlonMcuContext() as context: | ||
session = context.create_session() | ||
run = session.create_run(features=[], config={}) | ||
run.add_frontend_by_name("onnx", context=context) | ||
run.add_model_by_name(model_path, context=context) | ||
run.add_backend_by_name(self.backend, context=context) | ||
run.add_platform_by_name(self.platform, context=context) | ||
run.add_target_by_name(self.target, context=context) | ||
# run.add_feature_by_name("vext", context=context) | ||
session.process_runs(until=RunStage.RUN, context=context) | ||
report = session.get_reports() | ||
print(report) | ||
|
||
# Return a dict of metric values | ||
mlonmcu_metrics = report.df | ||
result = {} | ||
for metric in self.metrics: | ||
if metric in mlonmcu_metrics: | ||
result[metric] = float(mlonmcu_metrics[metric]) | ||
else: | ||
# raise Exception("Metric is not supported by MLonMCU: ", metric) | ||
msglogger.info(f"WARNING: Metric {metric} is not supported by MLonMCU ") | ||
|
||
return result | ||
|
||
def load(self, result_folder): | ||
pass | ||
|
||
def update(self, new_data, input): | ||
pass | ||
|
||
|
||
def convert_to_onnx(pytorch_model, sample_input, onnx_model_path): | ||
# Export pytorch model to onnx | ||
torch.onnx.export( | ||
model=pytorch_model, | ||
args=sample_input, | ||
f=onnx_model_path, # save path | ||
verbose=False, | ||
) | ||
|
||
return onnx_model_path |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters