Pytorch text classification based on torchtext


By Dr. Vaibhav Kumar
Compile | VK
Source | analytics in diamag

Text classification is one of the important applications of natural language processing. There are many methods to classify text in machine learning. However, most of these classification techniques need a lot of preprocessing and computing resources. In this paper, we use pytorch to classify multiple types of text, because it has the following advantages:

  • Pytorch provides a powerful way to implement complex model architectures and algorithms, with less preprocessing and less computing resources (including execution time).
  • The basic unit of pytorch is tensor, which has the advantages of changing architecture at runtime and distributing training across GPUs.
  • Pytorch provides a powerful library called torchtext, which contains scripts for preprocessing text and source code for some popular NLP datasets.

In this article, we will demonstrate multi class text categorization using torchtext, a powerful natural language processing library in pytorch.

For this classification, a model consisting of the embedding bag layer and the linear layer will be used. Embeddingbag processes variable length text items by calculating the average embedded value.

This model will be trained on the DBpedia dataset, where the text belongs to 14 classes. After successful training, the model will predict the class label of the input text.

DBpedia dataset

DBpedia is a popular benchmark data set in the field of natural language processing. It contains 14 categories of text, such as companies, educational institutions, artists, films, etc.

It’s actually a set of structured content extracted from information created by a Wikipedia project. The DBpedia dataset provided by torchtext has 63000 text instances belonging to 14 classes. It includes 5600 training cases and 70000 test cases.

Text classification with torchtext

First, we need to install the latest version of torchtext.

!pip install torchtext==0.4

After that, we’ll import all the required libraries.

import torch
import torchtext
from torchtext.datasets import text_classification
import os
import torch.nn as nn
import torch.nn.functional as F
from import DataLoader
import time
from import random_split
import re
from import ngrams_iterator
from import get_tokenizer

In the next step, we will define the ngrams and batch sizes. The ngrams feature is used to capture important information about local word order.

We use bigram, and the sample text in the dataset will be a list of single words plus bigrams strings.


Now, we will read the DBpedia dataset provided by torchtext.

if not os.path.isdir('./.data'):
train_dataset, test_dataset = text_classification.DATASETS['DBpedia'](
    root='./.data', ngrams=NGRAMS, vocab=None)

After downloading the dataset, we will verify the length of the downloaded dataset and the number of tags.



We will use CUDA architecture to speed up operation and execution.

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In the next step, we will define the model for the classification.

class TextSentiment(nn.Module):
    def __init__(self, vocab_size, embed_dim, num_class):
        self.embedding = nn.EmbeddingBag(vocab_size, embed_dim, sparse=True)
        self.fc = nn.Linear(embed_dim, num_class)

    def init_weights(self):
        initrange = 0.5, initrange), initrange)

    def forward(self, text, offsets):
        embedded = self.embedding(text, offsets)
        return self.fc(embedded)


Now, we will initialize the parameters and define functions to generate the training batch.

VOCAB_SIZE = len(train_dataset.get_vocab())
NUN_CLASS = len(train_dataset.get_labels())
model = TextSentiment(VOCAB_SIZE, EMBED_DIM, NUN_CLASS).to(device)

def generate_batch(batch):
    label = torch.tensor([entry[0] for entry in batch])
    text = [entry[1] for entry in batch]
    offsets = [0] + [len(entry) for entry in text]
    offsets = torch.tensor(offsets[:-1]).cumsum(dim=0)
    text =
    return text, offsets, label

In the next step, we will define the functions for training and testing the model.

def train_func(sub_train_):

    #Training model
    train_loss = 0
    train_acc = 0
    data = DataLoader(sub_train_, batch_size=BATCH_SIZE, shuffle=True,
    for i, (text, offsets, cls) in enumerate(data):
        text, offsets, cls =,,
        output = model(text, offsets)
        loss = criterion(output, cls)
        train_loss += loss.item()
        train_acc += (output.argmax(1) == cls).sum().item()

    #Adjust learning rate

    return train_loss / len(sub_train_), train_acc / len(sub_train_)

def test(data_):
    loss = 0
    acc = 0
    data = DataLoader(data_, batch_size=BATCH_SIZE, collate_fn=generate_batch)
    for text, offsets, cls in data:
        text, offsets, cls =,,
        with torch.no_grad():
            output = model(text, offsets)
            loss = criterion(output, cls)
            loss += loss.item()
            acc += (output.argmax(1) == cls).sum().item()

    return loss / len(data_), acc / len(data_)

We will use five epoch training models.

min_valid_loss = float('inf')

criterion = torch.nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=4.0)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1, gamma=0.9)

train_len = int(len(train_dataset) * 0.95)
sub_train_, sub_valid_ = \
    random_split(train_dataset, [train_len, len(train_dataset) - train_len])

for epoch in range(N_EPOCHS):

    start_time = time.time()
    train_loss, train_acc = train_func(sub_train_)
    valid_loss, valid_acc = test(sub_valid_)

    secs = int(time.time() - start_time)
    mins = secs / 60
    secs = secs % 60

    print('Epoch: %d' %(epoch + 1), " | time in %d minutes, %d seconds" %(mins, secs))
    print(f'\tLoss: {train_loss:.4f}(train)\t|\tAcc: {train_acc * 100:.1f}%(train)')
    print(f'\tLoss: {valid_loss:.4f}(valid)\t|\tAcc: {valid_acc * 100:.1f}%(valid)')

Next, we will test our model on the test dataset and check the accuracy of the model.

print('Checking the results of test dataset...')
test_loss, test_acc = test(test_dataset)
print(f'\tLoss: {test_loss:.4f}(test)\t|\tAcc: {test_acc * 100:.1f}%(test)')

Now we’ll test our model on a single news text string and predict the class tags for a given news text.

DBpedia_label = {0: 'Company',
                1: 'EducationalInstitution',
                2: 'Artist',
                3: 'Athlete',
                4: 'OfficeHolder',
                5: 'MeanOfTransportation',
                6: 'Building',
                7: 'NaturalPlace',
                8: 'Village',
                9: 'Animal',
                10: 'Plant',
                11: 'Album',
                12: 'Film',
                13: 'WrittenWork'}

def predict(text, model, vocab, ngrams):
    tokenizer = get_tokenizer("basic_english")
    with torch.no_grad():
        text = torch.tensor([vocab[token]
                            for token in ngrams_iterator(tokenizer(text), ngrams)])
        output = model(text, torch.tensor([0]))
        return output.argmax(1).item() + 1
vocab = train_dataset.get_vocab()
model ="cpu")

Now we’ll randomly extract some text from the test data and examine the predicted class tags.

The first prediction:

ex_text_str = "Brekke Church (Norwegian: Brekke kyrkje) is a parish church in Gulen Municipality in Sogn og Fjordane county, Norway. It is located in the village of Brekke. The church is part of the Brekke parish in the Nordhordland deanery in the Diocese of Bjørgvin. The white, wooden church, which has 390 seats, was consecrated on 19 November 1862 by the local Dean Thomas Erichsen. The architect Christian Henrik Grosch made the designs for the church, which is the third church on the site."

print("This is a %s news" %DBpedia_label[predict(ex_text_str, model, vocab, 2)])

The second prediction:

ex_text_str2 = "Cerithiella superba is a species of very small sea snail, a marine gastropod mollusk in the family Newtoniellidae. This species is known from European waters. It was described by Thiele, 1912."

print("This text belongs to %s class" %DBpedia_label[predict(ex_text_str2, model, vocab, 2)])

The third prediction:

ex_text_str3 = "  Nithari is a village in the western part of the state of Uttar Pradesh India bordering on New Delhi. Nithari forms part of the New Okhla Industrial Development Authority's planned industrial city Noida falling in Sector 31. Nithari made international news headlines in December 2006 when the skeletons of a number of apparently murdered women and children were unearthed in the village."

print("This text belongs to %s class" %DBpedia_label[predict(ex_text_str3, model, vocab, 2)])

Therefore, in this way, we use torchtext to implement multi class text classification.

This is a simple and easy method of text classification, using this pytorch library only needs a small amount of preprocessing. On 5600 training cases, it took less than 5 minutes to train the model.

Rerun the code by changing NGram from 2 to 3 and see if the results improve. The same implementation can be implemented on other datasets provided by torchtext.


  1. ‘Text Classification with TorchText’, PyTorch tutorial
  2. Allen Nie, ‘A Tutorial on TorchText’

Link to the original text:

Welcome to visit pan Chuang AI blog station:

Sklearn machine learning Chinese official document:

Welcome to pay attention to pan Chuang blog resource collection station:

Recommended Today

Let me also summarize the knowledge of nginx

Recently, I want to deeply study the related knowledge of nginx, so I summarize the following contents. Nginx configuration parameters Nginx common commands Nginx variable Virtual host configuration Nginx’s own module Fastcgi related configuration Common functions Load balancing configuration Static and dynamic separation configuration Anti theft chain What is nginx? Nginx is a free, open […]