From 8853566296c941938425108b671f069383ccd1e5 Mon Sep 17 00:00:00 2001 From: Fabian Grob Date: Fri, 1 Mar 2024 05:19:53 -0800 Subject: [PATCH] Fix (GPFQ): use max/mean to avoid running out of memory --- src/brevitas/graph/gpfq.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/brevitas/graph/gpfq.py b/src/brevitas/graph/gpfq.py index e255660a0..567021b35 100644 --- a/src/brevitas/graph/gpfq.py +++ b/src/brevitas/graph/gpfq.py @@ -214,12 +214,18 @@ def update_batch(self, module, input, current_layer): if self.float_input is None: self.float_input = inp_processed else: - self.float_input = torch.cat([self.float_input, inp_processed], dim=1) + self.float_input = torch.max(self.float_input, inp_processed) + # self.float_input = torch.stack([self.float_input, inp_processed]) + # self.float_input = self.float_input.mean(dim=0) + # self.float_input = torch.cat([self.float_input, inp_processed], dim=1) else: if self.quantized_input is None: self.quantized_input = inp_processed else: - self.quantized_input = torch.cat([self.quantized_input, inp_processed], dim=1) + self.quantized_input = torch.max(self.quantized_input, inp_processed) + # self.quantized_input = torch.stack([self.quantized_input, inp_processed]) + # self.quantized_input = self.quantized_input.mean(dim=0) + # self.quantized_input = torch.cat([self.quantized_input, inp_processed], dim=1) # If we are executing GPFQ with group of parallel layers, we keep track of how many forward # we executed. Once we executed as many as the number of parallel_layers, we raise # StopFwdException