Giter Site home page Giter Site logo

google-research / rigl Goto Github PK

View Code? Open in Web Editor NEW
314.0 17.0 49.0 830 KB

End-to-end training of sparse deep neural networks with little-to-no performance loss.

License: Apache License 2.0

Python 91.40% Shell 0.33% Jupyter Notebook 8.27%
machine-learning computer-vision neural-networks sparse-training

rigl's Introduction

Rigging the Lottery: Making All Tickets Winners

80% Sparse Resnet-50

Paper: https://arxiv.org/abs/1911.11134

15min Presentation [pml4dc] [icml]

ML Reproducibility Challenge 2020 report

Colabs for Calculating FLOPs of Sparse Models

MobileNet-v1

ResNet-50

Best Sparse Models

Parameters are float, so each parameter is represented with 4 bytes. Uniform sparsity distribution keeps first layer dense therefore have slightly larger size and parameters. ERK applies to all layers except for 99% sparse model, in which we set the first layer to be dense, since otherwise we observe much worse performance.

Extended Training Results

Performance of RigL increases significantly with extended training iterations. In this section we extend the training of sparse models by 5x. Note that sparse models require much less FLOPs per training iteration and therefore most of the extended trainings cost less FLOPs than baseline dense training.

Observing improving performance we wanted to understand where the performance of sparse networks saturates. Longest training we ran had 100x training length of the original 100 epoch ImageNet training. This training costs 5.8x of the original dense training FLOPS and the resulting 99% sparse Resnet-50 achieves an impressive 68.15% test accuracy (vs 5x training accuracy of 61.86%).

S. Distribution Sparsity Training FLOPs Inference FLOPs Model Size (Bytes) Top-1 Acc Ckpt
- (DENSE) 0 3.2e18 8.2e9 102.122 76.8 -
ERK 0.8 2.09x 0.42x 23.683 77.17 link
Uniform 0.8 1.14x 0.23x 23.685 76.71 link
ERK 0.9 1.23x 0.24x 13.499 76.42 link
Uniform 0.9 0.66x 0.13x 13.532 75.73 link
ERK 0.95 0.63x 0.12x 8.399 74.63 link
Uniform 0.95 0.42x 0.08x 8.433 73.22 link
ERK 0.965 0.45x 0.09x 6.904 72.77 link
Uniform 0.965 0.34x 0.07x 6.904 71.31 link
ERK 0.99 0.29x 0.05x 4.354 61.86 link
ERK 0.99 0.58x 0.05x 4.354 63.89 link
ERK 0.99 2.32x 0.05x 4.354 66.94 link
ERK 0.99 5.8x 0.05x 4.354 68.15 link

We also ran extended training runs with MobileNet-v1. Again training 100x more, we were not able saturate the performance. Training longer consistently achieved better results.

S. Distribution Sparsity Training FLOPs Inference FLOPs Model Size (Bytes) Top-1 Acc Ckpt
- (DENSE) 0 4.5e17 1.14e9 16.864 72.1 -
ERK 0.89 1.39x 0.21x 2.392 69.31 link
ERK 0.89 2.79x 0.21x 2.392 70.63 link
Uniform 0.89 1.25x 0.09x 2.392 69.28 link
Uniform 0.89 6.25x 0.09x 2.392 70.25 link
Uniform 0.89 12.5x 0.09x 2.392 70.59 link

1x Training Results

S. Distribution Sparsity Training FLOPs Inference FLOPs Model Size (Bytes) Top-1 Acc Ckpt
ERK 0.8 0.42x 0.42x 23.683 75.12 link
Uniform 0.8 0.23x 0.23x 23.685 74.60 link
ERK 0.9 0.24x 0.24x 13.499 73.07 link
Uniform 0.9 0.13x 0.13x 13.532 72.02 link

Results w/o label smoothing

S. Distribution Sparsity Training FLOPs Inference FLOPs Model Size (Bytes) Top-1 Acc Ckpt
ERK 0.8 0.42x 0.42x 23.683 75.02 link
ERK 0.8 2.09x 0.42x 23.683 76.17 link
ERK 0.9 0.24x 0.24x 13.499 73.4 link
ERK 0.9 1.23x 0.24x 13.499 75.9 link
ERK 0.95 0.13x 0.12x 8.399 70.39 link
ERK 0.95 0.63x 0.12x 8.399 74.36 link

Evaluating checkpoints

Download the checkpoints and run the evaluation on ERK checkpoints with the following:

python imagenet_train_eval.py --mode=eval_once --output_dir=path/to/ckpt/folder \
    --eval_once_ckpt_prefix=model.ckpt-3200000 --use_folder_stub=False \
    --training_method=rigl --mask_init_method=erdos_renyi_kernel \
    --first_layer_sparsity=-1

When running checkpoints with uniform sparsity distribution use --mask_init_method=random and --first_layer_sparsity=0. Set --model_architecture=mobilenet_v1 when evaluating mobilenet checkpoints.

Sparse Training Algorithms

In this repository we implement following dynamic sparsity strategies:

  1. SET: Implements Sparse Evalutionary Training (SET) which corresponds to replacing low magnitude connections randomly with new ones.

  2. SNFS: Implements momentum based training without sparsity re-distribution:

  3. RigL: Our method, RigL, removes a fraction of connections based on weight magnitudes and activates new ones using instantaneous gradient information.

And the following one-shot pruning algorithm:

  1. SNIP: Single-shot Network Pruning based on connection sensitivity prunes the least salient connections before training.

We have code for following settings:

  • Imagenet2012: TPU compatible code with Resnet-50 and MobileNet-v1/v2.
  • CIFAR-10 with WideResNets.
  • MNIST with 2 layer fully connected network.

Setup

First clone this repo.

git clone https://github.com/google-research/rigl.git
cd rigl

We use Neurips 2019 MicroNet Challenge code for counting operations and size of our networks. Let's clone the google_research repo and add current folder to the python path.

git clone https://github.com/google-research/google-research.git
mv google-research/ google_research/
export PYTHONPATH=$PYTHONPATH:$PWD

Now we can run some tests. Following script creates a virtual environment and installs the necessary libraries. Finally, it runs few tests.

bash run.sh

We need to activate the virtual environment before running an experiment. With that, we are ready to run some trivial MNIST experiments.

source env/bin/activate

python rigl/mnist/mnist_train_eval.py

You can load and verify the performance of the Resnet-50 checkpoints like following.

python rigl/imagenet_resnet/imagenet_train_eval.py --mode=eval_once --training_method=baseline --eval_batch_size=100 --output_dir=/path/to/folder --eval_once_ckpt_prefix=s80_model.ckpt-1280000 --use_folder_stub=False

We use the Official TPU Code for loading ImageNet data. First clone the tensorflow/tpu repo and then add models/ folder to the python path.

git clone https://github.com/tensorflow/tpu.git
export PYTHONPATH=$PYTHONPATH:$PWD/tpu/models/

Other Implementations

Citation

@incollection{rigl,
 author = {Evci, Utku and Gale, Trevor and Menick, Jacob and Castro, Pablo Samuel and Elsen, Erich},
 booktitle = {Proceedings of Machine Learning and Systems 2020},
 pages = {471--481},
 title = {Rigging the Lottery: Making All Tickets Winners},
 year = {2020}
}

Disclaimer

This is not an official Google product.

rigl's People

Contributors

conchylicultor avatar evcu avatar hawkinsp avatar marcvanzee avatar psc-g avatar pwohlhart avatar qlzh727 avatar rchen152 avatar sun51 avatar yilei avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

rigl's Issues

Grow & Drop Ambiguity

Within the paper, it is referenced that when selecting the weights to grow by using the ArgTopK function over the sparse value's gradients, you may NOT select indices that are left over after the drop phase.

Screen Shot 2020-10-20 at 4 53 09 PM

After the drop phase, technically the topology of the network is smaller for a brief moment before entering the grow phase, therefore I have a couple questions:

  1. In the grow phase, can elements that were dropped in the drop phase be de-selected for drop candidacy (assuming their gradient is large enough to be selected by the ArgTopK)?

  2. If the answer to #1 is "yes", then are these values re-initialized to 0, or are they unaltered?

EDIT: seems like here covers this. it looks to me like they are by default NOT re-initialized, but rather kept as-is

I have had a hard time parsing through the code here, though it seems to me that the answer to 1 should be "yes" and the answer to 2 should be "they are unaltered".

I am reimplementing the paper in PyTorch and am having a hard time reproducing your results. For my previous simulations I have ran it under the assumption that the answer to 1 is "no", however I am re-running them with the "yes" and "unaltered" answers.

Thank you!

Specify TF and TF.data versions?

Hi,

I tried running the cifar10 example with:

tensorflow-datasets      1.3.0
tensorflow-estimator     1.15.1
tensorflow-gpu           1.15.4

This fails with the error:

  File "/srv/home/varunsundar/rigl/rigl/cifar_resnet/data_helper.py", line 105, in input_fn
    images_batch, labels_batch = tf.compat.v1.data.make_one_shot_iterator(
  File "/srv/home/varunsundar/.conda/envs/tf37/lib/python3.7/site-packages/tensorflow_core/python/util/module_wrapper.py", line 193, in __getattr__
    attr = getattr(self._tfmw_wrapped_module, name)
AttributeError: module 'tensorflow._api.v1.compat.v1.compat' has no attribute 'v1'

Seems like the tensor flow version is too high (too new)- I'm guessing tf.compat.v1.data became tf.data?

Could you specify the exact tf and tf-datasets versions to be used for reproduction?

MetaInit w. VGG

I want to try MetaInit with VGG. Can you share some instructions for it? Thank you so much

No module named 'officialresnet'

Hi,
I am trying to run imagenet_train_eval.py but I get the following error:
Traceback (most recent call last): File "rigl/imagenet_resnet/imagenet_train_eval.py", line 37, in <module> from officialresnet import imagenet_input ModuleNotFoundError: No module named 'officialresnet'
Should I install officialresnet library separately?
Thanks

Using Rigl to train with structured sparsity

Hi,
Thank you for sharing your work.

I have been using this codebase by modifying the initial sparsification method to give some structured sparsity to the network. Now, what I would like is to make sure that when back-propagation takes place and weights are updated, they only update the values in the initial non-zero positions and do not update the positions which are already zero.

How can I achieve this with the codebase here?

TF2 Grow Scores calculation

The TF1 code repo calculates grow scores based on the gradients of masked variables (after multiplication of masks and vars)
Refer:

masked_grads_vars = self._optimizer.compute_gradients(

where masked_weights are fetched from https://github.com/tensorflow/tensorflow/blob/r1.8/tensorflow/contrib/model_pruning/python/pruning.py#L258

whereas in the TF2 code, there is no such thing, which I think affects the performance of RigL in general. Is there any way to apply a fix for this?
Also, is this because if the TF2 code the updates are made using a part of validation dataset? If I were to change it back to how TF1 does the updates based on the specific training batch, I would have to make the above-mentioned change right?
[We want the gradient calculation for mask update to happen after the mask and weights are multiplied and not on the original weights before parameterization as shown in the figure attached]
image

RigL TF2 on Resnet50 + Imagenet

I wanted to use RigL TF2 code to train a sparsified Resnet50 architecture and see how that goes.
I loaded a Resnet50 -
model = tf.keras.applications.resnet50(
include_top=True,
weights="imagenet",
input_tensor=None,
input_shape=None,
pooling=None,
classes=1000,
classifier_activation="softmax",)

And then sparsified it.
The test accuracy seems to be stuck at 0.1% for the first 2000 steps even
Is this a common problem?

RigL TF2 - Initial Sparsification

There is an on-going effort to open-source a tf2 version with tf-model_optimization toolkit. It should be out in few months. Pytorch implementation is on hold (I plan to do when I have more time), but I am happy to help if you are interested adding that to the repo.

Originally posted by @evcu in #2 (comment)

Hello @evcu ,
I am going through the TF2 implementation for RigL and I'm trying to understand how the initial sparsification of the model is being controlled? I don't see an explicit argument for mask_init_method like in the TF1 implementation which I can set to as ERK or uniform. Can you point me to how I can control the initial sparsification of the model? Is it through pruning_params in rigl.gin config file?

Can you post the CIFAR-10 results of the paper?

Hi,

I am trying to reproduce your paper's results on CIFAR-10 in PyTorch. In the paper figure 4 describes the results however there are no exact numbers.
Can you please post the exact results you got from your experiments on CIFAR?

Thanks

How to load pretained weights?

Hi!
How can I load pertained weights in README?

I want to extract those weight for research purpose.

Can you tell me how can I load it? or can give me brief code to make model and load weights?

How to train own convolutional network

Hi, good work!
How can I train my CNN using RIGL?

In issue #2 , you recommend modify the model and wrap the optimizer but no matter how I try to do it, for my CNN it doesn't work.

Can you tell me exactly how this is done for CNN?

Thank you!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.