tensorflow 同一个session 分别对训练集和测试集训练参数失败,最终只能保存一组参数怎么解决?

程宏  

2017-12-01

各位,问一个问题

使用TensorFlow进行训练时,希望每个一定的step对训练集和测试集做出预测,打印到控制台,同时保存到accuracy到tensorborad。在我的实验中,有三个step-accuracy-weight parameters记录,1)每个500steps 保存训练集结果,2)每隔两千steps,测试并保存训练集结果 3)每隔5000steps 测试并保存test dataset 训练结果

 但是我进行保存的时候,最后就只有一个保存的session运行了

 

这里可以看到,val下保存参数已经结束了

再添加一下TensorBoard loss 图像


请问怎么解决这个问题呢?

      下面附加一下相关代码吧,

TensorFlow版本为1.2,

Anaconda python 3.6 64bit

双GTx1080 GPU

Ubuntu 14.04 LTS 64bit

       
        # with one svaer and three writer

        saver = tf.train.Saver()

        sess.run(tf.global_variables_initializer())
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess= sess, coord=coord)
        
        summary_op = tf.summary.merge_all()  

      
        # add a writer (create the directory)
        train_writer = tf.summary.FileWriter(logs_train_dir, sess.graph)
        val_writer = tf.summary.FileWriter(logs_val_dir, sess.graph)
        train_writer2 = tf.summary.FileWriter(logs_train_dir_2, sess.graph)

   
     try:
            for step in np.arange(MAX_STEP):
                if coord.should_stop():
                    break
                   
                if es_number-decrease_loss >= stop_point:
                    break;
               
                train_xqs, train_yqs = sess.run([train_xq_batch, train_yq_batch])
                tra_images,tra_labels = sess.run([train_batch, train_label_batch])


                tra_logitsOut, tra_loss , _ = sess.run([logits, loss ,train_op],
                                                feed_dict={x:tra_images, y_:tra_labels})
               

                if minLoss_train2 > tra_loss:
                        minLoss_train2 = tra_loss
                        summary_str_train2 = sess.run(summary_op)
                        train_writer2.add_summary(summary_str_train2, step)
                        checkpoint_path_train2 = os.path.join(logs_train_dir_2, 'model.ckpt')
                        saver.save(sess, checkpoint_path_train2, global_step=step)
                       
               
                if step % 2000 == 0:
          
                    print('Step %d, train loss = %.6f ' %(step, tra_loss ))
                    summary_str = sess.run(summary_op)
                    train_writer.add_summary(summary_str, step)
                   
                    checkpoint_path = os.path.join(logs_train_dir, 'model.ckpt')
                    saver.save(sess, checkpoint_path, global_step=step)
                   

                                                       
                   
                if step % 5000 == 0 or (step + 1) == MAX_STEP:
                    val_xqs, val_yqs = sess.run([val_xq_batch, val_yq_batch])
                    val_images, val_labels = sess.run([val_batch, val_label_batch])
                    val_logitsOut, val_loss = sess.run([logits, loss],feed_dict={x:val_images, y_:val_labels})
                    print('**  Step %d, val loss = %.6f **' %(step, val_loss ))
                    val_loss = sess.run(loss,feed_dict={x:val_images, y_:val_labels})

                    if minLoss_val > val_loss:
                        minLoss_val = val_loss
                        decrease_loss = decrease_loss+1
                        es_number = 0
                        summary_str_val = sess.run(summary_op)
                        val_writer.add_summary(summary_str_val, step)
                        checkpoint_path_val = os.path.join(logs_val_dir, 'model.ckpt')
                        saver.save(sess, checkpoint_path_val, global_step=step)
                    else:
                        es_number = es_number+1
                        decrease_loss = 0

 

关注问题 9人已关注
 写回答 1人已回答