Giter Site home page Giter Site logo

csinva / imodelsx Goto Github PK

View Code? Open in Web Editor NEW
74.0 6.0 15.0 35.69 MB

Scikit-learn friendly library to interpret, and prompt-engineer text datasets using large language models.

Home Page: https://csinva.io/imodelsX/

License: MIT License

Python 79.57% Jupyter Notebook 20.43%
ai deep-learning explainability huggingface interpretability language-model machine-learning ml natural-language-processing natural-language-understanding

imodelsx's Introduction

Scikit-learn friendly library to explain, predict, and steer text models/data.
Also a bunch of utilities for getting started with text data.

๐Ÿ“– demo notebooks

Explainable modeling/steering

Model Reference Output Description
Tree-Prompt ๐Ÿ—‚๏ธ, ๐Ÿ”—, ๐Ÿ“„, ๐Ÿ“–, Explanation
+ Steering
Generates a tree of prompts to
steer an LLM (Official)
iPrompt ๐Ÿ—‚๏ธ, ๐Ÿ”—, ๐Ÿ“„, ๐Ÿ“– Explanation
+ Steering
Generates a prompt that
explains patterns in data (Official)
AutoPrompt ใ…คใ…ค๐Ÿ—‚๏ธ, ๐Ÿ”—, ๐Ÿ“„ Explanation
+ Steering
Find a natural-language prompt
using input-gradients (โŒ› In progress)
D3 ๐Ÿ—‚๏ธ, ๐Ÿ”—, ๐Ÿ“„, ๐Ÿ“– Explanation Explain the difference between two distributions
SASC ใ…คใ…ค๐Ÿ—‚๏ธ, ๐Ÿ”—, ๐Ÿ“„ Explanation Explain a black-box text module
using an LLM (Official)
Aug-Linear ๐Ÿ—‚๏ธ, ๐Ÿ”—, ๐Ÿ“„, ๐Ÿ“– Linear model Fit better linear model using an LLM
to extract embeddings (Official)
Aug-Tree ๐Ÿ—‚๏ธ, ๐Ÿ”—, ๐Ÿ“„, ๐Ÿ“– Decision tree Fit better decision tree using an LLM
to expand features (Official)

๐Ÿ“–Demo notebooks โ€ƒ ๐Ÿ—‚๏ธ Doc โ€ƒ ๐Ÿ”— Reference code โ€ƒ ๐Ÿ“„ Research paper
โŒ› We plan to support other interpretable algorithms like RLPrompt, CBMs, and NBDT. If you want to contribute an algorithm, feel free to open a PR ๐Ÿ˜„

General utilities

Model Reference
๐Ÿ—‚๏ธ LLM wrapper Easily call different LLMs
๐Ÿ—‚๏ธ Dataset wrapper Download minimially processed huggingface datasets
๐Ÿ—‚๏ธ Bag of Ngrams Learn a linear model of ngrams
๐Ÿ—‚๏ธ Linear Finetune Finetune a single linear layer on top of LLM embeddings

Quickstart

Installation: pip install imodelsx (or, for more control, clone and install from source)

Demos: see the demo notebooks

Natural-language explanations

Tree-prompt

from imodelsx import TreePromptClassifier
import datasets
import numpy as np
from sklearn.tree import plot_tree
import matplotlib.pyplot as plt

# set up data
rng = np.random.default_rng(seed=42)
dset_train = datasets.load_dataset('rotten_tomatoes')['train']
dset_train = dset_train.select(rng.choice(
    len(dset_train), size=100, replace=False))
dset_val = datasets.load_dataset('rotten_tomatoes')['validation']
dset_val = dset_val.select(rng.choice(
    len(dset_val), size=100, replace=False))

# set up arguments
prompts = [
    "This movie is",
    " Positive or Negative? The movie was",
    " The sentiment of the movie was",
    " The plot of the movie was really",
    " The acting in the movie was",
]
verbalizer = {0: " Negative.", 1: " Positive."}
checkpoint = "gpt2"

# fit model
m = TreePromptClassifier(
    checkpoint=checkpoint,
    prompts=prompts,
    verbalizer=verbalizer,
    cache_prompt_features_dir=None,  # 'cache_prompt_features_dir/gp2',
)
m.fit(dset_train["text"], dset_train["label"])


# compute accuracy
preds = m.predict(dset_val['text'])
print('\nTree-Prompt acc (val) ->',
      np.mean(preds == dset_val['label']))  # -> 0.7

# compare to accuracy for individual prompts
for i, prompt in enumerate(prompts):
    print(i, prompt, '->', m.prompt_accs_[i])  # -> 0.65, 0.5, 0.5, 0.56, 0.51

# visualize decision tree
plot_tree(
    m.clf_,
    fontsize=10,
    feature_names=m.feature_names_,
    class_names=list(verbalizer.values()),
    filled=True,
)
plt.show()

iPrompt

from imodelsx import explain_dataset_iprompt, get_add_two_numbers_dataset

# get a simple dataset of adding two numbers
input_strings, output_strings = get_add_two_numbers_dataset(num_examples=100)
for i in range(5):
    print(repr(input_strings[i]), repr(output_strings[i]))

# explain the relationship between the inputs and outputs
# with a natural-language prompt string
prompts, metadata = explain_dataset_iprompt(
    input_strings=input_strings,
    output_strings=output_strings,
    checkpoint='EleutherAI/gpt-j-6B', # which language model to use
    num_learned_tokens=3, # how long of a prompt to learn
    n_shots=3, # shots per example
    n_epochs=15, # how many epochs to search
    verbose=0, # how much to print
    llm_float16=True, # whether to load the model in float_16
)
--------
prompts is a list of found natural-language prompt strings

D3 (DescribeDistributionalDifferences)

from imodelsx import explain_dataset_d3
hypotheses, hypothesis_scores = explain_dataset_d3(
    pos=positive_samples, # List[str] of positive examples
    neg=negative_samples, # another List[str]
    num_steps=100,
    num_folds=2,
    batch_size=64,
)

SASC

Here, we explain a module rather than a dataset

from imodelsx import explain_module_sasc
# a toy module that responds to the length of a string
mod = lambda str_list: np.array([len(s) for s in str_list])

# a toy dataset where the longest strings are animals
text_str_list = ["red", "blue", "x", "1", "2", "hippopotamus", "elephant", "rhinoceros"]
explanation_dict = explain_module_sasc(
    text_str_list,
    mod,
    ngrams=1,
)

Aug-imodels

Use these just a like a scikit-learn model. During training, they fit better features via LLMs, but at test-time they are extremely fast and completely transparent.

from imodelsx import AugLinearClassifier, AugTreeClassifier, AugLinearRegressor, AugTreeRegressor
import datasets
import numpy as np

# set up data
dset = datasets.load_dataset('rotten_tomatoes')['train']
dset = dset.select(np.random.choice(len(dset), size=300, replace=False))
dset_val = datasets.load_dataset('rotten_tomatoes')['validation']
dset_val = dset_val.select(np.random.choice(len(dset_val), size=300, replace=False))

# fit model
m = AugLinearClassifier(
    checkpoint='textattack/distilbert-base-uncased-rotten-tomatoes',
    ngrams=2, # use bigrams
)
m.fit(dset['text'], dset['label'])

# predict
preds = m.predict(dset_val['text'])
print('acc_val', np.mean(preds == dset_val['label']))

# interpret
print('Total ngram coefficients: ', len(m.coefs_dict_))
print('Most positive ngrams')
for k, v in sorted(m.coefs_dict_.items(), key=lambda item: item[1], reverse=True)[:8]:
    print('\t', k, round(v, 2))
print('Most negative ngrams')
for k, v in sorted(m.coefs_dict_.items(), key=lambda item: item[1])[:8]:
    print('\t', k, round(v, 2))

General utilities

Easy baselines

Easy-to-fit baselines that follow the sklearn API.

from imodelsx import LinearFinetuneClassifier, LinearNgramClassifier
# fit a simple one-layer finetune on top of LLM embeddings
m = LinearFinetuneClassifier(
    checkpoint='distilbert-base-uncased',
)
m.fit(dset['text'], dset['label'])
preds = m.predict(dset_val['text'])
acc = (preds == dset_val['label']).mean()
print('validation acc', acc)

LLM wrapper

Easy API for calling different language models with caching (much more lightweight than langchain).

import imodelsx.llm
# supports any huggingface checkpoint or openai checkpoint (including chat models)
llm = imodelsx.llm.get_llm(
    checkpoint="gpt2-xl",  # text-davinci-003, gpt-3.5-turbo, ...
    CACHE_DIR=".cache",
)
out = llm("May the Force be")
llm("May the Force be") # when computing the same string again, uses the cache

Data wrapper

API for loading huggingface datasets with basic preprocessing.

import imodelsx.data
dset, dataset_key_text = imodelsx.data.load_huggingface_dataset('ag_news')
# Ensures that dset has a split named 'train' and 'validation',
# and that the input data is contained for each split in a column given by {dataset_key_text}

Related work

  • imodels package (JOSS 2021 github) - interpretable ML package for concise, transparent, and accurate predictive modeling (sklearn-compatible).
  • Adaptive wavelet distillation (NeurIPS 2021 pdf, github) - distilling a neural network into a concise wavelet model
  • Transformation importance (ICLR 2020 workshop pdf, github) - using simple reparameterizations, allows for calculating disentangled importances to transformations of the input (e.g. assigning importances to different frequencies)
  • Hierarchical interpretations (ICLR 2019 pdf, github) - extends CD to CNNs / arbitrary DNNs, and aggregates explanations into a hierarchy
  • Interpretation regularization (ICML 2020 pdf, github) - penalizes CD / ACD scores during training to make models generalize better
  • PDR interpretability framework (PNAS 2019 pdf) - an overarching framewwork for guiding and framing interpretable machine learning

imodelsx's People

Contributors

arminaskari avatar csinva avatar divyanshuaggarwal avatar jxmorris12 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar

imodelsx's Issues

name 'openai' is not defined

Hello, I am using the code:
import numpy as np
from imodelsx import explain_module_sasc

a toy module that responds to the length of a string

mod = lambda str_list: np.array([len(s) for s in str_list])

a toy dataset where the longest strings are animals

text_str_list = ["red", "blue", "x", "1", "2", "hippopotamus", "elephant", "rhinoceros"]
explanation_dict = explain_module_sasc(
text_str_list,
mod,
ngrams=1,
)
print(explanation_dict)
Errors will be encountered:
name 'openai' is not defined
How to solve this problem. thank you.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.