Giter Site home page Giter Site logo

vd-pcr's Introduction

VD-PCR

PyTorch implementation for the paper:

VD-PCR: Improving visual dialog with pronoun coreference resolution [arxiv]

Xintong Yu, Hongming Zhang, Ruixin Hong, Yangqiu Song, Changshui Zhang

The visual dialog task requires an AI agent to interact with humans in multi-round dialogs based on a visual environment. As a common linguistic phenomenon, pronouns are often used in dialogs to improve the communication efficiency. As a result, resolving pronouns (i.e., grounding pronouns to the noun phrases they refer to) is an essential step towards understanding dialogs. In this paper, we propose VD-PCR, a novel framework to improve Visual Dialog understanding with Pronoun Coreference Resolution in both implicit and explicit ways. First, to implicitly help models understand pronouns, we design novel methods to perform the joint training of the pronoun coreference resolution and visual dialog tasks. Second, after observing that the coreference relationship of pronouns and their referents indicates the relevance between dialog rounds, we propose to explicitly prune the irrelevant history rounds in visual dialog models' input. With pruned input, the models can focus on relevant dialog history and ignore the distraction in the irrelevant one. With the proposed implicit and explicit methods, VD-PCR achieves state-of-the-art experimental results on the VisDial dataset.

The readers are welcome to star/fork this repository and use it to train your own model, reproduce our experiment, and follow our future work. Please kindly cite our paper:

@article{DBLP:journals/pr/YuZHSZ22,
  author    = {Xintong Yu and
               Hongming Zhang and
               Ruixin Hong and
               Yangqiu Song and
               Changshui Zhang},
  title     = {{VD-PCR:} Improving visual dialog with pronoun coreference resolution},
  journal   = {Pattern Recognit.},
  volume    = {125},
  pages     = {108540},
  year      = {2022},
  url       = {https://doi.org/10.1016/j.patcog.2022.108540},
  doi       = {10.1016/j.patcog.2022.108540},
  biburl    = {https://dblp.org/rec/journals/pr/YuZHSZ22.bib},
  bibsource = {dblp computer science bibliography, https://dblp.org}
}

Table of Contents

Setup and Dependencies

Our code is implemented in PyTorch (v1.7). To setup, do the following:

  1. Install Python 3.7+
  2. Get the source:
git clone https://github.com/HKUST-KnowComp/VD-PCR VD-PCR
  1. Install requirements into the vd-pcr virtual environment, using Anaconda:
conda env create -f environment.yml
  1. Install apex (optional for acceleration in Phase 1).

Usage

The basic usage is to run main.py with --model [MODEL_TYPE]/[MODEL_NAME] --mode [MODE], where [MODEL_TYPE] is one of (conly, vonly, joint), [MODEL_NAME] is defined in corresponding config/[MODEL_TYPE].conf, and [MODE] is one of (train, eval, predict, debug).

Use GPU as the environment variable CUDA_VISIBLE_DEVICES.

Experiment setups are stored in config/*.conf. The explaination of hyperparameters are in config.

All the hyperparameters and estimated time are conducted on 7 GeForce RTX 2080 Ti with NVIDIA-SMI 430.64 and CUDA 10.1. If the code is conducted on different GPU settings, you can try changing batch_size, sequences_per_image, eval_batch_size, and eval_line_batch_size in configs.


Download preprocessed data

Download preprocessed VisDial dataset with VisPro annotations and put them under data/all folder.

Preprocess data to fit ViLBERT and VisPro format. Correct dense annotation by setting gt answer score to 1 following Agarwal et al.:

python preprocessing/pre_process_visdial.py
python preprocessing/get_vispro_only.py --processed
python preprocessing/correct_dense.py

Download the extracted features of ViLBERT:

mkdir -p data/visdial/visdial_img_feat.lmdb
wget https://s3.amazonaws.com/visdial-bert/data/visdial_image_feats.lmdb/data.mdb -O data/visdial/visdial_img_feat.lmdb/data.mdb
wget https://s3.amazonaws.com/visdial-bert/data/visdial_image_feats.lmdb/lock.mdb -O data/visdial/visdial_img_feat.lmdb/lock.mdb

Download pre-trained checkpoints

Download the corresponding checkpoints if you need to apply inference or training of the model.

  • For direct inference of Phase 1 model and direct training of Phase 2 model: download the checkpoint epoch_best.ckpt and put it under the folder logs/joint/MB-JC.

  • For direct inference of Phase 2 model: download the checkpoint epoch_2.ckpt and put it under the folder logs/vonly/MB-JC-HP-crf_cap-trainval.

  • For training of Phase 0 and Phase 1 model: download pre-trained checkpoints of ViLBERT.

mkdir checkpoints-release
wget https://s3.amazonaws.com/visdial-bert/checkpoints/bestmodel_no_dense_finetuning -O checkpoints-release/basemodel

Direct Inference

Phase 1: Joint training for VD-PCR

Evaluating on VisDial val set with the MB-JC model

GPU=0,1,2,3,4,5,6 python main.py --model joint/MB-JC --mode eval

The expected output is:

r@1: 0.5691261887550354
r@5: 0.8593204021453857
r@10: 0.93708735704422
mean: 3.1439805030822754
mrr: 0.6984539031982422
ndcg: 0.4472862184047699

Inferencing on VisDial test set with the MB-JC model

GPU=0,1,2,3,4,5,6 python main.py --model vonly/MB-JC_predict --mode predict
python postprocessing/merge_predictions.py --mode phase1 --model MB-JC_predict

The prediction file can be submitted to VisDial test server.

Evaluating on VisPro test set with the MB-JC model

GPU=0 python main.py --model conly/MB-JC_eval --mode eval

The expected output is:

Pronoun_Coref_average_precision:        0.9217
Pronoun_Coref_average_recall:   0.8376
Pronoun_Coref_average_f1:       0.8776

Phase 2: History pruning for VD-PCR

Prune history inputs by the crf_cap rule:

python preprocessing/extract_relevant_history.py --include_cap --save_name crf_cap

Inferencing on VisDial test set with the MB-JC-HP model

GPU=0,1,2,3,4,5,6 python main.py --model vonly/MB-JC-HP-crf_cap-test --mode predict
python ensemble.py --exp convert --mode predict

The prediction file can be submitted to VisDial test server.


Training

For any phase, you can set --mode debug to debug the code.

During training, checkpoints are saved after each epoch finishes. If training aborts, just use the original command to resume training from the last saved checkpoints.

Phase 0: PCR model on VisDial

We merge the coreference prediction of PCR models into the preprocessed VisDial data. You can also train the MB+pseudo model from scratch with

GPU=0 python main.py --model conly/MB-pseudo --mode train

Evaluate the model with

GPU=0 python main.py --model conly/MB-pseudo --mode eval

Phase 1: Joint training for VD-PCR

Select the coreference-related heads from the base model with

GPU=0 python postprocessing/find_coref_head.py

Train the MB-JC model

GPU=0,1,2,3,4,5,6 python main.py --model joint/MB-JC --mode train

The training takes 33h.

Evaluate the MB-JC model

GPU=0,1,2,3,4,5,6 python main.py --model joint/MB-JC --mode eval

Inferencing on VisDial test set with the MB-JC model

GPU=0,1,2,3,4,5,6 python main.py --model vonly/MB-JC_predict --mode predict
python postprocessing/merge_predictions.py --mode phase1 --model MB-JC_predict

The prediction file can be submitted to VisDial test server.

Evaluating on VisPro test set with the MB-JC model

GPU=0 python main.py --model conly/MB-JC_eval --mode eval

Phase 2: History pruning for VD-PCR

Prune history inputs with various rules:

python preprocessing/extract_relevant_history.py --include_cap --save_name crf_cap
python preprocessing/extract_relevant_history.py --save_name crf
python preprocessing/extract_relevant_history.py --include_cap --q_only --save_name cap

Train the MB-JC-HP model with different pruning rules by setting [RULE] to be one of (crf_cap, crf, cap, all):

GPU=0,1,2,3,4,5,6 python main.py --model vonly/MB-JC-HP-[RULE] --mode train

The training takes 1.5h.

Evaluate the MB-JC-HP model:

GPU=0,1,2,3,4,5,6 python main.py --model vonly/MB-JC-HP-[RULE] --mode eval

Ensemble models of various rules:

python postprocessing/merge_predictions.py --mode phase2 --model MB-JC-HP-[RULE] --split val
python ensemble.py --exp val --mode eval

For final inference on test set, use both train and val sets as training data to train the MB-JC-HP model:

GPU=0,1,2,3,4,5,6 python main.py --model vonly/MB-JC-HP-[RULE]-trainval --mode train

The training takes 3.5h.

Inferencing on VisDial test set with the MB-JC-HP model

GPU=0,1,2,3,4,5,6 python main.py --model vonly/MB-JC-HP-[RULE]-test --mode predict

Ensemble predictions of various rules:

python ensemble.py --exp test --mode predict

The prediction file can be submitted to VisDial test server.


Logging

All logs are under logs folder. We use tensorboard to visualize the results:

tensorboard --logdir=./logs

Acknowledgements

Builds on vmurahari3's visdial-bert and YangXuanyue's pytorch-e2e-coref.

Others

If you have questions about the data or the code, you are welcome to open an issue or send me an email, I will respond to that as soon as possible.

vd-pcr's People

Contributors

yucosine avatar

Stargazers

 avatar  avatar Pan avatar adnen abdessaied avatar  avatar Junfeng Tian avatar  avatar Shanshan Du avatar  avatar

Watchers

James Cloos avatar Yangqiu Song avatar  avatar

vd-pcr's Issues

question about config

Hello,When I was studying the code of your thesis, some config parameters were not very clear, and I couldn’t find them in config/readme. Is there any contact information that I can ask you for advice?thanks!

Query about the use of dense annotation

Hello, from my understanding, in order to optimize retrieval metrics, VD-PCR is trained on the sparse annotation of the train set and evaluated on the sparse annotation of the val set. Then, to maximize the ranking metrics, VD-PCR is finetuned on the dense annotation of the train set and evaluated on the dense annotation of the val set. After that, the model is tested and the result was submitted to the test server. Is it correct?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.