Skip to content

Commit

Permalink
Refined review
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Jul 10, 2023
1 parent f741d07 commit 44aa3ab
Showing 1 changed file with 11 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -131,12 +131,21 @@ def main():

def ptq_torchvision_models(df, args):
# Generate all possible combinations, including invalid ones
combinations = list(product(*OPTIONS.values()))
# Split stats and mse due to the act_quant_percentile value
percentile_options = OPTIONS.copy()
percentile_options['act_param_method'] = ['stats']
mse_options = OPTIONS.copy()
mse_options['act_param_method'] = ['mse']
mse_options['act_quant_percentile'] = [None]
# Combine the two sets of combinations
combinations = list(product(*percentile_options.values())) + list(
product(*mse_options.values()))
# Generate Namespace for each configuration
configs = [
SimpleNamespace(**{k: v
for k, v in zip(OPTIONS.keys(), combination)})
for combination in combinations]
# Define which configs are not valid
# Define which configurations are not valid
configs = list(map(validate_config, configs))
# Drop invalid configurations
configs = list(config for config in configs if config.is_valid)
Expand Down Expand Up @@ -280,11 +289,6 @@ def validate_config(config_namespace):
is_valid = False
if config_namespace.act_bit_width < config_namespace.weight_bit_width:
is_valid = False
if config_namespace.act_param_method == 'mse':
if config_namespace.act_quant_percentile == 99.999:
config_namespace.act_quant_percentile = None
else:
is_valid = False

config_namespace.is_valid = is_valid
return config_namespace
Expand Down

0 comments on commit 44aa3ab

Please sign in to comment.