|
|
@ -4,7 +4,7 @@ import time |
|
|
|
import os |
|
|
|
import numpy as np |
|
|
|
import argparse |
|
|
|
|
|
|
|
from torchviz import make_dot |
|
|
|
|
|
|
|
from brep2sdf.config.default_config import get_default_config |
|
|
|
from brep2sdf.data.data import load_brep_file,prepare_sdf_data, print_data_distribution, check_tensor |
|
|
@ -324,6 +324,8 @@ class Trainer: |
|
|
|
logger.info(f'Train Epoch: {epoch:4d}]\t' |
|
|
|
f'Loss: {current_loss:.6f}') |
|
|
|
if loss_details: logger.info(f"Loss Details: {loss_details}") |
|
|
|
dot = make_dot((mnfld_pred), params=dict(list(self.model.named_parameters()) + [('mnfld_pnts', mnfld_pnts)])) |
|
|
|
dot.render("forward_graph1", format="png") # 这会保存计算图为png格式 |
|
|
|
|
|
|
|
return total_loss # 对于单批次训练,直接返回当前损失 |
|
|
|
|
|
|
@ -478,6 +480,9 @@ class Trainer: |
|
|
|
subloss_names = ["manifold", "normals", "eikonal", "offsurface", "psdf"] |
|
|
|
logger.info(" ".join([f"{name}: {weighted_avg[i].item():.6f}" for i, name in enumerate(subloss_names)])) |
|
|
|
|
|
|
|
dot = make_dot((mnfld_pred, nonmnfld_pred), params=dict(list(self.model.named_parameters()) + [('mnfld_pnts', mnfld_pnts), ('nonmnfld_pnts', nonmnfld_pnts)])) |
|
|
|
dot.render("forward_graph2", format="png") # 这会保存计算图为png格式 |
|
|
|
|
|
|
|
|
|
|
|
avg_loss = sum(losses) / len(losses) |
|
|
|
logger.info(f"Total Loss: {total_loss:.6f} | Avg Loss: {avg_loss:.6f}") |
|
|
@ -659,6 +664,8 @@ class Trainer: |
|
|
|
_nonmnfld_face_indices_mask[start_idx:end_idx], |
|
|
|
_nonmnfld_operator[start_idx:end_idx] |
|
|
|
) |
|
|
|
dot = make_dot((mnfld_pred), params=dict(list(self.model.named_parameters()) + [('mnfld_pnts', mnfld_pnts)])) |
|
|
|
dot.render("forward_graph3", format="png") # 这会保存计算图为png格式 |
|
|
|
|
|
|
|
#logger.print_tensor_stats("psdf",psdf) |
|
|
|
#logger.print_tensor_stats("nonmnfld_pnts",nonmnfld_pnts) |
|
|
@ -731,6 +738,7 @@ class Trainer: |
|
|
|
f'Loss: {current_loss:.6f}') |
|
|
|
if loss_details: logger.info(f"Loss Details: {loss_details}") |
|
|
|
|
|
|
|
|
|
|
|
return total_loss # 对于单批次训练,直接返回当前损失 |
|
|
|
|
|
|
|
def train_epoch(self, epoch: int,resample:bool=True) -> float: |
|
|
|