Giter Site home page Giter Site logo

tylerlima / conditional-dcgan-for-mnist Goto Github PK

View Code? Open in Web Editor NEW

This project forked from sarahwolf32/conditional-dcgan-for-mnist

0.0 2.0 0.0 246 KB

A conditional DCGAN, in Tensorflow, for generating hand-written digits from the MNIST dataset.

Python 100.00%

conditional-dcgan-for-mnist's Introduction

Conditional DCGAN for MNIST

This is a generative model for the hand-written digits of the MNIST dataset. It combines the DCGAN architecture recommended by Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks (Radford et al) with the inputting of labels suggested in Conditional Generative Adversarial Nets (Mirza).

Why a Conditional GAN?

In my last project, I used a DCGAN to generate MNIST digits in an unsupervised fashion - although MNIST is a labeled dataset, I threw away the labels at the beginning and did not use them. This worked, but of course those labels held a great deal of useful information. It would have been nice to allow the GAN to benefit from that additional input, and it would have also been nice to be able to specify which digit I wanted the trained generator to create.

Conditional GANs tackle these shortcomings by feeding the labels into both the Generator and Discriminator.

This has a couple of effects. For example, in the unsupervised DCGAN, the random vector z input controlled everything about the resulting digit - including which digit it was. Since that role is taken over by the labels in a conditional GAN, the z input here encodes all the other features (rotation, style, and so on).

Feeding in the labels also affected training. I found that the architecture that had worked in my last project quickly suffered from mode collapse when I used the corresponding version here. Apparently, the labels made it easier for the Discriminator to do its job, allowing the Discriminator to "win" the minimax game prematurely. The generator lost the gradients it needed to learn and started outputting identical black images.

Using fewer layers and larger filters stabilized training. See trainer/architecture.py for details.

Results

Once I used a suitable architecture, the cDCGAN converged relatively quickly. Below are four randomly sampled digits from each category (0 - 9) that were generated by the finished model:

Trained Model

To use:

  1. Download the trained model here.

  2. Unzip it and drag into the project directory.

  3. Navigate into the project directory, and run python -m trainer.task --sample [NUM_SAMPLES_PER_CLASS]. The results will be saved to the samples/all_samples folder by default.

If you want to store the trained model somewhere else, just include --checkpoint-dir [YOUR_PATH] in the command.

If you want to output the samples to another location, just include --sample-dir [YOUR_PATH] in the command.

Train Your Own (MNIST)

If you want to tweak this code and train your own version from scratch, you can find the main code in trainer/task.py. To train, you will need to:

  1. Download the MNIST data here.
  2. cd into the project directory
  3. Run python -m trainer.task --data-dir [YOUR_PATH_TO_MNIST_DATA] to start training.

Train Your Own (Other Dataset)

If you have a dataset of low resolution, categorically labeled images and want to generate new ones with this code, you should only have to:

  1. Edit the trainer/architecture.py file for your desired input image size, number of label categories, and architecture. DCGANs are very sensitive to architecture, so you may need to try multiple configurations.

  2. Edit the _load_data method in trainer/dataset_loader.py file to unwrap your dataset and shape it into the given format.

  3. Edit trainer/train_config.py to set your preferred training configurations (batch size, num epochs, output filepaths, etc.). I have a separate set of filepath defaults for local and remote training, since I tend to train in the cloud, so hopefully this is useful to you as well. Use the TrainConfig.is_local = True/False property to toggle between local and remote modes.

I hope this is helpful!

To start training, run python -m trainer.task from the project directory.

Acknowledgements

conditional-dcgan-for-mnist's People

Contributors

sarahwolf32 avatar

Watchers

 avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.