Pytorch: a simple Gan example (MNIST dataset)

Time:2021-4-6

I don’t want to talk much nonsense. Let’s go straight to the code!

# -*- coding: utf-8 -*-
"""
Created on Sat Oct 13 10:22:45 2018
@author: www
"""
 
import torch
from torch import nn
from torch.autograd import Variable
 
import torchvision.transforms as tfs
from torch.utils.data import DataLoader, sampler
from torchvision.datasets import MNIST
 
import numpy as np
 
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
 
plt.rcParams [' figure.figsize '] = (10.0, 8.0) # set drawing size
plt.rcParams['image.interpolation'] = 'nearest'
plt.rcParams['image.cmap'] = 'gray'
 
def show_ Images (images): # define drawing tools
  images = np.reshape(images, [images.shape[0], -1])
  sqrtn = int(np.ceil(np.sqrt(images.shape[0])))
  sqrtimg = int(np.ceil(np.sqrt(images.shape[1])))
 
  fig = plt.figure(figsize=(sqrtn, sqrtn))
  gs = gridspec.GridSpec(sqrtn, sqrtn)
  gs.update(wspace=0.05, hspace=0.05)
 
  for i, img in enumerate(images):
    ax = plt.subplot(gs[i])
    plt.axis('off')
    ax.set_xticklabels([])
    ax.set_yticklabels([])
    ax.set_aspect('equal')
    plt.imshow(img.reshape([sqrtimg,sqrtimg]))
  return 
  
def preprocess_img(x):
  x = tfs.ToTensor()(x)
  return (x - 0.5) / 0.5
 
def deprocess_img(x):
  return (x + 1.0) / 2.0
 
class ChunkSampler( sampler.Sampler )Define a sampling function
  """Samples elements sequentially from some offset. 
  Arguments:
    num_samples: # of desired datapoints
    start: offset where we should start selecting from
  """
  def __init__(self, num_samples, start=0):
    self.num_samples = num_samples
    self.start = start
 
  def __iter__(self):
    return iter(range(self.start, self.start + self.num_samples))
 
  def __len__(self):
    return self.num_samples
    
NUM_TRAIN = 50000
NUM_VAL = 5000
 
NOISE_DIM = 96
batch_size = 128
 
train_set = MNIST('E:/data', train=True, transform=preprocess_img)
 
train_data = DataLoader(train_set, batch_size=batch_size, sampler=ChunkSampler(NUM_TRAIN, 0))
 
val_set = MNIST('E:/data', train=True, transform=preprocess_img)
 
val_data = DataLoader(val_set, batch_size=batch_size, sampler=ChunkSampler(NUM_VAL, NUM_TRAIN))
 
imgs = deprocess_ img(train_ data.__ iter__ ().next()[0].view(batch_ Size, 784)). Numpy (). Squeeze () # visual image effect
show_images(imgs)
 
#Discriminant network
def discriminator():
  net = nn.Sequential(    
      nn.Linear(784, 256),
      nn.LeakyReLU(0.2),
      nn.Linear(256, 256),
      nn.LeakyReLU(0.2),
      nn.Linear(256, 1)
    )
  return net
  
#Generating network
def generator(noise_dim=NOISE_DIM):  
  net = nn.Sequential(
    nn.Linear(noise_dim, 1024),
    nn.ReLU(True),
    nn.Linear(1024, 1024),
    nn.ReLU(True),
    nn.Linear(1024, 784),
    nn.Tanh()
  )
  return net
  
#The loss of discriminator is to judge the score of real data as 1, the score of false data as 0, and the loss of generator is to judge the score of false data as 1
 
bce_ loss =  nn.BCEWithLogitsLoss () cross entropy loss function
 
def discriminator_ loss(logits_ real, logits_ Fake): # loss of discriminator
  size = logits_real.shape[0]
  true_labels = Variable(torch.ones(size, 1)).float()
  false_labels = Variable(torch.zeros(size, 1)).float()
  loss = bce_loss(logits_real, true_labels) + bce_loss(logits_fake, false_labels)
  return loss
  
def generator_ loss(logits_ Fake: loss of generator 
  size = logits_fake.shape[0]
  true_labels = Variable(torch.ones(size, 1)).float()
  loss = bce_loss(logits_fake, true_labels)
  return loss
  
#Using Adam for training, the learning rate is 3e-4, beta1 is 0.5, beta2 is 0.999
def get_optimizer(net):
  optimizer = torch.optim.Adam(net.parameters(), lr=3e-4, betas=(0.5, 0.999))
  return optimizer
  
def train_a_gan(D_net, G_net, D_optimizer, G_optimizer, discriminator_loss, generator_loss, show_every=250, 
        noise_size=96, num_epochs=10):
  iter_count = 0
  for epoch in range(num_epochs):
    for x, _ in train_data:
      bs = x.shape[0]
      #Discriminant network
      real_ Data = variable (x). View (BS, - 1) × real data
      logits_real = D_net(real_data) #Discriminant network得分
      
      sample_ noise = ( torch.rand (bs, noise_ Uniform distribution of (size) - 0.5) / 0.5 # - 1 ~ 1
      g_fake_seed = Variable(sample_noise)
      fake_ images = G_ net(g_ fake_ Seed) # false data generated
      logits_fake = D_net(fake_images) #Discriminant network得分
 
      d_ total_ error = discriminator_ loss(logits_ real, logits_ Loss of (fake) # discriminant
      D_optimizer.zero_grad()
      d_total_error.backward()
      D_ optimizer.step () optimal discriminant network
      
      #Generating network
      g_fake_seed = Variable(sample_noise)
      fake_ images = G_ net(g_ fake_ Seed) # false data generated
 
      gen_logits_fake = D_net(fake_images)
      g_error = generator_loss(gen_logits_fake) #Generating network的 loss
      G_optimizer.zero_grad()
      g_error.backward()
      G_ optimizer.step () optimize generation network
 
      if (iter_count % show_every == 0):
        print('Iter: {}, D: {:.4}, G:{:.4}'.format(iter_count, d_total_error.item(), g_error.item()))
        imgs_numpy = deprocess_img(fake_images.data.cpu().numpy())
        show_images(imgs_numpy[0:16])
        plt.show()
        print()
      iter_count += 1
 
D = discriminator()
G = generator()
 
D_optim = get_optimizer(D)
G_optim = get_optimizer(G)
 
train_a_gan(D, G, D_optim, G_optim, discriminator_loss, generator_loss)

The above pytorch: implementation of a simple Gan example (MNIST dataset) is the whole content that Xiaobian shares with you. I hope it can give you a reference, and I hope you can support developer more.