| 
						
						
							
								
							
						
						
					 | 
				
				 | 
				
					@ -173,19 +173,24 @@ def train_step(model, batch, optimizer, criterion): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    # 确保模型处于训练模式 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    model.train() | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					     | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    # 检查并设置所有参数的requires_grad | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    # 1. 深拷贝输入数据,确保每次都是新的张量 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    batch = { | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        k: v.clone().detach().requires_grad_(True) if isinstance(v, torch.Tensor) and v.dtype != torch.bool | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        else v | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        for k, v in batch.items() | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    } | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					     | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    # 2. 确保所有模型参数都需要梯度 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    for name, param in model.named_parameters(): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        if not param.requires_grad: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            logger.warning(f"参数 {name} 的requires_grad为False,现在设置为True") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            param.requires_grad = True | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					     | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    # 将所有输入转为requires_grad=True | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    batch['query_points'].requires_grad_(True) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					     | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    # 清零梯度 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    optimizer.zero_grad() | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    # 3. 清零梯度(同时清除优化器和模型的梯度) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    model.zero_grad(set_to_none=True)  # 使用set_to_none=True更彻底地清除梯度 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    optimizer.zero_grad(set_to_none=True) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					     | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    # 前向传播 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    # 4. 前向传播 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    pred_sdf = model( | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        edge_ncs=batch['edge_ncs'], | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        edge_pos=batch['edge_pos'], | 
				
			
			
		
	
	
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
				
				 | 
				
					@ -196,33 +201,43 @@ def train_step(model, batch, optimizer, criterion): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        query_points=batch['query_points'] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    ) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					     | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    # 计算损失 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    # 5. 检查预测输出是否需要梯度 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    if not pred_sdf.requires_grad: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        logger.warning("预测输出不需要梯度!") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        pred_sdf.requires_grad_(True) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					     | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    # 6. 计算损失 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    loss = criterion(pred_sdf, batch['gt_sdf']) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					     | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    # 检查损失是否有效 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    # 7. 检查损失 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    logger.info(f"Loss value: {loss.item()}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    logger.info(f"Loss requires grad: {loss.requires_grad}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					     | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    if not torch.isfinite(loss): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        logger.error(f"损失值无效: {loss.item()}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        raise ValueError("损失值无效") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					     | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    # 反向传播 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    loss.backward() | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    # 8. 反向传播 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    loss.backward(retain_graph=False)  # 使用retain_graph=False确保清理计算图 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					     | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    # 检查梯度 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    total_norm = 0 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    # 9. 详细的梯度检查 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    logger.info("\n梯度信息:") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    for name, param in model.named_parameters(): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        if param.grad is not None: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            param_norm = param.grad.data.norm(2) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            total_norm += param_norm.item() ** 2 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            logger.info(f"  {name}: grad_norm = {param_norm.item()}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            grad_norm = param.grad.norm().item() | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            grad_mean = param.grad.mean().item() | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            grad_std = param.grad.std().item() | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            logger.info(f"  {name}:") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            logger.info(f"    norm: {grad_norm:.6f}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            logger.info(f"    mean: {grad_mean:.6f}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            logger.info(f"    std: {grad_std:.6f}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        else: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            logger.warning(f"  {name}: No gradient!") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    total_norm = total_norm ** 0.5 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    logger.info(f"梯度总范数: {total_norm}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            logger.warning(f"  {name}: 没有梯度!") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					     | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    # 梯度裁剪(可选) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    # 10. 梯度裁剪 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10.0) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					     | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    # 更新参数 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    # 11. 更新参数 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    optimizer.step() | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					     | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    return loss.item() | 
				
			
			
		
	
	
		
			
				
					| 
						
							
								
							
						
						
							
								
							
						
						
					 | 
				
				 | 
				
					@ -260,7 +275,7 @@ def train(model, config, num_epochs=10): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    } | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					     | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    # 训练循环 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    for epoch in range(2): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    for epoch in range(num_epochs): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        try: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            loss = train_step(model, batch, optimizer, criterion) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            logger.info(f"Epoch {epoch+1}/{num_epochs}, Loss: {loss:.6f}") | 
				
			
			
		
	
	
		
			
				
					| 
						
							
								
							
						
						
						
					 | 
				
				 | 
				
					
  |