Giter Site home page Giter Site logo

pruning_before_training's Introduction

Gradient Coupled Flow: Performance Boosting On Network Pruning By Utilizing Implicit Loss Decrease

Introduce

This project is a PyTorch implementation of the GCS paper, including several pruning before training methods tested in the paper.

@author: Kang Xiatao ([email protected])

Structure

  • models: (lenet,resnet,vgg) Three models and the model class containing the mask
  • pruner: Pruning algorithm
  • runs: Output folder to store weights, evaluation information, etc.
  • utils: custom helper function

Environment

The code has been tested to run on Python3.

Some package versions are as follows:

  • torch == 1.7.1
  • numpy == 1.18.5
  • tensorboardX==2.4

Run

  • E.g. cifar10/vgg19 prune ratio: 90%
# GCS
python main.py --config 'cifar10/vgg19/90' --run 'test' --rank_algo 'gcs' --prune_mode 'rank'
# GCS-Group
python main.py --config 'cifar10/vgg19/90' --run 'test' --rank_algo 'gcs-group' --prune_mode 'rank'
  • E.g. mnist/lenet5 prune ratio: 99.9%
python main.py --config 'mnist/lenet5/99.9' --run 'test' --rank_algo 'gcs' --prune_mode 'rank'
  • Model optional: lenet, vgg, resnet

  • Dataset optional: fashionmnist, mnist, cifar10, cifar100(Other datasets need to be manually downloaded to the local)

  • All parameters(The default parameters are determined by the configs.py):

    Console Parameters Remark
    config = '' # Select Dataset, Model, and Pruning Rate
    pretrained = '' # Path to load pretrained model
    run = 'test' # Experimental Notes
    rank_algo = 'gcs' # Choose a pruning algorithm (snip, grasp, synflow, gcs, gcs-group)
    prune_mode = 'rank' # Choose a pruning mode (dense, rank, rank/random, rank/iterative)
    dp = '../Data' # Modify the path of the dataset
    storage_mask = 0 # Store the resulting mask
    --- Parameters for debugging
    debug = 0 # for debugging
    epoch # Modify the number of training rounds
    batch_size # Modify the batch size of training samples
    l2 # L2 regularization hyperparameters
    lr_mode = 'cosine' # Set the learning rate decay method (cosine or preset)
    optim_mode = 'SGD' # Choose an optimizer (SGD or Adam)
    train_mode = 1 Whether to use train mode when calculating weight sensitivity
    dynamic = 1 # Whether to use dynamic iteration
    num_iters_prune # Number of rounds for iterative pruning (default: 100)
    data_mode # Data sampling mode (see pruning.py for details)
    grad_mode # Calculate gradient mode (see pruning.py for details)
    score_mode # Calculate sort score Mode (see pruning.py for details)
    num_group # Number of gradient groups(GCS-Group)

...

To be added ...

pruning_before_training's People

Contributors

kangxiatao avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.