Numpy Mnist

Module 3: NumPy

Exercise: Identify Hand-Drawn Numbers

In this exercise, we are going to use the MNIST dataset, which is a large database of handwritten digits commonly used for training various image processing systems. Our goal is to create a ‘’toy AI model’’ that can detect the digit drawn by the user.

To complete this exercise, we need to install and import some additional libraries. Don’t worry if you don’t understand all these imports just yet. You will cover them in detail in future subjects.

# You DO NOT need to learn this cell

# Install necessary libraries (not included in Google Colab by default)
!pip install ipycanvas
# !pip install pillow  # Uncomment if required

# Import libraries
import matplotlib.pyplot as plt  # To draw graphs
import numpy as np
from tensorflow.keras.datasets import mnist  # To load the dataset
from PIL import Image  # To manage images
from ipycanvas import Canvas, hold_canvas  # Allows the user to draw
import ipywidgets as widgets  # To interact with the notebook
from IPython.display import display  # Update plots in real time

# Ask Google to let us draw on the notebook
from google.colab import output
output.enable_custom_widget_manager()  # Allows the user to draw online

Data Processing

# Load the MNIST dataset
dataset = mnist.load_data()

The data is organized as a nested tuple: (train_subset, test_subset), where each subset is further divided into (train_x, train_y) and (test_x, test_y). In this context, x represents the data we have (the drawings), and y represents the value to predict (the digit that each drawing represents).

How can we extract these four subsets? We aim to obtain: (train_x, train_y, test_x, test_y).

# Assuming your data is structured as (train_subset, test_subset)
train_subset, test_subset = dataset

# Unpacking train_subset and test_subset
x_train, y_train = train_subset
x_test, y_test = test_subset

# Now you have train_x, train_y, test_x, and test_y as separate variables
# Load the MNIST dataset (in one go)
(x_train, y_train), (x_test, y_test) = mnist.load_data()
# See the shape - do we understand these dimensions?
print(f"x_train shape: {x_train.shape}")
print(f"y_train shape: {y_train.shape}")
# Lets see the first data point
x = x_train[0]
y = y_train[0]

print(f"Number is {y}")
print(x)

Before we start building our ‘’AI model’’, let’s create a function to visualize the images from our dataset. This will help us better understand the data we are working with.

You do not need to understand the details of this function just yet. The function show_image will display an image and, if provided, the predicted digit.

# First, create a function to plot the images
# You DO NOT need to understand this cell

def show_image(x: np.ndarray, y: int=None) -> None:
  """
  Display an image with an optional title.

  Parameters
  ----------
  x : np.ndarray
    A 2D array representing the image. The image must be square (same number of rows and columns).
  y : int, optional
    The predicted digit to display as the title of the image. Default is None.

  Raises
  ------
  AssertionError
    If the image does not have 2 dimensions or is not square.
  """
  # Assertions are our way to stop the function
  # when we detect an error
  # If the condition of the assertion is False, the
  # assert will raise and error (stoppig the function)
  assert x.ndim == 2, "The image must have 2 dimensions"
  assert x.shape[0] == x.shape[1], "The image must be a square"
  # Display the image
  plt.imshow(x, cmap="gray")
  plt.title(f"Predicted: {y}")
  plt.show()
  plt.close()

Now, let’s test the show_image function with the first element of the training subset. This will help us verify that our function works correctly and allows us to visualize the images and their corresponding labels.

# Try out the functions with the first element of the train subset
idx = 0
x = x_train[0]
y = y_train[0]

show_image(x, y)

Build our AI Model

Time to start building our model!

First, we will compute the mean image for each digit in the training dataset. This involves calculating the average pixel values for all images corresponding to each digit. The resulting mean images will help us understand the general appearance of each digit.

# Compute the mean image for each digit
dict_mean = {}
for num in np.unique(y_train):
  dict_mean[num] = np.mean(x_train[y_train == num], axis=0)

Next, let’s plot the mean image for one of the digits. This will help us visualize the average appearance of a specific digit based on the training data. In this example, we will plot the mean image for the digit 4.

# Plot one of these means
y = 4
x = dict_mean[y]

show_image(x, y)

We now have the “average look” of each digit. For instance, any number 5 should resemble the average 5 more closely than the average 2 or 6. To compute this similarity, we use the pairwise distance.

The pairwise distance (d) between two matrices (A) and (B) is calculated using the following formula:

\[d = \sqrt{ \sum_{i} \sum_{j} {(A_{ij} - B_{ij})^{2}}}\]

This formula represents the Euclidean distance between the corresponding elements of the two matrices.

def pairwise_distance(A: np.ndarray, B: np.ndarray) -> float:
  """
  Compute the pairwise Euclidean distance between two matrices.

  Parameters
  ----------
  A : np.ndarray
    The first matrix.
  B : np.ndarray
    The second matrix.

  Returns
  -------
  float
    The Euclidean distance between the two matrices.
  """
  # Ensure the matrices have the same shape
  assert A.shape == B.shape, "The input matrices must have the same shape"

  # Compute all pairwise (Aij - Bij)^2
  d = np.power(A - B, 2)

  # Add all of them together
  d = np.sum(d)

  # Compute the square root of that sum
  d = np.sqrt(d)

  return d

Let’s now select a digit from the training set and compute the pairwise distance between this digit and the average of each digit.

idx = 0
x = x_train[idx]
y = y_train[idx]

for y_mean, x_mean in dict_mean.items():
    dist = pairwise_distance(x, x_mean)
    print(f"Target: {y} - Compared: {y_mean} - distance = {dist:.2f}")

We can build a function to compute the pairwise distance between the new image and the average look of each digit. Our function then identifies the digit with the closest average look and returns it as the predicted classification.

# Function to classify a new image
def classify_image(x: np.ndarray) -> int:
    # Find the closest mean image
    min_dist = np.inf
    for y_mean, x_mean in dict_mean.items():
        dist = pairwise_distance(x, x_mean)
        if dist < min_dist:
            min_dist = dist
            y_pred = y_mean
    return y_pred

Congratulations! You’ve just built your first AI model that uses real data to classify new information. 🎉

Now, let’s take a digit from the test subset and see if our model can predict it correctly.

idx = 3
x = x_test[idx]

y_pred = classify_image(x)

show_image(x, y_pred)

Score our Model

We can estimate how good our model is by computing its accuracy. Accuracy is a measure of how often the model correctly predicts the labels of the test data.

An accuracy of 1.00 (or 100%) means the model predicted every test sample correctly. An accuracy of 0.00 (or 0%) means the model did not predict any test sample correctly. Generally, a higher accuracy indicates a better-performing model (*).

To compute the accuracy of our model, we use the following formula:

\[ \text{Accuracy} = \frac{\text{Number of Correct Predictions}}{\text{Total Number of Predictions}} \]

_(*) But it is important to consider other metrics and the context of the problem for a comprehensive evaluation._

# Compute our model's accuracy
y_pred = []
for idx in range(len(y_test)):
    x = x_test[idx]
    y_pred.append(classify_image(x))

# Turn y_pred list into an array
y_pred = np.array(y_pred)
correct = y_pred == y_test
accuracy = np.mean(correct)
print(f"Accuracy: {accuracy:.2f}")

Our model has a 82% accuracy. Looks very promising!


Test the Model Yourself!

Execute the following cell to generate a canvas where you can draw a digit. Once you are done, click on “Predict” to see what our model thinks your digit is!

You do not need to understand the cell, but as always I provide comments for you to follow the code’s logic.

# Define the color of your background
background = "black"
color = "white" if background == "black" else "black"

# Create an interactive canvas
canvas = Canvas(width=280, height=280, sync_image_data=True)

# Variable to track if the mouse button is pressed
is_drawing = False

# Function to handle mouse down event
def handle_mouse_down(x, y):
  """
  Handles the mouse down event by setting the is_drawing flag to True
  and initiating the drawing process.

  Parameters:
  x (int): The x-coordinate of the mouse event.
  y (int): The y-coordinate of the mouse event.
  """
  global is_drawing
  is_drawing = True
  handle_draw(x, y)

# Function to handle mouse up event
def handle_mouse_up(x, y):
  """
  Handles the mouse up event by setting the is_drawing flag to False.

  Parameters:
  x (int): The x-coordinate of the mouse event.
  y (int): The y-coordinate of the mouse event.
  """
  global is_drawing
  is_drawing = False

# Function to handle drawing on the canvas
def handle_draw(x, y):
  """
  Draws a circle on the canvas at the specified coordinates if the
  is_drawing flag is True.

  Parameters:
  x (int): The x-coordinate where the circle will be drawn.
  y (int): The y-coordinate where the circle will be drawn.
  """
  if is_drawing:
    with hold_canvas(canvas):
      canvas.fill_style = color
      canvas.fill_circle(x, y, 20)

# Bind the mouse events to their respective handlers
canvas.on_mouse_down(handle_mouse_down)
canvas.on_mouse_up(handle_mouse_up)
canvas.on_mouse_move(handle_draw)

def click_predict(change):
  """
  Handles the click event of the predict button. Captures the image
  from the canvas, processes it, and uses a model to predict the digit.
  Displays the predicted digit and clears the canvas.

  Parameters:
  change (dict): The change event dictionary.
  """
  # Get the image from the canvas
  image_data = canvas.get_image_data()
  image = Image.fromarray(image_data)

  # Convert to grayscale and resize to 28x28
  image = image.convert("L").resize((28, 28))

  # Convert to a numpy array
  image = np.array(image)

  # Predict the digit
  predicted_digit = classify_image(image)

  # Display the result
  plt.imshow(image, cmap="gray")
  plt.title(f"Predicted: {predicted_digit}")
  plt.show()
  plt.close()

  # Clear the canvas
  canvas.clear()

# Add a button to trigger the prediction
button = widgets.Button(description="Predict")
button.on_click(click_predict)

# Display the canvas and button
display(canvas)
display(button)

Improving the Model

How can we improve the model? The most straightforward way to improve an AI model is with more data.

And we have the data! So far our model learnt only from the train subset. We can add the test subset to their learning, and see if that improves its predictions.

# Concatenate train and test data
x_both = np.concatenate((x_train, x_test), axis=0)
y_both = np.concatenate((y_train, y_test), axis=0)

print(f"x_both shape: {x_both.shape}")
print(f"y_both shape: {y_both.shape}")
# Generate a new dictionary of means
dict_mean_new = {}
for num in np.unique(y_both):
  dict_mean_new[num] = np.mean(x_both[y_both == num], axis=0)
# Build a new prediction function
def classify_image_new(x: np.ndarray) -> int:
  # Find the closest mean image
  min_dist = np.inf
  for y_mean, x_mean in dict_mean_new.items():
    dist = pairwise_distance(x, x_mean)
    if dist < min_dist:
      min_dist = dist
      y_pred = y_mean
  return y_pred
# Compute our model's accuracy
y_pred = []
for idx in range(len(y_test)):
    x = x_test[idx]
    y_pred.append(classify_image_new(x))

# Turn y_pred list into an array
y_pred = np.array(y_pred)
correct = y_pred == y_test
accuracy = np.mean(correct)
print(f"Accuracy: {accuracy:.2f}")

The model doesn’t seem to improve. This is because we are using a very basic structure: just comparing a digit to their mean. In future courses, you will learn how to build way better models.

Still, it was worth a try!

# Define the color of your background
background = "black"
color = "white" if background == "black" else "black"

# Create an interactive canvas
canvas = Canvas(width=280, height=280, sync_image_data=True)

# Variable to track if the mouse button is pressed
is_drawing = False

# Function to handle mouse down event
def handle_mouse_down(x, y):
  """
  Handles the mouse down event by setting the is_drawing flag to True
  and initiating the drawing process.

  Parameters:
  x (int): The x-coordinate of the mouse event.
  y (int): The y-coordinate of the mouse event.
  """
  global is_drawing
  is_drawing = True
  handle_draw(x, y)

# Function to handle mouse up event
def handle_mouse_up(x, y):
  """
  Handles the mouse up event by setting the is_drawing flag to False.

  Parameters:
  x (int): The x-coordinate of the mouse event.
  y (int): The y-coordinate of the mouse event.
  """
  global is_drawing
  is_drawing = False

# Function to handle drawing on the canvas
def handle_draw(x, y):
  """
  Draws a circle on the canvas at the specified coordinates if the
  is_drawing flag is True.

  Parameters:
  x (int): The x-coordinate where the circle will be drawn.
  y (int): The y-coordinate where the circle will be drawn.
  """
  if is_drawing:
    with hold_canvas(canvas):
      canvas.fill_style = color
      canvas.fill_circle(x, y, 20)

# Bind the mouse events to their respective handlers
canvas.on_mouse_down(handle_mouse_down)
canvas.on_mouse_up(handle_mouse_up)
canvas.on_mouse_move(handle_draw)

def click_predict(change):
  """
  Handles the click event of the predict button. Captures the image
  from the canvas, processes it, and uses a model to predict the digit.
  Displays the predicted digit and clears the canvas.

  Parameters:
  change (dict): The change event dictionary.
  """
  # Get the image from the canvas
  image_data = canvas.get_image_data()
  image = Image.fromarray(image_data)

  # Convert to grayscale and resize to 28x28
  image = image.convert("L").resize((28, 28))

  # Convert to a numpy array
  image = np.array(image)

  # Predict the digit
  predicted_digit = classify_image(image)

  # Display the result
  plt.imshow(image, cmap="gray")
  plt.title(f"Predicted: {predicted_digit}")
  plt.show()
  plt.close()

  # Clear the canvas
  canvas.clear()

# Add a button to trigger the prediction
button = widgets.Button(description="Predict")
button.on_click(click_predict)

# Display the canvas and button
display(canvas)
display(button)