Giter Site home page Giter Site logo

class_specific_cnn's Introduction

README

Usage

  • For usage see main.py

Finetuning after decomposition on pretrained resnet20

The weights are staged in the repo. This will evaluate the model first print the accuracy before fine-tuning and then do fine tuning after decomposition to demonstrate that original accuracy can be restored. Resnet20 with CIFAR-10 happens fairly quickly. Resnet50 with Imagenet takes a few epochs.

python main.py finetune --model resnet20 --weights resnet20-12fca82f.th --finetune-method decomposed -k 5 -b 256

Command line example for Resnet50 + Imagenet:

python main.py finetune --model resnet50 --finetune-method decomposed -k 64 -b 64

You don't need to specify weights file for Resnet50 as it'll use pytorch pretrained model.

Retrieving influential indices

If indices file is not given, then the code will automatically find the selective indices from the final layer. To get the indices for any layer, you'll have to give the name of the submodule in pytorch. You can do a [x[0] for x in model.named_modules()] at a python/IPython prompt after loading the model to see the names of submodules.

Example to retrieve from layer2 -> conv2 -> ReLU, that is, after ReLU of second conv module from "layer" 2.

python main.py get_inds --model resnet20 --weights resnet20-12fca82f.th --layer-name layer2.1.relu2 -k 5 -b 256 -g 0

This will store the indices in a file indices_layer2.1.relu2.pkl

If you want top $k \in (0, 1)$ fraction for each class, you can do:

python main.py get_inds --model resnet20 --weights resnet20-12fca82f.th --layer-name layer2.1.relu2 -k .5 -b 256 -g 0

Gives you top half most influential indices.

Calculating Selectivity $\Psi$

$\Psi$ for each layer and class $c$ is calculated at some $k$, $\Psi^c_k$. It gives a number $\in (0,1)$ how selective each layer is towards a certain class. $k$ can be float or int. But in the paper we only calculated as float so it's only implemented like that. It gives you a dictionary of selectivity per class at some layer. To get that:

python main.py calc_psi --model resnet20 --weights resnet20-12fca82f.th --layer-name layer3.2 -k .5 -b 256 -g 0

Calculating Selectivity $\mu$

$\mu$ calculates the Selectivity as given by Leavitt et al.

The functions are there in main.py. Will write example later.

TODO

Citation

If you use any part of this code, please consider citing:

@article{badola2021decomposing,
  url={https://doi.org/10.1007/s00521-023-08441-z},
  doi={10.1007/s00521-023-08441-z},
  year={2023},
  month={June},
  voluem={35},
  issue={18},
  pages={13583-13596},
  journal={Neural computing & applications},
  author={Badola, Akshay and Roy, Cherian and Padmanabhan, V. and Lal, R.},
  title={Decomposing the deep: finding class-specific filters in deep CNNs},
}

class_specific_cnn's People

Contributors

akshaybadola avatar

Stargazers

 avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.