Tensorflow2.0-mnist handwritten numeral recognition example

Time:2021-1-22

Tensorflow2.0-mnist handwritten numeral recognition example

     

When you read, you don’t realize that spring is deep, and every inch of time is golden.

 

Introduction:After training by CNN convolution neural network, handwritten images are recognized, and 0, 1, 2, 4 in MNIST dataset are tested.

                   

1、 MNIST data set preparation

Although the data set can be downloaded automatically through the code, the domestic download of MNIST data set is not stable, and the “downloading data from” https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz 】In this case, the code starts from the definition directory_ set_ If the MNIST data set is not obtained in TF3, it will be downloaded automatically, but the download time is relatively long, so it is better to prepare in advance.

Downloading mnist data from https

Download address of MNIST dataset

The official website of MNIST data set is as above, just download the following four things, the two images and labels in the picture.

Training set images: train-images-idx3- ubyte.gz (9.9 MB, 47 MB after decompression, including 60000 samples)

Training set labels: train-labels-idx1- ubyte.gz (29 KB, 60 KB after decompression, including 60000 tags)

Test set images: t10k-images-idx3- ubyte.gz (1.6 MB, 7.8 MB after decompression, including 10000 samples)

Test set labels: t10k-labels-idx1- ubyte.gz (5KB, 10KB after decompression, including 10000 tags)

MNIST data sets are from the National Institute of standards and technology, USA, and training sets(training set)Composed of 250 handwritten numbers from different people, 50% of them are high school students and 50% of them are Census Bureau staff; test set(test set)It is also the same proportion of handwritten numeral data; you can create a new folder – MNIST, download the data set to MNIST and decompress it.

MNIST data set integration

3、 Picture training

train.py The training code is as follows:

Tensorflow2.0-mnist handwritten numeral recognition exampleTensorflow2.0-mnist handwritten numeral recognition example

1 import os
 2 import tensorflow as tf
 3 from tensorflow.keras import datasets, layers, models
 4 
 5 '''
 6 python 3.7、3.9
 7 tensorflow 2.0.0b0
 8 '''
 9 
The first half of the model definition mainly uses Keras.layers  Conv2d (convolution) and maxpooling2d (pooling) functions provided.
11 # the input of CNN is the dimension (image)_ height, image_ width, color_ The tensor of channels,
The 12 # MNIST dataset is black and white, so there is only one color_ Generally, there are three color channels (R, G, b),
13 # also has four channels (R, G, B, a), a stands for transparency;
14 # for MNIST data set, the tensor dimension of input is (28, 28, 1), which is determined by the parameter input_ Shapa to the first layer of the network
15 # CNN model processing:
16 class CNN(object):
17     def __init__(self):
18         model = models.Sequential()
19 # the first layer convolution, convolution core size is 3 * 3, 32, 28 * 28 is the size of the image to be trained
20         model.add(layers.Conv2D(
21             32, (3, 3), activation='relu', input_shape=(28, 28, 1)))
22         model.add(layers.MaxPooling2D((2, 2)))
23 # the second layer convolution, convolution core size of 3 * 3, 64
twenty-four model.add ( layers.Conv2D (64, (3, 3), activation ='relu ')) #
25         model.add(layers.MaxPooling2D((2, 2)))
26 # the third layer convolution, convolution core size of 3 * 3, 64
27         model.add(layers.Conv2D(64, (3, 3), activation='relu'))
28 
29         model.add(layers.Flatten())
30         model.add(layers.Dense(64, activation='relu'))
31         model.add(layers.Dense(10, activation='softmax'))
32 # flatten layer is used to "flatten" the input, that is, to make the multi-dimensional input one-dimensional, which is often used in the transition from convolution layer to fully connected layer. Flatten does not affect the size of the batch
33 # dense: a fully connected layer is equivalent to adding a layer
34 # softmax is used in the process of multi classification. It maps the output of multiple neurons to (0,1) interval, which can be understood as probability, so as to carry out multi classification!
thirty-five model.summary () output the parameters of each layer of the model
36 
37         self.model = model
38 
39 
Preprocessing of 40 # MNIST data set
41 class DataSource(object):
42     def __init__(self):
43 # MNIST data set storage location, if it does not exist, it will be automatically downloaded
44         data_path = os.path.abspath(os.path.dirname(
45             __file__)) + '/../data_set_tf2/mnist.npz'
46         (train_images, train_labels), (test_images,
47                                        test_labels) = datasets.mnist.load_data(path=data_path)
4860000 training pictures and 10000 test pictures
49         train_images = train_images.reshape((60000, 28, 28, 1))
50         test_images = test_images.reshape((10000, 28, 28, 1))
51 # pixel values are mapped between 0 and 1
52         train_images, test_images = train_images / 255.0, test_images / 255.0
53 
54         self.train_images, self.train_labels = train_images, train_labels
55         self.test_images, self.test_labels = test_images, test_labels
56 
57 
58 # start training and save training results
59 class Train:
60     def __init__(self):
61         self.cnn = CNN()
62         self.data = DataSource()
63 
64     def train(self):
65         check_path = './ckpt/cp-{epoch:04d}.ckpt'
66 # period every 5 epoch
67         save_model_cb = tf.keras.callbacks.ModelCheckpoint(
68             check_path, save_weights_only=True, verbose=1, period=5)
69 
70         self.cnn.model.compile(optimizer='adam',
71                                loss='sparse_categorical_crossentropy',
72                                metrics=['accuracy'])
73         self.cnn.model.fit(self.data.train_images, self.data.train_labels,
74                            epochs=5, callbacks=[save_model_cb])
75 
76         test_loss, test_acc = self.cnn.model.evaluate(
77             self.data.test_images, self.data.test_labels)
78 print ("accuracy:%. 4f, a total of% d pictures tested"%_ acc, len( self.data.test_ labels)))
79 
80 
81 if __name__ == "__main__":
82     app = Train()
83     app.train()

View code ~ take a little tire

MNIST handwritten digit recognition training for about four minutes, the accuracy rate is as high as 0.9902, the following video only intercepts the first ten seconds of training.

MNIST handwritten digit recognition training video

model.summary () print the defined model structure

Model structure defined by CNN

Tensorflow2.0-mnist handwritten numeral recognition exampleTensorflow2.0-mnist handwritten numeral recognition example

1 Model: "sequential"
 2 _________________________________________________________________
 3 Layer (type)                 Output Shape              Param #   
 4 =================================================================
 5 conv2d (Conv2D)              (None, 26, 26, 32)        320       
 6 _________________________________________________________________
 7 max_pooling2d (MaxPooling2D) (None, 13, 13, 32)        0         
 8 _________________________________________________________________
 9 conv2d_1 (Conv2D)            (None, 11, 11, 64)        18496     
10 _________________________________________________________________
11 max_pooling2d_1 (MaxPooling2 (None, 5, 5, 64)          0         
12 _________________________________________________________________
13 conv2d_2 (Conv2D)            (None, 3, 3, 64)          36928     
14 _________________________________________________________________
15 flatten (Flatten)            (None, 576)               0         
16 _________________________________________________________________
17 dense (Dense)                (None, 64)                36928     
18 _________________________________________________________________
19 dense_1 (Dense)              (None, 10)                650       
20 =================================================================
21 Total params: 93,322
22 Trainable params: 93,322
23 Non-trainable params: 0
24 _________________________________________________________________

View Code

We can see that the output of each conv2d and maxpooling 2D layer is a three-dimensional tensor (height, width, channels), and the height and width will gradually decrease; the number of output channels is controlled by the first parameter (for example, 32 or 64). With the decrease of height and width, the channel can increase (from the perspective of computational force).

The second half of the model defines the tensoroutput. layers.Flatten The three-dimensional tensor will be transformed into one-dimensional vector. Before expansion, the dimension of the tensor is (3, 3, 64), and then it will be transformed into one-dimensional (576) [3 * 3 * 64] vector layers.Dense The number of bits of one-dimensional vector is gradually changed from 576 to 64, and then to 10.

In the second part, a common neural network with 64 hidden layer, 576 input layer and 10 output layer is constructed. The activation function of the last layer is softmax, 10 bits can express exactly 10 numbers of 0-9. The subscript of the maximum value can represent the corresponding number. Using the argmax () method of numpy to obtain the maximum subscript, it is easy to calculate the predicted value.

train.py Running results

It can be seen that after the first round of training, the recognition accuracy reaches 0.9536, and after five rounds of training, the accuracy reaches 0.9902.In the fifth round, the model parameters are successfully saved in. / CKPT / cp-0005.ckpt, and the accuracy is higher than 0.9940, so it is not that the longer the training time is, the better. The saved model parameters can be loaded, the whole convolution neural network can be restored, and the real image can be predicted.

Save training model parameters

4、 Picture prediction

predict.py The code is as follows:

Tensorflow2.0-mnist handwritten numeral recognition exampleTensorflow2.0-mnist handwritten numeral recognition example

1 import tensorflow as tf
 2 from PIL import Image
 3 import numpy as np
 4 
 5 from mnist.v4_cnn.train import CNN
 6 
 7 '''
 8 python 3.7 3.9
 9 tensorflow 2.0.0b0
10 pillow(PIL) 4.3.0
11 '''
12 
13 
14 class Predict(object):
15     def __init__(self):
16         latest = tf.train.latest_checkpoint('./ckpt')
17         self.cnn = CNN()
18 ᦇ recovery network weight
19         self.cnn.model.load_weights(latest)
20 
21     def predict(self, image_path):
22 # read pictures in black and white
23         img = Image.open(image_path).convert('L')
24         img = np.reshape(img, (28, 28, 1)) / 255.
25         x = np.array([1 - img])
26 
27         # API refer: https://keras.io/models/model/
28         y = self.cnn.model.predict(x)
29 
30 # because x only passes in one picture, take y [0]
31         #  np.argmax () obtain the subscript of the maximum value, which is the number represented
32         print(image_path)
33         print(y[0])
34         print('        -> Predict picture number is: ', np.argmax(y[0]))
35 
36 
37 if __name__ == "__main__":
38     app = Predict()
39     app.predict('../test_images/0.png')
40     app.predict('../test_images/1.png')
41     app.predict('../test_images/4.png')
42     app.predict('../test_images/2.png')

View Code

Forecast results

The prediction results are as follows

Tensorflow2.0-mnist handwritten numeral recognition exampleTensorflow2.0-mnist handwritten numeral recognition example

1 ../test_images/0.png
 2 [9.9999774e-01 2.6819215e-08 1.2541744e-07 8.7437911e-08 1.0661940e-09
 3  3.3693670e-08 4.6488995e-07 3.5915035e-09 9.8040758e-08 1.4385278e-06]
 4         -> Predict picture number is:  0
 5 ../test_images/1.png
 6 [7.75440956e-09 9.99991298e-01 1.41642090e-07 1.09819875e-10
 7  6.76554646e-06 7.63710162e-09 2.37024622e-08 1.58189516e-06
 8  2.49125264e-07 4.92376007e-09]
 9         -> Predict picture number is:  1
10 ../test_images/4.png
11 [7.03467840e-10 8.20740708e-04 1.11648405e-04 3.93262711e-09
12  9.99048650e-01 1.08713095e-07 4.24647197e-08 1.85665340e-05
13  5.03181887e-08 1.86591734e-07]
14         -> Predict picture number is:  4
15 ../test_images/2.png
16 [1.5828672e-08 1.9245699e-07 9.9999440e-01 5.3448480e-06 1.7397912e-10
17  8.6148493e-13 2.5441890e-10 5.3953073e-08 3.5735226e-08 8.9734775e-11]
18         -> Predict picture number is:  2

View Code

As mentioned above, after CNN training, the real values of 0, 1, 2 and 4 handwritten images are accurately predicted through the model parameters.

                 

    

 Spring is deep when reading

                            An inch of time is an inch of gold