From 4c7b91ffdc526caba61a852099be01b1774dc127 Mon Sep 17 00:00:00 2001 From: mckay Date: Sat, 10 May 2025 13:53:13 +0800 Subject: [PATCH] dot --- brep2sdf/train.py | 50 +++++++++++++++++++++-------------------------- 1 file changed, 22 insertions(+), 28 deletions(-) diff --git a/brep2sdf/train.py b/brep2sdf/train.py index e600da5..4f396ec 100644 --- a/brep2sdf/train.py +++ b/brep2sdf/train.py @@ -147,6 +147,7 @@ class Trainer: self.scheduler = LearningRateScheduler(config.train.learning_rate_schedule, config.train.weight_decay, self.model.parameters()) self.loss_manager = LossManager(ablation="none") + self.best_loss = float('inf') logger.gpu_memory_stats("训练器初始化后") self.sampler = NormalPerPoint( @@ -324,8 +325,8 @@ class Trainer: logger.info(f'Train Epoch: {epoch:4d}]\t' f'Loss: {current_loss:.6f}') if loss_details: logger.info(f"Loss Details: {loss_details}") - dot = make_dot((mnfld_pred), params=dict(list(self.model.named_parameters()) + [('mnfld_pnts', mnfld_pnts)])) - dot.render("forward_graph1", format="png") # 这会保存计算图为png格式 + #dot = make_dot((mnfld_pred), params=dict(list(self.model.named_parameters()) + [('mnfld_pnts', mnfld_pnts)])) + #dot.render("forward_graph1", format="png") # 这会保存计算图为png格式 return total_loss # 对于单批次训练,直接返回当前损失 @@ -358,7 +359,7 @@ class Trainer: logger.warning(f"Patch {patch_id} has no valid points.") continue - nonmnfld_pnts, psdf = self.sampler.get_norm_points(points[:,0:3], points[:,3:6]) # 生成非流形点 + nonmnfld_pnts, psdf = self.sampler.get_norm_points(points[:,0:3], points[:,3:6],0.1) # 生成非流形点 all_points.append(points) valid_patch_ids.append(patch_id) nonmnfld_pnts_list.append(nonmnfld_pnts) @@ -480,8 +481,8 @@ class Trainer: subloss_names = ["manifold", "normals", "eikonal", "offsurface", "psdf"] logger.info(" ".join([f"{name}: {weighted_avg[i].item():.6f}" for i, name in enumerate(subloss_names)])) - dot = make_dot((mnfld_pred, nonmnfld_pred), params=dict(list(self.model.named_parameters()) + [('mnfld_pnts', mnfld_pnts), ('nonmnfld_pnts', nonmnfld_pnts)])) - dot.render("forward_graph2", format="png") # 这会保存计算图为png格式 + #dot = make_dot((mnfld_pred, nonmnfld_pred), params=dict(list(self.model.named_parameters()) + [('mnfld_pnts', mnfld_pnts), ('nonmnfld_pnts', nonmnfld_pnts)])) + #dot.render("forward_graph2", format="png") # 这会保存计算图为png格式 avg_loss = sum(losses) / len(losses) @@ -607,7 +608,7 @@ class Trainer: _, _mnfld_face_indices_mask, _mnfld_operator = self.root.forward(_mnfld_pnts) # 生成非流形点 - _nonmnfld_pnts, _psdf = self.sampler.get_norm_points(_mnfld_pnts, _normals) + _nonmnfld_pnts, _psdf = self.sampler.get_norm_points(_mnfld_pnts, _normals, 0.1) _, _nonmnfld_face_indices_mask, _nonmnfld_operator = self.root.forward(_nonmnfld_pnts) # 更新缓存 @@ -664,8 +665,8 @@ class Trainer: _nonmnfld_face_indices_mask[start_idx:end_idx], _nonmnfld_operator[start_idx:end_idx] ) - dot = make_dot((mnfld_pred), params=dict(list(self.model.named_parameters()) + [('mnfld_pnts', mnfld_pnts)])) - dot.render("forward_graph3", format="png") # 这会保存计算图为png格式 + #dot = make_dot((mnfld_pred), params=dict(list(self.model.named_parameters()) + [('mnfld_pnts', mnfld_pnts)])) + #dot.render("forward_graph3", format="png") # 这会保存计算图为png格式 #logger.print_tensor_stats("psdf",psdf) #logger.print_tensor_stats("nonmnfld_pnts",nonmnfld_pnts) @@ -735,9 +736,9 @@ class Trainer: # 记录训练进度 (只记录有效的损失) logger.info(f'Train Epoch: {epoch:4d}]\t' - f'Loss: {current_loss:.6f}') + f'Loss: {total_loss:.6f}') if loss_details: logger.info(f"Loss Details: {loss_details}") - + self.validate(epoch,total_loss) return total_loss # 对于单批次训练,直接返回当前损失 @@ -873,25 +874,15 @@ class Trainer: return total_loss # 对于单批次训练,直接返回当前损失 - def validate(self, epoch: int) -> float: - self.model.eval() - total_loss = 0.0 + def validate(self, epoch, loss): + if loss < self.best_loss: + self.best_loss = loss + self._save_checkpoint(-1, loss) # 存 best + logger.info(f'Best Epoch: {epoch}\tAverage Loss: {loss:.6f}') + return - with torch.no_grad(): - for batch in self.val_loader: - points = batch['points'].to(self.device) - gt_sdf = batch['sdf'].to(self.device) - - pred_sdf = self.model(points) - loss = torch.nn.functional.mse_loss(pred_sdf, gt_sdf) - total_loss += loss.item() - - avg_loss = total_loss / len(self.val_loader) - logger.info(f'Validation Epoch: {epoch}\tAverage Loss: {avg_loss:.6f}') - return avg_loss def train(self): - best_val_loss = float('inf') logger.info("Starting training...") start_time = time.time() self.cached_train_data=None @@ -948,7 +939,7 @@ class Trainer: self._tracing_model_by_script() #self._tracing_model() logger.info(f'Training completed in {total_time//3600:.0f}h {(total_time%3600)//60:.0f}m {total_time%60:.2f}s') - logger.info(f'Best validation loss: {best_val_loss:.6f}') + logger.info(f'Best validation loss: {self.best_loss:.6f}') #self.test_load() def test_load(self): @@ -999,7 +990,10 @@ class Trainer: self.model_name ) os.makedirs(checkpoint_dir, exist_ok=True) - checkpoint_path = os.path.join(checkpoint_dir, f"epoch_{epoch:03d}.pth") + if epoch >= 0: + checkpoint_path = os.path.join(checkpoint_dir, f"epoch_{epoch:03d}.pth") + else: + checkpoint_path = os.path.join(checkpoint_dir, f"best.pth") # 只保存状态字典 torch.save({