Tensorflow implementation is tested on the trained model


Tensorflow can use the trained model to test the new data, there are two methods: the first method is to call the model and train in the same py file, the situation is relatively simple; the second method is to train the process and call the model process in two py files respectively. This article will explain the second method.

Save model

Tensorflow provides an interface to save the training model, which is not difficult to use. Directly explain the code:

#Network structure
w1 = tf.Variable(tf.truncated_normal([in_units, h1_units], stddev=0.1))
b1 = tf.Variable(tf.zeros([h1_units]))
y = tf.nn.softmax(tf.matmul(w1, x) + b1)
tf.add_to_collection('network-output', y)

x = tf.placeholder(tf.float32, [None, in_units], name='x')
y_ = tf.placeholder(tf.float32, [None, 10], name='y_')
#Loss function and optimization function
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1]))
train_step = tf.train.AdamOptimizer(rate).minimize(cross_entropy)

saver = tf.train.Saver()
with tf.Session() as sess: 
    train_step.run({x: train_x, y_: train_y})

The above code completes the saving of the model, and it is worth noting that the following line of code

tf.add_to_collection('network-output', y)

This line of code saves the output of the neural network, which plays a key role in the later use of the import model.

Import of model

After the model is trained and saved, it can be imported to evaluate the performance of the model on the test set. Many articles on the Internet only use four simple operations as an example, which makes people look big. Or code first:

with tf.Session() as sess:
  saver = tf.train.import_meta_graph('./model.ckpt.meta')
  Saver. Restore (sess, '. / model. CKPT') (data file)
  pred = tf.get_collection('network-output')[0]

  graph = tf.get_default_graph()
  x = graph.get_operation_by_name('x').outputs[0]
  y_ = graph.get_operation_by_name('y_').outputs[0]

  y = sess.run(pred, feed_dict={x: test_x, y_: test_y})

Explain the key code. First, pred= tf.get_ collection(‘pred_ Network ‘) [0], this line of code obtains the “interface” of network output in the training process, which is simply understood as follows: tf.get_ Collection () gets the entire network structure. After we get the network structure, we need to feed its corresponding data y= sess.run (pred, feed_ dict={x: test_ x, y_ : test_ y} ) our input during the training is

x = tf.placeholder(tf.float32, [None, in_units], name='x')
y_ = tf.placeholder(tf.float32, [None, 10], name='y_')

Therefore, the input required after importing the model should also correspond to it, which can be obtained by using the following code:

  x = graph.get_operation_by_name('x').outputs[0]
  y_ = graph.get_operation_by_name('y_').outputs[0]

The last step of using the model is to input the test set and then evaluate it according to the trained network

  sess.run(pred, feed_dict={x: test_x, y_: test_y})

Understand this line of code. The function prototype of sess. Run() is

run(fetches, feed_dict=None, options=None, run_metadata=None)

Tensorflow to feed_ Dict performs the fetches operation, so the operation after importing the model is to calculate the input data according to the trained network.

The above tensorflow implementation tested on the trained model is all the content shared by the editor. I hope it can give you a reference, and I hope you can support developepaer more.