Skip to content

Commit

Permalink
Fix (GPFQ): use max/mean to avoid running out of memory
Browse files Browse the repository at this point in the history
  • Loading branch information
fabianandresgrob committed Mar 1, 2024
1 parent 1b8ff28 commit 8853566
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions src/brevitas/graph/gpfq.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,12 +214,18 @@ def update_batch(self, module, input, current_layer):
if self.float_input is None:
self.float_input = inp_processed
else:
self.float_input = torch.cat([self.float_input, inp_processed], dim=1)
self.float_input = torch.max(self.float_input, inp_processed)
# self.float_input = torch.stack([self.float_input, inp_processed])
# self.float_input = self.float_input.mean(dim=0)
# self.float_input = torch.cat([self.float_input, inp_processed], dim=1)
else:
if self.quantized_input is None:
self.quantized_input = inp_processed
else:
self.quantized_input = torch.cat([self.quantized_input, inp_processed], dim=1)
self.quantized_input = torch.max(self.quantized_input, inp_processed)
# self.quantized_input = torch.stack([self.quantized_input, inp_processed])
# self.quantized_input = self.quantized_input.mean(dim=0)
# self.quantized_input = torch.cat([self.quantized_input, inp_processed], dim=1)
# If we are executing GPFQ with group of parallel layers, we keep track of how many forward
# we executed. Once we executed as many as the number of parallel_layers, we raise
# StopFwdException
Expand Down

0 comments on commit 8853566

Please sign in to comment.