TensorFlow.js

Training on Images: Recognizing Handwritten Digits with a Convolutional Neural Network

In this tutorial, we'll build a TensorFlow.js model to classify handwritten digits with a convolutional neural network. First, we'll train the classifier by having it “look” at thousands of handwritten digit images and their labels. Then we'll evaluate the classifier's accuracy using test data that the model has never seen.

Prerequisites

This tutorial assumes familiarity with the fundamental building blocks of TensorFlow.js (tensors, variables, and ops), as well as the concepts of optimization and loss. For more background on these topics, we recommend completing the following tutorials before this tutorial:

Running the Code

The full code for this tutorial can be found in the tfjs-examples/mnist directory in the TensorFlow.js examples repository.

To run the code locally, you need the following dependencies installed:

These instructions use Yarn, but if you are familiar with NPM CLI and prefer to use that instead it will still work.

You can run the code for the example by cloning the repo and building the demo:

$ git clone https://github.com/tensorflow/tfjs-examples
$ cd tfjs-examples/mnist
$ yarn
$ yarn watch

The tfjs-examples/mnist directory above is completely standalone, so you can copy it to start your own project.

NOTE: The difference between this tutorial's code and the tfjs-examples/mnist-core example is that here we use TensorFlow.js's higher-level APIs (Model, Layers) to construct the model, whereas mnist-core uses lower-level linear algebra ops to build a neural network.

Data

We will use the MNIST handwriting dataset for this tutorial. The handwritten MNIST digits we will learn to classify look like this:

mnist 4 mnist 3 mnist 8

To preprocess our data, we've written data.js, which contains the class MnistData that fetches random batches of MNIST images from a hosted version of the MNIST dataset we provide for convenience.

MnistData splits the entire dataset into training data and test data. When we train the model, the classifier will see only the training set. When we evaluate the model, we'll use only the data in the test set, which the model has not yet seen, to see how well the model's predictions generalize to brand-new data.

MnistData has two public methods:

NOTE: When training the MNIST classifier, it's important to randomly shuffle the data, so the model's predictions aren't affected by the order in which we feed it images. For example, if we were to feed the model all the 1 digits first, during this phase of training, the model might learn to simply predict 1 (since this minimizes the loss). If we were to then feed the model only 2s, it might simply switch to predicting only 2 and never predict a 1 (since, again, this would minimize loss for the new set of images). The model would never learn to make an accurate prediction over a representative sample of digits.

Building the Model

In this section, we'll build a convolutional image classifier model. To do so, we'll use a Sequential model (the simplest type of model), in which tensors are consecutively passed from one layer to the next.

First, let's instantiate our Sequential model with tf.sequential:

const model = tf.sequential();

Now that we've created a model, let's add layers to it.

Adding the First Layer

The first layer we’ll add is a two-dimensional convolutional layer. Convolutions slide a filter window over an image to learn transformations that are spatially invariant (that is, patterns or objects in different parts of the image will be treated the same way). For more information about convolutions, see this article.

We can create our 2-D convolutional layer using tf.layers.conv2d, which accepts a configuration object that defines the layer's structure:

model.add(tf.layers.conv2d({
  inputShape: [28, 28, 1],
  kernelSize: 5,
  filters: 8,
  strides: 1,
  activation: 'relu',
  kernelInitializer: 'VarianceScaling'
}));

Let’s break down each argument in the configuration object:

Adding the Second Layer

Let’s add a second layer to the model: a max pooling layer, which we'll create using tf.layers.maxPooling2d. This layer will downsample the result (also known as the activation) from the convolution by computing the maximum value for each sliding window:

model.add(tf.layers.maxPooling2d({
  poolSize: [2, 2],
  strides: [2, 2]
}));

Let’s break down the arguments:

NOTE: Since both poolSize and strides are 2x2, the pooling windows will be completely non-overlapping. This means the pooling layer will cut the size of the activation from the previous layer in half.

Adding the Remaining Layers

Repeating layer structure is a common pattern in neural networks. Let's add a second convolutional layer, followed by another pooling layer to our model. Note that in our second convolutional layer, we're doubling the number of filters from 8 to 16. Also note that we don't specify an inputShape, as it can be inferred from the shape of the output from the previous layer:

model.add(tf.layers.conv2d({
  kernelSize: 5,
  filters: 16,
  strides: 1,
  activation: 'relu',
  kernelInitializer: 'VarianceScaling'
}));

model.add(tf.layers.maxPooling2d({
  poolSize: [2, 2],
  strides: [2, 2]
}));

Next, let's add a flatten layer to flatten the output of the previous layer to a vector:

model.add(tf.layers.flatten());

Lastly, let's add a dense layer (also known as a fully connected layer), which will perform the final classification. Flattening the output of a convolution+pooling layer pair before a dense layer is another common pattern in neural networks:

model.add(tf.layers.dense({
  units: 10,
  kernelInitializer: 'VarianceScaling',
  activation: 'softmax'
}));

Let’s break down the arguments passed to the dense layer.

Training the Model

To actually drive training of the model, we'll need to construct an optimizer and define a loss function. We'll also define an evaluation metric to measure how well our model performs on the data.

NOTE: For a deeper dive into optimizers and loss functions in TensorFlow.js, see the tutorial Training First Steps.

Defining the Optimizer

For our convolutional neural network model, we'll use a stochastic gradient descent (SGD) optimizer with a learning rate of 0.15:

const LEARNING_RATE = 0.15;
const optimizer = tf.train.sgd(LEARNING_RATE);

Defining Loss

For our loss function, we'll use cross-entropy (categoricalCrossentropy), which is commonly used to optimize classification tasks. categoricalCrossentropy measures the error between the probability distribution generated by the last layer of our model and the probability distribution given by our label, which will be a distribution with 1 (100%) in the correct class label. For example, given the following label and prediction values for an example of the digit 7:

class 0 1 2 3 4 5 6 7 8 9
label 0 0 0 0 0 0 0 1 0 0
prediction .1 .01 .01 .01 .20 .01 .01 .60 .03 .02

categoricalCrossentropy gives a lower loss value if the prediction is a high probability that the digit is 7, and a higher loss value if the prediction is a low probability of 7. During training, the model will update its internal parameters to minimize categoricalCrossentropy over the whole dataset.

Defining the Evaluation Metric

For our evaluation metric, we'll use accuracy, which measures the percentage of correct predictions out of all predictions.

Compiling the Model

To compile the model, we pass it a configuration object with our optimizer, loss function, and a list of evaluation metrics (here, just 'accuracy'):

model.compile({
  optimizer: optimizer,
  loss: 'categoricalCrossentropy',
  metrics: ['accuracy'],
});

Configuring Batch Size

Before we begin training, we need to define a few more parameters related to batch size:

// How many examples the model should "see" before making a parameter update.
const BATCH_SIZE = 64;
// How many batches to train the model for.
const TRAIN_BATCHES = 100;

// Every TEST_ITERATION_FREQUENCY batches, test accuracy over TEST_BATCH_SIZE examples.
// Ideally, we'd compute accuracy over the whole test set, but for performance
// reasons we'll use a subset.
const TEST_BATCH_SIZE = 1000;
const TEST_ITERATION_FREQUENCY = 5;

More About Batching and Batch Size

To take full advantage of the GPU's ability to parallelize computation, we want to batch multiple inputs together and feed them through the network using a single feed-forward call.

Another reason we batch our computation is that during optimization, we update internal parameters (taking a step) only after averaging gradients from several examples. This helps us avoid taking a step in the wrong direction because of an outlier example (e.g., a mislabeled digit).

When batching input data, we introduce a tensor of rank D+1, where D is the dimensionality of a single input.

As discussed earlier, the dimensionality of a single image in our MNIST data set is [28, 28, 1]. When we set a BATCH_SIZE of 64, we're batching 64 images at a time, which means the actual shape of our data is [64, 28, 28, 1] (the batch is always the outermost dimension).

NOTE:* Recall that the inputShape in the config of our first conv2d did not specify the batch size (64). Configs are written to be batch-size-agnostic, so that they are able to accept batches of arbitrary size.

Coding the Training Loop

Here is the code for the training loop:

for (let i = 0; i < TRAIN_BATCHES; i++) {
  const batch = data.nextTrainBatch(BATCH_SIZE);
 
  let testBatch;
  let validationData;
  // Every few batches test the accuracy of the mode.
  if (i % TEST_ITERATION_FREQUENCY === 0) {
    testBatch = data.nextTestBatch(TEST_BATCH_SIZE);
    validationData = [
      testBatch.xs.reshape([TEST_BATCH_SIZE, 28, 28, 1]), testBatch.labels
    ];
  }
 
  // The entire dataset doesn't fit into memory so we call fit repeatedly
  // with batches.
  const history = await model.fit(
      batch.xs.reshape([BATCH_SIZE, 28, 28, 1]),
      batch.labels,
      {
        batchSize: BATCH_SIZE,
        validationData,
        epochs: 1
      });

  const loss = history.history.loss[0];
  const accuracy = history.history.acc[0];

  // ... plotting code ...
}

Let's break the code down. First, we fetch a batch of training examples. Recall from above that we batch examples to take advantage of GPU parallelization and to average evidence from many examples before making a parameter update:

const batch = data.nextTrainBatch(BATCH_SIZE);

Every 5 steps (our TEST_ITERATION_FREQUENCY, we construct validationData, an array of two elements containing a batch of MNIST images from the test set and their corresponding labels. We'll use this data to evaluate the accuracy of the model:

if (i % TEST_ITERATION_FREQUENCY === 0) {
  testBatch = data.nextTestBatch(TEST_BATCH_SIZE);
  validationData = [
    testBatch.xs.reshape([TEST_BATCH_SIZE, 28, 28, 1]),
    testBatch.labels
  ];
}

model.fit is where the model is trained and parameters actually get updated.

NOTE: Calling model.fit() once on the whole dataset will result in uploading the whole dataset to the GPU, which could freeze the application. To avoid uploading too much data to the GPU, we recommend calling model.fit() inside a for loop, passing a single batch of data at a time, as shown below:

// The entire dataset doesn't fit into memory so we call fit repeatedly
// with batches.
  const history = await model.fit(
      batch.xs.reshape([BATCH_SIZE, 28, 28, 1]), batch.labels,
      {batchSize: BATCH_SIZE, validationData: validationData, epochs: 1});

Let's break down the arguments again:

Each time we call fit, it returns a rich object that contains logs of our metrics, which we store in history. We extract our loss and accuracy for each training iteration, so we can plot them on a graph:

const loss = history.history.loss[0];
const accuracy = history.history.acc[0];

See the Results!

If you run the full code, you should see output like this:

Two plots. The first plot shows loss vs. training batch, with loss trending downward from batch 0 to batch 100. The second plot shows accuracy vs. test batch, with accuracy trending upward from batch 0 to batch 100

It looks like the model is predicting the right digit for most of the images. Great work!

Additional Resources