Giter Site home page Giter Site logo

hwalsuklee / tensorflow-mnist-cnn Goto Github PK

View Code? Open in Web Editor NEW
196.0 8.0 96.0 172.53 MB

MNIST classification using Convolutional NeuralNetwork. Various techniques such as data augmentation, dropout, batchnormalization, etc are implemented.

Python 100.00%
tensorflow mnist-classification mnist cnn ensemble-prediction data-augmentation dropout batch-normalization

tensorflow-mnist-cnn's People

Contributors

hwalsuklee avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

tensorflow-mnist-cnn's Issues

Tensorflow variables already exist

If I run exactly your code, Tensorflow complains about variables being already defined with, e.g., the following error:

ValueError: Variable conv1/weights already exists, disallowed. Did you mean to set reuse=True or reuse=tf.AUTO_REUSE in VarScope?

Do variables have to be reused or not? How can you set it globally?
[I'm with Tensorflow 1.3, Python 3.6 on Linux]

NotFoundError (see above for traceback): Key fc3/BatchNorm/moving_mean not found in checkpoint

Thanks for your code! And I have some questions, glad for your help.
With the Ubuntu 18.04 and the tensorflow 0.12.0, when I run
python mnist_cnn_test.py --model-dir model/model01_99.61 --batch-size 5000 --use-ensemble False
It don't work and return NotFoundError (see above for traceback): Key fc3/BatchNorm/moving_mean not found in checkpoint [[Node: save/RestoreV2_9 = RestoreV2[dtypes=[DT_FLOAT], _device="/job:localhost/replica:0/task:0/cpu:0"](_recv_save/Const_0, save/RestoreV2_9/tensor_names, save/RestoreV2_9/shape_and_slices)]]

And when I run 'python mnist_cnn_test.py --model-dir model/model01_99.61 --batch-size 5000 --use-ensemble True'
The "accuracy“ only 0.092, All models are like this.
Looking forward to your reply.

If I change test batch size,the result will be difference

I use this code on my own data. I restore model(saver.restore(sess, myModelPath)),and then test.
if I set test batch size is 1,the result will be [ 0. 0. 0. 0. 0.].
if I set test batch size is 5,the result will be [ 1. 0. 0. 0. 1.].
Why is it different

Possible bug with is_training parameter

Hello,

Going through the code of your project, I think the parameter is_training is not taken into account for the CNN model in file mnist_cnn_train.py.

I've seen that the cnn_model.CNN function takes "is_training" argument with default equals True which prevent the code to crash.

In mnist_cnn_train, you define the is_training placeholder but you don't use it when calling the cnn_model.CNN function. You use it in the training and testing loops of the same file so I assume this is not an intended behavior.

I've not tested it yet, but I think the is_training entry of the feed_dict is just ignored and this cause dropout to be applied during the testing loop (same goes for batch normalization). This bug could be the cause of the issue #1

mnist_cnn_train.py test accuracy values off

When the model is created "y = cnn_model.CNN(x) ", the is_training variable is not passed. Thus in the testing section when performing y_final = sess.run(y, feed_dict={x: batch_xs, y_: batch_ys, is_training: False}), the is_training: False has no affect. This will impact your accuracy.

If you use the mnist_cnn_train.py test function, the model is initialized with the is_training parameter and will give a result of approximately .5 % higher.

I changed y = cnn_model.CNN(x, is_training=is_training) and now the accuracy percents match for both modules.

Just as a side note: tf.scalar_summar is deprecated in Tensor 1.4

No model.ckpt was saved under dierctory model/

I ran python mnist_cnn_train.py in terminal and the terminal returned

Optimization` Finished!
test accuracy for the stored model: 0.9932

Training logs are saved in "logs/train" , however, there is no trained model saved as "model/model01_99.61/model.ckpt" or in other directories. When I run python mnist_cnn_test.py --model-dir model/model01_99.61 --batch-size 5000 --use-ensemble False and it returns an error message:

NotFoundError (see above for traceback): Key fc3/BatchNorm/beta not found in checkpoint
	 [[Node: save/RestoreV2 = RestoreV2[dtypes=[DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT], _device="/job:localhost/replica:0/task:0/device:CPU:0"](_arg_save/Const_0_0, save/RestoreV2/tensor_names, save/RestoreV2/shape_and_slices)]]

Where is the problem?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.