Browse Source

背景 normal 有提升

final
mckay 3 weeks ago
parent
commit
c9aadb2d0a
  1. 39
      brep2sdf/IsoSurfacing.py
  2. 26
      brep2sdf/batch_train.py
  3. 6
      brep2sdf/config/default_config.py
  4. 7
      brep2sdf/data/pre_process_by_mesh.py
  5. 33
      brep2sdf/data/sampler.py
  6. 136
      brep2sdf/eval_pos.py
  7. 2
      brep2sdf/networks/decoder.py
  8. 49
      brep2sdf/networks/encoder.py
  9. 2
      brep2sdf/networks/feature_volume.py
  10. 5
      brep2sdf/networks/loss.py
  11. 8
      brep2sdf/networks/network.py
  12. 321
      brep2sdf/test.py
  13. 25
      brep2sdf/train.py

39
brep2sdf/IsoSurfacing.py

@ -22,12 +22,10 @@ def create_grid(depth, box_size):
xx, yy, zz = np.meshgrid(x, y, z, indexing='ij') xx, yy, zz = np.meshgrid(x, y, z, indexing='ij')
points = np.stack([xx.ravel(), yy.ravel(), zz.ravel()], axis=1) points = np.stack([xx.ravel(), yy.ravel(), zz.ravel()], axis=1)
# 新增归一化处理
max_coord = np.max(np.abs(points))
points = points / max_coord # 归一化到[-1,1]
return points, xx, yy, zz return points, xx, yy, zz
def predict_sdf(model, points, device): def predict_sdf(model, points, device, use_bk=False):
""" """
使用模型预测SDF值 使用模型预测SDF值
:param model: PyTorch模型 :param model: PyTorch模型
@ -36,14 +34,27 @@ def predict_sdf(model, points, device):
:return: SDF值数组 (N,) :return: SDF值数组 (N,)
""" """
points_t = torch.from_numpy(points).float().to(device) points_t = torch.from_numpy(points).float().to(device)
logger.print_tensor_stats("input poitns", points_t)
with torch.no_grad(): with torch.no_grad():
sdf = model(points_t).cpu().numpy().flatten() if use_bk:
# 替换 inf 值为 2 print("only background")
#sdf[np.isinf(sdf)] = 2 sdf = model.forward_background(points_t)
else:
batch_size = 8192*4 # 定义批量大小
sdf_list = [] # 用于存储批量预测结果
for i in range(0, len(points), batch_size):
batch_points = points[i:i + batch_size]
points_t = torch.from_numpy(batch_points).float().to(device)
logger.print_tensor_stats("input points", points_t)
batch_sdf = model(points_t)
sdf_list.append(batch_sdf.cpu())
sdf = torch.cat(sdf_list) # 合并所有批量结果
logger.print_tensor_stats("sdf", sdf)
sdf = sdf.cpu().numpy().flatten()
return sdf return sdf
def extract_surface(sdf, xx, yy, zz, method='MC', feature_angle=30.0, voxel_size=0.01): def extract_surface(sdf, xx, yy, zz, method='MC', bbox_size=1.0,feature_angle=30.0, voxel_size=0.01):
""" """
提取零表面 提取零表面
:param sdf: SDF值三维数组 :param sdf: SDF值三维数组
@ -67,7 +78,7 @@ def extract_surface(sdf, xx, yy, zz, method='MC', feature_angle=30.0, voxel_size
raise ValueError(f"不支持的算法: {method}") raise ValueError(f"不支持的算法: {method}")
# 新增顶点后处理 # 新增顶点后处理
verts = (verts - sdf.shape[0]//2) / (sdf.shape[0]//2) # 归一化到[-1,1] verts = (verts - sdf.shape[0]//2) / (sdf.shape[0]//2) / 2 * bbox_size # 归一化到[-1,1]
return verts, faces return verts, faces
def save_ply(vertices, faces, filename): def save_ply(vertices, faces, filename):
@ -141,6 +152,7 @@ def main():
help='特征角度阈值(EMC算法专用)') help='特征角度阈值(EMC算法专用)')
parser.add_argument('--voxel_size', type=float, default=0.01, parser.add_argument('--voxel_size', type=float, default=0.01,
help='体素尺寸(DC算法专用)') help='体素尺寸(DC算法专用)')
parser.add_argument('--only-background', '-b', action='store_true', help='仅使用背景场')
parser.add_argument('--use-gpu', action='store_true', help='使用GPU') parser.add_argument('--use-gpu', action='store_true', help='使用GPU')
parser.add_argument('--compare', type=str, help='GT网格文件(.ply)') parser.add_argument('--compare', type=str, help='GT网格文件(.ply)')
parser.add_argument('--compres', type=int, default=32, help='误差计算分辨率') parser.add_argument('--compres', type=int, default=32, help='误差计算分辨率')
@ -157,7 +169,7 @@ def main():
# 创建网格并预测SDF # 创建网格并预测SDF
points, xx, yy, zz = create_grid(args.depth, args.box_size) points, xx, yy, zz = create_grid(args.depth, args.box_size)
print(points.shape) print(points.shape)
sdf = predict_sdf(model, points, device) sdf = predict_sdf(model, points, device,args.only_background)
print(points.shape) print(points.shape)
print(sdf.shape) print(sdf.shape)
print(sdf) print(sdf)
@ -167,14 +179,11 @@ def main():
# 提取表面 # 提取表面
print("Extracting surface...") print("Extracting surface...")
start_time = time.time() start_time = time.time()
verts, faces = extract_surface(sdf_grid, xx, yy, zz, args.method) verts, faces = extract_surface(sdf_grid, xx, yy, zz, args.method, args.box_size)
# 新增顶点归一化校验
max_val = np.max(np.abs(verts))
if max_val > 1.0 + 1e-6: # 允许微小误差
verts = verts / max_val
print(f"Surface extraction took {time.time() - start_time:.2f} seconds") print(f"Surface extraction took {time.time() - start_time:.2f} seconds")
verts = verts * 2
# 保存网格 # 保存网格
save_ply(verts, faces, args.output) save_ply(verts, faces, args.output)
print(f"Mesh saved to {args.output}") print(f"Mesh saved to {args.output}")

26
brep2sdf/batch_train.py

@ -42,7 +42,7 @@ def run_training_process(input_step: str, train_script: str, common_args: list)
"python", train_script, "python", train_script,
*common_args, *common_args,
"-i", input_step, "-i", input_step,
"--resume-checkpoint-path", f"/home/wch/brep2sdf/checkpoints/{name_id}/epoch_11000.pth"
] ]
try: try:
logger.info(f"即将执行的命令: {' '.join(command)}") logger.info(f"即将执行的命令: {' '.join(command)}")
@ -244,11 +244,15 @@ def run_isosurfacing_process(input_path, output_dir, use_gpu=True,if_nh=False):
command = [ command = [
"python", "IsoSurfacing.py", "python", "IsoSurfacing.py",
"-i", input_path, "-i", input_path,
"-o", output_path "-o", output_path,
"-b",
"--depth", "7"
] ]
if use_gpu: if use_gpu:
command.append("--use-gpu") command.append("--use-gpu")
if if_nh:
command.append("--box_size")
command.append("2.0")
# 执行命令 # 执行命令
result = subprocess.run( result = subprocess.run(
command, command,
@ -382,10 +386,16 @@ def batch_nh_Iso(args):
logger.info(f"处理完成。成功: {success_count}, 失败: {failure_count}") logger.info(f"处理完成。成功: {success_count}, 失败: {failure_count}")
def main(args): def main(args):
batch_train_max_workers_1(args) if args.task==1:
#batch_train(args) batch_train_max_workers_1(args)
#batch_Iso(args) elif args.task==2:
#batch_nh_Iso(args) batch_train(args)
elif args.task==3:
batch_Iso(args)
elif args.task==4:
batch_nh_Iso(args)
else:
logger.info(f"task 只允许1-4")
if __name__ == '__main__': if __name__ == '__main__':
@ -400,6 +410,8 @@ if __name__ == '__main__':
help="要执行的训练脚本路径。") help="要执行的训练脚本路径。")
parser.add_argument('--workers', type=int, default=1, parser.add_argument('--workers', type=int, default=1,
help="用于并行处理的工作进程数。") help="用于并行处理的工作进程数。")
parser.add_argument('--task','-t', type=int, default=1,
help="需要执行的任务")
parser.add_argument('--train-args', nargs='*', parser.add_argument('--train-args', nargs='*',
help="传递给 train.py 的额外参数 (例如 --epochs 10 --batch-size 32)。") help="传递给 train.py 的额外参数 (例如 --epochs 10 --batch-size 32)。")

6
brep2sdf/config/default_config.py

@ -50,14 +50,14 @@ class TrainConfig:
batch_size: int = 8 batch_size: int = 8
num_workers: int = 4 num_workers: int = 4
num_epochs1: int = 10000 num_epochs1: int = 10000
num_epochs2: int = 0 num_epochs2: int = 0000
num_epochs3: int = 0 num_epochs3: int = 0000
learning_rate: float = 0.1 learning_rate: float = 0.1
learning_rate_schedule: List = field(default_factory=lambda: [{ learning_rate_schedule: List = field(default_factory=lambda: [{
"Type": "Step", # 学习率调度类型。"Step"表示在指定迭代次数后将学习率乘以因子 "Type": "Step", # 学习率调度类型。"Step"表示在指定迭代次数后将学习率乘以因子
"Initial": 0.01, "Initial": 0.01,
"Interval": 2000, "Interval": 2000,
"Factor": 0.7 "Factor": 0.5
}]) }])
min_lr: float = 1e-5 min_lr: float = 1e-5
weight_decay: float = 0.0001 weight_decay: float = 0.0001

7
brep2sdf/data/pre_process_by_mesh.py

@ -25,7 +25,7 @@ from OCC.Core.StlAPI import StlAPI_Writer
from OCC.Core.BRepAdaptor import BRepAdaptor_Surface from OCC.Core.BRepAdaptor import BRepAdaptor_Surface
from OCC.Core.gp import gp_Pnt, gp_Vec from OCC.Core.gp import gp_Pnt, gp_Vec
from brep2sdf.data.sampler import sample_sdf_points_and_normals, sample_face_points_brep, sample_edge_points_brep,sample_zero_surface_points_and_normals from brep2sdf.data.sampler import sample_sdf_points_and_normals, sample_face_points_brep, sample_edge_points_brep,sample_zero_surface_points_and_normals,sample_grid
from brep2sdf.data.data import check_data_format from brep2sdf.data.data import check_data_format
from brep2sdf.data.utils import get_bbox, normalize, get_adjacency_info,batch_compute_normals from brep2sdf.data.utils import get_bbox, normalize, get_adjacency_info,batch_compute_normals
from brep2sdf.utils.logger import logger from brep2sdf.utils.logger import logger
@ -430,12 +430,17 @@ def parse_solid(step_path,sample_normal_vector=False,sample_sdf_points=False):
if trimesh_mesh_ncs is not None: if trimesh_mesh_ncs is not None:
# 调用封装的函数,传递固定数量参数 # 调用封装的函数,传递固定数量参数
logger.debug("采样 SDF 点和法线...") logger.debug("采样 SDF 点和法线...")
'''
data['sampled_points_normals_sdf'] = sample_sdf_points_and_normals( data['sampled_points_normals_sdf'] = sample_sdf_points_and_normals(
trimesh_mesh_ncs=trimesh_mesh_ncs, trimesh_mesh_ncs=trimesh_mesh_ncs,
surf_bbox_ncs=data['surf_bbox_ncs'], surf_bbox_ncs=data['surf_bbox_ncs'],
num_sdf_samples=50000, # <-- 传递固定数量 num_sdf_samples=50000, # <-- 传递固定数量
sdf_sampling_std_dev=0.0001 sdf_sampling_std_dev=0.0001
) )
'''
data['sampled_points_normals_sdf'] = sample_grid(
trimesh_mesh_ncs=trimesh_mesh_ncs,
)
else: else:
logger.warning("请求了 SDF 点采样,但 Trimesh 加载失败。") logger.warning("请求了 SDF 点采样,但 Trimesh 加载失败。")
return data return data

33
brep2sdf/data/sampler.py

@ -9,7 +9,8 @@ from typing import Optional
import numpy as np import numpy as np
from brep2sdf.utils.logger import logger from brep2sdf.utils.logger import logger
import trimesh import trimesh
from trimesh.proximity import ProximityQuery from trimesh.proximity import ProximityQuery, signed_distance, closest_point
# 导入OpenCASCADE相关库 # 导入OpenCASCADE相关库
@ -430,4 +431,34 @@ def sample_sdf_points_and_normals(
return None return None
except Exception as e: except Exception as e:
logger.error(f"计算 SDF 或法线时失败: {str(e)}") logger.error(f"计算 SDF 或法线时失败: {str(e)}")
return None
def sample_grid(trimesh_mesh_ncs: trimesh.Trimesh):
"""
np.ndarray | None: 形状为 (N, 7) 的数组 [x, y, z, nx, ny, nz, sdf]
如果采样或计算失败则返回 None
"""
grid_size = 2**5 + 1
start = -0.5
end = 0.5
x = np.linspace(start, end, grid_size)
y = np.linspace(start, end, grid_size)
z = np.linspace(start, end, grid_size)
xx, yy, zz = np.meshgrid(x, y, z, indexing='ij')
points = np.stack([xx.ravel(), yy.ravel(), zz.ravel()], axis=1)
try:
# Step 1: Compute signed distance
sdf = signed_distance(trimesh_mesh_ncs, points)
# Step 2: Compute normals via closest face
_, _, face_indices = closest_point(trimesh_mesh_ncs, points)
normals = trimesh_mesh_ncs.face_normals[face_indices]
# Step 3: Concatenate into (N, 7)
data = np.hstack([points, normals, sdf[:, None]])
return data
except Exception as e:
print(f"Error computing SDF or normals: {e}")
return None return None

136
brep2sdf/eval_pos.py

@ -11,16 +11,19 @@ from brep2sdf.utils.logger import logger
# 全局变量用于保存采样点 # 全局变量用于保存采样点
GLOBAL_SAMPLED_POINTS = None GLOBAL_SAMPLED_POINTS = None
def sample_grid_points(mesh, nh_mesh, our_mesh, num_samples_iou): def sample_grid_points():
global GLOBAL_SAMPLED_POINTS global GLOBAL_SAMPLED_POINTS
if GLOBAL_SAMPLED_POINTS is not None: if GLOBAL_SAMPLED_POINTS is not None:
return GLOBAL_SAMPLED_POINTS return GLOBAL_SAMPLED_POINTS
# 从一个较大的空间范围采样点 # 从一个较大的空间范围采样点
bounds = np.vstack([mesh.bounds, nh_mesh.bounds, our_mesh.bounds]) grid_size = 2**4 + 1
min_bound = np.min(bounds, axis=0) start = -1
max_bound = np.max(bounds, axis=0) end = 1
points = np.random.uniform(min_bound, max_bound, (num_samples_iou, 3)) x = np.linspace(start, end, grid_size)
GLOBAL_SAMPLED_POINTS = points y = np.linspace(start, end, grid_size)
z = np.linspace(start, end, grid_size)
xx, yy, zz = np.meshgrid(x, y, z, indexing='ij')
points = np.stack([xx.ravel(), yy.ravel(), zz.ravel()], axis=1)
return points return points
def position_loss(pred_sdfs: torch.Tensor) -> torch.Tensor: def position_loss(pred_sdfs: torch.Tensor) -> torch.Tensor:
@ -113,12 +116,17 @@ def nh(model_path, points):
logger.error(f"调用 NH 模型时出错: {e}") logger.error(f"调用 NH 模型时出错: {e}")
return None return None
def mine(model_path, points): def mine(model_path, points, only_bk=True):
# points 是 【-1,1】
points = points / 2
model = load_model(model_path) model = load_model(model_path)
if model is None: if model is None:
return None return None
try: try:
return model(points) if only_bk:
return model.forward_background(points) * 2
else:
return model(points) * 2
except Exception as e: except Exception as e:
logger.error(f"调用 mine 模型时出错: {e}") logger.error(f"调用 mine 模型时出错: {e}")
return None return None
@ -129,8 +137,10 @@ def run(name):
# 替换为实际的 obj 文件路径 # 替换为实际的 obj 文件路径
obj_file_path = f"/home/wch/brep2sdf/data/gt_mesh/{name}.obj" obj_file_path = f"/home/wch/brep2sdf/data/gt_mesh/{name}.obj"
model_path = f"/home/wch/brep2sdf/data/output_data/{name}.pt" model_path = f"/home/wch/brep2sdf/data/output_data/{name}.pt"
nh_model = f"/home/wch/NH-Rep/data_backup/output_data/extracted/output_data/{name}_0_50k_model_h.pt" #nh_model = f"/home/wch/NH-Rep/data_backup/output_data/extracted/output_data/{name}_0_50k_model_h.pt"
ply_nh = f"/home/wch/NH-Rep/data_backup/output_data/extracted/output_data/{name}_0_50k_model_h_nh.ply" #ply_nh = f"/home/wch/NH-Rep/data_backup/output_data/extracted/output_data/{name}_0_50k_model_h_nh.ply"
nh_model= f"/home/wch/NH-Rep/data/output_data/{name}_0_50k_model_h.pt"
ply_nh = f"/home/wch/NH-Rep/data/output_data/{name}_0_50k_model_h_nh.ply"
ply_our = f"/home/wch/brep2sdf/data/output_data/{name}.ply" ply_our = f"/home/wch/brep2sdf/data/output_data/{name}.ply"
npz_path = f"/home/wch/brep2sdf/data/output_data/{name}.xyz" npz_path = f"/home/wch/brep2sdf/data/output_data/{name}.xyz"
num_samples=4096 num_samples=4096
@ -160,7 +170,7 @@ def run(name):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
coordinates_tensor = torch.from_numpy(coordinates).float().to(device).requires_grad_(True) coordinates_tensor = torch.from_numpy(coordinates).float().to(device).requires_grad_(True)
sdf1 = nh(nh_model, coordinates_tensor) / 2 sdf1 = nh(nh_model, coordinates_tensor)
sdf2 = mine(model_path, coordinates_tensor) sdf2 = mine(model_path, coordinates_tensor)
loss1, loss2 = {}, {} loss1, loss2 = {}, {}
@ -198,6 +208,7 @@ def run(name):
# 从网格中采样点 # 从网格中采样点
nh_points = torch.from_numpy(nh_mesh.sample(num_samples)).float().to(device) nh_points = torch.from_numpy(nh_mesh.sample(num_samples)).float().to(device)
our_points = torch.from_numpy(our_mesh.sample(num_samples)).float().to(device) our_points = torch.from_numpy(our_mesh.sample(num_samples)).float().to(device)
print("成功读取ply采样点")
# 确保 coordinates 是 torch.Tensor 类型 # 确保 coordinates 是 torch.Tensor 类型
loss1["cd"] = chamfer_distance(coordinates_tensor, nh_points).item() loss1["cd"] = chamfer_distance(coordinates_tensor, nh_points).item()
@ -205,66 +216,81 @@ def run(name):
loss1["hd"] = hausdorff_distance(coordinates_tensor, nh_points).item() loss1["hd"] = hausdorff_distance(coordinates_tensor, nh_points).item()
loss2["hd"] = hausdorff_distance(coordinates_tensor, our_points).item() loss2["hd"] = hausdorff_distance(coordinates_tensor, our_points).item()
print("成功计算cd hd")
# fea # fea
data = load_brep_file(npz_path) data = load_brep_file(npz_path)
sampled_pnts=prepare_sdf_data(data["surf_ncs"],normals=data["surf_pnt_normals"],max_points=num_samples) sampled_pnts=prepare_sdf_data(data["surf_ncs"],normals=data["surf_pnt_normals"],max_points=num_samples)
# 展平处理
flattened_pnts = sampled_pnts.flatten()
# 修改此处,使用 clone().detach() # 修改此处,使用 clone().detach()
if isinstance(flattened_pnts[0:3], torch.Tensor): if isinstance(sampled_pnts[:,0:3], torch.Tensor):
f_pnts = flattened_pnts[0:3].clone().detach().to(device).view(-1, 3) f_pnts = sampled_pnts[:,0:3].clone().detach().to(device)
else: else:
f_pnts = torch.from_numpy(flattened_pnts[0:3]).clone().detach().to(device).view(-1, 3) f_pnts = torch.from_numpy(sampled_pnts[:,0:3]).clone().detach().to(device)
if isinstance(flattened_pnts[3:6], torch.Tensor): if isinstance(sampled_pnts[:,3:6], torch.Tensor):
f_normals = flattened_pnts[3:6].clone().detach().to(device).view(-1, 3) f_normals = sampled_pnts[:,3:6].clone().detach().to(device)
else: else:
f_normals = torch.from_numpy(flattened_pnts[3:6]).clone().detach().to(device).view(-1, 3) f_normals = torch.from_numpy(sampled_pnts[:,3:6]).clone().detach().to(device)
#logger.info(f"normals norm:{f_normals.norm(2, dim=-1)}")
#logger.info(f"normals norm:{f_normals}")
# 检查 f_pnts 和 f_normals 的形状 # 检查 f_pnts 和 f_normals 的形状
if f_pnts.shape[-1] != 3 or f_normals.shape[-1] != 3: if f_pnts.shape[-1] != 3 or f_normals.shape[-1] != 3:
logger.error(f"f_pnts 形状: {f_pnts.shape}, f_normals 形状: {f_normals.shape},期望最后一维尺寸为 3") logger.error(f"f_pnts 形状: {f_pnts.shape}, f_normals 形状: {f_normals.shape},期望最后一维尺寸为 3")
return return
# 【-1,1】 归一化
f_pnts = f_pnts * 2
loss1["fcd"] = chamfer_distance(f_pnts, nh_points).item() loss1["fcd"] = chamfer_distance(f_pnts, nh_points).item()
loss2["fcd"] = chamfer_distance(f_pnts, our_points).item() loss2["fcd"] = chamfer_distance(f_pnts, our_points).item()
loss1["fae"] = hausdorff_distance(f_normals, nh_points).item()
loss2["fae"] = hausdorff_distance(f_normals, our_points).item()
f_pnts.to(device).requires_grad_(True)
sdf1 = nh(nh_model, f_pnts)
sdf2 = mine(model_path, f_pnts)
normal1 = gradient(f_pnts, sdf1)
normal2 = gradient(f_pnts, sdf2)
#logger.info(f"normal1s norm:{normal1.norm(2, dim=-1)}")
#logger.info(f"normal2s norm:{normal2.norm(2, dim=-1)}")
loss1["fae"] = average_normal_error(f_normals, normal1).item()
loss2["fae"] = average_normal_error(f_normals, normal2).item()
# 计算 IoU,从obj文件计算 # 计算 IoU,从obj文件计算
# ... existing code ... # ... existing code ...
# 计算 IoU,使用采样点方法 escape_iou = True
try: if not escape_iou:
num_samples_iou = 10000 # 采样点数量,可以根据需要调整 # 计算 IoU,使用采样点方法
# 调用封装的采样函数 try:
points = sample_grid_points(mesh, nh_mesh, our_mesh, num_samples_iou) print("计算iou")
num_samples_iou = 1000 # 采样点数量,可以根据需要调整
# 判断点是否在各个网格内部 # 调用封装的采样函数
inside_mesh = mesh.contains(points) points = sample_grid_points()
inside_nh = nh_mesh.contains(points)
inside_our = our_mesh.contains(points) # 判断点是否在各个网格内部
inside_mesh = mesh.contains(points)
# 计算 nh_mesh 与 mesh 的交集和并集 inside_nh = nh_mesh.contains(points)
intersection_nh = np.logical_and(inside_mesh, inside_nh).sum() inside_our = our_mesh.contains(points)
union_nh = np.logical_or(inside_mesh, inside_nh).sum()
# 计算 nh_mesh 与 mesh 的交集和并集
# 计算 our_mesh 与 mesh 的交集和并集 intersection_nh = np.logical_and(inside_mesh, inside_nh).sum()
intersection_our = np.logical_and(inside_mesh, inside_our).sum() union_nh = np.logical_or(inside_mesh, inside_nh).sum()
union_our = np.logical_or(inside_mesh, inside_our).sum()
# 计算 our_mesh 与 mesh 的交集和并集
# 计算 IoU intersection_our = np.logical_and(inside_mesh, inside_our).sum()
iou_nh = intersection_nh / union_nh if union_nh > 0 else 0.0 union_our = np.logical_or(inside_mesh, inside_our).sum()
iou_our = intersection_our / union_our if union_our > 0 else 0.0
# 计算 IoU
loss1["iou"] = iou_nh iou_nh = intersection_nh / union_nh if union_nh > 0 else 0.0
loss2["iou"] = iou_our iou_our = intersection_our / union_our if union_our > 0 else 0.0
except Exception as e: logger.debug(f"成功计算 IoU,NH 模型 IoU: {intersection_our}, Mine 模型 IoU: {union_our}")
print(f"使用采样点计算 IoU 时出错: {e}")
loss1["iou"] = iou_nh
loss2["iou"] = iou_our
print(f"IoU:{iou_our}")
except Exception as e:
logger.debug(f"使用采样点计算 IoU 时出错: {e}")
else:
loss1["iou"] = 0.0
loss2["iou"] = 0.0
return loss1, loss2 return loss1, loss2
@ -299,6 +325,8 @@ def main():
tl1_iou += l1["iou"] tl1_iou += l1["iou"]
tl2_iou += l2["iou"] tl2_iou += l2["iou"]
valid_count += 1 valid_count += 1
logger.debug(f"| {name}-NH 模型 | {l1['cd']} | {l1['hd']} | {l1['nae']} | {l1['fcd']} | {l1['fae']} | {l1['de']} | {l1['iou']} |")
logger.debug(f"| {name}-Mine 模型 | {l2['cd']} | {l2['hd']} | {l2['nae']} | {l2['fcd']} | {l2['fae']} | {l2['de']} | {l2['iou']} |")
if valid_count > 0: if valid_count > 0:
avg_l1_de = tl1_de / valid_count avg_l1_de = tl1_de / valid_count
avg_l1_nae = tl1_nae / valid_count avg_l1_nae = tl1_nae / valid_count

2
brep2sdf/networks/decoder.py

@ -65,7 +65,7 @@ class Decoder(nn.Module):
self.activation = nn.ReLU() self.activation = nn.ReLU()
else: else:
#siren #siren
self.activation = nn.ReLU() self.activation = Sine()
self.final_activation = nn.Tanh() self.final_activation = nn.Tanh()
def forward(self, feature_matrix: torch.Tensor) -> torch.Tensor: def forward(self, feature_matrix: torch.Tensor) -> torch.Tensor:

49
brep2sdf/networks/encoder.py

@ -1,5 +1,6 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
import numpy as np
from .octree import OctreeNode from .octree import OctreeNode
from .feature_volume import PatchFeatureVolume,SimpleFeatureEncoder from .feature_volume import PatchFeatureVolume,SimpleFeatureEncoder
@ -31,7 +32,7 @@ class Encoder(nn.Module):
resolutions = self._batch_calculate_resolution(volume_bboxs) resolutions = self._batch_calculate_resolution(volume_bboxs)
# 初始化多个特征体积 # 初始化多个特征体积
'''
self.feature_volumes = nn.ModuleList([ self.feature_volumes = nn.ModuleList([
PatchFeatureVolume( PatchFeatureVolume(
bbox=bbox, bbox=bbox,
@ -45,14 +46,14 @@ class Encoder(nn.Module):
input_dim=3, feature_dim=feature_dim input_dim=3, feature_dim=feature_dim
) for i, bbox in enumerate(volume_bboxs) ) for i, bbox in enumerate(volume_bboxs)
]) ])
'''
self.background = self.simple_encoder = nn.Sequential( self.background = self.simple_encoder = nn.Sequential(
nn.Linear(3, 256), nn.Linear(3, feature_dim)
nn.BatchNorm1d(256),
nn.ReLU(),
nn.Linear(256, feature_dim)
) )
torch.nn.init.constant_(self.background[0].bias, 0.0)
torch.nn.init.normal_(self.background[0].weight, 0.0, np.sqrt(2) / np.sqrt(feature_dim))
# 添加几何初始化
print(f"Initialized {len(self.feature_volumes)} feature volumes with resolutions: {resolutions.tolist()}") print(f"Initialized {len(self.feature_volumes)} feature volumes with resolutions: {resolutions.tolist()}")
print(f"Model parameters memory usage: {sum(p.numel() * p.element_size() for p in self.parameters()) / 1024**2:.2f} MB") print(f"Model parameters memory usage: {sum(p.numel() * p.element_size() for p in self.parameters()) / 1024**2:.2f} MB")
@ -79,35 +80,34 @@ class Encoder(nn.Module):
return resolutions return resolutions
def forward(self, query_points: torch.Tensor, volume_indices: torch.Tensor) -> torch.Tensor: def forward(self, query_points: torch.Tensor, volume_indices_mask: torch.Tensor) -> torch.Tensor:
""" """
修改后的前向传播返回所有关联volume的特征矩阵 修改后的前向传播返回所有关联volume的特征矩阵
参数: 参数:
query_points: 查询点坐标 (B, 3) query_points: 查询点坐标 (B, 3)
volume_indices: 关联的volume索引矩阵 (B, P) volume_indices_mask: 关联的volume索引矩阵 (B, P)
返回: 返回:
特征张量 (B, P, D) 特征张量 (B, P, D)
""" """
batch_size, num_volumes = volume_indices.shape batch_size, num_volumes = volume_indices_mask.shape
all_features = torch.zeros(batch_size, num_volumes, self.feature_dim, all_features = torch.zeros(batch_size, num_volumes, self.feature_dim,
device=query_points.device) device=query_points.device)
background_features = self.background.forward(query_points) # (B, D) background_features = self.background.forward(query_points) # (B, D)
# 遍历每个volume索引 # 遍历每个volume索引
for vol_id, volume in enumerate(self.feature_volumes): for vol_id, volume in enumerate(self.feature_volumes):
current_indices = volume_indices[:, vol_id] mask = volume_indices_mask[:, vol_id].squeeze()
# 创建掩码 (B,) #logger.debug(f"mask:{mask},shape:{mask.shape},mask.any():{mask.any()}")
mask = (current_indices == vol_id)
if mask.any(): if mask.any():
# 获取对应volume的特征 (M, D) # 获取对应volume的特征 (M, D)
features = volume.forward(query_points[mask]) features = volume.forward(query_points[mask])
all_features[mask, vol_id] = 0.9 * background_features[mask] + 0.1 * features all_features[mask, vol_id] = 0.9 * background_features[mask] + 0.1 * features
#all_features[:, :] = background_features.unsqueeze(1)
return all_features return all_features
def forward_background(self, query_points: torch.Tensor) -> torch.Tensor: def forward_background(self, query_points: torch.Tensor) -> torch.Tensor:
""" """
修改后的前向传播返回所有关联volume的特征矩阵 修改后的前向传播返回所有关联volume的特征矩阵
@ -132,27 +132,10 @@ class Encoder(nn.Module):
""" """
# 获取 patch 特征 # 获取 patch 特征
patch_features = self.feature_volumes[patch_id].forward(surf_points) patch_features = self.feature_volumes[patch_id].forward(surf_points)
background_features = self.background.forward(surf_points) # (B, D)
#dot = make_dot(patch_features, params=dict(self.feature_volumes.named_parameters())) #dot = make_dot(patch_features, params=dict(self.feature_volumes.named_parameters()))
#dot.render("feature_extraction", format="png") # 将计算图保存为 PDF 文件 #dot.render("feature_extraction", format="png") # 将计算图保存为 PDF 文件
return patch_features return 0.9 * background_features + 0.1 * patch_features
def _optimized_trilinear(self, points, bboxes, features)-> torch.Tensor:
"""优化后的向量化三线性插值"""
# 添加显式类型转换确保计算稳定性
min_coords = bboxes[..., :3].to(torch.float32)
max_coords = bboxes[..., 3:].to(torch.float32)
normalized = (points - min_coords) / (max_coords - min_coords + 1e-8)
# 使用爱因斯坦求和代替分步计算
wx = torch.stack([1 - normalized[...,0], normalized[...,0]], -1) # (B,2)
wy = torch.stack([1 - normalized[...,1], normalized[...,1]], -1)
wz = torch.stack([1 - normalized[...,2], normalized[...,2]], -1)
# 合并所有计算步骤 (B,8,D) * (B,8,1) -> (B,D)
return torch.einsum('bcd,bc->bd', features,
torch.einsum('bi,bj,bk->bijk', wx, wy, wz).view(-1,8))
# 原batch_trilinear_interpolation方法可以删除
def to(self, device): def to(self, device):
super().to(device) super().to(device)

2
brep2sdf/networks/feature_volume.py

@ -21,7 +21,7 @@ class PatchFeatureVolume(nn.Module):
# 初始化特征向量为很小的值,使用较小的标准差 # 初始化特征向量为很小的值,使用较小的标准差
self.feature_volume = nn.Parameter(torch.empty(resolution, resolution, resolution, feature_dim)) self.feature_volume = nn.Parameter(torch.empty(resolution, resolution, resolution, feature_dim))
torch.nn.init.normal_(self.feature_volume, mean=0.0, std=0.01) # 标准差设置为 0.01,可根据需要调整 torch.nn.init.normal_(self.feature_volume, mean=0.0, std=torch.sqrt(torch.tensor(2.0)) / torch.sqrt(torch.tensor(float(feature_dim)))) # 标准差设置为 0.01,可根据需要调整
def _expand_bbox(self, min_coords, max_coords, ratio): def _expand_bbox(self, min_coords, max_coords, ratio):
# 扩展包围盒范围 # 扩展包围盒范围

5
brep2sdf/networks/loss.py

@ -78,6 +78,7 @@ class LossManager:
# 计算法线损失 # 计算法线损失
normals_loss = (((branch_grad - normals).abs()).norm(2, dim=1)).mean() # 计算法线损失 normals_loss = (((branch_grad - normals).abs()).norm(2, dim=1)).mean() # 计算法线损失
#logger.info(f"normals norm:{branch_grad.norm(2, dim=-1)}")
return normals_loss # 返回加权后的法线损失 return normals_loss # 返回加权后的法线损失
def eikonal_loss(self, nonmnfld_pnts, nonmnfld_pred): def eikonal_loss(self, nonmnfld_pnts, nonmnfld_pred):
@ -210,7 +211,7 @@ class LossManager:
manifold_loss = self.position_loss(mnfld_pred, gt_sdfs) manifold_loss = self.position_loss(mnfld_pred, gt_sdfs)
# 计算法线损失 # 计算法线损失
#normals_loss = self.normals_loss(normals, mnfld_pnts, mnfld_pred) normals_loss = self.normals_loss(normals, mnfld_pnts, mnfld_pred)
#logger.gpu_memory_stats("计算法线损失后") #logger.gpu_memory_stats("计算法线损失后")
@ -224,7 +225,7 @@ class LossManager:
# 汇总损失 # 汇总损失
loss_details = { loss_details = {
"manifold": self.weights["manifold"] * manifold_loss, "manifold": self.weights["manifold"] * manifold_loss,
#"normals": self.weights["normals"] * normals_loss "normals": self.weights["normals"] * normals_loss
} }
# 计算总损失 # 计算总损失

8
brep2sdf/networks/network.py

@ -56,7 +56,7 @@ class Net(nn.Module):
def __init__(self, def __init__(self,
octree, octree,
volume_bboxs, volume_bboxs,
feature_dim=8, feature_dim,
decoder_output_dim=1, decoder_output_dim=1,
decoder_hidden_dim=512, decoder_hidden_dim=512,
decoder_num_layers=6, decoder_num_layers=6,
@ -78,7 +78,7 @@ class Net(nn.Module):
d_in=feature_dim, d_in=feature_dim,
dims_sdf=[decoder_hidden_dim] * decoder_num_layers, dims_sdf=[decoder_hidden_dim] * decoder_num_layers,
#skip_in=(3,), #skip_in=(3,),
geometric_init=False, geometric_init=True,
beta=5 beta=5
) )
@ -135,6 +135,7 @@ class Net(nn.Module):
#logger.gpu_memory_stats("decoder farward后") #logger.gpu_memory_stats("decoder farward后")
#logger.debug("step combine") #logger.debug("step combine")
return f_i[:,0]
return self.process_sdf(f_i, face_indices_mask, operator) return self.process_sdf(f_i, face_indices_mask, operator)
@torch.jit.export @torch.jit.export
@ -167,13 +168,14 @@ class Net(nn.Module):
# 批量查询所有点的索引和bbox # 批量查询所有点的索引和bbox
#logger.debug("step encode") #logger.debug("step encode")
# 编码 # 编码
feature_vectors = self.encoder.forward(query_points,face_indices_mask) feature_vectors = self.encoder(query_points,face_indices_mask)
#print("feature_vector:", feature_vectors.shape) #print("feature_vector:", feature_vectors.shape)
# 解码 # 解码
f_i = self.decoder(feature_vectors) # (B, P) f_i = self.decoder(feature_vectors) # (B, P)
#logger.gpu_memory_stats("decoder farward后") #logger.gpu_memory_stats("decoder farward后")
#logger.debug("step combine") #logger.debug("step combine")
return f_i[:,0]
return self.process_sdf(f_i, face_indices_mask, operator) return self.process_sdf(f_i, face_indices_mask, operator)
@torch.jit.ignore @torch.jit.ignore

321
brep2sdf/test.py

@ -1,112 +1,247 @@
import trimesh
import numpy as np import numpy as np
from brep2sdf.data.sampler import sample_zero_surface_points_and_normals
from brep2sdf.networks.network import gradient
import torch import torch
import logging import argparse
from skimage import measure
# 配置日志记录 import time
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') import trimesh
import pickle
def position_loss(pred_sdfs: torch.Tensor) -> torch.Tensor: from trimesh.proximity import ProximityQuery, signed_distance, closest_point
"""位置损失函数""" from brep2sdf.utils.logger import logger
# 保持梯度流
squared_diff = torch.pow(pred_sdfs, 2)
return torch.mean(squared_diff)
def average_normal_error(normals1: torch.Tensor, normals2: torch.Tensor) -> torch.Tensor: def sample_from_obj(obj_file, num_samples):
""" """
计算平均法向量误差 (NAE) 从OBJ文件中采样点
:param normals1: 形状为 (B, 3) 的法向量张量 :param obj_file: OBJ文件路径
:param normals2: 形状为 (B, 3) 的法向量张 :param num_samples: 采样点数
:return: NAE :return: 采样点数组 (N, 3)
""" """
dot_products = torch.sum(normals1 * normals2, dim=-1) mesh = trimesh.load(obj_file)
absolute_dot_products = torch.abs(dot_products) points, _ = mesh.sample(num_samples, return_index=True)
angle_errors = 1 - absolute_dot_products return points
return torch.mean(angle_errors)
def def obj_to_sdf(obj_file, depth, box_size):
"""
将OBJ文件转换为SDF网格
:param obj_file: OBJ文件路径
:param depth: 网格深度
:param box_size: 边界框大小
:return: SDF值数组和网格坐标
"""
# 创建网格
points, xx, yy, zz = create_grid(depth, box_size)
print(1)
# 计算SDF
mesh = trimesh.load(obj_file)
print(2)
sdf = signed_distance(mesh, points)
print(3)
sdf_grid = sdf.reshape(xx.shape)
print(4)
return sdf_grid, xx, yy, zz
def extract_surface(sdf, xx, yy, zz, method='MC', bbox_size=1.0, feature_angle=30.0, voxel_size=0.01):
"""
提取零表面
:param sdf: SDF值三维数组
:param xx/yy/zz: 网格坐标
:param method: 提取方法MC: Marching Cubes
:return: 顶点和面片
"""
if method == 'MC':
verts, faces, _, _ = measure.marching_cubes(sdf, level=0)
elif method == 'EMC':
from iso_algorithms import enhanced_marching_cubes
verts, faces = enhanced_marching_cubes(
sdf,
feature_angle=feature_angle,
gradient_direction='descent'
)
elif method == 'DC':
from iso_algorithms import dual_contouring
verts, faces = dual_contouring(sdf, voxel_size=voxel_size)
else:
raise ValueError(f"不支持的算法: {method}")
# 新增顶点后处理
verts = (verts - sdf.shape[0]//2) / (sdf.shape[0]//2) / 2 * bbox_size # 归一化到[-1,1]
return verts, faces
def save_ply(vertices, faces, filename):
"""
保存顶点和面片为PLY文件
:param vertices: 顶点数组 (N, 3)
:param faces: 面片数组 (M, 3)
:param filename: 输出文件名
"""
with open(filename, 'w') as f:
f.write("ply\n")
f.write("format ascii 1.0\n")
f.write(f"element vertex {len(vertices)}\n")
f.write("property float x\n")
f.write("property float y\n")
f.write("property float z\n")
f.write(f"element face {len(faces)}\n")
f.write("property list uchar int vertex_indices\n")
f.write("end_header\n")
for v in vertices:
f.write(f"{v[0]} {v[1]} {v[2]}\n")
for face in faces:
f.write(f"3 {face[0]} {face[1]} {face[2]}\n")
def compute_iou(sdf1, sdf2, threshold=0.0):
"""
计算两个 SDF 之间的 IoU
:param sdf1: 第一个 SDF 数组
:param sdf2: 第二个 SDF 数组
:param threshold: 阈值用于判断点是否在表面内
:return: IoU
"""
if sdf1.shape != sdf2.shape:
raise ValueError("sdf1 和 sdf2 的形状必须一致")
inside1 = sdf1 <= threshold
inside2 = sdf2 <= threshold
intersection = np.logical_and(inside1, inside2).sum()
union = np.logical_or(inside1, inside2).sum()
if union == 0:
logger.warning("union 为 0,无法计算 IoU")
return 0.0
iou = intersection / union
return iou
def create_grid(depth, box_size):
"""
创建规则的三维网格
:param depth: 网格深度分辨率
:param box_size: 边界框大小
:return: 网格点坐标和网格坐标数组
"""
grid_size = 2**depth + 1
x = np.linspace(-box_size/2, box_size/2, grid_size)
y = np.linspace(-box_size/2, box_size/2, grid_size)
z = np.linspace(-box_size/2, box_size/2, grid_size)
xx, yy, zz = np.meshgrid(x, y, z, indexing='ij')
points = np.stack([xx.ravel(), yy.ravel(), zz.ravel()], axis=1)
return points, xx, yy, zz
def load_brep_file(brep_path):
with open(brep_path, 'rb') as f:
brep_raw = pickle.load(f)
return brep_raw
def test():
data = load_brep_file("/home/wch/brep2sdf/data/output_data/00000003.xyz")
surfs = data["sampled_points_normals_sdf"] # 获取采样点数据
points = surfs[:, :3] # 提取点坐标
normals = surfs[:, 3:6] # 提取法向量
sdf_values = surfs[:, 6] # 提取SDF值
print(len(points))
print(f"points max: {np.max(points)}") # 打印points的最大值
print(f"sdf_values max: {np.max(sdf_values)}") # 打印sdf_values的最大值
# 将点坐标和SDF值转换为网格格式
grid_size = int(np.cbrt(len(points))) # 假设采样点是立方体网格
sdf_grid = sdf_values.reshape((grid_size, grid_size, grid_size))
# 使用Marching Cubes提取零表面
verts, faces, _, _ = measure.marching_cubes(sdf_grid, level=0)
# 打印verts和faces的最大值
print(f"verts max: {np.max(verts)}")
print(f"faces max: {np.max(faces)}")
# 保存提取的网格
save_ply(verts, faces, "output_mc.ply")
print("Marching Cubes 提取的网格已保存为 output_mc.ply")
def sample_grid(trimesh_mesh_ncs: trimesh.Trimesh):
"""
np.ndarray | None: 形状为 (N, 7) 的数组 [x, y, z, nx, ny, nz, sdf]
如果采样或计算失败则返回 None
"""
grid_size = 2**5 + 1
start = -1
end = 1
x = np.linspace(start, end, grid_size)
y = np.linspace(start, end, grid_size)
z = np.linspace(start, end, grid_size)
xx, yy, zz = np.meshgrid(x, y, z, indexing='ij')
points = np.stack([xx.ravel(), yy.ravel(), zz.ravel()], axis=1)
# ==========
def load_model(model_path):
"""加载模型的通用函数"""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
try: try:
model = torch.jit.load(model_path).to(device) # Step 1: Compute signed distance
logging.info(f"成功加载模型: {model_path}") sdf = signed_distance(trimesh_mesh_ncs, points)
return model print(sdf)
except Exception as e:
logging.error(f"加载模型 {model_path} 时出错: {e}")
return None
# Step 2: Compute normals via closest face
_, _, face_indices = closest_point(trimesh_mesh_ncs, points)
normals = trimesh_mesh_ncs.face_normals[face_indices]
#========== # Step 3: Concatenate into (N, 7)
def nh(model_path, points): data = np.hstack([points, normals, sdf[:, None]])
model = load_model(model_path) return data
if model is None:
return None
try:
return model(points)
except Exception as e:
logging.error(f"调用 NH 模型时出错: {e}")
return None
def mine(model_path, points):
model = load_model(model_path)
if model is None:
return None
try:
return model.forward_background(points)
except Exception as e: except Exception as e:
logging.error(f"调用 mine 模型时出错: {e}") print(f"Error computing SDF or normals: {e}")
return None return None
def main(): def test2(obj_file):
# 替换为实际的 obj 文件路径 mesh = trimesh.load(obj_file)
obj_file_path = "/home/wch/brep2sdf/data/gt_mesh/00000031.obj" surfs = sample_grid(mesh)
model_path = "/home/wch/brep2sdf/data/output_data/00000031.pt" points = surfs[:, :3] # 提取点坐标
nh_model = "/home/wch/NH-Rep/data/output_data/00000031_0_50k_model_h.pt" normals = surfs[:, 3:6] # 提取法向量
sdf_values = surfs[:, 6] # 提取SDF值
try: # 将点坐标和SDF值转换为网格格式
# 读取 obj 文件 grid_size = int(np.cbrt(len(points))) # 假设采样点是立方体网格
mesh = trimesh.load_mesh(obj_file_path) sdf_grid = sdf_values.reshape((grid_size, grid_size, grid_size))
logging.info(f"成功读取 OBJ 文件: {obj_file_path}")
except Exception as e:
logging.error(f"读取 OBJ 文件 {obj_file_path} 时出错: {e}")
return
try: # 使用Marching Cubes提取零表面
# 调用采样函数 verts, faces, _, _ = measure.marching_cubes(sdf_grid, level=0)
result1 = sample_zero_surface_points_and_normals(mesh, num_samples=4096)
if result1 is None: # 打印verts和faces的最大值
logging.error("采样失败,返回 None") print(f"verts max: {np.max(verts)}")
return print(f"faces max: {np.max(faces)}")
# 提取前 3 列作为坐标点
coordinates = result1[:, :3]
# 将 ndarray 转换为 Tensor 并移动到设备上
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
coordinates_tensor = torch.from_numpy(coordinates).float().to(device)
sdf1 = nh(nh_model, coordinates_tensor) / 2
sdf2 = mine(model_path, coordinates_tensor)
if sdf1 is not None and sdf2 is not None:
loss1_ = position_loss(sdf1)
loss2 = position_loss(sdf2)
logging.info(f"NH 模型位置损失: {loss1.item()}")
logging.info(f"Mine 模型位置损失: {loss2.item()}")
gt_normal = result1[:, 3:6]
normal1 = gradient(coordinates, sdf1)
normal2 = gradient(coordinates, sdf2)
nae1=average_normal_error(gt_normal, normal1)
nae2=average_normal_error(gt_normal, normal2)
else:
logging.error("无法计算损失,SDF 结果为 None")
except Exception as e: # 保存提取的网格
logging.error(f"处理过程中出现错误: {e}") save_ply(verts, faces, "output.ply")
print("Marching Cubes 提取的网格已保存为 output.ply")
def main():
parser = argparse.ArgumentParser(description='IsoSurface Generator')
parser.add_argument('-i', '--input', type=str, required=True, help='Input OBJ file')
parser.add_argument('-o', '--output', type=str, required=True, help='Output mesh file (.ply)')
parser.add_argument('--depth', type=int, default=5, help='网格深度(分辨率)')
parser.add_argument('--box_size', type=float, default=2.0, help='边界框大小')
parser.add_argument('--method', type=str, default='MC', choices=['MC', 'EMC', 'DC'], help='表面提取方法')
args = parser.parse_args()
# 将OBJ转换为SDF网格
sdf_grid, xx, yy, zz = obj_to_sdf(args.input, args.depth, args.box_size)
# 提取表面
print("Extracting surface...")
start_time = time.time()
verts, faces = extract_surface(sdf_grid, xx, yy, zz, args.method, args.box_size)
print(f"Surface extraction took {time.time() - start_time:.2f} seconds")
# 保存网格
save_ply(verts, faces, args.output)
print(f"Mesh saved to {args.output}")
# 计算 IoU
sdf1 = sdf_grid # 第一个 SDF 网格
sdf2 = sdf_grid # 第二个 SDF 网格(这里假设使用相同的网格进行计算)
iou = compute_iou(sdf1, sdf2)
print(f"IoU: {iou}")
if __name__ == "__main__": if __name__ == "__main__":
main() #main()
#test()
test2("/home/wch/brep2sdf/data/gt_mesh/00000003.obj")
# python test.py -i /home/wch/brep2sdf/data/gt_mesh/00000003.obj -o output.ply --depth 6 --box_size 2.0 --method MC

25
brep2sdf/train.py

@ -138,7 +138,7 @@ class Trainer:
self.model = Net( self.model = Net(
octree=self.root, octree=self.root,
volume_bboxs=surf_bbox, volume_bboxs=surf_bbox,
feature_dim=64 feature_dim=3
).to(self.device) ).to(self.device)
logger.gpu_memory_stats("模型初始化后") logger.gpu_memory_stats("模型初始化后")
@ -594,7 +594,7 @@ class Trainer:
self.model.train() self.model.train()
total_loss = 0.0 total_loss = 0.0
step = 0 # 如果你的训练是分批次的,这里应该用批次索引 step = 0 # 如果你的训练是分批次的,这里应该用批次索引
batch_size = 4096 # 设置合适的batch大小 batch_size = 4096*5 # 设置合适的batch大小
# 数据处理 # 数据处理
# manfld # manfld
@ -608,7 +608,7 @@ class Trainer:
_, _mnfld_face_indices_mask, _mnfld_operator = self.root.forward(_mnfld_pnts) _, _mnfld_face_indices_mask, _mnfld_operator = self.root.forward(_mnfld_pnts)
# 生成非流形点 # 生成非流形点
_nonmnfld_pnts, _psdf = self.sampler.get_norm_points(_mnfld_pnts, _normals, 0.1) _nonmnfld_pnts, _psdf = self.sampler.get_norm_points(_mnfld_pnts, _normals, 0.01)
_, _nonmnfld_face_indices_mask, _nonmnfld_operator = self.root.forward(_nonmnfld_pnts) _, _nonmnfld_face_indices_mask, _nonmnfld_operator = self.root.forward(_nonmnfld_pnts)
# 更新缓存 # 更新缓存
@ -620,6 +620,7 @@ class Trainer:
"nonmnfld_face_indices_mask": _nonmnfld_face_indices_mask, "nonmnfld_face_indices_mask": _nonmnfld_face_indices_mask,
"nonmnfld_operator": _nonmnfld_operator "nonmnfld_operator": _nonmnfld_operator
} }
logger.gpu_memory_stats("缓存后")
else: else:
# 从缓存中读取数据 # 从缓存中读取数据
_mnfld_face_indices_mask = self.cached_train_data["mnfld_face_indices_mask"] _mnfld_face_indices_mask = self.cached_train_data["mnfld_face_indices_mask"]
@ -684,7 +685,7 @@ class Trainer:
# 检查法线和带梯度的点 # 检查法线和带梯度的点
#if check_tensor(normals, "Normals (Loss Input)", epoch, step): raise ValueError("Bad normals before loss") #if check_tensor(normals, "Normals (Loss Input)", epoch, step): raise ValueError("Bad normals before loss")
#if check_tensor(points, "Points (Loss Input)", epoch, step): raise ValueError("Bad points before loss") #if check_tensor(points, "Points (Loss Input)", epoch, step): raise ValueError("Bad points before loss")
logger.gpu_memory_stats("计算损失前") #logger.gpu_memory_stats("计算损失前")
loss, loss_details = self.loss_manager.compute_loss( loss, loss_details = self.loss_manager.compute_loss(
mnfld_pnts, mnfld_pnts,
nonmnfld_pnts, nonmnfld_pnts,
@ -694,6 +695,7 @@ class Trainer:
nonmnfld_pred, nonmnfld_pred,
psdf psdf
) )
#logger.gpu_memory_stats("计算损失后")
else: else:
loss = torch.nn.functional.mse_loss(pred_sdf, gt_sdf) loss = torch.nn.functional.mse_loss(pred_sdf, gt_sdf)
@ -707,7 +709,7 @@ class Trainer:
except Exception as loss_e: except Exception as loss_e:
logger.error(f"Epoch {epoch} Step {step}: Error during loss calculation: {loss_e}", exc_info=True) logger.error(f"Epoch {epoch} Step {step}: Error during loss calculation: {loss_e}", exc_info=True)
return float('inf') # 如果计算出错,停止这个epoch return float('inf') # 如果计算出错,停止这个epoch
logger.gpu_memory_stats("损失计算后") #logger.gpu_memory_stats("损失计算后")
# --- 反向传播和优化 --- # --- 反向传播和优化 ---
try: try:
@ -738,7 +740,7 @@ class Trainer:
logger.info(f'Train Epoch: {epoch:4d}]\t' logger.info(f'Train Epoch: {epoch:4d}]\t'
f'Loss: {total_loss:.6f}') f'Loss: {total_loss:.6f}')
if loss_details: logger.info(f"Loss Details: {loss_details}") if loss_details: logger.info(f"Loss Details: {loss_details}")
self.validate(epoch,total_loss) #self.validate(epoch,total_loss)
return total_loss # 对于单批次训练,直接返回当前损失 return total_loss # 对于单批次训练,直接返回当前损失
@ -952,8 +954,17 @@ class Trainer:
sdfs= model(example_input) sdfs= model(example_input)
logger.debug(f"sdfs:{sdfs}") logger.debug(f"sdfs:{sdfs}")
def _tracing_model_by_script(self): def _tracing_model_by_script(self,if_best=True):
"""保存模型""" """保存模型"""
if if_best and self.best_loss < float('inf'):
checkpoint_dir = os.path.join(
self.config.train.checkpoint_dir,
self.model_name
)
checkpoint_path = os.path.join(checkpoint_dir, f"best.pth")
start_epoch = self._load_checkpoint(checkpoint_path)
logger.info(f"Loaded best model from {checkpoint_path}, epoch: {start_epoch}")
self.model.eval() self.model.eval()
# 确保模型中的所有逻辑都兼容 TorchScript # 确保模型中的所有逻辑都兼容 TorchScript
scripted_model = torch.jit.script(self.model) scripted_model = torch.jit.script(self.model)

Loading…
Cancel
Save