Simple implementation of prototypical networks in few shot learning
Installation
Create a conda/virtualenv with all necessary packages:
conda create --name fs-learn
conda activate fs-learn
conda install pytorch torchvision torchaudio -c pytorch
conda install --file requirements.txt
python3 -m pip install virtualenv
virtualenv venv-fs-learn
source venv/bin/activate
python3 -m pip install torch torchvision
python3 -m pip install ./requirements.txt
Datasets
We used 3 main classification datasets:
- mini_imagenet: a collection of 100 real-world objects classes as rgb images.
- omniglot: a collection of 1623 classes of handwritted characters. Each image is then rotated 3 more times by 90 degrees.
- flowers102: a collection of 102 real-world flowers classes as rgb images.
- stanfors_cars: a collection of 192 real-world cars classes as rgb images.
Usage
The starter script is meta_train.py that has all necessary params to meta-train and meta-test on a dataset.
To replicate the results, launch this training (writes to runs/train_X):
python meta_train.py --data mini_imagenet \
--episodes 200 \
--device cuda \
--num-way 30 \
--query 15 \
--shot 5 \
--val-num-way 5 \
--iterations 100 \
--adam-lr 0.001 \
--adam-step 20 \
--adam-gamma 0.5 \
--metric "euclidean" \
--save-period 5 \
--patience 10 \
--patience-delta 0.01
Implemented datasets are [omniglot, mini_imagenet, flowers102, stanford_cars]:
To train with your own custom dataset, set --dataset toy our dataset folder.
Rember, your custom dataset should have this format:
├── train
│ ├── class1
│ │ ├── img1.jpg
│ │ ├── ...
│ ├── class2
│ │ ├── ...
│ ├── ...
├── val
│ ├── class3
│ │ ├── img57.jpg
│ │ ├── ...
│ ├── class4
│ │ ├── ...
│ ├── ...
├── test
│ ├── class5
│ │ ├── img182.jpg
│ │ ├── ...
│ ├── class6
│ │ ├── ...
│ ├── ...
To meta-test, use meta_test.py script:
python meta_test.py --model "your_model_or_pretrained.py" \
--data mini_imagenet \
--iterations 100 \
--device cuda \
--val-num-way 15 \
--query 15 \
--shot 5 \
--metric "euclidean"
To learn centroids for new data, use learn_centroids.py script (writes to runs/centroids_Y):
python learn_centroids.py --model "your_model_or_pretrained.py" \
--data your_folder_with_classes_of_images \
--imgsz 64 \
--channels 3 \
--device cuda
This will take all classes inside your_folder_with_classes_of_images dir and calculate centroids for classification task.
To use centroids for classification on new images, use predict.py script (outputs results):
python predict.py --model "your_model_or_pretrained.py" \
--centroids runs/centroids_0 \
--data a_path_with_new_images \
--imgsz 64 \
--device cuda
This will perform predictions by printing out all classes based on images in a_path_with_new_images .
Experiments
Dataset | Images (shape) |
Embeddings (shape) |
Duration (Colab T4) |
---|---|---|---|
mini_imagenet | (84, 84, 3) | (batch, 1600) | gpu / 1h43m |
omniglot | (28, 28, 1) | (batch, 60) | gpu / 2h32 |
flowers102 | (74, 74, 3) | (batch, 1024) | gpu / 58m |
stanford_cars | (90, 90, 3) | (batch, 1024) | gpu / 1h52m |
Lots of experiments were done using basic paper's data by replicating the training. All of these uses nway=30, epochs=200 and iterations_per_epoch=100 for training. Then evaluation is performed in different n-ways and k-shots.
Dataset | Paper res 5-way 5-shot (Acc) |
Our res 5-way 5-shot (Acc) |
Paper res 5-way 1-shot (Acc) |
Our res 5-way 1-shot (Acc) |
---|---|---|---|---|
mini_imagenet | 68.20 | 63.62 | 49.42 | 46.13 |
omniglot | 98.80 | 97.77 | 98.8 | 91.93 |
flowers102 | / | 84.48 | / | 56.08 |
stanford_cars | / | 51.87 | / | / |
Cosine experiments were done on 5-way 5-shot configurations. Same results for similar 1-shot and 20-way trainings.
Dataset | Cosine (acc) |
Euclidean (acc) |
---|---|---|
mini_imagenet | 22.36 | 63.62 |
omniglot | 23.48 | 97.77 |
flowers102 | 82.89 | 84.48 |
stanford_cars | ____ | 51.87 |