You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

317 lines
11 KiB

import torch
from torch.serialization import add_safe_globals
import torch.optim as optim
import time
import os
import numpy as np
import argparse
from brep2sdf.config.default_config import get_default_config
from brep2sdf.data.data import load_brep_file,load_sdf_file
from brep2sdf.data.pre_process_by_mesh import process_single_step
from brep2sdf.networks.network import Net
from brep2sdf.networks.octree import OctreeNode
from brep2sdf.networks.loss import LossManager
from brep2sdf.utils.logger import logger
# 配置命令行参数
parser = argparse.ArgumentParser(description='STEP文件批量处理工具')
parser.add_argument('-i', '--input', required=True,
help='待处理 brep (.step) 路径')
parser.add_argument(
'--use-normal',
action='store_true', # 默认为 False,如果用户指定该参数,则为 True
help='强制采样点有法向量'
)
parser.add_argument(
'--force-reprocess',
action='store_true', # 默认为 False,如果用户指定该参数,则为 True
help='强制重新进行数据预处理,忽略缓存或已有结果'
)
args = parser.parse_args()
def prepare_sdf_data(surf_data, normals=None, max_points=100000, device='cuda'):
total_points = sum(len(s) for s in surf_data)
# 降采样逻辑(修复版)
if total_points > max_points:
# 生成索引
indices = []
for i, points in enumerate(surf_data):
indices.extend([(i, j) for j in range(len(points))])
# 随机打乱索引
np.random.shuffle(indices)
# 选择前max_points个索引
selected_indices = indices[:max_points]
if not normals is None:
# 根据索引构建sdf_array
sdf_array = np.zeros((max_points, 4), dtype=np.float32)
for idx, (i, j) in enumerate(selected_indices):
sdf_array[idx, :3] = surf_data[i][j]
else:
sdf_array = np.zeros((max_points, 7), dtype=np.float32)
for idx, (i, j) in enumerate(selected_indices):
sdf_array[idx, :3] = surf_data[i][j]
sdf_array[idx, 3:6] = normals[i][j]
else:
if not normals is None:
sdf_array = np.zeros((total_points, 4), dtype=np.float32)
sdf_array[:, :3] = np.concatenate(surf_data)
sdf_array = np.zeros((max_points, 7), dtype=np.float32)
else:
for idx, (i, j) in enumerate(selected_indices):
sdf_array[idx, :3] = surf_data[i][j]
sdf_array[idx, 3:6] = normals[i][j]
return torch.tensor(sdf_array, dtype=torch.float32, device=device)
class Trainer:
def __init__(self, config, input_step):
self.config = config
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.model_name = os.path.basename(input_step).split('_')[0]
self.base_name = self.model_name + ".xyz"
data_path = os.path.join("/home/wch/brep2sdf/data/output_data",self.base_name)
if os.path.exists(data_path) and not args.force_reprocess:
try:
self.data = load_brep_file(data_path)
except Exception as e:
logger.error(f"fail to load {data_path}, {str(e)}")
raise e
if args.use_normal and self.data.get("surf_pnt_normals", None) is None:
self.data = process_single_step(step_path=input_step, output_path=data_path, sample_normal_vector=args.use_normal)
else:
self.data = process_single_step(step_path=input_step, output_path=data_path, sample_normal_vector=args.use_normal)
# 将曲面点云列表转换为 (N*M, 4) 数组
surfs = self.data["surf_ncs"]
self.sdf_data = prepare_sdf_data(
surfs,
normals = self.data["surf_pnt_normals"],
max_points=4096,
device=self.device
)
# 初始化数据集
#self.brep_data = load_brep_file(self.config.data.pkl_path)
#logger.info( self.brep_data )
#self.sdf_data = load_sdf_file(sdf_path=self.config.data.sdf_path, num_query_points=self.config.data.num_query_points).to(self.device)
# 初始化网络
surf_bbox=torch.tensor(
self.data['surf_bbox_ncs'],
dtype=torch.float32,
device=self.device
)
self.build_tree(surf_bbox=surf_bbox, max_depth=4)
self.model = Net(
octree=self.root,
feature_dim=64
).to(self.device)
# 初始化优化器
self.optimizer = optim.AdamW(
self.model.parameters(),
lr=config.train.learning_rate,
weight_decay=config.train.weight_decay
)
self.loss_manager = LossManager(ablation="none")
def build_tree(self,surf_bbox, max_depth=6):
num_faces = surf_bbox.shape[0]
bbox = self._calculate_global_bbox(surf_bbox)
self.root = OctreeNode(
bbox=bbox,
face_indices=np.arange(num_faces), # 初始包含所有面
max_depth=max_depth,
surf_bbox=surf_bbox
)
#print(surf_bbox)
logger.info("starting octree conduction")
self.root.conduct_tree()
logger.info("complete octree conduction")
#self.root.print_tree(0)
def _calculate_global_bbox(self, surf_bbox: torch.Tensor) -> torch.Tensor:
"""
计算整个数据集的全局边界框,综合考虑表面包围盒和采样点
参数:
surf_bbox: 形状为 (num_edges, 6) 的Tensor,表示每条边的包围盒
[xmin, ymin, zmin, xmax, ymax, zmax]
返回:
形状为 (6,) 的Tensor,格式为 [x_min, y_min, z_min, x_max, y_max, z_max]
"""
# 验证输入
if not isinstance(surf_bbox, torch.Tensor):
raise TypeError(f"surf_bbox 必须是 torch.Tensor,但得到 {type(surf_bbox)}")
if surf_bbox.dim() != 2 or surf_bbox.shape[1] != 6:
raise ValueError(f"surf_bbox 形状应为 (num_edges, 6),但得到 {surf_bbox.shape}")
# 计算表面包围盒的全局范围
global_min = surf_bbox[:, :3].min(dim=0).values
global_max = surf_bbox[:, 3:].max(dim=0).values
# 返回合并后的边界框
return torch.cat([global_min, global_max])
def train_epoch(self, epoch: int) -> float:
self.model.train()
total_loss = 0.0
# 获取数据并移动到设备
points = self.sdf_data[:,0:3]
points.requires_grad_(True)
if args.use_normal:
normals = self.sdf_data[:,3:6]
gt_sdf = self.sdf_data[:,6]
else:
gt_sdf = self.sdf_data[:,3]
# 前向传播
self.optimizer.zero_grad()
pred_sdf = self.model(points)
# 计算损失
if args.use_normal:
loss,loss_details = self.loss_manager.compute_loss(
points,
normals,
gt_sdf,
pred_sdf
) # 计算损失
else:
loss = torch.nn.functional.mse_loss(pred_sdf, gt_sdf)
# 反向传播和优化
loss.backward()
self.optimizer.step()
total_loss += loss.item()
# 记录训练进度
logger.info(f'Train Epoch: {epoch:4d}]\t'
f'Loss: {loss.item():.6f}')
return total_loss
def validate(self, epoch: int) -> float:
self.model.eval()
total_loss = 0.0
with torch.no_grad():
for batch in self.val_loader:
points = batch['points'].to(self.device)
gt_sdf = batch['sdf'].to(self.device)
pred_sdf = self.model(points)
loss = torch.nn.functional.mse_loss(pred_sdf, gt_sdf)
total_loss += loss.item()
avg_loss = total_loss / len(self.val_loader)
logger.info(f'Validation Epoch: {epoch}\tAverage Loss: {avg_loss:.6f}')
return avg_loss
def train(self):
best_val_loss = float('inf')
logger.info("Starting training...")
start_time = time.time()
for epoch in range(1, self.config.train.num_epochs + 1):
# 训练一个epoch
train_loss = self.train_epoch(epoch)
# 验证
'''
if epoch % self.config.train.val_freq == 0:
val_loss = self.validate(epoch)
logger.info(f'Epoch {epoch}: Train Loss = {train_loss:.6f}, Val Loss = {val_loss:.6f}')
# 保存最佳模型
if val_loss < best_val_loss:
best_val_loss = val_loss
self._save_model(epoch, val_loss)
logger.info(f'New best model saved at epoch {epoch} with val loss {val_loss:.6f}')
else:
logger.info(f'Epoch {epoch}: Train Loss = {train_loss:.6f}')
'''
# 保存检查点
if epoch % self.config.train.save_freq == 0:
self._save_checkpoint(epoch, train_loss)
logger.info(f'Checkpoint saved at epoch {epoch}')
# 训练完成
total_time = time.time() - start_time
logger.info(f'Training completed in {total_time//3600:.0f}h {(total_time%3600)//60:.0f}m {total_time%60:.2f}s')
logger.info(f'Best validation loss: {best_val_loss:.6f}')
self._tracing_model()
#self.test_load()
def test_load(self):
model = self._load_checkpoint("/home/wch/brep2sdf/brep2sdf/00000054.pt")
model.eval()
logger.debug(model)
example_input = torch.rand(10, 3, device=self.device)
#logger.debug(model.encoder.octree.bbox)
logger.debug(f"points: {example_input}")
sdfs= model(example_input)
logger.debug(f"sdfs:{sdfs}")
def _tracing_model(self):
"""保存模型"""
self.model.eval()
# 确保模型中的所有逻辑都兼容 TorchScript
scripted_model = torch.jit.script(self.model)
torch.jit.save(scripted_model, f"/home/wch/brep2sdf/data/output_data/{self.model_name}.pt")
def _load_checkpoint(self, checkpoint_path):
"""从检查点恢复训练状态"""
model = torch.load(checkpoint_path)
return model
def _save_checkpoint(self, epoch: int, train_loss: float):
"""保存训练检查点"""
checkpoint_dir = os.path.join(
self.config.train.checkpoint_dir,
self.model_name
)
os.makedirs(checkpoint_dir, exist_ok=True)
checkpoint_path = os.path.join(checkpoint_dir,f"epoch_{epoch:03d}.pth")
'''
torch.save({
'epoch': epoch,
'model_state_dict': self.model.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'loss': train_loss,
'config': self.config
}, checkpoint_path)
'''
torch.save(self.model,checkpoint_path)
def main():
# 这里需要初始化配置
config = get_default_config()
# 初始化训练器并开始训练
trainer = Trainer(config, input_step=args.input)
trainer.train()
if __name__ == '__main__':
main()