mckay 4 weeks ago
parent
commit
4c7b91ffdc
  1. 50
      brep2sdf/train.py

50
brep2sdf/train.py

@ -147,6 +147,7 @@ class Trainer:
self.scheduler = LearningRateScheduler(config.train.learning_rate_schedule, config.train.weight_decay, self.model.parameters())
self.loss_manager = LossManager(ablation="none")
self.best_loss = float('inf')
logger.gpu_memory_stats("训练器初始化后")
self.sampler = NormalPerPoint(
@ -324,8 +325,8 @@ class Trainer:
logger.info(f'Train Epoch: {epoch:4d}]\t'
f'Loss: {current_loss:.6f}')
if loss_details: logger.info(f"Loss Details: {loss_details}")
dot = make_dot((mnfld_pred), params=dict(list(self.model.named_parameters()) + [('mnfld_pnts', mnfld_pnts)]))
dot.render("forward_graph1", format="png") # 这会保存计算图为png格式
#dot = make_dot((mnfld_pred), params=dict(list(self.model.named_parameters()) + [('mnfld_pnts', mnfld_pnts)]))
#dot.render("forward_graph1", format="png") # 这会保存计算图为png格式
return total_loss # 对于单批次训练,直接返回当前损失
@ -358,7 +359,7 @@ class Trainer:
logger.warning(f"Patch {patch_id} has no valid points.")
continue
nonmnfld_pnts, psdf = self.sampler.get_norm_points(points[:,0:3], points[:,3:6]) # 生成非流形点
nonmnfld_pnts, psdf = self.sampler.get_norm_points(points[:,0:3], points[:,3:6],0.1) # 生成非流形点
all_points.append(points)
valid_patch_ids.append(patch_id)
nonmnfld_pnts_list.append(nonmnfld_pnts)
@ -480,8 +481,8 @@ class Trainer:
subloss_names = ["manifold", "normals", "eikonal", "offsurface", "psdf"]
logger.info(" ".join([f"{name}: {weighted_avg[i].item():.6f}" for i, name in enumerate(subloss_names)]))
dot = make_dot((mnfld_pred, nonmnfld_pred), params=dict(list(self.model.named_parameters()) + [('mnfld_pnts', mnfld_pnts), ('nonmnfld_pnts', nonmnfld_pnts)]))
dot.render("forward_graph2", format="png") # 这会保存计算图为png格式
#dot = make_dot((mnfld_pred, nonmnfld_pred), params=dict(list(self.model.named_parameters()) + [('mnfld_pnts', mnfld_pnts), ('nonmnfld_pnts', nonmnfld_pnts)]))
#dot.render("forward_graph2", format="png") # 这会保存计算图为png格式
avg_loss = sum(losses) / len(losses)
@ -607,7 +608,7 @@ class Trainer:
_, _mnfld_face_indices_mask, _mnfld_operator = self.root.forward(_mnfld_pnts)
# 生成非流形点
_nonmnfld_pnts, _psdf = self.sampler.get_norm_points(_mnfld_pnts, _normals)
_nonmnfld_pnts, _psdf = self.sampler.get_norm_points(_mnfld_pnts, _normals, 0.1)
_, _nonmnfld_face_indices_mask, _nonmnfld_operator = self.root.forward(_nonmnfld_pnts)
# 更新缓存
@ -664,8 +665,8 @@ class Trainer:
_nonmnfld_face_indices_mask[start_idx:end_idx],
_nonmnfld_operator[start_idx:end_idx]
)
dot = make_dot((mnfld_pred), params=dict(list(self.model.named_parameters()) + [('mnfld_pnts', mnfld_pnts)]))
dot.render("forward_graph3", format="png") # 这会保存计算图为png格式
#dot = make_dot((mnfld_pred), params=dict(list(self.model.named_parameters()) + [('mnfld_pnts', mnfld_pnts)]))
#dot.render("forward_graph3", format="png") # 这会保存计算图为png格式
#logger.print_tensor_stats("psdf",psdf)
#logger.print_tensor_stats("nonmnfld_pnts",nonmnfld_pnts)
@ -735,9 +736,9 @@ class Trainer:
# 记录训练进度 (只记录有效的损失)
logger.info(f'Train Epoch: {epoch:4d}]\t'
f'Loss: {current_loss:.6f}')
f'Loss: {total_loss:.6f}')
if loss_details: logger.info(f"Loss Details: {loss_details}")
self.validate(epoch,total_loss)
return total_loss # 对于单批次训练,直接返回当前损失
@ -873,25 +874,15 @@ class Trainer:
return total_loss # 对于单批次训练,直接返回当前损失
def validate(self, epoch: int) -> float:
self.model.eval()
total_loss = 0.0
def validate(self, epoch, loss):
if loss < self.best_loss:
self.best_loss = loss
self._save_checkpoint(-1, loss) # 存 best
logger.info(f'Best Epoch: {epoch}\tAverage Loss: {loss:.6f}')
return
with torch.no_grad():
for batch in self.val_loader:
points = batch['points'].to(self.device)
gt_sdf = batch['sdf'].to(self.device)
pred_sdf = self.model(points)
loss = torch.nn.functional.mse_loss(pred_sdf, gt_sdf)
total_loss += loss.item()
avg_loss = total_loss / len(self.val_loader)
logger.info(f'Validation Epoch: {epoch}\tAverage Loss: {avg_loss:.6f}')
return avg_loss
def train(self):
best_val_loss = float('inf')
logger.info("Starting training...")
start_time = time.time()
self.cached_train_data=None
@ -948,7 +939,7 @@ class Trainer:
self._tracing_model_by_script()
#self._tracing_model()
logger.info(f'Training completed in {total_time//3600:.0f}h {(total_time%3600)//60:.0f}m {total_time%60:.2f}s')
logger.info(f'Best validation loss: {best_val_loss:.6f}')
logger.info(f'Best validation loss: {self.best_loss:.6f}')
#self.test_load()
def test_load(self):
@ -999,7 +990,10 @@ class Trainer:
self.model_name
)
os.makedirs(checkpoint_dir, exist_ok=True)
checkpoint_path = os.path.join(checkpoint_dir, f"epoch_{epoch:03d}.pth")
if epoch >= 0:
checkpoint_path = os.path.join(checkpoint_dir, f"epoch_{epoch:03d}.pth")
else:
checkpoint_path = os.path.join(checkpoint_dir, f"best.pth")
# 只保存状态字典
torch.save({

Loading…
Cancel
Save