You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

456 lines
19 KiB

import torch
import torch.optim as optim
import time
import os
import numpy as np
import argparse
from brep2sdf.config.default_config import get_default_config
from brep2sdf.data.data import load_brep_file,prepare_sdf_data, print_data_distribution, check_tensor
1 year ago
from brep2sdf.data.pre_process_by_mesh import process_single_step
from brep2sdf.networks.network import Net
from brep2sdf.networks.octree import OctreeNode
1 year ago
from brep2sdf.networks.loss import LossManager
from brep2sdf.networks.patch_graph import PatchGraph
from brep2sdf.utils.logger import logger
1 year ago
# 配置命令行参数
parser = argparse.ArgumentParser(description='STEP文件批量处理工具')
parser.add_argument('-i', '--input', required=True,
help='待处理 brep (.step) 路径')
parser.add_argument(
'--use-normal',
action='store_true', # 默认为 False,如果用户指定该参数,则为 True
help='强制采样点有法向量'
)
parser.add_argument(
'--only-zero-surface',
action='store_true', # 默认为 False,如果用户指定该参数,则为 True
help='只采样零表面点 SDF 训练'
)
1 year ago
parser.add_argument(
'--force-reprocess','-f',
1 year ago
action='store_true', # 默认为 False,如果用户指定该参数,则为 True
help='强制重新进行数据预处理,忽略缓存或已有结果'
)
parser.add_argument(
'--resume-checkpoint-path',
type=str,
default=None,
help='从指定的checkpoint文件继续训练'
)
parser.add_argument(
'--octree-cuda',
action='store_true', # 默认为 False,如果用户指定该参数,则为 True
help='使用CUDA加速Octree构建'
)
1 year ago
args = parser.parse_args()
class Trainer:
def __init__(self, config, input_step):
logger.gpu_memory_stats("初始化开始")
self.config = config
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.debug_mode = config.train.debug_mode
self.model_name = os.path.basename(input_step).split('_')[0]
self.base_name = self.model_name + ".xyz"
data_path = os.path.join("/home/wch/brep2sdf/data/output_data",self.base_name)
1 year ago
if os.path.exists(data_path) and not args.force_reprocess:
try:
self.data = load_brep_file(data_path)
except Exception as e:
logger.error(f"fail to load {data_path}, {str(e)}")
raise e
if args.use_normal and self.data.get("surf_pnt_normals", None) is None:
self.data = process_single_step(step_path=input_step, output_path=data_path, sample_normal_vector=args.use_normal,sample_sdf_points=not args.only_zero_surface)
else:
self.data = process_single_step(step_path=input_step, output_path=data_path, sample_normal_vector=args.use_normal,sample_sdf_points=not args.only_zero_surface)
logger.gpu_memory_stats("数据预处理后")
# 将曲面点云列表转换为 (N*M, 4) 数组
surfs = self.data["surf_ncs"]
# 准备表面点的SDF数据
surface_sdf_data = prepare_sdf_data(
surfs,
normals=self.data["surf_pnt_normals"],
max_points=50000,
device=self.device
)
# 如果不是仅使用零表面,则合并采样点数据
if not args.only_zero_surface:
# 加载采样点数据
sampled_sdf_data = torch.tensor(
self.data['sampled_points_normals_sdf'],
dtype=torch.float32,
device=self.device
)
# 合并表面点数据和采样点数据
self.sdf_data = torch.cat([surface_sdf_data, sampled_sdf_data], dim=0)
else:
self.sdf_data = surface_sdf_data
print_data_distribution(self.sdf_data)
logger.gpu_memory_stats("SDF数据准备后")
# 初始化数据集
#self.brep_data = load_brep_file(self.config.data.pkl_path)
#logger.info( self.brep_data )
#self.sdf_data = load_sdf_file(sdf_path=self.config.data.sdf_path, num_query_points=self.config.data.num_query_points).to(self.device)
# 构建面片邻接图
graph = PatchGraph.from_preprocessed_data(
surf_wcs=self.data['surf_wcs'],
edgeFace_adj=self.data['edgeFace_adj'],
edge_types=self.data['edge_types'],
device='cuda' if args.octree_cuda else 'cpu'
)
# 初始化网络
surf_bbox=torch.tensor(
self.data['surf_bbox_ncs'],
dtype=torch.float32,
device=self.device
)
self.build_tree(surf_bbox=surf_bbox, graph=graph,max_depth=6)
logger.gpu_memory_stats("数初始化后")
self.model = Net(
octree=self.root,
volume_bboxs=surf_bbox,
feature_dim=64
).to(self.device)
logger.gpu_memory_stats("模型初始化后")
# 初始化优化器
self.optimizer = optim.AdamW(
self.model.parameters(),
lr=config.train.learning_rate,
weight_decay=config.train.weight_decay
)
1 year ago
self.loss_manager = LossManager(ablation="none")
logger.gpu_memory_stats("训练器初始化后")
1 year ago
logger.info(f"初始化完成,正在处理模型 {self.model_name}")
def build_tree(self,surf_bbox, graph, max_depth=9):
num_faces = surf_bbox.shape[0]
bbox = self._calculate_global_bbox(surf_bbox)
self.root = OctreeNode(
bbox=bbox,
face_indices=np.arange(num_faces), # 初始包含所有面
patch_graph=graph,
max_depth=max_depth,
surf_bbox=surf_bbox,
surf_ncs=self.data['surf_ncs']
)
#print(surf_bbox)
logger.info("starting octree conduction")
self.root.build_static_tree()
logger.info("complete octree conduction")
self.root.print_tree()
def _calculate_global_bbox(self, surf_bbox: torch.Tensor) -> torch.Tensor:
"""
返回一个固定的全局边界框单位立方体
参数:
surf_bbox: (此参数在此实现中未使用)
返回:
形状为 (6,) 的Tensor表示固定的边界框 [-0.5, -0.5, -0.5, 0.5, 0.5, 0.5]
"""
# 直接定义固定的单位立方体边界框
# 注意:确保张量在正确的设备上创建
fixed_bbox = torch.tensor([-0.5, -0.5, -0.5, 0.5, 0.5, 0.5],
dtype=torch.float32) # 假设 self.device 存储了目标设备
logger.debug(f"使用固定的全局边界框: {fixed_bbox.cpu().numpy()}")
return fixed_bbox
# --- 旧的计算逻辑 (注释掉或删除) ---
# # 验证输入
# if not isinstance(surf_bbox, torch.Tensor):
# raise TypeError(f"surf_bbox 必须是 torch.Tensor,但得到 {type(surf_bbox)}")
# if surf_bbox.dim() != 2 or surf_bbox.shape[1] != 6:
# raise ValueError(f"surf_bbox 形状应为 (num_edges, 6),但得到 {surf_bbox.shape}")
#
# # 计算表面包围盒的全局范围
# global_min = surf_bbox[:, :3].min(dim=0).values
# global_max = surf_bbox[:, 3:].max(dim=0).values
#
# # 返回合并后的边界框
# return torch.cat([global_min, global_max])
# return [-0.5,] # 这个是错误的
def train_epoch(self, epoch: int) -> float:
# --- 1. 检查输入数据 ---
# 注意:假设 self.sdf_data 包含 [x, y, z, nx, ny, nz, sdf] (7列) 或 [x, y, z, sdf] (4列)
# 并且 SDF 值总是在最后一列
if self.sdf_data is None:
logger.error(f"Epoch {epoch}: self.sdf_data is None. Cannot train.")
return float('inf')
self.model.train()
total_loss = 0.0
step = 0 # 如果你的训练是分批次的,这里应该用批次索引
batch_size = 10240 # 设置合适的batch大小
# 将数据分成多个batch
num_points = self.sdf_data.shape[0]
num_batches = (num_points + batch_size - 1) // batch_size
for batch_idx in range(num_batches):
start_idx = batch_idx * batch_size
end_idx = min((batch_idx + 1) * batch_size, num_points)
points = self.sdf_data[start_idx:end_idx, 0:3].clone().detach() # 取前3列作为点
gt_sdf = self.sdf_data[start_idx:end_idx, -1].clone().detach() # 取最后一列作为SDF真值
normals = None
if args.use_normal:
if self.sdf_data.shape[1] < 7: # 检查是否有足够的列给法线
logger.error(f"Epoch {epoch}: --use-normal is specified, but sdf_data has only {self.sdf_data.shape[1]} columns.")
return float('inf')
normals = self.sdf_data[start_idx:end_idx, 3:6].clone().detach() # 取中间3列作为法线
# 执行检查
if self.debug_mode:
if check_tensor(points, "Input Points", epoch, step): return float('inf')
if check_tensor(gt_sdf, "Input GT SDF", epoch, step): return float('inf')
if args.use_normal:
# 只有在请求法线时才检查 normals
if check_tensor(normals, "Input Normals", epoch, step): return float('inf')
# --- 准备模型输入,启用梯度 ---
points.requires_grad_(True) # 在检查之后启用梯度
# --- 前向传播 ---
self.optimizer.zero_grad()
pred_sdf = self.model(points)
if self.debug_mode:
# --- 检查前向传播的输出 ---
logger.gpu_memory_stats("前向传播后")
# --- 2. 检查模型输出 ---
#if check_tensor(pred_sdf, "Predicted SDF (Model Output)", epoch, step): return float('inf')
# --- 计算损失 ---
loss = torch.tensor(float('nan'), device=self.device) # 初始化为 NaN 以防计算失败
loss_details = {}
try:
# --- 3. 检查损失计算前的输入 ---
# (points 已经启用梯度,不需要再次检查inf/nan,但检查pred_sdf和gt_sdf)
#if check_tensor(pred_sdf, "Predicted SDF (Loss Input)", epoch, step): raise ValueError("Bad pred_sdf before loss")
#if check_tensor(gt_sdf, "GT SDF (Loss Input)", epoch, step): raise ValueError("Bad gt_sdf before loss")
if args.use_normal:
# 检查法线和带梯度的点
#if check_tensor(normals, "Normals (Loss Input)", epoch, step): raise ValueError("Bad normals before loss")
#if check_tensor(points, "Points (Loss Input)", epoch, step): raise ValueError("Bad points before loss")
logger.gpu_memory_stats("计算损失前")
loss, loss_details = self.loss_manager.compute_loss(
points,
normals, # 传递检查过的 normals
gt_sdf,
pred_sdf
)
else:
loss = torch.nn.functional.mse_loss(pred_sdf, gt_sdf)
# --- 4. 检查损失计算结果 ---
if self.debug_mode:
if check_tensor(loss, "Calculated Loss", epoch, step):
logger.error(f"Epoch {epoch} Step {step}: Loss calculation resulted in inf/nan.")
if loss_details: logger.error(f"Loss Details: {loss_details}")
return float('inf') # 如果损失无效,停止这个epoch
except Exception as loss_e:
logger.error(f"Epoch {epoch} Step {step}: Error during loss calculation: {loss_e}", exc_info=True)
return float('inf') # 如果计算出错,停止这个epoch
logger.gpu_memory_stats("损失计算后")
# --- 反向传播和优化 ---
try:
loss.backward()
# --- 5. (可选) 检查梯度 ---
# for name, param in self.model.named_parameters():
# if param.grad is not None:
# if check_tensor(param.grad, f"Gradient/{name}", epoch, step):
# logger.warning(f"Epoch {epoch} Step {step}: Bad gradient for {name}. Consider clipping or zeroing.")
# # 例如:param.grad.data.clamp_(-1, 1) # 值裁剪
# # 或在 optimizer.step() 前进行范数裁剪:
# # torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
# --- (推荐) 添加梯度裁剪 ---
# 防止梯度爆炸,这可能是导致 inf/nan 的原因之一
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) # 范数裁剪
self.optimizer.step()
except Exception as backward_e:
logger.error(f"Epoch {epoch} Step {step}: Error during backward pass or optimizer step: {backward_e}", exc_info=True)
# 如果你想看是哪个操作导致的,可以启用 anomaly detection
# torch.autograd.set_detect_anomaly(True) # 放在训练开始前
return float('inf') # 如果反向传播或优化出错,停止这个epoch
# --- 记录和累加损失 ---
current_loss = loss.item()
if not np.isfinite(current_loss): # 再次确认损失是有效的数值
logger.error(f"Epoch {epoch} Step {step}: Loss item is not finite ({current_loss}).")
return float('inf')
total_loss += current_loss
del loss
torch.cuda.empty_cache()
# 记录训练进度 (只记录有效的损失)
logger.info(f'Train Epoch: {epoch:4d}]\t'
f'Loss: {current_loss:.6f}')
if loss_details: logger.info(f"Loss Details: {loss_details}")
return total_loss # 对于单批次训练,直接返回当前损失
def validate(self, epoch: int) -> float:
self.model.eval()
total_loss = 0.0
with torch.no_grad():
for batch in self.val_loader:
points = batch['points'].to(self.device)
gt_sdf = batch['sdf'].to(self.device)
pred_sdf = self.model(points)
loss = torch.nn.functional.mse_loss(pred_sdf, gt_sdf)
total_loss += loss.item()
avg_loss = total_loss / len(self.val_loader)
logger.info(f'Validation Epoch: {epoch}\tAverage Loss: {avg_loss:.6f}')
return avg_loss
def train(self):
best_val_loss = float('inf')
logger.info("Starting training...")
start_time = time.time()
start_epoch = 1
if args.resume_checkpoint_path:
start_epoch = self._load_checkpoint(args.resume_checkpoint_path)
logger.info(f"Loaded model from {args.resume_checkpoint_path}")
1 year ago
for epoch in range(start_epoch, self.config.train.num_epochs + 1):
# 训练一个epoch
train_loss = self.train_epoch(epoch)
# 验证
'''
if epoch % self.config.train.val_freq == 0:
val_loss = self.validate(epoch)
logger.info(f'Epoch {epoch}: Train Loss = {train_loss:.6f}, Val Loss = {val_loss:.6f}')
# 保存最佳模型
if val_loss < best_val_loss:
best_val_loss = val_loss
self._save_model(epoch, val_loss)
logger.info(f'New best model saved at epoch {epoch} with val loss {val_loss:.6f}')
else:
logger.info(f'Epoch {epoch}: Train Loss = {train_loss:.6f}')
'''
# 保存检查点
if epoch % self.config.train.save_freq == 0:
self._save_checkpoint(epoch, train_loss)
logger.info(f'Checkpoint saved at epoch {epoch}')
# 训练完成
total_time = time.time() - start_time
self._tracing_model_by_script()
#self._tracing_model()
logger.info(f'Training completed in {total_time//3600:.0f}h {(total_time%3600)//60:.0f}m {total_time%60:.2f}s')
logger.info(f'Best validation loss: {best_val_loss:.6f}')
1 year ago
#self.test_load()
def test_load(self):
model = self._load_checkpoint("/home/wch/brep2sdf/brep2sdf/00000054.pt")
model.eval()
logger.debug(model)
example_input = torch.rand(10, 3, device=self.device)
#logger.debug(model.encoder.octree.bbox)
logger.debug(f"points: {example_input}")
sdfs= model(example_input)
logger.debug(f"sdfs:{sdfs}")
def _tracing_model_by_script(self):
"""保存模型"""
self.model.eval()
# 确保模型中的所有逻辑都兼容 TorchScript
scripted_model = torch.jit.script(self.model)
#optimized_model = optimize_for_mobile(scripted_model)
torch.jit.save(scripted_model, f"/home/wch/brep2sdf/data/output_data/{self.model_name}.pt")
def _tracing_model(self):
"""保存模型"""
self.model.eval()
# 创建示例输入
example_input = torch.rand(1, 3, device=self.device)
# 使用 trace 方式导出模型
traced_model = torch.jit.trace(self.model, example_input)
# 保存模型
save_path = f"/home/wch/brep2sdf/data/output_data/{self.model_name}.pt"
torch.jit.save(traced_model, save_path)
# 验证保存的模型
try:
loaded_model = torch.jit.load(save_path)
test_input = torch.rand(1, 3, device=self.device)
_ = loaded_model(test_input)
logger.info(f"模型已保存并验证成功:{save_path}")
except Exception as e:
logger.error(f"模型验证失败:{e}")
def _save_checkpoint(self, epoch: int, train_loss: float):
"""保存训练检查点"""
checkpoint_dir = os.path.join(
self.config.train.checkpoint_dir,
self.model_name
)
os.makedirs(checkpoint_dir, exist_ok=True)
checkpoint_path = os.path.join(checkpoint_dir, f"epoch_{epoch:03d}.pth")
# 只保存状态字典
torch.save({
'epoch': epoch,
'model_state_dict': self.model.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'loss': train_loss,
}, checkpoint_path)
def _load_checkpoint(self, checkpoint_path):
"""从检查点恢复训练状态"""
try:
checkpoint = torch.load(checkpoint_path)
self.model.load_state_dict(checkpoint['model_state_dict'])
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
return checkpoint['epoch'] + 1
except Exception as e:
logger.error(f"加载checkpoint失败: {str(e)}")
raise
def main():
# 这里需要初始化配置
config = get_default_config()
# 初始化训练器并开始训练
trainer = Trainer(config, input_step=args.input)
trainer.train()
if __name__ == '__main__':
main()