Giter Site home page Giter Site logo

mattdl / clsurvey Goto Github PK

View Code? Open in Web Editor NEW
193.0 9.0 24.0 169 KB

Continual Hyperparameter Selection Framework. Compares 11 state-of-the-art Lifelong Learning methods and 4 baselines. Official Codebase of "A continual learning survey: Defying forgetting in classification tasks." in IEEE TPAMI.

Home Page: https://ieeexplore.ieee.org/abstract/document/9349197

License: Other

Python 98.82% Shell 1.18%
continual-learning tpami defy-forgetting classification-tasks deep-learning neural-networks framework hyperparameter-tuning inaturalist tinyimagenet

clsurvey's Introduction

A continual learning survey: Defying forgetting in classification tasks

This is the original source code for the Continual Learning survey paper "A continual learning survey: Defying forgetting in classification tasks" published at TPAMI [TPAMI paper] [Open-Access paper].

This work allows comparing the state-of-the-art in a fair fashion using the Continual Hyperparameter Framework, which sets the hyperparameters dynamically based on the stability-plasticity dilemma. This addresses the longstanding problem in literature to set hyperparameters for different methods in a fair fashion, using ONLY the current task data (hence without using iid validation data, which is not available in continual learning).

The code contains a generalizing framework for 11 SOTA methods and 4 baselines in Pytorch.
Implemented task-incremental methods are

SI | EWC | MAS | mean/mode-IMM | LWF | EBLL | PackNet | HAT | GEM | iCaRL

These are compared with 4 baselines:

Joint | Finetuning | Finetuning-FM | Finetuning-PM

  • Joint: Learn from all task data at once with a single head (multi-task learning baseline).
  • Finetuning: standard SGD
  • Finetuning with Full Memory replay: Allocate memory dynamically to incoming tasks.
  • Finetuning with Partial Memory replay: Divide memory a priori over all tasks.

This source code is released under a Attribution-NonCommercial 4.0 International license, find out more about it in the LICENSE file.

Pipeline

Reproducibility: Results from the paper can be obtained from src/main_'dataset'.sh. Full pipeline example in src/main_tinyimagenet.sh .

Pipeline: Constructing a custom pipeline typically requires the following steps.

  1. Project Setup
    1. For all requirements see requirements.txt. Main packages can be installed as in
      conda create --name <ENV-NAME> python=3.7
      conda activate <ENV-NAME>
      
      # Main packages
      conda install -c conda-forge matplotlib tqdm
      conda install pytorch torchvision cudatoolkit=10.2 -c pytorch
      
      # For GEM QP
      conda install -c omnia quadprog
      
      # For PackNet: torchnet 
      pip install git+https://github.com/pytorch/tnt.git@master
      
    2. Set paths in 'config.init' (or leave default)
      1. '{tr,test}_results_root_path': where to save training/testing results.
      2. 'models_root_path': where to store initial models (to ensure same initial model)
      3. 'ds_root_path': root path of your datasets
    3. Prepare dataset: see src/data/"dataset"_dataprep.py (e.g. src/data/tinyimgnet_dataprep.py)
  2. Train any out of the 11 SOTA methods or 4 baselines
    1. Regularization-based/replay methods: We run a first task model dump, for Synaptic Intelligence (SI) as it acquires importance weights during training. Other methods start from this same initial model.
    2. Baselines/parameter isolation methods: Start training sequence from scratch
  3. Evaluate performance, sequence for testing on a task is saved in dictionary format under test_results_root_path defined in config.init.
  4. Plot the evaluation results, using one of the configuration files in utilities/plot_configs

Implement Your Method

  1. Find class "YourMethod" in methods/method.py. Implement the framework phases (documented in code).
  2. Implement your task-based training script in methods: methods/"YourMethodDir". The class "YourMethod" will call this code for training/eval/processing of a single task.

Project structure

  • src/data: datasets and automated preparation scripts for Tiny Imagenet and iNaturalist.
  • src/framework: the novel task incremental continual learning framework. main.py starts training pipeline, specify --test argument to perform evaluation with eval.py.
  • src/methods: all methods source code and method.py wrapper.
  • src/models: net.py all model preprocessing.
  • src/utilities: utils used across all modules and plotting.
  • Config:

Credits

Support

  • If you have troubles, please open a Git issue.
  • Have you defined your method in the framework and want to share it with the community? Send a pull request!

clsurvey's People

Contributors

mattdl avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

clsurvey's Issues

Bug about finetuning

when I run the experiment of fintuning, the code appears bugs. Can you publish main_tingimgnet.sh document for fintune?
Thank you.

Hello,I what to know some detail about code

Excuse me, I'd like to ask some questions.

  1. Is it true that all data sets are kept at the same size in the experiment?
  2. Which labels are adopted in the VOC dataset?
  3. Does the experiment provide a mutli-head setting for each method?

If you have time, can you help me answer it? Thank you very much.

Multihead classification

Hi, thank you for the great baseline repo.
I am trying to setup each dataset (with different number of classes) as a task and perform continual learning. However I am a little lost regarding how the models are handling the multiple heads with potentially different output features. Do you have any suggestions on how this might be addressed?

Currently I am using something like:

class MultiTaskModel(ResNet):
    def __init__(self):
        super(MultiTaskModel, self).__init__(BasicBlock, [2,2,2,2])
        resnet = torchvision.models.resnet34(pretrained=True)
        self.in_feature = resnet.fc.in_features
        self.tasks = []
        self.fc = None

        # add all layers that are not fc or classifier to the model
        self.shared = nn.Sequential()
        for name, module in resnet.named_children():
            if name != 'fc' and name != 'classifiers':
                self.shared.add_module(name, module)
        # self.classifiers.append(resnet.fc)

    
    def set_task(self, task):
        print("Setting task to", task)
        self.tasks.append(task)
        print(f"tasks are {self.tasks}")
        print(f"task index is {task_list.index(task)}")
        # add a new fc layer for the new task
        self.fc = nn.Linear(self.in_feature, classes_per_task[task_list.index(task)])
        self.fc.apply(kaiming_normal_init)
        print(f"fc is {self.fc}")
        
    def forward(self, x):
        x = self.shared(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

But unable to replicate the results for some methods, such as the EWC . If I use the default lambda = 400 as in the repo, the loss becomes Nan .
Currently, I am maintaining the sequence of recogseq dataset but observing Catastrophic forgetting for EWC and LWF. To be precise, I calculated the performance of each task after training the last task.

This is my training loop:

for task in dataset_names:
    train_loader, val_loader, _, _, _ = get_dataloaders(task, 0.8, batch_size)
    all_train_loaders[task] = train_loader
    all_val_loaders[task] = val_loader

for idx, task in tqdm.tqdm(enumerate(task_list)):
    current_train_loader = all_train_loaders[task]
    current_val_loader = all_val_loaders[task]
    model = MultiTaskModel().to(device)
    if idx > 0:
        print(f"Previous task: {task_list[idx-1]}")
        ckpt = torch.load(f"epoch_{task_list[idx-1]}.pth.tar")
        model.load_state_dict(ckpt['state_dict'])
        model = model.to(device)

    start_time = time.time()
    model, acc = fine_tune_EWC_acuumelation(current_train_loader, current_val_loader, model, reg_lambda=1, num_epochs=num_epochs, lr=0.008, batch_size=batch_size, weight_decay=0, current_task=task)
    ```

Here's a sample output : 

```bash
Model loaded for task svhn

Performance of previous task: flowers
fc set to Linear(in_features=512, out_features=103, bias=True)
Accuracy of the network on the 1311 test images: 3.75

Performance of previous task: scenes
fc set to Linear(in_features=512, out_features=67, bias=True)
Accuracy of the network on the 3123 test images: 5.013020833333333

Performance of previous task: birds
fc set to Linear(in_features=512, out_features=201, bias=True)
Accuracy of the network on the 2358 test images: 0.9982638888888888

Performance of previous task: cars
fc set to Linear(in_features=512, out_features=196, bias=True)
Accuracy of the network on the 1621 test images: 0.6875

Performance of previous task: aircraft
fc set to Linear(in_features=512, out_features=56, bias=True)
Accuracy of the network on the 2000 test images: 15.574596774193548

Performance of previous task: chars
fc set to Linear(in_features=512, out_features=63, bias=True)
Accuracy of the network on the 12599 test images: 43.247767857142854

Performance of previous task: svhn
fc set to Linear(in_features=512, out_features=10, bias=True)
Accuracy of the network on the 26032 test images: 96.13223522167488

Running accuracy for task svhn is [3.75, 5.013020833333333, 0.9982638888888888, 0.6875, 15.574596774193548, 43.247767857142854, 96.13223522167488]
Mean accuracy for task svhn is 23.629054939319072

How to run the entire pipeline?

Hi~
Thanks for your great job. This code is a great help to the community and people who just started.
I have checked the main_tinyimagenet.sh and only find the run with SI and EBLL. Forgive me for just being a newcomer, and I am lost in the code.
I want to know how to run the whole pipeline with 11 methods reported in your paper.
Thanks!

Request for a template for evaluating other sotas

Thanks for the share of the CL study frame! However I wonder how to use iCarl with the framework. I have modified the main_tinyimagenet.sh file via adding static_hyperparams term to the arguments, but still got several errors. Could you please provide a shell file for running iCarl ? Appreciate that.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.