Browse Source

脚本修改

final
mckay 1 month ago
parent
commit
3eca153112
  1. 1
      .gitignore
  2. 6
      brep2sdf/IsoSurfacing.py
  3. 152
      brep2sdf/batch_train.py
  4. 14
      brep2sdf/config/default_config.py
  5. 31
      brep2sdf/data/utils.py
  6. 228
      brep2sdf/eval_pos.py
  7. 4
      brep2sdf/scripts/npz2points.py
  8. 2
      brep2sdf/train.py
  9. 24
      data/name_list copy.txt

1
.gitignore

@ -169,6 +169,7 @@ cython_debug/
*.step
test_data/
logs/
nohup.out
wandb/
*.pth
*.pt

6
brep2sdf/IsoSurfacing.py

@ -38,7 +38,9 @@ def predict_sdf(model, points, device):
points_t = torch.from_numpy(points).float().to(device)
with torch.no_grad():
sdf = model.forward_background(points_t).cpu().numpy().flatten()
sdf = model(points_t).cpu().numpy().flatten()
# 替换 inf 值为 2
#sdf[np.isinf(sdf)] = 2
return sdf
def extract_surface(sdf, xx, yy, zz, method='MC', feature_angle=30.0, voxel_size=0.01):
@ -129,7 +131,7 @@ def main():
parser = argparse.ArgumentParser(description='IsoSurface Generator')
parser.add_argument('-i', '--input', type=str, required=True, help='Input model file (.pt)')
parser.add_argument('-o', '--output', type=str, required=True, help='Output mesh file (.ply)')
parser.add_argument('--depth', type=int, default=7, help='网格深度(分辨率)')
parser.add_argument('--depth', type=int, default=5, help='网格深度(分辨率)')
parser.add_argument('--box_size', type=float, default=1.0, # 从1.0改为2.0
help='边界框大小(建议设为2.0以得到[-1,1]范围)')
parser.add_argument('--method', type=str, default='MC',

152
brep2sdf/batch_train.py

@ -37,18 +37,21 @@ def run_training_process(input_step: str, train_script: str, common_args: list)
Returns:
Tuple: (输入文件路径, 是否成功, stdout, stderr)
"""
name_id = input_step.split("/")[-2]
command = [
"python", train_script,
*common_args,
"-i", input_step,
"--resume-checkpoint-path", f"/home/wch/brep2sdf/checkpoints/{name_id}/epoch_11000.pth"
]
try:
logger.info(f"即将执行的命令: {' '.join(command)}")
result = subprocess.run(
command,
capture_output=True,
text=True,
check=True, # 如果返回非零退出码,则抛出 CalledProcessError
timeout=600 # 添加超时设置(例如10分钟)
timeout=14400
)
return input_step, True, result.stdout, result.stderr
except subprocess.CalledProcessError as e:
@ -61,6 +64,78 @@ def run_training_process(input_step: str, train_script: str, common_args: list)
logger.error(f"处理 '{input_step}' 时发生意外错误: {e}")
return input_step, False, "", str(e)
def batch_train_max_workers_1(args):
# 读取名称列表
names = get_namelist(args.name_list_path)
# 准备 train.py 的通用参数
# 注意:从命令行参数或其他配置中获取这些参数通常更好
common_train_args = [
"--use-normal",
"--only-zero-surface",
"--octree-cuda",
#"--force-reprocess",
# 可以添加更多通用参数
]
if args.train_args:
common_train_args.extend(args.train_args)
tasks = []
skipped_count = 0
# 准备所有任务
for name in names:
step_dir = os.path.join(args.step_root_dir, name)
if not os.path.isdir(step_dir):
logger.warning(f"目录 '{step_dir}' 不存在。跳过 '{name}'")
skipped_count += 1
continue
step_files = []
try:
step_files = [
os.path.join(step_dir, f)
for f in os.listdir(step_dir)
if f.lower().endswith((".step", ".stp")) and f.startswith(name) # 支持 .stp 并忽略大小写
]
except OSError as e:
logger.warning(f"无法访问目录 '{step_dir}': {e}。跳过 '{name}'")
skipped_count += 1
continue
if len(step_files) == 0:
logger.warning(f"在目录 '{step_dir}' 中未找到匹配的 STEP 文件。跳过 '{name}'")
skipped_count += 1
elif len(step_files) > 1:
logger.warning(f"在目录 '{step_dir}' 中找到多个匹配的 STEP 文件,将使用第一个: {step_files[0]}")
tasks.append(step_files[0])
else:
tasks.append(step_files[0])
if not tasks:
logger.info("没有找到需要处理的有效 STEP 文件。")
return
logger.info(f"准备处理 {len(tasks)} 个 STEP 文件,跳过了 {skipped_count} 个名称。")
success_count = 0
failure_count = 0
# 使用 for 循环顺序执行任务
for task_path in tqdm(tasks, desc="运行训练"):
input_step, success, stdout, stderr = run_training_process(task_path, args.train_script, common_train_args)
if success:
success_count += 1
# 可以选择记录成功的 stdout/stderr,但通常只记录失败的更有用
# logger.debug(f"成功处理 '{input_step}'. STDOUT:\n{stdout}\nSTDERR:\n{stderr}")
else:
failure_count += 1
logger.error(f"处理 '{input_step}' 失败。STDOUT:\n{stdout}\nSTDERR:\n{stderr}")
logger.info(f"处理完成。成功: {success_count}, 失败: {failure_count}, 跳过: {skipped_count}")
def batch_train(args):
# 读取名称列表
names = get_namelist(args.name_list_path)
@ -71,7 +146,7 @@ def batch_train(args):
"--use-normal",
"--only-zero-surface",
"--octree-cuda",
"--force-reprocess",
#"--force-reprocess",
# 可以添加更多通用参数
]
if args.train_args:
@ -145,7 +220,7 @@ def batch_train(args):
#===========================
# Iso
def run_isosurfacing_process(input_path, output_dir, use_gpu=True):
def run_isosurfacing_process(input_path, output_dir, use_gpu=True,if_nh=False):
"""
运行 IsoSurfacing.py 脚本生成等值面
@ -163,8 +238,8 @@ def run_isosurfacing_process(input_path, output_dir, use_gpu=True):
try:
# 构造输出文件名
base_name = os.path.splitext(os.path.basename(input_path))[0]
output_path = os.path.join(output_dir, f"{base_name}.ply")
output_path = os.path.join(output_dir, f"{base_name}_nh.ply") if if_nh else os.path.join(output_dir, f"{base_name}.ply")
#print(output_path)
# 构造命令
command = [
"python", "IsoSurfacing.py",
@ -249,9 +324,68 @@ def batch_Iso(args):
logger.error(f"处理任务 '{input_path}' 时获取结果失败: {e}")
logger.info(f"处理完成。成功: {success_count}, 失败: {failure_count}")
def batch_nh_Iso(args):
# python IsoSurfacing.py -i /home/wch/NH-Rep/data/output_data/00000003_0_50k_model_h.pt -o /home/wch/NH-Rep/data/output_data/00000003.ply --use-gpu
"""
批量处理 .pt 文件生成等值面并保存为 .ply 文件
"""
names = get_namelist(args.name_list_path)
# 检查输入目录是否存在
if not os.path.exists(args.pt_dir):
logger.error(f"错误: 输入目录 '{args.pt_dir}' 不存在。")
return
# 获取所有 .pt 文件
try:
pt_files = [
os.path.join(args.pt_dir, f"{name}_0_50k_model_h.pt")
for name in names
]
pt_files = [f for f in pt_files if os.path.exists(f)]
except OSError as e:
logger.error(f"无法访问输入目录 '{args.pt_dir}': {e}")
return
if not pt_files:
logger.info(f"输入目录 '{args.pt_dir}' 中未找到任何 .pt 文件。")
return
logger.info(f"在输入目录中找到 {len(pt_files)} 个 .pt 文件。")
success_count = 0
failure_count = 0
# 使用 ProcessPoolExecutor 进行并行处理
with ProcessPoolExecutor(max_workers=args.workers) as executor:
# 提交所有任务
futures = {
executor.submit(run_isosurfacing_process, pt_path, args.pt_dir, True, True): pt_path
for pt_path in pt_files
}
# 使用 tqdm 显示进度并处理结果
for future in tqdm(as_completed(futures), total=len(pt_files), desc="生成等值面"):
input_path = futures[future]
try:
input_pt, success, stdout, stderr = future.result()
if success:
success_count += 1
# 可以选择记录成功的 stdout/stderr,但通常只记录失败的更有用
# logger.debug(f"成功处理 '{input_pt}'. STDOUT:\n{stdout}\nSTDERR:\n{stderr}")
else:
failure_count += 1
logger.error(f"处理 '{input_pt}' 失败。STDOUT:\n{stdout}\nSTDERR:\n{stderr}")
except Exception as e:
failure_count += 1
logger.error(f"处理任务 '{input_path}' 时获取结果失败: {e}")
logger.info(f"处理完成。成功: {success_count}, 失败: {failure_count}")
def main(args):
batch_train(args)
batch_train_max_workers_1(args)
#batch_train(args)
#batch_Iso(args)
#batch_nh_Iso(args)
if __name__ == '__main__':
@ -264,10 +398,12 @@ if __name__ == '__main__':
help="包含要处理的名称列表的文件路径。")
parser.add_argument('--train-script', type=str, default="train.py",
help="要执行的训练脚本路径。")
parser.add_argument('--workers', type=int, default=os.cpu_count(),
parser.add_argument('--workers', type=int, default=1,
help="用于并行处理的工作进程数。")
parser.add_argument('--train-args', nargs='*',
help="传递给 train.py 的额外参数 (例如 --epochs 10 --batch-size 32)。")
args = parser.parse_args()
main(args)
main(args)
# python batch_train.py --pt-dir /home/wch/NH-Rep/data_backup/output_data/extracted/output_data

14
brep2sdf/config/default_config.py

@ -49,16 +49,18 @@ class TrainConfig:
# 基本训练参数
batch_size: int = 8
num_workers: int = 4
num_epochs: int = 50
learning_rate: float = 0.005
num_epochs1: int = 10000
num_epochs2: int = 1000
num_epochs3: int = 1000
learning_rate: float = 0.1
learning_rate_schedule: List = field(default_factory=lambda: [{
"Type": "Step", # 学习率调度类型。"Step"表示在指定迭代次数后将学习率乘以因子
"Initial": 0.005,
"Initial": 0.01,
"Interval": 2000,
"Factor": 0.5
"Factor": 0.3
}])
min_lr: float = 1e-5
weight_decay: float = 0.01
weight_decay: float = 0.0001
# 梯度和损失相关
max_grad_norm: float = 1.0
@ -71,7 +73,7 @@ class TrainConfig:
warmup_epochs: int = 5
# 保存和验证
save_freq: int = 10 # 每多少个epoch保存一次
save_freq: int = 1000 # 每多少个epoch保存一次
val_freq: int = 1 # 每多少个epoch验证一次
# 保存路径

31
brep2sdf/data/utils.py

@ -246,4 +246,33 @@ def batch_compute_normals(mesh, surf_wcs, normal_type='vertex', k_neighbors=3):
normals_output[i] = nearest_normals[start:end]
start = end
return normals_output
return normals_output
def points_in_box(points: torch.Tensor, bbox: torch.Tensor) -> torch.Tensor:
"""
返回落在AABB包围盒内的点保留所有7个通道
参数:
points: 形状为 (N, 7) 的张量其中前3维是坐标(x,y,z)其余为属性如法线颜色等
bbox: 形状为 (6,) 的张量表示AABB包围盒的坐标 [x_min, y_min, z_min, x_max, y_max, z_max]
返回:
torch.Tensor: 形状为 (K, 7)其中 K 是落在包围盒内的点的数量
"""
assert points.shape[1] == 7, f"points 必须有7个通道,但得到 {points.shape[1]}"
assert bbox.shape == (6,), f"bbox 必须是长度为6的一维张量,但得到 {bbox.shape}"
min_coords = bbox[:3]
max_coords = bbox[3:]
# 检查每个点的 xyz 是否在包围盒内 (N, 3)
within_box = (points[:, :3] >= min_coords) & (points[:, :3] <= max_coords)
# 所有轴都满足条件的点 (N,)
inside_mask = within_box.all(dim=1)
# 提取符合条件的完整点(包括所有7个维度)
points_inside = points[inside_mask]
return points_inside.detach().clone()

228
brep2sdf/eval_pos.py

@ -1,5 +1,6 @@
import trimesh
import numpy as np
from brep2sdf.data.data import prepare_sdf_data,load_brep_file
from brep2sdf.data.sampler import sample_zero_surface_points_and_normals
from brep2sdf.utils.load import get_namelist, get_step_paths
from brep2sdf.networks.network import gradient
@ -7,6 +8,21 @@ import torch
import os
from brep2sdf.utils.logger import logger
# 全局变量用于保存采样点
GLOBAL_SAMPLED_POINTS = None
def sample_grid_points(mesh, nh_mesh, our_mesh, num_samples_iou):
global GLOBAL_SAMPLED_POINTS
if GLOBAL_SAMPLED_POINTS is not None:
return GLOBAL_SAMPLED_POINTS
# 从一个较大的空间范围采样点
bounds = np.vstack([mesh.bounds, nh_mesh.bounds, our_mesh.bounds])
min_bound = np.min(bounds, axis=0)
max_bound = np.max(bounds, axis=0)
points = np.random.uniform(min_bound, max_bound, (num_samples_iou, 3))
GLOBAL_SAMPLED_POINTS = points
return points
def position_loss(pred_sdfs: torch.Tensor) -> torch.Tensor:
"""位置损失函数"""
# 保持梯度流
@ -25,8 +41,54 @@ def average_normal_error(normals1: torch.Tensor, normals2: torch.Tensor) -> torc
angle_errors = 1 - absolute_dot_products
return torch.mean(angle_errors)
def chamfer_distance(points_A: torch.Tensor, points_B: torch.Tensor) -> torch.Tensor:
"""
计算两个点集之间的 Chamfer Distance (CD)
:param points_A: 形状为 (N, 3) 的点集 A
:param points_B: 形状为 (M, 3) 的点集 B
:return: CD
"""
N = points_A.shape[0]
M = points_B.shape[0]
points_A_expanded = points_A.unsqueeze(1).expand(N, M, 3)
points_B_expanded = points_B.unsqueeze(0).expand(N, M, 3)
distances = torch.sum((points_A_expanded - points_B_expanded) ** 2, dim=-1) # (N, M)
dist_A_to_B = torch.min(distances, dim=1)[0] # (N,)
dist_B_to_A = torch.min(distances, dim=0)[0] # (M,)
return (torch.mean(dist_A_to_B) + torch.mean(dist_B_to_A)) / 2
def hausdorff_distance(points_A: torch.Tensor, points_B: torch.Tensor) -> torch.Tensor:
"""
计算两个点集之间的 Two-Side Hausdorff Distance (HD)
:param points_A: 形状为 (N, 3) 的点集 A
:param points_B: 形状为 (M, 3) 的点集 B
:return: HD
"""
N = points_A.shape[0]
M = points_B.shape[0]
points_A_expanded = points_A.unsqueeze(1).expand(N, M, 3)
points_B_expanded = points_B.unsqueeze(0).expand(N, M, 3)
distances = torch.sum((points_A_expanded - points_B_expanded) ** 2, dim=-1) # (N, M)
dist_A_to_B = torch.min(distances, dim=1)[0] # (N,)
dist_B_to_A = torch.min(distances, dim=0)[0] # (M,)
return torch.max(torch.max(dist_A_to_B), torch.max(dist_B_to_A))
def compute_iou(sdf1, sdf2, threshold=0.0):
"""
计算两个 SDF 之间的 IoU
:param sdf1: 第一个 SDF 数组
:param sdf2: 第二个 SDF 数组
:param threshold: 阈值用于判断点是否在表面内
:return: IoU
"""
inside1 = sdf1 <= threshold
inside2 = sdf2 <= threshold
intersection = np.logical_and(inside1, inside2).sum()
union = np.logical_or(inside1, inside2).sum()
iou = intersection / union if union > 0 else 0.0
return iou
# load
def load_model(model_path):
"""加载模型的通用函数"""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@ -38,6 +100,9 @@ def load_model(model_path):
logger.error(f"加载模型 {model_path} 时出错: {e}")
return None
# model
def nh(model_path, points):
model = load_model(model_path)
if model is None:
@ -53,16 +118,22 @@ def mine(model_path, points):
if model is None:
return None
try:
return model.forward_background(points)
return model(points)
except Exception as e:
logger.error(f"调用 mine 模型时出错: {e}")
return None
def run(name):
# 替换为实际的 obj 文件路径
obj_file_path = f"/home/wch/brep2sdf/data/gt_mesh/{name}.obj"
model_path = f"/home/wch/brep2sdf/data/output_data/{name}.pt"
nh_model = f"/home/wch/NH-Rep/data_backup/output_data/extracted/output_data/{name}_0_50k_model_h.pt"
ply_nh = f"/home/wch/NH-Rep/data_backup/output_data/extracted/output_data/{name}_0_50k_model_h_nh.ply"
ply_our = f"/home/wch/brep2sdf/data/output_data/{name}.ply"
npz_path = f"/home/wch/brep2sdf/data/output_data/{name}.xyz"
num_samples=4096
# 检查文件是否存在
if not os.path.isfile(obj_file_path):
@ -79,7 +150,7 @@ def run(name):
try:
# 调用采样函数
result1 = sample_zero_surface_points_and_normals(mesh, num_samples=4096)
result1 = sample_zero_surface_points_and_normals(mesh, num_samples)
if result1 is None:
logger.error("采样失败,返回 None")
return
@ -98,29 +169,115 @@ def run(name):
loss2["de"] = position_loss(sdf2).item()
logger.info(f"NH 模型位置损失: {loss1}")
logger.info(f"Mine 模型位置损失: {loss2}")
# 将 gt_normal 转换为 torch.Tensor 并移动到设备上
gt_normal = torch.from_numpy(result1[:, 3:6]).float().to(device)
# 假设 gradient 函数已正确导入
normal1 = gradient(coordinates_tensor, sdf1)
normal2 = gradient(coordinates_tensor, sdf2)
loss1["nae"] = average_normal_error(gt_normal, normal1).item()
loss2["nae"] = average_normal_error(gt_normal, normal2).item()
print("NH 模型的平均法向量误差 (NAE):", loss1["nae"])
print("Mine 模型的平均法向量误差 (NAE):", loss2["nae"])
return loss1, loss2
else:
logger.error("无法计算损失,SDF 结果为 None")
# 读取 ply 文件
try:
nh_mesh = trimesh.load_mesh(ply_nh)
our_mesh = trimesh.load_mesh(ply_our)
logger.info(f"成功读取 PLY 文件: {ply_nh}{ply_our}")
except Exception as e:
logger.error(f"读取 PLY 文件时出错: {e}")
return
# 从网格中采样点
nh_points = torch.from_numpy(nh_mesh.sample(num_samples)).float().to(device)
our_points = torch.from_numpy(our_mesh.sample(num_samples)).float().to(device)
# 确保 coordinates 是 torch.Tensor 类型
loss1["cd"] = chamfer_distance(coordinates_tensor, nh_points).item()
loss2["cd"] = chamfer_distance(coordinates_tensor, our_points).item()
loss1["hd"] = hausdorff_distance(coordinates_tensor, nh_points).item()
loss2["hd"] = hausdorff_distance(coordinates_tensor, our_points).item()
# fea
data = load_brep_file(npz_path)
sampled_pnts=prepare_sdf_data(data["surf_ncs"],normals=data["surf_pnt_normals"],max_points=num_samples)
# 展平处理
flattened_pnts = sampled_pnts.flatten()
# 修改此处,使用 clone().detach()
if isinstance(flattened_pnts[0:3], torch.Tensor):
f_pnts = flattened_pnts[0:3].clone().detach().to(device).view(-1, 3)
else:
f_pnts = torch.from_numpy(flattened_pnts[0:3]).clone().detach().to(device).view(-1, 3)
if isinstance(flattened_pnts[3:6], torch.Tensor):
f_normals = flattened_pnts[3:6].clone().detach().to(device).view(-1, 3)
else:
f_normals = torch.from_numpy(flattened_pnts[3:6]).clone().detach().to(device).view(-1, 3)
# 检查 f_pnts 和 f_normals 的形状
if f_pnts.shape[-1] != 3 or f_normals.shape[-1] != 3:
logger.error(f"f_pnts 形状: {f_pnts.shape}, f_normals 形状: {f_normals.shape},期望最后一维尺寸为 3")
return
loss1["fcd"] = chamfer_distance(f_pnts, nh_points).item()
loss2["fcd"] = chamfer_distance(f_pnts, our_points).item()
loss1["fae"] = hausdorff_distance(f_normals, nh_points).item()
loss2["fae"] = hausdorff_distance(f_normals, our_points).item()
# 计算 IoU,从obj文件计算
# ... existing code ...
# 计算 IoU,使用采样点方法
try:
num_samples_iou = 10000 # 采样点数量,可以根据需要调整
# 调用封装的采样函数
points = sample_grid_points(mesh, nh_mesh, our_mesh, num_samples_iou)
# 判断点是否在各个网格内部
inside_mesh = mesh.contains(points)
inside_nh = nh_mesh.contains(points)
inside_our = our_mesh.contains(points)
# 计算 nh_mesh 与 mesh 的交集和并集
intersection_nh = np.logical_and(inside_mesh, inside_nh).sum()
union_nh = np.logical_or(inside_mesh, inside_nh).sum()
# 计算 our_mesh 与 mesh 的交集和并集
intersection_our = np.logical_and(inside_mesh, inside_our).sum()
union_our = np.logical_or(inside_mesh, inside_our).sum()
# 计算 IoU
iou_nh = intersection_nh / union_nh if union_nh > 0 else 0.0
iou_our = intersection_our / union_our if union_our > 0 else 0.0
loss1["iou"] = iou_nh
loss2["iou"] = iou_our
except Exception as e:
print(f"使用采样点计算 IoU 时出错: {e}")
return loss1, loss2
except Exception as e:
logger.error(f"处理过程中出现错误: {e}")
def main():
names = get_namelist("/home/wch/brep2sdf/data/name_list.txt")
tl1_de, tl1_nae, tl2_de, tl2_nae = 0.0, 0.0, 0.0, 0.0
tl1_cd, tl1_hd, tl2_cd, tl2_hd = 0.0, 0.0, 0.0, 0.0
tl1_fcd, tl1_fae, tl2_fcd, tl2_fae = 0.0, 0.0, 0.0, 0.0
# 新增累加 IoU 的变量
tl1_iou, tl2_iou = 0.0, 0.0
valid_count = 0
for name in names:
result = run(name)
@ -130,14 +287,61 @@ def main():
tl1_nae += l1["nae"]
tl2_de += l2["de"]
tl2_nae += l2["nae"]
tl1_cd += l1["cd"]
tl1_hd += l1["hd"]
tl2_cd += l2["cd"]
tl2_hd += l2["hd"]
tl1_fcd += l1["fcd"]
tl1_fae += l1["fae"]
tl2_fcd += l2["fcd"]
tl2_fae += l2["fae"]
# 累加 IoU 的值
tl1_iou += l1["iou"]
tl2_iou += l2["iou"]
valid_count += 1
if valid_count > 0:
print(f"NH 模型平均位置损失 (de): {tl1_de/valid_count}")
print(f"NH 模型平均法向量误差 (nae): {tl1_nae/valid_count}")
print(f"Mine 模型平均位置损失 (de): {tl2_de/valid_count}")
print(f"Mine 模型平均法向量误差 (nae): {tl2_nae/valid_count}")
avg_l1_de = tl1_de / valid_count
avg_l1_nae = tl1_nae / valid_count
avg_l2_de = tl2_de / valid_count
avg_l2_nae = tl2_nae / valid_count
avg_l1_cd = tl1_cd / valid_count
avg_l1_hd = tl1_hd / valid_count
avg_l2_cd = tl2_cd / valid_count
avg_l2_hd = tl2_hd / valid_count
avg_l1_fcd = tl1_fcd / valid_count
avg_l1_fae = tl1_fae / valid_count
avg_l2_fcd = tl2_fcd / valid_count
avg_l2_fae = tl2_fae / valid_count
# 计算 IoU 的平均值
avg_l1_iou = tl1_iou / valid_count
avg_l2_iou = tl2_iou / valid_count
# 打印表格表头
print("| 模型 | Chamfer Distance (CD) | Hausdorff Distance (HD) | 平均法向量误差 (NAE) | Feature Chamfer Distance (FCD) | Feature Angle Error (FAE) | 位置损失 (DE) | IoU |")
print("|------|-----------------------|-------------------------|----------------------|--------------------------------|---------------------------|--------------|-----|")
# 打印 NH 模型数据
print(f"| NH 模型 | {avg_l1_cd} | {avg_l1_hd} | {avg_l1_nae} | {avg_l1_fcd} | {avg_l1_fae} | {avg_l1_de} | {avg_l1_iou} |")
# 打印 Mine 模型数据
print(f"| Mine 模型 | {avg_l2_cd} | {avg_l2_hd} | {avg_l2_nae} | {avg_l2_fcd} | {avg_l2_fae} | {avg_l2_de} | {avg_l2_iou} |")
else:
print("没有有效的结果,无法计算平均值。")
def test(name_id):
result = run(name_id) # 修正参数使用错误
if result is not None:
l1, l2 = result
# 打印表格表头
print("| 模型 | Chamfer Distance (CD) | Hausdorff Distance (HD) | 平均法向量误差 (NAE) | Feature Chamfer Distance (FCD) | Feature Angle Error (FAE) | 位置损失 (DE) | IoU |")
print("|------|-----------------------|-------------------------|----------------------|--------------------------------|---------------------------|--------------|-----|")
# 假设 IoU 数据存在于 l1 和 l2 中,如果不存在可以先忽略或者设置默认值
# 打印 NH 模型数据
print(f"| NH 模型 | {l1['cd']} | {l1['hd']} | {l1['nae']} | {l1['fcd']} | {l1['fae']} | {l1['de']} | {l1['iou']} |")
# 打印 Mine 模型数据
print(f"| Mine 模型 | {l2['cd']} | {l2['hd']} | {l2['nae']} | {l2['fcd']} | {l2['fae']} | {l2['de']} | {l2['iou']} |")
else:
print("没有有效的结果。")
if __name__ == "__main__":
main()
#main()
test("00000031")

4
brep2sdf/scripts/npz2points.py

@ -11,10 +11,10 @@ def load_brep_file(brep_path):
if __name__ == "__main__":
data=load_brep_file("/home/wch/brep2sdf/data/output_data/00000031.xyz")
data=load_brep_file("/home/wch/brep2sdf/data/output_data/00000003.xyz")
surfs =data["train_surf_ncs"]
print(surfs)
with open("0031_t.xyz","w") as f:
with open("0003_t.xyz","w") as f:
for point in surfs:
#f.write(f"{point[0]} {point[1]} {point[2]}\n")
f.write(f"{point[0]} {point[1]} {point[2]} {point[3]} {point[4]} {point[5]}\n")

2
brep2sdf/train.py

@ -610,7 +610,7 @@ class Trainer:
self.model.train()
total_loss = 0.0
step = 0 # 如果你的训练是分批次的,这里应该用批次索引
batch_size = 8192 * 2 # 设置合适的batch大小
batch_size = 4096 # 设置合适的batch大小
# 数据处理
# manfld

24
data/name_list copy.txt

@ -0,0 +1,24 @@
00000003
00000008
00000009
00000029
00000031
00000032
00000047
00000049
00000057
00000058
00000060
00000061
00000065
00000066
00000067
00000068
00000070
00000072
00000076
00000077
00000078
00000079
00000088
00000093
Loading…
Cancel
Save