diff --git a/brep2sdf/train.py b/brep2sdf/train.py index e753e5f..eedccbf 100644 --- a/brep2sdf/train.py +++ b/brep2sdf/train.py @@ -13,6 +13,7 @@ from brep2sdf.networks.octree import OctreeNode from brep2sdf.networks.loss import LossManager from brep2sdf.networks.patch_graph import PatchGraph from brep2sdf.networks.sample import NormalPerPoint +from brep2sdf.networks.learning_rate import LearningRateScheduler from brep2sdf.utils.logger import logger @@ -140,6 +141,8 @@ class Trainer: weight_decay=config.train.weight_decay ) + #self.scheduler = LearningRateScheduler(self.conf.get_list('train.learning_rate_schedule'), self.conf.get_float('train.weight_decay'), self.model.parameters()) + self.loss_manager = LossManager(ablation="none") logger.gpu_memory_stats("训练器初始化后") @@ -222,6 +225,7 @@ class Trainer: normals = None if args.use_normal: normals = torch.tensor(self.data["surf_pnt_normals"][step], device=self.device) + logger.debug(normals) # --- 准备模型输入,启用梯度 --- mnfld_points.requires_grad_(True) # 在检查之后启用梯度 @@ -232,6 +236,9 @@ class Trainer: mnfld_pred = self.model.forward_training_volumes(mnfld_points, step) nonmnfld_pred = self.model.forward_training_volumes(nonmnfld_pnts, step) + logger.print_tensor_stats("mnfld_pred",mnfld_pred) + logger.print_tensor_stats("nonmnfld_pred",nonmnfld_pred) + if self.debug_mode: # --- 检查前向传播的输出 --- logger.gpu_memory_stats("前向传播后") @@ -302,7 +309,7 @@ class Trainer: self.model.train() total_loss = 0.0 step = 0 # 如果你的训练是分批次的,这里应该用批次索引 - batch_size = 10240 # 设置合适的batch大小 + batch_size = 8192 # 设置合适的batch大小 # 将数据分成多个batch num_points = self.sdf_data.shape[0] @@ -311,7 +318,8 @@ class Trainer: for batch_idx in range(num_batches): start_idx = batch_idx * batch_size end_idx = min((batch_idx + 1) * batch_size, num_points) - points = self.sdf_data[start_idx:end_idx, 0:3].clone().detach() # 取前3列作为点 + mnfld_pnts = self.sdf_data[start_idx:end_idx, 0:3].clone().detach() # 取前3列作为点 + nonmnfld_pnts = self.sampler.get_points(mnfld_pnts.unsqueeze(0)).squeeze(0) # 生成非流形点 gt_sdf = self.sdf_data[start_idx:end_idx, -1].clone().detach() # 取最后一列作为SDF真值 normals = None if args.use_normal: @@ -322,19 +330,28 @@ class Trainer: # 执行检查 if self.debug_mode: - if check_tensor(points, "Input Points", epoch, step): return float('inf') + if check_tensor(mnfld_pnts, "Input Points", epoch, step): return float('inf') if check_tensor(gt_sdf, "Input GT SDF", epoch, step): return float('inf') if args.use_normal: # 只有在请求法线时才检查 normals if check_tensor(normals, "Input Normals", epoch, step): return float('inf') + logger.debug(normals) + logger.print_tensor_stats("normals-x",normals[0]) + logger.print_tensor_stats("normals-y",normals[1]) + logger.print_tensor_stats("normals-z",normals[2]) # --- 准备模型输入,启用梯度 --- - points.requires_grad_(True) # 在检查之后启用梯度 + mnfld_pnts.requires_grad_(True) # 在检查之后启用梯度 + nonmnfld_pnts.requires_grad_(True) # 在检查之后启用梯度 # --- 前向传播 --- self.optimizer.zero_grad() - pred_sdf = self.model(points) + mnfld_pred = self.model(mnfld_pnts) + nonmnfld_pred = self.model(nonmnfld_pnts) + + logger.print_tensor_stats("mnfld_pred",mnfld_pred) + logger.print_tensor_stats("nonmnfld_pred",nonmnfld_pred) if self.debug_mode: # --- 检查前向传播的输出 --- @@ -356,10 +373,12 @@ class Trainer: #if check_tensor(points, "Points (Loss Input)", epoch, step): raise ValueError("Bad points before loss") logger.gpu_memory_stats("计算损失前") loss, loss_details = self.loss_manager.compute_loss( - points, + mnfld_pnts, + nonmnfld_pnts, normals, # 传递检查过的 normals gt_sdf, - pred_sdf + mnfld_pred, + nonmnfld_pred ) else: loss = torch.nn.functional.mse_loss(pred_sdf, gt_sdf) @@ -448,8 +467,8 @@ class Trainer: for epoch in range(start_epoch, self.config.train.num_epochs + 1): # 训练一个epoch - train_loss = self.train_epoch_stage1(epoch) - #train_loss = self.train_epoch(epoch) + #train_loss = self.train_epoch_stage1(epoch) + train_loss = self.train_epoch(epoch) # 验证 '''