Giter Site home page Giter Site logo

ap229997 / drnet Goto Github PK

View Code? Open in Web Editor NEW
47.0 4.0 11.0 487 KB

PyTorch implementation of the NIPS 2017 paper - Unsupervised Learning of Disentangled Representations from Video

Python 97.25% Lua 1.66% Shell 1.09%
drnet video-representation disentangled-representations unsupervised-learning adversarial-loss pytorch

drnet's Introduction

DRNET

PyTorch implementation of the NIPS 2017 paper "Unsupervised Learning of Disentangled Representations from Video" [Link]

Introduction

The authors present a new model DRNET that learns disentangled image representations from video. They utilise a novel adversarial loss to learn a representation that factorizes each frame into a stationary part and a temporally varying component. They evaluate the approach on a range of synthetic (MNIST, SUNCG) and real (KTH Actions) videos, demonstrating the ability to coherently generate hundreds of steps into the future.

Setup

This repository is compatible with python 2.

  • Follow instructions outlined on PyTorch Homepage for installing PyTorch (Python2).
  • Follow instructions outlined on Lua Download page for installing Lua. This is required since the script for converting the KTH Actions dataset is provided in Lua.

Downloading and Preprocessing data

Detailed instructions for downloading and preprocessing data are provided by edenton/drnet.

KTH Actions dataset

Download the KTH action recognition dataset by running:

sh datasets/download_kth.sh /my/kth/data/path/

where /my/kth/data/path/ is the directory the data will be downloaded into. Next, convert the downloaded .avi files into .png's for the data loader. To do this you'll want ffmpeg installed. Then run:

th datasets/convert_kth.lua --dataRoot /my/kth/data/path/ --imageSize 128

The --imageSize flag specifiec the image resolution. The models implemented in this repository are for image size 128 or greater. However, they can also be used for lesser image resolution by decreasing the number of convolution blocks in the network architecture.

The file utils.py contains the dataloader for processing KTH Actions data further.

MNIST, SUNCG

The file utils.py contains the functionality for downloading and processing the MNIST and SUNCG datasets while running the model.

Train the model

Different architectures are used for training the model on different datasets. This can be set by specifying the --dataset parameter while calling main.py. Different networks are used in the paper - base, base with skip connections, lstm for sequential predictions. These models can be trained by running the following commands: (Other parameters to be specified are described in main.py. Refer to it for better understanding)

  • Training the base model python main.py
  • Training the base model with skip connections python main.py --use_skip True
  • Training the lstm model for sequential predictions python main.py --use_lstm True

Training loss curves

Training loss curves for (left) reconstruction loss, (center) similarity loss (with the base model) and (right) mse loss (with lstm model) vs the number of iterations are shown here. These results are on the MNIST dataset.

Also, on the running the training code for MNIST provided by [edenton/drnet-py] for 10 epochs, the corresponding loss values obtained are shown here. Each epoch constitutes 600 iterations.

The values for the similarity loss goes quickly to zero in both the cases and good convergence can be observed in case of reconstruction loss and mse loss for the lstm model.

Alternate Implementations

Alternate implementation for this paper are also available. Refer to them for better understanding.

Citation

If you find this code useful, please consider citing the original work by authors:

@inproceedings{Denton2017NeurIPS,
title = {Unsupervised Learning of Disentangled Representations from Video},
author = {Denton, Emily L and Birodkar, vighnesh},
booktitle = {Advances in Neural Information Processing Systems (NeurIPS)},
year = {2017}
}

drnet's People

Contributors

ap229997 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar

drnet's Issues

Why the rec_loss returned from function train_main_network() is divided by args.batch_size in main.py?

  1. Why the rec_loss returned from function train_main_network() is divided by args.batch_size in main.py? As we know, nn.MSELoss() already returns the mean loss by default.
    line 183: return sim_loss.data/args.batch_size, rec_loss.data/args.batch_size

  2. Are the default parameters just right for the MNIST dataset? I run python main.py for 4500 iterations with the default parameters, and the reconstruction loss had converged to 0.000126. But the visualization of the reconstructed image is almost a black picture. I'm really confused where I make a mistake. Could you give me any instructions?

Thank you verrrrrry much!
Best,
Dong

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.