diff --git a/src/brevitas/core/stats/stats_wrapper.py b/src/brevitas/core/stats/stats_wrapper.py index 49bf62a82..8e148b8c1 100644 --- a/src/brevitas/core/stats/stats_wrapper.py +++ b/src/brevitas/core/stats/stats_wrapper.py @@ -97,19 +97,17 @@ def __init__( super(_ParameterListStats, self).__init__() self.stats_input_concat_dim = stats_input_concat_dim - if len(tracked_parameter_list) >= 1: - self.first_tracked_param = _ViewParameterWrapper( - tracked_parameter_list[0], stats_input_view_shape_impl) - else: - self.first_tracked_param = _ViewParameter(stats_input_view_shape_impl) if len(tracked_parameter_list) > 1: + self.first_tracked_param = _ViewParameterWrapper( + tracked_parameter_list[0], stats_input_view_shape_impl) extra_list = [ _ViewCatParameterWrapper( param, stats_input_view_shape_impl, stats_input_concat_dim) for param in tracked_parameter_list[1:]] self.extra_tracked_params_list = torch.nn.ModuleList(extra_list) else: + self.first_tracked_param = _ViewParameter(stats_input_view_shape_impl) self.extra_tracked_params_list = None self.stats = _Stats(stats_impl, stats_output_shape)