|
|
@ -75,6 +75,8 @@ class Trainer: |
|
|
|
self.data = process_single_step(step_path=input_step, output_path=data_path, sample_normal_vector=args.use_normal,sample_sdf_points=not args.only_zero_surface) |
|
|
|
|
|
|
|
logger.gpu_memory_stats("数据预处理后") |
|
|
|
|
|
|
|
self.train_surf_ncs = torch.tensor(self.data["train_surf_ncs"],dtype=torch.float32,device=self.device) # |
|
|
|
# 将曲面点云列表转换为 (N*M, 4) 数组 |
|
|
|
surfs = self.data["surf_ncs"] |
|
|
|
|
|
|
@ -98,6 +100,7 @@ class Trainer: |
|
|
|
else: |
|
|
|
self.sdf_data = surface_sdf_data |
|
|
|
print_data_distribution(self.sdf_data) |
|
|
|
logger.debug(self.sdf_data.shape) |
|
|
|
logger.gpu_memory_stats("SDF数据准备后") |
|
|
|
# 初始化数据集 |
|
|
|
#self.brep_data = load_brep_file(self.config.data.pkl_path) |
|
|
@ -200,7 +203,7 @@ class Trainer: |
|
|
|
# # 返回合并后的边界框 |
|
|
|
# return torch.cat([global_min, global_max]) |
|
|
|
# return [-0.5,] # 这个是错误的 |
|
|
|
def train_epoch_stage1(self, epoch: int): |
|
|
|
def train_epoch_stage1_(self, epoch: int): |
|
|
|
total_loss = 0.0 |
|
|
|
total_loss_details = { |
|
|
|
"manifold": 0.0, |
|
|
@ -293,8 +296,303 @@ class Trainer: |
|
|
|
return total_loss # 返回平均损失而非累计值 |
|
|
|
|
|
|
|
|
|
|
|
def train_epoch_stage1(self, epoch: int) -> float: |
|
|
|
# --- 1. 检查输入数据 --- |
|
|
|
# 注意:假设 self.sdf_data 包含 [x, y, z, nx, ny, nz, sdf] (7列) 或 [x, y, z, sdf] (4列) |
|
|
|
# 并且 SDF 值总是在最后一列 |
|
|
|
if self.train_surf_ncs is None: |
|
|
|
logger.error(f"Epoch {epoch}: self.train_surf_ncs is None. Cannot train.") |
|
|
|
return float('inf') |
|
|
|
|
|
|
|
self.model.train() |
|
|
|
total_loss = 0.0 |
|
|
|
step = 0 # 如果你的训练是分批次的,这里应该用批次索引 |
|
|
|
batch_size = 8192 # 设置合适的batch大小 |
|
|
|
|
|
|
|
# 数据处理 |
|
|
|
# manfld |
|
|
|
_mnfld_pnts = self.train_surf_ncs[:, 0:3].clone().detach() # 取前3列作为点 |
|
|
|
_normals = self.train_surf_ncs[:, 3:6].clone().detach() # 取中间3列作为法线 |
|
|
|
_gt_sdf = self.train_surf_ncs[:, -1].clone().detach() # 取最后一列作为SDF真值 |
|
|
|
|
|
|
|
# 检查是否需要重新计算缓存 |
|
|
|
if epoch % 10 == 1 or self.cached_train_data is None: |
|
|
|
# 计算流形点的掩码和操作符 |
|
|
|
# 生成非流形点 |
|
|
|
_nonmnfld_pnts, _psdf = self.sampler.get_norm_points(_mnfld_pnts, _normals) |
|
|
|
|
|
|
|
# 更新缓存 |
|
|
|
self.cached_train_data = { |
|
|
|
"nonmnfld_pnts": _nonmnfld_pnts, |
|
|
|
"psdf": _psdf, |
|
|
|
} |
|
|
|
else: |
|
|
|
# 从缓存中读取数据 |
|
|
|
_nonmnfld_pnts = self.cached_train_data["nonmnfld_pnts"] |
|
|
|
_psdf = self.cached_train_data["psdf"] |
|
|
|
|
|
|
|
logger.debug((_mnfld_pnts, _nonmnfld_pnts, _psdf)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 将数据分成多个batch |
|
|
|
num_points = self.train_surf_ncs.shape[0] |
|
|
|
num_batches = (num_points + batch_size - 1) // batch_size |
|
|
|
|
|
|
|
for batch_idx in range(num_batches): |
|
|
|
start_idx = batch_idx * batch_size |
|
|
|
end_idx = min((batch_idx + 1) * batch_size, num_points) |
|
|
|
|
|
|
|
# 获取当前batch的数据 |
|
|
|
mnfld_pnts = _mnfld_pnts[start_idx:end_idx] # 流形点 |
|
|
|
gt_sdf = _gt_sdf[start_idx:end_idx] # SDF真值 |
|
|
|
normals = _normals[start_idx:end_idx] if args.use_normal else None # 法线 |
|
|
|
|
|
|
|
# 非流形点使用缓存数据(整个batch共享) |
|
|
|
nonmnfld_pnts = _nonmnfld_pnts[start_idx:end_idx] |
|
|
|
psdf = _psdf[start_idx:end_idx] |
|
|
|
|
|
|
|
# --- 准备模型输入,启用梯度 --- |
|
|
|
mnfld_pnts.requires_grad_(True) # 在检查之后启用梯度 |
|
|
|
nonmnfld_pnts.requires_grad_(True) # 在检查之后启用梯度 |
|
|
|
|
|
|
|
# --- 前向传播 --- |
|
|
|
mnfld_pred = self.model.forward_background( |
|
|
|
mnfld_pnts |
|
|
|
) |
|
|
|
nonmnfld_pred = self.model.forward_background( |
|
|
|
nonmnfld_pnts |
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# --- 计算损失 --- |
|
|
|
loss = torch.tensor(float('nan'), device=self.device) # 初始化为 NaN 以防计算失败 |
|
|
|
loss_details = {} |
|
|
|
try: |
|
|
|
# --- 3. 检查损失计算前的输入 --- |
|
|
|
# (points 已经启用梯度,不需要再次检查inf/nan,但检查pred_sdf和gt_sdf) |
|
|
|
#if check_tensor(pred_sdf, "Predicted SDF (Loss Input)", epoch, step): raise ValueError("Bad pred_sdf before loss") |
|
|
|
#if check_tensor(gt_sdf, "GT SDF (Loss Input)", epoch, step): raise ValueError("Bad gt_sdf before loss") |
|
|
|
if args.use_normal: |
|
|
|
# 检查法线和带梯度的点 |
|
|
|
#if check_tensor(normals, "Normals (Loss Input)", epoch, step): raise ValueError("Bad normals before loss") |
|
|
|
#if check_tensor(points, "Points (Loss Input)", epoch, step): raise ValueError("Bad points before loss") |
|
|
|
logger.gpu_memory_stats("计算损失前") |
|
|
|
loss, loss_details = self.loss_manager.compute_loss( |
|
|
|
mnfld_pnts, |
|
|
|
nonmnfld_pnts, |
|
|
|
normals, # 传递检查过的 normals |
|
|
|
gt_sdf, |
|
|
|
mnfld_pred, |
|
|
|
nonmnfld_pred, |
|
|
|
psdf |
|
|
|
) |
|
|
|
else: |
|
|
|
loss = torch.nn.functional.mse_loss(pred_sdf, gt_sdf) |
|
|
|
|
|
|
|
# --- 4. 检查损失计算结果 --- |
|
|
|
if self.debug_mode: |
|
|
|
logger.print_tensor_stats("psdf",psdf) |
|
|
|
logger.print_tensor_stats("nonmnfld_pred",nonmnfld_pred) |
|
|
|
if check_tensor(loss, "Calculated Loss", epoch, step): |
|
|
|
logger.error(f"Epoch {epoch} Step {step}: Loss calculation resulted in inf/nan.") |
|
|
|
if loss_details: logger.error(f"Loss Details: {loss_details}") |
|
|
|
return float('inf') # 如果损失无效,停止这个epoch |
|
|
|
|
|
|
|
except Exception as loss_e: |
|
|
|
logger.error(f"Epoch {epoch} Step {step}: Error during loss calculation: {loss_e}", exc_info=True) |
|
|
|
return float('inf') # 如果计算出错,停止这个epoch |
|
|
|
logger.gpu_memory_stats("损失计算后") |
|
|
|
|
|
|
|
# --- 反向传播和优化 --- |
|
|
|
try: |
|
|
|
# 反向传播 |
|
|
|
self.scheduler.optimizer.zero_grad() # 清空梯度 |
|
|
|
loss.backward() # 反向传播 |
|
|
|
self.scheduler.optimizer.step() # 更新参数 |
|
|
|
self.scheduler.step(loss,epoch) |
|
|
|
except Exception as backward_e: |
|
|
|
logger.error(f"Epoch {epoch} Step {step}: Error during backward pass or optimizer step: {backward_e}", exc_info=True) |
|
|
|
# 如果你想看是哪个操作导致的,可以启用 anomaly detection |
|
|
|
# torch.autograd.set_detect_anomaly(True) # 放在训练开始前 |
|
|
|
return float('inf') # 如果反向传播或优化出错,停止这个epoch |
|
|
|
|
|
|
|
|
|
|
|
# --- 记录和累加损失 --- |
|
|
|
current_loss = loss.item() |
|
|
|
if not np.isfinite(current_loss): # 再次确认损失是有效的数值 |
|
|
|
logger.error(f"Epoch {epoch} Step {step}: Loss item is not finite ({current_loss}).") |
|
|
|
return float('inf') |
|
|
|
|
|
|
|
total_loss += current_loss |
|
|
|
del loss |
|
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
|
|
if epoch % 100 == 0: |
|
|
|
# 记录训练进度 (只记录有效的损失) |
|
|
|
logger.info(f'Train Epoch: {epoch:4d}]\t' |
|
|
|
f'Loss: {current_loss:.6f}') |
|
|
|
if loss_details: logger.info(f"Loss Details: {loss_details}") |
|
|
|
|
|
|
|
return total_loss # 对于单批次训练,直接返回当前损失 |
|
|
|
|
|
|
|
def train_epoch_stage2(self, epoch: int) -> float: |
|
|
|
# --- 1. 检查输入数据 --- |
|
|
|
# 注意:假设 self.sdf_data 包含 [x, y, z, nx, ny, nz, sdf] (7列) 或 [x, y, z, sdf] (4列) |
|
|
|
# 并且 SDF 值总是在最后一列 |
|
|
|
if self.sdf_data is None: |
|
|
|
logger.error(f"Epoch {epoch}: self.sdf_data is None. Cannot train.") |
|
|
|
return float('inf') |
|
|
|
|
|
|
|
self.model.train() |
|
|
|
total_loss = 0.0 |
|
|
|
step = 0 # 如果你的训练是分批次的,这里应该用批次索引 |
|
|
|
batch_size = 8192 * 2 # 设置合适的batch大小 |
|
|
|
|
|
|
|
# 数据处理 |
|
|
|
# manfld |
|
|
|
_mnfld_pnts = self.sdf_data[:, 0:3].clone().detach() # 取前3列作为点 |
|
|
|
_normals = self.sdf_data[:, 3:6].clone().detach() # 取中间3列作为法线 |
|
|
|
_gt_sdf = self.sdf_data[:, -1].clone().detach() # 取最后一列作为SDF真值 |
|
|
|
|
|
|
|
# 检查是否需要重新计算缓存 |
|
|
|
if epoch % 10 == 1 or self.cached_train_data is None: |
|
|
|
# 计算流形点的掩码和操作符 |
|
|
|
_, _mnfld_face_indices_mask, _mnfld_operator = self.root.forward(_mnfld_pnts) |
|
|
|
|
|
|
|
# 生成非流形点 |
|
|
|
_nonmnfld_pnts, _psdf = self.sampler.get_norm_points(_mnfld_pnts, _normals) |
|
|
|
_, _nonmnfld_face_indices_mask, _nonmnfld_operator = self.root.forward(_nonmnfld_pnts) |
|
|
|
|
|
|
|
# 更新缓存 |
|
|
|
self.cached_train_data = { |
|
|
|
"mnfld_face_indices_mask": _mnfld_face_indices_mask, |
|
|
|
"mnfld_operator": _mnfld_operator, |
|
|
|
"nonmnfld_pnts": _nonmnfld_pnts, |
|
|
|
"psdf": _psdf, |
|
|
|
"nonmnfld_face_indices_mask": _nonmnfld_face_indices_mask, |
|
|
|
"nonmnfld_operator": _nonmnfld_operator |
|
|
|
} |
|
|
|
else: |
|
|
|
# 从缓存中读取数据 |
|
|
|
_mnfld_face_indices_mask = self.cached_train_data["mnfld_face_indices_mask"] |
|
|
|
_mnfld_operator = self.cached_train_data["mnfld_operator"] |
|
|
|
_nonmnfld_pnts = self.cached_train_data["nonmnfld_pnts"] |
|
|
|
_psdf = self.cached_train_data["psdf"] |
|
|
|
_nonmnfld_face_indices_mask = self.cached_train_data["nonmnfld_face_indices_mask"] |
|
|
|
_nonmnfld_operator = self.cached_train_data["nonmnfld_operator"] |
|
|
|
|
|
|
|
logger.debug((_mnfld_pnts, _nonmnfld_pnts, _psdf)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 将数据分成多个batch |
|
|
|
num_points = self.sdf_data.shape[0] |
|
|
|
num_batches = (num_points + batch_size - 1) // batch_size |
|
|
|
|
|
|
|
for batch_idx in range(num_batches): |
|
|
|
start_idx = batch_idx * batch_size |
|
|
|
end_idx = min((batch_idx + 1) * batch_size, num_points) |
|
|
|
|
|
|
|
# 获取当前batch的数据 |
|
|
|
mnfld_pnts = _mnfld_pnts[start_idx:end_idx] # 流形点 |
|
|
|
gt_sdf = _gt_sdf[start_idx:end_idx] # SDF真值 |
|
|
|
normals = _normals[start_idx:end_idx] if args.use_normal else None # 法线 |
|
|
|
|
|
|
|
# 非流形点使用缓存数据(整个batch共享) |
|
|
|
nonmnfld_pnts = _nonmnfld_pnts[start_idx:end_idx] |
|
|
|
psdf = _psdf[start_idx:end_idx] |
|
|
|
|
|
|
|
# --- 准备模型输入,启用梯度 --- |
|
|
|
mnfld_pnts.requires_grad_(True) # 在检查之后启用梯度 |
|
|
|
nonmnfld_pnts.requires_grad_(True) # 在检查之后启用梯度 |
|
|
|
|
|
|
|
# --- 前向传播 --- |
|
|
|
mnfld_pred = self.model.forward_without_octree( |
|
|
|
mnfld_pnts, |
|
|
|
_mnfld_face_indices_mask[start_idx:end_idx], |
|
|
|
_mnfld_operator[start_idx:end_idx] |
|
|
|
) |
|
|
|
nonmnfld_pred = self.model.forward_without_octree( |
|
|
|
nonmnfld_pnts, |
|
|
|
_nonmnfld_face_indices_mask[start_idx:end_idx], |
|
|
|
_nonmnfld_operator[start_idx:end_idx] |
|
|
|
) |
|
|
|
|
|
|
|
def train_epoch(self, epoch: int) -> float: |
|
|
|
#logger.print_tensor_stats("psdf",psdf) |
|
|
|
#logger.print_tensor_stats("nonmnfld_pred",nonmnfld_pred) |
|
|
|
|
|
|
|
# --- 计算损失 --- |
|
|
|
loss = torch.tensor(float('nan'), device=self.device) # 初始化为 NaN 以防计算失败 |
|
|
|
loss_details = {} |
|
|
|
try: |
|
|
|
# --- 3. 检查损失计算前的输入 --- |
|
|
|
# (points 已经启用梯度,不需要再次检查inf/nan,但检查pred_sdf和gt_sdf) |
|
|
|
#if check_tensor(pred_sdf, "Predicted SDF (Loss Input)", epoch, step): raise ValueError("Bad pred_sdf before loss") |
|
|
|
#if check_tensor(gt_sdf, "GT SDF (Loss Input)", epoch, step): raise ValueError("Bad gt_sdf before loss") |
|
|
|
if args.use_normal: |
|
|
|
# 检查法线和带梯度的点 |
|
|
|
#if check_tensor(normals, "Normals (Loss Input)", epoch, step): raise ValueError("Bad normals before loss") |
|
|
|
#if check_tensor(points, "Points (Loss Input)", epoch, step): raise ValueError("Bad points before loss") |
|
|
|
logger.gpu_memory_stats("计算损失前") |
|
|
|
loss, loss_details = self.loss_manager.compute_loss( |
|
|
|
mnfld_pnts, |
|
|
|
nonmnfld_pnts, |
|
|
|
normals, # 传递检查过的 normals |
|
|
|
gt_sdf, |
|
|
|
mnfld_pred, |
|
|
|
nonmnfld_pred, |
|
|
|
psdf |
|
|
|
) |
|
|
|
else: |
|
|
|
loss = torch.nn.functional.mse_loss(pred_sdf, gt_sdf) |
|
|
|
|
|
|
|
# --- 4. 检查损失计算结果 --- |
|
|
|
if self.debug_mode: |
|
|
|
if check_tensor(loss, "Calculated Loss", epoch, step): |
|
|
|
logger.error(f"Epoch {epoch} Step {step}: Loss calculation resulted in inf/nan.") |
|
|
|
if loss_details: logger.error(f"Loss Details: {loss_details}") |
|
|
|
return float('inf') # 如果损失无效,停止这个epoch |
|
|
|
|
|
|
|
except Exception as loss_e: |
|
|
|
logger.error(f"Epoch {epoch} Step {step}: Error during loss calculation: {loss_e}", exc_info=True) |
|
|
|
return float('inf') # 如果计算出错,停止这个epoch |
|
|
|
logger.gpu_memory_stats("损失计算后") |
|
|
|
|
|
|
|
# --- 反向传播和优化 --- |
|
|
|
try: |
|
|
|
# 反向传播 |
|
|
|
self.scheduler.optimizer.zero_grad() # 清空梯度 |
|
|
|
loss.backward() # 反向传播 |
|
|
|
self.scheduler.optimizer.step() # 更新参数 |
|
|
|
self.scheduler.step(loss,epoch) |
|
|
|
except Exception as backward_e: |
|
|
|
logger.error(f"Epoch {epoch} Step {step}: Error during backward pass or optimizer step: {backward_e}", exc_info=True) |
|
|
|
# 如果你想看是哪个操作导致的,可以启用 anomaly detection |
|
|
|
# torch.autograd.set_detect_anomaly(True) # 放在训练开始前 |
|
|
|
return float('inf') # 如果反向传播或优化出错,停止这个epoch |
|
|
|
|
|
|
|
|
|
|
|
# --- 记录和累加损失 --- |
|
|
|
current_loss = loss.item() |
|
|
|
if not np.isfinite(current_loss): # 再次确认损失是有效的数值 |
|
|
|
logger.error(f"Epoch {epoch} Step {step}: Loss item is not finite ({current_loss}).") |
|
|
|
return float('inf') |
|
|
|
|
|
|
|
total_loss += current_loss |
|
|
|
del loss |
|
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
|
|
|
|
|
|
# 记录训练进度 (只记录有效的损失) |
|
|
|
logger.info(f'Train Epoch: {epoch:4d}]\t' |
|
|
|
f'Loss: {current_loss:.6f}') |
|
|
|
if loss_details: logger.info(f"Loss Details: {loss_details}") |
|
|
|
|
|
|
|
return total_loss # 对于单批次训练,直接返回当前损失 |
|
|
|
|
|
|
|
def train_epoch(self, epoch: int,resample:bool=True) -> float: |
|
|
|
# --- 1. 检查输入数据 --- |
|
|
|
# 注意:假设 self.sdf_data 包含 [x, y, z, nx, ny, nz, sdf] (7列) 或 [x, y, z, sdf] (4列) |
|
|
|
# 并且 SDF 值总是在最后一列 |
|
|
@ -447,16 +745,19 @@ class Trainer: |
|
|
|
best_val_loss = float('inf') |
|
|
|
logger.info("Starting training...") |
|
|
|
start_time = time.time() |
|
|
|
self.cached_train_data=None |
|
|
|
|
|
|
|
start_epoch = 1 |
|
|
|
if args.resume_checkpoint_path: |
|
|
|
start_epoch = self._load_checkpoint(args.resume_checkpoint_path) |
|
|
|
logger.info(f"Loaded model from {args.resume_checkpoint_path}") |
|
|
|
|
|
|
|
self.model.encoder.freeze_stage1() |
|
|
|
for epoch in range(start_epoch, self.config.train.num_epochs + 1): |
|
|
|
# 训练一个epoch |
|
|
|
#train_loss = self.train_epoch_stage1(epoch) |
|
|
|
train_loss = self.train_epoch(epoch) |
|
|
|
train_loss = self.train_epoch_stage1(epoch) |
|
|
|
#train_loss = self.train_epoch_stage2(epoch) |
|
|
|
#train_loss = self.train_epoch(epoch) |
|
|
|
|
|
|
|
# 验证 |
|
|
|
''' |
|
|
@ -476,7 +777,7 @@ class Trainer: |
|
|
|
if epoch % self.config.train.save_freq == 0: |
|
|
|
self._save_checkpoint(epoch, train_loss) |
|
|
|
logger.info(f'Checkpoint saved at epoch {epoch}') |
|
|
|
|
|
|
|
self.model.encoder.unfreeze() |
|
|
|
# 训练完成 |
|
|
|
total_time = time.time() - start_time |
|
|
|
|
|
|
@ -555,8 +856,6 @@ class Trainer: |
|
|
|
logger.error(f"加载checkpoint失败: {str(e)}") |
|
|
|
raise |
|
|
|
|
|
|
|
# ... existing code ... |
|
|
|
|
|
|
|
def _save_octree(self): |
|
|
|
""" |
|
|
|
保存八叉树到文件。 |
|
|
@ -566,6 +865,7 @@ class Trainer: |
|
|
|
self.config.train.checkpoint_dir, |
|
|
|
self.model_name |
|
|
|
) |
|
|
|
os.makedirs(checkpoint_dir, exist_ok=True) |
|
|
|
octree_path = os.path.join(checkpoint_dir, "octree.pth") |
|
|
|
|
|
|
|
try: |
|
|
|