You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 

1093 lines
58 KiB

from network import Network
from data_helper import *
import tensorflow as tf
#config = tf.ConfigProto(allow_soft_placement=True)
# saver = tf.train.saver()
# with tf.Session(config=config) as sess:
# saver.restore(sess, "model-xxxx")
class SAGNet(Network):
def __init__(self, data, data_helper, for_training):
self.inputs = []
self.data = data
self.data_helper = data_helper
self.for_training = for_training
self.rnn_cell_depth = data['config_dict']['RNN_CELL_DEPTH'] # 2D
self.rnn_state_dim = data['config_dict']['RNN_STATE_DIM'] # 512D
self.leaky_value = data['config_dict']['LEAK_VALUE']
self.max_part_size = data['config_dict']['MAX_PART_SIZE']
self.embedding_size = data['config_dict']['EMBEDDING_VECOTR_SIZE'] # 320D
self.bbox_size = data['config_dict']['BOUNDING_BOX_SIZE'] # 6D
self.vert_rnn_max_time_step = self.max_part_size # k
self.edge_rnn_max_time_step = self.max_part_size * (self.max_part_size - 1) / 2 # K = (k - 1) x k / 2
self.layers = {}
if self.for_training:
self.gpu_num = data['config_dict']['TRAIN']['NUM_GPUS']
self.batch_size = data['config_dict']['TRAIN']['BATCH_SIZE']
self.part_voxels = data['part_voxels']
self.part_bboxs = data['part_bbox']
self.part_visible_masks = data['part_visible_masks']
self.gaussian_noise = data['gaussian_noise']
# number of refine iterations
self.n_iter = data['config_dict']['TRAIN']['EXCHANGE_NUM']
# The same as variables mentioned above, two dimensional vector
self.edge_pair_mask_inds = data['rel_pair_mask_inds']
self.vert_lr = data['vert_lr']
self.edge_lr = data['edge_lr']
self.graph_gen_lr = data['graph_gen_lr']
self.recon_gen_loss_ratio = data['recon_gen_loss_ratio']
self.voxel_bbox_ratio = data['voxel_bbox_ratio']
self.g_rec_kl_loss_ratio = data['g_rec_kl_loss_ratio']
self.max_gradient_norm = data['max_gradient_norm']
self.voxel_loss_weights = data['part_voxel_loss_weights']
self.part_bbox_loss_masks = data['part_bbox_loss_masks']
self.optimizer = data['config_dict']['TRAIN']['OPTIMIZER_TYPE'] # Optimizer type
self.keep_prob = tf.placeholder(tf.float32)
else:
self.gpu_num = 1
self.batch_size = data['config_dict']['TRAIN']['BATCH_SIZE']
self.latent_z = data['latent_codes']
self.part_visible_masks = data['part_visible_masks']
##############################################################################
# Functions to setup network
##############################################################################
def setup(self):
if self.for_training:
self.setup_for_training()
else:
self.setup_for_testing()
def setup_for_training(self):
self._setup_rnns_for_training()
self._setup_optimizer()
global voxel_output, bbox_output, g_loss
global graph_gen_grads_and_vars, g_part_mse_loss
vert_grad_var_list = []
edge_grad_var_list = []
graph_gen_grad_var_list = []
vert_loss_list = []
edge_loss_list = []
graph_rec_loss_list = []
graph_kl_loss_list = []
total_loss_list = []
for gpu_id in range(int(self.gpu_num)):
cur_dev_str = '/gpu:%d' % gpu_id
self.vert_multi_cell_state = self.vert_multi_cell.zero_state(self.batch_size, tf.float32)
self.edge_multi_cell_state = self.edge_multi_cell.zero_state(self.batch_size, tf.float32)
with tf.device(cur_dev_str):
with tf.variable_scope(tf.get_variable_scope(), reuse=(gpu_id > 0)):
voxels_vector = self._voxel_encoder(self.part_voxels[gpu_id]) # -1*2*2*2*512
bboxs_vector = self._bbox_encoder(self.part_bboxs[gpu_id],
edge_pair_mask=self.edge_pair_mask_inds[gpu_id]) # -1* 400
vert_factor, edge_factor = self._iterate(voxels_vector, bboxs_vector,
edge_pair_mask=self.edge_pair_mask_inds[gpu_id])
vert_out, edge_out, mu, log_sigma, g_part_mse_loss = self._learn_representation_for_graph(
vert_factor, edge_factor, part_visible_masks=self.part_visible_masks[gpu_id],
gaussian_noise=self.gaussian_noise[gpu_id], reuse=(gpu_id > 0))
voxel_output, bbox_output = self._pred_output(vert_out, edge_out)
# KL loss
g_kl_loss = self._final_graph_kl_loss(mu=mu, log_sigma=log_sigma)
g_kl_loss = tf.expand_dims(g_kl_loss, 0)
graph_kl_loss_list.append(g_kl_loss)
# The reconstruction loss for the graph representation
g_mse_loss = self._final_graph_reconstruction_loss(g_vert_in=vert_factor,
g_edge_in=edge_factor,
g_vert_out=vert_out,
g_edge_out=edge_out)
g_mse_loss = g_mse_loss + g_part_mse_loss
g_mse_loss = tf.expand_dims(g_mse_loss, 0)
graph_rec_loss_list.append(g_mse_loss)
# Total loss for graph representation
g_loss = [g_mse_loss * self.g_rec_kl_loss_ratio,
g_kl_loss * (1 - self.g_rec_kl_loss_ratio)]
g_loss = tf.add_n(g_loss, name='graph_loss')
# Reconstruction loss for the voxel maps and bounding boxes
cur_voxel_loss = self._final_voxels_loss(self.part_voxels[gpu_id], voxel_output,
voxel_loss_weight=self.voxel_loss_weights,
voxel_loss_mask=self.part_visible_masks[gpu_id])
cur_bbox_loss = self._final_bboxs_loss(self.bbox_input, bbox_output,
bbox_loss_mask=self.part_bbox_loss_masks[gpu_id])
cur_voxel_loss = tf.expand_dims(cur_voxel_loss, 0)
cur_bbox_loss = tf.expand_dims(cur_bbox_loss, 0)
vert_loss_list.append(cur_voxel_loss)
edge_loss_list.append(cur_bbox_loss)
rec_losses = [cur_voxel_loss * self.voxel_bbox_ratio, cur_bbox_loss * (1 - self.voxel_bbox_ratio)]
rec_losses = tf.add_n(rec_losses, name='reconstruction_loss')
# Total losses for the whole framework
cur_total_loss = rec_losses * self.recon_gen_loss_ratio + \
g_loss * (1 - self.recon_gen_loss_ratio)
# Compute the gradients for the three modules in our framework
vert_var_list, edge_var_list, graph_gen_list = self.merge_variable_list()
grads = tf.gradients(cur_total_loss, vert_var_list + edge_var_list + graph_gen_list)
grads, _ = tf.clip_by_global_norm(grads, self.max_gradient_norm)
vert_grad = grads[:len(vert_var_list)]
edge_grad = grads[len(vert_var_list): len(vert_var_list) + len(edge_var_list)]
graph_gen_grad = grads[len(vert_var_list) + len(edge_var_list):]
graph_gen_grad_var_list.append((graph_gen_grad, graph_gen_list))
vert_grad_var_list.append((vert_grad, vert_var_list))
edge_grad_var_list.append((edge_grad, edge_var_list))
total_loss_list.append(cur_total_loss)
vert_loss = tf.concat(values=vert_loss_list, axis=0)
self.voxel_loss = tf.reduce_mean(vert_loss, name='voxel_loss')
edge_loss = tf.concat(values=edge_loss_list, axis=0)
self.bbox_loss = tf.reduce_mean(edge_loss, name='bbox_loss')
graph_rec_loss = tf.concat(values=graph_rec_loss_list, axis=0)
self.g_mse_loss = tf.reduce_mean(graph_rec_loss, name='graph_mse_loss')
graph_kl_loss = tf.concat(values=graph_kl_loss_list, axis=0)
self.g_kl_loss = tf.reduce_mean(graph_kl_loss, name='graph_kl_loss')
# Using three different optimizers to process three modules in our framework
graph_gen_grads_and_vars = self._average_gradients(graph_gen_grad_var_list)
tf.contrib.training.add_gradients_summaries(graph_gen_grads_and_vars)
tf.summary.scalar('graph_kl_loss', self.g_kl_loss)
tf.summary.scalar('graph_rec_loss', self.g_mse_loss)
total_loss = tf.concat(values=total_loss_list, axis=0)
self.total_losses = tf.reduce_mean(total_loss, name='total_loss')
# Average the gradients for the geometry and structure information
vert_grads_and_vars = self._average_gradients(vert_grad_var_list)
edge_grads_and_vars = self._average_gradients(edge_grad_var_list)
tf.summary.scalar('bbox_loss_', self.bbox_loss)
tf.summary.scalar('voxel_loss_', self.voxel_loss)
tf.contrib.training.add_gradients_summaries(vert_grads_and_vars)
tf.contrib.training.add_gradients_summaries(edge_grads_and_vars)
# Apply the gradients for the three optimizers
vert_op = self.vert_opt.apply_gradients(vert_grads_and_vars)
edge_op = self.edge_opt.apply_gradients(edge_grads_and_vars)
graph_gen_op = self.graph_gen_opt.apply_gradients(graph_gen_grads_and_vars)
self.train_op = tf.group(vert_op, edge_op, graph_gen_op, name='train_op')
self.summary_op = tf.summary.merge_all()
def setup_for_testing(self):
self._setup_rnns_for_testing()
cur_dev_str = '/gpu:0'
with tf.device(cur_dev_str):
with tf.variable_scope(tf.get_variable_scope()):
layer_name = 'graph_embedding_layer'
with tf.variable_scope(layer_name) as scope:
p_masks = tf.cast(self.part_visible_masks[self.gpu_num - 1], tf.float32)
obj_model_embedding = tf.concat(values=[self.latent_z, p_masks], axis=1)
part_representations = self._obj_gen_decoder_rnn_forward(obj_model_embedding)
vert_out = self._vert_gen_decoder_rnn_forward(part_representations[0])
edge_out = self._edge_gen_decoder_rnn_forward(part_representations[1])
_ = self._pred_output(vert_out, edge_out)
def merge_variable_list(self):
"""Merge all the trainable variables into different lists"""
vert_list = []
edge_list = []
graph_gen_list = []
for cur_var in tf.trainable_variables():
var_name = cur_var.name
if var_name.find('Graph') != -1 or var_name.find('graph') != -1:
graph_gen_list.append(cur_var)
elif var_name.find('voxel') != -1 or var_name.find('vert') != -1 or var_name.find('Voxel') != -1:
vert_list.append(cur_var)
elif var_name.find('bbox') != -1 or var_name.find('edge') != -1 or var_name.find('BBox') != -1:
edge_list.append(cur_var)
return vert_list, edge_list, graph_gen_list
def _setup_optimizer(self):
"""Setup the optimizers for different modules in our framework. The vert_opt is
for the geometry information, and the edge_opt is for the structure information,
and the graph_gen_opt is for the 2-way VAE."""
momentum_value = self.data['config_dict']['TRAIN']['MOMENTUM_VALUE']
if self.optimizer.lower() == 'adadelta':
self.vert_opt = tf.train.AdadeltaOptimizer(learning_rate=self.vert_lr)
self.edge_opt = tf.train.AdadeltaOptimizer(learning_rate=self.edge_lr)
self.graph_gen_opt = tf.train.AdadeltaOptimizer(learning_rate=self.graph_gen_lr)
elif self.optimizer.lower() == 'adam':
self.vert_opt = tf.train.AdamOptimizer(learning_rate=self.vert_lr)
self.edge_opt = tf.train.AdamOptimizer(learning_rate=self.edge_lr)
self.graph_gen_opt = tf.train.AdamOptimizer(learning_rate=self.graph_gen_lr)
elif self.optimizer.lower() == 'rmsprop':
self.vert_opt = tf.train.RMSPropOptimizer(learning_rate=self.vert_lr)
self.edge_opt = tf.train.RMSPropOptimizer(learning_rate=self.edge_lr)
self.graph_gen_opt = tf.train.RMSPropOptimizer(learning_rate=self.graph_gen_lr)
elif self.optimizer.lower() == 'momentum':
self.vert_opt = tf.train.MomentumOptimizer(learning_rate=self.vert_lr, momentum=momentum_value, use_nesterov=False)
self.edge_opt = tf.train.MomentumOptimizer(learning_rate=self.edge_lr, momentum=momentum_value, use_nesterov=False)
self.graph_gen_opt = tf.train.MomentumOptimizer(learning_rate=self.graph_gen_lr, momentum=momentum_value, use_nesterov=False)
elif self.optimizer.lower() == 'nesterov':
self.vert_opt = tf.train.MomentumOptimizer(learning_rate=self.vert_lr, momentum=momentum_value, use_nesterov=True)
self.edge_opt = tf.train.MomentumOptimizer(learning_rate=self.edge_lr, momentum=momentum_value, use_nesterov=True)
self.graph_gen_opt = tf.train.MomentumOptimizer(learning_rate=self.graph_gen_lr, momentum=momentum_value, use_nesterov=True)
elif self.optimizer.lower() == 'adagrad':
self.vert_opt = tf.train.AdagradOptimizer(learning_rate=self.vert_lr)
self.edge_opt = tf.train.AdagradOptimizer(learning_rate=self.edge_lr)
self.graph_gen_opt = tf.train.AdagradOptimizer(learning_rate=self.graph_gen_lr)
elif self.optimizer.lower() == 'adagradda':
self.vert_opt = tf.train.AdagradDAOptimizer(learning_rate=self.vert_lr)
self.edge_opt = tf.train.AdagradDAOptimizer(learning_rate=self.edge_lr)
self.graph_gen_opt = tf.train.AdagradDAOptimizer(learning_rate=self.graph_gen_lr)
else:
self.vert_opt = tf.train.GradientDescentOptimizer(learning_rate=self.vert_lr)
self.edge_opt = tf.train.GradientDescentOptimizer(learning_rate=self.edge_lr)
self.graph_gen_opt = tf.train.GradientDescentOptimizer(learning_rate=self.graph_gen_lr)
tf.summary.scalar('vert_learning_rate_', self.vert_lr)
tf.summary.scalar('edge_learning_rate_', self.edge_lr)
tf.summary.scalar('graph_gen_learning_rate', self.graph_gen_lr)
def _setup_rnns_for_testing(self):
# build rnn for decode the whole object representation
obj_decode_gru_cell = tf.contrib.rnn.GRUCell(self.rnn_state_dim, activation=tf.tanh)
self.obj_decode_multi_cell = tf.contrib.rnn.MultiRNNCell([obj_decode_gru_cell] * self.rnn_cell_depth)
self.obj_decode_multi_cell_state = self.obj_decode_multi_cell.zero_state(self.batch_size, tf.float32)
# build the rnn for decode the vert and edge information
vert_decode_gru_cell = tf.contrib.rnn.GRUCell(self.rnn_state_dim, activation=tf.tanh)
edge_decode_gru_cell = tf.contrib.rnn.GRUCell(self.rnn_state_dim, activation=tf.tanh)
self.vert_decode_multi_cell = tf.contrib.rnn.MultiRNNCell([vert_decode_gru_cell] * self.rnn_cell_depth)
self.edge_decode_multi_cell = tf.contrib.rnn.MultiRNNCell([edge_decode_gru_cell] * self.rnn_cell_depth)
self.vert_decode_multi_cell_state = self.vert_decode_multi_cell.zero_state(self.batch_size, tf.float32)
self.edge_decode_multi_cell_state = self.edge_decode_multi_cell.zero_state(self.batch_size, tf.float32)
def _setup_rnns_for_training(self):
"""Construct RNN cells and states. And build and initialize RNNs for message passing"""
vert_gru_cell = tf.contrib.rnn.GRUCell(self.rnn_state_dim, activation=tf.tanh)
edge_gru_cell = tf.contrib.rnn.GRUCell(self.rnn_state_dim, activation=tf.tanh)
self.vert_multi_cell = tf.contrib.rnn.MultiRNNCell([vert_gru_cell] * self.rnn_cell_depth)
self.edge_multi_cell = tf.contrib.rnn.MultiRNNCell([edge_gru_cell] * self.rnn_cell_depth)
self.vert_multi_cell_state = self.vert_multi_cell.zero_state(self.batch_size, tf.float32)
self.edge_multi_cell_state = self.edge_multi_cell.zero_state(self.batch_size, tf.float32)
# build the rnns for encode the vert and edge information
vert_encode_gru_cell = tf.contrib.rnn.GRUCell(self.rnn_state_dim, activation=tf.tanh)
edge_encode_gru_cell = tf.contrib.rnn.GRUCell(self.rnn_state_dim, activation=tf.tanh)
self.vert_encode_multi_cell = tf.contrib.rnn.MultiRNNCell([vert_encode_gru_cell] * self.rnn_cell_depth)
self.edge_encode_multi_cell = tf.contrib.rnn.MultiRNNCell([edge_encode_gru_cell] * self.rnn_cell_depth)
self.vert_encode_multi_cell_state = self.vert_encode_multi_cell.zero_state(self.batch_size, tf.float32)
self.edge_encode_multi_cell_state = self.edge_encode_multi_cell.zero_state(self.batch_size, tf.float32)
# build the rnn for decode the vert and edge information
vert_decode_gru_cell = tf.contrib.rnn.GRUCell(self.rnn_state_dim, activation=tf.tanh)
edge_decode_gru_cell = tf.contrib.rnn.GRUCell(self.rnn_state_dim, activation=tf.tanh)
self.vert_decode_multi_cell = tf.contrib.rnn.MultiRNNCell([vert_decode_gru_cell] * self.rnn_cell_depth)
self.edge_decode_multi_cell = tf.contrib.rnn.MultiRNNCell([edge_decode_gru_cell] * self.rnn_cell_depth)
self.vert_decode_multi_cell_state = self.vert_decode_multi_cell.zero_state(self.batch_size, tf.float32)
self.edge_decode_multi_cell_state = self.edge_decode_multi_cell.zero_state(self.batch_size, tf.float32)
# build rnn for encode/decode the whole object representation
obj_encode_gru_cell = tf.contrib.rnn.GRUCell(self.rnn_state_dim, activation=tf.tanh)
obj_decode_gru_cell = tf.contrib.rnn.GRUCell(self.rnn_state_dim, activation=tf.tanh)
self.obj_encode_multi_cell = tf.contrib.rnn.MultiRNNCell([obj_encode_gru_cell] * self.rnn_cell_depth)
self.obj_decode_multi_cell = tf.contrib.rnn.MultiRNNCell([obj_decode_gru_cell] * self.rnn_cell_depth)
self.obj_encode_multi_cell_state = self.obj_encode_multi_cell.zero_state(self.batch_size, tf.float32)
self.obj_decode_multi_cell_state = self.obj_decode_multi_cell.zero_state(self.batch_size, tf.float32)
def _average_gradients(self, tower_grads):
"""Average the gradients for multi-GPU training"""
total_grads = []
grad_len = len(tower_grads[0][0])
for i in range(grad_len):
total_grads.append([])
for t_ind in range(len(total_grads)):
for g_ind in range(len(tower_grads)):
total_grads[t_ind].append(tower_grads[g_ind][0][t_ind])
grad_and_var = []
for (grads, vars) in zip(total_grads, tower_grads[0][1]):
has_none = False
for grad in grads:
if grad is None:
has_none = True
if has_none:
continue
cur_grad = grads
cur_grad = tf.reduce_mean(cur_grad, axis=0)
grad_and_var.append((cur_grad, vars))
return grad_and_var
##############################################################################
# Functions in the encoder module
##############################################################################
def _iterate(self, voxel_vector, bbox_vector, edge_pair_mask=None):
"""
iterate for passing global relationship between parts and merge the geometry & structure features iteratively.
"""
with tf.variable_scope('voxel_unary') as scope:
(self.feed(voxel_vector)
.fc(self.rnn_state_dim, leaky_value=self.leaky_value, relu=False, name='vert_unary_fc')
.batch_norm(name='vert_unary', relu=False))
with tf.variable_scope('bbox_unary') as scope:
(self.feed(bbox_vector)
.fc(self.rnn_state_dim, leaky_value=self.leaky_value, relu=False, name='edge_unary_fc')
.batch_norm(name='edge_unary', relu=False))
vert_unary = self.get_output('vert_unary') #-1*512
edge_unary = self.get_output('edge_unary')
global vert_factor, edge_factor
# we obtain the new states of the gru cells
vert_factor = self._vert_rnn_forward(vert_unary, reuse=False) # 10*5*512
edge_factor = self._edge_rnn_forward(edge_unary, reuse=False) # 10*10*512
for i in xrange(self.n_iter):
reuse = i > 0
# compute vert states
vert_ctx = self._compute_vert_context(edge_factor, vert_factor, reuse=reuse, edge_pair_mask=edge_pair_mask)
vert_ctx = tf.reshape(vert_ctx, [-1, self.vert_rnn_max_time_step, self.rnn_state_dim])
vert_factor = self._vert_rnn_forward(vert_ctx, reuse=True)
# compute edge states
edge_ctx = self._compute_edge_context(vert_factor, edge_factor, reuse=reuse, edge_pair_mask=edge_pair_mask) # 100*512
edge_ctx = tf.reshape(edge_ctx, [-1, self.edge_rnn_max_time_step, self.rnn_state_dim]) # 10*10*512
edge_factor = self._edge_rnn_forward(edge_ctx, reuse=True) # 10*10*512
# These two features are used to compute the reconstruction loss
self.vert_gen_encoder_input = vert_factor
self.edge_gen_encoder_input = edge_factor
return vert_factor, edge_factor
def _voxel_encoder(self, input_voxels, reuse=False):
"""Encoder for voxel maps"""
layer_name = 'voxel_encoder'
voxel_size = self.data['config_dict']['CUBE_LEN']
with tf.variable_scope(layer_name) as scope:
if reuse: tf.get_variable_scope().reuse_variables()
input_voxels = tf.reshape(input_voxels, (-1, voxel_size, voxel_size, voxel_size, 1))
init_weights = tf.contrib.layers.xavier_initializer()
init_biases = tf.zeros_initializer()
valid_strides = [1, 2, 2, 2, 1]
same_strides = [1, 1, 1, 1, 1]
ve_conv_1_out = self.conv3d(input_voxels, 3, 64, 've_conv_1_1', init_weights, strides=valid_strides,init_biases=init_biases,
leaky_value=self.leaky_value, relu=True, batch_norm=True, padding='SAME')
# first resnet blocks
ve_conv_block_2_1_out = self.conv_residual_block(ve_conv_1_out, 3, 64, init_weights,
've_conv_block_2_1', leaky_value=self.leaky_value, padding='SAME', bottle_neck=True)
ve_conv_block_2_2_out = self.conv_residual_block(ve_conv_block_2_1_out, 3, 64, init_weights,
've_conv_block_2_2', leaky_value=self.leaky_value, padding='SAME', bottle_neck=False)
# second resnet blocks
ve_conv_block_3_1_out = self.conv_residual_block(ve_conv_block_2_2_out, 3, 128, init_weights,
've_conv_block_3_1', self.leaky_value, padding='SAME', bottle_neck=True)
ve_conv_block_3_2_out = self.conv_residual_block(ve_conv_block_3_1_out, 3, 128, init_weights,
've_conv_block_3_2', self.leaky_value, padding='SAME', bottle_neck=False)
# third resnet blocks
ve_conv_block_4_1_out = self.conv_residual_block(ve_conv_block_3_2_out, 3, 256, init_weights,
've_conv_block_4_1', self.leaky_value, padding='SAME', bottle_neck=True)
ve_conv_block_4_2_out = self.conv_residual_block(ve_conv_block_4_1_out, 3, 256, init_weights,
've_conv_block_4_2', self.leaky_value, padding='SAME', bottle_neck=False)
ve_conv_block_4_3_out = self.conv_residual_block(ve_conv_block_4_2_out, 3, 256, init_weights,
've_conv_block_4_3', self.leaky_value, padding='SAME', bottle_neck=False)
ve_conv_out = self.conv3d(ve_conv_block_4_3_out, 2, 512, 've_conv_out', init_weights, strides=same_strides, init_biases=init_biases,
leaky_value=self.leaky_value, relu=True, batch_norm=True, padding='SAME')
return ve_conv_out
def _voxel_decoder(self, input_voxel_features, reuse=False):
"""decode and predict the value of each voxel"""
layer_name = 'voxel_decoder'
with tf.variable_scope(layer_name) as scope:
if reuse: tf.get_variable_scope().reuse_variables()
init_weights = tf.contrib.layers.xavier_initializer()
init_biases = tf.zeros_initializer()
valid_strides = [1, 2, 2, 2, 1]
same_strides = [1, 1, 1, 1, 1]
batch_size = tf.shape(input_voxel_features)[0]
z = tf.reshape(input_voxel_features, (-1, 1, 1, 1, self.embedding_size))
vd_deconv_1_1_out = self.deconv3d(z, 2, 512, 'vd_deconv_1_1', (batch_size, 2, 2, 2, 512), init_weights,
same_strides, init_biases, leaky_value=self.leaky_value, relu=True,
batch_norm=True, padding='VALID')
# first deconv block
vd_deconv_2_1_out = self.deconv_residual_block(vd_deconv_1_1_out, 3, 256, (batch_size, 4, 4, 4, 256),
init_weights, 'vd_deconv_2_1_block', self.leaky_value,
padding='SAME', bottle_neck=True)
vd_deconv_2_2_out = self.deconv_residual_block(vd_deconv_2_1_out, 3, 256, (batch_size, 4, 4, 4, 256),
init_weights, 'vd_deconv_2_2_block', self.leaky_value,
padding='SAME', bottle_neck=False)
vd_deconv_2_3_out = self.deconv_residual_block(vd_deconv_2_2_out, 3, 256, (batch_size, 4, 4, 4, 256),
init_weights, 'vd_deconv_2_3_block', self.leaky_value,
padding='SAME', bottle_neck=False)
# second deconv block
vd_deconv_3_1_out = self.deconv_residual_block(vd_deconv_2_3_out, 3, 128, (batch_size, 8, 8, 8, 128),
init_weights, 'vd_deconv_3_1_block', self.leaky_value,
padding='SAME', bottle_neck=True)
vd_deconv_3_2_out = self.deconv_residual_block(vd_deconv_3_1_out, 3, 128, (batch_size, 8, 8, 8, 128),
init_weights, 'vd_deconv_3_2_block', self.leaky_value,
padding='SAME', bottle_neck=False)
# third deconv block
vd_deconv_4_1_out = self.deconv_residual_block(vd_deconv_3_2_out, 3, 64, (batch_size, 16, 16, 16, 64),
init_weights, 'vd_deconv_4_1_block', self.leaky_value,
padding='SAME', bottle_neck=True)
vd_deconv_4_2_out = self.deconv_residual_block(vd_deconv_4_1_out, 3, 64, (batch_size, 16, 16, 16, 64),
init_weights, 'vd_deconv_4_2_block', self.leaky_value,
padding='SAME', bottle_neck=False)
vd_deconv_out = self.deconv3d(vd_deconv_4_2_out, 3, 1, 'vd_deconv_5_1', (batch_size, 32, 32, 32, 1),
init_weights, valid_strides, init_biases, leaky_value=self.leaky_value, relu=False,
batch_norm=False, padding='SAME')
deconv_feature = tf.nn.sigmoid(vd_deconv_out, name='voxels_out')
self.wrap(deconv_feature, layer_name)
return deconv_feature
def _bbox_encoder(self, input_bb_feature, reuse=False, edge_pair_mask=None):
"""Encoder for bounding boxes"""
layer_name = 'bbox_encoder'
with tf.variable_scope(layer_name) as scope:
bbox_tensor = tf.reshape(input_bb_feature, [-1, self.bbox_size])
# we get the context relationship of pairwise vertex of the bounding box of a graph
bbox_f_factor = tf.gather(bbox_tensor, edge_pair_mask[:, 0])
bbox_s_factor = tf.gather(bbox_tensor, edge_pair_mask[:, 1])
bbox_input = tf.concat(values=[bbox_f_factor, bbox_s_factor], axis=1, name="bbox_concat_input")
self.bbox_input = bbox_input
self.bbox_origin_input = bbox_tensor
(self.feed(bbox_input)
.fc(self.embedding_size * 2, leaky_value=self.leaky_value, relu=False, name='bbox_encoder_layer_fc', reuse=reuse)
.batch_norm(name='bbox_encoder_layer_bn', relu=False)
.lrelu(leaky_value=self.leaky_value, name='bbox_encoder_layer_out'))
bbox_encoder_feature = self.get_output('bbox_encoder_layer_out')
return bbox_encoder_feature
def _decode_bbox(self, input_bbox_feature, reuse=False):
layer_name = 'bbox_decoder'
pred_name = 'bbox_pred'
with tf.variable_scope(layer_name) as scope:
(self.feed(input_bbox_feature)
.fc(self.embedding_size * 2, leaky_value=self.leaky_value, relu=False, name='bbox_decoder_layer_1_fc', reuse=reuse)
.batch_norm(name='bbox_decoder_layer_1_bn', relu=False, reuse=reuse)
.lrelu(leaky_value=self.leaky_value, name='bbox_decoder_layer_1_out')
.fc(self.bbox_size * 2, leaky_value=self.leaky_value, relu=False, name='bbox_decoder_layer_2_fc', reuse=reuse)
.batch_norm(name=pred_name, relu=False, reuse=reuse))
self.bbox_pred = self.get_output(pred_name)
return self.bbox_pred
def _decode_part_voxels(self, input_layer, reuse=False):
layer_name = 'part_voxel_output'
print(layer_name)
# Transform the dimension through a fully-connected layer
with tf.variable_scope('part_voxel_decoder') as scope:
(self.feed(input_layer)
.fc(self.embedding_size, relu=False, leaky_value=self.leaky_value, name='part_voxel_decoder_fc', reuse=reuse)
.batch_norm(name='part_voxel_decoder_bn', relu=False, reuse=reuse)
.lrelu(leaky_value=self.leaky_value, name=layer_name))
vert_feature = self.get_output(layer_name)
self.voxel_pred = self._voxel_decoder(vert_feature, reuse=reuse)
return self.voxel_pred
def _pred_output(self, vert_factor, edge_factor, reuse=False):
"""Predict the outputs for geometry(voxel maps) and structure(bounding boxes)"""
voxel_decodings = self._decode_part_voxels(vert_factor, reuse=reuse)
bbox_decodings = self._decode_bbox(edge_factor, reuse=reuse)
return voxel_decodings, bbox_decodings
##############################################################################
# Functions to compute context and learn the latent representation through RNNs
##############################################################################
def _compute_edge_context(self, vert_factor, edge_factor, reuse=False, edge_pair_mask=None):
"""
attention-based edge message pooling
"""
vert_factor = tf.reshape(vert_factor, [-1, self.rnn_state_dim]) # 50*512
edge_factor = tf.reshape(edge_factor, [-1, self.rnn_state_dim]) #100*512
vert_in_factor = tf.gather(vert_factor, edge_pair_mask[:, 0])
vert_out_factor = tf.gather(vert_factor, edge_pair_mask[:, 1])
vert_w_input_first = tf.concat(values=[vert_in_factor, edge_factor], axis=1) # 100*1024
vert_w_input_second = tf.concat(values=[vert_out_factor, edge_factor], axis=1)
# compute compatibility scores
(self.feed(vert_w_input_first)
.fc(1, relu=False, leaky_value=self.leaky_value, reuse=reuse, name='vert_first_w_fc')
.sigmoid(name='edge_vert_first_score'))
(self.feed(vert_w_input_second)
.fc(1, relu=False, leaky_value=self.leaky_value, reuse=True, name='vert_first_w_fc')
.sigmoid(name='edge_vert_second_score'))
vert_first_w = self.get_output('edge_vert_first_score') # 100*1
vert_second_w = self.get_output('edge_vert_second_score')
weighted_first_vert = tf.multiply(vert_in_factor, vert_first_w)
weighted_second_vert = tf.multiply(vert_out_factor, vert_second_w)
return weighted_first_vert + weighted_second_vert
def _compute_vert_context(self, edge_factor, vert_factor, reuse=False, edge_pair_mask=None):
"""
attention-based vertex(node) message pooling
"""
"""the edge_pair_mask_inds[:, 0] store the index of in-bound vertex of an edge
and the edge_pair_mask_inds[:, 1] store the index of out-bound vertex of an edge"""
edge_factor = tf.reshape(edge_factor, [-1, self.rnn_state_dim]) #100*512
vert_factor = tf.reshape(vert_factor, [-1, self.rnn_state_dim]) #50*512
vert_in_factor = tf.gather(vert_factor, edge_pair_mask[:, 0])
vert_out_factor = tf.gather(vert_factor, edge_pair_mask[:, 1])
# concat outgoing edges and ingoing edges with gathered vert_factors
in_edge_w_input = tf.concat(values=[vert_in_factor, edge_factor], axis=1) #100*1024
out_edge_w_input = tf.concat(values=[vert_out_factor, edge_factor], axis=1)
# compute compatibility scores
(self.feed(out_edge_w_input)
.fc(1, relu=False, leaky_value=self.leaky_value, reuse=reuse, name='edge_w_fc')
.sigmoid(name='out_edge_score'))
(self.feed(in_edge_w_input)
.fc(1, relu=False, leaky_value=self.leaky_value, reuse=True, name='edge_w_fc')
.sigmoid(name='in_edge_score'))
out_edge_w = self.get_output('out_edge_score') # 100*1
in_edge_w = self.get_output('in_edge_score')
# weigh the edge factors with computed weigths
out_edge_weighted = tf.multiply(edge_factor, out_edge_w) #100*512
in_edge_weighted = tf.multiply(edge_factor, in_edge_w)
out_edge_weighted = tf.reshape(out_edge_weighted, [-1, self.edge_rnn_max_time_step, self.rnn_state_dim]) # 10*10*512
in_edge_weighted = tf.reshape(in_edge_weighted, [-1, self.edge_rnn_max_time_step, self.rnn_state_dim])
out_edge_weighted_list = tf.split(out_edge_weighted, num_or_size_splits=self.edge_rnn_max_time_step, axis=1) #10*1*512
in_edge_weighted_list = tf.split(in_edge_weighted, num_or_size_splits=self.edge_rnn_max_time_step, axis=1)
first_index_list = []
second_index_list = []
first_tens_list = []
second_tens_list = []
cur_index = 0
for ind in range(self.max_part_size - 1):
first_index_list.append(cur_index)
first_tens_list.append(tf.identity(in_edge_weighted_list[cur_index]))
cur_index = cur_index + self.max_part_size - 1 - ind
for ind in range(self.max_part_size - 1):
second_index_list.append(ind)
second_tens_list.append(tf.identity(out_edge_weighted_list[ind]))
cur_part_index = 0
for f_offset in range(self.max_part_size):
for s_offset in range(f_offset + 1, self.max_part_size):
if not (cur_part_index in first_index_list):
first_tens_list[f_offset] = tf.add(first_tens_list[f_offset], in_edge_weighted_list[cur_part_index])
if not (cur_part_index in second_index_list):
second_tens_list[s_offset - 1] = tf.add(second_tens_list[s_offset - 1], out_edge_weighted_list[cur_part_index])
cur_part_index = cur_part_index + 1
self.first_tens_list = first_tens_list
self.second_tens_list = second_tens_list
self.out_edge_weighted_list = out_edge_weighted_list
self.in_edge_weighted_list = in_edge_weighted_list
final_list = []
for ind in range(self.max_part_size):
if ind == 0:
final_list.append(first_tens_list[ind])
elif ind == self.max_part_size - 1:
final_list.append(second_tens_list[ind - 1])
else:
final_list.append(first_tens_list[ind] + second_tens_list[ind - 1])
vert_ctx = tf.concat(values=final_list, axis=0)
vert_ctx = tf.reshape(vert_ctx, [-1, self.rnn_state_dim])
return vert_ctx
def _vert_rnn_forward(self, vert_in, reuse=False):
with tf.variable_scope('vert_rnn'):
if reuse: tf.get_variable_scope().reuse_variables()
vert_in = tf.reshape(vert_in, [-1, self.vert_rnn_max_time_step, self.rnn_state_dim]) # -1 *5 *512
(vert_out, self.vert_multi_cell_state) = \
tf.nn.dynamic_rnn(self.vert_multi_cell, vert_in, initial_state=self.vert_multi_cell_state, time_major=False) # 10*5*512
return vert_out
def _edge_rnn_forward(self, edge_in, reuse=False):
with tf.variable_scope('edge_rnn'):
if reuse: tf.get_variable_scope().reuse_variables()
edge_in = tf.reshape(edge_in, [-1, self.edge_rnn_max_time_step, self.rnn_state_dim]) # 10*10*512
(edge_out, self.edge_multi_cell_state) = \
tf.nn.dynamic_rnn(self.edge_multi_cell, inputs=edge_in, initial_state=self.edge_multi_cell_state, time_major=False)
return edge_out
def _vert_gen_encoder_rnn_forward(self, vert_in, reuse=False):
with tf.variable_scope('vert_gen_encoder_rnn'):
if reuse: tf.get_variable_scope().reuse_variables()
vert_in = tf.reshape(vert_in, [-1, self.vert_rnn_max_time_step, self.rnn_state_dim])
vert_encoder_initial_state = self.vert_encode_multi_cell.zero_state(self.batch_size, tf.float32)
(vert_out, self.vert_encode_multi_cell_state) = \
tf.nn.dynamic_rnn(self.vert_encode_multi_cell, vert_in, initial_state=vert_encoder_initial_state, time_major=False)
vert_state_out = self.vert_encode_multi_cell_state[-1]
return vert_state_out
def _edge_gen_encoder_rnn_forward(self, edge_in, reuse=False):
with tf.variable_scope('edge_gen_encoder_rnn'):
if reuse: tf.get_variable_scope().reuse_variables()
edge_in = tf.reshape(edge_in, [-1, self.edge_rnn_max_time_step, self.rnn_state_dim])
edge_encoder_initial_state = self.edge_encode_multi_cell.zero_state(self.batch_size, tf.float32)
(edge_out, self.edge_encode_multi_cell_state) = \
tf.nn.dynamic_rnn(self.edge_encode_multi_cell, edge_in, initial_state=edge_encoder_initial_state, time_major=False)
edge_state_out = self.edge_encode_multi_cell_state[-1]
return edge_state_out
def _obj_gen_encoder_rnn_forward(self, obj_vert_in, obj_edge_in, reuse=False):
with tf.variable_scope('obj_gen_encoder_rnn'):
if reuse: tf.get_variable_scope().reuse_variables()
obj_in = tf.concat(values=[obj_vert_in, obj_edge_in], axis=1)
obj_in = tf.reshape(obj_in, [self.batch_size, -1, self.rnn_state_dim])
obj_encoder_initial_state = self.obj_encode_multi_cell.zero_state(self.batch_size, tf.float32)
(obj_out, self.obj_encode_multi_cell_state) = \
tf.nn.dynamic_rnn(self.obj_encode_multi_cell, obj_in, initial_state=obj_encoder_initial_state, time_major=False)
obj_state_out = self.obj_encode_multi_cell_state[-1]
return obj_state_out
def _obj_gen_decoder_rnn_forward(self, latent_input, reuse=False):
with tf.variable_scope('obj_gen_decoder_rnn'):
if reuse: tf.get_variable_scope().reuse_variables()
part_out_list = []
# transform the dimension for initial state
(self.feed(latent_input)
.fc(self.rnn_state_dim, relu=False, name='obj_initial_state_dense', reuse=reuse)
.batch_norm(name='obj_initial_state_bn', relu=False)
.lrelu(leaky_value=self.leaky_value, name='obj_initial_state_dense_out'))
obj_decoder_initial_state = self.get_output('obj_initial_state_dense_out')
obj_zero_input = tf.zeros_like(obj_decoder_initial_state)
obj_initial_input = tf.concat(values=[obj_zero_input, obj_decoder_initial_state], axis=1)
# transform the dimension for initial input
(self.feed(obj_initial_input)
.fc(self.rnn_state_dim, relu=False, name='obj_decoder_input_embedding', reuse=reuse)
.batch_norm(name='obj_decoder_input_bn', relu=False)
.lrelu(leaky_value=self.leaky_value, name='obj_decoder_input_out'))
obj_input = self.get_output('obj_decoder_input_out')
obj_state = [obj_decoder_initial_state] * self.rnn_cell_depth
(part_out, obj_state) = self.obj_decode_multi_cell(obj_input, obj_state)
part_out_list.append(obj_state[-1])
part_out = tf.concat(values=[part_out, obj_decoder_initial_state], axis=1)
(self.feed(part_out)
.fc(self.rnn_state_dim, relu=False, name='obj_decoder_input_embedding', reuse=True)
.batch_norm(name='obj_decoder_input_bn', relu=False, reuse=True)
.lrelu(leaky_value=self.leaky_value, name='obj_decoder_input_out'))
part_out = self.get_output('obj_decoder_input_out')
(part_out, obj_state) = self.obj_decode_multi_cell(part_out, obj_state)
part_out_list.append(obj_state[-1])
return part_out_list
def _vert_gen_decoder_rnn_forward(self, vert_initial_state, reuse=False):
with tf.variable_scope('vert_gen_decoder_rnn'):
if reuse: tf.get_variable_scope().reuse_variables()
vert_out_list = []
vert_zero_input = tf.zeros_like(vert_initial_state)
vert_initial_input = tf.concat(values=[vert_zero_input, vert_initial_state], axis=1)
(self.feed(vert_initial_input)
.fc(self.rnn_state_dim, relu=False, name='vert_input_embedding', reuse=reuse)
.batch_norm(name='vert_decoder_input_bn', relu=False)
.lrelu(leaky_value=self.leaky_value, name='vert_decoder_input_out'))
vert_input = self.get_output('vert_decoder_input_out')
vert_state = [vert_initial_state] * self.rnn_cell_depth
(vert_out, vert_state) = self.vert_decode_multi_cell(vert_input, vert_state)
vert_out_list.append(vert_out)
for ind in range(1, self.max_part_size):
if ind > 0:
should_reuse = True
vert_out = tf.concat(values=[vert_out, vert_initial_state], axis=1)
(self.feed(vert_out)
.fc(self.rnn_state_dim, relu=False, name='vert_input_embedding', reuse=should_reuse)
.batch_norm(name='vert_decoder_input_bn', relu=False, reuse=should_reuse)
.lrelu(leaky_value=self.leaky_value, name='vert_decoder_input_out'))
vert_input = self.get_output('vert_decoder_input_out')
(vert_out, vert_state) = self.vert_decode_multi_cell(vert_input, vert_state)
vert_out_list.append(vert_out)
vert_decoder_out = tf.stack(vert_out_list, axis=1)
return vert_decoder_out
def _edge_gen_decoder_rnn_forward(self, edge_initial_state, reuse=False):
with tf.variable_scope('edge_gen_decoder_rnn'):
if reuse: tf.get_variable_scope().reuse_variables()
edge_out_list = []
edge_zero_input = tf.zeros_like(edge_initial_state)
edge_initial_input = tf.concat(values=[edge_zero_input, edge_initial_state], axis=1)
(self.feed(edge_initial_input)
.fc(self.rnn_state_dim, relu=False, name='edge_input_embedding', reuse=reuse)
.batch_norm(name='edge_decoder_input_bn', relu=False)
.lrelu(leaky_value=self.leaky_value, name='edge_decoder_input_out'))
edge_input = self.get_output('edge_decoder_input_out')
edge_state = [edge_initial_state] * self.rnn_cell_depth
(edge_out, edge_state) = self.edge_decode_multi_cell(edge_input, edge_state)
edge_out_list.append(edge_out)
for ind in range(1, self.edge_rnn_max_time_step):
if ind > 0:
should_reuse = True
edge_out = tf.concat(values=[edge_out, edge_initial_state], axis=1)
(self.feed(edge_out)
.fc(self.rnn_state_dim, relu=False, name='edge_input_embedding', reuse=should_reuse)
.batch_norm(name='edge_decoder_input_bn', relu=False, reuse=should_reuse)
.lrelu(leaky_value=self.leaky_value, name='edge_decoder_input_out'))
edge_input = self.get_output('edge_decoder_input_out')
(edge_out, edge_state) = self.edge_decode_multi_cell(edge_input, edge_state)
edge_out_list.append(edge_out)
edge_decoder_out = tf.stack(edge_out_list, axis=1)
return edge_decoder_out
def _learn_representation_for_graph(self, input_vert_feature, input_edge_feature, part_visible_masks, gaussian_noise,
layer_suffix='', phase_train=True, reuse=False):
"""Learn a latent space in the 2-way VAE"""
layer_name = 'graph_embedding_layer_' + layer_suffix if layer_suffix != '' else 'graph_embedding_layer'
with tf.variable_scope(layer_name) as scope:
if reuse: tf.get_variable_scope().reuse_variables()
p_masks = tf.cast(part_visible_masks, tf.float32)
# Expand the part masks and then merge them with vert features
mask_dims = tf.expand_dims(p_masks, axis=1)
vert_expanded_mask = tf.tile(mask_dims, [1, self.max_part_size, 1])
vert_expanded_mask = tf.reshape(vert_expanded_mask, [-1, self.max_part_size])
input_vert_feature = tf.reshape(input_vert_feature, [-1, self.rnn_state_dim])
input_edge_feature = tf.reshape(input_edge_feature, [-1, self.rnn_state_dim])
vert_feature = tf.concat(values=[input_vert_feature, vert_expanded_mask], axis=1)
(self.feed(vert_feature)
.fc(self.rnn_state_dim, relu=False, name='input_embedding_feature', reuse=reuse, trainable=phase_train)
.batch_norm(is_training=phase_train, name='vert_feature_batch_norm', relu=False)
.lrelu(leaky_value=self.leaky_value, name='vert_feature_out'))
vert_feature_out = self.get_output('vert_feature_out')
# Expand the part masks and then merge them with edge features
edge_expanded_mask = tf.tile(mask_dims, [1, self.edge_rnn_max_time_step, 1])
edge_expanded_mask = tf.reshape(edge_expanded_mask, [-1, self.max_part_size])
edge_feature = tf.concat(values=[input_edge_feature, edge_expanded_mask], axis=1)
(self.feed(edge_feature)
.fc(self.rnn_state_dim, relu=False, name='input_embedding_feature', reuse=True, trainable=phase_train)
.batch_norm(is_training=phase_train, name='edge_feature_batch_norm', relu=False)
.lrelu(leaky_value=self.leaky_value, name='edge_feature_out'))
edge_feature_out = self.get_output('edge_feature_out')
parts_vert_out = self._vert_gen_encoder_rnn_forward(vert_feature_out, reuse=reuse)
parts_edge_out = self._edge_gen_encoder_rnn_forward(edge_feature_out, reuse=reuse)
obj_embedding = self._obj_gen_encoder_rnn_forward(parts_vert_out, parts_edge_out, reuse=reuse)
# obj_embedding = tf.concat(values=[obj_embedding, p_masks], axis=1)
(self.feed(obj_embedding)
.fc(256, relu=False, name='graph_vector_fc', reuse=reuse, trainable=phase_train)
.batch_norm(is_training=phase_train, name='graph_vector_batch_norm', relu=False)
.lrelu(leaky_value=self.leaky_value, name='graph_vector_out'))
graph_embedding_vector = self.get_output('graph_vector_out')
(self.feed(graph_embedding_vector)
.fc(self.embedding_size, relu=False, name='graph_mu', reuse=reuse, trainable=phase_train)
.batch_norm(is_training=phase_train, name='graph_mu_out', relu=False))
(self.feed(graph_embedding_vector)
.fc(self.embedding_size, relu=False, name='graph_sigma', reuse=reuse, trainable=phase_train)
.batch_norm(is_training=phase_train, name='graph_sigma_out', relu=False))
mu = self.get_output('graph_mu_out')
log_sigma = self.get_output('graph_sigma_out')
# we make the sigma positive in this way by an exponential operation
sigma = tf.exp(0.5 * log_sigma)
self.latent_z = mu + tf.multiply(sigma, gaussian_noise)
obj_model_embedding = tf.concat(values=[self.latent_z, p_masks], axis=1)
part_representations = self._obj_gen_decoder_rnn_forward(obj_model_embedding, reuse=reuse)
vert_decoder_out = self._vert_gen_decoder_rnn_forward(part_representations[0], reuse=reuse)
edge_decoder_out = self._edge_gen_decoder_rnn_forward(part_representations[1], reuse=reuse)
self.vert_decoder_out = vert_decoder_out
self.edge_decoder_out = edge_decoder_out
graph_part_loss = self._final_graph_reconstruction_loss(g_vert_in=parts_vert_out,
g_edge_in=parts_edge_out,
g_vert_out=part_representations[0],
g_edge_out=part_representations[1])
return vert_decoder_out, edge_decoder_out, mu, log_sigma, graph_part_loss
##############################################################################
# Functions to compute losses
##############################################################################
def _final_bboxs_loss(self, bbox_input, bbox_output, bbox_loss_mask=None):
"""calculate losses about bounding boxes"""
bbox_in = tf.reshape(bbox_input, [-1])
bbox_pred = tf.reshape(bbox_output, [-1])
l1_loss = tf.abs(tf.subtract(bbox_pred, bbox_in))
l1_loss = tf.reshape(l1_loss, [-1, 2 * self.bbox_size])
l1_loss = tf.reduce_mean(l1_loss, axis=1)
if bbox_loss_mask is not None:
bbox_loss_mask = tf.reshape(bbox_loss_mask, [-1])
l1_loss = tf.multiply(bbox_loss_mask, l1_loss)
bbox_loss = tf.reduce_mean(l1_loss)
return bbox_loss
def _final_voxels_loss(self, voxel_input, voxel_output, voxel_loss_weight=None, voxel_loss_mask=None):
"""calculate the losses for the voxels of parts"""
voxel_in = tf.reshape(voxel_input, [-1])
voxel_pred = tf.reshape(voxel_output, [-1])
cube_len = self.data['config_dict']['CUBE_LEN']
mse_loss = tf.pow(voxel_in - voxel_pred, 2)
mse_loss = tf.reshape(mse_loss, [-1, cube_len * cube_len * cube_len])
mse_loss = tf.reduce_mean(mse_loss, axis=1)
if voxel_loss_weight is not None:
voxel_loss_weight = tf.reshape(voxel_loss_weight, [-1])
mse_loss = tf.multiply(mse_loss, voxel_loss_weight)
if voxel_loss_mask is not None:
voxel_loss_mask = tf.reshape(voxel_loss_mask, [-1])
voxel_loss_mask = tf.cast(voxel_loss_mask, tf.float32)
mse_loss = tf.multiply(voxel_loss_mask, mse_loss)
mse_loss = tf.reduce_mean(mse_loss)
return mse_loss
def _final_graph_reconstruction_loss(self, g_vert_in, g_edge_in, g_vert_out, g_edge_out,
vert_loss_weight=1.0, edge_loss_weight=1.0):
"""calculate the reconstruction loss for the graph representation"""
g_vert_in = tf.reshape(g_vert_in, [-1])
g_vert_out = tf.reshape(g_vert_out, [-1])
g_edge_in = tf.reshape(g_edge_in, [-1])
g_edge_out = tf.reshape(g_edge_out, [-1])
g_vert_mse_loss = tf.reduce_mean(tf.pow(g_vert_in - g_vert_out, 2))
g_edge_mse_loss = tf.reduce_mean(tf.pow(g_edge_in - g_edge_out, 2))
g_total_mse_loss = g_vert_mse_loss * vert_loss_weight + g_edge_mse_loss * edge_loss_weight
return g_total_mse_loss
def _final_graph_kl_loss(self, mu, log_sigma):
"""calculate the kl loss for the graph representation"""
kl_loss = -0.5 * tf.reduce_sum(1 + log_sigma - tf.pow(mu, 2) - tf.exp(log_sigma), reduction_indices=1)
kl_loss = tf.reduce_mean(kl_loss)
return kl_loss
##############################################################################
# Functions to output geometry and structure features
##############################################################################
def pred_voxel_and_bbox(self):
"""output results of voxel map and bounding box"""
voxel_pred = self.get_output('voxel_decoder')
bbox_pred = self.get_output('bbox_pred')
return voxel_pred, bbox_pred
##############################################################################
# Functions for checking input data
##############################################################################
def check_feeds(self, inputs_data):
part_voxels = inputs_data['part_voxels']
part_bbox = inputs_data['part_bbox']
rel_pair_mask_inds = inputs_data['rel_pair_mask_inds']
part_visible_masks = inputs_data['part_visible_masks']
if part_voxels[0].shape[0] != part_bbox[0].shape[0]:
raise KeyError("voxel array and bbox array must have the same size")
if part_visible_masks[0].shape[0] != part_voxels[0].shape[0] / self.max_part_size:
raise KeyError("part visible masks and voxel array mush have the same size")
input_feed = {}
input_feed[self.part_voxels] = part_voxels
input_feed[self.part_bboxs] = part_bbox
input_feed[self.gaussian_noise] = inputs_data['gaussian_noise']
input_feed[self.edge_pair_mask_inds] = rel_pair_mask_inds
input_feed[self.part_visible_masks] = inputs_data['part_visible_masks']
input_feed[self.vert_lr] = float(inputs_data['vert_lr'])
input_feed[self.edge_lr] = float(inputs_data['edge_lr'])
input_feed[self.graph_gen_lr] = float(inputs_data['graph_gen_lr'])
input_feed[self.recon_gen_loss_ratio] = float(inputs_data['recon_gen_loss_ratio'])
input_feed[self.voxel_bbox_ratio] = float(inputs_data['voxel_bbox_ratio'])
input_feed[self.g_rec_kl_loss_ratio] = float(inputs_data['g_rec_kl_loss_ratio'])
input_feed[self.max_gradient_norm] = float(inputs_data['max_gradient_norm'])
input_feed[self.voxel_loss_weights] = inputs_data['part_voxel_loss_weights']
input_feed[self.part_bbox_loss_masks] = inputs_data['part_bbox_loss_masks']
return input_feed
##############################################################################
# Additional function to compute the final output voxel maps and bounding boxes
##############################################################################
def get_batch_info(self, sess, inputs_data):
input_feed = self.check_feeds(inputs_data)
part_visible_masks = sess.run(self.part_visible_masks[self.gpu_num - 1], feed_dict=input_feed)
d_voxels = sess.run(self.voxel_pred, feed_dict=input_feed)
d_bboxs = sess.run(self.bbox_pred, feed_dict=input_feed)
# part_visible_masks = sess.run(self.part_visible_masks[self.gpu_num - 1], feed_dict=input_feed)
voxels_list = self.data_helper.process_voxel_data(d_voxels, part_visible_masks)
bboxs_list = self.data_helper.process_bbox_data(d_bboxs, part_visible_masks)
return voxels_list, bboxs_list, part_visible_masks
##############################################################################
# Functions to train and test
##############################################################################
def train(self, sess, inputs_data, iter_n, is_summary=False):
input_feed = self.check_feeds(inputs_data)
keep_prob = self.data['config_dict']['TRAIN']['DROPOUT_KEEP_PROB']
input_feed[self.keep_prob] = keep_prob
total_iter_n = int(self.data['config_dict']['TRAIN']['ITER_NUM'])
if is_summary:
output_feed = [self.train_op, self.total_losses, self.voxel_loss, self.bbox_loss, self.summary_op]
else:
output_feed = [self.train_op, self.total_losses, self.voxel_loss, self.bbox_loss]
outputs = sess.run(output_feed, input_feed)
print ("[%6d/%6d], total loss: %.8f, voxel loss: %.8f, bbox loss: %.8f" %
(int(iter_n), total_iter_n, outputs[1], outputs[2], outputs[3]))
if is_summary:
return outputs[1], outputs[4]
else:
return outputs[1]
def test(self, sess, inputs_data):
input_feed = {}
input_feed[self.latent_z] = inputs_data['gaussian_noise'][self.gpu_num - 1]
input_feed[self.part_visible_masks] = inputs_data['visible_part_index']
d_voxels = sess.run(self.voxel_pred, feed_dict=input_feed)
d_bboxs = sess.run(self.bbox_pred, feed_dict=input_feed)
part_visible_masks = inputs_data['visible_part_index'][self.gpu_num - 1]
voxels_list = self.data_helper.process_voxel_data(d_voxels, part_visible_masks)
bboxs_list = self.data_helper.process_bbox_data(d_bboxs, part_visible_masks)
return voxels_list, bboxs_list, part_visible_masks