Detailed explanation of the difference between using pytorch save model for testing and continuing training

Time:2021-3-4

Save model

Save the model just for testing, just


torch.save(model.state_dict, path)

Path is the saved path

However, sometimes there are too many models and data to complete the training at one time, and when the Adam optimizer is used, the optimizer parameters and epoch of the training must be saved


state = { 'model': model.state_dict(), 'optimizer':optimizer.state_dict(), 'epoch': epoch }  
torch.save(state, path)

Because here


def adjust_learning_rate(optimizer, epoch):
  lr_t = lr
  lr_t = lr_t * (0.3 ** (epoch // 2))
  for param_group in optimizer.param_groups:
    param_group['lr'] = lr_t

The learning rate changes according to the epoch. If you don’t save the epoch, basically every time you start training from epoch 0, the learning rate will be the same!!

Recovery model

When the recovery model is only used for testing,


model.load_state_dict(torch.load(path))

Path is the path of the previous storage model

But if it’s for further training,


checkpoint = torch.load(path)
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
start_epoch = checkpoint['epoch']+1

The model optimizer parameters and epoch are recovered in turn

The difference between the above saved model of Python for testing and continuous training is the whole content shared by Xiaobian. I hope it can give you a reference and support developer.

Recommended Today

Review of SQL Sever basic command

catalogue preface Installation of virtual machine Commands and operations Basic command syntax Case sensitive SQL keyword and function name Column and Index Names alias Too long to see? Space Database connection Connection of SSMS Connection of command line Database operation establish delete constraint integrity constraint Common constraints NOT NULL UNIQUE PRIMARY KEY FOREIGN KEY DEFAULT […]