From f1635fabb6d8319847dbb3089961a0b19836c638 Mon Sep 17 00:00:00 2001 From: coincheung <867153576@qq.com> Date: Fri, 12 Feb 2021 07:56:19 +0000 Subject: [PATCH] upload --- .github/CODE_OF_CONDUCT.md | 76 +++ .github/CONTRIBUTING.md | 30 ++ .gitignore | 109 ++++ LICENSE | 399 ++++++++++++++ README.md | 125 +++++ detection/README.md | 61 +++ detection/configs/Base-RCNN-C4-BN.yaml | 17 + detection/configs/coco_R_50_C4_2x.yaml | 13 + detection/configs/coco_R_50_C4_2x_moco.yaml | 9 + detection/configs/pascal_voc_R_50_C4_24k.yaml | 16 + .../configs/pascal_voc_R_50_C4_24k_moco.yaml | 9 + detection/convert-pretrain-to-detectron2.py | 34 ++ detection/train_net.py | 75 +++ dist_train.sh | 11 + main_densecl.py | 417 +++++++++++++++ main_lincls.py | 502 ++++++++++++++++++ main_moco.py | 419 +++++++++++++++ moco/__init__.py | 1 + moco/builder.py | 205 +++++++ moco/loader.py | 27 + resnet.py | 360 +++++++++++++ 21 files changed, 2915 insertions(+) create mode 100644 .github/CODE_OF_CONDUCT.md create mode 100644 .github/CONTRIBUTING.md create mode 100644 .gitignore create mode 100644 LICENSE create mode 100644 README.md create mode 100644 detection/README.md create mode 100644 detection/configs/Base-RCNN-C4-BN.yaml create mode 100644 detection/configs/coco_R_50_C4_2x.yaml create mode 100644 detection/configs/coco_R_50_C4_2x_moco.yaml create mode 100644 detection/configs/pascal_voc_R_50_C4_24k.yaml create mode 100644 detection/configs/pascal_voc_R_50_C4_24k_moco.yaml create mode 100755 detection/convert-pretrain-to-detectron2.py create mode 100755 detection/train_net.py create mode 100644 dist_train.sh create mode 100755 main_densecl.py create mode 100755 main_lincls.py create mode 100755 main_moco.py create mode 100644 moco/__init__.py create mode 100644 moco/builder.py create mode 100644 moco/loader.py create mode 100644 resnet.py diff --git a/.github/CODE_OF_CONDUCT.md b/.github/CODE_OF_CONDUCT.md new file mode 100644 index 0000000..4bd525a --- /dev/null +++ b/.github/CODE_OF_CONDUCT.md @@ -0,0 +1,76 @@ +# Code of Conduct + +## Our Pledge + +In the interest of fostering an open and welcoming environment, we as +contributors and maintainers pledge to make participation in our project and +our community a harassment-free experience for everyone, regardless of age, body +size, disability, ethnicity, sex characteristics, gender identity and expression, +level of experience, education, socio-economic status, nationality, personal +appearance, race, religion, or sexual identity and orientation. + +## Our Standards + +Examples of behavior that contributes to creating a positive environment +include: + +* Using welcoming and inclusive language +* Being respectful of differing viewpoints and experiences +* Gracefully accepting constructive criticism +* Focusing on what is best for the community +* Showing empathy towards other community members + +Examples of unacceptable behavior by participants include: + +* The use of sexualized language or imagery and unwelcome sexual attention or +advances +* Trolling, insulting/derogatory comments, and personal or political attacks +* Public or private harassment +* Publishing others' private information, such as a physical or electronic +address, without explicit permission +* Other conduct which could reasonably be considered inappropriate in a +professional setting + +## Our Responsibilities + +Project maintainers are responsible for clarifying the standards of acceptable +behavior and are expected to take appropriate and fair corrective action in +response to any instances of unacceptable behavior. + +Project maintainers have the right and responsibility to remove, edit, or +reject comments, commits, code, wiki edits, issues, and other contributions +that are not aligned to this Code of Conduct, or to ban temporarily or +permanently any contributor for other behaviors that they deem inappropriate, +threatening, offensive, or harmful. + +## Scope + +This Code of Conduct applies within all project spaces, and it also applies when +an individual is representing the project or its community in public spaces. +Examples of representing a project or community include using an official +project e-mail address, posting via an official social media account, or acting +as an appointed representative at an online or offline event. Representation of +a project may be further defined and clarified by project maintainers. + +## Enforcement + +Instances of abusive, harassing, or otherwise unacceptable behavior may be +reported by contacting the project team at . All +complaints will be reviewed and investigated and will result in a response that +is deemed necessary and appropriate to the circumstances. The project team is +obligated to maintain confidentiality with regard to the reporter of an incident. +Further details of specific enforcement policies may be posted separately. + +Project maintainers who do not follow or enforce the Code of Conduct in good +faith may face temporary or permanent repercussions as determined by other +members of the project's leadership. + +## Attribution + +This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, +available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html + +[homepage]: https://www.contributor-covenant.org + +For answers to common questions about this code of conduct, see +https://www.contributor-covenant.org/faq diff --git a/.github/CONTRIBUTING.md b/.github/CONTRIBUTING.md new file mode 100644 index 0000000..8a422bd --- /dev/null +++ b/.github/CONTRIBUTING.md @@ -0,0 +1,30 @@ +# Contributing to moco +We want to make contributing to this project as easy and transparent as +possible. + +## Pull Requests +We actively welcome your pull requests. +But note that this is a research project for reproducing results in a paper, +we may not accept PRs adding new features if they do not align with the goal of reproducing +results in the paper. + +If you haven't already, complete the Contributor License Agreement ("CLA") before sending a pull +request. + +## Contributor License Agreement ("CLA") +In order to accept your pull request, we need you to submit a CLA. You only need +to do this once to work on any of Facebook's open source projects. + +Complete your CLA here: + +## Issues +We use GitHub issues to track public bugs. Please ensure your description is +clear and has sufficient instructions to be able to reproduce the issue. + +Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe +disclosure of security bugs. In those cases, please go through the process +outlined on that page and do not file a public issue. + +## License +By contributing to moco, you agree that your contributions will be licensed +under the LICENSE file in the root directory of this source tree. diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..33846b4 --- /dev/null +++ b/.gitignore @@ -0,0 +1,109 @@ + +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# pyenv +.python-version + +# celery beat schedule file +celerybeat-schedule + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ + + +## Coin: +res/ +*pth diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..6b28d56 --- /dev/null +++ b/LICENSE @@ -0,0 +1,399 @@ +Attribution-NonCommercial 4.0 International + +======================================================================= + +Creative Commons Corporation ("Creative Commons") is not a law firm and +does not provide legal services or legal advice. Distribution of +Creative Commons public licenses does not create a lawyer-client or +other relationship. Creative Commons makes its licenses and related +information available on an "as-is" basis. Creative Commons gives no +warranties regarding its licenses, any material licensed under their +terms and conditions, or any related information. Creative Commons +disclaims all liability for damages resulting from their use to the +fullest extent possible. + +Using Creative Commons Public Licenses + +Creative Commons public licenses provide a standard set of terms and +conditions that creators and other rights holders may use to share +original works of authorship and other material subject to copyright +and certain other rights specified in the public license below. The +following considerations are for informational purposes only, are not +exhaustive, and do not form part of our licenses. + + Considerations for licensors: Our public licenses are + intended for use by those authorized to give the public + permission to use material in ways otherwise restricted by + copyright and certain other rights. Our licenses are + irrevocable. Licensors should read and understand the terms + and conditions of the license they choose before applying it. + Licensors should also secure all rights necessary before + applying our licenses so that the public can reuse the + material as expected. Licensors should clearly mark any + material not subject to the license. This includes other CC- + licensed material, or material used under an exception or + limitation to copyright. More considerations for licensors: + wiki.creativecommons.org/Considerations_for_licensors + + Considerations for the public: By using one of our public + licenses, a licensor grants the public permission to use the + licensed material under specified terms and conditions. If + the licensor's permission is not necessary for any reason--for + example, because of any applicable exception or limitation to + copyright--then that use is not regulated by the license. Our + licenses grant only permissions under copyright and certain + other rights that a licensor has authority to grant. Use of + the licensed material may still be restricted for other + reasons, including because others have copyright or other + rights in the material. A licensor may make special requests, + such as asking that all changes be marked or described. + Although not required by our licenses, you are encouraged to + respect those requests where reasonable. More_considerations + for the public: + wiki.creativecommons.org/Considerations_for_licensees + +======================================================================= + +Creative Commons Attribution-NonCommercial 4.0 International Public +License + +By exercising the Licensed Rights (defined below), You accept and agree +to be bound by the terms and conditions of this Creative Commons +Attribution-NonCommercial 4.0 International Public License ("Public +License"). To the extent this Public License may be interpreted as a +contract, You are granted the Licensed Rights in consideration of Your +acceptance of these terms and conditions, and the Licensor grants You +such rights in consideration of benefits the Licensor receives from +making the Licensed Material available under these terms and +conditions. + +Section 1 -- Definitions. + + a. Adapted Material means material subject to Copyright and Similar + Rights that is derived from or based upon the Licensed Material + and in which the Licensed Material is translated, altered, + arranged, transformed, or otherwise modified in a manner requiring + permission under the Copyright and Similar Rights held by the + Licensor. For purposes of this Public License, where the Licensed + Material is a musical work, performance, or sound recording, + Adapted Material is always produced where the Licensed Material is + synched in timed relation with a moving image. + + b. Adapter's License means the license You apply to Your Copyright + and Similar Rights in Your contributions to Adapted Material in + accordance with the terms and conditions of this Public License. + + c. Copyright and Similar Rights means copyright and/or similar rights + closely related to copyright including, without limitation, + performance, broadcast, sound recording, and Sui Generis Database + Rights, without regard to how the rights are labeled or + categorized. For purposes of this Public License, the rights + specified in Section 2(b)(1)-(2) are not Copyright and Similar + Rights. + d. Effective Technological Measures means those measures that, in the + absence of proper authority, may not be circumvented under laws + fulfilling obligations under Article 11 of the WIPO Copyright + Treaty adopted on December 20, 1996, and/or similar international + agreements. + + e. Exceptions and Limitations means fair use, fair dealing, and/or + any other exception or limitation to Copyright and Similar Rights + that applies to Your use of the Licensed Material. + + f. Licensed Material means the artistic or literary work, database, + or other material to which the Licensor applied this Public + License. + + g. Licensed Rights means the rights granted to You subject to the + terms and conditions of this Public License, which are limited to + all Copyright and Similar Rights that apply to Your use of the + Licensed Material and that the Licensor has authority to license. + + h. Licensor means the individual(s) or entity(ies) granting rights + under this Public License. + + i. NonCommercial means not primarily intended for or directed towards + commercial advantage or monetary compensation. For purposes of + this Public License, the exchange of the Licensed Material for + other material subject to Copyright and Similar Rights by digital + file-sharing or similar means is NonCommercial provided there is + no payment of monetary compensation in connection with the + exchange. + + j. Share means to provide material to the public by any means or + process that requires permission under the Licensed Rights, such + as reproduction, public display, public performance, distribution, + dissemination, communication, or importation, and to make material + available to the public including in ways that members of the + public may access the material from a place and at a time + individually chosen by them. + + k. Sui Generis Database Rights means rights other than copyright + resulting from Directive 96/9/EC of the European Parliament and of + the Council of 11 March 1996 on the legal protection of databases, + as amended and/or succeeded, as well as other essentially + equivalent rights anywhere in the world. + + l. You means the individual or entity exercising the Licensed Rights + under this Public License. Your has a corresponding meaning. + +Section 2 -- Scope. + + a. License grant. + + 1. Subject to the terms and conditions of this Public License, + the Licensor hereby grants You a worldwide, royalty-free, + non-sublicensable, non-exclusive, irrevocable license to + exercise the Licensed Rights in the Licensed Material to: + + a. reproduce and Share the Licensed Material, in whole or + in part, for NonCommercial purposes only; and + + b. produce, reproduce, and Share Adapted Material for + NonCommercial purposes only. + + 2. Exceptions and Limitations. For the avoidance of doubt, where + Exceptions and Limitations apply to Your use, this Public + License does not apply, and You do not need to comply with + its terms and conditions. + + 3. Term. The term of this Public License is specified in Section + 6(a). + + 4. Media and formats; technical modifications allowed. The + Licensor authorizes You to exercise the Licensed Rights in + all media and formats whether now known or hereafter created, + and to make technical modifications necessary to do so. The + Licensor waives and/or agrees not to assert any right or + authority to forbid You from making technical modifications + necessary to exercise the Licensed Rights, including + technical modifications necessary to circumvent Effective + Technological Measures. For purposes of this Public License, + simply making modifications authorized by this Section 2(a) + (4) never produces Adapted Material. + + 5. Downstream recipients. + + a. Offer from the Licensor -- Licensed Material. Every + recipient of the Licensed Material automatically + receives an offer from the Licensor to exercise the + Licensed Rights under the terms and conditions of this + Public License. + + b. No downstream restrictions. You may not offer or impose + any additional or different terms or conditions on, or + apply any Effective Technological Measures to, the + Licensed Material if doing so restricts exercise of the + Licensed Rights by any recipient of the Licensed + Material. + + 6. No endorsement. Nothing in this Public License constitutes or + may be construed as permission to assert or imply that You + are, or that Your use of the Licensed Material is, connected + with, or sponsored, endorsed, or granted official status by, + the Licensor or others designated to receive attribution as + provided in Section 3(a)(1)(A)(i). + + b. Other rights. + + 1. Moral rights, such as the right of integrity, are not + licensed under this Public License, nor are publicity, + privacy, and/or other similar personality rights; however, to + the extent possible, the Licensor waives and/or agrees not to + assert any such rights held by the Licensor to the limited + extent necessary to allow You to exercise the Licensed + Rights, but not otherwise. + + 2. Patent and trademark rights are not licensed under this + Public License. + + 3. To the extent possible, the Licensor waives any right to + collect royalties from You for the exercise of the Licensed + Rights, whether directly or through a collecting society + under any voluntary or waivable statutory or compulsory + licensing scheme. In all other cases the Licensor expressly + reserves any right to collect such royalties, including when + the Licensed Material is used other than for NonCommercial + purposes. + +Section 3 -- License Conditions. + +Your exercise of the Licensed Rights is expressly made subject to the +following conditions. + + a. Attribution. + + 1. If You Share the Licensed Material (including in modified + form), You must: + + a. retain the following if it is supplied by the Licensor + with the Licensed Material: + + i. identification of the creator(s) of the Licensed + Material and any others designated to receive + attribution, in any reasonable manner requested by + the Licensor (including by pseudonym if + designated); + + ii. a copyright notice; + + iii. a notice that refers to this Public License; + + iv. a notice that refers to the disclaimer of + warranties; + + v. a URI or hyperlink to the Licensed Material to the + extent reasonably practicable; + + b. indicate if You modified the Licensed Material and + retain an indication of any previous modifications; and + + c. indicate the Licensed Material is licensed under this + Public License, and include the text of, or the URI or + hyperlink to, this Public License. + + 2. You may satisfy the conditions in Section 3(a)(1) in any + reasonable manner based on the medium, means, and context in + which You Share the Licensed Material. For example, it may be + reasonable to satisfy the conditions by providing a URI or + hyperlink to a resource that includes the required + information. + + 3. If requested by the Licensor, You must remove any of the + information required by Section 3(a)(1)(A) to the extent + reasonably practicable. + + 4. If You Share Adapted Material You produce, the Adapter's + License You apply must not prevent recipients of the Adapted + Material from complying with this Public License. + +Section 4 -- Sui Generis Database Rights. + +Where the Licensed Rights include Sui Generis Database Rights that +apply to Your use of the Licensed Material: + + a. for the avoidance of doubt, Section 2(a)(1) grants You the right + to extract, reuse, reproduce, and Share all or a substantial + portion of the contents of the database for NonCommercial purposes + only; + + b. if You include all or a substantial portion of the database + contents in a database in which You have Sui Generis Database + Rights, then the database in which You have Sui Generis Database + Rights (but not its individual contents) is Adapted Material; and + + c. You must comply with the conditions in Section 3(a) if You Share + all or a substantial portion of the contents of the database. + +For the avoidance of doubt, this Section 4 supplements and does not +replace Your obligations under this Public License where the Licensed +Rights include other Copyright and Similar Rights. + +Section 5 -- Disclaimer of Warranties and Limitation of Liability. + + a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE + EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS + AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF + ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, + IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, + WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR + PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, + ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT + KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT + ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. + + b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE + TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, + NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, + INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, + COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR + USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN + ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR + DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR + IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. + + c. The disclaimer of warranties and limitation of liability provided + above shall be interpreted in a manner that, to the extent + possible, most closely approximates an absolute disclaimer and + waiver of all liability. + +Section 6 -- Term and Termination. + + a. This Public License applies for the term of the Copyright and + Similar Rights licensed here. However, if You fail to comply with + this Public License, then Your rights under this Public License + terminate automatically. + + b. Where Your right to use the Licensed Material has terminated under + Section 6(a), it reinstates: + + 1. automatically as of the date the violation is cured, provided + it is cured within 30 days of Your discovery of the + violation; or + + 2. upon express reinstatement by the Licensor. + + For the avoidance of doubt, this Section 6(b) does not affect any + right the Licensor may have to seek remedies for Your violations + of this Public License. + + c. For the avoidance of doubt, the Licensor may also offer the + Licensed Material under separate terms or conditions or stop + distributing the Licensed Material at any time; however, doing so + will not terminate this Public License. + + d. Sections 1, 5, 6, 7, and 8 survive termination of this Public + License. + +Section 7 -- Other Terms and Conditions. + + a. The Licensor shall not be bound by any additional or different + terms or conditions communicated by You unless expressly agreed. + + b. Any arrangements, understandings, or agreements regarding the + Licensed Material not stated herein are separate from and + independent of the terms and conditions of this Public License. + +Section 8 -- Interpretation. + + a. For the avoidance of doubt, this Public License does not, and + shall not be interpreted to, reduce, limit, restrict, or impose + conditions on any use of the Licensed Material that could lawfully + be made without permission under this Public License. + + b. To the extent possible, if any provision of this Public License is + deemed unenforceable, it shall be automatically reformed to the + minimum extent necessary to make it enforceable. If the provision + cannot be reformed, it shall be severed from this Public License + without affecting the enforceability of the remaining terms and + conditions. + + c. No term or condition of this Public License will be waived and no + failure to comply consented to unless expressly agreed to by the + Licensor. + + d. Nothing in this Public License constitutes or may be interpreted + as a limitation upon, or waiver of, any privileges and immunities + that apply to the Licensor or You, including from the legal + processes of any jurisdiction or authority. + +======================================================================= + +Creative Commons is not a party to its public +licenses. Notwithstanding, Creative Commons may elect to apply one of +its public licenses to material it publishes and in those instances +will be considered the “Licensor.” The text of the Creative Commons +public licenses is dedicated to the public domain under the CC0 Public +Domain Dedication. Except for the limited purpose of indicating that +material is shared under a Creative Commons public license or as +otherwise permitted by the Creative Commons policies published at +creativecommons.org/policies, Creative Commons does not authorize the +use of the trademark "Creative Commons" or any other trademark or logo +of Creative Commons without its prior written consent including, +without limitation, in connection with any unauthorized modifications +to any of its public licenses or any other arrangements, +understandings, or agreements concerning use of licensed material. For +the avoidance of doubt, this paragraph does not form part of the +public licenses. + +Creative Commons may be contacted at creativecommons.org. diff --git a/README.md b/README.md new file mode 100644 index 0000000..3fd229d --- /dev/null +++ b/README.md @@ -0,0 +1,125 @@ +## DenseCL: Dense Contrastive Learning for Self-Supervised Visual Pre-Training + + +

+ +

+ +This is an unofficial PyTorch implementation of the [DenseCL paper](https://arxiv.org/abs/2011.09157): + + +### Preparation + +Install PyTorch and ImageNet dataset following the [official PyTorch ImageNet training code](https://github.com/pytorch/examples/tree/master/imagenet). + +This repo aims to be minimal modifications on that code. Check the modifications by: +``` +diff main_densecl.py <(curl https://raw.githubusercontent.com/pytorch/examples/master/imagenet/main.py) +diff main_lincls.py <(curl https://raw.githubusercontent.com/pytorch/examples/master/imagenet/main.py) +``` + + +### Unsupervised Training + +This implementation only supports **multi-gpu**, **DistributedDataParallel** training, which is faster and simpler; single-gpu or DataParallel training is not supported. + +To do unsupervised pre-training of a ResNet-50 model on ImageNet in an 8-gpu machine, run: +``` +sh dist_train.sh [your imagenet-folder with train and val folders] +``` +Since the paper says they use default mocov2 hyper-parameters, the above script uses same hyper-parameters as mocov2. + +***Note***: for 4-gpu training, we recommend following the [linear lr scaling recipe](https://arxiv.org/abs/1706.02677): `--lr 0.015 --batch-size 128` with 4 gpus. We got similar results using this setting. + + +### Linear Classification + +With a pre-trained model, to train a supervised linear classifier on frozen features/weights in an 8-gpu machine, run: +``` +python main_lincls.py \ + -a resnet50 \ + --lr 30.0 \ + --batch-size 256 \ + --pretrained [your checkpoint path]/checkpoint_0199.pth.tar \ + --dist-url 'tcp://localhost:10001' --multiprocessing-distributed --world-size 1 --rank 0 \ + [your imagenet-folder with train and val folders] +``` + +Linear classification results on ImageNet using this repo with 8 NVIDIA V100 GPUs : + + + + + + + + + + + + + + + +
pre-train
epochs
pre-train
time
MoCo v1
top-1 acc.
MoCo v2
top-1 acc.
ResNet-5020053 hours60.8±0.267.5±0.1
+ +Here we run 5 trials (of pre-training and linear classification) and report mean±std: the 5 results of MoCo v1 are {60.6, 60.6, 60.7, 60.9, 61.1}, and of MoCo v2 are {67.7, 67.6, 67.4, 67.6, 67.3}. + + +### Models + +Our pre-trained ResNet-50 models can be downloaded as following: + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
epochsmlpaug+costop-1 acc.modelmd5
MoCo v120060.6downloadb251726a
MoCo v220067.7download59fd9945
MoCo v280071.1downloada04e12f8
+ + +### Transferring to Object Detection + +See [./detection](detection). + + +### License + +This project is under the CC-BY-NC 4.0 license. See [LICENSE](LICENSE) for details. + +### See Also +* [moco.tensorflow](https://github.com/ppwwyyxx/moco.tensorflow): A TensorFlow re-implementation. +* [Colab notebook](https://colab.research.google.com/github/facebookresearch/moco/blob/colab-notebook/colab/moco_cifar10_demo.ipynb): CIFAR demo on Colab GPU. diff --git a/detection/README.md b/detection/README.md new file mode 100644 index 0000000..d5bd80c --- /dev/null +++ b/detection/README.md @@ -0,0 +1,61 @@ + +## MoCo: Transferring to Detection + +The `train_net.py` script reproduces the object detection experiments on Pascal VOC and COCO. + +### Instruction + +1. Install [detectron2](https://github.com/facebookresearch/detectron2/blob/master/INSTALL.md). + +1. Convert a pre-trained MoCo model to detectron2's format: + ``` + python3 convert-pretrain-to-detectron2.py input.pth.tar output.pkl + ``` + +1. Put dataset under "./datasets" directory, + following the [directory structure](https://github.com/facebookresearch/detectron2/tree/master/datasets) + requried by detectron2. + +1. Run training: + ``` + python train_net.py --config-file configs/pascal_voc_R_50_C4_24k_moco.yaml \ + --num-gpus 8 MODEL.WEIGHTS ./output.pkl + ``` + +### Results + +Below are the results on Pascal VOC 2007 test, fine-tuned on 2007+2012 trainval for 24k iterations using Faster R-CNN with a R50-C4 backbone: + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
pretrainAP50APAP75
ImageNet-1M, supervised81.353.558.8
ImageNet-1M, MoCo v1, 200ep81.555.962.6
ImageNet-1M, MoCo v2, 200ep82.457.063.6
ImageNet-1M, MoCo v2, 800ep82.557.464.0
+ +***Note:*** These results are means of 5 trials. Variation on Pascal VOC is large: the std of AP50, AP, AP75 is expected to be 0.2, 0.2, 0.4 in most cases. We recommend to run 5 trials and compute means. diff --git a/detection/configs/Base-RCNN-C4-BN.yaml b/detection/configs/Base-RCNN-C4-BN.yaml new file mode 100644 index 0000000..5104c6a --- /dev/null +++ b/detection/configs/Base-RCNN-C4-BN.yaml @@ -0,0 +1,17 @@ +MODEL: + META_ARCHITECTURE: "GeneralizedRCNN" + RPN: + PRE_NMS_TOPK_TEST: 6000 + POST_NMS_TOPK_TEST: 1000 + ROI_HEADS: + NAME: "Res5ROIHeadsExtraNorm" + BACKBONE: + FREEZE_AT: 0 + RESNETS: + NORM: "SyncBN" +TEST: + PRECISE_BN: + ENABLED: True +SOLVER: + IMS_PER_BATCH: 16 + BASE_LR: 0.02 diff --git a/detection/configs/coco_R_50_C4_2x.yaml b/detection/configs/coco_R_50_C4_2x.yaml new file mode 100644 index 0000000..5b7e424 --- /dev/null +++ b/detection/configs/coco_R_50_C4_2x.yaml @@ -0,0 +1,13 @@ +_BASE_: "Base-RCNN-C4-BN.yaml" +MODEL: + MASK_ON: True + WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl" +INPUT: + MIN_SIZE_TRAIN: (640, 672, 704, 736, 768, 800) + MIN_SIZE_TEST: 800 +DATASETS: + TRAIN: ("coco_2017_train",) + TEST: ("coco_2017_val",) +SOLVER: + STEPS: (120000, 160000) + MAX_ITER: 180000 diff --git a/detection/configs/coco_R_50_C4_2x_moco.yaml b/detection/configs/coco_R_50_C4_2x_moco.yaml new file mode 100644 index 0000000..73ef270 --- /dev/null +++ b/detection/configs/coco_R_50_C4_2x_moco.yaml @@ -0,0 +1,9 @@ +_BASE_: "coco_R_50_C4_2x.yaml" +MODEL: + PIXEL_MEAN: [123.675, 116.280, 103.530] + PIXEL_STD: [58.395, 57.120, 57.375] + WEIGHTS: "See Instructions" + RESNETS: + STRIDE_IN_1X1: False +INPUT: + FORMAT: "RGB" diff --git a/detection/configs/pascal_voc_R_50_C4_24k.yaml b/detection/configs/pascal_voc_R_50_C4_24k.yaml new file mode 100644 index 0000000..a05eb5e --- /dev/null +++ b/detection/configs/pascal_voc_R_50_C4_24k.yaml @@ -0,0 +1,16 @@ +_BASE_: "Base-RCNN-C4-BN.yaml" +MODEL: + MASK_ON: False + WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl" + ROI_HEADS: + NUM_CLASSES: 20 +INPUT: + MIN_SIZE_TRAIN: (480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800) + MIN_SIZE_TEST: 800 +DATASETS: + TRAIN: ('voc_2007_trainval', 'voc_2012_trainval') + TEST: ('voc_2007_test',) +SOLVER: + STEPS: (18000, 22000) + MAX_ITER: 24000 + WARMUP_ITERS: 100 diff --git a/detection/configs/pascal_voc_R_50_C4_24k_moco.yaml b/detection/configs/pascal_voc_R_50_C4_24k_moco.yaml new file mode 100644 index 0000000..eebe690 --- /dev/null +++ b/detection/configs/pascal_voc_R_50_C4_24k_moco.yaml @@ -0,0 +1,9 @@ +_BASE_: "pascal_voc_R_50_C4_24k.yaml" +MODEL: + PIXEL_MEAN: [123.675, 116.280, 103.530] + PIXEL_STD: [58.395, 57.120, 57.375] + WEIGHTS: "See Instructions" + RESNETS: + STRIDE_IN_1X1: False +INPUT: + FORMAT: "RGB" diff --git a/detection/convert-pretrain-to-detectron2.py b/detection/convert-pretrain-to-detectron2.py new file mode 100755 index 0000000..b96ed91 --- /dev/null +++ b/detection/convert-pretrain-to-detectron2.py @@ -0,0 +1,34 @@ +#!/usr/bin/env python +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved + +import pickle as pkl +import sys +import torch + +if __name__ == "__main__": + input = sys.argv[1] + + obj = torch.load(input, map_location="cpu") + obj = obj["state_dict"] + + newmodel = {} + for k, v in obj.items(): + if not k.startswith("module.encoder_q."): + continue + old_k = k + k = k.replace("module.encoder_q.", "") + if "layer" not in k: + k = "stem." + k + for t in [1, 2, 3, 4]: + k = k.replace("layer{}".format(t), "res{}".format(t + 1)) + for t in [1, 2, 3]: + k = k.replace("bn{}".format(t), "conv{}.norm".format(t)) + k = k.replace("downsample.0", "shortcut") + k = k.replace("downsample.1", "shortcut.norm") + print(old_k, "->", k) + newmodel[k] = v.numpy() + + res = {"model": newmodel, "__author__": "MOCO", "matching_heuristics": True} + + with open(sys.argv[2], "wb") as f: + pkl.dump(res, f) diff --git a/detection/train_net.py b/detection/train_net.py new file mode 100755 index 0000000..39e844c --- /dev/null +++ b/detection/train_net.py @@ -0,0 +1,75 @@ +#!/usr/bin/env python +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved + +import os + +from detectron2.checkpoint import DetectionCheckpointer +from detectron2.config import get_cfg +from detectron2.engine import DefaultTrainer, default_argument_parser, default_setup, launch +from detectron2.evaluation import COCOEvaluator, PascalVOCDetectionEvaluator +from detectron2.layers import get_norm +from detectron2.modeling.roi_heads import ROI_HEADS_REGISTRY, Res5ROIHeads + + +@ROI_HEADS_REGISTRY.register() +class Res5ROIHeadsExtraNorm(Res5ROIHeads): + """ + As described in the MOCO paper, there is an extra BN layer + following the res5 stage. + """ + def _build_res5_block(self, cfg): + seq, out_channels = super()._build_res5_block(cfg) + norm = cfg.MODEL.RESNETS.NORM + norm = get_norm(norm, out_channels) + seq.add_module("norm", norm) + return seq, out_channels + + +class Trainer(DefaultTrainer): + @classmethod + def build_evaluator(cls, cfg, dataset_name, output_folder=None): + if output_folder is None: + output_folder = os.path.join(cfg.OUTPUT_DIR, "inference") + if "coco" in dataset_name: + return COCOEvaluator(dataset_name, cfg, True, output_folder) + else: + assert "voc" in dataset_name + return PascalVOCDetectionEvaluator(dataset_name) + + +def setup(args): + cfg = get_cfg() + cfg.merge_from_file(args.config_file) + cfg.merge_from_list(args.opts) + cfg.freeze() + default_setup(cfg, args) + return cfg + + +def main(args): + cfg = setup(args) + + if args.eval_only: + model = Trainer.build_model(cfg) + DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load( + cfg.MODEL.WEIGHTS, resume=args.resume + ) + res = Trainer.test(cfg, model) + return res + + trainer = Trainer(cfg) + trainer.resume_or_load(resume=args.resume) + return trainer.train() + + +if __name__ == "__main__": + args = default_argument_parser().parse_args() + print("Command Line Args:", args) + launch( + main, + args.num_gpus, + num_machines=args.num_machines, + machine_rank=args.machine_rank, + dist_url=args.dist_url, + args=(args,), + ) diff --git a/dist_train.sh b/dist_train.sh new file mode 100644 index 0000000..a2afc23 --- /dev/null +++ b/dist_train.sh @@ -0,0 +1,11 @@ + +export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 + + +# for pretrain +# EPOCHS=200 +# python main_moco.py -a resnet50 --lr 0.03 --batch-size 256 --epochs $EPOCHS --world-size 1 --rank 0 --dist-url 'tcp://localhost:10001' --multiprocessing-distributed --use-mixed-precision --mlp --moco-t 0.2 --aug-plus --cos $1 + +PRETRAINED=res/r50_org_v2/checkpoint_0199.pth.tar +python main_lincls.py -a resnet50 --lr 30.0 --batch-size 256 --pretrained $PRETRAINED --dist-url 'tcp://127.0.0.1:10002' --multiprocessing-distributed --world-size 1 --rank 0 /data/zzy/.datasets/imagenet/ + diff --git a/main_densecl.py b/main_densecl.py new file mode 100755 index 0000000..80b2922 --- /dev/null +++ b/main_densecl.py @@ -0,0 +1,417 @@ +#!/usr/bin/env python +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +import argparse +import builtins +import math +import os +import random +import shutil +import time +import warnings + +import torch +import torch.nn as nn +import torch.nn.parallel +import torch.backends.cudnn as cudnn +import torch.distributed as dist +import torch.optim +import torch.multiprocessing as mp +import torch.utils.data +import torch.utils.data.distributed +import torch.cuda.amp as amp +import torchvision.transforms as transforms +import torchvision.datasets as datasets +import torchvision.models as models + +import moco.loader +import moco.builder + +model_names = sorted(name for name in models.__dict__ + if name.islower() and not name.startswith("__") + and callable(models.__dict__[name])) + +parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') +parser.add_argument('data', metavar='DIR', + help='path to dataset') +parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50', + choices=model_names, + help='model architecture: ' + + ' | '.join(model_names) + + ' (default: resnet50)') +parser.add_argument('-j', '--workers', default=32, type=int, metavar='N', + help='number of data loading workers (default: 32)') +parser.add_argument('--epochs', default=200, type=int, metavar='N', + help='number of total epochs to run') +parser.add_argument('--start-epoch', default=0, type=int, metavar='N', + help='manual epoch number (useful on restarts)') +parser.add_argument('-b', '--batch-size', default=256, type=int, + metavar='N', + help='mini-batch size (default: 256), this is the total ' + 'batch size of all GPUs on the current node when ' + 'using Data Parallel or Distributed Data Parallel') +parser.add_argument('--lr', '--learning-rate', default=0.03, type=float, + metavar='LR', help='initial learning rate', dest='lr') +parser.add_argument('--schedule', default=[120, 160], nargs='*', type=int, + help='learning rate schedule (when to drop lr by 10x)') +parser.add_argument('--momentum', default=0.9, type=float, metavar='M', + help='momentum of SGD solver') +parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float, + metavar='W', help='weight decay (default: 1e-4)', + dest='weight_decay') +parser.add_argument('-p', '--print-freq', default=100, type=int, + metavar='N', help='print frequency (default: 10)') +parser.add_argument('--resume', default='', type=str, metavar='PATH', + help='path to latest checkpoint (default: none)') +parser.add_argument('--world-size', default=-1, type=int, + help='number of nodes for distributed training') +parser.add_argument('--rank', default=-1, type=int, + help='node rank for distributed training') +parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str, + help='url used to set up distributed training') +parser.add_argument('--dist-backend', default='nccl', type=str, + help='distributed backend') +parser.add_argument('--seed', default=None, type=int, + help='seed for initializing training. ') +parser.add_argument('--gpu', default=None, type=int, + help='GPU id to use.') +parser.add_argument('--multiprocessing-distributed', action='store_true', + help='Use multi-processing distributed training to launch ' + 'N processes per node, which has N GPUs. This is the ' + 'fastest way to use PyTorch for either single node or ' + 'multi node data parallel training') +parser.add_argument('--use-mixed-precision', action='store_true', + help='use mlp head') + +# moco specific configs: +parser.add_argument('--moco-dim', default=128, type=int, + help='feature dimension (default: 128)') +parser.add_argument('--moco-k', default=65536, type=int, + help='queue size; number of negative keys (default: 65536)') +parser.add_argument('--moco-m', default=0.999, type=float, + help='moco momentum of updating key encoder (default: 0.999)') +parser.add_argument('--moco-t', default=0.07, type=float, + help='softmax temperature (default: 0.07)') + +# options for moco v2 +parser.add_argument('--mlp', action='store_true', + help='use mlp head') +parser.add_argument('--aug-plus', action='store_true', + help='use moco v2 data augmentation') +parser.add_argument('--cos', action='store_true', + help='use cosine lr schedule') + + +def main(): + args = parser.parse_args() + + if args.seed is not None: + random.seed(args.seed) + torch.manual_seed(args.seed) + cudnn.deterministic = True + warnings.warn('You have chosen to seed training. ' + 'This will turn on the CUDNN deterministic setting, ' + 'which can slow down your training considerably! ' + 'You may see unexpected behavior when restarting ' + 'from checkpoints.') + + if args.gpu is not None: + warnings.warn('You have chosen a specific GPU. This will completely ' + 'disable data parallelism.') + + if args.dist_url == "env://" and args.world_size == -1: + args.world_size = int(os.environ["WORLD_SIZE"]) + + args.distributed = args.world_size > 1 or args.multiprocessing_distributed + + ngpus_per_node = torch.cuda.device_count() + if args.multiprocessing_distributed: + # Since we have ngpus_per_node processes per node, the total world_size + # needs to be adjusted accordingly + args.world_size = ngpus_per_node * args.world_size + # Use torch.multiprocessing.spawn to launch distributed processes: the + # main_worker process function + mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args)) + else: + # Simply call main_worker function + main_worker(args.gpu, ngpus_per_node, args) + + +def main_worker(gpu, ngpus_per_node, args): + args.gpu = gpu + + # suppress printing if not master + if args.multiprocessing_distributed and args.gpu != 0: + def print_pass(*args): + pass + builtins.print = print_pass + + if args.gpu is not None: + print("Use GPU: {} for training".format(args.gpu)) + + if args.distributed: + if args.dist_url == "env://" and args.rank == -1: + args.rank = int(os.environ["RANK"]) + if args.multiprocessing_distributed: + # For multiprocessing distributed training, rank needs to be the + # global rank among all the processes + args.rank = args.rank * ngpus_per_node + gpu + dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, + world_size=args.world_size, rank=args.rank) + # create model + print("=> creating model '{}'".format(args.arch)) + model = moco.builder.MoCo( + models.__dict__[args.arch], + args.moco_dim, args.moco_k, args.moco_m, args.moco_t, args.mlp) + print(model) + + if args.distributed: + # For multiprocessing distributed, DistributedDataParallel constructor + # should always set the single device scope, otherwise, + # DistributedDataParallel will use all available devices. + if args.gpu is not None: + torch.cuda.set_device(args.gpu) + model.cuda(args.gpu) + # When using a single GPU per process and per + # DistributedDataParallel, we need to divide the batch size + # ourselves based on the total number of GPUs we have + args.batch_size = int(args.batch_size / ngpus_per_node) + args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node) + model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) + else: + model.cuda() + # DistributedDataParallel will divide and allocate batch_size to all + # available GPUs if device_ids are not set + model = torch.nn.parallel.DistributedDataParallel(model) + elif args.gpu is not None: + torch.cuda.set_device(args.gpu) + model = model.cuda(args.gpu) + # comment out the following line for debugging + raise NotImplementedError("Only DistributedDataParallel is supported.") + else: + # AllGather implementation (batch shuffle, queue update, etc.) in + # this code only supports DistributedDataParallel. + raise NotImplementedError("Only DistributedDataParallel is supported.") + + # define loss function (criterion) and optimizer + criterion = nn.CrossEntropyLoss().cuda(args.gpu) + + optimizer = torch.optim.SGD(model.parameters(), args.lr, + momentum=args.momentum, + weight_decay=args.weight_decay) + + # optionally resume from a checkpoint + if args.resume: + if os.path.isfile(args.resume): + print("=> loading checkpoint '{}'".format(args.resume)) + if args.gpu is None: + checkpoint = torch.load(args.resume) + else: + # Map model to be loaded to specified single gpu. + loc = 'cuda:{}'.format(args.gpu) + checkpoint = torch.load(args.resume, map_location=loc) + args.start_epoch = checkpoint['epoch'] + model.load_state_dict(checkpoint['state_dict']) + optimizer.load_state_dict(checkpoint['optimizer']) + print("=> loaded checkpoint '{}' (epoch {})" + .format(args.resume, checkpoint['epoch'])) + else: + print("=> no checkpoint found at '{}'".format(args.resume)) + + cudnn.benchmark = True + + # Data loading code + traindir = os.path.join(args.data, 'train') + normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]) + if args.aug_plus: + # MoCo v2's aug: similar to SimCLR https://arxiv.org/abs/2002.05709 + augmentation = [ + transforms.RandomResizedCrop(224, scale=(0.2, 1.)), + transforms.RandomApply([ + transforms.ColorJitter(0.4, 0.4, 0.4, 0.1) # not strengthened + ], p=0.8), + transforms.RandomGrayscale(p=0.2), + transforms.RandomApply([moco.loader.GaussianBlur([.1, 2.])], p=0.5), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + normalize + ] + else: + # MoCo v1's aug: the same as InstDisc https://arxiv.org/abs/1805.01978 + augmentation = [ + transforms.RandomResizedCrop(224, scale=(0.2, 1.)), + transforms.RandomGrayscale(p=0.2), + transforms.ColorJitter(0.4, 0.4, 0.4, 0.4), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + normalize + ] + + train_dataset = datasets.ImageFolder( + traindir, + moco.loader.TwoCropsTransform(transforms.Compose(augmentation))) + + if args.distributed: + train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) + else: + train_sampler = None + + train_loader = torch.utils.data.DataLoader( + train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), + num_workers=args.workers, pin_memory=True, sampler=train_sampler, drop_last=True) + + for epoch in range(args.start_epoch, args.epochs): + if args.distributed: + train_sampler.set_epoch(epoch) + adjust_learning_rate(optimizer, epoch, args) + + # train for one epoch + train(train_loader, model, criterion, optimizer, epoch, args) + + n_ckpt_period = 20 + if not args.multiprocessing_distributed or (args.multiprocessing_distributed + and args.rank % ngpus_per_node == 0) and ((epoch + 1) % n_ckpt_period == 0): + if epoch > args.epochs - 10: n_ckpt_period = 1 + save_checkpoint({ + 'epoch': epoch + 1, + 'arch': args.arch, + 'state_dict': model.state_dict(), + 'optimizer' : optimizer.state_dict(), + }, is_best=False, filename='checkpoint_{:04d}.pth.tar'.format(epoch)) + + +def train(train_loader, model, criterion, optimizer, epoch, args): + batch_time = AverageMeter('Time', ':6.3f') + data_time = AverageMeter('Data', ':6.3f') + losses = AverageMeter('Loss', ':.4e') + top1 = AverageMeter('Acc@1', ':6.2f') + top5 = AverageMeter('Acc@5', ':6.2f') + progress = ProgressMeter( + len(train_loader), + [batch_time, data_time, losses, top1, top5], + prefix="Epoch: [{}]".format(epoch)) + + # amp scalar + scaler = amp.GradScaler() + + # switch to train mode + model.train() + + end = time.time() + for i, (images, _) in enumerate(train_loader): + # measure data loading time + data_time.update(time.time() - end) + + if args.gpu is not None: + images[0] = images[0].cuda(args.gpu, non_blocking=True) + images[1] = images[1].cuda(args.gpu, non_blocking=True) + + # compute output + with amp.autocast(enabled=args.use_mixed_precision): + output, target, output_dense, target_dense = model( + im_q=images[0], im_k=images[1]) + loss = criterion(output, target) + loss_dense = criterion(output_dense, target_dense) + loss = 0.5 * (loss + loss_dense) + + # acc1/acc5 are (K+1)-way contrast classifier accuracy + # measure accuracy and record loss + acc1, acc5 = accuracy(output, target, topk=(1, 5)) + losses.update(loss.item(), images[0].size(0)) + top1.update(acc1[0], images[0].size(0)) + top5.update(acc5[0], images[0].size(0)) + + # compute gradient and do SGD step + optimizer.zero_grad() + scaler.scale(loss).backward() + scaler.step(optimizer) + scaler.update() + # loss.backward() + # optimizer.step() + + # measure elapsed time + batch_time.update(time.time() - end) + end = time.time() + + if i % args.print_freq == 0: + progress.display(i) + + +def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): + torch.save(state, filename) + if is_best: + shutil.copyfile(filename, 'model_best.pth.tar') + + +class AverageMeter(object): + """Computes and stores the average and current value""" + def __init__(self, name, fmt=':f'): + self.name = name + self.fmt = fmt + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + def __str__(self): + fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' + return fmtstr.format(**self.__dict__) + + +class ProgressMeter(object): + def __init__(self, num_batches, meters, prefix=""): + self.batch_fmtstr = self._get_batch_fmtstr(num_batches) + self.meters = meters + self.prefix = prefix + + def display(self, batch): + entries = [self.prefix + self.batch_fmtstr.format(batch)] + entries += [str(meter) for meter in self.meters] + print('\t'.join(entries)) + + def _get_batch_fmtstr(self, num_batches): + num_digits = len(str(num_batches // 1)) + fmt = '{:' + str(num_digits) + 'd}' + return '[' + fmt + '/' + fmt.format(num_batches) + ']' + + +def adjust_learning_rate(optimizer, epoch, args): + """Decay the learning rate based on schedule""" + lr = args.lr + if args.cos: # cosine lr schedule + lr *= 0.5 * (1. + math.cos(math.pi * epoch / args.epochs)) + else: # stepwise lr schedule + for milestone in args.schedule: + lr *= 0.1 if epoch >= milestone else 1. + for param_group in optimizer.param_groups: + param_group['lr'] = lr + + +def accuracy(output, target, topk=(1,)): + """Computes the accuracy over the k top predictions for the specified values of k""" + with torch.no_grad(): + maxk = max(topk) + batch_size = target.size(0) + + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + correct = pred.eq(target.view(1, -1).expand_as(pred)) + + res = [] + for k in topk: + correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) + res.append(correct_k.mul_(100.0 / batch_size)) + return res + + +if __name__ == '__main__': + main() diff --git a/main_lincls.py b/main_lincls.py new file mode 100755 index 0000000..cb9d516 --- /dev/null +++ b/main_lincls.py @@ -0,0 +1,502 @@ +#!/usr/bin/env python +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +import argparse +import builtins +import os +import random +import shutil +import time +import warnings + +import torch +import torch.nn as nn +import torch.nn.parallel +import torch.backends.cudnn as cudnn +import torch.distributed as dist +import torch.optim +import torch.multiprocessing as mp +import torch.utils.data +import torch.utils.data.distributed +import torchvision.transforms as transforms +import torchvision.datasets as datasets +import torchvision.models as models + +model_names = sorted(name for name in models.__dict__ + if name.islower() and not name.startswith("__") + and callable(models.__dict__[name])) + +parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') +parser.add_argument('data', metavar='DIR', + help='path to dataset') +parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50', + choices=model_names, + help='model architecture: ' + + ' | '.join(model_names) + + ' (default: resnet50)') +parser.add_argument('-j', '--workers', default=32, type=int, metavar='N', + help='number of data loading workers (default: 32)') +parser.add_argument('--epochs', default=100, type=int, metavar='N', + help='number of total epochs to run') +parser.add_argument('--start-epoch', default=0, type=int, metavar='N', + help='manual epoch number (useful on restarts)') +parser.add_argument('-b', '--batch-size', default=256, type=int, + metavar='N', + help='mini-batch size (default: 256), this is the total ' + 'batch size of all GPUs on the current node when ' + 'using Data Parallel or Distributed Data Parallel') +parser.add_argument('--lr', '--learning-rate', default=30., type=float, + metavar='LR', help='initial learning rate', dest='lr') +parser.add_argument('--schedule', default=[60, 80], nargs='*', type=int, + help='learning rate schedule (when to drop lr by a ratio)') +parser.add_argument('--momentum', default=0.9, type=float, metavar='M', + help='momentum') +parser.add_argument('--wd', '--weight-decay', default=0., type=float, + metavar='W', help='weight decay (default: 0.)', + dest='weight_decay') +parser.add_argument('-p', '--print-freq', default=10, type=int, + metavar='N', help='print frequency (default: 10)') +parser.add_argument('--resume', default='', type=str, metavar='PATH', + help='path to latest checkpoint (default: none)') +parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', + help='evaluate model on validation set') +parser.add_argument('--world-size', default=-1, type=int, + help='number of nodes for distributed training') +parser.add_argument('--rank', default=-1, type=int, + help='node rank for distributed training') +parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str, + help='url used to set up distributed training') +parser.add_argument('--dist-backend', default='nccl', type=str, + help='distributed backend') +parser.add_argument('--seed', default=None, type=int, + help='seed for initializing training. ') +parser.add_argument('--gpu', default=None, type=int, + help='GPU id to use.') +parser.add_argument('--multiprocessing-distributed', action='store_true', + help='Use multi-processing distributed training to launch ' + 'N processes per node, which has N GPUs. This is the ' + 'fastest way to use PyTorch for either single node or ' + 'multi node data parallel training') + +parser.add_argument('--pretrained', default='', type=str, + help='path to moco pretrained checkpoint') + +best_acc1 = 0 + + +def main(): + args = parser.parse_args() + + if args.seed is not None: + random.seed(args.seed) + torch.manual_seed(args.seed) + cudnn.deterministic = True + warnings.warn('You have chosen to seed training. ' + 'This will turn on the CUDNN deterministic setting, ' + 'which can slow down your training considerably! ' + 'You may see unexpected behavior when restarting ' + 'from checkpoints.') + + if args.gpu is not None: + warnings.warn('You have chosen a specific GPU. This will completely ' + 'disable data parallelism.') + + if args.dist_url == "env://" and args.world_size == -1: + args.world_size = int(os.environ["WORLD_SIZE"]) + + args.distributed = args.world_size > 1 or args.multiprocessing_distributed + + ngpus_per_node = torch.cuda.device_count() + if args.multiprocessing_distributed: + # Since we have ngpus_per_node processes per node, the total world_size + # needs to be adjusted accordingly + args.world_size = ngpus_per_node * args.world_size + # Use torch.multiprocessing.spawn to launch distributed processes: the + # main_worker process function + mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args)) + else: + # Simply call main_worker function + main_worker(args.gpu, ngpus_per_node, args) + + +def main_worker(gpu, ngpus_per_node, args): + global best_acc1 + args.gpu = gpu + + # suppress printing if not master + if args.multiprocessing_distributed and args.gpu != 0: + def print_pass(*args): + pass + builtins.print = print_pass + + if args.gpu is not None: + print("Use GPU: {} for training".format(args.gpu)) + + if args.distributed: + if args.dist_url == "env://" and args.rank == -1: + args.rank = int(os.environ["RANK"]) + if args.multiprocessing_distributed: + # For multiprocessing distributed training, rank needs to be the + # global rank among all the processes + args.rank = args.rank * ngpus_per_node + gpu + dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, + world_size=args.world_size, rank=args.rank) + # create model + print("=> creating model '{}'".format(args.arch)) + model = models.__dict__[args.arch]() + + # freeze all layers but the last fc + for name, param in model.named_parameters(): + if name not in ['fc.weight', 'fc.bias']: + param.requires_grad = False + # init the fc layer + model.fc.weight.data.normal_(mean=0.0, std=0.01) + model.fc.bias.data.zero_() + + # load from pre-trained, before DistributedDataParallel constructor + if args.pretrained: + if os.path.isfile(args.pretrained): + print("=> loading checkpoint '{}'".format(args.pretrained)) + checkpoint = torch.load(args.pretrained, map_location="cpu") + + # rename moco pre-trained keys + state_dict = checkpoint['state_dict'] + for k in list(state_dict.keys()): + # retain only encoder_q up to before the embedding layer + if k.startswith('module.encoder_q') and not k.startswith('module.encoder_q.fc'): + # remove prefix + state_dict[k[len("module.encoder_q."):]] = state_dict[k] + # delete renamed or unused k + del state_dict[k] + + args.start_epoch = 0 + msg = model.load_state_dict(state_dict, strict=False) + assert set(msg.missing_keys) == {"fc.weight", "fc.bias"} + + print("=> loaded pre-trained model '{}'".format(args.pretrained)) + else: + print("=> no checkpoint found at '{}'".format(args.pretrained)) + + if args.distributed: + # For multiprocessing distributed, DistributedDataParallel constructor + # should always set the single device scope, otherwise, + # DistributedDataParallel will use all available devices. + if args.gpu is not None: + torch.cuda.set_device(args.gpu) + model.cuda(args.gpu) + # When using a single GPU per process and per + # DistributedDataParallel, we need to divide the batch size + # ourselves based on the total number of GPUs we have + args.batch_size = int(args.batch_size / ngpus_per_node) + args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node) + model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) + else: + model.cuda() + # DistributedDataParallel will divide and allocate batch_size to all + # available GPUs if device_ids are not set + model = torch.nn.parallel.DistributedDataParallel(model) + elif args.gpu is not None: + torch.cuda.set_device(args.gpu) + model = model.cuda(args.gpu) + else: + # DataParallel will divide and allocate batch_size to all available GPUs + if args.arch.startswith('alexnet') or args.arch.startswith('vgg'): + model.features = torch.nn.DataParallel(model.features) + model.cuda() + else: + model = torch.nn.DataParallel(model).cuda() + + # define loss function (criterion) and optimizer + criterion = nn.CrossEntropyLoss().cuda(args.gpu) + + # optimize only the linear classifier + parameters = list(filter(lambda p: p.requires_grad, model.parameters())) + assert len(parameters) == 2 # fc.weight, fc.bias + optimizer = torch.optim.SGD(parameters, args.lr, + momentum=args.momentum, + weight_decay=args.weight_decay) + + # optionally resume from a checkpoint + if args.resume: + if os.path.isfile(args.resume): + print("=> loading checkpoint '{}'".format(args.resume)) + if args.gpu is None: + checkpoint = torch.load(args.resume) + else: + # Map model to be loaded to specified single gpu. + loc = 'cuda:{}'.format(args.gpu) + checkpoint = torch.load(args.resume, map_location=loc) + args.start_epoch = checkpoint['epoch'] + best_acc1 = checkpoint['best_acc1'] + if args.gpu is not None: + # best_acc1 may be from a checkpoint from a different GPU + best_acc1 = best_acc1.to(args.gpu) + model.load_state_dict(checkpoint['state_dict']) + optimizer.load_state_dict(checkpoint['optimizer']) + print("=> loaded checkpoint '{}' (epoch {})" + .format(args.resume, checkpoint['epoch'])) + else: + print("=> no checkpoint found at '{}'".format(args.resume)) + + cudnn.benchmark = True + + # Data loading code + traindir = os.path.join(args.data, 'train') + valdir = os.path.join(args.data, 'val') + normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]) + + train_dataset = datasets.ImageFolder( + traindir, + transforms.Compose([ + transforms.RandomResizedCrop(224), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + normalize, + ])) + + if args.distributed: + train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) + else: + train_sampler = None + + train_loader = torch.utils.data.DataLoader( + train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), + num_workers=args.workers, pin_memory=True, sampler=train_sampler) + + val_loader = torch.utils.data.DataLoader( + datasets.ImageFolder(valdir, transforms.Compose([ + transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor(), + normalize, + ])), + batch_size=args.batch_size, shuffle=False, + num_workers=args.workers, pin_memory=True) + + if args.evaluate: + validate(val_loader, model, criterion, args) + return + + for epoch in range(args.start_epoch, args.epochs): + if args.distributed: + train_sampler.set_epoch(epoch) + adjust_learning_rate(optimizer, epoch, args) + + # train for one epoch + train(train_loader, model, criterion, optimizer, epoch, args) + + # evaluate on validation set + acc1 = validate(val_loader, model, criterion, args) + + # remember best acc@1 and save checkpoint + is_best = acc1 > best_acc1 + best_acc1 = max(acc1, best_acc1) + + if not args.multiprocessing_distributed or (args.multiprocessing_distributed + and args.rank % ngpus_per_node == 0): + save_checkpoint({ + 'epoch': epoch + 1, + 'arch': args.arch, + 'state_dict': model.state_dict(), + 'best_acc1': best_acc1, + 'optimizer' : optimizer.state_dict(), + }, is_best) + if epoch == args.start_epoch: + sanity_check(model.state_dict(), args.pretrained) + + +def train(train_loader, model, criterion, optimizer, epoch, args): + batch_time = AverageMeter('Time', ':6.3f') + data_time = AverageMeter('Data', ':6.3f') + losses = AverageMeter('Loss', ':.4e') + top1 = AverageMeter('Acc@1', ':6.2f') + top5 = AverageMeter('Acc@5', ':6.2f') + progress = ProgressMeter( + len(train_loader), + [batch_time, data_time, losses, top1, top5], + prefix="Epoch: [{}]".format(epoch)) + + """ + Switch to eval mode: + Under the protocol of linear classification on frozen features/models, + it is not legitimate to change any part of the pre-trained model. + BatchNorm in train mode may revise running mean/std (even if it receives + no gradient), which are part of the model parameters too. + """ + model.eval() + + end = time.time() + for i, (images, target) in enumerate(train_loader): + # measure data loading time + data_time.update(time.time() - end) + + if args.gpu is not None: + images = images.cuda(args.gpu, non_blocking=True) + target = target.cuda(args.gpu, non_blocking=True) + + # compute output + output = model(images) + loss = criterion(output, target) + + # measure accuracy and record loss + acc1, acc5 = accuracy(output, target, topk=(1, 5)) + losses.update(loss.item(), images.size(0)) + top1.update(acc1[0], images.size(0)) + top5.update(acc5[0], images.size(0)) + + # compute gradient and do SGD step + optimizer.zero_grad() + loss.backward() + optimizer.step() + + # measure elapsed time + batch_time.update(time.time() - end) + end = time.time() + + if i % args.print_freq == 0: + progress.display(i) + + +def validate(val_loader, model, criterion, args): + batch_time = AverageMeter('Time', ':6.3f') + losses = AverageMeter('Loss', ':.4e') + top1 = AverageMeter('Acc@1', ':6.2f') + top5 = AverageMeter('Acc@5', ':6.2f') + progress = ProgressMeter( + len(val_loader), + [batch_time, losses, top1, top5], + prefix='Test: ') + + # switch to evaluate mode + model.eval() + + with torch.no_grad(): + end = time.time() + for i, (images, target) in enumerate(val_loader): + if args.gpu is not None: + images = images.cuda(args.gpu, non_blocking=True) + target = target.cuda(args.gpu, non_blocking=True) + + # compute output + output = model(images) + loss = criterion(output, target) + + # measure accuracy and record loss + acc1, acc5 = accuracy(output, target, topk=(1, 5)) + losses.update(loss.item(), images.size(0)) + top1.update(acc1[0], images.size(0)) + top5.update(acc5[0], images.size(0)) + + # measure elapsed time + batch_time.update(time.time() - end) + end = time.time() + + if i % args.print_freq == 0: + progress.display(i) + + # TODO: this should also be done with the ProgressMeter + print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}' + .format(top1=top1, top5=top5)) + + return top1.avg + + +def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): + torch.save(state, filename) + if is_best: + shutil.copyfile(filename, 'model_best.pth.tar') + + +def sanity_check(state_dict, pretrained_weights): + """ + Linear classifier should not change any weights other than the linear layer. + This sanity check asserts nothing wrong happens (e.g., BN stats updated). + """ + print("=> loading '{}' for sanity check".format(pretrained_weights)) + checkpoint = torch.load(pretrained_weights, map_location="cpu") + state_dict_pre = checkpoint['state_dict'] + + for k in list(state_dict.keys()): + # only ignore fc layer + if 'fc.weight' in k or 'fc.bias' in k: + continue + + # name in pretrained model + k_pre = 'module.encoder_q.' + k[len('module.'):] \ + if k.startswith('module.') else 'module.encoder_q.' + k + + assert ((state_dict[k].cpu() == state_dict_pre[k_pre]).all()), \ + '{} is changed in linear classifier training.'.format(k) + + print("=> sanity check passed.") + + +class AverageMeter(object): + """Computes and stores the average and current value""" + def __init__(self, name, fmt=':f'): + self.name = name + self.fmt = fmt + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + def __str__(self): + fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' + return fmtstr.format(**self.__dict__) + + +class ProgressMeter(object): + def __init__(self, num_batches, meters, prefix=""): + self.batch_fmtstr = self._get_batch_fmtstr(num_batches) + self.meters = meters + self.prefix = prefix + + def display(self, batch): + entries = [self.prefix + self.batch_fmtstr.format(batch)] + entries += [str(meter) for meter in self.meters] + print('\t'.join(entries)) + + def _get_batch_fmtstr(self, num_batches): + num_digits = len(str(num_batches // 1)) + fmt = '{:' + str(num_digits) + 'd}' + return '[' + fmt + '/' + fmt.format(num_batches) + ']' + + +def adjust_learning_rate(optimizer, epoch, args): + """Decay the learning rate based on schedule""" + lr = args.lr + for milestone in args.schedule: + lr *= 0.1 if epoch >= milestone else 1. + for param_group in optimizer.param_groups: + param_group['lr'] = lr + + +def accuracy(output, target, topk=(1,)): + """Computes the accuracy over the k top predictions for the specified values of k""" + with torch.no_grad(): + maxk = max(topk) + batch_size = target.size(0) + + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + correct = pred.eq(target.view(1, -1).expand_as(pred)) + + res = [] + for k in topk: + correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) + res.append(correct_k.mul_(100.0 / batch_size)) + return res + + +if __name__ == '__main__': + main() diff --git a/main_moco.py b/main_moco.py new file mode 100755 index 0000000..f9c3d82 --- /dev/null +++ b/main_moco.py @@ -0,0 +1,419 @@ +#!/usr/bin/env python +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +import argparse +import builtins +import math +import os +import random +import shutil +import time +import warnings + +import torch +import torch.nn as nn +import torch.nn.parallel +import torch.backends.cudnn as cudnn +import torch.distributed as dist +import torch.optim +import torch.multiprocessing as mp +import torch.utils.data +import torch.utils.data.distributed +import torch.cuda.amp as amp +import torchvision.transforms as transforms +import torchvision.datasets as datasets +import torchvision.models as models + +import moco.loader +import moco.builder + +from resnet import ResNet + +model_names = sorted(name for name in models.__dict__ + if name.islower() and not name.startswith("__") + and callable(models.__dict__[name])) + +parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') +parser.add_argument('data', metavar='DIR', + help='path to dataset') +parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50', + choices=model_names, + help='model architecture: ' + + ' | '.join(model_names) + + ' (default: resnet50)') +parser.add_argument('-j', '--workers', default=32, type=int, metavar='N', + help='number of data loading workers (default: 32)') +parser.add_argument('--epochs', default=200, type=int, metavar='N', + help='number of total epochs to run') +parser.add_argument('--start-epoch', default=0, type=int, metavar='N', + help='manual epoch number (useful on restarts)') +parser.add_argument('-b', '--batch-size', default=256, type=int, + metavar='N', + help='mini-batch size (default: 256), this is the total ' + 'batch size of all GPUs on the current node when ' + 'using Data Parallel or Distributed Data Parallel') +parser.add_argument('--lr', '--learning-rate', default=0.03, type=float, + metavar='LR', help='initial learning rate', dest='lr') +parser.add_argument('--schedule', default=[120, 160], nargs='*', type=int, + help='learning rate schedule (when to drop lr by 10x)') +parser.add_argument('--momentum', default=0.9, type=float, metavar='M', + help='momentum of SGD solver') +parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float, + metavar='W', help='weight decay (default: 1e-4)', + dest='weight_decay') +parser.add_argument('-p', '--print-freq', default=100, type=int, + metavar='N', help='print frequency (default: 10)') +parser.add_argument('--resume', default='', type=str, metavar='PATH', + help='path to latest checkpoint (default: none)') +parser.add_argument('--world-size', default=-1, type=int, + help='number of nodes for distributed training') +parser.add_argument('--rank', default=-1, type=int, + help='node rank for distributed training') +parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str, + help='url used to set up distributed training') +parser.add_argument('--dist-backend', default='nccl', type=str, + help='distributed backend') +parser.add_argument('--seed', default=None, type=int, + help='seed for initializing training. ') +parser.add_argument('--gpu', default=None, type=int, + help='GPU id to use.') +parser.add_argument('--multiprocessing-distributed', action='store_true', + help='Use multi-processing distributed training to launch ' + 'N processes per node, which has N GPUs. This is the ' + 'fastest way to use PyTorch for either single node or ' + 'multi node data parallel training') +parser.add_argument('--use-mixed-precision', action='store_true', + help='use mlp head') + +# moco specific configs: +parser.add_argument('--moco-dim', default=128, type=int, + help='feature dimension (default: 128)') +parser.add_argument('--moco-k', default=65536, type=int, + help='queue size; number of negative keys (default: 65536)') +parser.add_argument('--moco-m', default=0.999, type=float, + help='moco momentum of updating key encoder (default: 0.999)') +parser.add_argument('--moco-t', default=0.07, type=float, + help='softmax temperature (default: 0.07)') + +# options for moco v2 +parser.add_argument('--mlp', action='store_true', + help='use mlp head') +parser.add_argument('--aug-plus', action='store_true', + help='use moco v2 data augmentation') +parser.add_argument('--cos', action='store_true', + help='use cosine lr schedule') + + +def main(): + args = parser.parse_args() + + if args.seed is not None: + random.seed(args.seed) + torch.manual_seed(args.seed) + cudnn.deterministic = True + warnings.warn('You have chosen to seed training. ' + 'This will turn on the CUDNN deterministic setting, ' + 'which can slow down your training considerably! ' + 'You may see unexpected behavior when restarting ' + 'from checkpoints.') + + if args.gpu is not None: + warnings.warn('You have chosen a specific GPU. This will completely ' + 'disable data parallelism.') + + if args.dist_url == "env://" and args.world_size == -1: + args.world_size = int(os.environ["WORLD_SIZE"]) + + args.distributed = args.world_size > 1 or args.multiprocessing_distributed + + ngpus_per_node = torch.cuda.device_count() + if args.multiprocessing_distributed: + # Since we have ngpus_per_node processes per node, the total world_size + # needs to be adjusted accordingly + args.world_size = ngpus_per_node * args.world_size + # Use torch.multiprocessing.spawn to launch distributed processes: the + # main_worker process function + mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args)) + else: + # Simply call main_worker function + main_worker(args.gpu, ngpus_per_node, args) + + +def main_worker(gpu, ngpus_per_node, args): + args.gpu = gpu + + # suppress printing if not master + if args.multiprocessing_distributed and args.gpu != 0: + def print_pass(*args): + pass + builtins.print = print_pass + + if args.gpu is not None: + print("Use GPU: {} for training".format(args.gpu)) + + if args.distributed: + if args.dist_url == "env://" and args.rank == -1: + args.rank = int(os.environ["RANK"]) + if args.multiprocessing_distributed: + # For multiprocessing distributed training, rank needs to be the + # global rank among all the processes + args.rank = args.rank * ngpus_per_node + gpu + dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, + world_size=args.world_size, rank=args.rank) + # create model + print("=> creating model '{}'".format(args.arch)) + model = moco.builder.MoCo( + models.__dict__[args.arch], + args.moco_dim, args.moco_k, args.moco_m, args.moco_t, args.mlp) + print(model) + + if args.distributed: + # For multiprocessing distributed, DistributedDataParallel constructor + # should always set the single device scope, otherwise, + # DistributedDataParallel will use all available devices. + if args.gpu is not None: + torch.cuda.set_device(args.gpu) + model.cuda(args.gpu) + # When using a single GPU per process and per + # DistributedDataParallel, we need to divide the batch size + # ourselves based on the total number of GPUs we have + args.batch_size = int(args.batch_size / ngpus_per_node) + args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node) + model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) + else: + model.cuda() + # DistributedDataParallel will divide and allocate batch_size to all + # available GPUs if device_ids are not set + model = torch.nn.parallel.DistributedDataParallel(model) + elif args.gpu is not None: + torch.cuda.set_device(args.gpu) + model = model.cuda(args.gpu) + # comment out the following line for debugging + raise NotImplementedError("Only DistributedDataParallel is supported.") + else: + # AllGather implementation (batch shuffle, queue update, etc.) in + # this code only supports DistributedDataParallel. + raise NotImplementedError("Only DistributedDataParallel is supported.") + + # define loss function (criterion) and optimizer + criterion = nn.CrossEntropyLoss().cuda(args.gpu) + + optimizer = torch.optim.SGD(model.parameters(), args.lr, + momentum=args.momentum, + weight_decay=args.weight_decay) + + # optionally resume from a checkpoint + if args.resume: + if os.path.isfile(args.resume): + print("=> loading checkpoint '{}'".format(args.resume)) + if args.gpu is None: + checkpoint = torch.load(args.resume) + else: + # Map model to be loaded to specified single gpu. + loc = 'cuda:{}'.format(args.gpu) + checkpoint = torch.load(args.resume, map_location=loc) + args.start_epoch = checkpoint['epoch'] + model.load_state_dict(checkpoint['state_dict']) + optimizer.load_state_dict(checkpoint['optimizer']) + print("=> loaded checkpoint '{}' (epoch {})" + .format(args.resume, checkpoint['epoch'])) + else: + print("=> no checkpoint found at '{}'".format(args.resume)) + + cudnn.benchmark = True + + # Data loading code + traindir = os.path.join(args.data, 'train') + normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]) + if args.aug_plus: + # MoCo v2's aug: similar to SimCLR https://arxiv.org/abs/2002.05709 + augmentation = [ + transforms.RandomResizedCrop(224, scale=(0.2, 1.)), + transforms.RandomApply([ + transforms.ColorJitter(0.4, 0.4, 0.4, 0.1) # not strengthened + ], p=0.8), + transforms.RandomGrayscale(p=0.2), + transforms.RandomApply([moco.loader.GaussianBlur([.1, 2.])], p=0.5), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + normalize + ] + else: + # MoCo v1's aug: the same as InstDisc https://arxiv.org/abs/1805.01978 + augmentation = [ + transforms.RandomResizedCrop(224, scale=(0.2, 1.)), + transforms.RandomGrayscale(p=0.2), + transforms.ColorJitter(0.4, 0.4, 0.4, 0.4), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + normalize + ] + + train_dataset = datasets.ImageFolder( + traindir, + moco.loader.TwoCropsTransform(transforms.Compose(augmentation))) + + if args.distributed: + train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) + else: + train_sampler = None + + train_loader = torch.utils.data.DataLoader( + train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), + num_workers=args.workers, pin_memory=True, sampler=train_sampler, drop_last=True) + + for epoch in range(args.start_epoch, args.epochs): + if args.distributed: + train_sampler.set_epoch(epoch) + adjust_learning_rate(optimizer, epoch, args) + + # train for one epoch + train(train_loader, model, criterion, optimizer, epoch, args) + + n_ckpt_period = 20 + if not args.multiprocessing_distributed or (args.multiprocessing_distributed + and args.rank % ngpus_per_node == 0) and ((epoch + 1) % n_ckpt_period == 0): + if epoch > args.epochs - 10: n_ckpt_period = 1 + save_checkpoint({ + 'epoch': epoch + 1, + 'arch': args.arch, + 'state_dict': model.state_dict(), + 'optimizer' : optimizer.state_dict(), + }, is_best=False, filename='checkpoint_{:04d}.pth.tar'.format(epoch)) + + +def train(train_loader, model, criterion, optimizer, epoch, args): + batch_time = AverageMeter('Time', ':6.3f') + data_time = AverageMeter('Data', ':6.3f') + losses = AverageMeter('Loss', ':.4e') + top1 = AverageMeter('Acc@1', ':6.2f') + top5 = AverageMeter('Acc@5', ':6.2f') + progress = ProgressMeter( + len(train_loader), + [batch_time, data_time, losses, top1, top5], + prefix="Epoch: [{}]".format(epoch)) + + # amp scalar + scaler = amp.GradScaler() + + # switch to train mode + model.train() + + end = time.time() + for i, (images, _) in enumerate(train_loader): + # measure data loading time + data_time.update(time.time() - end) + + if args.gpu is not None: + images[0] = images[0].cuda(args.gpu, non_blocking=True) + images[1] = images[1].cuda(args.gpu, non_blocking=True) + + # compute output + with amp.autocast(enabled=args.use_mixed_precision): + output, target, output_dense, target_dense = model( + im_q=images[0], im_k=images[1]) + loss = criterion(output, target) + loss_dense = criterion(output_dense, target_dense) + loss = 0.5 * (loss + loss_dense) + + # acc1/acc5 are (K+1)-way contrast classifier accuracy + # measure accuracy and record loss + acc1, acc5 = accuracy(output, target, topk=(1, 5)) + losses.update(loss.item(), images[0].size(0)) + top1.update(acc1[0], images[0].size(0)) + top5.update(acc5[0], images[0].size(0)) + + # compute gradient and do SGD step + optimizer.zero_grad() + scaler.scale(loss).backward() + scaler.step(optimizer) + scaler.update() + # loss.backward() + # optimizer.step() + + # measure elapsed time + batch_time.update(time.time() - end) + end = time.time() + + if i % args.print_freq == 0: + progress.display(i) + + +def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): + torch.save(state, filename) + if is_best: + shutil.copyfile(filename, 'model_best.pth.tar') + + +class AverageMeter(object): + """Computes and stores the average and current value""" + def __init__(self, name, fmt=':f'): + self.name = name + self.fmt = fmt + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + def __str__(self): + fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' + return fmtstr.format(**self.__dict__) + + +class ProgressMeter(object): + def __init__(self, num_batches, meters, prefix=""): + self.batch_fmtstr = self._get_batch_fmtstr(num_batches) + self.meters = meters + self.prefix = prefix + + def display(self, batch): + entries = [self.prefix + self.batch_fmtstr.format(batch)] + entries += [str(meter) for meter in self.meters] + print('\t'.join(entries)) + + def _get_batch_fmtstr(self, num_batches): + num_digits = len(str(num_batches // 1)) + fmt = '{:' + str(num_digits) + 'd}' + return '[' + fmt + '/' + fmt.format(num_batches) + ']' + + +def adjust_learning_rate(optimizer, epoch, args): + """Decay the learning rate based on schedule""" + lr = args.lr + if args.cos: # cosine lr schedule + lr *= 0.5 * (1. + math.cos(math.pi * epoch / args.epochs)) + else: # stepwise lr schedule + for milestone in args.schedule: + lr *= 0.1 if epoch >= milestone else 1. + for param_group in optimizer.param_groups: + param_group['lr'] = lr + + +def accuracy(output, target, topk=(1,)): + """Computes the accuracy over the k top predictions for the specified values of k""" + with torch.no_grad(): + maxk = max(topk) + batch_size = target.size(0) + + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + correct = pred.eq(target.view(1, -1).expand_as(pred)) + + res = [] + for k in topk: + correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) + res.append(correct_k.mul_(100.0 / batch_size)) + return res + + +if __name__ == '__main__': + main() diff --git a/moco/__init__.py b/moco/__init__.py new file mode 100644 index 0000000..168f997 --- /dev/null +++ b/moco/__init__.py @@ -0,0 +1 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved diff --git a/moco/builder.py b/moco/builder.py new file mode 100644 index 0000000..fe137a8 --- /dev/null +++ b/moco/builder.py @@ -0,0 +1,205 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +import torch +import torch.nn as nn + +from resnet import resnet50 + + +class MoCo(nn.Module): + """ + Build a MoCo model with: a query encoder, a key encoder, and a queue + https://arxiv.org/abs/1911.05722 + """ + def __init__(self, base_encoder, dim=128, K=65536, m=0.999, T=0.07, mlp=False): + """ + dim: feature dimension (default: 128) + K: queue size; number of negative keys (default: 65536) + m: moco momentum of updating key encoder (default: 0.999) + T: softmax temperature (default: 0.07) + """ + super(MoCo, self).__init__() + + self.K = K + self.m = m + self.T = T + + # create the encoders + # num_classes is the output fc dimension + + self.encoder_q = resnet50(num_classes=dim) + self.encoder_k = resnet50(num_classes=dim) + + if mlp: # hack: brute-force replacement + dim_mlp = self.encoder_q.fc.weight.shape[1] + self.encoder_q.fc = nn.Sequential(nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), self.encoder_q.fc) + self.encoder_k.fc = nn.Sequential(nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), self.encoder_k.fc) + + + for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()): + param_k.data.copy_(param_q.data) # initialize + param_k.requires_grad = False # not update by gradient + + # create the queue + self.register_buffer("queue", torch.randn(dim, K)) + self.register_buffer("queue_dense", torch.randn(dim, K)) + self.queue = nn.functional.normalize(self.queue, dim=0) + self.queue_dense = nn.functional.normalize(self.queue_dense, dim=0) + + self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long)) + + @torch.no_grad() + def _momentum_update_key_encoder(self): + """ + Momentum update of the key encoder + """ + for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()): + param_k.data = param_k.data * self.m + param_q.data * (1. - self.m) + + @torch.no_grad() + def _dequeue_and_enqueue(self, keys, dense_keys): + # gather keys before updating queue + keys = concat_all_gather(keys) + dense_keys = nn.functional.normalize(dense_keys.mean(dim=2), dim=1) + dense_keys = concat_all_gather(dense_keys) + + batch_size = keys.shape[0] + + ptr = int(self.queue_ptr) + assert self.K % batch_size == 0 # for simplicity + + # replace the keys at ptr (dequeue and enqueue) + self.queue[:, ptr:ptr + batch_size] = keys.T + self.queue_dense[:, ptr:ptr + batch_size] = dense_keys.T + ptr = (ptr + batch_size) % self.K # move pointer + + self.queue_ptr[0] = ptr + + @torch.no_grad() + def _batch_shuffle_ddp(self, x): + """ + Batch shuffle, for making use of BatchNorm. + *** Only support DistributedDataParallel (DDP) model. *** + """ + # gather from all gpus + batch_size_this = x.shape[0] + x_gather = concat_all_gather(x) + batch_size_all = x_gather.shape[0] + + num_gpus = batch_size_all // batch_size_this + + # random shuffle index + idx_shuffle = torch.randperm(batch_size_all).cuda() + + # broadcast to all gpus + torch.distributed.broadcast(idx_shuffle, src=0) + + # index for restoring + idx_unshuffle = torch.argsort(idx_shuffle) + + # shuffled index for this gpu + gpu_idx = torch.distributed.get_rank() + idx_this = idx_shuffle.view(num_gpus, -1)[gpu_idx] + + return x_gather[idx_this], idx_unshuffle + + @torch.no_grad() + def _batch_unshuffle_ddp(self, x, idx_unshuffle): + """ + Undo batch shuffle. + *** Only support DistributedDataParallel (DDP) model. *** + """ + # gather from all gpus + batch_size_this = x.shape[0] + x_gather = concat_all_gather(x) + batch_size_all = x_gather.shape[0] + + num_gpus = batch_size_all // batch_size_this + + # restored index for this gpu + gpu_idx = torch.distributed.get_rank() + idx_this = idx_unshuffle.view(num_gpus, -1)[gpu_idx] + + return x_gather[idx_this] + + def forward(self, im_q, im_k): + """ + Input: + im_q: a batch of query images + im_k: a batch of key images + Output: + logits, targets + """ + + # compute query features + q, dense_q, feat_q = self.encoder_q(im_q) # queries: NxC + q = nn.functional.normalize(q, dim=1) + n, c, h, w = feat_q.size() + dim_dense = dense_q.size(1) + dense_q, feat_q = dense_q.view(n, dim_dense, -1), feat_q.view(n, c, -1) + dense_q = nn.functional.normalize(dense_q, dim=1) + + # compute key features + with torch.no_grad(): # no gradient to keys + self._momentum_update_key_encoder() # update the key encoder + + # shuffle for making use of BN + im_k, idx_unshuffle = self._batch_shuffle_ddp(im_k) + + k, dense_k, feat_k = self.encoder_k(im_k) # keys: NxC + k = nn.functional.normalize(k, dim=1) + dense_k, feat_k = dense_k.view(n, dim_dense, -1), feat_k.view(n, c, -1) + dense_k_norm = nn.functional.normalize(dense_k, dim=1) + + # undo shuffle + k = self._batch_unshuffle_ddp(k, idx_unshuffle) + + ## match + feat_q_norm = nn.functional.normalize(feat_q, dim=1) + feat_k_norm = nn.functional.normalize(feat_k, dim=1) + cosine = torch.einsum('nca,ncb->nab', feat_q_norm, feat_k_norm) + pos_idx = cosine.argmax(dim=-1) + dense_k_norm = dense_k_norm.gather(2, pos_idx.unsqueeze(1).expand(-1, dim_dense, -1)) + + # compute logits + # Einstein sum is more intuitive + # positive logits: Nx1 + l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1) + # negative logits: NxK + l_neg = torch.einsum('nc,ck->nk', [q, self.queue.clone().detach()]) + + # logits: Nx(1+K) + logits = torch.cat([l_pos, l_neg], dim=1) + + # apply temperature + logits /= self.T + + # labels: positive key indicators + labels = torch.zeros(logits.shape[0], dtype=torch.long).cuda() + + + ## densecl logits + d_pos = torch.einsum('ncm,ncm->nm', dense_q, dense_k_norm).unsqueeze(1) + d_neg = torch.einsum('ncm,ck->nkm', dense_q, self.queue_dense.clone().detach()) + logits_dense = torch.cat([d_pos, d_neg], dim=1) + logits_dense = logits_dense / self.T + labels_dense = torch.zeros((n, h*w), dtype=torch.long).cuda() + + # dequeue and enqueue + self._dequeue_and_enqueue(k, dense_k) + + return logits, labels, logits_dense, labels_dense + + +# utils +@torch.no_grad() +def concat_all_gather(tensor): + """ + Performs all_gather operation on the provided tensors. + *** Warning ***: torch.distributed.all_gather has no gradient. + """ + tensors_gather = [torch.ones_like(tensor) + for _ in range(torch.distributed.get_world_size())] + torch.distributed.all_gather(tensors_gather, tensor, async_op=False) + + output = torch.cat(tensors_gather, dim=0) + return output diff --git a/moco/loader.py b/moco/loader.py new file mode 100644 index 0000000..655aea5 --- /dev/null +++ b/moco/loader.py @@ -0,0 +1,27 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +from PIL import ImageFilter +import random + + +class TwoCropsTransform: + """Take two random crops of one image as the query and key.""" + + def __init__(self, base_transform): + self.base_transform = base_transform + + def __call__(self, x): + q = self.base_transform(x) + k = self.base_transform(x) + return [q, k] + + +class GaussianBlur(object): + """Gaussian blur augmentation in SimCLR https://arxiv.org/abs/2002.05709""" + + def __init__(self, sigma=[.1, 2.]): + self.sigma = sigma + + def __call__(self, x): + sigma = random.uniform(self.sigma[0], self.sigma[1]) + x = x.filter(ImageFilter.GaussianBlur(radius=sigma)) + return x diff --git a/resnet.py b/resnet.py new file mode 100644 index 0000000..a6e89c1 --- /dev/null +++ b/resnet.py @@ -0,0 +1,360 @@ +import torch +import torch.nn as nn +# from .utils import load_state_dict_from_url + + +__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', + 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', + 'wide_resnet50_2', 'wide_resnet101_2'] + + +model_urls = { + 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', + 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', + 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', + 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', + 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', + 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', + 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', + 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth', + 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth', +} + + +def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=dilation, groups=groups, bias=False, dilation=dilation) + + +def conv1x1(in_planes, out_planes, stride=1): + """1x1 convolution""" + return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, + base_width=64, dilation=1, norm_layer=None): + super(BasicBlock, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + if groups != 1 or base_width != 64: + raise ValueError('BasicBlock only supports groups=1 and base_width=64') + if dilation > 1: + raise NotImplementedError("Dilation > 1 not supported in BasicBlock") + # Both self.conv1 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = norm_layer(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = norm_layer(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) + # while original implementation places the stride at the first 1x1 convolution(self.conv1) + # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. + # This variant is also known as ResNet V1.5 and improves accuracy according to + # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. + + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, + base_width=64, dilation=1, norm_layer=None): + super(Bottleneck, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + width = int(planes * (base_width / 64.)) * groups + # Both self.conv2 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv1x1(inplanes, width) + self.bn1 = norm_layer(width) + self.conv2 = conv3x3(width, width, stride, groups, dilation) + self.bn2 = norm_layer(width) + self.conv3 = conv1x1(width, planes * self.expansion) + self.bn3 = norm_layer(planes * self.expansion) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + +class ResNet(nn.Module): + + def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, + groups=1, width_per_group=64, replace_stride_with_dilation=None, + norm_layer=None): + super(ResNet, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + self._norm_layer = norm_layer + + self.inplanes = 64 + self.dilation = 1 + if replace_stride_with_dilation is None: + # each element in the tuple indicates if we should replace + # the 2x2 stride with a dilated convolution instead + replace_stride_with_dilation = [False, False, False] + if len(replace_stride_with_dilation) != 3: + raise ValueError("replace_stride_with_dilation should be None " + "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) + self.groups = groups + self.base_width = width_per_group + self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, + bias=False) + self.bn1 = norm_layer(self.inplanes) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2, + dilate=replace_stride_with_dilation[0]) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2, + dilate=replace_stride_with_dilation[1]) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2, + dilate=replace_stride_with_dilation[2]) + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + self.fc = nn.Linear(512 * block.expansion, num_classes) + self.dense_head = nn.Sequential( + nn.Conv2d(512 * block.expansion, 512 * block.expansion, 1, 1, 0, bias=True), + nn.ReLU(inplace=True), + nn.Conv2d(512 * block.expansion, num_classes, 1, 1, 0, bias=True) + ) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + # Zero-initialize the last BN in each residual branch, + # so that the residual branch starts with zeros, and each residual block behaves like an identity. + # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 + if zero_init_residual: + for m in self.modules(): + if isinstance(m, Bottleneck): + nn.init.constant_(m.bn3.weight, 0) + elif isinstance(m, BasicBlock): + nn.init.constant_(m.bn2.weight, 0) + + def _make_layer(self, block, planes, blocks, stride=1, dilate=False): + norm_layer = self._norm_layer + downsample = None + previous_dilation = self.dilation + if dilate: + self.dilation *= stride + stride = 1 + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + conv1x1(self.inplanes, planes * block.expansion, stride), + norm_layer(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample, self.groups, + self.base_width, previous_dilation, norm_layer)) + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append(block(self.inplanes, planes, groups=self.groups, + base_width=self.base_width, dilation=self.dilation, + norm_layer=norm_layer)) + + return nn.Sequential(*layers) + + def _forward_impl(self, x): + # See note [TorchScript super()] + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + feat = self.layer4(x) + + x = self.avgpool(feat) + x = torch.flatten(x, 1) + logits = self.fc(x) + dense = self.dense_head(feat) + + return logits, dense, feat + # return x + + def forward(self, x): + return self._forward_impl(x) + + +def _resnet(arch, block, layers, pretrained, progress, **kwargs): + model = ResNet(block, layers, **kwargs) + # if pretrained: + # state_dict = load_state_dict_from_url(model_urls[arch], + # progress=progress) + # model.load_state_dict(state_dict) + return model + + +def resnet18(pretrained=False, progress=True, **kwargs): + r"""ResNet-18 model from + `"Deep Residual Learning for Image Recognition" `_ + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, + **kwargs) + + +def resnet34(pretrained=False, progress=True, **kwargs): + r"""ResNet-34 model from + `"Deep Residual Learning for Image Recognition" `_ + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, + **kwargs) + + +def resnet50(pretrained=False, progress=True, **kwargs): + r"""ResNet-50 model from + `"Deep Residual Learning for Image Recognition" `_ + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, + **kwargs) + + +def resnet101(pretrained=False, progress=True, **kwargs): + r"""ResNet-101 model from + `"Deep Residual Learning for Image Recognition" `_ + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, + **kwargs) + + +def resnet152(pretrained=False, progress=True, **kwargs): + r"""ResNet-152 model from + `"Deep Residual Learning for Image Recognition" `_ + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, + **kwargs) + + +def resnext50_32x4d(pretrained=False, progress=True, **kwargs): + r"""ResNeXt-50 32x4d model from + `"Aggregated Residual Transformation for Deep Neural Networks" `_ + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + kwargs['groups'] = 32 + kwargs['width_per_group'] = 4 + return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], + pretrained, progress, **kwargs) + + +def resnext101_32x8d(pretrained=False, progress=True, **kwargs): + r"""ResNeXt-101 32x8d model from + `"Aggregated Residual Transformation for Deep Neural Networks" `_ + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + kwargs['groups'] = 32 + kwargs['width_per_group'] = 8 + return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], + pretrained, progress, **kwargs) + + +def wide_resnet50_2(pretrained=False, progress=True, **kwargs): + r"""Wide ResNet-50-2 model from + `"Wide Residual Networks" `_ + + The model is the same as ResNet except for the bottleneck number of channels + which is twice larger in every block. The number of channels in outer 1x1 + convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 + channels, and in Wide ResNet-50-2 has 2048-1024-2048. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + kwargs['width_per_group'] = 64 * 2 + return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3], + pretrained, progress, **kwargs) + + +def wide_resnet101_2(pretrained=False, progress=True, **kwargs): + r"""Wide ResNet-101-2 model from + `"Wide Residual Networks" `_ + + The model is the same as ResNet except for the bottleneck number of channels + which is twice larger in every block. The number of channels in outer 1x1 + convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 + channels, and in Wide ResNet-50-2 has 2048-1024-2048. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + kwargs['width_per_group'] = 64 * 2 + return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3], + pretrained, progress, **kwargs)