|
|
|
import os
|
|
|
|
import time
|
|
|
|
import glob
|
|
|
|
import trimesh
|
|
|
|
import numpy as np
|
|
|
|
from tqdm import tqdm
|
|
|
|
import mesh2sdf
|
|
|
|
import skimage.measure
|
|
|
|
import matplotlib.pyplot as plt
|
|
|
|
from concurrent.futures import ProcessPoolExecutor, as_completed
|
|
|
|
import logging
|
|
|
|
from typing import Tuple, Optional
|
|
|
|
|
|
|
|
from OCC.Core.STEPControl import STEPControl_Reader
|
|
|
|
from OCC.Core.BRepMesh import BRepMesh_IncrementalMesh
|
|
|
|
from OCC.Core.StlAPI import StlAPI_Writer
|
|
|
|
from OCC.Core.IFSelect import IFSelect_RetDone
|
|
|
|
|
|
|
|
# 配置日志记录
|
|
|
|
def setup_logger():
|
|
|
|
logger = logging.getLogger('furniture_processing')
|
|
|
|
logger.setLevel(logging.INFO)
|
|
|
|
|
|
|
|
# 文件处理器
|
|
|
|
fh = logging.FileHandler('furniture_processing.log')
|
|
|
|
fh.setLevel(logging.INFO)
|
|
|
|
|
|
|
|
# 控制台处理器
|
|
|
|
ch = logging.StreamHandler()
|
|
|
|
ch.setLevel(logging.INFO)
|
|
|
|
|
|
|
|
# 创建格式器
|
|
|
|
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
|
|
|
|
fh.setFormatter(formatter)
|
|
|
|
ch.setFormatter(formatter)
|
|
|
|
|
|
|
|
# 添加处理器
|
|
|
|
logger.addHandler(fh)
|
|
|
|
logger.addHandler(ch)
|
|
|
|
|
|
|
|
return logger
|
|
|
|
|
|
|
|
logger = setup_logger()
|
|
|
|
|
|
|
|
def brep_to_mesh(
|
|
|
|
brep_file: str,
|
|
|
|
stl_output_file: str,
|
|
|
|
obj_output_file: str
|
|
|
|
) -> Optional[trimesh.Trimesh]:
|
|
|
|
"""将BREP/STEP文件转换为Trimesh对象并保存STL和OBJ文件"""
|
|
|
|
try:
|
|
|
|
step_reader = STEPControl_Reader()
|
|
|
|
status = step_reader.ReadFile(brep_file)
|
|
|
|
|
|
|
|
if status != IFSelect_RetDone:
|
|
|
|
logger.error(f"无法读取BREP文件: {brep_file}")
|
|
|
|
return None
|
|
|
|
|
|
|
|
step_reader.TransferRoots()
|
|
|
|
shape = step_reader.OneShape()
|
|
|
|
|
|
|
|
mesh = BRepMesh_IncrementalMesh(shape, 0.1)
|
|
|
|
mesh.Perform()
|
|
|
|
|
|
|
|
base_name = os.path.splitext(os.path.basename(brep_file))[0]
|
|
|
|
|
|
|
|
# 保存STL
|
|
|
|
stl_writer = StlAPI_Writer()
|
|
|
|
stl_writer.Write(shape, stl_output_file)
|
|
|
|
|
|
|
|
# 保存OBJ
|
|
|
|
mesh = trimesh.load(stl_output_file)
|
|
|
|
mesh.export(obj_output_file)
|
|
|
|
|
|
|
|
logger.debug(f"成功转换并保存mesh: {base_name}")
|
|
|
|
return mesh
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
logger.error(f"转换BREP文件失败 {brep_file}: {str(e)}")
|
|
|
|
return None
|
|
|
|
|
|
|
|
# 现在加上x,y,z坐标值,并分成pos和neg两个组
|
|
|
|
def sdf2xyzsdf(sdf, size=128):
|
|
|
|
# Generate coordinate grids
|
|
|
|
x, y, z = np.meshgrid(range(size), range(size), range(size), indexing='ij')
|
|
|
|
|
|
|
|
# Stack coordinates along a new axis
|
|
|
|
coords = np.stack([x, y, z], axis=-1)
|
|
|
|
|
|
|
|
# Reshape to match the SDF shape
|
|
|
|
coords = coords.reshape(size, size, size, 3)
|
|
|
|
|
|
|
|
# Stack coordinates and SDF values
|
|
|
|
xyzsdf = np.concatenate([coords, sdf[..., np.newaxis]], axis=-1)
|
|
|
|
|
|
|
|
return xyzsdf
|
|
|
|
|
|
|
|
def split_sdf(xyzsdf):
|
|
|
|
# Separate positive and negative SDF values
|
|
|
|
pos_mask = xyzsdf[..., 3] >= 0
|
|
|
|
neg_mask = xyzsdf[..., 3] < 0
|
|
|
|
|
|
|
|
pos_xyzsdf = xyzsdf[pos_mask]
|
|
|
|
neg_xyzsdf = xyzsdf[neg_mask]
|
|
|
|
|
|
|
|
return pos_xyzsdf, neg_xyzsdf
|
|
|
|
|
|
|
|
|
|
|
|
def mesh_to_sdf(
|
|
|
|
mesh: trimesh.Trimesh,
|
|
|
|
filename: str,
|
|
|
|
sdf_output_file: str,
|
|
|
|
mesh_scale: float = 0.8,
|
|
|
|
size: int = 128,
|
|
|
|
level: float = None,
|
|
|
|
) -> bool:
|
|
|
|
"""
|
|
|
|
将mesh转换为SDF并保存结果
|
|
|
|
"""
|
|
|
|
try:
|
|
|
|
if not level:
|
|
|
|
level = 2 / size
|
|
|
|
|
|
|
|
# 标准化网格
|
|
|
|
vertices = mesh.vertices
|
|
|
|
bbmin = vertices.min(0)
|
|
|
|
bbmax = vertices.max(0)
|
|
|
|
center = (bbmin + bbmax) * 0.5
|
|
|
|
scale = 2.0 * mesh_scale / (bbmax - bbmin).max()
|
|
|
|
vertices = (vertices - center) * scale
|
|
|
|
|
|
|
|
# 计算SDF
|
|
|
|
sdf, mesh = mesh2sdf.compute(
|
|
|
|
vertices,
|
|
|
|
mesh.faces,
|
|
|
|
size,
|
|
|
|
fix=True,
|
|
|
|
level=level,
|
|
|
|
return_mesh=True
|
|
|
|
)
|
|
|
|
|
|
|
|
# Get the dimensions of the SDF grid
|
|
|
|
size = sdf.shape[0]
|
|
|
|
|
|
|
|
# Convert SDF to XYZ+SDF format
|
|
|
|
xyzsdf = sdf2xyzsdf(sdf, size)
|
|
|
|
|
|
|
|
# Split into positive and negative SDF values
|
|
|
|
pos_xyzsdf, neg_xyzsdf = split_sdf(xyzsdf)
|
|
|
|
|
|
|
|
# Save the results as a .npz file
|
|
|
|
np.savez_compressed(sdf_output_file, pos=pos_xyzsdf, neg=neg_xyzsdf)
|
|
|
|
|
|
|
|
logger.debug(f"成功生成并保存SDF: {os.path.basename(sdf_output_file)}")
|
|
|
|
return True
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
logger.error(f"处理文件失败 {filename}: {str(e)}")
|
|
|
|
return False
|
|
|
|
|
|
|
|
def check_conversion(
|
|
|
|
step_file: str,
|
|
|
|
obj_path: str,
|
|
|
|
sdf_path: str
|
|
|
|
) -> bool:
|
|
|
|
"""检查转换结果的质量"""
|
|
|
|
try:
|
|
|
|
base_name = os.path.splitext(os.path.basename(step_file))[0]
|
|
|
|
|
|
|
|
# 检查文件是否存在
|
|
|
|
if not (os.path.exists(obj_path) and os.path.exists(sdf_path)):
|
|
|
|
logger.warning(f"文件不完整: {base_name}")
|
|
|
|
return False
|
|
|
|
|
|
|
|
# 加载并检查OBJ
|
|
|
|
obj_mesh = trimesh.load(obj_path)
|
|
|
|
if not (obj_mesh.is_watertight ):
|
|
|
|
logger.warning(f"网格质量不合格: {base_name}")
|
|
|
|
return False
|
|
|
|
|
|
|
|
# 检查SDF
|
|
|
|
sdf_data = np.load(sdf_path)
|
|
|
|
# 检查pos和neg数组是否存在
|
|
|
|
if 'pos' not in sdf_data or 'neg' not in sdf_data:
|
|
|
|
logger.warning(f"SDF数据格式不正确: {base_name}")
|
|
|
|
return False
|
|
|
|
# 检查数据维度和内容
|
|
|
|
if np.isnan(sdf_data['pos']).any() or np.isnan(sdf_data['neg']).any():
|
|
|
|
logger.warning(f"SDF数据包含NaN值: {base_name}")
|
|
|
|
return False
|
|
|
|
|
|
|
|
logger.debug(f"质量检查通过: {base_name}")
|
|
|
|
return True
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
logger.error(f"检查失败 {step_file}: {str(e)}")
|
|
|
|
return False
|
|
|
|
|
|
|
|
def process(step_file: str, set_name:str) -> bool:
|
|
|
|
"""处理单个STEP文件的完整流程"""
|
|
|
|
try:
|
|
|
|
base_name = os.path.splitext(os.path.basename(step_file))[0]
|
|
|
|
logger.info(f"开始处理: {base_name}")
|
|
|
|
|
|
|
|
# 准备输出路径
|
|
|
|
stl_output_file = os.path.join('test_data/stl', f"{set_name}/{base_name}.stl")
|
|
|
|
obj_output_file = os.path.join('test_data/obj', f"{set_name}/{base_name}.obj")
|
|
|
|
sdf_output_file = os.path.join('test_data/sdf', f"{set_name}/{base_name}.npz")
|
|
|
|
|
|
|
|
|
|
|
|
# 使用tqdm创建处理步骤的进度条
|
|
|
|
steps = ['STEP->Mesh', 'Mesh->SDF', '质量检查']
|
|
|
|
pbar = tqdm(steps, desc=f"处理 {base_name}", leave=False, position=1)
|
|
|
|
|
|
|
|
# 1. STEP到Mesh的转换
|
|
|
|
pbar.set_description(f"处理 {base_name} [STEP->Mesh]")
|
|
|
|
mesh = brep_to_mesh(step_file, stl_output_file, obj_output_file)
|
|
|
|
if mesh is None:
|
|
|
|
pbar.close()
|
|
|
|
return False
|
|
|
|
pbar.update(1)
|
|
|
|
|
|
|
|
# 2. Mesh到SDF的转换
|
|
|
|
pbar.set_description(f"处理 {base_name} [Mesh->SDF]")
|
|
|
|
if not mesh_to_sdf(mesh, step_file, sdf_output_file, mesh_scale=0.8, size=128):
|
|
|
|
pbar.close()
|
|
|
|
return False
|
|
|
|
pbar.update(1)
|
|
|
|
|
|
|
|
# 3. 检查转换结果
|
|
|
|
pbar.set_description(f"处理 {base_name} [质量检查]")
|
|
|
|
if not check_conversion(step_file, obj_output_file, sdf_output_file):
|
|
|
|
pbar.close()
|
|
|
|
return False
|
|
|
|
pbar.update(1)
|
|
|
|
|
|
|
|
pbar.close()
|
|
|
|
logger.info(f"成功处理: {base_name}")
|
|
|
|
return True
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
logger.error(f"处理失败 {step_file}: {str(e)}")
|
|
|
|
return False
|
|
|
|
|
|
|
|
def main():
|
|
|
|
"""主函数:并行处理所有STEP文件"""
|
|
|
|
INPUT = '/mnt/disk2/dataset/furniture/step/furniture_dataset_step'
|
|
|
|
if not os.path.exists(INPUT):
|
|
|
|
logger.error(f"输入目录不存在: {INPUT}")
|
|
|
|
return
|
|
|
|
|
|
|
|
# 创建基础输出目录
|
|
|
|
os.makedirs('test_data/stl', exist_ok=True)
|
|
|
|
os.makedirs('test_data/obj', exist_ok=True)
|
|
|
|
os.makedirs('test_data/sdf', exist_ok=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logger.info("开始数据处理...")
|
|
|
|
|
|
|
|
# 创建总体进度条
|
|
|
|
#for set_name in ['train', 'val', 'test']:
|
|
|
|
for set_name in ['val', 'test']:
|
|
|
|
|
|
|
|
input_dir = os.path.join(INPUT, set_name)
|
|
|
|
if not os.path.exists(input_dir):
|
|
|
|
logger.warning(f"目录不存在: {input_dir}")
|
|
|
|
continue
|
|
|
|
|
|
|
|
os.makedirs(f'test_data/stl/{set_name}', exist_ok=True)
|
|
|
|
os.makedirs(f'test_data/obj/{set_name}', exist_ok=True)
|
|
|
|
os.makedirs(f'test_data/sdf/{set_name}', exist_ok=True)
|
|
|
|
|
|
|
|
step_files = glob.glob(os.path.join(input_dir, '*.step'))
|
|
|
|
total_files = len(step_files)
|
|
|
|
if total_files == 0:
|
|
|
|
logger.error("没有找到任何STEP文件")
|
|
|
|
logger.info(f"找到 {set_name} 集合文件: {total_files}个")
|
|
|
|
|
|
|
|
|
|
|
|
valid_conversions = 0
|
|
|
|
# 创建进度条
|
|
|
|
with tqdm(total=total_files, desc=f"{set_name}进度", position=0) as main_pbar:
|
|
|
|
with ProcessPoolExecutor(max_workers=min(max(1, os.cpu_count() // 2), 8)) as executor:
|
|
|
|
futures = {executor.submit(process, f, set_name): f for f in step_files}
|
|
|
|
|
|
|
|
for future in as_completed(futures):
|
|
|
|
try:
|
|
|
|
if future.result(timeout=300):
|
|
|
|
valid_conversions += 1
|
|
|
|
main_pbar.update(1)
|
|
|
|
main_pbar.set_postfix({
|
|
|
|
'success_rate': f"{(valid_conversions/total_files)*100:.1f}%"
|
|
|
|
})
|
|
|
|
except Exception as e:
|
|
|
|
logger.error(f"处理失败 {futures[future]}: {str(e)}")
|
|
|
|
main_pbar.update(1)
|
|
|
|
success_rate = (valid_conversions / total_files) * 100 # 这个变量在日志中被使用但未定义
|
|
|
|
logger.info(f"处理完成: {set_name} 集合, 成功率: {success_rate:.2f}% = {valid_conversions}/{total_files}个")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
main()
|