Giter Site home page Giter Site logo

dumpmemory / learning-scaffold Goto Github PK

View Code? Open in Web Editor NEW

This project forked from coderpat/learning-scaffold

0.0 0.0 0.0 3.84 MB

This is the official implementation for the paper "Learning to Scaffold: Optimizing Model Explanations for Teaching"

Jupyter Notebook 95.78% Python 4.21% Shell 0.02%

learning-scaffold's Introduction

Scaffold-Maximizing Training (SMaT)

Python Lint

This is the official implementation for the paper Learning to Scaffold: Optimizing Model Explanations for Teaching.


Abstract: Modern machine learning models are opaque, and as a result there is a burgeoning academic subfield on methods that explain these models’ behavior. However, what is the precise goal of providing such explanations, and how can we demonstrate that explanations achieve this goal? Some research argues that explanations should help teach a student (either human or machine) to simulate the model being explained, and that the quality of explanations can be measured by the simulation accuracy of students on unexplained examples. In this work, leveraging meta-learning techniques, we extend this idea to improve the quality of the explanations themselves by optimizing them to improve the training of student models to simulate original model. We train models on three natural language processing and computer vision tasks, and find that students trained with explanations extracted with our framework are able to simulate the teacher significantly more effectively than ones produced with previous methods. Through human annotations and a user study, we further find that these learned explanations more closely align with how humans would explain the required decisions in these tasks.


Requirements

The code is based on the JAX. Please refer to the project page to see how to install the correct version for your system. We used both jax==0.2.24 jaxlib==0.1.72 and jax==0.3.1 jaxlib==0.3.0+cuda11.cudnn82.

It also depends on two custom forks. The forks are required because neither Flax nor Transformers allow extracting unnormalized attention:

All requirements except jax can be install by running

pip install -r requirements.txt

Quickly train explainers for you model

The smat package contains a wrapper function that allows you to quickly train explainers for your model. All you need to do is wrap your model into a special class, and define some parameters for smat.

import jax, flax
from smat import *

# wrap model with
@smat.models.register_model('my_model')
class MyModel(smat.models.WrappedModel):
      ...

# get data and model
train_data, valid_data, dataloader = get_data()
model, params = get_trained_model()

explainer, expl_params = smat.compact.train_explainer(
    task_type="classification",
    teacher_model=model,
    teacher_params=params,
    dataloader=dataloader,
    train_dataset=train_data,
    valid_dataset=valid_data,
    num_examples=0.1,
    student_model="my_model",
)

See the example for a more concrete case on applying SMAT to explain BERT predictions on STT-2 (not in the paper!)

Please report any bugs you find by opening an issue.

Train models and explainers

To train a teacher model run

python smat/train.py \
      --setup no_teacher \
      --task $task \
      --arch $arch \
      --model-dir $teacher_dir \
      --do-save

To train a student model learning from this teacher model with num_samples training examples, run

python smat/train.py \
      --setup static_teacher \
      --task $task \
      --arch $arch \
      --num-examples $num_examples \
      --teacher-dir $teacher_dir \
      --do-save

Finally to train a student model AND an explainer for the teacher run

python smat/train.py \
      --setup learnable_teacher \
      --num-examples $num_examples \
      --teacher-dir $teacher_dir 
      --teacher-explainer-dir $teacher_explainer_dir \
      --do-save

Workflows

To run experiments using the workflow manager ducttape. For parallel jobs, also place these scripts somewhere in your $PATH

The experiments are organized into two files

  • tapes/main.tape: This contains the task definitions. It's where you should add new tasks and functionally or edit previously defined ones.
  • tapes/EXPERIMENT_NAME.tconf: This is where you define the variables for experiments, as well as which tasks to run.

To start off, we recommend creating you own copy of tapes/imdb.tconf. This file is organized into two parts: (1) the variable definitions at the global block (2) the plan definition

To start off, you need to edit the variables to correspond to paths in your file systems. Examples include the $repo variable and the ducttape output folder.

Then try running one of the existing plans by executing

ducttape tapes/main.tape -C $my_tconf -p PaperResults -j $num_jobs

Annotation Tool for Visual Explanations

We also make the source code for the annotation tool available. See the annotation-tool sub-repo for more details.

learning-scaffold's People

Contributors

coderpat avatar mtreviso avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.