MNIST Neural Network¶

Book: (http://neuralnetworksanddeeplearning.com/index.html) Neural Networks and Deep Learning by Michael Nielsen.

Source code: https://zoo.cs.yale.edu/classes/cs370/aima/neural-networks-and-deep-learning/src/

From (https://github.com/mnielsen/neural-networks-and-deep-learning.git)

MNIST data: https://zoo.cs.yale.edu/classes/cs370/aima/neural-networks-and-deep-learning/data/ (compressed pickle format)

Import MNIST data¶

In [1]:
import mnist_loader
training_data, validation_data, test_data = mnist_loader.load_data_wrapper()
In [2]:
len(training_data)
Out[2]:
50000
In [3]:
len(test_data)
Out[3]:
10000
In [4]:
len(validation_data)
Out[4]:
10000
In [5]:
training_data[0]
Out[5]:
(array([[0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.01171875],
        [0.0703125 ],
        [0.0703125 ],
        [0.0703125 ],
        [0.4921875 ],
        [0.53125   ],
        [0.68359375],
        [0.1015625 ],
        [0.6484375 ],
        [0.99609375],
        [0.96484375],
        [0.49609375],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.1171875 ],
        [0.140625  ],
        [0.3671875 ],
        [0.6015625 ],
        [0.6640625 ],
        [0.98828125],
        [0.98828125],
        [0.98828125],
        [0.98828125],
        [0.98828125],
        [0.87890625],
        [0.671875  ],
        [0.98828125],
        [0.9453125 ],
        [0.76171875],
        [0.25      ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.19140625],
        [0.9296875 ],
        [0.98828125],
        [0.98828125],
        [0.98828125],
        [0.98828125],
        [0.98828125],
        [0.98828125],
        [0.98828125],
        [0.98828125],
        [0.98046875],
        [0.36328125],
        [0.3203125 ],
        [0.3203125 ],
        [0.21875   ],
        [0.15234375],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.0703125 ],
        [0.85546875],
        [0.98828125],
        [0.98828125],
        [0.98828125],
        [0.98828125],
        [0.98828125],
        [0.7734375 ],
        [0.7109375 ],
        [0.96484375],
        [0.94140625],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.3125    ],
        [0.609375  ],
        [0.41796875],
        [0.98828125],
        [0.98828125],
        [0.80078125],
        [0.04296875],
        [0.        ],
        [0.16796875],
        [0.6015625 ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.0546875 ],
        [0.00390625],
        [0.6015625 ],
        [0.98828125],
        [0.3515625 ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.54296875],
        [0.98828125],
        [0.7421875 ],
        [0.0078125 ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.04296875],
        [0.7421875 ],
        [0.98828125],
        [0.2734375 ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.13671875],
        [0.94140625],
        [0.87890625],
        [0.625     ],
        [0.421875  ],
        [0.00390625],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.31640625],
        [0.9375    ],
        [0.98828125],
        [0.98828125],
        [0.46484375],
        [0.09765625],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.17578125],
        [0.7265625 ],
        [0.98828125],
        [0.98828125],
        [0.5859375 ],
        [0.10546875],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.0625    ],
        [0.36328125],
        [0.984375  ],
        [0.98828125],
        [0.73046875],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.97265625],
        [0.98828125],
        [0.97265625],
        [0.25      ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.1796875 ],
        [0.5078125 ],
        [0.71484375],
        [0.98828125],
        [0.98828125],
        [0.80859375],
        [0.0078125 ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.15234375],
        [0.578125  ],
        [0.89453125],
        [0.98828125],
        [0.98828125],
        [0.98828125],
        [0.9765625 ],
        [0.7109375 ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.09375   ],
        [0.4453125 ],
        [0.86328125],
        [0.98828125],
        [0.98828125],
        [0.98828125],
        [0.98828125],
        [0.78515625],
        [0.3046875 ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.08984375],
        [0.2578125 ],
        [0.83203125],
        [0.98828125],
        [0.98828125],
        [0.98828125],
        [0.98828125],
        [0.7734375 ],
        [0.31640625],
        [0.0078125 ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.0703125 ],
        [0.66796875],
        [0.85546875],
        [0.98828125],
        [0.98828125],
        [0.98828125],
        [0.98828125],
        [0.76171875],
        [0.3125    ],
        [0.03515625],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.21484375],
        [0.671875  ],
        [0.8828125 ],
        [0.98828125],
        [0.98828125],
        [0.98828125],
        [0.98828125],
        [0.953125  ],
        [0.51953125],
        [0.04296875],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.53125   ],
        [0.98828125],
        [0.98828125],
        [0.98828125],
        [0.828125  ],
        [0.52734375],
        [0.515625  ],
        [0.0625    ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ]], dtype=float32),
 array([[0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [1.],
        [0.],
        [0.],
        [0.],
        [0.]]))

A 28 x 28 array (784 elements) plus a 10 element array. training_data[0] is a numeral 5.

We are going to train a neural network on these data. It will have an input layer comprising 784 neurons and a hidden layer of 30 neurons. (The picture shows 15 neurons in the hidden layer.) The output layer is 10 neurons, corresponding to the digits 0 - 9. Only one of the output neurons will be activated.

No description has been provided for this image
In [6]:
import network
In [7]:
net = network.Network([784,30,10])
In [8]:
dir(net)
Out[8]:
['SGD',
 '__class__',
 '__delattr__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattribute__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__le__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__sizeof__',
 '__str__',
 '__subclasshook__',
 '__weakref__',
 'backprop',
 'biases',
 'cost_derivative',
 'evaluate',
 'feedforward',
 'num_layers',
 'sizes',
 'update_mini_batch',
 'weights']

The SGD method is Stochastic Gradiant Descent. See (https://en.wikipedia.org/wiki/Stochastic_gradient_descent). It minimizes an objective function.

We invoke SGD on the training data,

 def SGD(self, training_data, epochs, mini_batch_size, eta,                                        
            test_data=None):                                                                          
        """Train the neural network using mini-batch stochastic                                       
        gradient descent.  The ``training_data`` is a list of tuples                                  
        ``(x, y)`` representing the training inputs and the desired                                   
        outputs.  The other non-optional parameters are                                               
        self-explanatory.  If ``test_data`` is provided then the                                      
        network will be evaluated against the test data after each                                    
        epoch, and partial progress printed out.  This is useful for                                  
        tracking progress, but slows things down substantially.""" 

mini_batch_size is how large a batch to process at a time. Below we process 10 examples at a time.

eta is the Greek letter η which is the learning rate. Below the learning rate is 3.

Timing is done through calls to time.time() before and after runs. Below it takes less than 3 seconds per epoch.

In [9]:
net.SGD(training_data, 30, 10, 3.0, test_data=test_data)
Epoch 0: 9061 / 10000, took 2.40 seconds
Epoch 1: 9242 / 10000, took 2.73 seconds
Epoch 2: 9272 / 10000, took 2.44 seconds
Epoch 3: 9317 / 10000, took 2.37 seconds
Epoch 4: 9391 / 10000, took 2.44 seconds
Epoch 5: 9373 / 10000, took 2.39 seconds
Epoch 6: 9386 / 10000, took 2.39 seconds
Epoch 7: 9417 / 10000, took 2.37 seconds
Epoch 8: 9446 / 10000, took 2.44 seconds
Epoch 9: 9442 / 10000, took 2.65 seconds
Epoch 10: 9412 / 10000, took 2.43 seconds
Epoch 11: 9409 / 10000, took 2.39 seconds
Epoch 12: 9450 / 10000, took 2.57 seconds
Epoch 13: 9471 / 10000, took 2.56 seconds
Epoch 14: 9472 / 10000, took 2.56 seconds
Epoch 15: 9477 / 10000, took 2.37 seconds
Epoch 16: 9466 / 10000, took 2.40 seconds
Epoch 17: 9480 / 10000, took 2.37 seconds
Epoch 18: 9498 / 10000, took 2.51 seconds
Epoch 19: 9488 / 10000, took 2.43 seconds
Epoch 20: 9496 / 10000, took 2.43 seconds
Epoch 21: 9480 / 10000, took 2.38 seconds
Epoch 22: 9503 / 10000, took 2.45 seconds
Epoch 23: 9513 / 10000, took 2.41 seconds
Epoch 24: 9500 / 10000, took 2.43 seconds
Epoch 25: 9494 / 10000, took 2.38 seconds
Epoch 26: 9518 / 10000, took 2.40 seconds
Epoch 27: 9506 / 10000, took 2.41 seconds
Epoch 28: 9504 / 10000, took 2.42 seconds
Epoch 29: 9509 / 10000, took 2.44 seconds

Now we change the number of hidden neurons to 100. We use the same epochs and other parameters. The time per epoch is now 8 seconds or so.

In [10]:
 net = network.Network([784, 100, 10])
In [11]:
net.SGD(training_data, 30, 10, 3.0, test_data=test_data)
Epoch 0: 6718 / 10000, took 7.61 seconds
Epoch 1: 6775 / 10000, took 7.61 seconds
Epoch 2: 6927 / 10000, took 7.60 seconds
Epoch 3: 7795 / 10000, took 7.67 seconds
Epoch 4: 8645 / 10000, took 7.57 seconds
Epoch 5: 8692 / 10000, took 7.62 seconds
Epoch 6: 8707 / 10000, took 7.59 seconds
Epoch 7: 8711 / 10000, took 7.59 seconds
Epoch 8: 8748 / 10000, took 7.54 seconds
Epoch 9: 8710 / 10000, took 7.70 seconds
Epoch 10: 8754 / 10000, took 7.55 seconds
Epoch 11: 8762 / 10000, took 7.63 seconds
Epoch 12: 8771 / 10000, took 7.62 seconds
Epoch 13: 8773 / 10000, took 7.66 seconds
Epoch 14: 8785 / 10000, took 7.68 seconds
Epoch 15: 8794 / 10000, took 7.66 seconds
Epoch 16: 8789 / 10000, took 7.60 seconds
Epoch 17: 8797 / 10000, took 7.61 seconds
Epoch 18: 8788 / 10000, took 7.65 seconds
Epoch 19: 8809 / 10000, took 7.70 seconds
Epoch 20: 8793 / 10000, took 7.59 seconds
Epoch 21: 8817 / 10000, took 7.69 seconds
Epoch 22: 8813 / 10000, took 7.63 seconds
Epoch 23: 8809 / 10000, took 7.60 seconds
Epoch 24: 8811 / 10000, took 7.65 seconds
Epoch 25: 8832 / 10000, took 7.65 seconds
Epoch 26: 8816 / 10000, took 7.95 seconds
Epoch 27: 8806 / 10000, took 7.69 seconds
Epoch 28: 8826 / 10000, took 7.70 seconds
Epoch 29: 8835 / 10000, took 7.62 seconds
In [ ]: