Giter Site home page Giter Site logo

cnnsplitter's Introduction

Patching Weak Convolutional Neural Network Models through Modularization and Composition

Abstract

This repository includes the code and experimental data in our paper entitled "Patching Weak Convolutional Neural Network Models through Modularization and Composition".

In this paper, we propose a structured modularization approach, CNNSplitter, which decomposes a strong CNN model for $N$-class classification into $N$ CNN modules. Each module is a sub-model containing a part of the convolution kernels of the strong model. To patch a weak CNN model with low performance on a target class, we compose the weak CNN model with the corresponding module obtained from a strong CNN model. The ability of the weak CNN model to recognize the target class can thus be improved through patching.

Requirements

  • python 3.8.10
  • pytorch 1.8.1
  • numpy 1.19.2
  • tqdm 4.61.0
  • matplotlib 3.4.2
  • seaborn 0.11.1
  • GPU with CUDA support is also needed

How to install

Install the dependent packages via pip:

$ pip install numpy==1.19.2 tqdm==4.61.0 matplotlib==3.4.2 seaborn==0.11.1

Install pytorch according to your environment, see https://pytorch.org/.

How to modularize a trained CNN model

  1. modify global_configure.py to set the root_dir.
  2. run python train.py --model simcnn --dataset cifar10 to get the pre-trained model SimCNN-CIFAR.
  3. run python kernel_importance_analyzer.py --model simcnn --dataset cifar10 in directory preprocess/ to get the importance of each kernel in SimCNN-CIFAR.
  4. run python run_layer_sensitivity_analyzer.py --model simcnn --dataset cifar10 in directory scripts/ to analyze the sensitivity of SimCNN-CIFAR.
  5. modify configures/simcnn_cifar10.py to set the configures of GA searching.
  6. run python module_recorder.py --model simcnn --dataset cifar10.
  7. run python module_explorer.py --model simcnn --dataset cifar10 --target_class 0 with 10 instances in parallel (--target_class from 0 to 9), each of which searches for one class.

We provide the four trained CNN models and the corresponding modules, as well as the weak models.
One can download data/ from here and reuse a module to patch a weak CNN model following the description below.

How to patch a weak CNN model

preparing

  1. run python module_output_collector.py --model simcnn --dataset cifar10 in directory preprocess/ to collect the outputs of 10 modules.

Patching a simple model

  1. run python train.py --model simcnn --dataset cifar10 in directory experiments/patch/patch_for_weak_model to train an overly simple SimCNN-CIFAR.
  2. run python apply_patch.py --model simcnn --dataset cifar10 --exp_type weak --target_class 0 --target_epoch 99 in directory experiments/patch to patch the simple SimCNN-CIFAR.

Patching an overfitting/underfitting model

  1. run python train.py --model simcnn --dataset cifar10 in directory experiments/patch/patch_for_poor_model to train an overfitting/underfitting SimCNN-CIFAR.
  2. run python apply_patch.py --model simcnn --dataset cifar10 --exp_type poor_fit --target_class 0 --target_epoch 169 in directory experiments/patch to patch the overfitting SimCNN-CIFAR.
  3. run python apply_patch.py --model simcnn --dataset cifar10 --exp_type poor_fit --target_class 0 --target_epoch 84 in directory experiments/patch to patch the underfitting SimCNN-CIFAR.

cnnsplitter's People

Contributors

qibinhang avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.