Giter Site home page Giter Site logo

jdc08161063 / visdial-challenge-starter-pytorch Goto Github PK

View Code? Open in Web Editor NEW

This project forked from batra-mlp-lab/visdial-challenge-starter-pytorch

0.0 2.0 0.0 61 KB

Starter code in PyTorch for the Visual Dialog challenge

License: BSD 3-Clause "New" or "Revised" License

Shell 2.25% Python 84.07% Lua 13.68%

visdial-challenge-starter-pytorch's Introduction

Visual Dialog Challenge Starter Code

PyTorch starter code for the Visual Dialog Challenge.

Setup and Dependencies

Our code is implemented in PyTorch (v0.3.0 with CUDA). To setup, do the following:

If you do not have any Anaconda or Miniconda distribution, head over to their downloads' site before proceeding further.

Clone the repository and create an environment.

git clone https://www.github.com/batra-mlp-lab/visdial-challenge-starter-pytorch
conda env create -f env.yml

This creates an environment named visdial-chal with all the dependencies installed.

If you wish to extract your own image features, you require a Torch distribution. Skip everything in this subsection from here if you will not extract your own features.

git clone https://github.com/torch/distro.git ~/torch --recursive
cd ~/torch; bash install-deps;
TORCH_LUA_VERSION=LUA51 ./install.sh

Additionally, image feature extraction code uses torch-hdf5, torch/image and torch/loadcaffe. After Torch is installed, these can be installed/updated using:

luarocks install image
luarocks install loadcaffe

Installation instructions for torch-hdf5 are given here. Optionally, these packages are required for GPU acceleration:

luarocks install cutorch
luarocks install cudnn
luarocks install cunn

Note: Since Torch is in maintenance mode now, it requires CUDNN v5.1 or lower. Install it separately and set $CUDNN_PATH environment variable to the binary (shared object) file.

Download Preprocessed Data

We provide preprocessed files for VisDial v1.0 (tokenized captions, questions, answers, image indices, vocabulary mappings and image features extracted by pretrained CNN). If you wish to preprocess data or extract your own features, skip this step.

Extracted features for v1.0 train, val and test are available for download here.

  • visdial_data_train.h5: Tokenized captions, questions, answers, image indices, for training on train
  • visdial_params_train.json: Vocabulary mappings and COCO image ids for training on train
  • data_img_vgg16_relu7_train.h5: VGG16 relu7 image features for training on train
  • data_img_vgg16_pool5_train.h5: VGG16 pool5 image features for training on train
  • visdial_data_trainval.h5: Tokenized captions, questions, answers, image indices, for training on train+val
  • visdial_params_trainval.json: Vocabulary mappings and COCO image ids for training on train+val
  • data_img_vgg16_relu7_trainval.h5: VGG16 relu7 image features for training on train+val
  • data_img_vgg16_pool5_trainval.h5: VGG16 pool5 image features for training on train+val

Download these files to data directory. If you are downloaded just one file each for visdial_data*.h5, visdial_params*.json, data_img*.h5, it would be convenient to rename them and remove everything represented by asterisk. These names are used in default arguments of train and evaluate scripts.

Preprocessing VisDial

Download all the images required for VisDial v1.0. Create an empty directory anywhere and place four subdirectories with the downloaded images, named:

  • train2014 and val2014 from COCO dataset, used by train split.
  • VisualDialog_val2018 and VisualDialog_test2018 - can be downloaded from here.

This shall be referred as the image root directory.

cd data
python prepro.py -download -image_root /path/to/images
cd ..

This script will generate the files data/visdial_data.h5 (contains tokenized captions, questions, answers, image indices) and data/visdial_params.json (contains vocabulary mappings and COCO image ids).

Extracting Image Features

Since we don't finetune the CNN, training is significantly faster if image features are pre-extracted. Currently this repository provides support for extraction from VGG-16 and ResNets. We use image features from VGG-16.

To extract image features using VGG-16, run the following:

sh data/download_model.sh vgg 16
cd data

th prepro_img_vgg16.lua -imageRoot /path/to/images -gpuid 0

Similary, to extract features using ResNet, run:

sh data/download_model.sh resnet 200
cd data
th prepro_img_resnet.lua -imageRoot /path/to/images -cnnModel /path/to/t7/model -gpuid 0

Running either of these should generate data/data_img.h5 containing features for train, val and test splits corresponding to VisDial v1.0.

Training

This codebase supports discriminative decoding only; read more here. For reference, we have Late Fusion Encoder from the Visual Dialog paper.

Training works with default arguments by:

python train.py -encoder lf-ques-im-hist -decoder disc -gpuid 0  # other args

The script has all the default arguments, so it works without specifying any arguments. Execute the script with -h to see a list of available arguments which can be changed as per need (such as learning rate, epochs, batch size, etc).

To extend this starter code, add your own encoder/decoder modules into their respective directories and include their names as choices in command line arguments of train.py.

We have an -overfit flag, which can be useful for rapid debugging. It takes a batch of 5 examples and overfits the model on them.

Evaluation

Evaluation of a trained model checkpoint can be done as follows:

python evaluate.py -split val -load_path /path/to/pth/checkpoint -use_gt

To evaluate on metrics from the Visual Dialog paper (Mean reciprocal rank, R@{1, 5, 10}, Mean rank), use the -use_gt flag. Since the test split has no ground truth, -split test won't work here.

Note: The metrics reported here would be the same as those reported through EvalAI by making a submission in val phase.

Generate Submission

To save predictions in a format submittable to the evaluation server on EvalAI, run the evaluation script (without using the -use_gt flag).

To generate a submission file for val phase:

python evaluate.py -split val -load_path /path/to/pth/checkpoint -save_ranks -save_path /path/to/submission/json

To generate a submission file for test-std or test-challenge phase, replace -split val with -split test.

Pretrained Checkpoint

Pretrained checkpoint of Late Fusion Encoder - Discriminative Decoder model is available here.

Performance on v1.0 val (trained on v1.0 train):

R@1 R@5 R@10 MeanR MRR
0.4194 0.7345 0.8387 5.9876 0.5650

Acknowledgements

  • This starter code began as a fork of batra-mlp-lab/visdial-rl. We thank the developers for doing most of the heavy-lifting.
  • The Lua-torch codebase of Visual Dialog, at batra-mlp-lab/visdial, served as an important reference while developing this codebase.

License

BSD

visdial-challenge-starter-pytorch's People

Watchers

James Cloos avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.