11 changed files with 299 additions and 102 deletions
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -1,118 +1,307 @@ |
|||
import os |
|||
import time |
|||
import glob |
|||
import trimesh |
|||
import numpy as np |
|||
from scipy.spatial import KDTree |
|||
from tqdm import tqdm |
|||
import mesh2sdf |
|||
import skimage.measure |
|||
import matplotlib.pyplot as plt |
|||
from concurrent.futures import ProcessPoolExecutor, as_completed |
|||
import logging |
|||
from typing import Tuple, Optional |
|||
|
|||
# 加载BREP模型并转换为三角网格 |
|||
def load_brep_to_mesh(file_path: str) -> trimesh.Trimesh: |
|||
mesh = trimesh.load_mesh(file_path) |
|||
return mesh |
|||
from OCC.Core.STEPControl import STEPControl_Reader |
|||
from OCC.Core.BRepMesh import BRepMesh_IncrementalMesh |
|||
from OCC.Core.StlAPI import StlAPI_Writer |
|||
from OCC.Core.IFSelect import IFSelect_RetDone |
|||
|
|||
def process_mesh_to_sdf( |
|||
# 配置日志记录 |
|||
def setup_logger(): |
|||
logger = logging.getLogger('furniture_processing') |
|||
logger.setLevel(logging.INFO) |
|||
|
|||
# 文件处理器 |
|||
fh = logging.FileHandler('furniture_processing.log') |
|||
fh.setLevel(logging.INFO) |
|||
|
|||
# 控制台处理器 |
|||
ch = logging.StreamHandler() |
|||
ch.setLevel(logging.INFO) |
|||
|
|||
# 创建格式器 |
|||
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') |
|||
fh.setFormatter(formatter) |
|||
ch.setFormatter(formatter) |
|||
|
|||
# 添加处理器 |
|||
logger.addHandler(fh) |
|||
logger.addHandler(ch) |
|||
|
|||
return logger |
|||
|
|||
logger = setup_logger() |
|||
|
|||
def brep_to_mesh( |
|||
brep_file: str, |
|||
stl_output_file: str, |
|||
obj_output_file: str |
|||
) -> Optional[trimesh.Trimesh]: |
|||
"""将BREP/STEP文件转换为Trimesh对象并保存STL和OBJ文件""" |
|||
try: |
|||
step_reader = STEPControl_Reader() |
|||
status = step_reader.ReadFile(brep_file) |
|||
|
|||
if status != IFSelect_RetDone: |
|||
logger.error(f"无法读取BREP文件: {brep_file}") |
|||
return None |
|||
|
|||
step_reader.TransferRoots() |
|||
shape = step_reader.OneShape() |
|||
|
|||
mesh = BRepMesh_IncrementalMesh(shape, 0.1) |
|||
mesh.Perform() |
|||
|
|||
base_name = os.path.splitext(os.path.basename(brep_file))[0] |
|||
|
|||
# 保存STL |
|||
stl_writer = StlAPI_Writer() |
|||
stl_writer.Write(shape, stl_output_file) |
|||
|
|||
# 保存OBJ |
|||
mesh = trimesh.load(stl_output_file) |
|||
mesh.export(obj_output_file) |
|||
|
|||
logger.debug(f"成功转换并保存mesh: {base_name}") |
|||
return mesh |
|||
|
|||
except Exception as e: |
|||
logger.error(f"转换BREP文件失败 {brep_file}: {str(e)}") |
|||
return None |
|||
|
|||
# 现在加上x,y,z坐标值,并分成pos和neg两个组 |
|||
def sdf2xyzsdf(sdf, size=128): |
|||
# Generate coordinate grids |
|||
x, y, z = np.meshgrid(range(size), range(size), range(size), indexing='ij') |
|||
|
|||
# Stack coordinates along a new axis |
|||
coords = np.stack([x, y, z], axis=-1) |
|||
|
|||
# Reshape to match the SDF shape |
|||
coords = coords.reshape(size, size, size, 3) |
|||
|
|||
# Stack coordinates and SDF values |
|||
xyzsdf = np.concatenate([coords, sdf[..., np.newaxis]], axis=-1) |
|||
|
|||
return xyzsdf |
|||
|
|||
def split_sdf(xyzsdf): |
|||
# Separate positive and negative SDF values |
|||
pos_mask = xyzsdf[..., 3] >= 0 |
|||
neg_mask = xyzsdf[..., 3] < 0 |
|||
|
|||
pos_xyzsdf = xyzsdf[pos_mask] |
|||
neg_xyzsdf = xyzsdf[neg_mask] |
|||
|
|||
return pos_xyzsdf, neg_xyzsdf |
|||
|
|||
|
|||
def mesh_to_sdf( |
|||
mesh: trimesh.Trimesh, |
|||
filename:str, |
|||
save_dir:str, |
|||
idx: str, |
|||
mesh_scale:int = 0.8, |
|||
size:int= 128, |
|||
level: int = None, |
|||
) -> str: |
|||
if not level: |
|||
level = 2 / size |
|||
|
|||
# mesh = trimesh.load(filename, force='mesh') |
|||
|
|||
# normalize mesh |
|||
vertices = mesh.vertices |
|||
bbmin = vertices.min(0) |
|||
bbmax = vertices.max(0) |
|||
center = (bbmin + bbmax) * 0.5 |
|||
scale = 2.0 * mesh_scale / (bbmax - bbmin).max() |
|||
vertices = (vertices - center) * scale |
|||
|
|||
# fix mesh |
|||
t0 = time.time() |
|||
sdf, mesh = mesh2sdf.compute( |
|||
vertices, mesh.faces, size, fix=True, level=level, return_mesh=True) |
|||
t1 = time.time() |
|||
|
|||
# sdf to x,y,z,sdf |
|||
xyzsdf = np.zeros((size, size, size, 4)) |
|||
for x in range(size): |
|||
for y in range(size): |
|||
for z in range(size): |
|||
xyzsdf[x, y, z] = [x, y, z, sdf[x, y, z]] |
|||
# output |
|||
mesh.vertices = mesh.vertices / scale + center |
|||
name = filename.split('/')[-1][:-4] |
|||
#mesh.export(os.path.join(save_dir, filename[:-4] + '.fixed.obj')) |
|||
np.save(os.path.join(save_dir, filename[:-4] + '.xyzsdf.npy'), xyzsdf) |
|||
#mesh.export(filename[:-4] + '.fixed.obj') |
|||
#np.save(filename[:-4] + '.npy', sdf) |
|||
#print('It takes %.4f seconds to process %s' % (t1-t0, filename)) |
|||
return idx |
|||
|
|||
|
|||
''' |
|||
# 定义网格 |
|||
def create_grid(x_range, y_range, z_range, resolution): |
|||
x = np.linspace(x_range[0], x_range[1], resolution) |
|||
y = np.linspace(y_range[0], y_range[1], resolution) |
|||
z = np.linspace(z_range[0], z_range[1], resolution) |
|||
grid = np.meshgrid(x, y, z) |
|||
points = np.vstack(list(map(np.ravel, grid))).T |
|||
return points |
|||
''' |
|||
|
|||
|
|||
|
|||
def test(): |
|||
# 参数设置 |
|||
input_file = "/mnt/disk2/dataset/furniture/step/furniture_dataset_step/train/bathtub_0004.step" |
|||
output_file = "tmp_data" |
|||
|
|||
# 加载BREP模型并转换为三角网格 |
|||
mesh = load_brep_to_mesh(input_file) |
|||
|
|||
process_mesh_to_sdf( |
|||
mesh=mesh, |
|||
filename=input_file, |
|||
save_dir=output_file, |
|||
idx="0000000", |
|||
) |
|||
|
|||
|
|||
|
|||
|
|||
|
|||
|
|||
# 真的执行转换 |
|||
filename: str, |
|||
sdf_output_file: str, |
|||
mesh_scale: float = 0.8, |
|||
size: int = 128, |
|||
level: float = None, |
|||
) -> bool: |
|||
""" |
|||
将mesh转换为SDF并保存结果 |
|||
""" |
|||
try: |
|||
if not level: |
|||
level = 2 / size |
|||
|
|||
# 标准化网格 |
|||
vertices = mesh.vertices |
|||
bbmin = vertices.min(0) |
|||
bbmax = vertices.max(0) |
|||
center = (bbmin + bbmax) * 0.5 |
|||
scale = 2.0 * mesh_scale / (bbmax - bbmin).max() |
|||
vertices = (vertices - center) * scale |
|||
|
|||
# 计算SDF |
|||
sdf, mesh = mesh2sdf.compute( |
|||
vertices, |
|||
mesh.faces, |
|||
size, |
|||
fix=True, |
|||
level=level, |
|||
return_mesh=True |
|||
) |
|||
|
|||
# Get the dimensions of the SDF grid |
|||
size = sdf.shape[0] |
|||
|
|||
# Convert SDF to XYZ+SDF format |
|||
xyzsdf = sdf2xyzsdf(sdf, size) |
|||
|
|||
# Split into positive and negative SDF values |
|||
pos_xyzsdf, neg_xyzsdf = split_sdf(xyzsdf) |
|||
|
|||
# Save the results as a .npz file |
|||
np.savez_compressed(sdf_output_file, pos=pos_xyzsdf, neg=neg_xyzsdf) |
|||
|
|||
logger.debug(f"成功生成并保存SDF: {os.path.basename(sdf_output_file)}") |
|||
return True |
|||
|
|||
except Exception as e: |
|||
logger.error(f"处理文件失败 {filename}: {str(e)}") |
|||
return False |
|||
|
|||
def check_conversion( |
|||
step_file: str, |
|||
obj_path: str, |
|||
sdf_path: str |
|||
) -> bool: |
|||
"""检查转换结果的质量""" |
|||
try: |
|||
base_name = os.path.splitext(os.path.basename(step_file))[0] |
|||
|
|||
# 检查文件是否存在 |
|||
if not (os.path.exists(obj_path) and os.path.exists(sdf_path)): |
|||
logger.warning(f"文件不完整: {base_name}") |
|||
return False |
|||
|
|||
# 加载并检查OBJ |
|||
obj_mesh = trimesh.load(obj_path) |
|||
if not (obj_mesh.is_watertight ): |
|||
logger.warning(f"网格质量不合格: {base_name}") |
|||
return False |
|||
|
|||
# 检查SDF |
|||
sdf_data = np.load(sdf_path) |
|||
# 检查pos和neg数组是否存在 |
|||
if 'pos' not in sdf_data or 'neg' not in sdf_data: |
|||
logger.warning(f"SDF数据格式不正确: {base_name}") |
|||
return False |
|||
# 检查数据维度和内容 |
|||
if np.isnan(sdf_data['pos']).any() or np.isnan(sdf_data['neg']).any(): |
|||
logger.warning(f"SDF数据包含NaN值: {base_name}") |
|||
return False |
|||
|
|||
logger.debug(f"质量检查通过: {base_name}") |
|||
return True |
|||
|
|||
except Exception as e: |
|||
logger.error(f"检查失败 {step_file}: {str(e)}") |
|||
return False |
|||
|
|||
def process(step_file: str, set_name:str) -> bool: |
|||
"""处理单个STEP文件的完整流程""" |
|||
try: |
|||
base_name = os.path.splitext(os.path.basename(step_file))[0] |
|||
logger.info(f"开始处理: {base_name}") |
|||
|
|||
# 准备输出路径 |
|||
stl_output_file = os.path.join('test_data/stl', f"{set_name}/{base_name}.stl") |
|||
obj_output_file = os.path.join('test_data/obj', f"{set_name}/{base_name}.obj") |
|||
sdf_output_file = os.path.join('test_data/sdf', f"{set_name}/{base_name}.npz") |
|||
|
|||
|
|||
# 使用tqdm创建处理步骤的进度条 |
|||
steps = ['STEP->Mesh', 'Mesh->SDF', '质量检查'] |
|||
pbar = tqdm(steps, desc=f"处理 {base_name}", leave=False, position=1) |
|||
|
|||
# 1. STEP到Mesh的转换 |
|||
pbar.set_description(f"处理 {base_name} [STEP->Mesh]") |
|||
mesh = brep_to_mesh(step_file, stl_output_file, obj_output_file) |
|||
if mesh is None: |
|||
pbar.close() |
|||
return False |
|||
pbar.update(1) |
|||
|
|||
# 2. Mesh到SDF的转换 |
|||
pbar.set_description(f"处理 {base_name} [Mesh->SDF]") |
|||
if not mesh_to_sdf(mesh, step_file, sdf_output_file, mesh_scale=0.8, size=128): |
|||
pbar.close() |
|||
return False |
|||
pbar.update(1) |
|||
|
|||
# 3. 检查转换结果 |
|||
pbar.set_description(f"处理 {base_name} [质量检查]") |
|||
if not check_conversion(step_file, obj_output_file, sdf_output_file): |
|||
pbar.close() |
|||
return False |
|||
pbar.update(1) |
|||
|
|||
pbar.close() |
|||
logger.info(f"成功处理: {base_name}") |
|||
return True |
|||
|
|||
except Exception as e: |
|||
logger.error(f"处理失败 {step_file}: {str(e)}") |
|||
return False |
|||
|
|||
def main(): |
|||
"""主函数:并行处理所有STEP文件""" |
|||
INPUT = '/mnt/disk2/dataset/furniture/step/furniture_dataset_step' |
|||
if not os.path.exists(INPUT): |
|||
logger.error(f"输入目录不存在: {INPUT}") |
|||
return |
|||
|
|||
# 创建基础输出目录 |
|||
os.makedirs('test_data/stl', exist_ok=True) |
|||
os.makedirs('test_data/obj', exist_ok=True) |
|||
os.makedirs('test_data/sdf', exist_ok=True) |
|||
|
|||
|
|||
|
|||
logger.info("开始数据处理...") |
|||
|
|||
# 创建总体进度条 |
|||
#for set_name in ['train', 'val', 'test']: |
|||
for set_name in ['val', 'test']: |
|||
|
|||
input_dir = os.path.join(INPUT, set_name) |
|||
if not os.path.exists(input_dir): |
|||
logger.warning(f"目录不存在: {input_dir}") |
|||
continue |
|||
|
|||
os.makedirs(f'test_data/stl/{set_name}', exist_ok=True) |
|||
os.makedirs(f'test_data/obj/{set_name}', exist_ok=True) |
|||
os.makedirs(f'test_data/sdf/{set_name}', exist_ok=True) |
|||
|
|||
step_files = glob.glob(os.path.join(input_dir, '*.step')) |
|||
total_files = len(step_files) |
|||
if total_files == 0: |
|||
logger.error("没有找到任何STEP文件") |
|||
logger.info(f"找到 {set_name} 集合文件: {total_files}个") |
|||
|
|||
INPUT = '/mnt/disk2/dataset/furniture/step/furniture_dataset_step' # 下面train,val和test |
|||
OUTPUT = '/app/data/furniture_sdf' |
|||
valid = 0 |
|||
for set_ in ['train', 'val', 'test']: |
|||
with ProcessPoolExecutor(max_workers=os.cpu_count() // 2) as executor: |
|||
futures = {} |
|||
for step_folder in glob.glob(os.path.join(input_folder, set_, '*.step')): |
|||
future = executor.submit(process, step_folder, timeout=300) |
|||
futures[future] = step_folder |
|||
|
|||
for future in tqdm(as_completed(futures), total=len(step_dirs)): |
|||
try: |
|||
status = future.result(timeout=300) |
|||
valid += status |
|||
except TimeoutError: |
|||
print(f"Timeout occurred while processing {futures[future]}") |
|||
except Exception as e: |
|||
print(f"An error occurred while processing {futures[future]}: {e}") |
|||
valid_conversions = 0 |
|||
# 创建进度条 |
|||
with tqdm(total=total_files, desc=f"{set_name}进度", position=0) as main_pbar: |
|||
with ProcessPoolExecutor(max_workers=min(max(1, os.cpu_count() // 2), 8)) as executor: |
|||
futures = {executor.submit(process, f, set_name): f for f in step_files} |
|||
|
|||
for future in as_completed(futures): |
|||
try: |
|||
if future.result(timeout=300): |
|||
valid_conversions += 1 |
|||
main_pbar.update(1) |
|||
main_pbar.set_postfix({ |
|||
'success_rate': f"{(valid_conversions/total_files)*100:.1f}%" |
|||
}) |
|||
except Exception as e: |
|||
logger.error(f"处理失败 {futures[future]}: {str(e)}") |
|||
main_pbar.update(1) |
|||
success_rate = (valid_conversions / total_files) * 100 # 这个变量在日志中被使用但未定义 |
|||
logger.info(f"处理完成: {set_name} 集合, 成功率: {success_rate:.2f}% = {valid_conversions}/{total_files}个") |
|||
|
|||
print(f'Done... Data Converted Ratio {100.0*valid/len(step_dirs)}%') |
|||
|
|||
|
|||
|
|||
|
|||
|
|||
if __name__ == "__main__": |
|||
test() |
|||
main() |
|||
|
Loading…
Reference in new issue