Skip to content

Commit

Permalink
Improve typing for logging (Lightning-AI#10748)
Browse files Browse the repository at this point in the history
Co-authored-by: Justus Schock <[email protected]>
  • Loading branch information
carmocca and justusschock authored Nov 26, 2021
1 parent 31bb6e6 commit 78face6
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 16 deletions.
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,6 @@ module = [
"pytorch_lightning.trainer.connectors.callback_connector",
"pytorch_lightning.trainer.connectors.checkpoint_connector",
"pytorch_lightning.trainer.connectors.data_connector",
"pytorch_lightning.trainer.connectors.logger_connector.result",
"pytorch_lightning.trainer.data_loading",
"pytorch_lightning.trainer.optimizers",
"pytorch_lightning.trainer.supporters",
Expand Down
24 changes: 17 additions & 7 deletions pytorch_lightning/core/mixins/device_dtype_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,16 @@
import torch
from torch.nn import Module

try:
from typing_extensions import Self
except ImportError:
# workaround for Python 3.6 and 3.7.
# see https://www.python.org/dev/peps/pep-0673/
from typing import TypeVar

Self = TypeVar("TDeviceDtypeModuleMixin", bound="DeviceDtypeModuleMixin")


import pytorch_lightning as pl


Expand Down Expand Up @@ -47,7 +57,7 @@ def device(self) -> Union[str, torch.device]:

return device

def to(self, *args: Any, **kwargs: Any) -> "DeviceDtypeModuleMixin":
def to(self, *args: Any, **kwargs: Any) -> Self:
"""Moves and/or casts the parameters and buffers.
This can be called as
Expand Down Expand Up @@ -110,7 +120,7 @@ def to(self, *args: Any, **kwargs: Any) -> "DeviceDtypeModuleMixin":
self.__update_properties(device=out[0], dtype=out[1])
return super().to(*args, **kwargs)

def cuda(self, device: Optional[Union[torch.device, int]] = None) -> "DeviceDtypeModuleMixin":
def cuda(self, device: Optional[Union[torch.device, int]] = None) -> Self:
"""Moves all model parameters and buffers to the GPU. This also makes associated parameters and buffers
different objects. So it should be called before constructing optimizer if the module will live on GPU
while being optimized.
Expand All @@ -127,7 +137,7 @@ def cuda(self, device: Optional[Union[torch.device, int]] = None) -> "DeviceDtyp
self.__update_properties(device=device)
return super().cuda(device=device)

def cpu(self) -> "DeviceDtypeModuleMixin":
def cpu(self) -> Self:
"""Moves all model parameters and buffers to the CPU.
Returns:
Expand All @@ -136,7 +146,7 @@ def cpu(self) -> "DeviceDtypeModuleMixin":
self.__update_properties(device=torch.device("cpu"))
return super().cpu()

def type(self, dst_type: Union[str, torch.dtype]) -> "DeviceDtypeModuleMixin":
def type(self, dst_type: Union[str, torch.dtype]) -> Self:
"""Casts all parameters and buffers to :attr:`dst_type`.
Arguments:
Expand All @@ -148,7 +158,7 @@ def type(self, dst_type: Union[str, torch.dtype]) -> "DeviceDtypeModuleMixin":
self.__update_properties(dtype=dst_type)
return super().type(dst_type=dst_type)

def float(self) -> "DeviceDtypeModuleMixin":
def float(self) -> Self:
"""Casts all floating point parameters and buffers to ``float`` datatype.
Returns:
Expand All @@ -157,7 +167,7 @@ def float(self) -> "DeviceDtypeModuleMixin":
self.__update_properties(dtype=torch.float)
return super().float()

def double(self) -> "DeviceDtypeModuleMixin":
def double(self) -> Self:
"""Casts all floating point parameters and buffers to ``double`` datatype.
Returns:
Expand All @@ -166,7 +176,7 @@ def double(self) -> "DeviceDtypeModuleMixin":
self.__update_properties(dtype=torch.double)
return super().double()

def half(self) -> "DeviceDtypeModuleMixin":
def half(self) -> Self:
"""Casts all floating point parameters and buffers to ``half`` datatype.
Returns:
Expand Down
17 changes: 10 additions & 7 deletions pytorch_lightning/trainer/connectors/logger_connector/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,8 +211,10 @@ def __init__(self, metadata: _Metadata, is_tensor: bool) -> None:
self.add_state("value", torch.tensor(0.0), dist_reduce_fx=torch.sum)
if self.meta.is_mean_reduction:
self.add_state("cumulated_batch_size", torch.tensor(0), dist_reduce_fx=torch.sum)
# this is defined here only because upstream is missing the type annotation
self._forward_cache: Optional[Any] = None

def update(self, value: _IN_METRIC, batch_size: int) -> None:
def update(self, value: _IN_METRIC, batch_size: int) -> None: # type: ignore[override]
if self.is_tensor:
if not torch.is_floating_point(value):
dtype = torch.get_default_dtype()
Expand All @@ -225,16 +227,17 @@ def update(self, value: _IN_METRIC, batch_size: int) -> None:

if self.meta.on_step:
self._forward_cache = self.meta.sync(value.clone()) # `clone` because `sync` is in-place

# performance: no need to accumulate on values only logged on_step
if not self.meta.on_epoch:
self.value = self._forward_cache
return
# performance: no need to accumulate on values only logged on_step
if not self.meta.on_epoch:
self.value = self._forward_cache
return

# perform accumulation with reduction
if self.meta.is_mean_reduction:
self.value += value.mean() * batch_size
self.cumulated_batch_size += batch_size
# `Metric.add_state` does not work well with mypy, mypy doesn't know this is a `Tensor`
# we could add an assertion, but this is a hot code path
self.cumulated_batch_size += batch_size # type: ignore[operator]
elif self.meta.is_max_reduction or self.meta.is_min_reduction:
self.value = self.meta.reduce_fx(self.value, value.mean())
elif self.meta.is_sum_reduction:
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,4 @@ tensorboard>=2.2.0
torchmetrics>=0.4.1
pyDeprecate==0.3.1
packaging>=17.0
typing-extensions
typing-extensions>=4.0.0

0 comments on commit 78face6

Please sign in to comment.