From 14ffef59d43aaf43aef2e1f95b2348879f428d76 Mon Sep 17 00:00:00 2001 From: Henry Tsang Date: Thu, 9 May 2024 11:48:14 -0700 Subject: [PATCH] Merge memory efficient prefetch pipeline into current prefetch pipeline (#1972) Summary: Pull Request resolved: https://github.com/pytorch/torchrec/pull/1972 use one less batch to save memory. Memory usage is less, and QPS is on par. Relanding of D54322913 Differential Revision: D57139157 --- .../distributed/train_pipeline/train_pipelines.py | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/torchrec/distributed/train_pipeline/train_pipelines.py b/torchrec/distributed/train_pipeline/train_pipelines.py index f8685ab1b..f05143420 100644 --- a/torchrec/distributed/train_pipeline/train_pipelines.py +++ b/torchrec/distributed/train_pipeline/train_pipelines.py @@ -715,10 +715,6 @@ def _fill_pipeline(self, dataloader_iter: Iterator[In]) -> None: # batch 2 self._batch_ip1 = self._copy_batch_to_gpu(dataloader_iter) self._start_sparse_data_dist(self._batch_ip1) - self._wait_sparse_data_dist() - - # batch 3 - self._batch_ip2 = self._copy_batch_to_gpu(dataloader_iter) def progress(self, dataloader_iter: Iterator[In]) -> Out: self._fill_pipeline(dataloader_iter) @@ -730,18 +726,15 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out: with record_function("## wait_for_batch ##"): _wait_for_batch(cast(In, self._batch_i), self._prefetch_stream) - self._start_sparse_data_dist(self._batch_ip2) - - self._batch_ip3 = self._copy_batch_to_gpu(dataloader_iter) + self._batch_ip2 = self._copy_batch_to_gpu(dataloader_iter) + self._wait_sparse_data_dist() # forward with record_function("## forward ##"): losses, output = cast(Tuple[torch.Tensor, Out], self._model(self._batch_i)) self._prefetch(self._batch_ip1) - self._wait_sparse_data_dist() - if self._model.training: # backward with record_function("## backward ##"): @@ -751,9 +744,10 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out: with record_function("## optimizer ##"): self._optimizer.step() + self._start_sparse_data_dist(self._batch_ip2) + self._batch_i = self._batch_ip1 self._batch_ip1 = self._batch_ip2 - self._batch_ip2 = self._batch_ip3 return output