Giter Site home page Giter Site logo

cs439_project's Introduction

CS439 project - Gradient Compression over SGD and Adam: A Survey

Code for the experimental parts of the course project in CS-439: Optimization for Machine Learning.

The implementation is based on this repository's code and uses PyTorch.

Requirements

The following packages were used for the experiments. Newer versions are also likely to work.

  • torchvision==0.2.1
  • numpy==1.15.4
  • torch==0.4.1
  • pandas==0.23.4
  • scikit_learn==0.20.3

To install them automatically: pip install -r requirements.txt

Organization

  • optimizers/ contains the custom optimizer, namely CompSGD, ErrorFeedbackSGD and OneBitAdam.
  • models/ contains the deep net architectures. Only Resnet were experimented.
  • results/ contains the results of the experiments in pickle files.
  • utils/ contains utility functions for saving/loading objects, convex optimization, progress bar...
  • checkpoints/ contains the saved models' checkpoints with all the nets parameters. The folder is empty here as those files are very large.

Notations

We clarify the noations here. In particular,

  • ssgd: SGD with sign gradient compression.
  • sgd_topk: SGD with top-k gradient compression.
  • sgd_pcak: SGD with k-PCA gradient compression.
  • sssgd: SGD with scaled sign gradient compression.
  • ussgd: Unscaled SignSGD (MEM-SGD), i.e., SGD with sign gradient compression and error feedback.
  • ssgdf: Error-feedback SignSGD, i.e., SGD with scaled sign gradient compression and error feedback.
  • onebit_adam_unscaled: the original version of one-bit Adam.
  • onebit_adam_scaled: the scaled version of one-bit Adam.

Usage

  • run.ipynb has three parts, consisting of lines for tuning learning rates, running experiments, and plotting figures that are in the report.
  • main.py can be called from the command line to run a single network training and testing. It can take a variety of optional arguments. Type python main.py --help for further details.
  • utils.hyperparameters.py facilitate the definition of all the hyper-parameters of the experiments.
  • tune_lr.py allows to tune the learning rate for a network architecture/data set/optimizer configuration.
  • main_experiments.py contains the experiments in the report.
  • plot_graph.py constains the code for plotting the results
  • print_stats.py constains the code to list the best performance of each experiment done by tunr_lr.py

cs439_project's People

Contributors

robertflame avatar doub7e avatar yljblues avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.