Browse Source

增加log,修改corner

main
mckay 7 months ago
parent
commit
a1da6a8084
  1. 267
      brep2sdf/data/data.py
  2. 40
      brep2sdf/data/utils.py
  3. 43
      brep2sdf/scripts/process_brep.py
  4. 3
      brep2sdf/scripts/read_pkl.py
  5. 86
      brep2sdf/utils/logger.py

267
brep2sdf/data/data.py

@ -3,10 +3,8 @@ import torch
from torch.utils.data import Dataset from torch.utils.data import Dataset
import numpy as np import numpy as np
import pickle import pickle
from brep2sdf.utils.logger import setup_logger from brep2sdf.utils.logger import logger
from .utils import process_brep_data from brep2sdf.data.utils import process_brep_data
# 设置日志记录器
logger = setup_logger('dataset')
@ -22,10 +20,16 @@ class BRepSDFDataset(Dataset):
sdf_dir: npz文件目录 sdf_dir: npz文件目录
split: 数据集分割('train', 'val', 'test') split: 数据集分割('train', 'val', 'test')
""" """
super().__init__()
self.brep_dir = os.path.join(brep_dir, split) self.brep_dir = os.path.join(brep_dir, split)
self.sdf_dir = os.path.join(sdf_dir, split) self.sdf_dir = os.path.join(sdf_dir, split)
self.split = split self.split = split
# 设置固定参数
self.max_face = 70
self.max_edge = 70
self.bbox_scaled = 1.0
# 检查目录是否存在 # 检查目录是否存在
if not os.path.exists(self.brep_dir): if not os.path.exists(self.brep_dir):
raise ValueError(f"B-rep directory not found: {self.brep_dir}") raise ValueError(f"B-rep directory not found: {self.brep_dir}")
@ -59,112 +63,187 @@ class BRepSDFDataset(Dataset):
def __getitem__(self, idx): def __getitem__(self, idx):
"""获取单个数据样本""" """获取单个数据样本"""
try:
brep_path = self.brep_data_list[idx] brep_path = self.brep_data_list[idx]
sdf_path = self.sdf_data_list[idx] sdf_path = self.sdf_data_list[idx]
try:
# 获取文件名(不含扩展名)作为sample name
name = os.path.splitext(os.path.basename(brep_path))[0] name = os.path.splitext(os.path.basename(brep_path))[0]
# 加载B-rep和SDF数据 # 加载B-rep和SDF数据
brep_data = self._load_brep_file(brep_path) with open(brep_path, 'rb') as f:
brep_raw = pickle.load(f)
sdf_data = self._load_sdf_file(sdf_path) sdf_data = self._load_sdf_file(sdf_path)
# 修改返回格式,将sdf_data作为一个键值对添加 try:
# 处理B-rep数据
brep_features = process_brep_data(
data=brep_raw,
max_face=self.max_face,
max_edge=self.max_edge,
bbox_scaled=self.bbox_scaled
)
# 检查返回值的类型和数量
if not isinstance(brep_features, tuple):
logger.error(f"process_brep_data returned {type(brep_features)}, expected tuple")
raise ValueError("Invalid return type from process_brep_data")
if len(brep_features) != 6:
logger.error(f"Expected 6 features, got {len(brep_features)}")
logger.error("Features returned:")
for i, feat in enumerate(brep_features):
if isinstance(feat, torch.Tensor):
logger.error(f" {i}: Tensor of shape {feat.shape}")
else:
logger.error(f" {i}: {type(feat)}")
raise ValueError(f"Incorrect number of features: {len(brep_features)}")
# 解包处理后的特征
edge_ncs, edge_pos, edge_mask, surf_ncs, surf_pos, vertex_pos = brep_features
# 构建返回字典
return { return {
'name': name, 'name': name,
**brep_data, # 解包B-rep特征 'edge_ncs': edge_ncs, # [max_face, max_edge, 10, 3]
'sdf': sdf_data # 添加SDF数据作为一个键 'edge_pos': edge_pos, # [max_face, max_edge, 6]
'edge_mask': edge_mask, # [max_face, max_edge]
'surf_ncs': surf_ncs, # [max_face, 100, 3]
'surf_pos': surf_pos, # [max_face, 6]
'vertex_pos': vertex_pos, # [max_face, max_edge, 6]
'sdf': sdf_data # [N, 4]
} }
except Exception as e:
logger.error(f"\nError processing B-rep data for file: {brep_path}")
logger.error(f"Error type: {type(e).__name__}")
logger.error(f"Error message: {str(e)}")
# 打印原始数据的结构
logger.error("\nRaw data structure:")
for key, value in brep_raw.items():
if isinstance(value, list):
logger.error(f" {key}: list of length {len(value)}")
if value:
logger.error(f" First element type: {type(value[0])}")
if hasattr(value[0], 'shape'):
logger.error(f" First element shape: {value[0].shape}")
elif hasattr(value, 'shape'):
logger.error(f" {key}: shape {value.shape}")
else:
logger.error(f" {key}: {type(value)}")
raise
except Exception as e: except Exception as e:
logger.error(f"Error loading sample from {brep_path}: {str(e)}") logger.error(f"Error loading sample from {brep_path}: {str(e)}")
logger.error("Data structure:") logger.error("Data structure:")
if 'brep_data' in locals():
for key, value in brep_data.items():
if isinstance(value, np.ndarray):
logger.error(f" {key}: type={type(value)}, dtype={value.dtype}, shape={value.shape}")
raise raise
def _load_brep_file(self, brep_path): def _load_brep_file(self, brep_path):
"""加载B-rep特征文件""" """加载B-rep特征文件"""
try: try:
# 1. 加载原始数据
with open(brep_path, 'rb') as f: with open(brep_path, 'rb') as f:
brep_data = pickle.load(f) raw_data = pickle.load(f)
brep_data = {} brep_data = {}
# 1. 处理几何数据(不等长序列) # 2. 处理几何数据(不等长序列)
for key in ['surf_wcs', 'surf_ncs', 'edge_wcs', 'edge_ncs']: geom_keys = ['surf_wcs', 'surf_ncs', 'edge_wcs', 'edge_ncs']
if key in brep_data: for key in geom_keys:
if key in raw_data:
try:
# 确保数据是列表
if not isinstance(raw_data[key], list):
raise ValueError(f"{key} is not a list")
# 转换每个元素为张量
tensors = []
for i, x in enumerate(raw_data[key]):
try: try:
brep_data[key] = [ # 先转换为numpy数组
torch.from_numpy(np.array(x, dtype=np.float32)) arr = np.array(x, dtype=np.float32)
for x in brep_data[key] # 再转换为张量
] tensor = torch.from_numpy(arr)
tensors.append(tensor)
except Exception as e: except Exception as e:
logger.error(f"Error converting {key}:") logger.error(f"Error converting {key}[{i}]:")
logger.error(f" Type: {type(brep_data[key])}") logger.error(f" Data type: {type(x)}")
if isinstance(brep_data[key], list): if isinstance(x, np.ndarray):
logger.error(f" List length: {len(brep_data[key])}") logger.error(f" Shape: {x.shape}")
if len(brep_data[key]) > 0: logger.error(f" dtype: {x.dtype}")
logger.error(f" First element type: {type(brep_data[key][0])}") raise ValueError(f"Failed to convert {key}[{i}]: {str(e)}")
logger.error(f" First element shape: {brep_data[key][0].shape}")
logger.error(f" First element dtype: {brep_data[key][0].dtype}")
raise ValueError(f"Failed to convert {key}: {str(e)}")
# 2. 处理固定形状的数据 brep_data[key] = tensors
for key in ['corner_wcs', 'corner_unique', 'surf_bbox_wcs', 'edge_bbox_wcs']:
if key in brep_data: except Exception as e:
logger.error(f"Error processing {key}:")
logger.error(f" Raw data type: {type(raw_data[key])}")
raise ValueError(f"Failed to process {key}: {str(e)}")
# 3. 处理固定形状的数据
fixed_keys = ['corner_wcs', 'corner_unique', 'surf_bbox_wcs', 'edge_bbox_wcs']
for key in fixed_keys:
if key in raw_data:
try: try:
data = np.array(brep_data[key], dtype=np.float32) # 直接从原始数据转换
brep_data[key] = torch.from_numpy(data) arr = np.array(raw_data[key], dtype=np.float32)
brep_data[key] = torch.from_numpy(arr)
except Exception as e: except Exception as e:
logger.error(f"Error converting {key}:") logger.error(f"Error converting fixed shape data {key}:")
logger.error(f" Type: {type(brep_data[key])}") logger.error(f" Raw data type: {type(raw_data[key])}")
if isinstance(brep_data[key], np.ndarray): if isinstance(raw_data[key], np.ndarray):
logger.error(f" Shape: {brep_data[key].shape}") logger.error(f" Shape: {raw_data[key].shape}")
logger.error(f" dtype: {brep_data[key].dtype}") logger.error(f" dtype: {raw_data[key].dtype}")
raise ValueError(f"Failed to convert {key}: {str(e)}") raise ValueError(f"Failed to convert {key}: {str(e)}")
# 3. 处理邻接矩阵 # 4. 处理邻接矩阵
for key in ['edgeFace_adj', 'edgeCorner_adj', 'faceEdge_adj']: adj_keys = ['edgeFace_adj', 'edgeCorner_adj', 'faceEdge_adj']
if key in brep_data: for key in adj_keys:
if key in raw_data:
try: try:
data = np.array(brep_data[key], dtype=np.int32) # 转换为整型数组
brep_data[key] = torch.from_numpy(data) arr = np.array(raw_data[key], dtype=np.int32)
brep_data[key] = torch.from_numpy(arr)
except Exception as e: except Exception as e:
logger.error(f"Error converting {key}:") logger.error(f"Error converting adjacency matrix {key}:")
logger.error(f" Type: {type(brep_data[key])}") logger.error(f" Raw data type: {type(raw_data[key])}")
if isinstance(brep_data[key], np.ndarray): if isinstance(raw_data[key], np.ndarray):
logger.error(f" Shape: {brep_data[key].shape}") logger.error(f" Shape: {raw_data[key].shape}")
logger.error(f" dtype: {brep_data[key].dtype}") logger.error(f" dtype: {raw_data[key].dtype}")
raise ValueError(f"Failed to convert {key}: {str(e)}") raise ValueError(f"Failed to convert {key}: {str(e)}")
feature_embedder = process_brep_data(brep_data, 70,70,1,)
return feature_embedder # 5. 验证必要的键是否存在
required_keys = {'surf_wcs', 'edge_wcs', 'corner_wcs'}
missing_keys = required_keys - set(brep_data.keys())
if missing_keys:
raise ValueError(f"Missing required keys: {missing_keys}")
# 6. 使用process_brep_data处理数据
try:
features = process_brep_data(
data=brep_data,
max_face=self.max_face,
max_edge=self.max_edge,
bbox_scaled=self.bbox_scaled
)
return features
except Exception as e: except Exception as e:
logger.error(f"\nError loading B-rep file: {brep_path}") logger.error("Error in process_brep_data:")
logger.error(f" Error message: {str(e)}") logger.error(f" Error message: {str(e)}")
# 打印数据形状信息
# 打印完整的数据结构信息 logger.error("\nInput data shapes:")
if 'brep_data' in locals():
logger.error("\nComplete data structure:")
for key, value in brep_data.items(): for key, value in brep_data.items():
logger.error(f"\n{key}:") if isinstance(value, list):
logger.error(f" Type: {type(value)}") shapes = [t.shape for t in value]
if isinstance(value, np.ndarray): logger.error(f" {key}: list of tensors with shapes {shapes}")
logger.error(f" Shape: {value.shape}") elif isinstance(value, torch.Tensor):
logger.error(f" dtype: {value.dtype}") logger.error(f" {key}: tensor of shape {value.shape}")
elif isinstance(value, list): raise
logger.error(f" List length: {len(value)}")
if len(value) > 0: except Exception as e:
logger.error(f" First element type: {type(value[0])}") logger.error(f"\nError loading B-rep file: {brep_path}")
if isinstance(value[0], np.ndarray): logger.error(f"Error message: {str(e)}")
logger.error(f" First element shape: {value[0].shape}")
logger.error(f" First element dtype: {value[0].dtype}")
raise raise
def _load_sdf_file(self, sdf_path): def _load_sdf_file(self, sdf_path):
@ -190,6 +269,52 @@ class BRepSDFDataset(Dataset):
logger.error(f"Error type: {type(e).__name__}") logger.error(f"Error type: {type(e).__name__}")
logger.error(f"Error message: {str(e)}") logger.error(f"Error message: {str(e)}")
raise raise
@staticmethod
def collate_fn(batch):
"""自定义批处理函数"""
# 收集所有样本的名称
names = [item['name'] for item in batch]
# 处理固定大小的张量数据
tensor_keys = ['edge_ncs', 'edge_pos', 'edge_mask',
'surf_ncs', 'surf_pos', 'vertex_pos']
tensors = {
key: torch.stack([item[key] for item in batch])
for key in tensor_keys
}
# 处理变长的SDF数据
sdf_data = [item['sdf'] for item in batch]
max_sdf_len = max(sdf.size(0) for sdf in sdf_data)
# 填充SDF数据
padded_sdfs = []
sdf_masks = []
for sdf in sdf_data:
pad_len = max_sdf_len - sdf.size(0)
if pad_len > 0:
padding = torch.zeros(pad_len, sdf.size(1),
dtype=sdf.dtype, device=sdf.device)
padded_sdf = torch.cat([sdf, padding], dim=0)
mask = torch.cat([
torch.ones(sdf.size(0), dtype=torch.bool),
torch.zeros(pad_len, dtype=torch.bool)
])
else:
padded_sdf = sdf
mask = torch.ones(sdf.size(0), dtype=torch.bool)
padded_sdfs.append(padded_sdf)
sdf_masks.append(mask)
# 合并所有数据
batch_data = {
'name': names,
'sdf': torch.stack(padded_sdfs),
'sdf_mask': torch.stack(sdf_masks),
**tensors
}
return batch_data
def test_dataset(): def test_dataset():
"""测试数据集功能""" """测试数据集功能"""

40
brep2sdf/data/utils.py

@ -8,6 +8,8 @@ import argparse
from chamferdist import ChamferDistance from chamferdist import ChamferDistance
from mpl_toolkits.mplot3d.art3d import Poly3DCollection from mpl_toolkits.mplot3d.art3d import Poly3DCollection
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
from brep2sdf.utils.logger import logger
from OCC.Core.gp import gp_Pnt, gp_Pnt from OCC.Core.gp import gp_Pnt, gp_Pnt
from OCC.Core.TColgp import TColgp_Array2OfPnt from OCC.Core.TColgp import TColgp_Array2OfPnt
@ -986,7 +988,15 @@ def process_brep_data(
- data_class: (可选) 类别标签 [1] - data_class: (可选) 类别标签 [1]
""" """
# 解包数据 # 解包数据
_, _, surf_ncs, edge_ncs, corner_wcs, _, _, faceEdge_adj, surf_pos, edge_pos, _, _ = data.values() #_, _, surf_ncs, edge_ncs, corner_wcs, _, _, faceEdge_adj, surf_pos, edge_pos, _ = data.values()
# 直接获取需要的数据
surf_ncs = data['surf_ncs'] # (num_faces,) -> 每个元素形状 (N, 3)
edge_ncs = data['edge_ncs'] # (num_edges, 100, 3)
corner_wcs = data['corner_wcs'] # (num_corners, 3)
faceEdge_adj = data['faceEdge_adj'] # (num_faces, num_edges)
edgeCorner_adj = data['edgeCorner_adj'] # (num_edges, 2) 每条边连接2个顶点
surf_pos = data['surf_bbox_wcs'] # (num_faces, 6)
edge_pos = data['edge_bbox_wcs'] # (num_edges, 6)
# 数据增强 # 数据增强
random_num = np.random.rand() random_num = np.random.rand()
@ -1002,9 +1012,13 @@ def process_brep_data(
surfpos_corners = rotate_axis(surfpos_corners, angle, axis, normalized=True) surfpos_corners = rotate_axis(surfpos_corners, angle, axis, normalized=True)
edgepos_corners = rotate_axis(edgepos_corners, angle, axis, normalized=True) edgepos_corners = rotate_axis(edgepos_corners, angle, axis, normalized=True)
corner_wcs = rotate_axis(corner_wcs, angle, axis, normalized=True) corner_wcs = rotate_axis(corner_wcs, angle, axis, normalized=True)
surf_ncs = rotate_axis(surf_ncs, angle, axis, normalized=False)
# 对每个面的点云进行旋转
for i in range(len(surf_ncs)):
surf_ncs[i] = rotate_axis(surf_ncs[i], angle, axis, normalized=False)
edge_ncs = rotate_axis(edge_ncs, angle, axis, normalized=False) edge_ncs = rotate_axis(edge_ncs, angle, axis, normalized=False)
# 重新计算边界框 # 重新计算边界框
surf_pos = get_bbox(surfpos_corners) # [num_faces, 2, 3] surf_pos = get_bbox(surfpos_corners) # [num_faces, 2, 3]
surf_pos = surf_pos.reshape(len(surf_pos), 6) # [num_faces, 6] surf_pos = surf_pos.reshape(len(surf_pos), 6) # [num_faces, 6]
@ -1017,14 +1031,19 @@ def process_brep_data(
corner_wcs = corner_wcs * bbox_scaled # [num_edges, 2, 3] corner_wcs = corner_wcs * bbox_scaled # [num_edges, 2, 3]
# 特征复制 # 特征复制
edge_pos_duplicated: List[np.ndarray] = [] # 每个元素形状: [num_edges_per_face, 6] edge_pos_duplicated: List[np.ndarray] = [] # 每个元素形状: [num_edges, 6]
vertex_pos_duplicated: List[np.ndarray] = [] # 每个元素形状: [num_edges_per_face, 6] vertex_pos_duplicated: List[np.ndarray] = [] # 每个元素形状: [num_edges, 6]
edge_ncs_duplicated: List[np.ndarray] = [] # 每个元素形状: [num_edges_per_face, 10, 3] edge_ncs_duplicated: List[np.ndarray] = [] # 每个元素形状: [num_edges, 10, 3]
for adj in faceEdge_adj: # [num_faces, num_edges] for adj in faceEdge_adj: # [num_faces, num_edges]
edge_ncs_duplicated.append(edge_ncs[adj]) # [num_edges_per_face, 10, 3] edge_ncs_duplicated.append(edge_ncs[adj]) # [num_edges, 10, 3]
edge_pos_duplicated.append(edge_pos[adj]) # [num_edges_per_face, 6] edge_pos_duplicated.append(edge_pos[adj]) # [num_edges, 6]
corners = corner_wcs[adj] # [num_edges_per_face, 2, 3] #corners = corner_wcs[adj] # [num_vertax, 3] FIXME
edge_indices = np.where(adj)[0] # 获取当前面的边索引
corner_indices = edgeCorner_adj[edge_indices] # 获取这些边对应的顶点索引
corners = corner_wcs[corner_indices] # 获取顶点坐标
logger.debug(corners)
corners_sorted = [] corners_sorted = []
for corner in corners: # [2, 3] for corner in corners: # [2, 3]
sorted_indices = np.lexsort((corner[:, 2], corner[:, 1], corner[:, 0])) sorted_indices = np.lexsort((corner[:, 2], corner[:, 1], corner[:, 0]))
@ -1062,7 +1081,10 @@ def process_brep_data(
random_indices = np.random.permutation(surf_pos.shape[0]) random_indices = np.random.permutation(surf_pos.shape[0])
surf_pos = surf_pos[random_indices] # [num_faces, 6] surf_pos = surf_pos[random_indices] # [num_faces, 6]
edge_pos = edge_pos[random_indices] # [num_faces, max_edge, 6] edge_pos = edge_pos[random_indices] # [num_faces, max_edge, 6]
surf_ncs = surf_ncs[random_indices] # [num_faces, 100, 3] # 修改这里:surf_ncs是对象数组,需要特殊处理
surf_ncs_new = np.array([surf_ncs[i] for i in random_indices]) # 重新排列对象数组
surf_ncs = surf_ncs_new
#surf_ncs = surf_ncs[random_indices] # [num_faces, 100, 3]
edge_ncs = edge_ncs[random_indices] # [num_faces, max_edge, 10, 3] edge_ncs = edge_ncs[random_indices] # [num_faces, max_edge, 10, 3]
edge_mask = edge_mask[random_indices] # [num_faces, max_edge] edge_mask = edge_mask[random_indices] # [num_faces, max_edge]
vertex_pos = vertex_pos[random_indices] # [num_faces, max_edge, 6] vertex_pos = vertex_pos[random_indices] # [num_faces, max_edge, 6]

43
brep2sdf/scripts/process_brep.py

@ -184,32 +184,27 @@ def get_bbox(shape, subshape):
def parse_solid(step_path): def parse_solid(step_path):
"""Parse the surface, curve, face, edge, vertex in a CAD solid using OCC."""
""" """
返回值一个dict包括 解析STEP文件中的CAD模型数据
几何关系
# 面片数据
'surf_wcs': surfs_wcs, # list of np.array(N, 3), 每个面片的点云坐标
'surf_ncs': surfs_ncs, # list of np.array(N, 3), 每个面片的法向量
# 边数据 (num_samples=100)
'edge_wcs': edges_wcs, # list of np.array(100, 3), 每条边的采样点坐标
'edge_ncs': edges_ncs, # list of np.array(100, 3), 每条边的法向量
# 顶点数据
'corner_wcs': corner_wcs.astype(np.float32), # np.array(N, 3), 所有顶点坐标
'corner_unique': np.unique(corner_wcs, axis=0).astype(np.float32), # np.array(M, 3), 去重后的顶点坐标
拓扑关系
# 邻接矩阵,都是int32类型
'edgeFace_adj': edgeFace_adj, # np.array(num_edges, num_faces), 边-面邻接关系
'edgeCorner_adj': edgeCorner_adj, # np.array(num_edges, 2), 边-顶点邻接关系
'faceEdge_adj': faceEdge_adj, # np.array(num_faces, num_edges), 面-边邻接关系
包围盒数据
# 包围盒坐标,float32类型
'surf_bbox_wcs': surf_bbox_wcs, # np.array(num_faces, 6), 每个面的包围盒 [xmin,ymin,zmin,xmax,ymax,zmax]
'edge_bbox_wcs': edge_bbox_wcs, # np.array(num_edges, 6), 每条边的包围盒 [xmin,ymin,zmin,xmax,ymax,zmax]
返回:
dict: 包含以下键值对的字典:
# 几何数据
'surf_wcs': np.ndarray(dtype=object) # 形状为(N,)的数组,每个元素是形状为(M, 3)的float32数组,表示面的点云坐标
'edge_wcs': np.ndarray(dtype=object) # 形状为(N,)的数组,每个元素是形状为(100, 3)的float32数组,表示边的采样点坐标
'surf_ncs': np.ndarray(dtype=object) # 形状为(N,)的数组,每个元素是形状为(M, 3)的float32数组,表示归一化后的面点云
'edge_ncs': np.ndarray(dtype=object) # 形状为(N,)的数组,每个元素是形状为(100, 3)的float32数组,表示归一化后的边采样点
'corner_wcs': np.ndarray(dtype=float32) # 形状为(K, 3)的数组,表示所有顶点的坐标
'corner_unique': np.ndarray(dtype=float32) # 形状为(L, 3)的数组,表示去重后的顶点坐标
# 拓扑关系
'edgeFace_adj': np.ndarray(dtype=int32) # 形状为(num_edges, num_faces)的数组,表示边-面邻接关系
'edgeCorner_adj': np.ndarray(dtype=int32) # 形状为(num_edges, 2)的数组,表示边-顶点邻接关系
'faceEdge_adj': np.ndarray(dtype=int32) # 形状为(num_faces, num_edges)的数组,表示面-边邻接关系
# 包围盒数据
'surf_bbox_wcs': np.ndarray(dtype=float32) # 形状为(num_faces, 6)的数组,表示每个面的包围盒[xmin,ymin,zmin,xmax,ymax,zmax]
'edge_bbox_wcs': np.ndarray(dtype=float32) # 形状为(num_edges, 6)的数组,表示每条边的包围盒[xmin,ymin,zmin,xmax,ymax,zmax]
""" """
# Load STEP file # Load STEP file
reader = STEPControl_Reader() reader = STEPControl_Reader()

3
brep2sdf/scripts/read_pkl.py

@ -8,7 +8,6 @@ def inspect_data(pkl_file):
"""检查并显示pickle文件中的数据结构""" """检查并显示pickle文件中的数据结构"""
with open(pkl_file, 'rb') as f: with open(pkl_file, 'rb') as f:
data = pickle.load(f) data = pickle.load(f)
print("数据结构概览:") print("数据结构概览:")
print("=" * 50) print("=" * 50)
@ -36,5 +35,5 @@ def inspect_data(pkl_file):
print(f"值: {value}") print(f"值: {value}")
if __name__ == "__main__": if __name__ == "__main__":
pkl_file = "./bathtub_0011.pkl" # 替换为你的文件路径 pkl_file = "/home/wch/brep2sdf/test_data/pkl/train/bathtub_0004.pkl" # 替换为你的文件路径
inspect_data(pkl_file) inspect_data(pkl_file)

86
brep2sdf/utils/logger.py

@ -1,46 +1,82 @@
import os import os
import sys
import logging import logging
import traceback
from datetime import datetime from datetime import datetime
def setup_logger(name, log_dir='logs'): class BRepLogger:
""" _instance = None
设置日志记录器
参数: def __new__(cls):
name: 日志记录器名称 if cls._instance is None:
log_dir: 日志文件存储目录 cls._instance = super().__new__(cls)
cls._instance._initialize_logger()
return cls._instance
返回: def _initialize_logger(self):
logger: 配置好的日志记录器 """初始化日志记录器"""
"""
# 创建logs目录 # 创建logs目录
log_dir = 'logs'
os.makedirs(log_dir, exist_ok=True) os.makedirs(log_dir, exist_ok=True)
# 创建日志记录器 # 创建logger
logger = logging.getLogger(name) self.logger = logging.getLogger('BRepLogger')
logger.setLevel(logging.INFO) self.logger.setLevel(logging.DEBUG)
# 如果logger已经有处理器,则不添加 # 如果logger已经有处理器,则返回
if logger.handlers: if self.logger.handlers:
return logger return
# 创建格式化器 # 创建格式化器
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') formatter = logging.Formatter(
'%(asctime)s - %(levelname)s - '
'%(filename)s:%(lineno)d in %(funcName)s - %(message)s'
)
# 创建文件处理器 # 创建文件处理器
current_time = datetime.now().strftime('%Y%m%d_%H%M%S') current_time = datetime.now().strftime('%Y%m%d_%H%M%S')
log_file = os.path.join(log_dir, f'{name}_{current_time}.log') log_file = os.path.join(log_dir, f'brep2sdf_{current_time}.log')
file_handler = logging.FileHandler(log_file, encoding='utf-8') file_handler = logging.FileHandler(log_file, encoding='utf-8')
file_handler.setLevel(logging.INFO) file_handler.setLevel(logging.DEBUG)
file_handler.setFormatter(formatter) file_handler.setFormatter(formatter)
# 添加文件处理器到日志记录器 # 创建控制台处理器
logger.addHandler(file_handler) console_handler = logging.StreamHandler(sys.stdout)
console_handler.setLevel(logging.DEBUG)
console_handler.setFormatter(formatter)
# 添加处理器
self.logger.addHandler(file_handler)
self.logger.addHandler(console_handler)
# 记录初始信息 # 记录初始信息
logger.info("="*50) self.logger.info("="*50)
logger.info(f"Logger initialized: {name}") self.logger.info("BRep Logger initialized")
logger.info(f"Log file: {log_file}") self.logger.info(f"Log file: {log_file}")
logger.info("="*50) self.logger.info("="*50)
def debug(self, msg):
"""带调用位置的调试信息"""
caller = traceback.extract_stack()[-2]
filename = os.path.basename(caller.filename)
self.logger.debug(f"{msg} (in {filename})")
def info(self, msg):
self.logger.info(msg)
def warning(self, msg):
self.logger.warning(msg)
def error(self, msg, include_trace=True):
"""错误信息,可选是否包含调用栈"""
if include_trace:
self.logger.error(msg, exc_info=True, stack_info=True)
else:
self.logger.error(msg)
def exception(self, msg):
"""异常信息,总是包含异常堆栈"""
self.logger.exception(msg)
return logger # 创建全局logger实例
logger = BRepLogger()
Loading…
Cancel
Save