PyTorch implementation for EagleEye: Fast Sub-net Evaluation for Efficient Neural Network Pruning
Bailin Li, Bowen Wu, Jiang Su, Guangrun Wang, Liang Lin
Presented at ECCV 2020 (Oral)
If you use EagleEye in your research, please consider citing:
@misc{li2020eagleeye,
title={EagleEye: Fast Sub-net Evaluation for Efficient Neural Network Pruning},
author={Bailin Li and Bowen Wu and Jiang Su and Guangrun Wang and Liang Lin},
year={2020},
eprint={2007.02491},
archivePrefix={arXiv},
primaryClass={cs.CV}
}
- Inference Code
- Pruning Strategy Generation
- Adaptive-BN-based Candidate Evaluation of Pruning Strategy
- Finetuning of Pruned Model
For the ease of your own implementation, here we present the key code for proposed Adaptive-BN-based Candidate Evaluation. The official implementation will be released soon.
def eval_pruning_strategy(model, pruning_strategy, dataloader_train):
# Apply filter pruning to trained model
pruned_model = prune(model, pruning_strategy)
# Adaptive-BN
pruned_model.train()
max_iter = 50
with torch.no_grad():
for iter_in_epoch, sample in enumerate(dataloader_train):
pruned_model.forward(sample)
if iter_in_epoch > max_iter:
break
# Eval top-1 accuracy for pruned model
acc = pruned_model.get_val_acc()
return acc
-
Prepare Data
Download
ILSVRC2012
dataset from http://image-net.org/challenges/LSVRC/2012/index#introduction -
Download Pretrained Models
We provide reported pruned models in Dropbox. Please put the downloaded models in the dir of
models/ckpt/
. -
Prepare Runtime Environment
pip install -r requirements.txt
usage: main.py [-h] [--model_name MODEL_NAME] [--num_classes NUM_CLASSES]
[--checkpoint CHECKPOINT] [--gpu_ids GPU_IDS [GPU_IDS ...]]
[--batch_size BATCH_SIZE] [--dataset_path DATASET_PATH]
[--dataset_name DATASET_NAME] [--num_workers NUM_WORKERS]
[--lr LR] [--weight_decay WEIGHT_DECAY] [--momentum MOMENTUM]
[--max_rate MAX_RATE] [--affine AFFINE]
[--compress_schedule_path COMPRESS_SCHEDULE_PATH]
[--flops_target FLOPS_TARGET] [--output_file OUTPUT_FILE]
optional arguments:
-h, --help show this help message and exit
--model_name MODEL_NAME
what kind of model you are using. Only support
`resnet50`, `mobilenetv1` and `mobilenetv1_imagenet`
--num_classes NUM_CLASSES
num of class label
--checkpoint CHECKPOINT
path to model state dict
--gpu_ids GPU_IDS [GPU_IDS ...]
GPU ids.
--batch_size BATCH_SIZE
batch size while fine-tuning
--dataset_path DATASET_PATH
path to dataset
--dataset_name DATASET_NAME
filename of the file contains your own
`get_dataloaders` function
--num_workers NUM_WORKERS
Number of workers used in dataloading
--lr LR learning rate while fine-tuning
--weight_decay WEIGHT_DECAY
weight decay while fine-tuning
--momentum MOMENTUM momentum while fine-tuning
--max_rate MAX_RATE define search space
--affine AFFINE define search space
--compress_schedule_path COMPRESS_SCHEDULE_PATH
path to compression schedule
--flops_target FLOPS_TARGET
flops constraints for pruning
--output_file OUTPUT_FILE
path to compression schedule
python3 search.py \
--model_name mobilenetv1 \
--num_classes 1000 \
--checkpoint models/ckpt/imagenet_mobilenet_726.pth \
--gpu_ids 5 \
--batch_size 128 \
--dataset_path /data/imagenet \
--dataset_name imagenet_train_val_split \
--num_workers 4 \
--flops_target 0.5 \
--max_rate 0.7 \
--affine 0 \
--flops_target 0.5 \
--output_file log.txt \
--compress_schedule_path {compress_config/mbv1_imagenet.yaml|/compress_config/res50_imagenet.yaml}
python choose_strategy.py log.txt
Coming soon...
For ResNet50:
python3 main.py \
--model_name resnet50 \
--num_classes 1000 \
--checkpoint models/ckpt/{resnet50_25flops.pth|resnet50_50flops.pth|resnet50_72flops.pth} \
--gpu_ids 4 \
--batch_size 512 \
--dataset_path {PATH_TO_IMAGENET} \
--dataset_name imagenet \
--num_workers 20
For MobileNetV1:
python3 main.py \
--model_name mobilenetv1 \
--num_classes 1000 \
--checkpoint models/ckpt/mobilenetv1_50flops.pth \
--gpu_ids 4 \
--batch_size 512 \
--dataset_path {PATH_TO_IMAGENET} \
--dataset_name imagenet \
--num_workers 20
After running above program, the output looks like below:
######### Report #########
Model:resnet50
Checkpoint:models/ckpt/resnet50_50flops_7637.pth
FLOPs of Original Model:4.089G;Params of Original Model:25.50M
FLOPs of Pruned Model:2.057G;Params of Pruned Model:14.37M
Top-1 Acc of Pruned Model on imagenet:0.76366
##########################
Correlation between evaluation and fine-tuning accuracy with different pruning ratios (MobileNet V1 on ImageNet classification Top-1 results)
Model | FLOPs | Top-1 Acc | Top-5 Acc | Checkpoint |
---|---|---|---|---|
ResNet-50 | 3G 2G 1G |
77.1% 76.4% 74.2% |
93.37% 92.89% 91.77% |
resnet50_75flops.pth resnet50_50flops.pth resnet50_25flops.pth |
MobileNetV1 | 284M | 70.9% | 89.62% | mobilenetv1_50flops.pth |
Model | FLOPs | Top-1 Acc |
---|---|---|
ResNet-50 | 62.23M | 94.66% |
MobileNetV1 | 26.5M 12.1M 3.3M |
91.89% 91.44% 88.01% |