Presenting a simple neural network example

Thu, Jan 28, 2021 5-minute read

I’m a bit of a Python tragic. In an organisation full of R users, this means I’m often tasked with hosting Python introductions and training. These sessions outline what Python is, give examples of what it is used for, and provide the audience with internal and external resources.

For Python examples, I like to show both cutting-edge and more practical use-cases. For the former, I’ve recently showcased OpenAI’s new DALL-E image generator, or the very fun (and free!) text-to-voice generator from 15.ai. For the latter, I work through a short data science exercise to illustrate how Python is a simple and powerful tool for such tasks. What does this exercise look like?

The audience are generally junior economists, unfamiliar with Python but technically savvy, aware of data science concepts like neural networks, and familiar with code workflows. So, we run a worked example of training a neural network to do a simple – and importantly visual – classification task. The MINST dataset is perfect.1

MNIST Example

import tensorflow as tf
import numpy as np

While MNIST is famous enough to have its own Wikipedia page, in short, it’s a collection of processed images of hand-written digits from 0 to 9, tagged with the true integer the writer was intending to convey. To familiarise the group, I show a few example images (like below), and ask the audience to ‘classify’ them by telling me the integer. When they succeed I congratulate them on being a highly advanced neural network.

The MNIST Example Images

I then motivate our simple example. What if we were building some hand writing recognition software, and needed a small neural network to identify digits and return an integer?2 I explain that in just a few minutes, we can build, train, and test a neural network in Python that will classify these digits to a high accuracy. At this point I engage in a quick Q&A, asking the audience to tell me what they know about neural networks and how they work, and check if they have any questions.

Build and train a neural network

We then begin working through a Jupyter Notebook, based loosely on the Google Tensorflow ‘Get Started’ example (my notebook is available on GitHub here). We do some simple pre-processing, and then set up a model. Once it is initialised –- and I’ve reminded the audience that initialised models have random parameters – we ask it to classify the 5 example digits above. As expected, the un-trained model puts out gibberish, as it’s no better at classifying than chance.

Following that, we compile and train the model. I use these steps to discuss neural networks, using the analogy to OLS regression (that economists are very familiar with). We talk about defining a loss function, compared with minimising least squares in an OLS. We talk about optimisers, and how numeric approaches are required for neural networks whereas analytical solutions are available for OLS regressions. These discussions also help us pass the time used for model training (although for such a simple neural network, this takes only a minute or two even on old hardware).

Once training is complete, we check the accuracy of the model (usually around 98 per cent), and also see if the model has correctly predicted the example images above, which it almost always has.

What does the model get wrong?

Although correct predictions are fun for everyone, I also find it’s enlightening to talk about what handwritten digits the model failed to predict correctly. I pull up a list of digits where the prediction failed. Below I show these with the correct classification in the bottom left, and the model’s incorrect classification in the bottom right. This usually facilitates an interesting discussion about why the model got these digits wrong (see if you can identify what features the model saw that it mistook), and about how we are much better at understanding and interpreting the handwriting than a model.

Incorrect Predictions (bottom LHS is actual, bottom RHS is model prediction) A useful extension to this exercise might be to look through the most common incorrectly classified MNIST digits to see some that are truly ambiguous due to very poor or unorthodox handwriting (although it looks like someone on GitHub is ahead of me). This might lead to discussions about shortcomings of the dataset, which raises a broader question.

Is this example worthwhile?

Attentive audience members might ask why the MNIST dataset is so comparatively easy to classify. At the end of the exercise, I emphasise that this data is available for free, is pre-cleaned, has no noise, and so on. IT also has a very well behaved dimensional structure (that is, the parameter space of different digits is very distinct – see this excellent blog post). Overall, this makes MNIST quite an unrealistic example for data science amateurs.

A few online have complained about the MNIST example for thsi reason – some more dramatically than others. Although the dataset isn’t perfect, a visual, simple and computationally-light data science example can be an engaging and thought-provoking exercise to demonstrate to an unfamiliar audience. In future, I might consider moving to the alternative Fashion-MNIST dataset, or amending the presentation to focus more on MNIST’s shortcomings.


  1. If you haven’t already, it’s important to check if any of your audience members have a visual impairment before including an exercise along these lines.↩︎

  2. I also mention that this is not that far from our actual work – text and pattern recognition is being increasingly used in economic applications, like in digitising old datasets and archived material.↩︎