Skip to content

Commit

Permalink
tentative cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Oct 1, 2024
1 parent 609a164 commit b92d5dd
Showing 1 changed file with 3 additions and 5 deletions.
8 changes: 3 additions & 5 deletions src/brevitas/core/stats/stats_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,19 +97,17 @@ def __init__(
super(_ParameterListStats, self).__init__()

self.stats_input_concat_dim = stats_input_concat_dim
if len(tracked_parameter_list) >= 1:
self.first_tracked_param = _ViewParameterWrapper(
tracked_parameter_list[0], stats_input_view_shape_impl)
else:
self.first_tracked_param = _ViewParameter(stats_input_view_shape_impl)

if len(tracked_parameter_list) > 1:
self.first_tracked_param = _ViewParameterWrapper(
tracked_parameter_list[0], stats_input_view_shape_impl)
extra_list = [
_ViewCatParameterWrapper(
param, stats_input_view_shape_impl, stats_input_concat_dim)
for param in tracked_parameter_list[1:]]
self.extra_tracked_params_list = torch.nn.ModuleList(extra_list)
else:
self.first_tracked_param = _ViewParameter(stats_input_view_shape_impl)
self.extra_tracked_params_list = None
self.stats = _Stats(stats_impl, stats_output_shape)

Expand Down

0 comments on commit b92d5dd

Please sign in to comment.