Giter Site home page Giter Site logo

bmhrl's Introduction

BMHRL

The BiModal Hierarchical Reinforcement Learning Agent is a hierarchical captioning model using audiovisual data to generate captions for videos. It has been trained on the ActivityNetCaptions dataset, achieving a Meteor Score of 10.80. We utilise the bimodal attention mechanism and dataloder provided by Iashin and Rahtu

โ€ข ArXiv

Model

Getting Started


1. Clone the repository

git clone --recursive https://github.com/d-rothen/bmhrl

2. Download features & checkpoints

Download features (I3D and VGGish) and word embeddings (GloVe). The script will download them (~10 GB) and unpack into ./data and ./.vector_cache folders. Make sure to run it while being in BMHRL folder

bash ./download_data.sh

Set up a conda environment

conda env create -f ./req.yml
conda activate bmhrl
python -m spacy download en

Dowload the critic checkpoint https://drive.google.com/file/d/1fidxz-WocOTsXN0gWnWpHl3VWEiKz0NX/view?usp=sharing

(optionally) Dowload a baseline to warmstart the model: https://drive.google.com/drive/folders/13I6BW4SreXEQgmFLgyCok9K-qOnOyYkY?usp=sharing

Download the best performing model: https://drive.google.com/drive/folders/1zULCCntv8ZdQ3-n-EhmlZLLdJJsHIyDP?usp=sharing

3. Run the model

Important Parameters:

  • procedure : [train_rl_cap]
  • mode : [BMHRL, eval] (use eval to forgo training and evaluate on a pretrained model)
  • video_features_path: ( path of downloaded video features )
  • audio_features_path: ( path of downloaded audio features )
  • rl_critic_path: ( path of the downloaded critic checkpoint file )

Optional:

  • rl_warmstart_epochs: ( # of epochs the model trains with standard KL Divergence )
  • rl_pretrained_model_dir: ( directory of the checkpoints the model is initialised with, required for eval mode)
  • rl_train_worker: [True, False] ( model trains the worker before training the manager )
  • B: ( batchsize )

Example:

python main.py --procedure train_rl_cap --mode BMHRL --rl_warmstart_epochs 2 --rl_pretrained_model_dir /home/xxxx/BMHRL/log/train_rl_cap/baseline/checkpoints/E_3 --rl_train_worker True --B 16 --rl_critic_path /home/xxxx/BMHRL/data/critic.cp  --video_features_path /nas/BMHRL/data/i3d_25fps_stack64step64_2stream_npy/ --audio_features_path /nas/BMHRL/data/vggish_npy/

The best performing model checkpoints as well as tensorboard logs will be saved in the ./log directory.

View results via

tensorboard --logdir ./log/[procedue]

bmhrl's People

Contributors

v-iashin avatar d-rothen avatar xanthan011 avatar

Stargazers

Sun Jiahui avatar Yunchuan Ma avatar  avatar  avatar

Watchers

Kostas Georgiou avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.