Giter Site home page Giter Site logo

sahagobinda / gpm Goto Github PK

View Code? Open in Web Editor NEW
78.0 78.0 16.0 40 KB

Official [ICLR] Code Repository for "Gradient Projection Memory for Continual Learning"

License: MIT License

Python 99.81% Shell 0.19%
computer-vision continual-learning deep-learning optimizaiton pytorch

gpm's People

Contributors

sahagobinda avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

gpm's Issues

How to understand the different between GPM

from my understanding, if X=U\Sigam V, then P=X(X^\top X)^{-1}X^\top=UU^\top, so the main difference is that GPM does not use U but (U)_k, and that GPM does calculate P but direct memory (U)_k ?

Dataloader for miniImagenet

Hi, thanks for sharing the code for your project. Can you please also share the dataloaders/preprocessing files for mini Imagenet dataset?

Could you please provide the code for miniimagenet?

Hi, thanks for your good work.
But when I repeat the result for mini-imagenet by changing the code for five datasets, I can not repeat the result that you write in your paper.
Could you please provide the code for mini-imagenet?

About the code

Hi! I have a question about the function get_representation_matrix. Why was the batch_list set to [2 * 12, 100, 100, 125, 125]? And can I set it to another batch_list? Thanks!

About batch normalization

Hi, thank you for your interesting and promising work. I have a question about the implementation of batch normalization. I wonder why in the code of all models, "track_running_stats" is set as "False" for all BN modules. This means that this module will always use the mean and variance of the current batch data during feedforward inference (both for training and testing). Then the testing of the model will also be influenced by the input batch size and the sequence of testing batches, which may lead to performance fluctuation under different test settings (for example, batch statistics will be inaccurate when batch size is 1 and BN is notorious for performance drop with small batch size). The common practice for BN is to track running mean and variance of the training data and BN is set to the "eval" mode during testing so that the statistics will use the tracked mean and variance, and the testing results will be consistent.

In the paper, it is said that "batch normalization parameters are learned for the first task and shared with all the other tasks (following Mallya & Lazebnik (2018))". However, I checked the code of PackNet (https://github.com/arunmallya/packnet), they do not set "track_running_stats" as "False", but they track the statistics of the first task and set BN to the "eval" mode for the remaining tasks so that statistics are fixed. So they follow the common practice of BN and the testing results will be consistent. I wonder if there is any additional consideration for setting "track_running_stats=False" in this implementation, and how would the model with this setting be influenced by different testing settings, e.g. with different test batch sizes?

what‘s meaning about 15 in cifar100

what's meaning about 15, and why fill zero to parameters' gradient?

GPM/main_cifar100.py

Lines 150 to 157 in 1a238ec

for k, (m,params) in enumerate(model.named_parameters()):
if k<15 and len(params.size())!=1:
sz = params.grad.data.size(0)
params.grad.data = params.grad.data - torch.mm(params.grad.data.view(sz,-1),\
feature_mat[kk]).view(params.size())
kk +=1
elif (k<15 and len(params.size())==1) and task_id !=0 :
params.grad.data.fill_(0)

Download wrong fashionMNIST

HI, thanks for your goord work. I find that the code in five_datasets.py downloads the wrong dataset, which is actually downloading MNIST instead of fashionMNIST. And using correct fashionMNIST, the total accuracy of 5 datasets may drop 2-3% according to my experiments. I think maybe you can check your code and figure out this issue.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.