Giter Site home page Giter Site logo

aramachandran2000 / cap Goto Github PK

View Code? Open in Web Editor NEW

This project forked from ist-daslab/cap

0.0 0.0 0.0 21.41 MB

Repository for Correlation Aware Prune (NeurIPS23) source and experimental code

License: Apache License 2.0

Shell 0.05% Python 99.86% Makefile 0.08%

cap's Introduction

Official repository of CAP


This repository contains the code for the NeurIPS 2023 paper CAP: Correlation-Aware Pruning for Highly-Accurate Sparse Vision Models.

To facilitate reproducibility of our experiments, we integrate our pruners with the popular open-source library SparseML and build on top of the rwightman's train.py script from https://github.com/rwightman/pytorch-image-models.

Structure of the repository


The modified source code from SparseML is located in src/ and its subdirectories. Main pruning algorithms are implemented in src/sparseml/pytorch/sparsification/pruning directory as SparseML PruningModifiers. Notably:

  • CAPruningModifier: modifier_pruning_cap.py (our CAP pruner)
  • FastCAPruningModifier: modifier_pruning_fast_cap.py (our FastCAP pruner)
  • GlobalMagnitudePruningModifier: modifier_pruning_magnitude.py (implemented by NeuralMagic)
  • OBSPruningModifier: modifier_pruning_obs.py (implemented by authors of the oBERT paper)

The code to launch experiments is located inside research/ directory.

  • research/ — root directory for experiments
    ├── sparse_training.py — main script for gradual pruning (based on train.py from timm)
    ├── one_shot_pruning.py — script for running one-shot pruning experiments
    ├── run_gradual_pruning.sh — script to launch sparse_training.py
    ├── run_one_shot_pruning.sh — script to launch one_shot_pruning.py
    ├── utils/ — additional utils used in training scripts
    ├── configs/.yaml recipes with training hyperparameters
    ├── recipes/ — SparseML recipes for pruning

Usage


Installation

The recommended way to run CAP is via conda enviroment.

Configure enviroment

One needs to install torch with GPU support and timm library to run the code:

Follow the steps below to setup a conda environment:

conda create --name CAP python==3.9
conda activate CAP
conda install scipy numpy scikit-learn pytorch=1.13.1 torchvision==0.14.1 torchaudio==0.13.1 cudatoolkit=11.3 -c pytorch 
pip install -r requirements.txt

To install SparseML type (in the root directory of the project):

python setup.py install

(Optional) We use W&B for logging. Install it via pip in case you want to log data there:

pip install wandb

If logging to W&B prior to launching script define W&B environment variables:

export WANDB_ENTITY=<your_entity>
export WANDB_PROJECT=<project_name>
export WANDB_NAME=<run_name>

Workflow

  • Select a config with training hyperparameters (research/configs)
  • Select a SparseML recipe (research/recipes)
  • Define other hyperparams in the launch script (research/run_gradual_pruning.sh or research/run_one_shot_pruning.sh)
  • Enjoy!

Example usage

Recipes used in the paper are located in research/recipes directory. Choose a recipe from one_shot subdirectory for one-pruning and one_shot+finetune for one-shot+finetune pruning and gradual_pruning for experiments with a gradual increase of sparsity level.

One-shot pruning

python one_shot_pruning.py \
    \
    --data-dir <data_dir> \
    \
    --sparseml-recipe <path_to_recipe> \
    \
    --model <model_name> \
    \
    --experiment <experiment_name> \
    \
    -gb <gs_loader_batch_size> \
    -vb <validation_batch_size> \
    \
    --sparsities <list_of_sparsities>

One-shot+finetune/gradual pruning

python -m torch.distributed.launch \
    --nproc_per_node=<num_proc> \
    --master_port=<master_port> \
    sparse_training.py \
    \
    --data-dir <data_dir> \
    \
    --sparseml-recipe <path_to_recipe> \
    \
    --model <model_name> \
    \
    --experiment <experiment_name> \
    \
    -gb <gs_loader_batch_size> \
    -vb <validation_batch_size> \
    \
    --sparsities <list_of_sparsities>

Tweaking CAP hyperparameters

There are several hyperparameters in the oViT method that can be adjusted for better peformance and tuned for each model/dataset. We provide some defaults that should work well across many different models, as demonstrated in the paper.

    :param mask_type: String to define type of sparsity to apply. 'unstructured'
        'block4', 'N:M' are supported. Default is 'unstructured'. For N:M provide
        two integers that will be parsed, e.g. '2:4'
    :param num_grads: number of gradients used to calculate the Fisher approximation
    :param damp: dampening factor, default is 1e-7
    :param fisher_block_size: size of blocks along the main diagonal of the Fisher
        approximation, default is 50
    :param grad_sampler_kwargs: kwargs to override default train dataloader config
        for pruner's gradient sampling
    :param num_recomputations: number of EmpiricalFisher matrix recomputations
    :param blocks_in_parallel: amount of rows traversed simultaneously by OBSX pruning modifier
    :param fisher_inv_device: select specific device to store Fisher inverses.
    :param traces_backup_dir: str. If one would like to store pruning traces on disk, one can 
        specify temporary dir for storage. 

cap's People

Contributors

godofnothing avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.