You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
159 lines
5.9 KiB
159 lines
5.9 KiB
#!/usr/bin/env python
|
|
from SAG_network import *
|
|
from utils import *
|
|
import os
|
|
import time
|
|
import math
|
|
|
|
c_dict = load_from_yml("config.yml")
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(map(str, c_dict['TEST']['GPU_ID']))
|
|
|
|
def load_model_info(model_dir):
|
|
info_dict = {}
|
|
|
|
for f in os.listdir(model_dir):
|
|
if f.find("model_info.txt") != -1:
|
|
found_modelclass = False
|
|
found_bbox_mean = False
|
|
found_bbox_variance = False
|
|
found_bbox_bemchmark = False
|
|
|
|
bbox_dict = {}
|
|
bbox_mean_list = []
|
|
bbox_stds_list = []
|
|
bbox_benchmark_list = []
|
|
|
|
bbox_info_file = os.path.join(model_dir, f)
|
|
model_info_file = open(bbox_info_file, 'r')
|
|
all_lines = model_info_file.readlines()
|
|
for c_line in all_lines:
|
|
if c_line.find('model_class') != -1:
|
|
found_modelclass = True
|
|
continue
|
|
if c_line.find('bbox_mean') != -1:
|
|
found_bbox_mean = True
|
|
continue
|
|
if c_line.find('bbox_variance') != -1:
|
|
found_bbox_variance = True
|
|
continue
|
|
if c_line.find('bbox_benchmark') != -1:
|
|
found_bbox_bemchmark = True
|
|
continue
|
|
if c_line == "\n":
|
|
found_modelclass = False
|
|
found_bbox_mean = False
|
|
found_bbox_variance = False
|
|
found_bbox_bemchmark = False
|
|
continue
|
|
|
|
if found_modelclass:
|
|
info_dict['model_class'] = str(c_line[:-1])
|
|
continue
|
|
elif found_bbox_mean:
|
|
cur_line = c_line.strip()
|
|
cur_mean_list = map(float, cur_line.split())
|
|
cur_mean_array = np.array(cur_mean_list)
|
|
bbox_mean_list.append(cur_mean_array)
|
|
continue
|
|
elif found_bbox_variance:
|
|
cur_line = c_line.strip()
|
|
cur_mean_list = map(float, cur_line.split())
|
|
cur_stds_array = np.array(cur_mean_list)
|
|
bbox_stds_list.append(cur_stds_array)
|
|
continue
|
|
elif found_bbox_bemchmark:
|
|
cur_line = c_line.strip()
|
|
cur_benchmark_list = map(float, cur_line.split())
|
|
cur_benchmark_array = np.array(cur_benchmark_list)
|
|
bbox_benchmark_list.append(cur_benchmark_array)
|
|
continue
|
|
|
|
bbox_dict['bbox_mean'] = np.array(bbox_mean_list)
|
|
bbox_dict['bbox_variance'] = np.array(bbox_stds_list)
|
|
bbox_dict['bbox_benchmark'] = np.array(bbox_benchmark_list)
|
|
|
|
info_dict['bbox_info'] = bbox_dict
|
|
|
|
if f.find(".mat") != -1:
|
|
mask_info_file = os.path.join(model_dir, f)
|
|
part_masks = io.loadmat(mask_info_file)['masks']
|
|
|
|
info_dict['part_masks'] = part_masks
|
|
|
|
if info_dict['bbox_info'] is not None and info_dict['part_masks'] is not None:
|
|
return info_dict
|
|
|
|
return None
|
|
|
|
|
|
def test_net(sess, config_dict):
|
|
output_dir = config_dict['TEST']['RESULTS_DIRECTORY']
|
|
checkpoint_path = config_dict['TEST']['PRETRAINED_MODEL_PATH']
|
|
|
|
checkpoint_dir = os.path.abspath(os.path.dirname(checkpoint_path))
|
|
info_dict = load_model_info(checkpoint_dir)
|
|
if info_dict is not None:
|
|
shape_name = info_dict['model_class']
|
|
config_dict['model_info'] = info_dict
|
|
else:
|
|
raise KeyError("Can not find model info or part mask files in pretrained model directory!!!")
|
|
|
|
max_part_size = 0
|
|
if shape_name == 'motorbike':
|
|
max_part_size = 5 # motorbike
|
|
elif shape_name == 'chair':
|
|
max_part_size = 5 # chair
|
|
elif shape_name == 'airplane':
|
|
max_part_size = 6 # airplane
|
|
elif shape_name == 'guitar':
|
|
max_part_size = 3 # guitar
|
|
elif shape_name == 'lamp':
|
|
max_part_size = 4 # lamp
|
|
elif shape_name == 'toy_examples':
|
|
max_part_size = 2 # toy_examples
|
|
config_dict['MAX_PART_SIZE'] = max_part_size
|
|
|
|
# Construct and initialize a data_runner
|
|
data_helper = data_runner(config_dict, for_training=False)
|
|
inputs = data_helper.get_placeholders()
|
|
|
|
inputs['config_dict'] = config_dict
|
|
|
|
test_net = SAGNet(inputs, data_helper=data_helper, for_training=False)
|
|
print "setup SAGNet....................................................."
|
|
test_net.setup()
|
|
|
|
cur_time = time.strftime("%Y_%m_%d_%H_%M_%S", time.localtime())
|
|
cur_dir_name = "%s_%s" % (shape_name, cur_time)
|
|
|
|
results_directory = os.path.join(output_dir, cur_dir_name)
|
|
|
|
saver = tf.train.Saver(tf.trainable_variables())
|
|
|
|
sess.run(tf.global_variables_initializer())
|
|
|
|
# Restore the trained model
|
|
if checkpoint_path != '' and checkpoint_path.endswith('.ckpt'):
|
|
print "Restore model from checkpoints..."
|
|
saver.restore(sess, checkpoint_path)
|
|
print "Restore done."
|
|
|
|
n_epochs = config_dict['TEST']['SAMPLE_SIZE']
|
|
n_epochs = int(math.ceil(n_epochs / config_dict['TRAIN']['BATCH_SIZE']))
|
|
|
|
print "Starting to synthesize 3D shapes."
|
|
|
|
for cur_iter in range(n_epochs):
|
|
f_dict = data_helper.get_inputs_for_testing(cur_iter = cur_iter)
|
|
voxels_list, bboxs_list, part_visible_masks = test_net.test(sess, f_dict)
|
|
data_helper.write_output_to_file(voxels_list=voxels_list, bboxs_list=bboxs_list,
|
|
part_visible_masks=part_visible_masks, input_info_dict=f_dict,
|
|
output_dir=results_directory, iter_n=cur_iter)
|
|
|
|
print ("[%6d/%6d] %s Generation Done" % (int(cur_iter), int(n_epochs), shape_name))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# with tf.Session() as sess:
|
|
with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
|
|
test_net(sess=sess, config_dict=c_dict)
|