Giter Site home page Giter Site logo

landscapes's Introduction

landscapes

Exploring loss landscapes

Installation

anaconda simplifies dependency management. To install, execute:

conda env create -f torch-land.yml
conda activate torch-land
export PYTHONPATH=.

Scripts

All scripts are run from the projects root directory and further specify their usage when called with the -h flag.

Training

Training a model whose loss landscapes we want to investigate later:

python src/scripts/train.py resnet fashion-mnist

Computing loss landscapes

After having trained a model, to compute loss landscapes (= losses over a 2-dimensional parameter subspace) use the gird.py script, e.g.

python src/scripts/grid.py grid9 resnet fashion-mnist --grid_width=9

Visualizing the results

Now, that the loss values have been computed, visualize the landscapes using 2d heatmaps by calling the visualize.py script with the same parameters.

python src/scripts/visualize.py grid9 resnet fashion-mnist --grid_width=9

Experiments

The commands to run the experiments are documented in the files experiments_run.sh and experiments_visualize.sh.

The landscapes are computed using a pair of random filter-normalized vectors that perturb the model's parameters. The losses correspond to a training step, meaning only a single mini-batch.

We use three pairs of perturbation vectors and the training-set's first three mini-batches of 256 images.

For visualization, we can either look at heatmaps or contour-plots (using the --contour flag on visualize.py):

Some results

Training progress: ResNet on CIFAR-10 before training and after the first and ninth episode

Perturbing only a single layer or even a single conv-filter

First (convolutional) layer:

Last (fully connected) layer:

First filter in first layer:

Different activation functions before and after overfitting: ReLU, sigmoid and tanh

ReLU

sigmoid

tanh

Training progress on Fashion-MNIST

Network architecture: Resnet14 vs VGG11

before overfitting (after 1 epoch)

with overfitting (after 9 epochs)

Zooming in 100x (VGG one shows only noise)

landscapes's People

Contributors

l-berg avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.