Source code interpretation: cssrnn

Time:2021-4-29

Source code:Cssrnn GitHub link, ijcai2017

Model interpretation

Source code interpretation: cssrnn

  • current embedding: [state_ size, emb_ Dim], where state_ Size is the number of links or grids, EMB_ Dim is the dimension of embedding.
  • destination embedding: [state_ size, emb_ [dim], ditto; The terminal embedding is a separate set.
  • Neighbor embedding: in fact, it is a set of coefficients of linear transformation, not embedding, but using embedding to accelerate; The shape of W is [hid]_ dim, state_ The shape of B is [size]_ Size], where hid_ Dim is the output dimension of LSTM. The essence of the last layer of cssrnn is a softmax layer with adjacency table constraints. The so-called lpirnn adds a dimension to cssrnn, and the shape is [hid]_ dim, state_ size, adj_ size]。 It can be understood that cssrnn gives coefficients according to nodes and lpirnn gives coefficients according to edges.

Code interpretation

ID Embedding
First create the edge embedding, shape = [state]_ size, emb_ dim]; The essence of embedding is the coefficient matrix w of fully connected neural network.

#There is retrain
emb_ = tf.get_variable("embedding", dtype=tf.float64, initializer=pretrained_emb_)
#No retrain
emb_ = tf.get_variable("embedding", [state_size, emb_dim], dtype=tf.float64)

The encoding of input is one hot, and constructing a fully connected neural network for one hot input is equivalent to extracting one hot, that is, the row of element 1, from embedding according to ID number. This function is similar to the TF. Gather () method, and tensorflow provides TF. NN. Embedding_ Look up (), you can look up the table from embedding in parallel and get the input tensor (shape = [batch])_ size, time_ steps, state_ Tensor (shape = [batch]) after embedding_ size, time_ steps, emb_ dim])。

emb_inputs_ = tf.nn.embedding_lookup(emb_, input_label, name="emb_inputs") # [batch, time, emb_dim]

In order to consider the impact of the end point, we can use the same method to embed the destination, and then use TF. Concat to splice it into the tensor of one hot embedding.

#Note that the destination is embedded once, which is not the same as the previous emb
dest_emb_ = tf.get_variable("dest_emb", [state_size, emb_dim], dtype=tf.float64)
dest_inputs_ = tf.tile(tf.expand_dims(tf.nn.embedding_lookup(self.dest_emb_, dest_label_), 1), [1, self.max_t_, 1])  # [batch, time, dest_dim]
inputs_ = tf.concat(2, [emb_inputs_, dest_inputs_], "input_with_dest") # [batch, time, emb_dim + dest_dim]

RNN layer:

cell = tf.keras.layers.LSTMCell(hidden_dim)
layer = tf.keras.layers.RNN(cell)
rnn_outputs = layer(emb_inputs_, return_sequences=True)  # [batch, time, hid_dim]

Softmax layer:
The loss is calculated according to the output of RNN, and the constraint of adjacency table should be considered in the loss calculation.

outputs_ = tf.reshape(rnn_outputs, ...) # [batch*time, hid_dim]

#Parameters of output layer
wp_ = tf.get_variable("wp", [int(outputs_flat_.get_shape()[1]), config.state_size],
                          dtype=config.float_type)  # [hid_dim, state_size]
bp_ = tf.get_variable("bp", [config.state_size], dtype=config.float_type)  # [state_size]


adj_mat = ... # n_edge * n_neighbor, element is the id of edge
adj_mask = ... # n_edge * n_neighbor, element is 1 or 0, where 1 means it is an edge in adj_mat and 0 means a padding in adj_mat

input_flat_ = tf.reshape(input_, [-1])  # [batch*t]
target_flat_ = tf.reshape(target_, [-1, 1])  # [batch*t, 1]
sub_adj_mat_ = tf.nn.embedding_lookup(adj_mat_, input_flat_) # [batch*t, max_adj_num]
sub_adj_mask_ = tf.nn.embedding_lookup(adj_mask_, input_flat_)  # [batch*t, max_adj_num]
# first column is target_
target_and_sub_adj_mat_ = tf.concat(1, [target_flat_, sub_adj_mat_])  # [batch*t, max_adj_num+1]

outputs_3d_ = tf.expand_dims(outputs_, 1)  # [batch*max_seq_len, hid_dim] -> [batch*max_seq_len, 1, hid_dim]

sub_w_ = tf.nn.embedding_lookup(w_t_, target_and_sub_adj_mat_)  # [batch*max_seq_len, max_adj_num+1, hid_dim]
sub_b_ = tf.nn.embedding_lookup(b_, target_and_sub_adj_mat_)  # [batch*max_seq_len, max_adj_num+1] 
sub_w_flat_ = tf.reshape(sub_w_, [-1, int(sub_w_.get_shape()[2])])  # [batch*max_seq_len*max_adj_num+1, hid_dim]
sub_b_flat_ = tf.reshape(sub_b_, [-1]) # [batch*max_seq_len*max_adj_num+1]

outputs_tiled_ = tf.tile(outputs_3d_, [1, tf.shape(adj_mat_)[1] + 1, 1])  # [batch*max_seq_len, max+adj_num+1, hid_dim]
outputs_tiled_ = tf.reshape(outputs_tiled_, [-1, int(outputs_tiled_.get_shape()[2])])  # [batch*max_seq_len*max_adj_num+1, hid_dim]
target_logit_and_sub_logits_ = tf.reshape(tf.reduce_sum(tf.multiply(sub_w_flat_, outputs_tiled_), 1) + sub_b_flat_,
                                                          [-1, tf.shape(adj_mat_)[1] + 1])  # [batch*max_seq_len, max_adj_num+1]

# for numerical stability
scales_ = tf.reduce_max(target_logit_and_sub_logits_, 1)  # [batch*max_seq_len]
scaled_target_logit_and_sub_logits_ = tf.transpose(tf.subtract(tf.transpose(target_logit_and_sub_logits_), scales_))  # transpose for broadcasting [batch*max_seq_len, max_adj_num+1]

scaled_sub_logits_ = scaled_target_logit_and_sub_logits_[:, 1:]  # [batch*max_seq_len, max_adj_num]
exp_scaled_sub_logits_ = tf.exp(scaled_sub_logits_)  # [batch*max_seq_len, max_adj_num]
deno_ = tf.reduce_sum(tf.multiply(exp_scaled_sub_logits_, sub_adj_mask_), 1)  # [batch*max_seq_len]
#log_deno_ = tf.log(deno_)  # [batch*max_seq_len]
log_ deno_ =  tf.log(tf.clip_ by_ value(deno_, 1e-8,tf.reduce_ max(deno_))) # Avoid meaningless calculation
log_nume_ = tf.reshape(scaled_target_logit_and_sub_logits_[:, 0:1], [-1])  # [batch*max_seq_len]
loss_ = tf.subtract(log_deno_, log_nume_)  # [batch*t] since loss is -sum(log(softmax))

max_prediction_ = tf.one_hot(tf.argmax(exp_scaled_sub_logits_ * sub_adj_mask_, 1),
                           int(adj_mat_.get_shape()[1]),
                           dtype=tf.float32)  # [batch*max_seq_len, max_adj_num]