Browse Source

修改预处理

final
mckay 2 months ago
parent
commit
32d4e45196
  1. 8
      brep2sdf/config/default_config.py
  2. 70
      brep2sdf/data/pre_process.py

8
brep2sdf/config/default_config.py

@ -27,8 +27,8 @@ class ModelConfig:
@dataclass @dataclass
class DataConfig: class DataConfig:
"""数据相关配置""" """数据相关配置"""
max_face: int = 16 max_face: int = 80
max_edge: int = 64 max_edge: int = 16
num_query_points: int = 32*32*32 # 限制查询点数量,sdf 采样点数 本来是 128*128*128 ,在data load时随机采样 num_query_points: int = 32*32*32 # 限制查询点数量,sdf 采样点数 本来是 128*128*128 ,在data load时随机采样
bbox_scaled: float = 1.0 bbox_scaled: float = 1.0
@ -47,7 +47,7 @@ class TrainConfig:
# 基本训练参数 # 基本训练参数
batch_size: int = 8 batch_size: int = 8
num_workers: int = 4 num_workers: int = 4
num_epochs: int = 500 num_epochs: int = 20
learning_rate: float = 0.01 learning_rate: float = 0.01
min_lr: float = 1e-5 min_lr: float = 1e-5
weight_decay: float = 0.01 weight_decay: float = 0.01
@ -62,7 +62,7 @@ class TrainConfig:
warmup_epochs: int = 5 warmup_epochs: int = 5
# 保存和验证 # 保存和验证
save_freq: int = 10 # 每多少个epoch保存一次 save_freq: int = 20 # 每多少个epoch保存一次
val_freq: int = 1 # 每多少个epoch验证一次 val_freq: int = 1 # 每多少个epoch验证一次
# 保存路径 # 保存路径

70
brep2sdf/data/pre_process.py

@ -93,49 +93,25 @@ def normalize(surfs, edges, corners):
scale scale
) )
def get_adjacency_info(shape): def get_adjacency_info(shape, faces, edges, vertices):
""" """
获取CAD模型中面顶点之间的邻接关系 优化后的邻接关系计算函数直接使用已收集的几何元素
参数: 参数新增:
shape: CAD模型的形状对象 faces: 已收集的面列表
edges: 已收集的边列表
返回: vertices: 已收集的顶点列表
edgeFace_adj: -面邻接矩阵 (num_edges × num_faces)
faceEdge_adj: -边邻接矩阵 (num_faces × num_edges)
edgeCorner_adj: -顶点邻接矩阵 (num_edges × 2)
""" """
logger.debug("Get adjacency infos...")
# 创建边-面映射关系 # 创建边-面映射关系
edge_face_map = TopTools_IndexedDataMapOfShapeListOfShape() edge_face_map = TopTools_IndexedDataMapOfShapeListOfShape()
topexp.MapShapesAndAncestors(shape, TopAbs_EDGE, TopAbs_FACE, edge_face_map) topexp.MapShapesAndAncestors(shape, TopAbs_EDGE, TopAbs_FACE, edge_face_map)
# 获取所有几何元素 # 直接使用传入的几何元素列表
faces = [] # 存储所有面
edges = [] # 存储所有边
vertices = [] # 存储所有顶点
# 创建拓扑结构探索器
face_explorer = TopExp_Explorer(shape, TopAbs_FACE)
edge_explorer = TopExp_Explorer(shape, TopAbs_EDGE)
vertex_explorer = TopExp_Explorer(shape, TopAbs_VERTEX)
# 收集所有几何元素
while face_explorer.More():
faces.append(topods.Face(face_explorer.Current()))
face_explorer.Next()
while edge_explorer.More():
edges.append(topods.Edge(edge_explorer.Current()))
edge_explorer.Next()
while vertex_explorer.More():
vertices.append(topods.Vertex(vertex_explorer.Current()))
vertex_explorer.Next()
# 创建邻接矩阵
num_faces = len(faces) num_faces = len(faces)
num_edges = len(edges) num_edges = len(edges)
num_vertices = len(vertices) num_vertices = len(vertices)
logger.debug(f"num_faces: {num_faces}, num_edges: {num_edges}, num_vertices: {num_vertices}")
edgeFace_adj = np.zeros((num_edges, num_faces), dtype=np.int32) edgeFace_adj = np.zeros((num_edges, num_faces), dtype=np.int32)
faceEdge_adj = np.zeros((num_faces, num_edges), dtype=np.int32) faceEdge_adj = np.zeros((num_faces, num_edges), dtype=np.int32)
@ -244,9 +220,13 @@ def parse_solid(step_path):
surf_bbox_wcs = [] surf_bbox_wcs = []
edge_bbox_wcs = [] edge_bbox_wcs = []
faces, edges, vertices = [], [], []
# Extract face points # Extract face points
logger.debug("Extract face points...")
while face_explorer.More(): while face_explorer.More():
face = topods.Face(face_explorer.Current()) face = topods.Face(face_explorer.Current())
faces.append(face)
loc = TopLoc_Location() loc = TopLoc_Location()
triangulation = BRep_Tool.Triangulation(face, loc) triangulation = BRep_Tool.Triangulation(face, loc)
@ -267,11 +247,18 @@ def parse_solid(step_path):
surf_bbox_wcs.append(get_bbox(shape, face)) surf_bbox_wcs.append(get_bbox(shape, face))
face_explorer.Next() face_explorer.Next()
face_count = len(faces)
if face_count > MAX_FACE:
logger.error(f"step has {face_count} faces, which exceeds MAX_FACE {MAX_FACE}")
return None
# Extract edge points # Extract edge points
logger.debug("Extract edge points...")
num_samples = config.model.num_edge_points # 使用配置中的边采样点数 num_samples = config.model.num_edge_points # 使用配置中的边采样点数
while edge_explorer.More(): while edge_explorer.More():
edge = topods.Edge(edge_explorer.Current()) edge = topods.Edge(edge_explorer.Current())
edges.append(edge)
logger.debug(len(edges))
curve_info = BRep_Tool.Curve(edge) curve_info = BRep_Tool.Curve(edge)
if curve_info is None: if curve_info is None:
continue # 跳过无效边 continue # 跳过无效边
@ -280,15 +267,12 @@ def parse_solid(step_path):
if len(curve_info) == 3: if len(curve_info) == 3:
curve, first, last = curve_info curve, first, last = curve_info
elif len(curve_info) == 2: elif len(curve_info) == 2:
continue curve = None # 跳过判断
curve, location = curve_info
logger.info(curve)
first, last = BRep_Tool.Range(edge) # 显式获取参数范围
else: else:
raise ValueError(f"Unexpected curve info: {curve_info}") raise ValueError(f"Unexpected curve info: {curve_info}")
except Exception as e: except Exception as e:
logger.error(f"Failed to process edge {edge}: {str(e)}") logger.error(f"Failed to process edge {edge}: {str(e)}")
continue curve = None
if curve is not None: if curve is not None:
points = [] points = []
@ -306,14 +290,22 @@ def parse_solid(step_path):
edge_explorer.Next() edge_explorer.Next()
# Extract vertex points # Extract vertex points
logger.debug("Extract vertex points...")
while vertex_explorer.More(): while vertex_explorer.More():
vertex = topods.Vertex(vertex_explorer.Current()) vertex = topods.Vertex(vertex_explorer.Current())
vertices.append(vertex)
pnt = BRep_Tool.Pnt(vertex) pnt = BRep_Tool.Pnt(vertex)
corner_pnts.append([pnt.X(), pnt.Y(), pnt.Z()]) corner_pnts.append([pnt.X(), pnt.Y(), pnt.Z()])
vertex_explorer.Next() vertex_explorer.Next()
# 获取邻接信息 # 获取邻接信息
edgeFace_adj, faceEdge_adj, edgeCorner_adj = get_adjacency_info(shape) edgeFace_adj, faceEdge_adj, edgeCorner_adj = get_adjacency_info(
shape,
faces=faces, # 传入已收集的面列表
edges=edges, # 传入已收集的边列表
vertices=vertices # 传入已收集的顶点列表
)
logger.debug("complete.")
# 转换为numpy数组时确保类型正确 # 转换为numpy数组时确保类型正确
face_pnts = [np.array(points, dtype=np.float32) for points in face_pnts] face_pnts = [np.array(points, dtype=np.float32) for points in face_pnts]

Loading…
Cancel
Save