From f23247df24cb0bef300ea9383a013f25f7b0d1ba Mon Sep 17 00:00:00 2001 From: Christoph Gerum Date: Fri, 26 Jan 2024 15:19:42 +0100 Subject: [PATCH] Fix errors on appa --- hannah/callbacks/backends.py | 6 ++--- hannah/train.py | 7 ----- hannah/utils/logger.py | 7 ++++- poetry.lock | 50 +++++++++++++++++------------------- pyproject.toml | 3 ++- 5 files changed, 33 insertions(+), 40 deletions(-) diff --git a/hannah/callbacks/backends.py b/hannah/callbacks/backends.py index 91f5864a..68a14bac 100644 --- a/hannah/callbacks/backends.py +++ b/hannah/callbacks/backends.py @@ -119,7 +119,7 @@ def on_validation_epoch_start(self, trainer, pl_module): self.prepare(pl_module) def on_validation_batch_end( - self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx + self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx = -1 ): """ @@ -129,8 +129,6 @@ def on_validation_batch_end( outputs: batch: batch_idx: - dataloader_idx: - Returns: """ @@ -210,7 +208,7 @@ def quantize(self, pl_module: torch.nn.Module) -> torch.nn.Module: return pl_module def on_test_batch_end( - self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx + self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx = -1 ): """ diff --git a/hannah/train.py b/hannah/train.py index 0a8684f1..da67c720 100644 --- a/hannah/train.py +++ b/hannah/train.py @@ -146,11 +146,6 @@ def train( _convert_="partial", ) - if config.get("input_file", None): - msglogger.info("Loading initial weights from model %s", config.input_file) - lit_module.setup("train") - lit_module.load_from_state_dict(config.input_file, strict=False) - if config["auto_lr"]: # run lr finder (counts as one epoch) lr_finder = lit_trainer.lr_find(lit_module) @@ -163,8 +158,6 @@ def train( suggested_lr = lr_finder.suggestion() config["lr"] = suggested_lr - lit_trainer.tune(lit_module) - logging.info("Starting training") # PL TRAIN ckpt_path = None diff --git a/hannah/utils/logger.py b/hannah/utils/logger.py index 7c192e9c..b85048c1 100644 --- a/hannah/utils/logger.py +++ b/hannah/utils/logger.py @@ -26,12 +26,17 @@ from lightning_fabric.loggers.logger import rank_zero_experiment from lightning_fabric.utilities import rank_zero_only, rank_zero_warn -from lightning_fabric.utilities.cloud_io import _is_dir, get_filesystem +from lightning_fabric.utilities.cloud_io import get_filesystem from pytorch_lightning.loggers import Logger from torch import Tensor +import fsspec + log = logging.getLogger(__name__) +def _is_dir(fs, path, strict=False): + return fs.isdir(path) or (not strict and fs.exists(path) and not fs.isfile(path)) + def _add_prefix( metrics: Mapping[str, Union[Tensor, float]], prefix: str, separator: str diff --git a/poetry.lock b/poetry.lock index a6426d5c..d095cf42 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2941,13 +2941,13 @@ zookeeper = ["kazoo (>=2.8.0)"] [[package]] name = "kornia" -version = "0.6.12" +version = "0.7.1" description = "Open Source Differentiable Computer Vision Library for PyTorch" optional = true -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "kornia-0.6.12-py2.py3-none-any.whl", hash = "sha256:659f0f0948e127b69ed437592f49531ccf4fc83d672474e1d89ed30d087e39e1"}, - {file = "kornia-0.6.12.tar.gz", hash = "sha256:e30bd3d830226f7a159dff1f7757c6200e8f27d1333f06e9d2f98bdb33ce18d3"}, + {file = "kornia-0.7.1-py2.py3-none-any.whl", hash = "sha256:bd1cbe99373beafe6e59423be2374afbc2086a9ba57a8c66b94db6622b86f091"}, + {file = "kornia-0.7.1.tar.gz", hash = "sha256:65b54a50f70c1f88240b557fda3fdcc1ab866982a5d062e52213130f5a48465c"}, ] [package.dependencies] @@ -2955,9 +2955,9 @@ packaging = "*" torch = ">=1.9.1" [package.extras] -dev = ["isort", "kornia-rs (==0.0.8)", "mypy[reports]", "numpy", "opencv-python", "pre-commit (>=2)", "pydocstyle", "pytest (==7.3.1)", "pytest-cov (==4)", "scipy"] -docs = ["PyYAML (>=5.1)", "furo", "kornia-moons", "matplotlib", "opencv-python", "sphinx (>=4)", "sphinx-autodoc-defaultargs", "sphinx-autodoc-typehints", "sphinx-copybutton (>=0.3)", "sphinx-design", "sphinxcontrib-bibtex", "sphinxcontrib-gtagjs", "sphinxcontrib-youtube", "torchvision"] -x = ["accelerate (==0.18.0)"] +dev = ["coverage", "kornia-rs (==0.0.8)", "mypy[reports]", "numpy", "onnx", "pre-commit (>=2)", "pydocstyle", "pytest (==7.4.3)", "pytest-timeout"] +docs = ["PyYAML (>=5.1)", "furo", "kornia-moons", "matplotlib", "opencv-python", "sphinx", "sphinx-autodoc-defaultargs", "sphinx-autodoc-typehints", "sphinx-copybutton (>=0.3)", "sphinx-design", "sphinxcontrib-bibtex", "sphinxcontrib-gtagjs", "sphinxcontrib-youtube"] +x = ["accelerate (==0.25.0)", "onnxruntime-gpu (>=1.16)"] [[package]] name = "lazy-loader" @@ -5050,38 +5050,34 @@ validation = ["pydantic (>=1.7.4)"] [[package]] name = "pytorch-lightning" -version = "1.9.5" +version = "2.1.3" description = "PyTorch Lightning is the lightweight PyTorch wrapper for ML researchers. Scale your models. Write less boilerplate." optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "pytorch-lightning-1.9.5.tar.gz", hash = "sha256:925fe7b80ddf04859fa385aa493b260be4000b11a2f22447afb4a932d1f07d26"}, - {file = "pytorch_lightning-1.9.5-py3-none-any.whl", hash = "sha256:06821558158623c5d2ecf5d3d0374dc8bd661e0acd3acf54a6d6f71737c156c5"}, + {file = "pytorch-lightning-2.1.3.tar.gz", hash = "sha256:2500b002fa09cb37b0e12f879876bf30a2d260b0f04783d33264dab175f0c966"}, + {file = "pytorch_lightning-2.1.3-py3-none-any.whl", hash = "sha256:03ed186035a230b161130e0d8ecf1dd6657ff7e3f1520e9257b0db7650f9aeea"}, ] [package.dependencies] -fsspec = {version = ">2021.06.0", extras = ["http"]} -lightning-utilities = ">=0.6.0.post0" +fsspec = {version = ">=2022.5.0", extras = ["http"]} +lightning-utilities = ">=0.8.0" numpy = ">=1.17.2" -packaging = ">=17.1" +packaging = ">=20.0" PyYAML = ">=5.4" -torch = ">=1.10.0" +torch = ">=1.12.0" torchmetrics = ">=0.7.0" tqdm = ">=4.57.0" typing-extensions = ">=4.0.0" [package.extras] -all = ["colossalai (>=0.2.0)", "deepspeed (>=0.6.0)", "fairscale (>=0.4.5)", "gym[classic-control] (>=0.17.0)", "hivemind (==1.1.5)", "horovod (>=0.21.2,!=0.24.0)", "hydra-core (>=1.0.5)", "ipython[all] (<8.7.1)", "jsonargparse[signatures] (>=4.18.0)", "matplotlib (>3.1)", "omegaconf (>=2.0.5)", "rich (>=10.14.0,!=10.15.0.a)", "tensorboardX (>=2.2)", "torchmetrics (>=0.10.0)", "torchvision (>=0.11.1)"] -colossalai = ["colossalai (>=0.2.0)"] -deepspeed = ["deepspeed (>=0.6.0)"] -dev = ["cloudpickle (>=1.3)", "codecov (==2.1.12)", "colossalai (>=0.2.0)", "coverage (==6.5.0)", "deepspeed (>=0.6.0)", "fairscale (>=0.4.5)", "fastapi (<0.87.0)", "gym[classic-control] (>=0.17.0)", "hivemind (==1.1.5)", "horovod (>=0.21.2,!=0.24.0)", "hydra-core (>=1.0.5)", "ipython[all] (<8.7.1)", "jsonargparse[signatures] (>=4.18.0)", "matplotlib (>3.1)", "omegaconf (>=2.0.5)", "onnx (<1.14.0)", "onnxruntime (<1.14.0)", "pandas (>1.0)", "pre-commit (==2.20.0)", "protobuf (<=3.20.1)", "psutil (<5.9.5)", "pytest (==7.2.0)", "pytest-cov (==4.0.0)", "pytest-forked (==1.4.0)", "pytest-rerunfailures (==10.3)", "rich (>=10.14.0,!=10.15.0.a)", "scikit-learn (>0.22.1)", "tensorboard (>=2.9.1)", "tensorboardX (>=2.2)", "torchmetrics (>=0.10.0)", "torchvision (>=0.11.1)", "uvicorn (<0.19.1)"] -examples = ["gym[classic-control] (>=0.17.0)", "ipython[all] (<8.7.1)", "torchmetrics (>=0.10.0)", "torchvision (>=0.11.1)"] -extra = ["hydra-core (>=1.0.5)", "jsonargparse[signatures] (>=4.18.0)", "matplotlib (>3.1)", "omegaconf (>=2.0.5)", "rich (>=10.14.0,!=10.15.0.a)", "tensorboardX (>=2.2)"] -fairscale = ["fairscale (>=0.4.5)"] -hivemind = ["hivemind (==1.1.5)"] -horovod = ["horovod (>=0.21.2,!=0.24.0)"] -strategies = ["colossalai (>=0.2.0)", "deepspeed (>=0.6.0)", "fairscale (>=0.4.5)", "hivemind (==1.1.5)", "horovod (>=0.21.2,!=0.24.0)"] -test = ["cloudpickle (>=1.3)", "codecov (==2.1.12)", "coverage (==6.5.0)", "fastapi (<0.87.0)", "onnx (<1.14.0)", "onnxruntime (<1.14.0)", "pandas (>1.0)", "pre-commit (==2.20.0)", "protobuf (<=3.20.1)", "psutil (<5.9.5)", "pytest (==7.2.0)", "pytest-cov (==4.0.0)", "pytest-forked (==1.4.0)", "pytest-rerunfailures (==10.3)", "scikit-learn (>0.22.1)", "tensorboard (>=2.9.1)", "uvicorn (<0.19.1)"] +all = ["bitsandbytes (<=0.41.1)", "deepspeed (>=0.8.2,<=0.9.3)", "gym[classic-control] (>=0.17.0)", "hydra-core (>=1.0.5)", "ipython[all] (<8.15.0)", "jsonargparse[signatures] (>=4.26.1)", "lightning-utilities (>=0.8.0)", "matplotlib (>3.1)", "omegaconf (>=2.0.5)", "rich (>=12.3.0)", "tensorboardX (>=2.2)", "torchmetrics (>=0.10.0)", "torchvision (>=0.13.0)"] +deepspeed = ["deepspeed (>=0.8.2,<=0.9.3)"] +dev = ["bitsandbytes (<=0.41.1)", "cloudpickle (>=1.3)", "coverage (==7.3.1)", "deepspeed (>=0.8.2,<=0.9.3)", "fastapi", "gym[classic-control] (>=0.17.0)", "hydra-core (>=1.0.5)", "ipython[all] (<8.15.0)", "jsonargparse[signatures] (>=4.26.1)", "lightning-utilities (>=0.8.0)", "matplotlib (>3.1)", "omegaconf (>=2.0.5)", "onnx (>=0.14.0)", "onnxruntime (>=0.15.0)", "pandas (>1.0)", "psutil (<5.9.6)", "pytest (==7.4.0)", "pytest-cov (==4.1.0)", "pytest-random-order (==1.1.0)", "pytest-rerunfailures (==12.0)", "pytest-timeout (==2.1.0)", "rich (>=12.3.0)", "scikit-learn (>0.22.1)", "tensorboard (>=2.9.1)", "tensorboardX (>=2.2)", "torchmetrics (>=0.10.0)", "torchvision (>=0.13.0)", "uvicorn"] +examples = ["gym[classic-control] (>=0.17.0)", "ipython[all] (<8.15.0)", "lightning-utilities (>=0.8.0)", "torchmetrics (>=0.10.0)", "torchvision (>=0.13.0)"] +extra = ["bitsandbytes (<=0.41.1)", "hydra-core (>=1.0.5)", "jsonargparse[signatures] (>=4.26.1)", "matplotlib (>3.1)", "omegaconf (>=2.0.5)", "rich (>=12.3.0)", "tensorboardX (>=2.2)"] +strategies = ["deepspeed (>=0.8.2,<=0.9.3)"] +test = ["cloudpickle (>=1.3)", "coverage (==7.3.1)", "fastapi", "onnx (>=0.14.0)", "onnxruntime (>=0.15.0)", "pandas (>1.0)", "psutil (<5.9.6)", "pytest (==7.4.0)", "pytest-cov (==4.1.0)", "pytest-random-order (==1.1.0)", "pytest-rerunfailures (==12.0)", "pytest-timeout (==2.1.0)", "scikit-learn (>0.22.1)", "tensorboard (>=2.9.1)", "uvicorn"] [[package]] name = "pytz" @@ -7170,4 +7166,4 @@ vision = ["albumentations", "gdown", "imagecorruptions", "kornia", "pycocotools" [metadata] lock-version = "2.0" python-versions = ">3.8 <3.12" -content-hash = "3e8a18b1bd8aa11b328d6ac9f0b1f61c85aa24ef9d1ea47a9b015b5fb322b3b5" +content-hash = "735bec380b7040d57a2e260bcfb530321ce7cd13d42b2338abf8798848f0e9ac" diff --git a/pyproject.toml b/pyproject.toml index 01600b36..0d8a3d76 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -61,12 +61,13 @@ timm = {version = "^0.9.12", optional = true} pycocotools = {version = "^2.0.6", optional = true} gdown = {version = "^4.5.3", optional = true} albumentations = {version = "^1.3.0", optional = true} -kornia = {version="^0.6.4", optional=true} +kornia = {version = "^0.7.1", optional = true} lightning = "^2.1.2" dvc = {version="^3.33.3", optional=true} dvclive = {version="^3.4.1", optional=true} dgl = "1.1.3" captum = {version = "^0.7.0", optional = true} +pytorch-lightning = "^2.1.3"