Browse Source

refactor: Replace print statements with logger in run.py and network.py

- Replaced print statements with logger.info() in ReconstructionRunner class
- Added logging for input and output tensor shapes in NHRepNet forward method
- Improved logging consistency and added docstring for network forward method
NH-Rep
mckay 4 weeks ago
parent
commit
39f56470ef
  1. 6
      code/conversion/run.py
  2. 16
      code/model/network.py

6
code/conversion/run.py

@ -242,12 +242,12 @@ class ReconstructionRunner:
if epoch % self.conf.get_int('train.status_frequency') == 0: if epoch % self.conf.get_int('train.status_frequency') == 0:
print('Train Epoch: [{}/{} ({:.0f}%)]\tTrain Loss: {:.6f}\t Manifold loss: {:.6f}' logger.info('Train Epoch: [{}/{} ({:.0f}%)]\tTrain Loss: {:.6f}\t Manifold loss: {:.6f}'
'\tManifold patch loss: {:.6f}\t grad loss h: {:.6f}\t normals loss all: {:.6f}\t normals loss h: {:.6f}\t Manifold consistency loss: {:.6f}\tCorrection loss: {:.6f}\t Offsurface loss: {:.6f}'.format( '\tManifold patch loss: {:.6f}\t grad loss h: {:.6f}\t normals loss all: {:.6f}\t normals loss h: {:.6f}\t Manifold consistency loss: {:.6f}\tCorrection loss: {:.6f}\t Offsurface loss: {:.6f}'.format(
epoch, self.nepochs, 100. * epoch / self.nepochs, epoch, self.nepochs, 100. * epoch / self.nepochs,
loss.item(), mnfld_loss.item(), mnfld_loss_patch.item(), grad_loss_h.item(), normals_loss.item(), normals_loss_h.item(), mnfld_consistency_loss.item(), correction_loss.item(), offsurface_loss.item())) loss.item(), mnfld_loss.item(), mnfld_loss_patch.item(), grad_loss_h.item(), normals_loss.item(), normals_loss_h.item(), mnfld_consistency_loss.item(), correction_loss.item(), offsurface_loss.item()))
if args.feature_sample: if args.feature_sample:
print('feature mnfld loss: {} patch loss: {} cons loss: {}'.format(feature_mnfld_loss.item(), feature_loss_patch.item(), feature_loss_cons.item())) logger.info('feature mnfld loss: {} patch loss: {} cons loss: {}'.format(feature_mnfld_loss.item(), feature_loss_patch.item(), feature_loss_cons.item()))
self.tracing() self.tracing()
@ -278,7 +278,7 @@ class ReconstructionRunner:
example = torch.rand(224,3).to(device) example = torch.rand(224,3).to(device)
traced_script_module = torch.jit.trace(network, example) traced_script_module = torch.jit.trace(network, example)
traced_script_module.save(save_prefix + self.foldername + "_model_h.pt") traced_script_module.save(save_prefix + self.foldername + "_model_h.pt")
print('converting to pt finished') logger.info('converting to pt finished')
def plot_shapes(self, epoch, path=None, with_cuts=False, file_suffix="all"): def plot_shapes(self, epoch, path=None, with_cuts=False, file_suffix="all"):
# plot network validation shapes # plot network validation shapes

16
code/model/network.py

@ -3,6 +3,7 @@ import torch.nn as nn
import torch import torch
from torch.autograd import grad from torch.autograd import grad
from utils.logger import logger
#borrowed from siren paper #borrowed from siren paper
class Sine(nn.Module): class Sine(nn.Module):
def __init(self): def __init(self):
@ -143,10 +144,18 @@ class NHRepNet(nn.Module):
def forward(self, input): def forward(self, input):
'''
:param input: 2D array of floats, each row represents a point in 3D space
e.g. input = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
:return: 2D array of floats, each row represents a point in 3D space
e.g. output = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
:note: The input tensor should have a shape of (N, 3), where N is the number of points. usually N = 100000
The output tensor will have a shape of (N, 5), where each row corresponds to the processed output for the respective input point.
'''
# logger.info(f"input shape: {input.shape}")
x = input x = input
for layer in range(0, self.sdf_layers - 1): for layer in range(0, self.sdf_layers - 1):
lin = getattr(self, "sdf_" + str(layer)) lin = getattr(self, "sdf_" + str(layer))
if layer in self.skip_in: if layer in self.skip_in:
x = torch.cat([x, input], -1) / np.sqrt(2) x = torch.cat([x, input], -1) / np.sqrt(2)
@ -160,8 +169,13 @@ class NHRepNet(nn.Module):
# h = self.nested_cvx_output_soft_blend(output_value, self.csg_tree, self.flag_convex) # h = self.nested_cvx_output_soft_blend(output_value, self.csg_tree, self.flag_convex)
if self.flag_output == 0: if self.flag_output == 0:
# logger.info(f"h shape: {h.shape}")
# logger.info(f"output_value shape: {output_value.shape}")
return torch.cat((h, output_value), 1) # return all return torch.cat((h, output_value), 1) # return all
elif self.flag_output == 1: elif self.flag_output == 1:
# logger.info(f"h shape: {h.shape}")
return h #return h return h #return h
else: else:
# logger.info(f"output_value shape: {output_value.shape}")
# logger.info(f"flag_output: {self.flag_output}")
return output_value[:, self.flag_output - 2] #return f_i return output_value[:, self.flag_output - 2] #return f_i
Loading…
Cancel
Save