Skip to content

Commit

Permalink
BaseTask: fix load_from_checkpoint, ignore 'ignore' (#2317)
Browse files Browse the repository at this point in the history
* BaseTask: fix load_from_checkpoint, ignore 'ignore'

* Fix ObjectDetectionTask
  • Loading branch information
adamjstewart authored Sep 28, 2024
1 parent 1a98078 commit 69f91a2
Show file tree
Hide file tree
Showing 8 changed files with 14 additions and 8 deletions.
7 changes: 5 additions & 2 deletions torchgeo/trainers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ class BaseTask(LightningModule, ABC):
.. versionadded:: 0.5
"""

#: Parameters to ignore when saving hyperparameters.
ignore: Sequence[str] | str | None = 'weights'

#: Model to train.
model: Any

Expand All @@ -28,14 +31,14 @@ class BaseTask(LightningModule, ABC):
#: Whether the goal is to minimize or maximize the performance metric to monitor.
mode = 'min'

def __init__(self, ignore: Sequence[str] | str | None = None) -> None:
def __init__(self) -> None:
"""Initialize a new BaseTask instance.
Args:
ignore: Arguments to skip when saving hyperparameters.
"""
super().__init__()
self.save_hyperparameters(ignore=ignore)
self.save_hyperparameters(ignore=self.ignore)
self.configure_models()
self.configure_losses()
self.configure_metrics()
Expand Down
2 changes: 1 addition & 1 deletion torchgeo/trainers/byol.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ def __init__(
renamed to *model*, *lr*, and *patience*.
"""
self.weights = weights
super().__init__(ignore='weights')
super().__init__()

def configure_models(self) -> None:
"""Initialize the model."""
Expand Down
2 changes: 1 addition & 1 deletion torchgeo/trainers/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ class and used with 'ce' loss.
*lr* and *patience*.
"""
self.weights = weights
super().__init__(ignore='weights')
super().__init__()

def configure_models(self) -> None:
"""Initialize the model."""
Expand Down
1 change: 1 addition & 0 deletions torchgeo/trainers/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ class ObjectDetectionTask(BaseTask):
.. versionadded:: 0.4
"""

ignore = None
monitor = 'val_map'
mode = 'max'

Expand Down
3 changes: 2 additions & 1 deletion torchgeo/trainers/moco.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ class MoCoTask(BaseTask):
.. versionadded:: 0.5
"""

ignore = ('weights', 'augmentation1', 'augmentation2')
monitor = 'train_loss'

def __init__(
Expand Down Expand Up @@ -219,7 +220,7 @@ def __init__(
warnings.warn('MoCo v3 does not use a memory bank')

self.weights = weights
super().__init__(ignore=['weights', 'augmentation1', 'augmentation2'])
super().__init__()

grayscale_weights = grayscale_weights or torch.ones(in_channels)
aug1, aug2 = moco_augmentations(version, size, grayscale_weights)
Expand Down
2 changes: 1 addition & 1 deletion torchgeo/trainers/regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def __init__(
*lr* and *patience*.
"""
self.weights = weights
super().__init__(ignore='weights')
super().__init__()

def configure_models(self) -> None:
"""Initialize the model."""
Expand Down
2 changes: 1 addition & 1 deletion torchgeo/trainers/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ class and used with 'ce' loss.
The *ignore_index* parameter now works for jaccard loss.
"""
self.weights = weights
super().__init__(ignore='weights')
super().__init__()

def configure_models(self) -> None:
"""Initialize the model.
Expand Down
3 changes: 2 additions & 1 deletion torchgeo/trainers/simclr.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ class SimCLRTask(BaseTask):
.. versionadded:: 0.5
"""

ignore = ('weights', 'augmentations')
monitor = 'train_loss'

def __init__(
Expand Down Expand Up @@ -140,7 +141,7 @@ def __init__(
warnings.warn('SimCLR v2 uses a memory bank')

self.weights = weights
super().__init__(ignore=['weights', 'augmentations'])
super().__init__()

grayscale_weights = grayscale_weights or torch.ones(in_channels)
self.augmentations = augmentations or simclr_augmentations(
Expand Down

0 comments on commit 69f91a2

Please sign in to comment.