Giter Site home page Giter Site logo

gmvandeven / continual-learning Goto Github PK

View Code? Open in Web Editor NEW
1.5K 29.0 304.0 3.21 MB

PyTorch implementation of various methods for continual learning (XdG, EWC, SI, LwF, FROMP, DGR, BI-R, ER, A-GEM, iCaRL, Generative Classifier) in three different scenarios.

License: MIT License

Python 36.13% Shell 0.57% Jupyter Notebook 63.30%
deep-learning artificial-neural-networks continual-learning lifelong-learning incremental-learning replay distillation generative-models variational-autoencoder elastic-weight-consolidation

continual-learning's People

Contributors

gmvandeven avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

continual-learning's Issues

Task-IL evaluation

I recently read your paper titled 'Three types of incremental learning', and I'm very grateful for the meticulous breakdown. In table2,3, you presented evaluations for the same 'method' under both task-incremental and class-incremental learning scenarios. The only difference between two scenarios is whether giving or not giving 'context' information on training and evaluating processes, isn't it?

It's relatively easy to understand the class-IL evaluations since there are many papers on class-IL methods, so I assume it's about calculating accuracy across all labels. For the task-IL evaluations, I assume it's about calculating accuracy by limiting the predicted space within the labels of a given context, and then averaging this for all contexts. Is this how you did it, or did you use a different approach?

If you have any references or other papers that you referred to for your task_IL accuracy measurement, I would appreciate it if you could share them with me.

Grad in SI

Hi,

I am recently reading your excellent continual-learning implementation, in particular about the SI. In the following line of code, you used p.grad, which is the gradient of the regularized loss. However, based on my understanding about SI, the gradient should be computed merely on the data loss, so that it measures how much each weight contributes to the fitting error of the present task. Am I wrong about it, or I missed important factors in your implementation? Thanks ahead for your clarification.

W[n].add_(-p.grad*(p.detach()-p_old[n]))

Thank you for your wonderful package!

Hi Gido M. van de Ven,

Thank you for your well written package. This would be of great contribution to the community. Thanks!

Can you kindly confirm whether you were able to match results with the baseline performances reported in the corresponding papers, by the authors?

Thanks,
Joseph

Question about Online EWC

Hi,
and thank you for your work.

I have a doubt about the implementation of Online EWC. Specifically, I refer to the following line of code.

fisher = self.gamma*fisher if self.online else fisher

To the best of my understanding, gamma is decay applied to prior Fisher matrices when updating its estimate, but shouldn't affect how the regularization loss is computed (Eq. (7) in Sec A.2.2).

Could you please provide intuition on this matter?
Best,
D

Suspicious Precision

Why precision is 0 for tasks 1-4?

Precision on test-set:

  • Task 1: 0.0000
  • Task 2: 0.0000
  • Task 3: 0.0000
  • Task 4: 0.0000
  • Task 5: 0.9945
    => Average precision over all 5 tasks: 0.1989

Final results of class incremental learning

Hello! I am sorry but I am confused that the code shows the test result of the final task as the final result of class incremental learning. In my opinion, the results of class incremental learning per task should be the average result of all previous tasks and current task. I would like to hear any suggestion, thank you!

iCaRL's accuracy different from 'Three scenarios for continual learning'

I have ran the code in the repo with just one change: in _compare.py: changed 'scenario' to 'class'.

The results is as follows

############################################################
SUMMARY RESULTS: splitMNIST  -  class-incremental learning
------------------------------------------------------------
None         19.94
EWC          19.93
o-EWC        19.91
SI           19.95
LwF          24.00
DGR          91.69
DGR+distil   90.42
iCaRL        72.83
Offline      97.59
############################################################

I tried with two seeds (1, 99) both gave ~72 for iCaRL. The reported accuracy in the paper is ~94.

Am I missing something that is leading to this discrepancy?

Wrong dataset?

train_datasets[task_id], batch_size_to_use, cuda=cuda, drop_last=True

Hello! My team and I are currently running experiments with your BI-R repo. We are trying to find the limit between ER and GR approaches and see how much we can degrade examplars (stored or generate) to acheive similar accuracies. As such, we ported your implementation of ER to the other repo. We noticed that on the line 123 of train.py, that the dataset train_datasets is used rather than previous_datasets. Is train_datasets not the full scale dataset? We would love to have your input on this potential issue and if you have and ablation/directions of experiments to run, please do tell!
Btw, your code is absolutely beautiful and so well documented!

Empirical Fisher Estimation

It seems convenient to average the gradients over samples by calling F.nll_loss before squaring them, as we only need one backward pass. However, I feel like the diagonal of the empirical Fisher information matrix should be calculated by squaring the gradients before taking their average (as done in this Tensorflow implementation). Can you please confirm that the order doesn't matter here?

My understanding is that the expected values of the gradients are 0 (see this Wiki), so if you do averaging first, the Fisher values are very close to 0, which seems incorrect. Am I missing something here? Please let me know what you think. Thank you.

# calculate negative log-likelihood
negloglikelihood = F.nll_loss(F.log_softmax(output, dim=1), label)
# Calculate gradient of negative loglikelihood
self.zero_grad()
negloglikelihood.backward()
# Square gradients and keep running sum
for n, p in self.named_parameters():
if p.requires_grad:
n = n.replace('.', '__')
if p.grad is not None:
est_fisher_info[n] += p.grad.detach() ** 2
# Normalize by sample size used for estimation
est_fisher_info = {n: p/index for n, p in est_fisher_info.items()}

0 accuracy values for task-free setting

Hello,

I tried compare_task_free.py and main_task_free.py script for a setting where task boundaries are not available and --iters=1 and --budget=0 but such a setting either throws errors or gives 0 accuracy values for all tasks and 1.0 for last class of CIFAR10. I set --contexts 10 for this experiment. I would highly appreciate your help in this matter.

Thank you!

Lower/Upper Bound Experiments

How to run "None – lower bound" and "Offline – upper bound" experiments? I am not able to find any flags for that.

Whether context identity must be inferred in case of domain increment?

In your paper "Three types of incremental learning" domain increments do not require contextual identity to be inferred. But in the paper you cited "[17] An Efficient Domain-Incremental Learning Approach to Drive in All Weather Conditions" it is mentioned that "A requirement of DISC and the proposed domain-IL scenario is to have access to the task-ID during inference , as it is done in task-incremental learning approaches”. Is it because the task_ID can be obtained through the car sensor, so there is no need to infer the contextual identity?

about kafc fisher infromation matrix

Hi, thanks for your great work! I met the following error when I turn fisher_kfac to True :

raise Exception(f"Layer {label} does not have phantom parameters")
Exception: Layer fcLayer1 does not have phantom parameters

Cloud you please help me find what's wrong with this exception?
Thanks!

Knowledge Distillation Loss

Hey,
In order to compute the cross entropy between the "soft" targets and the predictions you do the following:
KD_loss_unnorm = (-targets_norm * log_scores_norm).mean() #--> average over batch

Wouldnt the correct cross entropy with mean over the batch be:

KD_loss_unnorm = (-targets_norm * log_scores_norm).sum(dim=1).mean()

Reproducing BI+SI method

Hi! Firstly, I wanted to tell thank you for your great work. It was really interesting to read your paper. And you code is really thought-through.
I wanted to ask how can I reproduce the results for brain-inspired combined with synaptic intelligence method?
I tried to run the following command which gave me 18% accuracy:
python main.py --experiment=CIFAR100 --scenario=class --brain-inspired --SI --seed=0 --pre-convE --freeze-convE --seed-to-ltag --time
I also tried:
python main.py --experiment=CIFAR100 --scenario=class --brain-inspired --SI --seed=0 --pre-convE --seed-to-ltag --time --reg-strength=100000000 (10^8) --dg-prop=0.6
as suggested in generative-classifiers paper if I understood correctly. However, even training accuracy dropped to 0 for context 2 during the training.
Can you suggest me correct arguments to reproduce the results from "Brain-inspired replay for continual learning with artificial neural networks" for main.py script?

How to create Resnet34

I am trying to create a resnet34 using the code you provided to conduct the continual learning task, but I have not succeeded. Could you tell me how to create resnet34 using the code you provided?

Link error

The link of "van de Ven et al. (2022, Nature Machine Intelligence" in "NeurIPS_tutorial.md" shows error.

permutedMNIST accs

Hey - thank you for the good implementation of all these methods. Very helpfull.
To start a permutedMNIST run, I executed

python main.py --experiment 'permMNIST' --scenario 'task' --tasks 10 --replay=generative --distill --feedback --iters 5000

iters need to be 5000 to get the results reported in the paper, correct?

why batch_size has to be 1 when update fisher?

Hi,

Thanks for the great repo. I have a quick question about the computation in the Fisher Information Matrix update: does the batch_size have to be 1 for the dataloader here:

data_loader = utils.get_data_loader(dataset, batch_size=1, cuda=self._is_on_cuda(), collate_fn=collate_fn)
? My main concern is about the speed here. Is that equivalent if I use a larger batch size?

Thank you so much in advance! :D

Performance

hey again!

when I execute
./main.py --ewc --online --lambda=5000 --gamma=1 --scenario task

this should be close to 99% acc no?

For EWC and SI I get much worse performance with the default values.
What am I doing wrong? Thank you!

Just a request

Hey, Haven't found any contact of yours on your page, so asking here.
Note: This is not an issue, but a concept question?

I have a text data stream coming daily on different domains like agriculture, space, biology, etc...
I train them to detect domains, but new domains are incoming too, so how can train new data preserving the learnings of previous ones?

You have specified various forgetting prevention techniques on MNIST, which do you think would work on text data using let us say bert as an encoder?

Waiting for your reply! Thanks!

one little confusion about the loss_fn_kd function

Many thanks for your impressive project. Here I am a few confused about the .detach() in the below code,

targets_norm = torch.cat([targets_norm.detach(), zeros_to_add], dim=1)

which is defined in

def loss_fn_kd(scores, target_scores, T=2.):

Refer to the blog, PyTorch .detach() method , .detach() will take the targets_norm as one fixed parameter in the the KD_loss, and the backpropagation will not update the parameters along the targets_norm related branch.

However, in your another project, brain-inspired-replay, the same loss function, loss_fn_kd uses,

 targets_norm = torch.cat([targets_norm, zeros_to_add], dim=1)

as shown in line 29, in which no .detach() is attached.

Although the same results all these two types I have tested, I am still confused about how does the second type work?

average EWC pickle

hey again, it seems that you save incorrect accuracies (averages) in the pickle. In the first list of "all_tasks" should be the accuracies tested when all tasks are trained correct? I did not check other methods but for online EWC they seem wrong whereas when training on SI it seems correct. Can you check that? Thanks!

Datasets more complicated than MNIST

First of all, thank you for releasing this code.

Do you know of any results for generative replay (i.e. where images from across tasks are generated) on datasets more complicated than MNIST? For example CIFAR or ImageNet.

It seems your feedback connections paper, three scenarios for continual learning paper and the original deep generative replay paper only test on MNIST. Did you try any other datasets? Do you think there is something about the combination of more complex natural images + the continual training of the generator that makes it difficult? Because surely someone has tried.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.