We often set batch when training the network_ Size, this batch_ What is the use of size? How big should the data set of 10000 graphs be? What is the difference between setting it to 1, 10, 100 or 10000?
#Handwritten numeral recognition network training method network.fit( train_images, train_labels, epochs=5, batch_size=128)
Batch gradient descent (bgd)
Gradient descent algorithm is generally used to minimize the loss function: feed the original data network to the network, and the network will perform certain calculations and obtain a loss function, which represents the gap between the calculation results of the network and the actual situation. Gradient descent algorithm is used to adjust parameters to make the training results better fit the actual situation, which is the meaning of gradient descent.
Batch gradient descent is the most original form of gradient descent. Its idea is to use all training data to update the gradient together. The gradient descent algorithm needs to calculate the derivative of the loss function. It can be imagined that if the training data set is relatively large, all data need to be read in together, trained together in the network and summed together, it will be a huge matrix, This calculation will be very huge. Of course, this is also an advantage, that is, considering all the training sets, the network must be optimized in the direction of optimization (extremum).
Stochastic gradient descent (SGD)
Different from batch gradient descent, the idea of random gradient descent is to take out one of the training sets each time for fitting training and iterative training. The training process is to take out a training data, modify the network parameters to fit it and modify the parameters, then take out the next training data, use the newly modified network to fit and modify the parameters, and iterate until each data has been input into the network, and then start over again until the parameters are relatively stable, The advantage is that only one training data is used for each fitting, and the iteration speed of each round of update is particularly fast. The disadvantage is that only one training data is considered for each fitting. The optimization direction is not necessarily the direction of the network in the overall optimal direction of the training set, and it often jitters or converges to the local optimal.
Mini batch gradient descent (MBGD)
The small batch gradient descent adopts the most commonly used compromise solution in the computer. Each time the network is input for training, it is neither the whole training data set nor one in the training data set, but a part of it, such as 20 at a time. It can be imagined that this will not cause too much data and slow calculation, nor will it cause severe jitter or non optimal optimization of the network due to some noise characteristics of a training sample.
Compare the calculation methods of these three gradient descent algorithms: batch gradient descent is the operation of large matrix. It can be considered to use the method of matrix calculation optimization for parallel calculation, which requires high performance of hardware such as memory; Each iteration of random gradient descent depends on the previous calculation results, so it can not be calculated in parallel and has low hardware requirements; The gradient of small batch decreases. In each iteration, there is a small matrix, and the requirements for hardware are not high. At the same time, parallel computing can be used for matrix operation, and serial computing can be used between multiple iterations, which will save time as a whole.
Look at the following figure, which can better reflect the iterative process of three shaving reduction algorithms to optimize the network, and will have a more intuitive impression.
For the optimization of gradient descent algorithm, the training data set is very small, and batch gradient descent is directly used; Only one training data can be obtained at a time, or the training data transmitted online in real time, using random gradient descent; In other cases or in general, batch gradient descent algorithm is better.
- This article starts from: rais