Skip to content

achen46/Neighborhood-Attention-Transformer

 
 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

8 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Neighborhood Attention Transformer

Preprint Link: Neighborhood Attention Transformer

By Ali Hassani[1, 2], Steven Walton[1, 2], Jiachen Li[1,2], Shen Li[3], and Humphrey Shi[1,2]

In association with SHI Lab @ University of Oregon & UIUC[1] and Picsart AI Research (PAIR)[2], and Meta/Facebook AI[3]

NAT-Intro NAT-Intro

Abstract

NAT-Arch NAT-Arch We present Neighborhood Attention Transformer (NAT), an efficient, accurate and scalable hierarchical transformer that works well on both image classification and downstream vision tasks. It is built upon Neighborhood Attention (NA), a simple and flexible attention mechanism that localizes the receptive field for each query to its nearest neighboring pixels. NA is a localization of self-attention, and approaches it as the receptive field size increases. It is also equivalent in FLOPs and memory usage to Swin Transformer's shifted window attention given the same receptive field size, while being less constrained. Furthermore, NA includes local inductive biases, which eliminate the need for extra operations such as pixel shifts. Experimental results on NAT are competitive; NAT-Tiny reaches 83.2% top-1 accuracy on ImageNet with only 4.3 GFLOPs and 28M parameters, 51.4% mAP on MS-COCO and 48.4% mIoU on ADE20k.

computeplot_dark computeplot_light

How it works

Natural Attention localizes the query's (red) receptive field to its nearest neighborhood (green). This is equivalent to dot-product self attention when the neighborhood size is identical to the image dimensions. Note that the edges are special (edge) cases.

720p_fast_dm 720p_fast_lm

Implementation

We wrote a PyTorch CUDA extension to parallelize NA. It's relatively fast, very memory-efficient, and supports half precision. There's still a lot of room for improvement, so feel free to open PRs and contribute!

Results and checkpoints

Classification

Model # of Params FLOPs Top-1
NAT-Mini 20M 2.7G 81.8%
NAT-Tiny 28M 4.3G 83.2%
NAT-Small 51M 7.8G 83.7%
NAT-Base 90M 13.7G 84.3%

Details on training and validation are provided in classification.

Object Detection

Backbone Network # of Params FLOPs mAP Mask mAP Checkpoint
NAT-Mini Mask R-CNN 40M 225G 46.5 41.7 Download
NAT-Tiny Mask R-CNN 48M 258G 47.7 42.6 Download
NAT-Small Mask R-CNN 70M 330G 48.4 43.2 Download
NAT-Mini Cascade Mask R-CNN 77M 704G 50.3 43.6 Download
NAT-Tiny Cascade Mask R-CNN 85M 737G 51.4 44.5 Download
NAT-Small Cascade Mask R-CNN 108M 809G 52.0 44.9 Download
NAT-Base Cascade Mask R-CNN 147M 931G 52.3 45.1 Download

Details on training and validation are provided in detection.

Semantic Segmentation

Backbone Network # of Params FLOPs mIoU mIoU (multi-scale) Checkpoint
NAT-Mini UPerNet 50M 900G 45.1 46.4 Download
NAT-Tiny UPerNet 58M 934G 47.1 48.4 Download
NAT-Small UPerNet 82M 1010G 48.0 49.5 Download
NAT-Base UPerNet 123M 1137G 48.5 49.7 Download

Details on training and validation are provided in segmentation.

Salient maps

Original ViT Swin NAT
img0 img0-vit-darkimg0-vit-light img0-swin-darkimg0-swin-light img0-nat-darkimg0-nat-light
img1 img1-vit-darkimg1-vit-light img1-swin-darkimg1-swin-light img1-nat-darkimg1-nat-light
img2 img2-vit-darkimg2-vit-light img2-swin-darkimg2-swin-light img2-nat-darkimg2-nat-light
img3 img3-vit-darkimg3-vit-light img3-swin-darkimg3-swin-light img3-nat-darkimg3-nat-light

Citation

@article{hassani2022neighborhood,
	title        = {Neighborhood Attention Transformer},
	author       = {Ali Hassani and Steven Walton and Jiachen Li and Shen Li and Humphrey Shi},
	year         = 2022,
	url          = {https://arxiv.org/abs/2204.07143},
	eprint       = {2204.07143},
	archiveprefix = {arXiv},
	primaryclass = {cs.CV}
}

About

[Preprint] Neighborhood Attention Transformer

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 68.2%
  • Cuda 28.9%
  • C++ 2.5%
  • Shell 0.4%