diff --git a/src/brevitas/graph/gptq.py b/src/brevitas/graph/gptq.py index 31d31433b..a1380da4e 100644 --- a/src/brevitas/graph/gptq.py +++ b/src/brevitas/graph/gptq.py @@ -132,7 +132,12 @@ def __init__( # Initialize Hessian matrix and counter. We need it in float32 to compute the inverse self.H = torch.zeros((self.groups, self.columns, self.columns), device='cpu', - dtype=torch.float32) + dtype=torch.float32, + pin_memory=torch.cuda.is_available()) + self.B = torch.zeros((self.groups, self.columns, self.columns), + device='cpu', + dtype=torch.float32, + pin_memory=torch.cuda.is_available()) self.nsamples = 0 assert torch_version >= version.parse('1.10'), "GPTQ requires torch 1.10 or higher" @@ -184,7 +189,9 @@ def update_batch(self, module, input, current_layer): self.H *= self.nsamples / (self.nsamples + batch_size) self.nsamples += batch_size inp_processed = math.sqrt(2 / self.nsamples) * inp_processed.to(torch.float32) - self.H += (inp_processed.bmm(inp_processed.transpose(2, 1))).to(self.H.device) + # optimizing CPU to GPU transfer using in-place copy to pinned memory + self.B.copy_(inp_processed.bmm(inp_processed.transpose(2, 1))) + self.H += self.B # If we are executing GPTQ with group of parallel layers, we keep track of how many forward # we executed. Once we executed as many as the number of parallel_layers, we raise # StopFwdException @@ -255,7 +262,7 @@ def single_layer_update(self, percdamp=.01): f'Increasing the number of samples might fix this issue') return finally: - del self.H + del self.H, self.B for i1 in range(0, self.columns, self.blocksize): i2 = min(i1 + self.blocksize, self.columns)