Skip to content

Commit

Permalink
Feat (gptq): optimizing CPU to GPU memory transfer (#1009)
Browse files Browse the repository at this point in the history
  • Loading branch information
i-colbert authored Sep 12, 2024
1 parent 9932b92 commit 10dcee3
Showing 1 changed file with 10 additions and 3 deletions.
13 changes: 10 additions & 3 deletions src/brevitas/graph/gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,12 @@ def __init__(
# Initialize Hessian matrix and counter. We need it in float32 to compute the inverse
self.H = torch.zeros((self.groups, self.columns, self.columns),
device='cpu',
dtype=torch.float32)
dtype=torch.float32,
pin_memory=torch.cuda.is_available())
self.B = torch.zeros((self.groups, self.columns, self.columns),
device='cpu',
dtype=torch.float32,
pin_memory=torch.cuda.is_available())
self.nsamples = 0

assert torch_version >= version.parse('1.10'), "GPTQ requires torch 1.10 or higher"
Expand Down Expand Up @@ -184,7 +189,9 @@ def update_batch(self, module, input, current_layer):
self.H *= self.nsamples / (self.nsamples + batch_size)
self.nsamples += batch_size
inp_processed = math.sqrt(2 / self.nsamples) * inp_processed.to(torch.float32)
self.H += (inp_processed.bmm(inp_processed.transpose(2, 1))).to(self.H.device)
# optimizing CPU to GPU transfer using in-place copy to pinned memory
self.B.copy_(inp_processed.bmm(inp_processed.transpose(2, 1)))
self.H += self.B
# If we are executing GPTQ with group of parallel layers, we keep track of how many forward
# we executed. Once we executed as many as the number of parallel_layers, we raise
# StopFwdException
Expand Down Expand Up @@ -255,7 +262,7 @@ def single_layer_update(self, percdamp=.01):
f'Increasing the number of samples might fix this issue')
return
finally:
del self.H
del self.H, self.B

for i1 in range(0, self.columns, self.blocksize):
i2 = min(i1 + self.blocksize, self.columns)
Expand Down

0 comments on commit 10dcee3

Please sign in to comment.