You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

918 lines
40 KiB

import torch
import torch.optim as optim
import time
import os
import numpy as np
import argparse
from brep2sdf.config.default_config import get_default_config
from brep2sdf.data.data import load_brep_file,prepare_sdf_data, print_data_distribution, check_tensor
from brep2sdf.data.pre_process_by_mesh import process_single_step
from brep2sdf.networks.network import Net
from brep2sdf.networks.octree import OctreeNode
from brep2sdf.networks.loss import LossManager
from brep2sdf.networks.patch_graph import PatchGraph
from brep2sdf.networks.sample import NormalPerPoint
from brep2sdf.networks.learning_rate import LearningRateScheduler
from brep2sdf.utils.logger import logger
# 配置命令行参数
parser = argparse.ArgumentParser(description='STEP文件批量处理工具')
parser.add_argument('-i', '--input', required=True,
help='待处理 brep (.step) 路径')
parser.add_argument(
'--use-normal',
action='store_true', # 默认为 False,如果用户指定该参数,则为 True
help='强制采样点有法向量'
)
parser.add_argument(
'--only-zero-surface',
action='store_true', # 默认为 False,如果用户指定该参数,则为 True
help='只采样零表面点 SDF 训练'
)
parser.add_argument(
'--force-reprocess','-f',
action='store_true', # 默认为 False,如果用户指定该参数,则为 True
help='强制重新进行数据预处理,忽略缓存或已有结果'
)
parser.add_argument(
'--resume-checkpoint-path',
type=str,
default=None,
help='从指定的checkpoint文件继续训练'
)
parser.add_argument(
'--octree-cuda',
action='store_true', # 默认为 False,如果用户指定该参数,则为 True
help='使用CUDA加速Octree构建'
)
args = parser.parse_args()
class Trainer:
def __init__(self, config, input_step):
logger.gpu_memory_stats("初始化开始")
self.config = config
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.debug_mode = config.train.debug_mode
self.model_name = os.path.basename(input_step).split('_')[0]
self.base_name = self.model_name + ".xyz"
data_path = os.path.join("/home/wch/brep2sdf/data/output_data",self.base_name)
if os.path.exists(data_path) and not args.force_reprocess:
try:
self.data = load_brep_file(data_path)
except Exception as e:
logger.error(f"fail to load {data_path}, {str(e)}")
raise e
if args.use_normal and self.data.get("surf_pnt_normals", None) is None:
self.data = process_single_step(step_path=input_step, output_path=data_path, sample_normal_vector=args.use_normal,sample_sdf_points=not args.only_zero_surface)
else:
self.data = process_single_step(step_path=input_step, output_path=data_path, sample_normal_vector=args.use_normal,sample_sdf_points=not args.only_zero_surface)
logger.gpu_memory_stats("数据预处理后")
self.train_surf_ncs = torch.tensor(self.data["train_surf_ncs"],dtype=torch.float32,device=self.device) #
# 将曲面点云列表转换为 (N*M, 4) 数组
surfs = self.data["surf_ncs"]
# 准备表面点的SDF数据
surface_sdf_data = prepare_sdf_data(
surfs,
normals=self.data["surf_pnt_normals"],
max_points=50000,
device=self.device
)
# 如果不是仅使用零表面,则合并采样点数据
if not args.only_zero_surface:
# 加载采样点数据
sampled_sdf_data = torch.tensor(
self.data['sampled_points_normals_sdf'],
dtype=torch.float32,
device=self.device
)
# 合并表面点数据和采样点数据
self.sdf_data = torch.cat([surface_sdf_data, sampled_sdf_data], dim=0)
else:
self.sdf_data = surface_sdf_data
print_data_distribution(self.sdf_data)
logger.debug(self.sdf_data.shape)
logger.gpu_memory_stats("SDF数据准备后")
# 初始化数据集
#self.brep_data = load_brep_file(self.config.data.pkl_path)
#logger.info( self.brep_data )
#self.sdf_data = load_sdf_file(sdf_path=self.config.data.sdf_path, num_query_points=self.config.data.num_query_points).to(self.device)
# 构建面片邻接图
graph = PatchGraph.from_preprocessed_data(
surf_ncs=self.data['surf_ncs'],
edgeFace_adj=self.data['edgeFace_adj'],
edge_types=self.data['edge_types'],
device='cuda' if args.octree_cuda else 'cpu'
)
# 初始化网络
surf_bbox=torch.tensor(
self.data['surf_bbox_ncs'],
dtype=torch.float32,
device=self.device
)
max_depth = config.model.octree_max_depth
if not args.force_reprocess:
if not self._load_octree():
self.build_tree(surf_bbox=surf_bbox, graph=graph,max_depth=max_depth)
elif self.root.max_depth != max_depth:
self.build_tree(surf_bbox=surf_bbox, graph=graph,max_depth=max_depth)
else:
self.build_tree(surf_bbox=surf_bbox, graph=graph,max_depth=max_depth)
logger.gpu_memory_stats("树初始化后")
self.model = Net(
octree=self.root,
volume_bboxs=surf_bbox,
feature_dim=64
).to(self.device)
logger.gpu_memory_stats("模型初始化后")
self.scheduler = LearningRateScheduler(config.train.learning_rate_schedule, config.train.weight_decay, self.model.parameters())
self.loss_manager = LossManager(ablation="none")
logger.gpu_memory_stats("训练器初始化后")
self.sampler = NormalPerPoint(
global_sigma=0.1, # 全局采样标准差
local_sigma=0.01 # 局部采样标准差
)
logger.info(f"初始化完成,正在处理模型 {self.model_name}")
def build_tree(self,surf_bbox, graph, max_depth=9):
logger.info("开始构造八叉树...")
num_faces = surf_bbox.shape[0]
bbox = self._calculate_global_bbox(surf_bbox)
self.root = OctreeNode(
bbox=bbox,
face_indices=np.arange(num_faces), # 初始包含所有面
patch_graph=graph,
max_depth=max_depth,
surf_bbox=surf_bbox,
surf_ncs=self.data['surf_ncs']
)
#print(surf_bbox)
logger.info("starting octree conduction")
self.root.build_static_tree()
logger.info("complete octree conduction")
self.root.print_tree()
self._save_octree()
def _calculate_global_bbox(self, surf_bbox: torch.Tensor) -> torch.Tensor:
"""
返回一个固定的全局边界框(单位立方体)。
参数:
surf_bbox: (此参数在此实现中未使用)
返回:
形状为 (6,) 的Tensor,表示固定的边界框 [-0.5, -0.5, -0.5, 0.5, 0.5, 0.5]
"""
# 直接定义固定的单位立方体边界框
# 注意:确保张量在正确的设备上创建
fixed_bbox = torch.tensor([-0.5, -0.5, -0.5, 0.5, 0.5, 0.5],
dtype=torch.float32) # 假设 self.device 存储了目标设备
logger.debug(f"使用固定的全局边界框: {fixed_bbox.cpu().numpy()}")
return fixed_bbox
# --- 旧的计算逻辑 (注释掉或删除) ---
# # 验证输入
# if not isinstance(surf_bbox, torch.Tensor):
# raise TypeError(f"surf_bbox 必须是 torch.Tensor,但得到 {type(surf_bbox)}")
# if surf_bbox.dim() != 2 or surf_bbox.shape[1] != 6:
# raise ValueError(f"surf_bbox 形状应为 (num_edges, 6),但得到 {surf_bbox.shape}")
#
# # 计算表面包围盒的全局范围
# global_min = surf_bbox[:, :3].min(dim=0).values
# global_max = surf_bbox[:, 3:].max(dim=0).values
#
# # 返回合并后的边界框
# return torch.cat([global_min, global_max])
# return [-0.5,] # 这个是错误的
def train_epoch_stage1_(self, epoch: int):
total_loss = 0.0
total_loss_details = {
"manifold": 0.0,
"normals": 0.0,
"eikonal": 0.0,
"offsurface": 0.0
}
accumulated_loss = 0.0 # 新增:用于累积多个step的loss
# 新增:在每个epoch开始时清零梯度
self.optimizer.zero_grad()
for step, surf_points in enumerate(self.data['surf_ncs']):
mnfld_points = torch.tensor(surf_points, device=self.device)
nonmnfld_pnts = self.sampler.get_points(mnfld_points.unsqueeze(0)).squeeze(0) # 生成非流形点
gt_sdf = torch.zeros(mnfld_points.shape[0], device=self.device)
normals = None
if args.use_normal:
normals = torch.tensor(self.data["surf_pnt_normals"][step], device=self.device)
logger.debug(normals)
# --- 准备模型输入,启用梯度 ---
mnfld_points.requires_grad_(True) # 在检查之后启用梯度
nonmnfld_pnts.requires_grad_(True) # 在检查之后启用梯度
# --- 前向传播 ---
self.optimizer.zero_grad()
mnfld_pred = self.model.forward_training_volumes(mnfld_points, step)
nonmnfld_pred = self.model.forward_training_volumes(nonmnfld_pnts, step)
if self.debug_mode:
# --- 检查前向传播的输出 ---
logger.print_tensor_stats("mnfld_pred",mnfld_pred)
logger.print_tensor_stats("nonmnfld_pred",nonmnfld_pred)
logger.gpu_memory_stats("前向传播后")
# --- 计算损失 ---
try:
if args.use_normal:
loss, loss_details = self.loss_manager.compute_loss(
mnfld_points,
nonmnfld_pnts,
normals,
gt_sdf,
mnfld_pred,
nonmnfld_pred
)
else:
loss = torch.nn.functional.mse_loss(pred_sdf, gt_sdf)
loss_details = {} # 确保变量初始化
# 修改:累积loss而不是立即backward
accumulated_loss += loss / self.config.train.accumulation_steps # 假设配置中有accumulation_steps
current_loss = loss.item()
total_loss += current_loss
for key in total_loss_details:
if key in loss_details:
total_loss_details[key] += loss_details[key].item()
# 新增:达到累积步数时执行反向传播
if (step + 1) % self.config.train.accumulation_steps == 0:
# 反向传播
self.scheduler.optimizer.zero_grad() # 清空梯度
loss.backward() # 反向传播
self.scheduler.optimizer.step() # 更新参数
self.scheduler.step(accumulated_loss,epoch)
# 记录日志保持不变 ...
except Exception as loss_e:
logger.error(f"Error in step {step}: {loss_e}")
continue
# --- 内存管理 ---
del loss
torch.cuda.empty_cache()
# 新增:处理最后未达到累积步数的剩余loss
if accumulated_loss != 0:
# 反向传播
self.scheduler.optimizer.zero_grad() # 清空梯度
loss.backward() # 反向传播
self.scheduler.optimizer.step() # 更新参数
self.scheduler.step(accumulated_loss,epoch)
# 计算并记录epoch损失
logger.info(f'Train Epoch: {epoch:4d}]\t'
f'Loss: {total_loss:.6f}')
logger.info(f"Loss Details: {total_loss_details}")
return total_loss # 返回平均损失而非累计值
def train_epoch_stage1(self, epoch: int) -> float:
# --- 1. 检查输入数据 ---
# 注意:假设 self.sdf_data 包含 [x, y, z, nx, ny, nz, sdf] (7列) 或 [x, y, z, sdf] (4列)
# 并且 SDF 值总是在最后一列
if self.train_surf_ncs is None:
logger.error(f"Epoch {epoch}: self.train_surf_ncs is None. Cannot train.")
return float('inf')
self.model.train()
total_loss = 0.0
step = 0 # 如果你的训练是分批次的,这里应该用批次索引
batch_size = 8192 # 设置合适的batch大小
# 数据处理
# manfld
_mnfld_pnts = self.train_surf_ncs[:, 0:3].clone().detach() # 取前3列作为点
_normals = self.train_surf_ncs[:, 3:6].clone().detach() # 取中间3列作为法线
_gt_sdf = self.train_surf_ncs[:, -1].clone().detach() # 取最后一列作为SDF真值
# 检查是否需要重新计算缓存
if epoch % 10 == 1 or self.cached_train_data is None:
# 计算流形点的掩码和操作符
# 生成非流形点
_psdf_pnts, _psdf = self.sampler.get_norm_points(_mnfld_pnts, _normals,local_sigma=0.001)
_nonmnfld_pnts = self.sampler.get_points(_mnfld_pnts, local_sigma=0.01):
# 更新缓存
self.cached_train_data = {
"nonmnfld_pnts": _nonmnfld_pnts,
"psdf_pnts": _psdf_pnts,
"psdf": _psdf,
}
else:
# 从缓存中读取数据
_nonmnfld_pnts = self.cached_train_data["nonmnfld_pnts"]
_psdf_pnts = self.cached_train_data["psdf_pnts"]
_psdf = self.cached_train_data["psdf"]
logger.debug((_mnfld_pnts, _nonmnfld_pnts, _psdf))
# 将数据分成多个batch
num_points = self.train_surf_ncs.shape[0]
num_batches = (num_points + batch_size - 1) // batch_size
for batch_idx in range(num_batches):
start_idx = batch_idx * batch_size
end_idx = min((batch_idx + 1) * batch_size, num_points)
# 获取当前batch的数据
mnfld_pnts = _mnfld_pnts[start_idx:end_idx] # 流形点
gt_sdf = _gt_sdf[start_idx:end_idx] # SDF真值
normals = _normals[start_idx:end_idx] if args.use_normal else None # 法线
# 非流形点使用缓存数据(整个batch共享)
nonmnfld_pnts = _nonmnfld_pnts[start_idx:end_idx]
psdf_pnts = _psdf_pnts[start_idx:end_idx]
psdf = _psdf[start_idx:end_idx]
# --- 准备模型输入,启用梯度 ---
mnfld_pnts.requires_grad_(True) # 在检查之后启用梯度
nonmnfld_pnts.requires_grad_(True) # 在检查之后启用梯度
psdf_pnts.requires_grad_(True) # 在检查之后启用梯度
# --- 前向传播 ---
mnfld_pred = self.model.forward_background(
mnfld_pnts
)
nonmnfld_pred = self.model.forward_background(
nonmnfld_pnts
)
psdf_pred = self.model.forward_background(
psdf_pnts
)
# --- 计算损失 ---
loss = torch.tensor(float('nan'), device=self.device) # 初始化为 NaN 以防计算失败
loss_details = {}
try:
# --- 3. 检查损失计算前的输入 ---
# (points 已经启用梯度,不需要再次检查inf/nan,但检查pred_sdf和gt_sdf)
#if check_tensor(pred_sdf, "Predicted SDF (Loss Input)", epoch, step): raise ValueError("Bad pred_sdf before loss")
#if check_tensor(gt_sdf, "GT SDF (Loss Input)", epoch, step): raise ValueError("Bad gt_sdf before loss")
if args.use_normal:
# 检查法线和带梯度的点
#if check_tensor(normals, "Normals (Loss Input)", epoch, step): raise ValueError("Bad normals before loss")
#if check_tensor(points, "Points (Loss Input)", epoch, step): raise ValueError("Bad points before loss")
logger.gpu_memory_stats("计算损失前")
loss, loss_details = self.loss_manager.compute_loss(
mnfld_pnts,
nonmnfld_pnts,
psdf_pnts,
normals, # 传递检查过的 normals
gt_sdf,
mnfld_pred,
nonmnfld_pred,
psdf
)
else:
loss = torch.nn.functional.mse_loss(pred_sdf, gt_sdf)
# --- 4. 检查损失计算结果 ---
if self.debug_mode:
logger.print_tensor_stats("psdf",psdf)
logger.print_tensor_stats("nonmnfld_pred",nonmnfld_pred)
if check_tensor(loss, "Calculated Loss", epoch, step):
logger.error(f"Epoch {epoch} Step {step}: Loss calculation resulted in inf/nan.")
if loss_details: logger.error(f"Loss Details: {loss_details}")
return float('inf') # 如果损失无效,停止这个epoch
except Exception as loss_e:
logger.error(f"Epoch {epoch} Step {step}: Error during loss calculation: {loss_e}", exc_info=True)
return float('inf') # 如果计算出错,停止这个epoch
logger.gpu_memory_stats("损失计算后")
# --- 反向传播和优化 ---
try:
# 反向传播
self.scheduler.optimizer.zero_grad() # 清空梯度
loss.backward() # 反向传播
self.scheduler.optimizer.step() # 更新参数
self.scheduler.step(loss,epoch)
except Exception as backward_e:
logger.error(f"Epoch {epoch} Step {step}: Error during backward pass or optimizer step: {backward_e}", exc_info=True)
# 如果你想看是哪个操作导致的,可以启用 anomaly detection
# torch.autograd.set_detect_anomaly(True) # 放在训练开始前
return float('inf') # 如果反向传播或优化出错,停止这个epoch
# --- 记录和累加损失 ---
current_loss = loss.item()
if not np.isfinite(current_loss): # 再次确认损失是有效的数值
logger.error(f"Epoch {epoch} Step {step}: Loss item is not finite ({current_loss}).")
return float('inf')
total_loss += current_loss
del loss
torch.cuda.empty_cache()
if epoch % 100 == 0:
# 记录训练进度 (只记录有效的损失)
logger.info(f'Train Epoch: {epoch:4d}]\t'
f'Loss: {current_loss:.6f}')
if loss_details: logger.info(f"Loss Details: {loss_details}")
return total_loss # 对于单批次训练,直接返回当前损失
def train_epoch_stage2(self, epoch: int) -> float:
# --- 1. 检查输入数据 ---
# 注意:假设 self.sdf_data 包含 [x, y, z, nx, ny, nz, sdf] (7列) 或 [x, y, z, sdf] (4列)
# 并且 SDF 值总是在最后一列
if self.sdf_data is None:
logger.error(f"Epoch {epoch}: self.sdf_data is None. Cannot train.")
return float('inf')
self.model.train()
total_loss = 0.0
step = 0 # 如果你的训练是分批次的,这里应该用批次索引
batch_size = 8192 * 2 # 设置合适的batch大小
# 数据处理
# manfld
_mnfld_pnts = self.sdf_data[:, 0:3].clone().detach() # 取前3列作为点
_normals = self.sdf_data[:, 3:6].clone().detach() # 取中间3列作为法线
_gt_sdf = self.sdf_data[:, -1].clone().detach() # 取最后一列作为SDF真值
# 检查是否需要重新计算缓存
if epoch % 10 == 1 or self.cached_train_data is None:
# 计算流形点的掩码和操作符
_, _mnfld_face_indices_mask, _mnfld_operator = self.root.forward(_mnfld_pnts)
# 生成非流形点
_nonmnfld_pnts, _psdf = self.sampler.get_norm_points(_mnfld_pnts, _normals)
_, _nonmnfld_face_indices_mask, _nonmnfld_operator = self.root.forward(_nonmnfld_pnts)
# 更新缓存
self.cached_train_data = {
"mnfld_face_indices_mask": _mnfld_face_indices_mask,
"mnfld_operator": _mnfld_operator,
"nonmnfld_pnts": _nonmnfld_pnts,
"psdf": _psdf,
"nonmnfld_face_indices_mask": _nonmnfld_face_indices_mask,
"nonmnfld_operator": _nonmnfld_operator
}
else:
# 从缓存中读取数据
_mnfld_face_indices_mask = self.cached_train_data["mnfld_face_indices_mask"]
_mnfld_operator = self.cached_train_data["mnfld_operator"]
_nonmnfld_pnts = self.cached_train_data["nonmnfld_pnts"]
_psdf = self.cached_train_data["psdf"]
_nonmnfld_face_indices_mask = self.cached_train_data["nonmnfld_face_indices_mask"]
_nonmnfld_operator = self.cached_train_data["nonmnfld_operator"]
logger.debug((_mnfld_pnts, _nonmnfld_pnts, _psdf))
# 将数据分成多个batch
num_points = self.sdf_data.shape[0]
num_batches = (num_points + batch_size - 1) // batch_size
for batch_idx in range(num_batches):
start_idx = batch_idx * batch_size
end_idx = min((batch_idx + 1) * batch_size, num_points)
# 获取当前batch的数据
mnfld_pnts = _mnfld_pnts[start_idx:end_idx] # 流形点
gt_sdf = _gt_sdf[start_idx:end_idx] # SDF真值
normals = _normals[start_idx:end_idx] if args.use_normal else None # 法线
# 非流形点使用缓存数据(整个batch共享)
nonmnfld_pnts = _nonmnfld_pnts[start_idx:end_idx]
psdf = _psdf[start_idx:end_idx]
# --- 准备模型输入,启用梯度 ---
mnfld_pnts.requires_grad_(True) # 在检查之后启用梯度
nonmnfld_pnts.requires_grad_(True) # 在检查之后启用梯度
# --- 前向传播 ---
mnfld_pred = self.model.forward_without_octree(
mnfld_pnts,
_mnfld_face_indices_mask[start_idx:end_idx],
_mnfld_operator[start_idx:end_idx]
)
nonmnfld_pred = self.model.forward_without_octree(
nonmnfld_pnts,
_nonmnfld_face_indices_mask[start_idx:end_idx],
_nonmnfld_operator[start_idx:end_idx]
)
#logger.print_tensor_stats("psdf",psdf)
#logger.print_tensor_stats("nonmnfld_pred",nonmnfld_pred)
# --- 计算损失 ---
loss = torch.tensor(float('nan'), device=self.device) # 初始化为 NaN 以防计算失败
loss_details = {}
try:
# --- 3. 检查损失计算前的输入 ---
# (points 已经启用梯度,不需要再次检查inf/nan,但检查pred_sdf和gt_sdf)
#if check_tensor(pred_sdf, "Predicted SDF (Loss Input)", epoch, step): raise ValueError("Bad pred_sdf before loss")
#if check_tensor(gt_sdf, "GT SDF (Loss Input)", epoch, step): raise ValueError("Bad gt_sdf before loss")
if args.use_normal:
# 检查法线和带梯度的点
#if check_tensor(normals, "Normals (Loss Input)", epoch, step): raise ValueError("Bad normals before loss")
#if check_tensor(points, "Points (Loss Input)", epoch, step): raise ValueError("Bad points before loss")
logger.gpu_memory_stats("计算损失前")
loss, loss_details = self.loss_manager.compute_loss(
mnfld_pnts,
nonmnfld_pnts,
normals, # 传递检查过的 normals
gt_sdf,
mnfld_pred,
nonmnfld_pred,
psdf
)
else:
loss = torch.nn.functional.mse_loss(pred_sdf, gt_sdf)
# --- 4. 检查损失计算结果 ---
if self.debug_mode:
if check_tensor(loss, "Calculated Loss", epoch, step):
logger.error(f"Epoch {epoch} Step {step}: Loss calculation resulted in inf/nan.")
if loss_details: logger.error(f"Loss Details: {loss_details}")
return float('inf') # 如果损失无效,停止这个epoch
except Exception as loss_e:
logger.error(f"Epoch {epoch} Step {step}: Error during loss calculation: {loss_e}", exc_info=True)
return float('inf') # 如果计算出错,停止这个epoch
logger.gpu_memory_stats("损失计算后")
# --- 反向传播和优化 ---
try:
# 反向传播
self.scheduler.optimizer.zero_grad() # 清空梯度
loss.backward() # 反向传播
self.scheduler.optimizer.step() # 更新参数
self.scheduler.step(loss,epoch)
except Exception as backward_e:
logger.error(f"Epoch {epoch} Step {step}: Error during backward pass or optimizer step: {backward_e}", exc_info=True)
# 如果你想看是哪个操作导致的,可以启用 anomaly detection
# torch.autograd.set_detect_anomaly(True) # 放在训练开始前
return float('inf') # 如果反向传播或优化出错,停止这个epoch
# --- 记录和累加损失 ---
current_loss = loss.item()
if not np.isfinite(current_loss): # 再次确认损失是有效的数值
logger.error(f"Epoch {epoch} Step {step}: Loss item is not finite ({current_loss}).")
return float('inf')
total_loss += current_loss
del loss
torch.cuda.empty_cache()
# 记录训练进度 (只记录有效的损失)
logger.info(f'Train Epoch: {epoch:4d}]\t'
f'Loss: {current_loss:.6f}')
if loss_details: logger.info(f"Loss Details: {loss_details}")
return total_loss # 对于单批次训练,直接返回当前损失
def train_epoch(self, epoch: int,resample:bool=True) -> float:
# --- 1. 检查输入数据 ---
# 注意:假设 self.sdf_data 包含 [x, y, z, nx, ny, nz, sdf] (7列) 或 [x, y, z, sdf] (4列)
# 并且 SDF 值总是在最后一列
if self.sdf_data is None:
logger.error(f"Epoch {epoch}: self.sdf_data is None. Cannot train.")
return float('inf')
self.model.train()
total_loss = 0.0
step = 0 # 如果你的训练是分批次的,这里应该用批次索引
batch_size = 8192 # 设置合适的batch大小
# 将数据分成多个batch
num_points = self.sdf_data.shape[0]
num_batches = (num_points + batch_size - 1) // batch_size
for batch_idx in range(num_batches):
start_idx = batch_idx * batch_size
end_idx = min((batch_idx + 1) * batch_size, num_points)
mnfld_pnts = self.sdf_data[start_idx:end_idx, 0:3].clone().detach() # 取前3列作为点
gt_sdf = self.sdf_data[start_idx:end_idx, -1].clone().detach() # 取最后一列作为SDF真值
normals = None
if args.use_normal:
if self.sdf_data.shape[1] < 7: # 检查是否有足够的列给法线
logger.error(f"Epoch {epoch}: --use-normal is specified, but sdf_data has only {self.sdf_data.shape[1]} columns.")
return float('inf')
normals = self.sdf_data[start_idx:end_idx, 3:6].clone().detach() # 取中间3列作为法线
nonmnfld_pnts,psdf = self.sampler.get_norm_points(mnfld_pnts,normals) # 生成非流形点
logger.debug((mnfld_pnts,nonmnfld_pnts,psdf))
else:
nonmnfld_pnts = self.sampler.get_points(mnfld_pnts.unsqueeze(0)).squeeze(0) # 生成非流形点
# 执行检查
if self.debug_mode:
if check_tensor(mnfld_pnts, "Input Points", epoch, step): return float('inf')
if check_tensor(gt_sdf, "Input GT SDF", epoch, step): return float('inf')
if args.use_normal:
# 只有在请求法线时才检查 normals
if check_tensor(normals, "Input Normals", epoch, step): return float('inf')
logger.debug(normals)
logger.print_tensor_stats("normals-x",normals[0])
logger.print_tensor_stats("normals-y",normals[1])
logger.print_tensor_stats("normals-z",normals[2])
# --- 准备模型输入,启用梯度 ---
mnfld_pnts.requires_grad_(True) # 在检查之后启用梯度
nonmnfld_pnts.requires_grad_(True) # 在检查之后启用梯度
# --- 前向传播 ---
mnfld_pred = self.model(mnfld_pnts)
nonmnfld_pred = self.model(nonmnfld_pnts)
if self.debug_mode:
# --- 检查前向传播的输出 ---
logger.print_tensor_stats("mnfld_pred",mnfld_pred)
logger.print_tensor_stats("nonmnfld_pred",nonmnfld_pred)
logger.gpu_memory_stats("前向传播后")
# --- 2. 检查模型输出 ---
#if check_tensor(pred_sdf, "Predicted SDF (Model Output)", epoch, step): return float('inf')
# --- 计算损失 ---
loss = torch.tensor(float('nan'), device=self.device) # 初始化为 NaN 以防计算失败
loss_details = {}
try:
# --- 3. 检查损失计算前的输入 ---
# (points 已经启用梯度,不需要再次检查inf/nan,但检查pred_sdf和gt_sdf)
#if check_tensor(pred_sdf, "Predicted SDF (Loss Input)", epoch, step): raise ValueError("Bad pred_sdf before loss")
#if check_tensor(gt_sdf, "GT SDF (Loss Input)", epoch, step): raise ValueError("Bad gt_sdf before loss")
if args.use_normal:
# 检查法线和带梯度的点
#if check_tensor(normals, "Normals (Loss Input)", epoch, step): raise ValueError("Bad normals before loss")
#if check_tensor(points, "Points (Loss Input)", epoch, step): raise ValueError("Bad points before loss")
logger.gpu_memory_stats("计算损失前")
loss, loss_details = self.loss_manager.compute_loss(
mnfld_pnts,
nonmnfld_pnts,
normals, # 传递检查过的 normals
gt_sdf,
mnfld_pred,
nonmnfld_pred,
psdf
)
else:
loss = torch.nn.functional.mse_loss(pred_sdf, gt_sdf)
# --- 4. 检查损失计算结果 ---
if self.debug_mode:
if check_tensor(loss, "Calculated Loss", epoch, step):
logger.error(f"Epoch {epoch} Step {step}: Loss calculation resulted in inf/nan.")
if loss_details: logger.error(f"Loss Details: {loss_details}")
return float('inf') # 如果损失无效,停止这个epoch
except Exception as loss_e:
logger.error(f"Epoch {epoch} Step {step}: Error during loss calculation: {loss_e}", exc_info=True)
return float('inf') # 如果计算出错,停止这个epoch
logger.gpu_memory_stats("损失计算后")
# --- 反向传播和优化 ---
try:
# 反向传播
self.scheduler.optimizer.zero_grad() # 清空梯度
loss.backward() # 反向传播
self.scheduler.optimizer.step() # 更新参数
self.scheduler.step(loss,epoch)
except Exception as backward_e:
logger.error(f"Epoch {epoch} Step {step}: Error during backward pass or optimizer step: {backward_e}", exc_info=True)
# 如果你想看是哪个操作导致的,可以启用 anomaly detection
# torch.autograd.set_detect_anomaly(True) # 放在训练开始前
return float('inf') # 如果反向传播或优化出错,停止这个epoch
# --- 记录和累加损失 ---
current_loss = loss.item()
if not np.isfinite(current_loss): # 再次确认损失是有效的数值
logger.error(f"Epoch {epoch} Step {step}: Loss item is not finite ({current_loss}).")
return float('inf')
total_loss += current_loss
del loss
torch.cuda.empty_cache()
# 记录训练进度 (只记录有效的损失)
logger.info(f'Train Epoch: {epoch:4d}]\t'
f'Loss: {current_loss:.6f}')
if loss_details: logger.info(f"Loss Details: {loss_details}")
return total_loss # 对于单批次训练,直接返回当前损失
def validate(self, epoch: int) -> float:
self.model.eval()
total_loss = 0.0
with torch.no_grad():
for batch in self.val_loader:
points = batch['points'].to(self.device)
gt_sdf = batch['sdf'].to(self.device)
pred_sdf = self.model(points)
loss = torch.nn.functional.mse_loss(pred_sdf, gt_sdf)
total_loss += loss.item()
avg_loss = total_loss / len(self.val_loader)
logger.info(f'Validation Epoch: {epoch}\tAverage Loss: {avg_loss:.6f}')
return avg_loss
def train(self):
best_val_loss = float('inf')
logger.info("Starting training...")
start_time = time.time()
self.cached_train_data=None
start_epoch = 1
if args.resume_checkpoint_path:
start_epoch = self._load_checkpoint(args.resume_checkpoint_path)
logger.info(f"Loaded model from {args.resume_checkpoint_path}")
self.model.encoder.freeze_stage1()
for epoch in range(start_epoch, self.config.train.num_epochs + 1):
# 训练一个epoch
train_loss = self.train_epoch_stage1(epoch)
#train_loss = self.train_epoch_stage2(epoch)
#train_loss = self.train_epoch(epoch)
# 验证
'''
if epoch % self.config.train.val_freq == 0:
val_loss = self.validate(epoch)
logger.info(f'Epoch {epoch}: Train Loss = {train_loss:.6f}, Val Loss = {val_loss:.6f}')
# 保存最佳模型
if val_loss < best_val_loss:
best_val_loss = val_loss
self._save_model(epoch, val_loss)
logger.info(f'New best model saved at epoch {epoch} with val loss {val_loss:.6f}')
else:
logger.info(f'Epoch {epoch}: Train Loss = {train_loss:.6f}')
'''
# 保存检查点
if epoch % self.config.train.save_freq == 0:
self._save_checkpoint(epoch, train_loss)
logger.info(f'Checkpoint saved at epoch {epoch}')
self.model.encoder.unfreeze()
# 训练完成
total_time = time.time() - start_time
self._tracing_model_by_script()
#self._tracing_model()
logger.info(f'Training completed in {total_time//3600:.0f}h {(total_time%3600)//60:.0f}m {total_time%60:.2f}s')
logger.info(f'Best validation loss: {best_val_loss:.6f}')
#self.test_load()
def test_load(self):
model = self._load_checkpoint("/home/wch/brep2sdf/brep2sdf/00000054.pt")
model.eval()
logger.debug(model)
example_input = torch.rand(10, 3, device=self.device)
#logger.debug(model.encoder.octree.bbox)
logger.debug(f"points: {example_input}")
sdfs= model(example_input)
logger.debug(f"sdfs:{sdfs}")
def _tracing_model_by_script(self):
"""保存模型"""
self.model.eval()
# 确保模型中的所有逻辑都兼容 TorchScript
scripted_model = torch.jit.script(self.model)
#optimized_model = optimize_for_mobile(scripted_model)
torch.jit.save(scripted_model, f"/home/wch/brep2sdf/data/output_data/{self.model_name}.pt")
def _tracing_model(self):
"""保存模型"""
self.model.eval()
# 创建示例输入
example_input = torch.rand(1, 3, device=self.device)
# 使用 trace 方式导出模型
traced_model = torch.jit.trace(self.model, example_input)
# 保存模型
save_path = f"/home/wch/brep2sdf/data/output_data/{self.model_name}.pt"
torch.jit.save(traced_model, save_path)
# 验证保存的模型
try:
loaded_model = torch.jit.load(save_path)
test_input = torch.rand(1, 3, device=self.device)
_ = loaded_model(test_input)
logger.info(f"模型已保存并验证成功:{save_path}")
except Exception as e:
logger.error(f"模型验证失败:{e}")
def _save_checkpoint(self, epoch: int, train_loss: float):
"""保存训练检查点"""
checkpoint_dir = os.path.join(
self.config.train.checkpoint_dir,
self.model_name
)
os.makedirs(checkpoint_dir, exist_ok=True)
checkpoint_path = os.path.join(checkpoint_dir, f"epoch_{epoch:03d}.pth")
# 只保存状态字典
torch.save({
'epoch': epoch,
'model_state_dict': self.model.state_dict(),
'optimizer_state_dict': self.scheduler.optimizer.state_dict(),
'loss': train_loss,
}, checkpoint_path)
def _load_checkpoint(self, checkpoint_path):
"""从检查点恢复训练状态"""
try:
checkpoint = torch.load(checkpoint_path)
self.model.load_state_dict(checkpoint['model_state_dict'])
self.scheduler.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
return checkpoint['epoch'] + 1
except Exception as e:
logger.error(f"加载checkpoint失败: {str(e)}")
raise
def _save_octree(self):
"""
保存八叉树到文件。
八叉树保存路径基于模型名称和配置中的检查点目录。
"""
checkpoint_dir = os.path.join(
self.config.train.checkpoint_dir,
self.model_name
)
os.makedirs(checkpoint_dir, exist_ok=True)
octree_path = os.path.join(checkpoint_dir, "octree.pth")
try:
# 保存八叉树的根节点
torch.save(self.root, octree_path)
logger.info(f"八叉树已保存到 {octree_path}")
except Exception as e:
logger.error(f"保存八叉树失败: {str(e)}")
def _load_octree(self)->bool:
"""
从文件加载八叉树。
尝试从基于模型名称和配置检查点目录的路径加载八叉树。
"""
checkpoint_dir = os.path.join(
self.config.train.checkpoint_dir,
self.model_name
)
octree_path = os.path.join(checkpoint_dir, "octree.pth")
try:
if os.path.exists(octree_path):
# 加载八叉树的根节点
self.root = torch.load(octree_path, weights_only=False)
logger.info(f"八叉树已从 {octree_path} 加载")
return True
else:
logger.warning(f"八叉树文件 {octree_path} 不存在,无法加载。")
except Exception as e:
logger.error(f"加载八叉树失败: {str(e)}")
return False
def main():
# 这里需要初始化配置
config = get_default_config()
# 初始化训练器并开始训练
trainer = Trainer(config, input_step=args.input)
trainer.train()
if __name__ == '__main__':
main()