Giter Site home page Giter Site logo

tracyyxchen / prediction-constrained-topic-models Goto Github PK

View Code? Open in Web Editor NEW

This project forked from dtak/prediction-constrained-topic-models

0.0 1.0 0.0 1.84 MB

Public repo containing code to train, visualize, and evaluate semi-supervised topic models and baselines for regression/classification on labeled bag-of-words dataset, as described in Hughes et al. AISTATS 2018

License: MIT License

Makefile 0.03% Python 90.50% HTML 0.34% Shell 9.13%

prediction-constrained-topic-models's Introduction

Prediction-Constrained Topic Models

Public repo containing code to train, visualize, and evaluate semi-supervised topic models. Also includes code for baseline classifiers/regressors to perform supervised prediction on bag-of-words datasets.

Overview

This repo is based on the following academic publication:

"Prediction-constrained semi-supervised topic models" M. C. Hughes, L. Weiner, G. Hope, T. H. McCoy, R. H. Perlis, E. B. Sudderth, and F. Doshi-Velez Artificial Intelligence & Statistics (AISTATS), 2018.

Contents

  • datasets/

  • pc_toolbox/

    • Main python package, with code for training PC topic models and some baseline classifiers/regressors.
  • scripts/

    • Bash scripts to run experiments. Support SLURM/LSF/SGE clusters.

Examples

Python script to train binary classifier from bag-of-words data

The primary script is train_and_eval_sklearn_binary_classifier.py

python train_and_eval_sklearn_binary_classifier.py \
  --dataset_path $PC_REPO_DIR/datasets/toy_bars_3x3/ \
  --output_path /tmp/demo_results/ \
  --seed 8675309 \        # random seed (for reproducibility)
  --feature_arr_names X \
  --target_arr_name Y \
  --classifier_name extra_trees \

Python script to train topic models with PC objective

The primary script is train_slda_model.py. For a quick exmaple, you might call this python script as follows:

python train_slda_model.py \
  --dataset_path $PC_REPO_DIR/datasets/toy_bars_3x3/ \
  --output_path /tmp/demo_results/ \
  --seed 8675309 \        # random seed (for reproducibility)
  --alpha 1.1 \           # scalar hyperparameter for Dirichlet prior over doc-topic probas
  --tau 1.1 \             # scalar hyperparameter for Dirichlet prior over topic-word probas
  --weight_y 5.0 \        # aka "lambda" in AISTATS paper, the key hyperparameter to emphasize y|x
  --n_laps 10 \           # number of laps (aka epochs). this will complete 10 full passes thru training dataset.
  --n_batches 1 \
  --alg_name grad_descent_minimizer \

Mostly, we use wrapper bash scripts that call this function with many different hyperparameters (model, algorithm, initialization, etc)

Quicktest: Train PC sLDA topic models on toy bars with autograd

Test script to train with autograd (ag) as source of automatic gradients:

cd $PC_REPO_DIR/scripts/toy_bars_3x3/quicktest_topic_models/
export XHOST_RESULTS_DIR=/tmp/
XHOST=local bash pcslda_ag_adam_fromscratch.sh 

Should finish in <1 minute, just demonstrate that training occurs without errors.

Quicktest: Train PC sLDA topic models on toy bars with tensorflow

Test script to train with tensorflow (tf) as source of automatic gradients:

cd $PC_REPO_DIR/scripts/toy_bars_3x3/quicktest_topic_models/
export XHOST_RESULTS_DIR=/tmp/
XHOST=local bash pcslda_tf_adam_fromscratch.sh 

Should finish in <1 minute, just demonstrate that training occurs without errors.

Train PC sLDA topic models extensively on movie_reviews dataset

Script to train with tensorflow (tf) as source of automatic gradients:

cd $PC_REPO_DIR/scripts/movie_reviews_pang_lee/train_topic_models/
export XHOST_RESULTS_DIR=/tmp/
XHOST={local|grid} bash pcslda_tf_adam_fromscratch.sh 

Use XHOST=local to run on local computer. Use XHOST=grid to launch jobs on a cluster (Sun Grid Engine, SLURM, IBM's LSF, etc).

Should finish in a few hours.

Installation

  • Step 1: Clone this repo

git clone https://github.com/dtak/prediction-constrained-topic-models/

  • Step 2: Setup a fresh conda enviroment with all required Python packages

bash $PC_REPO_DIR/scripts/install/create_conda_env.sh

  • Step 3: Compile Cython code for per-document inference (makes things very fast)

cd $PC_REPO_DIR/

python setup.py build_ext --inplace

  • Step 4: (Optional) Install tensorflow

bash $PC_REPO_DIR/scripts/install/install_tensorflow_linux.sh

Configuration

Set up your environment variables!

First, make a shortcut variable to the location of this repo, so you can easily reference datasets, etc.

$ export PC_REPO_DIR=/path/to/prediction_constrained_topic_models/

Second, add this repo to your python path, so you can do "import pc_toolbox"

$ export PYTHONPATH=$PC_REPO_DIR:$PYTHONPATH

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.