Pytoch implementation of transformer


This article is reprinted,

This article mainly introduces how to use pytorch to reproduce transformer to realize simple machine translation tasks. For a detailed introduction to transformer, please refer to this articleTransformer details

Transformer structure

Pytoch implementation of transformer

Data preprocessing

I did not use any large data set here, but manually input two pairs of German → English sentences, and I manually hard coded the index of each word, mainly to reduce the difficulty of code reading. I hope readers can pay more attention to the part of model implementation

import math
import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
import as Data

device = 'cuda' if torch.cuda.is_available() else 'cpu'

# S: Symbol that shows starting of decoding input
# E: Symbol that shows starting of decoding output
# P: Symbol that will fill in blank sequence if current batch data size is short than time steps
sentences = [
    # enc_input           dec_input         dec_output
    ['ich mochte ein bier P', 'S i want a beer .', 'i want a beer . E'],
    ['ich mochte ein cola P', 'S i want a coke .', 'i want a coke . E']

# Padding Should be Zero
src_vocab = {'P': 0, 'ich': 1, 'mochte': 2, 'ein': 3, 'bier': 4, 'cola': 5}
src_vocab_size = len(src_vocab)

tgt_vocab = {'P': 0, 'i': 1, 'want': 2, 'a': 3, 'beer': 4, 'coke': 5, 'S': 6, 'E': 7, '.': 8}
idx2word = {i: w for i, w in enumerate(tgt_vocab)}
tgt_vocab_size = len(tgt_vocab)

src_len = 5  # enc_input max sequence length
tgt_len = 6  # dec_input(=dec_output) max sequence length

def make_data(sentences):
    enc_inputs, dec_inputs, dec_outputs = [], [], []
    for i in range(len(sentences)):
        enc_input = [[src_vocab[n] for n in sentences[i][0].split()]]  # [[1, 2, 3, 4, 0], [1, 2, 3, 5, 0]]
        dec_input = [[tgt_vocab[n] for n in sentences[i][1].split()]]  # [[6, 1, 2, 3, 4, 8], [6, 1, 2, 3, 5, 8]]
        dec_output = [[tgt_vocab[n] for n in sentences[i][2].split()]]  # [[1, 2, 3, 4, 8, 7], [1, 2, 3, 5, 8, 7]]


    return torch.LongTensor(enc_inputs), torch.LongTensor(dec_inputs), torch.LongTensor(dec_outputs)

enc_inputs, dec_inputs, dec_outputs = make_data(sentences)

class MyDataSet(Data.Dataset):
    def __init__(self, enc_inputs, dec_inputs, dec_outputs):
        super(MyDataSet, self).__init__()
        self.enc_inputs = enc_inputs
        self.dec_inputs = dec_inputs
        self.dec_outputs = dec_outputs

    def __len__(self):
        return self.enc_inputs.shape[0]

    def __getitem__(self, idx):
        return self.enc_inputs[idx], self.dec_inputs[idx], self.dec_outputs[idx]

loader = Data.DataLoader(MyDataSet(enc_inputs, dec_inputs, dec_outputs), 2, True)

model parameter

The meanings of the following variables are

  • Word embedding & the dimension of position embedding. These two values are the same, so just use a variable
  • Number of hidden neurons in feedforward layer
  • Q. The dimensions of K and V vectors, where Q and K must be equal, and the dimension of V is not limited, but I set it to 64 for convenience
  • Number of encoders and decoders
  • Number of heads in bulls’ attention

# Transformer Parameters
d_model = 512  # Embedding size
d_ff = 2048  # FeedForward dimension
d_k = d_v = 64  # dimension of K(=Q), V
n_layers = 6  # number of Encoder and Decoder Layer
n_heads = 8  # number of heads in Multi-Head Attention

The above is relatively simple, and the model involved below is more complex. Therefore, I will split the model into the following parts for explanation

  • Positional Encoding
  • Pad mask (pad is added because the sentence is not long enough, so pad needs to be masked)
  • Subsequence mask (decoder input cannot see the word information of the future time, so mask is required)
  • Scaleddotproductattention (calculate context vector)
  • Multi-Head Attention
  • FeedForward Layer
  • Encoder Layer
  • Encoder
  • Decoder Layer
  • Decoder
  • Transformer
    For comments in the code, if the value issrc_lenperhapstgt_lenYes, I’ll write it clearly, but some functions or classes, encoder and decoder, can be called, so I’m not sure what it issrc_lenstilltgt_len, for uncertain, I’ll write it downseq_len

Positional Encoding

According to the formula:

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)  # [max_len, d_model]
        Position = (0, max_len, dtype = torch. Float). Unsqueeze (1) # [max_len, 1], POS vector
        # div_term [d_model/2]
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))  # 10000^{2i/d_model}
        PE [:, 0:: 2] = torch.sin (position * div_term) # even bit assignment [max_len, d_model / 2]
        PE [:, 1:: 2] = torch.cos (position * div_term) # technical bit assignment [max_len, d_model / 2]
        pe = pe.unsqueeze(0).transpose(0, 1)  # [max_len,1,d_model]
        self.register_buffer('pe', pe)

    def forward(self, x):
        :param x: [seq_len, batch_size, d_model]
        X = x + self. PE [: x.size (0),:] # directly POS_ Embedding and vocab_ Embedding addition
        return self.dropout(x)


def get_attn_pad_mask(seq_q, seq_k):
    :param seq_q: [batch_size, seq_len]
    :param seq_k: [batch_size, seq_len]
    seq_len could be src_len or it could be tgt_len
    seq_len in seq_q and seq_len in seq_k maybe not equal
    batch_size, len_q = seq_q.size()
    batch_size, len_k = seq_k.size()
    #eq(zero) is PAD token
    #For example, enter SEQ_ data = [1, 2, 3, 4, 0],seq_ (0) will return [false, false, false, false, true]
    pad_attn_mask =  # [batch_size, 1, len_k], True is masked
    return pad_attn_mask.expand(batch_size, len_q, len_k)  # [batch_size, len_q, len_k]

Mask operation is required in both encoder and decoder,
Therefore, the parameters of this function cannot be determinedseq_lenThe value of,
If it is invoked in Encoder,seq_lenIs equal tosrc_len
If it is invoked in Decoder,seq_lenIt could be equal tosrc_len
It may also be equal totgt_len(because the decoder has two masks)

The core code of this function, the function of this sentence is to return a size andseq_kThe same tensor, except that the values in it are only true and false. Ifseq_kIf the value of a position is equal to 0, the corresponding position is true, otherwise it is false. For example, enterseq_data = [1, 2, 3, 4, 0] return [false, false, false, false, true]. True means a mask is required.

The remaining code is mainly to extend the dimension, which is strongly recommendedReaders print it out and see what the final returned data looks like

Subsequence Mask

def get_attn_subsequence_mask(seq):
    seq: [batch_size, tgt_len]
    attn_shape = [seq.size(0), seq.size(1), seq.size(1)]
    subsequence_mask = np.triu(np.ones(attn_shape), k=1) # Upper triangular matrix
    subsequence_mask = torch.from_numpy(subsequence_mask).byte()
    return subsequence_mask # [batch_size, tgt_len, tgt_len]

Subsequence mask onlyDecoderWill be used. Its main function is to shield the information of words in the future.
First passnp.ones()Generate a square matrix of all 1, and thennp.triu()Generate an upper triangular matrix, K represents moving up a diagonal. The following figure shows the usage of NP. Triu()



class ScaledDotProductAttention(nn.Module):
    def __init__(self):
        super(ScaledDotProductAttention, self).__init__()

    def forward(self, Q, K, V, attn_mask):
        Q: [batch_size, n_heads, len_q, d_k]
        K: [batch_size, n_heads, len_k, d_k]
        V: [batch_size, n_heads, len_v(=len_k), d_v]
        attn_mask: [batch_size, n_heads, seq_len, seq_len]
        scores = torch.matmul(Q, K.transpose(-1, -2)) / np.sqrt(d_k) # scores : [batch_size, n_heads, len_q, len_k]
        scores.masked_fill_(attn_mask, -1e9) # Fills elements of self tensor with value where mask is True.
        attn = nn.Softmax(dim=-1)(scores)
        context = torch.matmul(attn, V) # [batch_size, n_heads, len_q, d_v]
        return context, attn

What we have to do here is, throughQandKCalculate the scores, and thenscoresandVMultiply to get the context vector of each word

The first step is to multiply the transpose of Q and K. there’s nothing to say. After multiplying, we get itscoresSoftmax cannot be performed immediately, andattn_maskAdd up and block out some information that needs to be blocked,attn_maskIt is a tensor only composed of true and false, and it will be guaranteedattn_maskandscoresThe four values of the dimension are the same (otherwise, the corresponding position cannot be added)

After the mask is finished, you canscoresSoftmax. Then withVMultiply and getcontext


class MultiHeadAttention(nn.Module):
    def __init__(self):
        super(MultiHeadAttention, self).__init__()
        self.W_Q = nn.Linear(d_model, d_k * n_heads, bias=False)
        self.W_K = nn.Linear(d_model, d_k * n_heads, bias=False)
        self.W_V = nn.Linear(d_model, d_v * n_heads, bias=False)
        self.fc = nn.Linear(n_heads * d_v, d_model, bias=False)

    def forward(self,input_Q, input_K, input_V, attn_mask):
        :param input_Q: [batch_size, len_q, d_model]
        :param input_K: [batch_size, len_k, d_model]
        :param input_V: [batch_size, len_v(=len_k), d_model]
        :param attn_mask: [batch_size, seq_len, seq_len]
        residual, batch_size = input_Q, input_Q.size(0)
        # (B,S,D) - proj -> (B,S,D_new) -split -> (B, S, H, W) -> trans -> (B,H,S,W)

        #Decompose into multihead attention
        Q = self.W_Q(input_Q).view(batch_size,-1, n_heads, d_k).transpose(1,2) # Q:[batch_size, n_heads, len_q, d_k]
        K = self.W_K(input_K).view(batch_size,-1, n_heads, d_k).transpose(1,2) # K:[batch_size, n_heads, len_k, d_k]
        V = self.W_V(input_V).view(batch_size,-1, n_heads, d_v).transpose(1,2) # V:[batch_size, n_heads, len_v(=len_k, d_v]

        attn_mask = attn_mask.unsqueeze(1).repeat(1, n_heads, 1, 1) # attn_mask: [batch_size,n_heads, seq_len, seq_len]

        # [batch_size, n_heads, len_q, d_v], attn: [batch_size, n_heads, len_q, len_k]
        context, attn = ScaledDotProductAttention()(Q,K,V, attn_mask)
        context = context.transpose(1,2).reshape(batch_size, -1, n_heads * d_v)
        # context: [batch_size, len_q, n_heads * d_v]
        output = self.fc(context)

        return nn.LayerNorm(d_model).to(device)(output+residual),attn # Layer Normalization

There must be three calls in the complete codeMultiHeadAttention(), encoder layer is called once, and the passed ininput_Qinput_Kinput_VAll of themenc_inputs; The decoder layer is called twice, and the first time is alldec_inputs, the second incoming aredec_outputsenc_outputsenc_outputs

FeedForward Layer

class PoswiseFeedForwardNet(nn.Module):
    def __init__(self):
        super(PoswiseFeedForwardNet, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(d_ff, d_model, bias=False)

    def forward(self, inputs):
        :param inputs: [batch_size, seq_len, d_model]
        residual = inputs
        output = self.fc(inputs)
        return nn.LayerNorm(d_model).to(device)(output+residual) #[batch_size, seq_len, d_model]

This code is very simple, that is, do two linear transformations, connect the residuals and then follow a layer norm

Encoder Layer

class EncoderLayer(nn.Module):
    def __init__(self):
        super(EncoderLayer, self).__init__()
        self.enc_self_attn = MultiHeadAttention()
        self.pos_ffn = PoswiseFeedForwardNet()

    def forward(self,enc_inputs, enc_self_attn_mask):
        :param enc_inputs: [batch_size, src_len, d_model]
        :param enc_self_attn_mask: [batch_size, src_len, src_len]
        # enc_outputs: [batch_size, src_len, d_model], attn: [batch_size, n_heads, src_len, src_len]
        enc_outputs, attn = self.enc_self_attn(enc_inputs,enc_inputs,enc_inputs,enc_self_attn_mask)
        # enc_inputs to same Q,K,V
        enc_outputs = self.pos_ffn(enc_outputs) # enc_outputs: [batch_size, src_len, d_model]
        return enc_outputs, attn

Put the above components together to form a complete encoder layer


usenn.ModuleList(), the parameters in the list are lists, which are stored in the listn_layersMultiple encoder layers
Since we control the input and output dimensions of the encoder layer to be the same,
Therefore, you can directly use a for loop to nest the output of the last encoder layer as the input of the next encoder layer

class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        self.src_emb = nn.Embedding(src_vocab_size, d_model)
        self.pos_emb = PositionalEncoding(d_model)
        self.layers = nn.ModuleList([EncoderLayer() for _ in range(n_layers)])

    def forward(self, enc_inputs):
        :param enc_inputs: [batch_size, src_len]
        enc_outputs = self.src_emb(enc_inputs) # [batch_size, src_len, d_model]
        enc_outputs = self.pos_emb(enc_outputs.transpose(0,1)).transpose(0,1) # [batch_size, src_len, src_len]
        enc_self_attn_mask = get_attn_pad_mask(enc_inputs,enc_inputs) # [batch_size, src_len, src_len]
        enc_self_attns = []
        for layer in self.layers:
            # enc_outputs: [batch_size, src_len, d_model], enc_self_attn: [batch_size, n_heads, src_len, src_len]
            enc_outputs, enc_self_attn = layer(enc_outputs, enc_self_attn_mask)
        return enc_outputs, enc_self_attns

Decoder Layer

class DecoderLayer(nn.Module):
    def __init__(self):
        super(DecoderLayer, self).__init__()
        self.dec_self_attn = MultiHeadAttention()
        self.dec_enc_attn = MultiHeadAttention()
        self.pos_ffn = PoswiseFeedForwardNet()

    def forward(self, dec_inputs, enc_outputs, dec_self_attn_mask, dec_enc_attn_mask):
        :param dec_inputs: [batch_size, tgt_len, d_model]
        :param enc_outputs: [batch_size, src_len, d_model]
        :param dec_self_attn_mask: [batch_size, tgt_len, tgt_len]
        :param dec_enc_attn_mask: [batch_size, tgt_len, src_len]

        # dec_outputs: [batch_size, tgt_len, d_model], dec_self_attn: [batch_size, n_heads, tgt_len, tgt_len]
        dec_outputs, dec_self_attn = self.dec_self_attn(dec_inputs, dec_inputs, dec_inputs, dec_self_attn_mask)
        # dec_outputs: [batch_size, tgt_len, d_model], dec_enc_attn: [batch_size, h_heads, tgt_len, src_len]
        dec_outputs, dec_enc_attn = self.dec_enc_attn(dec_outputs, enc_outputs, enc_outputs, dec_enc_attn_mask)
        dec_outputs = self.pos_ffn(dec_outputs) # [batch_size, tgt_len, d_model]
        return dec_outputs, dec_self_attn, dec_enc_attn

Called twice in the decoder layerMultiHeadAttention, the first time is to calculate the self attention of the decoder input to get the outputdec_outputs。 Thendec_outputsAs the element generating Q,enc_outputsAs elements that generate K and V, call againMultiHeadAttention, the context vector between encoder and decoder layer is obtained. Finallydec_outptusMake a dimension transformation, and then return


class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        self.tgt_emb = nn.Embedding(tgt_vocab_size, d_model)
        self.pos_emb = PositionalEncoding(d_model)
        self.layers = nn.ModuleList([DecoderLayer() for _ in range(n_layers)])

    def forward(self, dec_inputs, enc_inputs, enc_outputs):
        :param dec_inputs: [batch_size, tgt_len]
        :param enc_inputs: [batch_size, src_len]
        :param enc_outputs: [batch_size, src_len, d_model]
        dec_outputs = self.tgt_emb(dec_inputs) # [batch_size, tgt_len, d_model]
        dec_outputs = self.pos_emb(dec_outputs.transpose(0,1)).transpose(0,1).to(device) # [batch_size, tgt_len, d_model]

        dec_self_attn_pad_mask = get_attn_pad_mask(dec_inputs,dec_inputs).to(device) # [batch_size, tgt_len, tgt_len]
        dec_self_attn_subsequence_mask = get_attn_subsequence_mask(dec_inputs).to(device) #[batch_size, tgt_len, tgt_len]
        #Torch. GT (a, value): compare the element at each position in a with value. If it is greater than value, the position will be taken as 1, otherwise it will be taken as 0
        dec_self_attn_mask =,0).to(device) # [batch_size, tgt_len, tgt_len]

        dec_enc_attn_mask = get_attn_pad_mask(dec_inputs,enc_inputs) #[batch_size, tgt_len, src_len]

        dec_self_attns, dec_enc_attns = [], []
        for layer in self.layers:
            # dec_outputs: [batch_size, tgt_len, d_model]
            # dec_self_attn: [batch_size, n_heads, tgt_len, tgt_len]
            # dec_enc_attn: [batch_size, n_heads, tgt_len, src_len]
            dec_outputs, dec_self_attn, dec_enc_attn = layer(dec_outputs,enc_outputs,dec_self_attn_mask,dec_enc_attn_mask)
        return dec_outputs,dec_self_attns,dec_enc_attns

In the decoder, not only the “pad” mask should be removed, but also the information of the future time should be masked. Therefore, there are the following three lines of code, in, value)Compare the element at each position in a with value. If it is greater than value, the position will be taken as 1, otherwise it will be taken as 0.

dec_self_attn_pad_mask = get_attn_pad_mask(dec_inputs, dec_inputs) # [batch_size, tgt_len, tgt_len]
        dec_self_attn_subsequence_mask = get_attn_subsequence_mask(dec_inputs) # [batch_size, tgt_len, tgt_len]
        dec_self_attn_mask = + dec_self_attn_subsequence_mask), 0) # [batch_size, tgt_len, tgt_len]


class Transformer(nn.Module):
    def __init__(self):
        self.encoder = Encoder().to(device)
        self.decoder = Decoder().to(device)
        self.projection = nn.Linear(d_model, tgt_vocab_size, bias=False).to(device)

    def forward(self,enc_inputs, dec_inputs):
        :param enc_inputs: [batch_size, src_len]
        :param dec_inputs: [batch_size, tgt_len]

        # enc_outputs: [batch_size, src_len, d_model],
        # enc_self_attns: [n_layers, batch_size, n_heads, src_len, src_len]
        enc_outputs,enc_self_attns = self.encoder(enc_inputs)

        # dec_outputs: [batch_size, tgt_len, d_model],
        # dec_self_attns: [n_layers, batch_size, n_heads, tgt_len, tgt_len],
        # dec_enc_attn: [n_layers, batch_size, n_heads,tgt_len, src_len]
        dec_outputs, dec_self_attns, dec_enc_attns = self.decoder(dec_inputs, enc_inputs, enc_outputs)

        dec_logits = self.projection(dec_outputs)
        return dec_logits.view(-1,dec_logits.size(-1)), enc_self_attns, dec_self_attns, dec_enc_attns

Transformer mainly calls encoder and decoder. Finally, Dec is returned_ The dimension of Logits is [batch_size * tgt_len, tgt_vocab_size], which can be understood as a sentence with batch_ size*tgt_ Len words, each word has TGT_ vocab_ In each case, take the one with the greatest probability

Model & Loss Function & optimizer

model = Transformer().to(device)
criterion = nn.CrossEntropyLoss(ignore_index=0)
optimizer = optim.SGD(model.parameters(),lr=1e-3,momentum=0.99)

I set a parameter in the loss function hereignore_index=0, because the index of the word “pad” is 0, the loss of “pad” will not be calculated after setting it (because “pad” is meaningless and does not need to be calculated). For a more detailed description of this parameter, see this articlearticleAt the bottom of the, I mentioned it a little


for epoch in range(30):
    for enc_inputs, dec_inputs, dec_outputs in loader:
        enc_inputs: [batch_size, src_len]
        dec_inputs: [batch_size, tgt_len]
        dec_outputs: [batch_size, tgt_len]
        enc_inputs, dec_inputs, dec_outputs =,,
        # outputs: [batch_size * tgt_len, tgt_vocab_size]
        outputs, enc_self_attns, dec_self_attns, dec_enc_attns = model(enc_inputs, dec_inputs)
        loss = criterion(outputs, dec_outputs.view(-1))
        print('Epoch:','%04d' % (epoch+1), 'loss =','{:.6f}'.format(loss))



def greedy_decoder(model, enc_input, start_symbol):
    For simplicity, a Greedy Decoder is Beam search when K=1. This is necessary for inference as we don't know the
    target sequence input. Therefore we try to generate the target input word by word, then feed it into the transformer.
    Starting Reference:
    :param model: Transformer Model
    :param enc_input: The encoder input
    :param start_symbol: The start symbol. In this example it is 'S' which corresponds to index 4
    :return: The target input
    enc_outputs, enc_self_attns = model.encoder(enc_input)
    dec_input = torch.zeros(1, 0).type_as(
    terminal = False
    next_symbol = start_symbol
    while not terminal:[dec_input.detach(),torch.tensor([[next_symbol]],dtype=enc_input.dtype,device=device)],-1)
        dec_outputs, _, _ = model.decoder(dec_input, enc_input, enc_outputs)
        projected = model.projection(dec_outputs)
        prob = projected.squeeze(0).max(dim=-1, keepdim=False)[1]
        next_word =[-1]
        next_symbol = next_word
        if next_symbol == tgt_vocab["."]:
            terminal = True
    return dec_input

# Test
enc_inputs, _, _ = next(iter(loader))
enc_inputs =
for i in range(len(enc_inputs)):
    greedy_dec_input = greedy_decoder(model, enc_inputs[i].view(1, -1), start_symbol=tgt_vocab["S"])
    predict, _, _, _ = model(enc_inputs[i].view(1, -1), greedy_dec_input)
    predict =, keepdim=True)[1]
    print(enc_inputs[i], '->', [idx2word[n.item()] for n in predict.squeeze()])