From 39f56470ef09a8808a1eb228475651432d896944 Mon Sep 17 00:00:00 2001 From: mckay Date: Sun, 16 Feb 2025 16:20:16 +0800 Subject: [PATCH] refactor: Replace print statements with logger in run.py and network.py - Replaced print statements with logger.info() in ReconstructionRunner class - Added logging for input and output tensor shapes in NHRepNet forward method - Improved logging consistency and added docstring for network forward method --- code/conversion/run.py | 6 +++--- code/model/network.py | 16 +++++++++++++++- 2 files changed, 18 insertions(+), 4 deletions(-) diff --git a/code/conversion/run.py b/code/conversion/run.py index 1cb145e..6713b9d 100644 --- a/code/conversion/run.py +++ b/code/conversion/run.py @@ -242,12 +242,12 @@ class ReconstructionRunner: if epoch % self.conf.get_int('train.status_frequency') == 0: - print('Train Epoch: [{}/{} ({:.0f}%)]\tTrain Loss: {:.6f}\t Manifold loss: {:.6f}' + logger.info('Train Epoch: [{}/{} ({:.0f}%)]\tTrain Loss: {:.6f}\t Manifold loss: {:.6f}' '\tManifold patch loss: {:.6f}\t grad loss h: {:.6f}\t normals loss all: {:.6f}\t normals loss h: {:.6f}\t Manifold consistency loss: {:.6f}\tCorrection loss: {:.6f}\t Offsurface loss: {:.6f}'.format( epoch, self.nepochs, 100. * epoch / self.nepochs, loss.item(), mnfld_loss.item(), mnfld_loss_patch.item(), grad_loss_h.item(), normals_loss.item(), normals_loss_h.item(), mnfld_consistency_loss.item(), correction_loss.item(), offsurface_loss.item())) if args.feature_sample: - print('feature mnfld loss: {} patch loss: {} cons loss: {}'.format(feature_mnfld_loss.item(), feature_loss_patch.item(), feature_loss_cons.item())) + logger.info('feature mnfld loss: {} patch loss: {} cons loss: {}'.format(feature_mnfld_loss.item(), feature_loss_patch.item(), feature_loss_cons.item())) self.tracing() @@ -278,7 +278,7 @@ class ReconstructionRunner: example = torch.rand(224,3).to(device) traced_script_module = torch.jit.trace(network, example) traced_script_module.save(save_prefix + self.foldername + "_model_h.pt") - print('converting to pt finished') + logger.info('converting to pt finished') def plot_shapes(self, epoch, path=None, with_cuts=False, file_suffix="all"): # plot network validation shapes diff --git a/code/model/network.py b/code/model/network.py index e8232b1..4869f85 100644 --- a/code/model/network.py +++ b/code/model/network.py @@ -3,6 +3,7 @@ import torch.nn as nn import torch from torch.autograd import grad +from utils.logger import logger #borrowed from siren paper class Sine(nn.Module): def __init(self): @@ -143,10 +144,18 @@ class NHRepNet(nn.Module): def forward(self, input): + ''' + :param input: 2D array of floats, each row represents a point in 3D space + e.g. input = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + :return: 2D array of floats, each row represents a point in 3D space + e.g. output = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + :note: The input tensor should have a shape of (N, 3), where N is the number of points. usually N = 100000 + The output tensor will have a shape of (N, 5), where each row corresponds to the processed output for the respective input point. + ''' + # logger.info(f"input shape: {input.shape}") x = input for layer in range(0, self.sdf_layers - 1): lin = getattr(self, "sdf_" + str(layer)) - if layer in self.skip_in: x = torch.cat([x, input], -1) / np.sqrt(2) @@ -160,8 +169,13 @@ class NHRepNet(nn.Module): # h = self.nested_cvx_output_soft_blend(output_value, self.csg_tree, self.flag_convex) if self.flag_output == 0: + # logger.info(f"h shape: {h.shape}") + # logger.info(f"output_value shape: {output_value.shape}") return torch.cat((h, output_value), 1) # return all elif self.flag_output == 1: + # logger.info(f"h shape: {h.shape}") return h #return h else: + # logger.info(f"output_value shape: {output_value.shape}") + # logger.info(f"flag_output: {self.flag_output}") return output_value[:, self.flag_output - 2] #return f_i \ No newline at end of file