Skip to content

Commit

Permalink
Merge memory efficient prefetch pipeline into current prefetch pipeli…
Browse files Browse the repository at this point in the history
…ne (pytorch#1972)

Summary:
Pull Request resolved: pytorch#1972

use one less batch to save memory.

Memory usage is less, and QPS is on par.

Relanding of D54322913

Differential Revision: D57139157
  • Loading branch information
henrylhtsang authored and facebook-github-bot committed May 9, 2024
1 parent 5eeefd9 commit 14ffef5
Showing 1 changed file with 4 additions and 10 deletions.
14 changes: 4 additions & 10 deletions torchrec/distributed/train_pipeline/train_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -715,10 +715,6 @@ def _fill_pipeline(self, dataloader_iter: Iterator[In]) -> None:
# batch 2
self._batch_ip1 = self._copy_batch_to_gpu(dataloader_iter)
self._start_sparse_data_dist(self._batch_ip1)
self._wait_sparse_data_dist()

# batch 3
self._batch_ip2 = self._copy_batch_to_gpu(dataloader_iter)

def progress(self, dataloader_iter: Iterator[In]) -> Out:
self._fill_pipeline(dataloader_iter)
Expand All @@ -730,18 +726,15 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out:
with record_function("## wait_for_batch ##"):
_wait_for_batch(cast(In, self._batch_i), self._prefetch_stream)

self._start_sparse_data_dist(self._batch_ip2)

self._batch_ip3 = self._copy_batch_to_gpu(dataloader_iter)
self._batch_ip2 = self._copy_batch_to_gpu(dataloader_iter)

self._wait_sparse_data_dist()
# forward
with record_function("## forward ##"):
losses, output = cast(Tuple[torch.Tensor, Out], self._model(self._batch_i))

self._prefetch(self._batch_ip1)

self._wait_sparse_data_dist()

if self._model.training:
# backward
with record_function("## backward ##"):
Expand All @@ -751,9 +744,10 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out:
with record_function("## optimizer ##"):
self._optimizer.step()

self._start_sparse_data_dist(self._batch_ip2)

self._batch_i = self._batch_ip1
self._batch_ip1 = self._batch_ip2
self._batch_ip2 = self._batch_ip3

return output

Expand Down

0 comments on commit 14ffef5

Please sign in to comment.