Python can train more than three channels of input data


Case background: video recognition

Assuming that each input is 8s gray-scale video and the video frame rate is 25FPS, the video is composed of 200 image sequences. Each frame is a single channel gray-scale image, and the video frame rate is 25 FPS np.stack (depth splicing) can splice 200 frames into 200 channels of depth data, and then sent to the network for training

If there are more than 200 channels in the input image, you can extract frames from the video. For example, 40 frames can be extracted at equal intervals according to the specific scene. For example, 40 frames are extracted at equal intervals. Finally, the input video is equivalent to inputting a 40 channel image data

Loading of more than three channels of data by Python:

Read every frame of the video and convert it into array format. Then, we can get a 40 channel array format depth data and save it in pickle

Do the above operation for each video and save it in the pickle

I save the video depth data of fire in a. Pkl file. There are 2504 fire videos, that is, 2504 fire depth data

The depth data of non fire video is saved in a.pkl file. There are 3985 non fire videos, that is, 3985 non fire depth data

Data loading

import torch 
from torch.utils import data
import os
from PIL import Image
import numpy as np
import pickle
class Fire_Unfire(data.Dataset):
  def __init__(self,fire_path,unfire_path):
    self.pickle_fire = open(fire_path,'rb')
    self.pickle_unfire = open(unfire_path,'rb')
  def __getitem__(self,index):
    if index <2504:
      fire =  pickle.load ( self.pickle_ Fire) ා height * width * channel
      fire =  fire.transpose (2,0,1) channel * height * width
      data = torch.from_numpy(fire)
      label = 1
      return data,label
    elif index>=2504 and index<6489:
      unfire = pickle.load(self.pickle_unfire)
      unfire = unfire.transpose(2,0,1)
      data = torch.from_numpy(unfire)
      label = 0
      return data,label
  def __len__(self):
    return 6489
root_path = './datasets/train'
dataset = Fire_Unfire(root_path +'/fire_train.pkl',root_path +'/unfire_train.pkl')
#Convert to the format of Python Network input (batch size, number of channels, height, width)
from import DataLoader
fire_dataloader = DataLoader(dataset,batch_size=4,shuffle=True,drop_last = True)

model training

import torch
from torch.utils import data
from nets.mobilenet import mobilenet
from config.config import default_config
from torch.autograd import Variable as V
import numpy as np
import sys
import time
opt = default_config()
def train():
  #Model definition
  model = mobilenet().cuda()
  if opt.pretrain_model:
  #Loss function
  criterion = torch.nn.CrossEntropyLoss().cuda()
  #Learning rate
  lr =
  optimizer = torch.optim.SGD(model.parameters(),lr = lr,weight_decay=opt.weight_decay)
  pre_loss = 0.0
  for epoch in range(opt.max_epoch):
    train_data = Fire_Unfire(opt.root_path +'/fire_train.pkl',opt.root_path +'/unfire_train.pkl')
    train_dataloader = data.DataLoader(train_data,batch_size=opt.batch_size,shuffle=True,drop_last = True)
    loss_sum = 0.0
    for i,(datas,labels) in enumerate(train_dataloader):
      #Gradient clearing
      input = V(datas.cuda()).float()
      target = V(labels.cuda()).long()
      score = model(input).cuda()
      loss = criterion(score,target)
      loss_sum += loss
      #Back propagation
      #Gradient update

RuntimeError: Expected object of scalar type Long but got scalar type Float for argument #2 ‘target’

Solution: target= target.long ()

The above Python implementation of training more than three channels of data input is the small editor to share with you all the content, I hope to give you a reference, also hope that you can support developpaer more.

Recommended Today

PHP 12th week function learning record

sha1() effect sha1()Function to evaluate the value of a stringSHA-1Hash. usage sha1(string,raw) case <?php $str = “Hello”; echo sha1($str); ?> result f7ff9e8b7bb2e09b70935a5d785e0cc5d9d0abf0 sha1_file() effect sha1_file()Function calculation fileSHA-1Hash. usage sha1_file(file,raw) case <?php $filename = “test.txt”; $sha1file = sha1_file($filename); echo $sha1file; ?> result aaf4c61ddcc5e8a2dabede0f3b482cd9aea9434d similar_text() effect similar_text()Function to calculate the similarity between two strings. usage similar_text(string1,string2,percent) case […]