Tensorflow fixed partial parameter training, only training partial parameter examples

Time:2020-5-30

When using tensorflow to train a model, sometimes we need to rely on the verification set to determine whether the model has been fitted or whether we need to stop training.

1. The first thought is to use tf.placeholder () load different data for calculation, such as


def inference(input_):
  """
  this is where you put your graph.
  the following is just an example.
  """
  
  conv1 = tf.layers.conv2d(input_)
 
  conv2 = tf.layers.conv2d(conv1)
 
  return conv2
 
 
input_ = tf.placeholder()
output = inference(input_)
...
calculate_loss_op = ...
train_op = ...
...
 
with tf.Session() as sess:
  sess.run([loss, train_op], feed_dict={input_: train_data})
 
  if validation == True:
    sess.run([loss], feed_dict={input_: validate_date})

It’s simple and straightforward.

2. However, if the amount of data processed is large, use the tf.placeholder () to load data will seriously slow down the progress of training, so tfrecords file is often used to read data.

At this point, it’s easy to think of passing different values into the information () function for calculation.


train_batch, label_batch = decode_train()
val_train_batch, val_label_batch = decode_validation()
 
 
train_result = inference(train_batch)
...
loss = ..
train_op = ...
...
 
if validation == True:
  val_result = inference(val_train_batch)
  val_loss = ..
  
 
with tf.Session() as sess:
  sess.run([loss, train_op])
 
  if validation == True:
    sess.run([val_result, val_loss])

This method seems to be able to directly call information() to carry out forward propagation calculation on the validation data. However, in fact, many new nodes will be added to the original graph. The parameters of these nodes need to be reinitialized, that is to say, the training weight is not used in the verification.

3. Use one tf.placeholder To control whether to train and verify.


def inference(input_):
  ...
  ...
  ...
  
  return inference_result
 
 
train_batch, label_batch = decode_train()
val_batch, val_label = decode_validation()
 
is_training = tf.placeholder(tf.bool, shape=())
 
x = tf.cond(is_training, lambda: train_batch, lambda: val_batch)
y = tf.cond(is_training, lambda: train_label, lambda: val_label)
 
logits = inference(x)
loss = cal_loss(logits, y)
train_op = optimize(loss)
 
with tf.Session() as sess:
  
  loss, _ = sess.run([loss, train_op], feed_dict={is_training: True})
  
  if validation == True:
    loss = sess.run(loss, feed_dict={is_training: False})

In this way, you can create a branch condition in a large graph, and control whether to verify by controlling the placeholder.

In the above tensorflow fixed part parameter training, only the example of training part parameters is all the content shared by Xiaobian. I hope to give you a reference, and I hope you can support developer more.