Browse Source

基本可以跑通,除了八叉树部分保存还有问题

final
mckay 2 months ago
parent
commit
2ac55ea3df
  1. 5
      .gitignore
  2. 207
      brep2sdf/IsoSurfacing.py
  3. 67
      brep2sdf/batch_train.py
  4. 4
      brep2sdf/config/default_config.py
  5. 280
      brep2sdf/evaluation.py
  6. 4
      brep2sdf/train.py

5
.gitignore

@ -171,6 +171,8 @@ test_data/
logs/
wandb/
*.pth
*.pt
*.csv
checkpoints/
data/gt_mesh
@ -178,4 +180,5 @@ data/gt_point
data/step
data/input_data
data/output_data
data/name_list.txt
data/name_list.txt
data/scripts/IsoSurfacing

207
brep2sdf/IsoSurfacing.py

@ -1,66 +1,161 @@
import os
import subprocess
from tqdm import tqdm
import numpy as np
import torch
import argparse
from skimage import measure
import time
import trimesh
# 使用一个 c++ 程序处理,这里只是调用,注意要在docker里面运行。宿主机编译失败
def create_grid(depth, box_size):
"""
创建三维网格点
:param depth: 网格深度决定分辨率
:param box_size: 边界框大小边长
:return: 网格点数组和坐标网格
"""
grid_size = 2**depth + 1
start = -box_size / 2
end = box_size / 2
x = np.linspace(start, end, grid_size)
y = np.linspace(start, end, grid_size)
z = np.linspace(start, end, grid_size)
xx, yy, zz = np.meshgrid(x, y, z, indexing='ij')
points = np.stack([xx.ravel(), yy.ravel(), zz.ravel()], axis=1)
return points, xx, yy, zz
def predict_sdf(model, points, device):
"""
使用模型预测SDF值
:param model: PyTorch模型
:param points: 输入点坐标 (N, 3)
:param device: 设备CPU/GPU
:return: SDF值数组 (N,)
"""
points_t = torch.from_numpy(points).float().to(device)
with torch.no_grad():
sdf = model(points_t).cpu().numpy().flatten()
return sdf
def extract_surface(sdf, xx, yy, zz, method='MC'):
"""
提取零表面
:param sdf: SDF值三维数组
:param xx/yy/zz: 网格坐标
:param method: 提取方法MC: Marching Cubes
:return: 顶点和面片
"""
if method == 'MC':
verts, faces, _, _ = measure.marching_cubes(sdf, level=0)
else:
raise NotImplementedError("仅支持Marching Cubes方法")
return verts, faces
def save_ply(vertices, faces, filename):
"""
保存顶点和面片为PLY文件
:param vertices: 顶点数组 (N, 3)
:param faces: 面片数组 (M, 3)
:param filename: 输出文件名
"""
with open(filename, 'w') as f:
f.write("ply\n")
f.write("format ascii 1.0\n")
f.write(f"element vertex {len(vertices)}\n")
f.write("property float x\n")
f.write("property float y\n")
f.write("property float z\n")
f.write(f"element face {len(faces)}\n")
f.write("property list uchar int vertex_indices\n")
f.write("end_header\n")
for v in vertices:
f.write(f"{v[0]} {v[1]} {v[2]}\n")
for face in faces:
f.write(f"3 {face[0]} {face[1]} {face[2]}\n")
def compute_sdf_error(model, gt_mesh, res, device):
"""
计算预测SDF与GT网格的误差
:param model: PyTorch模型
:param gt_mesh: GT网格Trimesh格式
:param res: 误差计算分辨率
:param device: 设备
:return: 平均误差和最大误差
"""
# 生成均匀采样点
box_size = max(gt_mesh.extents)
start = -box_size / 2
end = box_size / 2
x = np.linspace(start, end, res)
y = np.linspace(start, end, res)
z = np.linspace(start, end, res)
points = np.array(np.meshgrid(x, y, z)).T.reshape(-1, 3)
# 预测SDF
pred_sdf = predict_sdf(model, points, device)
# 计算GT距离
distances = gt_mesh.nearest.on_surface(points)[1]
gt_sdf = np.abs(distances)
# 计算误差
abs_error = np.abs(pred_sdf - gt_sdf)
rel_error = abs_error / (np.abs(gt_sdf) + 1e-9)
avg_abs = np.mean(abs_error)
avg_rel = np.mean(rel_error)
max_abs = np.max(abs_error)
max_rel = np.max(rel_error)
return avg_abs, avg_rel, max_abs, max_rel
def main():
# 定义 STEP 文件目录和名称列表文件路径
output_data_root_dir = "/workspace/home/wch/brep2sdf/data/output_data"
name_list_path = "/workspace/home/wch/brep2sdf/data/name_list.txt"
parser = argparse.ArgumentParser(description='IsoSurface Generator')
parser.add_argument('-i', '--input', type=str, required=True, help='Input model file (.pt)')
parser.add_argument('-o', '--output', type=str, required=True, help='Output mesh file (.ply)')
parser.add_argument('--depth', type=int, default=7, help='网格深度(分辨率)')
parser.add_argument('--box_size', type=float, default=2.0, help='边界框大小')
parser.add_argument('--method', type=str, default='MC', choices=['MC'], help='表面提取方法')
parser.add_argument('--use-gpu', action='store_true', help='使用GPU')
parser.add_argument('--compare', type=str, help='GT网格文件(.ply)')
parser.add_argument('--compres', type=int, default=32, help='误差计算分辨率')
args = parser.parse_args()
# 设置设备
device = torch.device("cuda" if args.use_gpu and torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# 读取名称列表
try:
with open(name_list_path, 'r') as f:
names = [line.strip() for line in f if line.strip()] # 去除空行
except FileNotFoundError:
print(f"Error: File '{name_list_path}' not found.")
return
except Exception as e:
print(f"Error reading file '{name_list_path}': {e}")
return
model = torch.jit.load(args.input).to(device)
#model = torch.load(args.input).to(device)
model.eval()
# 遍历名称列表并处理每个 STEP 文件
for name in tqdm(names, desc="ISOsurfing pt files"):
pt_file = os.path.join(output_data_root_dir, f"{name}.pt")
if not pt_file:
print(f"Warning: No pt files found in directory '{output_data_root_dir}'. Skipping...")
continue
# 创建网格并预测SDF
points, xx, yy, zz = create_grid(args.depth, args.box_size)
sdf = predict_sdf(model, points, device)
print(points.shape)
print(sdf.shape)
print(sdf)
sdf_grid = sdf.reshape(xx.shape)
# ./ISG_console_pytorch -i ./test/teaser.pt -o outputmesh.ply -v -0.01 -d 8
# 构造子进程命令
command = [
"python", "/workspace/home/wch/brep2sdf/data/scripts/IsoSurfacing/build/App/console_pytorch/ISG_console_pytorch",
"-i", pt_file, # 使用当前遍历的pt文件
"-o", os.path.join(output_data_root_dir, f"{name}_outputmesh.ply"), # 动态生成输出文件路径
"-v", "-0.01", "-d", "8"
]
# 提取表面
print("Extracting surface...")
start_time = time.time()
verts, faces = extract_surface(sdf_grid, xx, yy, zz, args.method)
print(f"Surface extraction took {time.time() - start_time:.2f} seconds")
# 调用子进程运行命令
try:
result = subprocess.run(
command,
capture_output=True,
text=True,
check=True # 如果返回非零退出码,则抛出 CalledProcessError
)
print(f"Successfully processed '{name}'")
print("STDOUT:", result.stdout)
print("STDERR:", result.stderr)
except subprocess.CalledProcessError as e:
print(f"Error processing '{name}': Command failed with return code {e.returncode}")
print(f"Command: {e.cmd}")
print(f"Error type: {type(e).__name__}")
print("STDOUT:", e.stdout)
print("STDERR:", e.stderr)
print("Traceback:", e.__traceback__)
except Exception as e:
print(f"Unexpected error processing '{name}': {str(e)}")
print(f"Command: {command}")
print("Traceback:", traceback.format_exc())
# 保存网格
save_ply(verts, faces, args.output)
print(f"Mesh saved to {args.output}")
# 误差评估(可选)
if args.compare:
print("Computing SDF error...")
gt_mesh = trimesh.load(args.compare)
avg_abs, avg_rel, max_abs, max_rel = compute_sdf_error(
model, gt_mesh, args.compres, device
)
print(f"Average Absolute Error: {avg_abs:.4f}")
print(f"Average Relative Error: {avg_rel:.4f}")
print(f"Max Absolute Error: {max_abs:.4f}")
print(f"Max Relative Error: {max_rel:.4f}")
if __name__ == '__main__':
if __name__ == "__main__":
main()

67
brep2sdf/batch_train.py

@ -0,0 +1,67 @@
import os
import subprocess
from tqdm import tqdm
def main():
# 定义 STEP 文件目录和名称列表文件路径
step_root_dir = "/home/wch/brep2sdf/data/step"
name_list_path = "/home/wch/brep2sdf/data/name_list.txt"
# 读取名称列表
try:
with open(name_list_path, 'r') as f:
names = [line.strip() for line in f if line.strip()] # 去除空行
except FileNotFoundError:
print(f"Error: File '{name_list_path}' not found.")
return
except Exception as e:
print(f"Error reading file '{name_list_path}': {e}")
return
# 遍历名称列表并处理每个 STEP 文件
for name in tqdm(names, desc="Processing STEP files"):
step_dir = os.path.join(step_root_dir, name)
# 动态生成 STEP 文件路径(假设只有一个文件)
step_files = [
os.path.join(step_dir, f)
for f in os.listdir(step_dir)
if f.endswith(".step") and f.startswith(name)
]
if not step_files:
print(f"Warning: No STEP files found in directory '{step_dir}'. Skipping...")
continue
# 假设我们只处理第一个匹配的文件
input_step = step_files[0]
# 构造子进程命令
command = [
"python", "train.py",
"--use-normal",
"-i", input_step, # 输入文件路径
]
# 调用子进程运行命令
try:
result = subprocess.run(
command,
capture_output=True,
text=True,
check=True # 如果返回非零退出码,则抛出 CalledProcessError
)
print(f"Processed '{input_step}' successfully.")
print("STDOUT:", result.stdout)
print("STDERR:", result.stderr)
except subprocess.CalledProcessError as e:
print(f"Error processing '{input_step}': {e}")
print("STDOUT:", e.stdout)
print("STDERR:", e.stderr)
except Exception as e:
print(f"Unexpected error processing '{input_step}': {e}")
if __name__ == '__main__':
main()

4
brep2sdf/config/default_config.py

@ -47,7 +47,7 @@ class TrainConfig:
# 基本训练参数
batch_size: int = 8
num_workers: int = 4
num_epochs: int = 20
num_epochs: int = 200
learning_rate: float = 0.01
min_lr: float = 1e-5
weight_decay: float = 0.01
@ -89,7 +89,7 @@ class LogConfig:
# 本地日志
log_dir: str = '/home/wch/brep2sdf/logs' # 日志保存目录
log_level: str = 'INFO' # 日志级别
console_level: str = 'INFO' # 控制台日志级别
console_level: str = 'DEBUG' # 控制台日志级别
file_level: str = 'DEBUG' # 文件日志级别
@dataclass

280
brep2sdf/evaluation.py

@ -0,0 +1,280 @@
import os
import sys
from brep2sdf.utils.logger import logger
# 导入日志系统
from brep2sdf.utils.logger import logger
import numpy as np
from scipy.spatial import cKDTree
from scipy.spatial.distance import directed_hausdorff
import trimesh
import pandas as pd
import csv
import math
import pickle
import argparse
project_dir = "/home/wch/brep2sdf"
# parse args first and set gpu id
parser = argparse.ArgumentParser()
parser.add_argument('--gt_path', type=str,
default=os.path.join(project_dir, 'data/gt_point'),
help='ground truth data path')
parser.add_argument('--pred_path', type=str,
default=os.path.join(project_dir, 'data/output_data'),
help='converted data path')
parser.add_argument('--name_list', type=str, default='name_list.txt', help='names of models to be evaluated, if you want to evaluate the whole dataset, please set it as all_names.txt')
parser.add_argument('--nsample', type=int, default=50000, help='point batch size')
parser.add_argument('--regen', default = False, action="store_true", help = 'regenerate feature curves')
parser.add_argument('--csv_name', type=str, default='eval_results.csv', help='csv file name')
args = parser.parse_args()
def distance_p2p(points_src, normals_src, points_tgt, normals_tgt):
''' Computes minimal distances of each point in points_src to points_tgt.
Args:
points_src (numpy array [N, 3]): source points
normals_src (numpy array [N, 3]): source normals
points_tgt (numpy array [M, 3]): target points
normals_tgt (numpy array [M, 3]): target
Returns:
dist (numpy array [N]): minimal distances of each point in points_src to points_tgt
normals_dot_product (numpy array [N]): dot product of normals of points_src and points_tgt
'''
kdtree = cKDTree(points_tgt)
dist, idx = kdtree.query(points_src)
if normals_src is not None and normals_tgt is not None:
normals_src = \
normals_src / np.linalg.norm(normals_src, axis=-1, keepdims=True)
normals_tgt = \
normals_tgt / np.linalg.norm(normals_tgt, axis=-1, keepdims=True)
normals_dot_product = (normals_tgt[idx] * normals_src).sum(axis=-1)
# Handle normals that point into wrong direction gracefully
# (mostly due to mehtod not caring about this in generation)
normals_dot_product = np.abs(normals_dot_product)
return dist, normals_dot_product
def distance_feature2mesh(points, mesh):
prox = trimesh.proximity.ProximityQuery(mesh)
signed_distance = prox.signed_distance(points)
return np.abs(signed_distance)
def distance_p2mesh(points_src, normals_src, mesh):
points_tgt, idx = mesh.sample(args.nsample, return_index=True)
points_tgt = points_tgt.astype(np.float32)
normals_tgt = mesh.face_normals[idx]
cd1, nc1 = distance_p2p(points_src, normals_src, points_tgt, normals_tgt) #pred2gt
hd1 = cd1.max()
cd1 = cd1.mean()
nc1 = np.clip(nc1, -1.0, 1.0)
angles1 = np.arccos(nc1) / math.pi * 180.0
angles1_mean = angles1.mean()
angles1_std = np.std(angles1)
cd2, nc2 = distance_p2p(points_tgt, normals_tgt, points_src, normals_src) #gt2pred
hd2 = cd2.max()
cd2 = cd2.mean()
nc2 = np.clip(nc2, -1.0, 1.0)
angles2 = np.arccos(nc2)/ math.pi * 180.0
angles2_mean = angles2.mean()
angles2_std = np.std(angles2)
cd = 0.5 * (cd1 + cd2)
hd = max(hd1, hd2)
angles_mean = 0.5 * (angles1_mean + angles2_mean)
angles_std = 0.5 * (angles1_std + angles2_std)
return cd, hd, angles_mean, angles_std, hd1, hd2
def distance_fea(gt_pa, pred_pa):
"""计算特征点之间的距离和角度差异
Args:
gt_pa: 真实特征点和角度 [N, 4]
pred_pa: 预测特征点和角度 [N, 4]
Returns:
dfg2p: 真实到预测的距离
dfp2g: 预测到真实的距离
fag2p: 真实到预测的角度差
fap2g: 预测到真实的角度差
"""
gt_points = gt_pa[:,:3]
pred_points = pred_pa[:,:3]
gt_angle = gt_pa[:,3]
pred_angle = pred_pa[:,3]
dfg2p = 0.0
dfp2g = 0.0
fag2p = 0.0
fap2g = 0.0
pred_kdtree = cKDTree(pred_points)
dist1, idx1 = pred_kdtree.query(gt_points)
dfg2p = dist1.mean()
assert(idx1.shape[0] == gt_points.shape[0])
fag2p = np.abs(gt_angle - pred_angle[idx1])
gt_kdtree = cKDTree(gt_points)
dist2, idx2 = gt_kdtree.query(pred_points)
dfp2g = dist2.mean()
fap2g = np.abs(pred_angle - gt_angle[idx2])
fag2p = fag2p.mean()
fap2g = fap2g.mean()
return dfg2p, dfp2g, fag2p, fap2g
def load_and_process_single_model(line, gt_path, pred_mesh_path, args):
"""处理单个模型的评估
Args:
line (str): 模型名称
gt_path (str): 真值路径
pred_mesh_path (str): 预测网格路径
args: 参数配置
Returns:
dict: 包含该模型所有评估指标的字典
"""
try:
#line = line.strip()[:-4] # 不用去 _50k
result = {'name': line}
# 加载点云数据
test_xyz = os.path.join(gt_path, line+'_50k.xyz')
try:
ptnormal = np.loadtxt(test_xyz)
except FileNotFoundError:
logger.error(f"XYZ file not found: {test_xyz}")
return None
except IOError as e:
logger.error(f"Error reading XYZ file {test_xyz}: {str(e)}")
return None
except ValueError as e:
logger.error(f"Invalid data format in XYZ file {test_xyz}: {str(e)}")
return None
except Exception as e:
logger.error(f"Unexpected error loading {test_xyz}: {str(e)}")
return None
logger.debug("successfully load gt points.")
# 加载预测网格
meshfile = os.path.join(pred_mesh_path, '{}.ply'.format(line))
if not os.path.exists(meshfile):
logger.warning(f'File not exists: {meshfile}, try to generate it...')
pt_file = os.path.join(pred_mesh_path, '{}.pt'.format(line))
try:
# 记录开始执行命令
logger.debug(f"Executing IsoSurfacing: python ./IsoSurfacing.py -i {pt_file} -o {meshfile} --use-gpu")
# 执行命令并检查返回值
ret = os.system(f"python ./IsoSurfacing.py -i {pt_file} -o {meshfile} --use-gpu")
if ret != 0:
raise RuntimeError(f"IsoSurfacing failed with return code {ret}")
# 检查输出文件是否生成
if not os.path.exists(meshfile):
raise FileNotFoundError(f"Output mesh file not created: {meshfile}")
logger.debug("IsoSurfacing completed successfully")
except FileNotFoundError as e:
logger.error(f"IsoSurfacing input file not found: {str(e)}")
return None
except RuntimeError as e:
logger.error(f"IsoSurfacing execution failed: {str(e)}")
return None
except Exception as e:
logger.error(f"Unexpected error in IsoSurfacing: {str(e)}")
return None
# 检查缓存
stat_file = meshfile + "_stat"
if not args.regen and os.path.exists(stat_file) and os.path.getsize(stat_file) > 0:
with open(stat_file, 'rb') as f:
return pickle.load(f)
# 计算网格距离指标
mesh = trimesh.load(meshfile)
logger.debug("successfully load pred mesh.")
cd, hd, adm, ads, hd_pred2gt, hd_gt2pred = distance_p2mesh(
ptnormal[:,:3], ptnormal[:,3:6], mesh)
result.update({
'CD': cd, 'HD': hd, 'HDpred2gt': hd_pred2gt,
'HDgt2pred': hd_gt2pred, 'AngleDiffMean': adm,
'AngleDiffStd': ads
})
# 计算特征点指标
gt_ptangle = np.loadtxt(os.path.join(gt_path, line + '.ptangle'))
pred_ptangle_path = meshfile[:-4]+'.ptangle'
if not os.path.exists(pred_ptangle_path) or args.regen:
os.system('/home/wch/brep2sdf/data/scripts/MeshFeatureSample/build/SimpleSample -i {} -o {} -s 4e-3'.format(meshfile, pred_ptangle_path))
pred_ptangle = np.loadtxt(pred_ptangle_path).reshape(-1,4)
# 处理特征点结果
if len(gt_ptangle) == 0 or len(pred_ptangle) == 0:
result.update({
'FeaDfgt2pred': 0.0, 'FeaDfpred2gt': 0.0,
'FeaAnglegt2pred': 0.0, 'FeaAnglepred2gt': 0.0,
'FeaDf': 0.0, 'FeaAngle': 0.0
})
else:
dfg2p, dfp2g, fag2p, fap2g = distance_fea(gt_ptangle, pred_ptangle)
result.update({
'FeaDfgt2pred': dfg2p, 'FeaDfpred2gt': dfp2g,
'FeaAnglegt2pred': fag2p, 'FeaAnglepred2gt': fap2g,
'FeaDf': (dfg2p + dfp2g) / 2.0,
'FeaAngle': (fag2p + fap2g) / 2.0
})
# 保存缓存
with open(stat_file, "wb") as f:
pickle.dump(result, f)
return result
except Exception as e:
logger.error(f"Error processing {line}: {str(e)}")
return None
def compute_all():
"""计算所有模型的评估指标"""
try:
# 初始化结果字典
results = []
# 读取模型列表
with open(os.path.join(project_dir, 'data', args.name_list), 'r') as f:
lines = f.readlines()
print(lines)
# 处理每个模型
for line in lines:
result = load_and_process_single_model(line, args.gt_path, args.pred_path, args)
if result:
results.append(result)
logger.info(result)
# 计算平均值
mean_result = {'name': 'mean'}
for key in results[0].keys():
if key != 'name':
mean_result[key] = sum(r[key] for r in results) / len(results)
results.append(mean_result)
# 保存结果
df = pd.DataFrame(results)
df.to_csv(args.csv_name, index=False)
logger.info(f"Evaluation completed. Results saved to {os.path.abspath(args.csv_name)}")
except Exception as e:
logger.error(f"Error in compute_all: {str(e)}")
raise
if __name__ == '__main__':
compute_all()

4
brep2sdf/train.py

@ -80,6 +80,8 @@ class Trainer:
data_path = os.path.join("/home/wch/brep2sdf/data/output_data",self.base_name)
if os.path.exists(data_path) and not args.force_reprocess:
self.data = load_brep_file(data_path)
if args.use_normal and self.data.get("surf_pnt_normals", None) is None:
self.data = process_single_step(step_path=input_step, output_path=data_path, sample_normal_vector=args.use_normal)
else:
self.data = process_single_step(step_path=input_step, output_path=data_path, sample_normal_vector=args.use_normal)
@ -276,7 +278,7 @@ class Trainer:
# 3. 在no_grad上下文中执行追踪
with torch.no_grad():
traced_model = torch.jit.trace(self.model, example_input)
torch.jit.save(traced_model, f"{self.model_name}.pt")
torch.jit.save(traced_model, f"/home/wch/brep2sdf/data/output_data/{self.model_name}.pt")
def _load_checkpoint(self, checkpoint_path):
"""从检查点恢复训练状态"""

Loading…
Cancel
Save