|
|
@ -197,8 +197,17 @@ class Trainer: |
|
|
|
# return torch.cat([global_min, global_max]) |
|
|
|
# return [-0.5,] # 这个是错误的 |
|
|
|
def train_epoch_stage1(self, epoch: int): |
|
|
|
total_loss = 0.0 # 初始化总损失 |
|
|
|
for step, surf_points in enumerate(self.data['surf_ncs']): # 定义 step 变量 |
|
|
|
total_loss = 0.0 |
|
|
|
total_loss_details = { |
|
|
|
"manifold": 0.0, |
|
|
|
"normals": 0.0 |
|
|
|
} |
|
|
|
accumulated_loss = 0.0 # 新增:用于累积多个step的loss |
|
|
|
|
|
|
|
# 新增:在每个epoch开始时清零梯度 |
|
|
|
self.optimizer.zero_grad() |
|
|
|
|
|
|
|
for step, surf_points in enumerate(self.data['surf_ncs']): |
|
|
|
points = torch.tensor(surf_points, device=self.device) |
|
|
|
gt_sdf = torch.zeros(points.shape[0], device=self.device) |
|
|
|
normals = None |
|
|
@ -211,14 +220,13 @@ class Trainer: |
|
|
|
# --- 前向传播 --- |
|
|
|
self.optimizer.zero_grad() |
|
|
|
pred_sdf = self.model.forward_training_volumes(points, step) |
|
|
|
logger.debug(f"pred_sdf:{pred_sdf}") |
|
|
|
|
|
|
|
if self.debug_mode: |
|
|
|
# --- 检查前向传播的输出 --- |
|
|
|
logger.gpu_memory_stats("前向传播后") |
|
|
|
|
|
|
|
# --- 计算损失 --- |
|
|
|
loss = torch.tensor(float('nan'), device=self.device) # 初始化为 NaN 以防计算失败 |
|
|
|
loss_details = {} |
|
|
|
try: |
|
|
|
if args.use_normal: |
|
|
|
loss, loss_details = self.loss_manager.compute_loss( |
|
|
@ -229,45 +237,45 @@ class Trainer: |
|
|
|
) |
|
|
|
else: |
|
|
|
loss = torch.nn.functional.mse_loss(pred_sdf, gt_sdf) |
|
|
|
loss_details = {} # 确保变量初始化 |
|
|
|
|
|
|
|
if self.debug_mode: |
|
|
|
if check_tensor(loss, "Calculated Loss", epoch, step): |
|
|
|
logger.error(f"Epoch {epoch} Step {step}: Loss calculation resulted in inf/nan.") |
|
|
|
if loss_details: |
|
|
|
logger.error(f"Loss Details: {loss_details}") |
|
|
|
return float('inf') # 如果损失无效,停止这个epoch |
|
|
|
except Exception as loss_e: |
|
|
|
logger.error(f"Epoch {epoch} Step {step}: Error during loss calculation: {loss_e}", exc_info=True) |
|
|
|
return float('inf') # 如果计算出错,停止这个epoch |
|
|
|
|
|
|
|
# --- 反向传播和优化 --- |
|
|
|
try: |
|
|
|
loss.backward() |
|
|
|
# --- (推荐) 添加梯度裁剪 --- |
|
|
|
# 防止梯度爆炸,这可能是导致 inf/nan 的原因之一 |
|
|
|
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) # 范数裁剪 |
|
|
|
|
|
|
|
# 修改:累积loss而不是立即backward |
|
|
|
accumulated_loss += loss / self.config.train.accumulation_steps # 假设配置中有accumulation_steps |
|
|
|
current_loss = loss.item() |
|
|
|
total_loss += current_loss |
|
|
|
for key in total_loss_details: |
|
|
|
if key in loss_details: |
|
|
|
total_loss_details[key] += loss_details[key].item() |
|
|
|
|
|
|
|
# 新增:达到累积步数时执行反向传播 |
|
|
|
if (step + 1) % self.config.train.accumulation_steps == 0: |
|
|
|
accumulated_loss.backward() |
|
|
|
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) |
|
|
|
self.optimizer.step() |
|
|
|
except Exception as backward_e: |
|
|
|
logger.error(f"Epoch {epoch} Step {step}: Error during backward pass or optimizer step: {backward_e}", exc_info=True) |
|
|
|
return float('inf') # 如果反向传播或优化出错,停止这个epoch |
|
|
|
self.optimizer.zero_grad() |
|
|
|
accumulated_loss = 0.0 # 重置累积loss |
|
|
|
|
|
|
|
# --- 记录和累加损失 --- |
|
|
|
current_loss = loss.item() |
|
|
|
if not np.isfinite(current_loss): # 再次确认损失是有效的数值 |
|
|
|
logger.error(f"Epoch {epoch} Step {step}: Loss item is not finite ({current_loss}).") |
|
|
|
return float('inf') |
|
|
|
# 记录日志保持不变 ... |
|
|
|
|
|
|
|
total_loss += current_loss |
|
|
|
except Exception as loss_e: |
|
|
|
logger.error(f"Error in step {step}: {loss_e}") |
|
|
|
continue |
|
|
|
|
|
|
|
# --- 内存管理 --- |
|
|
|
del loss |
|
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
|
|
# 记录训练进度 (只记录有效的损失) |
|
|
|
logger.info(f'Train Epoch: {epoch:4d}]\t' |
|
|
|
f'Loss: {current_loss:.6f}') |
|
|
|
if loss_details: logger.info(f"Loss Details: {loss_details}") |
|
|
|
# 新增:处理最后未达到累积步数的剩余loss |
|
|
|
if accumulated_loss != 0: |
|
|
|
accumulated_loss.backward() |
|
|
|
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) |
|
|
|
self.optimizer.step() |
|
|
|
|
|
|
|
return total_loss |
|
|
|
# 计算并记录epoch损失 |
|
|
|
logger.info(f'Train Epoch: {epoch:4d}]\t' |
|
|
|
f'Loss: {total_loss:.6f}') |
|
|
|
logger.info(f"Loss Details: {total_loss_details}") |
|
|
|
return total_loss # 返回平均损失而非累计值 |
|
|
|
|
|
|
|
|
|
|
|
def train_epoch(self, epoch: int) -> float: |
|
|
|