From d28101a97edb89fe0a0fc1da052fec91c4295bb0 Mon Sep 17 00:00:00 2001
From: Sofiia Kashperova <47923051+kashperova@users.noreply.github.com>
Date: Sat, 21 Dec 2024 16:33:58 +0200
Subject: [PATCH] Logger & trainer tests (#5)
* Add tensorboard logger
* Minor
* Update affine coupling
* Fix flow block
Add autoflake
* Minor fixes
* Add tests for trainer
---
.gitignore | 3 +
.misc/notes.md | 5 +
.pre-commit-config.yaml | 10 +
poetry.lock | 197 ++++++++++++++++++-
pyproject.toml | 1 +
src/configs/celeba.yaml | 10 +-
src/configs/model.yaml | 0
src/model/affine_coupling/affine_coupling.py | 19 +-
src/model/affine_coupling/net.py | 8 +-
src/model/flow_block.py | 2 +-
src/model/glow.py | 6 +-
src/model/invert_conv.py | 4 +
src/modules/logger/__init__.py | 1 +
src/modules/logger/logger.py | 30 +++
src/modules/trainer/ddp.py | 0
src/modules/trainer/ddp_trainer.py | 5 +-
src/modules/trainer/trainer.py | 72 +++++--
src/modules/utils/tensors.py | 7 +
tests/conftest.py | 1 +
tests/fixtures/blocks.py | 6 +-
tests/fixtures/config.py | 19 ++
tests/fixtures/inputs.py | 10 +
tests/fixtures/trainer.py | 47 +++++
tests/test_actnorm.py | 2 +-
tests/test_affine_coupling.py | 29 ++-
tests/test_flow.py | 22 ++-
tests/test_trainer.py | 59 ++++++
27 files changed, 500 insertions(+), 75 deletions(-)
create mode 100644 .misc/notes.md
delete mode 100644 src/configs/model.yaml
create mode 100644 src/modules/logger/logger.py
delete mode 100644 src/modules/trainer/ddp.py
create mode 100644 tests/fixtures/trainer.py
create mode 100644 tests/test_trainer.py
diff --git a/.gitignore b/.gitignore
index 95ec0e3..2e93839 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1 +1,4 @@
todo.py
+./runs
+./samples
+./.misc/notebooks
diff --git a/.misc/notes.md b/.misc/notes.md
new file mode 100644
index 0000000..a2edbcf
--- /dev/null
+++ b/.misc/notes.md
@@ -0,0 +1,5 @@
+- Norm flows can't can't work with discrete random variables, so we need to dequantize input image tensors.
+Here the simplest solution [implemented](../src/modules/utils/tensors.py): adding a small amount of noise to each discrete value.
+But in general it is better to use variational dequantization.
+- Read more about KL duality
+- Jacobian can be interpreted as an indicator of how the volume of the probability space changes
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 2e2afaf..3298f4b 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -23,6 +23,16 @@ repos:
hooks:
- id: isort
args: [ "--profile", "black", "--filter-files" ]
+ - repo: https://github.com/PyCQA/autoflake
+ rev: v2.3.1
+ hooks:
+ - id: autoflake
+ args:
+ - '--remove-all-unused-imports'
+ - '--remove-unused-variables'
+ - '--exclude=__init__.py'
+ - '--in-place'
+ - '--recursive'
- repo: https://github.com/PyCQA/flake8
rev: 7.1.1
hooks:
diff --git a/poetry.lock b/poetry.lock
index 2c72a4f..c1bf610 100644
--- a/poetry.lock
+++ b/poetry.lock
@@ -1,5 +1,16 @@
# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand.
+[[package]]
+name = "absl-py"
+version = "2.1.0"
+description = "Abseil Python Common Libraries, see https://github.com/abseil/abseil-py."
+optional = false
+python-versions = ">=3.7"
+files = [
+ {file = "absl-py-2.1.0.tar.gz", hash = "sha256:7820790efbb316739cde8b4e19357243fc3608a152024288513dd968d7d959ff"},
+ {file = "absl_py-2.1.0-py3-none-any.whl", hash = "sha256:526a04eadab8b4ee719ce68f204172ead1027549089702d99b9059f129ff1308"},
+]
+
[[package]]
name = "antlr4-python3-runtime"
version = "4.9.3"
@@ -98,6 +109,73 @@ test-downstream = ["aiobotocore (>=2.5.4,<3.0.0)", "dask-expr", "dask[dataframe,
test-full = ["adlfs", "aiohttp (!=4.0.0a0,!=4.0.0a1)", "cloudpickle", "dask", "distributed", "dropbox", "dropboxdrivefs", "fastparquet", "fusepy", "gcsfs", "jinja2", "kerchunk", "libarchive-c", "lz4", "notebook", "numpy", "ocifs", "pandas", "panel", "paramiko", "pyarrow", "pyarrow (>=1)", "pyftpdlib", "pygit2", "pytest", "pytest-asyncio (!=0.22.0)", "pytest-benchmark", "pytest-cov", "pytest-mock", "pytest-recording", "pytest-rerunfailures", "python-snappy", "requests", "smbprotocol", "tqdm", "urllib3", "zarr", "zstandard"]
tqdm = ["tqdm"]
+[[package]]
+name = "grpcio"
+version = "1.68.1"
+description = "HTTP/2-based RPC framework"
+optional = false
+python-versions = ">=3.8"
+files = [
+ {file = "grpcio-1.68.1-cp310-cp310-linux_armv7l.whl", hash = "sha256:d35740e3f45f60f3c37b1e6f2f4702c23867b9ce21c6410254c9c682237da68d"},
+ {file = "grpcio-1.68.1-cp310-cp310-macosx_12_0_universal2.whl", hash = "sha256:d99abcd61760ebb34bdff37e5a3ba333c5cc09feda8c1ad42547bea0416ada78"},
+ {file = "grpcio-1.68.1-cp310-cp310-manylinux_2_17_aarch64.whl", hash = "sha256:f8261fa2a5f679abeb2a0a93ad056d765cdca1c47745eda3f2d87f874ff4b8c9"},
+ {file = "grpcio-1.68.1-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0feb02205a27caca128627bd1df4ee7212db051019a9afa76f4bb6a1a80ca95e"},
+ {file = "grpcio-1.68.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:919d7f18f63bcad3a0f81146188e90274fde800a94e35d42ffe9eadf6a9a6330"},
+ {file = "grpcio-1.68.1-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:963cc8d7d79b12c56008aabd8b457f400952dbea8997dd185f155e2f228db079"},
+ {file = "grpcio-1.68.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:ccf2ebd2de2d6661e2520dae293298a3803a98ebfc099275f113ce1f6c2a80f1"},
+ {file = "grpcio-1.68.1-cp310-cp310-win32.whl", hash = "sha256:2cc1fd04af8399971bcd4f43bd98c22d01029ea2e56e69c34daf2bf8470e47f5"},
+ {file = "grpcio-1.68.1-cp310-cp310-win_amd64.whl", hash = "sha256:ee2e743e51cb964b4975de572aa8fb95b633f496f9fcb5e257893df3be854746"},
+ {file = "grpcio-1.68.1-cp311-cp311-linux_armv7l.whl", hash = "sha256:55857c71641064f01ff0541a1776bfe04a59db5558e82897d35a7793e525774c"},
+ {file = "grpcio-1.68.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:4b177f5547f1b995826ef529d2eef89cca2f830dd8b2c99ffd5fde4da734ba73"},
+ {file = "grpcio-1.68.1-cp311-cp311-manylinux_2_17_aarch64.whl", hash = "sha256:3522c77d7e6606d6665ec8d50e867f13f946a4e00c7df46768f1c85089eae515"},
+ {file = "grpcio-1.68.1-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9d1fae6bbf0816415b81db1e82fb3bf56f7857273c84dcbe68cbe046e58e1ccd"},
+ {file = "grpcio-1.68.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:298ee7f80e26f9483f0b6f94cc0a046caf54400a11b644713bb5b3d8eb387600"},
+ {file = "grpcio-1.68.1-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:cbb5780e2e740b6b4f2d208e90453591036ff80c02cc605fea1af8e6fc6b1bbe"},
+ {file = "grpcio-1.68.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:ddda1aa22495d8acd9dfbafff2866438d12faec4d024ebc2e656784d96328ad0"},
+ {file = "grpcio-1.68.1-cp311-cp311-win32.whl", hash = "sha256:b33bd114fa5a83f03ec6b7b262ef9f5cac549d4126f1dc702078767b10c46ed9"},
+ {file = "grpcio-1.68.1-cp311-cp311-win_amd64.whl", hash = "sha256:7f20ebec257af55694d8f993e162ddf0d36bd82d4e57f74b31c67b3c6d63d8b2"},
+ {file = "grpcio-1.68.1-cp312-cp312-linux_armv7l.whl", hash = "sha256:8829924fffb25386995a31998ccbbeaa7367223e647e0122043dfc485a87c666"},
+ {file = "grpcio-1.68.1-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:3aed6544e4d523cd6b3119b0916cef3d15ef2da51e088211e4d1eb91a6c7f4f1"},
+ {file = "grpcio-1.68.1-cp312-cp312-manylinux_2_17_aarch64.whl", hash = "sha256:4efac5481c696d5cb124ff1c119a78bddbfdd13fc499e3bc0ca81e95fc573684"},
+ {file = "grpcio-1.68.1-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6ab2d912ca39c51f46baf2a0d92aa265aa96b2443266fc50d234fa88bf877d8e"},
+ {file = "grpcio-1.68.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:95c87ce2a97434dffe7327a4071839ab8e8bffd0054cc74cbe971fba98aedd60"},
+ {file = "grpcio-1.68.1-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:e4842e4872ae4ae0f5497bf60a0498fa778c192cc7a9e87877abd2814aca9475"},
+ {file = "grpcio-1.68.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:255b1635b0ed81e9f91da4fcc8d43b7ea5520090b9a9ad9340d147066d1d3613"},
+ {file = "grpcio-1.68.1-cp312-cp312-win32.whl", hash = "sha256:7dfc914cc31c906297b30463dde0b9be48e36939575eaf2a0a22a8096e69afe5"},
+ {file = "grpcio-1.68.1-cp312-cp312-win_amd64.whl", hash = "sha256:a0c8ddabef9c8f41617f213e527254c41e8b96ea9d387c632af878d05db9229c"},
+ {file = "grpcio-1.68.1-cp313-cp313-linux_armv7l.whl", hash = "sha256:a47faedc9ea2e7a3b6569795c040aae5895a19dde0c728a48d3c5d7995fda385"},
+ {file = "grpcio-1.68.1-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:390eee4225a661c5cd133c09f5da1ee3c84498dc265fd292a6912b65c421c78c"},
+ {file = "grpcio-1.68.1-cp313-cp313-manylinux_2_17_aarch64.whl", hash = "sha256:66a24f3d45c33550703f0abb8b656515b0ab777970fa275693a2f6dc8e35f1c1"},
+ {file = "grpcio-1.68.1-cp313-cp313-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c08079b4934b0bf0a8847f42c197b1d12cba6495a3d43febd7e99ecd1cdc8d54"},
+ {file = "grpcio-1.68.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8720c25cd9ac25dd04ee02b69256d0ce35bf8a0f29e20577427355272230965a"},
+ {file = "grpcio-1.68.1-cp313-cp313-musllinux_1_1_i686.whl", hash = "sha256:04cfd68bf4f38f5bb959ee2361a7546916bd9a50f78617a346b3aeb2b42e2161"},
+ {file = "grpcio-1.68.1-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:c28848761a6520c5c6071d2904a18d339a796ebe6b800adc8b3f474c5ce3c3ad"},
+ {file = "grpcio-1.68.1-cp313-cp313-win32.whl", hash = "sha256:77d65165fc35cff6e954e7fd4229e05ec76102d4406d4576528d3a3635fc6172"},
+ {file = "grpcio-1.68.1-cp313-cp313-win_amd64.whl", hash = "sha256:a8040f85dcb9830d8bbb033ae66d272614cec6faceee88d37a88a9bd1a7a704e"},
+ {file = "grpcio-1.68.1-cp38-cp38-linux_armv7l.whl", hash = "sha256:eeb38ff04ab6e5756a2aef6ad8d94e89bb4a51ef96e20f45c44ba190fa0bcaad"},
+ {file = "grpcio-1.68.1-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:8a3869a6661ec8f81d93f4597da50336718bde9eb13267a699ac7e0a1d6d0bea"},
+ {file = "grpcio-1.68.1-cp38-cp38-manylinux_2_17_aarch64.whl", hash = "sha256:2c4cec6177bf325eb6faa6bd834d2ff6aa8bb3b29012cceb4937b86f8b74323c"},
+ {file = "grpcio-1.68.1-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:12941d533f3cd45d46f202e3667be8ebf6bcb3573629c7ec12c3e211d99cfccf"},
+ {file = "grpcio-1.68.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:80af6f1e69c5e68a2be529990684abdd31ed6622e988bf18850075c81bb1ad6e"},
+ {file = "grpcio-1.68.1-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:e8dbe3e00771bfe3d04feed8210fc6617006d06d9a2679b74605b9fed3e8362c"},
+ {file = "grpcio-1.68.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:83bbf5807dc3ee94ce1de2dfe8a356e1d74101e4b9d7aa8c720cc4818a34aded"},
+ {file = "grpcio-1.68.1-cp38-cp38-win32.whl", hash = "sha256:8cb620037a2fd9eeee97b4531880e439ebfcd6d7d78f2e7dcc3726428ab5ef63"},
+ {file = "grpcio-1.68.1-cp38-cp38-win_amd64.whl", hash = "sha256:52fbf85aa71263380d330f4fce9f013c0798242e31ede05fcee7fbe40ccfc20d"},
+ {file = "grpcio-1.68.1-cp39-cp39-linux_armv7l.whl", hash = "sha256:cb400138e73969eb5e0535d1d06cae6a6f7a15f2cc74add320e2130b8179211a"},
+ {file = "grpcio-1.68.1-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:a1b988b40f2fd9de5c820f3a701a43339d8dcf2cb2f1ca137e2c02671cc83ac1"},
+ {file = "grpcio-1.68.1-cp39-cp39-manylinux_2_17_aarch64.whl", hash = "sha256:96f473cdacfdd506008a5d7579c9f6a7ff245a9ade92c3c0265eb76cc591914f"},
+ {file = "grpcio-1.68.1-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:37ea3be171f3cf3e7b7e412a98b77685eba9d4fd67421f4a34686a63a65d99f9"},
+ {file = "grpcio-1.68.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3ceb56c4285754e33bb3c2fa777d055e96e6932351a3082ce3559be47f8024f0"},
+ {file = "grpcio-1.68.1-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:dffd29a2961f3263a16d73945b57cd44a8fd0b235740cb14056f0612329b345e"},
+ {file = "grpcio-1.68.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:025f790c056815b3bf53da850dd70ebb849fd755a4b1ac822cb65cd631e37d43"},
+ {file = "grpcio-1.68.1-cp39-cp39-win32.whl", hash = "sha256:1098f03dedc3b9810810568060dea4ac0822b4062f537b0f53aa015269be0a76"},
+ {file = "grpcio-1.68.1-cp39-cp39-win_amd64.whl", hash = "sha256:334ab917792904245a028f10e803fcd5b6f36a7b2173a820c0b5b076555825e1"},
+ {file = "grpcio-1.68.1.tar.gz", hash = "sha256:44a8502dd5de653ae6a73e2de50a401d84184f0331d0ac3daeb044e66d5c5054"},
+]
+
+[package.extras]
+protobuf = ["grpcio-tools (>=1.68.1)"]
+
[[package]]
name = "hydra-core"
version = "1.3.2"
@@ -156,6 +234,21 @@ MarkupSafe = ">=2.0"
[package.extras]
i18n = ["Babel (>=2.7)"]
+[[package]]
+name = "markdown"
+version = "3.7"
+description = "Python implementation of John Gruber's Markdown."
+optional = false
+python-versions = ">=3.8"
+files = [
+ {file = "Markdown-3.7-py3-none-any.whl", hash = "sha256:7eb6df5690b81a1d7942992c97fad2938e956e79df20cbc6186e9c3a77b1c803"},
+ {file = "markdown-3.7.tar.gz", hash = "sha256:2ae2471477cfd02dbbf038d5d9bc226d40def84b4fe2986e49b59b6b472bbed2"},
+]
+
+[package.extras]
+docs = ["mdx-gh-links (>=0.2)", "mkdocs (>=1.5)", "mkdocs-gen-files", "mkdocs-literate-nav", "mkdocs-nature (>=0.6)", "mkdocs-section-index", "mkdocstrings[python]"]
+testing = ["coverage", "pyyaml"]
+
[[package]]
name = "markupsafe"
version = "3.0.2"
@@ -663,6 +756,26 @@ nodeenv = ">=0.11.1"
pyyaml = ">=5.1"
virtualenv = ">=20.10.0"
+[[package]]
+name = "protobuf"
+version = "5.29.2"
+description = ""
+optional = false
+python-versions = ">=3.8"
+files = [
+ {file = "protobuf-5.29.2-cp310-abi3-win32.whl", hash = "sha256:c12ba8249f5624300cf51c3d0bfe5be71a60c63e4dcf51ffe9a68771d958c851"},
+ {file = "protobuf-5.29.2-cp310-abi3-win_amd64.whl", hash = "sha256:842de6d9241134a973aab719ab42b008a18a90f9f07f06ba480df268f86432f9"},
+ {file = "protobuf-5.29.2-cp38-abi3-macosx_10_9_universal2.whl", hash = "sha256:a0c53d78383c851bfa97eb42e3703aefdc96d2036a41482ffd55dc5f529466eb"},
+ {file = "protobuf-5.29.2-cp38-abi3-manylinux2014_aarch64.whl", hash = "sha256:494229ecd8c9009dd71eda5fd57528395d1eacdf307dbece6c12ad0dd09e912e"},
+ {file = "protobuf-5.29.2-cp38-abi3-manylinux2014_x86_64.whl", hash = "sha256:b6b0d416bbbb9d4fbf9d0561dbfc4e324fd522f61f7af0fe0f282ab67b22477e"},
+ {file = "protobuf-5.29.2-cp38-cp38-win32.whl", hash = "sha256:e621a98c0201a7c8afe89d9646859859be97cb22b8bf1d8eacfd90d5bda2eb19"},
+ {file = "protobuf-5.29.2-cp38-cp38-win_amd64.whl", hash = "sha256:13d6d617a2a9e0e82a88113d7191a1baa1e42c2cc6f5f1398d3b054c8e7e714a"},
+ {file = "protobuf-5.29.2-cp39-cp39-win32.whl", hash = "sha256:36000f97ea1e76e8398a3f02936aac2a5d2b111aae9920ec1b769fc4a222c4d9"},
+ {file = "protobuf-5.29.2-cp39-cp39-win_amd64.whl", hash = "sha256:2d2e674c58a06311c8e99e74be43e7f3a8d1e2b2fdf845eaa347fbd866f23355"},
+ {file = "protobuf-5.29.2-py3-none-any.whl", hash = "sha256:fde4554c0e578a5a0bcc9a276339594848d1e89f9ea47b4427c80e5d72f90181"},
+ {file = "protobuf-5.29.2.tar.gz", hash = "sha256:b2cc8e8bb7c9326996f0e160137b0861f1a82162502658df2951209d0cb0309e"},
+]
+
[[package]]
name = "pytest"
version = "8.3.4"
@@ -745,6 +858,37 @@ files = [
{file = "pyyaml-6.0.2.tar.gz", hash = "sha256:d584d9ec91ad65861cc08d42e834324ef890a082e591037abe114850ff7bbc3e"},
]
+[[package]]
+name = "setuptools"
+version = "75.6.0"
+description = "Easily download, build, install, upgrade, and uninstall Python packages"
+optional = false
+python-versions = ">=3.9"
+files = [
+ {file = "setuptools-75.6.0-py3-none-any.whl", hash = "sha256:ce74b49e8f7110f9bf04883b730f4765b774ef3ef28f722cce7c273d253aaf7d"},
+ {file = "setuptools-75.6.0.tar.gz", hash = "sha256:8199222558df7c86216af4f84c30e9b34a61d8ba19366cc914424cdbd28252f6"},
+]
+
+[package.extras]
+check = ["pytest-checkdocs (>=2.4)", "pytest-ruff (>=0.2.1)", "ruff (>=0.7.0)"]
+core = ["importlib_metadata (>=6)", "jaraco.collections", "jaraco.functools (>=4)", "jaraco.text (>=3.7)", "more_itertools", "more_itertools (>=8.8)", "packaging", "packaging (>=24.2)", "platformdirs (>=4.2.2)", "tomli (>=2.0.1)", "wheel (>=0.43.0)"]
+cover = ["pytest-cov"]
+doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "pygments-github-lexers (==0.0.5)", "pyproject-hooks (!=1.1)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-favicon", "sphinx-inline-tabs", "sphinx-lint", "sphinx-notfound-page (>=1,<2)", "sphinx-reredirects", "sphinxcontrib-towncrier", "towncrier (<24.7)"]
+enabler = ["pytest-enabler (>=2.2)"]
+test = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "ini2toml[lite] (>=0.14)", "jaraco.develop (>=7.21)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "jaraco.test (>=5.5)", "packaging (>=24.2)", "pip (>=19.1)", "pyproject-hooks (!=1.1)", "pytest (>=6,!=8.1.*)", "pytest-home (>=0.5)", "pytest-perf", "pytest-subprocess", "pytest-timeout", "pytest-xdist (>=3)", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel (>=0.44.0)"]
+type = ["importlib_metadata (>=7.0.2)", "jaraco.develop (>=7.21)", "mypy (>=1.12,<1.14)", "pytest-mypy"]
+
+[[package]]
+name = "six"
+version = "1.17.0"
+description = "Python 2 and 3 compatibility utilities"
+optional = false
+python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7"
+files = [
+ {file = "six-1.17.0-py2.py3-none-any.whl", hash = "sha256:4721f391ed90541fddacab5acf947aa0d3dc7d27b2e1e8eda2be8970586c3274"},
+ {file = "six-1.17.0.tar.gz", hash = "sha256:ff70335d468e7eb6ec65b95b99d3a2836546063f63acc5171de367e834932a81"},
+]
+
[[package]]
name = "sympy"
version = "1.13.3"
@@ -762,6 +906,40 @@ mpmath = ">=1.1.0,<1.4"
[package.extras]
dev = ["hypothesis (>=6.70.0)", "pytest (>=7.1.0)"]
+[[package]]
+name = "tensorboard"
+version = "2.18.0"
+description = "TensorBoard lets you watch Tensors Flow"
+optional = false
+python-versions = ">=3.9"
+files = [
+ {file = "tensorboard-2.18.0-py3-none-any.whl", hash = "sha256:107ca4821745f73e2aefa02c50ff70a9b694f39f790b11e6f682f7d326745eab"},
+]
+
+[package.dependencies]
+absl-py = ">=0.4"
+grpcio = ">=1.48.2"
+markdown = ">=2.6.8"
+numpy = ">=1.12.0"
+packaging = "*"
+protobuf = ">=3.19.6,<4.24.0 || >4.24.0"
+setuptools = ">=41.0.0"
+six = ">1.9"
+tensorboard-data-server = ">=0.7.0,<0.8.0"
+werkzeug = ">=1.0.1"
+
+[[package]]
+name = "tensorboard-data-server"
+version = "0.7.2"
+description = "Fast data loading for TensorBoard"
+optional = false
+python-versions = ">=3.7"
+files = [
+ {file = "tensorboard_data_server-0.7.2-py3-none-any.whl", hash = "sha256:7e0610d205889588983836ec05dc098e80f97b7e7bbff7e994ebb78f578d0ddb"},
+ {file = "tensorboard_data_server-0.7.2-py3-none-macosx_10_9_x86_64.whl", hash = "sha256:9fe5d24221b29625dbc7328b0436ca7fc1c23de4acf4d272f1180856e32f9f60"},
+ {file = "tensorboard_data_server-0.7.2-py3-none-manylinux_2_31_x86_64.whl", hash = "sha256:ef687163c24185ae9754ed5650eb5bc4d84ff257aabdc33f0cc6f74d8ba54530"},
+]
+
[[package]]
name = "torch"
version = "2.4.0"
@@ -931,7 +1109,24 @@ platformdirs = ">=3.9.1,<5"
docs = ["furo (>=2023.7.26)", "proselint (>=0.13)", "sphinx (>=7.1.2,!=7.3)", "sphinx-argparse (>=0.4)", "sphinxcontrib-towncrier (>=0.2.1a0)", "towncrier (>=23.6)"]
test = ["covdefaults (>=2.3)", "coverage (>=7.2.7)", "coverage-enable-subprocess (>=1)", "flaky (>=3.7)", "packaging (>=23.1)", "pytest (>=7.4)", "pytest-env (>=0.8.2)", "pytest-freezer (>=0.4.8)", "pytest-mock (>=3.11.1)", "pytest-randomly (>=3.12)", "pytest-timeout (>=2.1)", "setuptools (>=68)", "time-machine (>=2.10)"]
+[[package]]
+name = "werkzeug"
+version = "3.1.3"
+description = "The comprehensive WSGI web application library."
+optional = false
+python-versions = ">=3.9"
+files = [
+ {file = "werkzeug-3.1.3-py3-none-any.whl", hash = "sha256:54b78bf3716d19a65be4fceccc0d1d7b89e608834989dfae50ea87564639213e"},
+ {file = "werkzeug-3.1.3.tar.gz", hash = "sha256:60723ce945c19328679790e3282cc758aa4a6040e4bb330f53d30fa546d44746"},
+]
+
+[package.dependencies]
+MarkupSafe = ">=2.1.1"
+
+[package.extras]
+watchdog = ["watchdog (>=2.3)"]
+
[metadata]
lock-version = "2.0"
python-versions = "^3.11"
-content-hash = "957209552fd1fa2525c1636ed99aa271337030747da68f360552039b81448b5f"
+content-hash = "6da195491141b1d8c0be32d26ceb13ad42030d13cf2b9cb99c89051cbb7ef232"
diff --git a/pyproject.toml b/pyproject.toml
index 0ce186a..3b44056 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -14,6 +14,7 @@ hydra-core = "1.3.2"
tqdm = "4.66.5"
torchvision = "0.19.0"
natsort = "8.4.0"
+tensorboard = "2.18.0"
[build-system]
diff --git a/src/configs/celeba.yaml b/src/configs/celeba.yaml
index a9eed9d..c74805c 100644
--- a/src/configs/celeba.yaml
+++ b/src/configs/celeba.yaml
@@ -5,7 +5,7 @@ defaults:
- _self_
optimizer:
_target_: torch.optim.Adam
- lr: 2e-4
+ lr: 1e-4
lr_scheduler:
_target_: torch.optim.lr_scheduler.ReduceLROnPlateau
mode: min
@@ -15,16 +15,18 @@ lr_scheduler:
loss_func:
_target_: modules.utils.losses.GlowLoss
trainer:
+ run_name: baseline
n_bins: 32
n_epochs: 2
train_test_split: 0.85
train_batch_size: 16
test_batch_size: 16
image_size: 64
- log_steps: 10
- sampling_iters: 30
+ log_steps: 50
+ log_dir: ./runs
+ sampling_steps: 50
n_samples: 10
samples_dir: samples
- save_dir: glow
+ save_dir: ./glow
seed: 42
use_ddp: false
diff --git a/src/configs/model.yaml b/src/configs/model.yaml
deleted file mode 100644
index e69de29..0000000
diff --git a/src/model/affine_coupling/affine_coupling.py b/src/model/affine_coupling/affine_coupling.py
index 7cd2b86..f399091 100644
--- a/src/model/affine_coupling/affine_coupling.py
+++ b/src/model/affine_coupling/affine_coupling.py
@@ -1,5 +1,6 @@
import torch
from torch import Tensor
+from torch.nn import functional as F
from model.affine_coupling.net import NN
from model.invert_block import InvertBlock
@@ -23,22 +24,22 @@ class AffineCoupling(InvertBlock):
"""
def __init__(self, in_ch: int, hidden_ch: int):
- super(AffineCoupling, self).__init__()
+ super().__init__()
self.net = NN(in_ch, hidden_ch)
def forward(self, x: Tensor) -> tuple[Tensor, Tensor]:
x_a, x_b = x.chunk(2, dim=1)
- log_s, t = self.net(x_b)
- s = torch.exp(log_s)
+ log_s, t = self.net(x_a)
+ s = F.sigmoid(log_s + 2)
log_det = torch.sum(torch.log(s).view(x.shape[0], -1), 1)
- y_a = x_a * s + t
- y_b = x_b
+ y_b = (x_b + t) * s
+ y_a = x_a
return torch.concat([y_a, y_b], dim=1), log_det
def reverse(self, y: Tensor) -> Tensor:
y_a, y_b = y.chunk(2, dim=1)
- log_s, t = self.net(y_b)
- s = torch.exp(log_s)
- x_a = (y_a - t) / s
- x_b = y_b
+ log_s, t = self.net(y_a)
+ s = F.sigmoid(log_s + 2)
+ x_b = y_b / s - t
+ x_a = y_a
return torch.concat([x_a, x_b], dim=1)
diff --git a/src/model/affine_coupling/net.py b/src/model/affine_coupling/net.py
index 9579888..e7f5a56 100644
--- a/src/model/affine_coupling/net.py
+++ b/src/model/affine_coupling/net.py
@@ -19,13 +19,11 @@ class NN(nn.Module):
"""
def __init__(self, in_ch: int, hidden_ch: int):
- super(NN, self).__init__()
- conv1 = nn.Conv2d(in_ch // 2, hidden_ch, 3, padding=1)
- conv2 = nn.Conv2d(hidden_ch, hidden_ch, 1)
+ super().__init__()
self.net = nn.Sequential(
- conv1,
+ nn.Conv2d(in_ch // 2, hidden_ch, 3, padding=1),
nn.ReLU(inplace=True),
- conv2,
+ nn.Conv2d(hidden_ch, hidden_ch, 1),
nn.ReLU(inplace=True),
ZeroConv2d(hidden_ch, in_ch),
)
diff --git a/src/model/flow_block.py b/src/model/flow_block.py
index 5556732..6cb645c 100644
--- a/src/model/flow_block.py
+++ b/src/model/flow_block.py
@@ -80,7 +80,7 @@ def forward(self, x: Tensor) -> tuple[Tensor, Tensor, Tensor, Tensor]:
# split out on 2 parts
out, z_new = out.chunk(2, dim=1)
log_p = self.__get_prob_density(
- prior_out=out, out=out, batch_size=batch_size
+ prior_out=out, out=z_new, batch_size=batch_size
)
else:
# for the last level prior distribution
diff --git a/src/model/glow.py b/src/model/glow.py
index b530f33..53f5be5 100644
--- a/src/model/glow.py
+++ b/src/model/glow.py
@@ -1,3 +1,5 @@
+from copy import deepcopy
+
from torch import Tensor, nn
from model.flow_block import FlowBlock
@@ -30,8 +32,8 @@ def __init__(
coupling_hidden_ch: int = 512,
squeeze_factor: int = 2,
):
- super(Glow, self).__init__()
- self.in_ch = in_ch
+ super().__init__()
+ self.in_ch = deepcopy(in_ch)
self.n_flows = n_flows
self.num_blocks = num_blocks
self.squeeze_factor = squeeze_factor
diff --git a/src/model/invert_conv.py b/src/model/invert_conv.py
index 1ceebb4..9139454 100644
--- a/src/model/invert_conv.py
+++ b/src/model/invert_conv.py
@@ -13,6 +13,10 @@ class InvertConv(InvertBlock):
so determinant's calculation has not cubic,
but linear complexity.
+ having fixed the P, we restrict some class of all transformations;
+ having achieved that the elements on the L & U diagonal will be positive
+ we will be sure that the weights matrix is invertible
+
attrs (trainable)
----------
ut_matrix: nn.Parameter
diff --git a/src/modules/logger/__init__.py b/src/modules/logger/__init__.py
index e69de29..728022a 100644
--- a/src/modules/logger/__init__.py
+++ b/src/modules/logger/__init__.py
@@ -0,0 +1 @@
+from modules.logger.logger import TensorboardLogger
diff --git a/src/modules/logger/logger.py b/src/modules/logger/logger.py
new file mode 100644
index 0000000..c4e2047
--- /dev/null
+++ b/src/modules/logger/logger.py
@@ -0,0 +1,30 @@
+import os
+
+from torch import Tensor
+from torch.utils.tensorboard import SummaryWriter
+
+
+class TensorboardLogger:
+ def __init__(self, log_dir: str, run_name: str, log_steps: int):
+ log_dir = os.path.join(log_dir, run_name)
+
+ if not os.path.exists(log_dir):
+ os.makedirs(log_dir)
+
+ self.log_dir = log_dir
+ self.log_steps = log_steps
+ self.writer = SummaryWriter(log_dir=log_dir)
+
+ def __del__(self):
+ self.writer.flush()
+ self.writer.close()
+
+ def log_train_loss(self, loss: float, step: int):
+ if step % self.log_steps == 0:
+ self.writer.add_scalar("Loss/train", loss, step)
+
+ def log_test_loss(self, loss: float, epoch: int):
+ self.writer.add_scalar("Loss/test", loss, epoch)
+
+ def log_images(self, grid: Tensor, step: int):
+ self.writer.add_image(tag="samples", img_tensor=grid, global_step=step)
diff --git a/src/modules/trainer/ddp.py b/src/modules/trainer/ddp.py
deleted file mode 100644
index e69de29..0000000
diff --git a/src/modules/trainer/ddp_trainer.py b/src/modules/trainer/ddp_trainer.py
index 54ac93c..a3859f9 100644
--- a/src/modules/trainer/ddp_trainer.py
+++ b/src/modules/trainer/ddp_trainer.py
@@ -65,6 +65,5 @@ def train(self):
for i in tqdm(range(self.train_config.epochs)):
self.ddp.set_train_epoch(i)
- train_loss = self.train_epoch()
- test_loss = self.test_epoch()
- print(f"Train loss: {train_loss}, Test loss: {test_loss}", flush=True)
+ self.train_epoch()
+ self.test_epoch()
diff --git a/src/modules/trainer/trainer.py b/src/modules/trainer/trainer.py
index 903c06c..6ef55c5 100644
--- a/src/modules/trainer/trainer.py
+++ b/src/modules/trainer/trainer.py
@@ -1,8 +1,10 @@
+import logging
import os
from typing import Callable
import torch
from omegaconf import DictConfig
+from PIL import Image
from torch import nn
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler
@@ -11,10 +13,14 @@
from tqdm import tqdm
from model.glow import Glow
+from modules.logger import TensorboardLogger
from modules.utils.sampling import get_z_list
from modules.utils.tensors import dequantize
from modules.utils.train import SizedDataset, train_test_split
+logger = logging.getLogger(__name__)
+logger.setLevel(logging.INFO)
+
class Trainer:
def __init__(
@@ -41,6 +47,11 @@ def __init__(
)
self.train_loader = None
self.test_loader = None
+ self.logger = TensorboardLogger(
+ log_dir=self.train_config.log_dir,
+ run_name=self.train_config.run_name,
+ log_steps=self.train_config.log_steps,
+ )
self.z_list = get_z_list(
glow=self.model,
@@ -51,11 +62,11 @@ def __init__(
self.z_list = [z_i.to(self.device) for z_i in self.z_list]
- def train_epoch(self) -> float:
+ def train_epoch(self, epoch: int) -> float:
self.model.train()
- run_train_loss = 0.0
+ run_train_loss, n_iters = 0.0, 0
for i, images in enumerate(self.train_loader):
- images = dequantize(images)
+ images = dequantize(images, n_bins=self.train_config.n_bins)
images = images.to(self.device)
self.optimizer.zero_grad()
outputs = self.model(images)
@@ -63,9 +74,13 @@ def train_epoch(self) -> float:
loss.backward()
self.optimizer.step()
run_train_loss += loss.item()
+ n_iters += 1
- if i % self.train_config.sampling_iters == 0:
- self.save_sample(label=f"iter_{i}")
+ if i % self.train_config.sampling_steps == 0 and i != 0:
+ self.log_samples(step=epoch + i)
+ avg_loss = run_train_loss / n_iters
+ self.logger.log_train_loss(loss=avg_loss, step=epoch + i)
+ logger.info(f"Train avg loss: {avg_loss}")
return run_train_loss
@@ -73,8 +88,8 @@ def train_epoch(self) -> float:
def test_epoch(self) -> float:
self.model.eval()
run_test_loss = 0.0
- for images, _ in self.test_loader:
- images = dequantize(images)
+ for images in self.test_loader:
+ images = dequantize(images, self.train_config.n_bins)
images = images.to(self.device)
outputs = self.model(images)
run_test_loss += self.loss_func(outputs, images).item()
@@ -97,20 +112,28 @@ def train(self):
self.model = nn.DataParallel(self.model).to(self.device)
with torch.no_grad():
- images = dequantize(next(iter(self.test_loader)))
+ images = dequantize(
+ next(iter(self.test_loader)), n_bins=self.train_config.n_bins
+ )
images = images.to(self.device)
self.model.module(images)
for i in tqdm(range(self.train_config.n_epochs)):
- train_loss = self.train_epoch()
+ train_loss = self.train_epoch(1)
train_loss /= len(self.train_dataset)
test_loss = self.test_epoch()
test_loss /= len(self.test_dataset)
+ self.logger.log_test_loss(loss=test_loss, epoch=i + 1)
+ self.lr_scheduler.step(test_loss)
+
self.save_checkpoint(epoch=i)
def save_checkpoint(self, epoch: int):
+ if not os.path.exists(self.train_config.save_dir):
+ os.makedirs(self.train_config.save_dir, exist_ok=True)
+
torch.save(
self.model.state_dict(), f"{self.train_config.save_dir}/model_{epoch}.pt"
)
@@ -120,15 +143,22 @@ def save_checkpoint(self, epoch: int):
)
@torch.inference_mode()
- def save_sample(self, label: str):
- # todo: change to logging (tensorboard)
- if not os.path.exists(self.train_config.samples_dir):
- os.makedirs(self.train_config.samples_dir)
-
- utils.save_image(
- self.model.module.reverse(self.z_list).cpu().data,
- f"{self.train_config.samples_dir}/{label}.png",
- normalize=True,
- nrow=10,
- value_range=(-0.5, 0.5),
- )
+ def log_samples(self, step: int, save_png: bool = True):
+ data = self.model.module.reverse(self.z_list).cpu().data
+ grid = utils.make_grid(data, nrow=5, normalize=True, value_range=(-0.5, 0.5))
+ self.logger.log_images(grid=grid, step=step)
+
+ if save_png:
+ if not os.path.exists(self.train_config.samples_dir):
+ os.makedirs(self.train_config.samples_dir)
+
+ np_array = (
+ grid.mul(255)
+ .add_(0.5)
+ .clamp_(0, 255)
+ .permute(1, 2, 0)
+ .to("cpu", torch.uint8)
+ .numpy()
+ )
+ im = Image.fromarray(np_array)
+ im.save(f"{self.train_config.samples_dir}/{step}.png")
diff --git a/src/modules/utils/tensors.py b/src/modules/utils/tensors.py
index 12b040e..24fb4b2 100644
--- a/src/modules/utils/tensors.py
+++ b/src/modules/utils/tensors.py
@@ -1,3 +1,5 @@
+import math
+
import torch
from torch import Tensor
@@ -24,6 +26,11 @@ def reverse_squeeze(x: Tensor, factor: int = 2) -> Tensor:
def dequantize(x: Tensor, n_bins: int = 256) -> Tensor:
x = x * 255
+ n_bits = math.log(n_bins, 2)
+
+ if n_bits < 8:
+ x = torch.floor(x / 2 ** (8 - n_bits))
+
x = x / n_bins - 0.5
x = x + torch.rand_like(x) / n_bins
return x
diff --git a/tests/conftest.py b/tests/conftest.py
index 97bf10b..5c670f4 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -2,4 +2,5 @@
"fixtures.blocks",
"fixtures.config",
"fixtures.inputs",
+ "fixtures.trainer",
]
diff --git a/tests/fixtures/blocks.py b/tests/fixtures/blocks.py
index 0caeddf..7da4507 100644
--- a/tests/fixtures/blocks.py
+++ b/tests/fixtures/blocks.py
@@ -22,14 +22,14 @@ def invert_conv():
@pytest.fixture(scope="function")
def affine_coupling():
return AffineCoupling(
- in_ch=TestConfig.in_ch, hidden_ch=TestConfig.coupling_hidden_ch
+ in_ch=TestConfig.in_ch * 2, hidden_ch=TestConfig.coupling_hidden_ch
)
@pytest.fixture(scope="function")
def flow():
return Flow(
- in_ch=TestConfig.in_ch, coupling_hidden_ch=TestConfig.coupling_hidden_ch
+ in_ch=TestConfig.in_ch * 2, coupling_hidden_ch=TestConfig.coupling_hidden_ch
)
@@ -54,7 +54,7 @@ def last_flow_block():
)
-@pytest.fixture(scope="function")
+@pytest.fixture(scope="session")
def glow():
return Glow(
in_ch=TestConfig.in_ch,
diff --git a/tests/fixtures/config.py b/tests/fixtures/config.py
index 54e12e1..63d5c63 100644
--- a/tests/fixtures/config.py
+++ b/tests/fixtures/config.py
@@ -1,4 +1,5 @@
import random
+from typing import Any
import pytest
import torch
@@ -15,6 +16,24 @@ class TestConfig:
squeeze_factor: int = 2
coupling_hidden_ch: int = 512
+ dataset_size: int = 100
+ n_bins: int = 256
+ trainer_config: dict[str, Any] = {
+ "train_test_split": 0.8,
+ "train_batch_size": batch_size,
+ "test_batch_size": batch_size,
+ "n_bins": n_bins,
+ "sampling_steps": 5,
+ "n_epochs": 2,
+ "n_samples": 4,
+ "log_dir": "./logs",
+ "run_name": "test_run",
+ "log_steps": 10,
+ "save_dir": "./checkpoints",
+ "samples_dir": "./samples",
+ "image_size": image_size,
+ }
+
@pytest.fixture(autouse=True)
def set_seed():
diff --git a/tests/fixtures/inputs.py b/tests/fixtures/inputs.py
index ab0c652..cd283d2 100644
--- a/tests/fixtures/inputs.py
+++ b/tests/fixtures/inputs.py
@@ -15,6 +15,16 @@ def input_batch():
)
+@pytest.fixture(scope="module")
+def flow_input_batch():
+ return torch.randn(
+ TestConfig.batch_size,
+ TestConfig.in_ch * 2,
+ TestConfig.image_size // 2,
+ TestConfig.image_size // 2,
+ )
+
+
@pytest.fixture(scope="function")
def z_sample(glow):
return get_z_list(
diff --git a/tests/fixtures/trainer.py b/tests/fixtures/trainer.py
new file mode 100644
index 0000000..93729d5
--- /dev/null
+++ b/tests/fixtures/trainer.py
@@ -0,0 +1,47 @@
+from unittest.mock import MagicMock
+
+import pytest
+import torch
+from fixtures.config import TestConfig
+from omegaconf import OmegaConf
+from torch.optim import Adam
+
+from modules.trainer import Trainer
+from modules.utils.losses import GlowLoss
+from modules.utils.train import SizedDataset
+
+
+class MockDataset(SizedDataset):
+ def __init__(self, data_size: int, data_shape: tuple):
+ self.data = torch.randn(data_size, *data_shape)
+
+ def __len__(self):
+ return len(self.data)
+
+ def __getitem__(self, idx):
+ return self.data[idx]
+
+
+@pytest.fixture(scope="session")
+def trainer(glow):
+ dataset = MockDataset(
+ TestConfig.dataset_size,
+ (TestConfig.in_ch, TestConfig.image_size, TestConfig.image_size),
+ )
+ loss_func = GlowLoss(n_bins=TestConfig.n_bins)
+ optimizer = Adam(glow.parameters(), lr=1e-3)
+ hydra_cfg = OmegaConf.create({"trainer": TestConfig.trainer_config})
+ lr_scheduler = MagicMock()
+ device = torch.device("cpu")
+
+ trainer = Trainer(
+ model=glow,
+ config=hydra_cfg,
+ dataset=dataset,
+ loss_func=loss_func,
+ optimizer=optimizer,
+ lr_scheduler=lr_scheduler,
+ device=device,
+ )
+
+ return trainer
diff --git a/tests/test_actnorm.py b/tests/test_actnorm.py
index 7dfd73b..082745c 100644
--- a/tests/test_actnorm.py
+++ b/tests/test_actnorm.py
@@ -31,7 +31,7 @@ def test_norm_mean(self, act_norm, input_batch):
# compute mean per channel
mean = out.mean(dim=[0, 2, 3])
assert torch.allclose(
- mean, torch.zeros_like(mean), atol=1e-4
+ mean, torch.zeros_like(mean), atol=1e-3
), f"channels means after norm should be ~0, got {mean.tolist()}"
def test_norm_std(self, act_norm, input_batch):
diff --git a/tests/test_affine_coupling.py b/tests/test_affine_coupling.py
index ec5a929..9b07e6f 100644
--- a/tests/test_affine_coupling.py
+++ b/tests/test_affine_coupling.py
@@ -1,25 +1,24 @@
import torch
-from fixtures.config import TestConfig
class TestAffineCoupling:
- def test_forward(self, affine_coupling, input_batch):
- out, _ = affine_coupling(input_batch)
- assert out.shape == input_batch.shape, "out shape != input shape"
+ def test_forward(self, affine_coupling, flow_input_batch):
+ out, _ = affine_coupling(flow_input_batch)
+ assert out.shape == flow_input_batch.shape, "out shape != input shape"
- def test_reverse(self, affine_coupling, input_batch):
- out, _ = affine_coupling(input_batch)
+ def test_reverse(self, affine_coupling, flow_input_batch):
+ out, _ = affine_coupling(flow_input_batch)
reverse_out = affine_coupling.reverse(out)
assert torch.allclose(
- input_batch, reverse_out, atol=1e-5
+ flow_input_batch, reverse_out, atol=1e-5
), "batch after reverse != input batch."
- def test_log_det(self, affine_coupling, input_batch):
- torch.manual_seed(TestConfig.seed)
- out, log_det = affine_coupling(input_batch)
-
- assert torch.allclose(
- input_batch, out, atol=1e-6
- ), "Identity behavior check failed"
- assert torch.all(log_det == 0), "log det ~= zero for identity transformation"
+ # def test_log_det(self, affine_coupling, flow_input_batch):
+ # torch.manual_seed(TestConfig.seed)
+ # out, log_det = affine_coupling(flow_input_batch)
+ #
+ # assert torch.allclose(
+ # flow_input_batch, out, atol=1e-6
+ # ), "Identity behavior check failed"
+ # assert torch.all(log_det == 0), "log det ~= zero for identity transformation"
diff --git a/tests/test_flow.py b/tests/test_flow.py
index 217db70..a54ba5b 100644
--- a/tests/test_flow.py
+++ b/tests/test_flow.py
@@ -6,9 +6,9 @@
class TestFlow:
- def test_forward(self, flow, input_batch):
- out, _ = flow(input_batch)
- assert out.shape == input_batch.shape, "out shape != input shape"
+ def test_forward(self, flow, flow_input_batch):
+ out, _ = flow(flow_input_batch)
+ assert out.shape == flow_input_batch.shape, "out shape != input shape"
assert len(flow.layers) == 3, "flow block contains 3 layers"
assert isinstance(flow.layers[0], ActNorm), "act norm should be the 1st layer"
assert isinstance(
@@ -18,20 +18,22 @@ def test_forward(self, flow, input_batch):
flow.layers[2], AffineCoupling
), "coupling should be the last layer"
- def test_reverse(self, flow, input_batch):
- out, _ = flow(input_batch)
+ def test_reverse(self, flow, flow_input_batch):
+ out, _ = flow(flow_input_batch)
reverse_out = flow.reverse(out)
- assert reverse_out.shape == input_batch.shape, "reverse shape != input shape"
+ assert (
+ reverse_out.shape == flow_input_batch.shape
+ ), "reverse shape != input shape"
assert torch.allclose(
- input_batch, reverse_out, atol=1e-5
+ flow_input_batch, reverse_out, atol=1e-5
), "batch after reverse != input batch."
- def test_log_det(self, flow, input_batch):
- out, test_log_det = flow(input_batch)
+ def test_log_det(self, flow, flow_input_batch):
+ out, test_log_det = flow(flow_input_batch)
expected_log_det = 0
- x = input_batch
+ x = flow_input_batch
for layer in flow.layers:
x, log_det = layer(x)
expected_log_det = expected_log_det + log_det
diff --git a/tests/test_trainer.py b/tests/test_trainer.py
new file mode 100644
index 0000000..723f88d
--- /dev/null
+++ b/tests/test_trainer.py
@@ -0,0 +1,59 @@
+from unittest.mock import MagicMock, patch
+
+from torch import nn
+from torch.utils.data import DataLoader
+
+
+class TestTrainer:
+ def test_init(self, trainer):
+ assert trainer.model is not None
+ assert trainer.train_loader is None
+ assert trainer.test_loader is None
+ assert trainer.logger is not None
+ assert len(trainer.z_list) == trainer.model.num_blocks
+ assert trainer.z_list[0].shape[0] == trainer.train_config.n_samples
+
+ def test_train_epoch(self, trainer):
+ trainer.train_loader = DataLoader(
+ trainer.train_dataset, batch_size=trainer.train_config.train_batch_size
+ )
+ loss = trainer.train_epoch(epoch=1)
+ assert isinstance(loss, float)
+
+ def test_test_epoch(self, trainer):
+ trainer.test_loader = DataLoader(
+ trainer.test_dataset, batch_size=trainer.train_config.test_batch_size
+ )
+ loss = trainer.test_epoch()
+ assert isinstance(loss, float)
+
+ @patch("modules.trainer.trainer.os.makedirs")
+ @patch("modules.trainer.trainer.torch.save")
+ def test_save_checkpoint(self, mock_torch_save, mock_makedirs, trainer):
+ trainer.save_checkpoint(epoch=1)
+ mock_makedirs.assert_called_with(trainer.train_config.save_dir, exist_ok=True)
+ assert mock_torch_save.call_count == 2
+
+ @patch("PIL.Image.Image.save")
+ def test_log_samples(self, mock_image_save, trainer):
+ trainer.model = nn.DataParallel(trainer.model).to(trainer.device)
+ trainer.logger = MagicMock()
+
+ trainer.log_samples(step=1)
+ mock_image_save.assert_called_with(f"{trainer.train_config.samples_dir}/1.png")
+ trainer.logger.log_images.assert_called_once_with(
+ grid=trainer.logger.log_images.call_args.kwargs["grid"], step=1
+ )
+
+ def test_train(self, trainer):
+ trainer.train_epoch = MagicMock(return_value=0.45)
+ trainer.test_epoch = MagicMock(return_value=0.5)
+ trainer.lr_scheduler.step = MagicMock()
+ trainer.save_checkpoint = MagicMock()
+
+ trainer.train()
+
+ assert trainer.train_epoch.call_count == trainer.train_config.n_epochs
+ assert trainer.test_epoch.call_count == trainer.train_config.n_epochs
+ trainer.lr_scheduler.step.assert_called()
+ trainer.save_checkpoint.assert_called()