| 
						
						
							
								
							
						
						
					 | 
				
				 | 
				
					@ -350,106 +350,135 @@ class Trainer: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            return total_loss # 对于单批次训练,直接返回当前损失 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    def train_stage2(self, num_epoch): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        if not args.use_normal: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            logger.warning(f"need args.use_normal, skip stage2") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            return float('inf') | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        self.model.freeze_stage2() | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        self.cached_train_data = None | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        num_volumes = self.data['surf_bbox_ncs'].shape[0] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        surf_bbox=torch.tensor( | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        surf_bbox = torch.tensor( | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            self.data['surf_bbox_ncs'],  | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            dtype=torch.float32, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            device=self.device | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        ) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        logger.info(f"Start Stage 2 Training: {num_epoch} epochs") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					         | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        total_loss = 0.0  | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        # 收集所有有效的点云数据和对应的 patch_ids | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        all_points = [] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        valid_patch_ids = [] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        nonmnfld_pnts_list, psdf_list = [], [] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        for patch_id in range(num_volumes): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            points = points_in_box(self.train_surf_ncs, surf_bbox[patch_id]) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            loss = self.train_stage2_by_volume(num_epoch, patch_id, points) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            logger.debug(f"Patch [{patch_id:2d}] | Loss: {loss:.6f}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            total_loss += loss | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            points = points.to(self.device) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            if points.shape[0] == 0: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                logger.warning(f"Patch {patch_id} has no valid points.") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                continue | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        return total_loss | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            nonmnfld_pnts, psdf = self.sampler.get_norm_points(points[:,0:3], points[:,3:6])  # 生成非流形点 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            all_points.append(points) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            valid_patch_ids.append(patch_id) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            nonmnfld_pnts_list.append(nonmnfld_pnts) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            psdf_list.append(psdf) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        if not all_points: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            logger.warning("No valid patches found.") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            return 0.0 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    def train_stage2_by_volume(self, num_epoch, patch_id, points): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        logger.debug(f"Patch [{patch_id:2d}] | train pnt number {points.shape[0]}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        points.to(self.device) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        mnfld_pnts = points[:,0:3] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        logger.debug(mnfld_pnts) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        gt_sdf = torch.zeros(mnfld_pnts.shape[0], device=self.device) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        if not args.use_normal: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            logger.warning(f"need args.use_normal,skip stage2") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            return float('inf') | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        normals = points[:,3:6] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        logger.debug(normals) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        nonmnfld_pnts, psdf = self.sampler.get_norm_points(mnfld_pnts, normals)  # 生成非流形点 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        weights = torch.tensor([points.shape[0] for points in all_points], device=self.device).float() | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        weights /= weights.sum() | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        # --- 准备模型输入,启用梯度 --- | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        mnfld_pnts.requires_grad_(True) # 在检查之后启用梯度 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        nonmnfld_pnts.requires_grad_(True) # 在检查之后启用梯度 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        # 清空梯度 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        self.scheduler.optimizer.zero_grad() | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        # 对每个 patch 进行前向传播并计算损失 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					         | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        for epoch in range(num_epoch): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            # --- 前向传播 --- | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            mnfld_pred = self.model.forward_training_volumes(mnfld_pnts, patch_id) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            nonmnfld_pred = self.model.forward_training_volumes(nonmnfld_pnts, patch_id) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            losses = [] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            loss_detailss = [] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            for patch_id, points,nonmnfld_pnts, psdf in zip(valid_patch_ids, all_points, nonmnfld_pnts_list,psdf_list): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                logger.debug(f"Patch [{patch_id:2d}] | train pnt number {points.shape[0]}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                mnfld_pnts = points[:, 0:3] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                gt_sdf = torch.zeros(mnfld_pnts.shape[0], device=self.device) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                normals = points[:, 3:6] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            # --- 计算损失 --- | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            loss = torch.tensor(float('nan'), device=self.device) # 初始化为 NaN 以防计算失败 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            loss_details = {} | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            try: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                logger.gpu_memory_stats("计算损失前") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                loss, loss_details = self.loss_manager.compute_loss( | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    mnfld_pnts, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    nonmnfld_pnts, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    normals, # 传递检查过的 normals | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    gt_sdf, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    mnfld_pred, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    nonmnfld_pred, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    psdf | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                ) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                # --- 准备模型输入,启用梯度 --- | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                mnfld_pnts.requires_grad_(True) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                nonmnfld_pnts.requires_grad_(True) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                # --- 4. 检查损失计算结果 --- | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                if self.debug_mode: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    logger.print_tensor_stats("psdf",psdf) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    logger.print_tensor_stats("nonmnfld_pred",nonmnfld_pred) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    if check_tensor(loss, "Calculated Loss", epoch, step): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        logger.error(f"Epoch {epoch} Step {step}: Loss calculation resulted in inf/nan.") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        if loss_details: logger.error(f"Loss Details: {loss_details}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        return float('inf') # 如果损失无效,停止这个epoch | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                # --- 前向传播 --- | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                mnfld_pred = self.model.forward_training_volumes(mnfld_pnts, patch_id) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                nonmnfld_pred = self.model.forward_training_volumes(nonmnfld_pnts, patch_id) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            except Exception as loss_e: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                logger.error(f"Epoch {epoch} Step {step}: Error during loss calculation: {loss_e}", exc_info=True) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                return float('inf') # 如果计算出错,停止这个epoch | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            logger.gpu_memory_stats("损失计算后") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                # --- 计算损失 --- | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                loss = torch.tensor(float('nan'), device=self.device) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                loss_details = {} | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                try: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    loss, loss_details = self.loss_manager.compute_loss_volume( | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        mnfld_pnts, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        nonmnfld_pnts, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        normals, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        gt_sdf, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        mnfld_pred, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        nonmnfld_pred, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        psdf | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    ) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            # --- 反向传播和优化 --- | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            try: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                # 反向传播 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                self.scheduler.optimizer.zero_grad()  # 清空梯度 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                loss.backward()  # 反向传播 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                self.scheduler.optimizer.step()  # 更新参数 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                self.scheduler.step(loss,epoch) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            except Exception as backward_e: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                logger.error(f"Epoch {epoch} Step {step}: Error during backward pass or optimizer step: {backward_e}", exc_info=True) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                # 如果你想看是哪个操作导致的,可以启用 anomaly detection | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                # torch.autograd.set_detect_anomaly(True) # 放在训练开始前 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                return float('inf') # 如果反向传播或优化出错,停止这个epoch | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    # 检查损失计算结果 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    if self.debug_mode: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        if check_tensor(loss, "Calculated Loss", epoch): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                            logger.error(f"Epoch {epoch}: Loss calculation resulted in inf/nan.") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                            if loss_details: logger.error(f"Loss Details: {loss_details}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                            return float('inf') | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                except Exception as loss_e: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    logger.error(f"Epoch {epoch}: Error during loss calculation: {loss_e}", exc_info=True) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    return float('inf') | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                # 累积损失 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                losses.append(loss) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                if epoch % 1 == 0: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    loss_detailss.append(loss_details) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            torch.cuda.empty_cache() | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					             | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            if epoch % 100 == 0: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                # 记录训练进度 (只记录有效的损失) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                logger.info(f'Train Epoch: {epoch:4d}]\t' | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                            f'Loss: {loss:.6f}') | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                if loss_details: logger.info(f"Loss Details: {loss_details}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        return loss # last loss | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            # 多个损失平均后反向传播 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            loss_tensor = torch.stack(losses) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            mean_loss = (loss_tensor * weights).sum() | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            mean_loss.backward() | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            # 更新参数 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            self.scheduler.optimizer.step() | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            self.scheduler.step(mean_loss, epoch) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            # 清空梯度 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            self.scheduler.optimizer.zero_grad() | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            # 清理缓存 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            torch.cuda.empty_cache() | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					         | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            # 如果你想查看详细的损失信息,可以在这里添加日志记录 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            if epoch % 1 == 0: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                logger.info(f'Train [Stage 2] Epoch: {epoch:4d}\t' | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                            f'Loss: {loss:.6f}') | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                loss_details_tensor = torch.stack(loss_detailss)  # shape: [num_patches, 5] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                # 对每个子项取加权平均(如果需要 weights) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                weighted_avg = (loss_details_tensor * weights.view(-1, 1)).sum(dim=0) / weights.sum() | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                subloss_names = ["manifold", "normals", "eikonal", "offsurface", "psdf"] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                logger.info("  ".join([f"{name}: {weighted_avg[i].item():.6f}" for i, name in enumerate(subloss_names)])) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        avg_loss = sum(losses) / len(losses) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        logger.info(f"Total Loss: {total_loss:.6f} | Avg Loss: {avg_loss:.6f}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        return avg_loss | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    def train_epoch_stage2_(self, epoch: int): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        total_loss = 0.0 | 
				
			
			
		
	
	
		
			
				
					| 
						
							
								
							
						
						
							
								
							
						
						
					 | 
				
				 | 
				
					@ -874,14 +903,31 @@ class Trainer: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                logger.info(f'Checkpoint saved at epoch {epoch}') | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					         | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        # stage2 freeze_stage2 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					         | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        self.train_stage2(self.config.train.num_epochs2) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        epoch = self.config.train.num_epochs1+self.config.train.num_epochs2 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        logger.info(f'Checkpoint saved at epoch {epoch}') | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        self._save_checkpoint(epoch, 0.0) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        max_stage2_epoch = self.config.train.num_epochs1+self.config.train.num_epochs2 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        if start_epoch < max_stage2_epoch: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            self.scheduler.reset() | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            self.train_stage2(self.config.train.num_epochs2) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            cur_epoch = max_stage2_epoch | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            logger.info(f'Checkpoint saved at epoch {cur_epoch}') | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            self._save_checkpoint(cur_epoch, 0.0) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        else: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            logger.info(f"start_epoch:{start_epoch} > {max_stage2_epoch}, skip stage 2 training.") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            cur_epoch = start_epoch | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        #stage 3 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        self.model.encoder.unfreeze() | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        self.scheduler.reset() | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        for epoch in range(cur_epoch, max_stage2_epoch + self.config.train.num_epochs3 + 1): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            # 训练一个epoch | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            train_loss = self.train_epoch_stage3(epoch) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            #train_loss = self.train_epoch_stage2(epoch) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            #train_loss = self.train_epoch(epoch) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            # 保存检查点 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            if epoch % self.config.train.save_freq == 0: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                self._save_checkpoint(epoch, train_loss) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                logger.info(f'Checkpoint saved at epoch {epoch}') | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        # 训练完成 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        total_time = time.time() - start_time | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					         | 
				
			
			
		
	
	
		
			
				
					| 
						
							
								
							
						
						
						
					 | 
				
				 | 
				
					
  |