|
|
@ -147,6 +147,7 @@ class Trainer: |
|
|
|
self.scheduler = LearningRateScheduler(config.train.learning_rate_schedule, config.train.weight_decay, self.model.parameters()) |
|
|
|
|
|
|
|
self.loss_manager = LossManager(ablation="none") |
|
|
|
self.best_loss = float('inf') |
|
|
|
logger.gpu_memory_stats("训练器初始化后") |
|
|
|
|
|
|
|
self.sampler = NormalPerPoint( |
|
|
@ -324,8 +325,8 @@ class Trainer: |
|
|
|
logger.info(f'Train Epoch: {epoch:4d}]\t' |
|
|
|
f'Loss: {current_loss:.6f}') |
|
|
|
if loss_details: logger.info(f"Loss Details: {loss_details}") |
|
|
|
dot = make_dot((mnfld_pred), params=dict(list(self.model.named_parameters()) + [('mnfld_pnts', mnfld_pnts)])) |
|
|
|
dot.render("forward_graph1", format="png") # 这会保存计算图为png格式 |
|
|
|
#dot = make_dot((mnfld_pred), params=dict(list(self.model.named_parameters()) + [('mnfld_pnts', mnfld_pnts)])) |
|
|
|
#dot.render("forward_graph1", format="png") # 这会保存计算图为png格式 |
|
|
|
|
|
|
|
return total_loss # 对于单批次训练,直接返回当前损失 |
|
|
|
|
|
|
@ -358,7 +359,7 @@ class Trainer: |
|
|
|
logger.warning(f"Patch {patch_id} has no valid points.") |
|
|
|
continue |
|
|
|
|
|
|
|
nonmnfld_pnts, psdf = self.sampler.get_norm_points(points[:,0:3], points[:,3:6]) # 生成非流形点 |
|
|
|
nonmnfld_pnts, psdf = self.sampler.get_norm_points(points[:,0:3], points[:,3:6],0.1) # 生成非流形点 |
|
|
|
all_points.append(points) |
|
|
|
valid_patch_ids.append(patch_id) |
|
|
|
nonmnfld_pnts_list.append(nonmnfld_pnts) |
|
|
@ -480,8 +481,8 @@ class Trainer: |
|
|
|
subloss_names = ["manifold", "normals", "eikonal", "offsurface", "psdf"] |
|
|
|
logger.info(" ".join([f"{name}: {weighted_avg[i].item():.6f}" for i, name in enumerate(subloss_names)])) |
|
|
|
|
|
|
|
dot = make_dot((mnfld_pred, nonmnfld_pred), params=dict(list(self.model.named_parameters()) + [('mnfld_pnts', mnfld_pnts), ('nonmnfld_pnts', nonmnfld_pnts)])) |
|
|
|
dot.render("forward_graph2", format="png") # 这会保存计算图为png格式 |
|
|
|
#dot = make_dot((mnfld_pred, nonmnfld_pred), params=dict(list(self.model.named_parameters()) + [('mnfld_pnts', mnfld_pnts), ('nonmnfld_pnts', nonmnfld_pnts)])) |
|
|
|
#dot.render("forward_graph2", format="png") # 这会保存计算图为png格式 |
|
|
|
|
|
|
|
|
|
|
|
avg_loss = sum(losses) / len(losses) |
|
|
@ -607,7 +608,7 @@ class Trainer: |
|
|
|
_, _mnfld_face_indices_mask, _mnfld_operator = self.root.forward(_mnfld_pnts) |
|
|
|
|
|
|
|
# 生成非流形点 |
|
|
|
_nonmnfld_pnts, _psdf = self.sampler.get_norm_points(_mnfld_pnts, _normals) |
|
|
|
_nonmnfld_pnts, _psdf = self.sampler.get_norm_points(_mnfld_pnts, _normals, 0.1) |
|
|
|
_, _nonmnfld_face_indices_mask, _nonmnfld_operator = self.root.forward(_nonmnfld_pnts) |
|
|
|
|
|
|
|
# 更新缓存 |
|
|
@ -664,8 +665,8 @@ class Trainer: |
|
|
|
_nonmnfld_face_indices_mask[start_idx:end_idx], |
|
|
|
_nonmnfld_operator[start_idx:end_idx] |
|
|
|
) |
|
|
|
dot = make_dot((mnfld_pred), params=dict(list(self.model.named_parameters()) + [('mnfld_pnts', mnfld_pnts)])) |
|
|
|
dot.render("forward_graph3", format="png") # 这会保存计算图为png格式 |
|
|
|
#dot = make_dot((mnfld_pred), params=dict(list(self.model.named_parameters()) + [('mnfld_pnts', mnfld_pnts)])) |
|
|
|
#dot.render("forward_graph3", format="png") # 这会保存计算图为png格式 |
|
|
|
|
|
|
|
#logger.print_tensor_stats("psdf",psdf) |
|
|
|
#logger.print_tensor_stats("nonmnfld_pnts",nonmnfld_pnts) |
|
|
@ -735,9 +736,9 @@ class Trainer: |
|
|
|
|
|
|
|
# 记录训练进度 (只记录有效的损失) |
|
|
|
logger.info(f'Train Epoch: {epoch:4d}]\t' |
|
|
|
f'Loss: {current_loss:.6f}') |
|
|
|
f'Loss: {total_loss:.6f}') |
|
|
|
if loss_details: logger.info(f"Loss Details: {loss_details}") |
|
|
|
|
|
|
|
self.validate(epoch,total_loss) |
|
|
|
|
|
|
|
return total_loss # 对于单批次训练,直接返回当前损失 |
|
|
|
|
|
|
@ -873,25 +874,15 @@ class Trainer: |
|
|
|
|
|
|
|
return total_loss # 对于单批次训练,直接返回当前损失 |
|
|
|
|
|
|
|
def validate(self, epoch: int) -> float: |
|
|
|
self.model.eval() |
|
|
|
total_loss = 0.0 |
|
|
|
def validate(self, epoch, loss): |
|
|
|
if loss < self.best_loss: |
|
|
|
self.best_loss = loss |
|
|
|
self._save_checkpoint(-1, loss) # 存 best |
|
|
|
logger.info(f'Best Epoch: {epoch}\tAverage Loss: {loss:.6f}') |
|
|
|
return |
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
|
for batch in self.val_loader: |
|
|
|
points = batch['points'].to(self.device) |
|
|
|
gt_sdf = batch['sdf'].to(self.device) |
|
|
|
|
|
|
|
pred_sdf = self.model(points) |
|
|
|
loss = torch.nn.functional.mse_loss(pred_sdf, gt_sdf) |
|
|
|
total_loss += loss.item() |
|
|
|
|
|
|
|
avg_loss = total_loss / len(self.val_loader) |
|
|
|
logger.info(f'Validation Epoch: {epoch}\tAverage Loss: {avg_loss:.6f}') |
|
|
|
return avg_loss |
|
|
|
|
|
|
|
def train(self): |
|
|
|
best_val_loss = float('inf') |
|
|
|
logger.info("Starting training...") |
|
|
|
start_time = time.time() |
|
|
|
self.cached_train_data=None |
|
|
@ -948,7 +939,7 @@ class Trainer: |
|
|
|
self._tracing_model_by_script() |
|
|
|
#self._tracing_model() |
|
|
|
logger.info(f'Training completed in {total_time//3600:.0f}h {(total_time%3600)//60:.0f}m {total_time%60:.2f}s') |
|
|
|
logger.info(f'Best validation loss: {best_val_loss:.6f}') |
|
|
|
logger.info(f'Best validation loss: {self.best_loss:.6f}') |
|
|
|
#self.test_load() |
|
|
|
|
|
|
|
def test_load(self): |
|
|
@ -999,7 +990,10 @@ class Trainer: |
|
|
|
self.model_name |
|
|
|
) |
|
|
|
os.makedirs(checkpoint_dir, exist_ok=True) |
|
|
|
checkpoint_path = os.path.join(checkpoint_dir, f"epoch_{epoch:03d}.pth") |
|
|
|
if epoch >= 0: |
|
|
|
checkpoint_path = os.path.join(checkpoint_dir, f"epoch_{epoch:03d}.pth") |
|
|
|
else: |
|
|
|
checkpoint_path = os.path.join(checkpoint_dir, f"best.pth") |
|
|
|
|
|
|
|
# 只保存状态字典 |
|
|
|
torch.save({ |
|
|
|