Tensorflow: how to save/restore a model?

Tensorflow: how to save/restore a model?

Asked on November 3, 2018 in Python.
Add Comment


  • 3 Answer(s)

    Here is the solution to save or restore a model:

    import tensorflow as tf
    from tensorflow.python.saved_model import tag_constants
     
    with tf.Graph().as_default():
        with tf.Session as sess:
            ...
            # Saving
            inputs = {
                "batch_size_placeholder": batch_size_placeholder,
                "features_placeholder": features_placeholder,
                "labels_placeholder": labels_placeholder,
            }
            outputs = {"prediction": model_output}
            tf.saved_model.simple_save(
                sess, 'path/to/your/location/', inputs, outputs
            )
    

    Restoring:

    graph = tf.Graph()
    with restored_graph.as_default():
        with tf.Session as sess:
            tf.saved_model.loader.load(
                sess,
                [tag_constants.SERVING],
            'path/to/your/location/',
            )
            batch_size_placeholder = graph.get_tensor_by_name('batch_size_placeholder:0')
            features_placeholder = graph.get_tensor_by_name('features_placeholder:0')
            labels_placeholder = graph.get_tensor_by_name('labels_placeholder:0')
            prediction = restored_graph.get_tensor_by_name('dense/BiasAdd:0')
     
            sess.run(prediction, feed_dict={
                batch_size_placeholder: some_value,
                features_placeholder: some_other_value,
                labels_placeholder: another_value
            })
    
    Answered on November 3, 2018.
    Add Comment

    This is a improved answer

    In(and after) Tensorflow version 0.11:
    Save the model:

    import tensorflow as tf
     
    #Prepare to feed input, i.e. feed_dict and placeholders
    w1 = tf.placeholder("float", name="w1")
    w2 = tf.placeholder("float", name="w2")
    b1= tf.Variable(2.0,name="bias")
    feed_dict ={w1:4,w2:8}
     
    #Define a test operation that we will restore
    w3 = tf.add(w1,w2)
    w4 = tf.multiply(w3,b1,name="op_to_restore")
    sess = tf.Session()
    sess.run(tf.global_variables_initializer())
     
    #Create a saver object which will save all the variables
    saver = tf.train.Saver()
     
    #Run the operation by feeding input
    print sess.run(w4,feed_dict)
    #Prints 24 which is sum of (w1+w2)*b1
     
    #Now, save the graph
    saver.save(sess, 'my_test_model',global_step=1000)
    

    Restore the model:

    import tensorflow as tf
     
    sess=tf.Session()
    #First let's load meta graph and restore weights
    saver = tf.train.import_meta_graph('my_test_model-1000.meta')
    saver.restore(sess,tf.train.latest_checkpoint('./'))
     
    # Access saved Variables directly
    print(sess.run('bias:0'))
    # This will print 2, which is the value of bias that we saved
     
    # Now, let's access and create placeholders variables and
    # create feed-dict to feed new data
     
    graph = tf.get_default_graph()
    w1 = graph.get_tensor_by_name("w1:0")
    w2 = graph.get_tensor_by_name("w2:0")
    feed_dict ={w1:13.0,w2:17.0}
     
    #Now, access the op that you want to run.
    op_to_restore = graph.get_tensor_by_name("op_to_restore:0")
     
    print sess.run(op_to_restore,feed_dict)
    #This will print 60 which is calculated
    

    This and some more advanced use-cases have been explained very well here with the python community.

    Answered on November 3, 2018.
    Add Comment

    In TensorFlow version 0.11.0RC1, you can now save and restore your model directly by calling tf.train.export_meta_graph and tf.train.import_meta_graph the function 

    Save the model

    w1 = tf.Variable(tf.truncated_normal(shape=[10]), name='w1')
    w2 = tf.Variable(tf.truncated_normal(shape=[20]), name='w2')
    tf.add_to_collection('vars', w1)
    tf.add_to_collection('vars', w2)
    saver = tf.train.Saver()
    sess = tf.Session()
    sess.run(tf.global_variables_initializer())
    saver.save(sess, 'my-model')
    # `save` method will call `export_meta_graph` implicitly.
    # you will get saved graph files:my-model.meta
    

    Restore the model

    sess = tf.Session()
    new_saver = tf.train.import_meta_graph('my-model.meta')
    new_saver.restore(sess, tf.train.latest_checkpoint('./'))
    all_vars = tf.get_collection('vars')
    for v in all_vars:
        v_ = sess.run(v)
        print(v_)
    
    Answered on November 3, 2018.
    Add Comment


  • Your Answer

    By posting your answer, you agree to the privacy policy and terms of service.