Skip to content

Commit

Permalink
missing upcasts
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Jun 20, 2024
1 parent bb19288 commit 382d338
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion src/brevitas/graph/equalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -990,7 +990,8 @@ def forward_stats_hook(self, module, *args, name, batch_dim=0, use_inp=True, **k

self.batch_dim_act_map[name] = batch_dim

input_scales = self.scale_fn(x, dim=batch_dim)
dtype = x.dtype
input_scales = self.scale_fn(x.to(torch.float32), dim=batch_dim).to(dtype)
if name not in self.float_act_map:
self.float_act_map[name] = input_scales
else:
Expand Down

0 comments on commit 382d338

Please sign in to comment.