MNIST Neural Network

Neural Network Example

Click Image For Video

Opening Comments

This article showcases my handwritten digit AI based on the MNIST dataset. I am not an expert in machine learning, but I will try my best to explain the basics. This means I won’t go into all the really complicated stuff. I would highly recommend using all the sources linked below. I made this project with no machine learning libraries, just pure C# code, displayed in the Unity Game Engine. The code will be posted in a github repo on my account.

Introduction

Imagine a 28 by 28 grid of grayscale pixels, 784 in total. Now imagine that grid is a drawing of a number, between 0 and 9, like 3 for instance.

MNIST Example

Then imagine that there are 1000s of pixel grids that need to be sorted as a particular number. How would you make a computer program to recognize one drawing correctly if you would also need it to be accurate for all the other examples? This is the typical “Hello World” of neural network problems.

Basic Summary

The procedure for making it is below: First you parse the MNIST dataset for input data, representing a grayscale (black and whtie) picture. You use the parsed data to make 784 inputs called neurons, organized in an array called a layer. Each layer has an activation or value between 0 and 1 (black to white).

Gradient: 0 -> 1

These values affect the further values of neurons in the next layers. The value of a non input layer neuron is defined by the values of the activations from the layer preceding it. Each pair of past layer neuron to current layer neuron has a float value called the weight. You multiply the weight by the previous layer activation. You then sum up all the weighted products with a float value called the bias. This value is then put into a function to simplify it into an activation. I used a sigmoid to put the value between 0 and 1. The function is below:

\[ \text{Sigmoid: } σ(x) = \frac{1}{1 + e^{-x}} \] As you can see the sigmoid looks like an S, that approaches the values of 0 and 1, but never crosses them.

Also, on initialization the network is horrible. It would probably have an accuracy around 10%. To fix this make the network “learn”. Learning begins with a cost (how bad the netowork is) followed by cost reduction analysis to ensure the network performs better in its next iteration. The process of fixing the overall cost, after defining it, is called backpropagation.

My Project

My project has an accuracy of 95.85%. Which is good for hobbyists (state of the art is higher than 99%). To illustrate, the network is able to identify hard examples like this:

Hard Number to Categorize

with okay accuracy. If you didn't know this is supposed to be a 4. Without these hard ones the network would never be able to identify bad written digits. Basically, a hard example like this make a more robust network.

Conclusion

The process of making a computer “learn” is very fascinating. The way rules can bring a simple math function from random to almost absolute in accuracy is amazing. I am proud of making this network, with no libraries, but I still have a long way to go before becoming an expert. I believe I will work more on projects like this in the future.

About Using AI

I did use AI for this project. ChatGPT helped me with most of the code. At the time I did not know the relevant mathematics, such as calculus or unknown linear algebra topics, for this project. I really wanted to make a neural network project, however, so I used AI to help me (kind of, funny I know). For projects where I use AI, I always make sure to understand the code it gives me. So, that's why I can't really call myself an expert. In the future I want to be able to code hard projects like this myself, and I think with effort I will. For all projects that use AI I will disclose my use of it. With that I want to thank you for reading.

References and Learning Resources

These 4 vidoes are an amazing series by 3Blue1Brown

https://www.youtube.com/watch


https://www.youtube.com/watch


https://www.youtube.com/watch


https://www.youtube.com/watch

For parsing the MNIST dataset.

https://www.youtube.com/watch

The 3Blue1Brown series as articles 

https://www.3blue1brown.com


I would recommend using the Wayback Machine to get the MNIST dataset.

Homepage