From d28101a97edb89fe0a0fc1da052fec91c4295bb0 Mon Sep 17 00:00:00 2001 From: Sofiia Kashperova <47923051+kashperova@users.noreply.github.com> Date: Sat, 21 Dec 2024 16:33:58 +0200 Subject: [PATCH] Logger & trainer tests (#5) * Add tensorboard logger * Minor * Update affine coupling * Fix flow block Add autoflake * Minor fixes * Add tests for trainer --- .gitignore | 3 + .misc/notes.md | 5 + .pre-commit-config.yaml | 10 + poetry.lock | 197 ++++++++++++++++++- pyproject.toml | 1 + src/configs/celeba.yaml | 10 +- src/configs/model.yaml | 0 src/model/affine_coupling/affine_coupling.py | 19 +- src/model/affine_coupling/net.py | 8 +- src/model/flow_block.py | 2 +- src/model/glow.py | 6 +- src/model/invert_conv.py | 4 + src/modules/logger/__init__.py | 1 + src/modules/logger/logger.py | 30 +++ src/modules/trainer/ddp.py | 0 src/modules/trainer/ddp_trainer.py | 5 +- src/modules/trainer/trainer.py | 72 +++++-- src/modules/utils/tensors.py | 7 + tests/conftest.py | 1 + tests/fixtures/blocks.py | 6 +- tests/fixtures/config.py | 19 ++ tests/fixtures/inputs.py | 10 + tests/fixtures/trainer.py | 47 +++++ tests/test_actnorm.py | 2 +- tests/test_affine_coupling.py | 29 ++- tests/test_flow.py | 22 ++- tests/test_trainer.py | 59 ++++++ 27 files changed, 500 insertions(+), 75 deletions(-) create mode 100644 .misc/notes.md delete mode 100644 src/configs/model.yaml create mode 100644 src/modules/logger/logger.py delete mode 100644 src/modules/trainer/ddp.py create mode 100644 tests/fixtures/trainer.py create mode 100644 tests/test_trainer.py diff --git a/.gitignore b/.gitignore index 95ec0e3..2e93839 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,4 @@ todo.py +./runs +./samples +./.misc/notebooks diff --git a/.misc/notes.md b/.misc/notes.md new file mode 100644 index 0000000..a2edbcf --- /dev/null +++ b/.misc/notes.md @@ -0,0 +1,5 @@ +- Norm flows can't can't work with discrete random variables, so we need to dequantize input image tensors. +Here the simplest solution [implemented](../src/modules/utils/tensors.py): adding a small amount of noise to each discrete value. +But in general it is better to use variational dequantization. +- Read more about KL duality +- Jacobian can be interpreted as an indicator of how the volume of the probability space changes diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 2e2afaf..3298f4b 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -23,6 +23,16 @@ repos: hooks: - id: isort args: [ "--profile", "black", "--filter-files" ] + - repo: https://github.com/PyCQA/autoflake + rev: v2.3.1 + hooks: + - id: autoflake + args: + - '--remove-all-unused-imports' + - '--remove-unused-variables' + - '--exclude=__init__.py' + - '--in-place' + - '--recursive' - repo: https://github.com/PyCQA/flake8 rev: 7.1.1 hooks: diff --git a/poetry.lock b/poetry.lock index 2c72a4f..c1bf610 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,5 +1,16 @@ # This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. +[[package]] +name = "absl-py" +version = "2.1.0" +description = "Abseil Python Common Libraries, see https://github.com/abseil/abseil-py." +optional = false +python-versions = ">=3.7" +files = [ + {file = "absl-py-2.1.0.tar.gz", hash = "sha256:7820790efbb316739cde8b4e19357243fc3608a152024288513dd968d7d959ff"}, + {file = "absl_py-2.1.0-py3-none-any.whl", hash = "sha256:526a04eadab8b4ee719ce68f204172ead1027549089702d99b9059f129ff1308"}, +] + [[package]] name = "antlr4-python3-runtime" version = "4.9.3" @@ -98,6 +109,73 @@ test-downstream = ["aiobotocore (>=2.5.4,<3.0.0)", "dask-expr", "dask[dataframe, test-full = ["adlfs", "aiohttp (!=4.0.0a0,!=4.0.0a1)", "cloudpickle", "dask", "distributed", "dropbox", "dropboxdrivefs", "fastparquet", "fusepy", "gcsfs", "jinja2", "kerchunk", "libarchive-c", "lz4", "notebook", "numpy", "ocifs", "pandas", "panel", "paramiko", "pyarrow", "pyarrow (>=1)", "pyftpdlib", "pygit2", "pytest", "pytest-asyncio (!=0.22.0)", "pytest-benchmark", "pytest-cov", "pytest-mock", "pytest-recording", "pytest-rerunfailures", "python-snappy", "requests", "smbprotocol", "tqdm", "urllib3", "zarr", "zstandard"] tqdm = ["tqdm"] +[[package]] +name = "grpcio" +version = "1.68.1" +description = "HTTP/2-based RPC framework" +optional = false +python-versions = ">=3.8" +files = [ + {file = "grpcio-1.68.1-cp310-cp310-linux_armv7l.whl", hash = "sha256:d35740e3f45f60f3c37b1e6f2f4702c23867b9ce21c6410254c9c682237da68d"}, + {file = "grpcio-1.68.1-cp310-cp310-macosx_12_0_universal2.whl", hash = "sha256:d99abcd61760ebb34bdff37e5a3ba333c5cc09feda8c1ad42547bea0416ada78"}, + {file = "grpcio-1.68.1-cp310-cp310-manylinux_2_17_aarch64.whl", hash = "sha256:f8261fa2a5f679abeb2a0a93ad056d765cdca1c47745eda3f2d87f874ff4b8c9"}, + {file = "grpcio-1.68.1-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0feb02205a27caca128627bd1df4ee7212db051019a9afa76f4bb6a1a80ca95e"}, + {file = "grpcio-1.68.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:919d7f18f63bcad3a0f81146188e90274fde800a94e35d42ffe9eadf6a9a6330"}, + {file = "grpcio-1.68.1-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:963cc8d7d79b12c56008aabd8b457f400952dbea8997dd185f155e2f228db079"}, + {file = "grpcio-1.68.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:ccf2ebd2de2d6661e2520dae293298a3803a98ebfc099275f113ce1f6c2a80f1"}, + {file = "grpcio-1.68.1-cp310-cp310-win32.whl", hash = "sha256:2cc1fd04af8399971bcd4f43bd98c22d01029ea2e56e69c34daf2bf8470e47f5"}, + {file = "grpcio-1.68.1-cp310-cp310-win_amd64.whl", hash = "sha256:ee2e743e51cb964b4975de572aa8fb95b633f496f9fcb5e257893df3be854746"}, + {file = "grpcio-1.68.1-cp311-cp311-linux_armv7l.whl", hash = "sha256:55857c71641064f01ff0541a1776bfe04a59db5558e82897d35a7793e525774c"}, + {file = "grpcio-1.68.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:4b177f5547f1b995826ef529d2eef89cca2f830dd8b2c99ffd5fde4da734ba73"}, + {file = "grpcio-1.68.1-cp311-cp311-manylinux_2_17_aarch64.whl", hash = "sha256:3522c77d7e6606d6665ec8d50e867f13f946a4e00c7df46768f1c85089eae515"}, + {file = "grpcio-1.68.1-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9d1fae6bbf0816415b81db1e82fb3bf56f7857273c84dcbe68cbe046e58e1ccd"}, + {file = "grpcio-1.68.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:298ee7f80e26f9483f0b6f94cc0a046caf54400a11b644713bb5b3d8eb387600"}, + {file = "grpcio-1.68.1-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:cbb5780e2e740b6b4f2d208e90453591036ff80c02cc605fea1af8e6fc6b1bbe"}, + {file = "grpcio-1.68.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:ddda1aa22495d8acd9dfbafff2866438d12faec4d024ebc2e656784d96328ad0"}, + {file = "grpcio-1.68.1-cp311-cp311-win32.whl", hash = "sha256:b33bd114fa5a83f03ec6b7b262ef9f5cac549d4126f1dc702078767b10c46ed9"}, + {file = "grpcio-1.68.1-cp311-cp311-win_amd64.whl", hash = "sha256:7f20ebec257af55694d8f993e162ddf0d36bd82d4e57f74b31c67b3c6d63d8b2"}, + {file = "grpcio-1.68.1-cp312-cp312-linux_armv7l.whl", hash = "sha256:8829924fffb25386995a31998ccbbeaa7367223e647e0122043dfc485a87c666"}, + {file = "grpcio-1.68.1-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:3aed6544e4d523cd6b3119b0916cef3d15ef2da51e088211e4d1eb91a6c7f4f1"}, + {file = "grpcio-1.68.1-cp312-cp312-manylinux_2_17_aarch64.whl", hash = "sha256:4efac5481c696d5cb124ff1c119a78bddbfdd13fc499e3bc0ca81e95fc573684"}, + {file = "grpcio-1.68.1-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6ab2d912ca39c51f46baf2a0d92aa265aa96b2443266fc50d234fa88bf877d8e"}, + {file = "grpcio-1.68.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:95c87ce2a97434dffe7327a4071839ab8e8bffd0054cc74cbe971fba98aedd60"}, + {file = "grpcio-1.68.1-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:e4842e4872ae4ae0f5497bf60a0498fa778c192cc7a9e87877abd2814aca9475"}, + {file = "grpcio-1.68.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:255b1635b0ed81e9f91da4fcc8d43b7ea5520090b9a9ad9340d147066d1d3613"}, + {file = "grpcio-1.68.1-cp312-cp312-win32.whl", hash = "sha256:7dfc914cc31c906297b30463dde0b9be48e36939575eaf2a0a22a8096e69afe5"}, + {file = "grpcio-1.68.1-cp312-cp312-win_amd64.whl", hash = "sha256:a0c8ddabef9c8f41617f213e527254c41e8b96ea9d387c632af878d05db9229c"}, + {file = "grpcio-1.68.1-cp313-cp313-linux_armv7l.whl", hash = "sha256:a47faedc9ea2e7a3b6569795c040aae5895a19dde0c728a48d3c5d7995fda385"}, + {file = "grpcio-1.68.1-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:390eee4225a661c5cd133c09f5da1ee3c84498dc265fd292a6912b65c421c78c"}, + {file = "grpcio-1.68.1-cp313-cp313-manylinux_2_17_aarch64.whl", hash = "sha256:66a24f3d45c33550703f0abb8b656515b0ab777970fa275693a2f6dc8e35f1c1"}, + {file = "grpcio-1.68.1-cp313-cp313-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c08079b4934b0bf0a8847f42c197b1d12cba6495a3d43febd7e99ecd1cdc8d54"}, + {file = "grpcio-1.68.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8720c25cd9ac25dd04ee02b69256d0ce35bf8a0f29e20577427355272230965a"}, + {file = "grpcio-1.68.1-cp313-cp313-musllinux_1_1_i686.whl", hash = "sha256:04cfd68bf4f38f5bb959ee2361a7546916bd9a50f78617a346b3aeb2b42e2161"}, + {file = "grpcio-1.68.1-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:c28848761a6520c5c6071d2904a18d339a796ebe6b800adc8b3f474c5ce3c3ad"}, + {file = "grpcio-1.68.1-cp313-cp313-win32.whl", hash = "sha256:77d65165fc35cff6e954e7fd4229e05ec76102d4406d4576528d3a3635fc6172"}, + {file = "grpcio-1.68.1-cp313-cp313-win_amd64.whl", hash = "sha256:a8040f85dcb9830d8bbb033ae66d272614cec6faceee88d37a88a9bd1a7a704e"}, + {file = "grpcio-1.68.1-cp38-cp38-linux_armv7l.whl", hash = "sha256:eeb38ff04ab6e5756a2aef6ad8d94e89bb4a51ef96e20f45c44ba190fa0bcaad"}, + {file = "grpcio-1.68.1-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:8a3869a6661ec8f81d93f4597da50336718bde9eb13267a699ac7e0a1d6d0bea"}, + {file = "grpcio-1.68.1-cp38-cp38-manylinux_2_17_aarch64.whl", hash = "sha256:2c4cec6177bf325eb6faa6bd834d2ff6aa8bb3b29012cceb4937b86f8b74323c"}, + {file = "grpcio-1.68.1-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:12941d533f3cd45d46f202e3667be8ebf6bcb3573629c7ec12c3e211d99cfccf"}, + {file = "grpcio-1.68.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:80af6f1e69c5e68a2be529990684abdd31ed6622e988bf18850075c81bb1ad6e"}, + {file = "grpcio-1.68.1-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:e8dbe3e00771bfe3d04feed8210fc6617006d06d9a2679b74605b9fed3e8362c"}, + {file = "grpcio-1.68.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:83bbf5807dc3ee94ce1de2dfe8a356e1d74101e4b9d7aa8c720cc4818a34aded"}, + {file = "grpcio-1.68.1-cp38-cp38-win32.whl", hash = "sha256:8cb620037a2fd9eeee97b4531880e439ebfcd6d7d78f2e7dcc3726428ab5ef63"}, + {file = "grpcio-1.68.1-cp38-cp38-win_amd64.whl", hash = "sha256:52fbf85aa71263380d330f4fce9f013c0798242e31ede05fcee7fbe40ccfc20d"}, + {file = "grpcio-1.68.1-cp39-cp39-linux_armv7l.whl", hash = "sha256:cb400138e73969eb5e0535d1d06cae6a6f7a15f2cc74add320e2130b8179211a"}, + {file = "grpcio-1.68.1-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:a1b988b40f2fd9de5c820f3a701a43339d8dcf2cb2f1ca137e2c02671cc83ac1"}, + {file = "grpcio-1.68.1-cp39-cp39-manylinux_2_17_aarch64.whl", hash = "sha256:96f473cdacfdd506008a5d7579c9f6a7ff245a9ade92c3c0265eb76cc591914f"}, + {file = "grpcio-1.68.1-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:37ea3be171f3cf3e7b7e412a98b77685eba9d4fd67421f4a34686a63a65d99f9"}, + {file = "grpcio-1.68.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3ceb56c4285754e33bb3c2fa777d055e96e6932351a3082ce3559be47f8024f0"}, + {file = "grpcio-1.68.1-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:dffd29a2961f3263a16d73945b57cd44a8fd0b235740cb14056f0612329b345e"}, + {file = "grpcio-1.68.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:025f790c056815b3bf53da850dd70ebb849fd755a4b1ac822cb65cd631e37d43"}, + {file = "grpcio-1.68.1-cp39-cp39-win32.whl", hash = "sha256:1098f03dedc3b9810810568060dea4ac0822b4062f537b0f53aa015269be0a76"}, + {file = "grpcio-1.68.1-cp39-cp39-win_amd64.whl", hash = "sha256:334ab917792904245a028f10e803fcd5b6f36a7b2173a820c0b5b076555825e1"}, + {file = "grpcio-1.68.1.tar.gz", hash = "sha256:44a8502dd5de653ae6a73e2de50a401d84184f0331d0ac3daeb044e66d5c5054"}, +] + +[package.extras] +protobuf = ["grpcio-tools (>=1.68.1)"] + [[package]] name = "hydra-core" version = "1.3.2" @@ -156,6 +234,21 @@ MarkupSafe = ">=2.0" [package.extras] i18n = ["Babel (>=2.7)"] +[[package]] +name = "markdown" +version = "3.7" +description = "Python implementation of John Gruber's Markdown." +optional = false +python-versions = ">=3.8" +files = [ + {file = "Markdown-3.7-py3-none-any.whl", hash = "sha256:7eb6df5690b81a1d7942992c97fad2938e956e79df20cbc6186e9c3a77b1c803"}, + {file = "markdown-3.7.tar.gz", hash = "sha256:2ae2471477cfd02dbbf038d5d9bc226d40def84b4fe2986e49b59b6b472bbed2"}, +] + +[package.extras] +docs = ["mdx-gh-links (>=0.2)", "mkdocs (>=1.5)", "mkdocs-gen-files", "mkdocs-literate-nav", "mkdocs-nature (>=0.6)", "mkdocs-section-index", "mkdocstrings[python]"] +testing = ["coverage", "pyyaml"] + [[package]] name = "markupsafe" version = "3.0.2" @@ -663,6 +756,26 @@ nodeenv = ">=0.11.1" pyyaml = ">=5.1" virtualenv = ">=20.10.0" +[[package]] +name = "protobuf" +version = "5.29.2" +description = "" +optional = false +python-versions = ">=3.8" +files = [ + {file = "protobuf-5.29.2-cp310-abi3-win32.whl", hash = "sha256:c12ba8249f5624300cf51c3d0bfe5be71a60c63e4dcf51ffe9a68771d958c851"}, + {file = "protobuf-5.29.2-cp310-abi3-win_amd64.whl", hash = "sha256:842de6d9241134a973aab719ab42b008a18a90f9f07f06ba480df268f86432f9"}, + {file = "protobuf-5.29.2-cp38-abi3-macosx_10_9_universal2.whl", hash = "sha256:a0c53d78383c851bfa97eb42e3703aefdc96d2036a41482ffd55dc5f529466eb"}, + {file = "protobuf-5.29.2-cp38-abi3-manylinux2014_aarch64.whl", hash = "sha256:494229ecd8c9009dd71eda5fd57528395d1eacdf307dbece6c12ad0dd09e912e"}, + {file = "protobuf-5.29.2-cp38-abi3-manylinux2014_x86_64.whl", hash = "sha256:b6b0d416bbbb9d4fbf9d0561dbfc4e324fd522f61f7af0fe0f282ab67b22477e"}, + {file = "protobuf-5.29.2-cp38-cp38-win32.whl", hash = "sha256:e621a98c0201a7c8afe89d9646859859be97cb22b8bf1d8eacfd90d5bda2eb19"}, + {file = "protobuf-5.29.2-cp38-cp38-win_amd64.whl", hash = "sha256:13d6d617a2a9e0e82a88113d7191a1baa1e42c2cc6f5f1398d3b054c8e7e714a"}, + {file = "protobuf-5.29.2-cp39-cp39-win32.whl", hash = "sha256:36000f97ea1e76e8398a3f02936aac2a5d2b111aae9920ec1b769fc4a222c4d9"}, + {file = "protobuf-5.29.2-cp39-cp39-win_amd64.whl", hash = "sha256:2d2e674c58a06311c8e99e74be43e7f3a8d1e2b2fdf845eaa347fbd866f23355"}, + {file = "protobuf-5.29.2-py3-none-any.whl", hash = "sha256:fde4554c0e578a5a0bcc9a276339594848d1e89f9ea47b4427c80e5d72f90181"}, + {file = "protobuf-5.29.2.tar.gz", hash = "sha256:b2cc8e8bb7c9326996f0e160137b0861f1a82162502658df2951209d0cb0309e"}, +] + [[package]] name = "pytest" version = "8.3.4" @@ -745,6 +858,37 @@ files = [ {file = "pyyaml-6.0.2.tar.gz", hash = "sha256:d584d9ec91ad65861cc08d42e834324ef890a082e591037abe114850ff7bbc3e"}, ] +[[package]] +name = "setuptools" +version = "75.6.0" +description = "Easily download, build, install, upgrade, and uninstall Python packages" +optional = false +python-versions = ">=3.9" +files = [ + {file = "setuptools-75.6.0-py3-none-any.whl", hash = "sha256:ce74b49e8f7110f9bf04883b730f4765b774ef3ef28f722cce7c273d253aaf7d"}, + {file = "setuptools-75.6.0.tar.gz", hash = "sha256:8199222558df7c86216af4f84c30e9b34a61d8ba19366cc914424cdbd28252f6"}, +] + +[package.extras] +check = ["pytest-checkdocs (>=2.4)", "pytest-ruff (>=0.2.1)", "ruff (>=0.7.0)"] +core = ["importlib_metadata (>=6)", "jaraco.collections", "jaraco.functools (>=4)", "jaraco.text (>=3.7)", "more_itertools", "more_itertools (>=8.8)", "packaging", "packaging (>=24.2)", "platformdirs (>=4.2.2)", "tomli (>=2.0.1)", "wheel (>=0.43.0)"] +cover = ["pytest-cov"] +doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "pygments-github-lexers (==0.0.5)", "pyproject-hooks (!=1.1)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-favicon", "sphinx-inline-tabs", "sphinx-lint", "sphinx-notfound-page (>=1,<2)", "sphinx-reredirects", "sphinxcontrib-towncrier", "towncrier (<24.7)"] +enabler = ["pytest-enabler (>=2.2)"] +test = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "ini2toml[lite] (>=0.14)", "jaraco.develop (>=7.21)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "jaraco.test (>=5.5)", "packaging (>=24.2)", "pip (>=19.1)", "pyproject-hooks (!=1.1)", "pytest (>=6,!=8.1.*)", "pytest-home (>=0.5)", "pytest-perf", "pytest-subprocess", "pytest-timeout", "pytest-xdist (>=3)", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel (>=0.44.0)"] +type = ["importlib_metadata (>=7.0.2)", "jaraco.develop (>=7.21)", "mypy (>=1.12,<1.14)", "pytest-mypy"] + +[[package]] +name = "six" +version = "1.17.0" +description = "Python 2 and 3 compatibility utilities" +optional = false +python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7" +files = [ + {file = "six-1.17.0-py2.py3-none-any.whl", hash = "sha256:4721f391ed90541fddacab5acf947aa0d3dc7d27b2e1e8eda2be8970586c3274"}, + {file = "six-1.17.0.tar.gz", hash = "sha256:ff70335d468e7eb6ec65b95b99d3a2836546063f63acc5171de367e834932a81"}, +] + [[package]] name = "sympy" version = "1.13.3" @@ -762,6 +906,40 @@ mpmath = ">=1.1.0,<1.4" [package.extras] dev = ["hypothesis (>=6.70.0)", "pytest (>=7.1.0)"] +[[package]] +name = "tensorboard" +version = "2.18.0" +description = "TensorBoard lets you watch Tensors Flow" +optional = false +python-versions = ">=3.9" +files = [ + {file = "tensorboard-2.18.0-py3-none-any.whl", hash = "sha256:107ca4821745f73e2aefa02c50ff70a9b694f39f790b11e6f682f7d326745eab"}, +] + +[package.dependencies] +absl-py = ">=0.4" +grpcio = ">=1.48.2" +markdown = ">=2.6.8" +numpy = ">=1.12.0" +packaging = "*" +protobuf = ">=3.19.6,<4.24.0 || >4.24.0" +setuptools = ">=41.0.0" +six = ">1.9" +tensorboard-data-server = ">=0.7.0,<0.8.0" +werkzeug = ">=1.0.1" + +[[package]] +name = "tensorboard-data-server" +version = "0.7.2" +description = "Fast data loading for TensorBoard" +optional = false +python-versions = ">=3.7" +files = [ + {file = "tensorboard_data_server-0.7.2-py3-none-any.whl", hash = "sha256:7e0610d205889588983836ec05dc098e80f97b7e7bbff7e994ebb78f578d0ddb"}, + {file = "tensorboard_data_server-0.7.2-py3-none-macosx_10_9_x86_64.whl", hash = "sha256:9fe5d24221b29625dbc7328b0436ca7fc1c23de4acf4d272f1180856e32f9f60"}, + {file = "tensorboard_data_server-0.7.2-py3-none-manylinux_2_31_x86_64.whl", hash = "sha256:ef687163c24185ae9754ed5650eb5bc4d84ff257aabdc33f0cc6f74d8ba54530"}, +] + [[package]] name = "torch" version = "2.4.0" @@ -931,7 +1109,24 @@ platformdirs = ">=3.9.1,<5" docs = ["furo (>=2023.7.26)", "proselint (>=0.13)", "sphinx (>=7.1.2,!=7.3)", "sphinx-argparse (>=0.4)", "sphinxcontrib-towncrier (>=0.2.1a0)", "towncrier (>=23.6)"] test = ["covdefaults (>=2.3)", "coverage (>=7.2.7)", "coverage-enable-subprocess (>=1)", "flaky (>=3.7)", "packaging (>=23.1)", "pytest (>=7.4)", "pytest-env (>=0.8.2)", "pytest-freezer (>=0.4.8)", "pytest-mock (>=3.11.1)", "pytest-randomly (>=3.12)", "pytest-timeout (>=2.1)", "setuptools (>=68)", "time-machine (>=2.10)"] +[[package]] +name = "werkzeug" +version = "3.1.3" +description = "The comprehensive WSGI web application library." +optional = false +python-versions = ">=3.9" +files = [ + {file = "werkzeug-3.1.3-py3-none-any.whl", hash = "sha256:54b78bf3716d19a65be4fceccc0d1d7b89e608834989dfae50ea87564639213e"}, + {file = "werkzeug-3.1.3.tar.gz", hash = "sha256:60723ce945c19328679790e3282cc758aa4a6040e4bb330f53d30fa546d44746"}, +] + +[package.dependencies] +MarkupSafe = ">=2.1.1" + +[package.extras] +watchdog = ["watchdog (>=2.3)"] + [metadata] lock-version = "2.0" python-versions = "^3.11" -content-hash = "957209552fd1fa2525c1636ed99aa271337030747da68f360552039b81448b5f" +content-hash = "6da195491141b1d8c0be32d26ceb13ad42030d13cf2b9cb99c89051cbb7ef232" diff --git a/pyproject.toml b/pyproject.toml index 0ce186a..3b44056 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,6 +14,7 @@ hydra-core = "1.3.2" tqdm = "4.66.5" torchvision = "0.19.0" natsort = "8.4.0" +tensorboard = "2.18.0" [build-system] diff --git a/src/configs/celeba.yaml b/src/configs/celeba.yaml index a9eed9d..c74805c 100644 --- a/src/configs/celeba.yaml +++ b/src/configs/celeba.yaml @@ -5,7 +5,7 @@ defaults: - _self_ optimizer: _target_: torch.optim.Adam - lr: 2e-4 + lr: 1e-4 lr_scheduler: _target_: torch.optim.lr_scheduler.ReduceLROnPlateau mode: min @@ -15,16 +15,18 @@ lr_scheduler: loss_func: _target_: modules.utils.losses.GlowLoss trainer: + run_name: baseline n_bins: 32 n_epochs: 2 train_test_split: 0.85 train_batch_size: 16 test_batch_size: 16 image_size: 64 - log_steps: 10 - sampling_iters: 30 + log_steps: 50 + log_dir: ./runs + sampling_steps: 50 n_samples: 10 samples_dir: samples - save_dir: glow + save_dir: ./glow seed: 42 use_ddp: false diff --git a/src/configs/model.yaml b/src/configs/model.yaml deleted file mode 100644 index e69de29..0000000 diff --git a/src/model/affine_coupling/affine_coupling.py b/src/model/affine_coupling/affine_coupling.py index 7cd2b86..f399091 100644 --- a/src/model/affine_coupling/affine_coupling.py +++ b/src/model/affine_coupling/affine_coupling.py @@ -1,5 +1,6 @@ import torch from torch import Tensor +from torch.nn import functional as F from model.affine_coupling.net import NN from model.invert_block import InvertBlock @@ -23,22 +24,22 @@ class AffineCoupling(InvertBlock): """ def __init__(self, in_ch: int, hidden_ch: int): - super(AffineCoupling, self).__init__() + super().__init__() self.net = NN(in_ch, hidden_ch) def forward(self, x: Tensor) -> tuple[Tensor, Tensor]: x_a, x_b = x.chunk(2, dim=1) - log_s, t = self.net(x_b) - s = torch.exp(log_s) + log_s, t = self.net(x_a) + s = F.sigmoid(log_s + 2) log_det = torch.sum(torch.log(s).view(x.shape[0], -1), 1) - y_a = x_a * s + t - y_b = x_b + y_b = (x_b + t) * s + y_a = x_a return torch.concat([y_a, y_b], dim=1), log_det def reverse(self, y: Tensor) -> Tensor: y_a, y_b = y.chunk(2, dim=1) - log_s, t = self.net(y_b) - s = torch.exp(log_s) - x_a = (y_a - t) / s - x_b = y_b + log_s, t = self.net(y_a) + s = F.sigmoid(log_s + 2) + x_b = y_b / s - t + x_a = y_a return torch.concat([x_a, x_b], dim=1) diff --git a/src/model/affine_coupling/net.py b/src/model/affine_coupling/net.py index 9579888..e7f5a56 100644 --- a/src/model/affine_coupling/net.py +++ b/src/model/affine_coupling/net.py @@ -19,13 +19,11 @@ class NN(nn.Module): """ def __init__(self, in_ch: int, hidden_ch: int): - super(NN, self).__init__() - conv1 = nn.Conv2d(in_ch // 2, hidden_ch, 3, padding=1) - conv2 = nn.Conv2d(hidden_ch, hidden_ch, 1) + super().__init__() self.net = nn.Sequential( - conv1, + nn.Conv2d(in_ch // 2, hidden_ch, 3, padding=1), nn.ReLU(inplace=True), - conv2, + nn.Conv2d(hidden_ch, hidden_ch, 1), nn.ReLU(inplace=True), ZeroConv2d(hidden_ch, in_ch), ) diff --git a/src/model/flow_block.py b/src/model/flow_block.py index 5556732..6cb645c 100644 --- a/src/model/flow_block.py +++ b/src/model/flow_block.py @@ -80,7 +80,7 @@ def forward(self, x: Tensor) -> tuple[Tensor, Tensor, Tensor, Tensor]: # split out on 2 parts out, z_new = out.chunk(2, dim=1) log_p = self.__get_prob_density( - prior_out=out, out=out, batch_size=batch_size + prior_out=out, out=z_new, batch_size=batch_size ) else: # for the last level prior distribution diff --git a/src/model/glow.py b/src/model/glow.py index b530f33..53f5be5 100644 --- a/src/model/glow.py +++ b/src/model/glow.py @@ -1,3 +1,5 @@ +from copy import deepcopy + from torch import Tensor, nn from model.flow_block import FlowBlock @@ -30,8 +32,8 @@ def __init__( coupling_hidden_ch: int = 512, squeeze_factor: int = 2, ): - super(Glow, self).__init__() - self.in_ch = in_ch + super().__init__() + self.in_ch = deepcopy(in_ch) self.n_flows = n_flows self.num_blocks = num_blocks self.squeeze_factor = squeeze_factor diff --git a/src/model/invert_conv.py b/src/model/invert_conv.py index 1ceebb4..9139454 100644 --- a/src/model/invert_conv.py +++ b/src/model/invert_conv.py @@ -13,6 +13,10 @@ class InvertConv(InvertBlock): so determinant's calculation has not cubic, but linear complexity. + having fixed the P, we restrict some class of all transformations; + having achieved that the elements on the L & U diagonal will be positive + we will be sure that the weights matrix is invertible + attrs (trainable) ---------- ut_matrix: nn.Parameter diff --git a/src/modules/logger/__init__.py b/src/modules/logger/__init__.py index e69de29..728022a 100644 --- a/src/modules/logger/__init__.py +++ b/src/modules/logger/__init__.py @@ -0,0 +1 @@ +from modules.logger.logger import TensorboardLogger diff --git a/src/modules/logger/logger.py b/src/modules/logger/logger.py new file mode 100644 index 0000000..c4e2047 --- /dev/null +++ b/src/modules/logger/logger.py @@ -0,0 +1,30 @@ +import os + +from torch import Tensor +from torch.utils.tensorboard import SummaryWriter + + +class TensorboardLogger: + def __init__(self, log_dir: str, run_name: str, log_steps: int): + log_dir = os.path.join(log_dir, run_name) + + if not os.path.exists(log_dir): + os.makedirs(log_dir) + + self.log_dir = log_dir + self.log_steps = log_steps + self.writer = SummaryWriter(log_dir=log_dir) + + def __del__(self): + self.writer.flush() + self.writer.close() + + def log_train_loss(self, loss: float, step: int): + if step % self.log_steps == 0: + self.writer.add_scalar("Loss/train", loss, step) + + def log_test_loss(self, loss: float, epoch: int): + self.writer.add_scalar("Loss/test", loss, epoch) + + def log_images(self, grid: Tensor, step: int): + self.writer.add_image(tag="samples", img_tensor=grid, global_step=step) diff --git a/src/modules/trainer/ddp.py b/src/modules/trainer/ddp.py deleted file mode 100644 index e69de29..0000000 diff --git a/src/modules/trainer/ddp_trainer.py b/src/modules/trainer/ddp_trainer.py index 54ac93c..a3859f9 100644 --- a/src/modules/trainer/ddp_trainer.py +++ b/src/modules/trainer/ddp_trainer.py @@ -65,6 +65,5 @@ def train(self): for i in tqdm(range(self.train_config.epochs)): self.ddp.set_train_epoch(i) - train_loss = self.train_epoch() - test_loss = self.test_epoch() - print(f"Train loss: {train_loss}, Test loss: {test_loss}", flush=True) + self.train_epoch() + self.test_epoch() diff --git a/src/modules/trainer/trainer.py b/src/modules/trainer/trainer.py index 903c06c..6ef55c5 100644 --- a/src/modules/trainer/trainer.py +++ b/src/modules/trainer/trainer.py @@ -1,8 +1,10 @@ +import logging import os from typing import Callable import torch from omegaconf import DictConfig +from PIL import Image from torch import nn from torch.optim import Optimizer from torch.optim.lr_scheduler import LRScheduler @@ -11,10 +13,14 @@ from tqdm import tqdm from model.glow import Glow +from modules.logger import TensorboardLogger from modules.utils.sampling import get_z_list from modules.utils.tensors import dequantize from modules.utils.train import SizedDataset, train_test_split +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + class Trainer: def __init__( @@ -41,6 +47,11 @@ def __init__( ) self.train_loader = None self.test_loader = None + self.logger = TensorboardLogger( + log_dir=self.train_config.log_dir, + run_name=self.train_config.run_name, + log_steps=self.train_config.log_steps, + ) self.z_list = get_z_list( glow=self.model, @@ -51,11 +62,11 @@ def __init__( self.z_list = [z_i.to(self.device) for z_i in self.z_list] - def train_epoch(self) -> float: + def train_epoch(self, epoch: int) -> float: self.model.train() - run_train_loss = 0.0 + run_train_loss, n_iters = 0.0, 0 for i, images in enumerate(self.train_loader): - images = dequantize(images) + images = dequantize(images, n_bins=self.train_config.n_bins) images = images.to(self.device) self.optimizer.zero_grad() outputs = self.model(images) @@ -63,9 +74,13 @@ def train_epoch(self) -> float: loss.backward() self.optimizer.step() run_train_loss += loss.item() + n_iters += 1 - if i % self.train_config.sampling_iters == 0: - self.save_sample(label=f"iter_{i}") + if i % self.train_config.sampling_steps == 0 and i != 0: + self.log_samples(step=epoch + i) + avg_loss = run_train_loss / n_iters + self.logger.log_train_loss(loss=avg_loss, step=epoch + i) + logger.info(f"Train avg loss: {avg_loss}") return run_train_loss @@ -73,8 +88,8 @@ def train_epoch(self) -> float: def test_epoch(self) -> float: self.model.eval() run_test_loss = 0.0 - for images, _ in self.test_loader: - images = dequantize(images) + for images in self.test_loader: + images = dequantize(images, self.train_config.n_bins) images = images.to(self.device) outputs = self.model(images) run_test_loss += self.loss_func(outputs, images).item() @@ -97,20 +112,28 @@ def train(self): self.model = nn.DataParallel(self.model).to(self.device) with torch.no_grad(): - images = dequantize(next(iter(self.test_loader))) + images = dequantize( + next(iter(self.test_loader)), n_bins=self.train_config.n_bins + ) images = images.to(self.device) self.model.module(images) for i in tqdm(range(self.train_config.n_epochs)): - train_loss = self.train_epoch() + train_loss = self.train_epoch(1) train_loss /= len(self.train_dataset) test_loss = self.test_epoch() test_loss /= len(self.test_dataset) + self.logger.log_test_loss(loss=test_loss, epoch=i + 1) + self.lr_scheduler.step(test_loss) + self.save_checkpoint(epoch=i) def save_checkpoint(self, epoch: int): + if not os.path.exists(self.train_config.save_dir): + os.makedirs(self.train_config.save_dir, exist_ok=True) + torch.save( self.model.state_dict(), f"{self.train_config.save_dir}/model_{epoch}.pt" ) @@ -120,15 +143,22 @@ def save_checkpoint(self, epoch: int): ) @torch.inference_mode() - def save_sample(self, label: str): - # todo: change to logging (tensorboard) - if not os.path.exists(self.train_config.samples_dir): - os.makedirs(self.train_config.samples_dir) - - utils.save_image( - self.model.module.reverse(self.z_list).cpu().data, - f"{self.train_config.samples_dir}/{label}.png", - normalize=True, - nrow=10, - value_range=(-0.5, 0.5), - ) + def log_samples(self, step: int, save_png: bool = True): + data = self.model.module.reverse(self.z_list).cpu().data + grid = utils.make_grid(data, nrow=5, normalize=True, value_range=(-0.5, 0.5)) + self.logger.log_images(grid=grid, step=step) + + if save_png: + if not os.path.exists(self.train_config.samples_dir): + os.makedirs(self.train_config.samples_dir) + + np_array = ( + grid.mul(255) + .add_(0.5) + .clamp_(0, 255) + .permute(1, 2, 0) + .to("cpu", torch.uint8) + .numpy() + ) + im = Image.fromarray(np_array) + im.save(f"{self.train_config.samples_dir}/{step}.png") diff --git a/src/modules/utils/tensors.py b/src/modules/utils/tensors.py index 12b040e..24fb4b2 100644 --- a/src/modules/utils/tensors.py +++ b/src/modules/utils/tensors.py @@ -1,3 +1,5 @@ +import math + import torch from torch import Tensor @@ -24,6 +26,11 @@ def reverse_squeeze(x: Tensor, factor: int = 2) -> Tensor: def dequantize(x: Tensor, n_bins: int = 256) -> Tensor: x = x * 255 + n_bits = math.log(n_bins, 2) + + if n_bits < 8: + x = torch.floor(x / 2 ** (8 - n_bits)) + x = x / n_bins - 0.5 x = x + torch.rand_like(x) / n_bins return x diff --git a/tests/conftest.py b/tests/conftest.py index 97bf10b..5c670f4 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,4 +2,5 @@ "fixtures.blocks", "fixtures.config", "fixtures.inputs", + "fixtures.trainer", ] diff --git a/tests/fixtures/blocks.py b/tests/fixtures/blocks.py index 0caeddf..7da4507 100644 --- a/tests/fixtures/blocks.py +++ b/tests/fixtures/blocks.py @@ -22,14 +22,14 @@ def invert_conv(): @pytest.fixture(scope="function") def affine_coupling(): return AffineCoupling( - in_ch=TestConfig.in_ch, hidden_ch=TestConfig.coupling_hidden_ch + in_ch=TestConfig.in_ch * 2, hidden_ch=TestConfig.coupling_hidden_ch ) @pytest.fixture(scope="function") def flow(): return Flow( - in_ch=TestConfig.in_ch, coupling_hidden_ch=TestConfig.coupling_hidden_ch + in_ch=TestConfig.in_ch * 2, coupling_hidden_ch=TestConfig.coupling_hidden_ch ) @@ -54,7 +54,7 @@ def last_flow_block(): ) -@pytest.fixture(scope="function") +@pytest.fixture(scope="session") def glow(): return Glow( in_ch=TestConfig.in_ch, diff --git a/tests/fixtures/config.py b/tests/fixtures/config.py index 54e12e1..63d5c63 100644 --- a/tests/fixtures/config.py +++ b/tests/fixtures/config.py @@ -1,4 +1,5 @@ import random +from typing import Any import pytest import torch @@ -15,6 +16,24 @@ class TestConfig: squeeze_factor: int = 2 coupling_hidden_ch: int = 512 + dataset_size: int = 100 + n_bins: int = 256 + trainer_config: dict[str, Any] = { + "train_test_split": 0.8, + "train_batch_size": batch_size, + "test_batch_size": batch_size, + "n_bins": n_bins, + "sampling_steps": 5, + "n_epochs": 2, + "n_samples": 4, + "log_dir": "./logs", + "run_name": "test_run", + "log_steps": 10, + "save_dir": "./checkpoints", + "samples_dir": "./samples", + "image_size": image_size, + } + @pytest.fixture(autouse=True) def set_seed(): diff --git a/tests/fixtures/inputs.py b/tests/fixtures/inputs.py index ab0c652..cd283d2 100644 --- a/tests/fixtures/inputs.py +++ b/tests/fixtures/inputs.py @@ -15,6 +15,16 @@ def input_batch(): ) +@pytest.fixture(scope="module") +def flow_input_batch(): + return torch.randn( + TestConfig.batch_size, + TestConfig.in_ch * 2, + TestConfig.image_size // 2, + TestConfig.image_size // 2, + ) + + @pytest.fixture(scope="function") def z_sample(glow): return get_z_list( diff --git a/tests/fixtures/trainer.py b/tests/fixtures/trainer.py new file mode 100644 index 0000000..93729d5 --- /dev/null +++ b/tests/fixtures/trainer.py @@ -0,0 +1,47 @@ +from unittest.mock import MagicMock + +import pytest +import torch +from fixtures.config import TestConfig +from omegaconf import OmegaConf +from torch.optim import Adam + +from modules.trainer import Trainer +from modules.utils.losses import GlowLoss +from modules.utils.train import SizedDataset + + +class MockDataset(SizedDataset): + def __init__(self, data_size: int, data_shape: tuple): + self.data = torch.randn(data_size, *data_shape) + + def __len__(self): + return len(self.data) + + def __getitem__(self, idx): + return self.data[idx] + + +@pytest.fixture(scope="session") +def trainer(glow): + dataset = MockDataset( + TestConfig.dataset_size, + (TestConfig.in_ch, TestConfig.image_size, TestConfig.image_size), + ) + loss_func = GlowLoss(n_bins=TestConfig.n_bins) + optimizer = Adam(glow.parameters(), lr=1e-3) + hydra_cfg = OmegaConf.create({"trainer": TestConfig.trainer_config}) + lr_scheduler = MagicMock() + device = torch.device("cpu") + + trainer = Trainer( + model=glow, + config=hydra_cfg, + dataset=dataset, + loss_func=loss_func, + optimizer=optimizer, + lr_scheduler=lr_scheduler, + device=device, + ) + + return trainer diff --git a/tests/test_actnorm.py b/tests/test_actnorm.py index 7dfd73b..082745c 100644 --- a/tests/test_actnorm.py +++ b/tests/test_actnorm.py @@ -31,7 +31,7 @@ def test_norm_mean(self, act_norm, input_batch): # compute mean per channel mean = out.mean(dim=[0, 2, 3]) assert torch.allclose( - mean, torch.zeros_like(mean), atol=1e-4 + mean, torch.zeros_like(mean), atol=1e-3 ), f"channels means after norm should be ~0, got {mean.tolist()}" def test_norm_std(self, act_norm, input_batch): diff --git a/tests/test_affine_coupling.py b/tests/test_affine_coupling.py index ec5a929..9b07e6f 100644 --- a/tests/test_affine_coupling.py +++ b/tests/test_affine_coupling.py @@ -1,25 +1,24 @@ import torch -from fixtures.config import TestConfig class TestAffineCoupling: - def test_forward(self, affine_coupling, input_batch): - out, _ = affine_coupling(input_batch) - assert out.shape == input_batch.shape, "out shape != input shape" + def test_forward(self, affine_coupling, flow_input_batch): + out, _ = affine_coupling(flow_input_batch) + assert out.shape == flow_input_batch.shape, "out shape != input shape" - def test_reverse(self, affine_coupling, input_batch): - out, _ = affine_coupling(input_batch) + def test_reverse(self, affine_coupling, flow_input_batch): + out, _ = affine_coupling(flow_input_batch) reverse_out = affine_coupling.reverse(out) assert torch.allclose( - input_batch, reverse_out, atol=1e-5 + flow_input_batch, reverse_out, atol=1e-5 ), "batch after reverse != input batch." - def test_log_det(self, affine_coupling, input_batch): - torch.manual_seed(TestConfig.seed) - out, log_det = affine_coupling(input_batch) - - assert torch.allclose( - input_batch, out, atol=1e-6 - ), "Identity behavior check failed" - assert torch.all(log_det == 0), "log det ~= zero for identity transformation" + # def test_log_det(self, affine_coupling, flow_input_batch): + # torch.manual_seed(TestConfig.seed) + # out, log_det = affine_coupling(flow_input_batch) + # + # assert torch.allclose( + # flow_input_batch, out, atol=1e-6 + # ), "Identity behavior check failed" + # assert torch.all(log_det == 0), "log det ~= zero for identity transformation" diff --git a/tests/test_flow.py b/tests/test_flow.py index 217db70..a54ba5b 100644 --- a/tests/test_flow.py +++ b/tests/test_flow.py @@ -6,9 +6,9 @@ class TestFlow: - def test_forward(self, flow, input_batch): - out, _ = flow(input_batch) - assert out.shape == input_batch.shape, "out shape != input shape" + def test_forward(self, flow, flow_input_batch): + out, _ = flow(flow_input_batch) + assert out.shape == flow_input_batch.shape, "out shape != input shape" assert len(flow.layers) == 3, "flow block contains 3 layers" assert isinstance(flow.layers[0], ActNorm), "act norm should be the 1st layer" assert isinstance( @@ -18,20 +18,22 @@ def test_forward(self, flow, input_batch): flow.layers[2], AffineCoupling ), "coupling should be the last layer" - def test_reverse(self, flow, input_batch): - out, _ = flow(input_batch) + def test_reverse(self, flow, flow_input_batch): + out, _ = flow(flow_input_batch) reverse_out = flow.reverse(out) - assert reverse_out.shape == input_batch.shape, "reverse shape != input shape" + assert ( + reverse_out.shape == flow_input_batch.shape + ), "reverse shape != input shape" assert torch.allclose( - input_batch, reverse_out, atol=1e-5 + flow_input_batch, reverse_out, atol=1e-5 ), "batch after reverse != input batch." - def test_log_det(self, flow, input_batch): - out, test_log_det = flow(input_batch) + def test_log_det(self, flow, flow_input_batch): + out, test_log_det = flow(flow_input_batch) expected_log_det = 0 - x = input_batch + x = flow_input_batch for layer in flow.layers: x, log_det = layer(x) expected_log_det = expected_log_det + log_det diff --git a/tests/test_trainer.py b/tests/test_trainer.py new file mode 100644 index 0000000..723f88d --- /dev/null +++ b/tests/test_trainer.py @@ -0,0 +1,59 @@ +from unittest.mock import MagicMock, patch + +from torch import nn +from torch.utils.data import DataLoader + + +class TestTrainer: + def test_init(self, trainer): + assert trainer.model is not None + assert trainer.train_loader is None + assert trainer.test_loader is None + assert trainer.logger is not None + assert len(trainer.z_list) == trainer.model.num_blocks + assert trainer.z_list[0].shape[0] == trainer.train_config.n_samples + + def test_train_epoch(self, trainer): + trainer.train_loader = DataLoader( + trainer.train_dataset, batch_size=trainer.train_config.train_batch_size + ) + loss = trainer.train_epoch(epoch=1) + assert isinstance(loss, float) + + def test_test_epoch(self, trainer): + trainer.test_loader = DataLoader( + trainer.test_dataset, batch_size=trainer.train_config.test_batch_size + ) + loss = trainer.test_epoch() + assert isinstance(loss, float) + + @patch("modules.trainer.trainer.os.makedirs") + @patch("modules.trainer.trainer.torch.save") + def test_save_checkpoint(self, mock_torch_save, mock_makedirs, trainer): + trainer.save_checkpoint(epoch=1) + mock_makedirs.assert_called_with(trainer.train_config.save_dir, exist_ok=True) + assert mock_torch_save.call_count == 2 + + @patch("PIL.Image.Image.save") + def test_log_samples(self, mock_image_save, trainer): + trainer.model = nn.DataParallel(trainer.model).to(trainer.device) + trainer.logger = MagicMock() + + trainer.log_samples(step=1) + mock_image_save.assert_called_with(f"{trainer.train_config.samples_dir}/1.png") + trainer.logger.log_images.assert_called_once_with( + grid=trainer.logger.log_images.call_args.kwargs["grid"], step=1 + ) + + def test_train(self, trainer): + trainer.train_epoch = MagicMock(return_value=0.45) + trainer.test_epoch = MagicMock(return_value=0.5) + trainer.lr_scheduler.step = MagicMock() + trainer.save_checkpoint = MagicMock() + + trainer.train() + + assert trainer.train_epoch.call_count == trainer.train_config.n_epochs + assert trainer.test_epoch.call_count == trainer.train_config.n_epochs + trainer.lr_scheduler.step.assert_called() + trainer.save_checkpoint.assert_called()