Giter Site home page Giter Site logo

naseemakhtar994 / tflite-mnist-android Goto Github PK

View Code? Open in Web Editor NEW

This project forked from nex3z/tflite-mnist-android

0.0 2.0 0.0 524 KB

MNIST with TensorFlow Lite on Android

Java 55.85% Python 44.15%
tensorflow tensorflow-lite mnist android gesture gesture-recognition

tflite-mnist-android's Introduction

MNIST with TensorFlow Lite on Android

This project demonstrates how to use TensorFlow Lite on Android for handwritten digits classification from MNIST.

Prebuilt APK can be downloaded from here.

How to build from scratch

Requirement

  • TensorFlow 1.6.0
  • Python 3.6, NumPy 1.14.1
  • Android Studio 3.0, Gradle 4.1
  • Linux or macOS if you want to convert the model to .tflite as described in the Step 2 below

Step 1. Training

The model is defined in /model/mnist.py, run the following command to train the model.

python train.py --model_dir ./saved_model --iterations 10000

After training, a collection of checkpoint files and a frozen model mnist.pb will be generated in ./saved_model.

You can test the model on test set using the command below.

python test.py --model_dir ./saved_model

A pre-trained model can be downloaded from here.

Step 2. Model conversion

The standard TensorFlow model obtained in Step 1 cannot be used directly in TensorFlow Lite. We need to freeze the graph and convert the model to flatbuffer format (.tflite). There are two ways to convert the model: use TOCO command-line or python API (which uses TOCO under the hood).

You will need Linux or macOS for this step as TOCO is not available and cannot be bulit on Windows at the moment.

Option 1. Use TOCO command-line

TOCO is a Tensorflow optimizing converter that can convert a TensorFlow GraphDef to TensorFlow Lite flatbuffer. We need to build TOCO with Bazel from Tensorflow repository and use it to convert the mnist.pb to a mnist.tflite.

  1. Install Bazel

Install Bazel by following the instructions.

  1. Clone TensorFlow repository.
git clone https://github.com/tensorflow/tensorflow
  1. Build TOCO

Navigate to the TensorFlow repository directory, run the following command to build TOCO.

bazel build tensorflow/contrib/lite/toco:toco
  1. Convert model

Stay at the TensorFlow repository directory, run the following command to convert the model.

/bazel-bin/tensorflow/contrib/lite/toco/toco  \
  --input_file=model_path/mnist.pb \
  --input_format=TENSORFLOW_GRAPHDEF  --output_format=TFLITE \
  --output_file=output_path/mnist.tflite --inference_type=FLOAT \
  --input_type=FLOAT --input_arrays=x \
  --output_arrays=output --input_shapes=1,28,28,1

Replace model_path/mnist.pb with the path of the TensorFlow model trained in step one, and replace output_path/mnist.tflite with the path to save the converted model.

Notice that the mnist.pb generated by mnist.py is already frozen, so we can skip the "freeze the graph" step and use it directly for the conversion.

More example can be found here.

A converted TensorFlow Lite flatbuffer file can be downloaded from here.

Option 2. Use Python API

Instead of using TOCO command line, we can also convert the model by Python API.

python convert.py --model_dir ./saved_model --output_file ./mnist.tflite

The convert.py restores the lastest checkpoint from Step 1, freezes the graph and invokes tf.contrib.lite.toco_convert to convert the model.

Step 3. Build Android app

Copy the mnist.tflite file from Step 2 to /android/app/src/main/assets, then build and run the app. A prebuilt APK can be downloaded from here.

The Classifer reads the mnist.tflite from assets directory and loads it into an Interpreter for inference. The Interpreter provides an interface between TensorFlow Lite model and Java code, which is included in the following library.

implementation 'org.tensorflow:tensorflow-lite:0.1.1'

If you are building your own app, remember to add the following code to build.gradle to prevent compression for model files.

aaptOptions {
    noCompress "tflite"
    noCompress "lite"
}

Credits

tflite-mnist-android's People

Contributors

nex3z avatar

Watchers

James Cloos avatar Naseem Akhtar avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.