Training of cifar10 dataset

Time:2021-3-25

Download dataset

The cifar10 dataset contains 60000 32 * 32 pixel color images and tags, covering ten categories: aircraft, automobile, bird, cat, deer, dog, frog, horse, boat and truck.

Fifty thousand of them were for training and ten thousand for testing.

 

import tensorflow as tf
from tensorflow import keras
from matplotlib import pyplot as plt
import numpy as np
from tensorflow.keras.layers import Conv2D, MaxPool2D, Flatten, Dense,Dropout

cifar10 = keras.datasets.cifar10
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

 

Build network structure

model = keras.models.Sequential([
    Conv2D(128, (3, 3), activation='relu',padding='same'),
    keras.layers.BatchNormalization(),
    MaxPool2D((2, 2)),
    Dropout(0.3),
    Conv2D(256, (3, 3), activation='relu',padding='same'),
    keras.layers.BatchNormalization(),
    MaxPool2D((2, 2)),
    Dropout(0.3),
    Conv2D(512, (3, 3), activation='relu',padding='same'),
    keras.layers.BatchNormalization(),
    MaxPool2D((2, 2)),
    Flatten(),
    Dropout(0.5),
    Dense(512, activation='relu', kernel_regularizer=keras.regularizers.l2(0.1)),
    Dropout(0.5),
    Dense(10, activation='softmax')
])

 

Compilation model

model.compile(optimizer=keras.optimizers.Adam(lr=0.0001), loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False), metrics=['accuracy'])

 

Training model

history = model.fit(x_train, y_train, epochs=100, batch_size=16,verbose=1,validation_data=(x_test, y_test),validation_freq=1)

 

Visualization of ACC / loss curve

#Show the ACC and loss curves of training set and test set
plt.rcParams['font.sans-serif']=['SimHei']
acc = history.history['accuracy']
val_acc = history.history['val_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']

plt.subplot(1, 2, 1)
plt.plot (ACC, label ='training ACC ')
plt.plot (val_ ACC, label ='test ACC ')
plt.title ('acc curve ')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot (loss, label ='training loss')
plt.plot (val_ Loss, label ='test loss')
plt.title ('loss curve ')
plt.legend()
plt.show()