Giter Site home page Giter Site logo

mattkleinsmith / pbt Goto Github PK

View Code? Open in Web Editor NEW
165.0 12.0 25.0 123 KB

Population Based Training (in PyTorch with sqlite3). Status: Unsupported

License: MIT License

Python 100.00%
pbt hyperparameters population-based-training deep-learning hyperparameter-optimization hyperparameter-tuning hyperparameter-search deepmind

pbt's Introduction

PBT: Population Based Training

Population Based Training of Neural Networks, Jaderberg et al. @ DeepMind

What this code is

A PyTorch implementation of PBT.

What this code is for

Finding a good hyperparameter schedule.

How to use this code

Warning: This implementation isn't user friendly yet. If you have any questions, create a github issue and I'll try to help you.

Steps:

  1. Wrestle with dependencies.
  2. Edit config.py to set your options.
  3. Store your data as bcolz carrays. See datasets.py for an example.
  4. In a terminal, enter: python main.py --exploiter
  5. If you want to use a second GPU, then in a second terminal, enter: python main.py --gpu 1 --population_id -1, where "1" refers to your GPU's ID in nvidia-smi, and "-1" means to work on the most recently created population.

When finished, the process will print the path to the weights of the best performing model.

Figures for intuition

png

png

These figures are for building intuition for PBT. They aren't the results of a rigorous experiment. In the accuracy plots, the best model is shown in purple. In the hyperparameter scatter plots, the size of the dots grow as the models train. The hyperparameter configurations of the best model from each population are purple stars.

Notice how the hyperparameter configurations evolve in PBT, but stay the same in random search.

How does PBT work?

PBT trains each model partially and assesses them on the validation set. It then transfers the parameters and hyperparameters from the top performing models to the bottom performing models (exploitation). After transferring the hyperparameters, PBT perturbs them (exploration). Each model is then trained some more, and the process repeats. This allows PBT to learn a hyperparameter schedule instead of only a fixed hyperparameter configuration. PBT can be used with different selection methods (e.g. different ways of defining "top" and "bottom" (e.g. top 5, top 5%, etc.)).

For more information, see the paper or blog post.

Selection method in this code

Truncation selection: For each model in the bottom 20% of performance, sample a model from the top 20% and transfer its parameters and hyperparameters to the worse model. One can think of the models in the bottom 20% as being truncated during each exploitation step. Leave the top 80% unchanged. This selection method was used in the paper.

About the figures

The figures above were produced with a naive PBT selection method: select the best model each time. The accuracy improves to around 99.35% with the selection method in the paper: truncation selection. Seeds will change results. A simple conv net was used. Dataset: MNIST.

I produced these figures using an old and very different version of this repo. I haven't yet re-added logging and plotting.

The essence of this code

Managing tasks via a sqlite database table.

Each task corresponds to training a model for half an epoch. These tasks can be done in parallel. Once in a while the exploiter process truncates the worst-performing models, which blocks other processes from training models for a bit. That makes it 99% parallel instead of 100% parallel like random search.

Since this code is mostly about task management, it isn't very tied to a particular deep learning framework. With a little work, one could replace the PyTorch ties with TensorFlow. However, this assumes you have your hyperparameter space defined in your framework of choice, which is what you need for any hyperparameter optimization algorithm, including random search. As of this writing, the hyperparameter space in this code only has two dimensions: learning rate and momentum coefficient.

Acknowledgements

This repo is inspired by bkj's pbt repo, where they replicated figure 2 of the paper.

pbt's People

Contributors

mattkleinsmith avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

pbt's Issues

help_ModuleNotFoundError: No module named 'mkl'

hi,l am a student,when i run your pbt.l meet some problem.
Traceback (most recent call last):
File "main.py", line 2, in
import mkl
ModuleNotFoundError: No module named 'mkl'
but when use command 'conda list|grep mkl' ,its show that conda environment contain 'mkl' package .
please tell me.Why does this happen? thanks

error in line 86 of Trainer.py

error in line 86 of Trainer.py :
Execute this statement:
with torch.no_grad ():
AttributeError: 'module' object has no attribute 'no_grad',
How to solve this? Thank you

Network architecture hyperparameters

Hello,

The technique of PBT seems very interesting to me. However, I was wondering if it is possible to vary hyperparameter such as the number of layers and the number hidden units? This would result in different dimensions and I am not sure how the copying works.

Thank you in advance!

help _out of memory

hi,l am a student,when i run your pbt.l meet some problem.
image
i change the batch_szie from 64 to 2,but meet this problem again ,so do you have any other way to solve this problem,please tell me.i really want to achieve pbt thanks

Dependencies

FWIW, in order to get the code to run in a newly created conda env, I needed to install these dependencies:
torchvision
mkl-service
psycopg2
tqdm

and in order to create the datasets
bcolz

Might be worth listing in the README

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.