You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 

159 lines
5.9 KiB

#!/usr/bin/env python
from SAG_network import *
from utils import *
import os
import time
import math
c_dict = load_from_yml("config.yml")
os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(map(str, c_dict['TEST']['GPU_ID']))
def load_model_info(model_dir):
info_dict = {}
for f in os.listdir(model_dir):
if f.find("model_info.txt") != -1:
found_modelclass = False
found_bbox_mean = False
found_bbox_variance = False
found_bbox_bemchmark = False
bbox_dict = {}
bbox_mean_list = []
bbox_stds_list = []
bbox_benchmark_list = []
bbox_info_file = os.path.join(model_dir, f)
model_info_file = open(bbox_info_file, 'r')
all_lines = model_info_file.readlines()
for c_line in all_lines:
if c_line.find('model_class') != -1:
found_modelclass = True
continue
if c_line.find('bbox_mean') != -1:
found_bbox_mean = True
continue
if c_line.find('bbox_variance') != -1:
found_bbox_variance = True
continue
if c_line.find('bbox_benchmark') != -1:
found_bbox_bemchmark = True
continue
if c_line == "\n":
found_modelclass = False
found_bbox_mean = False
found_bbox_variance = False
found_bbox_bemchmark = False
continue
if found_modelclass:
info_dict['model_class'] = str(c_line[:-1])
continue
elif found_bbox_mean:
cur_line = c_line.strip()
cur_mean_list = map(float, cur_line.split())
cur_mean_array = np.array(cur_mean_list)
bbox_mean_list.append(cur_mean_array)
continue
elif found_bbox_variance:
cur_line = c_line.strip()
cur_mean_list = map(float, cur_line.split())
cur_stds_array = np.array(cur_mean_list)
bbox_stds_list.append(cur_stds_array)
continue
elif found_bbox_bemchmark:
cur_line = c_line.strip()
cur_benchmark_list = map(float, cur_line.split())
cur_benchmark_array = np.array(cur_benchmark_list)
bbox_benchmark_list.append(cur_benchmark_array)
continue
bbox_dict['bbox_mean'] = np.array(bbox_mean_list)
bbox_dict['bbox_variance'] = np.array(bbox_stds_list)
bbox_dict['bbox_benchmark'] = np.array(bbox_benchmark_list)
info_dict['bbox_info'] = bbox_dict
if f.find(".mat") != -1:
mask_info_file = os.path.join(model_dir, f)
part_masks = io.loadmat(mask_info_file)['masks']
info_dict['part_masks'] = part_masks
if info_dict['bbox_info'] is not None and info_dict['part_masks'] is not None:
return info_dict
return None
def test_net(sess, config_dict):
output_dir = config_dict['TEST']['RESULTS_DIRECTORY']
checkpoint_path = config_dict['TEST']['PRETRAINED_MODEL_PATH']
checkpoint_dir = os.path.abspath(os.path.dirname(checkpoint_path))
info_dict = load_model_info(checkpoint_dir)
if info_dict is not None:
shape_name = info_dict['model_class']
config_dict['model_info'] = info_dict
else:
raise KeyError("Can not find model info or part mask files in pretrained model directory!!!")
max_part_size = 0
if shape_name == 'motorbike':
max_part_size = 5 # motorbike
elif shape_name == 'chair':
max_part_size = 5 # chair
elif shape_name == 'airplane':
max_part_size = 6 # airplane
elif shape_name == 'guitar':
max_part_size = 3 # guitar
elif shape_name == 'lamp':
max_part_size = 4 # lamp
elif shape_name == 'toy_examples':
max_part_size = 2 # toy_examples
config_dict['MAX_PART_SIZE'] = max_part_size
# Construct and initialize a data_runner
data_helper = data_runner(config_dict, for_training=False)
inputs = data_helper.get_placeholders()
inputs['config_dict'] = config_dict
test_net = SAGNet(inputs, data_helper=data_helper, for_training=False)
print "setup SAGNet....................................................."
test_net.setup()
cur_time = time.strftime("%Y_%m_%d_%H_%M_%S", time.localtime())
cur_dir_name = "%s_%s" % (shape_name, cur_time)
results_directory = os.path.join(output_dir, cur_dir_name)
saver = tf.train.Saver(tf.trainable_variables())
sess.run(tf.global_variables_initializer())
# Restore the trained model
if checkpoint_path != '' and checkpoint_path.endswith('.ckpt'):
print "Restore model from checkpoints..."
saver.restore(sess, checkpoint_path)
print "Restore done."
n_epochs = config_dict['TEST']['SAMPLE_SIZE']
n_epochs = int(math.ceil(n_epochs / config_dict['TRAIN']['BATCH_SIZE']))
print "Starting to synthesize 3D shapes."
for cur_iter in range(n_epochs):
f_dict = data_helper.get_inputs_for_testing(cur_iter = cur_iter)
voxels_list, bboxs_list, part_visible_masks = test_net.test(sess, f_dict)
data_helper.write_output_to_file(voxels_list=voxels_list, bboxs_list=bboxs_list,
part_visible_masks=part_visible_masks, input_info_dict=f_dict,
output_dir=results_directory, iter_n=cur_iter)
print ("[%6d/%6d] %s Generation Done" % (int(cur_iter), int(n_epochs), shape_name))
if __name__ == "__main__":
# with tf.Session() as sess:
with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
test_net(sess=sess, config_dict=c_dict)