You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
317 lines
11 KiB
317 lines
11 KiB
import torch
|
|
from torch.serialization import add_safe_globals
|
|
import torch.optim as optim
|
|
import time
|
|
import os
|
|
import numpy as np
|
|
import argparse
|
|
|
|
from brep2sdf.config.default_config import get_default_config
|
|
from brep2sdf.data.data import load_brep_file,load_sdf_file
|
|
from brep2sdf.data.pre_process_by_mesh import process_single_step
|
|
from brep2sdf.networks.network import Net
|
|
from brep2sdf.networks.octree import OctreeNode
|
|
from brep2sdf.networks.loss import LossManager
|
|
from brep2sdf.utils.logger import logger
|
|
|
|
|
|
# 配置命令行参数
|
|
parser = argparse.ArgumentParser(description='STEP文件批量处理工具')
|
|
parser.add_argument('-i', '--input', required=True,
|
|
help='待处理 brep (.step) 路径')
|
|
parser.add_argument(
|
|
'--use-normal',
|
|
action='store_true', # 默认为 False,如果用户指定该参数,则为 True
|
|
help='强制采样点有法向量'
|
|
)
|
|
parser.add_argument(
|
|
'--force-reprocess',
|
|
action='store_true', # 默认为 False,如果用户指定该参数,则为 True
|
|
help='强制重新进行数据预处理,忽略缓存或已有结果'
|
|
)
|
|
args = parser.parse_args()
|
|
|
|
|
|
def prepare_sdf_data(surf_data, normals=None, max_points=100000, device='cuda'):
|
|
total_points = sum(len(s) for s in surf_data)
|
|
|
|
# 降采样逻辑(修复版)
|
|
if total_points > max_points:
|
|
# 生成索引
|
|
indices = []
|
|
for i, points in enumerate(surf_data):
|
|
indices.extend([(i, j) for j in range(len(points))])
|
|
|
|
# 随机打乱索引
|
|
np.random.shuffle(indices)
|
|
|
|
# 选择前max_points个索引
|
|
selected_indices = indices[:max_points]
|
|
if not normals is None:
|
|
# 根据索引构建sdf_array
|
|
sdf_array = np.zeros((max_points, 4), dtype=np.float32)
|
|
for idx, (i, j) in enumerate(selected_indices):
|
|
sdf_array[idx, :3] = surf_data[i][j]
|
|
else:
|
|
sdf_array = np.zeros((max_points, 7), dtype=np.float32)
|
|
for idx, (i, j) in enumerate(selected_indices):
|
|
sdf_array[idx, :3] = surf_data[i][j]
|
|
sdf_array[idx, 3:6] = normals[i][j]
|
|
else:
|
|
if not normals is None:
|
|
sdf_array = np.zeros((total_points, 4), dtype=np.float32)
|
|
sdf_array[:, :3] = np.concatenate(surf_data)
|
|
sdf_array = np.zeros((max_points, 7), dtype=np.float32)
|
|
else:
|
|
for idx, (i, j) in enumerate(selected_indices):
|
|
sdf_array[idx, :3] = surf_data[i][j]
|
|
sdf_array[idx, 3:6] = normals[i][j]
|
|
|
|
return torch.tensor(sdf_array, dtype=torch.float32, device=device)
|
|
|
|
|
|
class Trainer:
|
|
def __init__(self, config, input_step):
|
|
self.config = config
|
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
|
self.model_name = os.path.basename(input_step).split('_')[0]
|
|
self.base_name = self.model_name + ".xyz"
|
|
data_path = os.path.join("/home/wch/brep2sdf/data/output_data",self.base_name)
|
|
if os.path.exists(data_path) and not args.force_reprocess:
|
|
try:
|
|
self.data = load_brep_file(data_path)
|
|
except Exception as e:
|
|
logger.error(f"fail to load {data_path}, {str(e)}")
|
|
raise e
|
|
if args.use_normal and self.data.get("surf_pnt_normals", None) is None:
|
|
self.data = process_single_step(step_path=input_step, output_path=data_path, sample_normal_vector=args.use_normal)
|
|
else:
|
|
self.data = process_single_step(step_path=input_step, output_path=data_path, sample_normal_vector=args.use_normal)
|
|
|
|
# 将曲面点云列表转换为 (N*M, 4) 数组
|
|
surfs = self.data["surf_ncs"]
|
|
self.sdf_data = prepare_sdf_data(
|
|
surfs,
|
|
normals = self.data["surf_pnt_normals"],
|
|
max_points=4096,
|
|
device=self.device
|
|
)
|
|
# 初始化数据集
|
|
#self.brep_data = load_brep_file(self.config.data.pkl_path)
|
|
#logger.info( self.brep_data )
|
|
#self.sdf_data = load_sdf_file(sdf_path=self.config.data.sdf_path, num_query_points=self.config.data.num_query_points).to(self.device)
|
|
|
|
# 初始化网络
|
|
|
|
surf_bbox=torch.tensor(
|
|
self.data['surf_bbox_ncs'],
|
|
dtype=torch.float32,
|
|
device=self.device
|
|
)
|
|
|
|
|
|
self.build_tree(surf_bbox=surf_bbox, max_depth=4)
|
|
|
|
|
|
self.model = Net(
|
|
octree=self.root,
|
|
feature_dim=64
|
|
).to(self.device)
|
|
|
|
# 初始化优化器
|
|
self.optimizer = optim.AdamW(
|
|
self.model.parameters(),
|
|
lr=config.train.learning_rate,
|
|
weight_decay=config.train.weight_decay
|
|
)
|
|
|
|
self.loss_manager = LossManager(ablation="none")
|
|
|
|
|
|
def build_tree(self,surf_bbox, max_depth=6):
|
|
num_faces = surf_bbox.shape[0]
|
|
bbox = self._calculate_global_bbox(surf_bbox)
|
|
self.root = OctreeNode(
|
|
bbox=bbox,
|
|
face_indices=np.arange(num_faces), # 初始包含所有面
|
|
max_depth=max_depth,
|
|
surf_bbox=surf_bbox
|
|
)
|
|
#print(surf_bbox)
|
|
logger.info("starting octree conduction")
|
|
self.root.conduct_tree()
|
|
logger.info("complete octree conduction")
|
|
#self.root.print_tree(0)
|
|
|
|
def _calculate_global_bbox(self, surf_bbox: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
计算整个数据集的全局边界框,综合考虑表面包围盒和采样点
|
|
|
|
参数:
|
|
surf_bbox: 形状为 (num_edges, 6) 的Tensor,表示每条边的包围盒
|
|
[xmin, ymin, zmin, xmax, ymax, zmax]
|
|
|
|
返回:
|
|
形状为 (6,) 的Tensor,格式为 [x_min, y_min, z_min, x_max, y_max, z_max]
|
|
"""
|
|
# 验证输入
|
|
if not isinstance(surf_bbox, torch.Tensor):
|
|
raise TypeError(f"surf_bbox 必须是 torch.Tensor,但得到 {type(surf_bbox)}")
|
|
if surf_bbox.dim() != 2 or surf_bbox.shape[1] != 6:
|
|
raise ValueError(f"surf_bbox 形状应为 (num_edges, 6),但得到 {surf_bbox.shape}")
|
|
|
|
# 计算表面包围盒的全局范围
|
|
global_min = surf_bbox[:, :3].min(dim=0).values
|
|
global_max = surf_bbox[:, 3:].max(dim=0).values
|
|
|
|
|
|
# 返回合并后的边界框
|
|
return torch.cat([global_min, global_max])
|
|
|
|
|
|
def train_epoch(self, epoch: int) -> float:
|
|
self.model.train()
|
|
total_loss = 0.0
|
|
|
|
|
|
# 获取数据并移动到设备
|
|
points = self.sdf_data[:,0:3]
|
|
points.requires_grad_(True)
|
|
if args.use_normal:
|
|
normals = self.sdf_data[:,3:6]
|
|
gt_sdf = self.sdf_data[:,6]
|
|
|
|
else:
|
|
gt_sdf = self.sdf_data[:,3]
|
|
|
|
# 前向传播
|
|
self.optimizer.zero_grad()
|
|
pred_sdf = self.model(points)
|
|
|
|
# 计算损失
|
|
if args.use_normal:
|
|
loss,loss_details = self.loss_manager.compute_loss(
|
|
points,
|
|
normals,
|
|
gt_sdf,
|
|
pred_sdf
|
|
) # 计算损失
|
|
else:
|
|
loss = torch.nn.functional.mse_loss(pred_sdf, gt_sdf)
|
|
|
|
# 反向传播和优化
|
|
loss.backward()
|
|
self.optimizer.step()
|
|
|
|
total_loss += loss.item()
|
|
|
|
# 记录训练进度
|
|
logger.info(f'Train Epoch: {epoch:4d}]\t'
|
|
f'Loss: {loss.item():.6f}')
|
|
|
|
return total_loss
|
|
|
|
def validate(self, epoch: int) -> float:
|
|
self.model.eval()
|
|
total_loss = 0.0
|
|
|
|
with torch.no_grad():
|
|
for batch in self.val_loader:
|
|
points = batch['points'].to(self.device)
|
|
gt_sdf = batch['sdf'].to(self.device)
|
|
|
|
pred_sdf = self.model(points)
|
|
loss = torch.nn.functional.mse_loss(pred_sdf, gt_sdf)
|
|
total_loss += loss.item()
|
|
|
|
avg_loss = total_loss / len(self.val_loader)
|
|
logger.info(f'Validation Epoch: {epoch}\tAverage Loss: {avg_loss:.6f}')
|
|
return avg_loss
|
|
|
|
def train(self):
|
|
best_val_loss = float('inf')
|
|
logger.info("Starting training...")
|
|
start_time = time.time()
|
|
|
|
for epoch in range(1, self.config.train.num_epochs + 1):
|
|
# 训练一个epoch
|
|
train_loss = self.train_epoch(epoch)
|
|
|
|
# 验证
|
|
'''
|
|
if epoch % self.config.train.val_freq == 0:
|
|
val_loss = self.validate(epoch)
|
|
logger.info(f'Epoch {epoch}: Train Loss = {train_loss:.6f}, Val Loss = {val_loss:.6f}')
|
|
|
|
# 保存最佳模型
|
|
if val_loss < best_val_loss:
|
|
best_val_loss = val_loss
|
|
self._save_model(epoch, val_loss)
|
|
logger.info(f'New best model saved at epoch {epoch} with val loss {val_loss:.6f}')
|
|
else:
|
|
logger.info(f'Epoch {epoch}: Train Loss = {train_loss:.6f}')
|
|
'''
|
|
# 保存检查点
|
|
if epoch % self.config.train.save_freq == 0:
|
|
self._save_checkpoint(epoch, train_loss)
|
|
logger.info(f'Checkpoint saved at epoch {epoch}')
|
|
|
|
# 训练完成
|
|
total_time = time.time() - start_time
|
|
logger.info(f'Training completed in {total_time//3600:.0f}h {(total_time%3600)//60:.0f}m {total_time%60:.2f}s')
|
|
logger.info(f'Best validation loss: {best_val_loss:.6f}')
|
|
self._tracing_model()
|
|
|
|
#self.test_load()
|
|
|
|
def test_load(self):
|
|
model = self._load_checkpoint("/home/wch/brep2sdf/brep2sdf/00000054.pt")
|
|
model.eval()
|
|
logger.debug(model)
|
|
example_input = torch.rand(10, 3, device=self.device)
|
|
#logger.debug(model.encoder.octree.bbox)
|
|
logger.debug(f"points: {example_input}")
|
|
sdfs= model(example_input)
|
|
logger.debug(f"sdfs:{sdfs}")
|
|
|
|
def _tracing_model(self):
|
|
"""保存模型"""
|
|
self.model.eval()
|
|
# 确保模型中的所有逻辑都兼容 TorchScript
|
|
scripted_model = torch.jit.script(self.model)
|
|
torch.jit.save(scripted_model, f"/home/wch/brep2sdf/data/output_data/{self.model_name}.pt")
|
|
|
|
def _load_checkpoint(self, checkpoint_path):
|
|
"""从检查点恢复训练状态"""
|
|
model = torch.load(checkpoint_path)
|
|
return model
|
|
|
|
def _save_checkpoint(self, epoch: int, train_loss: float):
|
|
"""保存训练检查点"""
|
|
checkpoint_dir = os.path.join(
|
|
self.config.train.checkpoint_dir,
|
|
self.model_name
|
|
)
|
|
os.makedirs(checkpoint_dir, exist_ok=True)
|
|
checkpoint_path = os.path.join(checkpoint_dir,f"epoch_{epoch:03d}.pth")
|
|
'''
|
|
torch.save({
|
|
'epoch': epoch,
|
|
'model_state_dict': self.model.state_dict(),
|
|
'optimizer_state_dict': self.optimizer.state_dict(),
|
|
'loss': train_loss,
|
|
'config': self.config
|
|
}, checkpoint_path)
|
|
'''
|
|
torch.save(self.model,checkpoint_path)
|
|
|
|
def main():
|
|
# 这里需要初始化配置
|
|
config = get_default_config()
|
|
# 初始化训练器并开始训练
|
|
trainer = Trainer(config, input_step=args.input)
|
|
trainer.train()
|
|
|
|
if __name__ == '__main__':
|
|
main()
|