| 
						
						
							
								
							
						
						
					 | 
				
				 | 
				
					@ -307,7 +307,7 @@ class Trainer: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        self.model.train() | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        total_loss = 0.0 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        step = 0 # 如果你的训练是分批次的,这里应该用批次索引 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        batch_size = 8192  # 设置合适的batch大小 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        batch_size = 8192*16  # 设置合适的batch大小 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        # 数据处理 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        # manfld | 
				
			
			
		
	
	
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
				
				 | 
				
					@ -319,19 +319,16 @@ class Trainer: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        if epoch % 10 == 1 or self.cached_train_data is None: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            # 计算流形点的掩码和操作符 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            # 生成非流形点 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            _psdf_pnts, _psdf = self.sampler.get_norm_points(_mnfld_pnts, _normals,local_sigma=0.001) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            _nonmnfld_pnts = self.sampler.get_points(_mnfld_pnts, local_sigma=0.01): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            _nonmnfld_pnts, _psdf = self.sampler.get_norm_points(_mnfld_pnts, _normals) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					             | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            # 更新缓存 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            self.cached_train_data = { | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                "nonmnfld_pnts": _nonmnfld_pnts, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                "psdf_pnts": _psdf_pnts, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                "psdf": _psdf, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            } | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        else: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            # 从缓存中读取数据 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            _nonmnfld_pnts = self.cached_train_data["nonmnfld_pnts"] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            _psdf_pnts =  self.cached_train_data["psdf_pnts"] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            _psdf = self.cached_train_data["psdf"] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					         | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        logger.debug((_mnfld_pnts, _nonmnfld_pnts, _psdf)) | 
				
			
			
		
	
	
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
				
				 | 
				
					@ -353,13 +350,11 @@ class Trainer: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					             | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            # 非流形点使用缓存数据(整个batch共享) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            nonmnfld_pnts = _nonmnfld_pnts[start_idx:end_idx] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            psdf_pnts = _psdf_pnts[start_idx:end_idx] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            psdf = _psdf[start_idx:end_idx] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            # --- 准备模型输入,启用梯度 --- | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            mnfld_pnts.requires_grad_(True) # 在检查之后启用梯度 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            nonmnfld_pnts.requires_grad_(True) # 在检查之后启用梯度 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            psdf_pnts.requires_grad_(True) # 在检查之后启用梯度 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            # --- 前向传播 --- | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            mnfld_pred = self.model.forward_background( | 
				
			
			
		
	
	
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
				
				 | 
				
					@ -368,9 +363,6 @@ class Trainer: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            nonmnfld_pred = self.model.forward_background( | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                nonmnfld_pnts | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            ) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            psdf_pred = self.model.forward_background( | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                psdf_pnts | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            ) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
	
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
				
				 | 
				
					@ -390,7 +382,6 @@ class Trainer: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    loss, loss_details = self.loss_manager.compute_loss( | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        mnfld_pnts, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        nonmnfld_pnts, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        psdf_pnts, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        normals, # 传递检查过的 normals | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        gt_sdf, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        mnfld_pred, | 
				
			
			
		
	
	
		
			
				
					| 
						
							
								
							
						
						
						
					 | 
				
				 | 
				
					
  |