Little bear paddle exercise book-01 handwritten numeral recognition

Time:2022-5-26

Little bear paddle exercise book-01 handwritten numeral recognition

brief introduction

The development and testing of this project are carried out under Ubuntu 20.04 system.
View the latest code of the project home page:Little bear paddle exercise book
Home page of Baidu PaddlePaddle AI Studio:Little bear paddle exercise book-01 handwritten numeral recognition
CUDA reference for Ubuntu system installation:Installation of Ubuntu Baidu PaddlePaddle and CUDA

Document description

file explain
train.py Training procedure
test.py Test procedure
report.py Report program
get-data.sh Get data to dataset directory
check-data.sh Check whether the data in the dataset directory exists
mod/lenet.py Lenet network model
mod/dataset.py Analysis of MNIST handwritten data set
mod/utils.py miscellaneous
mod/config.py to configure
mod/report.py Result report
dataset Dataset directory
params Model parameter saving directory
log Visualdl log save directory

data set

The data set comes from the public data set of Baidu PaddlePaddle:Classic MNIST dataset

get data

If it is running on the local computer, after downloading the data, the file will be placed in thedatasetUnder the project directory, run the following script under the project directory.
If running in BaiduAI StudioEnvironment, viewingdataWhether the directory has data, run the following script under the project directory.

bash get-data.sh

Check data

After obtaining the data, run the following script in the project directory to check whether the data in the dataset directory exists.

bash check-data.sh

network model

Network model usageLenet network modelSource: Baidu PaddlePaddle tutorial and network
Lenet network modelreference resources:Baidu PaddlePaddle tutorial

import paddle
import paddle.nn as nn
import paddle.nn.functional as F


#Lenet network model
class LeNet(nn.Layer):
    def __init__(self, num_classes=10):
        super(LeNet, self).__init__()
        if num_classes < 1:
            Raise exception ("the number of categories num_classes must be greater than 0: {}". Format (num_classes))
        self.num_classes = num_classes
        self.conv1 = nn.Conv2D(
            in_channels=1, out_channels=6, kernel_size=5, stride=1)
        self.max_pool1 = nn.MaxPool2D(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2D(
            in_channels=6, out_channels=16, kernel_size=5, stride=1)
        self.max_pool2 = nn.MaxPool2D(kernel_size=2, stride=2)
        self.conv3 = nn.Conv2D(
            in_channels=16, out_channels=120, kernel_size=4, stride=1)
        self.fc1 = nn.Linear(in_features=120, out_features=64)
        self.fc2 = nn.Linear(in_features=64, out_features=num_classes)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.max_pool1(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = self.max_pool2(x)
        x = self.conv3(x)
        x = F.relu(x)
        x = paddle.flatten(x, start_axis=1, stop_axis=-1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        return x

Data set analysis

The data set analysis method comes from the Baidu PaddlePaddle tutorial and network, which is slightly different from the Baidu PaddlePaddle MNIST data set

import paddle
import os
import struct
import numpy as np


class MNIST(paddle.io.Dataset):
    """
    MNIST handwritten data set analysis, inheriting pad io. Dataset class
    """

    def __init__(self,
                 images_path: str,
                 labels_path: str,
                 transform=None,
                 ):
        """
        Constructor that defines the size of the dataset

        Args:
            images_ Path (STR): image set path
            labels_ Path (STR): label set path
            Transform (compose, optional): the operation combination of converting data. The default is none
        """
        super(MNIST, self).__init__()
        self.images_path = images_path
        self.labels_path = labels_path
        self._ check_ Path (images_path, "data path error")
        self._ check_ Path (labels_path, "label path error")
        self.transform = transform
        self.images, self.labels = self.parse_dataset(images_path, labels_path)

    def __getitem__(self, idx):
        """
        Get individual data and labels

        Args:
            IDX (any): index

        Returns:
            Image (float32): image
            Label (Int64): Label
        """
        image, label = self.images[idx], self.labels[idx]
        #Here reshape is two-dimensional [28, 28]
        image = np.reshape(image, [28, 28])
        if self.transform is not None:
            image = self.transform(image)
        # label. Astype can only be Int64 if it is an integer
        return image.astype('float32'), label.astype('int64')

    def __len__(self):
        """
        Data quantity

        Returns:
            int: Data quantity
        """
        return len(self.labels)

    def _check_path(self, path: str, msg: str):
        """
        Check whether the path exists

        Args:
            Path (STR): Path
            MSG (STR, optional): exception message

        Raises:
            Exception: path error, exception
        """
        if not os.path.exists(path):
            raise Exception("{}: {}".format(msg, path))

    @staticmethod
    def parse_dataset(images_path: str, labels_path: str):
        """
        Data set analysis

        Args:
            images_ Path (STR): image set path
            labels_ Path (STR): label set path

        Returns:
            Images: image set
            Labels: label set
        """
        with open(images_path, 'rb') as imgpath:
            #Parse image set
            magic, num, rows, cols = struct.unpack('>IIII', imgpath.read(16))
            #Here reshape is one dimension [786]
            images = np.fromfile(
                imgpath, dtype=np.uint8).reshape(num, rows * cols)
        with open(labels_path, 'rb') as lbpath:
            #Resolve label set
            magic, n = struct.unpack('>II', lbpath.read(8))
            labels = np.fromfile(lbpath, dtype=np.uint8)
        return images, labels

Configuration module

You can view and modifymod/config.pyDocument with detailed description

Start training

functiontrain.pyFile, view the command line parameter plus – H

python3 train.py
--Whether CPU uses CPU for calculation. CUDA is used by default
  --Learning rate, default 0.001
  --Epochs training rounds, 2 rounds by default
  --Batch size batch quantity, default 128
  --Number of num workers threads, default 2
  --No save whether to save the model parameters. It is saved by default. After selection, the model parameters are not saved
  --Load dir reads model parameters and subfolders under params directory. It does not read by default
  --Log whether to output visualdl log. It is not output by default
  --Summary outputs network model information, which is not output by default. After selection, only information is output and training will not be started

test model

functiontest.pyFile, view the command line parameter plus – H

python3 test.py
--Whether CPU uses CPU for calculation. CUDA is used by default
  --Batch size batch quantity, default 128
  --Number of num workers threads, default 2
  --Load dir reads model parameters, reads subfolders under params directory, and defaults to best directory

View results report

functionreport.pyFile, can displayparamsOf all subdirectories under the directoryreport.json
Add parameter–bestaccording tolossThe smallest model parameters are saved inbestUnder the subdirectory.

python3 report.py

report. JSON description

Key name explain
id String ID generated by time
loss Loss value of this training
acc ACC value of this training
epochs Epochs training value
batch_size Batch of this training_ Size value
learning_rate Learning in this training_ Rate value

Visualdl visual analysis tool

  • Installation and use instructions refer to:VisualDL
  • Add parameters during training–log
  • If it isAI StudioEnvironmental traininglogDownload the directory, unzip it and put it under the local project directorylogcatalogue
  • Run the following command in the project directory
  • Then, according to the prompted URL, open the browser to access the prompted URL
visualdl --logdir ./log