9 changed files with 723 additions and 68 deletions
@ -0,0 +1,204 @@ |
|||||
|
import numpy as np |
||||
|
import torch |
||||
|
import argparse |
||||
|
from skimage import measure |
||||
|
import time |
||||
|
import trimesh |
||||
|
from brep2sdf.utils.logger import logger |
||||
|
from brep2sdf.networks.octree import OctreeNode |
||||
|
|
||||
|
def create_grid_with_octree(octree, model, device): |
||||
|
""" |
||||
|
使用八叉树创建三维网格点 |
||||
|
:param octree: 八叉树对象 |
||||
|
:param model: PyTorch模型 |
||||
|
:param device: 设备(CPU/GPU) |
||||
|
:return: 网格点数组和SDF值数组 |
||||
|
""" |
||||
|
leaf_indices = (octree.is_leaf_mask & octree.is_valid_leaf_mask).nonzero().flatten() |
||||
|
print(leaf_indices.shape) |
||||
|
points = [] |
||||
|
for idx in leaf_indices: |
||||
|
bbox = octree.node_bboxes[idx] |
||||
|
min_coords = bbox[:3].cpu().numpy() |
||||
|
max_coords = bbox[3:].cpu().numpy() |
||||
|
# 在叶子节点的边界框内采样 |
||||
|
num_samples = 1 # 可根据需要调整采样点数 |
||||
|
x = np.linspace(min_coords[0], max_coords[0], num_samples) |
||||
|
y = np.linspace(min_coords[1], max_coords[1], num_samples) |
||||
|
z = np.linspace(min_coords[2], max_coords[2], num_samples) |
||||
|
xx, yy, zz = np.meshgrid(x, y, z, indexing='ij') |
||||
|
node_points = np.stack([xx.ravel(), yy.ravel(), zz.ravel()], axis=1) |
||||
|
points.append(node_points) |
||||
|
|
||||
|
points = np.vstack(points) |
||||
|
sdf = predict_sdf(model, points, device) |
||||
|
return points, sdf |
||||
|
|
||||
|
def predict_sdf(model, points, device): |
||||
|
""" |
||||
|
使用模型预测SDF值 |
||||
|
:param model: PyTorch模型 |
||||
|
:param points: 输入点坐标 (N, 3) |
||||
|
:param device: 设备(CPU/GPU) |
||||
|
:return: SDF值数组 (N,) |
||||
|
""" |
||||
|
points_t = torch.from_numpy(points).float().to(device) |
||||
|
|
||||
|
with torch.no_grad(): |
||||
|
sdf = model.forward_background(points_t).cpu().numpy().flatten() |
||||
|
return sdf |
||||
|
|
||||
|
def extract_surface(sdf, xx, yy, zz, method='MC', feature_angle=30.0, voxel_size=0.01): |
||||
|
""" |
||||
|
提取零表面 |
||||
|
:param sdf: SDF值三维数组 |
||||
|
:param xx/yy/zz: 网格坐标 |
||||
|
:param method: 提取方法(MC: Marching Cubes) |
||||
|
:return: 顶点和面片 |
||||
|
""" |
||||
|
if method == 'MC': |
||||
|
verts, faces, _, _ = measure.marching_cubes(sdf, level=0) |
||||
|
elif method == 'EMC': |
||||
|
from iso_algorithms import enhanced_marching_cubes |
||||
|
verts, faces = enhanced_marching_cubes( |
||||
|
sdf, |
||||
|
feature_angle=feature_angle, |
||||
|
gradient_direction='descent' |
||||
|
) |
||||
|
elif method == 'DC': |
||||
|
from iso_algorithms import dual_contouring |
||||
|
verts, faces = dual_contouring(sdf, voxel_size=voxel_size) |
||||
|
else: |
||||
|
raise ValueError(f"不支持的算法: {method}") |
||||
|
|
||||
|
# 新增顶点后处理 |
||||
|
verts = (verts - sdf.shape[0]//2) / (sdf.shape[0]//2) # 归一化到[-1,1] |
||||
|
return verts, faces |
||||
|
|
||||
|
def save_ply(vertices, faces, filename): |
||||
|
""" |
||||
|
保存顶点和面片为PLY文件 |
||||
|
:param vertices: 顶点数组 (N, 3) |
||||
|
:param faces: 面片数组 (M, 3) |
||||
|
:param filename: 输出文件名 |
||||
|
""" |
||||
|
with open(filename, 'w') as f: |
||||
|
f.write("ply\n") |
||||
|
f.write("format ascii 1.0\n") |
||||
|
f.write(f"element vertex {len(vertices)}\n") |
||||
|
f.write("property float x\n") |
||||
|
f.write("property float y\n") |
||||
|
f.write("property float z\n") |
||||
|
f.write(f"element face {len(faces)}\n") |
||||
|
f.write("property list uchar int vertex_indices\n") |
||||
|
f.write("end_header\n") |
||||
|
for v in vertices: |
||||
|
f.write(f"{v[0]} {v[1]} {v[2]}\n") |
||||
|
for face in faces: |
||||
|
f.write(f"3 {face[0]} {face[1]} {face[2]}\n") |
||||
|
|
||||
|
def compute_sdf_error(model, gt_mesh, res, device): |
||||
|
""" |
||||
|
计算预测SDF与GT网格的误差 |
||||
|
:param model: PyTorch模型 |
||||
|
:param gt_mesh: GT网格(Trimesh格式) |
||||
|
:param res: 误差计算分辨率 |
||||
|
:param device: 设备 |
||||
|
:return: 平均误差和最大误差 |
||||
|
""" |
||||
|
# 生成均匀采样点 |
||||
|
box_size = max(gt_mesh.extents) |
||||
|
start = -box_size / 2 |
||||
|
end = box_size / 2 |
||||
|
x = np.linspace(start, end, res) |
||||
|
y = np.linspace(start, end, res) |
||||
|
z = np.linspace(start, end, res) |
||||
|
points = np.array(np.meshgrid(x, y, z)).T.reshape(-1, 3) |
||||
|
|
||||
|
# 预测SDF |
||||
|
pred_sdf = predict_sdf(model, points, device) |
||||
|
|
||||
|
# 计算GT距离 |
||||
|
distances = gt_mesh.nearest.on_surface(points)[1] |
||||
|
gt_sdf = np.abs(distances) |
||||
|
|
||||
|
# 计算误差 |
||||
|
abs_error = np.abs(pred_sdf - gt_sdf) |
||||
|
rel_error = abs_error / (np.abs(gt_sdf) + 1e-9) |
||||
|
avg_abs = np.mean(abs_error) |
||||
|
avg_rel = np.mean(rel_error) |
||||
|
max_abs = np.max(abs_error) |
||||
|
max_rel = np.max(rel_error) |
||||
|
|
||||
|
return avg_abs, avg_rel, max_abs, max_rel |
||||
|
|
||||
|
def main(): |
||||
|
parser = argparse.ArgumentParser(description='IsoSurface Generator') |
||||
|
parser.add_argument('-i', '--input', type=str, required=True, help='Input model file (.pt)') |
||||
|
parser.add_argument('-o', '--output', type=str, required=True, help='Output mesh file (.ply)') |
||||
|
parser.add_argument('--depth', type=int, default=7, help='网格深度(分辨率)') |
||||
|
parser.add_argument('--box_size', type=float, default=1.0, # 从1.0改为2.0 |
||||
|
help='边界框大小(建议设为2.0以得到[-1,1]范围)') |
||||
|
parser.add_argument('--method', type=str, default='MC', |
||||
|
choices=['MC', 'EMC', 'DC'], # 新增算法选项 |
||||
|
help='表面提取方法: MC-MarchingCubes, EMC-EnhancedMC, DC-DualContouring') |
||||
|
parser.add_argument('--feature_angle', type=float, default=30.0, |
||||
|
help='特征角度阈值(EMC算法专用)') |
||||
|
parser.add_argument('--voxel_size', type=float, default=0.01, |
||||
|
help='体素尺寸(DC算法专用)') |
||||
|
parser.add_argument('--use-gpu', action='store_true', help='使用GPU') |
||||
|
parser.add_argument('--compare', type=str, help='GT网格文件(.ply)') |
||||
|
parser.add_argument('--compres', type=int, default=32, help='误差计算分辨率') |
||||
|
args = parser.parse_args() |
||||
|
|
||||
|
# 设置设备 |
||||
|
device = torch.device("cuda" if args.use_gpu and torch.cuda.is_available() else "cpu") |
||||
|
print(f"Using device: {device}") |
||||
|
|
||||
|
model = torch.jit.load(args.input).to(device) |
||||
|
#model = torch.load(args.input).to(device) |
||||
|
model.eval() |
||||
|
|
||||
|
octree = model.octree_module |
||||
|
|
||||
|
# 使用八叉树创建网格并预测SDF |
||||
|
points, sdf = create_grid_with_octree(octree, model, device) |
||||
|
print(1) |
||||
|
# 这里需要根据实际情况将points转换为网格坐标xx, yy, zz |
||||
|
# 简单示例:假设points是均匀采样的 |
||||
|
grid_size = int(np.ceil(len(points) ** (1/3))) |
||||
|
xx = points[:, 0].reshape(grid_size, grid_size, grid_size) |
||||
|
yy = points[:, 1].reshape(grid_size, grid_size, grid_size) |
||||
|
zz = points[:, 2].reshape(grid_size, grid_size, grid_size) |
||||
|
sdf_grid = sdf.reshape(xx.shape) |
||||
|
|
||||
|
# 提取表面 |
||||
|
print("Extracting surface...") |
||||
|
start_time = time.time() |
||||
|
verts, faces = extract_surface(sdf_grid, xx, yy, zz, args.method) |
||||
|
|
||||
|
# 新增顶点归一化校验 |
||||
|
max_val = np.max(np.abs(verts)) |
||||
|
if max_val > 1.0 + 1e-6: # 允许微小误差 |
||||
|
verts = verts / max_val |
||||
|
print(f"Surface extraction took {time.time() - start_time:.2f} seconds") |
||||
|
|
||||
|
# 保存网格 |
||||
|
save_ply(verts, faces, args.output) |
||||
|
print(f"Mesh saved to {args.output}") |
||||
|
|
||||
|
# 误差评估(可选) |
||||
|
if args.compare: |
||||
|
print("Computing SDF error...") |
||||
|
gt_mesh = trimesh.load(args.compare) |
||||
|
avg_abs, avg_rel, max_abs, max_rel = compute_sdf_error( |
||||
|
model, gt_mesh, args.compres, device |
||||
|
) |
||||
|
print(f"Average Absolute Error: {avg_abs:.4f}") |
||||
|
print(f"Average Relative Error: {avg_rel:.4f}") |
||||
|
print(f"Max Absolute Error: {max_abs:.4f}") |
||||
|
print(f"Max Relative Error: {max_rel:.4f}") |
||||
|
|
||||
|
if __name__ == "__main__": |
||||
|
main() |
@ -0,0 +1,143 @@ |
|||||
|
import trimesh |
||||
|
import numpy as np |
||||
|
from brep2sdf.data.sampler import sample_zero_surface_points_and_normals |
||||
|
from brep2sdf.utils.load import get_namelist, get_step_paths |
||||
|
from brep2sdf.networks.network import gradient |
||||
|
import torch |
||||
|
import os |
||||
|
from brep2sdf.utils.logger import logger |
||||
|
|
||||
|
def position_loss(pred_sdfs: torch.Tensor) -> torch.Tensor: |
||||
|
"""位置损失函数""" |
||||
|
# 保持梯度流 |
||||
|
squared_diff = torch.pow(pred_sdfs, 2) |
||||
|
return torch.mean(squared_diff) |
||||
|
|
||||
|
def average_normal_error(normals1: torch.Tensor, normals2: torch.Tensor) -> torch.Tensor: |
||||
|
""" |
||||
|
计算平均法向量误差 (NAE) |
||||
|
:param normals1: 形状为 (B, 3) 的法向量张量 |
||||
|
:param normals2: 形状为 (B, 3) 的法向量张量 |
||||
|
:return: NAE 值 |
||||
|
""" |
||||
|
dot_products = torch.sum(normals1 * normals2, dim=-1) |
||||
|
absolute_dot_products = torch.abs(dot_products) |
||||
|
angle_errors = 1 - absolute_dot_products |
||||
|
return torch.mean(angle_errors) |
||||
|
|
||||
|
|
||||
|
|
||||
|
def load_model(model_path): |
||||
|
"""加载模型的通用函数""" |
||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
||||
|
try: |
||||
|
model = torch.jit.load(model_path).to(device) |
||||
|
logger.info(f"成功加载模型: {model_path}") |
||||
|
return model |
||||
|
except Exception as e: |
||||
|
logger.error(f"加载模型 {model_path} 时出错: {e}") |
||||
|
return None |
||||
|
|
||||
|
def nh(model_path, points): |
||||
|
model = load_model(model_path) |
||||
|
if model is None: |
||||
|
return None |
||||
|
try: |
||||
|
return model(points) |
||||
|
except Exception as e: |
||||
|
logger.error(f"调用 NH 模型时出错: {e}") |
||||
|
return None |
||||
|
|
||||
|
def mine(model_path, points): |
||||
|
model = load_model(model_path) |
||||
|
if model is None: |
||||
|
return None |
||||
|
try: |
||||
|
return model.forward_background(points) |
||||
|
except Exception as e: |
||||
|
logger.error(f"调用 mine 模型时出错: {e}") |
||||
|
return None |
||||
|
|
||||
|
def run(name): |
||||
|
# 替换为实际的 obj 文件路径 |
||||
|
obj_file_path = f"/home/wch/brep2sdf/data/gt_mesh/{name}.obj" |
||||
|
model_path = f"/home/wch/brep2sdf/data/output_data/{name}.pt" |
||||
|
nh_model = f"/home/wch/NH-Rep/data_backup/output_data/extracted/output_data/{name}_0_50k_model_h.pt" |
||||
|
|
||||
|
# 检查文件是否存在 |
||||
|
if not os.path.isfile(obj_file_path): |
||||
|
logger.error(f"OBJ 文件 {obj_file_path} 不存在。") |
||||
|
return |
||||
|
|
||||
|
try: |
||||
|
# 读取 obj 文件 |
||||
|
mesh = trimesh.load_mesh(obj_file_path) |
||||
|
logger.info(f"成功读取 OBJ 文件: {obj_file_path}") |
||||
|
except Exception as e: |
||||
|
logger.error(f"读取 OBJ 文件 {obj_file_path} 时出错: {e}") |
||||
|
return |
||||
|
|
||||
|
try: |
||||
|
# 调用采样函数 |
||||
|
result1 = sample_zero_surface_points_and_normals(mesh, num_samples=4096) |
||||
|
if result1 is None: |
||||
|
logger.error("采样失败,返回 None") |
||||
|
return |
||||
|
# 提取前 3 列作为坐标点 |
||||
|
coordinates = result1[:, :3] |
||||
|
# 将 ndarray 转换为 Tensor 并移动到设备上 |
||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
||||
|
coordinates_tensor = torch.from_numpy(coordinates).float().to(device).requires_grad_(True) |
||||
|
|
||||
|
sdf1 = nh(nh_model, coordinates_tensor) / 2 |
||||
|
sdf2 = mine(model_path, coordinates_tensor) |
||||
|
|
||||
|
loss1, loss2 = {}, {} |
||||
|
if sdf1 is not None and sdf2 is not None: |
||||
|
loss1["de"] = position_loss(sdf1).item() |
||||
|
loss2["de"] = position_loss(sdf2).item() |
||||
|
logger.info(f"NH 模型位置损失: {loss1}") |
||||
|
logger.info(f"Mine 模型位置损失: {loss2}") |
||||
|
|
||||
|
# 将 gt_normal 转换为 torch.Tensor 并移动到设备上 |
||||
|
gt_normal = torch.from_numpy(result1[:, 3:6]).float().to(device) |
||||
|
# 假设 gradient 函数已正确导入 |
||||
|
normal1 = gradient(coordinates_tensor, sdf1) |
||||
|
normal2 = gradient(coordinates_tensor, sdf2) |
||||
|
|
||||
|
loss1["nae"] = average_normal_error(gt_normal, normal1).item() |
||||
|
loss2["nae"] = average_normal_error(gt_normal, normal2).item() |
||||
|
|
||||
|
print("NH 模型的平均法向量误差 (NAE):", loss1["nae"]) |
||||
|
print("Mine 模型的平均法向量误差 (NAE):", loss2["nae"]) |
||||
|
|
||||
|
return loss1, loss2 |
||||
|
else: |
||||
|
logger.error("无法计算损失,SDF 结果为 None") |
||||
|
|
||||
|
except Exception as e: |
||||
|
logger.error(f"处理过程中出现错误: {e}") |
||||
|
|
||||
|
def main(): |
||||
|
names = get_namelist("/home/wch/brep2sdf/data/name_list.txt") |
||||
|
tl1_de, tl1_nae, tl2_de, tl2_nae = 0.0, 0.0, 0.0, 0.0 |
||||
|
valid_count = 0 |
||||
|
for name in names: |
||||
|
result = run(name) |
||||
|
if result is not None: |
||||
|
l1, l2 = result |
||||
|
tl1_de += l1["de"] |
||||
|
tl1_nae += l1["nae"] |
||||
|
tl2_de += l2["de"] |
||||
|
tl2_nae += l2["nae"] |
||||
|
valid_count += 1 |
||||
|
if valid_count > 0: |
||||
|
print(f"NH 模型平均位置损失 (de): {tl1_de/valid_count}") |
||||
|
print(f"NH 模型平均法向量误差 (nae): {tl1_nae/valid_count}") |
||||
|
print(f"Mine 模型平均位置损失 (de): {tl2_de/valid_count}") |
||||
|
print(f"Mine 模型平均法向量误差 (nae): {tl2_nae/valid_count}") |
||||
|
else: |
||||
|
print("没有有效的结果,无法计算平均值。") |
||||
|
|
||||
|
if __name__ == "__main__": |
||||
|
main() |
@ -0,0 +1,124 @@ |
|||||
|
import os |
||||
|
import glob |
||||
|
import argparse |
||||
|
from OCC.Core.STEPControl import STEPControl_Reader |
||||
|
from OCC.Core.TopExp import TopExp_Explorer |
||||
|
from OCC.Core.TopAbs import TopAbs_FACE |
||||
|
import matplotlib.pyplot as plt |
||||
|
import csv |
||||
|
from tqdm import tqdm # 导入 tqdm 库 |
||||
|
import concurrent.futures |
||||
|
from multiprocessing import Process, Queue |
||||
|
|
||||
|
# 定义一个新的函数,用于在子进程中执行计数操作 |
||||
|
def count_faces_task(file_path, result_queue): |
||||
|
try: |
||||
|
# 创建 STEP 读取器 |
||||
|
reader = STEPControl_Reader() |
||||
|
# 读取 STEP 文件 |
||||
|
status = reader.ReadFile(file_path) |
||||
|
if status == 1: |
||||
|
reader.TransferRoots() |
||||
|
shape = reader.OneShape() |
||||
|
# 遍历所有面 |
||||
|
explorer = TopExp_Explorer(shape, TopAbs_FACE) |
||||
|
face_count = 0 |
||||
|
while explorer.More(): |
||||
|
face_count += 1 |
||||
|
explorer.Next() |
||||
|
#print(face_count) |
||||
|
result_queue.put(face_count) |
||||
|
else: |
||||
|
print(f"无法读取文件 {file_path}") |
||||
|
result_queue.put(None) |
||||
|
except Exception as e: |
||||
|
print(f"处理文件 {file_path} 时出错: {e}") |
||||
|
result_queue.put(None) |
||||
|
|
||||
|
def count_faces_in_step_file(file_path, timeout=30): |
||||
|
result_queue = Queue() |
||||
|
p = Process(target=count_faces_task, args=(file_path, result_queue)) |
||||
|
p.start() |
||||
|
p.join(timeout) |
||||
|
|
||||
|
if p.is_alive(): |
||||
|
print(f"处理文件 {file_path} 超时,已终止") |
||||
|
p.terminate() |
||||
|
p.join() |
||||
|
return None |
||||
|
|
||||
|
result = result_queue.get() |
||||
|
return result |
||||
|
|
||||
|
def main(): |
||||
|
parser = argparse.ArgumentParser(description='统计 ABC 数据集模型面的数量并可视化') |
||||
|
parser.add_argument('-i','--input_dir', type=str, required=True, help='包含 STEP 文件的输入目录') |
||||
|
parser.add_argument('-o', '--output_file', type=str, default='face_counts.csv', help='保存面数量数据的 CSV 文件路径') |
||||
|
# 新增参数,用于指定进程数 |
||||
|
parser.add_argument('--processes', type=int, default=os.cpu_count()-1, help='并行处理的进程数,默认为 CPU 核心数') |
||||
|
args = parser.parse_args() |
||||
|
|
||||
|
# 读取已处理的文件名 |
||||
|
processed_files = set() |
||||
|
if os.path.exists(args.output_file): |
||||
|
with open(args.output_file, 'r', newline='') as csvfile: |
||||
|
reader = csv.reader(csvfile) |
||||
|
try: |
||||
|
next(reader) # 尝试跳过表头 |
||||
|
except StopIteration: |
||||
|
# 如果文件为空,直接跳过 |
||||
|
pass |
||||
|
for row in reader: |
||||
|
processed_files.add(row[0]) |
||||
|
|
||||
|
# 获取所有 STEP 文件并过滤掉已处理的文件 |
||||
|
step_files = glob.glob(os.path.join(args.input_dir, "**/*.step"), recursive=True) |
||||
|
step_files = [file for file in step_files if os.path.basename(file) not in processed_files] |
||||
|
|
||||
|
# 划分批次 |
||||
|
num_processes = args.processes |
||||
|
batch_size = len(step_files) // num_processes + 1 |
||||
|
batches = [step_files[i:i + batch_size] for i in range(0, len(step_files), batch_size)] |
||||
|
|
||||
|
face_counts = [] |
||||
|
# 打开 CSV 文件,准备逐批次写入 |
||||
|
with open(args.output_file, 'a+', newline='') as csvfile: |
||||
|
writer = csv.writer(csvfile) |
||||
|
# 写入表头 |
||||
|
writer.writerow(['文件名', '面的数量']) |
||||
|
|
||||
|
for batch in tqdm(batches, desc="处理批次进度"): |
||||
|
batch_results = [] |
||||
|
with concurrent.futures.ProcessPoolExecutor(max_workers=num_processes) as executor: |
||||
|
future_to_file = {executor.submit(count_faces_in_step_file, file_path): file_path for file_path in batch} |
||||
|
for future in concurrent.futures.as_completed(future_to_file): |
||||
|
file_path = future_to_file[future] |
||||
|
try: |
||||
|
# 指定超时时间 |
||||
|
result = future.result(timeout=30) |
||||
|
except concurrent.futures.TimeoutError: |
||||
|
print(f'{file_path} 处理超时,已终止') |
||||
|
continue |
||||
|
except Exception as exc: |
||||
|
print(f'{file_path} 产生了异常: {exc}') |
||||
|
else: |
||||
|
if result is not None: |
||||
|
batch_results.append((file_path, result)) |
||||
|
# 逐批次写入 CSV |
||||
|
writer.writerow([os.path.basename(file_path), result]) |
||||
|
|
||||
|
face_counts.extend(batch_results) |
||||
|
|
||||
|
if face_counts: |
||||
|
# 绘制直方图,需要提取面数 |
||||
|
face_counts_only = [count for _, count in face_counts] |
||||
|
plt.hist(face_counts_only, bins=50, edgecolor='black') |
||||
|
plt.title('ABC 数据集模型面数量直方图') |
||||
|
plt.xlabel('面的数量') |
||||
|
plt.ylabel('模型数量') |
||||
|
plt.show() |
||||
|
else: |
||||
|
print("未找到有效的 STEP 文件或处理过程中出现错误。") |
||||
|
|
||||
|
if __name__ == "__main__": |
||||
|
main() |
@ -0,0 +1,25 @@ |
|||||
|
import torch |
||||
|
import numpy as np |
||||
|
import pickle |
||||
|
|
||||
|
|
||||
|
|
||||
|
def load_brep_file(brep_path): |
||||
|
with open(brep_path, 'rb') as f: |
||||
|
brep_raw = pickle.load(f) |
||||
|
return brep_raw |
||||
|
|
||||
|
|
||||
|
if __name__ == "__main__": |
||||
|
data=load_brep_file("/home/wch/brep2sdf/data/output_data/00000031.xyz") |
||||
|
surfs =data["train_surf_ncs"] |
||||
|
print(surfs) |
||||
|
with open("0031_t.xyz","w") as f: |
||||
|
for point in surfs: |
||||
|
#f.write(f"{point[0]} {point[1]} {point[2]}\n") |
||||
|
f.write(f"{point[0]} {point[1]} {point[2]} {point[3]} {point[4]} {point[5]}\n") |
||||
|
''' |
||||
|
for surf in surfs: |
||||
|
for point in surf: |
||||
|
f.write(f"{point[0]} {point[1]} {point[2]}\n") |
||||
|
''' |
@ -0,0 +1,44 @@ |
|||||
|
import torch |
||||
|
import torch.nn as nn |
||||
|
from torchviz import make_dot |
||||
|
|
||||
|
class SimpleEncoder(nn.Module): |
||||
|
def __init__(self, feature_dim): |
||||
|
super(SimpleEncoder, self).__init__() |
||||
|
self.simple_encoder = nn.Sequential( |
||||
|
nn.Linear(3, 256), |
||||
|
nn.BatchNorm1d(256), |
||||
|
nn.ReLU(), |
||||
|
nn.Linear(256, 512), |
||||
|
nn.BatchNorm1d(512), |
||||
|
nn.ReLU(), |
||||
|
nn.Linear(512, 256), |
||||
|
nn.BatchNorm1d(256), |
||||
|
nn.ReLU(), |
||||
|
nn.Linear(256, feature_dim) |
||||
|
) |
||||
|
|
||||
|
def forward(self, x): |
||||
|
return self.simple_encoder(x) |
||||
|
|
||||
|
# 创建模型实例 |
||||
|
feature_dim = 8 # 根据你的需求设定 |
||||
|
model = SimpleEncoder(feature_dim) |
||||
|
|
||||
|
# 方法一:将模型设置为评估模式 |
||||
|
model.eval() |
||||
|
|
||||
|
# 方法二:增加输入数据的批次大小 |
||||
|
# x = torch.randn(2, 3) # 将批次大小从 1 改为 2 |
||||
|
|
||||
|
# 创建随机输入张量(根据实际情况调整大小) |
||||
|
x = torch.randn(1, 3) |
||||
|
|
||||
|
# 获取模型输出 |
||||
|
output = model(x) |
||||
|
|
||||
|
# 使用torchviz生成模型图 |
||||
|
dot = make_dot(output, params=dict(list(model.named_parameters()))) |
||||
|
|
||||
|
# 保存图像文件 |
||||
|
dot.render("simple_encoder", format="png") |
@ -0,0 +1,69 @@ |
|||||
|
import numpy as np |
||||
|
import torch |
||||
|
import argparse |
||||
|
import time |
||||
|
from brep2sdf.utils.logger import logger |
||||
|
|
||||
|
def create_grid(depth, box_size): |
||||
|
""" |
||||
|
创建三维网格点 |
||||
|
:param depth: 网格深度(决定分辨率) |
||||
|
:param box_size: 边界框大小(边长) |
||||
|
:return: 网格点数组和坐标网格 |
||||
|
""" |
||||
|
grid_size = 2**depth + 1 |
||||
|
start = -box_size / 2 |
||||
|
end = box_size / 2 |
||||
|
x = np.linspace(start, end, grid_size) |
||||
|
y = np.linspace(start, end, grid_size) |
||||
|
z = np.linspace(start, end, grid_size) |
||||
|
xx, yy, zz = np.meshgrid(x, y, z, indexing='ij') |
||||
|
points = np.stack([xx.ravel(), yy.ravel(), zz.ravel()], axis=1) |
||||
|
|
||||
|
# 新增归一化处理 |
||||
|
max_coord = np.max(np.abs(points)) |
||||
|
points = points / max_coord # 归一化到[-1,1] |
||||
|
return points, xx, yy, zz |
||||
|
|
||||
|
def predict_sdf(model, points, device): |
||||
|
""" |
||||
|
使用模型预测SDF值 |
||||
|
:param model: PyTorch模型 |
||||
|
:param points: 输入点坐标 (N, 3) |
||||
|
:param device: 设备(CPU/GPU) |
||||
|
:return: SDF值数组 (N,) |
||||
|
""" |
||||
|
points_t = torch.from_numpy(points).float().to(device) |
||||
|
|
||||
|
with torch.no_grad(): |
||||
|
sdf = model.forward_background(points_t).cpu().numpy().flatten() |
||||
|
return sdf |
||||
|
|
||||
|
def main(): |
||||
|
parser = argparse.ArgumentParser(description='SDF Visualization') |
||||
|
parser.add_argument('-i', '--input', type=str, required=True, help='Input model file (.pt)') |
||||
|
parser.add_argument('--depth', type=int, default=7, help='网格深度(分辨率)') |
||||
|
parser.add_argument('--box_size', type=float, default=2.0, |
||||
|
help='边界框大小(建议设为2.0以得到[-1,1]范围)') |
||||
|
parser.add_argument('--use-gpu', action='store_true', help='使用GPU') |
||||
|
parser.add_argument('-o', '--output', type=str, default='sdf_data.npz', help='输出SDF数据文件(.npz格式)') |
||||
|
args = parser.parse_args() |
||||
|
|
||||
|
# 设置设备 |
||||
|
device = torch.device("cuda" if args.use_gpu and torch.cuda.is_available() else "cpu") |
||||
|
print(f"Using device: {device}") |
||||
|
|
||||
|
model = torch.jit.load(args.input).to(device) |
||||
|
model.eval() |
||||
|
|
||||
|
# 创建网格并预测SDF |
||||
|
points, xx, yy, zz = create_grid(args.depth, args.box_size) |
||||
|
sdf = predict_sdf(model, points, device) |
||||
|
sdf_grid = sdf.reshape(xx.shape) |
||||
|
|
||||
|
# 保存SDF数据到文件 |
||||
|
np.savez(args.output, xx=xx, yy=yy, zz=zz, sdf_grid=sdf_grid) |
||||
|
print(f"SDF数据已保存到 {args.output}") |
||||
|
|
||||
|
if __name__ == "__main__": |
||||
|
main() |
@ -1,69 +1,112 @@ |
|||||
|
import trimesh |
||||
|
import numpy as np |
||||
|
from brep2sdf.data.sampler import sample_zero_surface_points_and_normals |
||||
|
from brep2sdf.networks.network import gradient |
||||
import torch |
import torch |
||||
|
import logging |
||||
|
|
||||
|
# 配置日志记录 |
||||
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') |
||||
|
|
||||
|
def position_loss(pred_sdfs: torch.Tensor) -> torch.Tensor: |
||||
|
"""位置损失函数""" |
||||
|
# 保持梯度流 |
||||
|
squared_diff = torch.pow(pred_sdfs, 2) |
||||
|
return torch.mean(squared_diff) |
||||
|
|
||||
|
def average_normal_error(normals1: torch.Tensor, normals2: torch.Tensor) -> torch.Tensor: |
||||
|
""" |
||||
|
计算平均法向量误差 (NAE) |
||||
|
:param normals1: 形状为 (B, 3) 的法向量张量 |
||||
|
:param normals2: 形状为 (B, 3) 的法向量张量 |
||||
|
:return: NAE 值 |
||||
|
""" |
||||
|
dot_products = torch.sum(normals1 * normals2, dim=-1) |
||||
|
absolute_dot_products = torch.abs(dot_products) |
||||
|
angle_errors = 1 - absolute_dot_products |
||||
|
return torch.mean(angle_errors) |
||||
|
|
||||
|
def |
||||
|
|
||||
|
# ========== |
||||
|
def load_model(model_path): |
||||
|
"""加载模型的通用函数""" |
||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
||||
|
try: |
||||
|
model = torch.jit.load(model_path).to(device) |
||||
|
logging.info(f"成功加载模型: {model_path}") |
||||
|
return model |
||||
|
except Exception as e: |
||||
|
logging.error(f"加载模型 {model_path} 时出错: {e}") |
||||
|
return None |
||||
|
|
||||
|
|
||||
|
#========== |
||||
|
def nh(model_path, points): |
||||
|
model = load_model(model_path) |
||||
|
if model is None: |
||||
|
return None |
||||
|
try: |
||||
|
return model(points) |
||||
|
except Exception as e: |
||||
|
logging.error(f"调用 NH 模型时出错: {e}") |
||||
|
return None |
||||
|
|
||||
|
def mine(model_path, points): |
||||
|
model = load_model(model_path) |
||||
|
if model is None: |
||||
|
return None |
||||
|
try: |
||||
|
return model.forward_background(points) |
||||
|
except Exception as e: |
||||
|
logging.error(f"调用 mine 模型时出错: {e}") |
||||
|
return None |
||||
|
|
||||
|
def main(): |
||||
|
# 替换为实际的 obj 文件路径 |
||||
|
obj_file_path = "/home/wch/brep2sdf/data/gt_mesh/00000031.obj" |
||||
|
model_path = "/home/wch/brep2sdf/data/output_data/00000031.pt" |
||||
|
nh_model = "/home/wch/NH-Rep/data/output_data/00000031_0_50k_model_h.pt" |
||||
|
|
||||
|
try: |
||||
|
# 读取 obj 文件 |
||||
|
mesh = trimesh.load_mesh(obj_file_path) |
||||
|
logging.info(f"成功读取 OBJ 文件: {obj_file_path}") |
||||
|
except Exception as e: |
||||
|
logging.error(f"读取 OBJ 文件 {obj_file_path} 时出错: {e}") |
||||
|
return |
||||
|
|
||||
|
try: |
||||
|
# 调用采样函数 |
||||
|
result1 = sample_zero_surface_points_and_normals(mesh, num_samples=4096) |
||||
|
if result1 is None: |
||||
|
logging.error("采样失败,返回 None") |
||||
|
return |
||||
|
# 提取前 3 列作为坐标点 |
||||
|
coordinates = result1[:, :3] |
||||
|
# 将 ndarray 转换为 Tensor 并移动到设备上 |
||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
||||
|
coordinates_tensor = torch.from_numpy(coordinates).float().to(device) |
||||
|
|
||||
|
sdf1 = nh(nh_model, coordinates_tensor) / 2 |
||||
|
sdf2 = mine(model_path, coordinates_tensor) |
||||
|
|
||||
|
if sdf1 is not None and sdf2 is not None: |
||||
|
loss1_ = position_loss(sdf1) |
||||
|
loss2 = position_loss(sdf2) |
||||
|
logging.info(f"NH 模型位置损失: {loss1.item()}") |
||||
|
logging.info(f"Mine 模型位置损失: {loss2.item()}") |
||||
|
|
||||
|
gt_normal = result1[:, 3:6] |
||||
|
normal1 = gradient(coordinates, sdf1) |
||||
|
normal2 = gradient(coordinates, sdf2) |
||||
|
nae1=average_normal_error(gt_normal, normal1) |
||||
|
nae2=average_normal_error(gt_normal, normal2) |
||||
|
else: |
||||
|
logging.error("无法计算损失,SDF 结果为 None") |
||||
|
|
||||
|
except Exception as e: |
||||
|
logging.error(f"处理过程中出现错误: {e}") |
||||
|
|
||||
from typing import List, Tuple |
|
||||
|
|
||||
def bbox_intersect(surf_bboxes: torch.Tensor, indices: torch.Tensor, child_bboxes: torch.Tensor) -> torch.Tensor: |
|
||||
''' |
|
||||
args: |
|
||||
surf_bboxes: [B, 6] - 表示多个包围盒的张量,每个包围盒由其最小和最大坐标定义。 |
|
||||
indices: 选择bbox, [N], N <= B - 用于选择特定包围盒的索引张量。 |
|
||||
child_bboxes: [8, 6] - 一个包围盒被分成八个子包围盒后的结果。 |
|
||||
return: |
|
||||
intersect_mask: [8, N] - 布尔掩码,表示每个子包围盒与选择的包围盒是否相交。 |
|
||||
''' |
|
||||
# 提取选中的边界框 |
|
||||
selected_bboxes = surf_bboxes[indices] # 形状为 [N, 6] |
|
||||
min1, max1 = selected_bboxes[:, :3], selected_bboxes[:, 3:] # 形状为 [N, 3] |
|
||||
min2, max2 = child_bboxes[:, :3], child_bboxes[:, 3:] # 形状为 [8, 3] |
|
||||
|
|
||||
# 确保广播机制正常工作 |
|
||||
intersect_mask = torch.all( |
|
||||
(max1.unsqueeze(0) >= min2.unsqueeze(1)) & # 形状为 [8, N, 3] |
|
||||
(max2.unsqueeze(1) >= min1.unsqueeze(0)), # 形状为 [8, N, 3] |
|
||||
dim=-1 |
|
||||
) # 最终形状为 [8, N] |
|
||||
|
|
||||
return intersect_mask |
|
||||
|
|
||||
# 测试程序 |
|
||||
if __name__ == "__main__": |
if __name__ == "__main__": |
||||
# 构造输入数据 |
main() |
||||
surf_bboxes = torch.tensor([ |
|
||||
[0, 0, 0, 1, 1, 1], # 立方体 1 |
|
||||
[0.5, 0.5, 0.5, 1.5, 1.5, 1.5], # 立方体 2 |
|
||||
[2, 2, 2, 3, 3, 3] # 立方体 3 |
|
||||
]) # [B=3, 6] |
|
||||
|
|
||||
indices = torch.tensor([0, 1]) # 选择前两个立方体 |
|
||||
|
|
||||
# 假设父边界框为 [0, 0, 0, 2, 2, 2],生成其八个子边界框 |
|
||||
parent_bbox = torch.tensor([0, 0, 0, 2, 2, 2]) |
|
||||
center = (parent_bbox[:3] + parent_bbox[3:]) / 2 |
|
||||
child_bboxes = torch.tensor([ |
|
||||
[parent_bbox[0], parent_bbox[1], parent_bbox[2], center[0], center[1], center[2]], # 左下前 |
|
||||
[center[0], parent_bbox[1], parent_bbox[2], parent_bbox[3], center[1], center[2]], # 右下前 |
|
||||
[parent_bbox[0], center[1], parent_bbox[2], center[0], parent_bbox[4], center[2]], # 左上前 |
|
||||
[center[0], center[1], parent_bbox[2], parent_bbox[3], parent_bbox[4], center[2]], # 右上前 |
|
||||
[parent_bbox[0], parent_bbox[1], center[2], center[0], center[1], parent_bbox[5]], # 左下后 |
|
||||
[center[0], parent_bbox[1], center[2], parent_bbox[3], center[1], parent_bbox[5]], # 右下后 |
|
||||
[parent_bbox[0], center[1], center[2], center[0], parent_bbox[4], parent_bbox[5]], # 左上后 |
|
||||
[center[0], center[1], center[2], parent_bbox[3], parent_bbox[4], parent_bbox[5]] # 右上后 |
|
||||
]) # [8, 6] |
|
||||
|
|
||||
# 调用函数 |
|
||||
intersect_mask = bbox_intersect(surf_bboxes, indices, child_bboxes) |
|
||||
|
|
||||
# 输出结果 |
|
||||
print("Intersect Mask:") |
|
||||
print(intersect_mask) |
|
||||
|
|
||||
# 将布尔掩码转换为索引列表 |
|
||||
child_indices = [] |
|
||||
for i in range(8): # 遍历每个子节点 |
|
||||
intersecting_faces = indices[intersect_mask[i]] # 获取当前子节点的相交面片索引 |
|
||||
child_indices.append(intersecting_faces) |
|
||||
|
|
||||
# 打印每个子节点对应的相交索引 |
|
||||
print("\nChild Indices:") |
|
||||
for i, indices in enumerate(child_indices): |
|
||||
print(f"Child {i}: {indices}") |
|
Loading…
Reference in new issue