| 
						
						
							
								
							
						
						
					 | 
					@ -6,50 +6,25 @@ from torch.utils.data import DataLoader | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					from brep2sdf.data.data import BRepSDFDataset | 
					 | 
					 | 
					from brep2sdf.data.data import BRepSDFDataset | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					from brep2sdf.networks.encoder import BRepToSDF, sdf_loss | 
					 | 
					 | 
					from brep2sdf.networks.encoder import BRepToSDF, sdf_loss | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					from brep2sdf.utils.logger import logger | 
					 | 
					 | 
					from brep2sdf.utils.logger import logger | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					from brep2sdf.config.default_config import get_default_config, load_config | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					import wandb | 
					 | 
					 | 
					import wandb | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					
 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					def main(): | 
					 | 
					 | 
					def main(): | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					    # 使用字典存储配置参数 | 
					 | 
					 | 
					    # 获取配置 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					    config = { | 
					 | 
					 | 
					    config = get_default_config() | 
				
			
			
				
				
			
		
	
		
		
			
				
					 | 
					 | 
					        # 数据路径 | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        'brep_dir': '/home/wch/brep2sdf/test_data/pkl', | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        'sdf_dir': '/home/wch/brep2sdf/test_data/sdf', | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        'valid_data_dir': '/home/wch/brep2sdf/test_data/result/pkl', | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        'save_dir': 'checkpoints', | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					         | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        # 训练参数 | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        'batch_size': 32, | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        'num_workers': 4, | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        'num_epochs': 100, | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        'learning_rate': 1e-4, | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        'min_lr': 1e-6, | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        'weight_decay': 0.01, | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        'grad_weight': 0.1, | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        'max_grad_norm': 1.0, | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					         | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        # 模型参数 | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        'brep_feature_dim': 48, | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        'use_cf': True, | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        'embed_dim': 768, | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        'latent_dim': 256, | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					         | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        # wandb参数 | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        'use_wandb': True, | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        'project_name': 'brep2sdf', | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        'run_name': 'training_run', | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        'log_interval': 10 | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					    } | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					     | 
					 | 
					 | 
					     | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					    # 创建保存目录 | 
					 | 
					 | 
					    # 创建所有保存目录 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					    os.makedirs(config['save_dir'], exist_ok=True) | 
					 | 
					 | 
					    os.makedirs(config.data.model_save_dir, exist_ok=True) | 
				
			
			
				
				
			
		
	
		
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					    os.makedirs(config.data.log_save_dir, exist_ok=True) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					    os.makedirs(config.data.result_save_dir, exist_ok=True) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					     | 
					 | 
					 | 
					     | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					    # 初始化wandb (添加超时设置和离线模式) | 
					 | 
					 | 
					    # 初始化wandb | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					    if config['use_wandb']: | 
					 | 
					 | 
					    if config.log.use_wandb: | 
				
			
			
				
				
			
		
	
		
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					        try: | 
					 | 
					 | 
					        try: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            wandb.init( | 
					 | 
					 | 
					            wandb.init( | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					                project=config['project_name'], | 
					 | 
					 | 
					                project=config.log.project_name, | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					                name=config['run_name'], | 
					 | 
					 | 
					                name=config.log.run_name, | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					                config=config, | 
					 | 
					 | 
					                config=config.__dict__, | 
				
			
			
				
				
			
		
	
		
		
	
		
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					                settings=wandb.Settings( | 
					 | 
					 | 
					                settings=wandb.Settings( | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					                    init_timeout=180,  # 增加超时时间 | 
					 | 
					 | 
					                    init_timeout=180,  # 增加超时时间 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					                    _disable_stats=True,  # 禁用统计 | 
					 | 
					 | 
					                    _disable_stats=True,  # 禁用统计 | 
				
			
			
		
	
	
		
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
					@ -60,7 +35,7 @@ def main(): | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            logger.info("Wandb initialized in offline mode") | 
					 | 
					 | 
					            logger.info("Wandb initialized in offline mode") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        except Exception as e: | 
					 | 
					 | 
					        except Exception as e: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            logger.warning(f"Failed to initialize wandb: {str(e)}") | 
					 | 
					 | 
					            logger.warning(f"Failed to initialize wandb: {str(e)}") | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					            config['use_wandb'] = False  # 禁用wandb | 
					 | 
					 | 
					            config.log.use_wandb = False | 
				
			
			
				
				
			
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					            logger.warning("Continuing without wandb logging") | 
					 | 
					 | 
					            logger.warning("Continuing without wandb logging") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					     | 
					 | 
					 | 
					     | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					    # 初始化训练器并开始训练 | 
					 | 
					 | 
					    # 初始化训练器并开始训练 | 
				
			
			
		
	
	
		
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
					@ -75,15 +50,15 @@ class Trainer: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					         | 
					 | 
					 | 
					         | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        # 初始化数据集 | 
					 | 
					 | 
					        # 初始化数据集 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        self.train_dataset = BRepSDFDataset( | 
					 | 
					 | 
					        self.train_dataset = BRepSDFDataset( | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					            brep_dir=config['brep_dir'], | 
					 | 
					 | 
					            brep_dir=config.data.brep_dir, | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					            sdf_dir=config['sdf_dir'], | 
					 | 
					 | 
					            sdf_dir=config.data.sdf_dir, | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					            valid_data_dir=config['valid_data_dir'], | 
					 | 
					 | 
					            valid_data_dir=config.data.valid_data_dir, | 
				
			
			
				
				
			
		
	
		
		
	
		
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					            split='train' | 
					 | 
					 | 
					            split='train' | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        ) | 
					 | 
					 | 
					        ) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        self.val_dataset = BRepSDFDataset( | 
					 | 
					 | 
					        self.val_dataset = BRepSDFDataset( | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					            brep_dir=config['brep_dir'], | 
					 | 
					 | 
					            brep_dir=config.data.brep_dir, | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					            sdf_dir=config['sdf_dir'], | 
					 | 
					 | 
					            sdf_dir=config.data.sdf_dir, | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					            valid_data_dir=config['valid_data_dir'], | 
					 | 
					 | 
					            valid_data_dir=config.data.valid_data_dir, | 
				
			
			
				
				
			
		
	
		
		
	
		
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					            split='val' | 
					 | 
					 | 
					            split='val' | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        ) | 
					 | 
					 | 
					        ) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					         | 
					 | 
					 | 
					         | 
				
			
			
		
	
	
		
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
					@ -93,37 +68,37 @@ class Trainer: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        # 初始化数据加载器 | 
					 | 
					 | 
					        # 初始化数据加载器 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        self.train_loader = DataLoader( | 
					 | 
					 | 
					        self.train_loader = DataLoader( | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            self.train_dataset, | 
					 | 
					 | 
					            self.train_dataset, | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					            batch_size=config['batch_size'], | 
					 | 
					 | 
					            batch_size=config.train.batch_size, | 
				
			
			
				
				
			
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					            shuffle=True, | 
					 | 
					 | 
					            shuffle=True, | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					            num_workers=config['num_workers'] | 
					 | 
					 | 
					            num_workers=config.train.num_workers | 
				
			
			
				
				
			
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					        ) | 
					 | 
					 | 
					        ) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        self.val_loader = DataLoader( | 
					 | 
					 | 
					        self.val_loader = DataLoader( | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            self.val_dataset, | 
					 | 
					 | 
					            self.val_dataset, | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					            batch_size=config['batch_size'], | 
					 | 
					 | 
					            batch_size=config.train.batch_size, | 
				
			
			
				
				
			
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					            shuffle=False, | 
					 | 
					 | 
					            shuffle=False, | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					            num_workers=config['num_workers'] | 
					 | 
					 | 
					            num_workers=config.train.num_workers | 
				
			
			
				
				
			
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					        ) | 
					 | 
					 | 
					        ) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					         | 
					 | 
					 | 
					         | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        # 初始化模型 | 
					 | 
					 | 
					        # 初始化模型 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        self.model = BRepToSDF( | 
					 | 
					 | 
					        self.model = BRepToSDF( | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					            brep_feature_dim=config['brep_feature_dim'], | 
					 | 
					 | 
					            brep_feature_dim=config.model.brep_feature_dim, | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					            use_cf=config['use_cf'], | 
					 | 
					 | 
					            use_cf=config.model.use_cf, | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					            embed_dim=config['embed_dim'], | 
					 | 
					 | 
					            embed_dim=config.model.embed_dim, | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					            latent_dim=config['latent_dim'] | 
					 | 
					 | 
					            latent_dim=config.model.latent_dim | 
				
			
			
				
				
			
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					        ).to(self.device) | 
					 | 
					 | 
					        ).to(self.device) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					         | 
					 | 
					 | 
					         | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        # 初始化优化器 | 
					 | 
					 | 
					        # 初始化优化器 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        self.optimizer = optim.AdamW( | 
					 | 
					 | 
					        self.optimizer = optim.AdamW( | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            self.model.parameters(), | 
					 | 
					 | 
					            self.model.parameters(), | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					            lr=config['learning_rate'], | 
					 | 
					 | 
					            lr=config.train.learning_rate, | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					            weight_decay=config['weight_decay'] | 
					 | 
					 | 
					            weight_decay=config.train.weight_decay | 
				
			
			
				
				
			
		
	
		
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					        ) | 
					 | 
					 | 
					        ) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					         | 
					 | 
					 | 
					         | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        # 学习率调度器 | 
					 | 
					 | 
					        # 学习率调度器 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        self.scheduler = optim.lr_scheduler.CosineAnnealingLR( | 
					 | 
					 | 
					        self.scheduler = optim.lr_scheduler.CosineAnnealingLR( | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            self.optimizer, | 
					 | 
					 | 
					            self.optimizer, | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					            T_max=config['num_epochs'], | 
					 | 
					 | 
					            T_max=config.train.num_epochs, | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					            eta_min=config['min_lr'] | 
					 | 
					 | 
					            eta_min=config.train.min_lr | 
				
			
			
				
				
			
		
	
		
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					        ) | 
					 | 
					 | 
					        ) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					     | 
					 | 
					 | 
					     | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					    def train_epoch(self, epoch): | 
					 | 
					 | 
					    def train_epoch(self, epoch): | 
				
			
			
		
	
	
		
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
					@ -148,7 +123,7 @@ class Trainer: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					                pred_sdf, | 
					 | 
					 | 
					                pred_sdf, | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					                gt_sdf, | 
					 | 
					 | 
					                gt_sdf, | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					                query_points, | 
					 | 
					 | 
					                query_points, | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					                grad_weight=self.config['grad_weight'] | 
					 | 
					 | 
					                grad_weight=self.config.train.grad_weight | 
				
			
			
				
				
			
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					            ) | 
					 | 
					 | 
					            ) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					             | 
					 | 
					 | 
					             | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            # 反向传播 | 
					 | 
					 | 
					            # 反向传播 | 
				
			
			
		
	
	
		
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
					@ -158,19 +133,19 @@ class Trainer: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            # 梯度裁剪 | 
					 | 
					 | 
					            # 梯度裁剪 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            torch.nn.utils.clip_grad_norm_( | 
					 | 
					 | 
					            torch.nn.utils.clip_grad_norm_( | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					                self.model.parameters(),  | 
					 | 
					 | 
					                self.model.parameters(),  | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					                self.config['max_grad_norm'] | 
					 | 
					 | 
					                self.config.train.max_grad_norm | 
				
			
			
				
				
			
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					            ) | 
					 | 
					 | 
					            ) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					             | 
					 | 
					 | 
					             | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            self.optimizer.step() | 
					 | 
					 | 
					            self.optimizer.step() | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            total_loss += loss.item() | 
					 | 
					 | 
					            total_loss += loss.item() | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					             | 
					 | 
					 | 
					             | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            # 打印训练进度 | 
					 | 
					 | 
					            # 打印训练进度 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					            if (batch_idx + 1) % self.config['log_interval'] == 0: | 
					 | 
					 | 
					            if (batch_idx + 1) % self.config.log.log_interval == 0: | 
				
			
			
				
				
			
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					                logger.info(f'Train Epoch: {epoch} [{batch_idx+1}/{len(self.train_loader)}]\t' | 
					 | 
					 | 
					                logger.info(f'Train Epoch: {epoch} [{batch_idx+1}/{len(self.train_loader)}]\t' | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					                          f'Loss: {loss.item():.6f}') | 
					 | 
					 | 
					                          f'Loss: {loss.item():.6f}') | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					             | 
					 | 
					 | 
					             | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            # 记录到wandb | 
					 | 
					 | 
					            # 记录到wandb | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					            if self.config['use_wandb'] and (batch_idx + 1) % self.config['log_interval'] == 0: | 
					 | 
					 | 
					            if self.config.log.use_wandb and (batch_idx + 1) % self.config.log.log_interval == 0: | 
				
			
			
				
				
			
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					                wandb.log({ | 
					 | 
					 | 
					                wandb.log({ | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					                    'batch_loss': loss.item(), | 
					 | 
					 | 
					                    'batch_loss': loss.item(), | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					                    'batch': batch_idx, | 
					 | 
					 | 
					                    'batch': batch_idx, | 
				
			
			
		
	
	
		
		
			
				
					| 
						
							
								
							
						
						
							
								
							
						
						
					 | 
					@ -203,7 +178,7 @@ class Trainer: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					                    pred_sdf, | 
					 | 
					 | 
					                    pred_sdf, | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					                    gt_sdf, | 
					 | 
					 | 
					                    gt_sdf, | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					                    query_points, | 
					 | 
					 | 
					                    query_points, | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					                    grad_weight=self.config['grad_weight'] | 
					 | 
					 | 
					                    grad_weight=self.config.train.grad_weight | 
				
			
			
				
				
			
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					                ) | 
					 | 
					 | 
					                ) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					                 | 
					 | 
					 | 
					                 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					                total_loss += loss.item() | 
					 | 
					 | 
					                total_loss += loss.item() | 
				
			
			
		
	
	
		
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
					@ -211,7 +186,7 @@ class Trainer: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        avg_loss = total_loss / len(self.val_loader) | 
					 | 
					 | 
					        avg_loss = total_loss / len(self.val_loader) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        logger.info(f'Validation Epoch: {epoch}\tAverage Loss: {avg_loss:.6f}') | 
					 | 
					 | 
					        logger.info(f'Validation Epoch: {epoch}\tAverage Loss: {avg_loss:.6f}') | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					         | 
					 | 
					 | 
					         | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					        if self.config['use_wandb']: | 
					 | 
					 | 
					        if self.config.log.use_wandb: | 
				
			
			
				
				
			
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					            wandb.log({ | 
					 | 
					 | 
					            wandb.log({ | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					                'val_loss': avg_loss, | 
					 | 
					 | 
					                'val_loss': avg_loss, | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					                'epoch': epoch | 
					 | 
					 | 
					                'epoch': epoch | 
				
			
			
		
	
	
		
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
					@ -222,35 +197,68 @@ class Trainer: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        best_val_loss = float('inf') | 
					 | 
					 | 
					        best_val_loss = float('inf') | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        logger.info("Starting training...") | 
					 | 
					 | 
					        logger.info("Starting training...") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					         | 
					 | 
					 | 
					         | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					        for epoch in range(1, self.config['num_epochs'] + 1): | 
					 | 
					 | 
					        for epoch in range(1, self.config.train.num_epochs + 1): | 
				
			
			
				
				
			
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					            train_loss = self.train_epoch(epoch) | 
					 | 
					 | 
					            train_loss = self.train_epoch(epoch) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            val_loss = self.validate(epoch) | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            self.scheduler.step() | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					             | 
					 | 
					 | 
					             | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					            # 保存最佳模型 | 
					 | 
					 | 
					            # 定期验证 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					            if val_loss < best_val_loss: | 
					 | 
					 | 
					            if epoch % self.config.train.val_freq == 0: | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					                best_val_loss = val_loss | 
					 | 
					 | 
					                val_loss = self.validate(epoch) | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					                model_path = os.path.join(self.config['save_dir'], 'best_model.pth') | 
					 | 
					 | 
					                 | 
				
			
			
				
				
			
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                # 保存最佳模型 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                if val_loss < best_val_loss: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                    best_val_loss = val_loss | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                    best_model_path = os.path.join( | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                        self.config.data.model_save_dir, | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                        self.config.data.best_model_name.format( | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                            model_name=self.config.data.model_name | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                        ) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                    ) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                    torch.save({ | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                        'epoch': epoch, | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                        'model_state_dict': self.model.state_dict(), | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                        'optimizer_state_dict': self.optimizer.state_dict(), | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                        'val_loss': val_loss, | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                    }, best_model_path) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                    logger.info(f'Saved best model with val_loss: {val_loss:.6f}') | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					             | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            # 定期保存检查点 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            if epoch % self.config.train.save_freq == 0: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                checkpoint_path = os.path.join( | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                    self.config.data.model_save_dir, | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                    self.config.data.checkpoint_format.format( | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                        model_name=self.config.data.model_name, | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                        epoch=epoch | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                    ) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                ) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					                torch.save({ | 
					 | 
					 | 
					                torch.save({ | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					                    'epoch': epoch, | 
					 | 
					 | 
					                    'epoch': epoch, | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					                    'model_state_dict': self.model.state_dict(), | 
					 | 
					 | 
					                    'model_state_dict': self.model.state_dict(), | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					                    'optimizer_state_dict': self.optimizer.state_dict(), | 
					 | 
					 | 
					                    'optimizer_state_dict': self.optimizer.state_dict(), | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					                    'val_loss': val_loss, | 
					 | 
					 | 
					                    'scheduler_state_dict': self.scheduler.state_dict(), | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					                }, model_path) | 
					 | 
					 | 
					                    'train_loss': train_loss, | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					                logger.info(f'Saved best model with val_loss: {val_loss:.6f}') | 
					 | 
					 | 
					                    'val_loss': val_loss if epoch % self.config.train.val_freq == 0 else None, | 
				
			
			
				
				
			
		
	
		
		
	
		
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                }, checkpoint_path) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                logger.info(f'Saved checkpoint at epoch {epoch}') | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					             | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            self.scheduler.step() | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					             | 
					 | 
					 | 
					             | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            # 记录训练信息 | 
					 | 
					 | 
					            # 记录训练信息 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					            logger.info(f'Epoch: {epoch}\tTrain Loss: {train_loss:.6f}\t' | 
					 | 
					 | 
					            log_info = f'Epoch: {epoch}\tTrain Loss: {train_loss:.6f}\t' | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					                       f'Val Loss: {val_loss:.6f}\tLR: {self.scheduler.get_last_lr()[0]:.6f}') | 
					 | 
					 | 
					            if epoch % self.config.train.val_freq == 0: | 
				
			
			
				
				
			
		
	
		
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                log_info += f'Val Loss: {val_loss:.6f}\t' | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            log_info += f'LR: {self.scheduler.get_last_lr()[0]:.6f}' | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            logger.info(log_info) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					             | 
					 | 
					 | 
					             | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            # 记录到wandb | 
					 | 
					 | 
					            # 记录到wandb | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					            if self.config['use_wandb']: | 
					 | 
					 | 
					            if self.config.log.use_wandb: | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					                wandb.log({ | 
					 | 
					 | 
					                log_dict = { | 
				
			
			
				
				
			
		
	
		
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					                    'train_loss': train_loss, | 
					 | 
					 | 
					                    'train_loss': train_loss, | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					                    'val_loss': val_loss, | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					                    'learning_rate': self.scheduler.get_last_lr()[0], | 
					 | 
					 | 
					                    'learning_rate': self.scheduler.get_last_lr()[0], | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					                    'epoch': epoch | 
					 | 
					 | 
					                    'epoch': epoch | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					                }) | 
					 | 
					 | 
					                } | 
				
			
			
				
				
			
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                if epoch % self.config.train.val_freq == 0: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                    log_dict['val_loss'] = val_loss | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                wandb.log(log_dict) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					
 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					if __name__ == '__main__': | 
					 | 
					 | 
					if __name__ == '__main__': | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					    main() | 
					 | 
					 | 
					    main() |