This quick starter provides a simple example of using the MNIST dataset with TensorFlow and Keras. It includes a Convolutional Neural Network (CNN) for digit classification, training the model, and evaluating its performance.
Ensure you have the required dependencies installed by running:
Currently using Python 3.11
pip install -r requirements.txt
- Run the
main.py
file to train the model and execute predictions on the test dataset.
python main.py
The code in main.py
contains a class MNISTClassifier
encapsulating the MNIST model and related functions. It includes the following methods:
train_model(epochs)
: Trains the model on the MNIST training dataset for the specified number of epochs.evaluate_model()
: Evaluates the trained model on the MNIST test dataset and prints the test accuracy.predict_samples(num_samples)
: Displays predictions for a specified number of samples from the test dataset.
Additionally, the code includes a test class TestMNISTClassifier
within which there is a test method test_model_evaluation
. This method checks if the model is initialized, trains it, evaluates its accuracy, and predicts samples, ensuring the accuracy is above a specified threshold.
The MNISTClassifier
class utilizes TensorFlow and Keras to define, compile, and train a Convolutional Neural Network for digit classification on the MNIST dataset. The architecture consists of a convolutional layer, max-pooling layer, flattening layer, and dense layers.
- The test accuracy is checked to ensure it is greater than 95%.
- Predictions for sample images are displayed using Matplotlib.