Introduction and implementation of multi task learning model based on esmm

Time:2022-6-10

Introduction: This article introduces the paper “entire space multi task model: an e” effective approach for estimating post click conversion rate “published by Alibaba team in SIGIR ‘2018. Based on the idea of multi task learning (MTL), this paper proposes a CVR prediction model called esmm, which effectively solves the two key problems of data sparsity and sample selection bias in CVR prediction in real scenes. Later, we will introduce the multi task learning models such as MMOE, ple and dbmtl.

Multi task learning background

At present, the recommendation algorithms used in industry are not only limited to single target (CTR) tasks, but also need to pay attention to the subsequent conversion links, such as whether to comment, collect, add purchase, purchase, viewing duration and so on.

This article introduces the paper “entire space multi task model: an e” effective approach for estimating post click conversion rate “published by Alibaba team in SIGIR ‘2018. Based on the idea of multi task learning (MTL), this paper proposes a CVR prediction model called esmm, which effectively solves the two key problems of data sparsity and sample selection bias in CVR prediction in real scenes. Later, we will introduce the multi task learning models such as MMOE, ple and dbmtl.

Paper introduction

CVR estimation faces two key problems:

1. Sample Selection Bias (SSB)

Conversion is an action that is “possible” only after clicking. The traditional CVR model usually takes click data as the training set, in which click is not converted into a negative example, and click is converted into a positive example. However, when the trained model is actually used, it estimates the samples in the whole space, not just the click samples. That is, the training data and the actual data to be predicted come from different distributions. This deviation poses a great challenge to the generalization ability of the model, resulting in the general effect of online business after the model is launched.

Introduction and implementation of multi task learning model based on esmm

2. Data Sparsity (DS)

The training data (i.e. click samples) used in CVR prediction task is much smaller than the exposure samples used in CTR prediction training. Using only a small number of samples for training will lead to difficulty in fitting the depth model.

Some strategies can alleviate these two problems, such as negative samples from the exposure set to mitigate SSB, oversampling transformed samples to mitigate DS, etc. But no matter which method is adopted, none of the above problems can be solved in essence.

Since the Click = > transformation itself is two strongly related continuous behaviors, the author hopes to show and consider this “behavior chain relationship” in the model structure, so that it can be trained and predicted in the whole space. This involves two tasks: CTR and CVR, so using multi task learning (MTL) is a natural choice. The key highlight of this paper is “how to build” this MTL.

First of all, it is necessary to distinguish between CVR estimation task and ctcvr estimation task.

CVR = number of conversions / hits. It predicts the probability that “if an item is clicked, it will be transformed”. CVR estimation task has no absolute relationship with CTR. The CTR of an item is high, but the CVR is not necessarily the same. For example, the browsing time of the title party article is often low. This is also the reason why it is impossible to directly use all samples to train the CVR model, because it is impossible to determine whether the exposed samples that are not clicked will be converted if they are clicked. If 0 is directly used as their label, it will mislead the learning of CVR model to a great extent.
Ctcvr = number of conversions / number of exposures. It predicts the probability that an item is clicked and then converted.

Introduction and implementation of multi task learning model based on esmm

Where x, y and Z respectively represent exposure, click and conversion. Note that in all sample spaces, the label corresponding to CTR is click, while the label corresponding to ctcvr is click & conversion. All samples can be used for these two tasks. Therefore, esmm learns CTR and ctcvr tasks, and then implicitly learns CVR tasks according to the above formula. The specific structure is as follows:

Introduction and implementation of multi task learning model based on esmm

There are two points worth emphasizing on the network structure:

Share embedding. CVR task and CTR task use the same features and feature embedding, that is, they learn their own exclusive parameters after concatenate;

Implicit learning pcvr. Here, pcvr is only a variable in the network, and there is no monitoring signal displayed.
Specifically, it is reflected in the objective function:

Introduction and implementation of multi task learning model based on esmm

code implementation

Based on the easyrec recommended algorithm framework, we have implemented the esmm algorithm. The specific implementation can be moved to GitHub: easyrec esmm.

Easyrec introduction: easyrec is an open-source large-scale distributed recommendation algorithm framework for the machine learning Pai team of the Alibaba cloud computing platform. Easyrec, like its name, is simple and easy to use, integrates many excellent and cutting-edge recommendation system ideas, and has a feature engineering method that achieves good results in the actual industrial implementation. It integrates training, evaluation, deployment, and seamlessly connects with Alibaba cloud products, With the help of easyrec, a cutting-edge recommendation system can be built in a short time. As Alibaba cloud’s flagship product, it has been steadily serving hundreds of enterprise customers.

Model feedforward network:

def build_predict_graph(self):
   """Forward function.

   Returns:
     self._prediction_dict: Prediction result of two tasks.
   """
   #Here, the generation logic is omitted from the tensor (all_fea) after concatenate

   cvr_tower_name = self._cvr_tower_cfg.tower_name
   dnn_model = dnn.DNN(
       self._cvr_tower_cfg.dnn,
       self._l2_reg,
       name=cvr_tower_name,
       is_training=self._is_training)
   cvr_tower_output = dnn_model(all_fea)
   cvr_tower_output = tf.layers.dense(
       inputs=cvr_tower_output,
       units=1,
       kernel_regularizer=self._l2_reg,
       name='%s/dnn_output' % cvr_tower_name)

   ctr_tower_name = self._ctr_tower_cfg.tower_name
   dnn_model = dnn.DNN(
       self._ctr_tower_cfg.dnn,
       self._l2_reg,
       name=ctr_tower_name,
       is_training=self._is_training)
   ctr_tower_output = dnn_model(all_fea)
   ctr_tower_output = tf.layers.dense(
       inputs=ctr_tower_output,
       units=1,
       kernel_regularizer=self._l2_reg,
       name='%s/dnn_output' % ctr_tower_name)

   tower_outputs = {
       cvr_tower_name: cvr_tower_output,
       ctr_tower_name: ctr_tower_output
   }
   self._add_to_prediction_dict(tower_outputs)
   return self._prediction_dict

Loss calculation:

Note: mask off exposure data is required when calculating CVR indicators.

def build_loss_graph(self):
   """Build loss graph.

   Returns:
     self._loss_dict: Weighted loss of ctr and cvr.
   """
   cvr_tower_name = self._cvr_tower_cfg.tower_name
   ctr_tower_name = self._ctr_tower_cfg.tower_name
   cvr_label_name = self._label_name_dict[cvr_tower_name]
   ctr_label_name = self._label_name_dict[ctr_tower_name]

   ctcvr_label = tf.cast(
       self._labels[cvr_label_name] * self._labels[ctr_label_name], 
       tf.float32)
   cvr_loss = tf.keras.backend.binary_crossentropy(
       ctcvr_label, self._prediction_dict['probs_ctcvr'])
   cvr_loss = tf.reduce_sum(cvr_losses, name="ctcvr_loss")

   # The weight defaults to 1.
   self._loss_dict['weighted_cross_entropy_loss_%s' %
                     cvr_tower_name] = self._cvr_tower_cfg.weight * cvr_loss

   ctr_loss = tf.reduce_sum(tf.nn.sigmoid_cross_entropy_with_logits(
       labels=tf.cast(self._labels[ctr_label_name], tf.float32),
       logits=self._prediction_dict['logits_%s' % ctr_tower_name]
       ), name="ctr_loss")

   self._loss_dict['weighted_cross_entropy_loss_%s' %
                   ctr_tower_name] = self._ctr_tower_cfg.weight * ctr_loss
   return self._loss_dict

Note: loss here is weighted_ cross_ entropy_ loss_ ctr + weighted_ cross_ entropy_ loss_ CVR, easyrec framework will automatically_ loss_ Add the contents of dict.

Metric calculation:

Note: mask off exposure data is required when calculating CVR indicators.

def build_metric_graph(self, eval_config):
  """Build metric graph.

  Args:
    eval_config: Evaluation configuration.

  Returns:
    metric_dict: Calculate AUC of ctr, cvr and ctrvr.
  """
  metric_dict = {}

  cvr_tower_name = self._cvr_tower_cfg.tower_name
  ctr_tower_name = self._ctr_tower_cfg.tower_name
  cvr_label_name = self._label_name_dict[cvr_tower_name]
  ctr_label_name = self._label_name_dict[ctr_tower_name]
  for metric in self._cvr_tower_cfg.metrics_set:
    # CTCVR metric
    ctcvr_label_name = cvr_label_name + '_ctcvr'
    cvr_dtype = self._labels[cvr_label_name].dtype
    self._labels[ctcvr_label_name] = self._labels[cvr_label_name] * tf.cast(
        self._labels[ctr_label_name], cvr_dtype)
    metric_dict.update(
        self._build_metric_impl(
            metric,
            loss_type=self._cvr_tower_cfg.loss_type,
            label_name=ctcvr_label_name,
            num_class=self._cvr_tower_cfg.num_class,
            suffix='_ctcvr'))

    # CVR metric
    cvr_label_masked_name = cvr_label_name + '_masked'
    ctr_mask = self._labels[ctr_label_name] > 0
    self._labels[cvr_label_masked_name] = tf.boolean_mask(
        self._labels[cvr_label_name], ctr_mask)
    pred_prefix = 'probs' if self._cvr_tower_cfg.loss_type == LossType.CLASSIFICATION else 'y'
    pred_name = '%s_%s' % (pred_prefix, cvr_tower_name)
    self._prediction_dict[pred_name + '_masked'] = tf.boolean_mask(
        self._prediction_dict[pred_name], ctr_mask)
    metric_dict.update(
        self._build_metric_impl(
            metric,
            loss_type=self._cvr_tower_cfg.loss_type,
            label_name=cvr_label_masked_name,
            num_class=self._cvr_tower_cfg.num_class,
            suffix='_%s_masked' % cvr_tower_name))

  for metric in self._ctr_tower_cfg.metrics_set:
    # CTR metric
    metric_dict.update(
        self._build_metric_impl(
            metric,
            loss_type=self._ctr_tower_cfg.loss_type,
            label_name=ctr_label_name,
            num_class=self._ctr_tower_cfg.num_class,
            suffix='_%s' % ctr_tower_name))
  return metric_dict

Experiment and deficiency

We have conducted a large number of experiments based on open source aliccp data. Please look forward to the next article for the experiment part. It is found that the seesaw phenomenon of esmm is obvious, and it is difficult to improve the effects of CTR and CVR tasks at the same time.

reference

Entire Space Multi-Task Model: An Effective Approach for Estimating Post-Click Conversion Rate
Esmm of Ali CVR prediction model
Introduction to easyrec esmm introduction and implementation of multi task learning model
Note: the pictures and publicity in this article are quoted from the paper: entire space multi task model: an effective approach for estimating post click conversion rate.

Original link
This article is the original content of Alibaba cloud and cannot be reproduced without permission.