You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

308 lines
9.6 KiB

7 months ago
import os
7 months ago
import time
import glob
7 months ago
import trimesh
import numpy as np
from tqdm import tqdm
import mesh2sdf
7 months ago
import skimage.measure
import matplotlib.pyplot as plt
from concurrent.futures import ProcessPoolExecutor, as_completed
import logging
from typing import Tuple, Optional
7 months ago
7 months ago
from OCC.Core.STEPControl import STEPControl_Reader
from OCC.Core.BRepMesh import BRepMesh_IncrementalMesh
from OCC.Core.StlAPI import StlAPI_Writer
from OCC.Core.IFSelect import IFSelect_RetDone
7 months ago
7 months ago
# 配置日志记录
def setup_logger():
logger = logging.getLogger('furniture_processing')
logger.setLevel(logging.INFO)
# 文件处理器
fh = logging.FileHandler('furniture_processing.log')
fh.setLevel(logging.INFO)
# 控制台处理器
ch = logging.StreamHandler()
ch.setLevel(logging.INFO)
# 创建格式器
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
fh.setFormatter(formatter)
ch.setFormatter(formatter)
# 添加处理器
logger.addHandler(fh)
logger.addHandler(ch)
return logger
logger = setup_logger()
def brep_to_mesh(
brep_file: str,
stl_output_file: str,
obj_output_file: str
) -> Optional[trimesh.Trimesh]:
"""将BREP/STEP文件转换为Trimesh对象并保存STL和OBJ文件"""
try:
step_reader = STEPControl_Reader()
status = step_reader.ReadFile(brep_file)
if status != IFSelect_RetDone:
logger.error(f"无法读取BREP文件: {brep_file}")
return None
step_reader.TransferRoots()
shape = step_reader.OneShape()
mesh = BRepMesh_IncrementalMesh(shape, 0.1)
mesh.Perform()
base_name = os.path.splitext(os.path.basename(brep_file))[0]
# 保存STL
stl_writer = StlAPI_Writer()
stl_writer.Write(shape, stl_output_file)
# 保存OBJ
mesh = trimesh.load(stl_output_file)
mesh.export(obj_output_file)
logger.debug(f"成功转换并保存mesh: {base_name}")
return mesh
except Exception as e:
logger.error(f"转换BREP文件失败 {brep_file}: {str(e)}")
return None
# 现在加上x,y,z坐标值,并分成pos和neg两个组
def sdf2xyzsdf(sdf, size=128):
# Generate coordinate grids
x, y, z = np.meshgrid(range(size), range(size), range(size), indexing='ij')
# Stack coordinates along a new axis
coords = np.stack([x, y, z], axis=-1)
# Reshape to match the SDF shape
coords = coords.reshape(size, size, size, 3)
# Stack coordinates and SDF values
xyzsdf = np.concatenate([coords, sdf[..., np.newaxis]], axis=-1)
return xyzsdf
def split_sdf(xyzsdf):
# Separate positive and negative SDF values
pos_mask = xyzsdf[..., 3] >= 0
neg_mask = xyzsdf[..., 3] < 0
pos_xyzsdf = xyzsdf[pos_mask]
neg_xyzsdf = xyzsdf[neg_mask]
return pos_xyzsdf, neg_xyzsdf
def mesh_to_sdf(
7 months ago
mesh: trimesh.Trimesh,
7 months ago
filename: str,
sdf_output_file: str,
mesh_scale: float = 0.8,
size: int = 128,
level: float = None,
) -> bool:
"""
将mesh转换为SDF并保存结果
"""
try:
if not level:
level = 2 / size
# 标准化网格
vertices = mesh.vertices
bbmin = vertices.min(0)
bbmax = vertices.max(0)
center = (bbmin + bbmax) * 0.5
scale = 2.0 * mesh_scale / (bbmax - bbmin).max()
vertices = (vertices - center) * scale
# 计算SDF
sdf, mesh = mesh2sdf.compute(
vertices,
mesh.faces,
size,
fix=True,
level=level,
return_mesh=True
)
# Get the dimensions of the SDF grid
size = sdf.shape[0]
# Convert SDF to XYZ+SDF format
xyzsdf = sdf2xyzsdf(sdf, size)
# Split into positive and negative SDF values
pos_xyzsdf, neg_xyzsdf = split_sdf(xyzsdf)
# Save the results as a .npz file
np.savez_compressed(sdf_output_file, pos=pos_xyzsdf, neg=neg_xyzsdf)
logger.debug(f"成功生成并保存SDF: {os.path.basename(sdf_output_file)}")
return True
except Exception as e:
logger.error(f"处理文件失败 {filename}: {str(e)}")
return False
def check_conversion(
step_file: str,
obj_path: str,
sdf_path: str
) -> bool:
"""检查转换结果的质量"""
try:
base_name = os.path.splitext(os.path.basename(step_file))[0]
# 检查文件是否存在
if not (os.path.exists(obj_path) and os.path.exists(sdf_path)):
logger.warning(f"文件不完整: {base_name}")
return False
# 加载并检查OBJ
obj_mesh = trimesh.load(obj_path)
if not (obj_mesh.is_watertight ):
logger.warning(f"网格质量不合格: {base_name}")
return False
# 检查SDF
sdf_data = np.load(sdf_path)
# 检查pos和neg数组是否存在
if 'pos' not in sdf_data or 'neg' not in sdf_data:
logger.warning(f"SDF数据格式不正确: {base_name}")
return False
# 检查数据维度和内容
if np.isnan(sdf_data['pos']).any() or np.isnan(sdf_data['neg']).any():
logger.warning(f"SDF数据包含NaN值: {base_name}")
return False
logger.debug(f"质量检查通过: {base_name}")
return True
except Exception as e:
logger.error(f"检查失败 {step_file}: {str(e)}")
return False
def process(step_file: str, set_name:str) -> bool:
"""处理单个STEP文件的完整流程"""
try:
base_name = os.path.splitext(os.path.basename(step_file))[0]
logger.info(f"开始处理: {base_name}")
# 准备输出路径
stl_output_file = os.path.join('test_data/stl', f"{set_name}/{base_name}.stl")
obj_output_file = os.path.join('test_data/obj', f"{set_name}/{base_name}.obj")
sdf_output_file = os.path.join('test_data/sdf', f"{set_name}/{base_name}.npz")
# 使用tqdm创建处理步骤的进度条
steps = ['STEP->Mesh', 'Mesh->SDF', '质量检查']
pbar = tqdm(steps, desc=f"处理 {base_name}", leave=False, position=1)
# 1. STEP到Mesh的转换
pbar.set_description(f"处理 {base_name} [STEP->Mesh]")
mesh = brep_to_mesh(step_file, stl_output_file, obj_output_file)
if mesh is None:
pbar.close()
return False
pbar.update(1)
# 2. Mesh到SDF的转换
pbar.set_description(f"处理 {base_name} [Mesh->SDF]")
if not mesh_to_sdf(mesh, step_file, sdf_output_file, mesh_scale=0.8, size=128):
pbar.close()
return False
pbar.update(1)
# 3. 检查转换结果
pbar.set_description(f"处理 {base_name} [质量检查]")
if not check_conversion(step_file, obj_output_file, sdf_output_file):
pbar.close()
return False
pbar.update(1)
pbar.close()
logger.info(f"成功处理: {base_name}")
return True
except Exception as e:
logger.error(f"处理失败 {step_file}: {str(e)}")
return False
7 months ago
def main():
7 months ago
"""主函数:并行处理所有STEP文件"""
INPUT = '/mnt/disk2/dataset/furniture/step/furniture_dataset_step'
if not os.path.exists(INPUT):
logger.error(f"输入目录不存在: {INPUT}")
return
# 创建基础输出目录
os.makedirs('test_data/stl', exist_ok=True)
os.makedirs('test_data/obj', exist_ok=True)
os.makedirs('test_data/sdf', exist_ok=True)
logger.info("开始数据处理...")
# 创建总体进度条
#for set_name in ['train', 'val', 'test']:
for set_name in ['val', 'test']:
input_dir = os.path.join(INPUT, set_name)
if not os.path.exists(input_dir):
logger.warning(f"目录不存在: {input_dir}")
continue
os.makedirs(f'test_data/stl/{set_name}', exist_ok=True)
os.makedirs(f'test_data/obj/{set_name}', exist_ok=True)
os.makedirs(f'test_data/sdf/{set_name}', exist_ok=True)
step_files = glob.glob(os.path.join(input_dir, '*.step'))
total_files = len(step_files)
if total_files == 0:
logger.error("没有找到任何STEP文件")
logger.info(f"找到 {set_name} 集合文件: {total_files}")
7 months ago
7 months ago
valid_conversions = 0
# 创建进度条
with tqdm(total=total_files, desc=f"{set_name}进度", position=0) as main_pbar:
with ProcessPoolExecutor(max_workers=min(max(1, os.cpu_count() // 2), 8)) as executor:
futures = {executor.submit(process, f, set_name): f for f in step_files}
for future in as_completed(futures):
try:
if future.result(timeout=300):
valid_conversions += 1
main_pbar.update(1)
main_pbar.set_postfix({
'success_rate': f"{(valid_conversions/total_files)*100:.1f}%"
})
except Exception as e:
logger.error(f"处理失败 {futures[future]}: {str(e)}")
main_pbar.update(1)
success_rate = (valid_conversions / total_files) * 100 # 这个变量在日志中被使用但未定义
logger.info(f"处理完成: {set_name} 集合, 成功率: {success_rate:.2f}% = {valid_conversions}/{total_files}")
7 months ago
7 months ago
7 months ago
if __name__ == "__main__":
7 months ago
main()