Skip to content

Commit

Permalink
Merge branch 'cherry-pick-84476780' into 'main'
Browse files Browse the repository at this point in the history
Remove legacy ofa code

See merge request es/ai/hannah/hannah!361
  • Loading branch information
cgerum committed Nov 27, 2023
2 parents 8da61ee + 85dc6f2 commit af3ca10
Show file tree
Hide file tree
Showing 33 changed files with 25 additions and 9,870 deletions.
69 changes: 25 additions & 44 deletions hannah/callbacks/summaries.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,29 +17,24 @@
# limitations under the License.
#
import logging
from collections import OrderedDict
import sys
import traceback
from collections import OrderedDict

import pandas as pd
import torch
import torch.fx as fx
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.utilities.rank_zero import rank_zero_only
from tabulate import tabulate
from torch.fx.graph_module import GraphModule

from hannah.models.ofa.submodules.elasticBase import ElasticBase1d
from hannah.nas.functional_operators.operators import add, conv2d, linear
from hannah.nas.graph_conversion import GraphConversionTracer

from ..models.factory import qat
from ..models.ofa import OFAModel
from ..models.ofa.submodules.elastickernelconv import ConvBn1d, ConvBnReLu1d, ConvRelu1d
from ..models.ofa.type_utils import elastic_conv_type, elastic_Linear_type
from ..models.sinc import SincNet

import torch.fx as fx
from hannah.nas.graph_conversion import GraphConversionTracer
from hannah.nas.functional_operators.operators import conv2d, linear, add

msglogger = logging.getLogger(__name__)


Expand Down Expand Up @@ -142,11 +137,6 @@ def get_extra(module, volume_ofm, output):
"""
classes = {
elastic_conv_type: get_elastic_conv,
elastic_Linear_type: get_elastic_linear,
ConvBn1d: get_conv,
ConvRelu1d: get_conv,
ConvBnReLu1d: get_conv,
torch.nn.Conv1d: get_conv,
torch.nn.Conv2d: get_conv,
qat.Conv1d: get_conv,
Expand Down Expand Up @@ -330,16 +320,9 @@ def _do_summary(self, pl_module, input=None, print_log=True):
total_weights = 0.0
estimated_acts = 0.0
model = pl_module.model
ofamodel = isinstance(model, OFAModel)
if ofamodel:
if model.validation_model is None:
model.build_validation_model()
model = model.validation_model

try:
df = walk_model(model, dummy_input)
if ofamodel:
pl_module.model.reset_validation_model()
t = tabulate(df, headers="keys", tablefmt="psql", floatfmt=".5f")
total_macs = df["MACs"].sum()
total_acts = df["IFM volume"][0] + df["OFM volume"].sum()
Expand All @@ -354,8 +337,6 @@ def _do_summary(self, pl_module, input=None, print_log=True):
"Estimated Activations: " + "{:,}".format(estimated_acts)
)
except RuntimeError as e:
if ofamodel:
pl_module.model.reset_validation_model()
msglogger.warning("Could not create performance summary: %s", str(e))
return OrderedDict()

Expand Down Expand Up @@ -477,15 +458,13 @@ def get_conv(node, output, args, kwargs):
out_channels = weight.shape[0]
in_channels = weight.shape[1]
kernel_size = weight.shape[2]
num_weights = out_channels * in_channels / kwargs['groups'] * kernel_size**2
macs = volume_ofm * in_channels / kwargs['groups'] * kernel_size
num_weights = out_channels * in_channels / kwargs["groups"] * kernel_size**2
macs = volume_ofm * in_channels / kwargs["groups"] * kernel_size
attrs = "k=" + "(%d, %d)" % (kernel_size, kernel_size)
attrs += ", s=" + "(%d, %d)" % (kwargs['stride'], kwargs['stride'])
attrs += ", g=(%d)" % kwargs['groups']
attrs += ", dsc=(%s)" % str(
in_channels == out_channels == kwargs['groups']
)
attrs += ", d=" + "(%d, %d)" % (kwargs['dilation'], kwargs['dilation'])
attrs += ", s=" + "(%d, %d)" % (kwargs["stride"], kwargs["stride"])
attrs += ", g=(%d)" % kwargs["groups"]
attrs += ", dsc=(%s)" % str(in_channels == out_channels == kwargs["groups"])
attrs += ", d=" + "(%d, %d)" % (kwargs["dilation"], kwargs["dilation"])
return num_weights, macs, attrs


Expand All @@ -500,7 +479,7 @@ def get_linear(node, output, args, kwargs):

def get_type(node):
try:
return node.name.split('_')[-2]
return node.name.split("_")[-2]
except Exception as e:
pass
return node.name
Expand Down Expand Up @@ -531,24 +510,26 @@ def __init__(self, module: torch.nn.Module):
"MACs": [],
}

def run_node(self, n : torch.fx.Node):
def run_node(self, n: torch.fx.Node):
try:
out = super().run_node(n)
except Exception as e:
print(str(e))
if n.op == 'call_function':
if n.op == "call_function":
try:
args, kwargs = self.fetch_args_kwargs_from_env(n)
num_weights, macs, attrs = self.count_function.get(n.target, get_zero_op)(n, out, args, kwargs)
self.data['Name'] += [n.name]
self.data['Type'] += [get_type(n)]
self.data['Attrs'] += [attrs]
self.data['IFM'] += [tuple(args[0].shape)]
self.data['IFM volume'] += [prod(args[0].shape)]
self.data['OFM'] += [tuple(out.shape)]
self.data['OFM volume'] += [prod(out.shape)]
self.data['Weights volume'] += [int(num_weights)]
self.data['MACs'] += [int(macs)]
num_weights, macs, attrs = self.count_function.get(
n.target, get_zero_op
)(n, out, args, kwargs)
self.data["Name"] += [n.name]
self.data["Type"] += [get_type(n)]
self.data["Attrs"] += [attrs]
self.data["IFM"] += [tuple(args[0].shape)]
self.data["IFM volume"] += [prod(args[0].shape)]
self.data["OFM"] += [tuple(out.shape)]
self.data["OFM volume"] += [prod(out.shape)]
self.data["Weights volume"] += [int(num_weights)]
self.data["MACs"] += [int(macs)]
except Exception as e:
msglogger.warning("Summary of node %s failed: %s", n.name, str(e))
return out
Expand Down
24 changes: 0 additions & 24 deletions hannah/conf/config_ofa.yaml

This file was deleted.

191 changes: 0 additions & 191 deletions hannah/conf/model/ofa.yaml

This file was deleted.

Loading

0 comments on commit af3ca10

Please sign in to comment.