Giter Site home page Giter Site logo

moxi's Introduction

Identifying Important Group of Pixels using Interactions[CVPR'24]

Kosuke Sumiyasu, Kazuhiko Kawamoto, Hiroshi Kera [arxiv]

Overview

heatmap_example MoXI (Model eXplanation by Interactions) is a black box game-theoretic explanation method of image classifiers. Unlike other popular methods (e.g., GradCAM and AttentionRollout), it takes into account the cooperative contributions of two pixels and accurately identifies a group of pixels that have a high impact on prediction confidence.

Installation

Clone this repo:

$ git clone https://github.com/KosukeSumiyasu/MoXI

Demo

You can walk through some examples as follows.

Step 0: download ImageNet dataset. Please specify the ImageNet dataset path in the config_file_path.yaml file for integration.

Step 1: run the following.

$ cd MoXI
$ pip install -r requirements.txt
$ ./online_identify.sh
$ ./evaluate_curve.sh

Step 2: Open Jupyter notebooks in notebook/. 00_plot_insertion_deletion_curve.ipynb --- Quantitive evaluation by insertion and deletion curves. 01_visualize_heatmap.ipynb --- Qualitative evaluation by headmaps.

Try out MoXI on your own model

We offer two implementations of MoXI.

Implentation 1 (Model-agnostic implementation).

If your model is a CNN, use this implementation.

from src.util.load_parser import load_parser
args = load_parser() # set args.interaction_method = 'pixel_zero_values'
...
model = load_your_model(...)

Implementation 2 (ViT-aware implementation).

If you use Vision Transformer models, we highly recommend using this implementation. If your model is based on ViTForImageClassification class of a HuggingFace, it’s very simple.

from src.util.load_parser import load_parser
args = load_parser() # set args.interaction_method = 'vit_embedding'
...
model = replace_vit_embedding_mask(args, model)

For example, refer to the model in "Visualize the heatmap" at https://github.com/KosukeSumiyasu/MoXI/blob/main/notebook/01_visualize_heatmap.ipynb

Otherwise, you need a slight modification in your model.

  • allow forward() functions to recieve embedding_mask keyword argument
  • call select_batch_removing() in the input embedding module. No worries; after this modification, you can still load your pre-trained weights.
from .mask_vit_embedding import select_batch_removing

class YourViTClassifier(...):
  def __init__(...):
    self.ViTModel = YourViTModel(...)
    ...
  def forward(x, ..., embedding_mask=None): # MODIFICATION: new keyword argument embedding_mask
    output = self.YourViTModel(x, embedding_mask)
    ...

class YourViTModel(...):
  def __init__(...):
    self.ViTEmbedding = YourViTEmbedding(...)
    ...
  def forward(x, ..., embedding_mask=None): # MODIFICATION: new keyword argument embedding_mask
    embedding = self.ViTEmbedding(x, embedding_mask)
    ...

class YourEmbedding(...):
  def __init(...):
    ...
  def forward(x, ..., embedding_mask=None): # MODIFICATION: new keyword argument embedding_mask
    ...
    embeddings = self.patch_embeddings(x, ...)
    embeddings = embeddings + self.position_embeddings[:, 1:, :]
    
    # MODIFICATION: two lines added.
    if embedding_masking is not None:
        embeddings = select_batch_removing(embeddings, embedding_masking)
    ...

Contact

Citation

If you find this useful, please cite:

@inproceedings{kosuke2024identifying,
  author    = {Kosuke Sumiyasu and Kazuhiko Kawamoto and Hiroshi Kera},
  title     = {Identifying Important Group of Pixels using Interactions},
  journal   = {Conference on Computer Vision and Pattern Recognition (CVPR)},
  year      = {2024}
}

moxi's People

Contributors

kosukesumiyasu avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    πŸ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. πŸ“ŠπŸ“ˆπŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❀️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.