diff --git a/src/brevitas/graph/equalize.py b/src/brevitas/graph/equalize.py index 601081676..31f295f5c 100644 --- a/src/brevitas/graph/equalize.py +++ b/src/brevitas/graph/equalize.py @@ -990,7 +990,8 @@ def forward_stats_hook(self, module, *args, name, batch_dim=0, use_inp=True, **k self.batch_dim_act_map[name] = batch_dim - input_scales = self.scale_fn(x, dim=batch_dim) + dtype = x.dtype + input_scales = self.scale_fn(x.to(torch.float32), dim=batch_dim).to(dtype) if name not in self.float_act_map: self.float_act_map[name] = input_scales else: