Giter Site home page Giter Site logo

mnistcrnn's Introduction

mnistCRNN

Testing TimeDistributedConvolution layers with GRU layers

Requirements

This code is built on keras, and is a demonstration of how to use the new TimeDistributed wrapper in 1.0.2 for convolutional-recurrent neural networks (previously compatible with 0.3.2 - for backwards compatibility find previous commits or change the Reshape layer). It was previously built on keras-extra, but keras has since merged TimeDistributed as a wrapper for arbitrary layers.

Task

The addMNISTrnn.py script downloads the MNIST dataset and creates training vectors with varying numbers of images in them. It then trains a CRNN on the sequences to predict the sum of the digits.

Validation

The model achieves a RMSE of 1.10 on the task of guessing the sum of 1 to 8 digits, with the baseline of guessing the distribution's mean having RMSE 11.81.

Notes

A simpler model can do just as well on this task, but this one has multiple conv layers and multiple GRU layers in order to demonstrate how they interact.

#define our time-distributed setup
inp = Input(shape=(maxToAdd, size, size, 1))
x = TimeDistributed(Conv2D(8, (4, 4), padding='valid', activation='relu'))(inp)
x = TimeDistributed(Conv2D(16, (4, 4), padding='valid', activation='relu'))(x)
x = TimeDistributed(Flatten())(x)
x = GRU(units=100, return_sequences=True)(x)
x = GRU(units=50, return_sequences=False)(x)
x = Dropout(.2)(x)
x = Dense(1)(x)
model = Model(inp, x)
rmsprop = RMSprop()
model.compile(loss='mean_squared_error', optimizer=rmsprop)

Example Output

alt1 alt1 alt1 alt1 alt1 alt1 alt1 alt1

The output for this input is 38.82

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.