From 775211ad4e5b1464b1e2bcdad6f0e7b489ef6201 Mon Sep 17 00:00:00 2001 From: danibene <34680344+danibene@users.noreply.github.com> Date: Sat, 24 Feb 2024 17:12:10 -0500 Subject: [PATCH 1/2] add dev tools without black and isort in precommit --- .coveragerc | 28 +++ .github/workflows/ci.yml | 107 +++++++++++ .gitignore | 10 +- .pre-commit-config.yaml | 74 ++++++++ .readthedocs.yml | 27 +++ AUTHORS.md | 3 + CHANGELOG.md | 3 + CONTRIBUTING.md | 371 +++++++++++++++++++++++++++++++++++++++ README.md | 31 +++- conda_env.yaml | 1 - docs/Makefile | 29 +++ docs/_static/.gitignore | 1 + docs/authors.md | 4 + docs/changelog.md | 4 + docs/conf.py | 304 ++++++++++++++++++++++++++++++++ docs/contributing.md | 4 + docs/index.md | 39 ++++ docs/license.md | 5 + docs/readme.md | 4 + docs/requirements.txt | 6 + pyproject.toml | 12 ++ setup.cfg | 116 ++++++++++++ setup.py | 46 +++-- tests/conftest.py | 10 ++ tests/test_utils.py | 12 ++ tox.ini | 93 ++++++++++ 26 files changed, 1307 insertions(+), 37 deletions(-) create mode 100644 .coveragerc create mode 100644 .github/workflows/ci.yml create mode 100644 .pre-commit-config.yaml create mode 100644 .readthedocs.yml create mode 100644 AUTHORS.md create mode 100644 CHANGELOG.md create mode 100644 CONTRIBUTING.md create mode 100644 docs/Makefile create mode 100644 docs/_static/.gitignore create mode 100644 docs/authors.md create mode 100644 docs/changelog.md create mode 100644 docs/conf.py create mode 100644 docs/contributing.md create mode 100644 docs/index.md create mode 100644 docs/license.md create mode 100644 docs/readme.md create mode 100644 docs/requirements.txt create mode 100644 pyproject.toml create mode 100644 setup.cfg create mode 100644 tests/conftest.py create mode 100644 tests/test_utils.py create mode 100644 tox.ini diff --git a/.coveragerc b/.coveragerc new file mode 100644 index 0000000..103964b --- /dev/null +++ b/.coveragerc @@ -0,0 +1,28 @@ +# .coveragerc to control coverage.py +[run] +branch = True +source = equiadapt +# omit = bad_file.py + +[paths] +source = + equiadapt/ + */site-packages/ + +[report] +# Regexes for lines to exclude from consideration +exclude_lines = + # Have to re-enable the standard pragma + pragma: no cover + + # Don't complain about missing debug-only code: + def __repr__ + if self\.debug + + # Don't complain if tests don't hit defensive assertion code: + raise AssertionError + raise NotImplementedError + + # Don't complain if non-runnable code isn't run: + if 0: + if __name__ == .__main__.: diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..b95f5ab --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,107 @@ +# GitHub Actions configuration **EXAMPLE**, +# MODIFY IT ACCORDING TO YOUR NEEDS! +# Reference: https://docs.github.com/en/actions + +name: tests + +on: + push: + # Avoid using all the resources/limits available by checking only + # relevant branches and tags. Other branches can be checked via PRs. + branches: [main] + tags: ['v[0-9]*', '[0-9]+.[0-9]+*'] # Match tags that resemble a version + pull_request: # Run in every PR + workflow_dispatch: # Allow manually triggering the workflow + schedule: + # Run roughly every 15 days at 00:00 UTC + # (useful to check if updates on dependencies break the package) + - cron: '0 0 1,16 * *' + +permissions: + contents: read + +concurrency: + group: >- + ${{ github.workflow }}-${{ github.ref_type }}- + ${{ github.event.pull_request.number || github.sha }} + cancel-in-progress: true + +jobs: + prepare: + runs-on: ubuntu-latest + outputs: + wheel-distribution: ${{ steps.wheel-distribution.outputs.path }} + steps: + - uses: actions/checkout@v3 + with: {fetch-depth: 0} # deep clone for setuptools-scm + - uses: actions/setup-python@v4 + id: setup-python + with: {python-version: "3.10"} + - name: Run static analysis and format checkers + run: pipx run pre-commit run --all-files --show-diff-on-failure + - name: Build package distribution files + run: >- + pipx run --python '${{ steps.setup-python.outputs.python-path }}' + tox -e clean,build + - name: Record the path of wheel distribution + id: wheel-distribution + run: echo "path=$(ls dist/*.whl)" >> $GITHUB_OUTPUT + - name: Store the distribution files for use in other stages + # `tests` and `publish` will use the same pre-built distributions, + # so we make sure to release the exact same package that was tested + uses: actions/upload-artifact@v3 + with: + name: python-distribution-files + path: dist/ + retention-days: 1 + + test: + needs: prepare + strategy: + matrix: + python: + - "3.7" # oldest Python supported by PSF + - "3.10" # newest Python that is stable + platform: + - ubuntu-latest + # - macos-latest + # - windows-latest + runs-on: ${{ matrix.platform }} + steps: + - uses: actions/checkout@v3 + - uses: actions/setup-python@v4 + id: setup-python + with: + python-version: ${{ matrix.python }} + - name: Retrieve pre-built distribution files + uses: actions/download-artifact@v3 + with: {name: python-distribution-files, path: dist/} + - name: Run tests + run: >- + pipx run --python '${{ steps.setup-python.outputs.python-path }}' + tox --installpkg '${{ needs.prepare.outputs.wheel-distribution }}' + -- -rFEx --durations 10 --color yes # pytest args + - name: Generate coverage report + run: pipx run coverage lcov -o coverage.lcov + + publish: + if: ${{ github.event_name == 'push' && contains(github.ref, 'refs/tags/') }} + runs-on: ubuntu-latest + permissions: + contents: write + steps: + - uses: actions/checkout@v3 + - uses: actions/setup-python@v4 + with: {python-version: "3.10"} + - name: Retrieve pre-built distribution files + uses: actions/download-artifact@v3 + with: {name: python-distribution-files, path: dist/} + - name: Publish Package + env: + # TODO: Set your PYPI_TOKEN as a secret using GitHub UI + # - https://pypi.org/help/#apitoken + # - https://docs.github.com/en/actions/security-guides/encrypted-secrets + TWINE_REPOSITORY: pypi + TWINE_USERNAME: __token__ + TWINE_PASSWORD: ${{ secrets.PYPI_TOKEN }} + run: pipx run tox -e publish diff --git a/.gitignore b/.gitignore index b538eae..f202880 100644 --- a/.gitignore +++ b/.gitignore @@ -104,11 +104,15 @@ dmypy.json # Ignore .vscode in all folders **/.vscode -# Ignore scripts to run experiments in mila +# Ignore scripts to run experiments in mila mila_scripts/ escnn *__pycache__/ -*_output/ +rotmnist_sweep_output/ +cifar10_sweep_output/ wandb/ -*.png \ No newline at end of file + +# Docs +docs/api/ +docs/_build/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..512c6cf --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,74 @@ +exclude: '^docs/conf.py' + +repos: +- repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.5.0 + hooks: + - id: trailing-whitespace + - id: check-added-large-files + - id: check-ast +# - id: check-json + - id: check-merge-conflict + - id: check-xml + - id: check-yaml + - id: debug-statements + - id: end-of-file-fixer + - id: requirements-txt-fixer + - id: mixed-line-ending + args: ['--fix=auto'] # replace 'auto' with 'lf' to enforce Linux/Mac line endings or 'crlf' for Windows + +## If you want to automatically "modernize" your Python code: +# - repo: https://github.com/asottile/pyupgrade +# rev: v3.7.0 +# hooks: +# - id: pyupgrade +# args: ['--py37-plus'] + +## If you want to avoid flake8 errors due to unused vars or imports: +# - repo: https://github.com/PyCQA/autoflake +# rev: v2.1.1 +# hooks: +# - id: autoflake +# args: [ +# --in-place, +# --remove-all-unused-imports, +# --remove-unused-variables, +# ] + +# - repo: https://github.com/PyCQA/isort +# rev: 5.13.2 +# hooks: +# - id: isort + +# - repo: https://github.com/psf/black +# rev: 24.2.0 +# hooks: +# - id: black +# language_version: python3 + +## If like to embrace black styles even in the docs: +# - repo: https://github.com/asottile/blacken-docs +# rev: v1.13.0 +# hooks: +# - id: blacken-docs +# additional_dependencies: [black] + +# - repo: https://github.com/PyCQA/flake8 +# rev: 7.0.0 +# hooks: +# - id: flake8 + ## You can add flake8 plugins via `additional_dependencies`: + # additional_dependencies: [flake8-bugbear] + +## Check for misspells in documentation files: +# - repo: https://github.com/codespell-project/codespell +# rev: v2.2.5 +# hooks: +# - id: codespell + +## Check for type errors with mypy: +# - repo: https://github.com/pre-commit/mirrors-mypy +# rev: 'v1.8.0' +# hooks: +# - id: mypy +# args: [--disallow-untyped-defs, --ignore-missing-imports] diff --git a/.readthedocs.yml b/.readthedocs.yml new file mode 100644 index 0000000..a2bcab3 --- /dev/null +++ b/.readthedocs.yml @@ -0,0 +1,27 @@ +# Read the Docs configuration file +# See https://docs.readthedocs.io/en/stable/config-file/v2.html for details + +# Required +version: 2 + +# Build documentation in the docs/ directory with Sphinx +sphinx: + configuration: docs/conf.py + +# Build documentation with MkDocs +#mkdocs: +# configuration: mkdocs.yml + +# Optionally build your docs in additional formats such as PDF +formats: + - pdf + +build: + os: ubuntu-22.04 + tools: + python: "3.11" + +python: + install: + - requirements: docs/requirements.txt + - {path: ., method: pip} diff --git a/AUTHORS.md b/AUTHORS.md new file mode 100644 index 0000000..17eddad --- /dev/null +++ b/AUTHORS.md @@ -0,0 +1,3 @@ +# Contributors + +* Arnab Mondal [arnab.mondal@mila.quebec](mailto:arnab.mondal@mila.quebec)s diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000..dd0325b --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,3 @@ +# Changelog + +## Version 0.1 (development) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000..93ebb5e --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,371 @@ +```{todo} THIS IS SUPPOSED TO BE AN EXAMPLE. MODIFY IT ACCORDING TO YOUR NEEDS! + + The document assumes you are using a source repository service that promotes a + contribution model similar to [GitHub's fork and pull request workflow]. + While this is true for the majority of services (like GitHub, GitLab, + BitBucket), it might not be the case for private repositories (e.g., when + using Gerrit). + + Also notice that the code examples might refer to GitHub URLs or the text + might use GitHub specific terminology (e.g., *Pull Request* instead of *Merge + Request*). + + Please make sure to check the document having these assumptions in mind + and update things accordingly. +``` + +```{todo} Provide the correct links/replacements at the bottom of the document. +``` + +```{todo} You might want to have a look on [PyScaffold's contributor's guide], + + especially if your project is open source. The text should be very similar to + this template, but there are a few extra contents that you might decide to + also include, like mentioning labels of your issue tracker or automated + releases. +``` + +# Contributing + +Welcome to `equiadapt` contributor's guide. + +This document focuses on getting any potential contributor familiarized with +the development processes, but [other kinds of contributions] are also appreciated. + +If you are new to using [git] or have never collaborated in a project previously, +please have a look at [contribution-guide.org]. Other resources are also +listed in the excellent [guide created by FreeCodeCamp] [^contrib1]. + +Please notice, all users and contributors are expected to be **open, +considerate, reasonable, and respectful**. When in doubt, +[Python Software Foundation's Code of Conduct] is a good reference in terms of +behavior guidelines. + +## Issue Reports + +If you experience bugs or general issues with `equiadapt`, please have a look +on the [issue tracker]. +If you don't see anything useful there, please feel free to fire an issue report. + +:::{tip} +Please don't forget to include the closed issues in your search. +Sometimes a solution was already reported, and the problem is considered +**solved**. +::: + +New issue reports should include information about your programming environment +(e.g., operating system, Python version) and steps to reproduce the problem. +Please try also to simplify the reproduction steps to a very minimal example +that still illustrates the problem you are facing. By removing other factors, +you help us to identify the root cause of the issue. + +## Documentation Improvements + +You can help improve `equiadapt` docs by making them more readable and coherent, or +by adding missing information and correcting mistakes. + +`equiadapt` documentation uses [Sphinx] as its main documentation compiler. +This means that the docs are kept in the same repository as the project code, and +that any documentation update is done in the same way was a code contribution. + +```{todo} Don't forget to mention which markup language you are using. + + e.g., [reStructuredText] or [CommonMark] with [MyST] extensions. +``` + +```{todo} If your project is hosted on GitHub, you can also mention the following tip: + + :::{tip} + Please notice that the [GitHub web interface] provides a quick way of + propose changes in `equiadapt`'s files. While this mechanism can + be tricky for normal code contributions, it works perfectly fine for + contributing to the docs, and can be quite handy. + + If you are interested in trying this method out, please navigate to + the `docs` folder in the source [repository], find which file you + would like to propose changes and click in the little pencil icon at the + top, to open [GitHub's code editor]. Once you finish editing the file, + please write a message in the form at the bottom of the page describing + which changes have you made and what are the motivations behind them and + submit your proposal. + ::: +``` + +When working on documentation changes in your local machine, you can +compile them using [tox] : + +``` +tox -e docs +``` + +and use Python's built-in web server for a preview in your web browser +(`http://localhost:8000`): + +``` +python3 -m http.server --directory 'docs/_build/html' +``` + +## Code Contributions + +```{todo} Please include a reference or explanation about the internals of the project. + + An architecture description, design principles or at least a summary of the + main concepts will make it easy for potential contributors to get started + quickly. +``` + +### Submit an issue + +Before you work on any non-trivial code contribution it's best to first create +a report in the [issue tracker] to start a discussion on the subject. +This often provides additional considerations and avoids unnecessary work. + +### Create an environment + +Before you start coding, we recommend creating an isolated [virtual environment] +to avoid any problems with your installed Python packages. +This can easily be done via either [virtualenv]: + +``` +virtualenv +source /bin/activate +``` + +or [Miniconda]: + +``` +conda create -n equiadapt python=3 six virtualenv pytest pytest-cov +conda activate equiadapt +``` + +### Clone the repository + +1. Create an user account on GitHub if you do not already have one. + +2. Fork the project [repository]: click on the *Fork* button near the top of the + page. This creates a copy of the code under your account on GitHub. + +3. Clone this copy to your local disk: + + ``` + git clone git@github.com:YourLogin/equiadapt.git + cd equiadapt + ``` + +4. You should run: + + ``` + pip install -U pip setuptools -e . + ``` + + to be able to import the package under development in the Python REPL. + + ```{todo} if you are not using pre-commit, please remove the following item: + ``` + +5. Install [pre-commit]: + + ``` + pip install pre-commit + pre-commit install + ``` + + `equiadapt` comes with a lot of hooks configured to automatically help the + developer to check the code being written. + +### Implement your changes + +1. Create a branch to hold your changes: + + ``` + git checkout -b my-feature + ``` + + and start making changes. Never work on the main branch! + +2. Start your work on this branch. Don't forget to add [docstrings] to new + functions, modules and classes, especially if they are part of public APIs. + +3. Add yourself to the list of contributors in `AUTHORS.rst`. + +4. When you’re done editing, do: + + ``` + git add + git commit + ``` + + to record your changes in [git]. + + ```{todo} if you are not using pre-commit, please remove the following item: + ``` + + Please make sure to see the validation messages from [pre-commit] and fix + any eventual issues. + This should automatically use [flake8]/[black] to check/fix the code style + in a way that is compatible with the project. + + :::{important} + Don't forget to add unit tests and documentation in case your + contribution adds an additional feature and is not just a bugfix. + + Moreover, writing a [descriptive commit message] is highly recommended. + In case of doubt, you can check the commit history with: + + ``` + git log --graph --decorate --pretty=oneline --abbrev-commit --all + ``` + + to look for recurring communication patterns. + ::: + +5. Please check that your changes don't break any unit tests with: + + ``` + tox + ``` + + (after having installed [tox] with `pip install tox` or `pipx`). + + You can also use [tox] to run several other pre-configured tasks in the + repository. Try `tox -av` to see a list of the available checks. + +### Submit your contribution + +1. If everything works fine, push your local branch to the remote server with: + + ``` + git push -u origin my-feature + ``` + +2. Go to the web page of your fork and click "Create pull request" + to send your changes for review. + + ```{todo} if you are using GitHub, you can uncomment the following paragraph + + Find more detailed information in [creating a PR]. You might also want to open + the PR as a draft first and mark it as ready for review after the feedbacks + from the continuous integration (CI) system or any required fixes. + + ``` + +### Troubleshooting + +The following tips can be used when facing problems to build or test the +package: + +1. Make sure to fetch all the tags from the upstream [repository]. + The command `git describe --abbrev=0 --tags` should return the version you + are expecting. If you are trying to run CI scripts in a fork repository, + make sure to push all the tags. + You can also try to remove all the egg files or the complete egg folder, i.e., + `.eggs`, as well as the `*.egg-info` folders in the `src` folder or + potentially in the root of your project. + +2. Sometimes [tox] misses out when new dependencies are added, especially to + `setup.cfg` and `docs/requirements.txt`. If you find any problems with + missing dependencies when running a command with [tox], try to recreate the + `tox` environment using the `-r` flag. For example, instead of: + + ``` + tox -e docs + ``` + + Try running: + + ``` + tox -r -e docs + ``` + +3. Make sure to have a reliable [tox] installation that uses the correct + Python version (e.g., 3.7+). When in doubt you can run: + + ``` + tox --version + # OR + which tox + ``` + + If you have trouble and are seeing weird errors upon running [tox], you can + also try to create a dedicated [virtual environment] with a [tox] binary + freshly installed. For example: + + ``` + virtualenv .venv + source .venv/bin/activate + .venv/bin/pip install tox + .venv/bin/tox -e all + ``` + +4. [Pytest can drop you] in an interactive session in the case an error occurs. + In order to do that you need to pass a `--pdb` option (for example by + running `tox -- -k --pdb`). + You can also setup breakpoints manually instead of using the `--pdb` option. + +## Maintainer tasks + +### Releases + +```{todo} This section assumes you are using PyPI to publicly release your package. + + If instead you are using a different/private package index, please update + the instructions accordingly. +``` + +If you are part of the group of maintainers and have correct user permissions +on [PyPI], the following steps can be used to release a new version for +`equiadapt`: + +1. Make sure all unit tests are successful. +2. Tag the current commit on the main branch with a release tag, e.g., `v1.2.3`. +3. Push the new tag to the upstream [repository], + e.g., `git push upstream v1.2.3` +4. Clean up the `dist` and `build` folders with `tox -e clean` + (or `rm -rf dist build`) + to avoid confusion with old builds and Sphinx docs. +5. Run `tox -e build` and check that the files in `dist` have + the correct version (no `.dirty` or [git] hash) according to the [git] tag. + Also check the sizes of the distributions, if they are too big (e.g., > + 500KB), unwanted clutter may have been accidentally included. +6. Run `tox -e publish -- --repository pypi` and check that everything was + uploaded to [PyPI] correctly. + +[^contrib1]: Even though, these resources focus on open source projects and + communities, the general ideas behind collaborating with other developers + to collectively create software are general and can be applied to all sorts + of environments, including private companies and proprietary code bases. + + +[black]: https://pypi.org/project/black/ +[commonmark]: https://commonmark.org/ +[contribution-guide.org]: http://www.contribution-guide.org/ +[creating a pr]: https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/proposing-changes-to-your-work-with-pull-requests/creating-a-pull-request +[descriptive commit message]: https://chris.beams.io/posts/git-commit +[docstrings]: https://www.sphinx-doc.org/en/master/usage/extensions/napoleon.html +[first-contributions tutorial]: https://github.com/firstcontributions/first-contributions +[flake8]: https://flake8.pycqa.org/en/stable/ +[git]: https://git-scm.com +[github web interface]: https://docs.github.com/en/github/managing-files-in-a-repository/managing-files-on-github/editing-files-in-your-repository +[github's code editor]: https://docs.github.com/en/github/managing-files-in-a-repository/managing-files-on-github/editing-files-in-your-repository +[github's fork and pull request workflow]: https://guides.github.com/activities/forking/ +[guide created by freecodecamp]: https://github.com/freecodecamp/how-to-contribute-to-open-source +[miniconda]: https://docs.conda.io/en/latest/miniconda.html +[myst]: https://myst-parser.readthedocs.io/en/latest/syntax/syntax.html +[other kinds of contributions]: https://opensource.guide/how-to-contribute +[pre-commit]: https://pre-commit.com/ +[pypi]: https://pypi.org/ +[pyscaffold's contributor's guide]: https://pyscaffold.org/en/stable/contributing.html +[pytest can drop you]: https://docs.pytest.org/en/stable/usage.html#dropping-to-pdb-python-debugger-at-the-start-of-a-test +[python software foundation's code of conduct]: https://www.python.org/psf/conduct/ +[restructuredtext]: https://www.sphinx-doc.org/en/master/usage/restructuredtext/ +[sphinx]: https://www.sphinx-doc.org/en/master/ +[tox]: https://tox.readthedocs.io/en/stable/ +[virtual environment]: https://realpython.com/python-virtual-environments-a-primer/ +[virtualenv]: https://virtualenv.pypa.io/en/stable/ + + +```{todo} Please review and change the following definitions: +``` + +[repository]: https://github.com//equiadapt +[issue tracker]: https://github.com//equiadapt/issues diff --git a/README.md b/README.md index fad71b8..a9fe130 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ Library to make any existing neural network architecture equivariant # Setup instructions -### Setup Conda environment +### Setup Conda environment To create a conda environment with the necessary packages: @@ -14,7 +14,7 @@ pip install -e . #### For Python 3.10 -Currently, everything works in Python 3.8. +Currently, everything works in Python 3.8. But to use Python 3.10, you need to remove `py3nj` from the `escnn` package requirements and install `escnn` from GitHub manually. ``` @@ -23,7 +23,7 @@ cd escnn (and go to setup.py and remove py3nj from the requirements) pip install -e . ``` -### Setup Hydra +### Setup Hydra - Create a `.env` file in the root of the project with the following content: ``` export HYDRA_JOBS="/path/to/your/hydra/jobs/directory" @@ -31,11 +31,11 @@ pip install -e . export WANDB_CACHE_DIR="/path/to/your/wandb/cache/directory" export DATA_PATH="/path/to/your/data/directory" export CHECKPOINT_PATH="/path/to/your/checkpoint/directory" - ``` + ``` # Running Instructions -For image classification: [here](/examples/images/classification/README.md) +For image classification: [here](/examples/images/classification/README.md) For image segmentation: [here](/examples/images/segmentation/README.md) @@ -43,7 +43,7 @@ For image segmentation: [here](/examples/images/segmentation/README.md) For more insights on this library refer to our original paper on the idea: [Equivariance with Learned Canonicalization Function (ICML 2023)](https://proceedings.mlr.press/v202/kaba23a.html) and how to extend it to make any existing large pre-trained model equivariant: [Equivariant Adaptation of Large Pretrained Models (NeurIPS 2023)](https://proceedings.neurips.cc/paper_files/paper/2023/hash/9d5856318032ef3630cb580f4e24f823-Abstract-Conference.html). -To learn more about this from a blog, check out: [How to make your foundation model equivariant](https://mila.quebec/en/article/how-to-make-your-foundation-model-equivariant/) +To learn more about this from a blog, check out: [How to make your foundation model equivariant](https://mila.quebec/en/article/how-to-make-your-foundation-model-equivariant/) # Citation If you find this library or the associated papers useful, please cite: @@ -70,7 +70,22 @@ If you find this library or the associated papers useful, please cite: # Contact -For questions related to this code, you can mail us at: +For questions related to this code, you can mail us at: ```arnab.mondal@mila.quebec``` ```siba-smarak.panigrahi@mila.quebec``` -```kabaseko@mila.quebec``` \ No newline at end of file +```kabaseko@mila.quebec``` + +# Contributing + +You can check out the [contributor's guide](CONTRIBUTING.md). + +This project uses `pre-commit`_, you can install it before making any +changes:: + + pip install pre-commit + cd equiadapt + pre-commit install + +It is a good idea to update the hooks to the latest version:: + + pre-commit autoupdate diff --git a/conda_env.yaml b/conda_env.yaml index 176e904..4dcc926 100644 --- a/conda_env.yaml +++ b/conda_env.yaml @@ -155,4 +155,3 @@ dependencies: - urllib3==1.26.18 - wheel==0.41.2 - yarl==1.9.4 - diff --git a/docs/Makefile b/docs/Makefile new file mode 100644 index 0000000..31655dd --- /dev/null +++ b/docs/Makefile @@ -0,0 +1,29 @@ +# Makefile for Sphinx documentation +# + +# You can set these variables from the command line, and also +# from the environment for the first two. +SPHINXOPTS ?= +SPHINXBUILD ?= sphinx-build +SOURCEDIR = . +BUILDDIR = _build +AUTODOCDIR = api + +# User-friendly check for sphinx-build +ifeq ($(shell which $(SPHINXBUILD) >/dev/null 2>&1; echo $?), 1) +$(error "The '$(SPHINXBUILD)' command was not found. Make sure you have Sphinx installed, then set the SPHINXBUILD environment variable to point to the full path of the '$(SPHINXBUILD)' executable. Alternatively you can add the directory with the executable to your PATH. If you don't have Sphinx installed, grab it from https://sphinx-doc.org/") +endif + +.PHONY: help clean Makefile + +# Put it first so that "make" without argument is like "make help". +help: + @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) + +clean: + rm -rf $(BUILDDIR)/* $(AUTODOCDIR) + +# Catch-all target: route all unknown targets to Sphinx using the new +# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). +%: Makefile + @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) diff --git a/docs/_static/.gitignore b/docs/_static/.gitignore new file mode 100644 index 0000000..3c96363 --- /dev/null +++ b/docs/_static/.gitignore @@ -0,0 +1 @@ +# Empty directory diff --git a/docs/authors.md b/docs/authors.md new file mode 100644 index 0000000..ced47d0 --- /dev/null +++ b/docs/authors.md @@ -0,0 +1,4 @@ +```{include} ../AUTHORS.md +:relative-docs: docs/ +:relative-images: +``` diff --git a/docs/changelog.md b/docs/changelog.md new file mode 100644 index 0000000..6e2f0fb --- /dev/null +++ b/docs/changelog.md @@ -0,0 +1,4 @@ +```{include} ../CHANGELOG.md +:relative-docs: docs/ +:relative-images: +``` diff --git a/docs/conf.py b/docs/conf.py new file mode 100644 index 0000000..7249aa9 --- /dev/null +++ b/docs/conf.py @@ -0,0 +1,304 @@ +# This file is execfile()d with the current directory set to its containing dir. +# +# This file only contains a selection of the most common options. For a full +# list see the documentation: +# https://www.sphinx-doc.org/en/master/usage/configuration.html +# +# All configuration values have a default; values that are commented out +# serve to show the default. + +import os +import sys +import shutil + +# -- Path setup -------------------------------------------------------------- + +__location__ = os.path.dirname(__file__) + +# If extensions (or modules to document with autodoc) are in another directory, +# add these directories to sys.path here. If the directory is relative to the +# documentation root, use os.path.abspath to make it absolute, like shown here. +sys.path.insert(0, os.path.join(__location__, "../src")) + +# -- Run sphinx-apidoc ------------------------------------------------------- +# This hack is necessary since RTD does not issue `sphinx-apidoc` before running +# `sphinx-build -b html . _build/html`. See Issue: +# https://github.com/readthedocs/readthedocs.org/issues/1139 +# DON'T FORGET: Check the box "Install your project inside a virtualenv using +# setup.py install" in the RTD Advanced Settings. +# Additionally it helps us to avoid running apidoc manually + +try: # for Sphinx >= 1.7 + from sphinx.ext import apidoc +except ImportError: + from sphinx import apidoc + +output_dir = os.path.join(__location__, "api") +module_dir = os.path.join(__location__, "../equiadapt") +try: + shutil.rmtree(output_dir) +except FileNotFoundError: + pass + +try: + import sphinx + + cmd_line = f"sphinx-apidoc --implicit-namespaces -f -o {output_dir} {module_dir}" + + args = cmd_line.split(" ") + if tuple(sphinx.__version__.split(".")) >= ("1", "7"): + # This is a rudimentary parse_version to avoid external dependencies + args = args[1:] + + apidoc.main(args) +except Exception as e: + print("Running `sphinx-apidoc` failed!\n{}".format(e)) + +# -- General configuration --------------------------------------------------- + +# If your documentation needs a minimal Sphinx version, state it here. +# needs_sphinx = '1.0' + +# Add any Sphinx extension module names here, as strings. They can be extensions +# coming with Sphinx (named 'sphinx.ext.*') or your custom ones. +extensions = [ + "sphinx.ext.autodoc", + "sphinx.ext.intersphinx", + "sphinx.ext.todo", + "sphinx.ext.autosummary", + "sphinx.ext.viewcode", + "sphinx.ext.coverage", + "sphinx.ext.doctest", + "sphinx.ext.ifconfig", + "sphinx.ext.mathjax", + "sphinx.ext.napoleon", +] + +# Add any paths that contain templates here, relative to this directory. +templates_path = ["_templates"] + + +# Enable markdown +extensions.append("myst_parser") + +# Configure MyST-Parser +myst_enable_extensions = [ + "amsmath", + "colon_fence", + "deflist", + "dollarmath", + "html_image", + "linkify", + "replacements", + "smartquotes", + "substitution", + "tasklist", +] + +# The suffix of source filenames. +source_suffix = [".rst", ".md"] + +# The encoding of source files. +# source_encoding = 'utf-8-sig' + +# The master toctree document. +master_doc = "index" + +# General information about the project. +project = "equiadapt" +copyright = "2024, Danielle Benesch" + +# The version info for the project you're documenting, acts as replacement for +# |version| and |release|, also used in various other places throughout the +# built documents. +# +# version: The short X.Y version. +# release: The full version, including alpha/beta/rc tags. +# If you don’t need the separation provided between version and release, +# just set them both to the same value. +try: + from equiadapt import __version__ as version +except ImportError: + version = "" + +if not version or version.lower() == "unknown": + version = os.getenv("READTHEDOCS_VERSION", "unknown") # automatically set by RTD + +release = version + +# The language for content autogenerated by Sphinx. Refer to documentation +# for a list of supported languages. +# language = None + +# There are two options for replacing |today|: either, you set today to some +# non-false value, then it is used: +# today = '' +# Else, today_fmt is used as the format for a strftime call. +# today_fmt = '%B %d, %Y' + +# List of patterns, relative to source directory, that match files and +# directories to ignore when looking for source files. +exclude_patterns = ["_build", "Thumbs.db", ".DS_Store", ".venv"] + +# The reST default role (used for this markup: `text`) to use for all documents. +# default_role = None + +# If true, '()' will be appended to :func: etc. cross-reference text. +# add_function_parentheses = True + +# If true, the current module name will be prepended to all description +# unit titles (such as .. function::). +# add_module_names = True + +# If true, sectionauthor and moduleauthor directives will be shown in the +# output. They are ignored by default. +# show_authors = False + +# The name of the Pygments (syntax highlighting) style to use. +pygments_style = "sphinx" + +# A list of ignored prefixes for module index sorting. +# modindex_common_prefix = [] + +# If true, keep warnings as "system message" paragraphs in the built documents. +# keep_warnings = False + +# If this is True, todo emits a warning for each TODO entries. The default is False. +todo_emit_warnings = True + + +# -- Options for HTML output ------------------------------------------------- + +# The theme to use for HTML and HTML Help pages. See the documentation for +# a list of builtin themes. +html_theme = "alabaster" + +# Theme options are theme-specific and customize the look and feel of a theme +# further. For a list of options available for each theme, see the +# documentation. +html_theme_options = { + "sidebar_width": "300px", + "page_width": "1200px" +} + +# Add any paths that contain custom themes here, relative to this directory. +# html_theme_path = [] + +# The name for this set of Sphinx documents. If None, it defaults to +# " v documentation". +# html_title = None + +# A shorter title for the navigation bar. Default is the same as html_title. +# html_short_title = None + +# The name of an image file (relative to this directory) to place at the top +# of the sidebar. +# html_logo = "" + +# The name of an image file (within the static path) to use as favicon of the +# docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32 +# pixels large. +# html_favicon = None + +# Add any paths that contain custom static files (such as style sheets) here, +# relative to this directory. They are copied after the builtin static files, +# so a file named "default.css" will overwrite the builtin "default.css". +html_static_path = ["_static"] + +# If not '', a 'Last updated on:' timestamp is inserted at every page bottom, +# using the given strftime format. +# html_last_updated_fmt = '%b %d, %Y' + +# If true, SmartyPants will be used to convert quotes and dashes to +# typographically correct entities. +# html_use_smartypants = True + +# Custom sidebar templates, maps document names to template names. +# html_sidebars = {} + +# Additional templates that should be rendered to pages, maps page names to +# template names. +# html_additional_pages = {} + +# If false, no module index is generated. +# html_domain_indices = True + +# If false, no index is generated. +# html_use_index = True + +# If true, the index is split into individual pages for each letter. +# html_split_index = False + +# If true, links to the reST sources are added to the pages. +# html_show_sourcelink = True + +# If true, "Created using Sphinx" is shown in the HTML footer. Default is True. +# html_show_sphinx = True + +# If true, "(C) Copyright ..." is shown in the HTML footer. Default is True. +# html_show_copyright = True + +# If true, an OpenSearch description file will be output, and all pages will +# contain a tag referring to it. The value of this option must be the +# base URL from which the finished HTML is served. +# html_use_opensearch = '' + +# This is the file name suffix for HTML files (e.g. ".xhtml"). +# html_file_suffix = None + +# Output file base name for HTML help builder. +htmlhelp_basename = "equiadapt-doc" + + +# -- Options for LaTeX output ------------------------------------------------ + +latex_elements = { + # The paper size ("letterpaper" or "a4paper"). + # "papersize": "letterpaper", + # The font size ("10pt", "11pt" or "12pt"). + # "pointsize": "10pt", + # Additional stuff for the LaTeX preamble. + # "preamble": "", +} + +# Grouping the document tree into LaTeX files. List of tuples +# (source start file, target name, title, author, documentclass [howto/manual]). +latex_documents = [ + ("index", "user_guide.tex", "equiadapt Documentation", "Danielle Benesch", "manual") +] + +# The name of an image file (relative to this directory) to place at the top of +# the title page. +# latex_logo = "" + +# For "manual" documents, if this is true, then toplevel headings are parts, +# not chapters. +# latex_use_parts = False + +# If true, show page references after internal links. +# latex_show_pagerefs = False + +# If true, show URL addresses after external links. +# latex_show_urls = False + +# Documents to append as an appendix to all manuals. +# latex_appendices = [] + +# If false, no module index is generated. +# latex_domain_indices = True + +# -- External mapping -------------------------------------------------------- +python_version = ".".join(map(str, sys.version_info[0:2])) +intersphinx_mapping = { + "sphinx": ("https://www.sphinx-doc.org/en/master", None), + "python": ("https://docs.python.org/" + python_version, None), + "matplotlib": ("https://matplotlib.org", None), + "numpy": ("https://numpy.org/doc/stable", None), + "sklearn": ("https://scikit-learn.org/stable", None), + "pandas": ("https://pandas.pydata.org/pandas-docs/stable", None), + "scipy": ("https://docs.scipy.org/doc/scipy/reference", None), + "setuptools": ("https://setuptools.pypa.io/en/stable/", None), + "pyscaffold": ("https://pyscaffold.org/en/stable", None), +} + +print(f"loading configurations for {project} {version} ...", file=sys.stderr) \ No newline at end of file diff --git a/docs/contributing.md b/docs/contributing.md new file mode 100644 index 0000000..fc1b213 --- /dev/null +++ b/docs/contributing.md @@ -0,0 +1,4 @@ +```{include} ../CONTRIBUTING.md +:relative-docs: docs/ +:relative-images: +``` diff --git a/docs/index.md b/docs/index.md new file mode 100644 index 0000000..e03ae46 --- /dev/null +++ b/docs/index.md @@ -0,0 +1,39 @@ +# equiadapt + +Library that provides metrics to asses representation quality + + +## Note + +> This is the main page of your project's [Sphinx] documentation. It is +> formatted in [Markdown]. Add additional pages by creating md-files in +> `docs` or rst-files (formatted in [reStructuredText]) and adding links to +> them in the `Contents` section below. +> +> Please check [Sphinx] and [MyST] for more information +> about how to document your project and how to configure your preferences. + + +## Contents + +```{toctree} +:maxdepth: 2 + +Overview +Contributions & Help +License +Authors +Changelog +Module Reference +``` + +## Indices and tables + +* {ref}`genindex` +* {ref}`modindex` +* {ref}`search` + +[Sphinx]: http://www.sphinx-doc.org/ +[Markdown]: https://daringfireball.net/projects/markdown/ +[reStructuredText]: http://www.sphinx-doc.org/en/master/usage/restructuredtext/basics.html +[MyST]: https://myst-parser.readthedocs.io/en/latest/ diff --git a/docs/license.md b/docs/license.md new file mode 100644 index 0000000..22567b6 --- /dev/null +++ b/docs/license.md @@ -0,0 +1,5 @@ +# License + +```{literalinclude} ../LICENSE +:language: text +``` diff --git a/docs/readme.md b/docs/readme.md new file mode 100644 index 0000000..2cb706b --- /dev/null +++ b/docs/readme.md @@ -0,0 +1,4 @@ +```{include} ../README.md +:relative-docs: docs/ +:relative-images: +``` diff --git a/docs/requirements.txt b/docs/requirements.txt new file mode 100644 index 0000000..0990c2a --- /dev/null +++ b/docs/requirements.txt @@ -0,0 +1,6 @@ +# Requirements file for ReadTheDocs, check .readthedocs.yml. +# To build the module reference correctly, make sure every external package +# under `install_requires` in `setup.cfg` is also listed here! +# sphinx_rtd_theme +myst-parser[linkify] +sphinx>=3.2.1 diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..9f81bba --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,12 @@ +[build-system] +# AVOID CHANGING REQUIRES: IT WILL BE UPDATED BY PYSCAFFOLD! +requires = ["setuptools>=46.1.0", "setuptools_scm[toml]>=5"] +build-backend = "setuptools.build_meta" + +[tool.setuptools_scm] +# For smarter version schemes and other configuration options, +# check out https://github.com/pypa/setuptools_scm +version_scheme = "no-guess-dev" + +[tool.mypy] +exclude = ['docs'] diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 0000000..94d11bd --- /dev/null +++ b/setup.cfg @@ -0,0 +1,116 @@ +# This file is used to configure your project. +# Read more about the various options under: +# https://setuptools.pypa.io/en/latest/userguide/declarative_config.html +# https://setuptools.pypa.io/en/latest/references/keywords.html + +[metadata] +name = equiadapt +description = Library that provides metrics to assess representation quality +author = Arnab Mondal +author_email = arnab.mondal@mila.quebec +license = MIT +license_files = LICENSE +long_description = file: README.md +long_description_content_type = text/markdown; charset=UTF-8; variant=GFM +url = https://github.com/arnab39/EquivariantAdaptation/ +# Add here related links, for example: +project_urls = + Tracker = https://github.com/arnab39/EquivariantAdaptation/issues + Source = https://github.com/arnab39/EquivariantAdaptation/ + +# Change if running only on Windows, Mac or Linux (comma-separated) +platforms = Linux + +# Add here all kinds of additional classifiers as defined under +# https://pypi.org/classifiers/ +classifiers = + Programming Language :: Python :: 3 + License :: OSI Approved :: MIT License + Operating System :: Linux + +[options] +zip_safe = False +packages = find_namespace: +include_package_data = True + +# Require a min/specific Python version (comma-separated conditions) +python_requires = >=3.7 + +# Add here dependencies of your project (line-separated), e.g. requests>=2.2,<3.0. +# Version specifiers like >=2.2,<3.0 avoid problems due to API changes in +# new major versions. This works if the required packages follow Semantic Versioning. +# For more information, check out https://semver.org/. +install_requires = + torch + numpy + torchvision + kornia + escnn @ git+https://github.com/danibene/escnn.git@remove/py3nj_dep + +[options.packages.find] +exclude = + tests + +[options.extras_require] +# Add here additional requirements for extra features, to install with: +# `pip install equiadapt[PDF]` like: +# PDF = ReportLab; RXP + +# Add here test requirements (semicolon/line-separated) +testing = + setuptools + pytest + pytest-cov + +[options.entry_points] +# Add here console scripts like: +# console_scripts = +# script_name = equiadapt.module:function + +[tool:pytest] +# Specify command line options as you would do when invoking pytest directly. +# e.g. --cov-report html (or xml) for html/xml output or --junitxml junit.xml +# in order to write a coverage file that can be read by Jenkins. +# CAUTION: --cov flags may prohibit setting breakpoints while debugging. +# Comment those flags to avoid this pytest issue. +addopts = + --cov equiadapt --cov-report term-missing + --verbose +norecursedirs = + dist + build + .tox +testpaths = tests +# Use pytest markers to select/deselect specific tests +# markers = +# slow: mark tests as slow (deselect with '-m "not slow"') +# system: mark end-to-end system tests + +[devpi:upload] +# Options for the devpi: PyPI server and packaging tool +# VCS export must be deactivated since we are using setuptools-scm +no_vcs = 1 +formats = bdist_wheel + +[flake8] +# Some sane defaults for the code style checker flake8 +max_line_length = 88 +extend_ignore = E203, W503 +# ^ Black-compatible +# E203 and W503 have edge cases handled by black +exclude = + .tox + build + dist + .eggs + docs/conf.py + +[pyscaffold] +# PyScaffold's parameters when the project was created. +# This will be used when updating. Do not change! +version = 4.5 +package = equiadapt +extensions = + github_actions + markdown + pre_commit diff --git a/setup.py b/setup.py index 8c21c04..477bbff 100644 --- a/setup.py +++ b/setup.py @@ -1,26 +1,22 @@ -from setuptools import setup, find_packages +""" + Setup file for equiadapt. + Use setup.cfg to configure your project. -setup( - name='equiadapt', # Replace with your package's name - version='0.1.0', # Package version - author='Arnab Mondal', # Replace with your name - author_email='arnab.mondal@mila.quebec', # Replace with your email - description='Library to make any existing neural network architecture equivariant', # Package summary - long_description=open('README.md').read(), - long_description_content_type='text/markdown', - url='https://github.com/arnab39/EquivariantAdaptation', # Replace with your repository URL - packages=find_packages(), - install_requires=[ - 'torch', # Specify your project's dependencies here - 'numpy', - 'torchvision', - 'kornia', - 'escnn @ git+https://github.com/danibene/escnn.git@remove/py3nj_dep' - ], - classifiers=[ - 'Programming Language :: Python :: 3', - 'License :: OSI Approved :: MIT License', - 'Operating System :: OS Independent', - ], - python_requires='>=3.7', # Minimum version requirement of Python -) + This file was generated with PyScaffold 4.5. + PyScaffold helps you to put up the scaffold of your new Python project. + Learn more under: https://pyscaffold.org/ +""" + +from setuptools import setup + +if __name__ == "__main__": + try: + setup(use_scm_version={"version_scheme": "no-guess-dev"}) + except: # noqa + print( + "\n\nAn error occurred while building the project, " + "please ensure you have the most updated version of setuptools, " + "setuptools_scm and wheel with:\n" + " pip install -U setuptools setuptools_scm wheel\n\n" + ) + raise diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..08c21bc --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,10 @@ +""" + Dummy conftest.py for equiadapt. + + If you don't know what this is for, just leave it empty. + Read more about conftest.py under: + - https://docs.pytest.org/en/stable/fixture.html + - https://docs.pytest.org/en/stable/writing_plugins.html +""" + +# import pytest diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 0000000..4913c70 --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,12 @@ +import torch + +from equiadapt.common.utils import gram_schmidt + + +def test_gram_schmidt() -> None: + torch.manual_seed(0) + vectors = torch.randn(1, 3, 3) # batch of 1, 3 vectors of dimension 3 + + output = gram_schmidt(vectors) + + assert torch.allclose(output[0][0][0], torch.tensor(0.5740), atol=1e-4) diff --git a/tox.ini b/tox.ini new file mode 100644 index 0000000..69f8159 --- /dev/null +++ b/tox.ini @@ -0,0 +1,93 @@ +# Tox configuration file +# Read more under https://tox.wiki/ +# THIS SCRIPT IS SUPPOSED TO BE AN EXAMPLE. MODIFY IT ACCORDING TO YOUR NEEDS! + +[tox] +minversion = 3.24 +envlist = default +isolated_build = True + + +[testenv] +description = Invoke pytest to run automated tests +setenv = + TOXINIDIR = {toxinidir} +passenv = + HOME + SETUPTOOLS_* +extras = + testing +commands = + pytest {posargs} + + +# # To run `tox -e lint` you need to make sure you have a +# # `.pre-commit-config.yaml` file. See https://pre-commit.com +# [testenv:lint] +# description = Perform static analysis and style checks +# skip_install = True +# deps = pre-commit +# passenv = +# HOMEPATH +# PROGRAMDATA +# SETUPTOOLS_* +# commands = +# pre-commit run --all-files {posargs:--show-diff-on-failure} + + +[testenv:{build,clean}] +description = + build: Build the package in isolation according to PEP517, see https://github.com/pypa/build + clean: Remove old distribution files and temporary build artifacts (./build and ./dist) +# https://setuptools.pypa.io/en/stable/build_meta.html#how-to-use-it +skip_install = True +changedir = {toxinidir} +deps = + build: build[virtualenv] +passenv = + SETUPTOOLS_* +commands = + clean: python -c 'import shutil; [shutil.rmtree(p, True) for p in ("build", "dist", "docs/_build")]' + clean: python -c 'import pathlib, shutil; [shutil.rmtree(p, True) for p in pathlib.Path("src").glob("*.egg-info")]' + build: python -m build {posargs} +# By default, both `sdist` and `wheel` are built. If your sdist is too big or you don't want +# to make it available, consider running: `tox -e build -- --wheel` + + +[testenv:{docs,doctests,linkcheck}] +description = + docs: Invoke sphinx-build to build the docs + doctests: Invoke sphinx-build to run doctests + linkcheck: Check for broken links in the documentation +passenv = + SETUPTOOLS_* +setenv = + DOCSDIR = {toxinidir}/docs + BUILDDIR = {toxinidir}/docs/_build + docs: BUILD = html + doctests: BUILD = doctest + linkcheck: BUILD = linkcheck +deps = + -r {toxinidir}/docs/requirements.txt + # ^ requirements.txt shared with Read The Docs +commands = + sphinx-build --color -b {env:BUILD} -d "{env:BUILDDIR}/doctrees" "{env:DOCSDIR}" "{env:BUILDDIR}/{env:BUILD}" {posargs} + + +[testenv:publish] +description = + Publish the package you have been developing to a package index server. + By default, it uses testpypi. If you really want to publish your package + to be publicly accessible in PyPI, use the `-- --repository pypi` option. +skip_install = True +changedir = {toxinidir} +passenv = + # See: https://twine.readthedocs.io/en/latest/ + TWINE_USERNAME + TWINE_PASSWORD + TWINE_REPOSITORY + TWINE_REPOSITORY_URL +deps = twine +commands = + python -m twine check dist/* + python -m twine upload {posargs:--repository {env:TWINE_REPOSITORY:testpypi}} dist/* From 529b63ee818c66b1d5158e78a32bca8da09a5c0e Mon Sep 17 00:00:00 2001 From: danibene <34680344+danibene@users.noreply.github.com> Date: Sat, 24 Feb 2024 17:13:51 -0500 Subject: [PATCH 2/2] apply precommit --- equiadapt/common/basecanonicalization.py | 100 ++++++----- equiadapt/common/utils.py | 14 +- .../canonicalization/continuous_group.py | 161 +++++++++-------- .../images/canonicalization/discrete_group.py | 164 +++++++++--------- .../canonicalization_networks/__init__.py | 2 +- .../custom_equivariant_networks.py | 18 +- .../custom_nonequivariant_networks.py | 12 +- .../escnn_networks.py | 82 ++++----- equiadapt/images/utils.py | 18 +- examples/images/classification/README.md | 6 +- .../canonicalization/group_equivariant.yaml | 2 +- .../configs/canonicalization/identity.yaml | 2 +- .../opt_group_equivariant.yaml | 2 +- .../canonicalization/opt_steerable.yaml | 2 +- .../configs/canonicalization/steerable.yaml | 2 +- .../configs/checkpoint/default.yaml | 2 +- .../configs/dataset/default.yaml | 2 +- .../configs/experiment/default.yaml | 4 +- .../group_equivariant/cifar10.yaml | 6 +- .../group_equivariant/rotmnist.yaml | 6 +- .../opt_equivariant/cifar10.yaml | 6 +- .../opt_equivariant/rotmnist.yaml | 6 +- .../original_configs/steerable/cifar10.yaml | 4 +- .../classification/configs/wandb_sweep.yaml | 2 +- .../images/classification/inference_utils.py | 55 +++--- examples/images/classification/model.py | 82 ++++----- examples/images/classification/model_utils.py | 16 +- .../images/classification/prepare/__init__.py | 2 +- .../classification/prepare/celeba_data.py | 4 +- .../classification/prepare/cifar_data.py | 8 +- .../classification/prepare/flowers102_data.py | 2 +- .../classification/prepare/imagenet_data.py | 4 +- .../prepare/rotated_mnist_data.py | 2 +- .../classification/prepare/stl10_data.py | 6 +- examples/images/classification/train.py | 12 +- examples/images/classification/train_utils.py | 60 +++---- examples/images/common/utils.py | 26 +-- examples/images/segmentation/README.md | 8 +- .../canonicalization/group_equivariant.yaml | 2 +- .../configs/canonicalization/identity.yaml | 2 +- .../opt_group_equivariant.yaml | 2 +- .../canonicalization/opt_steerable.yaml | 2 +- .../configs/canonicalization/steerable.yaml | 2 +- .../configs/checkpoint/default.yaml | 2 +- .../segmentation/configs/dataset/default.yaml | 4 +- .../configs/experiment/default.yaml | 4 +- .../group_equivariant/cifar10.yaml | 6 +- .../group_equivariant/rotmnist.yaml | 6 +- .../opt_equivariant/cifar10.yaml | 6 +- .../opt_equivariant/rotmnist.yaml | 6 +- .../original_configs/steerable/cifar10.yaml | 4 +- .../configs/prediction/default.yaml | 1 - .../segmentation/configs/wandb_sweep.yaml | 2 +- .../images/segmentation/inference_utils.py | 69 ++++---- examples/images/segmentation/model.py | 104 +++++------ examples/images/segmentation/model_utils.py | 20 +-- .../images/segmentation/prepare/__init__.py | 2 +- .../images/segmentation/prepare/coco_data.py | 8 +- .../segmentation/prepare/vision_transforms.py | 2 +- examples/images/segmentation/train.py | 12 +- examples/images/segmentation/train_utils.py | 62 +++---- 61 files changed, 617 insertions(+), 625 deletions(-) diff --git a/equiadapt/common/basecanonicalization.py b/equiadapt/common/basecanonicalization.py index 3ca61cb..e57c8b9 100644 --- a/equiadapt/common/basecanonicalization.py +++ b/equiadapt/common/basecanonicalization.py @@ -3,7 +3,7 @@ # Base skeleton for the canonicalization class -# DiscreteGroupCanonicalization and ContinuousGroupCanonicalization +# DiscreteGroupCanonicalization and ContinuousGroupCanonicalization # will inherit from this class class BaseCanonicalization(torch.nn.Module): @@ -11,76 +11,76 @@ def __init__(self, canonicalization_network: torch.nn.Module): super().__init__() self.canonicalization_network = canonicalization_network self.canonicalization_info_dict = {} - + def forward(self, x: torch.Tensor, targets: torch.Tensor=None, **kwargs): """ Forward method for the canonicalization which takes the input data and returns the canonicalized version of the data - + Args: x: input data **kwargs: additional arguments - + Returns: canonicalized_x: canonicalized version of the input data """ - + return self.canonicalize(x, targets, **kwargs) - + def canonicalize(self, x: torch.Tensor, targets: torch.Tensor=None, **kwargs): """ - This method takes an input data and + This method takes an input data and returns its canonicalized version and a dictionary containing the information about the canonicalization """ raise NotImplementedError() - + def invert_canonicalization(self, x: torch.Tensor, **kwargs): """ - This method takes the output of the canonicalized data + This method takes the output of the canonicalized data and returns the output for the original data orientation """ raise NotImplementedError() - + class IdentityCanonicalization(BaseCanonicalization): def __init__(self, canonicalization_network: torch.nn.Module = torch.nn.Identity()): super().__init__(canonicalization_network) - + def canonicalize(self, x: torch.Tensor, targets: torch.Tensor=None, **kwargs): if targets: return x, targets return x - + def invert_canonicalization(self, x: torch.Tensor, **kwargs): return x - + def get_prior_regularization_loss(self): return torch.tensor(0.0) - + def get_identity_metric(self): return torch.tensor(1.0) - + class DiscreteGroupCanonicalization(BaseCanonicalization): - def __init__(self, - canonicalization_network: torch.nn.Module, + def __init__(self, + canonicalization_network: torch.nn.Module, beta: float = 1.0, gradient_trick: str = 'straight_through'): super().__init__(canonicalization_network) self.beta = beta self.gradient_trick = gradient_trick - + def groupactivations_to_groupelementonehot(self, group_activations: torch.Tensor): """ This method takes the activations for each group element as input and returns the group element in a differentiable manner - + Args: group_activations: activations for each group element - + Returns: group_element_onehot: one hot encoding of the group element """ @@ -88,103 +88,103 @@ def groupactivations_to_groupelementonehot(self, group_activations: torch.Tensor torch.argmax(group_activations, dim=-1), self.num_group).float() group_activations_soft = torch.nn.functional.softmax(self.beta * group_activations, dim=-1) if self.gradient_trick == 'straight_through': - if self.training: - group_element_onehot = (group_activations_one_hot + group_activations_soft - group_activations_soft.detach()) + if self.training: + group_element_onehot = (group_activations_one_hot + group_activations_soft - group_activations_soft.detach()) else: group_element_onehot = group_activations_one_hot elif self.gradient_trick == 'gumbel_softmax': group_element_onehot = torch.nn.functional.gumbel_softmax(group_activations, tau=1, hard=True) else: - raise ValueError(f'Gradient trick {self.gradient_trick} not implemented') - + raise ValueError(f'Gradient trick {self.gradient_trick} not implemented') + # return the group element one hot encoding return group_element_onehot - + def canonicalize(self, x: torch.Tensor, targets: torch.Tensor=None, **kwargs): """ - This method takes an input data and + This method takes an input data and returns its canonicalized version and a dictionary containing the information about the canonicalization """ raise NotImplementedError() - + def invert_canonicalization(self, x: torch.Tensor, **kwargs): """ - This method takes the output of the canonicalized data + This method takes the output of the canonicalized data and returns the output for the original data orientation """ raise NotImplementedError() - - + + def get_prior_regularization_loss(self): group_activations = self.canonicalization_info_dict['group_activations'] dataset_prior = torch.zeros((group_activations.shape[0],), dtype=torch.long).to(self.device) return torch.nn.CrossEntropyLoss()(group_activations, dataset_prior) - - + + def get_identity_metric(self): group_activations = self.canonicalization_info_dict['group_activations'] return (group_activations.argmax(dim=-1) == 0).float().mean() - + class ContinuousGroupCanonicalization(BaseCanonicalization): - def __init__(self, - canonicalization_network: torch.nn.Module, + def __init__(self, + canonicalization_network: torch.nn.Module, beta: float = 1.0, gradient_trick: str = 'straight_through'): super().__init__(canonicalization_network) self.beta = beta self.gradient_trick = gradient_trick - + def canonicalizationnetworkout_to_groupelement(self, group_activations: torch.Tensor): """ This method takes the as input and returns the group element in a differentiable manner - + Args: group_activations: activations for each group element - + Returns: group_element: group element """ raise NotImplementedError() - + def canonicalize(self, x: torch.Tensor, targets: torch.Tensor=None, **kwargs): """ - This method takes an input data and + This method takes an input data and returns its canonicalized version and a dictionary containing the information about the canonicalization """ raise NotImplementedError() - + def invert_canonicalization(self, x: torch.Tensor, **kwargs): """ - This method takes the output of the canonicalized data + This method takes the output of the canonicalized data and returns the output for the original data orientation """ raise NotImplementedError() - - + + def get_prior_regularization_loss(self): group_elements_rep = self.canonicalization_info_dict['group_element_matrix_representation'] # shape: (batch_size, group_rep_dim, group_rep_dim) # Set the dataset prior to identity matrix of size group_rep_dim and repeat it for batch_size dataset_prior = torch.eye(group_elements_rep.shape[-1]).repeat( group_elements_rep.shape[0], 1, 1).to(self.device) return torch.nn.MSELoss()(group_elements_rep, dataset_prior) - + def get_identity_metric(self): group_elements_rep = self.canonicalization_info_dict['group_element_matrix_representation'] identity_element = torch.eye(group_elements_rep.shape[-1]).repeat( group_elements_rep.shape[0], 1, 1).to(self.device) return 1.0 - torch.nn.functional.mse_loss(group_elements_rep, identity_element).mean() - - - + + + # Idea for the user interface: @@ -193,7 +193,7 @@ def get_identity_metric(self): # example: canonicalization_network = ESCNNEquivariantNetwork(in_shape, out_channels, kernel_size, group_type='rotation', num_rotations=4, num_layers=3) # canonicalizer = GroupEquivariantImageCanonicalization(canonicalization_network, beta=1.0) # -# +# # 2. The user uses this wrapper with their code to canonicalize the input data # example: model = ResNet18() # x_canonized = canonicalizer(x) @@ -206,5 +206,3 @@ def get_identity_metric(self): # loss = criterion(model_out, y) # loss = canonicalizer.add_prior_regularizer(loss) # loss.backward() - - \ No newline at end of file diff --git a/equiadapt/common/utils.py b/equiadapt/common/utils.py index e22d3d0..9556568 100644 --- a/equiadapt/common/utils.py +++ b/equiadapt/common/utils.py @@ -69,7 +69,7 @@ def get_son_rep(self, params: torch.Tensor): son_bases = self.get_son_bases().to(params.device) A = torch.einsum('bs,sij->bij', params, son_bases) return torch.matrix_exp(A) - + def get_on_rep(self, params: torch.Tensor, reflect_indicators: torch.Tensor): """ Computes the representation for O(n) group, optionally including reflections. @@ -82,7 +82,7 @@ def get_on_rep(self, params: torch.Tensor, reflect_indicators: torch.Tensor): torch.Tensor: The representation of shape (batch_size, rep_dim, rep_dim). """ son_rep = self.get_son_rep(params) - + # This is a simplified and conceptual approach; actual reflection handling # would need to determine how to reflect (e.g., across which axis or plane) # and this might not directly apply as-is. @@ -90,7 +90,7 @@ def get_on_rep(self, params: torch.Tensor, reflect_indicators: torch.Tensor): reflection_matrix = torch.diag_embed(torch.tensor([1] * (self.group_dim - 1) + [-1])) on_rep = torch.matmul(son_rep, reflect_indicators * reflection_matrix + (1 - reflect_indicators) * identity_matrix) return on_rep - + def get_sen_rep(self, params: torch.Tensor): """Computes the representation for SEn group. @@ -101,14 +101,14 @@ def get_sen_rep(self, params: torch.Tensor): torch.Tensor: The representation of shape (batch_size, rep_dim, rep_dim). """ son_param_dim = self.group_dim * (self.group_dim - 1) // 2 - rho = torch.zeros(params.shape[0], self.group_dim + 1, + rho = torch.zeros(params.shape[0], self.group_dim + 1, self.group_dim + 1, device=params.device) rho[:, :self.group_dim, :self.group_dim] = self.get_son_rep( params[:, :son_param_dim].unsqueeze(0)).squeeze(0) rho[:, :self.group_dim, self.group_dim] = params[:, son_param_dim:] rho[:, self.group_dim, self.group_dim] = 1 return rho - + def get_en_rep(self, params: torch.Tensor, reflect_indicators: torch.Tensor): """Computes the representation for E(n) group. @@ -146,7 +146,7 @@ def get_en_rep(self, params: torch.Tensor, reflect_indicators: torch.Tensor): en_rep[:, self.group_dim, self.group_dim] = 1 return en_rep - + def get_group_rep(self, params): """Computes the representation for the specified Lie group. @@ -167,5 +167,3 @@ def get_group_rep(self, params): return self.get_en_rep(params) else: raise ValueError(f"Unsupported group type: {self.group_type}") - - diff --git a/equiadapt/images/canonicalization/continuous_group.py b/equiadapt/images/canonicalization/continuous_group.py index f1c0a49..bf25bbe 100644 --- a/equiadapt/images/canonicalization/continuous_group.py +++ b/equiadapt/images/canonicalization/continuous_group.py @@ -8,15 +8,15 @@ from torch.nn import functional as F class ContinuousGroupImageCanonicalization(ContinuousGroupCanonicalization): - def __init__(self, - canonicalization_network: torch.nn.Module, + def __init__(self, + canonicalization_network: torch.nn.Module, canonicalization_hyperparams: dict, in_shape: tuple ): super().__init__(canonicalization_network) - + assert len(in_shape) == 3, 'Input shape should be in the format (channels, height, width)' - + # pad and crop the input image if it is not rotated MNIST is_grayscale = (in_shape[0] == 1) self.pad = torch.nn.Identity() if is_grayscale else transforms.Pad( @@ -24,137 +24,137 @@ def __init__(self, ) self.crop = torch.nn.Identity() if is_grayscale else transforms.CenterCrop((in_shape[-2], in_shape[-1])) self.crop_canonization = torch.nn.Identity() if is_grayscale else transforms.CenterCrop(( - math.ceil(in_shape[-2] * canonicalization_hyperparams.input_crop_ratio), + math.ceil(in_shape[-2] * canonicalization_hyperparams.input_crop_ratio), math.ceil(in_shape[-1] * canonicalization_hyperparams.input_crop_ratio) )) self.resize_canonization = torch.nn.Identity() if is_grayscale else transforms.Resize(size=canonicalization_hyperparams.resize_shape) self.group_info_dict = {} - + def get_groupelement(self, x: torch.Tensor): """ This method takes the input image and maps it to the group element - + Args: x: input image - + Returns: group_element: group element """ raise NotImplementedError('get_groupelement method is not implemented') - + def transformations_before_canonicalization_network_forward(self, x: torch.Tensor): """ - This method takes an image as input and - returns the pre-canonicalized image + This method takes an image as input and + returns the pre-canonicalized image """ x = self.crop_canonization(x) x = self.resize_canonization(x) return x - + def get_group_from_out_vectors(self, out_vectors: torch.Tensor): """ This method takes the output of the canonicalization network and returns the group element - + Args: out_vectors: output of the canonicalization network - + Returns: group_element_dict: group element group_element_representation: group element representation """ group_element_dict = {} - + if self.group_type == 'roto-reflection': # Apply Gram-Schmidt to get the rotation matrices/orthogonal frame from # a batch of two 2D vectors rotoreflection_matrices = gram_schmidt(out_vectors) # (batch_size, 2, 2) - + # Calculate the determinant to check for reflection determinant = rotoreflection_matrices[:, 0, 0] * rotoreflection_matrices[:, 1, 1] - \ rotoreflection_matrices[:, 0, 1] * rotoreflection_matrices[:, 1, 0] - + reflect_indicator = (1 - determinant[:, None, None, None]) / 2 group_element_dict['reflection'] = reflect_indicator - + # Identify matrices with a reflection (negative determinant) reflection_indices = determinant < 0 # For matrices with a reflection, adjust to remove the reflection component # This example assumes flipping the sign of the second column as one way to adjust # Note: This method of adjustment is context-dependent and may vary based on your specific requirements - rotation_matrices = rotoreflection_matrices - rotation_matrices[reflection_indices, :, 1] *= -1 + rotation_matrices = rotoreflection_matrices + rotation_matrices[reflection_indices, :, 1] *= -1 else: # Pass the first vector to get the rotation matrix rotation_matrices = self.get_rotation_matrix_from_vector(out_vectors[:, 0]) - + group_element_dict['rotation'] = rotation_matrices - + return group_element_dict, rotoreflection_matrices if self.group_type == 'roto-reflection' else rotation_matrices - - + + def canonicalize(self, x: torch.Tensor): """ - This method takes an image as input and - returns the canonicalized image - + This method takes an image as input and + returns the canonicalized image + Args: x: input image - + Returns: x_canonicalized: canonicalized image """ self.device = x.device - + # get the group element dictionary with keys as 'rotation' and 'reflection' - group_element_dict = self.get_groupelement(x) - + group_element_dict = self.get_groupelement(x) + rotation_matrices = group_element_dict['rotation'] rotation_matrices[:, [0, 1], [1, 0]] *= -1 - + if 'reflection' in group_element_dict: reflect_indicator = group_element_dict['reflection'] # Reflect the image conditionally x = (1 - reflect_indicator) * x + reflect_indicator * K.geometry.hflip(x) - - + + # Apply padding before canonicalization x = self.pad(x) - + # Compute affine part for warp affine alpha, beta = rotation_matrices[:, 0, 0], rotation_matrices[:, 0, 1] cx, cy = x.shape[-2] // 2, x.shape[-1] // 2 affine_part = torch.stack([(1 - alpha) * cx - beta * cy, beta * cx + (1 - alpha) * cy], dim=1) - + # Prepare affine matrices for warp affine, adjusting rotation matrix for Kornia compatibility affine_matrices = torch.cat([rotation_matrices, affine_part.unsqueeze(-1)], dim=-1) - - # Apply warp affine, and then crop + + # Apply warp affine, and then crop x = K.geometry.warp_affine(x, affine_matrices, dsize=(x.shape[-2], x.shape[-1])) x = self.crop(x) return x - + def invert_canonicalization(self, x_canonicalized_out: torch.Tensor, induced_rep_type: str = 'vector'): """ This method takes the output of canonicalized image as input and returns output of the original image - + """ return get_action_on_image_features(feature_map = x_canonicalized_out, group_info_dict = self.group_info_dict, group_element_dict = self.canonicalization_info_dict['group_element'], induced_rep_type = induced_rep_type) - + class SteerableImageCanonicalization(ContinuousGroupImageCanonicalization): - def __init__(self, - canonicalization_network: torch.nn.Module, + def __init__(self, + canonicalization_network: torch.nn.Module, canonicalization_hyperparams: dict, in_shape: tuple ): @@ -162,14 +162,14 @@ def __init__(self, canonicalization_hyperparams, in_shape) self.group_type = canonicalization_network.group_type - + def get_rotation_matrix_from_vector(self, vectors: torch.Tensor): ''' This method takes the input vector and returns the rotation matrix - + Args: vectors: input vector - + Returns: rotation_matrices: rotation matrices ''' @@ -177,27 +177,27 @@ def get_rotation_matrix_from_vector(self, vectors: torch.Tensor): v2 = torch.stack([-v1[:, 1], v1[:, 0]], dim=1) rotation_matrices = torch.stack([v1, v2], dim=1) return rotation_matrices - + def get_groupelement(self, x: torch.Tensor): """ This method takes the input image and maps it to the group element - + Args: x: input image - + Returns: group_element: group element """ - + group_element_dict = {} - + x = self.transformations_before_canonicalization_network_forward(x) - + # convert the group activations to one hot encoding of group element # this conversion is differentiable and will be used to select the group element out_vectors = self.canonicalization_network(x) - + # Check whether canonicalization_info_dict is already defined if not hasattr(self, 'canonicalization_info_dict'): self.canonicalization_info_dict = {} @@ -206,13 +206,13 @@ def get_groupelement(self, x: torch.Tensor): self.canonicalization_info_dict['group_element_matrix_representation'] = group_element_representation self.canonicalization_info_dict['group_element'] = group_element_dict - + return group_element_dict - + class OptimizedSteerableImageCanonicalization(ContinuousGroupImageCanonicalization): - def __init__(self, - canonicalization_network: torch.nn.Module, + def __init__(self, + canonicalization_network: torch.nn.Module, canonicalization_hyperparams: dict, in_shape: tuple ): @@ -220,14 +220,14 @@ def __init__(self, canonicalization_hyperparams, in_shape) self.group_type = canonicalization_hyperparams.group_type - + def get_rotation_matrix_from_vector(self, vectors: torch.Tensor): ''' This method takes the input vector and returns the rotation matrix - + Args: vectors: input vector - + Returns: rotation_matrices: rotation matrices ''' @@ -235,7 +235,7 @@ def get_rotation_matrix_from_vector(self, vectors: torch.Tensor): v2 = torch.stack([-v1[:, 1], v1[:, 0]], dim=1) rotation_matrices = torch.stack([v1, v2], dim=1) return rotation_matrices - + def group_augment(self, x): """ Augmentation of the input images by applying random rotations and, @@ -274,58 +274,58 @@ def group_augment(self, x): # Return augmented images and the transformation matrices used return augmented_images, rotation_matrices[:, :, :2] - + def get_groupelement(self, x: torch.Tensor): """ This method takes the input image and maps it to the group element - + Args: x: input image - + Returns: group_element: group element """ - + group_element_dict = {} - + batch_size = x.shape[0] - + # randomly sample generate some agmentations of the input image using rotation and reflection - + x_augmented, group_element_representations_augmented_gt = self.group_augment(x) # size (batch_size * group_size, in_channels, height, width) - + x_all = torch.cat([x, x_augmented], dim=0) # size (batch_size * 2, in_channels, height, width) - + x_all = self.transformations_before_canonicalization_network_forward(x_all) - + out_vectors_all = self.canonicalization_network(x_all) # size (batch_size * 2, out_vector_size) - + out_vectors_all = out_vectors_all.reshape(2 * batch_size, -1, 2) # size (batch_size * 2, num_vectors, 2) - + out_vectors, out_vectors_augmented = out_vectors_all.chunk(2, dim=0) - + # Check whether canonicalization_info_dict is already defined if not hasattr(self, 'canonicalization_info_dict'): - self.canonicalization_info_dict = {} - + self.canonicalization_info_dict = {} + group_element_dict, group_element_representations = self.get_group_from_out_vectors(out_vectors) # Store the matrix representation of the group element for regularization and identity metric self.canonicalization_info_dict['group_element_matrix_representation'] = group_element_representations self.canonicalization_info_dict['group_element'] = group_element_dict - + _, group_element_representations_augmented = self.get_group_from_out_vectors(out_vectors_augmented) self.canonicalization_info_dict['group_element_matrix_representation_augmented'] = \ group_element_representations_augmented self.canonicalization_info_dict['group_element_matrix_representation_augmented_gt'] = \ group_element_representations_augmented_gt - + return group_element_dict - + def get_optimization_specific_loss(self): """ This method returns the optimization specific loss - + Returns: loss: optimization specific loss """ @@ -333,4 +333,3 @@ def get_optimization_specific_loss(self): self.canonicalization_info_dict['group_element_matrix_representation_augmented'], \ self.canonicalization_info_dict['group_element_matrix_representation_augmented_gt'] return F.mse_loss(group_element_representations_augmented, group_element_representations_augmented_gt) - \ No newline at end of file diff --git a/equiadapt/images/canonicalization/discrete_group.py b/equiadapt/images/canonicalization/discrete_group.py index 9a5699c..d84b78f 100644 --- a/equiadapt/images/canonicalization/discrete_group.py +++ b/equiadapt/images/canonicalization/discrete_group.py @@ -7,142 +7,142 @@ from torch.nn import functional as F class DiscreteGroupImageCanonicalization(DiscreteGroupCanonicalization): - def __init__(self, - canonicalization_network: torch.nn.Module, + def __init__(self, + canonicalization_network: torch.nn.Module, canonicalization_hyperparams: dict, in_shape: tuple ): super().__init__(canonicalization_network) - + self.beta = canonicalization_hyperparams.beta - + assert len(in_shape) == 3, "Input shape should be in the format (channels, height, width)" - + # DEfine all the image transformations here which are used during canonicalization # pad and crop the input image if it is not rotated MNIST is_grayscale = (in_shape[0] == 1) - + self.pad = torch.nn.Identity() if is_grayscale else transforms.Pad(math.ceil(in_shape[-1] * 0.5), padding_mode='edge') self.crop = torch.nn.Identity() if is_grayscale else transforms.CenterCrop((in_shape[-2], in_shape[-1])) - + self.crop_canonization = torch.nn.Identity() if is_grayscale else transforms.CenterCrop(( - math.ceil(in_shape[-2] * canonicalization_hyperparams.input_crop_ratio), + math.ceil(in_shape[-2] * canonicalization_hyperparams.input_crop_ratio), math.ceil(in_shape[-1] * canonicalization_hyperparams.input_crop_ratio) )) - + self.resize_canonization = torch.nn.Identity() if is_grayscale else transforms.Resize(size=canonicalization_hyperparams.resize_shape) - + def groupactivations_to_groupelement(self, group_activations: torch.Tensor): """ This method takes the activations for each group element as input and returns the group element - + Args: group_activations: activations for each group element - + Returns: group_element: group element """ - + # convert the group activations to one hot encoding of group element # this conversion is differentiable and will be used to select the group element group_elements_one_hot = self.groupactivations_to_groupelementonehot(group_activations) - + angles = torch.linspace(0., 360., self.num_rotations+1)[:self.num_rotations].to(self.device) group_elements_rot_comp = torch.cat([angles, angles], dim=0) if self.group_type == "roto-reflection" else angles - + group_element_dict = {} - + group_element_rot_comp = torch.sum(group_elements_one_hot * group_elements_rot_comp, dim=-1) group_element_dict["rotation"] = group_element_rot_comp if self.group_type == "roto-reflection": - reflect_identifier_vector = torch.cat([torch.zeros(self.num_rotations), + reflect_identifier_vector = torch.cat([torch.zeros(self.num_rotations), torch.ones(self.num_rotations)], dim=0).to(self.device) group_element_reflect_comp = torch.sum(group_elements_one_hot * reflect_identifier_vector, dim=-1) group_element_dict["reflection"] = group_element_reflect_comp - + return group_element_dict - + def get_group_activations(self, x: torch.Tensor): """ - This method takes an image as input and + This method takes an image as input and returns the group activations """ - raise NotImplementedError("get_group_activations is not implemented for" + raise NotImplementedError("get_group_activations is not implemented for" "the DiscreteGroupImageCanonicalization class") - - + + def get_groupelement(self, x: torch.Tensor): """ This method takes the input image and maps it to the group element - + Args: x: input image - + Returns: group_element: group element """ group_activations = self.get_group_activations(x) group_element_dict = self.groupactivations_to_groupelement(group_activations) - + # Check whether canonicalization_info_dict is already defined if not hasattr(self, "canonicalization_info_dict"): self.canonicalization_info_dict = {} self.canonicalization_info_dict["group_element"] = group_element_dict self.canonicalization_info_dict["group_activations"] = group_activations - + return group_element_dict - + def transformations_before_canonicalization_network_forward(self, x: torch.Tensor): """ - This method takes an image as input and - returns the pre-canonicalized image + This method takes an image as input and + returns the pre-canonicalized image """ x = self.crop_canonization(x) x = self.resize_canonization(x) return x - - + + def canonicalize(self, x: torch.Tensor, targets: torch.Tensor = None): """ - This method takes an image as input and - returns the canonicalized image + This method takes an image as input and + returns the canonicalized image """ self.device = x.device group_element_dict = self.get_groupelement(x) - + x = self.pad(x) - + if "reflection" in group_element_dict.keys(): reflect_indicator = group_element_dict["reflection"][:,None,None,None] x = (1 - reflect_indicator) * x + reflect_indicator * K.geometry.hflip(x) x = K.geometry.rotate(x, -group_element_dict["rotation"]) - + x = self.crop(x) - + if targets: # canonicalize the targets (for instance segmentation, masks and boxes) image_width = x.shape[-1] - + if 'reflection' in group_element_dict.keys(): # flip masks and boxes for t in range(len(targets)): targets[t]["boxes"] = flip_boxes(targets[t]["boxes"], image_width) targets[t]["masks"] = flip_masks(targets[t]["masks"]) - + # rotate masks and boxes for t in range(len(targets)): targets[t]["boxes"] = rotate_boxes(targets[t]["boxes"], group_element_dict["rotation"][t], image_width) targets[t]["masks"] = rotate_masks(targets[t]["masks"], -group_element_dict["rotation"][t].item()) - + return x, targets - + return x - + def invert_canonicalization(self, x_canonicalized_out: torch.Tensor, induced_rep_type: str = "regular"): """ This method takes the output of canonicalized image as input and @@ -152,13 +152,13 @@ def invert_canonicalization(self, x_canonicalized_out: torch.Tensor, induced_rep group_info_dict = self.group_info_dict, group_element_dict = self.canonicalization_info_dict["group_element"], induced_rep_type = induced_rep_type) - - - + + + class GroupEquivariantImageCanonicalization(DiscreteGroupImageCanonicalization): - def __init__(self, - canonicalization_network: torch.nn.Module, + def __init__(self, + canonicalization_network: torch.nn.Module, canonicalization_hyperparams: dict, in_shape: tuple ): @@ -170,21 +170,21 @@ def __init__(self, self.num_group = self.num_rotations if self.group_type == "rotation" else 2 * self.num_rotations self.group_info_dict = {"num_rotations": self.num_rotations, "num_group": self.num_group} - + def get_group_activations(self, x: torch.Tensor): """ - This method takes an image as input and + This method takes an image as input and returns the group activations """ x = self.transformations_before_canonicalization_network_forward(x) group_activations = self.canonicalization_network(x) return group_activations - - - + + + class OptimizedGroupEquivariantImageCanonicalization(DiscreteGroupImageCanonicalization): - def __init__(self, - canonicalization_network: torch.nn.Module, + def __init__(self, + canonicalization_network: torch.nn.Module, canonicalization_hyperparams: dict, in_shape: tuple ): @@ -196,20 +196,20 @@ def __init__(self, self.artifact_err_wt = canonicalization_hyperparams.artifact_err_wt self.num_group = self.num_rotations if self.group_type == "rotation" else 2 * self.num_rotations self.out_vector_size = canonicalization_network.out_vector_size - + # group optimization specific cropping and padding (required for group_augment()) group_augment_in_shape = canonicalization_hyperparams.resize_shape self.crop_group_augment = torch.nn.Identity() if in_shape[0] == 1 else transforms.CenterCrop(group_augment_in_shape) self.pad_group_augment = torch.nn.Identity() if in_shape[0] == 1 else transforms.Pad(math.ceil(group_augment_in_shape * 0.5), padding_mode='edge') - - + + self.reference_vector = torch.nn.Parameter( - torch.randn(1, self.out_vector_size), + torch.randn(1, self.out_vector_size), requires_grad=canonicalization_hyperparams.learn_ref_vec ) self.group_info_dict = {"num_rotations": self.num_rotations, "num_group": self.num_group} - + def rotate_and_maybe_reflect(self, x: torch.Tensor, degrees: torch.Tensor, reflect: bool = False): x_augmented_list = [] for degree in degrees: @@ -220,68 +220,68 @@ def rotate_and_maybe_reflect(self, x: torch.Tensor, degrees: torch.Tensor, refle x_rot = self.crop_group_augment(x_rot) x_augmented_list.append(x_rot) return x_augmented_list - - + + def group_augment(self, x : torch.Tensor): - + degrees = torch.linspace(0, 360, self.num_rotations + 1)[:-1].to(self.device) x_augmented_list = self.rotate_and_maybe_reflect(x, degrees) - + if self.group_type == "roto-reflection": x_augmented_list += self.rotate_and_maybe_reflect(x, degrees, reflect=True) - + return torch.cat(x_augmented_list, dim=0) - + def get_group_activations(self, x: torch.Tensor): """ - This method takes an image as input and + This method takes an image as input and returns the group activations """ - - x = self.transformations_before_canonicalization_network_forward(x) + + x = self.transformations_before_canonicalization_network_forward(x) x_augmented = self.group_augment(x) # size (batch_size * group_size, in_channels, height, width) vector_out = self.canonicalization_network(x_augmented) # size (batch_size * group_size, reference_vector_size) self.canonicalization_info_dict = {"vector_out": vector_out} - + if self.artifact_err_wt: # select a random rotation for each image in the batch rotation_indices = torch.randint(0, self.num_rotations, (x_augmented.shape[0],)).to(self.device) - + # apply the rotation degree to the images x_dummy = self.pad_group_augment(x_augmented) x_dummy = K.geometry.rotate(x_dummy, -rotation_indices * 360 / self.num_rotations) x_dummy = self.crop_group_augment(x_dummy) - + # invert the image back to the original orientation x_dummy = self.pad_group_augment(x_dummy) x_dummy = K.geometry.rotate(x_dummy, rotation_indices * 360 / self.num_rotations) x_dummy = self.crop_group_augment(x_dummy) - + vector_out_dummy = self.canonicalization_network(x_dummy) # size (batch_size * group_size, reference_vector_size) self.canonicalization_info_dict.update({"vector_out_dummy": vector_out_dummy}) - + scalar_out = F.cosine_similarity( - self.reference_vector.repeat(vector_out.shape[0], 1), + self.reference_vector.repeat(vector_out.shape[0], 1), vector_out ) # size (batch_size * group_size, 1) group_activations = scalar_out.reshape(self.num_group, -1).T # size (batch_size, group_size) return group_activations - - + + def get_optimization_specific_loss(self): vectors = self.canonicalization_info_dict["vector_out"] - + # compute error to reduce rotation artifacts rotation_artifact_error = 0 if self.artifact_err_wt: vectors_dummy = self.canonicalization_info_dict["vector_out_dummy"] rotation_artifact_error = torch.nn.functional.mse_loss(vectors_dummy, vectors) - + # error to ensure that the vectors are (as much as possible) orthogonal vectors = vectors.reshape(self.num_group, -1, self.out_vector_size).permute((1, 0, 2)) # (batch_size, group_size, vector_out_size) distances = vectors @ vectors.permute((0, 2, 1)) mask = 1.0 - torch.eye(self.num_group).to(self.device) # (group_size, group_size) - - - return torch.abs(distances * mask).mean() + self.artifact_err_wt * rotation_artifact_error \ No newline at end of file + + + return torch.abs(distances * mask).mean() + self.artifact_err_wt * rotation_artifact_error diff --git a/equiadapt/images/canonicalization_networks/__init__.py b/equiadapt/images/canonicalization_networks/__init__.py index ee3760c..7d5dec1 100644 --- a/equiadapt/images/canonicalization_networks/__init__.py +++ b/equiadapt/images/canonicalization_networks/__init__.py @@ -1,3 +1,3 @@ from .escnn_networks import ESCNNEquivariantNetwork, ESCNNSteerableNetwork, ESCNNWRNEquivariantNetwork from .custom_nonequivariant_networks import ConvNetwork -from .custom_equivariant_networks import CustomEquivariantNetwork \ No newline at end of file +from .custom_equivariant_networks import CustomEquivariantNetwork diff --git a/equiadapt/images/canonicalization_networks/custom_equivariant_networks.py b/equiadapt/images/canonicalization_networks/custom_equivariant_networks.py index 4fbdaa8..b3d6d21 100644 --- a/equiadapt/images/canonicalization_networks/custom_equivariant_networks.py +++ b/equiadapt/images/canonicalization_networks/custom_equivariant_networks.py @@ -223,18 +223,18 @@ def forward(self, x): if self.bias is not None: x = x + self.bias[None, :, None, None, None] return x - + class CustomEquivariantNetwork(nn.Module): - def __init__(self, - in_shape, - out_channels, - kernel_size, - group_type='rotation', - num_rotations=4, + def __init__(self, + in_shape, + out_channels, + kernel_size, + group_type='rotation', + num_rotations=4, num_layers=1, device='cuda' if torch.cuda.is_available() else 'cpu'): super().__init__() - + if group_type == 'rotation': layer_list = [RotationEquivariantConvLift(in_shape[0], out_channels, kernel_size, num_rotations, device=device)] for i in range(num_layers - 1): @@ -257,5 +257,5 @@ def forward(self, x): """ feature_map = self.eqv_network(x) group_activatiobs = torch.mean(feature_map, dim=(1, 3, 4)) - + return group_activatiobs diff --git a/equiadapt/images/canonicalization_networks/custom_nonequivariant_networks.py b/equiadapt/images/canonicalization_networks/custom_nonequivariant_networks.py index 5d47147..f690c9a 100644 --- a/equiadapt/images/canonicalization_networks/custom_nonequivariant_networks.py +++ b/equiadapt/images/canonicalization_networks/custom_nonequivariant_networks.py @@ -2,11 +2,11 @@ from torch import nn class ConvNetwork(nn.Module): - def __init__(self, - in_shape, - out_channels, - kernel_size, - num_layers=2, + def __init__(self, + in_shape, + out_channels, + kernel_size, + num_layers=2, out_vector_size=128): super().__init__() @@ -44,4 +44,4 @@ def forward(self, x): batch_size = x.shape[0] out = self.enc_network(x) out = out.reshape(batch_size, -1) - return self.final_fc(out) \ No newline at end of file + return self.final_fc(out) diff --git a/equiadapt/images/canonicalization_networks/escnn_networks.py b/equiadapt/images/canonicalization_networks/escnn_networks.py index 36e6b0a..7114917 100644 --- a/equiadapt/images/canonicalization_networks/escnn_networks.py +++ b/equiadapt/images/canonicalization_networks/escnn_networks.py @@ -3,12 +3,12 @@ from escnn import gspaces class ESCNNEquivariantNetwork(torch.nn.Module): - def __init__(self, - in_shape, - out_channels, - kernel_size, - group_type='rotation', - num_rotations=4, + def __init__(self, + in_shape, + out_channels, + kernel_size, + group_type='rotation', + num_rotations=4, num_layers=1): super().__init__() @@ -24,13 +24,13 @@ def __init__(self, self.gspace = gspaces.flipRot2dOnR2(num_rotations) else: raise ValueError('group_type must be rotation or roto-reflection for now.') - + # If the group is roto-reflection, then the number of group elements is twice the number of rotations self.num_group_elements = num_rotations if group_type == 'rotation' else 2 * num_rotations r1 = escnn.nn.FieldType(self.gspace, [self.gspace.trivial_repr] * self.in_channels) r2 = escnn.nn.FieldType(self.gspace, [self.gspace.regular_repr] * out_channels) - + self.in_type = r1 self.out_type = r2 @@ -45,14 +45,14 @@ def __init__(self, self.eqv_network.append(escnn.nn.InnerBatchNorm(self.out_type, momentum=0.9),) self.eqv_network.append(escnn.nn.ReLU(self.out_type, inplace=True),) self.eqv_network.append(escnn.nn.PointwiseDropout(self.out_type, p=0.5),) - + self.eqv_network.append(escnn.nn.R2Conv(self.out_type, self.out_type, kernel_size),) - + def forward(self, x): """ - The forward takes an image as input and returns the activations of + The forward takes an image as input and returns the activations of each group element as output. - + x shape: (batch_size, in_channels, height, width) :return: (batch_size, group_size) """ @@ -61,24 +61,24 @@ def forward(self, x): feature_map = out.tensor feature_map = feature_map.reshape( - feature_map.shape[0], self.out_channels, self.num_group_elements, + feature_map.shape[0], self.out_channels, self.num_group_elements, feature_map.shape[-2], feature_map.shape[-1] ) - + group_activations = torch.mean(feature_map, dim=(1, 3, 4)) return group_activations - + class ESCNNSteerableNetwork(torch.nn.Module): - def __init__(self, - in_shape: tuple, - out_channels: int, - kernel_size: int = 9, - group_type: str = 'rotation', + def __init__(self, + in_shape: tuple, + out_channels: int, + kernel_size: int = 9, + group_type: str = 'rotation', num_layers: int = 1): super().__init__() - + self.group_type = group_type assert group_type == 'rotation', 'group_type must be rotation for now.' # TODO: Add support for roto-reflection group @@ -88,7 +88,7 @@ def __init__(self, # The input image is a scalar field, corresponding to the trivial representation in_type = escnn.nn.FieldType(self.gspace, in_shape[0] * [self.gspace.trivial_repr]) - + # Store the input type for wrapping the images into a geometric tensor during the forward pass self.input_type = in_type @@ -117,7 +117,7 @@ def forward(self, x : torch.Tensor): x = torch.mean(x, dim=(-1, -2)) # Average over spatial dimensions x = x.reshape(x.shape[0], 2, 2) # Reshape to get vector/vectors of dimension 2 return x - + # wide resnet equivariant network and utilities class ESCNNWideBottleneck(torch.nn.Module): @@ -140,7 +140,7 @@ def __init__( escnn.nn.InnerBatchNorm(self.middle_type, momentum=0.9), escnn.nn.ReLU(self.middle_type, inplace=True), escnn.nn.R2Conv(self.middle_type, self.out_type, kernel_size, padding=kernel_size//2), - + escnn.nn.InnerBatchNorm(self.out_type, momentum=0.9), escnn.nn.ReLU(self.out_type, inplace=True), @@ -151,8 +151,8 @@ def forward(self, x): out = self.conv_network(x) out += x return out - - + + class ESCNNWideBasic(torch.nn.Module): def __init__( self, @@ -186,19 +186,19 @@ def forward(self, x): shortcut = self.shortcut(x) if self.shortcut is not None else x out += shortcut return out - + class ESCNNWRNEquivariantNetwork(torch.nn.Module): - def __init__(self, - in_shape: tuple, + def __init__(self, + in_shape: tuple, out_channels: int = 64, kernel_size: int = 9, - group_type: str = 'rotation', + group_type: str = 'rotation', num_layers: int = 12, num_rotations: int = 4): super().__init__() - + self.group_type = group_type - + # The model is equivariant under discrete rotations if group_type == 'rotation': self.gspace = gspaces.rot2dOnR2(num_rotations) @@ -206,19 +206,19 @@ def __init__(self, self.gspace = gspaces.flipRot2dOnR2(num_rotations) else: raise ValueError('group_type must be rotation or roto-reflection for now.') - + # The input image is a scalar field, corresponding to the trivial representation in_type = escnn.nn.FieldType(self.gspace, in_shape[0] * [self.gspace.trivial_repr]) - + # Store the input type for wrapping the images into a geometric tensor during the forward pass self.input_type = in_type - + # other initialization widen_factor = 2 self.kernel_size = kernel_size self.group_type = group_type self.out_channels = out_channels * widen_factor - + self.num_rotations = num_rotations self.num_group_elements = num_rotations if group_type == 'rotation' else 2 * num_rotations @@ -228,7 +228,7 @@ def __init__(self, r3 = escnn.nn.FieldType(self.gspace, [self.gspace.regular_repr] * nstages[1]) r4 = escnn.nn.FieldType(self.gspace, [self.gspace.regular_repr] * nstages[2]) r5 = escnn.nn.FieldType(self.gspace, [self.gspace.regular_repr] * nstages[3]) - + self.in_type = r1 self.out_type = r5 @@ -250,7 +250,7 @@ def __init__(self, self.eqv_network.append(escnn.nn.ReLU(rs[ridx+1], inplace=True),) self.eqv_network.append(escnn.nn.R2Conv(r4, r5, kernel_size),) - + def forward(self, x): """ x shape: (batch_size, in_channels, height, width) @@ -261,9 +261,9 @@ def forward(self, x): out = self.eqv_network(x) feature_map = out.tensor - feature_map = feature_map.reshape(feature_map.shape[0], - feature_map.shape[1] // self.num_group_elements, self.num_group_elements, + feature_map = feature_map.reshape(feature_map.shape[0], + feature_map.shape[1] // self.num_group_elements, self.num_group_elements, feature_map.shape[-2], feature_map.shape[-1]) feature_fibres = torch.mean(feature_map, dim=(1, 3, 4)) - return feature_fibres \ No newline at end of file + return feature_fibres diff --git a/equiadapt/images/utils.py b/equiadapt/images/utils.py index 423812f..c633762 100644 --- a/equiadapt/images/utils.py +++ b/equiadapt/images/utils.py @@ -12,7 +12,7 @@ def roll_by_gather(feature_map: torch.Tensor, shifts: torch.Tensor): def get_action_on_image_features(feature_map: torch.Tensor, group_info_dict: dict, - group_element_dict: dict, + group_element_dict: dict, induced_rep_type: str ='regular'): """ This function takes the feature map and the action and returns the feature map @@ -26,18 +26,18 @@ def get_action_on_image_features(feature_map: torch.Tensor, assert feature_map.shape[1] % num_group == 0 angles = group_element_dict['group']['rotation'] x_out = K.geometry.rotate(feature_map, angles) - + if 'reflection' in group_element_dict['group']: - reflect_indicator = group_element_dict['group']['reflection'] + reflect_indicator = group_element_dict['group']['reflection'] x_out_reflected = K.geometry.hflip(x_out) x_out = x_out * reflect_indicator[:,None,None,None] + \ x_out_reflected * (1 - reflect_indicator[:,None,None,None]) - + x_out = x_out.reshape(batch_size, C // num_group, num_group, H, W) shift = angles / 360. * num_rotations if 'reflection' in group_element_dict['group']: x_out = torch.cat([ - roll_by_gather(x_out[:,:,:num_rotations], shift), + roll_by_gather(x_out[:,:,:num_rotations], shift), roll_by_gather(x_out[:,:,num_rotations:], -shift) ], dim=2) else: @@ -48,7 +48,7 @@ def get_action_on_image_features(feature_map: torch.Tensor, angles = group_element_dict['group'][0] x_out = K.geometry.rotate(feature_map, angles) if 'reflection' in group_element_dict['group']: - reflect_indicator = group_element_dict['group']['reflection'] + reflect_indicator = group_element_dict['group']['reflection'] x_out_reflected = K.geometry.hflip(x_out) x_out = x_out * reflect_indicator[:,None,None,None] + \ x_out_reflected * (1 - reflect_indicator[:,None,None,None]) @@ -61,11 +61,11 @@ def get_action_on_image_features(feature_map: torch.Tensor, def flip_boxes(boxes, width): boxes[:, [0, 2]] = width - boxes[:, [2, 0]] - return boxes + return boxes def flip_masks(masks): return masks.flip(-1) - + def rotate_masks(masks, angle): return transforms.functional.rotate(masks, angle) @@ -88,4 +88,4 @@ def rotate_boxes(boxes, angle, width): y_min_rot, y_max_rot = torch.min(y_min_rot, y_max_rot), torch.max(y_min_rot, y_max_rot) rotated_boxes = torch.stack([x_min_rot, y_min_rot, x_max_rot, y_max_rot], dim=-1) - return rotated_boxes \ No newline at end of file + return rotated_boxes diff --git a/examples/images/classification/README.md b/examples/images/classification/README.md index fdd92b1..054a480 100644 --- a/examples/images/classification/README.md +++ b/examples/images/classification/README.md @@ -6,11 +6,11 @@ python train.py canonicalization=group_equivariant experiment.training.loss.prior_weight=0 ``` ### For image classification (with prior regularization) -``` -python train.py canonicalization=group_equivariant +``` +python train.py canonicalization=group_equivariant ``` -**Note**: You can also run the `train.py` as follows from root directory of the project: +**Note**: You can also run the `train.py` as follows from root directory of the project: ``` python examples/images/classification/train.py canonicalization=group_equivariant ``` diff --git a/examples/images/classification/configs/canonicalization/group_equivariant.yaml b/examples/images/classification/configs/canonicalization/group_equivariant.yaml index 712a1f3..8a188a1 100644 --- a/examples/images/classification/configs/canonicalization/group_equivariant.yaml +++ b/examples/images/classification/configs/canonicalization/group_equivariant.yaml @@ -8,4 +8,4 @@ network_hyperparams: num_rotations: 4 # Number of rotations for the canonization network beta: 1.0 # Beta parameter for the canonization network input_crop_ratio: 0.8 # Ratio at which we crop the input to the canonicalization -resize_shape: 32 # Resize shape for the input \ No newline at end of file +resize_shape: 32 # Resize shape for the input diff --git a/examples/images/classification/configs/canonicalization/identity.yaml b/examples/images/classification/configs/canonicalization/identity.yaml index 1598d17..513e776 100644 --- a/examples/images/classification/configs/canonicalization/identity.yaml +++ b/examples/images/classification/configs/canonicalization/identity.yaml @@ -1 +1 @@ -canonicalization_type: identity \ No newline at end of file +canonicalization_type: identity diff --git a/examples/images/classification/configs/canonicalization/opt_group_equivariant.yaml b/examples/images/classification/configs/canonicalization/opt_group_equivariant.yaml index a41bb04..986eae3 100644 --- a/examples/images/classification/configs/canonicalization/opt_group_equivariant.yaml +++ b/examples/images/classification/configs/canonicalization/opt_group_equivariant.yaml @@ -11,4 +11,4 @@ beta: 1.0 # Beta parameter for the canonization network input_crop_ratio: 0.8 # Ratio at which we crop the input to the canonicalization resize_shape: 96 # Resize shape for the input learn_ref_vec: False # Whether to learn the reference vector -artifact_err_wt: 0 # Weight for rotation artifact error (specific to image data, for non C4 rotation, for non-equivariant canonicalization networks) \ No newline at end of file +artifact_err_wt: 0 # Weight for rotation artifact error (specific to image data, for non C4 rotation, for non-equivariant canonicalization networks) diff --git a/examples/images/classification/configs/canonicalization/opt_steerable.yaml b/examples/images/classification/configs/canonicalization/opt_steerable.yaml index e492781..76ca2d9 100644 --- a/examples/images/classification/configs/canonicalization/opt_steerable.yaml +++ b/examples/images/classification/configs/canonicalization/opt_steerable.yaml @@ -6,4 +6,4 @@ network_hyperparams: num_layers: 3 # Number of layers in the canonization network out_vector_size: 4 # Dimension of the output vector group_type: rotation # Type of group for the canonization network -input_crop_ratio: 0.8 # Ratio at which we crop the input to the canonicalization \ No newline at end of file +input_crop_ratio: 0.8 # Ratio at which we crop the input to the canonicalization diff --git a/examples/images/classification/configs/canonicalization/steerable.yaml b/examples/images/classification/configs/canonicalization/steerable.yaml index d3a63bc..e6c0755 100644 --- a/examples/images/classification/configs/canonicalization/steerable.yaml +++ b/examples/images/classification/configs/canonicalization/steerable.yaml @@ -5,4 +5,4 @@ network_hyperparams: out_channels: 16 # Number of output channels for the canonization network num_layers: 3 # Number of layers in the canonization network group_type: rotation # Type of group for the canonization network -input_crop_ratio: 0.8 # Ratio at which we crop the input to the canonicalization \ No newline at end of file +input_crop_ratio: 0.8 # Ratio at which we crop the input to the canonicalization diff --git a/examples/images/classification/configs/checkpoint/default.yaml b/examples/images/classification/configs/checkpoint/default.yaml index 419f669..7398463 100644 --- a/examples/images/classification/configs/checkpoint/default.yaml +++ b/examples/images/classification/configs/checkpoint/default.yaml @@ -1,3 +1,3 @@ checkpoint_path: ${oc.env:CHECKPOINT_PATH} # Path to save checkpoints checkpoint_name: "" # Model checkpoint name, should be left empty and dynamically allocated later -save_canonized_images: 0 # Whether to save canonized images (1) or not (0) \ No newline at end of file +save_canonized_images: 0 # Whether to save canonized images (1) or not (0) diff --git a/examples/images/classification/configs/dataset/default.yaml b/examples/images/classification/configs/dataset/default.yaml index 1b43b55..7600170 100644 --- a/examples/images/classification/configs/dataset/default.yaml +++ b/examples/images/classification/configs/dataset/default.yaml @@ -2,4 +2,4 @@ dataset_name: cifar10 # Name of the dataset to use data_path: ${oc.env:DATA_PATH} # Path to the dataset augment: 1 # Whether to use data augmentation (1) or not (0) num_workers: 4 # Number of workers for data loading -batch_size: 128 # Number of samples per batch \ No newline at end of file +batch_size: 128 # Number of samples per batch diff --git a/examples/images/classification/configs/experiment/default.yaml b/examples/images/classification/configs/experiment/default.yaml index a7211d2..67ef281 100644 --- a/examples/images/classification/configs/experiment/default.yaml +++ b/examples/images/classification/configs/experiment/default.yaml @@ -1,5 +1,5 @@ run_mode: train # Mode to run the model in, different run modes 1)dryrun 2)train 3)test 4)auto_tune -seed: 0 # Seed for random number generation +seed: 0 # Seed for random number generation deterministic: false # Whether to set deterministic mode (true) or not (false) device: cuda # Device, can be cuda or cpu num_nodes: 1 @@ -17,4 +17,4 @@ training: inference: method: group # Type of inference options 1) vanilla 2) group group_type: rotation # Type of group to test during inference 1) Rotation 2) Roto-reflection - num_rotations: 4 # Number of rotations to check robustness during inference + num_rotations: 4 # Number of rotations to check robustness during inference diff --git a/examples/images/classification/configs/original_configs/group_equivariant/cifar10.yaml b/examples/images/classification/configs/original_configs/group_equivariant/cifar10.yaml index d94f0bf..a209727 100644 --- a/examples/images/classification/configs/original_configs/group_equivariant/cifar10.yaml +++ b/examples/images/classification/configs/original_configs/group_equivariant/cifar10.yaml @@ -9,7 +9,7 @@ dataset: experiment: run_mode: "train" # Mode to run the model in, different run modes 1)dryrun 2)train 3)test 4)auto_tune - seed: 0 # Seed for random number generation + seed: 0 # Seed for random number generation deterministic: false # Whether to set deterministic mode (true) or not (false) device: "cuda" # Device, can be cuda or cpu num_nodes: 1 @@ -25,7 +25,7 @@ experiment: inference: method: "group" # Type of inference options 1) vanilla 2) group group_type: "rotation" # Type of group to test during inference 1) Rotation 2) Roto-reflection - num_rotations: 4 # Number of rotations to check robustness during inference + num_rotations: 4 # Number of rotations to check robustness during inference prediction: prediction_network_architecture: "resnet50" # Architecture of the prediction network @@ -53,4 +53,4 @@ wandb: checkpoint: checkpoint_path: "/home/mila/s/siba-smarak.panigrahi/scratch/equiadapt/image/checkpoints" # Path to save checkpoints checkpoint_name: "" # Model checkpoint name, should be left empty and dynamically allocated later - save_canonized_images: 0 # Whether to save canonized images (1) or not (0) \ No newline at end of file + save_canonized_images: 0 # Whether to save canonized images (1) or not (0) diff --git a/examples/images/classification/configs/original_configs/group_equivariant/rotmnist.yaml b/examples/images/classification/configs/original_configs/group_equivariant/rotmnist.yaml index 26b0414..afac33a 100644 --- a/examples/images/classification/configs/original_configs/group_equivariant/rotmnist.yaml +++ b/examples/images/classification/configs/original_configs/group_equivariant/rotmnist.yaml @@ -9,7 +9,7 @@ dataset: experiment: run_mode: "train" # Mode to run the model in, different run modes 1)dryrun 2)train 3)test 4)auto_tune - seed: 0 # Seed for random number generation + seed: 0 # Seed for random number generation deterministic: false # Whether to set deterministic mode (true) or not (false) device: "cuda" # Device, can be cuda or cpu num_nodes: 1 @@ -25,7 +25,7 @@ experiment: inference: method: "group" # Type of inference options 1) vanilla 2) group group_type: "rotation" # Type of group to test during inference 1) Rotation 2) Roto-reflection - num_rotations: 4 # Number of rotations to check robustness during inference + num_rotations: 4 # Number of rotations to check robustness during inference prediction: prediction_network_architecture: "resnet50" # Architecture of the prediction network @@ -53,4 +53,4 @@ wandb: checkpoint: checkpoint_path: "/home/mila/a/arnab.mondal/scratch/equiadapt/image/checkpoints" # Path to save checkpoints checkpoint_name: "" # Model checkpoint name, should be left empty and dynamically allocated later - save_canonized_images: 0 # Whether to save canonized images (1) or not (0) \ No newline at end of file + save_canonized_images: 0 # Whether to save canonized images (1) or not (0) diff --git a/examples/images/classification/configs/original_configs/opt_equivariant/cifar10.yaml b/examples/images/classification/configs/original_configs/opt_equivariant/cifar10.yaml index c611b84..9c7afe6 100644 --- a/examples/images/classification/configs/original_configs/opt_equivariant/cifar10.yaml +++ b/examples/images/classification/configs/original_configs/opt_equivariant/cifar10.yaml @@ -9,7 +9,7 @@ dataset: experiment: run_mode: "train" # Mode to run the model in, different run modes 1)dryrun 2)train 3)test 4)auto_tune - seed: 0 # Seed for random number generation + seed: 0 # Seed for random number generation deterministic: false # Whether to set deterministic mode (true) or not (false) num_nodes: 1 num_gpus: 1 @@ -25,7 +25,7 @@ experiment: inference: method: "group" # Type of inference options 1) vanilla 2) group group_type: "rotation" # Type of group to test during inference 1) Rotation 2) Roto-reflection - num_rotations: 4 # Number of rotations to check robustness during inference + num_rotations: 4 # Number of rotations to check robustness during inference prediction: prediction_network_architecture: "resnet50" # Architecture of the prediction network @@ -54,4 +54,4 @@ wandb: checkpoint: checkpoint_path: "/home/mila/a/arnab.mondal/scratch/equiadapt/image/checkpoints" # Path to save checkpoints checkpoint_name: "" # Model checkpoint name, should be left empty and dynamically allocated later - save_canonized_images: 0 # Whether to save canonized images (1) or not (0) \ No newline at end of file + save_canonized_images: 0 # Whether to save canonized images (1) or not (0) diff --git a/examples/images/classification/configs/original_configs/opt_equivariant/rotmnist.yaml b/examples/images/classification/configs/original_configs/opt_equivariant/rotmnist.yaml index 5168bd1..1dc06c3 100644 --- a/examples/images/classification/configs/original_configs/opt_equivariant/rotmnist.yaml +++ b/examples/images/classification/configs/original_configs/opt_equivariant/rotmnist.yaml @@ -9,7 +9,7 @@ dataset: experiment: run_mode: "train" # Mode to run the model in, different run modes 1)dryrun 2)train 3)test 4)auto_tune - seed: 0 # Seed for random number generation + seed: 0 # Seed for random number generation deterministic: false # Whether to set deterministic mode (true) or not (false) num_nodes: 1 num_gpus: 1 @@ -25,7 +25,7 @@ experiment: inference: method: "group" # Type of inference options 1) vanilla 2) group group_type: "rotation" # Type of group to test during inference 1) Rotation 2) Roto-reflection - num_rotations: 4 # Number of rotations to check robustness during inference + num_rotations: 4 # Number of rotations to check robustness during inference prediction: prediction_network_architecture: "resnet50" # Architecture of the prediction network @@ -54,4 +54,4 @@ wandb: checkpoint: checkpoint_path: "/home/mila/a/arnab.mondal/scratch/equiadapt/image/checkpoints" # Path to save checkpoints checkpoint_name: "" # Model checkpoint name, should be left empty and dynamically allocated later - save_canonized_images: 0 # Whether to save canonized images (1) or not (0) \ No newline at end of file + save_canonized_images: 0 # Whether to save canonized images (1) or not (0) diff --git a/examples/images/classification/configs/original_configs/steerable/cifar10.yaml b/examples/images/classification/configs/original_configs/steerable/cifar10.yaml index 36db32a..be68cf1 100644 --- a/examples/images/classification/configs/original_configs/steerable/cifar10.yaml +++ b/examples/images/classification/configs/original_configs/steerable/cifar10.yaml @@ -19,7 +19,7 @@ prediction: freeze_pretrained_encoder: 0 # Whether to freeze the pretrained encoder (1) or not (0) canonicalization: - network_type: 'escnn' # Options o canonization method 1) escnn + network_type: 'escnn' # Options o canonization method 1) escnn network_hyperparams: kernel_size: 3 # Kernel size for the canonization network out_channels: 16 # Number of output channels for the canonization network @@ -42,4 +42,4 @@ wandb: checkpoint: checkpoint_path: "/home/mila/a/arnab.mondal/scratch/equiadapt/image/checkpoints" # Path to save checkpoints deterministic: false # Whether to set deterministic mode (true) or not (false) - save_canonized_images: 0 # Whether to save canonized images (1) or not (0) \ No newline at end of file + save_canonized_images: 0 # Whether to save canonized images (1) or not (0) diff --git a/examples/images/classification/configs/wandb_sweep.yaml b/examples/images/classification/configs/wandb_sweep.yaml index 995e1f7..63056b3 100644 --- a/examples/images/classification/configs/wandb_sweep.yaml +++ b/examples/images/classification/configs/wandb_sweep.yaml @@ -41,4 +41,4 @@ command: - ${env} - python3 - ${program} - - ${args_no_hyphens} \ No newline at end of file + - ${args_no_hyphens} diff --git a/examples/images/classification/inference_utils.py b/examples/images/classification/inference_utils.py index 8556e5a..3f7eb8a 100644 --- a/examples/images/classification/inference_utils.py +++ b/examples/images/classification/inference_utils.py @@ -5,42 +5,42 @@ from torchvision import transforms -def get_inference_method(canonicalizer: torch.nn.Module, - prediction_network: torch.nn.Module, - num_classes: int, - inference_hyperparams: Union[Dict, wandb.Config], +def get_inference_method(canonicalizer: torch.nn.Module, + prediction_network: torch.nn.Module, + num_classes: int, + inference_hyperparams: Union[Dict, wandb.Config], in_shape: tuple = (3, 32, 32)): if inference_hyperparams.method == 'vanilla': return VanillaInference(canonicalizer, prediction_network, num_classes) elif inference_hyperparams.method == 'group': return GroupInference( - canonicalizer, prediction_network, num_classes, + canonicalizer, prediction_network, num_classes, inference_hyperparams, in_shape ) else: raise ValueError(f'{inference_hyperparams.method} is not implemented for now.') class VanillaInference: - def __init__(self, - canonicalizer: torch.nn.Module, - prediction_network: torch.nn.Module, + def __init__(self, + canonicalizer: torch.nn.Module, + prediction_network: torch.nn.Module, num_classes: int) -> None: self.canonicalizer = canonicalizer self.prediction_network = prediction_network self.num_classes = num_classes - + def forward(self, x): # canonicalize the input data # For the vanilla model, the canonicalization is the identity transformation x_canonicalized = self.canonicalizer(x) - + # Forward pass through the prediction network as you'll normally do logits = self.prediction_network(x_canonicalized) return logits - + def get_inference_metrics(self, x: torch.Tensor, y: torch.Tensor): # Forward pass through the prediction network - logits = self.forward(x) + logits = self.forward(x) preds = logits.argmax(dim=-1) # Calculate the accuracy @@ -49,7 +49,7 @@ def get_inference_metrics(self, x: torch.Tensor, y: torch.Tensor): # Calculate accuracy per class acc_per_class = [(preds[y == i] == y[y == i]).float().mean() for i in range(self.num_classes)] - + # check if the accuracy per class is nan acc_per_class = [0.0 if math.isnan(acc) else acc for acc in acc_per_class] @@ -60,19 +60,19 @@ def get_inference_metrics(self, x: torch.Tensor, y: torch.Tensor): class GroupInference(VanillaInference): - def __init__(self, - canonicalizer: torch.nn.Module, - prediction_network: torch.nn.Module, + def __init__(self, + canonicalizer: torch.nn.Module, + prediction_network: torch.nn.Module, num_classes: int, - inference_hyperparams: Union[Dict, wandb.Config], + inference_hyperparams: Union[Dict, wandb.Config], in_shape: tuple = (3, 32, 32)): - + super().__init__(canonicalizer, prediction_network, num_classes) self.group_type = inference_hyperparams.group_type self.num_rotations = inference_hyperparams.num_rotations self.num_group_elements = self.num_rotations if self.group_type == 'rotation' else 2 * self.num_rotations self.pad = transforms.Pad( - math.ceil(in_shape[-2] * 0.4), + math.ceil(in_shape[-2] * 0.4), padding_mode='edge' ) self.crop = transforms.CenterCrop((in_shape[-2], in_shape[-1])) @@ -81,13 +81,13 @@ def get_group_element_wise_logits(self, x: torch.Tensor): logits_dict = {} degrees = torch.linspace(0, 360, self.num_rotations + 1)[:-1] for rot, degree in enumerate(degrees): - + x_pad = self.pad(x) x_rot = transforms.functional.rotate(x_pad, degree.item()) x_rot = self.crop(x_rot) - + logits_dict[rot] = self.forward(x_rot) - + if self.group_type == 'roto-reflection': # Rotate the reflected images and get the logits for rot, degree in enumerate(degrees): @@ -100,17 +100,17 @@ def get_group_element_wise_logits(self, x: torch.Tensor): logits_dict[rot + len(degrees)] = self.forward(x_rotoreflect) return logits_dict - + def get_inference_metrics(self, x: torch.Tensor, y: torch.Tensor): - + logits_dict = self.get_group_element_wise_logits(x) - + # Use list comprehension to calculate accuracy for each group element acc_per_group_element = torch.tensor([(logits.argmax(dim=-1) == y).float().mean() for logits in logits_dict.values()]) metrics = {"test/group_acc": torch.mean(acc_per_group_element)} metrics.update({f'test/acc_group_element_{i}': max(acc_per_group_element[i], 0.0) for i in range(self.num_group_elements)}) - + preds = logits_dict[0].argmax(dim=-1) # Calculate the accuracy @@ -119,7 +119,7 @@ def get_inference_metrics(self, x: torch.Tensor, y: torch.Tensor): # Calculate accuracy per class acc_per_class = [(preds[y == i] == y[y == i]).float().mean() for i in range(self.num_classes)] - + # check if the accuracy per class is nan acc_per_class = [0.0 if math.isnan(acc) else acc for acc in acc_per_class] @@ -127,4 +127,3 @@ def get_inference_metrics(self, x: torch.Tensor, y: torch.Tensor): metrics.update({f'test/acc_class_{i}': max(acc, 0.0) for i, acc in enumerate(acc_per_class)}) return metrics - \ No newline at end of file diff --git a/examples/images/classification/model.py b/examples/images/classification/model.py index 37ad200..b5a787e 100644 --- a/examples/images/classification/model.py +++ b/examples/images/classification/model.py @@ -12,7 +12,7 @@ class ImageClassifierPipeline(pl.LightningModule): def __init__(self, hyperparams: DictConfig): super().__init__() - + self.loss, self.image_shape, self.num_classes = get_dataset_specific_info(hyperparams.dataset.dataset_name) self.prediction_network = get_prediction_network( @@ -25,20 +25,20 @@ def __init__(self, hyperparams: DictConfig): ) canonicalization_network = get_canonicalization_network( - hyperparams.canonicalization_type, + hyperparams.canonicalization_type, hyperparams.canonicalization, self.image_shape, ) - + self.canonicalizer = get_canonicalizer( hyperparams.canonicalization_type, canonicalization_network, hyperparams.canonicalization, self.image_shape - ) - + ) + self.hyperparams = hyperparams - + self.inference_method = get_inference_method( self.canonicalizer, self.prediction_network, @@ -46,105 +46,105 @@ def __init__(self, hyperparams: DictConfig): hyperparams.experiment.inference, self.image_shape ) - + self.max_epochs = hyperparams.experiment.training.num_epochs - + self.save_hyperparameters() def training_step(self, batch: torch.Tensor): x, y = batch batch_size, num_channels, height, width = x.shape - + # assert that the input is in the right shape assert (num_channels, height, width) == self.image_shape training_metrics = {} loss = 0.0 - + # canonicalize the input data # For the vanilla model, the canonicalization is the identity transformation x_canonicalized = self.canonicalizer(x) - + # add group contrast loss while using optmization based canonicalization method if 'opt' in self.hyperparams.canonicalization_type: group_contrast_loss = self.canonicalizer.get_optimization_specific_loss() loss += group_contrast_loss * self.hyperparams.experiment.training.loss.group_contrast_weight training_metrics.update({"train/optimization_specific_loss": group_contrast_loss}) - - + + # calculate the task loss which is the cross-entropy loss for classification if self.hyperparams.experiment.training.loss.task_weight: # Forward pass through the prediction network as you'll normally do logits = self.prediction_network(x_canonicalized) - + task_loss = self.loss(logits, y) loss += self.hyperparams.experiment.training.loss.task_weight * task_loss - + # Get the predictions and calculate the accuracy preds = logits.argmax(dim=-1) acc = (preds == y).float().mean() - + training_metrics.update({ "train/task_loss": task_loss, "train/acc": acc }) - + # Add prior regularization loss if the prior weight is non-zero if self.hyperparams.experiment.training.loss.prior_weight: prior_loss = self.canonicalizer.get_prior_regularization_loss() loss += prior_loss * self.hyperparams.experiment.training.loss.prior_weight metric_identity = self.canonicalizer.get_identity_metric() training_metrics.update({ - "train/prior_loss": prior_loss, + "train/prior_loss": prior_loss, "train/identity_metric": metric_identity }) - + training_metrics.update({ "train/loss": loss, }) - + # Log the training metrics self.log_dict(training_metrics, prog_bar=True) assert not torch.isnan(loss), "Loss is NaN" - + return {'loss': loss, 'acc': acc} - + def validation_step(self, batch: torch.Tensor): x, y = batch - + batch_size, num_channels, height, width = x.shape - + # assert that the input is in the right shape assert (num_channels, height, width) == self.image_shape - + validation_metrics = {} - + # canonicalize the input data # For the vanilla model, the canonicalization is the identity transformation x_canonicalized = self.canonicalizer(x) - + # Forward pass through the prediction network as you'll normally do logits = self.prediction_network(x_canonicalized) - # Get the predictions and calculate the accuracy + # Get the predictions and calculate the accuracy preds = logits.argmax(dim=-1) acc = (preds == y).float().mean() - + # Log the identity metric if the prior weight is non-zero if self.hyperparams.experiment.training.loss.prior_weight: metric_identity = self.canonicalizer.get_identity_metric() validation_metrics.update({ "train/identity_metric": metric_identity }) - - + + # Logging to TensorBoard by default validation_metrics.update({ "val/acc": acc }) - + self.log_dict(validation_metrics, prog_bar=True) return {'acc': acc} @@ -153,17 +153,17 @@ def validation_step(self, batch: torch.Tensor): def test_step(self, batch: torch.Tensor): x, y = batch batch_size, num_channels, height, width = x.shape - + # assert that the input is in the right shape assert (num_channels, height, width) == self.image_shape test_metrics = self.inference_method.get_inference_metrics(x, y) - + # Log the test metrics self.log_dict(test_metrics, prog_bar=True) - - return test_metrics - + + return test_metrics + def configure_optimizers(self): if 'resnet' in self.hyperparams.prediction.prediction_network_architecture and 'mnist' not in self.hyperparams.dataset.dataset_name: @@ -172,16 +172,16 @@ def configure_optimizers(self): [ {'params': self.prediction_network.parameters(), 'lr': self.hyperparams.experiment.training.prediction_lr}, {'params': self.canonicalizer.parameters(), 'lr': self.hyperparams.experiment.training.canonicalization_lr}, - ], + ], momentum=0.9, weight_decay=5e-4, ) - + if self.max_epochs > 100: milestones = [self.trainer.max_epochs // 6, self.trainer.max_epochs // 3, self.trainer.max_epochs // 2] else: milestones = [self.trainer.max_epochs // 3, self.trainer.max_epochs // 2] # for small training epochs - + scheduler_dict = { "scheduler": MultiStepLR( optimizer, @@ -197,4 +197,4 @@ def configure_optimizers(self): {'params': self.prediction_network.parameters(), 'lr': self.hyperparams.experiment.training.prediction_lr}, {'params': self.canonicalizer.parameters(), 'lr': self.hyperparams.experiment.training.canonicalization_lr}, ]) - return optimizer \ No newline at end of file + return optimizer diff --git a/examples/images/classification/model_utils.py b/examples/images/classification/model_utils.py index 5442d4d..befb091 100644 --- a/examples/images/classification/model_utils.py +++ b/examples/images/classification/model_utils.py @@ -19,7 +19,7 @@ def forward(self, x): reps = self.encoder(x) reps = reps.view(x.shape[0], -1) return self.predictor(reps) - + def get_dataset_specific_info(dataset_name): dataset_info = { 'rotated_mnist': (nn.CrossEntropyLoss(), (1, 28, 28), 10), @@ -35,10 +35,10 @@ def get_dataset_specific_info(dataset_name): raise ValueError('Dataset not implemented for now.') return dataset_info[dataset_name] - + def get_prediction_network( - architecture: str = 'resnet50', + architecture: str = 'resnet50', dataset_name: str = 'cifar10', use_pretrained: bool = False, freeze_encoder: bool = False, @@ -60,18 +60,18 @@ def get_prediction_network( if input_shape[-2:] == [32, 32] or dataset_name == 'rotated_mnist': encoder.conv1 = nn.Conv2d(input_shape[0], 64, kernel_size=3, stride=1, padding=1, bias=False) encoder.maxpool = nn.Identity() - + if freeze_encoder: for param in encoder.parameters(): param.requires_grad = False if dataset_name != 'ImageNet': feature_dim = encoder.fc.in_features - encoder.fc = nn.Identity() + encoder.fc = nn.Identity() prediction_network = PredictionNetwork(encoder, feature_dim, num_classes) else: prediction_network = encoder - - - return prediction_network \ No newline at end of file + + + return prediction_network diff --git a/examples/images/classification/prepare/__init__.py b/examples/images/classification/prepare/__init__.py index 67886d5..29c24ec 100644 --- a/examples/images/classification/prepare/__init__.py +++ b/examples/images/classification/prepare/__init__.py @@ -3,4 +3,4 @@ from .stl10_data import STL10DataModule from .celeba_data import CelebADataModule from .flowers102_data import Flowers102DataModule -from .imagenet_data import ImageNetDataModule \ No newline at end of file +from .imagenet_data import ImageNetDataModule diff --git a/examples/images/classification/prepare/celeba_data.py b/examples/images/classification/prepare/celeba_data.py index c3c064d..bdaffa6 100644 --- a/examples/images/classification/prepare/celeba_data.py +++ b/examples/images/classification/prepare/celeba_data.py @@ -12,7 +12,7 @@ def __init__(self, hyperparams, download=False): super().__init__() self.data_path = hyperparams.data_path self.hyperparams = hyperparams - + if hyperparams.augment == 1: self.train_transform = transforms.Compose( [ @@ -97,4 +97,4 @@ def test_dataloader(self): shuffle=False, num_workers=self.hyperparams.num_workers, ) - return test_loader \ No newline at end of file + return test_loader diff --git a/examples/images/classification/prepare/cifar_data.py b/examples/images/classification/prepare/cifar_data.py index 595ce48..224c4ff 100644 --- a/examples/images/classification/prepare/cifar_data.py +++ b/examples/images/classification/prepare/cifar_data.py @@ -63,7 +63,7 @@ def __init__(self, hyperparams, download=False): self.train_transform = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.Resize(224), - + transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)), ]) @@ -152,7 +152,7 @@ def __init__(self, hyperparams, download=False): transforms.RandomHorizontalFlip(), transforms.AutoAugment(policy=transforms.autoaugment.AutoAugmentPolicy.CIFAR10), - + transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)), ]) @@ -161,7 +161,7 @@ def __init__(self, hyperparams, download=False): self.train_transform = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.Resize(224), - + transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)), ]) @@ -211,4 +211,4 @@ def test_dataloader(self): shuffle=False, num_workers=self.hyperparams.num_workers, ) - return test_loader \ No newline at end of file + return test_loader diff --git a/examples/images/classification/prepare/flowers102_data.py b/examples/images/classification/prepare/flowers102_data.py index 05e2c8e..59e3aa8 100644 --- a/examples/images/classification/prepare/flowers102_data.py +++ b/examples/images/classification/prepare/flowers102_data.py @@ -64,4 +64,4 @@ def test_dataloader(self): shuffle=False, num_workers=self.hyperparams.num_workers, ) - return test_loader \ No newline at end of file + return test_loader diff --git a/examples/images/classification/prepare/imagenet_data.py b/examples/images/classification/prepare/imagenet_data.py index 3dc3572..ee0e767 100644 --- a/examples/images/classification/prepare/imagenet_data.py +++ b/examples/images/classification/prepare/imagenet_data.py @@ -66,7 +66,7 @@ def __init__(self, mode='train'): transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) - + def __call__(self, x): return self.transform(x) @@ -95,7 +95,7 @@ def val_dataloader(self): def test_dataloader(self): return self.loaders['val'] - + def get_imagenet_pytorch_dataloaders(self, data_dir=None, batch_size=None, num_workers=None): paths = { 'train': data_dir + '/train', diff --git a/examples/images/classification/prepare/rotated_mnist_data.py b/examples/images/classification/prepare/rotated_mnist_data.py index a2e092a..45bcb1f 100644 --- a/examples/images/classification/prepare/rotated_mnist_data.py +++ b/examples/images/classification/prepare/rotated_mnist_data.py @@ -147,4 +147,4 @@ def test_dataloader(self): # help='Path to the dataset.' # ) # args = parser.parse_args() -# obtain(args.data_path) \ No newline at end of file +# obtain(args.data_path) diff --git a/examples/images/classification/prepare/stl10_data.py b/examples/images/classification/prepare/stl10_data.py index a5bfdf3..189c644 100644 --- a/examples/images/classification/prepare/stl10_data.py +++ b/examples/images/classification/prepare/stl10_data.py @@ -16,7 +16,7 @@ def __init__(self, angles): def __call__(self, x): angle = random.choice(self.angles) return transforms.functional.rotate(x, angle) - + class STL10DataModule(pl.LightningDataModule): def __init__(self, hyperparams, download=False): super().__init__() @@ -43,7 +43,7 @@ def __init__(self, hyperparams, download=False): transforms.Pad(4), transforms.RandomCrop(96), transforms.Resize(224), - + CustomRotationTransform([0, 45, 90, 135, 180, 225, 270, 315]), # transforms.RandomRotation(180), transforms.RandomHorizontalFlip(), @@ -106,4 +106,4 @@ def test_dataloader(self): shuffle=False, num_workers=self.hyperparams.num_workers, ) - return test_loader \ No newline at end of file + return test_loader diff --git a/examples/images/classification/train.py b/examples/images/classification/train.py index 8834e1e..29ccdfc 100644 --- a/examples/images/classification/train.py +++ b/examples/images/classification/train.py @@ -27,18 +27,18 @@ def train_images(hyperparams: DictConfig): print("Wandb disabled for logging...") os.environ["WANDB_MODE"] = "disabled" os.environ["WANDB_DIR"] = hyperparams['wandb']['wandb_dir'] - os.environ["WANDB_CACHE_DIR"] = hyperparams['wandb']['wandb_cache_dir'] - + os.environ["WANDB_CACHE_DIR"] = hyperparams['wandb']['wandb_cache_dir'] + # initialize wandb wandb.init(config=OmegaConf.to_container(hyperparams, resolve=True), entity=hyperparams['wandb']['wandb_entity'], project=hyperparams['wandb']['wandb_project'], dir=hyperparams['wandb']['wandb_dir']) wandb_logger = WandbLogger(project=hyperparams['wandb']['wandb_project'], log_model="all") # set seed pl.seed_everything(hyperparams.experiment.seed) - + # get model, callbacks, and image data model, image_data, callbacks = get_model_data_and_callbacks(hyperparams) - + if hyperparams.canonicalization_type in ("group_equivariant", "opt_equivariant", "steerable"): wandb.watch(model.canonicalizer.canonicalization_network, log='all') @@ -47,7 +47,7 @@ def train_images(hyperparams: DictConfig): if hyperparams.experiment.run_mode == "train": trainer.fit(model, datamodule=image_data) - + elif hyperparams.experiment.run_mode == "auto_tune": trainer.tune(model, datamodule=image_data) @@ -61,4 +61,4 @@ def main(cfg: omegaconf.DictConfig): train_images(cfg) if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/examples/images/classification/train_utils.py b/examples/images/classification/train_utils.py index 3a39758..a3638c7 100644 --- a/examples/images/classification/train_utils.py +++ b/examples/images/classification/train_utils.py @@ -7,21 +7,21 @@ from model import ImageClassifierPipeline from prepare import RotatedMNISTDataModule, CIFAR10DataModule, CIFAR100DataModule, STL10DataModule, Flowers102DataModule, CelebADataModule, ImageNetDataModule - + def get_model_data_and_callbacks(hyperparams : DictConfig): - + # get image data image_data = get_image_data(hyperparams.dataset) - + # checkpoint name hyperparams.checkpoint.checkpoint_name = get_checkpoint_name(hyperparams) - + # checkpoint callbacks callbacks = get_callbacks(hyperparams) - # get model pipeline + # get model pipeline model = get_model_pipeline(hyperparams) - + return model, image_data, callbacks def get_model_pipeline(hyperparams: DictConfig): @@ -36,7 +36,7 @@ def get_model_pipeline(hyperparams: DictConfig): model.eval() else: model = ImageClassifierPipeline(hyperparams) - + return model def get_trainer( @@ -46,32 +46,32 @@ def get_trainer( ): if hyperparams.experiment.run_mode == "auto_tune": trainer = pl.Trainer( - max_epochs=hyperparams.experiment.num_epochs, accelerator="auto", - auto_scale_batch_size=True, auto_lr_find=True, logger=wandb_logger, + max_epochs=hyperparams.experiment.num_epochs, accelerator="auto", + auto_scale_batch_size=True, auto_lr_find=True, logger=wandb_logger, callbacks=callbacks, deterministic=hyperparams.experiment.deterministic, - num_nodes=hyperparams.experiment.num_nodes, devices=hyperparams.experiment.num_gpus, + num_nodes=hyperparams.experiment.num_nodes, devices=hyperparams.experiment.num_gpus, strategy='ddp' ) - + elif hyperparams.experiment.run_mode == "dryrun": trainer = pl.Trainer( - fast_dev_run=5, max_epochs=hyperparams.experiment.training.num_epochs, accelerator="auto", - limit_train_batches=5, limit_val_batches=5, logger=wandb_logger, + fast_dev_run=5, max_epochs=hyperparams.experiment.training.num_epochs, accelerator="auto", + limit_train_batches=5, limit_val_batches=5, logger=wandb_logger, callbacks=callbacks, deterministic=hyperparams.experiment.deterministic ) else: trainer = pl.Trainer( - max_epochs=hyperparams.experiment.training.num_epochs, accelerator="auto", + max_epochs=hyperparams.experiment.training.num_epochs, accelerator="auto", logger=wandb_logger, callbacks=callbacks, deterministic=hyperparams.experiment.deterministic, - num_nodes=hyperparams.experiment.num_nodes, devices=hyperparams.experiment.num_gpus, + num_nodes=hyperparams.experiment.num_nodes, devices=hyperparams.experiment.num_gpus, strategy='ddp' ) return trainer - - + + def get_callbacks(hyperparams: DictConfig): - + checkpoint_callback = ModelCheckpoint( dirpath=hyperparams.checkpoint.checkpoint_path, filename=hyperparams.checkpoint.checkpoint_name, @@ -79,12 +79,12 @@ def get_callbacks(hyperparams: DictConfig): mode="max", save_on_train_epoch_end=False, ) - early_stop_metric_callback = EarlyStopping(monitor="val/acc", - min_delta=hyperparams.experiment.training.min_delta, - patience=hyperparams.experiment.training.patience, - verbose=True, + early_stop_metric_callback = EarlyStopping(monitor="val/acc", + min_delta=hyperparams.experiment.training.min_delta, + patience=hyperparams.experiment.training.patience, + verbose=True, mode="max") - + return [checkpoint_callback, early_stop_metric_callback] def get_recursive_hyperparams_identifier(hyperparams: Dict): @@ -94,7 +94,7 @@ def get_recursive_hyperparams_identifier(hyperparams: Dict): for key, value in hyperparams.items(): if isinstance(value, DictConfig): identifier += f"{get_recursive_hyperparams_identifier(value)}" - # special manipulation for the keys (to avoid exceeding OS limit for file names) + # special manipulation for the keys (to avoid exceeding OS limit for file names) elif key not in ["canonicalization_type", "beta", "input_crop_ratio"]: if key == "network_type": identifier += f"_net_type_{value}_" @@ -107,16 +107,16 @@ def get_recursive_hyperparams_identifier(hyperparams: Dict): else: identifier += f"_{key}_{value}_" return identifier - + def get_checkpoint_name(hyperparams : DictConfig): return f"{get_recursive_hyperparams_identifier(hyperparams.canonicalization)}".lstrip("_") + \ f"_loss_wts_{int(hyperparams.experiment.training.loss.task_weight)}_{int(hyperparams.experiment.training.loss.prior_weight)}_{int(hyperparams.experiment.training.loss.group_contrast_weight)}" + \ f"_lrs_{hyperparams.experiment.training.prediction_lr}_{hyperparams.experiment.training.canonicalization_lr}" + \ f"_seed_{hyperparams.experiment.seed}" - + def get_image_data(dataset_hyperparams: DictConfig): - + dataset_classes = { "rotated_mnist": RotatedMNISTDataModule, "cifar10": CIFAR10DataModule, @@ -126,10 +126,10 @@ def get_image_data(dataset_hyperparams: DictConfig): "flowers102": Flowers102DataModule, "imagenet": ImageNetDataModule } - + if dataset_hyperparams.dataset_name not in dataset_classes: raise ValueError(f"{dataset_hyperparams.dataset_name} not implemented") - + return dataset_classes[dataset_hyperparams.dataset_name](dataset_hyperparams) def load_envs(env_file: Optional[str] = None) -> None: @@ -142,4 +142,4 @@ def load_envs(env_file: Optional[str] = None) -> None: :param env_file: the file that defines the environment variables to use. If None it searches for a `.env` file in the project. """ - dotenv.load_dotenv(dotenv_path=env_file, override=True) \ No newline at end of file + dotenv.load_dotenv(dotenv_path=env_file, override=True) diff --git a/examples/images/common/utils.py b/examples/images/common/utils.py index de5be1b..e8e8059 100644 --- a/examples/images/common/utils.py +++ b/examples/images/common/utils.py @@ -1,4 +1,4 @@ -import torch +import torch from omegaconf import DictConfig from equiadapt.common.basecanonicalization import IdentityCanonicalization @@ -20,7 +20,7 @@ def get_canonicalization_network( """ if canonicalization_type == 'identity': return torch.nn.Identity() - + canonicalization_network_dict = { 'group_equivariant': { 'escnn': ESCNNEquivariantNetwork, @@ -37,21 +37,21 @@ def get_canonicalization_network( 'cnn': ConvNetwork, } } - + if canonicalization_type not in canonicalization_network_dict: - raise ValueError(f'{canonicalization_type} is not implemented') + raise ValueError(f'{canonicalization_type} is not implemented') if canonicalization_hyperparams.network_type not in canonicalization_network_dict[canonicalization_type]: raise ValueError(f'{canonicalization_hyperparams.network_type} is not implemented for {canonicalization_type}') - + canonicalization_network = \ canonicalization_network_dict[canonicalization_type][ canonicalization_hyperparams.network_type ]( - in_shape = (in_shape[0], canonicalization_hyperparams.resize_shape, - canonicalization_hyperparams.resize_shape), + in_shape = (in_shape[0], canonicalization_hyperparams.resize_shape, + canonicalization_hyperparams.resize_shape), **canonicalization_hyperparams.network_hyperparams ) - + return canonicalization_network @@ -70,21 +70,21 @@ def get_canonicalizer( """ if canonicalization_type == 'identity': return IdentityCanonicalization(canonicalization_network) - + canonicalizer_dict = { 'group_equivariant': GroupEquivariantImageCanonicalization, 'steerable': SteerableImageCanonicalization, 'opt_group_equivariant': OptimizedGroupEquivariantImageCanonicalization, 'opt_steerable': OptimizedSteerableImageCanonicalization } - + if canonicalization_type not in canonicalizer_dict: raise ValueError(f'{canonicalization_type} needs a canonicalization network implementation.') - + canonicalizer = canonicalizer_dict[canonicalization_type]( canonicalization_network=canonicalization_network, canonicalization_hyperparams=canonicalization_hyperparams, in_shape=in_shape ) - - return canonicalizer \ No newline at end of file + + return canonicalizer diff --git a/examples/images/segmentation/README.md b/examples/images/segmentation/README.md index c49cd9c..ee45f86 100644 --- a/examples/images/segmentation/README.md +++ b/examples/images/segmentation/README.md @@ -3,14 +3,14 @@ ## For COCO ### For instance segmentation (without prior regularization) ``` -python train.py canonicalization=group_equivariant experiment.training.loss.prior_weight=0 +python train.py canonicalization=group_equivariant experiment.training.loss.prior_weight=0 ``` ### For instance segmentation (with prior regularization) -``` -python train.py canonicalization=group_equivariant +``` +python train.py canonicalization=group_equivariant ``` -**Note**: You can also run the `train.py` as follows from root directory of the project: +**Note**: You can also run the `train.py` as follows from root directory of the project: ``` python examples/images/segmentation/train.py canonicalization=group_equivariant ``` diff --git a/examples/images/segmentation/configs/canonicalization/group_equivariant.yaml b/examples/images/segmentation/configs/canonicalization/group_equivariant.yaml index c70e07c..8f1fe6a 100644 --- a/examples/images/segmentation/configs/canonicalization/group_equivariant.yaml +++ b/examples/images/segmentation/configs/canonicalization/group_equivariant.yaml @@ -8,4 +8,4 @@ network_hyperparams: num_rotations: 4 # Number of rotations for the canonization network beta: 1.0 # Beta parameter for the canonization network input_crop_ratio: 0.8 # Ratio at which we crop the input to the canonicalization -resize_shape: 128 # Resize shape for the input \ No newline at end of file +resize_shape: 128 # Resize shape for the input diff --git a/examples/images/segmentation/configs/canonicalization/identity.yaml b/examples/images/segmentation/configs/canonicalization/identity.yaml index 1598d17..513e776 100644 --- a/examples/images/segmentation/configs/canonicalization/identity.yaml +++ b/examples/images/segmentation/configs/canonicalization/identity.yaml @@ -1 +1 @@ -canonicalization_type: identity \ No newline at end of file +canonicalization_type: identity diff --git a/examples/images/segmentation/configs/canonicalization/opt_group_equivariant.yaml b/examples/images/segmentation/configs/canonicalization/opt_group_equivariant.yaml index a41bb04..986eae3 100644 --- a/examples/images/segmentation/configs/canonicalization/opt_group_equivariant.yaml +++ b/examples/images/segmentation/configs/canonicalization/opt_group_equivariant.yaml @@ -11,4 +11,4 @@ beta: 1.0 # Beta parameter for the canonization network input_crop_ratio: 0.8 # Ratio at which we crop the input to the canonicalization resize_shape: 96 # Resize shape for the input learn_ref_vec: False # Whether to learn the reference vector -artifact_err_wt: 0 # Weight for rotation artifact error (specific to image data, for non C4 rotation, for non-equivariant canonicalization networks) \ No newline at end of file +artifact_err_wt: 0 # Weight for rotation artifact error (specific to image data, for non C4 rotation, for non-equivariant canonicalization networks) diff --git a/examples/images/segmentation/configs/canonicalization/opt_steerable.yaml b/examples/images/segmentation/configs/canonicalization/opt_steerable.yaml index e492781..76ca2d9 100644 --- a/examples/images/segmentation/configs/canonicalization/opt_steerable.yaml +++ b/examples/images/segmentation/configs/canonicalization/opt_steerable.yaml @@ -6,4 +6,4 @@ network_hyperparams: num_layers: 3 # Number of layers in the canonization network out_vector_size: 4 # Dimension of the output vector group_type: rotation # Type of group for the canonization network -input_crop_ratio: 0.8 # Ratio at which we crop the input to the canonicalization \ No newline at end of file +input_crop_ratio: 0.8 # Ratio at which we crop the input to the canonicalization diff --git a/examples/images/segmentation/configs/canonicalization/steerable.yaml b/examples/images/segmentation/configs/canonicalization/steerable.yaml index d3a63bc..e6c0755 100644 --- a/examples/images/segmentation/configs/canonicalization/steerable.yaml +++ b/examples/images/segmentation/configs/canonicalization/steerable.yaml @@ -5,4 +5,4 @@ network_hyperparams: out_channels: 16 # Number of output channels for the canonization network num_layers: 3 # Number of layers in the canonization network group_type: rotation # Type of group for the canonization network -input_crop_ratio: 0.8 # Ratio at which we crop the input to the canonicalization \ No newline at end of file +input_crop_ratio: 0.8 # Ratio at which we crop the input to the canonicalization diff --git a/examples/images/segmentation/configs/checkpoint/default.yaml b/examples/images/segmentation/configs/checkpoint/default.yaml index 419f669..7398463 100644 --- a/examples/images/segmentation/configs/checkpoint/default.yaml +++ b/examples/images/segmentation/configs/checkpoint/default.yaml @@ -1,3 +1,3 @@ checkpoint_path: ${oc.env:CHECKPOINT_PATH} # Path to save checkpoints checkpoint_name: "" # Model checkpoint name, should be left empty and dynamically allocated later -save_canonized_images: 0 # Whether to save canonized images (1) or not (0) \ No newline at end of file +save_canonized_images: 0 # Whether to save canonized images (1) or not (0) diff --git a/examples/images/segmentation/configs/dataset/default.yaml b/examples/images/segmentation/configs/dataset/default.yaml index bd7dc1e..403e50c 100644 --- a/examples/images/segmentation/configs/dataset/default.yaml +++ b/examples/images/segmentation/configs/dataset/default.yaml @@ -1,8 +1,8 @@ dataset_name: coco # Name of the dataset to use root_dir: ${oc.env:DATA_PATH} # Root directory of the dataset ann_dir: "" # Directory containing annotations -augment: flip # Whether to train with flip augmentation +augment: flip # Whether to train with flip augmentation img_size: 1024 # Size of the input images num_workers: 0 # Number of workers for data loading batch_size: 128 # Number of samples per batch -val_batch_size: 12 # Number of samples per batch for validation \ No newline at end of file +val_batch_size: 12 # Number of samples per batch for validation diff --git a/examples/images/segmentation/configs/experiment/default.yaml b/examples/images/segmentation/configs/experiment/default.yaml index bdd939d..bee0e34 100644 --- a/examples/images/segmentation/configs/experiment/default.yaml +++ b/examples/images/segmentation/configs/experiment/default.yaml @@ -1,5 +1,5 @@ run_mode: train # Mode to run the model in, different run modes 1)dryrun 2)train 3)test 4)auto_tune -seed: 0 # Seed for random number generation +seed: 0 # Seed for random number generation deterministic: false # Whether to set deterministic mode (true) or not (false) device: cuda # Device, can be cuda or cpu num_nodes: 1 @@ -18,4 +18,4 @@ training: inference: method: group # Type of inference options 1) vanilla 2) group group_type: rotation # Type of group to test during inference 1) Rotation 2) Roto-reflection - num_rotations: 4 # Number of rotations to check robustness during inference + num_rotations: 4 # Number of rotations to check robustness during inference diff --git a/examples/images/segmentation/configs/original_configs/group_equivariant/cifar10.yaml b/examples/images/segmentation/configs/original_configs/group_equivariant/cifar10.yaml index d94f0bf..a209727 100644 --- a/examples/images/segmentation/configs/original_configs/group_equivariant/cifar10.yaml +++ b/examples/images/segmentation/configs/original_configs/group_equivariant/cifar10.yaml @@ -9,7 +9,7 @@ dataset: experiment: run_mode: "train" # Mode to run the model in, different run modes 1)dryrun 2)train 3)test 4)auto_tune - seed: 0 # Seed for random number generation + seed: 0 # Seed for random number generation deterministic: false # Whether to set deterministic mode (true) or not (false) device: "cuda" # Device, can be cuda or cpu num_nodes: 1 @@ -25,7 +25,7 @@ experiment: inference: method: "group" # Type of inference options 1) vanilla 2) group group_type: "rotation" # Type of group to test during inference 1) Rotation 2) Roto-reflection - num_rotations: 4 # Number of rotations to check robustness during inference + num_rotations: 4 # Number of rotations to check robustness during inference prediction: prediction_network_architecture: "resnet50" # Architecture of the prediction network @@ -53,4 +53,4 @@ wandb: checkpoint: checkpoint_path: "/home/mila/s/siba-smarak.panigrahi/scratch/equiadapt/image/checkpoints" # Path to save checkpoints checkpoint_name: "" # Model checkpoint name, should be left empty and dynamically allocated later - save_canonized_images: 0 # Whether to save canonized images (1) or not (0) \ No newline at end of file + save_canonized_images: 0 # Whether to save canonized images (1) or not (0) diff --git a/examples/images/segmentation/configs/original_configs/group_equivariant/rotmnist.yaml b/examples/images/segmentation/configs/original_configs/group_equivariant/rotmnist.yaml index 26b0414..afac33a 100644 --- a/examples/images/segmentation/configs/original_configs/group_equivariant/rotmnist.yaml +++ b/examples/images/segmentation/configs/original_configs/group_equivariant/rotmnist.yaml @@ -9,7 +9,7 @@ dataset: experiment: run_mode: "train" # Mode to run the model in, different run modes 1)dryrun 2)train 3)test 4)auto_tune - seed: 0 # Seed for random number generation + seed: 0 # Seed for random number generation deterministic: false # Whether to set deterministic mode (true) or not (false) device: "cuda" # Device, can be cuda or cpu num_nodes: 1 @@ -25,7 +25,7 @@ experiment: inference: method: "group" # Type of inference options 1) vanilla 2) group group_type: "rotation" # Type of group to test during inference 1) Rotation 2) Roto-reflection - num_rotations: 4 # Number of rotations to check robustness during inference + num_rotations: 4 # Number of rotations to check robustness during inference prediction: prediction_network_architecture: "resnet50" # Architecture of the prediction network @@ -53,4 +53,4 @@ wandb: checkpoint: checkpoint_path: "/home/mila/a/arnab.mondal/scratch/equiadapt/image/checkpoints" # Path to save checkpoints checkpoint_name: "" # Model checkpoint name, should be left empty and dynamically allocated later - save_canonized_images: 0 # Whether to save canonized images (1) or not (0) \ No newline at end of file + save_canonized_images: 0 # Whether to save canonized images (1) or not (0) diff --git a/examples/images/segmentation/configs/original_configs/opt_equivariant/cifar10.yaml b/examples/images/segmentation/configs/original_configs/opt_equivariant/cifar10.yaml index c611b84..9c7afe6 100644 --- a/examples/images/segmentation/configs/original_configs/opt_equivariant/cifar10.yaml +++ b/examples/images/segmentation/configs/original_configs/opt_equivariant/cifar10.yaml @@ -9,7 +9,7 @@ dataset: experiment: run_mode: "train" # Mode to run the model in, different run modes 1)dryrun 2)train 3)test 4)auto_tune - seed: 0 # Seed for random number generation + seed: 0 # Seed for random number generation deterministic: false # Whether to set deterministic mode (true) or not (false) num_nodes: 1 num_gpus: 1 @@ -25,7 +25,7 @@ experiment: inference: method: "group" # Type of inference options 1) vanilla 2) group group_type: "rotation" # Type of group to test during inference 1) Rotation 2) Roto-reflection - num_rotations: 4 # Number of rotations to check robustness during inference + num_rotations: 4 # Number of rotations to check robustness during inference prediction: prediction_network_architecture: "resnet50" # Architecture of the prediction network @@ -54,4 +54,4 @@ wandb: checkpoint: checkpoint_path: "/home/mila/a/arnab.mondal/scratch/equiadapt/image/checkpoints" # Path to save checkpoints checkpoint_name: "" # Model checkpoint name, should be left empty and dynamically allocated later - save_canonized_images: 0 # Whether to save canonized images (1) or not (0) \ No newline at end of file + save_canonized_images: 0 # Whether to save canonized images (1) or not (0) diff --git a/examples/images/segmentation/configs/original_configs/opt_equivariant/rotmnist.yaml b/examples/images/segmentation/configs/original_configs/opt_equivariant/rotmnist.yaml index 5168bd1..1dc06c3 100644 --- a/examples/images/segmentation/configs/original_configs/opt_equivariant/rotmnist.yaml +++ b/examples/images/segmentation/configs/original_configs/opt_equivariant/rotmnist.yaml @@ -9,7 +9,7 @@ dataset: experiment: run_mode: "train" # Mode to run the model in, different run modes 1)dryrun 2)train 3)test 4)auto_tune - seed: 0 # Seed for random number generation + seed: 0 # Seed for random number generation deterministic: false # Whether to set deterministic mode (true) or not (false) num_nodes: 1 num_gpus: 1 @@ -25,7 +25,7 @@ experiment: inference: method: "group" # Type of inference options 1) vanilla 2) group group_type: "rotation" # Type of group to test during inference 1) Rotation 2) Roto-reflection - num_rotations: 4 # Number of rotations to check robustness during inference + num_rotations: 4 # Number of rotations to check robustness during inference prediction: prediction_network_architecture: "resnet50" # Architecture of the prediction network @@ -54,4 +54,4 @@ wandb: checkpoint: checkpoint_path: "/home/mila/a/arnab.mondal/scratch/equiadapt/image/checkpoints" # Path to save checkpoints checkpoint_name: "" # Model checkpoint name, should be left empty and dynamically allocated later - save_canonized_images: 0 # Whether to save canonized images (1) or not (0) \ No newline at end of file + save_canonized_images: 0 # Whether to save canonized images (1) or not (0) diff --git a/examples/images/segmentation/configs/original_configs/steerable/cifar10.yaml b/examples/images/segmentation/configs/original_configs/steerable/cifar10.yaml index 36db32a..be68cf1 100644 --- a/examples/images/segmentation/configs/original_configs/steerable/cifar10.yaml +++ b/examples/images/segmentation/configs/original_configs/steerable/cifar10.yaml @@ -19,7 +19,7 @@ prediction: freeze_pretrained_encoder: 0 # Whether to freeze the pretrained encoder (1) or not (0) canonicalization: - network_type: 'escnn' # Options o canonization method 1) escnn + network_type: 'escnn' # Options o canonization method 1) escnn network_hyperparams: kernel_size: 3 # Kernel size for the canonization network out_channels: 16 # Number of output channels for the canonization network @@ -42,4 +42,4 @@ wandb: checkpoint: checkpoint_path: "/home/mila/a/arnab.mondal/scratch/equiadapt/image/checkpoints" # Path to save checkpoints deterministic: false # Whether to set deterministic mode (true) or not (false) - save_canonized_images: 0 # Whether to save canonized images (1) or not (0) \ No newline at end of file + save_canonized_images: 0 # Whether to save canonized images (1) or not (0) diff --git a/examples/images/segmentation/configs/prediction/default.yaml b/examples/images/segmentation/configs/prediction/default.yaml index b53c191..1717eae 100644 --- a/examples/images/segmentation/configs/prediction/default.yaml +++ b/examples/images/segmentation/configs/prediction/default.yaml @@ -3,4 +3,3 @@ prediction_network_architecture_type: vit_h # Class of the prediction network use_pretrained: 1 # Whether to use pretrained weights (1) or not (0) freeze_encoder: 1 # Whether to freeze encoder (1) or not (0) pretrained_ckpt_path: "/home/mila/s/siba-smarak.panigrahi/scratch/sam_checkpoints/sam_vit_h_4b8939.pth" # must be set for Segment-Anything model - diff --git a/examples/images/segmentation/configs/wandb_sweep.yaml b/examples/images/segmentation/configs/wandb_sweep.yaml index 94e7a9e..78889c2 100644 --- a/examples/images/segmentation/configs/wandb_sweep.yaml +++ b/examples/images/segmentation/configs/wandb_sweep.yaml @@ -27,4 +27,4 @@ command: - ${env} - python3 - ${program} - - ${args_no_hyphens} \ No newline at end of file + - ${args_no_hyphens} diff --git a/examples/images/segmentation/inference_utils.py b/examples/images/segmentation/inference_utils.py index 25930a8..bab89e8 100644 --- a/examples/images/segmentation/inference_utils.py +++ b/examples/images/segmentation/inference_utils.py @@ -9,65 +9,65 @@ from equiadapt.images.utils import flip_boxes, flip_masks, rotate_boxes, rotate_masks -def get_inference_method(canonicalizer: torch.nn.Module, - prediction_network: torch.nn.Module, - inference_hyperparams: Union[Dict, wandb.Config], +def get_inference_method(canonicalizer: torch.nn.Module, + prediction_network: torch.nn.Module, + inference_hyperparams: Union[Dict, wandb.Config], in_shape: tuple = (3, 1024, 1024)): if inference_hyperparams.method == 'vanilla': return VanillaInference(canonicalizer, prediction_network) elif inference_hyperparams.method == 'group': return GroupInference( - canonicalizer, prediction_network, + canonicalizer, prediction_network, inference_hyperparams, in_shape ) else: raise ValueError(f'{inference_hyperparams.method} is not implemented for now.') class VanillaInference: - def __init__(self, - canonicalizer: torch.nn.Module, + def __init__(self, + canonicalizer: torch.nn.Module, prediction_network: torch.nn.Module) -> None: self.canonicalizer = canonicalizer self.prediction_network = prediction_network - + def forward(self, x, targets): # canonicalize the input data # For the vanilla model, the canonicalization is the identity transformation x_canonicalized, targets_canonicalized = self.canonicalizer(x, targets) - + # Forward pass through the prediction network as you'll normally do # Finetuning maskrcnn model will return the losses which can be used to fine tune the model - # Meanwhile, Segment-Anything (SAM) can return boxes, ious, masks predictions + # Meanwhile, Segment-Anything (SAM) can return boxes, ious, masks predictions # For uniformity, we will ensure the prediction network returns both losses and predictions irrespective of the model return self.prediction_network(x_canonicalized, targets_canonicalized) - + def get_inference_metrics(self, x: torch.Tensor, targets: torch.Tensor): # Forward pass through the prediction network - _, _, _, outputs = self.forward(x) - + _, _, _, outputs = self.forward(x) + _map = MeanAveragePrecision(iou_type='segm') targets = [dict(boxes=target['boxes'], labels=target['labels'], masks=target['masks']) for target in targets] outputs = [dict(boxes=output['boxes'], labels=output['labels'], scores=output['scores'], masks=output['masks']) for output in outputs] _map.update(outputs, targets) _map_dict = _map.compute() - + metrics = {'test/map': _map_dict['map']} - + return metrics - + class GroupInference(VanillaInference): - def __init__(self, - canonicalizer: torch.nn.Module, + def __init__(self, + canonicalizer: torch.nn.Module, prediction_network: torch.nn.Module, - inference_hyperparams: Union[Dict, wandb.Config], + inference_hyperparams: Union[Dict, wandb.Config], in_shape: tuple = (3, 32, 32)): - + super().__init__(canonicalizer, prediction_network) self.group_type = inference_hyperparams.group_type self.num_rotations = inference_hyperparams.num_rotations self.num_group_elements = self.num_rotations if self.group_type == 'rotation' else 2 * self.num_rotations self.pad = transforms.Pad( - math.ceil(in_shape[-2] * 0.4), + math.ceil(in_shape[-2] * 0.4), padding_mode='edge' ) self.crop = transforms.CenterCrop((in_shape[-2], in_shape[-1])) @@ -75,17 +75,17 @@ def __init__(self, def get_group_element_wise_maps(self, images: torch.Tensor, targets: torch.Tensor): map_dict = dict() image_width = images[0].shape[1] - + degrees = torch.linspace(0, 360, self.num_rotations + 1)[:-1] for rot, degree in enumerate(degrees): - + targets_transformed = copy.deepcopy(targets) - + # apply group element on images images_pad = self.pad(images) images_rot = transforms.functional.rotate(images_pad, degree.item()) images_rot = self.crop(images_rot) - + # apply group element on bounding boxes and masks for t in range(len(targets_transformed)): targets_transformed[t]["boxes"] = rotate_boxes(targets_transformed[t]["boxes"], -degree, image_width) @@ -93,12 +93,12 @@ def get_group_element_wise_maps(self, images: torch.Tensor, targets: torch.Tenso # get predictions for the transformed images _, _, _, outputs = self.forward(images_rot, targets_transformed) - + Map = MeanAveragePrecision(iou_type='segm') targets = [dict(boxes=target['boxes'], labels=target['labels'], masks=target['masks']) for target in targets] outputs = [dict(boxes=output['boxes'], labels=output['labels'], scores=output['scores'], masks=output['masks']) for output in outputs] Map.update(outputs, targets) - + map_dict[rot] = Map.compute() if self.group_type == 'roto-reflection': @@ -114,27 +114,27 @@ def get_group_element_wise_maps(self, images: torch.Tensor, targets: torch.Tenso for t in range(len(targets_transformed)): targets_transformed[t]["boxes"] = rotate_boxes(targets_transformed[t]["boxes"], -degree, image_width) targets_transformed[t]["boxes"] = flip_boxes(targets_transformed[t]["boxes"], image_width) - + targets_transformed[t]["masks"] = rotate_masks(targets_transformed[t]["masks"], degree) targets_transformed[t]["masks"] = flip_masks(targets_transformed[t]["masks"]) # get predictions for the transformed images _, _, _, outputs = self.forward(images_rotoreflect, targets_transformed) - + Map = MeanAveragePrecision(iou_type='segm') targets = [dict(boxes=target['boxes'], labels=target['labels'], masks=target['masks']) for target in targets] outputs = [dict(boxes=output['boxes'], labels=output['labels'], scores=output['scores'], masks=output['masks']) for output in outputs] Map.update(outputs, targets) - + map_dict[rot + len(degrees)] = Map.compute() - + return map_dict - + def get_inference_metrics(self, images: torch.Tensor, targets: torch.Tensor): metrics = {} - + map_dict = self.get_group_element_wise_maps(images, targets) - + # Use list comprehension to calculate accuracy for each group element for i in range(self.num_group_elements): metrics.update({ @@ -151,7 +151,7 @@ def get_inference_metrics(self, images: torch.Tensor, targets: torch.Tensor): f'test/mar_medium_group_element_{i}': max(map_dict[i]['mar_medium'], 0.0), f'test/mar_large_group_element_{i}': max(map_dict[i]['mar_large'], 0.0), }) - + map_per_group_element = torch.tensor([map_dict[i]['map'] for i in range(self.num_group_elements)]) metrics.update({"test/group_map": torch.mean(map_per_group_element)}) @@ -161,4 +161,3 @@ def get_inference_metrics(self, images: torch.Tensor, targets: torch.Tensor): metrics.update({"test/map": max(map_dict[0]['map'], 0.0)}) return metrics - \ No newline at end of file diff --git a/examples/images/segmentation/model.py b/examples/images/segmentation/model.py index 5c77199..bd2d57f 100644 --- a/examples/images/segmentation/model.py +++ b/examples/images/segmentation/model.py @@ -14,7 +14,7 @@ class ImageSegmentationPipeline(pl.LightningModule): def __init__(self, hyperparams: DictConfig): super().__init__() - + self.loss, self.image_shape, self.num_classes = get_dataset_specific_info(hyperparams.dataset.dataset_name, hyperparams.prediction.prediction_network_architecture) self.prediction_network = get_prediction_network( @@ -28,61 +28,61 @@ def __init__(self, hyperparams: DictConfig): ) canonicalization_network = get_canonicalization_network( - hyperparams.canonicalization_type, + hyperparams.canonicalization_type, hyperparams.canonicalization, self.image_shape, ) - + self.canonicalizer = get_canonicalizer( hyperparams.canonicalization_type, canonicalization_network, hyperparams.canonicalization, self.image_shape - ) - + ) + self.hyperparams = hyperparams - + self.inference_method = get_inference_method( self.canonicalizer, self.prediction_network, hyperparams.experiment.inference, self.image_shape ) - + self.max_epochs = hyperparams.experiment.training.num_epochs - + self.save_hyperparameters() - + def apply_loss(self, loss_dict: dict, pred_masks: torch.Tensor, targets_canonicalized: dict, iou_predictions: torch.Tensor = None): assert self.loss or loss_dict, "Either pass a loss function or a dictionary of pre-computed losses for segmentation task loss" - + if loss_dict: # for maskrcnn model, the loss_dict will contain the losses - return sum(loss_dict.values()) - + return sum(loss_dict.values()) + num_masks = sum(len(pred_mask) for pred_mask in pred_masks) - + loss_focal = torch.tensor(0., device=self.hyperparams.device) loss_dice = torch.tensor(0., device=self.hyperparams.device) loss_iou = torch.tensor(0., device=self.hyperparams.device) - + for pred_mask, target, iou_prediction in zip(pred_masks, targets_canonicalized, iou_predictions): - + # if gt_masks is larger then select the first len(pred_masks) masks - gt_mask = target['masks'][:len(pred_mask), :, :] - + gt_mask = target['masks'][:len(pred_mask), :, :] + for loss_func in self.loss: assert hasattr(loss_func, 'forward'), "The loss function must have a forward method" if loss_func.name == 'focal_loss': loss_focal += loss_func(pred_mask, gt_mask.float(), num_masks) elif loss_func.name == 'dice_loss': loss_dice += loss_func(pred_mask, gt_mask, num_masks) else: raise ValueError(f"Loss function {loss_func.name} is not supported") - - + + if iou_predictions: batch_iou = calc_iou(pred_mask, gt_mask) loss_iou += torch.nn.functional.mse_loss(iou_prediction, batch_iou, reduction='sum') / num_masks - + return 20. * loss_focal + loss_dice + loss_iou @@ -90,89 +90,89 @@ def training_step(self, batch: torch.Tensor): x, targets = batch x = torch.stack(x) batch_size, num_channels, height, width = x.shape - + # assert that the input is in the right shape assert (num_channels, height, width) == self.image_shape training_metrics = {} loss = 0.0 - + # canonicalize the input data # For the vanilla model, the canonicalization is the identity transformation x_canonicalized, targets_canonicalized = self.canonicalizer(x, targets) - + # add group contrast loss while using optmization based canonicalization method if 'opt' in self.hyperparams.canonicalization_type: group_contrast_loss = self.canonicalizer.get_optimization_specific_loss() loss += group_contrast_loss * self.hyperparams.experiment.training.loss.group_contrast_weight training_metrics.update({"train/optimization_specific_loss": group_contrast_loss}) - + # calculate the task loss - # if finetuning is not required, set the weight for task loss to 0 + # if finetuning is not required, set the weight for task loss to 0 # it will avoid unnecessary forward pass through the prediction network - if self.hyperparams.experiment.training.loss.task_weight: - + if self.hyperparams.experiment.training.loss.task_weight: + # Forward pass through the prediction network as you'll normally do # Finetuning maskrcnn model will return the losses which can be used to fine tune the model - # Meanwhile, Segment-Anything (SAM) can return boxes, ious, masks predictions + # Meanwhile, Segment-Anything (SAM) can return boxes, ious, masks predictions # For uniformity, we will ensure the prediction network returns both losses and predictions irrespective of the model loss_dict, pred_masks, iou_predictions, _ = self.prediction_network(x_canonicalized, targets_canonicalized) - + # no requirement to invert canonicalization for the loss calculation # since we will compute the loss w.r.t canonicalized targets (to align with the loss computation in maskrcnn) task_loss = self.apply_loss(loss_dict, pred_masks, targets_canonicalized, iou_predictions) loss += self.hyperparams.experiment.training.loss.task_weight * task_loss - + training_metrics.update({ "train/task_loss": task_loss, }) - + # Add prior regularization loss if the prior weight is non-zero if self.hyperparams.experiment.training.loss.prior_weight: prior_loss = self.canonicalizer.get_prior_regularization_loss() loss += prior_loss * self.hyperparams.experiment.training.loss.prior_weight metric_identity = self.canonicalizer.get_identity_metric() training_metrics.update({ - "train/prior_loss": prior_loss, + "train/prior_loss": prior_loss, "train/identity_metric": metric_identity }) - + training_metrics.update({ "train/loss": loss, }) - + # Log the training metrics self.log_dict(training_metrics, prog_bar=True) - + assert not torch.isnan(loss), "Loss is NaN" return {'loss': loss} - + def validation_step(self, batch: torch.Tensor): x, targets = batch x = torch.stack(x) batch_size, num_channels, height, width = x.shape - + # assert that the input is in the right shape assert (num_channels, height, width) == self.image_shape validation_metrics = {} - + # canonicalize the input data # For the vanilla model, the canonicalization is the identity transformation x_canonicalized, targets_canonicalized = self.canonicalizer(x, targets) - + # Forward pass through the prediction network as you'll normally do # Finetuning maskrcnn model will return the losses which can be used to fine tune the model - # Meanwhile, Segment-Anything (SAM) can return boxes, ious, masks predictions + # Meanwhile, Segment-Anything (SAM) can return boxes, ious, masks predictions # For uniformity, we will ensure the prediction network returns both losses and predictions irrespective of the model _, _, _, outputs = self.prediction_network(x_canonicalized, targets_canonicalized) - + _map = MeanAveragePrecision(iou_type='segm') targets = [dict(boxes=target['boxes'], labels=target['labels'], masks=target['masks']) for target in targets] outputs = [dict(boxes=output['boxes'], labels=output['labels'], scores=output['scores'], masks=output['masks']) for output in outputs] _map.update(outputs, targets) _map_dict = _map.compute() - + validation_metrics.update({ 'val/map': _map_dict['map'], 'val/map_small': _map_dict['map_small'], @@ -188,10 +188,10 @@ def validation_step(self, batch: torch.Tensor): 'val/mar_large': _map_dict['mar_large'], }) - + # Log the validation metrics self.log_dict(validation_metrics, prog_bar=True) - + # Log the identity metric if the prior weight is non-zero if self.hyperparams.experiment.training.loss.prior_weight: metric_identity = self.canonicalizer.get_identity_metric() @@ -208,28 +208,28 @@ def test_step(self, batch: torch.Tensor): images, targets = batch images = torch.stack(images) batch_size, num_channels, height, width = images.shape - + # assert that the input is in the right shape assert (num_channels, height, width) == self.image_shape test_metrics = self.inference_method.get_inference_metrics(images, targets) - + # Log the test metrics self.log_dict(test_metrics, prog_bar=True) - - return test_metrics - + + return test_metrics + def configure_optimizers(self): # using SGD optimizer and MultiStepLR scheduler optimizer = torch.optim.SGD( [ {'params': self.prediction_network.parameters(), 'lr': self.hyperparams.experiment.training.prediction_lr}, {'params': self.canonicalizer.parameters(), 'lr': self.hyperparams.experiment.training.canonicalization_lr}, - ], + ], momentum=0.9, weight_decay=5e-4, ) - + scheduler_dict = { "scheduler": MultiStepLR( optimizer, @@ -238,4 +238,4 @@ def configure_optimizers(self): ), "interval": "epoch", } - return {"optimizer": optimizer, "lr_scheduler": scheduler_dict} \ No newline at end of file + return {"optimizer": optimizer, "lr_scheduler": scheduler_dict} diff --git a/examples/images/segmentation/model_utils.py b/examples/images/segmentation/model_utils.py index dc3d6dc..03efcf1 100644 --- a/examples/images/segmentation/model_utils.py +++ b/examples/images/segmentation/model_utils.py @@ -11,17 +11,17 @@ GAMMA = 2 class MaskRCNNModel(nn.Module): - def __init__(self, - architecture_type: str, + def __init__(self, + architecture_type: str, pretrained_ckpt_path:str = None, num_classes: int = 91, weights:str ='DEFAULT'): super().__init__() - + assert architecture_type in ['resnet50_fpn_v2'], NotImplementedError('Only `maskrcnn_resnet50_fpn_v2` is supported for now.') if architecture_type == 'resnet50_fpn_v2': self.model = maskrcnn_resnet50_fpn_v2(weights=weights) - + if num_classes != 91: in_features = self.model.roi_heads.box_predictor.cls_score.in_features # replace the pre-trained head with a new one @@ -63,8 +63,8 @@ def forward(self, images, targets): class SAMModel(nn.Module): - def __init__(self, - architecture_type: str, + def __init__(self, + architecture_type: str, pretrained_ckpt_path:str =None, num_classes: int = 91, weights:str ='DEFAULT'): @@ -115,7 +115,7 @@ def forward(self, images, targets): outputs.append(output) return None, pred_masks, ious, outputs - + class FocalLoss(nn.Module): def __init__(self, weight=None, size_average=True): @@ -153,7 +153,7 @@ def forward(self, inputs, targets, smooth=1): dice = (2. * intersection + smooth) / (inputs.sum() + targets.sum() + smooth) return 1 - dice - + def get_dataset_specific_info(dataset_name, prediction_architecture_name): dataset_info = { 'coco': { @@ -188,7 +188,7 @@ def get_prediction_network( raise ValueError(f'{architecture} is not implemented as prediction network for now.') prediction_network = model_dict[architecture](architecture_type, pretrained_ckpt_path, num_classes, weights) - + if freeze_encoder: for param in prediction_network.parameters(): param.requires_grad = False @@ -206,4 +206,4 @@ def calc_iou(pred_mask, gt_mask): union = torch.sum(pred_mask, dim=(1, 2)) + torch.sum(gt_mask, dim=(1, 2)) - intersection batch_iou = intersection / union batch_iou = batch_iou.unsqueeze(1) - return batch_iou \ No newline at end of file + return batch_iou diff --git a/examples/images/segmentation/prepare/__init__.py b/examples/images/segmentation/prepare/__init__.py index dc33087..ab3e630 100644 --- a/examples/images/segmentation/prepare/__init__.py +++ b/examples/images/segmentation/prepare/__init__.py @@ -1 +1 @@ -from .coco_data import COCODataModule \ No newline at end of file +from .coco_data import COCODataModule diff --git a/examples/images/segmentation/prepare/coco_data.py b/examples/images/segmentation/prepare/coco_data.py index a153529..ec56cf9 100644 --- a/examples/images/segmentation/prepare/coco_data.py +++ b/examples/images/segmentation/prepare/coco_data.py @@ -51,7 +51,7 @@ class COCODataModule(pl.LightningDataModule): def __init__(self, hyperparams): super().__init__() self.hyperparams = hyperparams - + def get_transform(self, train=True): tr = [] tr.append(T.PILToTensor()) @@ -60,7 +60,7 @@ def get_transform(self, train=True): if train and self.hyperparams.augment == 'flip': tr.append(T.RandomHorizontalFlip(0.5)) return T.Compose(tr) - + def collate_fn(self, batch): images = [x[0] for x in batch] targets = [x[1] for x in batch] @@ -85,7 +85,7 @@ def setup(self, stage=None): transform=self.get_transform(train=False) ) print('Test dataset size: ', len(self.test_dataset)) - + def train_dataloader(self): train_loader = DataLoader( self.train_dataset, @@ -169,4 +169,4 @@ def __getitem__(self, idx): if self.transform is not None: image, target = self.transform(image, target) - return image, target \ No newline at end of file + return image, target diff --git a/examples/images/segmentation/prepare/vision_transforms.py b/examples/images/segmentation/prepare/vision_transforms.py index 7e993e0..b39a016 100644 --- a/examples/images/segmentation/prepare/vision_transforms.py +++ b/examples/images/segmentation/prepare/vision_transforms.py @@ -61,4 +61,4 @@ def forward( self, image: Tensor, target: Optional[Dict[str, Tensor]] = None ) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]: image = F.convert_image_dtype(image, self.dtype) - return image, target \ No newline at end of file + return image, target diff --git a/examples/images/segmentation/train.py b/examples/images/segmentation/train.py index 3423600..0ee8b79 100644 --- a/examples/images/segmentation/train.py +++ b/examples/images/segmentation/train.py @@ -28,18 +28,18 @@ def train_images(hyperparams: DictConfig): print("Wandb disabled for logging...") os.environ["WANDB_MODE"] = "disabled" os.environ["WANDB_DIR"] = hyperparams['wandb']['wandb_dir'] - os.environ["WANDB_CACHE_DIR"] = hyperparams['wandb']['wandb_cache_dir'] - + os.environ["WANDB_CACHE_DIR"] = hyperparams['wandb']['wandb_cache_dir'] + # initialize wandb wandb.init(config=OmegaConf.to_container(hyperparams, resolve=True), entity=hyperparams['wandb']['wandb_entity'], project=hyperparams['wandb']['wandb_project'], dir=hyperparams['wandb']['wandb_dir']) wandb_logger = WandbLogger(project=hyperparams['wandb']['wandb_project'], log_model="all") # set seed pl.seed_everything(hyperparams.experiment.seed) - + # get model, callbacks, and image data model, image_data, callbacks = get_model_data_and_callbacks(hyperparams) - + if hyperparams.canonicalization_type in ("group_equivariant", "opt_equivariant", "steerable"): wandb.watch(model.canonicalizer.canonicalization_network, log='all') @@ -48,7 +48,7 @@ def train_images(hyperparams: DictConfig): if hyperparams.experiment.run_mode == "train": trainer.fit(model, datamodule=image_data) - + elif hyperparams.experiment.run_mode == "auto_tune": trainer.tune(model, datamodule=image_data) @@ -62,4 +62,4 @@ def main(cfg: omegaconf.DictConfig): train_images(cfg) if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/examples/images/segmentation/train_utils.py b/examples/images/segmentation/train_utils.py index 69653af..bf06c04 100644 --- a/examples/images/segmentation/train_utils.py +++ b/examples/images/segmentation/train_utils.py @@ -7,21 +7,21 @@ from model import ImageSegmentationPipeline from prepare import COCODataModule - + def get_model_data_and_callbacks(hyperparams : DictConfig): - + # get image data image_data = get_image_data(hyperparams.dataset) - + # checkpoint name hyperparams.checkpoint.checkpoint_name = get_checkpoint_name(hyperparams) - + # checkpoint callbacks callbacks = get_callbacks(hyperparams) - # get model pipeline + # get model pipeline model = get_model_pipeline(hyperparams) - + return model, image_data, callbacks def get_model_pipeline(hyperparams: DictConfig): @@ -36,7 +36,7 @@ def get_model_pipeline(hyperparams: DictConfig): model.eval() else: model = ImageSegmentationPipeline(hyperparams) - + return model def get_trainer( @@ -46,34 +46,34 @@ def get_trainer( ): if hyperparams.experiment.run_mode == "auto_tune": trainer = pl.Trainer( - max_epochs=hyperparams.experiment.num_epochs, accelerator="auto", - auto_scale_batch_size=True, auto_lr_find=True, logger=wandb_logger, + max_epochs=hyperparams.experiment.num_epochs, accelerator="auto", + auto_scale_batch_size=True, auto_lr_find=True, logger=wandb_logger, callbacks=callbacks, deterministic=hyperparams.experiment.deterministic, - num_nodes=hyperparams.experiment.num_nodes, devices=hyperparams.experiment.num_gpus, + num_nodes=hyperparams.experiment.num_nodes, devices=hyperparams.experiment.num_gpus, strategy='ddp' if not hyperparams.experiment.training.loss.task_weight else 'ddp_find_unused_parameters_true' ) - + elif hyperparams.experiment.run_mode == "dryrun": trainer = pl.Trainer( - fast_dev_run=5, max_epochs=hyperparams.experiment.training.num_epochs, accelerator="auto", - limit_train_batches=5, limit_val_batches=5, logger=wandb_logger, + fast_dev_run=5, max_epochs=hyperparams.experiment.training.num_epochs, accelerator="auto", + limit_train_batches=5, limit_val_batches=5, logger=wandb_logger, callbacks=callbacks, deterministic=hyperparams.experiment.deterministic ) else: trainer = pl.Trainer( - max_epochs=hyperparams.experiment.training.num_epochs, accelerator="auto", + max_epochs=hyperparams.experiment.training.num_epochs, accelerator="auto", logger=wandb_logger, callbacks=callbacks, deterministic=hyperparams.experiment.deterministic, - num_nodes=hyperparams.experiment.num_nodes, devices=hyperparams.experiment.num_gpus, - strategy='ddp' if not hyperparams.experiment.training.loss.task_weight else 'ddp_find_unused_parameters_true' + num_nodes=hyperparams.experiment.num_nodes, devices=hyperparams.experiment.num_gpus, + strategy='ddp' if not hyperparams.experiment.training.loss.task_weight else 'ddp_find_unused_parameters_true' # since when you do a forward pass through the (large) prediction network (such as Segment-Anything Model) # there might be some unused parameters in the prediction network, so we need to set the strategy to ddp_find_unused_parameters_true ) return trainer - - + + def get_callbacks(hyperparams: DictConfig): - + checkpoint_callback = ModelCheckpoint( dirpath=hyperparams.checkpoint.checkpoint_path, filename=hyperparams.checkpoint.checkpoint_name, @@ -81,12 +81,12 @@ def get_callbacks(hyperparams: DictConfig): mode="max", save_on_train_epoch_end=False, ) - early_stop_metric_callback = EarlyStopping(monitor="val/map", - min_delta=hyperparams.experiment.training.min_delta, - patience=hyperparams.experiment.training.patience, - verbose=True, + early_stop_metric_callback = EarlyStopping(monitor="val/map", + min_delta=hyperparams.experiment.training.min_delta, + patience=hyperparams.experiment.training.patience, + verbose=True, mode="max") - + return [checkpoint_callback, early_stop_metric_callback] def get_recursive_hyperparams_identifier(hyperparams: Dict): @@ -99,22 +99,22 @@ def get_recursive_hyperparams_identifier(hyperparams: Dict): else: identifier += f"_{key}_{value}_" return identifier - + def get_checkpoint_name(hyperparams : DictConfig): - + return f"{get_recursive_hyperparams_identifier(hyperparams.canonicalization)}".lstrip("_") + \ f"__epochs_{hyperparams.experiment.training.num_epochs}_" + f"__seed_{hyperparams.experiment.seed}" - + def get_image_data(dataset_hyperparams: DictConfig): - + dataset_classes = { "coco": COCODataModule } - + if dataset_hyperparams.dataset_name not in dataset_classes: raise ValueError(f"{dataset_hyperparams.dataset_name} not implemented") - + return dataset_classes[dataset_hyperparams.dataset_name](dataset_hyperparams) def load_envs(env_file: Optional[str] = None) -> None: @@ -127,4 +127,4 @@ def load_envs(env_file: Optional[str] = None) -> None: :param env_file: the file that defines the environment variables to use. If None it searches for a `.env` file in the project. """ - dotenv.load_dotenv(dotenv_path=env_file, override=True) \ No newline at end of file + dotenv.load_dotenv(dotenv_path=env_file, override=True)