Pytorch implementation of various Knowledge Distillation methods.
This repository is a simple reference, thus many tricks, such as step-by-step training, iterative training, ensemble of teachers, etc. are not considered.
Filename | Method | Link |
---|---|---|
train_baseline.py | basic cnn with softmax loss | — |
train_logits.py | mimic learning via regressing logits (logits) | paper |
train_st.py | soft targets (st) | paper |
train_fitnet.py | hints for thin deep nets (fitnet) | paper |
train_at.py | attention transfer (at) | paper |
train_fsp.py | flow of solution procedure (fsp) | paper |
train_nst.py | neural selective transfer (nst) | paper |
train_pkt.py | probabilistic knowledge transfer (pkt) | paper |
train_ft.py | factor transfer (ft) | paper |
train_dml.py | deep mutual learning (dml) | paper |
- Note, there are some differences between this repository and the original paper:
- For
fitnet
: the training procedure is one stage without hint layer. - For
at
: I use the sum of absolute values with power p=2 as the attention. - For
nst
: I use squared mmd matching. - For
dml
: just two nets are employed.
- For
- CIFAR10
- CIFAR100
- Resnet-20
- Resnet-110
The networks are same with Tabel 6 in paper.
- Creating
./dataset
directory and downloading CIFAR10/CIFAR100 in it. - Using the train script, simply specifying the parameters listed in
train_xxx.py
as a flag or manually changing them. - The parameters I used can be found in the training logs.
- For
baseline
python train_baseline.py
--data_name=cifar10/cifar100 \
--net_name=resnet20/resnet110 \
--num_class=10/100
- For
logits,st,fitnet,at,fsp,nst,pkt,ft
python train_xxx.py
--s_init=/path/to/your/student_initial_model \
--t_model=/path/to/your/teacher_model \
--data_name=cifar10/cifar100 \
--t_name=resnet20/resnet110 \
--s_name=resnet20/resnet110 \
--num_class=10/100
- For
dml
python train_dml.py
--net1_init=/path/to/your/net1_initial_model \
--net2_init=/path/to/your/net2_initial_model \
--data_name=cifar10/cifar100 \
--net1_name=resnet20/resnet110 \
--net2_name=resnet20/resnet110 \
--num_class=10/100
- The trained baseline models are used as teachers. For fair comparison, all the student nets have same initialization with the baseline models.
- The initial models, trained models and training logs are uploaded here.
- The loss trade-off parameters
--lambda_xxx
are not chosen carefully. Thus the following results do not reflect which method is better than the others.
Teacher | Student | Method | CIFAR10 | CIFAR100 |
- | resnet-20 | baseline | 92.18% | 68.33% |
resnet-20 | resnet-20 | logits | 93.01% | 69.87% |
resnet-20 | resnet-20 | st | 92.54% | 69.92% |
resnet-20 | resnet-20 | fitnet | 92.48% | 69.05% |
resnet-20 | resnet-20 | at | 92.58% | 68.56% |
resnet-20 | resnet-20 | fsp | 92.57% | 69.10% |
resnet-20 | resnet-20 | nst | 92.35% | 68.35% |
resnet-20 | resnet-20 | pkt | 92.83% | 68.83% |
resnet-20 | resnet-20 | ft | 92.92% | 68.86% |
Teacher | Student | Method | CIFAR10 | CIFAR100 |
- | resnet-20 | baseline | 92.18% | 68.33% |
- | resnet-110 | baseline | 94.04% | 72.65% |
resnet-110 | resnet-20 | logits | 93.33% | 69.94% |
resnet-110 | resnet-20 | st | 92.82% | 69.45% |
resnet-110 | resnet-20 | fitnet | 92.55% | 69.68% |
resnet-110 | resnet-20 | at | 92.84% | 69.05% |
resnet-110 | resnet-20 | fsp | 92.83% | 69.38% |
resnet-110 | resnet-20 | nst | 92.51% | 68.41% |
resnet-110 | resnet-20 | pkt | 92.95% | 69.04% |
resnet-110 | resnet-20 | ft | 93.20% | 69.45% |
Teacher | Student | Method | CIFAR10 | CIFAR100 |
- | resnet-110 | baseline | 94.04% | 72.65% |
resnet-110 | resnet-110 | logits | 94.48% | 74.72% |
resnet-110 | resnet-110 | st | 94.30% | 74.29% |
resnet-110 | resnet-110 | fitnet | 94.58% | 73.21% |
resnet-110 | resnet-110 | at | 94.34% | 73.81% |
resnet-110 | resnet-110 | fsp | 94.29% | 73.71% |
resnet-110 | resnet-110 | nst | 94.27% | 72.84% |
resnet-110 | resnet-110 | pkt | 94.76% | 73.73% |
resnet-110 | resnet-110 | ft | 94.46% | 73.41% |
Net1 | Net2 | Method | CIFAR10 | CIFAR100 |
- | resnet-20 | baseline | 92.18% | 68.33% |
- | resnet-110 | baseline | 94.04% | 72.65% |
resnet20 | resnet20 | dml | 92.99%/92.81% | 70.30%/70.19% |
resnet110 | resnet20 | dml | 94.52%/92.72% | 75.25%/70.26% |
resnet110 | resnet110 | dml | 94.92%/94.46% | 74.70%/74.91% |
- python 2.7
- pytorch 1.0.0
- torchvision 0.2.1