You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
456 lines
19 KiB
456 lines
19 KiB
import torch
|
|
import torch.optim as optim
|
|
import time
|
|
import os
|
|
import numpy as np
|
|
import argparse
|
|
|
|
from brep2sdf.config.default_config import get_default_config
|
|
from brep2sdf.data.data import load_brep_file,prepare_sdf_data, print_data_distribution, check_tensor
|
|
from brep2sdf.data.pre_process_by_mesh import process_single_step
|
|
from brep2sdf.networks.network import Net
|
|
from brep2sdf.networks.octree import OctreeNode
|
|
from brep2sdf.networks.loss import LossManager
|
|
from brep2sdf.networks.patch_graph import PatchGraph
|
|
from brep2sdf.utils.logger import logger
|
|
|
|
|
|
# 配置命令行参数
|
|
parser = argparse.ArgumentParser(description='STEP文件批量处理工具')
|
|
parser.add_argument('-i', '--input', required=True,
|
|
help='待处理 brep (.step) 路径')
|
|
parser.add_argument(
|
|
'--use-normal',
|
|
action='store_true', # 默认为 False,如果用户指定该参数,则为 True
|
|
help='强制采样点有法向量'
|
|
)
|
|
parser.add_argument(
|
|
'--only-zero-surface',
|
|
action='store_true', # 默认为 False,如果用户指定该参数,则为 True
|
|
help='只采样零表面点 SDF 训练'
|
|
)
|
|
parser.add_argument(
|
|
'--force-reprocess','-f',
|
|
action='store_true', # 默认为 False,如果用户指定该参数,则为 True
|
|
help='强制重新进行数据预处理,忽略缓存或已有结果'
|
|
)
|
|
parser.add_argument(
|
|
'--resume-checkpoint-path',
|
|
type=str,
|
|
default=None,
|
|
help='从指定的checkpoint文件继续训练'
|
|
)
|
|
|
|
parser.add_argument(
|
|
'--octree-cuda',
|
|
action='store_true', # 默认为 False,如果用户指定该参数,则为 True
|
|
help='使用CUDA加速Octree构建'
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
|
|
class Trainer:
|
|
def __init__(self, config, input_step):
|
|
logger.gpu_memory_stats("初始化开始")
|
|
self.config = config
|
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
self.debug_mode = config.train.debug_mode
|
|
|
|
self.model_name = os.path.basename(input_step).split('_')[0]
|
|
self.base_name = self.model_name + ".xyz"
|
|
data_path = os.path.join("/home/wch/brep2sdf/data/output_data",self.base_name)
|
|
if os.path.exists(data_path) and not args.force_reprocess:
|
|
try:
|
|
self.data = load_brep_file(data_path)
|
|
except Exception as e:
|
|
logger.error(f"fail to load {data_path}, {str(e)}")
|
|
raise e
|
|
if args.use_normal and self.data.get("surf_pnt_normals", None) is None:
|
|
self.data = process_single_step(step_path=input_step, output_path=data_path, sample_normal_vector=args.use_normal,sample_sdf_points=not args.only_zero_surface)
|
|
else:
|
|
self.data = process_single_step(step_path=input_step, output_path=data_path, sample_normal_vector=args.use_normal,sample_sdf_points=not args.only_zero_surface)
|
|
|
|
logger.gpu_memory_stats("数据预处理后")
|
|
# 将曲面点云列表转换为 (N*M, 4) 数组
|
|
surfs = self.data["surf_ncs"]
|
|
|
|
# 准备表面点的SDF数据
|
|
surface_sdf_data = prepare_sdf_data(
|
|
surfs,
|
|
normals=self.data["surf_pnt_normals"],
|
|
max_points=50000,
|
|
device=self.device
|
|
)
|
|
# 如果不是仅使用零表面,则合并采样点数据
|
|
if not args.only_zero_surface:
|
|
# 加载采样点数据
|
|
sampled_sdf_data = torch.tensor(
|
|
self.data['sampled_points_normals_sdf'],
|
|
dtype=torch.float32,
|
|
device=self.device
|
|
)
|
|
# 合并表面点数据和采样点数据
|
|
self.sdf_data = torch.cat([surface_sdf_data, sampled_sdf_data], dim=0)
|
|
else:
|
|
self.sdf_data = surface_sdf_data
|
|
print_data_distribution(self.sdf_data)
|
|
logger.gpu_memory_stats("SDF数据准备后")
|
|
# 初始化数据集
|
|
#self.brep_data = load_brep_file(self.config.data.pkl_path)
|
|
#logger.info( self.brep_data )
|
|
#self.sdf_data = load_sdf_file(sdf_path=self.config.data.sdf_path, num_query_points=self.config.data.num_query_points).to(self.device)
|
|
|
|
# 构建面片邻接图
|
|
graph = PatchGraph.from_preprocessed_data(
|
|
surf_wcs=self.data['surf_wcs'],
|
|
edgeFace_adj=self.data['edgeFace_adj'],
|
|
edge_types=self.data['edge_types'],
|
|
device='cuda' if args.octree_cuda else 'cpu'
|
|
)
|
|
# 初始化网络
|
|
surf_bbox=torch.tensor(
|
|
self.data['surf_bbox_ncs'],
|
|
dtype=torch.float32,
|
|
device=self.device
|
|
)
|
|
|
|
self.build_tree(surf_bbox=surf_bbox, graph=graph,max_depth=6)
|
|
logger.gpu_memory_stats("数初始化后")
|
|
|
|
self.model = Net(
|
|
octree=self.root,
|
|
volume_bboxs=surf_bbox,
|
|
feature_dim=64
|
|
).to(self.device)
|
|
logger.gpu_memory_stats("模型初始化后")
|
|
|
|
# 初始化优化器
|
|
self.optimizer = optim.AdamW(
|
|
self.model.parameters(),
|
|
lr=config.train.learning_rate,
|
|
weight_decay=config.train.weight_decay
|
|
)
|
|
|
|
self.loss_manager = LossManager(ablation="none")
|
|
logger.gpu_memory_stats("训练器初始化后")
|
|
|
|
logger.info(f"初始化完成,正在处理模型 {self.model_name}")
|
|
|
|
|
|
def build_tree(self,surf_bbox, graph, max_depth=9):
|
|
num_faces = surf_bbox.shape[0]
|
|
bbox = self._calculate_global_bbox(surf_bbox)
|
|
self.root = OctreeNode(
|
|
bbox=bbox,
|
|
face_indices=np.arange(num_faces), # 初始包含所有面
|
|
patch_graph=graph,
|
|
max_depth=max_depth,
|
|
surf_bbox=surf_bbox,
|
|
surf_ncs=self.data['surf_ncs']
|
|
)
|
|
#print(surf_bbox)
|
|
logger.info("starting octree conduction")
|
|
self.root.build_static_tree()
|
|
logger.info("complete octree conduction")
|
|
self.root.print_tree()
|
|
|
|
def _calculate_global_bbox(self, surf_bbox: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
返回一个固定的全局边界框(单位立方体)。
|
|
|
|
参数:
|
|
surf_bbox: (此参数在此实现中未使用)
|
|
|
|
返回:
|
|
形状为 (6,) 的Tensor,表示固定的边界框 [-0.5, -0.5, -0.5, 0.5, 0.5, 0.5]
|
|
"""
|
|
# 直接定义固定的单位立方体边界框
|
|
# 注意:确保张量在正确的设备上创建
|
|
fixed_bbox = torch.tensor([-0.5, -0.5, -0.5, 0.5, 0.5, 0.5],
|
|
dtype=torch.float32) # 假设 self.device 存储了目标设备
|
|
|
|
logger.debug(f"使用固定的全局边界框: {fixed_bbox.cpu().numpy()}")
|
|
return fixed_bbox
|
|
|
|
# --- 旧的计算逻辑 (注释掉或删除) ---
|
|
# # 验证输入
|
|
# if not isinstance(surf_bbox, torch.Tensor):
|
|
# raise TypeError(f"surf_bbox 必须是 torch.Tensor,但得到 {type(surf_bbox)}")
|
|
# if surf_bbox.dim() != 2 or surf_bbox.shape[1] != 6:
|
|
# raise ValueError(f"surf_bbox 形状应为 (num_edges, 6),但得到 {surf_bbox.shape}")
|
|
#
|
|
# # 计算表面包围盒的全局范围
|
|
# global_min = surf_bbox[:, :3].min(dim=0).values
|
|
# global_max = surf_bbox[:, 3:].max(dim=0).values
|
|
#
|
|
# # 返回合并后的边界框
|
|
# return torch.cat([global_min, global_max])
|
|
# return [-0.5,] # 这个是错误的
|
|
|
|
|
|
def train_epoch(self, epoch: int) -> float:
|
|
# --- 1. 检查输入数据 ---
|
|
# 注意:假设 self.sdf_data 包含 [x, y, z, nx, ny, nz, sdf] (7列) 或 [x, y, z, sdf] (4列)
|
|
# 并且 SDF 值总是在最后一列
|
|
if self.sdf_data is None:
|
|
logger.error(f"Epoch {epoch}: self.sdf_data is None. Cannot train.")
|
|
return float('inf')
|
|
|
|
self.model.train()
|
|
total_loss = 0.0
|
|
step = 0 # 如果你的训练是分批次的,这里应该用批次索引
|
|
batch_size = 10240 # 设置合适的batch大小
|
|
|
|
# 将数据分成多个batch
|
|
num_points = self.sdf_data.shape[0]
|
|
num_batches = (num_points + batch_size - 1) // batch_size
|
|
|
|
for batch_idx in range(num_batches):
|
|
start_idx = batch_idx * batch_size
|
|
end_idx = min((batch_idx + 1) * batch_size, num_points)
|
|
points = self.sdf_data[start_idx:end_idx, 0:3].clone().detach() # 取前3列作为点
|
|
gt_sdf = self.sdf_data[start_idx:end_idx, -1].clone().detach() # 取最后一列作为SDF真值
|
|
normals = None
|
|
if args.use_normal:
|
|
if self.sdf_data.shape[1] < 7: # 检查是否有足够的列给法线
|
|
logger.error(f"Epoch {epoch}: --use-normal is specified, but sdf_data has only {self.sdf_data.shape[1]} columns.")
|
|
return float('inf')
|
|
normals = self.sdf_data[start_idx:end_idx, 3:6].clone().detach() # 取中间3列作为法线
|
|
|
|
# 执行检查
|
|
if self.debug_mode:
|
|
if check_tensor(points, "Input Points", epoch, step): return float('inf')
|
|
if check_tensor(gt_sdf, "Input GT SDF", epoch, step): return float('inf')
|
|
if args.use_normal:
|
|
# 只有在请求法线时才检查 normals
|
|
if check_tensor(normals, "Input Normals", epoch, step): return float('inf')
|
|
|
|
|
|
# --- 准备模型输入,启用梯度 ---
|
|
points.requires_grad_(True) # 在检查之后启用梯度
|
|
|
|
# --- 前向传播 ---
|
|
self.optimizer.zero_grad()
|
|
pred_sdf = self.model(points)
|
|
|
|
if self.debug_mode:
|
|
# --- 检查前向传播的输出 ---
|
|
logger.gpu_memory_stats("前向传播后")
|
|
# --- 2. 检查模型输出 ---
|
|
#if check_tensor(pred_sdf, "Predicted SDF (Model Output)", epoch, step): return float('inf')
|
|
|
|
# --- 计算损失 ---
|
|
loss = torch.tensor(float('nan'), device=self.device) # 初始化为 NaN 以防计算失败
|
|
loss_details = {}
|
|
try:
|
|
# --- 3. 检查损失计算前的输入 ---
|
|
# (points 已经启用梯度,不需要再次检查inf/nan,但检查pred_sdf和gt_sdf)
|
|
#if check_tensor(pred_sdf, "Predicted SDF (Loss Input)", epoch, step): raise ValueError("Bad pred_sdf before loss")
|
|
#if check_tensor(gt_sdf, "GT SDF (Loss Input)", epoch, step): raise ValueError("Bad gt_sdf before loss")
|
|
if args.use_normal:
|
|
# 检查法线和带梯度的点
|
|
#if check_tensor(normals, "Normals (Loss Input)", epoch, step): raise ValueError("Bad normals before loss")
|
|
#if check_tensor(points, "Points (Loss Input)", epoch, step): raise ValueError("Bad points before loss")
|
|
logger.gpu_memory_stats("计算损失前")
|
|
loss, loss_details = self.loss_manager.compute_loss(
|
|
points,
|
|
normals, # 传递检查过的 normals
|
|
gt_sdf,
|
|
pred_sdf
|
|
)
|
|
else:
|
|
loss = torch.nn.functional.mse_loss(pred_sdf, gt_sdf)
|
|
|
|
# --- 4. 检查损失计算结果 ---
|
|
if self.debug_mode:
|
|
if check_tensor(loss, "Calculated Loss", epoch, step):
|
|
logger.error(f"Epoch {epoch} Step {step}: Loss calculation resulted in inf/nan.")
|
|
if loss_details: logger.error(f"Loss Details: {loss_details}")
|
|
return float('inf') # 如果损失无效,停止这个epoch
|
|
|
|
except Exception as loss_e:
|
|
logger.error(f"Epoch {epoch} Step {step}: Error during loss calculation: {loss_e}", exc_info=True)
|
|
return float('inf') # 如果计算出错,停止这个epoch
|
|
logger.gpu_memory_stats("损失计算后")
|
|
|
|
# --- 反向传播和优化 ---
|
|
try:
|
|
loss.backward()
|
|
|
|
# --- 5. (可选) 检查梯度 ---
|
|
# for name, param in self.model.named_parameters():
|
|
# if param.grad is not None:
|
|
# if check_tensor(param.grad, f"Gradient/{name}", epoch, step):
|
|
# logger.warning(f"Epoch {epoch} Step {step}: Bad gradient for {name}. Consider clipping or zeroing.")
|
|
# # 例如:param.grad.data.clamp_(-1, 1) # 值裁剪
|
|
# # 或在 optimizer.step() 前进行范数裁剪:
|
|
# # torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
|
|
|
|
# --- (推荐) 添加梯度裁剪 ---
|
|
# 防止梯度爆炸,这可能是导致 inf/nan 的原因之一
|
|
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) # 范数裁剪
|
|
|
|
self.optimizer.step()
|
|
except Exception as backward_e:
|
|
logger.error(f"Epoch {epoch} Step {step}: Error during backward pass or optimizer step: {backward_e}", exc_info=True)
|
|
# 如果你想看是哪个操作导致的,可以启用 anomaly detection
|
|
# torch.autograd.set_detect_anomaly(True) # 放在训练开始前
|
|
return float('inf') # 如果反向传播或优化出错,停止这个epoch
|
|
|
|
|
|
# --- 记录和累加损失 ---
|
|
current_loss = loss.item()
|
|
if not np.isfinite(current_loss): # 再次确认损失是有效的数值
|
|
logger.error(f"Epoch {epoch} Step {step}: Loss item is not finite ({current_loss}).")
|
|
return float('inf')
|
|
|
|
total_loss += current_loss
|
|
del loss
|
|
torch.cuda.empty_cache()
|
|
|
|
|
|
# 记录训练进度 (只记录有效的损失)
|
|
logger.info(f'Train Epoch: {epoch:4d}]\t'
|
|
f'Loss: {current_loss:.6f}')
|
|
if loss_details: logger.info(f"Loss Details: {loss_details}")
|
|
|
|
return total_loss # 对于单批次训练,直接返回当前损失
|
|
|
|
def validate(self, epoch: int) -> float:
|
|
self.model.eval()
|
|
total_loss = 0.0
|
|
|
|
with torch.no_grad():
|
|
for batch in self.val_loader:
|
|
points = batch['points'].to(self.device)
|
|
gt_sdf = batch['sdf'].to(self.device)
|
|
|
|
pred_sdf = self.model(points)
|
|
loss = torch.nn.functional.mse_loss(pred_sdf, gt_sdf)
|
|
total_loss += loss.item()
|
|
|
|
avg_loss = total_loss / len(self.val_loader)
|
|
logger.info(f'Validation Epoch: {epoch}\tAverage Loss: {avg_loss:.6f}')
|
|
return avg_loss
|
|
|
|
def train(self):
|
|
best_val_loss = float('inf')
|
|
logger.info("Starting training...")
|
|
start_time = time.time()
|
|
|
|
start_epoch = 1
|
|
if args.resume_checkpoint_path:
|
|
start_epoch = self._load_checkpoint(args.resume_checkpoint_path)
|
|
logger.info(f"Loaded model from {args.resume_checkpoint_path}")
|
|
|
|
for epoch in range(start_epoch, self.config.train.num_epochs + 1):
|
|
# 训练一个epoch
|
|
train_loss = self.train_epoch(epoch)
|
|
|
|
# 验证
|
|
'''
|
|
if epoch % self.config.train.val_freq == 0:
|
|
val_loss = self.validate(epoch)
|
|
logger.info(f'Epoch {epoch}: Train Loss = {train_loss:.6f}, Val Loss = {val_loss:.6f}')
|
|
|
|
# 保存最佳模型
|
|
if val_loss < best_val_loss:
|
|
best_val_loss = val_loss
|
|
self._save_model(epoch, val_loss)
|
|
logger.info(f'New best model saved at epoch {epoch} with val loss {val_loss:.6f}')
|
|
else:
|
|
logger.info(f'Epoch {epoch}: Train Loss = {train_loss:.6f}')
|
|
'''
|
|
# 保存检查点
|
|
if epoch % self.config.train.save_freq == 0:
|
|
self._save_checkpoint(epoch, train_loss)
|
|
logger.info(f'Checkpoint saved at epoch {epoch}')
|
|
|
|
# 训练完成
|
|
total_time = time.time() - start_time
|
|
|
|
self._tracing_model_by_script()
|
|
#self._tracing_model()
|
|
logger.info(f'Training completed in {total_time//3600:.0f}h {(total_time%3600)//60:.0f}m {total_time%60:.2f}s')
|
|
logger.info(f'Best validation loss: {best_val_loss:.6f}')
|
|
#self.test_load()
|
|
|
|
def test_load(self):
|
|
model = self._load_checkpoint("/home/wch/brep2sdf/brep2sdf/00000054.pt")
|
|
model.eval()
|
|
logger.debug(model)
|
|
example_input = torch.rand(10, 3, device=self.device)
|
|
#logger.debug(model.encoder.octree.bbox)
|
|
logger.debug(f"points: {example_input}")
|
|
sdfs= model(example_input)
|
|
logger.debug(f"sdfs:{sdfs}")
|
|
|
|
def _tracing_model_by_script(self):
|
|
"""保存模型"""
|
|
self.model.eval()
|
|
# 确保模型中的所有逻辑都兼容 TorchScript
|
|
scripted_model = torch.jit.script(self.model)
|
|
#optimized_model = optimize_for_mobile(scripted_model)
|
|
torch.jit.save(scripted_model, f"/home/wch/brep2sdf/data/output_data/{self.model_name}.pt")
|
|
|
|
def _tracing_model(self):
|
|
"""保存模型"""
|
|
self.model.eval()
|
|
|
|
# 创建示例输入
|
|
example_input = torch.rand(1, 3, device=self.device)
|
|
|
|
# 使用 trace 方式导出模型
|
|
traced_model = torch.jit.trace(self.model, example_input)
|
|
|
|
# 保存模型
|
|
save_path = f"/home/wch/brep2sdf/data/output_data/{self.model_name}.pt"
|
|
torch.jit.save(traced_model, save_path)
|
|
|
|
# 验证保存的模型
|
|
try:
|
|
loaded_model = torch.jit.load(save_path)
|
|
test_input = torch.rand(1, 3, device=self.device)
|
|
_ = loaded_model(test_input)
|
|
logger.info(f"模型已保存并验证成功:{save_path}")
|
|
except Exception as e:
|
|
logger.error(f"模型验证失败:{e}")
|
|
|
|
def _save_checkpoint(self, epoch: int, train_loss: float):
|
|
"""保存训练检查点"""
|
|
checkpoint_dir = os.path.join(
|
|
self.config.train.checkpoint_dir,
|
|
self.model_name
|
|
)
|
|
os.makedirs(checkpoint_dir, exist_ok=True)
|
|
checkpoint_path = os.path.join(checkpoint_dir, f"epoch_{epoch:03d}.pth")
|
|
|
|
# 只保存状态字典
|
|
torch.save({
|
|
'epoch': epoch,
|
|
'model_state_dict': self.model.state_dict(),
|
|
'optimizer_state_dict': self.optimizer.state_dict(),
|
|
'loss': train_loss,
|
|
}, checkpoint_path)
|
|
|
|
def _load_checkpoint(self, checkpoint_path):
|
|
"""从检查点恢复训练状态"""
|
|
try:
|
|
checkpoint = torch.load(checkpoint_path)
|
|
self.model.load_state_dict(checkpoint['model_state_dict'])
|
|
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
|
return checkpoint['epoch'] + 1
|
|
except Exception as e:
|
|
logger.error(f"加载checkpoint失败: {str(e)}")
|
|
raise
|
|
def main():
|
|
# 这里需要初始化配置
|
|
config = get_default_config()
|
|
# 初始化训练器并开始训练
|
|
trainer = Trainer(config, input_step=args.input)
|
|
trainer.train()
|
|
|
|
if __name__ == '__main__':
|
|
main()
|