Giter Site home page Giter Site logo

ashleyredy / attention-boosted-deep-networks-for-video-classification Goto Github PK

View Code? Open in Web Editor NEW

This project forked from junyongyou/attention-boosted-deep-networks-for-video-classification

1.0 0.0 0.0 218 KB

This is a implementation of integrating a simple but efficient attention block in CNN + bidirectional LSTM for video classification.

License: MIT License

Python 100.00%

attention-boosted-deep-networks-for-video-classification's Introduction

Attention-boosted-deep-networks-for-video-classificaton

This is a implementation of integrating a simple but efficient attention block in CNN + bidirectional LSTM for video classification.

Requirements

Please install the packages listed in requirements.txt. Anaconda + PyCharm are recommended.

Train the model

Run Python train.py cnn_model attention_mode(optional) input_path dataset_name output_path feature_extraction(optional)

Training a model for the first time, it is recommended to use the feature_extraction argument, which can extract image features by CNN and then store them in npy files.

Please see train.py for details about the training arguments.

def parse_args(args):
    parser = argparse.ArgumentParser(description='Simple script for attention integrated CNN + LSTM video classification training')
    parser.add_argument('cnn_model', help='Specify which CNN model is used (VGG16/VGG19/InceptionV3/Resnet50/Xception')
    parser.add_argument('--attention_mode', help='Specify how to add the attention block (after LSTM: cnn-lstm-attention, before LSTM: cnn-attention-lstm; no attention: cnn-lstm', default='cnn-lstm-attention')
    parser.add_argument('input_path', help='Specify the input data folder path')
    parser.add_argument('--dataset_name', help='Specify the dataset name (UCF-101/Sports-1M)', default='UCF-101')
    parser.add_argument('output_path', help='Specify the output path')
    parser.add_argument('--feature_extraction', help='Specify whether or not do feature extraction first', default=False)
    return check_args(parser.parse_args())

After the training is complete, respective config information and model will be stored in the output_path folder, which can be used in prediction of new video class.

Predict video class

Run Python predict.py cnn_model model_path video_path config_path

def parse_args(args):
    parser = argparse.ArgumentParser(description='Simple script for attention integrated CNN + LSTM video classification')
    parser.add_argument('cnn_model', help='Specify which CNN model is used (VGG16/VGG19/InceptionV3/Resnet50/Xception')
    parser.add_argument('model_path', help='Specify the model path')
    parser.add_argument('video_path', help='Specify the input video path')
    parser.add_argument('config_path', help='Specify the config file path')
    return parser.parse_args(args)

Datasets

Information about two datasets (UCF101 and Sports-1M) is presented in the utils module, including the 99 video classes together with download links from the Sports-1M.

attention-boosted-deep-networks-for-video-classification's People

Contributors

junyongyou avatar

Stargazers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.