Tf2.0 XLA accelerated test

Time:2020-2-24

Tf2.0 XLA accelerated test

Officially, XLA (accelerated linear algebra) is a field specific linear algebra compiler, which can optimize tensorflow computing, improve the running speed of servers and mobile platforms, and improve memory usage and portability. XLA framework is an experimental framework, which is still in the active development stage.

So I want to see how XLA accelerates the Bert model. I chose the Chinese model of Bert and tested it on the emotion classification task.

import tensorflow as tf
from transformers import *

from band.dataset import ChnSentiCorp
from band.progress import classification_convert_examples_to_features

USE_XLA = False
USE_AMP = False

EPOCHS = 5
BATCH_SIZE = 16
EVAL_BATCH_SIZE = 16
TEST_BATCH_SIZE = 1
MAX_SEQ_LEN = 128
LEARNING_RATE = 3e-5


tf.config.optimizer.set_jit(USE_XLA)
tf.config.optimizer.set_experimental_options({"auto_mixed_precision": USE_AMP})

dataset = ChnSentiCorp(save_path="/tmp/band")
data, label = dataset.data, dataset.label
dataset.dataset_information()

train_number, eval_number, test_number = dataset.train_examples_num, dataset.eval_examples_num, dataset.test_examples_num

tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')

train_dataset = classification_convert_examples_to_features(data['train'], tokenizer, max_length=MAX_SEQ_LEN,
                                                            label_list=label,
                                                            output_mode="classification")
valid_dataset = classification_convert_examples_to_features(data['validation'], tokenizer, max_length=MAX_SEQ_LEN,
                                                            label_list=label,
                                                            output_mode="classification")

train_dataset = train_dataset.shuffle(100).batch(BATCH_SIZE, drop_remainder=True).repeat(EPOCHS)
train_dataset = train_dataset.prefetch(tf.data.experimental.AUTOTUNE)

valid_dataset = valid_dataset.batch(EVAL_BATCH_SIZE)
valid_dataset = valid_dataset.prefetch(tf.data.experimental.AUTOTUNE)

config = BertConfig.from_pretrained("bert-base-chinese", num_labels=dataset.num_labels)
model = TFBertForSequenceClassification.from_pretrained('bert-base-chinese', config=config)
optimizer = tf.keras.optimizers.Adam(learning_rate=LEARNING_RATE, epsilon=1e-08)
if USE_AMP:
    optimizer = tf.keras.mixed_precision.experimental.LossScaleOptimizer(optimizer, 'dynamic')
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
metric = tf.keras.metrics.SparseCategoricalAccuracy('accuracy')
model.compile(optimizer=optimizer, loss=loss, metrics=[metric])

history = model.fit(train_dataset, epochs=EPOCHS,
                    steps_per_epoch=train_number // BATCH_SIZE,
                    validation_data=valid_dataset,
                    validation_steps=eval_number // EVAL_BATCH_SIZE)

Among them, band is a library of Bert that I wrote by myself, which is still under development. No XLA, just settingsUSE_XLAYes. The experimental results of running are as follows:

  • Do not use XLA

    Epoch 1/5
     600/600 [==============================] - 355s 592ms/step - loss: 0.2685 - accuracy: 0.8976 - val_loss: 0.2427 - val_accuracy: 0.9142
     Epoch 2/5
     600/600 [==============================] - 332s 554ms/step - loss: 0.1707 - accuracy: 0.9420 - val_loss: 0.1824 - val_accuracy: 0.9258
     Epoch 3/5
     600/600 [==============================] - 332s 554ms/step - loss: 0.0934 - accuracy: 0.9686 - val_loss: 0.1995 - val_accuracy: 0.9383
     Epoch 4/5
     600/600 [==============================] - 333s 554ms/step - loss: 0.0768 - accuracy: 0.9747 - val_loss: 0.2288 - val_accuracy: 0.9442
     Epoch 5/5
     600/600 [==============================] - 333s 555ms/step - loss: 0.0564 - accuracy: 0.9807 - val_loss: 0.2247 - val_accuracy: 0.9408
  • Using XLA

    Epoch 1/5
    600/600 [==============================] - 573s 955ms/step - loss: 0.2824 - accuracy: 0.8940 - val_loss: 0.2162 - val_accuracy: 0.9192
    Epoch 2/5
    600/600 [==============================] - 309s 515ms/step - loss: 0.1577 - accuracy: 0.9444 - val_loss: 0.2361 - val_accuracy: 0.9233
    Epoch 3/5
    600/600 [==============================] - 309s 514ms/step - loss: 0.0993 - accuracy: 0.9678 - val_loss: 0.2270 - val_accuracy: 0.9333
    Epoch 4/5
    600/600 [==============================] - 307s 512ms/step - loss: 0.0702 - accuracy: 0.9780 - val_loss: 0.2492 - val_accuracy: 0.9300
    Epoch 5/5
    600/600 [==============================] - 310s 516ms/step - loss: 0.0572 - accuracy: 0.9815 - val_loss: 0.2675 - val_accuracy: 0.9300

The specific operation table is as follows:
|Compare | epoch1 | epoch2 ~ 5|
| :———-: | :——: | :—————————: |
|XLA | 355s | 332s not used|
|Use xla| 573s | 309S|

  • Here comes the first question:Why does it take extra long to run the first epoch normally?

    The explanation is that GPU needs to complete some initialization operations (which can be understood as warm-up) of GPU in the first epoch, and the second epoch can be regarded as normal operation.

  • Here comes the second question:Why did it take so long to use XLA’s first epoch?

    XLA is a compiler, so the first epoch is compiling code, which will be slower.

  • Here comes the third question:Why does XLA seem to have a low accuracy?

    I didn’t set up to run seed, XLA was just compiling, and it should have no impact on the code running results.

So to sum up, the first epoch of XLA is compiling code, so the running time is extra long. After the first epoch, the performance is stable and faster than normal running. In this experiment, it is about one tenth faster.

Officials say it’s also helpful to reduce resource use. It’s not a good comparison. Let’s say it’s right for now.

Recommended Today

On the reference count in PHP string type

Author: Wang Shu Background introduction String type is also a commonly used type. Due to the characteristics of strings, in order to save memory, the same string variables usually share a block of memory space. By reference counting, multiple variables are marked to use this memory. However, after GDB tracking, it is found that not […]