Skip to content

Commit

Permalink
Avoid optional instances in Loops (Lightning-AI#10735)
Browse files Browse the repository at this point in the history
* Avoid optional instances in Loops

* More cleanup
  • Loading branch information
carmocca authored Nov 26, 2021
1 parent ae53562 commit 31bb6e6
Show file tree
Hide file tree
Showing 6 changed files with 8 additions and 20 deletions.
2 changes: 1 addition & 1 deletion pytorch_lightning/loops/batch/training_batch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def done(self) -> bool:
return len(self._remaining_splits) == 0

def connect(
self, optimizer_loop: Optional["Loop"] = None, manual_loop: Optional[ManualOptimization] = None
self, optimizer_loop: Optional[OptimizerLoop] = None, manual_loop: Optional[ManualOptimization] = None
) -> None:
if optimizer_loop is not None:
self.optimizer_loop = optimizer_loop
Expand Down
7 changes: 2 additions & 5 deletions pytorch_lightning/loops/epoch/evaluation_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,16 +49,13 @@ def __init__(self) -> None:
self._num_dataloaders: Optional[int] = None
self._dataloader_iter: Optional[Iterator] = None
self._data_fetcher: Optional[DataFetcher] = None
self._dataloader_state_dict: Dict[str, Any] = None
self._dataloader_state_dict: Dict[str, Any] = {}

@property
def done(self) -> bool:
"""Returns ``True`` if the current iteration count reaches the number of dataloader batches."""
return self.batch_progress.current.completed >= self._dl_max_batches

def connect(self, **kwargs: "Loop") -> None:
raise NotImplementedError(f"{self.__class__.__name__} does not connect any child loops.")

def reset(self) -> None:
"""Resets the loop's internal state."""
self._dl_max_batches = None
Expand Down Expand Up @@ -192,7 +189,7 @@ def on_load_checkpoint(self, state_dict: Dict) -> None:
def _reload_dataloader_state_dict(self, data_fetcher: AbstractDataFetcher):
if not self.trainer.sanity_checking and self._dataloader_state_dict:
_reload_dataloader_state_dict(data_fetcher.dataloader, self._dataloader_state_dict)
self._dataloader_state_dict = None
self._dataloader_state_dict = {}

def _num_completed_batches_reached(self) -> bool:
epoch_finished_on_completed = self.batch_progress.current.completed == self._dl_max_batches
Expand Down
8 changes: 3 additions & 5 deletions pytorch_lightning/loops/epoch/training_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,8 @@ def __init__(self, min_steps: Optional[int] = 0, max_steps: int = -1) -> None:
self.batch_progress = BatchProgress()
self.scheduler_progress = SchedulerProgress()

self.batch_loop: Optional[TrainingBatchLoop] = None
self.val_loop: Optional["loops.EvaluationLoop"] = None
self.batch_loop = TrainingBatchLoop()
self.val_loop = loops.EvaluationLoop()

self._results = ResultCollection(training=True)
self._outputs: _OUTPUTS_TYPE = []
Expand Down Expand Up @@ -107,7 +107,7 @@ def done(self) -> bool:

def connect(
self,
batch_loop: TrainingBatchLoop = None,
batch_loop: Optional[TrainingBatchLoop] = None,
val_loop: Optional["loops.EvaluationLoop"] = None,
) -> None:
"""Optionally connect a custom batch or validation loop to this training epoch loop."""
Expand All @@ -118,8 +118,6 @@ def connect(

def reset(self) -> None:
"""Resets the internal state of the loop for a new run."""
assert self.batch_loop is not None
assert self.batch_loop.optimizer_loop is not None
if self.restarting:
self.batch_progress.reset_on_restart()
self.scheduler_progress.reset_on_restart()
Expand Down
6 changes: 1 addition & 5 deletions pytorch_lightning/loops/fit_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def __init__(

self.max_epochs = max_epochs
self.min_epochs = min_epochs
self.epoch_loop: Optional[TrainingEpochLoop] = None
self.epoch_loop = TrainingEpochLoop()
self.epoch_progress = Progress()
self._is_fresh_start_epoch: bool = True

Expand Down Expand Up @@ -128,15 +128,11 @@ def running_loss(self) -> TensorRunningAccum:
@property
def _skip_backward(self) -> bool:
"""Determines whether the loop will skip backward during automatic optimization."""
assert self.epoch_loop.batch_loop is not None
assert self.epoch_loop.batch_loop.optimizer_loop is not None
return self.epoch_loop.batch_loop.optimizer_loop._skip_backward

@_skip_backward.setter
def _skip_backward(self, value: bool) -> None:
"""Determines whether the loop will skip backward during automatic optimization."""
assert self.epoch_loop.batch_loop is not None
assert self.epoch_loop.batch_loop.optimizer_loop is not None
self.epoch_loop.batch_loop.optimizer_loop._skip_backward = value

@property
Expand Down
3 changes: 0 additions & 3 deletions pytorch_lightning/loops/optimization/optimizer_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,9 +268,6 @@ def _run_optimization(
# if no result, user decided to skip optimization
# otherwise update running loss + reset accumulated loss
# TODO: find proper way to handle updating running loss
assert self.trainer.fit_loop is not None
assert self.trainer.fit_loop.epoch_loop is not None
assert self.trainer.fit_loop.epoch_loop.batch_loop is not None
self.trainer.fit_loop.epoch_loop.batch_loop._update_running_loss(result.loss)

# untoggle model params
Expand Down
2 changes: 1 addition & 1 deletion tests/loops/test_loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def test_connect_loops_direct(loop_name):

trainer = Trainer()

# trainer.loop = loop
# trainer.loop_name = loop
setattr(trainer, loop_name, loop)
assert loop.trainer is trainer

Expand Down

0 comments on commit 31bb6e6

Please sign in to comment.