You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
918 lines
40 KiB
918 lines
40 KiB
import torch
|
|
import torch.optim as optim
|
|
import time
|
|
import os
|
|
import numpy as np
|
|
import argparse
|
|
|
|
from brep2sdf.config.default_config import get_default_config
|
|
from brep2sdf.data.data import load_brep_file,prepare_sdf_data, print_data_distribution, check_tensor
|
|
from brep2sdf.data.pre_process_by_mesh import process_single_step
|
|
from brep2sdf.networks.network import Net
|
|
from brep2sdf.networks.octree import OctreeNode
|
|
from brep2sdf.networks.loss import LossManager
|
|
from brep2sdf.networks.patch_graph import PatchGraph
|
|
from brep2sdf.networks.sample import NormalPerPoint
|
|
from brep2sdf.networks.learning_rate import LearningRateScheduler
|
|
from brep2sdf.utils.logger import logger
|
|
|
|
|
|
# 配置命令行参数
|
|
parser = argparse.ArgumentParser(description='STEP文件批量处理工具')
|
|
parser.add_argument('-i', '--input', required=True,
|
|
help='待处理 brep (.step) 路径')
|
|
parser.add_argument(
|
|
'--use-normal',
|
|
action='store_true', # 默认为 False,如果用户指定该参数,则为 True
|
|
help='强制采样点有法向量'
|
|
)
|
|
parser.add_argument(
|
|
'--only-zero-surface',
|
|
action='store_true', # 默认为 False,如果用户指定该参数,则为 True
|
|
help='只采样零表面点 SDF 训练'
|
|
)
|
|
parser.add_argument(
|
|
'--force-reprocess','-f',
|
|
action='store_true', # 默认为 False,如果用户指定该参数,则为 True
|
|
help='强制重新进行数据预处理,忽略缓存或已有结果'
|
|
)
|
|
parser.add_argument(
|
|
'--resume-checkpoint-path',
|
|
type=str,
|
|
default=None,
|
|
help='从指定的checkpoint文件继续训练'
|
|
)
|
|
|
|
parser.add_argument(
|
|
'--octree-cuda',
|
|
action='store_true', # 默认为 False,如果用户指定该参数,则为 True
|
|
help='使用CUDA加速Octree构建'
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
|
|
class Trainer:
|
|
def __init__(self, config, input_step):
|
|
logger.gpu_memory_stats("初始化开始")
|
|
self.config = config
|
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
self.debug_mode = config.train.debug_mode
|
|
|
|
self.model_name = os.path.basename(input_step).split('_')[0]
|
|
self.base_name = self.model_name + ".xyz"
|
|
data_path = os.path.join("/home/wch/brep2sdf/data/output_data",self.base_name)
|
|
if os.path.exists(data_path) and not args.force_reprocess:
|
|
try:
|
|
self.data = load_brep_file(data_path)
|
|
except Exception as e:
|
|
logger.error(f"fail to load {data_path}, {str(e)}")
|
|
raise e
|
|
if args.use_normal and self.data.get("surf_pnt_normals", None) is None:
|
|
self.data = process_single_step(step_path=input_step, output_path=data_path, sample_normal_vector=args.use_normal,sample_sdf_points=not args.only_zero_surface)
|
|
else:
|
|
self.data = process_single_step(step_path=input_step, output_path=data_path, sample_normal_vector=args.use_normal,sample_sdf_points=not args.only_zero_surface)
|
|
|
|
logger.gpu_memory_stats("数据预处理后")
|
|
|
|
self.train_surf_ncs = torch.tensor(self.data["train_surf_ncs"],dtype=torch.float32,device=self.device) #
|
|
# 将曲面点云列表转换为 (N*M, 4) 数组
|
|
surfs = self.data["surf_ncs"]
|
|
|
|
# 准备表面点的SDF数据
|
|
surface_sdf_data = prepare_sdf_data(
|
|
surfs,
|
|
normals=self.data["surf_pnt_normals"],
|
|
max_points=50000,
|
|
device=self.device
|
|
)
|
|
# 如果不是仅使用零表面,则合并采样点数据
|
|
if not args.only_zero_surface:
|
|
# 加载采样点数据
|
|
sampled_sdf_data = torch.tensor(
|
|
self.data['sampled_points_normals_sdf'],
|
|
dtype=torch.float32,
|
|
device=self.device
|
|
)
|
|
# 合并表面点数据和采样点数据
|
|
self.sdf_data = torch.cat([surface_sdf_data, sampled_sdf_data], dim=0)
|
|
else:
|
|
self.sdf_data = surface_sdf_data
|
|
print_data_distribution(self.sdf_data)
|
|
logger.debug(self.sdf_data.shape)
|
|
logger.gpu_memory_stats("SDF数据准备后")
|
|
# 初始化数据集
|
|
#self.brep_data = load_brep_file(self.config.data.pkl_path)
|
|
#logger.info( self.brep_data )
|
|
#self.sdf_data = load_sdf_file(sdf_path=self.config.data.sdf_path, num_query_points=self.config.data.num_query_points).to(self.device)
|
|
|
|
# 构建面片邻接图
|
|
graph = PatchGraph.from_preprocessed_data(
|
|
surf_ncs=self.data['surf_ncs'],
|
|
edgeFace_adj=self.data['edgeFace_adj'],
|
|
edge_types=self.data['edge_types'],
|
|
device='cuda' if args.octree_cuda else 'cpu'
|
|
)
|
|
# 初始化网络
|
|
surf_bbox=torch.tensor(
|
|
self.data['surf_bbox_ncs'],
|
|
dtype=torch.float32,
|
|
device=self.device
|
|
)
|
|
max_depth = config.model.octree_max_depth
|
|
if not args.force_reprocess:
|
|
if not self._load_octree():
|
|
self.build_tree(surf_bbox=surf_bbox, graph=graph,max_depth=max_depth)
|
|
elif self.root.max_depth != max_depth:
|
|
self.build_tree(surf_bbox=surf_bbox, graph=graph,max_depth=max_depth)
|
|
else:
|
|
self.build_tree(surf_bbox=surf_bbox, graph=graph,max_depth=max_depth)
|
|
logger.gpu_memory_stats("树初始化后")
|
|
|
|
self.model = Net(
|
|
octree=self.root,
|
|
volume_bboxs=surf_bbox,
|
|
feature_dim=64
|
|
).to(self.device)
|
|
logger.gpu_memory_stats("模型初始化后")
|
|
|
|
|
|
|
|
self.scheduler = LearningRateScheduler(config.train.learning_rate_schedule, config.train.weight_decay, self.model.parameters())
|
|
|
|
self.loss_manager = LossManager(ablation="none")
|
|
logger.gpu_memory_stats("训练器初始化后")
|
|
|
|
self.sampler = NormalPerPoint(
|
|
global_sigma=0.1, # 全局采样标准差
|
|
local_sigma=0.01 # 局部采样标准差
|
|
)
|
|
|
|
logger.info(f"初始化完成,正在处理模型 {self.model_name}")
|
|
|
|
|
|
def build_tree(self,surf_bbox, graph, max_depth=9):
|
|
logger.info("开始构造八叉树...")
|
|
num_faces = surf_bbox.shape[0]
|
|
bbox = self._calculate_global_bbox(surf_bbox)
|
|
self.root = OctreeNode(
|
|
bbox=bbox,
|
|
face_indices=np.arange(num_faces), # 初始包含所有面
|
|
patch_graph=graph,
|
|
max_depth=max_depth,
|
|
surf_bbox=surf_bbox,
|
|
surf_ncs=self.data['surf_ncs']
|
|
)
|
|
#print(surf_bbox)
|
|
logger.info("starting octree conduction")
|
|
self.root.build_static_tree()
|
|
logger.info("complete octree conduction")
|
|
self.root.print_tree()
|
|
self._save_octree()
|
|
|
|
def _calculate_global_bbox(self, surf_bbox: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
返回一个固定的全局边界框(单位立方体)。
|
|
|
|
参数:
|
|
surf_bbox: (此参数在此实现中未使用)
|
|
|
|
返回:
|
|
形状为 (6,) 的Tensor,表示固定的边界框 [-0.5, -0.5, -0.5, 0.5, 0.5, 0.5]
|
|
"""
|
|
# 直接定义固定的单位立方体边界框
|
|
# 注意:确保张量在正确的设备上创建
|
|
fixed_bbox = torch.tensor([-0.5, -0.5, -0.5, 0.5, 0.5, 0.5],
|
|
dtype=torch.float32) # 假设 self.device 存储了目标设备
|
|
|
|
logger.debug(f"使用固定的全局边界框: {fixed_bbox.cpu().numpy()}")
|
|
return fixed_bbox
|
|
|
|
# --- 旧的计算逻辑 (注释掉或删除) ---
|
|
# # 验证输入
|
|
# if not isinstance(surf_bbox, torch.Tensor):
|
|
# raise TypeError(f"surf_bbox 必须是 torch.Tensor,但得到 {type(surf_bbox)}")
|
|
# if surf_bbox.dim() != 2 or surf_bbox.shape[1] != 6:
|
|
# raise ValueError(f"surf_bbox 形状应为 (num_edges, 6),但得到 {surf_bbox.shape}")
|
|
#
|
|
# # 计算表面包围盒的全局范围
|
|
# global_min = surf_bbox[:, :3].min(dim=0).values
|
|
# global_max = surf_bbox[:, 3:].max(dim=0).values
|
|
#
|
|
# # 返回合并后的边界框
|
|
# return torch.cat([global_min, global_max])
|
|
# return [-0.5,] # 这个是错误的
|
|
def train_epoch_stage1_(self, epoch: int):
|
|
total_loss = 0.0
|
|
total_loss_details = {
|
|
"manifold": 0.0,
|
|
"normals": 0.0,
|
|
"eikonal": 0.0,
|
|
"offsurface": 0.0
|
|
}
|
|
accumulated_loss = 0.0 # 新增:用于累积多个step的loss
|
|
|
|
# 新增:在每个epoch开始时清零梯度
|
|
self.optimizer.zero_grad()
|
|
|
|
for step, surf_points in enumerate(self.data['surf_ncs']):
|
|
mnfld_points = torch.tensor(surf_points, device=self.device)
|
|
nonmnfld_pnts = self.sampler.get_points(mnfld_points.unsqueeze(0)).squeeze(0) # 生成非流形点
|
|
gt_sdf = torch.zeros(mnfld_points.shape[0], device=self.device)
|
|
normals = None
|
|
if args.use_normal:
|
|
normals = torch.tensor(self.data["surf_pnt_normals"][step], device=self.device)
|
|
logger.debug(normals)
|
|
|
|
# --- 准备模型输入,启用梯度 ---
|
|
mnfld_points.requires_grad_(True) # 在检查之后启用梯度
|
|
nonmnfld_pnts.requires_grad_(True) # 在检查之后启用梯度
|
|
|
|
# --- 前向传播 ---
|
|
self.optimizer.zero_grad()
|
|
mnfld_pred = self.model.forward_training_volumes(mnfld_points, step)
|
|
nonmnfld_pred = self.model.forward_training_volumes(nonmnfld_pnts, step)
|
|
|
|
if self.debug_mode:
|
|
# --- 检查前向传播的输出 ---
|
|
logger.print_tensor_stats("mnfld_pred",mnfld_pred)
|
|
logger.print_tensor_stats("nonmnfld_pred",nonmnfld_pred)
|
|
logger.gpu_memory_stats("前向传播后")
|
|
|
|
# --- 计算损失 ---
|
|
try:
|
|
if args.use_normal:
|
|
loss, loss_details = self.loss_manager.compute_loss(
|
|
mnfld_points,
|
|
nonmnfld_pnts,
|
|
normals,
|
|
gt_sdf,
|
|
mnfld_pred,
|
|
nonmnfld_pred
|
|
)
|
|
else:
|
|
loss = torch.nn.functional.mse_loss(pred_sdf, gt_sdf)
|
|
loss_details = {} # 确保变量初始化
|
|
|
|
# 修改:累积loss而不是立即backward
|
|
accumulated_loss += loss / self.config.train.accumulation_steps # 假设配置中有accumulation_steps
|
|
current_loss = loss.item()
|
|
total_loss += current_loss
|
|
for key in total_loss_details:
|
|
if key in loss_details:
|
|
total_loss_details[key] += loss_details[key].item()
|
|
|
|
# 新增:达到累积步数时执行反向传播
|
|
if (step + 1) % self.config.train.accumulation_steps == 0:
|
|
# 反向传播
|
|
self.scheduler.optimizer.zero_grad() # 清空梯度
|
|
loss.backward() # 反向传播
|
|
self.scheduler.optimizer.step() # 更新参数
|
|
self.scheduler.step(accumulated_loss,epoch)
|
|
|
|
# 记录日志保持不变 ...
|
|
|
|
except Exception as loss_e:
|
|
logger.error(f"Error in step {step}: {loss_e}")
|
|
continue
|
|
|
|
# --- 内存管理 ---
|
|
del loss
|
|
torch.cuda.empty_cache()
|
|
|
|
# 新增:处理最后未达到累积步数的剩余loss
|
|
if accumulated_loss != 0:
|
|
# 反向传播
|
|
self.scheduler.optimizer.zero_grad() # 清空梯度
|
|
loss.backward() # 反向传播
|
|
self.scheduler.optimizer.step() # 更新参数
|
|
self.scheduler.step(accumulated_loss,epoch)
|
|
|
|
# 计算并记录epoch损失
|
|
logger.info(f'Train Epoch: {epoch:4d}]\t'
|
|
f'Loss: {total_loss:.6f}')
|
|
logger.info(f"Loss Details: {total_loss_details}")
|
|
return total_loss # 返回平均损失而非累计值
|
|
|
|
|
|
def train_epoch_stage1(self, epoch: int) -> float:
|
|
# --- 1. 检查输入数据 ---
|
|
# 注意:假设 self.sdf_data 包含 [x, y, z, nx, ny, nz, sdf] (7列) 或 [x, y, z, sdf] (4列)
|
|
# 并且 SDF 值总是在最后一列
|
|
if self.train_surf_ncs is None:
|
|
logger.error(f"Epoch {epoch}: self.train_surf_ncs is None. Cannot train.")
|
|
return float('inf')
|
|
|
|
self.model.train()
|
|
total_loss = 0.0
|
|
step = 0 # 如果你的训练是分批次的,这里应该用批次索引
|
|
batch_size = 8192 # 设置合适的batch大小
|
|
|
|
# 数据处理
|
|
# manfld
|
|
_mnfld_pnts = self.train_surf_ncs[:, 0:3].clone().detach() # 取前3列作为点
|
|
_normals = self.train_surf_ncs[:, 3:6].clone().detach() # 取中间3列作为法线
|
|
_gt_sdf = self.train_surf_ncs[:, -1].clone().detach() # 取最后一列作为SDF真值
|
|
|
|
# 检查是否需要重新计算缓存
|
|
if epoch % 10 == 1 or self.cached_train_data is None:
|
|
# 计算流形点的掩码和操作符
|
|
# 生成非流形点
|
|
_psdf_pnts, _psdf = self.sampler.get_norm_points(_mnfld_pnts, _normals,local_sigma=0.001)
|
|
_nonmnfld_pnts = self.sampler.get_points(_mnfld_pnts, local_sigma=0.01):
|
|
|
|
# 更新缓存
|
|
self.cached_train_data = {
|
|
"nonmnfld_pnts": _nonmnfld_pnts,
|
|
"psdf_pnts": _psdf_pnts,
|
|
"psdf": _psdf,
|
|
}
|
|
else:
|
|
# 从缓存中读取数据
|
|
_nonmnfld_pnts = self.cached_train_data["nonmnfld_pnts"]
|
|
_psdf_pnts = self.cached_train_data["psdf_pnts"]
|
|
_psdf = self.cached_train_data["psdf"]
|
|
|
|
logger.debug((_mnfld_pnts, _nonmnfld_pnts, _psdf))
|
|
|
|
|
|
|
|
# 将数据分成多个batch
|
|
num_points = self.train_surf_ncs.shape[0]
|
|
num_batches = (num_points + batch_size - 1) // batch_size
|
|
|
|
for batch_idx in range(num_batches):
|
|
start_idx = batch_idx * batch_size
|
|
end_idx = min((batch_idx + 1) * batch_size, num_points)
|
|
|
|
# 获取当前batch的数据
|
|
mnfld_pnts = _mnfld_pnts[start_idx:end_idx] # 流形点
|
|
gt_sdf = _gt_sdf[start_idx:end_idx] # SDF真值
|
|
normals = _normals[start_idx:end_idx] if args.use_normal else None # 法线
|
|
|
|
# 非流形点使用缓存数据(整个batch共享)
|
|
nonmnfld_pnts = _nonmnfld_pnts[start_idx:end_idx]
|
|
psdf_pnts = _psdf_pnts[start_idx:end_idx]
|
|
psdf = _psdf[start_idx:end_idx]
|
|
|
|
# --- 准备模型输入,启用梯度 ---
|
|
mnfld_pnts.requires_grad_(True) # 在检查之后启用梯度
|
|
nonmnfld_pnts.requires_grad_(True) # 在检查之后启用梯度
|
|
psdf_pnts.requires_grad_(True) # 在检查之后启用梯度
|
|
|
|
# --- 前向传播 ---
|
|
mnfld_pred = self.model.forward_background(
|
|
mnfld_pnts
|
|
)
|
|
nonmnfld_pred = self.model.forward_background(
|
|
nonmnfld_pnts
|
|
)
|
|
psdf_pred = self.model.forward_background(
|
|
psdf_pnts
|
|
)
|
|
|
|
|
|
|
|
# --- 计算损失 ---
|
|
loss = torch.tensor(float('nan'), device=self.device) # 初始化为 NaN 以防计算失败
|
|
loss_details = {}
|
|
try:
|
|
# --- 3. 检查损失计算前的输入 ---
|
|
# (points 已经启用梯度,不需要再次检查inf/nan,但检查pred_sdf和gt_sdf)
|
|
#if check_tensor(pred_sdf, "Predicted SDF (Loss Input)", epoch, step): raise ValueError("Bad pred_sdf before loss")
|
|
#if check_tensor(gt_sdf, "GT SDF (Loss Input)", epoch, step): raise ValueError("Bad gt_sdf before loss")
|
|
if args.use_normal:
|
|
# 检查法线和带梯度的点
|
|
#if check_tensor(normals, "Normals (Loss Input)", epoch, step): raise ValueError("Bad normals before loss")
|
|
#if check_tensor(points, "Points (Loss Input)", epoch, step): raise ValueError("Bad points before loss")
|
|
logger.gpu_memory_stats("计算损失前")
|
|
loss, loss_details = self.loss_manager.compute_loss(
|
|
mnfld_pnts,
|
|
nonmnfld_pnts,
|
|
psdf_pnts,
|
|
normals, # 传递检查过的 normals
|
|
gt_sdf,
|
|
mnfld_pred,
|
|
nonmnfld_pred,
|
|
psdf
|
|
)
|
|
else:
|
|
loss = torch.nn.functional.mse_loss(pred_sdf, gt_sdf)
|
|
|
|
# --- 4. 检查损失计算结果 ---
|
|
if self.debug_mode:
|
|
logger.print_tensor_stats("psdf",psdf)
|
|
logger.print_tensor_stats("nonmnfld_pred",nonmnfld_pred)
|
|
if check_tensor(loss, "Calculated Loss", epoch, step):
|
|
logger.error(f"Epoch {epoch} Step {step}: Loss calculation resulted in inf/nan.")
|
|
if loss_details: logger.error(f"Loss Details: {loss_details}")
|
|
return float('inf') # 如果损失无效,停止这个epoch
|
|
|
|
except Exception as loss_e:
|
|
logger.error(f"Epoch {epoch} Step {step}: Error during loss calculation: {loss_e}", exc_info=True)
|
|
return float('inf') # 如果计算出错,停止这个epoch
|
|
logger.gpu_memory_stats("损失计算后")
|
|
|
|
# --- 反向传播和优化 ---
|
|
try:
|
|
# 反向传播
|
|
self.scheduler.optimizer.zero_grad() # 清空梯度
|
|
loss.backward() # 反向传播
|
|
self.scheduler.optimizer.step() # 更新参数
|
|
self.scheduler.step(loss,epoch)
|
|
except Exception as backward_e:
|
|
logger.error(f"Epoch {epoch} Step {step}: Error during backward pass or optimizer step: {backward_e}", exc_info=True)
|
|
# 如果你想看是哪个操作导致的,可以启用 anomaly detection
|
|
# torch.autograd.set_detect_anomaly(True) # 放在训练开始前
|
|
return float('inf') # 如果反向传播或优化出错,停止这个epoch
|
|
|
|
|
|
# --- 记录和累加损失 ---
|
|
current_loss = loss.item()
|
|
if not np.isfinite(current_loss): # 再次确认损失是有效的数值
|
|
logger.error(f"Epoch {epoch} Step {step}: Loss item is not finite ({current_loss}).")
|
|
return float('inf')
|
|
|
|
total_loss += current_loss
|
|
del loss
|
|
torch.cuda.empty_cache()
|
|
|
|
if epoch % 100 == 0:
|
|
# 记录训练进度 (只记录有效的损失)
|
|
logger.info(f'Train Epoch: {epoch:4d}]\t'
|
|
f'Loss: {current_loss:.6f}')
|
|
if loss_details: logger.info(f"Loss Details: {loss_details}")
|
|
|
|
return total_loss # 对于单批次训练,直接返回当前损失
|
|
|
|
def train_epoch_stage2(self, epoch: int) -> float:
|
|
# --- 1. 检查输入数据 ---
|
|
# 注意:假设 self.sdf_data 包含 [x, y, z, nx, ny, nz, sdf] (7列) 或 [x, y, z, sdf] (4列)
|
|
# 并且 SDF 值总是在最后一列
|
|
if self.sdf_data is None:
|
|
logger.error(f"Epoch {epoch}: self.sdf_data is None. Cannot train.")
|
|
return float('inf')
|
|
|
|
self.model.train()
|
|
total_loss = 0.0
|
|
step = 0 # 如果你的训练是分批次的,这里应该用批次索引
|
|
batch_size = 8192 * 2 # 设置合适的batch大小
|
|
|
|
# 数据处理
|
|
# manfld
|
|
_mnfld_pnts = self.sdf_data[:, 0:3].clone().detach() # 取前3列作为点
|
|
_normals = self.sdf_data[:, 3:6].clone().detach() # 取中间3列作为法线
|
|
_gt_sdf = self.sdf_data[:, -1].clone().detach() # 取最后一列作为SDF真值
|
|
|
|
# 检查是否需要重新计算缓存
|
|
if epoch % 10 == 1 or self.cached_train_data is None:
|
|
# 计算流形点的掩码和操作符
|
|
_, _mnfld_face_indices_mask, _mnfld_operator = self.root.forward(_mnfld_pnts)
|
|
|
|
# 生成非流形点
|
|
_nonmnfld_pnts, _psdf = self.sampler.get_norm_points(_mnfld_pnts, _normals)
|
|
_, _nonmnfld_face_indices_mask, _nonmnfld_operator = self.root.forward(_nonmnfld_pnts)
|
|
|
|
# 更新缓存
|
|
self.cached_train_data = {
|
|
"mnfld_face_indices_mask": _mnfld_face_indices_mask,
|
|
"mnfld_operator": _mnfld_operator,
|
|
"nonmnfld_pnts": _nonmnfld_pnts,
|
|
"psdf": _psdf,
|
|
"nonmnfld_face_indices_mask": _nonmnfld_face_indices_mask,
|
|
"nonmnfld_operator": _nonmnfld_operator
|
|
}
|
|
else:
|
|
# 从缓存中读取数据
|
|
_mnfld_face_indices_mask = self.cached_train_data["mnfld_face_indices_mask"]
|
|
_mnfld_operator = self.cached_train_data["mnfld_operator"]
|
|
_nonmnfld_pnts = self.cached_train_data["nonmnfld_pnts"]
|
|
_psdf = self.cached_train_data["psdf"]
|
|
_nonmnfld_face_indices_mask = self.cached_train_data["nonmnfld_face_indices_mask"]
|
|
_nonmnfld_operator = self.cached_train_data["nonmnfld_operator"]
|
|
|
|
logger.debug((_mnfld_pnts, _nonmnfld_pnts, _psdf))
|
|
|
|
|
|
|
|
# 将数据分成多个batch
|
|
num_points = self.sdf_data.shape[0]
|
|
num_batches = (num_points + batch_size - 1) // batch_size
|
|
|
|
for batch_idx in range(num_batches):
|
|
start_idx = batch_idx * batch_size
|
|
end_idx = min((batch_idx + 1) * batch_size, num_points)
|
|
|
|
# 获取当前batch的数据
|
|
mnfld_pnts = _mnfld_pnts[start_idx:end_idx] # 流形点
|
|
gt_sdf = _gt_sdf[start_idx:end_idx] # SDF真值
|
|
normals = _normals[start_idx:end_idx] if args.use_normal else None # 法线
|
|
|
|
# 非流形点使用缓存数据(整个batch共享)
|
|
nonmnfld_pnts = _nonmnfld_pnts[start_idx:end_idx]
|
|
psdf = _psdf[start_idx:end_idx]
|
|
|
|
# --- 准备模型输入,启用梯度 ---
|
|
mnfld_pnts.requires_grad_(True) # 在检查之后启用梯度
|
|
nonmnfld_pnts.requires_grad_(True) # 在检查之后启用梯度
|
|
|
|
# --- 前向传播 ---
|
|
mnfld_pred = self.model.forward_without_octree(
|
|
mnfld_pnts,
|
|
_mnfld_face_indices_mask[start_idx:end_idx],
|
|
_mnfld_operator[start_idx:end_idx]
|
|
)
|
|
nonmnfld_pred = self.model.forward_without_octree(
|
|
nonmnfld_pnts,
|
|
_nonmnfld_face_indices_mask[start_idx:end_idx],
|
|
_nonmnfld_operator[start_idx:end_idx]
|
|
)
|
|
|
|
#logger.print_tensor_stats("psdf",psdf)
|
|
#logger.print_tensor_stats("nonmnfld_pred",nonmnfld_pred)
|
|
|
|
# --- 计算损失 ---
|
|
loss = torch.tensor(float('nan'), device=self.device) # 初始化为 NaN 以防计算失败
|
|
loss_details = {}
|
|
try:
|
|
# --- 3. 检查损失计算前的输入 ---
|
|
# (points 已经启用梯度,不需要再次检查inf/nan,但检查pred_sdf和gt_sdf)
|
|
#if check_tensor(pred_sdf, "Predicted SDF (Loss Input)", epoch, step): raise ValueError("Bad pred_sdf before loss")
|
|
#if check_tensor(gt_sdf, "GT SDF (Loss Input)", epoch, step): raise ValueError("Bad gt_sdf before loss")
|
|
if args.use_normal:
|
|
# 检查法线和带梯度的点
|
|
#if check_tensor(normals, "Normals (Loss Input)", epoch, step): raise ValueError("Bad normals before loss")
|
|
#if check_tensor(points, "Points (Loss Input)", epoch, step): raise ValueError("Bad points before loss")
|
|
logger.gpu_memory_stats("计算损失前")
|
|
loss, loss_details = self.loss_manager.compute_loss(
|
|
mnfld_pnts,
|
|
nonmnfld_pnts,
|
|
normals, # 传递检查过的 normals
|
|
gt_sdf,
|
|
mnfld_pred,
|
|
nonmnfld_pred,
|
|
psdf
|
|
)
|
|
else:
|
|
loss = torch.nn.functional.mse_loss(pred_sdf, gt_sdf)
|
|
|
|
# --- 4. 检查损失计算结果 ---
|
|
if self.debug_mode:
|
|
if check_tensor(loss, "Calculated Loss", epoch, step):
|
|
logger.error(f"Epoch {epoch} Step {step}: Loss calculation resulted in inf/nan.")
|
|
if loss_details: logger.error(f"Loss Details: {loss_details}")
|
|
return float('inf') # 如果损失无效,停止这个epoch
|
|
|
|
except Exception as loss_e:
|
|
logger.error(f"Epoch {epoch} Step {step}: Error during loss calculation: {loss_e}", exc_info=True)
|
|
return float('inf') # 如果计算出错,停止这个epoch
|
|
logger.gpu_memory_stats("损失计算后")
|
|
|
|
# --- 反向传播和优化 ---
|
|
try:
|
|
# 反向传播
|
|
self.scheduler.optimizer.zero_grad() # 清空梯度
|
|
loss.backward() # 反向传播
|
|
self.scheduler.optimizer.step() # 更新参数
|
|
self.scheduler.step(loss,epoch)
|
|
except Exception as backward_e:
|
|
logger.error(f"Epoch {epoch} Step {step}: Error during backward pass or optimizer step: {backward_e}", exc_info=True)
|
|
# 如果你想看是哪个操作导致的,可以启用 anomaly detection
|
|
# torch.autograd.set_detect_anomaly(True) # 放在训练开始前
|
|
return float('inf') # 如果反向传播或优化出错,停止这个epoch
|
|
|
|
|
|
# --- 记录和累加损失 ---
|
|
current_loss = loss.item()
|
|
if not np.isfinite(current_loss): # 再次确认损失是有效的数值
|
|
logger.error(f"Epoch {epoch} Step {step}: Loss item is not finite ({current_loss}).")
|
|
return float('inf')
|
|
|
|
total_loss += current_loss
|
|
del loss
|
|
torch.cuda.empty_cache()
|
|
|
|
|
|
# 记录训练进度 (只记录有效的损失)
|
|
logger.info(f'Train Epoch: {epoch:4d}]\t'
|
|
f'Loss: {current_loss:.6f}')
|
|
if loss_details: logger.info(f"Loss Details: {loss_details}")
|
|
|
|
return total_loss # 对于单批次训练,直接返回当前损失
|
|
|
|
def train_epoch(self, epoch: int,resample:bool=True) -> float:
|
|
# --- 1. 检查输入数据 ---
|
|
# 注意:假设 self.sdf_data 包含 [x, y, z, nx, ny, nz, sdf] (7列) 或 [x, y, z, sdf] (4列)
|
|
# 并且 SDF 值总是在最后一列
|
|
if self.sdf_data is None:
|
|
logger.error(f"Epoch {epoch}: self.sdf_data is None. Cannot train.")
|
|
return float('inf')
|
|
|
|
self.model.train()
|
|
total_loss = 0.0
|
|
step = 0 # 如果你的训练是分批次的,这里应该用批次索引
|
|
batch_size = 8192 # 设置合适的batch大小
|
|
|
|
# 将数据分成多个batch
|
|
num_points = self.sdf_data.shape[0]
|
|
num_batches = (num_points + batch_size - 1) // batch_size
|
|
|
|
for batch_idx in range(num_batches):
|
|
start_idx = batch_idx * batch_size
|
|
end_idx = min((batch_idx + 1) * batch_size, num_points)
|
|
mnfld_pnts = self.sdf_data[start_idx:end_idx, 0:3].clone().detach() # 取前3列作为点
|
|
|
|
gt_sdf = self.sdf_data[start_idx:end_idx, -1].clone().detach() # 取最后一列作为SDF真值
|
|
normals = None
|
|
if args.use_normal:
|
|
if self.sdf_data.shape[1] < 7: # 检查是否有足够的列给法线
|
|
logger.error(f"Epoch {epoch}: --use-normal is specified, but sdf_data has only {self.sdf_data.shape[1]} columns.")
|
|
return float('inf')
|
|
normals = self.sdf_data[start_idx:end_idx, 3:6].clone().detach() # 取中间3列作为法线
|
|
nonmnfld_pnts,psdf = self.sampler.get_norm_points(mnfld_pnts,normals) # 生成非流形点
|
|
logger.debug((mnfld_pnts,nonmnfld_pnts,psdf))
|
|
|
|
else:
|
|
nonmnfld_pnts = self.sampler.get_points(mnfld_pnts.unsqueeze(0)).squeeze(0) # 生成非流形点
|
|
# 执行检查
|
|
if self.debug_mode:
|
|
if check_tensor(mnfld_pnts, "Input Points", epoch, step): return float('inf')
|
|
if check_tensor(gt_sdf, "Input GT SDF", epoch, step): return float('inf')
|
|
if args.use_normal:
|
|
# 只有在请求法线时才检查 normals
|
|
if check_tensor(normals, "Input Normals", epoch, step): return float('inf')
|
|
logger.debug(normals)
|
|
logger.print_tensor_stats("normals-x",normals[0])
|
|
logger.print_tensor_stats("normals-y",normals[1])
|
|
logger.print_tensor_stats("normals-z",normals[2])
|
|
|
|
|
|
# --- 准备模型输入,启用梯度 ---
|
|
mnfld_pnts.requires_grad_(True) # 在检查之后启用梯度
|
|
nonmnfld_pnts.requires_grad_(True) # 在检查之后启用梯度
|
|
|
|
# --- 前向传播 ---
|
|
mnfld_pred = self.model(mnfld_pnts)
|
|
nonmnfld_pred = self.model(nonmnfld_pnts)
|
|
|
|
if self.debug_mode:
|
|
# --- 检查前向传播的输出 ---
|
|
logger.print_tensor_stats("mnfld_pred",mnfld_pred)
|
|
logger.print_tensor_stats("nonmnfld_pred",nonmnfld_pred)
|
|
logger.gpu_memory_stats("前向传播后")
|
|
# --- 2. 检查模型输出 ---
|
|
#if check_tensor(pred_sdf, "Predicted SDF (Model Output)", epoch, step): return float('inf')
|
|
|
|
# --- 计算损失 ---
|
|
loss = torch.tensor(float('nan'), device=self.device) # 初始化为 NaN 以防计算失败
|
|
loss_details = {}
|
|
try:
|
|
# --- 3. 检查损失计算前的输入 ---
|
|
# (points 已经启用梯度,不需要再次检查inf/nan,但检查pred_sdf和gt_sdf)
|
|
#if check_tensor(pred_sdf, "Predicted SDF (Loss Input)", epoch, step): raise ValueError("Bad pred_sdf before loss")
|
|
#if check_tensor(gt_sdf, "GT SDF (Loss Input)", epoch, step): raise ValueError("Bad gt_sdf before loss")
|
|
if args.use_normal:
|
|
# 检查法线和带梯度的点
|
|
#if check_tensor(normals, "Normals (Loss Input)", epoch, step): raise ValueError("Bad normals before loss")
|
|
#if check_tensor(points, "Points (Loss Input)", epoch, step): raise ValueError("Bad points before loss")
|
|
logger.gpu_memory_stats("计算损失前")
|
|
loss, loss_details = self.loss_manager.compute_loss(
|
|
mnfld_pnts,
|
|
nonmnfld_pnts,
|
|
normals, # 传递检查过的 normals
|
|
gt_sdf,
|
|
mnfld_pred,
|
|
nonmnfld_pred,
|
|
psdf
|
|
)
|
|
else:
|
|
loss = torch.nn.functional.mse_loss(pred_sdf, gt_sdf)
|
|
|
|
# --- 4. 检查损失计算结果 ---
|
|
if self.debug_mode:
|
|
if check_tensor(loss, "Calculated Loss", epoch, step):
|
|
logger.error(f"Epoch {epoch} Step {step}: Loss calculation resulted in inf/nan.")
|
|
if loss_details: logger.error(f"Loss Details: {loss_details}")
|
|
return float('inf') # 如果损失无效,停止这个epoch
|
|
|
|
except Exception as loss_e:
|
|
logger.error(f"Epoch {epoch} Step {step}: Error during loss calculation: {loss_e}", exc_info=True)
|
|
return float('inf') # 如果计算出错,停止这个epoch
|
|
logger.gpu_memory_stats("损失计算后")
|
|
|
|
# --- 反向传播和优化 ---
|
|
try:
|
|
# 反向传播
|
|
self.scheduler.optimizer.zero_grad() # 清空梯度
|
|
loss.backward() # 反向传播
|
|
self.scheduler.optimizer.step() # 更新参数
|
|
self.scheduler.step(loss,epoch)
|
|
except Exception as backward_e:
|
|
logger.error(f"Epoch {epoch} Step {step}: Error during backward pass or optimizer step: {backward_e}", exc_info=True)
|
|
# 如果你想看是哪个操作导致的,可以启用 anomaly detection
|
|
# torch.autograd.set_detect_anomaly(True) # 放在训练开始前
|
|
return float('inf') # 如果反向传播或优化出错,停止这个epoch
|
|
|
|
|
|
# --- 记录和累加损失 ---
|
|
current_loss = loss.item()
|
|
if not np.isfinite(current_loss): # 再次确认损失是有效的数值
|
|
logger.error(f"Epoch {epoch} Step {step}: Loss item is not finite ({current_loss}).")
|
|
return float('inf')
|
|
|
|
total_loss += current_loss
|
|
del loss
|
|
torch.cuda.empty_cache()
|
|
|
|
|
|
# 记录训练进度 (只记录有效的损失)
|
|
logger.info(f'Train Epoch: {epoch:4d}]\t'
|
|
f'Loss: {current_loss:.6f}')
|
|
if loss_details: logger.info(f"Loss Details: {loss_details}")
|
|
|
|
return total_loss # 对于单批次训练,直接返回当前损失
|
|
|
|
def validate(self, epoch: int) -> float:
|
|
self.model.eval()
|
|
total_loss = 0.0
|
|
|
|
with torch.no_grad():
|
|
for batch in self.val_loader:
|
|
points = batch['points'].to(self.device)
|
|
gt_sdf = batch['sdf'].to(self.device)
|
|
|
|
pred_sdf = self.model(points)
|
|
loss = torch.nn.functional.mse_loss(pred_sdf, gt_sdf)
|
|
total_loss += loss.item()
|
|
|
|
avg_loss = total_loss / len(self.val_loader)
|
|
logger.info(f'Validation Epoch: {epoch}\tAverage Loss: {avg_loss:.6f}')
|
|
return avg_loss
|
|
|
|
def train(self):
|
|
best_val_loss = float('inf')
|
|
logger.info("Starting training...")
|
|
start_time = time.time()
|
|
self.cached_train_data=None
|
|
|
|
start_epoch = 1
|
|
if args.resume_checkpoint_path:
|
|
start_epoch = self._load_checkpoint(args.resume_checkpoint_path)
|
|
logger.info(f"Loaded model from {args.resume_checkpoint_path}")
|
|
|
|
self.model.encoder.freeze_stage1()
|
|
for epoch in range(start_epoch, self.config.train.num_epochs + 1):
|
|
# 训练一个epoch
|
|
train_loss = self.train_epoch_stage1(epoch)
|
|
#train_loss = self.train_epoch_stage2(epoch)
|
|
#train_loss = self.train_epoch(epoch)
|
|
|
|
# 验证
|
|
'''
|
|
if epoch % self.config.train.val_freq == 0:
|
|
val_loss = self.validate(epoch)
|
|
logger.info(f'Epoch {epoch}: Train Loss = {train_loss:.6f}, Val Loss = {val_loss:.6f}')
|
|
|
|
# 保存最佳模型
|
|
if val_loss < best_val_loss:
|
|
best_val_loss = val_loss
|
|
self._save_model(epoch, val_loss)
|
|
logger.info(f'New best model saved at epoch {epoch} with val loss {val_loss:.6f}')
|
|
else:
|
|
logger.info(f'Epoch {epoch}: Train Loss = {train_loss:.6f}')
|
|
'''
|
|
# 保存检查点
|
|
if epoch % self.config.train.save_freq == 0:
|
|
self._save_checkpoint(epoch, train_loss)
|
|
logger.info(f'Checkpoint saved at epoch {epoch}')
|
|
self.model.encoder.unfreeze()
|
|
# 训练完成
|
|
total_time = time.time() - start_time
|
|
|
|
self._tracing_model_by_script()
|
|
#self._tracing_model()
|
|
logger.info(f'Training completed in {total_time//3600:.0f}h {(total_time%3600)//60:.0f}m {total_time%60:.2f}s')
|
|
logger.info(f'Best validation loss: {best_val_loss:.6f}')
|
|
#self.test_load()
|
|
|
|
def test_load(self):
|
|
model = self._load_checkpoint("/home/wch/brep2sdf/brep2sdf/00000054.pt")
|
|
model.eval()
|
|
logger.debug(model)
|
|
example_input = torch.rand(10, 3, device=self.device)
|
|
#logger.debug(model.encoder.octree.bbox)
|
|
logger.debug(f"points: {example_input}")
|
|
sdfs= model(example_input)
|
|
logger.debug(f"sdfs:{sdfs}")
|
|
|
|
def _tracing_model_by_script(self):
|
|
"""保存模型"""
|
|
self.model.eval()
|
|
# 确保模型中的所有逻辑都兼容 TorchScript
|
|
scripted_model = torch.jit.script(self.model)
|
|
#optimized_model = optimize_for_mobile(scripted_model)
|
|
torch.jit.save(scripted_model, f"/home/wch/brep2sdf/data/output_data/{self.model_name}.pt")
|
|
|
|
def _tracing_model(self):
|
|
"""保存模型"""
|
|
self.model.eval()
|
|
|
|
# 创建示例输入
|
|
example_input = torch.rand(1, 3, device=self.device)
|
|
|
|
# 使用 trace 方式导出模型
|
|
traced_model = torch.jit.trace(self.model, example_input)
|
|
|
|
# 保存模型
|
|
save_path = f"/home/wch/brep2sdf/data/output_data/{self.model_name}.pt"
|
|
torch.jit.save(traced_model, save_path)
|
|
|
|
# 验证保存的模型
|
|
try:
|
|
loaded_model = torch.jit.load(save_path)
|
|
test_input = torch.rand(1, 3, device=self.device)
|
|
_ = loaded_model(test_input)
|
|
logger.info(f"模型已保存并验证成功:{save_path}")
|
|
except Exception as e:
|
|
logger.error(f"模型验证失败:{e}")
|
|
|
|
def _save_checkpoint(self, epoch: int, train_loss: float):
|
|
"""保存训练检查点"""
|
|
checkpoint_dir = os.path.join(
|
|
self.config.train.checkpoint_dir,
|
|
self.model_name
|
|
)
|
|
os.makedirs(checkpoint_dir, exist_ok=True)
|
|
checkpoint_path = os.path.join(checkpoint_dir, f"epoch_{epoch:03d}.pth")
|
|
|
|
# 只保存状态字典
|
|
torch.save({
|
|
'epoch': epoch,
|
|
'model_state_dict': self.model.state_dict(),
|
|
'optimizer_state_dict': self.scheduler.optimizer.state_dict(),
|
|
'loss': train_loss,
|
|
}, checkpoint_path)
|
|
|
|
def _load_checkpoint(self, checkpoint_path):
|
|
"""从检查点恢复训练状态"""
|
|
try:
|
|
checkpoint = torch.load(checkpoint_path)
|
|
self.model.load_state_dict(checkpoint['model_state_dict'])
|
|
self.scheduler.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
|
return checkpoint['epoch'] + 1
|
|
except Exception as e:
|
|
logger.error(f"加载checkpoint失败: {str(e)}")
|
|
raise
|
|
|
|
def _save_octree(self):
|
|
"""
|
|
保存八叉树到文件。
|
|
八叉树保存路径基于模型名称和配置中的检查点目录。
|
|
"""
|
|
checkpoint_dir = os.path.join(
|
|
self.config.train.checkpoint_dir,
|
|
self.model_name
|
|
)
|
|
os.makedirs(checkpoint_dir, exist_ok=True)
|
|
octree_path = os.path.join(checkpoint_dir, "octree.pth")
|
|
|
|
try:
|
|
# 保存八叉树的根节点
|
|
torch.save(self.root, octree_path)
|
|
logger.info(f"八叉树已保存到 {octree_path}")
|
|
except Exception as e:
|
|
logger.error(f"保存八叉树失败: {str(e)}")
|
|
|
|
def _load_octree(self)->bool:
|
|
"""
|
|
从文件加载八叉树。
|
|
尝试从基于模型名称和配置检查点目录的路径加载八叉树。
|
|
"""
|
|
checkpoint_dir = os.path.join(
|
|
self.config.train.checkpoint_dir,
|
|
self.model_name
|
|
)
|
|
octree_path = os.path.join(checkpoint_dir, "octree.pth")
|
|
|
|
try:
|
|
if os.path.exists(octree_path):
|
|
# 加载八叉树的根节点
|
|
self.root = torch.load(octree_path, weights_only=False)
|
|
logger.info(f"八叉树已从 {octree_path} 加载")
|
|
return True
|
|
else:
|
|
logger.warning(f"八叉树文件 {octree_path} 不存在,无法加载。")
|
|
except Exception as e:
|
|
logger.error(f"加载八叉树失败: {str(e)}")
|
|
return False
|
|
|
|
def main():
|
|
# 这里需要初始化配置
|
|
config = get_default_config()
|
|
# 初始化训练器并开始训练
|
|
trainer = Trainer(config, input_step=args.input)
|
|
trainer.train()
|
|
|
|
if __name__ == '__main__':
|
|
main()
|