| 
						
						
							
								
							
						
						
					 | 
				
				 | 
				
					@ -4,7 +4,7 @@ import time | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					import os | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					import numpy as np | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					import argparse | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					from torchviz import make_dot | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					from brep2sdf.config.default_config import get_default_config | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					from brep2sdf.data.data import load_brep_file,prepare_sdf_data, print_data_distribution, check_tensor | 
				
			
			
		
	
	
		
			
				
					| 
						
							
								
							
						
						
							
								
							
						
						
					 | 
				
				 | 
				
					@ -324,6 +324,8 @@ class Trainer: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            logger.info(f'Train Epoch: {epoch:4d}]\t' | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        f'Loss: {current_loss:.6f}') | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            if loss_details: logger.info(f"Loss Details: {loss_details}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            dot = make_dot((mnfld_pred), params=dict(list(self.model.named_parameters()) + [('mnfld_pnts', mnfld_pnts)])) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            dot.render("forward_graph1", format="png")  # 这会保存计算图为png格式 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            return total_loss # 对于单批次训练,直接返回当前损失 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
	
		
			
				
					| 
						
							
								
							
						
						
							
								
							
						
						
					 | 
				
				 | 
				
					@ -478,6 +480,9 @@ class Trainer: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    subloss_names = ["manifold", "normals", "eikonal", "offsurface", "psdf"] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    logger.info("  ".join([f"{name}: {weighted_avg[i].item():.6f}" for i, name in enumerate(subloss_names)])) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    dot = make_dot((mnfld_pred, nonmnfld_pred), params=dict(list(self.model.named_parameters()) + [('mnfld_pnts', mnfld_pnts), ('nonmnfld_pnts', nonmnfld_pnts)])) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    dot.render("forward_graph2", format="png")  # 这会保存计算图为png格式 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        avg_loss = sum(losses) / len(losses) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        logger.info(f"Total Loss: {total_loss:.6f} | Avg Loss: {avg_loss:.6f}") | 
				
			
			
		
	
	
		
			
				
					| 
						
							
								
							
						
						
							
								
							
						
						
					 | 
				
				 | 
				
					@ -659,6 +664,8 @@ class Trainer: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                _nonmnfld_face_indices_mask[start_idx:end_idx], | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                _nonmnfld_operator[start_idx:end_idx] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            ) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            dot = make_dot((mnfld_pred), params=dict(list(self.model.named_parameters()) + [('mnfld_pnts', mnfld_pnts)])) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            dot.render("forward_graph3", format="png")  # 这会保存计算图为png格式 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            #logger.print_tensor_stats("psdf",psdf) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            #logger.print_tensor_stats("nonmnfld_pnts",nonmnfld_pnts) | 
				
			
			
		
	
	
		
			
				
					| 
						
							
								
							
						
						
							
								
							
						
						
					 | 
				
				 | 
				
					@ -731,6 +738,7 @@ class Trainer: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    f'Loss: {current_loss:.6f}') | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        if loss_details: logger.info(f"Loss Details: {loss_details}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        return total_loss # 对于单批次训练,直接返回当前损失 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    def train_epoch(self, epoch: int,resample:bool=True) -> float: | 
				
			
			
		
	
	
		
			
				
					| 
						
							
								
							
						
						
						
					 | 
				
				 | 
				
					
  |