Skip to content

Commit

Permalink
Merge memory efficient prefetch pipeline into current prefetch pipeline
Browse files Browse the repository at this point in the history
Summary: context: to be added

Differential Revision: D57139157
  • Loading branch information
henrylhtsang authored and facebook-github-bot committed May 9, 2024
1 parent 7e47032 commit 841d019
Showing 1 changed file with 5 additions and 16 deletions.
21 changes: 5 additions & 16 deletions torchrec/distributed/train_pipeline/train_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -702,23 +702,14 @@ def _fill_pipeline(self, dataloader_iter: Iterator[In]) -> None:
if self._batch_i is None:
raise StopIteration

self._init_pipelined_modules(
self._batch_i,
self._context,
# pyre-ignore
PrefetchPipelinedForward,
)
self._init_pipelined_modules(self._batch_i)
self._start_sparse_data_dist(self._batch_i)
self._wait_sparse_data_dist()
self._prefetch(self._batch_i)

# batch 2
self._batch_ip1 = self._copy_batch_to_gpu(dataloader_iter)
self._start_sparse_data_dist(self._batch_ip1)
self._wait_sparse_data_dist()

# batch 3
self._batch_ip2 = self._copy_batch_to_gpu(dataloader_iter)

def progress(self, dataloader_iter: Iterator[In]) -> Out:
self._fill_pipeline(dataloader_iter)
Expand All @@ -730,18 +721,15 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out:
with record_function("## wait_for_batch ##"):
_wait_for_batch(cast(In, self._batch_i), self._prefetch_stream)

self._start_sparse_data_dist(self._batch_ip2)

self._batch_ip3 = self._copy_batch_to_gpu(dataloader_iter)
self._batch_ip2 = self._copy_batch_to_gpu(dataloader_iter)

self._wait_sparse_data_dist()
# forward
with record_function("## forward ##"):
losses, output = cast(Tuple[torch.Tensor, Out], self._model(self._batch_i))

self._prefetch(self._batch_ip1)

self._wait_sparse_data_dist()

if self._model.training:
# backward
with record_function("## backward ##"):
Expand All @@ -751,9 +739,10 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out:
with record_function("## optimizer ##"):
self._optimizer.step()

self._start_sparse_data_dist(self._batch_ip2)

self._batch_i = self._batch_ip1
self._batch_ip1 = self._batch_ip2
self._batch_ip2 = self._batch_ip3

return output

Expand Down

0 comments on commit 841d019

Please sign in to comment.