Skip to content

Commit

Permalink
Fix errors on appa
Browse files Browse the repository at this point in the history
  • Loading branch information
Christoph Gerum committed Jan 26, 2024
1 parent d12925f commit f23247d
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 40 deletions.
6 changes: 2 additions & 4 deletions hannah/callbacks/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def on_validation_epoch_start(self, trainer, pl_module):
self.prepare(pl_module)

def on_validation_batch_end(
self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx
self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx = -1
):
"""
Expand All @@ -129,8 +129,6 @@ def on_validation_batch_end(
outputs:
batch:
batch_idx:
dataloader_idx:
Returns:
"""
Expand Down Expand Up @@ -210,7 +208,7 @@ def quantize(self, pl_module: torch.nn.Module) -> torch.nn.Module:
return pl_module

def on_test_batch_end(
self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx
self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx = -1
):
"""
Expand Down
7 changes: 0 additions & 7 deletions hannah/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,11 +146,6 @@ def train(
_convert_="partial",
)

if config.get("input_file", None):
msglogger.info("Loading initial weights from model %s", config.input_file)
lit_module.setup("train")
lit_module.load_from_state_dict(config.input_file, strict=False)

if config["auto_lr"]:
# run lr finder (counts as one epoch)
lr_finder = lit_trainer.lr_find(lit_module)
Expand All @@ -163,8 +158,6 @@ def train(
suggested_lr = lr_finder.suggestion()
config["lr"] = suggested_lr

lit_trainer.tune(lit_module)

logging.info("Starting training")
# PL TRAIN
ckpt_path = None
Expand Down
7 changes: 6 additions & 1 deletion hannah/utils/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,17 @@

from lightning_fabric.loggers.logger import rank_zero_experiment
from lightning_fabric.utilities import rank_zero_only, rank_zero_warn
from lightning_fabric.utilities.cloud_io import _is_dir, get_filesystem
from lightning_fabric.utilities.cloud_io import get_filesystem
from pytorch_lightning.loggers import Logger
from torch import Tensor

import fsspec

log = logging.getLogger(__name__)

def _is_dir(fs, path, strict=False):
return fs.isdir(path) or (not strict and fs.exists(path) and not fs.isfile(path))


def _add_prefix(
metrics: Mapping[str, Union[Tensor, float]], prefix: str, separator: str
Expand Down
50 changes: 23 additions & 27 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,13 @@ timm = {version = "^0.9.12", optional = true}
pycocotools = {version = "^2.0.6", optional = true}
gdown = {version = "^4.5.3", optional = true}
albumentations = {version = "^1.3.0", optional = true}
kornia = {version="^0.6.4", optional=true}
kornia = {version = "^0.7.1", optional = true}
lightning = "^2.1.2"
dvc = {version="^3.33.3", optional=true}
dvclive = {version="^3.4.1", optional=true}
dgl = "1.1.3"
captum = {version = "^0.7.0", optional = true}
pytorch-lightning = "^2.1.3"



Expand Down

0 comments on commit f23247d

Please sign in to comment.