Using keras_ Realization of multi output and parameter sharing model by Bert

Time:2021-4-14

background

In the field of NLP, the pre training model Bert is very popular.

But now we can find most of the frameworks written by python, and most of them are single output models.

Therefore, this paper uses keras to design a multi output and parameter sharing model based on the background of multi-layer label classification with mutual relationship.

keras_ Basic application of Bert

def batch_iter(data_path, cat_to_id, tokenizer, batch_size=64, shuffle=True):
    'generate batch data'
    keras_bert_iter = get_keras_bert_iterator(data_path, cat_to_id, tokenizer)
    while True:
        data_list = []
        for _ in range(batch_size):
            data = next(keras_bert_iter)
            data_list.append(data)
        if shuffle:
            random.shuffle(data_list)
        
        indices_list = []
        segments_list = []
        label_index_list = []
        for data in data_list:
            indices, segments, label_index = data
            indices_list.append(indices)
            segments_list.append(segments)
            label_index_list.append(label_index)

        yield [np.array(indices_list), np.array(segments_list)], np.array(label_index_list)

def get_model(label_list):
    K.clear_session()
    
    bert_ model = load_ trained_ model_ from_ checkpoint(bert_ paths.config , bert_ paths.checkpoint , seq_ len=text_ max_ Load pre training model
 
    for l in bert_model.layers:
        l.trainable = True
 
    input_indices = Input(shape=(None,))
    input_segments = Input(shape=(None,))
 
    bert_output = bert_model([input_indices, input_segments])
    bert_ cls = Lambda(lambda x: x[:, 0])(bert_ Output) # take out the vector corresponding to [CLS] for classification
    output = Dense(len(label_list), activation='softmax')(bert_cls)
 
    model = Model([input_indices, input_segments], output)
    model.compile(loss='sparse_categorical_crossentropy',
                  Optimizer = Adam (1e-5), # with small enough learning rate
                  metrics=['accuracy'])
    print(model.summary())
    return model

early_ stopping = EarlyStopping(monitor='val_ ACC ', patient = 3) # early stop method to prevent over fitting
plateau = ReduceLROnPlateau(monitor="val_ When the evaluation index is not improving, the learning rate is reduced
checkpoint = ModelCheckpoint('trained_ model/keras_ bert_ THUCNews.hdf5 ', monitor='val_ acc',verbose=2, save_ best_ only=True, mode='max', save_ weights_ Only = true) # save the best model

def get_step(sample_count, batch_size):
    step = sample_count // batch_size
    if sample_count % batch_size != 0:
        step += 1
    return step

batch_size = 4
train_step = get_step(train_sample_count, batch_size)
dev_step = get_step(dev_sample_count, batch_size)

train_dataset_iterator = batch_iter(r"data/keras_bert_train.txt", cat_to_id, tokenizer, batch_size)
dev_dataset_iterator = batch_iter(r"data/keras_bert_dev.txt", cat_to_id, tokenizer, batch_size)

model = get_model(categories)

#Model training
model.fit(
    train_dataset_iterator,
    steps_per_epoch=train_step,
    epochs=10,
    validation_data=dev_dataset_iterator,
    validation_steps=dev_step,
    callbacks=[early_stopping, plateau, checkpoint],
    verbose=1
)

Model design of multi output and parameter sharing

def batch_iter(data_path, cat_to_id, tokenizer, second_label_list, batch_size=64, shuffle=True):
    'generate batch data'
    keras_bert_iter = get_keras_bert_iterator(data_path, cat_to_id, tokenizer, second_label_list)
    while True:
        data_list = []
        for _ in range(batch_size):
            data = next(keras_bert_iter)
            data_list.append(data)
        if shuffle:
            random.shuffle(data_list)
        
        indices_list = []
        segments_list = []
        label_index_list = []
        second_label_list = []
        for data in data_list:
            indices, segments, label_index, second_label = data
            indices_list.append(indices)
            segments_list.append(segments)
            label_index_list.append(label_index)
            second_label_list.append(second_label)

        yield [np.array(indices_list), np.array(segments_list)], [np.array(label_index_list), np.array(second_label_list)]

def get_model(label_list, second_label_list):
    K.clear_session()
    
    bert_ model = load_ trained_ model_ from_ checkpoint(bert_ paths.config , bert_ paths.checkpoint , seq_ len=text_ max_ Load pre training model
 
    for l in bert_model.layers:
        l.trainable = True
 
    input_indices = Input(shape=(None,))
    input_segments = Input(shape=(None,))
 
    bert_output = bert_model([input_indices, input_segments])
    bert_ cls = Lambda(lambda x: x[:, 0])(bert_ Output) # take out the vector corresponding to [CLS] for classification
    output = Dense(len(label_list), activation='softmax')(bert_cls)
    output_second = Dense(len(second_label_list), activation='softmax')(bert_cls)
 
    model = Model([input_indices, input_segments], [output, output_second])
    model.compile(loss='sparse_categorical_crossentropy',
                  Optimizer = Adam (1e-5), # with small enough learning rate
                  metrics=['accuracy'])
    print(model.summary())
    return model

batch_size = 4
train_step = get_step(train_sample_count, batch_size)
dev_step = get_step(dev_sample_count, batch_size)

train_dataset_iterator = batch_iter(r"data/keras_bert_train.txt", cat_to_id, tokenizer, second_label_list, batch_size)
dev_dataset_iterator = batch_iter(r"data/keras_bert_dev.txt", cat_to_id, tokenizer, second_label_list, batch_size)

model = get_model(categories, second_label_list)

#Model training
model.fit(
    train_dataset_iterator,
    steps_per_epoch=train_step,
    epochs=10,
    validation_data=dev_dataset_iterator,
    validation_steps=dev_step,
    callbacks=[early_stopping, plateau, checkpoint],
    verbose=1
)

appendix

All source code

import os
import sys
import re
from collections import Counter
import random

from tqdm import tqdm
import numpy as np
import tensorflow.keras as keras
from keras_bert import load_vocabulary, load_trained_model_from_checkpoint, Tokenizer, get_checkpoint_paths
from keras_bert.layers import MaskedGlobalMaxPool1D
from keras_bert import load_trained_model_from_checkpoint, Tokenizer
from keras.metrics import top_k_categorical_accuracy
from keras.layers import *
from keras.callbacks import *
from keras.models import Model
import keras.backend as K
from keras.optimizers import Adam
from keras.utils import to_categorical
data_path = "000_text_classifier_tensorflow_textcnn/THUCNews/"
text_max_length = 512
bert_paths = get_checkpoint_paths(r"chinese_L-12_H-768_A-12")

Text data iterator

def _read_file(filename):
    "Read a file and convert to one line"
    with open(filename, 'r', encoding='utf-8') as f:
        s = f.read().strip().replace('\n', '。').replace('\t', '').replace('\u3000', '')
        return re.sub(r'。+', '。', s)
def get_data_iterator(data_path):
    for category in os.listdir(data_path):
        category_path = os.path.join(data_path, category)
        for file_name in os.listdir(category_path):
            yield _read_file(os.path.join(category_path, file_name)), category
it = get_data_iterator(data_path)
next(it)
('lottery analysis: Japan and the United States compete for the title, and the two Pakistan will have life and death when they meet. On Sunday, the women's World Cup final and the quarter finals of the Americas Cup are undoubtedly the focus of fans and lottery fans all over the world. Can Japan, the biggest dark horse in the women's football World Cup, make a miracle in Asia? Can the U.S. team, the dominant team in women's football, win the third championship? Brazil and Paraguay are enemies, who can laugh last? The answers will be revealed in the early hours of Monday. Japan and the United States are struggling for the title. This women's football World Cup is a struggle between subversion and anti subversion. The favorite host team, Germany, was beaten by Japan in extra time in the quarter final while the other favorite team, Sweden, was completely defeated by Japan 3-1 in the semi-final. In the quarter final, they fought hard with the Brazilian women's football team until the penalty shoot out, and finally eliminated the fast-growing dark horse team 5-3. In the semi-final, they won the European dark horse France 3-1. The process of the world cup between the U.S. and Japan is surprisingly similar. They won the first two rounds of the group match and lost in the last round. The quarter finals were also tied with their opponents in the 90 minute civil war. In the semi-finals, they also won 3-1. This decisive battle, whether Japan or the United States win the championship, will create a new history of women's football World Cup. There is life and death when two Pakistan meet. There are too many surprises in this year's Copa America. Brazil and Paraguay seem to be more legendary because of their narrow relationship. The two teams were divided into group B at the same time in the group stage. Originally, two of them were very popular, but they both drew in the first two rounds of the group stage. The direct confrontation between the two teams was a 2-2 draw. As a result, they both faced the danger of being out. In the last round, Brazil finally got the upper hand in the second half with a 4-2 victory over Ecuador, and came out first in the group. In the final game, Paraguay drew 3-3 with Venezuela to get third in the group. Fortunately, Costa Rica, who was third in group A, was overtaken by the advantage of goal difference and won a place in the last eight. In the group stage, Brazil drew with Paraguay at the last minute. Will their good luck show itself again in the knockout? It seems that Paraguay lacked luck in the previous three rounds of group matches. Will it be compensated by the God of luck?. In the other quarter finals of the Copa America, Chile won 2-1 in group C and advanced to the top eight of the group. Venezuela was the least favored team in group B, but even in the same group as Brazil and Paraguay, they established the group qualification in the first two rounds. They won 1-2 in group 3 and remained unbeaten. The number of goals they scored was the same as Chile, but the number of goals they lost was higher than Chile One more. But it's not surprising that they've managed to keep their goals in the face of a strong Brazil. ',
 'lottery')

Build label table

def read_category(data_path):
    "Read catalog, fix"
    categories = os.listdir(data_path)

    cat_to_id = dict(zip(categories, range(len(categories))))

    return categories, cat_to_id
categories, cat_to_id = read_category(data_path)
cat_to_id
{'lottery': 0,
 'home': 1,
 Game: 2,
 'stock ': 3,
 Technology: 4,
 'society ': 5,
 Finance and Economics: 6,
 'fashion': 7,
 Constellation: 8,
 Sports: 9,
 'real estate ': 10,
 Entertainment: 11,
 Current affairs: 12,
 'education': 13}
categories
['lottery ',
 'home',
 'game',
 'stock ',
 'technology ',
 'society ',
 'finance and economics',
 'fashion',
 Constellation,
 'sports',
 'real estate ',
 'entertainment',
 'current affairs',
 [education]

Build training, verification and test set

def build_dataset(data_path, train_path, dev_path, test_path):
    data_iter = get_data_iterator(data_path)
    with open(train_path, 'w', encoding='utf-8') as train_file, \
         open(dev_path, 'w', encoding='utf-8') as dev_file, \
         open(test_path, 'w', encoding='utf-8') as test_file:
        
        for text, label in tqdm(data_iter):
            radio = random.random()
            if radio < 0.8:
                train_file.write(text + "\t" + label + "\n")
            elif radio < 0.9:
                dev_file.write(text + "\t" + label + "\n")
            else:
                test_file.write(text + "\t" + label + "\n")
# build_dataset(data_path, r"data/keras_bert_train.txt", r"data/keras_bert_dev.txt", r"data/keras_bert_test.txt")

Get the number of data set samples

def get_sample_num(data_path):
    count = 0
    with open(data_path, 'r', encoding='utf-8') as data_file:
        for line in tqdm(data_file):
            count += 1
    return count
train_sample_count = get_sample_num(r"data/keras_bert_train.txt")
668858it [00:09, 67648.27it/s]


dev_sample_count = get_sample_num(r"data/keras_bert_dev.txt")
83721it [00:01, 61733.96it/s]


test_sample_count = get_sample_num(r"data/keras_bert_test.txt")
83496it [00:01, 72322.53it/s]


train_sample_count, dev_sample_count, test_sample_count
(668858, 83721, 83496)



Building data iterators

def get_text_iterator(data_path):
    with open(data_path, 'r', encoding='utf-8') as data_file:
        for line in data_file:
            data_split = line.strip().split('\t')
            if len(data_split) != 2:
                print(line)
                continue
            yield data_split[0], data_split[1]
it = get_text_iterator(r"data/keras_bert_train.txt")
next(it)
('lottery analysis: Japan and the United States compete for the title, and the two Pakistan will have life and death when they meet. On Sunday, the women's World Cup final and the quarter finals of the Americas Cup are undoubtedly the focus of fans and lottery fans all over the world. Can Japan, the biggest dark horse in the women's football World Cup, make a miracle in Asia? Can the U.S. team, the dominant team in women's football, win the third championship? Brazil and Paraguay are enemies, who can laugh last? The answers will be revealed in the early hours of Monday. Japan and the United States are struggling for the title. This women's football World Cup is a struggle between subversion and anti subversion. The favorite host team, Germany, was beaten by Japan in extra time in the quarter final while the other favorite team, Sweden, was completely defeated by Japan 3-1 in the semi-final. In the quarter final, they fought hard with the Brazilian women's football team until the penalty shoot out, and finally eliminated the fast-growing dark horse team 5-3. In the semi-final, they won the European dark horse France 3-1. The process of the world cup between the U.S. and Japan is surprisingly similar. They won the first two rounds of the group match and lost in the last round. The quarter finals were also tied with their opponents in the 90 minute civil war. In the semi-finals, they also won 3-1. This decisive battle, whether Japan or the United States win the championship, will create a new history of women's football World Cup. There is life and death when two Pakistan meet. There are too many surprises in this year's Copa America. Brazil and Paraguay seem to be more legendary because of their narrow relationship. The two teams were divided into group B at the same time in the group stage. Originally, two of them were very popular, but they both drew in the first two rounds of the group stage. The direct confrontation between the two teams was a 2-2 draw. As a result, they both faced the danger of being out. In the last round, Brazil finally got the upper hand in the second half with a 4-2 victory over Ecuador, and came out first in the group. In the final game, Paraguay drew 3-3 with Venezuela to get third in the group. Fortunately, Costa Rica, who was third in group A, was overtaken by the advantage of goal difference and won a place in the last eight. In the group stage, Brazil drew with Paraguay at the last minute. Will their good luck show itself again in the knockout? It seems that Paraguay lacked luck in the previous three rounds of group matches. Will it be compensated by the God of luck?. In the other quarter finals of the Copa America, Chile won 2-1 in group C and advanced to the top eight of the group. Venezuela was the least favored team in group B, but even in the same group as Brazil and Paraguay, they established the group qualification in the first two rounds. They won 1-2 in group 3 and remained unbeaten. The number of goals they scored was the same as Chile, but the number of goals they lost was higher than Chile One more. But it's not surprising that they've managed to keep their goals in the face of a strong Brazil. ',
 'lottery')
token_dict = load_vocabulary(bert_paths.vocab)
tokenizer = Tokenizer(token_dict)
def get_keras_bert_iterator(data_path, cat_to_id, tokenizer):
    while True:
        data_iter = get_text_iterator(data_path)
        for text, category in data_iter:
            indices, segments = tokenizer.encode(first=text, max_len=text_max_length)
            yield indices, segments, cat_to_id[category]
it = get_keras_bert_iterator(r"data/keras_bert_train.txt", cat_to_id, tokenizer)
# next(it)

Building batch data iterators

def batch_iter(data_path, cat_to_id, tokenizer, batch_size=64, shuffle=True):
    'generate batch data'
    keras_bert_iter = get_keras_bert_iterator(data_path, cat_to_id, tokenizer)
    while True:
        data_list = []
        for _ in range(batch_size):
            data = next(keras_bert_iter)
            data_list.append(data)
        if shuffle:
            random.shuffle(data_list)
        
        indices_list = []
        segments_list = []
        label_index_list = []
        for data in data_list:
            indices, segments, label_index = data
            indices_list.append(indices)
            segments_list.append(segments)
            label_index_list.append(label_index)

        yield [np.array(indices_list), np.array(segments_list)], np.array(label_index_list)
it = batch_iter(r"data/keras_bert_train.txt", cat_to_id, tokenizer, batch_size=1)
# next(it)
it = batch_iter(r"data/keras_bert_train.txt", cat_to_id, tokenizer, batch_size=2)
next(it)
([array([[ 101, 4993, 2506, ...,  131,  123,  102],
         [ 101, 2506, 3696, ..., 1139,  125,  102]]),
  array([[0, 0, 0, ..., 0, 0, 0],
         [0, 0, 0, ..., 0, 0, 0]])],
 array([0, 0]))



Define base model

def get_model(label_list):
    K.clear_session()
    
    bert_ model = load_ trained_ model_ from_ checkpoint(bert_ paths.config , bert_ paths.checkpoint , seq_ len=text_ max_ Load pre training model
 
    for l in bert_model.layers:
        l.trainable = True
 
    input_indices = Input(shape=(None,))
    input_segments = Input(shape=(None,))
 
    bert_output = bert_model([input_indices, input_segments])
    bert_ cls = Lambda(lambda x: x[:, 0])(bert_ Output) # take out the vector corresponding to [CLS] for classification
    output = Dense(len(label_list), activation='softmax')(bert_cls)
 
    model = Model([input_indices, input_segments], output)
    model.compile(loss='sparse_categorical_crossentropy',
                  Optimizer = Adam (1e-5), # with small enough learning rate
                  metrics=['accuracy'])
    print(model.summary())
    return model
early_ stopping = EarlyStopping(monitor='val_ ACC ', patient = 3) # early stop method to prevent over fitting
plateau = ReduceLROnPlateau(monitor="val_ When the evaluation index is not improving, the learning rate is reduced
checkpoint = ModelCheckpoint('trained_ model/keras_ bert_ THUCNews.hdf5 ', monitor='val_ acc',verbose=2, save_ best_ only=True, mode='max', save_ weights_ Only = true) # save the best model

model training

def get_step(sample_count, batch_size):
    step = sample_count // batch_size
    if sample_count % batch_size != 0:
        step += 1
    return step
batch_size = 4
train_step = get_step(train_sample_count, batch_size)
dev_step = get_step(dev_sample_count, batch_size)

train_dataset_iterator = batch_iter(r"data/keras_bert_train.txt", cat_to_id, tokenizer, batch_size)
dev_dataset_iterator = batch_iter(r"data/keras_bert_dev.txt", cat_to_id, tokenizer, batch_size)

model = get_model(categories)

#Model training
model.fit(
    train_dataset_iterator,
    steps_per_epoch=train_step,
    epochs=10,
    validation_data=dev_dataset_iterator,
    validation_steps=dev_step,
    callbacks=[early_stopping, plateau, checkpoint],
    verbose=1
)
Model: "functional_5"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_1 (InputLayer)            [(None, 512)]        0                                            
__________________________________________________________________________________________________
input_2 (InputLayer)            [(None, 512)]        0                                            
__________________________________________________________________________________________________
functional_3 (Functional)       (None, 512, 768)     101677056   input_1[0][0]                    
                                                                 input_2[0][0]                    
__________________________________________________________________________________________________
lambda (Lambda)                 (None, 768)          0           functional_3[0][0]               
__________________________________________________________________________________________________
dense (Dense)                   (None, 14)           10766       lambda[0][0]                     
==================================================================================================
Total params: 101,687,822
Trainable params: 101,687,822
Non-trainable params: 0
__________________________________________________________________________________________________
None
Epoch 1/10
     5/167215 [..............................] - ETA: 775:02:36 - loss: 0.4064 - accuracy: 0.9000


---------------------------------------------------------------------------



Multi output model

Building data iterators

second_label_list = [0, 1, 2]
def get_keras_bert_iterator(data_path, cat_to_id, tokenizer, second_label_list):
    while True:
        data_iter = get_text_iterator(data_path)
        for text, category in data_iter:
            indices, segments = tokenizer.encode(first=text, max_len=text_max_length)
            yield indices, segments, cat_to_id[category], random.choice(second_label_list)
it = get_keras_bert_iterator(r"data/keras_bert_train.txt", cat_to_id, tokenizer, second_label_list)
# next(it)
def batch_iter(data_path, cat_to_id, tokenizer, second_label_list, batch_size=64, shuffle=True):
    'generate batch data'
    keras_bert_iter = get_keras_bert_iterator(data_path, cat_to_id, tokenizer, second_label_list)
    while True:
        data_list = []
        for _ in range(batch_size):
            data = next(keras_bert_iter)
            data_list.append(data)
        if shuffle:
            random.shuffle(data_list)
        
        indices_list = []
        segments_list = []
        label_index_list = []
        second_label_list = []
        for data in data_list:
            indices, segments, label_index, second_label = data
            indices_list.append(indices)
            segments_list.append(segments)
            label_index_list.append(label_index)
            second_label_list.append(second_label)

        yield [np.array(indices_list), np.array(segments_list)], [np.array(label_index_list), np.array(second_label_list)]
it = batch_iter(r"data/keras_bert_train.txt", cat_to_id, tokenizer, second_label_list, batch_size=2)
next(it)
([array([[ 101, 4993, 2506, ...,  131,  123,  102],
         [ 101, 2506, 3696, ..., 1139,  125,  102]]),
  array([[0, 0, 0, ..., 0, 0, 0],
         [0, 0, 0, ..., 0, 0, 0]])],
 [array([0, 0]), array([0, 0])])


Define the model

def get_model(label_list, second_label_list):
    K.clear_session()
    
    bert_ model = load_ trained_ model_ from_ checkpoint(bert_ paths.config , bert_ paths.checkpoint , seq_ len=text_ max_ Load pre training model
 
    for l in bert_model.layers:
        l.trainable = True
 
    input_indices = Input(shape=(None,))
    input_segments = Input(shape=(None,))
 
    bert_output = bert_model([input_indices, input_segments])
    bert_ cls = Lambda(lambda x: x[:, 0])(bert_ Output) # take out the vector corresponding to [CLS] for classification
    output = Dense(len(label_list), activation='softmax')(bert_cls)
    output_second = Dense(len(second_label_list), activation='softmax')(bert_cls)
 
    model = Model([input_indices, input_segments], [output, output_second])
    model.compile(loss='sparse_categorical_crossentropy',
                  Optimizer = Adam (1e-5), # with small enough learning rate
                  metrics=['accuracy'])
    print(model.summary())
    return model
early_ stopping = EarlyStopping(monitor='val_ ACC ', patient = 3) # early stop method to prevent over fitting
plateau = ReduceLROnPlateau(monitor="val_ When the evaluation index is not improving, the learning rate is reduced
checkpoint = ModelCheckpoint('trained_ model/muilt_ keras_ bert_ THUCNews.hdf5 ', monitor='val_ acc',verbose=2, save_ best_ only=True, mode='max', save_ weights_ Only = true) # save the best model

model training

batch_size = 4
train_step = get_step(train_sample_count, batch_size)
dev_step = get_step(dev_sample_count, batch_size)

train_dataset_iterator = batch_iter(r"data/keras_bert_train.txt", cat_to_id, tokenizer, second_label_list, batch_size)
dev_dataset_iterator = batch_iter(r"data/keras_bert_dev.txt", cat_to_id, tokenizer, second_label_list, batch_size)

model = get_model(categories, second_label_list)

#Model training
model.fit(
    train_dataset_iterator,
    steps_per_epoch=train_step,
    epochs=10,
    validation_data=dev_dataset_iterator,
    validation_steps=dev_step,
    callbacks=[early_stopping, plateau, checkpoint],
    verbose=1
)
Model: "functional_5"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_1 (InputLayer)            [(None, 512)]        0                                            
__________________________________________________________________________________________________
input_2 (InputLayer)            [(None, 512)]        0                                            
__________________________________________________________________________________________________
functional_3 (Functional)       (None, 512, 768)     101677056   input_1[0][0]                    
                                                                 input_2[0][0]                    
__________________________________________________________________________________________________
lambda (Lambda)                 (None, 768)          0           functional_3[0][0]               
__________________________________________________________________________________________________
dense (Dense)                   (None, 14)           10766       lambda[0][0]                     
__________________________________________________________________________________________________
dense_1 (Dense)                 (None, 3)            2307        lambda[0][0]                     
==================================================================================================
Total params: 101,690,129
Trainable params: 101,690,129
Non-trainable params: 0
__________________________________________________________________________________________________
None
Epoch 1/10
     7/167215 [..............................] - ETA: 1829:52:33 - loss: 3.1260 - dense_loss: 1.4949 - dense_1_loss: 1.6311 - dense_accuracy: 0.6429 - dense_1_accuracy: 0.3571 

Using keras_ Realization of multi output and parameter sharing model by Bert