MNIST Neural Network¶
Book: (http://neuralnetworksanddeeplearning.com/index.html) Neural Networks and Deep Learning by Michael Nielsen.
Source code: https://zoo.cs.yale.edu/classes/cs370/aima/neural-networks-and-deep-learning/src/
From (https://github.com/mnielsen/neural-networks-and-deep-learning.git)
MNIST data: https://zoo.cs.yale.edu/classes/cs370/aima/neural-networks-and-deep-learning/data/ (compressed pickle format)
Import MNIST data¶
import mnist_loader
training_data, validation_data, test_data = mnist_loader.load_data_wrapper()
len(training_data)
50000
len(test_data)
10000
len(validation_data)
10000
training_data[0]
(array([[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0.01171875],
[0.0703125 ],
[0.0703125 ],
[0.0703125 ],
[0.4921875 ],
[0.53125 ],
[0.68359375],
[0.1015625 ],
[0.6484375 ],
[0.99609375],
[0.96484375],
[0.49609375],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0.1171875 ],
[0.140625 ],
[0.3671875 ],
[0.6015625 ],
[0.6640625 ],
[0.98828125],
[0.98828125],
[0.98828125],
[0.98828125],
[0.98828125],
[0.87890625],
[0.671875 ],
[0.98828125],
[0.9453125 ],
[0.76171875],
[0.25 ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0.19140625],
[0.9296875 ],
[0.98828125],
[0.98828125],
[0.98828125],
[0.98828125],
[0.98828125],
[0.98828125],
[0.98828125],
[0.98828125],
[0.98046875],
[0.36328125],
[0.3203125 ],
[0.3203125 ],
[0.21875 ],
[0.15234375],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0.0703125 ],
[0.85546875],
[0.98828125],
[0.98828125],
[0.98828125],
[0.98828125],
[0.98828125],
[0.7734375 ],
[0.7109375 ],
[0.96484375],
[0.94140625],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0.3125 ],
[0.609375 ],
[0.41796875],
[0.98828125],
[0.98828125],
[0.80078125],
[0.04296875],
[0. ],
[0.16796875],
[0.6015625 ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0.0546875 ],
[0.00390625],
[0.6015625 ],
[0.98828125],
[0.3515625 ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0.54296875],
[0.98828125],
[0.7421875 ],
[0.0078125 ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0.04296875],
[0.7421875 ],
[0.98828125],
[0.2734375 ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0.13671875],
[0.94140625],
[0.87890625],
[0.625 ],
[0.421875 ],
[0.00390625],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0.31640625],
[0.9375 ],
[0.98828125],
[0.98828125],
[0.46484375],
[0.09765625],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0.17578125],
[0.7265625 ],
[0.98828125],
[0.98828125],
[0.5859375 ],
[0.10546875],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0.0625 ],
[0.36328125],
[0.984375 ],
[0.98828125],
[0.73046875],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0.97265625],
[0.98828125],
[0.97265625],
[0.25 ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0.1796875 ],
[0.5078125 ],
[0.71484375],
[0.98828125],
[0.98828125],
[0.80859375],
[0.0078125 ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0.15234375],
[0.578125 ],
[0.89453125],
[0.98828125],
[0.98828125],
[0.98828125],
[0.9765625 ],
[0.7109375 ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0.09375 ],
[0.4453125 ],
[0.86328125],
[0.98828125],
[0.98828125],
[0.98828125],
[0.98828125],
[0.78515625],
[0.3046875 ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0.08984375],
[0.2578125 ],
[0.83203125],
[0.98828125],
[0.98828125],
[0.98828125],
[0.98828125],
[0.7734375 ],
[0.31640625],
[0.0078125 ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0.0703125 ],
[0.66796875],
[0.85546875],
[0.98828125],
[0.98828125],
[0.98828125],
[0.98828125],
[0.76171875],
[0.3125 ],
[0.03515625],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0.21484375],
[0.671875 ],
[0.8828125 ],
[0.98828125],
[0.98828125],
[0.98828125],
[0.98828125],
[0.953125 ],
[0.51953125],
[0.04296875],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0.53125 ],
[0.98828125],
[0.98828125],
[0.98828125],
[0.828125 ],
[0.52734375],
[0.515625 ],
[0.0625 ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ],
[0. ]], dtype=float32),
array([[0.],
[0.],
[0.],
[0.],
[0.],
[1.],
[0.],
[0.],
[0.],
[0.]]))
A 28 x 28 array (784 elements) plus a 10 element array. training_data[0] is a numeral 5.
We are going to train a neural network on these data. It will have an input layer comprising 784 neurons and a hidden layer of 30 neurons. (The picture shows 15 neurons in the hidden layer.) The output layer is 10 neurons, corresponding to the digits 0 - 9. Only one of the output neurons will be activated.
import network
net = network.Network([784,30,10])
dir(net)
['SGD', '__class__', '__delattr__', '__dict__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattribute__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__le__', '__lt__', '__module__', '__ne__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__sizeof__', '__str__', '__subclasshook__', '__weakref__', 'backprop', 'biases', 'cost_derivative', 'evaluate', 'feedforward', 'num_layers', 'sizes', 'update_mini_batch', 'weights']
The SGD method is Stochastic Gradiant Descent. See (https://en.wikipedia.org/wiki/Stochastic_gradient_descent). It minimizes an objective function.
We invoke SGD on the training data,
def SGD(self, training_data, epochs, mini_batch_size, eta,
test_data=None):
"""Train the neural network using mini-batch stochastic
gradient descent. The ``training_data`` is a list of tuples
``(x, y)`` representing the training inputs and the desired
outputs. The other non-optional parameters are
self-explanatory. If ``test_data`` is provided then the
network will be evaluated against the test data after each
epoch, and partial progress printed out. This is useful for
tracking progress, but slows things down substantially."""
mini_batch_size is how large a batch to process at a time. Below we process 10 examples at a time.
eta is the Greek letter η which is the learning rate. Below the learning rate is 3.
Timing is done through calls to time.time() before and after runs. Below it takes less than 3 seconds per epoch.
net.SGD(training_data, 30, 10, 3.0, test_data=test_data)
Epoch 0: 9061 / 10000, took 2.40 seconds Epoch 1: 9242 / 10000, took 2.73 seconds Epoch 2: 9272 / 10000, took 2.44 seconds Epoch 3: 9317 / 10000, took 2.37 seconds Epoch 4: 9391 / 10000, took 2.44 seconds Epoch 5: 9373 / 10000, took 2.39 seconds Epoch 6: 9386 / 10000, took 2.39 seconds Epoch 7: 9417 / 10000, took 2.37 seconds Epoch 8: 9446 / 10000, took 2.44 seconds Epoch 9: 9442 / 10000, took 2.65 seconds Epoch 10: 9412 / 10000, took 2.43 seconds Epoch 11: 9409 / 10000, took 2.39 seconds Epoch 12: 9450 / 10000, took 2.57 seconds Epoch 13: 9471 / 10000, took 2.56 seconds Epoch 14: 9472 / 10000, took 2.56 seconds Epoch 15: 9477 / 10000, took 2.37 seconds Epoch 16: 9466 / 10000, took 2.40 seconds Epoch 17: 9480 / 10000, took 2.37 seconds Epoch 18: 9498 / 10000, took 2.51 seconds Epoch 19: 9488 / 10000, took 2.43 seconds Epoch 20: 9496 / 10000, took 2.43 seconds Epoch 21: 9480 / 10000, took 2.38 seconds Epoch 22: 9503 / 10000, took 2.45 seconds Epoch 23: 9513 / 10000, took 2.41 seconds Epoch 24: 9500 / 10000, took 2.43 seconds Epoch 25: 9494 / 10000, took 2.38 seconds Epoch 26: 9518 / 10000, took 2.40 seconds Epoch 27: 9506 / 10000, took 2.41 seconds Epoch 28: 9504 / 10000, took 2.42 seconds Epoch 29: 9509 / 10000, took 2.44 seconds
Now we change the number of hidden neurons to 100. We use the same epochs and other parameters. The time per epoch is now 8 seconds or so.
net = network.Network([784, 100, 10])
net.SGD(training_data, 30, 10, 3.0, test_data=test_data)
Epoch 0: 6718 / 10000, took 7.61 seconds Epoch 1: 6775 / 10000, took 7.61 seconds Epoch 2: 6927 / 10000, took 7.60 seconds Epoch 3: 7795 / 10000, took 7.67 seconds Epoch 4: 8645 / 10000, took 7.57 seconds Epoch 5: 8692 / 10000, took 7.62 seconds Epoch 6: 8707 / 10000, took 7.59 seconds Epoch 7: 8711 / 10000, took 7.59 seconds Epoch 8: 8748 / 10000, took 7.54 seconds Epoch 9: 8710 / 10000, took 7.70 seconds Epoch 10: 8754 / 10000, took 7.55 seconds Epoch 11: 8762 / 10000, took 7.63 seconds Epoch 12: 8771 / 10000, took 7.62 seconds Epoch 13: 8773 / 10000, took 7.66 seconds Epoch 14: 8785 / 10000, took 7.68 seconds Epoch 15: 8794 / 10000, took 7.66 seconds Epoch 16: 8789 / 10000, took 7.60 seconds Epoch 17: 8797 / 10000, took 7.61 seconds Epoch 18: 8788 / 10000, took 7.65 seconds Epoch 19: 8809 / 10000, took 7.70 seconds Epoch 20: 8793 / 10000, took 7.59 seconds Epoch 21: 8817 / 10000, took 7.69 seconds Epoch 22: 8813 / 10000, took 7.63 seconds Epoch 23: 8809 / 10000, took 7.60 seconds Epoch 24: 8811 / 10000, took 7.65 seconds Epoch 25: 8832 / 10000, took 7.65 seconds Epoch 26: 8816 / 10000, took 7.95 seconds Epoch 27: 8806 / 10000, took 7.69 seconds Epoch 28: 8826 / 10000, took 7.70 seconds Epoch 29: 8835 / 10000, took 7.62 seconds