Giter Site home page Giter Site logo

anml's Introduction

ANML: Learning to Continually Learn (ECAI 2020)

arXiv Link

Continual lifelong learning requires an agent or model to learn many sequentially ordered tasks, building on previous knowledge without catastrophically forgetting it. Much work has gone towards preventing the default tendency of machine learning models to catastrophically forget, yet virtually all such work involves manually-designed solutions to the problem. We instead advocate meta-learning a solution to catastrophic forgetting, allowing AI to learn to continually learn. Inspired by neuromodulatory processes in the brain, we propose A Neuromodulated Meta-Learning Algorithm (ANML). It differentiates through a sequential learning process to meta-learn an activation-gating function that enables context-dependent selective activation within a deep neural network. Specifically, a neuromodulatory (NM) neural network gates the forward pass of another (otherwise normal) neural network called the prediction learning network (PLN). The NM network also thus indirectly controls selective plasticity (i.e. the backward pass of) the PLN. ANML enables continual learning without catastrophic forgetting at scale: it produces state-of-the-art continual learning performance, sequentially learning as many as 600 classes (over 9,000 SGD updates).

How to Run

First, install Anaconda for Python 3 on your machine.

Next, install PyTorch and Tensorboard

pip install torch
pip install tensorboardX

Then clone the repository:

git clone https://github.com/shawnbeaulieu/ANML.git

Meta-train your network(s). To modify the network architecture, see modelfactory.py in the model folder. Depending on the architecture you choose, you may have to change how the data is loaded and/or preprocessed. See omniglot.py and task_sampler.py in the datasets folder.

python mrcl_classification.py --rln 7 --meta_lr 0.001 --update_lr 0.1 --name mrcl_omniglot --steps 20000 --seed 9 --model_name "Neuromodulation_Model.net"

Evaluate your trained model. RLN tag specifies which layers you want to fix during the meta-test training phase. For example, to have no layers fixed, run:

python evaluate_classification.py --rln 0  --model Neuromodulation_Model.net --name Omni_test_traj --runs 10

Prerequisites

Python 3 PyTorch 1.4.0 Tensorboard

Built From

anml's People

Contributors

ncheney avatar shawnbeaulieu avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

anml's Issues

FIrst order gradients are used for outer loop optimization

Hi, there was a bug on the original code of OML where they did not create the computation graph for the inner loop updates. Looking the ANML code it seems to have the same issue, specifically here it is currently

grad = torch.autograd.grad(loss, fast_weights, allow_unused=False)

but to correctly backpropagate through the inner optimization it should be

grad = torch.autograd.grad(loss, fast_weights, allow_unused=False, create_graph=True)

I was wondering which version was used for the results on the paper, OML's author said fixing this bug improved performance and reduced training time.
Thanks

p.d.: congratz on the work, is really cool

hyperparameters of OML results

Hello :)
First of all, thanks for your hard work
It have give me new insight of considering CLP in meta learning

By the way, can you explain the hyper-parameters used for training OML in your code.
I am trying to reproduce OML in your repository, but facing hard time finding the right parameter that matches the results of your paper. It would be really helpful if you can give the exact script like the one you've written for training ANML model.

Thanks again ๐Ÿ˜€

ANML in image segmentation task

Can ANML training regime be used for image segmentation task? Or is there any other method for meta learning as such ANML for image segmentation task?

Output layer size in the meta-training and meta-testing phases

My understanding is that you use a single fully connected layer on top of the neuro-modulated representations.

  1. Does this output layer has 963 nodes during meta-training, since you are performing 963-class classification during meta-training ?
  2. If yes, how do you use the meta-learned output layer to learn/perform 600-class classification at meta-testing time ? Or do you randomly initialize a new output layer with 600 nodes?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.