You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 

169 lines
6.4 KiB

#!/usr/bin/env python
from SAG_network import *
from timer import Timer
from utils import *
import os
import time
import numpy as np
c_dict = load_from_yml("config.yml")
os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(map(str, c_dict['TRAIN']['GPU_ID']))
def snapshot(sess, saver, output_dir, cur_iter, cur_shape, cur_time):
"""Take a snapshot of the network and save the trainable parameters."""
if not os.path.exists(output_dir):
os.makedirs(output_dir)
cur_dir_name = "%s_%s" % (cur_shape, cur_time)
if cur_dir_name is not None:
current_output_dir = output_dir + cur_dir_name + "/"
if not os.path.exists(current_output_dir):
os.makedirs(current_output_dir)
filename = os.path.join(current_output_dir, '%s_%i.ckpt' % (cur_shape, cur_iter))
else:
filename = os.path.join(output_dir, '%s_%i.ckpt' % (cur_shape, cur_iter))
saver.save(sess, filename)
print 'Wrote snapshot to: {:s}'.format(filename)
def train_net(sess, config_dict):
config_dict['TRAIN']['NUM_GPUS'] = len(config_dict['TRAIN']['GPU_ID'])
shape_name = config_dict['TRAIN']['SHAPE_NAME']
max_part_size = 0
if shape_name == 'motorbike':
max_part_size = 5 # motorbike
elif shape_name == 'chair':
max_part_size = 5 # chair
elif shape_name == 'airplane':
max_part_size = 6 # airplane
elif shape_name == 'guitar':
max_part_size = 3 # guitar
elif shape_name == 'lamp':
max_part_size = 4 # lamp
elif shape_name == 'toy_examples':
max_part_size = 2 # toy_examples
config_dict['MAX_PART_SIZE'] = max_part_size
# Construct and initialize a data_runner
data_helper = data_runner(config_dict, for_training=True)
inputs = data_helper.get_placeholders()
# set the random seed
random_seed = config_dict['TRAIN']['RANDOM_SEED']
tf.set_random_seed(random_seed)
np.random.seed(random_seed)
inputs['config_dict'] = config_dict
total_timer = Timer()
total_timer.tic()
cur_time = time.strftime("%Y_%m_%d_%H_%M_%S", time.localtime())
cur_dir_name = "%s_%s" % (shape_name, cur_time)
train_net = SAGNet(inputs, data_helper=data_helper, for_training=True)
print "setup GCDNet....................................................."
train_net.setup()
total_time_diff = total_timer.toc()
print "------------------------------------------------------------------------------"
for variable in tf.trainable_variables():
shape = variable.get_shape()
print "shape:" + str(shape)
print "name:" + variable.name
print "Initial time: step 1: build the network: %g" % (total_time_diff)
results_dir = config_dict['TRAIN']['RESULTS_DIRECTORY']
results_dir = os.path.join(results_dir, cur_dir_name)
model_dir = config_dict['TRAIN']['MODEL_DIRECTORY']
log_dir = config_dict['TRAIN']['LOG_DIRECTORY']
checkpoint_path = config_dict['TRAIN']['PRETRAINED_MODEL_PATH']
snapshot_freq = config_dict['TRAIN']['SNAPSHOT_FREQ']
summary_freq = config_dict['TRAIN']['SUMMARY_FREQ']
current_log_dir = os.path.join(log_dir, cur_dir_name)
train_writer = tf.summary.FileWriter(current_log_dir, sess.graph)
# write the model info to file
model_info_dir = os.path.join(model_dir, cur_dir_name)
data_helper.write_model_info_to_file(model_info_dir)
iter_timer = Timer()
saver = tf.train.Saver(tf.trainable_variables())
# config = tf.ConfigProto(allow_soft_placement=True)
# sess = tf.Session(config=config)
# # saver.restore(sess, "model-xxxx")
sess.run(tf.global_variables_initializer())
if checkpoint_path != '' and checkpoint_path.endswith('.ckpt'):
print "Restore model from checkpoints..."
saver.restore(sess, checkpoint_path)
print "Restore done."
restore_epoch = os.path.splitext(os.path.basename(checkpoint_path))[0]
restore_epoch = int(restore_epoch.split('_')[-1])
print "Restore done. Current starting iteration: %d" % restore_epoch
else:
restore_epoch = 0
# model_vars = tf.trainable_variables()
# tf.contrib.slim.model_analyzer.analyze_vars(model_vars, print_info=True)
current_iter = -1
last_snapshot_iter = -1
n_epochs = config_dict['TRAIN']['ITER_NUM']
for cur_iter in range(restore_epoch, n_epochs):
current_iter = cur_iter
iter_timer.tic()
f_dict = data_helper.get_next_minibatch(cur_iter=cur_iter) # Load the input data for next mini-batch
if cur_iter == 0:
print "train neural network---------------------------------------------------"
if (cur_iter + 1) % summary_freq == 0 or cur_iter < 20:
_, summaries = train_net.train(sess, f_dict, cur_iter, is_summary=True)
train_writer.add_summary(summaries, cur_iter)
else:
_ = train_net.train(sess, f_dict, cur_iter, is_summary=False)
iter_diff = iter_timer.toc()
print "Current iter: %d, time diff: %g" % (cur_iter, iter_diff)
if (cur_iter + 1) % snapshot_freq == 0:
last_snapshot_iter = cur_iter
snapshot(sess, saver, model_dir, cur_iter, cur_shape=shape_name, cur_time=cur_time)
voxels_list, bboxs_list, part_visible_masks = train_net.get_batch_info(sess, f_dict)
data_helper.write_output_to_file(voxels_list=voxels_list, bboxs_list=bboxs_list,
part_visible_masks=part_visible_masks,
input_info_dict=f_dict, output_dir=results_dir,
iter_n=current_iter)
if last_snapshot_iter != current_iter:
snapshot(sess, saver, model_dir, current_iter, cur_shape=shape_name, cur_time=cur_time)
f_dict = data_helper.get_next_minibatch(cur_iter=current_iter)
voxels_list, bboxs_list, part_visible_masks = train_net.get_batch_info(sess, f_dict)
data_helper.write_output_to_file(voxels_list=voxels_list, bboxs_list=bboxs_list,
part_visible_masks=part_visible_masks,
input_info_dict=f_dict, output_dir=results_dir, iter_n=current_iter)
if __name__ == "__main__":
# with tf.Session() as sess:
#config = tf.ConfigProto(allow_soft_placement=True)
with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
train_net(sess=sess, config_dict=c_dict)