You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
169 lines
6.4 KiB
169 lines
6.4 KiB
#!/usr/bin/env python
|
|
from SAG_network import *
|
|
from timer import Timer
|
|
from utils import *
|
|
import os
|
|
import time
|
|
import numpy as np
|
|
|
|
c_dict = load_from_yml("config.yml")
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(map(str, c_dict['TRAIN']['GPU_ID']))
|
|
|
|
def snapshot(sess, saver, output_dir, cur_iter, cur_shape, cur_time):
|
|
"""Take a snapshot of the network and save the trainable parameters."""
|
|
if not os.path.exists(output_dir):
|
|
os.makedirs(output_dir)
|
|
|
|
cur_dir_name = "%s_%s" % (cur_shape, cur_time)
|
|
if cur_dir_name is not None:
|
|
current_output_dir = output_dir + cur_dir_name + "/"
|
|
if not os.path.exists(current_output_dir):
|
|
os.makedirs(current_output_dir)
|
|
filename = os.path.join(current_output_dir, '%s_%i.ckpt' % (cur_shape, cur_iter))
|
|
else:
|
|
filename = os.path.join(output_dir, '%s_%i.ckpt' % (cur_shape, cur_iter))
|
|
|
|
saver.save(sess, filename)
|
|
print 'Wrote snapshot to: {:s}'.format(filename)
|
|
|
|
|
|
def train_net(sess, config_dict):
|
|
config_dict['TRAIN']['NUM_GPUS'] = len(config_dict['TRAIN']['GPU_ID'])
|
|
|
|
shape_name = config_dict['TRAIN']['SHAPE_NAME']
|
|
max_part_size = 0
|
|
if shape_name == 'motorbike':
|
|
max_part_size = 5 # motorbike
|
|
elif shape_name == 'chair':
|
|
max_part_size = 5 # chair
|
|
elif shape_name == 'airplane':
|
|
max_part_size = 6 # airplane
|
|
elif shape_name == 'guitar':
|
|
max_part_size = 3 # guitar
|
|
elif shape_name == 'lamp':
|
|
max_part_size = 4 # lamp
|
|
elif shape_name == 'toy_examples':
|
|
max_part_size = 2 # toy_examples
|
|
config_dict['MAX_PART_SIZE'] = max_part_size
|
|
|
|
# Construct and initialize a data_runner
|
|
data_helper = data_runner(config_dict, for_training=True)
|
|
inputs = data_helper.get_placeholders()
|
|
|
|
# set the random seed
|
|
random_seed = config_dict['TRAIN']['RANDOM_SEED']
|
|
tf.set_random_seed(random_seed)
|
|
np.random.seed(random_seed)
|
|
|
|
inputs['config_dict'] = config_dict
|
|
|
|
total_timer = Timer()
|
|
total_timer.tic()
|
|
|
|
cur_time = time.strftime("%Y_%m_%d_%H_%M_%S", time.localtime())
|
|
cur_dir_name = "%s_%s" % (shape_name, cur_time)
|
|
|
|
train_net = SAGNet(inputs, data_helper=data_helper, for_training=True)
|
|
print "setup GCDNet....................................................."
|
|
train_net.setup()
|
|
total_time_diff = total_timer.toc()
|
|
|
|
print "------------------------------------------------------------------------------"
|
|
for variable in tf.trainable_variables():
|
|
shape = variable.get_shape()
|
|
print "shape:" + str(shape)
|
|
print "name:" + variable.name
|
|
|
|
print "Initial time: step 1: build the network: %g" % (total_time_diff)
|
|
|
|
results_dir = config_dict['TRAIN']['RESULTS_DIRECTORY']
|
|
results_dir = os.path.join(results_dir, cur_dir_name)
|
|
|
|
model_dir = config_dict['TRAIN']['MODEL_DIRECTORY']
|
|
log_dir = config_dict['TRAIN']['LOG_DIRECTORY']
|
|
checkpoint_path = config_dict['TRAIN']['PRETRAINED_MODEL_PATH']
|
|
|
|
snapshot_freq = config_dict['TRAIN']['SNAPSHOT_FREQ']
|
|
summary_freq = config_dict['TRAIN']['SUMMARY_FREQ']
|
|
|
|
current_log_dir = os.path.join(log_dir, cur_dir_name)
|
|
train_writer = tf.summary.FileWriter(current_log_dir, sess.graph)
|
|
|
|
# write the model info to file
|
|
model_info_dir = os.path.join(model_dir, cur_dir_name)
|
|
data_helper.write_model_info_to_file(model_info_dir)
|
|
|
|
iter_timer = Timer()
|
|
|
|
saver = tf.train.Saver(tf.trainable_variables())
|
|
|
|
# config = tf.ConfigProto(allow_soft_placement=True)
|
|
# sess = tf.Session(config=config)
|
|
# # saver.restore(sess, "model-xxxx")
|
|
sess.run(tf.global_variables_initializer())
|
|
|
|
if checkpoint_path != '' and checkpoint_path.endswith('.ckpt'):
|
|
print "Restore model from checkpoints..."
|
|
saver.restore(sess, checkpoint_path)
|
|
print "Restore done."
|
|
|
|
restore_epoch = os.path.splitext(os.path.basename(checkpoint_path))[0]
|
|
restore_epoch = int(restore_epoch.split('_')[-1])
|
|
|
|
print "Restore done. Current starting iteration: %d" % restore_epoch
|
|
else:
|
|
restore_epoch = 0
|
|
|
|
# model_vars = tf.trainable_variables()
|
|
# tf.contrib.slim.model_analyzer.analyze_vars(model_vars, print_info=True)
|
|
|
|
current_iter = -1
|
|
|
|
last_snapshot_iter = -1
|
|
n_epochs = config_dict['TRAIN']['ITER_NUM']
|
|
|
|
for cur_iter in range(restore_epoch, n_epochs):
|
|
current_iter = cur_iter
|
|
iter_timer.tic()
|
|
|
|
f_dict = data_helper.get_next_minibatch(cur_iter=cur_iter) # Load the input data for next mini-batch
|
|
|
|
if cur_iter == 0:
|
|
print "train neural network---------------------------------------------------"
|
|
|
|
|
|
if (cur_iter + 1) % summary_freq == 0 or cur_iter < 20:
|
|
_, summaries = train_net.train(sess, f_dict, cur_iter, is_summary=True)
|
|
train_writer.add_summary(summaries, cur_iter)
|
|
else:
|
|
_ = train_net.train(sess, f_dict, cur_iter, is_summary=False)
|
|
|
|
iter_diff = iter_timer.toc()
|
|
print "Current iter: %d, time diff: %g" % (cur_iter, iter_diff)
|
|
|
|
if (cur_iter + 1) % snapshot_freq == 0:
|
|
last_snapshot_iter = cur_iter
|
|
snapshot(sess, saver, model_dir, cur_iter, cur_shape=shape_name, cur_time=cur_time)
|
|
|
|
voxels_list, bboxs_list, part_visible_masks = train_net.get_batch_info(sess, f_dict)
|
|
data_helper.write_output_to_file(voxels_list=voxels_list, bboxs_list=bboxs_list,
|
|
part_visible_masks=part_visible_masks,
|
|
input_info_dict=f_dict, output_dir=results_dir,
|
|
iter_n=current_iter)
|
|
|
|
if last_snapshot_iter != current_iter:
|
|
snapshot(sess, saver, model_dir, current_iter, cur_shape=shape_name, cur_time=cur_time)
|
|
|
|
f_dict = data_helper.get_next_minibatch(cur_iter=current_iter)
|
|
|
|
voxels_list, bboxs_list, part_visible_masks = train_net.get_batch_info(sess, f_dict)
|
|
data_helper.write_output_to_file(voxels_list=voxels_list, bboxs_list=bboxs_list,
|
|
part_visible_masks=part_visible_masks,
|
|
input_info_dict=f_dict, output_dir=results_dir, iter_n=current_iter)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# with tf.Session() as sess:
|
|
#config = tf.ConfigProto(allow_soft_placement=True)
|
|
with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
|
|
train_net(sess=sess, config_dict=c_dict)
|
|
|