diff --git a/src/brevitas/graph/gpfq.py b/src/brevitas/graph/gpfq.py index e255660a0..567021b35 100644 --- a/src/brevitas/graph/gpfq.py +++ b/src/brevitas/graph/gpfq.py @@ -214,12 +214,18 @@ def update_batch(self, module, input, current_layer): if self.float_input is None: self.float_input = inp_processed else: - self.float_input = torch.cat([self.float_input, inp_processed], dim=1) + self.float_input = torch.max(self.float_input, inp_processed) + # self.float_input = torch.stack([self.float_input, inp_processed]) + # self.float_input = self.float_input.mean(dim=0) + # self.float_input = torch.cat([self.float_input, inp_processed], dim=1) else: if self.quantized_input is None: self.quantized_input = inp_processed else: - self.quantized_input = torch.cat([self.quantized_input, inp_processed], dim=1) + self.quantized_input = torch.max(self.quantized_input, inp_processed) + # self.quantized_input = torch.stack([self.quantized_input, inp_processed]) + # self.quantized_input = self.quantized_input.mean(dim=0) + # self.quantized_input = torch.cat([self.quantized_input, inp_processed], dim=1) # If we are executing GPFQ with group of parallel layers, we keep track of how many forward # we executed. Once we executed as many as the number of parallel_layers, we raise # StopFwdException