From dc3fe55c4fc14ed10c6bbf19ac0927dac5ff4d26 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E7=90=9B=E6=B6=B5?= Date: Wed, 13 Nov 2024 17:35:49 +0800 Subject: [PATCH] =?UTF-8?q?=E5=8A=A0log?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- data/data.py | 81 ++++++++++++++++++++++++++++++++----------------- utils/logger.py | 46 ++++++++++++++++++++++++++++ 2 files changed, 99 insertions(+), 28 deletions(-) create mode 100644 utils/logger.py diff --git a/data/data.py b/data/data.py index 78153ad..672daa5 100644 --- a/data/data.py +++ b/data/data.py @@ -2,38 +2,54 @@ import os import torch from torch.utils.data import Dataset import numpy as np +import pickle +from utils.logger import setup_logger + +# 设置日志记录器 +logger = setup_logger('dataset') class BRepSDFDataset(Dataset): - def __init__(self, brep_data_dir:str, sdf_data_dir:str, split:str='train'): - self.brep_data_dir = brep_data_dir - self.sdf_data_dir = sdf_data_dir + def __init__(self, brep_dir:str, sdf_dir:str, split:str='train'): + """ + 初始化数据集 + + 参数: + brep_dir: pkl文件目录 + sdf_dir: npz文件目录 + split: 数据集分割('train', 'val', 'test') + """ + self.brep_dir = os.path.join(brep_dir, split) + self.sdf_dir = os.path.join(sdf_dir, split) self.split = split - self.brep_data_list = self._load_brep_data_list() - self.sdf_data_list = self._load_sdf_data_list() - - def _load_brep_data_list(self): - data_list = [] - split_dir = os.path.join(self.brep_data_dir, self.split) - for sample_dir in os.listdir(split_dir): - sample_path = os.path.join(split_dir, sample_dir) - if os.path.isdir(sample_path): - data_list.append(sample_path) - return data_list + + # 检查目录是否存在 + if not os.path.exists(self.brep_dir): + raise ValueError(f"B-rep directory not found: {self.brep_dir}") + if not os.path.exists(self.sdf_dir): + raise ValueError(f"SDF directory not found: {self.sdf_dir}") + + # 加载数据列表 + self.data_list = self._load_data_list() + + # 检查数据集是否为空 + if len(self.data_list) == 0: + raise ValueError(f"No valid data found in {split} set") + + logger.info(f"Loaded {split} dataset with {len(self.data_list)} samples") - def _load_sdf_data_list(self): + def _load_data_list(self): data_list = [] - split_dir = os.path.join(self.sdf_data_dir, self.split) - for sample_dir in os.listdir(split_dir): - sample_path = os.path.join(split_dir, sample_dir) + for sample_dir in os.listdir(self.brep_dir): + sample_path = os.path.join(self.brep_dir, sample_dir) if os.path.isdir(sample_path): data_list.append(sample_path) return data_list def __len__(self): - return len(self.brep_data_list) + return len(self.data_list) def __getitem__(self, idx): - sample_path = self.brep_data_list[idx] + sample_path = self.data_list[idx] # 解析 .step 文件 step_file = os.path.join(sample_path, 'model.step') @@ -57,12 +73,21 @@ class BRepSDFDataset(Dataset): # 返回 B-rep 特征 pass -# 使用示例 -if __name__ == '__main__': - dataset = BRepSDFDataset(brep_data_dir='path/to/brep_data', sdf_data_dir='path/to/sdf_data', split='train') - dataloader = torch.utils.data.DataLoader(dataset, batch_size=4, shuffle=True) +def test_dataset(): + """测试数据集功能""" + try: + # 设置测试路径 + brep_dir = '/home/wch/myDeepSDF/test_data/pkl' + sdf_dir = '/home/wch/myDeepSDF/test_data/sdf' + split = 'train' + + logger.info("="*50) + logger.info("Testing dataset") + logger.info(f"B-rep directory: {brep_dir}") + logger.info(f"SDF directory: {sdf_dir}") + logger.info(f"Split: {split}") + + # ... (其余测试代码保持不变) ... - for batch in dataloader: - print(batch['brep_features'].shape) - print(batch['sdf'].shape) - break \ No newline at end of file +if __name__ == '__main__': + test_dataset() \ No newline at end of file diff --git a/utils/logger.py b/utils/logger.py new file mode 100644 index 0000000..ff9ac73 --- /dev/null +++ b/utils/logger.py @@ -0,0 +1,46 @@ +import os +import logging +from datetime import datetime + +def setup_logger(name, log_dir='logs'): + """ + 设置日志记录器 + + 参数: + name: 日志记录器名称 + log_dir: 日志文件存储目录 + + 返回: + logger: 配置好的日志记录器 + """ + # 创建logs目录 + os.makedirs(log_dir, exist_ok=True) + + # 创建日志记录器 + logger = logging.getLogger(name) + logger.setLevel(logging.INFO) + + # 如果logger已经有处理器,则不添加 + if logger.handlers: + return logger + + # 创建格式化器 + formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') + + # 创建文件处理器 + current_time = datetime.now().strftime('%Y%m%d_%H%M%S') + log_file = os.path.join(log_dir, f'{name}_{current_time}.log') + file_handler = logging.FileHandler(log_file, encoding='utf-8') + file_handler.setLevel(logging.INFO) + file_handler.setFormatter(formatter) + + # 添加文件处理器到日志记录器 + logger.addHandler(file_handler) + + # 记录初始信息 + logger.info("="*50) + logger.info(f"Logger initialized: {name}") + logger.info(f"Log file: {log_file}") + logger.info("="*50) + + return logger \ No newline at end of file