Browse Source
			
			
			
			
				
		- Added `data_loader.py` with `NHREP_Dataset` class for loading point cloud, feature mask, and CSG tree data - Implemented `CustomDataLoader` for flexible data loading with configurable parameters - Refactored `train.py` to create a structured training pipeline for NHRepNet - Added support for feature sampling, device selection, and TensorBoard logging - Introduced modular training methods with error handling and loggingNH-Rep
				 2 changed files with 264 additions and 28 deletions
			
			
		@ -0,0 +1,192 @@ | 
				
			|||
import os | 
				
			|||
import sys | 
				
			|||
import time | 
				
			|||
 | 
				
			|||
project_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir)) | 
				
			|||
sys.path.append(project_dir) | 
				
			|||
os.chdir(project_dir) | 
				
			|||
 | 
				
			|||
import numpy as np | 
				
			|||
import torch | 
				
			|||
from torch.utils.data import Dataset, DataLoader | 
				
			|||
from torchvision import transforms | 
				
			|||
from pyhocon import ConfigFactory | 
				
			|||
from typing import List, Tuple | 
				
			|||
from utils.logger import logger | 
				
			|||
from utils.general import load_point_cloud_by_file_extension, load_feature_mask | 
				
			|||
''' | 
				
			|||
一个模型 对应 三个文件 | 
				
			|||
*_50k.xyz: 50,000 sampled points of the input B-Rep, can be visualized with MeshLab. | 
				
			|||
    e.g. x,y,z,nx,ny,nz | 
				
			|||
*_50k_mask.txt: (patch_id + 1) of sampled points. | 
				
			|||
    e.g. 1 or 0 each line | 
				
			|||
*_50k_csg.conf: Boolean tree built on the patches, stored in nested lists. 'flag_convex' indicates the convexity of the root node.  | 
				
			|||
    e.g. | 
				
			|||
        csg{ | 
				
			|||
            list = [0,1,[2,3,4,],], | 
				
			|||
            flag_convex = 1, | 
				
			|||
        } | 
				
			|||
''' | 
				
			|||
 | 
				
			|||
 | 
				
			|||
class NHREP_Dataset(Dataset): | 
				
			|||
    def __init__(self, data_dir, name_prefix: str, if_baseline: bool = False, if_feature_sample: bool = False): | 
				
			|||
        """ | 
				
			|||
        初始化数据集 | 
				
			|||
        :param data_dir: 数据目录 | 
				
			|||
        :param name_prefix: 模型名称 | 
				
			|||
        """ | 
				
			|||
        self.data_dir = os.path.abspath(data_dir)  # 将数据目录转换为绝对路径 | 
				
			|||
        self.if_baseline = if_baseline | 
				
			|||
        self.if_feature_sample = if_feature_sample | 
				
			|||
        self._load_single_data(self.data_dir, name_prefix, if_baseline, if_feature_sample) | 
				
			|||
 | 
				
			|||
    def _check_data_file_exists(self, file_name: str): | 
				
			|||
        if not os.path.exists(file_name): | 
				
			|||
            logger.error(f"Data file not found: {file_name}, absolute path: {os.path.abspath(file_name)}") | 
				
			|||
            raise Exception(f"Data file not found: {file_name}, absolute path: {os.path.abspath(file_name)}") | 
				
			|||
     | 
				
			|||
    def _load_feature_samples(self, data_dir: str, file_prefix: str) -> Tuple[torch.Tensor, torch.Tensor]: | 
				
			|||
        """加载特征样本""" | 
				
			|||
        try: | 
				
			|||
            logger.info(f"Loading feature samples for {file_prefix}") | 
				
			|||
            # load feature data | 
				
			|||
            input_fs_file = os.path.join( | 
				
			|||
                data_dir,  | 
				
			|||
                file_prefix+'_feature.xyz' | 
				
			|||
            ) | 
				
			|||
            self._check_data_file_exists(input_fs_file) | 
				
			|||
            feature_data = torch.tensor( | 
				
			|||
                np.loadtxt(input_fs_file), | 
				
			|||
                dtype=torch.float32,  | 
				
			|||
                device='cuda' | 
				
			|||
            ) | 
				
			|||
             | 
				
			|||
            # load feature mask | 
				
			|||
            fs_mask_file = os.path.join( | 
				
			|||
                data_dir,  | 
				
			|||
                file_prefix+'_feature_mask.txt' | 
				
			|||
            ) | 
				
			|||
            self._check_data_file_exists(fs_mask_file) | 
				
			|||
            feature_data_mask_pair = torch.tensor( | 
				
			|||
                np.loadtxt(fs_mask_file),  | 
				
			|||
                dtype=torch.int64,  | 
				
			|||
                device='cuda' | 
				
			|||
            ) | 
				
			|||
 | 
				
			|||
            return feature_data, feature_data_mask_pair | 
				
			|||
        except Exception as e: | 
				
			|||
            logger.error(f"Error loading feature samples: {str(e)}") | 
				
			|||
            raise | 
				
			|||
 | 
				
			|||
    def _load_single_data(self, data_dir: str, name_prefix: str, if_baseline: bool, if_feature_sample: bool): | 
				
			|||
        """从列表加载数据 | 
				
			|||
        :param data_dir: 数据目录 | 
				
			|||
        :param name_prefix: 模型名称 | 
				
			|||
        :param if_baseline: 是否为基准模型 | 
				
			|||
        :param if_feature_sample: 是否加载特征样本 | 
				
			|||
        """ | 
				
			|||
        try: | 
				
			|||
            logger.info(f"Loading data for {name_prefix}") | 
				
			|||
            # load xyz file | 
				
			|||
            # self.data: 2D array of floats, each row represents a point in 3D space | 
				
			|||
            xyz_file = os.path.join( | 
				
			|||
                data_dir,  | 
				
			|||
                name_prefix+'.xyz' | 
				
			|||
            ) | 
				
			|||
            self._check_data_file_exists(xyz_file) | 
				
			|||
            self.data = load_point_cloud_by_file_extension(xyz_file) | 
				
			|||
 | 
				
			|||
            # load mask file | 
				
			|||
            # self.feature_mask: 1D array of integers, each integer represents a feature mask | 
				
			|||
            mask_file = os.path.join( | 
				
			|||
                data_dir,  | 
				
			|||
                name_prefix+'_mask.txt' | 
				
			|||
            ) | 
				
			|||
            self._check_data_file_exists(mask_file) | 
				
			|||
            self.feature_mask = load_feature_mask(mask_file) | 
				
			|||
             | 
				
			|||
            # load csg file | 
				
			|||
            # self.csg_tree: list of lists, each inner list represents a node in the CSG tree | 
				
			|||
            # self.csg_flag_convex: boolean, indicating whether the root node is convex | 
				
			|||
            try: | 
				
			|||
                if if_baseline: | 
				
			|||
                    self.csg_tree = [0] | 
				
			|||
                    self.csg_flag_convex = True | 
				
			|||
                else: | 
				
			|||
                    csg_conf_file = os.path.join( | 
				
			|||
                        data_dir,  | 
				
			|||
                        name_prefix+'_csg.conf' | 
				
			|||
                    ) | 
				
			|||
                    self._check_data_file_exists(csg_conf_file) | 
				
			|||
                    csg_config = ConfigFactory.parse_file(csg_conf_file) | 
				
			|||
                    self.csg_tree = csg_config.get_list('csg.list') | 
				
			|||
                    self.csg_flag_convex = csg_config.get_int('csg.flag_convex') | 
				
			|||
            except Exception as e: | 
				
			|||
                logger.error(f"Error in CSG tree setup: {str(e)}") | 
				
			|||
                raise | 
				
			|||
             | 
				
			|||
            # load feature samples | 
				
			|||
            # self.feature_data: 2D array of floats, each row represents a point in 3D space | 
				
			|||
            # self.feature_data_mask_pair: 1D array of integers, each integer represents a feature mask | 
				
			|||
            if if_feature_sample: | 
				
			|||
                self.feature_data, self.feature_data_mask_pair = self._load_feature_samples(data_dir, name_prefix) | 
				
			|||
             | 
				
			|||
             | 
				
			|||
             | 
				
			|||
        except Exception as e: | 
				
			|||
            logger.error(f"Error loading data from list: {str(e)}") | 
				
			|||
            raise | 
				
			|||
     | 
				
			|||
    def get_data(self): | 
				
			|||
        return self.data | 
				
			|||
     | 
				
			|||
    def get_feature_mask(self): | 
				
			|||
        return self.feature_mask | 
				
			|||
 | 
				
			|||
    def get_csg_tree(self): | 
				
			|||
        return self.csg_tree, self.csg_flag_convex | 
				
			|||
   | 
				
			|||
    def get_feature_data(self): | 
				
			|||
        if self.if_feature_sample: | 
				
			|||
            return self.feature_data, self.feature_data_mask_pair | 
				
			|||
        else: | 
				
			|||
            return None, None | 
				
			|||
 | 
				
			|||
class CustomDataLoader: | 
				
			|||
    def __init__(self, data_dir, batch_size=32, shuffle=True, num_workers=4, transform=None): | 
				
			|||
        """ | 
				
			|||
        初始化数据加载器 | 
				
			|||
        :param data_dir: 数据目录 | 
				
			|||
        :param batch_size: 批量大小 | 
				
			|||
        :param shuffle: 是否打乱数据 | 
				
			|||
        :param num_workers: 使用的子进程数 | 
				
			|||
        :param transform: 数据增强或转换 | 
				
			|||
        """ | 
				
			|||
        self.dataset = CustomDataset(data_dir, transform) | 
				
			|||
        self.dataloader = DataLoader( | 
				
			|||
            self.dataset, | 
				
			|||
            batch_size=batch_size, | 
				
			|||
            shuffle=shuffle, | 
				
			|||
            num_workers=num_workers | 
				
			|||
        ) | 
				
			|||
 | 
				
			|||
    def get_loader(self): | 
				
			|||
        """返回数据加载器""" | 
				
			|||
        return self.dataloader | 
				
			|||
 | 
				
			|||
# 示例用法 | 
				
			|||
if __name__ == "__main__": | 
				
			|||
    # 数据目录和模型名称前缀 | 
				
			|||
    data_dir = '../data/input_data'  # 数据目录 | 
				
			|||
    name_prefix = 'broken_bullet_50k' | 
				
			|||
 | 
				
			|||
    # 数据增强示例 | 
				
			|||
    transform = transforms.Compose([ | 
				
			|||
        transforms.Normalize(mean=[0.5], std=[0.5]),  # 归一化 | 
				
			|||
    ]) | 
				
			|||
 | 
				
			|||
    # 创建数据集实例 | 
				
			|||
    dataset = NHREP_Dataset(data_dir, name_prefix) | 
				
			|||
 | 
				
			|||
 | 
				
			|||
@ -1,41 +1,85 @@ | 
				
			|||
import os | 
				
			|||
import sys | 
				
			|||
import time | 
				
			|||
 | 
				
			|||
project_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir)) | 
				
			|||
sys.path.append(project_dir) | 
				
			|||
os.chdir(project_dir) | 
				
			|||
 | 
				
			|||
import torch | 
				
			|||
import numpy as np | 
				
			|||
import os | 
				
			|||
from torch.utils.tensorboard import SummaryWriter | 
				
			|||
from torch.optim.lr_scheduler import StepLR | 
				
			|||
from tqdm import tqdm | 
				
			|||
from utils.logger import logger | 
				
			|||
from utils.general import gradient | 
				
			|||
 | 
				
			|||
from data_loader import NHREP_Dataset | 
				
			|||
from model.network import NHRepNet  # 导入 NHRepNet | 
				
			|||
 | 
				
			|||
class NHREPNet_Training: | 
				
			|||
    def __init__(self, data_dir, name_prefix: str, if_baseline: bool = False, if_feature_sample: bool = False): | 
				
			|||
        self.dataset = NHREP_Dataset(data_dir, name_prefix, if_baseline, if_feature_sample) | 
				
			|||
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | 
				
			|||
         | 
				
			|||
        # 初始化模型 | 
				
			|||
        d_in = 6  # 输入维度,例如 3D 坐标 | 
				
			|||
        dims_sdf = [64, 64, 64]  # 隐藏层维度 | 
				
			|||
        csg_tree, _ = self.dataset.get_csg_tree() | 
				
			|||
        self.model = NHRepNet(d_in, dims_sdf, csg_tree).to(self.device)  # 实例化模型并移动到设备 | 
				
			|||
         | 
				
			|||
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=0.001)  # Adam 优化器 | 
				
			|||
        self.scheduler = StepLR(self.optimizer, step_size=1000, gamma=0.1)  # 学习率调度器 | 
				
			|||
        self.nepochs = 15000  # 训练轮数 | 
				
			|||
        self.writer = SummaryWriter()  # TensorBoard 记录器 | 
				
			|||
 | 
				
			|||
    def run_nhrepnet_training(self): | 
				
			|||
        print("running")  # 输出训练开始的提示信息 | 
				
			|||
        self.data = self.data.cuda()  # 将数据移动到 GPU 上 | 
				
			|||
        self.data.requires_grad_()  # 设置数据以便计算梯度 | 
				
			|||
        self.feature_mask = self.feature_mask.cuda()  # 将特征掩码移动到 GPU 上 | 
				
			|||
        n_branch = int(torch.max(self.feature_mask).item())  # 计算分支数量 | 
				
			|||
        n_batchsize = self.points_batch  # 设置批次大小 | 
				
			|||
        n_patch_batch = n_batchsize // n_branch  # 计算每个分支的补丁批次大小 | 
				
			|||
 | 
				
			|||
        # 初始化补丁 ID 列表 | 
				
			|||
        patch_id = [np.where(self.feature_mask.cpu().numpy() == i + 1)[0] for i in range(n_branch)] | 
				
			|||
         | 
				
			|||
        for epoch in range(15000):  # 开始训练循环 | 
				
			|||
            indices = torch.cat([torch.tensor(patch_id[i][np.random.choice(len(patch_id[i]), n_patch_batch, replace=True)]).cuda() for i in range(n_branch)]).cuda()  # 随机选择补丁的索引 | 
				
			|||
            cur_data = self.data[indices]  # 根据索引获取当前数据 | 
				
			|||
            mnfld_pnts = cur_data[:, :self.d_in]  # 提取流形点 | 
				
			|||
        logger.info("开始训练") | 
				
			|||
        self.model.train()  # 设置模型为训练模式 | 
				
			|||
 | 
				
			|||
        for epoch in range(self.nepochs):  # 开始训练循环 | 
				
			|||
            try: | 
				
			|||
                self.train_one_epoch(epoch) | 
				
			|||
                self.scheduler.step()  # 更新学习率 | 
				
			|||
            except Exception as e: | 
				
			|||
                logger.error(f"训练过程中发生错误: {str(e)}") | 
				
			|||
                break | 
				
			|||
 | 
				
			|||
    def train_one_epoch(self, epoch): | 
				
			|||
        logger.info(f"Epoch {epoch}/{self.nepochs} 开始") | 
				
			|||
        total_loss = 0.0 | 
				
			|||
 | 
				
			|||
        # 获取输入数据 | 
				
			|||
        input_data = self.dataset.get_data().to(self.device)  # 获取数据并移动到设备 | 
				
			|||
         | 
				
			|||
        # 前向传播 | 
				
			|||
            mnfld_pred_all = self.network(mnfld_pnts)  # 进行前向传播,计算流形点的预测值 | 
				
			|||
            mnfld_pred = mnfld_pred_all[:, 0]  # 提取流形预测结果 | 
				
			|||
        outputs = self.model(input_data)  # 使用模型进行前向传播 | 
				
			|||
         | 
				
			|||
        # 计算损失 | 
				
			|||
            loss = (mnfld_pred.abs()).mean()  # 计算流形损失 | 
				
			|||
        loss = self.compute_loss(outputs)  # 计算损失 | 
				
			|||
        total_loss += loss.item() | 
				
			|||
 | 
				
			|||
        # 反向传播 | 
				
			|||
        self.optimizer.zero_grad()  # 清空梯度 | 
				
			|||
        loss.backward()  # 反向传播 | 
				
			|||
        self.optimizer.step()  # 更新参数 | 
				
			|||
 | 
				
			|||
        avg_loss = total_loss  | 
				
			|||
        logger.info(f'Epoch [{epoch}/{self.nepochs}], Average Loss: {avg_loss:.4f}') | 
				
			|||
        self.writer.add_scalar('Loss/train', avg_loss, epoch)  # 记录损失到 TensorBoard | 
				
			|||
 | 
				
			|||
    def compute_loss(self, outputs): | 
				
			|||
        """ | 
				
			|||
        计算流型损失的逻辑 | 
				
			|||
 | 
				
			|||
            self.optimizer.zero_grad()  # 清零优化器的梯度 | 
				
			|||
            loss.backward()  # 反向传播计算梯度 | 
				
			|||
            self.optimizer.step()  # 更新模型参数 | 
				
			|||
        :param outputs: 模型的输出 | 
				
			|||
        :return: 计算得到的流型损失值      | 
				
			|||
        """ | 
				
			|||
 | 
				
			|||
            if epoch % 100 == 0:  # 每 100 轮记录损失 | 
				
			|||
                print(f'Epoch [{epoch}/{self.nepochs}], Loss: {loss.item():.4f}')  # 输出当前轮次的损失 | 
				
			|||
        # 计算流型损失(这里使用均方误差作为示例) | 
				
			|||
        manifold_loss = (outputs.abs()).mean()  # 计算流型损失 | 
				
			|||
        return manifold_loss | 
				
			|||
 | 
				
			|||
        self.tracing()  # | 
				
			|||
if __name__ == "__main__": | 
				
			|||
    data_dir = '../data/input_data'  # 数据目录 | 
				
			|||
    name_prefix = 'broken_bullet_50k' | 
				
			|||
    train = NHREPNet_Training(data_dir, name_prefix, if_baseline=True, if_feature_sample=False) | 
				
			|||
    train.run_nhrepnet_training() | 
				
			|||
					Loading…
					
					
				
		Reference in new issue