Giter Site home page Giter Site logo

guanhuawang / sensai Goto Github PK

View Code? Open in Web Editor NEW
64.0 9.0 8.0 1.3 MB

sensAI: ConvNets Decomposition via Class Parallelism for Fast Inference on Live Data

Home Page: https://rise.cs.berkeley.edu/projects/sensai/

License: Apache License 2.0

Python 97.60% Shell 2.40%
machine-learning distributed-systems cifar10 cifar100 imagenet1k cnn-classification sysml deep-neural-networks deep-learning distributed-deep-learning

sensai's Introduction

sensAI: ConvNets Decomposition via Class Parallelism for Fast Inference on Live Data

Environment

Linux, python 3.6+

Setup

pip install -r requirements.txt

Instruction

Supported CNN architectures and datasets:

Dataset Architecture(ARCH)
CIFAR-10 vgg19_bn, resnet110, resnet164, mobilenetv2, shufflenetv2
CIFAR-100 vgg19_bn, resnet110, resnet164
ImageNet-1K vgg19_bn, resnet50

1. Generate class groups

For CIFAR-10/CIFAR-100:

python3 group_selection.py \
        --arch $ARCH \
        --resume $pretrained_model \
        --dataset $DATASET \
        --ngroups $number_of_groups \
        --gpu_num $number_of_gpu 

For ImageNet-1K:

python3 group_selection.py \
        --arch $ARCH \
        --dataset imagenet \
        --ngroups $number_of_groups \
        --gpu_num $number_of_gpu \
        --data /{path_to_imagenet_dataset}/

Pruning candidate now stored in ./prune_candidate_logs/

2. Prune models

For CIFAR-10/CIFAR-100:

python3 prune_and_get_model.py \
        -a $ARCH \
        --dataset $DATASET \
        --resume $pretrained_model \
        -c ./prune_candidate_logs/ \
        -s ./{TO_SAVE_PRUNED_MODEL_DIR}/

For ImageNet-1K:

python3 prune_and_get_model.py \
        -a $ARCH \
        --dataset imagenet \
        -c ./prune_candidate_logs/ \
        -s ./{TO_SAVE_PRUNED_MODEL_DIR}/ \
        --pretrained

Pruned models are now saved in ./{TO_SAVE_PRUNED_MODEL_DIR}/$ARCH/

3. Retrain pruned models

For CIFAR-10/CIFAR-100:

python3 retrain_grouped_model.py \
        -a $ARCH \
        --dataset $DATASET \
        --resume ./{TO_SAVE_PRUNED_MODEL_DIR}/ \
        --train_batch $batch_size \
        --epochs $number_of_epochs \
        --num_gpus $number_of_gpus

For ImageNet-1K:

python3 retrain_grouped_model.py \
        -a $ARCH \
        --dataset imagenet \
        --resume ./{TO_SAVE_PRUNED_MODEL_DIR}/ \
        --epochs $number_of_epochs \
        --num_gpus $number_of_gpus \
        --train_batch $batch_size \
        --data /{path_to_imagenet_dataset}/

Retrained models now saved in ./{TO_SAVE_PRUNED_MODEL_DIR}_retrained/$ARCH/

4. Evaluate

For CIFAR-10/CIFAR-100:

python3 evaluate.py \
        -a $ARCH \
        --dataset=$DATASET \
        --retrained_dir ./{TO_SAVE_PRUNED_MODEL_DIR}_retrained/ \
        --test-batch $batch_size

For ImageNet-1K:

python3 evaluate.py \
        -d imagenet \
        -a $ARCH \
        --retrained_dir ./{TO_SAVE_PRUNED_MODEL_DIR}_retrained/ \
        --data /{path_to_imagenet_dataset}/

Contributors

Thanks for all the main contributors to this repository:

And many others Zihao Fan, Hank O'Brien , Yaoqing Yang, Adarsh Karnati, Jichan Chung, Yingxin Kang, Balaji Veeramani, Sahil Rao.

Citation

@inproceedings{wang2021sensAI,
 author = {Guanhua Wang and Zhuang Liu and Brandon Hsieh and Siyuan Zhuang and Joseph Gonzalez and Trevor Darrell and Ion Stoica},
 title = {{sensAI: ConvNets Decomposition via Class Parallelism for Fast Inference on Live Data}},
 booktitle = {Proceedings of Fourth Conference on Machine Learning and Systems (MLSys'21)},
 year = {2021}
} 

sensai's People

Contributors

guanhuawang avatar hjobrien avatar hsiehbrandon avatar jason-khan avatar kenan-jiang avatar suquark avatar zihao-fan avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

sensai's Issues

Missing Packages in requirements.txt

๐Ÿš€ Feature

I propose we add numpy, scikit-learn, and matplotlib to requirements.txt.

Motivation

Several Python modules in the repository depend on the scikit-learn package, but scikit-learn is not in requirements.txt. This is problematic because several modules will error if pip3 install -r requirements.txt is run, but pip3 install scikit-learn is not. Moreover, some modules only work with versions of sklearn strictly below version 0.22 (the most recent version). Therefore, it would be prudent to also include the version of scikit-learn so we get consistent results.

matplotlib is used in a few of the utility modules. The utility modules will similarily fail to execute if matplotlib is not installed in addition to the packages specified in requirements.txt.

numpy is installed as part of torch; however, relying on numpy to be installed with torch is sub-optimal for a couple of reasons: one, we lose the flexibility to specify a specific version of numpy; two, we break an abstraction barrier by assuming the implementation of torch.

Pitch

Add the following three lines to requirements.txt:

matplotlib
numpy
scikit-learn==0.21

Alternatives

numpy can be excluded from the requirements.txt and the code should execute normally. Because matplotlib is only used in a couple of utility modules, matplotlib could be an optional dependency.

Additional context

scikit-learn must be version 0.21 or earlier because of calls to _check_sample_weight in even_k_means, which changed function signatures between versions 0.21 and 0.22.

Averaging by batch not applied in get_prune_candidates.py

In get_prune_candidates.py:

with ActivationRecord(model) as recorder:
        # collect pruning data
        bar = tqdm(total=len(pruning_loader))
        for batch_idx, (inputs, _) in enumerate(pruning_loader):
            bar.update(1)
            if use_cuda:
                inputs = inputs.cuda()
            recorder.record_batch(inputs)
        candidates_by_layer = recorder.generate_pruned_candidates()
        return candidates_by_layer

when running recorder.generate_pruned_candidates(), apoz_scores are expected to be in the range of [0, 100] but due to __exit__() method not applied, the apoz_scores have values bigger than 100 and has wrong behavior when apoz percent thresholds are applied.

I suggest removing indentation like this:

with ActivationRecord(model) as recorder:
        # collect pruning data
        bar = tqdm(total=len(pruning_loader))
        for batch_idx, (inputs, _) in enumerate(pruning_loader):
            bar.update(1)
            if use_cuda:
                inputs = inputs.cuda()
            recorder.record_batch(inputs)
candidates_by_layer = recorder.generate_pruned_candidates()
return candidates_by_layer

Cannot generate all 10 pruned models for cifar10

When running:
python3 group_selection.py --arch=resnet110 --resume=resnet110-bottleneck-cifar10.pth.tar --dataset=cifar10 --ngroups=10 -gpu_num=10
We have this error:
RuntimeError: cuda runtime error (711) : peer mapping resources exhausted at /opt/conda/conda-bld/pytorch_1573049304260/work/aten/src/THC/THCGeneral.cpp:157
It seems that the code has error when we choose (>= 10) gpus and 10 groups for any models on cifar10 dataset.

README Enhancements

๐Ÿš€ Feature

Update supported networks in the README. Introduce a checkpoint variable and use the variable in the README code snippets. Change default architecture for group_selection.py (maybe?).

Motivation

The README states that the only supported architectures are vgg19_bn, resnet110, and resnet164. However, the help message for group_selection states that alexnet, densenet, preresnet , resnet110, resnet164, resnext, vgg11, vgg11_bn, vgg13, vgg13_bn, vgg16, vgg16_bn, vgg19, vgg19_bn, and wrn are all supported. The README should be updated to reflect the actual architectures supported.

The code snippet for the first step of the Experiment Instructions contains --resume=vgg19bn-cifar100.pth.tar. The code snippet for the sixth step of the Experiment Instructions contians --resume ./checkpoint_bearclaw.pth.tar. It is unclear if vgg19bn-cifar100.pth.tar is supposed to be the same checkpoint as checkpoint_bearclaw.pth.tar. I think it would be more clear if we defined a variable CHECKPOINT (in the same way ARCH and DATASET are defined), and use $CHECKPOINT subsequently in the README.

Additionally, the default architecture for group_selection.py is resnet20, but the help message states that the default is resnet18. Moreover, neither resnet20 nor resnet18 are in the list of supported models returned by load_model.model_arches('cifar').

Pitch

Change line 15 of the README to:

Supported neural network architectures: ARCH = {alexnet, densenet, preresnet, resnet110, resnet164, resnext, vgg11, vgg11_bn, vgg13, vgg13_bn, vgg16, vgg16_bn, vgg19, vgg19_bn, wrn}

Change lines 22-24 of the README to:

python3 group_selection.py --arch=$ARCH --resume=$CHECKPOINT --dataset=$DATASET -- 
ngroups=10

Change lines 40-42 of the README to:

python3 prune_and_get_model.py -a $ARCH --dataset $DATASET --resume $CHECKPOINT  -c ./prune_candidate_logs/ -s ./TO_SAVE_MODEL_BASE_DIR

Discrepancy in Help Message

๐Ÿš€ Feature

I propose that we require the --resume argument in group_selection.py.

Motivation

Currently, group_selection.py does not require the --resume command-line argument, and provides an empty string as the default argument. The group_selection.py module calls the load_pretrained_models function in load_model.py, in which the statement assert os.path.isfile(resume_checkpoint), 'Error: no checkpoint found!' is executed. In brief, the program will exit early if a checkpoint is not provided.

Because the program will exit early if a checkpoint is not provided, I think that the --resume argument should be required.

Pitch

Replace lines 21-22 of group_selection.py with:

parser.add_argument('--resume', required=True, type=str, metavar='PATH',
                    help='path to latest checkpoint')

Use of Assertions over Exceptions

๐Ÿš€ Feature

Replace assertions with exceptions in load_pretrain_model and get_prune_candidates.

Motivation

A ubiquitous convention is to use assertions to check internal correctness and exceptions to validate arguments. In line 42 of get_prune candidates and line 19 of load_model, assert is used validate an argument when an exception would be more appropriate.

From the Google Style guide:

Do not use assert statements for validating argument values of a public API. assert is used to ensure internal correctness, not to enforce correct usage nor to indicate that some unexpected event occurred. If an exception is desired in the latter cases, use a raise statement.

Pitch

Replace line 19 of load_model.py with:

if not os.path.isfile(resume_checkpoint):
    raise ValueError("Error: no checkpoint found!)

Replace line 42 of get_prune_candidates with:

if not args.grouped:
    raise ValueError("grouped must be non-empty.")

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.