Giter Site home page Giter Site logo

hummat / saliency Goto Github PK

View Code? Open in Web Editor NEW
17.0 0.0 3.0 8.24 MB

PyTorch implementation of 'Vanilla' Gradient, Grad-CAM, Guided backprop, Integrated Gradients and their SmoothGrad variants.

Home Page: https://hummat.github.io/saliency

License: GNU General Public License v3.0

Python 0.55% Jupyter Notebook 99.45%
deep-learning deep-neural-networks grad-cam smoothgrad saliency integrated-gradients guided-backpropagation xrai pytorch machine-learning machine-learning-algorithms deep-learning-algorithms

saliency's Introduction

Saliency Methods

Introduction

This repository contains code for the following saliency techniques:

  • XRAI
  • SmoothGrad
  • Vanilla Gradients
  • Guided Backpropogation
  • Integrated Gradients
  • (Guided) Grad-CAM

Remarks

The methods should work with all models from the torchvision package. Tested models so far are:

  • VGG variants
  • ResNet variants
  • DenseNet variants
  • Inception/GoogLeNet*

*In order for Guided Backpropagation and Grad-CAM to work properly with the Inception and GoogLeNet models, they need to by modified slightly, such that all ReLUs are modules of the model rather than function calls.

# This class can be found at the very end of inception.py and googlenet.py respectively.
class BasicConv2d(nn.Module):

    def __init__(self, in_channels, out_channels, **kwargs):
        super(BasicConv2d, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
        self.bn = nn.BatchNorm2d(out_channels, eps=0.001)
        self.relu = nn.ReLU(inplace=True)  # Add this line

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        return self.relu(x)  # Replaces F.relu(x, inplace=True)

Examples

For a brief overview on how to use the package, please have a look at this short tutorial notebook. The bare minimum is summarized below.

# Standard imports 
import torchvision

# Import desired utils and methods
from ml_utils import load_image, show_mask
from guided_backprop import GuidedBackprop

# Load model and image
model = torchvision.models.resnet50(pretrained=True)
doberman = load_image('images/doberman.png', size=224)

# Construct a saliency object and compute the saliency mask.
guided_backprop = GuidedBackprop(model)
rgb_mask = guided_backprop.get_mask(image_tensor=doberman)

# Visualize the result
show_mask(rgb_mask, title='Guided Backprop')

Credits

The implementation follows closely that of the corresponding TensorFlow saliency repository, reusing its code were applicable (mostly for the XRAI method).

Further inspiration has been taken from this repository.

saliency's People

Contributors

hummat avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.